rasa-pro 3.12.6.dev2__py3-none-any.whl → 3.12.7.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (55) hide show
  1. rasa/__init__.py +0 -6
  2. rasa/cli/run.py +10 -6
  3. rasa/cli/utils.py +7 -0
  4. rasa/core/actions/action.py +0 -6
  5. rasa/core/channels/voice_ready/audiocodes.py +46 -17
  6. rasa/core/nlg/contextual_response_rephraser.py +4 -21
  7. rasa/core/nlg/summarize.py +1 -15
  8. rasa/core/policies/enterprise_search_policy.py +3 -16
  9. rasa/core/policies/flows/flow_executor.py +3 -38
  10. rasa/core/policies/intentless_policy.py +4 -17
  11. rasa/core/policies/policy.py +0 -2
  12. rasa/core/processor.py +19 -5
  13. rasa/core/utils.py +53 -0
  14. rasa/dialogue_understanding/coexistence/llm_based_router.py +4 -18
  15. rasa/dialogue_understanding/commands/cancel_flow_command.py +4 -59
  16. rasa/dialogue_understanding/commands/start_flow_command.py +0 -41
  17. rasa/dialogue_understanding/generator/command_generator.py +67 -0
  18. rasa/dialogue_understanding/generator/llm_based_command_generator.py +4 -20
  19. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -3
  20. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +1 -12
  21. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +0 -61
  22. rasa/dialogue_understanding/processor/command_processor.py +7 -65
  23. rasa/dialogue_understanding/stack/utils.py +0 -38
  24. rasa/e2e_test/utils/validation.py +3 -3
  25. rasa/hooks.py +0 -55
  26. rasa/shared/constants.py +0 -5
  27. rasa/shared/core/constants.py +0 -8
  28. rasa/shared/core/domain.py +12 -3
  29. rasa/shared/core/flows/flow.py +0 -17
  30. rasa/shared/core/flows/flows_yaml_schema.json +3 -38
  31. rasa/shared/core/flows/steps/collect.py +5 -18
  32. rasa/shared/core/flows/utils.py +1 -16
  33. rasa/shared/core/slot_mappings.py +11 -5
  34. rasa/shared/nlu/constants.py +0 -1
  35. rasa/shared/providers/constants.py +0 -9
  36. rasa/shared/providers/llm/_base_litellm_client.py +4 -14
  37. rasa/shared/providers/llm/litellm_router_llm_client.py +7 -17
  38. rasa/shared/providers/llm/llm_client.py +15 -24
  39. rasa/shared/providers/llm/self_hosted_llm_client.py +2 -10
  40. rasa/shared/utils/common.py +11 -1
  41. rasa/shared/utils/health_check/health_check.py +1 -7
  42. rasa/tracing/instrumentation/attribute_extractors.py +4 -4
  43. rasa/tracing/instrumentation/intentless_policy_instrumentation.py +1 -2
  44. rasa/utils/licensing.py +0 -15
  45. rasa/validator.py +1 -123
  46. rasa/version.py +1 -1
  47. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.12.7.dev1.dist-info}/METADATA +3 -4
  48. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.12.7.dev1.dist-info}/RECORD +51 -55
  49. rasa/core/actions/action_handle_digressions.py +0 -164
  50. rasa/dialogue_understanding/commands/handle_digressions_command.py +0 -144
  51. rasa/dialogue_understanding/patterns/handle_digressions.py +0 -81
  52. rasa/monkey_patches.py +0 -91
  53. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.12.7.dev1.dist-info}/NOTICE +0 -0
  54. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.12.7.dev1.dist-info}/WHEEL +0 -0
  55. {rasa_pro-3.12.6.dev2.dist-info → rasa_pro-3.12.7.dev1.dist-info}/entry_points.txt +0 -0
rasa/hooks.py CHANGED
@@ -1,20 +1,8 @@
1
1
  import argparse
2
2
  import logging
3
- import os
4
3
  from typing import TYPE_CHECKING, List, Optional, Text, Union
5
4
 
6
- import litellm
7
5
  import pluggy
