rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (183) hide show
  1. README.md +0 -374
  2. rasa/__init__.py +1 -2
  3. rasa/__main__.py +5 -0
  4. rasa/anonymization/anonymization_rule_executor.py +2 -2
  5. rasa/api.py +27 -23
  6. rasa/cli/arguments/data.py +27 -2
  7. rasa/cli/arguments/default_arguments.py +25 -3
  8. rasa/cli/arguments/run.py +9 -9
  9. rasa/cli/arguments/train.py +11 -3
  10. rasa/cli/data.py +70 -8
  11. rasa/cli/e2e_test.py +104 -431
  12. rasa/cli/evaluate.py +1 -1
  13. rasa/cli/interactive.py +1 -0
  14. rasa/cli/llm_fine_tuning.py +398 -0
  15. rasa/cli/project_templates/calm/endpoints.yml +1 -1
  16. rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
  17. rasa/cli/run.py +15 -14
  18. rasa/cli/scaffold.py +10 -8
  19. rasa/cli/studio/studio.py +35 -5
  20. rasa/cli/train.py +56 -8
  21. rasa/cli/utils.py +22 -5
  22. rasa/cli/x.py +1 -1
  23. rasa/constants.py +7 -1
  24. rasa/core/actions/action.py +98 -49
  25. rasa/core/actions/action_run_slot_rejections.py +4 -1
  26. rasa/core/actions/custom_action_executor.py +9 -6
  27. rasa/core/actions/direct_custom_actions_executor.py +80 -0
  28. rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
  29. rasa/core/actions/grpc_custom_action_executor.py +2 -2
  30. rasa/core/actions/http_custom_action_executor.py +6 -5
  31. rasa/core/agent.py +21 -17
  32. rasa/core/channels/__init__.py +2 -0
  33. rasa/core/channels/audiocodes.py +1 -16
  34. rasa/core/channels/voice_aware/__init__.py +0 -0
  35. rasa/core/channels/voice_aware/jambonz.py +103 -0
  36. rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
  37. rasa/core/channels/voice_aware/utils.py +20 -0
  38. rasa/core/channels/voice_native/__init__.py +0 -0
  39. rasa/core/constants.py +6 -1
  40. rasa/core/information_retrieval/faiss.py +7 -4
  41. rasa/core/information_retrieval/information_retrieval.py +8 -0
  42. rasa/core/information_retrieval/milvus.py +9 -2
  43. rasa/core/information_retrieval/qdrant.py +1 -1
  44. rasa/core/nlg/contextual_response_rephraser.py +32 -10
  45. rasa/core/nlg/summarize.py +4 -3
  46. rasa/core/policies/enterprise_search_policy.py +113 -45
  47. rasa/core/policies/flows/flow_executor.py +122 -76
  48. rasa/core/policies/intentless_policy.py +83 -29
  49. rasa/core/processor.py +72 -54
  50. rasa/core/run.py +5 -4
  51. rasa/core/tracker_store.py +8 -4
  52. rasa/core/training/interactive.py +1 -1
  53. rasa/core/utils.py +56 -57
  54. rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
  55. rasa/dialogue_understanding/commands/__init__.py +6 -0
  56. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  57. rasa/dialogue_understanding/commands/session_start_command.py +59 -0
  58. rasa/dialogue_understanding/commands/utils.py +40 -0
  59. rasa/dialogue_understanding/generator/constants.py +10 -3
  60. rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
  61. rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
  62. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
  63. rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
  64. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
  65. rasa/dialogue_understanding/patterns/restart.py +37 -0
  66. rasa/dialogue_understanding/patterns/session_start.py +37 -0
  67. rasa/dialogue_understanding/processor/command_processor.py +16 -3
  68. rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
  69. rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
  70. rasa/e2e_test/assertions.py +1223 -0
  71. rasa/e2e_test/assertions_schema.yml +106 -0
  72. rasa/e2e_test/constants.py +20 -0
  73. rasa/e2e_test/e2e_config.py +220 -0
  74. rasa/e2e_test/e2e_config_schema.yml +26 -0
  75. rasa/e2e_test/e2e_test_case.py +131 -8
  76. rasa/e2e_test/e2e_test_converter.py +363 -0
  77. rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
  78. rasa/e2e_test/e2e_test_coverage_report.py +364 -0
  79. rasa/e2e_test/e2e_test_result.py +26 -6
  80. rasa/e2e_test/e2e_test_runner.py +493 -71
  81. rasa/e2e_test/e2e_test_schema.yml +96 -0
  82. rasa/e2e_test/pykwalify_extensions.py +39 -0
  83. rasa/e2e_test/stub_custom_action.py +70 -0
  84. rasa/e2e_test/utils/__init__.py +0 -0
  85. rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
  86. rasa/e2e_test/utils/io.py +598 -0
  87. rasa/e2e_test/utils/validation.py +80 -0
  88. rasa/engine/graph.py +9 -3
  89. rasa/engine/recipes/default_components.py +0 -2
  90. rasa/engine/recipes/default_recipe.py +10 -2
  91. rasa/engine/storage/local_model_storage.py +40 -12
  92. rasa/engine/validation.py +78 -1
  93. rasa/env.py +9 -0
  94. rasa/graph_components/providers/story_graph_provider.py +59 -6
  95. rasa/llm_fine_tuning/__init__.py +0 -0
  96. rasa/llm_fine_tuning/annotation_module.py +241 -0
  97. rasa/llm_fine_tuning/conversations.py +144 -0
  98. rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
  99. rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
  100. rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
  101. rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
  102. rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
  103. rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
  104. rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
  105. rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
  106. rasa/llm_fine_tuning/storage.py +174 -0
  107. rasa/llm_fine_tuning/train_test_split_module.py +441 -0
  108. rasa/model_training.py +56 -16
  109. rasa/nlu/persistor.py +157 -36
  110. rasa/server.py +45 -10
  111. rasa/shared/constants.py +76 -16
  112. rasa/shared/core/domain.py +27 -19
  113. rasa/shared/core/events.py +28 -2
  114. rasa/shared/core/flows/flow.py +208 -13
  115. rasa/shared/core/flows/flow_path.py +84 -0
  116. rasa/shared/core/flows/flows_list.py +33 -11
  117. rasa/shared/core/flows/flows_yaml_schema.json +269 -193
  118. rasa/shared/core/flows/validation.py +112 -25
  119. rasa/shared/core/flows/yaml_flows_io.py +149 -10
  120. rasa/shared/core/trackers.py +6 -0
  121. rasa/shared/core/training_data/structures.py +20 -0
  122. rasa/shared/core/training_data/visualization.html +2 -2
  123. rasa/shared/exceptions.py +4 -0
  124. rasa/shared/importers/importer.py +64 -16
  125. rasa/shared/nlu/constants.py +2 -0
  126. rasa/shared/providers/_configs/__init__.py +0 -0
  127. rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
  128. rasa/shared/providers/_configs/client_config.py +57 -0
  129. rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
  130. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
  131. rasa/shared/providers/_configs/openai_client_config.py +175 -0
  132. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
  133. rasa/shared/providers/_configs/utils.py +101 -0
  134. rasa/shared/providers/_ssl_verification_utils.py +124 -0
  135. rasa/shared/providers/embedding/__init__.py +0 -0
  136. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
  137. rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
  138. rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
  139. rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
  140. rasa/shared/providers/embedding/embedding_client.py +90 -0
  141. rasa/shared/providers/embedding/embedding_response.py +41 -0
  142. rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
  143. rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
  144. rasa/shared/providers/llm/__init__.py +0 -0
  145. rasa/shared/providers/llm/_base_litellm_client.py +251 -0
  146. rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
  147. rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
  148. rasa/shared/providers/llm/llm_client.py +76 -0
  149. rasa/shared/providers/llm/llm_response.py +50 -0
  150. rasa/shared/providers/llm/openai_llm_client.py +155 -0
  151. rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
  152. rasa/shared/providers/mappings.py +75 -0
  153. rasa/shared/utils/cli.py +30 -0
  154. rasa/shared/utils/io.py +65 -2
  155. rasa/shared/utils/llm.py +246 -200
  156. rasa/shared/utils/yaml.py +121 -15
  157. rasa/studio/auth.py +6 -4
  158. rasa/studio/config.py +13 -4
  159. rasa/studio/constants.py +1 -0
  160. rasa/studio/data_handler.py +10 -3
  161. rasa/studio/download.py +19 -13
  162. rasa/studio/train.py +2 -3
  163. rasa/studio/upload.py +19 -11
  164. rasa/telemetry.py +113 -58
  165. rasa/tracing/instrumentation/attribute_extractors.py +32 -17
  166. rasa/utils/common.py +18 -19
  167. rasa/utils/endpoints.py +7 -4
  168. rasa/utils/json_utils.py +60 -0
  169. rasa/utils/licensing.py +9 -1
  170. rasa/utils/ml_utils.py +4 -2
  171. rasa/validator.py +213 -3
  172. rasa/version.py +1 -1
  173. rasa_pro-3.10.16.dist-info/METADATA +196 -0
  174. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
  175. rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
  176. rasa/shared/providers/openai/clients.py +0 -43
  177. rasa/shared/providers/openai/session_handler.py +0 -110
  178. rasa_pro-3.9.18.dist-info/METADATA +0 -563
  179. /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
  180. /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
  181. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
  182. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
  183. {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
rasa/shared/utils/llm.py CHANGED
@@ -1,6 +1,18 @@
1
- import os
2
- import warnings
3
- from typing import Any, Dict, Optional, Text, Type, TYPE_CHECKING, Union
1
+ import sys
2
+ from functools import wraps
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ Optional,
8
+ Text,
9
+ Type,
10
+ TypeVar,
11
+ TYPE_CHECKING,
12
+ Union,
13
+ cast,
14
+ )
15
+ import json
4
16
 
