rasa-pro 3.8.18__py3-none-any.whl → 3.9.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (278) hide show
  1. README.md +6 -42
  2. rasa/__main__.py +14 -9
  3. rasa/anonymization/anonymization_pipeline.py +0 -1
  4. rasa/anonymization/anonymization_rule_executor.py +3 -3
  5. rasa/anonymization/utils.py +4 -3
  6. rasa/api.py +2 -2
  7. rasa/cli/arguments/default_arguments.py +1 -1
  8. rasa/cli/arguments/run.py +2 -2
  9. rasa/cli/arguments/test.py +1 -1
  10. rasa/cli/arguments/train.py +10 -10
  11. rasa/cli/e2e_test.py +27 -7
  12. rasa/cli/export.py +0 -1
  13. rasa/cli/license.py +3 -3
  14. rasa/cli/project_templates/calm/actions/action_template.py +1 -1
  15. rasa/cli/project_templates/calm/config.yml +1 -1
  16. rasa/cli/project_templates/calm/credentials.yml +1 -1
  17. rasa/cli/project_templates/calm/data/flows/add_contact.yml +1 -1
  18. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +1 -1
  19. rasa/cli/project_templates/calm/domain/add_contact.yml +8 -2
  20. rasa/cli/project_templates/calm/domain/list_contacts.yml +3 -0
  21. rasa/cli/project_templates/calm/domain/remove_contact.yml +9 -2
  22. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  23. rasa/cli/project_templates/calm/endpoints.yml +4 -4
  24. rasa/cli/project_templates/default/actions/actions.py +1 -1
  25. rasa/cli/project_templates/default/config.yml +5 -5
  26. rasa/cli/project_templates/default/credentials.yml +1 -1
  27. rasa/cli/project_templates/default/endpoints.yml +4 -4
  28. rasa/cli/project_templates/default/tests/test_stories.yml +1 -1
  29. rasa/cli/project_templates/tutorial/config.yml +1 -1
  30. rasa/cli/project_templates/tutorial/credentials.yml +1 -1
  31. rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
  32. rasa/cli/project_templates/tutorial/domain.yml +4 -0
  33. rasa/cli/project_templates/tutorial/endpoints.yml +6 -6
  34. rasa/cli/run.py +0 -1
  35. rasa/cli/scaffold.py +3 -2
  36. rasa/cli/studio/download.py +11 -0
  37. rasa/cli/studio/studio.py +180 -24
  38. rasa/cli/studio/upload.py +0 -8
  39. rasa/cli/telemetry.py +18 -6
  40. rasa/cli/utils.py +21 -10
  41. rasa/cli/x.py +3 -2
  42. rasa/constants.py +1 -1
  43. rasa/core/actions/action.py +90 -315
  44. rasa/core/actions/action_exceptions.py +24 -0
  45. rasa/core/actions/constants.py +3 -0
  46. rasa/core/actions/custom_action_executor.py +188 -0
  47. rasa/core/actions/forms.py +11 -7
  48. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  49. rasa/core/actions/http_custom_action_executor.py +140 -0
  50. rasa/core/actions/loops.py +3 -0
  51. rasa/core/actions/two_stage_fallback.py +1 -1
  52. rasa/core/agent.py +2 -4
  53. rasa/core/brokers/pika.py +1 -2
  54. rasa/core/channels/audiocodes.py +1 -1
  55. rasa/core/channels/botframework.py +0 -1
  56. rasa/core/channels/callback.py +0 -1
  57. rasa/core/channels/console.py +6 -8
  58. rasa/core/channels/development_inspector.py +1 -1
  59. rasa/core/channels/facebook.py +0 -3
  60. rasa/core/channels/hangouts.py +0 -6
  61. rasa/core/channels/inspector/dist/assets/{arc-5623b6dc.js → arc-b6e548fe.js} +1 -1
  62. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-685c106a.js → c4Diagram-d0fbc5ce-fa03ac9e.js} +1 -1
  63. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-8cbed007.js → classDiagram-936ed81e-ee67392a.js} +1 -1
  64. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-5889cf12.js → classDiagram-v2-c3cb15f1-9b283fae.js} +1 -1
  65. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-24c249d7.js → createText-62fc7601-8b6fcc2a.js} +1 -1
  66. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-7dd06a75.js → edges-f2ad444c-22e77f4f.js} +1 -1
  67. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-62c1e54c.js → erDiagram-9d236eb7-60ffc87f.js} +1 -1
  68. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-ce49b86f.js → flowDb-1972c806-9dd802e4.js} +1 -1
  69. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4067e48f.js → flowDiagram-7ea5b25a-5fa1912f.js} +1 -1
  70. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
  71. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-59fe4051.js → flowchart-elk-definition-abe16c3d-622a1fd2.js} +1 -1
  72. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-47e3a43b.js → ganttDiagram-9b5ea136-e285a63a.js} +1 -1
  73. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-5a2ac0d9.js → gitGraphDiagram-99d0ae7c-f237bdca.js} +1 -1
  74. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-dfb8efc4.js → index-2c4b9a3b-4b03d70e.js} +1 -1
  75. rasa/core/channels/inspector/dist/assets/{index-268a75c0.js → index-a5d3e69d.js} +4 -4
  76. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-b0c470f2.js → infoDiagram-736b4530-72a0fa5f.js} +1 -1
  77. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-2edb829a.js → journeyDiagram-df861f2b-82218c41.js} +1 -1
  78. rasa/core/channels/inspector/dist/assets/{layout-b6873d69.js → layout-78cff630.js} +1 -1
  79. rasa/core/channels/inspector/dist/assets/{line-1efc5781.js → line-5038b469.js} +1 -1
  80. rasa/core/channels/inspector/dist/assets/{linear-661e9b94.js → linear-c4fc4098.js} +1 -1
  81. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2d2e727f.js → mindmap-definition-beec6740-c33c8ea6.js} +1 -1
  82. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-9d3ea93d.js → pieDiagram-dbbf0591-a8d03059.js} +1 -1
  83. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-06a178a2.js → quadrantDiagram-4d7f4fd6-6a0e56b2.js} +1 -1
  84. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-0bfedffc.js → requirementDiagram-6fc4c22a-2dc7c7bd.js} +1 -1
  85. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-d76d0a04.js → sankeyDiagram-8f13d901-2360fe39.js} +1 -1
  86. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-37bb4341.js → sequenceDiagram-b655622a-41b9f9ad.js} +1 -1
  87. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-f52f7f57.js → stateDiagram-59f0c015-0aad326f.js} +1 -1
  88. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-4a986a20.js → stateDiagram-v2-2b26beab-9847d984.js} +1 -1
  89. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-7dd9ae12.js → styles-080da4f6-564d890e.js} +1 -1
  90. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-46e1ca14.js → styles-3dcbcfbf-38957613.js} +1 -1
  91. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-4a97439a.js → styles-9c745c82-f0fc6921.js} +1 -1
  92. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-823917a3.js → svgDrawCommon-4835440b-ef3c5a77.js} +1 -1
  93. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-9ea72896.js → timeline-definition-5b62e21b-bf3e91c1.js} +1 -1
  94. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-b631a8b6.js → xychartDiagram-2b33534f-4d4026c0.js} +1 -1
  95. rasa/core/channels/inspector/dist/index.html +1 -1
  96. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +10 -0
  97. rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -7
  98. rasa/core/channels/inspector/src/helpers/formatters.ts +3 -2
  99. rasa/core/channels/rest.py +36 -21
  100. rasa/core/channels/rocketchat.py +0 -1
  101. rasa/core/channels/socketio.py +1 -1
  102. rasa/core/channels/telegram.py +3 -3
  103. rasa/core/channels/webexteams.py +0 -1
  104. rasa/core/concurrent_lock_store.py +1 -1
  105. rasa/core/evaluation/marker_base.py +1 -3
  106. rasa/core/evaluation/marker_stats.py +1 -2
  107. rasa/core/featurizers/single_state_featurizer.py +3 -26
  108. rasa/core/featurizers/tracker_featurizers.py +18 -122
  109. rasa/core/information_retrieval/__init__.py +7 -0
  110. rasa/core/information_retrieval/faiss.py +9 -4
  111. rasa/core/information_retrieval/information_retrieval.py +64 -7
  112. rasa/core/information_retrieval/milvus.py +7 -14
  113. rasa/core/information_retrieval/qdrant.py +8 -15
  114. rasa/core/lock_store.py +0 -1
  115. rasa/core/migrate.py +1 -2
  116. rasa/core/nlg/callback.py +3 -4
  117. rasa/core/policies/enterprise_search_policy.py +86 -22
  118. rasa/core/policies/enterprise_search_prompt_template.jinja2 +4 -41
  119. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  120. rasa/core/policies/flows/flow_executor.py +104 -2
  121. rasa/core/policies/intentless_policy.py +7 -9
  122. rasa/core/policies/memoization.py +3 -3
  123. rasa/core/policies/policy.py +18 -9
  124. rasa/core/policies/rule_policy.py +8 -11
  125. rasa/core/policies/ted_policy.py +61 -88
  126. rasa/core/policies/unexpected_intent_policy.py +8 -17
  127. rasa/core/processor.py +136 -47
  128. rasa/core/run.py +41 -25
  129. rasa/core/secrets_manager/endpoints.py +2 -2
  130. rasa/core/secrets_manager/vault.py +6 -8
  131. rasa/core/test.py +3 -5
  132. rasa/core/tracker_store.py +49 -14
  133. rasa/core/train.py +1 -3
  134. rasa/core/training/interactive.py +9 -6
  135. rasa/core/utils.py +5 -10
  136. rasa/dialogue_understanding/coexistence/intent_based_router.py +11 -4
  137. rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -3
  138. rasa/dialogue_understanding/commands/__init__.py +4 -0
  139. rasa/dialogue_understanding/commands/can_not_handle_command.py +9 -0
  140. rasa/dialogue_understanding/commands/cancel_flow_command.py +9 -0
  141. rasa/dialogue_understanding/commands/change_flow_command.py +38 -0
  142. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +9 -0
  143. rasa/dialogue_understanding/commands/clarify_command.py +9 -0
  144. rasa/dialogue_understanding/commands/correct_slots_command.py +9 -0
  145. rasa/dialogue_understanding/commands/error_command.py +12 -0
  146. rasa/dialogue_understanding/commands/handle_code_change_command.py +9 -0
  147. rasa/dialogue_understanding/commands/human_handoff_command.py +9 -0
  148. rasa/dialogue_understanding/commands/knowledge_answer_command.py +9 -0
  149. rasa/dialogue_understanding/commands/noop_command.py +9 -0
  150. rasa/dialogue_understanding/commands/set_slot_command.py +34 -3
  151. rasa/dialogue_understanding/commands/skip_question_command.py +9 -0
  152. rasa/dialogue_understanding/commands/start_flow_command.py +9 -0
  153. rasa/dialogue_understanding/generator/__init__.py +16 -1
  154. rasa/dialogue_understanding/generator/command_generator.py +92 -6
  155. rasa/dialogue_understanding/generator/constants.py +18 -0
  156. rasa/dialogue_understanding/generator/flow_retrieval.py +7 -5
  157. rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
  158. rasa/dialogue_understanding/generator/llm_command_generator.py +39 -609
  159. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  160. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  161. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  162. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
  163. rasa/dialogue_understanding/generator/nlu_command_adapter.py +69 -8
  164. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  165. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
  166. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +44 -39
  167. rasa/dialogue_understanding/processor/command_processor.py +111 -3
  168. rasa/e2e_test/constants.py +1 -0
  169. rasa/e2e_test/e2e_test_case.py +44 -0
  170. rasa/e2e_test/e2e_test_runner.py +114 -11
  171. rasa/e2e_test/e2e_test_schema.yml +18 -0
  172. rasa/engine/caching.py +0 -1
  173. rasa/engine/graph.py +18 -6
  174. rasa/engine/recipes/config_files/default_config.yml +3 -3
  175. rasa/engine/recipes/default_components.py +1 -1
  176. rasa/engine/recipes/default_recipe.py +4 -5
  177. rasa/engine/recipes/recipe.py +1 -1
  178. rasa/engine/runner/dask.py +3 -9
  179. rasa/engine/storage/local_model_storage.py +0 -2
  180. rasa/engine/validation.py +179 -145
  181. rasa/exceptions.py +2 -2
  182. rasa/graph_components/validators/default_recipe_validator.py +3 -5
  183. rasa/hooks.py +0 -1
  184. rasa/model.py +1 -1
  185. rasa/model_training.py +1 -0
  186. rasa/nlu/classifiers/diet_classifier.py +33 -52
  187. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
  188. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  189. rasa/nlu/extractors/crf_entity_extractor.py +54 -97
  190. rasa/nlu/extractors/duckling_entity_extractor.py +1 -1
  191. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +1 -5
  192. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +0 -4
  193. rasa/nlu/featurizers/featurizer.py +1 -1
  194. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +18 -49
  195. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +26 -64
  196. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  197. rasa/nlu/persistor.py +68 -26
  198. rasa/nlu/selectors/response_selector.py +7 -10
  199. rasa/nlu/test.py +0 -3
  200. rasa/nlu/utils/hugging_face/registry.py +1 -1
  201. rasa/nlu/utils/spacy_utils.py +1 -3
  202. rasa/server.py +22 -7
  203. rasa/shared/constants.py +12 -1
  204. rasa/shared/core/command_payload_reader.py +109 -0
  205. rasa/shared/core/constants.py +4 -5
  206. rasa/shared/core/domain.py +57 -56
  207. rasa/shared/core/events.py +4 -7
  208. rasa/shared/core/flows/flow.py +9 -0
  209. rasa/shared/core/flows/flows_list.py +12 -0
  210. rasa/shared/core/flows/steps/action.py +7 -2
  211. rasa/shared/core/generator.py +12 -11
  212. rasa/shared/core/slot_mappings.py +315 -24
  213. rasa/shared/core/slots.py +4 -2
  214. rasa/shared/core/trackers.py +32 -14
  215. rasa/shared/core/training_data/loading.py +0 -1
  216. rasa/shared/core/training_data/story_reader/story_reader.py +3 -3
  217. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +11 -11
  218. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +5 -3
  219. rasa/shared/core/training_data/structures.py +1 -1
  220. rasa/shared/core/training_data/visualization.py +1 -1
  221. rasa/shared/data.py +58 -1
  222. rasa/shared/exceptions.py +36 -2
  223. rasa/shared/importers/importer.py +1 -2
  224. rasa/shared/importers/rasa.py +0 -1
  225. rasa/shared/nlu/constants.py +2 -0
  226. rasa/shared/nlu/training_data/entities_parser.py +1 -2
  227. rasa/shared/nlu/training_data/features.py +2 -120
  228. rasa/shared/nlu/training_data/formats/dialogflow.py +3 -2
  229. rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -5
  230. rasa/shared/nlu/training_data/formats/readerwriter.py +0 -1
  231. rasa/shared/nlu/training_data/message.py +13 -0
  232. rasa/shared/nlu/training_data/training_data.py +0 -2
  233. rasa/shared/providers/openai/session_handler.py +2 -2
  234. rasa/shared/utils/constants.py +3 -0
  235. rasa/shared/utils/io.py +11 -1
  236. rasa/shared/utils/llm.py +1 -2
  237. rasa/shared/utils/pykwalify_extensions.py +1 -0
  238. rasa/shared/utils/schemas/domain.yml +3 -0
  239. rasa/shared/utils/yaml.py +44 -35
  240. rasa/studio/auth.py +26 -10
  241. rasa/studio/constants.py +2 -0
  242. rasa/studio/data_handler.py +114 -107
  243. rasa/studio/download.py +160 -27
  244. rasa/studio/results_logger.py +137 -0
  245. rasa/studio/train.py +6 -7
  246. rasa/studio/upload.py +159 -134
  247. rasa/telemetry.py +188 -34
  248. rasa/tracing/config.py +18 -3
  249. rasa/tracing/constants.py +26 -2
  250. rasa/tracing/instrumentation/attribute_extractors.py +50 -41
  251. rasa/tracing/instrumentation/instrumentation.py +290 -44
  252. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +7 -5
  253. rasa/tracing/instrumentation/metrics.py +109 -21
  254. rasa/tracing/metric_instrument_provider.py +83 -3
  255. rasa/utils/cli.py +2 -1
  256. rasa/utils/common.py +1 -1
  257. rasa/utils/endpoints.py +1 -2
  258. rasa/utils/io.py +72 -6
  259. rasa/utils/licensing.py +246 -31
  260. rasa/utils/ml_utils.py +1 -1
  261. rasa/utils/tensorflow/data_generator.py +1 -1
  262. rasa/utils/tensorflow/environment.py +1 -1
  263. rasa/utils/tensorflow/model_data.py +201 -12
  264. rasa/utils/tensorflow/model_data_utils.py +499 -500
  265. rasa/utils/tensorflow/models.py +5 -6
  266. rasa/utils/tensorflow/rasa_layers.py +15 -15
  267. rasa/utils/train_utils.py +1 -1
  268. rasa/utils/url_tools.py +53 -0
  269. rasa/validator.py +305 -3
  270. rasa/version.py +1 -1
  271. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/METADATA +25 -61
  272. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/RECORD +276 -259
  273. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +0 -1
  274. rasa/utils/tensorflow/feature_array.py +0 -370
  275. /rasa/dialogue_understanding/generator/{command_prompt_template.jinja2 → single_step/command_prompt_template.jinja2} +0 -0
  276. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/NOTICE +0 -0
  277. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/WHEEL +0 -0
  278. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/entry_points.txt +0 -0
