rasa-pro 3.12.6.dev2__py3-none-any.whl → 3.13.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (92) hide show
  1. rasa/__init__.py +0 -6
  2. rasa/cli/scaffold.py +1 -1
  3. rasa/core/actions/action.py +38 -34
  4. rasa/core/actions/action_run_slot_rejections.py +1 -1
  5. rasa/core/channels/studio_chat.py +16 -43
  6. rasa/core/channels/voice_ready/audiocodes.py +46 -17
  7. rasa/core/information_retrieval/faiss.py +68 -7
  8. rasa/core/information_retrieval/information_retrieval.py +40 -2
  9. rasa/core/information_retrieval/milvus.py +7 -2
  10. rasa/core/information_retrieval/qdrant.py +7 -2
  11. rasa/core/nlg/contextual_response_rephraser.py +11 -27
  12. rasa/core/nlg/generator.py +5 -21
  13. rasa/core/nlg/response.py +6 -43
  14. rasa/core/nlg/summarize.py +1 -15
  15. rasa/core/nlg/translate.py +0 -8
  16. rasa/core/policies/enterprise_search_policy.py +64 -316
  17. rasa/core/policies/flows/flow_executor.py +3 -38
  18. rasa/core/policies/intentless_policy.py +4 -17
  19. rasa/core/policies/policy.py +0 -2
  20. rasa/core/processor.py +27 -6
  21. rasa/core/utils.py +53 -0
  22. rasa/dialogue_understanding/coexistence/llm_based_router.py +4 -18
  23. rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
  24. rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
  25. rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
  26. rasa/dialogue_understanding/generator/command_generator.py +67 -0
  27. rasa/dialogue_understanding/generator/command_parser.py +1 -1
  28. rasa/dialogue_understanding/generator/llm_based_command_generator.py +7 -23
  29. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -3
  30. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_template.jinja2 +1 -1
  31. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +1 -1
  32. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +24 -2
  33. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +8 -12
  34. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
  35. rasa/dialogue_understanding/processor/command_processor.py +7 -65
  36. rasa/dialogue_understanding/stack/utils.py +0 -38
  37. rasa/dialogue_understanding_test/command_metric_calculation.py +7 -40
  38. rasa/dialogue_understanding_test/command_metrics.py +38 -0
  39. rasa/dialogue_understanding_test/du_test_case.py +58 -25
  40. rasa/dialogue_understanding_test/du_test_result.py +228 -132
  41. rasa/dialogue_understanding_test/du_test_runner.py +10 -1
  42. rasa/dialogue_understanding_test/io.py +48 -16
  43. rasa/document_retrieval/__init__.py +0 -0
  44. rasa/document_retrieval/constants.py +32 -0
  45. rasa/document_retrieval/document_post_processor.py +351 -0
  46. rasa/document_retrieval/document_post_processor_prompt_template.jinja2 +0 -0
  47. rasa/document_retrieval/document_retriever.py +333 -0
  48. rasa/document_retrieval/knowledge_base_connectors/__init__.py +0 -0
  49. rasa/document_retrieval/knowledge_base_connectors/api_connector.py +39 -0
  50. rasa/document_retrieval/knowledge_base_connectors/knowledge_base_connector.py +34 -0
  51. rasa/document_retrieval/knowledge_base_connectors/vector_store_connector.py +226 -0
  52. rasa/document_retrieval/query_rewriter.py +234 -0
  53. rasa/document_retrieval/query_rewriter_prompt_template.jinja2 +8 -0
  54. rasa/engine/recipes/default_components.py +2 -0
  55. rasa/hooks.py +0 -55
  56. rasa/model_manager/model_api.py +1 -1
  57. rasa/model_manager/socket_bridge.py +0 -7
  58. rasa/shared/constants.py +0 -5
  59. rasa/shared/core/constants.py +0 -8
  60. rasa/shared/core/domain.py +12 -3
  61. rasa/shared/core/flows/flow.py +0 -17
  62. rasa/shared/core/flows/flows_yaml_schema.json +3 -38
  63. rasa/shared/core/flows/steps/collect.py +5 -18
  64. rasa/shared/core/flows/utils.py +1 -16
  65. rasa/shared/core/slot_mappings.py +11 -5
  66. rasa/shared/core/slots.py +1 -1
  67. rasa/shared/core/trackers.py +4 -10
  68. rasa/shared/nlu/constants.py +0 -1
  69. rasa/shared/providers/constants.py +0 -9
  70. rasa/shared/providers/llm/_base_litellm_client.py +4 -14
  71. rasa/shared/providers/llm/default_litellm_llm_client.py +2 -2
  72. rasa/shared/providers/llm/litellm_router_llm_client.py +7 -17
  73. rasa/shared/providers/llm/llm_client.py +15 -24
  74. rasa/shared/providers/llm/self_hosted_llm_client.py +2 -10
  75. rasa/shared/utils/common.py +11 -1
  76. rasa/shared/utils/health_check/health_check.py +1 -7
  77. rasa/shared/utils/llm.py +1 -1
  78. rasa/tracing/instrumentation/attribute_extractors.py +50 -17
  79. rasa/tracing/instrumentation/instrumentation.py +12 -12
  80. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +1 -2
  81. rasa/utils/licensing.py +0 -15
  82. rasa/validator.py +1 -123
  83. rasa/version.py +1 -1
  84. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/METADATA +2 -3
  85. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/RECORD +88 -80
  86. rasa/core/actions/action_handle_digressions.py +0 -164
  87. rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
  88. rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
  89. rasa/monkey_patches.py +0 -91
  90. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/NOTICE +0 -0
  91. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/WHEEL +0 -0
  92. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.13.0.dev2.dist-info}/entry_points.txt +0 -0
