fix(async-context-compression): sync CN version with EN version logic

- Add missing imports (contextlib, sessionmaker, Engine)
- Add database engine discovery functions (_discover_owui_engine, _discover_owui_schema)
- Fix ChatSummary table to support schema configuration
- Fix duplicate code in __init__ method
- Add _db_session context manager for robust session handling
- Fix inlet method signature (add __request__, __model__ parameters)
- Fix tool output trimming to check native function calling
- Add chat_id empty check in outlet method
This commit is contained in:
fujie
2026-01-19 20:37:37 +08:00
parent e7de80a059
commit db1a1e7ef0

View File

@@ -254,23 +254,25 @@ show_debug_log (前端调试日志)
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing import Optional, Dict, Any, List, Union, Callable, Awaitable from typing import Optional, Dict, Any, List, Union, Callable, Awaitable
import re
import asyncio import asyncio
import json import json
import hashlib import hashlib
import time import time
import re import contextlib
# Open WebUI 内置导入 # Open WebUI 内置导入
from open_webui.utils.chat import generate_chat_completion from open_webui.utils.chat import generate_chat_completion
from open_webui.models.models import Models
from open_webui.models.users import Users from open_webui.models.users import Users
from open_webui.models.models import Models
from fastapi.requests import Request from fastapi.requests import Request
from open_webui.main import app as webui_app from open_webui.main import app as webui_app
# Open WebUI 内部数据库 (复用共享连接) # Open WebUI 内部数据库 (复用共享连接)
from open_webui.internal.db import engine as owui_engine try:
from open_webui.internal.db import Session as owui_Session from open_webui.internal import db as owui_db
from open_webui.internal.db import Base as owui_Base except ModuleNotFoundError: # pragma: no cover - filter runs inside Open WebUI
owui_db = None
# 尝试导入 tiktoken # 尝试导入 tiktoken
try: try:
@@ -280,14 +282,91 @@ except ImportError:
# 数据库导入 # 数据库导入
from sqlalchemy import Column, String, Text, DateTime, Integer, inspect from sqlalchemy import Column, String, Text, DateTime, Integer, inspect
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.engine import Engine
from datetime import datetime from datetime import datetime
def _discover_owui_engine(db_module: Any) -> Optional[Engine]:
"""Discover the Open WebUI SQLAlchemy engine via provided db module helpers."""
if db_module is None:
return None
db_context = getattr(db_module, "get_db_context", None) or getattr(
db_module, "get_db", None
)
if callable(db_context):
try:
with db_context() as session:
try:
return session.get_bind()
except AttributeError:
return getattr(session, "bind", None) or getattr(
session, "engine", None
)
except Exception as exc:
print(f"[DB Discover] get_db_context failed: {exc}")
for attr in ("engine", "ENGINE", "bind", "BIND"):
candidate = getattr(db_module, attr, None)
if candidate is not None:
return candidate
return None
def _discover_owui_schema(db_module: Any) -> Optional[str]:
"""Discover the Open WebUI database schema name if configured."""
if db_module is None:
return None
try:
base = getattr(db_module, "Base", None)
metadata = getattr(base, "metadata", None) if base is not None else None
candidate = getattr(metadata, "schema", None) if metadata is not None else None
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
except Exception as exc:
print(f"[DB Discover] Base metadata schema lookup failed: {exc}")
try:
metadata_obj = getattr(db_module, "metadata_obj", None)
candidate = (
getattr(metadata_obj, "schema", None) if metadata_obj is not None else None
)
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
except Exception as exc:
print(f"[DB Discover] metadata_obj schema lookup failed: {exc}")
try:
from open_webui import env as owui_env
candidate = getattr(owui_env, "DATABASE_SCHEMA", None)
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
except Exception as exc:
print(f"[DB Discover] env schema lookup failed: {exc}")
return None
owui_engine = _discover_owui_engine(owui_db)
owui_schema = _discover_owui_schema(owui_db)
owui_Base = getattr(owui_db, "Base", None) if owui_db is not None else None
if owui_Base is None:
owui_Base = declarative_base()
class ChatSummary(owui_Base): class ChatSummary(owui_Base):
"""对话摘要存储表""" """对话摘要存储表"""
__tablename__ = "chat_summary" __tablename__ = "chat_summary"
__table_args__ = {"extend_existing": True} __table_args__ = (
{"extend_existing": True, "schema": owui_schema}
if owui_schema
else {"extend_existing": True}
)
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
chat_id = Column(String(255), unique=True, nullable=False, index=True) chat_id = Column(String(255), unique=True, nullable=False, index=True)
@@ -300,15 +379,65 @@ class ChatSummary(owui_Base):
class Filter: class Filter:
def __init__(self): def __init__(self):
self.valves = self.Valves() self.valves = self.Valves()
self._owui_db = owui_db
self._db_engine = owui_engine self._db_engine = owui_engine
self._SessionLocal = owui_Session self._fallback_session_factory = (
self._SessionLocal = owui_Session sessionmaker(bind=self._db_engine) if self._db_engine else None
self._init_database() )
self._init_database() self._init_database()
@contextlib.contextmanager
def _db_session(self):
"""Yield a database session using Open WebUI helpers with graceful fallbacks."""
db_module = self._owui_db
db_context = None
if db_module is not None:
db_context = getattr(db_module, "get_db_context", None) or getattr(
db_module, "get_db", None
)
if callable(db_context):
with db_context() as session:
yield session
return
factory = None
if db_module is not None:
factory = getattr(db_module, "SessionLocal", None) or getattr(
db_module, "ScopedSession", None
)
if callable(factory):
session = factory()
try:
yield session
finally:
close = getattr(session, "close", None)
if callable(close):
close()
return
if self._fallback_session_factory is None:
raise RuntimeError(
"Open WebUI database session is unavailable. Ensure Open WebUI's database layer is initialized."
)
session = self._fallback_session_factory()
try:
yield session
finally:
try:
session.close()
except Exception as exc: # pragma: no cover - best-effort cleanup
print(f"[Database] ⚠️ Failed to close fallback session: {exc}")
def _init_database(self): def _init_database(self):
"""使用 Open WebUI 的共享连接初始化数据库表""" """使用 Open WebUI 的共享连接初始化数据库表"""
try: try:
if self._db_engine is None:
raise RuntimeError(
"Open WebUI database engine is unavailable. Ensure Open WebUI is configured with a valid DATABASE_URL."
)
# 使用 SQLAlchemy inspect 检查表是否存在 # 使用 SQLAlchemy inspect 检查表是否存在
inspector = inspect(self._db_engine) inspector = inspect(self._db_engine)
if not inspector.has_table("chat_summary"): if not inspector.has_table("chat_summary"):
@@ -376,7 +505,7 @@ class Filter:
def _save_summary(self, chat_id: str, summary: str, compressed_count: int): def _save_summary(self, chat_id: str, summary: str, compressed_count: int):
"""保存摘要到数据库""" """保存摘要到数据库"""
try: try:
with self._SessionLocal() as session: with self._db_session() as session:
# 查找现有记录 # 查找现有记录
existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first() existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
@@ -414,7 +543,7 @@ class Filter:
def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]: def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]:
"""从数据库加载摘要记录对象""" """从数据库加载摘要记录对象"""
try: try:
with self._SessionLocal() as session: with self._db_session() as session:
record = session.query(ChatSummary).filter_by(chat_id=chat_id).first() record = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
if record: if record:
# Detach the object from the session so it can be used after session close # Detach the object from the session so it can be used after session close
@@ -628,6 +757,8 @@ class Filter:
body: dict, body: dict,
__user__: Optional[dict] = None, __user__: Optional[dict] = None,
__metadata__: dict = None, __metadata__: dict = None,
__request__: Request = None,
__model__: dict = None,
__event_emitter__: Callable[[Any], Awaitable[None]] = None, __event_emitter__: Callable[[Any], Awaitable[None]] = None,
__event_call__: Callable[[Any], Awaitable[None]] = None, __event_call__: Callable[[Any], Awaitable[None]] = None,
) -> dict: ) -> dict:
@@ -641,8 +772,10 @@ class Filter:
messages = body.get("messages", []) messages = body.get("messages", [])
# --- 原生工具输出裁剪 (Native Tool Output Trimming) --- # --- 原生工具输出裁剪 (Native Tool Output Trimming) ---
# 即使未启用压缩,也始终检查并裁剪过长的工具输出,以节省 Token metadata = body.get("metadata", {})
if self.valves.enable_tool_output_trimming: is_native_func_calling = metadata.get("function_calling") == "native"
if self.valves.enable_tool_output_trimming and is_native_func_calling:
trimmed_count = 0 trimmed_count = 0
for msg in messages: for msg in messages:
content = msg.get("content", "") content = msg.get("content", "")
@@ -1207,6 +1340,13 @@ class Filter:
""" """
chat_ctx = self._get_chat_context(body, __metadata__) chat_ctx = self._get_chat_context(body, __metadata__)
chat_id = chat_ctx["chat_id"] chat_id = chat_ctx["chat_id"]
if not chat_id:
await self._log(
"[Outlet] ❌ metadata 中缺少 chat_id跳过压缩",
type="error",
event_call=__event_call__,
)
return body
model = body.get("model") or "" model = body.get("model") or ""
# 直接计算目标压缩进度 # 直接计算目标压缩进度