rasa-pro 3.9.18__py3-none-any.whl → 3.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (189) hide show
  1. README.md +26 -57
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +26 -22
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +2 -0
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +108 -433
  12. rasa/cli/interactive.py +1 -0
  13. rasa/cli/llm_fine_tuning.py +395 -0
  14. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  15. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  16. rasa/cli/run.py +14 -13
  17. rasa/cli/scaffold.py +10 -8
  18. rasa/cli/train.py +8 -7
  19. rasa/cli/utils.py +15 -0
  20. rasa/constants.py +7 -1
  21. rasa/core/actions/action.py +98 -49
  22. rasa/core/actions/action_run_slot_rejections.py +4 -1
  23. rasa/core/actions/custom_action_executor.py +9 -6
  24. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  25. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  26. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  27. rasa/core/actions/http_custom_action_executor.py +6 -5
  28. rasa/core/agent.py +21 -17
  29. rasa/core/channels/__init__.py +2 -0
  30. rasa/core/channels/audiocodes.py +1 -16
  31. rasa/core/channels/inspector/dist/index.html +0 -2
  32. rasa/core/channels/inspector/index.html +0 -2
  33. rasa/core/channels/voice_aware/__init__.py +0 -0
  34. rasa/core/channels/voice_aware/jambonz.py +103 -0
  35. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  36. rasa/core/channels/voice_aware/utils.py +20 -0
  37. rasa/core/channels/voice_native/__init__.py +0 -0
  38. rasa/core/constants.py +6 -1
  39. rasa/core/featurizers/single_state_featurizer.py +1 -22
  40. rasa/core/featurizers/tracker_featurizers.py +18 -115
  41. rasa/core/information_retrieval/faiss.py +7 -4
  42. rasa/core/information_retrieval/information_retrieval.py +8 -0
  43. rasa/core/information_retrieval/milvus.py +9 -2
  44. rasa/core/information_retrieval/qdrant.py +1 -1
  45. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  46. rasa/core/nlg/summarize.py +4 -3
  47. rasa/core/policies/enterprise_search_policy.py +100 -44
  48. rasa/core/policies/flows/flow_executor.py +130 -94
  49. rasa/core/policies/intentless_policy.py +52 -28
  50. rasa/core/policies/ted_policy.py +33 -58
  51. rasa/core/policies/unexpected_intent_policy.py +7 -15
  52. rasa/core/processor.py +20 -53
  53. rasa/core/run.py +5 -4
  54. rasa/core/tracker_store.py +8 -4
  55. rasa/core/utils.py +45 -56
  56. rasa/dialogue_understanding/coexistence/llm_based_router.py +45 -12
  57. rasa/dialogue_understanding/commands/__init__.py +4 -0
  58. rasa/dialogue_understanding/commands/change_flow_command.py +0 -6
  59. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  60. rasa/dialogue_understanding/commands/set_slot_command.py +1 -5
  61. rasa/dialogue_understanding/commands/utils.py +38 -0
  62. rasa/dialogue_understanding/generator/constants.py +10 -3
  63. rasa/dialogue_understanding/generator/flow_retrieval.py +14 -5
  64. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -2
  65. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +106 -87
  66. rasa/dialogue_understanding/generator/nlu_command_adapter.py +28 -6
  67. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +90 -37
  68. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
  69. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  70. rasa/dialogue_understanding/processor/command_processor.py +13 -14
  71. rasa/e2e_test/aggregate_test_stats_calculator.py +124 -0
  72. rasa/e2e_test/assertions.py +1181 -0
  73. rasa/e2e_test/assertions_schema.yml +106 -0
  74. rasa/e2e_test/constants.py +20 -0
  75. rasa/e2e_test/e2e_config.py +220 -0
  76. rasa/e2e_test/e2e_config_schema.yml +26 -0
  77. rasa/e2e_test/e2e_test_case.py +131 -8
  78. rasa/e2e_test/e2e_test_converter.py +363 -0
  79. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  80. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  81. rasa/e2e_test/e2e_test_result.py +26 -6
  82. rasa/e2e_test/e2e_test_runner.py +491 -72
  83. rasa/e2e_test/e2e_test_schema.yml +96 -0
  84. rasa/e2e_test/pykwalify_extensions.py +39 -0
  85. rasa/e2e_test/stub_custom_action.py +70 -0
  86. rasa/e2e_test/utils/__init__.py +0 -0
  87. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  88. rasa/e2e_test/utils/io.py +596 -0
  89. rasa/e2e_test/utils/validation.py +80 -0
  90. rasa/engine/recipes/default_components.py +0 -2
  91. rasa/engine/storage/local_model_storage.py +0 -1
  92. rasa/env.py +9 -0
  93. rasa/llm_fine_tuning/__init__.py +0 -0
  94. rasa/llm_fine_tuning/annotation_module.py +241 -0
  95. rasa/llm_fine_tuning/conversations.py +144 -0
  96. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  97. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  98. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  99. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  100. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  101. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  102. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  103. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  104. rasa/llm_fine_tuning/storage.py +174 -0
  105. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  106. rasa/model_training.py +48 -16
  107. rasa/nlu/classifiers/diet_classifier.py +25 -38
  108. rasa/nlu/classifiers/logistic_regression_classifier.py +9 -44
  109. rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
  110. rasa/nlu/extractors/crf_entity_extractor.py +50 -93
  111. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -78
  112. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +17 -52
  113. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
  114. rasa/nlu/persistor.py +129 -32
  115. rasa/server.py +45 -10
  116. rasa/shared/constants.py +63 -15
  117. rasa/shared/core/domain.py +15 -12
  118. rasa/shared/core/events.py +28 -2
  119. rasa/shared/core/flows/flow.py +208 -13
  120. rasa/shared/core/flows/flow_path.py +84 -0
  121. rasa/shared/core/flows/flows_list.py +28 -10
  122. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  123. rasa/shared/core/flows/validation.py +112 -25
  124. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  125. rasa/shared/core/trackers.py +6 -0
  126. rasa/shared/core/training_data/visualization.html +2 -2
  127. rasa/shared/exceptions.py +4 -0
  128. rasa/shared/importers/importer.py +60 -11
  129. rasa/shared/importers/remote_importer.py +196 -0
  130. rasa/shared/nlu/constants.py +2 -0
  131. rasa/shared/nlu/training_data/features.py +2 -120
  132. rasa/shared/providers/_configs/__init__.py +0 -0
  133. rasa/shared/providers/_configs/azure_openai_client_config.py +181 -0
  134. rasa/shared/providers/_configs/client_config.py +57 -0
  135. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  136. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  137. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  138. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +171 -0
  139. rasa/shared/providers/_configs/utils.py +101 -0
  140. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  141. rasa/shared/providers/embedding/__init__.py +0 -0
  142. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +254 -0
  143. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  144. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  145. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  146. rasa/shared/providers/embedding/embedding_client.py +90 -0
  147. rasa/shared/providers/embedding/embedding_response.py +41 -0
  148. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  149. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  150. rasa/shared/providers/llm/__init__.py +0 -0
  151. rasa/shared/providers/llm/_base_litellm_client.py +227 -0
  152. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  153. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  154. rasa/shared/providers/llm/llm_client.py +76 -0
  155. rasa/shared/providers/llm/llm_response.py +50 -0
  156. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  157. rasa/shared/providers/llm/self_hosted_llm_client.py +169 -0
  158. rasa/shared/providers/mappings.py +75 -0
  159. rasa/shared/utils/cli.py +30 -0
  160. rasa/shared/utils/io.py +65 -3
  161. rasa/shared/utils/llm.py +223 -200
  162. rasa/shared/utils/yaml.py +122 -7
  163. rasa/studio/download.py +19 -13
  164. rasa/studio/train.py +2 -3
  165. rasa/studio/upload.py +2 -3
  166. rasa/telemetry.py +113 -58
  167. rasa/tracing/config.py +2 -3
  168. rasa/tracing/instrumentation/attribute_extractors.py +29 -17
  169. rasa/tracing/instrumentation/instrumentation.py +4 -47
  170. rasa/utils/common.py +18 -19
  171. rasa/utils/endpoints.py +7 -4
  172. rasa/utils/io.py +66 -0
  173. rasa/utils/json_utils.py +60 -0
  174. rasa/utils/licensing.py +9 -1
  175. rasa/utils/ml_utils.py +4 -2
  176. rasa/utils/tensorflow/model_data.py +193 -2
  177. rasa/validator.py +195 -1
  178. rasa/version.py +1 -1
  179. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/METADATA +47 -72
  180. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/RECORD +185 -121
  181. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  182. rasa/shared/providers/openai/clients.py +0 -43
  183. rasa/shared/providers/openai/session_handler.py +0 -110
  184. rasa/utils/tensorflow/feature_array.py +0 -366
  185. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  186. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  187. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/NOTICE +0 -0
  188. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/WHEEL +0 -0
  189. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.3.dist-info}/entry_points.txt +0 -0