@@ -52,8 +52,6 @@ ACTION_TRIGGER_CHITCHAT = "action_trigger_chitchat"
52
52
  ACTION_RESET_ROUTING = "action_reset_routing"
53
53
  ACTION_HANGUP = "action_hangup"
54
54
  ACTION_REPEAT_BOT_MESSAGES = "action_repeat_bot_messages"
55
- ACTION_BLOCK_DIGRESSION = "action_block_digression"
56
- ACTION_CONTINUE_DIGRESSION = "action_continue_digression"
57
55
 
58
56
  ACTION_METADATA_EXECUTION_SUCCESS = "execution_success"
59
57
  ACTION_METADATA_EXECUTION_ERROR_MESSAGE = "execution_error_message"
@@ -84,8 +82,6 @@ DEFAULT_ACTION_NAMES = [
84
82
  ACTION_RESET_ROUTING,
85
83
  ACTION_HANGUP,
86
84
  ACTION_REPEAT_BOT_MESSAGES,
87
- ACTION_BLOCK_DIGRESSION,
88
- ACTION_CONTINUE_DIGRESSION,
89
85
  ]
90
86
 
91
87
  ACTION_SHOULD_SEND_DOMAIN = "send_domain"
@@ -205,8 +201,4 @@ CLASSIFIER_NAME_FALLBACK = "FallbackClassifier"
205
201
 
206
202
  POLICIES_THAT_EXTRACT_ENTITIES = {"TEDPolicy"}
207
203
 
208
- # digression constants
209
- KEY_ASK_CONFIRM_DIGRESSIONS = "ask_confirm_digressions"
210
- KEY_BLOCK_DIGRESSIONS = "block_digressions"
211
-
212
204
  ERROR_CODE_KEY = "error_code"
@@ -1678,6 +1678,14 @@ class Domain:
1678
1678
  """Write domain to a file."""
1679
1679
  as_yaml = self.as_yaml()
1680
1680
  rasa.shared.utils.io.write_text_file(as_yaml, filename)
1681
+ # run the check again on the written domain to catch any errors
1682
+ # that may have been missed in the user defined domain files
1683
+ structlogger.info(
1684
+ "domain.persist.domain_written_to_file",
1685
+ event_info="The entire domain content has been written to file.",
1686
+ filename=filename,
1687
+ )
1688
+ Domain.is_domain_file(filename)
1681
1689
 
1682
1690
  def as_yaml(self) -> Text:
