rasa-pro 3.11.3a1.dev7__py3-none-any.whl → 3.12.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (65) hide show
  1. rasa/cli/arguments/default_arguments.py +1 -1
  2. rasa/cli/dialogue_understanding_test.py +251 -0
  3. rasa/core/actions/action.py +7 -16
  4. rasa/core/channels/__init__.py +0 -2
  5. rasa/core/channels/socketio.py +23 -1
  6. rasa/core/nlg/contextual_response_rephraser.py +9 -62
  7. rasa/core/policies/enterprise_search_policy.py +12 -77
  8. rasa/core/policies/flows/flow_executor.py +2 -26
  9. rasa/core/processor.py +8 -11
  10. rasa/dialogue_understanding/generator/command_generator.py +49 -43
  11. rasa/dialogue_understanding/generator/llm_based_command_generator.py +5 -5
  12. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -2
  13. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +15 -34
  14. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +6 -11
  15. rasa/dialogue_understanding/utils.py +1 -8
  16. rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
  17. rasa/dialogue_understanding_test/constants.py +2 -0
  18. rasa/dialogue_understanding_test/du_test_runner.py +93 -0
  19. rasa/dialogue_understanding_test/io.py +54 -0
  20. rasa/dialogue_understanding_test/validation.py +22 -0
  21. rasa/e2e_test/e2e_test_runner.py +9 -7
  22. rasa/hooks.py +9 -15
  23. rasa/model_manager/socket_bridge.py +2 -7
  24. rasa/model_manager/warm_rasa_process.py +4 -9
  25. rasa/plugin.py +0 -11
  26. rasa/shared/constants.py +2 -21
  27. rasa/shared/core/events.py +8 -8
  28. rasa/shared/nlu/constants.py +0 -3
  29. rasa/shared/providers/_configs/azure_entra_id_client_creds.py +40 -0
  30. rasa/shared/providers/_configs/azure_entra_id_config.py +533 -0
  31. rasa/shared/providers/_configs/azure_openai_client_config.py +131 -15
  32. rasa/shared/providers/_configs/client_config.py +3 -1
  33. rasa/shared/providers/_configs/default_litellm_client_config.py +9 -7
  34. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +13 -11
  35. rasa/shared/providers/_configs/litellm_router_client_config.py +12 -10
  36. rasa/shared/providers/_configs/model_group_config.py +11 -5
  37. rasa/shared/providers/_configs/oauth_config.py +33 -0
  38. rasa/shared/providers/_configs/openai_client_config.py +14 -12
  39. rasa/shared/providers/_configs/rasa_llm_client_config.py +5 -3
  40. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +12 -11
  41. rasa/shared/providers/constants.py +6 -0
  42. rasa/shared/providers/embedding/azure_openai_embedding_client.py +30 -7
  43. rasa/shared/providers/embedding/litellm_router_embedding_client.py +5 -2
  44. rasa/shared/providers/llm/_base_litellm_client.py +6 -4
  45. rasa/shared/providers/llm/azure_openai_llm_client.py +88 -34
  46. rasa/shared/providers/llm/default_litellm_llm_client.py +4 -2
  47. rasa/shared/providers/llm/litellm_router_llm_client.py +23 -3
  48. rasa/shared/providers/llm/llm_client.py +4 -2
  49. rasa/shared/providers/llm/llm_response.py +1 -42
  50. rasa/shared/providers/llm/openai_llm_client.py +11 -5
  51. rasa/shared/providers/llm/rasa_llm_client.py +13 -5
  52. rasa/shared/providers/llm/self_hosted_llm_client.py +17 -10
  53. rasa/shared/providers/router/_base_litellm_router_client.py +10 -8
  54. rasa/shared/providers/router/router_client.py +3 -1
  55. rasa/shared/utils/llm.py +16 -12
  56. rasa/shared/utils/schemas/events.py +1 -1
  57. rasa/tracing/instrumentation/attribute_extractors.py +0 -2
  58. rasa/version.py +1 -1
  59. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/METADATA +2 -1
  60. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/RECORD +63 -56
  61. rasa/core/channels/studio_chat.py +0 -192
  62. rasa/dialogue_understanding/constants.py +0 -1
  63. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/NOTICE +0 -0
  64. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/WHEEL +0 -0
  65. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/entry_points.txt +0 -0