rasa/shared/utils/llm.py CHANGED
@@ -1,46 +1,60 @@
1
- import os
2
- import warnings
3
- from typing import Any, Dict, Optional, Text, Type, TYPE_CHECKING, Union
4
-
1
+ from functools import wraps
2
+ from typing import (
3
+ Any,
4
+ Callable,
5
+ Dict,
6
+ Optional,
7
+ Text,
8
+ Type,
9
+ TypeVar,
10
+ TYPE_CHECKING,
11
+ Union,
12
+ cast,
13
+ )
14
+ import json
5
15
  import structlog
6
16
 
7
17
  import rasa.shared.utils.io
8
18
  from rasa.shared.constants import (
9
19
  RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG,
10
20
  RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY,
11
- OPENAI_API_TYPE_ENV_VAR,
12
- OPENAI_API_VERSION_ENV_VAR,
13
- OPENAI_API_BASE_ENV_VAR,
14
- REQUESTS_CA_BUNDLE_ENV_VAR,
15
- OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY,
16
- OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY,
17
- OPENAI_API_VERSION_CONFIG_KEY,
18
- OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY,
19
- OPENAI_API_TYPE_CONFIG_KEY,
20
- OPENAI_API_BASE_CONFIG_KEY,
21
- OPENAI_DEPLOYMENT_NAME_CONFIG_KEY,
22
- OPENAI_DEPLOYMENT_CONFIG_KEY,
23
- OPENAI_ENGINE_CONFIG_KEY,
24
- LANGCHAIN_TYPE_CONFIG_KEY,
25
- RASA_TYPE_CONFIG_KEY,
21
+ PROVIDER_CONFIG_KEY,
26
22
  )