5
17
  import structlog
6
18
 
@@ -8,39 +20,42 @@ import rasa.shared.utils.io
8
20
  from rasa.shared.constants import (
9
21
  RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_TOO_LONG,
10
22
  RASA_PATTERN_INTERNAL_ERROR_USER_INPUT_EMPTY,
11
- OPENAI_API_TYPE_ENV_VAR,
12
- OPENAI_API_VERSION_ENV_VAR,
13
- OPENAI_API_BASE_ENV_VAR,
14
- REQUESTS_CA_BUNDLE_ENV_VAR,
15
- OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY,
16
- OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY,
17
- OPENAI_API_VERSION_CONFIG_KEY,
18
- OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY,
19
- OPENAI_API_TYPE_CONFIG_KEY,
20
- OPENAI_API_BASE_CONFIG_KEY,
21
- OPENAI_DEPLOYMENT_NAME_CONFIG_KEY,
22
- OPENAI_DEPLOYMENT_CONFIG_KEY,
23
- OPENAI_ENGINE_CONFIG_KEY,
24
- LANGCHAIN_TYPE_CONFIG_KEY,
25
- RASA_TYPE_CONFIG_KEY,
23
+ PROVIDER_CONFIG_KEY,
26
24
  )
27
25
  from rasa.shared.core.events import BotUttered, UserUttered