@@ -4,23 +4,29 @@ from typing import Any, Dict, List, Optional
4
4
  import structlog
5
5
 
6
6
  from rasa.shared.constants import (
7
+ API_BASE_CONFIG_KEY,
8
+ API_KEY,
9
+ API_VERSION_CONFIG_KEY,
7
10
  AZURE_API_BASE_ENV_VAR,
8
11
  AZURE_API_KEY_ENV_VAR,
9
12
  AZURE_API_TYPE_ENV_VAR,
10
13
  AZURE_API_VERSION_ENV_VAR,
14
+ AZURE_OPENAI_PROVIDER,
11
15
  OPENAI_API_BASE_ENV_VAR,
12
16
  OPENAI_API_KEY_ENV_VAR,
13
17
  OPENAI_API_TYPE_ENV_VAR,
14
18
  OPENAI_API_VERSION_ENV_VAR,
15
- API_BASE_CONFIG_KEY,
16
- API_KEY,
17
- API_VERSION_CONFIG_KEY,
18
- AZURE_OPENAI_PROVIDER,
19
19
  )
20
20
  from rasa.shared.exceptions import ProviderClientValidationError
21
21
  from rasa.shared.providers._configs.azure_openai_client_config import (
22
+ AzureEntraIDOAuthConfig,
22
23
  AzureOpenAIClientConfig,
23
24
  )
25
+ from rasa.shared.providers.constants import (
26
+ DEFAULT_AZURE_API_KEY_NAME,
27
+ LITE_LLM_API_KEY_FIELD,
28
+ LITE_LLM_AZURE_AD_TOKEN,
29
+ )
24
30
  from rasa.shared.providers.embedding._base_litellm_embedding_client import (
25
31
  _BaseLiteLLMEmbeddingClient,
26
32
  )
@@ -41,6 +47,8 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
41
47
  If not provided, it will be set via environment variable.
42
48
  api_version (Optional[str]): The version of the API to use.
43
49
  If not provided, it will be set via environment variable.
50
+ oauth (Optional[AzureEntraIDOAuthConfig]): Optional OAuth configuration. If provided,
51
+ the client will use OAuth for authentication.
44
52
  kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
45
53
  to the embedding model deployment.
46
54
 
@@ -57,6 +65,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
57
65
  api_base: Optional[str] = None,
58
66
  api_type: Optional[str] = None,
59
67
  api_version: Optional[str] = None,
68
+ oauth: Optional[AzureEntraIDOAuthConfig] = None,
60
69
  **kwargs: Any,
61
70
  ):
62
71
  super().__init__() # type: ignore
@@ -84,7 +93,11 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
84
93
  # Litellm does not support use of OPENAI_API_KEY, so we need to map it
85
94
  # because of backward compatibility. However, we're first looking at
86
95
  # AZURE_API_KEY.
87
- self._api_key_env_var = self._resolve_api_key_env_var()
96
+
97
+ self._oauth = oauth
98
+ self._api_key_env_var = (
99
+ self._resolve_api_key_env_var() if not self._oauth else None
100
+ )
88
101
 
89
102
  self.validate_client_setup()
90
103
 
@@ -100,7 +113,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
100
113
  return self._extra_parameters[API_KEY]
101
114
 
102
115
  if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
103
- return "${AZURE_API_KEY}"
116
+ return f"${DEFAULT_AZURE_API_KEY_NAME}"
104
117
 
105
118
  if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
106
119
  # API key can be set through OPENAI_API_KEY too,
@@ -163,6 +176,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
163
176
  api_base=azure_openai_config.api_base,
164
177
  api_type=azure_openai_config.api_type,
165
178
  api_version=azure_openai_config.api_version,
179
+ oauth=azure_openai_config.oauth,
166
180
  **azure_openai_config.extra_parameters,
167
181
  )
168
182
 
@@ -177,6 +191,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
177
191
  api_base=self.api_base,
