rasa-pro 3.11.0rc2__py3-none-any.whl → 3.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (65) hide show
  1. rasa/__main__.py +9 -3
  2. rasa/cli/studio/upload.py +0 -15
  3. rasa/cli/utils.py +1 -1
  4. rasa/core/channels/development_inspector.py +8 -2
  5. rasa/core/channels/voice_ready/audiocodes.py +3 -4
  6. rasa/core/channels/voice_stream/asr/asr_engine.py +19 -1
  7. rasa/core/channels/voice_stream/asr/asr_event.py +1 -1
  8. rasa/core/channels/voice_stream/asr/azure.py +16 -9
  9. rasa/core/channels/voice_stream/asr/deepgram.py +17 -14
  10. rasa/core/channels/voice_stream/tts/azure.py +3 -1
  11. rasa/core/channels/voice_stream/tts/cartesia.py +3 -3
  12. rasa/core/channels/voice_stream/tts/tts_engine.py +10 -1
  13. rasa/core/channels/voice_stream/voice_channel.py +48 -18
  14. rasa/core/information_retrieval/qdrant.py +1 -0
  15. rasa/core/nlg/contextual_response_rephraser.py +2 -2
  16. rasa/core/persistor.py +93 -49
  17. rasa/core/policies/enterprise_search_policy.py +5 -5
  18. rasa/core/policies/flows/flow_executor.py +18 -8
  19. rasa/core/policies/intentless_policy.py +9 -5
  20. rasa/core/processor.py +7 -5
  21. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +2 -1
  22. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +9 -0
  23. rasa/e2e_test/aggregate_test_stats_calculator.py +11 -1
  24. rasa/e2e_test/assertions.py +133 -16
  25. rasa/e2e_test/assertions_schema.yml +23 -0
  26. rasa/e2e_test/e2e_test_runner.py +2 -2
  27. rasa/engine/loader.py +12 -0
  28. rasa/engine/validation.py +310 -86
  29. rasa/model_manager/config.py +8 -0
  30. rasa/model_manager/model_api.py +166 -61
  31. rasa/model_manager/runner_service.py +31 -26
  32. rasa/model_manager/trainer_service.py +14 -23
  33. rasa/model_manager/warm_rasa_process.py +187 -0
  34. rasa/model_service.py +3 -5
  35. rasa/model_training.py +3 -1
  36. rasa/shared/constants.py +27 -5
  37. rasa/shared/core/constants.py +1 -1
  38. rasa/shared/core/domain.py +8 -31
  39. rasa/shared/core/flows/yaml_flows_io.py +13 -4
  40. rasa/shared/importers/importer.py +19 -2
  41. rasa/shared/importers/rasa.py +5 -1
  42. rasa/shared/nlu/training_data/formats/rasa_yaml.py +18 -3
  43. rasa/shared/providers/_configs/litellm_router_client_config.py +29 -9
  44. rasa/shared/providers/_utils.py +79 -0
  45. rasa/shared/providers/embedding/default_litellm_embedding_client.py +24 -0
  46. rasa/shared/providers/embedding/litellm_router_embedding_client.py +1 -1
  47. rasa/shared/providers/llm/_base_litellm_client.py +26 -0
  48. rasa/shared/providers/llm/default_litellm_llm_client.py +24 -0
  49. rasa/shared/providers/llm/litellm_router_llm_client.py +56 -1
  50. rasa/shared/providers/llm/self_hosted_llm_client.py +4 -28
  51. rasa/shared/providers/router/_base_litellm_router_client.py +35 -1
  52. rasa/shared/utils/common.py +30 -3
  53. rasa/shared/utils/health_check/health_check.py +26 -24
  54. rasa/shared/utils/yaml.py +116 -31
  55. rasa/studio/data_handler.py +3 -1
  56. rasa/studio/upload.py +119 -57
  57. rasa/telemetry.py +3 -1
  58. rasa/tracing/config.py +1 -1
  59. rasa/validator.py +40 -4
  60. rasa/version.py +1 -1
  61. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/METADATA +2 -2
  62. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/RECORD +65 -63
  63. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/NOTICE +0 -0
  64. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/WHEEL +0 -0
  65. {rasa_pro-3.11.0rc2.dist-info → rasa_pro-3.11.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,187 @@
1
+ import shlex
2
+ import subprocess
3
+ from rasa.__main__ import main
4
+ import os
5
+ from typing import List
6
+ import structlog
7
+ from dataclasses import dataclass
8
+ import uuid
9
+
10
+ from rasa.model_manager import config
11
+ from rasa.model_manager.utils import ensure_base_directory_exists, logs_path
12
+
13
+ structlogger = structlog.get_logger(__name__)
14
+
15
+ warm_rasa_processes: List["WarmRasaProcess"] = []
16
+
17
+ NUMBER_OF_INITIAL_PROCESSES = 3
18
+
19
+
20
+ @dataclass
21
+ class WarmRasaProcess:
22
+ """Data class to store a warm Rasa process.
23
+
24
+ A "warm" Rasa process is one where we've done the heavy lifting of
25
+ importing key modules ahead of time (e.g. litellm). This is to avoid
26
+ long import times when we actually want to run a command.
27
+
28
+ This is a started process waiting for a Rasa CLI command. It's
29
+ output is stored in a log file identified by `log_id`.
30
+ """
31
+
32
+ process: subprocess.Popen
33
+ log_id: str
34
+
35
+
36
+ def _create_warm_rasa_process() -> WarmRasaProcess:
37
+ """Create a new warm Rasa process."""
38
+ command = [
39
+ config.RASA_PYTHON_PATH,
40
+ "-m",
41
+ "rasa.model_manager.warm_rasa_process",
42
+ ]
43
+
44
+ envs = os.environ.copy()
45
+ envs["RASA_TELEMETRY_ENABLED"] = "false"
46
+
47
+ log_id = uuid.uuid4().hex
48
+ log_path = logs_path(log_id)
49
+
50
+ ensure_base_directory_exists(log_path)
51
+
52
+ process = subprocess.Popen(
53
+ command,
54
+ stdout=open(log_path, "w"),
55
+ stderr=subprocess.STDOUT,
56
+ stdin=subprocess.PIPE,
57
+ env=envs,
58
+ )
59
+
60
+ structlogger.debug(
61
+ "model_trainer.created_warm_rasa_process",
62
+ pid=process.pid,
63
+ command=command,
64
+ log_path=log_path,
65
+ )
66
+
67
+ return WarmRasaProcess(process=process, log_id=log_id)
68
+
69
+
70
+ def initialize_warm_rasa_process() -> None:
71
+ """Initialize the warm Rasa processes."""
72
+ global warm_rasa_processes
73
+ for _ in range(NUMBER_OF_INITIAL_PROCESSES):
74
+ warm_rasa_processes.append(_create_warm_rasa_process())
75
+
76
+
77
+ def shutdown_warm_rasa_processes() -> None:
78
+ """Shutdown all warm Rasa processes."""
79
+ global warm_rasa_processes
80
+ for warm_rasa_process in warm_rasa_processes:
81
+ warm_rasa_process.process.terminate()
82
+ warm_rasa_processes = []
83
+
84
+
85
+ def start_rasa_process(cwd: str, arguments: List[str]) -> WarmRasaProcess:
86
+ """Start a Rasa process.
87
+
88
+ This will start a Rasa process with the given current working directory
89
+ and arguments. The process will be a warm one, meaning that it has already
90
+ imported all necessary modules.
91
+ """
92
+ warm_rasa_process = _get_warm_rasa_process()
93
+ _pass_arguments_to_process(warm_rasa_process.process, cwd, arguments)
94
+ return warm_rasa_process
95
+
96
+
97
+ def _get_warm_rasa_process() -> WarmRasaProcess:
98
+ """Get a warm Rasa process.
99
+
100
+ This will return a warm Rasa process from the pool and create a
101
+ new one to replace it.
102
+ """
103
+ global warm_rasa_processes
104
+
105
+ if not warm_rasa_processes:
106
+ warm_rasa_processes = [_create_warm_rasa_process()]
107
+
108
+ previous_warm_rasa_process = warm_rasa_processes.pop(0)
109
+
110
+ if previous_warm_rasa_process.process.poll() is not None:
111
+ # process has finished (for some reason...)
112
+ # back up plan is to create a new one on the spot.
113
+ # this should not happen, but let's be safe
114
+ structlogger.warning(
115
+ "model_trainer.warm_rasa_process_finished_unexpectedly",
116
+ pid=previous_warm_rasa_process.process.pid,
117
+ )
118
+ previous_warm_rasa_process = _create_warm_rasa_process()
119
+
120
+ warm_rasa_processes.append(_create_warm_rasa_process())
121
+ return previous_warm_rasa_process
122
+
123
+
124
+ def _pass_arguments_to_process(
125
+ process: subprocess.Popen, cwd: str, arguments: List[str]
126
+ ) -> None:
127
+ """Pass arguments to a warm Rasa process.
128
+
129
+ The process is waiting for input on stdin. We pass the current working
130
+ directory and the arguments to run a Rasa CLI command.
131
+ """
132
+ arguments_string = " ".join(arguments)
133
+ # send arguments to stdin
134
+ process.stdin.write(cwd.encode()) # type: ignore[union-attr]
135
+ process.stdin.write("\n".encode()) # type: ignore[union-attr]
136
+ process.stdin.write(arguments_string.encode()) # type: ignore[union-attr]
137
+ process.stdin.write("\n".encode()) # type: ignore[union-attr]
138
+ process.stdin.flush() # type: ignore[union-attr]
139
+
140
+
141
+ def warmup() -> None:
142
+ """Import all necessary modules to warm up the process.
143
+
144
+ This should include all the modules that take a long time to import.
145
+ We import them now, so that the training / deployment can later
146
+ directly start.
147
+ """
148
+ try:
149
+ import presidio_analyzer # noqa: F401
150
+ import litellm # noqa: F401
151
+ import langchain # noqa: F401
152
+ import tensorflow # noqa: F401
153
+ import matplotlib # noqa: F401
154
+ import pandas # noqa: F401
155
+ import numpy # noqa: F401
156
+ import spacy # noqa: F401
157
+ import rasa.validator # noqa: F401
158
+ except ImportError:
159
+ pass
160
+
161
+
162
+ def warm_rasa_main() -> None:
163
+ """Entry point for processes waiting for their command to run.
164
+
165
+ The process will wait for the current working directory and the command
166
+ to run. These will be send on stdin by the parent process. After receiving
167
+ the input, we will kick things of starting or running a bot.
168
+
169
+ Uses the normal Rasa CLI entry point (e.g. `rasa train --data ...`).
170
+ """
171
+ warmup()
172
+
173
+ cwd = input()
174
+
175
+ # this should be `train --data ...` or similar
176
+ cli_arguments_str = input()
177
+ # splits the arguments string into a list of arguments as expected by `argparse`
178
+ arguments = shlex.split(cli_arguments_str)
179
+
180
+ # needed to make sure the passed arguments are relative to the working directory
181
+ os.chdir(cwd)
182
+
183
+ main(arguments)
184
+
185
+
186
+ if __name__ == "__main__":
187
+ warm_rasa_main()
rasa/model_service.py CHANGED
@@ -8,7 +8,7 @@ from rasa.core.persistor import RemoteStorageType, get_persistor
8
8
  from rasa.core.utils import list_routes
9
9
  from rasa.model_manager import model_api
10
10
  from rasa.model_manager import config
11
- from rasa.model_manager.config import SERVER_BASE_URL
11
+ from rasa.model_manager.config import SERVER_BASE_URL, SERVER_PORT
12
12
  from rasa.utils.common import configure_logging_and_warnings
13
13
  import rasa.utils.licensing
14
14
  from urllib.parse import urlparse
@@ -18,8 +18,6 @@ from rasa.utils.sanic_error_handler import register_custom_sanic_error_handler
18
18
 
19
19
  structlogger = structlog.get_logger()
20
20
 
21
- MODEL_SERVICE_PORT = 8000
22
-
23
21
 
24
22
  def url_prefix_from_base_url() -> str:
25
23
  """Return the path prefix from the base URL."""
@@ -93,7 +91,7 @@ def main() -> None:
93
91
 
94
92
  validate_model_storage_type()
95
93
 
96
- structlogger.debug("model_api.starting_server", port=MODEL_SERVICE_PORT)
94
+ structlogger.debug("model_api.starting_server", port=SERVER_PORT)
97
95
 
98
96
  url_prefix = url_prefix_from_base_url()
99
97
  # configure the sanic application
@@ -107,7 +105,7 @@ def main() -> None:
107
105
 
108
106
  register_custom_sanic_error_handler(app)
109
107
 
110
- app.run(host="0.0.0.0", port=MODEL_SERVICE_PORT, legacy=True, motd=False)
108
+ app.run(host="0.0.0.0", port=SERVER_PORT, legacy=True, motd=False)
111
109
 
112
110
 
113
111
  if __name__ == "__main__":
rasa/model_training.py CHANGED
@@ -322,8 +322,10 @@ async def _train_graph(
322
322
  rasa.engine.validation.validate_coexistance_routing_setup(
323
323
  domain, model_configuration, flows
324
324
  )
325
- rasa.engine.validation.validate_model_client_configuration_setup(config)
326
325
  rasa.engine.validation.validate_model_group_configuration_setup()
326
+ rasa.engine.validation.validate_model_client_configuration_setup_during_training_time(
327
+ config
328
+ )
327
329
  rasa.engine.validation.validate_flow_component_dependencies(
328
330
  flows, model_configuration
329
331
  )
rasa/shared/constants.py CHANGED
@@ -149,6 +149,10 @@ AZURE_AD_TOKEN_ENV_VAR = "AZURE_AD_TOKEN"
149
149
  AZURE_API_BASE_ENV_VAR = "AZURE_API_BASE"
150
150
  AZURE_API_VERSION_ENV_VAR = "AZURE_API_VERSION"
151
151
  AZURE_API_TYPE_ENV_VAR = "AZURE_API_TYPE"
152
+ AZURE_SPEECH_API_KEY_ENV_VAR = "AZURE_SPEECH_API_KEY"
153
+
154
+ DEEPGRAM_API_KEY_ENV_VAR = "DEEPGRAM_API_KEY"
155
+ CARTESIA_API_KEY_ENV_VAR = "CARTESIA_API_KEY"
152
156
 
153
157
  OPENAI_API_KEY_ENV_VAR = "OPENAI_API_KEY"
154
158
  OPENAI_API_TYPE_ENV_VAR = "OPENAI_API_TYPE"
@@ -159,6 +163,9 @@ OPENAI_API_BASE_CONFIG_KEY = "openai_api_base"
159
163
  OPENAI_API_TYPE_CONFIG_KEY = "openai_api_type"
160
164
  OPENAI_API_VERSION_CONFIG_KEY = "openai_api_version"
161
165
 
166
+ AWS_BEDROCK_PROVIDER = "bedrock"
167
+ AWS_SAGEMAKER_PROVIDER = "sagemaker"
168
+
162
169
  API_BASE_CONFIG_KEY = "api_base"
163
170
  API_TYPE_CONFIG_KEY = "api_type"
164
171
  API_VERSION_CONFIG_KEY = "api_version"
@@ -184,19 +191,19 @@ N_REPHRASES_CONFIG_KEY = "n"
184
191
  USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY = "use_chat_completions_endpoint"
185
192
 
186
193
  ROUTER_CONFIG_KEY = "router"
187
- ROUTER_STRATEGY_CONFIG_KEY = "router_strategy"
194
+ ROUTING_STRATEGY_CONFIG_KEY = "routing_strategy"
188
195
  REDIS_HOST_CONFIG_KEY = "redis_host"
189
- ROUTER_STRATEGIES_REQUIRING_REDIS_CACHE = [
196
+ ROUTING_STRATEGIES_REQUIRING_REDIS_CACHE = [
190
197
  "cost-based-routing",
191
198
  "usage-based-routing",
192
199
  ]
193
- ROUTER_STRATEGIES_NOT_REQUIRING_CACHE = [
200
+ ROUTING_STRATEGIES_NOT_REQUIRING_CACHE = [
194
201
  "latency-based-routing",
195
202
  "least-busy",
196
203
  "simple-shuffle",
197
204
  ]
198
- VALID_ROUTER_STRATEGIES = (
199
- ROUTER_STRATEGIES_REQUIRING_REDIS_CACHE + ROUTER_STRATEGIES_NOT_REQUIRING_CACHE
205
+ VALID_ROUTING_STRATEGIES = (
206
+ ROUTING_STRATEGIES_REQUIRING_REDIS_CACHE + ROUTING_STRATEGIES_NOT_REQUIRING_CACHE
200
207
  )
201
208
 
202
209
  MODELS_CONFIG_KEY = "models"
@@ -219,6 +226,14 @@ AZURE_API_VERSION_ENV_VAR = "AZURE_API_VERSION"
219
226
  AZURE_API_TYPE_ENV_VAR = "AZURE_API_TYPE"
220
227
 
221
228
  AWS_REGION_NAME_CONFIG_KEY = "aws_region_name"
229
+ AWS_ACCESS_KEY_ID_CONFIG_KEY = "aws_access_key_id"
230
+ AWS_SECRET_ACCESS_KEY_CONFIG_KEY = "aws_secret_access_key"
231
+ AWS_SESSION_TOKEN_CONFIG_KEY = "aws_session_token"
232
+
233
+ AWS_ACCESS_KEY_ID_ENV_VAR = "AWS_ACCESS_KEY_ID"
234
+ AWS_SECRET_ACCESS_KEY_ENV_VAR = "AWS_SECRET_ACCESS_KEY"
235
+ AWS_REGION_NAME_ENV_VAR = "AWS_REGION_NAME"
236
+ AWS_SESSION_TOKEN_ENV_VAR = "AWS_SESSION_TOKEN"
222
237
 
223
238
  HUGGINGFACE_MULTIPROCESS_CONFIG_KEY = "multi_process"
224
239
  HUGGINGFACE_CACHE_FOLDER_CONFIG_KEY = "cache_folder"
@@ -280,3 +295,10 @@ RASA_PATTERN_CANNOT_HANDLE_INVALID_INTENT = (
280
295
  )
281
296
 
282
297
  ROUTE_TO_CALM_SLOT = "route_session_to_calm"
298
+
299
+ SENSITIVE_DATA = [
300
+ API_KEY,
301
+ AWS_ACCESS_KEY_ID_CONFIG_KEY,
302
+ AWS_SECRET_ACCESS_KEY_CONFIG_KEY,
303
+ AWS_SESSION_TOKEN_CONFIG_KEY,
304
+ ]
@@ -110,8 +110,8 @@ FLOW_SLOT_NAMES = [FLOW_HASHES_SLOT]
110
110
 
111
111
  # slots for audio timeout
112
112
  SLOT_SILENCE_TIMEOUT = "silence_timeout"
113
- SILENCE_TIMEOUT_DEFAULT_VALUE = 6.0
114
113
  SLOT_CONSECUTIVE_SILENCE_TIMEOUTS = "consecutive_silence_timeouts"
114
+ SILENCE_TIMEOUT_DEFAULT_VALUE = 6.0
115
115
  SILENCE_SLOTS = [SLOT_SILENCE_TIMEOUT, SLOT_CONSECUTIVE_SILENCE_TIMEOUTS]
116
116
  # slots for knowledge base
117
117
  SLOT_LISTED_ITEMS = "knowledge_base_listed_objects"
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  import collections
4
4
  import copy
5
5
  import json
6
- import math
7
6
  import os
8
7
  from dataclasses import dataclass
9
8
  from functools import cached_property
@@ -58,7 +57,6 @@ from rasa.shared.core.events import SlotSet, UserUttered
58
57
  from rasa.shared.core.slots import (
59
58
  AnySlot,
60
59
  CategoricalSlot,
61
- FloatSlot,
62
60
  ListSlot,
63
61
  Slot,
64
62
  TextSlot,
@@ -198,6 +196,7 @@ class Domain:
198
196
  """
199
197
 
200
198
  validate_yaml: ClassVar[bool] = True
199
+ expand_env_vars: ClassVar[bool] = True
201
200
 
202
201
  @classmethod
203
202
  def empty(cls) -> Domain:
@@ -1084,7 +1083,6 @@ class Domain:
1084
1083
  self._add_knowledge_base_slots()
1085
1084
  self._add_categorical_slot_default_value()
1086
1085
  self._add_session_metadata_slot()
1087
- self._add_audio_slots()
1088
1086
 
1089
1087
  def _add_categorical_slot_default_value(self) -> None:
1090
1088
  """Add a default value to all categorical slots.
@@ -1139,29 +1137,6 @@ class Domain:
1139
1137
  )
1140
1138
  )
1141
1139
 
1142
- def _add_audio_slots(self) -> None:
1143
- """Add slots relevant for audio channels."""
1144
- self.slots.append(
1145
- FloatSlot(
1146
- rasa.shared.core.constants.SLOT_SILENCE_TIMEOUT,
1147
- mappings=[],
1148
- influence_conversation=False,
1149
- is_builtin=True,
1150
- initial_value=rasa.shared.core.constants.SILENCE_TIMEOUT_DEFAULT_VALUE,
1151
- max_value=math.inf,
1152
- )
1153
- )
1154
- self.slots.append(
1155
- FloatSlot(
1156
- rasa.shared.core.constants.SLOT_CONSECUTIVE_SILENCE_TIMEOUTS,
1157
- mappings=[],
1158
- influence_conversation=False,
1159
- is_builtin=True,
1160
- initial_value=0.0,
1161
- max_value=math.inf,
1162
- )
1163
- )
1164
-
1165
1140
  def _add_knowledge_base_slots(self) -> None:
1166
1141
  """Add slots for the knowledge base action to slots.
1167
1142
 
@@ -1981,8 +1956,8 @@ class Domain:
1981
1956
  """Check whether the domain is empty."""
1982
1957
  return self.as_dict() == Domain.empty().as_dict()
1983
1958
 
1984
- @staticmethod
1985
- def is_domain_file(filename: Union[Text, Path]) -> bool:
1959
+ @classmethod
1960
+ def is_domain_file(cls, filename: Union[Text, Path]) -> bool:
1986
1961
  """Checks whether the given file path is a Rasa domain file.
1987
1962
 
1988
1963
  Args:
@@ -2001,7 +1976,7 @@ class Domain:
2001
1976
  return False
2002
1977
 
2003
1978
  try:
2004
- content = read_yaml_file(filename)
1979
+ content = read_yaml_file(filename, expand_env_vars=cls.expand_env_vars)
2005
1980
  except (RasaException, YamlSyntaxException):
2006
1981
  structlogger.warning(
2007
1982
  "domain.cannot_load_domain_file",
@@ -2130,10 +2105,12 @@ class Domain:
2130
2105
  "domain.from_yaml.validating",
2131
2106
  )
2132
2107
  validate_raw_yaml_using_schema_file_with_responses(
2133
- raw_yaml_content, DOMAIN_SCHEMA_FILE
2108
+ raw_yaml_content,
2109
+ DOMAIN_SCHEMA_FILE,
2110
+ expand_env_vars=cls.expand_env_vars,
2134
2111
  )
2135
2112
 
2136
- return read_yaml(raw_yaml_content)
2113
+ return read_yaml(raw_yaml_content, expand_env_vars=cls.expand_env_vars)
2137
2114
 
2138
2115
 
2139
2116
  def warn_about_duplicates_found_during_domain_merging(
@@ -1,5 +1,5 @@
1
1
  from pathlib import Path
2
- from typing import Any, Dict, List, Optional, Text, Union
2
+ from typing import Any, ClassVar, Dict, List, Optional, Text, Union
3
3
 
4
4
  import jsonschema
5
5
  import ruamel.yaml.nodes as yaml_nodes
@@ -25,6 +25,8 @@ KEY_FLOWS = "flows"
25
25
  class YAMLFlowsReader:
26
26
  """Class that reads flows information in YAML format."""
27
27
 
28
+ expand_env_vars: ClassVar[bool] = True
29
+
28
30
  @classmethod
29
31
  def read_from_file(
30
32
  cls, filename: Union[Text, Path], add_line_numbers: bool = True
@@ -217,14 +219,21 @@ class YAMLFlowsReader:
217
219
  `Flow`s read from `string`.
218
220
  """
219
221
  validate_yaml_with_jsonschema(
220
- string, FLOWS_SCHEMA_FILE, humanize_error=cls.humanize_flow_error
222
+ string,
223
+ FLOWS_SCHEMA_FILE,
224
+ humanize_error=cls.humanize_flow_error,
225
+ expand_env_vars=cls.expand_env_vars,
221
226
  )
222
227
  if add_line_numbers:
223
- yaml_content = read_yaml(string, custom_constructor=line_number_constructor)
228
+ yaml_content = read_yaml(
229
+ string,
230
+ custom_constructor=line_number_constructor,
231
+ expand_env_vars=cls.expand_env_vars,
232
+ )
224
233
  yaml_content = process_yaml_content(yaml_content)
225
234
 
226
235
  else:
227
- yaml_content = read_yaml(string)
236
+ yaml_content = read_yaml(string, expand_env_vars=cls.expand_env_vars)
228
237
 
229
238
  return FlowsList.from_json(yaml_content.get(KEY_FLOWS, {}), file_path=file_path)
230
239
 
@@ -1,7 +1,18 @@
1
1
  import logging
2
2
  from abc import ABC, abstractmethod
3
3
  from functools import reduce
4
- from typing import Any, Dict, List, Optional, Set, Text, Tuple, Type, Union, cast
4
+ from typing import (
5
+ Any,
6
+ Dict,
7
+ List,
8
+ Optional,
9
+ Set,
10
+ Text,
11
+ Tuple,
12
+ Type,
13
+ Union,
14
+ cast,
15
+ )
5
16
 
6
17
  import importlib_resources
7
18
 
@@ -167,6 +178,7 @@ class TrainingDataImporter(ABC):
167
178
  domain_path: Optional[Text] = None,
168
179
  training_data_paths: Optional[List[Text]] = None,
169
180
  args: Optional[Dict[Text, Any]] = None,
181
+ expand_env_vars: bool = True,
170
182
  ) -> "TrainingDataImporter":
171
183
  """Loads a `TrainingDataImporter` instance from a dictionary."""
172
184
  from rasa.shared.importers.rasa import RasaFileImporter
@@ -182,7 +194,12 @@ class TrainingDataImporter(ABC):
182
194
  importers = [importer for importer in importers if importer]
183
195
  if not importers:
184
196
  importers = [
185
- RasaFileImporter(config_path, domain_path, training_data_paths)
197
+ RasaFileImporter(
198
+ config_path,
199
+ domain_path,
200
+ training_data_paths,
201
+ expand_env_vars=expand_env_vars,
202
+ )
186
203
  ]
187
204
 
188
205
  return E2EImporter(
@@ -29,7 +29,9 @@ class RasaFileImporter(TrainingDataImporter):
29
29
  config_file: Optional[Text] = None,
30
30
  domain_path: Optional[Text] = None,
31
31
  training_data_paths: Optional[Union[List[Text], Text]] = None,
32
+ expand_env_vars: bool = True,
32
33
  ):
34
+ self.expand_env_vars = expand_env_vars
33
35
  self._domain_path = domain_path
34
36
 
35
37
  self._nlu_files = rasa.shared.data.get_data_files(
@@ -54,7 +56,9 @@ class RasaFileImporter(TrainingDataImporter):
54
56
  logger.debug("No configuration file was provided to the RasaFileImporter.")
55
57
  return {}
56
58
 
57
- config = read_model_configuration(self.config_file)
59
+ config = read_model_configuration(
60
+ self.config_file, expand_env_vars=self.expand_env_vars
61
+ )
58
62
  return config
59
63
 
60
64
  def get_config_file_for_auto_config(self) -> Optional[Text]:
@@ -1,7 +1,18 @@
1
1
  import logging
2
2
  from collections import OrderedDict
3
3
  from pathlib import Path
4
- from typing import Text, Any, List, Dict, Tuple, Union, Iterator, Optional, Callable
4
+ from typing import (
5
+ ClassVar,
6
+ Text,
7
+ Any,
8
+ List,
9
+ Dict,
10
+ Tuple,
11
+ Union,
12
+ Iterator,
13
+ Optional,
14
+ Callable,
15
+ )
5
16
 
6
17
  import rasa.shared.data
7
18
  from rasa.shared.core.domain import Domain
@@ -55,6 +66,8 @@ STRIP_SYMBOLS = "\n\r "
55
66
  class RasaYAMLReader(TrainingDataReader):
56
67
  """Reads YAML training data and creates a TrainingData object."""
57
68
 
69
+ expand_env_vars: ClassVar[bool] = True
70
+
58
71
  def __init__(self) -> None:
59
72
  super().__init__()
60
73
  self.training_examples: List[Message] = []
@@ -69,7 +82,9 @@ class RasaYAMLReader(TrainingDataReader):
69
82
  If the string is not in the right format, an exception will be raised.
70
83
  """
71
84
  try:
72
- validate_raw_yaml_using_schema_file_with_responses(string, NLU_SCHEMA_FILE)
85
+ validate_raw_yaml_using_schema_file_with_responses(
86
+ string, NLU_SCHEMA_FILE, expand_env_vars=self.expand_env_vars
87
+ )
73
88
  except YamlException as e:
74
89
  e.filename = self.filename
75
90
  raise e
@@ -88,7 +103,7 @@ class RasaYAMLReader(TrainingDataReader):
88
103
  """
89
104
  self.validate(string)
90
105
 
91
- yaml_content = read_yaml(string)
106
+ yaml_content = read_yaml(string, expand_env_vars=self.expand_env_vars)
92
107
 
93
108
  if not validate_training_data_format_version(yaml_content, self.filename):
94
109
  return TrainingData()
@@ -14,6 +14,7 @@ from rasa.shared.constants import (
14
14
  API_TYPE_CONFIG_KEY,
15
15
  MODEL_CONFIG_KEY,
16
16
  MODEL_LIST_KEY,
17
+ USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
17
18
  )
18
19
  from rasa.shared.providers._configs.model_group_config import (
19
20
  ModelGroupConfig,
@@ -29,6 +30,7 @@ _LITELLM_UNSUPPORTED_KEYS = [
29
30
  PROVIDER_CONFIG_KEY,
30
31
  DEPLOYMENT_CONFIG_KEY,
31
32
  API_TYPE_CONFIG_KEY,
33
+ USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY,
32
34
  ]
33
35
 
34
36
 
@@ -84,6 +86,7 @@ class LiteLLMRouterClientConfig:
84
86
 
85
87
  _model_group_config: ModelGroupConfig
86
88
  router: Dict[str, Any]
89
+ _use_chat_completions_endpoint: bool = True
87
90
  extra_parameters: dict = field(default_factory=dict)
88
91
 
89
92
  @property
@@ -98,6 +101,14 @@ class LiteLLMRouterClientConfig:
98
101
  def litellm_model_list(self) -> List[Dict[str, Any]]:
99
102
  return self._convert_models_to_litellm_model_list()
100
103
 
104
+ @property
105
+ def litellm_router_settings(self) -> Dict[str, Any]:
106
+ return self._convert_router_to_litellm_router_settings()
107
+
108
+ @property
109
+ def use_chat_completions_endpoint(self) -> bool:
110
+ return self._use_chat_completions_endpoint
111
+
101
112
  def __post_init__(self) -> None:
102
113
  if not self.router:
103
114
  message = "Router cannot be empty."
@@ -121,7 +132,6 @@ class LiteLLMRouterClientConfig:
121
132
  Returns:
122
133
  LiteLLMRouterClientConfig
123
134
  """
124
-
125
135
  model_group_config = ModelGroupConfig.from_dict(config)
126
136
 
127
137
  # Copy config to avoid mutating the original
@@ -130,13 +140,18 @@ class LiteLLMRouterClientConfig:
130
140
  config_copy.pop(MODEL_GROUP_ID_CONFIG_KEY, None)
131
141
  config_copy.pop(MODELS_CONFIG_KEY, None)
132
142
  # Get the router settings
133
- router_settings = config_copy.pop(ROUTER_CONFIG_KEY, None)
143
+ router_settings = config_copy.pop(ROUTER_CONFIG_KEY, {})
144
+ # Get the use_chat_completions_endpoint setting
145
+ use_chat_completions_endpoint = router_settings.get(
146
+ USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY, True
147
+ )
134
148
  # The rest is considered as extra parameters
135
149
  extra_parameters = config_copy
136
150
 
137
151
  this = LiteLLMRouterClientConfig(
138
152
  _model_group_config=model_group_config,
139
153
  router=router_settings,
154
+ _use_chat_completions_endpoint=use_chat_completions_endpoint,
140
155
  extra_parameters=extra_parameters,
141
156
  )
142
157
  return this
@@ -150,14 +165,17 @@ class LiteLLMRouterClientConfig:
150
165
  return d
151
166
 
152
167
  def to_litellm_dict(self) -> dict:
153
- litellm_model_list = self._convert_models_to_litellm_model_list()
154
- d = {
168
+ return {
155
169
  **self.extra_parameters,
156
170
  MODEL_GROUP_ID_CONFIG_KEY: self.model_group_id,
157
- MODEL_LIST_KEY: litellm_model_list,
158
- ROUTER_CONFIG_KEY: self.router,
171
+ MODEL_LIST_KEY: self._convert_models_to_litellm_model_list(),
172
+ ROUTER_CONFIG_KEY: self._convert_router_to_litellm_router_settings(),
159
173
  }
160
- return d
174
+
175
+ def _convert_router_to_litellm_router_settings(self) -> Dict[str, Any]:
176
+ _router_settings_copy = copy.deepcopy(self.router)
177
+ _router_settings_copy.pop(USE_CHAT_COMPLETIONS_ENDPOINT_CONFIG_KEY, None)
178
+ return _router_settings_copy
161
179
 
162
180
  def _convert_models_to_litellm_model_list(self) -> List[Dict[str, Any]]:
163
181
  litellm_model_list = []
@@ -172,7 +190,7 @@ class LiteLLMRouterClientConfig:
172
190
  prefix = get_prefix_from_provider(provider)
173
191
 
174
192
  # Determine whether to use model or deployment key based on the provider.
175
- litellm_model_name_without_prefix = (
193
+ litellm_model_name = (
176
194
  litellm_model_config[DEPLOYMENT_CONFIG_KEY]
177
195
  if provider in DEPLOYMENT_CENTRIC_PROVIDERS
178
196
  else litellm_model_config[MODEL_CONFIG_KEY]
@@ -180,7 +198,9 @@ class LiteLLMRouterClientConfig:
180
198
 
181
199
  # Set 'model' to a provider prefixed model name e.g. openai/gpt-4
182
200
  litellm_model_config[MODEL_CONFIG_KEY] = (
183
- f"{prefix}/{litellm_model_name_without_prefix}"
201
+ litellm_model_name
202
+ if f"{prefix}/" in litellm_model_name
203
+ else f"{prefix}/{litellm_model_name}"
184
204
  )
185
205
 
186
206
  # Remove parameters that are None and not supported by LiteLLM.