rasa-pro 3.8.18__py3-none-any.whl → 3.9.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (278) hide show
  1. README.md +6 -42
  2. rasa/__main__.py +14 -9
  3. rasa/anonymization/anonymization_pipeline.py +0 -1
  4. rasa/anonymization/anonymization_rule_executor.py +3 -3
  5. rasa/anonymization/utils.py +4 -3
  6. rasa/api.py +2 -2
  7. rasa/cli/arguments/default_arguments.py +1 -1
  8. rasa/cli/arguments/run.py +2 -2
  9. rasa/cli/arguments/test.py +1 -1
  10. rasa/cli/arguments/train.py +10 -10
  11. rasa/cli/e2e_test.py +27 -7
  12. rasa/cli/export.py +0 -1
  13. rasa/cli/license.py +3 -3
  14. rasa/cli/project_templates/calm/actions/action_template.py +1 -1
  15. rasa/cli/project_templates/calm/config.yml +1 -1
  16. rasa/cli/project_templates/calm/credentials.yml +1 -1
  17. rasa/cli/project_templates/calm/data/flows/add_contact.yml +1 -1
  18. rasa/cli/project_templates/calm/data/flows/remove_contact.yml +1 -1
  19. rasa/cli/project_templates/calm/domain/add_contact.yml +8 -2
  20. rasa/cli/project_templates/calm/domain/list_contacts.yml +3 -0
  21. rasa/cli/project_templates/calm/domain/remove_contact.yml +9 -2
  22. rasa/cli/project_templates/calm/domain/shared.yml +5 -0
  23. rasa/cli/project_templates/calm/endpoints.yml +4 -4
  24. rasa/cli/project_templates/default/actions/actions.py +1 -1
  25. rasa/cli/project_templates/default/config.yml +5 -5
  26. rasa/cli/project_templates/default/credentials.yml +1 -1
  27. rasa/cli/project_templates/default/endpoints.yml +4 -4
  28. rasa/cli/project_templates/default/tests/test_stories.yml +1 -1
  29. rasa/cli/project_templates/tutorial/config.yml +1 -1
  30. rasa/cli/project_templates/tutorial/credentials.yml +1 -1
  31. rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
  32. rasa/cli/project_templates/tutorial/domain.yml +4 -0
  33. rasa/cli/project_templates/tutorial/endpoints.yml +6 -6
  34. rasa/cli/run.py +0 -1
  35. rasa/cli/scaffold.py +3 -2
  36. rasa/cli/studio/download.py +11 -0
  37. rasa/cli/studio/studio.py +180 -24
  38. rasa/cli/studio/upload.py +0 -8
  39. rasa/cli/telemetry.py +18 -6
  40. rasa/cli/utils.py +21 -10
  41. rasa/cli/x.py +3 -2
  42. rasa/constants.py +1 -1
  43. rasa/core/actions/action.py +90 -315
  44. rasa/core/actions/action_exceptions.py +24 -0
  45. rasa/core/actions/constants.py +3 -0
  46. rasa/core/actions/custom_action_executor.py +188 -0
  47. rasa/core/actions/forms.py +11 -7
  48. rasa/core/actions/grpc_custom_action_executor.py +251 -0
  49. rasa/core/actions/http_custom_action_executor.py +140 -0
  50. rasa/core/actions/loops.py +3 -0
  51. rasa/core/actions/two_stage_fallback.py +1 -1
  52. rasa/core/agent.py +2 -4
  53. rasa/core/brokers/pika.py +1 -2
  54. rasa/core/channels/audiocodes.py +1 -1
  55. rasa/core/channels/botframework.py +0 -1
  56. rasa/core/channels/callback.py +0 -1
  57. rasa/core/channels/console.py +6 -8
  58. rasa/core/channels/development_inspector.py +1 -1
  59. rasa/core/channels/facebook.py +0 -3
  60. rasa/core/channels/hangouts.py +0 -6
  61. rasa/core/channels/inspector/dist/assets/{arc-5623b6dc.js → arc-b6e548fe.js} +1 -1
  62. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-685c106a.js → c4Diagram-d0fbc5ce-fa03ac9e.js} +1 -1
  63. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-8cbed007.js → classDiagram-936ed81e-ee67392a.js} +1 -1
  64. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-5889cf12.js → classDiagram-v2-c3cb15f1-9b283fae.js} +1 -1
  65. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-24c249d7.js → createText-62fc7601-8b6fcc2a.js} +1 -1
  66. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-7dd06a75.js → edges-f2ad444c-22e77f4f.js} +1 -1
  67. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-62c1e54c.js → erDiagram-9d236eb7-60ffc87f.js} +1 -1
  68. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-ce49b86f.js → flowDb-1972c806-9dd802e4.js} +1 -1
  69. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4067e48f.js → flowDiagram-7ea5b25a-5fa1912f.js} +1 -1
  70. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
  71. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-59fe4051.js → flowchart-elk-definition-abe16c3d-622a1fd2.js} +1 -1
  72. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-47e3a43b.js → ganttDiagram-9b5ea136-e285a63a.js} +1 -1
  73. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-5a2ac0d9.js → gitGraphDiagram-99d0ae7c-f237bdca.js} +1 -1
  74. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-dfb8efc4.js → index-2c4b9a3b-4b03d70e.js} +1 -1
  75. rasa/core/channels/inspector/dist/assets/{index-268a75c0.js → index-a5d3e69d.js} +4 -4
  76. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-b0c470f2.js → infoDiagram-736b4530-72a0fa5f.js} +1 -1
  77. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-2edb829a.js → journeyDiagram-df861f2b-82218c41.js} +1 -1
  78. rasa/core/channels/inspector/dist/assets/{layout-b6873d69.js → layout-78cff630.js} +1 -1
  79. rasa/core/channels/inspector/dist/assets/{line-1efc5781.js → line-5038b469.js} +1 -1
  80. rasa/core/channels/inspector/dist/assets/{linear-661e9b94.js → linear-c4fc4098.js} +1 -1
  81. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2d2e727f.js → mindmap-definition-beec6740-c33c8ea6.js} +1 -1
  82. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-9d3ea93d.js → pieDiagram-dbbf0591-a8d03059.js} +1 -1
  83. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-06a178a2.js → quadrantDiagram-4d7f4fd6-6a0e56b2.js} +1 -1
  84. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-0bfedffc.js → requirementDiagram-6fc4c22a-2dc7c7bd.js} +1 -1
  85. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-d76d0a04.js → sankeyDiagram-8f13d901-2360fe39.js} +1 -1
  86. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-37bb4341.js → sequenceDiagram-b655622a-41b9f9ad.js} +1 -1
  87. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-f52f7f57.js → stateDiagram-59f0c015-0aad326f.js} +1 -1
  88. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-4a986a20.js → stateDiagram-v2-2b26beab-9847d984.js} +1 -1
  89. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-7dd9ae12.js → styles-080da4f6-564d890e.js} +1 -1
  90. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-46e1ca14.js → styles-3dcbcfbf-38957613.js} +1 -1
  91. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-4a97439a.js → styles-9c745c82-f0fc6921.js} +1 -1
  92. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-823917a3.js → svgDrawCommon-4835440b-ef3c5a77.js} +1 -1
  93. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-9ea72896.js → timeline-definition-5b62e21b-bf3e91c1.js} +1 -1
  94. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-b631a8b6.js → xychartDiagram-2b33534f-4d4026c0.js} +1 -1
  95. rasa/core/channels/inspector/dist/index.html +1 -1
  96. rasa/core/channels/inspector/src/components/DiagramFlow.tsx +10 -0
  97. rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -7
  98. rasa/core/channels/inspector/src/helpers/formatters.ts +3 -2
  99. rasa/core/channels/rest.py +36 -21
  100. rasa/core/channels/rocketchat.py +0 -1
  101. rasa/core/channels/socketio.py +1 -1
  102. rasa/core/channels/telegram.py +3 -3
  103. rasa/core/channels/webexteams.py +0 -1
  104. rasa/core/concurrent_lock_store.py +1 -1
  105. rasa/core/evaluation/marker_base.py +1 -3
  106. rasa/core/evaluation/marker_stats.py +1 -2
  107. rasa/core/featurizers/single_state_featurizer.py +3 -26
  108. rasa/core/featurizers/tracker_featurizers.py +18 -122
  109. rasa/core/information_retrieval/__init__.py +7 -0
  110. rasa/core/information_retrieval/faiss.py +9 -4
  111. rasa/core/information_retrieval/information_retrieval.py +64 -7
  112. rasa/core/information_retrieval/milvus.py +7 -14
  113. rasa/core/information_retrieval/qdrant.py +8 -15
  114. rasa/core/lock_store.py +0 -1
  115. rasa/core/migrate.py +1 -2
  116. rasa/core/nlg/callback.py +3 -4
  117. rasa/core/policies/enterprise_search_policy.py +86 -22
  118. rasa/core/policies/enterprise_search_prompt_template.jinja2 +4 -41
  119. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
  120. rasa/core/policies/flows/flow_executor.py +104 -2
  121. rasa/core/policies/intentless_policy.py +7 -9
  122. rasa/core/policies/memoization.py +3 -3
  123. rasa/core/policies/policy.py +18 -9
  124. rasa/core/policies/rule_policy.py +8 -11
  125. rasa/core/policies/ted_policy.py +61 -88
  126. rasa/core/policies/unexpected_intent_policy.py +8 -17
  127. rasa/core/processor.py +136 -47
  128. rasa/core/run.py +41 -25
  129. rasa/core/secrets_manager/endpoints.py +2 -2
  130. rasa/core/secrets_manager/vault.py +6 -8
  131. rasa/core/test.py +3 -5
  132. rasa/core/tracker_store.py +49 -14
  133. rasa/core/train.py +1 -3
  134. rasa/core/training/interactive.py +9 -6
  135. rasa/core/utils.py +5 -10
  136. rasa/dialogue_understanding/coexistence/intent_based_router.py +11 -4
  137. rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -3
  138. rasa/dialogue_understanding/commands/__init__.py +4 -0
  139. rasa/dialogue_understanding/commands/can_not_handle_command.py +9 -0
  140. rasa/dialogue_understanding/commands/cancel_flow_command.py +9 -0
  141. rasa/dialogue_understanding/commands/change_flow_command.py +38 -0
  142. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +9 -0
  143. rasa/dialogue_understanding/commands/clarify_command.py +9 -0
  144. rasa/dialogue_understanding/commands/correct_slots_command.py +9 -0
  145. rasa/dialogue_understanding/commands/error_command.py +12 -0
  146. rasa/dialogue_understanding/commands/handle_code_change_command.py +9 -0
  147. rasa/dialogue_understanding/commands/human_handoff_command.py +9 -0
  148. rasa/dialogue_understanding/commands/knowledge_answer_command.py +9 -0
  149. rasa/dialogue_understanding/commands/noop_command.py +9 -0
  150. rasa/dialogue_understanding/commands/set_slot_command.py +38 -3
  151. rasa/dialogue_understanding/commands/skip_question_command.py +9 -0
  152. rasa/dialogue_understanding/commands/start_flow_command.py +9 -0
  153. rasa/dialogue_understanding/generator/__init__.py +16 -1
  154. rasa/dialogue_understanding/generator/command_generator.py +92 -6
  155. rasa/dialogue_understanding/generator/constants.py +18 -0
  156. rasa/dialogue_understanding/generator/flow_retrieval.py +7 -5
  157. rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
  158. rasa/dialogue_understanding/generator/llm_command_generator.py +39 -609
  159. rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
  160. rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
  161. rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
  162. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
  163. rasa/dialogue_understanding/generator/nlu_command_adapter.py +69 -8
  164. rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
  165. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
  166. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +36 -31
  167. rasa/dialogue_understanding/processor/command_processor.py +112 -3
  168. rasa/e2e_test/constants.py +1 -0
  169. rasa/e2e_test/e2e_test_case.py +44 -0
  170. rasa/e2e_test/e2e_test_runner.py +114 -11
  171. rasa/e2e_test/e2e_test_schema.yml +18 -0
  172. rasa/engine/caching.py +0 -1
  173. rasa/engine/graph.py +18 -6
  174. rasa/engine/recipes/config_files/default_config.yml +3 -3
  175. rasa/engine/recipes/default_components.py +1 -1
  176. rasa/engine/recipes/default_recipe.py +4 -5
  177. rasa/engine/recipes/recipe.py +1 -1
  178. rasa/engine/runner/dask.py +3 -9
  179. rasa/engine/storage/local_model_storage.py +0 -2
  180. rasa/engine/validation.py +179 -145
  181. rasa/exceptions.py +2 -2
  182. rasa/graph_components/validators/default_recipe_validator.py +3 -5
  183. rasa/hooks.py +0 -1
  184. rasa/model.py +1 -1
  185. rasa/model_training.py +1 -0
  186. rasa/nlu/classifiers/diet_classifier.py +33 -52
  187. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
  188. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  189. rasa/nlu/extractors/crf_entity_extractor.py +54 -97
  190. rasa/nlu/extractors/duckling_entity_extractor.py +1 -1
  191. rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +1 -5
  192. rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +0 -4
  193. rasa/nlu/featurizers/featurizer.py +1 -1
  194. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +18 -49
  195. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +26 -64
  196. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  197. rasa/nlu/persistor.py +68 -26
  198. rasa/nlu/selectors/response_selector.py +7 -10
  199. rasa/nlu/test.py +0 -3
  200. rasa/nlu/utils/hugging_face/registry.py +1 -1
  201. rasa/nlu/utils/spacy_utils.py +1 -3
  202. rasa/server.py +22 -7
  203. rasa/shared/constants.py +12 -1
  204. rasa/shared/core/command_payload_reader.py +109 -0
  205. rasa/shared/core/constants.py +4 -5
  206. rasa/shared/core/domain.py +57 -56
  207. rasa/shared/core/events.py +4 -7
  208. rasa/shared/core/flows/flow.py +9 -0
  209. rasa/shared/core/flows/flows_list.py +12 -0
  210. rasa/shared/core/flows/steps/action.py +7 -2
  211. rasa/shared/core/generator.py +12 -11
  212. rasa/shared/core/slot_mappings.py +315 -24
  213. rasa/shared/core/slots.py +4 -2
  214. rasa/shared/core/trackers.py +32 -14
  215. rasa/shared/core/training_data/loading.py +0 -1
  216. rasa/shared/core/training_data/story_reader/story_reader.py +3 -3
  217. rasa/shared/core/training_data/story_reader/yaml_story_reader.py +11 -11
  218. rasa/shared/core/training_data/story_writer/yaml_story_writer.py +5 -3
  219. rasa/shared/core/training_data/structures.py +1 -1
  220. rasa/shared/core/training_data/visualization.py +1 -1
  221. rasa/shared/data.py +58 -1
  222. rasa/shared/exceptions.py +36 -2
  223. rasa/shared/importers/importer.py +1 -2
  224. rasa/shared/importers/rasa.py +0 -1
  225. rasa/shared/nlu/constants.py +2 -0
  226. rasa/shared/nlu/training_data/entities_parser.py +1 -2
  227. rasa/shared/nlu/training_data/features.py +2 -120
  228. rasa/shared/nlu/training_data/formats/dialogflow.py +3 -2
  229. rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -5
  230. rasa/shared/nlu/training_data/formats/readerwriter.py +0 -1
  231. rasa/shared/nlu/training_data/message.py +13 -0
  232. rasa/shared/nlu/training_data/training_data.py +0 -2
  233. rasa/shared/providers/openai/session_handler.py +2 -2
  234. rasa/shared/utils/constants.py +3 -0
  235. rasa/shared/utils/io.py +11 -1
  236. rasa/shared/utils/llm.py +1 -2
  237. rasa/shared/utils/pykwalify_extensions.py +1 -0
  238. rasa/shared/utils/schemas/domain.yml +3 -0
  239. rasa/shared/utils/yaml.py +44 -35
  240. rasa/studio/auth.py +26 -10
  241. rasa/studio/constants.py +2 -0
  242. rasa/studio/data_handler.py +114 -107
  243. rasa/studio/download.py +160 -27
  244. rasa/studio/results_logger.py +137 -0
  245. rasa/studio/train.py +6 -7
  246. rasa/studio/upload.py +159 -134
  247. rasa/telemetry.py +188 -34
  248. rasa/tracing/config.py +18 -3
  249. rasa/tracing/constants.py +26 -2
  250. rasa/tracing/instrumentation/attribute_extractors.py +50 -41
  251. rasa/tracing/instrumentation/instrumentation.py +290 -44
  252. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +7 -5
  253. rasa/tracing/instrumentation/metrics.py +109 -21
  254. rasa/tracing/metric_instrument_provider.py +83 -3
  255. rasa/utils/cli.py +2 -1
  256. rasa/utils/common.py +1 -1
  257. rasa/utils/endpoints.py +1 -2
  258. rasa/utils/io.py +72 -6
  259. rasa/utils/licensing.py +246 -31
  260. rasa/utils/ml_utils.py +1 -1
  261. rasa/utils/tensorflow/data_generator.py +1 -1
  262. rasa/utils/tensorflow/environment.py +1 -1
  263. rasa/utils/tensorflow/model_data.py +201 -12
  264. rasa/utils/tensorflow/model_data_utils.py +499 -500
  265. rasa/utils/tensorflow/models.py +5 -6
  266. rasa/utils/tensorflow/rasa_layers.py +15 -15
  267. rasa/utils/train_utils.py +1 -1
  268. rasa/utils/url_tools.py +53 -0
  269. rasa/validator.py +305 -3
  270. rasa/version.py +1 -1
  271. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/METADATA +25 -61
  272. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/RECORD +276 -259
  273. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +0 -1
  274. rasa/utils/tensorflow/feature_array.py +0 -370
  275. /rasa/dialogue_understanding/generator/{command_prompt_template.jinja2 → single_step/command_prompt_template.jinja2} +0 -0
  276. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/NOTICE +0 -0
  277. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/WHEEL +0 -0
  278. {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/entry_points.txt +0 -0
@@ -1,17 +1,18 @@
1
1
  from __future__ import annotations
2
-
3
2
  import copy
4
3
  import logging
5
4
  from collections import defaultdict
6
5
  from pathlib import Path
7
- from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
6
+
7
+ from rasa.exceptions import ModelNotFound
8
+ from rasa.nlu.featurizers.featurizer import Featurizer
8
9
 
9
10
  import numpy as np
10
11
  import scipy.sparse
11
12
  import tensorflow as tf
12
13
 
13
- from rasa.exceptions import ModelNotFound
14
- from rasa.nlu.featurizers.featurizer import Featurizer
14
+ from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
15
+
15
16
  from rasa.engine.graph import ExecutionContext, GraphComponent
16
17
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
17
18
  from rasa.engine.storage.resource import Resource
@@ -19,21 +20,18 @@ from rasa.engine.storage.storage import ModelStorage
19
20
  from rasa.nlu.extractors.extractor import EntityExtractorMixin
20
21
  from rasa.nlu.classifiers.classifier import IntentClassifier
21
22
  import rasa.shared.utils.io
23
+ import rasa.utils.io as io_utils
22
24
  import rasa.nlu.utils.bilou_utils as bilou_utils
23
25
  from rasa.shared.constants import DIAGNOSTIC_DATA
24
26
  from rasa.nlu.extractors.extractor import EntityTagSpec
25
27
  from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
26
28
  from rasa.utils import train_utils
27
29
  from rasa.utils.tensorflow import rasa_layers
28
- from rasa.utils.tensorflow.feature_array import (
29
- FeatureArray,
30
- serialize_nested_feature_arrays,
31
- deserialize_nested_feature_arrays,
32
- )
33
30
  from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
34
31
  from rasa.utils.tensorflow.model_data import (
35
32
  RasaModelData,
36
33
  FeatureSignature,
34
+ FeatureArray,
37
35
  )
38
36
  from rasa.nlu.constants import TOKENS_NAMES, DEFAULT_TRANSFORMER_SIZE
39
37
  from rasa.shared.nlu.constants import (
@@ -120,6 +118,7 @@ LABEL_SUB_KEY = IDS
120
118
 
121
119
  POSSIBLE_TAGS = [ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_ROLE, ENTITY_ATTRIBUTE_GROUP]
122
120
 
121
+
123
122
  DIETClassifierT = TypeVar("DIETClassifierT", bound="DIETClassifier")
124
123
 
125
124
 
@@ -511,7 +510,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
511
510
  def _extract_features(
512
511
  self, message: Message, attribute: Text
513
512
  ) -> Dict[Text, Union[scipy.sparse.spmatrix, np.ndarray]]:
514
-
515
513
  (
516
514
  sparse_sequence_features,
517
515
  sparse_sentence_features,
@@ -781,7 +779,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
781
779
  sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
782
780
  label_attribute: Optional[Text] = None,
783
781
  ) -> Dict[Text, Dict[Text, List[int]]]:
784
-
785
782
  if label_attribute in sparse_feature_sizes:
786
783
  del sparse_feature_sizes[label_attribute]
787
784
  return sparse_feature_sizes
@@ -1086,24 +1083,18 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1086
1083
 
1087
1084
  self.model.save(str(tf_model_file))
1088
1085
 
1089
- # save data example
1090
- serialize_nested_feature_arrays(
1091
- self._data_example,
1092
- model_path / f"{file_name}.data_example.st",
1093
- model_path / f"{file_name}.data_example_metadata.json",
1086
+ io_utils.pickle_dump(
1087
+ model_path / f"{file_name}.data_example.pkl", self._data_example
1094
1088
  )
1095
- # save label data
1096
- serialize_nested_feature_arrays(
1097
- dict(self._label_data.data) if self._label_data is not None else {},
1098
- model_path / f"{file_name}.label_data.st",
1099
- model_path / f"{file_name}.label_data_metadata.json",
1100
- )
1101
-
1102
- rasa.shared.utils.io.dump_obj_as_json_to_file(
1103
- model_path / f"{file_name}.sparse_feature_sizes.json",
1089
+ io_utils.pickle_dump(
1090
+ model_path / f"{file_name}.sparse_feature_sizes.pkl",
1104
1091
  self._sparse_feature_sizes,
1105
1092
  )
1106
- rasa.shared.utils.io.dump_obj_as_json_to_file(
1093
+ io_utils.pickle_dump(
1094
+ model_path / f"{file_name}.label_data.pkl",
1095
+ dict(self._label_data.data) if self._label_data is not None else {},
1096
+ )
1097
+ io_utils.json_pickle(
1107
1098
  model_path / f"{file_name}.index_label_id_mapping.json",
1108
1099
  self.index_label_id_mapping,
1109
1100
  )
@@ -1192,22 +1183,15 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1192
1183
  ]:
1193
1184
  file_name = cls.__name__
1194
1185
 
1195
- # load data example
1196
- data_example = deserialize_nested_feature_arrays(
1197
- str(model_path / f"{file_name}.data_example.st"),
1198
- str(model_path / f"{file_name}.data_example_metadata.json"),
1186
+ data_example = io_utils.pickle_load(
1187
+ model_path / f"{file_name}.data_example.pkl"
1199
1188
  )
1200
- # load label data
1201
- loaded_label_data = deserialize_nested_feature_arrays(
1202
- str(model_path / f"{file_name}.label_data.st"),
1203
- str(model_path / f"{file_name}.label_data_metadata.json"),
1204
- )
1205
- label_data = RasaModelData(data=loaded_label_data)
1206
-
1207
- sparse_feature_sizes = rasa.shared.utils.io.read_json_file(
1208
- model_path / f"{file_name}.sparse_feature_sizes.json"
1189
+ label_data = io_utils.pickle_load(model_path / f"{file_name}.label_data.pkl")
1190
+ label_data = RasaModelData(data=label_data)
1191
+ sparse_feature_sizes = io_utils.pickle_load(
1192
+ model_path / f"{file_name}.sparse_feature_sizes.pkl"
1209
1193
  )
1210
- index_label_id_mapping = rasa.shared.utils.io.read_json_file(
1194
+ index_label_id_mapping = io_utils.json_unpickle(
1211
1195
  model_path / f"{file_name}.index_label_id_mapping.json"
1212
1196
  )
1213
1197
  entity_tag_specs = rasa.shared.utils.io.read_json_file(
@@ -1227,6 +1211,7 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1227
1211
  for tag_spec in entity_tag_specs
1228
1212
  ]
1229
1213
 
1214
+ # jsonpickle converts dictionary keys to strings
1230
1215
  index_label_id_mapping = {
1231
1216
  int(key): value for key, value in index_label_id_mapping.items()
1232
1217
  }
@@ -1280,7 +1265,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1280
1265
  config: Dict[Text, Any],
1281
1266
  finetune_mode: bool,
1282
1267
  ) -> "RasaModel":
1283
-
1284
1268
  predict_data_example = RasaModelData(
1285
1269
  label_key=model_data_example.label_key,
1286
1270
  data={
@@ -1467,10 +1451,10 @@ class DIET(TransformerRasaModel):
1467
1451
  # everything using a transformer and optionally also do masked language
1468
1452
  # modeling.
1469
1453
  self.text_name = TEXT
1470
- self._tf_layers[
1471
- f"sequence_layer.{self.text_name}"
1472
- ] = rasa_layers.RasaSequenceLayer(
1473
- self.text_name, self.data_signature[self.text_name], self.config
1454
+ self._tf_layers[f"sequence_layer.{self.text_name}"] = (
1455
+ rasa_layers.RasaSequenceLayer(
1456
+ self.text_name, self.data_signature[self.text_name], self.config
1457
+ )
1474
1458
  )
1475
1459
  if self.config[MASKED_LM]:
1476
1460
  self._prepare_mask_lm_loss(self.text_name)
@@ -1488,10 +1472,10 @@ class DIET(TransformerRasaModel):
1488
1472
  {SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False}
1489
1473
  )
1490
1474
 
1491
- self._tf_layers[
1492
- f"feature_combining_layer.{self.label_name}"
1493
- ] = rasa_layers.RasaFeatureCombiningLayer(
1494
- self.label_name, self.label_signature[self.label_name], label_config
1475
+ self._tf_layers[f"feature_combining_layer.{self.label_name}"] = (
1476
+ rasa_layers.RasaFeatureCombiningLayer(
1477
+ self.label_name, self.label_signature[self.label_name], label_config
1478
+ )
1495
1479
  )
1496
1480
 
1497
1481
  self._prepare_ffnn_layer(
@@ -1523,7 +1507,6 @@ class DIET(TransformerRasaModel):
1523
1507
  sequence_feature_lengths: tf.Tensor,
1524
1508
  name: Text,
1525
1509
  ) -> tf.Tensor:
1526
-
1527
1510
  x, _ = self._tf_layers[f"feature_combining_layer.{name}"](
1528
1511
  (sequence_features, sentence_features, sequence_feature_lengths),
1529
1512
  training=self._training,
@@ -1705,7 +1688,6 @@ class DIET(TransformerRasaModel):
1705
1688
  return loss
1706
1689
 
1707
1690
  def _update_label_metrics(self, loss: tf.Tensor, acc: tf.Tensor) -> None:
1708
-
1709
1691
  self.intent_loss.update_state(loss)
1710
1692
  self.intent_acc.update_state(acc)
1711
1693
 
@@ -1864,7 +1846,6 @@ class DIET(TransformerRasaModel):
1864
1846
  combined_sequence_sentence_feature_lengths: tf.Tensor,
1865
1847
  text_transformed: tf.Tensor,
1866
1848
  ) -> Dict[Text, tf.Tensor]:
1867
-
1868
1849
  if self.all_labels_embed is None:
1869
1850
  raise ValueError(
1870
1851
  "The model was not prepared for prediction. "
@@ -1,21 +1,22 @@
1
1
  from typing import Any, Text, Dict, List, Type, Tuple
2
2
 
3
+ import joblib
3
4
  import structlog
4
5
  from scipy.sparse import hstack, vstack, csr_matrix
5
6
  from sklearn.exceptions import NotFittedError
6
7
  from sklearn.linear_model import LogisticRegression
7
8
  from sklearn.utils.validation import check_is_fitted
8
9
 
9
- from rasa.engine.graph import ExecutionContext, GraphComponent
10
- from rasa.engine.recipes.default_recipe import DefaultV1Recipe
11
10
  from rasa.engine.storage.resource import Resource
12
11
  from rasa.engine.storage.storage import ModelStorage
12
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
13
+ from rasa.engine.graph import ExecutionContext, GraphComponent
13
14
  from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
14
- from rasa.nlu.classifiers.classifier import IntentClassifier
15
15
  from rasa.nlu.featurizers.featurizer import Featurizer
16
- from rasa.shared.nlu.constants import TEXT, INTENT
17
- from rasa.shared.nlu.training_data.message import Message
16
+ from rasa.nlu.classifiers.classifier import IntentClassifier
18
17
  from rasa.shared.nlu.training_data.training_data import TrainingData
18
+ from rasa.shared.nlu.training_data.message import Message
19
+ from rasa.shared.nlu.constants import TEXT, INTENT
19
20
  from rasa.utils.tensorflow.constants import RANKING_LENGTH
20
21
 
21
22
  structlogger = structlog.get_logger()
@@ -183,11 +184,9 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
183
184
 
184
185
  def persist(self) -> None:
185
186
  """Persist this model into the passed directory."""
186
- import skops.io as sio
187
-
188
187
  with self._model_storage.write_to(self._resource) as model_dir:
189
- path = model_dir / f"{self._resource.name}.skops"
190
- sio.dump(self.clf, path)
188
+ path = model_dir / f"{self._resource.name}.joblib"
189
+ joblib.dump(self.clf, path)
191
190
  structlogger.debug(
192
191
  "logistic_regression_classifier.persist",
193
192
  event_info=f"Saved intent classifier to '{path}'.",
@@ -203,21 +202,9 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
203
202
  **kwargs: Any,
204
203
  ) -> "LogisticRegressionClassifier":
205
204
  """Loads trained component (see parent class for full docstring)."""
206
- import skops.io as sio
207
-
208
205
  try:
209
206
  with model_storage.read_from(resource) as model_dir:
210
- classifier_file = model_dir / f"{resource.name}.skops"
211
- unknown_types = sio.get_untrusted_types(file=classifier_file)
212
-
213
- if unknown_types:
214
- structlogger.error(
215
- f"Untrusted types found when loading {classifier_file}!",
216
- unknown_types=unknown_types,
217
- )
218
- raise ValueError()
219
-
220
- classifier = sio.load(classifier_file, trusted=unknown_types)
207
+ classifier = joblib.load(model_dir / f"{resource.name}.joblib")
221
208
  component = cls(
222
209
  config, execution_context.node_name, model_storage, resource
223
210
  )
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
-
3
2
  import logging
3
+ from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
4
4
  import typing
5
5
  import warnings
6
6
  from typing import Any, Dict, List, Optional, Text, Tuple, Type
@@ -8,18 +8,18 @@ from typing import Any, Dict, List, Optional, Text, Tuple, Type
8
8
  import numpy as np
9
9
 
10
10
  import rasa.shared.utils.io
11
+ import rasa.utils.io as io_utils
11
12
  from rasa.engine.graph import GraphComponent, ExecutionContext
12
13
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
13
14
  from rasa.engine.storage.resource import Resource
14
15
  from rasa.engine.storage.storage import ModelStorage
15
- from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
16
- from rasa.nlu.classifiers.classifier import IntentClassifier
17
- from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
18
16
  from rasa.shared.constants import DOCS_URL_TRAINING_DATA_NLU
17
+ from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
19
18
  from rasa.shared.exceptions import RasaException
20
19
  from rasa.shared.nlu.constants import TEXT
21
- from rasa.shared.nlu.training_data.message import Message
20
+ from rasa.nlu.classifiers.classifier import IntentClassifier
22
21
  from rasa.shared.nlu.training_data.training_data import TrainingData
22
+ from rasa.shared.nlu.training_data.message import Message
23
23
  from rasa.utils.tensorflow.constants import FEATURIZERS
24
24
 
25
25
  logger = logging.getLogger(__name__)
@@ -266,20 +266,14 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
266
266
 
267
267
  def persist(self) -> None:
268
268
  """Persist this model into the passed directory."""
269
- import skops.io as sio
270
-
271
269
  with self._model_storage.write_to(self._resource) as model_dir:
272
270
  file_name = self.__class__.__name__
273
- classifier_file_name = model_dir / f"{file_name}_classifier.skops"
274
- encoder_file_name = model_dir / f"{file_name}_encoder.json"
271
+ classifier_file_name = model_dir / f"{file_name}_classifier.pkl"
272
+ encoder_file_name = model_dir / f"{file_name}_encoder.pkl"
275
273
 
276
274
  if self.clf and self.le:
277
- # convert self.le.classes_ (numpy array of strings) to a list in order
278
- # to use json dump
279
- rasa.shared.utils.io.dump_obj_as_json_to_file(
280
- encoder_file_name, list(self.le.classes_)
281
- )
282
- sio.dump(self.clf.best_estimator_, classifier_file_name)
275
+ io_utils.json_pickle(encoder_file_name, self.le.classes_)
276
+ io_utils.json_pickle(classifier_file_name, self.clf.best_estimator_)
283
277
 
284
278
  @classmethod
285
279
  def load(
@@ -292,36 +286,21 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
292
286
  ) -> SklearnIntentClassifier:
293
287
  """Loads trained component (see parent class for full docstring)."""
294
288
  from sklearn.preprocessing import LabelEncoder
295
- import skops.io as sio
296
289
 
297
290
  try:
298
291
  with model_storage.read_from(resource) as model_dir:
299
292
  file_name = cls.__name__
300
- classifier_file = model_dir / f"{file_name}_classifier.skops"
293
+ classifier_file = model_dir / f"{file_name}_classifier.pkl"
301
294
 
302
295
  if classifier_file.exists():
303
- unknown_types = sio.get_untrusted_types(file=classifier_file)
304
-
305
- if unknown_types:
306
- logger.error(
307
- f"Untrusted types ({unknown_types}) found when "
308
- f"loading {classifier_file}!"
309
- )
310
- raise ValueError()
311
- else:
312
- classifier = sio.load(classifier_file, trusted=unknown_types)
313
-
314
- encoder_file = model_dir / f"{file_name}_encoder.json"
315
- classes = rasa.shared.utils.io.read_json_file(encoder_file)
296
+ classifier = io_utils.json_unpickle(classifier_file)
316
297
 
298
+ encoder_file = model_dir / f"{file_name}_encoder.pkl"
299
+ classes = io_utils.json_unpickle(encoder_file)
317
300
  encoder = LabelEncoder()
318
- intent_classifier = cls(
319
- config, model_storage, resource, classifier, encoder
320
- )
321
- # convert list of strings (class labels) back to numpy array of
322
- # strings
323
- intent_classifier.transform_labels_str2num(classes)
324
- return intent_classifier
301
+ encoder.classes_ = classes
302
+
303
+ return cls(config, model_storage, resource, classifier, encoder)
325
304
  except ValueError:
326
305
  logger.debug(
327
306
  f"Failed to load '{cls.__name__}' from model storage. Resource "
@@ -4,9 +4,9 @@ from collections import OrderedDict
4
4
  from enum import Enum
5
5
  import logging
6
6
  import typing
7
- from typing import Any, Dict, List, Optional, Text, Tuple, Callable, Type
8
7
 
9
8
  import numpy as np
9
+ from typing import Any, Dict, List, Optional, Text, Tuple, Callable, Type
10
10
 
11
11
  import rasa.nlu.utils.bilou_utils as bilou_utils
12
12
  import rasa.shared.utils.io
@@ -41,9 +41,6 @@ if typing.TYPE_CHECKING:
41
41
  from sklearn_crfsuite import CRF
42
42
 
43
43
 
44
- CONFIG_FEATURES = "features"
45
-
46
-
47
44
  class CRFToken:
48
45
  def __init__(
49
46
  self,
@@ -63,29 +60,6 @@ class CRFToken:
63
60
  self.entity_role_tag = entity_role_tag
64
61
  self.entity_group_tag = entity_group_tag
65
62
 
66
- def to_dict(self) -> Dict[str, Any]:
67
- return {
68
- "text": self.text,
69
- "pos_tag": self.pos_tag,
70
- "pattern": self.pattern,
71
- "dense_features": [str(x) for x in list(self.dense_features)],
72
- "entity_tag": self.entity_tag,
73
- "entity_role_tag": self.entity_role_tag,
74
- "entity_group_tag": self.entity_group_tag,
75
- }
76
-
77
- @classmethod
78
- def create_from_dict(cls, data: Dict[str, Any]) -> "CRFToken":
79
- return cls(
80
- data["text"],
81
- data["pos_tag"],
82
- data["pattern"],
83
- np.array([float(x) for x in data["dense_features"]]),
84
- data["entity_tag"],
85
- data["entity_role_tag"],
86
- data["entity_group_tag"],
87
- )
88
-
89
63
 
90
64
  class CRFEntityExtractorOptions(str, Enum):
91
65
  """Features that can be used for the 'CRFEntityExtractor'."""
@@ -114,6 +88,8 @@ class CRFEntityExtractorOptions(str, Enum):
114
88
  class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
115
89
  """Implements conditional random fields (CRF) to do named entity recognition."""
116
90
 
91
+ CONFIG_FEATURES = "features"
92
+
117
93
  function_dict: Dict[Text, Callable[[CRFToken], Any]] = { # noqa: RUF012
118
94
  CRFEntityExtractorOptions.LOW: lambda crf_token: crf_token.text.lower(),
119
95
  CRFEntityExtractorOptions.TITLE: lambda crf_token: crf_token.text.istitle(),
@@ -132,7 +108,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
132
108
  CRFEntityExtractorOptions.DIGIT: lambda crf_token: crf_token.text.isdigit(),
133
109
  CRFEntityExtractorOptions.PATTERN: lambda crf_token: crf_token.pattern,
134
110
  CRFEntityExtractorOptions.TEXT_DENSE_FEATURES: (
135
- lambda crf_token: CRFEntityExtractor._convert_dense_features_for_crfsuite( # noqa: E501
111
+ lambda crf_token: CRFEntityExtractor._convert_dense_features_for_crfsuite(
136
112
  crf_token
137
113
  )
138
114
  ),
@@ -161,7 +137,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
161
137
  # "is the preceding token in title case?"
162
138
  # POS features require SpacyTokenizer
163
139
  # pattern feature require RegexFeaturizer
164
- CONFIG_FEATURES: [
140
+ CRFEntityExtractor.CONFIG_FEATURES: [
165
141
  [
166
142
  CRFEntityExtractorOptions.LOW,
167
143
  CRFEntityExtractorOptions.TITLE,
@@ -224,7 +200,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
224
200
  )
225
201
 
226
202
  def _validate_configuration(self) -> None:
227
- if len(self.component_config.get(CONFIG_FEATURES, [])) % 2 != 1:
203
+ if len(self.component_config.get(self.CONFIG_FEATURES, [])) % 2 != 1:
228
204
  raise ValueError(
229
205
  "Need an odd number of crf feature lists to have a center word."
230
206
  )
@@ -275,11 +251,9 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
275
251
  ]
276
252
  dataset = [self._convert_to_crf_tokens(example) for example in entity_examples]
277
253
 
278
- self.entity_taggers = self.train_model(
279
- dataset, self.component_config, self.crf_order
280
- )
254
+ self._train_model(dataset)
281
255
 
282
- self.persist(dataset)
256
+ self.persist()
283
257
 
284
258
  return self._resource
285
259
 
@@ -325,9 +299,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
325
299
  if include_tag_features:
326
300
  self._add_tag_to_crf_token(crf_tokens, predictions)
327
301
 
328
- features = self._crf_tokens_to_features(
329
- crf_tokens, self.component_config, include_tag_features
330
- )
302
+ features = self._crf_tokens_to_features(crf_tokens, include_tag_features)
331
303
  predictions[tag_name] = entity_tagger.predict_marginals_single(features)
332
304
 
333
305
  # convert predictions into a list of tags and a list of confidences
@@ -417,25 +389,27 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
417
389
  **kwargs: Any,
418
390
  ) -> CRFEntityExtractor:
419
391
  """Loads trained component (see parent class for full docstring)."""
392
+ import joblib
393
+
420
394
  try:
395
+ entity_taggers = OrderedDict()
421
396
  with model_storage.read_from(resource) as model_dir:
422
- dataset = rasa.shared.utils.io.read_json_file(
423
- model_dir / "crf_dataset.json"
424
- )
425
- crf_order = rasa.shared.utils.io.read_json_file(
426
- model_dir / "crf_order.json"
427
- )
428
-
429
- dataset = [
430
- [CRFToken.create_from_dict(token_data) for token_data in sub_list]
431
- for sub_list in dataset
432
- ]
397
+ # We have to load in the same order as we persisted things as otherwise
398
+ # the predictions might be off
399
+ file_names = sorted(model_dir.glob("**/*.pkl"))
400
+ if not file_names:
401
+ logger.debug(
402
+ "Failed to load model for 'CRFEntityExtractor'. "
403
+ "Maybe you did not provide enough training data and "
404
+ "no model was trained."
405
+ )
406
+ return cls(config, model_storage, resource)
433
407
 
434
- entity_taggers = cls.train_model(dataset, config, crf_order)
408
+ for file_name in file_names:
409
+ name = file_name.stem[1:]
410
+ entity_taggers[name] = joblib.load(file_name)
435
411
 
436
- entity_extractor = cls(config, model_storage, resource, entity_taggers)
437
- entity_extractor.crf_order = crf_order
438
- return entity_extractor
412
+ return cls(config, model_storage, resource, entity_taggers)
439
413
  except ValueError:
440
414
  logger.warning(
441
415
  f"Failed to load {cls.__name__} from model storage. Resource "
@@ -443,29 +417,23 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
443
417
  )
444
418
  return cls(config, model_storage, resource)
445
419
 
446
- def persist(self, dataset: List[List[CRFToken]]) -> None:
420
+ def persist(self) -> None:
447
421
  """Persist this model into the passed directory."""
448
- with self._model_storage.write_to(self._resource) as model_dir:
449
- data_to_store = [
450
- [token.to_dict() for token in sub_list] for sub_list in dataset
451
- ]
422
+ import joblib
452
423
 
453
- rasa.shared.utils.io.dump_obj_as_json_to_file(
454
- model_dir / "crf_dataset.json", data_to_store
455
- )
456
- rasa.shared.utils.io.dump_obj_as_json_to_file(
457
- model_dir / "crf_order.json", self.crf_order
458
- )
424
+ with self._model_storage.write_to(self._resource) as model_dir:
425
+ if self.entity_taggers:
426
+ for idx, (name, entity_tagger) in enumerate(
427
+ self.entity_taggers.items()
428
+ ):
429
+ model_file_name = model_dir / f"{idx}{name}.pkl"
430
+ joblib.dump(entity_tagger, model_file_name)
459
431
 
460
- @classmethod
461
432
  def _crf_tokens_to_features(
462
- cls,
463
- crf_tokens: List[CRFToken],
464
- config: Dict[str, Any],
465
- include_tag_features: bool = False,
433
+ self, crf_tokens: List[CRFToken], include_tag_features: bool = False
466
434
  ) -> List[Dict[Text, Any]]:
467
435
  """Convert the list of tokens into discrete features."""
468
- configured_features = config[CONFIG_FEATURES]
436
+ configured_features = self.component_config[self.CONFIG_FEATURES]
469
437
  sentence_features = []
470
438
 
471
439
  for token_idx in range(len(crf_tokens)):
@@ -476,31 +444,28 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
476
444
  half_window_size = window_size // 2
477
445
  window_range = range(-half_window_size, half_window_size + 1)
478
446
 
479
- token_features = cls._create_features_for_token(
447
+ token_features = self._create_features_for_token(
480
448
  crf_tokens,
481
449
  token_idx,
482
450
  half_window_size,
483
451
  window_range,
484
452
  include_tag_features,
485
- config,
486
453
  )
487
454
 
488
455
  sentence_features.append(token_features)
489
456
 
490
457
  return sentence_features
491
458
 
492
- @classmethod
493
459
  def _create_features_for_token(
494
- cls,
460
+ self,
495
461
  crf_tokens: List[CRFToken],
496
462
  token_idx: int,
497
463
  half_window_size: int,
498
464
  window_range: range,
499
465
  include_tag_features: bool,
500
- config: Dict[str, Any],
501
466
  ) -> Dict[Text, Any]:
502
467
  """Convert a token into discrete features including words before and after."""
503
- configured_features = config[CONFIG_FEATURES]
468
+ configured_features = self.component_config[self.CONFIG_FEATURES]
504
469
  prefixes = [str(i) for i in window_range]
505
470
 
506
471
  token_features = {}
@@ -540,13 +505,13 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
540
505
  # set in the training data, 'matched' is either 'True' or
541
506
  # 'False' depending on whether the token actually matches the
542
507
  # pattern or not
543
- regex_patterns = cls.function_dict[feature](token)
508
+ regex_patterns = self.function_dict[feature](token)
544
509
  for pattern_name, matched in regex_patterns.items():
545
- token_features[
546
- f"{prefix}:{feature}:{pattern_name}"
547
- ] = matched
510
+ token_features[f"{prefix}:{feature}:{pattern_name}"] = (
511
+ matched
512
+ )
548
513
  else:
549
- value = cls.function_dict[feature](token)
514
+ value = self.function_dict[feature](token)
550
515
  token_features[f"{prefix}:{feature}"] = value
551
516
 
552
517
  return token_features
@@ -670,46 +635,38 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
670
635
 
671
636
  return tags
672
637
 
673
- @classmethod
674
- def train_model(
675
- cls,
676
- df_train: List[List[CRFToken]],
677
- config: Dict[str, Any],
678
- crf_order: List[str],
679
- ) -> OrderedDict[str, CRF]:
638
+ def _train_model(self, df_train: List[List[CRFToken]]) -> None:
680
639
  """Train the crf tagger based on the training data."""
681
640
  import sklearn_crfsuite
682
641
 
683
- entity_taggers = OrderedDict()
642
+ self.entity_taggers = OrderedDict()
684
643
 
685
- for tag_name in crf_order:
644
+ for tag_name in self.crf_order:
686
645
  logger.debug(f"Training CRF for '{tag_name}'.")
687
646
 
688
647
  # add entity tag features for second level CRFs
689
648
  include_tag_features = tag_name != ENTITY_ATTRIBUTE_TYPE
690
649
  X_train = (
691
- cls._crf_tokens_to_features(sentence, config, include_tag_features)
650
+ self._crf_tokens_to_features(sentence, include_tag_features)
692
651
  for sentence in df_train
693
652
  )
694
653
  y_train = (
695
- cls._crf_tokens_to_tags(sentence, tag_name) for sentence in df_train
654
+ self._crf_tokens_to_tags(sentence, tag_name) for sentence in df_train
696
655
  )
697
656
 
698
657
  entity_tagger = sklearn_crfsuite.CRF(
699
658
  algorithm="lbfgs",
700
659
  # coefficient for L1 penalty
701
- c1=config["L1_c"],
660
+ c1=self.component_config["L1_c"],
702
661
  # coefficient for L2 penalty
703
- c2=config["L2_c"],
662
+ c2=self.component_config["L2_c"],
704
663
  # stop earlier
705
- max_iterations=config["max_iterations"],
664
+ max_iterations=self.component_config["max_iterations"],
706
665
  # include transitions that are possible, but not observed
707
666
  all_possible_transitions=True,
708
667
  )
709
668
  entity_tagger.fit(X_train, y_train)
710
669
 
711
- entity_taggers[tag_name] = entity_tagger
670
+ self.entity_taggers[tag_name] = entity_tagger
712
671
 
713
672
  logger.debug("Training finished.")
714
-
715
- return entity_taggers