diff --git a/plugins/filters/async-context-compression/async_context_compression.py b/plugins/filters/async-context-compression/async_context_compression.py index 0ec2323..d757067 100644 --- a/plugins/filters/async-context-compression/async_context_compression.py +++ b/plugins/filters/async-context-compression/async_context_compression.py @@ -249,6 +249,7 @@ import asyncio import json import hashlib import time +import contextlib # Open WebUI built-in imports from open_webui.utils.chat import generate_chat_completion @@ -257,7 +258,10 @@ from fastapi.requests import Request from open_webui.main import app as webui_app # Open WebUI internal database (re-use shared connection) -import open_webui.internal.db as owui_db +try: + from open_webui.internal import db as owui_db +except Exception: # pragma: no cover - filter runs inside Open WebUI + owui_db = None # Try to import tiktoken try: @@ -267,17 +271,88 @@ except ImportError: # Database imports from sqlalchemy import Column, String, Text, DateTime, Integer, inspect +from sqlalchemy.orm import declarative_base, sessionmaker from datetime import datetime -owui_Base = owui_db.Base +def _discover_owui_engine(db_module) -> Any | None: + if db_module is None: + return None + + db_context = getattr(db_module, "get_db_context", None) or getattr( + db_module, "get_db", None + ) + if callable(db_context): + try: + with db_context() as session: + try: + return session.get_bind() + except Exception: + return getattr(session, "bind", None) or getattr( + session, "engine", None + ) + except Exception: + pass + + for attr in ("engine", "ENGINE", "bind", "BIND"): + candidate = getattr(db_module, attr, None) + if candidate is not None: + return candidate + + return None + + +def _discover_owui_schema(db_module) -> str | None: + if db_module is None: + return None + + try: + base = getattr(db_module, "Base", None) + metadata = getattr(base, "metadata", None) if base is not None else None + candidate = getattr(metadata, "schema", None) if metadata is not None else None + if isinstance(candidate, str) and candidate.strip(): + return candidate.strip() + except Exception: + pass + + try: + metadata_obj = getattr(db_module, "metadata_obj", None) + candidate = ( + getattr(metadata_obj, "schema", None) if metadata_obj is not None else None + ) + if isinstance(candidate, str) and candidate.strip(): + return candidate.strip() + except Exception: + pass + + try: + from open_webui import env as owui_env + + candidate = getattr(owui_env, "DATABASE_SCHEMA", None) + if isinstance(candidate, str) and candidate.strip(): + return candidate.strip() + except Exception: + pass + + return None + + +owui_engine = _discover_owui_engine(owui_db) +owui_schema = _discover_owui_schema(owui_db) +owui_Base = getattr(owui_db, "Base", None) if owui_db is not None else None +if owui_Base is None: + owui_Base = declarative_base() class ChatSummary(owui_Base): """Chat Summary Storage Table""" __tablename__ = "chat_summary" - __table_args__ = {"extend_existing": True} + __table_args__ = ( + {"extend_existing": True, "schema": owui_schema} + if owui_schema + else {"extend_existing": True} + ) id = Column(Integer, primary_key=True, autoincrement=True) chat_id = Column(String(255), unique=True, nullable=False, index=True) @@ -290,20 +365,59 @@ class ChatSummary(owui_Base): class Filter: def __init__(self): self.valves = self.Valves() - self._db_engine = owui_db.engine - self._SessionLocal = ( - getattr(owui_db, "ScopedSession", None) - or getattr(owui_db, "SessionLocal", None) - or getattr(owui_db, "Session", None) - ) - if self._SessionLocal is None: - raise RuntimeError("Open WebUI database session factory unavailable.") + self._owui_db = owui_db + self._db_engine = owui_engine self.temp_state = {} # Used to pass temporary data between inlet and outlet + self._fallback_session_factory = ( + sessionmaker(bind=self._db_engine) if self._db_engine else None + ) self._init_database() + @contextlib.contextmanager + def _db_session(self): + db_module = self._owui_db + db_context = None + if db_module is not None: + db_context = getattr(db_module, "get_db_context", None) or getattr( + db_module, "get_db", None + ) + + if callable(db_context): + with db_context() as session: + yield session + return + + factory = None + if db_module is not None: + factory = getattr(db_module, "SessionLocal", None) or getattr( + db_module, "ScopedSession", None + ) + if callable(factory): + session = factory() + try: + yield session + finally: + close = getattr(session, "close", None) + if callable(close): + close() + return + + if self._fallback_session_factory is None: + raise RuntimeError("Open WebUI database session is unavailable.") + + session = self._fallback_session_factory() + try: + yield session + finally: + with contextlib.suppress(Exception): + session.close() + def _init_database(self): """Initializes the database table using Open WebUI's shared connection.""" try: + if self._db_engine is None: + raise RuntimeError("Open WebUI database engine is unavailable.") + # Check if table exists using SQLAlchemy inspect inspector = inspect(self._db_engine) if not inspector.has_table("chat_summary"): @@ -373,7 +487,7 @@ class Filter: def _save_summary(self, chat_id: str, summary: str, compressed_count: int): """Saves the summary to the database.""" try: - with self._SessionLocal() as session: + with self._db_session() as session: # Find existing record existing = session.query(ChatSummary).filter_by(chat_id=chat_id).first() @@ -413,7 +527,7 @@ class Filter: def _load_summary_record(self, chat_id: str) -> Optional[ChatSummary]: """Loads the summary record object from the database.""" try: - with self._SessionLocal() as session: + with self._db_session() as session: record = session.query(ChatSummary).filter_by(chat_id=chat_id).first() if record: # Detach the object from the session so it can be used after session close