28
26
  from rasa.shared.core.slots import Slot, BooleanSlot, CategoricalSlot
29
- from rasa.shared.engine.caching import get_local_cache_location
27
+ from rasa.shared.engine.caching import (
28
+ get_local_cache_location,
29
+ )
30
30
  from rasa.shared.exceptions import (
31
31
  FileIOException,
32
32
  FileNotFoundException,
33
+ ProviderClientValidationError,
34
+ )
35
+ from rasa.shared.providers._configs.azure_openai_client_config import (
36
+ is_azure_openai_config,
37
+ )
38
+ from rasa.shared.providers._configs.huggingface_local_embedding_client_config import (
39
+ is_huggingface_local_config,
40
+ )
41
+ from rasa.shared.providers._configs.openai_client_config import is_openai_config
42
+ from rasa.shared.providers._configs.self_hosted_llm_client_config import (
43
+ is_self_hosted_config,
44
+ )
45
+ from rasa.shared.providers.embedding.embedding_client import EmbeddingClient
46
+ from rasa.shared.providers.llm.llm_client import LLMClient
47
+ from rasa.shared.providers.mappings import (
48
+ get_llm_client_from_provider,
49
+ AZURE_OPENAI_PROVIDER,
50
+ OPENAI_PROVIDER,
51
+ SELF_HOSTED_PROVIDER,
52
+ get_embedding_client_from_provider,
53
+ HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER,
54
+ get_client_config_class_from_provider,
33
55
  )