27
23
  from rasa.shared.core.events import BotUttered, UserUttered
28
24
  from rasa.shared.core.slots import Slot, BooleanSlot, CategoricalSlot
29
- from rasa.shared.engine.caching import get_local_cache_location
25
+ from rasa.shared.engine.caching import (
26
+ get_local_cache_location,
27
+ )
30
28
  from rasa.shared.exceptions import (
31
29
  FileIOException,
32
30
  FileNotFoundException,
31
+ ProviderClientValidationError,
32
+ )
33
+ from rasa.shared.providers._configs.azure_openai_client_config import (
34
+ is_azure_openai_config,
35
+ )
36
+ from rasa.shared.providers._configs.huggingface_local_embedding_client_config import (
37
+ is_huggingface_local_config,
33
38
  )
39
+ from rasa.shared.providers._configs.openai_client_config import is_openai_config
40
+ from rasa.shared.providers._configs.self_hosted_llm_client_config import (
41
+ is_self_hosted_config,
42
+ )
43
+ from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
44
+ from rasa.shared.providers.llm.llm_client import LLMClient
45
+ from rasa.shared.providers.mappings import (
46
+ get_llm_client_from_provider,
47
+ AZURE_OPENAI_PROVIDER,
48
+ OPENAI_PROVIDER,
49
+ SELF_HOSTED_PROVIDER,
50
+ get_embedding_client_from_provider,
51
+ HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
52
+ get_client_config_class_from_provider,
53
+ )
54
+ from rasa.shared.utils.cli import print_error_and_exit
34
55
 
35
56
  if TYPE_CHECKING:
36
- from langchain.chat_models import AzureChatOpenAI
37
- from langchain.schema.embeddings import Embeddings
38
- from langchain.llms.base import BaseLLM
39
57
  from rasa.shared.core.trackers import DialogueStateTracker
40
- from rasa.shared.providers.openai.clients import (
41
- AioHTTPSessionAzureChatOpenAI,
42
- AioHTTPSessionOpenAIChat,
43
- )
44
58
 
45
59
  structlogger = structlog.get_logger()
46
60
 
@@ -70,6 +84,94 @@ ERROR_PLACEHOLDER = {
70
84
  "default": "[User input triggered an error]",
71
85
  }
72
86
 
