rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (183) hide show
  1. README.md +0 -374
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +27 -23
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +11 -3
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +104 -431
  12. rasa/cli/evaluate.py +1 -1
  13. rasa/cli/interactive.py +1 -0
  14. rasa/cli/llm_fine_tuning.py +398 -0
  15. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  16. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  17. rasa/cli/run.py +15 -14
  18. rasa/cli/scaffold.py +10 -8
  19. rasa/cli/studio/studio.py +35 -5
  20. rasa/cli/train.py +56 -8
  21. rasa/cli/utils.py +22 -5
  22. rasa/cli/x.py +1 -1
  23. rasa/constants.py +7 -1
  24. rasa/core/actions/action.py +98 -49
  25. rasa/core/actions/action_run_slot_rejections.py +4 -1
  26. rasa/core/actions/custom_action_executor.py +9 -6
  27. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  29. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  30. rasa/core/actions/http_custom_action_executor.py +6 -5
  31. rasa/core/agent.py +21 -17
  32. rasa/core/channels/__init__.py +2 -0
  33. rasa/core/channels/audiocodes.py +1 -16
  34. rasa/core/channels/voice_aware/__init__.py +0 -0
  35. rasa/core/channels/voice_aware/jambonz.py +103 -0
  36. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  37. rasa/core/channels/voice_aware/utils.py +20 -0
  38. rasa/core/channels/voice_native/__init__.py +0 -0
  39. rasa/core/constants.py +6 -1
  40. rasa/core/information_retrieval/faiss.py +7 -4
  41. rasa/core/information_retrieval/information_retrieval.py +8 -0
  42. rasa/core/information_retrieval/milvus.py +9 -2
  43. rasa/core/information_retrieval/qdrant.py +1 -1
  44. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  45. rasa/core/nlg/summarize.py +4 -3
  46. rasa/core/policies/enterprise_search_policy.py +113 -45
  47. rasa/core/policies/flows/flow_executor.py +122 -76
  48. rasa/core/policies/intentless_policy.py +83 -29
  49. rasa/core/processor.py +72 -54
  50. rasa/core/run.py +5 -4
  51. rasa/core/tracker_store.py +8 -4
  52. rasa/core/training/interactive.py +1 -1
  53. rasa/core/utils.py +56 -57
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
  55. rasa/dialogue_understanding/commands/__init__.py +6 -0
  56. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/utils.py +40 -0
  59. rasa/dialogue_understanding/generator/constants.py +10 -3
  60. rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
  61. rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
  62. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
  63. rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
  64. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
  65. rasa/dialogue_understanding/patterns/restart.py +37 -0
  66. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  67. rasa/dialogue_understanding/processor/command_processor.py +16 -3
  68. rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  70. rasa/e2e_test/assertions.py +1223 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +493 -71
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +598 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/graph.py +9 -3
  89. rasa/engine/recipes/default_components.py +0 -2
  90. rasa/engine/recipes/default_recipe.py +10 -2
  91. rasa/engine/storage/local_model_storage.py +40 -12
  92. rasa/engine/validation.py +78 -1
  93. rasa/env.py +9 -0
  94. rasa/graph_components/providers/story_graph_provider.py +59 -6
  95. rasa/llm_fine_tuning/__init__.py +0 -0
  96. rasa/llm_fine_tuning/annotation_module.py +241 -0
  97. rasa/llm_fine_tuning/conversations.py +144 -0
  98. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  99. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  100. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  101. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  102. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  104. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  105. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  106. rasa/llm_fine_tuning/storage.py +174 -0
  107. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  108. rasa/model_training.py +56 -16
  109. rasa/nlu/persistor.py +157 -36
  110. rasa/server.py +45 -10
  111. rasa/shared/constants.py +76 -16
  112. rasa/shared/core/domain.py +27 -19
  113. rasa/shared/core/events.py +28 -2
  114. rasa/shared/core/flows/flow.py +208 -13
  115. rasa/shared/core/flows/flow_path.py +84 -0
  116. rasa/shared/core/flows/flows_list.py +33 -11
  117. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  118. rasa/shared/core/flows/validation.py +112 -25
  119. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  120. rasa/shared/core/trackers.py +6 -0
  121. rasa/shared/core/training_data/structures.py +20 -0
  122. rasa/shared/core/training_data/visualization.html +2 -2
  123. rasa/shared/exceptions.py +4 -0
  124. rasa/shared/importers/importer.py +64 -16
  125. rasa/shared/nlu/constants.py +2 -0
  126. rasa/shared/providers/_configs/__init__.py +0 -0
  127. rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
  128. rasa/shared/providers/_configs/client_config.py +57 -0
  129. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  130. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  131. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  132. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
  133. rasa/shared/providers/_configs/utils.py +101 -0
  134. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  135. rasa/shared/providers/embedding/__init__.py +0 -0
  136. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
  137. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  138. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  139. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  140. rasa/shared/providers/embedding/embedding_client.py +90 -0
  141. rasa/shared/providers/embedding/embedding_response.py +41 -0
  142. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  143. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  144. rasa/shared/providers/llm/__init__.py +0 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +251 -0
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  147. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  148. rasa/shared/providers/llm/llm_client.py +76 -0
  149. rasa/shared/providers/llm/llm_response.py +50 -0
  150. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  151. rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
  152. rasa/shared/providers/mappings.py +75 -0
  153. rasa/shared/utils/cli.py +30 -0
  154. rasa/shared/utils/io.py +65 -2
  155. rasa/shared/utils/llm.py +246 -200
  156. rasa/shared/utils/yaml.py +121 -15
  157. rasa/studio/auth.py +6 -4
  158. rasa/studio/config.py +13 -4
  159. rasa/studio/constants.py +1 -0
  160. rasa/studio/data_handler.py +10 -3
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +19 -11
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/instrumentation/attribute_extractors.py +32 -17
  166. rasa/utils/common.py +18 -19
  167. rasa/utils/endpoints.py +7 -4
  168. rasa/utils/json_utils.py +60 -0
  169. rasa/utils/licensing.py +9 -1
  170. rasa/utils/ml_utils.py +4 -2
  171. rasa/validator.py +213 -3
  172. rasa/version.py +1 -1
  173. rasa_pro-3.10.16.dist-info/METADATA +196 -0
  174. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
  175. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  176. rasa/shared/providers/openai/clients.py +0 -43
  177. rasa/shared/providers/openai/session_handler.py +0 -110
  178. rasa_pro-3.9.18.dist-info/METADATA +0 -563
  179. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  180. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
  182. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
  183. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1223 @@
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import json
5
+ import re
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from functools import lru_cache
9
+ from typing import (
10
+ Any,
11
+ Callable,
12
+ Dict,
13
+ List,
14
+ Optional,
15
+ Set,
16
+ TYPE_CHECKING,
17
+ Text,
18
+ Tuple,
19
+ Type,
20
+ )
21
+
22
+ import pandas as pd
23
+ import structlog
24
+
25
+ import rasa.shared.utils.common
26
+ from rasa.core.constants import (
27
+ DOMAIN_GROUND_TRUTH_METADATA_KEY,
28
+ UTTER_SOURCE_METADATA_KEY,
29
+ )
30
+ from rasa.core.policies.enterprise_search_policy import (
31
+ SEARCH_QUERY_METADATA_KEY,
32
+ SEARCH_RESULTS_METADATA_KEY,
33
+ )
34
+ from rasa.dialogue_understanding.patterns.clarify import FLOW_PATTERN_CLARIFICATION
35
+ from rasa.shared.core.constants import DEFAULT_SLOT_NAMES
36
+ from rasa.shared.core.events import (
37
+ ActionExecuted,
38
+ BotUttered,
39
+ DefinePrevUserUtteredFeaturization,
40
+ DialogueStackUpdated,
41
+ Event,
42
+ FlowCancelled,
43
+ FlowCompleted,
44
+ FlowStarted,
45
+ SlotSet,
46
+ )
47
+ from rasa.shared.exceptions import RasaException
48
+ from rasa.utils.common import update_mlflow_log_level
49
+ from rasa.utils.json_utils import SetEncoder
50
+
51
+ if TYPE_CHECKING:
52
+ from rasa.e2e_test.e2e_config import LLMJudgeConfig
53
+
54
+
55
+ structlogger = structlog.get_logger()
56
+
57
+ DEFAULT_THRESHOLD = 0.5
58
+ ELIGIBLE_UTTER_SOURCE_METADATA = [
59
+ "EnterpriseSearchPolicy",
60
+ "ContextualResponseRephraser",
61
+ "IntentlessPolicy",
62
+ ]
63
+
64
+
65
+ class AssertionType(Enum):
66
+ FLOW_STARTED = "flow_started"
67
+ FLOW_COMPLETED = "flow_completed"
68
+ FLOW_CANCELLED = "flow_cancelled"
69
+ PATTERN_CLARIFICATION_CONTAINS = "pattern_clarification_contains"
70
+ ACTION_EXECUTED = "action_executed"
71
+ SLOT_WAS_SET = "slot_was_set"
72
+ SLOT_WAS_NOT_SET = "slot_was_not_set"
73
+ BOT_UTTERED = "bot_uttered"
74
+ GENERATIVE_RESPONSE_IS_RELEVANT = "generative_response_is_relevant"
75
+ GENERATIVE_RESPONSE_IS_GROUNDED = "generative_response_is_grounded"
76
+
77
+
78
+ @lru_cache(maxsize=1)
79
+ def _get_all_assertion_subclasses() -> Dict[str, Type[Assertion]]:
80
+ return {
81
+ sub_class.type(): sub_class
82
+ for sub_class in rasa.shared.utils.common.all_subclasses(Assertion)
83
+ }
84
+
85
+
86
+ class InvalidAssertionType(RasaException):
87
+ """Raised if an assertion type is invalid."""
88
+
89
+ def __init__(self, assertion_type: str) -> None:
90
+ """Creates a `InvalidAssertionType`.
91
+
92
+ Args:
93
+ assertion_type: The invalid assertion type.
94
+ """
95
+ super().__init__(f"Invalid assertion type '{assertion_type}'.")
96
+
97
+
98
+ @dataclass
99
+ class Assertion:
100
+ """Base class for storing assertions."""
101
+
102
+ @classmethod
103
+ def type(cls) -> str:
104
+ """Returns the type of the assertion."""
105
+ raise NotImplementedError
106
+
107
+ @staticmethod
108
+ def from_dict(assertion_dict: Dict[Text, Any]) -> Assertion:
109
+ """Creates an assertion from a dictionary."""
110
+ raise NotImplementedError
111
+
112
+ def as_dict(self) -> Dict[str, Any]:
113
+ """Return the `Assertion` as a dictionary.
114
+
115
+ Returns:
116
+ The `Assertion` as a dictionary.
117
+ """
118
+ data = dataclasses.asdict(self)
119
+ data["type"] = self.type()
120
+ return data
121
+
122
+ @staticmethod
123
+ def create_typed_assertion(data: Dict[str, Any]) -> Assertion:
124
+ """Creates a `Assertion` from a dictionary.
125
+
126
+ Args:
127
+ data: The dictionary to create the `Assertion` from.
128
+
129
+ Returns:
130
+ The created `Assertion`.
131
+ """
132
+ typ = next(iter(data.keys()))
133
+
134
+ subclass_mapping = _get_all_assertion_subclasses()
135
+
136
+ clazz = subclass_mapping.get(typ)
137
+
138
+ if clazz is None:
139
+ structlogger.warning("assertion.unknown_type", data=data)
140
+ raise InvalidAssertionType(typ)
141
+
142
+ try:
143
+ return clazz.from_dict(data)
144
+ except NotImplementedError:
145
+ structlogger.warning("assertion.unknown_type", data=data)
146
+ raise InvalidAssertionType(typ)
147
+
148
+ def run(
149
+ self,
150
+ turn_events: List[Event],
151
+ prior_events: List[Event],
152
+ assertion_order_error_message: str = "",
153
+ **kwargs: Any,
154
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
155
+ """Run the assertion on the given events for that user turn.
156
+
157
+ Args:
158
+ turn_events: The events to run the assertion on.
159
+ prior_events: All events prior to the current turn.
160
+ assertion_order_error_message: The error message to append if the assertion
161
+ order is enabled.
162
+ kwargs: Additional keyword arguments.
163
+
164
+ Returns:
165
+ A tuple of the assertion failure and the matching event if the assertion
166
+ passes, otherwise `None`.
167
+ """
168
+ raise NotImplementedError
169
+
170
+ def _generate_assertion_failure(
171
+ self,
172
+ error_message: str,
173
+ prior_events: List[Event],
174
+ turn_events: List[Event],
175
+ line: Optional[int] = None,
176
+ ) -> Tuple[AssertionFailure, None]:
177
+ return AssertionFailure(
178
+ assertion=self,
179
+ error_message=error_message,
180
+ actual_events_transcript=create_actual_events_transcript(
181
+ prior_events, turn_events
182
+ ),
183
+ error_line=line,
184
+ ), None
185
+
186
+
187
+ @dataclass
188
+ class FlowStartedAssertion(Assertion):
189
+ """Class for storing the flow started assertion."""
190
+
191
+ flow_id: str
192
+ line: Optional[int] = None
193
+
194
+ @classmethod
195
+ def type(cls) -> str:
196
+ return AssertionType.FLOW_STARTED.value
197
+
198
+ @staticmethod
199
+ def from_dict(assertion_dict: Dict[Text, Any]) -> FlowStartedAssertion:
200
+ return FlowStartedAssertion(
201
+ flow_id=assertion_dict.get(AssertionType.FLOW_STARTED.value),
202
+ line=assertion_dict.lc.line + 1 if hasattr(assertion_dict, "lc") else None,
203
+ )
204
+
205
+ def run(
206
+ self,
207
+ turn_events: List[Event],
208
+ prior_events: List[Event],
209
+ assertion_order_error_message: str = "",
210
+ **kwargs: Any,
211
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
212
+ """Run the flow started assertion on the given events for that user turn."""
213
+ try:
214
+ matching_event = next(
215
+ event
216
+ for event in turn_events
217
+ if isinstance(event, FlowStarted) and event.flow_id == self.flow_id
218
+ )
219
+ except StopIteration:
220
+ error_message = f"Flow with id '{self.flow_id}' did not start."
221
+ error_message += assertion_order_error_message
222
+
223
+ return self._generate_assertion_failure(
224
+ error_message, prior_events, turn_events, self.line
225
+ )
226
+
227
+ return None, matching_event
228
+
229
+ def __hash__(self) -> int:
230
+ return hash(json.dumps(self.as_dict()))
231
+
232
+
233
+ @dataclass
234
+ class FlowCompletedAssertion(Assertion):
235
+ """Class for storing the flow completed assertion."""
236
+
237
+ flow_id: str
238
+ flow_step_id: Optional[str] = None
239
+ line: Optional[int] = None
240
+
241
+ @classmethod
242
+ def type(cls) -> str:
243
+ return AssertionType.FLOW_COMPLETED.value
244
+
245
+ @staticmethod
246
+ def from_dict(assertion_dict: Dict[Text, Any]) -> FlowCompletedAssertion:
247
+ line = assertion_dict.lc.line + 1 if hasattr(assertion_dict, "lc") else None
248
+ assertion_dict = assertion_dict.get(AssertionType.FLOW_COMPLETED.value, {})
249
+
250
+ return FlowCompletedAssertion(
251
+ flow_id=assertion_dict.get("flow_id"),
252
+ flow_step_id=assertion_dict.get("flow_step_id"),
253
+ line=line,
254
+ )
255
+
256
+ def run(
257
+ self,
258
+ turn_events: List[Event],
259
+ prior_events: List[Event],
260
+ assertion_order_error_message: str = "",
261
+ **kwargs: Any,
262
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
263
+ """Run the flow completed assertion on the given events for that user turn."""
264
+ try:
265
+ matching_event = next(
266
+ event
267
+ for event in turn_events
268
+ if isinstance(event, FlowCompleted) and event.flow_id == self.flow_id
269
+ )
270
+ except StopIteration:
271
+ error_message = f"Flow with id '{self.flow_id}' did not complete."
272
+ error_message += assertion_order_error_message
273
+
274
+ return self._generate_assertion_failure(
275
+ error_message, prior_events, turn_events, self.line
276
+ )
277
+
278
+ if (
279
+ self.flow_step_id is not None
280
+ and matching_event.step_id != self.flow_step_id
281
+ ):
282
+ error_message = (
283
+ f"Flow with id '{self.flow_id}' did not complete "
284
+ f"at expected step id '{self.flow_step_id}'. The actual "
285
+ f"step id was '{matching_event.step_id}'."
286
+ )
287
+ error_message += assertion_order_error_message
288
+ return self._generate_assertion_failure(
289
+ error_message, prior_events, turn_events, self.line
290
+ )
291
+
292
+ return None, matching_event
293
+
294
+ def __hash__(self) -> int:
295
+ return hash(json.dumps(self.as_dict()))
296
+
297
+
298
+ @dataclass
299
+ class FlowCancelledAssertion(Assertion):
300
+ """Class for storing the flow cancelled assertion."""
301
+
302
+ flow_id: str
303
+ flow_step_id: Optional[str] = None
304
+ line: Optional[int] = None
305
+
306
+ @classmethod
307
+ def type(cls) -> str:
308
+ return AssertionType.FLOW_CANCELLED.value
309
+
310
+ @staticmethod
311
+ def from_dict(assertion_dict: Dict[Text, Any]) -> FlowCancelledAssertion:
312
+ line = assertion_dict.lc.line + 1 if hasattr(assertion_dict, "lc") else None
313
+ assertion_dict = assertion_dict.get(AssertionType.FLOW_CANCELLED.value, {})
314
+
315
+ return FlowCancelledAssertion(
316
+ flow_id=assertion_dict.get("flow_id"),
317
+ flow_step_id=assertion_dict.get("flow_step_id"),
318
+ line=line,
319
+ )
320
+
321
+ def run(
322
+ self,
323
+ turn_events: List[Event],
324
+ prior_events: List[Event],
325
+ assertion_order_error_message: str = "",
326
+ **kwargs: Any,
327
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
328
+ """Run the flow cancelled assertion on the given events for that user turn."""
329
+ try:
330
+ matching_event = next(
331
+ event
332
+ for event in turn_events
333
+ if isinstance(event, FlowCancelled) and event.flow_id == self.flow_id
334
+ )
335
+ except StopIteration:
336
+ error_message = f"Flow with id '{self.flow_id}' was not cancelled."
337
+ error_message += assertion_order_error_message
338
+
339
+ return self._generate_assertion_failure(
340
+ error_message, prior_events, turn_events, self.line
341
+ )
342
+
343
+ if (
344
+ self.flow_step_id is not None
345
+ and matching_event.step_id != self.flow_step_id
346
+ ):
347
+ error_message = (
348
+ f"Flow with id '{self.flow_id}' was not cancelled "
349
+ f"at expected step id '{self.flow_step_id}'. The actual "
350
+ f"step id was '{matching_event.step_id}'."
351
+ )
352
+ error_message += assertion_order_error_message
353
+
354
+ return self._generate_assertion_failure(
355
+ error_message, prior_events, turn_events, self.line
356
+ )
357
+
358
+ return None, matching_event
359
+
360
+ def __hash__(self) -> int:
361
+ return hash(json.dumps(self.as_dict()))
362
+
363
+
364
+ @dataclass
365
+ class PatternClarificationContainsAssertion(Assertion):
366
+ """Class for storing the pattern clarification contains assertion."""
367
+
368
+ flow_names: Set[str]
369
+ line: Optional[int] = None
370
+
371
+ @classmethod
372
+ def type(cls) -> str:
373
+ return AssertionType.PATTERN_CLARIFICATION_CONTAINS.value
374
+
375
+ @staticmethod
376
+ def from_dict(
377
+ assertion_dict: Dict[Text, Any],
378
+ ) -> PatternClarificationContainsAssertion:
379
+ return PatternClarificationContainsAssertion(
380
+ flow_names=set(
381
+ assertion_dict.get(
382
+ AssertionType.PATTERN_CLARIFICATION_CONTAINS.value, []
383
+ )
384
+ ),
385
+ line=assertion_dict.lc.line + 1 if hasattr(assertion_dict, "lc") else None,
386
+ )
387
+
388
+ def run(
389
+ self,
390
+ turn_events: List[Event],
391
+ prior_events: List[Event],
392
+ assertion_order_error_message: str = "",
393
+ **kwargs: Any,
394
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
395
+ """Run the flow completed assertion on the given events for that user turn."""
396
+ try:
397
+ matching_event = next(
398
+ event
399
+ for event in turn_events
400
+ if isinstance(event, FlowStarted)
401
+ and event.flow_id == FLOW_PATTERN_CLARIFICATION
402
+ )
403
+ except StopIteration:
404
+ error_message = f"'{FLOW_PATTERN_CLARIFICATION}' pattern did not trigger."
405
+ error_message += assertion_order_error_message
406
+
407
+ return self._generate_assertion_failure(
408
+ error_message, prior_events, turn_events, self.line
409
+ )
410
+
411
+ actual_flow_names = set(matching_event.metadata.get("names", set()))
412
+ if actual_flow_names != self.flow_names:
413
+ error_message = (
414
+ f"'{FLOW_PATTERN_CLARIFICATION}' pattern did not contain "
415
+ f"the expected options. Expected options: {self.flow_names}. "
416
+ )
417
+ error_message += assertion_order_error_message
418
+
419
+ return self._generate_assertion_failure(
420
+ error_message, prior_events, turn_events, self.line
421
+ )
422
+
423
+ return None, matching_event
424
+
425
+ def __hash__(self) -> int:
426
+ return hash(json.dumps(self.as_dict(), cls=SetEncoder))
427
+
428
+
429
+ @dataclass
430
+ class ActionExecutedAssertion(Assertion):
431
+ """Class for storing the action executed assertion."""
432
+
433
+ action_name: str
434
+ line: Optional[int] = None
435
+
436
+ @classmethod
437
+ def type(cls) -> str:
438
+ return AssertionType.ACTION_EXECUTED.value
439
+
440
+ @staticmethod
441
+ def from_dict(assertion_dict: Dict[Text, Any]) -> ActionExecutedAssertion:
442
+ return ActionExecutedAssertion(
443
+ action_name=assertion_dict.get(AssertionType.ACTION_EXECUTED.value),
444
+ line=assertion_dict.lc.line + 1 if hasattr(assertion_dict, "lc") else None,
445
+ )
446
+
447
+ def run(
448
+ self,
449
+ turn_events: List[Event],
450
+ prior_events: List[Event],
451
+ assertion_order_error_message: str = "",
452
+ **kwargs: Any,
453
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
454
+ """Run the action executed assertion on the given events for that user turn."""
455
+ step_index = kwargs.get("step_index")
456
+ original_turn_events, turn_events = _get_turn_events_based_on_step_index(
457
+ step_index, turn_events, prior_events
458
+ )
459
+
460
+ try:
461
+ matching_event = next(
462
+ event
463
+ for event in turn_events
464
+ if isinstance(event, ActionExecuted)
465
+ and event.action_name == self.action_name
466
+ )
467
+ except StopIteration:
468
+ error_message = f"Action '{self.action_name}' did not execute."
469
+ error_message += assertion_order_error_message
470
+
471
+ return self._generate_assertion_failure(
472
+ error_message, prior_events, original_turn_events, self.line
473
+ )
474
+
475
+ return None, matching_event
476
+
477
+ def __hash__(self) -> int:
478
+ return hash(json.dumps(self.as_dict()))
479
+
480
+
481
+ @dataclass
482
+ class AssertedSlot:
483
+ """Class for storing information asserted about slots."""
484
+
485
+ name: str
486
+ value: Any
487
+ line: Optional[int] = None
488
+
489
+ @staticmethod
490
+ def from_dict(slot_dict: Dict[Text, Any]) -> AssertedSlot:
491
+ return AssertedSlot(
492
+ name=slot_dict.get("name"),
493
+ value=slot_dict.get("value", "value key is undefined"),
494
+ line=slot_dict.lc.line + 1 if hasattr(slot_dict, "lc") else None,
495
+ )
496
+
497
+
498
+ @dataclass
499
+ class SlotWasSetAssertion(Assertion):
500
+ """Class for storing the slot was set assertion."""
501
+
502
+ slots: List[AssertedSlot]
503
+
504
+ @classmethod
505
+ def type(cls) -> str:
506
+ return AssertionType.SLOT_WAS_SET.value
507
+
508
+ @staticmethod
509
+ def from_dict(assertion_dict: Dict[Text, Any]) -> SlotWasSetAssertion:
510
+ return SlotWasSetAssertion(
511
+ slots=[
512
+ AssertedSlot.from_dict(slot)
513
+ for slot in assertion_dict.get(AssertionType.SLOT_WAS_SET.value, [])
514
+ ],
515
+ )
516
+
517
+ def run(
518
+ self,
519
+ turn_events: List[Event],
520
+ prior_events: List[Event],
521
+ assertion_order_error_message: str = "",
522
+ **kwargs: Any,
523
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
524
+ """Run the slot_was_set assertion on the given events for that user turn."""
525
+ matching_event = None
526
+
527
+ step_index = kwargs.get("step_index")
528
+ original_turn_events, turn_events = _get_turn_events_based_on_step_index(
529
+ step_index, turn_events, prior_events
530
+ )
531
+
532
+ for slot in self.slots:
533
+ matching_events = [
534
+ event
535
+ for event in turn_events
536
+ if isinstance(event, SlotSet) and event.key == slot.name
537
+ ]
538
+ if not matching_events:
539
+ error_message = f"Slot '{slot.name}' was not set."
540
+ error_message += assertion_order_error_message
541
+
542
+ return self._generate_assertion_failure(
543
+ error_message, prior_events, turn_events, slot.line
544
+ )
545
+
546
+ if slot.value == "value key is undefined":
547
+ matching_event = matching_events[0]
548
+ structlogger.debug(
549
+ "slot_was_set_assertion.run",
550
+ last_event_seen=matching_event,
551
+ event_info="Slot value is not asserted and we have "
552
+ "multiple events for the same slot. "
553
+ "We will mark the first event as last event seen.",
554
+ )
555
+ continue
556
+
557
+ try:
558
+ matching_event = next(
559
+ event for event in matching_events if event.value == slot.value
560
+ )
561
+ except StopIteration:
562
+ error_message = (
563
+ f"Slot '{slot.name}' was set to a different value "
564
+ f"'{matching_events[-1].value}' than the "
565
+ f"expected '{slot.value}' value."
566
+ )
567
+ error_message += assertion_order_error_message
568
+
569
+ return self._generate_assertion_failure(
570
+ error_message, prior_events, original_turn_events, slot.line
571
+ )
572
+
573
+ return None, matching_event
574
+
575
+ def __hash__(self) -> int:
576
+ return hash(json.dumps(self.as_dict()))
577
+
578
+
579
+ @dataclass
580
+ class SlotWasNotSetAssertion(Assertion):
581
+ """Class for storing the slot was not set assertion."""
582
+
583
+ slots: List[AssertedSlot]
584
+
585
+ @classmethod
586
+ def type(cls) -> str:
587
+ return AssertionType.SLOT_WAS_NOT_SET.value
588
+
589
+ @staticmethod
590
+ def from_dict(assertion_dict: Dict[Text, Any]) -> SlotWasNotSetAssertion:
591
+ return SlotWasNotSetAssertion(
592
+ slots=[
593
+ AssertedSlot.from_dict(slot)
594
+ for slot in assertion_dict.get(AssertionType.SLOT_WAS_NOT_SET.value, [])
595
+ ]
596
+ )
597
+
598
+ def run(
599
+ self,
600
+ turn_events: List[Event],
601
+ prior_events: List[Event],
602
+ assertion_order_error_message: str = "",
603
+ **kwargs: Any,
604
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
605
+ """Run the slot_was_not_set assertion on the given events for that user turn."""
606
+ matching_event = None
607
+
608
+ step_index = kwargs.get("step_index")
609
+ original_turn_events, turn_events = _get_turn_events_based_on_step_index(
610
+ step_index, turn_events, prior_events
611
+ )
612
+
613
+ for slot in self.slots:
614
+ matching_events = [
615
+ event
616
+ for event in turn_events
617
+ if isinstance(event, SlotSet) and event.key == slot.name
618
+ ]
619
+ if not matching_events:
620
+ continue
621
+
622
+ # take the most recent event in the list of matching events
623
+ # since that is the final value in the tracker for that user turn
624
+ matching_event = matching_events[-1]
625
+
626
+ if (
627
+ slot.value == "value key is undefined"
628
+ and matching_event.value is not None
629
+ ):
630
+ error_message = (
631
+ f"Slot '{slot.name}' was set to '{matching_event.value}' but "
632
+ f"it should not have been set."
633
+ )
634
+ error_message += assertion_order_error_message
635
+
636
+ return self._generate_assertion_failure(
637
+ error_message, prior_events, turn_events, slot.line
638
+ )
639
+
640
+ if matching_event.value == slot.value:
641
+ error_message = (
642
+ f"Slot '{slot.name}' was set to '{slot.value}' "
643
+ f"but it should not have been set."
644
+ )
645
+ error_message += assertion_order_error_message
646
+
647
+ return self._generate_assertion_failure(
648
+ error_message, prior_events, original_turn_events, slot.line
649
+ )
650
+
651
+ return None, matching_event
652
+
653
+ def __hash__(self) -> int:
654
+ return hash(json.dumps(self.as_dict()))
655
+
656
+
657
+ @dataclass
658
+ class AssertedButton:
659
+ """Class for storing information asserted about buttons."""
660
+
661
+ title: str
662
+ payload: Optional[str] = None
663
+
664
+ @staticmethod
665
+ def from_dict(button_dict: Dict[Text, Any]) -> AssertedButton:
666
+ return AssertedButton(
667
+ title=button_dict.get("title"),
668
+ payload=button_dict.get("payload"),
669
+ )
670
+
671
+
672
+ @dataclass
673
+ class BotUtteredAssertion(Assertion):
674
+ """Class for storing the bot uttered assertion."""
675
+
676
+ utter_name: Optional[str] = None
677
+ text_matches: Optional[str] = None
678
+ buttons: Optional[List[AssertedButton]] = None
679
+ line: Optional[int] = None
680
+
681
+ @classmethod
682
+ def type(cls) -> str:
683
+ return AssertionType.BOT_UTTERED.value
684
+
685
+ @staticmethod
686
+ def from_dict(assertion_dict: Dict[Text, Any]) -> BotUtteredAssertion:
687
+ utter_name, text_matches, buttons = (
688
+ BotUtteredAssertion._extract_assertion_properties(assertion_dict)
689
+ )
690
+
691
+ if BotUtteredAssertion._assertion_is_empty(utter_name, text_matches, buttons):
692
+ raise RasaException(
693
+ "A 'bot_uttered' assertion is empty, it should contain at least one "
694
+ "of the allowed properties: 'utter_name', 'text_matches', 'buttons'."
695
+ )
696
+
697
+ return BotUtteredAssertion(
698
+ utter_name=utter_name,
699
+ text_matches=text_matches,
700
+ buttons=buttons,
701
+ line=assertion_dict.lc.line + 1 if hasattr(assertion_dict, "lc") else None,
702
+ )
703
+
704
+ @staticmethod
705
+ def _extract_assertion_properties(
706
+ assertion_dict: Dict[Text, Any],
707
+ ) -> Tuple[Optional[str], Optional[str], List[AssertedButton]]:
708
+ """Extracts the assertion properties from a dictionary."""
709
+ assertion_dict = assertion_dict.get(AssertionType.BOT_UTTERED.value, {})
710
+ utter_name = assertion_dict.get("utter_name")
711
+ text_matches = assertion_dict.get("text_matches")
712
+ buttons = [
713
+ AssertedButton.from_dict(button)
714
+ for button in assertion_dict.get("buttons", [])
715
+ ]
716
+
717
+ return utter_name, text_matches, buttons
718
+
719
+ @staticmethod
720
+ def _assertion_is_empty(
721
+ utter_name: Optional[str],
722
+ text_matches: Optional[str],
723
+ buttons: List[AssertedButton],
724
+ ) -> bool:
725
+ """Validate if the bot uttered assertion is empty."""
726
+ if not utter_name and not text_matches and not buttons:
727
+ return True
728
+
729
+ return False
730
+
731
+ def run(
732
+ self,
733
+ turn_events: List[Event],
734
+ prior_events: List[Event],
735
+ assertion_order_error_message: str = "",
736
+ **kwargs: Any,
737
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
738
+ """Run the bot_uttered assertion on the given events for that user turn."""
739
+ matching_event = None
740
+
741
+ step_index = kwargs.get("step_index")
742
+ original_turn_events, turn_events = _get_turn_events_based_on_step_index(
743
+ step_index, turn_events, prior_events
744
+ )
745
+
746
+ if self.utter_name is not None:
747
+ try:
748
+ matching_event = next(
749
+ event
750
+ for event in turn_events
751
+ if isinstance(event, BotUttered)
752
+ and event.metadata.get("utter_action") == self.utter_name
753
+ )
754
+ except StopIteration:
755
+ error_message = f"Bot did not utter '{self.utter_name}' response."
756
+ error_message += assertion_order_error_message
757
+
758
+ return self._generate_assertion_failure(
759
+ error_message, prior_events, original_turn_events, self.line
760
+ )
761
+
762
+ if self.text_matches is not None:
763
+ pattern = re.compile(self.text_matches)
764
+ try:
765
+ matching_event = next(
766
+ event
767
+ for event in turn_events
768
+ if isinstance(event, BotUttered) and pattern.search(event.text)
769
+ )
770
+ except StopIteration:
771
+ error_message = (
772
+ f"Bot did not utter any response which "
773
+ f"matches the provided text pattern "
774
+ f"'{self.text_matches}'."
775
+ )
776
+ error_message += assertion_order_error_message
777
+
778
+ return self._generate_assertion_failure(
779
+ error_message, prior_events, original_turn_events, self.line
780
+ )
781
+
782
+ if self.buttons:
783
+ try:
784
+ matching_event = next(
785
+ event
786
+ for event in turn_events
787
+ if isinstance(event, BotUttered) and self._buttons_match(event)
788
+ )
789
+ except StopIteration:
790
+ error_message = (
791
+ "Bot did not utter any response with the expected buttons."
792
+ )
793
+ error_message += assertion_order_error_message
794
+ return self._generate_assertion_failure(
795
+ error_message, prior_events, original_turn_events, self.line
796
+ )
797
+
798
+ return None, matching_event
799
+
800
+ def _buttons_match(self, event: BotUttered) -> bool:
801
+ """Check if the bot response contains the expected buttons."""
802
+ # a button is a dictionary with keys 'title' and 'payload'
803
+ actual_buttons = event.data.get("buttons", [])
804
+ if not actual_buttons:
805
+ return False
806
+
807
+ return all(
808
+ self._button_matches(actual_button, expected_button)
809
+ for actual_button, expected_button in zip(actual_buttons, self.buttons)
810
+ )
811
+
812
+ @staticmethod
813
+ def _button_matches(
814
+ actual_button: Dict[str, Any], expected_button: AssertedButton
815
+ ) -> bool:
816
+ """Check if the actual button matches the expected button."""
817
+ return (
818
+ actual_button.get("title") == expected_button.title
819
+ and actual_button.get("payload") == expected_button.payload
820
+ )
821
+
822
+ def __hash__(self) -> int:
823
+ return hash(json.dumps(self.as_dict()))
824
+
825
+
826
+ @dataclass
827
+ class GenerativeResponseMixin(Assertion):
828
+ """Mixin class for storing generative response assertions."""
829
+
830
+ threshold: float = DEFAULT_THRESHOLD
831
+ utter_name: Optional[str] = None
832
+ line: Optional[int] = None
833
+ metric_adjective: Optional[str] = None
834
+ metric_name: Optional[str] = None
835
+ mlflow_metric: Callable = print
836
+
837
+ @classmethod
838
+ def type(cls) -> str:
839
+ return ""
840
+
841
+ def _get_ground_truth(self, matching_event: BotUttered) -> str:
842
+ raise NotImplementedError
843
+
844
+ def as_dict(self) -> Dict[str, Any]:
845
+ data = super().as_dict()
846
+ data.pop("metric_name")
847
+ data.pop("metric_adjective")
848
+ data.pop("mlflow_metric")
849
+
850
+ return data
851
+
852
+ def _run_llm_evaluation(
853
+ self,
854
+ matching_event: BotUttered,
855
+ step_text: str,
856
+ llm_judge_config: "LLMJudgeConfig",
857
+ assertion_order_error_message: str,
858
+ prior_events: List[Event],
859
+ turn_events: List[Event],
860
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
861
+ """Run the LLM evaluation on the given event."""
862
+ import mlflow
863
+
864
+ # we need to configure the log level for mlflow
865
+ # after a local import to avoid unnecessary logs
866
+ update_mlflow_log_level()
867
+
868
+ # extract user question from event if available
869
+ user_question_from_event = matching_event.metadata.get(
870
+ SEARCH_QUERY_METADATA_KEY
871
+ )
872
+ user_question = (
873
+ user_question_from_event if user_question_from_event else step_text
874
+ )
875
+
876
+ ground_truth = self._get_ground_truth(matching_event)
877
+
878
+ eval_data = pd.DataFrame(
879
+ {
880
+ "inputs": [user_question],
881
+ "ground_truth": [ground_truth],
882
+ "predictions": [matching_event.text],
883
+ }
884
+ )
885
+
886
+ model_uri = llm_judge_config.get_model_uri()
887
+
888
+ structlogger.debug(
889
+ f"generative_response_is_{self.metric_adjective}_assertion.run_llm_evaluation",
890
+ model_uri=model_uri,
891
+ )
892
+
893
+ with mlflow.start_run():
894
+ results = mlflow.evaluate(
895
+ data=eval_data,
896
+ targets="ground_truth",
897
+ predictions="predictions",
898
+ model_type="question-answering",
899
+ evaluators="default",
900
+ extra_metrics=[
901
+ self.mlflow_metric(model_uri),
902
+ ],
903
+ )
904
+
905
+ # Evaluation result for each data record is available in `results.tables`.
906
+ eval_table = results.tables["eval_results_table"]
907
+ score = eval_table.iloc[0][f"{self.metric_name}/v1/score"]
908
+ justification = eval_table.iloc[0][f"{self.metric_name}/v1/justification"]
909
+
910
+ # convert 1-5 score to 0-1 float
911
+ score = score * 20 / 100 if score is not None else 0
912
+
913
+ structlogger.debug(
914
+ f"generative_response_is_{self.metric_adjective}_assertion.run_results",
915
+ matching_event=repr(matching_event),
916
+ score=score,
917
+ justification=justification,
918
+ )
919
+
920
+ if score < self.threshold:
921
+ error_message = (
922
+ f"Generative response '{matching_event.text}' "
923
+ f"given to the user input '{user_question}' "
924
+ f"was not {self.metric_adjective}. "
925
+ f"Expected score to be above '{self.threshold}' threshold, "
926
+ f"but was '{score}'. The explanation for this score is: "
927
+ f"{justification}."
928
+ )
929
+ error_message += assertion_order_error_message
930
+
931
+ return self._generate_assertion_failure(
932
+ error_message, prior_events, turn_events, self.line
933
+ )
934
+
935
+ return None, matching_event
936
+
937
+ def _run_assertion_with_utter_name(
938
+ self,
939
+ matching_events: List[BotUttered],
940
+ step_text: str,
941
+ llm_judge_config: "LLMJudgeConfig",
942
+ assertion_order_error_message: str,
943
+ prior_events: List[Event],
944
+ turn_events: List[Event],
945
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
946
+ """Assert metric for the given utter name."""
947
+ try:
948
+ matching_event = next(
949
+ event
950
+ for event in matching_events
951
+ if event.metadata.get("utter_action") == self.utter_name
952
+ )
953
+ except StopIteration:
954
+ error_message = f"Bot did not utter '{self.utter_name}' response."
955
+ error_message += assertion_order_error_message
956
+
957
+ return self._generate_assertion_failure(
958
+ error_message, prior_events, turn_events, self.line
959
+ )
960
+
961
+ return self._run_llm_evaluation(
962
+ matching_event,
963
+ step_text,
964
+ llm_judge_config,
965
+ assertion_order_error_message,
966
+ prior_events,
967
+ turn_events,
968
+ )
969
+
970
+ def _run_assertion_for_multiple_generative_responses(
971
+ self,
972
+ matching_events: List[BotUttered],
973
+ step_text: str,
974
+ llm_judge_config: "LLMJudgeConfig",
975
+ assertion_order_error_message: str,
976
+ prior_events: List[Event],
977
+ turn_events: List[Event],
978
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
979
+ """Run LLM evaluation for multiple bot utterances."""
980
+ structlogger.debug(
981
+ f"generative_response_is_{self.metric_adjective}_assertion.run",
982
+ event_info="Multiple generative responses found, "
983
+ "we will evaluate each of the responses.",
984
+ )
985
+
986
+ passing_events = set()
987
+ for event in matching_events:
988
+ failure, event_result = self._run_llm_evaluation(
989
+ event,
990
+ step_text,
991
+ llm_judge_config,
992
+ assertion_order_error_message,
993
+ prior_events,
994
+ turn_events,
995
+ )
996
+ if event_result is not None:
997
+ passing_events.add(event_result)
998
+ else:
999
+ if not passing_events:
1000
+ error_message = (
1001
+ f"None of the generative responses issued by either the "
1002
+ f"Enterprise Search Policy, IntentlessPolicy or the "
1003
+ f"Contextual Response Rephraser were {self.metric_adjective}."
1004
+ )
1005
+ error_message += assertion_order_error_message
1006
+
1007
+ return self._generate_assertion_failure(
1008
+ error_message, prior_events, turn_events, self.line
1009
+ )
1010
+
1011
+ return None, list(passing_events)[-1]
1012
+
1013
+ def run(
1014
+ self,
1015
+ turn_events: List[Event],
1016
+ prior_events: List[Event],
1017
+ assertion_order_error_message: str = "",
1018
+ llm_judge_config: Optional["LLMJudgeConfig"] = None,
1019
+ step_text: Optional[str] = None,
1020
+ **kwargs: Any,
1021
+ ) -> Tuple[Optional[AssertionFailure], Optional[Event]]:
1022
+ """Run the LLM evaluation on the given events for that user turn."""
1023
+ matching_events: List[BotUttered] = _find_matching_generative_events(
1024
+ turn_events
1025
+ )
1026
+
1027
+ if not matching_events:
1028
+ error_message = (
1029
+ "No generative response issued by either the Enterprise Search Policy, "
1030
+ "IntentlessPolicy or the Contextual Response Rephraser was found, "
1031
+ "but one was expected."
1032
+ )
1033
+ error_message += assertion_order_error_message
1034
+
1035
+ return self._generate_assertion_failure(
1036
+ error_message, prior_events, turn_events, self.line
1037
+ )
1038
+
1039
+ if self.utter_name is not None:
1040
+ return self._run_assertion_with_utter_name(
1041
+ matching_events,
1042
+ step_text,
1043
+ llm_judge_config,
1044
+ assertion_order_error_message,
1045
+ prior_events,
1046
+ turn_events,
1047
+ )
1048
+
1049
+ if len(matching_events) > 1:
1050
+ return self._run_assertion_for_multiple_generative_responses(
1051
+ matching_events,
1052
+ step_text,
1053
+ llm_judge_config,
1054
+ assertion_order_error_message,
1055
+ prior_events,
1056
+ turn_events,
1057
+ )
1058
+
1059
+ matching_event = matching_events[0]
1060
+
1061
+ return self._run_llm_evaluation(
1062
+ matching_event,
1063
+ step_text,
1064
+ llm_judge_config,
1065
+ assertion_order_error_message,
1066
+ prior_events,
1067
+ turn_events,
1068
+ )
1069
+
1070
+
1071
+ @dataclass
1072
+ class GenerativeResponseIsRelevantAssertion(GenerativeResponseMixin):
1073
+ """Class for storing the generative response is relevant assertion."""
1074
+
1075
+ def _get_ground_truth(self, matching_event: BotUttered) -> str:
1076
+ return ""
1077
+
1078
+ @classmethod
1079
+ def type(cls) -> str:
1080
+ return AssertionType.GENERATIVE_RESPONSE_IS_RELEVANT.value
1081
+
1082
+ @staticmethod
1083
+ def from_dict(
1084
+ assertion_dict: Dict[Text, Any],
1085
+ ) -> GenerativeResponseIsRelevantAssertion:
1086
+ import mlflow
1087
+
1088
+ assertion_dict = assertion_dict.get(
1089
+ AssertionType.GENERATIVE_RESPONSE_IS_RELEVANT.value, {}
1090
+ )
1091
+ return GenerativeResponseIsRelevantAssertion(
1092
+ threshold=assertion_dict.get("threshold", DEFAULT_THRESHOLD),
1093
+ utter_name=assertion_dict.get("utter_name"),
1094
+ line=assertion_dict.lc.line + 1 if hasattr(assertion_dict, "lc") else None,
1095
+ metric_name="answer_relevance",
1096
+ metric_adjective="relevant",
1097
+ mlflow_metric=mlflow.metrics.genai.answer_relevance,
1098
+ )
1099
+
1100
+ def __hash__(self) -> int:
1101
+ return hash(json.dumps(self.as_dict()))
1102
+
1103
+
1104
+ @dataclass
1105
+ class GenerativeResponseIsGroundedAssertion(GenerativeResponseMixin):
1106
+ """Class for storing the generative response is grounded assertion."""
1107
+
1108
+ ground_truth: Optional[str] = None
1109
+
1110
+ @classmethod
1111
+ def type(cls) -> str:
1112
+ return AssertionType.GENERATIVE_RESPONSE_IS_GROUNDED.value
1113
+
1114
+ @staticmethod
1115
+ def from_dict(
1116
+ assertion_dict: Dict[Text, Any],
1117
+ ) -> GenerativeResponseIsGroundedAssertion:
1118
+ import mlflow
1119
+
1120
+ assertion_dict = assertion_dict.get(
1121
+ AssertionType.GENERATIVE_RESPONSE_IS_GROUNDED.value, {}
1122
+ )
1123
+ return GenerativeResponseIsGroundedAssertion(
1124
+ threshold=assertion_dict.get("threshold", DEFAULT_THRESHOLD),
1125
+ utter_name=assertion_dict.get("utter_name"),
1126
+ ground_truth=assertion_dict.get("ground_truth"),
1127
+ line=assertion_dict.lc.line + 1 if hasattr(assertion_dict, "lc") else None,
1128
+ metric_name="answer_correctness",
1129
+ metric_adjective="grounded",
1130
+ mlflow_metric=mlflow.metrics.genai.answer_correctness,
1131
+ )
1132
+
1133
+ def __hash__(self) -> int:
1134
+ return hash(json.dumps(self.as_dict()))
1135
+
1136
+ def _get_ground_truth(self, matching_event: BotUttered) -> str:
1137
+ # extract ground truth from event if available or use the provided ground truth
1138
+ ground_truth_event_metadata = matching_event.metadata.get(
1139
+ SEARCH_RESULTS_METADATA_KEY, ""
1140
+ ) or matching_event.metadata.get(DOMAIN_GROUND_TRUTH_METADATA_KEY, "")
1141
+
1142
+ if isinstance(ground_truth_event_metadata, list):
1143
+ ground_truth_event_metadata = "\n".join(ground_truth_event_metadata)
1144
+
1145
+ ground_truth = (
1146
+ self.ground_truth
1147
+ if self.ground_truth is not None
1148
+ else ground_truth_event_metadata
1149
+ )
1150
+
1151
+ return ground_truth
1152
+
1153
+
1154
+ @dataclass
1155
+ class AssertionFailure:
1156
+ """Class for storing the assertion failure."""
1157
+
1158
+ assertion: Assertion
1159
+ error_message: Text
1160
+ actual_events_transcript: List[Text]
1161
+ error_line: Optional[int] = None
1162
+
1163
+ def as_dict(self) -> Dict[Text, Any]:
1164
+ """Returns the assertion failure as a dictionary."""
1165
+ return {
1166
+ "assertion": self.assertion.as_dict(),
1167
+ "error_message": self.error_message,
1168
+ "actual_events_transcript": self.actual_events_transcript,
1169
+ }
1170
+
1171
+
1172
+ def create_actual_events_transcript(
1173
+ prior_events: List[Event], turn_events: List[Event]
1174
+ ) -> List[Text]:
1175
+ """Create the actual events transcript for the assertion failure."""
1176
+ all_events = prior_events + turn_events
1177
+
1178
+ event_transcript = []
1179
+
1180
+ for event in all_events:
1181
+ if isinstance(event, SlotSet) and event.key in DEFAULT_SLOT_NAMES:
1182
+ continue
1183
+ if isinstance(event, DefinePrevUserUtteredFeaturization):
1184
+ continue
1185
+ if isinstance(event, DialogueStackUpdated):
1186
+ continue
1187
+
1188
+ event_transcript.append(repr(event))
1189
+
1190
+ return event_transcript
1191
+
1192
+
1193
+ def _find_matching_generative_events(turn_events: List[Event]) -> List[BotUttered]:
1194
+ """Find the matching events for the generative response assertions."""
1195
+ return [
1196
+ event
1197
+ for event in turn_events
1198
+ if isinstance(event, BotUttered)
1199
+ and event.metadata.get(UTTER_SOURCE_METADATA_KEY)
1200
+ in ELIGIBLE_UTTER_SOURCE_METADATA
1201
+ ]
1202
+
1203
+
1204
+ def _get_turn_events_based_on_step_index(
1205
+ step_index: int, turn_events: List[Event], prior_events: List[Event]
1206
+ ) -> Tuple[List[Event], List[Event]]:
1207
+ """Get the turn events based on the step index.
1208
+
1209
+ For the first step, we need to include the prior events as well
1210
+ in the same user turn. For the subsequent steps, we only need the
1211
+ events that follow the user uttered event on which the tracker
1212
+ was originally sliced by.
1213
+
1214
+ Returns:
1215
+ List[Event]: The copy of turn_events
1216
+ List[Event]: The turn events based on the step index
1217
+
1218
+ """
1219
+ original_turn_events = turn_events[:]
1220
+ if step_index == 0:
1221
+ return original_turn_events, prior_events + turn_events
1222
+
1223
+ return original_turn_events, turn_events