rasa-pro 3.9.17__py3-none-any.whl → 3.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (187) hide show
  1. README.md +5 -37
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +26 -22
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +2 -0
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +108 -433
  12. rasa/cli/interactive.py +1 -0
  13. rasa/cli/llm_fine_tuning.py +395 -0
  14. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  15. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  16. rasa/cli/run.py +14 -13
  17. rasa/cli/scaffold.py +10 -8
  18. rasa/cli/train.py +8 -7
  19. rasa/cli/utils.py +15 -0
  20. rasa/constants.py +7 -1
  21. rasa/core/actions/action.py +98 -49
  22. rasa/core/actions/action_run_slot_rejections.py +4 -1
  23. rasa/core/actions/custom_action_executor.py +9 -6
  24. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  25. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  26. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  27. rasa/core/actions/http_custom_action_executor.py +6 -5
  28. rasa/core/agent.py +21 -17
  29. rasa/core/channels/__init__.py +2 -0
  30. rasa/core/channels/audiocodes.py +1 -16
  31. rasa/core/channels/voice_aware/__init__.py +0 -0
  32. rasa/core/channels/voice_aware/jambonz.py +103 -0
  33. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  34. rasa/core/channels/voice_aware/utils.py +20 -0
  35. rasa/core/channels/voice_native/__init__.py +0 -0
  36. rasa/core/constants.py +6 -1
  37. rasa/core/featurizers/single_state_featurizer.py +1 -22
  38. rasa/core/featurizers/tracker_featurizers.py +18 -115
  39. rasa/core/information_retrieval/faiss.py +7 -4
  40. rasa/core/information_retrieval/information_retrieval.py +8 -0
  41. rasa/core/information_retrieval/milvus.py +9 -2
  42. rasa/core/information_retrieval/qdrant.py +1 -1
  43. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  44. rasa/core/nlg/summarize.py +4 -3
  45. rasa/core/policies/enterprise_search_policy.py +100 -44
  46. rasa/core/policies/flows/flow_executor.py +155 -98
  47. rasa/core/policies/intentless_policy.py +52 -28
  48. rasa/core/policies/ted_policy.py +33 -58
  49. rasa/core/policies/unexpected_intent_policy.py +7 -15
  50. rasa/core/processor.py +15 -46
  51. rasa/core/run.py +5 -4
  52. rasa/core/tracker_store.py +8 -4
  53. rasa/core/utils.py +45 -56
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
  55. rasa/dialogue_understanding/commands/__init__.py +4 -0
  56. rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
  59. rasa/dialogue_understanding/commands/utils.py +38 -0
  60. rasa/dialogue_understanding/generator/constants.py +10 -3
  61. rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
  62. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
  63. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
  64. rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
  65. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
  66. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
  67. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  68. rasa/dialogue_understanding/processor/command_processor.py +13 -14
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
  70. rasa/e2e_test/assertions.py +1181 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +498 -73
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +596 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/recipes/default_components.py +0 -2
  89. rasa/engine/storage/local_model_storage.py +0 -1
  90. rasa/env.py +9 -0
  91. rasa/llm_fine_tuning/__init__.py +0 -0
  92. rasa/llm_fine_tuning/annotation_module.py +241 -0
  93. rasa/llm_fine_tuning/conversations.py +144 -0
  94. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  95. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  96. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  97. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  98. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  99. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  100. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  101. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  102. rasa/llm_fine_tuning/storage.py +174 -0
  103. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  104. rasa/model_training.py +48 -16
  105. rasa/nlu/classifiers/diet_classifier.py +25 -38
  106. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
  107. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  108. rasa/nlu/extractors/crf_entity_extractor.py +50 -93
  109. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
  110. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
  111. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  112. rasa/nlu/persistor.py +129 -32
  113. rasa/server.py +45 -10
  114. rasa/shared/constants.py +63 -15
  115. rasa/shared/core/domain.py +15 -12
  116. rasa/shared/core/events.py +28 -2
  117. rasa/shared/core/flows/flow.py +208 -13
  118. rasa/shared/core/flows/flow_path.py +84 -0
  119. rasa/shared/core/flows/flows_list.py +28 -10
  120. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  121. rasa/shared/core/flows/validation.py +112 -25
  122. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  123. rasa/shared/core/trackers.py +6 -0
  124. rasa/shared/core/training_data/visualization.html +2 -2
  125. rasa/shared/exceptions.py +4 -0
  126. rasa/shared/importers/importer.py +60 -11
  127. rasa/shared/importers/remote_importer.py +196 -0
  128. rasa/shared/nlu/constants.py +2 -0
  129. rasa/shared/nlu/training_data/features.py +2 -120
  130. rasa/shared/providers/_configs/__init__.py +0 -0
  131. rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
  132. rasa/shared/providers/_configs/client_config.py +57 -0
  133. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  134. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  135. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  136. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
  137. rasa/shared/providers/_configs/utils.py +101 -0
  138. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  139. rasa/shared/providers/embedding/__init__.py +0 -0
  140. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
  141. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  142. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  143. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  144. rasa/shared/providers/embedding/embedding_client.py +90 -0
  145. rasa/shared/providers/embedding/embedding_response.py +41 -0
  146. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  147. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  148. rasa/shared/providers/llm/__init__.py +0 -0
  149. rasa/shared/providers/llm/_base_litellm_client.py +227 -0
  150. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  151. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  152. rasa/shared/providers/llm/llm_client.py +76 -0
  153. rasa/shared/providers/llm/llm_response.py +50 -0
  154. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  155. rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
  156. rasa/shared/providers/mappings.py +75 -0
  157. rasa/shared/utils/cli.py +30 -0
  158. rasa/shared/utils/io.py +65 -3
  159. rasa/shared/utils/llm.py +223 -200
  160. rasa/shared/utils/yaml.py +122 -7
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +2 -3
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/config.py +2 -3
  166. rasa/tracing/instrumentation/attribute_extractors.py +29 -17
  167. rasa/tracing/instrumentation/instrumentation.py +4 -47
  168. rasa/utils/common.py +18 -19
  169. rasa/utils/endpoints.py +7 -4
  170. rasa/utils/io.py +66 -0
  171. rasa/utils/json_utils.py +60 -0
  172. rasa/utils/licensing.py +9 -1
  173. rasa/utils/ml_utils.py +4 -2
  174. rasa/utils/tensorflow/model_data.py +193 -2
  175. rasa/validator.py +195 -1
  176. rasa/version.py +1 -1
  177. {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/METADATA +25 -51
  178. {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/RECORD +183 -119
  179. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  180. rasa/shared/providers/openai/clients.py +0 -43
  181. rasa/shared/providers/openai/session_handler.py +0 -110
  182. rasa/utils/tensorflow/feature_array.py +0 -366
  183. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  184. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  185. {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/NOTICE +0 -0
  186. {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/WHEEL +0 -0
  187. {rasa_pro-3.9.17.dist-info → rasa_pro-3.10.3.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,130 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Any, Dict
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ MODEL_NAME_CONFIG_KEY,
9
+ STREAM_CONFIG_KEY,
10
+ N_REPHRASES_CONFIG_KEY,
11
+ PROVIDER_CONFIG_KEY,
12
+ TIMEOUT_CONFIG_KEY,
13
+ REQUEST_TIMEOUT_CONFIG_KEY,
14
+ )
15
+ from rasa.shared.providers._configs.utils import (
16
+ validate_required_keys,
17
+ validate_forbidden_keys,
18
+ resolve_aliases,
19
+ raise_deprecation_warnings,
20
+ )
21
+ import rasa.shared.utils.cli
22
+
23
+ structlogger = structlog.get_logger()
24
+
25
+
26
+ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
27
+ # Timeout aliases
28
+ REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
29
+ }
30
+
31
+ REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
32
+
33
+ FORBIDDEN_KEYS = [
34
+ STREAM_CONFIG_KEY,
35
+ N_REPHRASES_CONFIG_KEY,
36
+ ]
37
+
38
+
39
+ @dataclass
40
+ class DefaultLiteLLMClientConfig:
41
+ """Parses configuration for default LiteLLM client, resolves aliases and
42
+ raises deprecation warnings.
43
+
44
+ Raises:
45
+ ValueError: Raised in cases of invalid configuration:
46
+ - If any of the required configuration keys are missing.
47
+ """
48
+
49
+ model: str
50
+ provider: str
51
+ extra_parameters: dict = field(default_factory=dict)
52
+
53
+ def __post_init__(self) -> None:
54
+ if self.model is None:
55
+ message = "Model cannot be set to None."
56
+ structlogger.error(
57
+ "default_litellm_client_config.validation_error",
58
+ message=message,
59
+ model=self.model,
60
+ )
61
+ raise ValueError(message)
62
+ if self.provider is None:
63
+ message = "Provider cannot be set to None."
64
+ structlogger.error(
65
+ "default_litellm_client_config.validation_error",
66
+ message=message,
67
+ provider=self.provider,
68
+ )
69
+ raise ValueError(message)
70
+
71
+ @classmethod
72
+ def from_dict(cls, config: dict) -> "DefaultLiteLLMClientConfig":
73
+ """
74
+ Initializes a dataclass from the passed config.
75
+
76
+ Args:
77
+ config: (dict) The config from which to initialize.
78
+
79
+ Raises:
80
+ ValueError: Config is missing required keys.
81
+
82
+ Returns:
83
+ DefaultLiteLLMClientConfig
84
+ """
85
+ # Check for deprecated keys
86
+ raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
87
+ # Raise error for using `model_name` instead instead of `model`
88
+ cls.check_and_error_for_model_name_in_config(config)
89
+ # Resolve any potential aliases.
90
+ config = cls.resolve_config_aliases(config)
91
+ # Validate that the required keys are present
92
+ validate_required_keys(config, REQUIRED_KEYS)
93
+ # Validate that the forbidden keys are not present
94
+ validate_forbidden_keys(config, FORBIDDEN_KEYS)
95
+ this = DefaultLiteLLMClientConfig(
96
+ # Required parameters
97
+ model=config.pop(MODEL_CONFIG_KEY),
98
+ provider=config.pop(PROVIDER_CONFIG_KEY),
99
+ # The rest of parameters (e.g. model parameters) are considered
100
+ # as extra parameters
101
+ extra_parameters=config,
102
+ )
103
+ return this
104
+
105
+ def to_dict(self) -> dict:
106
+ """Converts the config instance into a dictionary."""
107
+ d = asdict(self)
108
+ # Extra parameters should also be on the top level
109
+ d.pop("extra_parameters", None)
110
+ d.update(self.extra_parameters)
111
+ return d
112
+
113
+ @staticmethod
114
+ def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
115
+ return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
116
+
117
+ @staticmethod
118
+ def check_and_error_for_model_name_in_config(config: Dict[str, Any]) -> None:
119
+ """Check for usage of deprecated model_name and raise an error if found."""
120
+ if config.get(MODEL_NAME_CONFIG_KEY) and not config.get(MODEL_CONFIG_KEY):
121
+ event_info = (
122
+ f"Unsupported parameter - {MODEL_NAME_CONFIG_KEY} is set. Please use "
123
+ f"{MODEL_CONFIG_KEY} instead."
124
+ )
125
+ structlogger.error(
126
+ "default_litellm_client_config.unsupported_parameter_in_config",
127
+ event_info=event_info,
128
+ config=config,
129
+ )
130
+ rasa.shared.utils.cli.print_error_and_exit(event_info)
@@ -0,0 +1,234 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Any, Dict, Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ MODEL_NAME_CONFIG_KEY,
9
+ RASA_TYPE_CONFIG_KEY,
10
+ LANGCHAIN_TYPE_CONFIG_KEY,
11
+ HUGGINGFACE_MULTIPROCESS_CONFIG_KEY,
12
+ HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY,
13
+ HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY,
14
+ HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY,
15
+ HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY,
16
+ HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER,
17
+ PROVIDER_CONFIG_KEY,
18
+ HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
19
+ TIMEOUT_CONFIG_KEY,
20
+ REQUEST_TIMEOUT_CONFIG_KEY,
21
+ )
22
+ from rasa.shared.providers._configs.utils import (
23
+ resolve_aliases,
24
+ raise_deprecation_warnings,
25
+ validate_required_keys,
26
+ )
27
+ from rasa.shared.utils.io import raise_deprecation_warning
28
+
29
+ structlogger = structlog.get_logger()
30
+
31
+ DEPRECATED_HUGGINGFACE_TYPE = "huggingface"
32
+
33
+ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
34
+ # Provider aliases
35
+ RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
36
+ LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
37
+ # Model name aliases
38
+ MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
39
+ # Timeout aliases
40
+ REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
41
+ }
42
+
43
+ REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
44
+
45
+
46
+ @dataclass
47
+ class HuggingFaceLocalEmbeddingClientConfig:
48
+ """Parses configuration for HuggingFace local embeddings client, resolves
49
+ aliases and raises deprecation warnings.
50
+
51
+ Raises:
52
+ ValueError: Raised in cases of invalid configuration:
53
+ - If any of the required configuration keys are missing.
54
+ - If `api_type` has a value different from `huggingface_local` or
55
+ `huggingface` (deprecated).
56
+ """
57
+
58
+ model: str
59
+
60
+ multi_process: Optional[bool]
61
+ cache_folder: Optional[str]
62
+ show_progress: Optional[bool]
63
+
64
+ # Provider is not actually used by sentence-transformers, but we define
65
+ # it here because it's used as a switch denominator for HuggingFace
66
+ # local embedding client.
67
+ provider: str = HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
68
+
69
+ model_kwargs: dict = field(default_factory=dict)
70
+ encode_kwargs: dict = field(default_factory=dict)
71
+
72
+ def __post_init__(self) -> None:
73
+ if self.provider != HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER:
74
+ message = (
75
+ f"API type must be set to '{HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}'."
76
+ )
77
+ structlogger.error(
78
+ "huggingface_local_embeddings_client_config.validation_error",
79
+ message=message,
80
+ provider=self.provider,
81
+ )
82
+ raise ValueError(message)
83
+ if self.model is None:
84
+ message = "Model cannot be set to None."
85
+ structlogger.error(
86
+ "huggingface_local_embeddings_client_config.validation_error",
87
+ message=message,
88
+ model=self.model,
89
+ )
90
+ raise ValueError(message)
91
+
92
+ @classmethod
93
+ def from_dict(cls, config: dict) -> "HuggingFaceLocalEmbeddingClientConfig":
94
+ """
95
+ Initializes a dataclass from the passed config.
96
+
97
+ Args:
98
+ config: (dict) The config from which to initialize.
99
+
100
+ Raises:
101
+ ValueError: Config is missing required keys.
102
+
103
+ Returns:
104
+ DefaultLiteLLMClientConfig
105
+ """
106
+ # Check for usage of deprecated switching key and value:
107
+ # 1. type: huggingface
108
+ # 2. _type: huggingface
109
+ _raise_deprecation_warning_for_huggingface_deprecated_switch_value(config)
110
+ # Check for other deprecated keys
111
+ raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
112
+ # Resolve any potential aliases
113
+ config = cls.resolve_config_aliases(config)
114
+ # Validate that required keys are set
115
+ validate_required_keys(config, REQUIRED_KEYS)
116
+ this = HuggingFaceLocalEmbeddingClientConfig(
117
+ # Required parameters
118
+ model=config.pop(MODEL_CONFIG_KEY),
119
+ provider=config.pop(PROVIDER_CONFIG_KEY),
120
+ # Optional
121
+ multi_process=config.pop(HUGGINGFACE_MULTIPROCESS_CONFIG_KEY, False),
122
+ cache_folder=config.pop(
123
+ HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY,
124
+ str(HUGGINGFACE_LOCAL_EMBEDDING_CACHING_FOLDER),
125
+ ),
126
+ show_progress=config.pop(HUGGINGFACE_SHOW_PROGRESS_CONFIG_KEY, False),
127
+ model_kwargs=config.pop(HUGGINGFACE_MODEL_KWARGS_CONFIG_KEY, {}),
128
+ encode_kwargs=config.pop(HUGGINGFACE_ENCODE_KWARGS_CONFIG_KEY, {}),
129
+ )
130
+ return this
131
+
132
+ def to_dict(self) -> dict:
133
+ """Converts the config instance into a dictionary."""
134
+ return asdict(self)
135
+
136
+ @staticmethod
137
+ def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
138
+ config = _resolve_huggingface_deprecated_switch_value(config)
139
+ return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
140
+
141
+
142
+ def is_huggingface_local_config(config: dict) -> bool:
143
+ """Check whether the configuration is meant to configure
144
+ a local HuggingFace embedding client.
145
+ """
146
+ # Hugging face special deprecated cases:
147
+ # 1. type: huggingface
148
+ # 2. _type: huggingface
149
+ # If the deprecated setting is detected resolve both alias key and key
150
+ # value. This would mean that the configurations above will be
151
+ # transformed to:
152
+ # provider: huggingface_local
153
+ config = HuggingFaceLocalEmbeddingClientConfig.resolve_config_aliases(config)
154
+
155
+ # Case: Configuration contains `provider: huggingface_local`
156
+ if config.get(PROVIDER_CONFIG_KEY) in [
157
+ HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
158
+ ]:
159
+ return True
160
+
161
+ return False
162
+
163
+
164
+ def _raise_deprecation_warning_for_huggingface_deprecated_switch_value(
165
+ config: dict,
166
+ ) -> None:
167
+ deprecated_switch_keys = [RASA_TYPE_CONFIG_KEY, LANGCHAIN_TYPE_CONFIG_KEY]
168
+ deprecation_message = (
169
+ f"Configuration "
170
+ f"`{{deprecated_switch_key}}: {DEPRECATED_HUGGINGFACE_TYPE}` "
171
+ f"is deprecated and will be removed in 4.0.0. "
172
+ f"Please use "
173
+ f"`{PROVIDER_CONFIG_KEY}: {HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}` "
174
+ f"instead."
175
+ )
176
+ for deprecated_switch_key in deprecated_switch_keys:
177
+ if (
178
+ deprecated_switch_key in config
179
+ and config[deprecated_switch_key] == DEPRECATED_HUGGINGFACE_TYPE
180
+ ):
181
+ raise_deprecation_warning(
182
+ message=deprecation_message.format(
183
+ deprecated_switch_key=deprecated_switch_key
184
+ )
185
+ )
186
+
187
+
188
+ def _resolve_huggingface_deprecated_switch_value(config: dict) -> dict:
189
+ """
190
+ Resolve use of deprecated switching mechanism for HuggingFace local
191
+ embedding client.
192
+
193
+ The following settings (key + value) are deprecated:
194
+ 1. `type: huggingface`
195
+ 2. `_type: huggingface`
196
+ in favor of `provider: huggingface_local`.
197
+
198
+
199
+ Args:
200
+ config: given config
201
+
202
+ Returns:
203
+ New config with resolved switch mechanism
204
+
205
+ """
206
+ config = config.copy()
207
+
208
+ deprecated_switch_keys = [RASA_TYPE_CONFIG_KEY, LANGCHAIN_TYPE_CONFIG_KEY]
209
+ debug_message = (
210
+ f"Switching "
211
+ f"`{{deprecated_switch_key}}: {DEPRECATED_HUGGINGFACE_TYPE}` "
212
+ f"to `{PROVIDER_CONFIG_KEY}: {HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER}`."
213
+ )
214
+
215
+ for deprecated_switch_key in deprecated_switch_keys:
216
+ if (
217
+ deprecated_switch_key in config
218
+ and config[deprecated_switch_key] == DEPRECATED_HUGGINGFACE_TYPE
219
+ ):
220
+ # Update configuration with new switch mechanism
221
+ config[PROVIDER_CONFIG_KEY] = HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
222
+ # Pop the deprecated key used
223
+ config.pop(deprecated_switch_key, None)
224
+
225
+ structlogger.debug(
226
+ "HuggingFaceLocalEmbeddingClientConfig"
227
+ "._resolve_huggingface_deprecated_switch_value",
228
+ message=debug_message.format(
229
+ deprecated_switch_key=deprecated_switch_key
230
+ ),
231
+ new_config=config,
232
+ )
233
+
234
+ return config
@@ -0,0 +1,175 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Any, Dict, Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ MODEL_NAME_CONFIG_KEY,
9
+ OPENAI_API_BASE_CONFIG_KEY,
10
+ API_BASE_CONFIG_KEY,
11
+ OPENAI_API_TYPE_CONFIG_KEY,
12
+ API_TYPE_CONFIG_KEY,
13
+ OPENAI_API_VERSION_CONFIG_KEY,
14
+ API_VERSION_CONFIG_KEY,
15
+ RASA_TYPE_CONFIG_KEY,
16
+ LANGCHAIN_TYPE_CONFIG_KEY,
17
+ STREAM_CONFIG_KEY,
18
+ N_REPHRASES_CONFIG_KEY,
19
+ REQUEST_TIMEOUT_CONFIG_KEY,
20
+ TIMEOUT_CONFIG_KEY,
21
+ PROVIDER_CONFIG_KEY,
22
+ OPENAI_PROVIDER,
23
+ OPENAI_API_TYPE,
24
+ )
25
+ from rasa.shared.providers._configs.utils import (
26
+ resolve_aliases,
27
+ validate_required_keys,
28
+ raise_deprecation_warnings,
29
+ validate_forbidden_keys,
30
+ )
31
+
32
+ structlogger = structlog.get_logger()
33
+
34
+
35
+ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
36
+ # Model name aliases
37
+ MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
38
+ # Provider aliases
39
+ RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
40
+ LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
41
+ # API type aliases
42
+ OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
43
+ # API base aliases
44
+ OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
45
+ # API version aliases
46
+ OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
47
+ # Timeout aliases
48
+ REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
49
+ }
50
+
51
+ REQUIRED_KEYS = [MODEL_CONFIG_KEY]
52
+
53
+ FORBIDDEN_KEYS = [
54
+ STREAM_CONFIG_KEY,
55
+ N_REPHRASES_CONFIG_KEY,
56
+ ]
57
+
58
+
59
+ @dataclass
60
+ class OpenAIClientConfig:
61
+ """Parses configuration for Azure OpenAI client, resolves aliases and
62
+ raises deprecation warnings.
63
+
64
+ Raises:
65
+ ValueError: Raised in cases of invalid configuration:
66
+ - If any of the required configuration keys are missing.
67
+ - If `api_type` has a value different from `openai`.
68
+ """
69
+
70
+ model: str
71
+ api_base: Optional[str]
72
+ api_version: Optional[str]
73
+
74
+ # API Type is not actually used by LiteLLM backend, but we define
75
+ # it here for backward compatibility.
76
+ api_type: str = OPENAI_API_TYPE
77
+
78
+ # Provider is not used by LiteLLM backend, but we define
79
+ # it here since it's used as switch between different
80
+ # clients
81
+ provider: str = OPENAI_PROVIDER
82
+
83
+ extra_parameters: dict = field(default_factory=dict)
84
+
85
+ def __post_init__(self) -> None:
86
+ # In case of OpenAI hosting, it doesn't make sense
87
+ # for API type to be anything else that 'openai'
88
+ if self.api_type != OPENAI_API_TYPE:
89
+ message = f"API type must be set to '{OPENAI_API_TYPE}'."
90
+ structlogger.error(
91
+ "openai_client_config.validation_error",
92
+ message=message,
93
+ api_type=self.api_type,
94
+ )
95
+ raise ValueError(message)
96
+ if self.provider != OPENAI_PROVIDER:
97
+ message = f"Provider must be set to '{OPENAI_PROVIDER}'."
98
+ structlogger.error(
99
+ "openai_client_config.validation_error",
100
+ message=message,
101
+ provider=self.provider,
102
+ )
103
+ raise ValueError(message)
104
+ if self.model is None:
105
+ message = "Model cannot be set to None."
106
+ structlogger.error(
107
+ "openai_client_config.validation_error",
108
+ message=message,
109
+ model=self.model,
110
+ )
111
+ raise ValueError(message)
112
+
113
+ @classmethod
114
+ def from_dict(cls, config: dict) -> "OpenAIClientConfig":
115
+ """
116
+ Initializes a dataclass from the passed config.
117
+
118
+ Args:
119
+ config: (dict) The config from which to initialize.
120
+
121
+ Raises:
122
+ ValueError: Config is missing required keys.
123
+
124
+ Returns:
125
+ AzureOpenAIClientConfig
126
+ """
127
+ # Check for deprecated keys
128
+ raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
129
+ # Resolve any potential aliases
130
+ config = cls.resolve_config_aliases(config)
131
+ # Validate that the required keys are present
132
+ validate_required_keys(config, REQUIRED_KEYS)
133
+ # Validate that the forbidden keys are not present
134
+ validate_forbidden_keys(config, FORBIDDEN_KEYS)
135
+ this = OpenAIClientConfig(
136
+ # Required parameters
137
+ model=config.pop(MODEL_CONFIG_KEY),
138
+ # Pop the 'provider' key. Currently, it's *optional* because of
139
+ # backward compatibility with older versions.
140
+ provider=config.pop(PROVIDER_CONFIG_KEY, OPENAI_PROVIDER),
141
+ # Optional parameters
142
+ api_base=config.pop(API_BASE_CONFIG_KEY, None),
143
+ api_version=config.pop(API_VERSION_CONFIG_KEY, None),
144
+ api_type=config.pop(API_TYPE_CONFIG_KEY, OPENAI_API_TYPE),
145
+ # The rest of parameters (e.g. model parameters) are considered
146
+ # as extra parameters (this also includes timeout).
147
+ extra_parameters=config,
148
+ )
149
+ return this
150
+
151
+ def to_dict(self) -> dict:
152
+ """Converts the config instance into a dictionary."""
153
+ d = asdict(self)
154
+ # Extra parameters should also be on the top level
155
+ d.pop("extra_parameters", None)
156
+ d.update(self.extra_parameters)
157
+ return d
158
+
159
+ @staticmethod
160
+ def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
161
+ return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
162
+
163
+
164
+ def is_openai_config(config: dict) -> bool:
165
+ """Check whether the configuration is meant to configure
166
+ an OpenAI client.
167
+ """
168
+ # Process the config to handle all the aliases
169
+ config = OpenAIClientConfig.resolve_config_aliases(config)
170
+
171
+ # Case: Configuration contains `provider: openai`
172
+ if config.get(PROVIDER_CONFIG_KEY) == OPENAI_PROVIDER:
173
+ return True
174
+
175
+ return False
@@ -0,0 +1,171 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Any, Dict, Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ MODEL_NAME_CONFIG_KEY,
9
+ OPENAI_API_BASE_CONFIG_KEY,
10
+ API_BASE_CONFIG_KEY,
11
+ OPENAI_API_TYPE_CONFIG_KEY,
12
+ API_TYPE_CONFIG_KEY,
13
+ OPENAI_API_VERSION_CONFIG_KEY,
14
+ API_VERSION_CONFIG_KEY,
15
+ RASA_TYPE_CONFIG_KEY,
16
+ LANGCHAIN_TYPE_CONFIG_KEY,
17
+ STREAM_CONFIG_KEY,
18
+ N_REPHRASES_CONFIG_KEY,
19
+ REQUEST_TIMEOUT_CONFIG_KEY,
20
+ TIMEOUT_CONFIG_KEY,
21
+ PROVIDER_CONFIG_KEY,
22
+ OPENAI_PROVIDER,
23
+ SELF_HOSTED_PROVIDER,
24
+ )
25
+ from rasa.shared.providers._configs.utils import (
26
+ raise_deprecation_warnings,
27
+ resolve_aliases,
28
+ validate_forbidden_keys,
29
+ validate_required_keys,
30
+ )
31
+
32
+ structlogger = structlog.get_logger()
33
+
34
+
35
+ DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING = {
36
+ # Model name aliases
37
+ MODEL_NAME_CONFIG_KEY: MODEL_CONFIG_KEY,
38
+ # Provider aliases
39
+ RASA_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
40
+ LANGCHAIN_TYPE_CONFIG_KEY: PROVIDER_CONFIG_KEY,
41
+ # API type aliases
42
+ OPENAI_API_TYPE_CONFIG_KEY: API_TYPE_CONFIG_KEY,
43
+ # API base aliases
44
+ OPENAI_API_BASE_CONFIG_KEY: API_BASE_CONFIG_KEY,
45
+ # API version aliases
46
+ OPENAI_API_VERSION_CONFIG_KEY: API_VERSION_CONFIG_KEY,
47
+ # Timeout aliases
48
+ REQUEST_TIMEOUT_CONFIG_KEY: TIMEOUT_CONFIG_KEY,
49
+ }
50
+
51
+ REQUIRED_KEYS = [API_BASE_CONFIG_KEY, MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY]
52
+
53
+ FORBIDDEN_KEYS = [
54
+ STREAM_CONFIG_KEY,
55
+ N_REPHRASES_CONFIG_KEY,
56
+ ]
57
+
58
+
59
+ @dataclass
60
+ class SelfHostedLLMClientConfig:
61
+ """Parses configuration for Self Hosted LiteLLM client, resolves aliases and
62
+ raises deprecation warnings.
63
+
64
+ Raises:
65
+ ValueError: Raised in cases of invalid configuration:
66
+ - If any of the required configuration keys are missing.
67
+ """
68
+
69
+ model: str
70
+ provider: str
71
+ api_base: str
72
+ api_version: Optional[str] = None
73
+ api_type: Optional[str] = OPENAI_PROVIDER
74
+ extra_parameters: dict = field(default_factory=dict)
75
+
76
+ def __post_init__(self) -> None:
77
+ if self.model is None:
78
+ message = "Model cannot be set to None."
79
+ structlogger.error(
80
+ "self_hosted_llm_client_config.validation_error",
81
+ message=message,
82
+ model=self.model,
83
+ )
84
+ raise ValueError(message)
85
+ if self.provider is None:
86
+ message = "Provider cannot be set to None."
87
+ structlogger.error(
88
+ "self_hosted_llm_client_config.validation_error",
89
+ message=message,
90
+ provider=self.provider,
91
+ )
92
+ raise ValueError(message)
93
+ if self.api_base is None:
94
+ message = "API base cannot be set to None."
95
+ structlogger.error(
96
+ "self_hosted_llm_client_config.validation_error",
97
+ message=message,
98
+ provider=self.provider,
99
+ )
100
+ raise ValueError(message)
101
+ if self.api_type != OPENAI_PROVIDER:
102
+ message = (
103
+ f"Currently supports only {OPENAI_PROVIDER} endpoints. "
104
+ f"API type must be set to '{OPENAI_PROVIDER}'."
105
+ )
106
+ structlogger.error(
107
+ "self_hosted_llm_client_config.validation_error",
108
+ message=message,
109
+ api_type=self.api_type,
110
+ )
111
+ raise ValueError(message)
112
+
113
+ @classmethod
114
+ def from_dict(cls, config: dict) -> "SelfHostedLLMClientConfig":
115
+ """
116
+ Initializes a dataclass from the passed config.
117
+
118
+ Args:
119
+ config: (dict) The config from which to initialize.
120
+
121
+ Raises:
122
+ ValueError: Config is missing required keys.
123
+
124
+ Returns:
125
+ DefaultLiteLLMClientConfig
126
+ """
127
+ # Check for deprecated keys
128
+ raise_deprecation_warnings(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
129
+ # Resolve any potential aliases
130
+ config = cls.resolve_config_aliases(config)
131
+ # Validate that the required keys are present
132
+ validate_required_keys(config, REQUIRED_KEYS)
133
+ # Validate that the forbidden keys are not present
134
+ validate_forbidden_keys(config, FORBIDDEN_KEYS)
135
+ this = SelfHostedLLMClientConfig(
136
+ # Required parameters
137
+ model=config.pop(MODEL_CONFIG_KEY),
138
+ provider=config.pop(PROVIDER_CONFIG_KEY),
139
+ api_base=config.pop(API_BASE_CONFIG_KEY),
140
+ # Optional parameters
141
+ api_type=config.pop(API_TYPE_CONFIG_KEY, OPENAI_PROVIDER),
142
+ api_version=config.pop(API_VERSION_CONFIG_KEY, None),
143
+ # The rest of parameters (e.g. model parameters) are considered
144
+ # as extra parameters
145
+ extra_parameters=config,
146
+ )
147
+ return this
148
+
149
+ def to_dict(self) -> dict:
150
+ """Converts the config instance into a dictionary."""
151
+ d = asdict(self)
152
+ # Extra parameters should also be on the top level
153
+ d.pop("extra_parameters", None)
154
+ d.update(self.extra_parameters)
155
+ return d
156
+
157
+ @staticmethod
158
+ def resolve_config_aliases(config: Dict[str, Any]) -> Dict[str, Any]:
159
+ return resolve_aliases(config, DEPRECATED_ALIASES_TO_STANDARD_KEY_MAPPING)
160
+
161
+
162
+ def is_self_hosted_config(config: dict) -> bool:
163
+ """Check whether the configuration is meant to configure an self-hosted client."""
164
+ # Process the config to handle all the aliases
165
+ config = SelfHostedLLMClientConfig.resolve_config_aliases(config)
166
+
167
+ # Case: Configuration contains `provider: self-hosted`
168
+ if config.get(PROVIDER_CONFIG_KEY) == SELF_HOSTED_PROVIDER:
169
+ return True
170
+
171
+ return False