rasa-pro 3.11.0a4.dev3__py3-none-any.whl → 3.11.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rasa-pro might be problematic. Click here for more details.
- rasa/__main__.py +22 -12
- rasa/api.py +1 -1
- rasa/cli/arguments/default_arguments.py +1 -2
- rasa/cli/arguments/shell.py +5 -1
- rasa/cli/e2e_test.py +1 -1
- rasa/cli/evaluate.py +8 -8
- rasa/cli/inspect.py +6 -4
- rasa/cli/llm_fine_tuning.py +1 -1
- rasa/cli/project_templates/calm/config.yml +5 -7
- rasa/cli/project_templates/calm/endpoints.yml +8 -0
- rasa/cli/project_templates/tutorial/config.yml +8 -5
- rasa/cli/project_templates/tutorial/data/flows.yml +1 -1
- rasa/cli/project_templates/tutorial/data/patterns.yml +5 -0
- rasa/cli/project_templates/tutorial/domain.yml +14 -0
- rasa/cli/project_templates/tutorial/endpoints.yml +7 -7
- rasa/cli/run.py +1 -1
- rasa/cli/scaffold.py +4 -2
- rasa/cli/studio/studio.py +18 -8
- rasa/cli/utils.py +5 -0
- rasa/cli/x.py +8 -8
- rasa/constants.py +1 -1
- rasa/core/actions/action_repeat_bot_messages.py +17 -0
- rasa/core/channels/channel.py +20 -0
- rasa/core/channels/inspector/dist/assets/{arc-6852c607.js → arc-bc141fb2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{c4Diagram-d0fbc5ce-acc952b2.js → c4Diagram-d0fbc5ce-be2db283.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-936ed81e-848a7597.js → classDiagram-936ed81e-55366915.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{classDiagram-v2-c3cb15f1-a73d3e68.js → classDiagram-v2-c3cb15f1-bb529518.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{createText-62fc7601-e5ee049d.js → createText-62fc7601-b0ec81d6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{edges-f2ad444c-771e517e.js → edges-f2ad444c-6166330c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{erDiagram-9d236eb7-aa347178.js → erDiagram-9d236eb7-5ccc6a8e.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDb-1972c806-651fc57d.js → flowDb-1972c806-fca3bfe4.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{flowDiagram-7ea5b25a-ca67804f.js → flowDiagram-7ea5b25a-4739080f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-736177bf.js +1 -0
- rasa/core/channels/inspector/dist/assets/{flowchart-elk-definition-abe16c3d-2dbc568d.js → flowchart-elk-definition-abe16c3d-7c1b0e0f.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{ganttDiagram-9b5ea136-25a65bd8.js → ganttDiagram-9b5ea136-772fd050.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{gitGraphDiagram-99d0ae7c-fdc7378d.js → gitGraphDiagram-99d0ae7c-8eae1dc9.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-2c4b9a3b-6f1fd606.js → index-2c4b9a3b-f55afcdf.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{index-efdd30c1.js → index-e7cef9de.js} +68 -68
- rasa/core/channels/inspector/dist/assets/{infoDiagram-736b4530-cb1a041a.js → infoDiagram-736b4530-124d4a14.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{journeyDiagram-df861f2b-14609879.js → journeyDiagram-df861f2b-7c4fae44.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{layout-2490f52b.js → layout-b9885fb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{line-40186f1f.js → line-7c59abb6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{linear-08814e93.js → linear-4776f780.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{mindmap-definition-beec6740-1a534584.js → mindmap-definition-beec6740-2332c46c.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{pieDiagram-dbbf0591-72397b61.js → pieDiagram-dbbf0591-8fb39303.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{quadrantDiagram-4d7f4fd6-3bb0b6a3.js → quadrantDiagram-4d7f4fd6-3c7180a2.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{requirementDiagram-6fc4c22a-57334f61.js → requirementDiagram-6fc4c22a-e910bcb8.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sankeyDiagram-8f13d901-111e1297.js → sankeyDiagram-8f13d901-ead16c89.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{sequenceDiagram-b655622a-10bcfe62.js → sequenceDiagram-b655622a-29a02a19.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-59f0c015-acaf7513.js → stateDiagram-59f0c015-042b3137.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{stateDiagram-v2-2b26beab-3ec2a235.js → stateDiagram-v2-2b26beab-2178c0f3.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-080da4f6-62730289.js → styles-080da4f6-23ffa4fc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-3dcbcfbf-5284ee76.js → styles-3dcbcfbf-94f59763.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{styles-9c745c82-642435e3.js → styles-9c745c82-78a6bebc.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{svgDrawCommon-4835440b-b250a350.js → svgDrawCommon-4835440b-eae2a6f6.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{timeline-definition-5b62e21b-c2b147ed.js → timeline-definition-5b62e21b-5c968d92.js} +1 -1
- rasa/core/channels/inspector/dist/assets/{xychartDiagram-2b33534f-f92cfea9.js → xychartDiagram-2b33534f-fd3db0d5.js} +1 -1
- rasa/core/channels/inspector/dist/index.html +1 -1
- rasa/core/channels/inspector/src/App.tsx +1 -1
- rasa/core/channels/inspector/src/helpers/audiostream.ts +77 -16
- rasa/core/channels/socketio.py +2 -1
- rasa/core/channels/telegram.py +1 -1
- rasa/core/channels/twilio.py +1 -1
- rasa/core/channels/voice_ready/audiocodes.py +12 -0
- rasa/core/channels/voice_ready/jambonz.py +15 -4
- rasa/core/channels/voice_ready/twilio_voice.py +6 -21
- rasa/core/channels/voice_stream/asr/asr_event.py +5 -0
- rasa/core/channels/voice_stream/asr/azure.py +122 -0
- rasa/core/channels/voice_stream/asr/deepgram.py +16 -6
- rasa/core/channels/voice_stream/audio_bytes.py +1 -0
- rasa/core/channels/voice_stream/browser_audio.py +31 -8
- rasa/core/channels/voice_stream/call_state.py +23 -0
- rasa/core/channels/voice_stream/tts/azure.py +6 -2
- rasa/core/channels/voice_stream/tts/cartesia.py +10 -6
- rasa/core/channels/voice_stream/tts/tts_engine.py +1 -0
- rasa/core/channels/voice_stream/twilio_media_streams.py +27 -18
- rasa/core/channels/voice_stream/util.py +4 -4
- rasa/core/channels/voice_stream/voice_channel.py +189 -39
- rasa/core/featurizers/single_state_featurizer.py +22 -1
- rasa/core/featurizers/tracker_featurizers.py +115 -18
- rasa/core/nlg/contextual_response_rephraser.py +32 -30
- rasa/core/persistor.py +86 -39
- rasa/core/policies/enterprise_search_policy.py +119 -60
- rasa/core/policies/flows/flow_executor.py +7 -4
- rasa/core/policies/intentless_policy.py +78 -22
- rasa/core/policies/ted_policy.py +58 -33
- rasa/core/policies/unexpected_intent_policy.py +15 -7
- rasa/core/processor.py +25 -0
- rasa/core/training/interactive.py +34 -35
- rasa/core/utils.py +8 -3
- rasa/dialogue_understanding/coexistence/llm_based_router.py +39 -12
- rasa/dialogue_understanding/commands/change_flow_command.py +6 -0
- rasa/dialogue_understanding/commands/user_silence_command.py +59 -0
- rasa/dialogue_understanding/commands/utils.py +5 -0
- rasa/dialogue_understanding/generator/constants.py +2 -0
- rasa/dialogue_understanding/generator/flow_retrieval.py +49 -4
- rasa/dialogue_understanding/generator/llm_based_command_generator.py +37 -23
- rasa/dialogue_understanding/generator/multi_step/multi_step_llm_command_generator.py +57 -10
- rasa/dialogue_understanding/generator/nlu_command_adapter.py +19 -1
- rasa/dialogue_understanding/generator/single_step/single_step_llm_command_generator.py +71 -11
- rasa/dialogue_understanding/patterns/default_flows_for_patterns.yml +39 -0
- rasa/dialogue_understanding/patterns/user_silence.py +37 -0
- rasa/dialogue_understanding/processor/command_processor.py +21 -1
- rasa/e2e_test/e2e_test_case.py +85 -6
- rasa/e2e_test/e2e_test_runner.py +4 -2
- rasa/e2e_test/utils/io.py +1 -1
- rasa/engine/validation.py +316 -10
- rasa/model_manager/config.py +15 -3
- rasa/model_manager/model_api.py +15 -7
- rasa/model_manager/runner_service.py +8 -6
- rasa/model_manager/socket_bridge.py +6 -3
- rasa/model_manager/trainer_service.py +7 -5
- rasa/model_manager/utils.py +28 -7
- rasa/model_service.py +9 -2
- rasa/model_training.py +2 -0
- rasa/nlu/classifiers/diet_classifier.py +38 -25
- rasa/nlu/classifiers/logistic_regression_classifier.py +22 -9
- rasa/nlu/classifiers/sklearn_intent_classifier.py +37 -16
- rasa/nlu/extractors/crf_entity_extractor.py +93 -50
- rasa/nlu/featurizers/sparse_featurizer/count_vectors_featurizer.py +45 -16
- rasa/nlu/featurizers/sparse_featurizer/lexical_syntactic_featurizer.py +52 -17
- rasa/nlu/featurizers/sparse_featurizer/regex_featurizer.py +5 -3
- rasa/nlu/tokenizers/whitespace_tokenizer.py +3 -14
- rasa/server.py +3 -1
- rasa/shared/constants.py +36 -3
- rasa/shared/core/constants.py +7 -0
- rasa/shared/core/domain.py +26 -0
- rasa/shared/core/flows/flow.py +5 -0
- rasa/shared/core/flows/flows_list.py +5 -1
- rasa/shared/core/flows/flows_yaml_schema.json +10 -0
- rasa/shared/core/flows/utils.py +39 -0
- rasa/shared/core/flows/validation.py +96 -0
- rasa/shared/core/slots.py +5 -0
- rasa/shared/nlu/training_data/features.py +120 -2
- rasa/shared/providers/_configs/azure_openai_client_config.py +5 -3
- rasa/shared/providers/_configs/litellm_router_client_config.py +200 -0
- rasa/shared/providers/_configs/model_group_config.py +167 -0
- rasa/shared/providers/_configs/openai_client_config.py +1 -1
- rasa/shared/providers/_configs/rasa_llm_client_config.py +73 -0
- rasa/shared/providers/_configs/self_hosted_llm_client_config.py +1 -0
- rasa/shared/providers/_configs/utils.py +16 -0
- rasa/shared/providers/embedding/_base_litellm_embedding_client.py +18 -29
- rasa/shared/providers/embedding/azure_openai_embedding_client.py +54 -21
- rasa/shared/providers/embedding/litellm_router_embedding_client.py +135 -0
- rasa/shared/providers/llm/_base_litellm_client.py +37 -31
- rasa/shared/providers/llm/azure_openai_llm_client.py +50 -29
- rasa/shared/providers/llm/litellm_router_llm_client.py +127 -0
- rasa/shared/providers/llm/rasa_llm_client.py +112 -0
- rasa/shared/providers/llm/self_hosted_llm_client.py +1 -1
- rasa/shared/providers/mappings.py +19 -0
- rasa/shared/providers/router/__init__.py +0 -0
- rasa/shared/providers/router/_base_litellm_router_client.py +149 -0
- rasa/shared/providers/router/router_client.py +73 -0
- rasa/shared/utils/common.py +8 -0
- rasa/shared/utils/health_check/__init__.py +0 -0
- rasa/shared/utils/health_check/embeddings_health_check_mixin.py +31 -0
- rasa/shared/utils/health_check/health_check.py +256 -0
- rasa/shared/utils/health_check/llm_health_check_mixin.py +31 -0
- rasa/shared/utils/io.py +28 -6
- rasa/shared/utils/llm.py +353 -46
- rasa/shared/utils/yaml.py +111 -73
- rasa/studio/auth.py +3 -5
- rasa/studio/config.py +13 -4
- rasa/studio/constants.py +1 -0
- rasa/studio/data_handler.py +10 -3
- rasa/studio/upload.py +81 -26
- rasa/telemetry.py +92 -17
- rasa/tracing/config.py +2 -0
- rasa/tracing/instrumentation/attribute_extractors.py +94 -17
- rasa/tracing/instrumentation/instrumentation.py +121 -0
- rasa/utils/common.py +5 -0
- rasa/utils/io.py +7 -81
- rasa/utils/log_utils.py +9 -2
- rasa/utils/sanic_error_handler.py +32 -0
- rasa/utils/tensorflow/feature_array.py +366 -0
- rasa/utils/tensorflow/model_data.py +2 -193
- rasa/validator.py +70 -0
- rasa/version.py +1 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/METADATA +11 -10
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/RECORD +183 -163
- rasa/core/channels/inspector/dist/assets/flowDiagram-v2-855bc5b3-587d82d8.js +0 -1
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/NOTICE +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/WHEEL +0 -0
- {rasa_pro-3.11.0a4.dev3.dist-info → rasa_pro-3.11.0rc2.dist-info}/entry_points.txt +0 -0
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from typing import List, Optional, Dict, Text, Set, Any
|
|
3
|
+
|
|
2
4
|
import numpy as np
|
|
3
5
|
import scipy.sparse
|
|
4
|
-
from typing import List, Optional, Dict, Text, Set, Any
|
|
5
6
|
|
|
6
7
|
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
|
|
7
8
|
from rasa.nlu.extractors.extractor import EntityTagSpec
|
|
@@ -360,6 +361,26 @@ class SingleStateFeaturizer:
|
|
|
360
361
|
for action in domain.action_names_or_texts
|
|
361
362
|
]
|
|
362
363
|
|
|
364
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
365
|
+
return {
|
|
366
|
+
"action_texts": self.action_texts,
|
|
367
|
+
"entity_tag_specs": self.entity_tag_specs,
|
|
368
|
+
"feature_states": self._default_feature_states,
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
@classmethod
|
|
372
|
+
def create_from_dict(
|
|
373
|
+
cls, data: Dict[str, Any]
|
|
374
|
+
) -> Optional["SingleStateFeaturizer"]:
|
|
375
|
+
if not data:
|
|
376
|
+
return None
|
|
377
|
+
|
|
378
|
+
featurizer = SingleStateFeaturizer()
|
|
379
|
+
featurizer.action_texts = data["action_texts"]
|
|
380
|
+
featurizer._default_feature_states = data["feature_states"]
|
|
381
|
+
featurizer.entity_tag_specs = data["entity_tag_specs"]
|
|
382
|
+
return featurizer
|
|
383
|
+
|
|
363
384
|
|
|
364
385
|
class IntentTokenizerSingleStateFeaturizer(SingleStateFeaturizer):
|
|
365
386
|
"""A SingleStateFeaturizer for use with policies that predict intent labels."""
|
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from collections import defaultdict
|
|
4
|
-
from abc import abstractmethod
|
|
5
|
-
import jsonpickle
|
|
6
|
-
import logging
|
|
7
2
|
|
|
8
|
-
|
|
3
|
+
import logging
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from pathlib import Path
|
|
9
7
|
from typing import (
|
|
10
8
|
Tuple,
|
|
11
9
|
List,
|
|
@@ -18,25 +16,30 @@ from typing import (
|
|
|
18
16
|
Set,
|
|
19
17
|
DefaultDict,
|
|
20
18
|
cast,
|
|
19
|
+
Type,
|
|
20
|
+
Callable,
|
|
21
|
+
ClassVar,
|
|
21
22
|
)
|
|
23
|
+
|
|
22
24
|
import numpy as np
|
|
25
|
+
from tqdm import tqdm
|
|
23
26
|
|
|
24
|
-
from rasa.core.featurizers.single_state_featurizer import SingleStateFeaturizer
|
|
25
|
-
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
|
|
26
|
-
from rasa.core.exceptions import InvalidTrackerFeaturizerUsageError
|
|
27
27
|
import rasa.shared.core.trackers
|
|
28
28
|
import rasa.shared.utils.io
|
|
29
|
-
from rasa.
|
|
30
|
-
from rasa.
|
|
31
|
-
from rasa.
|
|
32
|
-
from rasa.shared.core.domain import State, Domain
|
|
33
|
-
from rasa.shared.core.events import Event, ActionExecuted, UserUttered
|
|
29
|
+
from rasa.core.exceptions import InvalidTrackerFeaturizerUsageError
|
|
30
|
+
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
|
|
31
|
+
from rasa.core.featurizers.single_state_featurizer import SingleStateFeaturizer
|
|
34
32
|
from rasa.shared.core.constants import (
|
|
35
33
|
USER,
|
|
36
34
|
ACTION_UNLIKELY_INTENT_NAME,
|
|
37
35
|
PREVIOUS_ACTION,
|
|
38
36
|
)
|
|
37
|
+
from rasa.shared.core.domain import State, Domain
|
|
38
|
+
from rasa.shared.core.events import Event, ActionExecuted, UserUttered
|
|
39
|
+
from rasa.shared.core.trackers import DialogueStateTracker
|
|
39
40
|
from rasa.shared.exceptions import RasaException
|
|
41
|
+
from rasa.shared.nlu.constants import TEXT, INTENT, ENTITIES, ACTION_NAME
|
|
42
|
+
from rasa.shared.nlu.training_data.features import Features
|
|
40
43
|
from rasa.utils.tensorflow.constants import LABEL_PAD_ID
|
|
41
44
|
from rasa.utils.tensorflow.model_data import ragged_array_to_ndarray
|
|
42
45
|
|
|
@@ -64,6 +67,10 @@ class InvalidStory(RasaException):
|
|
|
64
67
|
class TrackerFeaturizer:
|
|
65
68
|
"""Base class for actual tracker featurizers."""
|
|
66
69
|
|
|
70
|
+
# Class registry to store all subclasses
|
|
71
|
+
_registry: ClassVar[Dict[str, Type["TrackerFeaturizer"]]] = {}
|
|
72
|
+
_featurizer_type: str = "TrackerFeaturizer"
|
|
73
|
+
|
|
67
74
|
def __init__(
|
|
68
75
|
self, state_featurizer: Optional[SingleStateFeaturizer] = None
|
|
69
76
|
) -> None:
|
|
@@ -74,6 +81,36 @@ class TrackerFeaturizer:
|
|
|
74
81
|
"""
|
|
75
82
|
self.state_featurizer = state_featurizer
|
|
76
83
|
|
|
84
|
+
@classmethod
|
|
85
|
+
def register(cls, featurizer_type: str) -> Callable:
|
|
86
|
+
"""Decorator to register featurizer subclasses."""
|
|
87
|
+
|
|
88
|
+
def wrapper(subclass: Type["TrackerFeaturizer"]) -> Type["TrackerFeaturizer"]:
|
|
89
|
+
cls._registry[featurizer_type] = subclass
|
|
90
|
+
# Store the type identifier in the class for serialization
|
|
91
|
+
subclass._featurizer_type = featurizer_type
|
|
92
|
+
return subclass
|
|
93
|
+
|
|
94
|
+
return wrapper
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def from_dict(cls, data: Dict[str, Any]) -> "TrackerFeaturizer":
|
|
98
|
+
"""Create featurizer instance from dictionary."""
|
|
99
|
+
featurizer_type = data.pop("type")
|
|
100
|
+
|
|
101
|
+
if featurizer_type not in cls._registry:
|
|
102
|
+
raise ValueError(f"Unknown featurizer type: {featurizer_type}")
|
|
103
|
+
|
|
104
|
+
# Get the correct subclass and instantiate it
|
|
105
|
+
subclass = cls._registry[featurizer_type]
|
|
106
|
+
return subclass.create_from_dict(data)
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
@abstractmethod
|
|
110
|
+
def create_from_dict(cls, data: Dict[str, Any]) -> "TrackerFeaturizer":
|
|
111
|
+
"""Each subclass must implement its own creation from dict method."""
|
|
112
|
+
pass
|
|
113
|
+
|
|
77
114
|
@staticmethod
|
|
78
115
|
def _create_states(
|
|
79
116
|
tracker: DialogueStateTracker,
|
|
@@ -465,9 +502,7 @@ class TrackerFeaturizer:
|
|
|
465
502
|
self.state_featurizer.entity_tag_specs = []
|
|
466
503
|
|
|
467
504
|
# noinspection PyTypeChecker
|
|
468
|
-
rasa.shared.utils.io.
|
|
469
|
-
str(jsonpickle.encode(self)), featurizer_file
|
|
470
|
-
)
|
|
505
|
+
rasa.shared.utils.io.dump_obj_as_json_to_file(featurizer_file, self.to_dict())
|
|
471
506
|
|
|
472
507
|
@staticmethod
|
|
473
508
|
def load(path: Union[Text, Path]) -> Optional[TrackerFeaturizer]:
|
|
@@ -481,7 +516,17 @@ class TrackerFeaturizer:
|
|
|
481
516
|
"""
|
|
482
517
|
featurizer_file = Path(path) / FEATURIZER_FILE
|
|
483
518
|
if featurizer_file.is_file():
|
|
484
|
-
|
|
519
|
+
data = rasa.shared.utils.io.read_json_file(featurizer_file)
|
|
520
|
+
|
|
521
|
+
if "type" not in data:
|
|
522
|
+
logger.error(
|
|
523
|
+
f"Couldn't load featurizer for policy. "
|
|
524
|
+
f"File '{featurizer_file}' does not contain all "
|
|
525
|
+
f"necessary information. 'type' is missing."
|
|
526
|
+
)
|
|
527
|
+
return None
|
|
528
|
+
|
|
529
|
+
return TrackerFeaturizer.from_dict(data)
|
|
485
530
|
|
|
486
531
|
logger.error(
|
|
487
532
|
f"Couldn't load featurizer for policy. "
|
|
@@ -508,7 +553,16 @@ class TrackerFeaturizer:
|
|
|
508
553
|
)
|
|
509
554
|
]
|
|
510
555
|
|
|
556
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
557
|
+
return {
|
|
558
|
+
"type": self.__class__._featurizer_type,
|
|
559
|
+
"state_featurizer": (
|
|
560
|
+
self.state_featurizer.to_dict() if self.state_featurizer else None
|
|
561
|
+
),
|
|
562
|
+
}
|
|
563
|
+
|
|
511
564
|
|
|
565
|
+
@TrackerFeaturizer.register("FullDialogueTrackerFeaturizer")
|
|
512
566
|
class FullDialogueTrackerFeaturizer(TrackerFeaturizer):
|
|
513
567
|
"""Creates full dialogue training data for time distributed architectures.
|
|
514
568
|
|
|
@@ -646,7 +700,20 @@ class FullDialogueTrackerFeaturizer(TrackerFeaturizer):
|
|
|
646
700
|
|
|
647
701
|
return trackers_as_states
|
|
648
702
|
|
|
703
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
704
|
+
return super().to_dict()
|
|
649
705
|
|
|
706
|
+
@classmethod
|
|
707
|
+
def create_from_dict(cls, data: Dict[str, Any]) -> "FullDialogueTrackerFeaturizer":
|
|
708
|
+
state_featurizer = SingleStateFeaturizer.create_from_dict(
|
|
709
|
+
data["state_featurizer"]
|
|
710
|
+
)
|
|
711
|
+
return cls(
|
|
712
|
+
state_featurizer,
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
@TrackerFeaturizer.register("MaxHistoryTrackerFeaturizer")
|
|
650
717
|
class MaxHistoryTrackerFeaturizer(TrackerFeaturizer):
|
|
651
718
|
"""Truncates the tracker history into `max_history` long sequences.
|
|
652
719
|
|
|
@@ -884,7 +951,25 @@ class MaxHistoryTrackerFeaturizer(TrackerFeaturizer):
|
|
|
884
951
|
|
|
885
952
|
return trackers_as_states
|
|
886
953
|
|
|
954
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
955
|
+
data = super().to_dict()
|
|
956
|
+
data.update(
|
|
957
|
+
{
|
|
958
|
+
"remove_duplicates": self.remove_duplicates,
|
|
959
|
+
"max_history": self.max_history,
|
|
960
|
+
}
|
|
961
|
+
)
|
|
962
|
+
return data
|
|
963
|
+
|
|
964
|
+
@classmethod
|
|
965
|
+
def create_from_dict(cls, data: Dict[str, Any]) -> "MaxHistoryTrackerFeaturizer":
|
|
966
|
+
state_featurizer = SingleStateFeaturizer.create_from_dict(
|
|
967
|
+
data["state_featurizer"]
|
|
968
|
+
)
|
|
969
|
+
return cls(state_featurizer, data["max_history"], data["remove_duplicates"])
|
|
887
970
|
|
|
971
|
+
|
|
972
|
+
@TrackerFeaturizer.register("IntentMaxHistoryTrackerFeaturizer")
|
|
888
973
|
class IntentMaxHistoryTrackerFeaturizer(MaxHistoryTrackerFeaturizer):
|
|
889
974
|
"""Truncates the tracker history into `max_history` long sequences.
|
|
890
975
|
|
|
@@ -1159,6 +1244,18 @@ class IntentMaxHistoryTrackerFeaturizer(MaxHistoryTrackerFeaturizer):
|
|
|
1159
1244
|
|
|
1160
1245
|
return trackers_as_states
|
|
1161
1246
|
|
|
1247
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1248
|
+
return super().to_dict()
|
|
1249
|
+
|
|
1250
|
+
@classmethod
|
|
1251
|
+
def create_from_dict(
|
|
1252
|
+
cls, data: Dict[str, Any]
|
|
1253
|
+
) -> "IntentMaxHistoryTrackerFeaturizer":
|
|
1254
|
+
state_featurizer = SingleStateFeaturizer.create_from_dict(
|
|
1255
|
+
data["state_featurizer"]
|
|
1256
|
+
)
|
|
1257
|
+
return cls(state_featurizer, data["max_history"], data["remove_duplicates"])
|
|
1258
|
+
|
|
1162
1259
|
|
|
1163
1260
|
def _is_prev_action_unlikely_intent_in_state(state: State) -> bool:
|
|
1164
1261
|
prev_action_name = state.get(PREVIOUS_ACTION, {}).get(ACTION_NAME)
|
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
from typing import Any, Dict, Optional, Text
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
import structlog
|
|
5
4
|
from jinja2 import Template
|
|
6
|
-
|
|
7
5
|
from rasa import telemetry
|
|
8
6
|
from rasa.core.nlg.response import TemplatedNaturalLanguageGenerator
|
|
7
|
+
from rasa.core.nlg.summarize import summarize_conversation
|
|
9
8
|
from rasa.shared.constants import (
|
|
10
|
-
LLM_API_HEALTH_CHECK_ENV_VAR,
|
|
11
9
|
LLM_CONFIG_KEY,
|
|
12
10
|
MODEL_CONFIG_KEY,
|
|
13
11
|
MODEL_NAME_CONFIG_KEY,
|
|
@@ -15,27 +13,25 @@ from rasa.shared.constants import (
|
|
|
15
13
|
PROVIDER_CONFIG_KEY,
|
|
16
14
|
OPENAI_PROVIDER,
|
|
17
15
|
TIMEOUT_CONFIG_KEY,
|
|
16
|
+
MODEL_GROUP_CONFIG_KEY,
|
|
18
17
|
)
|
|
19
18
|
from rasa.shared.core.domain import KEY_RESPONSES_TEXT, Domain
|
|
20
19
|
from rasa.shared.core.events import BotUttered, UserUttered
|
|
21
20
|
from rasa.shared.core.trackers import DialogueStateTracker
|
|
21
|
+
from rasa.shared.utils.health_check.llm_health_check_mixin import LLMHealthCheckMixin
|
|
22
22
|
from rasa.shared.utils.llm import (
|
|
23
23
|
DEFAULT_OPENAI_GENERATE_MODEL_NAME,
|
|
24
24
|
DEFAULT_OPENAI_MAX_GENERATED_TOKENS,
|
|
25
25
|
USER,
|
|
26
26
|
combine_custom_and_default_config,
|
|
27
27
|
get_prompt_template,
|
|
28
|
-
llm_api_health_check,
|
|
29
28
|
llm_factory,
|
|
30
|
-
|
|
29
|
+
resolve_model_client_config,
|
|
31
30
|
)
|
|
32
|
-
from rasa.utils.endpoints import EndpointConfig
|
|
33
31
|
from rasa.shared.utils.llm import (
|
|
34
32
|
tracker_as_readable_transcript,
|
|
35
33
|
)
|
|
36
|
-
|
|
37
|
-
from rasa.core.nlg.summarize import summarize_conversation
|
|
38
|
-
|
|
34
|
+
from rasa.utils.endpoints import EndpointConfig
|
|
39
35
|
from rasa.utils.log_utils import log_llm
|
|
40
36
|
|
|
41
37
|
structlogger = structlog.get_logger()
|
|
@@ -47,6 +43,8 @@ RESPONSE_REPHRASING_TEMPLATE_KEY = "rephrase_prompt"
|
|
|
47
43
|
RESPONSE_SUMMARISE_CONVERSATION_KEY = "summarize_conversation"
|
|
48
44
|
|
|
49
45
|
DEFAULT_REPHRASE_ALL = False
|
|
46
|
+
DEFAULT_SUMMARIZE_HISTORY = True
|
|
47
|
+
DEFAULT_MAX_HISTORICAL_TURNS = 5
|
|
50
48
|
|
|
51
49
|
DEFAULT_LLM_CONFIG = {
|
|
52
50
|
PROVIDER_CONFIG_KEY: OPENAI_PROVIDER,
|
|
@@ -71,7 +69,9 @@ Suggested AI Response: {{suggested_response}}
|
|
|
71
69
|
Rephrased AI Response:"""
|
|
72
70
|
|
|
73
71
|
|
|
74
|
-
class ContextualResponseRephraser(
|
|
72
|
+
class ContextualResponseRephraser(
|
|
73
|
+
LLMHealthCheckMixin, TemplatedNaturalLanguageGenerator
|
|
74
|
+
):
|
|
75
75
|
"""Generates responses based on modified templates.
|
|
76
76
|
|
|
77
77
|
The templates are filled with the entities and slots that are available in the
|
|
@@ -105,18 +105,24 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
105
105
|
self.trace_prompt_tokens = self.nlg_endpoint.kwargs.get(
|
|
106
106
|
"trace_prompt_tokens", False
|
|
107
107
|
)
|
|
108
|
-
|
|
108
|
+
self.summarize_history = self.nlg_endpoint.kwargs.get(
|
|
109
|
+
"summarize_history", DEFAULT_SUMMARIZE_HISTORY
|
|
110
|
+
)
|
|
111
|
+
self.max_historical_turns = self.nlg_endpoint.kwargs.get(
|
|
112
|
+
"max_historical_turns", DEFAULT_MAX_HISTORICAL_TURNS
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self.llm_config = resolve_model_client_config(
|
|
109
116
|
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY),
|
|
117
|
+
ContextualResponseRephraser.__name__,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self.perform_llm_health_check(
|
|
121
|
+
self.llm_config,
|
|
110
122
|
DEFAULT_LLM_CONFIG,
|
|
111
123
|
"contextual_response_rephraser.init",
|
|
112
124
|
ContextualResponseRephraser.__name__,
|
|
113
125
|
)
|
|
114
|
-
if os.getenv(LLM_API_HEALTH_CHECK_ENV_VAR, "true").lower() == "true":
|
|
115
|
-
llm_api_health_check(
|
|
116
|
-
llm_client,
|
|
117
|
-
"contextual_response_rephraser.init",
|
|
118
|
-
ContextualResponseRephraser.__name__,
|
|
119
|
-
)
|
|
120
126
|
|
|
121
127
|
def _last_message_if_human(self, tracker: DialogueStateTracker) -> Optional[str]:
|
|
122
128
|
"""Returns the latest message from the tracker.
|
|
@@ -145,9 +151,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
145
151
|
Returns:
|
|
146
152
|
generated text
|
|
147
153
|
"""
|
|
148
|
-
llm = llm_factory(
|
|
149
|
-
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG
|
|
150
|
-
)
|
|
154
|
+
llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
|
|
151
155
|
|
|
152
156
|
try:
|
|
153
157
|
llm_response = await llm.acompletion(prompt)
|
|
@@ -161,7 +165,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
161
165
|
def llm_property(self, prop: str) -> Optional[str]:
|
|
162
166
|
"""Returns a property of the LLM provider."""
|
|
163
167
|
return combine_custom_and_default_config(
|
|
164
|
-
self.
|
|
168
|
+
self.llm_config, DEFAULT_LLM_CONFIG
|
|
165
169
|
).get(prop)
|
|
166
170
|
|
|
167
171
|
def custom_prompt_template(self, prompt_template: str) -> Optional[str]:
|
|
@@ -194,9 +198,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
194
198
|
Returns:
|
|
195
199
|
The history for the prompt.
|
|
196
200
|
"""
|
|
197
|
-
llm = llm_factory(
|
|
198
|
-
self.nlg_endpoint.kwargs.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG
|
|
199
|
-
)
|
|
201
|
+
llm = llm_factory(self.llm_config, DEFAULT_LLM_CONFIG)
|
|
200
202
|
return await summarize_conversation(tracker, llm, max_turns=5)
|
|
201
203
|
|
|
202
204
|
async def rephrase(
|
|
@@ -220,18 +222,17 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
220
222
|
prompt_template_text = self._template_for_response_rephrasing(response)
|
|
221
223
|
|
|
222
224
|
# Retrieve inputs for the dynamic prompt
|
|
223
|
-
transcript = tracker_as_readable_transcript(tracker, max_turns=5)
|
|
224
225
|
latest_message = self._last_message_if_human(tracker)
|
|
225
226
|
current_input = f"{USER}: {latest_message}" if latest_message else ""
|
|
226
227
|
|
|
227
228
|
# Only summarise conversation history if flagged
|
|
228
|
-
|
|
229
|
-
RESPONSE_SUMMARISE_CONVERSATION_KEY, False
|
|
230
|
-
)
|
|
231
|
-
if summarize_conversation_flag:
|
|
229
|
+
if self.summarize_history:
|
|
232
230
|
history = await self._create_history(tracker)
|
|
233
231
|
else:
|
|
234
|
-
history
|
|
232
|
+
# make sure the transcript/history contains the last user utterance
|
|
233
|
+
max_turns = max(self.max_historical_turns, 1)
|
|
234
|
+
history = tracker_as_readable_transcript(tracker, max_turns=max_turns)
|
|
235
|
+
# the history already contains the current input
|
|
235
236
|
current_input = ""
|
|
236
237
|
|
|
237
238
|
prompt = Template(prompt_template_text).render(
|
|
@@ -252,6 +253,7 @@ class ContextualResponseRephraser(TemplatedNaturalLanguageGenerator):
|
|
|
252
253
|
llm_type=self.llm_property(PROVIDER_CONFIG_KEY),
|
|
253
254
|
llm_model=self.llm_property(MODEL_CONFIG_KEY)
|
|
254
255
|
or self.llm_property(MODEL_NAME_CONFIG_KEY),
|
|
256
|
+
llm_model_group_id=self.llm_property(MODEL_GROUP_CONFIG_KEY),
|
|
255
257
|
)
|
|
256
258
|
if not (updated_text := await self._generate_llm_response(prompt)):
|
|
257
259
|
# If the LLM fails to generate a response, we
|
rasa/core/persistor.py
CHANGED
|
@@ -30,7 +30,7 @@ from rasa.shared.utils.io import raise_warning
|
|
|
30
30
|
|
|
31
31
|
if TYPE_CHECKING:
|
|
32
32
|
from azure.storage.blob import ContainerClient
|
|
33
|
-
import
|
|
33
|
+
from botocore.exceptions import ClientError
|
|
34
34
|
|
|
35
35
|
structlogger = structlog.get_logger()
|
|
36
36
|
|
|
@@ -233,13 +233,13 @@ class AWSPersistor(Persistor):
|
|
|
233
233
|
def _ensure_bucket_exists(
|
|
234
234
|
self, bucket_name: Text, region_name: Optional[Text] = None
|
|
235
235
|
) -> None:
|
|
236
|
-
import
|
|
236
|
+
from botocore import exceptions
|
|
237
237
|
|
|
238
238
|
# noinspection PyUnresolvedReferences
|
|
239
239
|
try:
|
|
240
240
|
self.s3.meta.client.head_bucket(Bucket=bucket_name)
|
|
241
|
-
except
|
|
242
|
-
if self.
|
|
241
|
+
except exceptions.ClientError as exc:
|
|
242
|
+
if self._error_code(exc) == HTTP_STATUS_FORBIDDEN:
|
|
243
243
|
log = (
|
|
244
244
|
f"Access to the specified bucket '{bucket_name}' is forbidden. "
|
|
245
245
|
"Please make sure you have the necessary "
|
|
@@ -251,7 +251,7 @@ class AWSPersistor(Persistor):
|
|
|
251
251
|
event_info=log,
|
|
252
252
|
)
|
|
253
253
|
raise RasaException(log)
|
|
254
|
-
elif self.
|
|
254
|
+
elif self._error_code(exc) == HTTP_STATUS_NOT_FOUND:
|
|
255
255
|
log = (
|
|
256
256
|
f"The specified bucket '{bucket_name}' does not exist. "
|
|
257
257
|
"Please make sure to create the bucket first."
|
|
@@ -264,7 +264,7 @@ class AWSPersistor(Persistor):
|
|
|
264
264
|
raise RasaException(log)
|
|
265
265
|
|
|
266
266
|
@staticmethod
|
|
267
|
-
def
|
|
267
|
+
def _error_code(e: "ClientError") -> int:
|
|
268
268
|
return int(e.response["Error"]["Code"])
|
|
269
269
|
|
|
270
270
|
def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
|
|
@@ -274,26 +274,48 @@ class AWSPersistor(Persistor):
|
|
|
274
274
|
|
|
275
275
|
def _retrieve_tar(self, model_path: Text) -> None:
|
|
276
276
|
"""Downloads a model that has previously been persisted to s3."""
|
|
277
|
-
import
|
|
277
|
+
from botocore import exceptions
|
|
278
278
|
|
|
279
279
|
target_filename = os.path.basename(model_path)
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
280
|
+
bucket_objects = list(self.bucket.objects.all())
|
|
281
|
+
|
|
282
|
+
model_found = False
|
|
283
|
+
|
|
284
|
+
log = (
|
|
285
|
+
f"Model '{target_filename}' not found in the specified bucket "
|
|
286
|
+
f"'{self.bucket_name}'. Please make sure the model exists "
|
|
287
|
+
f"in the bucket."
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
for obj in bucket_objects:
|
|
291
|
+
if model_path not in obj.key:
|
|
292
|
+
continue
|
|
293
|
+
structlogger.debug(
|
|
294
|
+
"aws_persistor.retrieve_tar.object_found", object_key=obj.key
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
with open(target_filename, "wb") as f:
|
|
299
|
+
self.bucket.download_fileobj(obj.key, f)
|
|
300
|
+
model_found = True
|
|
301
|
+
break
|
|
302
|
+
except exceptions.ClientError as exc:
|
|
303
|
+
if self._error_code(exc) == HTTP_STATUS_NOT_FOUND:
|
|
304
|
+
structlogger.error(
|
|
305
|
+
"aws_persistor.retrieve_tar.model_not_found",
|
|
306
|
+
bucket_name=self.bucket_name,
|
|
307
|
+
target_filename=target_filename,
|
|
308
|
+
event_info=log,
|
|
309
|
+
)
|
|
310
|
+
raise ModelNotFound() from exc
|
|
311
|
+
if not model_found:
|
|
312
|
+
structlogger.error(
|
|
313
|
+
"aws_persistor.retrieve_tar.model_not_found",
|
|
314
|
+
bucket_name=self.bucket_name,
|
|
315
|
+
target_filename=target_filename,
|
|
316
|
+
event_info=log,
|
|
317
|
+
)
|
|
318
|
+
raise ModelNotFound()
|
|
297
319
|
|
|
298
320
|
|
|
299
321
|
class GCSPersistor(Persistor):
|
|
@@ -322,7 +344,7 @@ class GCSPersistor(Persistor):
|
|
|
322
344
|
|
|
323
345
|
try:
|
|
324
346
|
self.storage_client.get_bucket(bucket_name)
|
|
325
|
-
except auth_exceptions.GoogleAuthError as
|
|
347
|
+
except auth_exceptions.GoogleAuthError as exc:
|
|
326
348
|
log = (
|
|
327
349
|
f"An error occurred while authenticating with Google Cloud "
|
|
328
350
|
f"Storage. Please make sure you have the necessary credentials "
|
|
@@ -333,8 +355,8 @@ class GCSPersistor(Persistor):
|
|
|
333
355
|
bucket_name=bucket_name,
|
|
334
356
|
event_info=log,
|
|
335
357
|
)
|
|
336
|
-
raise RasaException(log) from
|
|
337
|
-
except exceptions.NotFound as
|
|
358
|
+
raise RasaException(log) from exc
|
|
359
|
+
except exceptions.NotFound as exc:
|
|
338
360
|
log = (
|
|
339
361
|
f"The specified Google Cloud Storage bucket '{bucket_name}' "
|
|
340
362
|
f"does not exist. Please make sure to create the bucket first or "
|
|
@@ -345,20 +367,20 @@ class GCSPersistor(Persistor):
|
|
|
345
367
|
bucket_name=bucket_name,
|
|
346
368
|
event_info=log,
|
|
347
369
|
)
|
|
348
|
-
raise RasaException(log) from
|
|
349
|
-
except exceptions.Forbidden as
|
|
370
|
+
raise RasaException(log) from exc
|
|
371
|
+
except exceptions.Forbidden as exc:
|
|
350
372
|
log = (
|
|
351
373
|
f"Access to the specified Google Cloud storage bucket '{bucket_name}' "
|
|
352
374
|
f"is forbidden. Please make sure you have the necessary "
|
|
353
|
-
f"
|
|
375
|
+
f"permissions to access the bucket. "
|
|
354
376
|
)
|
|
355
377
|
structlogger.error(
|
|
356
378
|
"gcp_persistor.ensure_bucket_exists.bucket_access_forbidden",
|
|
357
379
|
bucket_name=bucket_name,
|
|
358
380
|
event_info=log,
|
|
359
381
|
)
|
|
360
|
-
raise RasaException(log) from
|
|
361
|
-
except ValueError as
|
|
382
|
+
raise RasaException(log) from exc
|
|
383
|
+
except ValueError as exc:
|
|
362
384
|
# bucket_name is None
|
|
363
385
|
log = (
|
|
364
386
|
"The specified Google Cloud Storage bucket name is None. Please "
|
|
@@ -368,7 +390,7 @@ class GCSPersistor(Persistor):
|
|
|
368
390
|
"gcp_persistor.ensure_bucket_exists.bucket_name_none",
|
|
369
391
|
event_info=log,
|
|
370
392
|
)
|
|
371
|
-
raise RasaException(log) from
|
|
393
|
+
raise RasaException(log) from exc
|
|
372
394
|
|
|
373
395
|
def _persist_tar(self, file_key: Text, tar_path: Text) -> None:
|
|
374
396
|
"""Uploads a model persisted in the `target_dir` to GCS."""
|
|
@@ -382,7 +404,7 @@ class GCSPersistor(Persistor):
|
|
|
382
404
|
blob = self.bucket.blob(target_filename)
|
|
383
405
|
try:
|
|
384
406
|
blob.download_to_filename(target_filename)
|
|
385
|
-
except exceptions.NotFound as
|
|
407
|
+
except exceptions.NotFound as exc:
|
|
386
408
|
log = (
|
|
387
409
|
f"Model '{target_filename}' not found in the specified bucket "
|
|
388
410
|
f"'{self.bucket_name}'. Please make sure the model exists "
|
|
@@ -394,7 +416,7 @@ class GCSPersistor(Persistor):
|
|
|
394
416
|
target_filename=target_filename,
|
|
395
417
|
event_info=log,
|
|
396
418
|
)
|
|
397
|
-
raise ModelNotFound() from
|
|
419
|
+
raise ModelNotFound() from exc
|
|
398
420
|
|
|
399
421
|
|
|
400
422
|
class AzurePersistor(Persistor):
|
|
@@ -440,8 +462,33 @@ class AzurePersistor(Persistor):
|
|
|
440
462
|
|
|
441
463
|
def _retrieve_tar(self, target_filename: Text) -> None:
|
|
442
464
|
"""Downloads a model that has previously been persisted to Azure."""
|
|
443
|
-
|
|
465
|
+
try:
|
|
466
|
+
blob_list = self._container_client().list_blobs()
|
|
467
|
+
|
|
468
|
+
for blob in blob_list:
|
|
469
|
+
if target_filename not in blob.name:
|
|
470
|
+
continue
|
|
471
|
+
|
|
472
|
+
structlogger.debug(
|
|
473
|
+
"azure_persistor.retrieve_tar.blob_found", blob_name=blob.name
|
|
474
|
+
)
|
|
444
475
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
476
|
+
with open(target_filename, "wb") as model_file:
|
|
477
|
+
blob_client = self._container_client().get_blob_client(blob.name)
|
|
478
|
+
download_stream = blob_client.download_blob()
|
|
479
|
+
model_file.write(download_stream.readall())
|
|
480
|
+
except Exception as exc:
|
|
481
|
+
log = (
|
|
482
|
+
f"An exception occurred while trying to download "
|
|
483
|
+
f"the model '{target_filename}' in the specified container "
|
|
484
|
+
f"'{self.container_name}'. Please make sure the model exists "
|
|
485
|
+
f"in the container."
|
|
486
|
+
)
|
|
487
|
+
structlogger.error(
|
|
488
|
+
"azure_persistor.retrieve_tar.model_download_error",
|
|
489
|
+
container_name=self.container_name,
|
|
490
|
+
target_filename=target_filename,
|
|
491
|
+
event_info=log,
|
|
492
|
+
exception=exc,
|
|
493
|
+
)
|
|
494
|
+
raise ModelNotFound() from exc
|