34
56
 
35
57
  if TYPE_CHECKING:
36
- from langchain.chat_models import AzureChatOpenAI
37
- from langchain.schema.embeddings import Embeddings
38
- from langchain.llms.base import BaseLLM
39
58
  from rasa.shared.core.trackers import DialogueStateTracker
40
- from rasa.shared.providers.openai.clients import (
41
- AioHTTPSessionAzureChatOpenAI,
42
- AioHTTPSessionOpenAIChat,
43
- )
44
59
 
45
60
  structlogger = structlog.get_logger()
46
61
 
@@ -52,7 +67,7 @@ DEFAULT_OPENAI_GENERATE_MODEL_NAME = "gpt-3.5-turbo"
52
67
 
53
68
  DEFAULT_OPENAI_CHAT_MODEL_NAME = "gpt-3.5-turbo"
54
69
 
55
- DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED = "gpt-4"
70
+ DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED = "gpt-4-0613"
56
71
 
57
72
  DEFAULT_OPENAI_EMBEDDING_MODEL_NAME = "text-embedding-ada-002"
58
73
 
@@ -70,6 +85,94 @@ ERROR_PLACEHOLDER = {
70
85
  "default": "[User input triggered an error]",
71
86
  }
72
87
 
88
+ _Factory_F = TypeVar(
89
+ "_Factory_F",
90
+ bound=Callable[[Dict[str, Any], Dict[str, Any]], Union[EmbeddingClient, LLMClient]],
91
+ )
92
+ _CombineConfigs_F = TypeVar(
93
+ "_CombineConfigs_F",
94
+ bound=Callable[[Dict[str, Any], Dict[str, Any]], Dict[str, Any]],
95
+ )
96
+
97
+
98
+ def _compute_hash_for_cache_from_configs(
99
+ config_x: Dict[str, Any], config_y: Dict[str, Any]
100
+ ) -> int:
101
+ """Get a unique hash of the default and custom configs."""
102
+ return hash(
103
+ json.dumps(config_x, sort_keys=True) + json.dumps(config_y, sort_keys=True)
104
+ )
105
+
106
+
107
+ def _retrieve_from_cache(
108
+ cache: Dict[int, Any], unique_hash: int, function: Callable, function_kwargs: dict
109
+ ) -> Any:
110
+ """Retrieve the value from the cache if it exists. If it does not exist, cache it"""
111
+ if unique_hash in cache:
112
+ return cache[unique_hash]
113
+ else:
114
+ return_value = function(**function_kwargs)
115
+ cache[unique_hash] = return_value
116
+ return return_value
117
+
118
+
119
+ def _cache_factory(function: _Factory_F) -> _Factory_F:
120
+ """Memoize the factory methods based on the arguments."""
121
+ cache: Dict[int, Union[EmbeddingClient, LLMClient]] = {}
122
+
123
+ @wraps(function)
124
+ def factory_method_wrapper(
125
+ config_x: Dict[str, Any], config_y: Dict[str, Any]
126
+ ) -> Union[EmbeddingClient, LLMClient]:
127
+ # Get a unique hash of the default and custom configs.
128
+ unique_hash = _compute_hash_for_cache_from_configs(config_x, config_y)
129
+ return _retrieve_from_cache(
130
+ cache=cache,
131
+ unique_hash=unique_hash,
132
+ function=function,
133
+ function_kwargs={"custom_config": config_x, "default_config": config_y},
134
+ )
135
+
136
+ def clear_cache() -> None:
137
+ cache.clear()
138
+ structlogger.debug(
139
+ "Cleared cache for factory method",
140
+ function_name=function.__name__,
141
+ )
142
+
143
+ setattr(factory_method_wrapper, "clear_cache", clear_cache)
144
+ return cast(_Factory_F, factory_method_wrapper)
145
+
146
+
147
+ def _cache_combine_custom_and_default_configs(
148
+ function: _CombineConfigs_F,
149
+ ) -> _CombineConfigs_F:
150
+ """Memoize the combine_custom_and_default_config method based on the arguments."""
151
+ cache: Dict[int, dict] = {}
152
+
153
+ @wraps(function)
154
+ def combine_configs_wrapper(
155
+ config_x: Dict[str, Any], config_y: Dict[str, Any]
156
+ ) -> dict:
157
+ # Get a unique hash of the default and custom configs.
158
+ unique_hash = _compute_hash_for_cache_from_configs(config_x, config_y)
159
+ return _retrieve_from_cache(
160
+ cache=cache,
161
+ unique_hash=unique_hash,
162
+ function=function,
163
+ function_kwargs={"custom_config": config_x, "default_config": config_y},
164
+ )
165
+
166
+ def clear_cache() -> None:
167
+ cache.clear()
168
+ structlogger.debug(
169
+ "Cleared cache for combine_custom_and_default_config method",
170
+ function_name=function.__name__,
171
+ )
172
+
173
+ setattr(combine_configs_wrapper, "clear_cache", clear_cache)
174
+ return cast(_CombineConfigs_F, combine_configs_wrapper)
175
+
73
176
 
