Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 109 additions & 28 deletions backend/apps/system/api/aimodel.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import json
from typing import List, Union

from fastapi import APIRouter, Path, Query, Body
from fastapi.responses import StreamingResponse
from sqlmodel import func, select, update, delete

from apps.ai_model.model_factory import LLMConfig, LLMFactory
from apps.swagger.i18n import PLACEHOLDER_PREFIX
from apps.system.crud.aimodel_manage import get_ai_model_list_by_workspace
from apps.system.models.system_model import AiModelDetail, AiModelWorkspaceMapping
from apps.system.schemas.ai_model_schema import AiModelConfigItem, AiModelCreator, AiModelEditor, AiModelGridItem
from fastapi import APIRouter, Path, Query
from sqlmodel import func, select, update

from apps.system.models.system_model import AiModelDetail
from apps.system.schemas.permission import SqlbotPermission, require_permissions
from common.core.deps import SessionDep, Trans
from common.core.deps import SessionDep, Trans, CurrentUser
from common.utils.crypto import sqlbot_decrypt
from common.utils.time import get_timestamp
from common.utils.utils import SQLBotLogUtil, prepare_model_arg
Expand All @@ -19,12 +20,14 @@
from common.audit.models.log_model import OperationType, OperationModules
from common.audit.schemas.logger_decorator import LogConfig, system_log


