rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (183) hide show
  1. README.md +0 -374
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +27 -23
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +11 -3
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +104 -431
  12. rasa/cli/evaluate.py +1 -1
  13. rasa/cli/interactive.py +1 -0
  14. rasa/cli/llm_fine_tuning.py +398 -0
  15. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  16. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  17. rasa/cli/run.py +15 -14
  18. rasa/cli/scaffold.py +10 -8
  19. rasa/cli/studio/studio.py +35 -5
  20. rasa/cli/train.py +56 -8
  21. rasa/cli/utils.py +22 -5
  22. rasa/cli/x.py +1 -1
  23. rasa/constants.py +7 -1
  24. rasa/core/actions/action.py +98 -49
  25. rasa/core/actions/action_run_slot_rejections.py +4 -1
  26. rasa/core/actions/custom_action_executor.py +9 -6
  27. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  29. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  30. rasa/core/actions/http_custom_action_executor.py +6 -5
  31. rasa/core/agent.py +21 -17
  32. rasa/core/channels/__init__.py +2 -0
  33. rasa/core/channels/audiocodes.py +1 -16
  34. rasa/core/channels/voice_aware/__init__.py +0 -0
  35. rasa/core/channels/voice_aware/jambonz.py +103 -0
  36. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  37. rasa/core/channels/voice_aware/utils.py +20 -0
  38. rasa/core/channels/voice_native/__init__.py +0 -0
  39. rasa/core/constants.py +6 -1
  40. rasa/core/information_retrieval/faiss.py +7 -4
  41. rasa/core/information_retrieval/information_retrieval.py +8 -0
  42. rasa/core/information_retrieval/milvus.py +9 -2
  43. rasa/core/information_retrieval/qdrant.py +1 -1
  44. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  45. rasa/core/nlg/summarize.py +4 -3
  46. rasa/core/policies/enterprise_search_policy.py +113 -45
  47. rasa/core/policies/flows/flow_executor.py +122 -76
  48. rasa/core/policies/intentless_policy.py +83 -29
  49. rasa/core/processor.py +72 -54
  50. rasa/core/run.py +5 -4
  51. rasa/core/tracker_store.py +8 -4
  52. rasa/core/training/interactive.py +1 -1
  53. rasa/core/utils.py +56 -57
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
  55. rasa/dialogue_understanding/commands/__init__.py +6 -0
  56. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/utils.py +40 -0
  59. rasa/dialogue_understanding/generator/constants.py +10 -3
  60. rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
  61. rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
  62. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
  63. rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
  64. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
  65. rasa/dialogue_understanding/patterns/restart.py +37 -0
  66. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  67. rasa/dialogue_understanding/processor/command_processor.py +16 -3
  68. rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  70. rasa/e2e_test/assertions.py +1223 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +493 -71
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +598 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/graph.py +9 -3
  89. rasa/engine/recipes/default_components.py +0 -2
  90. rasa/engine/recipes/default_recipe.py +10 -2
  91. rasa/engine/storage/local_model_storage.py +40 -12
  92. rasa/engine/validation.py +78 -1
  93. rasa/env.py +9 -0
  94. rasa/graph_components/providers/story_graph_provider.py +59 -6
  95. rasa/llm_fine_tuning/__init__.py +0 -0
  96. rasa/llm_fine_tuning/annotation_module.py +241 -0
  97. rasa/llm_fine_tuning/conversations.py +144 -0
  98. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  99. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  100. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  101. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  102. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  104. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  105. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  106. rasa/llm_fine_tuning/storage.py +174 -0
  107. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  108. rasa/model_training.py +56 -16
  109. rasa/nlu/persistor.py +157 -36
  110. rasa/server.py +45 -10
  111. rasa/shared/constants.py +76 -16
  112. rasa/shared/core/domain.py +27 -19
  113. rasa/shared/core/events.py +28 -2
  114. rasa/shared/core/flows/flow.py +208 -13
  115. rasa/shared/core/flows/flow_path.py +84 -0
  116. rasa/shared/core/flows/flows_list.py +33 -11
  117. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  118. rasa/shared/core/flows/validation.py +112 -25
  119. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  120. rasa/shared/core/trackers.py +6 -0
  121. rasa/shared/core/training_data/structures.py +20 -0
  122. rasa/shared/core/training_data/visualization.html +2 -2
  123. rasa/shared/exceptions.py +4 -0
  124. rasa/shared/importers/importer.py +64 -16
  125. rasa/shared/nlu/constants.py +2 -0
  126. rasa/shared/providers/_configs/__init__.py +0 -0
  127. rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
  128. rasa/shared/providers/_configs/client_config.py +57 -0
  129. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  130. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  131. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  132. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
  133. rasa/shared/providers/_configs/utils.py +101 -0
  134. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  135. rasa/shared/providers/embedding/__init__.py +0 -0
  136. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
  137. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  138. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  139. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  140. rasa/shared/providers/embedding/embedding_client.py +90 -0
  141. rasa/shared/providers/embedding/embedding_response.py +41 -0
  142. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  143. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  144. rasa/shared/providers/llm/__init__.py +0 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +251 -0
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  147. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  148. rasa/shared/providers/llm/llm_client.py +76 -0
  149. rasa/shared/providers/llm/llm_response.py +50 -0
  150. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  151. rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
  152. rasa/shared/providers/mappings.py +75 -0
  153. rasa/shared/utils/cli.py +30 -0
  154. rasa/shared/utils/io.py +65 -2
  155. rasa/shared/utils/llm.py +246 -200
  156. rasa/shared/utils/yaml.py +121 -15
  157. rasa/studio/auth.py +6 -4
  158. rasa/studio/config.py +13 -4
  159. rasa/studio/constants.py +1 -0
  160. rasa/studio/data_handler.py +10 -3
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +19 -11
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/instrumentation/attribute_extractors.py +32 -17
  166. rasa/utils/common.py +18 -19
  167. rasa/utils/endpoints.py +7 -4
  168. rasa/utils/json_utils.py +60 -0
  169. rasa/utils/licensing.py +9 -1
  170. rasa/utils/ml_utils.py +4 -2
  171. rasa/validator.py +213 -3
  172. rasa/version.py +1 -1
  173. rasa_pro-3.10.16.dist-info/METADATA +196 -0
  174. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
  175. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  176. rasa/shared/providers/openai/clients.py +0 -43
  177. rasa/shared/providers/openai/session_handler.py +0 -110
  178. rasa_pro-3.9.18.dist-info/METADATA +0 -563
  179. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  180. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
  182. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
  183. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
