rasa-pro 3.11.13__py3-none-any.whl → 3.11.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/api.py +4 -0
- rasa/cli/arguments/default_arguments.py +13 -1
- rasa/cli/arguments/train.py +2 -0
- rasa/cli/train.py +1 -0
- rasa/constants.py +2 -0
- rasa/core/nlg/contextual_response_rephraser.py +40 -14
- rasa/core/nlg/summarize.py +37 -5
- rasa/core/persistor.py +55 -20
- rasa/core/policies/enterprise_search_policy.py +10 -7
- rasa/core/policies/intentless_policy.py +17 -11
- rasa/core/run.py +7 -2
- rasa/dialogue_understanding/coexistence/llm_based_router.py +11 -6
- rasa/dialogue_understanding/generator/constants.py +6 -4
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +1 -1
- rasa/dialogue_understanding/processor/command_processor_component.py +3 -3
- rasa/engine/recipes/default_recipe.py +26 -2
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +4 -2
- rasa/model_manager/config.py +3 -1
- rasa/model_manager/model_api.py +1 -2
- rasa/model_manager/runner_service.py +8 -4
- rasa/model_manager/trainer_service.py +1 -0
- rasa/model_training.py +12 -3
- rasa/server.py +6 -2
- rasa/shared/constants.py +6 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +14 -10
- rasa/shared/providers/_configs/openai_client_config.py +13 -9
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +3 -0
- rasa/shared/providers/llm/_base_litellm_client.py +5 -2
- rasa/shared/utils/llm.py +8 -2
- rasa/version.py +1 -1
- {rasa_pro-3.11.13.dist-info → rasa_pro-3.11.14.dist-info}/METADATA +1 -1
- {rasa_pro-3.11.13.dist-info → rasa_pro-3.11.14.dist-info}/RECORD +36 -36
- {rasa_pro-3.11.13.dist-info → rasa_pro-3.11.14.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.13.dist-info → rasa_pro-3.11.14.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.13.dist-info → rasa_pro-3.11.14.dist-info}/entry_points.txt +0 -0
rasa/api.py
CHANGED
|
@@ -81,6 +81,7 @@ def train(
|
|
|
81
81
|
remote_storage: Optional[StorageType] = None,
|
|
82
82
|
file_importer: Optional["TrainingDataImporter"] = None,
|
|
83
83
|
keep_local_model_copy: bool = False,
|
|
84
|
+
remote_root_only: bool = False,
|
|
84
85
|
) -> "TrainingResult":
|
|
85
86
|
"""Runs Rasa Core and NLU training in `async` loop.
|
|
86
87
|
|
|
@@ -108,6 +109,8 @@ def train(
|
|
|
108
109
|
If it is not provided, a new instance will be created.
|
|
109
110
|
keep_local_model_copy: If `True` the model will be stored locally even if
|
|
110
111
|
remote storage is configured.
|
|
112
|
+
remote_root_only: If `True`, the model will be stored in the root of the
|
|
113
|
+
remote model storage.
|
|
111
114
|
|
|
112
115
|
Returns:
|
|
113
116
|
An instance of `TrainingResult`.
|
|
@@ -131,6 +134,7 @@ def train(
|
|
|
131
134
|
remote_storage=remote_storage,
|
|
132
135
|
file_importer=file_importer,
|
|
133
136
|
keep_local_model_copy=keep_local_model_copy,
|
|
137
|
+
remote_root_only=remote_root_only,
|
|
134
138
|
)
|
|
135
139
|
)
|
|
136
140
|
|
|
@@ -172,7 +172,7 @@ def add_remote_storage_param(
|
|
|
172
172
|
) -> None:
|
|
173
173
|
parser.add_argument(
|
|
174
174
|
"--remote-storage",
|
|
175
|
-
help="Remote storage which should be used to store/load the model."
|
|
175
|
+
help="Remote storage which should be used to store/load the model. "
|
|
176
176
|
f"Supported storages are: {RemoteStorageType.list()}. "
|
|
177
177
|
"You can also provide your own implementation of the `Persistor` interface.",
|
|
178
178
|
required=required,
|
|
@@ -180,6 +180,18 @@ def add_remote_storage_param(
|
|
|
180
180
|
)
|
|
181
181
|
|
|
182
182
|
|
|
183
|
+
def add_remote_root_only_param(
|
|
184
|
+
parser: argparse.ArgumentParser, required: bool = False
|
|
185
|
+
) -> None:
|
|
186
|
+
parser.add_argument(
|
|
187
|
+
"--remote-root-only",
|
|
188
|
+
action="store_true",
|
|
189
|
+
help="If set, models will be stored only at the root directory "
|
|
190
|
+
"of the remote storage.",
|
|
191
|
+
required=required,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
|
|
183
195
|
def parse_remote_storage_arg(value: str) -> StorageType:
|
|
184
196
|
try:
|
|
185
197
|
return parse_remote_storage(value)
|
rasa/cli/arguments/train.py
CHANGED
|
@@ -8,6 +8,7 @@ from rasa.cli.arguments.default_arguments import (
|
|
|
8
8
|
add_out_param,
|
|
9
9
|
add_domain_param,
|
|
10
10
|
add_endpoint_param,
|
|
11
|
+
add_remote_root_only_param,
|
|
11
12
|
add_remote_storage_param,
|
|
12
13
|
)
|
|
13
14
|
from rasa.graph_components.providers.training_tracker_provider import (
|
|
@@ -41,6 +42,7 @@ def set_train_arguments(parser: argparse.ArgumentParser) -> None:
|
|
|
41
42
|
parser, help_text="Configuration file for the connectors as a yml file."
|
|
42
43
|
)
|
|
43
44
|
add_remote_storage_param(parser)
|
|
45
|
+
add_remote_root_only_param(parser)
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
def set_train_core_arguments(parser: argparse.ArgumentParser) -> None:
|
rasa/cli/train.py
CHANGED
|
@@ -153,6 +153,7 @@ def run_training(args: argparse.Namespace, can_exit: bool = False) -> Optional[T
|
|
|
153
153
|
remote_storage=args.remote_storage,
|
|
154
154
|
file_importer=training_data_importer,
|
|
155
155
|
keep_local_model_copy=args.keep_local_model_copy,
|
|
156
|
+
remote_root_only=args.remote_root_only,
|
|
156
157
|
)
|
|
157
158
|
if training_result.code != 0 and can_exit:
|
|
158
159
|
sys.exit(training_result.code)
|
rasa/constants.py
CHANGED
|
@@ -4,16 +4,21 @@ import structlog
|
|
|
4
4
|
from jinja2 import Template
|
|
5
5
|
from rasa import telemetry
|
|
6
6
|
from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
|
|
7
|
-
from rasa.core.nlg.summarize import
|
|
7
|
+
from rasa.core.nlg.summarize import (
|
|
8
|
+
_count_multiple_utterances_as_single_turn,
|
|
9
|
+
summarize_conversation,
|
|
10
|
+
)
|
|
8
11
|
from rasa.shared.constants import (
|
|
9
12
|
LLM_CONFIG_KEY,
|
|
13
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
10
14
|
MODEL_CONFIG_KEY,
|
|
15
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
11
16
|
MODEL_NAME_CONFIG_KEY,
|
|
17
|
+
OPENAI_PROVIDER,
|
|
12
18
|
PROMPT_CONFIG_KEY,
|
|
13
19
|
PROVIDER_CONFIG_KEY,
|
|
14
|
-
|
|
20
|
+
TEMPERATURE_CONFIG_KEY,
|
|
15
21
|
TIMEOUT_CONFIG_KEY,
|
|
16
|
-
MODEL_GROUP_ID_CONFIG_KEY,
|
|
17
22
|
)
|
|
18
23
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
19
24
|
from rasa.shared.core.events import BotUttered, UserUttered
|
|
@@ -45,12 +50,13 @@ RESPONSE_SUMMARISE_CONVERSATION_KEY = "summarize_conversation"
|
|
|
45
50
|
DEFAULT_REPHRASE_ALL = False
|
|
46
51
|
DEFAULT_SUMMARIZE_HISTORY = True
|
|
47
52
|
DEFAULT_MAX_HISTORICAL_TURNS = 5
|
|
53
|
+
DEFAULT_COUNT_MULTIPLE_UTTERANCES_AS_SINGLE_TURN = True
|
|
48
54
|
|
|
49
55
|
DEFAULT_LLM_CONFIG = {
|
|
50
56
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
51
57
|
MODEL_CONFIG_KEY: DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
52
|
-
|
|
53
|
-
|
|
58
|
+
TEMPERATURE_CONFIG_KEY: 0.3,
|
|
59
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY: DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
54
60
|
TIMEOUT_CONFIG_KEY: 5,
|
|
55
61
|
}
|
|
56
62
|
|
|
@@ -62,6 +68,7 @@ its meaning. Use simple english.
|
|
|
62
68
|
Context / previous conversation with the user:
|
|
63
69
|
{{history}}
|
|
64
70
|
|
|
71
|
+
Last user message:
|
|
65
72
|
{{current_input}}
|
|
66
73
|
|
|
67
74
|
Suggested AI Response: {{suggested_response}}
|
|
@@ -112,6 +119,11 @@ class ContextualResponseRephraser(
|
|
|
112
119
|
"max_historical_turns", DEFAULT_MAX_HISTORICAL_TURNS
|
|
113
120
|
)
|
|
114
121
|
|
|
122
|
+
self.count_multiple_utterances_as_single_turn = self.nlg_endpoint.kwargs.get(
|
|
123
|
+
"count_multiple_utterances_as_single_turn",
|
|
124
|
+
DEFAULT_COUNT_MULTIPLE_UTTERANCES_AS_SINGLE_TURN,
|
|
125
|
+
)
|
|
126
|
+
|
|
115
127
|
self.llm_config = resolve_model_client_config(
|
|
116
128
|
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY),
|
|
117
129
|
ContextualResponseRephraser.__name__,
|
|
@@ -198,8 +210,16 @@ class ContextualResponseRephraser(
|
|
|
198
210
|
Returns:
|
|
199
211
|
The history for the prompt.
|
|
200
212
|
"""
|
|
213
|
+
# Count multiple utterances by bot/user as single turn in conversation history
|
|
214
|
+
turns_wrapper = (
|
|
215
|
+
_count_multiple_utterances_as_single_turn
|
|
216
|
+
if self.count_multiple_utterances_as_single_turn
|
|
217
|
+
else None
|
|
218
|
+
)
|
|
201
219
|
llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
|
|
202
|
-
return await summarize_conversation(
|
|
220
|
+
return await summarize_conversation(
|
|
221
|
+
tracker, llm, max_turns=5, turns_wrapper=turns_wrapper
|
|
222
|
+
)
|
|
203
223
|
|
|
204
224
|
async def rephrase(
|
|
205
225
|
self,
|
|
@@ -211,7 +231,6 @@ class ContextualResponseRephraser(
|
|
|
211
231
|
Args:
|
|
212
232
|
response: The response to rephrase.
|
|
213
233
|
tracker: The tracker to use for the prediction.
|
|
214
|
-
model_name: The name of the model to use for the prediction.
|
|
215
234
|
|
|
216
235
|
Returns:
|
|
217
236
|
The response with the rephrased text.
|
|
@@ -221,19 +240,26 @@ class ContextualResponseRephraser(
|
|
|
221
240
|
|
|
222
241
|
prompt_template_text = self._template_for_response_rephrasing(response)
|
|
223
242
|
|
|
224
|
-
#
|
|
225
|
-
|
|
226
|
-
current_input =
|
|
243
|
+
# Last user message (=current input) should always be in prompt if available
|
|
244
|
+
last_message_by_user = getattr(tracker.latest_message, "text", "")
|
|
245
|
+
current_input = (
|
|
246
|
+
f"{USER}: {last_message_by_user}" if last_message_by_user else ""
|
|
247
|
+
)
|
|
227
248
|
|
|
228
249
|
# Only summarise conversation history if flagged
|
|
229
250
|
if self.summarize_history:
|
|
230
251
|
history = await self._create_history(tracker)
|
|
231
252
|
else:
|
|
232
|
-
#
|
|
253
|
+
# Count multiple utterances by bot/user as single turn
|
|
254
|
+
turns_wrapper = (
|
|
255
|
+
_count_multiple_utterances_as_single_turn
|
|
256
|
+
if self.count_multiple_utterances_as_single_turn
|
|
257
|
+
else None
|
|
258
|
+
)
|
|
233
259
|
max_turns = max(self.max_historical_turns, 1)
|
|
234
|
-
history = tracker_as_readable_transcript(
|
|
235
|
-
|
|
236
|
-
|
|
260
|
+
history = tracker_as_readable_transcript(
|
|
261
|
+
tracker, max_turns=max_turns, turns_wrapper=turns_wrapper
|
|
262
|
+
)
|
|
237
263
|
|
|
238
264
|
prompt = Template(prompt_template_text).render(
|
|
239
265
|
history=history,
|
rasa/core/nlg/summarize.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from
|
|
1
|
+
from itertools import groupby
|
|
2
|
+
from typing import Callable, List, Optional
|
|
2
3
|
|
|
3
4
|
import structlog
|
|
4
5
|
from jinja2 import Template
|
|
@@ -22,20 +23,47 @@ SUMMARY_PROMPT_TEMPLATE = Template(_DEFAULT_SUMMARIZER_TEMPLATE)
|
|
|
22
23
|
MAX_TURNS_DEFAULT = 20
|
|
23
24
|
|
|
24
25
|
|
|
26
|
+
def _count_multiple_utterances_as_single_turn(transcript: List[str]) -> List[str]:
|
|
27
|
+
"""Counts multiple utterances as a single turn.
|
|
28
|
+
Args:
|
|
29
|
+
transcript: the lines of the transcript
|
|
30
|
+
Returns:
|
|
31
|
+
transcript: with multiple utterances counted as a single turn
|
|
32
|
+
"""
|
|
33
|
+
if not transcript:
|
|
34
|
+
return []
|
|
35
|
+
|
|
36
|
+
def get_speaker_label(line: str) -> str:
|
|
37
|
+
return line.partition(": ")[0] if ": " in line else ""
|
|
38
|
+
|
|
39
|
+
modified_transcript = [
|
|
40
|
+
f"{speaker}: {' '.join(line.partition(': ')[2] for line in group)}"
|
|
41
|
+
for speaker, group in groupby(transcript, key=get_speaker_label)
|
|
42
|
+
if speaker
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
return modified_transcript
|
|
46
|
+
|
|
47
|
+
|
|
25
48
|
def _create_summarization_prompt(
|
|
26
|
-
tracker: DialogueStateTracker,
|
|
49
|
+
tracker: DialogueStateTracker,
|
|
50
|
+
max_turns: Optional[int],
|
|
51
|
+
turns_wrapper: Optional[Callable[[List[str]], List[str]]],
|
|
27
52
|
) -> str:
|
|
28
53
|
"""Creates an LLM prompt to summarize the conversation in the tracker.
|
|
29
54
|
|
|
30
55
|
Args:
|
|
31
56
|
tracker: tracker of the conversation to be summarized
|
|
32
57
|
max_turns: maximum number of turns to summarize
|
|
58
|
+
turns_wrapper: optional function to wrap the turns
|
|
33
59
|
|
|
34
60
|
|
|
35
61
|
Returns:
|
|
36
62
|
The prompt to summarize the conversation.
|
|
37
63
|
"""
|
|
38
|
-
transcript = tracker_as_readable_transcript(
|
|
64
|
+
transcript = tracker_as_readable_transcript(
|
|
65
|
+
tracker, max_turns=max_turns, turns_wrapper=turns_wrapper
|
|
66
|
+
)
|
|
39
67
|
return SUMMARY_PROMPT_TEMPLATE.render(
|
|
40
68
|
conversation=transcript,
|
|
41
69
|
)
|
|
@@ -45,6 +73,7 @@ async def summarize_conversation(
|
|
|
45
73
|
tracker: DialogueStateTracker,
|
|
46
74
|
llm: LLMClient,
|
|
47
75
|
max_turns: Optional[int] = MAX_TURNS_DEFAULT,
|
|
76
|
+
turns_wrapper: Optional[Callable[[List[str]], List[str]]] = None,
|
|
48
77
|
) -> str:
|
|
49
78
|
"""Summarizes the dialogue using the LLM.
|
|
50
79
|
|
|
@@ -52,11 +81,12 @@ async def summarize_conversation(
|
|
|
52
81
|
tracker: the tracker to summarize
|
|
53
82
|
llm: the LLM to use for summarization
|
|
54
83
|
max_turns: maximum number of turns to summarize
|
|
84
|
+
turns_wrapper: optional function to wrap the turns
|
|
55
85
|
|
|
56
86
|
Returns:
|
|
57
87
|
The summary of the dialogue.
|
|
58
88
|
"""
|
|
59
|
-
prompt = _create_summarization_prompt(tracker, max_turns)
|
|
89
|
+
prompt = _create_summarization_prompt(tracker, max_turns, turns_wrapper)
|
|
60
90
|
try:
|
|
61
91
|
llm_response = await llm.acompletion(prompt)
|
|
62
92
|
summarization = llm_response.choices[0].strip()
|
|
@@ -65,6 +95,8 @@ async def summarize_conversation(
|
|
|
65
95
|
)
|
|
66
96
|
return summarization
|
|
67
97
|
except Exception as e:
|
|
68
|
-
transcript = tracker_as_readable_transcript(
|
|
98
|
+
transcript = tracker_as_readable_transcript(
|
|
99
|
+
tracker, max_turns=max_turns, turns_wrapper=turns_wrapper
|
|
100
|
+
)
|
|
69
101
|
structlogger.error("summarization.error", error=e)
|
|
70
102
|
return transcript
|
rasa/core/persistor.py
CHANGED
|
@@ -121,10 +121,12 @@ def get_persistor(storage: StorageType) -> Optional[Persistor]:
|
|
|
121
121
|
class Persistor(abc.ABC):
|
|
122
122
|
"""Store models in cloud and fetch them when needed."""
|
|
123
123
|
|
|
124
|
-
def persist(self, trained_model: str) -> None:
|
|
124
|
+
def persist(self, trained_model: str, remote_root_only: bool = False) -> None:
|
|
125
125
|
"""Uploads a trained model persisted in the `target_dir` to cloud storage."""
|
|
126
126
|
absolute_file_key = self._create_file_key(trained_model)
|
|
127
|
-
file_key =
|
|
127
|
+
file_key = (
|
|
128
|
+
Path(absolute_file_key).name if remote_root_only else absolute_file_key
|
|
129
|
+
)
|
|
128
130
|
self._persist_tar(file_key, trained_model)
|
|
129
131
|
|
|
130
132
|
def retrieve(self, model_name: Text, target_path: Text) -> Text:
|
|
@@ -143,30 +145,32 @@ class Persistor(abc.ABC):
|
|
|
143
145
|
# ensure backward compatibility
|
|
144
146
|
tar_name = self._tar_name(model_name)
|
|
145
147
|
tar_name = self._create_file_key(tar_name)
|
|
146
|
-
|
|
147
|
-
self._retrieve_tar(target_filename)
|
|
148
|
-
self._copy(os.path.basename(tar_name), target_path)
|
|
148
|
+
self._retrieve_tar(tar_name, target_path)
|
|
149
149
|
|
|
150
150
|
if os.path.isdir(target_path):
|
|
151
151
|
return os.path.join(target_path, model_name)
|
|
152
152
|
|
|
153
153
|
return target_path
|
|
154
154
|
|
|
155
|
-
def size_of_persisted_model(
|
|
155
|
+
def size_of_persisted_model(
|
|
156
|
+
self, model_name: Text, target_path: Optional[str] = None
|
|
157
|
+
) -> int:
|
|
156
158
|
"""Returns the size of the model that has been persisted to cloud storage.
|
|
157
159
|
|
|
158
160
|
Args:
|
|
159
161
|
model_name: The name of the model to retrieve.
|
|
162
|
+
target_path: The path to which the model should be saved.
|
|
160
163
|
"""
|
|
161
164
|
tar_name = model_name
|
|
162
165
|
if not model_name.endswith(MODEL_ARCHIVE_EXTENSION):
|
|
163
166
|
# ensure backward compatibility
|
|
164
167
|
tar_name = self._tar_name(model_name)
|
|
165
168
|
tar_name = self._create_file_key(tar_name)
|
|
166
|
-
|
|
167
|
-
return self._retrieve_tar_size(target_filename)
|
|
169
|
+
return self._retrieve_tar_size(tar_name, target_path)
|
|
168
170
|
|
|
169
|
-
def _retrieve_tar_size(
|
|
171
|
+
def _retrieve_tar_size(
|
|
172
|
+
self, filename: Text, target_path: Optional[str] = None
|
|
173
|
+
) -> int:
|
|
170
174
|
"""Returns the size of the model that has been persisted to cloud storage."""
|
|
171
175
|
structlogger.warning(
|
|
172
176
|
"persistor.retrieve_tar_size.not_implemented",
|
|
@@ -179,11 +183,11 @@ class Persistor(abc.ABC):
|
|
|
179
183
|
"size directly from the cloud storage."
|
|
180
184
|
),
|
|
181
185
|
)
|
|
182
|
-
self._retrieve_tar(filename)
|
|
186
|
+
self._retrieve_tar(filename, target_path)
|
|
183
187
|
return os.path.getsize(os.path.basename(filename))
|
|
184
188
|
|
|
185
189
|
@abc.abstractmethod
|
|
186
|
-
def _retrieve_tar(self, filename:
|
|
190
|
+
def _retrieve_tar(self, filename: str, target_path: Optional[str] = None) -> None:
|
|
187
191
|
"""Downloads a model previously persisted to cloud storage."""
|
|
188
192
|
raise NotImplementedError
|
|
189
193
|
|
|
@@ -302,7 +306,9 @@ class AWSPersistor(Persistor):
|
|
|
302
306
|
with open(tar_path, "rb") as f:
|
|
303
307
|
self.s3.Object(self.bucket_name, file_key).put(Body=f)
|
|
304
308
|
|
|
305
|
-
def _retrieve_tar_size(
|
|
309
|
+
def _retrieve_tar_size(
|
|
310
|
+
self, model_path: Text, target_path: Optional[str] = None
|
|
311
|
+
) -> int:
|
|
306
312
|
"""Returns the size of the model that has been persisted to s3."""
|
|
307
313
|
try:
|
|
308
314
|
obj = self.s3.Object(self.bucket_name, model_path)
|
|
@@ -310,7 +316,9 @@ class AWSPersistor(Persistor):
|
|
|
310
316
|
except Exception:
|
|
311
317
|
raise ModelNotFound()
|
|
312
318
|
|
|
313
|
-
def _retrieve_tar(
|
|
319
|
+
def _retrieve_tar(
|
|
320
|
+
self, target_filename: str, target_path: Optional[str] = None
|
|
321
|
+
) -> None:
|
|
314
322
|
"""Downloads a model that has previously been persisted to s3."""
|
|
315
323
|
from botocore import exceptions
|
|
316
324
|
|
|
@@ -320,8 +328,14 @@ class AWSPersistor(Persistor):
|
|
|
320
328
|
f"in the bucket."
|
|
321
329
|
)
|
|
322
330
|
|
|
331
|
+
tar_name = (
|
|
332
|
+
os.path.join(target_path, os.path.basename(target_filename))
|
|
333
|
+
if target_path
|
|
334
|
+
else os.path.basename(target_filename)
|
|
335
|
+
)
|
|
336
|
+
|
|
323
337
|
try:
|
|
324
|
-
with open(
|
|
338
|
+
with open(tar_name, "wb") as f:
|
|
325
339
|
self.bucket.download_fileobj(target_filename, f)
|
|
326
340
|
|
|
327
341
|
structlogger.debug(
|
|
@@ -425,7 +439,9 @@ class GCSPersistor(Persistor):
|
|
|
425
439
|
blob = self.bucket.blob(file_key)
|
|
426
440
|
blob.upload_from_filename(tar_path)
|
|
427
441
|
|
|
428
|
-
def _retrieve_tar_size(
|
|
442
|
+
def _retrieve_tar_size(
|
|
443
|
+
self, target_filename: Text, target_path: Optional[str] = None
|
|
444
|
+
) -> int:
|
|
429
445
|
"""Returns the size of the model that has been persisted to GCS."""
|
|
430
446
|
try:
|
|
431
447
|
blob = self.bucket.blob(target_filename)
|
|
@@ -433,13 +449,22 @@ class GCSPersistor(Persistor):
|
|
|
433
449
|
except Exception:
|
|
434
450
|
raise ModelNotFound()
|
|
435
451
|
|
|
436
|
-
def _retrieve_tar(
|
|
452
|
+
def _retrieve_tar(
|
|
453
|
+
self, target_filename: str, target_path: Optional[str] = None
|
|
454
|
+
) -> None:
|
|
437
455
|
"""Downloads a model that has previously been persisted to GCS."""
|
|
438
456
|
from google.api_core import exceptions
|
|
439
457
|
|
|
440
458
|
blob = self.bucket.blob(target_filename)
|
|
459
|
+
|
|
460
|
+
destination = (
|
|
461
|
+
os.path.join(target_path, os.path.basename(target_filename))
|
|
462
|
+
if target_path
|
|
463
|
+
else target_filename
|
|
464
|
+
)
|
|
465
|
+
|
|
441
466
|
try:
|
|
442
|
-
blob.download_to_filename(
|
|
467
|
+
blob.download_to_filename(destination)
|
|
443
468
|
|
|
444
469
|
structlogger.debug(
|
|
445
470
|
"gcs_persistor.retrieve_tar.object_found", object_key=target_filename
|
|
@@ -500,7 +525,9 @@ class AzurePersistor(Persistor):
|
|
|
500
525
|
with open(tar_path, "rb") as data:
|
|
501
526
|
self._container_client().upload_blob(name=file_key, data=data)
|
|
502
527
|
|
|
503
|
-
def _retrieve_tar_size(
|
|
528
|
+
def _retrieve_tar_size(
|
|
529
|
+
self, target_filename: Text, target_path: Optional[str] = None
|
|
530
|
+
) -> int:
|
|
504
531
|
"""Returns the size of the model that has been persisted to Azure."""
|
|
505
532
|
try:
|
|
506
533
|
blob_client = self._container_client().get_blob_client(target_filename)
|
|
@@ -509,12 +536,20 @@ class AzurePersistor(Persistor):
|
|
|
509
536
|
except Exception:
|
|
510
537
|
raise ModelNotFound()
|
|
511
538
|
|
|
512
|
-
def _retrieve_tar(
|
|
539
|
+
def _retrieve_tar(
|
|
540
|
+
self, target_filename: Text, target_path: Optional[str] = None
|
|
541
|
+
) -> None:
|
|
513
542
|
"""Downloads a model that has previously been persisted to Azure."""
|
|
514
543
|
from azure.core.exceptions import AzureError
|
|
515
544
|
|
|
545
|
+
destination = (
|
|
546
|
+
os.path.join(target_path, os.path.basename(target_filename))
|
|
547
|
+
if target_path
|
|
548
|
+
else target_filename
|
|
549
|
+
)
|
|
550
|
+
|
|
516
551
|
try:
|
|
517
|
-
with open(
|
|
552
|
+
with open(destination, "wb") as model_file:
|
|
518
553
|
blob_client = self._container_client().get_blob_client(target_filename)
|
|
519
554
|
download_stream = blob_client.download_blob()
|
|
520
555
|
model_file.write(download_stream.readall())
|
|
@@ -45,13 +45,16 @@ from rasa.graph_components.providers.forms_provider import Forms
|
|
|
45
45
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
46
46
|
from rasa.shared.constants import (
|
|
47
47
|
EMBEDDINGS_CONFIG_KEY,
|
|
48
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
49
|
+
MAX_RETRIES_CONFIG_KEY,
|
|
48
50
|
MODEL_CONFIG_KEY,
|
|
51
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
52
|
+
MODEL_NAME_CONFIG_KEY,
|
|
53
|
+
OPENAI_PROVIDER,
|
|
49
54
|
PROMPT_CONFIG_KEY,
|
|
50
55
|
PROVIDER_CONFIG_KEY,
|
|
51
|
-
|
|
56
|
+
TEMPERATURE_CONFIG_KEY,
|
|
52
57
|
TIMEOUT_CONFIG_KEY,
|
|
53
|
-
MODEL_NAME_CONFIG_KEY,
|
|
54
|
-
MODEL_GROUP_ID_CONFIG_KEY,
|
|
55
58
|
)
|
|
56
59
|
from rasa.shared.core.constants import (
|
|
57
60
|
ACTION_CANCEL_FLOW,
|
|
@@ -121,14 +124,14 @@ DEFAULT_LLM_CONFIG = {
|
|
|
121
124
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
122
125
|
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
123
126
|
TIMEOUT_CONFIG_KEY: 10,
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
+
TEMPERATURE_CONFIG_KEY: 0.0,
|
|
128
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY: 256,
|
|
129
|
+
MAX_RETRIES_CONFIG_KEY: 1,
|
|
127
130
|
}
|
|
128
131
|
|
|
129
132
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
130
133
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
131
|
-
|
|
134
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
132
135
|
}
|
|
133
136
|
|
|
134
137
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
@@ -31,13 +31,15 @@ from rasa.graph_components.providers.responses_provider import Responses
|
|
|
31
31
|
from rasa.shared.constants import (
|
|
32
32
|
EMBEDDINGS_CONFIG_KEY,
|
|
33
33
|
LLM_CONFIG_KEY,
|
|
34
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
34
35
|
MODEL_CONFIG_KEY,
|
|
36
|
+
MODEL_GROUP_ID_CONFIG_KEY,
|
|
35
37
|
MODEL_NAME_CONFIG_KEY,
|
|
38
|
+
OPENAI_PROVIDER,
|
|
36
39
|
PROMPT_CONFIG_KEY,
|
|
37
40
|
PROVIDER_CONFIG_KEY,
|
|
38
|
-
|
|
41
|
+
TEMPERATURE_CONFIG_KEY,
|
|
39
42
|
TIMEOUT_CONFIG_KEY,
|
|
40
|
-
MODEL_GROUP_ID_CONFIG_KEY,
|
|
41
43
|
)
|
|
42
44
|
from rasa.shared.core.constants import ACTION_LISTEN_NAME
|
|
43
45
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
@@ -110,14 +112,14 @@ NLU_ABSTENTION_THRESHOLD = "nlu_abstention_threshold"
|
|
|
110
112
|
DEFAULT_LLM_CONFIG = {
|
|
111
113
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
112
114
|
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
113
|
-
|
|
114
|
-
|
|
115
|
+
TEMPERATURE_CONFIG_KEY: 0.0,
|
|
116
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY: DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
115
117
|
TIMEOUT_CONFIG_KEY: 5,
|
|
116
118
|
}
|
|
117
119
|
|
|
118
120
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
119
121
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
120
|
-
|
|
122
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
121
123
|
}
|
|
122
124
|
|
|
123
125
|
DEFAULT_INTENTLESS_PROMPT_TEMPLATE = importlib.resources.open_text(
|
|
@@ -343,8 +345,6 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
343
345
|
# ensures that the policy will not override a deterministic policy
|
|
344
346
|
# which utilizes the nlu predictions confidence (e.g. Memoization).
|
|
345
347
|
NLU_ABSTENTION_THRESHOLD: 0.9,
|
|
346
|
-
LLM_CONFIG_KEY: DEFAULT_LLM_CONFIG,
|
|
347
|
-
EMBEDDINGS_CONFIG_KEY: DEFAULT_EMBEDDINGS_CONFIG,
|
|
348
348
|
PROMPT_CONFIG_KEY: DEFAULT_INTENTLESS_PROMPT_TEMPLATE,
|
|
349
349
|
}
|
|
350
350
|
|
|
@@ -380,13 +380,19 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
380
380
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
381
381
|
|
|
382
382
|
# Resolve LLM config
|
|
383
|
-
self.config[LLM_CONFIG_KEY] =
|
|
384
|
-
|
|
383
|
+
self.config[LLM_CONFIG_KEY] = combine_custom_and_default_config(
|
|
384
|
+
resolve_model_client_config(
|
|
385
|
+
self.config.get(LLM_CONFIG_KEY), IntentlessPolicy.__name__
|
|
386
|
+
),
|
|
387
|
+
DEFAULT_LLM_CONFIG,
|
|
385
388
|
)
|
|
386
389
|
|
|
387
390
|
# Resolve embeddings config
|
|
388
|
-
self.config[EMBEDDINGS_CONFIG_KEY] =
|
|
389
|
-
|
|
391
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = combine_custom_and_default_config(
|
|
392
|
+
resolve_model_client_config(
|
|
393
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), IntentlessPolicy.__name__
|
|
394
|
+
),
|
|
395
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
390
396
|
)
|
|
391
397
|
|
|
392
398
|
self.nlu_abstention_threshold: float = self.config[NLU_ABSTENTION_THRESHOLD]
|
rasa/core/run.py
CHANGED
|
@@ -86,13 +86,15 @@ def _create_single_channel(channel: Text, credentials: Dict[Text, Any]) -> Any:
|
|
|
86
86
|
)
|
|
87
87
|
|
|
88
88
|
|
|
89
|
-
def _create_app_without_api(
|
|
89
|
+
def _create_app_without_api(
|
|
90
|
+
cors: Optional[Union[Text, List[Text]]] = None, is_inspector_enabled: bool = False
|
|
91
|
+
) -> Sanic:
|
|
90
92
|
app = Sanic("rasa_core_no_api", configure_logging=False)
|
|
91
93
|
|
|
92
94
|
# Reset Sanic warnings filter that allows the triggering of Sanic warnings
|
|
93
95
|
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"sanic.*")
|
|
94
96
|
|
|
95
|
-
server.add_root_route(app)
|
|
97
|
+
server.add_root_route(app, is_inspector_enabled)
|
|
96
98
|
server.configure_cors(app, cors)
|
|
97
99
|
return app
|
|
98
100
|
|
|
@@ -127,6 +129,7 @@ def configure_app(
|
|
|
127
129
|
server_listeners: Optional[List[Tuple[Callable, Text]]] = None,
|
|
128
130
|
use_uvloop: Optional[bool] = True,
|
|
129
131
|
keep_alive_timeout: int = constants.DEFAULT_KEEP_ALIVE_TIMEOUT,
|
|
132
|
+
is_inspector_enabled: bool = False,
|
|
130
133
|
) -> Sanic:
|
|
131
134
|
"""Run the agent."""
|
|
132
135
|
rasa.core.utils.configure_file_logging(
|
|
@@ -144,6 +147,7 @@ def configure_app(
|
|
|
144
147
|
jwt_private_key=jwt_private_key,
|
|
145
148
|
jwt_method=jwt_method,
|
|
146
149
|
endpoints=endpoints,
|
|
150
|
+
is_inspector_enabled=is_inspector_enabled,
|
|
147
151
|
)
|
|
148
152
|
)
|
|
149
153
|
else:
|
|
@@ -259,6 +263,7 @@ def serve_application(
|
|
|
259
263
|
syslog_protocol=syslog_protocol,
|
|
260
264
|
request_timeout=request_timeout,
|
|
261
265
|
server_listeners=server_listeners,
|
|
266
|
+
is_inspector_enabled=inspect,
|
|
262
267
|
)
|
|
263
268
|
|
|
264
269
|
ssl_context = server.create_ssl_context(
|
|
@@ -23,11 +23,14 @@ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
|
23
23
|
from rasa.engine.storage.resource import Resource
|
|
24
24
|
from rasa.engine.storage.storage import ModelStorage
|
|
25
25
|
from rasa.shared.constants import (
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
PROVIDER_CONFIG_KEY,
|
|
26
|
+
LOGIT_BIAS_CONFIG_KEY,
|
|
27
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
29
28
|
MODEL_CONFIG_KEY,
|
|
30
29
|
OPENAI_PROVIDER,
|
|
30
|
+
PROMPT_CONFIG_KEY,
|
|
31
|
+
PROVIDER_CONFIG_KEY,
|
|
32
|
+
ROUTE_TO_CALM_SLOT,
|
|
33
|
+
TEMPERATURE_CONFIG_KEY,
|
|
31
34
|
TIMEOUT_CONFIG_KEY,
|
|
32
35
|
)
|
|
33
36
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
@@ -62,9 +65,11 @@ DEFAULT_LLM_CONFIG = {
|
|
|
62
65
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
63
66
|
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
64
67
|
TIMEOUT_CONFIG_KEY: 7,
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
+
TEMPERATURE_CONFIG_KEY: 0.0,
|
|
69
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY: 1,
|
|
70
|
+
LOGIT_BIAS_CONFIG_KEY: {
|
|
71
|
+
str(token_id): 100 for token_id in A_TO_C_TOKEN_IDS_CHATGPT
|
|
72
|
+
},
|
|
68
73
|
}
|
|
69
74
|
|
|
70
75
|
structlogger = structlog.get_logger()
|
|
@@ -1,8 +1,10 @@
|
|
|
1
1
|
from rasa.shared.constants import (
|
|
2
|
-
|
|
3
|
-
OPENAI_PROVIDER,
|
|
2
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
4
3
|
MODEL_CONFIG_KEY,
|
|
4
|
+
OPENAI_PROVIDER,
|
|
5
|
+
PROVIDER_CONFIG_KEY,
|
|
5
6
|
TIMEOUT_CONFIG_KEY,
|
|
7
|
+
TEMPERATURE_CONFIG_KEY,
|
|
6
8
|
)
|
|
7
9
|
from rasa.shared.utils.llm import (
|
|
8
10
|
DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
|
|
@@ -12,8 +14,8 @@ from rasa.shared.utils.llm import (
|
|
|
12
14
|
DEFAULT_LLM_CONFIG = {
|
|
13
15
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
14
16
|
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
|
|
15
|
-
|
|
16
|
-
|
|
17
|
+
TEMPERATURE_CONFIG_KEY: 0.0,
|
|
18
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY: DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
17
19
|
TIMEOUT_CONFIG_KEY: 7,
|
|
18
20
|
}
|
|
19
21
|
|