rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +22 -12
- rasa/api.py +1 -1
- rasa/cli/arguments/default_arguments.py +1 -2
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +6 -4
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +8 -0
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/studio.py +18 -8
- rasa/cli/utils.py +5 -0
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -1
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +20 -0
- rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
- rasa/core/channels/socketio.py +2 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +12 -0
- rasa/core/channels/voice_ready/jambonz.py +15 -4
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/asr/azure.py +122 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
- rasa/core/channels/voice_stream/audio_bytes.py +1 -0
- rasa/core/channels/voice_stream/browser_audio.py +31 -8
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/azure.py +6 -2
- rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +189 -39
- rasa/core/featurizers/single_state_featurizer.py +22 -1
- rasa/core/featurizers/tracker_featurizers.py +115 -18
- rasa/core/nlg/contextual_response_rephraser.py +32 -30
- rasa/core/persistor.py +86 -39
- rasa/core/policies/enterprise_search_policy.py +119 -60
- rasa/core/policies/flows/flow_executor.py +7 -4
- rasa/core/policies/intentless_policy.py +78 -22
- rasa/core/policies/ted_policy.py +58 -33
- rasa/core/policies/unexpected_intent_policy.py +15 -7
- rasa/core/processor.py +25 -0
- rasa/core/training/interactive.py +34 -35
- rasa/core/utils.py +8 -3
- rasa/dialogue_understanding/coexistence/llm_based_router.py +39 -12
- rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +49 -4
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +37 -23
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -10
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +71 -11
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +4 -2
- rasa/e2e_test/utils/io.py +1 -1
- rasa/engine/validation.py +316 -10
- rasa/model_manager/config.py +15 -3
- rasa/model_manager/model_api.py +15 -7
- rasa/model_manager/runner_service.py +8 -6
- rasa/model_manager/socket_bridge.py +6 -3
- rasa/model_manager/trainer_service.py +7 -5
- rasa/model_manager/utils.py +28 -7
- rasa/model_service.py +9 -2
- rasa/model_training.py +2 -0
- rasa/nlu/classifiers/diet_classifier.py +38 -25
- rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
- rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
- rasa/nlu/extractors/crf_entity_extractor.py +93 -50
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/constants.py +36 -3
- rasa/shared/core/constants.py +7 -0
- rasa/shared/core/domain.py +26 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +96 -0
- rasa/shared/core/slots.py +5 -0
- rasa/shared/nlu/training_data/features.py +120 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +18 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +37 -31
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +8 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +28 -6
- rasa/shared/utils/llm.py +353 -46
- rasa/shared/utils/yaml.py +111 -73
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +81 -26
- rasa/telemetry.py +92 -17
- rasa/tracing/config.py +2 -0
- rasa/tracing/instrumentation/attribute_extractors.py +94 -17
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +7 -81
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/model_data.py +2 -193
- rasa/validator.py +70 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/METADATA +11 -10
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/RECORD +183 -163
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/entry_points.txt +0 -0
rasa/model_manager/config.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
import os
|
|
3
3
|
|
|
4
|
+
DEFAULT_SERVER_BASE_WORKING_DIRECTORY = "working-data"
|
|
5
|
+
|
|
4
6
|
SERVER_BASE_WORKING_DIRECTORY = os.environ.get(
|
|
5
|
-
"RASA_MODEL_SERVER_BASE_DIRECTORY",
|
|
7
|
+
"RASA_MODEL_SERVER_BASE_DIRECTORY", DEFAULT_SERVER_BASE_WORKING_DIRECTORY
|
|
6
8
|
)
|
|
7
9
|
|
|
8
10
|
SERVER_BASE_URL = os.environ.get("RASA_MODEL_SERVER_BASE_URL", None)
|
|
@@ -14,7 +16,17 @@ SERVER_MODEL_REMOTE_STORAGE = os.environ.get("RASA_REMOTE_STORAGE", None)
|
|
|
14
16
|
# we will use the same python to run training / bots
|
|
15
17
|
RASA_PYTHON_PATH = sys.executable
|
|
16
18
|
|
|
17
|
-
# the max limit for parallel training
|
|
18
|
-
|
|
19
|
+
# the max limit for parallel training requests
|
|
20
|
+
DEFAULT_MAX_PARALLEL_TRAININGS = 10
|
|
21
|
+
|
|
22
|
+
MAX_PARALLEL_TRAININGS = os.getenv(
|
|
23
|
+
"MAX_PARALLEL_TRAININGS", DEFAULT_MAX_PARALLEL_TRAININGS
|
|
24
|
+
)
|
|
25
|
+
# the max limit for parallel running bots
|
|
26
|
+
DEFAULT_MAX_PARALLEL_BOT_RUNS = 10
|
|
27
|
+
|
|
28
|
+
MAX_PARALLEL_BOT_RUNS = os.getenv(
|
|
29
|
+
"MAX_PARALLEL_BOT_RUNS", DEFAULT_MAX_PARALLEL_BOT_RUNS
|
|
30
|
+
)
|
|
19
31
|
|
|
20
32
|
DEFAULT_SERVER_PATH_PREFIX = "talk"
|
rasa/model_manager/model_api.py
CHANGED
|
@@ -16,6 +16,7 @@ from rasa.model_manager.config import SERVER_BASE_URL
|
|
|
16
16
|
from rasa.constants import MODEL_ARCHIVE_EXTENSION
|
|
17
17
|
from rasa.model_manager.runner_service import (
|
|
18
18
|
BotSession,
|
|
19
|
+
BotSessionStatus,
|
|
19
20
|
fetch_remote_model_to_dir,
|
|
20
21
|
run_bot,
|
|
21
22
|
terminate_bot,
|
|
@@ -24,11 +25,13 @@ from rasa.model_manager.runner_service import (
|
|
|
24
25
|
from rasa.model_manager.socket_bridge import create_bridge_server
|
|
25
26
|
from rasa.model_manager.trainer_service import (
|
|
26
27
|
TrainingSession,
|
|
28
|
+
TrainingSessionStatus,
|
|
27
29
|
run_training,
|
|
28
30
|
terminate_training,
|
|
29
31
|
update_training_status,
|
|
30
32
|
)
|
|
31
33
|
from rasa.model_manager.utils import (
|
|
34
|
+
InvalidPathException,
|
|
32
35
|
get_logs_content,
|
|
33
36
|
logs_base_path,
|
|
34
37
|
models_base_path,
|
|
@@ -134,7 +137,8 @@ def internal_blueprint() -> Blueprint:
|
|
|
134
137
|
[
|
|
135
138
|
training
|
|
136
139
|
for training in trainings.values()
|
|
137
|
-
if training.status ==
|
|
140
|
+
if training.status == TrainingSessionStatus.RUNNING
|
|
141
|
+
and training.process.poll() is None
|
|
138
142
|
]
|
|
139
143
|
)
|
|
140
144
|
|
|
@@ -152,7 +156,7 @@ def internal_blueprint() -> Blueprint:
|
|
|
152
156
|
@bp.on_request # type: ignore[misc]
|
|
153
157
|
async def limit_parallel_bot_runs(request: Request) -> Any:
|
|
154
158
|
"""Limit the number of parallel bot runs."""
|
|
155
|
-
from rasa.model_manager.config import
|
|
159
|
+
from rasa.model_manager.config import MAX_PARALLEL_BOT_RUNS
|
|
156
160
|
|
|
157
161
|
if not request.url.endswith("/bot"):
|
|
158
162
|
return None
|
|
@@ -161,15 +165,15 @@ def internal_blueprint() -> Blueprint:
|
|
|
161
165
|
[
|
|
162
166
|
bot
|
|
163
167
|
for bot in running_bots.values()
|
|
164
|
-
if bot.status in {
|
|
168
|
+
if bot.status in {BotSessionStatus.RUNNING, BotSessionStatus.QUEUED}
|
|
165
169
|
]
|
|
166
170
|
)
|
|
167
171
|
|
|
168
|
-
if running_requests >= int(
|
|
172
|
+
if running_requests >= int(MAX_PARALLEL_BOT_RUNS):
|
|
169
173
|
return response.json(
|
|
170
174
|
{
|
|
171
175
|
"message": f"Too many parallel bot runs, above "
|
|
172
|
-
f"the limit of {
|
|
176
|
+
f"the limit of {MAX_PARALLEL_BOT_RUNS}. "
|
|
173
177
|
f"Retry later or increase your server's "
|
|
174
178
|
f"memory and CPU resources."
|
|
175
179
|
},
|
|
@@ -244,6 +248,8 @@ def internal_blueprint() -> Blueprint:
|
|
|
244
248
|
return json({"message": "Training id is required"}, status=400)
|
|
245
249
|
|
|
246
250
|
try:
|
|
251
|
+
# file deepcode ignore PT: path traversal is prevented
|
|
252
|
+
# by the `subpath` function found in the `rasa.model_manager.utils` module
|
|
247
253
|
training_session = run_training(
|
|
248
254
|
training_id=training_id,
|
|
249
255
|
assistant_id=assistant_id,
|
|
@@ -254,8 +260,10 @@ def internal_blueprint() -> Blueprint:
|
|
|
254
260
|
return json(
|
|
255
261
|
{"training_id": training_id, "model_name": training_session.model_name}
|
|
256
262
|
)
|
|
257
|
-
except
|
|
258
|
-
return json({"message": str(
|
|
263
|
+
except InvalidPathException as exc:
|
|
264
|
+
return json({"message": str(exc)}, status=403)
|
|
265
|
+
except Exception as exc:
|
|
266
|
+
return json({"message": str(exc)}, status=500)
|
|
259
267
|
|
|
260
268
|
@bp.get("/training/<training_id>")
|
|
261
269
|
async def get_training(request: Request, training_id: str) -> response.HTTPResponse:
|
|
@@ -87,7 +87,7 @@ async def is_bot_startup_finished(bot: BotSession) -> bool:
|
|
|
87
87
|
return False
|
|
88
88
|
|
|
89
89
|
|
|
90
|
-
def
|
|
90
|
+
def set_bot_status_to_stopped(bot: BotSession) -> None:
|
|
91
91
|
"""Set a bots state to stopped."""
|
|
92
92
|
structlogger.info(
|
|
93
93
|
"model_runner.bot_stopped",
|
|
@@ -97,7 +97,7 @@ def update_bot_to_stopped(bot: BotSession) -> None:
|
|
|
97
97
|
bot.status = BotSessionStatus.STOPPED
|
|
98
98
|
|
|
99
99
|
|
|
100
|
-
def
|
|
100
|
+
def set_bot_status_to_running(bot: BotSession) -> None:
|
|
101
101
|
"""Set a bots state to running."""
|
|
102
102
|
structlogger.info(
|
|
103
103
|
"model_runner.bot_running",
|
|
@@ -119,7 +119,9 @@ def get_open_port() -> int:
|
|
|
119
119
|
return port
|
|
120
120
|
|
|
121
121
|
|
|
122
|
-
def
|
|
122
|
+
def write_encoded_config_data_to_files(
|
|
123
|
+
encoded_configs: Dict[str, bytes], base_path: str
|
|
124
|
+
) -> None:
|
|
123
125
|
"""Write the encoded config data to files."""
|
|
124
126
|
for key, value in encoded_configs.items():
|
|
125
127
|
write_encoded_data_to_file(value, subpath(base_path, f"{key}.yml"))
|
|
@@ -155,7 +157,7 @@ def prepare_bot_directory(
|
|
|
155
157
|
dst=subpath(bot_base_path, "models"),
|
|
156
158
|
)
|
|
157
159
|
|
|
158
|
-
|
|
160
|
+
write_encoded_config_data_to_files(encoded_configs, bot_base_path)
|
|
159
161
|
|
|
160
162
|
|
|
161
163
|
def fetch_remote_model_to_dir(
|
|
@@ -253,9 +255,9 @@ def run_bot(
|
|
|
253
255
|
async def update_bot_status(bot: BotSession) -> None:
|
|
254
256
|
"""Update the status of a bot based on the process return code."""
|
|
255
257
|
if bot.has_died_recently():
|
|
256
|
-
|
|
258
|
+
set_bot_status_to_stopped(bot)
|
|
257
259
|
elif await bot.completed_startup_recently():
|
|
258
|
-
|
|
260
|
+
set_bot_status_to_running(bot)
|
|
259
261
|
|
|
260
262
|
|
|
261
263
|
def terminate_bot(bot: BotSession) -> None:
|
|
@@ -84,6 +84,7 @@ def create_bridge_server(sio: AsyncServer, running_bots: Dict[str, BotSession])
|
|
|
84
84
|
|
|
85
85
|
@sio.on("disconnect")
|
|
86
86
|
async def disconnect(sid: str) -> None:
|
|
87
|
+
"""Disconnect the bot connection."""
|
|
87
88
|
structlogger.debug("model_runner.bot_disconnect", sid=sid)
|
|
88
89
|
if sid in socket_proxy_clients:
|
|
89
90
|
await socket_proxy_clients[sid].disconnect()
|
|
@@ -91,10 +92,12 @@ def create_bridge_server(sio: AsyncServer, running_bots: Dict[str, BotSession])
|
|
|
91
92
|
|
|
92
93
|
@sio.on("*")
|
|
93
94
|
async def handle_message(event: str, sid: str, data: Dict[str, Any]) -> None:
|
|
94
|
-
|
|
95
|
-
# send the response back to the client. both need to happen
|
|
96
|
-
# in parallel in an async way
|
|
95
|
+
""" "Bridge messages between user and bot.
|
|
97
96
|
|
|
97
|
+
Both incoming user messages to the bot_url and
|
|
98
|
+
bot responses sent back to the client need to
|
|
99
|
+
happen in parallel in an async way.
|
|
100
|
+
"""
|
|
98
101
|
client = socket_proxy_clients.get(sid)
|
|
99
102
|
if client is None:
|
|
100
103
|
structlogger.error("model_runner.bot_not_connected", sid=sid)
|
|
@@ -52,12 +52,12 @@ class TrainingSession(BaseModel):
|
|
|
52
52
|
|
|
53
53
|
def train_path(training_id: str) -> str:
|
|
54
54
|
"""Return the path to the training directory for a given training id."""
|
|
55
|
-
return subpath(config.SERVER_BASE_WORKING_DIRECTORY
|
|
55
|
+
return subpath(config.SERVER_BASE_WORKING_DIRECTORY + "/trainings", training_id)
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def cache_for_assistant_path(assistant_id: str) -> str:
|
|
59
59
|
"""Return the path to the cache directory for a given assistant id."""
|
|
60
|
-
return subpath(config.SERVER_BASE_WORKING_DIRECTORY
|
|
60
|
+
return subpath(config.SERVER_BASE_WORKING_DIRECTORY + "/caches", assistant_id)
|
|
61
61
|
|
|
62
62
|
|
|
63
63
|
def terminate_training(training: TrainingSession) -> None:
|
|
@@ -132,8 +132,8 @@ def move_model_to_local_storage(training: TrainingSession) -> None:
|
|
|
132
132
|
ensure_base_directory_exists(models_base_path())
|
|
133
133
|
|
|
134
134
|
model_path = subpath(
|
|
135
|
-
train_path(training.training_id),
|
|
136
|
-
f"
|
|
135
|
+
train_path(training.training_id) + "/models",
|
|
136
|
+
f"{training.model_name}.{MODEL_ARCHIVE_EXTENSION}",
|
|
137
137
|
)
|
|
138
138
|
|
|
139
139
|
if os.path.exists(model_path):
|
|
@@ -217,9 +217,11 @@ def write_training_data_to_files(
|
|
|
217
217
|
}
|
|
218
218
|
|
|
219
219
|
for key, file in data_to_be_written_to_files.items():
|
|
220
|
+
parent_path, file_name = os.path.split(file)
|
|
221
|
+
|
|
220
222
|
write_encoded_data_to_file(
|
|
221
223
|
encoded_training_data.get(key, ""),
|
|
222
|
-
subpath(training_base_path,
|
|
224
|
+
subpath(training_base_path + "/" + parent_path, file_name),
|
|
223
225
|
)
|
|
224
226
|
|
|
225
227
|
|
rasa/model_manager/utils.py
CHANGED
|
@@ -2,7 +2,16 @@ import os
|
|
|
2
2
|
import base64
|
|
3
3
|
from typing import Optional
|
|
4
4
|
|
|
5
|
+
import structlog
|
|
6
|
+
|
|
5
7
|
from rasa.model_manager import config
|
|
8
|
+
from rasa.shared.exceptions import RasaException
|
|
9
|
+
|
|
10
|
+
structlogger = structlog.get_logger()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InvalidPathException(RasaException):
|
|
14
|
+
"""Raised if a path is invalid - e.g. path traversal is detected."""
|
|
6
15
|
|
|
7
16
|
|
|
8
17
|
def write_encoded_data_to_file(encoded_data: bytes, file: str) -> None:
|
|
@@ -17,7 +26,7 @@ def write_encoded_data_to_file(encoded_data: bytes, file: str) -> None:
|
|
|
17
26
|
|
|
18
27
|
|
|
19
28
|
def logs_base_path() -> str:
|
|
20
|
-
"""Return the path to the logs directory."""
|
|
29
|
+
"""Return the path to the logs' directory."""
|
|
21
30
|
return subpath(config.SERVER_BASE_WORKING_DIRECTORY, "logs")
|
|
22
31
|
|
|
23
32
|
|
|
@@ -31,7 +40,7 @@ def ensure_base_directory_exists(directory: str) -> None:
|
|
|
31
40
|
|
|
32
41
|
|
|
33
42
|
def models_base_path() -> str:
|
|
34
|
-
"""Return the path to the models directory."""
|
|
43
|
+
"""Return the path to the models' directory."""
|
|
35
44
|
return subpath(config.SERVER_BASE_WORKING_DIRECTORY, "models")
|
|
36
45
|
|
|
37
46
|
|
|
@@ -48,13 +57,24 @@ def subpath(parent: str, child: str) -> str:
|
|
|
48
57
|
"""Return the path to the child directory of the parent directory.
|
|
49
58
|
|
|
50
59
|
Ensures, that child doesn't navigate to parent directories. Prevents
|
|
51
|
-
path traversal.
|
|
60
|
+
path traversal. Raises an InvalidPathException if the path is invalid.
|
|
61
|
+
|
|
62
|
+
Based on Snyk's directory traversal mitigation:
|
|
63
|
+
https://learn.snyk.io/lesson/directory-traversal/
|
|
52
64
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
65
|
+
safe_path = os.path.abspath(os.path.join(parent, child))
|
|
66
|
+
parent = os.path.abspath(parent)
|
|
67
|
+
|
|
68
|
+
common_base = os.path.commonpath([parent, safe_path])
|
|
69
|
+
if common_base != parent:
|
|
70
|
+
raise InvalidPathException(f"Invalid path: {safe_path}")
|
|
71
|
+
|
|
72
|
+
if os.path.basename(safe_path) != child:
|
|
73
|
+
raise InvalidPathException(
|
|
74
|
+
f"Invalid path - path traversal detected: {safe_path}"
|
|
75
|
+
)
|
|
56
76
|
|
|
57
|
-
return
|
|
77
|
+
return safe_path
|
|
58
78
|
|
|
59
79
|
|
|
60
80
|
def get_logs_content(action_id: str) -> Optional[str]:
|
|
@@ -63,4 +83,5 @@ def get_logs_content(action_id: str) -> Optional[str]:
|
|
|
63
83
|
with open(logs_path(action_id), "r") as file:
|
|
64
84
|
return file.read()
|
|
65
85
|
except FileNotFoundError:
|
|
86
|
+
structlogger.debug("model_service.logs.not_found", action_id=action_id)
|
|
66
87
|
return None
|
rasa/model_service.py
CHANGED
|
@@ -13,6 +13,9 @@ from rasa.utils.common import configure_logging_and_warnings
|
|
|
13
13
|
import rasa.utils.licensing
|
|
14
14
|
from urllib.parse import urlparse
|
|
15
15
|
|
|
16
|
+
from rasa.utils.log_utils import configure_structlog
|
|
17
|
+
from rasa.utils.sanic_error_handler import register_custom_sanic_error_handler
|
|
18
|
+
|
|
16
19
|
structlogger = structlog.get_logger()
|
|
17
20
|
|
|
18
21
|
MODEL_SERVICE_PORT = 8000
|
|
@@ -63,12 +66,14 @@ def main() -> None:
|
|
|
63
66
|
The API server can receive requests to train models, run bots, and manage
|
|
64
67
|
the lifecycle of models and bots.
|
|
65
68
|
"""
|
|
69
|
+
log_level = logging.DEBUG
|
|
66
70
|
configure_logging_and_warnings(
|
|
67
|
-
log_level=
|
|
71
|
+
log_level=log_level,
|
|
68
72
|
logging_config_file=None,
|
|
69
73
|
warn_only_once=True,
|
|
70
74
|
filter_repeated_logs=True,
|
|
71
75
|
)
|
|
76
|
+
configure_structlog(log_level, include_time=True)
|
|
72
77
|
|
|
73
78
|
rasa.utils.licensing.validate_license_from_env()
|
|
74
79
|
|
|
@@ -100,7 +105,9 @@ def main() -> None:
|
|
|
100
105
|
# list all routes
|
|
101
106
|
list_routes(app)
|
|
102
107
|
|
|
103
|
-
app
|
|
108
|
+
register_custom_sanic_error_handler(app)
|
|
109
|
+
|
|
110
|
+
app.run(host="0.0.0.0", port=MODEL_SERVICE_PORT, legacy=True, motd=False)
|
|
104
111
|
|
|
105
112
|
|
|
106
113
|
if __name__ == "__main__":
|
rasa/model_training.py
CHANGED
|
@@ -322,6 +322,8 @@ async def _train_graph(
|
|
|
322
322
|
rasa.engine.validation.validate_coexistance_routing_setup(
|
|
323
323
|
domain, model_configuration, flows
|
|
324
324
|
)
|
|
325
|
+
rasa.engine.validation.validate_model_client_configuration_setup(config)
|
|
326
|
+
rasa.engine.validation.validate_model_group_configuration_setup()
|
|
325
327
|
rasa.engine.validation.validate_flow_component_dependencies(
|
|
326
328
|
flows, model_configuration
|
|
327
329
|
)
|
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import copy
|
|
3
4
|
import logging
|
|
4
5
|
from collections import defaultdict
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
|
|
7
|
-
from rasa.exceptions import ModelNotFound
|
|
8
|
-
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
7
|
+
from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
|
|
9
8
|
|
|
10
9
|
import numpy as np
|
|
11
10
|
import scipy.sparse
|
|
12
11
|
import tensorflow as tf
|
|
13
12
|
|
|
14
|
-
from
|
|
15
|
-
|
|
13
|
+
from rasa.exceptions import ModelNotFound
|
|
14
|
+
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
16
15
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
17
16
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
18
17
|
from rasa.engine.storage.resource import Resource
|
|
@@ -20,18 +19,21 @@ from rasa.engine.storage.storage import ModelStorage
|
|
|
20
19
|
from rasa.nlu.extractors.extractor import EntityExtractorMixin
|
|
21
20
|
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
22
21
|
import rasa.shared.utils.io
|
|
23
|
-
import rasa.utils.io as io_utils
|
|
24
22
|
import rasa.nlu.utils.bilou_utils as bilou_utils
|
|
25
23
|
from rasa.shared.constants import DIAGNOSTIC_DATA
|
|
26
24
|
from rasa.nlu.extractors.extractor import EntityTagSpec
|
|
27
25
|
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
28
26
|
from rasa.utils import train_utils
|
|
29
27
|
from rasa.utils.tensorflow import rasa_layers
|
|
28
|
+
from rasa.utils.tensorflow.feature_array import (
|
|
29
|
+
FeatureArray,
|
|
30
|
+
serialize_nested_feature_arrays,
|
|
31
|
+
deserialize_nested_feature_arrays,
|
|
32
|
+
)
|
|
30
33
|
from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
|
|
31
34
|
from rasa.utils.tensorflow.model_data import (
|
|
32
35
|
RasaModelData,
|
|
33
36
|
FeatureSignature,
|
|
34
|
-
FeatureArray,
|
|
35
37
|
)
|
|
36
38
|
from rasa.nlu.constants import TOKENS_NAMES, DEFAULT_TRANSFORMER_SIZE
|
|
37
39
|
from rasa.shared.nlu.constants import (
|
|
@@ -118,7 +120,6 @@ LABEL_SUB_KEY = IDS
|
|
|
118
120
|
|
|
119
121
|
POSSIBLE_TAGS = [ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_ROLE, ENTITY_ATTRIBUTE_GROUP]
|
|
120
122
|
|
|
121
|
-
|
|
122
123
|
DIETClassifierT = TypeVar("DIETClassifierT", bound="DIETClassifier")
|
|
123
124
|
|
|
124
125
|
|
|
@@ -1083,18 +1084,24 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1083
1084
|
|
|
1084
1085
|
self.model.save(str(tf_model_file))
|
|
1085
1086
|
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
model_path / f"{file_name}.
|
|
1091
|
-
self._sparse_feature_sizes,
|
|
1087
|
+
# save data example
|
|
1088
|
+
serialize_nested_feature_arrays(
|
|
1089
|
+
self._data_example,
|
|
1090
|
+
model_path / f"{file_name}.data_example.st",
|
|
1091
|
+
model_path / f"{file_name}.data_example_metadata.json",
|
|
1092
1092
|
)
|
|
1093
|
-
|
|
1094
|
-
|
|
1093
|
+
# save label data
|
|
1094
|
+
serialize_nested_feature_arrays(
|
|
1095
1095
|
dict(self._label_data.data) if self._label_data is not None else {},
|
|
1096
|
+
model_path / f"{file_name}.label_data.st",
|
|
1097
|
+
model_path / f"{file_name}.label_data_metadata.json",
|
|
1096
1098
|
)
|
|
1097
|
-
|
|
1099
|
+
|
|
1100
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1101
|
+
model_path / f"{file_name}.sparse_feature_sizes.json",
|
|
1102
|
+
self._sparse_feature_sizes,
|
|
1103
|
+
)
|
|
1104
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1098
1105
|
model_path / f"{file_name}.index_label_id_mapping.json",
|
|
1099
1106
|
self.index_label_id_mapping,
|
|
1100
1107
|
)
|
|
@@ -1183,15 +1190,22 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1183
1190
|
]:
|
|
1184
1191
|
file_name = cls.__name__
|
|
1185
1192
|
|
|
1186
|
-
|
|
1187
|
-
|
|
1193
|
+
# load data example
|
|
1194
|
+
data_example = deserialize_nested_feature_arrays(
|
|
1195
|
+
str(model_path / f"{file_name}.data_example.st"),
|
|
1196
|
+
str(model_path / f"{file_name}.data_example_metadata.json"),
|
|
1188
1197
|
)
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
model_path / f"{file_name}.
|
|
1198
|
+
# load label data
|
|
1199
|
+
loaded_label_data = deserialize_nested_feature_arrays(
|
|
1200
|
+
str(model_path / f"{file_name}.label_data.st"),
|
|
1201
|
+
str(model_path / f"{file_name}.label_data_metadata.json"),
|
|
1202
|
+
)
|
|
1203
|
+
label_data = RasaModelData(data=loaded_label_data)
|
|
1204
|
+
|
|
1205
|
+
sparse_feature_sizes = rasa.shared.utils.io.read_json_file(
|
|
1206
|
+
model_path / f"{file_name}.sparse_feature_sizes.json"
|
|
1193
1207
|
)
|
|
1194
|
-
index_label_id_mapping =
|
|
1208
|
+
index_label_id_mapping = rasa.shared.utils.io.read_json_file(
|
|
1195
1209
|
model_path / f"{file_name}.index_label_id_mapping.json"
|
|
1196
1210
|
)
|
|
1197
1211
|
entity_tag_specs = rasa.shared.utils.io.read_json_file(
|
|
@@ -1211,7 +1225,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1211
1225
|
for tag_spec in entity_tag_specs
|
|
1212
1226
|
]
|
|
1213
1227
|
|
|
1214
|
-
# jsonpickle converts dictionary keys to strings
|
|
1215
1228
|
index_label_id_mapping = {
|
|
1216
1229
|
int(key): value for key, value in index_label_id_mapping.items()
|
|
1217
1230
|
}
|
|
@@ -1,22 +1,21 @@
|
|
|
1
1
|
from typing import Any, Text, Dict, List, Type, Tuple
|
|
2
2
|
|
|
3
|
-
import joblib
|
|
4
3
|
import structlog
|
|
5
4
|
from scipy.sparse import hstack, vstack, csr_matrix
|
|
6
5
|
from sklearn.exceptions import NotFittedError
|
|
7
6
|
from sklearn.linear_model import LogisticRegression
|
|
8
7
|
from sklearn.utils.validation import check_is_fitted
|
|
9
8
|
|
|
9
|
+
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
10
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
10
11
|
from rasa.engine.storage.resource import Resource
|
|
11
12
|
from rasa.engine.storage.storage import ModelStorage
|
|
12
|
-
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
|
-
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
14
13
|
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
15
|
-
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
16
14
|
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
17
|
-
from rasa.
|
|
18
|
-
from rasa.shared.nlu.training_data.message import Message
|
|
15
|
+
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
19
16
|
from rasa.shared.nlu.constants import TEXT, INTENT
|
|
17
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
18
|
+
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
20
19
|
from rasa.utils.tensorflow.constants import RANKING_LENGTH
|
|
21
20
|
|
|
22
21
|
structlogger = structlog.get_logger()
|
|
@@ -184,9 +183,11 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
|
|
|
184
183
|
|
|
185
184
|
def persist(self) -> None:
|
|
186
185
|
"""Persist this model into the passed directory."""
|
|
186
|
+
import skops.io as sio
|
|
187
|
+
|
|
187
188
|
with self._model_storage.write_to(self._resource) as model_dir:
|
|
188
|
-
path = model_dir / f"{self._resource.name}.
|
|
189
|
-
|
|
189
|
+
path = model_dir / f"{self._resource.name}.skops"
|
|
190
|
+
sio.dump(self.clf, path)
|
|
190
191
|
structlogger.debug(
|
|
191
192
|
"logistic_regression_classifier.persist",
|
|
192
193
|
event_info=f"Saved intent classifier to '{path}'.",
|
|
@@ -202,9 +203,21 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
|
|
|
202
203
|
**kwargs: Any,
|
|
203
204
|
) -> "LogisticRegressionClassifier":
|
|
204
205
|
"""Loads trained component (see parent class for full docstring)."""
|
|
206
|
+
import skops.io as sio
|
|
207
|
+
|
|
205
208
|
try:
|
|
206
209
|
with model_storage.read_from(resource) as model_dir:
|
|
207
|
-
|
|
210
|
+
classifier_file = model_dir / f"{resource.name}.skops"
|
|
211
|
+
unknown_types = sio.get_untrusted_types(file=classifier_file)
|
|
212
|
+
|
|
213
|
+
if unknown_types:
|
|
214
|
+
structlogger.error(
|
|
215
|
+
f"Untrusted types found when loading {classifier_file}!",
|
|
216
|
+
unknown_types=unknown_types,
|
|
217
|
+
)
|
|
218
|
+
raise ValueError()
|
|
219
|
+
|
|
220
|
+
classifier = sio.load(classifier_file, trusted=unknown_types)
|
|
208
221
|
component = cls(
|
|
209
222
|
config, execution_context.node_name, model_storage, resource
|
|
210
223
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import logging
|
|
3
|
-
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
|
|
4
4
|
import typing
|
|
5
5
|
import warnings
|
|
6
6
|
from typing import Any, Dict, List, Optional, Text, Tuple, Type
|
|
@@ -8,18 +8,18 @@ from typing import Any, Dict, List, Optional, Text, Tuple, Type
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
import rasa.shared.utils.io
|
|
11
|
-
import rasa.utils.io as io_utils
|
|
12
11
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
13
12
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
14
13
|
from rasa.engine.storage.resource import Resource
|
|
15
14
|
from rasa.engine.storage.storage import ModelStorage
|
|
16
|
-
from rasa.shared.constants import DOCS_URL_TRAINING_DATA_NLU
|
|
17
15
|
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
16
|
+
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
17
|
+
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
|
|
18
|
+
from rasa.shared.constants import DOCS_URL_TRAINING_DATA_NLU
|
|
18
19
|
from rasa.shared.exceptions import RasaException
|
|
19
20
|
from rasa.shared.nlu.constants import TEXT
|
|
20
|
-
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
21
|
-
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
22
21
|
from rasa.shared.nlu.training_data.message import Message
|
|
22
|
+
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
23
23
|
from rasa.utils.tensorflow.constants import FEATURIZERS
|
|
24
24
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
@@ -266,14 +266,20 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
|
|
|
266
266
|
|
|
267
267
|
def persist(self) -> None:
|
|
268
268
|
"""Persist this model into the passed directory."""
|
|
269
|
+
import skops.io as sio
|
|
270
|
+
|
|
269
271
|
with self._model_storage.write_to(self._resource) as model_dir:
|
|
270
272
|
file_name = self.__class__.__name__
|
|
271
|
-
classifier_file_name = model_dir / f"{file_name}_classifier.
|
|
272
|
-
encoder_file_name = model_dir / f"{file_name}_encoder.
|
|
273
|
+
classifier_file_name = model_dir / f"{file_name}_classifier.skops"
|
|
274
|
+
encoder_file_name = model_dir / f"{file_name}_encoder.json"
|
|
273
275
|
|
|
274
276
|
if self.clf and self.le:
|
|
275
|
-
|
|
276
|
-
|
|
277
|
+
# convert self.le.classes_ (numpy array of strings) to a list in order
|
|
278
|
+
# to use json dump
|
|
279
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
280
|
+
encoder_file_name, list(self.le.classes_)
|
|
281
|
+
)
|
|
282
|
+
sio.dump(self.clf.best_estimator_, classifier_file_name)
|
|
277
283
|
|
|
278
284
|
@classmethod
|
|
279
285
|
def load(
|
|
@@ -286,21 +292,36 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
|
|
|
286
292
|
) -> SklearnIntentClassifier:
|
|
287
293
|
"""Loads trained component (see parent class for full docstring)."""
|
|
288
294
|
from sklearn.preprocessing import LabelEncoder
|
|
295
|
+
import skops.io as sio
|
|
289
296
|
|
|
290
297
|
try:
|
|
291
298
|
with model_storage.read_from(resource) as model_dir:
|
|
292
299
|
file_name = cls.__name__
|
|
293
|
-
classifier_file = model_dir / f"{file_name}_classifier.
|
|
300
|
+
classifier_file = model_dir / f"{file_name}_classifier.skops"
|
|
294
301
|
|
|
295
302
|
if classifier_file.exists():
|
|
296
|
-
|
|
303
|
+
unknown_types = sio.get_untrusted_types(file=classifier_file)
|
|
297
304
|
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
305
|
+
if unknown_types:
|
|
306
|
+
logger.error(
|
|
307
|
+
f"Untrusted types ({unknown_types}) found when "
|
|
308
|
+
f"loading {classifier_file}!"
|
|
309
|
+
)
|
|
310
|
+
raise ValueError()
|
|
311
|
+
else:
|
|
312
|
+
classifier = sio.load(classifier_file, trusted=unknown_types)
|
|
313
|
+
|
|
314
|
+
encoder_file = model_dir / f"{file_name}_encoder.json"
|
|
315
|
+
classes = rasa.shared.utils.io.read_json_file(encoder_file)
|
|
302
316
|
|
|
303
|
-
|
|
317
|
+
encoder = LabelEncoder()
|
|
318
|
+
intent_classifier = cls(
|
|
319
|
+
config, model_storage, resource, classifier, encoder
|
|
320
|
+
)
|
|
321
|
+
# convert list of strings (class labels) back to numpy array of
|
|
322
|
+
# strings
|
|
323
|
+
intent_classifier.transform_labels_str2num(classes)
|
|
324
|
+
return intent_classifier
|
|
304
325
|
except ValueError:
|
|
305
326
|
logger.debug(
|
|
306
327
|
f"Failed to load '{cls.__name__}' from model storage. Resource "
|