178
192
  api_type=self.api_type,
179
193
  api_version=self.api_version,
194
+ oauth=self._oauth,
180
195
  extra_parameters=self._extra_parameters,
181
196
  )
182
197
  return config.to_dict()
@@ -219,13 +234,21 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
219
234
 
220
235
  @property
221
236
  def _embedding_fn_args(self) -> dict:
237
+ auth_parameter = (
238
+ {
239
+ LITE_LLM_AZURE_AD_TOKEN: self._oauth.get_bearer_token(),
240
+ }
241
+ if self._oauth
242
+ else {LITE_LLM_API_KEY_FIELD: self._api_key_env_var}
243
+ )
244
+
222
245
  return {
223
246
  **self._litellm_extra_parameters,
224
247
  "model": self._litellm_model_name,
225
248
  "api_base": self.api_base,
226
249
  "api_type": self.api_type,
227
250
  "api_version": self.api_version,
228
- "api_key": self._api_key_env_var,
251
+ **auth_parameter,
229
252
  }
230
253
 
231
254
  @property
@@ -1,5 +1,8 @@
1
- from typing import Any, Dict, List
1
+ from __future__ import annotations
2
+
2
3
  import logging
4
+ from typing import Any, Dict, List
5
+
3
6
  import structlog
4
7
 
5
8
  from rasa.shared.exceptions import ProviderClientAPIException
@@ -45,7 +48,7 @@ class LiteLLMRouterEmbeddingClient(
45
48
  )
46
49
 
47
50
  @classmethod
