rasa-pro 3.12.0.dev13__py3-none-any.whl → 3.12.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (128) hide show
  1. rasa/anonymization/anonymization_rule_executor.py +16 -10
  2. rasa/cli/data.py +16 -0
  3. rasa/cli/project_templates/calm/config.yml +2 -2
  4. rasa/cli/project_templates/calm/endpoints.yml +2 -2
  5. rasa/cli/utils.py +12 -0
  6. rasa/core/actions/action.py +84 -191
  7. rasa/core/actions/action_run_slot_rejections.py +16 -4
  8. rasa/core/channels/__init__.py +2 -0
  9. rasa/core/channels/studio_chat.py +19 -0
  10. rasa/core/channels/telegram.py +42 -24
  11. rasa/core/channels/voice_ready/utils.py +1 -1
  12. rasa/core/channels/voice_stream/asr/asr_engine.py +10 -4
  13. rasa/core/channels/voice_stream/asr/azure.py +14 -1
  14. rasa/core/channels/voice_stream/asr/deepgram.py +20 -4
  15. rasa/core/channels/voice_stream/audiocodes.py +264 -0
  16. rasa/core/channels/voice_stream/browser_audio.py +4 -1
  17. rasa/core/channels/voice_stream/call_state.py +3 -0
  18. rasa/core/channels/voice_stream/genesys.py +6 -2
  19. rasa/core/channels/voice_stream/tts/azure.py +9 -1
  20. rasa/core/channels/voice_stream/tts/cartesia.py +14 -8
  21. rasa/core/channels/voice_stream/voice_channel.py +23 -2
  22. rasa/core/constants.py +2 -0
  23. rasa/core/nlg/contextual_response_rephraser.py +18 -1
  24. rasa/core/nlg/generator.py +83 -15
  25. rasa/core/nlg/response.py +6 -3
  26. rasa/core/nlg/translate.py +55 -0
  27. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +1 -1
  28. rasa/core/policies/flows/flow_executor.py +12 -5
  29. rasa/core/processor.py +72 -9
  30. rasa/dialogue_understanding/commands/can_not_handle_command.py +20 -2
  31. rasa/dialogue_understanding/commands/cancel_flow_command.py +24 -6
  32. rasa/dialogue_understanding/commands/change_flow_command.py +20 -2
  33. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +20 -2
  34. rasa/dialogue_understanding/commands/clarify_command.py +29 -3
  35. rasa/dialogue_understanding/commands/command.py +1 -16
  36. rasa/dialogue_understanding/commands/command_syntax_manager.py +55 -0
  37. rasa/dialogue_understanding/commands/human_handoff_command.py +20 -2
  38. rasa/dialogue_understanding/commands/knowledge_answer_command.py +20 -2
  39. rasa/dialogue_understanding/commands/prompt_command.py +94 -0
  40. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +20 -2
  41. rasa/dialogue_understanding/commands/set_slot_command.py +24 -2
  42. rasa/dialogue_understanding/commands/skip_question_command.py +20 -2
  43. rasa/dialogue_understanding/commands/start_flow_command.py +20 -2
  44. rasa/dialogue_understanding/commands/utils.py +98 -4
  45. rasa/dialogue_understanding/generator/__init__.py +2 -0
  46. rasa/dialogue_understanding/generator/command_parser.py +15 -12
  47. rasa/dialogue_understanding/generator/constants.py +3 -0
  48. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -5
  49. rasa/dialogue_understanding/generator/llm_command_generator.py +5 -3
  50. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +16 -2
  51. rasa/dialogue_understanding/generator/prompt_templates/__init__.py +0 -0
  52. rasa/dialogue_understanding/generator/{single_step → prompt_templates}/command_prompt_template.jinja2 +2 -0
  53. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +77 -0
  54. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +68 -0
  55. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +84 -0
  56. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +460 -0
  57. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +12 -310
  58. rasa/dialogue_understanding/patterns/collect_information.py +1 -1
  59. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +16 -0
  60. rasa/dialogue_understanding/patterns/validate_slot.py +65 -0
  61. rasa/dialogue_understanding/processor/command_processor.py +39 -0
  62. rasa/dialogue_understanding_test/du_test_case.py +28 -8
  63. rasa/dialogue_understanding_test/du_test_result.py +13 -9
  64. rasa/dialogue_understanding_test/io.py +14 -0
  65. rasa/e2e_test/utils/io.py +0 -37
  66. rasa/engine/graph.py +1 -0
  67. rasa/engine/language.py +140 -0
  68. rasa/engine/recipes/config_files/default_config.yml +4 -0
  69. rasa/engine/recipes/default_recipe.py +2 -0
  70. rasa/engine/recipes/graph_recipe.py +2 -0
  71. rasa/engine/storage/local_model_storage.py +1 -0
  72. rasa/engine/storage/storage.py +4 -1
  73. rasa/model_manager/runner_service.py +7 -4
  74. rasa/model_manager/socket_bridge.py +7 -6
  75. rasa/shared/constants.py +15 -13
  76. rasa/shared/core/constants.py +2 -0
  77. rasa/shared/core/flows/constants.py +11 -0
  78. rasa/shared/core/flows/flow.py +83 -19
  79. rasa/shared/core/flows/flows_yaml_schema.json +31 -3
  80. rasa/shared/core/flows/steps/collect.py +1 -36
  81. rasa/shared/core/flows/utils.py +28 -4
  82. rasa/shared/core/flows/validation.py +1 -1
  83. rasa/shared/core/slot_mappings.py +208 -5
  84. rasa/shared/core/slots.py +131 -1
  85. rasa/shared/core/trackers.py +74 -1
  86. rasa/shared/importers/importer.py +50 -2
  87. rasa/shared/nlu/training_data/schemas/responses.yml +19 -12
  88. rasa/shared/providers/_configs/azure_entra_id_config.py +541 -0
  89. rasa/shared/providers/_configs/azure_openai_client_config.py +138 -3
  90. rasa/shared/providers/_configs/client_config.py +3 -1
  91. rasa/shared/providers/_configs/default_litellm_client_config.py +3 -1
  92. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +3 -1
  93. rasa/shared/providers/_configs/litellm_router_client_config.py +3 -1
  94. rasa/shared/providers/_configs/model_group_config.py +4 -2
  95. rasa/shared/providers/_configs/oauth_config.py +33 -0
  96. rasa/shared/providers/_configs/openai_client_config.py +3 -1
  97. rasa/shared/providers/_configs/rasa_llm_client_config.py +3 -1
  98. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +3 -1
  99. rasa/shared/providers/constants.py +6 -0
  100. rasa/shared/providers/embedding/azure_openai_embedding_client.py +28 -3
  101. rasa/shared/providers/embedding/litellm_router_embedding_client.py +3 -1
  102. rasa/shared/providers/llm/_base_litellm_client.py +42 -17
  103. rasa/shared/providers/llm/azure_openai_llm_client.py +81 -25
  104. rasa/shared/providers/llm/default_litellm_llm_client.py +3 -1
  105. rasa/shared/providers/llm/litellm_router_llm_client.py +29 -8
  106. rasa/shared/providers/llm/llm_client.py +23 -7
  107. rasa/shared/providers/llm/openai_llm_client.py +9 -3
  108. rasa/shared/providers/llm/rasa_llm_client.py +11 -2
  109. rasa/shared/providers/llm/self_hosted_llm_client.py +30 -11
  110. rasa/shared/providers/router/_base_litellm_router_client.py +3 -1
  111. rasa/shared/providers/router/router_client.py +3 -1
  112. rasa/shared/utils/constants.py +3 -0
  113. rasa/shared/utils/llm.py +30 -7
  114. rasa/shared/utils/pykwalify_extensions.py +24 -0
  115. rasa/shared/utils/schemas/domain.yml +26 -0
  116. rasa/telemetry.py +2 -1
  117. rasa/tracing/config.py +2 -0
  118. rasa/tracing/constants.py +12 -0
  119. rasa/tracing/instrumentation/instrumentation.py +36 -0
  120. rasa/tracing/instrumentation/metrics.py +41 -0
  121. rasa/tracing/metric_instrument_provider.py +40 -0
  122. rasa/validator.py +372 -7
  123. rasa/version.py +1 -1
  124. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/METADATA +2 -1
  125. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/RECORD +128 -113
  126. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/NOTICE +0 -0
  127. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/WHEEL +0 -0
  128. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,20 @@
