| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import logging
- from sqlalchemy import create_engine
- from sqlalchemy.orm import sessionmaker, Session
- from sqlalchemy.exc import SQLAlchemyError
- from db.models import Base
- from config import config
- # 初始化日志
- logger = logging.getLogger(__name__)
- class Database:
- """数据库连接管理类"""
-
- def __init__(self):
- """初始化数据库连接"""
- self.connection_string = config.get("database.connection_string")
- self.pool_size = config.get("database.pool_size", 10)
- self.timeout = config.get("database.timeout", 30)
-
- if not self.connection_string:
- raise ValueError("数据库连接字符串未配置")
-
- try:
- # 创建数据库引擎
- self.engine = create_engine(
- self.connection_string,
- pool_size=self.pool_size,
- pool_recycle=3600, # 1小时后回收连接
- connect_args={"connect_timeout": self.timeout}
- )
-
- # 创建会话工厂
- self.SessionLocal = sessionmaker(
- autocommit=False,
- autoflush=False,
- bind=self.engine
- )
-
- # 验证数据库连接
- self._test_connection()
-
- logger.info("数据库连接初始化成功")
-
- except SQLAlchemyError as e:
- logger.error(f"数据库连接初始化失败: {str(e)}")
- raise Exception(f"数据库连接初始化失败: {str(e)}")
- except Exception as e:
- logger.error(f"初始化数据库时发生错误: {str(e)}")
- raise Exception(f"初始化数据库时发生错误: {str(e)}")
-
- def _test_connection(self) -> None:
- """测试数据库连接是否有效"""
- try:
- with self.engine.connect():
- pass # 连接成功
- except SQLAlchemyError as e:
- raise Exception(f"数据库连接测试失败: {str(e)}")
-
- def create_tables(self) -> None:
- """创建数据库表"""
- try:
- Base.metadata.create_all(bind=self.engine)
- logger.info("数据库表创建/验证成功")
- except SQLAlchemyError as e:
- logger.error(f"创建数据库表失败: {str(e)}")
- raise Exception(f"创建数据库表失败: {str(e)}")
-
- def get_session(self) -> Session:
- """获取数据库会话"""
- session = self.SessionLocal()
- try:
- yield session
- except SQLAlchemyError as e:
- logger.error(f"数据库会话操作失败: {str(e)}")
- session.rollback()
- raise
- finally:
- session.close()
- # 全局数据库实例
- db = Database()
- # 用于FastAPI依赖的会话获取函数
- def get_db():
- """FastAPI依赖项:获取数据库会话"""
- yield from db.get_session()
|