rasa人机对话脚本生成

rasa_generator.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import os
  2. import json
  3. import logging
  4. from pathlib import Path
  5. from sqlalchemy.orm import Session
  6. from sqlalchemy.exc import SQLAlchemyError
  7. from db import (
  8. Intent, IntentExample, Story, StoryStep, CustomAction,
  9. SystemConfig
  10. )
  11. from config import config
  12. # 初始化日志
  13. logger = logging.getLogger(__name__)
  14. class RasaFileGenerator:
  15. """Rasa配置文件生成服务,带完整异常处理"""
  16. def __init__(self):
  17. """初始化生成器,从配置获取输出目录"""
  18. try:
  19. self.output_dir = config.get("file_storage.rasa_files_dir", "./rasa_files")
  20. # 创建输出目录
  21. self._create_directories()
  22. logger.info(f"Rasa文件生成器初始化成功,输出目录: {self.output_dir}")
  23. except Exception as e:
  24. logger.error(f"初始化Rasa文件生成器失败: {str(e)}")
  25. raise Exception(f"初始化Rasa文件生成器失败: {str(e)}")
  26. def _create_directories(self) -> None:
  27. """创建必要的目录结构"""
  28. try:
  29. Path(self.output_dir).mkdir(parents=True, exist_ok=True)
  30. Path(f"{self.output_dir}/nlu").mkdir(parents=True, exist_ok=True)
  31. Path(f"{self.output_dir}/stories").mkdir(parents=True, exist_ok=True)
  32. Path(f"{self.output_dir}/actions").mkdir(parents=True, exist_ok=True)
  33. except OSError as e:
  34. raise Exception(f"创建目录失败: {str(e)}")
  35. def generate_all_files(self, session: Session) -> dict:
  36. """
  37. 生成所有Rasa配置文件
  38. Args:
  39. session: 数据库会话
  40. Returns:
  41. 生成结果,包含文件路径和状态
  42. """
  43. try:
  44. # 生成各文件
  45. domain_path = self.generate_domain_file(session)
  46. nlu_path = self.generate_nlu_file(session)
  47. story_paths = self.generate_stories_files(session)
  48. action_path = self.generate_actions_file(session)
  49. return {
  50. "success": True,
  51. "message": "所有文件生成成功",
  52. "output_dir": self.output_dir,
  53. "files": {
  54. "domain": domain_path,
  55. "nlu": nlu_path,
  56. "stories": story_paths,
  57. "actions": action_path
  58. }
  59. }
  60. except SQLAlchemyError as e:
  61. logger.error(f"数据库错误导致文件生成失败: {str(e)}")
  62. return {
  63. "success": False,
  64. "message": f"数据库错误: {str(e)}",
  65. }
  66. except OSError as e:
  67. logger.error(f"文件系统错误导致文件生成失败: {str(e)}")
  68. return {
  69. "success": False,
  70. "message": f"文件系统错误: {str(e)}",
  71. }
  72. except Exception as e:
  73. logger.error(f"文件生成失败: {str(e)}")
  74. return {
  75. "success": False,
  76. "message": f"生成文件时发生错误: {str(e)}",
  77. }
  78. def generate_domain_file(self, session: Session) -> str:
  79. """生成domain.yml文件"""
  80. try:
  81. domain_data = {
  82. "version": "3.1",
  83. "intents": [],
  84. "slots": {},
  85. "forms": {},
  86. "responses": {},
  87. "actions": []
  88. }
  89. # 获取指定版本的意图
  90. intents = session.query(Intent).all()
  91. if not intents:
  92. logger.warning(f"未找到任何意图")
  93. # 添加意图
  94. domain_data["intents"] = [intent.name for intent in intents]
  95. # 获取指定版本的自定义动作
  96. actions = session.query(CustomAction).all()
  97. # 添加动作
  98. domain_data["actions"] = [action.name for action in actions]
  99. # 实际应用中还需要添加槽位、表单和响应
  100. # 写入文件
  101. file_path = f"{self.output_dir}/domain.yml"
  102. with open(file_path, "w", encoding="utf-8") as f:
  103. self._write_yaml(f, domain_data)
  104. logger.info(f"生成domain文件: {file_path}")
  105. return file_path
  106. except Exception as e:
  107. logger.error(f"生成domain文件失败: {str(e)}")
  108. raise Exception(f"生成domain文件失败: {str(e)}")
  109. def generate_nlu_file(self, session: Session) -> str:
  110. """生成nlu.yml文件"""
  111. try:
  112. nlu_data = {
  113. "version": "3.1",
  114. "nlu": []
  115. }
  116. # 获取指定版本的意图及样本
  117. intents = session.query(Intent).all()
  118. for intent in intents:
  119. # 获取该意图指定版本的样本
  120. examples = session.query(IntentExample).filter(
  121. IntentExample.intent_id == intent.id,
  122. ).all()
  123. # 构建examples字符串,使用正确的YAML多行文本格式
  124. examples_str = "\n".join([f" - {ex.text}" for ex in examples])
  125. intent_entry = {
  126. "intent": intent.name,
  127. "examples": examples_str
  128. }
  129. if intent.description:
  130. intent_entry["intent"] = f"{intent.name} # {intent.description}"
  131. nlu_data["nlu"].append(intent_entry)
  132. # 写入文件
  133. file_path = f"{self.output_dir}/nlu/nlu.yml"
  134. with open(file_path, "w", encoding="utf-8") as f:
  135. self._write_yaml(f, nlu_data)
  136. logger.info(f"生成nlu文件: {file_path}")
  137. return file_path
  138. except Exception as e:
  139. logger.error(f"生成nlu文件失败: {str(e)}")
  140. raise Exception(f"生成nlu文件失败: {str(e)}")
  141. def generate_stories_files(self, session: Session) -> list[str]:
  142. """生成stories文件"""
  143. try:
  144. files = []
  145. # 获取指定版本的故事
  146. stories = session.query(Story).all()
  147. if not stories:
  148. logger.warning(f"未找到任何故事")
  149. # 生成合并的stories.yml
  150. all_stories = {
  151. "version": "3.1",
  152. "stories": []
  153. }
  154. for story in stories:
  155. # 获取故事步骤
  156. steps = session.query(StoryStep).filter(
  157. StoryStep.story_id == story.id
  158. ).order_by(StoryStep.step_order).all()
  159. # 构建故事内容
  160. story_entry = {
  161. "story": story.name,
  162. "steps": []
  163. }
  164. if story.description:
  165. story_entry["story"] = f"{story.name} # {story.description}"
  166. # 处理步骤
  167. for step in steps:
  168. content = step.content_dict
  169. step_type = step.step_type
  170. if step_type == "intent":
  171. story_entry["steps"].append({"intent": content.get("name")})
  172. elif step_type == "action":
  173. story_entry["steps"].append({"action": content.get("name")})
  174. elif step_type == "form":
  175. if content.get("activate", True):
  176. story_entry["steps"].append({"action": content.get("name")})
  177. else:
  178. story_entry["steps"].append({"action": f"form_deactivate_{content.get('name')}"})
  179. # 处理其他类型的步骤...
  180. all_stories["stories"].append(story_entry)
  181. # 生成单个故事文件
  182. story_file_name = story.name.lower().replace(" ", "_") + ".yml"
  183. story_file_path = f"{self.output_dir}/stories/{story_file_name}"
  184. with open(story_file_path, "w", encoding="utf-8") as f:
  185. self._write_yaml(f, {
  186. "version": "3.1",
  187. "stories": [story_entry]
  188. })
  189. files.append(story_file_path)
  190. # 写入合并的stories.yml
  191. merged_file_path = f"{self.output_dir}/stories/stories.yml"
  192. with open(merged_file_path, "w", encoding="utf-8") as f:
  193. self._write_yaml(f, all_stories)
  194. files.append(merged_file_path)
  195. logger.info(f"生成stories文件: {len(files)} 个文件")
  196. return files
  197. except Exception as e:
  198. logger.error(f"生成stories文件失败: {str(e)}")
  199. raise Exception(f"生成stories文件失败: {str(e)}")
  200. def generate_actions_file(self, session: Session) -> str:
  201. """生成自定义动作Python文件"""
  202. try:
  203. actions = session.query(CustomAction).all()
  204. if not actions:
  205. logger.warning(f"未找到任何自定义动作")
  206. # 生成actions.py内容
  207. code = [
  208. "from rasa_sdk import Action, Tracker",
  209. "from rasa_sdk.executor import CollectingDispatcher",
  210. "from rasa_sdk.events import SlotSet",
  211. "import requests",
  212. "import json\n"
  213. ]
  214. for action in actions:
  215. class_name = f"Action{''.join(word.capitalize() for word in action.name.split('_'))}"
  216. # 生成类定义
  217. code.append(f"class {class_name}(Action):")
  218. code.append(f" def name(self) -> str:")
  219. code.append(f" return \"{action.name}\"\n")
  220. # 生成run方法
  221. code.append(f" def run(self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: dict) -> list:")
  222. code.append(" # 构建请求头")
  223. code.append(" headers = {}")
  224. # 添加Token
  225. if action.token:
  226. code.append(f" headers[\"Authorization\"] = \"Bearer {action.token}\"")
  227. # 添加自定义请求头
  228. headers = action.headers_dict
  229. for key, value in headers.items():
  230. code.append(f" headers[\"{key}\"] = \"{value}\"")
  231. # 添加内容类型
  232. code.append(" headers[\"Content-Type\"] = \"application/json\"\n")
  233. # 构建请求体
  234. if action.http_method in ["POST", "PUT"] and action.request_body:
  235. code.append(" # 构建请求体")
  236. code.append(f" payload = {action.request_body}")
  237. # 替换模板变量
  238. code.append(" # 替换模板变量")
  239. code.append(" from jinja2 import Template")
  240. code.append(" payload_str = json.dumps(payload)")
  241. code.append(" template = Template(payload_str)")
  242. code.append(" payload = json.loads(template.render(tracker.slots))\n")
  243. # 发送请求
  244. code.append(" # 发送请求")
  245. code.append(" try:")
  246. if action.http_method == "GET":
  247. code.append(f" response = requests.get(\"{action.api_url}\", headers=headers)")
  248. elif action.http_method == "POST":
  249. code.append(f" response = requests.post(\"{action.api_url}\", json=payload, headers=headers)")
  250. elif action.http_method == "PUT":
  251. code.append(f" response = requests.put(\"{action.api_url}\", json=payload, headers=headers)")
  252. elif action.http_method == "DELETE":
  253. code.append(f" response = requests.delete(\"{action.api_url}\", headers=headers)")
  254. else:
  255. code.append(f" dispatcher.utter_message(text=f\"不支持的HTTP方法: {action.http_method}\")")
  256. code.append(" return []")
  257. # 处理响应
  258. code.append(" if response.status_code == 200:")
  259. code.append(" result = response.json()")
  260. # 处理响应映射
  261. if action.response_mapping:
  262. code.append(" # 处理响应映射")
  263. code.append(f" mappings = {action.response_mapping}")
  264. code.append(" slot_events = []")
  265. code.append(" for slot, path in mappings.items():")
  266. code.append(" # 简化的路径解析")
  267. code.append(" value = result")
  268. code.append(" for part in path.split('.'):")
  269. code.append(" if isinstance(value, dict) and part in value:")
  270. code.append(" value = value[part]")
  271. code.append(" else:")
  272. code.append(" value = None")
  273. code.append(" break")
  274. code.append(" if value is not None:")
  275. code.append(" slot_events.append(SlotSet(slot, value))")
  276. code.append(" dispatcher.utter_message(text=str(result))")
  277. code.append(" return slot_events")
  278. else:
  279. code.append(" dispatcher.utter_message(text=str(result))")
  280. code.append(" else:")
  281. code.append(" dispatcher.utter_message(text=f\"API调用失败,状态码: {response.status_code}\")")
  282. code.append(" except Exception as e:")
  283. code.append(" dispatcher.utter_message(text=f\"调用API时发生错误: {str(e)}\")\n")
  284. code.append(" return []\n")
  285. # 写入文件
  286. file_path = f"{self.output_dir}/actions/actions.py"
  287. with open(file_path, "w", encoding="utf-8") as f:
  288. f.write("\n".join(code))
  289. logger.info(f"生成actions文件: {file_path}")
  290. return file_path
  291. except Exception as e:
  292. logger.error(f"生成actions文件失败: {str(e)}")
  293. raise Exception(f"生成actions文件失败: {str(e)}")
  294. def _write_yaml(self, file, data, indent: int = 0, reset: bool = False) -> None:
  295. """
  296. 简单的YAML写入函数
  297. Args:
  298. file: 文件对象
  299. data: 要写入的数据
  300. indent: 当前缩进
  301. """
  302. try:
  303. indent_str = ""
  304. if (indent != 0):
  305. indent_str = " " * indent
  306. if isinstance(data, dict):
  307. for key, value in data.items():
  308. if isinstance(value, (dict, list)):
  309. file.write(f"{indent_str}{key}:\n")
  310. self._write_yaml(file, value, indent + 1)
  311. else:
  312. if (key != 'examples'):
  313. file.write(f"{key}: {self._format_yaml_value(value)}\n")
  314. else:
  315. file.write(f"{indent_str}{key}: {self._format_yaml_value(value)}\n")
  316. elif isinstance(data, list):
  317. for item in data:
  318. if isinstance(item, (dict, list)):
  319. file.write(f"{indent_str}- ")
  320. self._write_yaml(file, item, indent + 1, True)
  321. else:
  322. file.write(f"{indent_str}- {self._format_yaml_value(item)}\n")
  323. else:
  324. file.write(f"{self._format_yaml_value(data)}\n")
  325. except Exception as e:
  326. raise Exception(f"写入YAML数据失败: {str(e)}")
  327. def _format_yaml_value(self, value) -> str:
  328. """格式化YAML值"""
  329. if isinstance(value, str) and (":" in value or "\n" in value):
  330. return f"|{chr(10)}{value}"
  331. elif isinstance(value, str):
  332. return f'{value}'
  333. elif isinstance(value, bool):
  334. return "true" if value else "false"
  335. elif value is None:
  336. return "null"
  337. else:
  338. return str(value)