rasa-pro 3.12.0.dev13__py3-none-any.whl → 3.12.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (128) hide show
  1. rasa/anonymization/anonymization_rule_executor.py +16 -10
  2. rasa/cli/data.py +16 -0
  3. rasa/cli/project_templates/calm/config.yml +2 -2
  4. rasa/cli/project_templates/calm/endpoints.yml +2 -2
  5. rasa/cli/utils.py +12 -0
  6. rasa/core/actions/action.py +84 -191
  7. rasa/core/actions/action_run_slot_rejections.py +16 -4
  8. rasa/core/channels/__init__.py +2 -0
  9. rasa/core/channels/studio_chat.py +19 -0
  10. rasa/core/channels/telegram.py +42 -24
  11. rasa/core/channels/voice_ready/utils.py +1 -1
  12. rasa/core/channels/voice_stream/asr/asr_engine.py +10 -4
  13. rasa/core/channels/voice_stream/asr/azure.py +14 -1
  14. rasa/core/channels/voice_stream/asr/deepgram.py +20 -4
  15. rasa/core/channels/voice_stream/audiocodes.py +264 -0
  16. rasa/core/channels/voice_stream/browser_audio.py +4 -1
  17. rasa/core/channels/voice_stream/call_state.py +3 -0
  18. rasa/core/channels/voice_stream/genesys.py +6 -2
  19. rasa/core/channels/voice_stream/tts/azure.py +9 -1
  20. rasa/core/channels/voice_stream/tts/cartesia.py +14 -8
  21. rasa/core/channels/voice_stream/voice_channel.py +23 -2
  22. rasa/core/constants.py +2 -0
  23. rasa/core/nlg/contextual_response_rephraser.py +18 -1
  24. rasa/core/nlg/generator.py +83 -15
  25. rasa/core/nlg/response.py +6 -3
  26. rasa/core/nlg/translate.py +55 -0
  27. rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +1 -1
  28. rasa/core/policies/flows/flow_executor.py +12 -5
  29. rasa/core/processor.py +72 -9
  30. rasa/dialogue_understanding/commands/can_not_handle_command.py +20 -2
  31. rasa/dialogue_understanding/commands/cancel_flow_command.py +24 -6
  32. rasa/dialogue_understanding/commands/change_flow_command.py +20 -2
  33. rasa/dialogue_understanding/commands/chit_chat_answer_command.py +20 -2
  34. rasa/dialogue_understanding/commands/clarify_command.py +29 -3
  35. rasa/dialogue_understanding/commands/command.py +1 -16
  36. rasa/dialogue_understanding/commands/command_syntax_manager.py +55 -0
  37. rasa/dialogue_understanding/commands/human_handoff_command.py +20 -2
  38. rasa/dialogue_understanding/commands/knowledge_answer_command.py +20 -2
  39. rasa/dialogue_understanding/commands/prompt_command.py +94 -0
  40. rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +20 -2
  41. rasa/dialogue_understanding/commands/set_slot_command.py +24 -2
  42. rasa/dialogue_understanding/commands/skip_question_command.py +20 -2
  43. rasa/dialogue_understanding/commands/start_flow_command.py +20 -2
  44. rasa/dialogue_understanding/commands/utils.py +98 -4
  45. rasa/dialogue_understanding/generator/__init__.py +2 -0
  46. rasa/dialogue_understanding/generator/command_parser.py +15 -12
  47. rasa/dialogue_understanding/generator/constants.py +3 -0
  48. rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -5
  49. rasa/dialogue_understanding/generator/llm_command_generator.py +5 -3
  50. rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +16 -2
  51. rasa/dialogue_understanding/generator/prompt_templates/__init__.py +0 -0
  52. rasa/dialogue_understanding/generator/{single_step → prompt_templates}/command_prompt_template.jinja2 +2 -0
  53. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +77 -0
  54. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +68 -0
  55. rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +84 -0
  56. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +460 -0
  57. rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +12 -310
  58. rasa/dialogue_understanding/patterns/collect_information.py +1 -1
  59. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +16 -0
  60. rasa/dialogue_understanding/patterns/validate_slot.py +65 -0
  61. rasa/dialogue_understanding/processor/command_processor.py +39 -0
  62. rasa/dialogue_understanding_test/du_test_case.py +28 -8
  63. rasa/dialogue_understanding_test/du_test_result.py +13 -9
  64. rasa/dialogue_understanding_test/io.py +14 -0
  65. rasa/e2e_test/utils/io.py +0 -37
  66. rasa/engine/graph.py +1 -0
  67. rasa/engine/language.py +140 -0
  68. rasa/engine/recipes/config_files/default_config.yml +4 -0
  69. rasa/engine/recipes/default_recipe.py +2 -0
  70. rasa/engine/recipes/graph_recipe.py +2 -0
  71. rasa/engine/storage/local_model_storage.py +1 -0
  72. rasa/engine/storage/storage.py +4 -1
  73. rasa/model_manager/runner_service.py +7 -4
  74. rasa/model_manager/socket_bridge.py +7 -6
  75. rasa/shared/constants.py +15 -13
  76. rasa/shared/core/constants.py +2 -0
  77. rasa/shared/core/flows/constants.py +11 -0
  78. rasa/shared/core/flows/flow.py +83 -19
  79. rasa/shared/core/flows/flows_yaml_schema.json +31 -3
  80. rasa/shared/core/flows/steps/collect.py +1 -36
  81. rasa/shared/core/flows/utils.py +28 -4
  82. rasa/shared/core/flows/validation.py +1 -1
  83. rasa/shared/core/slot_mappings.py +208 -5
  84. rasa/shared/core/slots.py +131 -1
  85. rasa/shared/core/trackers.py +74 -1
  86. rasa/shared/importers/importer.py +50 -2
  87. rasa/shared/nlu/training_data/schemas/responses.yml +19 -12
  88. rasa/shared/providers/_configs/azure_entra_id_config.py +541 -0
  89. rasa/shared/providers/_configs/azure_openai_client_config.py +138 -3
  90. rasa/shared/providers/_configs/client_config.py +3 -1
  91. rasa/shared/providers/_configs/default_litellm_client_config.py +3 -1
  92. rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +3 -1
  93. rasa/shared/providers/_configs/litellm_router_client_config.py +3 -1
  94. rasa/shared/providers/_configs/model_group_config.py +4 -2
  95. rasa/shared/providers/_configs/oauth_config.py +33 -0
  96. rasa/shared/providers/_configs/openai_client_config.py +3 -1
  97. rasa/shared/providers/_configs/rasa_llm_client_config.py +3 -1
  98. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +3 -1
  99. rasa/shared/providers/constants.py +6 -0
  100. rasa/shared/providers/embedding/azure_openai_embedding_client.py +28 -3
  101. rasa/shared/providers/embedding/litellm_router_embedding_client.py +3 -1
  102. rasa/shared/providers/llm/_base_litellm_client.py +42 -17
  103. rasa/shared/providers/llm/azure_openai_llm_client.py +81 -25
  104. rasa/shared/providers/llm/default_litellm_llm_client.py +3 -1
  105. rasa/shared/providers/llm/litellm_router_llm_client.py +29 -8
  106. rasa/shared/providers/llm/llm_client.py +23 -7
  107. rasa/shared/providers/llm/openai_llm_client.py +9 -3
  108. rasa/shared/providers/llm/rasa_llm_client.py +11 -2
  109. rasa/shared/providers/llm/self_hosted_llm_client.py +30 -11
  110. rasa/shared/providers/router/_base_litellm_router_client.py +3 -1
  111. rasa/shared/providers/router/router_client.py +3 -1
  112. rasa/shared/utils/constants.py +3 -0
  113. rasa/shared/utils/llm.py +30 -7
  114. rasa/shared/utils/pykwalify_extensions.py +24 -0
  115. rasa/shared/utils/schemas/domain.yml +26 -0
  116. rasa/telemetry.py +2 -1
  117. rasa/tracing/config.py +2 -0
  118. rasa/tracing/constants.py +12 -0
  119. rasa/tracing/instrumentation/instrumentation.py +36 -0
  120. rasa/tracing/instrumentation/metrics.py +41 -0
  121. rasa/tracing/metric_instrument_provider.py +40 -0
  122. rasa/validator.py +372 -7
  123. rasa/version.py +1 -1
  124. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/METADATA +2 -1
  125. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/RECORD +128 -113
  126. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/NOTICE +0 -0
  127. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/WHEEL +0 -0
  128. {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc1.dist-info}/entry_points.txt +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import re
