rasa-pro 3.13.1a14__py3-none-any.whl → 3.13.1a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/builder/config.py +11 -1
- rasa/builder/exceptions.py +6 -0
- rasa/builder/inkeep-rag-response-schema.json +64 -0
- rasa/builder/inkeep_document_retrieval.py +212 -0
- rasa/builder/llm_service.py +22 -32
- rasa/builder/main.py +95 -9
- rasa/builder/models.py +61 -10
- rasa/builder/project_generator.py +7 -6
- rasa/builder/scrape_rasa_docs.py +4 -4
- rasa/builder/service.py +626 -436
- rasa/builder/training_service.py +3 -3
- rasa/cli/inspect.py +7 -0
- rasa/cli/project_templates/telco/actions/actions_billing.py +6 -5
- rasa/cli/project_templates/telco/actions/actions_get_data_from_db.py +3 -2
- rasa/cli/shell.py +6 -1
- rasa/cli/train.py +4 -0
- rasa/core/tracker_stores/dynamo_tracker_store.py +30 -2
- rasa/model_manager/model_api.py +1 -2
- rasa/shared/core/trackers.py +17 -0
- rasa/shared/importers/utils.py +77 -1
- rasa/studio/upload.py +11 -45
- rasa/utils/json_utils.py +6 -1
- rasa/utils/openapi.py +144 -0
- rasa/utils/plotting.py +1 -1
- rasa/version.py +1 -1
- {rasa_pro-3.13.1a14.dist-info → rasa_pro-3.13.1a16.dist-info}/METADATA +10 -9
- {rasa_pro-3.13.1a14.dist-info → rasa_pro-3.13.1a16.dist-info}/RECORD +30 -27
- {rasa_pro-3.13.1a14.dist-info → rasa_pro-3.13.1a16.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.1a14.dist-info → rasa_pro-3.13.1a16.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.1a14.dist-info → rasa_pro-3.13.1a16.dist-info}/entry_points.txt +0 -0
rasa/builder/training_service.py
CHANGED
|
@@ -9,7 +9,7 @@ from rasa.builder import config
|
|
|
9
9
|
from rasa.builder.exceptions import AgentLoadError, TrainingError
|
|
10
10
|
from rasa.core import agent
|
|
11
11
|
from rasa.core.utils import AvailableEndpoints, read_endpoints_from_path
|
|
12
|
-
from rasa.model_training import train
|
|
12
|
+
from rasa.model_training import TrainingResult, train
|
|
13
13
|
from rasa.shared.importers.importer import TrainingDataImporter
|
|
14
14
|
from rasa.shared.utils.yaml import dump_obj_as_yaml_to_string
|
|
15
15
|
|
|
@@ -55,7 +55,7 @@ async def train_and_load_agent(importer: TrainingDataImporter) -> agent.Agent:
|
|
|
55
55
|
raise TrainingError(f"SystemExit during training: {e}")
|
|
56
56
|
|
|
57
57
|
|
|
58
|
-
async def _setup_endpoints():
|
|
58
|
+
async def _setup_endpoints() -> None:
|
|
59
59
|
"""Setup endpoints configuration for training."""
|
|
60
60
|
try:
|
|
61
61
|
with tempfile.NamedTemporaryFile(
|
|
@@ -75,7 +75,7 @@ async def _setup_endpoints():
|
|
|
75
75
|
raise TrainingError(f"Failed to setup endpoints: {e}")
|
|
76
76
|
|
|
77
77
|
|
|
78
|
-
async def _train_model(importer: TrainingDataImporter):
|
|
78
|
+
async def _train_model(importer: TrainingDataImporter) -> TrainingResult:
|
|
79
79
|
"""Train the Rasa model."""
|
|
80
80
|
try:
|
|
81
81
|
structlogger.info("training.started")
|
rasa/cli/inspect.py
CHANGED
|
@@ -9,6 +9,7 @@ from rasa import telemetry
|
|
|
9
9
|
from rasa.cli import SubParsersAction
|
|
10
10
|
from rasa.cli.arguments import shell as arguments
|
|
11
11
|
from rasa.core import constants
|
|
12
|
+
from rasa.core.available_endpoints import AvailableEndpoints
|
|
12
13
|
from rasa.engine.storage.local_model_storage import LocalModelStorage
|
|
13
14
|
from rasa.exceptions import ModelNotFound
|
|
14
15
|
from rasa.model import get_local_model
|
|
@@ -83,6 +84,12 @@ def inspect(args: argparse.Namespace) -> None:
|
|
|
83
84
|
|
|
84
85
|
model = get_validated_path(args.model, "model", DEFAULT_MODELS_PATH)
|
|
85
86
|
|
|
87
|
+
# Load endpoints with proper endpoint file location
|
|
88
|
+
# This will initialise the endpoints singleton properly so that
|
|
89
|
+
# it can be used safely throughout the codebase with
|
|
90
|
+
# `AvailableEndpoints.get_instance()`
|
|
91
|
+
AvailableEndpoints.get_instance(args.endpoints)
|
|
92
|
+
|
|
86
93
|
try:
|
|
87
94
|
model = get_local_model(model)
|
|
88
95
|
except ModelNotFound:
|
|
@@ -1,16 +1,17 @@
|
|
|
1
1
|
import csv
|
|
2
2
|
import logging
|
|
3
3
|
from datetime import datetime
|
|
4
|
+
from typing import Dict, List
|
|
4
5
|
|
|
5
6
|
from rasa_sdk import Action
|
|
6
7
|
from rasa_sdk.events import SlotSet
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class ActionVerifyBillByDate(Action):
|
|
10
|
-
def name(self):
|
|
11
|
+
def name(self) -> str:
|
|
11
12
|
return "action_verify_bill_by_date"
|
|
12
13
|
|
|
13
|
-
def text_to_date(month_text):
|
|
14
|
+
def text_to_date(self, month_text: str) -> str:
|
|
14
15
|
try:
|
|
15
16
|
# Get the current year
|
|
16
17
|
current_year = datetime.now().year
|
|
@@ -28,7 +29,7 @@ class ActionVerifyBillByDate(Action):
|
|
|
28
29
|
except ValueError:
|
|
29
30
|
return "Invalid format. Please use a full month name (e.g., 'March')."
|
|
30
31
|
|
|
31
|
-
def run(self, dispatcher, tracker, domain):
|
|
32
|
+
def run(self, dispatcher, tracker, domain) -> List[Dict]:
|
|
32
33
|
# Get customer ID and date from slots
|
|
33
34
|
customer_id = tracker.get_slot("customer_id")
|
|
34
35
|
bill_month = tracker.get_slot("bill_month")
|
|
@@ -113,10 +114,10 @@ class ActionVerifyBillByDate(Action):
|
|
|
113
114
|
|
|
114
115
|
|
|
115
116
|
class ActionRecapBill(Action):
|
|
116
|
-
def name(self):
|
|
117
|
+
def name(self) -> str:
|
|
117
118
|
return "action_recap_bill"
|
|
118
119
|
|
|
119
|
-
def run(self, dispatcher, tracker, domain):
|
|
120
|
+
def run(self, dispatcher, tracker, domain) -> List[Dict]:
|
|
120
121
|
# Get customer_id and bill_date from slots
|
|
121
122
|
customer_id = tracker.get_slot("customer_id")
|
|
122
123
|
bill_month = tracker.get_slot("bill_month")
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
import csv
|
|
2
|
+
from typing import Dict, List
|
|
2
3
|
|
|
3
4
|
from rasa_sdk import Action
|
|
4
5
|
from rasa_sdk.events import SlotSet
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class ActionGetCustomerInfo(Action):
|
|
8
|
-
def name(self):
|
|
9
|
+
def name(self) -> str:
|
|
9
10
|
return "action_get_customer_info"
|
|
10
11
|
|
|
11
|
-
def run(self, dispatcher, tracker, domain):
|
|
12
|
+
def run(self, dispatcher, tracker, domain) -> List[Dict]:
|
|
12
13
|
# Load CSV file
|
|
13
14
|
file_path = "csvs/customers.csv" # get information from your DBs
|
|
14
15
|
customer_id = tracker.get_slot("customer_id")
|
rasa/cli/shell.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import List
|
|
|
6
6
|
from rasa import telemetry
|
|
7
7
|
from rasa.cli import SubParsersAction
|
|
8
8
|
from rasa.cli.arguments import shell as arguments
|
|
9
|
+
from rasa.core.available_endpoints import AvailableEndpoints
|
|
9
10
|
from rasa.engine.storage.local_model_storage import LocalModelStorage
|
|
10
11
|
from rasa.exceptions import ModelNotFound
|
|
11
12
|
from rasa.model import get_local_model
|
|
@@ -105,7 +106,11 @@ def shell(args: argparse.Namespace) -> None:
|
|
|
105
106
|
from rasa.shared.constants import DEFAULT_MODELS_PATH
|
|
106
107
|
|
|
107
108
|
args.connector = "cmdline"
|
|
108
|
-
|
|
109
|
+
# Load endpoints with proper endpoint file location
|
|
110
|
+
# This will initialise the endpoints singleton properly so that
|
|
111
|
+
# it can be used safely throughout the codebase with
|
|
112
|
+
# `AvailableEndpoints.get_instance()`
|
|
113
|
+
AvailableEndpoints.get_instance(args.endpoints)
|
|
109
114
|
model = get_validated_path(args.model, "model", DEFAULT_MODELS_PATH)
|
|
110
115
|
|
|
111
116
|
try:
|
rasa/cli/train.py
CHANGED
|
@@ -112,6 +112,10 @@ def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[T
|
|
|
112
112
|
)
|
|
113
113
|
config = rasa.cli.utils.get_validated_config(args.config, CONFIG_MANDATORY_KEYS)
|
|
114
114
|
|
|
115
|
+
# Validates and loads endpoints with proper endpoint file location
|
|
116
|
+
# This will initialise the endpoints singleton properly so that
|
|
117
|
+
# it can be used safely throughout the codebase with
|
|
118
|
+
# `AvailableEndpoints.get_instance()`
|
|
115
119
|
_check_nlg_endpoint_validity(args.endpoints)
|
|
116
120
|
|
|
117
121
|
training_files = [
|
|
@@ -109,7 +109,35 @@ class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict):
|
|
|
109
109
|
await self.stream_events(tracker)
|
|
110
110
|
serialized = self.serialise_tracker(tracker)
|
|
111
111
|
|
|
112
|
-
self.
|
|
112
|
+
full_tracker = await self.retrieve_full_tracker(tracker.sender_id)
|
|
113
|
+
if full_tracker is None:
|
|
114
|
+
self.db.put_item(Item=serialized)
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
# return the latest events since the last user message
|
|
118
|
+
new_tracker = DialogueStateTracker.from_dict(
|
|
119
|
+
serialized["sender_id"], events_as_dict=serialized["events"]
|
|
120
|
+
)
|
|
121
|
+
new_events = new_tracker.get_last_turn_events()
|
|
122
|
+
new_serialized_events = [event.as_dict() for event in new_events]
|
|
123
|
+
|
|
124
|
+
# we need to save the full tracker if it is a new tracker
|
|
125
|
+
# without events following a user message
|
|
126
|
+
if not new_serialized_events:
|
|
127
|
+
self.db.put_item(Item=serialized)
|
|
128
|
+
return None
|
|
129
|
+
|
|
130
|
+
# append new events to the existing tracker
|
|
131
|
+
self.db.update_item(
|
|
132
|
+
Key={"sender_id": tracker.sender_id},
|
|
133
|
+
UpdateExpression="SET events = list_append(if_not_exists(events, :empty_list), :events)", # noqa: E501
|
|
134
|
+
ExpressionAttributeValues={
|
|
135
|
+
":events": new_serialized_events,
|
|
136
|
+
":empty_list": [],
|
|
137
|
+
},
|
|
138
|
+
ReturnValues="UPDATED_NEW",
|
|
139
|
+
)
|
|
140
|
+
return None
|
|
113
141
|
|
|
114
142
|
async def delete(self, sender_id: Text) -> None:
|
|
115
143
|
"""Delete tracker for the given sender_id."""
|
|
@@ -181,7 +209,7 @@ class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict):
|
|
|
181
209
|
events = rasa.utils.json_utils.replace_decimals_with_floats(
|
|
182
210
|
dialogue["events"]
|
|
183
211
|
)
|
|
184
|
-
events_with_floats
|
|
212
|
+
events_with_floats.extend(events)
|
|
185
213
|
|
|
186
214
|
if self.domain is None:
|
|
187
215
|
slots = []
|
rasa/model_manager/model_api.py
CHANGED
|
@@ -39,7 +39,6 @@ from rasa.model_manager.trainer_service import (
|
|
|
39
39
|
update_training_status,
|
|
40
40
|
)
|
|
41
41
|
from rasa.model_manager.utils import (
|
|
42
|
-
InvalidPathException,
|
|
43
42
|
get_logs_content,
|
|
44
43
|
logs_base_path,
|
|
45
44
|
models_base_path,
|
|
@@ -52,7 +51,7 @@ from rasa.server import ErrorResponse
|
|
|
52
51
|
from rasa.shared.exceptions import InvalidConfigException
|
|
53
52
|
from rasa.shared.utils.yaml import dump_obj_as_yaml_to_string
|
|
54
53
|
from rasa.studio.upload import build_calm_import_parts
|
|
55
|
-
from rasa.utils.io import subpath
|
|
54
|
+
from rasa.utils.io import InvalidPathException, subpath
|
|
56
55
|
|
|
57
56
|
dotenv.load_dotenv()
|
|
58
57
|
|
rasa/shared/core/trackers.py
CHANGED
|
@@ -1170,6 +1170,23 @@ class DialogueStateTracker:
|
|
|
1170
1170
|
"Example: `language: en`."
|
|
1171
1171
|
)
|
|
1172
1172
|
|
|
1173
|
+
def get_last_turn_events(self) -> List[Event]:
|
|
1174
|
+
"""Get all events of the last conversation turn."""
|
|
1175
|
+
last_user_message = self.get_last_event_for(
|
|
1176
|
+
UserUttered, event_verbosity=EventVerbosity.ALL
|
|
1177
|
+
)
|
|
1178
|
+
if not last_user_message:
|
|
1179
|
+
return []
|
|
1180
|
+
|
|
1181
|
+
last_turn_events = []
|
|
1182
|
+
for event in reversed(self.events):
|
|
1183
|
+
if event.timestamp >= last_user_message.timestamp:
|
|
1184
|
+
last_turn_events.append(event)
|
|
1185
|
+
else:
|
|
1186
|
+
break
|
|
1187
|
+
|
|
1188
|
+
return list(reversed(last_turn_events))
|
|
1189
|
+
|
|
1173
1190
|
|
|
1174
1191
|
class TrackerEventDiffEngine:
|
|
1175
1192
|
"""Computes event difference of two trackers."""
|
rasa/shared/importers/utils.py
CHANGED
|
@@ -1,9 +1,37 @@
|
|
|
1
|
-
from typing import Iterable, List, Optional, Text
|
|
1
|
+
from typing import Any, Dict, Iterable, List, Optional, Text
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
2
4
|
|
|
3
5
|
from rasa.shared.core.domain import Domain
|
|
4
6
|
from rasa.shared.core.flows import FlowsList
|
|
7
|
+
from rasa.shared.core.flows.yaml_flows_io import get_flows_as_json
|
|
5
8
|
from rasa.shared.core.training_data.structures import StoryGraph
|
|
9
|
+
from rasa.shared.importers.importer import TrainingDataImporter
|
|
10
|
+
from rasa.shared.nlu.training_data.formats.rasa_yaml import RasaYAMLWriter
|
|
6
11
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
12
|
+
from rasa.utils.json_utils import extract_values
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CALMUserData(BaseModel):
|
|
16
|
+
"""All pieces that will be uploaded to Rasa Studio."""
|
|
17
|
+
|
|
18
|
+
flows: Dict[str, Any] = Field(default_factory=dict)
|
|
19
|
+
domain: Dict[str, Any] = Field(default_factory=dict)
|
|
20
|
+
config: Dict[str, Any] = Field(default_factory=dict)
|
|
21
|
+
endpoints: Dict[str, Any] = Field(default_factory=dict)
|
|
22
|
+
nlu: Dict[str, Any] = Field(default_factory=dict)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
DOMAIN_KEYS = [
|
|
26
|
+
"version",
|
|
27
|
+
"actions",
|
|
28
|
+
"responses",
|
|
29
|
+
"slots",
|
|
30
|
+
"intents",
|
|
31
|
+
"entities",
|
|
32
|
+
"forms",
|
|
33
|
+
"session_config",
|
|
34
|
+
]
|
|
7
35
|
|
|
8
36
|
|
|
9
37
|
def training_data_from_paths(paths: Iterable[Text], language: Text) -> TrainingData:
|
|
@@ -34,3 +62,51 @@ def flows_from_paths(files: List[Text]) -> FlowsList:
|
|
|
34
62
|
)
|
|
35
63
|
flows.validate()
|
|
36
64
|
return flows
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def extract_calm_import_parts_from_importer(
|
|
68
|
+
importer: TrainingDataImporter,
|
|
69
|
+
config: Optional[Dict[str, Any]] = None,
|
|
70
|
+
endpoints: Optional[Dict[str, Any]] = None,
|
|
71
|
+
) -> CALMUserData:
|
|
72
|
+
"""Extracts CALMUserData from a TrainingDataImporter.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
importer: The training data importer
|
|
76
|
+
data_paths: The path(s) to the training data for flows
|
|
77
|
+
config: Optional config dict, if not provided will use importer.get_config()
|
|
78
|
+
endpoints: Optional endpoints dict, defaults to empty dict
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
CALMUserData containing flows, domain, config, endpoints, and nlu data
|
|
82
|
+
"""
|
|
83
|
+
# Extract config
|
|
84
|
+
if config is None:
|
|
85
|
+
config = importer.get_config()
|
|
86
|
+
|
|
87
|
+
# Extract domain
|
|
88
|
+
domain_from_files = importer.get_user_domain().as_dict()
|
|
89
|
+
domain = extract_values(domain_from_files, DOMAIN_KEYS)
|
|
90
|
+
|
|
91
|
+
# Extract flows
|
|
92
|
+
flows = importer.get_user_flows()
|
|
93
|
+
flows_dict = get_flows_as_json(flows)
|
|
94
|
+
|
|
95
|
+
# Extract NLU data
|
|
96
|
+
nlu_data = importer.get_nlu_data()
|
|
97
|
+
nlu_examples = nlu_data.filter_training_examples(
|
|
98
|
+
lambda ex: ex.get("intent") in nlu_data.intents
|
|
99
|
+
)
|
|
100
|
+
nlu_dict = RasaYAMLWriter().training_data_to_dict(nlu_examples)
|
|
101
|
+
|
|
102
|
+
# Use provided endpoints or default to empty dict
|
|
103
|
+
if endpoints is None:
|
|
104
|
+
endpoints = {}
|
|
105
|
+
|
|
106
|
+
return CALMUserData(
|
|
107
|
+
flows=flows_dict or {},
|
|
108
|
+
domain=domain or {},
|
|
109
|
+
config=config or {},
|
|
110
|
+
endpoints=endpoints or {},
|
|
111
|
+
nlu=nlu_dict or {},
|
|
112
|
+
)
|
rasa/studio/upload.py
CHANGED
|
@@ -7,7 +7,6 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Text, Tuple, Union
|
|
|
7
7
|
import questionary
|
|
8
8
|
import requests
|
|
9
9
|
import structlog
|
|
10
|
-
from pydantic import BaseModel, Field
|
|
11
10
|
|
|
12
11
|
import rasa.cli.telemetry
|
|
13
12
|
import rasa.cli.utils
|
|
@@ -24,9 +23,13 @@ from rasa.shared.constants import (
|
|
|
24
23
|
DEFAULT_DOMAIN_PATHS,
|
|
25
24
|
)
|
|
26
25
|
from rasa.shared.core.domain import Domain
|
|
27
|
-
from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader
|
|
26
|
+
from rasa.shared.core.flows.yaml_flows_io import YAMLFlowsReader
|
|
28
27
|
from rasa.shared.exceptions import RasaException
|
|
29
|
-
from rasa.shared.importers.importer import
|
|
28
|
+
from rasa.shared.importers.importer import TrainingDataImporter
|
|
29
|
+
from rasa.shared.importers.utils import (
|
|
30
|
+
CALMUserData,
|
|
31
|
+
extract_calm_import_parts_from_importer,
|
|
32
|
+
)
|
|
30
33
|
from rasa.shared.nlu.training_data.formats.rasa_yaml import (
|
|
31
34
|
RasaYAMLReader,
|
|
32
35
|
RasaYAMLWriter,
|
|
@@ -34,7 +37,6 @@ from rasa.shared.nlu.training_data.formats.rasa_yaml import (
|
|
|
34
37
|
from rasa.shared.utils.llm import collect_custom_prompts
|
|
35
38
|
from rasa.shared.utils.yaml import (
|
|
36
39
|
dump_obj_as_yaml_to_string,
|
|
37
|
-
read_yaml,
|
|
38
40
|
read_yaml_file,
|
|
39
41
|
)
|
|
40
42
|
from rasa.studio import results_logger
|
|
@@ -43,6 +45,7 @@ from rasa.studio.config import StudioConfig
|
|
|
43
45
|
from rasa.studio.results_logger import StudioResult, with_studio_error_handler
|
|
44
46
|
from rasa.studio.utils import validate_argument_paths
|
|
45
47
|
from rasa.telemetry import track_upload_to_studio_failed
|
|
48
|
+
from rasa.utils.json_utils import extract_values
|
|
46
49
|
|
|
47
50
|
structlogger = structlog.get_logger()
|
|
48
51
|
|
|
@@ -68,16 +71,6 @@ DOMAIN_KEYS = [
|
|
|
68
71
|
]
|
|
69
72
|
|
|
70
73
|
|
|
71
|
-
class CALMImportParts(BaseModel):
|
|
72
|
-
"""All pieces that will be uploaded to Rasa Studio."""
|
|
73
|
-
|
|
74
|
-
flows: Dict[str, Any]
|
|
75
|
-
domain: Dict[str, Any]
|
|
76
|
-
config: Dict[str, Any]
|
|
77
|
-
endpoints: Dict[str, Any]
|
|
78
|
-
nlu: Dict[str, Any] = Field(default_factory=dict)
|
|
79
|
-
|
|
80
|
-
|
|
81
74
|
def _get_selected_entities_and_intents(
|
|
82
75
|
args: argparse.Namespace,
|
|
83
76
|
intents_from_files: Set[Text],
|
|
@@ -171,6 +164,7 @@ def handle_upload(args: argparse.Namespace) -> None:
|
|
|
171
164
|
|
|
172
165
|
config = read_yaml_file(args.config, expand_env_vars=False)
|
|
173
166
|
assistant_name = args.assistant_name or _get_assistant_name(config)
|
|
167
|
+
args.assistant_name = assistant_name
|
|
174
168
|
if not _handle_existing_assistant(
|
|
175
169
|
assistant_name, studio_config.studio_url, verify, args
|
|
176
170
|
):
|
|
@@ -193,11 +187,6 @@ config_keys = [
|
|
|
193
187
|
]
|
|
194
188
|
|
|
195
189
|
|
|
196
|
-
def extract_values(data: Dict, keys: List[Text]) -> Dict:
|
|
197
|
-
"""Extracts values for given keys from a dictionary."""
|
|
198
|
-
return {key: data.get(key) for key in keys if data.get(key)}
|
|
199
|
-
|
|
200
|
-
|
|
201
190
|
def _get_assistant_name(config: Dict[Text, Any]) -> str:
|
|
202
191
|
config_assistant_id = config.get("assistant_id", "")
|
|
203
192
|
assistant_name = questionary.text(
|
|
@@ -236,7 +225,7 @@ def build_calm_import_parts(
|
|
|
236
225
|
config_path: Text,
|
|
237
226
|
endpoints_path: Optional[Text] = None,
|
|
238
227
|
assistant_name: Optional[Text] = None,
|
|
239
|
-
) -> Tuple[str,
|
|
228
|
+
) -> Tuple[str, CALMUserData]:
|
|
240
229
|
"""Builds the parts of the assistant to be uploaded to Studio.
|
|
241
230
|
|
|
242
231
|
Args:
|
|
@@ -259,33 +248,10 @@ def build_calm_import_parts(
|
|
|
259
248
|
endpoints = read_yaml_file(endpoints_path, expand_env_vars=False)
|
|
260
249
|
assistant_name = assistant_name or _get_assistant_name(config)
|
|
261
250
|
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
flow_importer = FlowSyncImporter.load_from_dict(
|
|
266
|
-
training_data_paths=[str(data_path)], expand_env_vars=False
|
|
267
|
-
)
|
|
268
|
-
|
|
269
|
-
flows = list(flow_importer.get_user_flows())
|
|
270
|
-
flows_yaml = YamlFlowsWriter().dumps(flows)
|
|
271
|
-
flows = read_yaml(flows_yaml, expand_env_vars=False)
|
|
272
|
-
|
|
273
|
-
nlu_importer = TrainingDataImporter.load_from_dict(
|
|
274
|
-
training_data_paths=[str(data_path)], expand_env_vars=False
|
|
275
|
-
)
|
|
276
|
-
nlu_data = nlu_importer.get_nlu_data()
|
|
277
|
-
nlu_examples = nlu_data.filter_training_examples(
|
|
278
|
-
lambda ex: ex.get("intent") in nlu_data.intents
|
|
279
|
-
)
|
|
280
|
-
nlu_examples_yaml = RasaYAMLWriter().dumps(nlu_examples)
|
|
281
|
-
nlu = read_yaml(nlu_examples_yaml, expand_env_vars=False)
|
|
282
|
-
|
|
283
|
-
parts = CALMImportParts(
|
|
284
|
-
flows=flows,
|
|
285
|
-
domain=domain,
|
|
251
|
+
parts = extract_calm_import_parts_from_importer(
|
|
252
|
+
importer=importer,
|
|
286
253
|
config=config,
|
|
287
254
|
endpoints=endpoints,
|
|
288
|
-
nlu=nlu,
|
|
289
255
|
)
|
|
290
256
|
|
|
291
257
|
return assistant_name, parts
|
rasa/utils/json_utils.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from decimal import Decimal
|
|
3
|
-
from typing import Any, Text
|
|
3
|
+
from typing import Any, Dict, List, Text
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class DecimalEncoder(json.JSONEncoder):
|
|
@@ -58,3 +58,8 @@ def replace_decimals_with_floats(obj: Any) -> Any:
|
|
|
58
58
|
Input `obj` with all `Decimal` types replaced by `float`s.
|
|
59
59
|
"""
|
|
60
60
|
return json.loads(json.dumps(obj, cls=DecimalEncoder))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def extract_values(data: Dict, keys: List[Text]) -> Dict:
|
|
64
|
+
"""Extracts values for given keys from a dictionary."""
|
|
65
|
+
return {key: data.get(key) for key in keys if data.get(key)}
|
rasa/utils/openapi.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Type
|
|
2
|
+
|
|
3
|
+
from pydantic.main import BaseModel
|
|
4
|
+
from sanic_openapi import openapi
|
|
5
|
+
from sanic_openapi.openapi3.types import Schema
|
|
6
|
+
|
|
7
|
+
_SUPPORTED_ATTRIBUTES = frozenset(["format", "enum", "required", "example"])
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _to_schema(
|
|
11
|
+
definition_stack: List[str], schema_def: Dict[str, Any], definitions: Dict[str, Any]
|
|
12
|
+
) -> Schema:
|
|
13
|
+
type = schema_def.get("type")
|
|
14
|
+
|
|
15
|
+
if type == "object":
|
|
16
|
+
properties_spec = schema_def.get("properties", {})
|
|
17
|
+
properties = {}
|
|
18
|
+
for key in properties_spec:
|
|
19
|
+
properties[key] = _to_schema(
|
|
20
|
+
definition_stack=definition_stack,
|
|
21
|
+
schema_def=properties_spec[key],
|
|
22
|
+
definitions=definitions,
|
|
23
|
+
)
|
|
24
|
+
schema = openapi.Object(
|
|
25
|
+
title=schema_def.get("title"),
|
|
26
|
+
description=schema_def.get("description"),
|
|
27
|
+
required=schema_def.get("required"),
|
|
28
|
+
properties=properties,
|
|
29
|
+
)
|
|
30
|
+
elif type == "array":
|
|
31
|
+
schema = openapi.Array(
|
|
32
|
+
description=schema_def.get("description"),
|
|
33
|
+
required=schema_def.get("required"),
|
|
34
|
+
items=_to_schema(
|
|
35
|
+
definition_stack=definition_stack,
|
|
36
|
+
schema_def=schema_def.get("items"),
|
|
37
|
+
definitions=definitions,
|
|
38
|
+
),
|
|
39
|
+
)
|
|
40
|
+
elif type is None:
|
|
41
|
+
if allof_spec := schema_def.get("allOf"): # Model, Enum
|
|
42
|
+
definition = allof_spec[0]["$ref"].split("/")[-1]
|
|
43
|
+
definition_data = definitions.get(definition)
|
|
44
|
+
if definition_data is None:
|
|
45
|
+
schema = openapi.Object(
|
|
46
|
+
title=definition, description=schema_def.get("description")
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
schema = (
|
|
50
|
+
_to_schema(
|
|
51
|
+
definition_stack=definition_stack + [definition],
|
|
52
|
+
schema_def={**definition_data},
|
|
53
|
+
definitions=definitions,
|
|
54
|
+
)
|
|
55
|
+
if definition not in definition_stack
|
|
56
|
+
else openapi.Object(
|
|
57
|
+
title=definition, description=schema_def.get("description")
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
elif anyof_spec := schema_def.get("anyOf"): # Union
|
|
62
|
+
anyof = []
|
|
63
|
+
for any in anyof_spec:
|
|
64
|
+
if any.get("type"):
|
|
65
|
+
schema_type_obj = Schema(
|
|
66
|
+
**{
|
|
67
|
+
"type": any.get("type"),
|
|
68
|
+
"description": any.get("description"),
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
anyof.append(schema_type_obj)
|
|
72
|
+
else:
|
|
73
|
+
definition = any["$ref"].split("/")[-1]
|
|
74
|
+
if definition not in definition_stack:
|
|
75
|
+
definition_data = definitions.get(definition)
|
|
76
|
+
if definition_data is not None:
|
|
77
|
+
anyof.append(
|
|
78
|
+
_to_schema(
|
|
79
|
+
definition_stack=definition_stack + [definition],
|
|
80
|
+
schema_def=definition_data,
|
|
81
|
+
definitions=definitions,
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
anyof.append(
|
|
86
|
+
openapi.Object(
|
|
87
|
+
title=definition,
|
|
88
|
+
description=schema_def.get(
|
|
89
|
+
"description", definition
|
|
90
|
+
),
|
|
91
|
+
properties={},
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
anyof.append(
|
|
96
|
+
openapi.Object(
|
|
97
|
+
title=definition,
|
|
98
|
+
description=schema_def.get("description", definition),
|
|
99
|
+
properties={},
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
schema = Schema(anyOf=anyof)
|
|
103
|
+
elif ref := schema_def.get("$ref"): # $ref
|
|
104
|
+
definition = ref.split("/")[-1]
|
|
105
|
+
definition_data = definitions.get(definition)
|
|
106
|
+
if definition_data is not None:
|
|
107
|
+
schema = _to_schema(
|
|
108
|
+
definition_stack=definition_stack,
|
|
109
|
+
schema_def=definition_data,
|
|
110
|
+
definitions=definitions,
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
schema = openapi.Object(
|
|
114
|
+
title=definition, description=schema_def.get("description")
|
|
115
|
+
)
|
|
116
|
+
else: # Any type
|
|
117
|
+
schema = Schema(
|
|
118
|
+
**{"type": "object", "description": schema_def.get("description")}
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
else:
|
|
122
|
+
schema_spec = {
|
|
123
|
+
"type": schema_def.get("type"),
|
|
124
|
+
"description": schema_def.get("description"),
|
|
125
|
+
}
|
|
126
|
+
for spec in _SUPPORTED_ATTRIBUTES:
|
|
127
|
+
if schema_def.get(spec):
|
|
128
|
+
schema_spec[spec] = schema_def.get(spec)
|
|
129
|
+
schema = Schema(**schema_spec)
|
|
130
|
+
|
|
131
|
+
return schema
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def model_to_schema(model: Type[BaseModel]) -> Schema:
|
|
135
|
+
schema = model.model_json_schema()
|
|
136
|
+
# Handle both $defs (newer JSON Schema) and definitions (older JSON Schema)
|
|
137
|
+
definitions = schema.get("$defs") or schema.get("definitions") or {}
|
|
138
|
+
return _to_schema(
|
|
139
|
+
definition_stack=[],
|
|
140
|
+
schema_def=dict(
|
|
141
|
+
filter(lambda key: key[0] not in ("definitions", "$defs"), schema.items())
|
|
142
|
+
),
|
|
143
|
+
definitions=definitions,
|
|
144
|
+
)
|
rasa/utils/plotting.py
CHANGED
|
@@ -99,7 +99,7 @@ def plot_confusion_matrix(
|
|
|
99
99
|
zmax = confusion_matrix.max() if len(confusion_matrix) > 0 else 1
|
|
100
100
|
plt.clf()
|
|
101
101
|
if not color_map:
|
|
102
|
-
color_map = plt.cm.Blues
|
|
102
|
+
color_map = plt.cm.get_cmap("Blues")
|
|
103
103
|
plt.imshow(
|
|
104
104
|
confusion_matrix,
|
|
105
105
|
interpolation="nearest",
|
rasa/version.py
CHANGED