diff --git a/backend/apps/chat/task/llm.py b/backend/apps/chat/task/llm.py index 85d89dba..26af4845 100644 --- a/backend/apps/chat/task/llm.py +++ b/backend/apps/chat/task/llm.py @@ -41,6 +41,7 @@ from apps.datasource.embedding.ds_embedding import get_ds_embedding from apps.datasource.models.datasource import CoreDatasource from apps.db.db import exec_sql, get_version, check_connection +from apps.system.crud.aimodel_manage import get_ai_model_list_by_workspace from apps.system.crud.assistant import AssistantOutDs, AssistantOutDsFactory, get_assistant_ds from apps.system.crud.parameter_manage import get_groups from apps.system.schemas.system_schema import AssistantOutDsSchema @@ -176,11 +177,16 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C @classmethod async def create(cls, *args, **kwargs): specialized_model_id = None + _ai_model_list = [] if args[3]: + if args[1]: + ws_id = args[1].oid + _ai_model_list = get_ai_model_list_by_workspace(args[0], ws_id) if args[3].enable_custom_model: if args[3].custom_model: - specialized_model_id = args[3].custom_model - print("use custom model: id[" + args[3].custom_model + "]") + if any(str(model.id) == str(args[3].custom_model) for model in _ai_model_list): + specialized_model_id = args[3].custom_model + print("use custom model: id[" + specialized_model_id + "]") config: LLMConfig = await get_default_config(specialized_model_id) instance = cls(*args, **kwargs, config=config) diff --git a/backend/apps/system/api/aimodel.py b/backend/apps/system/api/aimodel.py index 1cd146cb..1a198e5c 100644 --- a/backend/apps/system/api/aimodel.py +++ b/backend/apps/system/api/aimodel.py @@ -249,11 +249,11 @@ async def update_model_ws_mapping_by_id( return [str(ws_id) for ws_id in ws_ids] -@router.get("/list_by_ws", response_model=AiModelBrief, summary=f"{PLACEHOLDER_PREFIX}system_model_query", +@router.get("/list/by_ws", response_model=List[AiModelBrief], summary=f"{PLACEHOLDER_PREFIX}system_model_query", description=f"{PLACEHOLDER_PREFIX}system_model_query") @require_permissions(permission=SqlbotPermission(role=['ws_admin'])) async def get_model_by_ws( session: SessionDep, current_user: CurrentUser ): - return get_ai_model_list_by_workspace(session, current_user.workspace_id) + return get_ai_model_list_by_workspace(session, current_user.oid) diff --git a/backend/apps/system/models/system_model.py b/backend/apps/system/models/system_model.py index 2bc52780..6b360232 100644 --- a/backend/apps/system/models/system_model.py +++ b/backend/apps/system/models/system_model.py @@ -1,6 +1,8 @@ - from typing import Optional + +from pydantic import field_serializer from sqlmodel import BigInteger, Field, Text, SQLModel + from common.core.models import SnowflakeBase from common.core.schemas import BaseCreatorDTO @@ -9,84 +11,98 @@ class AiModelBase: supplier: int = Field(nullable=False) name: str = Field(max_length=255, nullable=False) model_type: int = Field(nullable=False) - base_model: str = Field(max_length = 255, nullable=False) + base_model: str = Field(max_length=255, nullable=False) default_model: bool = Field(default=False, nullable=False) + class AiModelDetail(SnowflakeBase, AiModelBase, table=True): - __tablename__ = "ai_model" - api_key: str | None = Field(default=None, nullable=True, sa_type=Text()) - api_domain: str = Field(nullable=False, sa_type=Text()) - protocol: int = Field(nullable=False, default = 1) - config: str = Field(sa_type = Text()) - status: int = Field(nullable=False, default = 1) - create_time: int = Field(default=0, sa_type=BigInteger()) - + __tablename__ = "ai_model" + api_key: str | None = Field(default=None, nullable=True, sa_type=Text()) + api_domain: str = Field(nullable=False, sa_type=Text()) + protocol: int = Field(nullable=False, default=1) + config: str = Field(sa_type=Text()) + status: int = Field(nullable=False, default=1) + create_time: int = Field(default=0, sa_type=BigInteger()) + + class AiModelWorkspaceMapping(SnowflakeBase, table=True): __tablename__ = "ai_model_workspace_mapping" ai_model_id: int = Field(default=None, nullable=True, sa_type=BigInteger()) workspace_id: int = Field(default=None, nullable=True, sa_type=BigInteger()) + class AiModelBrief(SQLModel): id: int name: str default_model: bool supplier: int + @field_serializer("id") + def id_to_str(self, v: int) -> str: + return str(v) + + class WorkspaceBase(SQLModel): name: str = Field(max_length=255, nullable=False) + class WorkspaceEditor(WorkspaceBase, BaseCreatorDTO): pass - + + class WorkspaceModel(SnowflakeBase, WorkspaceBase, table=True): __tablename__ = "sys_workspace" create_time: int = Field(default=0, sa_type=BigInteger()) - + + class UserWsBaseModel(SQLModel): uid: int = Field(nullable=False, sa_type=BigInteger()) oid: int = Field(nullable=False, sa_type=BigInteger()) - weight: int = Field(default=0, nullable=False) - + weight: int = Field(default=0, nullable=False) + + class UserWsModel(SnowflakeBase, UserWsBaseModel, table=True): __tablename__ = "sys_user_ws" - + class AssistantBaseModel(SQLModel): name: str = Field(max_length=255, nullable=False) type: int = Field(nullable=False, default=0) domain: str = Field(max_length=255, nullable=False) - description: Optional[str] = Field(sa_type = Text(), nullable=True) - configuration: Optional[str] = Field(sa_type = Text(), nullable=True) + description: Optional[str] = Field(sa_type=Text(), nullable=True) + configuration: Optional[str] = Field(sa_type=Text(), nullable=True) create_time: int = Field(default=0, sa_type=BigInteger()) - app_id: Optional[str] = Field(default=None, max_length=255, nullable=True) + app_id: Optional[str] = Field(default=None, max_length=255, nullable=True) app_secret: Optional[str] = Field(default=None, max_length=255, nullable=True) oid: Optional[int] = Field(nullable=True, sa_type=BigInteger(), default=1) enable_custom_model: Optional[bool] = Field(default=False, nullable=True) custom_model: Optional[str] = Field(default=None, max_length=255, nullable=True) + class AssistantModel(SnowflakeBase, AssistantBaseModel, table=True): __tablename__ = "sys_assistant" - + class AuthenticationBaseModel(SQLModel): name: str = Field(max_length=255, nullable=False) type: int = Field(nullable=False, default=0) - config: Optional[str] = Field(sa_type = Text(), nullable=True) - - + config: Optional[str] = Field(sa_type=Text(), nullable=True) + + class AuthenticationModel(SnowflakeBase, AuthenticationBaseModel, table=True): __tablename__ = "sys_authentication" create_time: Optional[int] = Field(default=0, sa_type=BigInteger()) enable: bool = Field(default=False, nullable=False) valid: bool = Field(default=False, nullable=False) - + class ApiKeyBaseModel(SQLModel): access_key: str = Field(max_length=255, nullable=False) secret_key: str = Field(max_length=255, nullable=False) create_time: int = Field(default=0, sa_type=BigInteger()) - uid: int = Field(default=0,nullable=False, sa_type=BigInteger()) + uid: int = Field(default=0, nullable=False, sa_type=BigInteger()) status: bool = Field(default=True, nullable=False) - + + class ApiKeyModel(SnowflakeBase, ApiKeyBaseModel, table=True): - __tablename__ = "sys_apikey" \ No newline at end of file + __tablename__ = "sys_apikey" diff --git a/frontend/src/api/system.ts b/frontend/src/api/system.ts index f854c608..4d56fecb 100644 --- a/frontend/src/api/system.ts +++ b/frontend/src/api/system.ts @@ -30,4 +30,5 @@ export const modelApi = { platform: (id: number, lazy?: number, pid?: string) => request.post(`/system/platform/org/${id}`, { lazy, pid }), userSync: (data: any) => request.post(`/system/platform/user/sync`, data), + list_by_ws: () => request.get(`/system/aimodel/list/by_ws`), } diff --git a/frontend/src/views/system/embedded/iframe.vue b/frontend/src/views/system/embedded/iframe.vue index cf2fcc8a..292c72e1 100644 --- a/frontend/src/views/system/embedded/iframe.vue +++ b/frontend/src/views/system/embedded/iframe.vue @@ -101,25 +101,23 @@ const dsListOptions = ref([]) const embeddedListWithSearch = computed(() => { if (!keywords.value) return embeddedList.value return embeddedList.value.filter((ele: any) => - ele.name.toLowerCase().includes(keywords.value.toLowerCase()), + ele.name.toLowerCase().includes(keywords.value.toLowerCase()) ) }) interface Model { + id: number name: string - model_type: string - base_model: string - id: string default_model: boolean supplier: number } -const modelList =ref>([]) +const modelList = ref>([]) const searchModels = () => { searchLoading.value = true modelApi - .queryAll() + .list_by_ws() .then((res: any) => { modelList.value = res }) @@ -307,8 +305,8 @@ const validateUrl = (_: any, value: any, callback: any) => { if (value === '') { callback( new Error( - t('datasource.please_enter') + t('common.empty') + t('embedded.cross_domain_settings'), - ), + t('datasource.please_enter') + t('common.empty') + t('embedded.cross_domain_settings') + ) ) } else { // var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line @@ -352,7 +350,7 @@ const dsRules = { const validatePass = (_: any, value: any, callback: any) => { if (value === '') { callback( - new Error(t('datasource.please_enter') + t('common.empty') + t('embedded.interface_url')), + new Error(t('datasource.please_enter') + t('common.empty') + t('embedded.interface_url')) ) } else { // var Expression = /(https?:\/\/)?([\da-z\.-]+)\.([a-z]{2,6})(:\d{1,5})?([\/\w\.-]*)*\/?(#[\S]+)?/ // eslint-disable-line @@ -470,7 +468,7 @@ const saveEmbedded = () => { if (!currentEmbedded.id) { delete obj.id } - if (obj.custom_model == undefined){ + if (obj.custom_model == undefined) { obj.custom_model = '' } req(obj).then(() => { @@ -526,29 +524,29 @@ const handleEmbedded = (row: any) => { } const copyJsCode = () => { copy(jsCodeElement.value) - .then(function() { + .then(function () { ElMessage.success(t('embedded.copy_successful')) }) - .catch(function() { + .catch(function () { ElMessage.error(t('embedded.copy_failed')) }) } const copyJsCodeFull = () => { copy(jsCodeElementFull.value) - .then(function() { + .then(function () { ElMessage.success(t('embedded.copy_successful')) }) - .catch(function() { + .catch(function () { ElMessage.error(t('embedded.copy_failed')) }) } const copyCode = () => { copy(scriptElement.value) - .then(function() { + .then(function () { ElMessage.success(t('embedded.copy_successful')) }) - .catch(function() { + .catch(function () { ElMessage.error(t('embedded.copy_failed')) }) } @@ -821,13 +819,16 @@ const saveHandler = () => { - - {{t('embedded.enableCustomModel')}} + + {{ t('embedded.enableCustomModel') }} - + { /> - @@ -1013,7 +1013,7 @@ const saveHandler = () => {
{{ t('embedded.set_data_source') }} {{ $t('embedded.open_the_query') }} + >{{ $t('embedded.open_the_query') }}