1683
1691
  """Dump the `Domain` object as a YAML string.
@@ -1972,17 +1980,18 @@ class Domain:
1972
1980
 
1973
1981
  try:
1974
1982
  content = read_yaml_file(filename, expand_env_vars=cls.expand_env_vars)
1975
- except (RasaException, YamlSyntaxException):
1976
- structlogger.warning(
1983
+ except (RasaException, YamlSyntaxException) as error:
1984
+ structlogger.error(
1977
1985
  "domain.cannot_load_domain_file",
1978
1986
  file=filename,
1987
+ error=error,
1979
1988
  event_info=(
1980
1989
  f"The file {filename} could not be loaded as domain file. "
1981
1990
  f"You can use https://yamlchecker.com/ to validate "
1982
1991
  f"the YAML syntax of your file."
1983
1992
  ),
1984
1993
  )
1985
- return False
1994
+ raise RasaException(f"Domain could not be loaded: {error}")
1986
1995
 
1987
1996
  return any(key in content for key in ALL_DOMAIN_KEYS)
1988
1997
 
@@ -13,10 +13,6 @@ from pypred import Predicate
13
13
  import rasa.shared.utils.io
14
14
  from rasa.engine.language import Language
15
15
  from rasa.shared.constants import RASA_DEFAULT_FLOW_PATTERN_PREFIX
16
- from rasa.shared.core.constants import (
17
- KEY_ASK_CONFIRM_DIGRESSIONS,
18
- KEY_BLOCK_DIGRESSIONS,
19
- )
20
16
  from rasa.shared.core.flows.constants import (
21
17
  KEY_ALWAYS_INCLUDE_IN_PROMPT,
22
18
  KEY_DESCRIPTION,
@@ -52,7 +48,6 @@ from rasa.shared.core.flows.steps.constants import (
52
48
  START_STEP,
53
49
  )
54
50
  from rasa.shared.core.flows.steps.continuation import ContinueFlowStep
55
- from rasa.shared.core.flows.utils import extract_digression_prop
56
51
  from rasa.shared.core.slots import Slot
57
52
 
58
53
  structlogger = structlog.get_logger()
@@ -94,10 +89,6 @@ class Flow:
94
89
  """The path to the file where the flow is stored."""
95
90
  persisted_slots: List[str] = field(default_factory=list)
96
91
  """The list of slots that should be persisted after the flow ends."""
97
- ask_confirm_digressions: List[str] = field(default_factory=list)
98
- """The flow ids for which the assistant should ask for confirmation."""
99
- block_digressions: List[str] = field(default_factory=list)
100
- """The flow ids that the assistant should block from digressing to."""
101
92
  run_pattern_completed: bool = True
102
93
  """Whether the pattern_completed flow should be run after the flow ends."""
103
94
 
@@ -138,10 +129,6 @@ class Flow:
138
129
  # data. When the model is trained, take the provided file_path.
139
130
  file_path=data.get(KEY_FILE_PATH) if KEY_FILE_PATH in data else file_path,
140
131
  persisted_slots=data.get(KEY_PERSISTED_SLOTS, []),
141
- ask_confirm_digressions=extract_digression_prop(
142
- KEY_ASK_CONFIRM_DIGRESSIONS, data
143
- ),
144
- block_digressions=extract_digression_prop(KEY_BLOCK_DIGRESSIONS, data),
145
132
  run_pattern_completed=data.get(KEY_RUN_PATTERN_COMPLETED, True),
146
133
  translation=extract_translations(
147
134
  translation_data=data.get(KEY_TRANSLATION, {})
@@ -220,10 +207,6 @@ class Flow:
220
207
  data[KEY_FILE_PATH] = self.file_path
221
208
  if self.persisted_slots:
222
209
  data[KEY_PERSISTED_SLOTS] = self.persisted_slots
223
- if self.ask_confirm_digressions:
224
- data[KEY_ASK_CONFIRM_DIGRESSIONS] = self.ask_confirm_digressions
225
- if self.block_digressions:
226
- data[KEY_BLOCK_DIGRESSIONS] = self.block_digressions
227
210
  if self.run_pattern_completed is not None:
228
211
  data["run_pattern_completed"] = self.run_pattern_completed
229
212
  if self.translation:
@@ -217,15 +217,12 @@
217
217
  "reset_after_flow_ends": {
218
218
  "type": "boolean"
219
219
  },
220
- "ask_confirm_digressions": {
221
- "$ref": "#/$defs/ask_confirm_digressions"
222
- },
223
- "block_digressions": {
224
- "$ref": "#/$defs/block_digressions"
225
- },
226
220
  "utter": {
227
221
  "type": "string"
228
222
  },
223
+ "force_slot_filling": {
224
+ "type": "boolean"
225
+ },
229
226
  "rejections": {
230
227
  "type": "array",
231
228
  "schema_name": "list of rejections",
@@ -253,32 +250,6 @@
253
250
  }
254
251
  }
255
252
  },
256
- "ask_confirm_digressions": {
257
- "oneOf": [
258
- {
259
- "type": "boolean"
260
- },
261
- {
262
- "type": "array",
263
- "items": {
264
- "type": "string"
265
- }
266
- }
267
- ]
268
- },
269
- "block_digressions": {
270
- "oneOf": [
271
- {
272
- "type": "boolean"
273
- },
274
- {
275
- "type": "array",
276
- "items": {
277
- "type": "string"
278
- }
279
- }
280
- ]
281
- },
282
253
  "flow": {
283
254
  "required": [
284
255
  "steps",
@@ -340,12 +311,6 @@
340
311
  "persisted_slots": {
341
312
  "$ref": "#/$defs/persisted_slots"
342
313
  },
343
- "ask_confirm_digressions": {
344
- "$ref": "#/$defs/ask_confirm_digressions"
345
- },
346
- "block_digressions": {
347
- "$ref": "#/$defs/block_digressions"
348
- },
349
314
  "run_pattern_completed": {
350
315
  "type": "boolean"
351
316
  }
@@ -1,15 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass, field
3
+ from dataclasses import dataclass
4
4
  from typing import Any, Dict, List, Set, Text
5
5
 
6
6
  from rasa.shared.constants import ACTION_ASK_PREFIX, UTTER_ASK_PREFIX
7
- from rasa.shared.core.constants import (
8
- KEY_ASK_CONFIRM_DIGRESSIONS,
9
- KEY_BLOCK_DIGRESSIONS,
10
- )
11
7
  from rasa.shared.core.flows.flow_step import FlowStep
12
- from rasa.shared.core.flows.utils import extract_digression_prop
13
8
  from rasa.shared.core.slots import SlotRejection
14
9
 
15
10
 
@@ -29,10 +24,8 @@ class CollectInformationFlowStep(FlowStep):
29
24
  """Whether to always ask the question even if the slot is already filled."""
30
25
  reset_after_flow_ends: bool = True
31
26
  """Whether to reset the slot value at the end of the flow."""
32
- ask_confirm_digressions: List[str] = field(default_factory=list)
33
- """The flow id digressions for which the assistant should ask for confirmation."""
34
- block_digressions: List[str] = field(default_factory=list)
35
- """The flow id digressions that should be blocked during the flow step."""
27
+ force_slot_filling: bool = False
28
+ """Whether to keep only the SetSlot command for the collected slot."""
36
29
 
37
30
  @classmethod
38
31
  def from_json(
@@ -60,10 +53,7 @@ class CollectInformationFlowStep(FlowStep):
60
53
  SlotRejection.from_dict(rejection)
61
54
  for rejection in data.get("rejections", [])
62
55
  ],
63
- ask_confirm_digressions=extract_digression_prop(
64
- KEY_ASK_CONFIRM_DIGRESSIONS, data
65
- ),
66
- block_digressions=extract_digression_prop(KEY_BLOCK_DIGRESSIONS, data),
56
+ force_slot_filling=data.get("force_slot_filling", False),
67
57
  **base.__dict__,
68
58
  )
69
59
 
@@ -79,10 +69,7 @@ class CollectInformationFlowStep(FlowStep):
79
69
  data["ask_before_filling"] = self.ask_before_filling
80
70
  data["reset_after_flow_ends"] = self.reset_after_flow_ends
81
71
  data["rejections"] = [rejection.as_dict() for rejection in self.rejections]
82
- data["ask_confirm_digressions"] = self.ask_confirm_digressions
83
- data["block_digressions"] = (
84
- self.block_digressions if self.block_digressions else False
85
- )
72
+ data["force_slot_filling"] = self.force_slot_filling
86
73
 
87
74
  return data
88
75
 
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Any, Dict, List, Set, Text
1
+ from typing import TYPE_CHECKING, Any, Dict, Set, Text
2
2
 
3
3
  from rasa.shared.utils.io import raise_deprecation_warning
4
4
 
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
8
8
 
9
9
  RESET_PROPERTY_NAME = "reset_after_flow_ends"
10
10
  PERSIST_PROPERTY_NAME = "persisted_slots"
11
- ALL_LABEL = "ALL"
12
11
 
13
12
 
14
13
  def warn_deprecated_collect_step_config() -> None:
@@ -45,20 +44,6 @@ def get_invalid_slot_persistence_config_error_message(
45
44
  )
46
45
 
47
46
 
48
- def extract_digression_prop(prop: str, data: Dict[str, Any]) -> List[str]:
49
- """Extracts the digression property from the data.
50
-
51
- There can be two types of properties: ask_confirm_digressions and
52
- block_digressions.
53
- """
54
- digression_property = data.get(prop, [])
55
-
56
- if isinstance(digression_property, bool):
57
- digression_property = [ALL_LABEL] if digression_property else []
58
-
59
- return digression_property
60
-
61
-
62
47
  def extract_translations(
63
48
  translation_data: Dict[Text, Any],
64
49
  ) -> Dict[Text, "FlowLanguageTranslation"]:
@@ -648,12 +648,14 @@ class SlotFillingManager:
648
648
  output_channel: "OutputChannel",
649
649
  nlg: "NaturalLanguageGenerator",
650
650
  recreate_tracker: bool = False,
651
+ slot_events: Optional[List[Event]] = None,
651
652
  ) -> List[Event]:
652
653
  from rasa.core.actions.action import RemoteAction
653
654
  from rasa.shared.core.trackers import DialogueStateTracker
654
655
  from rasa.utils.endpoints import ClientResponseError
655
656
 
656
- slot_events: List[Event] = []
657
+ validated_slot_events: List[Event] = []
658
+ slot_events = slot_events if slot_events is not None else []
657
659
  remote_action = RemoteAction(custom_action, self._action_endpoint)
658
660
  disallowed_types = set()
659
661
 
@@ -673,9 +675,9 @@ class SlotFillingManager:
673
675
  )
674
676
  for event in custom_events:
675
677
  if isinstance(event, SlotSet):
676
- slot_events.append(event)
678
+ validated_slot_events.append(event)
677
679
  elif isinstance(event, BotUttered):
678
- slot_events.append(event)
680
+ validated_slot_events.append(event)
679
681
  else:
680
682
  disallowed_types.add(event.type_name)
681
683
  except (RasaException, ClientResponseError) as e:
@@ -699,7 +701,7 @@ class SlotFillingManager:
699
701
  f"updated with this event.",
700
702
  )
701
703
 
702
- return slot_events
704
+ return validated_slot_events
703
705
 
704
706
  async def execute_validation_action(
705
707
  self,
@@ -722,7 +724,11 @@ class SlotFillingManager:
722
724
  return cast(List[Event], slot_events)
723
725
 
724
726
  validate_events = await self._run_custom_action(
725
- ACTION_VALIDATE_SLOT_MAPPINGS, output_channel, nlg, recreate_tracker=True
727
+ ACTION_VALIDATE_SLOT_MAPPINGS,
728
+ output_channel,
729
+ nlg,
730
+ recreate_tracker=True,
731
+ slot_events=cast(List[Event], slot_events),
726
732
  )
727
733
  validated_slot_names = [
728
734
  event.key for event in validate_events if isinstance(event, SlotSet)
rasa/shared/core/slots.py CHANGED
@@ -787,7 +787,7 @@ class StrictCategoricalSlot(CategoricalSlot):
787
787
  def coerce_value(self, value: Any) -> Any:
788
788
  """Coerce the value to one of the allowed ones or raise an error if invalid."""
789
789
  if value is None:
790
- return self.initial_value
790
+ return value
791
791
 
792
792
  for allowed_value in self.values:
793
793
  # Allowed values are always stored as strings, so we can use casefold().
@@ -1123,16 +1123,10 @@ class DialogueStateTracker:
1123
1123
  f"Please update the slot configuration accordingly."
1124
1124
  )
1125
1125
 
1126
- supported_languages = []
1127
- for language_code in language_slot.values:
1128
- is_default = language_code == language_slot.initial_value
1129
- language = Language.from_language_code(
1130
- language_code=language_code,
1131
- is_default=is_default,
1132
- )
1133
- supported_languages.append(language)
1134
-
1135
- return supported_languages
1126
+ return [
1127
+ Language.from_language_code(language_code)
1128
+ for language_code in language_slot.values
1129
+ ]
1136
1130
 
1137
1131
  @property
1138
1132
  def current_language(self) -> Optional[Language]:
@@ -55,4 +55,3 @@ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE = True
55
55
  SINGLE_ENTITY_ALLOWED_INTERLEAVING_CHARSET = {".", ",", " ", ";"}
56
56
 
57
57
  SET_SLOT_COMMAND = "set slot"
58
- HANDLE_DIGRESSIONS_COMMAND = "handle digressions"
@@ -4,12 +4,3 @@ LITE_LLM_API_KEY_FIELD = "api_key"
4
4
  LITE_LLM_API_VERSION_FIELD = "api_version"
5
5
  LITE_LLM_MODEL_FIELD = "model"
6
6
  LITE_LLM_AZURE_AD_TOKEN = "azure_ad_token"
7
-
8
- # Enable or disable Langfuse integration
9
- RASA_LANGFUSE_INTEGRATION_ENABLED_ENV_VAR = "RASA_LANGFUSE_INTEGRATION_ENABLED"
10
- # Langfuse configuration
11
- LANGFUSE_CALLBACK_NAME = "langfuse"
12
- LANGFUSE_HOST_ENV_VAR = "LANGFUSE_HOST"
13
- LANGFUSE_PROJECT_ID_ENV_VAR = "LANGFUSE_PROJECT_ID"
14
- LANGFUSE_PUBLIC_KEY_ENV_VAR = "LANGFUSE_PUBLIC_KEY"
15
- LANGFUSE_SECRET_KEY_ENV_VAR = "LANGFUSE_SECRET_KEY"
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from abc import abstractmethod
5
- from typing import Any, Dict, List, Optional, Union, cast
5
+ from typing import Any, Dict, List, Union, cast
6
6
 
7
7
  import structlog
8
8
  from litellm import acompletion, completion, validate_environment
@@ -120,11 +120,7 @@ class _BaseLiteLLMClient:
120
120
  raise ProviderClientValidationError(event_info)
121
121
 
122
122
  @suppress_logs(log_level=logging.WARNING)
123
- def completion(
124
- self,
125
- messages: Union[List[dict], List[str], str],
126
- metadata: Optional[Dict[str, Any]] = None,
127
- ) -> LLMResponse:
123
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
128
124
  """Synchronously generate completions for given list of messages.
