rasa-pro 3.10.7.dev5__py3-none-any.whl → 3.10.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

Files changed (76) hide show
  1. README.md +37 -1
  2. rasa/api.py +2 -8
  3. rasa/cli/arguments/default_arguments.py +2 -23
  4. rasa/cli/arguments/run.py +0 -2
  5. rasa/cli/e2e_test.py +8 -10
  6. rasa/cli/inspect.py +2 -5
  7. rasa/cli/run.py +0 -7
  8. rasa/cli/studio/studio.py +21 -1
  9. rasa/cli/train.py +4 -9
  10. rasa/cli/utils.py +3 -3
  11. rasa/core/agent.py +2 -2
  12. rasa/core/brokers/kafka.py +1 -3
  13. rasa/core/brokers/pika.py +1 -3
  14. rasa/core/channels/socketio.py +1 -5
  15. rasa/core/channels/voice_aware/utils.py +5 -6
  16. rasa/core/nlg/contextual_response_rephraser.py +2 -11
  17. rasa/core/policies/enterprise_search_policy.py +2 -11
  18. rasa/core/policies/intentless_policy.py +2 -9
  19. rasa/core/run.py +1 -2
  20. rasa/core/secrets_manager/constants.py +0 -4
  21. rasa/core/secrets_manager/factory.py +0 -8
  22. rasa/core/secrets_manager/vault.py +1 -11
  23. rasa/core/utils.py +19 -30
  24. rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -9
  25. rasa/dialogue_understanding/commands/__init__.py +2 -0
  26. rasa/dialogue_understanding/commands/restart_command.py +58 -0
  27. rasa/dialogue_understanding/commands/set_slot_command.py +5 -1
  28. rasa/dialogue_understanding/commands/utils.py +3 -1
  29. rasa/dialogue_understanding/generator/llm_based_command_generator.py +2 -11
  30. rasa/dialogue_understanding/generator/llm_command_generator.py +1 -1
  31. rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +15 -15
  32. rasa/dialogue_understanding/patterns/restart.py +37 -0
  33. rasa/e2e_test/e2e_test_runner.py +1 -1
  34. rasa/engine/graph.py +1 -0
  35. rasa/engine/recipes/config_files/default_config.yml +3 -0
  36. rasa/engine/recipes/default_recipe.py +1 -0
  37. rasa/engine/recipes/graph_recipe.py +1 -0
  38. rasa/engine/storage/local_model_storage.py +1 -0
  39. rasa/engine/storage/storage.py +5 -1
  40. rasa/model_training.py +6 -11
  41. rasa/{core → nlu}/persistor.py +1 -1
  42. rasa/server.py +1 -1
  43. rasa/shared/constants.py +3 -2
  44. rasa/shared/core/domain.py +47 -101
  45. rasa/shared/core/flows/flows_list.py +6 -19
  46. rasa/shared/core/flows/validation.py +0 -25
  47. rasa/shared/core/flows/yaml_flows_io.py +24 -3
  48. rasa/shared/importers/importer.py +32 -32
  49. rasa/shared/importers/multi_project.py +11 -23
  50. rasa/shared/importers/rasa.py +2 -7
  51. rasa/shared/importers/remote_importer.py +2 -2
  52. rasa/shared/importers/utils.py +1 -3
  53. rasa/shared/nlu/training_data/training_data.py +19 -18
  54. rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
  55. rasa/shared/providers/llm/_base_litellm_client.py +26 -10
  56. rasa/shared/providers/llm/self_hosted_llm_client.py +15 -3
  57. rasa/shared/utils/common.py +22 -3
  58. rasa/shared/utils/llm.py +5 -29
  59. rasa/shared/utils/schemas/model_config.yml +10 -0
  60. rasa/studio/auth.py +4 -0
  61. rasa/tracing/instrumentation/attribute_extractors.py +1 -1
  62. rasa/validator.py +5 -2
  63. rasa/version.py +1 -1
  64. {rasa_pro-3.10.7.dev5.dist-info → rasa_pro-3.10.8.dist-info}/METADATA +43 -7
  65. {rasa_pro-3.10.7.dev5.dist-info → rasa_pro-3.10.8.dist-info}/RECORD +68 -74
  66. rasa/model_manager/__init__.py +0 -0
  67. rasa/model_manager/config.py +0 -12
  68. rasa/model_manager/model_api.py +0 -467
  69. rasa/model_manager/runner_service.py +0 -185
  70. rasa/model_manager/socket_bridge.py +0 -44
  71. rasa/model_manager/trainer_service.py +0 -240
  72. rasa/model_manager/utils.py +0 -27
  73. rasa/model_service.py +0 -66
  74. {rasa_pro-3.10.7.dev5.dist-info → rasa_pro-3.10.8.dist-info}/NOTICE +0 -0
  75. {rasa_pro-3.10.7.dev5.dist-info → rasa_pro-3.10.8.dist-info}/WHEEL +0 -0
  76. {rasa_pro-3.10.7.dev5.dist-info → rasa_pro-3.10.8.dist-info}/entry_points.txt +0 -0