87
+ _Factory_F = TypeVar(
88
+ "_Factory_F",
89
+ bound=Callable[[Dict[str, Any], Dict[str, Any]], Union[EmbeddingClient, LLMClient]],
90
+ )
91
+ _CombineConfigs_F = TypeVar(
92
+ "_CombineConfigs_F",
93
+ bound=Callable[[Dict[str, Any], Dict[str, Any]], Dict[str, Any]],
94
+ )
95
+
96
+
97
+ def _compute_hash_for_cache_from_configs(
98
+ config_x: Dict[str, Any], config_y: Dict[str, Any]
99
+ ) -> int:
100
+ """Get a unique hash of the default and custom configs."""
101
+ return hash(
102
+ json.dumps(config_x, sort_keys=True) + json.dumps(config_y, sort_keys=True)
103
+ )
104
+
105
+
106
+ def _retrieve_from_cache(
107
+ cache: Dict[int, Any], unique_hash: int, function: Callable, function_kwargs: dict
108
+ ) -> Any:
109
+ """Retrieve the value from the cache if it exists. If it does not exist, cache it"""
110
+ if unique_hash in cache:
111
+ return cache[unique_hash]
112
+ else:
113
+ return_value = function(**function_kwargs)
114
+ cache[unique_hash] = return_value
115
+ return return_value
116
+
117
+
118
+ def _cache_factory(function: _Factory_F) -> _Factory_F:
119
+ """Memoize the factory methods based on the arguments."""
120
+ cache: Dict[int, Union[EmbeddingClient, LLMClient]] = {}
121
+
122
+ @wraps(function)
123
+ def factory_method_wrapper(
124
+ config_x: Dict[str, Any], config_y: Dict[str, Any]
125
+ ) -> Union[EmbeddingClient, LLMClient]:
126
+ # Get a unique hash of the default and custom configs.
127
+ unique_hash = _compute_hash_for_cache_from_configs(config_x, config_y)
128
+ return _retrieve_from_cache(
129
+ cache=cache,
130
+ unique_hash=unique_hash,
131
+ function=function,
132
+ function_kwargs={"custom_config": config_x, "default_config": config_y},
133
+ )
134
+
135
+ def clear_cache() -> None:
136
+ cache.clear()
137
+ structlogger.debug(
138
+ "Cleared cache for factory method",
139
+ function_name=function.__name__,
140
+ )
141
+
142
+ setattr(factory_method_wrapper, "clear_cache", clear_cache)
143
+ return cast(_Factory_F, factory_method_wrapper)
144
+
145
+
146
+ def _cache_combine_custom_and_default_configs(
147
+ function: _CombineConfigs_F,
148
+ ) -> _CombineConfigs_F:
149
+ """Memoize the combine_custom_and_default_config method based on the arguments."""
150
+ cache: Dict[int, dict] = {}
151
+
152
+ @wraps(function)
153
+ def combine_configs_wrapper(
154
+ config_x: Dict[str, Any], config_y: Dict[str, Any]
155
+ ) -> dict:
156
+ # Get a unique hash of the default and custom configs.
157
+ unique_hash = _compute_hash_for_cache_from_configs(config_x, config_y)
158
+ return _retrieve_from_cache(
159
+ cache=cache,
160
+ unique_hash=unique_hash,
161
+ function=function,
162
+ function_kwargs={"custom_config": config_x, "default_config": config_y},
163
+ )
164
+
165
+ def clear_cache() -> None:
166
+ cache.clear()
167
+ structlogger.debug(
168
+ "Cleared cache for combine_custom_and_default_config method",
169
+ function_name=function.__name__,
170
+ )
171
+
172
+ setattr(combine_configs_wrapper, "clear_cache", clear_cache)
173
+ return cast(_CombineConfigs_F, combine_configs_wrapper)
174
+
73
175
 
