rasa-pro 3.9.18__py3-none-any.whl → 3.10.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +0 -374
- rasa/__init__.py +1 -2
- rasa/__main__.py +5 -0
- rasa/anonymization/anonymization_rule_executor.py +2 -2
- rasa/api.py +27 -23
- rasa/cli/arguments/data.py +27 -2
- rasa/cli/arguments/default_arguments.py +25 -3
- rasa/cli/arguments/run.py +9 -9
- rasa/cli/arguments/train.py +11 -3
- rasa/cli/data.py +70 -8
- rasa/cli/e2e_test.py +104 -431
- rasa/cli/evaluate.py +1 -1
- rasa/cli/interactive.py +1 -0
- rasa/cli/llm_fine_tuning.py +398 -0
- rasa/cli/project_templates/calm/endpoints.yml +1 -1
- rasa/cli/project_templates/tutorial/endpoints.yml +1 -1
- rasa/cli/run.py +15 -14
- rasa/cli/scaffold.py +10 -8
- rasa/cli/studio/studio.py +35 -5
- rasa/cli/train.py +56 -8
- rasa/cli/utils.py +22 -5
- rasa/cli/x.py +1 -1
- rasa/constants.py +7 -1
- rasa/core/actions/action.py +98 -49
- rasa/core/actions/action_run_slot_rejections.py +4 -1
- rasa/core/actions/custom_action_executor.py +9 -6
- rasa/core/actions/direct_custom_actions_executor.py +80 -0
- rasa/core/actions/e2e_stub_custom_action_executor.py +68 -0
- rasa/core/actions/grpc_custom_action_executor.py +2 -2
- rasa/core/actions/http_custom_action_executor.py +6 -5
- rasa/core/agent.py +21 -17
- rasa/core/channels/__init__.py +2 -0
- rasa/core/channels/audiocodes.py +1 -16
- rasa/core/channels/voice_aware/__init__.py +0 -0
- rasa/core/channels/voice_aware/jambonz.py +103 -0
- rasa/core/channels/voice_aware/jambonz_protocol.py +344 -0
- rasa/core/channels/voice_aware/utils.py +20 -0
- rasa/core/channels/voice_native/__init__.py +0 -0
- rasa/core/constants.py +6 -1
- rasa/core/information_retrieval/faiss.py +7 -4
- rasa/core/information_retrieval/information_retrieval.py +8 -0
- rasa/core/information_retrieval/milvus.py +9 -2
- rasa/core/information_retrieval/qdrant.py +1 -1
- rasa/core/nlg/contextual_response_rephraser.py +32 -10
- rasa/core/nlg/summarize.py +4 -3
- rasa/core/policies/enterprise_search_policy.py +113 -45
- rasa/core/policies/flows/flow_executor.py +122 -76
- rasa/core/policies/intentless_policy.py +83 -29
- rasa/core/processor.py +72 -54
- rasa/core/run.py +5 -4
- rasa/core/tracker_store.py +8 -4
- rasa/core/training/interactive.py +1 -1
- rasa/core/utils.py +56 -57
- rasa/dialogue_understanding/coexistence/llm_based_router.py +53 -13
- rasa/dialogue_understanding/commands/__init__.py +6 -0
- rasa/dialogue_understanding/commands/restart_command.py +58 -0
- rasa/dialogue_understanding/commands/session_start_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +40 -0
- rasa/dialogue_understanding/generator/constants.py +10 -3
- rasa/dialogue_understanding/generator/flow_retrieval.py +21 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +13 -3
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +134 -90
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +47 -7
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +127 -41
- rasa/dialogue_understanding/patterns/restart.py +37 -0
- rasa/dialogue_understanding/patterns/session_start.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +16 -3
- rasa/dialogue_understanding/processor/command_processor_component.py +6 -2
- rasa/e2e_test/aggregate_test_stats_calculator.py +134 -0
- rasa/e2e_test/assertions.py +1223 -0
- rasa/e2e_test/assertions_schema.yml +106 -0
- rasa/e2e_test/constants.py +20 -0
- rasa/e2e_test/e2e_config.py +220 -0
- rasa/e2e_test/e2e_config_schema.yml +26 -0
- rasa/e2e_test/e2e_test_case.py +131 -8
- rasa/e2e_test/e2e_test_converter.py +363 -0
- rasa/e2e_test/e2e_test_converter_prompt.jinja2 +70 -0
- rasa/e2e_test/e2e_test_coverage_report.py +364 -0
- rasa/e2e_test/e2e_test_result.py +26 -6
- rasa/e2e_test/e2e_test_runner.py +493 -71
- rasa/e2e_test/e2e_test_schema.yml +96 -0
- rasa/e2e_test/pykwalify_extensions.py +39 -0
- rasa/e2e_test/stub_custom_action.py +70 -0
- rasa/e2e_test/utils/__init__.py +0 -0
- rasa/e2e_test/utils/e2e_yaml_utils.py +55 -0
- rasa/e2e_test/utils/io.py +598 -0
- rasa/e2e_test/utils/validation.py +80 -0
- rasa/engine/graph.py +9 -3
- rasa/engine/recipes/default_components.py +0 -2
- rasa/engine/recipes/default_recipe.py +10 -2
- rasa/engine/storage/local_model_storage.py +40 -12
- rasa/engine/validation.py +78 -1
- rasa/env.py +9 -0
- rasa/graph_components/providers/story_graph_provider.py +59 -6
- rasa/llm_fine_tuning/__init__.py +0 -0
- rasa/llm_fine_tuning/annotation_module.py +241 -0
- rasa/llm_fine_tuning/conversations.py +144 -0
- rasa/llm_fine_tuning/llm_data_preparation_module.py +178 -0
- rasa/llm_fine_tuning/notebooks/unsloth_finetuning.ipynb +407 -0
- rasa/llm_fine_tuning/paraphrasing/__init__.py +0 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +281 -0
- rasa/llm_fine_tuning/paraphrasing/default_rephrase_prompt_template.jina2 +44 -0
- rasa/llm_fine_tuning/paraphrasing/rephrase_validator.py +121 -0
- rasa/llm_fine_tuning/paraphrasing/rephrased_user_message.py +10 -0
- rasa/llm_fine_tuning/paraphrasing_module.py +128 -0
- rasa/llm_fine_tuning/storage.py +174 -0
- rasa/llm_fine_tuning/train_test_split_module.py +441 -0
- rasa/model_training.py +56 -16
- rasa/nlu/persistor.py +157 -36
- rasa/server.py +45 -10
- rasa/shared/constants.py +76 -16
- rasa/shared/core/domain.py +27 -19
- rasa/shared/core/events.py +28 -2
- rasa/shared/core/flows/flow.py +208 -13
- rasa/shared/core/flows/flow_path.py +84 -0
- rasa/shared/core/flows/flows_list.py +33 -11
- rasa/shared/core/flows/flows_yaml_schema.json +269 -193
- rasa/shared/core/flows/validation.py +112 -25
- rasa/shared/core/flows/yaml_flows_io.py +149 -10
- rasa/shared/core/trackers.py +6 -0
- rasa/shared/core/training_data/structures.py +20 -0
- rasa/shared/core/training_data/visualization.html +2 -2
- rasa/shared/exceptions.py +4 -0
- rasa/shared/importers/importer.py +64 -16
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/providers/_configs/__init__.py +0 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +183 -0
- rasa/shared/providers/_configs/client_config.py +57 -0
- rasa/shared/providers/_configs/default_litellm_client_config.py +130 -0
- rasa/shared/providers/_configs/huggingface_local_embedding_client_config.py +234 -0
- rasa/shared/providers/_configs/openai_client_config.py +175 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +176 -0
- rasa/shared/providers/_configs/utils.py +101 -0
- rasa/shared/providers/_ssl_verification_utils.py +124 -0
- rasa/shared/providers/embedding/__init__.py +0 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +259 -0
- rasa/shared/providers/embedding/_langchain_embedding_client_adapter.py +74 -0
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +277 -0
- rasa/shared/providers/embedding/default_litellm_embedding_client.py +102 -0
- rasa/shared/providers/embedding/embedding_client.py +90 -0
- rasa/shared/providers/embedding/embedding_response.py +41 -0
- rasa/shared/providers/embedding/huggingface_local_embedding_client.py +191 -0
- rasa/shared/providers/embedding/openai_embedding_client.py +172 -0
- rasa/shared/providers/llm/__init__.py +0 -0
- rasa/shared/providers/llm/_base_litellm_client.py +251 -0
- rasa/shared/providers/llm/azure_openai_llm_client.py +338 -0
- rasa/shared/providers/llm/default_litellm_llm_client.py +84 -0
- rasa/shared/providers/llm/llm_client.py +76 -0
- rasa/shared/providers/llm/llm_response.py +50 -0
- rasa/shared/providers/llm/openai_llm_client.py +155 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +293 -0
- rasa/shared/providers/mappings.py +75 -0
- rasa/shared/utils/cli.py +30 -0
- rasa/shared/utils/io.py +65 -2
- rasa/shared/utils/llm.py +246 -200
- rasa/shared/utils/yaml.py +121 -15
- rasa/studio/auth.py +6 -4
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/download.py +19 -13
- rasa/studio/train.py +2 -3
- rasa/studio/upload.py +19 -11
- rasa/telemetry.py +113 -58
- rasa/tracing/instrumentation/attribute_extractors.py +32 -17
- rasa/utils/common.py +18 -19
- rasa/utils/endpoints.py +7 -4
- rasa/utils/json_utils.py +60 -0
- rasa/utils/licensing.py +9 -1
- rasa/utils/ml_utils.py +4 -2
- rasa/validator.py +213 -3
- rasa/version.py +1 -1
- rasa_pro-3.10.16.dist-info/METADATA +196 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/RECORD +179 -113
- rasa/nlu/classifiers/llm_intent_classifier.py +0 -519
- rasa/shared/providers/openai/clients.py +0 -43
- rasa/shared/providers/openai/session_handler.py +0 -110
- rasa_pro-3.9.18.dist-info/METADATA +0 -563
- /rasa/{shared/providers/openai → cli/project_templates/tutorial/actions}/__init__.py +0 -0
- /rasa/cli/project_templates/tutorial/{actions.py → actions/actions.py} +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/NOTICE +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/WHEEL +0 -0
- {rasa_pro-3.9.18.dist-info → rasa_pro-3.10.16.dist-info}/entry_points.txt +0 -0
|
@@ -8,13 +8,14 @@ import tiktoken
|
|
|
8
8
|
from jinja2 import Template
|
|
9
9
|
from langchain.docstore.document import Document
|
|
10
10
|
from langchain.schema.embeddings import Embeddings
|
|
11
|
-
from
|
|
11
|
+
from langchain_community.vectorstores.faiss import FAISS
|
|
12
12
|
|
|
13
13
|
import rasa.shared.utils.io
|
|
14
14
|
from rasa import telemetry
|
|
15
15
|
from rasa.core.constants import (
|
|
16
16
|
CHAT_POLICY_PRIORITY,
|
|
17
17
|
POLICY_PRIORITY,
|
|
18
|
+
UTTER_SOURCE_METADATA_KEY,
|
|
18
19
|
)
|
|
19
20
|
from rasa.core.policies.policy import Policy, PolicyPrediction, SupportedData
|
|
20
21
|
from rasa.dialogue_understanding.stack.frames import (
|
|
@@ -27,7 +28,17 @@ from rasa.engine.storage.resource import Resource
|
|
|
27
28
|
from rasa.engine.storage.storage import ModelStorage
|
|
28
29
|
from rasa.graph_components.providers.forms_provider import Forms
|
|
29
30
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
30
|
-
from rasa.shared.constants import
|
|
31
|
+
from rasa.shared.constants import (
|
|
32
|
+
REQUIRED_SLOTS_KEY,
|
|
33
|
+
EMBEDDINGS_CONFIG_KEY,
|
|
34
|
+
LLM_CONFIG_KEY,
|
|
35
|
+
MODEL_CONFIG_KEY,
|
|
36
|
+
MODEL_NAME_CONFIG_KEY,
|
|
37
|
+
PROMPT_CONFIG_KEY,
|
|
38
|
+
PROVIDER_CONFIG_KEY,
|
|
39
|
+
OPENAI_PROVIDER,
|
|
40
|
+
TIMEOUT_CONFIG_KEY,
|
|
41
|
+
)
|
|
31
42
|
from rasa.shared.core.constants import ACTION_LISTEN_NAME
|
|
32
43
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
33
44
|
from rasa.shared.core.events import (
|
|
@@ -42,6 +53,10 @@ from rasa.shared.core.trackers import DialogueStateTracker
|
|
|
42
53
|
from rasa.shared.exceptions import FileIOException, RasaCoreException
|
|
43
54
|
from rasa.shared.nlu.constants import PREDICTED_CONFIDENCE_KEY
|
|
44
55
|
from rasa.shared.nlu.training_data.training_data import TrainingData
|
|
56
|
+
from rasa.shared.providers.embedding._langchain_embedding_client_adapter import (
|
|
57
|
+
_LangchainEmbeddingClientAdapter,
|
|
58
|
+
)
|
|
59
|
+
from rasa.shared.providers.llm.llm_client import LLMClient
|
|
45
60
|
from rasa.shared.utils.io import deep_container_fingerprint
|
|
46
61
|
from rasa.shared.utils.llm import (
|
|
47
62
|
AI,
|
|
@@ -55,6 +70,8 @@ from rasa.shared.utils.llm import (
|
|
|
55
70
|
llm_factory,
|
|
56
71
|
sanitize_message_for_prompt,
|
|
57
72
|
tracker_as_readable_transcript,
|
|
73
|
+
try_instantiate_llm_client,
|
|
74
|
+
try_instantiate_embedder,
|
|
58
75
|
)
|
|
59
76
|
from rasa.utils.ml_utils import (
|
|
60
77
|
extract_ai_response_examples,
|
|
@@ -64,11 +81,12 @@ from rasa.utils.ml_utils import (
|
|
|
64
81
|
persist_faiss_vector_store,
|
|
65
82
|
response_for_template,
|
|
66
83
|
)
|
|
84
|
+
from rasa.dialogue_understanding.patterns.chitchat import FLOW_PATTERN_CHITCHAT
|
|
85
|
+
from rasa.shared.core.constants import ACTION_TRIGGER_CHITCHAT
|
|
67
86
|
from rasa.utils.log_utils import log_llm
|
|
68
87
|
|
|
69
88
|
if TYPE_CHECKING:
|
|
70
89
|
from rasa.core.featurizers.tracker_featurizers import TrackerFeaturizer
|
|
71
|
-
from langchain.llms.base import BaseLLM
|
|
72
90
|
|
|
73
91
|
structlogger = structlog.get_logger()
|
|
74
92
|
|
|
@@ -87,18 +105,16 @@ MAX_NUMBER_OF_TOKENS_FOR_SAMPLES = 900
|
|
|
87
105
|
# the config property name for the confidence of the nlu prediction
|
|
88
106
|
NLU_ABSTENTION_THRESHOLD = "nlu_abstention_threshold"
|
|
89
107
|
|
|
90
|
-
PROMPT = "prompt"
|
|
91
|
-
|
|
92
108
|
DEFAULT_LLM_CONFIG = {
|
|
93
|
-
|
|
94
|
-
|
|
109
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
110
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
95
111
|
"temperature": 0.0,
|
|
96
|
-
"model_name": DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
97
112
|
"max_tokens": DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
113
|
+
TIMEOUT_CONFIG_KEY: 5,
|
|
98
114
|
}
|
|
99
115
|
|
|
100
116
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
101
|
-
|
|
117
|
+
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
102
118
|
"model": DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
103
119
|
}
|
|
104
120
|
|
|
@@ -106,8 +122,6 @@ DEFAULT_INTENTLESS_PROMPT_TEMPLATE = importlib.resources.open_text(
|
|
|
106
122
|
"rasa.core.policies", "intentless_prompt_template.jinja2"
|
|
107
123
|
).name
|
|
108
124
|
|
|
109
|
-
EMBEDDINGS_CONFIG_KEY = "embeddings"
|
|
110
|
-
LLM_CONFIG_KEY = "llm"
|
|
111
125
|
INTENTLESS_PROMPT_TEMPLATE_FILE_NAME = "intentless_policy_prompt.jinja2"
|
|
112
126
|
|
|
113
127
|
|
|
@@ -163,6 +177,21 @@ def filter_responses(responses: Responses, forms: Forms, flows: FlowsList) -> Re
|
|
|
163
177
|
for name, variants in responses.data.items()
|
|
164
178
|
if name not in combined_responses
|
|
165
179
|
}
|
|
180
|
+
|
|
181
|
+
pattern_chitchat = flows.flow_by_id(FLOW_PATTERN_CHITCHAT)
|
|
182
|
+
|
|
183
|
+
# The following condition is highly unlikely, but mypy requires the case
|
|
184
|
+
# of pattern_chitchat == None to be addressed
|
|
185
|
+
if not pattern_chitchat:
|
|
186
|
+
return Responses(data=filtered_responses)
|
|
187
|
+
|
|
188
|
+
# if action_trigger_chitchat, filter out "utter_free_chitchat_response"
|
|
189
|
+
has_action_trigger_chitchat = pattern_chitchat.has_action_step(
|
|
190
|
+
ACTION_TRIGGER_CHITCHAT
|
|
191
|
+
)
|
|
192
|
+
if has_action_trigger_chitchat:
|
|
193
|
+
filtered_responses.pop("utter_free_chitchat_response", None)
|
|
194
|
+
|
|
166
195
|
return Responses(data=filtered_responses)
|
|
167
196
|
|
|
168
197
|
|
|
@@ -366,7 +395,7 @@ class IntentlessPolicy(Policy):
|
|
|
366
395
|
NLU_ABSTENTION_THRESHOLD: 0.9,
|
|
367
396
|
LLM_CONFIG_KEY: DEFAULT_LLM_CONFIG,
|
|
368
397
|
EMBEDDINGS_CONFIG_KEY: DEFAULT_EMBEDDINGS_CONFIG,
|
|
369
|
-
|
|
398
|
+
PROMPT_CONFIG_KEY: DEFAULT_INTENTLESS_PROMPT_TEMPLATE,
|
|
370
399
|
}
|
|
371
400
|
|
|
372
401
|
@staticmethod
|
|
@@ -405,7 +434,7 @@ class IntentlessPolicy(Policy):
|
|
|
405
434
|
self.conversation_samples_index = samples_docsearch
|
|
406
435
|
self.embedder = self._create_plain_embedder(config)
|
|
407
436
|
self.prompt_template = prompt_template or rasa.shared.utils.io.read_file(
|
|
408
|
-
self.config[
|
|
437
|
+
self.config[PROMPT_CONFIG_KEY]
|
|
409
438
|
)
|
|
410
439
|
self.trace_prompt_tokens = self.config.get("trace_prompt_tokens", False)
|
|
411
440
|
|
|
@@ -416,9 +445,10 @@ class IntentlessPolicy(Policy):
|
|
|
416
445
|
Returns:
|
|
417
446
|
The embedder.
|
|
418
447
|
"""
|
|
419
|
-
|
|
448
|
+
client = embedder_factory(
|
|
420
449
|
config.get(EMBEDDINGS_CONFIG_KEY), DEFAULT_EMBEDDINGS_CONFIG
|
|
421
450
|
)
|
|
451
|
+
return _LangchainEmbeddingClientAdapter(client)
|
|
422
452
|
|
|
423
453
|
def embeddings_property(self, prop: str) -> Optional[str]:
|
|
424
454
|
"""Returns the property of the embeddings config."""
|
|
@@ -458,6 +488,13 @@ class IntentlessPolicy(Policy):
|
|
|
458
488
|
A policy must return its resource locator so that potential children nodes
|
|
459
489
|
can load the policy from the resource.
|
|
460
490
|
"""
|
|
491
|
+
try_instantiate_llm_client(
|
|
492
|
+
self.config.get(LLM_CONFIG_KEY),
|
|
493
|
+
DEFAULT_LLM_CONFIG,
|
|
494
|
+
"intentless_policy.train",
|
|
495
|
+
"IntentlessPolicy",
|
|
496
|
+
)
|
|
497
|
+
|
|
461
498
|
responses = filter_responses(responses, forms, flows or FlowsList([]))
|
|
462
499
|
telemetry.track_intentless_policy_train()
|
|
463
500
|
response_texts = [r for r in extract_ai_response_examples(responses.data)]
|
|
@@ -500,11 +537,12 @@ class IntentlessPolicy(Policy):
|
|
|
500
537
|
|
|
501
538
|
structlogger.info("intentless_policy.training.completed")
|
|
502
539
|
telemetry.track_intentless_policy_train_completed(
|
|
503
|
-
embeddings_type=self.embeddings_property(
|
|
504
|
-
embeddings_model=self.embeddings_property(
|
|
505
|
-
or self.embeddings_property(
|
|
506
|
-
llm_type=self.llm_property(
|
|
507
|
-
llm_model=self.llm_property(
|
|
540
|
+
embeddings_type=self.embeddings_property(PROVIDER_CONFIG_KEY),
|
|
541
|
+
embeddings_model=self.embeddings_property(MODEL_CONFIG_KEY)
|
|
542
|
+
or self.embeddings_property(MODEL_NAME_CONFIG_KEY),
|
|
543
|
+
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
544
|
+
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
545
|
+
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
508
546
|
)
|
|
509
547
|
|
|
510
548
|
self.persist()
|
|
@@ -578,11 +616,12 @@ class IntentlessPolicy(Policy):
|
|
|
578
616
|
)
|
|
579
617
|
|
|
580
618
|
telemetry.track_intentless_policy_predict(
|
|
581
|
-
embeddings_type=self.embeddings_property(
|
|
582
|
-
embeddings_model=self.embeddings_property(
|
|
583
|
-
or self.embeddings_property(
|
|
584
|
-
llm_type=self.llm_property(
|
|
585
|
-
llm_model=self.llm_property(
|
|
619
|
+
embeddings_type=self.embeddings_property(PROVIDER_CONFIG_KEY),
|
|
620
|
+
embeddings_model=self.embeddings_property(MODEL_CONFIG_KEY)
|
|
621
|
+
or self.embeddings_property(MODEL_NAME_CONFIG_KEY),
|
|
622
|
+
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
623
|
+
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
624
|
+
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
586
625
|
score=score,
|
|
587
626
|
)
|
|
588
627
|
|
|
@@ -595,7 +634,9 @@ class IntentlessPolicy(Policy):
|
|
|
595
634
|
else:
|
|
596
635
|
events = []
|
|
597
636
|
|
|
598
|
-
|
|
637
|
+
action_metadata = {UTTER_SOURCE_METADATA_KEY: self.__class__.__name__}
|
|
638
|
+
|
|
639
|
+
return self._prediction(result, events=events, action_metadata=action_metadata)
|
|
599
640
|
|
|
600
641
|
async def generate_answer(
|
|
601
642
|
self,
|
|
@@ -619,9 +660,10 @@ class IntentlessPolicy(Policy):
|
|
|
619
660
|
)
|
|
620
661
|
return await self._generate_llm_answer(llm, prompt)
|
|
621
662
|
|
|
622
|
-
async def _generate_llm_answer(self, llm:
|
|
663
|
+
async def _generate_llm_answer(self, llm: LLMClient, prompt: str) -> Optional[str]:
|
|
623
664
|
try:
|
|
624
|
-
|
|
665
|
+
llm_response = await llm.acompletion(prompt)
|
|
666
|
+
return llm_response.choices[0]
|
|
625
667
|
except Exception as e:
|
|
626
668
|
# unfortunately, langchain does not wrap LLM exceptions which means
|
|
627
669
|
# we have to catch all exceptions here
|
|
@@ -685,6 +727,7 @@ class IntentlessPolicy(Policy):
|
|
|
685
727
|
number_of_samples=NUMBER_OF_CONVERSATION_SAMPLES,
|
|
686
728
|
max_number_of_tokens=MAX_NUMBER_OF_TOKENS_FOR_SAMPLES,
|
|
687
729
|
)
|
|
730
|
+
|
|
688
731
|
extra_ai_responses = self.extract_ai_responses(conversation_samples)
|
|
689
732
|
|
|
690
733
|
# put conversation responses in front of sampled examples,
|
|
@@ -876,6 +919,18 @@ class IntentlessPolicy(Policy):
|
|
|
876
919
|
**kwargs: Any,
|
|
877
920
|
) -> "IntentlessPolicy":
|
|
878
921
|
"""Loads a trained policy (see parent class for full docstring)."""
|
|
922
|
+
try_instantiate_llm_client(
|
|
923
|
+
config.get(LLM_CONFIG_KEY),
|
|
924
|
+
DEFAULT_LLM_CONFIG,
|
|
925
|
+
"intentless_policy.load",
|
|
926
|
+
IntentlessPolicy.__name__,
|
|
927
|
+
)
|
|
928
|
+
try_instantiate_embedder(
|
|
929
|
+
config.get(EMBEDDINGS_CONFIG_KEY),
|
|
930
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
931
|
+
"intentless_policy.load",
|
|
932
|
+
IntentlessPolicy.__name__,
|
|
933
|
+
)
|
|
879
934
|
responses_docsearch = None
|
|
880
935
|
samples_docsearch = None
|
|
881
936
|
prompt_template = None
|
|
@@ -901,7 +956,6 @@ class IntentlessPolicy(Policy):
|
|
|
901
956
|
structlogger.warning(
|
|
902
957
|
"intentless_policy.load.failed", error=e, resource_name=resource.name
|
|
903
958
|
)
|
|
904
|
-
|
|
905
959
|
return cls(
|
|
906
960
|
config,
|
|
907
961
|
model_storage,
|
|
@@ -916,7 +970,7 @@ class IntentlessPolicy(Policy):
|
|
|
916
970
|
def fingerprint_addon(cls, config: Dict[str, Any]) -> Optional[str]:
|
|
917
971
|
"""Add a fingerprint of the knowledge base for the graph."""
|
|
918
972
|
prompt_template = get_prompt_template(
|
|
919
|
-
config.get(
|
|
973
|
+
config.get(PROMPT_CONFIG_KEY),
|
|
920
974
|
DEFAULT_INTENTLESS_PROMPT_TEMPLATE,
|
|
921
975
|
)
|
|
922
976
|
return deep_container_fingerprint(prompt_template)
|
rasa/core/processor.py
CHANGED
|
@@ -3,12 +3,12 @@ import copy
|
|
|
3
3
|
import logging
|
|
4
4
|
import structlog
|
|
5
5
|
import os
|
|
6
|
+
import re
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
import tarfile
|
|
8
9
|
import time
|
|
9
10
|
from types import LambdaType
|
|
10
11
|
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Text, Tuple, Union
|
|
11
|
-
|
|
12
12
|
from rasa.core.actions.action_exceptions import ActionExecutionRejection
|
|
13
13
|
from rasa.core.actions.forms import FormAction
|
|
14
14
|
from rasa.core.http_interpreter import RasaNLUHttpInterpreter
|
|
@@ -101,6 +101,9 @@ logger = logging.getLogger(__name__)
|
|
|
101
101
|
structlogger = structlog.get_logger()
|
|
102
102
|
|
|
103
103
|
MAX_NUMBER_OF_PREDICTIONS = int(os.environ.get("MAX_NUMBER_OF_PREDICTIONS", "10"))
|
|
104
|
+
MAX_NUMBER_OF_PREDICTIONS_CALM = int(
|
|
105
|
+
os.environ.get("MAX_NUMBER_OF_PREDICTIONS_CALM", "1000")
|
|
106
|
+
)
|
|
104
107
|
|
|
105
108
|
|
|
106
109
|
class MessageProcessor:
|
|
@@ -114,6 +117,7 @@ class MessageProcessor:
|
|
|
114
117
|
generator: NaturalLanguageGenerator,
|
|
115
118
|
action_endpoint: Optional[EndpointConfig] = None,
|
|
116
119
|
max_number_of_predictions: int = MAX_NUMBER_OF_PREDICTIONS,
|
|
120
|
+
max_number_of_predictions_calm: int = MAX_NUMBER_OF_PREDICTIONS_CALM,
|
|
117
121
|
on_circuit_break: Optional[LambdaType] = None,
|
|
118
122
|
http_interpreter: Optional[RasaNLUHttpInterpreter] = None,
|
|
119
123
|
endpoints: Optional["AvailableEndpoints"] = None,
|
|
@@ -122,7 +126,6 @@ class MessageProcessor:
|
|
|
122
126
|
self.nlg = generator
|
|
123
127
|
self.tracker_store = tracker_store
|
|
124
128
|
self.lock_store = lock_store
|
|
125
|
-
self.max_number_of_predictions = max_number_of_predictions
|
|
126
129
|
self.on_circuit_break = on_circuit_break
|
|
127
130
|
self.action_endpoint = action_endpoint
|
|
128
131
|
self.model_filename, self.model_metadata, self.graph_runner = self._load_model(
|
|
@@ -130,6 +133,10 @@ class MessageProcessor:
|
|
|
130
133
|
)
|
|
131
134
|
self.endpoints = endpoints
|
|
132
135
|
|
|
136
|
+
self.max_number_of_predictions = max_number_of_predictions
|
|
137
|
+
self.max_number_of_predictions_calm = max_number_of_predictions_calm
|
|
138
|
+
self.is_calm_assistant = self._is_calm_assistant()
|
|
139
|
+
|
|
133
140
|
if self.model_metadata.assistant_id is None:
|
|
134
141
|
rasa.shared.utils.io.raise_warning(
|
|
135
142
|
f"The model metadata does not contain a value for the "
|
|
@@ -751,20 +758,17 @@ class MessageProcessor:
|
|
|
751
758
|
message=processed_message, domain=self.domain
|
|
752
759
|
)
|
|
753
760
|
|
|
754
|
-
# Invalid use of slash syntax
|
|
761
|
+
# Invalid use of slash syntax, sanitize the message before passing
|
|
762
|
+
# it to the graph
|
|
755
763
|
if (
|
|
756
764
|
processed_message.starts_with_slash_syntax()
|
|
757
765
|
and not processed_message.has_intent()
|
|
758
766
|
and not processed_message.has_commands()
|
|
759
767
|
):
|
|
760
|
-
|
|
761
|
-
processed_message, tracker, only_output_properties
|
|
762
|
-
)
|
|
768
|
+
message = self._sanitize_message(message)
|
|
763
769
|
|
|
764
770
|
# Intent or commands are not explicitly present. Pass message to graph.
|
|
765
|
-
|
|
766
|
-
processed_message.has_intent() or processed_message.has_commands()
|
|
767
|
-
):
|
|
771
|
+
if not (processed_message.has_intent() or processed_message.has_commands()):
|
|
768
772
|
parse_data = await self._parse_message_with_graph(
|
|
769
773
|
message, tracker, only_output_properties
|
|
770
774
|
)
|
|
@@ -788,44 +792,16 @@ class MessageProcessor:
|
|
|
788
792
|
|
|
789
793
|
return parse_data
|
|
790
794
|
|
|
791
|
-
def
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
"passed. Returning CannotHandleCommand() as a fallback."
|
|
802
|
-
),
|
|
803
|
-
message=message.get(TEXT),
|
|
804
|
-
)
|
|
805
|
-
parse_data: Dict[Text, Any] = {
|
|
806
|
-
TEXT: "",
|
|
807
|
-
INTENT: {INTENT_NAME_KEY: None, PREDICTED_CONFIDENCE_KEY: 0.0},
|
|
808
|
-
ENTITIES: [],
|
|
809
|
-
}
|
|
810
|
-
parse_data.update(
|
|
811
|
-
message.as_dict(only_output_properties=only_output_properties)
|
|
812
|
-
)
|
|
813
|
-
commands = parse_data.get(COMMANDS, [])
|
|
814
|
-
commands += [
|
|
815
|
-
CannotHandleCommand(RASA_PATTERN_CANNOT_HANDLE_INVALID_INTENT).as_dict()
|
|
816
|
-
]
|
|
817
|
-
|
|
818
|
-
if (
|
|
819
|
-
tracker is not None
|
|
820
|
-
and tracker.has_coexistence_routing_slot
|
|
821
|
-
and tracker.get_slot(ROUTE_TO_CALM_SLOT) is None
|
|
822
|
-
):
|
|
823
|
-
# if we are currently not routing to either CALM or dm1
|
|
824
|
-
# we make a sticky routing to CALM
|
|
825
|
-
commands += [SetSlotCommand(ROUTE_TO_CALM_SLOT, True).as_dict()]
|
|
826
|
-
|
|
827
|
-
parse_data[COMMANDS] = commands
|
|
828
|
-
return parse_data
|
|
795
|
+
def _sanitize_message(self, message: UserMessage) -> UserMessage:
|
|
796
|
+
"""Sanitize user message by removing prepended slashes before the
|
|
797
|
+
actual content.
|
|
798
|
+
"""
|
|
799
|
+
# Regex pattern to match leading slashes and any whitespace before
|
|
800
|
+
# actual content
|
|
801
|
+
pattern = r"^[/\s]+"
|
|
802
|
+
# Remove the matched pattern from the beginning of the message
|
|
803
|
+
message.text = re.sub(pattern, "", message.text).strip()
|
|
804
|
+
return message
|
|
829
805
|
|
|
830
806
|
async def _parse_message_with_commands_and_intents(
|
|
831
807
|
self,
|
|
@@ -1003,11 +979,15 @@ class MessageProcessor:
|
|
|
1003
979
|
) -> int:
|
|
1004
980
|
"""Select the action limit based on the tracker state.
|
|
1005
981
|
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
982
|
+
This function determines the maximum number of predictions that should be
|
|
983
|
+
made during a dialogue conversation. Typically, the number of predictions
|
|
984
|
+
is limited to the number of actions executed so far in the conversation.
|
|
985
|
+
However, in certain states (e.g., when the user is correcting the
|
|
986
|
+
conversation flow), more predictions may be allowed as the system traverses
|
|
987
|
+
through a long dialogue flow.
|
|
988
|
+
|
|
989
|
+
Additionally, if the `ROUTE_TO_CALM_SLOT` is present in the tracker slots,
|
|
990
|
+
the action limit is adjusted to a separate limit for CALM-based flows.
|
|
1011
991
|
|
|
1012
992
|
Args:
|
|
1013
993
|
tracker: instance of DialogueStateTracker.
|
|
@@ -1015,6 +995,18 @@ class MessageProcessor:
|
|
|
1015
995
|
Returns:
|
|
1016
996
|
The maximum number of predictions to make.
|
|
1017
997
|
"""
|
|
998
|
+
# Check if it is a CALM assistant and if so, that the `ROUTE_TO_CALM_SLOT`
|
|
999
|
+
# is either not present or set to `True`.
|
|
1000
|
+
# If it does, use the specific prediction limit for CALM assistants.
|
|
1001
|
+
# Otherwise, use the default prediction limit.
|
|
1002
|
+
if self.is_calm_assistant and (
|
|
1003
|
+
not tracker.has_coexistence_routing_slot
|
|
1004
|
+
or tracker.get_slot(ROUTE_TO_CALM_SLOT)
|
|
1005
|
+
):
|
|
1006
|
+
max_number_of_predictions = self.max_number_of_predictions_calm
|
|
1007
|
+
else:
|
|
1008
|
+
max_number_of_predictions = self.max_number_of_predictions
|
|
1009
|
+
|
|
1018
1010
|
reversed_events = list(tracker.events)[::-1]
|
|
1019
1011
|
is_conversation_in_flow_correction = False
|
|
1020
1012
|
for e in reversed_events:
|
|
@@ -1029,8 +1021,10 @@ class MessageProcessor:
|
|
|
1029
1021
|
# allow for more predictions to be made as we might be traversing through
|
|
1030
1022
|
# a long flow. We multiply the number of predictions by 10 to allow for
|
|
1031
1023
|
# more predictions to be made - the factor is a best guess.
|
|
1032
|
-
return
|
|
1033
|
-
|
|
1024
|
+
return max_number_of_predictions * 5
|
|
1025
|
+
|
|
1026
|
+
# Return the default
|
|
1027
|
+
return max_number_of_predictions
|
|
1034
1028
|
|
|
1035
1029
|
def is_action_limit_reached(
|
|
1036
1030
|
self, tracker: DialogueStateTracker, should_predict_another_action: bool
|
|
@@ -1420,3 +1414,27 @@ class MessageProcessor:
|
|
|
1420
1414
|
]
|
|
1421
1415
|
|
|
1422
1416
|
return len(filtered_commands) > 0
|
|
1417
|
+
|
|
1418
|
+
def _is_calm_assistant(self) -> bool:
|
|
1419
|
+
"""Inspects the nodes of the graph schema to determine whether
|
|
1420
|
+
any node is associated with the `FlowPolicy`, which is indicative of a
|
|
1421
|
+
CALM assistant setup.
|
|
1422
|
+
|
|
1423
|
+
Returns:
|
|
1424
|
+
bool: True if any node in the graph schema uses `FlowPolicy`.
|
|
1425
|
+
"""
|
|
1426
|
+
# Get the graph schema's nodes from the graph runner.
|
|
1427
|
+
nodes: dict[str, Any] = self.graph_runner._graph_schema.nodes # type: ignore[attr-defined]
|
|
1428
|
+
|
|
1429
|
+
flow_policy_class_path = "rasa.core.policies.flow_policy.FlowPolicy"
|
|
1430
|
+
# Iterate over the nodes and check if any node uses `FlowPolicy`.
|
|
1431
|
+
for node_name, schema_node in nodes.items():
|
|
1432
|
+
if (
|
|
1433
|
+
schema_node.uses is not None
|
|
1434
|
+
and f"{schema_node.uses.__module__}.{schema_node.uses.__name__}"
|
|
1435
|
+
== flow_policy_class_path
|
|
1436
|
+
):
|
|
1437
|
+
return True
|
|
1438
|
+
|
|
1439
|
+
# Return False if no node is found using `FlowPolicy`.
|
|
1440
|
+
return False
|
rasa/core/run.py
CHANGED
|
@@ -9,12 +9,12 @@ from functools import partial
|
|
|
9
9
|
from typing import (
|
|
10
10
|
Any,
|
|
11
11
|
Callable,
|
|
12
|
+
Dict,
|
|
12
13
|
List,
|
|
13
14
|
Optional,
|
|
14
15
|
Text,
|
|
15
16
|
Tuple,
|
|
16
17
|
Union,
|
|
17
|
-
Dict,
|
|
18
18
|
)
|
|
19
19
|
|
|
20
20
|
from sanic import Sanic
|
|
@@ -24,7 +24,6 @@ import rasa.core.utils
|
|
|
24
24
|
import rasa.shared.utils.common
|
|
25
25
|
import rasa.shared.utils.io
|
|
26
26
|
import rasa.utils
|
|
27
|
-
from rasa.utils import licensing
|
|
28
27
|
import rasa.utils.common
|
|
29
28
|
import rasa.utils.io
|
|
30
29
|
from rasa import server, telemetry
|
|
@@ -34,9 +33,11 @@ from rasa.core.agent import Agent
|
|
|
34
33
|
from rasa.core.channels import console
|
|
35
34
|
from rasa.core.channels.channel import InputChannel
|
|
36
35
|
from rasa.core.utils import AvailableEndpoints
|
|
36
|
+
from rasa.nlu.persistor import StorageType
|
|
37
37
|
from rasa.plugin import plugin_manager
|
|
38
38
|
from rasa.shared.exceptions import RasaException
|
|
39
39
|
from rasa.shared.utils.yaml import read_config_file
|
|
40
|
+
from rasa.utils import licensing
|
|
40
41
|
|
|
41
42
|
logger = logging.getLogger() # get the root logger
|
|
42
43
|
|
|
@@ -210,7 +211,7 @@ def serve_application(
|
|
|
210
211
|
jwt_private_key: Optional[Text] = None,
|
|
211
212
|
jwt_method: Optional[Text] = None,
|
|
212
213
|
endpoints: Optional[AvailableEndpoints] = None,
|
|
213
|
-
remote_storage: Optional[
|
|
214
|
+
remote_storage: Optional[StorageType] = None,
|
|
214
215
|
log_file: Optional[Text] = None,
|
|
215
216
|
ssl_certificate: Optional[Text] = None,
|
|
216
217
|
ssl_keyfile: Optional[Text] = None,
|
|
@@ -295,7 +296,7 @@ def serve_application(
|
|
|
295
296
|
async def load_agent_on_start(
|
|
296
297
|
model_path: Text,
|
|
297
298
|
endpoints: AvailableEndpoints,
|
|
298
|
-
remote_storage: Optional[
|
|
299
|
+
remote_storage: Optional[StorageType],
|
|
299
300
|
app: Sanic,
|
|
300
301
|
loop: AbstractEventLoop,
|
|
301
302
|
) -> Agent:
|
rasa/core/tracker_store.py
CHANGED
|
@@ -26,10 +26,10 @@ from typing import (
|
|
|
26
26
|
from boto3.dynamodb.conditions import Key
|
|
27
27
|
from pymongo.collection import Collection
|
|
28
28
|
|
|
29
|
-
import rasa.core.utils as core_utils
|
|
30
29
|
import rasa.shared.utils.cli
|
|
31
30
|
import rasa.shared.utils.common
|
|
32
31
|
import rasa.shared.utils.io
|
|
32
|
+
import rasa.utils.json_utils
|
|
33
33
|
from rasa.plugin import plugin_manager
|
|
34
34
|
from rasa.shared.core.constants import ACTION_LISTEN_NAME
|
|
35
35
|
from rasa.core.brokers.broker import EventBroker
|
|
@@ -705,7 +705,7 @@ class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict):
|
|
|
705
705
|
|
|
706
706
|
DynamoDB cannot store `float`s, so we'll convert them to `Decimal`s.
|
|
707
707
|
"""
|
|
708
|
-
return
|
|
708
|
+
return rasa.utils.json_utils.replace_floats_with_decimals(
|
|
709
709
|
SerializedTrackerAsDict.serialise_tracker(tracker)
|
|
710
710
|
)
|
|
711
711
|
|
|
@@ -747,12 +747,16 @@ class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict):
|
|
|
747
747
|
events_with_floats = []
|
|
748
748
|
for dialogue in dialogues:
|
|
749
749
|
if dialogue.get("events"):
|
|
750
|
-
events =
|
|
750
|
+
events = rasa.utils.json_utils.replace_decimals_with_floats(
|
|
751
|
+
dialogue["events"]
|
|
752
|
+
)
|
|
751
753
|
events_with_floats += events
|
|
752
754
|
else:
|
|
753
755
|
events = dialogues[0].get("events", [])
|
|
754
756
|
# `float`s are stored as `Decimal` objects - we need to convert them back
|
|
755
|
-
events_with_floats =
|
|
757
|
+
events_with_floats = rasa.utils.json_utils.replace_decimals_with_floats(
|
|
758
|
+
events
|
|
759
|
+
)
|
|
756
760
|
|
|
757
761
|
if self.domain is None:
|
|
758
762
|
slots = []
|
|
@@ -1688,7 +1688,7 @@ def run_interactive_learning(
|
|
|
1688
1688
|
p = None
|
|
1689
1689
|
|
|
1690
1690
|
app = run.configure_app(port=port, conversation_id="default", enable_api=True)
|
|
1691
|
-
endpoints = AvailableEndpoints.
|
|
1691
|
+
endpoints = AvailableEndpoints.get_instance(server_args.get("endpoints"))
|
|
1692
1692
|
|
|
1693
1693
|
# before_server_start handlers make sure the agent is loaded before the
|
|
1694
1694
|
# interactive learning IO starts
|