fix: make async compression db session discovery robust
Co-authored-by: Fu-Jie <33599649+Fu-Jie@users.noreply.github.com>
This commit is contained in:
@@ -249,6 +249,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
|
import contextlib
|
||||||
|
|
||||||
# Open WebUI built-in imports
|
# Open WebUI built-in imports
|
||||||
from open_webui.utils.chat import generate_chat_completion
|
from open_webui.utils.chat import generate_chat_completion
|
||||||
@@ -257,7 +258,10 @@ from fastapi.requests import Request
|
|||||||
from open_webui.main import app as webui_app
|
from open_webui.main import app as webui_app
|
||||||
|
|
||||||
# Open WebUI internal database (re-use shared connection)
|
# Open WebUI internal database (re-use shared connection)
|
||||||
import open_webui.internal.db as owui_db
|
try:
|
||||||
|
from open_webui.internal import db as owui_db
|
||||||
|
except Exception: # pragma: no cover - filter runs inside Open WebUI
|
||||||
|
owui_db = None
|
||||||
|
|
||||||
# Try to import tiktoken
|
# Try to import tiktoken
|
||||||
try:
|
try:
|
||||||
@@ -267,17 +271,88 @@ except ImportError:
|
|||||||
|
|
||||||
# Database imports
|
# Database imports
|
||||||
from sqlalchemy import Column, String, Text, DateTime, Integer, inspect
|
from sqlalchemy import Column, String, Text, DateTime, Integer, inspect
|
||||||
|
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
owui_Base = owui_db.Base
|
def _discover_owui_engine(db_module) -> Any | None:
|
||||||
|
if db_module is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
db_context = getattr(db_module, "get_db_context", None) or getattr(
|
||||||
|
db_module, "get_db", None
|
||||||
|
)
|
||||||
|
if callable(db_context):
|
||||||
|
try:
|
||||||
|
with db_context() as session:
|
||||||
|
try:
|
||||||
|
return session.get_bind()
|
||||||
|
except Exception:
|
||||||
|
return getattr(session, "bind", None) or getattr(
|
||||||
|
session, "engine", None
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
for attr in ("engine", "ENGINE", "bind", "BIND"):
|
||||||
|
candidate = getattr(db_module, attr, None)
|
||||||
|
if candidate is not None:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _discover_owui_schema(db_module) -> str | None:
|
||||||
|
if db_module is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
base = getattr(db_module, "Base", None)
|
||||||
|
metadata = getattr(base, "metadata", None) if base is not None else None
|
||||||
|
candidate = getattr(metadata, "schema", None) if metadata is not None else None
|
||||||
|
if isinstance(candidate, str) and candidate.strip():
|
||||||
|
return candidate.strip()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata_obj = getattr(db_module, "metadata_obj", None)
|
||||||
|
candidate = (
|
||||||
|
getattr(metadata_obj, "schema", None) if metadata_obj is not None else None
|
||||||
|
)
|
||||||
|
if isinstance(candidate, str) and candidate.strip():
|
||||||
|
return candidate.strip()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from open_webui import env as owui_env
|
||||||
|
|
||||||
|
candidate = getattr(owui_env, "DATABASE_SCHEMA", None)
|
||||||
|
if isinstance(candidate, str) and candidate.strip():
|
||||||
|
return candidate.strip()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
owui_engine = _discover_owui_engine(owui_db)
|
||||||
|
owui_schema = _discover_owui_schema(owui_db)
|
||||||
|
owui_Base = getattr(owui_db, "Base", None) if owui_db is not None else None
|
||||||
|
if owui_Base is None:
|
||||||
|
owui_Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class ChatSummary(owui_Base):
|
class ChatSummary(owui_Base):
|
||||||
"""Chat Summary Storage Table"""
|
"""Chat Summary Storage Table"""
|
||||||
|
|
||||||
__tablename__ = "chat_summary"
|
__tablename__ = "chat_summary"
|
||||||
__table_args__ = {"extend_existing": True}
|
__table_args__ = (
|
||||||
|
{"extend_existing": True, "schema": owui_schema}
|
||||||
|
if owui_schema
|
||||||
|
else {"extend_existing": True}
|
||||||
|
)
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
chat_id = Column(String(255), unique=True, nullable=False, index=True)
|
chat_id = Column(String(255), unique=True, nullable=False, index=True)
|
||||||
@@ -290,20 +365,59 @@ class ChatSummary(owui_Base):
|
|||||||
class Filter:
|
class Filter:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.valves = self.Valves()
|
self.valves = self.Valves()
|
||||||
self._db_engine = owui_db.engine
|
self._owui_db = owui_db
|
||||||
self._SessionLocal = (
|
self._db_engine = owui_engine
|
||||||
getattr(owui_db, "ScopedSession", None)
|
|
||||||
or getattr(owui_db, "SessionLocal", None)
|
|
||||||
or getattr(owui_db, "Session", None)
|
|
||||||
)
|
|
||||||
if self._SessionLocal is None:
|
|
||||||
raise RuntimeError("Open WebUI database session factory unavailable.")
|
|
||||||
self.temp_state = {} # Used to pass temporary data between inlet and outlet
|
self.temp_state = {} # Used to pass temporary data between inlet and outlet
|
||||||
|
self._fallback_session_factory = (
|
||||||
|
sessionmaker(bind=self._db_engine) if self._db_engine else None
|
||||||
|
)
|
||||||
self._init_database()
|
self._init_database()
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _db_session(self):
|
||||||
|
db_module = self._owui_db
|
||||||
|
db_context = None
|
||||||
|
if db_module is not None:
|
||||||
|
db_context = getattr(db_module, "get_db_context", None) or getattr(
|
||||||
|
db_module, "get_db", None
|
||||||
|
)
|
||||||
|
|
||||||
|
if callable(db_context):
|
||||||
|
with db_context() as session:
|
||||||
|
yield session
|
||||||
|
return
|
||||||
|
|
||||||
|
factory = None
|
||||||
|
if db_module is not None:
|
||||||
|
factory = getattr(db_module, "SessionLocal", None) or getattr(
|
||||||
|
db_module, "ScopedSession", None
|
||||||
|
)
|
||||||
|
if callable(factory):
|
||||||
|
session = factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
close = getattr(session, "close", None)
|
||||||
|
if callable(close):
|
||||||
|
close()
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._fallback_session_factory is None:
|
||||||
|
raise RuntimeError("Open WebUI database session is unavailable.")
|
||||||
|
|
||||||
|
session = self._fallback_session_factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
session.close()
|
||||||
|
|
||||||
def _init_database(self):
|
def _init_database(self):
|
||||||
"""Initializes the database table using Open WebUI's shared connection."""
|
"""Initializes the database table using Open WebUI's shared connection."""
|
||||||
try:
|
try:
|
||||||
|
if self._db_engine is None:
|
||||||
|
raise RuntimeError("Open WebUI database engine is unavailable.")
|
||||||
|
|
||||||
# Check if table exists using SQLAlchemy inspect
|
# Check if table exists using SQLAlchemy inspect
|
||||||
inspector = inspect(self._db_engine)
|
inspector = inspect(self._db_engine)
|
||||||
if not inspector.has_table("chat_summary"):
|
if not inspector.has_table("chat_summary"):
|
||||||
@@ -373,7 +487,7 @@ class Filter:
|
|||||||
def _save_summary(self, chat_id: str, summary: str, compressed_count: int):
|
def _save_summary(self, chat_id: str, summary: str, compressed_count: int):
|
||||||
"""Saves the summary to the database."""
|
"""Saves the summary to the database."""
|
||||||
try:
|
try:
|
||||||
with self._SessionLocal() as session:
|
with self._db_session() as session:
|
||||||
# Find existing record
|
# Find existing record
|
||||||
existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
||||||
|
|
||||||
@@ -413,7 +527,7 @@ class Filter:
|
|||||||
def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]:
|
def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]:
|
||||||
"""Loads the summary record object from the database."""
|
"""Loads the summary record object from the database."""
|
||||||
try:
|
try:
|
||||||
with self._SessionLocal() as session:
|
with self._db_session() as session:
|
||||||
record = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
record = session.query(ChatSummary).filter_by(chat_id=chat_id).first()
|
||||||
if record:
|
if record:
|
||||||
# Detach the object from the session so it can be used after session close
|
# Detach the object from the session so it can be used after session close
|
||||||
|
|||||||
Reference in New Issue
Block a user