rasa-pro 3.8.18__py3-none-any.whl → 3.9.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- README.md +6 -42
- rasa/__main__.py +14 -9
- rasa/anonymization/anonymization_pipeline.py +0 -1
- rasa/anonymization/anonymization_rule_executor.py +3 -3
- rasa/anonymization/utils.py +4 -3
- rasa/api.py +2 -2
- rasa/cli/arguments/default_arguments.py +1 -1
- rasa/cli/arguments/run.py +2 -2
- rasa/cli/arguments/test.py +1 -1
- rasa/cli/arguments/train.py +10 -10
- rasa/cli/e2e_test.py +27 -7
- rasa/cli/export.py +0 -1
- rasa/cli/license.py +3 -3
- rasa/cli/project_templates/calm/actions/action_template.py +1 -1
- rasa/cli/project_templates/calm/config.yml +1 -1
- rasa/cli/project_templates/calm/credentials.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/add_contact.yml +1 -1
- rasa/cli/project_templates/calm/data/flows/remove_contact.yml +1 -1
- rasa/cli/project_templates/calm/domain/add_contact.yml +8 -2
- rasa/cli/project_templates/calm/domain/list_contacts.yml +3 -0
- rasa/cli/project_templates/calm/domain/remove_contact.yml +9 -2
- rasa/cli/project_templates/calm/domain/shared.yml +5 -0
- rasa/cli/project_templates/calm/endpoints.yml +4 -4
- rasa/cli/project_templates/default/actions/actions.py +1 -1
- rasa/cli/project_templates/default/config.yml +5 -5
- rasa/cli/project_templates/default/credentials.yml +1 -1
- rasa/cli/project_templates/default/endpoints.yml +4 -4
- rasa/cli/project_templates/default/tests/test_stories.yml +1 -1
- rasa/cli/project_templates/tutorial/config.yml +1 -1
- rasa/cli/project_templates/tutorial/credentials.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +6 -0
- rasa/cli/project_templates/tutorial/domain.yml +4 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +6 -6
- rasa/cli/run.py +0 -1
- rasa/cli/scaffold.py +3 -2
- rasa/cli/studio/download.py +11 -0
- rasa/cli/studio/studio.py +180 -24
- rasa/cli/studio/upload.py +0 -8
- rasa/cli/telemetry.py +18 -6
- rasa/cli/utils.py +21 -10
- rasa/cli/x.py +3 -2
- rasa/constants.py +1 -1
- rasa/core/actions/action.py +90 -315
- rasa/core/actions/action_exceptions.py +24 -0
- rasa/core/actions/constants.py +3 -0
- rasa/core/actions/custom_action_executor.py +188 -0
- rasa/core/actions/forms.py +11 -7
- rasa/core/actions/grpc_custom_action_executor.py +251 -0
- rasa/core/actions/http_custom_action_executor.py +140 -0
- rasa/core/actions/loops.py +3 -0
- rasa/core/actions/two_stage_fallback.py +1 -1
- rasa/core/agent.py +2 -4
- rasa/core/brokers/pika.py +1 -2
- rasa/core/channels/audiocodes.py +1 -1
- rasa/core/channels/botframework.py +0 -1
- rasa/core/channels/callback.py +0 -1
- rasa/core/channels/console.py +6 -8
- rasa/core/channels/development_inspector.py +1 -1
- rasa/core/channels/facebook.py +0 -3
- rasa/core/channels/hangouts.py +0 -6
- rasa/core/channels/inspector/dist/assets/{arc-5623b6dc.js → arc-b6e548fe.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-685c106a.js → c4Diagram-d0fbc5ce-fa03ac9e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-8cbed007.js → classDiagram-936ed81e-ee67392a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-5889cf12.js → classDiagram-v2-c3cb15f1-9b283fae.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-24c249d7.js → createText-62fc7601-8b6fcc2a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-7dd06a75.js → edges-f2ad444c-22e77f4f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-62c1e54c.js → erDiagram-9d236eb7-60ffc87f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-ce49b86f.js → flowDb-1972c806-9dd802e4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-4067e48f.js → flowDiagram-7ea5b25a-5fa1912f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-1844e5a5.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-59fe4051.js → flowchart-elk-definition-abe16c3d-622a1fd2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-47e3a43b.js → ganttDiagram-9b5ea136-e285a63a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-5a2ac0d9.js → gitGraphDiagram-99d0ae7c-f237bdca.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-dfb8efc4.js → index-2c4b9a3b-4b03d70e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-268a75c0.js → index-a5d3e69d.js} +4 -4
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-b0c470f2.js → infoDiagram-736b4530-72a0fa5f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-2edb829a.js → journeyDiagram-df861f2b-82218c41.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-b6873d69.js → layout-78cff630.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-1efc5781.js → line-5038b469.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-661e9b94.js → linear-c4fc4098.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-2d2e727f.js → mindmap-definition-beec6740-c33c8ea6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-9d3ea93d.js → pieDiagram-dbbf0591-a8d03059.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-06a178a2.js → quadrantDiagram-4d7f4fd6-6a0e56b2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-0bfedffc.js → requirementDiagram-6fc4c22a-2dc7c7bd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-d76d0a04.js → sankeyDiagram-8f13d901-2360fe39.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-37bb4341.js → sequenceDiagram-b655622a-41b9f9ad.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-f52f7f57.js → stateDiagram-59f0c015-0aad326f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-4a986a20.js → stateDiagram-v2-2b26beab-9847d984.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-7dd9ae12.js → styles-080da4f6-564d890e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-46e1ca14.js → styles-3dcbcfbf-38957613.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-4a97439a.js → styles-9c745c82-f0fc6921.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-823917a3.js → svgDrawCommon-4835440b-ef3c5a77.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-9ea72896.js → timeline-definition-5b62e21b-bf3e91c1.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-b631a8b6.js → xychartDiagram-2b33534f-4d4026c0.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +10 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +4 -7
- rasa/core/channels/inspector/src/helpers/formatters.ts +3 -2
- rasa/core/channels/rest.py +36 -21
- rasa/core/channels/rocketchat.py +0 -1
- rasa/core/channels/socketio.py +1 -1
- rasa/core/channels/telegram.py +3 -3
- rasa/core/channels/webexteams.py +0 -1
- rasa/core/concurrent_lock_store.py +1 -1
- rasa/core/evaluation/marker_base.py +1 -3
- rasa/core/evaluation/marker_stats.py +1 -2
- rasa/core/featurizers/single_state_featurizer.py +3 -26
- rasa/core/featurizers/tracker_featurizers.py +18 -122
- rasa/core/information_retrieval/__init__.py +7 -0
- rasa/core/information_retrieval/faiss.py +9 -4
- rasa/core/information_retrieval/information_retrieval.py +64 -7
- rasa/core/information_retrieval/milvus.py +7 -14
- rasa/core/information_retrieval/qdrant.py +8 -15
- rasa/core/lock_store.py +0 -1
- rasa/core/migrate.py +1 -2
- rasa/core/nlg/callback.py +3 -4
- rasa/core/policies/enterprise_search_policy.py +86 -22
- rasa/core/policies/enterprise_search_prompt_template.jinja2 +4 -41
- rasa/core/policies/enterprise_search_prompt_with_citation_template.jinja2 +60 -0
- rasa/core/policies/flows/flow_executor.py +104 -2
- rasa/core/policies/intentless_policy.py +7 -9
- rasa/core/policies/memoization.py +3 -3
- rasa/core/policies/policy.py +18 -9
- rasa/core/policies/rule_policy.py +8 -11
- rasa/core/policies/ted_policy.py +61 -88
- rasa/core/policies/unexpected_intent_policy.py +8 -17
- rasa/core/processor.py +136 -47
- rasa/core/run.py +41 -25
- rasa/core/secrets_manager/endpoints.py +2 -2
- rasa/core/secrets_manager/vault.py +6 -8
- rasa/core/test.py +3 -5
- rasa/core/tracker_store.py +49 -14
- rasa/core/train.py +1 -3
- rasa/core/training/interactive.py +9 -6
- rasa/core/utils.py +5 -10
- rasa/dialogue_understanding/coexistence/intent_based_router.py +11 -4
- rasa/dialogue_understanding/coexistence/llm_based_router.py +2 -3
- rasa/dialogue_understanding/commands/__init__.py +4 -0
- rasa/dialogue_understanding/commands/can_not_handle_command.py +9 -0
- rasa/dialogue_understanding/commands/cancel_flow_command.py +9 -0
- rasa/dialogue_understanding/commands/change_flow_command.py +38 -0
- rasa/dialogue_understanding/commands/chit_chat_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/clarify_command.py +9 -0
- rasa/dialogue_understanding/commands/correct_slots_command.py +9 -0
- rasa/dialogue_understanding/commands/error_command.py +12 -0
- rasa/dialogue_understanding/commands/handle_code_change_command.py +9 -0
- rasa/dialogue_understanding/commands/human_handoff_command.py +9 -0
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +9 -0
- rasa/dialogue_understanding/commands/noop_command.py +9 -0
- rasa/dialogue_understanding/commands/set_slot_command.py +38 -3
- rasa/dialogue_understanding/commands/skip_question_command.py +9 -0
- rasa/dialogue_understanding/commands/start_flow_command.py +9 -0
- rasa/dialogue_understanding/generator/__init__.py +16 -1
- rasa/dialogue_understanding/generator/command_generator.py +92 -6
- rasa/dialogue_understanding/generator/constants.py +18 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +7 -5
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +467 -0
- rasa/dialogue_understanding/generator/llm_command_generator.py +39 -609
- rasa/dialogue_understanding/generator/multi_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/multi_step/fill_slots_prompt.jinja2 +62 -0
- rasa/dialogue_understanding/generator/multi_step/handle_flows_prompt.jinja2 +38 -0
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +827 -0
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +69 -8
- rasa/dialogue_understanding/generator/single_step/__init__.py +0 -0
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +345 -0
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +36 -31
- rasa/dialogue_understanding/processor/command_processor.py +112 -3
- rasa/e2e_test/constants.py +1 -0
- rasa/e2e_test/e2e_test_case.py +44 -0
- rasa/e2e_test/e2e_test_runner.py +114 -11
- rasa/e2e_test/e2e_test_schema.yml +18 -0
- rasa/engine/caching.py +0 -1
- rasa/engine/graph.py +18 -6
- rasa/engine/recipes/config_files/default_config.yml +3 -3
- rasa/engine/recipes/default_components.py +1 -1
- rasa/engine/recipes/default_recipe.py +4 -5
- rasa/engine/recipes/recipe.py +1 -1
- rasa/engine/runner/dask.py +3 -9
- rasa/engine/storage/local_model_storage.py +0 -2
- rasa/engine/validation.py +179 -145
- rasa/exceptions.py +2 -2
- rasa/graph_components/validators/default_recipe_validator.py +3 -5
- rasa/hooks.py +0 -1
- rasa/model.py +1 -1
- rasa/model_training.py +1 -0
- rasa/nlu/classifiers/diet_classifier.py +33 -52
- rasa/nlu/classifiers/logistic_regression_classifier.py +9 -22
- rasa/nlu/classifiers/sklearn_intent_classifier.py +16 -37
- rasa/nlu/extractors/crf_entity_extractor.py +54 -97
- rasa/nlu/extractors/duckling_entity_extractor.py +1 -1
- rasa/nlu/featurizers/dense_featurizer/convert_featurizer.py +1 -5
- rasa/nlu/featurizers/dense_featurizer/lm_featurizer.py +0 -4
- rasa/nlu/featurizers/featurizer.py +1 -1
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +18 -49
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +26 -64
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +3 -5
- rasa/nlu/persistor.py +68 -26
- rasa/nlu/selectors/response_selector.py +7 -10
- rasa/nlu/test.py +0 -3
- rasa/nlu/utils/hugging_face/registry.py +1 -1
- rasa/nlu/utils/spacy_utils.py +1 -3
- rasa/server.py +22 -7
- rasa/shared/constants.py +12 -1
- rasa/shared/core/command_payload_reader.py +109 -0
- rasa/shared/core/constants.py +4 -5
- rasa/shared/core/domain.py +57 -56
- rasa/shared/core/events.py +4 -7
- rasa/shared/core/flows/flow.py +9 -0
- rasa/shared/core/flows/flows_list.py +12 -0
- rasa/shared/core/flows/steps/action.py +7 -2
- rasa/shared/core/generator.py +12 -11
- rasa/shared/core/slot_mappings.py +315 -24
- rasa/shared/core/slots.py +4 -2
- rasa/shared/core/trackers.py +32 -14
- rasa/shared/core/training_data/loading.py +0 -1
- rasa/shared/core/training_data/story_reader/story_reader.py +3 -3
- rasa/shared/core/training_data/story_reader/yaml_story_reader.py +11 -11
- rasa/shared/core/training_data/story_writer/yaml_story_writer.py +5 -3
- rasa/shared/core/training_data/structures.py +1 -1
- rasa/shared/core/training_data/visualization.py +1 -1
- rasa/shared/data.py +58 -1
- rasa/shared/exceptions.py +36 -2
- rasa/shared/importers/importer.py +1 -2
- rasa/shared/importers/rasa.py +0 -1
- rasa/shared/nlu/constants.py +2 -0
- rasa/shared/nlu/training_data/entities_parser.py +1 -2
- rasa/shared/nlu/training_data/features.py +2 -120
- rasa/shared/nlu/training_data/formats/dialogflow.py +3 -2
- rasa/shared/nlu/training_data/formats/rasa_yaml.py +3 -5
- rasa/shared/nlu/training_data/formats/readerwriter.py +0 -1
- rasa/shared/nlu/training_data/message.py +13 -0
- rasa/shared/nlu/training_data/training_data.py +0 -2
- rasa/shared/providers/openai/session_handler.py +2 -2
- rasa/shared/utils/constants.py +3 -0
- rasa/shared/utils/io.py +11 -1
- rasa/shared/utils/llm.py +1 -2
- rasa/shared/utils/pykwalify_extensions.py +1 -0
- rasa/shared/utils/schemas/domain.yml +3 -0
- rasa/shared/utils/yaml.py +44 -35
- rasa/studio/auth.py +26 -10
- rasa/studio/constants.py +2 -0
- rasa/studio/data_handler.py +114 -107
- rasa/studio/download.py +160 -27
- rasa/studio/results_logger.py +137 -0
- rasa/studio/train.py +6 -7
- rasa/studio/upload.py +159 -134
- rasa/telemetry.py +188 -34
- rasa/tracing/config.py +18 -3
- rasa/tracing/constants.py +26 -2
- rasa/tracing/instrumentation/attribute_extractors.py +50 -41
- rasa/tracing/instrumentation/instrumentation.py +290 -44
- rasa/tracing/instrumentation/intentless_policy_instrumentation.py +7 -5
- rasa/tracing/instrumentation/metrics.py +109 -21
- rasa/tracing/metric_instrument_provider.py +83 -3
- rasa/utils/cli.py +2 -1
- rasa/utils/common.py +1 -1
- rasa/utils/endpoints.py +1 -2
- rasa/utils/io.py +72 -6
- rasa/utils/licensing.py +246 -31
- rasa/utils/ml_utils.py +1 -1
- rasa/utils/tensorflow/data_generator.py +1 -1
- rasa/utils/tensorflow/environment.py +1 -1
- rasa/utils/tensorflow/model_data.py +201 -12
- rasa/utils/tensorflow/model_data_utils.py +499 -500
- rasa/utils/tensorflow/models.py +5 -6
- rasa/utils/tensorflow/rasa_layers.py +15 -15
- rasa/utils/train_utils.py +1 -1
- rasa/utils/url_tools.py +53 -0
- rasa/validator.py +305 -3
- rasa/version.py +1 -1
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/METADATA +25 -61
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/RECORD +276 -259
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-85583a23.js +0 -1
- rasa/utils/tensorflow/feature_array.py +0 -370
- /rasa/dialogue_understanding/generator/{command_prompt_template.jinja2 → single_step/command_prompt_template.jinja2} +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/NOTICE +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/WHEEL +0 -0
- {rasa_pro-3.8.18.dist-info → rasa_pro-3.9.15.dist-info}/entry_points.txt +0 -0
rasa/core/tracker_store.py
CHANGED
|
@@ -276,10 +276,10 @@ class TrackerStore:
|
|
|
276
276
|
async def retrieve_full_tracker(
|
|
277
277
|
self, conversation_id: Text
|
|
278
278
|
) -> Optional[DialogueStateTracker]:
|
|
279
|
-
"""Retrieve method for fetching all tracker events
|
|
280
|
-
that may be overridden by specific tracker.
|
|
279
|
+
"""Retrieve method for fetching all tracker events.
|
|
281
280
|
|
|
282
|
-
The default implementation
|
|
281
|
+
Fetches events across conversation sessions. The default implementation
|
|
282
|
+
uses `self.retrieve()`.
|
|
283
283
|
|
|
284
284
|
Args:
|
|
285
285
|
conversation_id: The conversation ID to retrieve the tracker for.
|
|
@@ -339,6 +339,28 @@ class TrackerStore:
|
|
|
339
339
|
"""Returns the set of values for the tracker store's primary key."""
|
|
340
340
|
raise NotImplementedError()
|
|
341
341
|
|
|
342
|
+
async def count_conversations(self, after_timestamp: float = 0.0) -> int:
|
|
343
|
+
"""Returns the number of conversations that have occurred after a timestamp.
|
|
344
|
+
|
|
345
|
+
By default, this method returns the number of conversations that
|
|
346
|
+
have occurred after the Unix epoch (i.e. timestamp 0). A conversation
|
|
347
|
+
is considered to have occurred after a timestamp if at least one event
|
|
348
|
+
happened after that timestamp.
|
|
349
|
+
"""
|
|
350
|
+
tracker_keys = await self.keys()
|
|
351
|
+
|
|
352
|
+
conversation_count = 0
|
|
353
|
+
for key in tracker_keys:
|
|
354
|
+
tracker = await self.retrieve(key)
|
|
355
|
+
if tracker is None or not tracker.events:
|
|
356
|
+
continue
|
|
357
|
+
|
|
358
|
+
last_event = tracker.events[-1]
|
|
359
|
+
if last_event.timestamp >= after_timestamp:
|
|
360
|
+
conversation_count += 1
|
|
361
|
+
|
|
362
|
+
return conversation_count
|
|
363
|
+
|
|
342
364
|
def deserialise_tracker(
|
|
343
365
|
self, sender_id: Text, serialised_tracker: Union[Text, bytes]
|
|
344
366
|
) -> Optional[DialogueStateTracker]:
|
|
@@ -930,7 +952,7 @@ def _create_sequence(table_name: Text) -> "Sequence":
|
|
|
930
952
|
"""Creates a sequence object for a specific table name.
|
|
931
953
|
|
|
932
954
|
If using Oracle you will need to create a sequence in your database,
|
|
933
|
-
as described here: https://rasa.com/docs/rasa/tracker-stores#sqltrackerstore
|
|
955
|
+
as described here: https://rasa.com/docs/rasa-pro/production/tracker-stores#sqltrackerstore
|
|
934
956
|
Args:
|
|
935
957
|
table_name: The name of the table, which gets a Sequence assigned
|
|
936
958
|
|
|
@@ -1045,6 +1067,8 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
1045
1067
|
from sqlalchemy.orm import DeclarativeBase
|
|
1046
1068
|
|
|
1047
1069
|
class Base(DeclarativeBase):
|
|
1070
|
+
"""Base class for all tracker store tables."""
|
|
1071
|
+
|
|
1048
1072
|
pass
|
|
1049
1073
|
|
|
1050
1074
|
class SQLEvent(Base):
|
|
@@ -1113,7 +1137,6 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
1113
1137
|
sqlalchemy.exc.OperationalError,
|
|
1114
1138
|
sqlalchemy.exc.IntegrityError,
|
|
1115
1139
|
) as error:
|
|
1116
|
-
|
|
1117
1140
|
logger.warning(error)
|
|
1118
1141
|
sleep(5)
|
|
1119
1142
|
|
|
@@ -1132,8 +1155,10 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
1132
1155
|
login_db: Optional[Text] = None,
|
|
1133
1156
|
query: Optional[Dict] = None,
|
|
1134
1157
|
) -> Union[Text, "URL"]:
|
|
1135
|
-
"""Build an SQLAlchemy `URL` object
|
|
1136
|
-
|
|
1158
|
+
"""Build an SQLAlchemy `URL` object.
|
|
1159
|
+
|
|
1160
|
+
The URL object represents the parameters needed to connect to an
|
|
1161
|
+
SQL database.
|
|
1137
1162
|
|
|
1138
1163
|
Args:
|
|
1139
1164
|
dialect: SQL database type.
|
|
@@ -1260,11 +1285,24 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
1260
1285
|
conversation_id, fetch_events_from_all_sessions=True
|
|
1261
1286
|
)
|
|
1262
1287
|
|
|
1288
|
+
async def count_conversations(self, after_timestamp: float = 0.0) -> int:
|
|
1289
|
+
"""Returns the number of conversations that have occurred after a timestamp.
|
|
1290
|
+
|
|
1291
|
+
By default, this method returns the number of conversations that
|
|
1292
|
+
have occurred after the Unix epoch (i.e. timestamp 0).
|
|
1293
|
+
"""
|
|
1294
|
+
with self.session_scope() as session:
|
|
1295
|
+
query = (
|
|
1296
|
+
session.query(self.SQLEvent.sender_id)
|
|
1297
|
+
.distinct()
|
|
1298
|
+
.filter(self.SQLEvent.timestamp >= after_timestamp)
|
|
1299
|
+
)
|
|
1300
|
+
return query.count()
|
|
1301
|
+
|
|
1263
1302
|
async def _retrieve(
|
|
1264
1303
|
self, sender_id: Text, fetch_events_from_all_sessions: bool
|
|
1265
1304
|
) -> Optional[DialogueStateTracker]:
|
|
1266
1305
|
with self.session_scope() as session:
|
|
1267
|
-
|
|
1268
1306
|
serialised_events = self._event_query(
|
|
1269
1307
|
session,
|
|
1270
1308
|
sender_id,
|
|
@@ -1290,6 +1328,7 @@ class SQLTrackerStore(TrackerStore, SerializedTrackerAsText):
|
|
|
1290
1328
|
self, session: "Session", sender_id: Text, fetch_events_from_all_sessions: bool
|
|
1291
1329
|
) -> "Query":
|
|
1292
1330
|
"""Provide the query to retrieve the conversation events for a specific sender.
|
|
1331
|
+
|
|
1293
1332
|
The events are ordered by ID to ensure correct sequence of events.
|
|
1294
1333
|
As `timestamp` is not guaranteed to be unique and low-precision (float), it
|
|
1295
1334
|
cannot be used to order the events.
|
|
@@ -1637,9 +1676,7 @@ class AwaitableTrackerStore(TrackerStore):
|
|
|
1637
1676
|
"""Wrapper to call `retrieve` method of primary tracker store."""
|
|
1638
1677
|
result = self._tracker_store.retrieve(sender_id)
|
|
1639
1678
|
return (
|
|
1640
|
-
await result
|
|
1641
|
-
if isawaitable(result)
|
|
1642
|
-
else result # type: ignore[return-value]
|
|
1679
|
+
await result if isawaitable(result) else result # type: ignore[return-value, misc]
|
|
1643
1680
|
)
|
|
1644
1681
|
|
|
1645
1682
|
async def keys(self) -> Iterable[Text]:
|
|
@@ -1658,7 +1695,5 @@ class AwaitableTrackerStore(TrackerStore):
|
|
|
1658
1695
|
"""Wrapper to call `retrieve_full_tracker` method of primary tracker store."""
|
|
1659
1696
|
result = self._tracker_store.retrieve_full_tracker(conversation_id)
|
|
1660
1697
|
return (
|
|
1661
|
-
await result
|
|
1662
|
-
if isawaitable(result)
|
|
1663
|
-
else result # type: ignore[return-value]
|
|
1698
|
+
await result if isawaitable(result) else result # type: ignore[return-value, misc]
|
|
1664
1699
|
)
|
rasa/core/train.py
CHANGED
|
@@ -34,9 +34,7 @@ async def train_comparison_models(
|
|
|
34
34
|
for policy_config in policy_configs:
|
|
35
35
|
config_name = os.path.splitext(os.path.basename(policy_config))[0]
|
|
36
36
|
logging.info(
|
|
37
|
-
"Starting to train {} round {}/{}"
|
|
38
|
-
" with {}% exclusion"
|
|
39
|
-
"".format(
|
|
37
|
+
"Starting to train {} round {}/{} with {}% exclusion".format(
|
|
40
38
|
config_name, current_run, len(exclusion_percentages), percentage
|
|
41
39
|
)
|
|
42
40
|
)
|
|
@@ -3,6 +3,7 @@ import logging
|
|
|
3
3
|
import os
|
|
4
4
|
import textwrap
|
|
5
5
|
import uuid
|
|
6
|
+
import warnings
|
|
6
7
|
from functools import partial
|
|
7
8
|
from multiprocessing import Process
|
|
8
9
|
from typing import (
|
|
@@ -336,7 +337,7 @@ async def _ask_questions(
|
|
|
336
337
|
|
|
337
338
|
|
|
338
339
|
def _selection_choices_from_intent_prediction(
|
|
339
|
-
predictions: List[Dict[Text, Any]]
|
|
340
|
+
predictions: List[Dict[Text, Any]],
|
|
340
341
|
) -> List[Dict[Text, Any]]:
|
|
341
342
|
"""Given a list of ML predictions create a UI choice list."""
|
|
342
343
|
sorted_intents = sorted(
|
|
@@ -762,7 +763,7 @@ async def _request_export_info() -> Tuple[Text, Text, Text]:
|
|
|
762
763
|
|
|
763
764
|
|
|
764
765
|
def _split_conversation_at_restarts(
|
|
765
|
-
events: List[Dict[Text, Any]]
|
|
766
|
+
events: List[Dict[Text, Any]],
|
|
766
767
|
) -> List[List[Dict[Text, Any]]]:
|
|
767
768
|
"""Split a conversation at restart events.
|
|
768
769
|
|
|
@@ -1600,6 +1601,7 @@ def _serve_application(
|
|
|
1600
1601
|
"""Start a core server and attach the interactive learning IO."""
|
|
1601
1602
|
endpoint = EndpointConfig(url=DEFAULT_SERVER_FORMAT.format("http", port))
|
|
1602
1603
|
|
|
1604
|
+
@app.after_server_start
|
|
1603
1605
|
async def run_interactive_io(running_app: Sanic) -> None:
|
|
1604
1606
|
"""Small wrapper to shut down the server once cmd io is done."""
|
|
1605
1607
|
await record_messages(
|
|
@@ -1613,11 +1615,9 @@ def _serve_application(
|
|
|
1613
1615
|
|
|
1614
1616
|
running_app.stop() # kill the sanic server
|
|
1615
1617
|
|
|
1616
|
-
app.add_task(run_interactive_io)
|
|
1617
|
-
|
|
1618
1618
|
update_sanic_log_level()
|
|
1619
1619
|
|
|
1620
|
-
app.run(host="0.0.0.0", port=port)
|
|
1620
|
+
app.run(host="0.0.0.0", port=port, legacy=True)
|
|
1621
1621
|
|
|
1622
1622
|
return app
|
|
1623
1623
|
|
|
@@ -1626,6 +1626,9 @@ def start_visualization(image_path: Text, port: int) -> None:
|
|
|
1626
1626
|
"""Add routes to serve the conversation visualization files."""
|
|
1627
1627
|
app = Sanic("rasa_interactive")
|
|
1628
1628
|
|
|
1629
|
+
# Reset Sanic warnings filter that allows the triggering of Sanic warnings
|
|
1630
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"sanic.*")
|
|
1631
|
+
|
|
1629
1632
|
# noinspection PyUnusedLocal
|
|
1630
1633
|
@app.exception(NotFound)
|
|
1631
1634
|
async def ignore_404s(request: Request, exception: Exception) -> HTTPResponse:
|
|
@@ -1647,7 +1650,7 @@ def start_visualization(image_path: Text, port: int) -> None:
|
|
|
1647
1650
|
|
|
1648
1651
|
update_sanic_log_level()
|
|
1649
1652
|
|
|
1650
|
-
app.run(host="0.0.0.0", port=port, access_log=False)
|
|
1653
|
+
app.run(host="0.0.0.0", port=port, access_log=False, legacy=True)
|
|
1651
1654
|
|
|
1652
1655
|
|
|
1653
1656
|
def run_interactive_learning(
|
rasa/core/utils.py
CHANGED
|
@@ -73,8 +73,9 @@ def one_hot(hot_idx: int, length: int, dtype: Optional[Text] = None) -> np.ndarr
|
|
|
73
73
|
"""
|
|
74
74
|
if hot_idx >= length:
|
|
75
75
|
raise ValueError(
|
|
76
|
-
"Can't create one hot. Index '{}' is out "
|
|
77
|
-
|
|
76
|
+
"Can't create one hot. Index '{}' is out of range (length '{}')".format(
|
|
77
|
+
hot_idx, length
|
|
78
|
+
)
|
|
78
79
|
)
|
|
79
80
|
r = np.zeros(length, dtype)
|
|
80
81
|
r[hot_idx] = 1
|
|
@@ -159,12 +160,6 @@ def is_limit_reached(num_messages: int, limit: Optional[int]) -> bool:
|
|
|
159
160
|
return limit is not None and num_messages >= limit
|
|
160
161
|
|
|
161
162
|
|
|
162
|
-
def file_as_bytes(path: Text) -> bytes:
|
|
163
|
-
"""Read in a file as a byte array."""
|
|
164
|
-
with open(path, "rb") as f:
|
|
165
|
-
return f.read()
|
|
166
|
-
|
|
167
|
-
|
|
168
163
|
class AvailableEndpoints:
|
|
169
164
|
"""Collection of configured endpoints."""
|
|
170
165
|
|
|
@@ -216,7 +211,7 @@ class AvailableEndpoints:
|
|
|
216
211
|
|
|
217
212
|
|
|
218
213
|
def read_endpoints_from_path(
|
|
219
|
-
endpoints_path: Optional[Union[Path, Text]] = None
|
|
214
|
+
endpoints_path: Optional[Union[Path, Text]] = None,
|
|
220
215
|
) -> AvailableEndpoints:
|
|
221
216
|
"""Get `AvailableEndpoints` object from specified path.
|
|
222
217
|
|
|
@@ -283,7 +278,7 @@ def replace_decimals_with_floats(obj: Any) -> Any:
|
|
|
283
278
|
|
|
284
279
|
|
|
285
280
|
def _lock_store_is_multi_worker_compatible(
|
|
286
|
-
lock_store: Union[EndpointConfig, LockStore, None]
|
|
281
|
+
lock_store: Union[EndpointConfig, LockStore, None],
|
|
287
282
|
) -> bool:
|
|
288
283
|
if isinstance(lock_store, InMemoryLockStore):
|
|
289
284
|
return False
|
|
@@ -19,6 +19,7 @@ from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
|
19
19
|
from rasa.engine.storage.resource import Resource
|
|
20
20
|
from rasa.engine.storage.storage import ModelStorage
|
|
21
21
|
from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
22
|
+
from rasa.shared.core.domain import Domain
|
|
22
23
|
from rasa.shared.core.flows import FlowsList
|
|
23
24
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
24
25
|
from rasa.shared.exceptions import InvalidConfigException
|
|
@@ -111,6 +112,7 @@ class IntentBasedRouter(GraphComponent):
|
|
|
111
112
|
messages: List[Message],
|
|
112
113
|
flows: FlowsList,
|
|
113
114
|
tracker: Optional[DialogueStateTracker] = None,
|
|
115
|
+
domain: Optional[Domain] = None,
|
|
114
116
|
) -> List[Message]:
|
|
115
117
|
"""Process a list of messages."""
|
|
116
118
|
if tracker is None:
|
|
@@ -118,7 +120,7 @@ class IntentBasedRouter(GraphComponent):
|
|
|
118
120
|
return messages
|
|
119
121
|
|
|
120
122
|
for message in messages:
|
|
121
|
-
commands = await self.predict_commands(message, flows, tracker)
|
|
123
|
+
commands = await self.predict_commands(message, flows, tracker, domain)
|
|
122
124
|
commands_dicts = [command.as_dict() for command in commands]
|
|
123
125
|
message.set(COMMANDS, commands_dicts, add_to_output=True)
|
|
124
126
|
|
|
@@ -129,6 +131,7 @@ class IntentBasedRouter(GraphComponent):
|
|
|
129
131
|
message: Message,
|
|
130
132
|
flows: FlowsList,
|
|
131
133
|
tracker: DialogueStateTracker,
|
|
134
|
+
domain: Optional[Domain] = None,
|
|
132
135
|
) -> List[Command]:
|
|
133
136
|
if not tracker.has_coexistence_routing_slot:
|
|
134
137
|
raise InvalidConfigException(
|
|
@@ -144,8 +147,8 @@ class IntentBasedRouter(GraphComponent):
|
|
|
144
147
|
)
|
|
145
148
|
return commands
|
|
146
149
|
elif route_session_to_calm is True:
|
|
147
|
-
# don't set any commands so that
|
|
148
|
-
# and can predict the actual commands.
|
|
150
|
+
# don't set any commands so that a `LLMBasedCommandGenerator` is
|
|
151
|
+
# triggered and can predict the actual commands.
|
|
149
152
|
return []
|
|
150
153
|
else:
|
|
151
154
|
# If the session is assigned to DM1 add a `NoopCommand` to silence
|
|
@@ -156,7 +159,11 @@ class IntentBasedRouter(GraphComponent):
|
|
|
156
159
|
self, message: Message, tracker: DialogueStateTracker, flows: FlowsList
|
|
157
160
|
) -> bool:
|
|
158
161
|
"""Check if the intent is part of a nlu trigger."""
|
|
159
|
-
commands = NLUCommandAdapter.convert_nlu_to_commands(
|
|
162
|
+
commands = NLUCommandAdapter.convert_nlu_to_commands(
|
|
163
|
+
message,
|
|
164
|
+
tracker,
|
|
165
|
+
flows,
|
|
166
|
+
)
|
|
160
167
|
return len(commands) > 0
|
|
161
168
|
|
|
162
169
|
def _generate_command_using_intent(
|
|
@@ -13,7 +13,7 @@ from rasa.dialogue_understanding.coexistence.constants import (
|
|
|
13
13
|
)
|
|
14
14
|
from rasa.dialogue_understanding.commands import Command, SetSlotCommand
|
|
15
15
|
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
16
|
-
from rasa.dialogue_understanding.generator.
|
|
16
|
+
from rasa.dialogue_understanding.generator.constants import LLM_CONFIG_KEY
|
|
17
17
|
from rasa.engine.graph import ExecutionContext, GraphComponent
|
|
18
18
|
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
|
|
19
19
|
from rasa.engine.storage.resource import Resource
|
|
@@ -196,7 +196,7 @@ class LLMBasedRouter(GraphComponent):
|
|
|
196
196
|
structlogger.info("llm_based_router.predicated_commands", commands=commands)
|
|
197
197
|
return commands
|
|
198
198
|
elif route_session_to_calm is True:
|
|
199
|
-
# don't set any commands so that
|
|
199
|
+
# don't set any commands so that a `LLMBasedCommandGenerator` is triggered
|
|
200
200
|
# and can predict the actual commands.
|
|
201
201
|
return []
|
|
202
202
|
else:
|
|
@@ -231,7 +231,6 @@ class LLMBasedRouter(GraphComponent):
|
|
|
231
231
|
return [SetSlotCommand(ROUTE_TO_CALM_SLOT, False)]
|
|
232
232
|
|
|
233
233
|
def render_template(self, message: Message) -> str:
|
|
234
|
-
|
|
235
234
|
inputs = {
|
|
236
235
|
"user_message": message.get(TEXT),
|
|
237
236
|
f"{CALM_ENTRY}_{STICKY}": self.config[CALM_ENTRY][STICKY],
|
|
@@ -26,6 +26,8 @@ from rasa.dialogue_understanding.commands.correct_slots_command import (
|
|
|
26
26
|
CorrectSlotsCommand,
|
|
27
27
|
CorrectedSlot,
|
|
28
28
|
)
|
|
29
|
+
from rasa.dialogue_understanding.commands.noop_command import NoopCommand
|
|
30
|
+
from rasa.dialogue_understanding.commands.change_flow_command import ChangeFlowCommand
|
|
29
31
|
|
|
30
32
|
__all__ = [
|
|
31
33
|
"Command",
|
|
@@ -42,4 +44,6 @@ __all__ = [
|
|
|
42
44
|
"CorrectSlotsCommand",
|
|
43
45
|
"CorrectedSlot",
|
|
44
46
|
"ErrorCommand",
|
|
47
|
+
"NoopCommand",
|
|
48
|
+
"ChangeFlowCommand",
|
|
45
49
|
]
|
|
@@ -59,3 +59,12 @@ class CannotHandleCommand(Command):
|
|
|
59
59
|
else:
|
|
60
60
|
stack.push(CannotHandlePatternFlowStackFrame())
|
|
61
61
|
return tracker.create_stack_updated_events(stack)
|
|
62
|
+
|
|
63
|
+
def __hash__(self) -> int:
|
|
64
|
+
return hash(self.reason)
|
|
65
|
+
|
|
66
|
+
def __eq__(self, other: object) -> bool:
|
|
67
|
+
if not isinstance(other, CannotHandleCommand):
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
return other.reason == self.reason
|
|
@@ -114,3 +114,12 @@ class CancelFlowCommand(Command):
|
|
|
114
114
|
applied_events.append(FlowCancelled(user_frame.flow_id, user_frame.step_id))
|
|
115
115
|
|
|
116
116
|
return applied_events + tracker.create_stack_updated_events(stack)
|
|
117
|
+
|
|
118
|
+
def __hash__(self) -> int:
|
|
119
|
+
return hash(self.command())
|
|
120
|
+
|
|
121
|
+
def __eq__(self, other: object) -> bool:
|
|
122
|
+
if not isinstance(other, CancelFlowCommand):
|
|
123
|
+
return False
|
|
124
|
+
|
|
125
|
+
return True
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List
|
|
5
|
+
from rasa.dialogue_understanding.commands import Command
|
|
6
|
+
from rasa.shared.core.events import Event
|
|
7
|
+
from rasa.shared.core.flows import FlowsList
|
|
8
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ChangeFlowCommand(Command):
|
|
13
|
+
"""A command to indicate a change of flows was requested by the command
|
|
14
|
+
generator."""
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def command(cls) -> str:
|
|
18
|
+
"""Returns the command type."""
|
|
19
|
+
return "change_flow"
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def from_dict(cls, data: Dict[str, Any]) -> ChangeFlowCommand:
|
|
23
|
+
"""Converts the dictionary to a command.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
The converted dictionary.
|
|
27
|
+
"""
|
|
28
|
+
return ChangeFlowCommand()
|
|
29
|
+
|
|
30
|
+
def run_command_on_tracker(
|
|
31
|
+
self,
|
|
32
|
+
tracker: DialogueStateTracker,
|
|
33
|
+
all_flows: FlowsList,
|
|
34
|
+
original_tracker: DialogueStateTracker,
|
|
35
|
+
) -> List[Event]:
|
|
36
|
+
# the change flow command is not actually pushing anything to the tracker,
|
|
37
|
+
# but it is predicted by the MultiStepLLMCommandGenerator and used internally
|
|
38
|
+
return []
|
|
@@ -46,3 +46,12 @@ class ChitChatAnswerCommand(FreeFormAnswerCommand):
|
|
|
46
46
|
stack = tracker.stack
|
|
47
47
|
stack.push(ChitchatPatternFlowStackFrame())
|
|
48
48
|
return tracker.create_stack_updated_events(stack)
|
|
49
|
+
|
|
50
|
+
def __hash__(self) -> int:
|
|
51
|
+
return hash(self.command())
|
|
52
|
+
|
|
53
|
+
def __eq__(self, other: object) -> bool:
|
|
54
|
+
if not isinstance(other, ChitChatAnswerCommand):
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
return True
|
|
@@ -75,3 +75,12 @@ class ClarifyCommand(Command):
|
|
|
75
75
|
names = [flow.readable_name() for flow in relevant_flows if flow is not None]
|
|
76
76
|
stack.push(ClarifyPatternFlowStackFrame(names=names))
|
|
77
77
|
return tracker.create_stack_updated_events(stack)
|
|
78
|
+
|
|
79
|
+
def __hash__(self) -> int:
|
|
80
|
+
return hash(tuple(self.options))
|
|
81
|
+
|
|
82
|
+
def __eq__(self, other: object) -> bool:
|
|
83
|
+
if not isinstance(other, ClarifyCommand):
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
return other.options == self.options
|
|
@@ -286,3 +286,12 @@ class CorrectSlotsCommand(Command):
|
|
|
286
286
|
|
|
287
287
|
stack.push(correction_frame, index=insertion_index)
|
|
288
288
|
return tracker.create_stack_updated_events(stack)
|
|
289
|
+
|
|
290
|
+
def __hash__(self) -> int:
|
|
291
|
+
return hash(self.command())
|
|
292
|
+
|
|
293
|
+
def __eq__(self, other: object) -> bool:
|
|
294
|
+
if not isinstance(other, CorrectSlotsCommand):
|
|
295
|
+
return False
|
|
296
|
+
|
|
297
|
+
return True
|
|
@@ -65,3 +65,15 @@ class ErrorCommand(Command):
|
|
|
65
65
|
)
|
|
66
66
|
)
|
|
67
67
|
return tracker.create_stack_updated_events(stack)
|
|
68
|
+
|
|
69
|
+
def __hash__(self) -> int:
|
|
70
|
+
hashed = hash(self.error_type)
|
|
71
|
+
if self.info:
|
|
72
|
+
hashed += hash(str(self.info))
|
|
73
|
+
return hashed
|
|
74
|
+
|
|
75
|
+
def __eq__(self, other: object) -> bool:
|
|
76
|
+
if not isinstance(other, ErrorCommand):
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
return other.error_type == self.error_type and other.info == self.info
|
|
@@ -62,3 +62,12 @@ class HandleCodeChangeCommand(Command):
|
|
|
62
62
|
|
|
63
63
|
stack.push(CodeChangeFlowStackFrame())
|
|
64
64
|
return tracker.create_stack_updated_events(stack)
|
|
65
|
+
|
|
66
|
+
def __hash__(self) -> int:
|
|
67
|
+
return hash(self.command())
|
|
68
|
+
|
|
69
|
+
def __eq__(self, other: object) -> bool:
|
|
70
|
+
if not isinstance(other, HandleCodeChangeCommand):
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
return True
|
|
@@ -55,3 +55,12 @@ class HumanHandoffCommand(Command):
|
|
|
55
55
|
"command_executor.human_handoff.pushed_to_stack", command=self
|
|
56
56
|
)
|
|
57
57
|
return tracker.create_stack_updated_events(stack)
|
|
58
|
+
|
|
59
|
+
def __hash__(self) -> int:
|
|
60
|
+
return hash(self.command())
|
|
61
|
+
|
|
62
|
+
def __eq__(self, other: object) -> bool:
|
|
63
|
+
if not isinstance(other, HumanHandoffCommand):
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
return True
|
|
@@ -46,3 +46,12 @@ class KnowledgeAnswerCommand(FreeFormAnswerCommand):
|
|
|
46
46
|
stack = tracker.stack
|
|
47
47
|
stack.push(SearchPatternFlowStackFrame())
|
|
48
48
|
return tracker.create_stack_updated_events(stack)
|
|
49
|
+
|
|
50
|
+
def __hash__(self) -> int:
|
|
51
|
+
return hash(self.command())
|
|
52
|
+
|
|
53
|
+
def __eq__(self, other: object) -> bool:
|
|
54
|
+
if not isinstance(other, KnowledgeAnswerCommand):
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
return True
|
|
@@ -43,3 +43,12 @@ class NoopCommand(Command):
|
|
|
43
43
|
The events to apply to the tracker.
|
|
44
44
|
"""
|
|
45
45
|
return []
|
|
46
|
+
|
|
47
|
+
def __hash__(self) -> int:
|
|
48
|
+
return hash(self.command())
|
|
49
|
+
|
|
50
|
+
def __eq__(self, other: object) -> bool:
|
|
51
|
+
if not isinstance(other, NoopCommand):
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
return True
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
+
from enum import Enum
|
|
4
5
|
from typing import Any, Dict, List
|
|
5
6
|
|
|
6
7
|
import structlog
|
|
@@ -15,10 +16,22 @@ from rasa.shared.constants import ROUTE_TO_CALM_SLOT
|
|
|
15
16
|
from rasa.shared.core.events import Event, SlotSet
|
|
16
17
|
from rasa.shared.core.flows import FlowsList
|
|
17
18
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
19
|
+
from rasa.shared.nlu.constants import SET_SLOT_COMMAND
|
|
18
20
|
|
|
19
21
|
structlogger = structlog.get_logger()
|
|
20
22
|
|
|
21
23
|
|
|
24
|
+
class SetSlotExtractor(Enum):
|
|
25
|
+
"""The extractors that can set a slot."""
|
|
26
|
+
|
|
27
|
+
LLM = "LLM"
|
|
28
|
+
COMMAND_PAYLOAD_READER = "CommandPayloadReader"
|
|
29
|
+
NLU = "NLU"
|
|
30
|
+
|
|
31
|
+
def __str__(self) -> str:
|
|
32
|
+
return self.value
|
|
33
|
+
|
|
34
|
+
|
|
22
35
|
def get_flows_predicted_to_start_from_tracker(
|
|
23
36
|
tracker: DialogueStateTracker,
|
|
24
37
|
) -> List[str]:
|
|
@@ -45,11 +58,12 @@ class SetSlotCommand(Command):
|
|
|
45
58
|
|
|
46
59
|
name: str
|
|
47
60
|
value: Any
|
|
61
|
+
extractor: str = SetSlotExtractor.LLM.value
|
|
48
62
|
|
|
49
63
|
@classmethod
|
|
50
64
|
def command(cls) -> str:
|
|
51
65
|
"""Returns the command type."""
|
|
52
|
-
return
|
|
66
|
+
return SET_SLOT_COMMAND
|
|
53
67
|
|
|
54
68
|
@classmethod
|
|
55
69
|
def from_dict(cls, data: Dict[str, Any]) -> SetSlotCommand:
|
|
@@ -59,7 +73,11 @@ class SetSlotCommand(Command):
|
|
|
59
73
|
The converted dictionary.
|
|
60
74
|
"""
|
|
61
75
|
try:
|
|
62
|
-
return SetSlotCommand(
|
|
76
|
+
return SetSlotCommand(
|
|
77
|
+
name=data["name"],
|
|
78
|
+
value=data["value"],
|
|
79
|
+
extractor=data.get("extractor", SetSlotExtractor.LLM.value),
|
|
80
|
+
)
|
|
63
81
|
except KeyError as e:
|
|
64
82
|
raise ValueError(f"Missing key when parsing SetSlotCommand: {e}") from e
|
|
65
83
|
|
|
@@ -106,7 +124,15 @@ class SetSlotCommand(Command):
|
|
|
106
124
|
if isinstance(top_frame, CollectInformationPatternFlowStackFrame):
|
|
107
125
|
slots_of_active_flow.add(top_frame.collect)
|
|
108
126
|
|
|
109
|
-
if
|
|
127
|
+
if (
|
|
128
|
+
self.name not in slots_of_active_flow
|
|
129
|
+
and self.name != ROUTE_TO_CALM_SLOT
|
|
130
|
+
and self.extractor
|
|
131
|
+
in {
|
|
132
|
+
SetSlotExtractor.LLM.value,
|
|
133
|
+
SetSlotExtractor.COMMAND_PAYLOAD_READER.value,
|
|
134
|
+
}
|
|
135
|
+
):
|
|
110
136
|
# Get the other predicted flows from the most recent message on the tracker.
|
|
111
137
|
predicted_flows = get_flows_predicted_to_start_from_tracker(tracker)
|
|
112
138
|
use_slot_fill = any(
|
|
@@ -123,3 +149,12 @@ class SetSlotCommand(Command):
|
|
|
123
149
|
|
|
124
150
|
structlogger.debug("command_executor.set_slot", command=self)
|
|
125
151
|
return [SlotSet(self.name, slot.coerce_value(self.value))]
|
|
152
|
+
|
|
153
|
+
def __hash__(self) -> int:
|
|
154
|
+
return hash(self.value) + hash(self.name)
|
|
155
|
+
|
|
156
|
+
def __eq__(self, other: object) -> bool:
|
|
157
|
+
if not isinstance(other, SetSlotCommand):
|
|
158
|
+
return False
|
|
159
|
+
|
|
160
|
+
return other.value == self.value and other.name == self.name
|
|
@@ -64,3 +64,12 @@ class SkipQuestionCommand(Command):
|
|
|
64
64
|
|
|
65
65
|
stack.push(SkipQuestionPatternFlowStackFrame())
|
|
66
66
|
return tracker.create_stack_updated_events(stack)
|
|
67
|
+
|
|
68
|
+
def __hash__(self) -> int:
|
|
69
|
+
return hash(self.command())
|
|
70
|
+
|
|
71
|
+
def __eq__(self, other: object) -> bool:
|
|
72
|
+
if not isinstance(other, SkipQuestionCommand):
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
return True
|
|
@@ -96,3 +96,12 @@ class StartFlowCommand(Command):
|
|
|
96
96
|
structlogger.debug("command_executor.start_flow", command=self)
|
|
97
97
|
stack.push(UserFlowStackFrame(flow_id=self.flow, frame_type=frame_type))
|
|
98
98
|
return applied_events + tracker.create_stack_updated_events(stack)
|
|
99
|
+
|
|
100
|
+
def __hash__(self) -> int:
|
|
101
|
+
return hash(self.flow)
|
|
102
|
+
|
|
103
|
+
def __eq__(self, other: object) -> bool:
|
|
104
|
+
if not isinstance(other, StartFlowCommand):
|
|
105
|
+
return False
|
|
106
|
+
|
|
107
|
+
return other.flow == self.flow
|
|
@@ -1,6 +1,21 @@
|
|
|
1
1
|
from rasa.dialogue_understanding.generator.command_generator import CommandGenerator
|
|
2
|
+
from rasa.dialogue_understanding.generator.llm_based_command_generator import (
|
|
3
|
+
LLMBasedCommandGenerator,
|
|
4
|
+
)
|
|
2
5
|
from rasa.dialogue_understanding.generator.llm_command_generator import (
|
|
3
6
|
LLMCommandGenerator,
|
|
4
7
|
)
|
|
8
|
+
from rasa.dialogue_understanding.generator.multi_step.multi_step_llm_command_generator import ( # noqa: E501
|
|
9
|
+
MultiStepLLMCommandGenerator,
|
|
10
|
+
)
|
|
11
|
+
from rasa.dialogue_understanding.generator.single_step.single_step_llm_command_generator import ( # noqa: E501
|
|
12
|
+
SingleStepLLMCommandGenerator,
|
|
13
|
+
)
|
|
5
14
|
|
|
6
|
-
__all__ = [
|
|
15
|
+
__all__ = [
|
|
16
|
+
"CommandGenerator",
|
|
17
|
+
"LLMBasedCommandGenerator",
|
|
18
|
+
"LLMCommandGenerator",
|
|
19
|
+
"MultiStepLLMCommandGenerator",
|
|
20
|
+
"SingleStepLLMCommandGenerator",
|
|
21
|
+
]
|