48
- def from_config(cls, config: Dict[str, Any]) -> "LiteLLMRouterEmbeddingClient":
51
+ def from_config(cls, config: Dict[str, Any]) -> LiteLLMRouterEmbeddingClient:
49
52
  """Instantiates a LiteLLM Router Embedding client from a configuration dict.
50
53
 
51
54
  Args:
@@ -1,11 +1,13 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from abc import abstractmethod
3
- from typing import Dict, List, Any, Union
5
+ from typing import Any, Dict, List, Union
4
6
 
5
7
  import structlog
6
8
  from litellm import (
7
- completion,
8
9
  acompletion,
10
+ completion,
9
11
  validate_environment,
10
12
  )
11
13
 
@@ -19,7 +21,7 @@ from rasa.shared.providers._ssl_verification_utils import (
19
21
  ensure_ssl_certificates_for_litellm_openai_based_clients,
20
22
  )
21
23
  from rasa.shared.providers.llm.llm_response import LLMResponse, LLMUsage
22
- from rasa.shared.utils.io import suppress_logs, resolve_environment_variables
24
+ from rasa.shared.utils.io import resolve_environment_variables, suppress_logs
23
25
 
24
26
  structlogger = structlog.get_logger()
25
27
 
@@ -50,7 +52,7 @@ class _BaseLiteLLMClient:
50
52
 
51
53
  @classmethod
52
54
  @abstractmethod
53
- def from_config(cls, config: Dict[str, Any]) -> "_BaseLiteLLMClient":
55
+ def from_config(cls, config: Dict[str, Any]) -> _BaseLiteLLMClient:
54
56
  pass
55
57
 
56
58
  @property
@@ -1,33 +1,60 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import re
3
- from typing import Dict, Any, Optional
5
+ from typing import Any, Dict, Optional
4
6
 
5
7
  import structlog
6
8
 
7
9
  from rasa.shared.constants import (
8
- OPENAI_API_BASE_ENV_VAR,
9
- OPENAI_API_VERSION_ENV_VAR,
10
- AZURE_API_BASE_ENV_VAR,
11
- AZURE_API_VERSION_ENV_VAR,
12
10
  API_BASE_CONFIG_KEY,
11
+ API_KEY,
13
12
  API_VERSION_CONFIG_KEY,
14
- DEPLOYMENT_CONFIG_KEY,
13
+ AZURE_API_BASE_ENV_VAR,
15
14
  AZURE_API_KEY_ENV_VAR,
16
- OPENAI_API_TYPE_ENV_VAR,
17
- OPENAI_API_KEY_ENV_VAR,
18
15
  AZURE_API_TYPE_ENV_VAR,
16
+ AZURE_API_VERSION_ENV_VAR,
19
17
  AZURE_OPENAI_PROVIDER,
20
- API_KEY,
18
+ DEPLOYMENT_CONFIG_KEY,
19
+ OPENAI_API_BASE_ENV_VAR,
20
+ OPENAI_API_KEY_ENV_VAR,
21
+ OPENAI_API_TYPE_ENV_VAR,
22
+ OPENAI_API_VERSION_ENV_VAR,
21
23
  )
22
24
  from rasa.shared.exceptions import ProviderClientValidationError
23
25
  from rasa.shared.providers._configs.azure_openai_client_config import (
26
+ AzureEntraIDOAuthConfig,
24
27
  AzureOpenAIClientConfig,
25
28
  )
29
+ from rasa.shared.providers.constants import (
30
+ DEFAULT_AZURE_API_KEY_NAME,
31
+ LITE_LLM_API_BASE_FIELD,
32
+ LITE_LLM_API_KEY_FIELD,
33
+ LITE_LLM_API_VERSION_FIELD,
34
+ LITE_LLM_AZURE_AD_TOKEN,
35
+ )
26
36
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
27
37
  from rasa.shared.utils.io import raise_deprecation_warning
28
38
 
29
39
  structlogger = structlog.get_logger()
30
40
 
41
+ AZURE_CLIENT_ID = "AZURE_CLIENT_ID"
42
+ AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET"
43
+ AZURE_TENANT_ID = "AZURE_TENANT_ID"
44
+ CLIENT_SECRET_VARS = (AZURE_CLIENT_ID, AZURE_CLIENT_SECRET, AZURE_TENANT_ID)
45
+
46
+ AZURE_CLIENT_CERTIFICATE_PATH = "AZURE_CLIENT_CERTIFICATE_PATH"
47
+ AZURE_CLIENT_CERTIFICATE_PASSWORD = "AZURE_CLIENT_CERTIFICATE_PASSWORD"
48
+ AZURE_CLIENT_SEND_CERTIFICATE_CHAIN = "AZURE_CLIENT_SEND_CERTIFICATE_CHAIN"
49
+ CERT_VARS = (AZURE_CLIENT_ID, AZURE_CLIENT_CERTIFICATE_PATH, AZURE_TENANT_ID)
50
+
51
+
52
+ class AzureADConfig:
53
+ def __init__(
54
+ self, client_id: str, client_secret: str, tenant_id: str, scopes: str
55
+ ) -> None:
56
+ self.scopes = scopes
57
+
31
58
 
32
59
  class AzureOpenAILLMClient(_BaseLiteLLMClient):
33
60
  """A client for interfacing with Azure's OpenAI LLM deployments.
@@ -41,6 +68,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
41
68
  it will be set via environment variables.
42
69
  api_version (Optional[str]): The version of the API to use. If not provided,
43
70
  it will be set via environment variable.
71
+
44
72
  kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
45
73
  to the model deployment.
46
74
 
@@ -57,6 +85,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
57
85
  api_type: Optional[str] = None,
58
86
  api_base: Optional[str] = None,
59
87
  api_version: Optional[str] = None,
88
+ oauth: Optional[AzureEntraIDOAuthConfig] = None,
60
89
  **kwargs: Any,
61
90
  ):
62
91
  super().__init__() # type: ignore
@@ -80,8 +109,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
80
109
  or os.getenv(OPENAI_API_VERSION_ENV_VAR)
81
110
  )
82
111
 
83
- self._api_key_env_var = self._resolve_api_key_env_var()
84
-
85
112
  # Not used by LiteLLM, here for backward compatibility
86
113
  self._api_type = (
87
114
  api_type
@@ -89,6 +116,19 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
89
116
  or os.getenv(OPENAI_API_TYPE_ENV_VAR)
90
117
  )
91
118
 
