zenml-nightly 0.68.1.dev20241105__py3-none-any.whl → 0.68.1.dev20241107__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. zenml/VERSION +1 -1
  2. zenml/artifacts/{load_directory_materializer.py → preexisting_data_materializer.py} +8 -9
  3. zenml/artifacts/utils.py +121 -59
  4. zenml/constants.py +1 -0
  5. zenml/integrations/bentoml/materializers/bentoml_bento_materializer.py +19 -31
  6. zenml/integrations/evidently/__init__.py +1 -1
  7. zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py +8 -12
  8. zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py +17 -18
  9. zenml/integrations/huggingface/materializers/huggingface_t5_materializer.py +2 -5
  10. zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py +17 -18
  11. zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py +2 -3
  12. zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py +8 -15
  13. zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py +11 -16
  14. zenml/integrations/pillow/materializers/pillow_image_materializer.py +17 -20
  15. zenml/integrations/polars/materializers/dataframe_materializer.py +26 -39
  16. zenml/integrations/pycaret/materializers/model_materializer.py +7 -22
  17. zenml/integrations/tensorflow/materializers/keras_materializer.py +11 -22
  18. zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py +8 -15
  19. zenml/integrations/vllm/services/vllm_deployment.py +16 -7
  20. zenml/integrations/whylogs/materializers/whylogs_materializer.py +11 -18
  21. zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py +11 -22
  22. zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py +10 -19
  23. zenml/materializers/base_materializer.py +68 -1
  24. zenml/orchestrators/step_runner.py +17 -11
  25. zenml/stack/flavor.py +9 -5
  26. zenml/steps/step_context.py +2 -0
  27. zenml/utils/callback_registry.py +71 -0
  28. zenml/zen_server/rbac/endpoint_utils.py +43 -1
  29. zenml/zen_server/routers/artifact_version_endpoints.py +27 -1
  30. zenml/zen_stores/rest_zen_store.py +52 -0
  31. zenml/zen_stores/sql_zen_store.py +16 -0
  32. zenml/zen_stores/zen_store_interface.py +13 -0
  33. {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/METADATA +1 -1
  34. {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/RECORD +37 -36
  35. {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/LICENSE +0 -0
  36. {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/WHEEL +0 -0
  37. {zenml_nightly-0.68.1.dev20241105.dist-info → zenml_nightly-0.68.1.dev20241107.dist-info}/entry_points.txt +0 -0
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.68.1.dev20241105
1
+ 0.68.1.dev20241107
@@ -14,7 +14,6 @@
14
14
  """Only-load materializer for directories."""
15
15
 
16
16
  import os
17
- import tempfile
18
17
  from pathlib import Path
19
18
  from typing import Any, ClassVar, Tuple, Type
20
19
 
@@ -46,14 +45,14 @@ class PreexistingDataMaterializer(BaseMaterializer):
46
45
  Returns:
47
46
  Path to the local directory that contains the artifact files.
48
47
  """
49
- directory = tempfile.mkdtemp(prefix="zenml-artifact")
50
- if fileio.isdir(self.uri):
51
- self._copy_directory(src=self.uri, dst=directory)
52
- return Path(directory)
53
- else:
54
- dst = os.path.join(directory, os.path.split(self.uri)[-1])
55
- fileio.copy(src=self.uri, dst=dst)
56
- return Path(dst)
48
+ with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
49
+ if fileio.isdir(self.uri):
50
+ self._copy_directory(src=self.uri, dst=temp_dir)
51
+ return Path(temp_dir)
52
+ else:
53
+ dst = os.path.join(temp_dir, os.path.split(self.uri)[-1])
54
+ fileio.copy(src=self.uri, dst=dst)
55
+ return Path(dst)
57
56
 
58
57
  def save(self, data: Any) -> None:
59
58
  """Store the directory in the artifact store.
zenml/artifacts/utils.py CHANGED
@@ -31,7 +31,7 @@ from typing import (
31
31
  from uuid import UUID, uuid4
32
32
 
33
33
  from zenml.artifacts.artifact_config import ArtifactConfig
34
- from zenml.artifacts.load_directory_materializer import (
34
+ from zenml.artifacts.preexisting_data_materializer import (
35
35
  PreexistingDataMaterializer,
36
36
  )
37
37
  from zenml.client import Client
@@ -82,6 +82,114 @@ logger = get_logger(__name__)
82
82
  # ----------
83
83
 
84
84
 
85
+ def _save_artifact_visualizations(
86
+ data: Any, materializer: "BaseMaterializer"
87
+ ) -> List[ArtifactVisualizationRequest]:
88
+ """Save artifact visualizations.
89
+
90
+ Args:
91
+ data: The data for which to save the visualizations.
92
+ materializer: The materializer that should be used to generate and
93
+ save the visualizations.
94
+
95
+ Returns:
96
+ List of requests for the saved visualizations.
97
+ """
98
+ try:
99
+ visualizations = materializer.save_visualizations(data)
100
+ except Exception as e:
101
+ logger.warning("Failed to save artifact visualizations: %s", e)
102
+ return []
103
+
104
+ return [
105
+ ArtifactVisualizationRequest(
106
+ type=type,
107
+ uri=uri,
108
+ )
109
+ for uri, type in visualizations.items()
110
+ ]
111
+
112
+
113
+ def _store_artifact_data_and_prepare_request(
114
+ data: Any,
115
+ name: str,
116
+ uri: str,
117
+ materializer_class: Type["BaseMaterializer"],
118
+ version: Optional[Union[int, str]] = None,
119
+ tags: Optional[List[str]] = None,
120
+ store_metadata: bool = True,
121
+ store_visualizations: bool = True,
122
+ has_custom_name: bool = True,
123
+ metadata: Optional[Dict[str, "MetadataType"]] = None,
124
+ ) -> ArtifactVersionRequest:
125
+ """Store artifact data and prepare a request to the server.
126
+
127
+ Args:
128
+ data: The artifact data.
129
+ name: The artifact name.
130
+ uri: The artifact URI.
131
+ materializer_class: The materializer class to use for storing the
132
+ artifact data.
133
+ version: The artifact version.
134
+ tags: Tags for the artifact version.
135
+ store_metadata: Whether to store metadata for the artifact version.
136
+ store_visualizations: Whether to store visualizations for the artifact
137
+ version.
138
+ has_custom_name: Whether the artifact has a custom name.
139
+ metadata: Metadata to store for the artifact version. This will be
140
+ ignored if `store_metadata` is set to `False`.
141
+
142
+ Returns:
143
+ Artifact version request for the artifact data that was stored.
144
+ """
145
+ artifact_store = Client().active_stack.artifact_store
146
+ artifact_store.makedirs(uri)
147
+
148
+ materializer = materializer_class(uri=uri, artifact_store=artifact_store)
149
+ materializer.uri = materializer.uri.replace("\\", "/")
150
+
151
+ data_type = type(data)
152
+ materializer.validate_save_type_compatibility(data_type)
153
+ materializer.save(data)
154
+
155
+ visualizations = (
156
+ _save_artifact_visualizations(data=data, materializer=materializer)
157
+ if store_visualizations
158
+ else None
159
+ )
160
+
161
+ combined_metadata: Dict[str, "MetadataType"] = {}
162
+ if store_metadata:
163
+ try:
164
+ combined_metadata = materializer.extract_full_metadata(data)
165
+ except Exception as e:
166
+ logger.warning("Failed to extract materializer metadata: %s", e)
167
+
168
+ # Update with user metadata to potentially overwrite values coming from
169
+ # the materializer
170
+ combined_metadata.update(metadata or {})
171
+
172
+ artifact_version_request = ArtifactVersionRequest(
173
+ artifact_name=name,
174
+ version=version,
175
+ tags=tags,
176
+ type=materializer.ASSOCIATED_ARTIFACT_TYPE,
177
+ uri=materializer.uri,
178
+ materializer=source_utils.resolve(materializer.__class__),
179
+ data_type=source_utils.resolve(data_type),
180
+ user=Client().active_user.id,
181
+ workspace=Client().active_workspace.id,
182
+ artifact_store_id=artifact_store.id,
183
+ visualizations=visualizations,
184
+ has_custom_name=has_custom_name,
185
+ metadata=validate_metadata(combined_metadata)
186
+ if combined_metadata
187
+ else None,
188
+ )
189
+
190
+ return artifact_version_request
191
+
192
+
85
193
  def save_artifact(
86
194
  data: Any,
87
195
  name: str,
@@ -89,13 +197,14 @@ def save_artifact(
89
197
  tags: Optional[List[str]] = None,
90
198
  extract_metadata: bool = True,
91
199
  include_visualizations: bool = True,
92
- has_custom_name: bool = True,
93
200
  user_metadata: Optional[Dict[str, "MetadataType"]] = None,
94
201
  materializer: Optional["MaterializerClassOrSource"] = None,
95
202
  uri: Optional[str] = None,
96
203
  is_model_artifact: bool = False,
97
204
  is_deployment_artifact: bool = False,
205
+ # TODO: remove these once external artifact does not use this function anymore
98
206
  manual_save: bool = True,
207
+ has_custom_name: bool = True,
99
208
  ) -> "ArtifactVersionResponse":
100
209
  """Upload and publish an artifact.
101
210
 
@@ -107,8 +216,6 @@ def save_artifact(
107
216
  tags: Tags to associate with the artifact.
108
217
  extract_metadata: If artifact metadata should be extracted and returned.
109
218
  include_visualizations: If artifact visualizations should be generated.
110
- has_custom_name: If the artifact name is custom and should be listed in
111
- the dashboard "Artifacts" tab.
112
219
  user_metadata: User-provided metadata to store with the artifact.
113
220
  materializer: The materializer to use for saving the artifact to the
114
221
  artifact store.
@@ -119,6 +226,8 @@ def save_artifact(
119
226
  is_deployment_artifact: If the artifact is a deployment artifact.
120
227
  manual_save: If this function is called manually and should therefore
121
228
  link the artifact to the current step run.
229
+ has_custom_name: If the artifact name is custom and should be listed in
230
+ the dashboard "Artifacts" tab.
122
231
 
123
232
  Returns:
124
233
  The saved artifact response.
@@ -129,11 +238,8 @@ def save_artifact(
129
238
  from zenml.utils import source_utils
130
239
 
131
240
  client = Client()
132
-
133
- # Get the current artifact store
134
241
  artifact_store = client.active_stack.artifact_store
135
242
 
136
- # Build and check the artifact URI
137
243
  if not uri:
138
244
  uri = os.path.join("custom_artifacts", name, str(uuid4()))
139
245
  if not uri.startswith(artifact_store.path):
@@ -147,9 +253,7 @@ def save_artifact(
147
253
  uri=uri,
148
254
  name=name,
149
255
  )
150
- artifact_store.makedirs(uri)
151
256
 
152
- # Find and initialize the right materializer class
153
257
  if isinstance(materializer, type):
154
258
  materializer_class = materializer
155
259
  elif materializer:
@@ -158,60 +262,18 @@ def save_artifact(
158
262
  )
159
263
  else:
160
264
  materializer_class = materializer_registry[type(data)]
161
- materializer_object = materializer_class(uri)
162
-
163
- # Force URIs to have forward slashes
164
- materializer_object.uri = materializer_object.uri.replace("\\", "/")
165
-
166
- # Save the artifact to the artifact store
167
- data_type = type(data)
168
- materializer_object.validate_save_type_compatibility(data_type)
169
- materializer_object.save(data)
170
265
 
171
- # Save visualizations of the artifact
172
- visualizations: List[ArtifactVisualizationRequest] = []
173
- if include_visualizations:
174
- try:
175
- vis_data = materializer_object.save_visualizations(data)
176
- for vis_uri, vis_type in vis_data.items():
177
- vis_model = ArtifactVisualizationRequest(
178
- type=vis_type,
179
- uri=vis_uri,
180
- )
181
- visualizations.append(vis_model)
182
- except Exception as e:
183
- logger.warning(
184
- f"Failed to save visualization for output artifact '{name}': "
185
- f"{e}"
186
- )
187
-
188
- # Save metadata of the artifact
189
- artifact_metadata: Dict[str, "MetadataType"] = {}
190
- if extract_metadata:
191
- try:
192
- artifact_metadata = materializer_object.extract_full_metadata(data)
193
- artifact_metadata.update(user_metadata or {})
194
- except Exception as e:
195
- logger.warning(
196
- f"Failed to extract metadata for output artifact '{name}': {e}"
197
- )
198
-
199
- artifact_version_request = ArtifactVersionRequest(
200
- artifact_name=name,
266
+ artifact_version_request = _store_artifact_data_and_prepare_request(
267
+ data=data,
268
+ name=name,
269
+ uri=uri,
270
+ materializer_class=materializer_class,
201
271
  version=version,
202
272
  tags=tags,
203
- type=materializer_object.ASSOCIATED_ARTIFACT_TYPE,
204
- uri=materializer_object.uri,
205
- materializer=source_utils.resolve(materializer_object.__class__),
206
- data_type=source_utils.resolve(data_type),
207
- user=Client().active_user.id,
208
- workspace=Client().active_workspace.id,
209
- artifact_store_id=artifact_store.id,
210
- visualizations=visualizations,
273
+ store_metadata=extract_metadata,
274
+ store_visualizations=include_visualizations,
211
275
  has_custom_name=has_custom_name,
212
- metadata=validate_metadata(artifact_metadata)
213
- if artifact_metadata
214
- else None,
276
+ metadata=user_metadata,
215
277
  )
216
278
  artifact_version = client.zen_store.create_artifact_version(
217
279
  artifact_version=artifact_version_request
zenml/constants.py CHANGED
@@ -338,6 +338,7 @@ ARTIFACTS = "/artifacts"
338
338
  ARTIFACT_VERSIONS = "/artifact_versions"
339
339
  ARTIFACT_VISUALIZATIONS = "/artifact_visualizations"
340
340
  AUTH = "/auth"
341
+ BATCH = "/batch"
341
342
  CODE_REFERENCES = "/code_references"
342
343
  CODE_REPOSITORIES = "/code_repositories"
343
344
  COMPONENT_TYPES = "/component-types"
@@ -14,7 +14,6 @@
14
14
  """Materializer for BentoML Bento objects."""
15
15
 
16
16
  import os
17
- import tempfile
18
17
  from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type
19
18
 
20
19
  import bentoml
@@ -23,7 +22,6 @@ from bentoml.exceptions import BentoMLException
23
22
 
24
23
  from zenml.enums import ArtifactType
25
24
  from zenml.integrations.bentoml.constants import DEFAULT_BENTO_FILENAME
26
- from zenml.io import fileio
27
25
  from zenml.logger import get_logger
28
26
  from zenml.materializers.base_materializer import BaseMaterializer
29
27
  from zenml.utils import io_utils
@@ -49,23 +47,21 @@ class BentoMaterializer(BaseMaterializer):
49
47
  Returns:
50
48
  An bento.Bento object.
51
49
  """
52
- # Create a temporary directory to store the model
53
- temp_dir = tempfile.TemporaryDirectory()
54
-
55
- # Copy from artifact store to temporary directory
56
- io_utils.copy_dir(self.uri, temp_dir.name)
57
-
58
- # Load the Bento from the temporary directory
59
- imported_bento = Bento.import_from(
60
- os.path.join(temp_dir.name, DEFAULT_BENTO_FILENAME)
61
- )
62
-
63
- # Try save the Bento to the local BentoML store
64
- try:
65
- _ = bentoml.get(imported_bento.tag)
66
- except BentoMLException:
67
- imported_bento.save()
68
- return imported_bento
50
+ with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
51
+ # Copy from artifact store to temporary directory
52
+ io_utils.copy_dir(self.uri, temp_dir)
53
+
54
+ # Load the Bento from the temporary directory
55
+ imported_bento = Bento.import_from(
56
+ os.path.join(temp_dir, DEFAULT_BENTO_FILENAME)
57
+ )
58
+
59
+ # Try save the Bento to the local BentoML store
60
+ try:
61
+ _ = bentoml.get(imported_bento.tag)
62
+ except BentoMLException:
63
+ imported_bento.save()
64
+ return imported_bento
69
65
 
70
66
  def save(self, bento: bento.Bento) -> None:
71
67
  """Write to artifact store.
@@ -73,18 +69,10 @@ class BentoMaterializer(BaseMaterializer):
73
69
  Args:
74
70
  bento: An bento.Bento object.
75
71
  """
76
- # Create a temporary directory to store the model
77
- temp_dir = tempfile.TemporaryDirectory(prefix="zenml-temp-")
78
- temp_bento_path = os.path.join(temp_dir.name, DEFAULT_BENTO_FILENAME)
79
-
80
- # save the image in a temporary directory
81
- bentoml.export_bento(bento.tag, temp_bento_path)
82
-
83
- # copy the saved image to the artifact store
84
- io_utils.copy_dir(temp_dir.name, self.uri)
85
-
86
- # Remove the temporary directory
87
- fileio.rmtree(temp_dir.name)
72
+ with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
73
+ temp_bento_path = os.path.join(temp_dir, DEFAULT_BENTO_FILENAME)
74
+ bentoml.export_bento(bento.tag, temp_bento_path)
75
+ io_utils.copy_dir(temp_dir, self.uri)
88
76
 
89
77
  def extract_metadata(
90
78
  self, bento: bento.Bento
@@ -33,7 +33,7 @@ from zenml.stack import Flavor
33
33
 
34
34
  # Fix numba errors in Docker and suppress logs and deprecation warning spam
35
35
  try:
36
- from numba.core.errors import ( # type: ignore[import-not-found]
36
+ from numba.core.errors import (
37
37
  NumbaDeprecationWarning,
38
38
  NumbaPendingDeprecationWarning,
39
39
  )
@@ -15,7 +15,6 @@
15
15
 
16
16
  import os
17
17
  from collections import defaultdict
18
- from tempfile import TemporaryDirectory, mkdtemp
19
18
  from typing import (
20
19
  TYPE_CHECKING,
21
20
  Any,
@@ -88,12 +87,12 @@ class HFDatasetMaterializer(BaseMaterializer):
88
87
  Returns:
89
88
  The dataset read from the specified dir.
90
89
  """
91
- temp_dir = mkdtemp()
92
- io_utils.copy_dir(
93
- os.path.join(self.uri, DEFAULT_DATASET_DIR),
94
- temp_dir,
95
- )
96
- return load_from_disk(temp_dir)
90
+ with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
91
+ io_utils.copy_dir(
92
+ os.path.join(self.uri, DEFAULT_DATASET_DIR),
93
+ temp_dir,
94
+ )
95
+ return load_from_disk(temp_dir)
97
96
 
98
97
  def save(self, ds: Union[Dataset, DatasetDict]) -> None:
99
98
  """Writes a Dataset to the specified dir.
@@ -101,16 +100,13 @@ class HFDatasetMaterializer(BaseMaterializer):
101
100
  Args:
102
101
  ds: The Dataset to write.
103
102
  """
104
- temp_dir = TemporaryDirectory()
105
- path = os.path.join(temp_dir.name, DEFAULT_DATASET_DIR)
106
- try:
103
+ with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
104
+ path = os.path.join(temp_dir, DEFAULT_DATASET_DIR)
107
105
  ds.save_to_disk(path)
108
106
  io_utils.copy_dir(
109
107
  path,
110
108
  os.path.join(self.uri, DEFAULT_DATASET_DIR),
111
109
  )
112
- finally:
113
- fileio.rmtree(temp_dir.name)
114
110
 
115
111
  def extract_metadata(
116
112
  self, ds: Union[Dataset, DatasetDict]
@@ -15,7 +15,6 @@
15
15
 
16
16
  import importlib
17
17
  import os
18
- from tempfile import TemporaryDirectory
19
18
  from typing import Any, ClassVar, Dict, Tuple, Type
20
19
 
21
20
  from transformers import (
@@ -46,17 +45,17 @@ class HFPTModelMaterializer(BaseMaterializer):
46
45
  Returns:
47
46
  The model read from the specified dir.
48
47
  """
49
- temp_dir = TemporaryDirectory()
50
- io_utils.copy_dir(
51
- os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), temp_dir.name
52
- )
53
-
54
- config = AutoConfig.from_pretrained(temp_dir.name)
55
- architecture = config.architectures[0]
56
- model_cls = getattr(
57
- importlib.import_module("transformers"), architecture
58
- )
59
- return model_cls.from_pretrained(temp_dir.name)
48
+ with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
49
+ io_utils.copy_dir(
50
+ os.path.join(self.uri, DEFAULT_PT_MODEL_DIR), temp_dir
51
+ )
52
+
53
+ config = AutoConfig.from_pretrained(temp_dir)
54
+ architecture = config.architectures[0]
55
+ model_cls = getattr(
56
+ importlib.import_module("transformers"), architecture
57
+ )
58
+ return model_cls.from_pretrained(temp_dir)
60
59
 
61
60
  def save(self, model: PreTrainedModel) -> None:
62
61
  """Writes a Model to the specified dir.
@@ -64,12 +63,12 @@ class HFPTModelMaterializer(BaseMaterializer):
64
63
  Args:
65
64
  model: The Torch Model to write.
66
65
  """
67
- temp_dir = TemporaryDirectory()
68
- model.save_pretrained(temp_dir.name)
69
- io_utils.copy_dir(
70
- temp_dir.name,
71
- os.path.join(self.uri, DEFAULT_PT_MODEL_DIR),
72
- )
66
+ with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
67
+ model.save_pretrained(temp_dir)
68
+ io_utils.copy_dir(
69
+ temp_dir,
70
+ os.path.join(self.uri, DEFAULT_PT_MODEL_DIR),
71
+ )
73
72
 
74
73
  def extract_metadata(
75
74
  self, model: PreTrainedModel
@@ -14,7 +14,6 @@
14
14
  """Implementation of the Huggingface t5 materializer."""
15
15
 
16
16
  import os
17
- import tempfile
18
17
  from typing import Any, ClassVar, Type, Union
19
18
 
20
19
  from transformers import (
@@ -52,8 +51,7 @@ class HFT5Materializer(BaseMaterializer):
52
51
  ValueError: Unsupported data type used
53
52
  """
54
53
  filepath = self.uri
55
-
56
- with tempfile.TemporaryDirectory(prefix="zenml-temp-") as temp_dir:
54
+ with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
57
55
  # Copy files from artifact store to temporary directory
58
56
  for file in fileio.listdir(filepath):
59
57
  src = os.path.join(filepath, file)
@@ -86,8 +84,7 @@ class HFT5Materializer(BaseMaterializer):
86
84
  Args:
87
85
  obj: A T5ForConditionalGeneration model or T5Tokenizer.
88
86
  """
89
- # Create a temporary directory
90
- with tempfile.TemporaryDirectory(prefix="zenml-temp-") as temp_dir:
87
+ with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
91
88
  # Save the model or tokenizer
92
89
  obj.save_pretrained(temp_dir)
93
90
 
@@ -15,7 +15,6 @@
15
15
 
16
16
  import importlib
17
17
  import os
18
- from tempfile import TemporaryDirectory
19
18
  from typing import Any, ClassVar, Dict, Tuple, Type
20
19
 
21
20
  from transformers import (
@@ -46,17 +45,17 @@ class HFTFModelMaterializer(BaseMaterializer):
46
45
  Returns:
47
46
  The model read from the specified dir.
48
47
  """
49
- temp_dir = TemporaryDirectory()
50
- io_utils.copy_dir(
51
- os.path.join(self.uri, DEFAULT_TF_MODEL_DIR), temp_dir.name
52
- )
53
-
54
- config = AutoConfig.from_pretrained(temp_dir.name)
55
- architecture = "TF" + config.architectures[0]
56
- model_cls = getattr(
57
- importlib.import_module("transformers"), architecture
58
- )
59
- return model_cls.from_pretrained(temp_dir.name)
48
+ with self.get_temporary_directory(delete_at_exit=False) as temp_dir:
49
+ io_utils.copy_dir(
50
+ os.path.join(self.uri, DEFAULT_TF_MODEL_DIR), temp_dir
51
+ )
52
+
53
+ config = AutoConfig.from_pretrained(temp_dir)
54
+ architecture = "TF" + config.architectures[0]
55
+ model_cls = getattr(
56
+ importlib.import_module("transformers"), architecture
57
+ )
58
+ return model_cls.from_pretrained(temp_dir)
60
59
 
61
60
  def save(self, model: TFPreTrainedModel) -> None:
62
61
  """Writes a Model to the specified dir.
@@ -64,12 +63,12 @@ class HFTFModelMaterializer(BaseMaterializer):
64
63
  Args:
65
64
  model: The TF Model to write.
66
65
  """
67
- temp_dir = TemporaryDirectory()
68
- model.save_pretrained(temp_dir.name)
69
- io_utils.copy_dir(
70
- temp_dir.name,
71
- os.path.join(self.uri, DEFAULT_TF_MODEL_DIR),
72
- )
66
+ with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
67
+ model.save_pretrained(temp_dir)
68
+ io_utils.copy_dir(
69
+ temp_dir,
70
+ os.path.join(self.uri, DEFAULT_TF_MODEL_DIR),
71
+ )
73
72
 
74
73
  def extract_metadata(
75
74
  self, model: TFPreTrainedModel
@@ -14,7 +14,6 @@
14
14
  """Implementation of the Huggingface tokenizer materializer."""
15
15
 
16
16
  import os
17
- from tempfile import TemporaryDirectory
18
17
  from typing import Any, ClassVar, Tuple, Type
19
18
 
20
19
  from transformers import AutoTokenizer
@@ -46,7 +45,7 @@ class HFTokenizerMaterializer(BaseMaterializer):
46
45
  Returns:
47
46
  The tokenizer read from the specified dir.
48
47
  """
49
- with TemporaryDirectory() as temp_dir:
48
+ with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
50
49
  io_utils.copy_dir(
51
50
  os.path.join(self.uri, DEFAULT_TOKENIZER_DIR), temp_dir
52
51
  )
@@ -58,7 +57,7 @@ class HFTokenizerMaterializer(BaseMaterializer):
58
57
  Args:
59
58
  tokenizer: The HFTokenizer to write.
60
59
  """
61
- with TemporaryDirectory() as temp_dir:
60
+ with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
62
61
  tokenizer.save_pretrained(temp_dir)
63
62
  io_utils.copy_dir(
64
63
  temp_dir,
@@ -14,7 +14,6 @@
14
14
  """Implementation of the LightGBM booster materializer."""
15
15
 
16
16
  import os
17
- import tempfile
18
17
  from typing import Any, ClassVar, Tuple, Type
19
18
 
20
19
  import lightgbm as lgb
@@ -42,18 +41,13 @@ class LightGBMBoosterMaterializer(BaseMaterializer):
42
41
  A lightgbm Booster object.
43
42
  """
44
43
  filepath = os.path.join(self.uri, DEFAULT_FILENAME)
44
+ with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
45
+ temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
45
46
 
46
- # Create a temporary folder
47
- temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
48
- temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
49
-
50
- # Copy from artifact store to temporary file
51
- fileio.copy(filepath, temp_file)
52
- booster = lgb.Booster(model_file=temp_file)
53
-
54
- # Cleanup and return
55
- fileio.rmtree(temp_dir)
56
- return booster
47
+ # Copy from artifact store to temporary file
48
+ fileio.copy(filepath, temp_file)
49
+ booster = lgb.Booster(model_file=temp_file)
50
+ return booster
57
51
 
58
52
  def save(self, booster: lgb.Booster) -> None:
59
53
  """Creates a JSON serialization for a lightgbm Booster model.
@@ -62,8 +56,7 @@ class LightGBMBoosterMaterializer(BaseMaterializer):
62
56
  booster: A lightgbm Booster model.
63
57
  """
64
58
  filepath = os.path.join(self.uri, DEFAULT_FILENAME)
65
-
66
- with tempfile.TemporaryDirectory() as tmp_dir:
67
- tmp_path = os.path.join(tmp_dir, "model.txt")
59
+ with self.get_temporary_directory(delete_at_exit=True) as temp_dir:
60
+ tmp_path = os.path.join(temp_dir, "model.txt")
68
61
  booster.save_model(tmp_path)
69
62
  fileio.copy(tmp_path, filepath)