fix(async-context-compression): sync CN version with EN version logic

- Add missing imports (contextlib, sessionmaker, Engine)
- Add database engine discovery functions (_discover_owui_engine, _discover_owui_schema)
- Fix ChatSummary table to support schema configuration
- Fix duplicate code in __init__ method
- Add _db_session context manager for robust session handling
- Fix inlet method signature (add __request__, __model__ parameters)
- Fix tool output trimming to check native function calling
- Add chat_id empty check in outlet method
This commit is contained in:
fujie
2026-01-19 20:37:37 +08:00
parent e7de80a059
commit db1a1e7ef0

View File

@@ -254,23 +254,25 @@ show_debug_log (前端调试日志)
from pydantic import BaseModel, Field, model_validator
from typing import Optional, Dict, Any, List, Union, Callable, Awaitable
import re
import asyncio
import json
import hashlib
import time
import re
import contextlib
# Open WebUI 内置导入
from open_webui.utils.chat import generate_chat_completion
from open_webui.models.models import Models
from open_webui.models.users import Users
from open_webui.models.models import Models
from fastapi.requests import Request
from open_webui.main import app as webui_app
# Open WebUI 内部数据库 (复用共享连接)
from open_webui.internal.db import engine as owui_engine
from open_webui.internal.db import Session as owui_Session
from open_webui.internal.db import Base as owui_Base
try:
from open_webui.internal import db as owui_db
except ModuleNotFoundError: # pragma: no cover - filter runs inside Open WebUI
owui_db = None
# 尝试导入 tiktoken
try:
@@ -280,14 +282,91 @@ except ImportError:
# 数据库导入
from sqlalchemy import Column, String, Text, DateTime, Integer, inspect
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.engine import Engine
from datetime import datetime
def _discover_owui_engine(db_module: Any) -> Optional[Engine]:
"""Discover the Open WebUI SQLAlchemy engine via provided db module helpers."""
if db_module is None:
return None
db_context = getattr(db_module, "get_db_context", None) or getattr(
db_module, "get_db", None
)
if callable(db_context):
try:
with db_context() as session:
try:
return session.get_bind()
except AttributeError:
return getattr(session, "bind", None) or getattr(
session, "engine", None
)
except Exception as exc:
print(f"[DB Discover] get_db_context failed: {exc}")
for attr in ("engine", "ENGINE", "bind", "BIND"):
candidate = getattr(db_module, attr, None)
if candidate is not None:
return candidate
return None
def _discover_owui_schema(db_module: Any) -> Optional[str]:
"""Discover the Open WebUI database schema name if configured."""
if db_module is None:
return None
try:
base = getattr(db_module, "Base", None)
metadata = getattr(base, "metadata", None) if base is not None else None
candidate = getattr(metadata, "schema", None) if metadata is not None else None
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
except Exception as exc:
print(f"[DB Discover] Base metadata schema lookup failed: {exc}")
try:
metadata_obj = getattr(db_module, "metadata_obj", None)
candidate = (
getattr(metadata_obj, "schema", None) if metadata_obj is not None else None
)
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
except Exception as exc:
print(f"[DB Discover] metadata_obj schema lookup failed: {exc}")
try:
from open_webui import env as owui_env
candidate = getattr(owui_env, "DATABASE_SCHEMA", None)
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
except Exception as exc:
print(f"[DB Discover] env schema lookup failed: {exc}")
return None
owui_engine = _discover_owui_engine(owui_db)
owui_schema = _discover_owui_schema(owui_db)
owui_Base = getattr(owui_db, "Base", None) if owui_db is not None else None
if owui_Base is None:
owui_Base = declarative_base()
class ChatSummary(owui_Base):
"""对话摘要存储表"""
__tablename__ = "chat_summary"
__table_args__ = {"extend_existing": True}
__table_args__ = (
{"extend_existing": True, "schema": owui_schema}
if owui_schema
else {"extend_existing": True}
)
id = Column(Integer, primary_key=True, autoincrement=True)
chat_id = Column(String(255), unique=True, nullable=False, index=True)
@@ -300,15 +379,65 @@ class ChatSummary(owui_Base):
class Filter:
def __init__(self):
self.valves = self.Valves()
self._owui_db = owui_db
self._db_engine = owui_engine
self._SessionLocal = owui_Session
self._SessionLocal = owui_Session
self._init_database()
self._fallback_session_factory = (
sessionmaker(bind=self._db_engine) if self._db_engine else None
)
self._init_database()
@contextlib.contextmanager
def _db_session(self):
"""Yield a database session using Open WebUI helpers with graceful fallbacks."""
db_module = self._owui_db
db_context = None
if db_module is not None:
db_context = getattr(db_module, "get_db_context", None) or getattr(
db_module, "get_db", None
)
if callable(db_context):
with db_context() as session:
yield session
return
factory = None
if db_module is not None:
factory = getattr(db_module, "SessionLocal", None) or getattr(
db_module, "ScopedSession", None
)
if callable(factory):
session = factory()
try:
yield session
finally:
close = getattr(session, "close", None)
if callable(close):
close()
return
if self._fallback_session_factory is None:
raise RuntimeError(
"Open WebUI database session is unavailable. Ensure Open WebUI's database layer is initialized."
)
session = self._fallback_session_factory()
try:
yield session
finally:
try:
session.close()
except Exception as exc: # pragma: no cover - best-effort cleanup
print(f"[Database] ⚠️ Failed to close fallback session: {exc}")
def _init_database(self):
"""使用 Open WebUI 的共享连接初始化数据库表"""
try:
if self._db_engine is None:
raise RuntimeError(
"Open WebUI database engine is unavailable. Ensure Open WebUI is configured with a valid DATABASE_URL."
)
# 使用 SQLAlchemy inspect 检查表是否存在
inspector = inspect(self._db_engine)
if not inspector.has_table("chat_summary"):
@@ -376,7 +505,7 @@ class Filter:
def _save_summary(self, chat_id: str, summary: str, compressed_count: int):
"""保存摘要到数据库"""
try:
with self._SessionLocal() as session:
with self._db_session() as session:
# 查找现有记录
existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
@@ -414,7 +543,7 @@ class Filter:
def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]:
"""从数据库加载摘要记录对象"""
try:
with self._SessionLocal() as session:
with self._db_session() as session:
record = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
if record:
# Detach the object from the session so it can be used after session close
@@ -628,6 +757,8 @@ class Filter:
body: dict,
__user__: Optional[dict] = None,
__metadata__: dict = None,
__request__: Request = None,
__model__: dict = None,
__event_emitter__: Callable[[Any], Awaitable[None]] = None,
__event_call__: Callable[[Any], Awaitable[None]] = None,
) -> dict:
@@ -641,8 +772,10 @@ class Filter:
messages = body.get("messages", [])
# --- 原生工具输出裁剪 (Native Tool Output Trimming) ---
# 即使未启用压缩,也始终检查并裁剪过长的工具输出,以节省 Token
if self.valves.enable_tool_output_trimming:
metadata = body.get("metadata", {})
is_native_func_calling = metadata.get("function_calling") == "native"
if self.valves.enable_tool_output_trimming and is_native_func_calling:
trimmed_count = 0
for msg in messages:
content = msg.get("content", "")
@@ -1207,6 +1340,13 @@ class Filter:
"""
chat_ctx = self._get_chat_context(body, __metadata__)
chat_id = chat_ctx["chat_id"]
if not chat_id:
await self._log(
"[Outlet] ❌ metadata 中缺少 chat_id跳过压缩",
type="error",
event_call=__event_call__,
)
return body
model = body.get("model") or ""
# 直接计算目标压缩进度