rasa-pro 3.12.0.dev13__py3-none-any.whl → 3.12.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (128) hide show
  1. rasa/anonymization/anonymization_rule_executor.py +16 -10
  2. rasa/cli/data.py +16 -0
  3. rasa/cli/project_templates/calm/config.yml +2 -2
  4. rasa/cli/project_templates/calm/endpoints.yml +2 -2
  5. rasa/cli/utils.py +12 -0
  6. rasa/core/actions/action.py +84 -191
  7. rasa/core/actions/action_run_slot_rejections.py +16 -4
  8. rasa/core/channels/__init__.py +2 -0
  9. rasa/core/channels/studio_chat.py +19 -0
  10. rasa/core/channels/telegram.py +42 -24
  11. rasa/core/channels/voice_ready/utils.py +1 -1
  12. rasa/core/channels/voice_stream/asr/asr_engine.py +10 -4
  13. rasa/core/channels/voice_stream/asr/azure.py +14 -1
  14. rasa/core/channels/voice_stream/asr/deepgram.py +20 -4
  15. rasa/core/channels/voice_stream/audiocodes.py +264 -0
  16. rasa/core/channels/voice_stream/browser_audio.py +4 -1
  17. rasa/core/channels/voice_stream/call_state.py +3 -0
  18. rasa/core/channels/voice_stream/genesys.py +6 -2
  19. rasa/core/channels/voice_stream/tts/azure.py +9 -1
  20. rasa/core/channels/voice_stream/tts/cartesia.py +14 -8
  21. rasa/core/channels/voice_stream/voice_channel.py +23 -2
  22. rasa/core/constants.py +2 -0
  23. rasa/core/nlg/contextual_response_rephraser.py +18 -1
  24. rasa/core/nlg/generator.py +83 -15
  25. rasa/core/nlg/response.py +6 -3
  26. rasa/core/nlg/translate.py +55 -0
  27. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +1 -1
  28. rasa/core/policies/flows/flow_executor.py +12 -5
  29. rasa/core/processor.py +72 -9
  30. rasa/dialogue_understanding/commands/can_not_handle_command.py +20 -2
  31. rasa/dialogue_understanding/commands/cancel_flow_command.py +24 -6
  32. rasa/dialogue_understanding/commands/change_flow_command.py +20 -2
  33. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +20 -2
  34. rasa/dialogue_understanding/commands/clarify_command.py +29 -3
  35. rasa/dialogue_understanding/commands/command.py +1 -16
  36. rasa/dialogue_understanding/commands/command_syntax_manager.py +55 -0
  37. rasa/dialogue_understanding/commands/human_handoff_command.py +20 -2
  38. rasa/dialogue_understanding/commands/knowledge_answer_command.py +20 -2
  39. rasa/dialogue_understanding/commands/prompt_command.py +94 -0
  40. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +20 -2
  41. rasa/dialogue_understanding/commands/set_slot_command.py +24 -2
  42. rasa/dialogue_understanding/commands/skip_question_command.py +20 -2
  43. rasa/dialogue_understanding/commands/start_flow_command.py +20 -2
  44. rasa/dialogue_understanding/commands/utils.py +98 -4
  45. rasa/dialogue_understanding/generator/__init__.py +2 -0
  46. rasa/dialogue_understanding/generator/command_parser.py +15 -12
  47. rasa/dialogue_understanding/generator/constants.py +3 -0
  48. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -5
  49. rasa/dialogue_understanding/generator/llm_command_generator.py +5 -3
  50. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +16 -2
  51. rasa/dialogue_understanding/generator/prompt_templates/__init__.py +0 -0
  52. rasa/dialogue_understanding/generator/{single_step → prompt_templates}/command_prompt_template.jinja2 +2 -0
  53. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +77 -0
  54. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +68 -0
  55. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +84 -0
  56. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +460 -0
  57. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +12 -310
  58. rasa/dialogue_understanding/patterns/collect_information.py +1 -1
  59. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +16 -0
  60. rasa/dialogue_understanding/patterns/validate_slot.py +65 -0
  61. rasa/dialogue_understanding/processor/command_processor.py +39 -0
  62. rasa/dialogue_understanding_test/du_test_case.py +28 -8
  63. rasa/dialogue_understanding_test/du_test_result.py +13 -9
  64. rasa/dialogue_understanding_test/io.py +14 -0
  65. rasa/e2e_test/utils/io.py +0 -37
  66. rasa/engine/graph.py +1 -0
  67. rasa/engine/language.py +140 -0
  68. rasa/engine/recipes/config_files/default_config.yml +4 -0
  69. rasa/engine/recipes/default_recipe.py +2 -0
  70. rasa/engine/recipes/graph_recipe.py +2 -0
  71. rasa/engine/storage/local_model_storage.py +1 -0
  72. rasa/engine/storage/storage.py +4 -1
  73. rasa/model_manager/runner_service.py +7 -4
  74. rasa/model_manager/socket_bridge.py +7 -6
  75. rasa/shared/constants.py +15 -13
  76. rasa/shared/core/constants.py +2 -0
  77. rasa/shared/core/flows/constants.py +11 -0
  78. rasa/shared/core/flows/flow.py +83 -19
  79. rasa/shared/core/flows/flows_yaml_schema.json +31 -3
  80. rasa/shared/core/flows/steps/collect.py +1 -36
  81. rasa/shared/core/flows/utils.py +28 -4
  82. rasa/shared/core/flows/validation.py +1 -1
  83. rasa/shared/core/slot_mappings.py +208 -5
  84. rasa/shared/core/slots.py +131 -1
  85. rasa/shared/core/trackers.py +74 -1
  86. rasa/shared/importers/importer.py +50 -2
  87. rasa/shared/nlu/training_data/schemas/responses.yml +19 -12
  88. rasa/shared/providers/_configs/azure_entra_id_config.py +541 -0
  89. rasa/shared/providers/_configs/azure_openai_client_config.py +138 -3
  90. rasa/shared/providers/_configs/client_config.py +3 -1
  91. rasa/shared/providers/_configs/default_litellm_client_config.py +3 -1
  92. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +3 -1
  93. rasa/shared/providers/_configs/litellm_router_client_config.py +3 -1
  94. rasa/shared/providers/_configs/model_group_config.py +4 -2
  95. rasa/shared/providers/_configs/oauth_config.py +33 -0
  96. rasa/shared/providers/_configs/openai_client_config.py +3 -1
  97. rasa/shared/providers/_configs/rasa_llm_client_config.py +3 -1
  98. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +3 -1
  99. rasa/shared/providers/constants.py +6 -0
  100. rasa/shared/providers/embedding/azure_openai_embedding_client.py +28 -3
  101. rasa/shared/providers/embedding/litellm_router_embedding_client.py +3 -1
  102. rasa/shared/providers/llm/_base_litellm_client.py +42 -17
  103. rasa/shared/providers/llm/azure_openai_llm_client.py +81 -25
  104. rasa/shared/providers/llm/default_litellm_llm_client.py +3 -1
  105. rasa/shared/providers/llm/litellm_router_llm_client.py +29 -8
  106. rasa/shared/providers/llm/llm_client.py +23 -7
  107. rasa/shared/providers/llm/openai_llm_client.py +9 -3
  108. rasa/shared/providers/llm/rasa_llm_client.py +11 -2
  109. rasa/shared/providers/llm/self_hosted_llm_client.py +30 -11
  110. rasa/shared/providers/router/_base_litellm_router_client.py +3 -1
  111. rasa/shared/providers/router/router_client.py +3 -1
  112. rasa/shared/utils/constants.py +3 -0
  113. rasa/shared/utils/llm.py +30 -7
  114. rasa/shared/utils/pykwalify_extensions.py +24 -0
  115. rasa/shared/utils/schemas/domain.yml +26 -0
  116. rasa/telemetry.py +2 -1
  117. rasa/tracing/config.py +2 -0
  118. rasa/tracing/constants.py +12 -0
  119. rasa/tracing/instrumentation/instrumentation.py +36 -0
  120. rasa/tracing/instrumentation/metrics.py +41 -0
  121. rasa/tracing/metric_instrument_provider.py +40 -0
  122. rasa/validator.py +372 -7
  123. rasa/version.py +1 -1
  124. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/METADATA +2 -1
  125. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/RECORD +128 -113
  126. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/NOTICE +0 -0
  127. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/WHEEL +0 -0
  128. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,541 @@
