diff --git a/backend/apps/system/api/aimodel.py b/backend/apps/system/api/aimodel.py index fe4df740..620adac7 100644 --- a/backend/apps/system/api/aimodel.py +++ b/backend/apps/system/api/aimodel.py @@ -359,4 +359,4 @@ async def get_model_by_ws( session: SessionDep, current_user: CurrentUser ): - return get_ai_model_list_by_workspace(session, current_user.oid) + return get_ai_model_list_by_workspace(session, current_user.oid, False) diff --git a/backend/apps/system/crud/aimodel_manage.py b/backend/apps/system/crud/aimodel_manage.py index 4ff8006c..b0fabdbd 100644 --- a/backend/apps/system/crud/aimodel_manage.py +++ b/backend/apps/system/crud/aimodel_manage.py @@ -31,7 +31,7 @@ async def async_model_info(): SQLBotLogUtil.info("✅ 异步加密已有模型的密钥和地址完成") -def get_ai_model_list_by_workspace(session: Session, workspace_id: int): +def get_ai_model_list_by_workspace(session: Session, workspace_id: int, with_default: bool = True): sub_stmt = ( select(AiModelWorkspaceMapping.ai_model_id) .where(AiModelWorkspaceMapping.workspace_id == workspace_id) @@ -39,6 +39,11 @@ def get_ai_model_list_by_workspace(session: Session, workspace_id: int): ) # 查询:关联的模型 + default_model 为 True 的模型,默认模型排第一 + base_condition = AiModelDetail.id.in_(sub_stmt) + if with_default: + where_condition = or_(base_condition, AiModelDetail.default_model == True) + else: + where_condition = base_condition stmt = ( select( AiModelDetail.id, @@ -46,12 +51,7 @@ def get_ai_model_list_by_workspace(session: Session, workspace_id: int): AiModelDetail.default_model, AiModelDetail.supplier, ) - .where( - or_( - AiModelDetail.id.in_(sub_stmt), - AiModelDetail.default_model == True, - ) - ) + .where(where_condition) .order_by(AiModelDetail.default_model.desc()) ) rows = session.exec(stmt).all()