74
176
  def tracker_as_readable_transcript(
75
177
  tracker: "DialogueStateTracker",
@@ -138,11 +240,15 @@ def sanitize_message_for_prompt(text: Optional[str]) -> str:
138
240
  return text.replace("\n", " ") if text else ""
139
241
 
140
242
 
243
+ @_cache_combine_custom_and_default_configs
141
244
  def combine_custom_and_default_config(
142
- custom_config: Optional[Dict[Text, Any]], default_config: Dict[Text, Any]
245
+ custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
143
246
  ) -> Dict[Text, Any]:
144
247
  """Merges the given llm config with the default config.
145
248
 
249
+ This method guarantees that the provider is set and all the deprecated keys are
250
+ resolved. Hence, produces only a valid client config.
251
+
146
252
  Only uses the default configuration arguments, if the type set in the
147
253
  custom config matches the type in the default config. Otherwise, only
148
254
  the custom config is used.
@@ -155,155 +261,96 @@ def combine_custom_and_default_config(
155
261
  The merged config.
156
262
  """
157
263
  if custom_config is None:
158
- return default_config
159
-
160
- if RASA_TYPE_CONFIG_KEY in custom_config:
161
- # rename type to _type as "type" is the convention we use
162
- # across the different components in config files.
163
- # langchain expects "_type" as the key though
164
- custom_config[LANGCHAIN_TYPE_CONFIG_KEY] = custom_config.pop(
165
- RASA_TYPE_CONFIG_KEY
264
+ return default_config.copy()
265
+
266
+ # Get the provider from the custom config.
267
+ custom_config_provider = get_provider_from_config(custom_config)
268
+ # We expect the provider to be set in the default configs of all Rasa components.
269
+ default_config_provider = default_config[PROVIDER_CONFIG_KEY]
270
+
271
+ if (
272
+ custom_config_provider is not None
273
+ and custom_config_provider != default_config_provider
274
+ ):
275
+ # Get the provider-specific config class
276
+ client_config_clazz = get_client_config_class_from_provider(
277
+ custom_config_provider
166
278
  )
279
+ # Checks for deprecated keys, resolves aliases and returns a valid config.
280
+ # This is done to ensure that the custom config is valid.
281
+ return client_config_clazz.from_dict(custom_config).to_dict()
282
+
283
+ # If the provider is the same in both configs
284
+ # OR provider is not specified in the custom config
285
+ # perform MERGE by overriding the default config keys and values
286
+ # with custom config keys and values.
287
+ merged_config = {**default_config.copy(), **custom_config.copy()}
288
+ # Check for deprecated keys, resolve aliases and return a valid config.
289
+ # This is done to ensure that the merged config is valid.
290
+ default_config_clazz = get_client_config_class_from_provider(
291
+ default_config_provider
292
+ )
293
+ return default_config_clazz.from_dict(merged_config).to_dict()
167
294
 
168
- if LANGCHAIN_TYPE_CONFIG_KEY in custom_config and custom_config[
169
- LANGCHAIN_TYPE_CONFIG_KEY
170
- ] != default_config.get(LANGCHAIN_TYPE_CONFIG_KEY):
171
- return custom_config
172
- return {**default_config, **custom_config}
295
+
296
+ def get_provider_from_config(config: dict) -> Optional[str]:
297
+ """Try to get the provider from the passed llm/embeddings configuration.
298
+ If no provider can be found, return None.
299
+ """
300
+ if not config:
301
+ return None
302
+ if is_self_hosted_config(config):
303
+ return SELF_HOSTED_PROVIDER
304
+ elif is_azure_openai_config(config):
305
+ return AZURE_OPENAI_PROVIDER
306
+ elif is_openai_config(config):
307
+ return OPENAI_PROVIDER
308
+ elif is_huggingface_local_config(config):
309
+ return HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
310
+ else:
311
+ return config.get(PROVIDER_CONFIG_KEY)
173
312
 
174
313
 
175
314
  def ensure_cache() -> None:
176
315
  """Ensures that the cache is initialized."""
177
- import langchain
178
- from langchain.cache import SQLiteCache
316
+ import litellm
179
317
 
180
- # ensure the cache directory exists
181
- cache_location = get_local_cache_location()
318
+ # Ensure the cache directory exists
319
+ cache_location = get_local_cache_location() / "rasa-llm-cache"
182
320
  cache_location.mkdir(parents=True, exist_ok=True)
183
321
 
184
- db_location = cache_location / "rasa-llm-cache.db"
185
- langchain.llm_cache = SQLiteCache(database_path=str(db_location))
186
-
187
-
188
- def preprocess_config_for_azure(config: Dict[str, Any]) -> Dict[str, Any]:
189
- """Preprocesses the config for Azure deployments.
190
-
191
- This function is used to preprocess the config for Azure deployments.
192
- AzureChatOpenAI does not expect the _type key, as it is not a defined parameter
193
- in the class. So we need to remove it before passing the config to the class.
194
- AzureChatOpenAI expects the openai_api_type key to be set instead.
195
-
196
- Args:
197
- config: The config to preprocess.
198
-
199
- Returns:
200
- The preprocessed config.
201
- """
202
- config["deployment_name"] = (
203
- config.get(OPENAI_DEPLOYMENT_NAME_CONFIG_KEY)
204
- or config.get(OPENAI_DEPLOYMENT_CONFIG_KEY)
205
- or config.get(OPENAI_ENGINE_CONFIG_KEY)
206
- )
207
- config["openai_api_base"] = (
208
- config.get(OPENAI_API_BASE_CONFIG_KEY)
209
- or config.get(OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY)
210
- or os.environ.get(OPENAI_API_BASE_ENV_VAR)
211
- )
212
- config["openai_api_type"] = (
213
- config.get(OPENAI_API_TYPE_CONFIG_KEY)
214
- or config.get(OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY)
215
- or os.environ.get(OPENAI_API_TYPE_ENV_VAR)
216
- )
217
- config["openai_api_version"] = (
218
- config.get(OPENAI_API_VERSION_CONFIG_KEY)
219
- or config.get(OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY)
220
- or os.environ.get(OPENAI_API_VERSION_ENV_VAR)
221
- )
222
- for keys in [
223
- OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY,
224
- OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY,
225
- OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY,
226
- OPENAI_DEPLOYMENT_CONFIG_KEY,
227
- OPENAI_ENGINE_CONFIG_KEY,
228
- LANGCHAIN_TYPE_CONFIG_KEY,
229
- ]:
230
- config.pop(keys, None)
231
-
232
- return config
233
-
234
-
235
- def process_config_for_aiohttp_chat_openai(config: Dict[str, Any]) -> Dict[str, Any]:
236
- config = config.copy()
237
- config.pop(LANGCHAIN_TYPE_CONFIG_KEY)
238
- return config
322
+ # Set diskcache as a caching option
323
+ litellm.cache = litellm.Cache(type="disk", disk_cache_dir=cache_location)
239
324
 
240
325
 
326
+ @_cache_factory
241
327
  def llm_factory(
242
328
  custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
243
- ) -> Union[
244
- "BaseLLM",
245
- "AzureChatOpenAI",
246
- "AioHTTPSessionAzureChatOpenAI",
247
- "AioHTTPSessionOpenAIChat",
248
- ]:
329
+ ) -> LLMClient:
249
330
  """Creates an LLM from the given config.
250
331
 
251
332
  Args:
252
333
  custom_config: The custom config containing values to overwrite defaults
253
334
  default_config: The default config.
254
335
 
255
-
256
336
  Returns:
257
- Instantiated LLM based on the configuration.
337
+ Instantiated LLM based on the configuration.
258
338
  """
259
- from langchain.llms.loading import load_llm_from_config
260
-
261
- ensure_cache()
262
-
263
339
  config = combine_custom_and_default_config(custom_config, default_config)
264
340
 
265
- # need to create a copy as the langchain function modifies the
266
- # config in place...
267
- structlogger.debug("llmfactory.create.llm", config=config)
268
- # langchain issues a user warning when using chat models. at the same time
269
- # it doesn't provide a way to instantiate a chat model directly using the
270
- # config. so for now, we need to suppress the warning here. Original
271
- # warning:
272
- # packages/langchain/llms/openai.py:189: UserWarning: You are trying to
273
- # use a chat model. This way of initializing it is no longer supported.
274
- # Instead, please use: `from langchain.chat_models import ChatOpenAI
275
- with warnings.catch_warnings():
276
- warnings.simplefilter("ignore", category=UserWarning)
277
- if is_azure_config(config):
278
- # Azure deployments are treated differently. This is done as the
279
- # GPT-3.5 Turbo newer versions 0613 and 1106 only support the
280
- # Chat Completions API.
281
- from langchain.chat_models import AzureChatOpenAI
282
- from rasa.shared.providers.openai.clients import (
283
- AioHTTPSessionAzureChatOpenAI,
284
- )
285
-
286
- transformed_config = preprocess_config_for_azure(config.copy())
287
- if os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is None:
288
- return AzureChatOpenAI(**transformed_config)
289
- else:
290
- return AioHTTPSessionAzureChatOpenAI(**transformed_config)
291
-
292
- if (
293
- os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is not None
294
- and config.get(LANGCHAIN_TYPE_CONFIG_KEY) == "openai"
295
- ):
296
- from rasa.shared.providers.openai.clients import AioHTTPSessionOpenAIChat
297
-
298
- config = process_config_for_aiohttp_chat_openai(config)
299
- return AioHTTPSessionOpenAIChat(**config.copy())
341
+ ensure_cache()
300
342
 
301
- return load_llm_from_config(config.copy())
343
+ client_clazz: Type[LLMClient] = get_llm_client_from_provider(
344
+ config[PROVIDER_CONFIG_KEY]
345
+ )
346
+ client = client_clazz.from_config(config)
347
+ return client
302
348
 
303
349
 
350
+ @_cache_factory
304
351
  def embedder_factory(
305
352
  custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
306
- ) -> "Embeddings":
353
+ ) -> EmbeddingClient:
307
354
  """Creates an Embedder from the given config.
308
355
 
309
356
  Args:
@@ -312,55 +359,17 @@ def embedder_factory(
312
359
 
313
360
 
314
361
  Returns:
315
- Instantiated Embedder based on the configuration.
362
+ Instantiated Embedder based on the configuration.
316
363
  """