119
+ os.unsetenv("OPENAI_API_KEY")
120
+ os.unsetenv("AZURE_API_KEY")
121
+
122
+ self._oauth = oauth
123
+
124
+ if self._oauth:
125
+ os.unsetenv(DEFAULT_AZURE_API_KEY_NAME)
126
+ os.unsetenv(AZURE_API_KEY_ENV_VAR)
127
+ os.unsetenv(OPENAI_API_KEY_ENV_VAR)
128
+ self._api_key_env_var = (
129
+ self._resolve_api_key_env_var() if not self._oauth else None
130
+ )
131
+
92
132
  # Run helper function to check and raise deprecation warning if
93
133
  # deprecated environment variables were used for initialization of the
94
134
  # client settings
@@ -157,7 +197,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
157
197
  return self._extra_parameters[API_KEY]
158
198
 
159
199
  if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
160
- return "${AZURE_API_KEY}"
200
+ return f"${DEFAULT_AZURE_API_KEY_NAME}"
161
201
 
162
202
  if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
163
203
  # API key can be set through OPENAI_API_KEY too,
@@ -188,7 +228,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
188
228
  )
189
229
 
190
230
  @classmethod
191
- def from_config(cls, config: Dict[str, Any]) -> "AzureOpenAILLMClient":
231
+ def from_config(cls, config: Dict[str, Any]) -> AzureOpenAILLMClient:
192
232
  """Initializes the client from given configuration.
193
233
 
194
234
  Args:
@@ -215,11 +255,12 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
215
255
  raise
216
256
 
217
257
  return cls(
218
- azure_openai_config.deployment,
219
- azure_openai_config.model,
220
- azure_openai_config.api_type,
221
- azure_openai_config.api_base,
222
- azure_openai_config.api_version,
258
+ deployment=azure_openai_config.deployment,
259
+ model=azure_openai_config.model,
260
+ api_type=azure_openai_config.api_type,
261
+ api_base=azure_openai_config.api_base,
262
+ api_version=azure_openai_config.api_version,
263
+ oauth=azure_openai_config.oauth,
223
264
  **azure_openai_config.extra_parameters,
224
265
  )
225
266
 
@@ -234,6 +275,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
234
275
  api_base=self._api_base,
235
276
  api_version=self._api_version,
236
277
  api_type=self._api_type,
278
+ oauth=self._oauth,
237
279
  extra_parameters=self._extra_parameters,
238
280
  )
239
281
  return config.to_dict()
@@ -282,12 +324,21 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
282
324
  """Returns the completion arguments for invoking a call through
283
325
  LiteLLM's completion functions.