rasa/engine/validation.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import dataclasses
2
2
  import inspect
3
+ import re
3
4
  import logging
4
5
  import sys
5
6
  import typing
@@ -30,6 +31,9 @@ from rasa.dialogue_understanding.coexistence.constants import (
30
31
  STICKY,
31
32
  NON_STICKY,
32
33
  )
34
+ from rasa.dialogue_understanding.generator import (
35
+ LLMBasedCommandGenerator,
36
+ )
33
37
  from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
34
38
  from rasa.engine.constants import RESERVED_PLACEHOLDERS
35
39
  from rasa.engine.exceptions import GraphSchemaValidationException
@@ -51,6 +55,10 @@ from rasa.shared.core.slots import Slot
51
55
  from rasa.shared.exceptions import RasaException
52
56
  from rasa.shared.nlu.training_data.message import Message
53
57
 
58
+ from rasa.dialogue_understanding.coexistence.intent_based_router import (
59
+ IntentBasedRouter,
60
+ )
61
+ from rasa.dialogue_understanding.coexistence.llm_based_router import LLMBasedRouter
54
62
 
55
63
  TypeAnnotation = Union[TypeVar, Text, Type, Optional[AvailableEndpoints]]
56
64
 
@@ -209,7 +217,7 @@ def _validate_interface_usage(node: SchemaNode) -> None:
209
217
  raise GraphSchemaValidationException(
210
218
  f"Your model uses a component with class '{node.uses.__name__}'. "
211
219
  f"This class does not implement the '{GraphComponent.__name__}' interface "
212
- f"and can hence not be run within Rasa Open Source. Please use a different "
220
+ f"and can hence not be run within Rasa Pro. Please use a different "
213
221
  f"component or implement the '{GraphComponent}' interface in class "
214
222
  f"'{node.uses.__name__}'. "
215
223
  f"See {DOCS_URL_GRAPH_COMPONENTS} for more information."
@@ -503,7 +511,6 @@ def _validate_parent_return_type(
503
511
  parent_return_type: TypeAnnotation,
504
512
  required_type: TypeAnnotation,
505
513
  ) -> None:
506
-
507
514
  if not typing_utils.issubtype(parent_return_type, required_type):
508
515
  parent_node_text = ""
509
516
  if parent_node:
@@ -606,7 +613,6 @@ def _recursively_check_required_components(
606
613
  def validate_flow_component_dependencies(
607
614
  flows: FlowsList, model_configuration: GraphModelConfiguration
608
615
  ) -> None:
609
-
610
616
  if (pattern_chitchat := flows.flow_by_id(FLOW_PATTERN_CHITCHAT)) is not None:
611
617
  _validate_chitchat_dependencies(pattern_chitchat, model_configuration)
612
618
 
@@ -637,166 +643,169 @@ def _validate_chitchat_dependencies(
637
643
  )
638
644
 
639
645
 
640
- def validate_coexistance_routing_setup(
641
- domain: Domain, model_configuration: GraphModelConfiguration, flows: FlowsList
646
+ def get_component_index(schema: GraphSchema, component_class: Type) -> Optional[int]:
647
+ """Extracts the index of a component of the given class in the schema.
648
+ This function assumes that each component's node name is stored in a way
649
+ that includes the index as part of the name, formatted as
650
+ "run_{ComponentName}{Index}", which is how it's created by the recipe.
651
+ """
652
+ # the index of the component is at the end of the node name
653
+ pattern = re.compile(r"\d+$")
654
+ for node_name, node in schema.nodes.items():
655
+ if issubclass(node.uses, component_class):
656
+ match = pattern.search(node_name)
657
+ if match:
658
+ index = int(match.group())
659
+ return index
660
+ # index is not found or there is no component with the given class
661
+ return None
662
+
663
+
664
+ def get_component_config(
665
+ schema: GraphSchema, component_class: Type
666
+ ) -> Optional[Dict[str, Any]]:
667
+ """Extracts the config of a component of the given class in the schema."""
668
+ for node_name, node in schema.nodes.items():
669
+ if issubclass(node.uses, component_class):
670
+ return node.config
671
+ return None
672
+
673
+
674
+ def validate_router_exclusivity(schema: GraphSchema) -> None:
675
+ """Validate that intent-based and llm-based routers are not
676
+ defined at the same time.
677
+ """
678
+ if schema.has_node(IntentBasedRouter) and schema.has_node(LLMBasedRouter):
679
+ structlogger.error(
680
+ "validation.coexistance.both_routers_defined",
681
+ event_info=(
682
+ "Both LLMBasedRouter and IntentBasedRouter are in the config. "
683
+ "Please use only one of them."
684
+ ),
685
+ )
686
+ sys.exit(1)
687
+
688
+
689
+ def validate_intent_based_router_position(schema: GraphSchema) -> None:
690
+ """Validate that if intent-based router is defined, it is positioned before
691
+ the llm command generator.
692
+ """
693
+ intent_based_router_pos = get_component_index(schema, IntentBasedRouter)
694
+ llm_command_generator_pos = get_component_index(schema, LLMBasedCommandGenerator)
695
+ if (
696
+ intent_based_router_pos is not None
697
+ and llm_command_generator_pos is not None
698
+ and intent_based_router_pos > llm_command_generator_pos
699
+ ):
700
+ structlogger.error(
701
+ "validation.coexistance.wrong_order_of_components",
702
+ event_info=(
703
+ "IntentBasedRouter should come before "
704
+ "a LLMBasedCommandGenerator in the pipeline."
705
+ ),
706
+ )
707
+ sys.exit(1)
708
+
709
+
710
+ def validate_that_slots_are_defined_if_router_is_defined(
711
+ schema: GraphSchema, routing_slots: List[Slot]
642
712
  ) -> None:
643
- import re
644
- from rasa.dialogue_understanding.coexistence.intent_based_router import (
645
- IntentBasedRouter,
646
- )
647
- from rasa.dialogue_understanding.coexistence.llm_based_router import LLMBasedRouter
648
- from rasa.dialogue_understanding.generator import LLMCommandGenerator
649
-
650
- def get_component_index(
651
- schema: GraphSchema, component_class: Type
652
- ) -> Optional[int]:
653
- """Extracts the index of a component of the given class in the schema.
654
- This function assumes that each component's node name is stored in a way
655
- that includes the index as part of the name, formatted as
656
- "run_{ComponentName}{Index}", which is how it's created by the recipe.
657
- """
658
- # the index of the component is at the end of the node name
659
- pattern = re.compile(r"\d+$")
660
- for node_name, node in schema.nodes.items():
661
- if issubclass(node.uses, component_class):
662
- match = pattern.search(node_name)
663
- if match:
664
- index = int(match.group())
665
- return index
666
- # index is not found or there is no component with the given class
667
- return None
668
-
669
- def get_component_config(
670
- schema: GraphSchema, component_class: Type
671
- ) -> Optional[Dict[str, Any]]:
672
- """Extracts the config of a component of the given class in the schema."""
673
- for node_name, node in schema.nodes.items():
674
- if issubclass(node.uses, component_class):
675
- return node.config
676
- return None
677
-
678
- def validate_router_exclusivity(schema: GraphSchema) -> None:
679
- """Validate that intent-based and llm-based routers are not
680
- defined at the same time.
681
- """
682
- if schema.has_node(IntentBasedRouter) and schema.has_node(LLMBasedRouter):
713
+ # check whether intent-based or llm-based type of router is present
714
+ for router_type in [IntentBasedRouter, LLMBasedRouter]:
715
+ router_present = schema.has_node(router_type)
716
+ slot_has_issue = len(routing_slots) == 0 or routing_slots[0].type_name != "bool"
717
+ if router_present and slot_has_issue:
683
718
  structlogger.error(
684
- "validation.coexistance.both_routers_defined",
719
+ f"validation.coexistance.{ROUTE_TO_CALM_SLOT}_not_in_domain",
685
720
  event_info=(
686
- "Both LLMBasedRouter and IntentBasedRouter are in the config. "
687
- "Please use only one of them."
721
+ f"{router_type.__name__} is in the config, but the slot "
722
+ f"{ROUTE_TO_CALM_SLOT} is not in the domain or not of "
723
+ f"type bool."
688
724
  ),
689
725
  )
690
726
  sys.exit(1)
691
727
 
692
- def validate_intent_based_router_position(schema: GraphSchema) -> None:
693
- """Validate that if intent-based router is defined, it is positioned before
694
- the llm command generator.
695
- """
696
- intent_based_router_pos = get_component_index(schema, IntentBasedRouter)
697
- llm_command_generator_pos = get_component_index(schema, LLMCommandGenerator)
698
- if (
699
- intent_based_router_pos is not None
700
- and llm_command_generator_pos is not None
701
- and intent_based_router_pos > llm_command_generator_pos
702
- ):
728
+
729
+ def validate_that_router_is_defined_if_router_slots_are_in_domain(
730
+ schema: GraphSchema,
731
+ routing_slots: List[Slot],
732
+ ) -> None:
733
+ defined_router_slots = len(routing_slots) > 0
734
+ router_present = schema.has_node(IntentBasedRouter) or schema.has_node(
735
+ LLMBasedRouter
736
+ )
737
+ if defined_router_slots and (
738
+ not router_present or routing_slots[0].type_name != "bool"
739
+ ):
740
+ structlogger.error(
741
+ f"validation.coexistance"
742
+ f".{ROUTE_TO_CALM_SLOT}_in_domain_with_no_router_defined",
743
+ event_info=(
744
+ f"The slot {ROUTE_TO_CALM_SLOT} is in the domain but the "
745
+ f"LLMBasedRouter or the IntentBasedRouter is not in the config or "
746
+ f"the type of the slot is not bool."
747
+ ),
748
+ )
749
+ sys.exit(1)
750
+
751
+
752
+ def valid_nlu_entry_config(config: Optional[Dict[str, Any]]) -> bool:
753
+ return (
754
+ config is not None
755
+ and NLU_ENTRY in config
756
+ and isinstance(config[NLU_ENTRY], dict)
757
+ and STICKY in config[NLU_ENTRY]
758
+ and NON_STICKY in config[NLU_ENTRY]
759
+ )
760
+
761
+
762
+ def valid_calm_entry_config(config: Optional[Dict[str, Any]]) -> bool:
763
+ return (
764
+ config is not None
765
+ and CALM_ENTRY in config
766
+ and isinstance(config[CALM_ENTRY], dict)
767
+ and STICKY in config[CALM_ENTRY]
768
+ )
769
+
770
+
771
+ def validate_configuration(
772
+ schema: GraphSchema,
773
+ ) -> None:
774
+ """Validate the configuration of the existing coexistence routers."""
775
+ if schema.has_node(IntentBasedRouter, include_subtypes=False):
776
+ config = get_component_config(schema, IntentBasedRouter)
777
+ if not valid_calm_entry_config(config) or not valid_nlu_entry_config(config):
703
778
  structlogger.error(
704
- "validation.coexistance.wrong_order_of_components",
779
+ "validation.coexistance.invalid_configuration",
705
780
  event_info=(
706
- "IntentBasedRouter should come before LLMCommandGenerator "
707
- "in the pipeline."
781
+ "The configuration of the IntentBasedRouter is invalid. "
782
+ "Please check the documentation.",
708
783
  ),
709
784
  )
710
785
  sys.exit(1)
711
786
 
712
- def validate_that_slots_are_defined_if_router_is_defined(
713
- schema: GraphSchema, routing_slots: List[Slot]
714
- ) -> None:
715
- # check whether intent-based or llm-based type of router is present
716
- for router_type in [IntentBasedRouter, LLMBasedRouter]:
717
- router_present = schema.has_node(router_type)
718
- slot_has_issue = (
719
- len(routing_slots) == 0 or routing_slots[0].type_name != "bool"
720
- )
721
- if router_present and slot_has_issue:
722
- structlogger.error(
723
- f"validation.coexistance.{ROUTE_TO_CALM_SLOT}_not_in_domain",
724
- event_info=(
725
- f"{router_type.__name__} is in the config, but the slot "
726
- f"{ROUTE_TO_CALM_SLOT} is not in the domain or not of "
727
- f"type bool."
728
- ),
729
- )
730
- sys.exit(1)
731
-
732
- def validate_that_router_is_defined_if_router_slots_are_in_domain(
733
- schema: GraphSchema,
734
- routing_slots: List[Slot],
735
- ) -> None:
736
- defined_router_slots = len(routing_slots) > 0
737
- router_present = schema.has_node(IntentBasedRouter) or schema.has_node(
738
- LLMBasedRouter
739
- )
740
- if defined_router_slots and (
741
- not router_present or routing_slots[0].type_name != "bool"
787
+ if schema.has_node(LLMBasedRouter, include_subtypes=False):
788
+ config = get_component_config(schema, LLMBasedRouter)
789
+ if not valid_calm_entry_config(config) or (
790
+ config is not None
791
+ and NLU_ENTRY in config
792
+ and not valid_nlu_entry_config(config)
742
793
  ):
743
794
  structlogger.error(
744
- f"validation.coexistance"
745
- f".{ROUTE_TO_CALM_SLOT}_in_domain_with_no_router_defined",
795
+ "validation.coexistance.invalid_configuration",
746
796
  event_info=(
747
- f"The slot {ROUTE_TO_CALM_SLOT} is in the domain but the "
748
- f"LLMBasedRouter or the IntentBasedRouter is not in the config or "
749
- f"the type of the slot is not bool."
797
+ "The configuration of the LLMBasedRouter is invalid. "
798
+ "Please check the documentation.",
750
799
  ),
751
800
  )
752
801
  sys.exit(1)
753
802
 
754
- def valid_nlu_entry_config(config: Optional[Dict[str, Any]]) -> bool:
755
- return (
756
- config is not None
757
- and NLU_ENTRY in config
758
- and isinstance(config[NLU_ENTRY], dict)
759
- and STICKY in config[NLU_ENTRY]
760
- and NON_STICKY in config[NLU_ENTRY]
761
- )
762
803
 
763
- def valid_calm_entry_config(config: Optional[Dict[str, Any]]) -> bool:
764
- return (
765
- config is not None
766
- and CALM_ENTRY in config
767
- and isinstance(config[CALM_ENTRY], dict)
768
- and STICKY in config[CALM_ENTRY]
769
- )
770
-
771
- def validate_configuration(
772
- schema: GraphSchema,
773
- ) -> None:
774
- """Validate the configuration of the existing coexistence routers."""
775
- if schema.has_node(IntentBasedRouter, include_subtypes=False):
776
- config = get_component_config(schema, IntentBasedRouter)
777
- if not valid_calm_entry_config(config) or not valid_nlu_entry_config(
778
- config
779
- ):
780
- structlogger.error(
781
- "validation.coexistance.invalid_configuration",
782
- event_info=(
783
- "The configuration of the IntentBasedRouter is invalid. "
784
- "Please check the documentation.",
785
- ),
786
- )
787
- sys.exit(1)
788
-
789
- if schema.has_node(LLMBasedRouter, include_subtypes=False):
790
- config = get_component_config(schema, LLMBasedRouter)
791
- if not valid_calm_entry_config(config):
792
- structlogger.error(
793
- "validation.coexistance.invalid_configuration",
794
- event_info=(
795
- "The configuration of the LLMBasedRouter is invalid. "
796
- "Please check the documentation.",
797
- ),
798
- )
799
- sys.exit(1)
804
+ def validate_coexistance_routing_setup(
805
+ domain: Domain, model_configuration: GraphModelConfiguration, flows: FlowsList
806
+ ) -> None:
807
+ schema = model_configuration.predict_schema
808
+ routing_slots = [s for s in domain.slots if s.name == ROUTE_TO_CALM_SLOT]
800
809
 
801
810
  def validate_that_router_or_router_slot_are_defined_if_action_reset_routing_is_used(
802
811
  schema: GraphSchema, flows: FlowsList, routing_slots: List[Slot]
@@ -826,9 +835,6 @@ def validate_coexistance_routing_setup(
826
835
  )
827
836
  sys.exit(1)
828
837
 
829
- schema = model_configuration.predict_schema
830
- routing_slots = [s for s in domain.slots if s.name == ROUTE_TO_CALM_SLOT]
831
-
832
838
  validate_router_exclusivity(schema)
833
839
  validate_intent_based_router_position(schema)
834
840
  validate_that_slots_are_defined_if_router_is_defined(schema, routing_slots)
@@ -837,3 +843,31 @@ def validate_coexistance_routing_setup(
837
843
  validate_that_router_or_router_slot_are_defined_if_action_reset_routing_is_used(
838
844
  schema, flows, routing_slots
839
845
  )
846
+
847
+
848
+ def validate_command_generator_exclusivity(schema: GraphSchema) -> None:
849
+ """Validate that multiple command generators are not defined at same time."""
850
+ from rasa.dialogue_understanding.generator import (
851
+ LLMBasedCommandGenerator,
852
+ )
853
+
854
+ count = schema.count_nodes_of_a_given_type(
855
+ LLMBasedCommandGenerator, include_subtypes=True
856
+ )
857
+
858
+ if count > 1:
859
+ structlogger.error(
860
+ "validation.command_generator.multiple_command_generator_defined",
861
+ event_info=(
862
+ "Multiple LLM based command generators are defined in the config. "
863
+ "Please use only one LLM based command generator."
864
+ ),
865
+ )
866
+ sys.exit(1)
867
+
868
+
869
+ def validate_command_generator_setup(
870
+ model_configuration: GraphModelConfiguration,
871
+ ) -> None:
872
+ schema = model_configuration.predict_schema
873
+ validate_command_generator_exclusivity(schema)
rasa/exceptions.py CHANGED
@@ -20,9 +20,9 @@ class UnsupportedModelVersionError(RasaException):
20
20
  def __str__(self) -> Text:
21
21
  minimum_version = version.parse(MINIMUM_COMPATIBLE_VERSION)
22
22
  return (
23
- f"The model version is trained using Rasa Open Source {self.model_version} "
23
+ f"The model version is trained using Rasa Pro {self.model_version} "
24
24
  f"and is not compatible with your current installation "
25
- f"which supports models build with Rasa Open Source {minimum_version} "
25
+ f"which supports models build with Rasa Pro {minimum_version} "
26
26
  f"or higher. "
27
27
  f"This means that you either need to retrain your model "
28
28
  f"or revert back to the Rasa version that trained the model "
@@ -203,7 +203,6 @@ class DefaultV1RecipeValidator(GraphComponent):
203
203
  )
204
204
 
205
205
  if training_data.lookup_tables:
206
-
207
206
  if self._component_types.isdisjoint([CRFEntityExtractor, DIETClassifier]):
208
207
  rasa.shared.utils.io.raise_warning(
209
208
  f"You have defined training data consisting of lookup tables, but "
@@ -219,7 +218,6 @@ class DefaultV1RecipeValidator(GraphComponent):
219
218
  )
220
219
 
221
220
  elif CRFEntityExtractor in self._component_types:
222
-
223
221
  crf_schema_nodes = [
224
222
  schema_node
225
223
  for schema_node in self._graph_schema.nodes.values()
@@ -295,9 +293,9 @@ class DefaultV1RecipeValidator(GraphComponent):
295
293
  Both of these look for the same entities based on the same training data
296
294
  leading to ambiguity in the results.
297
295
  """
298
- extractors_in_configuration: Set[
299
- Type[GraphComponent]
300
- ] = self._component_types.intersection(TRAINABLE_EXTRACTORS)
296
+ extractors_in_configuration: Set[Type[GraphComponent]] = (
297
+ self._component_types.intersection(TRAINABLE_EXTRACTORS)
298
+ )
301
299
  if len(extractors_in_configuration) > 1:
302
300
  rasa.shared.utils.io.raise_warning(
303
301
  f"You have defined multiple entity extractors that do the same job "
rasa/hooks.py CHANGED
@@ -78,7 +78,6 @@ def create_tracker_store(
78
78
  domain: "Domain",
79
79
  event_broker: Optional["EventBroker"],
80
80
  ) -> "TrackerStore":
81
-
82
81
  if isinstance(endpoint_config, EndpointConfig):
83
82
  return AuthRetryTrackerStore(
84
83
  endpoint_config=endpoint_config, domain=domain, event_broker=event_broker
rasa/model.py CHANGED
@@ -74,7 +74,7 @@ def get_latest_model(model_path: Text = DEFAULT_MODELS_PATH) -> Optional[Text]:
74
74
 
75
75
 
76
76
  def get_model_for_finetuning(
77
- previous_model_file_or_dir: Union[Path, Text]
77
+ previous_model_file_or_dir: Union[Path, Text],
78
78
  ) -> Optional[Path]:
79
79
  """Gets validated path for model to finetune.
80
80
 
rasa/model_training.py CHANGED
@@ -309,6 +309,7 @@ async def _train_graph(
309
309
  rasa.engine.validation.validate_flow_component_dependencies(
310
310
  flows, model_configuration
311
311
  )
312
+ rasa.engine.validation.validate_command_generator_setup(model_configuration)
312
313
 
313
314
  tempdir_name = rasa.utils.common.get_temp_dir_name()
314
315
  # Use `TempDirectoryPath` instead of `tempfile.TemporaryDirectory` as this
@@ -1,17 +1,18 @@
1
1
  from __future__ import annotations
2
-
3
2
  import copy
4
3
  import logging
5
4
  from collections import defaultdict
6
5
  from pathlib import Path
7
- from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
6
+
7
+ from rasa.exceptions import ModelNotFound
8
+ from rasa.nlu.featurizers.featurizer import Featurizer
8
9
 
9
10
  import numpy as np
10
11
  import scipy.sparse
11
12
  import tensorflow as tf
12
13
 
13
- from rasa.exceptions import ModelNotFound
14
- from rasa.nlu.featurizers.featurizer import Featurizer
14
+ from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
15
+
15
16
  from rasa.engine.graph import ExecutionContext, GraphComponent
16
17
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
17
18
  from rasa.engine.storage.resource import Resource
@@ -19,21 +20,18 @@ from rasa.engine.storage.storage import ModelStorage
19
20
  from rasa.nlu.extractors.extractor import EntityExtractorMixin
20
21
  from rasa.nlu.classifiers.classifier import IntentClassifier
21
22
  import rasa.shared.utils.io
23
+ import rasa.utils.io as io_utils
22
24
  import rasa.nlu.utils.bilou_utils as bilou_utils
23
25
  from rasa.shared.constants import DIAGNOSTIC_DATA
24
26
  from rasa.nlu.extractors.extractor import EntityTagSpec
25
27
  from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
26
28
  from rasa.utils import train_utils
27
29
  from rasa.utils.tensorflow import rasa_layers
28
- from rasa.utils.tensorflow.feature_array import (
29
- FeatureArray,
30
- serialize_nested_feature_arrays,
31
- deserialize_nested_feature_arrays,
32
- )
33
30
  from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
34
31
  from rasa.utils.tensorflow.model_data import (
35
32
  RasaModelData,
36
33
  FeatureSignature,
34
+ FeatureArray,
37
35
  )
38
36
  from rasa.nlu.constants import TOKENS_NAMES, DEFAULT_TRANSFORMER_SIZE
39
37
  from rasa.shared.nlu.constants import (
@@ -120,6 +118,7 @@ LABEL_SUB_KEY = IDS
120
118
 
121
119
  POSSIBLE_TAGS = [ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_ROLE, ENTITY_ATTRIBUTE_GROUP]
122
120
 
121
+
123
122
  DIETClassifierT = TypeVar("DIETClassifierT", bound="DIETClassifier")
124
123
 
125
124
 
@@ -511,7 +510,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
511
510
  def _extract_features(
512
511
  self, message: Message, attribute: Text
513
512
  ) -> Dict[Text, Union[scipy.sparse.spmatrix, np.ndarray]]:
514
-
515
513
  (
516
514
  sparse_sequence_features,
517
515
  sparse_sentence_features,
@@ -781,7 +779,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
781
779
  sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
782
780
  label_attribute: Optional[Text] = None,
783
781
  ) -> Dict[Text, Dict[Text, List[int]]]:
784
-
785
782
  if label_attribute in sparse_feature_sizes:
786
783
  del sparse_feature_sizes[label_attribute]
787
784
  return sparse_feature_sizes
@@ -1086,24 +1083,18 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1086
1083
 
1087
1084
  self.model.save(str(tf_model_file))
1088
1085
 
1089
- # save data example
1090
- serialize_nested_feature_arrays(
1091
- self._data_example,
1092
- model_path / f"{file_name}.data_example.st",
1093
- model_path / f"{file_name}.data_example_metadata.json",
1086
+ io_utils.pickle_dump(
1087
+ model_path / f"{file_name}.data_example.pkl", self._data_example
1094
1088
  )
1095
- # save label data
1096
- serialize_nested_feature_arrays(
1097
- dict(self._label_data.data) if self._label_data is not None else {},
1098
- model_path / f"{file_name}.label_data.st",
1099
- model_path / f"{file_name}.label_data_metadata.json",
1100
- )
1101
-
1102
- rasa.shared.utils.io.dump_obj_as_json_to_file(
1103
- model_path / f"{file_name}.sparse_feature_sizes.json",
1089
+ io_utils.pickle_dump(
1090
+ model_path / f"{file_name}.sparse_feature_sizes.pkl",
1104
1091
  self._sparse_feature_sizes,
1105
1092
  )
1106
- rasa.shared.utils.io.dump_obj_as_json_to_file(
1093
+ io_utils.pickle_dump(
1094
+ model_path / f"{file_name}.label_data.pkl",
1095
+ dict(self._label_data.data) if self._label_data is not None else {},
1096
+ )
1097
+ io_utils.json_pickle(
1107
1098
  model_path / f"{file_name}.index_label_id_mapping.json",
1108
1099
  self.index_label_id_mapping,
1109
1100
  )
@@ -1192,22 +1183,15 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1192
1183
  ]:
1193
1184
  file_name = cls.__name__
1194
1185
 
1195
- # load data example
1196
- data_example = deserialize_nested_feature_arrays(
1197
- str(model_path / f"{file_name}.data_example.st"),
1198
- str(model_path / f"{file_name}.data_example_metadata.json"),
1186
+ data_example = io_utils.pickle_load(
1187
+ model_path / f"{file_name}.data_example.pkl"
1199
1188
  )
1200
- # load label data
1201
- loaded_label_data = deserialize_nested_feature_arrays(
1202
- str(model_path / f"{file_name}.label_data.st"),
1203
- str(model_path / f"{file_name}.label_data_metadata.json"),
1204
- )
1205
- label_data = RasaModelData(data=loaded_label_data)
1206
-
1207
- sparse_feature_sizes = rasa.shared.utils.io.read_json_file(
1208
- model_path / f"{file_name}.sparse_feature_sizes.json"
1189
+ label_data = io_utils.pickle_load(model_path / f"{file_name}.label_data.pkl")
1190
+ label_data = RasaModelData(data=label_data)
1191
+ sparse_feature_sizes = io_utils.pickle_load(
1192
+ model_path / f"{file_name}.sparse_feature_sizes.pkl"
1209
1193
  )
1210
- index_label_id_mapping = rasa.shared.utils.io.read_json_file(
1194
+ index_label_id_mapping = io_utils.json_unpickle(
1211
1195
  model_path / f"{file_name}.index_label_id_mapping.json"
1212
1196
  )
1213
1197
  entity_tag_specs = rasa.shared.utils.io.read_json_file(
@@ -1227,6 +1211,7 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1227
1211
  for tag_spec in entity_tag_specs
1228
1212
  ]
1229
1213
 
1214
+ # jsonpickle converts dictionary keys to strings
1230
1215
  index_label_id_mapping = {
1231
1216
  int(key): value for key, value in index_label_id_mapping.items()
1232
1217
  }
@@ -1280,7 +1265,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1280
1265
  config: Dict[Text, Any],
1281
1266
  finetune_mode: bool,
1282
1267
  ) -> "RasaModel":
1283
-
1284
1268
  predict_data_example = RasaModelData(
1285
1269
  label_key=model_data_example.label_key,
1286
1270
  data={
@@ -1467,10 +1451,10 @@ class DIET(TransformerRasaModel):
1467
1451
  # everything using a transformer and optionally also do masked language
1468
1452
  # modeling.
1469
1453
  self.text_name = TEXT
1470
- self._tf_layers[
1471
- f"sequence_layer.{self.text_name}"
1472
- ] = rasa_layers.RasaSequenceLayer(
1473
- self.text_name, self.data_signature[self.text_name], self.config
1454
+ self._tf_layers[f"sequence_layer.{self.text_name}"] = (
1455
+ rasa_layers.RasaSequenceLayer(
1456
+ self.text_name, self.data_signature[self.text_name], self.config
1457
+ )
1474
1458
  )
1475
1459
  if self.config[MASKED_LM]:
1476
1460
  self._prepare_mask_lm_loss(self.text_name)
@@ -1488,10 +1472,10 @@ class DIET(TransformerRasaModel):
1488
1472
  {SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False}
1489
1473
  )
1490
1474
 
1491
- self._tf_layers[
1492
- f"feature_combining_layer.{self.label_name}"
1493
- ] = rasa_layers.RasaFeatureCombiningLayer(
1494
- self.label_name, self.label_signature[self.label_name], label_config
1475
+ self._tf_layers[f"feature_combining_layer.{self.label_name}"] = (
1476
+ rasa_layers.RasaFeatureCombiningLayer(
1477
+ self.label_name, self.label_signature[self.label_name], label_config
1478
+ )
1495
1479
  )
1496
1480
 
1497
1481
  self._prepare_ffnn_layer(
@@ -1523,7 +1507,6 @@ class DIET(TransformerRasaModel):
1523
1507
  sequence_feature_lengths: tf.Tensor,
1524
1508
  name: Text,
1525
1509
  ) -> tf.Tensor:
1526
-
1527
1510
  x, _ = self._tf_layers[f"feature_combining_layer.{name}"](
1528
1511
  (sequence_features, sentence_features, sequence_feature_lengths),
1529
1512
  training=self._training,
@@ -1705,7 +1688,6 @@ class DIET(TransformerRasaModel):
1705
1688
  return loss
1706
1689
 
1707
1690
  def _update_label_metrics(self, loss: tf.Tensor, acc: tf.Tensor) -> None:
1708
-
1709
1691
  self.intent_loss.update_state(loss)
1710
1692
  self.intent_acc.update_state(acc)
1711
1693
 
@@ -1864,7 +1846,6 @@ class DIET(TransformerRasaModel):
1864
1846
  combined_sequence_sentence_feature_lengths: tf.Tensor,
1865
1847
  text_transformed: tf.Tensor,
1866
1848
  ) -> Dict[Text, tf.Tensor]:
1867
-
1868
1849
  if self.all_labels_embed is None:
1869
1850
  raise ValueError(
1870
1851
  "The model was not prepared for prediction. "