rasa-pro 3.8.18__py3-none-any.whl → 3.9.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +6 -42
- rasa/__main__.py +14 -9
- rasa/anonymization/anonymization_pipeline.py +0 -1
- rasa/anonymization/anonymization_rule_executor.py +3 -3
- rasa/anonymization/utils.py +4 -3
- rasa/api.py +2 -2
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/arguments/run.py +2 -2
- rasa/cli/arguments/test.py +1 -1
- rasa/cli/arguments/train.py +10 -10
- rasa/cli/e2e_test.py +27 -7
- rasa/cli/export.py +0 -1
- rasa/cli/license.py +3 -3
- rasa/cli/project_templates/calm/actions/action_template.py +1 -1
- rasa/cli/project_templates/calm/config.yml +1 -1
- rasa/cli/project_templates/calm/credentials.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +1 -1
- rasa/cli/project_templates/calm/domain/add_contact.yml +8 -2
- rasa/cli/project_templates/calm/domain/list_contacts.yml +3 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +9 -2
- rasa/cli/project_templates/calm/domain/shared.yml +5 -0
- rasa/cli/project_templates/calm/endpoints.yml +4 -4
- rasa/cli/project_templates/default/actions/actions.py +1 -1
- rasa/cli/project_templates/default/config.yml +5 -5
- rasa/cli/project_templates/default/credentials.yml +1 -1
- rasa/cli/project_templates/default/endpoints.yml +4 -4
- rasa/cli/project_templates/default/tests/test_stories.yml +1 -1
- rasa/cli/project_templates/tutorial/config.yml +1 -1
- rasa/cli/project_templates/tutorial/credentials.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
- rasa/cli/project_templates/tutorial/domain.yml +4 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +6 -6
- rasa/cli/run.py +0 -1
- rasa/cli/scaffold.py +3 -2
- rasa/cli/studio/download.py +11 -0
- rasa/cli/studio/studio.py +180 -24
- rasa/cli/studio/upload.py +0 -8
- rasa/cli/telemetry.py +18 -6
- rasa/cli/utils.py +21 -10
- rasa/cli/x.py +3 -2
- rasa/constants.py +1 -1
- rasa/core/actions/action.py +90 -315
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/constants.py +3 -0
- rasa/core/actions/custom_action_executor.py +188 -0
- rasa/core/actions/forms.py +11 -7
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +140 -0
- rasa/core/actions/loops.py +3 -0
- rasa/core/actions/two_stage_fallback.py +1 -1
- rasa/core/agent.py +2 -4
- rasa/core/brokers/pika.py +1 -2
- rasa/core/channels/audiocodes.py +1 -1
- rasa/core/channels/botframework.py +0 -1
- rasa/core/channels/callback.py +0 -1
- rasa/core/channels/console.py +6 -8
- rasa/core/channels/development_inspector.py +1 -1
- rasa/core/channels/facebook.py +0 -3
- rasa/core/channels/hangouts.py +0 -6
- rasa/core/channels/inspector/dist/assets/{arc-5623b6dc.js → arc-b6e548fe.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-685c106a.js → c4Diagram-d0fbc5ce-fa03ac9e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-8cbed007.js → classDiagram-936ed81e-ee67392a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-5889cf12.js → classDiagram-v2-c3cb15f1-9b283fae.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-24c249d7.js → createText-62fc7601-8b6fcc2a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-7dd06a75.js → edges-f2ad444c-22e77f4f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-62c1e54c.js → erDiagram-9d236eb7-60ffc87f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-ce49b86f.js → flowDb-1972c806-9dd802e4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4067e48f.js → flowDiagram-7ea5b25a-5fa1912f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-59fe4051.js → flowchart-elk-definition-abe16c3d-622a1fd2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-47e3a43b.js → ganttDiagram-9b5ea136-e285a63a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-5a2ac0d9.js → gitGraphDiagram-99d0ae7c-f237bdca.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-dfb8efc4.js → index-2c4b9a3b-4b03d70e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-268a75c0.js → index-a5d3e69d.js} +4 -4
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-b0c470f2.js → infoDiagram-736b4530-72a0fa5f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-2edb829a.js → journeyDiagram-df861f2b-82218c41.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b6873d69.js → layout-78cff630.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-1efc5781.js → line-5038b469.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-661e9b94.js → linear-c4fc4098.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2d2e727f.js → mindmap-definition-beec6740-c33c8ea6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-9d3ea93d.js → pieDiagram-dbbf0591-a8d03059.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-06a178a2.js → quadrantDiagram-4d7f4fd6-6a0e56b2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-0bfedffc.js → requirementDiagram-6fc4c22a-2dc7c7bd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-d76d0a04.js → sankeyDiagram-8f13d901-2360fe39.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-37bb4341.js → sequenceDiagram-b655622a-41b9f9ad.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-f52f7f57.js → stateDiagram-59f0c015-0aad326f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-4a986a20.js → stateDiagram-v2-2b26beab-9847d984.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-7dd9ae12.js → styles-080da4f6-564d890e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-46e1ca14.js → styles-3dcbcfbf-38957613.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-4a97439a.js → styles-9c745c82-f0fc6921.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-823917a3.js → svgDrawCommon-4835440b-ef3c5a77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-9ea72896.js → timeline-definition-5b62e21b-bf3e91c1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-b631a8b6.js → xychartDiagram-2b33534f-4d4026c0.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -7
- rasa/core/channels/inspector/src/helpers/formatters.ts +3 -2
- rasa/core/channels/rest.py +36 -21
- rasa/core/channels/rocketchat.py +0 -1
- rasa/core/channels/socketio.py +1 -1
- rasa/core/channels/telegram.py +3 -3
- rasa/core/channels/webexteams.py +0 -1
- rasa/core/concurrent_lock_store.py +1 -1
- rasa/core/evaluation/marker_base.py +1 -3
- rasa/core/evaluation/marker_stats.py +1 -2
- rasa/core/featurizers/single_state_featurizer.py +3 -26
- rasa/core/featurizers/tracker_featurizers.py +18 -122
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +9 -4
- rasa/core/information_retrieval/information_retrieval.py +64 -7
- rasa/core/information_retrieval/milvus.py +7 -14
- rasa/core/information_retrieval/qdrant.py +8 -15
- rasa/core/lock_store.py +0 -1
- rasa/core/migrate.py +1 -2
- rasa/core/nlg/callback.py +3 -4
- rasa/core/policies/enterprise_search_policy.py +86 -22
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +4 -41
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flows/flow_executor.py +104 -2
- rasa/core/policies/intentless_policy.py +7 -9
- rasa/core/policies/memoization.py +3 -3
- rasa/core/policies/policy.py +18 -9
- rasa/core/policies/rule_policy.py +8 -11
- rasa/core/policies/ted_policy.py +61 -88
- rasa/core/policies/unexpected_intent_policy.py +8 -17
- rasa/core/processor.py +136 -47
- rasa/core/run.py +41 -25
- rasa/core/secrets_manager/endpoints.py +2 -2
- rasa/core/secrets_manager/vault.py +6 -8
- rasa/core/test.py +3 -5
- rasa/core/tracker_store.py +49 -14
- rasa/core/train.py +1 -3
- rasa/core/training/interactive.py +9 -6
- rasa/core/utils.py +5 -10
- rasa/dialogue_understanding/coexistence/intent_based_router.py +11 -4
- rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -3
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +9 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +9 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +38 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/clarify_command.py +9 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +9 -0
- rasa/dialogue_understanding/commands/error_command.py +12 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +9 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +9 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/noop_command.py +9 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +34 -3
- rasa/dialogue_understanding/commands/skip_question_command.py +9 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +9 -0
- rasa/dialogue_understanding/generator/__init__.py +16 -1
- rasa/dialogue_understanding/generator/command_generator.py +92 -6
- rasa/dialogue_understanding/generator/constants.py +18 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +7 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +39 -609
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +69 -8
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +44 -39
- rasa/dialogue_understanding/processor/command_processor.py +111 -3
- rasa/e2e_test/constants.py +1 -0
- rasa/e2e_test/e2e_test_case.py +44 -0
- rasa/e2e_test/e2e_test_runner.py +114 -11
- rasa/e2e_test/e2e_test_schema.yml +18 -0
- rasa/engine/caching.py +0 -1
- rasa/engine/graph.py +18 -6
- rasa/engine/recipes/config_files/default_config.yml +3 -3
- rasa/engine/recipes/default_components.py +1 -1
- rasa/engine/recipes/default_recipe.py +4 -5
- rasa/engine/recipes/recipe.py +1 -1
- rasa/engine/runner/dask.py +3 -9
- rasa/engine/storage/local_model_storage.py +0 -2
- rasa/engine/validation.py +179 -145
- rasa/exceptions.py +2 -2
- rasa/graph_components/validators/default_recipe_validator.py +3 -5
- rasa/hooks.py +0 -1
- rasa/model.py +1 -1
- rasa/model_training.py +1 -0
- rasa/nlu/classifiers/diet_classifier.py +33 -52
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +54 -97
- rasa/nlu/extractors/duckling_entity_extractor.py +1 -1
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +1 -5
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +0 -4
- rasa/nlu/featurizers/featurizer.py +1 -1
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +18 -49
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +26 -64
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/persistor.py +68 -26
- rasa/nlu/selectors/response_selector.py +7 -10
- rasa/nlu/test.py +0 -3
- rasa/nlu/utils/hugging_face/registry.py +1 -1
- rasa/nlu/utils/spacy_utils.py +1 -3
- rasa/server.py +22 -7
- rasa/shared/constants.py +12 -1
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +4 -5
- rasa/shared/core/domain.py +57 -56
- rasa/shared/core/events.py +4 -7
- rasa/shared/core/flows/flow.py +9 -0
- rasa/shared/core/flows/flows_list.py +12 -0
- rasa/shared/core/flows/steps/action.py +7 -2
- rasa/shared/core/generator.py +12 -11
- rasa/shared/core/slot_mappings.py +315 -24
- rasa/shared/core/slots.py +4 -2
- rasa/shared/core/trackers.py +32 -14
- rasa/shared/core/training_data/loading.py +0 -1
- rasa/shared/core/training_data/story_reader/story_reader.py +3 -3
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +11 -11
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +5 -3
- rasa/shared/core/training_data/structures.py +1 -1
- rasa/shared/core/training_data/visualization.py +1 -1
- rasa/shared/data.py +58 -1
- rasa/shared/exceptions.py +36 -2
- rasa/shared/importers/importer.py +1 -2
- rasa/shared/importers/rasa.py +0 -1
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/entities_parser.py +1 -2
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/dialogflow.py +3 -2
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -5
- rasa/shared/nlu/training_data/formats/readerwriter.py +0 -1
- rasa/shared/nlu/training_data/message.py +13 -0
- rasa/shared/nlu/training_data/training_data.py +0 -2
- rasa/shared/providers/openai/session_handler.py +2 -2
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/io.py +11 -1
- rasa/shared/utils/llm.py +1 -2
- rasa/shared/utils/pykwalify_extensions.py +1 -0
- rasa/shared/utils/schemas/domain.yml +3 -0
- rasa/shared/utils/yaml.py +44 -35
- rasa/studio/auth.py +26 -10
- rasa/studio/constants.py +2 -0
- rasa/studio/data_handler.py +114 -107
- rasa/studio/download.py +160 -27
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +6 -7
- rasa/studio/upload.py +159 -134
- rasa/telemetry.py +188 -34
- rasa/tracing/config.py +18 -3
- rasa/tracing/constants.py +26 -2
- rasa/tracing/instrumentation/attribute_extractors.py +50 -41
- rasa/tracing/instrumentation/instrumentation.py +290 -44
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +7 -5
- rasa/tracing/instrumentation/metrics.py +109 -21
- rasa/tracing/metric_instrument_provider.py +83 -3
- rasa/utils/cli.py +2 -1
- rasa/utils/common.py +1 -1
- rasa/utils/endpoints.py +1 -2
- rasa/utils/io.py +72 -6
- rasa/utils/licensing.py +246 -31
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/data_generator.py +1 -1
- rasa/utils/tensorflow/environment.py +1 -1
- rasa/utils/tensorflow/model_data.py +201 -12
- rasa/utils/tensorflow/model_data_utils.py +499 -500
- rasa/utils/tensorflow/models.py +5 -6
- rasa/utils/tensorflow/rasa_layers.py +15 -15
- rasa/utils/train_utils.py +1 -1
- rasa/utils/url_tools.py +53 -0
- rasa/validator.py +305 -3
- rasa/version.py +1 -1
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/METADATA +25 -61
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/RECORD +276 -259
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +0 -1
- rasa/utils/tensorflow/feature_array.py +0 -370
- /rasa/dialogue_understanding/generator/{command_prompt_template.jinja2 → single_step/command_prompt_template.jinja2} +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/NOTICE +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/WHEEL +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/entry_points.txt +0 -0
|
@@ -1,21 +1,22 @@
|
|
|
1
1
|
from typing import Any, Text, Dict, List, Type, Tuple
|
|
2
2
|
|
|
3
|
+
import joblib
|
|
3
4
|
import structlog
|
|
4
5
|
from scipy.sparse import hstack, vstack, csr_matrix
|
|
5
6
|
from sklearn.exceptions import NotFittedError
|
|
6
7
|
from sklearn.linear_model import LogisticRegression
|
|
7
8
|
from sklearn.utils.validation import check_is_fitted
|
|
8
9
|
|
|
9
|
-
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
10
|
-
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
11
10
|
from rasa.engine.storage.resource import Resource
|
|
12
11
|
from rasa.engine.storage.storage import ModelStorage
|
|
12
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
|
+
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
13
14
|
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
14
|
-
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
15
15
|
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
16
|
-
from rasa.
|
|
17
|
-
from rasa.shared.nlu.training_data.message import Message
|
|
16
|
+
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
18
17
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
18
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
19
|
+
from rasa.shared.nlu.constants import TEXT, INTENT
|
|
19
20
|
from rasa.utils.tensorflow.constants import RANKING_LENGTH
|
|
20
21
|
|
|
21
22
|
structlogger = structlog.get_logger()
|
|
@@ -183,11 +184,9 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
|
|
|
183
184
|
|
|
184
185
|
def persist(self) -> None:
|
|
185
186
|
"""Persist this model into the passed directory."""
|
|
186
|
-
import skops.io as sio
|
|
187
|
-
|
|
188
187
|
with self._model_storage.write_to(self._resource) as model_dir:
|
|
189
|
-
path = model_dir / f"{self._resource.name}.
|
|
190
|
-
|
|
188
|
+
path = model_dir / f"{self._resource.name}.joblib"
|
|
189
|
+
joblib.dump(self.clf, path)
|
|
191
190
|
structlogger.debug(
|
|
192
191
|
"logistic_regression_classifier.persist",
|
|
193
192
|
event_info=f"Saved intent classifier to '{path}'.",
|
|
@@ -203,21 +202,9 @@ class LogisticRegressionClassifier(IntentClassifier, GraphComponent):
|
|
|
203
202
|
**kwargs: Any,
|
|
204
203
|
) -> "LogisticRegressionClassifier":
|
|
205
204
|
"""Loads trained component (see parent class for full docstring)."""
|
|
206
|
-
import skops.io as sio
|
|
207
|
-
|
|
208
205
|
try:
|
|
209
206
|
with model_storage.read_from(resource) as model_dir:
|
|
210
|
-
|
|
211
|
-
unknown_types = sio.get_untrusted_types(file=classifier_file)
|
|
212
|
-
|
|
213
|
-
if unknown_types:
|
|
214
|
-
structlogger.error(
|
|
215
|
-
f"Untrusted types found when loading {classifier_file}!",
|
|
216
|
-
unknown_types=unknown_types,
|
|
217
|
-
)
|
|
218
|
-
raise ValueError()
|
|
219
|
-
|
|
220
|
-
classifier = sio.load(classifier_file, trusted=unknown_types)
|
|
207
|
+
classifier = joblib.load(model_dir / f"{resource.name}.joblib")
|
|
221
208
|
component = cls(
|
|
222
209
|
config, execution_context.node_name, model_storage, resource
|
|
223
210
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
import logging
|
|
3
|
+
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
|
|
4
4
|
import typing
|
|
5
5
|
import warnings
|
|
6
6
|
from typing import Any, Dict, List, Optional, Text, Tuple, Type
|
|
@@ -8,18 +8,18 @@ from typing import Any, Dict, List, Optional, Text, Tuple, Type
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
|
|
10
10
|
import rasa.shared.utils.io
|
|
11
|
+
import rasa.utils.io as io_utils
|
|
11
12
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
12
13
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
14
|
from rasa.engine.storage.resource import Resource
|
|
14
15
|
from rasa.engine.storage.storage import ModelStorage
|
|
15
|
-
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
16
|
-
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
17
|
-
from rasa.nlu.featurizers.dense_featurizer.dense_featurizer import DenseFeaturizer
|
|
18
16
|
from rasa.shared.constants import DOCS_URL_TRAINING_DATA_NLU
|
|
17
|
+
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
19
18
|
from rasa.shared.exceptions import RasaException
|
|
20
19
|
from rasa.shared.nlu.constants import TEXT
|
|
21
|
-
from rasa.
|
|
20
|
+
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
22
21
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
22
|
+
from rasa.shared.nlu.training_data.message import Message
|
|
23
23
|
from rasa.utils.tensorflow.constants import FEATURIZERS
|
|
24
24
|
|
|
25
25
|
logger = logging.getLogger(__name__)
|
|
@@ -266,20 +266,14 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
|
|
|
266
266
|
|
|
267
267
|
def persist(self) -> None:
|
|
268
268
|
"""Persist this model into the passed directory."""
|
|
269
|
-
import skops.io as sio
|
|
270
|
-
|
|
271
269
|
with self._model_storage.write_to(self._resource) as model_dir:
|
|
272
270
|
file_name = self.__class__.__name__
|
|
273
|
-
classifier_file_name = model_dir / f"{file_name}_classifier.
|
|
274
|
-
encoder_file_name = model_dir / f"{file_name}_encoder.
|
|
271
|
+
classifier_file_name = model_dir / f"{file_name}_classifier.pkl"
|
|
272
|
+
encoder_file_name = model_dir / f"{file_name}_encoder.pkl"
|
|
275
273
|
|
|
276
274
|
if self.clf and self.le:
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
280
|
-
encoder_file_name, list(self.le.classes_)
|
|
281
|
-
)
|
|
282
|
-
sio.dump(self.clf.best_estimator_, classifier_file_name)
|
|
275
|
+
io_utils.json_pickle(encoder_file_name, self.le.classes_)
|
|
276
|
+
io_utils.json_pickle(classifier_file_name, self.clf.best_estimator_)
|
|
283
277
|
|
|
284
278
|
@classmethod
|
|
285
279
|
def load(
|
|
@@ -292,36 +286,21 @@ class SklearnIntentClassifier(GraphComponent, IntentClassifier):
|
|
|
292
286
|
) -> SklearnIntentClassifier:
|
|
293
287
|
"""Loads trained component (see parent class for full docstring)."""
|
|
294
288
|
from sklearn.preprocessing import LabelEncoder
|
|
295
|
-
import skops.io as sio
|
|
296
289
|
|
|
297
290
|
try:
|
|
298
291
|
with model_storage.read_from(resource) as model_dir:
|
|
299
292
|
file_name = cls.__name__
|
|
300
|
-
classifier_file = model_dir / f"{file_name}_classifier.
|
|
293
|
+
classifier_file = model_dir / f"{file_name}_classifier.pkl"
|
|
301
294
|
|
|
302
295
|
if classifier_file.exists():
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
if unknown_types:
|
|
306
|
-
logger.error(
|
|
307
|
-
f"Untrusted types ({unknown_types}) found when "
|
|
308
|
-
f"loading {classifier_file}!"
|
|
309
|
-
)
|
|
310
|
-
raise ValueError()
|
|
311
|
-
else:
|
|
312
|
-
classifier = sio.load(classifier_file, trusted=unknown_types)
|
|
313
|
-
|
|
314
|
-
encoder_file = model_dir / f"{file_name}_encoder.json"
|
|
315
|
-
classes = rasa.shared.utils.io.read_json_file(encoder_file)
|
|
296
|
+
classifier = io_utils.json_unpickle(classifier_file)
|
|
316
297
|
|
|
298
|
+
encoder_file = model_dir / f"{file_name}_encoder.pkl"
|
|
299
|
+
classes = io_utils.json_unpickle(encoder_file)
|
|
317
300
|
encoder = LabelEncoder()
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
)
|
|
321
|
-
# convert list of strings (class labels) back to numpy array of
|
|
322
|
-
# strings
|
|
323
|
-
intent_classifier.transform_labels_str2num(classes)
|
|
324
|
-
return intent_classifier
|
|
301
|
+
encoder.classes_ = classes
|
|
302
|
+
|
|
303
|
+
return cls(config, model_storage, resource, classifier, encoder)
|
|
325
304
|
except ValueError:
|
|
326
305
|
logger.debug(
|
|
327
306
|
f"Failed to load '{cls.__name__}' from model storage. Resource "
|
|
@@ -4,9 +4,9 @@ from collections import OrderedDict
|
|
|
4
4
|
from enum import Enum
|
|
5
5
|
import logging
|
|
6
6
|
import typing
|
|
7
|
-
from typing import Any, Dict, List, Optional, Text, Tuple, Callable, Type
|
|
8
7
|
|
|
9
8
|
import numpy as np
|
|
9
|
+
from typing import Any, Dict, List, Optional, Text, Tuple, Callable, Type
|
|
10
10
|
|
|
11
11
|
import rasa.nlu.utils.bilou_utils as bilou_utils
|
|
12
12
|
import rasa.shared.utils.io
|
|
@@ -41,9 +41,6 @@ if typing.TYPE_CHECKING:
|
|
|
41
41
|
from sklearn_crfsuite import CRF
|
|
42
42
|
|
|
43
43
|
|
|
44
|
-
CONFIG_FEATURES = "features"
|
|
45
|
-
|
|
46
|
-
|
|
47
44
|
class CRFToken:
|
|
48
45
|
def __init__(
|
|
49
46
|
self,
|
|
@@ -63,29 +60,6 @@ class CRFToken:
|
|
|
63
60
|
self.entity_role_tag = entity_role_tag
|
|
64
61
|
self.entity_group_tag = entity_group_tag
|
|
65
62
|
|
|
66
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
67
|
-
return {
|
|
68
|
-
"text": self.text,
|
|
69
|
-
"pos_tag": self.pos_tag,
|
|
70
|
-
"pattern": self.pattern,
|
|
71
|
-
"dense_features": [str(x) for x in list(self.dense_features)],
|
|
72
|
-
"entity_tag": self.entity_tag,
|
|
73
|
-
"entity_role_tag": self.entity_role_tag,
|
|
74
|
-
"entity_group_tag": self.entity_group_tag,
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
@classmethod
|
|
78
|
-
def create_from_dict(cls, data: Dict[str, Any]) -> "CRFToken":
|
|
79
|
-
return cls(
|
|
80
|
-
data["text"],
|
|
81
|
-
data["pos_tag"],
|
|
82
|
-
data["pattern"],
|
|
83
|
-
np.array([float(x) for x in data["dense_features"]]),
|
|
84
|
-
data["entity_tag"],
|
|
85
|
-
data["entity_role_tag"],
|
|
86
|
-
data["entity_group_tag"],
|
|
87
|
-
)
|
|
88
|
-
|
|
89
63
|
|
|
90
64
|
class CRFEntityExtractorOptions(str, Enum):
|
|
91
65
|
"""Features that can be used for the 'CRFEntityExtractor'."""
|
|
@@ -114,6 +88,8 @@ class CRFEntityExtractorOptions(str, Enum):
|
|
|
114
88
|
class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
115
89
|
"""Implements conditional random fields (CRF) to do named entity recognition."""
|
|
116
90
|
|
|
91
|
+
CONFIG_FEATURES = "features"
|
|
92
|
+
|
|
117
93
|
function_dict: Dict[Text, Callable[[CRFToken], Any]] = { # noqa: RUF012
|
|
118
94
|
CRFEntityExtractorOptions.LOW: lambda crf_token: crf_token.text.lower(),
|
|
119
95
|
CRFEntityExtractorOptions.TITLE: lambda crf_token: crf_token.text.istitle(),
|
|
@@ -132,7 +108,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
132
108
|
CRFEntityExtractorOptions.DIGIT: lambda crf_token: crf_token.text.isdigit(),
|
|
133
109
|
CRFEntityExtractorOptions.PATTERN: lambda crf_token: crf_token.pattern,
|
|
134
110
|
CRFEntityExtractorOptions.TEXT_DENSE_FEATURES: (
|
|
135
|
-
lambda crf_token: CRFEntityExtractor._convert_dense_features_for_crfsuite(
|
|
111
|
+
lambda crf_token: CRFEntityExtractor._convert_dense_features_for_crfsuite(
|
|
136
112
|
crf_token
|
|
137
113
|
)
|
|
138
114
|
),
|
|
@@ -161,7 +137,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
161
137
|
# "is the preceding token in title case?"
|
|
162
138
|
# POS features require SpacyTokenizer
|
|
163
139
|
# pattern feature require RegexFeaturizer
|
|
164
|
-
CONFIG_FEATURES: [
|
|
140
|
+
CRFEntityExtractor.CONFIG_FEATURES: [
|
|
165
141
|
[
|
|
166
142
|
CRFEntityExtractorOptions.LOW,
|
|
167
143
|
CRFEntityExtractorOptions.TITLE,
|
|
@@ -224,7 +200,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
224
200
|
)
|
|
225
201
|
|
|
226
202
|
def _validate_configuration(self) -> None:
|
|
227
|
-
if len(self.component_config.get(CONFIG_FEATURES, [])) % 2 != 1:
|
|
203
|
+
if len(self.component_config.get(self.CONFIG_FEATURES, [])) % 2 != 1:
|
|
228
204
|
raise ValueError(
|
|
229
205
|
"Need an odd number of crf feature lists to have a center word."
|
|
230
206
|
)
|
|
@@ -275,11 +251,9 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
275
251
|
]
|
|
276
252
|
dataset = [self._convert_to_crf_tokens(example) for example in entity_examples]
|
|
277
253
|
|
|
278
|
-
self.
|
|
279
|
-
dataset, self.component_config, self.crf_order
|
|
280
|
-
)
|
|
254
|
+
self._train_model(dataset)
|
|
281
255
|
|
|
282
|
-
self.persist(
|
|
256
|
+
self.persist()
|
|
283
257
|
|
|
284
258
|
return self._resource
|
|
285
259
|
|
|
@@ -325,9 +299,7 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
325
299
|
if include_tag_features:
|
|
326
300
|
self._add_tag_to_crf_token(crf_tokens, predictions)
|
|
327
301
|
|
|
328
|
-
features = self._crf_tokens_to_features(
|
|
329
|
-
crf_tokens, self.component_config, include_tag_features
|
|
330
|
-
)
|
|
302
|
+
features = self._crf_tokens_to_features(crf_tokens, include_tag_features)
|
|
331
303
|
predictions[tag_name] = entity_tagger.predict_marginals_single(features)
|
|
332
304
|
|
|
333
305
|
# convert predictions into a list of tags and a list of confidences
|
|
@@ -417,25 +389,27 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
417
389
|
**kwargs: Any,
|
|
418
390
|
) -> CRFEntityExtractor:
|
|
419
391
|
"""Loads trained component (see parent class for full docstring)."""
|
|
392
|
+
import joblib
|
|
393
|
+
|
|
420
394
|
try:
|
|
395
|
+
entity_taggers = OrderedDict()
|
|
421
396
|
with model_storage.read_from(resource) as model_dir:
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
]
|
|
397
|
+
# We have to load in the same order as we persisted things as otherwise
|
|
398
|
+
# the predictions might be off
|
|
399
|
+
file_names = sorted(model_dir.glob("**/*.pkl"))
|
|
400
|
+
if not file_names:
|
|
401
|
+
logger.debug(
|
|
402
|
+
"Failed to load model for 'CRFEntityExtractor'. "
|
|
403
|
+
"Maybe you did not provide enough training data and "
|
|
404
|
+
"no model was trained."
|
|
405
|
+
)
|
|
406
|
+
return cls(config, model_storage, resource)
|
|
433
407
|
|
|
434
|
-
|
|
408
|
+
for file_name in file_names:
|
|
409
|
+
name = file_name.stem[1:]
|
|
410
|
+
entity_taggers[name] = joblib.load(file_name)
|
|
435
411
|
|
|
436
|
-
|
|
437
|
-
entity_extractor.crf_order = crf_order
|
|
438
|
-
return entity_extractor
|
|
412
|
+
return cls(config, model_storage, resource, entity_taggers)
|
|
439
413
|
except ValueError:
|
|
440
414
|
logger.warning(
|
|
441
415
|
f"Failed to load {cls.__name__} from model storage. Resource "
|
|
@@ -443,29 +417,23 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
443
417
|
)
|
|
444
418
|
return cls(config, model_storage, resource)
|
|
445
419
|
|
|
446
|
-
def persist(self
|
|
420
|
+
def persist(self) -> None:
|
|
447
421
|
"""Persist this model into the passed directory."""
|
|
448
|
-
|
|
449
|
-
data_to_store = [
|
|
450
|
-
[token.to_dict() for token in sub_list] for sub_list in dataset
|
|
451
|
-
]
|
|
422
|
+
import joblib
|
|
452
423
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
424
|
+
with self._model_storage.write_to(self._resource) as model_dir:
|
|
425
|
+
if self.entity_taggers:
|
|
426
|
+
for idx, (name, entity_tagger) in enumerate(
|
|
427
|
+
self.entity_taggers.items()
|
|
428
|
+
):
|
|
429
|
+
model_file_name = model_dir / f"{idx}{name}.pkl"
|
|
430
|
+
joblib.dump(entity_tagger, model_file_name)
|
|
459
431
|
|
|
460
|
-
@classmethod
|
|
461
432
|
def _crf_tokens_to_features(
|
|
462
|
-
|
|
463
|
-
crf_tokens: List[CRFToken],
|
|
464
|
-
config: Dict[str, Any],
|
|
465
|
-
include_tag_features: bool = False,
|
|
433
|
+
self, crf_tokens: List[CRFToken], include_tag_features: bool = False
|
|
466
434
|
) -> List[Dict[Text, Any]]:
|
|
467
435
|
"""Convert the list of tokens into discrete features."""
|
|
468
|
-
configured_features =
|
|
436
|
+
configured_features = self.component_config[self.CONFIG_FEATURES]
|
|
469
437
|
sentence_features = []
|
|
470
438
|
|
|
471
439
|
for token_idx in range(len(crf_tokens)):
|
|
@@ -476,31 +444,28 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
476
444
|
half_window_size = window_size // 2
|
|
477
445
|
window_range = range(-half_window_size, half_window_size + 1)
|
|
478
446
|
|
|
479
|
-
token_features =
|
|
447
|
+
token_features = self._create_features_for_token(
|
|
480
448
|
crf_tokens,
|
|
481
449
|
token_idx,
|
|
482
450
|
half_window_size,
|
|
483
451
|
window_range,
|
|
484
452
|
include_tag_features,
|
|
485
|
-
config,
|
|
486
453
|
)
|
|
487
454
|
|
|
488
455
|
sentence_features.append(token_features)
|
|
489
456
|
|
|
490
457
|
return sentence_features
|
|
491
458
|
|
|
492
|
-
@classmethod
|
|
493
459
|
def _create_features_for_token(
|
|
494
|
-
|
|
460
|
+
self,
|
|
495
461
|
crf_tokens: List[CRFToken],
|
|
496
462
|
token_idx: int,
|
|
497
463
|
half_window_size: int,
|
|
498
464
|
window_range: range,
|
|
499
465
|
include_tag_features: bool,
|
|
500
|
-
config: Dict[str, Any],
|
|
501
466
|
) -> Dict[Text, Any]:
|
|
502
467
|
"""Convert a token into discrete features including words before and after."""
|
|
503
|
-
configured_features =
|
|
468
|
+
configured_features = self.component_config[self.CONFIG_FEATURES]
|
|
504
469
|
prefixes = [str(i) for i in window_range]
|
|
505
470
|
|
|
506
471
|
token_features = {}
|
|
@@ -540,13 +505,13 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
540
505
|
# set in the training data, 'matched' is either 'True' or
|
|
541
506
|
# 'False' depending on whether the token actually matches the
|
|
542
507
|
# pattern or not
|
|
543
|
-
regex_patterns =
|
|
508
|
+
regex_patterns = self.function_dict[feature](token)
|
|
544
509
|
for pattern_name, matched in regex_patterns.items():
|
|
545
|
-
token_features[
|
|
546
|
-
|
|
547
|
-
|
|
510
|
+
token_features[f"{prefix}:{feature}:{pattern_name}"] = (
|
|
511
|
+
matched
|
|
512
|
+
)
|
|
548
513
|
else:
|
|
549
|
-
value =
|
|
514
|
+
value = self.function_dict[feature](token)
|
|
550
515
|
token_features[f"{prefix}:{feature}"] = value
|
|
551
516
|
|
|
552
517
|
return token_features
|
|
@@ -670,46 +635,38 @@ class CRFEntityExtractor(GraphComponent, EntityExtractorMixin):
|
|
|
670
635
|
|
|
671
636
|
return tags
|
|
672
637
|
|
|
673
|
-
|
|
674
|
-
def train_model(
|
|
675
|
-
cls,
|
|
676
|
-
df_train: List[List[CRFToken]],
|
|
677
|
-
config: Dict[str, Any],
|
|
678
|
-
crf_order: List[str],
|
|
679
|
-
) -> OrderedDict[str, CRF]:
|
|
638
|
+
def _train_model(self, df_train: List[List[CRFToken]]) -> None:
|
|
680
639
|
"""Train the crf tagger based on the training data."""
|
|
681
640
|
import sklearn_crfsuite
|
|
682
641
|
|
|
683
|
-
entity_taggers = OrderedDict()
|
|
642
|
+
self.entity_taggers = OrderedDict()
|
|
684
643
|
|
|
685
|
-
for tag_name in crf_order:
|
|
644
|
+
for tag_name in self.crf_order:
|
|
686
645
|
logger.debug(f"Training CRF for '{tag_name}'.")
|
|
687
646
|
|
|
688
647
|
# add entity tag features for second level CRFs
|
|
689
648
|
include_tag_features = tag_name != ENTITY_ATTRIBUTE_TYPE
|
|
690
649
|
X_train = (
|
|
691
|
-
|
|
650
|
+
self._crf_tokens_to_features(sentence, include_tag_features)
|
|
692
651
|
for sentence in df_train
|
|
693
652
|
)
|
|
694
653
|
y_train = (
|
|
695
|
-
|
|
654
|
+
self._crf_tokens_to_tags(sentence, tag_name) for sentence in df_train
|
|
696
655
|
)
|
|
697
656
|
|
|
698
657
|
entity_tagger = sklearn_crfsuite.CRF(
|
|
699
658
|
algorithm="lbfgs",
|
|
700
659
|
# coefficient for L1 penalty
|
|
701
|
-
c1=
|
|
660
|
+
c1=self.component_config["L1_c"],
|
|
702
661
|
# coefficient for L2 penalty
|
|
703
|
-
c2=
|
|
662
|
+
c2=self.component_config["L2_c"],
|
|
704
663
|
# stop earlier
|
|
705
|
-
max_iterations=
|
|
664
|
+
max_iterations=self.component_config["max_iterations"],
|
|
706
665
|
# include transitions that are possible, but not observed
|
|
707
666
|
all_possible_transitions=True,
|
|
708
667
|
)
|
|
709
668
|
entity_tagger.fit(X_train, y_train)
|
|
710
669
|
|
|
711
|
-
entity_taggers[tag_name] = entity_tagger
|
|
670
|
+
self.entity_taggers[tag_name] = entity_tagger
|
|
712
671
|
|
|
713
672
|
logger.debug("Training finished.")
|
|
714
|
-
|
|
715
|
-
return entity_taggers
|
|
@@ -189,7 +189,7 @@ class ConveRTFeaturizer(DenseFeaturizer, GraphComponent):
|
|
|
189
189
|
f"Parameter 'model_url' of "
|
|
190
190
|
f"'{ConveRTFeaturizer.__name__}' was "
|
|
191
191
|
f"set to '{model_url}' which is strictly reserved for pytests of "
|
|
192
|
-
f"Rasa
|
|
192
|
+
f"Rasa Pro only. Due to licensing issues you are "
|
|
193
193
|
f"not allowed to use the model from this URL. "
|
|
194
194
|
f"You can either use a community hosted URL or if you have a "
|
|
195
195
|
f"local copy of the model, pass the path to the directory "
|
|
@@ -323,13 +323,11 @@ class ConveRTFeaturizer(DenseFeaturizer, GraphComponent):
|
|
|
323
323
|
return texts
|
|
324
324
|
|
|
325
325
|
def _sentence_encoding_of_text(self, batch: List[Text]) -> np.ndarray:
|
|
326
|
-
|
|
327
326
|
return self.sentence_encoding_signature(tf.convert_to_tensor(batch))[
|
|
328
327
|
"default"
|
|
329
328
|
].numpy()
|
|
330
329
|
|
|
331
330
|
def _sequence_encoding_of_text(self, batch: List[Text]) -> np.ndarray:
|
|
332
|
-
|
|
333
331
|
return self.sequence_encoding_signature(tf.convert_to_tensor(batch))[
|
|
334
332
|
"sequence_encoding"
|
|
335
333
|
].numpy()
|
|
@@ -346,7 +344,6 @@ class ConveRTFeaturizer(DenseFeaturizer, GraphComponent):
|
|
|
346
344
|
batch_size = 64
|
|
347
345
|
|
|
348
346
|
for attribute in DENSE_FEATURIZABLE_ATTRIBUTES:
|
|
349
|
-
|
|
350
347
|
non_empty_examples = list(
|
|
351
348
|
filter(lambda x: x.get(attribute), training_data.training_examples)
|
|
352
349
|
)
|
|
@@ -410,7 +407,6 @@ class ConveRTFeaturizer(DenseFeaturizer, GraphComponent):
|
|
|
410
407
|
)
|
|
411
408
|
|
|
412
409
|
def _tokenize(self, sentence: Text) -> Any:
|
|
413
|
-
|
|
414
410
|
return self.tokenize_signature(tf.convert_to_tensor([sentence]))[
|
|
415
411
|
"default"
|
|
416
412
|
].numpy()
|
|
@@ -316,7 +316,6 @@ class LanguageModelFeaturizer(DenseFeaturizer, GraphComponent):
|
|
|
316
316
|
batch_token_ids = []
|
|
317
317
|
batch_tokens = []
|
|
318
318
|
for example in batch_examples:
|
|
319
|
-
|
|
320
319
|
example_tokens, example_token_ids = self._tokenize_example(
|
|
321
320
|
example, attribute
|
|
322
321
|
)
|
|
@@ -416,7 +415,6 @@ class LanguageModelFeaturizer(DenseFeaturizer, GraphComponent):
|
|
|
416
415
|
# This doesn't affect the computation since we compute an attention mask
|
|
417
416
|
# anyways.
|
|
418
417
|
for example_token_ids in batch_token_ids:
|
|
419
|
-
|
|
420
418
|
# Truncate any longer sequences so that they can be fed to the model
|
|
421
419
|
if len(example_token_ids) > max_sequence_length_model:
|
|
422
420
|
example_token_ids = example_token_ids[:max_sequence_length_model]
|
|
@@ -710,7 +708,6 @@ class LanguageModelFeaturizer(DenseFeaturizer, GraphComponent):
|
|
|
710
708
|
batch_size = 64
|
|
711
709
|
|
|
712
710
|
for attribute in DENSE_FEATURIZABLE_ATTRIBUTES:
|
|
713
|
-
|
|
714
711
|
non_empty_examples = list(
|
|
715
712
|
filter(lambda x: x.get(attribute), training_data.training_examples)
|
|
716
713
|
)
|
|
@@ -718,7 +715,6 @@ class LanguageModelFeaturizer(DenseFeaturizer, GraphComponent):
|
|
|
718
715
|
batch_start_index = 0
|
|
719
716
|
|
|
720
717
|
while batch_start_index < len(non_empty_examples):
|
|
721
|
-
|
|
722
718
|
batch_end_index = min(
|
|
723
719
|
batch_start_index + batch_size, len(non_empty_examples)
|
|
724
720
|
)
|
|
@@ -64,7 +64,7 @@ class Featurizer(Generic[FeatureType], ABC):
|
|
|
64
64
|
|
|
65
65
|
@staticmethod
|
|
66
66
|
def raise_if_featurizer_configs_are_not_compatible(
|
|
67
|
-
featurizer_configs: Iterable[Dict[Text, Any]]
|
|
67
|
+
featurizer_configs: Iterable[Dict[Text, Any]],
|
|
68
68
|
) -> None:
|
|
69
69
|
"""Validates that the given configurations of featurizers can be used together.
|
|
70
70
|
|