import logging from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.exc import SQLAlchemyError from db.models import Base from config import config # 初始化日志 logger = logging.getLogger(__name__) class Database: """数据库连接管理类""" def __init__(self): """初始化数据库连接""" self.connection_string = config.get("database.connection_string") self.pool_size = config.get("database.pool_size", 10) self.timeout = config.get("database.timeout", 30) if not self.connection_string: raise ValueError("数据库连接字符串未配置") try: # 创建数据库引擎 self.engine = create_engine( self.connection_string, pool_size=self.pool_size, pool_recycle=3600, # 1小时后回收连接 connect_args={"connect_timeout": self.timeout} ) # 创建会话工厂 self.SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=self.engine ) # 验证数据库连接 self._test_connection() logger.info("数据库连接初始化成功") except SQLAlchemyError as e: logger.error(f"数据库连接初始化失败: {str(e)}") raise Exception(f"数据库连接初始化失败: {str(e)}") except Exception as e: logger.error(f"初始化数据库时发生错误: {str(e)}") raise Exception(f"初始化数据库时发生错误: {str(e)}") def _test_connection(self) -> None: """测试数据库连接是否有效""" try: with self.engine.connect(): pass # 连接成功 except SQLAlchemyError as e: raise Exception(f"数据库连接测试失败: {str(e)}") def create_tables(self) -> None: """创建数据库表""" try: Base.metadata.create_all(bind=self.engine) logger.info("数据库表创建/验证成功") except SQLAlchemyError as e: logger.error(f"创建数据库表失败: {str(e)}") raise Exception(f"创建数据库表失败: {str(e)}") def get_session(self) -> Session: """获取数据库会话""" session = self.SessionLocal() try: yield session except SQLAlchemyError as e: logger.error(f"数据库会话操作失败: {str(e)}") session.rollback() raise finally: session.close() # 全局数据库实例 db = Database() # 用于FastAPI依赖的会话获取函数 def get_db(): """FastAPI依赖项:获取数据库会话""" yield from db.get_session()