rasa-pro 3.10.8__py3-none-any.whl → 3.10.9.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

rasa/constants.py CHANGED
@@ -18,7 +18,7 @@ CONFIG_TELEMETRY_ID = "rasa_user_id"
18
18
  CONFIG_TELEMETRY_ENABLED = "enabled"
19
19
  CONFIG_TELEMETRY_DATE = "date"
20
20
 
21
- MINIMUM_COMPATIBLE_VERSION = "3.10.0rc1"
21
+ MINIMUM_COMPATIBLE_VERSION = "3.10.9.dev1"
22
22
 
23
23
  GLOBAL_USER_CONFIG_PATH = os.path.expanduser("~/.config/rasa/global.yml")
24
24
 
@@ -1,7 +1,8 @@
1
1
  import logging
2
+ from typing import List, Optional, Dict, Text, Set, Any
3
+
2
4
  import numpy as np
3
5
  import scipy.sparse
4
- from typing import List, Optional, Dict, Text, Set, Any
5
6
 
6
7
  from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
7
8
  from rasa.nlu.extractors.extractor import EntityTagSpec
@@ -360,6 +361,26 @@ class SingleStateFeaturizer:
360
361
  for action in domain.action_names_or_texts
361
362
  ]
362
363
 
364
+ def to_dict(self) -> Dict[str, Any]:
365
+ return {
366
+ "action_texts": self.action_texts,
367
+ "entity_tag_specs": self.entity_tag_specs,
368
+ "feature_states": self._default_feature_states,
369
+ }
370
+
371
+ @classmethod
372
+ def create_from_dict(
373
+ cls, data: Dict[str, Any]
374
+ ) -> Optional["SingleStateFeaturizer"]:
375
+ if not data:
376
+ return None
377
+
378
+ featurizer = SingleStateFeaturizer()
379
+ featurizer.action_texts = data["action_texts"]
380
+ featurizer._default_feature_states = data["feature_states"]
381
+ featurizer.entity_tag_specs = data["entity_tag_specs"]
382
+ return featurizer
383
+
363
384
 
364
385
  class IntentTokenizerSingleStateFeaturizer(SingleStateFeaturizer):
365
386
  """A SingleStateFeaturizer for use with policies that predict intent labels."""
@@ -1,11 +1,9 @@
1
1
  from __future__ import annotations
2
- from pathlib import Path
3
- from collections import defaultdict
4
- from abc import abstractmethod
5
- import jsonpickle
6
- import logging
7
2
 