74
177
  def tracker_as_readable_transcript(
75
178
  tracker: "DialogueStateTracker",
@@ -138,11 +241,15 @@ def sanitize_message_for_prompt(text: Optional[str]) -> str:
138
241
  return text.replace("\n", " ") if text else ""
139
242
 
140
243
 
244
+ @_cache_combine_custom_and_default_configs
141
245
  def combine_custom_and_default_config(
142
- custom_config: Optional[Dict[Text, Any]], default_config: Dict[Text, Any]
246
+ custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
143
247
  ) -> Dict[Text, Any]:
144
248
  """Merges the given llm config with the default config.
145
249
 
250
+ This method guarantees that the provider is set and all the deprecated keys are
251
+ resolved. Hence, produces only a valid client config.
252
+
146
253
  Only uses the default configuration arguments, if the type set in the
147
254
  custom config matches the type in the default config. Otherwise, only
148
255
  the custom config is used.
@@ -155,155 +262,96 @@ def combine_custom_and_default_config(
155
262
  The merged config.
156
263
  """
157
264
  if custom_config is None:
158
- return default_config
159
-
160
- if RASA_TYPE_CONFIG_KEY in custom_config:
161
- # rename type to _type as "type" is the convention we use
162
- # across the different components in config files.
163
- # langchain expects "_type" as the key though
164
- custom_config[LANGCHAIN_TYPE_CONFIG_KEY] = custom_config.pop(
165
- RASA_TYPE_CONFIG_KEY
265
+ return default_config.copy()
266
+
267
+ # Get the provider from the custom config.
268
+ custom_config_provider = get_provider_from_config(custom_config)
269
+ # We expect the provider to be set in the default configs of all Rasa components.
270
+ default_config_provider = default_config[PROVIDER_CONFIG_KEY]
271
+
272
+ if (
273
+ custom_config_provider is not None
274
+ and custom_config_provider != default_config_provider
275
+ ):
276
+ # Get the provider-specific config class
277
+ client_config_clazz = get_client_config_class_from_provider(
278
+ custom_config_provider
166
279
  )
280
+ # Checks for deprecated keys, resolves aliases and returns a valid config.
281
+ # This is done to ensure that the custom config is valid.
282
+ return client_config_clazz.from_dict(custom_config).to_dict()
283
+
284
+ # If the provider is the same in both configs
285
+ # OR provider is not specified in the custom config
286
+ # perform MERGE by overriding the default config keys and values
287
+ # with custom config keys and values.
288
+ merged_config = {**default_config.copy(), **custom_config.copy()}
289
+ # Check for deprecated keys, resolve aliases and return a valid config.
290
+ # This is done to ensure that the merged config is valid.
291
+ default_config_clazz = get_client_config_class_from_provider(
292
+ default_config_provider
293
+ )
294
+ return default_config_clazz.from_dict(merged_config).to_dict()
167
295
 
168
- if LANGCHAIN_TYPE_CONFIG_KEY in custom_config and custom_config[
169
- LANGCHAIN_TYPE_CONFIG_KEY
170
- ] != default_config.get(LANGCHAIN_TYPE_CONFIG_KEY):
171
- return custom_config
172
- return {**default_config, **custom_config}
296
+
297
+ def get_provider_from_config(config: dict) -> Optional[str]:
298
+ """Try to get the provider from the passed llm/embeddings configuration.
299
+ If no provider can be found, return None.
300
+ """
301
+ if not config:
302
+ return None
303
+ if is_self_hosted_config(config):
304
+ return SELF_HOSTED_PROVIDER
305
+ elif is_azure_openai_config(config):
306
+ return AZURE_OPENAI_PROVIDER
307
+ elif is_openai_config(config):
308
+ return OPENAI_PROVIDER
309
+ elif is_huggingface_local_config(config):
310
+ return HUGGINGFACE_LOCAL_EMBEDDING_PROVIDER
311
+ else:
312
+ return config.get(PROVIDER_CONFIG_KEY)
173
313
 
174
314
 
175
315
  def ensure_cache() -> None:
176
316
  """Ensures that the cache is initialized."""
177
- import langchain
178
- from langchain.cache import SQLiteCache
317
+ import litellm
179
318
 
180
- # ensure the cache directory exists
181
- cache_location = get_local_cache_location()
319
+ # Ensure the cache directory exists
320
+ cache_location = get_local_cache_location() / "rasa-llm-cache"
182
321
  cache_location.mkdir(parents=True, exist_ok=True)
183
322
 
184
- db_location = cache_location / "rasa-llm-cache.db"
185
- langchain.llm_cache = SQLiteCache(database_path=str(db_location))
186
-
187
-
188
- def preprocess_config_for_azure(config: Dict[str, Any]) -> Dict[str, Any]:
189
- """Preprocesses the config for Azure deployments.
190
-
191
- This function is used to preprocess the config for Azure deployments.
192
- AzureChatOpenAI does not expect the _type key, as it is not a defined parameter
193
- in the class. So we need to remove it before passing the config to the class.
194
- AzureChatOpenAI expects the openai_api_type key to be set instead.
195
-
196
- Args:
197
- config: The config to preprocess.
198
-
199
- Returns:
200
- The preprocessed config.
201
- """
202
- config["deployment_name"] = (
203
- config.get(OPENAI_DEPLOYMENT_NAME_CONFIG_KEY)
204
- or config.get(OPENAI_DEPLOYMENT_CONFIG_KEY)
205
- or config.get(OPENAI_ENGINE_CONFIG_KEY)
206
- )
207
- config["openai_api_base"] = (
208
- config.get(OPENAI_API_BASE_CONFIG_KEY)
209
- or config.get(OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY)
210
- or os.environ.get(OPENAI_API_BASE_ENV_VAR)
211
- )
212
- config["openai_api_type"] = (
213
- config.get(OPENAI_API_TYPE_CONFIG_KEY)
214
- or config.get(OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY)
215
- or os.environ.get(OPENAI_API_TYPE_ENV_VAR)
216
- )
217
- config["openai_api_version"] = (
218
- config.get(OPENAI_API_VERSION_CONFIG_KEY)
219
- or config.get(OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY)
220
- or os.environ.get(OPENAI_API_VERSION_ENV_VAR)
221
- )
222
- for keys in [
223
- OPENAI_API_BASE_NO_PREFIX_CONFIG_KEY,
224
- OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY,
225
- OPENAI_API_VERSION_NO_PREFIX_CONFIG_KEY,
226
- OPENAI_DEPLOYMENT_CONFIG_KEY,
227
- OPENAI_ENGINE_CONFIG_KEY,
228
- LANGCHAIN_TYPE_CONFIG_KEY,
229
- ]:
230
- config.pop(keys, None)
231
-
232
- return config
233
-
234
-
235
- def process_config_for_aiohttp_chat_openai(config: Dict[str, Any]) -> Dict[str, Any]:
236
- config = config.copy()
237
- config.pop(LANGCHAIN_TYPE_CONFIG_KEY)
238
- return config
323
+ # Set diskcache as a caching option
324
+ litellm.cache = litellm.Cache(type="disk", disk_cache_dir=cache_location)
239
325
 
240
326
 
327
+ @_cache_factory
241
328
  def llm_factory(
242
329
  custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
243
- ) -> Union[
244
- "BaseLLM",
245
- "AzureChatOpenAI",
246
- "AioHTTPSessionAzureChatOpenAI",
247
- "AioHTTPSessionOpenAIChat",
248
- ]:
330
+ ) -> LLMClient:
249
331
  """Creates an LLM from the given config.
250
332
 
251
333
  Args:
252
334
  custom_config: The custom config containing values to overwrite defaults
253
335
  default_config: The default config.
254
336
 
255
-
256
337
  Returns:
257
- Instantiated LLM based on the configuration.
338
+ Instantiated LLM based on the configuration.
258
339
  """
259
- from langchain.llms.loading import load_llm_from_config
260
-
261
- ensure_cache()
262
-
263
340
  config = combine_custom_and_default_config(custom_config, default_config)
264
341
 
265
- # need to create a copy as the langchain function modifies the
266
- # config in place...
267
- structlogger.debug("llmfactory.create.llm", config=config)
268
- # langchain issues a user warning when using chat models. at the same time
269
- # it doesn't provide a way to instantiate a chat model directly using the
270
- # config. so for now, we need to suppress the warning here. Original
271
- # warning:
272
- # packages/langchain/llms/openai.py:189: UserWarning: You are trying to
273
- # use a chat model. This way of initializing it is no longer supported.
274
- # Instead, please use: `from langchain.chat_models import ChatOpenAI
275
- with warnings.catch_warnings():
276
- warnings.simplefilter("ignore", category=UserWarning)
277
- if is_azure_config(config):
278
- # Azure deployments are treated differently. This is done as the
279
- # GPT-3.5 Turbo newer versions 0613 and 1106 only support the
280
- # Chat Completions API.
281
- from langchain.chat_models import AzureChatOpenAI
282
- from rasa.shared.providers.openai.clients import (
283
- AioHTTPSessionAzureChatOpenAI,
284
- )
285
-
286
- transformed_config = preprocess_config_for_azure(config.copy())
287
- if os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is None:
288
- return AzureChatOpenAI(**transformed_config)
289
- else:
290
- return AioHTTPSessionAzureChatOpenAI(**transformed_config)
291
-
292
- if (
293
- os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is not None
294
- and config.get(LANGCHAIN_TYPE_CONFIG_KEY) == "openai"
295
- ):
296
- from rasa.shared.providers.openai.clients import AioHTTPSessionOpenAIChat
297
-
298
- config = process_config_for_aiohttp_chat_openai(config)
299
- return AioHTTPSessionOpenAIChat(**config.copy())
342
+ ensure_cache()
300
343
 
301
- return load_llm_from_config(config.copy())
344
+ client_clazz: Type[LLMClient] = get_llm_client_from_provider(
345
+ config[PROVIDER_CONFIG_KEY]
346
+ )
347
+ client = client_clazz.from_config(config)
348
+ return client
302
349
 
303
350
 
351
+ @_cache_factory
304
352
  def embedder_factory(
305
353
  custom_config: Optional[Dict[str, Any]], default_config: Dict[str, Any]
306
- ) -> "Embeddings":
354
+ ) -> EmbeddingClient:
307
355
  """Creates an Embedder from the given config.
308
356
 
309
357
  Args:
@@ -312,55 +360,17 @@ def embedder_factory(
312
360
 
313
361
 
314
362
  Returns:
315
- Instantiated Embedder based on the configuration.
363
+ Instantiated Embedder based on the configuration.
316
364
  """
