rasa-pro 3.11.3a1.dev7__py3-none-any.whl → 3.12.0.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (65) hide show
  1. rasa/cli/arguments/default_arguments.py +1 -1
  2. rasa/cli/dialogue_understanding_test.py +251 -0
  3. rasa/core/actions/action.py +7 -16
  4. rasa/core/channels/__init__.py +0 -2
  5. rasa/core/channels/socketio.py +23 -1
  6. rasa/core/nlg/contextual_response_rephraser.py +9 -62
  7. rasa/core/policies/enterprise_search_policy.py +12 -77
  8. rasa/core/policies/flows/flow_executor.py +2 -26
  9. rasa/core/processor.py +8 -11
  10. rasa/dialogue_understanding/generator/command_generator.py +49 -43
  11. rasa/dialogue_understanding/generator/llm_based_command_generator.py +5 -5
  12. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -2
  13. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +15 -34
  14. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +6 -11
  15. rasa/dialogue_understanding/utils.py +1 -8
  16. rasa/dialogue_understanding_test/command_metric_calculation.py +12 -0
  17. rasa/dialogue_understanding_test/constants.py +2 -0
  18. rasa/dialogue_understanding_test/du_test_runner.py +93 -0
  19. rasa/dialogue_understanding_test/io.py +54 -0
  20. rasa/dialogue_understanding_test/validation.py +22 -0
  21. rasa/e2e_test/e2e_test_runner.py +9 -7
  22. rasa/hooks.py +9 -15
  23. rasa/model_manager/socket_bridge.py +2 -7
  24. rasa/model_manager/warm_rasa_process.py +4 -9
  25. rasa/plugin.py +0 -11
  26. rasa/shared/constants.py +2 -21
  27. rasa/shared/core/events.py +8 -8
  28. rasa/shared/nlu/constants.py +0 -3
  29. rasa/shared/providers/_configs/azure_entra_id_client_creds.py +40 -0
  30. rasa/shared/providers/_configs/azure_entra_id_config.py +533 -0
  31. rasa/shared/providers/_configs/azure_openai_client_config.py +131 -15
  32. rasa/shared/providers/_configs/client_config.py +3 -1
  33. rasa/shared/providers/_configs/default_litellm_client_config.py +9 -7
  34. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +13 -11
  35. rasa/shared/providers/_configs/litellm_router_client_config.py +12 -10
  36. rasa/shared/providers/_configs/model_group_config.py +11 -5
  37. rasa/shared/providers/_configs/oauth_config.py +33 -0
  38. rasa/shared/providers/_configs/openai_client_config.py +14 -12
  39. rasa/shared/providers/_configs/rasa_llm_client_config.py +5 -3
  40. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +12 -11
  41. rasa/shared/providers/constants.py +6 -0
  42. rasa/shared/providers/embedding/azure_openai_embedding_client.py +30 -7
  43. rasa/shared/providers/embedding/litellm_router_embedding_client.py +5 -2
  44. rasa/shared/providers/llm/_base_litellm_client.py +6 -4
  45. rasa/shared/providers/llm/azure_openai_llm_client.py +88 -34
  46. rasa/shared/providers/llm/default_litellm_llm_client.py +4 -2
  47. rasa/shared/providers/llm/litellm_router_llm_client.py +23 -3
  48. rasa/shared/providers/llm/llm_client.py +4 -2
  49. rasa/shared/providers/llm/llm_response.py +1 -42
  50. rasa/shared/providers/llm/openai_llm_client.py +11 -5
  51. rasa/shared/providers/llm/rasa_llm_client.py +13 -5
  52. rasa/shared/providers/llm/self_hosted_llm_client.py +17 -10
  53. rasa/shared/providers/router/_base_litellm_router_client.py +10 -8
  54. rasa/shared/providers/router/router_client.py +3 -1
  55. rasa/shared/utils/llm.py +16 -12
  56. rasa/shared/utils/schemas/events.py +1 -1
  57. rasa/tracing/instrumentation/attribute_extractors.py +0 -2
  58. rasa/version.py +1 -1
  59. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/METADATA +2 -1
  60. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/RECORD +63 -56
  61. rasa/core/channels/studio_chat.py +0 -192
  62. rasa/dialogue_understanding/constants.py +0 -1
  63. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/NOTICE +0 -0
  64. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/WHEEL +0 -0
  65. {rasa_pro-3.11.3a1.dev7.dist-info → rasa_pro-3.12.0.dev2.dist-info}/entry_points.txt +0 -0
