rasa-pro 3.10.13a1__py3-none-any.whl → 3.10.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (41) hide show
  1. rasa/api.py +1 -1
  2. rasa/cli/e2e_test.py +1 -1
  3. rasa/cli/evaluate.py +1 -1
  4. rasa/cli/llm_fine_tuning.py +1 -1
  5. rasa/cli/run.py +1 -1
  6. rasa/cli/studio/studio.py +18 -8
  7. rasa/cli/train.py +9 -0
  8. rasa/cli/x.py +1 -1
  9. rasa/core/policies/enterprise_search_policy.py +13 -1
  10. rasa/core/policies/flows/flow_executor.py +18 -8
  11. rasa/core/policies/intentless_policy.py +13 -1
  12. rasa/core/processor.py +7 -5
  13. rasa/core/training/interactive.py +1 -1
  14. rasa/core/utils.py +11 -1
  15. rasa/dialogue_understanding/coexistence/llm_based_router.py +8 -1
  16. rasa/dialogue_understanding/generator/flow_retrieval.py +7 -0
  17. rasa/dialogue_understanding/generator/llm_based_command_generator.py +1 -1
  18. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +8 -0
  19. rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
  20. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +8 -0
  21. rasa/e2e_test/aggregate_test_stats_calculator.py +11 -1
  22. rasa/e2e_test/assertions.py +48 -6
  23. rasa/e2e_test/e2e_test_runner.py +4 -3
  24. rasa/engine/validation.py +78 -1
  25. rasa/model_training.py +1 -0
  26. rasa/shared/constants.py +5 -0
  27. rasa/shared/core/flows/flows_list.py +5 -1
  28. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +6 -1
  29. rasa/shared/providers/llm/_base_litellm_client.py +5 -1
  30. rasa/shared/utils/llm.py +28 -7
  31. rasa/studio/auth.py +3 -5
  32. rasa/studio/config.py +13 -4
  33. rasa/studio/constants.py +1 -0
  34. rasa/studio/data_handler.py +10 -3
  35. rasa/studio/upload.py +17 -8
  36. rasa/version.py +1 -1
  37. {rasa_pro-3.10.13a1.dist-info → rasa_pro-3.10.15.dist-info}/METADATA +2 -2
  38. {rasa_pro-3.10.13a1.dist-info → rasa_pro-3.10.15.dist-info}/RECORD +41 -41
  39. {rasa_pro-3.10.13a1.dist-info → rasa_pro-3.10.15.dist-info}/NOTICE +0 -0
  40. {rasa_pro-3.10.13a1.dist-info → rasa_pro-3.10.15.dist-info}/WHEEL +0 -0
  41. {rasa_pro-3.10.13a1.dist-info → rasa_pro-3.10.15.dist-info}/entry_points.txt +0 -0