129
125
 
130
126
  Args:
@@ -136,7 +132,6 @@ class _BaseLiteLLMClient:
136
132
  - a list of messages. Each message is a string and will be formatted
137
133
  as a user message.
138
134
  - a single message as a string which will be formatted as user message.
139
- metadata: Optional metadata to be passed to the LLM call.
140
135
 
141
136
  Returns:
142
137
  List of message completions.
@@ -154,9 +149,7 @@ class _BaseLiteLLMClient:
154
149
 
155
150
  @suppress_logs(log_level=logging.WARNING)
156
151
  async def acompletion(
157
- self,
158
- messages: Union[List[dict], List[str], str],
159
- metadata: Optional[Dict[str, Any]] = None,
152
+ self, messages: Union[List[dict], List[str], str]
160
153
  ) -> LLMResponse:
161
154
  """Asynchronously generate completions for given list of messages.
162
155
 
@@ -169,7 +162,6 @@ class _BaseLiteLLMClient:
169
162
  - a list of messages. Each message is a string and will be formatted
170
163
  as a user message.
171
164
  - a single message as a string which will be formatted as user message.
172
- metadata: Optional metadata to be passed to the LLM call.
173
165
 
174
166
  Returns:
175
167
  List of message completions.
@@ -180,9 +172,7 @@ class _BaseLiteLLMClient:
180
172
  try:
181
173
  formatted_messages = self._get_formatted_messages(messages)
182
174
  arguments = resolve_environment_variables(self._completion_fn_args)
183
- response = await acompletion(
184
- messages=formatted_messages, metadata=metadata, **arguments
185
- )
175
+ response = await acompletion(messages=formatted_messages, **arguments)
186
176
  return self._format_response(response)
187
177
  except Exception as e:
188
178
  message = ""
@@ -101,11 +101,11 @@ class DefaultLiteLLMClient(_BaseLiteLLMClient):
101
101
  # SageMaker) in Rasa by allowing AWS secrets to be provided as extra
102
102
  # parameters without triggering validation errors due to missing AWS
103
103
  # environment variables.
104
- if self.provider.lower() in [
104
+ if self.provider.lower() in {
105
105
  AWS_BEDROCK_PROVIDER,
106
106
  AWS_SAGEMAKER_PROVIDER,
107
107
  AWS_SAGEMAKER_CHAT_PROVIDER,
108
- ]:
108
+ }:
109
109
  validate_aws_setup_for_litellm_clients(
110
110
  self._litellm_model_name,
111
111
  self._litellm_extra_parameters,
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import Any, Dict, List, Optional, Union
4
+ from typing import Any, Dict, List, Union
5
5
 
6
6
  import structlog
7
7
 
@@ -122,12 +122,9 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
122
122
  raise ProviderClientAPIException(e)
123
123
 
124
124
  @suppress_logs(log_level=logging.WARNING)
125
- def completion(
126
- self,
127
- messages: Union[List[dict], List[str], str],
128
- metadata: Optional[Dict[str, Any]] = None,
129
- ) -> LLMResponse:
130
- """Synchronously generate completions for given list of messages.
125
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
126
+ """
127
+ Synchronously generate completions for given list of messages.
131
128
 
132
129
  Method overrides the base class method to call the appropriate
133
130
  completion method based on the configuration. If the chat completions
@@ -143,11 +140,8 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
143
140
  - a list of messages. Each message is a string and will be formatted
144
141
  as a user message.
145
142
  - a single message as a string which will be formatted as user message.
146
- metadata: Optional metadata to be passed to the LLM call.
147
-
148
143
  Returns:
149
144
  List of message completions.
150
-
151
145
  Raises:
152
146
  ProviderClientAPIException: If the API request fails.
153
147
  """
