mlrun 1.9.1__py3-none-any.whl → 1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/artifacts/manager.py +1 -1
- mlrun/frameworks/tf_keras/mlrun_interface.py +9 -17
- mlrun/frameworks/tf_keras/model_handler.py +23 -3
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.9.1.dist-info → mlrun-1.9.2.dist-info}/METADATA +2 -2
- {mlrun-1.9.1.dist-info → mlrun-1.9.2.dist-info}/RECORD +10 -10
- {mlrun-1.9.1.dist-info → mlrun-1.9.2.dist-info}/WHEEL +0 -0
- {mlrun-1.9.1.dist-info → mlrun-1.9.2.dist-info}/entry_points.txt +0 -0
- {mlrun-1.9.1.dist-info → mlrun-1.9.2.dist-info}/licenses/LICENSE +0 -0
- {mlrun-1.9.1.dist-info → mlrun-1.9.2.dist-info}/top_level.txt +0 -0
mlrun/artifacts/manager.py
CHANGED
|
@@ -107,14 +107,10 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
|
|
|
107
107
|
)
|
|
108
108
|
|
|
109
109
|
# Call the pre compile method:
|
|
110
|
-
|
|
111
|
-
optimizer=kwargs["optimizer"]
|
|
112
|
-
)
|
|
110
|
+
optimizer = self._pre_compile(optimizer=kwargs["optimizer"])
|
|
113
111
|
|
|
114
112
|
# Assign parameters:
|
|
115
113
|
kwargs["optimizer"] = optimizer
|
|
116
|
-
if experimental_run_tf_function is not None:
|
|
117
|
-
kwargs["experimental_run_tf_function"] = experimental_run_tf_function
|
|
118
114
|
|
|
119
115
|
# Call the original compile method:
|
|
120
116
|
return self.original_compile(*args, **kwargs)
|
|
@@ -235,23 +231,20 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
|
|
|
235
231
|
"""
|
|
236
232
|
self._RANK_0_ONLY_CALLBACKS.add(callback_name)
|
|
237
233
|
|
|
238
|
-
def _pre_compile(self, optimizer: Optimizer) ->
|
|
234
|
+
def _pre_compile(self, optimizer: Optimizer) -> Optimizer:
|
|
239
235
|
"""
|
|
240
236
|
Method to call before calling 'compile' to setup the run and inputs for using horovod.
|
|
241
237
|
|
|
242
238
|
:param optimizer: The optimzier to compile. It will be wrapped in horovod's distributed optimizer:
|
|
243
239
|
'hvd.DistributedOptimizer'.
|
|
244
240
|
|
|
245
|
-
:return: The updated
|
|
246
|
-
[0] = Wrapped optimizer.
|
|
247
|
-
[1] = The 'experimental_run_tf_function' parameter for 'compile' kwargs or 'None' if horovod should not
|
|
248
|
-
be used.
|
|
241
|
+
:return: The updated Wrapped optimizer.
|
|
249
242
|
|
|
250
243
|
:raise MLRunInvalidArgumentError: In case the optimizer was passed as a string.
|
|
251
244
|
"""
|
|
252
245
|
# Check if needed to run with horovod:
|
|
253
246
|
if self._hvd is None:
|
|
254
|
-
return optimizer
|
|
247
|
+
return optimizer
|
|
255
248
|
|
|
256
249
|
# Validate the optimizer input:
|
|
257
250
|
if isinstance(optimizer, str):
|
|
@@ -280,16 +273,15 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
|
|
|
280
273
|
print(f"Horovod worker #{self._hvd.rank()} is using CPU")
|
|
281
274
|
|
|
282
275
|
# Adjust learning rate based on the number of GPUs:
|
|
283
|
-
optimizer
|
|
276
|
+
if hasattr(optimizer, "lr"):
|
|
277
|
+
optimizer.lr = optimizer.lr * self._hvd.size()
|
|
278
|
+
else:
|
|
279
|
+
optimizer.learning_rate = optimizer.learning_rate * self._hvd.size()
|
|
284
280
|
|
|
285
281
|
# Wrap the optimizer in horovod's distributed optimizer: 'hvd.DistributedOptimizer'.
|
|
286
282
|
optimizer = self._hvd.DistributedOptimizer(optimizer)
|
|
287
283
|
|
|
288
|
-
|
|
289
|
-
# optimizer to compute the gradients:
|
|
290
|
-
experimental_run_tf_function = False
|
|
291
|
-
|
|
292
|
-
return optimizer, experimental_run_tf_function
|
|
284
|
+
return optimizer
|
|
293
285
|
|
|
294
286
|
def _pre_fit(
|
|
295
287
|
self,
|
|
@@ -518,7 +518,6 @@ class TFKerasModelHandler(DLModelHandler):
|
|
|
518
518
|
)
|
|
519
519
|
|
|
520
520
|
# Read additional files according to the model format used:
|
|
521
|
-
# # ModelFormats.SAVED_MODEL - Unzip the SavedModel archive:
|
|
522
521
|
if self._model_format == TFKerasModelHandler.ModelFormats.SAVED_MODEL:
|
|
523
522
|
# Unzip the SavedModel directory:
|
|
524
523
|
with zipfile.ZipFile(self._model_file, "r") as zip_file:
|
|
@@ -527,11 +526,18 @@ class TFKerasModelHandler(DLModelHandler):
|
|
|
527
526
|
self._model_file = os.path.join(
|
|
528
527
|
os.path.dirname(self._model_file), self._model_name
|
|
529
528
|
)
|
|
530
|
-
|
|
531
|
-
|
|
529
|
+
elif self._model_format == TFKerasModelHandler.ModelFormats.KERAS:
|
|
530
|
+
# Rename the model file suffix:
|
|
531
|
+
self._rename_model_file_suffix(suffix="keras")
|
|
532
|
+
elif self._model_format == TFKerasModelHandler.ModelFormats.H5:
|
|
533
|
+
# Rename the model file suffix:
|
|
534
|
+
self._rename_model_file_suffix(suffix="h5")
|
|
535
|
+
elif ( # ModelFormats.JSON_ARCHITECTURE_H5_WEIGHTS
|
|
532
536
|
self._model_format
|
|
533
537
|
== TFKerasModelHandler.ModelFormats.JSON_ARCHITECTURE_H5_WEIGHTS
|
|
534
538
|
):
|
|
539
|
+
# Rename the model file suffix:
|
|
540
|
+
self._rename_model_file_suffix(suffix="json")
|
|
535
541
|
# Get the weights file:
|
|
536
542
|
self._weights_file = self._extra_data[
|
|
537
543
|
self._get_weights_file_artifact_name()
|
|
@@ -540,6 +546,20 @@ class TFKerasModelHandler(DLModelHandler):
|
|
|
540
546
|
# Continue collecting from abstract class:
|
|
541
547
|
super()._collect_files_from_store_object()
|
|
542
548
|
|
|
549
|
+
def _rename_model_file_suffix(self, suffix: str):
|
|
550
|
+
"""
|
|
551
|
+
Rename the model file suffix to the given one.
|
|
552
|
+
|
|
553
|
+
This is used for the case of loading a model from a store object that was saved with a different suffix as when
|
|
554
|
+
keras tries to load it, it validates the suffix. The `artifacts.model.get_model` function is downloading the
|
|
555
|
+
file to a temp file with a `pkl` suffix, so it needs to be replaced:than the one keras expects.
|
|
556
|
+
|
|
557
|
+
:param suffix: The suffix to rename the model file to (without the trailing dot).
|
|
558
|
+
"""
|
|
559
|
+
new_name = self._model_file.rsplit(".", 1)[0] + f".{suffix}"
|
|
560
|
+
os.rename(self._model_file, new_name)
|
|
561
|
+
self._model_file = new_name
|
|
562
|
+
|
|
543
563
|
def _collect_files_from_local_path(self):
|
|
544
564
|
"""
|
|
545
565
|
If the model path given is of a local path, search for the needed model files and collect them into this handler
|
mlrun/utils/version/version.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mlrun
|
|
3
|
-
Version: 1.9.
|
|
3
|
+
Version: 1.9.2
|
|
4
4
|
Summary: Tracking and config of machine learning runs
|
|
5
5
|
Home-page: https://github.com/mlrun/mlrun
|
|
6
6
|
Author: Yaron Haviv
|
|
@@ -32,7 +32,7 @@ Requires-Dist: ipython~=8.10
|
|
|
32
32
|
Requires-Dist: nuclio-jupyter~=0.11.1
|
|
33
33
|
Requires-Dist: numpy<1.27.0,>=1.26.4
|
|
34
34
|
Requires-Dist: pandas<2.2,>=1.2
|
|
35
|
-
Requires-Dist: pyarrow<
|
|
35
|
+
Requires-Dist: pyarrow<18,>=10.0
|
|
36
36
|
Requires-Dist: pyyaml<7,>=6.0.2
|
|
37
37
|
Requires-Dist: requests~=2.32
|
|
38
38
|
Requires-Dist: tabulate~=0.8.6
|
|
@@ -17,7 +17,7 @@ mlrun/artifacts/__init__.py,sha256=ofC2extBCOC1wg1YtdTzWzH3eeG_f-sFBUkHjYtZJpk,1
|
|
|
17
17
|
mlrun/artifacts/base.py,sha256=mQnDToP65cgXzhWUBNOT3ObkO_FqltHNVDEn7nFgDRQ,29982
|
|
18
18
|
mlrun/artifacts/dataset.py,sha256=p8Rk0yrBUszh4pe7VLfcUK9piD-J_UX_X6gU5fYCyQg,16665
|
|
19
19
|
mlrun/artifacts/document.py,sha256=3X1i27NYSd-cOcX-lEvaNTUvwS2UKWXW2EnlfWokrVk,17374
|
|
20
|
-
mlrun/artifacts/manager.py,sha256=
|
|
20
|
+
mlrun/artifacts/manager.py,sha256=rI4FXu4ckGgZh7n08ITWHh6oYFW5G5C4OasHVcAGvjg,16212
|
|
21
21
|
mlrun/artifacts/model.py,sha256=J5b8zODrpx5ULtsgS9RGKqzMXYs7ADacE0BLBglmhrs,22239
|
|
22
22
|
mlrun/artifacts/plots.py,sha256=TxOHBaGbj7fEKNTHVIM_uxQjqPLpU3Rh1pqGh2_inuo,4833
|
|
23
23
|
mlrun/common/__init__.py,sha256=xY3wHC4TEJgez7qtnn1pQvHosi8-5UJOCtyGBS7FcGE,571
|
|
@@ -199,8 +199,8 @@ mlrun/frameworks/sklearn/mlrun_interface.py,sha256=JzHMBQM4sPBJqzb8P-rsG_2RQ_QrX
|
|
|
199
199
|
mlrun/frameworks/sklearn/model_handler.py,sha256=n0vpsQznva_WVloz7GTnfMGcMDQU_f1bHhUAJ_qxjfE,4753
|
|
200
200
|
mlrun/frameworks/sklearn/utils.py,sha256=Cg_pSxUMvKe8vBSLQor6JM8u9_ccKJg4Rk5EPDzTsVo,1209
|
|
201
201
|
mlrun/frameworks/tf_keras/__init__.py,sha256=M2sMbYHLrlF-KFR5kvA9mevRo3Nf8U0B5a_DM9rzwCY,10484
|
|
202
|
-
mlrun/frameworks/tf_keras/mlrun_interface.py,sha256=
|
|
203
|
-
mlrun/frameworks/tf_keras/model_handler.py,sha256=
|
|
202
|
+
mlrun/frameworks/tf_keras/mlrun_interface.py,sha256=ZKBqlkPY9Kk0Mhd1TLEamWCAOGTaiCyYytSCOlbYxXQ,16110
|
|
203
|
+
mlrun/frameworks/tf_keras/model_handler.py,sha256=2PZgxgB0XLIM6V2pLTWpU6UVKrjgGdSgcQNPz2kwFnU,32170
|
|
204
204
|
mlrun/frameworks/tf_keras/model_server.py,sha256=PZW6OBGTJ6bSfHedAhhW8HATbJyp2VaAzSDC02zjyKk,9653
|
|
205
205
|
mlrun/frameworks/tf_keras/utils.py,sha256=Z8hA1CgpSJWLC_T6Ay7xZKVyWlX9B85MSmQr2biXRag,4582
|
|
206
206
|
mlrun/frameworks/tf_keras/callbacks/__init__.py,sha256=sd8aWG2jO9mO_noZca0ReVf8X6fSCqO_di1Z-mT8FH8,742
|
|
@@ -341,11 +341,11 @@ mlrun/utils/notifications/notification/mail.py,sha256=ZyJ3eqd8simxffQmXzqd3bgbAq
|
|
|
341
341
|
mlrun/utils/notifications/notification/slack.py,sha256=kfhogR5keR7Zjh0VCjJNK3NR5_yXT7Cv-x9GdOUW4Z8,7294
|
|
342
342
|
mlrun/utils/notifications/notification/webhook.py,sha256=zxh8CAlbPnTazsk6r05X5TKwqUZVOH5KBU2fJbzQlG4,5330
|
|
343
343
|
mlrun/utils/version/__init__.py,sha256=7kkrB7hEZ3cLXoWj1kPoDwo4MaswsI2JVOBpbKgPAgc,614
|
|
344
|
-
mlrun/utils/version/version.json,sha256=
|
|
344
|
+
mlrun/utils/version/version.json,sha256=km_lKY6BDgwi6j7iUtnx5w9ONLfK5AYN2W7JQfgm2lI,84
|
|
345
345
|
mlrun/utils/version/version.py,sha256=eEW0tqIAkU9Xifxv8Z9_qsYnNhn3YH7NRAfM-pPLt1g,1878
|
|
346
|
-
mlrun-1.9.
|
|
347
|
-
mlrun-1.9.
|
|
348
|
-
mlrun-1.9.
|
|
349
|
-
mlrun-1.9.
|
|
350
|
-
mlrun-1.9.
|
|
351
|
-
mlrun-1.9.
|
|
346
|
+
mlrun-1.9.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
347
|
+
mlrun-1.9.2.dist-info/METADATA,sha256=evpYiCA1EF8NuHQs0recAwt-DUxWt-BJgb1ob9gb4ZY,25756
|
|
348
|
+
mlrun-1.9.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
349
|
+
mlrun-1.9.2.dist-info/entry_points.txt,sha256=1Owd16eAclD5pfRCoJpYC2ZJSyGNTtUr0nCELMioMmU,46
|
|
350
|
+
mlrun-1.9.2.dist-info/top_level.txt,sha256=NObLzw3maSF9wVrgSeYBv-fgnHkAJ1kEkh12DLdd5KM,6
|
|
351
|
+
mlrun-1.9.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|