rasa/api.py CHANGED
@@ -41,7 +41,7 @@ def run(
41
41
  from rasa.shared.constants import DOCS_BASE_URL
42
42
  from rasa.shared.utils.cli import print_warning
43
43
 
44
- _endpoints = AvailableEndpoints.read_endpoints(endpoints)
44
+ _endpoints = AvailableEndpoints.get_instance(endpoints)
45
45
 
46
46
  if not connector and not credentials:
47
47
  connector = "rest"
rasa/cli/e2e_test.py CHANGED
@@ -164,7 +164,7 @@ def execute_e2e_tests(args: argparse.Namespace) -> None:
164
164
  args.endpoints = rasa.cli.utils.get_validated_path(
165
165
  args.endpoints, "endpoints", DEFAULT_ENDPOINTS_PATH, True
166
166
  )
167
- endpoints = AvailableEndpoints.read_endpoints(args.endpoints)
167
+ endpoints = AvailableEndpoints.get_instance(args.endpoints)
168
168
 
169
169
  # Ignore all endpoints apart from action server, model, nlu and nlg
170
170
  # to ensure InMemoryTrackerStore is being used instead of production
rasa/cli/evaluate.py CHANGED
@@ -217,6 +217,6 @@ def _create_tracker_loader(
217
217
  A MarkerTrackerLoader object configured with the specified strategy against
218
218
  the configured tracker store.
219
219
  """
220
- endpoints = AvailableEndpoints.read_endpoints(endpoint_config)
220
+ endpoints = AvailableEndpoints.get_instance(endpoint_config)
221
221
  tracker_store = TrackerStore.create(endpoints.tracker_store, domain=domain)
222
222
  return MarkerTrackerLoader(tracker_store, strategy, count, seed)
@@ -352,7 +352,7 @@ def get_valid_endpoints(endpoints_file: str) -> AvailableEndpoints:
352
352
  validated_endpoints_file = rasa.cli.utils.get_validated_path(
353
353
  endpoints_file, "endpoints", DEFAULT_ENDPOINTS_PATH, True
354
354
  )
355
- endpoints = AvailableEndpoints.read_endpoints(validated_endpoints_file)
355
+ endpoints = AvailableEndpoints.get_instance(validated_endpoints_file)
356
356
 
357
357
  # Ignore all endpoints apart from action server, model, nlu and nlg
358
358
  # to ensure InMemoryTrackerStore is being used instead of production
rasa/cli/run.py CHANGED
@@ -106,7 +106,7 @@ def run(args: argparse.Namespace) -> None:
106
106
  return
107
107
 
108
108
  # start server if model server is configured
109
- endpoints = AvailableEndpoints.read_endpoints(args.endpoints)
109
+ endpoints = AvailableEndpoints.get_instance(args.endpoints)
110
110
  model_server = endpoints.model if endpoints and endpoints.model else None
111
111
  if model_server is not None:
112
112
  rasa_run(**vars(args))
rasa/cli/studio/studio.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import argparse
2
- from typing import List, Optional
2
+ from typing import List, Optional, Tuple
3
3
  from urllib.parse import ParseResult, urlparse
4
4
 
5
5
  import questionary
@@ -149,7 +149,7 @@ def _configure_studio_url() -> Optional[str]:
149
149
  return studio_url
150
150
 
151
151
 
152
- def _get_advanced_config(studio_url: str) -> tuple:
152
+ def _get_advanced_config(studio_url: str) -> Tuple:
153
153
  """Get the advanced configuration values for Rasa Studio."""
154
154
  keycloak_url = questionary.text(
155
155
  "Please provide your Rasa Studio Keycloak URL",
@@ -167,7 +167,7 @@ def _get_advanced_config(studio_url: str) -> tuple:
167
167
  return keycloak_url, realm_name, client_id
168
168
 
169
169
 
170
- def _get_default_config(studio_url: str) -> tuple:
170
+ def _get_default_config(studio_url: str) -> Tuple:
171
171
  """Get the default configuration values for Rasa Studio."""
172
172
  keycloak_url = studio_url + "auth/"
173
173
  realm_name = DEFAULT_REALM_NAME
@@ -178,6 +178,7 @@ def _get_default_config(studio_url: str) -> tuple:
178
178
  f"Keycloak URL: {keycloak_url}, "
179
179
  f"Realm Name: '{realm_name}', "
180
180
  f"Client ID: '{client_id}'. "
181
+ f"SSL verification is enabled."
181
182
  f"You can use '--advanced' to configure these settings."
182
183
  )
183
184
 
@@ -185,7 +186,11 @@ def _get_default_config(studio_url: str) -> tuple:
185
186
 
186
187
 
187
188
  def _create_studio_config(
188
- studio_url: str, keycloak_url: str, realm_name: str, client_id: str
189
+ studio_url: str,
190
+ keycloak_url: str,
191
+ realm_name: str,
192
+ client_id: str,
193
+ disable_verify: bool = False,
189
194
  ) -> StudioConfig:
190
195
  """Create a StudioConfig object with the provided parameters."""
191
196
  return StudioConfig(
@@ -193,6 +198,7 @@ def _create_studio_config(
193
198
  studio_url=studio_url + "api/graphql/",
194
199
  client_id=client_id,
195
200
  realm_name=realm_name,
201
+ disable_verify=disable_verify,
196
202
  )
197
203
 
198
204
 
@@ -227,19 +233,23 @@ def _configure_studio_config(args: argparse.Namespace) -> StudioConfig:
227
233
 
228
234
  # create a configuration and auth object to try to reach the studio
229
235
  studio_config = _create_studio_config(
230
- studio_url, keycloak_url, realm_name, client_id
236
+ studio_url,
237
+ keycloak_url,
238
+ realm_name,
239
+ client_id,
240
+ disable_verify=args.disable_verify,
231
241
  )
232
242
 
233
- if args.disable_verify:
243
+ if studio_config.disable_verify:
234
244
  rasa.shared.utils.cli.print_info(
235
245
  "Disabling SSL verification for the Rasa Studio authentication server."
236
246
  )
237
- studio_auth = StudioAuth(studio_config, verify=False)
238
247
  else:
239
248
  rasa.shared.utils.cli.print_info(
240
249
  "Enabling SSL verification for the Rasa Studio authentication server."
241
250
  )
242
- studio_auth = StudioAuth(studio_config, verify=True)
251
+
252
+ studio_auth = StudioAuth(studio_config)
243
253
 
244
254
  if _check_studio_auth(studio_auth):
245
255
  return studio_config
rasa/cli/train.py CHANGED
@@ -14,8 +14,10 @@ import rasa.core.utils
14
14
  import rasa.utils.common
15
15
  from rasa.api import train as train_all
16
16
  from rasa.cli import SubParsersAction
17
+ from rasa.core import ContextualResponseRephraser
17
18
  from rasa.core.nlg.generator import NaturalLanguageGenerator
18
19
  from rasa.core.train import do_compare_training
20
+ from rasa.engine.validation import validate_api_type_config_key_usage
19
21
  from rasa.nlu.persistor import get_persistor
20
22
  from rasa.shared.constants import (
21
23
  CONFIG_MANDATORY_KEYS,
@@ -23,6 +25,7 @@ from rasa.shared.constants import (
23
25
  CONFIG_MANDATORY_KEYS_NLU,
24
26
  DEFAULT_DATA_PATH,
25
27
  DEFAULT_DOMAIN_PATHS,
28
+ LLM_CONFIG_KEY,
26
29
  )
27
30
  from rasa.shared.exceptions import RasaException
28
31
  from rasa.shared.importers.importer import TrainingDataImporter
@@ -75,6 +78,12 @@ def add_subparser(
75
78
  def _check_nlg_endpoint_validity(endpoint: Union[Path, str]) -> None:
76
79
  try:
77
80
  endpoints = rasa.core.utils.read_endpoints_from_path(endpoint)
81
+ if endpoints.nlg is not None:
82
+ validate_api_type_config_key_usage(
83
+ endpoints.nlg.kwargs,
84
+ LLM_CONFIG_KEY,
85
+ ContextualResponseRephraser.__name__,
86
+ )
78
87
  NaturalLanguageGenerator.create(endpoints.nlg)
79
88
  except Exception as e:
80
89
  structlogger.error(
rasa/cli/x.py CHANGED
@@ -179,7 +179,7 @@ def run_in_enterprise_connection_mode(args: argparse.Namespace) -> None:
179
179
  print_success("Starting a Rasa server in Rasa Enterprise connection mode... 🚀")
180
180
 
181
181
  credentials_path, endpoints_path = _get_credentials_and_endpoints_paths(args)
182
- endpoints = AvailableEndpoints.read_endpoints(endpoints_path)
182
+ endpoints = AvailableEndpoints.get_instance(endpoints_path)
183
183
 
184
184
  _rasa_service(args, endpoints, None, credentials_path)
185
185
 
@@ -76,6 +76,7 @@ from rasa.shared.utils.llm import (
76
76
  sanitize_message_for_prompt,
77
77
  tracker_as_readable_transcript,
78
78
  try_instantiate_llm_client,
79
+ try_instantiate_embedder,
79
80
  )
80
81
  from rasa.core.information_retrieval.faiss import FAISS_Store
81
82
  from rasa.core.information_retrieval import (
@@ -661,6 +662,18 @@ class EnterpriseSearchPolicy(Policy):
661
662
  execution_context: ExecutionContext,
662
663
  **kwargs: Any,
663
664
  ) -> "EnterpriseSearchPolicy":
665
+ try_instantiate_llm_client(
666
+ config.get(LLM_CONFIG_KEY),
667
+ DEFAULT_LLM_CONFIG,
668
+ "enterprise_search_policy.load",
669
+ EnterpriseSearchPolicy.__name__,
670
+ )
671
+ try_instantiate_embedder(
672
+ config.get(EMBEDDINGS_CONFIG_KEY),
673
+ DEFAULT_EMBEDDINGS_CONFIG,
674
+ "enterprise_search_policy.load",
675
+ EnterpriseSearchPolicy.__name__,
676
+ )
664
677
  """Loads a trained policy (see parent class for full docstring)."""
665
678
  prompt_template = None
666
679
  store_type = config.get(VECTOR_STORE_PROPERTY, {}).get(
@@ -695,7 +708,6 @@ class EnterpriseSearchPolicy(Policy):
695
708
  logger.warning(
696
709
  "enterprise_search_policy.load.failed", error=e, resource=resource.name
697
710
  )
698
-
699
711
  return cls(
700
712
  config,
701
713
  model_storage,
@@ -484,7 +484,8 @@ def validate_collect_step(
484
484
  step: CollectInformationFlowStep,
485
485
  stack: DialogueStack,
486
486
  available_actions: List[str],
487
- slots: Dict[Text, Slot],
487
+ slots: Dict[str, Slot],
488
+ flow_name: str,
488
489
  ) -> bool:
489
490
  """Validate that a collect step can be executed.
490
491
 
@@ -507,12 +508,12 @@ def validate_collect_step(
507
508
  slot_name=step.collect,
508
509
  )
509
510
 
510
- cancel_flow_and_push_internal_error(stack)
511
+ cancel_flow_and_push_internal_error(stack, flow_name)
511
512
 
512
513
  return False
513
514
 
514
515
 
515
- def cancel_flow_and_push_internal_error(stack: DialogueStack) -> None:
516
+ def cancel_flow_and_push_internal_error(stack: DialogueStack, flow_name: str) -> None:
516
517
  """Cancel the top user flow and push the internal error pattern."""
517
518
  top_frame = stack.top()
518
519
 
@@ -524,7 +525,7 @@ def cancel_flow_and_push_internal_error(stack: DialogueStack) -> None:
524
525
  canceled_frames = CancelFlowCommand.select_canceled_frames(stack)
525
526
  stack.push(
526
527
  CancelPatternFlowStackFrame(
527
- canceled_name=top_frame.flow_id,
528
+ canceled_name=flow_name,
528
529
  canceled_frames=canceled_frames,
529
530
  )
530
531
  )
@@ -536,6 +537,7 @@ def validate_custom_slot_mappings(
536
537
  stack: DialogueStack,
537
538
  tracker: DialogueStateTracker,
538
539
  available_actions: List[str],
540
+ flow_name: str,
539
541
  ) -> bool:
540
542
  """Validate a slot with custom mappings.
541
543
 
@@ -556,7 +558,7 @@ def validate_custom_slot_mappings(
556
558
  action=step.collect_action,
557
559
  collect=step.collect,
558
560
  )
559
- cancel_flow_and_push_internal_error(stack)
561
+ cancel_flow_and_push_internal_error(stack, flow_name)
560
562
  return False
561
563
 
562
564
  return True
@@ -596,7 +598,12 @@ def run_step(
596
598
 
597
599
  if isinstance(step, CollectInformationFlowStep):
598
600
  return _run_collect_information_step(
599
- available_actions, initial_events, stack, step, tracker
601
+ available_actions,
602
+ initial_events,
603
+ stack,
604
+ step,
605
+ tracker,
606
+ flow.readable_name(),
600
607
  )
601
608
 
602
609
  elif isinstance(step, ActionFlowStep):
@@ -716,15 +723,18 @@ def _run_collect_information_step(
716
723
  stack: DialogueStack,
717
724
  step: CollectInformationFlowStep,
718
725
  tracker: DialogueStateTracker,
726
+ flow_name: str,
719
727
  ) -> FlowStepResult:
720
- is_step_valid = validate_collect_step(step, stack, available_actions, tracker.slots)
728
+ is_step_valid = validate_collect_step(
729
+ step, stack, available_actions, tracker.slots, flow_name
730
+ )
721
731
 
722
732
  if not is_step_valid:
723
733
  # if we return any other FlowStepResult, the assistant will stay silent
724
734
  # instead of triggering the internal error pattern
725
735
  return ContinueFlowWithNextStep(events=initial_events)
726
736
  is_mapping_valid = validate_custom_slot_mappings(
727
- step, stack, tracker, available_actions
737
+ step, stack, tracker, available_actions, flow_name
728
738
  )
729
739
 
730
740
  if not is_mapping_valid:
@@ -71,6 +71,7 @@ from rasa.shared.utils.llm import (
71
71
  sanitize_message_for_prompt,
72
72
  tracker_as_readable_transcript,
73
73
  try_instantiate_llm_client,
74
+ try_instantiate_embedder,
74
75
  )
75
76
  from rasa.utils.ml_utils import (
76
77
  extract_ai_response_examples,
@@ -918,6 +919,18 @@ class IntentlessPolicy(Policy):
918
919
  **kwargs: Any,
919
920
  ) -> "IntentlessPolicy":
920
921
  """Loads a trained policy (see parent class for full docstring)."""
922
+ try_instantiate_llm_client(
923
+ config.get(LLM_CONFIG_KEY),
924
+ DEFAULT_LLM_CONFIG,
925
+ "intentless_policy.load",
926
+ IntentlessPolicy.__name__,
927
+ )
928
+ try_instantiate_embedder(
929
+ config.get(EMBEDDINGS_CONFIG_KEY),
930
+ DEFAULT_EMBEDDINGS_CONFIG,
931
+ "intentless_policy.load",
932
+ IntentlessPolicy.__name__,
933
+ )
921
934
  responses_docsearch = None
922
935
  samples_docsearch = None
923
936
  prompt_template = None
@@ -943,7 +956,6 @@ class IntentlessPolicy(Policy):
943
956
  structlogger.warning(
944
957
  "intentless_policy.load.failed", error=e, resource_name=resource.name
945
958
  )
946
-
947
959
  return cls(
948
960
  config,
949
961
  model_storage,
rasa/core/processor.py CHANGED
@@ -1254,11 +1254,13 @@ class MessageProcessor:
1254
1254
  tracker.update(events[0])
1255
1255
  return self.should_predict_another_action(action.name())
1256
1256
  except Exception:
1257
- logger.exception(
1258
- f"Encountered an exception while running action '{action.name()}'."
1259
- "Bot will continue, but the actions events are lost. "
1260
- "Please check the logs of your action server for "
1261
- "more information."
1257
+ structlogger.exception(
1258
+ "rasa.core.processor.run_action.exception",
1259
+ event_info=f"Encountered an exception while "
1260
+ f"running action '{action.name()}'."
1261
+ f"Bot will continue, but the actions events are lost. "
1262
+ f"Please check the logs of your action server for "
1263
+ f"more information.",
1262
1264
  )
1263
1265
  events = []
1264
1266
 
@@ -1688,7 +1688,7 @@ def run_interactive_learning(
1688
1688
  p = None
1689
1689
 
1690
1690
  app = run.configure_app(port=port, conversation_id="default", enable_api=True)
1691
- endpoints = AvailableEndpoints.read_endpoints(server_args.get("endpoints"))
1691
+ endpoints = AvailableEndpoints.get_instance(server_args.get("endpoints"))
1692
1692
 
1693
1693
  # before_server_start handlers make sure the agent is loaded before the
1694
1694
  # interactive learning IO starts
rasa/core/utils.py CHANGED
@@ -171,6 +171,8 @@ def is_limit_reached(num_messages: int, limit: Optional[int]) -> bool:
171
171
  class AvailableEndpoints:
172
172
  """Collection of configured endpoints."""
173
173
 
174
+ _instance = None
175
+
174
176
  @classmethod
175
177
  def read_endpoints(cls, endpoint_file: Text) -> "AvailableEndpoints":
176
178
  """Read the different endpoints from a yaml file."""
@@ -217,6 +219,14 @@ class AvailableEndpoints:
217
219
  self.event_broker = event_broker
218
220
  self.vector_store = vector_store
219
221
 
222
+ @classmethod
223
+ def get_instance(cls, endpoint_file: Optional[Text] = None) -> "AvailableEndpoints":
224
+ """Get the singleton instance of AvailableEndpoints."""
225
+ # Ensure that the instance is initialized only once.
226
+ if cls._instance is None:
227
+ cls._instance = cls.read_endpoints(endpoint_file)
228
+ return cls._instance
229
+
220
230
 
221
231
  def read_endpoints_from_path(
222
232
  endpoints_path: Optional[Union[Path, Text]] = None,
@@ -234,7 +244,7 @@ def read_endpoints_from_path(
234
244
  endpoints_config_path = cli_utils.get_validated_path(
235
245
  endpoints_path, "endpoints", DEFAULT_ENDPOINTS_PATH, True
236
246
  )
237
- return AvailableEndpoints.read_endpoints(endpoints_config_path)
247
+ return AvailableEndpoints.get_instance(endpoints_config_path)
238
248
 
239
249
 
240
250
  def _lock_store_is_multi_worker_compatible(
@@ -161,7 +161,14 @@ class LLMBasedRouter(GraphComponent):
161
161
  "llm_based_router.load.failed", error=e, resource=resource.name
162
162
  )
163
163
 
164
- return cls(config, model_storage, resource, prompt_template=prompt_template)
164
+ router = cls(config, model_storage, resource, prompt_template=prompt_template)
165
+ try_instantiate_llm_client(
166
+ router.config.get(LLM_CONFIG_KEY),
167
+ DEFAULT_LLM_CONFIG,
168
+ "llm_based_router.load",
169
+ LLMBasedRouter.__name__,
170
+ )
171
+ return router
165
172
 
166
173
  @classmethod
167
174
  def create(
@@ -50,6 +50,7 @@ from rasa.shared.utils.llm import (
50
50
  USER,
51
51
  get_prompt_template,
52
52
  allowed_values_for_slot,
53
+ try_instantiate_embedder,
53
54
  )
54
55
 
55
56
  DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
@@ -142,6 +143,12 @@ class FlowRetrieval:
142
143
  """Load flow retrieval with previously populated FAISS vector store."""
143
144
  # initialize base flow retrieval
144
145
  flow_retrieval = FlowRetrieval(config, model_storage, resource)
146
+ try_instantiate_embedder(
147
+ flow_retrieval.config.get(EMBEDDINGS_CONFIG_KEY),
148
+ DEFAULT_EMBEDDINGS_CONFIG,
149
+ "flow_retrieval.load",
150
+ FlowRetrieval.__name__,
151
+ )
145
152
  # load vector store
146
153
  vector_store = cls._load_vector_store(
147
154
  flow_retrieval.config, model_storage, resource
@@ -183,7 +183,7 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
183
183
  except Exception as e:
184
184
  structlogger.error(
185
185
  "llm_based_command_generator.train.failed",
186
- event_info=("Flow retrieval store isinaccessible."),
186
+ event_info="Flow retrieval store is inaccessible.",
187
187
  error=e,
188
188
  )
189
189
  raise
@@ -24,6 +24,7 @@ from rasa.dialogue_understanding.generator.constants import (
24
24
  LLM_CONFIG_KEY,
25
25
  USER_INPUT_CONFIG_KEY,
26
26
  FLOW_RETRIEVAL_KEY,
27
+ DEFAULT_LLM_CONFIG,
27
28
  )
28
29
  from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
29
30
  from rasa.dialogue_understanding.generator.llm_based_command_generator import (
@@ -53,6 +54,7 @@ from rasa.shared.utils.llm import (
53
54
  tracker_as_readable_transcript,
54
55
  sanitize_message_for_prompt,
55
56
  allowed_values_for_slot,
57
+ try_instantiate_llm_client,
56
58
  )
57
59
 
58
60
  # multistep template keys
@@ -141,6 +143,12 @@ class MultiStepLLMCommandGenerator(LLMBasedCommandGenerator):
141
143
  prompts = cls._load_prompt_templates(model_storage, resource)
142
144
  # init base command generator
143
145
  command_generator = cls(config, model_storage, resource, prompts)
146
+ try_instantiate_llm_client(
147
+ command_generator.config.get(LLM_CONFIG_KEY),
148
+ DEFAULT_LLM_CONFIG,
149
+ "multi_step_llm_command_generator.load",
150
+ MultiStepLLMCommandGenerator.__name__,
151
+ )
144
152
  # load flow retrieval if enabled
145
153
  if command_generator.enabled_flow_retrieval:
146
154
  command_generator.flow_retrieval = cls.load_flow_retrival(
@@ -19,6 +19,7 @@ from rasa.engine.storage.storage import ModelStorage
19
19
  from rasa.shared.constants import ROUTE_TO_CALM_SLOT
20
20
  from rasa.shared.core.domain import Domain
21
21
  from rasa.shared.core.flows.flows_list import FlowsList
22
+ from rasa.shared.core.flows.steps import CollectInformationFlowStep
22
23
  from rasa.shared.core.slot_mappings import (
23
24
  SlotFillingManager,
24
25
  extract_slot_value,
@@ -217,7 +218,24 @@ def _issue_set_slot_commands(
217
218
  commands: List[Command] = []
218
219
  domain = domain if domain else Domain.empty()
219
220
  slot_filling_manager = SlotFillingManager(domain, tracker, message)
220
- available_slot_names = flows.available_slot_names()
221
+
222
+ # only use slots that don't have ask_before_filling set to True
223
+ available_slot_names = flows.available_slot_names(ask_before_filling=False)
224
+
225
+ # check if the current step is a CollectInformationFlowStep
226
+ # in case it has ask_before_filling set to True, we need to add the
227
+ # slot to the available_slot_names
228
+ if tracker.active_flow:
229
+ flow = flows.flow_by_id(tracker.active_flow)
230
+ step_id = tracker.current_step_id
231
+ if flow is not None:
232
+ current_step = flow.step_by_id(step_id)
233
+ if (
234
+ current_step
235
+ and isinstance(current_step, CollectInformationFlowStep)
236
+ and current_step.ask_before_filling
237
+ ):
238
+ available_slot_names.add(current_step.collect)
221
239
 
222
240
  for _, slot in tracker.slots.items():
223
241
  # if a slot is not collected in available flows,
@@ -21,6 +21,7 @@ from rasa.dialogue_understanding.generator.constants import (
21
21
  LLM_CONFIG_KEY,
22
22
  USER_INPUT_CONFIG_KEY,
23
23
  FLOW_RETRIEVAL_KEY,
24
+ DEFAULT_LLM_CONFIG,
24
25
  )
25
26
  from rasa.dialogue_understanding.generator.flow_retrieval import (
26
27
  FlowRetrieval,
@@ -48,6 +49,7 @@ from rasa.shared.utils.llm import (
48
49
  get_prompt_template,
49
50
  tracker_as_readable_transcript,
50
51
  sanitize_message_for_prompt,
52
+ try_instantiate_llm_client,
51
53
  )
52
54
  from rasa.utils.log_utils import log_llm
53
55
 
@@ -136,6 +138,12 @@ class SingleStepLLMCommandGenerator(LLMBasedCommandGenerator):
136
138
  )
137
139
  # init base command generator
138
140
  command_generator = cls(config, model_storage, resource, prompt_template)
141
+ try_instantiate_llm_client(
142
+ command_generator.config.get(LLM_CONFIG_KEY),
143
+ DEFAULT_LLM_CONFIG,
144
+ "single_step_llm_command_generator.load",
145
+ SingleStepLLMCommandGenerator.__name__,
146
+ )
139
147
  # load flow retrieval if enabled
140
148
  if command_generator.enabled_flow_retrieval:
141
149
  command_generator.flow_retrieval = cls.load_flow_retrival(
@@ -35,6 +35,7 @@ class AggregateTestStatsCalculator:
35
35
  self.test_cases = test_cases
36
36
 
37
37
  self.failed_assertion_set: Set["Assertion"] = set()
38
+ self.failed_test_cases_without_assertion_failure: Set[str] = set()
38
39
  self.passed_count_mapping = {
39
40
  subclass_type: 0
40
41
  for subclass_type in _get_all_assertion_subclasses().keys()
@@ -89,8 +90,14 @@ class AggregateTestStatsCalculator:
89
90
  passed_test_case_names = [
90
91
  passed.test_case.name for passed in self.passed_results
91
92
  ]
93
+ # We filter out test cases that failed without an assertion failure
94
+ filtered_test_cases = [
95
+ test_case
96
+ for test_case in self.test_cases
97
+ if test_case.name not in self.failed_test_cases_without_assertion_failure
98
+ ]
92
99
 
93
- for test_case in self.test_cases:
100
+ for test_case in filtered_test_cases:
94
101
  if test_case.name in passed_test_case_names:
95
102
  for step in test_case.steps:
96
103
  if step.assertions is None:
@@ -118,6 +125,9 @@ class AggregateTestStatsCalculator:
118
125
  "no_assertion_failure_in_failed_result",
119
126
  test_case=failed.test_case.name,
120
127
  )
128
+ self.failed_test_cases_without_assertion_failure.add(
129
+ failed.test_case.name
130
+ )
121
131
  continue
122
132
 
123
133
  self.failed_assertion_set.add(failed.assertion_failure.assertion)