@@ -1,28 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
1
4
  from abc import ABC, abstractmethod
2
5
  from functools import reduce
3
- from typing import Text, Optional, List, Dict, Set, Any, Tuple, Type, Union, cast
4
- import logging
6
+ from typing import Any, Dict, List, Optional, Set, Text, Tuple, Type, Union, cast
5
7
 
6
8
  import importlib_resources
7
9
 
8
10
  import rasa.shared.constants
9
- from rasa.shared.core.flows import FlowsList
10
- import rasa.shared.utils.common
11
11
  import rasa.shared.core.constants
12
+ import rasa.shared.utils.common
12
13
  import rasa.shared.utils.io
13
14
  from rasa.shared.core.domain import (
14
- Domain,
15
+ IS_RETRIEVAL_INTENT_KEY,
16
+ KEY_ACTIONS,
15
17
  KEY_E2E_ACTIONS,
16
18
  KEY_INTENTS,
17
19
  KEY_RESPONSES,
18
- KEY_ACTIONS,
20
+ Domain,
19
21
  )
20
22
  from rasa.shared.core.events import ActionExecuted, UserUttered
23
+ from rasa.shared.core.flows import FlowsList
21
24
  from rasa.shared.core.training_data.structures import StoryGraph
25
+ from rasa.shared.nlu.constants import ACTION_NAME, ENTITIES
22
26
  from rasa.shared.nlu.training_data.message import Message
23
27
  from rasa.shared.nlu.training_data.training_data import TrainingData
24
- from rasa.shared.nlu.constants import ENTITIES, ACTION_NAME
25
- from rasa.shared.core.domain import IS_RETRIEVAL_INTENT_KEY
26
28
  from rasa.shared.utils.yaml import read_config_file
27
29
 
28
30
  logger = logging.getLogger(__name__)
