rasa人机对话脚本生成

rasa_generator_by_json.py 24KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. import os
  2. import sys
  3. import json
  4. import logging
  5. from pathlib import Path
  6. # 添加项目根目录到Python搜索路径
  7. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
  8. from config import Config
  9. # 初始化日志
  10. logger = logging.getLogger(__name__)
  11. class RasaFileGenerator:
  12. """Rasa配置文件生成服务,带完整异常处理"""
  13. def __init__(self):
  14. """初始化生成器,从配置获取输出目录"""
  15. try:
  16. # 创建配置实例
  17. self.config = Config()
  18. self.output_dir = self.config.get("file_storage.rasa_files_dir", "./rasa_files")
  19. # 创建输出目录
  20. self._create_directories()
  21. logger.info(f"Rasa文件生成器初始化成功,输出目录: {self.output_dir}")
  22. except Exception as e:
  23. logger.error(f"初始化Rasa文件生成器失败: {str(e)}")
  24. raise Exception(f"初始化Rasa文件生成器失败: {str(e)}")
  25. def _load_flow_json(self, flow_json_path: str) -> dict:
  26. """加载并解析flow.json文件"""
  27. try:
  28. with open(flow_json_path, 'r', encoding='utf-8') as f:
  29. return json.load(f)
  30. except FileNotFoundError:
  31. raise Exception(f"flow.json文件未找到: {flow_json_path}")
  32. except json.JSONDecodeError:
  33. raise Exception(f"flow.json文件格式无效")
  34. except Exception as e:
  35. raise Exception(f"加载flow.json文件失败: {str(e)}")
  36. def _create_directories(self) -> None:
  37. """创建必要的目录结构"""
  38. try:
  39. Path(self.output_dir).mkdir(parents=True, exist_ok=True)
  40. Path(f"{self.output_dir}/nlu").mkdir(parents=True, exist_ok=True)
  41. Path(f"{self.output_dir}/stories").mkdir(parents=True, exist_ok=True)
  42. Path(f"{self.output_dir}/actions").mkdir(parents=True, exist_ok=True)
  43. except OSError as e:
  44. raise Exception(f"创建目录失败: {str(e)}")
  45. def generate_all_files(self, flow_json_str: str, flow_name: str) -> dict:
  46. """
  47. 生成所有Rasa配置文件
  48. Args:
  49. flow_json_str: flow数据的JSON字符串
  50. Returns:
  51. 生成结果,包含文件路径和状态
  52. """
  53. try:
  54. # 解析flow JSON字符串
  55. flow_data = json.loads(flow_json_str)
  56. # 生成各文件
  57. domain_path = self.generate_domain_file(flow_data)
  58. nlu_path = self.generate_nlu_file(flow_data)
  59. story_paths = self.generate_stories_files(flow_data, flow_name)
  60. action_path = self.generate_actions_file(flow_data)
  61. return {
  62. "success": True,
  63. "message": "所有文件生成成功",
  64. "output_dir": self.output_dir,
  65. "files": {
  66. "domain": domain_path,
  67. "nlu": nlu_path,
  68. "stories": story_paths,
  69. "actions": action_path
  70. }
  71. }
  72. except OSError as e:
  73. logger.error(f"文件系统错误导致文件生成失败: {str(e)}")
  74. return {
  75. "success": False,
  76. "message": f"文件系统错误: {str(e)}",
  77. }
  78. except Exception as e:
  79. logger.error(f"文件生成失败: {str(e)}")
  80. return {
  81. "success": False,
  82. "message": f"生成文件时发生错误: {str(e)}",
  83. }
  84. def generate_domain_file(self, flow_data: dict) -> str:
  85. """生成domain.yml文件"""
  86. try:
  87. domain_data = {
  88. "version": "3.1",
  89. "intents": [],
  90. "slots": {},
  91. "forms": {},
  92. "responses": {},
  93. "actions": []
  94. }
  95. # 从flow_data中提取意图和动作
  96. nodes = flow_data.get("flowJson", {}).get("nodes", [])
  97. intents = []
  98. actions = []
  99. forms = []
  100. for node in nodes:
  101. node_type = node.get("type")
  102. properties = node.get("properties", {})
  103. code = properties.get("code")
  104. if node_type == "intention" and code:
  105. intents.append(code)
  106. elif node_type == "action" and code:
  107. actions.append(code)
  108. elif node_type == "form" and code:
  109. actions.append(code)
  110. elif node_type == "collection" and code:
  111. forms.append(code)
  112. actions.append(f"form_{code}")
  113. # 添加意图
  114. domain_data["intents"] = intents
  115. # 添加动作
  116. domain_data["actions"] = actions
  117. # 添加表单
  118. for form_code in forms:
  119. domain_data["forms"][form_code] = {}
  120. # 添加槽位 - 从表单节点提取
  121. for node in nodes:
  122. if node.get("type") == "collection":
  123. form_code = node.get("properties", {}).get("code")
  124. form_fields = node.get("properties", {}).get("formFields", [])
  125. for field in form_fields:
  126. slot_name = field.get("slotName") or f"{form_code}_{field.get('entityType')}"
  127. domain_data["slots"][slot_name] = {
  128. "type": "text" # 默认为text类型,可根据需要调整
  129. }
  130. # 写入文件
  131. file_path = f"{self.output_dir}/domain.yml"
  132. with open(file_path, "w", encoding="utf-8") as f:
  133. self._write_yaml(f, domain_data)
  134. logger.info(f"生成domain文件: {file_path}")
  135. return file_path
  136. except Exception as e:
  137. logger.error(f"生成domain文件失败: {str(e)}")
  138. raise Exception(f"生成domain文件失败: {str(e)}")
  139. def generate_nlu_file(self, flow_data: dict) -> str:
  140. """生成nlu.yml文件"""
  141. try:
  142. nlu_data = {
  143. "version": "3.1",
  144. "nlu": []
  145. }
  146. # 从flow_data中提取意图及样本
  147. nodes = flow_data.get("flowJson", {}).get("nodes", [])
  148. for node in nodes:
  149. if node.get("type") == "intention":
  150. properties = node.get("properties", {})
  151. intent_name = properties.get("code")
  152. intent_desc = properties.get("desc")
  153. samples = properties.get("samples", [])
  154. if not intent_name:
  155. continue
  156. # {
  157. # "text": "我想定明天的酒店",
  158. # "entities": [
  159. # {
  160. # "text": "明天",
  161. # "label": "DATE",
  162. # "start": 3,
  163. # "end": 4
  164. # }
  165. # ]
  166. # },
  167. # entities = samples.get("entities", [])
  168. # 如果entities存在说明需要在样本中标记实体
  169. # 处理实体并构建examples字符串
  170. def format_sample_with_entities(sample):
  171. text = sample.get('text', '')
  172. entities = sample.get('entities', [])
  173. # 没有实体,直接返回文本
  174. if not entities:
  175. return text
  176. # 按start位置降序排列实体,确保从后向前处理,避免替换位置偏移
  177. sorted_entities = sorted(entities, key=lambda e: e.get('start', 0), reverse=True)
  178. for entity in sorted_entities:
  179. start = entity.get('start', 0)
  180. end = entity.get('end', 0) + 1 # 因为end是 exclusive 的
  181. entity_text = text[start:end]
  182. label = entity.get('label', 'UNKNOWN')
  183. # 替换文本中的实体部分
  184. text = text[:start] + f"[{entity_text}]({label})" + text[end:]
  185. return text
  186. examples_str = "\n".join([f" - {format_sample_with_entities(sample)}" for sample in samples])
  187. intent_entry = {
  188. "intent": intent_name,
  189. "examples": examples_str
  190. }
  191. if intent_desc:
  192. intent_entry["intent"] = f"{intent_name} # {intent_desc}"
  193. nlu_data["nlu"].append(intent_entry)
  194. # 写入文件
  195. file_path = f"{self.output_dir}/nlu/nlu.yml"
  196. with open(file_path, "w", encoding="utf-8") as f:
  197. self._write_yaml(f, nlu_data)
  198. logger.info(f"生成nlu文件: {file_path}")
  199. return file_path
  200. except Exception as e:
  201. logger.error(f"生成nlu文件失败: {str(e)}")
  202. raise Exception(f"生成nlu文件失败: {str(e)}")
  203. def generate_stories_files(self, flow_data: dict, flow_name: str) -> list[str]:
  204. """生成stories文件"""
  205. try:
  206. files = []
  207. # 从flow_data中提取故事信息
  208. # flow_name = flow_data.get("flowName", "default_flow")
  209. nodes = flow_data.get("flowJson", {}).get("nodes", [])
  210. edges = flow_data.get("flowJson", {}).get("edges", [])
  211. # 构建节点ID到节点的映射
  212. node_map = {node.get("id"): node for node in nodes}
  213. # 构建边的映射 (source -> target)
  214. edge_map = {}
  215. for edge in edges:
  216. source = edge.get("sourceNodeId")
  217. target = edge.get("targetNodeId")
  218. if source not in edge_map:
  219. edge_map[source] = []
  220. edge_map[source].append(target)
  221. # 找到开始节点
  222. start_node = None
  223. for node in nodes:
  224. if node.get("type") == "start":
  225. start_node = node
  226. break
  227. if not start_node:
  228. logger.warning("未找到开始节点")
  229. return files
  230. # 构建故事步骤
  231. story_steps = []
  232. current_node = start_node
  233. while current_node:
  234. node_type = current_node.get("type")
  235. properties = current_node.get("properties", {})
  236. code = properties.get("code")
  237. if node_type == "intention" and code:
  238. story_steps.append({"intent": code})
  239. elif node_type == "action" and code:
  240. story_steps.append({"action": code})
  241. elif node_type == "collection" and code:
  242. story_steps.append({"action": f"form_{code}"})
  243. # 添加表单激活后的槽位填充
  244. story_steps.append({"active_loop": {"name": code}})
  245. story_steps.append({"active_loop": None})
  246. elif node_type == "form" and code:
  247. story_steps.append({"action": code})
  248. elif node_type == "condition" and code:
  249. # 简化处理条件节点
  250. story_steps.append({"action": code})
  251. # 这里应该根据条件分支生成不同的故事路径,但为简化起见,我们只取第一条路径
  252. # 找到下一个节点
  253. current_node_id = current_node.get("id")
  254. next_nodes = edge_map.get(current_node_id, [])
  255. if not next_nodes or node_type == "end":
  256. break
  257. # 简单处理,只取第一个下一个节点
  258. current_node = node_map.get(next_nodes[0])
  259. # 构建故事内容
  260. story_entry = {
  261. "story": flow_name,
  262. "steps": story_steps
  263. }
  264. # 生成合并的stories.yml
  265. all_stories = {
  266. "version": "3.1",
  267. "stories": [story_entry]
  268. }
  269. # 写入合并的stories.yml
  270. merged_file_path = f"{self.output_dir}/stories/stories.yml"
  271. with open(merged_file_path, "w", encoding="utf-8") as f:
  272. self._write_yaml(f, all_stories)
  273. files.append(merged_file_path)
  274. logger.info(f"生成stories文件: {len(files)} 个文件")
  275. return files
  276. except Exception as e:
  277. logger.error(f"生成stories文件失败: {str(e)}")
  278. raise Exception(f"生成stories文件失败: {str(e)}")
  279. def generate_actions_file(self, flow_data: dict) -> str:
  280. """生成自定义动作Python文件"""
  281. try:
  282. # 从flow_data中提取动作信息
  283. nodes = flow_data.get("flowJson", {}).get("nodes", [])
  284. actions = []
  285. for node in nodes:
  286. node_type = node.get("type")
  287. if node_type == "action" or node_type == "form" or node_type == "condition":
  288. properties = node.get("properties", {})
  289. if properties.get("code"):
  290. actions.append({
  291. "type": node_type,
  292. "properties": properties
  293. })
  294. if not actions:
  295. logger.warning(f"未找到任何自定义动作")
  296. # 生成actions.py内容
  297. code = [
  298. "from rasa_sdk import Action, Tracker",
  299. "from rasa_sdk.executor import CollectingDispatcher",
  300. "from rasa_sdk.events import SlotSet, ActiveLoop",
  301. "import requests",
  302. "import json\n"
  303. ]
  304. for action in actions:
  305. properties = action.get("properties", {})
  306. action_name = properties.get("code")
  307. action_type = action.get("type")
  308. if not action_name:
  309. continue
  310. class_name = f"Action{''.join(word.capitalize() for word in action_name.split('_'))}"
  311. # 生成类定义
  312. code.append(f"class {class_name}(Action):")
  313. code.append(f" def name(self) -> str:")
  314. code.append(f" return \"{action_name}\"\n")
  315. # 生成run方法
  316. code.append(f" def run(self, dispatcher: CollectingDispatcher, tracker: Tracker, domain: dict) -> list:")
  317. if action_type == "action":
  318. # 普通动作节点
  319. config_text = properties.get("configText", "")
  320. code.append(f" dispatcher.utter_message(text=\"{config_text}\")")
  321. code.append(" return []\n")
  322. elif action_type == "form":
  323. # 表单节点(API调用)
  324. code.append(" # 构建请求头")
  325. code.append(" headers = {}")
  326. # 添加自定义请求头
  327. headers = properties.get("headers", [])
  328. for header in headers:
  329. code.append(f" headers[\"{header.get('key')}\"] = \"{header.get('value')}\"")
  330. # 添加内容类型
  331. code.append(" headers[\"Content-Type\"] = \"application/json\"\n")
  332. # 构建请求参数
  333. params = properties.get("params", [])
  334. if params:
  335. code.append(" # 构建请求参数")
  336. code.append(" params = {}")
  337. for param in params:
  338. entity = param.get("entity")
  339. param_name = param.get("name")
  340. code.append(f" params[\"{param_name}\"] = tracker.get_slot(\"{entity}\")")
  341. code.append("")
  342. # 发送请求
  343. code.append(" # 发送请求")
  344. code.append(" try:")
  345. method = properties.get("requestMethod")
  346. url = properties.get("requestUrl")
  347. if method == "GET":
  348. code.append(f" response = requests.get(\"{url}\", headers=headers, params=params)")
  349. elif method == "POST":
  350. code.append(f" response = requests.post(\"{url}\", headers=headers, params=params)")
  351. elif method == "PUT":
  352. code.append(f" response = requests.put(\"{url}\", headers=headers, params=params)")
  353. elif method == "DELETE":
  354. code.append(f" response = requests.delete(\"{url}\", headers=headers, params=params)")
  355. else:
  356. code.append(f" dispatcher.utter_message(text=f\"不支持的HTTP方法: {method}\")")
  357. code.append(" return []")
  358. # 处理响应
  359. code.append(" if response.status_code == 200:")
  360. code.append(" result = response.json()")
  361. # 处理响应映射
  362. response_mappings = properties.get("responseMappings", [])
  363. if response_mappings:
  364. code.append(" # 处理响应映射")
  365. code.append(" slot_events = []")
  366. for mapping in response_mappings:
  367. field = mapping.get("responseField")
  368. target = mapping.get("targetVar")
  369. default = mapping.get("defaultValue")
  370. code.append(f" # 提取 {field}")
  371. code.append(f" value = {default}")
  372. code.append(f" try:")
  373. code.append(f" # 简化的路径解析")
  374. code.append(f" parts = \"{field}\".split('.')")
  375. code.append(f" temp = result")
  376. code.append(f" for part in parts:")
  377. code.append(f" if isinstance(temp, dict) and part in temp:")
  378. code.append(f" temp = temp[part]")
  379. code.append(f" elif isinstance(temp, list) and part.isdigit() and int(part) < len(temp):")
  380. code.append(f" temp = temp[int(part)]")
  381. code.append(f" else:")
  382. code.append(f" raise Exception(\"路径无效\")")
  383. code.append(f" value = temp")
  384. code.append(f" except:")
  385. code.append(f" pass")
  386. code.append(f" slot_events.append(SlotSet(\"{target}\", value))")
  387. code.append(" dispatcher.utter_message(text=str(result))")
  388. code.append(" return slot_events")
  389. else:
  390. code.append(" dispatcher.utter_message(text=str(result))")
  391. code.append(" else:")
  392. code.append(" dispatcher.utter_message(text=f\"API调用失败,状态码: {response.status_code}\")")
  393. code.append(" except Exception as e:")
  394. code.append(" dispatcher.utter_message(text=f\"调用API时发生错误: {str(e)}\")\n")
  395. code.append(" return []\n")
  396. elif action_type == "condition":
  397. # 条件节点
  398. code.append(" # 处理条件逻辑")
  399. code.append(" # 这里简化处理,实际应用中应根据条件执行不同的逻辑")
  400. code.append(" return []\n")
  401. # 写入文件
  402. file_path = f"{self.output_dir}/actions/actions.py"
  403. with open(file_path, "w", encoding="utf-8") as f:
  404. f.write("\n".join(code))
  405. logger.info(f"生成actions文件: {file_path}")
  406. return file_path
  407. except Exception as e:
  408. logger.error(f"生成actions文件失败: {str(e)}")
  409. raise Exception(f"生成actions文件失败: {str(e)}")
  410. def _write_yaml(self, file, data, indent: int = 0, reset: bool = False) -> None:
  411. """
  412. 简单的YAML写入函数
  413. Args:
  414. file: 文件对象
  415. data: 要写入的数据
  416. indent: 当前缩进
  417. """
  418. try:
  419. indent_str = ""
  420. if (indent != 0):
  421. indent_str = " " * indent
  422. if isinstance(data, dict):
  423. for key, value in data.items():
  424. if isinstance(value, (dict, list)):
  425. file.write(f"{indent_str}{key}:\n")
  426. self._write_yaml(file, value, indent + 1)
  427. else:
  428. if (key != 'examples'):
  429. file.write(f"{key}: {self._format_yaml_value(value)}\n")
  430. else:
  431. file.write(f"{indent_str}{key}: {self._format_yaml_value(value)}\n")
  432. elif isinstance(data, list):
  433. for item in data:
  434. if isinstance(item, (dict, list)):
  435. file.write(f"{indent_str}- ")
  436. self._write_yaml(file, item, indent + 1, True)
  437. else:
  438. file.write(f"{indent_str}- {self._format_yaml_value(item)}\n")
  439. else:
  440. file.write(f"{self._format_yaml_value(data)}\n")
  441. except Exception as e:
  442. raise Exception(f"写入YAML数据失败: {str(e)}")
  443. def _format_yaml_value(self, value) -> str:
  444. """格式化YAML值"""
  445. if isinstance(value, str) and (":" in value or "\n" in value):
  446. return f"|{chr(10)}{value}"
  447. elif isinstance(value, str):
  448. return f'{value}'
  449. elif isinstance(value, bool):
  450. return "true" if value else "false"
  451. elif value is None:
  452. return "null"
  453. else:
  454. return str(value)