rasa-pro 3.10.16__py3-none-any.whl → 3.11.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +396 -17
- rasa/api.py +9 -3
- rasa/cli/arguments/default_arguments.py +23 -2
- rasa/cli/arguments/run.py +15 -0
- rasa/cli/arguments/train.py +3 -9
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +1 -1
- rasa/cli/inspect.py +8 -4
- rasa/cli/llm_fine_tuning.py +12 -15
- rasa/cli/run.py +8 -1
- rasa/cli/studio/studio.py +8 -18
- rasa/cli/train.py +11 -53
- rasa/cli/utils.py +8 -10
- rasa/cli/x.py +1 -1
- rasa/constants.py +1 -1
- rasa/core/actions/action.py +2 -0
- rasa/core/actions/action_hangup.py +29 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +3 -1
- rasa/core/brokers/pika.py +3 -1
- rasa/core/channels/__init__.py +8 -6
- rasa/core/channels/channel.py +21 -4
- rasa/core/channels/development_inspector.py +143 -46
- rasa/core/channels/inspector/README.md +1 -1
- rasa/core/channels/inspector/dist/assets/{arc-b6e548fe.js → arc-86942a71.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-fa03ac9e.js → c4Diagram-d0fbc5ce-b0290676.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-ee67392a.js → classDiagram-936ed81e-f6405f6e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-9b283fae.js → classDiagram-v2-c3cb15f1-ef61ac77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-8b6fcc2a.js → createText-62fc7601-f0411e58.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-22e77f4f.js → edges-f2ad444c-7dcc4f3b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-60ffc87f.js → erDiagram-9d236eb7-e0c092d7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-9dd802e4.js → flowDb-1972c806-fba2e3ce.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-5fa1912f.js → flowDiagram-7ea5b25a-7a70b71a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-24a5f41a.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-622a1fd2.js → flowchart-elk-definition-abe16c3d-00a59b68.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-e285a63a.js → ganttDiagram-9b5ea136-293c91fa.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-f237bdca.js → gitGraphDiagram-99d0ae7c-07b2d68c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-4b03d70e.js → index-2c4b9a3b-bc959fbd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-3a8a5a28.js +1317 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-72a0fa5f.js → infoDiagram-736b4530-4a350f72.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-82218c41.js → journeyDiagram-df861f2b-af464fb7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-78cff630.js → layout-0071f036.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-5038b469.js → line-2f73cc83.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-c4fc4098.js → linear-f014b4cc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-c33c8ea6.js → mindmap-definition-beec6740-d2426fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-a8d03059.js → pieDiagram-dbbf0591-776f01a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-6a0e56b2.js → quadrantDiagram-4d7f4fd6-82e00b57.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-2dc7c7bd.js → requirementDiagram-6fc4c22a-ea13c6bb.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-2360fe39.js → sankeyDiagram-8f13d901-1feca7e9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-41b9f9ad.js → sequenceDiagram-b655622a-070c61d2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-0aad326f.js → stateDiagram-59f0c015-24f46263.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-9847d984.js → stateDiagram-v2-2b26beab-c9056051.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-564d890e.js → styles-080da4f6-08abc34a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-38957613.js → styles-3dcbcfbf-bc74c25a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-f0fc6921.js → styles-9c745c82-4e5d66de.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-ef3c5a77.js → svgDrawCommon-4835440b-849c4517.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-bf3e91c1.js → timeline-definition-5b62e21b-d0fb1598.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-4d4026c0.js → xychartDiagram-2b33534f-04d115e2.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +18 -17
- rasa/core/channels/inspector/index.html +17 -16
- rasa/core/channels/inspector/package.json +5 -1
- rasa/core/channels/inspector/src/App.tsx +117 -67
- rasa/core/channels/inspector/src/components/Chat.tsx +95 -0
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +11 -10
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +10 -25
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +107 -41
- rasa/core/channels/inspector/src/helpers/utils.ts +92 -7
- rasa/core/channels/inspector/src/types.ts +21 -1
- rasa/core/channels/inspector/yarn.lock +94 -1
- rasa/core/channels/rest.py +51 -46
- rasa/core/channels/socketio.py +22 -0
- rasa/core/channels/{audiocodes.py → voice_ready/audiocodes.py} +110 -68
- rasa/core/channels/{voice_aware → voice_ready}/jambonz.py +11 -4
- rasa/core/channels/{voice_aware → voice_ready}/jambonz_protocol.py +57 -5
- rasa/core/channels/{twilio_voice.py → voice_ready/twilio_voice.py} +58 -7
- rasa/core/channels/{voice_aware → voice_ready}/utils.py +16 -0
- rasa/core/channels/voice_stream/asr/__init__.py +0 -0
- rasa/core/channels/voice_stream/asr/asr_engine.py +71 -0
- rasa/core/channels/voice_stream/asr/asr_event.py +13 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +77 -0
- rasa/core/channels/voice_stream/audio_bytes.py +7 -0
- rasa/core/channels/voice_stream/tts/__init__.py +0 -0
- rasa/core/channels/voice_stream/tts/azure.py +100 -0
- rasa/core/channels/voice_stream/tts/cartesia.py +114 -0
- rasa/core/channels/voice_stream/tts/tts_cache.py +27 -0
- rasa/core/channels/voice_stream/tts/tts_engine.py +48 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +164 -0
- rasa/core/channels/voice_stream/util.py +57 -0
- rasa/core/channels/voice_stream/voice_channel.py +247 -0
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/nlg/contextual_response_rephraser.py +11 -2
- rasa/{nlu → core}/persistor.py +16 -38
- rasa/core/policies/enterprise_search_policy.py +12 -15
- rasa/core/policies/flows/flow_executor.py +8 -18
- rasa/core/policies/intentless_policy.py +10 -15
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +13 -64
- rasa/core/run.py +11 -1
- rasa/core/secrets_manager/constants.py +4 -0
- rasa/core/secrets_manager/factory.py +8 -0
- rasa/core/secrets_manager/vault.py +11 -1
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +1 -11
- rasa/dialogue_understanding/coexistence/llm_based_router.py +10 -10
- rasa/dialogue_understanding/commands/__init__.py +2 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/session_end_command.py +61 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +0 -7
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -3
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +3 -28
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -19
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +4 -37
- rasa/e2e_test/aggregate_test_stats_calculator.py +1 -11
- rasa/e2e_test/assertions.py +6 -48
- rasa/e2e_test/e2e_test_runner.py +6 -9
- rasa/e2e_test/utils/e2e_yaml_utils.py +1 -1
- rasa/e2e_test/utils/io.py +1 -3
- rasa/engine/graph.py +3 -10
- rasa/engine/recipes/config_files/default_config.yml +0 -3
- rasa/engine/recipes/default_recipe.py +0 -1
- rasa/engine/recipes/graph_recipe.py +0 -1
- rasa/engine/runner/dask.py +2 -2
- rasa/engine/storage/local_model_storage.py +12 -42
- rasa/engine/storage/storage.py +1 -5
- rasa/engine/validation.py +1 -78
- rasa/keys +1 -0
- rasa/model_training.py +13 -16
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +16 -45
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/server.py +1 -1
- rasa/shared/constants.py +3 -12
- rasa/shared/core/constants.py +4 -0
- rasa/shared/core/domain.py +101 -47
- rasa/shared/core/events.py +29 -0
- rasa/shared/core/flows/flows_list.py +20 -11
- rasa/shared/core/flows/validation.py +25 -0
- rasa/shared/core/flows/yaml_flows_io.py +3 -24
- rasa/shared/importers/importer.py +40 -39
- rasa/shared/importers/multi_project.py +23 -11
- rasa/shared/importers/rasa.py +7 -2
- rasa/shared/importers/remote_importer.py +196 -0
- rasa/shared/importers/utils.py +3 -1
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/training_data.py +18 -19
- rasa/shared/providers/_configs/azure_openai_client_config.py +3 -5
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +1 -6
- rasa/shared/providers/llm/_base_litellm_client.py +11 -31
- rasa/shared/providers/llm/self_hosted_llm_client.py +3 -15
- rasa/shared/utils/common.py +3 -22
- rasa/shared/utils/io.py +0 -1
- rasa/shared/utils/llm.py +30 -27
- rasa/shared/utils/schemas/events.py +2 -0
- rasa/shared/utils/schemas/model_config.yml +0 -10
- rasa/shared/utils/yaml.py +44 -0
- rasa/studio/auth.py +5 -3
- rasa/studio/config.py +4 -13
- rasa/studio/constants.py +0 -1
- rasa/studio/data_handler.py +3 -10
- rasa/studio/upload.py +8 -17
- rasa/tracing/instrumentation/attribute_extractors.py +1 -1
- rasa/utils/io.py +66 -0
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +0 -12
- rasa/version.py +1 -1
- rasa_pro-3.11.0a1.dist-info/METADATA +576 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/RECORD +181 -164
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-a5d3e69d.js +0 -1040
- rasa/utils/tensorflow/feature_array.py +0 -366
- rasa_pro-3.10.16.dist-info/METADATA +0 -196
- /rasa/core/channels/{voice_aware → voice_ready}/__init__.py +0 -0
- /rasa/core/channels/{voice_native → voice_stream}/__init__.py +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.10.16.dist-info → rasa_pro-3.11.0a1.dist-info}/entry_points.txt +0 -0
rasa/engine/runner/dask.py
CHANGED
|
@@ -220,7 +220,7 @@ async def execute_dask_graph(dsk: Dict[str, Any], result: List[str]) -> Any:
|
|
|
220
220
|
# if start_state_from_dask fails, we will have something
|
|
221
221
|
# to pass to the final block.
|
|
222
222
|
state = {}
|
|
223
|
-
keyorder = dask.local.order(dsk)
|
|
223
|
+
keyorder = dask.local.order(dsk)
|
|
224
224
|
|
|
225
225
|
state = dask.local.start_state_from_dask(dsk, cache=cache, sortkey=keyorder.get) # type:ignore[no-untyped-call]
|
|
226
226
|
|
|
@@ -235,7 +235,7 @@ async def execute_dask_graph(dsk: Dict[str, Any], result: List[str]) -> Any:
|
|
|
235
235
|
# Notify task is running
|
|
236
236
|
state["running"].add(key)
|
|
237
237
|
|
|
238
|
-
dependencies = dask.local.get_dependencies(dsk, key)
|
|
238
|
+
dependencies = dask.local.get_dependencies(dsk, key)
|
|
239
239
|
# Prep args to send
|
|
240
240
|
data = {dep: state["cache"][dep] for dep in dependencies}
|
|
241
241
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
import shutil
|
|
4
5
|
import sys
|
|
5
6
|
import tempfile
|
|
@@ -7,21 +8,19 @@ import uuid
|
|
|
7
8
|
from contextlib import contextmanager
|
|
8
9
|
from datetime import datetime
|
|
9
10
|
from pathlib import Path
|
|
10
|
-
from typing import Generator, Optional, Text, Tuple, Union
|
|
11
|
-
|
|
12
|
-
import structlog
|
|
13
11
|
from tarsafe import TarSafe
|
|
12
|
+
from typing import Generator, Optional, Text, Tuple, Union
|
|
14
13
|
|
|
15
|
-
import rasa.model
|
|
16
|
-
import rasa.shared.utils.io
|
|
17
14
|
import rasa.utils.common
|
|
15
|
+
import rasa.shared.utils.io
|
|
16
|
+
from rasa.engine.storage.storage import ModelMetadata, ModelStorage
|
|
18
17
|
from rasa.engine.graph import GraphModelConfiguration
|
|
19
18
|
from rasa.engine.storage.resource import Resource
|
|
20
|
-
from rasa.engine.storage.storage import ModelMetadata, ModelStorage
|
|
21
19
|
from rasa.exceptions import UnsupportedModelVersionError
|
|
22
20
|
from rasa.shared.core.domain import Domain
|
|
21
|
+
import rasa.model
|
|
23
22
|
|
|
24
|
-
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
25
24
|
|
|
26
25
|
# Paths within model archive
|
|
27
26
|
MODEL_ARCHIVE_COMPONENTS_DIR = "components"
|
|
@@ -87,14 +86,7 @@ class LocalModelStorage(ModelStorage):
|
|
|
87
86
|
cls._extract_archive_to_directory(
|
|
88
87
|
model_archive_path, temporary_directory_path
|
|
89
88
|
)
|
|
90
|
-
|
|
91
|
-
"local_model_storage.from_model_archive",
|
|
92
|
-
event_info=(
|
|
93
|
-
f"Extracted model '{model_archive_path}' to "
|
|
94
|
-
f"'{temporary_directory_path}'."
|
|
95
|
-
),
|
|
96
|
-
)
|
|
97
|
-
|
|
89
|
+
logger.debug(f"Extracted model to '{temporary_directory_path}'.")
|
|
98
90
|
cls._initialize_model_storage_from_model_archive(
|
|
99
91
|
temporary_directory_path, storage_path
|
|
100
92
|
)
|
|
@@ -150,10 +142,6 @@ class LocalModelStorage(ModelStorage):
|
|
|
150
142
|
temporary_directory: Path, storage_path: Path
|
|
151
143
|
) -> None:
|
|
152
144
|
for path in (temporary_directory / MODEL_ARCHIVE_COMPONENTS_DIR).glob("*"):
|
|
153
|
-
structlogger.debug(
|
|
154
|
-
"local_model_storage._initialize_model_storage_from_model_archive",
|
|
155
|
-
event_info=f"Moving '{path}' to '{storage_path}'.",
|
|
156
|
-
)
|
|
157
145
|
shutil.move(str(path), str(storage_path))
|
|
158
146
|
|
|
159
147
|
@staticmethod
|
|
@@ -167,10 +155,7 @@ class LocalModelStorage(ModelStorage):
|
|
|
167
155
|
@contextmanager
|
|
168
156
|
def write_to(self, resource: Resource) -> Generator[Path, None, None]:
|
|
169
157
|
"""Persists data for a resource (see parent class for full docstring)."""
|
|
170
|
-
|
|
171
|
-
"local_model_storage.write_to.resource_write_requested",
|
|
172
|
-
event_info=f"Resource '{resource.name}' was requested for writing.",
|
|
173
|
-
)
|
|
158
|
+
logger.debug(f"Resource '{resource.name}' was requested for writing.")
|
|
174
159
|
directory = self._directory_for_resource(resource)
|
|
175
160
|
|
|
176
161
|
if not directory.exists():
|
|
@@ -178,10 +163,7 @@ class LocalModelStorage(ModelStorage):
|
|
|
178
163
|
|
|
179
164
|
yield directory
|
|
180
165
|
|
|
181
|
-
|
|
182
|
-
"local_model_storage.write_to.resource_persisted",
|
|
183
|
-
event_info=f"Resource '{resource.name}' was persisted.",
|
|
184
|
-
)
|
|
166
|
+
logger.debug(f"Resource '{resource.name}' was persisted.")
|
|
185
167
|
|
|
186
168
|
def _directory_for_resource(self, resource: Resource) -> Path:
|
|
187
169
|
return self._storage_path / resource.name
|
|
@@ -189,10 +171,7 @@ class LocalModelStorage(ModelStorage):
|
|
|
189
171
|
@contextmanager
|
|
190
172
|
def read_from(self, resource: Resource) -> Generator[Path, None, None]:
|
|
191
173
|
"""Provides the data of a `Resource` (see parent class for full docstring)."""
|
|
192
|
-
|
|
193
|
-
"local_model_storage.read_from",
|
|
194
|
-
event_info=f"Resource '{resource.name}' was requested for reading.",
|
|
195
|
-
)
|
|
174
|
+
logger.debug(f"Resource '{resource.name}' was requested for reading.")
|
|
196
175
|
directory = self._directory_for_resource(resource)
|
|
197
176
|
|
|
198
177
|
if not directory.exists():
|
|
@@ -214,12 +193,7 @@ class LocalModelStorage(ModelStorage):
|
|
|
214
193
|
domain: Domain,
|
|
215
194
|
) -> ModelMetadata:
|
|
216
195
|
"""Creates model package (see parent class for full docstring)."""
|
|
217
|
-
|
|
218
|
-
"local_model_storage.create_model_package.started",
|
|
219
|
-
event_info=(
|
|
220
|
-
f"Start to created model " f"package for path '{model_archive_path}'.",
|
|
221
|
-
),
|
|
222
|
-
)
|
|
196
|
+
logger.debug(f"Start to created model package for path '{model_archive_path}'.")
|
|
223
197
|
|
|
224
198
|
with windows_safe_temporary_directory() as temp_dir:
|
|
225
199
|
temporary_directory = Path(temp_dir)
|
|
@@ -240,10 +214,7 @@ class LocalModelStorage(ModelStorage):
|
|
|
240
214
|
with TarSafe.open(model_archive_path, "w:gz") as tar:
|
|
241
215
|
tar.add(temporary_directory, arcname="")
|
|
242
216
|
|
|
243
|
-
|
|
244
|
-
"local_model_storage.create_model_package.finished",
|
|
245
|
-
event_info=f"Model package created in path '{model_archive_path}'.",
|
|
246
|
-
)
|
|
217
|
+
logger.debug(f"Model package created in path '{model_archive_path}'.")
|
|
247
218
|
|
|
248
219
|
return model_metadata
|
|
249
220
|
|
|
@@ -267,7 +238,6 @@ class LocalModelStorage(ModelStorage):
|
|
|
267
238
|
predict_schema=model_configuration.predict_schema,
|
|
268
239
|
training_type=model_configuration.training_type,
|
|
269
240
|
project_fingerprint=rasa.model.project_fingerprint(),
|
|
270
|
-
spaces=model_configuration.spaces,
|
|
271
241
|
language=model_configuration.language,
|
|
272
242
|
core_target=model_configuration.core_target,
|
|
273
243
|
nlu_target=model_configuration.nlu_target,
|
rasa/engine/storage/storage.py
CHANGED
|
@@ -6,7 +6,7 @@ from contextlib import contextmanager
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
from datetime import datetime
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
from typing import
|
|
9
|
+
from typing import Tuple, Union, Text, Generator, Dict, Any, Optional
|
|
10
10
|
from packaging import version
|
|
11
11
|
|
|
12
12
|
from rasa.constants import MINIMUM_COMPATIBLE_VERSION
|
|
@@ -140,7 +140,6 @@ class ModelMetadata:
|
|
|
140
140
|
core_target: Optional[Text]
|
|
141
141
|
nlu_target: Text
|
|
142
142
|
language: Optional[Text]
|
|
143
|
-
spaces: Optional[List[Dict[Text, Any]]] = None
|
|
144
143
|
training_type: TrainingType = TrainingType.BOTH
|
|
145
144
|
|
|
146
145
|
def __post_init__(self) -> None:
|
|
@@ -170,7 +169,6 @@ class ModelMetadata:
|
|
|
170
169
|
"core_target": self.core_target,
|
|
171
170
|
"nlu_target": self.nlu_target,
|
|
172
171
|
"language": self.language,
|
|
173
|
-
"spaces": self.spaces,
|
|
174
172
|
}
|
|
175
173
|
|
|
176
174
|
@classmethod
|
|
@@ -198,6 +196,4 @@ class ModelMetadata:
|
|
|
198
196
|
core_target=serialized["core_target"],
|
|
199
197
|
nlu_target=serialized["nlu_target"],
|
|
200
198
|
language=serialized["language"],
|
|
201
|
-
# optional, since introduced later
|
|
202
|
-
spaces=serialized.get("spaces"),
|
|
203
199
|
)
|
rasa/engine/validation.py
CHANGED
|
@@ -16,7 +16,6 @@ from typing import (
|
|
|
16
16
|
Union,
|
|
17
17
|
TypeVar,
|
|
18
18
|
List,
|
|
19
|
-
Literal,
|
|
20
19
|
)
|
|
21
20
|
|
|
22
21
|
import structlog
|
|
@@ -35,7 +34,6 @@ from rasa.dialogue_understanding.coexistence.constants import (
|
|
|
35
34
|
from rasa.dialogue_understanding.generator import (
|
|
36
35
|
LLMBasedCommandGenerator,
|
|
37
36
|
)
|
|
38
|
-
from rasa.dialogue_understanding.generator.constants import FLOW_RETRIEVAL_KEY
|
|
39
37
|
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
40
38
|
from rasa.engine.constants import RESERVED_PLACEHOLDERS
|
|
41
39
|
from rasa.engine.exceptions import GraphSchemaValidationException
|
|
@@ -49,15 +47,7 @@ from rasa.engine.graph import (
|
|
|
49
47
|
from rasa.engine.storage.resource import Resource
|
|
50
48
|
from rasa.engine.storage.storage import ModelStorage
|
|
51
49
|
from rasa.engine.training.fingerprinting import Fingerprintable
|
|
52
|
-
from rasa.shared.constants import
|
|
53
|
-
DOCS_URL_GRAPH_COMPONENTS,
|
|
54
|
-
ROUTE_TO_CALM_SLOT,
|
|
55
|
-
API_TYPE_CONFIG_KEY,
|
|
56
|
-
VALID_PROVIDERS_FOR_API_TYPE_CONFIG_KEY,
|
|
57
|
-
PROVIDER_CONFIG_KEY,
|
|
58
|
-
LLM_CONFIG_KEY,
|
|
59
|
-
EMBEDDINGS_CONFIG_KEY,
|
|
60
|
-
)
|
|
50
|
+
from rasa.shared.constants import DOCS_URL_GRAPH_COMPONENTS, ROUTE_TO_CALM_SLOT
|
|
61
51
|
from rasa.shared.core.constants import ACTION_RESET_ROUTING, ACTION_TRIGGER_CHITCHAT
|
|
62
52
|
from rasa.shared.core.domain import Domain
|
|
63
53
|
from rasa.shared.core.flows import FlowsList, Flow
|
|
@@ -881,70 +871,3 @@ def validate_command_generator_setup(
|
|
|
881
871
|
) -> None:
|
|
882
872
|
schema = model_configuration.predict_schema
|
|
883
873
|
validate_command_generator_exclusivity(schema)
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
def validate_model_client_configuration_setup(config: Dict[str, Any]) -> None:
|
|
887
|
-
"""Validates the model client configuration setup.
|
|
888
|
-
|
|
889
|
-
Validation fails, if
|
|
890
|
-
- the LLM/embeddings provider is defined using 'api_type' key for providers other
|
|
891
|
-
than 'openai' or 'azure'
|
|
892
|
-
|
|
893
|
-
Args:
|
|
894
|
-
config: The config dictionary
|
|
895
|
-
"""
|
|
896
|
-
for outer_key in ["pipeline", "policies"]:
|
|
897
|
-
if outer_key not in config or config[outer_key] is None:
|
|
898
|
-
continue
|
|
899
|
-
|
|
900
|
-
for component_config in config[outer_key]:
|
|
901
|
-
for key in [LLM_CONFIG_KEY, EMBEDDINGS_CONFIG_KEY]:
|
|
902
|
-
validate_api_type_config_key_usage(component_config, key)
|
|
903
|
-
|
|
904
|
-
# as flow retrieval is not a component itself, we need to
|
|
905
|
-
# check it separately
|
|
906
|
-
if (
|
|
907
|
-
FLOW_RETRIEVAL_KEY in component_config
|
|
908
|
-
and EMBEDDINGS_CONFIG_KEY in component_config[FLOW_RETRIEVAL_KEY]
|
|
909
|
-
):
|
|
910
|
-
validate_api_type_config_key_usage(
|
|
911
|
-
component_config[FLOW_RETRIEVAL_KEY],
|
|
912
|
-
EMBEDDINGS_CONFIG_KEY,
|
|
913
|
-
component_config["name"] + "." + FLOW_RETRIEVAL_KEY,
|
|
914
|
-
)
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
def validate_api_type_config_key_usage(
|
|
918
|
-
component_config: Dict[str, Any],
|
|
919
|
-
key: Literal["llm", "embeddings"],
|
|
920
|
-
component_name: Optional[str] = None,
|
|
921
|
-
) -> None:
|
|
922
|
-
"""Validate the LLM/embeddings configuration of a component.
|
|
923
|
-
|
|
924
|
-
Validation fails, if
|
|
925
|
-
- the LLM/embeddings provider is defined using 'api_type' key for providers other
|
|
926
|
-
than 'openai' or 'azure'
|
|
927
|
-
|
|
928
|
-
Args:
|
|
929
|
-
component_config: The config of the component
|
|
930
|
-
key: either 'llm' or 'embeddings'
|
|
931
|
-
component_name: the name of the component
|
|
932
|
-
"""
|
|
933
|
-
if component_config is None or key not in component_config:
|
|
934
|
-
return
|
|
935
|
-
|
|
936
|
-
if API_TYPE_CONFIG_KEY in component_config[key]:
|
|
937
|
-
api_type = component_config[key][API_TYPE_CONFIG_KEY]
|
|
938
|
-
if api_type not in VALID_PROVIDERS_FOR_API_TYPE_CONFIG_KEY:
|
|
939
|
-
structlogger.error(
|
|
940
|
-
"validation.component.api_type_config_key_invalid",
|
|
941
|
-
event_info=(
|
|
942
|
-
f"You specified '{API_TYPE_CONFIG_KEY}: {api_type}' for "
|
|
943
|
-
f"'{component_name or component_config['name']}', which is not "
|
|
944
|
-
f"allowed. "
|
|
945
|
-
f"The '{API_TYPE_CONFIG_KEY}' key can only be used for the "
|
|
946
|
-
f"following providers: {VALID_PROVIDERS_FOR_API_TYPE_CONFIG_KEY}. "
|
|
947
|
-
f"For other providers, please use the '{PROVIDER_CONFIG_KEY}' key."
|
|
948
|
-
),
|
|
949
|
-
)
|
|
950
|
-
sys.exit(1)
|
rasa/keys
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"segment": "CcvVD1I68Nkkxrv93cIqv1twIwrwG8nz", "sentry": "a283f1fde04347b099c8d729109dd450@o251570"}
|
rasa/model_training.py
CHANGED
|
@@ -15,6 +15,7 @@ import rasa.shared.utils.common
|
|
|
15
15
|
import rasa.shared.utils.io
|
|
16
16
|
import rasa.utils.common
|
|
17
17
|
from rasa import telemetry
|
|
18
|
+
from rasa.core.persistor import StorageType
|
|
18
19
|
from rasa.engine.caching import LocalTrainingCache
|
|
19
20
|
from rasa.engine.recipes.recipe import Recipe
|
|
20
21
|
from rasa.engine.runner.dask import DaskGraphRunner
|
|
@@ -22,7 +23,6 @@ from rasa.engine.storage.local_model_storage import LocalModelStorage
|
|
|
22
23
|
from rasa.engine.storage.storage import ModelStorage
|
|
23
24
|
from rasa.engine.training.components import FingerprintStatus
|
|
24
25
|
from rasa.engine.training.graph_trainer import GraphTrainer
|
|
25
|
-
from rasa.nlu.persistor import RemoteStorageType, StorageType
|
|
26
26
|
from rasa.shared.core.domain import Domain
|
|
27
27
|
from rasa.shared.core.events import SlotSet
|
|
28
28
|
from rasa.shared.core.training_data.structures import StoryGraph
|
|
@@ -156,6 +156,7 @@ async def train(
|
|
|
156
156
|
model_to_finetune: Optional[Text] = None,
|
|
157
157
|
finetuning_epoch_fraction: float = 1.0,
|
|
158
158
|
remote_storage: Optional[StorageType] = None,
|
|
159
|
+
file_importer: Optional[TrainingDataImporter] = None,
|
|
159
160
|
) -> TrainingResult:
|
|
160
161
|
"""Trains a Rasa model (Core and NLU).
|
|
161
162
|
|
|
@@ -177,14 +178,18 @@ async def train(
|
|
|
177
178
|
a directory in case the latest trained model should be used.
|
|
178
179
|
finetuning_epoch_fraction: The fraction currently specified training epochs
|
|
179
180
|
in the model configuration which should be used for finetuning.
|
|
180
|
-
remote_storage:
|
|
181
|
+
remote_storage: Optional name of the remote storage to
|
|
182
|
+
use for storing the model.
|
|
183
|
+
file_importer: Instance of `TrainingDataImporter` to use for training.
|
|
184
|
+
If it is not provided, a new instance will be created.
|
|
181
185
|
|
|
182
186
|
Returns:
|
|
183
187
|
An instance of `TrainingResult`.
|
|
184
188
|
"""
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
189
|
+
if not file_importer:
|
|
190
|
+
file_importer = TrainingDataImporter.load_from_config(
|
|
191
|
+
config, domain, training_files, core_additional_arguments
|
|
192
|
+
)
|
|
188
193
|
|
|
189
194
|
stories = file_importer.get_stories()
|
|
190
195
|
flows = file_importer.get_flows()
|
|
@@ -312,7 +317,6 @@ async def _train_graph(
|
|
|
312
317
|
rasa.engine.validation.validate_coexistance_routing_setup(
|
|
313
318
|
domain, model_configuration, flows
|
|
314
319
|
)
|
|
315
|
-
rasa.engine.validation.validate_model_client_configuration_setup(config)
|
|
316
320
|
rasa.engine.validation.validate_flow_component_dependencies(
|
|
317
321
|
flows, model_configuration
|
|
318
322
|
)
|
|
@@ -351,25 +355,18 @@ async def _train_graph(
|
|
|
351
355
|
if remote_storage:
|
|
352
356
|
push_model_to_remote_storage(full_model_path, remote_storage)
|
|
353
357
|
full_model_path.unlink()
|
|
354
|
-
remote_storage_string = (
|
|
355
|
-
remote_storage.value
|
|
356
|
-
if isinstance(remote_storage, RemoteStorageType)
|
|
357
|
-
else remote_storage
|
|
358
|
-
)
|
|
359
358
|
structlogger.info(
|
|
360
359
|
"model_training.train.finished_training",
|
|
361
360
|
event_info=(
|
|
362
361
|
f"Your Rasa model {model_name} is trained "
|
|
363
|
-
f"and saved at remote storage provider "
|
|
364
|
-
f"'{remote_storage_string}'."
|
|
362
|
+
f"and saved at remote storage provider '{remote_storage}'."
|
|
365
363
|
),
|
|
366
364
|
)
|
|
367
365
|
else:
|
|
368
366
|
structlogger.info(
|
|
369
367
|
"model_training.train.finished_training",
|
|
370
368
|
event_info=(
|
|
371
|
-
f"Your Rasa model is trained and saved at "
|
|
372
|
-
f"'{full_model_path}'."
|
|
369
|
+
f"Your Rasa model is trained and saved at '{full_model_path}'."
|
|
373
370
|
),
|
|
374
371
|
)
|
|
375
372
|
|
|
@@ -563,7 +560,7 @@ async def train_nlu(
|
|
|
563
560
|
|
|
564
561
|
def push_model_to_remote_storage(model_path: Path, remote_storage: StorageType) -> None:
|
|
565
562
|
"""push model to remote storage"""
|
|
566
|
-
from rasa.
|
|
563
|
+
from rasa.core.persistor import get_persistor
|
|
567
564
|
|
|
568
565
|
persistor = get_persistor(remote_storage)
|
|
569
566
|
|
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
import copy
|
|
4
3
|
import logging
|
|
5
4
|
from collections import defaultdict
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
|
|
6
|
+
|
|
7
|
+
from rasa.exceptions import ModelNotFound
|
|
8
|
+
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
11
|
import scipy.sparse
|
|
11
12
|
import tensorflow as tf
|
|
12
13
|
|
|
13
|
-
from
|
|
14
|
-
|
|
14
|
+
from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
|
|
15
|
+
|
|
15
16
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
16
17
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
17
18
|
from rasa.engine.storage.resource import Resource
|
|
@@ -19,21 +20,18 @@ from rasa.engine.storage.storage import ModelStorage
|
|
|
19
20
|
from rasa.nlu.extractors.extractor import EntityExtractorMixin
|
|
20
21
|
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
21
22
|
import rasa.shared.utils.io
|
|
23
|
+
import rasa.utils.io as io_utils
|
|
22
24
|
import rasa.nlu.utils.bilou_utils as bilou_utils
|
|
23
25
|
from rasa.shared.constants import DIAGNOSTIC_DATA
|
|
24
26
|
from rasa.nlu.extractors.extractor import EntityTagSpec
|
|
25
27
|
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
26
28
|
from rasa.utils import train_utils
|
|
27
29
|
from rasa.utils.tensorflow import rasa_layers
|
|
28
|
-
from rasa.utils.tensorflow.feature_array import (
|
|
29
|
-
FeatureArray,
|
|
30
|
-
serialize_nested_feature_arrays,
|
|
31
|
-
deserialize_nested_feature_arrays,
|
|
32
|
-
)
|
|
33
30
|
from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
|
|
34
31
|
from rasa.utils.tensorflow.model_data import (
|
|
35
32
|
RasaModelData,
|
|
36
33
|
FeatureSignature,
|
|
34
|
+
FeatureArray,
|
|
37
35
|
)
|
|
38
36
|
from rasa.nlu.constants import TOKENS_NAMES, DEFAULT_TRANSFORMER_SIZE
|
|
39
37
|
from rasa.shared.nlu.constants import (
|
|
@@ -120,6 +118,7 @@ LABEL_SUB_KEY = IDS
|
|
|
120
118
|
|
|
121
119
|
POSSIBLE_TAGS = [ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_ROLE, ENTITY_ATTRIBUTE_GROUP]
|
|
122
120
|
|
|
121
|
+
|
|
123
122
|
DIETClassifierT = TypeVar("DIETClassifierT", bound="DIETClassifier")
|
|
124
123
|
|
|
125
124
|
|
|
@@ -1084,24 +1083,18 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1084
1083
|
|
|
1085
1084
|
self.model.save(str(tf_model_file))
|
|
1086
1085
|
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
self._data_example,
|
|
1090
|
-
model_path / f"{file_name}.data_example.st",
|
|
1091
|
-
model_path / f"{file_name}.data_example_metadata.json",
|
|
1092
|
-
)
|
|
1093
|
-
# save label data
|
|
1094
|
-
serialize_nested_feature_arrays(
|
|
1095
|
-
dict(self._label_data.data) if self._label_data is not None else {},
|
|
1096
|
-
model_path / f"{file_name}.label_data.st",
|
|
1097
|
-
model_path / f"{file_name}.label_data_metadata.json",
|
|
1086
|
+
io_utils.pickle_dump(
|
|
1087
|
+
model_path / f"{file_name}.data_example.pkl", self._data_example
|
|
1098
1088
|
)
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
model_path / f"{file_name}.sparse_feature_sizes.json",
|
|
1089
|
+
io_utils.pickle_dump(
|
|
1090
|
+
model_path / f"{file_name}.sparse_feature_sizes.pkl",
|
|
1102
1091
|
self._sparse_feature_sizes,
|
|
1103
1092
|
)
|
|
1104
|
-
|
|
1093
|
+
io_utils.pickle_dump(
|
|
1094
|
+
model_path / f"{file_name}.label_data.pkl",
|
|
1095
|
+
dict(self._label_data.data) if self._label_data is not None else {},
|
|
1096
|
+
)
|
|
1097
|
+
io_utils.json_pickle(
|
|
1105
1098
|
model_path / f"{file_name}.index_label_id_mapping.json",
|
|
1106
1099
|
self.index_label_id_mapping,
|
|
1107
1100
|
)
|
|
@@ -1190,22 +1183,15 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1190
1183
|
]:
|
|
1191
1184
|
file_name = cls.__name__
|
|
1192
1185
|
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
str(model_path / f"{file_name}.data_example.st"),
|
|
1196
|
-
str(model_path / f"{file_name}.data_example_metadata.json"),
|
|
1186
|
+
data_example = io_utils.pickle_load(
|
|
1187
|
+
model_path / f"{file_name}.data_example.pkl"
|
|
1197
1188
|
)
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
)
|
|
1203
|
-
label_data = RasaModelData(data=loaded_label_data)
|
|
1204
|
-
|
|
1205
|
-
sparse_feature_sizes = rasa.shared.utils.io.read_json_file(
|
|
1206
|
-
model_path / f"{file_name}.sparse_feature_sizes.json"
|
|
1189
|
+
label_data = io_utils.pickle_load(model_path / f"{file_name}.label_data.pkl")
|
|
1190
|
+
label_data = RasaModelData(data=label_data)
|
|
1191
|
+
sparse_feature_sizes = io_utils.pickle_load(
|
|
1192
|
+
model_path / f"{file_name}.sparse_feature_sizes.pkl"
|
|
1207
1193
|
)
|
|
1208
|
-
index_label_id_mapping =
|
|
1194
|
+
index_label_id_mapping = io_utils.json_unpickle(
|
|
1209
1195
|
model_path / f"{file_name}.index_label_id_mapping.json"
|
|
1210
1196
|
)
|
|
1211
1197
|
entity_tag_specs = rasa.shared.utils.io.read_json_file(
|
|
@@ -1225,6 +1211,7 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1225
1211
|
for tag_spec in entity_tag_specs
|
|
1226
1212
|
]
|
|
1227
1213
|
|
|
1214
|
+
# jsonpickle converts dictionary keys to strings
|
|
1228
1215
|
index_label_id_mapping = {
|
|
1229
1216
|
int(key): value for key, value in index_label_id_mapping.items()
|
|
1230
1217
|
}
|
|
@@ -1,21 +1,22 @@
|
|
|
1
1
|
from typing import Any, Text, Dict, List, Type, Tuple
|
|
2
2
|
|
|
3
|
+
import joblib
|
|
3
4
|
import structlog
|
|
4
5
|
from scipy.sparse import hstack, vstack, csr_matrix
|
|
5
6
|
from sklearn.exceptions import NotFittedError
|
|
6
7
|
from sklearn.linear_model import LogisticRegression
|
|
7
8
|
from sklearn.utils.validation import check_is_fitted
|
|
8
9
|
|
|
9
|
-
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
10
|
-
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
11
10
|
from rasa.engine.storage.resource import Resource
|
|
12
11
|
from rasa.engine.storage.storage import ModelStorage
|
|
12
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
|
+
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
13
14
|
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
14
|
-
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
15
15
|
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
16
|
-
from rasa.
|
|
17
|
-
from rasa.shared.nlu.training_data.message import Message
|
|
16
|
+
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
18
17
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
18
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
19
|
+
from rasa.shared.nlu.constants import TEXT, INTENT
|
|
19
20
|
from rasa.utils.tensorflow.constants import RANKING_LENGTH
|
|
20
21
|
|
|
21
22
|
structlogger = structlog.get_logger()
|
|
@@ -183,11 +184,9 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
|
|
|
183
184
|
|
|
184
185
|
def persist(self) -> None:
|
|
185
186
|
"""Persist this model into the passed directory."""
|
|
186
|
-
import skops.io as sio
|
|
187
|
-
|
|
188
187
|
with self._model_storage.write_to(self._resource) as model_dir:
|
|
189
|
-
path = model_dir / f"{self._resource.name}.
|
|
190
|
-
|
|
188
|
+
path = model_dir / f"{self._resource.name}.joblib"
|
|
189
|
+
joblib.dump(self.clf, path)
|
|
191
190
|
structlogger.debug(
|
|
192
191
|
"logistic_regression_classifier.persist",
|
|
193
192
|
event_info=f"Saved intent classifier to '{path}'.",
|
|
@@ -203,21 +202,9 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
|
|
|
203
202
|
**kwargs: Any,
|
|
204
203
|
) -> "LogisticRegressionClassifier":
|
|
205
204
|
"""Loads trained component (see parent class for full docstring)."""
|
|
206
|
-
import skops.io as sio
|
|
207
|
-
|
|
208
205
|
try:
|
|
209
206
|
with model_storage.read_from(resource) as model_dir:
|
|
210
|
-
|
|
211
|
-
unknown_types = sio.get_untrusted_types(file=classifier_file)
|
|
212
|
-
|
|
213
|
-
if unknown_types:
|
|
214
|
-
structlogger.error(
|
|
215
|
-
f"Untrusted types found when loading {classifier_file}!",
|
|
216
|
-
unknown_types=unknown_types,
|
|
217
|
-
)
|
|
218
|
-
raise ValueError()
|
|
219
|
-
|
|
220
|
-
classifier = sio.load(classifier_file, trusted=unknown_types)
|
|
207
|
+
classifier = joblib.load(model_dir / f"{resource.name}.joblib")
|
|
221
208
|
component = cls(
|
|
222
209
|
config, execution_context.node_name, model_storage, resource
|
|
223
210
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
import logging
|
|
3
|
+
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
|
|
4
4
|
import typing
|
|
5
5
|
import warnings
|
|
6
6
|
from typing import Any, Dict, List, Optional, Text, Tuple, Type
|
|
@@ -8,18 +8,18 @@ from typing import Any, Dict, List, Optional, Text, Tuple, Type
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
import rasa.shared.utils.io
|
|
11
|
+
import rasa.utils.io as io_utils
|
|
11
12
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
12
13
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
14
|
from rasa.engine.storage.resource import Resource
|
|
14
15
|
from rasa.engine.storage.storage import ModelStorage
|
|
15
|
-
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
16
|
-
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
17
|
-
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
|
|
18
16
|
from rasa.shared.constants import DOCS_URL_TRAINING_DATA_NLU
|
|
17
|
+
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
19
18
|
from rasa.shared.exceptions import RasaException
|
|
20
19
|
from rasa.shared.nlu.constants import TEXT
|
|
21
|
-
from rasa.
|
|
20
|
+
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
22
21
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
22
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
23
23
|
from rasa.utils.tensorflow.constants import FEATURIZERS
|
|
24
24
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
@@ -266,20 +266,14 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
|
|
|
266
266
|
|
|
267
267
|
def persist(self) -> None:
|
|
268
268
|
"""Persist this model into the passed directory."""
|
|
269
|
-
import skops.io as sio
|
|
270
|
-
|
|
271
269
|
with self._model_storage.write_to(self._resource) as model_dir:
|
|
272
270
|
file_name = self.__class__.__name__
|
|
273
|
-
classifier_file_name = model_dir / f"{file_name}_classifier.
|
|
274
|
-
encoder_file_name = model_dir / f"{file_name}_encoder.
|
|
271
|
+
classifier_file_name = model_dir / f"{file_name}_classifier.pkl"
|
|
272
|
+
encoder_file_name = model_dir / f"{file_name}_encoder.pkl"
|
|
275
273
|
|
|
276
274
|
if self.clf and self.le:
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
280
|
-
encoder_file_name, list(self.le.classes_)
|
|
281
|
-
)
|
|
282
|
-
sio.dump(self.clf.best_estimator_, classifier_file_name)
|
|
275
|
+
io_utils.json_pickle(encoder_file_name, self.le.classes_)
|
|
276
|
+
io_utils.json_pickle(classifier_file_name, self.clf.best_estimator_)
|
|
283
277
|
|
|
284
278
|
@classmethod
|
|
285
279
|
def load(
|
|
@@ -292,36 +286,21 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
|
|
|
292
286
|
) -> SklearnIntentClassifier:
|
|
293
287
|
"""Loads trained component (see parent class for full docstring)."""
|
|
294
288
|
from sklearn.preprocessing import LabelEncoder
|
|
295
|
-
import skops.io as sio
|
|
296
289
|
|
|
297
290
|
try:
|
|
298
291
|
with model_storage.read_from(resource) as model_dir:
|
|
299
292
|
file_name = cls.__name__
|
|
300
|
-
classifier_file = model_dir / f"{file_name}_classifier.
|
|
293
|
+
classifier_file = model_dir / f"{file_name}_classifier.pkl"
|
|
301
294
|
|
|
302
295
|
if classifier_file.exists():
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
if unknown_types:
|
|
306
|
-
logger.error(
|
|
307
|
-
f"Untrusted types ({unknown_types}) found when "
|
|
308
|
-
f"loading {classifier_file}!"
|
|
309
|
-
)
|
|
310
|
-
raise ValueError()
|
|
311
|
-
else:
|
|
312
|
-
classifier = sio.load(classifier_file, trusted=unknown_types)
|
|
313
|
-
|
|
314
|
-
encoder_file = model_dir / f"{file_name}_encoder.json"
|
|
315
|
-
classes = rasa.shared.utils.io.read_json_file(encoder_file)
|
|
296
|
+
classifier = io_utils.json_unpickle(classifier_file)
|
|
316
297
|
|
|
298
|
+
encoder_file = model_dir / f"{file_name}_encoder.pkl"
|
|
299
|
+
classes = io_utils.json_unpickle(encoder_file)
|
|
317
300
|
encoder = LabelEncoder()
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
)
|
|
321
|
-
# convert list of strings (class labels) back to numpy array of
|
|
322
|
-
# strings
|
|
323
|
-
intent_classifier.transform_labels_str2num(classes)
|
|
324
|
-
return intent_classifier
|
|
301
|
+
encoder.classes_ = classes
|
|
302
|
+
|
|
303
|
+
return cls(config, model_storage, resource, classifier, encoder)
|
|
325
304
|
except ValueError:
|
|
326
305
|
logger.debug(
|
|
327
306
|
f"Failed to load '{cls.__name__}' from model storage. Resource "
|