rasa-pro 3.13.1a14__py3-none-any.whl → 3.13.1a15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rasa-pro might be problematic. Click here for more details.

@@ -9,7 +9,7 @@ from rasa.builder import config
9
9
  from rasa.builder.exceptions import AgentLoadError, TrainingError
10
10
  from rasa.core import agent
11
11
  from rasa.core.utils import AvailableEndpoints, read_endpoints_from_path
12
- from rasa.model_training import train
12
+ from rasa.model_training import TrainingResult, train
13
13
  from rasa.shared.importers.importer import TrainingDataImporter
14
14
  from rasa.shared.utils.yaml import dump_obj_as_yaml_to_string
15
15
 
@@ -55,7 +55,7 @@ async def train_and_load_agent(importer: TrainingDataImporter) -> agent.Agent:
55
55
  raise TrainingError(f"SystemExit during training: {e}")
56
56
 
57
57
 
58
- async def _setup_endpoints():
58
+ async def _setup_endpoints() -> None:
59
59
  """Setup endpoints configuration for training."""
60
60
  try:
61
61
  with tempfile.NamedTemporaryFile(
@@ -75,7 +75,7 @@ async def _setup_endpoints():
75
75
  raise TrainingError(f"Failed to setup endpoints: {e}")
76
76
 
77
77
 
78
- async def _train_model(importer: TrainingDataImporter):
78
+ async def _train_model(importer: TrainingDataImporter) -> TrainingResult:
79
79
  """Train the Rasa model."""
80
80
  try:
81
81
  structlogger.info("training.started")
rasa/cli/inspect.py CHANGED
@@ -9,6 +9,7 @@ from rasa import telemetry
9
9
  from rasa.cli import SubParsersAction
10
10
  from rasa.cli.arguments import shell as arguments
11
11
  from rasa.core import constants
12
+ from rasa.core.available_endpoints import AvailableEndpoints
12
13
  from rasa.engine.storage.local_model_storage import LocalModelStorage
13
14
  from rasa.exceptions import ModelNotFound
14
15
  from rasa.model import get_local_model
@@ -83,6 +84,12 @@ def inspect(args: argparse.Namespace) -> None:
83
84
 
84
85
  model = get_validated_path(args.model, "model", DEFAULT_MODELS_PATH)
85
86
 
87
+ # Load endpoints with proper endpoint file location
88
+ # This will initialise the endpoints singleton properly so that
89
+ # it can be used safely throughout the codebase with
90
+ # `AvailableEndpoints.get_instance()`
91
+ AvailableEndpoints.get_instance(args.endpoints)
92
+
86
93
  try:
87
94
  model = get_local_model(model)
88
95
  except ModelNotFound:
@@ -1,16 +1,17 @@
1
1
  import csv
2
2
  import logging
3
3
  from datetime import datetime
4
+ from typing import Dict, List
4
5
 
5
6
  from rasa_sdk import Action
6
7
  from rasa_sdk.events import SlotSet
7
8
 
8
9
 
9
10
  class ActionVerifyBillByDate(Action):
10
- def name(self):
11
+ def name(self) -> str:
11
12
  return "action_verify_bill_by_date"
12
13
 
13
- def text_to_date(month_text):
14
+ def text_to_date(self, month_text: str) -> str:
14
15
  try:
15
16
  # Get the current year
16
17
  current_year = datetime.now().year
@@ -28,7 +29,7 @@ class ActionVerifyBillByDate(Action):
28
29
  except ValueError:
29
30
  return "Invalid format. Please use a full month name (e.g., 'March')."
30
31
 
31
- def run(self, dispatcher, tracker, domain):
32
+ def run(self, dispatcher, tracker, domain) -> List[Dict]:
32
33
  # Get customer ID and date from slots
33
34
  customer_id = tracker.get_slot("customer_id")
34
35
  bill_month = tracker.get_slot("bill_month")
@@ -113,10 +114,10 @@ class ActionVerifyBillByDate(Action):
113
114
 
114
115
 
115
116
  class ActionRecapBill(Action):
116
- def name(self):
117
+ def name(self) -> str:
117
118
  return "action_recap_bill"
118
119
 
119
- def run(self, dispatcher, tracker, domain):
120
+ def run(self, dispatcher, tracker, domain) -> List[Dict]:
120
121
  # Get customer_id and bill_date from slots
121
122
  customer_id = tracker.get_slot("customer_id")
122
123
  bill_month = tracker.get_slot("bill_month")
@@ -1,14 +1,15 @@
1
1
  import csv
2
+ from typing import Dict, List
2
3
 
3
4
  from rasa_sdk import Action
4
5
  from rasa_sdk.events import SlotSet
5
6
 
6
7
 
7
8
  class ActionGetCustomerInfo(Action):
8
- def name(self):
9
+ def name(self) -> str:
9
10
  return "action_get_customer_info"
10
11
 
11
- def run(self, dispatcher, tracker, domain):
12
+ def run(self, dispatcher, tracker, domain) -> List[Dict]:
12
13
  # Load CSV file
13
14
  file_path = "csvs/customers.csv" # get information from your DBs
14
15
  customer_id = tracker.get_slot("customer_id")
rasa/cli/shell.py CHANGED
@@ -6,6 +6,7 @@ from typing import List
6
6
  from rasa import telemetry
7
7
  from rasa.cli import SubParsersAction
8
8
  from rasa.cli.arguments import shell as arguments
9
+ from rasa.core.available_endpoints import AvailableEndpoints
9
10
  from rasa.engine.storage.local_model_storage import LocalModelStorage
10
11
  from rasa.exceptions import ModelNotFound
11
12
  from rasa.model import get_local_model
@@ -105,7 +106,11 @@ def shell(args: argparse.Namespace) -> None:
105
106
  from rasa.shared.constants import DEFAULT_MODELS_PATH
106
107
 
107
108
  args.connector = "cmdline"
108
-
109
+ # Load endpoints with proper endpoint file location
110
+ # This will initialise the endpoints singleton properly so that
111
+ # it can be used safely throughout the codebase with
112
+ # `AvailableEndpoints.get_instance()`
113
+ AvailableEndpoints.get_instance(args.endpoints)
109
114
  model = get_validated_path(args.model, "model", DEFAULT_MODELS_PATH)
110
115
 
111
116
  try:
rasa/cli/train.py CHANGED
@@ -112,6 +112,10 @@ def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[T
112
112
  )
113
113
  config = rasa.cli.utils.get_validated_config(args.config, CONFIG_MANDATORY_KEYS)
114
114
 
115
+ # Validates and loads endpoints with proper endpoint file location
116
+ # This will initialise the endpoints singleton properly so that
117
+ # it can be used safely throughout the codebase with
118
+ # `AvailableEndpoints.get_instance()`
115
119
  _check_nlg_endpoint_validity(args.endpoints)
116
120
 
117
121
  training_files = [
@@ -109,7 +109,35 @@ class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict):
109
109
  await self.stream_events(tracker)
110
110
  serialized = self.serialise_tracker(tracker)
111
111
 
112
- self.db.put_item(Item=serialized)
112
+ full_tracker = await self.retrieve_full_tracker(tracker.sender_id)
113
+ if full_tracker is None:
114
+ self.db.put_item(Item=serialized)
115
+ return None
116
+
117
+ # return the latest events since the last user message
118
+ new_tracker = DialogueStateTracker.from_dict(
119
+ serialized["sender_id"], events_as_dict=serialized["events"]
120
+ )
121
+ new_events = new_tracker.get_last_turn_events()
122
+ new_serialized_events = [event.as_dict() for event in new_events]
123
+
124
+ # we need to save the full tracker if it is a new tracker
125
+ # without events following a user message
126
+ if not new_serialized_events:
127
+ self.db.put_item(Item=serialized)
128
+ return None
129
+
130
+ # append new events to the existing tracker
131
+ self.db.update_item(
132
+ Key={"sender_id": tracker.sender_id},
133
+ UpdateExpression="SET events = list_append(if_not_exists(events, :empty_list), :events)", # noqa: E501
134
+ ExpressionAttributeValues={
135
+ ":events": new_serialized_events,
136
+ ":empty_list": [],
137
+ },
138
+ ReturnValues="UPDATED_NEW",
139
+ )
140
+ return None
113
141
 
114
142
  async def delete(self, sender_id: Text) -> None:
115
143
  """Delete tracker for the given sender_id."""
@@ -181,7 +209,7 @@ class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict):
181
209
  events = rasa.utils.json_utils.replace_decimals_with_floats(
182
210
  dialogue["events"]
183
211
  )