317
- from langchain.schema.embeddings import Embeddings
318
- from langchain.embeddings import (
319
- CohereEmbeddings,
320
- HuggingFaceHubEmbeddings,
321
- HuggingFaceInstructEmbeddings,
322
- HuggingFaceEmbeddings,
323
- HuggingFaceBgeEmbeddings,
324
- LlamaCppEmbeddings,
325
- OpenAIEmbeddings,
326
- SpacyEmbeddings,
327
- VertexAIEmbeddings,
328
- )
329
- from rasa.shared.providers.openai.clients import AioHTTPSessionOpenAIEmbeddings
330
-
331
- type_to_embedding_cls_dict: Dict[str, Type[Embeddings]] = {
332
- "azure": OpenAIEmbeddings,
333
- "openai": OpenAIEmbeddings,
334
- "openai-aiohttp-session": AioHTTPSessionOpenAIEmbeddings,
335
- "cohere": CohereEmbeddings,
336
- "spacy": SpacyEmbeddings,
337
- "vertexai": VertexAIEmbeddings,
338
- "huggingface_instruct": HuggingFaceInstructEmbeddings,
339
- "huggingface_hub": HuggingFaceHubEmbeddings,
340
- "huggingface_bge": HuggingFaceBgeEmbeddings,
341
- "huggingface": HuggingFaceEmbeddings,
342
- "llamacpp": LlamaCppEmbeddings,
343
- }
344
-
345
365
  config = combine_custom_and_default_config(custom_config, default_config)
