rasa-pro 3.14.0.dev20250922__py3-none-any.whl → 3.14.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +15 -3
- rasa/agents/__init__.py +0 -0
- rasa/agents/agent_factory.py +122 -0
- rasa/agents/agent_manager.py +211 -0
- rasa/agents/constants.py +43 -0
- rasa/agents/core/__init__.py +0 -0
- rasa/agents/core/agent_protocol.py +107 -0
- rasa/agents/core/types.py +81 -0
- rasa/agents/exceptions.py +38 -0
- rasa/agents/protocol/__init__.py +5 -0
- rasa/agents/protocol/a2a/__init__.py +0 -0
- rasa/agents/protocol/a2a/a2a_agent.py +879 -0
- rasa/agents/protocol/mcp/__init__.py +0 -0
- rasa/agents/protocol/mcp/mcp_base_agent.py +726 -0
- rasa/agents/protocol/mcp/mcp_open_agent.py +327 -0
- rasa/agents/protocol/mcp/mcp_task_agent.py +522 -0
- rasa/agents/schemas/__init__.py +13 -0
- rasa/agents/schemas/agent_input.py +38 -0
- rasa/agents/schemas/agent_output.py +26 -0
- rasa/agents/schemas/agent_tool_result.py +65 -0
- rasa/agents/schemas/agent_tool_schema.py +186 -0
- rasa/agents/templates/__init__.py +0 -0
- rasa/agents/templates/mcp_open_agent_prompt_template.jinja2 +20 -0
- rasa/agents/templates/mcp_task_agent_prompt_template.jinja2 +22 -0
- rasa/agents/utils.py +206 -0
- rasa/agents/validation.py +485 -0
- rasa/api.py +24 -9
- rasa/builder/config.py +6 -2
- rasa/builder/guardrails/{lakera.py → clients.py} +55 -5
- rasa/builder/guardrails/constants.py +3 -0
- rasa/builder/guardrails/models.py +45 -10
- rasa/builder/guardrails/policy_checker.py +324 -0
- rasa/builder/guardrails/utils.py +42 -276
- rasa/builder/llm_service.py +32 -5
- rasa/builder/models.py +1 -0
- rasa/builder/project_generator.py +6 -1
- rasa/builder/service.py +16 -13
- rasa/builder/training_service.py +18 -24
- rasa/builder/validation_service.py +1 -1
- rasa/cli/arguments/default_arguments.py +12 -0
- rasa/cli/arguments/run.py +2 -0
- rasa/cli/arguments/train.py +2 -0
- rasa/cli/data.py +10 -8
- rasa/cli/dialogue_understanding_test.py +10 -7
- rasa/cli/e2e_test.py +9 -6
- rasa/cli/evaluate.py +4 -2
- rasa/cli/export.py +5 -2
- rasa/cli/inspect.py +8 -4
- rasa/cli/interactive.py +5 -4
- rasa/cli/llm_fine_tuning.py +11 -6
- rasa/cli/project_templates/tutorial/credentials.yml +10 -0
- rasa/cli/run.py +12 -10
- rasa/cli/scaffold.py +4 -4
- rasa/cli/shell.py +9 -5
- rasa/cli/studio/studio.py +1 -1
- rasa/cli/test.py +34 -14
- rasa/cli/train.py +41 -28
- rasa/cli/utils.py +1 -393
- rasa/cli/validation/__init__.py +0 -0
- rasa/cli/validation/bot_config.py +223 -0
- rasa/cli/validation/config_path_validation.py +257 -0
- rasa/cli/x.py +8 -4
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +51 -10
- rasa/core/actions/grpc_custom_action_executor.py +1 -1
- rasa/core/agent.py +19 -2
- rasa/core/available_agents.py +229 -0
- rasa/core/channels/__init__.py +82 -35
- rasa/core/channels/development_inspector.py +3 -3
- rasa/core/channels/inspector/README.md +25 -13
- rasa/core/channels/inspector/dist/assets/{arc-35222594.js → arc-6177260a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-a0efbfd3.js → blockDiagram-38ab4fdb-b054f038.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-0584c0f2.js → c4Diagram-3d4e48cf-f25427d5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/channel-bf9cbb34.js +1 -0
- rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-39f40dbe.js → classDiagram-70f12bd4-c7a2af53.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-1ad755f3.js → classDiagram-v2-f2320105-58db65c0.js} +1 -1
- rasa/core/channels/inspector/dist/assets/clone-8f9083bb.js +1 -0
- rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-b0f4f0fe.js → createText-2e5e7dd3-088372e2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-9039bff9.js → edges-e0da2a9e-58676240.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-65c9b127.js → erDiagram-9861fffd-0c14d7c6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-4f08b38e.js → flowDb-956e92f1-ea63f85c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-e95c362a.js → flowDiagram-66a62f08-a2af48cd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-9ecd5b59.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-703c3015.js → flowchart-elk-definition-4a651766-6937abe7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-699328ea.js → ganttDiagram-c361ad54-7473f357.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-04cf4b05.js → gitGraphDiagram-72cf32ee-d0c9405e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{graph-ee94449e.js → graph-0a6f8466.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3862675e-940162b4.js → index-3862675e-7610671a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/index-74e01d94.js +1354 -0
- rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-c79c2866.js → infoDiagram-f8f76790-be397dc7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-84489d30.js → journeyDiagram-49397b02-4cefbf62.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-a9aa9858.js → layout-e7fbc2bf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-eb73cf26.js → line-a8aa457c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-b3399f9a.js → linear-3351e0d2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-b095bf1a.js → mindmap-definition-fc14e90a-b8cbf605.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-07644b66.js → pieDiagram-8a3498a8-f327f774.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-573a3f9c.js → quadrantDiagram-120e2f19-2854c591.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-d457e1e1.js → requirementDiagram-deff3bca-964985d5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-9d26e1a2.js → sankeyDiagram-04a897e0-edeb4f33.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-3a9cde10.js → sequenceDiagram-704730f1-fcf70125.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-4f3e8cec.js → stateDiagram-587899a1-0e770395.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-e617e5bf.js → stateDiagram-v2-d93cdb3a-af8dcd22.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-eab30d2f.js → styles-6aaf32cf-36a9e70d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9a916d00-09994be2.js → styles-9a916d00-884a8b5b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-c10674c1-b7110364.js → styles-c10674c1-dc097813.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-3ebc92ad.js → svgDrawCommon-08f97a94-5a2c7eed.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-7d13d2f2.js → timeline-definition-85554ec2-e89c4f6e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-488385e1.js → xychartDiagram-e933f94c-afb6fe56.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/package.json +18 -18
- rasa/core/channels/inspector/src/App.tsx +29 -4
- rasa/core/channels/inspector/src/components/DialogueAgentStack.tsx +108 -0
- rasa/core/channels/inspector/src/components/{DialogueStack.tsx → DialogueHistoryStack.tsx} +4 -2
- rasa/core/channels/inspector/src/helpers/audio/audiostream.ts +7 -4
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -0
- rasa/core/channels/inspector/src/helpers/formatters.ts +24 -3
- rasa/core/channels/inspector/src/helpers/utils.test.ts +127 -0
- rasa/core/channels/inspector/src/helpers/utils.ts +66 -1
- rasa/core/channels/inspector/src/theme/base/styles.ts +19 -1
- rasa/core/channels/inspector/src/types.ts +21 -0
- rasa/core/channels/inspector/yarn.lock +336 -189
- rasa/core/channels/studio_chat.py +6 -6
- rasa/core/channels/telegram.py +4 -9
- rasa/core/channels/voice_stream/genesys.py +1 -1
- rasa/core/channels/voice_stream/tts/deepgram.py +140 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +5 -1
- rasa/core/channels/voice_stream/voice_channel.py +3 -0
- rasa/core/config/__init__.py +0 -0
- rasa/core/{available_endpoints.py → config/available_endpoints.py} +51 -16
- rasa/core/config/configuration.py +260 -0
- rasa/core/config/credentials.py +19 -0
- rasa/core/config/message_procesing_config.py +34 -0
- rasa/core/constants.py +4 -0
- rasa/core/policies/enterprise_search_policy.py +5 -3
- rasa/core/policies/flow_policy.py +4 -4
- rasa/core/policies/flows/agent_executor.py +632 -0
- rasa/core/policies/flows/flow_executor.py +136 -75
- rasa/core/policies/flows/mcp_tool_executor.py +298 -0
- rasa/core/policies/intentless_policy.py +1 -1
- rasa/core/policies/ted_policy.py +20 -12
- rasa/core/policies/unexpected_intent_policy.py +6 -0
- rasa/core/processor.py +68 -44
- rasa/core/run.py +37 -8
- rasa/core/test.py +4 -0
- rasa/core/tracker_stores/tracker_store.py +3 -7
- rasa/core/train.py +1 -1
- rasa/core/training/interactive.py +20 -18
- rasa/core/training/story_conflict.py +5 -5
- rasa/core/utils.py +22 -23
- rasa/dialogue_understanding/commands/__init__.py +8 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +19 -5
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +21 -2
- rasa/dialogue_understanding/commands/clarify_command.py +20 -2
- rasa/dialogue_understanding/commands/continue_agent_command.py +91 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +21 -2
- rasa/dialogue_understanding/commands/restart_agent_command.py +162 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +68 -7
- rasa/dialogue_understanding/commands/utils.py +124 -2
- rasa/dialogue_understanding/generator/command_parser.py +4 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +50 -12
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +66 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +66 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_claude_3_5_sonnet_20240620_template.jinja2 +89 -0
- rasa/dialogue_understanding/generator/prompt_templates/agent_command_prompt_v3_gpt_4o_2024_11_20_template.jinja2 +88 -0
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +42 -7
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +40 -3
- rasa/dialogue_understanding/generator/single_step/single_step_based_llm_command_generator.py +20 -3
- rasa/dialogue_understanding/patterns/cancel.py +27 -6
- rasa/dialogue_understanding/patterns/clarify.py +3 -14
- rasa/dialogue_understanding/patterns/continue_interrupted.py +239 -6
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +46 -8
- rasa/dialogue_understanding/processor/command_processor.py +136 -15
- rasa/dialogue_understanding/stack/dialogue_stack.py +98 -2
- rasa/dialogue_understanding/stack/frames/flow_stack_frame.py +57 -0
- rasa/dialogue_understanding/stack/utils.py +57 -3
- rasa/dialogue_understanding/utils.py +24 -4
- rasa/dialogue_understanding_test/du_test_runner.py +8 -3
- rasa/e2e_test/e2e_test_runner.py +13 -3
- rasa/engine/caching.py +2 -2
- rasa/engine/constants.py +1 -1
- rasa/engine/recipes/default_components.py +138 -49
- rasa/engine/recipes/default_recipe.py +108 -11
- rasa/engine/runner/dask.py +8 -5
- rasa/engine/validation.py +19 -6
- rasa/graph_components/validators/default_recipe_validator.py +86 -28
- rasa/hooks.py +5 -5
- rasa/llm_fine_tuning/utils.py +2 -2
- rasa/model_training.py +60 -47
- rasa/nlu/classifiers/diet_classifier.py +198 -98
- rasa/nlu/classifiers/logistic_regression_classifier.py +1 -4
- rasa/nlu/classifiers/mitie_intent_classifier.py +3 -0
- rasa/nlu/classifiers/sklearn_intent_classifier.py +1 -3
- rasa/nlu/extractors/crf_entity_extractor.py +9 -10
- rasa/nlu/extractors/mitie_entity_extractor.py +3 -0
- rasa/nlu/extractors/spacy_entity_extractor.py +3 -0
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +4 -0
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +5 -0
- rasa/nlu/featurizers/dense_featurizer/mitie_featurizer.py +2 -0
- rasa/nlu/featurizers/dense_featurizer/spacy_featurizer.py +3 -0
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +4 -2
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +4 -0
- rasa/nlu/selectors/response_selector.py +10 -2
- rasa/nlu/tokenizers/jieba_tokenizer.py +3 -4
- rasa/nlu/tokenizers/mitie_tokenizer.py +3 -2
- rasa/nlu/tokenizers/spacy_tokenizer.py +3 -2
- rasa/nlu/utils/mitie_utils.py +3 -0
- rasa/nlu/utils/spacy_utils.py +3 -2
- rasa/plugin.py +8 -8
- rasa/privacy/privacy_manager.py +12 -3
- rasa/server.py +15 -3
- rasa/shared/agents/__init__.py +0 -0
- rasa/shared/agents/auth/__init__.py +0 -0
- rasa/shared/agents/auth/agent_auth_factory.py +105 -0
- rasa/shared/agents/auth/agent_auth_manager.py +92 -0
- rasa/shared/agents/auth/auth_strategy/__init__.py +19 -0
- rasa/shared/agents/auth/auth_strategy/agent_auth_strategy.py +52 -0
- rasa/shared/agents/auth/auth_strategy/api_key_auth_strategy.py +42 -0
- rasa/shared/agents/auth/auth_strategy/bearer_token_auth_strategy.py +28 -0
- rasa/shared/agents/auth/auth_strategy/oauth2_auth_strategy.py +167 -0
- rasa/shared/agents/auth/constants.py +12 -0
- rasa/shared/agents/auth/types.py +12 -0
- rasa/shared/agents/utils.py +35 -0
- rasa/shared/constants.py +8 -0
- rasa/shared/core/constants.py +16 -1
- rasa/shared/core/domain.py +0 -7
- rasa/shared/core/events.py +327 -0
- rasa/shared/core/flows/constants.py +5 -0
- rasa/shared/core/flows/flows_list.py +21 -5
- rasa/shared/core/flows/flows_yaml_schema.json +119 -184
- rasa/shared/core/flows/steps/call.py +49 -5
- rasa/shared/core/flows/steps/collect.py +98 -13
- rasa/shared/core/flows/validation.py +372 -8
- rasa/shared/core/flows/yaml_flows_io.py +3 -2
- rasa/shared/core/slots.py +2 -2
- rasa/shared/core/trackers.py +5 -2
- rasa/shared/exceptions.py +16 -0
- rasa/shared/importers/rasa.py +1 -1
- rasa/shared/importers/utils.py +9 -3
- rasa/shared/providers/llm/_base_litellm_client.py +41 -9
- rasa/shared/providers/llm/litellm_router_llm_client.py +8 -4
- rasa/shared/providers/llm/llm_client.py +7 -3
- rasa/shared/providers/llm/llm_response.py +66 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +8 -4
- rasa/shared/utils/common.py +24 -0
- rasa/shared/utils/health_check/health_check.py +7 -3
- rasa/shared/utils/llm.py +39 -16
- rasa/shared/utils/mcp/__init__.py +0 -0
- rasa/shared/utils/mcp/server_connection.py +247 -0
- rasa/shared/utils/mcp/utils.py +20 -0
- rasa/shared/utils/schemas/events.py +42 -0
- rasa/shared/utils/yaml.py +3 -1
- rasa/studio/pull/pull.py +3 -2
- rasa/studio/train.py +8 -7
- rasa/studio/upload.py +3 -6
- rasa/telemetry.py +69 -5
- rasa/tracing/config.py +45 -12
- rasa/tracing/constants.py +14 -0
- rasa/tracing/instrumentation/attribute_extractors.py +142 -9
- rasa/tracing/instrumentation/instrumentation.py +626 -21
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +4 -4
- rasa/tracing/instrumentation/metrics.py +32 -0
- rasa/tracing/metric_instrument_provider.py +68 -0
- rasa/utils/common.py +92 -1
- rasa/utils/endpoints.py +11 -2
- rasa/utils/log_utils.py +96 -5
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/__init__.py +7 -0
- rasa/utils/tensorflow/callback.py +136 -101
- rasa/utils/tensorflow/crf.py +1 -1
- rasa/utils/tensorflow/data_generator.py +21 -8
- rasa/utils/tensorflow/layers.py +21 -11
- rasa/utils/tensorflow/metrics.py +7 -3
- rasa/utils/tensorflow/models.py +56 -8
- rasa/utils/tensorflow/rasa_layers.py +8 -6
- rasa/utils/tensorflow/transformer.py +2 -3
- rasa/utils/train_utils.py +54 -24
- rasa/validator.py +5 -5
- rasa/version.py +1 -1
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc1.dist-info}/METADATA +46 -41
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc1.dist-info}/RECORD +285 -226
- rasa/builder/scrape_rasa_docs.py +0 -97
- rasa/core/channels/inspector/dist/assets/channel-8e08bed9.js +0 -1
- rasa/core/channels/inspector/dist/assets/clone-78c82dea.js +0 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-2b08f601.js +0 -1
- rasa/core/channels/inspector/dist/assets/index-c941dcb3.js +0 -1336
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.0.dev20250922.dist-info → rasa_pro-3.14.0rc1.dist-info}/entry_points.txt +0 -0
rasa/builder/guardrails/utils.py
CHANGED
|
@@ -1,31 +1,13 @@
|
|
|
1
|
-
import
|
|
2
|
-
from functools import lru_cache
|
|
3
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type
|
|
4
2
|
|
|
5
3
|
import structlog
|
|
6
4
|
|
|
7
|
-
from rasa.builder.
|
|
8
|
-
ASSISTANT_HISTORY_GUARDRAIL_PROJECT_ID,
|
|
9
|
-
COPILOT_HISTORY_GUARDRAIL_PROJECT_ID,
|
|
10
|
-
)
|
|
11
|
-
from rasa.builder.copilot.constants import ROLE_COPILOT, ROLE_USER
|
|
12
|
-
from rasa.builder.copilot.copilot_response_handler import CopilotResponseHandler
|
|
13
|
-
from rasa.builder.copilot.models import (
|
|
14
|
-
CopilotChatMessage,
|
|
15
|
-
CopilotContext,
|
|
16
|
-
GeneratedContent,
|
|
17
|
-
ResponseCategory,
|
|
18
|
-
)
|
|
5
|
+
from rasa.builder.guardrails.clients import GuardrailsClient, LakeraAIGuardrails
|
|
19
6
|
from rasa.builder.guardrails.models import (
|
|
20
|
-
|
|
21
|
-
GuardrailResponse,
|
|
7
|
+
GuardrailRequest,
|
|
22
8
|
LakeraGuardrailRequest,
|
|
23
9
|
)
|
|
24
|
-
from rasa.
|
|
25
|
-
from rasa.builder.shared.tracker_context import (
|
|
26
|
-
AssistantConversationTurn,
|
|
27
|
-
TrackerContext,
|
|
28
|
-
)
|
|
10
|
+
from rasa.shared.constants import ROLE_USER
|
|
29
11
|
|
|
30
12
|
if TYPE_CHECKING:
|
|
31
13
|
from rasa.builder.guardrails.models import GuardrailType
|
|
@@ -64,265 +46,49 @@ def map_lakera_detector_type_to_guardrail_type(
|
|
|
64
46
|
return GuardrailType.OTHER
|
|
65
47
|
|
|
66
48
|
|
|
67
|
-
|
|
68
|
-
|
|
49
|
+
def create_guardrail_request(
|
|
50
|
+
client_type: Type[GuardrailsClient],
|
|
69
51
|
user_text: str,
|
|
70
52
|
hello_rasa_user_id: str,
|
|
71
53
|
hello_rasa_project_id: str,
|
|
72
|
-
|
|
73
|
-
) ->
|
|
74
|
-
"""
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
user_text:
|
|
78
|
-
hello_rasa_user_id:
|
|
79
|
-
hello_rasa_project_id:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
request = LakeraGuardrailRequest(
|
|
89
|
-
lakera_project_id=lakera_project_id,
|
|
90
|
-
hello_rasa_user_id=hello_rasa_user_id,
|
|
91
|
-
hello_rasa_project_id=hello_rasa_project_id,
|
|
92
|
-
messages=[{"role": ROLE_USER, "content": user_text}],
|
|
93
|
-
)
|
|
94
|
-
|
|
95
|
-
return loop.create_task(llm_service.guardrails.send_request(request))
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
async def _detect_flagged_user_indices(
|
|
99
|
-
items: List[Tuple[int, str]],
|
|
100
|
-
*,
|
|
101
|
-
hello_rasa_user_id: Optional[str],
|
|
102
|
-
hello_rasa_project_id: Optional[str],
|
|
103
|
-
lakera_project_id: str,
|
|
104
|
-
log_prefix: str,
|
|
105
|
-
) -> Set[int]:
|
|
106
|
-
"""Run guardrail checks for provided (index, user_text) pairs.
|
|
107
|
-
|
|
108
|
-
Args:
|
|
109
|
-
items: List of tuples containing (index, user_text) to check.
|
|
110
|
-
hello_rasa_user_id: The user ID for the conversation.
|
|
111
|
-
hello_rasa_project_id: The project ID for the conversation.
|
|
112
|
-
lakera_project_id: The Lakera project ID to use for this check.
|
|
113
|
-
log_prefix: Prefix for logging messages.
|
|
114
|
-
|
|
115
|
-
Returns:
|
|
116
|
-
A set of indices that were flagged by the guardrails.
|
|
117
|
-
"""
|
|
118
|
-
if not items:
|
|
119
|
-
return set()
|
|
120
|
-
|
|
121
|
-
# 1) Group indices by logical request key (hashable by value)
|
|
122
|
-
indices_by_key: Dict[GuardrailRequestKey, List[int]] = {}
|
|
123
|
-
for idx, text in items:
|
|
124
|
-
key = GuardrailRequestKey(
|
|
125
|
-
user_text=(text or "").strip(),
|
|
126
|
-
hello_rasa_user_id=hello_rasa_user_id or "",
|
|
127
|
-
hello_rasa_project_id=hello_rasa_project_id or "",
|
|
128
|
-
lakera_project_id=lakera_project_id,
|
|
54
|
+
**kwargs: Any,
|
|
55
|
+
) -> GuardrailRequest:
|
|
56
|
+
"""Create a guardrail request."""
|
|
57
|
+
|
|
58
|
+
def _create_lakera_guardrail_request(
|
|
59
|
+
user_text: str,
|
|
60
|
+
hello_rasa_user_id: str,
|
|
61
|
+
hello_rasa_project_id: str,
|
|
62
|
+
**kwargs: Any,
|
|
63
|
+
) -> LakeraGuardrailRequest:
|
|
64
|
+
"""Create a Lakera guardrail request."""
|
|
65
|
+
return LakeraGuardrailRequest(
|
|
66
|
+
hello_rasa_user_id=hello_rasa_user_id,
|
|
67
|
+
hello_rasa_project_id=hello_rasa_project_id,
|
|
68
|
+
messages=[{"role": ROLE_USER, "content": user_text}],
|
|
69
|
+
**kwargs,
|
|
129
70
|
)
|
|
130
|
-
if not key.user_text:
|
|
131
|
-
continue
|
|
132
|
-
indices_by_key.setdefault(key, []).append(idx)
|
|
133
71
|
|
|
134
|
-
|
|
135
|
-
|
|
72
|
+
map_client_to_request: Dict[
|
|
73
|
+
Type[GuardrailsClient], Callable[..., GuardrailRequest]
|
|
74
|
+
] = {
|
|
75
|
+
LakeraAIGuardrails: _create_lakera_guardrail_request,
|
|
76
|
+
}
|
|
136
77
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
hello_rasa_project_id=key.hello_rasa_project_id,
|
|
144
|
-
lakera_project_id=key.lakera_project_id,
|
|
78
|
+
if client_type in map_client_to_request:
|
|
79
|
+
return map_client_to_request[client_type](
|
|
80
|
+
user_text,
|
|
81
|
+
hello_rasa_user_id,
|
|
82
|
+
hello_rasa_project_id,
|
|
83
|
+
**kwargs,
|
|
145
84
|
)
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
for key, response in zip(keys, responses):
|
|
155
|
-
if isinstance(response, Exception):
|
|
156
|
-
structlogger.warning(f"{log_prefix}.request_failed", error=str(response))
|
|
157
|
-
continue
|
|
158
|
-
if response.flagged:
|
|
159
|
-
flagged.update(indices_by_key.get(key, []))
|
|
160
|
-
|
|
161
|
-
return flagged
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
async def check_assistant_chat_for_policy_violations(
|
|
165
|
-
tracker_context: TrackerContext,
|
|
166
|
-
hello_rasa_user_id: Optional[str],
|
|
167
|
-
hello_rasa_project_id: Optional[str],
|
|
168
|
-
) -> TrackerContext:
|
|
169
|
-
"""Return a sanitised TrackerContext with unsafe turns removed.
|
|
170
|
-
|
|
171
|
-
Only user messages are moderated – assistant messages are assumed safe.
|
|
172
|
-
LRU cache is used, so each unique user text is checked once.
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
tracker_context: The TrackerContext containing conversation turns.
|
|
176
|
-
hello_rasa_user_id: The user ID for the conversation.
|
|
177
|
-
hello_rasa_project_id: The project ID for the conversation.
|
|
178
|
-
|
|
179
|
-
Returns:
|
|
180
|
-
TrackerContext with unsafe turns removed.
|
|
181
|
-
"""
|
|
182
|
-
# Collect (turn_index, user_text) for all turns with a user message
|
|
183
|
-
items: List[Tuple[int, str]] = []
|
|
184
|
-
for idx, turn in enumerate(tracker_context.conversation_turns):
|
|
185
|
-
user_message = turn.user_message
|
|
186
|
-
if not user_message:
|
|
187
|
-
continue
|
|
188
|
-
|
|
189
|
-
text = (user_message.text or "").strip()
|
|
190
|
-
if not text:
|
|
191
|
-
continue
|
|
192
|
-
|
|
193
|
-
items.append((idx, text))
|
|
194
|
-
|
|
195
|
-
flagged_turns = await _detect_flagged_user_indices(
|
|
196
|
-
items,
|
|
197
|
-
hello_rasa_user_id=hello_rasa_user_id,
|
|
198
|
-
hello_rasa_project_id=hello_rasa_project_id,
|
|
199
|
-
lakera_project_id=ASSISTANT_HISTORY_GUARDRAIL_PROJECT_ID,
|
|
200
|
-
log_prefix="assistant_guardrails",
|
|
201
|
-
)
|
|
202
|
-
|
|
203
|
-
if not flagged_turns:
|
|
204
|
-
return tracker_context
|
|
205
|
-
|
|
206
|
-
structlogger.info(
|
|
207
|
-
"assistant_guardrails.turns_flagged",
|
|
208
|
-
count=len(flagged_turns),
|
|
209
|
-
turn_indices=sorted(flagged_turns),
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
# Build a filtered TrackerContext
|
|
213
|
-
safe_turns: List[AssistantConversationTurn] = [
|
|
214
|
-
turn
|
|
215
|
-
for idx, turn in enumerate(tracker_context.conversation_turns)
|
|
216
|
-
if idx not in flagged_turns
|
|
217
|
-
]
|
|
218
|
-
|
|
219
|
-
new_tracker_context = tracker_context.copy(deep=True)
|
|
220
|
-
new_tracker_context.conversation_turns = safe_turns
|
|
221
|
-
return new_tracker_context
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
def _annotate_flagged_user_messages(
|
|
225
|
-
history: List[CopilotChatMessage], flagged_user_indices: Set[int]
|
|
226
|
-
) -> None:
|
|
227
|
-
"""Mark flagged user messages in-place on the original history.
|
|
228
|
-
|
|
229
|
-
Args:
|
|
230
|
-
history: The copilot chat history containing messages.
|
|
231
|
-
flagged_user_indices: Set of indices of user messages that were flagged.
|
|
232
|
-
"""
|
|
233
|
-
if not flagged_user_indices:
|
|
234
|
-
return
|
|
235
|
-
|
|
236
|
-
total = len(history)
|
|
237
|
-
for uidx in flagged_user_indices:
|
|
238
|
-
if 0 <= uidx < total and history[uidx].role == ROLE_USER:
|
|
239
|
-
history[
|
|
240
|
-
uidx
|
|
241
|
-
].response_category = ResponseCategory.GUARDRAILS_POLICY_VIOLATION
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
async def check_copilot_chat_for_policy_violations(
|
|
245
|
-
context: CopilotContext,
|
|
246
|
-
hello_rasa_user_id: Optional[str],
|
|
247
|
-
hello_rasa_project_id: Optional[str],
|
|
248
|
-
) -> Optional[GeneratedContent]:
|
|
249
|
-
"""Check the copilot chat history for guardrail policy violations.
|
|
250
|
-
|
|
251
|
-
Only user messages are moderated – assistant messages are assumed safe.
|
|
252
|
-
LRU cache is used, so each unique user text is checked once.
|
|
253
|
-
|
|
254
|
-
Args:
|
|
255
|
-
context: The CopilotContext containing the copilot chat history.
|
|
256
|
-
hello_rasa_user_id: The user ID for the conversation.
|
|
257
|
-
hello_rasa_project_id: The project ID for the conversation.
|
|
258
|
-
|
|
259
|
-
Returns:
|
|
260
|
-
Returns a default violation response if the system flags any user message,
|
|
261
|
-
otherwise return None.
|
|
262
|
-
"""
|
|
263
|
-
history = context.copilot_chat_history
|
|
264
|
-
|
|
265
|
-
# Collect (index, text) for user messages; skip ones already marked as violations
|
|
266
|
-
items: List[Tuple[int, str]] = []
|
|
267
|
-
for idx, message in enumerate(history):
|
|
268
|
-
if message.response_category == ResponseCategory.GUARDRAILS_POLICY_VIOLATION:
|
|
269
|
-
continue
|
|
270
|
-
if message.role != ROLE_USER:
|
|
271
|
-
continue
|
|
272
|
-
formatted_message = message.to_openai_format()
|
|
273
|
-
text = (formatted_message.get("content") or "").strip()
|
|
274
|
-
if not text:
|
|
275
|
-
continue
|
|
276
|
-
items.append((idx, text))
|
|
277
|
-
|
|
278
|
-
flagged_user_indices = await _detect_flagged_user_indices(
|
|
279
|
-
items,
|
|
280
|
-
hello_rasa_user_id=hello_rasa_user_id,
|
|
281
|
-
hello_rasa_project_id=hello_rasa_project_id,
|
|
282
|
-
lakera_project_id=COPILOT_HISTORY_GUARDRAIL_PROJECT_ID,
|
|
283
|
-
log_prefix="copilot_guardrails",
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
_annotate_flagged_user_messages(history, flagged_user_indices)
|
|
287
|
-
|
|
288
|
-
if not flagged_user_indices:
|
|
289
|
-
return None
|
|
290
|
-
|
|
291
|
-
# Identify the latest user message index in the current request
|
|
292
|
-
last_user_idx: Optional[int] = None
|
|
293
|
-
for i in range(len(history) - 1, -1, -1):
|
|
294
|
-
if getattr(history[i], "role", None) == ROLE_USER:
|
|
295
|
-
last_user_idx = i
|
|
296
|
-
break
|
|
297
|
-
|
|
298
|
-
# Remove flagged user messages and their next copilot messages
|
|
299
|
-
indices_to_remove: Set[int] = set()
|
|
300
|
-
total = len(history)
|
|
301
|
-
for uidx in flagged_user_indices:
|
|
302
|
-
indices_to_remove.add(uidx)
|
|
303
|
-
next_idx = uidx + 1
|
|
304
|
-
if (
|
|
305
|
-
next_idx < total
|
|
306
|
-
and getattr(history[next_idx], "role", None) == ROLE_COPILOT
|
|
307
|
-
):
|
|
308
|
-
indices_to_remove.add(next_idx)
|
|
309
|
-
|
|
310
|
-
# Apply sanitization
|
|
311
|
-
filtered_history = [
|
|
312
|
-
msg for i, msg in enumerate(history) if i not in indices_to_remove
|
|
313
|
-
]
|
|
314
|
-
if len(filtered_history) != len(history):
|
|
315
|
-
structlogger.info(
|
|
316
|
-
"copilot_guardrails.history_sanitized",
|
|
317
|
-
removed_indices=sorted(indices_to_remove),
|
|
318
|
-
removed_messages=len(history) - len(filtered_history),
|
|
319
|
-
kept_messages=len(filtered_history),
|
|
85
|
+
else:
|
|
86
|
+
message = f"Unsupported guardrail client: {type(client_type)}"
|
|
87
|
+
structlogger.error(
|
|
88
|
+
"guardrails_policy_checker"
|
|
89
|
+
".create_guardrail_request"
|
|
90
|
+
".unsupported_guardrail_client",
|
|
91
|
+
message=message,
|
|
92
|
+
guardrail_client=client_type,
|
|
320
93
|
)
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
# Block only if the latest user message in this request was flagged
|
|
324
|
-
if last_user_idx is not None and last_user_idx in flagged_user_indices:
|
|
325
|
-
return CopilotResponseHandler.respond_to_guardrail_policy_violations()
|
|
326
|
-
|
|
327
|
-
# Otherwise proceed (following messages are respected)
|
|
328
|
-
return None
|
|
94
|
+
raise ValueError(message)
|
rasa/builder/llm_service.py
CHANGED
|
@@ -19,7 +19,11 @@ from rasa.builder.copilot.copilot_templated_message_provider import (
|
|
|
19
19
|
load_copilot_internal_message_templates,
|
|
20
20
|
)
|
|
21
21
|
from rasa.builder.exceptions import LLMGenerationError
|
|
22
|
-
from rasa.builder.guardrails.
|
|
22
|
+
from rasa.builder.guardrails.clients import (
|
|
23
|
+
GuardrailsClient,
|
|
24
|
+
LakeraAIGuardrails,
|
|
25
|
+
)
|
|
26
|
+
from rasa.builder.guardrails.policy_checker import GuardrailsPolicyChecker
|
|
23
27
|
from rasa.constants import PACKAGE_NAME
|
|
24
28
|
from rasa.shared.constants import DOMAIN_SCHEMA_FILE, RESPONSES_SCHEMA_FILE
|
|
25
29
|
from rasa.shared.core.flows.yaml_flows_io import FLOWS_SCHEMA_FILE
|
|
@@ -37,7 +41,8 @@ class LLMService:
|
|
|
37
41
|
self._domain_schema: Optional[Dict[str, Any]] = None
|
|
38
42
|
self._flows_schema: Optional[Dict[str, Any]] = None
|
|
39
43
|
self._copilot: Optional[Copilot] = None
|
|
40
|
-
self._guardrails: Optional[
|
|
44
|
+
self._guardrails: Optional[GuardrailsClient] = None
|
|
45
|
+
self._guardrails_policy_checker: Optional[GuardrailsPolicyChecker] = None
|
|
41
46
|
self._copilot_response_handler: Optional[CopilotResponseHandler] = None
|
|
42
47
|
self._copilot_internal_message_templates: Optional[Dict[str, str]] = None
|
|
43
48
|
|
|
@@ -77,11 +82,14 @@ class LLMService:
|
|
|
77
82
|
raise
|
|
78
83
|
|
|
79
84
|
@property
|
|
80
|
-
def guardrails(self) ->
|
|
85
|
+
def guardrails(self) -> Optional[GuardrailsClient]:
|
|
81
86
|
"""Get or lazy create guardrails instance."""
|
|
82
|
-
if
|
|
83
|
-
|
|
87
|
+
if not config.ENABLE_GUARDRAILS:
|
|
88
|
+
return None
|
|
89
|
+
# TODO: Replace with Open Source guardrails implementation once it's ready
|
|
84
90
|
try:
|
|
91
|
+
if self._guardrails is None:
|
|
92
|
+
self._guardrails = LakeraAIGuardrails()
|
|
85
93
|
return self._guardrails
|
|
86
94
|
except Exception as e:
|
|
87
95
|
structlogger.error(
|
|
@@ -91,6 +99,25 @@ class LLMService:
|
|
|
91
99
|
)
|
|
92
100
|
raise
|
|
93
101
|
|
|
102
|
+
@property
|
|
103
|
+
def guardrails_policy_checker(self) -> Optional[GuardrailsPolicyChecker]:
|
|
104
|
+
"""Get or lazy create guardrails policy checker instance."""
|
|
105
|
+
try:
|
|
106
|
+
if self._guardrails_policy_checker is None and self.guardrails is not None:
|
|
107
|
+
self._guardrails_policy_checker = GuardrailsPolicyChecker(
|
|
108
|
+
self.guardrails
|
|
109
|
+
)
|
|
110
|
+
return self._guardrails_policy_checker
|
|
111
|
+
except Exception as e:
|
|
112
|
+
structlogger.error(
|
|
113
|
+
"llm_service.guardrails_policy_checker.error",
|
|
114
|
+
event_info=(
|
|
115
|
+
"LLM Service: Error getting guardrails policy checker instance."
|
|
116
|
+
),
|
|
117
|
+
error=str(e),
|
|
118
|
+
)
|
|
119
|
+
raise
|
|
120
|
+
|
|
94
121
|
@property
|
|
95
122
|
def copilot_internal_message_templates(self) -> Dict[str, str]:
|
|
96
123
|
"""Get or lazy load copilot internal message templates."""
|
rasa/builder/models.py
CHANGED
|
@@ -206,6 +206,7 @@ class TrainingInput(BaseModel):
|
|
|
206
206
|
|
|
207
207
|
importer: TrainingDataImporter = Field(..., description="Training data importer")
|
|
208
208
|
endpoints_file: Path = Field(..., description="Path to the endpoints file")
|
|
209
|
+
config_file: Path = Field(..., description="Path to the config file")
|
|
209
210
|
|
|
210
211
|
|
|
211
212
|
class AgentStatus(str, Enum):
|
|
@@ -192,11 +192,16 @@ class ProjectGenerator:
|
|
|
192
192
|
"""Get the endpoints file."""
|
|
193
193
|
return self.project_folder / "endpoints.yml"
|
|
194
194
|
|
|
195
|
+
def _get_config_file(self) -> Path:
|
|
196
|
+
"""Get the config file."""
|
|
197
|
+
return self.project_folder / "config.yml"
|
|
198
|
+
|
|
195
199
|
def get_training_input(self) -> TrainingInput:
|
|
196
200
|
"""Get the training input."""
|
|
197
201
|
return TrainingInput(
|
|
198
202
|
importer=self._create_importer(),
|
|
199
203
|
endpoints_file=self._get_endpoints_file(),
|
|
204
|
+
config_file=self._get_config_file(),
|
|
200
205
|
)
|
|
201
206
|
|
|
202
207
|
def _create_importer(self) -> TrainingDataImporter:
|
|
@@ -208,7 +213,7 @@ class ProjectGenerator:
|
|
|
208
213
|
domain_path = self.project_folder / "domain"
|
|
209
214
|
|
|
210
215
|
return TrainingDataImporter.load_from_config(
|
|
211
|
-
config_path=str(self.
|
|
216
|
+
config_path=str(self._get_config_file()),
|
|
212
217
|
domain_path=str(domain_path),
|
|
213
218
|
training_data_paths=[str(self.project_folder / "data")],
|
|
214
219
|
args={},
|
rasa/builder/service.py
CHANGED
|
@@ -17,6 +17,8 @@ from rasa.builder.config import (
|
|
|
17
17
|
COPILOT_HANDLER_ROLLING_BUFFER_SIZE,
|
|
18
18
|
GUARDRAILS_ENABLE_BLOCKING,
|
|
19
19
|
HELLO_RASA_PROJECT_ID,
|
|
20
|
+
LAKERA_ASSISTANT_HISTORY_GUARDRAIL_PROJECT_ID,
|
|
21
|
+
LAKERA_COPILOT_HISTORY_GUARDRAIL_PROJECT_ID,
|
|
20
22
|
)
|
|
21
23
|
from rasa.builder.copilot.constants import ROLE_USER, SIGNATURE_VERSION_V1
|
|
22
24
|
from rasa.builder.copilot.copilot_response_handler import CopilotResponseHandler
|
|
@@ -47,10 +49,6 @@ from rasa.builder.guardrails.constants import (
|
|
|
47
49
|
BlockScope,
|
|
48
50
|
)
|
|
49
51
|
from rasa.builder.guardrails.store import guardrails_store
|
|
50
|
-
from rasa.builder.guardrails.utils import (
|
|
51
|
-
check_assistant_chat_for_policy_violations,
|
|
52
|
-
check_copilot_chat_for_policy_violations,
|
|
53
|
-
)
|
|
54
52
|
from rasa.builder.job_manager import job_manager
|
|
55
53
|
from rasa.builder.jobs import (
|
|
56
54
|
run_prompt_to_bot_job,
|
|
@@ -1081,11 +1079,15 @@ async def copilot(request: Request) -> None:
|
|
|
1081
1079
|
tracker_context = TrackerContext.from_tracker(
|
|
1082
1080
|
tracker, max_turns=COPILOT_ASSISTANT_TRACKER_MAX_TURNS
|
|
1083
1081
|
)
|
|
1084
|
-
if
|
|
1085
|
-
tracker_context
|
|
1082
|
+
if (
|
|
1083
|
+
tracker_context is not None
|
|
1084
|
+
and llm_service.guardrails_policy_checker is not None
|
|
1085
|
+
):
|
|
1086
|
+
tracker_context = await llm_service.guardrails_policy_checker.check_assistant_chat_for_policy_violations( # noqa: E501
|
|
1086
1087
|
tracker_context=tracker_context,
|
|
1087
1088
|
hello_rasa_user_id=user_id,
|
|
1088
1089
|
hello_rasa_project_id=HELLO_RASA_PROJECT_ID,
|
|
1090
|
+
lakera_project_id=LAKERA_ASSISTANT_HISTORY_GUARDRAIL_PROJECT_ID,
|
|
1089
1091
|
)
|
|
1090
1092
|
|
|
1091
1093
|
# Copilot doesn't need to know about the docs and any file that is not a core
|
|
@@ -1103,13 +1105,14 @@ async def copilot(request: Request) -> None:
|
|
|
1103
1105
|
|
|
1104
1106
|
# 5. Run guardrail policy checks. If any policy violations are detected,
|
|
1105
1107
|
# send a response and end the stream.
|
|
1106
|
-
guardrail_response: Optional[
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1108
|
+
guardrail_response: Optional[GeneratedContent] = None
|
|
1109
|
+
if llm_service.guardrails_policy_checker is not None:
|
|
1110
|
+
guardrail_response = await llm_service.guardrails_policy_checker.check_copilot_chat_for_policy_violations( # noqa: E501
|
|
1111
|
+
context=context,
|
|
1112
|
+
hello_rasa_user_id=user_id,
|
|
1113
|
+
hello_rasa_project_id=HELLO_RASA_PROJECT_ID,
|
|
1114
|
+
lakera_project_id=LAKERA_COPILOT_HISTORY_GUARDRAIL_PROJECT_ID,
|
|
1115
|
+
)
|
|
1113
1116
|
if guardrail_response is not None:
|
|
1114
1117
|
blocked_or_violation_message = (
|
|
1115
1118
|
await _handle_guardrail_violation_and_maybe_block(
|
rasa/builder/training_service.py
CHANGED
|
@@ -11,9 +11,10 @@ from rasa.builder.exceptions import AgentLoadError, TrainingError
|
|
|
11
11
|
from rasa.builder.models import TrainingInput
|
|
12
12
|
from rasa.core.agent import Agent, load_agent
|
|
13
13
|
from rasa.core.channels.studio_chat import StudioChatInput
|
|
14
|
-
from rasa.core.
|
|
14
|
+
from rasa.core.config.configuration import Configuration
|
|
15
15
|
from rasa.model import get_latest_model
|
|
16
16
|
from rasa.model_training import TrainingResult, train
|
|
17
|
+
from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH
|
|
17
18
|
from rasa.shared.importers.importer import TrainingDataImporter
|
|
18
19
|
|
|
19
20
|
structlogger = structlog.get_logger()
|
|
@@ -42,14 +43,13 @@ async def train_and_load_agent(input: TrainingInput) -> Agent:
|
|
|
42
43
|
AgentLoadError: If agent loading fails
|
|
43
44
|
"""
|
|
44
45
|
try:
|
|
45
|
-
# Setup endpoints for training validation
|
|
46
|
-
await _setup_endpoints(input.endpoints_file)
|
|
47
|
-
|
|
48
46
|
# Train the model
|
|
49
|
-
training_result = await _train_model(
|
|
47
|
+
training_result = await _train_model(
|
|
48
|
+
input.importer, input.endpoints_file, input.config_file
|
|
49
|
+
)
|
|
50
50
|
|
|
51
51
|
# Load the agent
|
|
52
|
-
agent_instance = await _load_agent(training_result.model)
|
|
52
|
+
agent_instance = await _load_agent(training_result.model, input.endpoints_file)
|
|
53
53
|
|
|
54
54
|
# Verify agent is ready
|
|
55
55
|
if not agent_instance.is_ready():
|
|
@@ -96,7 +96,9 @@ async def try_load_existing_agent(project_folder: str) -> Optional[Agent]:
|
|
|
96
96
|
)
|
|
97
97
|
|
|
98
98
|
# Get available endpoints for agent loading
|
|
99
|
-
available_endpoints =
|
|
99
|
+
available_endpoints = Configuration.initialise_endpoints(
|
|
100
|
+
endpoints_path=Path(project_folder) / DEFAULT_ENDPOINTS_PATH
|
|
101
|
+
).endpoints
|
|
100
102
|
|
|
101
103
|
# Load the agent
|
|
102
104
|
agent = await load_agent(
|
|
@@ -124,27 +126,17 @@ async def try_load_existing_agent(project_folder: str) -> Optional[Agent]:
|
|
|
124
126
|
return None
|
|
125
127
|
|
|
126
128
|
|
|
127
|
-
async def
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
# Reset and load endpoints
|
|
131
|
-
AvailableEndpoints.reset_instance()
|
|
132
|
-
read_endpoints_from_path(endpoints_file)
|
|
133
|
-
|
|
134
|
-
structlogger.debug("training.endpoints_setup", endpoints_file=endpoints_file)
|
|
135
|
-
|
|
136
|
-
except Exception as e:
|
|
137
|
-
raise TrainingError(f"Failed to setup endpoints: {e}")
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
async def _train_model(importer: TrainingDataImporter) -> TrainingResult:
|
|
129
|
+
async def _train_model(
|
|
130
|
+
importer: TrainingDataImporter, endpoints_file: Path, config_file: Path
|
|
131
|
+
) -> TrainingResult:
|
|
141
132
|
"""Train the Rasa model."""
|
|
142
133
|
try:
|
|
143
134
|
structlogger.info("training.started")
|
|
144
135
|
|
|
145
136
|
training_result = await train(
|
|
146
137
|
domain="",
|
|
147
|
-
config=
|
|
138
|
+
config=str(config_file),
|
|
139
|
+
endpoints=str(endpoints_file),
|
|
148
140
|
training_files=None,
|
|
149
141
|
file_importer=importer,
|
|
150
142
|
)
|
|
@@ -160,12 +152,14 @@ async def _train_model(importer: TrainingDataImporter) -> TrainingResult:
|
|
|
160
152
|
raise TrainingError(f"Model training failed: {e}")
|
|
161
153
|
|
|
162
154
|
|
|
163
|
-
async def _load_agent(model_path: str) -> Agent:
|
|
155
|
+
async def _load_agent(model_path: str, endpoints_file: Path) -> Agent:
|
|
164
156
|
"""Load the trained agent."""
|
|
165
157
|
try:
|
|
166
158
|
structlogger.info("training.loading_agent", model_path=model_path)
|
|
167
159
|
|
|
168
|
-
available_endpoints =
|
|
160
|
+
available_endpoints = Configuration.initialise_endpoints(
|
|
161
|
+
endpoints_path=endpoints_file
|
|
162
|
+
).endpoints
|
|
169
163
|
if available_endpoints is None:
|
|
170
164
|
raise AgentLoadError("No endpoints available for agent loading")
|
|
171
165
|
|
|
@@ -9,7 +9,7 @@ import structlog
|
|
|
9
9
|
from rasa.builder import config
|
|
10
10
|
from rasa.builder.exceptions import ValidationError
|
|
11
11
|
from rasa.builder.logging_utils import capture_validation_logs
|
|
12
|
-
from rasa.cli.
|
|
12
|
+
from rasa.cli.validation.bot_config import validate_files
|
|
13
13
|
from rasa.shared.importers.importer import TrainingDataImporter
|
|
14
14
|
|
|
15
15
|
structlogger = structlog.get_logger()
|
|
@@ -3,6 +3,7 @@ import logging
|
|
|
3
3
|
from enum import Enum
|
|
4
4
|
from typing import List, Optional, Text, Union
|
|
5
5
|
|
|
6
|
+
from rasa.core.constants import DEFAULT_SUB_AGENTS
|
|
6
7
|
from rasa.core.persistor import RemoteStorageType, StorageType, parse_remote_storage
|
|
7
8
|
from rasa.shared.constants import (
|
|
8
9
|
DEFAULT_CONFIG_PATH,
|
|
@@ -217,3 +218,14 @@ def add_skip_validation_flag(
|
|
|
217
218
|
action="append",
|
|
218
219
|
help="Skip YAML validation for selected parts of the training data.",
|
|
219
220
|
)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def add_sub_agents_param(
|
|
224
|
+
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer],
|
|
225
|
+
) -> None:
|
|
226
|
+
parser.add_argument(
|
|
227
|
+
"--sub-agents",
|
|
228
|
+
type=str,
|
|
229
|
+
default=DEFAULT_SUB_AGENTS,
|
|
230
|
+
help="Directory that specifies sub-agents to use (default: %(default)s).",
|
|
231
|
+
)
|
rasa/cli/arguments/run.py
CHANGED
|
@@ -7,6 +7,7 @@ from rasa.cli.arguments.default_arguments import (
|
|
|
7
7
|
add_model_param,
|
|
8
8
|
add_remote_storage_param,
|
|
9
9
|
add_skip_validation_flag,
|
|
10
|
+
add_sub_agents_param,
|
|
10
11
|
)
|
|
11
12
|
from rasa.core import constants
|
|
12
13
|
from rasa.env import (
|
|
@@ -24,6 +25,7 @@ def set_run_arguments(parser: argparse.ArgumentParser) -> None:
|
|
|
24
25
|
add_server_arguments(parser)
|
|
25
26
|
add_inspect_argument(parser)
|
|
26
27
|
add_skip_validation_flag(parser)
|
|
28
|
+
add_sub_agents_param(parser)
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
def set_run_action_arguments(parser: argparse.ArgumentParser) -> None:
|
rasa/cli/arguments/train.py
CHANGED
|
@@ -10,6 +10,7 @@ from rasa.cli.arguments.default_arguments import (
|
|
|
10
10
|
add_remote_root_only_param,
|
|
11
11
|
add_remote_storage_param,
|
|
12
12
|
add_stories_param,
|
|
13
|
+
add_sub_agents_param,
|
|
13
14
|
)
|
|
14
15
|
from rasa.graph_components.providers.training_tracker_provider import (
|
|
15
16
|
TrainingTrackerProvider,
|
|
@@ -43,6 +44,7 @@ def set_train_arguments(parser: argparse.ArgumentParser) -> None:
|
|
|
43
44
|
)
|
|
44
45
|
add_remote_storage_param(parser)
|
|
45
46
|
add_remote_root_only_param(parser)
|
|
47
|
+
add_sub_agents_param(parser)
|
|
46
48
|
|
|
47
49
|
|
|
48
50
|
def set_train_core_arguments(parser: argparse.ArgumentParser) -> None:
|