rasa-pro 3.14.1__py3-none-any.whl → 3.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/constants.py +1 -0
- rasa/core/actions/action_clean_stack.py +32 -0
- rasa/core/actions/constants.py +4 -0
- rasa/core/actions/custom_action_executor.py +70 -12
- rasa/core/actions/grpc_custom_action_executor.py +41 -2
- rasa/core/actions/http_custom_action_executor.py +49 -25
- rasa/core/channels/voice_stream/browser_audio.py +3 -3
- rasa/core/channels/voice_stream/voice_channel.py +27 -17
- rasa/core/config/credentials.py +3 -3
- rasa/core/policies/flows/flow_executor.py +49 -29
- rasa/core/run.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
- rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
- rasa/dialogue_understanding/processor/command_processor.py +13 -7
- rasa/e2e_test/e2e_config.py +4 -3
- rasa/engine/recipes/default_components.py +16 -6
- rasa/graph_components/validators/default_recipe_validator.py +10 -4
- rasa/nlu/classifiers/diet_classifier.py +2 -0
- rasa/shared/core/flows/flow.py +8 -2
- rasa/shared/core/slots.py +55 -24
- rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
- rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
- rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
- rasa/shared/providers/_configs/openai_client_config.py +5 -7
- rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
- rasa/shared/providers/llm/_base_litellm_client.py +42 -14
- rasa/shared/providers/llm/litellm_router_llm_client.py +38 -15
- rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
- rasa/shared/utils/common.py +9 -1
- rasa/shared/utils/configs.py +5 -8
- rasa/utils/common.py +9 -0
- rasa/utils/endpoints.py +6 -0
- rasa/utils/installation_utils.py +111 -0
- rasa/utils/tensorflow/callback.py +2 -0
- rasa/utils/tensorflow/models.py +3 -0
- rasa/utils/train_utils.py +2 -0
- rasa/version.py +1 -1
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/METADATA +2 -2
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/RECORD +44 -43
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/entry_points.txt +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
from typing import Any, Dict, List, Union
|
|
5
6
|
|
|
6
7
|
import structlog
|
|
7
8
|
|
|
9
|
+
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
8
10
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
9
11
|
from rasa.shared.providers._configs.litellm_router_client_config import (
|
|
10
12
|
LiteLLMRouterClientConfig,
|
|
@@ -79,13 +81,14 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
79
81
|
|
|
80
82
|
@suppress_logs(log_level=logging.WARNING)
|
|
81
83
|
def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
82
|
-
"""
|
|
83
|
-
Synchronously generate completions for given prompt.
|
|
84
|
+
"""Synchronously generate completions for given prompt.
|
|
84
85
|
|
|
85
86
|
Args:
|
|
86
87
|
prompt: Prompt to generate the completion for.
|
|
88
|
+
|
|
87
89
|
Returns:
|
|
88
90
|
List of message completions.
|
|
91
|
+
|
|
89
92
|
Raises:
|
|
90
93
|
ProviderClientAPIException: If the API request fails.
|
|
91
94
|
"""
|
|
@@ -103,21 +106,30 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
103
106
|
|
|
104
107
|
@suppress_logs(log_level=logging.WARNING)
|
|
105
108
|
async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
106
|
-
"""
|
|
107
|
-
Asynchronously generate completions for given prompt.
|
|
109
|
+
"""Asynchronously generate completions for given prompt.
|
|
108
110
|
|
|
109
111
|
Args:
|
|
110
112
|
prompt: Prompt to generate the completion for.
|
|
113
|
+
|
|
111
114
|
Returns:
|
|
112
115
|
List of message completions.
|
|
116
|
+
|
|
113
117
|
Raises:
|
|
114
118
|
ProviderClientAPIException: If the API request fails.
|
|
115
119
|
"""
|
|
116
120
|
try:
|
|
117
|
-
|
|
118
|
-
|
|
121
|
+
timeout = self._litellm_extra_parameters.get(
|
|
122
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
123
|
+
)
|
|
124
|
+
response = await asyncio.wait_for(
|
|
125
|
+
self.router_client.atext_completion(
|
|
126
|
+
prompt=prompt, **self._completion_fn_args
|
|
127
|
+
),
|
|
128
|
+
timeout=timeout,
|
|
119
129
|
)
|
|
120
130
|
return self._format_text_completion_response(response)
|
|
131
|
+
except asyncio.TimeoutError:
|
|
132
|
+
self._handle_timeout_error()
|
|
121
133
|
except Exception as e:
|
|
122
134
|
raise ProviderClientAPIException(e)
|
|
123
135
|
|
|
@@ -125,8 +137,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
125
137
|
def completion(
|
|
126
138
|
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
127
139
|
) -> LLMResponse:
|
|
128
|
-
"""
|
|
129
|
-
Synchronously generate completions for given list of messages.
|
|
140
|
+
"""Synchronously generate completions for given list of messages.
|
|
130
141
|
|
|
131
142
|
Method overrides the base class method to call the appropriate
|
|
132
143
|
completion method based on the configuration. If the chat completions
|
|
@@ -143,8 +154,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
143
154
|
as a user message.
|
|
144
155
|
- a single message as a string which will be formatted as user message.
|
|
145
156
|
**kwargs: Additional parameters to pass to the completion call.
|
|
157
|
+
|
|
146
158
|
Returns:
|
|
147
159
|
List of message completions.
|
|
160
|
+
|
|
148
161
|
Raises:
|
|
149
162
|
ProviderClientAPIException: If the API request fails.
|
|
150
163
|
"""
|
|
@@ -163,8 +176,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
163
176
|
async def acompletion(
|
|
164
177
|
self, messages: Union[List[dict], List[str], str], **kwargs: Any
|
|
165
178
|
) -> LLMResponse:
|
|
166
|
-
"""
|
|
167
|
-
Asynchronously generate completions for given list of messages.
|
|
179
|
+
"""Asynchronously generate completions for given list of messages.
|
|
168
180
|
|
|
169
181
|
Method overrides the base class method to call the appropriate
|
|
170
182
|
completion method based on the configuration. If the chat completions
|
|
@@ -181,8 +193,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
181
193
|
as a user message.
|
|
182
194
|
- a single message as a string which will be formatted as user message.
|
|
183
195
|
**kwargs: Additional parameters to pass to the completion call.
|
|
196
|
+
|
|
184
197
|
Returns:
|
|
185
198
|
List of message completions.
|
|
199
|
+
|
|
186
200
|
Raises:
|
|
187
201
|
ProviderClientAPIException: If the API request fails.
|
|
188
202
|
"""
|
|
@@ -190,19 +204,28 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
|
|
|
190
204
|
return await self._atext_completion(messages)
|
|
191
205
|
try:
|
|
192
206
|
formatted_messages = self._get_formatted_messages(messages)
|
|
193
|
-
|
|
194
|
-
|
|
207
|
+
timeout = self._litellm_extra_parameters.get(
|
|
208
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
209
|
+
)
|
|
210
|
+
response = await asyncio.wait_for(
|
|
211
|
+
self.router_client.acompletion(
|
|
212
|
+
messages=formatted_messages,
|
|
213
|
+
**{**self._completion_fn_args, **kwargs},
|
|
214
|
+
),
|
|
215
|
+
timeout=timeout,
|
|
195
216
|
)
|
|
196
217
|
return self._format_response(response)
|
|
218
|
+
except asyncio.TimeoutError:
|
|
219
|
+
self._handle_timeout_error()
|
|
197
220
|
except Exception as e:
|
|
198
221
|
raise ProviderClientAPIException(e)
|
|
199
222
|
|
|
200
223
|
@property
|
|
201
224
|
def _completion_fn_args(self) -> Dict[str, Any]:
|
|
202
|
-
"""Returns the completion arguments
|
|
203
|
-
LiteLLM's completion functions.
|
|
204
|
-
"""
|
|
225
|
+
"""Returns the completion arguments.
|
|
205
226
|
|
|
227
|
+
For invoking a call through LiteLLM's completion functions.
|
|
228
|
+
"""
|
|
206
229
|
return {
|
|
207
230
|
**self._litellm_extra_parameters,
|
|
208
231
|
LITE_LLM_MODEL_FIELD: self.model_group_id,
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
5
|
import os
|
|
5
6
|
from typing import Any, Dict, List, Optional, Union
|
|
@@ -7,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
|
|
|
7
8
|
import structlog
|
|
8
9
|
from litellm import atext_completion, text_completion
|
|
9
10
|
|
|
11
|
+
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
10
12
|
from rasa.shared.constants import (
|
|
11
13
|
API_KEY,
|
|
12
14
|
SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
|
|
@@ -28,7 +30,7 @@ structlogger = structlog.get_logger()
|
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
31
|
-
"""A client for interfacing with Self Hosted LLM endpoints
|
|
33
|
+
"""A client for interfacing with Self Hosted LLM endpoints.
|
|
32
34
|
|
|
33
35
|
Parameters:
|
|
34
36
|
model (str): The model or deployment name.
|
|
@@ -95,8 +97,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
95
97
|
|
|
96
98
|
@property
|
|
97
99
|
def provider(self) -> str:
|
|
98
|
-
"""
|
|
99
|
-
Returns the provider name for the self hosted llm client.
|
|
100
|
+
"""Returns the provider name for the self hosted llm client.
|
|
100
101
|
|
|
101
102
|
Returns:
|
|
102
103
|
String representing the provider name.
|
|
@@ -105,8 +106,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
105
106
|
|
|
106
107
|
@property
|
|
107
108
|
def model(self) -> str:
|
|
108
|
-
"""
|
|
109
|
-
Returns the model name for the self hosted llm client.
|
|
109
|
+
"""Returns the model name for the self hosted llm client.
|
|
110
110
|
|
|
111
111
|
Returns:
|
|
112
112
|
String representing the model name.
|
|
@@ -115,8 +115,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
115
115
|
|
|
116
116
|
@property
|
|
117
117
|
def api_base(self) -> str:
|
|
118
|
-
"""
|
|
119
|
-
Returns the base URL for the API endpoint.
|
|
118
|
+
"""Returns the base URL for the API endpoint.
|
|
120
119
|
|
|
121
120
|
Returns:
|
|
122
121
|
String representing the base URL.
|
|
@@ -125,8 +124,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
125
124
|
|
|
126
125
|
@property
|
|
127
126
|
def api_type(self) -> Optional[str]:
|
|
128
|
-
"""
|
|
129
|
-
Returns the type of the API endpoint. Currently only OpenAI is supported.
|
|
127
|
+
"""Returns the type of the API endpoint. Currently only OpenAI is supported.
|
|
130
128
|
|
|
131
129
|
Returns:
|
|
132
130
|
String representing the API type.
|
|
@@ -135,8 +133,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
135
133
|
|
|
136
134
|
@property
|
|
137
135
|
def api_version(self) -> Optional[str]:
|
|
138
|
-
"""
|
|
139
|
-
Returns the version of the API endpoint.
|
|
136
|
+
"""Returns the version of the API endpoint.
|
|
140
137
|
|
|
141
138
|
Returns:
|
|
142
139
|
String representing the API version.
|
|
@@ -145,8 +142,8 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
145
142
|
|
|
146
143
|
@property
|
|
147
144
|
def config(self) -> Dict:
|
|
148
|
-
"""
|
|
149
|
-
|
|
145
|
+
"""Returns the configuration for the self hosted llm client.
|
|
146
|
+
|
|
150
147
|
Returns:
|
|
151
148
|
Dictionary containing the configuration.
|
|
152
149
|
"""
|
|
@@ -163,9 +160,9 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
163
160
|
|
|
164
161
|
@property
|
|
165
162
|
def _litellm_model_name(self) -> str:
|
|
166
|
-
"""Returns the value of LiteLLM's model parameter
|
|
167
|
-
completion/acompletion in LiteLLM format:
|
|
163
|
+
"""Returns the value of LiteLLM's model parameter.
|
|
168
164
|
|
|
165
|
+
To be used in completion/acompletion in LiteLLM format:
|
|
169
166
|
<hosted_vllm>/<model or deployment name>
|
|
170
167
|
"""
|
|
171
168
|
if self.model and f"{SELF_HOSTED_VLLM_PREFIX}/" not in self.model:
|
|
@@ -174,15 +171,17 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
174
171
|
|
|
175
172
|
@property
|
|
176
173
|
def _litellm_extra_parameters(self) -> Dict[str, Any]:
|
|
177
|
-
"""Returns optional configuration parameters
|
|
178
|
-
|
|
174
|
+
"""Returns optional configuration parameters.
|
|
175
|
+
|
|
176
|
+
Specific to the client provider and deployed model.
|
|
179
177
|
"""
|
|
180
178
|
return self._extra_parameters
|
|
181
179
|
|
|
182
180
|
@property
|
|
183
181
|
def _completion_fn_args(self) -> Dict[str, Any]:
|
|
184
|
-
"""Returns the completion arguments
|
|
185
|
-
|
|
182
|
+
"""Returns the completion arguments.
|
|
183
|
+
|
|
184
|
+
For invoking a call through LiteLLM's completion functions.
|
|
186
185
|
"""
|
|
187
186
|
fn_args = super()._completion_fn_args
|
|
188
187
|
fn_args.update(
|
|
@@ -195,13 +194,14 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
195
194
|
|
|
196
195
|
@suppress_logs(log_level=logging.WARNING)
|
|
197
196
|
def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
198
|
-
"""
|
|
199
|
-
Synchronously generate completions for given prompt.
|
|
197
|
+
"""Synchronously generate completions for given prompt.
|
|
200
198
|
|
|
201
199
|
Args:
|
|
202
200
|
prompt: Prompt to generate the completion for.
|
|
201
|
+
|
|
203
202
|
Returns:
|
|
204
203
|
List of message completions.
|
|
204
|
+
|
|
205
205
|
Raises:
|
|
206
206
|
ProviderClientAPIException: If the API request fails.
|
|
207
207
|
"""
|
|
@@ -213,26 +213,28 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
|
|
|
213
213
|
|
|
214
214
|
@suppress_logs(log_level=logging.WARNING)
|
|
215
215
|
async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
|
|
216
|
-
"""
|
|
217
|
-
Asynchronously generate completions for given prompt.
|
|
216
|
+
"""Asynchronously generate completions for given prompt.
|
|
218
217
|
|
|
219
218
|
Args:
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
with the following keys:
|
|
223
|
-
- content: The message content.
|
|
224
|
-
- role: The role of the message (e.g. user or system).
|
|
225
|
-
- a list of messages. Each message is a string and will be formatted
|
|
226
|
-
as a user message.
|
|
227
|
-
- a single message as a string which will be formatted as user message.
|
|
219
|
+
prompt: Prompt to generate the completion for.
|
|
220
|
+
|
|
228
221
|
Returns:
|
|
229
222
|
List of message completions.
|
|
223
|
+
|
|
230
224
|
Raises:
|
|
231
225
|
ProviderClientAPIException: If the API request fails.
|
|
232
226
|
"""
|
|
233
227
|
try:
|
|
234
|
-
|
|
228
|
+
timeout = self._litellm_extra_parameters.get(
|
|
229
|
+
"timeout", DEFAULT_REQUEST_TIMEOUT
|
|
230
|
+
)
|
|
231
|
+
response = await asyncio.wait_for(
|
|
232
|
+
atext_completion(prompt=prompt, **self._completion_fn_args),
|
|
233
|
+
timeout=timeout,
|
|
234
|
+
)
|
|
235
235
|
return self._format_text_completion_response(response)
|
|
236
|
+
except asyncio.TimeoutError:
|
|
237
|
+
self._handle_timeout_error()
|
|
236
238
|
except Exception as e:
|
|
237
239
|
raise ProviderClientAPIException(e)
|
|
238
240
|
|
rasa/shared/utils/common.py
CHANGED
|
@@ -26,6 +26,7 @@ from rasa.exceptions import MissingDependencyException
|
|
|
26
26
|
from rasa.shared.constants import DOCS_URL_MIGRATION_GUIDE
|
|
27
27
|
from rasa.shared.exceptions import ProviderClientValidationError, RasaException
|
|
28
28
|
from rasa.shared.utils.cli import print_success
|
|
29
|
+
from rasa.utils.installation_utils import check_for_installation_issues
|
|
29
30
|
|
|
30
31
|
logger = logging.getLogger(__name__)
|
|
31
32
|
|
|
@@ -396,7 +397,11 @@ Sign up at: https://feedback.rasa.com
|
|
|
396
397
|
print_success(message)
|
|
397
398
|
|
|
398
399
|
|
|
399
|
-
def conditional_import(
|
|
400
|
+
def conditional_import(
|
|
401
|
+
module_name: str,
|
|
402
|
+
class_name: str,
|
|
403
|
+
check_installation_setup: bool = False,
|
|
404
|
+
) -> Tuple[Any, bool]:
|
|
400
405
|
"""Conditionally import a class, returning (class, is_available) tuple.
|
|
401
406
|
|
|
402
407
|
Args:
|
|
@@ -408,6 +413,9 @@ def conditional_import(module_name: str, class_name: str) -> Tuple[Any, bool]:
|
|
|
408
413
|
or None if import failed, and is_available is a boolean indicating
|
|
409
414
|
whether the import was successful.
|
|
410
415
|
"""
|
|
416
|
+
if check_installation_setup:
|
|
417
|
+
check_for_installation_issues()
|
|
418
|
+
|
|
411
419
|
try:
|
|
412
420
|
module = __import__(module_name, fromlist=[class_name])
|
|
413
421
|
return getattr(module, class_name), True
|
rasa/shared/utils/configs.py
CHANGED
|
@@ -8,8 +8,7 @@ structlogger = structlog.get_logger()
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def resolve_aliases(config: dict, deprecated_alias_mapping: dict) -> dict:
|
|
11
|
-
"""
|
|
12
|
-
Resolve aliases in the configuration to standard keys.
|
|
11
|
+
"""Resolve aliases in the configuration to standard keys.
|
|
13
12
|
|
|
14
13
|
Args:
|
|
15
14
|
config: Dictionary containing the configuration.
|
|
@@ -37,13 +36,13 @@ def raise_deprecation_warnings(
|
|
|
37
36
|
deprecated_alias_mapping: dict,
|
|
38
37
|
source: Optional[str] = None,
|
|
39
38
|
) -> None:
|
|
40
|
-
"""
|
|
41
|
-
Raises warnings for deprecated keys in the configuration.
|
|
39
|
+
"""Raises warnings for deprecated keys in the configuration.
|
|
42
40
|
|
|
43
41
|
Args:
|
|
44
42
|
config: Dictionary containing the configuration.
|
|
45
43
|
deprecated_alias_mapping: Dictionary mapping deprecated keys to
|
|
46
44
|
their standard keys.
|
|
45
|
+
source: Optional source context for the deprecation warning.
|
|
47
46
|
|
|
48
47
|
Raises:
|
|
49
48
|
DeprecationWarning: If any deprecated key is found in the config.
|
|
@@ -61,8 +60,7 @@ def raise_deprecation_warnings(
|
|
|
61
60
|
|
|
62
61
|
|
|
63
62
|
def validate_required_keys(config: dict, required_keys: list) -> None:
|
|
64
|
-
"""
|
|
65
|
-
Validates that the passed config contains all the required keys.
|
|
63
|
+
"""Validates that the passed config contains all the required keys.
|
|
66
64
|
|
|
67
65
|
Args:
|
|
68
66
|
config: Dictionary containing the configuration.
|
|
@@ -84,8 +82,7 @@ def validate_required_keys(config: dict, required_keys: list) -> None:
|
|
|
84
82
|
|
|
85
83
|
|
|
86
84
|
def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
|
|
87
|
-
"""
|
|
88
|
-
Validates that the passed config doesn't contain any forbidden keys.
|
|
85
|
+
"""Validates that the passed config doesn't contain any forbidden keys.
|
|
89
86
|
|
|
90
87
|
Args:
|
|
91
88
|
config: Dictionary containing the configuration.
|
rasa/utils/common.py
CHANGED
|
@@ -36,6 +36,7 @@ from rasa.constants import (
|
|
|
36
36
|
ENV_LOG_LEVEL_LIBRARIES,
|
|
37
37
|
ENV_LOG_LEVEL_MATPLOTLIB,
|
|
38
38
|
ENV_LOG_LEVEL_MCP,
|
|
39
|
+
ENV_LOG_LEVEL_PYMONGO,
|
|
39
40
|
ENV_LOG_LEVEL_RABBITMQ,
|
|
40
41
|
ENV_MCP_LOGGING_ENABLED,
|
|
41
42
|
)
|
|
@@ -297,6 +298,7 @@ def configure_library_logging() -> None:
|
|
|
297
298
|
update_rabbitmq_log_level(library_log_level)
|
|
298
299
|
update_websockets_log_level(library_log_level)
|
|
299
300
|
update_mcp_log_level()
|
|
301
|
+
update_pymongo_log_level(library_log_level)
|
|
300
302
|
|
|
301
303
|
|
|
302
304
|
def update_apscheduler_log_level() -> None:
|
|
@@ -481,6 +483,13 @@ def update_mcp_log_level() -> None:
|
|
|
481
483
|
logging.getLogger(logger_name).propagate = False
|
|
482
484
|
|
|
483
485
|
|
|
486
|
+
def update_pymongo_log_level(library_log_level: str) -> None:
|
|
487
|
+
"""Set the log level of pymongo."""
|
|
488
|
+
log_level = os.environ.get(ENV_LOG_LEVEL_PYMONGO, library_log_level)
|
|
489
|
+
logging.getLogger("pymongo").setLevel(log_level)
|
|
490
|
+
logging.getLogger("pymongo").propagate = False
|
|
491
|
+
|
|
492
|
+
|
|
484
493
|
def sort_list_of_dicts_by_first_key(dicts: List[Dict]) -> List[Dict]:
|
|
485
494
|
"""Sorts a list of dictionaries by their first key."""
|
|
486
495
|
return sorted(dicts, key=lambda d: next(iter(d.keys())))
|
rasa/utils/endpoints.py
CHANGED
|
@@ -9,6 +9,7 @@ import structlog
|
|
|
9
9
|
from aiohttp.client_exceptions import ContentTypeError
|
|
10
10
|
from sanic.request import Request
|
|
11
11
|
|
|
12
|
+
from rasa.core.actions.constants import MISSING_DOMAIN_MARKER
|
|
12
13
|
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
|
|
13
14
|
from rasa.shared.exceptions import FileNotFoundException
|
|
14
15
|
from rasa.shared.utils.yaml import read_config_file
|
|
@@ -222,6 +223,11 @@ class EndpointConfig:
|
|
|
222
223
|
ssl=sslcontext,
|
|
223
224
|
**kwargs,
|
|
224
225
|
) as response:
|
|
226
|
+
if response.status == 449:
|
|
227
|
+
# Return a special marker that HTTPCustomActionExecutor can detect
|
|
228
|
+
# This avoids raising an exception for this expected case
|
|
229
|
+
return {MISSING_DOMAIN_MARKER: True}
|
|
230
|
+
|
|
225
231
|
if response.status >= 400:
|
|
226
232
|
raise ClientResponseError(
|
|
227
233
|
response.status,
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
import structlog
|
|
5
|
+
|
|
6
|
+
structlogger = structlog.get_logger()
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def check_tensorflow_installation() -> None:
|
|
10
|
+
"""Check if TensorFlow is installed without proper Rasa extras."""
|
|
11
|
+
# Check if tensorflow is available in the environment
|
|
12
|
+
tensorflow_available = importlib.util.find_spec("tensorflow") is not None
|
|
13
|
+
|
|
14
|
+
if not tensorflow_available:
|
|
15
|
+
return
|
|
16
|
+
|
|
17
|
+
# Check if any TensorFlow-related extras were installed
|
|
18
|
+
# We do this by checking for packages that are only installed with nlu/full extras
|
|
19
|
+
tensorflow_extras_indicators = [
|
|
20
|
+
"tensorflow_text", # Only in nlu/full extras
|
|
21
|
+
"tensorflow_hub", # Only in nlu/full extras
|
|
22
|
+
"tf_keras", # Only in nlu/full extras
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
extras_installed = any(
|
|
26
|
+
importlib.util.find_spec(pkg) is not None
|
|
27
|
+
for pkg in tensorflow_extras_indicators
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
if tensorflow_available and not extras_installed:
|
|
31
|
+
structlogger.warning(
|
|
32
|
+
"installation_utils.tensorflow_installation",
|
|
33
|
+
warning=(
|
|
34
|
+
"TensorFlow is installed but Rasa was not installed with TensorFlow "
|
|
35
|
+
"support, i.e. additional packages required to use NLU components "
|
|
36
|
+
"have not been installed. For the most reliable setup, delete your "
|
|
37
|
+
"current virtual environment, create a new one, and install Rasa "
|
|
38
|
+
"again. Please follow the instructions at "
|
|
39
|
+
"https://rasa.com/docs/pro/installation/python"
|
|
40
|
+
),
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def check_tensorflow_integrity() -> None:
|
|
45
|
+
"""Check if TensorFlow installation is corrupted or incomplete."""
|
|
46
|
+
# Only check if tensorflow is available
|
|
47
|
+
if importlib.util.find_spec("tensorflow") is None:
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
# Try to import tensorflow - this will fail if installation is corrupted
|
|
52
|
+
import tensorflow as tf
|
|
53
|
+
|
|
54
|
+
# Try to access a basic TensorFlow function
|
|
55
|
+
_ = tf.constant([1, 2, 3])
|
|
56
|
+
except Exception:
|
|
57
|
+
# Simplified error message for all TensorFlow corruption issues
|
|
58
|
+
structlogger.error(
|
|
59
|
+
"installation_utils.tensorflow_integrity",
|
|
60
|
+
issue=(
|
|
61
|
+
"TensorFlow is installed but appears to be corrupted or incomplete. "
|
|
62
|
+
"For the most reliable setup, delete your current virtual "
|
|
63
|
+
"environment, create a new one, and install Rasa again. "
|
|
64
|
+
"Please follow the instructions at "
|
|
65
|
+
"https://rasa.com/docs/pro/installation/python"
|
|
66
|
+
),
|
|
67
|
+
)
|
|
68
|
+
sys.exit(1)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def check_rasa_availability() -> None:
|
|
72
|
+
"""Check if Rasa is installed and importable."""
|
|
73
|
+
if importlib.util.find_spec("rasa") is None:
|
|
74
|
+
structlogger.error(
|
|
75
|
+
"installation_utils.rasa_availability",
|
|
76
|
+
issue=(
|
|
77
|
+
"Rasa is not installed in this environment. "
|
|
78
|
+
"Please follow the instructions at "
|
|
79
|
+
"https://rasa.com/docs/pro/installation/python"
|
|
80
|
+
),
|
|
81
|
+
)
|
|
82
|
+
sys.exit(1)
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
_ = importlib.import_module("rasa")
|
|
86
|
+
except Exception as e:
|
|
87
|
+
structlogger.error(
|
|
88
|
+
"installation_utils.rasa_availability",
|
|
89
|
+
issue=(
|
|
90
|
+
f"Rasa is installed but cannot be imported: {e!s}."
|
|
91
|
+
f"Please follow the instructions at "
|
|
92
|
+
f"https://rasa.com/docs/pro/installation/python"
|
|
93
|
+
),
|
|
94
|
+
)
|
|
95
|
+
sys.exit(1)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def check_for_installation_issues() -> None:
|
|
99
|
+
"""Check for all potential installation issues.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of warning messages for detected issues.
|
|
103
|
+
"""
|
|
104
|
+
# Check if Rasa is available first
|
|
105
|
+
check_rasa_availability()
|
|
106
|
+
|
|
107
|
+
# Check TensorFlow integrity first (more critical)
|
|
108
|
+
check_tensorflow_integrity()
|
|
109
|
+
|
|
110
|
+
# Check for orphaned TensorFlow
|
|
111
|
+
check_tensorflow_installation()
|
|
@@ -2,9 +2,11 @@ import logging
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import Any, Dict, Optional, Text
|
|
4
4
|
|
|
5
|
+
from rasa.utils.installation_utils import check_for_installation_issues
|
|
5
6
|
from rasa.utils.tensorflow import TENSORFLOW_AVAILABLE
|
|
6
7
|
|
|
7
8
|
if TENSORFLOW_AVAILABLE:
|
|
9
|
+
check_for_installation_issues()
|
|
8
10
|
import tensorflow as tf
|
|
9
11
|
from tqdm import tqdm
|
|
10
12
|
else:
|
rasa/utils/tensorflow/models.py
CHANGED
|
@@ -498,7 +498,10 @@ class RasaModel(Model):
|
|
|
498
498
|
# predict on one data example to speed up prediction during inference
|
|
499
499
|
# the first prediction always takes a bit longer to trace tf function
|
|
500
500
|
if predict_data_example:
|
|
501
|
+
# Warm-up to build any lazily created variables/branches
|
|
501
502
|
model.run_inference(predict_data_example)
|
|
503
|
+
# Reload weights so newly created variables are restored as well
|
|
504
|
+
model.load_weights(model_file_name)
|
|
502
505
|
|
|
503
506
|
logger.debug("Finished loading the model.")
|
|
504
507
|
return model
|
rasa/utils/train_utils.py
CHANGED
|
@@ -11,10 +11,12 @@ from rasa.nlu.constants import NUMBER_OF_SUB_TOKENS
|
|
|
11
11
|
from rasa.shared.constants import NEXT_MAJOR_VERSION_FOR_DEPRECATIONS
|
|
12
12
|
from rasa.shared.exceptions import InvalidConfigException
|
|
13
13
|
from rasa.shared.nlu.constants import SPLIT_ENTITIES_BY_COMMA
|
|
14
|
+
from rasa.utils.installation_utils import check_for_installation_issues
|
|
14
15
|
from rasa.utils.tensorflow import TENSORFLOW_AVAILABLE
|
|
15
16
|
|
|
16
17
|
# Conditional imports for TensorFlow-dependent modules
|
|
17
18
|
if TENSORFLOW_AVAILABLE:
|
|
19
|
+
check_for_installation_issues()
|
|
18
20
|
from rasa.utils.tensorflow.callback import RasaModelCheckpoint, RasaTrainingLogger
|
|
19
21
|
from rasa.utils.tensorflow.constants import (
|
|
20
22
|
AUTO,
|
rasa/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: rasa-pro
|
|
3
|
-
Version: 3.14.
|
|
3
|
+
Version: 3.14.2
|
|
4
4
|
Summary: State-of-the-art open-core Conversational AI framework for Enterprises that natively leverages generative AI for effortless assistant development.
|
|
5
5
|
Keywords: nlp,machine-learning,machine-learning-library,bot,bots,botkit,rasa conversational-agents,conversational-ai,chatbot,chatbot-framework,bot-framework
|
|
6
6
|
Author: Rasa Technologies GmbH
|
|
@@ -102,7 +102,7 @@ Requires-Dist: python-dateutil (>=2.8.2,<2.9.0)
|
|
|
102
102
|
Requires-Dist: python-dotenv (>=1.0.1,<2.0.0)
|
|
103
103
|
Requires-Dist: python-engineio (>=4.12.2,<4.13.0)
|
|
104
104
|
Requires-Dist: python-keycloak (>=5.8.1,<5.9.0)
|
|
105
|
-
Requires-Dist: python-socketio (>=5.
|
|
105
|
+
Requires-Dist: python-socketio (>=5.14.2,<5.15.0)
|
|
106
106
|
Requires-Dist: pytz (>=2022.7.1,<2023.0)
|
|
107
107
|
Requires-Dist: pyyaml (>=6.0.2,<6.1.0)
|
|
108
108
|
Requires-Dist: qdrant-client (>=1.9.1,<1.10.0)
|