284
326
  """
327
+ # Set the API key env var to None if OAuth is used
328
+ auth_parameter = (
329
+ {
330
+ LITE_LLM_AZURE_AD_TOKEN: self._oauth.get_bearer_token(),
331
+ }
332
+ if self._oauth
333
+ else {LITE_LLM_API_KEY_FIELD: self._api_key_env_var}
334
+ )
335
+
285
336
  fn_args = super()._completion_fn_args
286
337
  fn_args.update(
287
338
  {
288
- "api_base": self.api_base,
289
- "api_version": self.api_version,
290
- "api_key": self._api_key_env_var,
339
+ LITE_LLM_API_BASE_FIELD: self.api_base,
340
+ LITE_LLM_API_VERSION_FIELD: self.api_version,
341
+ **auth_parameter,
291
342
  }
292
343
  )
293
344
  return fn_args
@@ -314,41 +365,44 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
314
365
 
315
366
  return info.format(setting=setting, options=options)
316
367
 
368
+ env_var_field = "env_var"
369
+ config_key_field = "config_key"
370
+ current_value_field = "current_value"
317
371
  # All required settings for Azure OpenAI client
318
372
  settings: Dict[str, Dict[str, Any]] = {
319
373
  "API Base": {
320
- "current_value": self.api_base,
321
- "env_var": AZURE_API_BASE_ENV_VAR,
322
- "config_key": API_BASE_CONFIG_KEY,
374
+ current_value_field: self.api_base,
375
+ env_var_field: AZURE_API_BASE_ENV_VAR,
376
+ config_key_field: API_BASE_CONFIG_KEY,
323
377
  },
324
378
  "API Version": {
325
- "current_value": self.api_version,
326
- "env_var": AZURE_API_VERSION_ENV_VAR,
327
- "config_key": API_VERSION_CONFIG_KEY,
379
+ current_value_field: self.api_version,
380
+ env_var_field: AZURE_API_VERSION_ENV_VAR,
381
+ config_key_field: API_VERSION_CONFIG_KEY,
328
382
  },
329
383
  "Deployment Name": {
330
- "current_value": self.deployment,
331
- "env_var": None,
332
- "config_key": DEPLOYMENT_CONFIG_KEY,
384
+ current_value_field: self.deployment,
385
+ env_var_field: None,
386
+ config_key_field: DEPLOYMENT_CONFIG_KEY,
333
387
  },
334
388
  }
335
389
 
336
390
  missing_settings = [
337
391
  setting_name
338
392
  for setting_name, setting_info in settings.items()
339
- if setting_info["current_value"] is None
393
+ if setting_info[current_value_field] is None
340
394
  ]
341
395
 
342
396
  if missing_settings:
343
397
  event_info = f"Client settings not set: " f"{', '.join(missing_settings)}. "
344
398
 
345
399
  for missing_setting in missing_settings:
346
- if settings[missing_setting]["current_value"] is not None:
400
+ if settings[missing_setting][current_value_field] is not None:
347
401
  continue
348
402
  event_info += generate_event_info_for_missing_setting(
349
403
  missing_setting,
350
- settings[missing_setting]["env_var"],
351
- settings[missing_setting]["config_key"],
404
+ settings[missing_setting][env_var_field],
405
+ settings[missing_setting][config_key_field],
352
406
  )
353
407
 
354
408
  structlogger.error(
@@ -1,4 +1,6 @@
1
- from typing import Dict, Any
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict
2
4
 
3
5
  from rasa.shared.constants import (
4
6
  AWS_BEDROCK_PROVIDER,
@@ -35,7 +37,7 @@ class DefaultLiteLLMClient(_BaseLiteLLMClient):
35
37
  self.validate_client_setup()
36
38
 
37
39
  @classmethod
38
- def from_config(cls, config: Dict[str, Any]) -> "DefaultLiteLLMClient":
40
+ def from_config(cls, config: Dict[str, Any]) -> DefaultLiteLLMClient:
39
41
  default_config = DefaultLiteLLMClientConfig.from_dict(config)
40
42
  return cls(
41
43
  model=default_config.model,
@@ -1,11 +1,15 @@
1
- from typing import Any, Dict, List, Union
1
+ from __future__ import annotations
2
+
2
3
  import logging
4
+ from typing import Any, Dict, List, Union
5
+
3
6
  import structlog
4
7
 
5
8
  from rasa.shared.exceptions import ProviderClientAPIException
6
9
  from rasa.shared.providers._configs.litellm_router_client_config import (
7
10
  LiteLLMRouterClientConfig,
8
11
  )
12
+ from rasa.shared.providers.constants import LITE_LLM_MODEL_FIELD
9
13
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
10
14
  from rasa.shared.providers.llm.llm_response import LLMResponse
11
15
  from rasa.shared.providers.router._base_litellm_router_client import (
@@ -41,7 +45,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
41
45
  )
42
46
 
43
47
  @classmethod
44
- def from_config(cls, config: Dict[str, Any]) -> "LiteLLMRouterLLMClient":
48
+ def from_config(cls, config: Dict[str, Any]) -> LiteLLMRouterLLMClient:
45
49
  """Instantiates a LiteLLM Router LLM client from a configuration dict.
46
50
 
47
51
  Args:
@@ -86,6 +90,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
86
90
  ProviderClientAPIException: If the API request fails.
87
91
  """
88
92
  try:
93
+ structlogger.info(
94
+ "litellm_router_llm_client.text_completion",
95
+ _completion_fn_args=self._completion_fn_args,
96
+ )
89
97
  response = self.router_client.text_completion(
90
98
  prompt=prompt, **self._completion_fn_args
91
99
  )
@@ -106,6 +114,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
106
114
  ProviderClientAPIException: If the API request fails.
107
115
  """
