diff --git a/plugins/filters/async-context-compression/async_context_compression_cn.py b/plugins/filters/async-context-compression/async_context_compression_cn.py index 74748be..1d970e6 100644 --- a/plugins/filters/async-context-compression/async_context_compression_cn.py +++ b/plugins/filters/async-context-compression/async_context_compression_cn.py @@ -254,23 +254,25 @@ show_debug_log (前端调试日志) from pydantic import BaseModel, Field, model_validator from typing import Optional, Dict, Any, List, Union, Callable, Awaitable +import re import asyncio import json import hashlib import time -import re +import contextlib # Open WebUI 内置导入 from open_webui.utils.chat import generate_chat_completion -from open_webui.models.models import Models from open_webui.models.users import Users +from open_webui.models.models import Models from fastapi.requests import Request from open_webui.main import app as webui_app # Open WebUI 内部数据库 (复用共享连接) -from open_webui.internal.db import engine as owui_engine -from open_webui.internal.db import Session as owui_Session -from open_webui.internal.db import Base as owui_Base +try: + from open_webui.internal import db as owui_db +except ModuleNotFoundError: # pragma: no cover - filter runs inside Open WebUI + owui_db = None # 尝试导入 tiktoken try: @@ -280,14 +282,91 @@ except ImportError: # 数据库导入 from sqlalchemy import Column, String, Text, DateTime, Integer, inspect +from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.engine import Engine from datetime import datetime +def _discover_owui_engine(db_module: Any) -> Optional[Engine]: + """Discover the Open WebUI SQLAlchemy engine via provided db module helpers.""" + if db_module is None: + return None + + db_context = getattr(db_module, "get_db_context", None) or getattr( + db_module, "get_db", None + ) + if callable(db_context): + try: + with db_context() as session: + try: + return session.get_bind() + except AttributeError: + return getattr(session, "bind", None) or getattr( + session, "engine", None + ) + except Exception as exc: + print(f"[DB Discover] get_db_context failed: {exc}") + + for attr in ("engine", "ENGINE", "bind", "BIND"): + candidate = getattr(db_module, attr, None) + if candidate is not None: + return candidate + + return None + + +def _discover_owui_schema(db_module: Any) -> Optional[str]: + """Discover the Open WebUI database schema name if configured.""" + if db_module is None: + return None + + try: + base = getattr(db_module, "Base", None) + metadata = getattr(base, "metadata", None) if base is not None else None + candidate = getattr(metadata, "schema", None) if metadata is not None else None + if isinstance(candidate, str) and candidate.strip(): + return candidate.strip() + except Exception as exc: + print(f"[DB Discover] Base metadata schema lookup failed: {exc}") + + try: + metadata_obj = getattr(db_module, "metadata_obj", None) + candidate = ( + getattr(metadata_obj, "schema", None) if metadata_obj is not None else None + ) + if isinstance(candidate, str) and candidate.strip(): + return candidate.strip() + except Exception as exc: + print(f"[DB Discover] metadata_obj schema lookup failed: {exc}") + + try: + from open_webui import env as owui_env + + candidate = getattr(owui_env, "DATABASE_SCHEMA", None) + if isinstance(candidate, str) and candidate.strip(): + return candidate.strip() + except Exception as exc: + print(f"[DB Discover] env schema lookup failed: {exc}") + + return None + + +owui_engine = _discover_owui_engine(owui_db) +owui_schema = _discover_owui_schema(owui_db) +owui_Base = getattr(owui_db, "Base", None) if owui_db is not None else None +if owui_Base is None: + owui_Base = declarative_base() + + class ChatSummary(owui_Base): """对话摘要存储表""" __tablename__ = "chat_summary" - __table_args__ = {"extend_existing": True} + __table_args__ = ( + {"extend_existing": True, "schema": owui_schema} + if owui_schema + else {"extend_existing": True} + ) id = Column(Integer, primary_key=True, autoincrement=True) chat_id = Column(String(255), unique=True, nullable=False, index=True) @@ -300,15 +379,65 @@ class ChatSummary(owui_Base): class Filter: def __init__(self): self.valves = self.Valves() + self._owui_db = owui_db self._db_engine = owui_engine - self._SessionLocal = owui_Session - self._SessionLocal = owui_Session - self._init_database() + self._fallback_session_factory = ( + sessionmaker(bind=self._db_engine) if self._db_engine else None + ) self._init_database() + @contextlib.contextmanager + def _db_session(self): + """Yield a database session using Open WebUI helpers with graceful fallbacks.""" + db_module = self._owui_db + db_context = None + if db_module is not None: + db_context = getattr(db_module, "get_db_context", None) or getattr( + db_module, "get_db", None + ) + + if callable(db_context): + with db_context() as session: + yield session + return + + factory = None + if db_module is not None: + factory = getattr(db_module, "SessionLocal", None) or getattr( + db_module, "ScopedSession", None + ) + if callable(factory): + session = factory() + try: + yield session + finally: + close = getattr(session, "close", None) + if callable(close): + close() + return + + if self._fallback_session_factory is None: + raise RuntimeError( + "Open WebUI database session is unavailable. Ensure Open WebUI's database layer is initialized." + ) + + session = self._fallback_session_factory() + try: + yield session + finally: + try: + session.close() + except Exception as exc: # pragma: no cover - best-effort cleanup + print(f"[Database] ⚠️ Failed to close fallback session: {exc}") + def _init_database(self): """使用 Open WebUI 的共享连接初始化数据库表""" try: + if self._db_engine is None: + raise RuntimeError( + "Open WebUI database engine is unavailable. Ensure Open WebUI is configured with a valid DATABASE_URL." + ) + # 使用 SQLAlchemy inspect 检查表是否存在 inspector = inspect(self._db_engine) if not inspector.has_table("chat_summary"): @@ -376,7 +505,7 @@ class Filter: def _save_summary(self, chat_id: str, summary: str, compressed_count: int): """保存摘要到数据库""" try: - with self._SessionLocal() as session: + with self._db_session() as session: # 查找现有记录 existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first() @@ -414,7 +543,7 @@ class Filter: def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]: """从数据库加载摘要记录对象""" try: - with self._SessionLocal() as session: + with self._db_session() as session: record = session.query(ChatSummary).filter_by(chat_id=chat_id).first() if record: # Detach the object from the session so it can be used after session close @@ -628,6 +757,8 @@ class Filter: body: dict, __user__: Optional[dict] = None, __metadata__: dict = None, + __request__: Request = None, + __model__: dict = None, __event_emitter__: Callable[[Any], Awaitable[None]] = None, __event_call__: Callable[[Any], Awaitable[None]] = None, ) -> dict: @@ -641,8 +772,10 @@ class Filter: messages = body.get("messages", []) # --- 原生工具输出裁剪 (Native Tool Output Trimming) --- - # 即使未启用压缩,也始终检查并裁剪过长的工具输出,以节省 Token - if self.valves.enable_tool_output_trimming: + metadata = body.get("metadata", {}) + is_native_func_calling = metadata.get("function_calling") == "native" + + if self.valves.enable_tool_output_trimming and is_native_func_calling: trimmed_count = 0 for msg in messages: content = msg.get("content", "") @@ -1207,6 +1340,13 @@ class Filter: """ chat_ctx = self._get_chat_context(body, __metadata__) chat_id = chat_ctx["chat_id"] + if not chat_id: + await self._log( + "[Outlet] ❌ metadata 中缺少 chat_id,跳过压缩", + type="error", + event_call=__event_call__, + ) + return body model = body.get("model") or "" # 直接计算目标压缩进度