317
- from langchain.schema.embeddings import Embeddings
318
- from langchain.embeddings import (
319
- CohereEmbeddings,
320
- HuggingFaceHubEmbeddings,
321
- HuggingFaceInstructEmbeddings,
322
- HuggingFaceEmbeddings,
323
- HuggingFaceBgeEmbeddings,
324
- LlamaCppEmbeddings,
325
- OpenAIEmbeddings,
326
- SpacyEmbeddings,
327
- VertexAIEmbeddings,
328
- )
329
- from rasa.shared.providers.openai.clients import AioHTTPSessionOpenAIEmbeddings
330
-
331
- type_to_embedding_cls_dict: Dict[str, Type[Embeddings]] = {
332
- "azure": OpenAIEmbeddings,
333
- "openai": OpenAIEmbeddings,
334
- "openai-aiohttp-session": AioHTTPSessionOpenAIEmbeddings,
335
- "cohere": CohereEmbeddings,
336
- "spacy": SpacyEmbeddings,
337
- "vertexai": VertexAIEmbeddings,
338
- "huggingface_instruct": HuggingFaceInstructEmbeddings,
339
- "huggingface_hub": HuggingFaceHubEmbeddings,
340
- "huggingface_bge": HuggingFaceBgeEmbeddings,
341
- "huggingface": HuggingFaceEmbeddings,
342
- "llamacpp": LlamaCppEmbeddings,
343
- }
344
-
345
364
  config = combine_custom_and_default_config(custom_config, default_config)
346
- embedding_type = config.get(LANGCHAIN_TYPE_CONFIG_KEY)
347
-
348
- if (
349
- os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is not None
350
- and embedding_type is not None
351
- ):
352
- embedding_type = f"{embedding_type}-aiohttp-session"
353
365
 
354
- structlogger.debug("llmfactory.create.embedder", config=config)
366
+ ensure_cache()
355
367
 
356
- if not embedding_type:
357
- return OpenAIEmbeddings()
358
- elif embeddings_cls := type_to_embedding_cls_dict.get(embedding_type):
359
- parameters = config.copy()
360
- parameters.pop(LANGCHAIN_TYPE_CONFIG_KEY)
361
- return embeddings_cls(**parameters)
362
- else:
363
- raise ValueError(f"Unsupported embeddings type '{embedding_type}'")
368
+ client_clazz: Type[EmbeddingClient] = get_embedding_client_from_provider(
369
+ config[PROVIDER_CONFIG_KEY]
370
+ )
371
+ client = client_clazz.from_config(config)
372
+ return client
364
373
 