108
116
  try:
117
+ structlogger.info(
118
+ "litellm_router_llm_client.atext_completion",
119
+ _completion_fn_args=self._completion_fn_args,
120
+ )
109
121
  response = await self.router_client.atext_completion(
110
122
  prompt=prompt, **self._completion_fn_args
111
123
  )
@@ -135,6 +147,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
135
147
  return self._text_completion(messages)
136
148
  try:
137
149
  formatted_messages = self._format_messages(messages)
150
+ structlogger.info(
151
+ "litellm_router_llm_client.completion",
152
+ _completion_fn_args=self._completion_fn_args,
153
+ )
138
154
  response = self.router_client.completion(
139
155
  messages=formatted_messages, **self._completion_fn_args
140
156
  )
@@ -164,6 +180,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
164
180
  return await self._atext_completion(messages)
165
181
  try:
166
182
  formatted_messages = self._format_messages(messages)
183
+ structlogger.info(
184
+ "litellm_router_llm_client.acompletion",
185
+ _completion_fn_args=self._completion_fn_args,
186
+ )
167
187
  response = await self.router_client.acompletion(
168
188
  messages=formatted_messages, **self._completion_fn_args
169
189
  )
@@ -178,5 +198,5 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
178
198
  """
179
199
  return {
180
200
  **self._litellm_extra_parameters,
181
- "model": self.model_group_id,
201
+ LITE_LLM_MODEL_FIELD: self.model_group_id,
182
202
  }
@@ -1,4 +1,6 @@
1
- from typing import Protocol, Dict, List, runtime_checkable, Union
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List, Protocol, Union, runtime_checkable
2
4
 
3
5
  from rasa.shared.providers.llm.llm_response import LLMResponse
4
6
 
@@ -11,7 +13,7 @@ class LLMClient(Protocol):
11
13
  """
12
14
 
13
15
  @classmethod
14
- def from_config(cls, config: dict) -> "LLMClient":
16
+ def from_config(cls, config: dict) -> LLMClient:
15
17
  """
16
18
  Initializes the llm client with the given configuration.
17
19
 
@@ -1,8 +1,5 @@
1
1
  from dataclasses import dataclass, field, asdict
2
- from typing import Dict, List, Optional, Text, Any, Union
3
- import structlog
4
-
5
- structlogger = structlog.get_logger()
2
+ from typing import Dict, List, Optional
6
3
 
7
4
 
8
5
  @dataclass
@@ -19,18 +16,6 @@ class LLMUsage:
19
16
  def __post_init__(self) -> None:
20
17
  self.total_tokens = self.prompt_tokens + self.completion_tokens
21
18
 
22
- @classmethod
23
- def from_dict(cls, data: Dict[Text, Any]) -> "LLMUsage":
24
- """
25
- Creates an LLMUsage object from a dictionary.
26
- If any keys are missing, they will default to zero
27
- or whatever default you prefer.
28
- """
29
- return cls(
30
- prompt_tokens=data.get("prompt_tokens"),
31
- completion_tokens=data.get("completion_tokens"),
32
- )
33
-
34
19
  def to_dict(self) -> dict:
35
20
  """Converts the LLMUsage dataclass instance into a dictionary."""
36
21
  return asdict(self)
@@ -57,32 +42,6 @@ class LLMResponse:
57
42
  """Optional dictionary for storing additional information related to the