@@ -1,36 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ from copy import deepcopy
1
4
  from dataclasses import asdict, dataclass, field
2
- from typing import Any, Dict, Optional
5
+ from typing import (
6
+ Any,
7
+ Dict,
8
+ Optional,
9
+ Set,
10
+ )
3
11
 
4
12
  import structlog
5
13
 
6
14
  from rasa.shared.constants import (
7
- MODEL_CONFIG_KEY,
8
- MODEL_NAME_CONFIG_KEY,
9
- OPENAI_API_BASE_CONFIG_KEY,
10
15
  API_BASE_CONFIG_KEY,
11
- OPENAI_API_TYPE_CONFIG_KEY,
16
+ API_KEY,
12
17
  API_TYPE_CONFIG_KEY,
13
- OPENAI_API_VERSION_CONFIG_KEY,
14
18
  API_VERSION_CONFIG_KEY,
19
+ AZURE_API_TYPE,
20
+ AZURE_OPENAI_PROVIDER,
15
21
  DEPLOYMENT_CONFIG_KEY,
16
22
  DEPLOYMENT_NAME_CONFIG_KEY,
17
23
  ENGINE_CONFIG_KEY,
18
- RASA_TYPE_CONFIG_KEY,
19
24
  LANGCHAIN_TYPE_CONFIG_KEY,
20
- STREAM_CONFIG_KEY,
25
+ MODEL_CONFIG_KEY,
26
+ MODEL_NAME_CONFIG_KEY,
21
27
  N_REPHRASES_CONFIG_KEY,
28
+ OPENAI_API_BASE_CONFIG_KEY,
29
+ OPENAI_API_TYPE_CONFIG_KEY,
30
+ OPENAI_API_VERSION_CONFIG_KEY,
31
+ PROVIDER_CONFIG_KEY,
32
+ RASA_TYPE_CONFIG_KEY,
22
33
  REQUEST_TIMEOUT_CONFIG_KEY,
34
+ STREAM_CONFIG_KEY,
23
35
  TIMEOUT_CONFIG_KEY,
24
- PROVIDER_CONFIG_KEY,
25
- AZURE_OPENAI_PROVIDER,
26
- AZURE_API_TYPE,
36
+ )
37
+ from rasa.shared.providers._configs.azure_entra_id_config import (
38
+ AzureEntraIDOAuthConfig,
39
+ AzureEntraIDOAuthType,
40
+ )
41
+ from rasa.shared.providers._configs.oauth_config import (
42
+ OAUTH_KEY,
43
+ OAUTH_TYPE_FIELD,
44
+ OAuth,
27
45
  )
28
46
  from rasa.shared.providers._configs.utils import (
29
- resolve_aliases,
30
47
  raise_deprecation_warnings,
31
- validate_required_keys,
48
+ resolve_aliases,
32
49
  validate_forbidden_keys,
50
+ validate_required_keys,
33
51
  )
52
+ from rasa.shared.utils.common import class_from_module_path
34
53
 
35
54
  structlogger = structlog.get_logger()
36
55
 
@@ -61,6 +80,76 @@ FORBIDDEN_KEYS = [
61
80
  ]
62
81
 
63
82
 