3
5
  from typing import Any, Dict, Optional
@@ -21,13 +23,38 @@ from rasa.shared.constants import (
21
23
  )
22
24
  from rasa.shared.exceptions import ProviderClientValidationError
23
25
  from rasa.shared.providers._configs.azure_openai_client_config import (
26
+ AzureEntraIDOAuthConfig,
24
27
  AzureOpenAIClientConfig,
25
28
  )
29
+ from rasa.shared.providers.constants import (
30
+ DEFAULT_AZURE_API_KEY_NAME,
31
+ LITE_LLM_API_BASE_FIELD,
32
+ LITE_LLM_API_KEY_FIELD,
33
+ LITE_LLM_API_VERSION_FIELD,
34
+ LITE_LLM_AZURE_AD_TOKEN,
35
+ )
26
36
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
27
37
  from rasa.shared.utils.io import raise_deprecation_warning
28
38
 
29
39
  structlogger = structlog.get_logger()
30
40
 
41
+ AZURE_CLIENT_ID = "AZURE_CLIENT_ID"
42
+ AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET"
43
+ AZURE_TENANT_ID = "AZURE_TENANT_ID"
44
+ CLIENT_SECRET_VARS = (AZURE_CLIENT_ID, AZURE_CLIENT_SECRET, AZURE_TENANT_ID)
45
+
46
+ AZURE_CLIENT_CERTIFICATE_PATH = "AZURE_CLIENT_CERTIFICATE_PATH"
47
+ AZURE_CLIENT_CERTIFICATE_PASSWORD = "AZURE_CLIENT_CERTIFICATE_PASSWORD"
48
+ AZURE_CLIENT_SEND_CERTIFICATE_CHAIN = "AZURE_CLIENT_SEND_CERTIFICATE_CHAIN"
49
+ CERT_VARS = (AZURE_CLIENT_ID, AZURE_CLIENT_CERTIFICATE_PATH, AZURE_TENANT_ID)
50
+
51
+
52
+ class AzureADConfig:
53
+ def __init__(
54
+ self, client_id: str, client_secret: str, tenant_id: str, scopes: str
55
+ ) -> None:
56
+ self.scopes = scopes
57
+
31
58
 