@@ -164,11 +158,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
164
158
 
165
159
  @suppress_logs(log_level=logging.WARNING)
166
160
  async def acompletion(
167
- self,
168
- messages: Union[List[dict], List[str], str],
169
- metadata: Optional[Dict[str, Any]] = None,
161
+ self, messages: Union[List[dict], List[str], str]
170
162
  ) -> LLMResponse:
171
- """Asynchronously generate completions for given list of messages.
163
+ """
164
+ Asynchronously generate completions for given list of messages.
172
165
 
173
166
  Method overrides the base class method to call the appropriate
174
167
  completion method based on the configuration. If the chat completions
@@ -184,11 +177,8 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
184
177
  - a list of messages. Each message is a string and will be formatted
185
178
  as a user message.
186
179
  - a single message as a string which will be formatted as user message.
187
- metadata: Optional metadata to be passed to the LLM call.
188
-
189
180
  Returns:
190
181
  List of message completions.
191
-
192
182
  Raises:
193
183
  ProviderClientAPIException: If the API request fails.
194
184
  """
@@ -1,19 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Dict, List, Optional, Protocol, Union, runtime_checkable
3
+ from typing import Dict, List, Protocol, Union, runtime_checkable
4
4
 
5
5
  from rasa.shared.providers.llm.llm_response import LLMResponse
6
6
 
7
7
 
8
8
  @runtime_checkable
9
9
  class LLMClient(Protocol):
10
- """Protocol for an LLM client that specifies the interface for interacting
10
+ """
11
+ Protocol for an LLM client that specifies the interface for interacting
11
12
  with the API.
12
13
  """