346
- embedding_type = config.get(LANGCHAIN_TYPE_CONFIG_KEY)
347
366
 
348
- if (
349
- os.environ.get(REQUESTS_CA_BUNDLE_ENV_VAR) is not None
350
- and embedding_type is not None
351
- ):
352
- embedding_type = f"{embedding_type}-aiohttp-session"
353
-
354
- structlogger.debug("llmfactory.create.embedder", config=config)
367
+ ensure_cache()
355
368
 
356
- if not embedding_type:
357
- return OpenAIEmbeddings()
358
- elif embeddings_cls := type_to_embedding_cls_dict.get(embedding_type):
359
- parameters = config.copy()
360
- parameters.pop(LANGCHAIN_TYPE_CONFIG_KEY)
361
- return embeddings_cls(**parameters)
362
- else:
363
- raise ValueError(f"Unsupported embeddings type '{embedding_type}'")
369
+ client_clazz: Type[EmbeddingClient] = get_embedding_client_from_provider(
370
+ config[PROVIDER_CONFIG_KEY]
371
+ )
372
+ client = client_clazz.from_config(config)
373
+ return client
364
374
 
365
375
 
366
376
  def get_prompt_template(
@@ -396,9 +406,45 @@ def allowed_values_for_slot(slot: Slot) -> Union[str, None]:
396
406
  return None
397
407
 
398
408
 
399
- def is_azure_config(config: Dict) -> bool:
400
- return (
401
- config.get(OPENAI_API_TYPE_CONFIG_KEY) == "azure"
402
- or config.get(OPENAI_API_TYPE_NO_PREFIX_CONFIG_KEY) == "azure"
403
- or os.environ.get(OPENAI_API_TYPE_ENV_VAR) == "azure"
404
- )
409
+ def try_instantiate_llm_client(
410
+ custom_llm_config: Optional[Dict],
411
+ default_llm_config: Optional[Dict],
412
+ log_source_function: str,
413
+ log_source_component: str,
414
+ ) -> None:
415
+ """Validate llm configuration."""
416
+ try:
417
+ llm_factory(custom_llm_config, default_llm_config)
418
+ except (ProviderClientValidationError, ValueError) as e:
419
+ structlogger.error(
420
+ f"{log_source_function}.llm_instantiation_failed",
421
+ event_info=(
422
+ f"Unable to create the LLM client for component - "
423
+ f"{log_source_component}. Please make sure you specified the required "
424
+ f"environment variables and configuration keys."
425
+ ),
426
+ error=e,
427
+ )
428
+ sys.exit(1)
429
+
430
+
431
+ def try_instantiate_embedder(
432
+ custom_embeddings_config: Optional[Dict],
433
+ default_embeddings_config: Optional[Dict],
434
+ log_source_function: str,
435
+ log_source_component: str,
436
+ ) -> EmbeddingClient:
437
+ """Validate embeddings configuration."""
438
+ try:
439
+ return embedder_factory(custom_embeddings_config, default_embeddings_config)
440
+ except (ProviderClientValidationError, ValueError) as e:
441
+ structlogger.error(
442
+ f"{log_source_function}.embedder_instantiation_failed",
443
+ event_info=(
444
+ f"Unable to create the Embedding client for component - "
445
+ f"{log_source_component}. Please make sure you specified the required "
446
+ f"environment variables and configuration keys."
447
+ ),
448
+ error=e,
449
+ )
450
+ sys.exit(1)