@@ -23,7 +23,6 @@ from rasa.shared.core.training_data.structures import StoryGraph
23
23
  from rasa.shared.nlu.constants import ACTION_NAME, ENTITIES
24
24
  from rasa.shared.nlu.training_data.message import Message
25
25
  from rasa.shared.nlu.training_data.training_data import TrainingData
26
- from rasa.shared.utils.common import cached_method
27
26
  from rasa.shared.utils.yaml import read_config_file
28
27
 
29
28
  logger = logging.getLogger(__name__)
@@ -114,7 +113,7 @@ class TrainingDataImporter(ABC):
114
113
  config_path: Text,
115
114
  domain_path: Optional[Text] = None,
116
115
  training_data_paths: Optional[List[Text]] = None,
117
- args: Optional[Dict[Text, Any]] = None,
116
+ args: Optional[Dict[Text, Any]] = {},
118
117
  ) -> "TrainingDataImporter":
119
118
  """Loads a `TrainingDataImporter` instance from a configuration file."""
120
119
  config = read_config_file(config_path)
@@ -127,7 +126,7 @@ class TrainingDataImporter(ABC):
127
126
  config_path: Text,
128
127
  domain_path: Optional[Text] = None,
129
128
  training_data_paths: Optional[List[Text]] = None,
130
- args: Optional[Dict[Text, Any]] = None,
129
+ args: Optional[Dict[Text, Any]] = {},
131
130
  ) -> "TrainingDataImporter":
132
131
  """Loads core `TrainingDataImporter` instance.
133
132
 
@@ -143,7 +142,7 @@ class TrainingDataImporter(ABC):
143
142
  config_path: Text,
144
143
  domain_path: Optional[Text] = None,
145
144
  training_data_paths: Optional[List[Text]] = None,
146
- args: Optional[Dict[Text, Any]] = None,
145
+ args: Optional[Dict[Text, Any]] = {},
147
146
  ) -> "TrainingDataImporter":
148
147
  """Loads nlu `TrainingDataImporter` instance.
149
148
 
@@ -227,8 +226,7 @@ class TrainingDataImporter(ABC):
227
226
  **constructor_arguments,
228
227
  )
229
228
 
230
- @staticmethod
231
- def fingerprint() -> Text:
229
+ def fingerprint(self) -> Text:
232
230
  """Returns a random fingerprint as data shouldn't be cached."""
233
231
  return rasa.shared.utils.io.random_string(25)
234
232
 
@@ -284,6 +282,7 @@ class NluDataImporter(TrainingDataImporter):
284
282
  """Retrieves NLU training data (see parent class for full docstring)."""
285
283
  return self._importer.get_nlu_data(language)
286
284
 
285
+ @rasa.shared.utils.common.cached_method
287
286
  def get_config_file_for_auto_config(self) -> Optional[Text]:
288
287
  """Returns config file path for auto-config only if there is a single one."""
289
288
  return self._importer.get_config_file_for_auto_config()
@@ -299,14 +298,14 @@ class CombinedDataImporter(TrainingDataImporter):
299
298
  def __init__(self, importers: List[TrainingDataImporter]):
300
299
  self._importers = importers
301
300
 
302
- @cached_method
301
+ @rasa.shared.utils.common.cached_method
303
302
  def get_config(self) -> Dict:
304
303
  """Retrieves model config (see parent class for full docstring)."""
305
304
  configs = [importer.get_config() for importer in self._importers]
306
305
 
307
306
  return reduce(lambda merged, other: {**merged, **(other or {})}, configs, {})
308
307
 
309
- @cached_method
308
+ @rasa.shared.utils.common.cached_method
310
309
  def get_domain(self) -> Domain:
311
310
  """Retrieves model domain (see parent class for full docstring)."""