83
+ @dataclass
84
+ class OAuthConfigWrapper(OAuth):
85
+ """Wrapper for OAuth configuration.
86
+
87
+ It's main purpose is to provide to_dict method which is used to serialize
88
+ the oauth configuration to the original format.
89
+
90
+ """
91
+
92
+ oauth: OAuth
93
+ original_config: Dict[str, Any]
94
+
95
+ def get_bearer_token(self) -> str:
96
+ """Returns a bearer token."""
97
+ return self.oauth.get_bearer_token()
98
+
99
+ def to_dict(self) -> Dict[str, Any]:
100
+ """Converts the OAuth configuration to the original format."""
101
+ return self.original_config
102
+
103
+ @staticmethod
104
+ def _valid_type_values() -> Set[str]:
105
+ """Returns the valid built-in values for the `type` field in the `oauth`."""
106
+ return AzureEntraIDOAuthType.valid_string_values()
107
+
108
+ @classmethod
109
+ def from_config(cls, oauth_config: Dict[str, Any]) -> OAuth:
110
+ """Initializes a dataclass from the passed config.
111
+
112
+ Args:
113
+ oauth_config: (dict) The config from which to initialize.
114
+
115
+ Returns:
116
+ AzureOAuthConfig
117
+ """
118
+ original_config = deepcopy(oauth_config)
119
+
120
+ oauth_type: Optional[str] = oauth_config.get(OAUTH_TYPE_FIELD, None)
121
+
122
+ if oauth_type is None:
123
+ message = (
124
+ "Oauth configuration must contain "
125
+ f"'{OAUTH_TYPE_FIELD}' field and it must be set to one of the "
126
+ f"following values: {OAuthConfigWrapper._valid_type_values()}, "
127
+ f"or to the path of module which is implementing {OAuth.__name__} protocol."
128
+ )
129
+ structlogger.error(
130
+ "azure_oauth_config.missing_oauth_type",
131
+ message=message,
132
+ )
133
+ raise ValueError(message)
134
+
135
+ if oauth_type in AzureEntraIDOAuthType.valid_string_values():
136
+ oauth = AzureEntraIDOAuthConfig.from_config(oauth_config)
137
+ else:
138
+ module = class_from_module_path(oauth_type)
139
+
140
+ if not issubclass(module, OAuth):
141
+ message = f"Module {oauth_type} does not implement {OAuth.__name__} interface."
142
+ structlogger.error(
143
+ "azure_oauth_config.invalid_oauth_module",
144
+ message=message,
145
+ )
146
+ raise ValueError(message)
147
+
148
+ oauth = module.from_config(oauth_config)
149
+
150
+ return cls(oauth=oauth, original_config=original_config)
151
+
152
+
64
153
  @dataclass
65
154
  class AzureOpenAIClientConfig:
66
155
  """Parses configuration for Azure OpenAI client, resolves aliases and
@@ -80,11 +169,13 @@ class AzureOpenAIClientConfig:
80
169
  # API Type is not used by LiteLLM backend, but we define
81
170
  # it here for backward compatibility.
82
171
  api_type: Optional[str] = AZURE_API_TYPE
83
-
84
172
  # Provider is not used by LiteLLM backend, but we define it here since it's
85
173
  # used as switch between different clients.
86
174
  provider: str = AZURE_OPENAI_PROVIDER
87
175
 
176
+ # OAuth related parameters
177
+ oauth: Optional[OAuthConfigWrapper] = None
178
+
88
179
  extra_parameters: dict = field(default_factory=dict)
89
180
 
90
181
  def __post_init__(self) -> None:
@@ -106,7 +197,7 @@ class AzureOpenAIClientConfig:
106
197
  raise ValueError(message)
107
198
 
108
199
  @classmethod
109
- def from_dict(cls, config: dict) -> "AzureOpenAIClientConfig":
200
+ def from_dict(cls, config: dict) -> AzureOpenAIClientConfig:
110
201
  """Initializes a dataclass from the passed config.
111
202
 
112
203
  Args:
@@ -129,6 +220,26 @@ class AzureOpenAIClientConfig:
129
220
  # Validate that the forbidden keys are not present
130
221
  validate_forbidden_keys(config, FORBIDDEN_KEYS)
131
222
  # Init client config
