feat: Add model type support
Adds model type support for chat and embedding models. This allows users to specify which type of model they want to use when adding custom models. Additionally, this commit introduces a more descriptive interface for adding custom models, enhancing the clarity of the model selection process.
This commit is contained in:
129
src/components/Option/Models/AddCustomModelModal.tsx
Normal file
129
src/components/Option/Models/AddCustomModelModal.tsx
Normal file
@@ -0,0 +1,129 @@
|
||||
import { createModel } from "@/db/models"
|
||||
import { getAllOpenAIConfig } from "@/db/openai"
|
||||
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
|
||||
import { Input, Modal, Form, Select, Radio } from "antd"
|
||||
import { Loader2 } from "lucide-react"
|
||||
import { useTranslation } from "react-i18next"
|
||||
|
||||
type Props = {
|
||||
open: boolean
|
||||
setOpen: (open: boolean) => void
|
||||
}
|
||||
|
||||
export const AddCustomModelModal: React.FC<Props> = ({ open, setOpen }) => {
|
||||
const { t } = useTranslation(["openai"])
|
||||
const [form] = Form.useForm()
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, isPending } = useQuery({
|
||||
queryKey: ["fetchProviders"],
|
||||
queryFn: async () => {
|
||||
const providers = await getAllOpenAIConfig()
|
||||
return providers.filter((provider) => provider.provider !== "lmstudio")
|
||||
}
|
||||
})
|
||||
|
||||
const onFinish = async (values: {
|
||||
model_id: string
|
||||
model_type: "chat" | "embedding"
|
||||
provider_id: string
|
||||
}) => {
|
||||
await createModel(
|
||||
values.model_id,
|
||||
values.model_id,
|
||||
values.provider_id,
|
||||
values.model_type
|
||||
)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
const { mutate: createModelMutation, isPending: isSaving } = useMutation({
|
||||
mutationFn: onFinish,
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["fetchCustomModels"]
|
||||
})
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ["fetchModel"]
|
||||
})
|
||||
setOpen(false)
|
||||
form.resetFields()
|
||||
}
|
||||
})
|
||||
|
||||
return (
|
||||
<Modal
|
||||
footer={null}
|
||||
open={open}
|
||||
title={t("manageModels.modal.title")}
|
||||
onCancel={() => setOpen(false)}>
|
||||
<Form form={form} onFinish={createModelMutation} layout="vertical">
|
||||
<Form.Item
|
||||
name="model_id"
|
||||
label={t("manageModels.modal.form.name.label")}
|
||||
rules={[
|
||||
{
|
||||
required: true,
|
||||
message: t("manageModels.modal.form.name.required")
|
||||
}
|
||||
]}>
|
||||
<Input
|
||||
placeholder={t("manageModels.modal.form.name.placeholder")}
|
||||
size="large"
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="provider_id"
|
||||
label={t("manageModels.modal.form.provider.label")}
|
||||
rules={[
|
||||
{
|
||||
required: true,
|
||||
message: t("manageModels.modal.form.provider.required")
|
||||
}
|
||||
]}>
|
||||
<Select
|
||||
placeholder={t("manageModels.modal.form.provider.placeholder")}
|
||||
size="large"
|
||||
loading={isPending}>
|
||||
{data?.map((provider: any) => (
|
||||
<Select.Option key={provider.id} value={provider.id}>
|
||||
{provider.name}
|
||||
</Select.Option>
|
||||
))}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="model_type"
|
||||
label={t("manageModels.modal.form.type.label")}
|
||||
initialValue="chat"
|
||||
rules={[
|
||||
{
|
||||
required: true,
|
||||
message: t("manageModels.modal.form.type.required")
|
||||
}
|
||||
]}>
|
||||
<Radio.Group>
|
||||
<Radio value="chat">{t("radio.chat")}</Radio>
|
||||
<Radio value="embedding">{t("radio.embedding")}</Radio>
|
||||
</Radio.Group>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item>
|
||||
<button
|
||||
type="submit"
|
||||
disabled={isSaving}
|
||||
className="inline-flex justify-center w-full text-center mt-4 items-center rounded-md border border-transparent bg-black px-2 py-2 text-sm font-medium leading-4 text-white shadow-sm hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50 ">
|
||||
{!isSaving ? (
|
||||
t("common:save")
|
||||
) : (
|
||||
<Loader2 className="w-5 h-5 animate-spin" />
|
||||
)}
|
||||
</button>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
)
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useForm } from "@mantine/form"
|
||||
import { useMutation } from "@tanstack/react-query"
|
||||
import { useMutation, useQueryClient } from "@tanstack/react-query"
|
||||
import { Input, Modal, notification } from "antd"
|
||||
import { Download } from "lucide-react"
|
||||
import { useTranslation } from "react-i18next"
|
||||
@@ -11,6 +11,7 @@ type Props = {
|
||||
|
||||
export const AddOllamaModelModal: React.FC<Props> = ({ open, setOpen }) => {
|
||||
const { t } = useTranslation(["settings", "common", "openai"])
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const form = useForm({
|
||||
initialValues: {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { getAllCustomModels, deleteModel } from "@/db/models"
|
||||
import { useStorage } from "@plasmohq/storage/hook"
|
||||
import { useQuery, useQueryClient, useMutation } from "@tanstack/react-query"
|
||||
import { Skeleton, Table, Tooltip } from "antd"
|
||||
import { Skeleton, Table, Tag, Tooltip } from "antd"
|
||||
import { Trash2 } from "lucide-react"
|
||||
import { useTranslation } from "react-i18next"
|
||||
|
||||
@@ -10,7 +10,6 @@ export const CustomModelsTable = () => {
|
||||
|
||||
const { t } = useTranslation(["openai", "common"])
|
||||
|
||||
|
||||
const queryClient = useQueryClient()
|
||||
|
||||
const { data, status } = useQuery({
|
||||
@@ -27,7 +26,6 @@ export const CustomModelsTable = () => {
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div>
|
||||
@@ -37,16 +35,20 @@ export const CustomModelsTable = () => {
|
||||
<div className="overflow-x-auto">
|
||||
<Table
|
||||
columns={[
|
||||
{
|
||||
title: t("manageModels.columns.name"),
|
||||
dataIndex: "name",
|
||||
key: "name"
|
||||
},
|
||||
{
|
||||
title: t("manageModels.columns.model_id"),
|
||||
dataIndex: "model_id",
|
||||
key: "model_id"
|
||||
},
|
||||
{
|
||||
title: t("manageModels.columns.model_type"),
|
||||
dataIndex: "model_type",
|
||||
render: (txt) => (
|
||||
<Tag color={txt === "chat" ? "green" : "blue"}>
|
||||
{t(`radio.${txt}`)}
|
||||
</Tag>
|
||||
)
|
||||
},
|
||||
{
|
||||
title: t("manageModels.columns.provider"),
|
||||
dataIndex: "provider",
|
||||
|
||||
@@ -6,11 +6,13 @@ import { useTranslation } from "react-i18next"
|
||||
import { OllamaModelsTable } from "./OllamaModelsTable"
|
||||
import { CustomModelsTable } from "./CustomModelsTable"
|
||||
import { AddOllamaModelModal } from "./AddOllamaModelModal"
|
||||
import { AddCustomModelModal } from "./AddCustomModelModal"
|
||||
|
||||
dayjs.extend(relativeTime)
|
||||
|
||||
export const ModelsBody = () => {
|
||||
const [open, setOpen] = useState(false)
|
||||
const [openAddModelModal, setOpenAddModelModal] = useState(false)
|
||||
const [segmented, setSegmented] = useState<string>("ollama")
|
||||
|
||||
const { t } = useTranslation(["settings", "common", "openai"])
|
||||
@@ -26,6 +28,8 @@ export const ModelsBody = () => {
|
||||
onClick={() => {
|
||||
if (segmented === "ollama") {
|
||||
setOpen(true)
|
||||
} else {
|
||||
setOpenAddModelModal(true)
|
||||
}
|
||||
}}
|
||||
className="inline-flex items-center rounded-md border border-transparent bg-black px-2 py-2 text-md font-medium leading-4 text-white shadow-sm hover:bg-gray-800 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50">
|
||||
@@ -56,6 +60,11 @@ export const ModelsBody = () => {
|
||||
</div>
|
||||
|
||||
<AddOllamaModelModal open={open} setOpen={setOpen} />
|
||||
|
||||
<AddCustomModelModal
|
||||
open={openAddModelModal}
|
||||
setOpen={setOpenAddModelModal}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user