184
- events_with_floats += events
212
+ events_with_floats.extend(events)
185
213
 
186
214
  if self.domain is None:
187
215
  slots = []
@@ -39,7 +39,6 @@ from rasa.model_manager.trainer_service import (
39
39
  update_training_status,
40
40
  )
41
41
  from rasa.model_manager.utils import (
42
- InvalidPathException,
43
42
  get_logs_content,
44
43
  logs_base_path,
45
44
  models_base_path,
@@ -52,7 +51,7 @@ from rasa.server import ErrorResponse
52
51
  from rasa.shared.exceptions import InvalidConfigException
53
52
  from rasa.shared.utils.yaml import dump_obj_as_yaml_to_string
54
53
  from rasa.studio.upload import build_calm_import_parts
55
- from rasa.utils.io import subpath
54
+ from rasa.utils.io import InvalidPathException, subpath
56
55
 
57
56
  dotenv.load_dotenv()
58
57
 
@@ -1170,6 +1170,23 @@ class DialogueStateTracker:
1170
1170
  "Example: `language: en`."
1171
1171
  )
1172
1172
 
1173
+ def get_last_turn_events(self) -> List[Event]:
1174
+ """Get all events of the last conversation turn."""
1175
+ last_user_message = self.get_last_event_for(
1176
+ UserUttered, event_verbosity=EventVerbosity.ALL
1177
+ )
1178
+ if not last_user_message:
1179
+ return []
1180
+
1181
+ last_turn_events = []
1182
+ for event in reversed(self.events):
1183
+ if event.timestamp >= last_user_message.timestamp:
1184
+ last_turn_events.append(event)
1185
+ else:
1186
+ break
1187
+
1188
+ return list(reversed(last_turn_events))
1189
+
1173
1190
 