32
59
  class AzureOpenAILLMClient(_BaseLiteLLMClient):
33
60
  """A client for interfacing with Azure's OpenAI LLM deployments.
@@ -41,6 +68,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
41
68
  it will be set via environment variables.
42
69
  api_version (Optional[str]): The version of the API to use. If not provided,
43
70
  it will be set via environment variable.
71
+
44
72
  kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
45
73
  to the model deployment.
46
74
 
@@ -57,6 +85,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
57
85
  api_type: Optional[str] = None,
58
86
  api_base: Optional[str] = None,
59
87
  api_version: Optional[str] = None,
88
+ oauth: Optional[AzureEntraIDOAuthConfig] = None,
60
89
  **kwargs: Any,
61
90
  ):
62
91
  super().__init__() # type: ignore
@@ -80,8 +109,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
80
109
  or os.getenv(OPENAI_API_VERSION_ENV_VAR)
81
110
  )
82
111
 
83
- self._api_key_env_var = self._resolve_api_key_env_var()
84
-
85
112
  # Not used by LiteLLM, here for backward compatibility
86
113
  self._api_type = (
87
114
  api_type
@@ -89,6 +116,19 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
89
116
  or os.getenv(OPENAI_API_TYPE_ENV_VAR)
90
117
  )
91
118
 
119
+ os.unsetenv("OPENAI_API_KEY")
120
+ os.unsetenv("AZURE_API_KEY")
121
+
122
+ self._oauth = oauth
123
+
124
+ if self._oauth:
125
+ os.unsetenv(DEFAULT_AZURE_API_KEY_NAME)
126
+ os.unsetenv(AZURE_API_KEY_ENV_VAR)
127
+ os.unsetenv(OPENAI_API_KEY_ENV_VAR)
128
+ self._api_key_env_var = (
129
+ self._resolve_api_key_env_var() if not self._oauth else None
130
+ )
131
+
92
132
  # Run helper function to check and raise deprecation warning if
93
133
  # deprecated environment variables were used for initialization of the
94
134
  # client settings
@@ -157,7 +197,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
157
197
  return self._extra_parameters[API_KEY]
158
198
 
159
199
  if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
160
- return "${AZURE_API_KEY}"
200
+ return f"${{{DEFAULT_AZURE_API_KEY_NAME}}}"
161
201
 
162
202
  if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
163
203
  # API key can be set through OPENAI_API_KEY too,
@@ -188,7 +228,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
188
228
  )
189
229
 
190
230
  @classmethod
191
- def from_config(cls, config: Dict[str, Any]) -> "AzureOpenAILLMClient":
231
+ def from_config(cls, config: Dict[str, Any]) -> AzureOpenAILLMClient:
192
232
  """Initializes the client from given configuration.
