rasa-pro 3.12.0.dev13__py3-none-any.whl → 3.12.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +10 -13
- rasa/anonymization/anonymization_rule_executor.py +16 -10
- rasa/cli/data.py +16 -0
- rasa/cli/project_templates/calm/config.yml +2 -2
- rasa/cli/project_templates/calm/domain/list_contacts.yml +1 -2
- rasa/cli/project_templates/calm/domain/remove_contact.yml +1 -2
- rasa/cli/project_templates/calm/domain/shared.yml +1 -4
- rasa/cli/project_templates/calm/endpoints.yml +2 -2
- rasa/cli/utils.py +12 -0
- rasa/core/actions/action.py +84 -191
- rasa/core/actions/action_handle_digressions.py +35 -13
- rasa/core/actions/action_run_slot_rejections.py +16 -4
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/studio_chat.py +19 -0
- rasa/core/channels/telegram.py +42 -24
- rasa/core/channels/voice_ready/utils.py +1 -1
- rasa/core/channels/voice_stream/asr/asr_engine.py +10 -4
- rasa/core/channels/voice_stream/asr/azure.py +14 -1
- rasa/core/channels/voice_stream/asr/deepgram.py +20 -4
- rasa/core/channels/voice_stream/audiocodes.py +264 -0
- rasa/core/channels/voice_stream/browser_audio.py +4 -1
- rasa/core/channels/voice_stream/call_state.py +3 -0
- rasa/core/channels/voice_stream/genesys.py +6 -2
- rasa/core/channels/voice_stream/tts/azure.py +9 -1
- rasa/core/channels/voice_stream/tts/cartesia.py +14 -8
- rasa/core/channels/voice_stream/voice_channel.py +23 -2
- rasa/core/constants.py +2 -0
- rasa/core/nlg/contextual_response_rephraser.py +18 -1
- rasa/core/nlg/generator.py +83 -15
- rasa/core/nlg/response.py +6 -3
- rasa/core/nlg/translate.py +55 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +1 -1
- rasa/core/policies/flows/flow_executor.py +19 -7
- rasa/core/processor.py +71 -9
- rasa/dialogue_understanding/commands/can_not_handle_command.py +20 -2
- rasa/dialogue_understanding/commands/cancel_flow_command.py +24 -6
- rasa/dialogue_understanding/commands/change_flow_command.py +20 -2
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +20 -2
- rasa/dialogue_understanding/commands/clarify_command.py +29 -3
- rasa/dialogue_understanding/commands/command.py +1 -16
- rasa/dialogue_understanding/commands/command_syntax_manager.py +55 -0
- rasa/dialogue_understanding/commands/handle_digressions_command.py +1 -7
- rasa/dialogue_understanding/commands/human_handoff_command.py +20 -2
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +20 -2
- rasa/dialogue_understanding/commands/prompt_command.py +94 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +20 -2
- rasa/dialogue_understanding/commands/set_slot_command.py +24 -2
- rasa/dialogue_understanding/commands/skip_question_command.py +20 -2
- rasa/dialogue_understanding/commands/start_flow_command.py +22 -2
- rasa/dialogue_understanding/commands/utils.py +71 -4
- rasa/dialogue_understanding/generator/__init__.py +2 -0
- rasa/dialogue_understanding/generator/command_parser.py +15 -12
- rasa/dialogue_understanding/generator/constants.py +3 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -5
- rasa/dialogue_understanding/generator/llm_command_generator.py +5 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +17 -3
- rasa/dialogue_understanding/generator/prompt_templates/__init__.py +0 -0
- rasa/dialogue_understanding/generator/{single_step → prompt_templates}/command_prompt_template.jinja2 +2 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +77 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +68 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +84 -0
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +522 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +12 -310
- rasa/dialogue_understanding/patterns/collect_information.py +1 -1
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +16 -0
- rasa/dialogue_understanding/patterns/validate_slot.py +65 -0
- rasa/dialogue_understanding/processor/command_processor.py +39 -0
- rasa/dialogue_understanding/stack/utils.py +38 -0
- rasa/dialogue_understanding_test/du_test_case.py +58 -18
- rasa/dialogue_understanding_test/du_test_result.py +14 -10
- rasa/dialogue_understanding_test/io.py +14 -0
- rasa/e2e_test/assertions.py +6 -8
- rasa/e2e_test/llm_judge_prompts/answer_relevance_prompt_template.jinja2 +5 -1
- rasa/e2e_test/llm_judge_prompts/groundedness_prompt_template.jinja2 +4 -0
- rasa/e2e_test/utils/io.py +0 -37
- rasa/engine/graph.py +1 -0
- rasa/engine/language.py +140 -0
- rasa/engine/recipes/config_files/default_config.yml +4 -0
- rasa/engine/recipes/default_recipe.py +2 -0
- rasa/engine/recipes/graph_recipe.py +2 -0
- rasa/engine/storage/local_model_storage.py +1 -0
- rasa/engine/storage/storage.py +4 -1
- rasa/llm_fine_tuning/conversations.py +1 -1
- rasa/model_manager/runner_service.py +7 -4
- rasa/model_manager/socket_bridge.py +7 -6
- rasa/shared/constants.py +15 -13
- rasa/shared/core/constants.py +2 -0
- rasa/shared/core/flows/constants.py +11 -0
- rasa/shared/core/flows/flow.py +83 -19
- rasa/shared/core/flows/flows_yaml_schema.json +31 -3
- rasa/shared/core/flows/steps/collect.py +1 -36
- rasa/shared/core/flows/utils.py +28 -4
- rasa/shared/core/flows/validation.py +1 -1
- rasa/shared/core/slot_mappings.py +208 -5
- rasa/shared/core/slots.py +137 -1
- rasa/shared/core/trackers.py +74 -1
- rasa/shared/importers/importer.py +50 -2
- rasa/shared/nlu/training_data/schemas/responses.yml +19 -12
- rasa/shared/providers/_configs/azure_entra_id_config.py +541 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +138 -3
- rasa/shared/providers/_configs/client_config.py +3 -1
- rasa/shared/providers/_configs/default_litellm_client_config.py +3 -1
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +3 -1
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -1
- rasa/shared/providers/_configs/model_group_config.py +4 -2
- rasa/shared/providers/_configs/oauth_config.py +33 -0
- rasa/shared/providers/_configs/openai_client_config.py +3 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +3 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +3 -1
- rasa/shared/providers/constants.py +6 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +28 -3
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +3 -1
- rasa/shared/providers/llm/_base_litellm_client.py +42 -17
- rasa/shared/providers/llm/azure_openai_llm_client.py +81 -25
- rasa/shared/providers/llm/default_litellm_llm_client.py +3 -1
- rasa/shared/providers/llm/litellm_router_llm_client.py +29 -8
- rasa/shared/providers/llm/llm_client.py +23 -7
- rasa/shared/providers/llm/openai_llm_client.py +9 -3
- rasa/shared/providers/llm/rasa_llm_client.py +11 -2
- rasa/shared/providers/llm/self_hosted_llm_client.py +30 -11
- rasa/shared/providers/router/_base_litellm_router_client.py +3 -1
- rasa/shared/providers/router/router_client.py +3 -1
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/llm.py +33 -7
- rasa/shared/utils/pykwalify_extensions.py +24 -0
- rasa/shared/utils/schemas/domain.yml +26 -0
- rasa/telemetry.py +2 -1
- rasa/tracing/config.py +2 -0
- rasa/tracing/constants.py +12 -0
- rasa/tracing/instrumentation/instrumentation.py +36 -0
- rasa/tracing/instrumentation/metrics.py +41 -0
- rasa/tracing/metric_instrument_provider.py +40 -0
- rasa/validator.py +372 -7
- rasa/version.py +1 -1
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc2.dist-info}/METADATA +13 -14
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc2.dist-info}/RECORD +139 -124
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc2.dist-info}/entry_points.txt +0 -0
|
@@ -20,16 +20,19 @@ import rasa.shared.constants
|
|
|
20
20
|
import rasa.shared.core.constants
|
|
21
21
|
import rasa.shared.utils.common
|
|
22
22
|
import rasa.shared.utils.io
|
|
23
|
+
from rasa.shared.constants import CONFIG_ADDITIONAL_LANGUAGES_KEY, CONFIG_LANGUAGE_KEY
|
|
23
24
|
from rasa.shared.core.domain import (
|
|
24
25
|
IS_RETRIEVAL_INTENT_KEY,
|
|
25
26
|
KEY_ACTIONS,
|
|
26
27
|
KEY_E2E_ACTIONS,
|
|
27
28
|
KEY_INTENTS,
|
|
28
29
|
KEY_RESPONSES,
|
|
30
|
+
KEY_SLOTS,
|
|
29
31
|
Domain,
|
|
30
32
|
)
|
|
31
33
|
from rasa.shared.core.events import ActionExecuted, UserUttered
|
|
32
34
|
from rasa.shared.core.flows import FlowsList
|
|
35
|
+
from rasa.shared.core.slots import StrictCategoricalSlot
|
|
33
36
|
from rasa.shared.core.training_data.structures import StoryGraph
|
|
34
37
|
from rasa.shared.nlu.constants import ACTION_NAME, ENTITIES
|
|
35
38
|
from rasa.shared.nlu.training_data.message import Message
|
|
@@ -202,8 +205,10 @@ class TrainingDataImporter(ABC):
|
|
|
202
205
|
)
|
|
203
206
|
]
|
|
204
207
|
|
|
205
|
-
return
|
|
206
|
-
|
|
208
|
+
return LanguageImporter(
|
|
209
|
+
E2EImporter(
|
|
210
|
+
FlowSyncImporter(ResponsesSyncImporter(CombinedDataImporter(importers)))
|
|
211
|
+
)
|
|
207
212
|
)
|
|
208
213
|
|
|
209
214
|
@staticmethod
|
|
@@ -522,6 +527,49 @@ class FlowSyncImporter(PassThroughImporter):
|
|
|
522
527
|
return self._importer.get_domain()
|
|
523
528
|
|
|
524
529
|
|
|
530
|
+
class LanguageImporter(PassThroughImporter):
|
|
531
|
+
"""Importer that configures the language settings in the domain."""
|
|
532
|
+
|
|
533
|
+
@cached_method
|
|
534
|
+
def get_domain(self) -> Domain:
|
|
535
|
+
domain = self._importer.get_domain()
|
|
536
|
+
if domain.is_empty():
|
|
537
|
+
return domain
|
|
538
|
+
|
|
539
|
+
config = self._importer.get_config()
|
|
540
|
+
language = config.get(CONFIG_LANGUAGE_KEY)
|
|
541
|
+
additional_languages = config.get(CONFIG_ADDITIONAL_LANGUAGES_KEY) or []
|
|
542
|
+
|
|
543
|
+
values = additional_languages.copy()
|
|
544
|
+
if language and language not in values:
|
|
545
|
+
values.append(language)
|
|
546
|
+
|
|
547
|
+
# Prepare the serialized representation of the language slot
|
|
548
|
+
slot_name = rasa.shared.core.constants.LANGUAGE_SLOT
|
|
549
|
+
serialized_slot: Dict[Text, Any] = {
|
|
550
|
+
"type": StrictCategoricalSlot.type_name,
|
|
551
|
+
"initial_value": language,
|
|
552
|
+
"values": values,
|
|
553
|
+
"mappings": [],
|
|
554
|
+
"is_builtin": True,
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
domain_with_language_slot = Domain.from_dict(
|
|
558
|
+
{KEY_SLOTS: {slot_name: serialized_slot}}
|
|
559
|
+
)
|
|
560
|
+
return domain.merge(domain_with_language_slot)
|
|
561
|
+
|
|
562
|
+
@cached_method
|
|
563
|
+
def get_user_domain(self) -> Domain:
|
|
564
|
+
"""Delegate to the underlying importer to get the user domain."""
|
|
565
|
+
return self._importer.get_user_domain()
|
|
566
|
+
|
|
567
|
+
@cached_method
|
|
568
|
+
def get_user_flows(self) -> FlowsList:
|
|
569
|
+
"""Delegate to the underlying importer to get user flows."""
|
|
570
|
+
return self._importer.get_user_flows()
|
|
571
|
+
|
|
572
|
+
|
|
525
573
|
class ResponsesSyncImporter(PassThroughImporter):
|
|
526
574
|
"""Importer that syncs `responses` between Domain and NLU training data.
|
|
527
575
|
|
|
@@ -17,6 +17,12 @@ schema;responses:
|
|
|
17
17
|
required: False
|
|
18
18
|
text:
|
|
19
19
|
type: "str"
|
|
20
|
+
translation:
|
|
21
|
+
type: "map"
|
|
22
|
+
allowempty: True
|
|
23
|
+
mapping:
|
|
24
|
+
regex;(.*):
|
|
25
|
+
type: "str"
|
|
20
26
|
image:
|
|
21
27
|
type: "str"
|
|
22
28
|
custom:
|
|
@@ -32,6 +38,18 @@ schema;responses:
|
|
|
32
38
|
type: "str"
|
|
33
39
|
payload:
|
|
34
40
|
type: "str"
|
|
41
|
+
translation:
|
|
42
|
+
type: "map"
|
|
43
|
+
allowempty: True
|
|
44
|
+
mapping:
|
|
45
|
+
regex;(.*):
|
|
46
|
+
type: "map"
|
|
47
|
+
allowempty: True
|
|
48
|
+
mapping:
|
|
49
|
+
title:
|
|
50
|
+
type: "str"
|
|
51
|
+
payload:
|
|
52
|
+
type: "str"
|
|
35
53
|
button_type:
|
|
36
54
|
type: "str"
|
|
37
55
|
quick_replies:
|
|
@@ -57,15 +75,4 @@ schema;responses:
|
|
|
57
75
|
metadata:
|
|
58
76
|
type: "any"
|
|
59
77
|
condition:
|
|
60
|
-
type: "
|
|
61
|
-
sequence:
|
|
62
|
-
- type: "map"
|
|
63
|
-
allowempty: False
|
|
64
|
-
mapping:
|
|
65
|
-
type:
|
|
66
|
-
type: "str"
|
|
67
|
-
enum: ['slot']
|
|
68
|
-
name:
|
|
69
|
-
type: "str"
|
|
70
|
-
value:
|
|
71
|
-
type: "any"
|
|
78
|
+
type: "any"
|
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
import logging
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from functools import lru_cache
|
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Type
|
|
9
|
+
|
|
10
|
+
import structlog
|
|
11
|
+
from azure.core.credentials import TokenProvider
|
|
12
|
+
from azure.identity import (
|
|
13
|
+
CertificateCredential,
|
|
14
|
+
ClientSecretCredential,
|
|
15
|
+
DefaultAzureCredential,
|
|
16
|
+
)
|
|
17
|
+
from pydantic import BaseModel, Field, SecretStr
|
|
18
|
+
|
|
19
|
+
from rasa.shared.providers._configs.oauth_config import OAUTH_TYPE_FIELD, OAuth
|
|
20
|
+
|
|
21
|
+
AZURE_CLIENT_ID_FIELD = "client_id"
|
|
22
|
+
AZURE_CLIENT_SECRET_FIELD = "client_secret"
|
|
23
|
+
AZURE_TENANT_ID_FIELD = "tenant_id"
|
|
24
|
+
AZURE_CERTIFICATE_PATH_FIELD = "certificate_path"
|
|
25
|
+
AZURE_CERTIFICATE_PASSWORD_FIELD = "certificate_password"
|
|
26
|
+
AZURE_SEND_CERTIFICATE_CHAIN_FIELD = "send_certificate_chain"
|
|
27
|
+
AZURE_SCOPES_FIELD = "scopes"
|
|
28
|
+
AZURE_AUTHORITY_FIELD = "authority_host"
|
|
29
|
+
AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD = "disable_instance_discovery"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
azure_logger = logging.getLogger("azure")
|
|
33
|
+
azure_logger.setLevel(logging.DEBUG)
|
|
34
|
+
|
|
35
|
+
structlogger = structlog.get_logger()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class AzureEntraIDOAuthType(str, Enum):
|
|
39
|
+
"""Azure Entra ID OAuth types."""
|
|
40
|
+
|
|
41
|
+
AZURE_ENTRA_ID_DEFAULT = "azure_entra_id_default"
|
|
42
|
+
AZURE_ENTRA_ID_CLIENT_SECRET = "azure_entra_id_client_secret"
|
|
43
|
+
AZURE_ENTRA_ID_CLIENT_CERTIFICATE = "azure_entra_id_client_certificate"
|
|
44
|
+
|
|
45
|
+
# Invalid type is used to indicate that the type
|
|
46
|
+
# configuration is invalid EntraID or not set.
|
|
47
|
+
INVALID = "invalid"
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def from_string(value: Optional[str]) -> AzureEntraIDOAuthType:
|
|
51
|
+
"""Converts a string to an AzureOAuthType."""
|
|
52
|
+
if value is None or value not in AzureEntraIDOAuthType.valid_string_values():
|
|
53
|
+
return AzureEntraIDOAuthType.INVALID
|
|
54
|
+
|
|
55
|
+
return AzureEntraIDOAuthType(value)
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def valid_string_values() -> Set[str]:
|
|
59
|
+
"""Returns the valid string values for the AzureOAuthType."""
|
|
60
|
+
return {e.value for e in AzureEntraIDOAuthType.valid_values()}
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def valid_values() -> Set[AzureEntraIDOAuthType]:
|
|
64
|
+
"""Returns the valid values for the AzureOAuthType."""
|
|
65
|
+
return {
|
|
66
|
+
AzureEntraIDOAuthType.AZURE_ENTRA_ID_DEFAULT,
|
|
67
|
+
AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_SECRET,
|
|
68
|
+
AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_CERTIFICATE,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# BearerTokenProvider is a callable that returns a bearer token.
|
|
73
|
+
BearerTokenProvider = Callable[[], str]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AzureEntraIDTokenProviderConfig(abc.ABC):
|
|
77
|
+
"""Interface for Azure Entra ID OAuth credential configuration."""
|
|
78
|
+
|
|
79
|
+
@abc.abstractmethod
|
|
80
|
+
def create_azure_token_provider(self) -> TokenProvider:
|
|
81
|
+
"""Create an Azure Entra ID token provider."""
|
|
82
|
+
...
|
|
83
|
+
|
|
84
|
+
@classmethod
|
|
85
|
+
@abc.abstractmethod
|
|
86
|
+
def from_dict(
|
|
87
|
+
cls: Type[AzureEntraIDTokenProviderConfig], config: Dict[str, Any]
|
|
88
|
+
) -> AzureEntraIDTokenProviderConfig:
|
|
89
|
+
"""Initializes a dataclass from the passed config.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
config: (dict) The config from which to initialize.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
AzureEntraIDCredential
|
|
96
|
+
"""
|
|
97
|
+
...
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class AzureEntraIDClientCredentialsConfig(AzureEntraIDTokenProviderConfig, BaseModel):
|
|
101
|
+
"""Azure Entra ID OAuth client credentials configuration.
|
|
102
|
+
|
|
103
|
+
Attributes:
|
|
104
|
+
client_id: The client ID.
|
|
105
|
+
client_secret: The client secret.
|
|
106
|
+
tenant_id: The tenant ID.
|
|
107
|
+
authority_host: The authority host.
|
|
108
|
+
disable_instance_discovery: Whether to disable instance discovery. This is used
|
|
109
|
+
to disable fetching metadata from the Azure Instance Metadata Service.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
client_id: str = Field(min_length=1)
|
|
113
|
+
client_secret: SecretStr = Field(min_length=1)
|
|
114
|
+
tenant_id: str = Field(min_length=1)
|
|
115
|
+
authority_host: Optional[str] = None
|
|
116
|
+
disable_instance_discovery: bool = False
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def required_fields() -> Set[str]:
|
|
120
|
+
"""Returns the required fields for the configuration."""
|
|
121
|
+
return {AZURE_CLIENT_ID_FIELD, AZURE_TENANT_ID_FIELD, AZURE_CLIENT_SECRET_FIELD}
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def config_has_required_fields(config: Dict[str, Any]) -> bool:
|
|
125
|
+
"""Check if the configuration has all the required fields."""
|
|
126
|
+
return AzureEntraIDClientCredentialsConfig.required_fields().issubset(
|
|
127
|
+
set(config.keys())
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
@classmethod
|
|
131
|
+
def from_dict(cls, config: Dict[str, Any]) -> AzureEntraIDClientCredentialsConfig:
|
|
132
|
+
"""Initializes a dataclass from the passed config.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
config: (dict) The config from which to initialize.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
AzureClientCredentialsConfig
|
|
139
|
+
"""
|
|
140
|
+
if not cls.config_has_required_fields(config):
|
|
141
|
+
message = (
|
|
142
|
+
f"A configuration for Azure client credentials "
|
|
143
|
+
f"must contain the following keys: {cls.required_fields()}"
|
|
144
|
+
)
|
|
145
|
+
structlogger.error(
|
|
146
|
+
"azure_client_credentials_config.missing_required_keys",
|
|
147
|
+
message=message,
|
|
148
|
+
config=config,
|
|
149
|
+
)
|
|
150
|
+
raise ValueError(message)
|
|
151
|
+
|
|
152
|
+
return cls(
|
|
153
|
+
client_id=config.pop(AZURE_CLIENT_ID_FIELD),
|
|
154
|
+
client_secret=config.pop(AZURE_CLIENT_SECRET_FIELD),
|
|
155
|
+
tenant_id=config.pop(AZURE_TENANT_ID_FIELD),
|
|
156
|
+
authority_host=config.pop(AZURE_AUTHORITY_FIELD, None),
|
|
157
|
+
disable_instance_discovery=config.pop(
|
|
158
|
+
AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD, False
|
|
159
|
+
),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def create_azure_token_provider(self) -> TokenProvider:
|
|
163
|
+
"""Create a ClientSecretCredential for Azure Entra ID."""
|
|
164
|
+
return create_azure_entra_id_client_credentials(
|
|
165
|
+
client_id=self.client_id,
|
|
166
|
+
client_secret=self.client_secret.get_secret_value(),
|
|
167
|
+
tenant_id=self.tenant_id,
|
|
168
|
+
authority_host=self.authority_host,
|
|
169
|
+
disable_instance_discovery=self.disable_instance_discovery,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
# We are caching the result of this function to preserve the refresh
|
|
174
|
+
# token which is stored inside the credential object.
|
|
175
|
+
# This allows us to reuse the same credential object (refresh token)
|
|
176
|
+
# across multiple requests.
|
|
177
|
+
# Refresh token is used to get a new access token when the current access
|
|
178
|
+
# token expires without having to re-authenticate the
|
|
179
|
+
# user (transmit the client secret again).
|
|
180
|
+
@lru_cache
|
|
181
|
+
def create_azure_entra_id_client_credentials(
|
|
182
|
+
client_id: str,
|
|
183
|
+
client_secret: str,
|
|
184
|
+
tenant_id: str,
|
|
185
|
+
authority_host: Optional[str] = None,
|
|
186
|
+
disable_instance_discovery: bool = False,
|
|
187
|
+
) -> ClientSecretCredential:
|
|
188
|
+
"""Creates a ClientSecretCredential for Azure Entra ID.
|
|
189
|
+
|
|
190
|
+
We cache the result of this function to avoid creating multiple instances
|
|
191
|
+
of the same credential. This makes it possible to utilise the token caching
|
|
192
|
+
and token refreshing functionality of the azure-identity library.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
client_id: The client ID.
|
|
196
|
+
client_secret: The client secret.
|
|
197
|
+
tenant_id: The tenant ID.
|
|
198
|
+
authority_host: The authority host.
|
|
199
|
+
disable_instance_discovery: Whether to disable instance discovery. This is used
|
|
200
|
+
to disable fetching metadata from the Azure Instance Metadata Service.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
ClientSecretCredential
|
|
204
|
+
"""
|
|
205
|
+
return ClientSecretCredential(
|
|
206
|
+
client_id=client_id,
|
|
207
|
+
client_secret=client_secret,
|
|
208
|
+
tenant_id=tenant_id,
|
|
209
|
+
authority=authority_host,
|
|
210
|
+
disable_instance_discovery=disable_instance_discovery,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class AzureEntraIDClientCertificateConfig(AzureEntraIDTokenProviderConfig, BaseModel):
|
|
215
|
+
"""Azure Entra ID OAuth client certificate configuration.
|
|
216
|
+
|
|
217
|
+
Attributes:
|
|
218
|
+
client_id: The client ID.
|
|
219
|
+
tenant_id: The tenant ID.
|
|
220
|
+
certificate_path: The path to the certificate file.
|
|
221
|
+
certificate_password: The certificate password.
|
|
222
|
+
send_certificate_chain: Whether to send the certificate chain.
|
|
223
|
+
authority_host: The authority host.
|
|
224
|
+
disable_instance_discovery: Whether to disable instance discovery. This is used
|
|
225
|
+
to disable fetching metadata from the Azure Instance Metadata Service.
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
client_id: str = Field(min_length=1)
|
|
229
|
+
tenant_id: str = Field(min_length=1)
|
|
230
|
+
certificate_path: str = Field(min_length=1)
|
|
231
|
+
certificate_password: Optional[SecretStr] = None
|
|
232
|
+
send_certificate_chain: bool = False
|
|
233
|
+
authority_host: Optional[str] = None
|
|
234
|
+
disable_instance_discovery: bool = False
|
|
235
|
+
|
|
236
|
+
@staticmethod
|
|
237
|
+
def required_fields() -> Set[str]:
|
|
238
|
+
"""Returns the required fields for the configuration."""
|
|
239
|
+
return {
|
|
240
|
+
AZURE_CLIENT_ID_FIELD,
|
|
241
|
+
AZURE_TENANT_ID_FIELD,
|
|
242
|
+
AZURE_CERTIFICATE_PATH_FIELD,
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
@staticmethod
|
|
246
|
+
def config_has_required_fields(config: Dict[str, Any]) -> bool:
|
|
247
|
+
"""Check if the configuration has all the required fields."""
|
|
248
|
+
return AzureEntraIDClientCertificateConfig.required_fields().issubset(
|
|
249
|
+
set(config.keys())
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
@classmethod
|
|
253
|
+
def from_dict(cls, config: Dict[str, Any]) -> AzureEntraIDClientCertificateConfig:
|
|
254
|
+
"""Initializes a dataclass from the passed config.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
config: (dict) The config from which to initialize.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
AzureClientCertificateConfig
|
|
261
|
+
"""
|
|
262
|
+
if not cls.config_has_required_fields(config):
|
|
263
|
+
message = (
|
|
264
|
+
f"A configuration for Azure client certificate "
|
|
265
|
+
f"must contain "
|
|
266
|
+
f"the following keys: {cls.required_fields()}"
|
|
267
|
+
)
|
|
268
|
+
structlogger.error(
|
|
269
|
+
"azure_client_certificate_config.validation_error",
|
|
270
|
+
message=message,
|
|
271
|
+
config=config,
|
|
272
|
+
)
|
|
273
|
+
raise ValueError(message)
|
|
274
|
+
|
|
275
|
+
return cls(
|
|
276
|
+
client_id=config[AZURE_CLIENT_ID_FIELD],
|
|
277
|
+
tenant_id=config[AZURE_TENANT_ID_FIELD],
|
|
278
|
+
certificate_path=config[AZURE_CERTIFICATE_PATH_FIELD],
|
|
279
|
+
certificate_password=config.get(AZURE_CERTIFICATE_PASSWORD_FIELD, None),
|
|
280
|
+
authority_host=config.get(AZURE_AUTHORITY_FIELD, None),
|
|
281
|
+
send_certificate_chain=config.get(
|
|
282
|
+
AZURE_SEND_CERTIFICATE_CHAIN_FIELD, False
|
|
283
|
+
),
|
|
284
|
+
disable_instance_discovery=config.get(
|
|
285
|
+
AZURE_DISABLE_INSTANCE_DISCOVERY_FIELD, False
|
|
286
|
+
),
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def create_azure_token_provider(self) -> TokenProvider:
|
|
290
|
+
"""Creates a CertificateCredential for Azure Entra ID."""
|
|
291
|
+
return create_azure_entra_id_certificate_credentials(
|
|
292
|
+
client_id=self.client_id,
|
|
293
|
+
tenant_id=self.tenant_id,
|
|
294
|
+
certificate_path=self.certificate_path,
|
|
295
|
+
password=self.certificate_password.get_secret_value()
|
|
296
|
+
if self.certificate_password
|
|
297
|
+
else None,
|
|
298
|
+
send_certificate_chain=self.send_certificate_chain,
|
|
299
|
+
authority_host=self.authority_host,
|
|
300
|
+
disable_instance_discovery=self.disable_instance_discovery,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
# We are caching the result of this function to preserve the refresh
|
|
305
|
+
# token which is stored inside the credential object.
|
|
306
|
+
# This allows us to reuse the same credential object (refresh token)
|
|
307
|
+
# across multiple requests.
|
|
308
|
+
# Refresh token is used to get a new access token when the current
|
|
309
|
+
# access token expires without having to re-authenticate
|
|
310
|
+
# the user (transmit the client certificate again).
|
|
311
|
+
@lru_cache
|
|
312
|
+
def create_azure_entra_id_certificate_credentials(
|
|
313
|
+
tenant_id: str,
|
|
314
|
+
client_id: str,
|
|
315
|
+
certificate_path: Optional[str] = None,
|
|
316
|
+
password: Optional[str] = None,
|
|
317
|
+
send_certificate_chain: bool = False,
|
|
318
|
+
authority_host: Optional[str] = None,
|
|
319
|
+
disable_instance_discovery: bool = False,
|
|
320
|
+
) -> CertificateCredential:
|
|
321
|
+
"""Creates a CertificateCredential for Azure Entra ID.
|
|
322
|
+
|
|
323
|
+
We cache the result of this function to avoid creating multiple instances
|
|
324
|
+
of the same credential. This makes it possible to utilise the token caching
|
|
325
|
+
and token refreshing functionality of the azure-identity library.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
tenant_id: The tenant ID.
|
|
329
|
+
client_id: The client ID.
|
|
330
|
+
certificate_path: The path to the certificate file.
|
|
331
|
+
password: The certificate password.
|
|
332
|
+
send_certificate_chain: Whether to send the certificate chain.
|
|
333
|
+
authority_host: The authority host.
|
|
334
|
+
disable_instance_discovery: Whether to disable instance discovery. This is used
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
CertificateCredential
|
|
338
|
+
"""
|
|
339
|
+
|
|
340
|
+
return CertificateCredential(
|
|
341
|
+
client_id=client_id,
|
|
342
|
+
tenant_id=tenant_id,
|
|
343
|
+
certificate_path=certificate_path,
|
|
344
|
+
password=password.encode("utf-8") if password else None,
|
|
345
|
+
send_certificate_chain=send_certificate_chain,
|
|
346
|
+
authority=authority_host,
|
|
347
|
+
disable_instance_discovery=disable_instance_discovery,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
class AzureEntraIDDefaultCredentialsConfig(AzureEntraIDTokenProviderConfig, BaseModel):
|
|
352
|
+
"""Azure Entra ID OAuth default credentials configuration.
|
|
353
|
+
|
|
354
|
+
Attributes:
|
|
355
|
+
authority_host: The authority host.
|
|
356
|
+
"""
|
|
357
|
+
|
|
358
|
+
authority_host: Optional[str] = None
|
|
359
|
+
|
|
360
|
+
@classmethod
|
|
361
|
+
def from_dict(cls, config: Dict[str, Any]) -> AzureEntraIDDefaultCredentialsConfig:
|
|
362
|
+
"""Initializes a dataclass from the passed config.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
config: (dict) The config from which to initialize.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
AzureOAuthDefaultCredentialsConfig
|
|
369
|
+
"""
|
|
370
|
+
return cls(authority_host=config.pop(AZURE_AUTHORITY_FIELD, None))
|
|
371
|
+
|
|
372
|
+
def create_azure_token_provider(self) -> TokenProvider:
|
|
373
|
+
"""Creates a DefaultAzureCredential."""
|
|
374
|
+
return create_azure_entra_id_default_credentials(
|
|
375
|
+
authority_host=self.authority_host
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@lru_cache
|
|
380
|
+
def create_azure_entra_id_default_credentials(
|
|
381
|
+
authority_host: Optional[str] = None,
|
|
382
|
+
) -> DefaultAzureCredential:
|
|
383
|
+
"""Creates a DefaultAzureCredential.
|
|
384
|
+
|
|
385
|
+
We cache the result of this function to avoid creating multiple instances
|
|
386
|
+
of the same credential. This makes it possible to utilise the token caching
|
|
387
|
+
functionality of the azure-identity library.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
authority_host: The authority host.
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
DefaultAzureCredential
|
|
394
|
+
"""
|
|
395
|
+
return DefaultAzureCredential(authority=authority_host)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class AzureEntraIDOAuthConfig(OAuth, BaseModel):
|
|
399
|
+
"""Azure Entra ID OAuth configuration.
|
|
400
|
+
|
|
401
|
+
It consists of the scopes and the Azure Entra ID OAuth credentials.
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
# pydantic configuration to allow arbitrary user defined types
|
|
405
|
+
class Config:
|
|
406
|
+
arbitrary_types_allowed = True
|
|
407
|
+
|
|
408
|
+
scopes: List[str]
|
|
409
|
+
azure_entra_id_token_provider_config: AzureEntraIDTokenProviderConfig
|
|
410
|
+
|
|
411
|
+
@staticmethod
|
|
412
|
+
def _supported_azure_oauth() -> (
|
|
413
|
+
Dict[AzureEntraIDOAuthType, Type[AzureEntraIDTokenProviderConfig]]
|
|
414
|
+
):
|
|
415
|
+
"""Returns a mapping of supported Azure Entra ID OAuth types to their"""
|
|
416
|
+
return {
|
|
417
|
+
AzureEntraIDOAuthType.AZURE_ENTRA_ID_DEFAULT: AzureEntraIDDefaultCredentialsConfig, # noqa: E501
|
|
418
|
+
AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_SECRET: AzureEntraIDClientCredentialsConfig, # noqa: E501
|
|
419
|
+
AzureEntraIDOAuthType.AZURE_ENTRA_ID_CLIENT_CERTIFICATE: AzureEntraIDClientCertificateConfig, # noqa: E501
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
@staticmethod
|
|
423
|
+
def _get_azure_oauth_by_type(
|
|
424
|
+
oauth_type: AzureEntraIDOAuthType,
|
|
425
|
+
) -> Type[AzureEntraIDTokenProviderConfig]:
|
|
426
|
+
"""Returns the Azure Entra ID OAuth class based on the type.
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
oauth_type: (AzureOAuthType) The type of the Azure Entra ID OAuth.
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
The Azure Entra ID OAuth class
|
|
433
|
+
|
|
434
|
+
Raises:
|
|
435
|
+
ValueError: If the passed oauth_type is not supported or invalid.
|
|
436
|
+
"""
|
|
437
|
+
azure_oauth_types = AzureEntraIDOAuthConfig._supported_azure_oauth()
|
|
438
|
+
azure_oauth_class = azure_oauth_types.get(oauth_type)
|
|
439
|
+
|
|
440
|
+
if azure_oauth_class is None:
|
|
441
|
+
message = (
|
|
442
|
+
f"Unsupported Azure Entra ID oauth type: {oauth_type}. "
|
|
443
|
+
f"Supported types are: {AzureEntraIDOAuthType.valid_string_values()}"
|
|
444
|
+
)
|
|
445
|
+
structlogger.error(
|
|
446
|
+
"azure_oauth_config.unsupported_azure_oauth_type",
|
|
447
|
+
message=message,
|
|
448
|
+
)
|
|
449
|
+
raise ValueError(message)
|
|
450
|
+
|
|
451
|
+
return azure_oauth_class
|
|
452
|
+
|
|
453
|
+
@classmethod
|
|
454
|
+
def from_dict(cls, oauth_config: Dict[str, Any]) -> AzureEntraIDOAuthConfig:
|
|
455
|
+
"""Initializes a dataclass from the passed config.
|
|
456
|
+
|
|
457
|
+
Args:
|
|
458
|
+
oauth_config: (dict) The config from which to initialize.
|
|
459
|
+
|
|
460
|
+
Returns:
|
|
461
|
+
AzureOAuthConfig
|
|
462
|
+
"""
|
|
463
|
+
|
|
464
|
+
config = deepcopy(oauth_config)
|
|
465
|
+
|
|
466
|
+
scopes = AzureEntraIDOAuthConfig._read_scopes_from_config(config)
|
|
467
|
+
azure_credentials = (
|
|
468
|
+
AzureEntraIDOAuthConfig._create_azure_entra_id_client_from_config(config)
|
|
469
|
+
)
|
|
470
|
+
return cls(
|
|
471
|
+
azure_entra_id_token_provider_config=azure_credentials, scopes=scopes
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
@staticmethod
|
|
475
|
+
def _read_scopes_from_config(oauth_config: Dict[str, Any]) -> List[str]:
|
|
476
|
+
"""Reads scopes from the configuration.
|
|
477
|
+
|
|
478
|
+
The original scopes are removed from the configuration.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
oauth_config: (dict) The configuration from which to read the scopes.
|
|
482
|
+
|
|
483
|
+
Returns:
|
|
484
|
+
List[str]: The list of scopes.
|
|
485
|
+
"""
|
|
486
|
+
scopes = oauth_config.pop(AZURE_SCOPES_FIELD, "")
|
|
487
|
+
|
|
488
|
+
if not scopes:
|
|
489
|
+
message = "Azure Entra ID scopes cannot be empty."
|
|
490
|
+
structlogger.error(
|
|
491
|
+
"azure_oauth_config.scopes_empty",
|
|
492
|
+
message=message,
|
|
493
|
+
)
|
|
494
|
+
raise ValueError(message)
|
|
495
|
+
|
|
496
|
+
if isinstance(scopes, str):
|
|
497
|
+
scopes = [scopes]
|
|
498
|
+
|
|
499
|
+
return scopes
|
|
500
|
+
|
|
501
|
+
@staticmethod
|
|
502
|
+
def _create_azure_entra_id_client_from_config(
|
|
503
|
+
oauth_config: Dict[str, Any],
|
|
504
|
+
) -> AzureEntraIDTokenProviderConfig:
|
|
505
|
+
"""Creates an Azure Entra ID client from the configuration.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
oauth_config: (dict) The configuration from which to create the credential.
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
AzureEntraIDTokenProviderConfig: The Azure OAuth credential.
|
|
512
|
+
"""
|
|
513
|
+
|
|
514
|
+
oauth_type = AzureEntraIDOAuthType.from_string(
|
|
515
|
+
oauth_config.pop(OAUTH_TYPE_FIELD, None)
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
if oauth_type == AzureEntraIDOAuthType.INVALID:
|
|
519
|
+
message = (
|
|
520
|
+
"Azure Entra ID oauth configuration must contain "
|
|
521
|
+
f"'{OAUTH_TYPE_FIELD}' field and it must be set to one of the "
|
|
522
|
+
f"following values: {AzureEntraIDOAuthType.valid_string_values()}, "
|
|
523
|
+
)
|
|
524
|
+
structlogger.error(
|
|
525
|
+
"azure_oauth_config.missing_azure_oauth_type",
|
|
526
|
+
message=message,
|
|
527
|
+
)
|
|
528
|
+
raise ValueError(message)
|
|
529
|
+
|
|
530
|
+
azure_oauth_class = AzureEntraIDOAuthConfig._get_azure_oauth_by_type(oauth_type)
|
|
531
|
+
return azure_oauth_class.from_dict(oauth_config)
|
|
532
|
+
|
|
533
|
+
def _create_azure_credential(
|
|
534
|
+
self,
|
|
535
|
+
) -> TokenProvider:
|
|
536
|
+
"""Create an Azure Entra ID client which can be used to get a bearer token."""
|
|
537
|
+
return self.azure_entra_id_token_provider_config.create_azure_token_provider()
|
|
538
|
+
|
|
539
|
+
def get_bearer_token(self) -> str:
|
|
540
|
+
"""Returns a bearer token."""
|
|
541
|
+
return self._create_azure_credential().get_token(*self.scopes).token # type: ignore
|