rasa-pro 3.14.0.dev20250922__py3-none-any.whl → 3.14.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +15 -3
- rasa/agents/__init__.py +0 -0
- rasa/agents/agent_factory.py +122 -0
- rasa/agents/agent_manager.py +211 -0
- rasa/agents/constants.py +43 -0
- rasa/agents/core/__init__.py +0 -0
- rasa/agents/core/agent_protocol.py +107 -0
- rasa/agents/core/types.py +81 -0
- rasa/agents/exceptions.py +38 -0
- rasa/agents/protocol/__init__.py +5 -0
- rasa/agents/protocol/a2a/__init__.py +0 -0
- rasa/agents/protocol/a2a/a2a_agent.py +879 -0
- rasa/agents/protocol/mcp/__init__.py +0 -0
- rasa/agents/protocol/mcp/mcp_base_agent.py +726 -0
- rasa/agents/protocol/mcp/mcp_open_agent.py +327 -0
- rasa/agents/protocol/mcp/mcp_task_agent.py +522 -0
- rasa/agents/schemas/__init__.py +13 -0
- rasa/agents/schemas/agent_input.py +38 -0
- rasa/agents/schemas/agent_output.py +26 -0
- rasa/agents/schemas/agent_tool_result.py +65 -0
- rasa/agents/schemas/agent_tool_schema.py +186 -0
- rasa/agents/templates/__init__.py +0 -0
- rasa/agents/templates/mcp_open_agent_prompt_template.jinja2 +20 -0
- rasa/agents/templates/mcp_task_agent_prompt_template.jinja2 +22 -0
- rasa/agents/utils.py +206 -0
- rasa/agents/validation.py +485 -0
- rasa/api.py +24 -9
- rasa/builder/config.py +6 -2
- rasa/builder/guardrails/{lakera.py → clients.py} +55 -5
- rasa/builder/guardrails/constants.py +3 -0
- rasa/builder/guardrails/models.py +45 -10
- rasa/builder/guardrails/policy_checker.py +324 -0
- rasa/builder/guardrails/utils.py +42 -276
- rasa/builder/llm_service.py +32 -5
- rasa/builder/models.py +1 -0
- rasa/builder/project_generator.py +6 -1
- rasa/builder/service.py +16 -13
- rasa/builder/training_service.py +18 -24
- rasa/builder/validation_service.py +1 -1
- rasa/cli/arguments/default_arguments.py +12 -0
- rasa/cli/arguments/run.py +2 -0
- rasa/cli/arguments/train.py +2 -0
- rasa/cli/data.py +10 -8
- rasa/cli/dialogue_understanding_test.py +10 -7
- rasa/cli/e2e_test.py +9 -6
- rasa/cli/evaluate.py +4 -2
- rasa/cli/export.py +5 -2
- rasa/cli/inspect.py +8 -4
- rasa/cli/interactive.py +5 -4
- rasa/cli/llm_fine_tuning.py +11 -6
- rasa/cli/project_templates/tutorial/credentials.yml +10 -0
- rasa/cli/run.py +12 -10
- rasa/cli/scaffold.py +4 -4
- rasa/cli/shell.py +9 -5
- rasa/cli/studio/studio.py +1 -1
- rasa/cli/test.py +34 -14
- rasa/cli/train.py +41 -28
- rasa/cli/utils.py +1 -393
- rasa/cli/validation/__init__.py +0 -0
- rasa/cli/validation/bot_config.py +223 -0
- rasa/cli/validation/config_path_validation.py +257 -0
- rasa/cli/x.py +8 -4
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +51 -10
- rasa/core/actions/grpc_custom_action_executor.py +1 -1
- rasa/core/agent.py +19 -2
- rasa/core/available_agents.py +229 -0
- rasa/core/channels/__init__.py +82 -35
- rasa/core/channels/development_inspector.py +3 -3
- rasa/core/channels/inspector/README.md +25 -13
- rasa/core/channels/inspector/dist/assets/{arc-35222594.js → arc-6177260a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-a0efbfd3.js → blockDiagram-38ab4fdb-b054f038.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-0584c0f2.js → c4Diagram-3d4e48cf-f25427d5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/channel-bf9cbb34.js +1 -0
- rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-39f40dbe.js → classDiagram-70f12bd4-c7a2af53.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-1ad755f3.js → classDiagram-v2-f2320105-58db65c0.js} +1 -1
- rasa/core/channels/inspector/dist/assets/clone-8f9083bb.js +1 -0
- rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-b0f4f0fe.js → createText-2e5e7dd3-088372e2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-9039bff9.js → edges-e0da2a9e-58676240.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-65c9b127.js → erDiagram-9861fffd-0c14d7c6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-4f08b38e.js → flowDb-956e92f1-ea63f85c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-e95c362a.js → flowDiagram-66a62f08-a2af48cd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-9ecd5b59.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-703c3015.js → flowchart-elk-definition-4a651766-6937abe7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-699328ea.js → ganttDiagram-c361ad54-7473f357.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-04cf4b05.js → gitGraphDiagram-72cf32ee-d0c9405e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{graph-ee94449e.js → graph-0a6f8466.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3862675e-940162b4.js → index-3862675e-7610671a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-74e01d94.js +1354 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-c79c2866.js → infoDiagram-f8f76790-be397dc7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-84489d30.js → journeyDiagram-49397b02-4cefbf62.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-a9aa9858.js → layout-e7fbc2bf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-eb73cf26.js → line-a8aa457c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-b3399f9a.js → linear-3351e0d2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-b095bf1a.js → mindmap-definition-fc14e90a-b8cbf605.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-07644b66.js → pieDiagram-8a3498a8-f327f774.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-573a3f9c.js → quadrantDiagram-120e2f19-2854c591.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-d457e1e1.js → requirementDiagram-deff3bca-964985d5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-9d26e1a2.js → sankeyDiagram-04a897e0-edeb4f33.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-3a9cde10.js → sequenceDiagram-704730f1-fcf70125.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-4f3e8cec.js → stateDiagram-587899a1-0e770395.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-e617e5bf.js → stateDiagram-v2-d93cdb3a-af8dcd22.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-eab30d2f.js → styles-6aaf32cf-36a9e70d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9a916d00-09994be2.js → styles-9a916d00-884a8b5b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-c10674c1-b7110364.js → styles-c10674c1-dc097813.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-3ebc92ad.js → svgDrawCommon-08f97a94-5a2c7eed.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-7d13d2f2.js → timeline-definition-85554ec2-e89c4f6e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-488385e1.js → xychartDiagram-e933f94c-afb6fe56.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/package.json +18 -18
- rasa/core/channels/inspector/src/App.tsx +29 -4
- rasa/core/channels/inspector/src/components/DialogueAgentStack.tsx +108 -0
- rasa/core/channels/inspector/src/components/{DialogueStack.tsx → DialogueHistoryStack.tsx} +4 -2
- rasa/core/channels/inspector/src/helpers/audio/audiostream.ts +7 -4
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +24 -3
- rasa/core/channels/inspector/src/helpers/utils.test.ts +127 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +66 -1
- rasa/core/channels/inspector/src/theme/base/styles.ts +19 -1
- rasa/core/channels/inspector/src/types.ts +21 -0
- rasa/core/channels/inspector/yarn.lock +336 -189
- rasa/core/channels/studio_chat.py +6 -6
- rasa/core/channels/telegram.py +4 -9
- rasa/core/channels/voice_stream/genesys.py +1 -1
- rasa/core/channels/voice_stream/tts/deepgram.py +140 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +5 -1
- rasa/core/channels/voice_stream/voice_channel.py +3 -0
- rasa/core/config/__init__.py +0 -0
- rasa/core/{available_endpoints.py → config/available_endpoints.py} +51 -16
- rasa/core/config/configuration.py +260 -0
- rasa/core/config/credentials.py +19 -0
- rasa/core/config/message_procesing_config.py +34 -0
- rasa/core/constants.py +4 -0
- rasa/core/policies/enterprise_search_policy.py +5 -3
- rasa/core/policies/flow_policy.py +4 -4
- rasa/core/policies/flows/agent_executor.py +632 -0
- rasa/core/policies/flows/flow_executor.py +136 -75
- rasa/core/policies/flows/mcp_tool_executor.py +298 -0
- rasa/core/policies/intentless_policy.py +1 -1
- rasa/core/policies/ted_policy.py +20 -12
- rasa/core/policies/unexpected_intent_policy.py +6 -0
- rasa/core/processor.py +68 -44
- rasa/core/run.py +37 -8
- rasa/core/test.py +4 -0
- rasa/core/tracker_stores/tracker_store.py +3 -7
- rasa/core/train.py +1 -1
- rasa/core/training/interactive.py +20 -18
- rasa/core/training/story_conflict.py +5 -5
- rasa/core/utils.py +22 -23
- rasa/dialogue_understanding/commands/__init__.py +8 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +19 -5
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +21 -2
- rasa/dialogue_understanding/commands/clarify_command.py +20 -2
- rasa/dialogue_understanding/commands/continue_agent_command.py +91 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +21 -2
- rasa/dialogue_understanding/commands/restart_agent_command.py +162 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +68 -7
- rasa/dialogue_understanding/commands/utils.py +124 -2
- rasa/dialogue_understanding/generator/command_parser.py +4 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +50 -12
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +66 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +66 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +89 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_gpt_4o_2024_11_20_template.jinja2 +88 -0
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +42 -7
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +40 -3
- rasa/dialogue_understanding/generator/single_step/single_step_based_llm_command_generator.py +20 -3
- rasa/dialogue_understanding/patterns/cancel.py +27 -6
- rasa/dialogue_understanding/patterns/clarify.py +3 -14
- rasa/dialogue_understanding/patterns/continue_interrupted.py +239 -6
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +46 -8
- rasa/dialogue_understanding/processor/command_processor.py +136 -15
- rasa/dialogue_understanding/stack/dialogue_stack.py +98 -2
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +57 -0
- rasa/dialogue_understanding/stack/utils.py +57 -3
- rasa/dialogue_understanding/utils.py +24 -4
- rasa/dialogue_understanding_test/du_test_runner.py +8 -3
- rasa/e2e_test/e2e_test_runner.py +13 -3
- rasa/engine/caching.py +2 -2
- rasa/engine/constants.py +1 -1
- rasa/engine/recipes/default_components.py +138 -49
- rasa/engine/recipes/default_recipe.py +108 -11
- rasa/engine/runner/dask.py +8 -5
- rasa/engine/validation.py +19 -6
- rasa/graph_components/validators/default_recipe_validator.py +86 -28
- rasa/hooks.py +5 -5
- rasa/llm_fine_tuning/utils.py +2 -2
- rasa/model_training.py +60 -47
- rasa/nlu/classifiers/diet_classifier.py +198 -98
- rasa/nlu/classifiers/logistic_regression_classifier.py +1 -4
- rasa/nlu/classifiers/mitie_intent_classifier.py +3 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +1 -3
- rasa/nlu/extractors/crf_entity_extractor.py +9 -10
- rasa/nlu/extractors/mitie_entity_extractor.py +3 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +3 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +4 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +5 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +2 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +3 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +4 -2
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +4 -0
- rasa/nlu/selectors/response_selector.py +10 -2
- rasa/nlu/tokenizers/jieba_tokenizer.py +3 -4
- rasa/nlu/tokenizers/mitie_tokenizer.py +3 -2
- rasa/nlu/tokenizers/spacy_tokenizer.py +3 -2
- rasa/nlu/utils/mitie_utils.py +3 -0
- rasa/nlu/utils/spacy_utils.py +3 -2
- rasa/plugin.py +8 -8
- rasa/privacy/privacy_manager.py +12 -3
- rasa/server.py +15 -3
- rasa/shared/agents/__init__.py +0 -0
- rasa/shared/agents/auth/__init__.py +0 -0
- rasa/shared/agents/auth/agent_auth_factory.py +105 -0
- rasa/shared/agents/auth/agent_auth_manager.py +92 -0
- rasa/shared/agents/auth/auth_strategy/__init__.py +19 -0
- rasa/shared/agents/auth/auth_strategy/agent_auth_strategy.py +52 -0
- rasa/shared/agents/auth/auth_strategy/api_key_auth_strategy.py +42 -0
- rasa/shared/agents/auth/auth_strategy/bearer_token_auth_strategy.py +28 -0
- rasa/shared/agents/auth/auth_strategy/oauth2_auth_strategy.py +167 -0
- rasa/shared/agents/auth/constants.py +12 -0
- rasa/shared/agents/auth/types.py +12 -0
- rasa/shared/agents/utils.py +35 -0
- rasa/shared/constants.py +8 -0
- rasa/shared/core/constants.py +16 -1
- rasa/shared/core/domain.py +0 -7
- rasa/shared/core/events.py +327 -0
- rasa/shared/core/flows/constants.py +5 -0
- rasa/shared/core/flows/flows_list.py +21 -5
- rasa/shared/core/flows/flows_yaml_schema.json +119 -184
- rasa/shared/core/flows/steps/call.py +49 -5
- rasa/shared/core/flows/steps/collect.py +98 -13
- rasa/shared/core/flows/validation.py +372 -8
- rasa/shared/core/flows/yaml_flows_io.py +3 -2
- rasa/shared/core/slots.py +2 -2
- rasa/shared/core/trackers.py +5 -2
- rasa/shared/exceptions.py +16 -0
- rasa/shared/importers/rasa.py +1 -1
- rasa/shared/importers/utils.py +9 -3
- rasa/shared/providers/llm/_base_litellm_client.py +41 -9
- rasa/shared/providers/llm/litellm_router_llm_client.py +8 -4
- rasa/shared/providers/llm/llm_client.py +7 -3
- rasa/shared/providers/llm/llm_response.py +66 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +8 -4
- rasa/shared/utils/common.py +24 -0
- rasa/shared/utils/health_check/health_check.py +7 -3
- rasa/shared/utils/llm.py +39 -16
- rasa/shared/utils/mcp/__init__.py +0 -0
- rasa/shared/utils/mcp/server_connection.py +247 -0
- rasa/shared/utils/mcp/utils.py +20 -0
- rasa/shared/utils/schemas/events.py +42 -0
- rasa/shared/utils/yaml.py +3 -1
- rasa/studio/pull/pull.py +3 -2
- rasa/studio/train.py +8 -7
- rasa/studio/upload.py +3 -6
- rasa/telemetry.py +69 -5
- rasa/tracing/config.py +45 -12
- rasa/tracing/constants.py +14 -0
- rasa/tracing/instrumentation/attribute_extractors.py +142 -9
- rasa/tracing/instrumentation/instrumentation.py +626 -21
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +4 -4
- rasa/tracing/instrumentation/metrics.py +32 -0
- rasa/tracing/metric_instrument_provider.py +68 -0
- rasa/utils/common.py +92 -1
- rasa/utils/endpoints.py +11 -2
- rasa/utils/log_utils.py +96 -5
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/__init__.py +7 -0
- rasa/utils/tensorflow/callback.py +136 -101
- rasa/utils/tensorflow/crf.py +1 -1
- rasa/utils/tensorflow/data_generator.py +21 -8
- rasa/utils/tensorflow/layers.py +21 -11
- rasa/utils/tensorflow/metrics.py +7 -3
- rasa/utils/tensorflow/models.py +56 -8
- rasa/utils/tensorflow/rasa_layers.py +8 -6
- rasa/utils/tensorflow/transformer.py +2 -3
- rasa/utils/train_utils.py +54 -24
- rasa/validator.py +5 -5
- rasa/version.py +1 -1
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc1.dist-info}/METADATA +46 -41
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc1.dist-info}/RECORD +285 -226
- rasa/builder/scrape_rasa_docs.py +0 -97
- rasa/core/channels/inspector/dist/assets/channel-8e08bed9.js +0 -1
- rasa/core/channels/inspector/dist/assets/clone-78c82dea.js +0 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-2b08f601.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-c941dcb3.js +0 -1336
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc1.dist-info}/entry_points.txt +0 -0
rasa/shared/exceptions.py
CHANGED
|
@@ -163,9 +163,25 @@ class ProviderClientAPIException(RasaException):
|
|
|
163
163
|
return s
|
|
164
164
|
|
|
165
165
|
|
|
166
|
+
class LLMToolResponseDecodeError(ProviderClientAPIException):
|
|
167
|
+
"""Raised when a JSON decoding error occurs in LLM tool response."""
|
|
168
|
+
|
|
169
|
+
|
|
166
170
|
class ProviderClientValidationError(RasaException):
|
|
167
171
|
"""Raised for errors that occur during validation of the API client."""
|
|
168
172
|
|
|
169
173
|
|
|
170
174
|
class FinetuningDataPreparationException(RasaException):
|
|
171
175
|
"""Raised when there is an error in data preparation for fine-tuning."""
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class AgentInitializationException(RasaException):
|
|
179
|
+
"""Raised when there is an error during the initialization of an agent."""
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class AgentAuthInitializationException(RasaException):
|
|
183
|
+
"""Raised when there is an error during the initialization of agent auth client."""
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class AuthenticationError(RasaException):
|
|
187
|
+
"""Raised when there is an authentication error."""
|
rasa/shared/importers/rasa.py
CHANGED
|
@@ -74,7 +74,7 @@ class RasaFileImporter(TrainingDataImporter):
|
|
|
74
74
|
@cached_method
|
|
75
75
|
def get_flows(self) -> FlowsList:
|
|
76
76
|
"""Retrieves training stories / rules (see parent class for full docstring)."""
|
|
77
|
-
return utils.flows_from_paths(self._flow_files)
|
|
77
|
+
return utils.flows_from_paths(self._flow_files, self.get_domain())
|
|
78
78
|
|
|
79
79
|
@cached_method
|
|
80
80
|
def get_conversation_tests(self) -> StoryGraph:
|
rasa/shared/importers/utils.py
CHANGED
|
@@ -51,8 +51,14 @@ def story_graph_from_paths(
|
|
|
51
51
|
return StoryGraph(story_steps)
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
def flows_from_paths(files: List[Text]) -> FlowsList:
|
|
55
|
-
"""Returns the flows from paths.
|
|
54
|
+
def flows_from_paths(files: List[Text], domain: Optional[Domain] = None) -> FlowsList:
|
|
55
|
+
"""Returns the flows from paths.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
files: List of flow file paths to load.
|
|
59
|
+
domain: Optional domain for validation. If provided, exit_if conditions
|
|
60
|
+
will be validated against defined slots.
|
|
61
|
+
"""
|
|
56
62
|
from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader
|
|
57
63
|
|
|
58
64
|
flows = FlowsList(underlying_flows=[])
|
|
@@ -60,7 +66,7 @@ def flows_from_paths(files: List[Text]) -> FlowsList:
|
|
|
60
66
|
flows = flows.merge(
|
|
61
67
|
YAMLFlowsReader.read_from_file(file), ignore_duplicates=False
|
|
62
68
|
)
|
|
63
|
-
flows.validate()
|
|
69
|
+
flows.validate(domain)
|
|
64
70
|
return flows
|
|
65
71
|
|
|
66
72
|
|
|
@@ -21,7 +21,7 @@ from rasa.shared.providers._ssl_verification_utils import (
|
|
|
21
21
|
ensure_ssl_certificates_for_litellm_non_openai_based_clients,
|
|
22
22
|
ensure_ssl_certificates_for_litellm_openai_based_clients,
|
|
23
23
|
)
|
|
24
|
-
from rasa.shared.providers.llm.llm_response import LLMResponse, LLMUsage
|
|
24
|
+
from rasa.shared.providers.llm.llm_response import LLMResponse, LLMToolCall, LLMUsage
|
|
25
25
|
from rasa.shared.utils.io import resolve_environment_variables, suppress_logs
|
|
26
26
|
|
|
27
27
|
structlogger = structlog.get_logger()
|
|
@@ -126,7 +126,9 @@ class _BaseLiteLLMClient:
|
|
|
126
126
|
raise ProviderClientValidationError(event_info)
|
|
127
127
|
|
|
128
128
|
@suppress_logs(log_level=logging.WARNING)
|
|
129
|
-
def completion(
|
|
129
|
+
def completion(
|
|
130
|
+
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
131
|
+
) -> LLMResponse:
|
|
130
132
|
"""Synchronously generate completions for given list of messages.
|
|
131
133
|
|
|
132
134
|
Args:
|
|
@@ -138,6 +140,7 @@ class _BaseLiteLLMClient:
|
|
|
138
140
|
- a list of messages. Each message is a string and will be formatted
|
|
139
141
|
as a user message.
|
|
140
142
|
- a single message as a string which will be formatted as user message.
|
|
143
|
+
**kwargs: Additional parameters to pass to the completion call.
|
|
141
144
|
|
|
142
145
|
Returns:
|
|
143
146
|
List of message completions.
|
|
@@ -147,15 +150,19 @@ class _BaseLiteLLMClient:
|
|
|
147
150
|
"""
|
|
148
151
|
try:
|
|
149
152
|
formatted_messages = self._get_formatted_messages(messages)
|
|
150
|
-
arguments =
|
|
151
|
-
|
|
153
|
+
arguments = cast(
|
|
154
|
+
Dict[str, Any], resolve_environment_variables(self._completion_fn_args)
|
|
155
|
+
)
|
|
156
|
+
response = completion(
|
|
157
|
+
messages=formatted_messages, **{**arguments, **kwargs}
|
|
158
|
+
)
|
|
152
159
|
return self._format_response(response)
|
|
153
160
|
except Exception as e:
|
|
154
|
-
raise ProviderClientAPIException(e)
|
|
161
|
+
raise ProviderClientAPIException(e) from e
|
|
155
162
|
|
|
156
163
|
@suppress_logs(log_level=logging.WARNING)
|
|
157
164
|
async def acompletion(
|
|
158
|
-
self, messages: Union[List[dict], List[str], str]
|
|
165
|
+
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
159
166
|
) -> LLMResponse:
|
|
160
167
|
"""Asynchronously generate completions for given list of messages.
|
|
161
168
|
|
|
@@ -168,6 +175,7 @@ class _BaseLiteLLMClient:
|
|
|
168
175
|
- a list of messages. Each message is a string and will be formatted
|
|
169
176
|
as a user message.
|
|
170
177
|
- a single message as a string which will be formatted as user message.
|
|
178
|
+
**kwargs: Additional parameters to pass to the completion call.
|
|
171
179
|
|
|
172
180
|
Returns:
|
|
173
181
|
List of message completions.
|
|
@@ -177,8 +185,12 @@ class _BaseLiteLLMClient:
|
|
|
177
185
|
"""
|
|
178
186
|
try:
|
|
179
187
|
formatted_messages = self._get_formatted_messages(messages)
|
|
180
|
-
arguments =
|
|
181
|
-
|
|
188
|
+
arguments = cast(
|
|
189
|
+
Dict[str, Any], resolve_environment_variables(self._completion_fn_args)
|
|
190
|
+
)
|
|
191
|
+
response = await acompletion(
|
|
192
|
+
messages=formatted_messages, **{**arguments, **kwargs}
|
|
193
|
+
)
|
|
182
194
|
return self._format_response(response)
|
|
183
195
|
except Exception as e:
|
|
184
196
|
message = ""
|
|
@@ -197,7 +209,7 @@ class _BaseLiteLLMClient:
|
|
|
197
209
|
"In case you are getting OpenAI connection errors, such as missing "
|
|
198
210
|
"API key, your configuration is incorrect."
|
|
199
211
|
)
|
|
200
|
-
raise ProviderClientAPIException(e, message)
|
|
212
|
+
raise ProviderClientAPIException(e, message) from e
|
|
201
213
|
|
|
202
214
|
def _get_formatted_messages(
|
|
203
215
|
self, messages: Union[List[dict], List[str], str]
|
|
@@ -246,12 +258,32 @@ class _BaseLiteLLMClient:
|
|
|
246
258
|
else 0
|
|
247
259
|
)
|
|
248
260
|
formatted_response.usage = LLMUsage(prompt_tokens, completion_tokens)
|
|
261
|
+
|
|
262
|
+
# Extract tool calls from all choices
|
|
263
|
+
formatted_response.tool_calls = self._extract_tool_calls(response)
|
|
264
|
+
|
|
249
265
|
structlogger.debug(
|
|
250
266
|
"base_litellm_client.formatted_response",
|
|
251
267
|
formatted_response=formatted_response.to_dict(),
|
|
252
268
|
)
|
|
253
269
|
return formatted_response
|
|
254
270
|
|
|
271
|
+
def _extract_tool_calls(self, response: Any) -> List[LLMToolCall]:
|
|
272
|
+
"""Extract tool calls from response choices.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
response: List of response choices from LiteLLM
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
List of LLMToolCall objects, empty if no tool calls found
|
|
279
|
+
"""
|
|
280
|
+
return [
|
|
281
|
+
LLMToolCall.from_litellm(tool_call)
|
|
282
|
+
for choice in response.choices
|
|
283
|
+
if choice.message.tool_calls
|
|
284
|
+
for tool_call in choice.message.tool_calls
|
|
285
|
+
]
|
|
286
|
+
|
|
255
287
|
def _format_text_completion_response(self, response: Any) -> LLMResponse:
|
|
256
288
|
"""Parses the LiteLLM text completion response to Rasa format."""
|
|
257
289
|
formatted_response = LLMResponse(
|
|
@@ -122,7 +122,9 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
122
122
|
raise ProviderClientAPIException(e)
|
|
123
123
|
|
|
124
124
|
@suppress_logs(log_level=logging.WARNING)
|
|
125
|
-
def completion(
|
|
125
|
+
def completion(
|
|
126
|
+
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
127
|
+
) -> LLMResponse:
|
|
126
128
|
"""
|
|
127
129
|
Synchronously generate completions for given list of messages.
|
|
128
130
|
|
|
@@ -140,6 +142,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
140
142
|
- a list of messages. Each message is a string and will be formatted
|
|
141
143
|
as a user message.
|
|
142
144
|
- a single message as a string which will be formatted as user message.
|
|
145
|
+
**kwargs: Additional parameters to pass to the completion call.
|
|
143
146
|
Returns:
|
|
144
147
|
List of message completions.
|
|
145
148
|
Raises:
|
|
@@ -150,7 +153,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
150
153
|
try:
|
|
151
154
|
formatted_messages = self._format_messages(messages)
|
|
152
155
|
response = self.router_client.completion(
|
|
153
|
-
messages=formatted_messages, **self._completion_fn_args
|
|
156
|
+
messages=formatted_messages, **{**self._completion_fn_args, **kwargs}
|
|
154
157
|
)
|
|
155
158
|
return self._format_response(response)
|
|
156
159
|
except Exception as e:
|
|
@@ -158,7 +161,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
158
161
|
|
|
159
162
|
@suppress_logs(log_level=logging.WARNING)
|
|
160
163
|
async def acompletion(
|
|
161
|
-
self, messages: Union[List[dict], List[str], str]
|
|
164
|
+
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
162
165
|
) -> LLMResponse:
|
|
163
166
|
"""
|
|
164
167
|
Asynchronously generate completions for given list of messages.
|
|
@@ -177,6 +180,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
177
180
|
- a list of messages. Each message is a string and will be formatted
|
|
178
181
|
as a user message.
|
|
179
182
|
- a single message as a string which will be formatted as user message.
|
|
183
|
+
**kwargs: Additional parameters to pass to the completion call.
|
|
180
184
|
Returns:
|
|
181
185
|
List of message completions.
|
|
182
186
|
Raises:
|
|
@@ -187,7 +191,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
187
191
|
try:
|
|
188
192
|
formatted_messages = self._format_messages(messages)
|
|
189
193
|
response = await self.router_client.acompletion(
|
|
190
|
-
messages=formatted_messages, **self._completion_fn_args
|
|
194
|
+
messages=formatted_messages, **{**self._completion_fn_args, **kwargs}
|
|
191
195
|
)
|
|
192
196
|
return self._format_response(response)
|
|
193
197
|
except Exception as e:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import Dict, List, Protocol, Union, runtime_checkable
|
|
3
|
+
from typing import Any, Dict, List, Protocol, Union, runtime_checkable
|
|
4
4
|
|
|
5
5
|
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
6
6
|
|
|
@@ -32,7 +32,9 @@ class LLMClient(Protocol):
|
|
|
32
32
|
"""
|
|
33
33
|
...
|
|
34
34
|
|
|
35
|
-
def completion(
|
|
35
|
+
def completion(
|
|
36
|
+
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
37
|
+
) -> LLMResponse:
|
|
36
38
|
"""
|
|
37
39
|
Synchronously generate completions for given list of messages.
|
|
38
40
|
|
|
@@ -48,13 +50,14 @@ class LLMClient(Protocol):
|
|
|
48
50
|
- a list of messages. Each message is a string and will be formatted
|
|
49
51
|
as a user message.
|
|
50
52
|
- a single message as a string which will be formatted as user message.
|
|
53
|
+
**kwargs: Additional parameters to pass to the completion call.
|
|
51
54
|
Returns:
|
|
52
55
|
LLMResponse
|
|
53
56
|
"""
|
|
54
57
|
...
|
|
55
58
|
|
|
56
59
|
async def acompletion(
|
|
57
|
-
self, messages: Union[List[dict], List[str], str]
|
|
60
|
+
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
58
61
|
) -> LLMResponse:
|
|
59
62
|
"""
|
|
60
63
|
Asynchronously generate completions for given list of messages.
|
|
@@ -71,6 +74,7 @@ class LLMClient(Protocol):
|
|
|
71
74
|
- a list of messages. Each message is a string and will be formatted
|
|
72
75
|
as a user message.
|
|
73
76
|
- a single message as a string which will be formatted as user message.
|
|
77
|
+
**kwargs: Additional parameters to pass to the completion call.
|
|
74
78
|
Returns:
|
|
75
79
|
LLMResponse
|
|
76
80
|
"""
|
|
@@ -1,9 +1,15 @@
|
|
|
1
1
|
import functools
|
|
2
|
+
import json
|
|
2
3
|
import time
|
|
3
4
|
from dataclasses import asdict, dataclass, field
|
|
4
5
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Text, Union
|
|
5
6
|
|
|
6
7
|
import structlog
|
|
8
|
+
from litellm.utils import ChatCompletionMessageToolCall
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from rasa.shared.constants import KEY_TOOL_CALLS
|
|
12
|
+
from rasa.shared.exceptions import LLMToolResponseDecodeError
|
|
7
13
|
|
|
8
14
|
structlogger = structlog.get_logger()
|
|
9
15
|
|
|
@@ -38,6 +44,53 @@ class LLMUsage:
|
|
|
38
44
|
return asdict(self)
|
|
39
45
|
|
|
40
46
|
|
|
47
|
+
class LLMToolCall(BaseModel):
|
|
48
|
+
"""A class representing a response from an LLM tool call."""
|
|
49
|
+
|
|
50
|
+
id: str
|
|
51
|
+
"""The ID of the tool call."""
|
|
52
|
+
|
|
53
|
+
tool_name: str
|
|
54
|
+
"""The name of the tool that was called."""
|
|
55
|
+
|
|
56
|
+
tool_args: Dict[str, Any]
|
|
57
|
+
"""The arguments passed to the tool call."""
|
|
58
|
+
|
|
59
|
+
type: str = "function"
|
|
60
|
+
"""The type of the tool call."""
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def from_dict(cls, data: Dict[Text, Any]) -> "LLMToolCall":
|
|
64
|
+
"""Creates an LLMToolResponse from a dictionary."""
|
|
65
|
+
return cls(**data)
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_litellm(cls, data: ChatCompletionMessageToolCall) -> "LLMToolCall":
|
|
69
|
+
"""Creates an LLMToolResponse from a dictionary."""
|
|
70
|
+
try:
|
|
71
|
+
tool_args = json.loads(data.function.arguments)
|
|
72
|
+
except json.JSONDecodeError as e:
|
|
73
|
+
structlogger.error(
|
|
74
|
+
"llm_response.litellm_tool_call.invalid_arguments",
|
|
75
|
+
tool_name=data.function.name,
|
|
76
|
+
tool_call=data.function.arguments,
|
|
77
|
+
)
|
|
78
|
+
raise LLMToolResponseDecodeError(
|
|
79
|
+
original_exception=e,
|
|
80
|
+
message=(
|
|
81
|
+
f"Invalid arguments for tool call - `{data.function.name}`: "
|
|
82
|
+
f"`{data.function.arguments}`"
|
|
83
|
+
),
|
|
84
|
+
) from e
|
|
85
|
+
|
|
86
|
+
return cls(
|
|
87
|
+
id=data.id,
|
|
88
|
+
tool_name=data.function.name,
|
|
89
|
+
tool_args=tool_args,
|
|
90
|
+
type=data.type,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
41
94
|
@dataclass
|
|
42
95
|
class LLMResponse:
|
|
43
96
|
id: str
|
|
@@ -62,12 +115,22 @@ class LLMResponse:
|
|
|
62
115
|
latency: Optional[float] = None
|
|
63
116
|
"""Optional field to store the latency of the LLM API call."""
|
|
64
117
|
|
|
118
|
+
tool_calls: Optional[List[LLMToolCall]] = None
|
|
119
|
+
"""The list of tool calls the model generated for the input prompt."""
|
|
120
|
+
|
|
65
121
|
@classmethod
|
|
66
122
|
def from_dict(cls, data: Dict[Text, Any]) -> "LLMResponse":
|
|
67
123
|
"""Creates an LLMResponse from a dictionary."""
|
|
68
124
|
usage_data = data.get("usage", {})
|
|
69
125
|
usage_obj = LLMUsage.from_dict(usage_data) if usage_data else None
|
|
70
126
|
|
|
127
|
+
tool_calls_data = data.get(KEY_TOOL_CALLS, [])
|
|
128
|
+
tool_calls_obj = (
|
|
129
|
+
[LLMToolCall.from_dict(tool) for tool in tool_calls_data]
|
|
130
|
+
if tool_calls_data
|
|
131
|
+
else None
|
|
132
|
+
)
|
|
133
|
+
|
|
71
134
|
return cls(
|
|
72
135
|
id=data["id"],
|
|
73
136
|
choices=data["choices"],
|
|
@@ -76,6 +139,7 @@ class LLMResponse:
|
|
|
76
139
|
usage=usage_obj,
|
|
77
140
|
additional_info=data.get("additional_info"),
|
|
78
141
|
latency=data.get("latency"),
|
|
142
|
+
tool_calls=tool_calls_obj,
|
|
79
143
|
)
|
|
80
144
|
|
|
81
145
|
@classmethod
|
|
@@ -92,6 +156,8 @@ class LLMResponse:
|
|
|
92
156
|
result = asdict(self)
|
|
93
157
|
if self.usage:
|
|
94
158
|
result["usage"] = self.usage.to_dict()
|
|
159
|
+
if self.tool_calls:
|
|
160
|
+
result[KEY_TOOL_CALLS] = [tool.model_dump() for tool in self.tool_calls]
|
|
95
161
|
return result
|
|
96
162
|
|
|
97
163
|
|
|
@@ -237,7 +237,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
237
237
|
raise ProviderClientAPIException(e)
|
|
238
238
|
|
|
239
239
|
async def acompletion(
|
|
240
|
-
self, messages: Union[List[dict], List[str], str]
|
|
240
|
+
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
241
241
|
) -> LLMResponse:
|
|
242
242
|
"""Asynchronous completion of the model with the given messages.
|
|
243
243
|
|
|
@@ -255,15 +255,18 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
255
255
|
- a list of messages. Each message is a string and will be formatted
|
|
256
256
|
as a user message.
|
|
257
257
|
- a single message as a string which will be formatted as user message.
|
|
258
|
+
**kwargs: Additional parameters to pass to the completion call.
|
|
258
259
|
|
|
259
260
|
Returns:
|
|
260
261
|
The completion response.
|
|
261
262
|
"""
|
|
262
263
|
if self._use_chat_completions_endpoint:
|
|
263
|
-
return await super().acompletion(messages)
|
|
264
|
+
return await super().acompletion(messages, **kwargs)
|
|
264
265
|
return await self._atext_completion(messages)
|
|
265
266
|
|
|
266
|
-
def completion(
|
|
267
|
+
def completion(
|
|
268
|
+
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
269
|
+
) -> LLMResponse:
|
|
267
270
|
"""Completion of the model with the given messages.
|
|
268
271
|
|
|
269
272
|
Method overrides the base class method to call the appropriate
|
|
@@ -273,12 +276,13 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
273
276
|
|
|
274
277
|
Args:
|
|
275
278
|
messages: The messages to be used for completion.
|
|
279
|
+
**kwargs: Additional parameters to pass to the completion call.
|
|
276
280
|
|
|
277
281
|
Returns:
|
|
278
282
|
The completion response.
|
|
279
283
|
"""
|
|
280
284
|
if self._use_chat_completions_endpoint:
|
|
281
|
-
return super().completion(messages)
|
|
285
|
+
return super().completion(messages, **kwargs)
|
|
282
286
|
return self._text_completion(messages)
|
|
283
287
|
|
|
284
288
|
@staticmethod
|
rasa/shared/utils/common.py
CHANGED
|
@@ -17,6 +17,7 @@ from typing import (
|
|
|
17
17
|
Optional,
|
|
18
18
|
Sequence,
|
|
19
19
|
Text,
|
|
20
|
+
Tuple,
|
|
20
21
|
Type,
|
|
21
22
|
)
|
|
22
23
|
|
|
@@ -102,9 +103,12 @@ def sort_list_of_dicts_by_first_key(dicts: List[Dict]) -> List[Dict]:
|
|
|
102
103
|
|
|
103
104
|
def cached_method(f: Callable[..., Any]) -> Callable[..., Any]:
|
|
104
105
|
"""Caches method calls based on the call's `args` and `kwargs`.
|
|
106
|
+
|
|
105
107
|
Works for `async` and `sync` methods. Don't apply this to functions.
|
|
108
|
+
|
|
106
109
|
Args:
|
|
107
110
|
f: The decorated method whose return value should be cached.
|
|
111
|
+
|
|
108
112
|
Returns:
|
|
109
113
|
The return value which the method gives for the first call with the given
|
|
110
114
|
arguments.
|
|
@@ -358,6 +362,7 @@ def validate_environment(
|
|
|
358
362
|
component_name: str,
|
|
359
363
|
) -> None:
|
|
360
364
|
"""Make sure all needed requirements for a component are met.
|
|
365
|
+
|
|
361
366
|
Args:
|
|
362
367
|
required_env_vars: List of environment variables that should be set
|
|
363
368
|
required_packages: List of packages that should be installed
|
|
@@ -389,3 +394,22 @@ Sign up at: https://feedback.rasa.com
|
|
|
389
394
|
{separator}
|
|
390
395
|
"""
|
|
391
396
|
print_success(message)
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def conditional_import(module_name: str, class_name: str) -> Tuple[Any, bool]:
|
|
400
|
+
"""Conditionally import a class, returning (class, is_available) tuple.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
module_name: The module path to import from
|
|
404
|
+
class_name: The class name to import
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
A tuple of (class, is_available) where class is the imported class
|
|
408
|
+
or None if import failed, and is_available is a boolean indicating
|
|
409
|
+
whether the import was successful.
|
|
410
|
+
"""
|
|
411
|
+
try:
|
|
412
|
+
module = __import__(module_name, fromlist=[class_name])
|
|
413
|
+
return getattr(module, class_name), True
|
|
414
|
+
except ImportError:
|
|
415
|
+
return None, False
|
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Any, Dict, Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
3
3
|
|
|
4
4
|
from rasa.exceptions import HealthCheckError
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
pass
|
|
5
8
|
from rasa.shared.constants import (
|
|
6
9
|
LLM_API_HEALTH_CHECK_DEFAULT_VALUE,
|
|
7
10
|
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
@@ -63,6 +66,7 @@ def perform_llm_health_check(
|
|
|
63
66
|
log_source_component: str,
|
|
64
67
|
) -> None:
|
|
65
68
|
"""Try to instantiate the LLM Client to validate the provided config.
|
|
69
|
+
|
|
66
70
|
If the LLM_API_HEALTH_CHECK environment variable is true, perform a test call
|
|
67
71
|
to the LLM API. If config contains multiple models, perform a test call for each
|
|
68
72
|
model in the model group.
|
|
@@ -125,6 +129,7 @@ def perform_embeddings_health_check(
|
|
|
125
129
|
log_source_component: str,
|
|
126
130
|
) -> None:
|
|
127
131
|
"""Try to instantiate the Embedder to validate the provided config.
|
|
132
|
+
|
|
128
133
|
If the LLM_API_HEALTH_CHECK environment variable is true, perform a test call
|
|
129
134
|
to the Embeddings API. If config contains multiple models, perform a test call for
|
|
130
135
|
each model in the model group.
|
|
@@ -240,8 +245,7 @@ def send_test_embeddings_api_request(
|
|
|
240
245
|
|
|
241
246
|
|
|
242
247
|
def is_api_health_check_enabled() -> bool:
|
|
243
|
-
"""Determines whether the API health check is enabled
|
|
244
|
-
variable.
|
|
248
|
+
"""Determines whether the API health check is enabled.
|
|
245
249
|
|
|
246
250
|
Returns:
|
|
247
251
|
bool: True if the API health check is enabled, False otherwise.
|
rasa/shared/utils/llm.py
CHANGED
|
@@ -29,7 +29,8 @@ import rasa.cli.telemetry
|
|
|
29
29
|
import rasa.cli.utils
|
|
30
30
|
import rasa.shared.utils.cli
|
|
31
31
|
import rasa.shared.utils.io
|
|
32
|
-
from rasa.core.available_endpoints import AvailableEndpoints
|
|
32
|
+
from rasa.core.config.available_endpoints import AvailableEndpoints
|
|
33
|
+
from rasa.core.config.configuration import Configuration
|
|
33
34
|
from rasa.shared.constants import (
|
|
34
35
|
CONFIG_NAME_KEY,
|
|
35
36
|
CONFIG_PIPELINE_KEY,
|
|
@@ -49,7 +50,15 @@ from rasa.shared.constants import (
|
|
|
49
50
|
RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG,
|
|
50
51
|
ROUTER_CONFIG_KEY,
|
|
51
52
|
)
|
|
52
|
-
from rasa.shared.core.events import
|
|
53
|
+
from rasa.shared.core.events import (
|
|
54
|
+
AgentCancelled,
|
|
55
|
+
AgentCompleted,
|
|
56
|
+
AgentInterrupted,
|
|
57
|
+
AgentResumed,
|
|
58
|
+
AgentStarted,
|
|
59
|
+
BotUttered,
|
|
60
|
+
UserUttered,
|
|
61
|
+
)
|
|
53
62
|
from rasa.shared.core.slots import BooleanSlot, CategoricalSlot, Slot
|
|
54
63
|
from rasa.shared.engine.caching import get_local_cache_location
|
|
55
64
|
from rasa.shared.exceptions import (
|
|
@@ -112,7 +121,7 @@ DEPLOYMENT_CENTRIC_PROVIDERS = [AZURE_OPENAI_PROVIDER]
|
|
|
112
121
|
|
|
113
122
|
# Placeholder messages used in the transcript for
|
|
114
123
|
# instances where user input results in an error
|
|
115
|
-
ERROR_PLACEHOLDER = {
|
|
124
|
+
ERROR_PLACEHOLDER: Dict[str, str] = {
|
|
116
125
|
RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG: "[User sent really long message]",
|
|
117
126
|
RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY: "",
|
|
118
127
|
"default": "[User input triggered an error]",
|
|
@@ -225,6 +234,7 @@ def tracker_as_readable_transcript(
|
|
|
225
234
|
ai_prefix: str = AI,
|
|
226
235
|
max_turns: Optional[int] = 20,
|
|
227
236
|
turns_wrapper: Optional[Callable[[List[str]], List[str]]] = None,
|
|
237
|
+
highlight_agent_turns: bool = False,
|
|
228
238
|
) -> str:
|
|
229
239
|
"""Creates a readable dialogue from a tracker.
|
|
230
240
|
|
|
@@ -234,6 +244,7 @@ def tracker_as_readable_transcript(
|
|
|
234
244
|
ai_prefix: the prefix to use for ai utterances
|
|
235
245
|
max_turns: the maximum number of turns to include in the transcript
|
|
236
246
|
turns_wrapper: optional function to wrap the turns in a custom way
|
|
247
|
+
highlight_agent_turns: whether to highlight agent turns in the transcript
|
|
237
248
|
|
|
238
249
|
Example:
|
|
239
250
|
>>> tracker = Tracker(
|
|
@@ -251,7 +262,9 @@ def tracker_as_readable_transcript(
|
|
|
251
262
|
Returns:
|
|
252
263
|
A string representing the transcript of the tracker
|
|
253
264
|
"""
|
|
254
|
-
transcript = []
|
|
265
|
+
transcript: List[str] = []
|
|
266
|
+
|
|
267
|
+
current_ai_prefix = ai_prefix
|
|
255
268
|
|
|
256
269
|
# using `applied_events` rather than `events` means that only events after the
|
|
257
270
|
# most recent `Restart` or `SessionStarted` are included in the transcript
|
|
@@ -266,9 +279,20 @@ def tracker_as_readable_transcript(
|
|
|
266
279
|
else:
|
|
267
280
|
message = sanitize_message_for_prompt(event.text)
|
|
268
281
|
transcript.append(f"{human_prefix}: {message}")
|
|
269
|
-
|
|
270
282
|
elif isinstance(event, BotUttered):
|
|
271
|
-
transcript.append(
|
|
283
|
+
transcript.append(
|
|
284
|
+
f"{current_ai_prefix}: {sanitize_message_for_prompt(event.text)}"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
if highlight_agent_turns:
|
|
288
|
+
if isinstance(event, AgentStarted) or isinstance(event, AgentResumed):
|
|
289
|
+
current_ai_prefix = event.agent_id
|
|
290
|
+
elif (
|
|
291
|
+
isinstance(event, AgentCompleted)
|
|
292
|
+
or isinstance(event, AgentCancelled)
|
|
293
|
+
or isinstance(event, AgentInterrupted)
|
|
294
|
+
):
|
|
295
|
+
current_ai_prefix = ai_prefix
|
|
272
296
|
|
|
273
297
|
# turns_wrapper to count multiple utterances by bot/user as single turn
|
|
274
298
|
if turns_wrapper:
|
|
@@ -739,7 +763,7 @@ def get_prompt_template(
|
|
|
739
763
|
log_source_method=log_source_method,
|
|
740
764
|
)
|
|
741
765
|
return prompt_template
|
|
742
|
-
except (FileIOException, FileNotFoundException):
|
|
766
|
+
except (FileIOException, FileNotFoundException) as e:
|
|
743
767
|
structlogger.warning(
|
|
744
768
|
"utils.llm.get_prompt_template" ".failed_to_read_custom_prompt_template",
|
|
745
769
|
event_info=(
|
|
@@ -747,6 +771,7 @@ def get_prompt_template(
|
|
|
747
771
|
),
|
|
748
772
|
log_source_component=log_source_component,
|
|
749
773
|
log_source_method=log_source_method,
|
|
774
|
+
error=str(e),
|
|
750
775
|
)
|
|
751
776
|
return default_prompt_template
|
|
752
777
|
|
|
@@ -899,7 +924,7 @@ def resolve_model_client_config(
|
|
|
899
924
|
if model_groups:
|
|
900
925
|
endpoints = AvailableEndpoints(model_groups=model_groups)
|
|
901
926
|
else:
|
|
902
|
-
endpoints =
|
|
927
|
+
endpoints = Configuration.get_instance().endpoints
|
|
903
928
|
if endpoints.model_groups is None:
|
|
904
929
|
_raise_invalid_config_exception(
|
|
905
930
|
reason=(
|
|
@@ -1015,14 +1040,12 @@ def _get_llm_command_generator_config(
|
|
|
1015
1040
|
return None
|
|
1016
1041
|
|
|
1017
1042
|
|
|
1018
|
-
def
|
|
1043
|
+
def _get_compact_llm_command_generator_prompt(
|
|
1019
1044
|
config: Dict[Text, Any], endpoints: Dict[Text, Any]
|
|
1020
1045
|
) -> Text:
|
|
1021
1046
|
"""Get the command generator prompt based on the config."""
|
|
1022
1047
|
from rasa.dialogue_understanding.generator.single_step.compact_llm_command_generator import ( # noqa: E501
|
|
1023
|
-
|
|
1024
|
-
FALLBACK_COMMAND_PROMPT_TEMPLATE_FILE_NAME,
|
|
1025
|
-
MODEL_PROMPT_MAPPER,
|
|
1048
|
+
CompactLLMCommandGenerator,
|
|
1026
1049
|
)
|
|
1027
1050
|
|
|
1028
1051
|
model_config = _get_llm_command_generator_config(config)
|
|
@@ -1032,9 +1055,9 @@ def _get_command_generator_prompt(
|
|
|
1032
1055
|
)
|
|
1033
1056
|
return get_default_prompt_template_based_on_model(
|
|
1034
1057
|
llm_config=llm_config or {},
|
|
1035
|
-
model_prompt_mapping=
|
|
1036
|
-
default_prompt_path=
|
|
1037
|
-
fallback_prompt_path=
|
|
1058
|
+
model_prompt_mapping=CompactLLMCommandGenerator.get_model_prompt_mapper(),
|
|
1059
|
+
default_prompt_path=CompactLLMCommandGenerator.get_default_prompt_template_file_name(),
|
|
1060
|
+
fallback_prompt_path=CompactLLMCommandGenerator.get_fallback_prompt_template_file_name(),
|
|
1038
1061
|
)
|
|
1039
1062
|
|
|
1040
1063
|
|
|
@@ -1073,7 +1096,7 @@ def get_system_default_prompts(
|
|
|
1073
1096
|
)
|
|
1074
1097
|
|
|
1075
1098
|
return SystemPrompts(
|
|
1076
|
-
command_generator=
|
|
1099
|
+
command_generator=_get_compact_llm_command_generator_prompt(config, endpoints),
|
|
1077
1100
|
enterprise_search=_get_enterprise_search_prompt(config),
|
|
1078
1101
|
contextual_response_rephraser=DEFAULT_RESPONSE_VARIATION_PROMPT_TEMPLATE,
|
|
1079
1102
|
)
|
|
File without changes
|