13
14
 
14
15
  @classmethod
15
16
  def from_config(cls, config: dict) -> LLMClient:
16
- """Initializes the llm client with the given configuration.
17
+ """
18
+ Initializes the llm client with the given configuration.
17
19
 
18
20
  This class method should be implemented to parse the given
19
21
  configuration and create an instance of an llm client.
@@ -22,24 +24,17 @@ class LLMClient(Protocol):
22
24
 
23
25
  @property
24
26
  def config(self) -> Dict:
25
- """Returns the configuration for that the llm client is initialized with.
27
+ """
28
+ Returns the configuration for that the llm client is initialized with.
26
29
 
27
30
  This property should be implemented to return a dictionary containing
28
31
  the configuration settings for the llm client.
29
32
  """
30
33
  ...
31
34
 
32
- def completion(
33
- self,
34
- messages: Union[List[dict], List[str], str],
35
- metadata: Optional[Dict[str, Any]] = None,
36
- ) -> LLMResponse:
37
- """Synchronously generate completions for given list of messages.
38
- def completion(
39
- self,
40
- messages: Union[List[dict], List[str], str],
41
- metadata: Optional[Dict[str, Any]] = None,
42
- ) -> LLMResponse:
35
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
36
+ """
37
+ Synchronously generate completions for given list of messages.
43
38
 
44
39
  This method should be implemented to take a list of messages (as
45
40
  strings) and return a list of completions (as strings).
@@ -53,19 +48,16 @@ class LLMClient(Protocol):
53
48
  - a list of messages. Each message is a string and will be formatted
54
49
  as a user message.
55
50
  - a single message as a string which will be formatted as user message.
56
- metadata: Optional metadata to be passed to the LLM call.
57
-
58
51
  Returns:
59
52
  LLMResponse
60
53
  """
