zenml-nightly 0.68.1.dev20241105__py3-none-any.whl → 0.68.1.dev20241107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/artifacts/{load_directory_materializer.py → preexisting_data_materializer.py} +8 -9
- zenml/artifacts/utils.py +121 -59
- zenml/constants.py +1 -0
- zenml/integrations/bentoml/materializers/bentoml_bento_materializer.py +19 -31
- zenml/integrations/evidently/__init__.py +1 -1
- zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py +8 -12
- zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py +17 -18
- zenml/integrations/huggingface/materializers/huggingface_t5_materializer.py +2 -5
- zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py +17 -18
- zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py +2 -3
- zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py +8 -15
- zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py +11 -16
- zenml/integrations/pillow/materializers/pillow_image_materializer.py +17 -20
- zenml/integrations/polars/materializers/dataframe_materializer.py +26 -39
- zenml/integrations/pycaret/materializers/model_materializer.py +7 -22
- zenml/integrations/tensorflow/materializers/keras_materializer.py +11 -22
- zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py +8 -15
- zenml/integrations/vllm/services/vllm_deployment.py +16 -7
- zenml/integrations/whylogs/materializers/whylogs_materializer.py +11 -18
- zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py +11 -22
- zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py +10 -19
- zenml/materializers/base_materializer.py +68 -1
- zenml/orchestrators/step_runner.py +17 -11
- zenml/stack/flavor.py +9 -5
- zenml/steps/step_context.py +2 -0
- zenml/utils/callback_registry.py +71 -0
- zenml/zen_server/rbac/endpoint_utils.py +43 -1
- zenml/zen_server/routers/artifact_version_endpoints.py +27 -1
- zenml/zen_stores/rest_zen_store.py +52 -0
- zenml/zen_stores/sql_zen_store.py +16 -0
- zenml/zen_stores/zen_store_interface.py +13 -0
- {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/METADATA +1 -1
- {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/RECORD +37 -36
- {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/entry_points.txt +0 -0
zenml/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.68.1.
|
1
|
+
0.68.1.dev20241107
|
@@ -14,7 +14,6 @@
|
|
14
14
|
"""Only-load materializer for directories."""
|
15
15
|
|
16
16
|
import os
|
17
|
-
import tempfile
|
18
17
|
from pathlib import Path
|
19
18
|
from typing import Any, ClassVar, Tuple, Type
|
20
19
|
|
@@ -46,14 +45,14 @@ class PreexistingDataMaterializer(BaseMaterializer):
|
|
46
45
|
Returns:
|
47
46
|
Path to the local directory that contains the artifact files.
|
48
47
|
"""
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
48
|
+
with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
|
49
|
+
if fileio.isdir(self.uri):
|
50
|
+
self._copy_directory(src=self.uri, dst=temp_dir)
|
51
|
+
return Path(temp_dir)
|
52
|
+
else:
|
53
|
+
dst = os.path.join(temp_dir, os.path.split(self.uri)[-1])
|
54
|
+
fileio.copy(src=self.uri, dst=dst)
|
55
|
+
return Path(dst)
|
57
56
|
|
58
57
|
def save(self, data: Any) -> None:
|
59
58
|
"""Store the directory in the artifact store.
|
zenml/artifacts/utils.py
CHANGED
@@ -31,7 +31,7 @@ from typing import (
|
|
31
31
|
from uuid import UUID, uuid4
|
32
32
|
|
33
33
|
from zenml.artifacts.artifact_config import ArtifactConfig
|
34
|
-
from zenml.artifacts.
|
34
|
+
from zenml.artifacts.preexisting_data_materializer import (
|
35
35
|
PreexistingDataMaterializer,
|
36
36
|
)
|
37
37
|
from zenml.client import Client
|
@@ -82,6 +82,114 @@ logger = get_logger(__name__)
|
|
82
82
|
# ----------
|
83
83
|
|
84
84
|
|
85
|
+
def _save_artifact_visualizations(
|
86
|
+
data: Any, materializer: "BaseMaterializer"
|
87
|
+
) -> List[ArtifactVisualizationRequest]:
|
88
|
+
"""Save artifact visualizations.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
data: The data for which to save the visualizations.
|
92
|
+
materializer: The materializer that should be used to generate and
|
93
|
+
save the visualizations.
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
List of requests for the saved visualizations.
|
97
|
+
"""
|
98
|
+
try:
|
99
|
+
visualizations = materializer.save_visualizations(data)
|
100
|
+
except Exception as e:
|
101
|
+
logger.warning("Failed to save artifact visualizations: %s", e)
|
102
|
+
return []
|
103
|
+
|
104
|
+
return [
|
105
|
+
ArtifactVisualizationRequest(
|
106
|
+
type=type,
|
107
|
+
uri=uri,
|
108
|
+
)
|
109
|
+
for uri, type in visualizations.items()
|
110
|
+
]
|
111
|
+
|
112
|
+
|
113
|
+
def _store_artifact_data_and_prepare_request(
|
114
|
+
data: Any,
|
115
|
+
name: str,
|
116
|
+
uri: str,
|
117
|
+
materializer_class: Type["BaseMaterializer"],
|
118
|
+
version: Optional[Union[int, str]] = None,
|
119
|
+
tags: Optional[List[str]] = None,
|
120
|
+
store_metadata: bool = True,
|
121
|
+
store_visualizations: bool = True,
|
122
|
+
has_custom_name: bool = True,
|
123
|
+
metadata: Optional[Dict[str, "MetadataType"]] = None,
|
124
|
+
) -> ArtifactVersionRequest:
|
125
|
+
"""Store artifact data and prepare a request to the server.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
data: The artifact data.
|
129
|
+
name: The artifact name.
|
130
|
+
uri: The artifact URI.
|
131
|
+
materializer_class: The materializer class to use for storing the
|
132
|
+
artifact data.
|
133
|
+
version: The artifact version.
|
134
|
+
tags: Tags for the artifact version.
|
135
|
+
store_metadata: Whether to store metadata for the artifact version.
|
136
|
+
store_visualizations: Whether to store visualizations for the artifact
|
137
|
+
version.
|
138
|
+
has_custom_name: Whether the artifact has a custom name.
|
139
|
+
metadata: Metadata to store for the artifact version. This will be
|
140
|
+
ignored if `store_metadata` is set to `False`.
|
141
|
+
|
142
|
+
Returns:
|
143
|
+
Artifact version request for the artifact data that was stored.
|
144
|
+
"""
|
145
|
+
artifact_store = Client().active_stack.artifact_store
|
146
|
+
artifact_store.makedirs(uri)
|
147
|
+
|
148
|
+
materializer = materializer_class(uri=uri, artifact_store=artifact_store)
|
149
|
+
materializer.uri = materializer.uri.replace("\\", "/")
|
150
|
+
|
151
|
+
data_type = type(data)
|
152
|
+
materializer.validate_save_type_compatibility(data_type)
|
153
|
+
materializer.save(data)
|
154
|
+
|
155
|
+
visualizations = (
|
156
|
+
_save_artifact_visualizations(data=data, materializer=materializer)
|
157
|
+
if store_visualizations
|
158
|
+
else None
|
159
|
+
)
|
160
|
+
|
161
|
+
combined_metadata: Dict[str, "MetadataType"] = {}
|
162
|
+
if store_metadata:
|
163
|
+
try:
|
164
|
+
combined_metadata = materializer.extract_full_metadata(data)
|
165
|
+
except Exception as e:
|
166
|
+
logger.warning("Failed to extract materializer metadata: %s", e)
|
167
|
+
|
168
|
+
# Update with user metadata to potentially overwrite values coming from
|
169
|
+
# the materializer
|
170
|
+
combined_metadata.update(metadata or {})
|
171
|
+
|
172
|
+
artifact_version_request = ArtifactVersionRequest(
|
173
|
+
artifact_name=name,
|
174
|
+
version=version,
|
175
|
+
tags=tags,
|
176
|
+
type=materializer.ASSOCIATED_ARTIFACT_TYPE,
|
177
|
+
uri=materializer.uri,
|
178
|
+
materializer=source_utils.resolve(materializer.__class__),
|
179
|
+
data_type=source_utils.resolve(data_type),
|
180
|
+
user=Client().active_user.id,
|
181
|
+
workspace=Client().active_workspace.id,
|
182
|
+
artifact_store_id=artifact_store.id,
|
183
|
+
visualizations=visualizations,
|
184
|
+
has_custom_name=has_custom_name,
|
185
|
+
metadata=validate_metadata(combined_metadata)
|
186
|
+
if combined_metadata
|
187
|
+
else None,
|
188
|
+
)
|
189
|
+
|
190
|
+
return artifact_version_request
|
191
|
+
|
192
|
+
|
85
193
|
def save_artifact(
|
86
194
|
data: Any,
|
87
195
|
name: str,
|
@@ -89,13 +197,14 @@ def save_artifact(
|
|
89
197
|
tags: Optional[List[str]] = None,
|
90
198
|
extract_metadata: bool = True,
|
91
199
|
include_visualizations: bool = True,
|
92
|
-
has_custom_name: bool = True,
|
93
200
|
user_metadata: Optional[Dict[str, "MetadataType"]] = None,
|
94
201
|
materializer: Optional["MaterializerClassOrSource"] = None,
|
95
202
|
uri: Optional[str] = None,
|
96
203
|
is_model_artifact: bool = False,
|
97
204
|
is_deployment_artifact: bool = False,
|
205
|
+
# TODO: remove these once external artifact does not use this function anymore
|
98
206
|
manual_save: bool = True,
|
207
|
+
has_custom_name: bool = True,
|
99
208
|
) -> "ArtifactVersionResponse":
|
100
209
|
"""Upload and publish an artifact.
|
101
210
|
|
@@ -107,8 +216,6 @@ def save_artifact(
|
|
107
216
|
tags: Tags to associate with the artifact.
|
108
217
|
extract_metadata: If artifact metadata should be extracted and returned.
|
109
218
|
include_visualizations: If artifact visualizations should be generated.
|
110
|
-
has_custom_name: If the artifact name is custom and should be listed in
|
111
|
-
the dashboard "Artifacts" tab.
|
112
219
|
user_metadata: User-provided metadata to store with the artifact.
|
113
220
|
materializer: The materializer to use for saving the artifact to the
|
114
221
|
artifact store.
|
@@ -119,6 +226,8 @@ def save_artifact(
|
|
119
226
|
is_deployment_artifact: If the artifact is a deployment artifact.
|
120
227
|
manual_save: If this function is called manually and should therefore
|
121
228
|
link the artifact to the current step run.
|
229
|
+
has_custom_name: If the artifact name is custom and should be listed in
|
230
|
+
the dashboard "Artifacts" tab.
|
122
231
|
|
123
232
|
Returns:
|
124
233
|
The saved artifact response.
|
@@ -129,11 +238,8 @@ def save_artifact(
|
|
129
238
|
from zenml.utils import source_utils
|
130
239
|
|
131
240
|
client = Client()
|
132
|
-
|
133
|
-
# Get the current artifact store
|
134
241
|
artifact_store = client.active_stack.artifact_store
|
135
242
|
|
136
|
-
# Build and check the artifact URI
|
137
243
|
if not uri:
|
138
244
|
uri = os.path.join("custom_artifacts", name, str(uuid4()))
|
139
245
|
if not uri.startswith(artifact_store.path):
|
@@ -147,9 +253,7 @@ def save_artifact(
|
|
147
253
|
uri=uri,
|
148
254
|
name=name,
|
149
255
|
)
|
150
|
-
artifact_store.makedirs(uri)
|
151
256
|
|
152
|
-
# Find and initialize the right materializer class
|
153
257
|
if isinstance(materializer, type):
|
154
258
|
materializer_class = materializer
|
155
259
|
elif materializer:
|
@@ -158,60 +262,18 @@ def save_artifact(
|
|
158
262
|
)
|
159
263
|
else:
|
160
264
|
materializer_class = materializer_registry[type(data)]
|
161
|
-
materializer_object = materializer_class(uri)
|
162
|
-
|
163
|
-
# Force URIs to have forward slashes
|
164
|
-
materializer_object.uri = materializer_object.uri.replace("\\", "/")
|
165
|
-
|
166
|
-
# Save the artifact to the artifact store
|
167
|
-
data_type = type(data)
|
168
|
-
materializer_object.validate_save_type_compatibility(data_type)
|
169
|
-
materializer_object.save(data)
|
170
265
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
for vis_uri, vis_type in vis_data.items():
|
177
|
-
vis_model = ArtifactVisualizationRequest(
|
178
|
-
type=vis_type,
|
179
|
-
uri=vis_uri,
|
180
|
-
)
|
181
|
-
visualizations.append(vis_model)
|
182
|
-
except Exception as e:
|
183
|
-
logger.warning(
|
184
|
-
f"Failed to save visualization for output artifact '{name}': "
|
185
|
-
f"{e}"
|
186
|
-
)
|
187
|
-
|
188
|
-
# Save metadata of the artifact
|
189
|
-
artifact_metadata: Dict[str, "MetadataType"] = {}
|
190
|
-
if extract_metadata:
|
191
|
-
try:
|
192
|
-
artifact_metadata = materializer_object.extract_full_metadata(data)
|
193
|
-
artifact_metadata.update(user_metadata or {})
|
194
|
-
except Exception as e:
|
195
|
-
logger.warning(
|
196
|
-
f"Failed to extract metadata for output artifact '{name}': {e}"
|
197
|
-
)
|
198
|
-
|
199
|
-
artifact_version_request = ArtifactVersionRequest(
|
200
|
-
artifact_name=name,
|
266
|
+
artifact_version_request = _store_artifact_data_and_prepare_request(
|
267
|
+
data=data,
|
268
|
+
name=name,
|
269
|
+
uri=uri,
|
270
|
+
materializer_class=materializer_class,
|
201
271
|
version=version,
|
202
272
|
tags=tags,
|
203
|
-
|
204
|
-
|
205
|
-
materializer=source_utils.resolve(materializer_object.__class__),
|
206
|
-
data_type=source_utils.resolve(data_type),
|
207
|
-
user=Client().active_user.id,
|
208
|
-
workspace=Client().active_workspace.id,
|
209
|
-
artifact_store_id=artifact_store.id,
|
210
|
-
visualizations=visualizations,
|
273
|
+
store_metadata=extract_metadata,
|
274
|
+
store_visualizations=include_visualizations,
|
211
275
|
has_custom_name=has_custom_name,
|
212
|
-
metadata=
|
213
|
-
if artifact_metadata
|
214
|
-
else None,
|
276
|
+
metadata=user_metadata,
|
215
277
|
)
|
216
278
|
artifact_version = client.zen_store.create_artifact_version(
|
217
279
|
artifact_version=artifact_version_request
|
zenml/constants.py
CHANGED
@@ -338,6 +338,7 @@ ARTIFACTS = "/artifacts"
|
|
338
338
|
ARTIFACT_VERSIONS = "/artifact_versions"
|
339
339
|
ARTIFACT_VISUALIZATIONS = "/artifact_visualizations"
|
340
340
|
AUTH = "/auth"
|
341
|
+
BATCH = "/batch"
|
341
342
|
CODE_REFERENCES = "/code_references"
|
342
343
|
CODE_REPOSITORIES = "/code_repositories"
|
343
344
|
COMPONENT_TYPES = "/component-types"
|
@@ -14,7 +14,6 @@
|
|
14
14
|
"""Materializer for BentoML Bento objects."""
|
15
15
|
|
16
16
|
import os
|
17
|
-
import tempfile
|
18
17
|
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type
|
19
18
|
|
20
19
|
import bentoml
|
@@ -23,7 +22,6 @@ from bentoml.exceptions import BentoMLException
|
|
23
22
|
|
24
23
|
from zenml.enums import ArtifactType
|
25
24
|
from zenml.integrations.bentoml.constants import DEFAULT_BENTO_FILENAME
|
26
|
-
from zenml.io import fileio
|
27
25
|
from zenml.logger import get_logger
|
28
26
|
from zenml.materializers.base_materializer import BaseMaterializer
|
29
27
|
from zenml.utils import io_utils
|
@@ -49,23 +47,21 @@ class BentoMaterializer(BaseMaterializer):
|
|
49
47
|
Returns:
|
50
48
|
An bento.Bento object.
|
51
49
|
"""
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
imported_bento.save()
|
68
|
-
return imported_bento
|
50
|
+
with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
|
51
|
+
# Copy from artifact store to temporary directory
|
52
|
+
io_utils.copy_dir(self.uri, temp_dir)
|
53
|
+
|
54
|
+
# Load the Bento from the temporary directory
|
55
|
+
imported_bento = Bento.import_from(
|
56
|
+
os.path.join(temp_dir, DEFAULT_BENTO_FILENAME)
|
57
|
+
)
|
58
|
+
|
59
|
+
# Try save the Bento to the local BentoML store
|
60
|
+
try:
|
61
|
+
_ = bentoml.get(imported_bento.tag)
|
62
|
+
except BentoMLException:
|
63
|
+
imported_bento.save()
|
64
|
+
return imported_bento
|
69
65
|
|
70
66
|
def save(self, bento: bento.Bento) -> None:
|
71
67
|
"""Write to artifact store.
|
@@ -73,18 +69,10 @@ class BentoMaterializer(BaseMaterializer):
|
|
73
69
|
Args:
|
74
70
|
bento: An bento.Bento object.
|
75
71
|
"""
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
# save the image in a temporary directory
|
81
|
-
bentoml.export_bento(bento.tag, temp_bento_path)
|
82
|
-
|
83
|
-
# copy the saved image to the artifact store
|
84
|
-
io_utils.copy_dir(temp_dir.name, self.uri)
|
85
|
-
|
86
|
-
# Remove the temporary directory
|
87
|
-
fileio.rmtree(temp_dir.name)
|
72
|
+
with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
|
73
|
+
temp_bento_path = os.path.join(temp_dir, DEFAULT_BENTO_FILENAME)
|
74
|
+
bentoml.export_bento(bento.tag, temp_bento_path)
|
75
|
+
io_utils.copy_dir(temp_dir, self.uri)
|
88
76
|
|
89
77
|
def extract_metadata(
|
90
78
|
self, bento: bento.Bento
|
@@ -33,7 +33,7 @@ from zenml.stack import Flavor
|
|
33
33
|
|
34
34
|
# Fix numba errors in Docker and suppress logs and deprecation warning spam
|
35
35
|
try:
|
36
|
-
from numba.core.errors import (
|
36
|
+
from numba.core.errors import (
|
37
37
|
NumbaDeprecationWarning,
|
38
38
|
NumbaPendingDeprecationWarning,
|
39
39
|
)
|
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
import os
|
17
17
|
from collections import defaultdict
|
18
|
-
from tempfile import TemporaryDirectory, mkdtemp
|
19
18
|
from typing import (
|
20
19
|
TYPE_CHECKING,
|
21
20
|
Any,
|
@@ -88,12 +87,12 @@ class HFDatasetMaterializer(BaseMaterializer):
|
|
88
87
|
Returns:
|
89
88
|
The dataset read from the specified dir.
|
90
89
|
"""
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
90
|
+
with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
|
91
|
+
io_utils.copy_dir(
|
92
|
+
os.path.join(self.uri, DEFAULT_DATASET_DIR),
|
93
|
+
temp_dir,
|
94
|
+
)
|
95
|
+
return load_from_disk(temp_dir)
|
97
96
|
|
98
97
|
def save(self, ds: Union[Dataset, DatasetDict]) -> None:
|
99
98
|
"""Writes a Dataset to the specified dir.
|
@@ -101,16 +100,13 @@ class HFDatasetMaterializer(BaseMaterializer):
|
|
101
100
|
Args:
|
102
101
|
ds: The Dataset to write.
|
103
102
|
"""
|
104
|
-
|
105
|
-
|
106
|
-
try:
|
103
|
+
with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
|
104
|
+
path = os.path.join(temp_dir, DEFAULT_DATASET_DIR)
|
107
105
|
ds.save_to_disk(path)
|
108
106
|
io_utils.copy_dir(
|
109
107
|
path,
|
110
108
|
os.path.join(self.uri, DEFAULT_DATASET_DIR),
|
111
109
|
)
|
112
|
-
finally:
|
113
|
-
fileio.rmtree(temp_dir.name)
|
114
110
|
|
115
111
|
def extract_metadata(
|
116
112
|
self, ds: Union[Dataset, DatasetDict]
|
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
import importlib
|
17
17
|
import os
|
18
|
-
from tempfile import TemporaryDirectory
|
19
18
|
from typing import Any, ClassVar, Dict, Tuple, Type
|
20
19
|
|
21
20
|
from transformers import (
|
@@ -46,17 +45,17 @@ class HFPTModelMaterializer(BaseMaterializer):
|
|
46
45
|
Returns:
|
47
46
|
The model read from the specified dir.
|
48
47
|
"""
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
48
|
+
with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
|
49
|
+
io_utils.copy_dir(
|
50
|
+
os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), temp_dir
|
51
|
+
)
|
52
|
+
|
53
|
+
config = AutoConfig.from_pretrained(temp_dir)
|
54
|
+
architecture = config.architectures[0]
|
55
|
+
model_cls = getattr(
|
56
|
+
importlib.import_module("transformers"), architecture
|
57
|
+
)
|
58
|
+
return model_cls.from_pretrained(temp_dir)
|
60
59
|
|
61
60
|
def save(self, model: PreTrainedModel) -> None:
|
62
61
|
"""Writes a Model to the specified dir.
|
@@ -64,12 +63,12 @@ class HFPTModelMaterializer(BaseMaterializer):
|
|
64
63
|
Args:
|
65
64
|
model: The Torch Model to write.
|
66
65
|
"""
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
66
|
+
with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
|
67
|
+
model.save_pretrained(temp_dir)
|
68
|
+
io_utils.copy_dir(
|
69
|
+
temp_dir,
|
70
|
+
os.path.join(self.uri, DEFAULT_PT_MODEL_DIR),
|
71
|
+
)
|
73
72
|
|
74
73
|
def extract_metadata(
|
75
74
|
self, model: PreTrainedModel
|
@@ -14,7 +14,6 @@
|
|
14
14
|
"""Implementation of the Huggingface t5 materializer."""
|
15
15
|
|
16
16
|
import os
|
17
|
-
import tempfile
|
18
17
|
from typing import Any, ClassVar, Type, Union
|
19
18
|
|
20
19
|
from transformers import (
|
@@ -52,8 +51,7 @@ class HFT5Materializer(BaseMaterializer):
|
|
52
51
|
ValueError: Unsupported data type used
|
53
52
|
"""
|
54
53
|
filepath = self.uri
|
55
|
-
|
56
|
-
with tempfile.TemporaryDirectory(prefix="zenml-temp-") as temp_dir:
|
54
|
+
with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
|
57
55
|
# Copy files from artifact store to temporary directory
|
58
56
|
for file in fileio.listdir(filepath):
|
59
57
|
src = os.path.join(filepath, file)
|
@@ -86,8 +84,7 @@ class HFT5Materializer(BaseMaterializer):
|
|
86
84
|
Args:
|
87
85
|
obj: A T5ForConditionalGeneration model or T5Tokenizer.
|
88
86
|
"""
|
89
|
-
|
90
|
-
with tempfile.TemporaryDirectory(prefix="zenml-temp-") as temp_dir:
|
87
|
+
with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
|
91
88
|
# Save the model or tokenizer
|
92
89
|
obj.save_pretrained(temp_dir)
|
93
90
|
|
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
import importlib
|
17
17
|
import os
|
18
|
-
from tempfile import TemporaryDirectory
|
19
18
|
from typing import Any, ClassVar, Dict, Tuple, Type
|
20
19
|
|
21
20
|
from transformers import (
|
@@ -46,17 +45,17 @@ class HFTFModelMaterializer(BaseMaterializer):
|
|
46
45
|
Returns:
|
47
46
|
The model read from the specified dir.
|
48
47
|
"""
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
48
|
+
with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
|
49
|
+
io_utils.copy_dir(
|
50
|
+
os.path.join(self.uri, DEFAULT_TF_MODEL_DIR), temp_dir
|
51
|
+
)
|
52
|
+
|
53
|
+
config = AutoConfig.from_pretrained(temp_dir)
|
54
|
+
architecture = "TF" + config.architectures[0]
|
55
|
+
model_cls = getattr(
|
56
|
+
importlib.import_module("transformers"), architecture
|
57
|
+
)
|
58
|
+
return model_cls.from_pretrained(temp_dir)
|
60
59
|
|
61
60
|
def save(self, model: TFPreTrainedModel) -> None:
|
62
61
|
"""Writes a Model to the specified dir.
|
@@ -64,12 +63,12 @@ class HFTFModelMaterializer(BaseMaterializer):
|
|
64
63
|
Args:
|
65
64
|
model: The TF Model to write.
|
66
65
|
"""
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
66
|
+
with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
|
67
|
+
model.save_pretrained(temp_dir)
|
68
|
+
io_utils.copy_dir(
|
69
|
+
temp_dir,
|
70
|
+
os.path.join(self.uri, DEFAULT_TF_MODEL_DIR),
|
71
|
+
)
|
73
72
|
|
74
73
|
def extract_metadata(
|
75
74
|
self, model: TFPreTrainedModel
|
@@ -14,7 +14,6 @@
|
|
14
14
|
"""Implementation of the Huggingface tokenizer materializer."""
|
15
15
|
|
16
16
|
import os
|
17
|
-
from tempfile import TemporaryDirectory
|
18
17
|
from typing import Any, ClassVar, Tuple, Type
|
19
18
|
|
20
19
|
from transformers import AutoTokenizer
|
@@ -46,7 +45,7 @@ class HFTokenizerMaterializer(BaseMaterializer):
|
|
46
45
|
Returns:
|
47
46
|
The tokenizer read from the specified dir.
|
48
47
|
"""
|
49
|
-
with
|
48
|
+
with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
|
50
49
|
io_utils.copy_dir(
|
51
50
|
os.path.join(self.uri, DEFAULT_TOKENIZER_DIR), temp_dir
|
52
51
|
)
|
@@ -58,7 +57,7 @@ class HFTokenizerMaterializer(BaseMaterializer):
|
|
58
57
|
Args:
|
59
58
|
tokenizer: The HFTokenizer to write.
|
60
59
|
"""
|
61
|
-
with
|
60
|
+
with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
|
62
61
|
tokenizer.save_pretrained(temp_dir)
|
63
62
|
io_utils.copy_dir(
|
64
63
|
temp_dir,
|
@@ -14,7 +14,6 @@
|
|
14
14
|
"""Implementation of the LightGBM booster materializer."""
|
15
15
|
|
16
16
|
import os
|
17
|
-
import tempfile
|
18
17
|
from typing import Any, ClassVar, Tuple, Type
|
19
18
|
|
20
19
|
import lightgbm as lgb
|
@@ -42,18 +41,13 @@ class LightGBMBoosterMaterializer(BaseMaterializer):
|
|
42
41
|
A lightgbm Booster object.
|
43
42
|
"""
|
44
43
|
filepath = os.path.join(self.uri, DEFAULT_FILENAME)
|
44
|
+
with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
|
45
|
+
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
|
45
46
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
# Copy from artifact store to temporary file
|
51
|
-
fileio.copy(filepath, temp_file)
|
52
|
-
booster = lgb.Booster(model_file=temp_file)
|
53
|
-
|
54
|
-
# Cleanup and return
|
55
|
-
fileio.rmtree(temp_dir)
|
56
|
-
return booster
|
47
|
+
# Copy from artifact store to temporary file
|
48
|
+
fileio.copy(filepath, temp_file)
|
49
|
+
booster = lgb.Booster(model_file=temp_file)
|
50
|
+
return booster
|
57
51
|
|
58
52
|
def save(self, booster: lgb.Booster) -> None:
|
59
53
|
"""Creates a JSON serialization for a lightgbm Booster model.
|
@@ -62,8 +56,7 @@ class LightGBMBoosterMaterializer(BaseMaterializer):
|
|
62
56
|
booster: A lightgbm Booster model.
|
63
57
|
"""
|
64
58
|
filepath = os.path.join(self.uri, DEFAULT_FILENAME)
|
65
|
-
|
66
|
-
|
67
|
-
tmp_path = os.path.join(tmp_dir, "model.txt")
|
59
|
+
with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
|
60
|
+
tmp_path = os.path.join(temp_dir, "model.txt")
|
68
61
|
booster.save_model(tmp_path)
|
69
62
|
fileio.copy(tmp_path, filepath)
|