fix: make async compression db session discovery robust

Co-authored-by: Fu-Jie <33599649+Fu-Jie@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-01-11 08:27:36 +00:00
parent 4b8515f682
commit 9e98d55e11

View File

@@ -249,6 +249,7 @@ import asyncio
import json
import hashlib
import time
import contextlib
# Open WebUI built-in imports
from open_webui.utils.chat import generate_chat_completion
@@ -257,7 +258,10 @@ from fastapi.requests import Request
from open_webui.main import app as webui_app
# Open WebUI internal database (re-use shared connection)
import open_webui.internal.db as owui_db
try:
from open_webui.internal import db as owui_db
except Exception: # pragma: no cover - filter runs inside Open WebUI
owui_db = None
# Try to import tiktoken
try:
@@ -267,17 +271,88 @@ except ImportError:
# Database imports
from sqlalchemy import Column, String, Text, DateTime, Integer, inspect
from sqlalchemy.orm import declarative_base, sessionmaker
from datetime import datetime
owui_Base = owui_db.Base
def _discover_owui_engine(db_module) -> Any | None:
if db_module is None:
return None
db_context = getattr(db_module, "get_db_context", None) or getattr(
db_module, "get_db", None
)
if callable(db_context):
try:
with db_context() as session:
try:
return session.get_bind()
except Exception:
return getattr(session, "bind", None) or getattr(
session, "engine", None
)
except Exception:
pass
for attr in ("engine", "ENGINE", "bind", "BIND"):
candidate = getattr(db_module, attr, None)
if candidate is not None:
return candidate
return None
def _discover_owui_schema(db_module) -> str | None:
if db_module is None:
return None
try:
base = getattr(db_module, "Base", None)
metadata = getattr(base, "metadata", None) if base is not None else None
candidate = getattr(metadata, "schema", None) if metadata is not None else None
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
except Exception:
pass
try:
metadata_obj = getattr(db_module, "metadata_obj", None)
candidate = (
getattr(metadata_obj, "schema", None) if metadata_obj is not None else None
)
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
except Exception:
pass
try:
from open_webui import env as owui_env
candidate = getattr(owui_env, "DATABASE_SCHEMA", None)
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
except Exception:
pass
return None
owui_engine = _discover_owui_engine(owui_db)
owui_schema = _discover_owui_schema(owui_db)
owui_Base = getattr(owui_db, "Base", None) if owui_db is not None else None
if owui_Base is None:
owui_Base = declarative_base()
class ChatSummary(owui_Base):
"""Chat Summary Storage Table"""
__tablename__ = "chat_summary"
__table_args__ = {"extend_existing": True}
__table_args__ = (
{"extend_existing": True, "schema": owui_schema}
if owui_schema
else {"extend_existing": True}
)
id = Column(Integer, primary_key=True, autoincrement=True)
chat_id = Column(String(255), unique=True, nullable=False, index=True)
@@ -290,20 +365,59 @@ class ChatSummary(owui_Base):
class Filter:
def __init__(self):
self.valves = self.Valves()
self._db_engine = owui_db.engine
self._SessionLocal = (
getattr(owui_db, "ScopedSession", None)
or getattr(owui_db, "SessionLocal", None)
or getattr(owui_db, "Session", None)
)
if self._SessionLocal is None:
raise RuntimeError("Open WebUI database session factory unavailable.")
self._owui_db = owui_db
self._db_engine = owui_engine
self.temp_state = {} # Used to pass temporary data between inlet and outlet
self._fallback_session_factory = (
sessionmaker(bind=self._db_engine) if self._db_engine else None
)
self._init_database()
@contextlib.contextmanager
def _db_session(self):
db_module = self._owui_db
db_context = None
if db_module is not None:
db_context = getattr(db_module, "get_db_context", None) or getattr(
db_module, "get_db", None
)
if callable(db_context):
with db_context() as session:
yield session
return
factory = None
if db_module is not None:
factory = getattr(db_module, "SessionLocal", None) or getattr(
db_module, "ScopedSession", None
)
if callable(factory):
session = factory()
try:
yield session
finally:
close = getattr(session, "close", None)
if callable(close):
close()
return
if self._fallback_session_factory is None:
raise RuntimeError("Open WebUI database session is unavailable.")
session = self._fallback_session_factory()
try:
yield session
finally:
with contextlib.suppress(Exception):
session.close()
def _init_database(self):
"""Initializes the database table using Open WebUI's shared connection."""
try:
if self._db_engine is None:
raise RuntimeError("Open WebUI database engine is unavailable.")
# Check if table exists using SQLAlchemy inspect
inspector = inspect(self._db_engine)
if not inspector.has_table("chat_summary"):
@@ -373,7 +487,7 @@ class Filter:
def _save_summary(self, chat_id: str, summary: str, compressed_count: int):
"""Saves the summary to the database."""
try:
with self._SessionLocal() as session:
with self._db_session() as session:
# Find existing record
existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
@@ -413,7 +527,7 @@ class Filter:
def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]:
"""Loads the summary record object from the database."""
try:
with self._SessionLocal() as session:
with self._db_session() as session:
record = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
if record:
# Detach the object from the session so it can be used after session close