96 lines
2.4 KiB
Python
96 lines
2.4 KiB
Python
|
|
"""
|
|||
|
|
数据库连接管理模块
|
|||
|
|
使用 SQLAlchemy ORM,支持同步操作
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import yaml
|
|||
|
|
from typing import Generator
|
|||
|
|
from contextlib import contextmanager
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
from sqlalchemy import create_engine, text
|
|||
|
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
|||
|
|
from sqlalchemy.pool import QueuePool
|
|||
|
|
from sqlalchemy.dialects.postgresql import dialect as pg_dialect
|
|||
|
|
|
|||
|
|
# 读取配置
|
|||
|
|
yaml_file = os.path.join(os.getcwd(), "config", "config.yaml")
|
|||
|
|
try:
|
|||
|
|
with open(yaml_file, "r", encoding="utf-8") as file:
|
|||
|
|
config = yaml.safe_load(file).get("database", {})
|
|||
|
|
except Exception:
|
|||
|
|
config = {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_database_url() -> str:
|
|||
|
|
"""获取数据库连接 URL"""
|
|||
|
|
# 优先使用环境变量
|
|||
|
|
host = os.getenv("DB_HOST", config.get("host", "localhost"))
|
|||
|
|
port = os.getenv("DB_PORT", config.get("port", "5432"))
|
|||
|
|
user = os.getenv("DB_USER", config.get("username", "postgres"))
|
|||
|
|
password = os.getenv("DB_PASSWORD", config.get("password", ""))
|
|||
|
|
dbname = os.getenv("DB_NAME", config.get("name", "agentcoord"))
|
|||
|
|
|
|||
|
|
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 创建引擎
|
|||
|
|
DATABASE_URL = get_database_url()
|
|||
|
|
engine = create_engine(
|
|||
|
|
DATABASE_URL,
|
|||
|
|
poolclass=QueuePool,
|
|||
|
|
pool_size=config.get("pool_size", 10),
|
|||
|
|
max_overflow=config.get("max_overflow", 20),
|
|||
|
|
pool_pre_ping=True,
|
|||
|
|
echo=False,
|
|||
|
|
# JSONB 类型处理器配置
|
|||
|
|
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
|||
|
|
json_deserializer=lambda s: json.loads(s) if isinstance(s, str) else s
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 创建会话工厂
|
|||
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|||
|
|
|
|||
|
|
# 基础类
|
|||
|
|
Base = declarative_base()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_db() -> Generator:
|
|||
|
|
"""
|
|||
|
|
获取数据库会话
|
|||
|
|
用法: for db in get_db(): ...
|
|||
|
|
"""
|
|||
|
|
db = SessionLocal()
|
|||
|
|
try:
|
|||
|
|
yield db
|
|||
|
|
finally:
|
|||
|
|
db.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@contextmanager
|
|||
|
|
def get_db_context() -> Generator:
|
|||
|
|
"""
|
|||
|
|
上下文管理器方式获取数据库会话
|
|||
|
|
用法: with get_db_context() as db: ...
|
|||
|
|
"""
|
|||
|
|
db = SessionLocal()
|
|||
|
|
try:
|
|||
|
|
yield db
|
|||
|
|
db.commit()
|
|||
|
|
except Exception as e:
|
|||
|
|
db.rollback()
|
|||
|
|
raise
|
|||
|
|
finally:
|
|||
|
|
db.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_connection() -> bool:
|
|||
|
|
"""测试数据库连接"""
|
|||
|
|
try:
|
|||
|
|
with engine.connect() as conn:
|
|||
|
|
conn.execute(text("SELECT 1"))
|
|||
|
|
return True
|
|||
|
|
except Exception as e:
|
|||
|
|
return False
|