feat: Add model type support

Adds model type support for chat and embedding models. This allows users to specify which type of model they want to use when adding custom models.

Additionally, this commit introduces a more descriptive interface for adding custom models, enhancing the clarity of the model selection process.
This commit is contained in:
n4ze3m
2024-10-13 18:22:16 +05:30
parent 4e04155471
commit ff4473c35b
9 changed files with 277 additions and 43 deletions

View File

@@ -1,10 +1,12 @@
import { getOpenAIConfigById } from "@/db/openai"
import { getAllOpenAIModels } from "@/libs/openai"
import { useMutation, useQuery } from "@tanstack/react-query"
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
import { useTranslation } from "react-i18next"
import { Checkbox, Input, Spin, message } from "antd"
import { Checkbox, Input, Spin, message, Radio } from "antd"
import { useState, useMemo } from "react"
import { createManyModels } from "@/db/models"
import { Popover } from "antd"
import { InfoIcon } from "lucide-react"
type Props = {
openaiId: string
@@ -15,6 +17,8 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
const { t } = useTranslation(["openai"])
const [selectedModels, setSelectedModels] = useState<string[]>([])
const [searchTerm, setSearchTerm] = useState("")
const [modelType, setModelType] = useState("chat")
const queryClient = useQueryClient()
const { data, status } = useQuery({
queryKey: ["openAIConfigs", openaiId],
@@ -56,7 +60,8 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
const payload = models.map((id) => ({
model_id: id,
name: filteredModels.find((model) => model.id === id)?.name ?? id,
provider_id: openaiId
provider_id: openaiId,
model_type: modelType
}))
await createManyModels(payload)
@@ -68,6 +73,9 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
mutationFn: onSave,
onSuccess: () => {
setOpenModelModal(false)
queryClient.invalidateQueries({
queryKey: ["fetchModel"]
})
message.success(t("modal.model.success"))
}
})
@@ -97,6 +105,7 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
<p className="text-sm text-gray-500 dark:text-gray-400">
{t("modal.model.subheading")}
</p>
<Input
placeholder={t("searchModel")}
value={searchTerm}
@@ -134,6 +143,35 @@ export const OpenAIFetchModel = ({ openaiId, setOpenModelModal }: Props) => {
))}
</div>
</div>
<div className="flex items-center">
<Radio.Group
onChange={(e) => setModelType(e.target.value)}
value={modelType}>
<Radio value="chat">{t("radio.chat")}</Radio>
<Radio value="embedding">{t("radio.embedding")}</Radio>
</Radio.Group>
<Popover
content={
<div>
<p>
<b className="text-gray-800 dark:text-gray-100">
{t("radio.chat")}
</b>{" "}
{t("radio.chatInfo")}
</p>
<p>
<b className="text-gray-800 dark:text-gray-100">
{t("radio.embedding")}
</b>{" "}
{t("radio.embeddingInfo")}
</p>
</div>
}>
<InfoIcon className="ml-2 h-4 w-4 text-gray-500 cursor-pointer" />
</Popover>
</div>
<button
onClick={handleSave}
disabled={isSaving}

View File

@@ -14,7 +14,13 @@ import {
updateOpenAIConfig
} from "@/db/openai"
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
import { Pencil, Trash2, RotateCwIcon, DownloadIcon, AlertTriangle } from "lucide-react"
import {
Pencil,
Trash2,
RotateCwIcon,
DownloadIcon,
AlertTriangle
} from "lucide-react"
import { OpenAIFetchModel } from "./openai-fetch-model"
import { OAI_API_PROVIDERS } from "@/utils/oai-api-providers"
@@ -149,17 +155,23 @@ export const OpenAIApp = () => {
</button>
</Tooltip>
<Tooltip title={t("refetch")}>
<Tooltip
title={
record.provider !== "lmstudio"
? t("newModel")
: t("noNewModel")
}>
<button
className="text-gray-700 dark:text-gray-400"
className="text-gray-700 dark:text-gray-400 disabled:opacity-50"
onClick={() => {
setOpenModelModal(true)
setOpenaiId(record.id)
}}
disabled={!record.id}>
disabled={!record.id || record.provider === "lmstudio"}>
<DownloadIcon className="size-4" />
</button>
</Tooltip>
<Tooltip title={t("delete")}>
<button
className="text-red-500 dark:text-red-400"
@@ -251,11 +263,11 @@ export const OpenAIApp = () => {
placeholder={t("modal.apiKey.placeholder")}
/>
</Form.Item>
{
provider === "lmstudio" && <div className="text-xs text-gray-600 dark:text-gray-400 mb-4">
{t("modal.tipLMStudio")}
</div>
}
{provider === "lmstudio" && (
<div className="text-xs text-gray-600 dark:text-gray-400 mb-4">
{t("modal.tipLMStudio")}
</div>
)}
<button
type="submit"
className="inline-flex justify-center w-full text-center mt-4 items-center rounded-md border border-transparent bg-black px-2 py-2 text-sm font-medium leading-4 text-white shadow-sm hover:bg-gray-700 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2 dark:bg-white dark:text-gray-800 dark:hover:bg-gray-100 dark:focus:ring-gray-500 dark:focus:ring-offset-gray-100 disabled:opacity-50">