58
43
  completion that may not be covered by other fields."""
59
44
 
60
- @classmethod
61
- def from_dict(cls, data: Dict[Text, Any]) -> "LLMResponse":
62
- """
63
- Creates an LLMResponse from a dictionary.
64
- """
65
- usage_data = data.get("usage", {})
66
- usage_obj = LLMUsage.from_dict(usage_data) if usage_data else None
67
-
68
- return cls(
69
- id=data["id"],
70
- choices=data["choices"],
71
- created=data["created"],
72
- model=data.get("model"),
73
- usage=usage_obj,
74
- additional_info=data.get("additional_info"),
75
- )
76
-
77
- @classmethod
78
- def ensure_llm_response(cls, response: Union[str, "LLMResponse"]) -> "LLMResponse":
79
- if isinstance(response, LLMResponse):
80
- return response
81
-
82
- structlogger.warn("llm_response.deprecated_response_type", response=response)
83
- data = {"id": None, "choices": [response], "created": None}
84
- return LLMResponse.from_dict(data)
85
-
86
45
  def to_dict(self) -> dict:
87
46
  """Converts the LLMResponse dataclass instance into a dictionary."""
88
47
  result = asdict(self)
@@ -1,16 +1,22 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import re
3
- from typing import Dict, Any, Optional
5
+ from typing import Any, Dict, Optional
4
6
 
5
7
  import structlog
6
8
 
7
9
  from rasa.shared.constants import (
8
10
  OPENAI_API_BASE_ENV_VAR,
9
- OPENAI_API_VERSION_ENV_VAR,
10
11
  OPENAI_API_TYPE_ENV_VAR,
12
+ OPENAI_API_VERSION_ENV_VAR,
11
13
  OPENAI_PROVIDER,
12
14
  )
13
15
  from rasa.shared.providers._configs.openai_client_config import OpenAIClientConfig
16
+ from rasa.shared.providers.constants import (
17
+ LITE_LLM_API_KEY_FIELD,
18
+ LITE_LLM_API_VERSION_FIELD,
19
+ )
14
20
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
15
21
 
16
22
  structlogger = structlog.get_logger()
@@ -57,7 +63,7 @@ class OpenAILLMClient(_BaseLiteLLMClient):
57
63
  self.validate_client_setup()
58
64
 
59
65
  @classmethod
60
- def from_config(cls, config: Dict[str, Any]) -> "OpenAILLMClient":
66
+ def from_config(cls, config: Dict[str, Any]) -> OpenAILLMClient:
61
67
  """
62
68
  Initializes the client from given configuration.
63
69
 
@@ -148,8 +154,8 @@ class OpenAILLMClient(_BaseLiteLLMClient):
148
154
  fn_args = super()._completion_fn_args
149
155
  fn_args.update(
150
156
  {
151
- "api_base": self.api_base,
152
- "api_version": self.api_version,
157
+ LITE_LLM_API_KEY_FIELD: self.api_base,
158
+ LITE_LLM_API_VERSION_FIELD: self.api_version,
153
159
  }
154
160
  )
155
161
  return fn_args
@@ -1,17 +1,22 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Any, Dict, Optional
2
4
 
3
5
  import structlog
4
6
 
5
7
  from rasa.shared.constants import (
6
- RASA_PROVIDER,
7
8
  OPENAI_PROVIDER,
9
+ RASA_PROVIDER,
8
10
  )
9
11
  from rasa.shared.providers._configs.rasa_llm_client_config import (
10
12
  RasaLLMClientConfig,
11
13
  )
12
- from rasa.utils.licensing import retrieve_license_from_env
14
+ from rasa.shared.providers.constants import (
15
+ LITE_LLM_API_BASE_FIELD,
16
+ LITE_LLM_API_KEY_FIELD,
17
+ )
13
18
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
14
-
19
+ from rasa.utils.licensing import retrieve_license_from_env
15
20
 
16
21
  structlogger = structlog.get_logger()
17
22
 
@@ -88,12 +93,15 @@ class RasaLLMClient(_BaseLiteLLMClient):
88
93
  """
89
94
  fn_args = super()._completion_fn_args
90
95
  fn_args.update(
91
- {"api_base": self.api_base, "api_key": retrieve_license_from_env()}
96
+ {
97
+ LITE_LLM_API_BASE_FIELD: self.api_base,
98
+ LITE_LLM_API_KEY_FIELD: retrieve_license_from_env(),
99
+ }
92
100
  )
93
101
  return fn_args
94
102
 
95
103
  @classmethod
96
- def from_config(cls, config: Dict[str, Any]) -> "RasaLLMClient":
104
+ def from_config(cls, config: Dict[str, Any]) -> RasaLLMClient:
97
105
  try:
98
106
  client_config = RasaLLMClientConfig.from_dict(config)
99
107
  except ValueError as e: