Files
AgentCoord/backend/db/crud.py

578 lines
18 KiB
Python
Raw Normal View History

"""
数据库 CRUD 操作
封装所有数据库操作方法 (基于 DATABASE_DESIGN.md)
"""
import copy
import uuid
from datetime import datetime, timezone
from typing import List, Optional
from sqlalchemy.orm import Session
2026-03-12 13:35:04 +08:00
from .models import MultiAgentTask, UserAgent, ExportRecord, PlanShare
class MultiAgentTaskCRUD:
"""多智能体任务 CRUD 操作"""
@staticmethod
def create(
db: Session,
task_id: Optional[str] = None, # 可选,如果为 None 则自动生成
user_id: str = "",
query: str = "",
agents_info: list = [],
task_outline: Optional[dict] = None,
assigned_agents: Optional[list] = None,
agent_scores: Optional[dict] = None,
result: Optional[str] = None,
) -> MultiAgentTask:
"""创建任务记录"""
task = MultiAgentTask(
task_id=task_id or str(uuid.uuid4()), # 如果没传则生成新的
user_id=user_id,
query=query,
agents_info=agents_info,
task_outline=task_outline,
assigned_agents=assigned_agents,
agent_scores=agent_scores,
result=result,
)
db.add(task)
db.commit()
db.refresh(task)
return task
@staticmethod
def get_by_id(db: Session, task_id: str) -> Optional[MultiAgentTask]:
"""根据任务 ID 获取记录"""
return db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
@staticmethod
def get_by_user_id(
db: Session, user_id: str, limit: int = 50, offset: int = 0
) -> List[MultiAgentTask]:
"""根据用户 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.user_id == user_id)
.order_by(MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def get_recent(
2026-03-09 10:28:07 +08:00
db: Session, limit: int = 20, offset: int = 0, user_id: str = None
) -> List[MultiAgentTask]:
"""获取最近的任务记录,置顶的排在最前面"""
2026-03-09 10:28:07 +08:00
query = db.query(MultiAgentTask)
# 按 user_id 过滤
if user_id:
query = query.filter(MultiAgentTask.user_id == user_id)
return (
2026-03-09 10:28:07 +08:00
query
.order_by(MultiAgentTask.is_pinned.desc(), MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def update_result(
db: Session, task_id: str, result: list
) -> Optional[MultiAgentTask]:
"""更新任务结果"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.result = result if result else []
db.commit()
db.refresh(task)
return task
@staticmethod
def update_task_outline(
db: Session, task_id: str, task_outline: dict
) -> Optional[MultiAgentTask]:
"""更新任务大纲"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.task_outline = task_outline
db.commit()
db.refresh(task)
return task
@staticmethod
def update_assigned_agents(
db: Session, task_id: str, assigned_agents: dict
) -> Optional[MultiAgentTask]:
"""更新分配的智能体(步骤名 -> agent列表"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.assigned_agents = assigned_agents
db.commit()
db.refresh(task)
return task
@staticmethod
def update_agent_scores(
db: Session, task_id: str, agent_scores: dict
) -> Optional[MultiAgentTask]:
"""更新智能体评分(合并模式,追加新步骤的评分)"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
# 合并现有评分数据和新评分数据
existing_scores = task.agent_scores or {}
merged_scores = {**existing_scores, **agent_scores} # 新数据覆盖/追加旧数据
task.agent_scores = merged_scores
db.commit()
db.refresh(task)
return task
@staticmethod
def update_status(
db: Session, task_id: str, status: str
) -> Optional[MultiAgentTask]:
"""更新任务状态"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.status = status
db.commit()
db.refresh(task)
return task
@staticmethod
def increment_execution_count(db: Session, task_id: str) -> Optional[MultiAgentTask]:
"""增加任务执行次数"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.execution_count = (task.execution_count or 0) + 1
db.commit()
db.refresh(task)
return task
@staticmethod
def update_generation_id(
db: Session, task_id: str, generation_id: str
) -> Optional[MultiAgentTask]:
"""更新生成 ID"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.generation_id = generation_id
db.commit()
db.refresh(task)
return task
@staticmethod
def update_execution_id(
db: Session, task_id: str, execution_id: str
) -> Optional[MultiAgentTask]:
"""更新执行 ID"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.execution_id = execution_id
db.commit()
db.refresh(task)
return task
@staticmethod
def update_rehearsal_log(
db: Session, task_id: str, rehearsal_log: list
) -> Optional[MultiAgentTask]:
"""更新排练日志"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.rehearsal_log = rehearsal_log if rehearsal_log else []
db.commit()
db.refresh(task)
return task
@staticmethod
def update_is_pinned(
db: Session, task_id: str, is_pinned: bool
) -> Optional[MultiAgentTask]:
"""更新任务置顶状态"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
task.is_pinned = is_pinned
db.commit()
db.refresh(task)
return task
@staticmethod
def append_rehearsal_log(
db: Session, task_id: str, log_entry: dict
) -> Optional[MultiAgentTask]:
"""追加排练日志条目"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
current_log = task.rehearsal_log or []
if isinstance(current_log, list):
current_log.append(log_entry)
else:
current_log = [log_entry]
task.rehearsal_log = current_log
db.commit()
db.refresh(task)
return task
@staticmethod
def update_branches(
db: Session, task_id: str, branches
) -> Optional[MultiAgentTask]:
"""更新任务分支数据
支持两种格式
- list: 旧格式直接覆盖
- dict: 新格式 { flow_branches: [...], task_process_branches: {...} }
两个 key 独立保存互不干扰
"""
import copy
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
if isinstance(branches, dict):
# 新格式:字典,独立保存两个 key互不干扰
# 使用深拷贝避免引用共享问题
existing = copy.deepcopy(task.branches) if task.branches else {}
if isinstance(existing, dict):
# 如果只更新 flow_branches保留已有的 task_process_branches
if 'flow_branches' in branches and 'task_process_branches' not in branches:
branches['task_process_branches'] = existing.get('task_process_branches', {})
# 如果只更新 task_process_branches保留已有的 flow_branches
if 'task_process_branches' in branches and 'flow_branches' not in branches:
branches['flow_branches'] = existing.get('flow_branches', [])
task.branches = branches
else:
# 旧格式:列表
task.branches = branches if branches else []
db.commit()
db.refresh(task)
return task
@staticmethod
def get_branches(db: Session, task_id: str) -> Optional[list]:
"""获取任务分支数据"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
return task.branches or []
return []
@staticmethod
def get_by_status(
db: Session, status: str, limit: int = 50, offset: int = 0
) -> List[MultiAgentTask]:
"""根据状态获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.status == status)
.order_by(MultiAgentTask.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
@staticmethod
def get_by_generation_id(
db: Session, generation_id: str
) -> List[MultiAgentTask]:
"""根据生成 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.generation_id == generation_id)
.all()
)
@staticmethod
def get_by_execution_id(
db: Session, execution_id: str
) -> List[MultiAgentTask]:
"""根据执行 ID 获取任务记录"""
return (
db.query(MultiAgentTask)
.filter(MultiAgentTask.execution_id == execution_id)
.all()
)
@staticmethod
def delete(db: Session, task_id: str) -> bool:
"""删除任务记录"""
task = db.query(MultiAgentTask).filter(MultiAgentTask.task_id == task_id).first()
if task:
db.delete(task)
db.commit()
return True
return False
class UserAgentCRUD:
"""用户智能体配置 CRUD 操作"""
@staticmethod
def create(
db: Session,
user_id: str,
agent_name: str,
agent_config: dict,
) -> UserAgent:
"""创建用户智能体配置"""
agent = UserAgent(
id=str(uuid.uuid4()),
user_id=user_id,
agent_name=agent_name,
agent_config=agent_config,
)
db.add(agent)
db.commit()
db.refresh(agent)
return agent
@staticmethod
def get_by_id(db: Session, agent_id: str) -> Optional[UserAgent]:
"""根据 ID 获取配置"""
return db.query(UserAgent).filter(UserAgent.id == agent_id).first()
@staticmethod
def get_by_user_id(
db: Session, user_id: str, limit: int = 50
) -> List[UserAgent]:
"""根据用户 ID 获取所有智能体配置"""
return (
db.query(UserAgent)
.filter(UserAgent.user_id == user_id)
.order_by(UserAgent.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def get_by_name(
db: Session, user_id: str, agent_name: str
) -> List[UserAgent]:
"""根据用户 ID 和智能体名称获取配置"""
return (
db.query(UserAgent)
.filter(
UserAgent.user_id == user_id,
UserAgent.agent_name == agent_name,
)
.all()
)
@staticmethod
def update_config(
db: Session, agent_id: str, agent_config: dict
) -> Optional[UserAgent]:
"""更新智能体配置"""
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
if agent:
agent.agent_config = agent_config
db.commit()
db.refresh(agent)
return agent
@staticmethod
def delete(db: Session, agent_id: str) -> bool:
"""删除智能体配置"""
agent = db.query(UserAgent).filter(UserAgent.id == agent_id).first()
if agent:
db.delete(agent)
db.commit()
return True
return False
@staticmethod
def upsert(
db: Session,
user_id: str,
agent_name: str,
agent_config: dict,
) -> UserAgent:
"""更新或插入用户智能体配置(根据 user_id + agent_name 判断唯一性)
如果已存在相同 user_id agent_name 的记录则更新配置
否则创建新记录
"""
existing = (
db.query(UserAgent)
.filter(
UserAgent.user_id == user_id,
UserAgent.agent_name == agent_name,
)
.first()
)
if existing:
# 更新现有记录
existing.agent_config = agent_config
db.commit()
db.refresh(existing)
return existing
else:
# 创建新记录
agent = UserAgent(
id=str(uuid.uuid4()),
user_id=user_id,
agent_name=agent_name,
agent_config=agent_config,
)
db.add(agent)
db.commit()
db.refresh(agent)
return agent
2026-03-05 11:00:21 +08:00
class ExportRecordCRUD:
"""导出记录 CRUD 操作"""
@staticmethod
def create(
db: Session,
task_id: str,
user_id: str,
export_type: str,
file_name: str,
file_path: str,
file_url: str = "",
file_size: int = 0,
) -> ExportRecord:
"""创建导出记录"""
record = ExportRecord(
task_id=task_id,
user_id=user_id,
export_type=export_type,
file_name=file_name,
file_path=file_path,
file_url=file_url,
file_size=file_size,
)
db.add(record)
db.commit()
db.refresh(record)
return record
@staticmethod
def get_by_id(db: Session, record_id: int) -> Optional[ExportRecord]:
"""根据 ID 获取记录"""
return db.query(ExportRecord).filter(ExportRecord.id == record_id).first()
@staticmethod
def get_by_task_id(
db: Session, task_id: str, limit: int = 50
) -> List[ExportRecord]:
"""根据任务 ID 获取导出记录列表"""
return (
db.query(ExportRecord)
.filter(ExportRecord.task_id == task_id)
.order_by(ExportRecord.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def get_by_user_id(
db: Session, user_id: str, limit: int = 50
) -> List[ExportRecord]:
"""根据用户 ID 获取导出记录列表"""
return (
db.query(ExportRecord)
.filter(ExportRecord.user_id == user_id)
.order_by(ExportRecord.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def delete(db: Session, record_id: int) -> bool:
"""删除导出记录"""
record = db.query(ExportRecord).filter(ExportRecord.id == record_id).first()
if record:
db.delete(record)
db.commit()
return True
return False
@staticmethod
def delete_by_task_id(db: Session, task_id: str) -> bool:
"""删除任务的所有导出记录"""
records = db.query(ExportRecord).filter(ExportRecord.task_id == task_id).all()
if records:
for record in records:
db.delete(record)
db.commit()
return True
return False
2026-03-12 13:35:04 +08:00
class PlanShareCRUD:
"""任务分享 CRUD 操作"""
@staticmethod
def create(
db: Session,
share_token: str,
task_id: str,
task_data: dict,
expires_at: Optional[datetime] = None,
extraction_code: Optional[str] = None,
) -> PlanShare:
"""创建分享记录"""
share = PlanShare(
share_token=share_token,
extraction_code=extraction_code,
task_id=task_id,
task_data=task_data,
expires_at=expires_at,
)
db.add(share)
db.commit()
db.refresh(share)
return share
@staticmethod
def get_by_token(db: Session, share_token: str) -> Optional[PlanShare]:
"""根据 token 获取分享记录"""
return db.query(PlanShare).filter(PlanShare.share_token == share_token).first()
@staticmethod
def get_by_task_id(
db: Session, task_id: str, limit: int = 10
) -> List[PlanShare]:
"""根据任务 ID 获取分享记录列表"""
return (
db.query(PlanShare)
.filter(PlanShare.task_id == task_id)
.order_by(PlanShare.created_at.desc())
.limit(limit)
.all()
)
@staticmethod
def increment_view_count(db: Session, share_token: str) -> Optional[PlanShare]:
"""增加查看次数"""
share = db.query(PlanShare).filter(PlanShare.share_token == share_token).first()
if share:
share.view_count = (share.view_count or 0) + 1
db.commit()
db.refresh(share)
return share
@staticmethod
def delete(db: Session, share_token: str) -> bool:
"""删除分享记录"""
share = db.query(PlanShare).filter(PlanShare.share_token == share_token).first()
if share:
db.delete(share)
db.commit()
return True
return False
@staticmethod
def delete_by_task_id(db: Session, task_id: str) -> bool:
"""删除任务的所有分享记录"""
shares = db.query(PlanShare).filter(PlanShare.task_id == task_id).all()
if shares:
for share in shares:
db.delete(share)
db.commit()
return True
return False