1174
1191
  class TrackerEventDiffEngine:
1175
1192
  """Computes event difference of two trackers."""
@@ -1,9 +1,37 @@
1
- from typing import Iterable, List, Optional, Text
1
+ from typing import Any, Dict, Iterable, List, Optional, Text
2
+
3
+ from pydantic import BaseModel, Field
2
4
 
3
5
  from rasa.shared.core.domain import Domain
4
6
  from rasa.shared.core.flows import FlowsList
7
+ from rasa.shared.core.flows.yaml_flows_io import get_flows_as_json
5
8
  from rasa.shared.core.training_data.structures import StoryGraph
9
+ from rasa.shared.importers.importer import TrainingDataImporter
10
+ from rasa.shared.nlu.training_data.formats.rasa_yaml import RasaYAMLWriter
6
11
  from rasa.shared.nlu.training_data.training_data import TrainingData
12
+ from rasa.utils.json_utils import extract_values
13
+
14
+
15
+ class CALMUserData(BaseModel):
16
+ """All pieces that will be uploaded to Rasa Studio."""
17
+
18
+ flows: Dict[str, Any] = Field(default_factory=dict)
19
+ domain: Dict[str, Any] = Field(default_factory=dict)
20
+ config: Dict[str, Any] = Field(default_factory=dict)
21
+ endpoints: Dict[str, Any] = Field(default_factory=dict)
22
+ nlu: Dict[str, Any] = Field(default_factory=dict)
23
+
24
+
25
+ DOMAIN_KEYS = [
26
+ "version",
27
+ "actions",
28
+ "responses",
29
+ "slots",
30
+ "intents",
31
+ "entities",
32
+ "forms",
33
+ "session_config",
34
+ ]
7
35
 
8
36
 
9
37
  def training_data_from_paths(paths: Iterable[Text], language: Text) -> TrainingData:
@@ -34,3 +62,51 @@ def flows_from_paths(files: List[Text]) -> FlowsList:
34
62
  )
35
63
  flows.validate()
36
64
  return flows
