rasa-pro 3.11.3a1.dev7__py3-none-any.whl → 3.12.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (65) hide show
  1. rasa/cli/arguments/default_arguments.py +1 -1
  2. rasa/cli/dialogue_understanding_test.py +251 -0
  3. rasa/core/actions/action.py +7 -16
  4. rasa/core/channels/__init__.py +0 -2
  5. rasa/core/channels/socketio.py +23 -1
  6. rasa/core/nlg/contextual_response_rephraser.py +9 -62
  7. rasa/core/policies/enterprise_search_policy.py +12 -77
  8. rasa/core/policies/flows/flow_executor.py +2 -26
  9. rasa/core/processor.py +8 -11
  10. rasa/dialogue_understanding/generator/command_generator.py +49 -43
  11. rasa/dialogue_understanding/generator/llm_based_command_generator.py +5 -5
  12. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -2
  13. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +15 -34
  14. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +6 -11
  15. rasa/dialogue_understanding/utils.py +1 -8
  16. rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
  17. rasa/dialogue_understanding_test/constants.py +2 -0
  18. rasa/dialogue_understanding_test/du_test_runner.py +93 -0
  19. rasa/dialogue_understanding_test/io.py +54 -0
  20. rasa/dialogue_understanding_test/validation.py +22 -0
  21. rasa/e2e_test/e2e_test_runner.py +9 -7
  22. rasa/hooks.py +9 -15
  23. rasa/model_manager/socket_bridge.py +2 -7
  24. rasa/model_manager/warm_rasa_process.py +4 -9
  25. rasa/plugin.py +0 -11
  26. rasa/shared/constants.py +2 -21
  27. rasa/shared/core/events.py +8 -8
  28. rasa/shared/nlu/constants.py +0 -3
  29. rasa/shared/providers/_configs/azure_entra_id_client_creds.py +40 -0
  30. rasa/shared/providers/_configs/azure_entra_id_config.py +533 -0
  31. rasa/shared/providers/_configs/azure_openai_client_config.py +131 -15
  32. rasa/shared/providers/_configs/client_config.py +3 -1
  33. rasa/shared/providers/_configs/default_litellm_client_config.py +9 -7
  34. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +13 -11
  35. rasa/shared/providers/_configs/litellm_router_client_config.py +12 -10
  36. rasa/shared/providers/_configs/model_group_config.py +11 -5
  37. rasa/shared/providers/_configs/oauth_config.py +33 -0
  38. rasa/shared/providers/_configs/openai_client_config.py +14 -12
  39. rasa/shared/providers/_configs/rasa_llm_client_config.py +5 -3
  40. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +12 -11
  41. rasa/shared/providers/constants.py +6 -0
  42. rasa/shared/providers/embedding/azure_openai_embedding_client.py +30 -7
  43. rasa/shared/providers/embedding/litellm_router_embedding_client.py +5 -2
  44. rasa/shared/providers/llm/_base_litellm_client.py +6 -4
  45. rasa/shared/providers/llm/azure_openai_llm_client.py +88 -34
  46. rasa/shared/providers/llm/default_litellm_llm_client.py +4 -2
  47. rasa/shared/providers/llm/litellm_router_llm_client.py +23 -3
  48. rasa/shared/providers/llm/llm_client.py +4 -2
  49. rasa/shared/providers/llm/llm_response.py +1 -42
  50. rasa/shared/providers/llm/openai_llm_client.py +11 -5
  51. rasa/shared/providers/llm/rasa_llm_client.py +13 -5
  52. rasa/shared/providers/llm/self_hosted_llm_client.py +17 -10
  53. rasa/shared/providers/router/_base_litellm_router_client.py +10 -8
  54. rasa/shared/providers/router/router_client.py +3 -1
  55. rasa/shared/utils/llm.py +16 -12
  56. rasa/shared/utils/schemas/events.py +1 -1
  57. rasa/tracing/instrumentation/attribute_extractors.py +0 -2
  58. rasa/version.py +1 -1
  59. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/METADATA +2 -1
  60. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/RECORD +63 -56
  61. rasa/core/channels/studio_chat.py +0 -192
  62. rasa/dialogue_understanding/constants.py +0 -1
  63. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/NOTICE +0 -0
  64. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/WHEEL +0 -0
  65. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/entry_points.txt +0 -0