8
- from tqdm import tqdm
3
+ import logging
4
+ from abc import abstractmethod
5
+ from collections import defaultdict
6
+ from pathlib import Path
9
7
  from typing import (
10
8
  Tuple,
11
9
  List,
@@ -18,25 +16,30 @@ from typing import (
18
16
  Set,
19
17
  DefaultDict,
20
18
  cast,
19
+ Type,
20
+ Callable,
21
+ ClassVar,
21
22
  )
23
+
22
24
  import numpy as np
25
+ from tqdm import tqdm
23
26
 
24
- from rasa.core.featurizers.single_state_featurizer import SingleStateFeaturizer
25
- from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
26
- from rasa.core.exceptions import InvalidTrackerFeaturizerUsageError
27
27
  import rasa.shared.core.trackers
28
28
  import rasa.shared.utils.io
29
- from rasa.shared.nlu.constants import TEXT, INTENT, ENTITIES, ACTION_NAME
30
- from rasa.shared.nlu.training_data.features import Features
31
- from rasa.shared.core.trackers import DialogueStateTracker
32
- from rasa.shared.core.domain import State, Domain
33
- from rasa.shared.core.events import Event, ActionExecuted, UserUttered
29
+ from rasa.core.exceptions import InvalidTrackerFeaturizerUsageError
30
+ from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
31
+ from rasa.core.featurizers.single_state_featurizer import SingleStateFeaturizer
34
32
  from rasa.shared.core.constants import (
35
33
  USER,
36
34
  ACTION_UNLIKELY_INTENT_NAME,
37
35
  PREVIOUS_ACTION,
38
36
  )
37
+ from rasa.shared.core.domain import State, Domain
38
+ from rasa.shared.core.events import Event, ActionExecuted, UserUttered
39
+ from rasa.shared.core.trackers import DialogueStateTracker
39
40
  from rasa.shared.exceptions import RasaException
41
+ from rasa.shared.nlu.constants import TEXT, INTENT, ENTITIES, ACTION_NAME
42
+ from rasa.shared.nlu.training_data.features import Features
40
43
  from rasa.utils.tensorflow.constants import LABEL_PAD_ID
41
44
  from rasa.utils.tensorflow.model_data import ragged_array_to_ndarray
42
45
 
@@ -64,6 +67,10 @@ class InvalidStory(RasaException):
64
67
  class TrackerFeaturizer:
65
68
  """Base class for actual tracker featurizers."""
66
69
 
70
+ # Class registry to store all subclasses
71
+ _registry: ClassVar[Dict[str, Type["TrackerFeaturizer"]]] = {}
72
+ _featurizer_type: str = "TrackerFeaturizer"
73
+
67
74
  def __init__(
68
75
  self, state_featurizer: Optional[SingleStateFeaturizer] = None
69
76
  ) -> None:
@@ -74,6 +81,36 @@ class TrackerFeaturizer:
74
81
  """
75
82
  self.state_featurizer = state_featurizer
76
83
 
84
+ @classmethod
85
+ def register(cls, featurizer_type: str) -> Callable:
86
+ """Decorator to register featurizer subclasses."""
87
+
88
+ def wrapper(subclass: Type["TrackerFeaturizer"]) -> Type["TrackerFeaturizer"]:
89
+ cls._registry[featurizer_type] = subclass
90
+ # Store the type identifier in the class for serialization
91
+ subclass._featurizer_type = featurizer_type
92
+ return subclass
93
+
94
+ return wrapper
95
+
96
+ @classmethod
97
+ def from_dict(cls, data: Dict[str, Any]) -> "TrackerFeaturizer":
98
+ """Create featurizer instance from dictionary."""
99
+ featurizer_type = data.pop("type")
100
+
101
+ if featurizer_type not in cls._registry:
102
+ raise ValueError(f"Unknown featurizer type: {featurizer_type}")
103
+
104
+ # Get the correct subclass and instantiate it
105
+ subclass = cls._registry[featurizer_type]
106
+ return subclass.create_from_dict(data)
107
+
108
+ @classmethod
109
+ @abstractmethod
110
+ def create_from_dict(cls, data: Dict[str, Any]) -> "TrackerFeaturizer":
111
+ """Each subclass must implement its own creation from dict method."""
112
+ pass
113
+
77
114
  @staticmethod
78
115
  def _create_states(
79
116
  tracker: DialogueStateTracker,
@@ -465,9 +502,7 @@ class TrackerFeaturizer:
465
502
  self.state_featurizer.entity_tag_specs = []
466
503
 
467
504
  # noinspection PyTypeChecker
468
- rasa.shared.utils.io.write_text_file(
469
- str(jsonpickle.encode(self)), featurizer_file
470
- )
505
+ rasa.shared.utils.io.dump_obj_as_json_to_file(featurizer_file, self.to_dict())
471
506
 
472
507
  @staticmethod
473
508
  def load(path: Union[Text, Path]) -> Optional[TrackerFeaturizer]:
@@ -481,7 +516,17 @@ class TrackerFeaturizer:
481
516
  """
482
517
  featurizer_file = Path(path) / FEATURIZER_FILE
483
518
  if featurizer_file.is_file():
484
- return jsonpickle.decode(rasa.shared.utils.io.read_file(featurizer_file))
519
+ data = rasa.shared.utils.io.read_json_file(featurizer_file)
520
+
521
+ if "type" not in data:
522
+ logger.error(
523
+ f"Couldn't load featurizer for policy. "
524
+ f"File '{featurizer_file}' does not contain all "
525
+ f"necessary information. 'type' is missing."
526
+ )
527
+ return None
528
+
529
+ return TrackerFeaturizer.from_dict(data)
485
530
 
486
531
  logger.error(
487
532
  f"Couldn't load featurizer for policy. "
@@ -508,7 +553,16 @@ class TrackerFeaturizer:
508
553
  )
509
554
  ]
510
555
 