312
311
  domains = [importer.get_domain() for importer in self._importers]
@@ -317,7 +316,7 @@ class CombinedDataImporter(TrainingDataImporter):
317
316
  Domain.empty(),
318
317
  )
319
318
 
320
- @cached_method
319
+ @rasa.shared.utils.common.cached_method
321
320
  def get_stories(self, exclusion_percentage: Optional[int] = None) -> StoryGraph:
322
321
  """Retrieves training stories / rules (see parent class for full docstring)."""
323
322
  stories = [
@@ -328,7 +327,7 @@ class CombinedDataImporter(TrainingDataImporter):
328
327
  lambda merged, other: merged.merge(other), stories, StoryGraph([])
329
328
  )
330
329
 
331
- @cached_method
330
+ @rasa.shared.utils.common.cached_method
332
331
  def get_flows(self) -> FlowsList:
333
332
  """Retrieves training stories / rules (see parent class for full docstring)."""
334
333
  flow_lists = [importer.get_flows() for importer in self._importers]
@@ -339,7 +338,7 @@ class CombinedDataImporter(TrainingDataImporter):
339
338
  FlowsList(underlying_flows=[]),
340
339
  )
341
340
 
342
- @cached_method
341
+ @rasa.shared.utils.common.cached_method
343
342
  def get_conversation_tests(self) -> StoryGraph:
344
343
  """Retrieves conversation test stories (see parent class for full docstring)."""
345
344
  stories = [importer.get_conversation_tests() for importer in self._importers]
@@ -348,7 +347,7 @@ class CombinedDataImporter(TrainingDataImporter):
348
347
  lambda merged, other: merged.merge(other), stories, StoryGraph([])
349
348
  )
350
349
 
351
- @cached_method
350
+ @rasa.shared.utils.common.cached_method
352
351
  def get_nlu_data(self, language: Optional[Text] = "en") -> TrainingData:
353
352
  """Retrieves NLU training data (see parent class for full docstring)."""
354
353
  nlu_data = [importer.get_nlu_data(language) for importer in self._importers]
@@ -357,7 +356,7 @@ class CombinedDataImporter(TrainingDataImporter):
357
356
  lambda merged, other: merged.merge(other), nlu_data, TrainingData()
358
357
  )
359
358
 
360
- @cached_method
359
+ @rasa.shared.utils.common.cached_method
361
360
  def get_config_file_for_auto_config(self) -> Optional[Text]:
362
361
  """Returns config file path for auto-config only if there is a single one."""
363
362
  if len(self._importers) != 1:
@@ -416,22 +415,26 @@ class FlowSyncImporter(PassThroughImporter):
416
415
  """Loads the default flows from the file system."""
417
416
  from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader
418
417
 
419
- flows = YAMLFlowsReader.read_from_file(FlowSyncImporter.default_pattern_path())
418
+ default_flows_file = str(
419
+ importlib_resources.files("rasa.dialogue_understanding.patterns").joinpath(
420
+ DEFAULT_PATTERN_FLOWS_FILE_NAME
421
+ )
422
+ )
423
+
424
+ flows = YAMLFlowsReader.read_from_file(default_flows_file)
420
425
  flows.validate()
421
426
  return flows
422
427
 
423
428
  @staticmethod