193
233
 
194
234
  Args:
@@ -215,11 +255,12 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
215
255
  raise
216
256
 
217
257
  return cls(
218
- azure_openai_config.deployment,
219
- azure_openai_config.model,
220
- azure_openai_config.api_type,
221
- azure_openai_config.api_base,
222
- azure_openai_config.api_version,
258
+ deployment=azure_openai_config.deployment,
259
+ model=azure_openai_config.model,
260
+ api_type=azure_openai_config.api_type,
261
+ api_base=azure_openai_config.api_base,
262
+ api_version=azure_openai_config.api_version,
263
+ oauth=azure_openai_config.oauth,
223
264
  **azure_openai_config.extra_parameters,
224
265
  )
225
266
 
@@ -234,6 +275,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
234
275
  api_base=self._api_base,
235
276
  api_version=self._api_version,
236
277
  api_type=self._api_type,
278
+ oauth=self._oauth,
237
279
  extra_parameters=self._extra_parameters,
238
280
  )
239
281
  return config.to_dict()
@@ -282,12 +324,23 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
282
324
  """Returns the completion arguments for invoking a call through
283
325
  LiteLLM's completion functions.
284
326
  """
327
+ # Set the API key env var to None if OAuth is used
328
+ auth_parameter: Dict[str, str] = {}
329
+
330
+ if self._oauth:
331
+ auth_parameter = {
332
+ **auth_parameter,
333
+ LITE_LLM_AZURE_AD_TOKEN: self._oauth.get_bearer_token(),
334
+ }
335
+ elif self._api_key_env_var:
336
+ auth_parameter = {LITE_LLM_API_KEY_FIELD: self._api_key_env_var}
337
+
285
338
  fn_args = super()._completion_fn_args
286
339
  fn_args.update(
287
340
  {
288
- "api_base": self.api_base,
289
- "api_version": self.api_version,
290
- "api_key": self._api_key_env_var,
341
+ LITE_LLM_API_BASE_FIELD: self.api_base,
342
+ LITE_LLM_API_VERSION_FIELD: self.api_version,
343
+ **auth_parameter,
291
344
  }
292
345
  )
293
346
  return fn_args
@@ -314,41 +367,44 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
314
367
 
315
368
  return info.format(setting=setting, options=options)
316
369
 
370
+ env_var_field = "env_var"
371
+ config_key_field = "config_key"
372
+ current_value_field = "current_value"
317
373
  # All required settings for Azure OpenAI client
318
374
  settings: Dict[str, Dict[str, Any]] = {
319
375
  "API Base": {
320
- "current_value": self.api_base,
321
- "env_var": AZURE_API_BASE_ENV_VAR,
322
- "config_key": API_BASE_CONFIG_KEY,
376
+ current_value_field: self.api_base,
377
+ env_var_field: AZURE_API_BASE_ENV_VAR,
378
+ config_key_field: API_BASE_CONFIG_KEY,
323
379
  },
324
380
  "API Version": {
325
- "current_value": self.api_version,
326
- "env_var": AZURE_API_VERSION_ENV_VAR,
327
- "config_key": API_VERSION_CONFIG_KEY,
381
+ current_value_field: self.api_version,
382
+ env_var_field: AZURE_API_VERSION_ENV_VAR,
383
+ config_key_field: API_VERSION_CONFIG_KEY,
328
384
  },
329
385
  "Deployment Name": {
330
- "current_value": self.deployment,
331
- "env_var": None,
332
- "config_key": DEPLOYMENT_CONFIG_KEY,
386
+ current_value_field: self.deployment,
387
+ env_var_field: None,
388
+ config_key_field: DEPLOYMENT_CONFIG_KEY,
333
389
  },
334
390
  }