@router.post("/status", include_in_schema=False)
@require_permissions(permission=SqlbotPermission(role=['admin']))
@require_permissions(permission=SqlbotPermission(role=['admin']))
async def check_llm(info: AiModelCreator, trans: Trans):
async def generate():
try:
additional_params = {item.key: prepare_model_arg(item.val) for item in info.config_list if item.key and item.val}
additional_params = {item.key: prepare_model_arg(item.val) for item in info.config_list if
item.key and item.val}
config = LLMConfig(
model_type="openai" if info.protocol == 1 else "vllm",
model_name=info.base_model,
Expand All @@ -39,23 +42,26 @@ async def generate():
yield json.dumps({"content": chunk}) + "\n"
if chunk and isinstance(chunk, dict) and chunk.content:
yield json.dumps({"content": chunk.content}) + "\n"

except Exception as e:
SQLBotLogUtil.error(f"Error checking LLM: {e}")
error_msg = trans('i18n_llm.validate_error', msg=str(e))
yield json.dumps({"error": error_msg}) + "\n"

return StreamingResponse(generate(), media_type="application/x-ndjson")


@router.get("/default", include_in_schema=False)
async def check_default(session: SessionDep, trans: Trans):
db_model = session.exec(
select(AiModelDetail).where(AiModelDetail.default_model == True)
).first()
if not db_model:
raise Exception(trans('i18n_llm.miss_default'))

@router.put("/default/{id}", summary=f"{PLACEHOLDER_PREFIX}system_model_default", description=f"{PLACEHOLDER_PREFIX}system_model_default")


@router.put("/default/{id}", summary=f"{PLACEHOLDER_PREFIX}system_model_default",
description=f"{PLACEHOLDER_PREFIX}system_model_default")
@require_permissions(permission=SqlbotPermission(role=['admin']))
@system_log(LogConfig(operation_type=OperationType.UPDATE, module=OperationModules.AI_MODEL, resource_id_expr="id"))
async def set_default(session: SessionDep, id: int = Path(description="ID")):
Expand All @@ -76,27 +82,31 @@ async def set_default(session: SessionDep, id: int = Path(description="ID")):
session.rollback()
raise e

@router.get("", response_model=list[AiModelGridItem], summary=f"{PLACEHOLDER_PREFIX}system_model_grid", description=f"{PLACEHOLDER_PREFIX}system_model_grid")
@require_permissions(permission=SqlbotPermission(role=['admin']))

@router.get("", response_model=list[AiModelGridItem], summary=f"{PLACEHOLDER_PREFIX}system_model_grid",
description=f"{PLACEHOLDER_PREFIX}system_model_grid")
@require_permissions(permission=SqlbotPermission(role=['admin']))
async def query(
session: SessionDep,
keyword: Union[str, None] = Query(default=None, max_length=255, description=f"{PLACEHOLDER_PREFIX}keyword")
):
statement = select(AiModelDetail.id,
AiModelDetail.name,
AiModelDetail.model_type,
AiModelDetail.base_model,
statement = select(AiModelDetail.id,
AiModelDetail.name,
AiModelDetail.model_type,
AiModelDetail.base_model,
AiModelDetail.supplier,
AiModelDetail.protocol,
AiModelDetail.protocol,
AiModelDetail.default_model)
if keyword is not None:
statement = statement.where(AiModelDetail.name.like(f"%{keyword}%"))
statement = statement.order_by(AiModelDetail.default_model.desc(), AiModelDetail.name, AiModelDetail.create_time)
items = session.exec(statement).all()
return items

@router.get("/{id}", response_model=AiModelEditor, summary=f"{PLACEHOLDER_PREFIX}system_model_query", description=f"{PLACEHOLDER_PREFIX}system_model_query")
@require_permissions(permission=SqlbotPermission(role=['admin']))

@router.get("/{id}", response_model=AiModelEditor, summary=f"{PLACEHOLDER_PREFIX}system_model_query",
description=f"{PLACEHOLDER_PREFIX}system_model_query")
@require_permissions(permission=SqlbotPermission(role=['admin']))
async def get_model_by_id(
session: SessionDep,
id: int = Path(description="ID")
Expand Down Expand Up @@ -124,7 +134,9 @@ async def get_model_by_id(
data["config_list"] = config_list
return AiModelEditor(**data)

@router.post("", summary=f"{PLACEHOLDER_PREFIX}system_model_create", description=f"{PLACEHOLDER_PREFIX}system_model_create")

@router.post("", summary=f"{PLACEHOLDER_PREFIX}system_model_create",
description=f"{PLACEHOLDER_PREFIX}system_model_create")
@require_permissions(permission=SqlbotPermission(role=['admin']))
@system_log(LogConfig(operation_type=OperationType.CREATE, module=OperationModules.AI_MODEL, result_id_expr="id"))
async def add_model(
Expand All @@ -143,9 +155,12 @@ async def add_model(
session.commit()
return detail

@router.put("", summary=f"{PLACEHOLDER_PREFIX}system_model_update", description=f"{PLACEHOLDER_PREFIX}system_model_update")

@router.put("", summary=f"{PLACEHOLDER_PREFIX}system_model_update",
description=f"{PLACEHOLDER_PREFIX}system_model_update")
@require_permissions(permission=SqlbotPermission(role=['admin']))
@system_log(LogConfig(operation_type=OperationType.UPDATE, module=OperationModules.AI_MODEL, resource_id_expr="editor.id"))
@system_log(
LogConfig(operation_type=OperationType.UPDATE, module=OperationModules.AI_MODEL, resource_id_expr="editor.id"))
async def update_model(
session: SessionDep,
editor: AiModelEditor
Expand All @@ -155,12 +170,14 @@ async def update_model(
data["config"] = json.dumps([item.model_dump(exclude_unset=True) for item in editor.config_list])
data.pop("config_list", None)
db_model = session.get(AiModelDetail, id)
#update_data = AiModelDetail.model_validate(data)
# update_data = AiModelDetail.model_validate(data)
db_model.sqlmodel_update(data)
session.add(db_model)
session.commit()

@router.delete("/{id}", summary=f"{PLACEHOLDER_PREFIX}system_model_del", description=f"{PLACEHOLDER_PREFIX}system_model_del")

@router.delete("/{id}", summary=f"{PLACEHOLDER_PREFIX}system_model_del",
description=f"{PLACEHOLDER_PREFIX}system_model_del")
@require_permissions(permission=SqlbotPermission(role=['admin']))
@system_log(LogConfig(operation_type=OperationType.DELETE, module=OperationModules.AI_MODEL, resource_id_expr="id"))
async def delete_model(
Expand All @@ -170,9 +187,73 @@ async def delete_model(
):
item = session.get(AiModelDetail, id)
if item.default_model:
raise Exception(trans('i18n_llm.delete_default_error', key = item.name))
raise Exception(trans('i18n_llm.delete_default_error', key=item.name))
session.delete(item)
session.commit()




@router.get("/{id}/ws_mapping", response_model=AiModelEditor, summary=f"{PLACEHOLDER_PREFIX}system_model_query",
description=f"{PLACEHOLDER_PREFIX}system_model_query")
@require_permissions(permission=SqlbotPermission(role=['admin']))
async def get_model_ws_mapping_by_id(
session: SessionDep,
id: int = Path(description="ID")
):
db_model = session.get(AiModelDetail, id)
if not db_model:
raise ValueError(f"AiModelDetail with id {id} not found")

# 根据 ai_model_id 查询关联的 workspace_id 列表
stmt = (
select(AiModelWorkspaceMapping.workspace_id)
.where(AiModelWorkspaceMapping.ai_model_id == id)
.distinct()
)
ws_ids: List[int] = session.exec(stmt).all()

return ws_ids


@router.put("/{id}/ws_mapping", response_model=AiModelEditor, summary=f"{PLACEHOLDER_PREFIX}system_model_query",
description=f"{PLACEHOLDER_PREFIX}system_model_query")
@require_permissions(permission=SqlbotPermission(role=['admin']))
async def update_model_ws_mapping_by_id(
session: SessionDep,
id: int = Path(description="ID"),
ws_ids: List[int] = Body(description="workspace id list"),
):
if ws_ids is None:
ws_ids = []
# 提前去重
ws_ids = list(set(ws_ids))

db_model = session.get(AiModelDetail, id)
if not db_model:
raise ValueError(f"AiModelDetail with id {id} not found")

# 根据 ai_model_id 更新关联的 workspace_id 列表
# 1. 批量删除旧映射
session.execute(
delete(AiModelWorkspaceMapping)
.where(AiModelWorkspaceMapping.ai_model_id == id)
)

# 2. 插入去重后的映射关系
for ws_id in ws_ids:
session.add(
AiModelWorkspaceMapping(ai_model_id=id, workspace_id=ws_id)
)

session.commit()

return ws_ids


@router.get("/list_by_ws", response_model=AiModelEditor, summary=f"{PLACEHOLDER_PREFIX}system_model_query",
description=f"{PLACEHOLDER_PREFIX}system_model_query")
@require_permissions(permission=SqlbotPermission(role=['admin']))
async def get_model_by_ws(
session: SessionDep,
current_user: CurrentUser
):
return get_ai_model_list_by_workspace(session, current_user.workspace_id)
46 changes: 40 additions & 6 deletions backend/apps/system/crud/aimodel_manage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from sqlmodel import Session, select, or_

from apps.system.models.system_model import AiModelDetail
from apps.system.models.system_model import AiModelDetail, AiModelBrief, AiModelWorkspaceMapping
from common.core.db import engine
from sqlmodel import Session, select
from common.utils.crypto import sqlbot_encrypt
from common.utils.utils import SQLBotLogUtil


async def async_model_info():
with Session(engine) as session:
model_list = session.exec(select(AiModelDetail)).all()
Expand All @@ -27,7 +28,40 @@ async def async_model_info():
session.add(model)
if any_model_change:
session.commit()
SQLBotLogUtil.info("✅ 异步加密已有模型的密钥和地址完成")



SQLBotLogUtil.info("✅ 异步加密已有模型的密钥和地址完成")


def get_ai_model_list_by_workspace(session: Session, workspace_id: int):
sub_stmt = (
select(AiModelWorkspaceMapping.ai_model_id)
.where(AiModelWorkspaceMapping.workspace_id == workspace_id)
.distinct()
)

# 查询:关联的模型 + default_model 为 True 的模型,默认模型排第一
stmt = (
select(
AiModelDetail.id,
AiModelDetail.name,
AiModelDetail.default_model,
AiModelDetail.supplier,
)
.where(
or_(
AiModelDetail.id.in_(sub_stmt),
AiModelDetail.default_model == True,
)
)
.order_by(AiModelDetail.default_model.desc())
)
rows = session.exec(stmt).all()

return [
AiModelBrief(
id=row[0],
name=row[1],
default_model=row[2],
supplier=row[3],
)
for row in rows
]
10 changes: 9 additions & 1 deletion backend/apps/system/models/system_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@ class AiModelDetail(SnowflakeBase, AiModelBase, table=True):
status: int = Field(nullable=False, default = 1)
create_time: int = Field(default=0, sa_type=BigInteger())

class AiModelWorkspaceMapping(SnowflakeBase, table=True):
__tablename__ = "ai_model_workspace_mapping"
ai_model_id: int = Field(default=None, nullable=True, sa_type=BigInteger())
workspace_id: int = Field(default=None, nullable=True, sa_type=BigInteger())


class AiModelBrief(SQLModel):
id: int
name: str
default_model: bool
supplier: int

class WorkspaceBase(SQLModel):
name: str = Field(max_length=255, nullable=False)
Expand Down
Loading