424
- def default_pattern_path() -> str:
425
- return str(
429
+ def load_default_pattern_flows_domain() -> Domain:
430
+ """Loads the default flows from the file system."""
431
+ default_flows_file = str(
426
432
  importlib_resources.files("rasa.dialogue_understanding.patterns").joinpath(
427
433
  DEFAULT_PATTERN_FLOWS_FILE_NAME
428
434
  )
429
435
  )
430
436
 
431
- @staticmethod
432
- def load_default_pattern_flows_domain() -> Domain:
433
- """Loads the default flows from the file system."""
434
- return Domain.from_path(FlowSyncImporter.default_pattern_path())
437
+ return Domain.from_path(default_flows_file)
435
438
 
436
439
  @classmethod
437
440
  def merge_with_default_flows(cls, flows: FlowsList) -> FlowsList:
@@ -457,7 +460,7 @@ class FlowSyncImporter(PassThroughImporter):
457
460
 
458
461
  return flows.merge(FlowsList(missing_default_flows))
459
462
 
460
- @cached_method
463
+ @rasa.shared.utils.common.cached_method
461
464
  def get_flows(self) -> FlowsList:
462
465
  flows = self._importer.get_flows()
463
466
 
@@ -467,11 +470,11 @@ class FlowSyncImporter(PassThroughImporter):
467
470
 
468
471
  return self.merge_with_default_flows(flows)
469
472
 
470
- @cached_method
473
+ @rasa.shared.utils.common.cached_method
471
474
  def get_user_flows(self) -> FlowsList:
472
475
  return self._importer.get_flows()
473
476
 
474
- @cached_method
477
+ @rasa.shared.utils.common.cached_method
475
478
  def get_domain(self) -> Domain:
476
479
  """Merge existing domain with properties of flows."""
477
480
  # load domain data from user defined domain files
@@ -513,7 +516,7 @@ class ResponsesSyncImporter(PassThroughImporter):
513
516
  back to the Domain.
514
517
  """
515
518
 
516
- @cached_method
519
+ @rasa.shared.utils.common.cached_method
517
520
  def get_domain(self) -> Domain:
518
521
  """Merge existing domain with properties of retrieval intents in NLU data."""
519
522
  existing_domain = self._importer.get_domain()
@@ -595,7 +598,7 @@ class ResponsesSyncImporter(PassThroughImporter):
595
598
  }
596
599
  )
597
600
 
598
- @cached_method
601
+ @rasa.shared.utils.common.cached_method
599
602
  def get_nlu_data(self, language: Optional[Text] = "en") -> TrainingData:
600
603
  """Updates NLU data with responses for retrieval intents from domain."""
601
604
  existing_nlu_data = self._importer.get_nlu_data(language)
@@ -630,7 +633,7 @@ class E2EImporter(PassThroughImporter):
630
633
  - adds potential end-to-end bot messages from stories as actions to the domain
631
634
  """
632
635
 
633
- @cached_method
636
+ @rasa.shared.utils.common.cached_method
634
637
  def get_user_flows(self) -> FlowsList:
635
638
  if not isinstance(self._importer, FlowSyncImporter):
636
639
  raise NotImplementedError(
@@ -639,12 +642,9 @@ class E2EImporter(PassThroughImporter):
639
642
 
640
643
  return self._importer.get_user_flows()
641
644
 
642
- @cached_method
645
+ @rasa.shared.utils.common.cached_method
643
646
  def get_domain(self) -> Domain:
644
- """Merge existing domain with properties of end-to-end actions in stories.
645
-
646
- Returns: Domain with end-to-end actions added to action names.
647
- """
647
+ """Retrieves model domain (see parent class for full docstring)."""
648
648
  original = self._importer.get_domain()
649
649
  e2e_domain = self._get_domain_with_e2e_actions()
650
650
 
@@ -674,7 +674,7 @@ class E2EImporter(PassThroughImporter):
674
674
 
675
675
  return Domain.from_dict({KEY_E2E_ACTIONS: list(additional_e2e_action_names)})
676
676
 
677
- @cached_method
677
+ @rasa.shared.utils.common.cached_method
678
678
  def get_nlu_data(self, language: Optional[Text] = "en") -> TrainingData:
679
679
  """Retrieves NLU training data (see parent class for full docstring)."""
680
680
  training_datasets = [
@@ -1,23 +1,22 @@
1
- import os
1
+ import logging
2
2
  from functools import reduce
3
- from typing import Any, Dict, List, Optional, Set, Text, Union
4
-
5
- import structlog
3
+ from typing import Text, Set, Dict, Optional, List, Union, Any
4
+ import os
6
5
 
7
6
  import rasa.shared.data
8
7
  import rasa.shared.utils.io
9
8
  from rasa.shared.core.domain import Domain
9
+ from rasa.shared.importers.importer import TrainingDataImporter
10
+ from rasa.shared.importers import utils
11
+ from rasa.shared.nlu.training_data.training_data import TrainingData
12
+ from rasa.shared.core.training_data.structures import StoryGraph
13
+ from rasa.shared.utils.common import mark_as_experimental_feature
10
14
  from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
11
15
  YAMLStoryReader,
12
16
  )
13
- from rasa.shared.core.training_data.structures import StoryGraph
14
- from rasa.shared.importers import utils
15
- from rasa.shared.importers.importer import TrainingDataImporter
16
- from rasa.shared.nlu.training_data.training_data import TrainingData
17
- from rasa.shared.utils.common import cached_method, mark_as_experimental_feature
18
17
  from rasa.shared.utils.yaml import read_config_file, read_model_configuration
19
18
 
20
- structlogger = structlog.get_logger()
19
+ logger = logging.getLogger(__name__)
21
20
 
22
21
 
23
22
  class MultiProjectImporter(TrainingDataImporter):
@@ -51,13 +50,8 @@ class MultiProjectImporter(TrainingDataImporter):
51
50
  self._story_paths += extra_story_files
52
51
  self._nlu_paths += extra_nlu_files
53
52
 
54
- structlogger.debug(
55
- "multi_project_importer.initialisation",
56
- event_info=(
57
- "Selected projects: {}".format(
58
- "".join([f"\n-{i}" for i in self._imports])
59
- )
60
- ),
53
+ logger.debug(
54
+ "Selected projects: {}".format("".join([f"\n-{i}" for i in self._imports]))
61
55
  )
62
56
 
63
57
  mark_as_experimental_feature(feature_name="MultiProjectImporter")
@@ -141,7 +135,6 @@ class MultiProjectImporter(TrainingDataImporter):
141
135
 
142
136
  return training_paths
143
137
 
144
- @cached_method
145
138
  def is_imported(self, path: Text) -> bool:
146
139
  """Checks whether a path is imported by a skill.
147
140
 
@@ -182,7 +175,6 @@ class MultiProjectImporter(TrainingDataImporter):
182
175
  [rasa.shared.utils.io.is_subdirectory(path, i) for i in self._imports]
183
176
  )
184
177
 
185
- @cached_method
186
178
  def get_domain(self) -> Domain:
187
179
  """Retrieves model domain (see parent class for full docstring)."""
188
180
  domains = [Domain.load(path) for path in self._domain_paths]
@@ -192,24 +184,20 @@ class MultiProjectImporter(TrainingDataImporter):
192
184
  Domain.empty(),
193
185
  )
194
186
 
195
- @cached_method
196
187
  def get_stories(self, exclusion_percentage: Optional[int] = None) -> StoryGraph:
197
188
  """Retrieves training stories / rules (see parent class for full docstring)."""
198
189
  return utils.story_graph_from_paths(
199
190
  self._story_paths, self.get_domain(), exclusion_percentage
200
191
  )
201
192
 
202
- @cached_method
203
193
  def get_conversation_tests(self) -> StoryGraph:
204
194
  """Retrieves conversation test stories (see parent class for full docstring)."""
205
195
  return utils.story_graph_from_paths(self._e2e_story_paths, self.get_domain())
206
196
 
207
- @cached_method
208
197
  def get_config(self) -> Dict:
209
198
  """Retrieves model config (see parent class for full docstring)."""
210
199
  return self.config
211
200
 
212
- @cached_method
213
201
  def get_nlu_data(self, language: Optional[Text] = "en") -> TrainingData:
214
202
  """Retrieves NLU training data (see parent class for full docstring)."""
215
203
  return utils.training_data_from_paths(self._nlu_paths, language)
@@ -6,6 +6,7 @@ import rasa.shared.core.flows.yaml_flows_io
6
6
  from rasa.shared.core.flows import FlowsList
7
7
 
8
8
  import rasa.shared.data
9
+ import rasa.shared.utils.common
9
10
  import rasa.shared.utils.io
10
11
  from rasa.shared.core.training_data.structures import StoryGraph
11
12
  from rasa.shared.importers import utils
@@ -15,7 +16,6 @@ from rasa.shared.core.domain import InvalidDomain, Domain
15
16
  from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
16
17
  YAMLStoryReader,
17
18
  )
18
- from rasa.shared.utils.common import cached_method
19
19
  from rasa.shared.utils.yaml import read_model_configuration
20
20
 
21
21
  logger = logging.getLogger(__name__)
@@ -47,7 +47,6 @@ class RasaFileImporter(TrainingDataImporter):
47
47
 
48
48
  self.config_file = config_file
49
49
 
50
- @cached_method
51
50
  def get_config(self) -> Dict:
52
51
  """Retrieves model config (see parent class for full docstring)."""
53
52
  if not self.config_file or not os.path.exists(self.config_file):
@@ -57,35 +56,31 @@ class RasaFileImporter(TrainingDataImporter):
57
56
  config = read_model_configuration(self.config_file)
58
57
  return config
59
58
 
59
+ @rasa.shared.utils.common.cached_method
60
60
  def get_config_file_for_auto_config(self) -> Optional[Text]:
61
61
  """Returns config file path for auto-config only if there is a single one."""
62
62
  return self.config_file
63
63
 
64
- @cached_method
65
64
  def get_stories(self, exclusion_percentage: Optional[int] = None) -> StoryGraph:
66
65
  """Retrieves training stories / rules (see parent class for full docstring)."""
67
66
  return utils.story_graph_from_paths(
68
67
  self._story_files, self.get_domain(), exclusion_percentage
69
68
  )
70
69
 
71
- @cached_method
72
70
  def get_flows(self) -> FlowsList:
73
71
  """Retrieves training stories / rules (see parent class for full docstring)."""
74
72
  return utils.flows_from_paths(self._flow_files)
75
73
 
76
- @cached_method
77
74
  def get_conversation_tests(self) -> StoryGraph:
78
75
  """Retrieves conversation test stories (see parent class for full docstring)."""
79
76
  return utils.story_graph_from_paths(
80
77
  self._conversation_test_files, self.get_domain()
81
78
  )
82
79
 
83
- @cached_method
84
80
  def get_nlu_data(self, language: Optional[Text] = "en") -> TrainingData:
85
81
  """Retrieves NLU training data (see parent class for full docstring)."""
86
82
  return utils.training_data_from_paths(self._nlu_files, language)
87
83
 
88
- @cached_method
89
84
  def get_domain(self) -> Domain:
90
85
  """Retrieves model domain (see parent class for full docstring)."""
91
86
  domain = Domain.empty()
@@ -8,7 +8,7 @@ import rasa.shared.core.flows.yaml_flows_io
8
8
  import rasa.shared.data
9
9
  import rasa.shared.utils.common
10
10
  import rasa.shared.utils.io
11
- from rasa.core.persistor import StorageType
11
+ from rasa.nlu.persistor import StorageType
12
12
  from rasa.shared.core.domain import Domain, InvalidDomain
13
13
  from rasa.shared.core.flows import FlowsList
14
14
  from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
@@ -79,7 +79,7 @@ class RemoteTrainingDataImporter(TrainingDataImporter):
79
79
  self, training_file: str, training_data_path: Optional[str] = None
80
80
  ) -> str:
81
81
  """Fetches training files from remote storage."""
82
- from rasa.core.persistor import get_persistor
82
+ from rasa.nlu.persistor import get_persistor
83
83
 
84
84
  persistor = get_persistor(self.remote_storage)
85
85
  if persistor is None:
@@ -29,8 +29,6 @@ def flows_from_paths(files: List[Text]) -> FlowsList:
29
29
 
30
30
  flows = FlowsList(underlying_flows=[])
31
31
  for file in files:
32
- flows = flows.merge(
33
- YAMLFlowsReader.read_from_file(file), ignore_duplicates=False
34
- )
32
+ flows = flows.merge(YAMLFlowsReader.read_from_file(file))
35
33
  flows.validate()
36
34
  return flows
@@ -1,6 +1,5 @@
1
1
  import logging
2
2
  import os
3
- from functools import cached_property
4
3
  from pathlib import Path
5
4
  import random
6
5
  from collections import Counter, OrderedDict
@@ -10,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Text, Tuple, Callable
10
9
  import operator
11
10
 
12
11
  import rasa.shared.data
12
+ from rasa.shared.utils.common import lazy_property
13
13
  import rasa.shared.utils.io
14
14
  from rasa.shared.nlu.constants import (
15
15
  RESPONSE,
@@ -202,7 +202,7 @@ class TrainingData:
202
202
 
203
203
  return list(OrderedDict.fromkeys(examples))
204
204
 
205
- @cached_property
205
+ @lazy_property
206
206
  def nlu_examples(self) -> List[Message]:
207
207
  """Return examples which have come from NLU training data.
208
208
 
@@ -215,32 +215,32 @@ class TrainingData:
215
215
  ex for ex in self.training_examples if not ex.is_core_or_domain_message()
216
216
  ]
217
217
 
218
- @cached_property
218
+ @lazy_property
219
219
  def intent_examples(self) -> List[Message]:
220
220
  """Returns the list of examples that have intent."""
221
221
  return [ex for ex in self.nlu_examples if ex.get(INTENT)]
222
222
 
223
- @cached_property
223
+ @lazy_property
224
224
  def response_examples(self) -> List[Message]:
225
225
  """Returns the list of examples that have response."""
226
226
  return [ex for ex in self.nlu_examples if ex.get(INTENT_RESPONSE_KEY)]
227
227
 
228
- @cached_property
228
+ @lazy_property
229
229
  def entity_examples(self) -> List[Message]:
230
230
  """Returns the list of examples that have entities."""
231
231
  return [ex for ex in self.nlu_examples if ex.get(ENTITIES)]
232
232
 
233
- @cached_property
233
+ @lazy_property
234
234
  def intents(self) -> Set[Text]:
235
235
  """Returns the set of intents in the training data."""
236
236
  return {ex.get(INTENT) for ex in self.training_examples} - {None}
237
237
 
238
- @cached_property
238
+ @lazy_property
239
239
  def action_names(self) -> Set[Text]:
240
240
  """Returns the set of action names in the training data."""
241
241
  return {ex.get(ACTION_NAME) for ex in self.training_examples} - {None}
242
242
 
243
- @cached_property
243
+ @lazy_property
244
244
  def retrieval_intents(self) -> Set[Text]:
245
245
  """Returns the total number of response types in the training data."""
246
246
  return {
@@ -249,13 +249,13 @@ class TrainingData:
249
249
  if ex.get(INTENT_RESPONSE_KEY)
250
250
  }
251
251
 
252
- @cached_property
252
+ @lazy_property
253
253
  def number_of_examples_per_intent(self) -> Dict[Text, int]:
254
254
  """Calculates the number of examples per intent."""
255
255
  intents = [ex.get(INTENT) for ex in self.nlu_examples]
256
256
  return dict(Counter(intents))
257
257
 
258
- @cached_property
258
+ @lazy_property
259
259
  def number_of_examples_per_response(self) -> Dict[Text, int]:
260
260
  """Calculates the number of examples per response."""
261
261
  responses = [
@@ -265,12 +265,12 @@ class TrainingData:
265
265
  ]
266
266
  return dict(Counter(responses))
267
267
 
268
- @cached_property
268
+ @lazy_property
269
269
  def entities(self) -> Set[Text]:
270
270
  """Returns the set of entity types in the training data."""
271
271
  return {e.get(ENTITY_ATTRIBUTE_TYPE) for e in self.sorted_entities()}
272
272
 
273
- @cached_property
273
+ @lazy_property
274
274
  def entity_roles(self) -> Set[Text]:
275
275
  """Returns the set of entity roles in the training data."""
276
276
  entity_types = {
@@ -280,7 +280,7 @@ class TrainingData:
280
280
  }
281
281
  return entity_types - {NO_ENTITY_TAG}
282
282
 
283
- @cached_property
283
+ @lazy_property
284
284
  def entity_groups(self) -> Set[Text]:
285
285
  """Returns the set of entity groups in the training data."""
286
286
  entity_types = {
@@ -299,7 +299,7 @@ class TrainingData:
299
299
 
300
300
  return entity_groups_used or entity_roles_used
301
301
 
302
- @cached_property
302
+ @lazy_property
303
303
  def number_of_examples_per_entity(self) -> Dict[Text, int]:
304
304
  """Calculates the number of examples per entity."""
305
305
  entities = []
@@ -426,9 +426,8 @@ class TrainingData:
426
426
  def persist(
427
427
  self, dir_name: Text, filename: Text = DEFAULT_TRAINING_DATA_OUTPUT_PATH
428
428
  ) -> Dict[Text, Any]:
429
- """Persists this training data to disk.
430
-
431
- Returns: necessary information to load it again.
429
+ """Persists this training data to disk and returns necessary
430
+ information to load it again.
432
431
  """
433
432
  if not os.path.exists(dir_name):
434
433
  os.makedirs(dir_name)
@@ -499,7 +498,9 @@ class TrainingData:
499
498
  def train_test_split(
500
499
  self, train_frac: float = 0.8, random_seed: Optional[int] = None
501
500
  ) -> Tuple["TrainingData", "TrainingData"]:
502
- """Split into a training and test dataset, preserving the fraction of examples per intent.""" # noqa: E501
501
+ """Split into a training and test dataset,
502
+ preserving the fraction of examples per intent.
503
+ """
503
504
  # collect all nlu data
504
505
  test, train = self.split_nlu_examples(train_frac, random_seed)
505
506
 
@@ -107,8 +107,7 @@ class AzureOpenAIClientConfig:
107
107
 
108
108
  @classmethod
109
109
  def from_dict(cls, config: dict) -> "AzureOpenAIClientConfig":
110
- """
111
- Initializes a dataclass from the passed config.
110
+ """Initializes a dataclass from the passed config.
112
111
 
113
112
  Args:
114
113
  config: (dict) The config from which to initialize.
@@ -175,7 +174,10 @@ def is_azure_openai_config(config: dict) -> bool:
175
174
 
176
175
  # Case: Configuration contains `deployment` key
177
176
  # (specific to Azure OpenAI configuration)
178
- if config.get(DEPLOYMENT_CONFIG_KEY) is not None:
177
+ if (
178
+ config.get(DEPLOYMENT_CONFIG_KEY) is not None
179
+ and config.get(PROVIDER_CONFIG_KEY) is None
180
+ ):
179
181
  return True
180
182
 
181
183
  return False
@@ -1,7 +1,7 @@
1
1
  from abc import abstractmethod
2
2
  from typing import Dict, List, Any, Union
3
-
4
3
  import logging
4
+
5
5
  import structlog
6
6
  from litellm import (
7
7
  completion,
@@ -29,8 +29,7 @@ logging.getLogger("LiteLLM").setLevel(logging.WARNING)
29
29
 
30
30
 
31
31
  class _BaseLiteLLMClient:
32
- """
33
- An abstract base class for LiteLLM clients.
32
+ """An abstract base class for LiteLLM clients.
34
33
 
35
34
  This class defines the interface and common functionality for all clients
36
35
  based on LiteLLM.
@@ -132,14 +131,15 @@ class _BaseLiteLLMClient:
132
131
 
133
132
  @suppress_logs(log_level=logging.WARNING)
134
133
  def completion(self, messages: Union[List[str], str]) -> LLMResponse:
135
- """
136
- Synchronously generate completions for given list of messages.
134
+ """Synchronously generate completions for given list of messages.
137
135
 
138
136
  Args:
139
137
  messages: List of messages or a single message to generate the
140
138
  completion for.
139
+
141
140
  Returns:
142
141
  List of message completions.
142
+
143
143
  Raises:
144
144
  ProviderClientAPIException: If the API request fails.
145
145
  """
@@ -154,14 +154,15 @@ class _BaseLiteLLMClient:
154
154
 
155
155
  @suppress_logs(log_level=logging.WARNING)
156
156
  async def acompletion(self, messages: Union[List[str], str]) -> LLMResponse:
157
- """
158
- Asynchronously generate completions for given list of messages.
157
+ """Asynchronously generate completions for given list of messages.
159
158
 
160
159
  Args:
161
160
  messages: List of messages or a single message to generate the
162
161
  completion for.
162
+
163
163
  Returns:
164
164
  List of message completions.
165
+
165
166
  Raises:
166
167
  ProviderClientAPIException: If the API request fails.
167
168
  """
@@ -172,7 +173,23 @@ class _BaseLiteLLMClient:
172
173
  )
173
174
  return self._format_response(response)
174
175
  except Exception as e:
175
- raise ProviderClientAPIException(e)
176
+ message = ""
177
+ from rasa.shared.providers.llm.self_hosted_llm_client import (
178
+ SelfHostedLLMClient,
179
+ )
180
+
181
+ if isinstance(self, SelfHostedLLMClient):
182
+ message = (
183
+ "If you are using 'provider=self-hosted' to call a hosted vllm "
184
+ "server make sure your config is correctly setup. You should have "
185
+ "the following mandatory keys in your config: "
186
+ "provider=self-hosted; "
187
+ "model='<your-vllm-model-name>'; "
188
+ "api_base='your-hosted-vllm-serv'."
189
+ "In case you are getting OpenAI connection errors, such as missing "
190
+ "API key, your configuration is incorrect."
191
+ )
192
+ raise ProviderClientAPIException(e, message)
176
193
 
177
194
  def _format_messages(self, messages: Union[List[str], str]) -> List[Dict[str, str]]:
178
195
  """Formats messages (or a single message) to OpenAI format."""
@@ -216,8 +233,7 @@ class _BaseLiteLLMClient:
216
233
 
217
234
  @staticmethod
218
235
  def _ensure_certificates() -> None:
219
- """
220
- Configures SSL certificates for LiteLLM. This method is invoked during
236
+ """Configures SSL certificates for LiteLLM. This method is invoked during
221
237
  client initialization.
222
238
 
223
239
  LiteLLM may utilize `openai` clients or other providers that require