65
+
66
+
67
+ def extract_calm_import_parts_from_importer(
68
+ importer: TrainingDataImporter,
69
+ config: Optional[Dict[str, Any]] = None,
70
+ endpoints: Optional[Dict[str, Any]] = None,
71
+ ) -> CALMUserData:
72
+ """Extracts CALMUserData from a TrainingDataImporter.
73
+
74
+ Args:
75
+ importer: The training data importer
76
+ data_paths: The path(s) to the training data for flows
77
+ config: Optional config dict, if not provided will use importer.get_config()
78
+ endpoints: Optional endpoints dict, defaults to empty dict
79
+
80
+ Returns:
81
+ CALMUserData containing flows, domain, config, endpoints, and nlu data
82
+ """
83
+ # Extract config
84
+ if config is None:
85
+ config = importer.get_config()
86
+
87
+ # Extract domain
88
+ domain_from_files = importer.get_user_domain().as_dict()
89
+ domain = extract_values(domain_from_files, DOMAIN_KEYS)
90
+
91
+ # Extract flows
92
+ flows = importer.get_user_flows()
93
+ flows_dict = get_flows_as_json(flows)
94
+
95
+ # Extract NLU data
96
+ nlu_data = importer.get_nlu_data()
97
+ nlu_examples = nlu_data.filter_training_examples(
98
+ lambda ex: ex.get("intent") in nlu_data.intents
99
+ )
100
+ nlu_dict = RasaYAMLWriter().training_data_to_dict(nlu_examples)
101
+
102
+ # Use provided endpoints or default to empty dict
103
+ if endpoints is None:
104
+ endpoints = {}
105
+
106
+ return CALMUserData(
107
+ flows=flows_dict or {},
108
+ domain=domain or {},
109
+ config=config or {},
110
+ endpoints=endpoints or {},
111
+ nlu=nlu_dict or {},
112
+ )
rasa/studio/upload.py CHANGED
@@ -7,7 +7,6 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Text, Tuple, Union
7
7
  import questionary
8
8
  import requests
9
9
  import structlog
10
- from pydantic import BaseModel, Field
11
10
 
12
11
  import rasa.cli.telemetry
13
12
  import rasa.cli.utils
@@ -24,9 +23,13 @@ from rasa.shared.constants import (
24
23
  DEFAULT_DOMAIN_PATHS,
25
24
  )
26
25
  from rasa.shared.core.domain import Domain
27
- from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader, YamlFlowsWriter
26
+ from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader
28
27
  from rasa.shared.exceptions import RasaException