@@ -82,7 +82,6 @@ from rasa.shared.core.flows.steps import (
82
82
  NoOperationFlowStep,
83
83
  )
84
84
  from rasa.shared.core.flows.steps.collect import SlotRejection
85
- from rasa.shared.core.flows.steps.constants import START_STEP
86
85
  from rasa.shared.core.slots import Slot
87
86
  from rasa.shared.core.trackers import (
88
87
  DialogueStateTracker,
@@ -317,7 +316,7 @@ def reset_scoped_slots(
317
316
  def _reset_slot(slot_name: Text, dialogue_tracker: DialogueStateTracker) -> None:
318
317
  slot = dialogue_tracker.slots.get(slot_name, None)
319
318
  initial_value = slot.initial_value if slot else None
320
- events.append(SlotSet(slot_name, initial_value, metadata={"reset": True}))
319
+ events.append(SlotSet(slot_name, initial_value))
321
320
 
322
321
  if (
323
322
  isinstance(current_frame, UserFlowStackFrame)
@@ -449,7 +448,6 @@ def advance_flows_until_next_action(
449
448
  tracker,
450
449
  available_actions,
451
450
  flows,
452
- previous_step_id,
453
451
  )
454
452
  new_events = step_result.events
455
453
  if (
@@ -466,9 +464,6 @@ def advance_flows_until_next_action(
466
464
  new_events.insert(
467
465
  idx, FlowCompleted(active_frame.flow_id, previous_step_id)
468
466
  )
469
- attach_stack_metadata_to_events(
470
- next_step.id, current_flow.id, new_events
471
- )
472
467
  tracker.update_stack(step_stack)
473
468
  tracker.update_with_events(new_events)
474
469
 
@@ -572,17 +567,6 @@ def validate_custom_slot_mappings(
572
567
  return True
573
568
 
574
569
 
575
- def attach_stack_metadata_to_events(
576
- step_id: str,
577
- flow_id: str,
578
- events: List[Event],
579
- ) -> None:
580
- """Attach the stack metadata to the events."""
581
- for event in events:
582
- event.metadata[STEP_ID_METADATA_KEY] = step_id
583
- event.metadata[ACTIVE_FLOW_METADATA_KEY] = flow_id
584
-
585
-
586
570
  def run_step(
587
571
  step: FlowStep,
588
572
  flow: Flow,
@@ -590,7 +574,6 @@ def run_step(
590
574
  tracker: DialogueStateTracker,
591
575
  available_actions: List[str],
592
576
  flows: FlowsList,
593
- previous_step_id: str,
594
577
  ) -> FlowStepResult:
595
578
  """Run a single step of a flow.
596
579
 
@@ -608,19 +591,12 @@ def run_step(
608
591
  tracker: The tracker to run the step on.
609
592
  available_actions: The actions that are available in the domain.
610
593
  flows: All flows.
611
- previous_step_id: The ID of the previous step.
612
594
 
613
595
  Returns:
614
596
  A result of running the step describing where to transition to.
615
597
  """
616
598
  initial_events: List[Event] = []
617
- if previous_step_id == START_STEP:
618
- # if the previous step id is the start step, we need to add a flow
619
- # started event to the initial events.
620
- # we can't use the current step to check this, as the current step is the
621
- # first step in the flow -> other steps might link to this flow, so the
622
- # only reliable way to check if we are starting a new flow is checking for
623
- # the START_STEP meta step
599
+ if step == flow.first_step_in_flow():
624
600
  initial_events.append(FlowStarted(flow.id, metadata=stack.current_context()))
625
601
 
626
602
  if isinstance(step, CollectInformationFlowStep):
rasa/core/processor.py CHANGED
@@ -818,9 +818,8 @@ class MessageProcessor:
818
818
  return parse_data
819
819
 
820
820
  def _sanitize_message(self, message: UserMessage) -> UserMessage:
821
- """Sanitize user messages.
822
-
823
- Removes prepended slashes before the actual content.
821
+ """Sanitize user message by removing prepended slashes before the
822
+ actual content.
824
823
  """
825
824
  # Regex pattern to match leading slashes and any whitespace before
826
825
  # actual content
@@ -922,7 +921,9 @@ class MessageProcessor:
922
921
  return [command.as_dict() for command in commands]
923
922
 
924
923
  def _contains_undefined_intent(self, message: Message) -> bool:
925
- """Checks if the message contains an undefined intent."""
924
+ """Checks if the message contains an intent that is undefined
925
+ in the domain.
926
+ """
926
927
  intent_name = message.get(INTENT, {}).get("name")
927
928
  return intent_name is not None and intent_name not in self.domain.intents
928
929
 
@@ -986,8 +987,6 @@ class MessageProcessor:
986
987
  if parse_data["entities"]:
987
988
  self._log_slots(tracker)
988
989
 
989
- plugin_manager().hook.after_new_user_message(tracker=tracker)
990
-
991
990
  logger.debug(
992
991
  f"Logged UserUtterance - tracker now has {len(tracker.events)} events."
993
992
  )
@@ -1306,7 +1305,7 @@ class MessageProcessor:
1306
1305
  self._log_slots(tracker)
1307
1306
 
1308
1307
  await self.execute_side_effects(events, tracker, output_channel)
1309
- plugin_manager().hook.after_action_executed(tracker=tracker)
1308
+
1310
1309
  return self.should_predict_another_action(action.name())
1311
1310
 
1312
1311
  def _log_action_on_tracker(
@@ -1442,10 +1441,8 @@ class MessageProcessor:
1442
1441
  return len(filtered_commands) > 0
1443
1442
 
1444
1443
  def _is_calm_assistant(self) -> bool:
1445
- """Inspects the nodes of the graph schema to decide if we are in CALM.
1446
-
1447
- To determine whether we are in CALM mode, we check if any node is
1448
- associated with the `FlowPolicy`, which is indicative of a
1444
+ """Inspects the nodes of the graph schema to determine whether
1445
+ any node is associated with the `FlowPolicy`, which is indicative of a
1449
1446
  CALM assistant setup.
1450
1447
 
1451
1448
  Returns:
@@ -26,12 +26,8 @@ from rasa.shared.nlu.constants import (
26
26
  PROMPTS,
27
27
  KEY_USER_PROMPT,
28
28
  KEY_SYSTEM_PROMPT,
29
- KEY_LLM_RESPONSE_METADATA,
30
- KEY_PROMPT_NAME,
31
- KEY_COMPONENT_NAME,
32
29
  )
33
30
  from rasa.shared.nlu.training_data.message import Message
34
- from rasa.shared.providers.llm.llm_response import LLMResponse
35
31
  from rasa.shared.utils.llm import DEFAULT_MAX_USER_INPUT_CHARACTERS
36
32
 
37
33
  structlogger = structlog.get_logger()
@@ -403,56 +399,66 @@ class CommandGenerator:
403
399
  prompt_name: str,
404
400
  user_prompt: str,
405
401
  system_prompt: Optional[str] = None,
406
- llm_response: Optional[LLMResponse] = None,
407
402
  ) -> None:
408
403
  """Add prompt to the message parse data.
409
404
 
410
405
  Prompt is only added in case the flag 'record_commands_and_prompts' is set.
411
406
  Example of prompts in the message parse data:
412
407
  Message(data={
413
- PROMPTS: [
414
- {
415
- "component_name": "MultiStepLLMCommandGenerator",
416
- "prompt_name": "fill_slots_prompt",
417
- "user_prompt": "...",
418
- "system_prompt": "...",
419
- "llm_response_metadata": { ... }
420
- },
421
- {
422
- "component_name": "MultiStepLLMCommandGenerator",
423
- "prompt_name": "handle_flows_prompt",
424
- "user_prompt": "...",
425
- "system_prompt": "...",
426
- "llm_response_metadata": { ... }
427
- },
428
- {
429
- "component_name": "SingleStepLLMCommandGenerator",
430
- "prompt_name": "prompt_template",
431
- "user_prompt": "...",
432
- "system_prompt": "...",
433
- "llm_response_metadata": { ... }
434
- }
408
+ PROMPTS: {
409
+ "MultiStepLLMCommandGenerator": [
410
+ (
411
+ "fill_slots_prompt",
412
+ {
413
+ "user_prompt": <prompt content>",
414
+ "system_prompt": <prompt content>"
415
+ }
416
+ ),
417
+ (
418
+ "handle_flows_prompt",
419
+ {
420
+ "user_prompt": <prompt content>",
421
+ "system_prompt": <prompt content>"
422
+ }
423
+ ),
424
+ ],
425
+ "SingleStepLLMCommandGenerator": [
426
+ (
427
+ "prompt_template",
428
+ {
429
+ "user_prompt": <prompt content>",
430
+ "system_prompt": <prompt content>"
431
+ }
432
+ ),
435
433
  ]
434
+ }
436
435
  })
437
436
  """
438
437
  from rasa.dialogue_understanding.utils import record_commands_and_prompts
439
438
 
440
- # Only set prompt if the flag "record_commands_and_prompts" is set to True.
439
+ # only set prompt if the flag "record_commands_and_prompts" is set to True
441
440
  if not record_commands_and_prompts:
442
441
  return
443
442
 
444
- # Construct the dictionary with prompt details.
445
- prompt_data: Dict[Text, Any] = {
446
- KEY_COMPONENT_NAME: component_name,
447
- KEY_PROMPT_NAME: prompt_name,
448
- KEY_USER_PROMPT: user_prompt,
449
- KEY_LLM_RESPONSE_METADATA: llm_response.to_dict() if llm_response else None,
450
- **({KEY_SYSTEM_PROMPT: system_prompt} if system_prompt else {}),
451
- }
452
-
453
- # Get or create a top-level "prompts" list.
454
- prompts = message.get(PROMPTS) or []
455
- prompts.append(prompt_data)
456
-
457
- # Update the message with the new prompts list.
458
- message.set(PROMPTS, prompts, add_to_output=True)
443
+ prompt_tuple = (
444
+ prompt_name,
445
+ {
446
+ KEY_USER_PROMPT: user_prompt,
447
+ **({KEY_SYSTEM_PROMPT: system_prompt} if system_prompt else {}),
448
+ },
449
+ )
450
+
451
+ if message.get(PROMPTS) is not None:
452
+ prompts = message.get(PROMPTS)
453
+ if component_name in prompts:
454
+ prompts[component_name].append(prompt_tuple)
455
+ else:
456
+ prompts[component_name] = [prompt_tuple]
457
+ else:
458
+ prompts = {component_name: [prompt_tuple]}
459
+
460
+ message.set(
461
+ PROMPTS,
462
+ prompts,
463
+ add_to_output=True,
464
+ )
@@ -32,7 +32,6 @@ from rasa.shared.exceptions import ProviderClientAPIException
32
32
  from rasa.shared.nlu.constants import FLOWS_IN_PROMPT
33
33
  from rasa.shared.nlu.training_data.message import Message
34
34
  from rasa.shared.nlu.training_data.training_data import TrainingData
35
- from rasa.shared.providers.llm.llm_response import LLMResponse
36
35
  from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
37
36
  from rasa.shared.utils.llm import (
38
37
  allowed_values_for_slot,
@@ -305,21 +304,22 @@ class LLMBasedCommandGenerator(
305
304
  )
306
305
  return filtered_flows
307
306
 
308
- async def invoke_llm(self, prompt: Text) -> Optional[LLMResponse]:
307
+ async def invoke_llm(self, prompt: Text) -> Optional[Text]:
309
308
  """Use LLM to generate a response.
310
309
 
311
310
  Args:
312
311
  prompt: The prompt to send to the LLM.
313
312
 
314
313
  Returns:
315
- An LLMResponse object.
314
+ The generated text.
316
315
 
317
316
  Raises:
318
- ProviderClientAPIException: If an error occurs during the LLM API call.
317
+ ProviderClientAPIException if an error during API call.
319
318
  """
320
319
  llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
321
320
  try:
322
- return await llm.acompletion(prompt)
321
+ llm_response = await llm.acompletion(prompt)
322
+ return llm_response.choices[0]
323
323
  except Exception as e:
324
324
  # unfortunately, langchain does not wrap LLM exceptions which means
325
325
  # we have to catch all exceptions here
@@ -10,7 +10,6 @@ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
10
10
  from rasa.engine.storage.resource import Resource
11
11
  from rasa.engine.storage.storage import ModelStorage
12
12
  from rasa.shared.exceptions import ProviderClientAPIException
13
- from rasa.shared.providers.llm.llm_response import LLMResponse
14
13
  from rasa.shared.utils.io import raise_deprecation_warning
15
14
 
16
15
  structlogger = structlog.get_logger()
@@ -54,7 +53,7 @@ class LLMCommandGenerator(SingleStepLLMCommandGenerator):
54
53
  **kwargs,
55
54
  )
56
55
 
57
- async def invoke_llm(self, prompt: Text) -> Optional[LLMResponse]:
56
+ async def invoke_llm(self, prompt: Text) -> Optional[Text]:
58
57
  try:
59
58
  return await super().invoke_llm(prompt)
60
59
  except ProviderClientAPIException:
@@ -51,7 +51,6 @@ from rasa.shared.core.trackers import DialogueStateTracker
51
51
  from rasa.shared.exceptions import ProviderClientAPIException
52
52
  from rasa.shared.nlu.constants import TEXT
53
53
  from rasa.shared.nlu.training_data.message import Message
54
- from rasa.shared.providers.llm.llm_response import LLMResponse
55
54
  from rasa.shared.utils.io import deep_container_fingerprint
56
55
  from rasa.shared.utils.llm import (
57
56
  get_prompt_template,
@@ -536,12 +535,7 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
536
535
  prompt=prompt,
537
536
  )
538
537
 
539
- response = await self.invoke_llm(prompt)
540
- llm_response = LLMResponse.ensure_llm_response(response)
541
- actions = None
542
- if llm_response and llm_response.choices:
543
- actions = llm_response.choices[0]
544
-
538
+ actions = await self.invoke_llm(prompt)
545
539
  structlogger.debug(
546
540
  "multi_step_llm_command_generator"
547
541
  ".predict_commands_for_active_flow"
@@ -553,11 +547,10 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
553
547
 
554
548
  if commands:
555
549
  self._add_prompt_to_message_parse_data(
556
- message=message,
557
- component_name=MultiStepLLMCommandGenerator.__name__,
558
- prompt_name="fill_slots_for_active_flow_prompt",
559
- user_prompt=prompt,
560
- llm_response=llm_response,
550
+ message,
551
+ MultiStepLLMCommandGenerator.__name__,
552
+ "fill_slots_for_active_flow_prompt",
553
+ prompt,
561
554
  )
562
555
 
563
556
  return commands
@@ -591,12 +584,7 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
591
584
  prompt=prompt,
592
585
  )
593
586
 
594
- response = await self.invoke_llm(prompt)
595
- llm_response = LLMResponse.ensure_llm_response(response)
596
- actions = None
597
- if llm_response and llm_response.choices:
598
- actions = llm_response.choices[0]
599
-
587
+ actions = await self.invoke_llm(prompt)
600
588
  structlogger.debug(
601
589
  "multi_step_llm_command_generator"
602
590
  ".predict_commands_for_handling_flows"
@@ -610,11 +598,10 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
610
598
 
611
599
  if commands:
612
600
  self._add_prompt_to_message_parse_data(
613
- message=message,
614
- component_name=MultiStepLLMCommandGenerator.__name__,
615
- prompt_name="handle_flows_prompt",
616
- user_prompt=prompt,
617
- llm_response=llm_response,
601
+ message,
602
+ MultiStepLLMCommandGenerator.__name__,
603
+ "handle_flows_prompt",
604
+ prompt,
618
605
  )
619
606
 
620
607
  return commands
@@ -681,12 +668,7 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
681
668
  prompt=prompt,
682
669
  )
683
670
 
684
- response = await self.invoke_llm(prompt)
685
- llm_response = LLMResponse.ensure_llm_response(response)
686
- actions = None
687
- if llm_response and llm_response.choices:
688
- actions = llm_response.choices[0]
689
-
671
+ actions = await self.invoke_llm(prompt)
690
672
  structlogger.debug(
691
673
  "multi_step_llm_command_generator"
692
674
  ".predict_commands_for_newly_started_flow"
@@ -713,11 +695,10 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
713
695
 
714
696
  if commands:
715
697
  self._add_prompt_to_message_parse_data(
716
- message=message,
717
- component_name=MultiStepLLMCommandGenerator.__name__,
718
- prompt_name="fill_slots_for_new_flow_prompt",
719
- user_prompt=prompt,
720
- llm_response=llm_response,
698
+ message,
699
+ MultiStepLLMCommandGenerator.__name__,
700
+ "fill_slots_for_new_flow_prompt",
701
+ prompt,
721
702
  )
722
703
 
723
704
  return commands
@@ -46,7 +46,6 @@ from rasa.shared.core.trackers import DialogueStateTracker
46
46
  from rasa.shared.exceptions import ProviderClientAPIException
47
47
  from rasa.shared.nlu.constants import TEXT, LLM_COMMANDS, LLM_PROMPT
48
48
  from rasa.shared.nlu.training_data.message import Message
49
- from rasa.shared.providers.llm.llm_response import LLMResponse
50
49
  from rasa.shared.utils.io import deep_container_fingerprint
51
50
  from rasa.shared.utils.llm import (
52
51
  get_prompt_template,
@@ -265,16 +264,13 @@ class SingleStepLLMCommandGenerator(LLMBasedCommandGenerator):
265
264
  prompt=flow_prompt,
266
265
  )
267
266
 
268
- response = await self.invoke_llm(flow_prompt)
269
- llm_response = LLMResponse.ensure_llm_response(response)
267
+ action_list = await self.invoke_llm(flow_prompt)
270
268
  # The check for 'None' maintains compatibility with older versions
271
269
  # of LLMCommandGenerator. In previous implementations, 'invoke_llm'
272
270
  # might return 'None' to indicate a failure to generate actions.
273
- if llm_response is None or not llm_response.choices:
271
+ if action_list is None:
274
272
  return [ErrorCommand()]
275
273
 
276
- action_list = llm_response.choices[0]
277
-
278
274
  log_llm(
279
275
  logger=structlogger,
280
276
  log_module="SingleStepLLMCommandGenerator",
@@ -289,11 +285,10 @@ class SingleStepLLMCommandGenerator(LLMBasedCommandGenerator):
289
285
  message, SingleStepLLMCommandGenerator.__name__, commands
290
286
  )
291
287
  self._add_prompt_to_message_parse_data(
292
- message=message,
293
- component_name=SingleStepLLMCommandGenerator.__name__,
294
- prompt_name="command_generator_prompt",
295
- user_prompt=flow_prompt,
296
- llm_response=llm_response,
288
+ message,
289
+ SingleStepLLMCommandGenerator.__name__,
290
+ "command_generator_prompt",
291
+ flow_prompt,
297
292
  )
298
293
 
299
294
  return commands
@@ -1,14 +1,7 @@
1
1
  from contextlib import contextmanager
2
2
  from typing import Generator
3
3
 
4
- from rasa.dialogue_understanding.constants import (
5
- RASA_RECORD_COMMANDS_AND_PROMPTS_ENV_VAR_NAME,
6
- )
7
- from rasa.utils.common import get_bool_env_variable
8
-
9
- record_commands_and_prompts = get_bool_env_variable(
10
- RASA_RECORD_COMMANDS_AND_PROMPTS_ENV_VAR_NAME, False
11
- )
4
+ record_commands_and_prompts = False
12
5
 
13
6
 
14
7
  @contextmanager
@@ -0,0 +1,12 @@
1
+ from typing import List, Dict
2
+
3
+ from rasa.dialogue_understanding_test.du_test_result import (
4
+ DialogueUnderstandingTestResult,
5
+ )
6
+
7
+
8
+ def calculate_command_metrics(
9
+ test_results: List[DialogueUnderstandingTestResult],
10
+ ) -> Dict[str, Dict[str, float]]:
11
+ """Calculate the metrics for the commands."""
12
+ return {}
@@ -13,3 +13,5 @@ KEY_COMMANDS = "commands"
13
13
 
14
14
  ACTOR_USER = "user"
15
15
  ACTOR_BOT = "bot"
16
+
17
+ DEFAULT_INPUT_TESTS_PATH = "du_tests/"
@@ -0,0 +1,93 @@
1
+ import asyncio
2
+ from typing import Dict, Optional, Text, Union, List
3
+
4
+ import structlog
5
+
6
+ from rasa.core.exceptions import AgentNotReady
7
+ from rasa.core.persistor import StorageType
8
+ from rasa.core.utils import AvailableEndpoints
9
+ from rasa.dialogue_understanding_test.du_test_case import DialogueUnderstandingTestCase
10
+ from rasa.dialogue_understanding_test.du_test_result import (
11
+ DialogueUnderstandingTestResult,
12
+ )
13
+ from rasa.e2e_test.e2e_test_case import (
14
+ KEY_STUB_CUSTOM_ACTIONS,
15
+ ActualStepOutput,
16
+ TestStep,
17
+ Fixture,
18
+ Metadata,
19
+ )
20
+ from rasa.e2e_test.e2e_test_runner import E2ETestRunner
21
+ from rasa.utils.endpoints import EndpointConfig
22
+
23
+ structlogger = structlog.get_logger()
24
+
25
+ TEST_TURNS_TYPE = Dict[int, Union[TestStep, ActualStepOutput]]
26
+
27
+
28
+ class DialogueUnderstandingTestRunner:
29
+ """Dialogue Understanding test suite runner."""
30
+
31
+ def __init__(
32
+ self,
33
+ model_path: Optional[Text] = None,
34
+ model_server: Optional[EndpointConfig] = None,
35
+ remote_storage: Optional[StorageType] = None,
36
+ endpoints: Optional[AvailableEndpoints] = None,
37
+ ) -> None:
38
+ """Initializes the Dialogue Understanding test suite runner.
39
+
40
+ Args:
41
+ model_path: Path to the model.
42
+ model_server: Model server configuration.
43
+ remote_storage: Remote storage to use for model retrieval.
44
+ endpoints: Endpoints configuration.
45
+ """
46
+ import rasa.core.agent
47
+
48
+ self._check_action_server(endpoints)
49
+
50
+ self.agent = asyncio.run(
51
+ rasa.core.agent.load_agent(
52
+ model_path=model_path,
53
+ model_server=model_server,
54
+ remote_storage=remote_storage,
55
+ endpoints=endpoints,
56
+ )
57
+ )
58
+ if not self.agent.is_ready():
59
+ raise AgentNotReady(
60
+ "Agent needs to be prepared before usage. "
61
+ "Please check that the agent was able to "
62
+ "load the trained model."
63
+ )
64
+
65
+ def _check_action_server(self, endpoints: AvailableEndpoints) -> None:
66
+ """Check if the action server is reachable."""
67
+ are_custom_actions_stubbed = (
68
+ endpoints
69
+ and endpoints.action
70
+ and endpoints.action.kwargs.get(KEY_STUB_CUSTOM_ACTIONS)
71
+ )
72
+ if endpoints and not are_custom_actions_stubbed:
73
+ E2ETestRunner._action_server_is_reachable(
74
+ endpoints, "dialogue_understanding_test_runner"
75
+ )
76
+
77
+ async def run_tests(
78
+ self,
79
+ test_cases: List[DialogueUnderstandingTestCase],
80
+ fixtures: List[Fixture],
81
+ metadata: List[Metadata],
82
+ ) -> List[DialogueUnderstandingTestResult]:
83
+ """Run the dialogue understanding tests.
84
+
85
+ Args:
86
+ test_cases: List of test cases.
87
+ fixtures: List of fixtures.
88
+ metadata: List of metadata.
89
+
90
+ Returns:
91
+ List[DialogueUnderstandingTestResult]: List of test results.
92
+ """
93
+ return []
@@ -0,0 +1,54 @@
1
+ from typing import List, Dict
2
+
3
+ from rasa.dialogue_understanding_test.du_test_result import (
4
+ DialogueUnderstandingTestResult,
5
+ )
6
+ from rasa.e2e_test.e2e_test_case import TestSuite
7
+
8
+
9
+ def read_test_suite(test_case_path: str) -> TestSuite:
10
+ """Read the test cases from the given test case path.
11
+
12
+ Args:
13
+ test_case_path: Path to the test cases.
14
+
15
+ Returns:
16
+ TestSuite: Test suite containing the dialogue understanding test cases.
17
+ """
18
+ return TestSuite([], [], [], {})
19
+
20
+
21
+ def write_test_results_to_file(
22
+ failed_tests: List[DialogueUnderstandingTestResult],
23
+ passed_tests: List[DialogueUnderstandingTestResult],
24
+ command_metrics: Dict[str, Dict[str, float]],
25
+ output_file: str,
26
+ output_prompt: bool,
27
+ ) -> None:
28
+ """Write the test results to the given output file.
29
+
30
+ Args:
31
+ failed_tests: Failed test cases.
32
+ passed_tests: Passed test cases.
33
+ command_metrics: Metrics for the commands.
34
+ output_file: Path to the output file.
35
+ output_prompt: Whether to log the prompt or not.
36
+ """
37
+ pass
38
+
39
+
40
+ def print_test_results(
41
+ failed_tests: List[DialogueUnderstandingTestResult],
42
+ passed_tests: List[DialogueUnderstandingTestResult],
43
+ command_metrics: Dict[str, Dict[str, float]],
44
+ output_prompt: bool,
45
+ ) -> None:
46
+ """Print the test results to console.
47
+
48
+ Args:
49
+ failed_tests: Failed test cases.
50
+ passed_tests: Passed test cases.
51
+ command_metrics: Metrics for the commands.
52
+ output_prompt: Whether to log the prompt or not.
53
+ """
54
+ pass
@@ -0,0 +1,22 @@
1
+ import argparse
2
+ from typing import List
3
+
4
+ from rasa.dialogue_understanding_test.du_test_case import DialogueUnderstandingTestCase
5
+
6
+
7
+ def validate_cli_arguments(args: argparse.Namespace) -> None:
8
+ """Validate the CLI arguments for the dialogue understanding test.
9
+
10
+ Args:
11
+ args: Commandline arguments.
12
+ """
13
+ pass
14
+
15
+
16
+ def validate_test_cases(test_cases: List[DialogueUnderstandingTestCase]) -> None:
17
+ """Validate the dialogue understanding test cases.
18
+
19
+ Args:
20
+ test_cases: Test cases to validate.
21
+ """
22
+ pass