rasa-pro 3.11.0__py3-none-any.whl → 3.11.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (220) hide show
  1. README.md +396 -17
  2. rasa/__main__.py +15 -31
  3. rasa/api.py +1 -5
  4. rasa/cli/arguments/default_arguments.py +2 -1
  5. rasa/cli/arguments/shell.py +1 -5
  6. rasa/cli/arguments/train.py +0 -14
  7. rasa/cli/e2e_test.py +1 -1
  8. rasa/cli/evaluate.py +8 -8
  9. rasa/cli/inspect.py +7 -15
  10. rasa/cli/interactive.py +0 -1
  11. rasa/cli/llm_fine_tuning.py +1 -1
  12. rasa/cli/project_templates/calm/config.yml +7 -5
  13. rasa/cli/project_templates/calm/endpoints.yml +2 -15
  14. rasa/cli/project_templates/tutorial/config.yml +5 -8
  15. rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
  16. rasa/cli/project_templates/tutorial/data/patterns.yml +0 -5
  17. rasa/cli/project_templates/tutorial/domain.yml +0 -14
  18. rasa/cli/project_templates/tutorial/endpoints.yml +0 -5
  19. rasa/cli/run.py +1 -1
  20. rasa/cli/scaffold.py +2 -4
  21. rasa/cli/studio/studio.py +8 -18
  22. rasa/cli/studio/upload.py +15 -0
  23. rasa/cli/train.py +0 -3
  24. rasa/cli/utils.py +1 -6
  25. rasa/cli/x.py +8 -8
  26. rasa/constants.py +1 -3
  27. rasa/core/actions/action.py +33 -75
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +1 -5
  29. rasa/core/actions/http_custom_action_executor.py +0 -4
  30. rasa/core/channels/__init__.py +0 -2
  31. rasa/core/channels/channel.py +0 -20
  32. rasa/core/channels/development_inspector.py +3 -10
  33. rasa/core/channels/inspector/dist/assets/{arc-bc141fb2.js → arc-86942a71.js} +1 -1
  34. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-be2db283.js → c4Diagram-d0fbc5ce-b0290676.js} +1 -1
  35. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-55366915.js → classDiagram-936ed81e-f6405f6e.js} +1 -1
  36. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-bb529518.js → classDiagram-v2-c3cb15f1-ef61ac77.js} +1 -1
  37. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-b0ec81d6.js → createText-62fc7601-f0411e58.js} +1 -1
  38. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-6166330c.js → edges-f2ad444c-7dcc4f3b.js} +1 -1
  39. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-5ccc6a8e.js → erDiagram-9d236eb7-e0c092d7.js} +1 -1
  40. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-fca3bfe4.js → flowDb-1972c806-fba2e3ce.js} +1 -1
  41. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4739080f.js → flowDiagram-7ea5b25a-7a70b71a.js} +1 -1
  42. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-24a5f41a.js +1 -0
  43. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-7c1b0e0f.js → flowchart-elk-definition-abe16c3d-00a59b68.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-772fd050.js → ganttDiagram-9b5ea136-293c91fa.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-8eae1dc9.js → gitGraphDiagram-99d0ae7c-07b2d68c.js} +1 -1
  46. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-f55afcdf.js → index-2c4b9a3b-bc959fbd.js} +1 -1
  47. rasa/core/channels/inspector/dist/assets/{index-e7cef9de.js → index-3a8a5a28.js} +143 -143
  48. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-124d4a14.js → infoDiagram-736b4530-4a350f72.js} +1 -1
  49. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-7c4fae44.js → journeyDiagram-df861f2b-af464fb7.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{layout-b9885fb6.js → layout-0071f036.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/{line-7c59abb6.js → line-2f73cc83.js} +1 -1
  52. rasa/core/channels/inspector/dist/assets/{linear-4776f780.js → linear-f014b4cc.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2332c46c.js → mindmap-definition-beec6740-d2426fb6.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-8fb39303.js → pieDiagram-dbbf0591-776f01a2.js} +1 -1
  55. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3c7180a2.js → quadrantDiagram-4d7f4fd6-82e00b57.js} +1 -1
  56. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-e910bcb8.js → requirementDiagram-6fc4c22a-ea13c6bb.js} +1 -1
  57. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-ead16c89.js → sankeyDiagram-8f13d901-1feca7e9.js} +1 -1
  58. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-29a02a19.js → sequenceDiagram-b655622a-070c61d2.js} +1 -1
  59. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-042b3137.js → stateDiagram-59f0c015-24f46263.js} +1 -1
  60. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-2178c0f3.js → stateDiagram-v2-2b26beab-c9056051.js} +1 -1
  61. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-23ffa4fc.js → styles-080da4f6-08abc34a.js} +1 -1
  62. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-94f59763.js → styles-3dcbcfbf-bc74c25a.js} +1 -1
  63. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-78a6bebc.js → styles-9c745c82-4e5d66de.js} +1 -1
  64. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-eae2a6f6.js → svgDrawCommon-4835440b-849c4517.js} +1 -1
  65. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-5c968d92.js → timeline-definition-5b62e21b-d0fb1598.js} +1 -1
  66. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-fd3db0d5.js → xychartDiagram-2b33534f-04d115e2.js} +1 -1
  67. rasa/core/channels/inspector/dist/index.html +1 -1
  68. rasa/core/channels/inspector/src/App.tsx +1 -1
  69. rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +3 -6
  70. rasa/core/channels/socketio.py +2 -7
  71. rasa/core/channels/telegram.py +1 -1
  72. rasa/core/channels/twilio.py +1 -1
  73. rasa/core/channels/voice_ready/audiocodes.py +4 -15
  74. rasa/core/channels/voice_ready/jambonz.py +4 -15
  75. rasa/core/channels/voice_ready/twilio_voice.py +21 -6
  76. rasa/core/channels/voice_ready/utils.py +5 -6
  77. rasa/core/channels/voice_stream/asr/asr_engine.py +1 -19
  78. rasa/core/channels/voice_stream/asr/asr_event.py +0 -5
  79. rasa/core/channels/voice_stream/asr/deepgram.py +15 -28
  80. rasa/core/channels/voice_stream/audio_bytes.py +0 -1
  81. rasa/core/channels/voice_stream/tts/azure.py +3 -9
  82. rasa/core/channels/voice_stream/tts/cartesia.py +8 -12
  83. rasa/core/channels/voice_stream/tts/tts_engine.py +1 -11
  84. rasa/core/channels/voice_stream/twilio_media_streams.py +19 -28
  85. rasa/core/channels/voice_stream/util.py +4 -4
  86. rasa/core/channels/voice_stream/voice_channel.py +42 -222
  87. rasa/core/featurizers/single_state_featurizer.py +1 -22
  88. rasa/core/featurizers/tracker_featurizers.py +18 -115
  89. rasa/core/information_retrieval/qdrant.py +0 -1
  90. rasa/core/nlg/contextual_response_rephraser.py +25 -44
  91. rasa/core/persistor.py +34 -191
  92. rasa/core/policies/enterprise_search_policy.py +60 -119
  93. rasa/core/policies/flows/flow_executor.py +4 -7
  94. rasa/core/policies/intentless_policy.py +22 -82
  95. rasa/core/policies/ted_policy.py +33 -58
  96. rasa/core/policies/unexpected_intent_policy.py +7 -15
  97. rasa/core/processor.py +13 -89
  98. rasa/core/run.py +2 -2
  99. rasa/core/training/interactive.py +35 -34
  100. rasa/core/utils.py +22 -58
  101. rasa/dialogue_understanding/coexistence/llm_based_router.py +12 -39
  102. rasa/dialogue_understanding/commands/__init__.py +0 -4
  103. rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
  104. rasa/dialogue_understanding/commands/utils.py +0 -5
  105. rasa/dialogue_understanding/generator/constants.py +0 -2
  106. rasa/dialogue_understanding/generator/flow_retrieval.py +4 -49
  107. rasa/dialogue_understanding/generator/llm_based_command_generator.py +23 -37
  108. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +10 -57
  109. rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -19
  110. rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +0 -3
  111. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +10 -90
  112. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -53
  113. rasa/dialogue_understanding/processor/command_processor.py +1 -21
  114. rasa/e2e_test/assertions.py +16 -133
  115. rasa/e2e_test/assertions_schema.yml +0 -23
  116. rasa/e2e_test/e2e_test_case.py +6 -85
  117. rasa/e2e_test/e2e_test_runner.py +4 -6
  118. rasa/e2e_test/utils/io.py +1 -3
  119. rasa/engine/loader.py +0 -12
  120. rasa/engine/validation.py +11 -541
  121. rasa/keys +1 -0
  122. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  123. rasa/model_training.py +7 -29
  124. rasa/nlu/classifiers/diet_classifier.py +25 -38
  125. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
  126. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  127. rasa/nlu/extractors/crf_entity_extractor.py +50 -93
  128. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +16 -45
  129. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
  130. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  131. rasa/nlu/tokenizers/whitespace_tokenizer.py +14 -3
  132. rasa/server.py +1 -3
  133. rasa/shared/constants.py +0 -61
  134. rasa/shared/core/constants.py +0 -9
  135. rasa/shared/core/domain.py +5 -8
  136. rasa/shared/core/flows/flow.py +0 -5
  137. rasa/shared/core/flows/flows_list.py +1 -5
  138. rasa/shared/core/flows/flows_yaml_schema.json +0 -10
  139. rasa/shared/core/flows/validation.py +0 -96
  140. rasa/shared/core/flows/yaml_flows_io.py +4 -13
  141. rasa/shared/core/slots.py +0 -5
  142. rasa/shared/importers/importer.py +2 -19
  143. rasa/shared/importers/rasa.py +1 -5
  144. rasa/shared/nlu/training_data/features.py +2 -120
  145. rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -18
  146. rasa/shared/providers/_configs/azure_openai_client_config.py +3 -5
  147. rasa/shared/providers/_configs/openai_client_config.py +1 -1
  148. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +0 -1
  149. rasa/shared/providers/_configs/utils.py +0 -16
  150. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +29 -18
  151. rasa/shared/providers/embedding/azure_openai_embedding_client.py +21 -54
  152. rasa/shared/providers/embedding/default_litellm_embedding_client.py +0 -24
  153. rasa/shared/providers/llm/_base_litellm_client.py +31 -63
  154. rasa/shared/providers/llm/azure_openai_llm_client.py +29 -50
  155. rasa/shared/providers/llm/default_litellm_llm_client.py +0 -24
  156. rasa/shared/providers/llm/self_hosted_llm_client.py +29 -17
  157. rasa/shared/providers/mappings.py +0 -19
  158. rasa/shared/utils/common.py +2 -37
  159. rasa/shared/utils/io.py +6 -28
  160. rasa/shared/utils/llm.py +46 -353
  161. rasa/shared/utils/yaml.py +82 -181
  162. rasa/studio/auth.py +5 -3
  163. rasa/studio/config.py +4 -13
  164. rasa/studio/constants.py +0 -1
  165. rasa/studio/data_handler.py +4 -13
  166. rasa/studio/upload.py +80 -175
  167. rasa/telemetry.py +17 -94
  168. rasa/tracing/config.py +1 -3
  169. rasa/tracing/instrumentation/attribute_extractors.py +17 -94
  170. rasa/tracing/instrumentation/instrumentation.py +0 -121
  171. rasa/utils/common.py +0 -5
  172. rasa/utils/endpoints.py +1 -27
  173. rasa/utils/io.py +81 -7
  174. rasa/utils/log_utils.py +2 -9
  175. rasa/utils/tensorflow/model_data.py +193 -2
  176. rasa/validator.py +4 -110
  177. rasa/version.py +1 -1
  178. rasa_pro-3.11.0a1.dist-info/METADATA +576 -0
  179. {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/RECORD +182 -216
  180. rasa/core/actions/action_repeat_bot_messages.py +0 -89
  181. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +0 -1
  182. rasa/core/channels/inspector/src/helpers/audiostream.ts +0 -165
  183. rasa/core/channels/voice_stream/asr/azure.py +0 -129
  184. rasa/core/channels/voice_stream/browser_audio.py +0 -107
  185. rasa/core/channels/voice_stream/call_state.py +0 -23
  186. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +0 -60
  187. rasa/dialogue_understanding/commands/user_silence_command.py +0 -59
  188. rasa/dialogue_understanding/patterns/repeat.py +0 -37
  189. rasa/dialogue_understanding/patterns/user_silence.py +0 -37
  190. rasa/model_manager/__init__.py +0 -0
  191. rasa/model_manager/config.py +0 -40
  192. rasa/model_manager/model_api.py +0 -559
  193. rasa/model_manager/runner_service.py +0 -286
  194. rasa/model_manager/socket_bridge.py +0 -146
  195. rasa/model_manager/studio_jwt_auth.py +0 -86
  196. rasa/model_manager/trainer_service.py +0 -325
  197. rasa/model_manager/utils.py +0 -87
  198. rasa/model_manager/warm_rasa_process.py +0 -187
  199. rasa/model_service.py +0 -112
  200. rasa/shared/core/flows/utils.py +0 -39
  201. rasa/shared/providers/_configs/litellm_router_client_config.py +0 -220
  202. rasa/shared/providers/_configs/model_group_config.py +0 -167
  203. rasa/shared/providers/_configs/rasa_llm_client_config.py +0 -73
  204. rasa/shared/providers/_utils.py +0 -79
  205. rasa/shared/providers/embedding/litellm_router_embedding_client.py +0 -135
  206. rasa/shared/providers/llm/litellm_router_llm_client.py +0 -182
  207. rasa/shared/providers/llm/rasa_llm_client.py +0 -112
  208. rasa/shared/providers/router/__init__.py +0 -0
  209. rasa/shared/providers/router/_base_litellm_router_client.py +0 -183
  210. rasa/shared/providers/router/router_client.py +0 -73
  211. rasa/shared/utils/health_check/__init__.py +0 -0
  212. rasa/shared/utils/health_check/embeddings_health_check_mixin.py +0 -31
  213. rasa/shared/utils/health_check/health_check.py +0 -258
  214. rasa/shared/utils/health_check/llm_health_check_mixin.py +0 -31
  215. rasa/utils/sanic_error_handler.py +0 -32
  216. rasa/utils/tensorflow/feature_array.py +0 -366
  217. rasa_pro-3.11.0.dist-info/METADATA +0 -198
  218. {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/NOTICE +0 -0
  219. {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/WHEEL +0 -0
  220. {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,407 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Fine-tuning a Hugging Face base model using Unsloth and TRL\n",
8
+ "\n",
9
+ "This is a worked example of how to efficiently fine-tune a base language model from [Hugging Face Hub](https://huggingface.co/models) using the [Unsloth](https://docs.unsloth.ai) and [TRL](https://huggingface.co/docs/trl/en/index) libraries on an instruction-based task.\n",
10
+ "\n",
11
+ "Unsloth integrates with TRL in order to reduce the time and GPU memory required to fine-tune LLMs, when compared to using TRL exclusively.\n",
12
+ "\n",
13
+ "To run fine-tuning, you must have first [generated the dataset](https://rasa.com/rasa-pro/docs/operating/fine-tuning-recipe) files `train.jsonl` and `val.jsonl`, which must be in the [TRL instruction format](https://huggingface.co/docs/trl/en/sft_trainer#dataset-format-support)."
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "## 1. Configure fine-tuning environment\n",
21
+ "\n",
22
+ "This notebook has been tested on on a [GCP Vertex AI Workbench instance](https://cloud.google.com/vertex-ai/docs/workbench/instances/introduction) with machine type `a2-highgpu-1g` in the `asia-southeast1-b` zone, which had the following hardware:\n",
23
+ "- Single NVIDIA A100 GPU (40GB VRAM)\n",
24
+ "- 12 core CPU with 85B RAM\n",
25
+ "- 256GB disk\n",
26
+ "\n",
27
+ "It has also been tested on on a [AWS Sagemaker Notebook instance](https://docs.aws.amazon.com/sagemaker/latest/dg/nbi.html) with machine type `ml.g4dn.4xlarge` in the `eu-central-1` region, which had the following hardware:\n",
28
+ "- Single NVIDIA T4 GPU (16GB VRAM)\n",
29
+ "- 16 core CPU with 64GB RAM\n",
30
+ "- 256GB disk\n",
31
+ "\n",
32
+ "In both setups, the notebook was executed in a Linux environment with the following software already installed:\n",
33
+ "- Python 3.10\n",
34
+ "- CUDA Toolkit 12.1\n",
35
+ "- PyTorch 2.2\n",
36
+ "\n",
37
+ "In order to run fine-tuning yourself, upload this notebook along with your dataset files to your own environment.\n",
38
+ "\n",
39
+ "It is highly recommended that you restart your notebook kernel and re-run all notebook cells every time you wish to perform fine-tuning.\n",
40
+ "\n",
41
+ "> Despite the fact that this notebook will work with a relatively underpowered GPU, such as the NVIDIA T4, running LLM fine-tuning and inference will be very slow.\n",
42
+ ">\n",
43
+ "> It is highly recommended that you use an NVIDIA A100 or other similar GPU types.\n",
44
+ ">\n",
45
+ "> The code presented here has been configured for use with an NVIDIA A100, please take note of later comments on the changes you should make to the code when using a different type of GPU.\n"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "metadata": {},
51
+ "source": [
52
+ "## 2. Install Python requirements\n",
53
+ "\n",
54
+ "The following `pip` commands will install Unsloth and other required Python packages when [Conda](https://anaconda.org/anaconda/conda) is used to manage your Python environment, which is the case for many Jupyter notebook runtimes.\n",
55
+ "\n",
56
+ "If Conda is not used in your environment, please follow [these alternative instructions](https://github.com/unslothai/unsloth?tab=readme-ov-file#pip-installation) for installing Unsloth."
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "%%sh\n",
66
+ "# install unsloth and other dependencies\n",
67
+ "pip install \"unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git\"\n",
68
+ "pip install --no-deps \"xformers<=0.0.26\" \"trl<0.9.0\" peft accelerate bitsandbytes huggingface_hub[cli]\n",
69
+ "# remove tpu-only package that is installed by default on gcp runtimes, even when only using gpu\n",
70
+ "pip uninstall torch-xla -y"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "metadata": {},
76
+ "source": [
77
+ "## 3. Download base model\n",
78
+ "\n",
79
+ "You can download the model you want to fine-tune from Hugging Face Hub using the [official CLI](https://huggingface.co/docs/huggingface_hub/en/guides/cli) with an [API access token](https://huggingface.co/docs/transformers.js/en/guides/private#step-1-generating-a-user-access-token) as per the code below. Make sure you first update the `HUGGINGFACE_TOKEN` and `BASE_MODEL` environment variables with your own values.\n",
80
+ "\n",
81
+ "When testing this notebook, the [Llama 3.1 8B Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) model was used. Note that `meta-llama/Meta-Llama-3.1-8B-Instruct` is a [gated model](https://huggingface.co/docs/hub/en/models-gated) that you must first request access to before using. \n",
82
+ "\n",
83
+ "You can use any other PyTorch model available on [Hugging Face Hub](https://huggingface.co/models). It is recommended that you use a model that has been pre-trained on instructional tasks, such as the [CodeLlama 13B Instruct](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf) model.\n",
84
+ "\n",
85
+ "Pre-trained models with more parameters will generally perform better at tasks than models with fewer parameters. However, the size of model you can use is limited by how much memory your GPU has.\n",
86
+ "\n",
87
+ "Alternatively, if you already have a PyTorch model directory to hand, you can upload it to your notebook environment manually."
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "%%sh\n",
97
+ "# TODO: update with your values\n",
98
+ "export HUGGINGFACE_TOKEN=\"CHANGEME\"\n",
99
+ "export BASE_MODEL=\"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
100
+ "\n",
101
+ "# download model\n",
102
+ "huggingface-cli download \"${BASE_MODEL}\" \\\n",
103
+ " --token \"${HUGGINGFACE_TOKEN}\" \\\n",
104
+ " --local-dir \"./base_model\" \\\n",
105
+ " --exclude \"*.bin*\""
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "markdown",
110
+ "metadata": {},
111
+ "source": [
112
+ "## 4. Load and quantize base model\n",
113
+ "\n",
114
+ "The [quantization of model parameters](https://huggingface.co/docs/optimum/en/concept_guides/quantization) can significantly reduce the GPU memory required to run model fine-tuning and inference, at the cost of model accuracy.\n",
115
+ "\n",
116
+ "Here, the base model is loaded from disk and quantized into an 8-bit representation on the fly using the [BitsAndBytes](https://huggingface.co/docs/transformers/main/en/quantization/bitsandbytes) library.\n",
117
+ "\n",
118
+ "If you are using a GPU with relatively little memory, such as the NVIDIA T4, or if you are using a base model larger than `meta-llama/Meta-Llama-3.1-8B-Instruct`, you may be required to use 4-bit quantization (e.g. `load_in_4bit = True`) in order to avoid \n",
119
+ "out of memory (OOM) errors."
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "from transformers import BitsAndBytesConfig\n",
129
+ "from unsloth import FastLanguageModel\n",
130
+ "\n",
131
+ "# configure quantization method for base model\n",
132
+ "quantization_config = BitsAndBytesConfig(\n",
133
+ " load_in_8bit=True,\n",
134
+ ")\n",
135
+ "\n",
136
+ "# load quantized model and tokenizer from disk\n",
137
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
138
+ " model_name=\"./base_model\",\n",
139
+ " max_seq_length=2048,\n",
140
+ " quantization_config=quantization_config,\n",
141
+ ")"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "markdown",
146
+ "metadata": {},
147
+ "source": [
148
+ "## 5. Load training and validation datasets\n",
149
+ "\n",
150
+ "The following code loads the training and validation datasets from the `train.jsonl` and `val.json` files, respectively\n",
151
+ "\n",
152
+ "As the files use the TRL instruction format, the TRL trainer used later will be able to [automatically parse](https://huggingface.co/docs/trl/en/sft_trainer#dataset-format-support) the datasets and [generate the prompts from a template](https://huggingface.co/docs/transformers/en/chat_templating) configured in the tokenizer.\n",
153
+ "\n",
154
+ "Prompt templates vary between models and TRL will infer the correct template from your base model. If this is not available in your base model or if you wish to change it, you can set your own [template string](https://huggingface.co/docs/transformers/en/chat_templating#advanced-adding-and-editing-chat-templates).\n",
155
+ "\n",
156
+ "You can also define your own [prompt formatting function](https://huggingface.co/docs/trl/en/sft_trainer#format-your-input-prompts) in order to have full control of how the prompts are constructed."
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "import datasets\n",
166
+ "from trl.extras.dataset_formatting import get_formatting_func_from_dataset\n",
167
+ "\n",
168
+ "# load datasets from disk\n",
169
+ "train_dataset = datasets.load_dataset(\n",
170
+ " \"json\", data_files={\"train\": \"train.jsonl\"}, split=\"train\"\n",
171
+ ")\n",
172
+ "eval_dataset = datasets.load_dataset(\n",
173
+ " \"json\", data_files={\"eval\": \"val.jsonl\"}, split=\"eval\"\n",
174
+ ")\n",
175
+ "\n",
176
+ "# test prompt templating on example from dataset\n",
177
+ "print(get_formatting_func_from_dataset(train_dataset, tokenizer)(eval_dataset[0]))"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "markdown",
182
+ "metadata": {},
183
+ "source": [
184
+ "## 6. Configure trainer\n",
185
+ "\n",
186
+ "Below, the arguments for the supervised fine-tuning (SFT) trainer are configured. Their values were chosen somewhat arbitrarily and resulted in satisfactory results during testing.\n",
187
+ "\n",
188
+ "It is recommended that you read the official documentation and experiment with the arguments passed to `SFTConfig` (see [here](https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTTrainer)) and `SFTTrainer` (see [here](https://huggingface.co/docs/trl/main/en/sft_trainer#trl.SFTTrainer)).\n",
189
+ "\n",
190
+ "For example:\n",
191
+ "- If you get an OOM error when running fine-tuning, you can reduce `per_device_train_batch_size` in order to reduce the memory footprint. However, if your GPU has sufficient memory, you can try increasing it in order to reduce the total number of training steps.\n",
192
+ "- Consider setting `max_steps`, as you may not need to perform all epochs in order to achieve optimal model accuracy. Conversely, you may see better model accuracy by increasing `num_train_epochs`.\n",
193
+ "- If fine-tuning is taking too long, you can increase `eval_steps` in order to reduce how often validation is performed. "
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": 6,
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "import torch\n",
203
+ "from transformers import TrainingArguments\n",
204
+ "from trl import SFTTrainer\n",
205
+ "\n",
206
+ "# configure training args\n",
207
+ "args = TrainingArguments(\n",
208
+ " # training\n",
209
+ " per_device_train_batch_size=8,\n",
210
+ " warmup_steps=50,\n",
211
+ " num_train_epochs=4,\n",
212
+ " learning_rate=0.0001,\n",
213
+ " lr_scheduler_type=\"cosine\",\n",
214
+ " optim=\"adamw_bnb_8bit\",\n",
215
+ " weight_decay=0.0,\n",
216
+ " logging_steps=1,\n",
217
+ " # datatypes\n",
218
+ " fp16=not torch.cuda.is_bf16_supported(),\n",
219
+ " bf16=torch.cuda.is_bf16_supported(),\n",
220
+ " # evaluation\n",
221
+ " eval_strategy=\"steps\",\n",
222
+ " eval_steps=50,\n",
223
+ " per_device_eval_batch_size=8,\n",
224
+ " output_dir=\"outputs\",\n",
225
+ ")\n",
226
+ "\n",
227
+ "# setup trainer\n",
228
+ "trainer = SFTTrainer(\n",
229
+ " model=model,\n",
230
+ " tokenizer=tokenizer,\n",
231
+ " train_dataset=train_dataset,\n",
232
+ " eval_dataset=eval_dataset,\n",
233
+ " args=args,\n",
234
+ ")"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "markdown",
239
+ "metadata": {},
240
+ "source": [
241
+ "## 7. Perform supervised fine-tuning\n",
242
+ "\n",
243
+ "In the code below, fine-tuning is performed using the previously congfigured trainer.\n",
244
+ "\n",
245
+ "When testing this step on an NVIDIA A100 using the configuration defined above, it took around 12 minutes to perform fine-tuning with a training dataset containing around 500 examples.\n",
246
+ "\n",
247
+ "After fine-tuning, the base model and fine-tuned adapters are [merged together and saved to disk](https://docs.unsloth.ai/basics/saving-models/saving-to-vllm) in 16-bit for future compatibility with the [vLLM](https://github.com/vllm-project/vllm) model serving library.\n",
248
+ "\n",
249
+ "If you are using a relatively small GPU, such as the NVIDIA T4, you may have to save the model in 4-bit instead (e.g. `save_method = \"merged_4bit_forced\"`)."
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": null,
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "# run fine-tuning\n",
259
+ "finetune_metrics = trainer.train()\n",
260
+ "\n",
261
+ "# save model to disk in 16-bit\n",
262
+ "model.save_pretrained_merged(\n",
263
+ " \"./finetune_model\", tokenizer, save_method=\"merged_16bit\"\n",
264
+ ")"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "markdown",
269
+ "metadata": {},
270
+ "source": [
271
+ "## 8. Visualize fine-tuning metrics\n",
272
+ "\n",
273
+ "Some of the metrics collected during fine-tuning are visualised below in order for you to diagnose any potential issues with the model.\n",
274
+ "\n",
275
+ "Specifically, the training and validation losses are plotted against the training step number. Please check the plot for the following:\n",
276
+ "- Ideally, as the fine-tuning steps increase, the training and validation losses should decrease and converge. \n",
277
+ "- If both loss curves do not converge, it may be worth performing more fine-tuning steps or epochs. This is known as [underfitting](https://www.ibm.com/topics/underfitting).\n",
278
+ "- If the validation loss suddenly starts to increase while the training loss continues to decrease or converge, you should decrease your total number of steps or epochs. This is known as [overfitting](https://www.ibm.com/topics/overfitting)."
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": null,
284
+ "metadata": {},
285
+ "outputs": [],
286
+ "source": [
287
+ "import pandas as pd\n",
288
+ "import matplotlib.pyplot as plt\n",
289
+ "\n",
290
+ "# plot step against train and val losses\n",
291
+ "fig, ax = plt.subplots()\n",
292
+ "log_history = pd.DataFrame(trainer.state.log_history)\n",
293
+ "eval_loss = (\n",
294
+ " log_history[[\"step\", \"eval_loss\"]]\n",
295
+ " .dropna()\n",
296
+ " .plot(x=\"step\", ax=ax)\n",
297
+ ")\n",
298
+ "train_loss = (\n",
299
+ " log_history[[\"step\", \"loss\"]]\n",
300
+ " .dropna()\n",
301
+ " .plot(x=\"step\", ax=ax)\n",
302
+ ")\n",
303
+ "fig.show()"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "markdown",
308
+ "metadata": {},
309
+ "source": [
310
+ "## 9. Run ad hoc inference\n",
311
+ "\n",
312
+ "You can load your fine-tuned model from disk using Unsloth and use it to run optimized inference on individual inputs of your choosing using the code below.\n",
313
+ "\n",
314
+ "Note that the inputs passed to model are in the [TRL convertsational format](https://huggingface.co/docs/trl/en/sft_trainer#dataset-format-support) as the Hugging Face [chat template requires them to be](https://huggingface.co/docs/transformers/main/en/chat_templating#how-do-i-use-chat-templates). During training TRL will [automatically convert the instruction format to the conversational format](https://github.com/huggingface/trl/blob/main/trl/extras/dataset_formatting.py). However, you have to do this yourself when applying chat templates manually for inference."
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "metadata": {},
321
+ "outputs": [],
322
+ "source": [
323
+ "from transformers import TextStreamer\n",
324
+ "from unsloth import FastLanguageModel\n",
325
+ "\n",
326
+ "model, tokenizer = FastLanguageModel.from_pretrained(\"./finetune_model\")\n",
327
+ "FastLanguageModel.for_inference(model) # enable inference optimizations\n",
328
+ "streamer = TextStreamer(tokenizer) # stream model outputs as they are generated\n",
329
+ "\n",
330
+ "# the content to include in the input prompt\n",
331
+ "# by default, a value from the validation dataset as example\n",
332
+ "content = eval_dataset[\"prompt\"][0]\n",
333
+ "\n",
334
+ "# apply prompt template and tokenize\n",
335
+ "input_ids = tokenizer.apply_chat_template(\n",
336
+ " [{\"role\": \"user\", \"content\": content}], # in the TRL conversational format\n",
337
+ " tokenize=True,\n",
338
+ " add_generation_prompt=True,\n",
339
+ " return_tensors=\"pt\",\n",
340
+ ").to(\"cuda\")\n",
341
+ "\n",
342
+ "# generate model output from user input\n",
343
+ "_ = model.generate(\n",
344
+ " input_ids=input_ids,\n",
345
+ " streamer=streamer, # remove streamer if you want whole output at end\n",
346
+ " max_new_tokens=64, # set the limit on how many tokens are generated\n",
347
+ " do_sample=False, # disable random sampling for deterministic outputs\n",
348
+ ")"
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "markdown",
353
+ "metadata": {},
354
+ "source": [
355
+ "## 10. Export fine-tuned model\n",
356
+ "\n",
357
+ "Lastly, export your fine-tuned model directory to an appropriate storage location that can be easily accessed later for [deployment](https://rasa.com/rasa-pro/docs/building-assistants/self-hosted-llm).\n",
358
+ "\n",
359
+ "It is recommended that you use a cloud object store, such as [Amazon S3](https://aws.amazon.com/s3/) or [Google Cloud Storage](https://cloud.google.com/storage).\n",
360
+ "\n",
361
+ "Uncomment and run the corresponding commands below for your cloud provider, making sure to first update the environment variables with your own values. It is assumed that:\n",
362
+ "- your bucket already exists\n",
363
+ "- you have already installed the CLI tool for your cloud provider\n",
364
+ "- you have already authenticated with your cloud provider and have sufficient permissions to write to your bucket"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "code",
369
+ "execution_count": null,
370
+ "metadata": {},
371
+ "outputs": [],
372
+ "source": [
373
+ "%%sh\n",
374
+ "export LOCAL_MODEL_PATH=\"./finetune_model\"\n",
375
+ "\n",
376
+ "# if using amazon\n",
377
+ "# export S3_MODEL_URI=\"s3://CHANGEME\" # update with your value\n",
378
+ "# aws s3 cp \"${LOCAL_MODEL_PATH}\" \"${S3_MODEL_URI}\" --recursive\n",
379
+ "\n",
380
+ "# if using google\n",
381
+ "# export GCS_MODEL_URI=\"gs://CHANGEME\" # update with your value\n",
382
+ "# gsutil cp -r \"${LOCAL_MODEL_PATH}\" \"${GCS_MODEL_URI}\""
383
+ ]
384
+ }
385
+ ],
386
+ "metadata": {
387
+ "kernelspec": {
388
+ "display_name": "Python 3",
389
+ "language": "python",
390
+ "name": "python3"
391
+ },
392
+ "language_info": {
393
+ "codemirror_mode": {
394
+ "name": "ipython",
395
+ "version": 3
396
+ },
397
+ "file_extension": ".py",
398
+ "mimetype": "text/x-python",
399
+ "name": "python",
400
+ "nbconvert_exporter": "python",
401
+ "pygments_lexer": "ipython3",
402
+ "version": "3.10.14"
403
+ }
404
+ },
405
+ "nbformat": 4,
406
+ "nbformat_minor": 4
407
+ }
rasa/model_training.py CHANGED
@@ -157,7 +157,6 @@ async def train(
157
157
  finetuning_epoch_fraction: float = 1.0,
158
158
  remote_storage: Optional[StorageType] = None,
159
159
  file_importer: Optional[TrainingDataImporter] = None,
160
- keep_local_model_copy: bool = False,
161
160
  ) -> TrainingResult:
162
161
  """Trains a Rasa model (Core and NLU).
163
162
 
@@ -183,8 +182,6 @@ async def train(
183
182
  use for storing the model.
184
183
  file_importer: Instance of `TrainingDataImporter` to use for training.
185
184
  If it is not provided, a new instance will be created.
186
- keep_local_model_copy: If `True` the model will be stored locally even if
187
- remote storage is configured.
188
185
 
189
186
  Returns:
190
187
  An instance of `TrainingResult`.
@@ -266,7 +263,6 @@ async def train(
266
263
  finetuning_epoch_fraction=finetuning_epoch_fraction,
267
264
  dry_run=dry_run,
268
265
  remote_storage=remote_storage,
269
- keep_local_model_copy=keep_local_model_copy,
270
266
  **(core_additional_arguments or {}),
271
267
  **(nlu_additional_arguments or {}),
272
268
  )
@@ -281,7 +277,6 @@ async def _train_graph(
281
277
  force_full_training: bool = False,
282
278
  dry_run: bool = False,
283
279
  remote_storage: Optional[StorageType] = None,
284
- keep_local_model_copy: bool = False,
285
280
  **kwargs: Any,
286
281
  ) -> TrainingResult:
287
282
  if model_to_finetune:
@@ -322,10 +317,6 @@ async def _train_graph(
322
317
  rasa.engine.validation.validate_coexistance_routing_setup(
323
318
  domain, model_configuration, flows
324
319
  )
325
- rasa.engine.validation.validate_model_group_configuration_setup()
326
- rasa.engine.validation.validate_model_client_configuration_setup_during_training_time(
327
- config
328
- )
329
320
  rasa.engine.validation.validate_flow_component_dependencies(
330
321
  flows, model_configuration
331
322
  )
@@ -348,7 +339,7 @@ async def _train_graph(
348
339
  )
349
340
  return _dry_run_result(fingerprint_status, force_full_training)
350
341
 
351
- model_name = determine_model_name(fixed_model_name, training_type)
342
+ model_name = _determine_model_name(fixed_model_name, training_type)
352
343
  full_model_path = Path(output_path, model_name)
353
344
 
354
345
  with telemetry.track_model_training(
@@ -363,8 +354,7 @@ async def _train_graph(
363
354
  )
364
355
  if remote_storage:
365
356
  push_model_to_remote_storage(full_model_path, remote_storage)
366
- if not keep_local_model_copy:
367
- full_model_path.unlink()
357
+ full_model_path.unlink()
368
358
  structlogger.info(
369
359
  "model_training.train.finished_training",
370
360
  event_info=(
@@ -396,14 +386,9 @@ def _create_model_storage(
396
386
  return model_storage
397
387
 
398
388
 
399
- def generate_random_model_name() -> str:
400
- time_format = "%Y%m%d-%H%M%S"
401
- return f"{time.strftime(time_format)}-{randomname.get_name()}"
402
-
403
-
404
- def determine_model_name(
389
+ def _determine_model_name(
405
390
  fixed_model_name: Optional[Text], training_type: TrainingType
406
- ) -> str:
391
+ ) -> Text:
407
392
  if fixed_model_name:
408
393
  if not fixed_model_name.endswith(".tar.gz"):
409
394
  return f"{fixed_model_name}.tar.gz"
@@ -413,7 +398,8 @@ def determine_model_name(
413
398
  if training_type in [TrainingType.CORE, TrainingType.NLU]:
414
399
  prefix = f"{training_type.model_type}-"
415
400
 
416
- return f"{prefix}{generate_random_model_name()}.tar.gz"
401
+ time_format = "%Y%m%d-%H%M%S"
402
+ return f"{prefix}{time.strftime(time_format)}-{randomname.get_name()}.tar.gz"
417
403
 
418
404
 
419
405
  async def train_core(
@@ -425,7 +411,6 @@ async def train_core(
425
411
  additional_arguments: Optional[Dict] = None,
426
412
  model_to_finetune: Optional[Text] = None,
427
413
  finetuning_epoch_fraction: float = 1.0,
428
- keep_local_model_copy: bool = False,
429
414
  ) -> Optional[Text]:
430
415
  """Trains a Core model.
431
416
 
@@ -440,8 +425,6 @@ async def train_core(
440
425
  a directory in case the latest trained model should be used.
441
426
  finetuning_epoch_fraction: The fraction currently specified training epochs
442
427
  in the model configuration which should be used for finetuning.
443
- keep_local_model_copy: If `True` the model will be stored locally even if
444
- remote storage is configured.
445
428
 
446
429
  Returns:
447
430
  Path to the model archive.
@@ -497,7 +480,6 @@ async def train_core(
497
480
  model_to_finetune=model_to_finetune,
498
481
  fixed_model_name=fixed_model_name,
499
482
  finetuning_epoch_fraction=finetuning_epoch_fraction,
500
- keep_local_model_copy=keep_local_model_copy,
501
483
  **(additional_arguments or {}),
502
484
  )
503
485
  ).model
@@ -513,7 +495,6 @@ async def train_nlu(
513
495
  domain: Optional[Union[Domain, Text]] = None,
514
496
  model_to_finetune: Optional[Text] = None,
515
497
  finetuning_epoch_fraction: float = 1.0,
516
- keep_local_model_copy: bool = False,
517
498
  ) -> Optional[Text]:
518
499
  """Trains an NLU model.
519
500
 
@@ -531,8 +512,6 @@ async def train_nlu(
531
512
  a directory in case the latest trained model should be used.
532
513
  finetuning_epoch_fraction: The fraction currently specified training epochs
533
514
  in the model configuration which should be used for finetuning.
534
- keep_local_model_copy: If `True` the model will be stored locally even if
535
- remote storage is configured.
536
515
 
537
516
  Returns:
538
517
  Path to the model archive.
@@ -574,14 +553,13 @@ async def train_nlu(
574
553
  fixed_model_name=fixed_model_name,
575
554
  finetuning_epoch_fraction=finetuning_epoch_fraction,
576
555
  persist_nlu_training_data=persist_nlu_training_data,
577
- keep_local_model_copy=keep_local_model_copy,
578
556
  **(additional_arguments or {}),
579
557
  )
580
558
  ).model
581
559
 
582
560
 
583
561
  def push_model_to_remote_storage(model_path: Path, remote_storage: StorageType) -> None:
584
- """Push model to remote storage."""
562
+ """push model to remote storage"""
585
563
  from rasa.core.persistor import get_persistor
586
564
 
587
565
  persistor = get_persistor(remote_storage)
@@ -1,17 +1,18 @@
1
1
  from __future__ import annotations
2
-
3
2
  import copy
4
3
  import logging
5
4
  from collections import defaultdict
6
5
  from pathlib import Path
7
- from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
6
+
7
+ from rasa.exceptions import ModelNotFound
8
+ from rasa.nlu.featurizers.featurizer import Featurizer
8
9
 
9
10
  import numpy as np
10
11
  import scipy.sparse
11
12
  import tensorflow as tf
12
13
 
13
- from rasa.exceptions import ModelNotFound
14
- from rasa.nlu.featurizers.featurizer import Featurizer
14
+ from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
15
+
15
16
  from rasa.engine.graph import ExecutionContext, GraphComponent
16
17
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
17
18
  from rasa.engine.storage.resource import Resource
@@ -19,21 +20,18 @@ from rasa.engine.storage.storage import ModelStorage
19
20
  from rasa.nlu.extractors.extractor import EntityExtractorMixin
20
21
  from rasa.nlu.classifiers.classifier import IntentClassifier
21
22
  import rasa.shared.utils.io
23
+ import rasa.utils.io as io_utils
22
24
  import rasa.nlu.utils.bilou_utils as bilou_utils
23
25
  from rasa.shared.constants import DIAGNOSTIC_DATA
24
26
  from rasa.nlu.extractors.extractor import EntityTagSpec
25
27
  from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
26
28
  from rasa.utils import train_utils
27
29
  from rasa.utils.tensorflow import rasa_layers
28
- from rasa.utils.tensorflow.feature_array import (
29
- FeatureArray,
30
- serialize_nested_feature_arrays,
31
- deserialize_nested_feature_arrays,
32
- )
33
30
  from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
34
31
  from rasa.utils.tensorflow.model_data import (
35
32
  RasaModelData,
36
33
  FeatureSignature,
34
+ FeatureArray,
37
35
  )
38
36
  from rasa.nlu.constants import TOKENS_NAMES, DEFAULT_TRANSFORMER_SIZE
39
37
  from rasa.shared.nlu.constants import (
@@ -120,6 +118,7 @@ LABEL_SUB_KEY = IDS
120
118
 
121
119
  POSSIBLE_TAGS = [ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_ROLE, ENTITY_ATTRIBUTE_GROUP]
122
120
 
121
+
123
122
  DIETClassifierT = TypeVar("DIETClassifierT", bound="DIETClassifier")
124
123
 
125
124
 
@@ -1084,24 +1083,18 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1084
1083
 
1085
1084
  self.model.save(str(tf_model_file))
1086
1085
 
1087
- # save data example
1088
- serialize_nested_feature_arrays(
1089
- self._data_example,
1090
- model_path / f"{file_name}.data_example.st",
1091
- model_path / f"{file_name}.data_example_metadata.json",
1092
- )
1093
- # save label data
1094
- serialize_nested_feature_arrays(
1095
- dict(self._label_data.data) if self._label_data is not None else {},
1096
- model_path / f"{file_name}.label_data.st",
1097
- model_path / f"{file_name}.label_data_metadata.json",
1086
+ io_utils.pickle_dump(
1087
+ model_path / f"{file_name}.data_example.pkl", self._data_example
1098
1088
  )
1099
-
1100
- rasa.shared.utils.io.dump_obj_as_json_to_file(
1101
- model_path / f"{file_name}.sparse_feature_sizes.json",
1089
+ io_utils.pickle_dump(
1090
+ model_path / f"{file_name}.sparse_feature_sizes.pkl",
1102
1091
  self._sparse_feature_sizes,
1103
1092
  )
1104
- rasa.shared.utils.io.dump_obj_as_json_to_file(
1093
+ io_utils.pickle_dump(
1094
+ model_path / f"{file_name}.label_data.pkl",
1095
+ dict(self._label_data.data) if self._label_data is not None else {},
1096
+ )
1097
+ io_utils.json_pickle(
1105
1098
  model_path / f"{file_name}.index_label_id_mapping.json",
1106
1099
  self.index_label_id_mapping,
1107
1100
  )
@@ -1190,22 +1183,15 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1190
1183
  ]:
1191
1184
  file_name = cls.__name__
1192
1185
 
1193
- # load data example
1194
- data_example = deserialize_nested_feature_arrays(
1195
- str(model_path / f"{file_name}.data_example.st"),
1196
- str(model_path / f"{file_name}.data_example_metadata.json"),
1186
+ data_example = io_utils.pickle_load(
1187
+ model_path / f"{file_name}.data_example.pkl"
1197
1188
  )
1198
- # load label data
1199
- loaded_label_data = deserialize_nested_feature_arrays(
1200
- str(model_path / f"{file_name}.label_data.st"),
1201
- str(model_path / f"{file_name}.label_data_metadata.json"),
1202
- )
1203
- label_data = RasaModelData(data=loaded_label_data)
1204
-
1205
- sparse_feature_sizes = rasa.shared.utils.io.read_json_file(
1206
- model_path / f"{file_name}.sparse_feature_sizes.json"
1189
+ label_data = io_utils.pickle_load(model_path / f"{file_name}.label_data.pkl")
1190
+ label_data = RasaModelData(data=label_data)
1191
+ sparse_feature_sizes = io_utils.pickle_load(
1192
+ model_path / f"{file_name}.sparse_feature_sizes.pkl"
1207
1193
  )
1208
- index_label_id_mapping = rasa.shared.utils.io.read_json_file(
1194
+ index_label_id_mapping = io_utils.json_unpickle(
1209
1195
  model_path / f"{file_name}.index_label_id_mapping.json"
1210
1196
  )
1211
1197
  entity_tag_specs = rasa.shared.utils.io.read_json_file(
@@ -1225,6 +1211,7 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
1225
1211
  for tag_spec in entity_tag_specs
1226
1212
  ]
1227
1213
 
1214
+ # jsonpickle converts dictionary keys to strings
1228
1215
  index_label_id_mapping = {
1229
1216
  int(key): value for key, value in index_label_id_mapping.items()
1230
1217
  }