365
374
 
366
375
  def get_prompt_template(
@@ -396,9 +405,23 @@ def allowed_values_for_slot(slot: Slot) -> Union[str, None]:
396
405
  return None
397
406
 
398
407
 
399
- def is_azure_config(config: Dict) -> bool:
400
- return (
401
- config.get(OPENAI_API_TYPE_CONFIG_KEY) == "azure"
402
- or config.get(OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY) == "azure"
403
- or os.environ.get(OPENAI_API_TYPE_ENV_VAR) == "azure"
404
- )
408
+ def try_instantiate_llm_client(
409
+ custom_llm_config: Optional[Dict],
410
+ default_llm_config: Optional[Dict],
411
+ log_source_function: str,
412
+ log_source_component: str,
413
+ ) -> None:
414
+ """Validate llm configuration."""
415
+ try:
416
+ llm_factory(custom_llm_config, default_llm_config)
417
+ except (ProviderClientValidationError, ValueError) as e:
418
+ structlogger.error(
419
+ f"{log_source_function}.llm_instantiation_failed",
420
+ message="Unable to instantiate LLM client.",
421
+ error=e,
422
+ )
423
+ print_error_and_exit(
424
+ f"Unable to create the LLM client for component - {log_source_component}. "
425
+ f"Please make sure you specified the required environment variables. "
426
+ f"Error: {e}"
427
+ )
rasa/shared/utils/yaml.py CHANGED
@@ -1,3 +1,4 @@
1
+ import datetime
1
2
  import logging
2
3
  import os
3
4
  import re
@@ -12,15 +13,17 @@ from typing import Dict, List, Optional, Any, Callable, Tuple, Union
12
13
  import jsonschema
13
14
  from importlib_resources import files
14
15
  from packaging import version
15
- from packaging.version import LegacyVersion
16
16
  from pykwalify.core import Core
17
17
  from pykwalify.errors import SchemaError
18
18
  from ruamel import yaml as yaml
19
19
  from ruamel.yaml import RoundTripRepresenter, YAMLError
20
20
  from ruamel.yaml.constructor import DuplicateKeyError, BaseConstructor, ScalarNode
21
21
  from ruamel.yaml.comments import CommentedSeq, CommentedMap
22
+ from ruamel.yaml.loader import SafeLoader
22
23
 
23
24
  from rasa.shared.constants import (
25
+ ASSERTIONS_SCHEMA_EXTENSIONS_FILE,
26
+ ASSERTIONS_SCHEMA_FILE,
24
27
  MODEL_CONFIG_SCHEMA_FILE,
25
28
  CONFIG_SCHEMA_FILE,
26
29
  DOCS_URL_TRAINING_DATA,
@@ -413,12 +416,17 @@ def validate_raw_yaml_using_schema_file_with_responses(
413
416
  )
414
417
 
415
418
 
416
- def read_yaml(content: str, reader_type: Union[str, List[str]] = "safe") -> Any:
419
+ def read_yaml(
420
+ content: str,
421
+ reader_type: Union[str, List[str]] = "safe",
422
+ **kwargs: Any,
423
+ ) -> Any:
417
424
  """Parses yaml from a text.
418
425
 
419
426
  Args:
420
427
  content: A text containing yaml content.
421
428
  reader_type: Reader type to use. By default, "safe" will be used.
429
+ **kwargs: Any
422
430
 
423
431
  Raises:
424
432
  ruamel.yaml.parser.ParserError: If there was an error when parsing the YAML.
@@ -432,11 +440,93 @@ def read_yaml(content: str, reader_type: Union[str, List[str]] = "safe") -> Any:
432
440
  .decode("utf-16")
433
441
  )
434
442
 
443
+ custom_constructor = kwargs.get("custom_constructor", None)
444
+
445
+ # Create YAML parser with custom constructor
446
+ yaml_parser, reset_constructors = create_yaml_parser(
447
+ reader_type, custom_constructor
448
+ )
449
+ yaml_content = yaml_parser.load(content) or {}
450
+
451
+ # Reset to default constructors
452
+ reset_constructors()
453
+
454
+ return yaml_content
455
+
456
+
457
+ def create_yaml_parser(
458
+ reader_type: str,
459
+ custom_constructor: Optional[Callable] = None,
460
+ ) -> Tuple[yaml.YAML, Callable[[], None]]:
461
+ """Create a YAML parser with an optional custom constructor.
462
+
463
+ Args:
464
+ reader_type (str): The type of the reader
465
+ (e.g., 'safe', 'rt', 'unsafe').
466
+ custom_constructor (Optional[Callable]):
467
+ A custom constructor function for YAML parsing.
468
+
469
+ Returns:
470
+ Tuple[yaml.YAML, Callable[[], None]]: A tuple containing
471
+ the YAML parser and a function to reset constructors to
472
+ their original state.
473
+ """
435
474
  yaml_parser = yaml.YAML(typ=reader_type)
436
475
  yaml_parser.version = YAML_VERSION # type: ignore[assignment]
437
476
  yaml_parser.preserve_quotes = True # type: ignore[assignment]
438
477
 
439
- return yaml_parser.load(content) or {}
478
+ # Save the original constructors
479
+ original_mapping_constructor = yaml_parser.constructor.yaml_constructors.get(
480
+ yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
481
+ )
482
+ original_sequence_constructor = yaml_parser.constructor.yaml_constructors.get(
483
+ yaml.resolver.BaseResolver.DEFAULT_SEQUENCE_TAG
484
+ )
485
+
486
+ if custom_constructor is not None:
487
+ # Attach the custom constructor to the loader
488
+ yaml_parser.constructor.add_constructor(
489
+ yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, custom_constructor
490
+ )
491
+ yaml_parser.constructor.add_constructor(
492
+ yaml.resolver.BaseResolver.DEFAULT_SEQUENCE_TAG, custom_constructor
493
+ )
494
+
495
+ def reset_constructors() -> None:
496
+ """Reset the constructors back to their original state."""
497
+ yaml_parser.constructor.add_constructor(
498
+ yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, original_mapping_constructor
499
+ )
500
+ yaml_parser.constructor.add_constructor(
501
+ yaml.resolver.BaseResolver.DEFAULT_SEQUENCE_TAG,
502
+ original_sequence_constructor,
503
+ )
504
+
505
+ def custom_date_constructor(loader: SafeLoader, node: ScalarNode) -> str:
506
+ """Custom constructor for parsing dates in the format '%Y-%m-%d'.
507
+
508
+ This constructor parses dates in the '%Y-%m-%d' format and returns them as
509
+ strings instead of datetime objects. This change was introduced because the
510
+ default timestamp constructor in ruamel.yaml returns datetime objects, which
511
+ caused issues in our use case where the `api_version` in the LLM config must
512
+ be a string, but was being interpreted as a datetime object.
513
+ """
514
+ value = loader.construct_scalar(node)
515
+ try:
516
+ # Attempt to parse the date
517
+ date_obj = datetime.datetime.strptime(value, "%Y-%m-%d").date()
518
+ # Return the date as a string instead of a datetime object
519
+ return date_obj.strftime("%Y-%m-%d")
520
+ except ValueError:
521
+ # If the date is not in the correct format, return the original value
522
+ return value
523
+
524
+ # Add the custom date constructor
525
+ yaml_parser.constructor.add_constructor(
526
+ "tag:yaml.org,2002:timestamp", custom_date_constructor
527
+ )
528
+
529
+ return yaml_parser, reset_constructors
440
530
 
441
531
 
442
532
  def _is_ascii(text: str) -> bool:
@@ -684,9 +774,6 @@ def validate_training_data_format_version(
684
774
  parsed_version = version.parse(version_value)
685
775
  latest_version = version.parse(LATEST_TRAINING_DATA_FORMAT_VERSION)
686
776
 
687
- if isinstance(parsed_version, LegacyVersion):
688
- raise TypeError
689
-
690
777
  if parsed_version < latest_version:
691
778
  raise_warning(
692
779
  f"Training data file {filename} has a lower "
@@ -702,7 +789,7 @@ def validate_training_data_format_version(
702
789
  if latest_version >= parsed_version:
703
790
  return True
704
791
 
705
- except TypeError:
792
+ except (TypeError, version.InvalidVersion):
706
793
  raise_warning(
707
794
  f"Training data file {filename} must specify "
708
795
  f"'{KEY_TRAINING_DATA_FORMAT_VERSION}' as string, for example:\n"
@@ -784,3 +871,31 @@ def validate_yaml_with_jsonschema(
784
871
  errors,
785
872
  content=source_data,
786
873
  )
874
+
875
+
876
+ def validate_yaml_data_using_schema_with_assertions(
877
+ yaml_data: Any,
878
+ schema_content: Union[List[Any], Dict[str, Any]],
879
+ package_name: str = PACKAGE_NAME,
880
+ ) -> None:
881
+ """Validate raw yaml content using a schema with assertions sub-schema.
882
+
883
+ Args:
884
+ yaml_data: the parsed yaml data to be validated
885
+ schema_content: the content of the YAML schema
886
+ package_name: the name of the package the schema is located in. defaults
887
+ to `rasa`.
888
+ """
889
+ # test case assertions are part of the schema extension
890
+ # it will be included if the schema explicitly references it with
891
+ # include: assertions
892
+ e2e_test_cases_schema_content = read_schema_file(
893
+ ASSERTIONS_SCHEMA_FILE, package_name
894
+ )
895
+
896
+ schema_content = dict(schema_content, **e2e_test_cases_schema_content)
897
+ schema_extensions = [
898
+ str(files(package_name).joinpath(ASSERTIONS_SCHEMA_EXTENSIONS_FILE))
899
+ ]
900
+
901
+ validate_yaml_content_using_schema(yaml_data, schema_content, schema_extensions)