fix: make async compression db session discovery robust
Co-authored-by: Fu-Jie <33599649+Fu-Jie@users.noreply.github.com>
This commit is contained in:
@@ -249,6 +249,7 @@ import asyncio
|
||||
import json
|
||||
import hashlib
|
||||
import time
|
||||
import contextlib
|
||||
|
||||
# Open WebUI built-in imports
|
||||
from open_webui.utils.chat import generate_chat_completion
|
||||
@@ -257,7 +258,10 @@ from fastapi.requests import Request
|
||||
from open_webui.main import app as webui_app
|
||||
|
||||
# Open WebUI internal database (re-use shared connection)
|
||||
import open_webui.internal.db as owui_db
|
||||
try:
|
||||
from open_webui.internal import db as owui_db
|
||||
except Exception: # pragma: no cover - filter runs inside Open WebUI
|
||||
owui_db = None
|
||||
|
||||
# Try to import tiktoken
|
||||
try:
|
||||
@@ -267,17 +271,88 @@ except ImportError:
|
||||
|
||||
# Database imports
|
||||
from sqlalchemy import Column, String, Text, DateTime, Integer, inspect
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
owui_Base = owui_db.Base
|
||||
def _discover_owui_engine(db_module) -> Any | None:
|
||||
if db_module is None:
|
||||
return None
|
||||
|
||||
db_context = getattr(db_module, "get_db_context", None) or getattr(
|
||||
db_module, "get_db", None
|
||||
)
|
||||
if callable(db_context):
|
||||
try:
|
||||
with db_context() as session:
|
||||
try:
|
||||
return session.get_bind()
|
||||
except Exception:
|
||||
return getattr(session, "bind", None) or getattr(
|
||||
session, "engine", None
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for attr in ("engine", "ENGINE", "bind", "BIND"):
|
||||
candidate = getattr(db_module, attr, None)
|
||||
if candidate is not None:
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _discover_owui_schema(db_module) -> str | None:
|
||||
if db_module is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
base = getattr(db_module, "Base", None)
|
||||
metadata = getattr(base, "metadata", None) if base is not None else None
|
||||
candidate = getattr(metadata, "schema", None) if metadata is not None else None
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
return candidate.strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
metadata_obj = getattr(db_module, "metadata_obj", None)
|
||||
candidate = (
|
||||
getattr(metadata_obj, "schema", None) if metadata_obj is not None else None
|
||||
)
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
return candidate.strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from open_webui import env as owui_env
|
||||
|
||||
candidate = getattr(owui_env, "DATABASE_SCHEMA", None)
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
return candidate.strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
owui_engine = _discover_owui_engine(owui_db)
|
||||
owui_schema = _discover_owui_schema(owui_db)
|
||||
owui_Base = getattr(owui_db, "Base", None) if owui_db is not None else None
|
||||
if owui_Base is None:
|
||||
owui_Base = declarative_base()
|
||||
|
||||
|
||||
class ChatSummary(owui_Base):
|
||||
"""Chat Summary Storage Table"""
|
||||
|
||||
__tablename__ = "chat_summary"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
__table_args__ = (
|
||||
{"extend_existing": True, "schema": owui_schema}
|
||||
if owui_schema
|
||||
else {"extend_existing": True}
|
||||
)
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
chat_id = Column(String(255), unique=True, nullable=False, index=True)
|
||||
@@ -290,20 +365,59 @@ class ChatSummary(owui_Base):
|
||||
class Filter:
|
||||
def __init__(self):
|
||||
self.valves = self.Valves()
|
||||
self._db_engine = owui_db.engine
|
||||
self._SessionLocal = (
|
||||
getattr(owui_db, "ScopedSession", None)
|
||||
or getattr(owui_db, "SessionLocal", None)
|
||||
or getattr(owui_db, "Session", None)
|
||||
)
|
||||
if self._SessionLocal is None:
|
||||
raise RuntimeError("Open WebUI database session factory unavailable.")
|
||||
self._owui_db = owui_db
|
||||
self._db_engine = owui_engine
|
||||
self.temp_state = {} # Used to pass temporary data between inlet and outlet
|
||||
self._fallback_session_factory = (
|
||||
sessionmaker(bind=self._db_engine) if self._db_engine else None
|
||||
)
|
||||
self._init_database()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _db_session(self):
|
||||
db_module = self._owui_db
|
||||
db_context = None
|
||||
if db_module is not None:
|
||||
db_context = getattr(db_module, "get_db_context", None) or getattr(
|
||||
db_module, "get_db", None
|
||||
)
|
||||
|
||||
if callable(db_context):
|
||||
with db_context() as session:
|
||||
yield session
|
||||
return
|
||||
|
||||
factory = None
|
||||
if db_module is not None:
|
||||
factory = getattr(db_module, "SessionLocal", None) or getattr(
|
||||
db_module, "ScopedSession", None
|
||||
)
|
||||
if callable(factory):
|
||||
session = factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
close = getattr(session, "close", None)
|
||||
if callable(close):
|
||||
close()
|
||||
return
|
||||
|
||||
if self._fallback_session_factory is None:
|
||||
raise RuntimeError("Open WebUI database session is unavailable.")
|
||||
|
||||
session = self._fallback_session_factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
session.close()
|
||||
|
||||
def _init_database(self):
|
||||
"""Initializes the database table using Open WebUI's shared connection."""
|
||||
try:
|
||||
if self._db_engine is None:
|
||||
raise RuntimeError("Open WebUI database engine is unavailable.")
|
||||
|
||||
# Check if table exists using SQLAlchemy inspect
|
||||
inspector = inspect(self._db_engine)
|
||||
if not inspector.has_table("chat_summary"):
|
||||
@@ -373,7 +487,7 @@ class Filter:
|
||||
def _save_summary(self, chat_id: str, summary: str, compressed_count: int):
|
||||
"""Saves the summary to the database."""
|
||||
try:
|
||||
with self._SessionLocal() as session:
|
||||
with self._db_session() as session:
|
||||
# Find existing record
|
||||
existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
||||
|
||||
@@ -413,7 +527,7 @@ class Filter:
|
||||
def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]:
|
||||
"""Loads the summary record object from the database."""
|
||||
try:
|
||||
with self._SessionLocal() as session:
|
||||
with self._db_session() as session:
|
||||
record = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
||||
if record:
|
||||
# Detach the object from the session so it can be used after session close
|
||||
|
||||
Reference in New Issue
Block a user