rasa-pro 3.12.0.dev1__py3-none-any.whl → 3.12.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List
5
+
6
+ from azure.identity import ClientSecretCredential
7
+
8
+ from rasa.shared.providers._configs.oauth_config import OAuth
9
+
10
+
11
+ @dataclass
12
+ class AzureEntraIDClientCreds(OAuth):
13
+ client_id: str
14
+ client_secret: str
15
+ tenant_id: str
16
+ scopes: List[str] = field(default_factory=list)
17
+
18
+ @classmethod
19
+ def from_config(cls, config: Dict[str, Any]) -> AzureEntraIDClientCreds:
20
+ scopes = config.get("scopes")
21
+ if isinstance(scopes, str):
22
+ scopes = [scopes]
23
+
24
+ return cls(
25
+ client_id=config.get("client_id"),
26
+ client_secret=config.get("client_secret"),
27
+ tenant_id=config.get("tenant_id"),
28
+ scopes=scopes,
29
+ )
30
+
31
+ def get_bearer_token(self) -> str:
32
+ return (
33
+ ClientSecretCredential(
34
+ client_id=self.client_id,
35
+ client_secret=self.client_secret,
36
+ tenant_id=self.tenant_id,
37
+ )
38
+ .get_token(*self.scopes)
39
+ .token
40
+ )
@@ -0,0 +1,533 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import logging
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ from functools import lru_cache
8
+ from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar
9
+
10
+ import structlog
11
+ from azure.core.credentials import TokenProvider
12
+ from azure.identity import (
13
+ CertificateCredential,
14
+ ClientSecretCredential,
15
+ DefaultAzureCredential,
16
+ )
17
+
18
+ from rasa.shared.providers._configs.oauth_config import OAUTH_TYPE_FIELD, OAuth
19
+
20
+ AZURE_CLIENT_ID_FIELD = "client_id"
21
+ AZURE_CLIENT_SECRET_FIELD = "client_secret"
22
+ AZURE_TENANT_ID_FIELD = "tenant_id"
23
+ AZURE_CERTIFICATE_PATH_FIELD = "certificate_path"
24
+ AZURE_CERTIFICATE_PASSWORD_FIELD = "certificate_password"
25
+ AZURE_SEND_CERTIFICATE_CHAIN_FIELD = "send_certificate_chain"
26
+ AZURE_SCOPES_FIELD = "scopes"
27
+ AZURE_AUTHORITY_FIELD = "authority_host"
28
+ AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD = "disable_instance_discovery"
29
+
30
+
31
+ azure_logger = logging.getLogger("azure")
32
+ azure_logger.setLevel(logging.DEBUG)
33
+
34
+ structlogger = structlog.get_logger()
35
+
36
+
37
+ class AzureEntraIDOAuthType(str, Enum):
38
+ """Azure Entra ID OAuth types."""
39
+
40
+ AZURE_ENTRA_ID_DEFAULT = "azure_entra_id_default"
41
+ AZURE_ENTRA_ID_CLIENT_SECRET = "azure_entra_id_client_secret"
42
+ AZURE_ENTRA_ID_CLIENT_CERTIFICATE = "azure_entra_id_client_certificate"
43
+
44
+ # Invalid type is used to indicate that the type
45
+ # configuration is invalid EntraID or not set.
46
+ INVALID = "invalid"
47
+
48
+ @staticmethod
49
+ def from_string(value: Optional[str]) -> AzureEntraIDOAuthType:
50
+ """Converts a string to an AzureOAuthType."""
51
+ if value is None or value not in AzureEntraIDOAuthType.valid_string_values():
52
+ return AzureEntraIDOAuthType.INVALID
53
+
54
+ return AzureEntraIDOAuthType(value)
55
+
56
+ @staticmethod
57
+ def valid_string_values() -> Set[str]:
58
+ """Returns the valid string values for the AzureOAuthType."""
59
+ return {e.value for e in AzureEntraIDOAuthType.valid_values()}
60
+
61
+ @staticmethod
62
+ def valid_values() -> Set[AzureEntraIDOAuthType]:
63
+ """Returns the valid values for the AzureOAuthType."""
64
+ return {
65
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_DEFAULT,
66
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_SECRET,
67
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_CERTIFICATE,
68
+ }
69
+
70
+
71
+ # BearerTokenProvider is a callable that returns a bearer token.
72
+ BearerTokenProvider = Callable[[], str]
73
+
74
+ AzureEntraIDTokenProviderConfigType = TypeVar(
75
+ "AzureEntraIDTokenProviderConfigType", bound="AzureEntraIDTokenProviderConfig"
76
+ )
77
+
78
+
79
+ class AzureEntraIDTokenProviderConfig(abc.ABC):
80
+ """Interface for Azure Entra ID OAuth credential configuration."""
81
+
82
+ @abc.abstractmethod
83
+ def create_azure_token_provider(self) -> TokenProvider:
84
+ """Create an Azure Entra ID token provider."""
85
+ ...
86
+
87
+ @classmethod
88
+ def from_config(
89
+ cls: AzureEntraIDTokenProviderConfigType, config: Dict[str, Any]
90
+ ) -> AzureEntraIDTokenProviderConfigType:
91
+ """Initializes a dataclass from the passed config.
92
+
93
+ Args:
94
+ config: (dict) The config from which to initialize.
95
+
96
+ Returns:
97
+ AzureEntraIDCredential
98
+ """
99
+ ...
100
+
101
+
102
+ @dataclass
103
+ class AzureEntraIDClientCredentialsConfig(AzureEntraIDTokenProviderConfig):
104
+ """Azure Entra ID OAuth client credentials configuration.
105
+
106
+ Attributes:
107
+ client_id: The client ID.
108
+ client_secret: The client secret.
109
+ tenant_id: The tenant ID.
110
+ authority_host: The authority host.
111
+ disable_instance_discovery: Whether to disable instance discovery. This is used
112
+ to disable fetching metadata from the Azure Instance Metadata Service.
113
+ """
114
+
115
+ client_id: str
116
+ client_secret: str
117
+ tenant_id: str
118
+ authority_host: Optional[str] = None
119
+ disable_instance_discovery: bool = False
120
+
121
+ @staticmethod
122
+ def required_fields() -> Set[str]:
123
+ """Returns the required fields for the configuration."""
124
+ return {AZURE_CLIENT_ID_FIELD, AZURE_TENANT_ID_FIELD, AZURE_CLIENT_SECRET_FIELD}
125
+
126
+ @staticmethod
127
+ def config_has_required_fields(config: Dict[str, Any]) -> bool:
128
+ """Check if the configuration has all the required fields."""
129
+ return AzureEntraIDClientCredentialsConfig.required_fields().issubset(
130
+ set(config.keys())
131
+ )
132
+
133
+ @classmethod
134
+ def from_config(cls, config: Dict[str, Any]) -> AzureEntraIDClientCredentialsConfig:
135
+ """Initializes a dataclass from the passed config.
136
+
137
+ Args:
138
+ config: (dict) The config from which to initialize.
139
+
140
+ Returns:
141
+ AzureClientCredentialsConfig
142
+ """
143
+ if not cls.config_has_required_fields(config):
144
+ message = (
145
+ f"A configuration for Azure client credentials "
146
+ f"must contain the following keys: {cls.required_fields()}"
147
+ )
148
+ structlogger.error(
149
+ "azure_client_credentials_config.missing_required_keys",
150
+ message=message,
151
+ config=config,
152
+ )
153
+ raise ValueError(message)
154
+
155
+ return cls(
156
+ client_id=config.pop(AZURE_CLIENT_ID_FIELD),
157
+ client_secret=config.pop(AZURE_CLIENT_SECRET_FIELD),
158
+ tenant_id=config.pop(AZURE_TENANT_ID_FIELD),
159
+ authority_host=config.pop(AZURE_AUTHORITY_FIELD, None),
160
+ disable_instance_discovery=config.pop(
161
+ AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD, False
162
+ ),
163
+ )
164
+
165
+ def create_azure_token_provider(self) -> TokenProvider:
166
+ """Create a ClientSecretCredential for Azure Entra ID."""
167
+ return create_azure_entra_id_client_credentials(
168
+ client_id=self.client_id,
169
+ client_secret=self.client_secret,
170
+ tenant_id=self.tenant_id,
171
+ authority_host=self.authority_host,
172
+ disable_instance_discovery=self.disable_instance_discovery,
173
+ )
174
+
175
+
176
+ @lru_cache
177
+ def create_azure_entra_id_client_credentials(
178
+ client_id: str,
179
+ client_secret: str,
180
+ tenant_id: str,
181
+ authority_host: Optional[str] = None,
182
+ disable_instance_discovery: bool = False,
183
+ ) -> ClientSecretCredential:
184
+ """Creates a ClientSecretCredential for Azure Entra ID.
185
+
186
+ We cache the result of this function to avoid creating multiple instances
187
+ of the same credential. This makes it possible to utilise the token caching
188
+ and token refreshing functionality of the azure-identity library.
189
+
190
+ Args:
191
+ client_id: The client ID.
192
+ client_secret: The client secret.
193
+ tenant_id: The tenant ID.
194
+ authority_host: The authority host.
195
+ disable_instance_discovery: Whether to disable instance discovery. This is used
196
+ to disable fetching metadata from the Azure Instance Metadata Service.
197
+
198
+ Returns:
199
+ ClientSecretCredential
200
+ """
201
+ return ClientSecretCredential(
202
+ client_id=client_id,
203
+ client_secret=client_secret,
204
+ tenant_id=tenant_id,
205
+ authority=authority_host,
206
+ disable_instance_discovery=disable_instance_discovery,
207
+ )
208
+
209
+
210
+ @dataclass
211
+ class AzureEntraIDClientCertificateConfig(AzureEntraIDTokenProviderConfig):
212
+ """Azure Entra ID OAuth client certificate configuration.
213
+
214
+ Attributes:
215
+ client_id: The client ID.
216
+ tenant_id: The tenant ID.
217
+ certificate_path: The path to the certificate file.
218
+ certificate_password: The certificate password.
219
+ send_certificate_chain: Whether to send the certificate chain.
220
+ authority_host: The authority host.
221
+ disable_instance_discovery: Whether to disable instance discovery. This is used
222
+ to disable fetching metadata from the Azure Instance Metadata Service.
223
+ """
224
+
225
+ client_id: str
226
+ tenant_id: str
227
+ certificate_path: str
228
+ certificate_password: Optional[str] = None
229
+ send_certificate_chain: bool = False
230
+ authority_host: Optional[str] = None
231
+ disable_instance_discovery: bool = False
232
+
233
+ @staticmethod
234
+ def required_fields() -> Set[str]:
235
+ """Returns the required fields for the configuration."""
236
+ return {
237
+ AZURE_CLIENT_ID_FIELD,
238
+ AZURE_TENANT_ID_FIELD,
239
+ AZURE_CERTIFICATE_PATH_FIELD,
240
+ AZURE_CERTIFICATE_PASSWORD_FIELD,
241
+ }
242
+
243
+ @staticmethod
244
+ def config_has_required_fields(config: Dict[str, Any]) -> bool:
245
+ """Check if the configuration has all the required fields."""
246
+ return AzureEntraIDClientCertificateConfig.required_fields().issubset(
247
+ set(config.keys())
248
+ )
249
+
250
+ @classmethod
251
+ def from_config(
252
+ cls, config: Dict[str, Any]
253
+ ) -> Optional[AzureEntraIDClientCertificateConfig]:
254
+ """Initializes a dataclass from the passed config.
255
+
256
+ Args:
257
+ config: (dict) The config from which to initialize.
258
+
259
+ Returns:
260
+ AzureClientCertificateConfig
261
+ """
262
+ if not cls.config_has_required_fields(config):
263
+ message = (
264
+ f"A configuration for Azure client certificate "
265
+ f"must contain "
266
+ f"the following keys: {cls.required_fields()}"
267
+ )
268
+ structlogger.error(
269
+ "azure_client_certificate_config.validation_error",
270
+ message=message,
271
+ config=config,
272
+ )
273
+ raise ValueError(message)
274
+
275
+ return cls(
276
+ client_id=config[AZURE_CLIENT_ID_FIELD],
277
+ tenant_id=config[AZURE_TENANT_ID_FIELD],
278
+ certificate_path=config[AZURE_CERTIFICATE_PATH_FIELD],
279
+ certificate_password=config.get(AZURE_CERTIFICATE_PASSWORD_FIELD, None),
280
+ authority_host=config.get(AZURE_AUTHORITY_FIELD, None),
281
+ send_certificate_chain=config.get(
282
+ AZURE_SEND_CERTIFICATE_CHAIN_FIELD, False
283
+ ),
284
+ disable_instance_discovery=config.get(
285
+ AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD, False
286
+ ),
287
+ )
288
+
289
+ def create_azure_token_provider(self) -> TokenProvider:
290
+ """Creates a CertificateCredential for Azure Entra ID."""
291
+ return create_azure_entra_id_certificate_credentials(
292
+ client_id=self.client_id,
293
+ tenant_id=self.tenant_id,
294
+ certificate_path=self.certificate_path,
295
+ password=self.certificate_password,
296
+ send_certificate_chain=self.send_certificate_chain,
297
+ authority_host=self.authority_host,
298
+ disable_instance_discovery=self.disable_instance_discovery,
299
+ )
300
+
301
+
302
+ @lru_cache
303
+ def create_azure_entra_id_certificate_credentials(
304
+ tenant_id: str,
305
+ client_id: str,
306
+ certificate_path: Optional[str] = None,
307
+ password: Optional[str] = None,
308
+ send_certificate_chain: bool = False,
309
+ authority_host: Optional[str] = None,
310
+ disable_instance_discovery: bool = False,
311
+ ) -> CertificateCredential:
312
+ """Creates a CertificateCredential for Azure Entra ID.
313
+
314
+ We cache the result of this function to avoid creating multiple instances
315
+ of the same credential. This makes it possible to utilise the token caching
316
+ and token refreshing functionality of the azure-identity library.
317
+
318
+ Args:
319
+ tenant_id: The tenant ID.
320
+ client_id: The client ID.
321
+ certificate_path: The path to the certificate file.
322
+ password: The certificate password.
323
+ send_certificate_chain: Whether to send the certificate chain.
324
+ authority_host: The authority host.
325
+ disable_instance_discovery: Whether to disable instance discovery. This is used
326
+
327
+ Returns:
328
+ CertificateCredential
329
+ """
330
+
331
+ return CertificateCredential(
332
+ client_id=client_id,
333
+ tenant_id=tenant_id,
334
+ certificate_path=certificate_path,
335
+ password=password.encode("utf-8") if password else None,
336
+ send_certificate_chain=send_certificate_chain,
337
+ authority=authority_host,
338
+ disable_instance_discovery=disable_instance_discovery,
339
+ )
340
+
341
+
342
+ @dataclass
343
+ class AzureEntraIDDefaultCredentialsConfig(AzureEntraIDTokenProviderConfig):
344
+ """Azure Entra ID OAuth default credentials configuration.
345
+
346
+ Attributes:
347
+ authority_host: The authority host.
348
+ """
349
+
350
+ authority_host: Optional[str] = None
351
+
352
+ @classmethod
353
+ def from_config(
354
+ cls, config: Dict[str, Any]
355
+ ) -> AzureEntraIDDefaultCredentialsConfig:
356
+ """Initializes a dataclass from the passed config.
357
+
358
+ Args:
359
+ config: (dict) The config from which to initialize.
360
+
361
+ Returns:
362
+ AzureOAuthDefaultCredentialsConfig
363
+ """
364
+ return cls(authority_host=config.pop(AZURE_AUTHORITY_FIELD, None))
365
+
366
+ def create_azure_token_provider(self) -> TokenProvider:
367
+ """Creates a DefaultAzureCredential."""
368
+ return create_azure_entra_id_default_credentials(
369
+ authority_host=self.authority_host
370
+ )
371
+
372
+
373
+ @lru_cache
374
+ def create_azure_entra_id_default_credentials(
375
+ authority_host: Optional[str] = None,
376
+ ) -> DefaultAzureCredential:
377
+ """Creates a DefaultAzureCredential.
378
+
379
+ We cache the result of this function to avoid creating multiple instances
380
+ of the same credential. This makes it possible to utilise the token caching
381
+ functionality of the azure-identity library.
382
+
383
+ Args:
384
+ authority_host: The authority host.
385
+
386
+ Returns:
387
+ DefaultAzureCredential
388
+ """
389
+ return DefaultAzureCredential(authority=authority_host)
390
+
391
+
392
+ @dataclass
393
+ class AzureEntraIDOAuthConfig(OAuth):
394
+ """Azure Entra ID OAuth configuration.
395
+
396
+ It consists of the scopes and the Azure Entra ID OAuth credentials.
397
+ """
398
+
399
+ scopes: List[str]
400
+ azure_entra_id_token_provider_config: Optional[AzureEntraIDTokenProviderConfig] = (
401
+ None
402
+ )
403
+
404
+ @staticmethod
405
+ def _supported_azure_oauth() -> (
406
+ Dict[AzureEntraIDOAuthType, Type[AzureEntraIDTokenProviderConfig]]
407
+ ):
408
+ """Returns a mapping of supported Azure Entra ID OAuth types to their"""
409
+ return {
410
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_DEFAULT: AzureEntraIDDefaultCredentialsConfig,
411
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_SECRET: AzureEntraIDClientCredentialsConfig,
412
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_CERTIFICATE: AzureEntraIDClientCertificateConfig,
413
+ }
414
+
415
+ @staticmethod
416
+ def _get_azure_oauth_by_type(
417
+ oauth_type: AzureEntraIDOAuthType,
418
+ ) -> Type[AzureEntraIDTokenProviderConfig]:
419
+ """Returns the Azure Entra ID OAuth class based on the type.
420
+
421
+ Args:
422
+ oauth_type: (AzureOAuthType) The type of the Azure Entra ID OAuth.
423
+
424
+ Returns:
425
+ The Azure Entra ID OAuth class
426
+
427
+ Raises:
428
+ ValueError: If the passed oauth_type is not supported or invalid.
429
+ """
430
+ azure_oauth_types = AzureEntraIDOAuthConfig._supported_azure_oauth()
431
+ azure_oauth_class = azure_oauth_types.get(oauth_type)
432
+
433
+ if azure_oauth_class is None:
434
+ message = (
435
+ f"Unsupported Azure Entra ID oauth type: {oauth_type}. "
436
+ f"Supported types are: {AzureEntraIDOAuthType.valid_string_values()}"
437
+ )
438
+ structlogger.error(
439
+ "azure_oauth_config.unsupported_azure_oauth_type",
440
+ message=message,
441
+ )
442
+ raise ValueError(message)
443
+
444
+ return azure_oauth_class
445
+
446
+ @classmethod
447
+ def from_config(cls, oauth_config: Dict[str, Any]) -> AzureEntraIDOAuthConfig:
448
+ """Initializes a dataclass from the passed config.
449
+
450
+ Args:
451
+ oauth_config: (dict) The config from which to initialize.
452
+
453
+ Returns:
454
+ AzureOAuthConfig
455
+ """
456
+ scopes = AzureEntraIDOAuthConfig._read_scopes_from_config(oauth_config)
457
+ azure_credentials = (
458
+ AzureEntraIDOAuthConfig._create_azure_entra_id_client_from_config(
459
+ oauth_config
460
+ )
461
+ )
462
+ return cls(
463
+ azure_entra_id_token_provider_config=azure_credentials, scopes=scopes
464
+ )
465
+
466
+ @staticmethod
467
+ def _read_scopes_from_config(oauth_config: Dict[str, Any]) -> List[str]:
468
+ """Reads scopes from the configuration.
469
+
470
+ The original scopes are removed from the configuration.
471
+
472
+ Args:
473
+ oauth_config: (dict) The configuration from which to read the scopes.
474
+
475
+ Returns:
476
+ List[str]: The list of scopes.
477
+ """
478
+ scopes = oauth_config.pop(AZURE_SCOPES_FIELD, "")
479
+
480
+ if not scopes:
481
+ message = "Azure Entra ID scopes cannot be empty."
482
+ structlogger.error(
483
+ "azure_oauth_config.scopes_empty",
484
+ message=message,
485
+ )
486
+ raise ValueError(message)
487
+
488
+ if isinstance(scopes, str):
489
+ scopes = [scopes]
490
+
491
+ return scopes
492
+
493
+ @staticmethod
494
+ def _create_azure_entra_id_client_from_config(
495
+ oauth_config: Dict[str, Any],
496
+ ) -> AzureEntraIDTokenProviderConfig:
497
+ """Creates an Azure Entra ID client from the configuration.
498
+
499
+ Args:
500
+ oauth_config: (dict) The configuration from which to create the credential.
501
+
502
+ Returns:
503
+ AzureEntraIDTokenProviderConfig: The Azure OAuth credential.
504
+ """
505
+
506
+ oauth_type = AzureEntraIDOAuthType.from_string(
507
+ oauth_config.pop(OAUTH_TYPE_FIELD, None)
508
+ )
509
+
510
+ if oauth_type == AzureEntraIDOAuthType.INVALID:
511
+ message = (
512
+ "Azure Entra ID oauth configuration must contain "
513
+ f"'{OAUTH_TYPE_FIELD}' field and it must be set to one of the "
514
+ f"following values: {AzureEntraIDOAuthType.valid_string_values()}, "
515
+ )
516
+ structlogger.error(
517
+ "azure_oauth_config.missing_azure_oauth_type",
518
+ message=message,
519
+ )
520
+ raise ValueError(message)
521
+
522
+ azure_oauth_class = AzureEntraIDOAuthConfig._get_azure_oauth_by_type(oauth_type)
523
+ return azure_oauth_class.from_config(oauth_config)
524
+
525
+ def _create_azure_credential(
526
+ self,
527
+ ) -> TokenProvider:
528
+ """Create an Azure Entra ID client which can be used to get a bearer token."""
529
+ return self.azure_entra_id_token_provider_config.create_azure_token_provider()
530
+
531
+ def get_bearer_token(self) -> str:
532
+ """Returns a bearer token."""
533
+ return self._create_azure_credential().get_token(*self.scopes).token
@@ -1,20 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
4
- from abc import abstractmethod
3
+ from copy import deepcopy
5
4
  from dataclasses import asdict, dataclass, field
