rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +0 -374
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +27 -23
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +11 -3
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +104 -431
- rasa/cli/evaluate.py +1 -1
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +398 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +15 -14
- rasa/cli/scaffold.py +10 -8
- rasa/cli/studio/studio.py +35 -5
- rasa/cli/train.py +56 -8
- rasa/cli/utils.py +22 -5
- rasa/cli/x.py +1 -1
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +113 -45
- rasa/core/policies/flows/flow_executor.py +122 -76
- rasa/core/policies/intentless_policy.py +83 -29
- rasa/core/processor.py +72 -54
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +56 -57
- rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +40 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +16 -3
- rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1223 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +493 -71
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/graph.py +9 -3
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/recipes/default_recipe.py +10 -2
- rasa/engine/storage/local_model_storage.py +40 -12
- rasa/engine/validation.py +78 -1
- rasa/env.py +9 -0
- rasa/graph_components/providers/story_graph_provider.py +59 -6
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +56 -16
- rasa/nlu/persistor.py +157 -36
- rasa/server.py +45 -10
- rasa/shared/constants.py +76 -16
- rasa/shared/core/domain.py +27 -19
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +33 -11
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/structures.py +20 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +64 -16
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +251 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -2
- rasa/shared/utils/llm.py +246 -200
- rasa/shared/utils/yaml.py +121 -15
- rasa/studio/auth.py +6 -4
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +19 -11
- rasa/telemetry.py +113 -58
- rasa/tracing/instrumentation/attribute_extractors.py +32 -17
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/validator.py +213 -3
- rasa/version.py +1 -1
- rasa_pro-3.10.16.dist-info/METADATA +196 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa_pro-3.9.18.dist-info/METADATA +0 -563
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
rasa/core/utils.py
CHANGED
|
@@ -1,23 +1,31 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import logging
|
|
3
2
|
import os
|
|
4
|
-
from decimal import Decimal
|
|
5
3
|
from pathlib import Path
|
|
6
|
-
from
|
|
4
|
+
from socket import SOCK_DGRAM, SOCK_STREAM
|
|
5
|
+
from typing import Any, Dict, Optional, Set, TYPE_CHECKING, Text, Tuple, Union
|
|
7
6
|
|
|
8
7
|
import numpy as np
|
|
8
|
+
from sanic import Sanic
|
|
9
9
|
|
|
10
|
+
import rasa.cli.utils as cli_utils
|
|
10
11
|
import rasa.shared.utils.io
|
|
11
12
|
from rasa.constants import DEFAULT_SANIC_WORKERS, ENV_SANIC_WORKERS
|
|
12
|
-
from rasa.
|
|
13
|
-
|
|
13
|
+
from rasa.core.constants import (
|
|
14
|
+
DOMAIN_GROUND_TRUTH_METADATA_KEY,
|
|
15
|
+
UTTER_SOURCE_METADATA_KEY,
|
|
16
|
+
ACTIVE_FLOW_METADATA_KEY,
|
|
17
|
+
STEP_ID_METADATA_KEY,
|
|
18
|
+
)
|
|
14
19
|
from rasa.core.lock_store import LockStore, RedisLockStore, InMemoryLockStore
|
|
20
|
+
from rasa.shared.constants import DEFAULT_ENDPOINTS_PATH, TCP_PROTOCOL
|
|
21
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
15
22
|
from rasa.utils.endpoints import EndpointConfig, read_endpoint_config
|
|
16
|
-
from sanic import Sanic
|
|
17
|
-
from socket import SOCK_DGRAM, SOCK_STREAM
|
|
18
|
-
import rasa.cli.utils as cli_utils
|
|
19
23
|
from rasa.utils.io import write_yaml
|
|
20
24
|
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from rasa.core.nlg import NaturalLanguageGenerator
|
|
27
|
+
from rasa.shared.core.domain import Domain
|
|
28
|
+
|
|
21
29
|
logger = logging.getLogger(__name__)
|
|
22
30
|
|
|
23
31
|
|
|
@@ -163,6 +171,8 @@ def is_limit_reached(num_messages: int, limit: Optional[int]) -> bool:
|
|
|
163
171
|
class AvailableEndpoints:
|
|
164
172
|
"""Collection of configured endpoints."""
|
|
165
173
|
|
|
174
|
+
_instance = None
|
|
175
|
+
|
|
166
176
|
@classmethod
|
|
167
177
|
def read_endpoints(cls, endpoint_file: Text) -> "AvailableEndpoints":
|
|
168
178
|
"""Read the different endpoints from a yaml file."""
|
|
@@ -209,6 +219,14 @@ class AvailableEndpoints:
|
|
|
209
219
|
self.event_broker = event_broker
|
|
210
220
|
self.vector_store = vector_store
|
|
211
221
|
|
|
222
|
+
@classmethod
|
|
223
|
+
def get_instance(cls, endpoint_file: Optional[Text] = None) -> "AvailableEndpoints":
|
|
224
|
+
"""Get the singleton instance of AvailableEndpoints."""
|
|
225
|
+
# Ensure that the instance is initialized only once.
|
|
226
|
+
if cls._instance is None:
|
|
227
|
+
cls._instance = cls.read_endpoints(endpoint_file)
|
|
228
|
+
return cls._instance
|
|
229
|
+
|
|
212
230
|
|
|
213
231
|
def read_endpoints_from_path(
|
|
214
232
|
endpoints_path: Optional[Union[Path, Text]] = None,
|
|
@@ -226,55 +244,7 @@ def read_endpoints_from_path(
|
|
|
226
244
|
endpoints_config_path = cli_utils.get_validated_path(
|
|
227
245
|
endpoints_path, "endpoints", DEFAULT_ENDPOINTS_PATH, True
|
|
228
246
|
)
|
|
229
|
-
return AvailableEndpoints.
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
def replace_floats_with_decimals(obj: Any, round_digits: int = 9) -> Any:
|
|
233
|
-
"""Convert all instances in `obj` of `float` to `Decimal`.
|
|
234
|
-
|
|
235
|
-
Args:
|
|
236
|
-
obj: Input object.
|
|
237
|
-
round_digits: Rounding precision of `Decimal` values.
|
|
238
|
-
|
|
239
|
-
Returns:
|
|
240
|
-
Input `obj` with all `float` types replaced by `Decimal`s rounded to
|
|
241
|
-
`round_digits` decimal places.
|
|
242
|
-
"""
|
|
243
|
-
|
|
244
|
-
def _float_to_rounded_decimal(s: Text) -> Decimal:
|
|
245
|
-
return Decimal(s).quantize(Decimal(10) ** -round_digits)
|
|
246
|
-
|
|
247
|
-
return json.loads(json.dumps(obj), parse_float=_float_to_rounded_decimal)
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
class DecimalEncoder(json.JSONEncoder):
|
|
251
|
-
"""`json.JSONEncoder` that dumps `Decimal`s as `float`s."""
|
|
252
|
-
|
|
253
|
-
def default(self, obj: Any) -> Any:
|
|
254
|
-
"""Get serializable object for `o`.
|
|
255
|
-
|
|
256
|
-
Args:
|
|
257
|
-
obj: Object to serialize.
|
|
258
|
-
|
|
259
|
-
Returns:
|
|
260
|
-
`obj` converted to `float` if `o` is a `Decimals`, else the base class
|
|
261
|
-
`default()` method.
|
|
262
|
-
"""
|
|
263
|
-
if isinstance(obj, Decimal):
|
|
264
|
-
return float(obj)
|
|
265
|
-
return super().default(obj)
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
def replace_decimals_with_floats(obj: Any) -> Any:
|
|
269
|
-
"""Convert all instances in `obj` of `Decimal` to `float`.
|
|
270
|
-
|
|
271
|
-
Args:
|
|
272
|
-
obj: A `List` or `Dict` object.
|
|
273
|
-
|
|
274
|
-
Returns:
|
|
275
|
-
Input `obj` with all `Decimal` types replaced by `float`s.
|
|
276
|
-
"""
|
|
277
|
-
return json.loads(json.dumps(obj, cls=DecimalEncoder))
|
|
247
|
+
return AvailableEndpoints.get_instance(endpoints_config_path)
|
|
278
248
|
|
|
279
249
|
|
|
280
250
|
def _lock_store_is_multi_worker_compatible(
|
|
@@ -337,3 +307,32 @@ def number_of_sanic_workers(lock_store: Union[EndpointConfig, LockStore, None])
|
|
|
337
307
|
f"configuration has been found."
|
|
338
308
|
)
|
|
339
309
|
return _log_and_get_default_number_of_workers()
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def add_bot_utterance_metadata(
|
|
313
|
+
message: Dict[str, Any],
|
|
314
|
+
domain_response_name: str,
|
|
315
|
+
nlg: "NaturalLanguageGenerator",
|
|
316
|
+
domain: "Domain",
|
|
317
|
+
tracker: Optional[DialogueStateTracker],
|
|
318
|
+
) -> Dict[str, Any]:
|
|
319
|
+
"""Add metadata to the bot message."""
|
|
320
|
+
message["utter_action"] = domain_response_name
|
|
321
|
+
|
|
322
|
+
utter_source = message.get(UTTER_SOURCE_METADATA_KEY)
|
|
323
|
+
if utter_source is None:
|
|
324
|
+
utter_source = nlg.__class__.__name__
|
|
325
|
+
message[UTTER_SOURCE_METADATA_KEY] = utter_source
|
|
326
|
+
|
|
327
|
+
if tracker:
|
|
328
|
+
message[ACTIVE_FLOW_METADATA_KEY] = tracker.active_flow
|
|
329
|
+
message[STEP_ID_METADATA_KEY] = tracker.current_step_id
|
|
330
|
+
|
|
331
|
+
if utter_source in ["IntentlessPolicy", "ContextualResponseRephraser"]:
|
|
332
|
+
message[DOMAIN_GROUND_TRUTH_METADATA_KEY] = [
|
|
333
|
+
response.get("text")
|
|
334
|
+
for response in domain.responses.get(domain_response_name, [])
|
|
335
|
+
if response.get("text") is not None
|
|
336
|
+
]
|
|
337
|
+
|
|
338
|
+
return message
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
import importlib
|
|
3
4
|
from typing import Any, Dict, List, Optional
|
|
4
5
|
|
|
5
6
|
import structlog
|
|
6
7
|
from jinja2 import Template
|
|
7
8
|
|
|
9
|
+
import rasa.shared.utils.io
|
|
8
10
|
from rasa.dialogue_understanding.coexistence.constants import (
|
|
9
11
|
CALM_ENTRY,
|
|
10
12
|
NLU_ENTRY,
|
|
@@ -18,25 +20,32 @@ from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
|
18
20
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
19
21
|
from rasa.engine.storage.resource import Resource
|
|
20
22
|
from rasa.engine.storage.storage import ModelStorage
|
|
21
|
-
from rasa.shared.constants import
|
|
23
|
+
from rasa.shared.constants import (
|
|
24
|
+
ROUTE_TO_CALM_SLOT,
|
|
25
|
+
PROMPT_CONFIG_KEY,
|
|
26
|
+
PROVIDER_CONFIG_KEY,
|
|
27
|
+
MODEL_CONFIG_KEY,
|
|
28
|
+
OPENAI_PROVIDER,
|
|
29
|
+
TIMEOUT_CONFIG_KEY,
|
|
30
|
+
)
|
|
22
31
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
23
32
|
from rasa.shared.exceptions import InvalidConfigException, FileIOException
|
|
24
33
|
from rasa.shared.nlu.constants import COMMANDS, TEXT
|
|
25
34
|
from rasa.shared.nlu.training_data.message import Message
|
|
26
35
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
27
|
-
import rasa.shared.utils.io
|
|
28
36
|
from rasa.shared.utils.llm import (
|
|
29
37
|
DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
30
38
|
get_prompt_template,
|
|
31
39
|
llm_factory,
|
|
40
|
+
try_instantiate_llm_client,
|
|
32
41
|
)
|
|
42
|
+
from rasa.utils.log_utils import log_llm
|
|
33
43
|
|
|
34
44
|
LLM_BASED_ROUTER_PROMPT_FILE_NAME = "llm_based_router_prompt.jinja2"
|
|
35
45
|
DEFAULT_COMMAND_PROMPT_TEMPLATE = importlib.resources.read_text(
|
|
36
46
|
"rasa.dialogue_understanding.coexistence", "router_template.jinja2"
|
|
37
47
|
)
|
|
38
48
|
|
|
39
|
-
|
|
40
49
|
# Token ids for gpt 3.5 and gpt 4 corresponding to space + capitalized Letter
|
|
41
50
|
A_TO_C_TOKEN_IDS_CHATGPT = [
|
|
42
51
|
362, # " A"
|
|
@@ -45,10 +54,10 @@ A_TO_C_TOKEN_IDS_CHATGPT = [
|
|
|
45
54
|
]
|
|
46
55
|
|
|
47
56
|
DEFAULT_LLM_CONFIG = {
|
|
48
|
-
|
|
49
|
-
|
|
57
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
58
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
59
|
+
TIMEOUT_CONFIG_KEY: 7,
|
|
50
60
|
"temperature": 0.0,
|
|
51
|
-
"model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
52
61
|
"max_tokens": 1,
|
|
53
62
|
"logit_bias": {str(token_id): 100 for token_id in A_TO_C_TOKEN_IDS_CHATGPT},
|
|
54
63
|
}
|
|
@@ -67,7 +76,7 @@ class LLMBasedRouter(GraphComponent):
|
|
|
67
76
|
def get_default_config() -> Dict[str, Any]:
|
|
68
77
|
"""The component's default config (see parent class for full docstring)."""
|
|
69
78
|
return {
|
|
70
|
-
|
|
79
|
+
PROMPT_CONFIG_KEY: None,
|
|
71
80
|
CALM_ENTRY: {STICKY: None},
|
|
72
81
|
NLU_ENTRY: {
|
|
73
82
|
NON_STICKY: "handles chitchat",
|
|
@@ -88,7 +97,7 @@ class LLMBasedRouter(GraphComponent):
|
|
|
88
97
|
self.prompt_template = (
|
|
89
98
|
prompt_template
|
|
90
99
|
or get_prompt_template(
|
|
91
|
-
config.get(
|
|
100
|
+
config.get(PROMPT_CONFIG_KEY),
|
|
92
101
|
DEFAULT_COMMAND_PROMPT_TEMPLATE,
|
|
93
102
|
).strip()
|
|
94
103
|
)
|
|
@@ -120,6 +129,14 @@ class LLMBasedRouter(GraphComponent):
|
|
|
120
129
|
|
|
121
130
|
def train(self, training_data: TrainingData) -> Resource:
|
|
122
131
|
"""Train the intent classifier on a data set."""
|
|
132
|
+
# Validate llm configuration
|
|
133
|
+
try_instantiate_llm_client(
|
|
134
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
135
|
+
DEFAULT_LLM_CONFIG,
|
|
136
|
+
"llm_based_router.train",
|
|
137
|
+
"LLMBasedRouter",
|
|
138
|
+
)
|
|
139
|
+
|
|
123
140
|
self.persist()
|
|
124
141
|
return self._resource
|
|
125
142
|
|
|
@@ -144,7 +161,14 @@ class LLMBasedRouter(GraphComponent):
|
|
|
144
161
|
"llm_based_router.load.failed", error=e, resource=resource.name
|
|
145
162
|
)
|
|
146
163
|
|
|
147
|
-
|
|
164
|
+
router = cls(config, model_storage, resource, prompt_template=prompt_template)
|
|
165
|
+
try_instantiate_llm_client(
|
|
166
|
+
router.config.get(LLM_CONFIG_KEY),
|
|
167
|
+
DEFAULT_LLM_CONFIG,
|
|
168
|
+
"llm_based_router.load",
|
|
169
|
+
LLMBasedRouter.__name__,
|
|
170
|
+
)
|
|
171
|
+
return router
|
|
148
172
|
|
|
149
173
|
@classmethod
|
|
150
174
|
def create(
|
|
@@ -188,12 +212,27 @@ class LLMBasedRouter(GraphComponent):
|
|
|
188
212
|
route_session_to_calm = tracker.get_slot(ROUTE_TO_CALM_SLOT)
|
|
189
213
|
if route_session_to_calm is None:
|
|
190
214
|
prompt = self.render_template(message)
|
|
191
|
-
|
|
215
|
+
log_llm(
|
|
216
|
+
logger=structlogger,
|
|
217
|
+
log_module="LLMBasedRouter",
|
|
218
|
+
log_event="llm_based_router.prompt_rendered",
|
|
219
|
+
prompt=prompt,
|
|
220
|
+
)
|
|
192
221
|
# generating answer
|
|
193
222
|
answer = await self._generate_answer_using_llm(prompt)
|
|
194
|
-
|
|
223
|
+
log_llm(
|
|
224
|
+
logger=structlogger,
|
|
225
|
+
log_module="LLMBasedRouter",
|
|
226
|
+
log_event="llm_based_router.llm_answer",
|
|
227
|
+
answer=answer,
|
|
228
|
+
)
|
|
195
229
|
commands = self.parse_answer(answer)
|
|
196
|
-
|
|
230
|
+
log_llm(
|
|
231
|
+
logger=structlogger,
|
|
232
|
+
log_module="LLMBasedRouter",
|
|
233
|
+
log_event="llm_based_router.final_commands",
|
|
234
|
+
commands=commands,
|
|
235
|
+
)
|
|
197
236
|
return commands
|
|
198
237
|
elif route_session_to_calm is True:
|
|
199
238
|
# don't set any commands so that a `LLMBasedCommandGenerator` is triggered
|
|
@@ -252,7 +291,8 @@ class LLMBasedRouter(GraphComponent):
|
|
|
252
291
|
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
253
292
|
|
|
254
293
|
try:
|
|
255
|
-
|
|
294
|
+
llm_response = await llm.acompletion(prompt)
|
|
295
|
+
return llm_response.choices[0]
|
|
256
296
|
except Exception as e:
|
|
257
297
|
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
258
298
|
# we have to catch all exceptions here
|
|
@@ -9,6 +9,7 @@ from rasa.dialogue_understanding.commands.knowledge_answer_command import (
|
|
|
9
9
|
from rasa.dialogue_understanding.commands.chit_chat_answer_command import (
|
|
10
10
|
ChitChatAnswerCommand,
|
|
11
11
|
)
|
|
12
|
+
from rasa.dialogue_understanding.commands.restart_command import RestartCommand
|
|
12
13
|
from rasa.dialogue_understanding.commands.skip_question_command import (
|
|
13
14
|
SkipQuestionCommand,
|
|
14
15
|
)
|
|
@@ -28,6 +29,9 @@ from rasa.dialogue_understanding.commands.correct_slots_command import (
|
|
|
28
29
|
)
|
|
29
30
|
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
30
31
|
from rasa.dialogue_understanding.commands.change_flow_command import ChangeFlowCommand
|
|
32
|
+
from rasa.dialogue_understanding.commands.session_start_command import (
|
|
33
|
+
SessionStartCommand,
|
|
34
|
+
)
|
|
31
35
|
|
|
32
36
|
__all__ = [
|
|
33
37
|
"Command",
|
|
@@ -46,4 +50,6 @@ __all__ = [
|
|
|
46
50
|
"ErrorCommand",
|
|
47
51
|
"NoopCommand",
|
|
48
52
|
"ChangeFlowCommand",
|
|
53
|
+
"SessionStartCommand",
|
|
54
|
+
"RestartCommand",
|
|
49
55
|
]
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
|
|
6
|
+
from rasa.dialogue_understanding.commands import Command
|
|
7
|
+
from rasa.dialogue_understanding.patterns.restart import RestartPatternFlowStackFrame
|
|
8
|
+
from rasa.shared.core.events import Event
|
|
9
|
+
from rasa.shared.core.flows import FlowsList
|
|
10
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class RestartCommand(Command):
|
|
15
|
+
"""A command to restart a session."""
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def command(cls) -> str:
|
|
19
|
+
"""Returns the command type."""
|
|
20
|
+
return "restart"
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def from_dict(cls, data: Dict[str, Any]) -> RestartCommand:
|
|
24
|
+
"""Converts the dictionary to a command.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
The converted dictionary.
|
|
28
|
+
"""
|
|
29
|
+
return RestartCommand()
|
|
30
|
+
|
|
31
|
+
def run_command_on_tracker(
|
|
32
|
+
self,
|
|
33
|
+
tracker: DialogueStateTracker,
|
|
34
|
+
all_flows: FlowsList,
|
|
35
|
+
original_tracker: DialogueStateTracker,
|
|
36
|
+
) -> List[Event]:
|
|
37
|
+
"""Runs the command on the tracker.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
tracker: The tracker to run the command on.
|
|
41
|
+
all_flows: All flows in the assistant.
|
|
42
|
+
original_tracker: The tracker before any command was executed.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
The events to apply to the tracker.
|
|
46
|
+
"""
|
|
47
|
+
stack = tracker.stack
|
|
48
|
+
stack.push(RestartPatternFlowStackFrame())
|
|
49
|
+
return tracker.create_stack_updated_events(stack)
|
|
50
|
+
|
|
51
|
+
def __hash__(self) -> int:
|
|
52
|
+
return hash(self.command())
|
|
53
|
+
|
|
54
|
+
def __eq__(self, other: object) -> bool:
|
|
55
|
+
if not isinstance(other, RestartCommand):
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
return True
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
from rasa.dialogue_understanding.commands import Command
|
|
6
|
+
from rasa.dialogue_understanding.patterns.session_start import (
|
|
7
|
+
SessionStartPatternFlowStackFrame,
|
|
8
|
+
)
|
|
9
|
+
from rasa.shared.core.events import Event
|
|
10
|
+
from rasa.shared.core.flows import FlowsList
|
|
11
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class SessionStartCommand(Command):
|
|
16
|
+
"""A command to indicate the start of a session."""
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def command(cls) -> str:
|
|
20
|
+
"""Returns the command type."""
|
|
21
|
+
return "session start"
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_dict(cls, data: Dict[str, Any]) -> SessionStartCommand:
|
|
25
|
+
"""Converts the dictionary to a command.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
The converted dictionary.
|
|
29
|
+
"""
|
|
30
|
+
return SessionStartCommand()
|
|
31
|
+
|
|
32
|
+
def run_command_on_tracker(
|
|
33
|
+
self,
|
|
34
|
+
tracker: DialogueStateTracker,
|
|
35
|
+
all_flows: FlowsList,
|
|
36
|
+
original_tracker: DialogueStateTracker,
|
|
37
|
+
) -> List[Event]:
|
|
38
|
+
"""Runs the command on the tracker.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
tracker: The tracker to run the command on.
|
|
42
|
+
all_flows: All flows in the assistant.
|
|
43
|
+
original_tracker: The tracker before any command was executed.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
The events to apply to the tracker.
|
|
47
|
+
"""
|
|
48
|
+
stack = tracker.stack
|
|
49
|
+
stack.push(SessionStartPatternFlowStackFrame())
|
|
50
|
+
return tracker.create_stack_updated_events(stack)
|
|
51
|
+
|
|
52
|
+
def __hash__(self) -> int:
|
|
53
|
+
return hash(self.command())
|
|
54
|
+
|
|
55
|
+
def __eq__(self, other: object) -> bool:
|
|
56
|
+
if not isinstance(other, SessionStartCommand):
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
return True
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import Dict, Type
|
|
2
|
+
|
|
3
|
+
from rasa.dialogue_understanding.commands import (
|
|
4
|
+
CancelFlowCommand,
|
|
5
|
+
CannotHandleCommand,
|
|
6
|
+
ChitChatAnswerCommand,
|
|
7
|
+
Command,
|
|
8
|
+
HumanHandoffCommand,
|
|
9
|
+
KnowledgeAnswerCommand,
|
|
10
|
+
SessionStartCommand,
|
|
11
|
+
SkipQuestionCommand,
|
|
12
|
+
RestartCommand,
|
|
13
|
+
)
|
|
14
|
+
from rasa.dialogue_understanding.patterns.cancel import CancelPatternFlowStackFrame
|
|
15
|
+
from rasa.dialogue_understanding.patterns.cannot_handle import (
|
|
16
|
+
CannotHandlePatternFlowStackFrame,
|
|
17
|
+
)
|
|
18
|
+
from rasa.dialogue_understanding.patterns.chitchat import ChitchatPatternFlowStackFrame
|
|
19
|
+
from rasa.dialogue_understanding.patterns.human_handoff import (
|
|
20
|
+
HumanHandoffPatternFlowStackFrame,
|
|
21
|
+
)
|
|
22
|
+
from rasa.dialogue_understanding.patterns.restart import RestartPatternFlowStackFrame
|
|
23
|
+
from rasa.dialogue_understanding.patterns.search import SearchPatternFlowStackFrame
|
|
24
|
+
from rasa.dialogue_understanding.patterns.session_start import (
|
|
25
|
+
SessionStartPatternFlowStackFrame,
|
|
26
|
+
)
|
|
27
|
+
from rasa.dialogue_understanding.patterns.skip_question import (
|
|
28
|
+
SkipQuestionPatternFlowStackFrame,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
triggerable_pattern_to_command_class: Dict[str, Type[Command]] = {
|
|
32
|
+
SessionStartPatternFlowStackFrame.flow_id: SessionStartCommand,
|
|
33
|
+
CancelPatternFlowStackFrame.flow_id: CancelFlowCommand,
|
|
34
|
+
ChitchatPatternFlowStackFrame.flow_id: ChitChatAnswerCommand,
|
|
35
|
+
HumanHandoffPatternFlowStackFrame.flow_id: HumanHandoffCommand,
|
|
36
|
+
SearchPatternFlowStackFrame.flow_id: KnowledgeAnswerCommand,
|
|
37
|
+
SkipQuestionPatternFlowStackFrame.flow_id: SkipQuestionCommand,
|
|
38
|
+
CannotHandlePatternFlowStackFrame.flow_id: CannotHandleCommand,
|
|
39
|
+
RestartPatternFlowStackFrame.flow_id: RestartCommand,
|
|
40
|
+
}
|
|
@@ -1,14 +1,20 @@
|
|
|
1
|
+
from rasa.shared.constants import (
|
|
2
|
+
PROVIDER_CONFIG_KEY,
|
|
3
|
+
OPENAI_PROVIDER,
|
|
4
|
+
MODEL_CONFIG_KEY,
|
|
5
|
+
TIMEOUT_CONFIG_KEY,
|
|
6
|
+
)
|
|
1
7
|
from rasa.shared.utils.llm import (
|
|
2
8
|
DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
|
|
3
9
|
DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
4
10
|
)
|
|
5
11
|
|
|
6
12
|
DEFAULT_LLM_CONFIG = {
|
|
7
|
-
|
|
8
|
-
|
|
13
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
14
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
|
|
9
15
|
"temperature": 0.0,
|
|
10
|
-
"model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME_ADVANCED,
|
|
11
16
|
"max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
17
|
+
TIMEOUT_CONFIG_KEY: 7,
|
|
12
18
|
}
|
|
13
19
|
|
|
14
20
|
LLM_CONFIG_KEY = "llm"
|
|
@@ -16,3 +22,4 @@ USER_INPUT_CONFIG_KEY = "user_input"
|
|
|
16
22
|
|
|
17
23
|
FLOW_RETRIEVAL_KEY = "flow_retrieval"
|
|
18
24
|
FLOW_RETRIEVAL_ACTIVE_KEY = "active"
|
|
25
|
+
FLOW_RETRIEVAL_EMBEDDINGS_CONFIG_KEY = "embeddings"
|
|
@@ -25,16 +25,24 @@ import structlog
|
|
|
25
25
|
from jinja2 import Template
|
|
26
26
|
from langchain.docstore.document import Document
|
|
27
27
|
from langchain.schema.embeddings import Embeddings
|
|
28
|
-
from
|
|
29
|
-
from
|
|
28
|
+
from langchain_community.vectorstores.faiss import FAISS
|
|
29
|
+
from langchain_community.vectorstores.utils import DistanceStrategy
|
|
30
30
|
from rasa.engine.storage.resource import Resource
|
|
31
31
|
from rasa.engine.storage.storage import ModelStorage
|
|
32
|
+
from rasa.shared.constants import (
|
|
33
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
34
|
+
PROVIDER_CONFIG_KEY,
|
|
35
|
+
OPENAI_PROVIDER,
|
|
36
|
+
)
|
|
32
37
|
from rasa.shared.core.domain import Domain
|
|
33
38
|
from rasa.shared.core.flows import FlowsList
|
|
34
39
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
35
40
|
from rasa.shared.nlu.constants import TEXT, FLOWS_FROM_SEMANTIC_SEARCH
|
|
36
41
|
from rasa.shared.nlu.training_data.message import Message
|
|
37
42
|
from rasa.shared.exceptions import ProviderClientAPIException
|
|
43
|
+
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
44
|
+
_LangchainEmbeddingClientAdapter,
|
|
45
|
+
)
|
|
38
46
|
from rasa.shared.utils.llm import (
|
|
39
47
|
tracker_as_readable_transcript,
|
|
40
48
|
embedder_factory,
|
|
@@ -42,15 +50,15 @@ from rasa.shared.utils.llm import (
|
|
|
42
50
|
USER,
|
|
43
51
|
get_prompt_template,
|
|
44
52
|
allowed_values_for_slot,
|
|
53
|
+
try_instantiate_embedder,
|
|
45
54
|
)
|
|
46
55
|
|
|
47
56
|
DEFAULT_FLOW_DOCUMENT_TEMPLATE = importlib.resources.read_text(
|
|
48
57
|
"rasa.dialogue_understanding.generator", "flow_document_template.jinja2"
|
|
49
58
|
)
|
|
50
59
|
|
|
51
|
-
EMBEDDINGS_CONFIG_KEY = "embeddings"
|
|
52
60
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
53
|
-
|
|
61
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
54
62
|
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
55
63
|
}
|
|
56
64
|
|
|
@@ -135,6 +143,12 @@ class FlowRetrieval:
|
|
|
135
143
|
"""Load flow retrieval with previously populated FAISS vector store."""
|
|
136
144
|
# initialize base flow retrieval
|
|
137
145
|
flow_retrieval = FlowRetrieval(config, model_storage, resource)
|
|
146
|
+
try_instantiate_embedder(
|
|
147
|
+
flow_retrieval.config.get(EMBEDDINGS_CONFIG_KEY),
|
|
148
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
149
|
+
"flow_retrieval.load",
|
|
150
|
+
FlowRetrieval.__name__,
|
|
151
|
+
)
|
|
138
152
|
# load vector store
|
|
139
153
|
vector_store = cls._load_vector_store(
|
|
140
154
|
flow_retrieval.config, model_storage, resource
|
|
@@ -154,6 +168,7 @@ class FlowRetrieval:
|
|
|
154
168
|
folder_path=model_path,
|
|
155
169
|
embeddings=embeddings,
|
|
156
170
|
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT,
|
|
171
|
+
allow_dangerous_deserialization=True,
|
|
157
172
|
)
|
|
158
173
|
except Exception as e:
|
|
159
174
|
structlogger.warning(
|
|
@@ -170,9 +185,10 @@ class FlowRetrieval:
|
|
|
170
185
|
Returns:
|
|
171
186
|
The embedder.
|
|
172
187
|
"""
|
|
173
|
-
|
|
188
|
+
client = embedder_factory(
|
|
174
189
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
175
190
|
)
|
|
191
|
+
return _LangchainEmbeddingClientAdapter(client)
|
|
176
192
|
|
|
177
193
|
def persist(self) -> None:
|
|
178
194
|
self._persist_vector_store()
|
|
@@ -32,8 +32,9 @@ from rasa.shared.nlu.constants import FLOWS_IN_PROMPT
|
|
|
32
32
|
from rasa.shared.nlu.training_data.message import Message
|
|
33
33
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
34
34
|
from rasa.shared.utils.llm import (
|
|
35
|
-
llm_factory,
|
|
36
35
|
allowed_values_for_slot,
|
|
36
|
+
llm_factory,
|
|
37
|
+
try_instantiate_llm_client,
|
|
37
38
|
)
|
|
38
39
|
from rasa.utils.log_utils import log_llm
|
|
39
40
|
|
|
@@ -167,6 +168,14 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
167
168
|
"""Train the llm based command generator. Stores all flows into a vector
|
|
168
169
|
store.
|
|
169
170
|
"""
|
|
171
|
+
# Validate llm configuration
|
|
172
|
+
try_instantiate_llm_client(
|
|
173
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
174
|
+
DEFAULT_LLM_CONFIG,
|
|
175
|
+
"llm_based_command_generator.train",
|
|
176
|
+
"LLMBasedCommandGenerator",
|
|
177
|
+
)
|
|
178
|
+
|
|
170
179
|
# flow retrieval is populated with only user-defined flows
|
|
171
180
|
try:
|
|
172
181
|
if self.flow_retrieval is not None and not flows.is_empty():
|
|
@@ -174,7 +183,7 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
174
183
|
except Exception as e:
|
|
175
184
|
structlogger.error(
|
|
176
185
|
"llm_based_command_generator.train.failed",
|
|
177
|
-
event_info=
|
|
186
|
+
event_info="Flow retrieval store is inaccessible.",
|
|
178
187
|
error=e,
|
|
179
188
|
)
|
|
180
189
|
raise
|
|
@@ -286,7 +295,8 @@ class LLMBasedCommandGenerator(GraphComponent, CommandGenerator, ABC):
|
|
|
286
295
|
"""
|
|
287
296
|
llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG)
|
|
288
297
|
try:
|
|
289
|
-
|
|
298
|
+
llm_response = await llm.acompletion(prompt)
|
|
299
|
+
return llm_response.choices[0]
|
|
290
300
|
except Exception as e:
|
|
291
301
|
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
292
302
|
# we have to catch all exceptions here
|