rasa-pro 3.12.0.dev13__py3-none-any.whl → 3.12.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +10 -13
- rasa/anonymization/anonymization_rule_executor.py +16 -10
- rasa/cli/data.py +16 -0
- rasa/cli/project_templates/calm/config.yml +2 -2
- rasa/cli/project_templates/calm/domain/list_contacts.yml +1 -2
- rasa/cli/project_templates/calm/domain/remove_contact.yml +1 -2
- rasa/cli/project_templates/calm/domain/shared.yml +1 -4
- rasa/cli/project_templates/calm/endpoints.yml +2 -2
- rasa/cli/utils.py +12 -0
- rasa/core/actions/action.py +84 -191
- rasa/core/actions/action_handle_digressions.py +35 -13
- rasa/core/actions/action_run_slot_rejections.py +16 -4
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/studio_chat.py +19 -0
- rasa/core/channels/telegram.py +42 -24
- rasa/core/channels/voice_ready/utils.py +1 -1
- rasa/core/channels/voice_stream/asr/asr_engine.py +10 -4
- rasa/core/channels/voice_stream/asr/azure.py +14 -1
- rasa/core/channels/voice_stream/asr/deepgram.py +20 -4
- rasa/core/channels/voice_stream/audiocodes.py +264 -0
- rasa/core/channels/voice_stream/browser_audio.py +4 -1
- rasa/core/channels/voice_stream/call_state.py +3 -0
- rasa/core/channels/voice_stream/genesys.py +6 -2
- rasa/core/channels/voice_stream/tts/azure.py +9 -1
- rasa/core/channels/voice_stream/tts/cartesia.py +14 -8
- rasa/core/channels/voice_stream/voice_channel.py +23 -2
- rasa/core/constants.py +2 -0
- rasa/core/nlg/contextual_response_rephraser.py +18 -1
- rasa/core/nlg/generator.py +83 -15
- rasa/core/nlg/response.py +6 -3
- rasa/core/nlg/translate.py +55 -0
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +1 -1
- rasa/core/policies/flows/flow_executor.py +19 -7
- rasa/core/processor.py +71 -9
- rasa/dialogue_understanding/commands/can_not_handle_command.py +20 -2
- rasa/dialogue_understanding/commands/cancel_flow_command.py +24 -6
- rasa/dialogue_understanding/commands/change_flow_command.py +20 -2
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +20 -2
- rasa/dialogue_understanding/commands/clarify_command.py +29 -3
- rasa/dialogue_understanding/commands/command.py +1 -16
- rasa/dialogue_understanding/commands/command_syntax_manager.py +55 -0
- rasa/dialogue_understanding/commands/handle_digressions_command.py +1 -7
- rasa/dialogue_understanding/commands/human_handoff_command.py +20 -2
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +20 -2
- rasa/dialogue_understanding/commands/prompt_command.py +94 -0
- rasa/dialogue_understanding/commands/repeat_bot_messages_command.py +20 -2
- rasa/dialogue_understanding/commands/set_slot_command.py +24 -2
- rasa/dialogue_understanding/commands/skip_question_command.py +20 -2
- rasa/dialogue_understanding/commands/start_flow_command.py +22 -2
- rasa/dialogue_understanding/commands/utils.py +71 -4
- rasa/dialogue_understanding/generator/__init__.py +2 -0
- rasa/dialogue_understanding/generator/command_parser.py +15 -12
- rasa/dialogue_understanding/generator/constants.py +3 -0
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +12 -5
- rasa/dialogue_understanding/generator/llm_command_generator.py +5 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +17 -3
- rasa/dialogue_understanding/generator/prompt_templates/__init__.py +0 -0
- rasa/dialogue_understanding/generator/{single_step → prompt_templates}/command_prompt_template.jinja2 +2 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_claude_3_5_sonnet_20240620_template.jinja2 +77 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_default.jinja2 +68 -0
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +84 -0
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +522 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +12 -310
- rasa/dialogue_understanding/patterns/collect_information.py +1 -1
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +16 -0
- rasa/dialogue_understanding/patterns/validate_slot.py +65 -0
- rasa/dialogue_understanding/processor/command_processor.py +39 -0
- rasa/dialogue_understanding/stack/utils.py +38 -0
- rasa/dialogue_understanding_test/du_test_case.py +58 -18
- rasa/dialogue_understanding_test/du_test_result.py +14 -10
- rasa/dialogue_understanding_test/io.py +14 -0
- rasa/e2e_test/assertions.py +6 -8
- rasa/e2e_test/llm_judge_prompts/answer_relevance_prompt_template.jinja2 +5 -1
- rasa/e2e_test/llm_judge_prompts/groundedness_prompt_template.jinja2 +4 -0
- rasa/e2e_test/utils/io.py +0 -37
- rasa/engine/graph.py +1 -0
- rasa/engine/language.py +140 -0
- rasa/engine/recipes/config_files/default_config.yml +4 -0
- rasa/engine/recipes/default_recipe.py +2 -0
- rasa/engine/recipes/graph_recipe.py +2 -0
- rasa/engine/storage/local_model_storage.py +1 -0
- rasa/engine/storage/storage.py +4 -1
- rasa/llm_fine_tuning/conversations.py +1 -1
- rasa/model_manager/runner_service.py +7 -4
- rasa/model_manager/socket_bridge.py +7 -6
- rasa/shared/constants.py +15 -13
- rasa/shared/core/constants.py +2 -0
- rasa/shared/core/flows/constants.py +11 -0
- rasa/shared/core/flows/flow.py +83 -19
- rasa/shared/core/flows/flows_yaml_schema.json +31 -3
- rasa/shared/core/flows/steps/collect.py +1 -36
- rasa/shared/core/flows/utils.py +28 -4
- rasa/shared/core/flows/validation.py +1 -1
- rasa/shared/core/slot_mappings.py +208 -5
- rasa/shared/core/slots.py +137 -1
- rasa/shared/core/trackers.py +74 -1
- rasa/shared/importers/importer.py +50 -2
- rasa/shared/nlu/training_data/schemas/responses.yml +19 -12
- rasa/shared/providers/_configs/azure_entra_id_config.py +541 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +138 -3
- rasa/shared/providers/_configs/client_config.py +3 -1
- rasa/shared/providers/_configs/default_litellm_client_config.py +3 -1
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +3 -1
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -1
- rasa/shared/providers/_configs/model_group_config.py +4 -2
- rasa/shared/providers/_configs/oauth_config.py +33 -0
- rasa/shared/providers/_configs/openai_client_config.py +3 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +3 -1
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +3 -1
- rasa/shared/providers/constants.py +6 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +28 -3
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +3 -1
- rasa/shared/providers/llm/_base_litellm_client.py +42 -17
- rasa/shared/providers/llm/azure_openai_llm_client.py +81 -25
- rasa/shared/providers/llm/default_litellm_llm_client.py +3 -1
- rasa/shared/providers/llm/litellm_router_llm_client.py +29 -8
- rasa/shared/providers/llm/llm_client.py +23 -7
- rasa/shared/providers/llm/openai_llm_client.py +9 -3
- rasa/shared/providers/llm/rasa_llm_client.py +11 -2
- rasa/shared/providers/llm/self_hosted_llm_client.py +30 -11
- rasa/shared/providers/router/_base_litellm_router_client.py +3 -1
- rasa/shared/providers/router/router_client.py +3 -1
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/llm.py +33 -7
- rasa/shared/utils/pykwalify_extensions.py +24 -0
- rasa/shared/utils/schemas/domain.yml +26 -0
- rasa/telemetry.py +2 -1
- rasa/tracing/config.py +2 -0
- rasa/tracing/constants.py +12 -0
- rasa/tracing/instrumentation/instrumentation.py +36 -0
- rasa/tracing/instrumentation/metrics.py +41 -0
- rasa/tracing/metric_instrument_provider.py +40 -0
- rasa/validator.py +372 -7
- rasa/version.py +1 -1
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc2.dist-info}/METADATA +13 -14
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc2.dist-info}/RECORD +139 -124
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.12.0.dev13.dist-info → rasa_pro-3.12.0rc2.dist-info}/entry_points.txt +0 -0
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import os
|
|
2
4
|
import re
|
|
3
5
|
from typing import Any, Dict, Optional
|
|
@@ -21,13 +23,38 @@ from rasa.shared.constants import (
|
|
|
21
23
|
)
|
|
22
24
|
from rasa.shared.exceptions import ProviderClientValidationError
|
|
23
25
|
from rasa.shared.providers._configs.azure_openai_client_config import (
|
|
26
|
+
AzureEntraIDOAuthConfig,
|
|
24
27
|
AzureOpenAIClientConfig,
|
|
25
28
|
)
|
|
29
|
+
from rasa.shared.providers.constants import (
|
|
30
|
+
DEFAULT_AZURE_API_KEY_NAME,
|
|
31
|
+
LITE_LLM_API_BASE_FIELD,
|
|
32
|
+
LITE_LLM_API_KEY_FIELD,
|
|
33
|
+
LITE_LLM_API_VERSION_FIELD,
|
|
34
|
+
LITE_LLM_AZURE_AD_TOKEN,
|
|
35
|
+
)
|
|
26
36
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
27
37
|
from rasa.shared.utils.io import raise_deprecation_warning
|
|
28
38
|
|
|
29
39
|
structlogger = structlog.get_logger()
|
|
30
40
|
|
|
41
|
+
AZURE_CLIENT_ID = "AZURE_CLIENT_ID"
|
|
42
|
+
AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET"
|
|
43
|
+
AZURE_TENANT_ID = "AZURE_TENANT_ID"
|
|
44
|
+
CLIENT_SECRET_VARS = (AZURE_CLIENT_ID, AZURE_CLIENT_SECRET, AZURE_TENANT_ID)
|
|
45
|
+
|
|
46
|
+
AZURE_CLIENT_CERTIFICATE_PATH = "AZURE_CLIENT_CERTIFICATE_PATH"
|
|
47
|
+
AZURE_CLIENT_CERTIFICATE_PASSWORD = "AZURE_CLIENT_CERTIFICATE_PASSWORD"
|
|
48
|
+
AZURE_CLIENT_SEND_CERTIFICATE_CHAIN = "AZURE_CLIENT_SEND_CERTIFICATE_CHAIN"
|
|
49
|
+
CERT_VARS = (AZURE_CLIENT_ID, AZURE_CLIENT_CERTIFICATE_PATH, AZURE_TENANT_ID)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AzureADConfig:
|
|
53
|
+
def __init__(
|
|
54
|
+
self, client_id: str, client_secret: str, tenant_id: str, scopes: str
|
|
55
|
+
) -> None:
|
|
56
|
+
self.scopes = scopes
|
|
57
|
+
|
|
31
58
|
|
|
32
59
|
class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
33
60
|
"""A client for interfacing with Azure's OpenAI LLM deployments.
|
|
@@ -41,6 +68,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
41
68
|
it will be set via environment variables.
|
|
42
69
|
api_version (Optional[str]): The version of the API to use. If not provided,
|
|
43
70
|
it will be set via environment variable.
|
|
71
|
+
|
|
44
72
|
kwargs (Optional[Dict[str, Any]]): Optional configuration parameters specific
|
|
45
73
|
to the model deployment.
|
|
46
74
|
|
|
@@ -57,6 +85,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
57
85
|
api_type: Optional[str] = None,
|
|
58
86
|
api_base: Optional[str] = None,
|
|
59
87
|
api_version: Optional[str] = None,
|
|
88
|
+
oauth: Optional[AzureEntraIDOAuthConfig] = None,
|
|
60
89
|
**kwargs: Any,
|
|
61
90
|
):
|
|
62
91
|
super().__init__() # type: ignore
|
|
@@ -80,8 +109,6 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
80
109
|
or os.getenv(OPENAI_API_VERSION_ENV_VAR)
|
|
81
110
|
)
|
|
82
111
|
|
|
83
|
-
self._api_key_env_var = self._resolve_api_key_env_var()
|
|
84
|
-
|
|
85
112
|
# Not used by LiteLLM, here for backward compatibility
|
|
86
113
|
self._api_type = (
|
|
87
114
|
api_type
|
|
@@ -89,6 +116,19 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
89
116
|
or os.getenv(OPENAI_API_TYPE_ENV_VAR)
|
|
90
117
|
)
|
|
91
118
|
|
|
119
|
+
os.unsetenv("OPENAI_API_KEY")
|
|
120
|
+
os.unsetenv("AZURE_API_KEY")
|
|
121
|
+
|
|
122
|
+
self._oauth = oauth
|
|
123
|
+
|
|
124
|
+
if self._oauth:
|
|
125
|
+
os.unsetenv(DEFAULT_AZURE_API_KEY_NAME)
|
|
126
|
+
os.unsetenv(AZURE_API_KEY_ENV_VAR)
|
|
127
|
+
os.unsetenv(OPENAI_API_KEY_ENV_VAR)
|
|
128
|
+
self._api_key_env_var = (
|
|
129
|
+
self._resolve_api_key_env_var() if not self._oauth else None
|
|
130
|
+
)
|
|
131
|
+
|
|
92
132
|
# Run helper function to check and raise deprecation warning if
|
|
93
133
|
# deprecated environment variables were used for initialization of the
|
|
94
134
|
# client settings
|
|
@@ -157,7 +197,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
157
197
|
return self._extra_parameters[API_KEY]
|
|
158
198
|
|
|
159
199
|
if os.getenv(AZURE_API_KEY_ENV_VAR) is not None:
|
|
160
|
-
return "${
|
|
200
|
+
return f"${{{DEFAULT_AZURE_API_KEY_NAME}}}"
|
|
161
201
|
|
|
162
202
|
if os.getenv(OPENAI_API_KEY_ENV_VAR) is not None:
|
|
163
203
|
# API key can be set through OPENAI_API_KEY too,
|
|
@@ -188,7 +228,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
188
228
|
)
|
|
189
229
|
|
|
190
230
|
@classmethod
|
|
191
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
231
|
+
def from_config(cls, config: Dict[str, Any]) -> AzureOpenAILLMClient:
|
|
192
232
|
"""Initializes the client from given configuration.
|
|
193
233
|
|
|
194
234
|
Args:
|
|
@@ -215,11 +255,12 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
215
255
|
raise
|
|
216
256
|
|
|
217
257
|
return cls(
|
|
218
|
-
azure_openai_config.deployment,
|
|
219
|
-
azure_openai_config.model,
|
|
220
|
-
azure_openai_config.api_type,
|
|
221
|
-
azure_openai_config.api_base,
|
|
222
|
-
azure_openai_config.api_version,
|
|
258
|
+
deployment=azure_openai_config.deployment,
|
|
259
|
+
model=azure_openai_config.model,
|
|
260
|
+
api_type=azure_openai_config.api_type,
|
|
261
|
+
api_base=azure_openai_config.api_base,
|
|
262
|
+
api_version=azure_openai_config.api_version,
|
|
263
|
+
oauth=azure_openai_config.oauth,
|
|
223
264
|
**azure_openai_config.extra_parameters,
|
|
224
265
|
)
|
|
225
266
|
|
|
@@ -234,6 +275,7 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
234
275
|
api_base=self._api_base,
|
|
235
276
|
api_version=self._api_version,
|
|
236
277
|
api_type=self._api_type,
|
|
278
|
+
oauth=self._oauth,
|
|
237
279
|
extra_parameters=self._extra_parameters,
|
|
238
280
|
)
|
|
239
281
|
return config.to_dict()
|
|
@@ -282,12 +324,23 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
282
324
|
"""Returns the completion arguments for invoking a call through
|
|
283
325
|
LiteLLM's completion functions.
|
|
284
326
|
"""
|
|
327
|
+
# Set the API key env var to None if OAuth is used
|
|
328
|
+
auth_parameter: Dict[str, str] = {}
|
|
329
|
+
|
|
330
|
+
if self._oauth:
|
|
331
|
+
auth_parameter = {
|
|
332
|
+
**auth_parameter,
|
|
333
|
+
LITE_LLM_AZURE_AD_TOKEN: self._oauth.get_bearer_token(),
|
|
334
|
+
}
|
|
335
|
+
elif self._api_key_env_var:
|
|
336
|
+
auth_parameter = {LITE_LLM_API_KEY_FIELD: self._api_key_env_var}
|
|
337
|
+
|
|
285
338
|
fn_args = super()._completion_fn_args
|
|
286
339
|
fn_args.update(
|
|
287
340
|
{
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
341
|
+
LITE_LLM_API_BASE_FIELD: self.api_base,
|
|
342
|
+
LITE_LLM_API_VERSION_FIELD: self.api_version,
|
|
343
|
+
**auth_parameter,
|
|
291
344
|
}
|
|
292
345
|
)
|
|
293
346
|
return fn_args
|
|
@@ -314,41 +367,44 @@ class AzureOpenAILLMClient(_BaseLiteLLMClient):
|
|
|
314
367
|
|
|
315
368
|
return info.format(setting=setting, options=options)
|
|
316
369
|
|
|
370
|
+
env_var_field = "env_var"
|
|
371
|
+
config_key_field = "config_key"
|
|
372
|
+
current_value_field = "current_value"
|
|
317
373
|
# All required settings for Azure OpenAI client
|
|
318
374
|
settings: Dict[str, Dict[str, Any]] = {
|
|
319
375
|
"API Base": {
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
376
|
+
current_value_field: self.api_base,
|
|
377
|
+
env_var_field: AZURE_API_BASE_ENV_VAR,
|
|
378
|
+
config_key_field: API_BASE_CONFIG_KEY,
|
|
323
379
|
},
|
|
324
380
|
"API Version": {
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
381
|
+
current_value_field: self.api_version,
|
|
382
|
+
env_var_field: AZURE_API_VERSION_ENV_VAR,
|
|
383
|
+
config_key_field: API_VERSION_CONFIG_KEY,
|
|
328
384
|
},
|
|
329
385
|
"Deployment Name": {
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
386
|
+
current_value_field: self.deployment,
|
|
387
|
+
env_var_field: None,
|
|
388
|
+
config_key_field: DEPLOYMENT_CONFIG_KEY,
|
|
333
389
|
},
|
|
334
390
|
}
|
|
335
391
|
|
|
336
392
|
missing_settings = [
|
|
337
393
|
setting_name
|
|
338
394
|
for setting_name, setting_info in settings.items()
|
|
339
|
-
if setting_info[
|
|
395
|
+
if setting_info[current_value_field] is None
|
|
340
396
|
]
|
|
341
397
|
|
|
342
398
|
if missing_settings:
|
|
343
399
|
event_info = f"Client settings not set: " f"{', '.join(missing_settings)}. "
|
|
344
400
|
|
|
345
401
|
for missing_setting in missing_settings:
|
|
346
|
-
if settings[missing_setting][
|
|
402
|
+
if settings[missing_setting][current_value_field] is not None:
|
|
347
403
|
continue
|
|
348
404
|
event_info += generate_event_info_for_missing_setting(
|
|
349
405
|
missing_setting,
|
|
350
|
-
settings[missing_setting][
|
|
351
|
-
settings[missing_setting][
|
|
406
|
+
settings[missing_setting][env_var_field],
|
|
407
|
+
settings[missing_setting][config_key_field],
|
|
352
408
|
)
|
|
353
409
|
|
|
354
410
|
structlogger.error(
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import Any, Dict
|
|
2
4
|
|
|
3
5
|
from rasa.shared.constants import (
|
|
@@ -35,7 +37,7 @@ class DefaultLiteLLMClient(_BaseLiteLLMClient):
|
|
|
35
37
|
self.validate_client_setup()
|
|
36
38
|
|
|
37
39
|
@classmethod
|
|
38
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
40
|
+
def from_config(cls, config: Dict[str, Any]) -> DefaultLiteLLMClient:
|
|
39
41
|
default_config = DefaultLiteLLMClientConfig.from_dict(config)
|
|
40
42
|
return cls(
|
|
41
43
|
model=default_config.model,
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
from typing import Any, Dict, List, Union
|
|
3
5
|
|
|
@@ -7,6 +9,7 @@ from rasa.shared.exceptions import ProviderClientAPIException
|
|
|
7
9
|
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
8
10
|
LiteLLMRouterClientConfig,
|
|
9
11
|
)
|
|
12
|
+
from rasa.shared.providers.constants import LITE_LLM_MODEL_FIELD
|
|
10
13
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
11
14
|
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
12
15
|
from rasa.shared.providers.router._base_litellm_router_client import (
|
|
@@ -42,7 +45,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
42
45
|
)
|
|
43
46
|
|
|
44
47
|
@classmethod
|
|
45
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
48
|
+
def from_config(cls, config: Dict[str, Any]) -> LiteLLMRouterLLMClient:
|
|
46
49
|
"""Instantiates a LiteLLM Router LLM client from a configuration dict.
|
|
47
50
|
|
|
48
51
|
Args:
|
|
@@ -87,6 +90,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
87
90
|
ProviderClientAPIException: If the API request fails.
|
|
88
91
|
"""
|
|
89
92
|
try:
|
|
93
|
+
structlogger.info(
|
|
94
|
+
"litellm_router_llm_client.text_completion",
|
|
95
|
+
_completion_fn_args=self._completion_fn_args,
|
|
96
|
+
)
|
|
90
97
|
response = self.router_client.text_completion(
|
|
91
98
|
prompt=prompt, **self._completion_fn_args
|
|
92
99
|
)
|
|
@@ -115,7 +122,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
115
122
|
raise ProviderClientAPIException(e)
|
|
116
123
|
|
|
117
124
|
@suppress_logs(log_level=logging.WARNING)
|
|
118
|
-
def completion(self, messages: Union[List[str], str]) -> LLMResponse:
|
|
125
|
+
def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
|
|
119
126
|
"""
|
|
120
127
|
Synchronously generate completions for given list of messages.
|
|
121
128
|
|
|
@@ -125,8 +132,14 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
125
132
|
text_completion method is called.
|
|
126
133
|
|
|
127
134
|
Args:
|
|
128
|
-
messages:
|
|
129
|
-
|
|
135
|
+
messages: The message can be,
|
|
136
|
+
- a list of preformatted messages. Each message should be a dictionary
|
|
137
|
+
with the following keys:
|
|
138
|
+
- content: The message content.
|
|
139
|
+
- role: The role of the message (e.g. user or system).
|
|
140
|
+
- a list of messages. Each message is a string and will be formatted
|
|
141
|
+
as a user message.
|
|
142
|
+
- a single message as a string which will be formatted as user message.
|
|
130
143
|
Returns:
|
|
131
144
|
List of message completions.
|
|
132
145
|
Raises:
|
|
@@ -144,7 +157,9 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
144
157
|
raise ProviderClientAPIException(e)
|
|
145
158
|
|
|
146
159
|
@suppress_logs(log_level=logging.WARNING)
|
|
147
|
-
async def acompletion(
|
|
160
|
+
async def acompletion(
|
|
161
|
+
self, messages: Union[List[dict], List[str], str]
|
|
162
|
+
) -> LLMResponse:
|
|
148
163
|
"""
|
|
149
164
|
Asynchronously generate completions for given list of messages.
|
|
150
165
|
|
|
@@ -154,8 +169,14 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
154
169
|
text_completion method is called.
|
|
155
170
|
|
|
156
171
|
Args:
|
|
157
|
-
messages:
|
|
158
|
-
|
|
172
|
+
messages: The message can be,
|
|
173
|
+
- a list of preformatted messages. Each message should be a dictionary
|
|
174
|
+
with the following keys:
|
|
175
|
+
- content: The message content.
|
|
176
|
+
- role: The role of the message (e.g. user or system).
|
|
177
|
+
- a list of messages. Each message is a string and will be formatted
|
|
178
|
+
as a user message.
|
|
179
|
+
- a single message as a string which will be formatted as user message.
|
|
159
180
|
Returns:
|
|
160
181
|
List of message completions.
|
|
161
182
|
Raises:
|
|
@@ -179,5 +200,5 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
179
200
|
"""
|
|
180
201
|
return {
|
|
181
202
|
**self._litellm_extra_parameters,
|
|
182
|
-
|
|
203
|
+
LITE_LLM_MODEL_FIELD: self.model_group_id,
|
|
183
204
|
}
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import Dict, List, Protocol, Union, runtime_checkable
|
|
2
4
|
|
|
3
5
|
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
@@ -11,7 +13,7 @@ class LLMClient(Protocol):
|
|
|
11
13
|
"""
|
|
12
14
|
|
|
13
15
|
@classmethod
|
|
14
|
-
def from_config(cls, config: dict) ->
|
|
16
|
+
def from_config(cls, config: dict) -> LLMClient:
|
|
15
17
|
"""
|
|
16
18
|
Initializes the llm client with the given configuration.
|
|
17
19
|
|
|
@@ -30,7 +32,7 @@ class LLMClient(Protocol):
|
|
|
30
32
|
"""
|
|
31
33
|
...
|
|
32
34
|
|
|
33
|
-
def completion(self, messages: Union[List[str], str]) -> LLMResponse:
|
|
35
|
+
def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
|
|
34
36
|
"""
|
|
35
37
|
Synchronously generate completions for given list of messages.
|
|
36
38
|
|
|
@@ -38,14 +40,22 @@ class LLMClient(Protocol):
|
|
|
38
40
|
strings) and return a list of completions (as strings).
|
|
39
41
|
|
|
40
42
|
Args:
|
|
41
|
-
messages:
|
|
42
|
-
|
|
43
|
+
messages: The message can be,
|
|
44
|
+
- a list of preformatted messages. Each message should be a dictionary
|
|
45
|
+
with the following keys:
|
|
46
|
+
- content: The message content.
|
|
47
|
+
- role: The role of the message (e.g. user or system).
|
|
48
|
+
- a list of messages. Each message is a string and will be formatted
|
|
49
|
+
as a user message.
|
|
50
|
+
- a single message as a string which will be formatted as user message.
|
|
43
51
|
Returns:
|
|
44
52
|
LLMResponse
|
|
45
53
|
"""
|
|
46
54
|
...
|
|
47
55
|
|
|
48
|
-
async def acompletion(
|
|
56
|
+
async def acompletion(
|
|
57
|
+
self, messages: Union[List[dict], List[str], str]
|
|
58
|
+
) -> LLMResponse:
|
|
49
59
|
"""
|
|
50
60
|
Asynchronously generate completions for given list of messages.
|
|
51
61
|
|
|
@@ -53,8 +63,14 @@ class LLMClient(Protocol):
|
|
|
53
63
|
strings) and return a list of completions (as strings).
|
|
54
64
|
|
|
55
65
|
Args:
|
|
56
|
-
messages:
|
|
57
|
-
|
|
66
|
+
messages: The message can be,
|
|
67
|
+
- a list of preformatted messages. Each message should be a dictionary
|
|
68
|
+
with the following keys:
|
|
69
|
+
- content: The message content.
|
|
70
|
+
- role: The role of the message (e.g. user or system).
|
|
71
|
+
- a list of messages. Each message is a string and will be formatted
|
|
72
|
+
as a user message.
|
|
73
|
+
- a single message as a string which will be formatted as user message.
|
|
58
74
|
Returns:
|
|
59
75
|
LLMResponse
|
|
60
76
|
"""
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import os
|
|
2
4
|
import re
|
|
3
5
|
from typing import Any, Dict, Optional
|
|
@@ -11,6 +13,10 @@ from rasa.shared.constants import (
|
|
|
11
13
|
OPENAI_PROVIDER,
|
|
12
14
|
)
|
|
13
15
|
from rasa.shared.providers._configs.openai_client_config import OpenAIClientConfig
|
|
16
|
+
from rasa.shared.providers.constants import (
|
|
17
|
+
LITE_LLM_API_BASE_FIELD,
|
|
18
|
+
LITE_LLM_API_VERSION_FIELD,
|
|
19
|
+
)
|
|
14
20
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
15
21
|
|
|
16
22
|
structlogger = structlog.get_logger()
|
|
@@ -57,7 +63,7 @@ class OpenAILLMClient(_BaseLiteLLMClient):
|
|
|
57
63
|
self.validate_client_setup()
|
|
58
64
|
|
|
59
65
|
@classmethod
|
|
60
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
66
|
+
def from_config(cls, config: Dict[str, Any]) -> OpenAILLMClient:
|
|
61
67
|
"""
|
|
62
68
|
Initializes the client from given configuration.
|
|
63
69
|
|
|
@@ -148,8 +154,8 @@ class OpenAILLMClient(_BaseLiteLLMClient):
|
|
|
148
154
|
fn_args = super()._completion_fn_args
|
|
149
155
|
fn_args.update(
|
|
150
156
|
{
|
|
151
|
-
|
|
152
|
-
|
|
157
|
+
LITE_LLM_API_BASE_FIELD: self.api_base,
|
|
158
|
+
LITE_LLM_API_VERSION_FIELD: self.api_version,
|
|
153
159
|
}
|
|
154
160
|
)
|
|
155
161
|
return fn_args
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import Any, Dict, Optional
|
|
2
4
|
|
|
3
5
|
import structlog
|
|
@@ -9,6 +11,10 @@ from rasa.shared.constants import (
|
|
|
9
11
|
from rasa.shared.providers._configs.rasa_llm_client_config import (
|
|
10
12
|
RasaLLMClientConfig,
|
|
11
13
|
)
|
|
14
|
+
from rasa.shared.providers.constants import (
|
|
15
|
+
LITE_LLM_API_BASE_FIELD,
|
|
16
|
+
LITE_LLM_API_KEY_FIELD,
|
|
17
|
+
)
|
|
12
18
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
13
19
|
from rasa.utils.licensing import retrieve_license_from_env
|
|
14
20
|
|
|
@@ -82,12 +88,15 @@ class RasaLLMClient(_BaseLiteLLMClient):
|
|
|
82
88
|
"""Returns the completion arguments for invoking a call using completions."""
|
|
83
89
|
fn_args = super()._completion_fn_args
|
|
84
90
|
fn_args.update(
|
|
85
|
-
{
|
|
91
|
+
{
|
|
92
|
+
LITE_LLM_API_BASE_FIELD: self.api_base,
|
|
93
|
+
LITE_LLM_API_KEY_FIELD: retrieve_license_from_env(),
|
|
94
|
+
}
|
|
86
95
|
)
|
|
87
96
|
return fn_args
|
|
88
97
|
|
|
89
98
|
@classmethod
|
|
90
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
99
|
+
def from_config(cls, config: Dict[str, Any]) -> RasaLLMClient:
|
|
91
100
|
try:
|
|
92
101
|
client_config = RasaLLMClientConfig.from_dict(config)
|
|
93
102
|
except ValueError as e:
|
|
@@ -1,12 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import logging
|
|
2
4
|
import os
|
|
3
5
|
from typing import Any, Dict, List, Optional, Union
|
|
4
6
|
|
|
5
7
|
import structlog
|
|
6
|
-
from litellm import
|
|
7
|
-
atext_completion,
|
|
8
|
-
text_completion,
|
|
9
|
-
)
|
|
8
|
+
from litellm import atext_completion, text_completion
|
|
10
9
|
|
|
11
10
|
from rasa.shared.constants import (
|
|
12
11
|
API_KEY,
|
|
@@ -17,6 +16,10 @@ from rasa.shared.exceptions import ProviderClientAPIException
|
|
|
17
16
|
from rasa.shared.providers._configs.self_hosted_llm_client_config import (
|
|
18
17
|
SelfHostedLLMClientConfig,
|
|
19
18
|
)
|
|
19
|
+
from rasa.shared.providers.constants import (
|
|
20
|
+
LITE_LLM_API_BASE_FIELD,
|
|
21
|
+
LITE_LLM_API_VERSION_FIELD,
|
|
22
|
+
)
|
|
20
23
|
from rasa.shared.providers.llm._base_litellm_client import _BaseLiteLLMClient
|
|
21
24
|
from rasa.shared.providers.llm.llm_response import LLMResponse
|
|
22
25
|
from rasa.shared.utils.io import suppress_logs
|
|
@@ -67,7 +70,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
67
70
|
self._apply_dummy_api_key_if_missing()
|
|
68
71
|
|
|
69
72
|
@classmethod
|
|
70
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
73
|
+
def from_config(cls, config: Dict[str, Any]) -> SelfHostedLLMClient:
|
|
71
74
|
try:
|
|
72
75
|
client_config = SelfHostedLLMClientConfig.from_dict(config)
|
|
73
76
|
except ValueError as e:
|
|
@@ -184,8 +187,8 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
184
187
|
fn_args = super()._completion_fn_args
|
|
185
188
|
fn_args.update(
|
|
186
189
|
{
|
|
187
|
-
|
|
188
|
-
|
|
190
|
+
LITE_LLM_API_BASE_FIELD: self.api_base,
|
|
191
|
+
LITE_LLM_API_VERSION_FIELD: self.api_version,
|
|
189
192
|
}
|
|
190
193
|
)
|
|
191
194
|
return fn_args
|
|
@@ -214,7 +217,14 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
214
217
|
Asynchronously generate completions for given prompt.
|
|
215
218
|
|
|
216
219
|
Args:
|
|
217
|
-
|
|
220
|
+
messages: The message can be,
|
|
221
|
+
- a list of preformatted messages. Each message should be a dictionary
|
|
222
|
+
with the following keys:
|
|
223
|
+
- content: The message content.
|
|
224
|
+
- role: The role of the message (e.g. user or system).
|
|
225
|
+
- a list of messages. Each message is a string and will be formatted
|
|
226
|
+
as a user message.
|
|
227
|
+
- a single message as a string which will be formatted as user message.
|
|
218
228
|
Returns:
|
|
219
229
|
List of message completions.
|
|
220
230
|
Raises:
|
|
@@ -226,7 +236,9 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
226
236
|
except Exception as e:
|
|
227
237
|
raise ProviderClientAPIException(e)
|
|
228
238
|
|
|
229
|
-
async def acompletion(
|
|
239
|
+
async def acompletion(
|
|
240
|
+
self, messages: Union[List[dict], List[str], str]
|
|
241
|
+
) -> LLMResponse:
|
|
230
242
|
"""Asynchronous completion of the model with the given messages.
|
|
231
243
|
|
|
232
244
|
Method overrides the base class method to call the appropriate
|
|
@@ -235,7 +247,14 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
235
247
|
atext_completion method is called.
|
|
236
248
|
|
|
237
249
|
Args:
|
|
238
|
-
messages: The
|
|
250
|
+
messages: The message can be,
|
|
251
|
+
- a list of preformatted messages. Each message should be a dictionary
|
|
252
|
+
with the following keys:
|
|
253
|
+
- content: The message content.
|
|
254
|
+
- role: The role of the message (e.g. user or system).
|
|
255
|
+
- a list of messages. Each message is a string and will be formatted
|
|
256
|
+
as a user message.
|
|
257
|
+
- a single message as a string which will be formatted as user message.
|
|
239
258
|
|
|
240
259
|
Returns:
|
|
241
260
|
The completion response.
|
|
@@ -244,7 +263,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
244
263
|
return await super().acompletion(messages)
|
|
245
264
|
return await self._atext_completion(messages)
|
|
246
265
|
|
|
247
|
-
def completion(self, messages: Union[List[str], str]) -> LLMResponse:
|
|
266
|
+
def completion(self, messages: Union[List[dict], List[str], str]) -> LLMResponse:
|
|
248
267
|
"""Completion of the model with the given messages.
|
|
249
268
|
|
|
250
269
|
Method overrides the base class method to call the appropriate
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import os
|
|
2
4
|
from typing import Any, Dict, List
|
|
3
5
|
|
|
@@ -93,7 +95,7 @@ class _BaseLiteLLMRouterClient:
|
|
|
93
95
|
return
|
|
94
96
|
|
|
95
97
|
@classmethod
|
|
96
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
|
98
|
+
def from_config(cls, config: Dict[str, Any]) -> _BaseLiteLLMRouterClient:
|
|
97
99
|
"""Instantiates a LiteLLM Router Embedding client from a configuration dict.
|
|
98
100
|
|
|
99
101
|
Args:
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from typing import Any, Dict, List, Protocol, runtime_checkable
|
|
2
4
|
|
|
3
5
|
|
|
@@ -9,7 +11,7 @@ class RouterClient(Protocol):
|
|
|
9
11
|
"""
|
|
10
12
|
|
|
11
13
|
@classmethod
|
|
12
|
-
def from_config(cls, config: dict) ->
|
|
14
|
+
def from_config(cls, config: dict) -> RouterClient:
|
|
13
15
|
"""
|
|
14
16
|
Initializes the router client with the given configuration.
|
|
15
17
|
|
rasa/shared/utils/constants.py
CHANGED