@@ -114,7 +116,7 @@ class TrainingDataImporter(ABC):
114
116
  domain_path: Optional[Text] = None,
115
117
  training_data_paths: Optional[List[Text]] = None,
116
118
  args: Optional[Dict[Text, Any]] = {},
117
- ) -> "TrainingDataImporter":
119
+ ) -> TrainingDataImporter:
118
120
  """Loads a `TrainingDataImporter` instance from a configuration file."""
119
121
  config = read_config_file(config_path)
120
122
  return TrainingDataImporter.load_from_dict(
@@ -127,7 +129,7 @@ class TrainingDataImporter(ABC):
127
129
  domain_path: Optional[Text] = None,
128
130
  training_data_paths: Optional[List[Text]] = None,
129
131
  args: Optional[Dict[Text, Any]] = {},
130
- ) -> "TrainingDataImporter":
132
+ ) -> TrainingDataImporter:
131
133
  """Loads core `TrainingDataImporter` instance.
132
134
 
133
135
  Instance loaded from configuration file will only read Core training data.
@@ -143,7 +145,7 @@ class TrainingDataImporter(ABC):
143
145
  domain_path: Optional[Text] = None,
144
146
  training_data_paths: Optional[List[Text]] = None,
145
147
  args: Optional[Dict[Text, Any]] = {},
146
- ) -> "TrainingDataImporter":
148
+ ) -> TrainingDataImporter:
147
149
  """Loads nlu `TrainingDataImporter` instance.
148
150
 
149
151
  Instance loaded from configuration file will only read NLU training data.
@@ -165,8 +167,8 @@ class TrainingDataImporter(ABC):
165
167
  config_path: Optional[Text] = None,
166
168
  domain_path: Optional[Text] = None,
167
169
  training_data_paths: Optional[List[Text]] = None,
168
- args: Optional[Dict[Text, Any]] = {},
169
- ) -> "TrainingDataImporter":
170
+ args: Optional[Dict[Text, Any]] = None,
171
+ ) -> TrainingDataImporter:
170
172
  """Loads a `TrainingDataImporter` instance from a dictionary."""
171
173
  from rasa.shared.importers.rasa import RasaFileImporter
172
174
 
@@ -194,8 +196,8 @@ class TrainingDataImporter(ABC):
194
196
  config_path: Text,
195
197
  domain_path: Optional[Text] = None,
196
198
  training_data_paths: Optional[List[Text]] = None,
197
- args: Optional[Dict[Text, Any]] = {},
198
- ) -> Optional["TrainingDataImporter"]:
199
+ args: Optional[Dict[Text, Any]] = None,
200
+ ) -> Optional[TrainingDataImporter]:
199
201
  from rasa.shared.importers.multi_project import MultiProjectImporter
200
202
  from rasa.shared.importers.rasa import RasaFileImporter
201
203
 
@@ -216,7 +218,6 @@ class TrainingDataImporter(ABC):
216
218
  constructor_arguments = rasa.shared.utils.common.minimal_kwargs(
217
219
  {**importer_config, **(args or {})}, importer_class
218
220
  )
219
-
220
221
  return importer_class(
221
222
  config_path,
222
223
  domain_path,
@@ -232,6 +233,26 @@ class TrainingDataImporter(ABC):
232
233
  """Returns text representation of object."""
233
234
  return self.__class__.__name__
234
235
 
236
+ def get_user_flows(self) -> FlowsList:
237
+ """Retrieves the user-defined flows that should be used for training.
238
+
239
+ Implemented by FlowSyncImporter and E2EImporter only.
240
+
241
+ Returns:
242
+ `FlowsList` containing all loaded flows.
243
+ """
244
+ raise NotImplementedError
245
+
246
+ def get_user_domain(self) -> Domain:
247
+ """Retrieves the user-defined domain that should be used for training.
248
+
249
+ Implemented by FlowSyncImporter and E2EImporter only.
250
+
251
+ Returns:
252
+ `Domain`.
253
+ """
254
+ raise NotImplementedError
255
+
235
256
 
236
257
  class NluDataImporter(TrainingDataImporter):
237
258
  """Importer that skips any Core-related file reading."""
@@ -448,6 +469,10 @@ class FlowSyncImporter(PassThroughImporter):
448
469
 
449
470
  return self.merge_with_default_flows(flows)
450
471
 
472
+ @rasa.shared.utils.common.cached_method
473
+ def get_user_flows(self) -> FlowsList:
474
+ return self._importer.get_flows()
475
+
451
476
  @rasa.shared.utils.common.cached_method
452
477
  def get_domain(self) -> Domain:
453
478
  """Merge existing domain with properties of flows."""
@@ -476,6 +501,11 @@ class FlowSyncImporter(PassThroughImporter):
476
501
  )