556
+ def to_dict(self) -> Dict[str, Any]:
557
+ return {
558
+ "type": self.__class__._featurizer_type,
559
+ "state_featurizer": (
560
+ self.state_featurizer.to_dict() if self.state_featurizer else None
561
+ ),
562
+ }
563
+
511
564
 
565
+ @TrackerFeaturizer.register("FullDialogueTrackerFeaturizer")
512
566
  class FullDialogueTrackerFeaturizer(TrackerFeaturizer):
513
567
  """Creates full dialogue training data for time distributed architectures.
514
568
 
@@ -646,7 +700,20 @@ class FullDialogueTrackerFeaturizer(TrackerFeaturizer):
646
700
 
647
701
  return trackers_as_states
648
702
 
703
+ def to_dict(self) -> Dict[str, Any]:
704
+ return super().to_dict()
649
705
 
706
+ @classmethod
707
+ def create_from_dict(cls, data: Dict[str, Any]) -> "FullDialogueTrackerFeaturizer":
708
+ state_featurizer = SingleStateFeaturizer.create_from_dict(
709
+ data["state_featurizer"]
710
+ )
711
+ return cls(
712
+ state_featurizer,
713
+ )
714
+
715
+
716
+ @TrackerFeaturizer.register("MaxHistoryTrackerFeaturizer")
650
717
  class MaxHistoryTrackerFeaturizer(TrackerFeaturizer):
651
718
  """Truncates the tracker history into `max_history` long sequences.
652
719
 
@@ -884,7 +951,25 @@ class MaxHistoryTrackerFeaturizer(TrackerFeaturizer):
884
951
 
885
952
  return trackers_as_states
886
953
 
954
+ def to_dict(self) -> Dict[str, Any]:
955
+ data = super().to_dict()
956
+ data.update(
957
+ {
958
+ "remove_duplicates": self.remove_duplicates,
959
+ "max_history": self.max_history,
960
+ }
961
+ )
962
+ return data
963
+
964
+ @classmethod
965
+ def create_from_dict(cls, data: Dict[str, Any]) -> "MaxHistoryTrackerFeaturizer":
966
+ state_featurizer = SingleStateFeaturizer.create_from_dict(
967
+ data["state_featurizer"]
968
+ )
969
+ return cls(state_featurizer, data["max_history"], data["remove_duplicates"])
887
970
 
971
+
972
+ @TrackerFeaturizer.register("IntentMaxHistoryTrackerFeaturizer")
888
973
  class IntentMaxHistoryTrackerFeaturizer(MaxHistoryTrackerFeaturizer):
889
974
  """Truncates the tracker history into `max_history` long sequences.
890
975
 
@@ -1159,6 +1244,18 @@ class IntentMaxHistoryTrackerFeaturizer(MaxHistoryTrackerFeaturizer):
1159
1244
 
1160
1245
  return trackers_as_states
1161
1246
 
1247
+ def to_dict(self) -> Dict[str, Any]:
1248
+ return super().to_dict()
1249
+
1250
+ @classmethod
1251
+ def create_from_dict(
1252
+ cls, data: Dict[str, Any]
1253
+ ) -> "IntentMaxHistoryTrackerFeaturizer":
1254
+ state_featurizer = SingleStateFeaturizer.create_from_dict(
1255
+ data["state_featurizer"]
1256
+ )
1257
+ return cls(state_featurizer, data["max_history"], data["remove_duplicates"])
1258
+
1162
1259
 
1163
1260
  def _is_prev_action_unlikely_intent_in_state(state: State) -> bool:
1164
1261
  prev_action_name = state.get(PREVIOUS_ACTION, {}).get(ACTION_NAME)
@@ -1,15 +1,15 @@
1
1
  from __future__ import annotations
2
- import logging
3
2
 
4
- from rasa.engine.recipes.default_recipe import DefaultV1Recipe
3
+ import logging
5
4
  from pathlib import Path
6
5
  from collections import defaultdict
7
6
  import contextlib
7
+ from typing import Any, List, Optional, Text, Dict, Tuple, Union, Type
8
8
 
9
9
  import numpy as np
10
10
  import tensorflow as tf
11
- from typing import Any, List, Optional, Text, Dict, Tuple, Union, Type
12
11
 
12
+ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
13
13
  from rasa.engine.graph import ExecutionContext
14
14
  from rasa.engine.storage.resource import Resource
15
15
  from rasa.engine.storage.storage import ModelStorage
@@ -49,18 +49,22 @@ from rasa.shared.core.generator import TrackerWithCachedStates
49
49
  from rasa.shared.core.events import EntitiesAdded, Event
50
50
  from rasa.shared.core.domain import Domain
51
51
  from rasa.shared.nlu.training_data.message import Message
52
- from rasa.shared.nlu.training_data.features import Features
52
+ from rasa.shared.nlu.training_data.features import (
53
+ Features,
54
+ save_features,
55
+ load_features,
56
+ )
53
57
  import rasa.shared.utils.io
54
58
  import rasa.utils.io
55
59
  from rasa.utils import train_utils
56
- from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
57
- from rasa.utils.tensorflow import rasa_layers
58
- from rasa.utils.tensorflow.model_data import (
59
- RasaModelData,
60
- FeatureSignature,
60
+ from rasa.utils.tensorflow.feature_array import (
61
61
  FeatureArray,
62
- Data,
62
+ serialize_nested_feature_arrays,
63
+ deserialize_nested_feature_arrays,
63
64
  )
65
+ from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel
66
+ from rasa.utils.tensorflow import rasa_layers
67
+ from rasa.utils.tensorflow.model_data import RasaModelData, FeatureSignature, Data
64
68
  from rasa.utils.tensorflow.model_data_utils import convert_to_data_format
65
69
  from rasa.utils.tensorflow.constants import (
66
70
  LABEL,
@@ -961,22 +965,32 @@ class TEDPolicy(Policy):
961
965
  model_path: Path where model is to be persisted
962
966
  """