61
54
  ...
62
55
 
63
56
  async def acompletion(
64
- self,
65
- messages: Union[List[dict], List[str], str],
66
- metadata: Optional[Dict[str, Any]] = None,
57
+ self, messages: Union[List[dict], List[str], str]
67
58
  ) -> LLMResponse:
68
- """Asynchronously generate completions for given list of messages.
59
+ """
60
+ Asynchronously generate completions for given list of messages.
69
61
 
70
62
  This method should be implemented to take a list of messages (as
71
63
  strings) and return a list of completions (as strings).
@@ -79,15 +71,14 @@ class LLMClient(Protocol):
79
71
  - a list of messages. Each message is a string and will be formatted
80
72
  as a user message.
81
73
  - a single message as a string which will be formatted as user message.
82
- metadata: Optional metadata to be passed to the LLM call.
83
-
84
74
  Returns:
85
75
  LLMResponse
86
76
  """
87
77
  ...
88
78
 
89
79
  def validate_client_setup(self, *args, **kwargs) -> None: # type: ignore
90
- """Perform client setup validation.
80
+ """
81
+ Perform client setup validation.
91
82
 
92
83
  This method should be implemented to validate whether the client can be
93
84
  used with the parameters provided through configuration or environment
@@ -237,9 +237,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
237
237
  raise ProviderClientAPIException(e)
238
238
 
239
239
  async def acompletion(
240
- self,
241
- messages: Union[List[dict], List[str], str],
242
- metadata: Optional[Dict[str, Any]] = None,
240
+ self, messages: Union[List[dict], List[str], str]
243
241
  ) -> LLMResponse:
244
242
  """Asynchronous completion of the model with the given messages.
