fix(async-context-compression): sync CN version with EN version logic
- Add missing imports (contextlib, sessionmaker, Engine) - Add database engine discovery functions (_discover_owui_engine, _discover_owui_schema) - Fix ChatSummary table to support schema configuration - Fix duplicate code in __init__ method - Add _db_session context manager for robust session handling - Fix inlet method signature (add __request__, __model__ parameters) - Fix tool output trimming to check native function calling - Add chat_id empty check in outlet method
This commit is contained in:
@@ -254,23 +254,25 @@ show_debug_log (前端调试日志)
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from typing import Optional, Dict, Any, List, Union, Callable, Awaitable
|
from typing import Optional, Dict, Any, List, Union, Callable, Awaitable
|
||||||
|
import re
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
import re
|
import contextlib
|
||||||
|
|
||||||
# Open WebUI 内置导入
|
# Open WebUI 内置导入
|
||||||
from open_webui.utils.chat import generate_chat_completion
|
from open_webui.utils.chat import generate_chat_completion
|
||||||
from open_webui.models.models import Models
|
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users
|
||||||
|
from open_webui.models.models import Models
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from open_webui.main import app as webui_app
|
from open_webui.main import app as webui_app
|
||||||
|
|
||||||
# Open WebUI 内部数据库 (复用共享连接)
|
# Open WebUI 内部数据库 (复用共享连接)
|
||||||
from open_webui.internal.db import engine as owui_engine
|
try:
|
||||||
from open_webui.internal.db import Session as owui_Session
|
from open_webui.internal import db as owui_db
|
||||||
from open_webui.internal.db import Base as owui_Base
|
except ModuleNotFoundError: # pragma: no cover - filter runs inside Open WebUI
|
||||||
|
owui_db = None
|
||||||
|
|
||||||
# 尝试导入 tiktoken
|
# 尝试导入 tiktoken
|
||||||
try:
|
try:
|
||||||
@@ -280,14 +282,91 @@ except ImportError:
|
|||||||
|
|
||||||
# 数据库导入
|
# 数据库导入
|
||||||
from sqlalchemy import Column, String, Text, DateTime, Integer, inspect
|
from sqlalchemy import Column, String, Text, DateTime, Integer, inspect
|
||||||
|
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def _discover_owui_engine(db_module: Any) -> Optional[Engine]:
|
||||||
|
"""Discover the Open WebUI SQLAlchemy engine via provided db module helpers."""
|
||||||
|
if db_module is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
db_context = getattr(db_module, "get_db_context", None) or getattr(
|
||||||
|
db_module, "get_db", None
|
||||||
|
)
|
||||||
|
if callable(db_context):
|
||||||
|
try:
|
||||||
|
with db_context() as session:
|
||||||
|
try:
|
||||||
|
return session.get_bind()
|
||||||
|
except AttributeError:
|
||||||
|
return getattr(session, "bind", None) or getattr(
|
||||||
|
session, "engine", None
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"[DB Discover] get_db_context failed: {exc}")
|
||||||
|
|
||||||
|
for attr in ("engine", "ENGINE", "bind", "BIND"):
|
||||||
|
candidate = getattr(db_module, attr, None)
|
||||||
|
if candidate is not None:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _discover_owui_schema(db_module: Any) -> Optional[str]:
|
||||||
|
"""Discover the Open WebUI database schema name if configured."""
|
||||||
|
if db_module is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
base = getattr(db_module, "Base", None)
|
||||||
|
metadata = getattr(base, "metadata", None) if base is not None else None
|
||||||
|
candidate = getattr(metadata, "schema", None) if metadata is not None else None
|
||||||
|
if isinstance(candidate, str) and candidate.strip():
|
||||||
|
return candidate.strip()
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"[DB Discover] Base metadata schema lookup failed: {exc}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata_obj = getattr(db_module, "metadata_obj", None)
|
||||||
|
candidate = (
|
||||||
|
getattr(metadata_obj, "schema", None) if metadata_obj is not None else None
|
||||||
|
)
|
||||||
|
if isinstance(candidate, str) and candidate.strip():
|
||||||
|
return candidate.strip()
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"[DB Discover] metadata_obj schema lookup failed: {exc}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from open_webui import env as owui_env
|
||||||
|
|
||||||
|
candidate = getattr(owui_env, "DATABASE_SCHEMA", None)
|
||||||
|
if isinstance(candidate, str) and candidate.strip():
|
||||||
|
return candidate.strip()
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"[DB Discover] env schema lookup failed: {exc}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
owui_engine = _discover_owui_engine(owui_db)
|
||||||
|
owui_schema = _discover_owui_schema(owui_db)
|
||||||
|
owui_Base = getattr(owui_db, "Base", None) if owui_db is not None else None
|
||||||
|
if owui_Base is None:
|
||||||
|
owui_Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class ChatSummary(owui_Base):
|
class ChatSummary(owui_Base):
|
||||||
"""对话摘要存储表"""
|
"""对话摘要存储表"""
|
||||||
|
|
||||||
__tablename__ = "chat_summary"
|
__tablename__ = "chat_summary"
|
||||||
__table_args__ = {"extend_existing": True}
|
__table_args__ = (
|
||||||
|
{"extend_existing": True, "schema": owui_schema}
|
||||||
|
if owui_schema
|
||||||
|
else {"extend_existing": True}
|
||||||
|
)
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
chat_id = Column(String(255), unique=True, nullable=False, index=True)
|
chat_id = Column(String(255), unique=True, nullable=False, index=True)
|
||||||
@@ -300,15 +379,65 @@ class ChatSummary(owui_Base):
|
|||||||
class Filter:
|
class Filter:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.valves = self.Valves()
|
self.valves = self.Valves()
|
||||||
|
self._owui_db = owui_db
|
||||||
self._db_engine = owui_engine
|
self._db_engine = owui_engine
|
||||||
self._SessionLocal = owui_Session
|
self._fallback_session_factory = (
|
||||||
self._SessionLocal = owui_Session
|
sessionmaker(bind=self._db_engine) if self._db_engine else None
|
||||||
self._init_database()
|
)
|
||||||
self._init_database()
|
self._init_database()
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _db_session(self):
|
||||||
|
"""Yield a database session using Open WebUI helpers with graceful fallbacks."""
|
||||||
|
db_module = self._owui_db
|
||||||
|
db_context = None
|
||||||
|
if db_module is not None:
|
||||||
|
db_context = getattr(db_module, "get_db_context", None) or getattr(
|
||||||
|
db_module, "get_db", None
|
||||||
|
)
|
||||||
|
|
||||||
|
if callable(db_context):
|
||||||
|
with db_context() as session:
|
||||||
|
yield session
|
||||||
|
return
|
||||||
|
|
||||||
|
factory = None
|
||||||
|
if db_module is not None:
|
||||||
|
factory = getattr(db_module, "SessionLocal", None) or getattr(
|
||||||
|
db_module, "ScopedSession", None
|
||||||
|
)
|
||||||
|
if callable(factory):
|
||||||
|
session = factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
close = getattr(session, "close", None)
|
||||||
|
if callable(close):
|
||||||
|
close()
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._fallback_session_factory is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Open WebUI database session is unavailable. Ensure Open WebUI's database layer is initialized."
|
||||||
|
)
|
||||||
|
|
||||||
|
session = self._fallback_session_factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
session.close()
|
||||||
|
except Exception as exc: # pragma: no cover - best-effort cleanup
|
||||||
|
print(f"[Database] ⚠️ Failed to close fallback session: {exc}")
|
||||||
|
|
||||||
def _init_database(self):
|
def _init_database(self):
|
||||||
"""使用 Open WebUI 的共享连接初始化数据库表"""
|
"""使用 Open WebUI 的共享连接初始化数据库表"""
|
||||||
try:
|
try:
|
||||||
|
if self._db_engine is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Open WebUI database engine is unavailable. Ensure Open WebUI is configured with a valid DATABASE_URL."
|
||||||
|
)
|
||||||
|
|
||||||
# 使用 SQLAlchemy inspect 检查表是否存在
|
# 使用 SQLAlchemy inspect 检查表是否存在
|
||||||
inspector = inspect(self._db_engine)
|
inspector = inspect(self._db_engine)
|
||||||
if not inspector.has_table("chat_summary"):
|
if not inspector.has_table("chat_summary"):
|
||||||
@@ -376,7 +505,7 @@ class Filter:
|
|||||||
def _save_summary(self, chat_id: str, summary: str, compressed_count: int):
|
def _save_summary(self, chat_id: str, summary: str, compressed_count: int):
|
||||||
"""保存摘要到数据库"""
|
"""保存摘要到数据库"""
|
||||||
try:
|
try:
|
||||||
with self._SessionLocal() as session:
|
with self._db_session() as session:
|
||||||
# 查找现有记录
|
# 查找现有记录
|
||||||
existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
||||||
|
|
||||||
@@ -414,7 +543,7 @@ class Filter:
|
|||||||
def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]:
|
def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]:
|
||||||
"""从数据库加载摘要记录对象"""
|
"""从数据库加载摘要记录对象"""
|
||||||
try:
|
try:
|
||||||
with self._SessionLocal() as session:
|
with self._db_session() as session:
|
||||||
record = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
record = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
||||||
if record:
|
if record:
|
||||||
# Detach the object from the session so it can be used after session close
|
# Detach the object from the session so it can be used after session close
|
||||||
@@ -628,6 +757,8 @@ class Filter:
|
|||||||
body: dict,
|
body: dict,
|
||||||
__user__: Optional[dict] = None,
|
__user__: Optional[dict] = None,
|
||||||
__metadata__: dict = None,
|
__metadata__: dict = None,
|
||||||
|
__request__: Request = None,
|
||||||
|
__model__: dict = None,
|
||||||
__event_emitter__: Callable[[Any], Awaitable[None]] = None,
|
__event_emitter__: Callable[[Any], Awaitable[None]] = None,
|
||||||
__event_call__: Callable[[Any], Awaitable[None]] = None,
|
__event_call__: Callable[[Any], Awaitable[None]] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@@ -641,8 +772,10 @@ class Filter:
|
|||||||
messages = body.get("messages", [])
|
messages = body.get("messages", [])
|
||||||
|
|
||||||
# --- 原生工具输出裁剪 (Native Tool Output Trimming) ---
|
# --- 原生工具输出裁剪 (Native Tool Output Trimming) ---
|
||||||
# 即使未启用压缩,也始终检查并裁剪过长的工具输出,以节省 Token
|
metadata = body.get("metadata", {})
|
||||||
if self.valves.enable_tool_output_trimming:
|
is_native_func_calling = metadata.get("function_calling") == "native"
|
||||||
|
|
||||||
|
if self.valves.enable_tool_output_trimming and is_native_func_calling:
|
||||||
trimmed_count = 0
|
trimmed_count = 0
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
content = msg.get("content", "")
|
content = msg.get("content", "")
|
||||||
@@ -1207,6 +1340,13 @@ class Filter:
|
|||||||
"""
|
"""
|
||||||
chat_ctx = self._get_chat_context(body, __metadata__)
|
chat_ctx = self._get_chat_context(body, __metadata__)
|
||||||
chat_id = chat_ctx["chat_id"]
|
chat_id = chat_ctx["chat_id"]
|
||||||
|
if not chat_id:
|
||||||
|
await self._log(
|
||||||
|
"[Outlet] ❌ metadata 中缺少 chat_id,跳过压缩",
|
||||||
|
type="error",
|
||||||
|
event_call=__event_call__,
|
||||||
|
)
|
||||||
|
return body
|
||||||
model = body.get("model") or ""
|
model = body.get("model") or ""
|
||||||
|
|
||||||
# 直接计算目标压缩进度
|
# 直接计算目标压缩进度
|
||||||
|
|||||||
Reference in New Issue
Block a user