feat: Add model type support

Adds model type support for chat and embedding models. This allows users to specify which type of model they want to use when adding custom models.

Additionally, this commit introduces a more descriptive interface for adding custom models, enhancing the clarity of the model selection process.
This commit is contained in:
n4ze3m
2024-10-13 18:22:16 +05:30
parent 4e04155471
commit ff4473c35b
9 changed files with 277 additions and 43 deletions

View File

@@ -0,0 +1,129 @@
import { createModel } from "@/db/models"
import { getAllOpenAIConfig } from "@/db/openai"
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
import { Input, Modal, Form, Select, Radio } from "antd"
import { Loader2 } from "lucide-react"
import { useTranslation } from "react-i18next"
type Props = {
open: boolean
setOpen: (open: boolean) => void
}
export const AddCustomModelModal: React.FC<Props> = ({ open, setOpen }) => {
const { t } = useTranslation(["openai"])
const [form] = Form.useForm()
const queryClient = useQueryClient()
const { data, isPending } = useQuery({
queryKey: ["fetchProviders"],
queryFn: async () => {
const providers = await getAllOpenAIConfig()
return providers.filter((provider) => provider.provider !== "lmstudio")
}
})
const onFinish = async (values: {
model_id: string
model_type: "chat" | "embedding"
provider_id: string
}) => {
await createModel(
values.model_id,
values.model_id,
values.provider_id,
values.model_type
)
return true
}
const { mutate: createModelMutation, isPending: isSaving } = useMutation({
mutationFn: onFinish,
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: ["fetchCustomModels"]
})
queryClient.invalidateQueries({
queryKey: ["fetchModel"]
})
setOpen(false)
form.resetFields()
}
})
return (
<Modal
footer={null}
open={open}
title={t("manageModels.modal.title")}
onCancel={() => setOpen(false)}>
<Form form={form} onFinish={createModelMutation} layout="vertical">
<Form.Item
name="model_id"
label={t("manageModels.modal.form.name.label")}
rules={[
{
required: true,
message: t("manageModels.modal.form.name.required")
}
]}>
<Input
placeholder={t("manageModels.modal.form.name.placeholder")}
size="large"
/>
</Form.Item>
<Form.Item
name="provider_id"
label={t("manageModels.modal.form.provider.label")}
rules={[
{
required: true,
message: t("manageModels.modal.form.provider.required")
}
]}>
<Select
placeholder={t("manageModels.modal.form.provider.placeholder")}
size="large"
loading={isPending}>
{data?.map((provider: any) => (
<Select.Option key={provider.id} value={provider.id}>
{provider.name}
</Select.Option>
))}
</Select>
</Form.Item>
<Form.Item
name="model_type"
label={t("manageModels.modal.form.type.label")}
initialValue="chat"
rules={[
{
required: true,
message: t("manageModels.modal.form.type.required")
}
]}>
<Radio.Group>
<Radio value="chat">{t("radio.chat")}</Radio>
<Radio value="embedding">{t("radio.embedding")}</Radio>
</Radio.Group>
</Form.Item>
<Form.Item>
<button
type="submit"
disabled={isSaving}
className="inline-flex justify-center w-full text-center mt-4 items-center rounded-md border border-transparent bg-black px-2 py-2 text-sm font-medium leading-4 text-white shadow-sm hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50 ">
{!isSaving ? (
t("common:save")
) : (
<Loader2 className="w-5 h-5 animate-spin" />
)}
</button>
</Form.Item>
</Form>
</Modal>
)
}

View File

@@ -1,5 +1,5 @@
import { useForm } from "@mantine/form"
import { useMutation } from "@tanstack/react-query"
import { useMutation, useQueryClient } from "@tanstack/react-query"
import { Input, Modal, notification } from "antd"
import { Download } from "lucide-react"
import { useTranslation } from "react-i18next"
@@ -11,6 +11,7 @@ type Props = {
export const AddOllamaModelModal: React.FC<Props> = ({ open, setOpen }) => {
const { t } = useTranslation(["settings", "common", "openai"])
const queryClient = useQueryClient()
const form = useForm({
initialValues: {

View File

@@ -1,7 +1,7 @@
import { getAllCustomModels, deleteModel } from "@/db/models"
import { useStorage } from "@plasmohq/storage/hook"
import { useQuery, useQueryClient, useMutation } from "@tanstack/react-query"
import { Skeleton, Table, Tooltip } from "antd"
import { Skeleton, Table, Tag, Tooltip } from "antd"
import { Trash2 } from "lucide-react"
import { useTranslation } from "react-i18next"
@@ -10,7 +10,6 @@ export const CustomModelsTable = () => {
const { t } = useTranslation(["openai", "common"])
const queryClient = useQueryClient()
const { data, status } = useQuery({
@@ -27,7 +26,6 @@ export const CustomModelsTable = () => {
}
})
return (
<div>
<div>
@@ -37,16 +35,20 @@ export const CustomModelsTable = () => {
<div className="overflow-x-auto">
<Table
columns={[
{
title: t("manageModels.columns.name"),
dataIndex: "name",
key: "name"
},
{
title: t("manageModels.columns.model_id"),
dataIndex: "model_id",
key: "model_id"
},
{
title: t("manageModels.columns.model_type"),
dataIndex: "model_type",
render: (txt) => (
<Tag color={txt === "chat" ? "green" : "blue"}>
{t(`radio.${txt}`)}
</Tag>
)
},
{
title: t("manageModels.columns.provider"),
dataIndex: "provider",

View File

@@ -6,11 +6,13 @@ import { useTranslation } from "react-i18next"
import { OllamaModelsTable } from "./OllamaModelsTable"
import { CustomModelsTable } from "./CustomModelsTable"
import { AddOllamaModelModal } from "./AddOllamaModelModal"
import { AddCustomModelModal } from "./AddCustomModelModal"
dayjs.extend(relativeTime)
export const ModelsBody = () => {
const [open, setOpen] = useState(false)
const [openAddModelModal, setOpenAddModelModal] = useState(false)
const [segmented, setSegmented] = useState<string>("ollama")
const { t } = useTranslation(["settings", "common", "openai"])
@@ -26,6 +28,8 @@ export const ModelsBody = () => {
onClick={() => {
if (segmented === "ollama") {
setOpen(true)
} else {
setOpenAddModelModal(true)
}
}}
className="inline-flex items-center rounded-md border border-transparent bg-black px-2 py-2 text-md font-medium leading-4 text-white shadow-sm hover:bg-gray-800 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50">
@@ -56,6 +60,11 @@ export const ModelsBody = () => {
</div>
<AddOllamaModelModal open={open} setOpen={setOpen} />
<AddCustomModelModal
open={openAddModelModal}
setOpen={setOpenAddModelModal}
/>
</div>
)
}