1
+ from __future__ import annotations
2
+
3
+ from copy import deepcopy
1
4
  from dataclasses import asdict, dataclass, field
2
- from typing import Any, Dict, Optional
5
+ from typing import (
6
+ Any,
7
+ Dict,
8
+ Optional,
9
+ Set,
10
+ )
3
11
 
4
12
  import structlog
13
+ from pydantic import BaseModel
5
14
 
6
15
  from rasa.shared.constants import (
7
16
  API_BASE_CONFIG_KEY,
17
+ API_KEY,
8
18
  API_TYPE_CONFIG_KEY,
9
19
  API_VERSION_CONFIG_KEY,
10
20
  AZURE_API_TYPE,
@@ -25,12 +35,22 @@ from rasa.shared.constants import (
25
35
  STREAM_CONFIG_KEY,
26
36
  TIMEOUT_CONFIG_KEY,
27
37
  )
38
+ from rasa.shared.providers._configs.azure_entra_id_config import (
39
+ AzureEntraIDOAuthConfig,
40
+ AzureEntraIDOAuthType,
41
+ )
42
+ from rasa.shared.providers._configs.oauth_config import (
43
+ OAUTH_KEY,
44
+ OAUTH_TYPE_FIELD,
45
+ OAuth,
46
+ )
28
47
  from rasa.shared.providers._configs.utils import (
29
48
  raise_deprecation_warnings,
30
49
  resolve_aliases,
31
50
  validate_forbidden_keys,
32
51
  validate_required_keys,
33
52
  )
53
+ from rasa.shared.utils.common import class_from_module_path
34
54
 
35
55
  structlogger = structlog.get_logger()
36
56
 
@@ -61,6 +81,86 @@ FORBIDDEN_KEYS = [
61
81
  ]
62
82
 
63
83
 
84
+ class OAuthConfigWrapper(OAuth, BaseModel):
85
+ """Wrapper for OAuth configuration.
86
+
87
+ It's main purpose is to provide to_dict method which is used to serialize
88
+ the oauth configuration to the original format.
89
+
90
+ """
91
+
92
+ # Pydantic configuration to allow arbitrary user defined types
93
+ class Config:
94
+ arbitrary_types_allowed = True
95
+
96
+ oauth: OAuth
97
+ original_config: Dict[str, Any]
98
+
99
+ def get_bearer_token(self) -> str:
100
+ """Returns a bearer token."""
101
+ return self.oauth.get_bearer_token()
102
+
103
+ def to_dict(self) -> Dict[str, Any]:
104
+ """Converts the OAuth configuration to the original format."""
105
+ return self.original_config
106
+
107
+ @staticmethod
108
+ def _valid_type_values() -> Set[str]:
109
+ """Returns the valid built-in values for the `type` field in the `oauth`."""
110
+ return AzureEntraIDOAuthType.valid_string_values()
111
+
112
+ @classmethod
113
+ def from_dict(cls, oauth_config: Dict[str, Any]) -> OAuthConfigWrapper:
114
+ """Initializes a dataclass from the passed config.
115
+
116
+ Args:
117
+ oauth_config: (dict) The config from which to initialize.
118
+
119
+ Returns:
120
+ AzureOAuthConfig
121
+ """
122
+ original_config = deepcopy(oauth_config)
123
+
124
+ oauth_type: Optional[str] = oauth_config.get(OAUTH_TYPE_FIELD, None)
125
+
126
+ if oauth_type is None:
127
+ message = (
128
+ "Oauth configuration must contain "
129
+ f"'{OAUTH_TYPE_FIELD}' field and it must be set to one of the "
130
+ f"following values: {OAuthConfigWrapper._valid_type_values()}, "
131
+ f"or to the path of module which is "
132
+ f"implementing {OAuth.__name__} protocol."
133
+ )
134
+ structlogger.error(
135
+ "azure_oauth_config.missing_oauth_type",
136
+ message=message,
137
+ )
138
+ raise ValueError(message)
139
+
140
+ if oauth_type in AzureEntraIDOAuthType.valid_string_values():
141
+ return cls(
142
+ oauth=AzureEntraIDOAuthConfig.from_dict(oauth_config),
143
+ original_config=original_config,
144
+ )
145
+
146
+ module = class_from_module_path(oauth_type)
147
+
148
+ if not issubclass(module, OAuth):
149
+ message = (
150
+ f"Module {oauth_type} does not implement "
151
+ f"{OAuth.__name__} interface."
152
+ )
153
+ structlogger.error(
154
+ "azure_oauth_config.invalid_oauth_module",
155
+ message=message,
156
+ )
157
+ raise ValueError(message)
158
+
159
+ return cls(
160
+ oauth=module.from_dict(oauth_config), original_config=original_config
161
+ )
162
+
163
+
64
164
  @dataclass
65
165
  class AzureOpenAIClientConfig:
66
166
  """Parses configuration for Azure OpenAI client, resolves aliases and
@@ -80,11 +180,13 @@ class AzureOpenAIClientConfig:
80
180
  # API Type is not used by LiteLLM backend, but we define
81
181
  # it here for backward compatibility.
82
182
  api_type: Optional[str] = AZURE_API_TYPE
83
-
84
183
  # Provider is not used by LiteLLM backend, but we define it here since it's
85
184
  # used as switch between different clients.
86
185
  provider: str = AZURE_OPENAI_PROVIDER
87
186
 
187
+ # OAuth related parameters
188
+ oauth: Optional[OAuthConfigWrapper] = None
189
+
88
190
  extra_parameters: dict = field(default_factory=dict)
89
191
 
90
192
  def __post_init__(self) -> None:
@@ -106,7 +208,7 @@ class AzureOpenAIClientConfig:
106
208
  raise ValueError(message)
107
209
 
108
210
  @classmethod
109
- def from_dict(cls, config: dict) -> "AzureOpenAIClientConfig":
211
+ def from_dict(cls, config: dict) -> AzureOpenAIClientConfig:
110
212
  """Initializes a dataclass from the passed config.
111
213
 
112
214
  Args:
@@ -129,6 +231,16 @@ class AzureOpenAIClientConfig:
129
231
  # Validate that the forbidden keys are not present
130
232
  validate_forbidden_keys(config, FORBIDDEN_KEYS)
131
233
  # Init client config
234
+
235
+ cls._validate_authentication_configuration(config)
236
+
237
+ has_oauth_key = config.get(OAUTH_KEY, None) is not None
238
+ oauth = (
239
+ OAuthConfigWrapper.from_dict(config.pop(OAUTH_KEY))
240
+ if has_oauth_key
241
+ else None
242
+ )
243
+
132
244
  this = AzureOpenAIClientConfig(
133
245
  # Required parameters
134
246
  deployment=config.pop(DEPLOYMENT_CONFIG_KEY),
@@ -142,6 +254,8 @@ class AzureOpenAIClientConfig:
142
254
  # in clients.
143
255
  api_base=config.pop(API_BASE_CONFIG_KEY, None),
144
256
  api_version=config.pop(API_VERSION_CONFIG_KEY, None),
257
+ # OAuth related parameters, set only if auth_type is set to 'entra_id'
258
+ oauth=oauth,
145
259
  # The rest of parameters (e.g. model parameters) are considered
146
260
  # as extra parameters (this also includes timeout).
147
261
  extra_parameters=config,
@@ -154,12 +268,33 @@ class AzureOpenAIClientConfig:
154
268
  # Extra parameters should also be on the top level
155
269
  d.pop("extra_parameters", None)
156
270
  d.update(self.extra_parameters)
271
+
272
+ d.pop("oauth", None)
273
+ d.update({"oauth": self.oauth.to_dict()} if self.oauth else {})
157
274
  return d
158
275
 
159
276
  @staticmethod
160
277
  def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
161
278
  return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
162
279
 
280
+ @staticmethod
281
+ def _validate_authentication_configuration(config: Dict[str, Any]) -> None:
282
+ """Validates the authentication configuration."""
283
+ has_api_key = config.get(API_KEY, None) is not None
284
+ has_oauth_key = config.get(OAUTH_KEY, None) is not None
285
+
286
+ if has_api_key and has_oauth_key:
287
+ message = (
288
+ "Azure OpenAI client configuration cannot contain "
289
+ f"both '{API_KEY}' and '{OAUTH_KEY}' fields. Please provide either "
290
+ f"'{API_KEY}' or '{OAUTH_KEY}' fields."
291
+ )
292
+ structlogger.error(
293
+ "azure_openai_client_config.multiple_auth_types_specified",
294
+ message=message,
295
+ )
296
+ raise ValueError(message)
297
+
163
298
 
164
299
  def is_azure_openai_config(config: dict) -> bool:
165
300
  """Check whether the configuration is meant to configure
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Protocol, runtime_checkable
2
4
 
3
5
 
@@ -9,7 +11,7 @@ class ClientConfig(Protocol):
9
11
  """
10
12
 
11
13
  @classmethod
12
- def from_dict(cls, config: dict) -> "ClientConfig":
14
+ def from_dict(cls, config: dict) -> ClientConfig:
13
15
  """
14
16
  Initializes the client config with the given configuration.
15
17
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import Any, Dict
3
5
 
@@ -69,7 +71,7 @@ class DefaultLiteLLMClientConfig:
69
71
  raise ValueError(message)
70
72
 
71
73
  @classmethod
72
- def from_dict(cls, config: dict) -> "DefaultLiteLLMClientConfig":
74
+ def from_dict(cls, config: dict) -> DefaultLiteLLMClientConfig:
73
75
  """
74
76
  Initializes a dataclass from the passed config.
75
77
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import Any, Dict, Optional
3
5
 
@@ -90,7 +92,7 @@ class HuggingFaceLocalEmbeddingClientConfig:
90
92
  raise ValueError(message)
91
93
 
92
94
  @classmethod
93
- def from_dict(cls, config: dict) -> "HuggingFaceLocalEmbeddingClientConfig":
95
+ def from_dict(cls, config: dict) -> HuggingFaceLocalEmbeddingClientConfig:
94
96
  """
95
97
  Initializes a dataclass from the passed config.
96
98
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  from dataclasses import dataclass, field
3
5
  from typing import Any, Dict, List
@@ -120,7 +122,7 @@ class LiteLLMRouterClientConfig:
120
122
  raise ValueError(message)
121
123
 
122
124
  @classmethod
123
- def from_dict(cls, config: dict) -> "LiteLLMRouterClientConfig":
125
+ def from_dict(cls, config: dict) -> LiteLLMRouterClientConfig:
124
126
  """Initializes a dataclass from the passed config.
125
127
 
126
128
  Args:
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import List, Optional
3
5
 
@@ -41,7 +43,7 @@ class ModelConfig:
41
43
  api_type: Optional[str] = None
42
44
 
43
45
  @classmethod
44
- def from_dict(cls, config: dict) -> "ModelConfig":
46
+ def from_dict(cls, config: dict) -> ModelConfig:
45
47
  """Initializes a dataclass from the passed config. The provider config param is
46
48
  used to determine the client config class to use. The client config class takes
47
49
  care of resolving config aliases and throwing deprecation warnings.
@@ -131,7 +133,7 @@ class ModelGroupConfig:
131
133
  raise ValueError(message)
132
134
 
133
135
  @classmethod
134
- def from_dict(cls, config: dict) -> "ModelGroupConfig":
136
+ def from_dict(cls, config: dict) -> ModelGroupConfig:
135
137
  """Initializes a dataclass from the passed config.
136
138
 
137
139
  Args:
@@ -0,0 +1,33 @@
1
+ import abc
2
+ from typing import Any, Dict, Type, TypeVar
3
+
4
+ OAUTH_TYPE_FIELD = "type"
5
+ OAUTH_KEY = "oauth"
6
+
7
+ OAuthType = TypeVar("OAuthType", bound="OAuth")
8
+
9
+
10
+ class OAuth(abc.ABC):
11
+ """Interface for OAuth configuration."""
12
+
13
+ @classmethod
14
+ @abc.abstractmethod
15
+ def from_dict(
16
+ cls: Type[OAuthType], config: Dict[str, Any]
17
+ ) -> OAuthType: # ignore[type]
18
+ """Initializes a dataclass from the passed config.
19
+
20
+ Args:
21
+ config: (dict) The config from which to initialize.
22
+
23
+ Returns:
24
+ OAuth
25
+ """
26
+
27
+ @abc.abstractmethod
28
+ def get_bearer_token(self) -> str:
29
+ """Returns a bearer token.
30
+
31
+ Bear token is used to authenticate requests to the Azure
32
+ Oopen AI instance's API protected by the Gateway.
33
+ """
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import Any, Dict, Optional
3
5
 
@@ -111,7 +113,7 @@ class OpenAIClientConfig:
111
113
  raise ValueError(message)
112
114
 
113
115
  @classmethod
114
- def from_dict(cls, config: dict) -> "OpenAIClientConfig":
116
+ def from_dict(cls, config: dict) -> OpenAIClientConfig:
115
117
  """
116
118
  Initializes a dataclass from the passed config.
117
119
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import Optional
3
5
 
@@ -37,7 +39,7 @@ class RasaLLMClientConfig:
37
39
  extra_parameters: dict = field(default_factory=dict)
38
40
 
39
41
  @classmethod
40
- def from_dict(cls, config: dict) -> "RasaLLMClientConfig":
42
+ def from_dict(cls, config: dict) -> RasaLLMClientConfig:
41
43
  """
42
44
  Initializes a dataclass from the passed config.
43
45
 
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import Any, Dict, Optional
3
5
 
@@ -113,7 +115,7 @@ class SelfHostedLLMClientConfig:
113
115
  raise ValueError(message)
114
116
 
115
117
  @classmethod
116
- def from_dict(cls, config: dict) -> "SelfHostedLLMClientConfig":
118
+ def from_dict(cls, config: dict) -> SelfHostedLLMClientConfig:
117
119
  """
118
120
  Initializes a dataclass from the passed config.
119
121
 
@@ -0,0 +1,6 @@
1
+ DEFAULT_AZURE_API_KEY_NAME = "AZURE_API_KEY"
2
+ LITE_LLM_API_BASE_FIELD = "api_base"
3
+ LITE_LLM_API_KEY_FIELD = "api_key"
4
+ LITE_LLM_API_VERSION_FIELD = "api_version"
5
+ LITE_LLM_MODEL_FIELD = "model"
6
+ LITE_LLM_AZURE_AD_TOKEN = "azure_ad_token"
@@ -19,8 +19,14 @@ from rasa.shared.constants import (
19
19
  )
20
20
  from rasa.shared.exceptions import ProviderClientValidationError
21
21
  from rasa.shared.providers._configs.azure_openai_client_config import (
22
+ AzureEntraIDOAuthConfig,
22
23
  AzureOpenAIClientConfig,
23
24
  )
25
+ from rasa.shared.providers.constants import (
26
+ DEFAULT_AZURE_API_KEY_NAME,
27
+ LITE_LLM_API_KEY_FIELD,
28
+ LITE_LLM_AZURE_AD_TOKEN,
29
+ )
24
30
  from rasa.shared.providers.embedding._base_litellm_embedding_client import (
25
31
  _BaseLiteLLMEmbeddingClient,
26
32
  )
@@ -41,6 +47,8 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
41
47
  If not provided, it will be set via environment variable.
42
48
  api_version (Optional[str]): The version of the API to use.
43
49
  If not provided, it will be set via environment variable.
50
+ oauth (Optional[AzureEntraIDOAuthConfig]): Optional OAuth configuration.
51
+ If provided, the client will use OAuth for authentication.
44
52
  kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
45
53
  to the embedding model deployment.
46
54
 
@@ -57,6 +65,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
57
65
  api_base: Optional[str] = None,
58
66
  api_type: Optional[str] = None,
59
67
  api_version: Optional[str] = None,
68
+ oauth: Optional[AzureEntraIDOAuthConfig] = None,
60
69
  **kwargs: Any,
61
70
  ):
62
71
  super().__init__() # type: ignore
@@ -84,7 +93,11 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
84
93
  # Litellm does not support use of OPENAI_API_KEY, so we need to map it
85
94
  # because of backward compatibility. However, we're first looking at
86
95
  # AZURE_API_KEY.
87
- self._api_key_env_var = self._resolve_api_key_env_var()
96
+
97
+ self._oauth = oauth
98
+ self._api_key_env_var = (
99
+ self._resolve_api_key_env_var() if not self._oauth else None
100
+ )
88
101
 
89
102
  self.validate_client_setup()
90
103
 
@@ -100,7 +113,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
100
113
  return self._extra_parameters[API_KEY]
101
114
 
102
115
  if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
103
- return "${AZURE_API_KEY}"
116
+ return f"${{{DEFAULT_AZURE_API_KEY_NAME}}}"
104
117
 
105
118
  if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
106
119
  # API key can be set through OPENAI_API_KEY too,
@@ -163,6 +176,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
163
176
  api_base=azure_openai_config.api_base,
164
177
  api_type=azure_openai_config.api_type,
165
178
  api_version=azure_openai_config.api_version,
179
+ oauth=azure_openai_config.oauth,
166
180
  **azure_openai_config.extra_parameters,
167
181
  )
168
182
 
@@ -177,6 +191,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
177
191
  api_base=self.api_base,
178
192
  api_type=self.api_type,
179
193
  api_version=self.api_version,
194
+ oauth=self._oauth,
180
195
  extra_parameters=self._extra_parameters,
181
196
  )
182
197
  return config.to_dict()
@@ -219,13 +234,23 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
219
234
 
220
235
  @property
221
236
  def _embedding_fn_args(self) -> dict:
237
+ auth_parameter: Dict[str, str] = {}
238
+
239
+ if self._oauth:
240
+ auth_parameter = {
241
+ **auth_parameter,
242
+ LITE_LLM_AZURE_AD_TOKEN: self._oauth.get_bearer_token(),
243
+ }
244
+ elif self._api_key_env_var:
245
+ auth_parameter = {LITE_LLM_API_KEY_FIELD: self._api_key_env_var}
246
+
222
247
  return {
223
248
  **self._litellm_extra_parameters,
224
249
  "model": self._litellm_model_name,
225
250
  "api_base": self.api_base,
226
251
  "api_type": self.api_type,
227
252
  "api_version": self.api_version,
228
- "api_key": self._api_key_env_var,
253
+ **auth_parameter,
229
254
  }
230
255
 
231
256
  @property
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Any, Dict, List
3
5
 
@@ -46,7 +48,7 @@ class LiteLLMRouterEmbeddingClient(
46
48
  )
47
49
 
48
50
  @classmethod
49
- def from_config(cls, config: Dict[str, Any]) -> "LiteLLMRouterEmbeddingClient":
51
+ def from_config(cls, config: Dict[str, Any]) -> LiteLLMRouterEmbeddingClient:
50
52
  """Instantiates a LiteLLM Router Embedding client from a configuration dict.
51
53
 
52
54
  Args:
@@ -1,15 +1,13 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from abc import abstractmethod
3
- from typing import Any, Dict, List, Union
5
+ from typing import Any, Dict, List, Union, cast
4
6
 
5
7
  import structlog
6
- from litellm import (
7
- acompletion,
8
- completion,
9
- validate_environment,
10
- )
8
+ from litellm import acompletion, completion, validate_environment
11
9
 
12
- from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY
10
+ from rasa.shared.constants import API_BASE_CONFIG_KEY, API_KEY, ROLE_USER
13
11
  from rasa.shared.exceptions import (
14
12
  ProviderClientAPIException,
15
13
  ProviderClientValidationError,
@@ -50,7 +48,7 @@ class _BaseLiteLLMClient:
50
48
 
51
49
  @classmethod
52
50
  @abstractmethod
53
- def from_config(cls, config: Dict[str, Any]) -> "_BaseLiteLLMClient":
51
+ def from_config(cls, config: Dict[str, Any]) -> _BaseLiteLLMClient:
54
52
  pass
55
53
 
56
54
  @property
@@ -122,12 +120,18 @@ class _BaseLiteLLMClient:
122
120
  raise ProviderClientValidationError(event_info)
123
121
 
124
122
  @suppress_logs(log_level=logging.WARNING)
125
- def completion(self, messages: Union[List[str], str]) -> LLMResponse:
123
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
126
124
  """Synchronously generate completions for given list of messages.
127
125
 
128
126
  Args:
129
- messages: List of messages or a single message to generate the
130
- completion for.
127
+ messages: The message can be,
128
+ - a list of preformatted messages. Each message should be a dictionary
129
+ with the following keys:
130
+ - content: The message content.
131
+ - role: The role of the message (e.g. user or system).
132
+ - a list of messages. Each message is a string and will be formatted
133
+ as a user message.
134
+ - a single message as a string which will be formatted as user message.
131
135
 
132
136
  Returns:
133
137
  List of message completions.
@@ -136,7 +140,7 @@ class _BaseLiteLLMClient:
136
140
  ProviderClientAPIException: If the API request fails.
137
141
  """
138
142
  try:
139
- formatted_messages = self._format_messages(messages)
143
+ formatted_messages = self._get_formatted_messages(messages)
140
144
  arguments = resolve_environment_variables(self._completion_fn_args)
141
145
  response = completion(messages=formatted_messages, **arguments)
142
146
  return self._format_response(response)
@@ -144,12 +148,20 @@ class _BaseLiteLLMClient:
144
148
  raise ProviderClientAPIException(e)
145
149
 
146
150
  @suppress_logs(log_level=logging.WARNING)
147
- async def acompletion(self, messages: Union[List[str], str]) -> LLMResponse:
151
+ async def acompletion(
152
+ self, messages: Union[List[dict], List[str], str]
153
+ ) -> LLMResponse:
148
154
  """Asynchronously generate completions for given list of messages.
149
155
 
150
156
  Args:
151
- messages: List of messages or a single message to generate the
152
- completion for.
157
+ messages: The message can be,
158
+ - a list of preformatted messages. Each message should be a dictionary
159
+ with the following keys:
160
+ - content: The message content.
161
+ - role: The role of the message (e.g. user or system).
162
+ - a list of messages. Each message is a string and will be formatted
163
+ as a user message.
164
+ - a single message as a string which will be formatted as user message.
153
165
 
154
166
  Returns:
155
167
  List of message completions.
@@ -158,7 +170,7 @@ class _BaseLiteLLMClient:
158
170
  ProviderClientAPIException: If the API request fails.
159
171
  """
160
172
  try:
161
- formatted_messages = self._format_messages(messages)
173
+ formatted_messages = self._get_formatted_messages(messages)
162
174
  arguments = resolve_environment_variables(self._completion_fn_args)
163
175
  response = await acompletion(messages=formatted_messages, **arguments)
164
176
  return self._format_response(response)
@@ -181,11 +193,24 @@ class _BaseLiteLLMClient:
181
193
  )
182
194
  raise ProviderClientAPIException(e, message)
183
195
 
196
+ def _get_formatted_messages(
197
+ self, messages: Union[List[dict], List[str], str]
198
+ ) -> List[Dict[str, str]]:
199
+ """Returns a list of formatted messages."""
200
+ if (
201
+ isinstance(messages, list)
202
+ and len(messages) > 0
203
+ and isinstance(messages[0], dict)
204
+ ):
205
+ # Check if the messages are already formatted. If so, return them as is.
206
+ return cast(List[Dict[str, str]], messages)
207
+ return self._format_messages(messages)
208
+
184
209
  def _format_messages(self, messages: Union[List[str], str]) -> List[Dict[str, str]]:
185
210
  """Formats messages (or a single message) to OpenAI format."""
186
211
  if isinstance(messages, str):
187
212
  messages = [messages]
188
- return [{"content": message, "role": "user"} for message in messages]
213
+ return [{"content": message, "role": ROLE_USER} for message in messages]
189
214
 
190
215
  def _format_response(self, response: Any) -> LLMResponse:
191
216
  """Parses the LiteLLM response to Rasa format."""