rasa-pro 3.14.1__py3-none-any.whl → 3.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (44) hide show
  1. rasa/constants.py +1 -0
  2. rasa/core/actions/action_clean_stack.py +32 -0
  3. rasa/core/actions/constants.py +4 -0
  4. rasa/core/actions/custom_action_executor.py +70 -12
  5. rasa/core/actions/grpc_custom_action_executor.py +41 -2
  6. rasa/core/actions/http_custom_action_executor.py +49 -25
  7. rasa/core/channels/voice_stream/browser_audio.py +3 -3
  8. rasa/core/channels/voice_stream/voice_channel.py +27 -17
  9. rasa/core/config/credentials.py +3 -3
  10. rasa/core/policies/flows/flow_executor.py +49 -29
  11. rasa/core/run.py +21 -5
  12. rasa/dialogue_understanding/generator/llm_based_command_generator.py +6 -3
  13. rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +15 -7
  14. rasa/dialogue_understanding/generator/single_step/search_ready_llm_command_generator.py +15 -8
  15. rasa/dialogue_understanding/processor/command_processor.py +13 -7
  16. rasa/e2e_test/e2e_config.py +4 -3
  17. rasa/engine/recipes/default_components.py +16 -6
  18. rasa/graph_components/validators/default_recipe_validator.py +10 -4
  19. rasa/nlu/classifiers/diet_classifier.py +2 -0
  20. rasa/shared/core/flows/flow.py +8 -2
  21. rasa/shared/core/slots.py +55 -24
  22. rasa/shared/providers/_configs/azure_openai_client_config.py +4 -5
  23. rasa/shared/providers/_configs/default_litellm_client_config.py +4 -4
  24. rasa/shared/providers/_configs/litellm_router_client_config.py +3 -2
  25. rasa/shared/providers/_configs/openai_client_config.py +5 -7
  26. rasa/shared/providers/_configs/rasa_llm_client_config.py +4 -4
  27. rasa/shared/providers/_configs/self_hosted_llm_client_config.py +4 -4
  28. rasa/shared/providers/llm/_base_litellm_client.py +42 -14
  29. rasa/shared/providers/llm/litellm_router_llm_client.py +38 -15
  30. rasa/shared/providers/llm/self_hosted_llm_client.py +34 -32
  31. rasa/shared/utils/common.py +9 -1
  32. rasa/shared/utils/configs.py +5 -8
  33. rasa/utils/common.py +9 -0
  34. rasa/utils/endpoints.py +6 -0
  35. rasa/utils/installation_utils.py +111 -0
  36. rasa/utils/tensorflow/callback.py +2 -0
  37. rasa/utils/tensorflow/models.py +3 -0
  38. rasa/utils/train_utils.py +2 -0
  39. rasa/version.py +1 -1
  40. {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/METADATA +2 -2
  41. {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/RECORD +44 -43
  42. {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/NOTICE +0 -0
  43. {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/WHEEL +0 -0
  44. {rasa_pro-3.14.1.dist-info → rasa_pro-3.14.2.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import logging
4
5
  from typing import Any, Dict, List, Union
5
6
 
6
7
  import structlog
7
8
 
9
+ from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
8
10
  from rasa.shared.exceptions import ProviderClientAPIException
9
11
  from rasa.shared.providers._configs.litellm_router_client_config import (
10
12
  LiteLLMRouterClientConfig,
@@ -79,13 +81,14 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
79
81
 
80
82
  @suppress_logs(log_level=logging.WARNING)
81
83
  def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
82
- """
83
- Synchronously generate completions for given prompt.
84
+ """Synchronously generate completions for given prompt.
84
85
 
85
86
  Args:
86
87
  prompt: Prompt to generate the completion for.
88
+
87
89
  Returns:
88
90
  List of message completions.
91
+
89
92
  Raises:
90
93
  ProviderClientAPIException: If the API request fails.
91
94
  """
@@ -103,21 +106,30 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
103
106
 
104
107
  @suppress_logs(log_level=logging.WARNING)
105
108
  async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
106
- """
107
- Asynchronously generate completions for given prompt.
109
+ """Asynchronously generate completions for given prompt.
108
110
 
109
111
  Args:
110
112
  prompt: Prompt to generate the completion for.
113
+
111
114
  Returns:
112
115
  List of message completions.
116
+
113
117
  Raises:
114
118
  ProviderClientAPIException: If the API request fails.
115
119
  """
116
120
  try:
117
- response = await self.router_client.atext_completion(
118
- prompt=prompt, **self._completion_fn_args
121
+ timeout = self._litellm_extra_parameters.get(
122
+ "timeout", DEFAULT_REQUEST_TIMEOUT
123
+ )
124
+ response = await asyncio.wait_for(
125
+ self.router_client.atext_completion(
126
+ prompt=prompt, **self._completion_fn_args
127
+ ),
128
+ timeout=timeout,
119
129
  )
120
130
  return self._format_text_completion_response(response)
131
+ except asyncio.TimeoutError:
132
+ self._handle_timeout_error()
121
133
  except Exception as e:
122
134
  raise ProviderClientAPIException(e)
123
135
 
@@ -125,8 +137,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
125
137
  def completion(
126
138
  self, messages: Union[List[dict], List[str], str], **kwargs: Any
127
139
  ) -> LLMResponse:
128
- """
129
- Synchronously generate completions for given list of messages.
140
+ """Synchronously generate completions for given list of messages.
130
141
 
131
142
  Method overrides the base class method to call the appropriate
132
143
  completion method based on the configuration. If the chat completions
@@ -143,8 +154,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
143
154
  as a user message.
144
155
  - a single message as a string which will be formatted as user message.
145
156
  **kwargs: Additional parameters to pass to the completion call.
157
+
146
158
  Returns:
147
159
  List of message completions.
160
+
148
161
  Raises:
149
162
  ProviderClientAPIException: If the API request fails.
150
163
  """
@@ -163,8 +176,7 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
163
176
  async def acompletion(
164
177
  self, messages: Union[List[dict], List[str], str], **kwargs: Any
165
178
  ) -> LLMResponse:
166
- """
167
- Asynchronously generate completions for given list of messages.
179
+ """Asynchronously generate completions for given list of messages.
168
180
 
169
181
  Method overrides the base class method to call the appropriate
170
182
  completion method based on the configuration. If the chat completions
@@ -181,8 +193,10 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
181
193
  as a user message.
182
194
  - a single message as a string which will be formatted as user message.
183
195
  **kwargs: Additional parameters to pass to the completion call.
196
+
184
197
  Returns:
185
198
  List of message completions.
199
+
186
200
  Raises:
187
201
  ProviderClientAPIException: If the API request fails.
188
202
  """
@@ -190,19 +204,28 @@ class LiteLLMRouterLLMClient(_BaseLiteLLMRouterClient, _BaseLiteLLMClient):
190
204
  return await self._atext_completion(messages)
191
205
  try:
192
206
  formatted_messages = self._get_formatted_messages(messages)
193
- response = await self.router_client.acompletion(
194
- messages=formatted_messages, **{**self._completion_fn_args, **kwargs}
207
+ timeout = self._litellm_extra_parameters.get(
208
+ "timeout", DEFAULT_REQUEST_TIMEOUT
209
+ )
210
+ response = await asyncio.wait_for(
211
+ self.router_client.acompletion(
212
+ messages=formatted_messages,
213
+ **{**self._completion_fn_args, **kwargs},
214
+ ),
215
+ timeout=timeout,
195
216
  )
196
217
  return self._format_response(response)
218
+ except asyncio.TimeoutError:
219
+ self._handle_timeout_error()
197
220
  except Exception as e:
198
221
  raise ProviderClientAPIException(e)
199
222
 
200
223
  @property
201
224
  def _completion_fn_args(self) -> Dict[str, Any]:
202
- """Returns the completion arguments for invoking a call through
203
- LiteLLM's completion functions.
204
- """
225
+ """Returns the completion arguments.
205
226
 
227
+ For invoking a call through LiteLLM's completion functions.
228
+ """
206
229
  return {
207
230
  **self._litellm_extra_parameters,
208
231
  LITE_LLM_MODEL_FIELD: self.model_group_id,
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import logging
4
5
  import os
5
6
  from typing import Any, Dict, List, Optional, Union
@@ -7,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
7
8
  import structlog
8
9
  from litellm import atext_completion, text_completion
9
10
 
11
+ from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
10
12
  from rasa.shared.constants import (
11
13
  API_KEY,
12
14
  SELF_HOSTED_VLLM_API_KEY_ENV_VAR,
@@ -28,7 +30,7 @@ structlogger = structlog.get_logger()
28
30
 
29
31
 
30
32
  class SelfHostedLLMClient(_BaseLiteLLMClient):
31
- """A client for interfacing with Self Hosted LLM endpoints that uses
33
+ """A client for interfacing with Self Hosted LLM endpoints.
32
34
 
33
35
  Parameters:
34
36
  model (str): The model or deployment name.
@@ -95,8 +97,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
95
97
 
96
98
  @property
97
99
  def provider(self) -> str:
98
- """
99
- Returns the provider name for the self hosted llm client.
100
+ """Returns the provider name for the self hosted llm client.
100
101
 
101
102
  Returns:
102
103
  String representing the provider name.
@@ -105,8 +106,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
105
106
 
106
107
  @property
107
108
  def model(self) -> str:
108
- """
109
- Returns the model name for the self hosted llm client.
109
+ """Returns the model name for the self hosted llm client.
110
110
 
111
111
  Returns:
112
112
  String representing the model name.
@@ -115,8 +115,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
115
115
 
116
116
  @property
117
117
  def api_base(self) -> str:
118
- """
119
- Returns the base URL for the API endpoint.
118
+ """Returns the base URL for the API endpoint.
120
119
 
121
120
  Returns:
122
121
  String representing the base URL.
@@ -125,8 +124,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
125
124
 
126
125
  @property
127
126
  def api_type(self) -> Optional[str]:
128
- """
129
- Returns the type of the API endpoint. Currently only OpenAI is supported.
127
+ """Returns the type of the API endpoint. Currently only OpenAI is supported.
130
128
 
131
129
  Returns:
132
130
  String representing the API type.
@@ -135,8 +133,7 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
135
133
 
136
134
  @property
137
135
  def api_version(self) -> Optional[str]:
138
- """
139
- Returns the version of the API endpoint.
136
+ """Returns the version of the API endpoint.
140
137
 
141
138
  Returns:
142
139
  String representing the API version.
@@ -145,8 +142,8 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
145
142
 
146
143
  @property
147
144
  def config(self) -> Dict:
148
- """
149
- Returns the configuration for the self hosted llm client.
145
+ """Returns the configuration for the self hosted llm client.
146
+
150
147
  Returns:
151
148
  Dictionary containing the configuration.
152
149
  """
@@ -163,9 +160,9 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
163
160
 
164
161
  @property
165
162
  def _litellm_model_name(self) -> str:
166
- """Returns the value of LiteLLM's model parameter to be used in
167
- completion/acompletion in LiteLLM format:
163
+ """Returns the value of LiteLLM's model parameter.
168
164
 
165
+ To be used in completion/acompletion in LiteLLM format:
169
166
  <hosted_vllm>/<model or deployment name>
170
167
  """
171
168
  if self.model and f"{SELF_HOSTED_VLLM_PREFIX}/" not in self.model:
@@ -174,15 +171,17 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
174
171
 
175
172
  @property
176
173
  def _litellm_extra_parameters(self) -> Dict[str, Any]:
177
- """Returns optional configuration parameters specific
178
- to the client provider and deployed model.
174
+ """Returns optional configuration parameters.
175
+
176
+ Specific to the client provider and deployed model.
179
177
  """
180
178
  return self._extra_parameters
181
179
 
182
180
  @property
183
181
  def _completion_fn_args(self) -> Dict[str, Any]:
184
- """Returns the completion arguments for invoking a call through
185
- LiteLLM's completion functions.
182
+ """Returns the completion arguments.
183
+
184
+ For invoking a call through LiteLLM's completion functions.
186
185
  """
187
186
  fn_args = super()._completion_fn_args
188
187
  fn_args.update(
@@ -195,13 +194,14 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
195
194
 
196
195
  @suppress_logs(log_level=logging.WARNING)
197
196
  def _text_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
198
- """
199
- Synchronously generate completions for given prompt.
197
+ """Synchronously generate completions for given prompt.
200
198
 
201
199
  Args:
202
200
  prompt: Prompt to generate the completion for.
201
+
203
202
  Returns:
204
203
  List of message completions.
204
+
205
205
  Raises:
206
206
  ProviderClientAPIException: If the API request fails.
207
207
  """
@@ -213,26 +213,28 @@ class SelfHostedLLMClient(_BaseLiteLLMClient):
213
213
 
214
214
  @suppress_logs(log_level=logging.WARNING)
215
215
  async def _atext_completion(self, prompt: Union[List[str], str]) -> LLMResponse:
216
- """
217
- Asynchronously generate completions for given prompt.
216
+ """Asynchronously generate completions for given prompt.
218
217
 
219
218
  Args:
220
- messages: The message can be,
221
- - a list of preformatted messages. Each message should be a dictionary
222
- with the following keys:
223
- - content: The message content.
224
- - role: The role of the message (e.g. user or system).
225
- - a list of messages. Each message is a string and will be formatted
226
- as a user message.
227
- - a single message as a string which will be formatted as user message.
219
+ prompt: Prompt to generate the completion for.
220
+
228
221
  Returns:
229
222
  List of message completions.
223
+
230
224
  Raises:
231
225
  ProviderClientAPIException: If the API request fails.
232
226
  """
233
227
  try:
234
- response = await atext_completion(prompt=prompt, **self._completion_fn_args)
228
+ timeout = self._litellm_extra_parameters.get(
229
+ "timeout", DEFAULT_REQUEST_TIMEOUT
230
+ )
231
+ response = await asyncio.wait_for(
232
+ atext_completion(prompt=prompt, **self._completion_fn_args),
233
+ timeout=timeout,
234
+ )
235
235
  return self._format_text_completion_response(response)
236
+ except asyncio.TimeoutError:
237
+ self._handle_timeout_error()
236
238
  except Exception as e:
237
239
  raise ProviderClientAPIException(e)
238
240
 
@@ -26,6 +26,7 @@ from rasa.exceptions import MissingDependencyException
26
26
  from rasa.shared.constants import DOCS_URL_MIGRATION_GUIDE
27
27
  from rasa.shared.exceptions import ProviderClientValidationError, RasaException
28
28
  from rasa.shared.utils.cli import print_success
29
+ from rasa.utils.installation_utils import check_for_installation_issues
29
30
 
30
31
  logger = logging.getLogger(__name__)
31
32
 
@@ -396,7 +397,11 @@ Sign up at: https://feedback.rasa.com
396
397
  print_success(message)
397
398
 
398
399
 
399
- def conditional_import(module_name: str, class_name: str) -> Tuple[Any, bool]:
400
+ def conditional_import(
401
+ module_name: str,
402
+ class_name: str,
403
+ check_installation_setup: bool = False,
404
+ ) -> Tuple[Any, bool]:
400
405
  """Conditionally import a class, returning (class, is_available) tuple.
401
406
 
402
407
  Args:
@@ -408,6 +413,9 @@ def conditional_import(module_name: str, class_name: str) -> Tuple[Any, bool]:
408
413
  or None if import failed, and is_available is a boolean indicating
409
414
  whether the import was successful.
410
415
  """
416
+ if check_installation_setup:
417
+ check_for_installation_issues()
418
+
411
419
  try:
412
420
  module = __import__(module_name, fromlist=[class_name])
413
421
  return getattr(module, class_name), True
@@ -8,8 +8,7 @@ structlogger = structlog.get_logger()
8
8
 
9
9
 
10
10
  def resolve_aliases(config: dict, deprecated_alias_mapping: dict) -> dict:
11
- """
12
- Resolve aliases in the configuration to standard keys.
11
+ """Resolve aliases in the configuration to standard keys.
13
12
 
14
13
  Args:
15
14
  config: Dictionary containing the configuration.
@@ -37,13 +36,13 @@ def raise_deprecation_warnings(
37
36
  deprecated_alias_mapping: dict,
38
37
  source: Optional[str] = None,
39
38
  ) -> None:
40
- """
41
- Raises warnings for deprecated keys in the configuration.
39
+ """Raises warnings for deprecated keys in the configuration.
42
40
 
43
41
  Args:
44
42
  config: Dictionary containing the configuration.
45
43
  deprecated_alias_mapping: Dictionary mapping deprecated keys to
46
44
  their standard keys.
45
+ source: Optional source context for the deprecation warning.
47
46
 
48
47
  Raises:
49
48
  DeprecationWarning: If any deprecated key is found in the config.
@@ -61,8 +60,7 @@ def raise_deprecation_warnings(
61
60
 
62
61
 
63
62
  def validate_required_keys(config: dict, required_keys: list) -> None:
64
- """
65
- Validates that the passed config contains all the required keys.
63
+ """Validates that the passed config contains all the required keys.
66
64
 
67
65
  Args:
68
66
  config: Dictionary containing the configuration.
@@ -84,8 +82,7 @@ def validate_required_keys(config: dict, required_keys: list) -> None:
84
82
 
85
83
 
86
84
  def validate_forbidden_keys(config: dict, forbidden_keys: list) -> None:
87
- """
88
- Validates that the passed config doesn't contain any forbidden keys.
85
+ """Validates that the passed config doesn't contain any forbidden keys.
89
86
 
90
87
  Args:
91
88
  config: Dictionary containing the configuration.
rasa/utils/common.py CHANGED
@@ -36,6 +36,7 @@ from rasa.constants import (
36
36
  ENV_LOG_LEVEL_LIBRARIES,
37
37
  ENV_LOG_LEVEL_MATPLOTLIB,
38
38
  ENV_LOG_LEVEL_MCP,
39
+ ENV_LOG_LEVEL_PYMONGO,
39
40
  ENV_LOG_LEVEL_RABBITMQ,
40
41
  ENV_MCP_LOGGING_ENABLED,
41
42
  )
@@ -297,6 +298,7 @@ def configure_library_logging() -> None:
297
298
  update_rabbitmq_log_level(library_log_level)
298
299
  update_websockets_log_level(library_log_level)
299
300
  update_mcp_log_level()
301
+ update_pymongo_log_level(library_log_level)
300
302
 
301
303
 
302
304
  def update_apscheduler_log_level() -> None:
@@ -481,6 +483,13 @@ def update_mcp_log_level() -> None:
481
483
  logging.getLogger(logger_name).propagate = False
482
484
 
483
485
 
486
+ def update_pymongo_log_level(library_log_level: str) -> None:
487
+ """Set the log level of pymongo."""
488
+ log_level = os.environ.get(ENV_LOG_LEVEL_PYMONGO, library_log_level)
489
+ logging.getLogger("pymongo").setLevel(log_level)
490
+ logging.getLogger("pymongo").propagate = False
491
+
492
+
484
493
  def sort_list_of_dicts_by_first_key(dicts: List[Dict]) -> List[Dict]:
485
494
  """Sorts a list of dictionaries by their first key."""
486
495
  return sorted(dicts, key=lambda d: next(iter(d.keys())))
rasa/utils/endpoints.py CHANGED
@@ -9,6 +9,7 @@ import structlog
9
9
  from aiohttp.client_exceptions import ContentTypeError
10
10
  from sanic.request import Request
11
11
 
12
+ from rasa.core.actions.constants import MISSING_DOMAIN_MARKER
12
13
  from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
13
14
  from rasa.shared.exceptions import FileNotFoundException
14
15
  from rasa.shared.utils.yaml import read_config_file
@@ -222,6 +223,11 @@ class EndpointConfig:
222
223
  ssl=sslcontext,
223
224
  **kwargs,
224
225
  ) as response:
226
+ if response.status == 449:
227
+ # Return a special marker that HTTPCustomActionExecutor can detect
228
+ # This avoids raising an exception for this expected case
229
+ return {MISSING_DOMAIN_MARKER: True}
230
+
225
231
  if response.status >= 400:
226
232
  raise ClientResponseError(
227
233
  response.status,
@@ -0,0 +1,111 @@
1
+ import importlib.util
2
+ import sys
3
+
4
+ import structlog
5
+
6
+ structlogger = structlog.get_logger()
7
+
8
+
9
+ def check_tensorflow_installation() -> None:
10
+ """Check if TensorFlow is installed without proper Rasa extras."""
11
+ # Check if tensorflow is available in the environment
12
+ tensorflow_available = importlib.util.find_spec("tensorflow") is not None
13
+
14
+ if not tensorflow_available:
15
+ return
16
+
17
+ # Check if any TensorFlow-related extras were installed
18
+ # We do this by checking for packages that are only installed with nlu/full extras
19
+ tensorflow_extras_indicators = [
20
+ "tensorflow_text", # Only in nlu/full extras
21
+ "tensorflow_hub", # Only in nlu/full extras
22
+ "tf_keras", # Only in nlu/full extras
23
+ ]
24
+
25
+ extras_installed = any(
26
+ importlib.util.find_spec(pkg) is not None
27
+ for pkg in tensorflow_extras_indicators
28
+ )
29
+
30
+ if tensorflow_available and not extras_installed:
31
+ structlogger.warning(
32
+ "installation_utils.tensorflow_installation",
33
+ warning=(
34
+ "TensorFlow is installed but Rasa was not installed with TensorFlow "
35
+ "support, i.e. additional packages required to use NLU components "
36
+ "have not been installed. For the most reliable setup, delete your "
37
+ "current virtual environment, create a new one, and install Rasa "
38
+ "again. Please follow the instructions at "
39
+ "https://rasa.com/docs/pro/installation/python"
40
+ ),
41
+ )
42
+
43
+
44
+ def check_tensorflow_integrity() -> None:
45
+ """Check if TensorFlow installation is corrupted or incomplete."""
46
+ # Only check if tensorflow is available
47
+ if importlib.util.find_spec("tensorflow") is None:
48
+ return
49
+
50
+ try:
51
+ # Try to import tensorflow - this will fail if installation is corrupted
52
+ import tensorflow as tf
53
+
54
+ # Try to access a basic TensorFlow function
55
+ _ = tf.constant([1, 2, 3])
56
+ except Exception:
57
+ # Simplified error message for all TensorFlow corruption issues
58
+ structlogger.error(
59
+ "installation_utils.tensorflow_integrity",
60
+ issue=(
61
+ "TensorFlow is installed but appears to be corrupted or incomplete. "
62
+ "For the most reliable setup, delete your current virtual "
63
+ "environment, create a new one, and install Rasa again. "
64
+ "Please follow the instructions at "
65
+ "https://rasa.com/docs/pro/installation/python"
66
+ ),
67
+ )
68
+ sys.exit(1)
69
+
70
+
71
+ def check_rasa_availability() -> None:
72
+ """Check if Rasa is installed and importable."""
73
+ if importlib.util.find_spec("rasa") is None:
74
+ structlogger.error(
75
+ "installation_utils.rasa_availability",
76
+ issue=(
77
+ "Rasa is not installed in this environment. "
78
+ "Please follow the instructions at "
79
+ "https://rasa.com/docs/pro/installation/python"
80
+ ),
81
+ )
82
+ sys.exit(1)
83
+
84
+ try:
85
+ _ = importlib.import_module("rasa")
86
+ except Exception as e:
87
+ structlogger.error(
88
+ "installation_utils.rasa_availability",
89
+ issue=(
90
+ f"Rasa is installed but cannot be imported: {e!s}."
91
+ f"Please follow the instructions at "
92
+ f"https://rasa.com/docs/pro/installation/python"
93
+ ),
94
+ )
95
+ sys.exit(1)
96
+
97
+
98
+ def check_for_installation_issues() -> None:
99
+ """Check for all potential installation issues.
100
+
101
+ Returns:
102
+ List of warning messages for detected issues.
103
+ """
104
+ # Check if Rasa is available first
105
+ check_rasa_availability()
106
+
107
+ # Check TensorFlow integrity first (more critical)
108
+ check_tensorflow_integrity()
109
+
110
+ # Check for orphaned TensorFlow
111
+ check_tensorflow_installation()
@@ -2,9 +2,11 @@ import logging
2
2
  from pathlib import Path
3
3
  from typing import Any, Dict, Optional, Text
4
4
 
5
+ from rasa.utils.installation_utils import check_for_installation_issues
5
6
  from rasa.utils.tensorflow import TENSORFLOW_AVAILABLE
6
7
 
7
8
  if TENSORFLOW_AVAILABLE:
9
+ check_for_installation_issues()
8
10
  import tensorflow as tf
9
11
  from tqdm import tqdm
10
12
  else:
@@ -498,7 +498,10 @@ class RasaModel(Model):
498
498
  # predict on one data example to speed up prediction during inference
499
499
  # the first prediction always takes a bit longer to trace tf function
500
500
  if predict_data_example:
501
+ # Warm-up to build any lazily created variables/branches
501
502
  model.run_inference(predict_data_example)
503
+ # Reload weights so newly created variables are restored as well
504
+ model.load_weights(model_file_name)
502
505
 
503
506
  logger.debug("Finished loading the model.")
504
507
  return model
rasa/utils/train_utils.py CHANGED
@@ -11,10 +11,12 @@ from rasa.nlu.constants import NUMBER_OF_SUB_TOKENS
11
11
  from rasa.shared.constants import NEXT_MAJOR_VERSION_FOR_DEPRECATIONS
12
12
  from rasa.shared.exceptions import InvalidConfigException
13
13
  from rasa.shared.nlu.constants import SPLIT_ENTITIES_BY_COMMA
14
+ from rasa.utils.installation_utils import check_for_installation_issues
14
15
  from rasa.utils.tensorflow import TENSORFLOW_AVAILABLE
15
16
 
16
17
  # Conditional imports for TensorFlow-dependent modules
17
18
  if TENSORFLOW_AVAILABLE:
19
+ check_for_installation_issues()
18
20
  from rasa.utils.tensorflow.callback import RasaModelCheckpoint, RasaTrainingLogger
19
21
  from rasa.utils.tensorflow.constants import (
20
22
  AUTO,
rasa/version.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # this file will automatically be changed,
2
2
  # do not add anything but the version number here!
3
- __version__ = "3.14.1"
3
+ __version__ = "3.14.2"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rasa-pro
3
- Version: 3.14.1
3
+ Version: 3.14.2
4
4
  Summary: State-of-the-art open-core Conversational AI framework for Enterprises that natively leverages generative AI for effortless assistant development.
5
5
  Keywords: nlp,machine-learning,machine-learning-library,bot,bots,botkit,rasa conversational-agents,conversational-ai,chatbot,chatbot-framework,bot-framework
6
6
  Author: Rasa Technologies GmbH
@@ -102,7 +102,7 @@ Requires-Dist: python-dateutil (>=2.8.2,<2.9.0)
102
102
  Requires-Dist: python-dotenv (>=1.0.1,<2.0.0)
103
103
  Requires-Dist: python-engineio (>=4.12.2,<4.13.0)
104
104
  Requires-Dist: python-keycloak (>=5.8.1,<5.9.0)
105
- Requires-Dist: python-socketio (>=5.13,<6)
105
+ Requires-Dist: python-socketio (>=5.14.2,<5.15.0)
106
106
  Requires-Dist: pytz (>=2022.7.1,<2023.0)
107
107
  Requires-Dist: pyyaml (>=6.0.2,<6.1.0)
108
108
  Requires-Dist: qdrant-client (>=1.9.1,<1.10.0)