6
- from enum import Enum
7
- from functools import lru_cache
8
- from typing import Any, Callable, Coroutine, Dict, List, Optional, Protocol, Set
5
+ from typing import (
6
+ Any,
7
+ Dict,
8
+ Optional,
9
+ Set,
10
+ )
9
11
 
10
12
  import structlog
11
- from azure.core.credentials import TokenProvider
12
- from azure.identity import (
13
- CertificateCredential,
14
- ClientSecretCredential,
15
- DefaultAzureCredential,
16
- get_bearer_token_provider,
17
- )
18
13
 
19
14
  from rasa.shared.constants import (
20
15
  API_BASE_CONFIG_KEY,
@@ -39,12 +34,22 @@ from rasa.shared.constants import (
39
34
  STREAM_CONFIG_KEY,
40
35
  TIMEOUT_CONFIG_KEY,
41
36
  )
37
+ from rasa.shared.providers._configs.azure_entra_id_config import (
38
+ AzureEntraIDOAuthConfig,
39
+ AzureEntraIDOAuthType,
40
+ )
41
+ from rasa.shared.providers._configs.oauth_config import (
42
+ OAUTH_KEY,
43
+ OAUTH_TYPE_FIELD,
44
+ OAuth,
45
+ )
42
46
  from rasa.shared.providers._configs.utils import (
43
47
  raise_deprecation_warnings,
44
48
  resolve_aliases,
45
49
  validate_forbidden_keys,
46
50
  validate_required_keys,
47
51
  )
52
+ from rasa.shared.utils.common import class_from_module_path
48
53
 
49
54
  structlogger = structlog.get_logger()
50
55
 
@@ -75,387 +80,33 @@ FORBIDDEN_KEYS = [
75
80
  ]
76
81
 
77
82
 
78
- AZURE_CLIENT_ID_FIELD = "client_id"
79
- AZURE_CLIENT_SECRET_FIELD = "client_secret"
80
- AZURE_TENANT_ID_FIELD = "tenant_id"
81
- AZURE_CERTIFICATE_PATH_FIELD = "certificate_path"
82
- AZURE_CERTIFICATE_PASSWORD_FIELD = "certificate_password"
83
- AZURE_SEND_CERTIFICATE_CHAIN_FIELD = "send_certificate_chain"
84
- AZURE_SCOPES_FIELD = "scopes"
85
- AZURE_AUTHORITY_FIELD = "authority_host"
86
- AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD = "disable_instance_discovery"
87
- OAUTH_TYPE_FIELD = "type"
88
- AZURE_OAUTH_KEY = "oauth"
89
-
90
-
91
- azure_logger = logging.getLogger("azure")
92
- azure_logger.setLevel(logging.DEBUG)
93
-
94
-
95
- class AzureOAuthType(str, Enum):
96
- AZURE_DEFAULT = "default"
97
- AZURE_CLIENT_SECRET = "client_secret"
98
- AZURE_CLIENT_CERTIFICATE = "client_certificate"
99
- # Invalid type is used to indicate that the type
100
- # configuration is invalid or not set.
101
- INVALID = "invalid"
102
-
103
- @staticmethod
104
- def from_string(value: Optional[str]) -> AzureOAuthType:
105
- if value is None or value not in AzureOAuthType.valid_string_values():
106
- return AzureOAuthType.INVALID
107
-
108
- return AzureOAuthType(value)
109
-
110
- @staticmethod
111
- def valid_string_values() -> Set[str]:
112
- return {e.value for e in AzureOAuthType.valid_values()}
113
-
114
- @staticmethod
115
- def valid_values() -> Set[AzureOAuthType]:
116
- return {
117
- AzureOAuthType.AZURE_DEFAULT,
118
- AzureOAuthType.AZURE_CLIENT_SECRET,
119
- AzureOAuthType.AZURE_CLIENT_CERTIFICATE,
120
- }
121
-
122
-
123
- class AzureAuthType(str, Enum):
124
- API_KEY = "api_key"
125
- OAUTH = "oauth"
126
-
127
- @staticmethod
128
- def from_string(value: str) -> AzureAuthType:
129
- try:
130
- return AzureAuthType(value)
131
- except ValueError:
132
- raise ValueError(f"Invalid AzureAuthType value: {value}")
133
-
134
- def __str__(self) -> str:
135
- return self.value
136
-
137
-
138
- DEFAULT_AUTH_TYPE = AzureAuthType.API_KEY
139
-
140
-
141
- BearerTokenProvider = Callable[[], Coroutine[Any, Any, str]]
142
-
143
-
144
- class AzureEntraIDCredential(Protocol):
145
- @abstractmethod
146
- def create_azure_credential(self) -> TokenProvider: ...
147
- @abstractmethod
148
- def to_dict(self) -> dict: ...
149
-
150
-
151
83
  @dataclass
152
- class AzureClientCredentialsConfig:
153
- """Azure OAuth client credentials configuration.
154
-
155
- Attributes:
156
- client_id: The client ID.
157
- client_secret: The client secret.
158
- tenant_id: The tenant ID.
159
- authority_host: The authority host.
160
- disable_instance_discovery: Whether to disable instance discovery. This is used
161
- to disable fetching metadata from the Azure Instance Metadata Service.
162
- """
163
-
164
- client_id: str
165
- client_secret: str
166
- tenant_id: str
167
- authority_host: Optional[str] = None
168
- disable_instance_discovery: bool = False
169
-
170
- @staticmethod
171
- def required_fields() -> Set[str]:
172
- """Returns the required fields for the configuration."""
173
- return {AZURE_CLIENT_ID_FIELD, AZURE_TENANT_ID_FIELD, AZURE_CLIENT_SECRET_FIELD}
174
-
175
- @staticmethod
176
- def config_has_required_fields(config: Dict[str, Any]) -> bool:
177
- """Check if the configuration has all the required fields."""
178
- return AzureClientCredentialsConfig.required_fields().issubset(
179
- set(config.keys())
180
- )
84
+ class OAuthConfigWrapper(OAuth):
85
+ """Wrapper for OAuth configuration.
181
86
 
182
- @classmethod
183
- def from_config(cls, config: Dict[str, Any]) -> AzureClientCredentialsConfig:
184
- """Initializes a dataclass from the passed config.
87
+ It's main purpose is to provide to_dict method which is used to serialize
88
+ the oauth configuration to the original format.
185
89
 
186
- Args:
187
- config: (dict) The config from which to initialize.
188
-
189
- Returns:
190
- AzureClientCredentialsConfig
191
- """
192
- if not cls.config_has_required_fields(config):
193
- message = (
194
- f"A configuration for Azure client credentials "
195
- f"must contain the following keys: {cls.required_fields()}"
196
- )
197
- structlogger.error(
198
- "azure_client_credentials_config.missing_required_keys",
199
- message=message,
200
- config=config,
201
- )
202
- raise ValueError(message)
203
-
204
- return cls(
205
- client_id=config.pop(AZURE_CLIENT_ID_FIELD),
206
- client_secret=config.pop(AZURE_CLIENT_SECRET_FIELD),
207
- tenant_id=config.pop(AZURE_TENANT_ID_FIELD),
208
- authority_host=config.pop(AZURE_AUTHORITY_FIELD, None),
209
- disable_instance_discovery=config.pop(
210
- AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD, False
211
- ),
212
- )
213
-
214
- def to_dict(self) -> dict:
215
- """Converts the config instance into a dictionary."""
216
- result = asdict(self)
217
- result[OAUTH_TYPE_FIELD] = AzureOAuthType.AZURE_CLIENT_SECRET.value
218
- return result
219
-
220
- def create_azure_credential(self) -> TokenProvider:
221
- return create_client_credentials(
222
- client_id=self.client_id,
223
- client_secret=self.client_secret,
224
- tenant_id=self.tenant_id,
225
- authority_host=self.authority_host,
226
- disable_instance_discovery=self.disable_instance_discovery,
227
- )
228
-
229
-
230
- @lru_cache
231
- def create_client_credentials(
232
- client_id: str,
233
- client_secret: str,
234
- tenant_id: str,
235
- authority_host: Optional[str] = None,
236
- disable_instance_discovery: bool = False,
237
- ) -> ClientSecretCredential:
238
- """Create a ClientSecretCredential.
239
-
240
- We cache the result of this function to avoid creating multiple instances
241
- of the same credential. This makes it possible to utilise the token caching
242
- functionality of the azure-identity library.
243
-
244
- Args:
245
- client_id: The client ID.
246
- client_secret: The client secret.
247
- tenant_id: The tenant ID.
248
- authority_host: The authority host.
249
- disable_instance_discovery: Whether to disable instance discovery. This is used
250
- to disable fetching metadata from the Azure Instance Metadata Service.
251
-
252
- Returns:
253
- ClientSecretCredential
254
90
  """
255
- return ClientSecretCredential(
256
- client_id=client_id,
257
- client_secret=client_secret,
258
- tenant_id=tenant_id,
259
- authority=authority_host,
260
- disable_instance_discovery=disable_instance_discovery,
261
- )
262
91
 
92
+ oauth: OAuth
93
+ original_config: Dict[str, Any]
263
94
 
264
- @dataclass
265
- class AzureClientCertificateConfig:
266
- """Azure OAuth client certificate configuration.
267
-
268
- Attributes:
269
- client_id: The client ID.
270
- tenant_id: The tenant ID.
271
- certificate_path: The path to the certificate file.
272
- certificate_password: The certificate password.
273
- send_certificate_chain: Whether to send the certificate chain.
274
- authority_host: The authority host.
275
- disable_instance_discovery: Whether to disable instance discovery. This is used
276
- to disable fetching metadata from the Azure Instance Metadata Service.
277
- """
278
-
279
- client_id: str
280
- tenant_id: str
281
- certificate_path: str
282
- certificate_password: Optional[str] = None
283
- send_certificate_chain: bool = False
284
- authority_host: Optional[str] = None
285
- disable_instance_discovery: bool = False
95
+ def get_bearer_token(self) -> str:
96
+ """Returns a bearer token."""
97
+ return self.oauth.get_bearer_token()
286
98
 
287
- @staticmethod
288
- def required_fields() -> Set[str]:
289
- """Returns the required fields for the configuration."""
290
- return {
291
- AZURE_CLIENT_ID_FIELD,
292
- AZURE_TENANT_ID_FIELD,
293
- AZURE_CERTIFICATE_PATH_FIELD,
294
- AZURE_CERTIFICATE_PASSWORD_FIELD,
295
- }
99
+ def to_dict(self) -> Dict[str, Any]:
100
+ """Converts the OAuth configuration to the original format."""
101
+ return self.original_config
296
102
 
297
103
  @staticmethod
298
- def config_has_required_fields(config: Dict[str, Any]) -> bool:
299
- """Check if the configuration has all the required fields."""
300
- return AzureClientCertificateConfig.required_fields().issubset(
301
- set(config.keys())
302
- )
104
+ def _valid_type_values() -> Set[str]:
105
+ """Returns the valid built-in values for the `type` field in the `oauth`."""
106
+ return AzureEntraIDOAuthType.valid_string_values()
303
107
 
304
108
  @classmethod
305
- def from_config(
306
- cls, config: Dict[str, Any]
307
- ) -> Optional[AzureClientCertificateConfig]:
308
- """Initializes a dataclass from the passed config.
309
-
310
- Args:
311
- config: (dict) The config from which to initialize.
312
-
313
- Returns:
314
- AzureClientCertificateConfig
315
- """
316
- if not cls.config_has_required_fields(config):
317
- message = (
318
- f"A configuration for Azure client certificate "
319
- f"must contain "
320
- f"the following keys: {cls.required_fields()}"
321
- )
322
- structlogger.error(
323
- "azure_client_certificate_config.validation_error",
324
- message=message,
325
- config=config,
326
- )
327
- raise ValueError(message)
328
-
329
- return cls(
330
- client_id=config[AZURE_CLIENT_ID_FIELD],
331
- tenant_id=config[AZURE_TENANT_ID_FIELD],
332
- certificate_path=config[AZURE_CERTIFICATE_PATH_FIELD],
333
- certificate_password=config.get(AZURE_CERTIFICATE_PASSWORD_FIELD, None),
334
- authority_host=config.get(AZURE_AUTHORITY_FIELD, None),
335
- send_certificate_chain=config.get(
336
- AZURE_SEND_CERTIFICATE_CHAIN_FIELD, False
337
- ),
338
- disable_instance_discovery=config.get(
339
- AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD, False
340
- ),
341
- )
342
-
343
- def to_dict(self) -> dict:
344
- """Converts the config instance into a dictionary."""
345
- result = asdict(self)
346
- result[OAUTH_TYPE_FIELD] = AzureOAuthType.AZURE_CLIENT_CERTIFICATE.value
347
- return result
348
-
349
- def create_azure_credential(self) -> TokenProvider:
350
- return create_certificate_credentials(
351
- client_id=self.client_id,
352
- tenant_id=self.tenant_id,
353
- certificate_path=self.certificate_path,
354
- password=self.certificate_password,
355
- send_certificate_chain=self.send_certificate_chain,
356
- authority_host=self.authority_host,
357
- disable_instance_discovery=self.disable_instance_discovery,
358
- )
359
-
360
-
361
- @lru_cache
362
- def create_certificate_credentials(
363
- tenant_id: str,
364
- client_id: str,
365
- certificate_path: Optional[str] = None,
366
- password: Optional[str] = None,
367
- send_certificate_chain: bool = False,
368
- authority_host: Optional[str] = None,
369
- disable_instance_discovery: bool = False,
370
- ) -> CertificateCredential:
371
- """Create a CertificateCredential.
372
-
373
- We cache the result of this function to avoid creating multiple instances
374
- of the same credential. This makes it possible to utilise the token caching
375
- functionality of the azure-identity library.
376
-
377
- Args:
378
- tenant_id: The tenant ID.
379
- client_id: The client ID.
380
- certificate_path: The path to the certificate file.
381
- password: The certificate password.
382
- send_certificate_chain: Whether to send the certificate chain.
383
- authority_host: The authority host.
384
- disable_instance_discovery: Whether to disable instance discovery. This is used
385
-
386
- Returns:
387
- CertificateCredential
388
- """
389
-
390
- return CertificateCredential(
391
- client_id=client_id,
392
- tenant_id=tenant_id,
393
- certificate_path=certificate_path,
394
- password=password.encode("utf-8") if password else None,
395
- send_certificate_chain=send_certificate_chain,
396
- authority=authority_host,
397
- disable_instance_discovery=disable_instance_discovery,
398
- )
399
-
400
-
401
- @dataclass
402
- class AzureOAuthDefaultCredentialsConfig:
403
- """Azure OAuth default credentials configuration.
404
-
405
- Attributes:
406
- authority_host: The authority host.
407
- """
408
-
409
- authority_host: Optional[str] = None
410
-
411
- @classmethod
412
- def from_config(cls, config: Dict[str, Any]) -> AzureOAuthDefaultCredentialsConfig:
413
- """Initializes a dataclass from the passed config.
414
-
415
- Args:
416
- config: (dict) The config from which to initialize.
417
-
418
- Returns:
419
- AzureOAuthDefaultCredentialsConfig
420
- """
421
- return cls(authority_host=config.pop(AZURE_AUTHORITY_FIELD, None))
422
-
423
- def to_dict(self) -> dict:
424
- """Converts the config instance into a dictionary."""
425
- result = asdict(self)
426
- result[OAUTH_TYPE_FIELD] = AzureOAuthType.AZURE_DEFAULT.value
427
- return result
428
-
429
- def create_azure_credential(self) -> TokenProvider:
430
- return create_default_credentials(authority_host=self.authority_host)
431
-
432
-
433
- @lru_cache
434
- def create_default_credentials(
435
- authority_host: Optional[str] = None,
436
- ) -> DefaultAzureCredential:
437
- """Create a DefaultAzureCredential.
438
-
439
- We cache the result of this function to avoid creating multiple instances
440
- of the same credential. This makes it possible to utilise the token caching
441
- functionality of the azure-identity library.
442
-
443
- Args:
444
- authority_host: The authority host.
445
-
446
- Returns:
447
- DefaultAzureCredential
448
- """
449
- return DefaultAzureCredential(authority=authority_host)
450
-
451
-
452
- @dataclass
453
- class AzureOAuthConfig:
454
- scopes: List[str]
455
- azure_credentials: Optional[AzureEntraIDCredential] = None
456
-
457
- @classmethod
458
- def from_config(cls, oauth_config: Dict[str, Any]) -> AzureOAuthConfig:
109
+ def from_config(cls, oauth_config: Dict[str, Any]) -> OAuth:
459
110
  """Initializes a dataclass from the passed config.
460
111
 
461
112
  Args:
@@ -464,15 +115,16 @@ class AzureOAuthConfig:
464
115
  Returns:
465
116
  AzureOAuthConfig
466
117
  """
467
- oauth_type = AzureOAuthType.from_string(
468
- oauth_config.pop(OAUTH_TYPE_FIELD, None)
469
- )
118
+ original_config = deepcopy(oauth_config)
119
+
120
+ oauth_type: Optional[str] = oauth_config.get(OAUTH_TYPE_FIELD, None)
470
121
 
471
- if oauth_type == AzureOAuthType.INVALID:
122
+ if oauth_type is None:
472
123
  message = (
473
- "Azure Entra ID oauth configuration must contain "
124
+ "Oauth configuration must contain "
474
125
  f"'{OAUTH_TYPE_FIELD}' field and it must be set to one of the "
475
- f"following values: {AzureOAuthType.valid_string_values()}, "
126
+ f"following values: {OAuthConfigWrapper._valid_type_values()}, "
127
+ f"or to the path of module which is implementing {OAuth.__name__} protocol."
476
128
  )
477
129
  structlogger.error(
478
130
  "azure_oauth_config.missing_oauth_type",
@@ -480,52 +132,22 @@ class AzureOAuthConfig:
480
132
  )
481
133
  raise ValueError(message)
482
134
 
483
- azure_credentials = None
484
- if oauth_type == AzureOAuthType.AZURE_CLIENT_SECRET:
485
- azure_credentials = AzureClientCredentialsConfig.from_config(oauth_config)
486
- elif oauth_type == AzureOAuthType.AZURE_CLIENT_CERTIFICATE:
487
- azure_credentials = AzureClientCertificateConfig.from_config(oauth_config)
488
- elif oauth_type == AzureOAuthType.AZURE_DEFAULT:
489
- azure_credentials = AzureOAuthDefaultCredentialsConfig.from_config(
490
- oauth_config
491
- )
492
-
493
- scopes = oauth_config.pop(AZURE_SCOPES_FIELD, "")
494
-
495
- if not scopes:
496
- message = "Azure Entra ID scopes cannot be empty."
497
- structlogger.error(
498
- "azure_oauth_config.scopes_empty",
499
- message=message,
500
- )
501
- raise ValueError(message)
502
-
503
- if isinstance(scopes, str):
504
- scopes = [scopes]
505
-
506
- return cls(azure_credentials=azure_credentials, scopes=scopes)
507
-
508
- def create_azure_credential(
509
- self,
510
- ) -> TokenProvider:
511
- return self.azure_credentials.create_azure_credential()
135
+ if oauth_type in AzureEntraIDOAuthType.valid_string_values():
136
+ oauth = AzureEntraIDOAuthConfig.from_config(oauth_config)
137
+ else:
138
+ module = class_from_module_path(oauth_type)
512
139
 
513
- def to_dict(self) -> dict:
514
- """Converts the config instance into a dictionary."""
515
- credentials_dict = (
516
- self.azure_credentials.to_dict() if self.azure_credentials else {}
517
- )
518
- result = asdict(self)
519
- result.update(credentials_dict)
520
- result.pop("azure_credentials", None)
521
- return result
140
+ if not issubclass(module, OAuth):
141
+ message = f"Module {oauth_type} does not implement {OAuth.__name__} interface."
142
+ structlogger.error(
143
+ "azure_oauth_config.invalid_oauth_module",
144
+ message=message,
145
+ )
146
+ raise ValueError(message)
522
147
 
523
- def get_bearer_token_provider(self) -> BearerTokenProvider:
524
- return get_bearer_token_provider(self.create_azure_credential(), *self.scopes)
148
+ oauth = module.from_config(oauth_config)
525
149
 
526
- def get_bearer_token(self) -> str:
527
- token = self.create_azure_credential().get_token(*self.scopes).token
528
- return token
150
+ return cls(oauth=oauth, original_config=original_config)
529
151
 
530
152
 
531
153
  @dataclass
@@ -552,7 +174,7 @@ class AzureOpenAIClientConfig:
552
174
  provider: str = AZURE_OPENAI_PROVIDER
553
175
 
554
176
  # OAuth related parameters
555
- oauth: Optional[AzureOAuthConfig] = None
177
+ oauth: Optional[OAuthConfigWrapper] = None
556
178
 
557
179
  extra_parameters: dict = field(default_factory=dict)
558
180
 
@@ -600,13 +222,13 @@ class AzureOpenAIClientConfig:
600
222
  # Init client config
601
223
 
602
224
  has_api_key = config.get(API_KEY, None) is not None
603
- has_oauth_key = config.get(AZURE_OAUTH_KEY, None) is not None
225
+ has_oauth_key = config.get(OAUTH_KEY, None) is not None
604
226
 
605
227
  if has_api_key and has_oauth_key:
606
228
  message = (
607
229
  "Azure OpenAI client configuration cannot contain "
608
- "both 'api_key' and 'oauth' fields. Please provide either "
609
- "'api_key' or 'oauth' fields."
230
+ f"both '{API_KEY}' and '{OAUTH_KEY}' fields. Please provide either "
231
+ f"'{API_KEY}' or '{OAUTH_KEY}' fields."
610
232
  )
611
233
  structlogger.error(
612
234
  "azure_openai_client_config.multiple_auth_types_specified",
@@ -616,7 +238,7 @@ class AzureOpenAIClientConfig:
616
238
 
617
239
  oauth = None
618
240
  if has_oauth_key:
619
- oauth = AzureOAuthConfig.from_config(config.pop(AZURE_OAUTH_KEY))
241
+ oauth = OAuthConfigWrapper.from_config(config.pop(OAUTH_KEY))
620
242
 
621
243
  this = AzureOpenAIClientConfig(
622
244
  # Required parameters
@@ -0,0 +1,33 @@
1
+ import abc
2
+ from typing import Any, Dict, TypeVar
3
+
4
+ OAUTH_TYPE_FIELD = "type"
5
+ OAUTH_KEY = "oauth"
6
+
7
+ OAuthType = TypeVar("OAuthType", bound="OAuth")
8
+
9
+
10
+ class OAuth(abc.ABC):
11
+ """Interface for OAuth configuration."""
12
+
13
+ @classmethod
14
+ @abc.abstractmethod
15
+ def from_config(cls: OAuthType, config: Dict[str, Any]) -> OAuthType:
16
+ """Initializes a dataclass from the passed config.
17
+
18
+ Args:
19
+ config: (dict) The config from which to initialize.
20
+
21
+ Returns:
22
+ OAuth
23
+ """
24
+ ...
25
+
26
+ @abc.abstractmethod
27
+ def get_bearer_token(self) -> str:
28
+ """Returns a bearer token.
29
+
30
+ Bear token is used to authenticate requests to the Azure Oopen AI instance's API protected
31
+ by the Gateway.
32
+ """
33
+ ...
@@ -3,5 +3,4 @@ LITE_LLM_API_BASE_FIELD = "api_base"
3
3
  LITE_LLM_API_KEY_FIELD = "api_key"
4
4
  LITE_LLM_API_VERSION_FIELD = "api_version"
5
5
  LITE_LLM_MODEL_FIELD = "model"
6
- LITE_LLM_AZURE_AD_TOKEN_PROVIDER = "azure_ad_token_provider"
7
6
  LITE_LLM_AZURE_AD_TOKEN = "azure_ad_token"
@@ -19,14 +19,13 @@ from rasa.shared.constants import (
19
19
  )
20
20
  from rasa.shared.exceptions import ProviderClientValidationError
21
21
  from rasa.shared.providers._configs.azure_openai_client_config import (
22
- AzureOAuthConfig,
22
+ AzureEntraIDOAuthConfig,
23
23
  AzureOpenAIClientConfig,
24
24
  )
25
25
  from rasa.shared.providers.constants import (
26
26
  DEFAULT_AZURE_API_KEY_NAME,
27
27
  LITE_LLM_API_KEY_FIELD,
28
28
  LITE_LLM_AZURE_AD_TOKEN,
29
- LITE_LLM_AZURE_AD_TOKEN_PROVIDER,
30
29
  )
31
30
  from rasa.shared.providers.embedding._base_litellm_embedding_client import (
32
31
  _BaseLiteLLMEmbeddingClient,
@@ -48,7 +47,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
48
47
  If not provided, it will be set via environment variable.
49
48
  api_version (Optional[str]): The version of the API to use.
50
49
  If not provided, it will be set via environment variable.
51
- oauth (Optional[AzureOAuthConfig]): Optional OAuth configuration. If provided,
50
+ oauth (Optional[AzureEntraIDOAuthConfig]): Optional OAuth configuration. If provided,
52
51
  the client will use OAuth for authentication.
53
52
  kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
54
53
  to the embedding model deployment.
@@ -66,7 +65,7 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
66
65
  api_base: Optional[str] = None,
67
66
  api_type: Optional[str] = None,
68
67
  api_version: Optional[str] = None,
69
- oauth: Optional[AzureOAuthConfig] = None,
68
+ oauth: Optional[AzureEntraIDOAuthConfig] = None,
70
69
  **kwargs: Any,
71
70
  ):
72
71
  super().__init__() # type: ignore
@@ -238,7 +237,6 @@ class AzureOpenAIEmbeddingClient(_BaseLiteLLMEmbeddingClient):
238
237
  auth_parameter = (
239
238
  {
240
239
  LITE_LLM_AZURE_AD_TOKEN: self._oauth.get_bearer_token(),
241
- LITE_LLM_AZURE_AD_TOKEN_PROVIDER: self._oauth.get_bearer_token_provider(), # noqa: E501
242
240
  }
243
241
  if self._oauth
244
242
  else {LITE_LLM_API_KEY_FIELD: self._api_key_env_var}
@@ -23,7 +23,7 @@ from rasa.shared.constants import (
23
23
  )
24
24
  from rasa.shared.exceptions import ProviderClientValidationError
25
25
  from rasa.shared.providers._configs.azure_openai_client_config import (
26
- AzureOAuthConfig,
26
+ AzureEntraIDOAuthConfig,
27
27
  AzureOpenAIClientConfig,
28
28
  )
29
29
  from rasa.shared.providers.constants import (
@@ -32,7 +32,6 @@ from rasa.shared.providers.constants import (
32
32
  LITE_LLM_API_KEY_FIELD,
33
33
  LITE_LLM_API_VERSION_FIELD,
34
34
  LITE_LLM_AZURE_AD_TOKEN,
35
- LITE_LLM_AZURE_AD_TOKEN_PROVIDER,
36
35
  )
37
36
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
38
37
  from rasa.shared.utils.io import raise_deprecation_warning
@@ -86,7 +85,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
86
85
  api_type: Optional[str] = None,
87
86
  api_base: Optional[str] = None,
88
87
  api_version: Optional[str] = None,
89
- oauth: Optional[AzureOAuthConfig] = None,
88
+ oauth: Optional[AzureEntraIDOAuthConfig] = None,
90
89
  **kwargs: Any,
91
90
  ):
92
91
  super().__init__() # type: ignore
@@ -329,7 +328,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
329
328
  auth_parameter = (
330
329
  {
331
330
  LITE_LLM_AZURE_AD_TOKEN: self._oauth.get_bearer_token(),
332
- LITE_LLM_AZURE_AD_TOKEN_PROVIDER: self._oauth.get_bearer_token_provider(), # noqa: E501
333
331
  }
334
332
  if self._oauth
335
333
  else {LITE_LLM_API_KEY_FIELD: self._api_key_env_var}
rasa/version.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # this file will automatically be changed,
2
2
  # do not add anything but the version number here!
3
- __version__ = "3.12.0dev1"
3
+ __version__ = "3.12.0dev2"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: rasa-pro
3
- Version: 3.12.0.dev1
3
+ Version: 3.12.0.dev2
4
4
  Summary: State-of-the-art open-core Conversational AI framework for Enterprises that natively leverages generative AI for effortless assistant development.
5
5
  Home-page: https://rasa.com
6
6
  Keywords: nlp,machine-learning,machine-learning-library,bot,bots,botkit,rasa conversational-agents,conversational-ai,chatbot,chatbot-framework,bot-framework
@@ -672,23 +672,26 @@ rasa/shared/nlu/training_data/training_data.py,sha256=KY51CJD9NY6vs9Zs7e-ivtyIYJ
672
672
  rasa/shared/nlu/training_data/util.py,sha256=mom7CxLKI5RfOpsJrAKL281a_b01sIcQsr04gSmEEbU,7049
673
673
  rasa/shared/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
674
674
  rasa/shared/providers/_configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
675
- rasa/shared/providers/_configs/azure_openai_client_config.py,sha256=4CwnOv5X0VQ7WbfGdT_x_Fkx5WHBYRDIUPKxH8iC6TY,22824
675
+ rasa/shared/providers/_configs/azure_entra_id_client_creds.py,sha256=S6slRcUERNAbkA2sp-Ykk3Y6wl6EnvnHYLY_ILsveVc,1103
676
+ rasa/shared/providers/_configs/azure_entra_id_config.py,sha256=bScClwyvCW7EO0GCvgUOS77CscFG2JrkH2zsojm4L1g,18391
677
+ rasa/shared/providers/_configs/azure_openai_client_config.py,sha256=01kg6TqjfjkZ6q2zASs-jp7j39adyLp6F37yi2fVpDo,10076
676
678
  rasa/shared/providers/_configs/client_config.py,sha256=nQ469h1XI970_7Vs49hNIpBIwlAeiAg-cwV0JFp7Hg0,1618
677
679
  rasa/shared/providers/_configs/default_litellm_client_config.py,sha256=tViurJ1NDbiBn9b5DbzhFHO1pJM889MC-GakWhEX07E,4352
678
680
  rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py,sha256=q8ddTFwddDhx654ZQmg9eP_yo77N3Xg77hAmfXOmzPg,8200
679
681
  rasa/shared/providers/_configs/litellm_router_client_config.py,sha256=OX7egiQXkGSYxIfEOFrGFwCIKFJc3IgBKrZGqdjeMVQ,7265
680
682
  rasa/shared/providers/_configs/model_group_config.py,sha256=E1_hjP1p9b08m-28EPNNCwoFUJ1MMdkmC54Bai6JK0A,5730
683
+ rasa/shared/providers/_configs/oauth_config.py,sha256=OPwignGb4xGI7GCTsfVUCM57fQYNDk-_IVXkt-OHyw0,775
681
684
  rasa/shared/providers/_configs/openai_client_config.py,sha256=tKCQSjtpVmPO_30sRmcFFDk0tNFs5bVseyI7iBU6ZOY,5839
682
685
  rasa/shared/providers/_configs/rasa_llm_client_config.py,sha256=elpbqVNSgkAiM0Dg-0N3ayVkSi6TAERepdZG7Bv8NdI,2245
683
686
  rasa/shared/providers/_configs/self_hosted_llm_client_config.py,sha256=l2JnypPXFL6KVxhftKTYvh-NqpXJ8--pjbJ-IQHoPRs,5963
684
687
  rasa/shared/providers/_configs/utils.py,sha256=-1WxEcrV5WHv3Q6GVGTJJZcdBe_p8NU4ArVspTTa8mg,3731
685
688
  rasa/shared/providers/_ssl_verification_utils.py,sha256=4tujCOjg0KKX2_DzOb7lZTdsUXtzRB4UkfhkC3W0jO0,4166
686
689
  rasa/shared/providers/_utils.py,sha256=JW2A1FM4QsIWNUH3QoMo_DRiZ90UW9tiu9n9RR1Oih4,3090
687
- rasa/shared/providers/constants.py,sha256=LYnqUf511ydlQcC7Ka6J_g5DGpi_fN48AVlEY577zJQ,295
690
+ rasa/shared/providers/constants.py,sha256=hgV8yNGxIbID_2h65OoSfSjIE4UkazrsqRg4SdkPAmI,234
688
691
  rasa/shared/providers/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
689
692
  rasa/shared/providers/embedding/_base_litellm_embedding_client.py,sha256=EW1AdTklBDJTVDGKslBZ-IMi7WgRzBdKbLIn9BiuXX4,8796
690
693
  rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py,sha256=IR2Rb3ReJ9C9sxOoOGRXgtz8STWdMREs_4AeSMKFjl4,2135
691
- rasa/shared/providers/embedding/azure_openai_embedding_client.py,sha256=EIezqXhrMb6bEg49OkClIXGb-H-0o7CNnap3Nf9pcLw,12483
694
+ rasa/shared/providers/embedding/azure_openai_embedding_client.py,sha256=QvFnnu-Jmcq9ABnNbxc4sWyi1mwVz7nWtuzpjIccokI,12361
692
695
  rasa/shared/providers/embedding/default_litellm_embedding_client.py,sha256=da17WeHjZp95Uv9jmTKxklNRcNpn-qRsRPcwDQusElg,4397
693
696
  rasa/shared/providers/embedding/embedding_client.py,sha256=rmFBKSKSihqmzpuZ-I0zVm1BBqTjL6V-K65gefoI35o,2839
694
697
  rasa/shared/providers/embedding/embedding_response.py,sha256=H55mSAL3LfVvDlBklaCCQ4AnNwCsQSQ1f2D0oPrx3FY,1204
@@ -697,7 +700,7 @@ rasa/shared/providers/embedding/litellm_router_embedding_client.py,sha256=eafDk6
697
700
  rasa/shared/providers/embedding/openai_embedding_client.py,sha256=XNRGE7apo2v3kWRrtgxE-Gq4rvNko3IiXtvgC4krDYE,5429
698
701
  rasa/shared/providers/llm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
699
702
  rasa/shared/providers/llm/_base_litellm_client.py,sha256=ZoLGPfFxa2YhA7JuEk8W2E_an1m_eK5MXxdFLwj026M,10038
700
- rasa/shared/providers/llm/azure_openai_llm_client.py,sha256=HXuUX0eNH5gy-0Nmm5YuqGPlLZvRrUV8Qs43pM30zvE,15010
703
+ rasa/shared/providers/llm/azure_openai_llm_client.py,sha256=59rJi3IzznDkPkFwfJX3fCGPt0veUpJMqcPdBQLF4pw,14881
701
704
  rasa/shared/providers/llm/default_litellm_llm_client.py,sha256=e3f-YMS7-aariB5erRot7NReD-eaVPgeD45rypF-sUw,3974
702
705
  rasa/shared/providers/llm/litellm_router_llm_client.py,sha256=e9OIQrXH80G_JxzOn3rIATAHl4nmN-kTK2RLnWpAuGQ,7592
703
706
  rasa/shared/providers/llm/llm_client.py,sha256=c2pYAS-0AmZKe_gLTptwmBspXFaCUDqK0qLTqo3l2Ok,2392
@@ -782,9 +785,9 @@ rasa/utils/train_utils.py,sha256=f1NWpp5y6al0dzoQyyio4hc4Nf73DRoRSHDzEK6-C4E,212
782
785
  rasa/utils/url_tools.py,sha256=JQcHL2aLqLHu82k7_d9imUoETCm2bmlHaDpOJ-dKqBc,1218
783
786
  rasa/utils/yaml.py,sha256=KjbZq5C94ZP7Jdsw8bYYF7HASI6K4-C_kdHfrnPLpSI,2000
784
787
  rasa/validator.py,sha256=wl5IKiyDmk6FlDcGO2Js-H-gHPeqVqUJ6hB4fgN0xjI,66796
785
- rasa/version.py,sha256=GJdudg_5YuWRFetde9-Wwcs5ARtFvjExGxcxDgvlQxw,121
786
- rasa_pro-3.12.0.dev1.dist-info/METADATA,sha256=HGhiL9O-I1zYt-UDu0OvXiUhPSTRhCocizU1ZuZ6fPo,10844
787
- rasa_pro-3.12.0.dev1.dist-info/NOTICE,sha256=7HlBoMHJY9CL2GlYSfTQ-PZsVmLmVkYmMiPlTjhuCqA,218
788
- rasa_pro-3.12.0.dev1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
789
- rasa_pro-3.12.0.dev1.dist-info/entry_points.txt,sha256=ckJ2SfEyTPgBqj_I6vm_tqY9dZF_LAPJZA335Xp0Q9U,43
790
- rasa_pro-3.12.0.dev1.dist-info/RECORD,,
788
+ rasa/version.py,sha256=Qk-yx_nkp859YygvuBicHLhsbvQxDrL2J8-SMq3DiIw,121
789
+ rasa_pro-3.12.0.dev2.dist-info/METADATA,sha256=ugLXE3EIrVQcn9W5z_5PTSKMLm5G0N7lTh4y9wrL8nA,10844
790
+ rasa_pro-3.12.0.dev2.dist-info/NOTICE,sha256=7HlBoMHJY9CL2GlYSfTQ-PZsVmLmVkYmMiPlTjhuCqA,218
791
+ rasa_pro-3.12.0.dev2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
792
+ rasa_pro-3.12.0.dev2.dist-info/entry_points.txt,sha256=ckJ2SfEyTPgBqj_I6vm_tqY9dZF_LAPJZA335Xp0Q9U,43
793
+ rasa_pro-3.12.0.dev2.dist-info/RECORD,,