223
+
224
+ has_api_key = config.get(API_KEY, None) is not None
225
+ has_oauth_key = config.get(OAUTH_KEY, None) is not None
226
+
227
+ if has_api_key and has_oauth_key:
228
+ message = (
229
+ "Azure OpenAI client configuration cannot contain "
230
+ f"both '{API_KEY}' and '{OAUTH_KEY}' fields. Please provide either "
231
+ f"'{API_KEY}' or '{OAUTH_KEY}' fields."
232
+ )
233
+ structlogger.error(
234
+ "azure_openai_client_config.multiple_auth_types_specified",
235
+ message=message,
236
+ )
237
+ raise ValueError(message)
238
+
239
+ oauth = None
240
+ if has_oauth_key:
241
+ oauth = OAuthConfigWrapper.from_config(config.pop(OAUTH_KEY))
242
+
132
243
  this = AzureOpenAIClientConfig(
133
244
  # Required parameters
134
245
  deployment=config.pop(DEPLOYMENT_CONFIG_KEY),
@@ -142,6 +253,8 @@ class AzureOpenAIClientConfig:
142
253
  # in clients.
143
254
  api_base=config.pop(API_BASE_CONFIG_KEY, None),
144
255
  api_version=config.pop(API_VERSION_CONFIG_KEY, None),
256
+ # OAuth related parameters, set only if auth_type is set to 'entra_id'
257
+ oauth=oauth,
145
258
  # The rest of parameters (e.g. model parameters) are considered
146
259
  # as extra parameters (this also includes timeout).
147
260
  extra_parameters=config,
@@ -154,6 +267,9 @@ class AzureOpenAIClientConfig:
154
267
  # Extra parameters should also be on the top level
155
268
  d.pop("extra_parameters", None)
156
269
  d.update(self.extra_parameters)
270
+
271
+ d.pop("oauth", None)
272
+ d.update({"oauth": self.oauth.to_dict()} if self.oauth else {})
157
273
  return d
158
274
 
159
275
  @staticmethod
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Protocol, runtime_checkable
2
4
 
3
5
 
@@ -9,7 +11,7 @@ class ClientConfig(Protocol):
9
11
  """
10
12
 
11
13
  @classmethod
12
- def from_dict(cls, config: dict) -> "ClientConfig":
14
+ def from_dict(cls, config: dict) -> ClientConfig:
13
15
  """
14
16
  Initializes the client config with the given configuration.
15
17
 
@@ -1,24 +1,26 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import Any, Dict
3
5
 
4
6
  import structlog
5
7
 
8
+ import rasa.shared.utils.cli
6
9
  from rasa.shared.constants import (
7
10
  MODEL_CONFIG_KEY,
8
11
  MODEL_NAME_CONFIG_KEY,
9
- STREAM_CONFIG_KEY,
10
12
  N_REPHRASES_CONFIG_KEY,
11
13
  PROVIDER_CONFIG_KEY,
12
- TIMEOUT_CONFIG_KEY,
13
14
  REQUEST_TIMEOUT_CONFIG_KEY,
15
+ STREAM_CONFIG_KEY,
16
+ TIMEOUT_CONFIG_KEY,
14
17
  )
15
18
  from rasa.shared.providers._configs.utils import (
16
- validate_required_keys,
17
- validate_forbidden_keys,
18
- resolve_aliases,
19
19
  raise_deprecation_warnings,
20
+ resolve_aliases,
21
+ validate_forbidden_keys,
22
+ validate_required_keys,
20
23
  )
21
- import rasa.shared.utils.cli
22
24
 
23
25
  structlogger = structlog.get_logger()
24
26
 
@@ -69,7 +71,7 @@ class DefaultLiteLLMClientConfig:
69
71
  raise ValueError(message)
70
72
 
71
73
  @classmethod
72
- def from_dict(cls, config: dict) -> "DefaultLiteLLMClientConfig":
74
+ def from_dict(cls, config: dict) -> DefaultLiteLLMClientConfig:
73
75
  """
74
76
  Initializes a dataclass from the passed config.
75
77
 
@@ -1,27 +1,29 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import Any, Dict, Optional
3
5
 
4
6
  import structlog
5
7
 
6
8
  from rasa.shared.constants import (
7
- MODEL_CONFIG_KEY,
8
- MODEL_NAME_CONFIG_KEY,
9
- RASA_TYPE_CONFIG_KEY,
10
- LANGCHAIN_TYPE_CONFIG_KEY,
11
- HUGGINGFACE_MULTIPROCESS_CONFIG_KEY,
12
9
  HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY,
13
- HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY,
14
- HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY,
15
10
  HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY,
16
11
  HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER,
17
- PROVIDER_CONFIG_KEY,
18
12
  HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
19
- TIMEOUT_CONFIG_KEY,
13
+ HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY,
14
+ HUGGINGFACE_MULTIPROCESS_CONFIG_KEY,
15
+ HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY,
16
+ LANGCHAIN_TYPE_CONFIG_KEY,
17
+ MODEL_CONFIG_KEY,
18
+ MODEL_NAME_CONFIG_KEY,
19
+ PROVIDER_CONFIG_KEY,
20
+ RASA_TYPE_CONFIG_KEY,
20
21
  REQUEST_TIMEOUT_CONFIG_KEY,
22
+ TIMEOUT_CONFIG_KEY,
21
23
  )
22
24
  from rasa.shared.providers._configs.utils import (
23
- resolve_aliases,
24
25
  raise_deprecation_warnings,
26
+ resolve_aliases,
25
27
  validate_required_keys,
26
28
  )
27
29
  from rasa.shared.utils.io import raise_deprecation_warning
@@ -90,7 +92,7 @@ class HuggingFaceLocalEmbeddingClientConfig:
90
92
  raise ValueError(message)
91
93
 
92
94
  @classmethod
93
- def from_dict(cls, config: dict) -> "HuggingFaceLocalEmbeddingClientConfig":
95
+ def from_dict(cls, config: dict) -> HuggingFaceLocalEmbeddingClientConfig:
94
96
  """
95
97
  Initializes a dataclass from the passed config.
96
98
 
@@ -1,29 +1,31 @@
1
+ from __future__ import annotations
2
+
1
3
  import copy
2
4
  from dataclasses import dataclass, field
3
5
  from typing import Any, Dict, List
4
6
 
5
7
  import structlog
8
+
6
9
  from rasa.shared.constants import (
7
- ROUTER_CONFIG_KEY,
8
- MODELS_CONFIG_KEY,
9
- MODEL_GROUP_ID_CONFIG_KEY,
10
- MODEL_NAME_CONFIG_KEY,
11
- LITELLM_PARAMS_KEY,
12
- PROVIDER_CONFIG_KEY,
13
- DEPLOYMENT_CONFIG_KEY,
14
10
  API_TYPE_CONFIG_KEY,
11
+ DEPLOYMENT_CONFIG_KEY,
12
+ LITELLM_PARAMS_KEY,
15
13
  MODEL_CONFIG_KEY,
14
+ MODEL_GROUP_ID_CONFIG_KEY,
16
15
  MODEL_LIST_KEY,
16
+ MODEL_NAME_CONFIG_KEY,
17
+ MODELS_CONFIG_KEY,
18
+ PROVIDER_CONFIG_KEY,
19
+ ROUTER_CONFIG_KEY,
17
20
  USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
18
21
  )
19
22
  from rasa.shared.providers._configs.model_group_config import (
20
- ModelGroupConfig,
21
23
  ModelConfig,
24
+ ModelGroupConfig,
22
25
  )
23
26
  from rasa.shared.providers.mappings import get_prefix_from_provider
24
27
  from rasa.shared.utils.llm import DEPLOYMENT_CENTRIC_PROVIDERS
25
28
 
26
-
27
29
  structlogger = structlog.get_logger()
28
30
 
29
31
  _LITELLM_UNSUPPORTED_KEYS = [
@@ -120,7 +122,7 @@ class LiteLLMRouterClientConfig:
120
122
  raise ValueError(message)
121
123
 
122
124
  @classmethod
123
- def from_dict(cls, config: dict) -> "LiteLLMRouterClientConfig":
125
+ def from_dict(cls, config: dict) -> LiteLLMRouterClientConfig:
124
126
  """Initializes a dataclass from the passed config.
125
127
 
126
128
  Args:
@@ -1,19 +1,25 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import List, Optional
3
5
 
4
6
  import structlog
7
+
5
8
  from rasa.shared.constants import (
6
9
  API_BASE_CONFIG_KEY,
7
10
  API_KEY,
8
11
  API_TYPE_CONFIG_KEY,
9
12
  API_VERSION_CONFIG_KEY,
10
13
  DEPLOYMENT_CONFIG_KEY,
11
- PROVIDER_CONFIG_KEY,
14
+ EXTRA_PARAMETERS_KEY,
12
15
  MODEL_CONFIG_KEY,
13
16
  MODEL_GROUP_ID_CONFIG_KEY,
14
- MODELS_CONFIG_KEY,
15
17
  MODEL_GROUPS_CONFIG_KEY,
16
- EXTRA_PARAMETERS_KEY,
18
+ MODELS_CONFIG_KEY,
19
+ PROVIDER_CONFIG_KEY,
20
+ )
21
+ from rasa.shared.providers._configs._lite_llm_config.lite_llm_config_adapter import (
22
+ to_lite_llm_config,
17
23
  )
18
24
  from rasa.shared.providers.mappings import get_client_config_class_from_provider
19
25
 
@@ -40,7 +46,7 @@ class ModelConfig:
40
46
  api_type: Optional[str] = None
41
47
 
42
48
  @classmethod
43
- def from_dict(cls, config: dict) -> "ModelConfig":
49
+ def from_dict(cls, config: dict) -> ModelConfig:
44
50
  """Initializes a dataclass from the passed config. The provider config param is
45
51
  used to determine the client config class to use. The client config class takes
46
52
  care of resolving config aliases and throwing deprecation warnings.
@@ -130,7 +136,7 @@ class ModelGroupConfig:
130
136
  raise ValueError(message)
131
137
 
132
138
  @classmethod
133
- def from_dict(cls, config: dict) -> "ModelGroupConfig":
139
+ def from_dict(cls, config: dict) -> ModelGroupConfig:
134
140
  """Initializes a dataclass from the passed config.
135
141
 
136
142
  Args:
@@ -0,0 +1,33 @@
1
+ import abc
2
+ from typing import Any, Dict, TypeVar
3
+
4
+ OAUTH_TYPE_FIELD = "type"
5
+ OAUTH_KEY = "oauth"
6
+
7
+ OAuthType = TypeVar("OAuthType", bound="OAuth")
8
+
9
+
10
+ class OAuth(abc.ABC):
11
+ """Interface for OAuth configuration."""
12
+
13
+ @classmethod
14
+ @abc.abstractmethod
15
+ def from_config(cls: OAuthType, config: Dict[str, Any]) -> OAuthType:
16
+ """Initializes a dataclass from the passed config.
17
+
18
+ Args:
19
+ config: (dict) The config from which to initialize.
20
+
21
+ Returns:
22
+ OAuth
23
+ """
24
+ ...
25
+
26
+ @abc.abstractmethod
27
+ def get_bearer_token(self) -> str:
28
+ """Returns a bearer token.
29
+
30
+ Bear token is used to authenticate requests to the Azure Oopen AI instance's API protected
31
+ by the Gateway.
32
+ """
33
+ ...
@@ -1,32 +1,34 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import Any, Dict, Optional
3
5
 
4
6
  import structlog
5
7
 
6
8
  from rasa.shared.constants import (
9
+ API_BASE_CONFIG_KEY,
10
+ API_TYPE_CONFIG_KEY,
11
+ API_VERSION_CONFIG_KEY,
12
+ LANGCHAIN_TYPE_CONFIG_KEY,
7
13
  MODEL_CONFIG_KEY,
8
14
  MODEL_NAME_CONFIG_KEY,
15
+ N_REPHRASES_CONFIG_KEY,
9
16
  OPENAI_API_BASE_CONFIG_KEY,
10
- API_BASE_CONFIG_KEY,
17
+ OPENAI_API_TYPE,
11
18
  OPENAI_API_TYPE_CONFIG_KEY,
12
- API_TYPE_CONFIG_KEY,
13
19
  OPENAI_API_VERSION_CONFIG_KEY,
14
- API_VERSION_CONFIG_KEY,
20
+ OPENAI_PROVIDER,
21
+ PROVIDER_CONFIG_KEY,
15
22
  RASA_TYPE_CONFIG_KEY,
16
- LANGCHAIN_TYPE_CONFIG_KEY,
17
- STREAM_CONFIG_KEY,
18
- N_REPHRASES_CONFIG_KEY,
19
23
  REQUEST_TIMEOUT_CONFIG_KEY,
24
+ STREAM_CONFIG_KEY,
20
25
  TIMEOUT_CONFIG_KEY,
21
- PROVIDER_CONFIG_KEY,
22
- OPENAI_API_TYPE,
23
- OPENAI_PROVIDER,
24
26
  )
25
27
  from rasa.shared.providers._configs.utils import (
26
- resolve_aliases,
27
- validate_required_keys,
28
28
  raise_deprecation_warnings,
29
+ resolve_aliases,
29
30
  validate_forbidden_keys,
31
+ validate_required_keys,
30
32
  )
31
33
 
32
34
  structlogger = structlog.get_logger()
@@ -111,7 +113,7 @@ class OpenAIClientConfig:
111
113
  raise ValueError(message)
112
114
 
113
115
  @classmethod
114
- def from_dict(cls, config: dict) -> "OpenAIClientConfig":
116
+ def from_dict(cls, config: dict) -> OpenAIClientConfig:
115
117
  """
116
118
  Initializes a dataclass from the passed config.
117
119
 
@@ -1,13 +1,15 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import Optional
3
5
 
4
6
  import structlog
5
7
 
6
8
  from rasa.shared.constants import (
9
+ API_BASE_CONFIG_KEY,
7
10
  MODEL_CONFIG_KEY,
8
- RASA_PROVIDER,
9
11
  PROVIDER_CONFIG_KEY,
10
- API_BASE_CONFIG_KEY,
12
+ RASA_PROVIDER,
11
13
  )
12
14
  from rasa.shared.providers._configs.utils import (
13
15
  validate_required_keys,
@@ -37,7 +39,7 @@ class RasaLLMClientConfig:
37
39
  extra_parameters: dict = field(default_factory=dict)
38
40
 
39
41
  @classmethod
40
- def from_dict(cls, config: dict) -> "RasaLLMClientConfig":
42
+ def from_dict(cls, config: dict) -> RasaLLMClientConfig:
41
43
  """
42
44
  Initializes a dataclass from the passed config.
43
45
 
@@ -1,29 +1,30 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import asdict, dataclass, field
2
4
  from typing import Any, Dict, Optional
3
5
 
4
6
  import structlog
5
7
 
6
8
  from rasa.shared.constants import (
9
+ API_BASE_CONFIG_KEY,
10
+ API_TYPE_CONFIG_KEY,
11
+ API_VERSION_CONFIG_KEY,
12
+ LANGCHAIN_TYPE_CONFIG_KEY,
7
13
  MODEL_CONFIG_KEY,
8
14
  MODEL_NAME_CONFIG_KEY,
15
+ N_REPHRASES_CONFIG_KEY,
9
16
  OPENAI_API_BASE_CONFIG_KEY,
10
- API_BASE_CONFIG_KEY,
11
17
  OPENAI_API_TYPE_CONFIG_KEY,
12
- API_TYPE_CONFIG_KEY,
13
18
  OPENAI_API_VERSION_CONFIG_KEY,
14
- API_VERSION_CONFIG_KEY,
19
+ OPENAI_PROVIDER,
20
+ PROVIDER_CONFIG_KEY,
15
21
  RASA_TYPE_CONFIG_KEY,
16
- LANGCHAIN_TYPE_CONFIG_KEY,
17
- STREAM_CONFIG_KEY,
18
- N_REPHRASES_CONFIG_KEY,
19
22
  REQUEST_TIMEOUT_CONFIG_KEY,
20
- TIMEOUT_CONFIG_KEY,
21
- PROVIDER_CONFIG_KEY,
22
- OPENAI_PROVIDER,
23
23
  SELF_HOSTED_PROVIDER,
24
+ STREAM_CONFIG_KEY,
25
+ TIMEOUT_CONFIG_KEY,
24
26
  USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
25
27
  )
26
-
27
28
  from rasa.shared.providers._configs.utils import (
28
29
  raise_deprecation_warnings,
29
30
  resolve_aliases,
@@ -114,7 +115,7 @@ class SelfHostedLLMClientConfig:
114
115
  raise ValueError(message)
115
116
 
116
117
  @classmethod
117
- def from_dict(cls, config: dict) -> "SelfHostedLLMClientConfig":
118
+ def from_dict(cls, config: dict) -> SelfHostedLLMClientConfig:
118
119
  """
119
120
  Initializes a dataclass from the passed config.
120
121
 
@@ -0,0 +1,6 @@
1
+ DEFAULT_AZURE_API_KEY_NAME = "AZURE_API_KEY"
2
+ LITE_LLM_API_BASE_FIELD = "api_base"
3
+ LITE_LLM_API_KEY_FIELD = "api_key"
4
+ LITE_LLM_API_VERSION_FIELD = "api_version"
5
+ LITE_LLM_MODEL_FIELD = "model"
6
+ LITE_LLM_AZURE_AD_TOKEN = "azure_ad_token"