rasa-pro 3.14.0.dev20250922__py3-none-any.whl → 3.14.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +15 -3
- rasa/agents/__init__.py +0 -0
- rasa/agents/agent_factory.py +122 -0
- rasa/agents/agent_manager.py +211 -0
- rasa/agents/constants.py +43 -0
- rasa/agents/core/__init__.py +0 -0
- rasa/agents/core/agent_protocol.py +107 -0
- rasa/agents/core/types.py +81 -0
- rasa/agents/exceptions.py +38 -0
- rasa/agents/protocol/__init__.py +5 -0
- rasa/agents/protocol/a2a/__init__.py +0 -0
- rasa/agents/protocol/a2a/a2a_agent.py +879 -0
- rasa/agents/protocol/mcp/__init__.py +0 -0
- rasa/agents/protocol/mcp/mcp_base_agent.py +726 -0
- rasa/agents/protocol/mcp/mcp_open_agent.py +327 -0
- rasa/agents/protocol/mcp/mcp_task_agent.py +522 -0
- rasa/agents/schemas/__init__.py +13 -0
- rasa/agents/schemas/agent_input.py +38 -0
- rasa/agents/schemas/agent_output.py +26 -0
- rasa/agents/schemas/agent_tool_result.py +65 -0
- rasa/agents/schemas/agent_tool_schema.py +186 -0
- rasa/agents/templates/__init__.py +0 -0
- rasa/agents/templates/mcp_open_agent_prompt_template.jinja2 +20 -0
- rasa/agents/templates/mcp_task_agent_prompt_template.jinja2 +22 -0
- rasa/agents/utils.py +206 -0
- rasa/agents/validation.py +485 -0
- rasa/api.py +24 -9
- rasa/builder/config.py +6 -2
- rasa/builder/copilot/constants.py +4 -1
- rasa/builder/copilot/copilot.py +155 -79
- rasa/builder/copilot/models.py +304 -108
- rasa/builder/copilot/prompts/copilot_training_error_handler_prompt.jinja2 +53 -0
- rasa/builder/guardrails/{lakera.py → clients.py} +55 -5
- rasa/builder/guardrails/constants.py +3 -0
- rasa/builder/guardrails/models.py +45 -10
- rasa/builder/guardrails/policy_checker.py +324 -0
- rasa/builder/guardrails/utils.py +42 -276
- rasa/builder/jobs.py +182 -12
- rasa/builder/llm_service.py +32 -5
- rasa/builder/models.py +13 -3
- rasa/builder/project_generator.py +6 -1
- rasa/builder/service.py +31 -15
- rasa/builder/training_service.py +18 -24
- rasa/builder/validation_service.py +1 -1
- rasa/cli/arguments/default_arguments.py +12 -0
- rasa/cli/arguments/run.py +2 -0
- rasa/cli/arguments/train.py +2 -0
- rasa/cli/data.py +10 -8
- rasa/cli/dialogue_understanding_test.py +10 -7
- rasa/cli/e2e_test.py +9 -6
- rasa/cli/evaluate.py +4 -2
- rasa/cli/export.py +5 -2
- rasa/cli/inspect.py +8 -4
- rasa/cli/interactive.py +5 -4
- rasa/cli/llm_fine_tuning.py +11 -6
- rasa/cli/project_templates/finance/domain/general/help.yml +0 -0
- rasa/cli/project_templates/tutorial/credentials.yml +10 -0
- rasa/cli/run.py +12 -10
- rasa/cli/scaffold.py +4 -4
- rasa/cli/shell.py +9 -5
- rasa/cli/studio/studio.py +1 -1
- rasa/cli/test.py +34 -14
- rasa/cli/train.py +41 -28
- rasa/cli/utils.py +1 -393
- rasa/cli/validation/__init__.py +0 -0
- rasa/cli/validation/bot_config.py +223 -0
- rasa/cli/validation/config_path_validation.py +257 -0
- rasa/cli/x.py +8 -4
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +51 -10
- rasa/core/actions/grpc_custom_action_executor.py +1 -1
- rasa/core/agent.py +19 -2
- rasa/core/available_agents.py +229 -0
- rasa/core/brokers/kafka.py +5 -1
- rasa/core/channels/__init__.py +82 -35
- rasa/core/channels/development_inspector.py +3 -3
- rasa/core/channels/inspector/README.md +25 -13
- rasa/core/channels/inspector/dist/assets/{arc-35222594.js → arc-6177260a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-a0efbfd3.js → blockDiagram-38ab4fdb-b054f038.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-0584c0f2.js → c4Diagram-3d4e48cf-f25427d5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/channel-bf9cbb34.js +1 -0
- rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-39f40dbe.js → classDiagram-70f12bd4-c7a2af53.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-1ad755f3.js → classDiagram-v2-f2320105-58db65c0.js} +1 -1
- rasa/core/channels/inspector/dist/assets/clone-8f9083bb.js +1 -0
- rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-b0f4f0fe.js → createText-2e5e7dd3-088372e2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-9039bff9.js → edges-e0da2a9e-58676240.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-65c9b127.js → erDiagram-9861fffd-0c14d7c6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-4f08b38e.js → flowDb-956e92f1-ea63f85c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-e95c362a.js → flowDiagram-66a62f08-a2af48cd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-9ecd5b59.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-703c3015.js → flowchart-elk-definition-4a651766-6937abe7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-699328ea.js → ganttDiagram-c361ad54-7473f357.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-04cf4b05.js → gitGraphDiagram-72cf32ee-d0c9405e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{graph-ee94449e.js → graph-0a6f8466.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3862675e-940162b4.js → index-3862675e-7610671a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-74e01d94.js +1354 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-c79c2866.js → infoDiagram-f8f76790-be397dc7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-84489d30.js → journeyDiagram-49397b02-4cefbf62.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-a9aa9858.js → layout-e7fbc2bf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-eb73cf26.js → line-a8aa457c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-b3399f9a.js → linear-3351e0d2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-b095bf1a.js → mindmap-definition-fc14e90a-b8cbf605.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-07644b66.js → pieDiagram-8a3498a8-f327f774.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-573a3f9c.js → quadrantDiagram-120e2f19-2854c591.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-d457e1e1.js → requirementDiagram-deff3bca-964985d5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-9d26e1a2.js → sankeyDiagram-04a897e0-edeb4f33.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-3a9cde10.js → sequenceDiagram-704730f1-fcf70125.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-4f3e8cec.js → stateDiagram-587899a1-0e770395.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-e617e5bf.js → stateDiagram-v2-d93cdb3a-af8dcd22.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-eab30d2f.js → styles-6aaf32cf-36a9e70d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9a916d00-09994be2.js → styles-9a916d00-884a8b5b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-c10674c1-b7110364.js → styles-c10674c1-dc097813.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-3ebc92ad.js → svgDrawCommon-08f97a94-5a2c7eed.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-7d13d2f2.js → timeline-definition-85554ec2-e89c4f6e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-488385e1.js → xychartDiagram-e933f94c-afb6fe56.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/package.json +18 -18
- rasa/core/channels/inspector/src/App.tsx +29 -4
- rasa/core/channels/inspector/src/components/DialogueAgentStack.tsx +108 -0
- rasa/core/channels/inspector/src/components/{DialogueStack.tsx → DialogueHistoryStack.tsx} +4 -2
- rasa/core/channels/inspector/src/helpers/audio/audiostream.ts +7 -4
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +24 -3
- rasa/core/channels/inspector/src/helpers/utils.test.ts +127 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +66 -1
- rasa/core/channels/inspector/src/theme/base/styles.ts +19 -1
- rasa/core/channels/inspector/src/types.ts +21 -0
- rasa/core/channels/inspector/yarn.lock +336 -189
- rasa/core/channels/studio_chat.py +6 -6
- rasa/core/channels/telegram.py +4 -9
- rasa/core/channels/voice_stream/genesys.py +1 -1
- rasa/core/channels/voice_stream/tts/deepgram.py +140 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +5 -1
- rasa/core/channels/voice_stream/voice_channel.py +3 -0
- rasa/core/concurrent_lock_store.py +38 -21
- rasa/core/config/__init__.py +0 -0
- rasa/core/{available_endpoints.py → config/available_endpoints.py} +51 -16
- rasa/core/config/configuration.py +260 -0
- rasa/core/config/credentials.py +19 -0
- rasa/core/config/message_procesing_config.py +34 -0
- rasa/core/constants.py +10 -0
- rasa/core/iam_credentials_providers/aws_iam_credentials_providers.py +69 -4
- rasa/core/iam_credentials_providers/credentials_provider_protocol.py +2 -1
- rasa/core/lock_store.py +4 -0
- rasa/core/policies/enterprise_search_policy.py +5 -3
- rasa/core/policies/flow_policy.py +4 -4
- rasa/core/policies/flows/agent_executor.py +632 -0
- rasa/core/policies/flows/flow_executor.py +136 -75
- rasa/core/policies/flows/mcp_tool_executor.py +298 -0
- rasa/core/policies/intentless_policy.py +1 -1
- rasa/core/policies/ted_policy.py +20 -12
- rasa/core/policies/unexpected_intent_policy.py +6 -0
- rasa/core/processor.py +68 -44
- rasa/core/redis_connection_factory.py +7 -2
- rasa/core/run.py +37 -8
- rasa/core/test.py +4 -0
- rasa/core/tracker_stores/redis_tracker_store.py +4 -0
- rasa/core/tracker_stores/sql_tracker_store.py +3 -1
- rasa/core/tracker_stores/tracker_store.py +3 -7
- rasa/core/train.py +1 -1
- rasa/core/training/interactive.py +20 -18
- rasa/core/training/story_conflict.py +5 -5
- rasa/core/utils.py +22 -23
- rasa/dialogue_understanding/commands/__init__.py +8 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +19 -5
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +21 -2
- rasa/dialogue_understanding/commands/clarify_command.py +20 -2
- rasa/dialogue_understanding/commands/continue_agent_command.py +91 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +21 -2
- rasa/dialogue_understanding/commands/restart_agent_command.py +162 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +68 -7
- rasa/dialogue_understanding/commands/utils.py +124 -2
- rasa/dialogue_understanding/generator/command_parser.py +4 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +50 -12
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +66 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +66 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +89 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_gpt_4o_2024_11_20_template.jinja2 +88 -0
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +42 -7
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +40 -3
- rasa/dialogue_understanding/generator/single_step/single_step_based_llm_command_generator.py +20 -3
- rasa/dialogue_understanding/patterns/cancel.py +27 -6
- rasa/dialogue_understanding/patterns/clarify.py +3 -14
- rasa/dialogue_understanding/patterns/continue_interrupted.py +239 -6
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +46 -8
- rasa/dialogue_understanding/processor/command_processor.py +136 -15
- rasa/dialogue_understanding/stack/dialogue_stack.py +98 -2
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +57 -0
- rasa/dialogue_understanding/stack/utils.py +57 -3
- rasa/dialogue_understanding/utils.py +24 -4
- rasa/dialogue_understanding_test/du_test_runner.py +8 -3
- rasa/e2e_test/e2e_test_runner.py +13 -3
- rasa/engine/caching.py +2 -2
- rasa/engine/constants.py +1 -1
- rasa/engine/recipes/default_components.py +138 -49
- rasa/engine/recipes/default_recipe.py +108 -11
- rasa/engine/runner/dask.py +8 -5
- rasa/engine/validation.py +19 -6
- rasa/graph_components/validators/default_recipe_validator.py +86 -28
- rasa/hooks.py +5 -5
- rasa/llm_fine_tuning/utils.py +2 -2
- rasa/model_training.py +60 -47
- rasa/nlu/classifiers/diet_classifier.py +198 -98
- rasa/nlu/classifiers/logistic_regression_classifier.py +1 -4
- rasa/nlu/classifiers/mitie_intent_classifier.py +3 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +1 -3
- rasa/nlu/extractors/crf_entity_extractor.py +9 -10
- rasa/nlu/extractors/mitie_entity_extractor.py +3 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +3 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +4 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +5 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +2 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +3 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +4 -2
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +4 -0
- rasa/nlu/selectors/response_selector.py +10 -2
- rasa/nlu/tokenizers/jieba_tokenizer.py +3 -4
- rasa/nlu/tokenizers/mitie_tokenizer.py +3 -2
- rasa/nlu/tokenizers/spacy_tokenizer.py +3 -2
- rasa/nlu/utils/mitie_utils.py +3 -0
- rasa/nlu/utils/spacy_utils.py +3 -2
- rasa/plugin.py +8 -8
- rasa/privacy/privacy_manager.py +12 -3
- rasa/server.py +15 -3
- rasa/shared/agents/__init__.py +0 -0
- rasa/shared/agents/auth/__init__.py +0 -0
- rasa/shared/agents/auth/agent_auth_factory.py +105 -0
- rasa/shared/agents/auth/agent_auth_manager.py +92 -0
- rasa/shared/agents/auth/auth_strategy/__init__.py +19 -0
- rasa/shared/agents/auth/auth_strategy/agent_auth_strategy.py +52 -0
- rasa/shared/agents/auth/auth_strategy/api_key_auth_strategy.py +42 -0
- rasa/shared/agents/auth/auth_strategy/bearer_token_auth_strategy.py +28 -0
- rasa/shared/agents/auth/auth_strategy/oauth2_auth_strategy.py +167 -0
- rasa/shared/agents/auth/constants.py +12 -0
- rasa/shared/agents/auth/types.py +12 -0
- rasa/shared/agents/utils.py +35 -0
- rasa/shared/constants.py +8 -0
- rasa/shared/core/constants.py +16 -1
- rasa/shared/core/domain.py +0 -7
- rasa/shared/core/events.py +327 -0
- rasa/shared/core/flows/constants.py +5 -0
- rasa/shared/core/flows/flows_list.py +21 -5
- rasa/shared/core/flows/flows_yaml_schema.json +119 -184
- rasa/shared/core/flows/steps/call.py +49 -5
- rasa/shared/core/flows/steps/collect.py +98 -13
- rasa/shared/core/flows/validation.py +372 -8
- rasa/shared/core/flows/yaml_flows_io.py +3 -2
- rasa/shared/core/slots.py +2 -2
- rasa/shared/core/trackers.py +5 -2
- rasa/shared/exceptions.py +16 -0
- rasa/shared/importers/rasa.py +1 -1
- rasa/shared/importers/utils.py +9 -3
- rasa/shared/providers/llm/_base_litellm_client.py +41 -9
- rasa/shared/providers/llm/litellm_router_llm_client.py +8 -4
- rasa/shared/providers/llm/llm_client.py +7 -3
- rasa/shared/providers/llm/llm_response.py +66 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +8 -4
- rasa/shared/utils/common.py +24 -0
- rasa/shared/utils/health_check/health_check.py +7 -3
- rasa/shared/utils/llm.py +39 -16
- rasa/shared/utils/mcp/__init__.py +0 -0
- rasa/shared/utils/mcp/server_connection.py +247 -0
- rasa/shared/utils/mcp/utils.py +20 -0
- rasa/shared/utils/schemas/events.py +42 -0
- rasa/shared/utils/yaml.py +3 -1
- rasa/studio/pull/pull.py +3 -2
- rasa/studio/train.py +8 -7
- rasa/studio/upload.py +3 -6
- rasa/telemetry.py +69 -5
- rasa/tracing/config.py +45 -12
- rasa/tracing/constants.py +14 -0
- rasa/tracing/instrumentation/attribute_extractors.py +142 -9
- rasa/tracing/instrumentation/instrumentation.py +626 -21
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +4 -4
- rasa/tracing/instrumentation/metrics.py +32 -0
- rasa/tracing/metric_instrument_provider.py +68 -0
- rasa/utils/common.py +92 -1
- rasa/utils/endpoints.py +11 -2
- rasa/utils/log_utils.py +96 -5
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/__init__.py +7 -0
- rasa/utils/tensorflow/callback.py +136 -101
- rasa/utils/tensorflow/crf.py +1 -1
- rasa/utils/tensorflow/data_generator.py +21 -8
- rasa/utils/tensorflow/layers.py +21 -11
- rasa/utils/tensorflow/metrics.py +7 -3
- rasa/utils/tensorflow/models.py +56 -8
- rasa/utils/tensorflow/rasa_layers.py +8 -6
- rasa/utils/tensorflow/transformer.py +2 -3
- rasa/utils/train_utils.py +54 -24
- rasa/validator.py +5 -5
- rasa/version.py +1 -1
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc2.dist-info}/METADATA +47 -41
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc2.dist-info}/RECORD +299 -238
- rasa/builder/scrape_rasa_docs.py +0 -97
- rasa/core/channels/inspector/dist/assets/channel-8e08bed9.js +0 -1
- rasa/core/channels/inspector/dist/assets/clone-78c82dea.js +0 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-2b08f601.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-c941dcb3.js +0 -1336
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc2.dist-info}/entry_points.txt +0 -0
rasa/core/policies/ted_policy.py
CHANGED
|
@@ -58,6 +58,7 @@ from rasa.shared.nlu.training_data.features import (
|
|
|
58
58
|
save_features,
|
|
59
59
|
)
|
|
60
60
|
from rasa.shared.nlu.training_data.message import Message
|
|
61
|
+
from rasa.shared.utils.io import raise_deprecation_warning
|
|
61
62
|
from rasa.utils import train_utils
|
|
62
63
|
from rasa.utils.tensorflow import rasa_layers
|
|
63
64
|
from rasa.utils.tensorflow.constants import (
|
|
@@ -80,7 +81,6 @@ from rasa.utils.tensorflow.constants import (
|
|
|
80
81
|
EMBEDDING_DIMENSION,
|
|
81
82
|
ENCODING_DIMENSION,
|
|
82
83
|
ENTITY_RECOGNITION,
|
|
83
|
-
EPOCH_OVERRIDE,
|
|
84
84
|
EPOCHS,
|
|
85
85
|
EVAL_NUM_EPOCHS,
|
|
86
86
|
EVAL_NUM_EXAMPLES,
|
|
@@ -363,6 +363,9 @@ class TEDPolicy(Policy):
|
|
|
363
363
|
entity_tag_specs: Optional[List[EntityTagSpec]] = None,
|
|
364
364
|
) -> None:
|
|
365
365
|
"""Declares instance variables with default values."""
|
|
366
|
+
raise_deprecation_warning(
|
|
367
|
+
"TEDPolicy is deprecated and will be removed in a future version."
|
|
368
|
+
)
|
|
366
369
|
super().__init__(
|
|
367
370
|
config, model_storage, resource, execution_context, featurizer=featurizer
|
|
368
371
|
)
|
|
@@ -668,6 +671,7 @@ class TEDPolicy(Policy):
|
|
|
668
671
|
self.model.compile(
|
|
669
672
|
optimizer=tf.keras.optimizers.Adam(self.config[LEARNING_RATE])
|
|
670
673
|
)
|
|
674
|
+
|
|
671
675
|
(
|
|
672
676
|
data_generator,
|
|
673
677
|
validation_data_generator,
|
|
@@ -943,14 +947,16 @@ class TEDPolicy(Policy):
|
|
|
943
947
|
|
|
944
948
|
with self._model_storage.write_to(self._resource) as model_path:
|
|
945
949
|
model_filename = self._metadata_filename()
|
|
946
|
-
tf_model_file = model_path / f"{model_filename}.
|
|
950
|
+
tf_model_file = model_path / f"{model_filename}.weights.h5"
|
|
947
951
|
|
|
948
952
|
rasa.shared.utils.io.create_directory_for_file(tf_model_file)
|
|
949
953
|
|
|
950
954
|
self.featurizer.persist(model_path)
|
|
951
955
|
|
|
952
956
|
if self.config[CHECKPOINT_MODEL] and self.tmp_checkpoint_dir:
|
|
953
|
-
self.model.load_weights(
|
|
957
|
+
self.model.load_weights(
|
|
958
|
+
self.tmp_checkpoint_dir / "checkpoint.weights.h5"
|
|
959
|
+
)
|
|
954
960
|
# Save an empty file to flag that this model has been
|
|
955
961
|
# produced using checkpointing
|
|
956
962
|
checkpoint_marker = model_path / f"{model_filename}.from_checkpoint.pkl"
|
|
@@ -1009,7 +1015,7 @@ class TEDPolicy(Policy):
|
|
|
1009
1015
|
Args:
|
|
1010
1016
|
model_path: Path where model is to be persisted.
|
|
1011
1017
|
"""
|
|
1012
|
-
tf_model_file = model_path / f"{cls._metadata_filename()}.
|
|
1018
|
+
tf_model_file = model_path / f"{cls._metadata_filename()}.weights.h5"
|
|
1013
1019
|
|
|
1014
1020
|
# load data example
|
|
1015
1021
|
loaded_data = deserialize_nested_feature_arrays(
|
|
@@ -1109,8 +1115,6 @@ class TEDPolicy(Policy):
|
|
|
1109
1115
|
model_utilities = cls._load_model_utilities(model_path)
|
|
1110
1116
|
|
|
1111
1117
|
config = cls._update_loaded_params(config)
|
|
1112
|
-
if execution_context.is_finetuning and EPOCH_OVERRIDE in config:
|
|
1113
|
-
config[EPOCHS] = config.get(EPOCH_OVERRIDE)
|
|
1114
1118
|
|
|
1115
1119
|
(
|
|
1116
1120
|
model_data_example,
|
|
@@ -1125,7 +1129,6 @@ class TEDPolicy(Policy):
|
|
|
1125
1129
|
model_data_example,
|
|
1126
1130
|
predict_data_example,
|
|
1127
1131
|
featurizer,
|
|
1128
|
-
execution_context.is_finetuning,
|
|
1129
1132
|
)
|
|
1130
1133
|
|
|
1131
1134
|
return cls._load_policy_with_model(
|
|
@@ -1167,7 +1170,6 @@ class TEDPolicy(Policy):
|
|
|
1167
1170
|
model_data_example: RasaModelData,
|
|
1168
1171
|
predict_data_example: RasaModelData,
|
|
1169
1172
|
featurizer: TrackerFeaturizer,
|
|
1170
|
-
should_finetune: bool,
|
|
1171
1173
|
) -> TED:
|
|
1172
1174
|
model = cls.model_class().load(
|
|
1173
1175
|
str(model_utilities["tf_model_file"]),
|
|
@@ -1180,7 +1182,9 @@ class TEDPolicy(Policy):
|
|
|
1180
1182
|
),
|
|
1181
1183
|
label_data=model_utilities["label_data"],
|
|
1182
1184
|
entity_tag_specs=model_utilities["entity_tag_specs"],
|
|
1183
|
-
|
|
1185
|
+
# This feature is no longer supported as the updated version
|
|
1186
|
+
# of Keras does not allow updating a compiled model anymore.
|
|
1187
|
+
finetune_mode=False,
|
|
1184
1188
|
)
|
|
1185
1189
|
return model
|
|
1186
1190
|
|
|
@@ -1463,7 +1467,7 @@ class TED(TransformerRasaModel):
|
|
|
1463
1467
|
|
|
1464
1468
|
dialogue_transformed, attention_weights = self._tf_layers[
|
|
1465
1469
|
f"transformer.{DIALOGUE}"
|
|
1466
|
-
](dialogue_in, 1 - mask, self._training)
|
|
1470
|
+
](dialogue_in, 1 - mask, training=self._training)
|
|
1467
1471
|
dialogue_transformed = tf.nn.gelu(dialogue_transformed)
|
|
1468
1472
|
|
|
1469
1473
|
if self.max_history_featurizer_is_used:
|
|
@@ -1708,7 +1712,7 @@ class TED(TransformerRasaModel):
|
|
|
1708
1712
|
|
|
1709
1713
|
if attribute in SENTENCE_FEATURES_TO_ENCODE + LABEL_FEATURES_TO_ENCODE:
|
|
1710
1714
|
attribute_features = self._tf_layers[f"encoding_layer.{attribute}"](
|
|
1711
|
-
attribute_features, self._training
|
|
1715
|
+
attribute_features, training=self._training
|
|
1712
1716
|
)
|
|
1713
1717
|
|
|
1714
1718
|
# attribute features have shape
|
|
@@ -2102,7 +2106,11 @@ class TED(TransformerRasaModel):
|
|
|
2102
2106
|
predictions = {
|
|
2103
2107
|
"scores": scores,
|
|
2104
2108
|
"similarities": sim_all,
|
|
2105
|
-
DIAGNOSTIC_DATA: {
|
|
2109
|
+
DIAGNOSTIC_DATA: {
|
|
2110
|
+
"attention_weights": attention_weights.numpy()
|
|
2111
|
+
if attention_weights is not None and hasattr(attention_weights, "numpy")
|
|
2112
|
+
else attention_weights,
|
|
2113
|
+
},
|
|
2106
2114
|
}
|
|
2107
2115
|
|
|
2108
2116
|
if (
|
|
@@ -54,6 +54,7 @@ from rasa.shared.nlu.constants import (
|
|
|
54
54
|
)
|
|
55
55
|
from rasa.shared.nlu.training_data.features import Features
|
|
56
56
|
from rasa.shared.utils import common
|
|
57
|
+
from rasa.shared.utils.io import raise_deprecation_warning
|
|
57
58
|
from rasa.utils import train_utils
|
|
58
59
|
from rasa.utils.tensorflow import layers
|
|
59
60
|
from rasa.utils.tensorflow.constants import (
|
|
@@ -300,6 +301,10 @@ class UnexpecTEDIntentPolicy(TEDPolicy):
|
|
|
300
301
|
label_quantiles: Optional[Dict[int, List[float]]] = None,
|
|
301
302
|
):
|
|
302
303
|
"""Declares instance variables with default values."""
|
|
304
|
+
raise_deprecation_warning(
|
|
305
|
+
"UnexpecTEDIntentPolicy is deprecated and "
|
|
306
|
+
"will be removed in a future version."
|
|
307
|
+
)
|
|
303
308
|
# Set all invalid / non configurable parameters
|
|
304
309
|
config[ENTITY_RECOGNITION] = False
|
|
305
310
|
config[BILOU_FLAG] = False
|
|
@@ -624,6 +629,7 @@ class UnexpecTEDIntentPolicy(TEDPolicy):
|
|
|
624
629
|
query_intent = (
|
|
625
630
|
last_user_uttered_event.intent_name
|
|
626
631
|
if last_user_uttered_event is not None
|
|
632
|
+
and isinstance(last_user_uttered_event, UserUttered)
|
|
627
633
|
else ""
|
|
628
634
|
)
|
|
629
635
|
is_unlikely_intent = self._check_unlikely_intent(
|
rasa/core/processor.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import inspect
|
|
3
|
-
import logging
|
|
4
3
|
import os
|
|
5
4
|
import re
|
|
6
5
|
import tarfile
|
|
@@ -69,6 +68,7 @@ from rasa.shared.constants import (
|
|
|
69
68
|
UTTER_PREFIX,
|
|
70
69
|
)
|
|
71
70
|
from rasa.shared.core.constants import (
|
|
71
|
+
ACTION_AGENT_REQUEST_USER_INPUT_NAME,
|
|
72
72
|
ACTION_CORRECT_FLOW_SLOT,
|
|
73
73
|
ACTION_EXTRACT_SLOTS,
|
|
74
74
|
ACTION_LISTEN_NAME,
|
|
@@ -113,10 +113,9 @@ from rasa.utils.common import TempDirectoryPath, get_temp_dir_name
|
|
|
113
113
|
from rasa.utils.endpoints import EndpointConfig
|
|
114
114
|
|
|
115
115
|
if TYPE_CHECKING:
|
|
116
|
-
from rasa.core.available_endpoints import AvailableEndpoints
|
|
116
|
+
from rasa.core.config.available_endpoints import AvailableEndpoints
|
|
117
117
|
from rasa.privacy.privacy_manager import BackgroundPrivacyManager
|
|
118
118
|
|
|
119
|
-
logger = logging.getLogger(__name__)
|
|
120
119
|
structlogger = structlog.get_logger()
|
|
121
120
|
|
|
122
121
|
MAX_NUMBER_OF_PREDICTIONS = int(os.environ.get("MAX_NUMBER_OF_PREDICTIONS", "10"))
|
|
@@ -190,7 +189,11 @@ class MessageProcessor:
|
|
|
190
189
|
except TypeError:
|
|
191
190
|
raise ModelNotFound(f"Model {model_path} can not be loaded.")
|
|
192
191
|
|
|
193
|
-
|
|
192
|
+
structlogger.info(
|
|
193
|
+
"rasa.core.processor.load_model",
|
|
194
|
+
event_info="Loading model.",
|
|
195
|
+
model_path=model_tar,
|
|
196
|
+
)
|
|
194
197
|
with TempDirectoryPath(get_temp_dir_name()) as temporary_directory:
|
|
195
198
|
try:
|
|
196
199
|
metadata, runner = loader.load_predict_graph_runner(
|
|
@@ -365,8 +368,10 @@ class MessageProcessor:
|
|
|
365
368
|
`ActionSessionStart`.
|
|
366
369
|
"""
|
|
367
370
|
if not tracker.applied_events() or self._has_session_expired(tracker):
|
|
368
|
-
|
|
369
|
-
|
|
371
|
+
structlogger.debug(
|
|
372
|
+
"rasa.core.processor._update_tracker_session",
|
|
373
|
+
event_info="Starting a new session.",
|
|
374
|
+
sender_id=tracker.sender_id,
|
|
370
375
|
)
|
|
371
376
|
|
|
372
377
|
action_session_start = self._get_action(ACTION_SESSION_START_NAME)
|
|
@@ -598,9 +603,11 @@ class MessageProcessor:
|
|
|
598
603
|
prediction.max_confidence_index, self.domain, self.action_endpoint
|
|
599
604
|
)
|
|
600
605
|
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
606
|
+
structlogger.debug(
|
|
607
|
+
"rasa.core.processor.predict_next_with_tracker_if_should",
|
|
608
|
+
event_info="Predicted next action.",
|
|
609
|
+
action=action.name(),
|
|
610
|
+
confidence=prediction.max_confidence,
|
|
604
611
|
)
|
|
605
612
|
|
|
606
613
|
return action, prediction
|
|
@@ -650,8 +657,10 @@ class MessageProcessor:
|
|
|
650
657
|
and self._has_message_after_reminder(tracker, reminder_event)
|
|
651
658
|
or not self._is_reminder_still_valid(tracker, reminder_event)
|
|
652
659
|
):
|
|
653
|
-
|
|
654
|
-
|
|
660
|
+
structlogger.debug(
|
|
661
|
+
"rasa.core.processor.handle_reminder",
|
|
662
|
+
event_info="Canceled reminder because it is outdated.",
|
|
663
|
+
reminder_event=reminder_event,
|
|
655
664
|
)
|
|
656
665
|
else:
|
|
657
666
|
intent = reminder_event.intent
|
|
@@ -731,7 +740,7 @@ class MessageProcessor:
|
|
|
731
740
|
if not self.domain or self.domain.is_empty():
|
|
732
741
|
return
|
|
733
742
|
|
|
734
|
-
intent = parse_data[
|
|
743
|
+
intent = parse_data[INTENT][INTENT_NAME_KEY]
|
|
735
744
|
if intent and intent not in self.domain.intents:
|
|
736
745
|
rasa.shared.utils.io.raise_warning(
|
|
737
746
|
f"Parsed an intent '{intent}' "
|
|
@@ -740,7 +749,7 @@ class MessageProcessor:
|
|
|
740
749
|
docs=DOCS_URL_DOMAINS,
|
|
741
750
|
)
|
|
742
751
|
|
|
743
|
-
entities = parse_data[
|
|
752
|
+
entities = parse_data[ENTITIES] or []
|
|
744
753
|
for element in entities:
|
|
745
754
|
entity = element["entity"]
|
|
746
755
|
if entity and entity not in self.domain.entities:
|
|
@@ -824,9 +833,9 @@ class MessageProcessor:
|
|
|
824
833
|
self._update_full_retrieval_intent(parse_data)
|
|
825
834
|
structlogger.debug(
|
|
826
835
|
"processor.message.parse",
|
|
827
|
-
parse_data_text=copy.deepcopy(parse_data[
|
|
828
|
-
parse_data_intent=parse_data[
|
|
829
|
-
parse_data_entities=copy.deepcopy(parse_data[
|
|
836
|
+
parse_data_text=copy.deepcopy(parse_data[TEXT]),
|
|
837
|
+
parse_data_intent=parse_data[INTENT],
|
|
838
|
+
parse_data_entities=copy.deepcopy(parse_data[ENTITIES]),
|
|
830
839
|
)
|
|
831
840
|
|
|
832
841
|
self._check_for_unseen_features(parse_data)
|
|
@@ -975,7 +984,7 @@ class MessageProcessor:
|
|
|
975
984
|
f"invalid intent: {parse_data[INTENT]['name']}. "
|
|
976
985
|
f"Returning CannotHandleCommand() as a fallback."
|
|
977
986
|
),
|
|
978
|
-
invalid_intent=parse_data[INTENT][
|
|
987
|
+
invalid_intent=parse_data[INTENT][INTENT_NAME_KEY],
|
|
979
988
|
)
|
|
980
989
|
commands.append(
|
|
981
990
|
CannotHandleCommand(RASA_PATTERN_CANNOT_HANDLE_INVALID_INTENT)
|
|
@@ -985,7 +994,7 @@ class MessageProcessor:
|
|
|
985
994
|
|
|
986
995
|
def _contains_undefined_intent(self, message: Message) -> bool:
|
|
987
996
|
"""Checks if the message contains an undefined intent."""
|
|
988
|
-
intent_name = message.get(INTENT, {}).get(
|
|
997
|
+
intent_name = message.get(INTENT, {}).get(INTENT_NAME_KEY)
|
|
989
998
|
return intent_name is not None and intent_name not in self.domain.intents
|
|
990
999
|
|
|
991
1000
|
async def _parse_message_with_graph(
|
|
@@ -1035,8 +1044,8 @@ class MessageProcessor:
|
|
|
1035
1044
|
tracker.update(
|
|
1036
1045
|
UserUttered(
|
|
1037
1046
|
message.text,
|
|
1038
|
-
parse_data[
|
|
1039
|
-
parse_data[
|
|
1047
|
+
parse_data[INTENT],
|
|
1048
|
+
parse_data[ENTITIES],
|
|
1040
1049
|
parse_data,
|
|
1041
1050
|
input_channel=message.input_channel,
|
|
1042
1051
|
message_id=message.message_id,
|
|
@@ -1045,13 +1054,16 @@ class MessageProcessor:
|
|
|
1045
1054
|
self.domain,
|
|
1046
1055
|
)
|
|
1047
1056
|
|
|
1048
|
-
if parse_data[
|
|
1057
|
+
if parse_data[ENTITIES]:
|
|
1049
1058
|
self._log_slots(tracker)
|
|
1050
1059
|
|
|
1051
1060
|
plugin_manager().hook.after_new_user_message(tracker=tracker)
|
|
1052
1061
|
|
|
1053
|
-
|
|
1054
|
-
|
|
1062
|
+
structlogger.debug(
|
|
1063
|
+
"rasa.core.processor.handle_message_with_tracker",
|
|
1064
|
+
event_info="Logged UserUtterance.",
|
|
1065
|
+
user_message=message.text,
|
|
1066
|
+
number_of_events=len(tracker.events),
|
|
1055
1067
|
)
|
|
1056
1068
|
|
|
1057
1069
|
@staticmethod
|
|
@@ -1166,9 +1178,11 @@ class MessageProcessor:
|
|
|
1166
1178
|
tracker
|
|
1167
1179
|
)
|
|
1168
1180
|
except ActionLimitReached:
|
|
1169
|
-
|
|
1170
|
-
"
|
|
1171
|
-
|
|
1181
|
+
structlogger.warning(
|
|
1182
|
+
"rasa.core.processor.run_prediction_loop",
|
|
1183
|
+
event_info="Circuit breaker tripped. Stopped predicting more "
|
|
1184
|
+
"actions.",
|
|
1185
|
+
sender_id=tracker.sender_id,
|
|
1172
1186
|
)
|
|
1173
1187
|
if self.on_circuit_break:
|
|
1174
1188
|
# call a registered callback
|
|
@@ -1176,9 +1190,11 @@ class MessageProcessor:
|
|
|
1176
1190
|
break
|
|
1177
1191
|
|
|
1178
1192
|
if prediction.is_end_to_end_prediction:
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1193
|
+
structlogger.debug(
|
|
1194
|
+
"rasa.core.processor.run_prediction_loop",
|
|
1195
|
+
event_info="An end-to-end prediction was made which has "
|
|
1196
|
+
"triggered the 2nd execution of the default action.",
|
|
1197
|
+
action=ACTION_EXTRACT_SLOTS,
|
|
1182
1198
|
)
|
|
1183
1199
|
tracker = await self.run_action_extract_slots(output_channel, tracker)
|
|
1184
1200
|
|
|
@@ -1197,7 +1213,11 @@ class MessageProcessor:
|
|
|
1197
1213
|
`False` if `action_name` is `ACTION_LISTEN_NAME` or
|
|
1198
1214
|
`ACTION_SESSION_START_NAME`, otherwise `True`.
|
|
1199
1215
|
"""
|
|
1200
|
-
return action_name not in (
|
|
1216
|
+
return action_name not in (
|
|
1217
|
+
ACTION_LISTEN_NAME,
|
|
1218
|
+
ACTION_SESSION_START_NAME,
|
|
1219
|
+
ACTION_AGENT_REQUEST_USER_INPUT_NAME,
|
|
1220
|
+
)
|
|
1201
1221
|
|
|
1202
1222
|
async def execute_side_effects(
|
|
1203
1223
|
self,
|
|
@@ -1390,10 +1410,11 @@ class MessageProcessor:
|
|
|
1390
1410
|
)
|
|
1391
1411
|
|
|
1392
1412
|
if any(isinstance(e, UserUttered) for e in events):
|
|
1393
|
-
|
|
1394
|
-
|
|
1413
|
+
structlogger.debug(
|
|
1414
|
+
"rasa.core.processor.run_action",
|
|
1415
|
+
message="A `UserUttered` event was returned by executing "
|
|
1395
1416
|
f"action '{action.name()}'. This will run the default action "
|
|
1396
|
-
f"'{ACTION_EXTRACT_SLOTS}'."
|
|
1417
|
+
f"'{ACTION_EXTRACT_SLOTS}'.",
|
|
1397
1418
|
)
|
|
1398
1419
|
tracker = await self.run_action_extract_slots(output_channel, tracker)
|
|
1399
1420
|
|
|
@@ -1499,11 +1520,9 @@ class MessageProcessor:
|
|
|
1499
1520
|
# tracker has never expired if sessions are disabled
|
|
1500
1521
|
return False
|
|
1501
1522
|
|
|
1502
|
-
user_uttered_event
|
|
1503
|
-
UserUttered
|
|
1504
|
-
)
|
|
1523
|
+
user_uttered_event = tracker.get_last_event_for(UserUttered)
|
|
1505
1524
|
|
|
1506
|
-
if not user_uttered_event:
|
|
1525
|
+
if not user_uttered_event or not isinstance(user_uttered_event, UserUttered):
|
|
1507
1526
|
# there is no user event so far so the session should not be considered
|
|
1508
1527
|
# expired
|
|
1509
1528
|
return False
|
|
@@ -1514,9 +1533,10 @@ class MessageProcessor:
|
|
|
1514
1533
|
> self.domain.session_config.session_expiration_time
|
|
1515
1534
|
)
|
|
1516
1535
|
if has_expired:
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1536
|
+
structlogger.debug(
|
|
1537
|
+
"rasa.core.processor.has_session_expired",
|
|
1538
|
+
event_info="The latest session has expired.",
|
|
1539
|
+
sender_id=tracker.sender_id,
|
|
1520
1540
|
)
|
|
1521
1541
|
|
|
1522
1542
|
return has_expired
|
|
@@ -1542,10 +1562,14 @@ class MessageProcessor:
|
|
|
1542
1562
|
)
|
|
1543
1563
|
return prediction
|
|
1544
1564
|
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
"
|
|
1548
|
-
|
|
1565
|
+
structlogger.error(
|
|
1566
|
+
"rasa.core.processor.predict_next_with_tracker",
|
|
1567
|
+
event_info="Trying to run unknown follow-up action.",
|
|
1568
|
+
message=(
|
|
1569
|
+
"Trying to run unknown follow-up action. Instead of running "
|
|
1570
|
+
"that, Rasa Pro will ignore the action and predict the next action."
|
|
1571
|
+
),
|
|
1572
|
+
followup_action=followup_action,
|
|
1549
1573
|
)
|
|
1550
1574
|
|
|
1551
1575
|
target = self.model_metadata.core_target
|
|
@@ -6,7 +6,10 @@ import redis
|
|
|
6
6
|
import structlog
|
|
7
7
|
from pydantic import BaseModel, ConfigDict
|
|
8
8
|
|
|
9
|
-
from rasa.core.constants import
|
|
9
|
+
from rasa.core.constants import (
|
|
10
|
+
AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME,
|
|
11
|
+
REDIS_SERVICE_NAME,
|
|
12
|
+
)
|
|
10
13
|
from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
|
|
11
14
|
IAMCredentialsProvider,
|
|
12
15
|
IAMCredentialsProviderInput,
|
|
@@ -65,6 +68,7 @@ class RedisConfig(BaseModel):
|
|
|
65
68
|
|
|
66
69
|
host: Text = "localhost"
|
|
67
70
|
port: int = 6379
|
|
71
|
+
service_type: SupportedServiceType
|
|
68
72
|
username: Optional[Text] = None
|
|
69
73
|
password: Optional[Text] = None
|
|
70
74
|
use_ssl: bool = False
|
|
@@ -117,7 +121,8 @@ class RedisConnectionFactory:
|
|
|
117
121
|
|
|
118
122
|
iam_credentials_provider = create_iam_credentials_provider(
|
|
119
123
|
IAMCredentialsProviderInput(
|
|
120
|
-
|
|
124
|
+
service_type=config.service_type,
|
|
125
|
+
service_name=REDIS_SERVICE_NAME,
|
|
121
126
|
username=config.username,
|
|
122
127
|
cluster_name=os.getenv(AWS_ELASTICACHE_CLUSTER_NAME_ENV_VAR_NAME),
|
|
123
128
|
)
|
rasa/core/run.py
CHANGED
|
@@ -30,34 +30,35 @@ from rasa import server, telemetry
|
|
|
30
30
|
from rasa.constants import ENV_SANIC_BACKLOG
|
|
31
31
|
from rasa.core import agent, channels, constants
|
|
32
32
|
from rasa.core.agent import Agent
|
|
33
|
-
from rasa.core.
|
|
33
|
+
from rasa.core.available_agents import AvailableAgents
|
|
34
34
|
from rasa.core.channels import console
|
|
35
35
|
from rasa.core.channels.channel import InputChannel
|
|
36
36
|
from rasa.core.channels.development_inspector import DevelopmentInspectProxy
|
|
37
|
+
from rasa.core.config.available_endpoints import AvailableEndpoints
|
|
38
|
+
from rasa.core.config.credentials import CredentialsConfig
|
|
37
39
|
from rasa.core.persistor import StorageType
|
|
38
40
|
from rasa.shared.exceptions import RasaException
|
|
39
|
-
from rasa.shared.utils.yaml import read_config_file
|
|
40
41
|
from rasa.utils import licensing
|
|
41
42
|
|
|
42
43
|
logger = logging.getLogger() # get the root logger
|
|
43
44
|
|
|
44
45
|
|
|
45
46
|
def create_input_channels(
|
|
46
|
-
channel: Optional[Text],
|
|
47
|
+
channel: Optional[Text], credentials_config: Optional[CredentialsConfig]
|
|
47
48
|
) -> List[InputChannel]:
|
|
48
49
|
"""Instantiate the chosen input channel.
|
|
49
50
|
|
|
50
51
|
Args:
|
|
51
52
|
channel (optional): The name of the specific input channel to create.
|
|
52
|
-
|
|
53
|
+
credentials_config: CredentialsConfig object containing channel credentials.
|
|
53
54
|
|
|
54
55
|
Returns:
|
|
55
56
|
A list of instantiated input channels. If a specific channel is provided,
|
|
56
57
|
it returns a list with that single channel. If no channel is specified,
|
|
57
58
|
it returns a list of all channels defined in the credentials file.
|
|
58
59
|
"""
|
|
59
|
-
if
|
|
60
|
-
all_credentials =
|
|
60
|
+
if credentials_config:
|
|
61
|
+
all_credentials = credentials_config.channels
|
|
61
62
|
else:
|
|
62
63
|
all_credentials = {}
|
|
63
64
|
if channel:
|
|
@@ -96,10 +97,32 @@ def _create_single_channel(
|
|
|
96
97
|
"""
|
|
97
98
|
from rasa.core.channels import BUILTIN_CHANNELS
|
|
98
99
|
|
|
100
|
+
# Channels that have optional dependencies
|
|
101
|
+
channels_with_optional_deps = {
|
|
102
|
+
"facebook": "fbmessenger",
|
|
103
|
+
"slack": "slack-sdk",
|
|
104
|
+
"telegram": "aiogram",
|
|
105
|
+
"twilio": "twilio",
|
|
106
|
+
"twilio_voice": "twilio",
|
|
107
|
+
"twilio_media_streams": "twilio",
|
|
108
|
+
"webexteams": "webexteamssdk",
|
|
109
|
+
"vier_cvg": "cvg_sdk",
|
|
110
|
+
}
|
|
111
|
+
|
|
99
112
|
if channel in BUILTIN_CHANNELS:
|
|
100
113
|
channel_class = BUILTIN_CHANNELS[channel]
|
|
101
114
|
|
|
102
115
|
return channel_class.from_credentials(credentials)
|
|
116
|
+
elif channel in channels_with_optional_deps:
|
|
117
|
+
# Channel is known but not available due to missing dependency
|
|
118
|
+
dependency = channels_with_optional_deps[channel]
|
|
119
|
+
raise RasaException(
|
|
120
|
+
f"Channel '{channel}' is not available "
|
|
121
|
+
f"due to missing '{dependency}' dependency. "
|
|
122
|
+
f"Please install the required extra by running: "
|
|
123
|
+
f"pip install 'rasa-pro[channels]' OR "
|
|
124
|
+
f"poetry add 'rasa-pro[channels]'"
|
|
125
|
+
)
|
|
103
126
|
else:
|
|
104
127
|
# try to load channel based on class name
|
|
105
128
|
try:
|
|
@@ -151,6 +174,7 @@ def configure_app(
|
|
|
151
174
|
route: Optional[Text] = "/webhooks/",
|
|
152
175
|
port: int = constants.DEFAULT_SERVER_PORT,
|
|
153
176
|
endpoints: Optional[AvailableEndpoints] = None,
|
|
177
|
+
sub_agents: Optional[AvailableAgents] = None,
|
|
154
178
|
log_file: Optional[Text] = None,
|
|
155
179
|
conversation_id: Optional[Text] = uuid.uuid4().hex,
|
|
156
180
|
use_syslog: bool = False,
|
|
@@ -179,6 +203,7 @@ def configure_app(
|
|
|
179
203
|
jwt_private_key=jwt_private_key,
|
|
180
204
|
jwt_method=jwt_method,
|
|
181
205
|
endpoints=endpoints,
|
|
206
|
+
sub_agents=sub_agents,
|
|
182
207
|
is_inspector_enabled=is_inspector_enabled,
|
|
183
208
|
)
|
|
184
209
|
)
|
|
@@ -234,7 +259,7 @@ def serve_application(
|
|
|
234
259
|
channel: Optional[Text] = None,
|
|
235
260
|
interface: Optional[Text] = constants.DEFAULT_SERVER_INTERFACE,
|
|
236
261
|
port: int = constants.DEFAULT_SERVER_PORT,
|
|
237
|
-
credentials: Optional[
|
|
262
|
+
credentials: Optional[CredentialsConfig] = None,
|
|
238
263
|
cors: Optional[Union[Text, List[Text]]] = None,
|
|
239
264
|
auth_token: Optional[Text] = None,
|
|
240
265
|
enable_api: bool = True,
|
|
@@ -243,6 +268,7 @@ def serve_application(
|
|
|
243
268
|
jwt_private_key: Optional[Text] = None,
|
|
244
269
|
jwt_method: Optional[Text] = None,
|
|
245
270
|
endpoints: Optional[AvailableEndpoints] = None,
|
|
271
|
+
sub_agents: Optional[AvailableAgents] = None,
|
|
246
272
|
remote_storage: Optional[StorageType] = None,
|
|
247
273
|
log_file: Optional[Text] = None,
|
|
248
274
|
ssl_certificate: Optional[Text] = None,
|
|
@@ -283,6 +309,7 @@ def serve_application(
|
|
|
283
309
|
jwt_method,
|
|
284
310
|
port=port,
|
|
285
311
|
endpoints=endpoints,
|
|
312
|
+
sub_agents=sub_agents,
|
|
286
313
|
log_file=log_file,
|
|
287
314
|
conversation_id=conversation_id,
|
|
288
315
|
use_syslog=use_syslog,
|
|
@@ -302,7 +329,7 @@ def serve_application(
|
|
|
302
329
|
logger.info(f"Starting Rasa server on {protocol}://{interface}:{port}")
|
|
303
330
|
|
|
304
331
|
app.register_listener(
|
|
305
|
-
partial(load_agent_on_start, model_path, endpoints, remote_storage),
|
|
332
|
+
partial(load_agent_on_start, model_path, endpoints, remote_storage, sub_agents),
|
|
306
333
|
"before_server_start",
|
|
307
334
|
)
|
|
308
335
|
|
|
@@ -340,6 +367,7 @@ async def load_agent_on_start(
|
|
|
340
367
|
model_path: Text,
|
|
341
368
|
endpoints: AvailableEndpoints,
|
|
342
369
|
remote_storage: Optional[StorageType],
|
|
370
|
+
sub_agents: Optional[AvailableAgents],
|
|
343
371
|
app: Sanic,
|
|
344
372
|
loop: AbstractEventLoop,
|
|
345
373
|
) -> Agent:
|
|
@@ -352,6 +380,7 @@ async def load_agent_on_start(
|
|
|
352
380
|
model_path=model_path,
|
|
353
381
|
remote_storage=remote_storage,
|
|
354
382
|
endpoints=endpoints,
|
|
383
|
+
sub_agents=sub_agents,
|
|
355
384
|
loop=loop,
|
|
356
385
|
)
|
|
357
386
|
|
rasa/core/test.py
CHANGED
|
@@ -1281,6 +1281,10 @@ async def compare_models_in_dir(
|
|
|
1281
1281
|
for k, v in number_correct_in_run.items():
|
|
1282
1282
|
number_correct[k].append(v)
|
|
1283
1283
|
|
|
1284
|
+
logger.info("current working directory: " + os.getcwd())
|
|
1285
|
+
logger.info(
|
|
1286
|
+
f"Writing model comparison results to '{os.path.join(output, RESULTS_FILE)}'"
|
|
1287
|
+
)
|
|
1284
1288
|
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
1285
1289
|
os.path.join(output, RESULTS_FILE), number_correct
|
|
1286
1290
|
)
|
|
@@ -8,6 +8,9 @@ from pydantic import ValidationError
|
|
|
8
8
|
|
|
9
9
|
import rasa.shared
|
|
10
10
|
from rasa.core.brokers.broker import EventBroker
|
|
11
|
+
from rasa.core.iam_credentials_providers.credentials_provider_protocol import (
|
|
12
|
+
SupportedServiceType,
|
|
13
|
+
)
|
|
11
14
|
from rasa.core.redis_connection_factory import (
|
|
12
15
|
DeploymentMode,
|
|
13
16
|
RedisConfig,
|
|
@@ -54,6 +57,7 @@ class RedisTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
54
57
|
config = RedisConfig(
|
|
55
58
|
host=host,
|
|
56
59
|
port=port,
|
|
60
|
+
service_type=SupportedServiceType.TRACKER_STORE,
|
|
57
61
|
db=db,
|
|
58
62
|
username=username,
|
|
59
63
|
password=password,
|
|
@@ -27,6 +27,7 @@ from rasa.core.constants import (
|
|
|
27
27
|
POSTGRESQL_MAX_OVERFLOW,
|
|
28
28
|
POSTGRESQL_POOL_SIZE,
|
|
29
29
|
POSTGRESQL_SCHEMA,
|
|
30
|
+
SQL_SERVICE_NAME,
|
|
30
31
|
SQL_TRACKER_STORE_SSL_MODE_ENV_VAR_NAME,
|
|
31
32
|
SQL_TRACKER_STORE_SSL_ROOT_CERTIFICATE_ENV_VAR_NAME,
|
|
32
33
|
)
|
|
@@ -229,7 +230,8 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
229
230
|
|
|
230
231
|
iam_credentials_provider = create_iam_credentials_provider(
|
|
231
232
|
IAMCredentialsProviderInput(
|
|
232
|
-
|
|
233
|
+
service_type=SupportedServiceType.TRACKER_STORE,
|
|
234
|
+
service_name=SQL_SERVICE_NAME,
|
|
233
235
|
username=username,
|
|
234
236
|
host=host,
|
|
235
237
|
port=port,
|
|
@@ -542,7 +542,7 @@ class FailSafeTrackerStore(TrackerStore):
|
|
|
542
542
|
return self._tracker_store.domain
|
|
543
543
|
|
|
544
544
|
@domain.setter
|
|
545
|
-
def domain(self, domain: Domain) -> None:
|
|
545
|
+
def domain(self, domain: Optional[Domain]) -> None:
|
|
546
546
|
self._tracker_store.domain = domain
|
|
547
547
|
|
|
548
548
|
if self._fallback_tracker_store:
|
|
@@ -805,9 +805,7 @@ class AwaitableTrackerStore(TrackerStore):
|
|
|
805
805
|
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
806
806
|
"""Wrapper to call `retrieve` method of primary tracker store."""
|
|
807
807
|
result = self._tracker_store.retrieve(sender_id)
|
|
808
|
-
return (
|
|
809
|
-
await result if isawaitable(result) else result # type: ignore[return-value, misc]
|
|
810
|
-
)
|
|
808
|
+
return await result if isawaitable(result) else result
|
|
811
809
|
|
|
812
810
|
async def keys(self) -> Iterable[Text]:
|
|
813
811
|
"""Wrapper to call `keys` method of primary tracker store."""
|
|
@@ -834,6 +832,4 @@ class AwaitableTrackerStore(TrackerStore):
|
|
|
834
832
|
) -> Optional[DialogueStateTracker]:
|
|
835
833
|
"""Wrapper to call `retrieve_full_tracker` method of primary tracker store."""
|
|
836
834
|
result = self._tracker_store.retrieve_full_tracker(conversation_id)
|
|
837
|
-
return (
|
|
838
|
-
await result if isawaitable(result) else result # type: ignore[return-value, misc]
|
|
839
|
-
)
|
|
835
|
+
return await result if isawaitable(result) else result
|
rasa/core/train.py
CHANGED
|
@@ -46,7 +46,7 @@ async def train_comparison_models(
|
|
|
46
46
|
output=str(Path(output_path, f"run_{r +1}")),
|
|
47
47
|
fixed_model_name=config_name + PERCENTAGE_KEY + str(percentage),
|
|
48
48
|
additional_arguments={
|
|
49
|
-
**additional_arguments,
|
|
49
|
+
**(additional_arguments or {}),
|
|
50
50
|
"exclusion_percentage": percentage,
|
|
51
51
|
},
|
|
52
52
|
)
|