rasa-pro 3.13.0.dev3__py3-none-any.whl → 3.13.0.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +3 -1
- rasa/api.py +4 -0
- rasa/cli/arguments/default_arguments.py +13 -1
- rasa/cli/arguments/train.py +2 -0
- rasa/cli/evaluate.py +1 -1
- rasa/cli/export.py +2 -2
- rasa/cli/inspect.py +8 -4
- rasa/cli/project_templates/default/config.yml +5 -32
- rasa/cli/project_templates/{calm → default}/e2e_tests/cancelations/user_cancels_during_a_correction.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/cancelations/user_changes_mind_on_a_whim.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/corrections/user_corrects_contact_handle.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/corrections/user_corrects_contact_name.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_adds_contact_to_their_list.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_lists_contacts.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_removes_contact.yml +1 -1
- rasa/cli/project_templates/{calm → default}/e2e_tests/happy_paths/user_removes_contact_from_list.yml +1 -1
- rasa/cli/project_templates/default/endpoints.yml +18 -2
- rasa/cli/scaffold.py +3 -4
- rasa/cli/studio/download.py +1 -1
- rasa/cli/studio/upload.py +0 -6
- rasa/cli/train.py +1 -0
- rasa/constants.py +2 -0
- rasa/core/agent.py +2 -2
- rasa/core/brokers/kafka.py +4 -0
- rasa/core/brokers/pika.py +4 -0
- rasa/core/brokers/sql.py +1 -1
- rasa/core/channels/channel.py +68 -5
- rasa/core/channels/inspector/.eslintrc.cjs +12 -6
- rasa/core/channels/inspector/.prettierrc +5 -0
- rasa/core/channels/inspector/README.md +10 -4
- rasa/core/channels/inspector/dist/assets/{arc-c7691751.js → arc-c4b064fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{blockDiagram-38ab4fdb-ab99dff7.js → blockDiagram-38ab4fdb-215b5026.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-3d4e48cf-08c35a6b.js → c4Diagram-3d4e48cf-2b54a0a3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/channel-3730f5fd.js +1 -0
- rasa/core/channels/inspector/dist/assets/{classDiagram-70f12bd4-9e9c71c9.js → classDiagram-70f12bd4-daacea5f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-f2320105-15e7e2bf.js → classDiagram-v2-f2320105-930d4dc2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/clone-e847561e.js +1 -0
- rasa/core/channels/inspector/dist/assets/{createText-2e5e7dd3-9c105cb1.js → createText-2e5e7dd3-83c206ba.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-e0da2a9e-77e89e48.js → edges-e0da2a9e-b0eb01d0.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9861fffd-7a011646.js → erDiagram-9861fffd-17586500.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-956e92f1-b6f105ac.js → flowDb-956e92f1-be2a1776.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-66a62f08-ce4f18c2.js → flowDiagram-66a62f08-c2120ebd.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-efbbfe00.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-4a651766-cb5f6da4.js → flowchart-elk-definition-4a651766-a6ab5c48.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-c361ad54-e4d19e28.js → ganttDiagram-c361ad54-ef613457.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-72cf32ee-727b1c33.js → gitGraphDiagram-72cf32ee-d59185b3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{graph-6e2ab9a7.js → graph-0f155405.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-3862675e-84ec700f.js → index-3862675e-d5f1d1b7.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-098a1a24.js → index-47737d3a.js} +162 -149
- rasa/core/channels/inspector/dist/assets/{infoDiagram-f8f76790-78dda442.js → infoDiagram-f8f76790-b07d141f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-49397b02-f1cc6dd1.js → journeyDiagram-49397b02-1936d429.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-d98dcd0c.js → layout-dde8d0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-838e3d82.js → line-0c2c7ee0.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-eae72406.js → linear-35dd89a4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-fc14e90a-c96fd84b.js → mindmap-definition-fc14e90a-56192851.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-8a3498a8-c936d4e2.js → pieDiagram-8a3498a8-fc21ed78.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-120e2f19-b338eb8f.js → quadrantDiagram-120e2f19-25e98518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-deff3bca-c6b6c0d5.js → requirementDiagram-deff3bca-546ff1f5.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-04a897e0-b9372e19.js → sankeyDiagram-04a897e0-02d8b82d.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-704730f1-479e0a3f.js → sequenceDiagram-704730f1-3ca5a92e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-587899a1-fd26eebc.js → stateDiagram-587899a1-128ea07c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-d93cdb3a-3233e0ae.js → stateDiagram-v2-d93cdb3a-95f290af.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-6aaf32cf-1fdd392b.js → styles-6aaf32cf-4984898a.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9a916d00-6d7bfa1b.js → styles-9a916d00-1bf266ba.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-c10674c1-f86aab11.js → styles-c10674c1-60521c63.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-08f97a94-e3e49d7a.js → svgDrawCommon-08f97a94-a25b6e12.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-85554ec2-6fe08b4d.js → timeline-definition-85554ec2-0fc086bf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-e933f94c-c2e06fd6.js → xychartDiagram-e933f94c-44ee592e.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/package.json +3 -1
- rasa/core/channels/inspector/src/App.tsx +92 -90
- rasa/core/channels/inspector/src/components/Chat.tsx +61 -36
- rasa/core/channels/inspector/src/components/DiagramFlow.tsx +40 -43
- rasa/core/channels/inspector/src/components/DialogueInformation.tsx +57 -57
- rasa/core/channels/inspector/src/components/DialogueStack.tsx +36 -27
- rasa/core/channels/inspector/src/components/ExpandIcon.tsx +4 -4
- rasa/core/channels/inspector/src/components/FullscreenButton.tsx +7 -7
- rasa/core/channels/inspector/src/components/LoadingSpinner.tsx +28 -12
- rasa/core/channels/inspector/src/components/NoActiveFlow.tsx +9 -9
- rasa/core/channels/inspector/src/components/RasaLogo.tsx +5 -5
- rasa/core/channels/inspector/src/components/RecruitmentPanel.tsx +55 -60
- rasa/core/channels/inspector/src/components/SaraDiagrams.tsx +5 -5
- rasa/core/channels/inspector/src/components/Slots.tsx +22 -22
- rasa/core/channels/inspector/src/components/Welcome.tsx +28 -31
- rasa/core/channels/inspector/src/helpers/audio/audiostream.ts +245 -0
- rasa/core/channels/inspector/src/helpers/audio/microphone-processor.js +12 -0
- rasa/core/channels/inspector/src/helpers/audio/playback-processor.js +36 -0
- rasa/core/channels/inspector/src/helpers/conversation.ts +16 -0
- rasa/core/channels/inspector/src/helpers/formatters.test.ts +181 -181
- rasa/core/channels/inspector/src/helpers/formatters.ts +111 -111
- rasa/core/channels/inspector/src/helpers/utils.ts +78 -61
- rasa/core/channels/inspector/src/main.tsx +8 -8
- rasa/core/channels/inspector/src/theme/Button/Button.ts +8 -8
- rasa/core/channels/inspector/src/theme/Heading/Heading.ts +7 -7
- rasa/core/channels/inspector/src/theme/Input/Input.ts +9 -9
- rasa/core/channels/inspector/src/theme/Link/Link.ts +6 -6
- rasa/core/channels/inspector/src/theme/Modal/Modal.ts +13 -13
- rasa/core/channels/inspector/src/theme/Table/Table.tsx +10 -10
- rasa/core/channels/inspector/src/theme/Tooltip/Tooltip.ts +5 -5
- rasa/core/channels/inspector/src/theme/base/breakpoints.ts +7 -7
- rasa/core/channels/inspector/src/theme/base/colors.ts +64 -64
- rasa/core/channels/inspector/src/theme/base/fonts/fontFaces.css +21 -18
- rasa/core/channels/inspector/src/theme/base/radii.ts +8 -8
- rasa/core/channels/inspector/src/theme/base/shadows.ts +5 -5
- rasa/core/channels/inspector/src/theme/base/sizes.ts +5 -5
- rasa/core/channels/inspector/src/theme/base/space.ts +12 -12
- rasa/core/channels/inspector/src/theme/base/styles.ts +5 -5
- rasa/core/channels/inspector/src/theme/base/typography.ts +12 -12
- rasa/core/channels/inspector/src/theme/base/zIndices.ts +3 -3
- rasa/core/channels/inspector/src/theme/index.ts +38 -38
- rasa/core/channels/inspector/src/types.ts +56 -50
- rasa/core/channels/inspector/yarn.lock +5 -0
- rasa/core/channels/voice_ready/audiocodes.py +75 -32
- rasa/core/channels/voice_ready/twilio_voice.py +48 -1
- rasa/core/channels/voice_stream/tts/azure.py +11 -2
- rasa/core/channels/voice_stream/twilio_media_streams.py +101 -26
- rasa/core/channels/voice_stream/voice_channel.py +28 -2
- rasa/core/concurrent_lock_store.py +24 -10
- rasa/core/evaluation/marker_tracker_loader.py +1 -1
- rasa/core/exporter.py +1 -1
- rasa/core/lock_store.py +151 -60
- rasa/core/nlg/contextual_response_rephraser.py +4 -2
- rasa/core/nlg/summarize.py +1 -1
- rasa/core/persistor.py +55 -20
- rasa/core/policies/enterprise_search_policy.py +7 -4
- rasa/core/policies/intentless_policy.py +15 -9
- rasa/core/processor.py +2 -2
- rasa/core/run.py +7 -2
- rasa/core/{auth_retry_tracker_store.py → tracker_stores/auth_retry_tracker_store.py} +5 -1
- rasa/core/tracker_stores/dynamo_tracker_store.py +218 -0
- rasa/core/tracker_stores/mongo_tracker_store.py +206 -0
- rasa/core/tracker_stores/redis_tracker_store.py +219 -0
- rasa/core/tracker_stores/sql_tracker_store.py +555 -0
- rasa/core/tracker_stores/tracker_store.py +805 -0
- rasa/core/utils.py +6 -0
- rasa/dialogue_understanding/coexistence/llm_based_router.py +8 -3
- rasa/dialogue_understanding/commands/clarify_command.py +2 -2
- rasa/dialogue_understanding/commands/knowledge_answer_command.py +2 -2
- rasa/dialogue_understanding/generator/constants.py +2 -2
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +1 -1
- rasa/dialogue_understanding/generator/prompt_templates/command_prompt_v2_gpt_4o_2024_11_20_template.jinja2 +33 -12
- rasa/dialogue_understanding/generator/single_step/compact_llm_command_generator.py +2 -2
- rasa/dialogue_understanding_test/du_test_case.py +16 -8
- rasa/hooks.py +2 -2
- rasa/keys +1 -0
- rasa/llm_fine_tuning/paraphrasing/conversation_rephraser.py +4 -2
- rasa/model_manager/config.py +3 -1
- rasa/model_manager/model_api.py +1 -2
- rasa/model_manager/runner_service.py +8 -4
- rasa/model_manager/trainer_service.py +1 -0
- rasa/model_training.py +12 -3
- rasa/nlu/extractors/crf_entity_extractor.py +66 -16
- rasa/plugin.py +1 -4
- rasa/server.py +6 -2
- rasa/shared/constants.py +4 -0
- rasa/shared/core/domain.py +165 -11
- rasa/shared/core/events.py +68 -2
- rasa/shared/core/flows/flow.py +155 -131
- rasa/shared/core/flows/flow_step.py +19 -3
- rasa/shared/core/flows/flow_step_links.py +15 -0
- rasa/shared/core/flows/flow_step_sequence.py +6 -0
- rasa/shared/core/flows/nlu_trigger.py +13 -0
- rasa/shared/core/flows/steps/action.py +7 -4
- rasa/shared/core/flows/steps/call.py +11 -4
- rasa/shared/core/flows/steps/collect.py +27 -6
- rasa/shared/core/flows/steps/internal.py +6 -1
- rasa/shared/core/flows/steps/link.py +7 -4
- rasa/shared/core/flows/steps/no_operation.py +7 -4
- rasa/shared/core/flows/steps/set_slots.py +8 -4
- rasa/shared/core/flows/yaml_flows_io.py +106 -5
- rasa/shared/importers/importer.py +8 -0
- rasa/shared/providers/_configs/azure_openai_client_config.py +4 -0
- rasa/shared/providers/_configs/openai_client_config.py +4 -0
- rasa/shared/providers/_utils.py +83 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +3 -0
- rasa/shared/providers/llm/_base_litellm_client.py +11 -5
- rasa/shared/providers/llm/azure_openai_llm_client.py +6 -68
- rasa/shared/providers/router/_base_litellm_router_client.py +53 -1
- rasa/shared/utils/common.py +42 -0
- rasa/studio/download/__init__.py +0 -0
- rasa/studio/download/domains.py +49 -0
- rasa/studio/download/download.py +439 -0
- rasa/studio/download/flows.py +359 -0
- rasa/studio/results_logger.py +6 -1
- rasa/studio/upload.py +69 -5
- rasa/telemetry.py +2 -2
- rasa/tracing/config.py +1 -1
- rasa/tracing/instrumentation/attribute_extractors.py +1 -1
- rasa/tracing/instrumentation/instrumentation.py +1 -1
- rasa/utils/common.py +36 -0
- rasa/utils/endpoints.py +22 -1
- rasa/utils/licensing.py +2 -3
- rasa/validator.py +1 -2
- rasa/version.py +1 -1
- {rasa_pro-3.13.0.dev3.dist-info → rasa_pro-3.13.0.dev7.dist-info}/METADATA +10 -10
- {rasa_pro-3.13.0.dev3.dist-info → rasa_pro-3.13.0.dev7.dist-info}/RECORD +213 -210
- rasa/cli/project_templates/calm/config.yml +0 -10
- rasa/cli/project_templates/calm/credentials.yml +0 -33
- rasa/cli/project_templates/calm/endpoints.yml +0 -58
- rasa/cli/project_templates/default/actions/actions.py +0 -27
- rasa/cli/project_templates/default/data/nlu.yml +0 -91
- rasa/cli/project_templates/default/data/rules.yml +0 -13
- rasa/cli/project_templates/default/data/stories.yml +0 -30
- rasa/cli/project_templates/default/domain.yml +0 -34
- rasa/cli/project_templates/default/tests/test_stories.yml +0 -91
- rasa/core/channels/inspector/dist/assets/channel-11268142.js +0 -1
- rasa/core/channels/inspector/dist/assets/clone-ff7f2ce7.js +0 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-96b9c2cf-cba7ae20.js +0 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +0 -191
- rasa/core/tracker_store.py +0 -1792
- rasa/studio/download.py +0 -489
- /rasa/cli/project_templates/{calm → default}/actions/action_template.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/add_contact.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/db.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/list_contacts.py +0 -0
- /rasa/cli/project_templates/{calm → default}/actions/remove_contact.py +0 -0
- /rasa/cli/project_templates/{calm → default}/data/flows/add_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/data/flows/list_contacts.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/data/flows/remove_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/db/contacts.json +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/add_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/list_contacts.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/remove_contact.yml +0 -0
- /rasa/cli/project_templates/{calm → default}/domain/shared.yml +0 -0
- /rasa/{cli/project_templates/calm/actions → core/tracker_stores}/__init__.py +0 -0
- {rasa_pro-3.13.0.dev3.dist-info → rasa_pro-3.13.0.dev7.dist-info}/NOTICE +0 -0
- {rasa_pro-3.13.0.dev3.dist-info → rasa_pro-3.13.0.dev7.dist-info}/WHEEL +0 -0
- {rasa_pro-3.13.0.dev3.dist-info → rasa_pro-3.13.0.dev7.dist-info}/entry_points.txt +0 -0
rasa/core/persistor.py
CHANGED
|
@@ -121,10 +121,12 @@ def get_persistor(storage: StorageType) -> Optional[Persistor]:
|
|
|
121
121
|
class Persistor(abc.ABC):
|
|
122
122
|
"""Store models in cloud and fetch them when needed."""
|
|
123
123
|
|
|
124
|
-
def persist(self, trained_model: str) -> None:
|
|
124
|
+
def persist(self, trained_model: str, remote_root_only: bool = False) -> None:
|
|
125
125
|
"""Uploads a trained model persisted in the `target_dir` to cloud storage."""
|
|
126
126
|
absolute_file_key = self._create_file_key(trained_model)
|
|
127
|
-
file_key =
|
|
127
|
+
file_key = (
|
|
128
|
+
Path(absolute_file_key).name if remote_root_only else absolute_file_key
|
|
129
|
+
)
|
|
128
130
|
self._persist_tar(file_key, trained_model)
|
|
129
131
|
|
|
130
132
|
def retrieve(self, model_name: Text, target_path: Text) -> Text:
|
|
@@ -143,30 +145,32 @@ class Persistor(abc.ABC):
|
|
|
143
145
|
# ensure backward compatibility
|
|
144
146
|
tar_name = self._tar_name(model_name)
|
|
145
147
|
tar_name = self._create_file_key(tar_name)
|
|
146
|
-
|
|
147
|
-
self._retrieve_tar(target_filename)
|
|
148
|
-
self._copy(os.path.basename(tar_name), target_path)
|
|
148
|
+
self._retrieve_tar(tar_name, target_path)
|
|
149
149
|
|
|
150
150
|
if os.path.isdir(target_path):
|
|
151
151
|
return os.path.join(target_path, model_name)
|
|
152
152
|
|
|
153
153
|
return target_path
|
|
154
154
|
|
|
155
|
-
def size_of_persisted_model(
|
|
155
|
+
def size_of_persisted_model(
|
|
156
|
+
self, model_name: Text, target_path: Optional[str] = None
|
|
157
|
+
) -> int:
|
|
156
158
|
"""Returns the size of the model that has been persisted to cloud storage.
|
|
157
159
|
|
|
158
160
|
Args:
|
|
159
161
|
model_name: The name of the model to retrieve.
|
|
162
|
+
target_path: The path to which the model should be saved.
|
|
160
163
|
"""
|
|
161
164
|
tar_name = model_name
|
|
162
165
|
if not model_name.endswith(MODEL_ARCHIVE_EXTENSION):
|
|
163
166
|
# ensure backward compatibility
|
|
164
167
|
tar_name = self._tar_name(model_name)
|
|
165
168
|
tar_name = self._create_file_key(tar_name)
|
|
166
|
-
|
|
167
|
-
return self._retrieve_tar_size(target_filename)
|
|
169
|
+
return self._retrieve_tar_size(tar_name, target_path)
|
|
168
170
|
|
|
169
|
-
def _retrieve_tar_size(
|
|
171
|
+
def _retrieve_tar_size(
|
|
172
|
+
self, filename: Text, target_path: Optional[str] = None
|
|
173
|
+
) -> int:
|
|
170
174
|
"""Returns the size of the model that has been persisted to cloud storage."""
|
|
171
175
|
structlogger.warning(
|
|
172
176
|
"persistor.retrieve_tar_size.not_implemented",
|
|
@@ -179,11 +183,11 @@ class Persistor(abc.ABC):
|
|
|
179
183
|
"size directly from the cloud storage."
|
|
180
184
|
),
|
|
181
185
|
)
|
|
182
|
-
self._retrieve_tar(filename)
|
|
186
|
+
self._retrieve_tar(filename, target_path)
|
|
183
187
|
return os.path.getsize(os.path.basename(filename))
|
|
184
188
|
|
|
185
189
|
@abc.abstractmethod
|
|
186
|
-
def _retrieve_tar(self, filename:
|
|
190
|
+
def _retrieve_tar(self, filename: str, target_path: Optional[str] = None) -> None:
|
|
187
191
|
"""Downloads a model previously persisted to cloud storage."""
|
|
188
192
|
raise NotImplementedError
|
|
189
193
|
|
|
@@ -302,7 +306,9 @@ class AWSPersistor(Persistor):
|
|
|
302
306
|
with open(tar_path, "rb") as f:
|
|
303
307
|
self.s3.Object(self.bucket_name, file_key).put(Body=f)
|
|
304
308
|
|
|
305
|
-
def _retrieve_tar_size(
|
|
309
|
+
def _retrieve_tar_size(
|
|
310
|
+
self, model_path: Text, target_path: Optional[str] = None
|
|
311
|
+
) -> int:
|
|
306
312
|
"""Returns the size of the model that has been persisted to s3."""
|
|
307
313
|
try:
|
|
308
314
|
obj = self.s3.Object(self.bucket_name, model_path)
|
|
@@ -310,7 +316,9 @@ class AWSPersistor(Persistor):
|
|
|
310
316
|
except Exception:
|
|
311
317
|
raise ModelNotFound()
|
|
312
318
|
|
|
313
|
-
def _retrieve_tar(
|
|
319
|
+
def _retrieve_tar(
|
|
320
|
+
self, target_filename: str, target_path: Optional[str] = None
|
|
321
|
+
) -> None:
|
|
314
322
|
"""Downloads a model that has previously been persisted to s3."""
|
|
315
323
|
from botocore import exceptions
|
|
316
324
|
|
|
@@ -320,8 +328,14 @@ class AWSPersistor(Persistor):
|
|
|
320
328
|
f"in the bucket."
|
|
321
329
|
)
|
|
322
330
|
|
|
331
|
+
tar_name = (
|
|
332
|
+
os.path.join(target_path, os.path.basename(target_filename))
|
|
333
|
+
if target_path
|
|
334
|
+
else os.path.basename(target_filename)
|
|
335
|
+
)
|
|
336
|
+
|
|
323
337
|
try:
|
|
324
|
-
with open(
|
|
338
|
+
with open(tar_name, "wb") as f:
|
|
325
339
|
self.bucket.download_fileobj(target_filename, f)
|
|
326
340
|
|
|
327
341
|
structlogger.debug(
|
|
@@ -425,7 +439,9 @@ class GCSPersistor(Persistor):
|
|
|
425
439
|
blob = self.bucket.blob(file_key)
|
|
426
440
|
blob.upload_from_filename(tar_path)
|
|
427
441
|
|
|
428
|
-
def _retrieve_tar_size(
|
|
442
|
+
def _retrieve_tar_size(
|
|
443
|
+
self, target_filename: Text, target_path: Optional[str] = None
|
|
444
|
+
) -> int:
|
|
429
445
|
"""Returns the size of the model that has been persisted to GCS."""
|
|
430
446
|
try:
|
|
431
447
|
blob = self.bucket.blob(target_filename)
|
|
@@ -433,13 +449,22 @@ class GCSPersistor(Persistor):
|
|
|
433
449
|
except Exception:
|
|
434
450
|
raise ModelNotFound()
|
|
435
451
|
|
|
436
|
-
def _retrieve_tar(
|
|
452
|
+
def _retrieve_tar(
|
|
453
|
+
self, target_filename: str, target_path: Optional[str] = None
|
|
454
|
+
) -> None:
|
|
437
455
|
"""Downloads a model that has previously been persisted to GCS."""
|
|
438
456
|
from google.api_core import exceptions
|
|
439
457
|
|
|
440
458
|
blob = self.bucket.blob(target_filename)
|
|
459
|
+
|
|
460
|
+
destination = (
|
|
461
|
+
os.path.join(target_path, os.path.basename(target_filename))
|
|
462
|
+
if target_path
|
|
463
|
+
else target_filename
|
|
464
|
+
)
|
|
465
|
+
|
|
441
466
|
try:
|
|
442
|
-
blob.download_to_filename(
|
|
467
|
+
blob.download_to_filename(destination)
|
|
443
468
|
|
|
444
469
|
structlogger.debug(
|
|
445
470
|
"gcs_persistor.retrieve_tar.object_found", object_key=target_filename
|
|
@@ -500,7 +525,9 @@ class AzurePersistor(Persistor):
|
|
|
500
525
|
with open(tar_path, "rb") as data:
|
|
501
526
|
self._container_client().upload_blob(name=file_key, data=data)
|
|
502
527
|
|
|
503
|
-
def _retrieve_tar_size(
|
|
528
|
+
def _retrieve_tar_size(
|
|
529
|
+
self, target_filename: Text, target_path: Optional[str] = None
|
|
530
|
+
) -> int:
|
|
504
531
|
"""Returns the size of the model that has been persisted to Azure."""
|
|
505
532
|
try:
|
|
506
533
|
blob_client = self._container_client().get_blob_client(target_filename)
|
|
@@ -509,12 +536,20 @@ class AzurePersistor(Persistor):
|
|
|
509
536
|
except Exception:
|
|
510
537
|
raise ModelNotFound()
|
|
511
538
|
|
|
512
|
-
def _retrieve_tar(
|
|
539
|
+
def _retrieve_tar(
|
|
540
|
+
self, target_filename: Text, target_path: Optional[str] = None
|
|
541
|
+
) -> None:
|
|
513
542
|
"""Downloads a model that has previously been persisted to Azure."""
|
|
514
543
|
from azure.core.exceptions import AzureError
|
|
515
544
|
|
|
545
|
+
destination = (
|
|
546
|
+
os.path.join(target_path, os.path.basename(target_filename))
|
|
547
|
+
if target_path
|
|
548
|
+
else target_filename
|
|
549
|
+
)
|
|
550
|
+
|
|
516
551
|
try:
|
|
517
|
-
with open(
|
|
552
|
+
with open(destination, "wb") as model_file:
|
|
518
553
|
blob_client = self._container_client().get_blob_client(target_filename)
|
|
519
554
|
download_stream = blob_client.download_blob()
|
|
520
555
|
model_file.write(download_stream.readall())
|
|
@@ -46,12 +46,15 @@ from rasa.graph_components.providers.forms_provider import Forms
|
|
|
46
46
|
from rasa.graph_components.providers.responses_provider import Responses
|
|
47
47
|
from rasa.shared.constants import (
|
|
48
48
|
EMBEDDINGS_CONFIG_KEY,
|
|
49
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
50
|
+
MAX_RETRIES_CONFIG_KEY,
|
|
49
51
|
MODEL_CONFIG_KEY,
|
|
50
52
|
MODEL_GROUP_ID_CONFIG_KEY,
|
|
51
53
|
MODEL_NAME_CONFIG_KEY,
|
|
52
54
|
OPENAI_PROVIDER,
|
|
53
55
|
PROMPT_CONFIG_KEY,
|
|
54
56
|
PROVIDER_CONFIG_KEY,
|
|
57
|
+
TEMPERATURE_CONFIG_KEY,
|
|
55
58
|
TIMEOUT_CONFIG_KEY,
|
|
56
59
|
)
|
|
57
60
|
from rasa.shared.core.constants import (
|
|
@@ -135,14 +138,14 @@ DEFAULT_LLM_CONFIG = {
|
|
|
135
138
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
136
139
|
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
137
140
|
TIMEOUT_CONFIG_KEY: 10,
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
+
TEMPERATURE_CONFIG_KEY: 0.0,
|
|
142
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY: 256,
|
|
143
|
+
MAX_RETRIES_CONFIG_KEY: 1,
|
|
141
144
|
}
|
|
142
145
|
|
|
143
146
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
144
147
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
145
|
-
|
|
148
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
146
149
|
}
|
|
147
150
|
|
|
148
151
|
ENTERPRISE_SEARCH_PROMPT_FILE_NAME = "enterprise_search_policy_prompt.jinja2"
|
|
@@ -31,12 +31,14 @@ from rasa.graph_components.providers.responses_provider import Responses
|
|
|
31
31
|
from rasa.shared.constants import (
|
|
32
32
|
EMBEDDINGS_CONFIG_KEY,
|
|
33
33
|
LLM_CONFIG_KEY,
|
|
34
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY,
|
|
34
35
|
MODEL_CONFIG_KEY,
|
|
35
36
|
MODEL_GROUP_ID_CONFIG_KEY,
|
|
36
37
|
MODEL_NAME_CONFIG_KEY,
|
|
37
38
|
OPENAI_PROVIDER,
|
|
38
39
|
PROMPT_CONFIG_KEY,
|
|
39
40
|
PROVIDER_CONFIG_KEY,
|
|
41
|
+
TEMPERATURE_CONFIG_KEY,
|
|
40
42
|
TIMEOUT_CONFIG_KEY,
|
|
41
43
|
)
|
|
42
44
|
from rasa.shared.core.constants import ACTION_LISTEN_NAME
|
|
@@ -111,14 +113,14 @@ NLU_ABSTENTION_THRESHOLD = "nlu_abstention_threshold"
|
|
|
111
113
|
DEFAULT_LLM_CONFIG = {
|
|
112
114
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
113
115
|
MODEL_CONFIG_KEY: DEFAULT_OPENAI_CHAT_MODEL_NAME,
|
|
114
|
-
|
|
115
|
-
|
|
116
|
+
TEMPERATURE_CONFIG_KEY: 0.0,
|
|
117
|
+
MAX_COMPLETION_TOKENS_CONFIG_KEY: DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
116
118
|
TIMEOUT_CONFIG_KEY: 5,
|
|
117
119
|
}
|
|
118
120
|
|
|
119
121
|
DEFAULT_EMBEDDINGS_CONFIG = {
|
|
120
122
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
121
|
-
|
|
123
|
+
MODEL_CONFIG_KEY: DEFAULT_OPENAI_EMBEDDING_MODEL_NAME,
|
|
122
124
|
}
|
|
123
125
|
|
|
124
126
|
DEFAULT_INTENTLESS_PROMPT_TEMPLATE = importlib.resources.open_text(
|
|
@@ -344,8 +346,6 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
344
346
|
# ensures that the policy will not override a deterministic policy
|
|
345
347
|
# which utilizes the nlu predictions confidence (e.g. Memoization).
|
|
346
348
|
NLU_ABSTENTION_THRESHOLD: 0.9,
|
|
347
|
-
LLM_CONFIG_KEY: DEFAULT_LLM_CONFIG,
|
|
348
|
-
EMBEDDINGS_CONFIG_KEY: DEFAULT_EMBEDDINGS_CONFIG,
|
|
349
349
|
PROMPT_CONFIG_KEY: DEFAULT_INTENTLESS_PROMPT_TEMPLATE,
|
|
350
350
|
}
|
|
351
351
|
|
|
@@ -381,13 +381,19 @@ class IntentlessPolicy(LLMHealthCheckMixin, EmbeddingsHealthCheckMixin, Policy):
|
|
|
381
381
|
super().__init__(config, model_storage, resource, execution_context, featurizer)
|
|
382
382
|
|
|
383
383
|
# Resolve LLM config
|
|
384
|
-
self.config[LLM_CONFIG_KEY] =
|
|
385
|
-
|
|
384
|
+
self.config[LLM_CONFIG_KEY] = combine_custom_and_default_config(
|
|
385
|
+
resolve_model_client_config(
|
|
386
|
+
self.config.get(LLM_CONFIG_KEY), IntentlessPolicy.__name__
|
|
387
|
+
),
|
|
388
|
+
DEFAULT_LLM_CONFIG,
|
|
386
389
|
)
|
|
387
390
|
|
|
388
391
|
# Resolve embeddings config
|
|
389
|
-
self.config[EMBEDDINGS_CONFIG_KEY] =
|
|
390
|
-
|
|
392
|
+
self.config[EMBEDDINGS_CONFIG_KEY] = combine_custom_and_default_config(
|
|
393
|
+
resolve_model_client_config(
|
|
394
|
+
self.config.get(EMBEDDINGS_CONFIG_KEY), IntentlessPolicy.__name__
|
|
395
|
+
),
|
|
396
|
+
DEFAULT_EMBEDDINGS_CONFIG,
|
|
391
397
|
)
|
|
392
398
|
|
|
393
399
|
self.nlu_abstention_threshold: float = self.config[NLU_ABSTENTION_THRESHOLD]
|
rasa/core/processor.py
CHANGED
|
@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Text, Tuple, Union
|
|
|
12
12
|
import structlog
|
|
13
13
|
|
|
14
14
|
import rasa.core.actions.action
|
|
15
|
-
import rasa.core.tracker_store
|
|
15
|
+
import rasa.core.tracker_stores.tracker_store
|
|
16
16
|
import rasa.core.utils
|
|
17
17
|
import rasa.shared.core.trackers
|
|
18
18
|
import rasa.shared.utils.io
|
|
@@ -126,7 +126,7 @@ class MessageProcessor:
|
|
|
126
126
|
def __init__(
|
|
127
127
|
self,
|
|
128
128
|
model_path: Union[Text, Path],
|
|
129
|
-
tracker_store: rasa.core.tracker_store.TrackerStore,
|
|
129
|
+
tracker_store: rasa.core.tracker_stores.tracker_store.TrackerStore,
|
|
130
130
|
lock_store: LockStore,
|
|
131
131
|
generator: NaturalLanguageGenerator,
|
|
132
132
|
action_endpoint: Optional[EndpointConfig] = None,
|
rasa/core/run.py
CHANGED
|
@@ -86,13 +86,15 @@ def _create_single_channel(channel: Text, credentials: Dict[Text, Any]) -> Any:
|
|
|
86
86
|
)
|
|
87
87
|
|
|
88
88
|
|
|
89
|
-
def _create_app_without_api(
|
|
89
|
+
def _create_app_without_api(
|
|
90
|
+
cors: Optional[Union[Text, List[Text]]] = None, is_inspector_enabled: bool = False
|
|
91
|
+
) -> Sanic:
|
|
90
92
|
app = Sanic("rasa_core_no_api", configure_logging=False)
|
|
91
93
|
|
|
92
94
|
# Reset Sanic warnings filter that allows the triggering of Sanic warnings
|
|
93
95
|
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"sanic.*")
|
|
94
96
|
|
|
95
|
-
server.add_root_route(app)
|
|
97
|
+
server.add_root_route(app, is_inspector_enabled)
|
|
96
98
|
server.configure_cors(app, cors)
|
|
97
99
|
return app
|
|
98
100
|
|
|
@@ -127,6 +129,7 @@ def configure_app(
|
|
|
127
129
|
server_listeners: Optional[List[Tuple[Callable, Text]]] = None,
|
|
128
130
|
use_uvloop: Optional[bool] = True,
|
|
129
131
|
keep_alive_timeout: int = constants.DEFAULT_KEEP_ALIVE_TIMEOUT,
|
|
132
|
+
is_inspector_enabled: bool = False,
|
|
130
133
|
) -> Sanic:
|
|
131
134
|
"""Run the agent."""
|
|
132
135
|
rasa.core.utils.configure_file_logging(
|
|
@@ -144,6 +147,7 @@ def configure_app(
|
|
|
144
147
|
jwt_private_key=jwt_private_key,
|
|
145
148
|
jwt_method=jwt_method,
|
|
146
149
|
endpoints=endpoints,
|
|
150
|
+
is_inspector_enabled=is_inspector_enabled,
|
|
147
151
|
)
|
|
148
152
|
)
|
|
149
153
|
else:
|
|
@@ -259,6 +263,7 @@ def serve_application(
|
|
|
259
263
|
syslog_protocol=syslog_protocol,
|
|
260
264
|
request_timeout=request_timeout,
|
|
261
265
|
server_listeners=server_listeners,
|
|
266
|
+
is_inspector_enabled=inspect,
|
|
262
267
|
)
|
|
263
268
|
|
|
264
269
|
ssl_context = server.create_ssl_context(
|
|
@@ -3,7 +3,7 @@ from typing import Iterable, Optional, Text
|
|
|
3
3
|
|
|
4
4
|
from rasa.core.brokers.broker import EventBroker
|
|
5
5
|
from rasa.core.secrets_manager.secret_manager import EndpointResolver
|
|
6
|
-
from rasa.core.tracker_store import TrackerStore, create_tracker_store
|
|
6
|
+
from rasa.core.tracker_stores.tracker_store import TrackerStore, create_tracker_store
|
|
7
7
|
from rasa.shared.core.domain import Domain
|
|
8
8
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
9
9
|
from rasa.utils.endpoints import EndpointConfig
|
|
@@ -119,3 +119,7 @@ class AuthRetryTrackerStore(TrackerStore):
|
|
|
119
119
|
"""Recreate tracker store with updated credentials."""
|
|
120
120
|
endpoint_config = EndpointResolver.update_config(self.endpoint_config)
|
|
121
121
|
return create_tracker_store(endpoint_config, domain, event_broker)
|
|
122
|
+
|
|
123
|
+
async def delete(self, sender_id: Text) -> None:
|
|
124
|
+
"""Delete tracker for the given sender_id."""
|
|
125
|
+
await self._tracker_store.delete(sender_id)
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Text
|
|
5
|
+
|
|
6
|
+
import structlog
|
|
7
|
+
from boto3.dynamodb.conditions import Key
|
|
8
|
+
|
|
9
|
+
import rasa.utils
|
|
10
|
+
from rasa.constants import DEFAULT_SANIC_WORKERS, ENV_SANIC_WORKERS
|
|
11
|
+
from rasa.core.tracker_stores.tracker_store import (
|
|
12
|
+
SerializedTrackerAsDict,
|
|
13
|
+
TrackerStore,
|
|
14
|
+
)
|
|
15
|
+
from rasa.shared.core.domain import Domain
|
|
16
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
17
|
+
from rasa.shared.exceptions import RasaException
|
|
18
|
+
from rasa.utils.endpoints import EndpointConfig
|
|
19
|
+
|
|
20
|
+
structlogger = structlog.get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
import boto3.resources.factory.dynamodb.Table
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DynamoTrackerStore(TrackerStore, SerializedTrackerAsDict):
|
|
27
|
+
"""Stores conversation history in DynamoDB."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
domain: Domain,
|
|
32
|
+
table_name: Text = "states",
|
|
33
|
+
region: Text = "us-east-1",
|
|
34
|
+
event_broker: Optional[EndpointConfig] = None,
|
|
35
|
+
**kwargs: Dict[Text, Any],
|
|
36
|
+
) -> None:
|
|
37
|
+
"""Initialize `DynamoTrackerStore`.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
domain: Domain associated with this tracker store.
|
|
41
|
+
table_name: The name of the DynamoDB table, does not need to be present a
|
|
42
|
+
priori.
|
|
43
|
+
region: The name of the region associated with the client.
|
|
44
|
+
A client is associated with a single region.
|
|
45
|
+
event_broker: An event broker used to publish events.
|
|
46
|
+
kwargs: Additional kwargs.
|
|
47
|
+
"""
|
|
48
|
+
import boto3
|
|
49
|
+
|
|
50
|
+
self.client = boto3.client("dynamodb", region_name=region)
|
|
51
|
+
self.region = region
|
|
52
|
+
self.table_name = table_name
|
|
53
|
+
self.db = self.get_or_create_table(table_name)
|
|
54
|
+
super().__init__(domain, event_broker, **kwargs)
|
|
55
|
+
|
|
56
|
+
def get_or_create_table(
|
|
57
|
+
self, table_name: Text
|
|
58
|
+
) -> "boto3.resources.factory.dynamodb.Table":
|
|
59
|
+
"""Returns table or creates one if the table name is not in the table list."""
|
|
60
|
+
import boto3
|
|
61
|
+
|
|
62
|
+
dynamo = boto3.resource("dynamodb", region_name=self.region)
|
|
63
|
+
try:
|
|
64
|
+
self.client.describe_table(TableName=table_name)
|
|
65
|
+
except self.client.exceptions.ResourceNotFoundException:
|
|
66
|
+
sanic_workers_count = int(
|
|
67
|
+
os.environ.get(ENV_SANIC_WORKERS, DEFAULT_SANIC_WORKERS)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if sanic_workers_count > 1:
|
|
71
|
+
structlogger.error(
|
|
72
|
+
"dynamo_tracker_store.table_creation_not_supported_in_multi_worker_mode",
|
|
73
|
+
event_info=(
|
|
74
|
+
"DynamoDB table creation is not "
|
|
75
|
+
"supported in multi-worker mode. "
|
|
76
|
+
"Table should already exist.",
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
raise RasaException(
|
|
80
|
+
"DynamoDB table creation is not supported in "
|
|
81
|
+
"case of multiple sanic workers. To create the table either "
|
|
82
|
+
"run Rasa with a single worker or create the table manually."
|
|
83
|
+
"Here are the defaults which can be used to "
|
|
84
|
+
"create the table manually: "
|
|
85
|
+
f"Table name: {table_name}, Primary key: sender_id, "
|
|
86
|
+
f"key type `HASH`, attribute type `S` (String), "
|
|
87
|
+
"Provisioned throughput: Read capacity units: 5, "
|
|
88
|
+
"Write capacity units: 5"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
table = dynamo.create_table(
|
|
92
|
+
TableName=self.table_name,
|
|
93
|
+
KeySchema=[{"AttributeName": "sender_id", "KeyType": "HASH"}],
|
|
94
|
+
AttributeDefinitions=[
|
|
95
|
+
{"AttributeName": "sender_id", "AttributeType": "S"}
|
|
96
|
+
],
|
|
97
|
+
ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5},
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Wait until the table exists.
|
|
101
|
+
table.meta.client.get_waiter("table_exists").wait(TableName=table_name)
|
|
102
|
+
else:
|
|
103
|
+
table = dynamo.Table(table_name)
|
|
104
|
+
|
|
105
|
+
return table
|
|
106
|
+
|
|
107
|
+
async def save(self, tracker: DialogueStateTracker) -> None:
|
|
108
|
+
"""Saves the current conversation state."""
|
|
109
|
+
await self.stream_events(tracker)
|
|
110
|
+
serialized = self.serialise_tracker(tracker)
|
|
111
|
+
|
|
112
|
+
self.db.put_item(Item=serialized)
|
|
113
|
+
|
|
114
|
+
async def delete(self, sender_id: Text) -> None:
|
|
115
|
+
"""Delete tracker for the given sender_id."""
|
|
116
|
+
if not await self.exists(sender_id):
|
|
117
|
+
structlogger.info(
|
|
118
|
+
"dynamo_tracker_store.delete.no_tracker_for_sender_id",
|
|
119
|
+
event_info=f"Could not find tracker for conversation ID '{sender_id}'.",
|
|
120
|
+
)
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
self.db.delete_item(
|
|
124
|
+
Key={"sender_id": sender_id},
|
|
125
|
+
ConditionExpression="attribute_exists(sender_id)",
|
|
126
|
+
)
|
|
127
|
+
structlogger.info(
|
|
128
|
+
"dynamo_tracker_store.delete.deleted_tracker",
|
|
129
|
+
sender_id=sender_id,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
@staticmethod
|
|
133
|
+
def serialise_tracker(
|
|
134
|
+
tracker: "DialogueStateTracker",
|
|
135
|
+
) -> Dict:
|
|
136
|
+
"""Serializes the tracker, returns object with decimal types.
|
|
137
|
+
|
|
138
|
+
DynamoDB cannot store `float`s, so we'll convert them to `Decimal`s.
|
|
139
|
+
"""
|
|
140
|
+
return rasa.utils.json_utils.replace_floats_with_decimals(
|
|
141
|
+
SerializedTrackerAsDict.serialise_tracker(tracker)
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
async def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
|
|
145
|
+
"""Retrieve dialogues for a sender_id in reverse-chronological order.
|
|
146
|
+
|
|
147
|
+
Based on the session_date sort key.
|
|
148
|
+
"""
|
|
149
|
+
return await self._retrieve(sender_id, fetch_all_sessions=False)
|
|
150
|
+
|
|
151
|
+
async def retrieve_full_tracker(
|
|
152
|
+
self, sender_id: Text
|
|
153
|
+
) -> Optional[DialogueStateTracker]:
|
|
154
|
+
"""Retrieves tracker for all conversation sessions.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
sender_id: Conversation ID to fetch the tracker for.
|
|
158
|
+
"""
|
|
159
|
+
return await self._retrieve(sender_id, fetch_all_sessions=True)
|
|
160
|
+
|
|
161
|
+
async def _retrieve(
|
|
162
|
+
self, sender_id: Text, fetch_all_sessions: bool
|
|
163
|
+
) -> Optional[DialogueStateTracker]:
|
|
164
|
+
"""Returns tracker matching sender_id.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
sender_id: Conversation ID to fetch the tracker for.
|
|
168
|
+
fetch_all_sessions: Whether to fetch all sessions or only the last one.
|
|
169
|
+
"""
|
|
170
|
+
dialogues = self.db.query(
|
|
171
|
+
KeyConditionExpression=Key("sender_id").eq(sender_id),
|
|
172
|
+
ScanIndexForward=False,
|
|
173
|
+
)["Items"]
|
|
174
|
+
|
|
175
|
+
if not dialogues:
|
|
176
|
+
return None
|
|
177
|
+
|
|
178
|
+
events_with_floats = []
|
|
179
|
+
for dialogue in dialogues:
|
|
180
|
+
if dialogue.get("events"):
|
|
181
|
+
events = rasa.utils.json_utils.replace_decimals_with_floats(
|
|
182
|
+
dialogue["events"]
|
|
183
|
+
)
|
|
184
|
+
events_with_floats += events
|
|
185
|
+
|
|
186
|
+
if self.domain is None:
|
|
187
|
+
slots = []
|
|
188
|
+
else:
|
|
189
|
+
slots = self.domain.slots
|
|
190
|
+
|
|
191
|
+
tracker = DialogueStateTracker.from_dict(sender_id, events_with_floats, slots)
|
|
192
|
+
|
|
193
|
+
if fetch_all_sessions:
|
|
194
|
+
return tracker
|
|
195
|
+
|
|
196
|
+
# only return the last session
|
|
197
|
+
multiple_tracker_sessions = (
|
|
198
|
+
rasa.shared.core.trackers.get_trackers_for_conversation_sessions(tracker)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if len(multiple_tracker_sessions) <= 1:
|
|
202
|
+
return tracker
|
|
203
|
+
|
|
204
|
+
return multiple_tracker_sessions[-1]
|
|
205
|
+
|
|
206
|
+
async def keys(self) -> Iterable[Text]:
|
|
207
|
+
"""Returns sender_ids of the `DynamoTrackerStore`."""
|
|
208
|
+
response = self.db.scan(ProjectionExpression="sender_id")
|
|
209
|
+
sender_ids = [i["sender_id"] for i in response["Items"]]
|
|
210
|
+
|
|
211
|
+
while response.get("LastEvaluatedKey"):
|
|
212
|
+
response = self.db.scan(
|
|
213
|
+
ProjectionExpression="sender_id",
|
|
214
|
+
ExclusiveStartKey=response["LastEvaluatedKey"],
|
|
215
|
+
)
|
|
216
|
+
sender_ids.extend([i["sender_id"] for i in response["Items"]])
|
|
217
|
+
|
|
218
|
+
return sender_ids
|