rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (163) hide show
  1. rasa/__main__.py +22 -12
  2. rasa/api.py +1 -1
  3. rasa/cli/arguments/default_arguments.py +1 -2
  4. rasa/cli/arguments/shell.py +5 -1
  5. rasa/cli/e2e_test.py +1 -1
  6. rasa/cli/evaluate.py +8 -8
  7. rasa/cli/inspect.py +4 -4
  8. rasa/cli/llm_fine_tuning.py +1 -1
  9. rasa/cli/project_templates/calm/config.yml +5 -7
  10. rasa/cli/project_templates/calm/endpoints.yml +8 -0
  11. rasa/cli/project_templates/tutorial/config.yml +8 -5
  12. rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
  13. rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
  14. rasa/cli/project_templates/tutorial/domain.yml +14 -0
  15. rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
  16. rasa/cli/run.py +1 -1
  17. rasa/cli/scaffold.py +4 -2
  18. rasa/cli/utils.py +5 -0
  19. rasa/cli/x.py +8 -8
  20. rasa/constants.py +1 -1
  21. rasa/core/channels/channel.py +3 -0
  22. rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
  23. rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
  24. rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
  25. rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
  26. rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
  27. rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
  28. rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
  29. rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
  30. rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
  31. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
  32. rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
  33. rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
  34. rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
  35. rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
  36. rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
  37. rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
  38. rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
  39. rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
  40. rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
  41. rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
  42. rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
  43. rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
  44. rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
  45. rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
  46. rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
  47. rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
  48. rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
  49. rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
  50. rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
  51. rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
  52. rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
  53. rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
  54. rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
  55. rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
  56. rasa/core/channels/inspector/dist/index.html +1 -1
  57. rasa/core/channels/inspector/src/App.tsx +1 -1
  58. rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
  59. rasa/core/channels/socketio.py +2 -1
  60. rasa/core/channels/telegram.py +1 -1
  61. rasa/core/channels/twilio.py +1 -1
  62. rasa/core/channels/voice_ready/jambonz.py +2 -2
  63. rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
  64. rasa/core/channels/voice_stream/asr/azure.py +122 -0
  65. rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
  66. rasa/core/channels/voice_stream/audio_bytes.py +1 -0
  67. rasa/core/channels/voice_stream/browser_audio.py +31 -8
  68. rasa/core/channels/voice_stream/call_state.py +23 -0
  69. rasa/core/channels/voice_stream/tts/azure.py +6 -2
  70. rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
  71. rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
  72. rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
  73. rasa/core/channels/voice_stream/util.py +4 -4
  74. rasa/core/channels/voice_stream/voice_channel.py +177 -39
  75. rasa/core/featurizers/single_state_featurizer.py +22 -1
  76. rasa/core/featurizers/tracker_featurizers.py +115 -18
  77. rasa/core/nlg/contextual_response_rephraser.py +16 -22
  78. rasa/core/persistor.py +86 -39
  79. rasa/core/policies/enterprise_search_policy.py +159 -60
  80. rasa/core/policies/flows/flow_executor.py +7 -4
  81. rasa/core/policies/intentless_policy.py +120 -22
  82. rasa/core/policies/ted_policy.py +58 -33
  83. rasa/core/policies/unexpected_intent_policy.py +15 -7
  84. rasa/core/processor.py +25 -0
  85. rasa/core/training/interactive.py +34 -35
  86. rasa/core/utils.py +8 -3
  87. rasa/dialogue_understanding/coexistence/llm_based_router.py +58 -16
  88. rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
  89. rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
  90. rasa/dialogue_understanding/commands/utils.py +5 -0
  91. rasa/dialogue_understanding/generator/constants.py +4 -0
  92. rasa/dialogue_understanding/generator/flow_retrieval.py +65 -3
  93. rasa/dialogue_understanding/generator/llm_based_command_generator.py +68 -26
  94. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -8
  95. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +64 -7
  96. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
  97. rasa/dialogue_understanding/patterns/user_silence.py +37 -0
  98. rasa/e2e_test/e2e_test_runner.py +4 -2
  99. rasa/e2e_test/utils/io.py +1 -1
  100. rasa/engine/validation.py +297 -7
  101. rasa/model_manager/config.py +15 -3
  102. rasa/model_manager/model_api.py +15 -7
  103. rasa/model_manager/runner_service.py +8 -6
  104. rasa/model_manager/socket_bridge.py +6 -3
  105. rasa/model_manager/trainer_service.py +7 -5
  106. rasa/model_manager/utils.py +28 -7
  107. rasa/model_service.py +6 -2
  108. rasa/model_training.py +2 -0
  109. rasa/nlu/classifiers/diet_classifier.py +38 -25
  110. rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
  111. rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
  112. rasa/nlu/extractors/crf_entity_extractor.py +93 -50
  113. rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
  114. rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
  115. rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
  116. rasa/shared/constants.py +36 -3
  117. rasa/shared/core/constants.py +7 -0
  118. rasa/shared/core/domain.py +26 -0
  119. rasa/shared/core/flows/flow.py +5 -0
  120. rasa/shared/core/flows/flows_yaml_schema.json +10 -0
  121. rasa/shared/core/flows/utils.py +39 -0
  122. rasa/shared/core/flows/validation.py +96 -0
  123. rasa/shared/core/slots.py +5 -0
  124. rasa/shared/nlu/training_data/features.py +120 -2
  125. rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
  126. rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
  127. rasa/shared/providers/_configs/model_group_config.py +167 -0
  128. rasa/shared/providers/_configs/openai_client_config.py +1 -1
  129. rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
  130. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
  131. rasa/shared/providers/_configs/utils.py +16 -0
  132. rasa/shared/providers/embedding/_base_litellm_embedding_client.py +12 -15
  133. rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
  134. rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
  135. rasa/shared/providers/llm/_base_litellm_client.py +31 -30
  136. rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
  137. rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
  138. rasa/shared/providers/llm/rasa_llm_client.py +112 -0
  139. rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
  140. rasa/shared/providers/mappings.py +19 -0
  141. rasa/shared/providers/router/__init__.py +0 -0
  142. rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
  143. rasa/shared/providers/router/router_client.py +73 -0
  144. rasa/shared/utils/common.py +8 -0
  145. rasa/shared/utils/health_check.py +533 -0
  146. rasa/shared/utils/io.py +28 -6
  147. rasa/shared/utils/llm.py +350 -46
  148. rasa/shared/utils/yaml.py +11 -13
  149. rasa/studio/upload.py +64 -20
  150. rasa/telemetry.py +80 -17
  151. rasa/tracing/instrumentation/attribute_extractors.py +74 -17
  152. rasa/utils/io.py +0 -66
  153. rasa/utils/log_utils.py +9 -2
  154. rasa/utils/tensorflow/feature_array.py +366 -0
  155. rasa/utils/tensorflow/model_data.py +2 -193
  156. rasa/validator.py +70 -0
  157. rasa/version.py +1 -1
  158. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/METADATA +10 -10
  159. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/RECORD +162 -146
  160. rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
  161. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/NOTICE +0 -0
  162. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/WHEEL +0 -0
  163. {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/entry_points.txt +0 -0
@@ -1,15 +1,133 @@
1
1
  from __future__ import annotations
2
- from typing import Iterable, Union, Text, Optional, List, Any, Tuple, Dict, Set
2
+
3
3
  import itertools
4
+ from dataclasses import dataclass
5
+ from typing import Iterable, Union, Text, Optional, List, Any, Tuple, Dict, Set
4
6
 
5
7
  import numpy as np
6
8
  import scipy.sparse
9
+ from safetensors.numpy import save_file, load_file
7
10
 
8
- import rasa.shared.utils.io
9
11
  import rasa.shared.nlu.training_data.util
12
+ import rasa.shared.utils.io
10
13
  from rasa.shared.nlu.constants import FEATURE_TYPE_SEQUENCE, FEATURE_TYPE_SENTENCE
11
14
 
12
15
 
16
+ @dataclass
17
+ class FeatureMetadata:
18
+ data_type: str
19
+ attribute: str
20
+ origin: Union[str, List[str]]
21
+ is_sparse: bool
22
+ shape: tuple
23
+ safetensors_key: str
24
+
25
+
26
+ def save_features(
27
+ features_dict: Dict[Text, List[Features]], file_name: str
28
+ ) -> Dict[str, Any]:
29
+ """Save a dictionary of Features lists to disk using safetensors.
30
+
31
+ Args:
32
+ features_dict: Dictionary mapping strings to lists of Features objects
33
+ file_name: File to save the features to
34
+
35
+ Returns:
36
+ The metadata to reconstruct the features.
37
+ """
38
+ # All tensors are stored in a single safetensors file
39
+ tensors_to_save = {}
40
+ # Metadata will be stored separately
41
+ metadata = {}
42
+
43
+ for key, features_list in features_dict.items():
44
+ feature_metadata_list = []
45
+
46
+ for idx, feature in enumerate(features_list):
47
+ # Create a unique key for this tensor in the safetensors file
48
+ safetensors_key = f"{key}_{idx}"
49
+
50
+ # Convert sparse matrices to dense if needed
51
+ if feature.is_sparse():
52
+ # For sparse matrices, use the COO format
53
+ coo = feature.features.tocoo() # type:ignore[union-attr]
54
+ # Save data, row indices and col indices separately
55
+ tensors_to_save[f"{safetensors_key}_data"] = coo.data
56
+ tensors_to_save[f"{safetensors_key}_row"] = coo.row
57
+ tensors_to_save[f"{safetensors_key}_col"] = coo.col
58
+ else:
59
+ tensors_to_save[safetensors_key] = feature.features
60
+
61
+ # Store metadata
62
+ metadata_item = FeatureMetadata(
63
+ data_type=feature.type,
64
+ attribute=feature.attribute,
65
+ origin=feature.origin,
66
+ is_sparse=feature.is_sparse(),
67
+ shape=feature.features.shape,
68
+ safetensors_key=safetensors_key,
69
+ )
70
+ feature_metadata_list.append(vars(metadata_item))
71
+
72
+ metadata[key] = feature_metadata_list
73
+
74
+ # Save tensors
75
+ save_file(tensors_to_save, file_name)
76
+
77
+ return metadata
78
+
79
+
80
+ def load_features(
81
+ filename: str, metadata: Dict[str, Any]
82
+ ) -> Dict[Text, List[Features]]:
83
+ """Load Features dictionary from disk.
84
+
85
+ Args:
86
+ filename: File name of the safetensors file.
87
+ metadata: Metadata to reconstruct the features.
88
+
89
+ Returns:
90
+ Dictionary mapping strings to lists of Features objects
91
+ """
92
+ # Load tensors
93
+ tensors = load_file(filename)
94
+
95
+ # Reconstruct the features dictionary
96
+ features_dict: Dict[Text, List[Features]] = {}
97
+
98
+ for key, feature_metadata_list in metadata.items():
99
+ features_list = []
100
+
101
+ for meta in feature_metadata_list:
102
+ safetensors_key = meta["safetensors_key"]
103
+
104
+ if meta["is_sparse"]:
105
+ # Reconstruct sparse matrix from COO format
106
+ data = tensors[f"{safetensors_key}_data"]
107
+ row = tensors[f"{safetensors_key}_row"]
108
+ col = tensors[f"{safetensors_key}_col"]
109
+
110
+ features_matrix = scipy.sparse.coo_matrix(
111
+ (data, (row, col)), shape=tuple(meta["shape"])
112
+ ).tocsr() # Convert back to CSR format
113
+ else:
114
+ features_matrix = tensors[safetensors_key]
115
+
116
+ # Reconstruct Features object
117
+ features = Features(
118
+ features=features_matrix,
119
+ feature_type=meta["data_type"],
120
+ attribute=meta["attribute"],
121
+ origin=meta["origin"],
122
+ )
123
+
124
+ features_list.append(features)
125
+
126
+ features_dict[key] = features_list
127
+
128
+ return features_dict
129
+
130
+
13
131
  class Features:
14
132
  """Stores the features produced by any featurizer."""
15
133
 
@@ -107,8 +107,7 @@ class AzureOpenAIClientConfig:
107
107
 
108
108
  @classmethod
109
109
  def from_dict(cls, config: dict) -> "AzureOpenAIClientConfig":
110
- """
111
- Initializes a dataclass from the passed config.
110
+ """Initializes a dataclass from the passed config.
112
111
 
113
112
  Args:
114
113
  config: (dict) The config from which to initialize.
@@ -175,7 +174,10 @@ def is_azure_openai_config(config: dict) -> bool:
175
174
 
176
175
  # Case: Configuration contains `deployment` key
177
176
  # (specific to Azure OpenAI configuration)
178
- if config.get(DEPLOYMENT_CONFIG_KEY) is not None:
177
+ if (
178
+ config.get(DEPLOYMENT_CONFIG_KEY) is not None
179
+ and config.get(PROVIDER_CONFIG_KEY) is None
180
+ ):
179
181
  return True
180
182
 
181
183
  return False
@@ -0,0 +1,200 @@
1
+ import copy
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List
4
+
5
+ import structlog
6
+ from rasa.shared.constants import (
7
+ ROUTER_CONFIG_KEY,
8
+ MODELS_CONFIG_KEY,
9
+ MODEL_GROUP_ID_CONFIG_KEY,
10
+ MODEL_NAME_CONFIG_KEY,
11
+ LITELLM_PARAMS_KEY,
12
+ PROVIDER_CONFIG_KEY,
13
+ DEPLOYMENT_CONFIG_KEY,
14
+ API_TYPE_CONFIG_KEY,
15
+ MODEL_CONFIG_KEY,
16
+ MODEL_LIST_KEY,
17
+ )
18
+ from rasa.shared.providers._configs.model_group_config import (
19
+ ModelGroupConfig,
20
+ ModelConfig,
21
+ )
22
+ from rasa.shared.providers.mappings import get_prefix_from_provider
23
+ from rasa.shared.utils.llm import DEPLOYMENT_CENTRIC_PROVIDERS
24
+
25
+
26
+ structlogger = structlog.get_logger()
27
+
28
+ _LITELLM_UNSUPPORTED_KEYS = [
29
+ PROVIDER_CONFIG_KEY,
30
+ DEPLOYMENT_CONFIG_KEY,
31
+ API_TYPE_CONFIG_KEY,
32
+ ]
33
+
34
+
35
+ @dataclass
36
+ class LiteLLMRouterClientConfig:
37
+ """Parses configuration for a LiteLLM Router client. The configuration is expected
38
+ to be in the following format:
39
+
40
+ {
41
+ "id": "model_group_id",
42
+ "models": [
43
+ {
44
+ "provider": "provider_name",
45
+ "model": "model_name",
46
+ "api_base": "api_base",
47
+ "api_key": "api_key",
48
+ "api_version": "api_version",
49
+ },
50
+ {
51
+ "provider": "provider_name",
52
+ "model": "model_name",
53
+ },
54
+ "router": {}
55
+ }
56
+
57
+ This configuration is converted into the LiteLLM required format:
58
+
59
+ {
60
+ "id": "model_group_id",
61
+ "model_list": [
62
+ {
63
+ "model_name": "model_group_id",
64
+ "litellm_params": {
65
+ "model": "provider_name/model_name",
66
+ "api_base": "api_base",
67
+ "api_key": "api_key",
68
+ "api_version": "api_version",
69
+ },
70
+ },
71
+ {
72
+ "model_name": "model_group_id",
73
+ "litellm_params": {
74
+ "model": "provider_name/model_name",
75
+ },
76
+ },
77
+ ],
78
+ "router": {},
79
+ }
80
+
81
+ Raises:
82
+ ValueError: If the configuration is missing required keys.
83
+ """
84
+
85
+ _model_group_config: ModelGroupConfig
86
+ router: Dict[str, Any]
87
+ extra_parameters: dict = field(default_factory=dict)
88
+
89
+ @property
90
+ def model_group_id(self) -> str:
91
+ return self._model_group_config.model_group_id
92
+
93
+ @property
94
+ def models(self) -> List[ModelConfig]:
95
+ return self._model_group_config.models
96
+
97
+ @property
98
+ def litellm_model_list(self) -> List[Dict[str, Any]]:
99
+ return self._convert_models_to_litellm_model_list()
100
+
101
+ def __post_init__(self) -> None:
102
+ if not self.router:
103
+ message = "Router cannot be empty."
104
+ structlogger.error(
105
+ "litellm_router_client_config.validation_error",
106
+ message=message,
107
+ model_group_id=self._model_group_config.model_group_id,
108
+ )
109
+ raise ValueError(message)
110
+
111
+ @classmethod
112
+ def from_dict(cls, config: dict) -> "LiteLLMRouterClientConfig":
113
+ """Initializes a dataclass from the passed config.
114
+
115
+ Args:
116
+ config: (dict) The config from which to initialize.
117
+
118
+ Raises:
119
+ ValueError: Config is missing required keys.
120
+
121
+ Returns:
122
+ LiteLLMRouterClientConfig
123
+ """
124
+
125
+ model_group_config = ModelGroupConfig.from_dict(config)
126
+
127
+ # Copy config to avoid mutating the original
128
+ config_copy = copy.deepcopy(config)
129
+ # Pop the keys used by ModelGroupConfig
130
+ config_copy.pop(MODEL_GROUP_ID_CONFIG_KEY, None)
131
+ config_copy.pop(MODELS_CONFIG_KEY, None)
132
+ # Get the router settings
133
+ router_settings = config_copy.pop(ROUTER_CONFIG_KEY, None)
134
+ # The rest is considered as extra parameters
135
+ extra_parameters = config_copy
136
+
137
+ this = LiteLLMRouterClientConfig(
138
+ _model_group_config=model_group_config,
139
+ router=router_settings,
140
+ extra_parameters=extra_parameters,
141
+ )
142
+ return this
143
+
144
+ def to_dict(self) -> dict:
145
+ """Converts the config instance into a dictionary."""
146
+ d = self._model_group_config.to_dict()
147
+ d[ROUTER_CONFIG_KEY] = self.router
148
+ if self.extra_parameters:
149
+ d.update(self.extra_parameters)
150
+ return d
151
+
152
+ def to_litellm_dict(self) -> dict:
153
+ litellm_model_list = self._convert_models_to_litellm_model_list()
154
+ d = {
155
+ **self.extra_parameters,
156
+ MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
157
+ MODEL_LIST_KEY: litellm_model_list,
158
+ ROUTER_CONFIG_KEY: self.router,
159
+ }
160
+ return d
161
+
162
+ def _convert_models_to_litellm_model_list(self) -> List[Dict[str, Any]]:
163
+ litellm_model_list = []
164
+
165
+ for model_config_object in self.models:
166
+ # Convert the model config to a dict representation
167
+ litellm_model_config = model_config_object.to_dict()
168
+
169
+ provider = litellm_model_config[PROVIDER_CONFIG_KEY]
170
+
171
+ # Get the litellm prefixing for the provider
172
+ prefix = get_prefix_from_provider(provider)
173
+
174
+ # Determine whether to use model or deployment key based on the provider.
175
+ litellm_model_name_without_prefix = (
176
+ litellm_model_config[DEPLOYMENT_CONFIG_KEY]
177
+ if provider in DEPLOYMENT_CENTRIC_PROVIDERS
178
+ else litellm_model_config[MODEL_CONFIG_KEY]
179
+ )
180
+
181
+ # Set 'model' to a provider prefixed model name e.g. openai/gpt-4
182
+ litellm_model_config[MODEL_CONFIG_KEY] = (
183
+ f"{prefix}/{litellm_model_name_without_prefix}"
184
+ )
185
+
186
+ # Remove parameters that are None and not supported by LiteLLM.
187
+ litellm_model_config = {
188
+ key: value
189
+ for key, value in litellm_model_config.items()
190
+ if key not in _LITELLM_UNSUPPORTED_KEYS and value is not None
191
+ }
192
+
193
+ litellm_model_list_item = {
194
+ MODEL_NAME_CONFIG_KEY: self.model_group_id,
195
+ LITELLM_PARAMS_KEY: litellm_model_config,
196
+ }
197
+
198
+ litellm_model_list.append(litellm_model_list_item)
199
+
200
+ return litellm_model_list
@@ -0,0 +1,167 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import List, Optional
3
+
4
+ import structlog
5
+ from rasa.shared.constants import (
6
+ API_BASE_CONFIG_KEY,
7
+ API_KEY,
8
+ API_TYPE_CONFIG_KEY,
9
+ API_VERSION_CONFIG_KEY,
10
+ DEPLOYMENT_CONFIG_KEY,
11
+ PROVIDER_CONFIG_KEY,
12
+ MODEL_CONFIG_KEY,
13
+ MODEL_GROUP_ID_CONFIG_KEY,
14
+ MODELS_CONFIG_KEY,
15
+ MODEL_GROUPS_CONFIG_KEY,
16
+ EXTRA_PARAMETERS_KEY,
17
+ )
18
+ from rasa.shared.providers.mappings import get_client_config_class_from_provider
19
+
20
+ structlogger = structlog.get_logger()
21
+
22
+
23
+ @dataclass
24
+ class ModelConfig:
25
+ """Parses the model config.
26
+
27
+ Raises:
28
+ ValueError: If the provider config key is missing in the config.
29
+ """
30
+
31
+ provider: str
32
+ model: Optional[str] = None
33
+ deployment: Optional[str] = None
34
+ api_base: Optional[str] = None
35
+ api_key: Optional[str] = None
36
+ api_version: Optional[str] = None
37
+ extra_parameters: dict = field(default_factory=dict)
38
+ # Retained for backward compatibility with older configurations,
39
+ # but intentionally not included in extra_parameters
40
+ api_type: Optional[str] = None
41
+
42
+ @classmethod
43
+ def from_dict(cls, config: dict) -> "ModelConfig":
44
+ """Initializes a dataclass from the passed config. The provider config param is
45
+ used to determine the client config class to use. The client config class takes
46
+ care of resolving config aliases and throwing deprecation warnings.
47
+
48
+ Args:
49
+ config: (dict) The config from which to initialize.
50
+
51
+ Raises:
52
+ ValueError: Config is missing required keys.
53
+
54
+ Returns:
55
+ ModelConfig
56
+ """
57
+ from rasa.shared.utils.llm import get_provider_from_config
58
+
59
+ # Get the provider from config, this also inferring the provider from
60
+ # deprecated configurations
61
+ provider = get_provider_from_config(config)
62
+
63
+ # Retrieve the client configuration class for the specified provider.
64
+ client_config_clazz = get_client_config_class_from_provider(provider)
65
+
66
+ # Try to instantiate the config object in order to resolve deprecated
67
+ # aliases and throw deprecation warnings.
68
+ client_config_obj = client_config_clazz.from_dict(config)
69
+
70
+ # Convert back to dictionary and instantiate the ModelConfig object.
71
+ client_config = client_config_obj.to_dict()
72
+
73
+ # Check for provider after resolving all aliases
74
+ if PROVIDER_CONFIG_KEY not in client_config:
75
+ raise ValueError(
76
+ f"Missing required key '{PROVIDER_CONFIG_KEY}' in "
77
+ f"'{MODELS_CONFIG_KEY}' config."
78
+ )
79
+
80
+ return ModelConfig(
81
+ provider=client_config.pop(PROVIDER_CONFIG_KEY, None),
82
+ model=client_config.pop(MODEL_CONFIG_KEY, None),
83
+ deployment=client_config.pop(DEPLOYMENT_CONFIG_KEY, None),
84
+ api_type=client_config.pop(API_TYPE_CONFIG_KEY, None),
85
+ api_base=client_config.pop(API_BASE_CONFIG_KEY, None),
86
+ api_key=client_config.pop(API_KEY, None),
87
+ api_version=client_config.pop(API_VERSION_CONFIG_KEY, None),
88
+ extra_parameters=client_config,
89
+ )
90
+
91
+ def to_dict(self) -> dict:
92
+ """Converts the config instance into a dictionary."""
93
+ d = asdict(self)
94
+
95
+ # Extra parameters should also be on the top level
96
+ d.pop(EXTRA_PARAMETERS_KEY, None)
97
+ d.update(self.extra_parameters)
98
+
99
+ # Remove keys with None values
100
+ return {key: value for key, value in d.items() if value is not None}
101
+
102
+
103
+ @dataclass
104
+ class ModelGroupConfig:
105
+ """Parses the models config. The models config is a list of model configs.
106
+
107
+ Raises:
108
+ ValueError: If the model group ID is None or if the models list is empty.
109
+ """
110
+
111
+ model_group_id: str
112
+ models: List[ModelConfig]
113
+
114
+ def __post_init__(self) -> None:
115
+ if self.model_group_id is None:
116
+ message = "Model group ID cannot be set to None."
117
+ structlogger.error(
118
+ "model_group_config.validation_error",
119
+ message=message,
120
+ model_group_id=self.model_group_id,
121
+ )
122
+ raise ValueError(message)
123
+ if not self.models:
124
+ message = "Models cannot be empty."
125
+ structlogger.error(
126
+ "model_group_config.validation_error",
127
+ message=message,
128
+ model_group_id=self.model_group_id,
129
+ )
130
+ raise ValueError(message)
131
+
132
+ @classmethod
133
+ def from_dict(cls, config: dict) -> "ModelGroupConfig":
134
+ """Initializes a dataclass from the passed config.
135
+
136
+ Args:
137
+ config: (dict) The config from which to initialize.
138
+
139
+ Raises:
140
+ ValueError: Config is missing required keys.
141
+
142
+ Returns:
143
+ ModelGroupConfig
144
+ """
145
+ if MODELS_CONFIG_KEY not in config:
146
+ raise ValueError(
147
+ f"Missing required key '{MODELS_CONFIG_KEY}' in "
148
+ f"'{MODEL_GROUPS_CONFIG_KEY}' config."
149
+ )
150
+
151
+ models_config = [
152
+ ModelConfig.from_dict(model_config)
153
+ for model_config in config[MODELS_CONFIG_KEY]
154
+ ]
155
+
156
+ return cls(
157
+ model_group_id=config.get(MODEL_GROUP_ID_CONFIG_KEY),
158
+ models=models_config,
159
+ )
160
+
161
+ def to_dict(self) -> dict:
162
+ """Converts the config instance into a dictionary."""
163
+ d = {
164
+ MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
165
+ MODELS_CONFIG_KEY: [model.to_dict() for model in self.models],
166
+ }
167
+ return d
@@ -19,8 +19,8 @@ from rasa.shared.constants import (
19
19
  REQUEST_TIMEOUT_CONFIG_KEY,
20
20
  TIMEOUT_CONFIG_KEY,
21
21
  PROVIDER_CONFIG_KEY,
22
- OPENAI_PROVIDER,
23
22
  OPENAI_API_TYPE,
23
+ OPENAI_PROVIDER,
24
24
  )
25
25
  from rasa.shared.providers._configs.utils import (
26
26
  resolve_aliases,
@@ -0,0 +1,73 @@
1
+ from dataclasses import asdict, dataclass, field
2
+ from typing import Optional
3
+
4
+ import structlog
5
+
6
+ from rasa.shared.constants import (
7
+ MODEL_CONFIG_KEY,
8
+ RASA_PROVIDER,
9
+ PROVIDER_CONFIG_KEY,
10
+ API_BASE_CONFIG_KEY,
11
+ )
12
+ from rasa.shared.providers._configs.utils import (
13
+ validate_required_keys,
14
+ )
15
+
16
+ REQUIRED_KEYS = [MODEL_CONFIG_KEY, PROVIDER_CONFIG_KEY, API_BASE_CONFIG_KEY]
17
+
18
+ structlogger = structlog.get_logger()
19
+
20
+
21
+ @dataclass
22
+ class RasaLLMClientConfig:
23
+ """Parses configuration for a Rasa Hosted LiteLLM client,
24
+ checks required keys present.
25
+
26
+ Raises:
27
+ ValueError: Raised in cases of invalid configuration:
28
+ - If any of the required configuration keys are missing.
29
+ """
30
+
31
+ model: Optional[str]
32
+ api_base: Optional[str]
33
+ # Provider is not used by LiteLLM backend, but we define it here since it's
34
+ # used as switch between different clients.
35
+ provider: str = RASA_PROVIDER
36
+
37
+ extra_parameters: dict = field(default_factory=dict)
38
+
39
+ @classmethod
40
+ def from_dict(cls, config: dict) -> "RasaLLMClientConfig":
41
+ """
42
+ Initializes a dataclass from the passed config.
43
+
44
+ Args:
45
+ config: (dict) The config from which to initialize.
46
+
47
+ Raises:
48
+ ValueError: Raised in cases of invalid configuration:
49
+ - If any of the required configuration keys are missing.
50
+ - If `api_type` has a value different from `azure`.
51
+
52
+ Returns:
53
+ RasaLLMClientConfig
54
+ """
55
+ # Validate that required keys are set
56
+ validate_required_keys(config, REQUIRED_KEYS)
57
+
58
+ extra_parameters = {k: v for k, v in config.items() if k not in REQUIRED_KEYS}
59
+
60
+ return cls(
61
+ model=config.get(MODEL_CONFIG_KEY),
62
+ api_base=config.get(API_BASE_CONFIG_KEY),
63
+ provider=config.get(PROVIDER_CONFIG_KEY, RASA_PROVIDER),
64
+ extra_parameters=extra_parameters,
65
+ )
66
+
67
+ def to_dict(self) -> dict:
68
+ """Converts the config instance into a dictionary."""
69
+ d = asdict(self)
70
+ # Extra parameters should also be on the top level
71
+ d.pop("extra_parameters", None)
72
+ d.update(self.extra_parameters)
73
+ return d
@@ -23,6 +23,7 @@ from rasa.shared.constants import (
23
23
  SELF_HOSTED_PROVIDER,
24
24
  USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
25
25
  )
26
+
26
27
  from rasa.shared.providers._configs.utils import (
27
28
  raise_deprecation_warnings,
28
29
  resolve_aliases,
@@ -99,3 +99,19 @@ def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
99
99
  config=config,
100
100
  )
101
101
  raise ValueError(message)
102
+
103
+
104
+ def get_provider_prefixed_model_name(provider: str, model: str) -> str:
105
+ """
106
+ Returns the model name with the provider prefixed.
107
+
108
+ Args:
109
+ provider: The provider of the model.
110
+ model: The model name.
111
+
112
+ Returns:
113
+ The model name with the provider prefixed.
114
+ """
115
+ if model and f"{provider}/" not in model:
116
+ return f"{provider}/{model}"
117
+ return model