mlrun 1.6.4rc8__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +11 -1
- mlrun/__main__.py +40 -122
- mlrun/alerts/__init__.py +15 -0
- mlrun/alerts/alert.py +248 -0
- mlrun/api/schemas/__init__.py +5 -4
- mlrun/artifacts/__init__.py +8 -3
- mlrun/artifacts/base.py +47 -257
- mlrun/artifacts/dataset.py +11 -192
- mlrun/artifacts/manager.py +79 -47
- mlrun/artifacts/model.py +31 -159
- mlrun/artifacts/plots.py +23 -380
- mlrun/common/constants.py +74 -1
- mlrun/common/db/sql_session.py +5 -5
- mlrun/common/formatters/__init__.py +21 -0
- mlrun/common/formatters/artifact.py +45 -0
- mlrun/common/formatters/base.py +113 -0
- mlrun/common/formatters/feature_set.py +33 -0
- mlrun/common/formatters/function.py +46 -0
- mlrun/common/formatters/pipeline.py +53 -0
- mlrun/common/formatters/project.py +51 -0
- mlrun/common/formatters/run.py +29 -0
- mlrun/common/helpers.py +12 -3
- mlrun/common/model_monitoring/helpers.py +9 -5
- mlrun/{runtimes → common/runtimes}/constants.py +37 -9
- mlrun/common/schemas/__init__.py +31 -5
- mlrun/common/schemas/alert.py +202 -0
- mlrun/common/schemas/api_gateway.py +196 -0
- mlrun/common/schemas/artifact.py +25 -4
- mlrun/common/schemas/auth.py +16 -5
- mlrun/common/schemas/background_task.py +1 -1
- mlrun/common/schemas/client_spec.py +4 -2
- mlrun/common/schemas/common.py +7 -4
- mlrun/common/schemas/constants.py +3 -0
- mlrun/common/schemas/feature_store.py +74 -44
- mlrun/common/schemas/frontend_spec.py +15 -7
- mlrun/common/schemas/function.py +12 -1
- mlrun/common/schemas/hub.py +11 -18
- mlrun/common/schemas/memory_reports.py +2 -2
- mlrun/common/schemas/model_monitoring/__init__.py +20 -4
- mlrun/common/schemas/model_monitoring/constants.py +123 -42
- mlrun/common/schemas/model_monitoring/grafana.py +13 -9
- mlrun/common/schemas/model_monitoring/model_endpoints.py +101 -54
- mlrun/common/schemas/notification.py +71 -14
- mlrun/common/schemas/object.py +2 -2
- mlrun/{model_monitoring/controller_handler.py → common/schemas/pagination.py} +9 -12
- mlrun/common/schemas/pipeline.py +8 -1
- mlrun/common/schemas/project.py +69 -18
- mlrun/common/schemas/runs.py +7 -1
- mlrun/common/schemas/runtime_resource.py +8 -12
- mlrun/common/schemas/schedule.py +4 -4
- mlrun/common/schemas/tag.py +1 -2
- mlrun/common/schemas/workflow.py +12 -4
- mlrun/common/types.py +14 -1
- mlrun/config.py +154 -69
- mlrun/data_types/data_types.py +6 -1
- mlrun/data_types/spark.py +2 -2
- mlrun/data_types/to_pandas.py +67 -37
- mlrun/datastore/__init__.py +6 -8
- mlrun/datastore/alibaba_oss.py +131 -0
- mlrun/datastore/azure_blob.py +143 -42
- mlrun/datastore/base.py +102 -58
- mlrun/datastore/datastore.py +34 -13
- mlrun/datastore/datastore_profile.py +146 -20
- mlrun/datastore/dbfs_store.py +3 -7
- mlrun/datastore/filestore.py +1 -4
- mlrun/datastore/google_cloud_storage.py +97 -33
- mlrun/datastore/hdfs.py +56 -0
- mlrun/datastore/inmem.py +6 -3
- mlrun/datastore/redis.py +7 -2
- mlrun/datastore/s3.py +34 -12
- mlrun/datastore/snowflake_utils.py +45 -0
- mlrun/datastore/sources.py +303 -111
- mlrun/datastore/spark_utils.py +31 -2
- mlrun/datastore/store_resources.py +9 -7
- mlrun/datastore/storeytargets.py +151 -0
- mlrun/datastore/targets.py +453 -176
- mlrun/datastore/utils.py +72 -58
- mlrun/datastore/v3io.py +6 -1
- mlrun/db/base.py +274 -41
- mlrun/db/factory.py +1 -1
- mlrun/db/httpdb.py +893 -225
- mlrun/db/nopdb.py +291 -33
- mlrun/errors.py +36 -6
- mlrun/execution.py +115 -42
- mlrun/feature_store/__init__.py +0 -2
- mlrun/feature_store/api.py +65 -73
- mlrun/feature_store/common.py +7 -12
- mlrun/feature_store/feature_set.py +76 -55
- mlrun/feature_store/feature_vector.py +39 -31
- mlrun/feature_store/ingestion.py +7 -6
- mlrun/feature_store/retrieval/base.py +16 -11
- mlrun/feature_store/retrieval/dask_merger.py +2 -0
- mlrun/feature_store/retrieval/job.py +13 -4
- mlrun/feature_store/retrieval/local_merger.py +2 -0
- mlrun/feature_store/retrieval/spark_merger.py +24 -32
- mlrun/feature_store/steps.py +45 -34
- mlrun/features.py +11 -21
- mlrun/frameworks/_common/artifacts_library.py +9 -9
- mlrun/frameworks/_common/mlrun_interface.py +5 -5
- mlrun/frameworks/_common/model_handler.py +48 -48
- mlrun/frameworks/_common/plan.py +5 -6
- mlrun/frameworks/_common/producer.py +3 -4
- mlrun/frameworks/_common/utils.py +5 -5
- mlrun/frameworks/_dl_common/loggers/logger.py +6 -7
- mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +9 -9
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +23 -47
- mlrun/frameworks/_ml_common/artifacts_library.py +1 -2
- mlrun/frameworks/_ml_common/loggers/logger.py +3 -4
- mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +4 -5
- mlrun/frameworks/_ml_common/model_handler.py +24 -24
- mlrun/frameworks/_ml_common/pkl_model_server.py +2 -2
- mlrun/frameworks/_ml_common/plan.py +2 -2
- mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
- mlrun/frameworks/_ml_common/utils.py +4 -4
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +9 -9
- mlrun/frameworks/huggingface/model_server.py +4 -4
- mlrun/frameworks/lgbm/__init__.py +33 -33
- mlrun/frameworks/lgbm/callbacks/callback.py +2 -4
- mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -5
- mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -5
- mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +1 -3
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +6 -6
- mlrun/frameworks/lgbm/model_handler.py +10 -10
- mlrun/frameworks/lgbm/model_server.py +6 -6
- mlrun/frameworks/lgbm/utils.py +5 -5
- mlrun/frameworks/onnx/dataset.py +8 -8
- mlrun/frameworks/onnx/mlrun_interface.py +3 -3
- mlrun/frameworks/onnx/model_handler.py +6 -6
- mlrun/frameworks/onnx/model_server.py +7 -7
- mlrun/frameworks/parallel_coordinates.py +6 -6
- mlrun/frameworks/pytorch/__init__.py +18 -18
- mlrun/frameworks/pytorch/callbacks/callback.py +4 -5
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +17 -17
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +23 -29
- mlrun/frameworks/pytorch/callbacks_handler.py +38 -38
- mlrun/frameworks/pytorch/mlrun_interface.py +20 -20
- mlrun/frameworks/pytorch/model_handler.py +17 -17
- mlrun/frameworks/pytorch/model_server.py +7 -7
- mlrun/frameworks/sklearn/__init__.py +13 -13
- mlrun/frameworks/sklearn/estimator.py +4 -4
- mlrun/frameworks/sklearn/metrics_library.py +14 -14
- mlrun/frameworks/sklearn/mlrun_interface.py +16 -9
- mlrun/frameworks/sklearn/model_handler.py +2 -2
- mlrun/frameworks/tf_keras/__init__.py +10 -7
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +15 -15
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +19 -23
- mlrun/frameworks/tf_keras/mlrun_interface.py +9 -11
- mlrun/frameworks/tf_keras/model_handler.py +14 -14
- mlrun/frameworks/tf_keras/model_server.py +6 -6
- mlrun/frameworks/xgboost/__init__.py +13 -13
- mlrun/frameworks/xgboost/model_handler.py +6 -6
- mlrun/k8s_utils.py +61 -17
- mlrun/launcher/__init__.py +1 -1
- mlrun/launcher/base.py +16 -15
- mlrun/launcher/client.py +13 -11
- mlrun/launcher/factory.py +1 -1
- mlrun/launcher/local.py +23 -13
- mlrun/launcher/remote.py +17 -10
- mlrun/lists.py +7 -6
- mlrun/model.py +478 -103
- mlrun/model_monitoring/__init__.py +1 -1
- mlrun/model_monitoring/api.py +163 -371
- mlrun/{runtimes/mpijob/v1alpha1.py → model_monitoring/applications/__init__.py} +9 -15
- mlrun/model_monitoring/applications/_application_steps.py +188 -0
- mlrun/model_monitoring/applications/base.py +108 -0
- mlrun/model_monitoring/applications/context.py +341 -0
- mlrun/model_monitoring/{evidently_application.py → applications/evidently_base.py} +27 -22
- mlrun/model_monitoring/applications/histogram_data_drift.py +354 -0
- mlrun/model_monitoring/applications/results.py +99 -0
- mlrun/model_monitoring/controller.py +131 -278
- mlrun/model_monitoring/db/__init__.py +18 -0
- mlrun/model_monitoring/db/stores/__init__.py +136 -0
- mlrun/model_monitoring/db/stores/base/__init__.py +15 -0
- mlrun/model_monitoring/db/stores/base/store.py +213 -0
- mlrun/model_monitoring/db/stores/sqldb/__init__.py +13 -0
- mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +71 -0
- mlrun/model_monitoring/db/stores/sqldb/models/base.py +190 -0
- mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +103 -0
- mlrun/model_monitoring/{stores/models/mysql.py → db/stores/sqldb/models/sqlite.py} +19 -13
- mlrun/model_monitoring/db/stores/sqldb/sql_store.py +659 -0
- mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +13 -0
- mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +726 -0
- mlrun/model_monitoring/db/tsdb/__init__.py +105 -0
- mlrun/model_monitoring/db/tsdb/base.py +448 -0
- mlrun/model_monitoring/db/tsdb/helpers.py +30 -0
- mlrun/model_monitoring/db/tsdb/tdengine/__init__.py +15 -0
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +279 -0
- mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +42 -0
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +507 -0
- mlrun/model_monitoring/db/tsdb/v3io/__init__.py +15 -0
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +158 -0
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +849 -0
- mlrun/model_monitoring/features_drift_table.py +134 -106
- mlrun/model_monitoring/helpers.py +199 -55
- mlrun/model_monitoring/metrics/__init__.py +13 -0
- mlrun/model_monitoring/metrics/histogram_distance.py +127 -0
- mlrun/model_monitoring/model_endpoint.py +3 -2
- mlrun/model_monitoring/stream_processing.py +134 -398
- mlrun/model_monitoring/tracking_policy.py +9 -2
- mlrun/model_monitoring/writer.py +161 -125
- mlrun/package/__init__.py +6 -6
- mlrun/package/context_handler.py +5 -5
- mlrun/package/packager.py +7 -7
- mlrun/package/packagers/default_packager.py +8 -8
- mlrun/package/packagers/numpy_packagers.py +15 -15
- mlrun/package/packagers/pandas_packagers.py +5 -5
- mlrun/package/packagers/python_standard_library_packagers.py +10 -10
- mlrun/package/packagers_manager.py +19 -23
- mlrun/package/utils/_formatter.py +6 -6
- mlrun/package/utils/_pickler.py +2 -2
- mlrun/package/utils/_supported_format.py +4 -4
- mlrun/package/utils/log_hint_utils.py +2 -2
- mlrun/package/utils/type_hint_utils.py +4 -9
- mlrun/platforms/__init__.py +11 -10
- mlrun/platforms/iguazio.py +24 -203
- mlrun/projects/operations.py +52 -25
- mlrun/projects/pipelines.py +191 -197
- mlrun/projects/project.py +1227 -400
- mlrun/render.py +16 -19
- mlrun/run.py +209 -184
- mlrun/runtimes/__init__.py +83 -15
- mlrun/runtimes/base.py +51 -35
- mlrun/runtimes/daskjob.py +17 -10
- mlrun/runtimes/databricks_job/databricks_cancel_task.py +1 -1
- mlrun/runtimes/databricks_job/databricks_runtime.py +8 -7
- mlrun/runtimes/databricks_job/databricks_wrapper.py +1 -1
- mlrun/runtimes/funcdoc.py +1 -29
- mlrun/runtimes/function_reference.py +1 -1
- mlrun/runtimes/kubejob.py +34 -128
- mlrun/runtimes/local.py +40 -11
- mlrun/runtimes/mpijob/__init__.py +0 -20
- mlrun/runtimes/mpijob/abstract.py +9 -10
- mlrun/runtimes/mpijob/v1.py +1 -1
- mlrun/{model_monitoring/stores/models/sqlite.py → runtimes/nuclio/__init__.py} +7 -9
- mlrun/runtimes/nuclio/api_gateway.py +769 -0
- mlrun/runtimes/nuclio/application/__init__.py +15 -0
- mlrun/runtimes/nuclio/application/application.py +758 -0
- mlrun/runtimes/nuclio/application/reverse_proxy.go +95 -0
- mlrun/runtimes/{function.py → nuclio/function.py} +200 -83
- mlrun/runtimes/{nuclio.py → nuclio/nuclio.py} +6 -6
- mlrun/runtimes/{serving.py → nuclio/serving.py} +65 -68
- mlrun/runtimes/pod.py +281 -101
- mlrun/runtimes/remotesparkjob.py +12 -9
- mlrun/runtimes/sparkjob/spark3job.py +67 -51
- mlrun/runtimes/utils.py +41 -75
- mlrun/secrets.py +9 -5
- mlrun/serving/__init__.py +8 -1
- mlrun/serving/remote.py +2 -7
- mlrun/serving/routers.py +85 -69
- mlrun/serving/server.py +69 -44
- mlrun/serving/states.py +209 -36
- mlrun/serving/utils.py +22 -14
- mlrun/serving/v1_serving.py +6 -7
- mlrun/serving/v2_serving.py +133 -54
- mlrun/track/tracker.py +2 -1
- mlrun/track/tracker_manager.py +3 -3
- mlrun/track/trackers/mlflow_tracker.py +6 -2
- mlrun/utils/async_http.py +6 -8
- mlrun/utils/azure_vault.py +1 -1
- mlrun/utils/clones.py +1 -2
- mlrun/utils/condition_evaluator.py +3 -3
- mlrun/utils/db.py +21 -3
- mlrun/utils/helpers.py +405 -225
- mlrun/utils/http.py +3 -6
- mlrun/utils/logger.py +112 -16
- mlrun/utils/notifications/notification/__init__.py +17 -13
- mlrun/utils/notifications/notification/base.py +50 -2
- mlrun/utils/notifications/notification/console.py +2 -0
- mlrun/utils/notifications/notification/git.py +24 -1
- mlrun/utils/notifications/notification/ipython.py +3 -1
- mlrun/utils/notifications/notification/slack.py +96 -21
- mlrun/utils/notifications/notification/webhook.py +59 -2
- mlrun/utils/notifications/notification_pusher.py +149 -30
- mlrun/utils/regex.py +9 -0
- mlrun/utils/retryer.py +208 -0
- mlrun/utils/singleton.py +1 -1
- mlrun/utils/v3io_clients.py +4 -6
- mlrun/utils/version/version.json +2 -2
- mlrun/utils/version/version.py +2 -6
- mlrun-1.7.0.dist-info/METADATA +378 -0
- mlrun-1.7.0.dist-info/RECORD +351 -0
- {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/WHEEL +1 -1
- mlrun/feature_store/retrieval/conversion.py +0 -273
- mlrun/kfpops.py +0 -868
- mlrun/model_monitoring/application.py +0 -310
- mlrun/model_monitoring/batch.py +0 -1095
- mlrun/model_monitoring/prometheus.py +0 -219
- mlrun/model_monitoring/stores/__init__.py +0 -111
- mlrun/model_monitoring/stores/kv_model_endpoint_store.py +0 -576
- mlrun/model_monitoring/stores/model_endpoint_store.py +0 -147
- mlrun/model_monitoring/stores/models/__init__.py +0 -27
- mlrun/model_monitoring/stores/models/base.py +0 -84
- mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -384
- mlrun/platforms/other.py +0 -306
- mlrun-1.6.4rc8.dist-info/METADATA +0 -272
- mlrun-1.6.4rc8.dist-info/RECORD +0 -314
- {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/LICENSE +0 -0
- {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/entry_points.txt +0 -0
- {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import Union
|
|
16
16
|
|
|
17
17
|
from torch import Tensor
|
|
18
18
|
from torch.nn import Module
|
|
@@ -66,7 +66,7 @@ class CallbacksHandler:
|
|
|
66
66
|
A class for handling multiple callbacks during a run.
|
|
67
67
|
"""
|
|
68
68
|
|
|
69
|
-
def __init__(self, callbacks:
|
|
69
|
+
def __init__(self, callbacks: list[Union[Callback, tuple[str, Callback]]]):
|
|
70
70
|
"""
|
|
71
71
|
Initialize the callbacks handler with the given callbacks he will handle. The callbacks can be passed as their
|
|
72
72
|
initialized instances or as a tuple where [0] is a name that will be attached to him and [1] will be the
|
|
@@ -99,7 +99,7 @@ class CallbacksHandler:
|
|
|
99
99
|
self._callbacks[callback.__class__.__name__] = callback
|
|
100
100
|
|
|
101
101
|
@property
|
|
102
|
-
def callbacks(self) ->
|
|
102
|
+
def callbacks(self) -> dict[str, Callback]:
|
|
103
103
|
"""
|
|
104
104
|
Get the callbacks dictionary handled by this handler.
|
|
105
105
|
|
|
@@ -114,9 +114,9 @@ class CallbacksHandler:
|
|
|
114
114
|
validation_set: DataLoader,
|
|
115
115
|
loss_function: Module,
|
|
116
116
|
optimizer: Optimizer,
|
|
117
|
-
metric_functions:
|
|
117
|
+
metric_functions: list[PyTorchTypes.MetricFunctionType],
|
|
118
118
|
scheduler,
|
|
119
|
-
callbacks:
|
|
119
|
+
callbacks: list[str] = None,
|
|
120
120
|
) -> bool:
|
|
121
121
|
"""
|
|
122
122
|
Call the 'on_setup' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
@@ -145,7 +145,7 @@ class CallbacksHandler:
|
|
|
145
145
|
scheduler=scheduler,
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
-
def on_run_begin(self, callbacks:
|
|
148
|
+
def on_run_begin(self, callbacks: list[str] = None) -> bool:
|
|
149
149
|
"""
|
|
150
150
|
Call the 'on_run_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
151
151
|
callbacks will be called.
|
|
@@ -159,7 +159,7 @@ class CallbacksHandler:
|
|
|
159
159
|
callbacks=self._parse_names(names=callbacks),
|
|
160
160
|
)
|
|
161
161
|
|
|
162
|
-
def on_run_end(self, callbacks:
|
|
162
|
+
def on_run_end(self, callbacks: list[str] = None) -> bool:
|
|
163
163
|
"""
|
|
164
164
|
Call the 'on_run_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
165
165
|
callbacks will be called.
|
|
@@ -173,7 +173,7 @@ class CallbacksHandler:
|
|
|
173
173
|
callbacks=self._parse_names(names=callbacks),
|
|
174
174
|
)
|
|
175
175
|
|
|
176
|
-
def on_epoch_begin(self, epoch: int, callbacks:
|
|
176
|
+
def on_epoch_begin(self, epoch: int, callbacks: list[str] = None) -> bool:
|
|
177
177
|
"""
|
|
178
178
|
Call the 'on_epoch_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
179
179
|
callbacks will be called.
|
|
@@ -189,7 +189,7 @@ class CallbacksHandler:
|
|
|
189
189
|
epoch=epoch,
|
|
190
190
|
)
|
|
191
191
|
|
|
192
|
-
def on_epoch_end(self, epoch: int, callbacks:
|
|
192
|
+
def on_epoch_end(self, epoch: int, callbacks: list[str] = None) -> bool:
|
|
193
193
|
"""
|
|
194
194
|
Call the 'on_epoch_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
195
195
|
callbacks will be called.
|
|
@@ -205,7 +205,7 @@ class CallbacksHandler:
|
|
|
205
205
|
epoch=epoch,
|
|
206
206
|
)
|
|
207
207
|
|
|
208
|
-
def on_train_begin(self, callbacks:
|
|
208
|
+
def on_train_begin(self, callbacks: list[str] = None) -> bool:
|
|
209
209
|
"""
|
|
210
210
|
Call the 'on_train_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
211
211
|
callbacks will be called.
|
|
@@ -219,7 +219,7 @@ class CallbacksHandler:
|
|
|
219
219
|
callbacks=self._parse_names(names=callbacks),
|
|
220
220
|
)
|
|
221
221
|
|
|
222
|
-
def on_train_end(self, callbacks:
|
|
222
|
+
def on_train_end(self, callbacks: list[str] = None) -> bool:
|
|
223
223
|
"""
|
|
224
224
|
Call the 'on_train_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
225
225
|
callbacks will be called.
|
|
@@ -233,7 +233,7 @@ class CallbacksHandler:
|
|
|
233
233
|
callbacks=self._parse_names(names=callbacks),
|
|
234
234
|
)
|
|
235
235
|
|
|
236
|
-
def on_validation_begin(self, callbacks:
|
|
236
|
+
def on_validation_begin(self, callbacks: list[str] = None) -> bool:
|
|
237
237
|
"""
|
|
238
238
|
Call the 'on_validation_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
239
239
|
(not given), all callbacks will be called.
|
|
@@ -250,8 +250,8 @@ class CallbacksHandler:
|
|
|
250
250
|
def on_validation_end(
|
|
251
251
|
self,
|
|
252
252
|
loss_value: PyTorchTypes.MetricValueType,
|
|
253
|
-
metric_values:
|
|
254
|
-
callbacks:
|
|
253
|
+
metric_values: list[float],
|
|
254
|
+
callbacks: list[str] = None,
|
|
255
255
|
) -> bool:
|
|
256
256
|
"""
|
|
257
257
|
Call the 'on_validation_end' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
@@ -271,7 +271,7 @@ class CallbacksHandler:
|
|
|
271
271
|
)
|
|
272
272
|
|
|
273
273
|
def on_train_batch_begin(
|
|
274
|
-
self, batch: int, x, y_true: Tensor, callbacks:
|
|
274
|
+
self, batch: int, x, y_true: Tensor, callbacks: list[str] = None
|
|
275
275
|
) -> bool:
|
|
276
276
|
"""
|
|
277
277
|
Call the 'on_train_batch_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -298,7 +298,7 @@ class CallbacksHandler:
|
|
|
298
298
|
x,
|
|
299
299
|
y_pred: Tensor,
|
|
300
300
|
y_true: Tensor,
|
|
301
|
-
callbacks:
|
|
301
|
+
callbacks: list[str] = None,
|
|
302
302
|
) -> bool:
|
|
303
303
|
"""
|
|
304
304
|
Call the 'on_train_batch_end' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
@@ -322,7 +322,7 @@ class CallbacksHandler:
|
|
|
322
322
|
)
|
|
323
323
|
|
|
324
324
|
def on_validation_batch_begin(
|
|
325
|
-
self, batch: int, x, y_true: Tensor, callbacks:
|
|
325
|
+
self, batch: int, x, y_true: Tensor, callbacks: list[str] = None
|
|
326
326
|
) -> bool:
|
|
327
327
|
"""
|
|
328
328
|
Call the 'on_validation_batch_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -349,7 +349,7 @@ class CallbacksHandler:
|
|
|
349
349
|
x,
|
|
350
350
|
y_pred: Tensor,
|
|
351
351
|
y_true: Tensor,
|
|
352
|
-
callbacks:
|
|
352
|
+
callbacks: list[str] = None,
|
|
353
353
|
) -> bool:
|
|
354
354
|
"""
|
|
355
355
|
Call the 'on_validation_batch_end' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -375,7 +375,7 @@ class CallbacksHandler:
|
|
|
375
375
|
def on_inference_begin(
|
|
376
376
|
self,
|
|
377
377
|
x,
|
|
378
|
-
callbacks:
|
|
378
|
+
callbacks: list[str] = None,
|
|
379
379
|
) -> bool:
|
|
380
380
|
"""
|
|
381
381
|
Call the 'on_inference_begin' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
@@ -396,7 +396,7 @@ class CallbacksHandler:
|
|
|
396
396
|
self,
|
|
397
397
|
y_pred: Tensor,
|
|
398
398
|
y_true: Tensor,
|
|
399
|
-
callbacks:
|
|
399
|
+
callbacks: list[str] = None,
|
|
400
400
|
) -> bool:
|
|
401
401
|
"""
|
|
402
402
|
Call the 'on_inference_end' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
@@ -415,7 +415,7 @@ class CallbacksHandler:
|
|
|
415
415
|
y_true=y_true,
|
|
416
416
|
)
|
|
417
417
|
|
|
418
|
-
def on_train_loss_begin(self, callbacks:
|
|
418
|
+
def on_train_loss_begin(self, callbacks: list[str] = None) -> bool:
|
|
419
419
|
"""
|
|
420
420
|
Call the 'on_train_loss_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
421
421
|
(not given), all callbacks will be called.
|
|
@@ -430,7 +430,7 @@ class CallbacksHandler:
|
|
|
430
430
|
)
|
|
431
431
|
|
|
432
432
|
def on_train_loss_end(
|
|
433
|
-
self, loss_value: PyTorchTypes.MetricValueType, callbacks:
|
|
433
|
+
self, loss_value: PyTorchTypes.MetricValueType, callbacks: list[str] = None
|
|
434
434
|
) -> bool:
|
|
435
435
|
"""
|
|
436
436
|
Call the 'on_train_loss_end' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
@@ -447,7 +447,7 @@ class CallbacksHandler:
|
|
|
447
447
|
loss_value=loss_value,
|
|
448
448
|
)
|
|
449
449
|
|
|
450
|
-
def on_validation_loss_begin(self, callbacks:
|
|
450
|
+
def on_validation_loss_begin(self, callbacks: list[str] = None) -> bool:
|
|
451
451
|
"""
|
|
452
452
|
Call the 'on_validation_loss_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
453
453
|
(not given), all callbacks will be called.
|
|
@@ -462,7 +462,7 @@ class CallbacksHandler:
|
|
|
462
462
|
)
|
|
463
463
|
|
|
464
464
|
def on_validation_loss_end(
|
|
465
|
-
self, loss_value: PyTorchTypes.MetricValueType, callbacks:
|
|
465
|
+
self, loss_value: PyTorchTypes.MetricValueType, callbacks: list[str] = None
|
|
466
466
|
) -> bool:
|
|
467
467
|
"""
|
|
468
468
|
Call the 'on_validation_loss_end' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -479,7 +479,7 @@ class CallbacksHandler:
|
|
|
479
479
|
loss_value=loss_value,
|
|
480
480
|
)
|
|
481
481
|
|
|
482
|
-
def on_train_metrics_begin(self, callbacks:
|
|
482
|
+
def on_train_metrics_begin(self, callbacks: list[str] = None) -> bool:
|
|
483
483
|
"""
|
|
484
484
|
Call the 'on_train_metrics_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
485
485
|
(not given), all callbacks will be called.
|
|
@@ -495,8 +495,8 @@ class CallbacksHandler:
|
|
|
495
495
|
|
|
496
496
|
def on_train_metrics_end(
|
|
497
497
|
self,
|
|
498
|
-
metric_values:
|
|
499
|
-
callbacks:
|
|
498
|
+
metric_values: list[PyTorchTypes.MetricValueType],
|
|
499
|
+
callbacks: list[str] = None,
|
|
500
500
|
) -> bool:
|
|
501
501
|
"""
|
|
502
502
|
Call the 'on_train_metrics_end' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -513,7 +513,7 @@ class CallbacksHandler:
|
|
|
513
513
|
metric_values=metric_values,
|
|
514
514
|
)
|
|
515
515
|
|
|
516
|
-
def on_validation_metrics_begin(self, callbacks:
|
|
516
|
+
def on_validation_metrics_begin(self, callbacks: list[str] = None) -> bool:
|
|
517
517
|
"""
|
|
518
518
|
Call the 'on_validation_metrics_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
519
519
|
(not given), all callbacks will be called.
|
|
@@ -529,8 +529,8 @@ class CallbacksHandler:
|
|
|
529
529
|
|
|
530
530
|
def on_validation_metrics_end(
|
|
531
531
|
self,
|
|
532
|
-
metric_values:
|
|
533
|
-
callbacks:
|
|
532
|
+
metric_values: list[PyTorchTypes.MetricValueType],
|
|
533
|
+
callbacks: list[str] = None,
|
|
534
534
|
) -> bool:
|
|
535
535
|
"""
|
|
536
536
|
Call the 'on_validation_metrics_end' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -547,7 +547,7 @@ class CallbacksHandler:
|
|
|
547
547
|
metric_values=metric_values,
|
|
548
548
|
)
|
|
549
549
|
|
|
550
|
-
def on_backward_begin(self, callbacks:
|
|
550
|
+
def on_backward_begin(self, callbacks: list[str] = None) -> bool:
|
|
551
551
|
"""
|
|
552
552
|
Call the 'on_backward_begin' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
553
553
|
all callbacks will be called.
|
|
@@ -561,7 +561,7 @@ class CallbacksHandler:
|
|
|
561
561
|
callbacks=self._parse_names(names=callbacks),
|
|
562
562
|
)
|
|
563
563
|
|
|
564
|
-
def on_backward_end(self, callbacks:
|
|
564
|
+
def on_backward_end(self, callbacks: list[str] = None) -> bool:
|
|
565
565
|
"""
|
|
566
566
|
Call the 'on_backward_end' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
567
567
|
all callbacks will be called.
|
|
@@ -575,7 +575,7 @@ class CallbacksHandler:
|
|
|
575
575
|
callbacks=self._parse_names(names=callbacks),
|
|
576
576
|
)
|
|
577
577
|
|
|
578
|
-
def on_optimizer_step_begin(self, callbacks:
|
|
578
|
+
def on_optimizer_step_begin(self, callbacks: list[str] = None) -> bool:
|
|
579
579
|
"""
|
|
580
580
|
Call the 'on_optimizer_step_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
581
581
|
(not given), all callbacks will be called.
|
|
@@ -589,7 +589,7 @@ class CallbacksHandler:
|
|
|
589
589
|
callbacks=self._parse_names(names=callbacks),
|
|
590
590
|
)
|
|
591
591
|
|
|
592
|
-
def on_optimizer_step_end(self, callbacks:
|
|
592
|
+
def on_optimizer_step_end(self, callbacks: list[str] = None) -> bool:
|
|
593
593
|
"""
|
|
594
594
|
Call the 'on_optimizer_step_end' method of every callback in the callbacks list. If the list is 'None'
|
|
595
595
|
(not given), all callbacks will be called.
|
|
@@ -603,7 +603,7 @@ class CallbacksHandler:
|
|
|
603
603
|
callbacks=self._parse_names(names=callbacks),
|
|
604
604
|
)
|
|
605
605
|
|
|
606
|
-
def on_scheduler_step_begin(self, callbacks:
|
|
606
|
+
def on_scheduler_step_begin(self, callbacks: list[str] = None) -> bool:
|
|
607
607
|
"""
|
|
608
608
|
Call the 'on_scheduler_step_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
609
609
|
(not given), all callbacks will be called.
|
|
@@ -617,7 +617,7 @@ class CallbacksHandler:
|
|
|
617
617
|
callbacks=self._parse_names(names=callbacks),
|
|
618
618
|
)
|
|
619
619
|
|
|
620
|
-
def on_scheduler_step_end(self, callbacks:
|
|
620
|
+
def on_scheduler_step_end(self, callbacks: list[str] = None) -> bool:
|
|
621
621
|
"""
|
|
622
622
|
Call the 'on_scheduler_step_end' method of every callback in the callbacks list. If the list is 'None'
|
|
623
623
|
(not given), all callbacks will be called.
|
|
@@ -631,7 +631,7 @@ class CallbacksHandler:
|
|
|
631
631
|
callbacks=self._parse_names(names=callbacks),
|
|
632
632
|
)
|
|
633
633
|
|
|
634
|
-
def _parse_names(self, names: Union[
|
|
634
|
+
def _parse_names(self, names: Union[list[str], None]) -> list[str]:
|
|
635
635
|
"""
|
|
636
636
|
Parse the given callbacks names. If they are not 'None' then the names will be returned as they are, otherwise
|
|
637
637
|
all of the callbacks handled by this handler will be returned (the default behavior of when there were no names
|
|
@@ -646,7 +646,7 @@ class CallbacksHandler:
|
|
|
646
646
|
return list(self._callbacks.keys())
|
|
647
647
|
|
|
648
648
|
def _run_callbacks(
|
|
649
|
-
self, method_name: str, callbacks:
|
|
649
|
+
self, method_name: str, callbacks: list[str], *args, **kwargs
|
|
650
650
|
) -> bool:
|
|
651
651
|
"""
|
|
652
652
|
Run the given method from the 'CallbackInterface' on all the specified callbacks with the given arguments.
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
#
|
|
15
15
|
import importlib
|
|
16
16
|
import sys
|
|
17
|
-
from typing import Any,
|
|
17
|
+
from typing import Any, Union
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
20
|
import torch.multiprocessing as mp
|
|
@@ -109,13 +109,13 @@ class PyTorchMLRunInterface:
|
|
|
109
109
|
loss_function: Module,
|
|
110
110
|
optimizer: Optimizer,
|
|
111
111
|
validation_set: DataLoader = None,
|
|
112
|
-
metric_functions:
|
|
112
|
+
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
113
113
|
scheduler=None,
|
|
114
114
|
scheduler_step_frequency: Union[int, float, str] = "epoch",
|
|
115
115
|
epochs: int = 1,
|
|
116
116
|
training_iterations: int = None,
|
|
117
117
|
validation_iterations: int = None,
|
|
118
|
-
callbacks:
|
|
118
|
+
callbacks: list[Callback] = None,
|
|
119
119
|
use_cuda: bool = True,
|
|
120
120
|
use_horovod: bool = None,
|
|
121
121
|
):
|
|
@@ -221,12 +221,12 @@ class PyTorchMLRunInterface:
|
|
|
221
221
|
self,
|
|
222
222
|
dataset: DataLoader,
|
|
223
223
|
loss_function: Module = None,
|
|
224
|
-
metric_functions:
|
|
224
|
+
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
225
225
|
iterations: int = None,
|
|
226
|
-
callbacks:
|
|
226
|
+
callbacks: list[Callback] = None,
|
|
227
227
|
use_cuda: bool = True,
|
|
228
228
|
use_horovod: bool = None,
|
|
229
|
-
) ->
|
|
229
|
+
) -> list[PyTorchTypes.MetricValueType]:
|
|
230
230
|
"""
|
|
231
231
|
Initiate an evaluation process on this interface configuration.
|
|
232
232
|
|
|
@@ -303,9 +303,9 @@ class PyTorchMLRunInterface:
|
|
|
303
303
|
def add_auto_logging_callbacks(
|
|
304
304
|
self,
|
|
305
305
|
add_mlrun_logger: bool = True,
|
|
306
|
-
mlrun_callback_kwargs:
|
|
306
|
+
mlrun_callback_kwargs: dict[str, Any] = None,
|
|
307
307
|
add_tensorboard_logger: bool = True,
|
|
308
|
-
tensorboard_callback_kwargs:
|
|
308
|
+
tensorboard_callback_kwargs: dict[str, Any] = None,
|
|
309
309
|
):
|
|
310
310
|
"""
|
|
311
311
|
Get automatic logging callbacks to both MLRun's context and Tensorboard. For further features of logging to both
|
|
@@ -347,7 +347,7 @@ class PyTorchMLRunInterface:
|
|
|
347
347
|
|
|
348
348
|
def predict(
|
|
349
349
|
self,
|
|
350
|
-
inputs: Union[Tensor,
|
|
350
|
+
inputs: Union[Tensor, list[Tensor]],
|
|
351
351
|
use_cuda: bool = True,
|
|
352
352
|
batch_size: int = -1,
|
|
353
353
|
) -> Tensor:
|
|
@@ -402,13 +402,13 @@ class PyTorchMLRunInterface:
|
|
|
402
402
|
loss_function: Module = None,
|
|
403
403
|
optimizer: Optimizer = None,
|
|
404
404
|
validation_set: DataLoader = None,
|
|
405
|
-
metric_functions:
|
|
405
|
+
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
406
406
|
scheduler=None,
|
|
407
407
|
scheduler_step_frequency: Union[int, float, str] = "epoch",
|
|
408
408
|
epochs: int = 1,
|
|
409
409
|
training_iterations: int = None,
|
|
410
410
|
validation_iterations: int = None,
|
|
411
|
-
callbacks:
|
|
411
|
+
callbacks: list[Callback] = None,
|
|
412
412
|
use_cuda: bool = True,
|
|
413
413
|
use_horovod: bool = None,
|
|
414
414
|
):
|
|
@@ -734,7 +734,7 @@ class PyTorchMLRunInterface:
|
|
|
734
734
|
|
|
735
735
|
def _validate(
|
|
736
736
|
self, is_evaluation: bool = False
|
|
737
|
-
) ->
|
|
737
|
+
) -> tuple[PyTorchTypes.MetricValueType, list[PyTorchTypes.MetricValueType]]:
|
|
738
738
|
"""
|
|
739
739
|
Initiate a single epoch validation.
|
|
740
740
|
|
|
@@ -817,7 +817,7 @@ class PyTorchMLRunInterface:
|
|
|
817
817
|
)
|
|
818
818
|
return loss_value, metric_values
|
|
819
819
|
|
|
820
|
-
def _print_results(self, loss_value: Tensor, metric_values:
|
|
820
|
+
def _print_results(self, loss_value: Tensor, metric_values: list[float]):
|
|
821
821
|
"""
|
|
822
822
|
Print the given result between each epoch.
|
|
823
823
|
|
|
@@ -832,7 +832,7 @@ class PyTorchMLRunInterface:
|
|
|
832
832
|
+ tabulate(table, headers=["Metrics", "Values"], tablefmt="pretty")
|
|
833
833
|
)
|
|
834
834
|
|
|
835
|
-
def _metrics(self, y_pred: Tensor, y_true: Tensor) ->
|
|
835
|
+
def _metrics(self, y_pred: Tensor, y_true: Tensor) -> list[float]:
|
|
836
836
|
"""
|
|
837
837
|
Call all the metrics on the given batch's truth and prediction output.
|
|
838
838
|
|
|
@@ -860,7 +860,7 @@ class PyTorchMLRunInterface:
|
|
|
860
860
|
average_tensor = self._hvd.allreduce(rank_value, name=name)
|
|
861
861
|
return average_tensor.item()
|
|
862
862
|
|
|
863
|
-
def _get_learning_rate(self) -> Union[
|
|
863
|
+
def _get_learning_rate(self) -> Union[tuple[str, list[Union[str, int]]], None]:
|
|
864
864
|
"""
|
|
865
865
|
Try and get the learning rate value form the stored optimizer.
|
|
866
866
|
|
|
@@ -949,8 +949,8 @@ class PyTorchMLRunInterface:
|
|
|
949
949
|
|
|
950
950
|
@staticmethod
|
|
951
951
|
def _tensor_to_cuda(
|
|
952
|
-
tensor: Union[Tensor,
|
|
953
|
-
) -> Union[Tensor,
|
|
952
|
+
tensor: Union[Tensor, dict, list, tuple],
|
|
953
|
+
) -> Union[Tensor, dict, list, tuple]:
|
|
954
954
|
"""
|
|
955
955
|
Send to given tensor to cuda if it is a tensor. If the given object is a dictionary, the dictionary values will
|
|
956
956
|
be sent to the function again recursively. If the given object is a list or a tuple, all the values in it will
|
|
@@ -997,7 +997,7 @@ class PyTorchMLRunInterface:
|
|
|
997
997
|
dataset: DataLoader,
|
|
998
998
|
iterations: int,
|
|
999
999
|
description: str,
|
|
1000
|
-
metrics:
|
|
1000
|
+
metrics: list[PyTorchTypes.MetricFunctionType],
|
|
1001
1001
|
) -> tqdm:
|
|
1002
1002
|
"""
|
|
1003
1003
|
Create a progress bar for training and validating / evaluating.
|
|
@@ -1028,8 +1028,8 @@ class PyTorchMLRunInterface:
|
|
|
1028
1028
|
@staticmethod
|
|
1029
1029
|
def _update_progress_bar(
|
|
1030
1030
|
progress_bar: tqdm,
|
|
1031
|
-
metrics:
|
|
1032
|
-
values:
|
|
1031
|
+
metrics: list[PyTorchTypes.MetricFunctionType],
|
|
1032
|
+
values: list[PyTorchTypes.MetricValueType],
|
|
1033
1033
|
):
|
|
1034
1034
|
"""
|
|
1035
1035
|
Update the progress bar metrics results.
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
import os
|
|
16
|
-
from typing import
|
|
16
|
+
from typing import Union
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
import torch
|
|
@@ -50,9 +50,9 @@ class PyTorchModelHandler(DLModelHandler):
|
|
|
50
50
|
model: Module = None,
|
|
51
51
|
model_path: str = None,
|
|
52
52
|
model_name: str = None,
|
|
53
|
-
model_class: Union[
|
|
54
|
-
modules_map: Union[
|
|
55
|
-
custom_objects_map: Union[
|
|
53
|
+
model_class: Union[type[Module], str] = None,
|
|
54
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
55
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
56
56
|
custom_objects_directory: str = None,
|
|
57
57
|
context: mlrun.MLClientCtx = None,
|
|
58
58
|
**kwargs,
|
|
@@ -136,7 +136,7 @@ class PyTorchModelHandler(DLModelHandler):
|
|
|
136
136
|
)
|
|
137
137
|
|
|
138
138
|
# Set up the base handler class:
|
|
139
|
-
super(
|
|
139
|
+
super().__init__(
|
|
140
140
|
model=model,
|
|
141
141
|
model_path=model_path,
|
|
142
142
|
model_name=model_name,
|
|
@@ -152,8 +152,8 @@ class PyTorchModelHandler(DLModelHandler):
|
|
|
152
152
|
|
|
153
153
|
def set_labels(
|
|
154
154
|
self,
|
|
155
|
-
to_add:
|
|
156
|
-
to_remove:
|
|
155
|
+
to_add: dict[str, Union[str, int, float]] = None,
|
|
156
|
+
to_remove: list[str] = None,
|
|
157
157
|
):
|
|
158
158
|
"""
|
|
159
159
|
Update the labels dictionary of this model artifact. There are required labels that cannot be edited or removed.
|
|
@@ -162,14 +162,14 @@ class PyTorchModelHandler(DLModelHandler):
|
|
|
162
162
|
:param to_remove: A list of labels keys to remove.
|
|
163
163
|
"""
|
|
164
164
|
# Update the user's labels:
|
|
165
|
-
super(
|
|
165
|
+
super().set_labels(to_add=to_add, to_remove=to_remove)
|
|
166
166
|
|
|
167
167
|
# Set the required labels:
|
|
168
168
|
self._labels[self._LabelKeys.MODEL_CLASS_NAME] = self._model_class_name
|
|
169
169
|
|
|
170
170
|
def save(
|
|
171
171
|
self, output_path: str = None, **kwargs
|
|
172
|
-
) -> Union[
|
|
172
|
+
) -> Union[dict[str, Artifact], None]:
|
|
173
173
|
"""
|
|
174
174
|
Save the handled model at the given output path.
|
|
175
175
|
|
|
@@ -182,7 +182,7 @@ class PyTorchModelHandler(DLModelHandler):
|
|
|
182
182
|
:raise MLRunInvalidArgumentError: If an output path was not given, yet a context was not provided in
|
|
183
183
|
initialization.
|
|
184
184
|
"""
|
|
185
|
-
super(
|
|
185
|
+
super().save(output_path=output_path)
|
|
186
186
|
|
|
187
187
|
# Set the output path:
|
|
188
188
|
if output_path is None:
|
|
@@ -207,7 +207,7 @@ class PyTorchModelHandler(DLModelHandler):
|
|
|
207
207
|
|
|
208
208
|
:raise MLRunInvalidArgumentError: If the model's class is not in the custom objects map.
|
|
209
209
|
"""
|
|
210
|
-
super(
|
|
210
|
+
super().load()
|
|
211
211
|
|
|
212
212
|
# Validate the model's class is in the custom objects map:
|
|
213
213
|
if (
|
|
@@ -233,10 +233,10 @@ class PyTorchModelHandler(DLModelHandler):
|
|
|
233
233
|
def to_onnx(
|
|
234
234
|
self,
|
|
235
235
|
model_name: str = None,
|
|
236
|
-
input_sample: Union[torch.Tensor,
|
|
237
|
-
input_layers_names:
|
|
238
|
-
output_layers_names:
|
|
239
|
-
dynamic_axes:
|
|
236
|
+
input_sample: Union[torch.Tensor, tuple[torch.Tensor, ...]] = None,
|
|
237
|
+
input_layers_names: list[str] = None,
|
|
238
|
+
output_layers_names: list[str] = None,
|
|
239
|
+
dynamic_axes: dict[str, dict[int, str]] = None,
|
|
240
240
|
is_batched: bool = True,
|
|
241
241
|
optimize: bool = True,
|
|
242
242
|
output_path: str = None,
|
|
@@ -406,7 +406,7 @@ class PyTorchModelHandler(DLModelHandler):
|
|
|
406
406
|
]
|
|
407
407
|
|
|
408
408
|
# Continue collecting from abstract class:
|
|
409
|
-
super(
|
|
409
|
+
super()._collect_files_from_store_object()
|
|
410
410
|
|
|
411
411
|
def _collect_files_from_local_path(self):
|
|
412
412
|
"""
|
|
@@ -443,7 +443,7 @@ class PyTorchModelHandler(DLModelHandler):
|
|
|
443
443
|
"""
|
|
444
444
|
# Supported types:
|
|
445
445
|
if isinstance(sample, np.ndarray):
|
|
446
|
-
return super(
|
|
446
|
+
return super()._read_sample(sample=sample)
|
|
447
447
|
elif isinstance(sample, torch.Tensor):
|
|
448
448
|
return Feature(
|
|
449
449
|
value_type=PyTorchUtils.convert_torch_dtype_to_value_type(
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import torch
|
|
@@ -39,9 +39,9 @@ class PyTorchModelServer(V2ModelServer):
|
|
|
39
39
|
model: Module = None,
|
|
40
40
|
model_path: str = None,
|
|
41
41
|
model_name: str = None,
|
|
42
|
-
model_class: Union[
|
|
43
|
-
modules_map: Union[
|
|
44
|
-
custom_objects_map: Union[
|
|
42
|
+
model_class: Union[type[Module], str] = None,
|
|
43
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
44
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
45
45
|
custom_objects_directory: str = None,
|
|
46
46
|
use_cuda: bool = True,
|
|
47
47
|
to_list: bool = False,
|
|
@@ -106,7 +106,7 @@ class PyTorchModelServer(V2ModelServer):
|
|
|
106
106
|
:param protocol: -
|
|
107
107
|
:param class_args: -
|
|
108
108
|
"""
|
|
109
|
-
super(
|
|
109
|
+
super().__init__(
|
|
110
110
|
context=context,
|
|
111
111
|
name=name,
|
|
112
112
|
model_path=model_path,
|
|
@@ -158,7 +158,7 @@ class PyTorchModelServer(V2ModelServer):
|
|
|
158
158
|
model=self._model_handler.model, context=self.context
|
|
159
159
|
)
|
|
160
160
|
|
|
161
|
-
def predict(self, request:
|
|
161
|
+
def predict(self, request: dict[str, Any]) -> Union[Tensor, list]:
|
|
162
162
|
"""
|
|
163
163
|
Infer the inputs through the model using MLRun's PyTorch interface and return its output. The inferred data will
|
|
164
164
|
be read from the "inputs" key of the request.
|
|
@@ -183,7 +183,7 @@ class PyTorchModelServer(V2ModelServer):
|
|
|
183
183
|
# Return as list if required:
|
|
184
184
|
return predictions if not self.to_list else predictions.tolist()
|
|
185
185
|
|
|
186
|
-
def explain(self, request:
|
|
186
|
+
def explain(self, request: dict[str, Any]) -> str:
|
|
187
187
|
"""
|
|
188
188
|
Return a string explaining what model is being serve in this serving function and the function name.
|
|
189
189
|
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
#
|
|
15
15
|
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
|
|
16
16
|
import warnings
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import Union
|
|
18
18
|
|
|
19
19
|
import mlrun
|
|
20
20
|
from mlrun.frameworks.sklearn.metric import Metric
|
|
@@ -37,25 +37,25 @@ def apply_mlrun(
|
|
|
37
37
|
model_name: str = "model",
|
|
38
38
|
tag: str = "",
|
|
39
39
|
model_path: str = None,
|
|
40
|
-
modules_map: Union[
|
|
41
|
-
custom_objects_map: Union[
|
|
40
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
41
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
42
42
|
custom_objects_directory: str = None,
|
|
43
43
|
context: mlrun.MLClientCtx = None,
|
|
44
|
-
artifacts: Union[
|
|
44
|
+
artifacts: Union[list[MLPlan], list[str], dict[str, dict]] = None,
|
|
45
45
|
metrics: Union[
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
list[Metric],
|
|
47
|
+
list[SKLearnTypes.MetricEntryType],
|
|
48
|
+
dict[str, SKLearnTypes.MetricEntryType],
|
|
49
49
|
] = None,
|
|
50
50
|
x_test: SKLearnTypes.DatasetType = None,
|
|
51
51
|
y_test: SKLearnTypes.DatasetType = None,
|
|
52
52
|
sample_set: Union[SKLearnTypes.DatasetType, mlrun.DataItem, str] = None,
|
|
53
|
-
y_columns: Union[
|
|
53
|
+
y_columns: Union[list[str], list[int]] = None,
|
|
54
54
|
feature_vector: str = None,
|
|
55
|
-
feature_weights:
|
|
56
|
-
labels:
|
|
57
|
-
parameters:
|
|
58
|
-
extra_data:
|
|
55
|
+
feature_weights: list[float] = None,
|
|
56
|
+
labels: dict[str, Union[str, int, float]] = None,
|
|
57
|
+
parameters: dict[str, Union[str, int, float]] = None,
|
|
58
|
+
extra_data: dict[str, SKLearnTypes.ExtraDataType] = None,
|
|
59
59
|
auto_log: bool = True,
|
|
60
60
|
**kwargs,
|
|
61
61
|
) -> SKLearnModelHandler:
|
|
@@ -92,7 +92,7 @@ def apply_mlrun(
|
|
|
92
92
|
|
|
93
93
|
{
|
|
94
94
|
"/.../custom_model.py": "MyModel",
|
|
95
|
-
"/.../custom_objects.py": ["object1", "object2"]
|
|
95
|
+
"/.../custom_objects.py": ["object1", "object2"],
|
|
96
96
|
}
|
|
97
97
|
|
|
98
98
|
All the paths will be accessed from the given 'custom_objects_directory', meaning
|