rasa-pro 3.8.17__py3-none-any.whl → 3.9.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +5 -5
- rasa/__main__.py +14 -9
- rasa/anonymization/anonymization_pipeline.py +0 -1
- rasa/anonymization/anonymization_rule_executor.py +3 -3
- rasa/anonymization/utils.py +4 -3
- rasa/api.py +2 -2
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/arguments/run.py +2 -2
- rasa/cli/arguments/test.py +1 -1
- rasa/cli/arguments/train.py +10 -10
- rasa/cli/e2e_test.py +27 -7
- rasa/cli/export.py +0 -1
- rasa/cli/license.py +3 -3
- rasa/cli/project_templates/calm/actions/action_template.py +1 -1
- rasa/cli/project_templates/calm/config.yml +1 -1
- rasa/cli/project_templates/calm/credentials.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +1 -1
- rasa/cli/project_templates/calm/domain/add_contact.yml +8 -2
- rasa/cli/project_templates/calm/domain/list_contacts.yml +3 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +9 -2
- rasa/cli/project_templates/calm/domain/shared.yml +5 -0
- rasa/cli/project_templates/calm/endpoints.yml +4 -4
- rasa/cli/project_templates/default/actions/actions.py +1 -1
- rasa/cli/project_templates/default/config.yml +5 -5
- rasa/cli/project_templates/default/credentials.yml +1 -1
- rasa/cli/project_templates/default/endpoints.yml +4 -4
- rasa/cli/project_templates/default/tests/test_stories.yml +1 -1
- rasa/cli/project_templates/tutorial/config.yml +1 -1
- rasa/cli/project_templates/tutorial/credentials.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
- rasa/cli/project_templates/tutorial/domain.yml +4 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +6 -6
- rasa/cli/run.py +0 -1
- rasa/cli/scaffold.py +3 -2
- rasa/cli/studio/download.py +11 -0
- rasa/cli/studio/studio.py +180 -24
- rasa/cli/studio/upload.py +0 -8
- rasa/cli/telemetry.py +18 -6
- rasa/cli/utils.py +21 -10
- rasa/cli/x.py +3 -2
- rasa/core/actions/action.py +90 -315
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/constants.py +3 -0
- rasa/core/actions/custom_action_executor.py +188 -0
- rasa/core/actions/forms.py +11 -7
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +140 -0
- rasa/core/actions/loops.py +3 -0
- rasa/core/actions/two_stage_fallback.py +1 -1
- rasa/core/agent.py +2 -4
- rasa/core/brokers/pika.py +1 -2
- rasa/core/channels/audiocodes.py +1 -1
- rasa/core/channels/botframework.py +0 -1
- rasa/core/channels/callback.py +0 -1
- rasa/core/channels/console.py +6 -8
- rasa/core/channels/development_inspector.py +1 -1
- rasa/core/channels/facebook.py +0 -3
- rasa/core/channels/hangouts.py +0 -6
- rasa/core/channels/inspector/dist/assets/{arc-5623b6dc.js → arc-b6e548fe.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-685c106a.js → c4Diagram-d0fbc5ce-fa03ac9e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-8cbed007.js → classDiagram-936ed81e-ee67392a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-5889cf12.js → classDiagram-v2-c3cb15f1-9b283fae.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-24c249d7.js → createText-62fc7601-8b6fcc2a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-7dd06a75.js → edges-f2ad444c-22e77f4f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-62c1e54c.js → erDiagram-9d236eb7-60ffc87f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-ce49b86f.js → flowDb-1972c806-9dd802e4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4067e48f.js → flowDiagram-7ea5b25a-5fa1912f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-59fe4051.js → flowchart-elk-definition-abe16c3d-622a1fd2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-47e3a43b.js → ganttDiagram-9b5ea136-e285a63a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-5a2ac0d9.js → gitGraphDiagram-99d0ae7c-f237bdca.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-dfb8efc4.js → index-2c4b9a3b-4b03d70e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-268a75c0.js → index-a5d3e69d.js} +4 -4
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-b0c470f2.js → infoDiagram-736b4530-72a0fa5f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-2edb829a.js → journeyDiagram-df861f2b-82218c41.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b6873d69.js → layout-78cff630.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-1efc5781.js → line-5038b469.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-661e9b94.js → linear-c4fc4098.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2d2e727f.js → mindmap-definition-beec6740-c33c8ea6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-9d3ea93d.js → pieDiagram-dbbf0591-a8d03059.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-06a178a2.js → quadrantDiagram-4d7f4fd6-6a0e56b2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-0bfedffc.js → requirementDiagram-6fc4c22a-2dc7c7bd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-d76d0a04.js → sankeyDiagram-8f13d901-2360fe39.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-37bb4341.js → sequenceDiagram-b655622a-41b9f9ad.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-f52f7f57.js → stateDiagram-59f0c015-0aad326f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-4a986a20.js → stateDiagram-v2-2b26beab-9847d984.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-7dd9ae12.js → styles-080da4f6-564d890e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-46e1ca14.js → styles-3dcbcfbf-38957613.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-4a97439a.js → styles-9c745c82-f0fc6921.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-823917a3.js → svgDrawCommon-4835440b-ef3c5a77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-9ea72896.js → timeline-definition-5b62e21b-bf3e91c1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-b631a8b6.js → xychartDiagram-2b33534f-4d4026c0.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -7
- rasa/core/channels/inspector/src/helpers/formatters.ts +3 -2
- rasa/core/channels/rest.py +36 -21
- rasa/core/channels/rocketchat.py +0 -1
- rasa/core/channels/socketio.py +1 -1
- rasa/core/channels/telegram.py +3 -3
- rasa/core/channels/webexteams.py +0 -1
- rasa/core/concurrent_lock_store.py +1 -1
- rasa/core/evaluation/marker_base.py +1 -3
- rasa/core/evaluation/marker_stats.py +1 -2
- rasa/core/featurizers/single_state_featurizer.py +2 -4
- rasa/core/featurizers/tracker_featurizers.py +0 -7
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +9 -4
- rasa/core/information_retrieval/information_retrieval.py +64 -7
- rasa/core/information_retrieval/milvus.py +7 -14
- rasa/core/information_retrieval/qdrant.py +8 -15
- rasa/core/lock_store.py +0 -1
- rasa/core/migrate.py +1 -2
- rasa/core/nlg/callback.py +3 -4
- rasa/core/policies/enterprise_search_policy.py +86 -22
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +4 -41
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flows/flow_executor.py +104 -2
- rasa/core/policies/intentless_policy.py +7 -9
- rasa/core/policies/memoization.py +3 -3
- rasa/core/policies/policy.py +18 -9
- rasa/core/policies/rule_policy.py +8 -11
- rasa/core/policies/ted_policy.py +28 -30
- rasa/core/policies/unexpected_intent_policy.py +1 -2
- rasa/core/processor.py +136 -47
- rasa/core/run.py +41 -25
- rasa/core/secrets_manager/endpoints.py +2 -2
- rasa/core/secrets_manager/vault.py +6 -8
- rasa/core/test.py +3 -5
- rasa/core/tracker_store.py +49 -14
- rasa/core/train.py +1 -3
- rasa/core/training/interactive.py +9 -6
- rasa/core/utils.py +5 -10
- rasa/dialogue_understanding/coexistence/intent_based_router.py +11 -4
- rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -3
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +9 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +9 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +38 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/clarify_command.py +9 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +9 -0
- rasa/dialogue_understanding/commands/error_command.py +12 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +9 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +9 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/noop_command.py +9 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +34 -3
- rasa/dialogue_understanding/commands/skip_question_command.py +9 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +9 -0
- rasa/dialogue_understanding/generator/__init__.py +16 -1
- rasa/dialogue_understanding/generator/command_generator.py +92 -6
- rasa/dialogue_understanding/generator/constants.py +18 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +7 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +39 -609
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +69 -8
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +44 -39
- rasa/dialogue_understanding/processor/command_processor.py +111 -3
- rasa/e2e_test/constants.py +1 -0
- rasa/e2e_test/e2e_test_case.py +44 -0
- rasa/e2e_test/e2e_test_runner.py +114 -11
- rasa/e2e_test/e2e_test_schema.yml +18 -0
- rasa/engine/caching.py +0 -1
- rasa/engine/graph.py +18 -6
- rasa/engine/recipes/config_files/default_config.yml +3 -3
- rasa/engine/recipes/default_components.py +1 -1
- rasa/engine/recipes/default_recipe.py +4 -5
- rasa/engine/recipes/recipe.py +1 -1
- rasa/engine/runner/dask.py +3 -9
- rasa/engine/storage/local_model_storage.py +0 -2
- rasa/engine/validation.py +179 -145
- rasa/exceptions.py +2 -2
- rasa/graph_components/validators/default_recipe_validator.py +3 -5
- rasa/hooks.py +0 -1
- rasa/model.py +1 -1
- rasa/model_training.py +1 -0
- rasa/nlu/classifiers/diet_classifier.py +8 -14
- rasa/nlu/extractors/crf_entity_extractor.py +4 -4
- rasa/nlu/extractors/duckling_entity_extractor.py +1 -1
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +1 -5
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +0 -4
- rasa/nlu/featurizers/featurizer.py +1 -1
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +2 -4
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +9 -12
- rasa/nlu/persistor.py +68 -26
- rasa/nlu/selectors/response_selector.py +7 -10
- rasa/nlu/test.py +0 -3
- rasa/nlu/utils/hugging_face/registry.py +1 -1
- rasa/nlu/utils/spacy_utils.py +1 -3
- rasa/server.py +22 -7
- rasa/shared/constants.py +12 -1
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +4 -5
- rasa/shared/core/domain.py +57 -56
- rasa/shared/core/events.py +4 -7
- rasa/shared/core/flows/flow.py +9 -0
- rasa/shared/core/flows/flows_list.py +12 -0
- rasa/shared/core/flows/steps/action.py +7 -2
- rasa/shared/core/generator.py +12 -11
- rasa/shared/core/slot_mappings.py +315 -24
- rasa/shared/core/slots.py +4 -2
- rasa/shared/core/trackers.py +32 -14
- rasa/shared/core/training_data/loading.py +0 -1
- rasa/shared/core/training_data/story_reader/story_reader.py +3 -3
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +11 -11
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +5 -3
- rasa/shared/core/training_data/structures.py +1 -1
- rasa/shared/core/training_data/visualization.py +1 -1
- rasa/shared/data.py +58 -1
- rasa/shared/exceptions.py +36 -2
- rasa/shared/importers/importer.py +1 -2
- rasa/shared/importers/rasa.py +0 -1
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/entities_parser.py +1 -2
- rasa/shared/nlu/training_data/formats/dialogflow.py +3 -2
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -5
- rasa/shared/nlu/training_data/formats/readerwriter.py +0 -1
- rasa/shared/nlu/training_data/message.py +13 -0
- rasa/shared/nlu/training_data/training_data.py +0 -2
- rasa/shared/providers/openai/session_handler.py +2 -2
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/io.py +11 -0
- rasa/shared/utils/llm.py +1 -2
- rasa/shared/utils/pykwalify_extensions.py +1 -0
- rasa/shared/utils/schemas/domain.yml +3 -0
- rasa/shared/utils/yaml.py +44 -35
- rasa/studio/auth.py +26 -10
- rasa/studio/constants.py +2 -0
- rasa/studio/data_handler.py +114 -107
- rasa/studio/download.py +160 -27
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +6 -7
- rasa/studio/upload.py +159 -134
- rasa/telemetry.py +188 -34
- rasa/tracing/config.py +18 -3
- rasa/tracing/constants.py +26 -2
- rasa/tracing/instrumentation/attribute_extractors.py +50 -41
- rasa/tracing/instrumentation/instrumentation.py +290 -44
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +7 -5
- rasa/tracing/instrumentation/metrics.py +109 -21
- rasa/tracing/metric_instrument_provider.py +83 -3
- rasa/utils/cli.py +2 -1
- rasa/utils/common.py +1 -1
- rasa/utils/endpoints.py +1 -2
- rasa/utils/io.py +6 -6
- rasa/utils/licensing.py +246 -31
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/data_generator.py +1 -1
- rasa/utils/tensorflow/environment.py +1 -1
- rasa/utils/tensorflow/model_data.py +9 -11
- rasa/utils/tensorflow/model_data_utils.py +499 -500
- rasa/utils/tensorflow/models.py +5 -6
- rasa/utils/tensorflow/rasa_layers.py +15 -15
- rasa/utils/train_utils.py +1 -1
- rasa/utils/url_tools.py +53 -0
- rasa/validator.py +305 -3
- rasa/version.py +1 -1
- {rasa_pro-3.8.17.dist-info → rasa_pro-3.9.14.dist-info}/METADATA +22 -22
- {rasa_pro-3.8.17.dist-info → rasa_pro-3.9.14.dist-info}/RECORD +271 -253
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +0 -1
- /rasa/dialogue_understanding/generator/{command_prompt_template.jinja2 → single_step/command_prompt_template.jinja2} +0 -0
- {rasa_pro-3.8.17.dist-info → rasa_pro-3.9.14.dist-info}/NOTICE +0 -0
- {rasa_pro-3.8.17.dist-info → rasa_pro-3.9.14.dist-info}/WHEEL +0 -0
- {rasa_pro-3.8.17.dist-info → rasa_pro-3.9.14.dist-info}/entry_points.txt +0 -0
rasa/core/run.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
-
import uuid
|
|
4
|
-
import platform
|
|
5
3
|
import os
|
|
4
|
+
import platform
|
|
5
|
+
import uuid
|
|
6
|
+
import warnings
|
|
7
|
+
from asyncio import AbstractEventLoop
|
|
6
8
|
from functools import partial
|
|
7
9
|
from typing import (
|
|
8
10
|
Any,
|
|
@@ -15,11 +17,14 @@ from typing import (
|
|
|
15
17
|
Dict,
|
|
16
18
|
)
|
|
17
19
|
|
|
20
|
+
from sanic import Sanic
|
|
21
|
+
from sanic.worker.loader import AppLoader
|
|
22
|
+
|
|
18
23
|
import rasa.core.utils
|
|
19
|
-
from rasa.plugin import plugin_manager
|
|
20
|
-
from rasa.shared.exceptions import RasaException
|
|
21
24
|
import rasa.shared.utils.common
|
|
25
|
+
import rasa.shared.utils.io
|
|
22
26
|
import rasa.utils
|
|
27
|
+
from rasa.utils import licensing
|
|
23
28
|
import rasa.utils.common
|
|
24
29
|
import rasa.utils.io
|
|
25
30
|
from rasa import server, telemetry
|
|
@@ -29,10 +34,8 @@ from rasa.core.agent import Agent
|
|
|
29
34
|
from rasa.core.channels import console
|
|
30
35
|
from rasa.core.channels.channel import InputChannel
|
|
31
36
|
from rasa.core.utils import AvailableEndpoints
|
|
32
|
-
|
|
33
|
-
from
|
|
34
|
-
from asyncio import AbstractEventLoop
|
|
35
|
-
|
|
37
|
+
from rasa.plugin import plugin_manager
|
|
38
|
+
from rasa.shared.exceptions import RasaException
|
|
36
39
|
from rasa.shared.utils.yaml import read_config_file
|
|
37
40
|
|
|
38
41
|
logger = logging.getLogger() # get the root logger
|
|
@@ -84,6 +87,10 @@ def _create_single_channel(channel: Text, credentials: Dict[Text, Any]) -> Any:
|
|
|
84
87
|
|
|
85
88
|
def _create_app_without_api(cors: Optional[Union[Text, List[Text]]] = None) -> Sanic:
|
|
86
89
|
app = Sanic("rasa_core_no_api", configure_logging=False)
|
|
90
|
+
|
|
91
|
+
# Reset Sanic warnings filter that allows the triggering of Sanic warnings
|
|
92
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"sanic.*")
|
|
93
|
+
|
|
87
94
|
server.add_root_route(app)
|
|
88
95
|
server.configure_cors(app, cors)
|
|
89
96
|
return app
|
|
@@ -126,19 +133,24 @@ def configure_app(
|
|
|
126
133
|
)
|
|
127
134
|
|
|
128
135
|
if enable_api:
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
136
|
+
loader = AppLoader(
|
|
137
|
+
factory=partial(
|
|
138
|
+
server.create_app,
|
|
139
|
+
cors_origins=cors,
|
|
140
|
+
auth_token=auth_token,
|
|
141
|
+
response_timeout=response_timeout,
|
|
142
|
+
jwt_secret=jwt_secret,
|
|
143
|
+
jwt_private_key=jwt_private_key,
|
|
144
|
+
jwt_method=jwt_method,
|
|
145
|
+
endpoints=endpoints,
|
|
146
|
+
)
|
|
137
147
|
)
|
|
138
148
|
else:
|
|
139
|
-
|
|
149
|
+
loader = AppLoader(factory=partial(_create_app_without_api, cors))
|
|
140
150
|
|
|
151
|
+
app = loader.load()
|
|
141
152
|
app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout
|
|
153
|
+
|
|
142
154
|
if _is_apple_silicon_system() or not use_uvloop:
|
|
143
155
|
app.config.USE_UVLOOP = False
|
|
144
156
|
# some library still sets the loop to uvloop, even if disabled for sanic
|
|
@@ -154,18 +166,16 @@ def configure_app(
|
|
|
154
166
|
if logger.isEnabledFor(logging.DEBUG):
|
|
155
167
|
rasa.core.utils.list_routes(app)
|
|
156
168
|
|
|
157
|
-
|
|
169
|
+
@app.main_process_start
|
|
170
|
+
async def configure_async_logging(running_app: Sanic) -> None:
|
|
158
171
|
if logger.isEnabledFor(logging.DEBUG):
|
|
159
172
|
rasa.utils.io.enable_async_loop_debugging(asyncio.get_event_loop())
|
|
160
173
|
|
|
161
|
-
app.add_task(configure_async_logging)
|
|
162
|
-
|
|
163
174
|
if "cmdline" in {c.name() for c in input_channels}:
|
|
164
175
|
|
|
176
|
+
@app.after_server_start
|
|
165
177
|
async def run_cmdline_io(running_app: Sanic) -> None:
|
|
166
178
|
"""Small wrapper to shut down the server once cmd io is done."""
|
|
167
|
-
await asyncio.sleep(1) # allow server to start
|
|
168
|
-
|
|
169
179
|
await console.record_messages(
|
|
170
180
|
server_url=constants.DEFAULT_SERVER_FORMAT.format("http", port),
|
|
171
181
|
sender_id=conversation_id,
|
|
@@ -174,12 +184,13 @@ def configure_app(
|
|
|
174
184
|
|
|
175
185
|
logger.info("Killing Sanic server now.")
|
|
176
186
|
running_app.stop() # kill the sanic server
|
|
177
|
-
plugin_manager().hook.after_server_stop()
|
|
178
187
|
|
|
179
|
-
|
|
188
|
+
@app.after_server_stop
|
|
189
|
+
async def after_server_stop(running_app: Sanic) -> None:
|
|
190
|
+
plugin_manager().hook.after_server_stop()
|
|
180
191
|
|
|
181
192
|
if server_listeners:
|
|
182
|
-
for
|
|
193
|
+
for listener, event in server_listeners:
|
|
183
194
|
app.register_listener(listener, event)
|
|
184
195
|
|
|
185
196
|
return app
|
|
@@ -252,6 +263,10 @@ def serve_application(
|
|
|
252
263
|
"before_server_start",
|
|
253
264
|
)
|
|
254
265
|
|
|
266
|
+
app.register_listener(
|
|
267
|
+
licensing.validate_limited_server_license, "after_server_start"
|
|
268
|
+
)
|
|
269
|
+
|
|
255
270
|
app.register_listener(close_resources, "after_server_stop")
|
|
256
271
|
|
|
257
272
|
number_of_workers = rasa.core.utils.number_of_sanic_workers(
|
|
@@ -272,6 +287,7 @@ def serve_application(
|
|
|
272
287
|
ssl=ssl_context,
|
|
273
288
|
backlog=int(os.environ.get(ENV_SANIC_BACKLOG, "100")),
|
|
274
289
|
workers=number_of_workers,
|
|
290
|
+
legacy=True,
|
|
275
291
|
)
|
|
276
292
|
|
|
277
293
|
|
|
@@ -47,7 +47,7 @@ class CredentialsLocation:
|
|
|
47
47
|
|
|
48
48
|
@staticmethod
|
|
49
49
|
def is_credentials_location_instance(
|
|
50
|
-
value: Union[Text, "CredentialsLocation"]
|
|
50
|
+
value: Union[Text, "CredentialsLocation"],
|
|
51
51
|
) -> bool:
|
|
52
52
|
"""Check if the value is a CredentialsLocation.
|
|
53
53
|
|
|
@@ -85,7 +85,7 @@ class CredentialsLocation:
|
|
|
85
85
|
|
|
86
86
|
@staticmethod
|
|
87
87
|
def is_credentials_location_valid(
|
|
88
|
-
raw_credentials_location: Dict[Text, Text]
|
|
88
|
+
raw_credentials_location: Dict[Text, Text],
|
|
89
89
|
) -> bool:
|
|
90
90
|
"""Check if the configuration is a secret manager configuration.
|
|
91
91
|
|
|
@@ -124,7 +124,6 @@ class VaultEndpointConfigReader:
|
|
|
124
124
|
credentials_location.get_secret_manager_name()
|
|
125
125
|
== VAULT_SECRET_MANAGER_NAME
|
|
126
126
|
):
|
|
127
|
-
|
|
128
127
|
return VaultCredentialsLocation.from_credentials_location(
|
|
129
128
|
credentials_location=credentials_location
|
|
130
129
|
)
|
|
@@ -161,11 +160,10 @@ class VaultEndpointConfigReader:
|
|
|
161
160
|
credentials_location
|
|
162
161
|
)
|
|
163
162
|
):
|
|
164
|
-
|
|
165
163
|
if credentials_location.transit_key:
|
|
166
|
-
transit_keys[
|
|
167
|
-
credentials_location.
|
|
168
|
-
|
|
164
|
+
transit_keys[credentials_location.secret_key] = (
|
|
165
|
+
credentials_location.transit_key
|
|
166
|
+
)
|
|
169
167
|
|
|
170
168
|
return transit_keys if transit_keys else None
|
|
171
169
|
|
|
@@ -357,9 +355,9 @@ class VaultTokenManager:
|
|
|
357
355
|
|
|
358
356
|
def start(self) -> None:
|
|
359
357
|
"""Start refreshing the token if it is expiring."""
|
|
360
|
-
renew_response: Dict[
|
|
361
|
-
|
|
362
|
-
|
|
358
|
+
renew_response: Dict[Text, Dict[Text, Any]] = (
|
|
359
|
+
self.client.auth.token.lookup_self()
|
|
360
|
+
)
|
|
363
361
|
is_token_expiring = renew_response["data"]["renewable"]
|
|
364
362
|
if is_token_expiring:
|
|
365
363
|
refresh_interval_in_seconds = renew_response["data"]["creation_ttl"]
|
rasa/core/test.py
CHANGED
|
@@ -300,14 +300,14 @@ class EvaluationStore:
|
|
|
300
300
|
filter(
|
|
301
301
|
lambda x: x.get(ENTITY_ATTRIBUTE_TEXT) == text, self.entity_targets
|
|
302
302
|
),
|
|
303
|
-
key=lambda x: x[ENTITY_ATTRIBUTE_START], # type: ignore[literal-required]
|
|
303
|
+
key=lambda x: x[ENTITY_ATTRIBUTE_START], # type: ignore[literal-required]
|
|
304
304
|
)
|
|
305
305
|
entity_predictions = sorted(
|
|
306
306
|
filter(
|
|
307
307
|
lambda x: x.get(ENTITY_ATTRIBUTE_TEXT) == text,
|
|
308
308
|
self.entity_predictions,
|
|
309
309
|
),
|
|
310
|
-
key=lambda x: x[ENTITY_ATTRIBUTE_START], # type: ignore[literal-required]
|
|
310
|
+
key=lambda x: x[ENTITY_ATTRIBUTE_START], # type: ignore[literal-required]
|
|
311
311
|
)
|
|
312
312
|
|
|
313
313
|
i_pred, i_target = 0, 0
|
|
@@ -461,7 +461,7 @@ def _clean_entity_results(
|
|
|
461
461
|
cleaned_entities = []
|
|
462
462
|
|
|
463
463
|
for r in tuple(entity_results):
|
|
464
|
-
cleaned_entity: EntityPrediction = {ENTITY_ATTRIBUTE_TEXT: text} # type: ignore[misc]
|
|
464
|
+
cleaned_entity: EntityPrediction = {ENTITY_ATTRIBUTE_TEXT: text} # type: ignore[misc]
|
|
465
465
|
for k in (
|
|
466
466
|
ENTITY_ATTRIBUTE_START,
|
|
467
467
|
ENTITY_ATTRIBUTE_END,
|
|
@@ -706,7 +706,6 @@ async def _collect_action_executed_predictions(
|
|
|
706
706
|
event: ActionExecuted,
|
|
707
707
|
fail_on_prediction_errors: bool,
|
|
708
708
|
) -> Tuple[EvaluationStore, PolicyPrediction, Optional[EntityEvaluationResult]]:
|
|
709
|
-
|
|
710
709
|
action_executed_eval_store = EvaluationStore()
|
|
711
710
|
|
|
712
711
|
expected_action_name = event.action_name
|
|
@@ -825,7 +824,6 @@ async def _predict_tracker_actions(
|
|
|
825
824
|
List[Dict[Text, Any]],
|
|
826
825
|
List[EntityEvaluationResult],
|
|
827
826
|
]:
|
|
828
|
-
|
|
829
827
|
processor = agent.processor
|
|
830
828
|
if agent.processor is not None:
|
|
831
829
|
processor = agent.processor
|
rasa/core/tracker_store.py
CHANGED
|
@@ -276,10 +276,10 @@ class TrackerStore:
|
|
|
276
276
|
async def retrieve_full_tracker(
|
|
277
277
|
self, conversation_id: Text
|
|
278
278
|
) -> Optional[DialogueStateTracker]:
|
|
279
|
-
"""Retrieve method for fetching all tracker events
|
|
280
|
-
that may be overridden by specific tracker.
|
|
279
|
+
"""Retrieve method for fetching all tracker events.
|
|
281
280
|
|
|
282
|
-
The default implementation
|
|
281
|
+
Fetches events across conversation sessions. The default implementation
|
|
282
|
+
uses `self.retrieve()`.
|
|
283
283
|
|
|
284
284
|
Args:
|
|
285
285
|
conversation_id: The conversation ID to retrieve the tracker for.
|
|
@@ -339,6 +339,28 @@ class TrackerStore:
|
|
|
339
339
|
"""Returns the set of values for the tracker store's primary key."""
|
|
340
340
|
raise NotImplementedError()
|
|
341
341
|
|
|
342
|
+
async def count_conversations(self, after_timestamp: float = 0.0) -> int:
|
|
343
|
+
"""Returns the number of conversations that have occurred after a timestamp.
|
|
344
|
+
|
|
345
|
+
By default, this method returns the number of conversations that
|
|
346
|
+
have occurred after the Unix epoch (i.e. timestamp 0). A conversation
|
|
347
|
+
is considered to have occurred after a timestamp if at least one event
|
|
348
|
+
happened after that timestamp.
|
|
349
|
+
"""
|
|
350
|
+
tracker_keys = await self.keys()
|
|
351
|
+
|
|
352
|
+
conversation_count = 0
|
|
353
|
+
for key in tracker_keys:
|
|
354
|
+
tracker = await self.retrieve(key)
|
|
355
|
+
if tracker is None or not tracker.events:
|
|
356
|
+
continue
|
|
357
|
+
|
|
358
|
+
last_event = tracker.events[-1]
|
|
359
|
+
if last_event.timestamp >= after_timestamp:
|
|
360
|
+
conversation_count += 1
|
|
361
|
+
|
|
362
|
+
return conversation_count
|
|
363
|
+
|
|
342
364
|
def deserialise_tracker(
|
|
343
365
|
self, sender_id: Text, serialised_tracker: Union[Text, bytes]
|
|
344
366
|
) -> Optional[DialogueStateTracker]:
|
|
@@ -930,7 +952,7 @@ def _create_sequence(table_name: Text) -> "Sequence":
|
|
|
930
952
|
"""Creates a sequence object for a specific table name.
|
|
931
953
|
|
|
932
954
|
If using Oracle you will need to create a sequence in your database,
|
|
933
|
-
as described here: https://rasa.com/docs/rasa/tracker-stores#sqltrackerstore
|
|
955
|
+
as described here: https://rasa.com/docs/rasa-pro/production/tracker-stores#sqltrackerstore
|
|
934
956
|
Args:
|
|
935
957
|
table_name: The name of the table, which gets a Sequence assigned
|
|
936
958
|
|
|
@@ -1045,6 +1067,8 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
1045
1067
|
from sqlalchemy.orm import DeclarativeBase
|
|
1046
1068
|
|
|
1047
1069
|
class Base(DeclarativeBase):
|
|
1070
|
+
"""Base class for all tracker store tables."""
|
|
1071
|
+
|
|
1048
1072
|
pass
|
|
1049
1073
|
|
|
1050
1074
|
class SQLEvent(Base):
|
|
@@ -1113,7 +1137,6 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
1113
1137
|
sqlalchemy.exc.OperationalError,
|
|
1114
1138
|
sqlalchemy.exc.IntegrityError,
|
|
1115
1139
|
) as error:
|
|
1116
|
-
|
|
1117
1140
|
logger.warning(error)
|
|
1118
1141
|
sleep(5)
|
|
1119
1142
|
|
|
@@ -1132,8 +1155,10 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
1132
1155
|
login_db: Optional[Text] = None,
|
|
1133
1156
|
query: Optional[Dict] = None,
|
|
1134
1157
|
) -> Union[Text, "URL"]:
|
|
1135
|
-
"""Build an SQLAlchemy `URL` object
|
|
1136
|
-
|
|
1158
|
+
"""Build an SQLAlchemy `URL` object.
|
|
1159
|
+
|
|
1160
|
+
The URL object represents the parameters needed to connect to an
|
|
1161
|
+
SQL database.
|
|
1137
1162
|
|
|
1138
1163
|
Args:
|
|
1139
1164
|
dialect: SQL database type.
|
|
@@ -1260,11 +1285,24 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
1260
1285
|
conversation_id, fetch_events_from_all_sessions=True
|
|
1261
1286
|
)
|
|
1262
1287
|
|
|
1288
|
+
async def count_conversations(self, after_timestamp: float = 0.0) -> int:
|
|
1289
|
+
"""Returns the number of conversations that have occurred after a timestamp.
|
|
1290
|
+
|
|
1291
|
+
By default, this method returns the number of conversations that
|
|
1292
|
+
have occurred after the Unix epoch (i.e. timestamp 0).
|
|
1293
|
+
"""
|
|
1294
|
+
with self.session_scope() as session:
|
|
1295
|
+
query = (
|
|
1296
|
+
session.query(self.SQLEvent.sender_id)
|
|
1297
|
+
.distinct()
|
|
1298
|
+
.filter(self.SQLEvent.timestamp >= after_timestamp)
|
|
1299
|
+
)
|
|
1300
|
+
return query.count()
|
|
1301
|
+
|
|
1263
1302
|
async def _retrieve(
|
|
1264
1303
|
self, sender_id: Text, fetch_events_from_all_sessions: bool
|
|
1265
1304
|
) -> Optional[DialogueStateTracker]:
|
|
1266
1305
|
with self.session_scope() as session:
|
|
1267
|
-
|
|
1268
1306
|
serialised_events = self._event_query(
|
|
1269
1307
|
session,
|
|
1270
1308
|
sender_id,
|
|
@@ -1290,6 +1328,7 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
1290
1328
|
self, session: "Session", sender_id: Text, fetch_events_from_all_sessions: bool
|
|
1291
1329
|
) -> "Query":
|
|
1292
1330
|
"""Provide the query to retrieve the conversation events for a specific sender.
|
|
1331
|
+
|
|
1293
1332
|
The events are ordered by ID to ensure correct sequence of events.
|
|
1294
1333
|
As `timestamp` is not guaranteed to be unique and low-precision (float), it
|
|
1295
1334
|
cannot be used to order the events.
|
|
@@ -1637,9 +1676,7 @@ class AwaitableTrackerStore(TrackerStore):
|
|
|
1637
1676
|
"""Wrapper to call `retrieve` method of primary tracker store."""
|
|
1638
1677
|
result = self._tracker_store.retrieve(sender_id)
|
|
1639
1678
|
return (
|
|
1640
|
-
await result
|
|
1641
|
-
if isawaitable(result)
|
|
1642
|
-
else result # type: ignore[return-value]
|
|
1679
|
+
await result if isawaitable(result) else result # type: ignore[return-value, misc]
|
|
1643
1680
|
)
|
|
1644
1681
|
|
|
1645
1682
|
async def keys(self) -> Iterable[Text]:
|
|
@@ -1658,7 +1695,5 @@ class AwaitableTrackerStore(TrackerStore):
|
|
|
1658
1695
|
"""Wrapper to call `retrieve_full_tracker` method of primary tracker store."""
|
|
1659
1696
|
result = self._tracker_store.retrieve_full_tracker(conversation_id)
|
|
1660
1697
|
return (
|
|
1661
|
-
await result
|
|
1662
|
-
if isawaitable(result)
|
|
1663
|
-
else result # type: ignore[return-value]
|
|
1698
|
+
await result if isawaitable(result) else result # type: ignore[return-value, misc]
|
|
1664
1699
|
)
|
rasa/core/train.py
CHANGED
|
@@ -34,9 +34,7 @@ async def train_comparison_models(
|
|
|
34
34
|
for policy_config in policy_configs:
|
|
35
35
|
config_name = os.path.splitext(os.path.basename(policy_config))[0]
|
|
36
36
|
logging.info(
|
|
37
|
-
"Starting to train {} round {}/{}"
|
|
38
|
-
" with {}% exclusion"
|
|
39
|
-
"".format(
|
|
37
|
+
"Starting to train {} round {}/{} with {}% exclusion".format(
|
|
40
38
|
config_name, current_run, len(exclusion_percentages), percentage
|
|
41
39
|
)
|
|
42
40
|
)
|
|
@@ -3,6 +3,7 @@ import logging
|
|
|
3
3
|
import os
|
|
4
4
|
import textwrap
|
|
5
5
|
import uuid
|
|
6
|
+
import warnings
|
|
6
7
|
from functools import partial
|
|
7
8
|
from multiprocessing import Process
|
|
8
9
|
from typing import (
|
|
@@ -336,7 +337,7 @@ async def _ask_questions(
|
|
|
336
337
|
|
|
337
338
|
|
|
338
339
|
def _selection_choices_from_intent_prediction(
|
|
339
|
-
predictions: List[Dict[Text, Any]]
|
|
340
|
+
predictions: List[Dict[Text, Any]],
|
|
340
341
|
) -> List[Dict[Text, Any]]:
|
|
341
342
|
"""Given a list of ML predictions create a UI choice list."""
|
|
342
343
|
sorted_intents = sorted(
|
|
@@ -762,7 +763,7 @@ async def _request_export_info() -> Tuple[Text, Text, Text]:
|
|
|
762
763
|
|
|
763
764
|
|
|
764
765
|
def _split_conversation_at_restarts(
|
|
765
|
-
events: List[Dict[Text, Any]]
|
|
766
|
+
events: List[Dict[Text, Any]],
|
|
766
767
|
) -> List[List[Dict[Text, Any]]]:
|
|
767
768
|
"""Split a conversation at restart events.
|
|
768
769
|
|
|
@@ -1600,6 +1601,7 @@ def _serve_application(
|
|
|
1600
1601
|
"""Start a core server and attach the interactive learning IO."""
|
|
1601
1602
|
endpoint = EndpointConfig(url=DEFAULT_SERVER_FORMAT.format("http", port))
|
|
1602
1603
|
|
|
1604
|
+
@app.after_server_start
|
|
1603
1605
|
async def run_interactive_io(running_app: Sanic) -> None:
|
|
1604
1606
|
"""Small wrapper to shut down the server once cmd io is done."""
|
|
1605
1607
|
await record_messages(
|
|
@@ -1613,11 +1615,9 @@ def _serve_application(
|
|
|
1613
1615
|
|
|
1614
1616
|
running_app.stop() # kill the sanic server
|
|
1615
1617
|
|
|
1616
|
-
app.add_task(run_interactive_io)
|
|
1617
|
-
|
|
1618
1618
|
update_sanic_log_level()
|
|
1619
1619
|
|
|
1620
|
-
app.run(host="0.0.0.0", port=port)
|
|
1620
|
+
app.run(host="0.0.0.0", port=port, legacy=True)
|
|
1621
1621
|
|
|
1622
1622
|
return app
|
|
1623
1623
|
|
|
@@ -1626,6 +1626,9 @@ def start_visualization(image_path: Text, port: int) -> None:
|
|
|
1626
1626
|
"""Add routes to serve the conversation visualization files."""
|
|
1627
1627
|
app = Sanic("rasa_interactive")
|
|
1628
1628
|
|
|
1629
|
+
# Reset Sanic warnings filter that allows the triggering of Sanic warnings
|
|
1630
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"sanic.*")
|
|
1631
|
+
|
|
1629
1632
|
# noinspection PyUnusedLocal
|
|
1630
1633
|
@app.exception(NotFound)
|
|
1631
1634
|
async def ignore_404s(request: Request, exception: Exception) -> HTTPResponse:
|
|
@@ -1647,7 +1650,7 @@ def start_visualization(image_path: Text, port: int) -> None:
|
|
|
1647
1650
|
|
|
1648
1651
|
update_sanic_log_level()
|
|
1649
1652
|
|
|
1650
|
-
app.run(host="0.0.0.0", port=port, access_log=False)
|
|
1653
|
+
app.run(host="0.0.0.0", port=port, access_log=False, legacy=True)
|
|
1651
1654
|
|
|
1652
1655
|
|
|
1653
1656
|
def run_interactive_learning(
|
rasa/core/utils.py
CHANGED
|
@@ -73,8 +73,9 @@ def one_hot(hot_idx: int, length: int, dtype: Optional[Text] = None) -> np.ndarr
|
|
|
73
73
|
"""
|
|
74
74
|
if hot_idx >= length:
|
|
75
75
|
raise ValueError(
|
|
76
|
-
"Can't create one hot. Index '{}' is out "
|
|
77
|
-
|
|
76
|
+
"Can't create one hot. Index '{}' is out of range (length '{}')".format(
|
|
77
|
+
hot_idx, length
|
|
78
|
+
)
|
|
78
79
|
)
|
|
79
80
|
r = np.zeros(length, dtype)
|
|
80
81
|
r[hot_idx] = 1
|
|
@@ -159,12 +160,6 @@ def is_limit_reached(num_messages: int, limit: Optional[int]) -> bool:
|
|
|
159
160
|
return limit is not None and num_messages >= limit
|
|
160
161
|
|
|
161
162
|
|
|
162
|
-
def file_as_bytes(path: Text) -> bytes:
|
|
163
|
-
"""Read in a file as a byte array."""
|
|
164
|
-
with open(path, "rb") as f:
|
|
165
|
-
return f.read()
|
|
166
|
-
|
|
167
|
-
|
|
168
163
|
class AvailableEndpoints:
|
|
169
164
|
"""Collection of configured endpoints."""
|
|
170
165
|
|
|
@@ -216,7 +211,7 @@ class AvailableEndpoints:
|
|
|
216
211
|
|
|
217
212
|
|
|
218
213
|
def read_endpoints_from_path(
|
|
219
|
-
endpoints_path: Optional[Union[Path, Text]] = None
|
|
214
|
+
endpoints_path: Optional[Union[Path, Text]] = None,
|
|
220
215
|
) -> AvailableEndpoints:
|
|
221
216
|
"""Get `AvailableEndpoints` object from specified path.
|
|
222
217
|
|
|
@@ -283,7 +278,7 @@ def replace_decimals_with_floats(obj: Any) -> Any:
|
|
|
283
278
|
|
|
284
279
|
|
|
285
280
|
def _lock_store_is_multi_worker_compatible(
|
|
286
|
-
lock_store: Union[EndpointConfig, LockStore, None]
|
|
281
|
+
lock_store: Union[EndpointConfig, LockStore, None],
|
|
287
282
|
) -> bool:
|
|
288
283
|
if isinstance(lock_store, InMemoryLockStore):
|
|
289
284
|
return False
|
|
@@ -19,6 +19,7 @@ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
|
19
19
|
from rasa.engine.storage.resource import Resource
|
|
20
20
|
from rasa.engine.storage.storage import ModelStorage
|
|
21
21
|
from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
22
|
+
from rasa.shared.core.domain import Domain
|
|
22
23
|
from rasa.shared.core.flows import FlowsList
|
|
23
24
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
24
25
|
from rasa.shared.exceptions import InvalidConfigException
|
|
@@ -111,6 +112,7 @@ class IntentBasedRouter(GraphComponent):
|
|
|
111
112
|
messages: List[Message],
|
|
112
113
|
flows: FlowsList,
|
|
113
114
|
tracker: Optional[DialogueStateTracker] = None,
|
|
115
|
+
domain: Optional[Domain] = None,
|
|
114
116
|
) -> List[Message]:
|
|
115
117
|
"""Process a list of messages."""
|
|
116
118
|
if tracker is None:
|
|
@@ -118,7 +120,7 @@ class IntentBasedRouter(GraphComponent):
|
|
|
118
120
|
return messages
|
|
119
121
|
|
|
120
122
|
for message in messages:
|
|
121
|
-
commands = await self.predict_commands(message, flows, tracker)
|
|
123
|
+
commands = await self.predict_commands(message, flows, tracker, domain)
|
|
122
124
|
commands_dicts = [command.as_dict() for command in commands]
|
|
123
125
|
message.set(COMMANDS, commands_dicts, add_to_output=True)
|
|
124
126
|
|
|
@@ -129,6 +131,7 @@ class IntentBasedRouter(GraphComponent):
|
|
|
129
131
|
message: Message,
|
|
130
132
|
flows: FlowsList,
|
|
131
133
|
tracker: DialogueStateTracker,
|
|
134
|
+
domain: Optional[Domain] = None,
|
|
132
135
|
) -> List[Command]:
|
|
133
136
|
if not tracker.has_coexistence_routing_slot:
|
|
134
137
|
raise InvalidConfigException(
|
|
@@ -144,8 +147,8 @@ class IntentBasedRouter(GraphComponent):
|
|
|
144
147
|
)
|
|
145
148
|
return commands
|
|
146
149
|
elif route_session_to_calm is True:
|
|
147
|
-
# don't set any commands so that
|
|
148
|
-
# and can predict the actual commands.
|
|
150
|
+
# don't set any commands so that a `LLMBasedCommandGenerator` is
|
|
151
|
+
# triggered and can predict the actual commands.
|
|
149
152
|
return []
|
|
150
153
|
else:
|
|
151
154
|
# If the session is assigned to DM1 add a `NoopCommand` to silence
|
|
@@ -156,7 +159,11 @@ class IntentBasedRouter(GraphComponent):
|
|
|
156
159
|
self, message: Message, tracker: DialogueStateTracker, flows: FlowsList
|
|
157
160
|
) -> bool:
|
|
158
161
|
"""Check if the intent is part of a nlu trigger."""
|
|
159
|
-
commands = NLUCommandAdapter.convert_nlu_to_commands(
|
|
162
|
+
commands = NLUCommandAdapter.convert_nlu_to_commands(
|
|
163
|
+
message,
|
|
164
|
+
tracker,
|
|
165
|
+
flows,
|
|
166
|
+
)
|
|
160
167
|
return len(commands) > 0
|
|
161
168
|
|
|
162
169
|
def _generate_command_using_intent(
|
|
@@ -13,7 +13,7 @@ from rasa.dialogue_understanding.coexistence.constants import (
|
|
|
13
13
|
)
|
|
14
14
|
from rasa.dialogue_understanding.commands import Command, SetSlotCommand
|
|
15
15
|
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
16
|
-
from rasa.dialogue_understanding.generator.
|
|
16
|
+
from rasa.dialogue_understanding.generator.constants import LLM_CONFIG_KEY
|
|
17
17
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
18
18
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
19
19
|
from rasa.engine.storage.resource import Resource
|
|
@@ -196,7 +196,7 @@ class LLMBasedRouter(GraphComponent):
|
|
|
196
196
|
structlogger.info("llm_based_router.predicated_commands", commands=commands)
|
|
197
197
|
return commands
|
|
198
198
|
elif route_session_to_calm is True:
|
|
199
|
-
# don't set any commands so that
|
|
199
|
+
# don't set any commands so that a `LLMBasedCommandGenerator` is triggered
|
|
200
200
|
# and can predict the actual commands.
|
|
201
201
|
return []
|
|
202
202
|
else:
|
|
@@ -231,7 +231,6 @@ class LLMBasedRouter(GraphComponent):
|
|
|
231
231
|
return [SetSlotCommand(ROUTE_TO_CALM_SLOT, False)]
|
|
232
232
|
|
|
233
233
|
def render_template(self, message: Message) -> str:
|
|
234
|
-
|
|
235
234
|
inputs = {
|
|
236
235
|
"user_message": message.get(TEXT),
|
|
237
236
|
f"{CALM_ENTRY}_{STICKY}": self.config[CALM_ENTRY][STICKY],
|
|
@@ -26,6 +26,8 @@ from rasa.dialogue_understanding.commands.correct_slots_command import (
|
|
|
26
26
|
CorrectSlotsCommand,
|
|
27
27
|
CorrectedSlot,
|
|
28
28
|
)
|
|
29
|
+
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
30
|
+
from rasa.dialogue_understanding.commands.change_flow_command import ChangeFlowCommand
|
|
29
31
|
|
|
30
32
|
__all__ = [
|
|
31
33
|
"Command",
|
|
@@ -42,4 +44,6 @@ __all__ = [
|
|
|
42
44
|
"CorrectSlotsCommand",
|
|
43
45
|
"CorrectedSlot",
|
|
44
46
|
"ErrorCommand",
|
|
47
|
+
"NoopCommand",
|
|
48
|
+
"ChangeFlowCommand",
|
|
45
49
|
]
|
|
@@ -59,3 +59,12 @@ class CannotHandleCommand(Command):
|
|
|
59
59
|
else:
|
|
60
60
|
stack.push(CannotHandlePatternFlowStackFrame())
|
|
61
61
|
return tracker.create_stack_updated_events(stack)
|
|
62
|
+
|
|
63
|
+
def __hash__(self) -> int:
|
|
64
|
+
return hash(self.reason)
|
|
65
|
+
|
|
66
|
+
def __eq__(self, other: object) -> bool:
|
|
67
|
+
if not isinstance(other, CannotHandleCommand):
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
return other.reason == self.reason
|
|
@@ -114,3 +114,12 @@ class CancelFlowCommand(Command):
|
|
|
114
114
|
applied_events.append(FlowCancelled(user_frame.flow_id, user_frame.step_id))
|
|
115
115
|
|
|
116
116
|
return applied_events + tracker.create_stack_updated_events(stack)
|
|
117
|
+
|
|
118
|
+
def __hash__(self) -> int:
|
|
119
|
+
return hash(self.command())
|
|
120
|
+
|
|
121
|
+
def __eq__(self, other: object) -> bool:
|
|
122
|
+
if not isinstance(other, CancelFlowCommand):
|
|
123
|
+
return False
|
|
124
|
+
|
|
125
|
+
return True
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
from rasa.dialogue_understanding.commands import Command
|
|
6
|
+
from rasa.shared.core.events import Event
|
|
7
|
+
from rasa.shared.core.flows import FlowsList
|
|
8
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ChangeFlowCommand(Command):
|
|
13
|
+
"""A command to indicate a change of flows was requested by the command
|
|
14
|
+
generator."""
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def command(cls) -> str:
|
|
18
|
+
"""Returns the command type."""
|
|
19
|
+
return "change_flow"
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def from_dict(cls, data: Dict[str, Any]) -> ChangeFlowCommand:
|
|
23
|
+
"""Converts the dictionary to a command.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
The converted dictionary.
|
|
27
|
+
"""
|
|
28
|
+
return ChangeFlowCommand()
|
|
29
|
+
|
|
30
|
+
def run_command_on_tracker(
|
|
31
|
+
self,
|
|
32
|
+
tracker: DialogueStateTracker,
|
|
33
|
+
all_flows: FlowsList,
|
|
34
|
+
original_tracker: DialogueStateTracker,
|
|
35
|
+
) -> List[Event]:
|
|
36
|
+
# the change flow command is not actually pushing anything to the tracker,
|
|
37
|
+
# but it is predicted by the MultiStepLLMCommandGenerator and used internally
|
|
38
|
+
return []
|
|
@@ -46,3 +46,12 @@ class ChitChatAnswerCommand(FreeFormAnswerCommand):
|
|
|
46
46
|
stack = tracker.stack
|
|
47
47
|
stack.push(ChitchatPatternFlowStackFrame())
|
|
48
48
|
return tracker.create_stack_updated_events(stack)
|
|
49
|
+
|
|
50
|
+
def __hash__(self) -> int:
|
|
51
|
+
return hash(self.command())
|
|
52
|
+
|
|
53
|
+
def __eq__(self, other: object) -> bool:
|
|
54
|
+
if not isinstance(other, ChitChatAnswerCommand):
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
return True
|