rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +22 -12
- rasa/api.py +1 -1
- rasa/cli/arguments/default_arguments.py +1 -2
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +4 -4
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +8 -0
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +4 -2
- rasa/cli/utils.py +5 -0
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -1
- rasa/core/channels/channel.py +3 -0
- rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
- rasa/core/channels/socketio.py +2 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/jambonz.py +2 -2
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/asr/azure.py +122 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
- rasa/core/channels/voice_stream/audio_bytes.py +1 -0
- rasa/core/channels/voice_stream/browser_audio.py +31 -8
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/azure.py +6 -2
- rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +177 -39
- rasa/core/featurizers/single_state_featurizer.py +22 -1
- rasa/core/featurizers/tracker_featurizers.py +115 -18
- rasa/core/nlg/contextual_response_rephraser.py +16 -22
- rasa/core/persistor.py +86 -39
- rasa/core/policies/enterprise_search_policy.py +159 -60
- rasa/core/policies/flows/flow_executor.py +7 -4
- rasa/core/policies/intentless_policy.py +120 -22
- rasa/core/policies/ted_policy.py +58 -33
- rasa/core/policies/unexpected_intent_policy.py +15 -7
- rasa/core/processor.py +25 -0
- rasa/core/training/interactive.py +34 -35
- rasa/core/utils.py +8 -3
- rasa/dialogue_understanding/coexistence/llm_based_router.py +58 -16
- rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +4 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +65 -3
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +68 -26
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -8
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +64 -7
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/e2e_test/e2e_test_runner.py +4 -2
- rasa/e2e_test/utils/io.py +1 -1
- rasa/engine/validation.py +297 -7
- rasa/model_manager/config.py +15 -3
- rasa/model_manager/model_api.py +15 -7
- rasa/model_manager/runner_service.py +8 -6
- rasa/model_manager/socket_bridge.py +6 -3
- rasa/model_manager/trainer_service.py +7 -5
- rasa/model_manager/utils.py +28 -7
- rasa/model_service.py +6 -2
- rasa/model_training.py +2 -0
- rasa/nlu/classifiers/diet_classifier.py +38 -25
- rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
- rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
- rasa/nlu/extractors/crf_entity_extractor.py +93 -50
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
- rasa/shared/constants.py +36 -3
- rasa/shared/core/constants.py +7 -0
- rasa/shared/core/domain.py +26 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +96 -0
- rasa/shared/core/slots.py +5 -0
- rasa/shared/nlu/training_data/features.py +120 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +12 -15
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +31 -30
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +8 -0
- rasa/shared/utils/health_check.py +533 -0
- rasa/shared/utils/io.py +28 -6
- rasa/shared/utils/llm.py +350 -46
- rasa/shared/utils/yaml.py +11 -13
- rasa/studio/upload.py +64 -20
- rasa/telemetry.py +80 -17
- rasa/tracing/instrumentation/attribute_extractors.py +74 -17
- rasa/utils/io.py +0 -66
- rasa/utils/log_utils.py +9 -2
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/model_data.py +2 -193
- rasa/validator.py +70 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/METADATA +10 -10
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/RECORD +162 -146
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc1.dist-info}/entry_points.txt +0 -0
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import importlib
|
|
4
|
-
import os
|
|
5
4
|
from typing import Any, Dict, List, Optional
|
|
6
5
|
|
|
7
6
|
import structlog
|
|
@@ -16,13 +15,15 @@ from rasa.dialogue_understanding.coexistence.constants import (
|
|
|
16
15
|
)
|
|
17
16
|
from rasa.dialogue_understanding.commands import Command, SetSlotCommand
|
|
18
17
|
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
19
|
-
from rasa.dialogue_understanding.generator.constants import
|
|
18
|
+
from rasa.dialogue_understanding.generator.constants import (
|
|
19
|
+
LLM_CONFIG_KEY,
|
|
20
|
+
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
21
|
+
)
|
|
20
22
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
21
23
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
22
24
|
from rasa.engine.storage.resource import Resource
|
|
23
25
|
from rasa.engine.storage.storage import ModelStorage
|
|
24
26
|
from rasa.shared.constants import (
|
|
25
|
-
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
26
27
|
ROUTE_TO_CALM_SLOT,
|
|
27
28
|
PROMPT_CONFIG_KEY,
|
|
28
29
|
PROVIDER_CONFIG_KEY,
|
|
@@ -35,12 +36,16 @@ from rasa.shared.exceptions import InvalidConfigException, FileIOException
|
|
|
35
36
|
from rasa.shared.nlu.constants import COMMANDS, TEXT
|
|
36
37
|
from rasa.shared.nlu.training_data.message import Message
|
|
37
38
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
39
|
+
from rasa.shared.utils.io import deep_container_fingerprint
|
|
38
40
|
from rasa.shared.utils.llm import (
|
|
39
41
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
40
42
|
get_prompt_template,
|
|
41
|
-
llm_api_health_check,
|
|
42
43
|
llm_factory,
|
|
43
|
-
|
|
44
|
+
resolve_model_client_config,
|
|
45
|
+
)
|
|
46
|
+
from rasa.shared.utils.health_check import (
|
|
47
|
+
perform_training_time_llm_health_check,
|
|
48
|
+
perform_inference_time_llm_health_check,
|
|
44
49
|
)
|
|
45
50
|
from rasa.utils.log_utils import log_llm
|
|
46
51
|
|
|
@@ -48,6 +53,7 @@ LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
|
|
|
48
53
|
DEFAULT_COMMAND_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
49
54
|
"rasa.dialogue_understanding.coexistence", "router_template.jinja2"
|
|
50
55
|
)
|
|
56
|
+
LLM_BASED_ROUTER_CONFIG_FILE_NAME = "config.json"
|
|
51
57
|
|
|
52
58
|
# Token ids for gpt 3.5 and gpt 4 corresponding to space + capitalized Letter
|
|
53
59
|
A_TO_C_TOKEN_IDS_CHATGPT = [
|
|
@@ -96,6 +102,9 @@ class LLMBasedRouter(GraphComponent):
|
|
|
96
102
|
prompt_template: Optional[str] = None,
|
|
97
103
|
) -> None:
|
|
98
104
|
self.config = {**self.get_default_config(), **config}
|
|
105
|
+
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
106
|
+
self.config.get(LLM_CONFIG_KEY), LLMBasedRouter.__name__
|
|
107
|
+
)
|
|
99
108
|
|
|
100
109
|
self.prompt_template = (
|
|
101
110
|
prompt_template
|
|
@@ -129,20 +138,20 @@ class LLMBasedRouter(GraphComponent):
|
|
|
129
138
|
rasa.shared.utils.io.write_text_file(
|
|
130
139
|
self.prompt_template, path / LLM_BASED_ROUTER_PROMPT_FILE_NAME
|
|
131
140
|
)
|
|
141
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
142
|
+
path / LLM_BASED_ROUTER_CONFIG_FILE_NAME, self.config
|
|
143
|
+
)
|
|
132
144
|
|
|
133
145
|
def train(self, training_data: TrainingData) -> Resource:
|
|
134
146
|
"""Train the intent classifier on a data set."""
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
)
|
|
142
|
-
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
143
|
-
llm_api_health_check(
|
|
144
|
-
llm_client, "llm_based_router.train", LLMBasedRouter.__name__
|
|
147
|
+
self.config[TRAINED_MODEL_NAME_CONFIG_KEY] = (
|
|
148
|
+
perform_training_time_llm_health_check(
|
|
149
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
150
|
+
DEFAULT_LLM_CONFIG,
|
|
151
|
+
"llm_based_router.train",
|
|
152
|
+
LLMBasedRouter.__name__,
|
|
145
153
|
)
|
|
154
|
+
)
|
|
146
155
|
|
|
147
156
|
self.persist()
|
|
148
157
|
return self._resource
|
|
@@ -158,17 +167,36 @@ class LLMBasedRouter(GraphComponent):
|
|
|
158
167
|
) -> "LLMBasedRouter":
|
|
159
168
|
"""Loads trained component (see parent class for full docstring)."""
|
|
160
169
|
prompt_template = None
|
|
170
|
+
persisted_config = None
|
|
161
171
|
try:
|
|
162
172
|
with model_storage.read_from(resource) as path:
|
|
163
173
|
prompt_template = rasa.shared.utils.io.read_file(
|
|
164
174
|
path / LLM_BASED_ROUTER_PROMPT_FILE_NAME
|
|
165
175
|
)
|
|
176
|
+
persisted_config = rasa.shared.utils.io.read_json_file(
|
|
177
|
+
path / LLM_BASED_ROUTER_CONFIG_FILE_NAME
|
|
178
|
+
)
|
|
166
179
|
except (FileNotFoundError, FileIOException) as e:
|
|
167
180
|
structlogger.warning(
|
|
168
181
|
"llm_based_router.load.failed", error=e, resource=resource.name
|
|
169
182
|
)
|
|
170
183
|
|
|
171
|
-
|
|
184
|
+
router = cls(config, model_storage, resource, prompt_template=prompt_template)
|
|
185
|
+
|
|
186
|
+
train_model_name = (
|
|
187
|
+
persisted_config.get(TRAINED_MODEL_NAME_CONFIG_KEY, None)
|
|
188
|
+
if persisted_config
|
|
189
|
+
else None
|
|
190
|
+
)
|
|
191
|
+
perform_inference_time_llm_health_check(
|
|
192
|
+
router.config.get(LLM_CONFIG_KEY),
|
|
193
|
+
DEFAULT_LLM_CONFIG,
|
|
194
|
+
train_model_name,
|
|
195
|
+
"llm_based_router.load",
|
|
196
|
+
LLMBasedRouter.__name__,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return router
|
|
172
200
|
|
|
173
201
|
@classmethod
|
|
174
202
|
def create(
|
|
@@ -298,3 +326,17 @@ class LLMBasedRouter(GraphComponent):
|
|
|
298
326
|
# we have to catch all exceptions here
|
|
299
327
|
structlogger.error("llm_based_router.llm.error", error=e)
|
|
300
328
|
return None
|
|
329
|
+
|
|
330
|
+
@classmethod
|
|
331
|
+
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
332
|
+
"""Add a fingerprint of llm based router for the graph."""
|
|
333
|
+
prompt_template = get_prompt_template(
|
|
334
|
+
config.get(PROMPT_CONFIG_KEY),
|
|
335
|
+
DEFAULT_COMMAND_PROMPT_TEMPLATE,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
llm_config = resolve_model_client_config(
|
|
339
|
+
config.get(LLM_CONFIG_KEY), LLMBasedRouter.__name__
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
return deep_container_fingerprint([prompt_template, llm_config])
|
|
@@ -36,3 +36,9 @@ class ChangeFlowCommand(Command):
|
|
|
36
36
|
# the change flow command is not actually pushing anything to the tracker,
|
|
37
37
|
# but it is predicted by the MultiStepLLMCommandGenerator and used internally
|
|
38
38
|
return []
|
|
39
|
+
|
|
40
|
+
def __eq__(self, other: Any) -> bool:
|
|
41
|
+
return isinstance(other, ChangeFlowCommand)
|
|
42
|
+
|
|
43
|
+
def __hash__(self) -> int:
|
|
44
|
+
return hash(self.command())
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
from rasa.dialogue_understanding.commands import Command
|
|
6
|
+
from rasa.dialogue_understanding.patterns.user_silence import (
|
|
7
|
+
UserSilencePatternFlowStackFrame,
|
|
8
|
+
)
|
|
9
|
+
from rasa.shared.core.events import Event
|
|
10
|
+
from rasa.shared.core.flows import FlowsList
|
|
11
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class UserSilenceCommand(Command):
|
|
16
|
+
"""A command to indicate user silence."""
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def command(cls) -> str:
|
|
20
|
+
"""Returns the command type."""
|
|
21
|
+
return "user silence"
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_dict(cls, data: Dict[str, Any]) -> UserSilenceCommand:
|
|
25
|
+
"""Converts the dictionary to a command.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
The converted dictionary.
|
|
29
|
+
"""
|
|
30
|
+
return UserSilenceCommand()
|
|
31
|
+
|
|
32
|
+
def run_command_on_tracker(
|
|
33
|
+
self,
|
|
34
|
+
tracker: DialogueStateTracker,
|
|
35
|
+
all_flows: FlowsList,
|
|
36
|
+
original_tracker: DialogueStateTracker,
|
|
37
|
+
) -> List[Event]:
|
|
38
|
+
"""Runs the command on the tracker.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tracker: The tracker to run the command on.
|
|
42
|
+
all_flows: All flows in the assistant.
|
|
43
|
+
original_tracker: The tracker before any command was executed.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
The events to apply to the tracker.
|
|
47
|
+
"""
|
|
48
|
+
stack = tracker.stack
|
|
49
|
+
stack.push(UserSilencePatternFlowStackFrame())
|
|
50
|
+
return tracker.create_stack_updated_events(stack)
|
|
51
|
+
|
|
52
|
+
def __hash__(self) -> int:
|
|
53
|
+
return hash(self.command())
|
|
54
|
+
|
|
55
|
+
def __eq__(self, other: object) -> bool:
|
|
56
|
+
if not isinstance(other, UserSilenceCommand):
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
return True
|
|
@@ -11,6 +11,7 @@ from rasa.dialogue_understanding.commands import (
|
|
|
11
11
|
SkipQuestionCommand,
|
|
12
12
|
RestartCommand,
|
|
13
13
|
)
|
|
14
|
+
from rasa.dialogue_understanding.commands.user_silence_command import UserSilenceCommand
|
|
14
15
|
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
15
16
|
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
16
17
|
CannotHandlePatternFlowStackFrame,
|
|
@@ -27,9 +28,13 @@ from rasa.dialogue_understanding.patterns.session_start import (
|
|
|
27
28
|
from rasa.dialogue_understanding.patterns.skip_question import (
|
|
28
29
|
SkipQuestionPatternFlowStackFrame,
|
|
29
30
|
)
|
|
31
|
+
from rasa.dialogue_understanding.patterns.user_silence import (
|
|
32
|
+
UserSilencePatternFlowStackFrame,
|
|
33
|
+
)
|
|
30
34
|
|
|
31
35
|
triggerable_pattern_to_command_class: Dict[str, Type[Command]] = {
|
|
32
36
|
SessionStartPatternFlowStackFrame.flow_id: SessionStartCommand,
|
|
37
|
+
UserSilencePatternFlowStackFrame.flow_id: UserSilenceCommand,
|
|
33
38
|
CancelPatternFlowStackFrame.flow_id: CancelFlowCommand,
|
|
34
39
|
ChitchatPatternFlowStackFrame.flow_id: ChitChatAnswerCommand,
|
|
35
40
|
HumanHandoffPatternFlowStackFrame.flow_id: HumanHandoffCommand,
|
|
@@ -18,8 +18,12 @@ DEFAULT_LLM_CONFIG = {
|
|
|
18
18
|
}
|
|
19
19
|
|
|
20
20
|
LLM_CONFIG_KEY = "llm"
|
|
21
|
+
TRAINED_MODEL_NAME_CONFIG_KEY = "trained_llm_model_name"
|
|
22
|
+
TRAINED_EMBEDDINGS_CONFIG_KEY = "trained_embeddings_model_name"
|
|
21
23
|
USER_INPUT_CONFIG_KEY = "user_input"
|
|
22
24
|
|
|
23
25
|
FLOW_RETRIEVAL_KEY = "flow_retrieval"
|
|
24
26
|
FLOW_RETRIEVAL_ACTIVE_KEY = "active"
|
|
25
27
|
FLOW_RETRIEVAL_EMBEDDINGS_CONFIG_KEY = "embeddings"
|
|
28
|
+
|
|
29
|
+
FLOW_RETRIEVAL_FLOW_THRESHOLD = 20
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
The module is primarily centered around the `FlowRetrieval` class which handles the
|
|
1
|
+
"""The module is primarily centered around the `FlowRetrieval` class which handles the
|
|
3
2
|
initialization, configuration validation, vector store management, and flow retrieval
|
|
4
3
|
logic. It integrates components for managing embeddings, vector stores, and
|
|
5
4
|
flow-specific templates, facilitating semantic search functionalities.
|
|
@@ -27,8 +26,13 @@ from langchain.docstore.document import Document
|
|
|
27
26
|
from langchain.schema.embeddings import Embeddings
|
|
28
27
|
from langchain_community.vectorstores.faiss import FAISS
|
|
29
28
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
|
29
|
+
|
|
30
|
+
from rasa.dialogue_understanding.generator.constants import (
|
|
31
|
+
TRAINED_EMBEDDINGS_CONFIG_KEY,
|
|
32
|
+
)
|
|
30
33
|
from rasa.engine.storage.resource import Resource
|
|
31
34
|
from rasa.engine.storage.storage import ModelStorage
|
|
35
|
+
|
|
32
36
|
from rasa.shared.constants import (
|
|
33
37
|
EMBEDDINGS_CONFIG_KEY,
|
|
34
38
|
PROVIDER_CONFIG_KEY,
|
|
@@ -37,9 +41,9 @@ from rasa.shared.constants import (
|
|
|
37
41
|
from rasa.shared.core.domain import Domain
|
|
38
42
|
from rasa.shared.core.flows import FlowsList
|
|
39
43
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
44
|
+
from rasa.shared.exceptions import ProviderClientAPIException, FileIOException
|
|
40
45
|
from rasa.shared.nlu.constants import TEXT, FLOWS_FROM_SEMANTIC_SEARCH
|
|
41
46
|
from rasa.shared.nlu.training_data.message import Message
|
|
42
|
-
from rasa.shared.exceptions import ProviderClientAPIException
|
|
43
47
|
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
44
48
|
_LangchainEmbeddingClientAdapter,
|
|
45
49
|
)
|
|
@@ -50,12 +54,20 @@ from rasa.shared.utils.llm import (
|
|
|
50
54
|
USER,
|
|
51
55
|
get_prompt_template,
|
|
52
56
|
allowed_values_for_slot,
|
|
57
|
+
resolve_model_client_config,
|
|
58
|
+
)
|
|
59
|
+
from rasa.shared.utils.health_check import (
|
|
60
|
+
perform_training_time_embeddings_health_check,
|
|
61
|
+
perform_inference_time_embeddings_health_check,
|
|
53
62
|
)
|
|
63
|
+
from rasa.shared.utils.io import dump_obj_as_json_to_file, read_json_file
|
|
54
64
|
|
|
55
65
|
DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
|
|
56
66
|
"rasa.dialogue_understanding.generator", "flow_document_template.jinja2"
|
|
57
67
|
)
|
|
58
68
|
|
|
69
|
+
FLOW_RETRIEVAL_CONFIG_FILE_NAME = "flow_retrieval_config.json"
|
|
70
|
+
|
|
59
71
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
60
72
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
61
73
|
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
@@ -82,6 +94,7 @@ class FlowRetrieval:
|
|
|
82
94
|
MAX_FLOWS_FROM_SEMANTIC_SEARCH_KEY: DEFAULT_MAX_FLOWS_FROM_SEMANTIC_SEARCH,
|
|
83
95
|
TURNS_TO_EMBED_KEY: DEFAULT_TURNS_TO_EMBED,
|
|
84
96
|
SHOULD_EMBED_SLOTS_KEY: DEFAULT_SHOULD_EMBED_SLOTS,
|
|
97
|
+
TRAINED_EMBEDDINGS_CONFIG_KEY: None,
|
|
85
98
|
}
|
|
86
99
|
|
|
87
100
|
def __init__(
|
|
@@ -92,6 +105,9 @@ class FlowRetrieval:
|
|
|
92
105
|
):
|
|
93
106
|
config = {**self.get_default_config(), **config}
|
|
94
107
|
self.config = self.validate_config(config)
|
|
108
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
109
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
110
|
+
)
|
|
95
111
|
self.vector_store: Optional[FAISS] = None
|
|
96
112
|
self.flow_document_template = get_prompt_template(
|
|
97
113
|
None, DEFAULT_FLOW_DOCUMENT_TEMPLATE
|
|
@@ -131,6 +147,16 @@ class FlowRetrieval:
|
|
|
131
147
|
|
|
132
148
|
return config
|
|
133
149
|
|
|
150
|
+
def train(self) -> None:
|
|
151
|
+
self.config[TRAINED_EMBEDDINGS_CONFIG_KEY] = (
|
|
152
|
+
perform_training_time_embeddings_health_check(
|
|
153
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
154
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
155
|
+
"flow_retrieval.train",
|
|
156
|
+
FlowRetrieval.__name__,
|
|
157
|
+
)
|
|
158
|
+
)
|
|
159
|
+
|
|
134
160
|
@classmethod
|
|
135
161
|
def load(
|
|
136
162
|
cls,
|
|
@@ -147,6 +173,31 @@ class FlowRetrieval:
|
|
|
147
173
|
flow_retrieval.config, model_storage, resource
|
|
148
174
|
)
|
|
149
175
|
flow_retrieval.vector_store = vector_store
|
|
176
|
+
|
|
177
|
+
persisted_config = None
|
|
178
|
+
try:
|
|
179
|
+
with model_storage.read_from(resource) as path:
|
|
180
|
+
persisted_config = read_json_file(
|
|
181
|
+
path / FLOW_RETRIEVAL_CONFIG_FILE_NAME
|
|
182
|
+
)
|
|
183
|
+
except (FileNotFoundError, FileIOException) as e:
|
|
184
|
+
structlogger.warning(
|
|
185
|
+
"flow_retrieval.load.failed", error=e, resource=resource.name
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
train_embeddings_name = (
|
|
189
|
+
persisted_config.get(TRAINED_EMBEDDINGS_CONFIG_KEY, None)
|
|
190
|
+
if persisted_config
|
|
191
|
+
else None
|
|
192
|
+
)
|
|
193
|
+
perform_inference_time_embeddings_health_check(
|
|
194
|
+
flow_retrieval.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
195
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
196
|
+
train_embeddings_name,
|
|
197
|
+
"flow_retrieval.load",
|
|
198
|
+
FlowRetrieval.__name__,
|
|
199
|
+
)
|
|
200
|
+
|
|
150
201
|
return flow_retrieval
|
|
151
202
|
|
|
152
203
|
@classmethod
|
|
@@ -178,13 +229,24 @@ class FlowRetrieval:
|
|
|
178
229
|
Returns:
|
|
179
230
|
The embedder.
|
|
180
231
|
"""
|
|
232
|
+
# Copy the config so original config is not modified
|
|
233
|
+
config = config.copy()
|
|
234
|
+
# Resolve config and instantiate the embedding client
|
|
235
|
+
config[EMBEDDINGS_CONFIG_KEY] = resolve_model_client_config(
|
|
236
|
+
config.get(EMBEDDINGS_CONFIG_KEY), FlowRetrieval.__name__
|
|
237
|
+
)
|
|
181
238
|
client = embedder_factory(
|
|
182
239
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
183
240
|
)
|
|
241
|
+
# Wrap the embedding client in the adapter
|
|
184
242
|
return _LangchainEmbeddingClientAdapter(client)
|
|
185
243
|
|
|
186
244
|
def persist(self) -> None:
|
|
187
245
|
self._persist_vector_store()
|
|
246
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
247
|
+
dump_obj_as_json_to_file(
|
|
248
|
+
path / FLOW_RETRIEVAL_CONFIG_FILE_NAME, self.config
|
|
249
|
+
)
|
|
188
250
|
|
|
189
251
|
def _persist_vector_store(self) -> None:
|
|
190
252
|
"""Persists the FAISS vector store."""
|
|
@@ -2,7 +2,6 @@ from abc import ABC, abstractmethod
|
|
|
2
2
|
from functools import lru_cache
|
|
3
3
|
from typing import Dict, Any, List, Optional, Tuple, Union, Text
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
5
|
import structlog
|
|
7
6
|
from jinja2 import Template
|
|
8
7
|
|
|
@@ -17,13 +16,14 @@ from rasa.dialogue_understanding.generator.constants import (
|
|
|
17
16
|
LLM_CONFIG_KEY,
|
|
18
17
|
FLOW_RETRIEVAL_KEY,
|
|
19
18
|
FLOW_RETRIEVAL_ACTIVE_KEY,
|
|
19
|
+
FLOW_RETRIEVAL_FLOW_THRESHOLD,
|
|
20
|
+
TRAINED_MODEL_NAME_CONFIG_KEY,
|
|
20
21
|
)
|
|
21
22
|
from rasa.dialogue_understanding.generator.flow_retrieval import FlowRetrieval
|
|
22
23
|
from rasa.engine.graph import GraphComponent, ExecutionContext
|
|
23
24
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
24
25
|
from rasa.engine.storage.resource import Resource
|
|
25
26
|
from rasa.engine.storage.storage import ModelStorage
|
|
26
|
-
from rasa.shared.constants import LLM_API_HEALTH_CHECK_ENV_VAR
|
|
27
27
|
from rasa.shared.core.domain import Domain
|
|
28
28
|
from rasa.shared.core.flows import FlowStep, Flow, FlowsList
|
|
29
29
|
from rasa.shared.core.flows.steps.collect import CollectInformationFlowStep
|
|
@@ -35,15 +35,18 @@ from rasa.shared.nlu.training_data.message import Message
|
|
|
35
35
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
36
36
|
from rasa.shared.utils.llm import (
|
|
37
37
|
allowed_values_for_slot,
|
|
38
|
-
llm_api_health_check,
|
|
39
38
|
llm_factory,
|
|
40
|
-
|
|
39
|
+
resolve_model_client_config,
|
|
41
40
|
)
|
|
41
|
+
from rasa.shared.utils.health_check import perform_training_time_llm_health_check
|
|
42
42
|
from rasa.utils.log_utils import log_llm
|
|
43
43
|
|
|
44
44
|
structlogger = structlog.get_logger()
|
|
45
45
|
|
|
46
46
|
|
|
47
|
+
LLM_BASED_COMMAND_GENERATOR_CONFIG_FILE = "config.json"
|
|
48
|
+
|
|
49
|
+
|
|
47
50
|
@DefaultV1Recipe.register(
|
|
48
51
|
[
|
|
49
52
|
DefaultV1Recipe.ComponentType.COMMAND_GENERATOR,
|
|
@@ -64,6 +67,9 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
64
67
|
) -> None:
|
|
65
68
|
super().__init__(config)
|
|
66
69
|
self.config = {**self.get_default_config(), **config}
|
|
70
|
+
self.config[LLM_CONFIG_KEY] = resolve_model_client_config(
|
|
71
|
+
self.config.get(LLM_CONFIG_KEY), LLMBasedCommandGenerator.__name__
|
|
72
|
+
)
|
|
67
73
|
self._model_storage = model_storage
|
|
68
74
|
self._resource = resource
|
|
69
75
|
self.flow_retrieval: Optional[FlowRetrieval]
|
|
@@ -73,17 +79,9 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
73
79
|
self.config[FLOW_RETRIEVAL_KEY], model_storage, resource
|
|
74
80
|
)
|
|
75
81
|
structlogger.info("llm_based_command_generator.flow_retrieval.enabled")
|
|
82
|
+
self.config[FLOW_RETRIEVAL_KEY] = self.flow_retrieval.config
|
|
76
83
|
else:
|
|
77
84
|
self.flow_retrieval = None
|
|
78
|
-
structlogger.warn(
|
|
79
|
-
"llm_based_command_generator.flow_retrieval.disabled",
|
|
80
|
-
event_info=(
|
|
81
|
-
"Disabling flow retrieval can cause issues when there are a "
|
|
82
|
-
"large number of flows to be included in the prompt. For more"
|
|
83
|
-
"information see:\n"
|
|
84
|
-
"https://rasa.com/docs/rasa-pro/concepts/dialogue-understanding#how-the-llmcommandgenerator-works"
|
|
85
|
-
),
|
|
86
|
-
)
|
|
87
85
|
|
|
88
86
|
### Abstract methods
|
|
89
87
|
@staticmethod
|
|
@@ -108,7 +106,11 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
108
106
|
@abstractmethod
|
|
109
107
|
def persist(self) -> None:
|
|
110
108
|
"""Persist the component to disk for future loading."""
|
|
111
|
-
|
|
109
|
+
# persist the config to store the resolved llm and embedding config
|
|
110
|
+
with self._model_storage.write_to(self._resource) as path:
|
|
111
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(
|
|
112
|
+
path / LLM_BASED_COMMAND_GENERATOR_CONFIG_FILE, self.config
|
|
113
|
+
)
|
|
112
114
|
|
|
113
115
|
@abstractmethod
|
|
114
116
|
async def predict_commands(
|
|
@@ -171,19 +173,35 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
171
173
|
"""Train the llm based command generator. Stores all flows into a vector
|
|
172
174
|
store.
|
|
173
175
|
"""
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
"llm_based_command_generator.train",
|
|
179
|
-
LLMBasedCommandGenerator.__name__,
|
|
180
|
-
)
|
|
181
|
-
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
182
|
-
llm_api_health_check(
|
|
183
|
-
llm_client,
|
|
176
|
+
self.config[TRAINED_MODEL_NAME_CONFIG_KEY] = (
|
|
177
|
+
perform_training_time_llm_health_check(
|
|
178
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
179
|
+
DEFAULT_LLM_CONFIG,
|
|
184
180
|
"llm_based_command_generator.train",
|
|
185
181
|
LLMBasedCommandGenerator.__name__,
|
|
186
182
|
)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if (
|
|
186
|
+
self.flow_retrieval is None
|
|
187
|
+
and len(flows.user_flows) > FLOW_RETRIEVAL_FLOW_THRESHOLD
|
|
188
|
+
):
|
|
189
|
+
structlogger.warn(
|
|
190
|
+
"llm_based_command_generator.flow_retrieval.disabled",
|
|
191
|
+
event_info=(
|
|
192
|
+
f"You have {len(flows.user_flows)} user flows but flow "
|
|
193
|
+
f"retrieval is disabled. "
|
|
194
|
+
f"It is recommended to enable flow retrieval if the "
|
|
195
|
+
f"total number of user flows exceed "
|
|
196
|
+
f"{FLOW_RETRIEVAL_FLOW_THRESHOLD}. "
|
|
197
|
+
f"Keeping it disabled can result in deterioration of "
|
|
198
|
+
f"command generator's functional "
|
|
199
|
+
f"performance and higher costs because of increased "
|
|
200
|
+
f"number of tokens in the prompt. For more"
|
|
201
|
+
"information see:\n"
|
|
202
|
+
"https://rasa.com/docs/rasa-pro/concepts/dialogue-understanding#how-the-llmcommandgenerator-works"
|
|
203
|
+
),
|
|
204
|
+
)
|
|
187
205
|
|
|
188
206
|
# flow retrieval is populated with only user-defined flows
|
|
189
207
|
try:
|
|
@@ -192,10 +210,12 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
192
210
|
except Exception as e:
|
|
193
211
|
structlogger.error(
|
|
194
212
|
"llm_based_command_generator.train.failed",
|
|
195
|
-
event_info=
|
|
213
|
+
event_info="Flow retrieval store isinaccessible.",
|
|
196
214
|
error=e,
|
|
197
215
|
)
|
|
198
216
|
raise
|
|
217
|
+
if self.flow_retrieval is not None:
|
|
218
|
+
self.flow_retrieval.train()
|
|
199
219
|
self.persist()
|
|
200
220
|
return self._resource
|
|
201
221
|
|
|
@@ -231,9 +251,31 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
231
251
|
)
|
|
232
252
|
return None
|
|
233
253
|
|
|
254
|
+
@classmethod
|
|
255
|
+
def load_config_from_model_storage(
|
|
256
|
+
cls,
|
|
257
|
+
model_storage: ModelStorage,
|
|
258
|
+
resource: Resource,
|
|
259
|
+
) -> Optional[Text]:
|
|
260
|
+
try:
|
|
261
|
+
with model_storage.read_from(resource) as path:
|
|
262
|
+
return rasa.shared.utils.io.read_json_file(
|
|
263
|
+
path / LLM_BASED_COMMAND_GENERATOR_CONFIG_FILE
|
|
264
|
+
)
|
|
265
|
+
except (FileNotFoundError, FileIOException) as e:
|
|
266
|
+
structlogger.warning(
|
|
267
|
+
"llm_based_command_generator.load_config.failed",
|
|
268
|
+
error=e,
|
|
269
|
+
resource=resource.name,
|
|
270
|
+
)
|
|
271
|
+
return None
|
|
272
|
+
|
|
234
273
|
@classmethod
|
|
235
274
|
def load_flow_retrival(
|
|
236
|
-
cls,
|
|
275
|
+
cls,
|
|
276
|
+
config: Dict[str, Any],
|
|
277
|
+
model_storage: ModelStorage,
|
|
278
|
+
resource: Resource,
|
|
237
279
|
) -> Optional[FlowRetrieval]:
|
|
238
280
|
"""Load the FlowRetrieval component if it is enabled in the configuration."""
|
|
239
281
|
enable_flow_retrieval = config.get(FLOW_RETRIEVAL_KEY, {}).get(
|