335
391
 
336
392
  missing_settings = [
337
393
  setting_name
338
394
  for setting_name, setting_info in settings.items()
339
- if setting_info["current_value"] is None
395
+ if setting_info[current_value_field] is None
340
396
  ]
341
397
 
342
398
  if missing_settings:
343
399
  event_info = f"Client settings not set: " f"{', '.join(missing_settings)}. "
344
400
 
345
401
  for missing_setting in missing_settings:
346
- if settings[missing_setting]["current_value"] is not None:
402
+ if settings[missing_setting][current_value_field] is not None:
347
403
  continue
348
404
  event_info += generate_event_info_for_missing_setting(
349
405
  missing_setting,
350
- settings[missing_setting]["env_var"],
351
- settings[missing_setting]["config_key"],
406
+ settings[missing_setting][env_var_field],
407
+ settings[missing_setting][config_key_field],
352
408
  )
353
409
 
354
410
  structlogger.error(
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Any, Dict
2
4
 
3
5
  from rasa.shared.constants import (
@@ -35,7 +37,7 @@ class DefaultLiteLLMClient(_BaseLiteLLMClient):
35
37
  self.validate_client_setup()
36
38
 
37
39
  @classmethod
38
- def from_config(cls, config: Dict[str, Any]) -> "DefaultLiteLLMClient":
40
+ def from_config(cls, config: Dict[str, Any]) -> DefaultLiteLLMClient:
39
41
  default_config = DefaultLiteLLMClientConfig.from_dict(config)
40
42
  return cls(
41
43
  model=default_config.model,
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Any, Dict, List, Union
3
5
 
@@ -7,6 +9,7 @@ from rasa.shared.exceptions import ProviderClientAPIException
7
9
  from rasa.shared.providers._configs.litellm_router_client_config import (
8
10
  LiteLLMRouterClientConfig,
9
11
  )
12
+ from rasa.shared.providers.constants import LITE_LLM_MODEL_FIELD
10
13
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
11
14
  from rasa.shared.providers.llm.llm_response import LLMResponse
12
15
  from rasa.shared.providers.router._base_litellm_router_client import (
@@ -42,7 +45,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
42
45
  )
43
46
 
44
47
  @classmethod
45
- def from_config(cls, config: Dict[str, Any]) -> "LiteLLMRouterLLMClient":
48
+ def from_config(cls, config: Dict[str, Any]) -> LiteLLMRouterLLMClient:
46
49
  """Instantiates a LiteLLM Router LLM client from a configuration dict.
47
50
 
48
51
  Args:
@@ -87,6 +90,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
87
90
  ProviderClientAPIException: If the API request fails.
88
91
  """
89
92
  try:
93
+ structlogger.info(
94
+ "litellm_router_llm_client.text_completion",
95
+ _completion_fn_args=self._completion_fn_args,
96
+ )
90
97
  response = self.router_client.text_completion(
91
98
  prompt=prompt, **self._completion_fn_args
92
99
  )
@@ -115,7 +122,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
115
122
  raise ProviderClientAPIException(e)
116
123
 
117
124
  @suppress_logs(log_level=logging.WARNING)
118
- def completion(self, messages: Union[List[str], str]) -> LLMResponse:
125
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
119
126
  """
120
127
  Synchronously generate completions for given list of messages.
121
128
 
@@ -125,8 +132,14 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
125
132
  text_completion method is called.
126
133
 
127
134
  Args:
128
- messages: List of messages or a single message to generate the
129
- completion for.
135
+ messages: The message can be,
136
+ - a list of preformatted messages. Each message should be a dictionary
137
+ with the following keys:
138
+ - content: The message content.
139
+ - role: The role of the message (e.g. user or system).
140
+ - a list of messages. Each message is a string and will be formatted
141
+ as a user message.
142
+ - a single message as a string which will be formatted as user message.
130
143
  Returns:
131
144
  List of message completions.
132
145
  Raises:
@@ -144,7 +157,9 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
144
157
  raise ProviderClientAPIException(e)
145
158
 
146
159
  @suppress_logs(log_level=logging.WARNING)
147
- async def acompletion(self, messages: Union[List[str], str]) -> LLMResponse:
160
+ async def acompletion(
161
+ self, messages: Union[List[dict], List[str], str]
162
+ ) -> LLMResponse:
148
163
  """
149
164
  Asynchronously generate completions for given list of messages.
150
165
 
@@ -154,8 +169,14 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
154
169
  text_completion method is called.
155
170
 
156
171
  Args:
157
- messages: List of messages or a single message to generate the
158
- completion for.
172
+ messages: The message can be,
173
+ - a list of preformatted messages. Each message should be a dictionary
174
+ with the following keys:
175
+ - content: The message content.
176
+ - role: The role of the message (e.g. user or system).
177
+ - a list of messages. Each message is a string and will be formatted
178
+ as a user message.
179
+ - a single message as a string which will be formatted as user message.
159
180
  Returns:
160
181
  List of message completions.
161
182
  Raises:
@@ -179,5 +200,5 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
179
200
  """
180
201
  return {
181
202
  **self._litellm_extra_parameters,
182
- "model": self.model_group_id,
203
+ LITE_LLM_MODEL_FIELD: self.model_group_id,
183
204
  }
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Dict, List, Protocol, Union, runtime_checkable
2
4
 
3
5
  from rasa.shared.providers.llm.llm_response import LLMResponse
@@ -11,7 +13,7 @@ class LLMClient(Protocol):
11
13
  """
12
14
 
13
15
  @classmethod
14
- def from_config(cls, config: dict) -> "LLMClient":
16
+ def from_config(cls, config: dict) -> LLMClient:
15
17
  """
16
18
  Initializes the llm client with the given configuration.
17
19
 
@@ -30,7 +32,7 @@ class LLMClient(Protocol):
30
32
  """
31
33
  ...
32
34
 
33
- def completion(self, messages: Union[List[str], str]) -> LLMResponse:
35
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
34
36
  """
35
37
  Synchronously generate completions for given list of messages.
36
38
 
@@ -38,14 +40,22 @@ class LLMClient(Protocol):
38
40
  strings) and return a list of completions (as strings).
39
41
 
40
42
  Args:
41
- messages: List of messages or a single message to generate the
42
- completion for.
43
+ messages: The message can be,
44
+ - a list of preformatted messages. Each message should be a dictionary
45
+ with the following keys:
46
+ - content: The message content.
47
+ - role: The role of the message (e.g. user or system).
48
+ - a list of messages. Each message is a string and will be formatted
49
+ as a user message.
50
+ - a single message as a string which will be formatted as user message.
43
51
  Returns:
44
52
  LLMResponse
45
53
  """
46
54
  ...
47
55
 
48
- async def acompletion(self, messages: Union[List[str], str]) -> LLMResponse:
56
+ async def acompletion(
57
+ self, messages: Union[List[dict], List[str], str]
58
+ ) -> LLMResponse:
49
59
  """
50
60
  Asynchronously generate completions for given list of messages.
51
61
 
@@ -53,8 +63,14 @@ class LLMClient(Protocol):
53
63
  strings) and return a list of completions (as strings).
54
64
 
55
65
  Args:
56
- messages: List of messages or a single message to generate the
57
- completion for.
66
+ messages: The message can be,
67
+ - a list of preformatted messages. Each message should be a dictionary
68
+ with the following keys:
69
+ - content: The message content.
70
+ - role: The role of the message (e.g. user or system).
71
+ - a list of messages. Each message is a string and will be formatted
72
+ as a user message.
73
+ - a single message as a string which will be formatted as user message.
58
74
  Returns:
59
75
  LLMResponse
60
76
  """
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import re
3
5
  from typing import Any, Dict, Optional
@@ -11,6 +13,10 @@ from rasa.shared.constants import (
11
13
  OPENAI_PROVIDER,
12
14
  )
13
15
  from rasa.shared.providers._configs.openai_client_config import OpenAIClientConfig
16
+ from rasa.shared.providers.constants import (
17
+ LITE_LLM_API_KEY_FIELD,
18
+ LITE_LLM_API_VERSION_FIELD,
19
+ )
14
20
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
15
21
 
16
22
  structlogger = structlog.get_logger()
@@ -57,7 +63,7 @@ class OpenAILLMClient(_BaseLiteLLMClient):
57
63
  self.validate_client_setup()
58
64
 
59
65
  @classmethod
60
- def from_config(cls, config: Dict[str, Any]) -> "OpenAILLMClient":
66
+ def from_config(cls, config: Dict[str, Any]) -> OpenAILLMClient:
61
67
  """
62
68
  Initializes the client from given configuration.
63
69
 
@@ -148,8 +154,8 @@ class OpenAILLMClient(_BaseLiteLLMClient):
148
154
  fn_args = super()._completion_fn_args
149
155
  fn_args.update(
150
156
  {
151
- "api_base": self.api_base,
152
- "api_version": self.api_version,
157
+ LITE_LLM_API_KEY_FIELD: self.api_base,
158
+ LITE_LLM_API_VERSION_FIELD: self.api_version,
153
159
  }
154
160
  )
155
161
  return fn_args
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Any, Dict, Optional
2
4
 
3
5
  import structlog
@@ -9,6 +11,10 @@ from rasa.shared.constants import (
9
11
  from rasa.shared.providers._configs.rasa_llm_client_config import (
10
12
  RasaLLMClientConfig,
11
13
  )
14
+ from rasa.shared.providers.constants import (
15
+ LITE_LLM_API_BASE_FIELD,
16
+ LITE_LLM_API_KEY_FIELD,
17
+ )
12
18
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
13
19
  from rasa.utils.licensing import retrieve_license_from_env
14
20
 
@@ -82,12 +88,15 @@ class RasaLLMClient(_BaseLiteLLMClient):
82
88
  """Returns the completion arguments for invoking a call using completions."""
83
89
  fn_args = super()._completion_fn_args
84
90
  fn_args.update(
85
- {"api_base": self.api_base, "api_key": retrieve_license_from_env()}
91
+ {
92
+ LITE_LLM_API_BASE_FIELD: self.api_base,
93
+ LITE_LLM_API_KEY_FIELD: retrieve_license_from_env(),
94
+ }
86
95
  )
87
96
  return fn_args
88
97
 
89
98
  @classmethod
90
- def from_config(cls, config: Dict[str, Any]) -> "RasaLLMClient":
99
+ def from_config(cls, config: Dict[str, Any]) -> RasaLLMClient:
91
100
  try:
92
101
  client_config = RasaLLMClientConfig.from_dict(config)
93
102
  except ValueError as e:
@@ -1,12 +1,11 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import os
3
5
  from typing import Any, Dict, List, Optional, Union
4
6
 
5
7
  import structlog
6
- from litellm import (
7
- atext_completion,
8
- text_completion,
9
- )
8
+ from litellm import atext_completion, text_completion
10
9
 
11
10
  from rasa.shared.constants import (
12
11
  API_KEY,
@@ -17,6 +16,10 @@ from rasa.shared.exceptions import ProviderClientAPIException
17
16
  from rasa.shared.providers._configs.self_hosted_llm_client_config import (
18
17
  SelfHostedLLMClientConfig,
19
18
  )
19
+ from rasa.shared.providers.constants import (
20
+ LITE_LLM_API_BASE_FIELD,
21
+ LITE_LLM_API_VERSION_FIELD,
22
+ )
20
23
  from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
21
24
  from rasa.shared.providers.llm.llm_response import LLMResponse
22
25
  from rasa.shared.utils.io import suppress_logs
@@ -67,7 +70,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
67
70
  self._apply_dummy_api_key_if_missing()
68
71
 
69
72
  @classmethod
70
- def from_config(cls, config: Dict[str, Any]) -> "SelfHostedLLMClient":
73
+ def from_config(cls, config: Dict[str, Any]) -> SelfHostedLLMClient:
71
74
  try:
72
75
  client_config = SelfHostedLLMClientConfig.from_dict(config)
73
76
  except ValueError as e:
@@ -184,8 +187,8 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
184
187
  fn_args = super()._completion_fn_args
185
188
  fn_args.update(
186
189
  {
187
- "api_base": self.api_base,
188
- "api_version": self.api_version,
190
+ LITE_LLM_API_BASE_FIELD: self.api_base,
191
+ LITE_LLM_API_VERSION_FIELD: self.api_version,
189
192
  }
190
193
  )
191
194
  return fn_args
@@ -214,7 +217,14 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
214
217
  Asynchronously generate completions for given prompt.
215
218
 
216
219
  Args:
217
- prompt: Prompt to generate the completion for.
220
+ messages: The message can be,
221
+ - a list of preformatted messages. Each message should be a dictionary
222
+ with the following keys:
223
+ - content: The message content.
224
+ - role: The role of the message (e.g. user or system).
225
+ - a list of messages. Each message is a string and will be formatted
226
+ as a user message.
227
+ - a single message as a string which will be formatted as user message.
218
228
  Returns:
219
229
  List of message completions.
220
230
  Raises:
@@ -226,7 +236,9 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
226
236
  except Exception as e:
227
237
  raise ProviderClientAPIException(e)
228
238
 
229
- async def acompletion(self, messages: Union[List[str], str]) -> LLMResponse:
239
+ async def acompletion(
240
+ self, messages: Union[List[dict], List[str], str]
241
+ ) -> LLMResponse:
230
242
  """Asynchronous completion of the model with the given messages.
231
243
 
232
244
  Method overrides the base class method to call the appropriate
@@ -235,7 +247,14 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
235
247
  atext_completion method is called.
236
248
 
237
249
  Args:
238
- messages: The messages to be used for completion.
250
+ messages: The message can be,
251
+ - a list of preformatted messages. Each message should be a dictionary
252
+ with the following keys:
253
+ - content: The message content.
254
+ - role: The role of the message (e.g. user or system).
255
+ - a list of messages. Each message is a string and will be formatted
256
+ as a user message.
257
+ - a single message as a string which will be formatted as user message.
239
258
 
240
259
  Returns:
241
260
  The completion response.
@@ -244,7 +263,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
244
263
  return await super().acompletion(messages)
245
264
  return await self._atext_completion(messages)
246
265
 
247
- def completion(self, messages: Union[List[str], str]) -> LLMResponse:
266
+ def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
248
267
  """Completion of the model with the given messages.
249
268
 
250
269
  Method overrides the base class method to call the appropriate
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  from typing import Any, Dict, List
3
5
 
@@ -93,7 +95,7 @@ class _BaseLiteLLMRouterClient:
93
95
  return
94
96
 
95
97
  @classmethod
96
- def from_config(cls, config: Dict[str, Any]) -> "_BaseLiteLLMRouterClient":
98
+ def from_config(cls, config: Dict[str, Any]) -> _BaseLiteLLMRouterClient:
97
99
  """Instantiates a LiteLLM Router Embedding client from a configuration dict.
98
100
 
99
101
  Args:
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Any, Dict, List, Protocol, runtime_checkable
2
4
 
3
5
 
@@ -9,7 +11,7 @@ class RouterClient(Protocol):
9
11
  """
10
12
 
11
13
  @classmethod
12
- def from_config(cls, config: dict) -> "RouterClient":
14
+ def from_config(cls, config: dict) -> RouterClient:
13
15
  """
14
16
  Initializes the router client with the given configuration.
15
17
 
@@ -2,3 +2,6 @@ DEFAULT_ENCODING = "utf-8"
2
2
 
3
3
  READ_YAML_FILE_CACHE_MAXSIZE_ENV_VAR = "READ_YAML_FILE_CACHE_MAXSIZE"
4
4
  DEFAULT_READ_YAML_FILE_CACHE_MAXSIZE = 256
5
+ RASA_PRO_BETA_PREDICATES_IN_RESPONSE_CONDITIONS_ENV_VAR_NAME = (
6
+ "RASA_PRO_BETA_PREDICATES_IN_RESPONSE_CONDITIONS"
7
+ )