29
- from rasa.shared.importers.importer import FlowSyncImporter, TrainingDataImporter
28
+ from rasa.shared.importers.importer import TrainingDataImporter
29
+ from rasa.shared.importers.utils import (
30
+ CALMUserData,
31
+ extract_calm_import_parts_from_importer,
32
+ )
30
33
  from rasa.shared.nlu.training_data.formats.rasa_yaml import (
31
34
  RasaYAMLReader,
32
35
  RasaYAMLWriter,
@@ -34,7 +37,6 @@ from rasa.shared.nlu.training_data.formats.rasa_yaml import (
34
37
  from rasa.shared.utils.llm import collect_custom_prompts
35
38
  from rasa.shared.utils.yaml import (
36
39
  dump_obj_as_yaml_to_string,
37
- read_yaml,
38
40
  read_yaml_file,
39
41
  )
40
42
  from rasa.studio import results_logger
@@ -43,6 +45,7 @@ from rasa.studio.config import StudioConfig
43
45
  from rasa.studio.results_logger import StudioResult, with_studio_error_handler
44
46
  from rasa.studio.utils import validate_argument_paths
45
47
  from rasa.telemetry import track_upload_to_studio_failed
48
+ from rasa.utils.json_utils import extract_values
46
49
 
47
50
  structlogger = structlog.get_logger()
48
51
 
@@ -68,16 +71,6 @@ DOMAIN_KEYS = [
68
71
  ]
69
72
 
70
73
 
71
- class CALMImportParts(BaseModel):
72
- """All pieces that will be uploaded to Rasa Studio."""
73
-
74
- flows: Dict[str, Any]
75
- domain: Dict[str, Any]
76
- config: Dict[str, Any]
77
- endpoints: Dict[str, Any]
78
- nlu: Dict[str, Any] = Field(default_factory=dict)
79
-
80
-
81
74
  def _get_selected_entities_and_intents(
82
75
  args: argparse.Namespace,
83
76
  intents_from_files: Set[Text],
@@ -171,6 +164,7 @@ def handle_upload(args: argparse.Namespace) -> None:
171
164
 
172
165
  config = read_yaml_file(args.config, expand_env_vars=False)
173
166
  assistant_name = args.assistant_name or _get_assistant_name(config)
167
+ args.assistant_name = assistant_name
174
168
  if not _handle_existing_assistant(
175
169
  assistant_name, studio_config.studio_url, verify, args
176
170
  ):
@@ -193,11 +187,6 @@ config_keys = [
193
187
  ]
194
188
 
195
189
 
196
- def extract_values(data: Dict, keys: List[Text]) -> Dict:
197
- """Extracts values for given keys from a dictionary."""
198
- return {key: data.get(key) for key in keys if data.get(key)}
199
-
200
-
201
190
  def _get_assistant_name(config: Dict[Text, Any]) -> str:
202
191
  config_assistant_id = config.get("assistant_id", "")
203
192
  assistant_name = questionary.text(
@@ -236,7 +225,7 @@ def build_calm_import_parts(
236
225
  config_path: Text,
237
226
  endpoints_path: Optional[Text] = None,
238
227
  assistant_name: Optional[Text] = None,
239
- ) -> Tuple[str, CALMImportParts]:
228
+ ) -> Tuple[str, CALMUserData]:
240
229
  """Builds the parts of the assistant to be uploaded to Studio.
241
230
 
242
231
  Args:
@@ -259,33 +248,10 @@ def build_calm_import_parts(
259
248
  endpoints = read_yaml_file(endpoints_path, expand_env_vars=False)
260
249
  assistant_name = assistant_name or _get_assistant_name(config)
261
250
 
262
- domain_from_files = importer.get_user_domain().as_dict()
263
- domain = extract_values(domain_from_files, DOMAIN_KEYS)
264
-
265
- flow_importer = FlowSyncImporter.load_from_dict(
266
- training_data_paths=[str(data_path)], expand_env_vars=False
267
- )
268
-
269
- flows = list(flow_importer.get_user_flows())
270
- flows_yaml = YamlFlowsWriter().dumps(flows)
271
- flows = read_yaml(flows_yaml, expand_env_vars=False)
272
-
273
- nlu_importer = TrainingDataImporter.load_from_dict(
274
- training_data_paths=[str(data_path)], expand_env_vars=False
275
- )
276
- nlu_data = nlu_importer.get_nlu_data()
277
- nlu_examples = nlu_data.filter_training_examples(
278
- lambda ex: ex.get("intent") in nlu_data.intents
279
- )
280
- nlu_examples_yaml = RasaYAMLWriter().dumps(nlu_examples)
281
- nlu = read_yaml(nlu_examples_yaml, expand_env_vars=False)
282
-
283
- parts = CALMImportParts(
284
- flows=flows,
285
- domain=domain,
251
+ parts = extract_calm_import_parts_from_importer(
252
+ importer=importer,
286
253
  config=config,
287
254
  endpoints=endpoints,
288
- nlu=nlu,
289
255
  )
290
256
 
291
257
  return assistant_name, parts
rasa/utils/json_utils.py CHANGED
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from decimal import Decimal
3
- from typing import Any, Text
3
+ from typing import Any, Dict, List, Text
4
4
 
5
5
 
6
6
  class DecimalEncoder(json.JSONEncoder):
@@ -58,3 +58,8 @@ def replace_decimals_with_floats(obj: Any) -> Any:
58
58
  Input `obj` with all `Decimal` types replaced by `float`s.
59
59
  """
60
60
  return json.loads(json.dumps(obj, cls=DecimalEncoder))
61
+
62
+
63
+ def extract_values(data: Dict, keys: List[Text]) -> Dict:
64
+ """Extracts values for given keys from a dictionary."""
65
+ return {key: data.get(key) for key in keys if data.get(key)}
rasa/utils/openapi.py ADDED
@@ -0,0 +1,144 @@
1
+ from typing import Any, Dict, List, Type
2
+
3
+ from pydantic.main import BaseModel
4
+ from sanic_openapi import openapi
5
+ from sanic_openapi.openapi3.types import Schema
6
+
7
+ _SUPPORTED_ATTRIBUTES = frozenset(["format", "enum", "required", "example"])
8
+
9
+
10
+ def _to_schema(
11
+ definition_stack: List[str], schema_def: Dict[str, Any], definitions: Dict[str, Any]
12
+ ) -> Schema:
13
+ type = schema_def.get("type")
14
+
15
+ if type == "object":
16
+ properties_spec = schema_def.get("properties", {})
17
+ properties = {}
18
+ for key in properties_spec:
19
+ properties[key] = _to_schema(
20
+ definition_stack=definition_stack,
21
+ schema_def=properties_spec[key],
22
+ definitions=definitions,
23
+ )
24
+ schema = openapi.Object(
25
+ title=schema_def.get("title"),
26
+ description=schema_def.get("description"),
27
+ required=schema_def.get("required"),
28
+ properties=properties,
29
+ )
30
+ elif type == "array":
31
+ schema = openapi.Array(
32
+ description=schema_def.get("description"),
33
+ required=schema_def.get("required"),
34
+ items=_to_schema(
35
+ definition_stack=definition_stack,
36
+ schema_def=schema_def.get("items"),
37
+ definitions=definitions,
38
+ ),
39
+ )
40
+ elif type is None:
41
+ if allof_spec := schema_def.get("allOf"): # Model, Enum
42
+ definition = allof_spec[0]["$ref"].split("/")[-1]
43
+ definition_data = definitions.get(definition)
44
+ if definition_data is None:
45
+ schema = openapi.Object(
46
+ title=definition, description=schema_def.get("description")
47
+ )
48
+ else:
49
+ schema = (
50
+ _to_schema(
51
+ definition_stack=definition_stack + [definition],
52
+ schema_def={**definition_data},
53
+ definitions=definitions,
54
+ )
55
+ if definition not in definition_stack
56
+ else openapi.Object(
57
+ title=definition, description=schema_def.get("description")
58
+ )
59
+ )
60
+
61
+ elif anyof_spec := schema_def.get("anyOf"): # Union
62
+ anyof = []
63
+ for any in anyof_spec:
64
+ if any.get("type"):
65
+ schema_type_obj = Schema(
66
+ **{
67
+ "type": any.get("type"),
68
+ "description": any.get("description"),
69
+ }
70
+ )
71
+ anyof.append(schema_type_obj)
72
+ else:
73
+ definition = any["$ref"].split("/")[-1]
74
+ if definition not in definition_stack:
75
+ definition_data = definitions.get(definition)
76
+ if definition_data is not None:
77
+ anyof.append(
78
+ _to_schema(
79
+ definition_stack=definition_stack + [definition],
80
+ schema_def=definition_data,
81
+ definitions=definitions,
82
+ )
83
+ )
84
+ else:
85
+ anyof.append(
86
+ openapi.Object(
87
+ title=definition,
88
+ description=schema_def.get(
89
+ "description", definition
90
+ ),
91
+ properties={},
92
+ )
93
+ )
94
+ else:
95
+ anyof.append(
96
+ openapi.Object(
97
+ title=definition,
98
+ description=schema_def.get("description", definition),
99
+ properties={},
100
+ )
101
+ )
102
+ schema = Schema(anyOf=anyof)
103
+ elif ref := schema_def.get("$ref"): # $ref
104
+ definition = ref.split("/")[-1]
105
+ definition_data = definitions.get(definition)
106
+ if definition_data is not None:
107
+ schema = _to_schema(
108
+ definition_stack=definition_stack,
109
+ schema_def=definition_data,
110
+ definitions=definitions,
111
+ )
112
+ else:
113
+ schema = openapi.Object(
114
+ title=definition, description=schema_def.get("description")
115
+ )
116
+ else: # Any type
117
+ schema = Schema(
118
+ **{"type": "object", "description": schema_def.get("description")}
119
+ )
120
+
121
+ else:
122
+ schema_spec = {
123
+ "type": schema_def.get("type"),
124
+ "description": schema_def.get("description"),
125
+ }
126
+ for spec in _SUPPORTED_ATTRIBUTES:
127
+ if schema_def.get(spec):
128
+ schema_spec[spec] = schema_def.get(spec)
129
+ schema = Schema(**schema_spec)
130
+
131
+ return schema
132
+
133
+
134
+ def model_to_schema(model: Type[BaseModel]) -> Schema:
135
+ schema = model.model_json_schema()
136
+ # Handle both $defs (newer JSON Schema) and definitions (older JSON Schema)
137
+ definitions = schema.get("$defs") or schema.get("definitions") or {}
138
+ return _to_schema(
139
+ definition_stack=[],
140
+ schema_def=dict(
141
+ filter(lambda key: key[0] not in ("definitions", "$defs"), schema.items())
142
+ ),
143
+ definitions=definitions,
144
+ )
rasa/utils/plotting.py CHANGED
@@ -99,7 +99,7 @@ def plot_confusion_matrix(
99
99
  zmax = confusion_matrix.max() if len(confusion_matrix) > 0 else 1
100
100
  plt.clf()
101
101
  if not color_map:
102
- color_map = plt.cm.Blues
102
+ color_map = plt.cm.get_cmap("Blues")
103
103
  plt.imshow(
104
104
  confusion_matrix,
105
105
  interpolation="nearest",
rasa/version.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # this file will automatically be changed,
2
2
  # do not add anything but the version number here!
3
- __version__ = "3.13.1a14"
3
+ __version__ = "3.13.1a15"