477
502
  return domain
478
503
 
504
+ @rasa.shared.utils.common.cached_method
505
+ def get_user_domain(self) -> Domain:
506
+ """Retrieves only user defined domain."""
507
+ return self._importer.get_domain()
508
+
479
509
 
480
510
  class ResponsesSyncImporter(PassThroughImporter):
481
511
  """Importer that syncs `responses` between Domain and NLU training data.
@@ -602,6 +632,15 @@ class E2EImporter(PassThroughImporter):
602
632
  - adds potential end-to-end bot messages from stories as actions to the domain
603
633
  """
604
634
 
635
+ @rasa.shared.utils.common.cached_method
636
+ def get_user_flows(self) -> FlowsList:
637
+ if not isinstance(self._importer, FlowSyncImporter):
638
+ raise NotImplementedError(
639
+ "Accessing user flows is only supported with FlowSyncImporter."
640
+ )
641
+
642
+ return self._importer.get_user_flows()
643
+
605
644
  @rasa.shared.utils.common.cached_method
606
645
  def get_domain(self) -> Domain:
607
646
  """Retrieves model domain (see parent class for full docstring)."""
@@ -610,6 +649,15 @@ class E2EImporter(PassThroughImporter):
610
649
 
611
650
  return original.merge(e2e_domain)
612
651
 
652
+ @rasa.shared.utils.common.cached_method
653
+ def get_user_domain(self) -> Domain:
654
+ """Retrieves only user defined domain."""
655
+ if not isinstance(self._importer, FlowSyncImporter):
656
+ raise NotImplementedError(
657
+ "Accessing user domain is only supported with FlowSyncImporter."
658
+ )
659
+ return self._importer.get_user_domain()
660
+
613
661
  def _get_domain_with_e2e_actions(self) -> Domain:
614
662
  stories = self.get_stories()
615
663
 
@@ -2,6 +2,8 @@ TEXT = "text"
2
2
  TEXT_TOKENS = "text_tokens"
3
3
  INTENT = "intent"
4
4
  COMMANDS = "commands"
5
+ LLM_COMMANDS = "llm_commands" # needed for fine-tuning
6
+ LLM_PROMPT = "llm_prompt" # needed for fine-tuning
5
7
  FLOWS_FROM_SEMANTIC_SEARCH = "flows_from_semantic_search"
6
8
  FLOWS_IN_PROMPT = "flows_in_prompt"
7
9
  NOT_INTENT = "not_intent"