245
243
 
@@ -257,7 +255,6 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
257
255
  - a list of messages. Each message is a string and will be formatted
258
256
  as a user message.
259
257
  - a single message as a string which will be formatted as user message.
260
- metadata: Optional metadata to be passed to the LLM call.
261
258
 
262
259
  Returns:
263
260
  The completion response.
@@ -266,11 +263,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
266
263
  return await super().acompletion(messages)
267
264
  return await self._atext_completion(messages)
268
265
 
269
- def completion(
270
- self,
271
- messages: Union[List[dict], List[str], str],
272
- metadata: Optional[Dict[str, Any]] = None,
273
- ) -> LLMResponse:
266
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
274
267
  """Completion of the model with the given messages.
275
268
 
276
269
  Method overrides the base class method to call the appropriate
@@ -280,7 +273,6 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
280
273
 
281
274
  Args:
282
275
  messages: The messages to be used for completion.
283
- metadata: Optional metadata to be passed to the LLM call.
284
276
 
285
277
  Returns:
286
278
  The completion response.
@@ -7,7 +7,17 @@ import os
7
7
  import pkgutil
8
8
  import sys
9
9
  from types import ModuleType
10
- from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Text, Type
10
+ from typing import (
11
+ Any,
12
+ Callable,
13
+ Collection,
14
+ Dict,
15
+ List,
16
+ Optional,
17
+ Sequence,
18
+ Text,
19
+ Type,
20
+ )
11
21
 
12
22
  import rasa.shared.utils.io
13
23
  from rasa.exceptions import MissingDependencyException