8
- import structlog
9
-
10
- from rasa.shared.providers.constants import (
11
- LANGFUSE_CALLBACK_NAME,
12
- LANGFUSE_HOST_ENV_VAR,
13
- LANGFUSE_PROJECT_ID_ENV_VAR,
14
- LANGFUSE_PUBLIC_KEY_ENV_VAR,
15
- LANGFUSE_SECRET_KEY_ENV_VAR,
16
- RASA_LANGFUSE_INTEGRATION_ENABLED_ENV_VAR,
17
- )
18
6
 
19
7
  # IMPORTANT: do not import anything from rasa here - use scoped imports
20
8
  # this avoids circular imports, as the hooks are used in different places
@@ -30,7 +18,6 @@ if TYPE_CHECKING:
30
18
 
31
19
  hookimpl = pluggy.HookimplMarker("rasa")
32
20
  logger = logging.getLogger(__name__)
33
- structlogger = structlog.get_logger()
34
21
 
35
22
 
36
23
  @hookimpl # type: ignore[misc]
@@ -70,8 +57,6 @@ def configure_commandline(cmdline_arguments: argparse.Namespace) -> Optional[Tex
70
57
  config.configure_tracing(tracer_provider)
71
58
  config.configure_metrics(endpoints_file)
72
59
 
73
- _init_langfuse_integration()
74
-
75
60
  return endpoints_file
76
61
 
77
62
 
@@ -130,43 +115,3 @@ def after_server_stop() -> None:
130
115
 
131
116
  if anon_pipeline is not None:
132
117
  anon_pipeline.stop()
133
-
134
-
135
- def _is_langfuse_integration_enabled() -> bool:
136
- return (
137
- os.environ.get(RASA_LANGFUSE_INTEGRATION_ENABLED_ENV_VAR, "false").lower()
138
- == "true"
139
- )
140
-
141
-
142
- def _init_langfuse_integration() -> None:
143
- if not _is_langfuse_integration_enabled():
144
- structlogger.info(
145
- "hooks._init_langfuse_integration.disabled",
146
- event_info="Langfuse integration is disabled.",
147
- )
148
- return
149
-
150
- if (
151
- not os.environ.get(LANGFUSE_HOST_ENV_VAR)
152
- or not os.environ.get(LANGFUSE_PROJECT_ID_ENV_VAR)
153
- or not os.environ.get(LANGFUSE_PUBLIC_KEY_ENV_VAR)
154
- or not os.environ.get(LANGFUSE_SECRET_KEY_ENV_VAR)
155
- ):
156
- structlogger.warning(
157
- "hooks._init_langfuse_integration.missing_langfuse_keys",
158
- event_info=(
159
- "Langfuse integration is enabled, but some environment variables"
160
- "are missing. Please set LANGFUSE_HOST, LANGFUSE_PROJECT_ID, "
161
- "LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY environment "
162
- "variables to use Langfuse integration."
163
- ),
164
- )
165
- return
166
-
167
- litellm.success_callback = [LANGFUSE_CALLBACK_NAME]
168
- litellm.failure_callback = [LANGFUSE_CALLBACK_NAME]
169
- structlogger.info(
170
- "hooks.langfuse_callbacks_initialized",
171
- event_info="Langfuse integration initialized.",
172
- )
rasa/shared/constants.py CHANGED
@@ -338,8 +338,3 @@ ROLE_SYSTEM = "system"
338
338
  # Used for key values in ValidateSlotPatternFlowStackFrame
339
339
  REFILL_UTTER = "refill_utter"
340
340
  REJECTIONS = "rejections"
341
-
342
- LANGFUSE_METADATA_USER_ID = "trace_user_id"
343
- LANGFUSE_METADATA_SESSION_ID = "session_id"
344
- LANGFUSE_CUSTOM_METADATA_DICT = "trace_metadata"
345
- LANGFUSE_TAGS = "tags"
@@ -52,8 +52,6 @@ ACTION_TRIGGER_CHITCHAT = "action_trigger_chitchat"
52
52
  ACTION_RESET_ROUTING = "action_reset_routing"
53
53
  ACTION_HANGUP = "action_hangup"
54
54
  ACTION_REPEAT_BOT_MESSAGES = "action_repeat_bot_messages"
55
- ACTION_BLOCK_DIGRESSION = "action_block_digression"
56
- ACTION_CONTINUE_DIGRESSION = "action_continue_digression"
57
55
 
58
56
  ACTION_METADATA_EXECUTION_SUCCESS = "execution_success"
59
57
  ACTION_METADATA_EXECUTION_ERROR_MESSAGE = "execution_error_message"
@@ -84,8 +82,6 @@ DEFAULT_ACTION_NAMES = [
84
82
  ACTION_RESET_ROUTING,
85
83
  ACTION_HANGUP,
86
84
  ACTION_REPEAT_BOT_MESSAGES,
87
- ACTION_BLOCK_DIGRESSION,
88
- ACTION_CONTINUE_DIGRESSION,
89
85
  ]
90
86
 
91
87
  ACTION_SHOULD_SEND_DOMAIN = "send_domain"
@@ -205,8 +201,4 @@ CLASSIFIER_NAME_FALLBACK = "FallbackClassifier"
205
201
 
206
202
  POLICIES_THAT_EXTRACT_ENTITIES = {"TEDPolicy"}
207
203
 
208
- # digression constants
209
- KEY_ASK_CONFIRM_DIGRESSIONS = "ask_confirm_digressions"
210
- KEY_BLOCK_DIGRESSIONS = "block_digressions"
211
-
212
204
  ERROR_CODE_KEY = "error_code"
@@ -1678,6 +1678,14 @@ class Domain:
1678
1678
  """Write domain to a file."""
1679
1679
  as_yaml = self.as_yaml()
1680
1680
  rasa.shared.utils.io.write_text_file(as_yaml, filename)
1681
+ # run the check again on the written domain to catch any errors
1682
+ # that may have been missed in the user defined domain files
1683
+ structlogger.info(
1684
+ "domain.persist.domain_written_to_file",
1685
+ event_info="The entire domain content has been written to file.",
1686
+ filename=filename,
1687
+ )
1688
+ Domain.is_domain_file(filename)
1681
1689
 
1682
1690
  def as_yaml(self) -> Text:
1683
1691
  """Dump the `Domain` object as a YAML string.
@@ -1972,17 +1980,18 @@ class Domain:
1972
1980
 
1973
1981
  try:
1974
1982
  content = read_yaml_file(filename, expand_env_vars=cls.expand_env_vars)
1975
- except (RasaException, YamlSyntaxException):
1976
- structlogger.warning(
1983
+ except (RasaException, YamlSyntaxException) as error:
1984
+ structlogger.error(
1977
1985
  "domain.cannot_load_domain_file",
1978
1986
  file=filename,
1987
+ error=error,
1979
1988
  event_info=(
1980
1989
  f"The file {filename} could not be loaded as domain file. "
1981
1990
  f"You can use https://yamlchecker.com/ to validate "
1982
1991
  f"the YAML syntax of your file."
1983
1992
  ),
1984
1993
  )
1985
- return False
1994
+ raise RasaException(f"Domain could not be loaded: {error}")
1986
1995
 
1987
1996
  return any(key in content for key in ALL_DOMAIN_KEYS)
1988
1997
 
@@ -13,10 +13,6 @@ from pypred import Predicate
13
13
  import rasa.shared.utils.io
14
14
  from rasa.engine.language import Language
15
15
  from rasa.shared.constants import RASA_DEFAULT_FLOW_PATTERN_PREFIX
16
- from rasa.shared.core.constants import (
17
- KEY_ASK_CONFIRM_DIGRESSIONS,
18
- KEY_BLOCK_DIGRESSIONS,
19
- )
20
16
  from rasa.shared.core.flows.constants import (
21
17
  KEY_ALWAYS_INCLUDE_IN_PROMPT,
22
18
  KEY_DESCRIPTION,
@@ -52,7 +48,6 @@ from rasa.shared.core.flows.steps.constants import (
52
48
  START_STEP,
53
49
  )
54
50
  from rasa.shared.core.flows.steps.continuation import ContinueFlowStep
55
- from rasa.shared.core.flows.utils import extract_digression_prop
56
51
  from rasa.shared.core.slots import Slot
57
52
 
58
53
  structlogger = structlog.get_logger()
@@ -94,10 +89,6 @@ class Flow:
94
89
  """The path to the file where the flow is stored."""
95
90
  persisted_slots: List[str] = field(default_factory=list)
96
91
  """The list of slots that should be persisted after the flow ends."""
97
- ask_confirm_digressions: List[str] = field(default_factory=list)
98
- """The flow ids for which the assistant should ask for confirmation."""
99
- block_digressions: List[str] = field(default_factory=list)
100
- """The flow ids that the assistant should block from digressing to."""
101
92
  run_pattern_completed: bool = True
102
93
  """Whether the pattern_completed flow should be run after the flow ends."""
103
94
 
@@ -138,10 +129,6 @@ class Flow:
138
129
  # data. When the model is trained, take the provided file_path.
139
130
  file_path=data.get(KEY_FILE_PATH) if KEY_FILE_PATH in data else file_path,
140
131
  persisted_slots=data.get(KEY_PERSISTED_SLOTS, []),
141
- ask_confirm_digressions=extract_digression_prop(
142
- KEY_ASK_CONFIRM_DIGRESSIONS, data
143
- ),
144
- block_digressions=extract_digression_prop(KEY_BLOCK_DIGRESSIONS, data),
145
132
  run_pattern_completed=data.get(KEY_RUN_PATTERN_COMPLETED, True),
146
133
  translation=extract_translations(
147
134
  translation_data=data.get(KEY_TRANSLATION, {})
@@ -220,10 +207,6 @@ class Flow:
220
207
  data[KEY_FILE_PATH] = self.file_path
221
208
  if self.persisted_slots:
222
209
  data[KEY_PERSISTED_SLOTS] = self.persisted_slots
223
- if self.ask_confirm_digressions:
224
- data[KEY_ASK_CONFIRM_DIGRESSIONS] = self.ask_confirm_digressions
225
- if self.block_digressions:
226
- data[KEY_BLOCK_DIGRESSIONS] = self.block_digressions
227
210
  if self.run_pattern_completed is not None:
228
211
  data["run_pattern_completed"] = self.run_pattern_completed
229
212
  if self.translation:
@@ -217,15 +217,12 @@
217
217
  "reset_after_flow_ends": {
218
218
  "type": "boolean"
219
219
  },
220
- "ask_confirm_digressions": {
221
- "$ref": "#/$defs/ask_confirm_digressions"
222
- },
223
- "block_digressions": {
224
- "$ref": "#/$defs/block_digressions"
225
- },
226
220
  "utter": {
227
221
  "type": "string"
228
222
  },
223
+ "force_slot_filling": {
224
+ "type": "boolean"
225
+ },
229
226
  "rejections": {
230
227
  "type": "array",
231
228
  "schema_name": "list of rejections",
@@ -253,32 +250,6 @@
253
250
  }
254
251
  }
255
252
  },
256
- "ask_confirm_digressions": {
257
- "oneOf": [
258
- {
259
- "type": "boolean"
260
- },
261
- {
262
- "type": "array",
263
- "items": {
264
- "type": "string"
265
- }
266
- }
267
- ]
268
- },
269
- "block_digressions": {
270
- "oneOf": [
271
- {
272
- "type": "boolean"
273
- },
274
- {
275
- "type": "array",
276
- "items": {
277
- "type": "string"
278
- }
279
- }
280
- ]
281
- },
282
253
  "flow": {
283
254
  "required": [
284
255
  "steps",
@@ -340,12 +311,6 @@
340
311
  "persisted_slots": {
341
312
  "$ref": "#/$defs/persisted_slots"
342
313
  },
343
- "ask_confirm_digressions": {
344
- "$ref": "#/$defs/ask_confirm_digressions"
345
- },
346
- "block_digressions": {
347
- "$ref": "#/$defs/block_digressions"
348
- },
349
314
  "run_pattern_completed": {
350
315
  "type": "boolean"
351
316
  }
@@ -1,15 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
- from dataclasses import dataclass, field
3
+ from dataclasses import dataclass
4
4
  from typing import Any, Dict, List, Set, Text
5
5
 
6
6
  from rasa.shared.constants import ACTION_ASK_PREFIX, UTTER_ASK_PREFIX
7
- from rasa.shared.core.constants import (
8
- KEY_ASK_CONFIRM_DIGRESSIONS,
9
- KEY_BLOCK_DIGRESSIONS,
10
- )
11
7
  from rasa.shared.core.flows.flow_step import FlowStep
12
- from rasa.shared.core.flows.utils import extract_digression_prop
13
8
  from rasa.shared.core.slots import SlotRejection
14
9
 
15
10
 
@@ -29,10 +24,8 @@ class CollectInformationFlowStep(FlowStep):
29
24
  """Whether to always ask the question even if the slot is already filled."""
30
25
  reset_after_flow_ends: bool = True
31
26
  """Whether to reset the slot value at the end of the flow."""
32
- ask_confirm_digressions: List[str] = field(default_factory=list)
33
- """The flow id digressions for which the assistant should ask for confirmation."""
34
- block_digressions: List[str] = field(default_factory=list)
35
- """The flow id digressions that should be blocked during the flow step."""
27
+ force_slot_filling: bool = False
28
+ """Whether to keep only the SetSlot command for the collected slot."""
36
29
 
37
30
  @classmethod
38
31
  def from_json(
@@ -60,10 +53,7 @@ class CollectInformationFlowStep(FlowStep):
60
53
  SlotRejection.from_dict(rejection)
61
54
  for rejection in data.get("rejections", [])
62
55
  ],
63
- ask_confirm_digressions=extract_digression_prop(
64
- KEY_ASK_CONFIRM_DIGRESSIONS, data
65
- ),
66
- block_digressions=extract_digression_prop(KEY_BLOCK_DIGRESSIONS, data),
56
+ force_slot_filling=data.get("force_slot_filling", False),
67
57
  **base.__dict__,
68
58
  )
69
59
 
@@ -79,10 +69,7 @@ class CollectInformationFlowStep(FlowStep):
79
69
  data["ask_before_filling"] = self.ask_before_filling
80
70
  data["reset_after_flow_ends"] = self.reset_after_flow_ends
81
71
  data["rejections"] = [rejection.as_dict() for rejection in self.rejections]
82
- data["ask_confirm_digressions"] = self.ask_confirm_digressions
83
- data["block_digressions"] = (
84
- self.block_digressions if self.block_digressions else False
85
- )
72
+ data["force_slot_filling"] = self.force_slot_filling
86
73
 
87
74
  return data
88
75
 
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Any, Dict, List, Set, Text
1
+ from typing import TYPE_CHECKING, Any, Dict, Set, Text
2
2
 
3
3
  from rasa.shared.utils.io import raise_deprecation_warning
4
4
 
@@ -8,7 +8,6 @@ if TYPE_CHECKING:
8
8
 
9
9
  RESET_PROPERTY_NAME = "reset_after_flow_ends"
10
10
  PERSIST_PROPERTY_NAME = "persisted_slots"
11
- ALL_LABEL = "ALL"
12
11
 
13
12
 
14
13
  def warn_deprecated_collect_step_config() -> None:
@@ -45,20 +44,6 @@ def get_invalid_slot_persistence_config_error_message(
45
44
  )
46
45
 
47
46
 
48
- def extract_digression_prop(prop: str, data: Dict[str, Any]) -> List[str]:
49
- """Extracts the digression property from the data.
50
-
51
- There can be two types of properties: ask_confirm_digressions and
52
- block_digressions.
53
- """
54
- digression_property = data.get(prop, [])
55
-
56
- if isinstance(digression_property, bool):
57
- digression_property = [ALL_LABEL] if digression_property else []
58
-
59
- return digression_property
60
-
61
-
62
47
  def extract_translations(
63
48
  translation_data: Dict[Text, Any],
64
49
  ) -> Dict[Text, "FlowLanguageTranslation"]:
@@ -648,12 +648,14 @@ class SlotFillingManager:
648
648
  output_channel: "OutputChannel",
649
649
  nlg: "NaturalLanguageGenerator",
650
650
  recreate_tracker: bool = False,
651
+ slot_events: Optional[List[Event]] = None,
651
652
  ) -> List[Event]:
652
653
  from rasa.core.actions.action import RemoteAction
653
654
  from rasa.shared.core.trackers import DialogueStateTracker
654
655
  from rasa.utils.endpoints import ClientResponseError
655
656
 
656
- slot_events: List[Event] = []
657
+ validated_slot_events: List[Event] = []
658
+ slot_events = slot_events if slot_events is not None else []
657
659
  remote_action = RemoteAction(custom_action, self._action_endpoint)
658
660
  disallowed_types = set()
659
661
 
@@ -673,9 +675,9 @@ class SlotFillingManager:
673
675
  )
674
676
  for event in custom_events:
675
677
  if isinstance(event, SlotSet):
676
- slot_events.append(event)
678
+ validated_slot_events.append(event)
677
679
  elif isinstance(event, BotUttered):
678
- slot_events.append(event)
680
+ validated_slot_events.append(event)
679
681
  else:
680
682
  disallowed_types.add(event.type_name)
681
683
  except (RasaException, ClientResponseError) as e:
@@ -699,7 +701,7 @@ class SlotFillingManager:
699
701
  f"updated with this event.",
700
702
  )
701
703
 
702
- return slot_events
704
+ return validated_slot_events
703
705
 
704
706
  async def execute_validation_action(
705
707
  self,
@@ -722,7 +724,11 @@ class SlotFillingManager:
722
724
  return cast(List[Event], slot_events)
723
725
 
724
726
  validate_events = await self._run_custom_action(
725
- ACTION_VALIDATE_SLOT_MAPPINGS, output_channel, nlg, recreate_tracker=True
727
+ ACTION_VALIDATE_SLOT_MAPPINGS,
728
+ output_channel,
729
+ nlg,
730
+ recreate_tracker=True,
731
+ slot_events=cast(List[Event], slot_events),
726
732
  )
727
733
  validated_slot_names = [
728
734
  event.key for event in validate_events if isinstance(event, SlotSet)
@@ -55,4 +55,3 @@ SPLIT_ENTITIES_BY_COMMA_DEFAULT_VALUE = True
55
55
  SINGLE_ENTITY_ALLOWED_INTERLEAVING_CHARSET = {".", ",", " ", ";"}
56
56
 
57
57
  SET_SLOT_COMMAND = "set slot"
58
- HANDLE_DIGRESSIONS_COMMAND = "handle digressions"
@@ -4,12 +4,3 @@ LITE_LLM_API_KEY_FIELD = "api_key"
4
4
  LITE_LLM_API_VERSION_FIELD = "api_version"
5
5
  LITE_LLM_MODEL_FIELD = "model"
6
6
  LITE_LLM_AZURE_AD_TOKEN = "azure_ad_token"
7
-
8
- # Enable or disable Langfuse integration
9
- RASA_LANGFUSE_INTEGRATION_ENABLED_ENV_VAR = "RASA_LANGFUSE_INTEGRATION_ENABLED"
10
- # Langfuse configuration
11
- LANGFUSE_CALLBACK_NAME = "langfuse"
12
- LANGFUSE_HOST_ENV_VAR = "LANGFUSE_HOST"
13
- LANGFUSE_PROJECT_ID_ENV_VAR = "LANGFUSE_PROJECT_ID"
14
- LANGFUSE_PUBLIC_KEY_ENV_VAR = "LANGFUSE_PUBLIC_KEY"
15
- LANGFUSE_SECRET_KEY_ENV_VAR = "LANGFUSE_SECRET_KEY"
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from abc import abstractmethod
5
- from typing import Any, Dict, List, Optional, Union, cast
5
+ from typing import Any, Dict, List, Union, cast
6
6
 
7
7
  import structlog
8
8
  from litellm import acompletion, completion, validate_environment
@@ -120,11 +120,7 @@ class _BaseLiteLLMClient:
120
120
  raise ProviderClientValidationError(event_info)
121
121
 
122
122
  @suppress_logs(log_level=logging.WARNING)
123
- def completion(
124
- self,
125
- messages: Union[List[dict], List[str], str],
126
- metadata: Optional[Dict[str, Any]] = None,
127
- ) -> LLMResponse:
123
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
128
124
  """Synchronously generate completions for given list of messages.
129
125
 
130
126
  Args:
@@ -136,7 +132,6 @@ class _BaseLiteLLMClient:
136
132
  - a list of messages. Each message is a string and will be formatted
137
133
  as a user message.
138
134
  - a single message as a string which will be formatted as user message.
139
- metadata: Optional metadata to be passed to the LLM call.
140
135
 
141
136
  Returns:
142
137
  List of message completions.
@@ -154,9 +149,7 @@ class _BaseLiteLLMClient:
154
149
 
155
150
  @suppress_logs(log_level=logging.WARNING)
156
151
  async def acompletion(
157
- self,
158
- messages: Union[List[dict], List[str], str],
159
- metadata: Optional[Dict[str, Any]] = None,
152
+ self, messages: Union[List[dict], List[str], str]
160
153
  ) -> LLMResponse:
161
154
  """Asynchronously generate completions for given list of messages.
162
155
 
@@ -169,7 +162,6 @@ class _BaseLiteLLMClient:
169
162
  - a list of messages. Each message is a string and will be formatted
170
163
  as a user message.
171
164
  - a single message as a string which will be formatted as user message.
172
- metadata: Optional metadata to be passed to the LLM call.
173
165
 
174
166
  Returns:
175
167
  List of message completions.
@@ -180,9 +172,7 @@ class _BaseLiteLLMClient:
180
172
  try:
181
173
  formatted_messages = self._get_formatted_messages(messages)
182
174
  arguments = resolve_environment_variables(self._completion_fn_args)
183
- response = await acompletion(
184
- messages=formatted_messages, metadata=metadata, **arguments
185
- )
175
+ response = await acompletion(messages=formatted_messages, **arguments)
186
176
  return self._format_response(response)
187
177
  except Exception as e:
188
178
  message = ""
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import Any, Dict, List, Optional, Union
4
+ from typing import Any, Dict, List, Union
5
5
 
6
6
  import structlog
7
7
 
@@ -122,12 +122,9 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
122
122
  raise ProviderClientAPIException(e)
123
123
 
124
124
  @suppress_logs(log_level=logging.WARNING)
125
- def completion(
126
- self,
127
- messages: Union[List[dict], List[str], str],
128
- metadata: Optional[Dict[str, Any]] = None,
129
- ) -> LLMResponse:
130
- """Synchronously generate completions for given list of messages.
125
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
126
+ """
127
+ Synchronously generate completions for given list of messages.
131
128
 
132
129
  Method overrides the base class method to call the appropriate
133
130
  completion method based on the configuration. If the chat completions
@@ -143,11 +140,8 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
143
140
  - a list of messages. Each message is a string and will be formatted
144
141
  as a user message.
145
142
  - a single message as a string which will be formatted as user message.
146
- metadata: Optional metadata to be passed to the LLM call.
147
-
148
143
  Returns:
149
144
  List of message completions.
150
-
151
145
  Raises:
152
146
  ProviderClientAPIException: If the API request fails.
153
147
  """
@@ -164,11 +158,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
164
158
 
165
159
  @suppress_logs(log_level=logging.WARNING)
166
160
  async def acompletion(
167
- self,
168
- messages: Union[List[dict], List[str], str],
169
- metadata: Optional[Dict[str, Any]] = None,
161
+ self, messages: Union[List[dict], List[str], str]
170
162
  ) -> LLMResponse:
171
- """Asynchronously generate completions for given list of messages.
163
+ """
164
+ Asynchronously generate completions for given list of messages.
172
165
 
173
166
  Method overrides the base class method to call the appropriate
174
167
  completion method based on the configuration. If the chat completions
@@ -184,11 +177,8 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
184
177
  - a list of messages. Each message is a string and will be formatted
185
178
  as a user message.
186
179
  - a single message as a string which will be formatted as user message.
187
- metadata: Optional metadata to be passed to the LLM call.
188
-
189
180
  Returns:
190
181
  List of message completions.
191
-
192
182
  Raises:
193
183
  ProviderClientAPIException: If the API request fails.
194
184
  """
@@ -1,19 +1,21 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, Dict, List, Optional, Protocol, Union, runtime_checkable
3
+ from typing import Dict, List, Protocol, Union, runtime_checkable
4
4
 
5
5
  from rasa.shared.providers.llm.llm_response import LLMResponse
6
6
 
7
7
 
8
8
  @runtime_checkable
9
9
  class LLMClient(Protocol):
10
- """Protocol for an LLM client that specifies the interface for interacting
10
+ """
11
+ Protocol for an LLM client that specifies the interface for interacting
11
12
  with the API.
12
13
  """
13
14
 
14
15
  @classmethod
15
16
  def from_config(cls, config: dict) -> LLMClient:
16
- """Initializes the llm client with the given configuration.
17
+ """
18
+ Initializes the llm client with the given configuration.
17
19
 
18
20
  This class method should be implemented to parse the given
19
21
  configuration and create an instance of an llm client.
@@ -22,24 +24,17 @@ class LLMClient(Protocol):
22
24
 
23
25
  @property
24
26
  def config(self) -> Dict:
25
- """Returns the configuration for that the llm client is initialized with.
27
+ """
28
+ Returns the configuration for that the llm client is initialized with.
26
29
 
27
30
  This property should be implemented to return a dictionary containing
28
31
  the configuration settings for the llm client.
29
32
  """
30
33
  ...
31
34
 
32
- def completion(
33
- self,
34
- messages: Union[List[dict], List[str], str],
35
- metadata: Optional[Dict[str, Any]] = None,
36
- ) -> LLMResponse:
37
- """Synchronously generate completions for given list of messages.
38
- def completion(
39
- self,
40
- messages: Union[List[dict], List[str], str],
41
- metadata: Optional[Dict[str, Any]] = None,
42
- ) -> LLMResponse:
35
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
36
+ """
37
+ Synchronously generate completions for given list of messages.
43
38
 
44
39
  This method should be implemented to take a list of messages (as
45
40
  strings) and return a list of completions (as strings).
@@ -53,19 +48,16 @@ class LLMClient(Protocol):
53
48
  - a list of messages. Each message is a string and will be formatted
54
49
  as a user message.
55
50
  - a single message as a string which will be formatted as user message.
56
- metadata: Optional metadata to be passed to the LLM call.
57
-
58
51
  Returns:
59
52
  LLMResponse
60
53
  """
61
54
  ...
62
55
 
63
56
  async def acompletion(
64
- self,
65
- messages: Union[List[dict], List[str], str],
66
- metadata: Optional[Dict[str, Any]] = None,
57
+ self, messages: Union[List[dict], List[str], str]
67
58
  ) -> LLMResponse:
68
- """Asynchronously generate completions for given list of messages.
59
+ """
60
+ Asynchronously generate completions for given list of messages.
69
61
 
70
62
  This method should be implemented to take a list of messages (as
71
63
  strings) and return a list of completions (as strings).
@@ -79,15 +71,14 @@ class LLMClient(Protocol):
79
71
  - a list of messages. Each message is a string and will be formatted
80
72
  as a user message.
81
73
  - a single message as a string which will be formatted as user message.
82
- metadata: Optional metadata to be passed to the LLM call.
83
-
84
74
  Returns:
85
75
  LLMResponse
86
76
  """
87
77
  ...
88
78
 
89
79
  def validate_client_setup(self, *args, **kwargs) -> None: # type: ignore
90
- """Perform client setup validation.
80
+ """
81
+ Perform client setup validation.
91
82
 
92
83
  This method should be implemented to validate whether the client can be
93
84
  used with the parameters provided through configuration or environment