rasa-pro 3.11.0__py3-none-any.whl → 3.11.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +396 -17
- rasa/__main__.py +15 -31
- rasa/api.py +1 -5
- rasa/cli/arguments/default_arguments.py +2 -1
- rasa/cli/arguments/shell.py +1 -5
- rasa/cli/arguments/train.py +0 -14
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +7 -15
- rasa/cli/interactive.py +0 -1
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +7 -5
- rasa/cli/project_templates/calm/endpoints.yml +2 -15
- rasa/cli/project_templates/tutorial/config.yml +5 -8
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +0 -5
- rasa/cli/project_templates/tutorial/domain.yml +0 -14
- rasa/cli/project_templates/tutorial/endpoints.yml +0 -5
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +2 -4
- rasa/cli/studio/studio.py +8 -18
- rasa/cli/studio/upload.py +15 -0
- rasa/cli/train.py +0 -3
- rasa/cli/utils.py +1 -6
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -3
- rasa/core/actions/action.py +33 -75
- rasa/core/actions/e2e_stub_custom_action_executor.py +1 -5
- rasa/core/actions/http_custom_action_executor.py +0 -4
- rasa/core/channels/__init__.py +0 -2
- rasa/core/channels/channel.py +0 -20
- rasa/core/channels/development_inspector.py +3 -10
- rasa/core/channels/inspector/dist/assets/{arc-bc141fb2.js → arc-86942a71.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-be2db283.js → c4Diagram-d0fbc5ce-b0290676.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-55366915.js → classDiagram-936ed81e-f6405f6e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-bb529518.js → classDiagram-v2-c3cb15f1-ef61ac77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-b0ec81d6.js → createText-62fc7601-f0411e58.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-6166330c.js → edges-f2ad444c-7dcc4f3b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-5ccc6a8e.js → erDiagram-9d236eb7-e0c092d7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-fca3bfe4.js → flowDb-1972c806-fba2e3ce.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4739080f.js → flowDiagram-7ea5b25a-7a70b71a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-24a5f41a.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-7c1b0e0f.js → flowchart-elk-definition-abe16c3d-00a59b68.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-772fd050.js → ganttDiagram-9b5ea136-293c91fa.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-8eae1dc9.js → gitGraphDiagram-99d0ae7c-07b2d68c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-f55afcdf.js → index-2c4b9a3b-bc959fbd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-e7cef9de.js → index-3a8a5a28.js} +143 -143
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-124d4a14.js → infoDiagram-736b4530-4a350f72.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-7c4fae44.js → journeyDiagram-df861f2b-af464fb7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b9885fb6.js → layout-0071f036.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-7c59abb6.js → line-2f73cc83.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-4776f780.js → linear-f014b4cc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2332c46c.js → mindmap-definition-beec6740-d2426fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-8fb39303.js → pieDiagram-dbbf0591-776f01a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3c7180a2.js → quadrantDiagram-4d7f4fd6-82e00b57.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-e910bcb8.js → requirementDiagram-6fc4c22a-ea13c6bb.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-ead16c89.js → sankeyDiagram-8f13d901-1feca7e9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-29a02a19.js → sequenceDiagram-b655622a-070c61d2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-042b3137.js → stateDiagram-59f0c015-24f46263.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-2178c0f3.js → stateDiagram-v2-2b26beab-c9056051.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-23ffa4fc.js → styles-080da4f6-08abc34a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-94f59763.js → styles-3dcbcfbf-bc74c25a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-78a6bebc.js → styles-9c745c82-4e5d66de.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-eae2a6f6.js → svgDrawCommon-4835440b-849c4517.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-5c968d92.js → timeline-definition-5b62e21b-d0fb1598.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-fd3db0d5.js → xychartDiagram-2b33534f-04d115e2.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +3 -6
- rasa/core/channels/socketio.py +2 -7
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +4 -15
- rasa/core/channels/voice_ready/jambonz.py +4 -15
- rasa/core/channels/voice_ready/twilio_voice.py +21 -6
- rasa/core/channels/voice_ready/utils.py +5 -6
- rasa/core/channels/voice_stream/asr/asr_engine.py +1 -19
- rasa/core/channels/voice_stream/asr/asr_event.py +0 -5
- rasa/core/channels/voice_stream/asr/deepgram.py +15 -28
- rasa/core/channels/voice_stream/audio_bytes.py +0 -1
- rasa/core/channels/voice_stream/tts/azure.py +3 -9
- rasa/core/channels/voice_stream/tts/cartesia.py +8 -12
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -11
- rasa/core/channels/voice_stream/twilio_media_streams.py +19 -28
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +42 -222
- rasa/core/featurizers/single_state_featurizer.py +1 -22
- rasa/core/featurizers/tracker_featurizers.py +18 -115
- rasa/core/information_retrieval/qdrant.py +0 -1
- rasa/core/nlg/contextual_response_rephraser.py +25 -44
- rasa/core/persistor.py +34 -191
- rasa/core/policies/enterprise_search_policy.py +60 -119
- rasa/core/policies/flows/flow_executor.py +4 -7
- rasa/core/policies/intentless_policy.py +22 -82
- rasa/core/policies/ted_policy.py +33 -58
- rasa/core/policies/unexpected_intent_policy.py +7 -15
- rasa/core/processor.py +13 -89
- rasa/core/run.py +2 -2
- rasa/core/training/interactive.py +35 -34
- rasa/core/utils.py +22 -58
- rasa/dialogue_understanding/coexistence/llm_based_router.py +12 -39
- rasa/dialogue_understanding/commands/__init__.py +0 -4
- rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
- rasa/dialogue_understanding/commands/utils.py +0 -5
- rasa/dialogue_understanding/generator/constants.py +0 -2
- rasa/dialogue_understanding/generator/flow_retrieval.py +4 -49
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +23 -37
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +10 -57
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -19
- rasa/dialogue_understanding/generator/single_step/command_prompt_template.jinja2 +0 -3
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +10 -90
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -53
- rasa/dialogue_understanding/processor/command_processor.py +1 -21
- rasa/e2e_test/assertions.py +16 -133
- rasa/e2e_test/assertions_schema.yml +0 -23
- rasa/e2e_test/e2e_test_case.py +6 -85
- rasa/e2e_test/e2e_test_runner.py +4 -6
- rasa/e2e_test/utils/io.py +1 -3
- rasa/engine/loader.py +0 -12
- rasa/engine/validation.py +11 -541
- rasa/keys +1 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/model_training.py +7 -29
- rasa/nlu/classifiers/diet_classifier.py +25 -38
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +50 -93
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +16 -45
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/tokenizers/whitespace_tokenizer.py +14 -3
- rasa/server.py +1 -3
- rasa/shared/constants.py +0 -61
- rasa/shared/core/constants.py +0 -9
- rasa/shared/core/domain.py +5 -8
- rasa/shared/core/flows/flow.py +0 -5
- rasa/shared/core/flows/flows_list.py +1 -5
- rasa/shared/core/flows/flows_yaml_schema.json +0 -10
- rasa/shared/core/flows/validation.py +0 -96
- rasa/shared/core/flows/yaml_flows_io.py +4 -13
- rasa/shared/core/slots.py +0 -5
- rasa/shared/importers/importer.py +2 -19
- rasa/shared/importers/rasa.py +1 -5
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -18
- rasa/shared/providers/_configs/azure_openai_client_config.py +3 -5
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +0 -1
- rasa/shared/providers/_configs/utils.py +0 -16
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +29 -18
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +21 -54
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +0 -24
- rasa/shared/providers/llm/_base_litellm_client.py +31 -63
- rasa/shared/providers/llm/azure_openai_llm_client.py +29 -50
- rasa/shared/providers/llm/default_litellm_llm_client.py +0 -24
- rasa/shared/providers/llm/self_hosted_llm_client.py +29 -17
- rasa/shared/providers/mappings.py +0 -19
- rasa/shared/utils/common.py +2 -37
- rasa/shared/utils/io.py +6 -28
- rasa/shared/utils/llm.py +46 -353
- rasa/shared/utils/yaml.py +82 -181
- rasa/studio/auth.py +5 -3
- rasa/studio/config.py +4 -13
- rasa/studio/constants.py +0 -1
- rasa/studio/data_handler.py +4 -13
- rasa/studio/upload.py +80 -175
- rasa/telemetry.py +17 -94
- rasa/tracing/config.py +1 -3
- rasa/tracing/instrumentation/attribute_extractors.py +17 -94
- rasa/tracing/instrumentation/instrumentation.py +0 -121
- rasa/utils/common.py +0 -5
- rasa/utils/endpoints.py +1 -27
- rasa/utils/io.py +81 -7
- rasa/utils/log_utils.py +2 -9
- rasa/utils/tensorflow/model_data.py +193 -2
- rasa/validator.py +4 -110
- rasa/version.py +1 -1
- rasa_pro-3.11.0a1.dist-info/METADATA +576 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/RECORD +182 -216
- rasa/core/actions/action_repeat_bot_messages.py +0 -89
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +0 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +0 -165
- rasa/core/channels/voice_stream/asr/azure.py +0 -129
- rasa/core/channels/voice_stream/browser_audio.py +0 -107
- rasa/core/channels/voice_stream/call_state.py +0 -23
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +0 -60
- rasa/dialogue_understanding/commands/user_silence_command.py +0 -59
- rasa/dialogue_understanding/patterns/repeat.py +0 -37
- rasa/dialogue_understanding/patterns/user_silence.py +0 -37
- rasa/model_manager/__init__.py +0 -0
- rasa/model_manager/config.py +0 -40
- rasa/model_manager/model_api.py +0 -559
- rasa/model_manager/runner_service.py +0 -286
- rasa/model_manager/socket_bridge.py +0 -146
- rasa/model_manager/studio_jwt_auth.py +0 -86
- rasa/model_manager/trainer_service.py +0 -325
- rasa/model_manager/utils.py +0 -87
- rasa/model_manager/warm_rasa_process.py +0 -187
- rasa/model_service.py +0 -112
- rasa/shared/core/flows/utils.py +0 -39
- rasa/shared/providers/_configs/litellm_router_client_config.py +0 -220
- rasa/shared/providers/_configs/model_group_config.py +0 -167
- rasa/shared/providers/_configs/rasa_llm_client_config.py +0 -73
- rasa/shared/providers/_utils.py +0 -79
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +0 -135
- rasa/shared/providers/llm/litellm_router_llm_client.py +0 -182
- rasa/shared/providers/llm/rasa_llm_client.py +0 -112
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +0 -183
- rasa/shared/providers/router/router_client.py +0 -73
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +0 -31
- rasa/shared/utils/health_check/health_check.py +0 -258
- rasa/shared/utils/health_check/llm_health_check_mixin.py +0 -31
- rasa/utils/sanic_error_handler.py +0 -32
- rasa/utils/tensorflow/feature_array.py +0 -366
- rasa_pro-3.11.0.dist-info/METADATA +0 -198
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0.dist-info → rasa_pro-3.11.0a1.dist-info}/entry_points.txt +0 -0
|
@@ -1,21 +1,22 @@
|
|
|
1
1
|
from typing import Any, Text, Dict, List, Type, Tuple
|
|
2
2
|
|
|
3
|
+
import joblib
|
|
3
4
|
import structlog
|
|
4
5
|
from scipy.sparse import hstack, vstack, csr_matrix
|
|
5
6
|
from sklearn.exceptions import NotFittedError
|
|
6
7
|
from sklearn.linear_model import LogisticRegression
|
|
7
8
|
from sklearn.utils.validation import check_is_fitted
|
|
8
9
|
|
|
9
|
-
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
10
|
-
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
11
10
|
from rasa.engine.storage.resource import Resource
|
|
12
11
|
from rasa.engine.storage.storage import ModelStorage
|
|
12
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
|
+
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
13
14
|
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
14
|
-
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
15
15
|
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
16
|
-
from rasa.
|
|
17
|
-
from rasa.shared.nlu.training_data.message import Message
|
|
16
|
+
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
18
17
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
18
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
19
|
+
from rasa.shared.nlu.constants import TEXT, INTENT
|
|
19
20
|
from rasa.utils.tensorflow.constants import RANKING_LENGTH
|
|
20
21
|
|
|
21
22
|
structlogger = structlog.get_logger()
|
|
@@ -183,11 +184,9 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
|
|
|
183
184
|
|
|
184
185
|
def persist(self) -> None:
|
|
185
186
|
"""Persist this model into the passed directory."""
|
|
186
|
-
import skops.io as sio
|
|
187
|
-
|
|
188
187
|
with self._model_storage.write_to(self._resource) as model_dir:
|
|
189
|
-
path = model_dir / f"{self._resource.name}.
|
|
190
|
-
|
|
188
|
+
path = model_dir / f"{self._resource.name}.joblib"
|
|
189
|
+
joblib.dump(self.clf, path)
|
|
191
190
|
structlogger.debug(
|
|
192
191
|
"logistic_regression_classifier.persist",
|
|
193
192
|
event_info=f"Saved intent classifier to '{path}'.",
|
|
@@ -203,21 +202,9 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
|
|
|
203
202
|
**kwargs: Any,
|
|
204
203
|
) -> "LogisticRegressionClassifier":
|
|
205
204
|
"""Loads trained component (see parent class for full docstring)."""
|
|
206
|
-
import skops.io as sio
|
|
207
|
-
|
|
208
205
|
try:
|
|
209
206
|
with model_storage.read_from(resource) as model_dir:
|
|
210
|
-
|
|
211
|
-
unknown_types = sio.get_untrusted_types(file=classifier_file)
|
|
212
|
-
|
|
213
|
-
if unknown_types:
|
|
214
|
-
structlogger.error(
|
|
215
|
-
f"Untrusted types found when loading {classifier_file}!",
|
|
216
|
-
unknown_types=unknown_types,
|
|
217
|
-
)
|
|
218
|
-
raise ValueError()
|
|
219
|
-
|
|
220
|
-
classifier = sio.load(classifier_file, trusted=unknown_types)
|
|
207
|
+
classifier = joblib.load(model_dir / f"{resource.name}.joblib")
|
|
221
208
|
component = cls(
|
|
222
209
|
config, execution_context.node_name, model_storage, resource
|
|
223
210
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
import logging
|
|
3
|
+
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
|
|
4
4
|
import typing
|
|
5
5
|
import warnings
|
|
6
6
|
from typing import Any, Dict, List, Optional, Text, Tuple, Type
|
|
@@ -8,18 +8,18 @@ from typing import Any, Dict, List, Optional, Text, Tuple, Type
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
import rasa.shared.utils.io
|
|
11
|
+
import rasa.utils.io as io_utils
|
|
11
12
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
12
13
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
14
|
from rasa.engine.storage.resource import Resource
|
|
14
15
|
from rasa.engine.storage.storage import ModelStorage
|
|
15
|
-
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
16
|
-
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
17
|
-
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
|
|
18
16
|
from rasa.shared.constants import DOCS_URL_TRAINING_DATA_NLU
|
|
17
|
+
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
19
18
|
from rasa.shared.exceptions import RasaException
|
|
20
19
|
from rasa.shared.nlu.constants import TEXT
|
|
21
|
-
from rasa.
|
|
20
|
+
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
22
21
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
22
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
23
23
|
from rasa.utils.tensorflow.constants import FEATURIZERS
|
|
24
24
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
@@ -266,20 +266,14 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
|
|
|
266
266
|
|
|
267
267
|
def persist(self) -> None:
|
|
268
268
|
"""Persist this model into the passed directory."""
|
|
269
|
-
import skops.io as sio
|
|
270
|
-
|
|
271
269
|
with self._model_storage.write_to(self._resource) as model_dir:
|
|
272
270
|
file_name = self.__class__.__name__
|
|
273
|
-
classifier_file_name = model_dir / f"{file_name}_classifier.
|
|
274
|
-
encoder_file_name = model_dir / f"{file_name}_encoder.
|
|
271
|
+
classifier_file_name = model_dir / f"{file_name}_classifier.pkl"
|
|
272
|
+
encoder_file_name = model_dir / f"{file_name}_encoder.pkl"
|
|
275
273
|
|
|
276
274
|
if self.clf and self.le:
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
280
|
-
encoder_file_name, list(self.le.classes_)
|
|
281
|
-
)
|
|
282
|
-
sio.dump(self.clf.best_estimator_, classifier_file_name)
|
|
275
|
+
io_utils.json_pickle(encoder_file_name, self.le.classes_)
|
|
276
|
+
io_utils.json_pickle(classifier_file_name, self.clf.best_estimator_)
|
|
283
277
|
|
|
284
278
|
@classmethod
|
|
285
279
|
def load(
|
|
@@ -292,36 +286,21 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
|
|
|
292
286
|
) -> SklearnIntentClassifier:
|
|
293
287
|
"""Loads trained component (see parent class for full docstring)."""
|
|
294
288
|
from sklearn.preprocessing import LabelEncoder
|
|
295
|
-
import skops.io as sio
|
|
296
289
|
|
|
297
290
|
try:
|
|
298
291
|
with model_storage.read_from(resource) as model_dir:
|
|
299
292
|
file_name = cls.__name__
|
|
300
|
-
classifier_file = model_dir / f"{file_name}_classifier.
|
|
293
|
+
classifier_file = model_dir / f"{file_name}_classifier.pkl"
|
|
301
294
|
|
|
302
295
|
if classifier_file.exists():
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
if unknown_types:
|
|
306
|
-
logger.error(
|
|
307
|
-
f"Untrusted types ({unknown_types}) found when "
|
|
308
|
-
f"loading {classifier_file}!"
|
|
309
|
-
)
|
|
310
|
-
raise ValueError()
|
|
311
|
-
else:
|
|
312
|
-
classifier = sio.load(classifier_file, trusted=unknown_types)
|
|
313
|
-
|
|
314
|
-
encoder_file = model_dir / f"{file_name}_encoder.json"
|
|
315
|
-
classes = rasa.shared.utils.io.read_json_file(encoder_file)
|
|
296
|
+
classifier = io_utils.json_unpickle(classifier_file)
|
|
316
297
|
|
|
298
|
+
encoder_file = model_dir / f"{file_name}_encoder.pkl"
|
|
299
|
+
classes = io_utils.json_unpickle(encoder_file)
|
|
317
300
|
encoder = LabelEncoder()
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
)
|
|
321
|
-
# convert list of strings (class labels) back to numpy array of
|
|
322
|
-
# strings
|
|
323
|
-
intent_classifier.transform_labels_str2num(classes)
|
|
324
|
-
return intent_classifier
|
|
301
|
+
encoder.classes_ = classes
|
|
302
|
+
|
|
303
|
+
return cls(config, model_storage, resource, classifier, encoder)
|
|
325
304
|
except ValueError:
|
|
326
305
|
logger.debug(
|
|
327
306
|
f"Failed to load '{cls.__name__}' from model storage. Resource "
|
|
@@ -4,9 +4,9 @@ from collections import OrderedDict
|
|
|
4
4
|
from enum import Enum
|
|
5
5
|
import logging
|
|
6
6
|
import typing
|
|
7
|
-
from typing import Any, Dict, List, Optional, Text, Tuple, Callable, Type
|
|
8
7
|
|
|
9
8
|
import numpy as np
|
|
9
|
+
from typing import Any, Dict, List, Optional, Text, Tuple, Callable, Type
|
|
10
10
|
|
|
11
11
|
import rasa.nlu.utils.bilou_utils as bilou_utils
|
|
12
12
|
import rasa.shared.utils.io
|
|
@@ -41,9 +41,6 @@ if typing.TYPE_CHECKING:
|
|
|
41
41
|
from sklearn_crfsuite import CRF
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
CONFIG_FEATURES = "features"
|
|
45
|
-
|
|
46
|
-
|
|
47
44
|
class CRFToken:
|
|
48
45
|
def __init__(
|
|
49
46
|
self,
|
|
@@ -63,29 +60,6 @@ class CRFToken:
|
|
|
63
60
|
self.entity_role_tag = entity_role_tag
|
|
64
61
|
self.entity_group_tag = entity_group_tag
|
|
65
62
|
|
|
66
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
67
|
-
return {
|
|
68
|
-
"text": self.text,
|
|
69
|
-
"pos_tag": self.pos_tag,
|
|
70
|
-
"pattern": self.pattern,
|
|
71
|
-
"dense_features": [str(x) for x in list(self.dense_features)],
|
|
72
|
-
"entity_tag": self.entity_tag,
|
|
73
|
-
"entity_role_tag": self.entity_role_tag,
|
|
74
|
-
"entity_group_tag": self.entity_group_tag,
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
@classmethod
|
|
78
|
-
def create_from_dict(cls, data: Dict[str, Any]) -> "CRFToken":
|
|
79
|
-
return cls(
|
|
80
|
-
data["text"],
|
|
81
|
-
data["pos_tag"],
|
|
82
|
-
data["pattern"],
|
|
83
|
-
np.array([float(x) for x in data["dense_features"]]),
|
|
84
|
-
data["entity_tag"],
|
|
85
|
-
data["entity_role_tag"],
|
|
86
|
-
data["entity_group_tag"],
|
|
87
|
-
)
|
|
88
|
-
|
|
89
63
|
|
|
90
64
|
class CRFEntityExtractorOptions(str, Enum):
|
|
91
65
|
"""Features that can be used for the 'CRFEntityExtractor'."""
|
|
@@ -114,6 +88,8 @@ class CRFEntityExtractorOptions(str, Enum):
|
|
|
114
88
|
class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
115
89
|
"""Implements conditional random fields (CRF) to do named entity recognition."""
|
|
116
90
|
|
|
91
|
+
CONFIG_FEATURES = "features"
|
|
92
|
+
|
|
117
93
|
function_dict: Dict[Text, Callable[[CRFToken], Any]] = { # noqa: RUF012
|
|
118
94
|
CRFEntityExtractorOptions.LOW: lambda crf_token: crf_token.text.lower(),
|
|
119
95
|
CRFEntityExtractorOptions.TITLE: lambda crf_token: crf_token.text.istitle(),
|
|
@@ -161,7 +137,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
161
137
|
# "is the preceding token in title case?"
|
|
162
138
|
# POS features require SpacyTokenizer
|
|
163
139
|
# pattern feature require RegexFeaturizer
|
|
164
|
-
CONFIG_FEATURES: [
|
|
140
|
+
CRFEntityExtractor.CONFIG_FEATURES: [
|
|
165
141
|
[
|
|
166
142
|
CRFEntityExtractorOptions.LOW,
|
|
167
143
|
CRFEntityExtractorOptions.TITLE,
|
|
@@ -224,7 +200,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
224
200
|
)
|
|
225
201
|
|
|
226
202
|
def _validate_configuration(self) -> None:
|
|
227
|
-
if len(self.component_config.get(CONFIG_FEATURES, [])) % 2 != 1:
|
|
203
|
+
if len(self.component_config.get(self.CONFIG_FEATURES, [])) % 2 != 1:
|
|
228
204
|
raise ValueError(
|
|
229
205
|
"Need an odd number of crf feature lists to have a center word."
|
|
230
206
|
)
|
|
@@ -275,11 +251,9 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
275
251
|
]
|
|
276
252
|
dataset = [self._convert_to_crf_tokens(example) for example in entity_examples]
|
|
277
253
|
|
|
278
|
-
self.
|
|
279
|
-
dataset, self.component_config, self.crf_order
|
|
280
|
-
)
|
|
254
|
+
self._train_model(dataset)
|
|
281
255
|
|
|
282
|
-
self.persist(
|
|
256
|
+
self.persist()
|
|
283
257
|
|
|
284
258
|
return self._resource
|
|
285
259
|
|
|
@@ -325,9 +299,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
325
299
|
if include_tag_features:
|
|
326
300
|
self._add_tag_to_crf_token(crf_tokens, predictions)
|
|
327
301
|
|
|
328
|
-
features = self._crf_tokens_to_features(
|
|
329
|
-
crf_tokens, self.component_config, include_tag_features
|
|
330
|
-
)
|
|
302
|
+
features = self._crf_tokens_to_features(crf_tokens, include_tag_features)
|
|
331
303
|
predictions[tag_name] = entity_tagger.predict_marginals_single(features)
|
|
332
304
|
|
|
333
305
|
# convert predictions into a list of tags and a list of confidences
|
|
@@ -417,25 +389,27 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
417
389
|
**kwargs: Any,
|
|
418
390
|
) -> CRFEntityExtractor:
|
|
419
391
|
"""Loads trained component (see parent class for full docstring)."""
|
|
392
|
+
import joblib
|
|
393
|
+
|
|
420
394
|
try:
|
|
395
|
+
entity_taggers = OrderedDict()
|
|
421
396
|
with model_storage.read_from(resource) as model_dir:
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
]
|
|
397
|
+
# We have to load in the same order as we persisted things as otherwise
|
|
398
|
+
# the predictions might be off
|
|
399
|
+
file_names = sorted(model_dir.glob("**/*.pkl"))
|
|
400
|
+
if not file_names:
|
|
401
|
+
logger.debug(
|
|
402
|
+
"Failed to load model for 'CRFEntityExtractor'. "
|
|
403
|
+
"Maybe you did not provide enough training data and "
|
|
404
|
+
"no model was trained."
|
|
405
|
+
)
|
|
406
|
+
return cls(config, model_storage, resource)
|
|
433
407
|
|
|
434
|
-
|
|
408
|
+
for file_name in file_names:
|
|
409
|
+
name = file_name.stem[1:]
|
|
410
|
+
entity_taggers[name] = joblib.load(file_name)
|
|
435
411
|
|
|
436
|
-
|
|
437
|
-
entity_extractor.crf_order = crf_order
|
|
438
|
-
return entity_extractor
|
|
412
|
+
return cls(config, model_storage, resource, entity_taggers)
|
|
439
413
|
except ValueError:
|
|
440
414
|
logger.warning(
|
|
441
415
|
f"Failed to load {cls.__name__} from model storage. Resource "
|
|
@@ -443,29 +417,23 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
443
417
|
)
|
|
444
418
|
return cls(config, model_storage, resource)
|
|
445
419
|
|
|
446
|
-
def persist(self
|
|
420
|
+
def persist(self) -> None:
|
|
447
421
|
"""Persist this model into the passed directory."""
|
|
448
|
-
|
|
449
|
-
data_to_store = [
|
|
450
|
-
[token.to_dict() for token in sub_list] for sub_list in dataset
|
|
451
|
-
]
|
|
422
|
+
import joblib
|
|
452
423
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
424
|
+
with self._model_storage.write_to(self._resource) as model_dir:
|
|
425
|
+
if self.entity_taggers:
|
|
426
|
+
for idx, (name, entity_tagger) in enumerate(
|
|
427
|
+
self.entity_taggers.items()
|
|
428
|
+
):
|
|
429
|
+
model_file_name = model_dir / f"{idx}{name}.pkl"
|
|
430
|
+
joblib.dump(entity_tagger, model_file_name)
|
|
459
431
|
|
|
460
|
-
@classmethod
|
|
461
432
|
def _crf_tokens_to_features(
|
|
462
|
-
|
|
463
|
-
crf_tokens: List[CRFToken],
|
|
464
|
-
config: Dict[str, Any],
|
|
465
|
-
include_tag_features: bool = False,
|
|
433
|
+
self, crf_tokens: List[CRFToken], include_tag_features: bool = False
|
|
466
434
|
) -> List[Dict[Text, Any]]:
|
|
467
435
|
"""Convert the list of tokens into discrete features."""
|
|
468
|
-
configured_features =
|
|
436
|
+
configured_features = self.component_config[self.CONFIG_FEATURES]
|
|
469
437
|
sentence_features = []
|
|
470
438
|
|
|
471
439
|
for token_idx in range(len(crf_tokens)):
|
|
@@ -476,31 +444,28 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
476
444
|
half_window_size = window_size // 2
|
|
477
445
|
window_range = range(-half_window_size, half_window_size + 1)
|
|
478
446
|
|
|
479
|
-
token_features =
|
|
447
|
+
token_features = self._create_features_for_token(
|
|
480
448
|
crf_tokens,
|
|
481
449
|
token_idx,
|
|
482
450
|
half_window_size,
|
|
483
451
|
window_range,
|
|
484
452
|
include_tag_features,
|
|
485
|
-
config,
|
|
486
453
|
)
|
|
487
454
|
|
|
488
455
|
sentence_features.append(token_features)
|
|
489
456
|
|
|
490
457
|
return sentence_features
|
|
491
458
|
|
|
492
|
-
@classmethod
|
|
493
459
|
def _create_features_for_token(
|
|
494
|
-
|
|
460
|
+
self,
|
|
495
461
|
crf_tokens: List[CRFToken],
|
|
496
462
|
token_idx: int,
|
|
497
463
|
half_window_size: int,
|
|
498
464
|
window_range: range,
|
|
499
465
|
include_tag_features: bool,
|
|
500
|
-
config: Dict[str, Any],
|
|
501
466
|
) -> Dict[Text, Any]:
|
|
502
467
|
"""Convert a token into discrete features including words before and after."""
|
|
503
|
-
configured_features =
|
|
468
|
+
configured_features = self.component_config[self.CONFIG_FEATURES]
|
|
504
469
|
prefixes = [str(i) for i in window_range]
|
|
505
470
|
|
|
506
471
|
token_features = {}
|
|
@@ -540,13 +505,13 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
540
505
|
# set in the training data, 'matched' is either 'True' or
|
|
541
506
|
# 'False' depending on whether the token actually matches the
|
|
542
507
|
# pattern or not
|
|
543
|
-
regex_patterns =
|
|
508
|
+
regex_patterns = self.function_dict[feature](token)
|
|
544
509
|
for pattern_name, matched in regex_patterns.items():
|
|
545
510
|
token_features[f"{prefix}:{feature}:{pattern_name}"] = (
|
|
546
511
|
matched
|
|
547
512
|
)
|
|
548
513
|
else:
|
|
549
|
-
value =
|
|
514
|
+
value = self.function_dict[feature](token)
|
|
550
515
|
token_features[f"{prefix}:{feature}"] = value
|
|
551
516
|
|
|
552
517
|
return token_features
|
|
@@ -670,46 +635,38 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
670
635
|
|
|
671
636
|
return tags
|
|
672
637
|
|
|
673
|
-
|
|
674
|
-
def train_model(
|
|
675
|
-
cls,
|
|
676
|
-
df_train: List[List[CRFToken]],
|
|
677
|
-
config: Dict[str, Any],
|
|
678
|
-
crf_order: List[str],
|
|
679
|
-
) -> OrderedDict[str, CRF]:
|
|
638
|
+
def _train_model(self, df_train: List[List[CRFToken]]) -> None:
|
|
680
639
|
"""Train the crf tagger based on the training data."""
|
|
681
640
|
import sklearn_crfsuite
|
|
682
641
|
|
|
683
|
-
entity_taggers = OrderedDict()
|
|
642
|
+
self.entity_taggers = OrderedDict()
|
|
684
643
|
|
|
685
|
-
for tag_name in crf_order:
|
|
644
|
+
for tag_name in self.crf_order:
|
|
686
645
|
logger.debug(f"Training CRF for '{tag_name}'.")
|
|
687
646
|
|
|
688
647
|
# add entity tag features for second level CRFs
|
|
689
648
|
include_tag_features = tag_name != ENTITY_ATTRIBUTE_TYPE
|
|
690
649
|
X_train = (
|
|
691
|
-
|
|
650
|
+
self._crf_tokens_to_features(sentence, include_tag_features)
|
|
692
651
|
for sentence in df_train
|
|
693
652
|
)
|
|
694
653
|
y_train = (
|
|
695
|
-
|
|
654
|
+
self._crf_tokens_to_tags(sentence, tag_name) for sentence in df_train
|
|
696
655
|
)
|
|
697
656
|
|
|
698
657
|
entity_tagger = sklearn_crfsuite.CRF(
|
|
699
658
|
algorithm="lbfgs",
|
|
700
659
|
# coefficient for L1 penalty
|
|
701
|
-
c1=
|
|
660
|
+
c1=self.component_config["L1_c"],
|
|
702
661
|
# coefficient for L2 penalty
|
|
703
|
-
c2=
|
|
662
|
+
c2=self.component_config["L2_c"],
|
|
704
663
|
# stop earlier
|
|
705
|
-
max_iterations=
|
|
664
|
+
max_iterations=self.component_config["max_iterations"],
|
|
706
665
|
# include transitions that are possible, but not observed
|
|
707
666
|
all_possible_transitions=True,
|
|
708
667
|
)
|
|
709
668
|
entity_tagger.fit(X_train, y_train)
|
|
710
669
|
|
|
711
|
-
entity_taggers[tag_name] = entity_tagger
|
|
670
|
+
self.entity_taggers[tag_name] = entity_tagger
|
|
712
671
|
|
|
713
672
|
logger.debug("Training finished.")
|
|
714
|
-
|
|
715
|
-
return entity_taggers
|
|
@@ -1,32 +1,30 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
import logging
|
|
4
3
|
import re
|
|
5
|
-
from typing import Any, Dict, List, Optional, Text, Tuple, Set, Type, Union
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
4
|
import scipy.sparse
|
|
9
|
-
from
|
|
10
|
-
from
|
|
5
|
+
from typing import Any, Dict, List, Optional, Text, Tuple, Set, Type
|
|
6
|
+
from rasa.nlu.tokenizers.tokenizer import Tokenizer
|
|
11
7
|
|
|
12
8
|
import rasa.shared.utils.io
|
|
13
9
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
14
10
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
15
11
|
from rasa.engine.storage.resource import Resource
|
|
16
12
|
from rasa.engine.storage.storage import ModelStorage
|
|
13
|
+
from rasa.nlu.featurizers.sparse_featurizer.sparse_featurizer import SparseFeaturizer
|
|
14
|
+
from rasa.nlu.utils.spacy_utils import SpacyModel
|
|
15
|
+
from rasa.shared.constants import DOCS_URL_COMPONENTS
|
|
16
|
+
import rasa.utils.io as io_utils
|
|
17
|
+
from sklearn.exceptions import NotFittedError
|
|
18
|
+
from sklearn.feature_extraction.text import CountVectorizer
|
|
19
|
+
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
20
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
21
|
+
from rasa.shared.exceptions import RasaException, FileIOException
|
|
17
22
|
from rasa.nlu.constants import (
|
|
18
23
|
TOKENS_NAMES,
|
|
19
24
|
MESSAGE_ATTRIBUTES,
|
|
20
25
|
DENSE_FEATURIZABLE_ATTRIBUTES,
|
|
21
26
|
)
|
|
22
|
-
from rasa.nlu.featurizers.sparse_featurizer.sparse_featurizer import SparseFeaturizer
|
|
23
|
-
from rasa.nlu.tokenizers.tokenizer import Tokenizer
|
|
24
|
-
from rasa.nlu.utils.spacy_utils import SpacyModel
|
|
25
|
-
from rasa.shared.constants import DOCS_URL_COMPONENTS
|
|
26
|
-
from rasa.shared.exceptions import RasaException, FileIOException
|
|
27
27
|
from rasa.shared.nlu.constants import TEXT, INTENT, INTENT_RESPONSE_KEY, ACTION_NAME
|
|
28
|
-
from rasa.shared.nlu.training_data.message import Message
|
|
29
|
-
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
30
28
|
|
|
31
29
|
BUFFER_SLOTS_PREFIX = "buf_"
|
|
32
30
|
|
|
@@ -690,31 +688,6 @@ class CountVectorsFeaturizer(SparseFeaturizer, GraphComponent):
|
|
|
690
688
|
"""Check if any model got trained."""
|
|
691
689
|
return any(value is not None for value in attribute_vocabularies.values())
|
|
692
690
|
|
|
693
|
-
@staticmethod
|
|
694
|
-
def convert_vocab(
|
|
695
|
-
vocab: Dict[str, Union[int, Optional[Dict[str, int]]]], to_int: bool
|
|
696
|
-
) -> Dict[str, Union[None, int, np.int64, Dict[str, Union[int, np.int64]]]]:
|
|
697
|
-
"""Converts numpy integers in the vocabulary to Python integers."""
|
|
698
|
-
|
|
699
|
-
def convert_value(value: int) -> Union[int, np.int64]:
|
|
700
|
-
"""Helper function to convert a single value based on to_int flag."""
|
|
701
|
-
return int(value) if to_int else np.int64(value)
|
|
702
|
-
|
|
703
|
-
result_dict: Dict[
|
|
704
|
-
str, Union[None, int, np.int64, Dict[str, Union[int, np.int64]]]
|
|
705
|
-
] = {}
|
|
706
|
-
for key, sub_dict in vocab.items():
|
|
707
|
-
if isinstance(sub_dict, int):
|
|
708
|
-
result_dict[key] = convert_value(sub_dict)
|
|
709
|
-
elif not sub_dict:
|
|
710
|
-
result_dict[key] = None
|
|
711
|
-
else:
|
|
712
|
-
result_dict[key] = {
|
|
713
|
-
sub_key: convert_value(value) for sub_key, value in sub_dict.items()
|
|
714
|
-
}
|
|
715
|
-
|
|
716
|
-
return result_dict
|
|
717
|
-
|
|
718
691
|
def persist(self) -> None:
|
|
719
692
|
"""Persist this model into the passed directory.
|
|
720
693
|
|
|
@@ -728,18 +701,17 @@ class CountVectorsFeaturizer(SparseFeaturizer, GraphComponent):
|
|
|
728
701
|
attribute_vocabularies = self._collect_vectorizer_vocabularies()
|
|
729
702
|
if self._is_any_model_trained(attribute_vocabularies):
|
|
730
703
|
# Definitely need to persist some vocabularies
|
|
731
|
-
featurizer_file = model_dir / "vocabularies.
|
|
704
|
+
featurizer_file = model_dir / "vocabularies.pkl"
|
|
732
705
|
|
|
733
706
|
# Only persist vocabulary from one attribute if `use_shared_vocab`.
|
|
734
707
|
# Can be loaded and distributed to all attributes.
|
|
735
|
-
|
|
708
|
+
vocab = (
|
|
736
709
|
attribute_vocabularies[TEXT]
|
|
737
710
|
if self.use_shared_vocab
|
|
738
711
|
else attribute_vocabularies
|
|
739
712
|
)
|
|
740
|
-
vocab = self.convert_vocab(loaded_vocab, to_int=True)
|
|
741
713
|
|
|
742
|
-
|
|
714
|
+
io_utils.json_pickle(featurizer_file, vocab)
|
|
743
715
|
|
|
744
716
|
# Dump OOV words separately as they might have been modified during
|
|
745
717
|
# training
|
|
@@ -814,9 +786,8 @@ class CountVectorsFeaturizer(SparseFeaturizer, GraphComponent):
|
|
|
814
786
|
"""Loads trained component (see parent class for full docstring)."""
|
|
815
787
|
try:
|
|
816
788
|
with model_storage.read_from(resource) as model_dir:
|
|
817
|
-
featurizer_file = model_dir / "vocabularies.
|
|
818
|
-
vocabulary =
|
|
819
|
-
vocabulary = cls.convert_vocab(vocabulary, to_int=False)
|
|
789
|
+
featurizer_file = model_dir / "vocabularies.pkl"
|
|
790
|
+
vocabulary = io_utils.json_unpickle(featurizer_file)
|
|
820
791
|
|
|
821
792
|
share_vocabulary = config["use_shared_vocab"]
|
|
822
793
|
|