File without changes
@@ -0,0 +1,183 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Any, Dict, Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ MODEL_NAME_CONFIG_KEY,
9
+ OPENAI_API_BASE_CONFIG_KEY,
10
+ API_BASE_CONFIG_KEY,
11
+ OPENAI_API_TYPE_CONFIG_KEY,
12
+ API_TYPE_CONFIG_KEY,
13
+ OPENAI_API_VERSION_CONFIG_KEY,
14
+ API_VERSION_CONFIG_KEY,
15
+ DEPLOYMENT_CONFIG_KEY,
16
+ DEPLOYMENT_NAME_CONFIG_KEY,
17
+ ENGINE_CONFIG_KEY,
18
+ RASA_TYPE_CONFIG_KEY,
19
+ LANGCHAIN_TYPE_CONFIG_KEY,
20
+ STREAM_CONFIG_KEY,
21
+ N_REPHRASES_CONFIG_KEY,
22
+ REQUEST_TIMEOUT_CONFIG_KEY,
23
+ TIMEOUT_CONFIG_KEY,
24
+ PROVIDER_CONFIG_KEY,
25
+ AZURE_OPENAI_PROVIDER,
26
+ AZURE_API_TYPE,
27
+ )
28
+ from rasa.shared.providers._configs.utils import (
29
+ resolve_aliases,
30
+ raise_deprecation_warnings,
31
+ validate_required_keys,
32
+ validate_forbidden_keys,
33
+ )
34
+
35
+ structlogger = structlog.get_logger()
36
+
37
+ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
38
+ # Deployment name aliases
39
+ DEPLOYMENT_NAME_CONFIG_KEY: DEPLOYMENT_CONFIG_KEY,
40
+ ENGINE_CONFIG_KEY: DEPLOYMENT_CONFIG_KEY,
41
+ # Provider aliases
42
+ RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
43
+ LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
44
+ # API type aliases
45
+ OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
46
+ # API base aliases
47
+ OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
48
+ # API version aliases
49
+ OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
50
+ # Model name aliases
51
+ MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
52
+ # Timeout aliases
53
+ REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
54
+ }
55
+
56
+ REQUIRED_KEYS = [DEPLOYMENT_CONFIG_KEY]
57
+
58
+ FORBIDDEN_KEYS = [
59
+ STREAM_CONFIG_KEY,
60
+ N_REPHRASES_CONFIG_KEY,
61
+ ]
62
+
63
+
64
+ @dataclass
65
+ class AzureOpenAIClientConfig:
66
+ """Parses configuration for Azure OpenAI client, resolves aliases and
67
+ raises deprecation warnings.
68
+
69
+ Raises:
70
+ ValueError: Raised in cases of invalid configuration:
71
+ - If any of the required configuration keys are missing.
72
+ - If `api_type` has a value different from `azure`.
73
+ """
74
+
75
+ deployment: str
76
+
77
+ model: Optional[str]
78
+ api_base: Optional[str]
79
+ api_version: Optional[str]
80
+ # API Type is not used by LiteLLM backend, but we define
81
+ # it here for backward compatibility.
82
+ api_type: Optional[str] = AZURE_API_TYPE
83
+
84
+ # Provider is not used by LiteLLM backend, but we define it here since it's
85
+ # used as switch between different clients.
86
+ provider: str = AZURE_OPENAI_PROVIDER
87
+
88
+ extra_parameters: dict = field(default_factory=dict)
89
+
90
+ def __post_init__(self) -> None:
91
+ if self.provider != AZURE_OPENAI_PROVIDER:
92
+ message = f"Provider must be set to '{AZURE_OPENAI_PROVIDER}'."
93
+ structlogger.error(
94
+ "azure_openai_client_config.validation_error",
95
+ message=message,
96
+ provider=self.provider,
97
+ )
98
+ raise ValueError(message)
99
+ if self.deployment is None:
100
+ message = "Deployment cannot be set to None."
101
+ structlogger.error(
102
+ "azure_openai_client_config.validation_error",
103
+ message=message,
104
+ deployment=self.deployment,
105
+ )
106
+ raise ValueError(message)
107
+
108
+ @classmethod
109
+ def from_dict(cls, config: dict) -> "AzureOpenAIClientConfig":
110
+ """Initializes a dataclass from the passed config.
111
+
112
+ Args:
113
+ config: (dict) The config from which to initialize.
114
+
115
+ Raises:
116
+ ValueError: Raised in cases of invalid configuration:
117
+ - If any of the required configuration keys are missing.
118
+ - If `api_type` has a value different from `azure`.
119
+
120
+ Returns:
121
+ AzureOpenAIClientConfig
122
+ """
123
+ # Check for deprecated keys
124
+ raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
125
+ # Resolve any potential aliases
126
+ config = cls.resolve_config_aliases(config)
127
+ # Validate that required keys are set
128
+ validate_required_keys(config, REQUIRED_KEYS)
129
+ # Validate that the forbidden keys are not present
130
+ validate_forbidden_keys(config, FORBIDDEN_KEYS)
131
+ # Init client config
132
+ this = AzureOpenAIClientConfig(
133
+ # Required parameters
134
+ deployment=config.pop(DEPLOYMENT_CONFIG_KEY),
135
+ # Pop the 'provider' key. Currently, it's *optional* because of
136
+ # backward compatibility with older versions.
137
+ provider=config.pop(PROVIDER_CONFIG_KEY, AZURE_OPENAI_PROVIDER),
138
+ # Optional
139
+ api_type=config.pop(API_TYPE_CONFIG_KEY, AZURE_API_TYPE),
140
+ model=config.pop(MODEL_CONFIG_KEY, None),
141
+ # Optional, can also be set through environment variables
142
+ # in clients.
143
+ api_base=config.pop(API_BASE_CONFIG_KEY, None),
144
+ api_version=config.pop(API_VERSION_CONFIG_KEY, None),
145
+ # The rest of parameters (e.g. model parameters) are considered
146
+ # as extra parameters (this also includes timeout).
147
+ extra_parameters=config,
148
+ )
149
+ return this
150
+
151
+ def to_dict(self) -> dict:
152
+ """Converts the config instance into a dictionary."""
153
+ d = asdict(self)
154
+ # Extra parameters should also be on the top level
155
+ d.pop("extra_parameters", None)
156
+ d.update(self.extra_parameters)
157
+ return d
158
+
159
+ @staticmethod
160
+ def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
161
+ return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
162
+
163
+
164
+ def is_azure_openai_config(config: dict) -> bool:
165
+ """Check whether the configuration is meant to configure
166
+ an Azure OpenAI client.
167
+ """
168
+ # Resolve any aliases that are specific to Azure OpenAI configuration
169
+ config = AzureOpenAIClientConfig.resolve_config_aliases(config)
170
+
171
+ # Case: Configuration contains `provider: azure`.
172
+ if config.get(PROVIDER_CONFIG_KEY) == AZURE_OPENAI_PROVIDER:
173
+ return True
174
+
175
+ # Case: Configuration contains `deployment` key
176
+ # (specific to Azure OpenAI configuration)
177
+ if (
178
+ config.get(DEPLOYMENT_CONFIG_KEY) is not None
179
+ and config.get(PROVIDER_CONFIG_KEY) is None
180
+ ):
181
+ return True
182
+
183
+ return False
@@ -0,0 +1,57 @@
1
+ from typing import Protocol, runtime_checkable
2
+
3
+
4
+ @runtime_checkable
5
+ class ClientConfig(Protocol):
6
+ """
7
+ Protocol for the client config that specifies the interface for interacting
8
+ with the API.
9
+ """
10
+
11
+ @classmethod
12
+ def from_dict(cls, config: dict) -> "ClientConfig":
13
+ """
14
+ Initializes the client config with the given configuration.
15
+
16
+ This class method should be implemented to parse the given
17
+ configuration and create an instance of an client config.
18
+
19
+ Args:
20
+ config: (dict) The config from which to initialize.
21
+
22
+ Raises:
23
+ ValueError: Config is missing required keys.
24
+
25
+ Returns:
26
+ ClientConfig
27
+ """
28
+ ...
29
+
30
+ def to_dict(self) -> dict:
31
+ """
32
+ Returns the configuration for that the client config is initialized with.
33
+
34
+ This method should be implemented to return a dictionary containing
35
+ the configuration settings for the client config.
36
+
37
+ Returns:
38
+ dictionary containing the configuration settings for the client config.
39
+ """
40
+ ...
41
+
42
+ @staticmethod
43
+ def resolve_config_aliases(config: dict) -> dict:
44
+ """
45
+ Resolve any potential aliases in the configuration.
46
+
47
+ This method should be implemented to resolve any potential aliases in the
48
+ configuration.
49
+
50
+ Args:
51
+ config: (dict) The config from which to initialize.
52
+
53
+ Returns:
54
+ dictionary containing the resolved configuration settings for the
55
+ client config.
56
+ """
57
+ ...
@@ -0,0 +1,130 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Any, Dict
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ MODEL_NAME_CONFIG_KEY,
9
+ STREAM_CONFIG_KEY,
10
+ N_REPHRASES_CONFIG_KEY,
11
+ PROVIDER_CONFIG_KEY,
12
+ TIMEOUT_CONFIG_KEY,
13
+ REQUEST_TIMEOUT_CONFIG_KEY,
14
+ )
15
+ from rasa.shared.providers._configs.utils import (
16
+ validate_required_keys,
17
+ validate_forbidden_keys,
18
+ resolve_aliases,
19
+ raise_deprecation_warnings,
20
+ )
21
+ import rasa.shared.utils.cli
22
+
23
+ structlogger = structlog.get_logger()
24
+
25
+
26
+ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
27
+ # Timeout aliases
28
+ REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
29
+ }
30
+
31
+ REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
32
+
33
+ FORBIDDEN_KEYS = [
34
+ STREAM_CONFIG_KEY,
35
+ N_REPHRASES_CONFIG_KEY,
36
+ ]
37
+
38
+
39
+ @dataclass
40
+ class DefaultLiteLLMClientConfig:
41
+ """Parses configuration for default LiteLLM client, resolves aliases and
42
+ raises deprecation warnings.
43
+
44
+ Raises:
45
+ ValueError: Raised in cases of invalid configuration:
46
+ - If any of the required configuration keys are missing.
47
+ """
48
+
49
+ model: str
50
+ provider: str
51
+ extra_parameters: dict = field(default_factory=dict)
52
+
53
+ def __post_init__(self) -> None:
54
+ if self.model is None:
55
+ message = "Model cannot be set to None."
56
+ structlogger.error(
57
+ "default_litellm_client_config.validation_error",
58
+ message=message,
59
+ model=self.model,
60
+ )
61
+ raise ValueError(message)
62
+ if self.provider is None:
63
+ message = "Provider cannot be set to None."
64
+ structlogger.error(
65
+ "default_litellm_client_config.validation_error",
66
+ message=message,
67
+ provider=self.provider,
68
+ )
69
+ raise ValueError(message)
70
+
71
+ @classmethod
72
+ def from_dict(cls, config: dict) -> "DefaultLiteLLMClientConfig":
73
+ """
74
+ Initializes a dataclass from the passed config.
75
+
76
+ Args:
77
+ config: (dict) The config from which to initialize.
78
+
79
+ Raises:
80
+ ValueError: Config is missing required keys.
81
+
82
+ Returns:
83
+ DefaultLiteLLMClientConfig
84
+ """
85
+ # Check for deprecated keys
86
+ raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
87
+ # Raise error for using `model_name` instead instead of `model`
88
+ cls.check_and_error_for_model_name_in_config(config)
89
+ # Resolve any potential aliases.
90
+ config = cls.resolve_config_aliases(config)
91
+ # Validate that the required keys are present
92
+ validate_required_keys(config, REQUIRED_KEYS)
93
+ # Validate that the forbidden keys are not present
94
+ validate_forbidden_keys(config, FORBIDDEN_KEYS)
95
+ this = DefaultLiteLLMClientConfig(
96
+ # Required parameters
97
+ model=config.pop(MODEL_CONFIG_KEY),
98
+ provider=config.pop(PROVIDER_CONFIG_KEY),
99
+ # The rest of parameters (e.g. model parameters) are considered
100
+ # as extra parameters
101
+ extra_parameters=config,
102
+ )
103
+ return this
104
+
105
+ def to_dict(self) -> dict:
106
+ """Converts the config instance into a dictionary."""
107
+ d = asdict(self)
108
+ # Extra parameters should also be on the top level
109
+ d.pop("extra_parameters", None)
110
+ d.update(self.extra_parameters)
111
+ return d
112
+
113
+ @staticmethod
114
+ def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
115
+ return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
116
+
117
+ @staticmethod
118
+ def check_and_error_for_model_name_in_config(config: Dict[str, Any]) -> None:
119
+ """Check for usage of deprecated model_name and raise an error if found."""
120
+ if config.get(MODEL_NAME_CONFIG_KEY) and not config.get(MODEL_CONFIG_KEY):
121
+ event_info = (
122
+ f"Unsupported parameter - {MODEL_NAME_CONFIG_KEY} is set. Please use "
123
+ f"{MODEL_CONFIG_KEY} instead."
124
+ )
125
+ structlogger.error(
126
+ "default_litellm_client_config.unsupported_parameter_in_config",
127
+ event_info=event_info,
128
+ config=config,
129
+ )
130
+ rasa.shared.utils.cli.print_error_and_exit(event_info)