rasa-pro 3.8.18__py3-none-any.whl → 3.9.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +6 -42
- rasa/__main__.py +14 -9
- rasa/anonymization/anonymization_pipeline.py +0 -1
- rasa/anonymization/anonymization_rule_executor.py +3 -3
- rasa/anonymization/utils.py +4 -3
- rasa/api.py +2 -2
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/arguments/run.py +2 -2
- rasa/cli/arguments/test.py +1 -1
- rasa/cli/arguments/train.py +10 -10
- rasa/cli/e2e_test.py +27 -7
- rasa/cli/export.py +0 -1
- rasa/cli/license.py +3 -3
- rasa/cli/project_templates/calm/actions/action_template.py +1 -1
- rasa/cli/project_templates/calm/config.yml +1 -1
- rasa/cli/project_templates/calm/credentials.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +1 -1
- rasa/cli/project_templates/calm/domain/add_contact.yml +8 -2
- rasa/cli/project_templates/calm/domain/list_contacts.yml +3 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +9 -2
- rasa/cli/project_templates/calm/domain/shared.yml +5 -0
- rasa/cli/project_templates/calm/endpoints.yml +4 -4
- rasa/cli/project_templates/default/actions/actions.py +1 -1
- rasa/cli/project_templates/default/config.yml +5 -5
- rasa/cli/project_templates/default/credentials.yml +1 -1
- rasa/cli/project_templates/default/endpoints.yml +4 -4
- rasa/cli/project_templates/default/tests/test_stories.yml +1 -1
- rasa/cli/project_templates/tutorial/config.yml +1 -1
- rasa/cli/project_templates/tutorial/credentials.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
- rasa/cli/project_templates/tutorial/domain.yml +4 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +6 -6
- rasa/cli/run.py +0 -1
- rasa/cli/scaffold.py +3 -2
- rasa/cli/studio/download.py +11 -0
- rasa/cli/studio/studio.py +180 -24
- rasa/cli/studio/upload.py +0 -8
- rasa/cli/telemetry.py +18 -6
- rasa/cli/utils.py +21 -10
- rasa/cli/x.py +3 -2
- rasa/constants.py +1 -1
- rasa/core/actions/action.py +90 -315
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/constants.py +3 -0
- rasa/core/actions/custom_action_executor.py +188 -0
- rasa/core/actions/forms.py +11 -7
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +140 -0
- rasa/core/actions/loops.py +3 -0
- rasa/core/actions/two_stage_fallback.py +1 -1
- rasa/core/agent.py +2 -4
- rasa/core/brokers/pika.py +1 -2
- rasa/core/channels/audiocodes.py +1 -1
- rasa/core/channels/botframework.py +0 -1
- rasa/core/channels/callback.py +0 -1
- rasa/core/channels/console.py +6 -8
- rasa/core/channels/development_inspector.py +1 -1
- rasa/core/channels/facebook.py +0 -3
- rasa/core/channels/hangouts.py +0 -6
- rasa/core/channels/inspector/dist/assets/{arc-5623b6dc.js → arc-b6e548fe.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-685c106a.js → c4Diagram-d0fbc5ce-fa03ac9e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-8cbed007.js → classDiagram-936ed81e-ee67392a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-5889cf12.js → classDiagram-v2-c3cb15f1-9b283fae.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-24c249d7.js → createText-62fc7601-8b6fcc2a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-7dd06a75.js → edges-f2ad444c-22e77f4f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-62c1e54c.js → erDiagram-9d236eb7-60ffc87f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-ce49b86f.js → flowDb-1972c806-9dd802e4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4067e48f.js → flowDiagram-7ea5b25a-5fa1912f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-59fe4051.js → flowchart-elk-definition-abe16c3d-622a1fd2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-47e3a43b.js → ganttDiagram-9b5ea136-e285a63a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-5a2ac0d9.js → gitGraphDiagram-99d0ae7c-f237bdca.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-dfb8efc4.js → index-2c4b9a3b-4b03d70e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-268a75c0.js → index-a5d3e69d.js} +4 -4
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-b0c470f2.js → infoDiagram-736b4530-72a0fa5f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-2edb829a.js → journeyDiagram-df861f2b-82218c41.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b6873d69.js → layout-78cff630.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-1efc5781.js → line-5038b469.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-661e9b94.js → linear-c4fc4098.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2d2e727f.js → mindmap-definition-beec6740-c33c8ea6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-9d3ea93d.js → pieDiagram-dbbf0591-a8d03059.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-06a178a2.js → quadrantDiagram-4d7f4fd6-6a0e56b2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-0bfedffc.js → requirementDiagram-6fc4c22a-2dc7c7bd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-d76d0a04.js → sankeyDiagram-8f13d901-2360fe39.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-37bb4341.js → sequenceDiagram-b655622a-41b9f9ad.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-f52f7f57.js → stateDiagram-59f0c015-0aad326f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-4a986a20.js → stateDiagram-v2-2b26beab-9847d984.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-7dd9ae12.js → styles-080da4f6-564d890e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-46e1ca14.js → styles-3dcbcfbf-38957613.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-4a97439a.js → styles-9c745c82-f0fc6921.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-823917a3.js → svgDrawCommon-4835440b-ef3c5a77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-9ea72896.js → timeline-definition-5b62e21b-bf3e91c1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-b631a8b6.js → xychartDiagram-2b33534f-4d4026c0.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -7
- rasa/core/channels/inspector/src/helpers/formatters.ts +3 -2
- rasa/core/channels/rest.py +36 -21
- rasa/core/channels/rocketchat.py +0 -1
- rasa/core/channels/socketio.py +1 -1
- rasa/core/channels/telegram.py +3 -3
- rasa/core/channels/webexteams.py +0 -1
- rasa/core/concurrent_lock_store.py +1 -1
- rasa/core/evaluation/marker_base.py +1 -3
- rasa/core/evaluation/marker_stats.py +1 -2
- rasa/core/featurizers/single_state_featurizer.py +3 -26
- rasa/core/featurizers/tracker_featurizers.py +18 -122
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +9 -4
- rasa/core/information_retrieval/information_retrieval.py +64 -7
- rasa/core/information_retrieval/milvus.py +7 -14
- rasa/core/information_retrieval/qdrant.py +8 -15
- rasa/core/lock_store.py +0 -1
- rasa/core/migrate.py +1 -2
- rasa/core/nlg/callback.py +3 -4
- rasa/core/policies/enterprise_search_policy.py +86 -22
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +4 -41
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flows/flow_executor.py +104 -2
- rasa/core/policies/intentless_policy.py +7 -9
- rasa/core/policies/memoization.py +3 -3
- rasa/core/policies/policy.py +18 -9
- rasa/core/policies/rule_policy.py +8 -11
- rasa/core/policies/ted_policy.py +61 -88
- rasa/core/policies/unexpected_intent_policy.py +8 -17
- rasa/core/processor.py +136 -47
- rasa/core/run.py +41 -25
- rasa/core/secrets_manager/endpoints.py +2 -2
- rasa/core/secrets_manager/vault.py +6 -8
- rasa/core/test.py +3 -5
- rasa/core/tracker_store.py +49 -14
- rasa/core/train.py +1 -3
- rasa/core/training/interactive.py +9 -6
- rasa/core/utils.py +5 -10
- rasa/dialogue_understanding/coexistence/intent_based_router.py +11 -4
- rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -3
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +9 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +9 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +38 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/clarify_command.py +9 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +9 -0
- rasa/dialogue_understanding/commands/error_command.py +12 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +9 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +9 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/noop_command.py +9 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +34 -3
- rasa/dialogue_understanding/commands/skip_question_command.py +9 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +9 -0
- rasa/dialogue_understanding/generator/__init__.py +16 -1
- rasa/dialogue_understanding/generator/command_generator.py +92 -6
- rasa/dialogue_understanding/generator/constants.py +18 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +7 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +39 -609
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +69 -8
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +44 -39
- rasa/dialogue_understanding/processor/command_processor.py +111 -3
- rasa/e2e_test/constants.py +1 -0
- rasa/e2e_test/e2e_test_case.py +44 -0
- rasa/e2e_test/e2e_test_runner.py +114 -11
- rasa/e2e_test/e2e_test_schema.yml +18 -0
- rasa/engine/caching.py +0 -1
- rasa/engine/graph.py +18 -6
- rasa/engine/recipes/config_files/default_config.yml +3 -3
- rasa/engine/recipes/default_components.py +1 -1
- rasa/engine/recipes/default_recipe.py +4 -5
- rasa/engine/recipes/recipe.py +1 -1
- rasa/engine/runner/dask.py +3 -9
- rasa/engine/storage/local_model_storage.py +0 -2
- rasa/engine/validation.py +179 -145
- rasa/exceptions.py +2 -2
- rasa/graph_components/validators/default_recipe_validator.py +3 -5
- rasa/hooks.py +0 -1
- rasa/model.py +1 -1
- rasa/model_training.py +1 -0
- rasa/nlu/classifiers/diet_classifier.py +33 -52
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +54 -97
- rasa/nlu/extractors/duckling_entity_extractor.py +1 -1
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +1 -5
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +0 -4
- rasa/nlu/featurizers/featurizer.py +1 -1
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +18 -49
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +26 -64
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/persistor.py +68 -26
- rasa/nlu/selectors/response_selector.py +7 -10
- rasa/nlu/test.py +0 -3
- rasa/nlu/utils/hugging_face/registry.py +1 -1
- rasa/nlu/utils/spacy_utils.py +1 -3
- rasa/server.py +22 -7
- rasa/shared/constants.py +12 -1
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +4 -5
- rasa/shared/core/domain.py +57 -56
- rasa/shared/core/events.py +4 -7
- rasa/shared/core/flows/flow.py +9 -0
- rasa/shared/core/flows/flows_list.py +12 -0
- rasa/shared/core/flows/steps/action.py +7 -2
- rasa/shared/core/generator.py +12 -11
- rasa/shared/core/slot_mappings.py +315 -24
- rasa/shared/core/slots.py +4 -2
- rasa/shared/core/trackers.py +32 -14
- rasa/shared/core/training_data/loading.py +0 -1
- rasa/shared/core/training_data/story_reader/story_reader.py +3 -3
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +11 -11
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +5 -3
- rasa/shared/core/training_data/structures.py +1 -1
- rasa/shared/core/training_data/visualization.py +1 -1
- rasa/shared/data.py +58 -1
- rasa/shared/exceptions.py +36 -2
- rasa/shared/importers/importer.py +1 -2
- rasa/shared/importers/rasa.py +0 -1
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/entities_parser.py +1 -2
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/dialogflow.py +3 -2
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -5
- rasa/shared/nlu/training_data/formats/readerwriter.py +0 -1
- rasa/shared/nlu/training_data/message.py +13 -0
- rasa/shared/nlu/training_data/training_data.py +0 -2
- rasa/shared/providers/openai/session_handler.py +2 -2
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/io.py +11 -1
- rasa/shared/utils/llm.py +1 -2
- rasa/shared/utils/pykwalify_extensions.py +1 -0
- rasa/shared/utils/schemas/domain.yml +3 -0
- rasa/shared/utils/yaml.py +44 -35
- rasa/studio/auth.py +26 -10
- rasa/studio/constants.py +2 -0
- rasa/studio/data_handler.py +114 -107
- rasa/studio/download.py +160 -27
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +6 -7
- rasa/studio/upload.py +159 -134
- rasa/telemetry.py +188 -34
- rasa/tracing/config.py +18 -3
- rasa/tracing/constants.py +26 -2
- rasa/tracing/instrumentation/attribute_extractors.py +50 -41
- rasa/tracing/instrumentation/instrumentation.py +290 -44
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +7 -5
- rasa/tracing/instrumentation/metrics.py +109 -21
- rasa/tracing/metric_instrument_provider.py +83 -3
- rasa/utils/cli.py +2 -1
- rasa/utils/common.py +1 -1
- rasa/utils/endpoints.py +1 -2
- rasa/utils/io.py +72 -6
- rasa/utils/licensing.py +246 -31
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/data_generator.py +1 -1
- rasa/utils/tensorflow/environment.py +1 -1
- rasa/utils/tensorflow/model_data.py +201 -12
- rasa/utils/tensorflow/model_data_utils.py +499 -500
- rasa/utils/tensorflow/models.py +5 -6
- rasa/utils/tensorflow/rasa_layers.py +15 -15
- rasa/utils/train_utils.py +1 -1
- rasa/utils/url_tools.py +53 -0
- rasa/validator.py +305 -3
- rasa/version.py +1 -1
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/METADATA +25 -61
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/RECORD +276 -259
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +0 -1
- rasa/utils/tensorflow/feature_array.py +0 -370
- /rasa/dialogue_understanding/generator/{command_prompt_template.jinja2 → single_step/command_prompt_template.jinja2} +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/NOTICE +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/WHEEL +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.14.dist-info}/entry_points.txt +0 -0
rasa/engine/validation.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
import inspect
|
|
3
|
+
import re
|
|
3
4
|
import logging
|
|
4
5
|
import sys
|
|
5
6
|
import typing
|
|
@@ -30,6 +31,9 @@ from rasa.dialogue_understanding.coexistence.constants import (
|
|
|
30
31
|
STICKY,
|
|
31
32
|
NON_STICKY,
|
|
32
33
|
)
|
|
34
|
+
from rasa.dialogue_understanding.generator import (
|
|
35
|
+
LLMBasedCommandGenerator,
|
|
36
|
+
)
|
|
33
37
|
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
34
38
|
from rasa.engine.constants import RESERVED_PLACEHOLDERS
|
|
35
39
|
from rasa.engine.exceptions import GraphSchemaValidationException
|
|
@@ -51,6 +55,10 @@ from rasa.shared.core.slots import Slot
|
|
|
51
55
|
from rasa.shared.exceptions import RasaException
|
|
52
56
|
from rasa.shared.nlu.training_data.message import Message
|
|
53
57
|
|
|
58
|
+
from rasa.dialogue_understanding.coexistence.intent_based_router import (
|
|
59
|
+
IntentBasedRouter,
|
|
60
|
+
)
|
|
61
|
+
from rasa.dialogue_understanding.coexistence.llm_based_router import LLMBasedRouter
|
|
54
62
|
|
|
55
63
|
TypeAnnotation = Union[TypeVar, Text, Type, Optional[AvailableEndpoints]]
|
|
56
64
|
|
|
@@ -209,7 +217,7 @@ def _validate_interface_usage(node: SchemaNode) -> None:
|
|
|
209
217
|
raise GraphSchemaValidationException(
|
|
210
218
|
f"Your model uses a component with class '{node.uses.__name__}'. "
|
|
211
219
|
f"This class does not implement the '{GraphComponent.__name__}' interface "
|
|
212
|
-
f"and can hence not be run within Rasa
|
|
220
|
+
f"and can hence not be run within Rasa Pro. Please use a different "
|
|
213
221
|
f"component or implement the '{GraphComponent}' interface in class "
|
|
214
222
|
f"'{node.uses.__name__}'. "
|
|
215
223
|
f"See {DOCS_URL_GRAPH_COMPONENTS} for more information."
|
|
@@ -503,7 +511,6 @@ def _validate_parent_return_type(
|
|
|
503
511
|
parent_return_type: TypeAnnotation,
|
|
504
512
|
required_type: TypeAnnotation,
|
|
505
513
|
) -> None:
|
|
506
|
-
|
|
507
514
|
if not typing_utils.issubtype(parent_return_type, required_type):
|
|
508
515
|
parent_node_text = ""
|
|
509
516
|
if parent_node:
|
|
@@ -606,7 +613,6 @@ def _recursively_check_required_components(
|
|
|
606
613
|
def validate_flow_component_dependencies(
|
|
607
614
|
flows: FlowsList, model_configuration: GraphModelConfiguration
|
|
608
615
|
) -> None:
|
|
609
|
-
|
|
610
616
|
if (pattern_chitchat := flows.flow_by_id(FLOW_PATTERN_CHITCHAT)) is not None:
|
|
611
617
|
_validate_chitchat_dependencies(pattern_chitchat, model_configuration)
|
|
612
618
|
|
|
@@ -637,166 +643,169 @@ def _validate_chitchat_dependencies(
|
|
|
637
643
|
)
|
|
638
644
|
|
|
639
645
|
|
|
640
|
-
def
|
|
641
|
-
|
|
646
|
+
def get_component_index(schema: GraphSchema, component_class: Type) -> Optional[int]:
|
|
647
|
+
"""Extracts the index of a component of the given class in the schema.
|
|
648
|
+
This function assumes that each component's node name is stored in a way
|
|
649
|
+
that includes the index as part of the name, formatted as
|
|
650
|
+
"run_{ComponentName}{Index}", which is how it's created by the recipe.
|
|
651
|
+
"""
|
|
652
|
+
# the index of the component is at the end of the node name
|
|
653
|
+
pattern = re.compile(r"\d+$")
|
|
654
|
+
for node_name, node in schema.nodes.items():
|
|
655
|
+
if issubclass(node.uses, component_class):
|
|
656
|
+
match = pattern.search(node_name)
|
|
657
|
+
if match:
|
|
658
|
+
index = int(match.group())
|
|
659
|
+
return index
|
|
660
|
+
# index is not found or there is no component with the given class
|
|
661
|
+
return None
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def get_component_config(
|
|
665
|
+
schema: GraphSchema, component_class: Type
|
|
666
|
+
) -> Optional[Dict[str, Any]]:
|
|
667
|
+
"""Extracts the config of a component of the given class in the schema."""
|
|
668
|
+
for node_name, node in schema.nodes.items():
|
|
669
|
+
if issubclass(node.uses, component_class):
|
|
670
|
+
return node.config
|
|
671
|
+
return None
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def validate_router_exclusivity(schema: GraphSchema) -> None:
|
|
675
|
+
"""Validate that intent-based and llm-based routers are not
|
|
676
|
+
defined at the same time.
|
|
677
|
+
"""
|
|
678
|
+
if schema.has_node(IntentBasedRouter) and schema.has_node(LLMBasedRouter):
|
|
679
|
+
structlogger.error(
|
|
680
|
+
"validation.coexistance.both_routers_defined",
|
|
681
|
+
event_info=(
|
|
682
|
+
"Both LLMBasedRouter and IntentBasedRouter are in the config. "
|
|
683
|
+
"Please use only one of them."
|
|
684
|
+
),
|
|
685
|
+
)
|
|
686
|
+
sys.exit(1)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def validate_intent_based_router_position(schema: GraphSchema) -> None:
|
|
690
|
+
"""Validate that if intent-based router is defined, it is positioned before
|
|
691
|
+
the llm command generator.
|
|
692
|
+
"""
|
|
693
|
+
intent_based_router_pos = get_component_index(schema, IntentBasedRouter)
|
|
694
|
+
llm_command_generator_pos = get_component_index(schema, LLMBasedCommandGenerator)
|
|
695
|
+
if (
|
|
696
|
+
intent_based_router_pos is not None
|
|
697
|
+
and llm_command_generator_pos is not None
|
|
698
|
+
and intent_based_router_pos > llm_command_generator_pos
|
|
699
|
+
):
|
|
700
|
+
structlogger.error(
|
|
701
|
+
"validation.coexistance.wrong_order_of_components",
|
|
702
|
+
event_info=(
|
|
703
|
+
"IntentBasedRouter should come before "
|
|
704
|
+
"a LLMBasedCommandGenerator in the pipeline."
|
|
705
|
+
),
|
|
706
|
+
)
|
|
707
|
+
sys.exit(1)
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def validate_that_slots_are_defined_if_router_is_defined(
|
|
711
|
+
schema: GraphSchema, routing_slots: List[Slot]
|
|
642
712
|
) -> None:
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
from rasa.dialogue_understanding.generator import LLMCommandGenerator
|
|
649
|
-
|
|
650
|
-
def get_component_index(
|
|
651
|
-
schema: GraphSchema, component_class: Type
|
|
652
|
-
) -> Optional[int]:
|
|
653
|
-
"""Extracts the index of a component of the given class in the schema.
|
|
654
|
-
This function assumes that each component's node name is stored in a way
|
|
655
|
-
that includes the index as part of the name, formatted as
|
|
656
|
-
"run_{ComponentName}{Index}", which is how it's created by the recipe.
|
|
657
|
-
"""
|
|
658
|
-
# the index of the component is at the end of the node name
|
|
659
|
-
pattern = re.compile(r"\d+$")
|
|
660
|
-
for node_name, node in schema.nodes.items():
|
|
661
|
-
if issubclass(node.uses, component_class):
|
|
662
|
-
match = pattern.search(node_name)
|
|
663
|
-
if match:
|
|
664
|
-
index = int(match.group())
|
|
665
|
-
return index
|
|
666
|
-
# index is not found or there is no component with the given class
|
|
667
|
-
return None
|
|
668
|
-
|
|
669
|
-
def get_component_config(
|
|
670
|
-
schema: GraphSchema, component_class: Type
|
|
671
|
-
) -> Optional[Dict[str, Any]]:
|
|
672
|
-
"""Extracts the config of a component of the given class in the schema."""
|
|
673
|
-
for node_name, node in schema.nodes.items():
|
|
674
|
-
if issubclass(node.uses, component_class):
|
|
675
|
-
return node.config
|
|
676
|
-
return None
|
|
677
|
-
|
|
678
|
-
def validate_router_exclusivity(schema: GraphSchema) -> None:
|
|
679
|
-
"""Validate that intent-based and llm-based routers are not
|
|
680
|
-
defined at the same time.
|
|
681
|
-
"""
|
|
682
|
-
if schema.has_node(IntentBasedRouter) and schema.has_node(LLMBasedRouter):
|
|
713
|
+
# check whether intent-based or llm-based type of router is present
|
|
714
|
+
for router_type in [IntentBasedRouter, LLMBasedRouter]:
|
|
715
|
+
router_present = schema.has_node(router_type)
|
|
716
|
+
slot_has_issue = len(routing_slots) == 0 or routing_slots[0].type_name != "bool"
|
|
717
|
+
if router_present and slot_has_issue:
|
|
683
718
|
structlogger.error(
|
|
684
|
-
"validation.coexistance.
|
|
719
|
+
f"validation.coexistance.{ROUTE_TO_CALM_SLOT}_not_in_domain",
|
|
685
720
|
event_info=(
|
|
686
|
-
"
|
|
687
|
-
"
|
|
721
|
+
f"{router_type.__name__} is in the config, but the slot "
|
|
722
|
+
f"{ROUTE_TO_CALM_SLOT} is not in the domain or not of "
|
|
723
|
+
f"type bool."
|
|
688
724
|
),
|
|
689
725
|
)
|
|
690
726
|
sys.exit(1)
|
|
691
727
|
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
728
|
+
|
|
729
|
+
def validate_that_router_is_defined_if_router_slots_are_in_domain(
|
|
730
|
+
schema: GraphSchema,
|
|
731
|
+
routing_slots: List[Slot],
|
|
732
|
+
) -> None:
|
|
733
|
+
defined_router_slots = len(routing_slots) > 0
|
|
734
|
+
router_present = schema.has_node(IntentBasedRouter) or schema.has_node(
|
|
735
|
+
LLMBasedRouter
|
|
736
|
+
)
|
|
737
|
+
if defined_router_slots and (
|
|
738
|
+
not router_present or routing_slots[0].type_name != "bool"
|
|
739
|
+
):
|
|
740
|
+
structlogger.error(
|
|
741
|
+
f"validation.coexistance"
|
|
742
|
+
f".{ROUTE_TO_CALM_SLOT}_in_domain_with_no_router_defined",
|
|
743
|
+
event_info=(
|
|
744
|
+
f"The slot {ROUTE_TO_CALM_SLOT} is in the domain but the "
|
|
745
|
+
f"LLMBasedRouter or the IntentBasedRouter is not in the config or "
|
|
746
|
+
f"the type of the slot is not bool."
|
|
747
|
+
),
|
|
748
|
+
)
|
|
749
|
+
sys.exit(1)
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
def valid_nlu_entry_config(config: Optional[Dict[str, Any]]) -> bool:
|
|
753
|
+
return (
|
|
754
|
+
config is not None
|
|
755
|
+
and NLU_ENTRY in config
|
|
756
|
+
and isinstance(config[NLU_ENTRY], dict)
|
|
757
|
+
and STICKY in config[NLU_ENTRY]
|
|
758
|
+
and NON_STICKY in config[NLU_ENTRY]
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def valid_calm_entry_config(config: Optional[Dict[str, Any]]) -> bool:
|
|
763
|
+
return (
|
|
764
|
+
config is not None
|
|
765
|
+
and CALM_ENTRY in config
|
|
766
|
+
and isinstance(config[CALM_ENTRY], dict)
|
|
767
|
+
and STICKY in config[CALM_ENTRY]
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
def validate_configuration(
|
|
772
|
+
schema: GraphSchema,
|
|
773
|
+
) -> None:
|
|
774
|
+
"""Validate the configuration of the existing coexistence routers."""
|
|
775
|
+
if schema.has_node(IntentBasedRouter, include_subtypes=False):
|
|
776
|
+
config = get_component_config(schema, IntentBasedRouter)
|
|
777
|
+
if not valid_calm_entry_config(config) or not valid_nlu_entry_config(config):
|
|
703
778
|
structlogger.error(
|
|
704
|
-
"validation.coexistance.
|
|
779
|
+
"validation.coexistance.invalid_configuration",
|
|
705
780
|
event_info=(
|
|
706
|
-
"
|
|
707
|
-
"
|
|
781
|
+
"The configuration of the IntentBasedRouter is invalid. "
|
|
782
|
+
"Please check the documentation.",
|
|
708
783
|
),
|
|
709
784
|
)
|
|
710
785
|
sys.exit(1)
|
|
711
786
|
|
|
712
|
-
|
|
713
|
-
schema
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
slot_has_issue = (
|
|
719
|
-
len(routing_slots) == 0 or routing_slots[0].type_name != "bool"
|
|
720
|
-
)
|
|
721
|
-
if router_present and slot_has_issue:
|
|
722
|
-
structlogger.error(
|
|
723
|
-
f"validation.coexistance.{ROUTE_TO_CALM_SLOT}_not_in_domain",
|
|
724
|
-
event_info=(
|
|
725
|
-
f"{router_type.__name__} is in the config, but the slot "
|
|
726
|
-
f"{ROUTE_TO_CALM_SLOT} is not in the domain or not of "
|
|
727
|
-
f"type bool."
|
|
728
|
-
),
|
|
729
|
-
)
|
|
730
|
-
sys.exit(1)
|
|
731
|
-
|
|
732
|
-
def validate_that_router_is_defined_if_router_slots_are_in_domain(
|
|
733
|
-
schema: GraphSchema,
|
|
734
|
-
routing_slots: List[Slot],
|
|
735
|
-
) -> None:
|
|
736
|
-
defined_router_slots = len(routing_slots) > 0
|
|
737
|
-
router_present = schema.has_node(IntentBasedRouter) or schema.has_node(
|
|
738
|
-
LLMBasedRouter
|
|
739
|
-
)
|
|
740
|
-
if defined_router_slots and (
|
|
741
|
-
not router_present or routing_slots[0].type_name != "bool"
|
|
787
|
+
if schema.has_node(LLMBasedRouter, include_subtypes=False):
|
|
788
|
+
config = get_component_config(schema, LLMBasedRouter)
|
|
789
|
+
if not valid_calm_entry_config(config) or (
|
|
790
|
+
config is not None
|
|
791
|
+
and NLU_ENTRY in config
|
|
792
|
+
and not valid_nlu_entry_config(config)
|
|
742
793
|
):
|
|
743
794
|
structlogger.error(
|
|
744
|
-
|
|
745
|
-
f".{ROUTE_TO_CALM_SLOT}_in_domain_with_no_router_defined",
|
|
795
|
+
"validation.coexistance.invalid_configuration",
|
|
746
796
|
event_info=(
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
f"the type of the slot is not bool."
|
|
797
|
+
"The configuration of the LLMBasedRouter is invalid. "
|
|
798
|
+
"Please check the documentation.",
|
|
750
799
|
),
|
|
751
800
|
)
|
|
752
801
|
sys.exit(1)
|
|
753
802
|
|
|
754
|
-
def valid_nlu_entry_config(config: Optional[Dict[str, Any]]) -> bool:
|
|
755
|
-
return (
|
|
756
|
-
config is not None
|
|
757
|
-
and NLU_ENTRY in config
|
|
758
|
-
and isinstance(config[NLU_ENTRY], dict)
|
|
759
|
-
and STICKY in config[NLU_ENTRY]
|
|
760
|
-
and NON_STICKY in config[NLU_ENTRY]
|
|
761
|
-
)
|
|
762
803
|
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
and STICKY in config[CALM_ENTRY]
|
|
769
|
-
)
|
|
770
|
-
|
|
771
|
-
def validate_configuration(
|
|
772
|
-
schema: GraphSchema,
|
|
773
|
-
) -> None:
|
|
774
|
-
"""Validate the configuration of the existing coexistence routers."""
|
|
775
|
-
if schema.has_node(IntentBasedRouter, include_subtypes=False):
|
|
776
|
-
config = get_component_config(schema, IntentBasedRouter)
|
|
777
|
-
if not valid_calm_entry_config(config) or not valid_nlu_entry_config(
|
|
778
|
-
config
|
|
779
|
-
):
|
|
780
|
-
structlogger.error(
|
|
781
|
-
"validation.coexistance.invalid_configuration",
|
|
782
|
-
event_info=(
|
|
783
|
-
"The configuration of the IntentBasedRouter is invalid. "
|
|
784
|
-
"Please check the documentation.",
|
|
785
|
-
),
|
|
786
|
-
)
|
|
787
|
-
sys.exit(1)
|
|
788
|
-
|
|
789
|
-
if schema.has_node(LLMBasedRouter, include_subtypes=False):
|
|
790
|
-
config = get_component_config(schema, LLMBasedRouter)
|
|
791
|
-
if not valid_calm_entry_config(config):
|
|
792
|
-
structlogger.error(
|
|
793
|
-
"validation.coexistance.invalid_configuration",
|
|
794
|
-
event_info=(
|
|
795
|
-
"The configuration of the LLMBasedRouter is invalid. "
|
|
796
|
-
"Please check the documentation.",
|
|
797
|
-
),
|
|
798
|
-
)
|
|
799
|
-
sys.exit(1)
|
|
804
|
+
def validate_coexistance_routing_setup(
|
|
805
|
+
domain: Domain, model_configuration: GraphModelConfiguration, flows: FlowsList
|
|
806
|
+
) -> None:
|
|
807
|
+
schema = model_configuration.predict_schema
|
|
808
|
+
routing_slots = [s for s in domain.slots if s.name == ROUTE_TO_CALM_SLOT]
|
|
800
809
|
|
|
801
810
|
def validate_that_router_or_router_slot_are_defined_if_action_reset_routing_is_used(
|
|
802
811
|
schema: GraphSchema, flows: FlowsList, routing_slots: List[Slot]
|
|
@@ -826,9 +835,6 @@ def validate_coexistance_routing_setup(
|
|
|
826
835
|
)
|
|
827
836
|
sys.exit(1)
|
|
828
837
|
|
|
829
|
-
schema = model_configuration.predict_schema
|
|
830
|
-
routing_slots = [s for s in domain.slots if s.name == ROUTE_TO_CALM_SLOT]
|
|
831
|
-
|
|
832
838
|
validate_router_exclusivity(schema)
|
|
833
839
|
validate_intent_based_router_position(schema)
|
|
834
840
|
validate_that_slots_are_defined_if_router_is_defined(schema, routing_slots)
|
|
@@ -837,3 +843,31 @@ def validate_coexistance_routing_setup(
|
|
|
837
843
|
validate_that_router_or_router_slot_are_defined_if_action_reset_routing_is_used(
|
|
838
844
|
schema, flows, routing_slots
|
|
839
845
|
)
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
def validate_command_generator_exclusivity(schema: GraphSchema) -> None:
|
|
849
|
+
"""Validate that multiple command generators are not defined at same time."""
|
|
850
|
+
from rasa.dialogue_understanding.generator import (
|
|
851
|
+
LLMBasedCommandGenerator,
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
count = schema.count_nodes_of_a_given_type(
|
|
855
|
+
LLMBasedCommandGenerator, include_subtypes=True
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
if count > 1:
|
|
859
|
+
structlogger.error(
|
|
860
|
+
"validation.command_generator.multiple_command_generator_defined",
|
|
861
|
+
event_info=(
|
|
862
|
+
"Multiple LLM based command generators are defined in the config. "
|
|
863
|
+
"Please use only one LLM based command generator."
|
|
864
|
+
),
|
|
865
|
+
)
|
|
866
|
+
sys.exit(1)
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
def validate_command_generator_setup(
|
|
870
|
+
model_configuration: GraphModelConfiguration,
|
|
871
|
+
) -> None:
|
|
872
|
+
schema = model_configuration.predict_schema
|
|
873
|
+
validate_command_generator_exclusivity(schema)
|
rasa/exceptions.py
CHANGED
|
@@ -20,9 +20,9 @@ class UnsupportedModelVersionError(RasaException):
|
|
|
20
20
|
def __str__(self) -> Text:
|
|
21
21
|
minimum_version = version.parse(MINIMUM_COMPATIBLE_VERSION)
|
|
22
22
|
return (
|
|
23
|
-
f"The model version is trained using Rasa
|
|
23
|
+
f"The model version is trained using Rasa Pro {self.model_version} "
|
|
24
24
|
f"and is not compatible with your current installation "
|
|
25
|
-
f"which supports models build with Rasa
|
|
25
|
+
f"which supports models build with Rasa Pro {minimum_version} "
|
|
26
26
|
f"or higher. "
|
|
27
27
|
f"This means that you either need to retrain your model "
|
|
28
28
|
f"or revert back to the Rasa version that trained the model "
|
|
@@ -203,7 +203,6 @@ class DefaultV1RecipeValidator(GraphComponent):
|
|
|
203
203
|
)
|
|
204
204
|
|
|
205
205
|
if training_data.lookup_tables:
|
|
206
|
-
|
|
207
206
|
if self._component_types.isdisjoint([CRFEntityExtractor, DIETClassifier]):
|
|
208
207
|
rasa.shared.utils.io.raise_warning(
|
|
209
208
|
f"You have defined training data consisting of lookup tables, but "
|
|
@@ -219,7 +218,6 @@ class DefaultV1RecipeValidator(GraphComponent):
|
|
|
219
218
|
)
|
|
220
219
|
|
|
221
220
|
elif CRFEntityExtractor in self._component_types:
|
|
222
|
-
|
|
223
221
|
crf_schema_nodes = [
|
|
224
222
|
schema_node
|
|
225
223
|
for schema_node in self._graph_schema.nodes.values()
|
|
@@ -295,9 +293,9 @@ class DefaultV1RecipeValidator(GraphComponent):
|
|
|
295
293
|
Both of these look for the same entities based on the same training data
|
|
296
294
|
leading to ambiguity in the results.
|
|
297
295
|
"""
|
|
298
|
-
extractors_in_configuration: Set[
|
|
299
|
-
|
|
300
|
-
|
|
296
|
+
extractors_in_configuration: Set[Type[GraphComponent]] = (
|
|
297
|
+
self._component_types.intersection(TRAINABLE_EXTRACTORS)
|
|
298
|
+
)
|
|
301
299
|
if len(extractors_in_configuration) > 1:
|
|
302
300
|
rasa.shared.utils.io.raise_warning(
|
|
303
301
|
f"You have defined multiple entity extractors that do the same job "
|
rasa/hooks.py
CHANGED
|
@@ -78,7 +78,6 @@ def create_tracker_store(
|
|
|
78
78
|
domain: "Domain",
|
|
79
79
|
event_broker: Optional["EventBroker"],
|
|
80
80
|
) -> "TrackerStore":
|
|
81
|
-
|
|
82
81
|
if isinstance(endpoint_config, EndpointConfig):
|
|
83
82
|
return AuthRetryTrackerStore(
|
|
84
83
|
endpoint_config=endpoint_config, domain=domain, event_broker=event_broker
|
rasa/model.py
CHANGED
|
@@ -74,7 +74,7 @@ def get_latest_model(model_path: Text = DEFAULT_MODELS_PATH) -> Optional[Text]:
|
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
def get_model_for_finetuning(
|
|
77
|
-
previous_model_file_or_dir: Union[Path, Text]
|
|
77
|
+
previous_model_file_or_dir: Union[Path, Text],
|
|
78
78
|
) -> Optional[Path]:
|
|
79
79
|
"""Gets validated path for model to finetune.
|
|
80
80
|
|
rasa/model_training.py
CHANGED
|
@@ -309,6 +309,7 @@ async def _train_graph(
|
|
|
309
309
|
rasa.engine.validation.validate_flow_component_dependencies(
|
|
310
310
|
flows, model_configuration
|
|
311
311
|
)
|
|
312
|
+
rasa.engine.validation.validate_command_generator_setup(model_configuration)
|
|
312
313
|
|
|
313
314
|
tempdir_name = rasa.utils.common.get_temp_dir_name()
|
|
314
315
|
# Use `TempDirectoryPath` instead of `tempfile.TemporaryDirectory` as this
|
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
import copy
|
|
4
3
|
import logging
|
|
5
4
|
from collections import defaultdict
|
|
6
5
|
from pathlib import Path
|
|
7
|
-
|
|
6
|
+
|
|
7
|
+
from rasa.exceptions import ModelNotFound
|
|
8
|
+
from rasa.nlu.featurizers.featurizer import Featurizer
|
|
8
9
|
|
|
9
10
|
import numpy as np
|
|
10
11
|
import scipy.sparse
|
|
11
12
|
import tensorflow as tf
|
|
12
13
|
|
|
13
|
-
from
|
|
14
|
-
|
|
14
|
+
from typing import Any, Dict, List, Optional, Text, Tuple, Union, TypeVar, Type
|
|
15
|
+
|
|
15
16
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
16
17
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
17
18
|
from rasa.engine.storage.resource import Resource
|
|
@@ -19,21 +20,18 @@ from rasa.engine.storage.storage import ModelStorage
|
|
|
19
20
|
from rasa.nlu.extractors.extractor import EntityExtractorMixin
|
|
20
21
|
from rasa.nlu.classifiers.classifier import IntentClassifier
|
|
21
22
|
import rasa.shared.utils.io
|
|
23
|
+
import rasa.utils.io as io_utils
|
|
22
24
|
import rasa.nlu.utils.bilou_utils as bilou_utils
|
|
23
25
|
from rasa.shared.constants import DIAGNOSTIC_DATA
|
|
24
26
|
from rasa.nlu.extractors.extractor import EntityTagSpec
|
|
25
27
|
from rasa.nlu.classifiers import LABEL_RANKING_LENGTH
|
|
26
28
|
from rasa.utils import train_utils
|
|
27
29
|
from rasa.utils.tensorflow import rasa_layers
|
|
28
|
-
from rasa.utils.tensorflow.feature_array import (
|
|
29
|
-
FeatureArray,
|
|
30
|
-
serialize_nested_feature_arrays,
|
|
31
|
-
deserialize_nested_feature_arrays,
|
|
32
|
-
)
|
|
33
30
|
from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
|
|
34
31
|
from rasa.utils.tensorflow.model_data import (
|
|
35
32
|
RasaModelData,
|
|
36
33
|
FeatureSignature,
|
|
34
|
+
FeatureArray,
|
|
37
35
|
)
|
|
38
36
|
from rasa.nlu.constants import TOKENS_NAMES, DEFAULT_TRANSFORMER_SIZE
|
|
39
37
|
from rasa.shared.nlu.constants import (
|
|
@@ -120,6 +118,7 @@ LABEL_SUB_KEY = IDS
|
|
|
120
118
|
|
|
121
119
|
POSSIBLE_TAGS = [ENTITY_ATTRIBUTE_TYPE, ENTITY_ATTRIBUTE_ROLE, ENTITY_ATTRIBUTE_GROUP]
|
|
122
120
|
|
|
121
|
+
|
|
123
122
|
DIETClassifierT = TypeVar("DIETClassifierT", bound="DIETClassifier")
|
|
124
123
|
|
|
125
124
|
|
|
@@ -511,7 +510,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
511
510
|
def _extract_features(
|
|
512
511
|
self, message: Message, attribute: Text
|
|
513
512
|
) -> Dict[Text, Union[scipy.sparse.spmatrix, np.ndarray]]:
|
|
514
|
-
|
|
515
513
|
(
|
|
516
514
|
sparse_sequence_features,
|
|
517
515
|
sparse_sentence_features,
|
|
@@ -781,7 +779,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
781
779
|
sparse_feature_sizes: Dict[Text, Dict[Text, List[int]]],
|
|
782
780
|
label_attribute: Optional[Text] = None,
|
|
783
781
|
) -> Dict[Text, Dict[Text, List[int]]]:
|
|
784
|
-
|
|
785
782
|
if label_attribute in sparse_feature_sizes:
|
|
786
783
|
del sparse_feature_sizes[label_attribute]
|
|
787
784
|
return sparse_feature_sizes
|
|
@@ -1086,24 +1083,18 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1086
1083
|
|
|
1087
1084
|
self.model.save(str(tf_model_file))
|
|
1088
1085
|
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
self._data_example,
|
|
1092
|
-
model_path / f"{file_name}.data_example.st",
|
|
1093
|
-
model_path / f"{file_name}.data_example_metadata.json",
|
|
1086
|
+
io_utils.pickle_dump(
|
|
1087
|
+
model_path / f"{file_name}.data_example.pkl", self._data_example
|
|
1094
1088
|
)
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
dict(self._label_data.data) if self._label_data is not None else {},
|
|
1098
|
-
model_path / f"{file_name}.label_data.st",
|
|
1099
|
-
model_path / f"{file_name}.label_data_metadata.json",
|
|
1100
|
-
)
|
|
1101
|
-
|
|
1102
|
-
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1103
|
-
model_path / f"{file_name}.sparse_feature_sizes.json",
|
|
1089
|
+
io_utils.pickle_dump(
|
|
1090
|
+
model_path / f"{file_name}.sparse_feature_sizes.pkl",
|
|
1104
1091
|
self._sparse_feature_sizes,
|
|
1105
1092
|
)
|
|
1106
|
-
|
|
1093
|
+
io_utils.pickle_dump(
|
|
1094
|
+
model_path / f"{file_name}.label_data.pkl",
|
|
1095
|
+
dict(self._label_data.data) if self._label_data is not None else {},
|
|
1096
|
+
)
|
|
1097
|
+
io_utils.json_pickle(
|
|
1107
1098
|
model_path / f"{file_name}.index_label_id_mapping.json",
|
|
1108
1099
|
self.index_label_id_mapping,
|
|
1109
1100
|
)
|
|
@@ -1192,22 +1183,15 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1192
1183
|
]:
|
|
1193
1184
|
file_name = cls.__name__
|
|
1194
1185
|
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
str(model_path / f"{file_name}.data_example.st"),
|
|
1198
|
-
str(model_path / f"{file_name}.data_example_metadata.json"),
|
|
1186
|
+
data_example = io_utils.pickle_load(
|
|
1187
|
+
model_path / f"{file_name}.data_example.pkl"
|
|
1199
1188
|
)
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
)
|
|
1205
|
-
label_data = RasaModelData(data=loaded_label_data)
|
|
1206
|
-
|
|
1207
|
-
sparse_feature_sizes = rasa.shared.utils.io.read_json_file(
|
|
1208
|
-
model_path / f"{file_name}.sparse_feature_sizes.json"
|
|
1189
|
+
label_data = io_utils.pickle_load(model_path / f"{file_name}.label_data.pkl")
|
|
1190
|
+
label_data = RasaModelData(data=label_data)
|
|
1191
|
+
sparse_feature_sizes = io_utils.pickle_load(
|
|
1192
|
+
model_path / f"{file_name}.sparse_feature_sizes.pkl"
|
|
1209
1193
|
)
|
|
1210
|
-
index_label_id_mapping =
|
|
1194
|
+
index_label_id_mapping = io_utils.json_unpickle(
|
|
1211
1195
|
model_path / f"{file_name}.index_label_id_mapping.json"
|
|
1212
1196
|
)
|
|
1213
1197
|
entity_tag_specs = rasa.shared.utils.io.read_json_file(
|
|
@@ -1227,6 +1211,7 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1227
1211
|
for tag_spec in entity_tag_specs
|
|
1228
1212
|
]
|
|
1229
1213
|
|
|
1214
|
+
# jsonpickle converts dictionary keys to strings
|
|
1230
1215
|
index_label_id_mapping = {
|
|
1231
1216
|
int(key): value for key, value in index_label_id_mapping.items()
|
|
1232
1217
|
}
|
|
@@ -1280,7 +1265,6 @@ class DIETClassifier(GraphComponent, IntentClassifier, EntityExtractorMixin):
|
|
|
1280
1265
|
config: Dict[Text, Any],
|
|
1281
1266
|
finetune_mode: bool,
|
|
1282
1267
|
) -> "RasaModel":
|
|
1283
|
-
|
|
1284
1268
|
predict_data_example = RasaModelData(
|
|
1285
1269
|
label_key=model_data_example.label_key,
|
|
1286
1270
|
data={
|
|
@@ -1467,10 +1451,10 @@ class DIET(TransformerRasaModel):
|
|
|
1467
1451
|
# everything using a transformer and optionally also do masked language
|
|
1468
1452
|
# modeling.
|
|
1469
1453
|
self.text_name = TEXT
|
|
1470
|
-
self._tf_layers[
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1454
|
+
self._tf_layers[f"sequence_layer.{self.text_name}"] = (
|
|
1455
|
+
rasa_layers.RasaSequenceLayer(
|
|
1456
|
+
self.text_name, self.data_signature[self.text_name], self.config
|
|
1457
|
+
)
|
|
1474
1458
|
)
|
|
1475
1459
|
if self.config[MASKED_LM]:
|
|
1476
1460
|
self._prepare_mask_lm_loss(self.text_name)
|
|
@@ -1488,10 +1472,10 @@ class DIET(TransformerRasaModel):
|
|
|
1488
1472
|
{SPARSE_INPUT_DROPOUT: False, DENSE_INPUT_DROPOUT: False}
|
|
1489
1473
|
)
|
|
1490
1474
|
|
|
1491
|
-
self._tf_layers[
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1475
|
+
self._tf_layers[f"feature_combining_layer.{self.label_name}"] = (
|
|
1476
|
+
rasa_layers.RasaFeatureCombiningLayer(
|
|
1477
|
+
self.label_name, self.label_signature[self.label_name], label_config
|
|
1478
|
+
)
|
|
1495
1479
|
)
|
|
1496
1480
|
|
|
1497
1481
|
self._prepare_ffnn_layer(
|
|
@@ -1523,7 +1507,6 @@ class DIET(TransformerRasaModel):
|
|
|
1523
1507
|
sequence_feature_lengths: tf.Tensor,
|
|
1524
1508
|
name: Text,
|
|
1525
1509
|
) -> tf.Tensor:
|
|
1526
|
-
|
|
1527
1510
|
x, _ = self._tf_layers[f"feature_combining_layer.{name}"](
|
|
1528
1511
|
(sequence_features, sentence_features, sequence_feature_lengths),
|
|
1529
1512
|
training=self._training,
|
|
@@ -1705,7 +1688,6 @@ class DIET(TransformerRasaModel):
|
|
|
1705
1688
|
return loss
|
|
1706
1689
|
|
|
1707
1690
|
def _update_label_metrics(self, loss: tf.Tensor, acc: tf.Tensor) -> None:
|
|
1708
|
-
|
|
1709
1691
|
self.intent_loss.update_state(loss)
|
|
1710
1692
|
self.intent_acc.update_state(acc)
|
|
1711
1693
|
|
|
@@ -1864,7 +1846,6 @@ class DIET(TransformerRasaModel):
|
|
|
1864
1846
|
combined_sequence_sentence_feature_lengths: tf.Tensor,
|
|
1865
1847
|
text_transformed: tf.Tensor,
|
|
1866
1848
|
) -> Dict[Text, tf.Tensor]:
|
|
1867
|
-
|
|
1868
1849
|
if self.all_labels_embed is None:
|
|
1869
1850
|
raise ValueError(
|
|
1870
1851
|
"The model was not prepared for prediction. "
|