1
+ from __future__ import annotations
2
+
3
+ import abc
4
+ import logging
5
+ from copy import deepcopy
6
+ from enum import Enum
7
+ from functools import lru_cache
8
+ from typing import Any, Callable, Dict, List, Optional, Set, Type
9
+
10
+ import structlog
11
+ from azure.core.credentials import TokenProvider
12
+ from azure.identity import (
13
+ CertificateCredential,
14
+ ClientSecretCredential,
15
+ DefaultAzureCredential,
16
+ )
17
+ from pydantic import BaseModel, Field, SecretStr
18
+
19
+ from rasa.shared.providers._configs.oauth_config import OAUTH_TYPE_FIELD, OAuth
20
+
21
+ AZURE_CLIENT_ID_FIELD = "client_id"
22
+ AZURE_CLIENT_SECRET_FIELD = "client_secret"
23
+ AZURE_TENANT_ID_FIELD = "tenant_id"
24
+ AZURE_CERTIFICATE_PATH_FIELD = "certificate_path"
25
+ AZURE_CERTIFICATE_PASSWORD_FIELD = "certificate_password"
26
+ AZURE_SEND_CERTIFICATE_CHAIN_FIELD = "send_certificate_chain"
27
+ AZURE_SCOPES_FIELD = "scopes"
28
+ AZURE_AUTHORITY_FIELD = "authority_host"
29
+ AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD = "disable_instance_discovery"
30
+
31
+
32
+ azure_logger = logging.getLogger("azure")
33
+ azure_logger.setLevel(logging.DEBUG)
34
+
35
+ structlogger = structlog.get_logger()
36
+
37
+
38
+ class AzureEntraIDOAuthType(str, Enum):
39
+ """Azure Entra ID OAuth types."""
40
+
41
+ AZURE_ENTRA_ID_DEFAULT = "azure_entra_id_default"
42
+ AZURE_ENTRA_ID_CLIENT_SECRET = "azure_entra_id_client_secret"
43
+ AZURE_ENTRA_ID_CLIENT_CERTIFICATE = "azure_entra_id_client_certificate"
44
+
45
+ # Invalid type is used to indicate that the type
46
+ # configuration is invalid EntraID or not set.
47
+ INVALID = "invalid"
48
+
49
+ @staticmethod
50
+ def from_string(value: Optional[str]) -> AzureEntraIDOAuthType:
51
+ """Converts a string to an AzureOAuthType."""
52
+ if value is None or value not in AzureEntraIDOAuthType.valid_string_values():
53
+ return AzureEntraIDOAuthType.INVALID
54
+
55
+ return AzureEntraIDOAuthType(value)
56
+
57
+ @staticmethod
58
+ def valid_string_values() -> Set[str]:
59
+ """Returns the valid string values for the AzureOAuthType."""
60
+ return {e.value for e in AzureEntraIDOAuthType.valid_values()}
61
+
62
+ @staticmethod
63
+ def valid_values() -> Set[AzureEntraIDOAuthType]:
64
+ """Returns the valid values for the AzureOAuthType."""
65
+ return {
66
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_DEFAULT,
67
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_SECRET,
68
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_CERTIFICATE,
69
+ }
70
+
71
+
72
+ # BearerTokenProvider is a callable that returns a bearer token.
73
+ BearerTokenProvider = Callable[[], str]
74
+
75
+
76
+ class AzureEntraIDTokenProviderConfig(abc.ABC):
77
+ """Interface for Azure Entra ID OAuth credential configuration."""
78
+
79
+ @abc.abstractmethod
80
+ def create_azure_token_provider(self) -> TokenProvider:
81
+ """Create an Azure Entra ID token provider."""
82
+ ...
83
+
84
+ @classmethod
85
+ @abc.abstractmethod
86
+ def from_dict(
87
+ cls: Type[AzureEntraIDTokenProviderConfig], config: Dict[str, Any]
88
+ ) -> AzureEntraIDTokenProviderConfig:
89
+ """Initializes a dataclass from the passed config.
90
+
91
+ Args:
92
+ config: (dict) The config from which to initialize.
93
+
94
+ Returns:
95
+ AzureEntraIDCredential
96
+ """
97
+ ...
98
+
99
+
100
+ class AzureEntraIDClientCredentialsConfig(AzureEntraIDTokenProviderConfig, BaseModel):
101
+ """Azure Entra ID OAuth client credentials configuration.
102
+
103
+ Attributes:
104
+ client_id: The client ID.
105
+ client_secret: The client secret.
106
+ tenant_id: The tenant ID.
107
+ authority_host: The authority host.
108
+ disable_instance_discovery: Whether to disable instance discovery. This is used
109
+ to disable fetching metadata from the Azure Instance Metadata Service.
110
+ """
111
+
112
+ client_id: str = Field(min_length=1)
113
+ client_secret: SecretStr = Field(min_length=1)
114
+ tenant_id: str = Field(min_length=1)
115
+ authority_host: Optional[str] = None
116
+ disable_instance_discovery: bool = False
117
+
118
+ @staticmethod
119
+ def required_fields() -> Set[str]:
120
+ """Returns the required fields for the configuration."""
121
+ return {AZURE_CLIENT_ID_FIELD, AZURE_TENANT_ID_FIELD, AZURE_CLIENT_SECRET_FIELD}
122
+
123
+ @staticmethod
124
+ def config_has_required_fields(config: Dict[str, Any]) -> bool:
125
+ """Check if the configuration has all the required fields."""
126
+ return AzureEntraIDClientCredentialsConfig.required_fields().issubset(
127
+ set(config.keys())
128
+ )
129
+
130
+ @classmethod
131
+ def from_dict(cls, config: Dict[str, Any]) -> AzureEntraIDClientCredentialsConfig:
132
+ """Initializes a dataclass from the passed config.
133
+
134
+ Args:
135
+ config: (dict) The config from which to initialize.
136
+
137
+ Returns:
138
+ AzureClientCredentialsConfig
139
+ """
140
+ if not cls.config_has_required_fields(config):
141
+ message = (
142
+ f"A configuration for Azure client credentials "
143
+ f"must contain the following keys: {cls.required_fields()}"
144
+ )
145
+ structlogger.error(
146
+ "azure_client_credentials_config.missing_required_keys",
147
+ message=message,
148
+ config=config,
149
+ )
150
+ raise ValueError(message)
151
+
152
+ return cls(
153
+ client_id=config.pop(AZURE_CLIENT_ID_FIELD),
154
+ client_secret=config.pop(AZURE_CLIENT_SECRET_FIELD),
155
+ tenant_id=config.pop(AZURE_TENANT_ID_FIELD),
156
+ authority_host=config.pop(AZURE_AUTHORITY_FIELD, None),
157
+ disable_instance_discovery=config.pop(
158
+ AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD, False
159
+ ),
160
+ )
161
+
162
+ def create_azure_token_provider(self) -> TokenProvider:
163
+ """Create a ClientSecretCredential for Azure Entra ID."""
164
+ return create_azure_entra_id_client_credentials(
165
+ client_id=self.client_id,
166
+ client_secret=self.client_secret.get_secret_value(),
167
+ tenant_id=self.tenant_id,
168
+ authority_host=self.authority_host,
169
+ disable_instance_discovery=self.disable_instance_discovery,
170
+ )
171
+
172
+
173
+ # We are caching the result of this function to preserve the refresh
174
+ # token which is stored inside the credential object.
175
+ # This allows us to reuse the same credential object (refresh token)
176
+ # across multiple requests.
177
+ # Refresh token is used to get a new access token when the current access
178
+ # token expires without having to re-authenticate the
179
+ # user (transmit the client secret again).
180
+ @lru_cache
181
+ def create_azure_entra_id_client_credentials(
182
+ client_id: str,
183
+ client_secret: str,
184
+ tenant_id: str,
185
+ authority_host: Optional[str] = None,
186
+ disable_instance_discovery: bool = False,
187
+ ) -> ClientSecretCredential:
188
+ """Creates a ClientSecretCredential for Azure Entra ID.
189
+
190
+ We cache the result of this function to avoid creating multiple instances
191
+ of the same credential. This makes it possible to utilise the token caching
192
+ and token refreshing functionality of the azure-identity library.
193
+
194
+ Args:
195
+ client_id: The client ID.
196
+ client_secret: The client secret.
197
+ tenant_id: The tenant ID.
198
+ authority_host: The authority host.
199
+ disable_instance_discovery: Whether to disable instance discovery. This is used
200
+ to disable fetching metadata from the Azure Instance Metadata Service.
201
+
202
+ Returns:
203
+ ClientSecretCredential
204
+ """
205
+ return ClientSecretCredential(
206
+ client_id=client_id,
207
+ client_secret=client_secret,
208
+ tenant_id=tenant_id,
209
+ authority=authority_host,
210
+ disable_instance_discovery=disable_instance_discovery,
211
+ )
212
+
213
+
214
+ class AzureEntraIDClientCertificateConfig(AzureEntraIDTokenProviderConfig, BaseModel):
215
+ """Azure Entra ID OAuth client certificate configuration.
216
+
217
+ Attributes:
218
+ client_id: The client ID.
219
+ tenant_id: The tenant ID.
220
+ certificate_path: The path to the certificate file.
221
+ certificate_password: The certificate password.
222
+ send_certificate_chain: Whether to send the certificate chain.
223
+ authority_host: The authority host.
224
+ disable_instance_discovery: Whether to disable instance discovery. This is used
225
+ to disable fetching metadata from the Azure Instance Metadata Service.
226
+ """
227
+
228
+ client_id: str = Field(min_length=1)
229
+ tenant_id: str = Field(min_length=1)
230
+ certificate_path: str = Field(min_length=1)
231
+ certificate_password: Optional[SecretStr] = None
232
+ send_certificate_chain: bool = False
233
+ authority_host: Optional[str] = None
234
+ disable_instance_discovery: bool = False
235
+
236
+ @staticmethod
237
+ def required_fields() -> Set[str]:
238
+ """Returns the required fields for the configuration."""
239
+ return {
240
+ AZURE_CLIENT_ID_FIELD,
241
+ AZURE_TENANT_ID_FIELD,
242
+ AZURE_CERTIFICATE_PATH_FIELD,
243
+ }
244
+
245
+ @staticmethod
246
+ def config_has_required_fields(config: Dict[str, Any]) -> bool:
247
+ """Check if the configuration has all the required fields."""
248
+ return AzureEntraIDClientCertificateConfig.required_fields().issubset(
249
+ set(config.keys())
250
+ )
251
+
252
+ @classmethod
253
+ def from_dict(cls, config: Dict[str, Any]) -> AzureEntraIDClientCertificateConfig:
254
+ """Initializes a dataclass from the passed config.
255
+
256
+ Args:
257
+ config: (dict) The config from which to initialize.
258
+
259
+ Returns:
260
+ AzureClientCertificateConfig
261
+ """
262
+ if not cls.config_has_required_fields(config):
263
+ message = (
264
+ f"A configuration for Azure client certificate "
265
+ f"must contain "
266
+ f"the following keys: {cls.required_fields()}"
267
+ )
268
+ structlogger.error(
269
+ "azure_client_certificate_config.validation_error",
270
+ message=message,
271
+ config=config,
272
+ )
273
+ raise ValueError(message)
274
+
275
+ return cls(
276
+ client_id=config[AZURE_CLIENT_ID_FIELD],
277
+ tenant_id=config[AZURE_TENANT_ID_FIELD],
278
+ certificate_path=config[AZURE_CERTIFICATE_PATH_FIELD],
279
+ certificate_password=config.get(AZURE_CERTIFICATE_PASSWORD_FIELD, None),
280
+ authority_host=config.get(AZURE_AUTHORITY_FIELD, None),
281
+ send_certificate_chain=config.get(
282
+ AZURE_SEND_CERTIFICATE_CHAIN_FIELD, False
283
+ ),
284
+ disable_instance_discovery=config.get(
285
+ AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD, False
286
+ ),
287
+ )
288
+
289
+ def create_azure_token_provider(self) -> TokenProvider:
290
+ """Creates a CertificateCredential for Azure Entra ID."""
291
+ return create_azure_entra_id_certificate_credentials(
292
+ client_id=self.client_id,
293
+ tenant_id=self.tenant_id,
294
+ certificate_path=self.certificate_path,
295
+ password=self.certificate_password.get_secret_value()
296
+ if self.certificate_password
297
+ else None,
298
+ send_certificate_chain=self.send_certificate_chain,
299
+ authority_host=self.authority_host,
300
+ disable_instance_discovery=self.disable_instance_discovery,
301
+ )
302
+
303
+
304
+ # We are caching the result of this function to preserve the refresh
305
+ # token which is stored inside the credential object.
306
+ # This allows us to reuse the same credential object (refresh token)
307
+ # across multiple requests.
308
+ # Refresh token is used to get a new access token when the current
309
+ # access token expires without having to re-authenticate
310
+ # the user (transmit the client certificate again).
311
+ @lru_cache
312
+ def create_azure_entra_id_certificate_credentials(
313
+ tenant_id: str,
314
+ client_id: str,
315
+ certificate_path: Optional[str] = None,
316
+ password: Optional[str] = None,
317
+ send_certificate_chain: bool = False,
318
+ authority_host: Optional[str] = None,
319
+ disable_instance_discovery: bool = False,
320
+ ) -> CertificateCredential:
321
+ """Creates a CertificateCredential for Azure Entra ID.
322
+
323
+ We cache the result of this function to avoid creating multiple instances
324
+ of the same credential. This makes it possible to utilise the token caching
325
+ and token refreshing functionality of the azure-identity library.
326
+
327
+ Args:
328
+ tenant_id: The tenant ID.
329
+ client_id: The client ID.
330
+ certificate_path: The path to the certificate file.
331
+ password: The certificate password.
332
+ send_certificate_chain: Whether to send the certificate chain.
333
+ authority_host: The authority host.
334
+ disable_instance_discovery: Whether to disable instance discovery. This is used
335
+
336
+ Returns:
337
+ CertificateCredential
338
+ """
339
+
340
+ return CertificateCredential(
341
+ client_id=client_id,
342
+ tenant_id=tenant_id,
343
+ certificate_path=certificate_path,
344
+ password=password.encode("utf-8") if password else None,
345
+ send_certificate_chain=send_certificate_chain,
346
+ authority=authority_host,
347
+ disable_instance_discovery=disable_instance_discovery,
348
+ )
349
+
350
+
351
+ class AzureEntraIDDefaultCredentialsConfig(AzureEntraIDTokenProviderConfig, BaseModel):
352
+ """Azure Entra ID OAuth default credentials configuration.
353
+
354
+ Attributes:
355
+ authority_host: The authority host.
356
+ """
357
+
358
+ authority_host: Optional[str] = None
359
+
360
+ @classmethod
361
+ def from_dict(cls, config: Dict[str, Any]) -> AzureEntraIDDefaultCredentialsConfig:
362
+ """Initializes a dataclass from the passed config.
363
+
364
+ Args:
365
+ config: (dict) The config from which to initialize.
366
+
367
+ Returns:
368
+ AzureOAuthDefaultCredentialsConfig
369
+ """
370
+ return cls(authority_host=config.pop(AZURE_AUTHORITY_FIELD, None))
371
+
372
+ def create_azure_token_provider(self) -> TokenProvider:
373
+ """Creates a DefaultAzureCredential."""
374
+ return create_azure_entra_id_default_credentials(
375
+ authority_host=self.authority_host
376
+ )
377
+
378
+
379
+ @lru_cache
380
+ def create_azure_entra_id_default_credentials(
381
+ authority_host: Optional[str] = None,
382
+ ) -> DefaultAzureCredential:
383
+ """Creates a DefaultAzureCredential.
384
+
385
+ We cache the result of this function to avoid creating multiple instances
386
+ of the same credential. This makes it possible to utilise the token caching
387
+ functionality of the azure-identity library.
388
+
389
+ Args:
390
+ authority_host: The authority host.
391
+
392
+ Returns:
393
+ DefaultAzureCredential
394
+ """
395
+ return DefaultAzureCredential(authority=authority_host)
396
+
397
+
398
+ class AzureEntraIDOAuthConfig(OAuth, BaseModel):
399
+ """Azure Entra ID OAuth configuration.
400
+
401
+ It consists of the scopes and the Azure Entra ID OAuth credentials.
402
+ """
403
+
404
+ # pydantic configuration to allow arbitrary user defined types
405
+ class Config:
406
+ arbitrary_types_allowed = True
407
+
408
+ scopes: List[str]
409
+ azure_entra_id_token_provider_config: AzureEntraIDTokenProviderConfig
410
+
411
+ @staticmethod
412
+ def _supported_azure_oauth() -> (
413
+ Dict[AzureEntraIDOAuthType, Type[AzureEntraIDTokenProviderConfig]]
414
+ ):
415
+ """Returns a mapping of supported Azure Entra ID OAuth types to their"""
416
+ return {
417
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_DEFAULT: AzureEntraIDDefaultCredentialsConfig, # noqa: E501
418
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_SECRET: AzureEntraIDClientCredentialsConfig, # noqa: E501
419
+ AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_CERTIFICATE: AzureEntraIDClientCertificateConfig, # noqa: E501
420
+ }
421
+
422
+ @staticmethod
423
+ def _get_azure_oauth_by_type(
424
+ oauth_type: AzureEntraIDOAuthType,
425
+ ) -> Type[AzureEntraIDTokenProviderConfig]:
426
+ """Returns the Azure Entra ID OAuth class based on the type.
427
+
428
+ Args:
429
+ oauth_type: (AzureOAuthType) The type of the Azure Entra ID OAuth.
430
+
431
+ Returns:
432
+ The Azure Entra ID OAuth class
433
+
434
+ Raises:
435
+ ValueError: If the passed oauth_type is not supported or invalid.
436
+ """
437
+ azure_oauth_types = AzureEntraIDOAuthConfig._supported_azure_oauth()
438
+ azure_oauth_class = azure_oauth_types.get(oauth_type)
439
+
440
+ if azure_oauth_class is None:
441
+ message = (
442
+ f"Unsupported Azure Entra ID oauth type: {oauth_type}. "
443
+ f"Supported types are: {AzureEntraIDOAuthType.valid_string_values()}"
444
+ )
445
+ structlogger.error(
446
+ "azure_oauth_config.unsupported_azure_oauth_type",
447
+ message=message,
448
+ )
449
+ raise ValueError(message)
450
+
451
+ return azure_oauth_class
452
+
453
+ @classmethod
454
+ def from_dict(cls, oauth_config: Dict[str, Any]) -> AzureEntraIDOAuthConfig:
455
+ """Initializes a dataclass from the passed config.
456
+
457
+ Args:
458
+ oauth_config: (dict) The config from which to initialize.
459
+
460
+ Returns:
461
+ AzureOAuthConfig
462
+ """
463
+
464
+ config = deepcopy(oauth_config)
465
+
466
+ scopes = AzureEntraIDOAuthConfig._read_scopes_from_config(config)
467
+ azure_credentials = (
468
+ AzureEntraIDOAuthConfig._create_azure_entra_id_client_from_config(config)
469
+ )
470
+ return cls(
471
+ azure_entra_id_token_provider_config=azure_credentials, scopes=scopes
472
+ )
473
+
474
+ @staticmethod
475
+ def _read_scopes_from_config(oauth_config: Dict[str, Any]) -> List[str]:
476
+ """Reads scopes from the configuration.
477
+
478
+ The original scopes are removed from the configuration.
479
+
480
+ Args:
481
+ oauth_config: (dict) The configuration from which to read the scopes.
482
+
483
+ Returns:
484
+ List[str]: The list of scopes.
485
+ """
486
+ scopes = oauth_config.pop(AZURE_SCOPES_FIELD, "")
487
+
488
+ if not scopes:
489
+ message = "Azure Entra ID scopes cannot be empty."
490
+ structlogger.error(
491
+ "azure_oauth_config.scopes_empty",
492
+ message=message,
493
+ )
494
+ raise ValueError(message)
495
+
496
+ if isinstance(scopes, str):
497
+ scopes = [scopes]
498
+
499
+ return scopes
500
+
501
+ @staticmethod
502
+ def _create_azure_entra_id_client_from_config(
503
+ oauth_config: Dict[str, Any],
504
+ ) -> AzureEntraIDTokenProviderConfig:
505
+ """Creates an Azure Entra ID client from the configuration.
506
+
507
+ Args:
508
+ oauth_config: (dict) The configuration from which to create the credential.
509
+
510
+ Returns:
511
+ AzureEntraIDTokenProviderConfig: The Azure OAuth credential.
512
+ """
513
+
514
+ oauth_type = AzureEntraIDOAuthType.from_string(
515
+ oauth_config.pop(OAUTH_TYPE_FIELD, None)
516
+ )
517
+
518
+ if oauth_type == AzureEntraIDOAuthType.INVALID:
519
+ message = (
520
+ "Azure Entra ID oauth configuration must contain "
521
+ f"'{OAUTH_TYPE_FIELD}' field and it must be set to one of the "
522
+ f"following values: {AzureEntraIDOAuthType.valid_string_values()}, "
523
+ )
524
+ structlogger.error(
525
+ "azure_oauth_config.missing_azure_oauth_type",
526
+ message=message,
527
+ )
528
+ raise ValueError(message)
529
+
530
+ azure_oauth_class = AzureEntraIDOAuthConfig._get_azure_oauth_by_type(oauth_type)
531
+ return azure_oauth_class.from_dict(oauth_config)
532
+
533
+ def _create_azure_credential(
534
+ self,
535
+ ) -> TokenProvider:
536
+ """Create an Azure Entra ID client which can be used to get a bearer token."""
537
+ return self.azure_entra_id_token_provider_config.create_azure_token_provider()
538
+
539
+ def get_bearer_token(self) -> str:
540
+ """Returns a bearer token."""
541
+ return self._create_azure_credential().get_token(*self.scopes).token # type: ignore