rasa-pro 3.12.4__py3-none-any.whl → 3.13.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/cli/scaffold.py +1 -1
- rasa/core/actions/action.py +38 -28
- rasa/core/actions/action_run_slot_rejections.py +1 -1
- rasa/core/channels/studio_chat.py +16 -43
- rasa/core/information_retrieval/faiss.py +62 -6
- rasa/core/nlg/contextual_response_rephraser.py +7 -6
- rasa/core/nlg/generator.py +5 -21
- rasa/core/nlg/response.py +6 -43
- rasa/core/nlg/translate.py +0 -8
- rasa/core/policies/enterprise_search_policy.py +1 -0
- rasa/core/policies/intentless_policy.py +6 -59
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
- rasa/dialogue_understanding/generator/_jinja_filters.py +9 -0
- rasa/dialogue_understanding/generator/constants.py +4 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +18 -3
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +3 -3
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +7 -4
- rasa/dialogue_understanding/processor/command_processor.py +20 -5
- rasa/dialogue_understanding/processor/command_processor_component.py +5 -2
- rasa/dialogue_understanding_test/command_metric_calculation.py +7 -40
- rasa/dialogue_understanding_test/command_metrics.py +38 -0
- rasa/dialogue_understanding_test/du_test_case.py +58 -25
- rasa/dialogue_understanding_test/du_test_result.py +228 -132
- rasa/dialogue_understanding_test/du_test_runner.py +10 -1
- rasa/dialogue_understanding_test/io.py +35 -8
- rasa/e2e_test/llm_judge_prompts/answer_relevance_prompt_template.jinja2 +1 -1
- rasa/engine/validation.py +36 -1
- rasa/model_manager/model_api.py +1 -1
- rasa/model_manager/socket_bridge.py +0 -7
- rasa/model_training.py +2 -1
- rasa/shared/constants.py +2 -0
- rasa/shared/core/policies/__init__.py +0 -0
- rasa/shared/core/policies/utils.py +87 -0
- rasa/shared/core/slot_mappings.py +12 -0
- rasa/shared/core/slots.py +1 -1
- rasa/shared/core/trackers.py +4 -10
- rasa/shared/providers/llm/default_litellm_llm_client.py +2 -2
- rasa/tracing/instrumentation/attribute_extractors.py +38 -6
- rasa/version.py +1 -1
- {rasa_pro-3.12.4.dist-info → rasa_pro-3.13.0.dev1.dist-info}/METADATA +5 -6
- {rasa_pro-3.12.4.dist-info → rasa_pro-3.13.0.dev1.dist-info}/RECORD +45 -43
- {rasa_pro-3.12.4.dist-info → rasa_pro-3.13.0.dev1.dist-info}/WHEEL +1 -1
- README.md +0 -38
- rasa/keys +0 -1
- {rasa_pro-3.12.4.dist-info → rasa_pro-3.13.0.dev1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.4.dist-info → rasa_pro-3.13.0.dev1.dist-info}/entry_points.txt +0 -0
|
@@ -33,3 +33,7 @@ LLM_BASED_COMMAND_GENERATOR_CONFIG_FILE = "config.json"
|
|
|
33
33
|
|
|
34
34
|
MODEL_NAME_GPT_4O_2024_11_20 = "gpt-4o-2024-11-20"
|
|
35
35
|
MODEL_NAME_CLAUDE_3_5_SONNET_20240620 = "claude-3-5-sonnet-20240620"
|
|
36
|
+
|
|
37
|
+
# JINJA template filters
|
|
38
|
+
|
|
39
|
+
TO_JSON_ESCAPED_STRING_JINJA_FILTER = "to_json_escaped_string"
|
|
@@ -3,7 +3,7 @@ from functools import lru_cache
|
|
|
3
3
|
from typing import Any, Dict, List, Optional, Set, Text, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import structlog
|
|
6
|
-
from jinja2 import Template
|
|
6
|
+
from jinja2 import Environment, Template, select_autoescape
|
|
7
7
|
|
|
8
8
|
import rasa.dialogue_understanding.generator.utils
|
|
9
9
|
import rasa.shared.utils.io
|
|
@@ -17,12 +17,14 @@ from rasa.dialogue_understanding.commands.handle_digressions_command import (
|
|
|
17
17
|
)
|
|
18
18
|
from rasa.dialogue_understanding.constants import KEY_MINIMIZE_NUM_CALLS
|
|
19
19
|
from rasa.dialogue_understanding.generator import CommandGenerator
|
|
20
|
+
from rasa.dialogue_understanding.generator._jinja_filters import to_json_escaped_string
|
|
20
21
|
from rasa.dialogue_understanding.generator.constants import (
|
|
21
22
|
DEFAULT_LLM_CONFIG,
|
|
22
23
|
FLOW_RETRIEVAL_ACTIVE_KEY,
|
|
23
24
|
FLOW_RETRIEVAL_FLOW_THRESHOLD,
|
|
24
25
|
FLOW_RETRIEVAL_KEY,
|
|
25
26
|
LLM_CONFIG_KEY,
|
|
27
|
+
TO_JSON_ESCAPED_STRING_JINJA_FILTER,
|
|
26
28
|
)
|
|
27
29
|
from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
|
|
28
30
|
from rasa.dialogue_understanding.stack.utils import top_flow_frame
|
|
@@ -226,12 +228,25 @@ class LLMBasedCommandGenerator(
|
|
|
226
228
|
|
|
227
229
|
@lru_cache
|
|
228
230
|
def compile_template(self, template: str) -> Template:
|
|
229
|
-
"""
|
|
231
|
+
"""
|
|
232
|
+
Compile the prompt template and register custom filters.
|
|
230
233
|
|
|
231
234
|
Compiling the template is an expensive operation,
|
|
232
235
|
so we cache the result.
|
|
233
236
|
"""
|
|
234
|
-
|
|
237
|
+
# Create an environment
|
|
238
|
+
# Autoescaping disabled explicitly for LLM prompt templates rendered from
|
|
239
|
+
# strings (safe, not HTML)
|
|
240
|
+
env = Environment(
|
|
241
|
+
autoescape=select_autoescape(
|
|
242
|
+
disabled_extensions=["jinja2"], default_for_string=False, default=True
|
|
243
|
+
)
|
|
244
|
+
)
|
|
245
|
+
# Register filters
|
|
246
|
+
env.filters[TO_JSON_ESCAPED_STRING_JINJA_FILTER] = to_json_escaped_string
|
|
247
|
+
|
|
248
|
+
# Return the template which can leverage registered filters
|
|
249
|
+
return env.from_string(template)
|
|
235
250
|
|
|
236
251
|
@classmethod
|
|
237
252
|
def load_prompt_template_from_model_storage(
|
|
@@ -140,7 +140,7 @@ class NLUCommandAdapter(GraphComponent, CommandGenerator):
|
|
|
140
140
|
|
|
141
141
|
if commands:
|
|
142
142
|
commands = clean_up_commands(
|
|
143
|
-
commands, tracker, flows, self._execution_context
|
|
143
|
+
commands, tracker, flows, self._execution_context, domain
|
|
144
144
|
)
|
|
145
145
|
log_llm(
|
|
146
146
|
logger=structlogger,
|
|
@@ -8,7 +8,7 @@ Your task is to analyze the current conversation context and generate a list of
|
|
|
8
8
|
* `set slot slot_name slot_value`: Slot setting. For example, `set slot transfer_money_recipient Freddy`. Can be used to correct and change previously set values.
|
|
9
9
|
* `cancel flow`: Cancelling the current flow.
|
|
10
10
|
* `disambiguate flows flow_name1 flow_name2 ... flow_name_n`: Disambiguate which flow should be started when user input is ambiguous by listing the potential flows as options. For example, `disambiguate flows list_contacts add_contact remove_contact ...` if the user just wrote "contacts".
|
|
11
|
-
* `
|
|
11
|
+
* `search and reply`: Responding to the user's questions by supplying relevant information, such as answering FAQs or explaining services.
|
|
12
12
|
* `offtopic reply`: Responding to casual or social user messages that are unrelated to any flows, engaging in friendly conversation and addressing off-topic remarks.
|
|
13
13
|
* `hand over`: Handing over to a human, in case the user seems frustrated or explicitly asks to speak to one.
|
|
14
14
|
|
|
@@ -33,7 +33,7 @@ Your task is to analyze the current conversation context and generate a list of
|
|
|
33
33
|
## Available Flows and Slots
|
|
34
34
|
Use the following structured data:
|
|
35
35
|
```json
|
|
36
|
-
{"flows":[{% for flow in available_flows %}{"name":"{{ flow.name }}","description":
|
|
36
|
+
{"flows":[{% for flow in available_flows %}{"name":"{{ flow.name }}","description":{{ flow.description | to_json_escaped_string }}{% if flow.slots %},"slots":[{% for slot in flow.slots %}{"name":"{{ slot.name }}"{% if slot.description %},"description":{{ slot.description | to_json_escaped_string }}{% endif %}{% if slot.allowed_values %},"allowed_values":{{ slot.allowed_values }}{% endif %}}{% if not loop.last %},{% endif %}{% endfor %}]{% endif %}}{% if not loop.last %},{% endif %}{% endfor %}]}
|
|
37
37
|
```
|
|
38
38
|
|
|
39
39
|
--
|
|
@@ -41,7 +41,7 @@ Use the following structured data:
|
|
|
41
41
|
## Current State
|
|
42
42
|
{% if current_flow != None %}Use the following structured data:
|
|
43
43
|
```json
|
|
44
|
-
{"active_flow":"{{ current_flow }}","current_step":{"requested_slot":"{{ current_slot }}","requested_slot_description":
|
|
44
|
+
{"active_flow":"{{ current_flow }}","current_step":{"requested_slot":"{{ current_slot }}","requested_slot_description":{{ current_slot_description | to_json_escaped_string }}},"slots":[{% for slot in flow_slots %}{"name":"{{ slot.name }}","value":"{{ slot.value }}","type":"{{ slot.type }}"{% if slot.description %},"description":{{ slot.description | to_json_escaped_string }}{% endif %}{% if slot.allowed_values %},"allowed_values":"{{ slot.allowed_values }}"{% endif %}}{% if not loop.last %},{% endif %}{% endfor %}]}
|
|
45
45
|
```{% else %}
|
|
46
46
|
You are currently not inside any flow.{% endif %}
|
|
47
47
|
|
|
@@ -6,7 +6,7 @@ Your task is to analyze the current conversation context and generate a list of
|
|
|
6
6
|
## Available Flows and Slots
|
|
7
7
|
Use the following structured data:
|
|
8
8
|
```json
|
|
9
|
-
{"flows":[{% for flow in available_flows %}{"name":"{{ flow.name }}","description":
|
|
9
|
+
{"flows":[{% for flow in available_flows %}{"name":"{{ flow.name }}","description":{{ flow.description | to_json_escaped_string }}{% if flow.slots %},"slots":[{% for slot in flow.slots %}{"name":"{{ slot.name }}"{% if slot.description %},"description":{{ slot.description | to_json_escaped_string }}{% endif %}{% if slot.allowed_values %},"allowed_values":{{ slot.allowed_values }}{% endif %}}{% if not loop.last %},{% endif %}{% endfor %}]{% endif %}}{% if not loop.last %},{% endif %}{% endfor %}]}
|
|
10
10
|
```
|
|
11
11
|
|
|
12
12
|
---
|
|
@@ -16,7 +16,7 @@ Use the following structured data:
|
|
|
16
16
|
* `set slot slot_name slot_value`: Slot setting. For example, `set slot transfer_money_recipient Freddy`. Can be used to correct and change previously set values.
|
|
17
17
|
* `cancel flow`: Cancelling the current flow.
|
|
18
18
|
* `disambiguate flows flow_name1 flow_name2 ... flow_name_n`: Disambiguate which flow should be started when user input is ambiguous by listing the potential flows as options. For example, `disambiguate flows list_contacts add_contact remove_contact ...` if the user just wrote "contacts".
|
|
19
|
-
* `
|
|
19
|
+
* `search and reply`: Responding to the user's message by accessing and supplying relevant information from the knowledge base to address their inquiry effectively.
|
|
20
20
|
* `offtopic reply`: Responding to casual or social user messages that are unrelated to any flows, engaging in friendly conversation and addressing off-topic remarks.
|
|
21
21
|
* `hand over`: Handing over to a human, in case the user seems frustrated or explicitly asks to speak to one.
|
|
22
22
|
|
|
@@ -27,8 +27,11 @@ Use the following structured data:
|
|
|
27
27
|
* For categorical slots try to match the user message with allowed slot values. Use "other" if you cannot match it.
|
|
28
28
|
* Set the boolean slots based on the user response. Map positive responses to `True`, and negative to `False`.
|
|
29
29
|
* Extract text slot values exactly as provided by the user. Avoid assumptions, format changes, or partial extractions.
|
|
30
|
-
* Only use information provided by the user.
|
|
31
30
|
* Use clarification in ambiguous cases.
|
|
31
|
+
* Use `disambiguate flows` only when multiple flows could fit the same message (e.g., "card" could mean `block_card` or `replace_card`).
|
|
32
|
+
* A user asking a question does not automatically imply that they want `search and reply`. The objective is to help them complete a business process if its possible to do so via a flow.
|
|
33
|
+
* **Flow Priority**: If a user message can be addressed by starting a flow (even if it looks like a general question), ALWAYS start the flow first. Example: If the user says "How do I activate my card?", use `start flow activate_card` instead of `search and reply`. Only use `search and reply` if no flow matches the request.
|
|
34
|
+
* Only use information provided by the user.
|
|
32
35
|
* Multiple flows can be started. If a user wants to digress into a second flow, you do not need to cancel the current flow.
|
|
33
36
|
* Do not cancel the flow unless the user explicitly requests it.
|
|
34
37
|
* Strictly adhere to the provided action format.
|
|
@@ -40,7 +43,7 @@ Use the following structured data:
|
|
|
40
43
|
## Current State
|
|
41
44
|
{% if current_flow != None %}Use the following structured data:
|
|
42
45
|
```json
|
|
43
|
-
{"active_flow":"{{ current_flow }}","current_step":{"requested_slot":"{{ current_slot }}","requested_slot_description":
|
|
46
|
+
{"active_flow":"{{ current_flow }}","current_step":{"requested_slot":"{{ current_slot }}","requested_slot_description":{{ current_slot_description | to_json_escaped_string }}},"slots":[{% for slot in flow_slots %}{"name":"{{ slot.name }}","value":"{{ slot.value }}","type":"{{ slot.type }}"{% if slot.description %},"description":{{ slot.description | to_json_escaped_string }}{% endif %}{% if slot.allowed_values %},"allowed_values":"{{ slot.allowed_values }}"{% endif %}}{% if not loop.last %},{% endif %}{% endfor %}]}
|
|
44
47
|
```{% else %}
|
|
45
48
|
You are currently not inside any flow.{% endif %}
|
|
46
49
|
|
|
@@ -54,9 +54,11 @@ from rasa.shared.core.constants import (
|
|
|
54
54
|
FLOW_HASHES_SLOT,
|
|
55
55
|
SlotMappingType,
|
|
56
56
|
)
|
|
57
|
+
from rasa.shared.core.domain import Domain
|
|
57
58
|
from rasa.shared.core.events import Event, SlotSet
|
|
58
59
|
from rasa.shared.core.flows import FlowsList
|
|
59
60
|
from rasa.shared.core.flows.steps.collect import CollectInformationFlowStep
|
|
61
|
+
from rasa.shared.core.policies.utils import contains_intentless_policy_responses
|
|
60
62
|
from rasa.shared.core.slot_mappings import SlotMapping
|
|
61
63
|
from rasa.shared.core.slots import Slot
|
|
62
64
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
@@ -197,6 +199,7 @@ def execute_commands(
|
|
|
197
199
|
all_flows: FlowsList,
|
|
198
200
|
execution_context: ExecutionContext,
|
|
199
201
|
story_graph: Optional[StoryGraph] = None,
|
|
202
|
+
domain: Optional[Domain] = None,
|
|
200
203
|
) -> List[Event]:
|
|
201
204
|
"""Executes a list of commands.
|
|
202
205
|
|
|
@@ -206,6 +209,7 @@ def execute_commands(
|
|
|
206
209
|
all_flows: All flows.
|
|
207
210
|
execution_context: Information about the single graph run.
|
|
208
211
|
story_graph: StoryGraph object with stories available for training.
|
|
212
|
+
domain: The domain of the bot.
|
|
209
213
|
|
|
210
214
|
Returns:
|
|
211
215
|
A list of the events that were created.
|
|
@@ -214,7 +218,7 @@ def execute_commands(
|
|
|
214
218
|
original_tracker = tracker.copy()
|
|
215
219
|
|
|
216
220
|
commands = clean_up_commands(
|
|
217
|
-
commands, tracker, all_flows, execution_context, story_graph
|
|
221
|
+
commands, tracker, all_flows, execution_context, story_graph, domain
|
|
218
222
|
)
|
|
219
223
|
|
|
220
224
|
updated_flows = find_updated_flows(tracker, all_flows)
|
|
@@ -381,6 +385,7 @@ def clean_up_commands(
|
|
|
381
385
|
all_flows: FlowsList,
|
|
382
386
|
execution_context: ExecutionContext,
|
|
383
387
|
story_graph: Optional[StoryGraph] = None,
|
|
388
|
+
domain: Optional[Domain] = None,
|
|
384
389
|
) -> List[Command]:
|
|
385
390
|
"""Clean up a list of commands.
|
|
386
391
|
|
|
@@ -396,10 +401,13 @@ def clean_up_commands(
|
|
|
396
401
|
all_flows: All flows.
|
|
397
402
|
execution_context: Information about a single graph run.
|
|
398
403
|
story_graph: StoryGraph object with stories available for training.
|
|
404
|
+
domain: The domain of the bot.
|
|
399
405
|
|
|
400
406
|
Returns:
|
|
401
407
|
The cleaned up commands.
|
|
402
408
|
"""
|
|
409
|
+
domain = domain if domain else Domain.empty()
|
|
410
|
+
|
|
403
411
|
slots_so_far, active_flow = filled_slots_for_active_flow(tracker, all_flows)
|
|
404
412
|
|
|
405
413
|
clean_commands: List[Command] = []
|
|
@@ -465,7 +473,12 @@ def clean_up_commands(
|
|
|
465
473
|
# handle chitchat command differently from other free-form answer commands
|
|
466
474
|
elif isinstance(command, ChitChatAnswerCommand):
|
|
467
475
|
clean_commands = clean_up_chitchat_command(
|
|
468
|
-
clean_commands,
|
|
476
|
+
clean_commands,
|
|
477
|
+
command,
|
|
478
|
+
all_flows,
|
|
479
|
+
execution_context,
|
|
480
|
+
domain,
|
|
481
|
+
story_graph,
|
|
469
482
|
)
|
|
470
483
|
|
|
471
484
|
elif isinstance(command, FreeFormAnswerCommand):
|
|
@@ -708,6 +721,7 @@ def clean_up_chitchat_command(
|
|
|
708
721
|
command: ChitChatAnswerCommand,
|
|
709
722
|
flows: FlowsList,
|
|
710
723
|
execution_context: ExecutionContext,
|
|
724
|
+
domain: Domain,
|
|
711
725
|
story_graph: Optional[StoryGraph] = None,
|
|
712
726
|
) -> List[Command]:
|
|
713
727
|
"""Clean up a chitchat answer command.
|
|
@@ -721,6 +735,8 @@ def clean_up_chitchat_command(
|
|
|
721
735
|
flows: All flows.
|
|
722
736
|
execution_context: Information about a single graph run.
|
|
723
737
|
story_graph: StoryGraph object with stories available for training.
|
|
738
|
+
domain: The domain of the bot.
|
|
739
|
+
|
|
724
740
|
Returns:
|
|
725
741
|
The cleaned up commands.
|
|
726
742
|
"""
|
|
@@ -746,10 +762,9 @@ def clean_up_chitchat_command(
|
|
|
746
762
|
)
|
|
747
763
|
defines_intentless_policy = execution_context.has_node(IntentlessPolicy)
|
|
748
764
|
|
|
749
|
-
has_e2e_stories = True if (story_graph and story_graph.has_e2e_stories()) else False
|
|
750
|
-
|
|
751
765
|
if (has_action_trigger_chitchat and not defines_intentless_policy) or (
|
|
752
|
-
defines_intentless_policy
|
|
766
|
+
defines_intentless_policy
|
|
767
|
+
and not contains_intentless_policy_responses(flows, domain, story_graph)
|
|
753
768
|
):
|
|
754
769
|
resulting_commands.insert(
|
|
755
770
|
0, CannotHandleCommand(RASA_PATTERN_CANNOT_HANDLE_CHITCHAT)
|
|
@@ -6,6 +6,7 @@ import rasa.dialogue_understanding.processor.command_processor
|
|
|
6
6
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
7
7
|
from rasa.engine.storage.resource import Resource
|
|
8
8
|
from rasa.engine.storage.storage import ModelStorage
|
|
9
|
+
from rasa.shared.core.domain import Domain
|
|
9
10
|
from rasa.shared.core.events import Event
|
|
10
11
|
from rasa.shared.core.flows import FlowsList
|
|
11
12
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
@@ -15,7 +16,8 @@ from rasa.shared.core.training_data.structures import StoryGraph
|
|
|
15
16
|
class CommandProcessorComponent(GraphComponent):
|
|
16
17
|
"""Processes commands by issuing events to modify a tracker.
|
|
17
18
|
|
|
18
|
-
Minimal component that applies commands to a tracker.
|
|
19
|
+
Minimal component that applies commands to a tracker.
|
|
20
|
+
"""
|
|
19
21
|
|
|
20
22
|
def __init__(self, execution_context: ExecutionContext):
|
|
21
23
|
self._execution_context = execution_context
|
|
@@ -36,8 +38,9 @@ class CommandProcessorComponent(GraphComponent):
|
|
|
36
38
|
tracker: DialogueStateTracker,
|
|
37
39
|
flows: FlowsList,
|
|
38
40
|
story_graph: StoryGraph,
|
|
41
|
+
domain: Domain,
|
|
39
42
|
) -> List[Event]:
|
|
40
43
|
"""Execute commands to update tracker state."""
|
|
41
44
|
return rasa.dialogue_understanding.processor.command_processor.execute_commands(
|
|
42
|
-
tracker, flows, self._execution_context, story_graph
|
|
45
|
+
tracker, flows, self._execution_context, story_graph, domain
|
|
43
46
|
)
|
|
@@ -1,54 +1,21 @@
|
|
|
1
|
+
import typing
|
|
1
2
|
from collections import defaultdict
|
|
2
3
|
from typing import Dict, List
|
|
3
4
|
|
|
4
|
-
from pydantic import BaseModel
|
|
5
|
-
|
|
6
5
|
from rasa.dialogue_understanding.commands import Command
|
|
7
6
|
from rasa.dialogue_understanding_test.command_comparison import (
|
|
8
7
|
is_command_present_in_list,
|
|
9
8
|
)
|
|
10
|
-
from rasa.dialogue_understanding_test.
|
|
11
|
-
DialogueUnderstandingTestResult,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class CommandMetrics(BaseModel):
|
|
16
|
-
tp: int
|
|
17
|
-
fp: int
|
|
18
|
-
fn: int
|
|
19
|
-
total_count: int
|
|
20
|
-
|
|
21
|
-
@staticmethod
|
|
22
|
-
def _safe_divide(numerator: float, denominator: float) -> float:
|
|
23
|
-
"""Safely perform division, returning 0.0 if the denominator is zero."""
|
|
24
|
-
return numerator / denominator if denominator > 0 else 0.0
|
|
9
|
+
from rasa.dialogue_understanding_test.command_metrics import CommandMetrics
|
|
25
10
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
return self._safe_divide(self.tp, self.tp + self.fn)
|
|
31
|
-
|
|
32
|
-
def get_f1_score(self) -> float:
|
|
33
|
-
precision = self.get_precision()
|
|
34
|
-
recall = self.get_recall()
|
|
35
|
-
|
|
36
|
-
return self._safe_divide(2 * precision * recall, precision + recall)
|
|
37
|
-
|
|
38
|
-
def as_dict(self) -> Dict[str, float]:
|
|
39
|
-
return {
|
|
40
|
-
"tp": self.tp,
|
|
41
|
-
"fp": self.fp,
|
|
42
|
-
"fn": self.fn,
|
|
43
|
-
"precision": self.get_precision(),
|
|
44
|
-
"recall": self.get_recall(),
|
|
45
|
-
"f1_score": self.get_f1_score(),
|
|
46
|
-
"total_count": self.total_count,
|
|
47
|
-
}
|
|
11
|
+
if typing.TYPE_CHECKING:
|
|
12
|
+
from rasa.dialogue_understanding_test.du_test_result import (
|
|
13
|
+
DialogueUnderstandingTestResult,
|
|
14
|
+
)
|
|
48
15
|
|
|
49
16
|
|
|
50
17
|
def calculate_command_metrics(
|
|
51
|
-
test_results: List[DialogueUnderstandingTestResult],
|
|
18
|
+
test_results: List["DialogueUnderstandingTestResult"],
|
|
52
19
|
) -> Dict[str, CommandMetrics]:
|
|
53
20
|
"""Calculate the command metrics for the test result."""
|
|
54
21
|
metrics: Dict[str, CommandMetrics] = defaultdict(
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CommandMetrics(BaseModel):
|
|
7
|
+
tp: int
|
|
8
|
+
fp: int
|
|
9
|
+
fn: int
|
|
10
|
+
total_count: int
|
|
11
|
+
|
|
12
|
+
@staticmethod
|
|
13
|
+
def _safe_divide(numerator: float, denominator: float) -> float:
|
|
14
|
+
"""Safely perform division, returning 0.0 if the denominator is zero."""
|
|
15
|
+
return numerator / denominator if denominator > 0 else 0.0
|
|
16
|
+
|
|
17
|
+
def get_precision(self) -> float:
|
|
18
|
+
return self._safe_divide(self.tp, self.tp + self.fp)
|
|
19
|
+
|
|
20
|
+
def get_recall(self) -> float:
|
|
21
|
+
return self._safe_divide(self.tp, self.tp + self.fn)
|
|
22
|
+
|
|
23
|
+
def get_f1_score(self) -> float:
|
|
24
|
+
precision = self.get_precision()
|
|
25
|
+
recall = self.get_recall()
|
|
26
|
+
|
|
27
|
+
return self._safe_divide(2 * precision * recall, precision + recall)
|
|
28
|
+
|
|
29
|
+
def as_dict(self) -> Dict[str, float]:
|
|
30
|
+
return {
|
|
31
|
+
"tp": self.tp,
|
|
32
|
+
"fp": self.fp,
|
|
33
|
+
"fn": self.fn,
|
|
34
|
+
"precision": self.get_precision(),
|
|
35
|
+
"recall": self.get_recall(),
|
|
36
|
+
"f1_score": self.get_f1_score(),
|
|
37
|
+
"total_count": self.total_count,
|
|
38
|
+
}
|
|
@@ -1,7 +1,11 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
1
2
|
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
|
2
3
|
|
|
3
4
|
from pydantic import BaseModel, Field
|
|
4
5
|
|
|
6
|
+
from rasa.core import IntentlessPolicy
|
|
7
|
+
from rasa.core.nlg.contextual_response_rephraser import ContextualResponseRephraser
|
|
8
|
+
from rasa.core.policies.enterprise_search_policy import EnterpriseSearchPolicy
|
|
5
9
|
from rasa.dialogue_understanding.commands.prompt_command import PromptCommand
|
|
6
10
|
from rasa.dialogue_understanding.generator.command_parser import parse_commands
|
|
7
11
|
from rasa.dialogue_understanding_test.command_comparison import are_command_lists_equal
|
|
@@ -69,6 +73,8 @@ class DialogueUnderstandingOutput(BaseModel):
|
|
|
69
73
|
commands: Dict[str, List[PromptCommand]]
|
|
70
74
|
# List of prompts
|
|
71
75
|
prompts: Optional[List[Dict[str, Any]]] = None
|
|
76
|
+
# Latency of the full message roundtrip
|
|
77
|
+
latency: Optional[float] = None
|
|
72
78
|
|
|
73
79
|
class Config:
|
|
74
80
|
"""Skip validation for PromptCommand protocol as pydantic does not know how to
|
|
@@ -88,27 +94,41 @@ class DialogueUnderstandingOutput(BaseModel):
|
|
|
88
94
|
def get_component_names_that_predicted_commands_or_have_llm_response(
|
|
89
95
|
self,
|
|
90
96
|
) -> List[str]:
|
|
91
|
-
"""Get all component names
|
|
97
|
+
"""Get all relevant component names.
|
|
98
|
+
|
|
99
|
+
Components are relevant if they have predicted commands or received a
|
|
92
100
|
non-empty response from LLM.
|
|
93
101
|
"""
|
|
102
|
+
# Exclude components that are not related to Dialogue Understanding
|
|
103
|
+
component_names_to_exclude = [
|
|
104
|
+
EnterpriseSearchPolicy.__name__,
|
|
105
|
+
IntentlessPolicy.__name__,
|
|
106
|
+
ContextualResponseRephraser.__name__,
|
|
107
|
+
]
|
|
108
|
+
|
|
94
109
|
component_names_that_predicted_commands = (
|
|
95
110
|
[
|
|
96
111
|
component_name
|
|
97
112
|
for component_name, predicted_commands in self.commands.items()
|
|
98
113
|
if predicted_commands
|
|
114
|
+
and component_name not in component_names_to_exclude
|
|
99
115
|
]
|
|
100
116
|
if self.commands
|
|
101
117
|
else []
|
|
102
118
|
)
|
|
119
|
+
|
|
103
120
|
components_with_prompts = (
|
|
104
121
|
[
|
|
105
122
|
str(prompt.get(KEY_COMPONENT_NAME, None))
|
|
106
123
|
for prompt in self.prompts
|
|
107
124
|
if prompt.get(KEY_LLM_RESPONSE_METADATA, None)
|
|
125
|
+
and prompt.get(KEY_COMPONENT_NAME, None)
|
|
126
|
+
not in component_names_to_exclude
|
|
108
127
|
]
|
|
109
128
|
if self.prompts
|
|
110
129
|
else []
|
|
111
130
|
)
|
|
131
|
+
|
|
112
132
|
return list(
|
|
113
133
|
set(component_names_that_predicted_commands + components_with_prompts)
|
|
114
134
|
)
|
|
@@ -290,41 +310,54 @@ class DialogueUnderstandingTestStep(BaseModel):
|
|
|
290
310
|
|
|
291
311
|
return ""
|
|
292
312
|
|
|
293
|
-
def get_latencies(self) -> List[float]:
|
|
313
|
+
def get_latencies(self) -> Dict[str, List[float]]:
|
|
294
314
|
if self.dialogue_understanding_output is None:
|
|
295
|
-
return
|
|
315
|
+
return {}
|
|
296
316
|
|
|
297
|
-
|
|
317
|
+
component_name_to_prompt_info = (
|
|
318
|
+
self.dialogue_understanding_output.get_component_name_to_prompt_info()
|
|
319
|
+
)
|
|
298
320
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
for
|
|
302
|
-
|
|
303
|
-
]
|
|
321
|
+
latencies = defaultdict(list)
|
|
322
|
+
for component_name, prompt_info_list in component_name_to_prompt_info.items():
|
|
323
|
+
for prompt_info in prompt_info_list:
|
|
324
|
+
latencies[component_name].append(prompt_info.get(KEY_LATENCY, 0.0))
|
|
304
325
|
|
|
305
|
-
|
|
326
|
+
return latencies
|
|
327
|
+
|
|
328
|
+
def get_completion_tokens(self) -> Dict[str, List[float]]:
|
|
306
329
|
if self.dialogue_understanding_output is None:
|
|
307
|
-
return
|
|
330
|
+
return {}
|
|
308
331
|
|
|
309
|
-
|
|
332
|
+
component_name_to_prompt_info = (
|
|
333
|
+
self.dialogue_understanding_output.get_component_name_to_prompt_info()
|
|
334
|
+
)
|
|
310
335
|
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
for
|
|
314
|
-
|
|
315
|
-
|
|
336
|
+
completion_tokens = defaultdict(list)
|
|
337
|
+
for component_name, prompt_info_list in component_name_to_prompt_info.items():
|
|
338
|
+
for prompt_info in prompt_info_list:
|
|
339
|
+
completion_tokens[component_name].append(
|
|
340
|
+
prompt_info.get(KEY_COMPLETION_TOKENS, 0.0)
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
return completion_tokens
|
|
316
344
|
|
|
317
|
-
def get_prompt_tokens(self) -> List[
|
|
345
|
+
def get_prompt_tokens(self) -> Dict[str, List[float]]:
|
|
318
346
|
if self.dialogue_understanding_output is None:
|
|
319
|
-
return
|
|
347
|
+
return {}
|
|
320
348
|
|
|
321
|
-
|
|
349
|
+
component_name_to_prompt_info = (
|
|
350
|
+
self.dialogue_understanding_output.get_component_name_to_prompt_info()
|
|
351
|
+
)
|
|
322
352
|
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
for
|
|
326
|
-
|
|
327
|
-
|
|
353
|
+
prompt_tokens = defaultdict(list)
|
|
354
|
+
for component_name, prompt_info_list in component_name_to_prompt_info.items():
|
|
355
|
+
for prompt_info in prompt_info_list:
|
|
356
|
+
prompt_tokens[component_name].append(
|
|
357
|
+
prompt_info.get(KEY_PROMPT_TOKENS, 0.0)
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
return prompt_tokens
|
|
328
361
|
|
|
329
362
|
|
|
330
363
|
class DialogueUnderstandingTestCase(BaseModel):
|