963
967
  model_filename = self._metadata_filename()
964
- rasa.utils.io.json_pickle(
965
- model_path / f"{model_filename}.priority.pkl", self.priority
966
- )
967
- rasa.utils.io.pickle_dump(
968
- model_path / f"{model_filename}.meta.pkl", self.config
968
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
969
+ model_path / f"{model_filename}.priority.json", self.priority
969
970
  )
970
- rasa.utils.io.pickle_dump(
971
- model_path / f"{model_filename}.data_example.pkl", self.data_example
971
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
972
+ model_path / f"{model_filename}.meta.json", self.config
972
973
  )
973
- rasa.utils.io.pickle_dump(
974
- model_path / f"{model_filename}.fake_features.pkl", self.fake_features
974
+ # save data example
975
+ serialize_nested_feature_arrays(
976
+ self.data_example,
977
+ str(model_path / f"{model_filename}.data_example.st"),
978
+ str(model_path / f"{model_filename}.data_example_metadata.json"),
975
979
  )
976
- rasa.utils.io.pickle_dump(
977
- model_path / f"{model_filename}.label_data.pkl",
980
+ # save label data
981
+ serialize_nested_feature_arrays(
978
982
  dict(self._label_data.data) if self._label_data is not None else {},
983
+ str(model_path / f"{model_filename}.label_data.st"),
984
+ str(model_path / f"{model_filename}.label_data_metadata.json"),
985
+ )
986
+ # save fake features
987
+ metadata = save_features(
988
+ self.fake_features, str(model_path / f"{model_filename}.fake_features.st")
989
+ )
990
+ rasa.shared.utils.io.dump_obj_as_json_to_file(
991
+ model_path / f"{model_filename}.fake_features_metadata.json", metadata
979
992
  )
993
+
980
994
  entity_tag_specs = (
981
995
  [tag_spec._asdict() for tag_spec in self._entity_tag_specs]
982
996
  if self._entity_tag_specs
@@ -994,18 +1008,29 @@ class TEDPolicy(Policy):
994
1008
  model_path: Path where model is to be persisted.
995
1009
  """
996
1010
  tf_model_file = model_path / f"{cls._metadata_filename()}.tf_model"
997
- loaded_data = rasa.utils.io.pickle_load(
998
- model_path / f"{cls._metadata_filename()}.data_example.pkl"
1011
+
1012
+ # load data example
1013
+ loaded_data = deserialize_nested_feature_arrays(
1014
+ str(model_path / f"{cls._metadata_filename()}.data_example.st"),
1015
+ str(model_path / f"{cls._metadata_filename()}.data_example_metadata.json"),
999
1016
  )
1000
- label_data = rasa.utils.io.pickle_load(
1001
- model_path / f"{cls._metadata_filename()}.label_data.pkl"
1017
+ # load label data
1018
+ loaded_label_data = deserialize_nested_feature_arrays(
1019
+ str(model_path / f"{cls._metadata_filename()}.label_data.st"),
1020
+ str(model_path / f"{cls._metadata_filename()}.label_data_metadata.json"),
1002
1021
  )
1003
- fake_features = rasa.utils.io.pickle_load(
1004
- model_path / f"{cls._metadata_filename()}.fake_features.pkl"
1022
+ label_data = RasaModelData(data=loaded_label_data)
1023
+
1024
+ # load fake features
1025
+ metadata = rasa.shared.utils.io.read_json_file(
1026
+ model_path / f"{cls._metadata_filename()}.fake_features_metadata.json"
1005
1027
  )
1006
- label_data = RasaModelData(data=label_data)
1007
- priority = rasa.utils.io.json_unpickle(
1008
- model_path / f"{cls._metadata_filename()}.priority.pkl"
1028
+ fake_features = load_features(
1029
+ str(model_path / f"{cls._metadata_filename()}.fake_features.st"), metadata
1030
+ )
1031
+
1032
+ priority = rasa.shared.utils.io.read_json_file(
1033
+ model_path / f"{cls._metadata_filename()}.priority.json"
1009
1034
  )
1010
1035
  entity_tag_specs = rasa.shared.utils.io.read_json_file(
1011
1036
  model_path / f"{cls._metadata_filename()}.entity_tag_specs.json"
@@ -1023,8 +1048,8 @@ class TEDPolicy(Policy):
1023
1048
  )
1024
1049
  for tag_spec in entity_tag_specs
1025
1050
  ]
1026
- model_config = rasa.utils.io.pickle_load(
1027
- model_path / f"{cls._metadata_filename()}.meta.pkl"
1051
+ model_config = rasa.shared.utils.io.read_json_file(
1052
+ model_path / f"{cls._metadata_filename()}.meta.json"
1028
1053
  )
1029
1054
 
1030
1055
  return {
@@ -1070,7 +1095,7 @@ class TEDPolicy(Policy):
1070
1095
  ) -> TEDPolicy:
1071
1096
  featurizer = TrackerFeaturizer.load(model_path)
1072
1097
 
1073
- if not (model_path / f"{cls._metadata_filename()}.data_example.pkl").is_file():
1098
+ if not (model_path / f"{cls._metadata_filename()}.data_example.st").is_file():
1074
1099
  return cls(
1075
1100
  config,
1076
1101
  model_storage,
@@ -5,6 +5,7 @@ from typing import Any, List, Optional, Text, Dict, Type, Union
5
5
 
6
6
  import numpy as np
7
7
  import tensorflow as tf
8
+
8
9
  import rasa.utils.common
9
10
  from rasa.engine.graph import ExecutionContext
10
11
  from rasa.engine.recipes.default_recipe import DefaultV1Recipe
@@ -16,6 +17,7 @@ from rasa.shared.core.domain import Domain
16
17
  from rasa.shared.core.trackers import DialogueStateTracker
17
18
  from rasa.shared.core.constants import SLOTS, ACTIVE_LOOP, ACTION_UNLIKELY_INTENT_NAME
18
19
  from rasa.shared.core.events import UserUttered, ActionExecuted
20
+ import rasa.shared.utils.io
19
21
  from rasa.shared.nlu.constants import (
20
22
  INTENT,
21
23
  TEXT,
@@ -103,8 +105,6 @@ from rasa.utils.tensorflow.constants import (
103
105
  )
104
106
  from rasa.utils.tensorflow import layers
105
107
  from rasa.utils.tensorflow.model_data import RasaModelData, FeatureArray, Data
106
-
107
- import rasa.utils.io as io_utils
108
108
  from rasa.core.exceptions import RasaCoreException
109
109
  from rasa.shared.utils import common
110
110
 
@@ -881,9 +881,12 @@ class UnexpecTEDIntentPolicy(TEDPolicy):
881
881
  model_path: Path where model is to be persisted
882
882
  """
883
883
  super().persist_model_utilities(model_path)
884
- io_utils.pickle_dump(
885
- model_path / f"{self._metadata_filename()}.label_quantiles.pkl",
886
- self.label_quantiles,
884
+
885
+ from safetensors.numpy import save_file
886
+
887
+ save_file(
888
+ {str(k): np.array(v) for k, v in self.label_quantiles.items()},
889
+ model_path / f"{self._metadata_filename()}.label_quantiles.st",
887
890
  )
888
891
 
889
892
  @classmethod
@@ -894,9 +897,14 @@ class UnexpecTEDIntentPolicy(TEDPolicy):
894
897
  model_path: Path where model is to be persisted.
895
898
  """
896
899
  model_utilties = super()._load_model_utilities(model_path)
897
- label_quantiles = io_utils.pickle_load(
898
- model_path / f"{cls._metadata_filename()}.label_quantiles.pkl"
900
+
901
+ from safetensors.numpy import load_file
902
+
903
+ loaded_label_quantiles = load_file(
904
+ model_path / f"{cls._metadata_filename()}.label_quantiles.st"
899
905
  )
906
+ label_quantiles = {int(k): list(v) for k, v in loaded_label_quantiles.items()}
907
+
900
908
  model_utilties.update({"label_quantiles": label_quantiles})
901
909
  return model_utilties
902
910
 
rasa/core/processor.py CHANGED
@@ -101,6 +101,9 @@ logger = logging.getLogger(__name__)
101
101
  structlogger = structlog.get_logger()
102
102
 
103
103
  MAX_NUMBER_OF_PREDICTIONS = int(os.environ.get("MAX_NUMBER_OF_PREDICTIONS", "10"))
104
+ MAX_NUMBER_OF_PREDICTIONS_CALM = int(
105
+ os.environ.get("MAX_NUMBER_OF_PREDICTIONS_CALM", "1000")
106
+ )
104
107
 
105
108
 
106
109
  class MessageProcessor:
@@ -114,6 +117,7 @@ class MessageProcessor:
114
117
  generator: NaturalLanguageGenerator,
115
118
  action_endpoint: Optional[EndpointConfig] = None,
116
119
  max_number_of_predictions: int = MAX_NUMBER_OF_PREDICTIONS,
120
+ max_number_of_predictions_calm: int = MAX_NUMBER_OF_PREDICTIONS_CALM,
117
121
  on_circuit_break: Optional[LambdaType] = None,
118
122
  http_interpreter: Optional[RasaNLUHttpInterpreter] = None,
119
123
  endpoints: Optional["AvailableEndpoints"] = None,
@@ -122,7 +126,6 @@ class MessageProcessor:
122
126
  self.nlg = generator
123
127
  self.tracker_store = tracker_store
124
128
  self.lock_store = lock_store
125
- self.max_number_of_predictions = max_number_of_predictions
126
129
  self.on_circuit_break = on_circuit_break
127
130
  self.action_endpoint = action_endpoint
128
131
  self.model_filename, self.model_metadata, self.graph_runner = self._load_model(
@@ -130,6 +133,10 @@ class MessageProcessor:
130
133
  )
131
134
  self.endpoints = endpoints
132
135
 
136
+ self.max_number_of_predictions = max_number_of_predictions
137
+ self.max_number_of_predictions_calm = max_number_of_predictions_calm
138
+ self.is_calm_assistant = self._is_calm_assistant()
139
+
133
140
  if self.model_metadata.assistant_id is None:
134
141
  rasa.shared.utils.io.raise_warning(
135
142
  f"The model metadata does not contain a value for the "
@@ -972,11 +979,15 @@ class MessageProcessor:
972
979
  ) -> int:
973
980
  """Select the action limit based on the tracker state.
974
981
 
975
- Usually, we want to limit the number of predictions to the number of actions
976
- that have been executed in the conversation so far. However, if the
977
- conversation is currently in a state where the user is correcting the flow
978
- we want to allow for more predictions to be made as we might be traversing
979
- through a long flow.
982
+ This function determines the maximum number of predictions that should be
983
+ made during a dialogue conversation. Typically, the number of predictions
984
+ is limited to the number of actions executed so far in the conversation.
985
+ However, in certain states (e.g., when the user is correcting the
986
+ conversation flow), more predictions may be allowed as the system traverses
987
+ through a long dialogue flow.
988
+
989
+ Additionally, if the `ROUTE_TO_CALM_SLOT` is present in the tracker slots,
990
+ the action limit is adjusted to a separate limit for CALM-based flows.
980
991
 
981
992
  Args:
982
993
  tracker: instance of DialogueStateTracker.
@@ -984,6 +995,18 @@ class MessageProcessor:
984
995
  Returns:
985
996
  The maximum number of predictions to make.
986
997
  """
998
+ # Check if it is a CALM assistant and if so, that the `ROUTE_TO_CALM_SLOT`
999
+ # is either not present or set to `True`.
1000
+ # If it does, use the specific prediction limit for CALM assistants.
1001
+ # Otherwise, use the default prediction limit.
1002
+ if self.is_calm_assistant and (
1003
+ not tracker.has_coexistence_routing_slot
1004
+ or tracker.get_slot(ROUTE_TO_CALM_SLOT)
1005
+ ):
1006
+ max_number_of_predictions = self.max_number_of_predictions_calm
1007
+ else:
1008
+ max_number_of_predictions = self.max_number_of_predictions
1009
+
987
1010
  reversed_events = list(tracker.events)[::-1]
988
1011
  is_conversation_in_flow_correction = False
989
1012
  for e in reversed_events:
@@ -998,8 +1021,10 @@ class MessageProcessor:
998
1021
  # allow for more predictions to be made as we might be traversing through
999
1022
  # a long flow. We multiply the number of predictions by 10 to allow for
1000
1023
  # more predictions to be made - the factor is a best guess.
1001
- return self.max_number_of_predictions * 5
1002
- return self.max_number_of_predictions
1024
+ return max_number_of_predictions * 5
1025
+
1026
+ # Return the default
1027
+ return max_number_of_predictions
1003
1028
 
1004
1029
  def is_action_limit_reached(
1005
1030
  self, tracker: DialogueStateTracker, should_predict_another_action: bool
@@ -1387,3 +1412,27 @@ class MessageProcessor:
1387
1412
  ]
1388
1413
 
1389
1414
  return len(filtered_commands) > 0
1415
+
1416
+ def _is_calm_assistant(self) -> bool:
1417
+ """Inspects the nodes of the graph schema to determine whether
1418
+ any node is associated with the `FlowPolicy`, which is indicative of a
1419
+ CALM assistant setup.
1420
+
1421
+ Returns:
1422
+ bool: True if any node in the graph schema uses `FlowPolicy`.
1423
+ """
1424
+ # Get the graph schema's nodes from the graph runner.
1425
+ nodes: dict[str, Any] = self.graph_runner._graph_schema.nodes # type: ignore[attr-defined]
1426
+
1427
+ flow_policy_class_path = "rasa.core.policies.flow_policy.FlowPolicy"
1428
+ # Iterate over the nodes and check if any node uses `FlowPolicy`.
1429
+ for node_name, schema_node in nodes.items():
1430
+ if (
1431
+ schema_node.uses is not None
1432
+ and f"{schema_node.uses.__module__}.{schema_node.uses.__name__}"
1433
+ == flow_policy_class_path
1434
+ ):
1435
+ return True
1436
+
1437
+ # Return False if no node is found using `FlowPolicy`.
1438
+ return False
rasa/e2e_test/utils/io.py CHANGED
@@ -346,7 +346,7 @@ def read_test_cases(path: str) -> TestSuite:
346
346
  beta_flag_verified = False
347
347
 
348
348
  for test_file in test_files:
349
- test_file_content = parse_raw_yaml(Path(test_file).read_text())
349
+ test_file_content = parse_raw_yaml(Path(test_file).read_text(encoding="utf-8"))
350
350
 
351
351
  validate_yaml_data_using_schema_with_assertions(
352
352
  yaml_data=test_file_content, schema_content=e2e_test_schema
@@ -506,6 +506,8 @@ def transform_results_output_to_yaml(yaml_string: str) -> str:
506
506
  result.append(s)
507
507
  elif s.startswith("\n"):
508
508
  result.append(s.strip())
509
+ elif s.strip().startswith("#"):
510
+ continue
509
511
  else:
510
512
  result.append(s)
511
513
  return "".join(result)