rasa人机对话脚本生成

connection.py 2.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import logging
  2. from sqlalchemy import create_engine
  3. from sqlalchemy.orm import sessionmaker, Session
  4. from sqlalchemy.exc import SQLAlchemyError
  5. from db.models import Base
  6. from config import config
  7. # 初始化日志
  8. logger = logging.getLogger(__name__)
  9. class Database:
  10. """数据库连接管理类"""
  11. def __init__(self):
  12. """初始化数据库连接"""
  13. self.connection_string = config.get("database.connection_string")
  14. self.pool_size = config.get("database.pool_size", 10)
  15. self.timeout = config.get("database.timeout", 30)
  16. if not self.connection_string:
  17. raise ValueError("数据库连接字符串未配置")
  18. try:
  19. # 创建数据库引擎
  20. self.engine = create_engine(
  21. self.connection_string,
  22. pool_size=self.pool_size,
  23. pool_recycle=3600, # 1小时后回收连接
  24. connect_args={"connect_timeout": self.timeout}
  25. )
  26. # 创建会话工厂
  27. self.SessionLocal = sessionmaker(
  28. autocommit=False,
  29. autoflush=False,
  30. bind=self.engine
  31. )
  32. # 验证数据库连接
  33. self._test_connection()
  34. logger.info("数据库连接初始化成功")
  35. except SQLAlchemyError as e:
  36. logger.error(f"数据库连接初始化失败: {str(e)}")
  37. raise Exception(f"数据库连接初始化失败: {str(e)}")
  38. except Exception as e:
  39. logger.error(f"初始化数据库时发生错误: {str(e)}")
  40. raise Exception(f"初始化数据库时发生错误: {str(e)}")
  41. def _test_connection(self) -> None:
  42. """测试数据库连接是否有效"""
  43. try:
  44. with self.engine.connect():
  45. pass # 连接成功
  46. except SQLAlchemyError as e:
  47. raise Exception(f"数据库连接测试失败: {str(e)}")
  48. def create_tables(self) -> None:
  49. """创建数据库表"""
  50. try:
  51. Base.metadata.create_all(bind=self.engine)
  52. logger.info("数据库表创建/验证成功")
  53. except SQLAlchemyError as e:
  54. logger.error(f"创建数据库表失败: {str(e)}")
  55. raise Exception(f"创建数据库表失败: {str(e)}")
  56. def get_session(self) -> Session:
  57. """获取数据库会话"""
  58. session = self.SessionLocal()
  59. try:
  60. yield session
  61. except SQLAlchemyError as e:
  62. logger.error(f"数据库会话操作失败: {str(e)}")
  63. session.rollback()
  64. raise
  65. finally:
  66. session.close()
  67. # 全局数据库实例
  68. db = Database()
  69. # 用于FastAPI依赖的会话获取函数
  70. def get_db():
  71. """FastAPI依赖项:获取数据库会话"""
  72. yield from db.get_session()