rasa-pro 3.8.18__py3-none-any.whl → 3.9.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +6 -42
- rasa/__main__.py +14 -9
- rasa/anonymization/anonymization_pipeline.py +0 -1
- rasa/anonymization/anonymization_rule_executor.py +3 -3
- rasa/anonymization/utils.py +4 -3
- rasa/api.py +2 -2
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/arguments/run.py +2 -2
- rasa/cli/arguments/test.py +1 -1
- rasa/cli/arguments/train.py +10 -10
- rasa/cli/e2e_test.py +27 -7
- rasa/cli/export.py +0 -1
- rasa/cli/license.py +3 -3
- rasa/cli/project_templates/calm/actions/action_template.py +1 -1
- rasa/cli/project_templates/calm/config.yml +1 -1
- rasa/cli/project_templates/calm/credentials.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +1 -1
- rasa/cli/project_templates/calm/domain/add_contact.yml +8 -2
- rasa/cli/project_templates/calm/domain/list_contacts.yml +3 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +9 -2
- rasa/cli/project_templates/calm/domain/shared.yml +5 -0
- rasa/cli/project_templates/calm/endpoints.yml +4 -4
- rasa/cli/project_templates/default/actions/actions.py +1 -1
- rasa/cli/project_templates/default/config.yml +5 -5
- rasa/cli/project_templates/default/credentials.yml +1 -1
- rasa/cli/project_templates/default/endpoints.yml +4 -4
- rasa/cli/project_templates/default/tests/test_stories.yml +1 -1
- rasa/cli/project_templates/tutorial/config.yml +1 -1
- rasa/cli/project_templates/tutorial/credentials.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
- rasa/cli/project_templates/tutorial/domain.yml +4 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +6 -6
- rasa/cli/run.py +0 -1
- rasa/cli/scaffold.py +3 -2
- rasa/cli/studio/download.py +11 -0
- rasa/cli/studio/studio.py +180 -24
- rasa/cli/studio/upload.py +0 -8
- rasa/cli/telemetry.py +18 -6
- rasa/cli/utils.py +21 -10
- rasa/cli/x.py +3 -2
- rasa/constants.py +1 -1
- rasa/core/actions/action.py +90 -315
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/constants.py +3 -0
- rasa/core/actions/custom_action_executor.py +188 -0
- rasa/core/actions/forms.py +11 -7
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +140 -0
- rasa/core/actions/loops.py +3 -0
- rasa/core/actions/two_stage_fallback.py +1 -1
- rasa/core/agent.py +2 -4
- rasa/core/brokers/pika.py +1 -2
- rasa/core/channels/audiocodes.py +1 -1
- rasa/core/channels/botframework.py +0 -1
- rasa/core/channels/callback.py +0 -1
- rasa/core/channels/console.py +6 -8
- rasa/core/channels/development_inspector.py +1 -1
- rasa/core/channels/facebook.py +0 -3
- rasa/core/channels/hangouts.py +0 -6
- rasa/core/channels/inspector/dist/assets/{arc-5623b6dc.js → arc-b6e548fe.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-685c106a.js → c4Diagram-d0fbc5ce-fa03ac9e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-8cbed007.js → classDiagram-936ed81e-ee67392a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-5889cf12.js → classDiagram-v2-c3cb15f1-9b283fae.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-24c249d7.js → createText-62fc7601-8b6fcc2a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-7dd06a75.js → edges-f2ad444c-22e77f4f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-62c1e54c.js → erDiagram-9d236eb7-60ffc87f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-ce49b86f.js → flowDb-1972c806-9dd802e4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4067e48f.js → flowDiagram-7ea5b25a-5fa1912f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-59fe4051.js → flowchart-elk-definition-abe16c3d-622a1fd2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-47e3a43b.js → ganttDiagram-9b5ea136-e285a63a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-5a2ac0d9.js → gitGraphDiagram-99d0ae7c-f237bdca.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-dfb8efc4.js → index-2c4b9a3b-4b03d70e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-268a75c0.js → index-a5d3e69d.js} +4 -4
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-b0c470f2.js → infoDiagram-736b4530-72a0fa5f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-2edb829a.js → journeyDiagram-df861f2b-82218c41.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b6873d69.js → layout-78cff630.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-1efc5781.js → line-5038b469.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-661e9b94.js → linear-c4fc4098.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2d2e727f.js → mindmap-definition-beec6740-c33c8ea6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-9d3ea93d.js → pieDiagram-dbbf0591-a8d03059.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-06a178a2.js → quadrantDiagram-4d7f4fd6-6a0e56b2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-0bfedffc.js → requirementDiagram-6fc4c22a-2dc7c7bd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-d76d0a04.js → sankeyDiagram-8f13d901-2360fe39.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-37bb4341.js → sequenceDiagram-b655622a-41b9f9ad.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-f52f7f57.js → stateDiagram-59f0c015-0aad326f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-4a986a20.js → stateDiagram-v2-2b26beab-9847d984.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-7dd9ae12.js → styles-080da4f6-564d890e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-46e1ca14.js → styles-3dcbcfbf-38957613.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-4a97439a.js → styles-9c745c82-f0fc6921.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-823917a3.js → svgDrawCommon-4835440b-ef3c5a77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-9ea72896.js → timeline-definition-5b62e21b-bf3e91c1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-b631a8b6.js → xychartDiagram-2b33534f-4d4026c0.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -7
- rasa/core/channels/inspector/src/helpers/formatters.ts +3 -2
- rasa/core/channels/rest.py +36 -21
- rasa/core/channels/rocketchat.py +0 -1
- rasa/core/channels/socketio.py +1 -1
- rasa/core/channels/telegram.py +3 -3
- rasa/core/channels/webexteams.py +0 -1
- rasa/core/concurrent_lock_store.py +1 -1
- rasa/core/evaluation/marker_base.py +1 -3
- rasa/core/evaluation/marker_stats.py +1 -2
- rasa/core/featurizers/single_state_featurizer.py +3 -26
- rasa/core/featurizers/tracker_featurizers.py +18 -122
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +9 -4
- rasa/core/information_retrieval/information_retrieval.py +64 -7
- rasa/core/information_retrieval/milvus.py +7 -14
- rasa/core/information_retrieval/qdrant.py +8 -15
- rasa/core/lock_store.py +0 -1
- rasa/core/migrate.py +1 -2
- rasa/core/nlg/callback.py +3 -4
- rasa/core/policies/enterprise_search_policy.py +86 -22
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +4 -41
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flows/flow_executor.py +104 -2
- rasa/core/policies/intentless_policy.py +7 -9
- rasa/core/policies/memoization.py +3 -3
- rasa/core/policies/policy.py +18 -9
- rasa/core/policies/rule_policy.py +8 -11
- rasa/core/policies/ted_policy.py +61 -88
- rasa/core/policies/unexpected_intent_policy.py +8 -17
- rasa/core/processor.py +136 -47
- rasa/core/run.py +41 -25
- rasa/core/secrets_manager/endpoints.py +2 -2
- rasa/core/secrets_manager/vault.py +6 -8
- rasa/core/test.py +3 -5
- rasa/core/tracker_store.py +49 -14
- rasa/core/train.py +1 -3
- rasa/core/training/interactive.py +9 -6
- rasa/core/utils.py +5 -10
- rasa/dialogue_understanding/coexistence/intent_based_router.py +11 -4
- rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -3
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +9 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +9 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +38 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/clarify_command.py +9 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +9 -0
- rasa/dialogue_understanding/commands/error_command.py +12 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +9 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +9 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/noop_command.py +9 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +38 -3
- rasa/dialogue_understanding/commands/skip_question_command.py +9 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +9 -0
- rasa/dialogue_understanding/generator/__init__.py +16 -1
- rasa/dialogue_understanding/generator/command_generator.py +92 -6
- rasa/dialogue_understanding/generator/constants.py +18 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +7 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +39 -609
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +69 -8
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +36 -31
- rasa/dialogue_understanding/processor/command_processor.py +112 -3
- rasa/e2e_test/constants.py +1 -0
- rasa/e2e_test/e2e_test_case.py +44 -0
- rasa/e2e_test/e2e_test_runner.py +114 -11
- rasa/e2e_test/e2e_test_schema.yml +18 -0
- rasa/engine/caching.py +0 -1
- rasa/engine/graph.py +18 -6
- rasa/engine/recipes/config_files/default_config.yml +3 -3
- rasa/engine/recipes/default_components.py +1 -1
- rasa/engine/recipes/default_recipe.py +4 -5
- rasa/engine/recipes/recipe.py +1 -1
- rasa/engine/runner/dask.py +3 -9
- rasa/engine/storage/local_model_storage.py +0 -2
- rasa/engine/validation.py +179 -145
- rasa/exceptions.py +2 -2
- rasa/graph_components/validators/default_recipe_validator.py +3 -5
- rasa/hooks.py +0 -1
- rasa/model.py +1 -1
- rasa/model_training.py +1 -0
- rasa/nlu/classifiers/diet_classifier.py +33 -52
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +54 -97
- rasa/nlu/extractors/duckling_entity_extractor.py +1 -1
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +1 -5
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +0 -4
- rasa/nlu/featurizers/featurizer.py +1 -1
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +18 -49
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +26 -64
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/persistor.py +68 -26
- rasa/nlu/selectors/response_selector.py +7 -10
- rasa/nlu/test.py +0 -3
- rasa/nlu/utils/hugging_face/registry.py +1 -1
- rasa/nlu/utils/spacy_utils.py +1 -3
- rasa/server.py +22 -7
- rasa/shared/constants.py +12 -1
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +4 -5
- rasa/shared/core/domain.py +57 -56
- rasa/shared/core/events.py +4 -7
- rasa/shared/core/flows/flow.py +9 -0
- rasa/shared/core/flows/flows_list.py +12 -0
- rasa/shared/core/flows/steps/action.py +7 -2
- rasa/shared/core/generator.py +12 -11
- rasa/shared/core/slot_mappings.py +315 -24
- rasa/shared/core/slots.py +4 -2
- rasa/shared/core/trackers.py +32 -14
- rasa/shared/core/training_data/loading.py +0 -1
- rasa/shared/core/training_data/story_reader/story_reader.py +3 -3
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +11 -11
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +5 -3
- rasa/shared/core/training_data/structures.py +1 -1
- rasa/shared/core/training_data/visualization.py +1 -1
- rasa/shared/data.py +58 -1
- rasa/shared/exceptions.py +36 -2
- rasa/shared/importers/importer.py +1 -2
- rasa/shared/importers/rasa.py +0 -1
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/entities_parser.py +1 -2
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/dialogflow.py +3 -2
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -5
- rasa/shared/nlu/training_data/formats/readerwriter.py +0 -1
- rasa/shared/nlu/training_data/message.py +13 -0
- rasa/shared/nlu/training_data/training_data.py +0 -2
- rasa/shared/providers/openai/session_handler.py +2 -2
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/io.py +11 -1
- rasa/shared/utils/llm.py +1 -2
- rasa/shared/utils/pykwalify_extensions.py +1 -0
- rasa/shared/utils/schemas/domain.yml +3 -0
- rasa/shared/utils/yaml.py +44 -35
- rasa/studio/auth.py +26 -10
- rasa/studio/constants.py +2 -0
- rasa/studio/data_handler.py +114 -107
- rasa/studio/download.py +160 -27
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +6 -7
- rasa/studio/upload.py +159 -134
- rasa/telemetry.py +188 -34
- rasa/tracing/config.py +18 -3
- rasa/tracing/constants.py +26 -2
- rasa/tracing/instrumentation/attribute_extractors.py +50 -41
- rasa/tracing/instrumentation/instrumentation.py +290 -44
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +7 -5
- rasa/tracing/instrumentation/metrics.py +109 -21
- rasa/tracing/metric_instrument_provider.py +83 -3
- rasa/utils/cli.py +2 -1
- rasa/utils/common.py +1 -1
- rasa/utils/endpoints.py +1 -2
- rasa/utils/io.py +72 -6
- rasa/utils/licensing.py +246 -31
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/data_generator.py +1 -1
- rasa/utils/tensorflow/environment.py +1 -1
- rasa/utils/tensorflow/model_data.py +201 -12
- rasa/utils/tensorflow/model_data_utils.py +499 -500
- rasa/utils/tensorflow/models.py +5 -6
- rasa/utils/tensorflow/rasa_layers.py +15 -15
- rasa/utils/train_utils.py +1 -1
- rasa/utils/url_tools.py +53 -0
- rasa/validator.py +305 -3
- rasa/version.py +1 -1
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/METADATA +25 -61
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/RECORD +276 -259
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +0 -1
- rasa/utils/tensorflow/feature_array.py +0 -370
- /rasa/dialogue_understanding/generator/{command_prompt_template.jinja2 → single_step/command_prompt_template.jinja2} +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/NOTICE +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/WHEEL +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/entry_points.txt +0 -0
|
@@ -3,6 +3,8 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import Any, Dict, Text, List, Optional
|
|
4
4
|
|
|
5
5
|
from jinja2 import Template
|
|
6
|
+
from rasa.dialogue_understanding.commands import CancelFlowCommand
|
|
7
|
+
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
6
8
|
from structlog.contextvars import (
|
|
7
9
|
bound_contextvars,
|
|
8
10
|
)
|
|
@@ -17,6 +19,9 @@ from rasa.core.policies.flows.flow_step_result import (
|
|
|
17
19
|
FlowStepResult,
|
|
18
20
|
PauseFlowReturnPrediction,
|
|
19
21
|
)
|
|
22
|
+
from rasa.dialogue_understanding.patterns.internal_error import (
|
|
23
|
+
InternalErrorPatternFlowStackFrame,
|
|
24
|
+
)
|
|
20
25
|
from rasa.dialogue_understanding.patterns.search import SearchPatternFlowStackFrame
|
|
21
26
|
from rasa.dialogue_understanding.stack.dialogue_stack import DialogueStack
|
|
22
27
|
from rasa.dialogue_understanding.stack.frames import (
|
|
@@ -42,7 +47,7 @@ from rasa.dialogue_understanding.stack.utils import (
|
|
|
42
47
|
|
|
43
48
|
from pypred import Predicate
|
|
44
49
|
|
|
45
|
-
from rasa.shared.core.constants import ACTION_LISTEN_NAME
|
|
50
|
+
from rasa.shared.core.constants import ACTION_LISTEN_NAME, SlotMappingType
|
|
46
51
|
from rasa.shared.core.events import (
|
|
47
52
|
Event,
|
|
48
53
|
FlowCompleted,
|
|
@@ -72,6 +77,7 @@ from rasa.shared.core.flows.flow import (
|
|
|
72
77
|
FlowStep,
|
|
73
78
|
)
|
|
74
79
|
from rasa.shared.core.flows.steps.collect import SlotRejection
|
|
80
|
+
from rasa.shared.core.slots import Slot
|
|
75
81
|
from rasa.shared.core.trackers import (
|
|
76
82
|
DialogueStateTracker,
|
|
77
83
|
)
|
|
@@ -397,7 +403,6 @@ def advance_flows_until_next_action(
|
|
|
397
403
|
number_of_steps_taken = 0
|
|
398
404
|
|
|
399
405
|
while isinstance(step_result, ContinueFlowWithNextStep):
|
|
400
|
-
|
|
401
406
|
number_of_steps_taken += 1
|
|
402
407
|
if number_of_steps_taken > MAX_NUMBER_OF_STEPS:
|
|
403
408
|
raise FlowCircuitBreakerTrippedException(
|
|
@@ -467,6 +472,87 @@ def advance_flows_until_next_action(
|
|
|
467
472
|
return FlowActionPrediction(None, 0.0, events=gathered_events)
|
|
468
473
|
|
|
469
474
|
|
|
475
|
+
def validate_collect_step(
|
|
476
|
+
step: CollectInformationFlowStep,
|
|
477
|
+
stack: DialogueStack,
|
|
478
|
+
available_actions: List[str],
|
|
479
|
+
slots: Dict[Text, Slot],
|
|
480
|
+
) -> bool:
|
|
481
|
+
"""Validate that a collect step can be executed.
|
|
482
|
+
|
|
483
|
+
A collect step can be executed if either the `utter_ask` or the `action_ask` is
|
|
484
|
+
defined in the domain. If neither is defined, the collect step can still be
|
|
485
|
+
executed if the slot has an initial value defined in the domain, which would cause
|
|
486
|
+
the step to be skipped."""
|
|
487
|
+
slot = slots.get(step.collect)
|
|
488
|
+
slot_has_initial_value_defined = slot and slot.initial_value is not None
|
|
489
|
+
if (
|
|
490
|
+
slot_has_initial_value_defined
|
|
491
|
+
or step.utter in available_actions
|
|
492
|
+
or step.collect_action in available_actions
|
|
493
|
+
):
|
|
494
|
+
return True
|
|
495
|
+
|
|
496
|
+
structlogger.error(
|
|
497
|
+
"flow.step.run.collect_missing_utter_or_collect_action",
|
|
498
|
+
slot_name=step.collect,
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
cancel_flow_and_push_internal_error(stack)
|
|
502
|
+
|
|
503
|
+
return False
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def cancel_flow_and_push_internal_error(stack: DialogueStack) -> None:
|
|
507
|
+
"""Cancel the top user flow and push the internal error pattern."""
|
|
508
|
+
top_frame = stack.top()
|
|
509
|
+
|
|
510
|
+
if isinstance(top_frame, BaseFlowStackFrame):
|
|
511
|
+
# we need to first cancel the top user flow
|
|
512
|
+
# because we cannot collect one of its slots
|
|
513
|
+
# and therefore should not proceed with the flow
|
|
514
|
+
# after triggering pattern_internal_error
|
|
515
|
+
canceled_frames = CancelFlowCommand.select_canceled_frames(stack)
|
|
516
|
+
stack.push(
|
|
517
|
+
CancelPatternFlowStackFrame(
|
|
518
|
+
canceled_name=top_frame.flow_id,
|
|
519
|
+
canceled_frames=canceled_frames,
|
|
520
|
+
)
|
|
521
|
+
)
|
|
522
|
+
stack.push(InternalErrorPatternFlowStackFrame())
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def validate_custom_slot_mappings(
|
|
526
|
+
step: CollectInformationFlowStep,
|
|
527
|
+
stack: DialogueStack,
|
|
528
|
+
tracker: DialogueStateTracker,
|
|
529
|
+
available_actions: List[str],
|
|
530
|
+
) -> bool:
|
|
531
|
+
"""Validate a slot with custom mappings.
|
|
532
|
+
|
|
533
|
+
If invalid, trigger pattern_internal_error and return False.
|
|
534
|
+
"""
|
|
535
|
+
slot = tracker.slots.get(step.collect, None)
|
|
536
|
+
slot_mappings = slot.mappings if slot else []
|
|
537
|
+
for mapping in slot_mappings:
|
|
538
|
+
if (
|
|
539
|
+
mapping.get("type") == SlotMappingType.CUSTOM.value
|
|
540
|
+
and mapping.get("action") is None
|
|
541
|
+
):
|
|
542
|
+
# this is a slot that must be filled by a custom action
|
|
543
|
+
# check if collect_action exists
|
|
544
|
+
if step.collect_action not in available_actions:
|
|
545
|
+
structlogger.error(
|
|
546
|
+
"flow.step.run.collect_action_not_found_for_custom_slot_mapping",
|
|
547
|
+
action=step.collect_action,
|
|
548
|
+
collect=step.collect,
|
|
549
|
+
)
|
|
550
|
+
cancel_flow_and_push_internal_error(stack)
|
|
551
|
+
return False
|
|
552
|
+
|
|
553
|
+
return True
|
|
554
|
+
|
|
555
|
+
|
|
470
556
|
def run_step(
|
|
471
557
|
step: FlowStep,
|
|
472
558
|
flow: Flow,
|
|
@@ -500,6 +586,22 @@ def run_step(
|
|
|
500
586
|
initial_events.append(FlowStarted(flow.id))
|
|
501
587
|
|
|
502
588
|
if isinstance(step, CollectInformationFlowStep):
|
|
589
|
+
is_step_valid = validate_collect_step(
|
|
590
|
+
step, stack, available_actions, tracker.slots
|
|
591
|
+
)
|
|
592
|
+
if not is_step_valid:
|
|
593
|
+
# if we return any other FlowStepResult, the assistant will stay silent
|
|
594
|
+
# instead of triggering the internal error pattern
|
|
595
|
+
return ContinueFlowWithNextStep(events=initial_events)
|
|
596
|
+
|
|
597
|
+
is_mapping_valid = validate_custom_slot_mappings(
|
|
598
|
+
step, stack, tracker, available_actions
|
|
599
|
+
)
|
|
600
|
+
if not is_mapping_valid:
|
|
601
|
+
# if we return any other FlowStepResult, the assistant will stay silent
|
|
602
|
+
# instead of triggering the internal error pattern
|
|
603
|
+
return ContinueFlowWithNextStep(events=initial_events)
|
|
604
|
+
|
|
503
605
|
structlogger.debug("flow.step.run.collect")
|
|
504
606
|
trigger_pattern_ask_collect_information(
|
|
505
607
|
step.collect, stack, step.rejections, step.utter, step.collect_action
|
|
@@ -3,7 +3,6 @@ import math
|
|
|
3
3
|
from dataclasses import dataclass, field
|
|
4
4
|
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Text, Tuple
|
|
5
5
|
|
|
6
|
-
import rasa.shared.utils.io
|
|
7
6
|
import structlog
|
|
8
7
|
import tiktoken
|
|
9
8
|
from jinja2 import Template
|
|
@@ -11,6 +10,7 @@ from langchain.docstore.document import Document
|
|
|
11
10
|
from langchain.schema.embeddings import Embeddings
|
|
12
11
|
from langchain.vectorstores import FAISS
|
|
13
12
|
|
|
13
|
+
import rasa.shared.utils.io
|
|
14
14
|
from rasa import telemetry
|
|
15
15
|
from rasa.core.constants import (
|
|
16
16
|
CHAT_POLICY_PRIORITY,
|
|
@@ -56,7 +56,6 @@ from rasa.shared.utils.llm import (
|
|
|
56
56
|
sanitize_message_for_prompt,
|
|
57
57
|
tracker_as_readable_transcript,
|
|
58
58
|
)
|
|
59
|
-
|
|
60
59
|
from rasa.utils.ml_utils import (
|
|
61
60
|
extract_ai_response_examples,
|
|
62
61
|
extract_participant_messages_from_transcript,
|
|
@@ -65,7 +64,6 @@ from rasa.utils.ml_utils import (
|
|
|
65
64
|
persist_faiss_vector_store,
|
|
66
65
|
response_for_template,
|
|
67
66
|
)
|
|
68
|
-
|
|
69
67
|
from rasa.utils.log_utils import log_llm
|
|
70
68
|
|
|
71
69
|
if TYPE_CHECKING:
|
|
@@ -543,7 +541,9 @@ class IntentlessPolicy(Policy):
|
|
|
543
541
|
Returns:
|
|
544
542
|
The prediction.
|
|
545
543
|
"""
|
|
546
|
-
if not self.supports_current_stack_frame(
|
|
544
|
+
if not self.supports_current_stack_frame(
|
|
545
|
+
tracker
|
|
546
|
+
) or self.should_abstain_in_coexistence(tracker, True):
|
|
547
547
|
return self._prediction(self._default_predictions(domain))
|
|
548
548
|
|
|
549
549
|
if tracker.has_bot_message_after_latest_user_message():
|
|
@@ -670,7 +670,7 @@ class IntentlessPolicy(Policy):
|
|
|
670
670
|
if tracker.latest_message.text.startswith("/"):
|
|
671
671
|
# we don't want to generate a response if the user is trying to
|
|
672
672
|
# execute a "command" - this should be handled by the regex
|
|
673
|
-
# intent classifier in rasa
|
|
673
|
+
# intent classifier in rasa pro.
|
|
674
674
|
structlogger.debug("intentless_policy.prediction.skip_slash")
|
|
675
675
|
return None, 0.0
|
|
676
676
|
|
|
@@ -863,7 +863,7 @@ class IntentlessPolicy(Policy):
|
|
|
863
863
|
"""
|
|
864
864
|
result = self._default_predictions(domain)
|
|
865
865
|
if action_name:
|
|
866
|
-
result[domain.index_for_action(action_name)] = score # type: ignore[assignment]
|
|
866
|
+
result[domain.index_for_action(action_name)] = score # type: ignore[assignment]
|
|
867
867
|
return result
|
|
868
868
|
|
|
869
869
|
@classmethod
|
|
@@ -892,9 +892,7 @@ class IntentlessPolicy(Policy):
|
|
|
892
892
|
# normalized. unfortunatley langchain doesn't persist / load
|
|
893
893
|
# this parameter.
|
|
894
894
|
if responses_docsearch:
|
|
895
|
-
responses_docsearch._normalize_L2 =
|
|
896
|
-
True # pylint: disable=protected-access
|
|
897
|
-
)
|
|
895
|
+
responses_docsearch._normalize_L2 = True # pylint: disable=protected-access
|
|
898
896
|
prompt_template = rasa.shared.utils.io.read_file(
|
|
899
897
|
path / INTENTLESS_PROMPT_TEMPLATE_FILE_NAME
|
|
900
898
|
)
|
|
@@ -419,9 +419,9 @@ class AugmentedMemoizationPolicy(MemoizationPolicy):
|
|
|
419
419
|
logger.debug("Launch DeLorean...")
|
|
420
420
|
|
|
421
421
|
# Truncate the tracker based on `max_history`
|
|
422
|
-
truncated_tracker: Optional[
|
|
423
|
-
|
|
424
|
-
|
|
422
|
+
truncated_tracker: Optional[DialogueStateTracker] = (
|
|
423
|
+
_trim_tracker_by_max_history(tracker, self.config[POLICY_MAX_HISTORY])
|
|
424
|
+
)
|
|
425
425
|
truncated_tracker = self._strip_leading_events_until_action_executed(
|
|
426
426
|
truncated_tracker
|
|
427
427
|
)
|
rasa/core/policies/policy.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import abc
|
|
3
4
|
import copy
|
|
4
5
|
import logging
|
|
5
6
|
from enum import Enum
|
|
6
7
|
from pathlib import Path
|
|
7
|
-
|
|
8
|
-
from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
9
|
-
from rasa.shared.core.events import Event
|
|
10
8
|
from typing import (
|
|
11
9
|
Any,
|
|
12
10
|
List,
|
|
@@ -21,6 +19,8 @@ from typing import (
|
|
|
21
19
|
|
|
22
20
|
import numpy as np
|
|
23
21
|
|
|
22
|
+
from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
23
|
+
from rasa.shared.core.events import Event
|
|
24
24
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
25
25
|
from rasa.engine.storage.resource import Resource
|
|
26
26
|
from rasa.engine.storage.storage import ModelStorage
|
|
@@ -40,14 +40,12 @@ from rasa.core.constants import (
|
|
|
40
40
|
from rasa.shared.core.constants import USER, SLOTS, PREVIOUS_ACTION, ACTIVE_LOOP
|
|
41
41
|
import rasa.shared.utils.common
|
|
42
42
|
|
|
43
|
-
|
|
44
43
|
if TYPE_CHECKING:
|
|
45
44
|
from rasa.shared.nlu.training_data.features import Features
|
|
46
45
|
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
47
46
|
from rasa.core.featurizers.tracker_featurizers import MaxHistoryTrackerFeaturizer
|
|
48
47
|
from rasa.dialogue_understanding.stack.frames import DialogueStackFrame
|
|
49
48
|
|
|
50
|
-
|
|
51
49
|
logger = logging.getLogger(__name__)
|
|
52
50
|
|
|
53
51
|
TrackerListTypeVar = TypeVar(
|
|
@@ -137,10 +135,20 @@ class Policy(GraphComponent):
|
|
|
137
135
|
def should_abstain_in_coexistence(
|
|
138
136
|
self, tracker: DialogueStateTracker, is_calm_policy: bool
|
|
139
137
|
) -> bool:
|
|
140
|
-
"""Whether a policy should abstain making predictions in coexistence.
|
|
138
|
+
"""Whether a policy should abstain making predictions in coexistence.
|
|
139
|
+
|
|
140
|
+
A calm policy should run when the routing slot is set to True.
|
|
141
|
+
A nlu-based policy should run when the routing slot is set to False or None.
|
|
142
|
+
"""
|
|
143
|
+
if is_calm_policy:
|
|
144
|
+
return tracker.has_coexistence_routing_slot and (
|
|
145
|
+
tracker.get_slot(ROUTE_TO_CALM_SLOT) is False
|
|
146
|
+
or tracker.get_slot(ROUTE_TO_CALM_SLOT) is None
|
|
147
|
+
)
|
|
148
|
+
|
|
141
149
|
return (
|
|
142
150
|
tracker.has_coexistence_routing_slot
|
|
143
|
-
and tracker.get_slot(ROUTE_TO_CALM_SLOT)
|
|
151
|
+
and tracker.get_slot(ROUTE_TO_CALM_SLOT) is True
|
|
144
152
|
)
|
|
145
153
|
|
|
146
154
|
def __init__(
|
|
@@ -299,8 +307,9 @@ class Policy(GraphComponent):
|
|
|
299
307
|
max_training_samples = kwargs.get("max_training_samples")
|
|
300
308
|
if max_training_samples is not None:
|
|
301
309
|
logger.debug(
|
|
302
|
-
"Limit training data to {} training samples."
|
|
303
|
-
|
|
310
|
+
"Limit training data to {} training samples.".format(
|
|
311
|
+
max_training_samples
|
|
312
|
+
)
|
|
304
313
|
)
|
|
305
314
|
state_features = state_features[:max_training_samples]
|
|
306
315
|
label_ids = label_ids[:max_training_samples]
|
|
@@ -60,7 +60,7 @@ logger = logging.getLogger(__name__)
|
|
|
60
60
|
structlogger = structlog.get_logger()
|
|
61
61
|
|
|
62
62
|
|
|
63
|
-
# These are Rasa
|
|
63
|
+
# These are Rasa Pro default actions and overrule everything at any time.
|
|
64
64
|
DEFAULT_ACTION_MAPPINGS = {
|
|
65
65
|
USER_INTENT_RESTART: ACTION_RESTART_NAME,
|
|
66
66
|
USER_INTENT_BACK: ACTION_BACK_NAME,
|
|
@@ -271,15 +271,13 @@ class RulePolicy(MemoizationPolicy):
|
|
|
271
271
|
if (
|
|
272
272
|
# loop is predicted after action_listen in unhappy path,
|
|
273
273
|
# therefore no validation is needed
|
|
274
|
-
is_prev_action_listen_in_state(states[-1])
|
|
275
|
-
and action == active_loop
|
|
274
|
+
is_prev_action_listen_in_state(states[-1]) and action == active_loop
|
|
276
275
|
):
|
|
277
276
|
lookup[feature_key] = LOOP_WAS_INTERRUPTED
|
|
278
277
|
elif (
|
|
279
278
|
# some action other than active_loop is predicted in unhappy path,
|
|
280
279
|
# therefore active_loop shouldn't be predicted by the rule
|
|
281
|
-
not is_prev_action_listen_in_state(states[-1])
|
|
282
|
-
and action != active_loop
|
|
280
|
+
not is_prev_action_listen_in_state(states[-1]) and action != active_loop
|
|
283
281
|
):
|
|
284
282
|
lookup[feature_key] = DO_NOT_PREDICT_LOOP_ACTION
|
|
285
283
|
return lookup
|
|
@@ -777,10 +775,10 @@ class RulePolicy(MemoizationPolicy):
|
|
|
777
775
|
trackers_as_actions = rule_trackers_as_actions + story_trackers_as_actions
|
|
778
776
|
|
|
779
777
|
# negative rules are not anti-rules, they are auxiliary to actual rules
|
|
780
|
-
self.lookup[
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
778
|
+
self.lookup[RULES_FOR_LOOP_UNHAPPY_PATH] = (
|
|
779
|
+
self._create_loop_unhappy_lookup_from_states(
|
|
780
|
+
trackers_as_states, trackers_as_actions
|
|
781
|
+
)
|
|
784
782
|
)
|
|
785
783
|
|
|
786
784
|
def train(
|
|
@@ -955,7 +953,6 @@ class RulePolicy(MemoizationPolicy):
|
|
|
955
953
|
def _find_action_from_loop_happy_path(
|
|
956
954
|
tracker: DialogueStateTracker,
|
|
957
955
|
) -> Tuple[Optional[Text], Optional[Text]]:
|
|
958
|
-
|
|
959
956
|
active_loop_name = tracker.active_loop_name
|
|
960
957
|
if active_loop_name is None:
|
|
961
958
|
return None, None
|
|
@@ -1132,7 +1129,7 @@ class RulePolicy(MemoizationPolicy):
|
|
|
1132
1129
|
tracker, domain, use_text_for_last_user_input=True
|
|
1133
1130
|
)
|
|
1134
1131
|
|
|
1135
|
-
# Rasa
|
|
1132
|
+
# Rasa Pro default actions overrule anything. If users want to achieve
|
|
1136
1133
|
# the same, they need to write a rule or make sure that their loop rejects
|
|
1137
1134
|
# accordingly.
|
|
1138
1135
|
(
|
rasa/core/policies/ted_policy.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
|
|
3
2
|
import logging
|
|
3
|
+
|
|
4
|
+
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from collections import defaultdict
|
|
6
7
|
import contextlib
|
|
7
|
-
from typing import Any, List, Optional, Text, Dict, Tuple, Union, Type
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import tensorflow as tf
|
|
11
|
+
from typing import Any, List, Optional, Text, Dict, Tuple, Union, Type
|
|
11
12
|
|
|
12
|
-
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
13
13
|
from rasa.engine.graph import ExecutionContext
|
|
14
14
|
from rasa.engine.storage.resource import Resource
|
|
15
15
|
from rasa.engine.storage.storage import ModelStorage
|
|
@@ -49,22 +49,18 @@ from rasa.shared.core.generator import TrackerWithCachedStates
|
|
|
49
49
|
from rasa.shared.core.events import EntitiesAdded, Event
|
|
50
50
|
from rasa.shared.core.domain import Domain
|
|
51
51
|
from rasa.shared.nlu.training_data.message import Message
|
|
52
|
-
from rasa.shared.nlu.training_data.features import
|
|
53
|
-
Features,
|
|
54
|
-
save_features,
|
|
55
|
-
load_features,
|
|
56
|
-
)
|
|
52
|
+
from rasa.shared.nlu.training_data.features import Features
|
|
57
53
|
import rasa.shared.utils.io
|
|
58
54
|
import rasa.utils.io
|
|
59
55
|
from rasa.utils import train_utils
|
|
60
|
-
from rasa.utils.tensorflow.feature_array import (
|
|
61
|
-
FeatureArray,
|
|
62
|
-
serialize_nested_feature_arrays,
|
|
63
|
-
deserialize_nested_feature_arrays,
|
|
64
|
-
)
|
|
65
56
|
from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
|
|
66
57
|
from rasa.utils.tensorflow import rasa_layers
|
|
67
|
-
from rasa.utils.tensorflow.model_data import
|
|
58
|
+
from rasa.utils.tensorflow.model_data import (
|
|
59
|
+
RasaModelData,
|
|
60
|
+
FeatureSignature,
|
|
61
|
+
FeatureArray,
|
|
62
|
+
Data,
|
|
63
|
+
)
|
|
68
64
|
from rasa.utils.tensorflow.model_data_utils import convert_to_data_format
|
|
69
65
|
from rasa.utils.tensorflow.constants import (
|
|
70
66
|
LABEL,
|
|
@@ -472,7 +468,7 @@ class TEDPolicy(Policy):
|
|
|
472
468
|
|
|
473
469
|
@staticmethod
|
|
474
470
|
def _should_extract_entities(
|
|
475
|
-
entity_tags: List[List[Dict[Text, List[Features]]]]
|
|
471
|
+
entity_tags: List[List[Dict[Text, List[Features]]]],
|
|
476
472
|
) -> bool:
|
|
477
473
|
for turns_tags in entity_tags:
|
|
478
474
|
for turn_tags in turns_tags:
|
|
@@ -965,32 +961,22 @@ class TEDPolicy(Policy):
|
|
|
965
961
|
model_path: Path where model is to be persisted
|
|
966
962
|
"""
|
|
967
963
|
model_filename = self._metadata_filename()
|
|
968
|
-
rasa.
|
|
969
|
-
model_path / f"{model_filename}.priority.
|
|
964
|
+
rasa.utils.io.json_pickle(
|
|
965
|
+
model_path / f"{model_filename}.priority.pkl", self.priority
|
|
970
966
|
)
|
|
971
|
-
rasa.
|
|
972
|
-
model_path / f"{model_filename}.meta.
|
|
973
|
-
)
|
|
974
|
-
# save data example
|
|
975
|
-
serialize_nested_feature_arrays(
|
|
976
|
-
self.data_example,
|
|
977
|
-
str(model_path / f"{model_filename}.data_example.st"),
|
|
978
|
-
str(model_path / f"{model_filename}.data_example_metadata.json"),
|
|
967
|
+
rasa.utils.io.pickle_dump(
|
|
968
|
+
model_path / f"{model_filename}.meta.pkl", self.config
|
|
979
969
|
)
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
dict(self._label_data.data) if self._label_data is not None else {},
|
|
983
|
-
str(model_path / f"{model_filename}.label_data.st"),
|
|
984
|
-
str(model_path / f"{model_filename}.label_data_metadata.json"),
|
|
970
|
+
rasa.utils.io.pickle_dump(
|
|
971
|
+
model_path / f"{model_filename}.data_example.pkl", self.data_example
|
|
985
972
|
)
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
self.fake_features, str(model_path / f"{model_filename}.fake_features.st")
|
|
973
|
+
rasa.utils.io.pickle_dump(
|
|
974
|
+
model_path / f"{model_filename}.fake_features.pkl", self.fake_features
|
|
989
975
|
)
|
|
990
|
-
rasa.
|
|
991
|
-
model_path / f"{model_filename}.
|
|
976
|
+
rasa.utils.io.pickle_dump(
|
|
977
|
+
model_path / f"{model_filename}.label_data.pkl",
|
|
978
|
+
dict(self._label_data.data) if self._label_data is not None else {},
|
|
992
979
|
)
|
|
993
|
-
|
|
994
980
|
entity_tag_specs = (
|
|
995
981
|
[tag_spec._asdict() for tag_spec in self._entity_tag_specs]
|
|
996
982
|
if self._entity_tag_specs
|
|
@@ -1008,29 +994,18 @@ class TEDPolicy(Policy):
|
|
|
1008
994
|
model_path: Path where model is to be persisted.
|
|
1009
995
|
"""
|
|
1010
996
|
tf_model_file = model_path / f"{cls._metadata_filename()}.tf_model"
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
loaded_data = deserialize_nested_feature_arrays(
|
|
1014
|
-
str(model_path / f"{cls._metadata_filename()}.data_example.st"),
|
|
1015
|
-
str(model_path / f"{cls._metadata_filename()}.data_example_metadata.json"),
|
|
997
|
+
loaded_data = rasa.utils.io.pickle_load(
|
|
998
|
+
model_path / f"{cls._metadata_filename()}.data_example.pkl"
|
|
1016
999
|
)
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
str(model_path / f"{cls._metadata_filename()}.label_data.st"),
|
|
1020
|
-
str(model_path / f"{cls._metadata_filename()}.label_data_metadata.json"),
|
|
1000
|
+
label_data = rasa.utils.io.pickle_load(
|
|
1001
|
+
model_path / f"{cls._metadata_filename()}.label_data.pkl"
|
|
1021
1002
|
)
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
# load fake features
|
|
1025
|
-
metadata = rasa.shared.utils.io.read_json_file(
|
|
1026
|
-
model_path / f"{cls._metadata_filename()}.fake_features_metadata.json"
|
|
1027
|
-
)
|
|
1028
|
-
fake_features = load_features(
|
|
1029
|
-
str(model_path / f"{cls._metadata_filename()}.fake_features.st"), metadata
|
|
1003
|
+
fake_features = rasa.utils.io.pickle_load(
|
|
1004
|
+
model_path / f"{cls._metadata_filename()}.fake_features.pkl"
|
|
1030
1005
|
)
|
|
1031
|
-
|
|
1032
|
-
priority = rasa.
|
|
1033
|
-
model_path / f"{cls._metadata_filename()}.priority.
|
|
1006
|
+
label_data = RasaModelData(data=label_data)
|
|
1007
|
+
priority = rasa.utils.io.json_unpickle(
|
|
1008
|
+
model_path / f"{cls._metadata_filename()}.priority.pkl"
|
|
1034
1009
|
)
|
|
1035
1010
|
entity_tag_specs = rasa.shared.utils.io.read_json_file(
|
|
1036
1011
|
model_path / f"{cls._metadata_filename()}.entity_tag_specs.json"
|
|
@@ -1048,8 +1023,8 @@ class TEDPolicy(Policy):
|
|
|
1048
1023
|
)
|
|
1049
1024
|
for tag_spec in entity_tag_specs
|
|
1050
1025
|
]
|
|
1051
|
-
model_config = rasa.
|
|
1052
|
-
model_path / f"{cls._metadata_filename()}.meta.
|
|
1026
|
+
model_config = rasa.utils.io.pickle_load(
|
|
1027
|
+
model_path / f"{cls._metadata_filename()}.meta.pkl"
|
|
1053
1028
|
)
|
|
1054
1029
|
|
|
1055
1030
|
return {
|
|
@@ -1095,7 +1070,7 @@ class TEDPolicy(Policy):
|
|
|
1095
1070
|
) -> TEDPolicy:
|
|
1096
1071
|
featurizer = TrackerFeaturizer.load(model_path)
|
|
1097
1072
|
|
|
1098
|
-
if not (model_path / f"{cls._metadata_filename()}.data_example.
|
|
1073
|
+
if not (model_path / f"{cls._metadata_filename()}.data_example.pkl").is_file():
|
|
1099
1074
|
return cls(
|
|
1100
1075
|
config,
|
|
1101
1076
|
model_storage,
|
|
@@ -1117,7 +1092,7 @@ class TEDPolicy(Policy):
|
|
|
1117
1092
|
|
|
1118
1093
|
model = None
|
|
1119
1094
|
|
|
1120
|
-
with
|
|
1095
|
+
with contextlib.nullcontext() if config["use_gpu"] else tf.device("/cpu:0"):
|
|
1121
1096
|
model = cls._load_tf_model(
|
|
1122
1097
|
model_utilities,
|
|
1123
1098
|
model_data_example,
|
|
@@ -1291,19 +1266,19 @@ class TED(TransformerRasaModel):
|
|
|
1291
1266
|
)
|
|
1292
1267
|
self._prepare_encoding_layers(name)
|
|
1293
1268
|
|
|
1294
|
-
self._tf_layers[
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1269
|
+
self._tf_layers[f"transformer.{DIALOGUE}"] = (
|
|
1270
|
+
rasa_layers.prepare_transformer_layer(
|
|
1271
|
+
attribute_name=DIALOGUE,
|
|
1272
|
+
config=self.config,
|
|
1273
|
+
num_layers=self.config[NUM_TRANSFORMER_LAYERS][DIALOGUE],
|
|
1274
|
+
units=self.config[TRANSFORMER_SIZE][DIALOGUE],
|
|
1275
|
+
drop_rate=self.config[DROP_RATE_DIALOGUE],
|
|
1276
|
+
# use bidirectional transformer, because
|
|
1277
|
+
# we will invert dialogue sequence so that the last turn is located
|
|
1278
|
+
# at the first position and would always have
|
|
1279
|
+
# exactly the same positional encoding
|
|
1280
|
+
unidirectional=not self.max_history_featurizer_is_used,
|
|
1281
|
+
)
|
|
1307
1282
|
)
|
|
1308
1283
|
|
|
1309
1284
|
self._prepare_label_classification_layers(DIALOGUE)
|
|
@@ -1333,23 +1308,23 @@ class TED(TransformerRasaModel):
|
|
|
1333
1308
|
# Attributes with sequence-level features also have sentence-level features,
|
|
1334
1309
|
# all these need to be combined and further processed.
|
|
1335
1310
|
if attribute_name in SEQUENCE_FEATURES_TO_ENCODE:
|
|
1336
|
-
self._tf_layers[
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1311
|
+
self._tf_layers[f"sequence_layer.{attribute_name}"] = (
|
|
1312
|
+
rasa_layers.RasaSequenceLayer(
|
|
1313
|
+
attribute_name, attribute_signature, config_to_use
|
|
1314
|
+
)
|
|
1340
1315
|
)
|
|
1341
1316
|
# Attributes without sequence-level features require some actual feature
|
|
1342
1317
|
# processing only if they have sentence-level features. Attributes with no
|
|
1343
1318
|
# sequence- and sentence-level features (dialogue, entity_tags, label) are
|
|
1344
1319
|
# skipped here.
|
|
1345
1320
|
elif SENTENCE in attribute_signature:
|
|
1346
|
-
self._tf_layers[
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1321
|
+
self._tf_layers[f"sparse_dense_concat_layer.{attribute_name}"] = (
|
|
1322
|
+
rasa_layers.ConcatenateSparseDenseFeatures(
|
|
1323
|
+
attribute=attribute_name,
|
|
1324
|
+
feature_type=SENTENCE,
|
|
1325
|
+
feature_type_signature=attribute_signature[SENTENCE],
|
|
1326
|
+
config=config_to_use,
|
|
1327
|
+
)
|
|
1353
1328
|
)
|
|
1354
1329
|
|
|
1355
1330
|
def _prepare_encoding_layers(self, name: Text) -> None:
|
|
@@ -1385,7 +1360,7 @@ class TED(TransformerRasaModel):
|
|
|
1385
1360
|
|
|
1386
1361
|
@staticmethod
|
|
1387
1362
|
def _compute_dialogue_indices(
|
|
1388
|
-
tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]]
|
|
1363
|
+
tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
|
|
1389
1364
|
) -> None:
|
|
1390
1365
|
dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], dtype=tf.int32)
|
|
1391
1366
|
# wrap in a list, because that's the structure of tf_batch_data
|
|
@@ -1424,7 +1399,7 @@ class TED(TransformerRasaModel):
|
|
|
1424
1399
|
|
|
1425
1400
|
@staticmethod
|
|
1426
1401
|
def _collect_label_attribute_encodings(
|
|
1427
|
-
all_labels_encoded: Dict[Text, tf.Tensor]
|
|
1402
|
+
all_labels_encoded: Dict[Text, tf.Tensor],
|
|
1428
1403
|
) -> tf.Tensor:
|
|
1429
1404
|
# Initialize with at least one attribute first
|
|
1430
1405
|
# so that the subsequent TF ops are simplified.
|
|
@@ -1953,7 +1928,6 @@ class TED(TransformerRasaModel):
|
|
|
1953
1928
|
text_output: tf.Tensor,
|
|
1954
1929
|
text_sequence_lengths: tf.Tensor,
|
|
1955
1930
|
) -> tf.Tensor:
|
|
1956
|
-
|
|
1957
1931
|
text_transformed, text_mask, text_sequence_lengths = self._reshape_for_entities(
|
|
1958
1932
|
tf_batch_data,
|
|
1959
1933
|
dialogue_transformer_output,
|
|
@@ -2156,7 +2130,6 @@ class TED(TransformerRasaModel):
|
|
|
2156
2130
|
text_output: tf.Tensor,
|
|
2157
2131
|
text_sequence_lengths: tf.Tensor,
|
|
2158
2132
|
) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
2159
|
-
|
|
2160
2133
|
text_transformed, _, text_sequence_lengths = self._reshape_for_entities(
|
|
2161
2134
|
tf_batch_data,
|
|
2162
2135
|
dialogue_transformer_output,
|