rasa-pro 3.13.0.dev20250612__py3-none-any.whl → 3.13.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +0 -3
- rasa/api.py +1 -1
- rasa/cli/dialogue_understanding_test.py +1 -1
- rasa/cli/e2e_test.py +1 -8
- rasa/cli/evaluate.py +1 -1
- rasa/cli/export.py +3 -1
- rasa/cli/llm_fine_tuning.py +12 -11
- rasa/cli/project_templates/defaults.py +133 -0
- rasa/cli/project_templates/tutorial/config.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +1 -1
- rasa/cli/studio/download.py +1 -23
- rasa/cli/studio/link.py +52 -0
- rasa/cli/studio/pull.py +79 -0
- rasa/cli/studio/push.py +78 -0
- rasa/cli/studio/studio.py +12 -0
- rasa/cli/studio/train.py +0 -1
- rasa/cli/studio/upload.py +8 -0
- rasa/cli/train.py +1 -1
- rasa/cli/utils.py +1 -1
- rasa/cli/x.py +1 -1
- rasa/constants.py +2 -0
- rasa/core/__init__.py +0 -16
- rasa/core/actions/action.py +5 -1
- rasa/core/actions/action_repeat_bot_messages.py +18 -22
- rasa/core/actions/action_run_slot_rejections.py +0 -1
- rasa/core/agent.py +16 -1
- rasa/core/available_endpoints.py +146 -0
- rasa/core/brokers/pika.py +1 -2
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/botframework.py +2 -2
- rasa/core/channels/channel.py +2 -2
- rasa/core/channels/development_inspector.py +1 -1
- rasa/core/channels/facebook.py +1 -4
- rasa/core/channels/hangouts.py +8 -5
- rasa/core/channels/inspector/README.md +3 -3
- rasa/core/channels/inspector/dist/assets/{arc-c4b064fc.js → arc-371401b1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-215b5026.js → blockDiagram-38ab4fdb-3f126156.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-2b54a0a3.js → c4Diagram-3d4e48cf-12f22eb7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/channel-f1efda17.js +1 -0
- rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-daacea5f.js → classDiagram-70f12bd4-03b1d386.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-930d4dc2.js → classDiagram-v2-f2320105-84f69d63.js} +1 -1
- rasa/core/channels/inspector/dist/assets/clone-fdf164e2.js +1 -0
- rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-83c206ba.js → createText-2e5e7dd3-ca47fd38.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-b0eb01d0.js → edges-e0da2a9e-f837ca8a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-17586500.js → erDiagram-9861fffd-8717ac54.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-be2a1776.js → flowDb-956e92f1-94f38b83.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-c2120ebd.js → flowDiagram-66a62f08-b616f9fb.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-7d7a1629.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-a6ab5c48.js → flowchart-elk-definition-4a651766-f5d24bb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-ef613457.js → ganttDiagram-c361ad54-b43ba8d9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-d59185b3.js → gitGraphDiagram-72cf32ee-c3aafaa5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{graph-0f155405.js → graph-0d0a2c10.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3862675e-d5f1d1b7.js → index-3862675e-58ea0305.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-47737d3a.js → index-cce6f8a1.js} +3 -3
- rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-b07d141f.js → infoDiagram-f8f76790-b8f60461.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-1936d429.js → journeyDiagram-49397b02-95be5545.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-dde8d0f3.js → layout-da885b9b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-0c2c7ee0.js → line-f1c817d3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-35dd89a4.js → linear-d42801e6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-56192851.js → mindmap-definition-fc14e90a-a38923a6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-fc21ed78.js → pieDiagram-8a3498a8-ca6e71e9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-25e98518.js → quadrantDiagram-120e2f19-b290dae9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-546ff1f5.js → requirementDiagram-deff3bca-03f02ceb.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-02d8b82d.js → sankeyDiagram-04a897e0-c49eee40.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-3ca5a92e.js → sequenceDiagram-704730f1-b2cd6a3d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-128ea07c.js → stateDiagram-587899a1-e53a2028.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-95f290af.js → stateDiagram-v2-d93cdb3a-e1982a03.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-4984898a.js → styles-6aaf32cf-d0226ca5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9a916d00-1bf266ba.js → styles-9a916d00-0e21dc00.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-c10674c1-60521c63.js → styles-c10674c1-9588494e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-a25b6e12.js → svgDrawCommon-08f97a94-be478d4f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-0fc086bf.js → timeline-definition-85554ec2-74631749.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-44ee592e.js → xychartDiagram-e933f94c-a043552f.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/components/RecruitmentPanel.tsx +1 -1
- rasa/core/channels/mattermost.py +1 -1
- rasa/core/channels/rasa_chat.py +2 -4
- rasa/core/channels/rest.py +5 -4
- rasa/core/channels/socketio.py +56 -41
- rasa/core/channels/studio_chat.py +314 -10
- rasa/core/channels/vier_cvg.py +1 -2
- rasa/core/channels/voice_ready/audiocodes.py +2 -9
- rasa/core/channels/voice_stream/asr/azure.py +9 -0
- rasa/core/channels/voice_stream/audiocodes.py +8 -5
- rasa/core/channels/voice_stream/browser_audio.py +1 -1
- rasa/core/channels/voice_stream/genesys.py +2 -2
- rasa/core/channels/voice_stream/jambonz.py +166 -0
- rasa/core/channels/voice_stream/tts/__init__.py +8 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +17 -5
- rasa/core/channels/voice_stream/voice_channel.py +44 -24
- rasa/core/exporter.py +36 -0
- rasa/core/http_interpreter.py +3 -7
- rasa/core/information_retrieval/faiss.py +18 -11
- rasa/core/information_retrieval/ingestion/faq_parser.py +158 -0
- rasa/core/jobs.py +2 -1
- rasa/core/nlg/contextual_response_rephraser.py +48 -12
- rasa/core/nlg/generator.py +0 -1
- rasa/core/nlg/interpolator.py +2 -3
- rasa/core/nlg/summarize.py +39 -5
- rasa/core/policies/enterprise_search_policy.py +298 -184
- rasa/core/policies/enterprise_search_policy_config.py +241 -0
- rasa/core/policies/enterprise_search_prompt_with_relevancy_check_and_citation_template.jinja2 +64 -0
- rasa/core/policies/flow_policy.py +1 -1
- rasa/core/policies/flows/flow_executor.py +96 -17
- rasa/core/policies/intentless_policy.py +71 -26
- rasa/core/processor.py +104 -51
- rasa/core/run.py +33 -11
- rasa/core/tracker_stores/tracker_store.py +1 -1
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +35 -99
- rasa/dialogue_understanding/coexistence/intent_based_router.py +2 -1
- rasa/dialogue_understanding/coexistence/llm_based_router.py +13 -17
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +2 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +6 -2
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +2 -0
- rasa/dialogue_understanding/commands/clarify_command.py +7 -3
- rasa/dialogue_understanding/commands/command_syntax_manager.py +1 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +5 -6
- rasa/dialogue_understanding/commands/error_command.py +1 -1
- rasa/dialogue_understanding/commands/human_handoff_command.py +3 -3
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +2 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +15 -5
- rasa/dialogue_understanding/commands/skip_question_command.py +3 -3
- rasa/dialogue_understanding/commands/start_flow_command.py +7 -3
- rasa/dialogue_understanding/commands/utils.py +26 -2
- rasa/dialogue_understanding/generator/__init__.py +7 -1
- rasa/dialogue_understanding/generator/command_generator.py +15 -3
- rasa/dialogue_understanding/generator/command_parser.py +2 -2
- rasa/dialogue_understanding/generator/command_parser_validator.py +63 -0
- rasa/dialogue_understanding/generator/constants.py +2 -2
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +2 -2
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +0 -2
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +1 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +1 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +79 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v3_gpt_4o_2024_11_20_template.jinja2 +79 -0
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +28 -463
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +147 -0
- rasa/dialogue_understanding/generator/single_step/single_step_based_llm_command_generator.py +461 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +11 -64
- rasa/dialogue_understanding/patterns/cancel.py +1 -2
- rasa/dialogue_understanding/patterns/clarify.py +1 -1
- rasa/dialogue_understanding/patterns/correction.py +2 -2
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +37 -25
- rasa/dialogue_understanding/patterns/domain_for_patterns.py +190 -0
- rasa/dialogue_understanding/processor/command_processor.py +11 -12
- rasa/dialogue_understanding/processor/command_processor_component.py +3 -3
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +17 -4
- rasa/dialogue_understanding/stack/utils.py +3 -1
- rasa/dialogue_understanding/utils.py +68 -12
- rasa/dialogue_understanding_test/du_test_case.py +1 -1
- rasa/dialogue_understanding_test/du_test_runner.py +4 -22
- rasa/dialogue_understanding_test/test_case_simulation/test_case_tracker_simulator.py +2 -6
- rasa/e2e_test/e2e_test_coverage_report.py +1 -1
- rasa/e2e_test/e2e_test_runner.py +1 -1
- rasa/engine/constants.py +1 -1
- rasa/engine/graph.py +2 -2
- rasa/engine/recipes/default_recipe.py +26 -2
- rasa/engine/validation.py +3 -2
- rasa/hooks.py +0 -28
- rasa/llm_fine_tuning/annotation_module.py +39 -9
- rasa/llm_fine_tuning/conversations.py +3 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +66 -49
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +5 -7
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +52 -44
- rasa/llm_fine_tuning/paraphrasing_module.py +10 -12
- rasa/llm_fine_tuning/storage.py +4 -4
- rasa/llm_fine_tuning/utils.py +63 -1
- rasa/model_manager/model_api.py +88 -0
- rasa/model_manager/trainer_service.py +4 -4
- rasa/plugin.py +1 -11
- rasa/privacy/__init__.py +0 -0
- rasa/privacy/constants.py +83 -0
- rasa/privacy/event_broker_utils.py +77 -0
- rasa/privacy/privacy_config.py +281 -0
- rasa/privacy/privacy_config_schema.json +86 -0
- rasa/privacy/privacy_filter.py +340 -0
- rasa/privacy/privacy_manager.py +576 -0
- rasa/server.py +23 -2
- rasa/shared/constants.py +18 -0
- rasa/shared/core/command_payload_reader.py +1 -5
- rasa/shared/core/constants.py +4 -3
- rasa/shared/core/domain.py +7 -0
- rasa/shared/core/events.py +38 -10
- rasa/shared/core/flows/constants.py +2 -0
- rasa/shared/core/flows/flow.py +127 -14
- rasa/shared/core/flows/flows_list.py +18 -1
- rasa/shared/core/flows/flows_yaml_schema.json +3 -0
- rasa/shared/core/flows/steps/collect.py +46 -2
- rasa/shared/core/flows/steps/link.py +7 -2
- rasa/shared/core/flows/validation.py +25 -5
- rasa/shared/core/slots.py +28 -0
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +1 -4
- rasa/shared/exceptions.py +4 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +6 -2
- rasa/shared/providers/_configs/default_litellm_client_config.py +1 -1
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +1 -1
- rasa/shared/providers/_configs/openai_client_config.py +5 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +1 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -1
- rasa/shared/providers/_configs/utils.py +0 -99
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +3 -0
- rasa/shared/providers/llm/_base_litellm_client.py +5 -2
- rasa/shared/utils/common.py +1 -1
- rasa/shared/utils/configs.py +110 -0
- rasa/shared/utils/constants.py +0 -3
- rasa/shared/utils/llm.py +195 -9
- rasa/shared/utils/pykwalify_extensions.py +0 -9
- rasa/shared/utils/yaml.py +32 -0
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +11 -4
- rasa/studio/download.py +167 -0
- rasa/studio/link.py +200 -0
- rasa/studio/prompts.py +223 -0
- rasa/studio/pull/__init__.py +0 -0
- rasa/studio/{download/flows.py → pull/data.py} +23 -160
- rasa/studio/{download → pull}/domains.py +1 -1
- rasa/studio/pull/pull.py +235 -0
- rasa/studio/push.py +136 -0
- rasa/studio/train.py +1 -1
- rasa/studio/upload.py +117 -67
- rasa/telemetry.py +82 -25
- rasa/tracing/config.py +3 -4
- rasa/tracing/constants.py +19 -1
- rasa/tracing/instrumentation/attribute_extractors.py +30 -8
- rasa/tracing/instrumentation/instrumentation.py +53 -2
- rasa/tracing/instrumentation/metrics.py +98 -15
- rasa/tracing/metric_instrument_provider.py +75 -3
- rasa/utils/common.py +7 -22
- rasa/utils/log_utils.py +1 -45
- rasa/validator.py +2 -8
- rasa/version.py +1 -1
- {rasa_pro-3.13.0.dev20250612.dist-info → rasa_pro-3.13.0rc1.dist-info}/METADATA +8 -9
- {rasa_pro-3.13.0.dev20250612.dist-info → rasa_pro-3.13.0rc1.dist-info}/RECORD +241 -220
- rasa/anonymization/__init__.py +0 -2
- rasa/anonymization/anonymisation_rule_yaml_reader.py +0 -91
- rasa/anonymization/anonymization_pipeline.py +0 -286
- rasa/anonymization/anonymization_rule_executor.py +0 -266
- rasa/anonymization/anonymization_rule_orchestrator.py +0 -119
- rasa/anonymization/schemas/config.yml +0 -47
- rasa/anonymization/utils.py +0 -118
- rasa/core/channels/inspector/dist/assets/channel-3730f5fd.js +0 -1
- rasa/core/channels/inspector/dist/assets/clone-e847561e.js +0 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-efbbfe00.js +0 -1
- rasa/studio/download/download.py +0 -439
- /rasa/{studio/download → core/information_retrieval/ingestion}/__init__.py +0 -0
- {rasa_pro-3.13.0.dev20250612.dist-info → rasa_pro-3.13.0rc1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.0.dev20250612.dist-info → rasa_pro-3.13.0rc1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.0.dev20250612.dist-info → rasa_pro-3.13.0rc1.dist-info}/entry_points.txt +0 -0
rasa/shared/utils/llm.py
CHANGED
|
@@ -1,13 +1,17 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import importlib.resources
|
|
2
4
|
import json
|
|
3
5
|
import logging
|
|
4
6
|
from copy import deepcopy
|
|
7
|
+
from datetime import datetime
|
|
5
8
|
from functools import wraps
|
|
6
9
|
from typing import (
|
|
7
10
|
TYPE_CHECKING,
|
|
8
11
|
Any,
|
|
9
12
|
Callable,
|
|
10
13
|
Dict,
|
|
14
|
+
List,
|
|
11
15
|
Literal,
|
|
12
16
|
Optional,
|
|
13
17
|
Text,
|
|
@@ -18,15 +22,23 @@ from typing import (
|
|
|
18
22
|
)
|
|
19
23
|
|
|
20
24
|
import structlog
|
|
25
|
+
from pydantic import BaseModel, Field
|
|
21
26
|
|
|
22
27
|
import rasa.shared.utils.io
|
|
23
|
-
from rasa.core.
|
|
28
|
+
from rasa.core.available_endpoints import AvailableEndpoints
|
|
24
29
|
from rasa.shared.constants import (
|
|
30
|
+
CONFIG_NAME_KEY,
|
|
31
|
+
CONFIG_PIPELINE_KEY,
|
|
32
|
+
CONFIG_POLICIES_KEY,
|
|
25
33
|
DEFAULT_PROMPT_PACKAGE_NAME,
|
|
34
|
+
LLM_CONFIG_KEY,
|
|
26
35
|
MODEL_CONFIG_KEY,
|
|
27
36
|
MODEL_GROUP_CONFIG_KEY,
|
|
28
37
|
MODEL_GROUP_ID_CONFIG_KEY,
|
|
38
|
+
MODEL_GROUPS_CONFIG_KEY,
|
|
29
39
|
MODELS_CONFIG_KEY,
|
|
40
|
+
PROMPT_CONFIG_KEY,
|
|
41
|
+
PROMPT_TEMPLATE_CONFIG_KEY,
|
|
30
42
|
PROVIDER_CONFIG_KEY,
|
|
31
43
|
RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY,
|
|
32
44
|
RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG,
|
|
@@ -61,9 +73,11 @@ from rasa.shared.providers.mappings import (
|
|
|
61
73
|
get_embedding_client_from_provider,
|
|
62
74
|
get_llm_client_from_provider,
|
|
63
75
|
)
|
|
76
|
+
from rasa.shared.utils.common import all_subclasses
|
|
64
77
|
from rasa.shared.utils.constants import LOG_COMPONENT_SOURCE_METHOD_INIT
|
|
65
78
|
|
|
66
79
|
if TYPE_CHECKING:
|
|
80
|
+
from rasa.core.agent import Agent
|
|
67
81
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
68
82
|
|
|
69
83
|
|
|
@@ -73,13 +87,15 @@ USER = "USER"
|
|
|
73
87
|
|
|
74
88
|
AI = "AI"
|
|
75
89
|
|
|
76
|
-
DEFAULT_OPENAI_GENERATE_MODEL_NAME = "gpt-
|
|
90
|
+
DEFAULT_OPENAI_GENERATE_MODEL_NAME = "gpt-4o-2024-11-20"
|
|
91
|
+
|
|
92
|
+
DEFAULT_OPENAI_CHAT_MODEL_NAME = "gpt-4o-2024-11-20"
|
|
77
93
|
|
|
78
|
-
|
|
94
|
+
DEFAULT_ENTERPRISE_SEARCH_POLICY_MODEL_NAME = "gpt-4.1-mini-2025-04-14"
|
|
79
95
|
|
|
80
96
|
DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED = "gpt-4-0613"
|
|
81
97
|
|
|
82
|
-
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME = "text-embedding-
|
|
98
|
+
DEFAULT_OPENAI_EMBEDDING_MODEL_NAME = "text-embedding-3-large"
|
|
83
99
|
|
|
84
100
|
DEFAULT_OPENAI_TEMPERATURE = 0.7
|
|
85
101
|
|
|
@@ -107,6 +123,18 @@ _CombineConfigs_F = TypeVar(
|
|
|
107
123
|
)
|
|
108
124
|
|
|
109
125
|
|
|
126
|
+
class SystemPrompts(BaseModel):
|
|
127
|
+
command_generator: str = Field(
|
|
128
|
+
..., description="Prompt used by the LLM command generator."
|
|
129
|
+
)
|
|
130
|
+
enterprise_search: str = Field(
|
|
131
|
+
..., description="Prompt for standard enterprise search requests."
|
|
132
|
+
)
|
|
133
|
+
contextual_response_rephraser: str = Field(
|
|
134
|
+
..., description="Prompt used for re-phrasing assistant responses."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
110
138
|
def _compute_hash_for_cache_from_configs(
|
|
111
139
|
config_x: Dict[str, Any], config_y: Dict[str, Any]
|
|
112
140
|
) -> int:
|
|
@@ -191,6 +219,7 @@ def tracker_as_readable_transcript(
|
|
|
191
219
|
human_prefix: str = USER,
|
|
192
220
|
ai_prefix: str = AI,
|
|
193
221
|
max_turns: Optional[int] = 20,
|
|
222
|
+
turns_wrapper: Optional[Callable[[List[str]], List[str]]] = None,
|
|
194
223
|
) -> str:
|
|
195
224
|
"""Creates a readable dialogue from a tracker.
|
|
196
225
|
|
|
@@ -199,6 +228,7 @@ def tracker_as_readable_transcript(
|
|
|
199
228
|
human_prefix: the prefix to use for human utterances
|
|
200
229
|
ai_prefix: the prefix to use for ai utterances
|
|
201
230
|
max_turns: the maximum number of turns to include in the transcript
|
|
231
|
+
turns_wrapper: optional function to wrap the turns in a custom way
|
|
202
232
|
|
|
203
233
|
Example:
|
|
204
234
|
>>> tracker = Tracker(
|
|
@@ -235,8 +265,11 @@ def tracker_as_readable_transcript(
|
|
|
235
265
|
elif isinstance(event, BotUttered):
|
|
236
266
|
transcript.append(f"{ai_prefix}: {sanitize_message_for_prompt(event.text)}")
|
|
237
267
|
|
|
238
|
-
|
|
239
|
-
|
|
268
|
+
# turns_wrapper to count multiple utterances by bot/user as single turn
|
|
269
|
+
if turns_wrapper:
|
|
270
|
+
transcript = turns_wrapper(transcript)
|
|
271
|
+
# otherwise, just take the last `max_turns` lines of the transcript
|
|
272
|
+
transcript = transcript[-max_turns if max_turns is not None else None :]
|
|
240
273
|
|
|
241
274
|
return "\n".join(transcript)
|
|
242
275
|
|
|
@@ -678,7 +711,6 @@ def get_prompt_template(
|
|
|
678
711
|
Returns:
|
|
679
712
|
The prompt template.
|
|
680
713
|
"""
|
|
681
|
-
|
|
682
714
|
try:
|
|
683
715
|
if jinja_file_path is not None:
|
|
684
716
|
prompt_template = rasa.shared.utils.io.read_file(jinja_file_path)
|
|
@@ -814,7 +846,9 @@ def allowed_values_for_slot(slot: Slot) -> Union[str, None]:
|
|
|
814
846
|
|
|
815
847
|
|
|
816
848
|
def resolve_model_client_config(
|
|
817
|
-
model_config: Optional[Dict[str, Any]],
|
|
849
|
+
model_config: Optional[Dict[str, Any]],
|
|
850
|
+
component_name: Optional[str] = None,
|
|
851
|
+
model_groups: Optional[List[Dict[str, Any]]] = None,
|
|
818
852
|
) -> Optional[Dict[str, Any]]:
|
|
819
853
|
"""Resolve the model group in the model config.
|
|
820
854
|
|
|
@@ -828,6 +862,7 @@ def resolve_model_client_config(
|
|
|
828
862
|
model_config: The model config to be resolved.
|
|
829
863
|
component_name: The name of the component.
|
|
830
864
|
component_name: The method of the component.
|
|
865
|
+
model_groups: Model groups from endpoints.yml.
|
|
831
866
|
|
|
832
867
|
Returns:
|
|
833
868
|
The resolved llm config.
|
|
@@ -854,7 +889,12 @@ def resolve_model_client_config(
|
|
|
854
889
|
|
|
855
890
|
model_group_id = model_config.get(MODEL_GROUP_CONFIG_KEY)
|
|
856
891
|
|
|
857
|
-
|
|
892
|
+
# If `model_groups` is provided, use it to initialise `AvailableEndpoints`,
|
|
893
|
+
# since `get_instance()` reads from the local endpoints file instead.
|
|
894
|
+
if model_groups:
|
|
895
|
+
endpoints = AvailableEndpoints(model_groups=model_groups)
|
|
896
|
+
else:
|
|
897
|
+
endpoints = AvailableEndpoints.get_instance()
|
|
858
898
|
if endpoints.model_groups is None:
|
|
859
899
|
_raise_invalid_config_exception(
|
|
860
900
|
reason=(
|
|
@@ -886,3 +926,149 @@ def resolve_model_client_config(
|
|
|
886
926
|
)
|
|
887
927
|
|
|
888
928
|
return model_group[0]
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
def generate_sender_id(test_case_name: str) -> str:
|
|
932
|
+
# add timestamp suffix to ensure sender_id is unique
|
|
933
|
+
return f"{test_case_name}_{datetime.now()}"
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
async def create_tracker_for_user_step(
|
|
937
|
+
step_sender_id: str,
|
|
938
|
+
agent: "Agent",
|
|
939
|
+
test_case_tracker: "DialogueStateTracker",
|
|
940
|
+
index_user_uttered_event: int,
|
|
941
|
+
) -> None:
|
|
942
|
+
"""Creates a tracker for the user step."""
|
|
943
|
+
tracker = test_case_tracker.copy()
|
|
944
|
+
# modify the sender id so that the original tracker is not overwritten
|
|
945
|
+
tracker.sender_id = step_sender_id
|
|
946
|
+
|
|
947
|
+
if tracker.events:
|
|
948
|
+
# get the timestamp of the event just before the user uttered event
|
|
949
|
+
timestamp = tracker.events[index_user_uttered_event - 1].timestamp
|
|
950
|
+
# revert the tracker to the event just before the user uttered event
|
|
951
|
+
tracker = tracker.travel_back_in_time(timestamp)
|
|
952
|
+
|
|
953
|
+
# store the tracker with the unique sender id
|
|
954
|
+
await agent.tracker_store.save(tracker)
|
|
955
|
+
|
|
956
|
+
|
|
957
|
+
def check_prompt_config_keys_and_warn_if_deprecated(
|
|
958
|
+
config: dict, component_source: str
|
|
959
|
+
) -> None:
|
|
960
|
+
"""Checks and warns about deprecated config parameters."""
|
|
961
|
+
if PROMPT_CONFIG_KEY in config and PROMPT_TEMPLATE_CONFIG_KEY in config:
|
|
962
|
+
structlogger.warning(
|
|
963
|
+
f"{component_source}.init"
|
|
964
|
+
".both_deprecated_and_non_deprecated_config_keys_used_at_the_same_time",
|
|
965
|
+
event_info=(
|
|
966
|
+
f"Both '{PROMPT_CONFIG_KEY}' and '{PROMPT_TEMPLATE_CONFIG_KEY}' "
|
|
967
|
+
f"are present in the config. '{PROMPT_CONFIG_KEY}' will be ignored "
|
|
968
|
+
f"in favor of {PROMPT_TEMPLATE_CONFIG_KEY}."
|
|
969
|
+
),
|
|
970
|
+
)
|
|
971
|
+
|
|
972
|
+
# 'prompt' config key is deprecated in favor of 'prompt_template'
|
|
973
|
+
if PROMPT_CONFIG_KEY in config:
|
|
974
|
+
structlogger.warning(
|
|
975
|
+
f"{component_source}.init.deprecated_config_key",
|
|
976
|
+
event_info=(
|
|
977
|
+
f"The config parameter '{PROMPT_CONFIG_KEY}' is deprecated "
|
|
978
|
+
"and will be removed in Rasa 4.0.0. "
|
|
979
|
+
f"Please use the config parameter '{PROMPT_TEMPLATE_CONFIG_KEY}'"
|
|
980
|
+
f" instead. "
|
|
981
|
+
),
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
def _get_llm_command_generator_config(
|
|
986
|
+
config: Dict[Text, Any],
|
|
987
|
+
) -> Optional[Dict[Text, Any]]:
|
|
988
|
+
"""Get the llm command generator config from config.yml.
|
|
989
|
+
|
|
990
|
+
Args:
|
|
991
|
+
config: The config.yml file data.
|
|
992
|
+
|
|
993
|
+
Returns:
|
|
994
|
+
The llm command generator config.
|
|
995
|
+
"""
|
|
996
|
+
from rasa.dialogue_understanding.generator import LLMBasedCommandGenerator
|
|
997
|
+
|
|
998
|
+
# Collect all LLM based Command Generator class names.
|
|
999
|
+
command_generator_subclasses = all_subclasses(LLMBasedCommandGenerator)
|
|
1000
|
+
command_generator_class_names = [
|
|
1001
|
+
command_generator.__name__ for command_generator in command_generator_subclasses
|
|
1002
|
+
]
|
|
1003
|
+
|
|
1004
|
+
# Read the LLM config of the Command Generator from the config.yml file.
|
|
1005
|
+
pipelines = config.get(CONFIG_PIPELINE_KEY, [])
|
|
1006
|
+
for pipeline in pipelines:
|
|
1007
|
+
if pipeline.get(CONFIG_NAME_KEY) in command_generator_class_names:
|
|
1008
|
+
return pipeline.get(LLM_CONFIG_KEY)
|
|
1009
|
+
|
|
1010
|
+
return None
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
def _get_command_generator_prompt(
|
|
1014
|
+
config: Dict[Text, Any], endpoints: Dict[Text, Any]
|
|
1015
|
+
) -> Text:
|
|
1016
|
+
"""Get the command generator prompt based on the config."""
|
|
1017
|
+
from rasa.dialogue_understanding.generator.single_step.compact_llm_command_generator import ( # noqa: E501
|
|
1018
|
+
DEFAULT_COMMAND_PROMPT_TEMPLATE_FILE_NAME,
|
|
1019
|
+
FALLBACK_COMMAND_PROMPT_TEMPLATE_FILE_NAME,
|
|
1020
|
+
MODEL_PROMPT_MAPPER,
|
|
1021
|
+
)
|
|
1022
|
+
|
|
1023
|
+
model_config = _get_llm_command_generator_config(config)
|
|
1024
|
+
llm_config = resolve_model_client_config(
|
|
1025
|
+
model_config=model_config,
|
|
1026
|
+
model_groups=endpoints.get(MODEL_GROUPS_CONFIG_KEY),
|
|
1027
|
+
)
|
|
1028
|
+
return get_default_prompt_template_based_on_model(
|
|
1029
|
+
llm_config=llm_config or {},
|
|
1030
|
+
model_prompt_mapping=MODEL_PROMPT_MAPPER,
|
|
1031
|
+
default_prompt_path=DEFAULT_COMMAND_PROMPT_TEMPLATE_FILE_NAME,
|
|
1032
|
+
fallback_prompt_path=FALLBACK_COMMAND_PROMPT_TEMPLATE_FILE_NAME,
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
|
|
1036
|
+
def _get_enterprise_search_prompt(config: Dict[Text, Any]) -> Text:
|
|
1037
|
+
"""Get the enterprise search prompt based on the config."""
|
|
1038
|
+
from rasa.core.policies.enterprise_search_policy import EnterpriseSearchPolicy
|
|
1039
|
+
|
|
1040
|
+
def get_enterprise_search_config() -> Dict[Text, Any]:
|
|
1041
|
+
policies = config.get(CONFIG_POLICIES_KEY, [])
|
|
1042
|
+
for policy in policies:
|
|
1043
|
+
if policy.get(CONFIG_NAME_KEY) == EnterpriseSearchPolicy.__name__:
|
|
1044
|
+
return policy
|
|
1045
|
+
|
|
1046
|
+
return {}
|
|
1047
|
+
|
|
1048
|
+
enterprise_search_config = get_enterprise_search_config()
|
|
1049
|
+
return EnterpriseSearchPolicy.get_system_default_prompt_based_on_config(
|
|
1050
|
+
enterprise_search_config
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def get_system_default_prompts(
|
|
1055
|
+
config: Dict[Text, Any], endpoints: Dict[Text, Any]
|
|
1056
|
+
) -> SystemPrompts:
|
|
1057
|
+
"""Returns the system default prompts for the component.
|
|
1058
|
+
|
|
1059
|
+
Args:
|
|
1060
|
+
config: The config.yml file data.
|
|
1061
|
+
endpoints: The endpoints.yml file data.
|
|
1062
|
+
|
|
1063
|
+
Returns:
|
|
1064
|
+
SystemPrompts: A Pydantic model containing all default prompts.
|
|
1065
|
+
"""
|
|
1066
|
+
from rasa.core.nlg.contextual_response_rephraser import (
|
|
1067
|
+
DEFAULT_RESPONSE_VARIATION_PROMPT_TEMPLATE,
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
return SystemPrompts(
|
|
1071
|
+
command_generator=_get_command_generator_prompt(config, endpoints),
|
|
1072
|
+
enterprise_search=_get_enterprise_search_prompt(config),
|
|
1073
|
+
contextual_response_rephraser=DEFAULT_RESPONSE_VARIATION_PROMPT_TEMPLATE,
|
|
1074
|
+
)
|
|
@@ -8,11 +8,6 @@ from typing import Any, Dict, List, Text, Union
|
|
|
8
8
|
|
|
9
9
|
from pykwalify.errors import SchemaError
|
|
10
10
|
|
|
11
|
-
from rasa.shared.utils.constants import (
|
|
12
|
-
RASA_PRO_BETA_PREDICATES_IN_RESPONSE_CONDITIONS_ENV_VAR_NAME,
|
|
13
|
-
)
|
|
14
|
-
from rasa.utils.beta import ensure_beta_feature_is_enabled
|
|
15
|
-
|
|
16
11
|
|
|
17
12
|
def require_response_keys(
|
|
18
13
|
responses: List[Dict[Text, Any]], _: Dict, __: Text
|
|
@@ -31,10 +26,6 @@ def require_response_keys(
|
|
|
31
26
|
|
|
32
27
|
conditions = response.get("condition", [])
|
|
33
28
|
if isinstance(conditions, str):
|
|
34
|
-
ensure_beta_feature_is_enabled(
|
|
35
|
-
"predicates in response conditions",
|
|
36
|
-
RASA_PRO_BETA_PREDICATES_IN_RESPONSE_CONDITIONS_ENV_VAR_NAME,
|
|
37
|
-
)
|
|
38
29
|
continue
|
|
39
30
|
|
|
40
31
|
for condition in conditions:
|
rasa/shared/utils/yaml.py
CHANGED
|
@@ -21,6 +21,7 @@ from ruamel.yaml import YAML, RoundTripRepresenter, YAMLError
|
|
|
21
21
|
from ruamel.yaml.comments import CommentedMap, CommentedSeq
|
|
22
22
|
from ruamel.yaml.constructor import BaseConstructor, DuplicateKeyError, ScalarNode
|
|
23
23
|
from ruamel.yaml.loader import SafeLoader
|
|
24
|
+
from ruamel.yaml.scalarstring import LiteralScalarString
|
|
24
25
|
|
|
25
26
|
from rasa.shared.constants import (
|
|
26
27
|
ASSERTIONS_SCHEMA_EXTENSIONS_FILE,
|
|
@@ -794,6 +795,25 @@ def write_yaml(
|
|
|
794
795
|
should_preserve_key_order: Whether to force preserve key order in `data`.
|
|
795
796
|
transform: A function to transform the data before writing it to the file.
|
|
796
797
|
"""
|
|
798
|
+
|
|
799
|
+
def multiline_str_representer(self: Any, value: str) -> Any:
|
|
800
|
+
"""Dump multi-line strings as readable YAML block scalars where possible."""
|
|
801
|
+
if "\n" in value:
|
|
802
|
+
# First line after the newline decides: paragraph vs. snippet
|
|
803
|
+
first_line = value.split("\n", 1)[1]
|
|
804
|
+
|
|
805
|
+
# If the first line after the newline is not indented, treat the value
|
|
806
|
+
# as plain text. Indented text is likely pre-formatted YAML/JSON/etc.
|
|
807
|
+
if not first_line.startswith((" ", "\t")):
|
|
808
|
+
return self.represent_scalar(
|
|
809
|
+
"tag:yaml.org,2002:str",
|
|
810
|
+
LiteralScalarString(value),
|
|
811
|
+
style="|",
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
# Fallback: keep default YAML scalar style (plain/quoted)
|
|
815
|
+
return self.represent_scalar("tag:yaml.org,2002:str", value)
|
|
816
|
+
|
|
797
817
|
_enable_ordered_dict_yaml_dumping()
|
|
798
818
|
|
|
799
819
|
if should_preserve_key_order:
|
|
@@ -808,6 +828,7 @@ def write_yaml(
|
|
|
808
828
|
type(None),
|
|
809
829
|
lambda self, _: self.represent_scalar("tag:yaml.org,2002:null", "null"),
|
|
810
830
|
)
|
|
831
|
+
dumper.representer.add_representer(str, multiline_str_representer)
|
|
811
832
|
|
|
812
833
|
if isinstance(target, StringIO):
|
|
813
834
|
dumper.dump(data, target, transform=transform)
|
|
@@ -1025,6 +1046,17 @@ def validate_yaml_with_jsonschema(
|
|
|
1025
1046
|
except (YAMLError, DuplicateKeyError) as e:
|
|
1026
1047
|
raise YamlSyntaxException(underlying_yaml_exception=e)
|
|
1027
1048
|
|
|
1049
|
+
validate_data_with_jsonschema(source_data, schema_content, humanize_error)
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def validate_data_with_jsonschema(
|
|
1053
|
+
source_data: Any,
|
|
1054
|
+
schema_content: Any,
|
|
1055
|
+
humanize_error: Callable[
|
|
1056
|
+
[jsonschema.ValidationError], str
|
|
1057
|
+
] = default_error_humanizer,
|
|
1058
|
+
) -> None:
|
|
1059
|
+
"""Validate Python object against the provided jsonschema content."""
|
|
1028
1060
|
try:
|
|
1029
1061
|
jsonschema.validate(source_data, schema_content)
|
|
1030
1062
|
except jsonschema.ValidationError as error:
|
rasa/studio/constants.py
CHANGED
|
@@ -14,6 +14,7 @@ RASA_STUDIO_CLI_DISABLE_VERIFY_KEY_ENV = "RASA_STUDIO_CLI_DISABLE_VERIFY_KEY"
|
|
|
14
14
|
|
|
15
15
|
STUDIO_NLU_FILENAME = "studio_nlu.yml"
|
|
16
16
|
STUDIO_DOMAIN_FILENAME = "studio_domain.yml"
|
|
17
|
+
DOMAIN_FILENAME = "domain.yml"
|
|
17
18
|
STUDIO_FLOWS_FILENAME = "studio_flows.yml"
|
|
18
19
|
STUDIO_CONFIG_FILENAME = "studio_config.yml"
|
|
19
20
|
STUDIO_ENDPOINTS_FILENAME = "studio_endpoints.yml"
|
rasa/studio/data_handler.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import base64
|
|
2
|
+
import json
|
|
2
3
|
import logging
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Any, Dict, List, Optional, Tuple
|
|
@@ -50,7 +51,7 @@ class StudioDataHandler:
|
|
|
50
51
|
"query ExportAsEncodedYaml($input: ExportAsEncodedYamlInput!) "
|
|
51
52
|
"{ exportAsEncodedYaml(input: $input) "
|
|
52
53
|
"{ ... on ExportModernAsEncodedYamlOutput "
|
|
53
|
-
"{ nlu flows domain endpoints config } "
|
|
54
|
+
"{ nlu flows domain endpoints config prompts } "
|
|
54
55
|
"... on ExportClassicAsEncodedYamlOutput "
|
|
55
56
|
"{ nlu domain }}}"
|
|
56
57
|
),
|
|
@@ -161,6 +162,9 @@ class StudioDataHandler:
|
|
|
161
162
|
def get_endpoints(self) -> Optional[str]:
|
|
162
163
|
return self.endpoints
|
|
163
164
|
|
|
165
|
+
def get_prompts(self) -> Optional[dict]:
|
|
166
|
+
return self.prompts
|
|
167
|
+
|
|
164
168
|
def _validate_response(self, response: dict) -> bool:
|
|
165
169
|
"""Validates the response from Rasa Studio.
|
|
166
170
|
|
|
@@ -200,6 +204,9 @@ class StudioDataHandler:
|
|
|
200
204
|
self.config = self._decode_response(return_data.get("config"))
|
|
201
205
|
self.endpoints = self._decode_response(return_data.get("endpoints"))
|
|
202
206
|
|
|
207
|
+
prompts_string = self._decode_response(return_data.get("prompts"))
|
|
208
|
+
self.prompts = json.loads(prompts_string) if prompts_string else None
|
|
209
|
+
|
|
203
210
|
if not self.has_nlu() and not self.has_flows():
|
|
204
211
|
raise RasaException("No nlu or flows data in Studio response.")
|
|
205
212
|
|
|
@@ -320,14 +327,14 @@ def create_new_flows_from_diff(
|
|
|
320
327
|
|
|
321
328
|
|
|
322
329
|
def import_data_from_studio(
|
|
323
|
-
handler: StudioDataHandler, domain_path: Path,
|
|
330
|
+
handler: StudioDataHandler, domain_path: Path, data_path: Path
|
|
324
331
|
) -> Tuple[TrainingDataImporter, TrainingDataImporter]:
|
|
325
332
|
"""Construct TrainingDataImporter from Studio data and original data.
|
|
326
333
|
|
|
327
334
|
Args:
|
|
328
335
|
handler (StudioDataHandler): handler with data from studio
|
|
329
336
|
domain_path (Path): Path to a domain file
|
|
330
|
-
|
|
337
|
+
data_path (List[Path]): List of paths to training data files
|
|
331
338
|
|
|
332
339
|
Returns:
|
|
333
340
|
Tuple[TrainingDataImporter, TrainingDataImporter]:
|
|
@@ -335,7 +342,7 @@ def import_data_from_studio(
|
|
|
335
342
|
"""
|
|
336
343
|
tmp_dir = get_temp_dir_name()
|
|
337
344
|
data_original = TrainingDataImporter.load_from_dict(
|
|
338
|
-
domain_path=domain_path, training_data_paths=
|
|
345
|
+
domain_path=str(domain_path), training_data_paths=[str(data_path)]
|
|
339
346
|
)
|
|
340
347
|
|
|
341
348
|
data_paths = []
|
rasa/studio/download.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import shutil
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict
|
|
5
|
+
|
|
6
|
+
import questionary
|
|
7
|
+
import structlog
|
|
8
|
+
from ruamel import yaml
|
|
9
|
+
from ruamel.yaml.scalarstring import LiteralScalarString
|
|
10
|
+
|
|
11
|
+
import rasa.cli.utils
|
|
12
|
+
import rasa.shared.utils.cli
|
|
13
|
+
from rasa.shared.constants import (
|
|
14
|
+
DEFAULT_CONFIG_PATH,
|
|
15
|
+
DEFAULT_DATA_PATH,
|
|
16
|
+
DEFAULT_ENDPOINTS_PATH,
|
|
17
|
+
)
|
|
18
|
+
from rasa.shared.core.flows.yaml_flows_io import FlowsList
|
|
19
|
+
from rasa.shared.nlu.training_data.training_data import (
|
|
20
|
+
DEFAULT_TRAINING_DATA_OUTPUT_PATH,
|
|
21
|
+
)
|
|
22
|
+
from rasa.shared.utils.yaml import read_yaml, write_yaml
|
|
23
|
+
from rasa.studio.config import StudioConfig
|
|
24
|
+
from rasa.studio.constants import DOMAIN_FILENAME
|
|
25
|
+
from rasa.studio.data_handler import StudioDataHandler
|
|
26
|
+
from rasa.studio.prompts import handle_prompts
|
|
27
|
+
from rasa.studio.pull.data import _dump_flows_as_separate_files
|
|
28
|
+
|
|
29
|
+
structlogger = structlog.get_logger()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def handle_download(args: argparse.Namespace) -> None:
|
|
33
|
+
"""Download an assistant from Studio and store it in `<assistant_name>/`.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
args: The command line arguments.
|
|
37
|
+
"""
|
|
38
|
+
assistant_name = args.assistant_name
|
|
39
|
+
target_root = _prepare_target_directory(assistant_name)
|
|
40
|
+
|
|
41
|
+
handler = StudioDataHandler(
|
|
42
|
+
studio_config=StudioConfig.read_config(), assistant_name=assistant_name
|
|
43
|
+
)
|
|
44
|
+
handler.request_all_data()
|
|
45
|
+
|
|
46
|
+
_handle_config(handler, target_root)
|
|
47
|
+
_handle_endpoints(handler, target_root)
|
|
48
|
+
_handle_domain(handler, target_root)
|
|
49
|
+
_handle_data(handler, target_root)
|
|
50
|
+
handle_prompts(handler, target_root)
|
|
51
|
+
|
|
52
|
+
structlogger.info(
|
|
53
|
+
"studio.download.success",
|
|
54
|
+
event_info=f"Downloaded assistant '{assistant_name}' from Studio.",
|
|
55
|
+
assistant_name=assistant_name,
|
|
56
|
+
)
|
|
57
|
+
rasa.shared.utils.cli.print_success(
|
|
58
|
+
f"Downloaded assistant '{assistant_name}' from Studio."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _prepare_target_directory(assistant_name: str) -> Path:
|
|
63
|
+
"""Create (or overwrite) the directory where everything is stored.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
assistant_name: The name of the assistant to download.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
The path to the target directory where the assistant will be stored.
|
|
70
|
+
"""
|
|
71
|
+
target_root = Path(assistant_name)
|
|
72
|
+
|
|
73
|
+
if target_root.exists():
|
|
74
|
+
overwrite = questionary.confirm(
|
|
75
|
+
f"Directory '{assistant_name}' already exists. Overwrite it?"
|
|
76
|
+
).ask()
|
|
77
|
+
if not overwrite:
|
|
78
|
+
rasa.shared.utils.cli.print_error_and_exit("Download cancelled.")
|
|
79
|
+
|
|
80
|
+
shutil.rmtree(target_root)
|
|
81
|
+
|
|
82
|
+
target_root.mkdir(parents=True, exist_ok=True)
|
|
83
|
+
return target_root
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _handle_config(handler: StudioDataHandler, root: Path) -> None:
|
|
87
|
+
"""Download and persist the assistant’s config file.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
handler: The data handler to retrieve the config from.
|
|
91
|
+
root: The root directory where the config file will be stored.
|
|
92
|
+
"""
|
|
93
|
+
config_data = handler.get_config()
|
|
94
|
+
if not config_data:
|
|
95
|
+
rasa.shared.utils.cli.print_error_and_exit("No config data found.")
|
|
96
|
+
|
|
97
|
+
config_path = root / DEFAULT_CONFIG_PATH
|
|
98
|
+
config_path.write_text(config_data, encoding="utf-8")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _handle_endpoints(handler: StudioDataHandler, root: Path) -> None:
|
|
102
|
+
"""Download and persist the assistant’s endpoints file.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
handler: The data handler to retrieve the endpoints from.
|
|
106
|
+
root: The root directory where the endpoints file will be stored.
|
|
107
|
+
"""
|
|
108
|
+
endpoints_data = handler.get_endpoints()
|
|
109
|
+
if not endpoints_data:
|
|
110
|
+
rasa.shared.utils.cli.print_error_and_exit("No endpoints data found.")
|
|
111
|
+
|
|
112
|
+
endpoints_path = root / DEFAULT_ENDPOINTS_PATH
|
|
113
|
+
endpoints_path.write_text(endpoints_data, encoding="utf-8")
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _handle_domain(handler: StudioDataHandler, root: Path) -> None:
|
|
117
|
+
"""Persist the assistant’s domain file.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
handler: The data handler to retrieve the domain from.
|
|
121
|
+
root: The root directory where the domain file will be stored.
|
|
122
|
+
"""
|
|
123
|
+
domain_yaml = handler.domain
|
|
124
|
+
data = read_yaml(domain_yaml)
|
|
125
|
+
target = root / DOMAIN_FILENAME
|
|
126
|
+
write_yaml(
|
|
127
|
+
data=data,
|
|
128
|
+
target=target,
|
|
129
|
+
should_preserve_key_order=True,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _handle_data(handler: StudioDataHandler, root: Path) -> None:
|
|
134
|
+
"""Persist NLU data and flows.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
handler: The data handler to retrieve the NLU data and flows from.
|
|
138
|
+
root: The root directory where the NLU data and flows will be stored.
|
|
139
|
+
"""
|
|
140
|
+
data_path = root / DEFAULT_DATA_PATH
|
|
141
|
+
data_path.mkdir(parents=True, exist_ok=True)
|
|
142
|
+
|
|
143
|
+
if handler.has_nlu():
|
|
144
|
+
nlu_yaml = handler.nlu
|
|
145
|
+
nlu_data = read_yaml(nlu_yaml)
|
|
146
|
+
if nlu_data.get("nlu"):
|
|
147
|
+
pretty_write_nlu_yaml(
|
|
148
|
+
nlu_data, data_path / DEFAULT_TRAINING_DATA_OUTPUT_PATH
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if handler.has_flows():
|
|
152
|
+
flows_yaml = handler.flows
|
|
153
|
+
data = read_yaml(flows_yaml)
|
|
154
|
+
flows_data = data.get("flows", {})
|
|
155
|
+
flows_list = FlowsList.from_json(flows_data)
|
|
156
|
+
_dump_flows_as_separate_files(flows_list.underlying_flows, data_path)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def pretty_write_nlu_yaml(data: Dict, file: Path) -> None:
|
|
160
|
+
"""Writes the NLU YAML in a pretty way."""
|
|
161
|
+
dumper = yaml.YAML()
|
|
162
|
+
if nlu_data := data.get("nlu"):
|
|
163
|
+
for item in nlu_data:
|
|
164
|
+
if item.get("examples"):
|
|
165
|
+
item["examples"] = LiteralScalarString(item["examples"])
|
|
166
|
+
with file.open("w", encoding="utf-8") as outfile:
|
|
167
|
+
dumper.dump(data, outfile)
|