rasa-pro 3.14.0.dev2__py3-none-any.whl → 3.14.0.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/agents/agent_manager.py +2 -2
- rasa/agents/constants.py +9 -0
- rasa/agents/core/agent_protocol.py +1 -2
- rasa/agents/protocol/a2a/a2a_agent.py +628 -17
- rasa/agents/protocol/mcp/mcp_base_agent.py +35 -56
- rasa/agents/protocol/mcp/mcp_open_agent.py +3 -2
- rasa/agents/protocol/mcp/mcp_task_agent.py +38 -16
- rasa/agents/schemas/__init__.py +8 -2
- rasa/agents/schemas/agent_input.py +15 -1
- rasa/agents/schemas/agent_tool_schema.py +23 -1
- rasa/agents/templates/mcp_task_agent_prompt_template.jinja2 +6 -2
- rasa/core/actions/action.py +13 -8
- rasa/core/available_agents.py +3 -0
- rasa/core/channels/development_inspector.py +3 -3
- rasa/core/channels/hangouts.py +2 -2
- rasa/core/channels/inspector/dist/assets/{arc-2e78c586.js → arc-63212852.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-806b712e.js → blockDiagram-38ab4fdb-eecf6b13.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-0745efa9.js → c4Diagram-3d4e48cf-8f798a9a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/channel-0cd70adf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-7bd1082b.js → classDiagram-70f12bd4-df71a04c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-d937ba49.js → classDiagram-v2-f2320105-9b275968.js} +1 -1
- rasa/core/channels/inspector/dist/assets/clone-a0f9c4ed.js +1 -0
- rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-a2a564ca.js → createText-2e5e7dd3-1c669cad.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-b5256940.js → edges-e0da2a9e-b1553799.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-e6883ad2.js → erDiagram-9861fffd-112388d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-e576fc02.js → flowDb-956e92f1-fdebec47.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-2e298d01.js → flowDiagram-66a62f08-6280ede1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-de9cc4aa.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-dd7b150a.js → flowchart-elk-definition-4a651766-e1dc03e5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-5b79575c.js → ganttDiagram-c361ad54-83f68c51.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-3016f40a.js → gitGraphDiagram-72cf32ee-22f8666f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{graph-3e19170f.js → graph-ca9e6217.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3862675e-eb9c86de.js → index-3862675e-c5ceb692.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-1bd9135e.js → index-3e293924.js} +3 -3
- rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-b4280e4d.js → infoDiagram-f8f76790-faa9999b.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-556091f8.js → journeyDiagram-49397b02-c4dda8d9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-08436411.js → layout-d4307784.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-683c4f3b.js → line-0567aaa7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-cee6d791.js → linear-c11b95cf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-a0bf0b1a.js → mindmap-definition-fc14e90a-0c7d3ca9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-3730d5c4.js → pieDiagram-8a3498a8-34b433fa.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-12a20fed.js → quadrantDiagram-120e2f19-4cab816e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-b9732102.js → requirementDiagram-deff3bca-8c22fa9e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-a2e72776.js → sankeyDiagram-04a897e0-70ce9e8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-8b7a76bb.js → sequenceDiagram-704730f1-fbcd7fc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-e65853ac.js → stateDiagram-587899a1-45f05ea6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-6f58a44b.js → stateDiagram-v2-d93cdb3a-beab1ea6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-df25b934.js → styles-6aaf32cf-2f29dbd5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9a916d00-88357141.js → styles-9a916d00-951eac83.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-c10674c1-d600174d.js → styles-c10674c1-897fbfdd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-4adc3e0b.js → svgDrawCommon-08f97a94-d667fac1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-42816fa1.js → timeline-definition-85554ec2-e3205144.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-621eb66a.js → xychartDiagram-e933f94c-4abeb0e2.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +1 -1
- rasa/core/channels/studio_chat.py +6 -6
- rasa/core/channels/voice_stream/genesys.py +1 -1
- rasa/core/policies/flow_policy.py +2 -2
- rasa/core/policies/flows/flow_executor.py +96 -17
- rasa/core/policies/flows/mcp_tool_executor.py +48 -11
- rasa/core/policies/intentless_policy.py +1 -1
- rasa/core/policies/unexpected_intent_policy.py +1 -0
- rasa/core/processor.py +12 -14
- rasa/core/tracker_stores/tracker_store.py +3 -7
- rasa/core/train.py +1 -1
- rasa/core/training/interactive.py +16 -16
- rasa/core/training/story_conflict.py +5 -5
- rasa/dialogue_understanding/commands/start_flow_command.py +36 -1
- rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +1 -1
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +5 -5
- rasa/dialogue_understanding/processor/command_processor.py +31 -15
- rasa/dialogue_understanding/stack/utils.py +14 -0
- rasa/e2e_test/e2e_test_runner.py +7 -2
- rasa/engine/caching.py +2 -2
- rasa/engine/recipes/default_components.py +10 -18
- rasa/graph_components/validators/default_recipe_validator.py +134 -134
- rasa/hooks.py +5 -5
- rasa/llm_fine_tuning/utils.py +2 -2
- rasa/model_manager/warm_rasa_process.py +1 -1
- rasa/nlu/extractors/extractor.py +2 -1
- rasa/plugin.py +8 -8
- rasa/privacy/privacy_manager.py +11 -2
- rasa/server.py +4 -2
- rasa/shared/core/events.py +9 -1
- rasa/shared/core/flows/flows_yaml_schema.json +12 -1
- rasa/shared/core/flows/steps/call.py +2 -0
- rasa/shared/core/flows/validation.py +3 -2
- rasa/shared/core/flows/yaml_flows_io.py +1 -1
- rasa/shared/core/slots.py +2 -2
- rasa/shared/core/trackers.py +5 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/utils/yaml.py +3 -1
- rasa/tracing/instrumentation/instrumentation.py +8 -8
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +4 -4
- rasa/utils/common.py +26 -0
- rasa/utils/log_utils.py +1 -1
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/rasa_layers.py +1 -1
- rasa/utils/train_utils.py +15 -15
- rasa/validator.py +16 -14
- rasa/version.py +1 -1
- {rasa_pro-3.14.0.dev2.dist-info → rasa_pro-3.14.0.dev4.dist-info}/METADATA +12 -15
- {rasa_pro-3.14.0.dev2.dist-info → rasa_pro-3.14.0.dev4.dist-info}/RECORD +107 -107
- rasa/core/channels/inspector/dist/assets/channel-c436ca7c.js +0 -1
- rasa/core/channels/inspector/dist/assets/clone-50dd656b.js +0 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-2b2aeaf8.js +0 -1
- {rasa_pro-3.14.0.dev2.dist-info → rasa_pro-3.14.0.dev4.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.0.dev2.dist-info → rasa_pro-3.14.0.dev4.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.0.dev2.dist-info → rasa_pro-3.14.0.dev4.dist-info}/entry_points.txt +0 -0
|
@@ -100,12 +100,12 @@ class StudioTrackerUpdatePlugin:
|
|
|
100
100
|
"""Remove tasks that have already completed."""
|
|
101
101
|
self.tasks = [task for task in self.tasks if not task.done()]
|
|
102
102
|
|
|
103
|
-
@hookimpl
|
|
103
|
+
@hookimpl
|
|
104
104
|
def after_new_user_message(self, tracker: "DialogueStateTracker") -> None:
|
|
105
105
|
"""Triggers a tracker update notification after a new user message."""
|
|
106
106
|
self.handle_tracker_update(tracker)
|
|
107
107
|
|
|
108
|
-
@hookimpl
|
|
108
|
+
@hookimpl
|
|
109
109
|
def after_action_executed(self, tracker: "DialogueStateTracker") -> None:
|
|
110
110
|
"""Triggers a tracker update notification after an action is executed."""
|
|
111
111
|
self.handle_tracker_update(tracker)
|
|
@@ -125,7 +125,7 @@ class StudioTrackerUpdatePlugin:
|
|
|
125
125
|
self.tasks.append(task)
|
|
126
126
|
self._cleanup_tasks()
|
|
127
127
|
|
|
128
|
-
@hookimpl
|
|
128
|
+
@hookimpl
|
|
129
129
|
def after_server_stop(self) -> None:
|
|
130
130
|
"""Cancels all remaining tasks when the server stops."""
|
|
131
131
|
self._cancel_tasks()
|
|
@@ -446,7 +446,7 @@ class StudioChatInput(SocketIOInput, VoiceInputChannel):
|
|
|
446
446
|
if sid in self.active_connections:
|
|
447
447
|
del self.active_connections[sid]
|
|
448
448
|
|
|
449
|
-
@hookimpl
|
|
449
|
+
@hookimpl
|
|
450
450
|
def after_server_stop(self) -> None:
|
|
451
451
|
"""Cleanup background tasks and active connections when the server stops."""
|
|
452
452
|
structlogger.info("studio_chat.after_server_stop.cleanup")
|
|
@@ -531,7 +531,7 @@ class StudioVoiceOutputChannel(VoiceOutputChannel):
|
|
|
531
531
|
|
|
532
532
|
def create_marker_message(self, recipient_id: str) -> Tuple[str, str]:
|
|
533
533
|
message_id = uuid.uuid4().hex
|
|
534
|
-
marker_data = {"marker": message_id}
|
|
534
|
+
marker_data: Dict[str, Any] = {"marker": message_id}
|
|
535
535
|
|
|
536
536
|
# Include comprehensive latency information if available
|
|
537
537
|
latency_data = {
|
|
@@ -546,7 +546,7 @@ class StudioVoiceOutputChannel(VoiceOutputChannel):
|
|
|
546
546
|
|
|
547
547
|
# Add latency data to marker if any metrics are available
|
|
548
548
|
if latency_data:
|
|
549
|
-
marker_data["latency"] = latency_data
|
|
549
|
+
marker_data["latency"] = latency_data
|
|
550
550
|
|
|
551
551
|
return json.dumps(marker_data), message_id
|
|
552
552
|
|
|
@@ -274,7 +274,7 @@ class GenesysInputChannel(VoiceInputChannel):
|
|
|
274
274
|
|
|
275
275
|
def handle_ping(self, ws: Websocket, message: dict) -> None:
|
|
276
276
|
"""Handle ping message from Genesys."""
|
|
277
|
-
response = {
|
|
277
|
+
response: Dict[str, Any] = {
|
|
278
278
|
"version": "2",
|
|
279
279
|
"type": "pong",
|
|
280
280
|
"seq": self._get_next_sequence(),
|
|
@@ -138,7 +138,7 @@ class FlowPolicy(Policy):
|
|
|
138
138
|
# create executor and predict next action
|
|
139
139
|
try:
|
|
140
140
|
prediction = await flow_executor.advance_flows(
|
|
141
|
-
tracker, domain.action_names_or_texts, flows
|
|
141
|
+
tracker, domain.action_names_or_texts, flows, domain.slots
|
|
142
142
|
)
|
|
143
143
|
return self._create_prediction_result(
|
|
144
144
|
prediction.action_name,
|
|
@@ -165,7 +165,7 @@ class FlowPolicy(Policy):
|
|
|
165
165
|
events = tracker.create_stack_updated_events(updated_stack)
|
|
166
166
|
tracker.update_with_events(events)
|
|
167
167
|
prediction = await flow_executor.advance_flows(
|
|
168
|
-
tracker, domain.action_names_or_texts, flows
|
|
168
|
+
tracker, domain.action_names_or_texts, flows, domain.slots
|
|
169
169
|
)
|
|
170
170
|
collected_events = events + (prediction.events or [])
|
|
171
171
|
return self._create_prediction_result(
|
|
@@ -10,12 +10,15 @@ from structlog.contextvars import bound_contextvars
|
|
|
10
10
|
|
|
11
11
|
from rasa.agents.agent_manager import AgentManager
|
|
12
12
|
from rasa.agents.constants import (
|
|
13
|
+
A2A_AGENT_CONTEXT_ID_KEY,
|
|
13
14
|
AGENT_METADATA_AGENT_RESPONSE_KEY,
|
|
14
15
|
AGENT_METADATA_EXIT_IF_KEY,
|
|
15
16
|
AGENT_METADATA_TOOL_RESULTS_KEY,
|
|
17
|
+
MAX_AGENT_RETRY_DELAY_SECONDS,
|
|
16
18
|
)
|
|
17
19
|
from rasa.agents.core.types import AgentStatus, ProtocolType
|
|
18
20
|
from rasa.agents.schemas import AgentInput, AgentOutput
|
|
21
|
+
from rasa.agents.schemas.agent_input import AgentInputSlot
|
|
19
22
|
from rasa.core.available_agents import AvailableAgents
|
|
20
23
|
from rasa.core.available_endpoints import AvailableEndpoints
|
|
21
24
|
from rasa.core.constants import ACTIVE_FLOW_METADATA_KEY, STEP_ID_METADATA_KEY
|
|
@@ -103,7 +106,7 @@ from rasa.shared.core.flows.steps import (
|
|
|
103
106
|
SetSlotsFlowStep,
|
|
104
107
|
)
|
|
105
108
|
from rasa.shared.core.flows.steps.constants import START_STEP
|
|
106
|
-
from rasa.shared.core.slots import Slot, SlotRejection
|
|
109
|
+
from rasa.shared.core.slots import CategoricalSlot, Slot, SlotRejection
|
|
107
110
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
108
111
|
from rasa.shared.utils.llm import tracker_as_readable_transcript
|
|
109
112
|
|
|
@@ -111,7 +114,6 @@ structlogger = structlog.get_logger()
|
|
|
111
114
|
|
|
112
115
|
MAX_NUMBER_OF_STEPS = 250
|
|
113
116
|
|
|
114
|
-
MAX_AGENT_RETRY_DELAY_SECONDS = 5
|
|
115
117
|
MAX_AGENT_RETRIES = 3
|
|
116
118
|
|
|
117
119
|
# Slots that should not be forwarded to sub-agents via AgentInput
|
|
@@ -387,7 +389,10 @@ def reset_scoped_slots(
|
|
|
387
389
|
|
|
388
390
|
|
|
389
391
|
async def advance_flows(
|
|
390
|
-
tracker: DialogueStateTracker,
|
|
392
|
+
tracker: DialogueStateTracker,
|
|
393
|
+
available_actions: List[str],
|
|
394
|
+
flows: FlowsList,
|
|
395
|
+
slots: List[Slot],
|
|
391
396
|
) -> FlowActionPrediction:
|
|
392
397
|
"""Advance the current flows until the next action.
|
|
393
398
|
|
|
@@ -395,6 +400,7 @@ async def advance_flows(
|
|
|
395
400
|
tracker: The tracker to get the next action for.
|
|
396
401
|
available_actions: The actions that are available in the domain.
|
|
397
402
|
flows: All flows.
|
|
403
|
+
slots: The slots that are available in the domain.
|
|
398
404
|
|
|
399
405
|
Returns:
|
|
400
406
|
The predicted action and the events to run.
|
|
@@ -404,13 +410,16 @@ async def advance_flows(
|
|
|
404
410
|
# if there are no flows, there is nothing to do
|
|
405
411
|
return FlowActionPrediction(None, 0.0)
|
|
406
412
|
|
|
407
|
-
return await advance_flows_until_next_action(
|
|
413
|
+
return await advance_flows_until_next_action(
|
|
414
|
+
tracker, available_actions, flows, slots
|
|
415
|
+
)
|
|
408
416
|
|
|
409
417
|
|
|
410
418
|
async def advance_flows_until_next_action(
|
|
411
419
|
tracker: DialogueStateTracker,
|
|
412
420
|
available_actions: List[str],
|
|
413
421
|
flows: FlowsList,
|
|
422
|
+
slots: List[Slot],
|
|
414
423
|
) -> FlowActionPrediction:
|
|
415
424
|
"""Advance the flow and select the next action to execute.
|
|
416
425
|
|
|
@@ -476,6 +485,7 @@ async def advance_flows_until_next_action(
|
|
|
476
485
|
available_actions,
|
|
477
486
|
flows,
|
|
478
487
|
previous_step_id,
|
|
488
|
+
slots,
|
|
479
489
|
)
|
|
480
490
|
new_events = step_result.events
|
|
481
491
|
if (
|
|
@@ -586,6 +596,7 @@ async def run_step(
|
|
|
586
596
|
available_actions: List[str],
|
|
587
597
|
flows: FlowsList,
|
|
588
598
|
previous_step_id: str,
|
|
599
|
+
slots: List[Slot],
|
|
589
600
|
) -> FlowStepResult:
|
|
590
601
|
"""Run a single step of a flow.
|
|
591
602
|
|
|
@@ -604,6 +615,7 @@ async def run_step(
|
|
|
604
615
|
available_actions: The actions that are available in the domain.
|
|
605
616
|
flows: All flows.
|
|
606
617
|
previous_step_id: The ID of the previous step.
|
|
618
|
+
slots: The slots that are available in the domain.
|
|
607
619
|
|
|
608
620
|
Returns:
|
|
609
621
|
A result of running the step describing where to transition to.
|
|
@@ -643,7 +655,7 @@ async def run_step(
|
|
|
643
655
|
return _run_link_step(initial_events, stack, step)
|
|
644
656
|
|
|
645
657
|
elif isinstance(step, CallFlowStep):
|
|
646
|
-
return await _run_call_step(initial_events, stack, step, tracker)
|
|
658
|
+
return await _run_call_step(initial_events, stack, step, tracker, slots)
|
|
647
659
|
|
|
648
660
|
elif isinstance(step, SetSlotsFlowStep):
|
|
649
661
|
return _run_set_slot_step(initial_events, step)
|
|
@@ -717,12 +729,13 @@ async def _run_call_step(
|
|
|
717
729
|
stack: DialogueStack,
|
|
718
730
|
step: CallFlowStep,
|
|
719
731
|
tracker: DialogueStateTracker,
|
|
732
|
+
slots: List[Slot],
|
|
720
733
|
) -> FlowStepResult:
|
|
721
734
|
structlogger.debug("flow.step.run.call")
|
|
722
735
|
if step.is_calling_mcp_tool():
|
|
723
736
|
return await call_mcp_tool(initial_events, stack, step, tracker)
|
|
724
737
|
elif step.is_calling_agent():
|
|
725
|
-
return await run_agent(initial_events, stack, step, tracker)
|
|
738
|
+
return await run_agent(initial_events, stack, step, tracker, slots)
|
|
726
739
|
else:
|
|
727
740
|
stack.push(
|
|
728
741
|
UserFlowStackFrame(
|
|
@@ -885,6 +898,7 @@ async def run_agent(
|
|
|
885
898
|
stack: DialogueStack,
|
|
886
899
|
step: CallFlowStep,
|
|
887
900
|
tracker: DialogueStateTracker,
|
|
901
|
+
slots: List[Slot],
|
|
888
902
|
) -> FlowStepResult:
|
|
889
903
|
"""Run an agent call step."""
|
|
890
904
|
structlogger.debug(
|
|
@@ -927,6 +941,9 @@ async def run_agent(
|
|
|
927
941
|
if agent_stack_frame and agent_stack_frame.metadata
|
|
928
942
|
else {}
|
|
929
943
|
)
|
|
944
|
+
_update_agent_input_metadata_with_events(
|
|
945
|
+
agent_input_metadata, step.call, step.flow_id, tracker
|
|
946
|
+
)
|
|
930
947
|
if step.exit_if:
|
|
931
948
|
# TODO: this is a temporary fix to reset the slots covered by the exit_if
|
|
932
949
|
if (
|
|
@@ -942,7 +959,7 @@ async def run_agent(
|
|
|
942
959
|
user_message=tracker.latest_message.text or ""
|
|
943
960
|
if tracker.latest_message
|
|
944
961
|
else "",
|
|
945
|
-
slots=
|
|
962
|
+
slots=_prepare_slots_for_agent(tracker.current_slot_values(), slots),
|
|
946
963
|
conversation_history=tracker_as_readable_transcript(tracker),
|
|
947
964
|
events=tracker.current_state().get("events") or [],
|
|
948
965
|
metadata=agent_input_metadata,
|
|
@@ -983,6 +1000,7 @@ async def run_agent(
|
|
|
983
1000
|
output.response_message or ""
|
|
984
1001
|
)
|
|
985
1002
|
output.metadata[AGENT_METADATA_TOOL_RESULTS_KEY] = output.tool_results or []
|
|
1003
|
+
_update_agent_events(final_events, output.metadata)
|
|
986
1004
|
|
|
987
1005
|
top_stack_frame = stack.top()
|
|
988
1006
|
# update the agent stack frame if it is already on the stack
|
|
@@ -1009,6 +1027,8 @@ async def run_agent(
|
|
|
1009
1027
|
)
|
|
1010
1028
|
return PauseFlowReturnPrediction(action_prediction)
|
|
1011
1029
|
elif output.status == AgentStatus.COMPLETED:
|
|
1030
|
+
output.metadata = output.metadata or {}
|
|
1031
|
+
_update_agent_events(final_events, output.metadata)
|
|
1012
1032
|
structlogger.debug(
|
|
1013
1033
|
"flow.step.run_agent.completed",
|
|
1014
1034
|
agent_name=step.call,
|
|
@@ -1027,6 +1047,8 @@ async def run_agent(
|
|
|
1027
1047
|
else:
|
|
1028
1048
|
return ContinueFlowWithNextStep(events=final_events)
|
|
1029
1049
|
elif output.status == AgentStatus.FATAL_ERROR:
|
|
1050
|
+
output.metadata = output.metadata or {}
|
|
1051
|
+
_update_agent_events(final_events, output.metadata)
|
|
1030
1052
|
# the agent failed, trigger pattern_internal_error
|
|
1031
1053
|
structlogger.error(
|
|
1032
1054
|
"flow.step.run_agent.fatal_error",
|
|
@@ -1044,6 +1066,8 @@ async def run_agent(
|
|
|
1044
1066
|
stack.push(InternalErrorPatternFlowStackFrame())
|
|
1045
1067
|
return ContinueFlowWithNextStep(events=final_events)
|
|
1046
1068
|
else:
|
|
1069
|
+
output.metadata = output.metadata or {}
|
|
1070
|
+
_update_agent_events(final_events, output.metadata)
|
|
1047
1071
|
structlogger.error(
|
|
1048
1072
|
"flow.step.run_agent.unknown_status",
|
|
1049
1073
|
agent_name=step.call,
|
|
@@ -1156,18 +1180,73 @@ async def _call_agent_with_retry(
|
|
|
1156
1180
|
)
|
|
1157
1181
|
|
|
1158
1182
|
|
|
1159
|
-
def
|
|
1160
|
-
|
|
1183
|
+
def _prepare_slots_for_agent(
|
|
1184
|
+
slot_values: Dict[str, Any], slot_definitions: List[Slot]
|
|
1185
|
+
) -> List[AgentInputSlot]:
|
|
1186
|
+
"""Prepare the slots for the agent.
|
|
1187
|
+
|
|
1188
|
+
Filter out slots that should not be forwarded to agents.
|
|
1189
|
+
Add the slot type and allowed values to the slot dictionary.
|
|
1161
1190
|
|
|
1162
1191
|
Args:
|
|
1163
|
-
|
|
1192
|
+
slot_values: The full slot dictionary from the tracker.
|
|
1193
|
+
slot_definitions: The slot definitions from the domain.
|
|
1164
1194
|
|
|
1165
1195
|
Returns:
|
|
1166
|
-
A
|
|
1167
|
-
`SLOTS_EXCLUDED_FOR_AGENT`.
|
|
1196
|
+
A list of slots containing the name, current value, type, and allowed values.
|
|
1168
1197
|
"""
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
for
|
|
1172
|
-
|
|
1173
|
-
|
|
1198
|
+
|
|
1199
|
+
def _get_slot_definition(slot_name: str) -> Optional[Slot]:
|
|
1200
|
+
for slot in slot_definitions:
|
|
1201
|
+
if slot.name == slot_name:
|
|
1202
|
+
return slot
|
|
1203
|
+
return None
|
|
1204
|
+
|
|
1205
|
+
filtered_slots: List[AgentInputSlot] = []
|
|
1206
|
+
for key, value in slot_values.items():
|
|
1207
|
+
if key in SLOTS_EXCLUDED_FOR_AGENT:
|
|
1208
|
+
continue
|
|
1209
|
+
slot_definition = _get_slot_definition(key)
|
|
1210
|
+
if slot_definition:
|
|
1211
|
+
filtered_slots.append(
|
|
1212
|
+
AgentInputSlot(
|
|
1213
|
+
name=key,
|
|
1214
|
+
value=value,
|
|
1215
|
+
type=slot_definition.type_name if slot_definition else "any",
|
|
1216
|
+
allowed_values=slot_definition.values
|
|
1217
|
+
if isinstance(slot_definition, CategoricalSlot)
|
|
1218
|
+
else None,
|
|
1219
|
+
)
|
|
1220
|
+
)
|
|
1221
|
+
|
|
1222
|
+
return filtered_slots
|
|
1223
|
+
|
|
1224
|
+
|
|
1225
|
+
def _update_agent_events(events: List[Event], metadata: Dict[str, Any]) -> None:
|
|
1226
|
+
"""Update the agent events based on the agent output metadata if needed."""
|
|
1227
|
+
if A2A_AGENT_CONTEXT_ID_KEY in metadata:
|
|
1228
|
+
# If the context ID is present, we need to store it in the AgentStarted
|
|
1229
|
+
# event, so that it can be re-used later in case the agent is restarted.
|
|
1230
|
+
for event in events:
|
|
1231
|
+
if isinstance(event, AgentStarted):
|
|
1232
|
+
event.context_id = metadata[A2A_AGENT_CONTEXT_ID_KEY]
|
|
1233
|
+
|
|
1234
|
+
|
|
1235
|
+
def _update_agent_input_metadata_with_events(
|
|
1236
|
+
metadata: Dict[str, Any], agent_id: str, flow_id: str, tracker: DialogueStateTracker
|
|
1237
|
+
) -> None:
|
|
1238
|
+
"""Update the agent input metadata with the events."""
|
|
1239
|
+
agent_started_events = [
|
|
1240
|
+
event
|
|
1241
|
+
for event in tracker.events
|
|
1242
|
+
if type(event) == AgentStarted
|
|
1243
|
+
and event.agent_id == agent_id
|
|
1244
|
+
and event.flow_id == flow_id
|
|
1245
|
+
]
|
|
1246
|
+
if agent_started_events:
|
|
1247
|
+
# If we have context ID from the previous agent run, we want to
|
|
1248
|
+
# include it in the metadata so that the agent can continue the same
|
|
1249
|
+
# context.
|
|
1250
|
+
agent_started_event = agent_started_events[-1]
|
|
1251
|
+
if agent_started_event.context_id:
|
|
1252
|
+
metadata[A2A_AGENT_CONTEXT_ID_KEY] = agent_started_event.context_id
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import operator
|
|
3
|
+
from functools import reduce
|
|
2
4
|
from typing import Any, Dict, List, Optional
|
|
3
5
|
|
|
4
6
|
import structlog
|
|
@@ -17,10 +19,12 @@ from rasa.shared.core.events import Event, SlotSet
|
|
|
17
19
|
from rasa.shared.core.flows.steps import CallFlowStep
|
|
18
20
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
19
21
|
from rasa.shared.utils.mcp.server_connection import MCPServerConnection
|
|
22
|
+
from rasa.utils.common import ensure_jsonified_iterable
|
|
20
23
|
|
|
21
24
|
structlogger = structlog.get_logger()
|
|
22
25
|
|
|
23
|
-
|
|
26
|
+
CONFIG_RESULT_KEY = "result_key"
|
|
27
|
+
CONFIG_SLOT = "slot"
|
|
24
28
|
|
|
25
29
|
|
|
26
30
|
async def call_mcp_tool(
|
|
@@ -101,7 +105,7 @@ async def _execute_mcp_tool_call(
|
|
|
101
105
|
result = await mcp_server.call_tool(step.call, arguments)
|
|
102
106
|
|
|
103
107
|
# Handle tool execution result
|
|
104
|
-
if result.isError:
|
|
108
|
+
if result is None or result.isError:
|
|
105
109
|
return _handle_mcp_tool_error(
|
|
106
110
|
stack,
|
|
107
111
|
initial_events,
|
|
@@ -121,12 +125,20 @@ async def _execute_mcp_tool_call(
|
|
|
121
125
|
tool_name=step.call,
|
|
122
126
|
mcp_server=step.mcp_server,
|
|
123
127
|
result_content=result.content,
|
|
128
|
+
result_structured_content=result.structuredContent,
|
|
124
129
|
)
|
|
125
130
|
|
|
126
131
|
# Process successful result
|
|
127
|
-
set_slot_event
|
|
128
|
-
|
|
129
|
-
|
|
132
|
+
if set_slot_event := _process_tool_result(result, step.mapping["output"]):
|
|
133
|
+
initial_events.extend(set_slot_event)
|
|
134
|
+
else:
|
|
135
|
+
return _handle_mcp_tool_error(
|
|
136
|
+
stack,
|
|
137
|
+
initial_events,
|
|
138
|
+
f"Failed to process tool result for '{step.call}'.",
|
|
139
|
+
tool_name=step.call,
|
|
140
|
+
mcp_server=step.mcp_server,
|
|
141
|
+
)
|
|
130
142
|
|
|
131
143
|
return ContinueFlowWithNextStep(events=initial_events)
|
|
132
144
|
|
|
@@ -203,16 +215,41 @@ def _prepare_tool_arguments(
|
|
|
203
215
|
return arguments
|
|
204
216
|
|
|
205
217
|
|
|
218
|
+
def _jsonify_slot_value(value: Any) -> str | int | float | bool | None:
|
|
219
|
+
"""Prepare value for SlotSet: iterables -> JSON string, primitives -> as-is"""
|
|
220
|
+
if isinstance(value, (list, dict)) and len(value):
|
|
221
|
+
return json.dumps(ensure_jsonified_iterable(value))
|
|
222
|
+
return value
|
|
223
|
+
|
|
224
|
+
|
|
206
225
|
def _process_tool_result(
|
|
207
226
|
result: CallToolResult,
|
|
208
|
-
output_mapping: str,
|
|
209
|
-
) -> Optional[SlotSet]:
|
|
227
|
+
output_mapping: List[Dict[str, str]],
|
|
228
|
+
) -> Optional[List[SlotSet]]:
|
|
210
229
|
"""Create a SetSlot event for the tool result."""
|
|
211
230
|
try:
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
231
|
+
_result_as_dict = {"result": result.model_dump()}
|
|
232
|
+
slots = []
|
|
233
|
+
for mapping in output_mapping:
|
|
234
|
+
try:
|
|
235
|
+
# Use reduce to navigate through nested keys in the result
|
|
236
|
+
slot_value = reduce(
|
|
237
|
+
operator.getitem,
|
|
238
|
+
mapping[CONFIG_RESULT_KEY].split("."),
|
|
239
|
+
_result_as_dict,
|
|
240
|
+
)
|
|
241
|
+
slots.append(
|
|
242
|
+
SlotSet(mapping[CONFIG_SLOT], _jsonify_slot_value(slot_value))
|
|
243
|
+
)
|
|
244
|
+
except (KeyError, TypeError):
|
|
245
|
+
structlogger.error(
|
|
246
|
+
"call_mcp_tool.result_key_not_found_in_tool_result",
|
|
247
|
+
slot=mapping[CONFIG_SLOT],
|
|
248
|
+
result_key=mapping[CONFIG_RESULT_KEY],
|
|
249
|
+
result=_result_as_dict,
|
|
250
|
+
)
|
|
251
|
+
return None
|
|
252
|
+
return slots
|
|
216
253
|
except Exception as e:
|
|
217
254
|
structlogger.error(
|
|
218
255
|
"call_mcp_tool.result_processing_failed",
|
|
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Tuple
|
|
|
5
5
|
|
|
6
6
|
import structlog
|
|
7
7
|
import tiktoken
|
|
8
|
-
from deprecated import deprecated # type: ignore[import]
|
|
8
|
+
from deprecated import deprecated # type: ignore[import-untyped]
|
|
9
9
|
from jinja2 import Template
|
|
10
10
|
from langchain.docstore.document import Document
|
|
11
11
|
from langchain.schema.embeddings import Embeddings
|
|
@@ -624,6 +624,7 @@ class UnexpecTEDIntentPolicy(TEDPolicy):
|
|
|
624
624
|
query_intent = (
|
|
625
625
|
last_user_uttered_event.intent_name
|
|
626
626
|
if last_user_uttered_event is not None
|
|
627
|
+
and isinstance(last_user_uttered_event, UserUttered)
|
|
627
628
|
else ""
|
|
628
629
|
)
|
|
629
630
|
is_unlikely_intent = self._check_unlikely_intent(
|
rasa/core/processor.py
CHANGED
|
@@ -730,7 +730,7 @@ class MessageProcessor:
|
|
|
730
730
|
if not self.domain or self.domain.is_empty():
|
|
731
731
|
return
|
|
732
732
|
|
|
733
|
-
intent = parse_data[
|
|
733
|
+
intent = parse_data[INTENT][INTENT_NAME_KEY]
|
|
734
734
|
if intent and intent not in self.domain.intents:
|
|
735
735
|
rasa.shared.utils.io.raise_warning(
|
|
736
736
|
f"Parsed an intent '{intent}' "
|
|
@@ -739,7 +739,7 @@ class MessageProcessor:
|
|
|
739
739
|
docs=DOCS_URL_DOMAINS,
|
|
740
740
|
)
|
|
741
741
|
|
|
742
|
-
entities = parse_data[
|
|
742
|
+
entities = parse_data[ENTITIES] or []
|
|
743
743
|
for element in entities:
|
|
744
744
|
entity = element["entity"]
|
|
745
745
|
if entity and entity not in self.domain.entities:
|
|
@@ -823,9 +823,9 @@ class MessageProcessor:
|
|
|
823
823
|
self._update_full_retrieval_intent(parse_data)
|
|
824
824
|
structlogger.debug(
|
|
825
825
|
"processor.message.parse",
|
|
826
|
-
parse_data_text=copy.deepcopy(parse_data[
|
|
827
|
-
parse_data_intent=parse_data[
|
|
828
|
-
parse_data_entities=copy.deepcopy(parse_data[
|
|
826
|
+
parse_data_text=copy.deepcopy(parse_data[TEXT]),
|
|
827
|
+
parse_data_intent=parse_data[INTENT],
|
|
828
|
+
parse_data_entities=copy.deepcopy(parse_data[ENTITIES]),
|
|
829
829
|
)
|
|
830
830
|
|
|
831
831
|
self._check_for_unseen_features(parse_data)
|
|
@@ -974,7 +974,7 @@ class MessageProcessor:
|
|
|
974
974
|
f"invalid intent: {parse_data[INTENT]['name']}. "
|
|
975
975
|
f"Returning CannotHandleCommand() as a fallback."
|
|
976
976
|
),
|
|
977
|
-
invalid_intent=parse_data[INTENT][
|
|
977
|
+
invalid_intent=parse_data[INTENT][INTENT_NAME_KEY],
|
|
978
978
|
)
|
|
979
979
|
commands.append(
|
|
980
980
|
CannotHandleCommand(RASA_PATTERN_CANNOT_HANDLE_INVALID_INTENT)
|
|
@@ -984,7 +984,7 @@ class MessageProcessor:
|
|
|
984
984
|
|
|
985
985
|
def _contains_undefined_intent(self, message: Message) -> bool:
|
|
986
986
|
"""Checks if the message contains an undefined intent."""
|
|
987
|
-
intent_name = message.get(INTENT, {}).get(
|
|
987
|
+
intent_name = message.get(INTENT, {}).get(INTENT_NAME_KEY)
|
|
988
988
|
return intent_name is not None and intent_name not in self.domain.intents
|
|
989
989
|
|
|
990
990
|
async def _parse_message_with_graph(
|
|
@@ -1034,8 +1034,8 @@ class MessageProcessor:
|
|
|
1034
1034
|
tracker.update(
|
|
1035
1035
|
UserUttered(
|
|
1036
1036
|
message.text,
|
|
1037
|
-
parse_data[
|
|
1038
|
-
parse_data[
|
|
1037
|
+
parse_data[INTENT],
|
|
1038
|
+
parse_data[ENTITIES],
|
|
1039
1039
|
parse_data,
|
|
1040
1040
|
input_channel=message.input_channel,
|
|
1041
1041
|
message_id=message.message_id,
|
|
@@ -1044,7 +1044,7 @@ class MessageProcessor:
|
|
|
1044
1044
|
self.domain,
|
|
1045
1045
|
)
|
|
1046
1046
|
|
|
1047
|
-
if parse_data[
|
|
1047
|
+
if parse_data[ENTITIES]:
|
|
1048
1048
|
self._log_slots(tracker)
|
|
1049
1049
|
|
|
1050
1050
|
plugin_manager().hook.after_new_user_message(tracker=tracker)
|
|
@@ -1472,11 +1472,9 @@ class MessageProcessor:
|
|
|
1472
1472
|
# tracker has never expired if sessions are disabled
|
|
1473
1473
|
return False
|
|
1474
1474
|
|
|
1475
|
-
user_uttered_event
|
|
1476
|
-
UserUttered
|
|
1477
|
-
)
|
|
1475
|
+
user_uttered_event = tracker.get_last_event_for(UserUttered)
|
|
1478
1476
|
|
|
1479
|
-
if not user_uttered_event:
|
|
1477
|
+
if not user_uttered_event or not isinstance(user_uttered_event, UserUttered):
|
|
1480
1478
|
# there is no user event so far so the session should not be considered
|
|
1481
1479
|
# expired
|
|
1482
1480
|
return False
|
|
@@ -542,7 +542,7 @@ class FailSafeTrackerStore(TrackerStore):
|
|
|
542
542
|
return self._tracker_store.domain
|
|
543
543
|
|
|
544
544
|
@domain.setter
|
|
545
|
-
def domain(self, domain: Domain) -> None:
|
|
545
|
+
def domain(self, domain: Optional[Domain]) -> None:
|
|
546
546
|
self._tracker_store.domain = domain
|
|
547
547
|
|
|
548
548
|
if self._fallback_tracker_store:
|
|
@@ -805,9 +805,7 @@ class AwaitableTrackerStore(TrackerStore):
|
|
|
805
805
|
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
806
806
|
"""Wrapper to call `retrieve` method of primary tracker store."""
|
|
807
807
|
result = self._tracker_store.retrieve(sender_id)
|
|
808
|
-
return (
|
|
809
|
-
await result if isawaitable(result) else result # type: ignore[return-value, misc]
|
|
810
|
-
)
|
|
808
|
+
return await result if isawaitable(result) else result
|
|
811
809
|
|
|
812
810
|
async def keys(self) -> Iterable[Text]:
|
|
813
811
|
"""Wrapper to call `keys` method of primary tracker store."""
|
|
@@ -834,6 +832,4 @@ class AwaitableTrackerStore(TrackerStore):
|
|
|
834
832
|
) -> Optional[DialogueStateTracker]:
|
|
835
833
|
"""Wrapper to call `retrieve_full_tracker` method of primary tracker store."""
|
|
836
834
|
result = self._tracker_store.retrieve_full_tracker(conversation_id)
|
|
837
|
-
return (
|
|
838
|
-
await result if isawaitable(result) else result # type: ignore[return-value, misc]
|
|
839
|
-
)
|
|
835
|
+
return await result if isawaitable(result) else result
|
rasa/core/train.py
CHANGED
|
@@ -46,7 +46,7 @@ async def train_comparison_models(
|
|
|
46
46
|
output=str(Path(output_path, f"run_{r +1}")),
|
|
47
47
|
fixed_model_name=config_name + PERCENTAGE_KEY + str(percentage),
|
|
48
48
|
additional_arguments={
|
|
49
|
-
**additional_arguments,
|
|
49
|
+
**(additional_arguments or {}),
|
|
50
50
|
"exclusion_percentage": percentage,
|
|
51
51
|
},
|
|
52
52
|
)
|
|
@@ -91,7 +91,7 @@ from rasa.shared.core.training_data.visualization import (
|
|
|
91
91
|
)
|
|
92
92
|
from rasa.shared.exceptions import InvalidConfigException
|
|
93
93
|
from rasa.shared.importers.rasa import TrainingDataImporter
|
|
94
|
-
from rasa.shared.nlu.constants import INTENT_NAME_KEY, TEXT
|
|
94
|
+
from rasa.shared.nlu.constants import ENTITIES, INTENT, INTENT_NAME_KEY, TEXT
|
|
95
95
|
|
|
96
96
|
# noinspection PyProtectedMember
|
|
97
97
|
from rasa.shared.nlu.training_data import loading
|
|
@@ -789,7 +789,7 @@ def _collect_messages(events: List[Dict[Text, Any]]) -> List[Message]:
|
|
|
789
789
|
data = event.get("parse_data", {})
|
|
790
790
|
rasa_nlu_training_data_utils.remove_untrainable_entities_from(data)
|
|
791
791
|
msg = Message.build(
|
|
792
|
-
data[
|
|
792
|
+
data[TEXT], data[INTENT][INTENT_NAME_KEY], data[ENTITIES]
|
|
793
793
|
)
|
|
794
794
|
messages.append(msg)
|
|
795
795
|
elif event.get("event") == UserUtteranceReverted.type_name and messages:
|
|
@@ -901,13 +901,13 @@ def _get_nlu_target_format(export_path: Text) -> Text:
|
|
|
901
901
|
|
|
902
902
|
def _entities_from_messages(messages: List[Message]) -> List[Text]:
|
|
903
903
|
"""Return all entities that occur in at least one of the messages."""
|
|
904
|
-
return list({e["entity"] for m in messages for e in m.data.get(
|
|
904
|
+
return list({e["entity"] for m in messages for e in m.data.get(ENTITIES, [])})
|
|
905
905
|
|
|
906
906
|
|
|
907
907
|
def _intents_from_messages(messages: List[Message]) -> Set[Text]:
|
|
908
908
|
"""Return all intents that occur in at least one of the messages."""
|
|
909
909
|
# set of distinct intents
|
|
910
|
-
distinct_intents = {m.data[
|
|
910
|
+
distinct_intents = {m.data[INTENT] for m in messages if INTENT in m.data}
|
|
911
911
|
|
|
912
912
|
return distinct_intents
|
|
913
913
|
|
|
@@ -1191,11 +1191,11 @@ def _as_md_message(parse_data: Dict[Text, Any]) -> Text:
|
|
|
1191
1191
|
"""Display the parse data of a message in markdown format."""
|
|
1192
1192
|
from rasa.shared.nlu.training_data.formats.readerwriter import TrainingDataWriter
|
|
1193
1193
|
|
|
1194
|
-
if parse_data.get(
|
|
1195
|
-
return parse_data[
|
|
1194
|
+
if parse_data.get(TEXT, "").startswith(INTENT_MESSAGE_PREFIX):
|
|
1195
|
+
return parse_data[TEXT]
|
|
1196
1196
|
|
|
1197
|
-
if not parse_data.get(
|
|
1198
|
-
parse_data[
|
|
1197
|
+
if not parse_data.get(ENTITIES):
|
|
1198
|
+
parse_data[ENTITIES] = []
|
|
1199
1199
|
|
|
1200
1200
|
return TrainingDataWriter.generate_message(parse_data)
|
|
1201
1201
|
|
|
@@ -1207,7 +1207,7 @@ def _validate_user_regex(latest_message: Dict[Text, Any], intents: List[Text]) -
|
|
|
1207
1207
|
`/greet`. Return `True` if the intent is a known one.
|
|
1208
1208
|
"""
|
|
1209
1209
|
parse_data = latest_message.get("parse_data", {})
|
|
1210
|
-
intent = parse_data.get(
|
|
1210
|
+
intent = parse_data.get(INTENT, {}).get(INTENT_NAME_KEY)
|
|
1211
1211
|
|
|
1212
1212
|
if intent in intents:
|
|
1213
1213
|
return True
|
|
@@ -1224,8 +1224,8 @@ async def _validate_user_text(
|
|
|
1224
1224
|
"""
|
|
1225
1225
|
parse_data = latest_message.get("parse_data", {})
|
|
1226
1226
|
text = _as_md_message(parse_data)
|
|
1227
|
-
intent = parse_data.get(
|
|
1228
|
-
entities = parse_data.get(
|
|
1227
|
+
intent = parse_data.get(INTENT, {}).get(INTENT_NAME_KEY)
|
|
1228
|
+
entities = parse_data.get(ENTITIES, [])
|
|
1229
1229
|
if entities:
|
|
1230
1230
|
message = (
|
|
1231
1231
|
f"Is the intent '{intent}' correct for '{text}' and are "
|
|
@@ -1276,9 +1276,9 @@ async def _validate_nlu(
|
|
|
1276
1276
|
|
|
1277
1277
|
entities = await _correct_entities(latest_message, endpoint, conversation_id)
|
|
1278
1278
|
corrected_nlu = {
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1279
|
+
INTENT: corrected_intent,
|
|
1280
|
+
ENTITIES: entities,
|
|
1281
|
+
TEXT: latest_message.get("text"),
|
|
1282
1282
|
}
|
|
1283
1283
|
|
|
1284
1284
|
await _correct_wrong_nlu(corrected_nlu, events, endpoint, conversation_id)
|
|
@@ -1315,9 +1315,9 @@ def _merge_annotated_and_original_entities(
|
|
|
1315
1315
|
# overwrite entities which have already been
|
|
1316
1316
|
# annotated in the original annotation to preserve
|
|
1317
1317
|
# additional entity parser information
|
|
1318
|
-
entities = parse_annotated.get(
|
|
1318
|
+
entities = parse_annotated.get(ENTITIES, [])[:]
|
|
1319
1319
|
for i, entity in enumerate(entities):
|
|
1320
|
-
for original_entity in parse_original.get(
|
|
1320
|
+
for original_entity in parse_original.get(ENTITIES, []):
|
|
1321
1321
|
if _is_same_entity_annotation(entity, original_entity):
|
|
1322
1322
|
entities[i] = original_entity
|
|
1323
1323
|
break
|