fix(async-context-compression): sync CN version with EN version logic
- Add missing imports (contextlib, sessionmaker, Engine) - Add database engine discovery functions (_discover_owui_engine, _discover_owui_schema) - Fix ChatSummary table to support schema configuration - Fix duplicate code in __init__ method - Add _db_session context manager for robust session handling - Fix inlet method signature (add __request__, __model__ parameters) - Fix tool output trimming to check native function calling - Add chat_id empty check in outlet method
This commit is contained in:
@@ -254,23 +254,25 @@ show_debug_log (前端调试日志)
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import Optional, Dict, Any, List, Union, Callable, Awaitable
|
||||
import re
|
||||
import asyncio
|
||||
import json
|
||||
import hashlib
|
||||
import time
|
||||
import re
|
||||
import contextlib
|
||||
|
||||
# Open WebUI 内置导入
|
||||
from open_webui.utils.chat import generate_chat_completion
|
||||
from open_webui.models.models import Models
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.models import Models
|
||||
from fastapi.requests import Request
|
||||
from open_webui.main import app as webui_app
|
||||
|
||||
# Open WebUI 内部数据库 (复用共享连接)
|
||||
from open_webui.internal.db import engine as owui_engine
|
||||
from open_webui.internal.db import Session as owui_Session
|
||||
from open_webui.internal.db import Base as owui_Base
|
||||
try:
|
||||
from open_webui.internal import db as owui_db
|
||||
except ModuleNotFoundError: # pragma: no cover - filter runs inside Open WebUI
|
||||
owui_db = None
|
||||
|
||||
# 尝试导入 tiktoken
|
||||
try:
|
||||
@@ -280,14 +282,91 @@ except ImportError:
|
||||
|
||||
# 数据库导入
|
||||
from sqlalchemy import Column, String, Text, DateTime, Integer, inspect
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from sqlalchemy.engine import Engine
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def _discover_owui_engine(db_module: Any) -> Optional[Engine]:
|
||||
"""Discover the Open WebUI SQLAlchemy engine via provided db module helpers."""
|
||||
if db_module is None:
|
||||
return None
|
||||
|
||||
db_context = getattr(db_module, "get_db_context", None) or getattr(
|
||||
db_module, "get_db", None
|
||||
)
|
||||
if callable(db_context):
|
||||
try:
|
||||
with db_context() as session:
|
||||
try:
|
||||
return session.get_bind()
|
||||
except AttributeError:
|
||||
return getattr(session, "bind", None) or getattr(
|
||||
session, "engine", None
|
||||
)
|
||||
except Exception as exc:
|
||||
print(f"[DB Discover] get_db_context failed: {exc}")
|
||||
|
||||
for attr in ("engine", "ENGINE", "bind", "BIND"):
|
||||
candidate = getattr(db_module, attr, None)
|
||||
if candidate is not None:
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _discover_owui_schema(db_module: Any) -> Optional[str]:
|
||||
"""Discover the Open WebUI database schema name if configured."""
|
||||
if db_module is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
base = getattr(db_module, "Base", None)
|
||||
metadata = getattr(base, "metadata", None) if base is not None else None
|
||||
candidate = getattr(metadata, "schema", None) if metadata is not None else None
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
return candidate.strip()
|
||||
except Exception as exc:
|
||||
print(f"[DB Discover] Base metadata schema lookup failed: {exc}")
|
||||
|
||||
try:
|
||||
metadata_obj = getattr(db_module, "metadata_obj", None)
|
||||
candidate = (
|
||||
getattr(metadata_obj, "schema", None) if metadata_obj is not None else None
|
||||
)
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
return candidate.strip()
|
||||
except Exception as exc:
|
||||
print(f"[DB Discover] metadata_obj schema lookup failed: {exc}")
|
||||
|
||||
try:
|
||||
from open_webui import env as owui_env
|
||||
|
||||
candidate = getattr(owui_env, "DATABASE_SCHEMA", None)
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
return candidate.strip()
|
||||
except Exception as exc:
|
||||
print(f"[DB Discover] env schema lookup failed: {exc}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
owui_engine = _discover_owui_engine(owui_db)
|
||||
owui_schema = _discover_owui_schema(owui_db)
|
||||
owui_Base = getattr(owui_db, "Base", None) if owui_db is not None else None
|
||||
if owui_Base is None:
|
||||
owui_Base = declarative_base()
|
||||
|
||||
|
||||
class ChatSummary(owui_Base):
|
||||
"""对话摘要存储表"""
|
||||
|
||||
__tablename__ = "chat_summary"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
__table_args__ = (
|
||||
{"extend_existing": True, "schema": owui_schema}
|
||||
if owui_schema
|
||||
else {"extend_existing": True}
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
chat_id = Column(String(255), unique=True, nullable=False, index=True)
|
||||
@@ -300,15 +379,65 @@ class ChatSummary(owui_Base):
|
||||
class Filter:
|
||||
def __init__(self):
|
||||
self.valves = self.Valves()
|
||||
self._owui_db = owui_db
|
||||
self._db_engine = owui_engine
|
||||
self._SessionLocal = owui_Session
|
||||
self._SessionLocal = owui_Session
|
||||
self._init_database()
|
||||
self._fallback_session_factory = (
|
||||
sessionmaker(bind=self._db_engine) if self._db_engine else None
|
||||
)
|
||||
self._init_database()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _db_session(self):
|
||||
"""Yield a database session using Open WebUI helpers with graceful fallbacks."""
|
||||
db_module = self._owui_db
|
||||
db_context = None
|
||||
if db_module is not None:
|
||||
db_context = getattr(db_module, "get_db_context", None) or getattr(
|
||||
db_module, "get_db", None
|
||||
)
|
||||
|
||||
if callable(db_context):
|
||||
with db_context() as session:
|
||||
yield session
|
||||
return
|
||||
|
||||
factory = None
|
||||
if db_module is not None:
|
||||
factory = getattr(db_module, "SessionLocal", None) or getattr(
|
||||
db_module, "ScopedSession", None
|
||||
)
|
||||
if callable(factory):
|
||||
session = factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
close = getattr(session, "close", None)
|
||||
if callable(close):
|
||||
close()
|
||||
return
|
||||
|
||||
if self._fallback_session_factory is None:
|
||||
raise RuntimeError(
|
||||
"Open WebUI database session is unavailable. Ensure Open WebUI's database layer is initialized."
|
||||
)
|
||||
|
||||
session = self._fallback_session_factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
try:
|
||||
session.close()
|
||||
except Exception as exc: # pragma: no cover - best-effort cleanup
|
||||
print(f"[Database] ⚠️ Failed to close fallback session: {exc}")
|
||||
|
||||
def _init_database(self):
|
||||
"""使用 Open WebUI 的共享连接初始化数据库表"""
|
||||
try:
|
||||
if self._db_engine is None:
|
||||
raise RuntimeError(
|
||||
"Open WebUI database engine is unavailable. Ensure Open WebUI is configured with a valid DATABASE_URL."
|
||||
)
|
||||
|
||||
# 使用 SQLAlchemy inspect 检查表是否存在
|
||||
inspector = inspect(self._db_engine)
|
||||
if not inspector.has_table("chat_summary"):
|
||||
@@ -376,7 +505,7 @@ class Filter:
|
||||
def _save_summary(self, chat_id: str, summary: str, compressed_count: int):
|
||||
"""保存摘要到数据库"""
|
||||
try:
|
||||
with self._SessionLocal() as session:
|
||||
with self._db_session() as session:
|
||||
# 查找现有记录
|
||||
existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
||||
|
||||
@@ -414,7 +543,7 @@ class Filter:
|
||||
def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]:
|
||||
"""从数据库加载摘要记录对象"""
|
||||
try:
|
||||
with self._SessionLocal() as session:
|
||||
with self._db_session() as session:
|
||||
record = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
||||
if record:
|
||||
# Detach the object from the session so it can be used after session close
|
||||
@@ -628,6 +757,8 @@ class Filter:
|
||||
body: dict,
|
||||
__user__: Optional[dict] = None,
|
||||
__metadata__: dict = None,
|
||||
__request__: Request = None,
|
||||
__model__: dict = None,
|
||||
__event_emitter__: Callable[[Any], Awaitable[None]] = None,
|
||||
__event_call__: Callable[[Any], Awaitable[None]] = None,
|
||||
) -> dict:
|
||||
@@ -641,8 +772,10 @@ class Filter:
|
||||
messages = body.get("messages", [])
|
||||
|
||||
# --- 原生工具输出裁剪 (Native Tool Output Trimming) ---
|
||||
# 即使未启用压缩,也始终检查并裁剪过长的工具输出,以节省 Token
|
||||
if self.valves.enable_tool_output_trimming:
|
||||
metadata = body.get("metadata", {})
|
||||
is_native_func_calling = metadata.get("function_calling") == "native"
|
||||
|
||||
if self.valves.enable_tool_output_trimming and is_native_func_calling:
|
||||
trimmed_count = 0
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
@@ -1207,6 +1340,13 @@ class Filter:
|
||||
"""
|
||||
chat_ctx = self._get_chat_context(body, __metadata__)
|
||||
chat_id = chat_ctx["chat_id"]
|
||||
if not chat_id:
|
||||
await self._log(
|
||||
"[Outlet] ❌ metadata 中缺少 chat_id,跳过压缩",
|
||||
type="error",
|
||||
event_call=__event_call__,
|
||||
)
|
||||
return body
|
||||
model = body.get("model") or ""
|
||||
|
||||
# 直接计算目标压缩进度
|
||||
|
||||
Reference in New Issue
Block a user