mlrun 1.6.4rc7__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +11 -1
- mlrun/__main__.py +40 -122
- mlrun/alerts/__init__.py +15 -0
- mlrun/alerts/alert.py +248 -0
- mlrun/api/schemas/__init__.py +5 -4
- mlrun/artifacts/__init__.py +8 -3
- mlrun/artifacts/base.py +47 -257
- mlrun/artifacts/dataset.py +11 -192
- mlrun/artifacts/manager.py +79 -47
- mlrun/artifacts/model.py +31 -159
- mlrun/artifacts/plots.py +23 -380
- mlrun/common/constants.py +74 -1
- mlrun/common/db/sql_session.py +5 -5
- mlrun/common/formatters/__init__.py +21 -0
- mlrun/common/formatters/artifact.py +45 -0
- mlrun/common/formatters/base.py +113 -0
- mlrun/common/formatters/feature_set.py +33 -0
- mlrun/common/formatters/function.py +46 -0
- mlrun/common/formatters/pipeline.py +53 -0
- mlrun/common/formatters/project.py +51 -0
- mlrun/common/formatters/run.py +29 -0
- mlrun/common/helpers.py +12 -3
- mlrun/common/model_monitoring/helpers.py +9 -5
- mlrun/{runtimes → common/runtimes}/constants.py +37 -9
- mlrun/common/schemas/__init__.py +31 -5
- mlrun/common/schemas/alert.py +202 -0
- mlrun/common/schemas/api_gateway.py +196 -0
- mlrun/common/schemas/artifact.py +25 -4
- mlrun/common/schemas/auth.py +16 -5
- mlrun/common/schemas/background_task.py +1 -1
- mlrun/common/schemas/client_spec.py +4 -2
- mlrun/common/schemas/common.py +7 -4
- mlrun/common/schemas/constants.py +3 -0
- mlrun/common/schemas/feature_store.py +74 -44
- mlrun/common/schemas/frontend_spec.py +15 -7
- mlrun/common/schemas/function.py +12 -1
- mlrun/common/schemas/hub.py +11 -18
- mlrun/common/schemas/memory_reports.py +2 -2
- mlrun/common/schemas/model_monitoring/__init__.py +20 -4
- mlrun/common/schemas/model_monitoring/constants.py +123 -42
- mlrun/common/schemas/model_monitoring/grafana.py +13 -9
- mlrun/common/schemas/model_monitoring/model_endpoints.py +101 -54
- mlrun/common/schemas/notification.py +71 -14
- mlrun/common/schemas/object.py +2 -2
- mlrun/{model_monitoring/controller_handler.py → common/schemas/pagination.py} +9 -12
- mlrun/common/schemas/pipeline.py +8 -1
- mlrun/common/schemas/project.py +69 -18
- mlrun/common/schemas/runs.py +7 -1
- mlrun/common/schemas/runtime_resource.py +8 -12
- mlrun/common/schemas/schedule.py +4 -4
- mlrun/common/schemas/tag.py +1 -2
- mlrun/common/schemas/workflow.py +12 -4
- mlrun/common/types.py +14 -1
- mlrun/config.py +154 -69
- mlrun/data_types/data_types.py +6 -1
- mlrun/data_types/spark.py +2 -2
- mlrun/data_types/to_pandas.py +67 -37
- mlrun/datastore/__init__.py +6 -8
- mlrun/datastore/alibaba_oss.py +131 -0
- mlrun/datastore/azure_blob.py +143 -42
- mlrun/datastore/base.py +102 -58
- mlrun/datastore/datastore.py +34 -13
- mlrun/datastore/datastore_profile.py +146 -20
- mlrun/datastore/dbfs_store.py +3 -7
- mlrun/datastore/filestore.py +1 -4
- mlrun/datastore/google_cloud_storage.py +97 -33
- mlrun/datastore/hdfs.py +56 -0
- mlrun/datastore/inmem.py +6 -3
- mlrun/datastore/redis.py +7 -2
- mlrun/datastore/s3.py +34 -12
- mlrun/datastore/snowflake_utils.py +45 -0
- mlrun/datastore/sources.py +303 -111
- mlrun/datastore/spark_utils.py +31 -2
- mlrun/datastore/store_resources.py +9 -7
- mlrun/datastore/storeytargets.py +151 -0
- mlrun/datastore/targets.py +453 -176
- mlrun/datastore/utils.py +72 -58
- mlrun/datastore/v3io.py +6 -1
- mlrun/db/base.py +274 -41
- mlrun/db/factory.py +1 -1
- mlrun/db/httpdb.py +893 -225
- mlrun/db/nopdb.py +291 -33
- mlrun/errors.py +36 -6
- mlrun/execution.py +115 -42
- mlrun/feature_store/__init__.py +0 -2
- mlrun/feature_store/api.py +65 -73
- mlrun/feature_store/common.py +7 -12
- mlrun/feature_store/feature_set.py +76 -55
- mlrun/feature_store/feature_vector.py +39 -31
- mlrun/feature_store/ingestion.py +7 -6
- mlrun/feature_store/retrieval/base.py +16 -11
- mlrun/feature_store/retrieval/dask_merger.py +2 -0
- mlrun/feature_store/retrieval/job.py +13 -4
- mlrun/feature_store/retrieval/local_merger.py +2 -0
- mlrun/feature_store/retrieval/spark_merger.py +24 -32
- mlrun/feature_store/steps.py +45 -34
- mlrun/features.py +11 -21
- mlrun/frameworks/_common/artifacts_library.py +9 -9
- mlrun/frameworks/_common/mlrun_interface.py +5 -5
- mlrun/frameworks/_common/model_handler.py +48 -48
- mlrun/frameworks/_common/plan.py +5 -6
- mlrun/frameworks/_common/producer.py +3 -4
- mlrun/frameworks/_common/utils.py +5 -5
- mlrun/frameworks/_dl_common/loggers/logger.py +6 -7
- mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +9 -9
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +23 -47
- mlrun/frameworks/_ml_common/artifacts_library.py +1 -2
- mlrun/frameworks/_ml_common/loggers/logger.py +3 -4
- mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +4 -5
- mlrun/frameworks/_ml_common/model_handler.py +24 -24
- mlrun/frameworks/_ml_common/pkl_model_server.py +2 -2
- mlrun/frameworks/_ml_common/plan.py +2 -2
- mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
- mlrun/frameworks/_ml_common/utils.py +4 -4
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +9 -9
- mlrun/frameworks/huggingface/model_server.py +4 -4
- mlrun/frameworks/lgbm/__init__.py +33 -33
- mlrun/frameworks/lgbm/callbacks/callback.py +2 -4
- mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -5
- mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -5
- mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +1 -3
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +6 -6
- mlrun/frameworks/lgbm/model_handler.py +10 -10
- mlrun/frameworks/lgbm/model_server.py +6 -6
- mlrun/frameworks/lgbm/utils.py +5 -5
- mlrun/frameworks/onnx/dataset.py +8 -8
- mlrun/frameworks/onnx/mlrun_interface.py +3 -3
- mlrun/frameworks/onnx/model_handler.py +6 -6
- mlrun/frameworks/onnx/model_server.py +7 -7
- mlrun/frameworks/parallel_coordinates.py +6 -6
- mlrun/frameworks/pytorch/__init__.py +18 -18
- mlrun/frameworks/pytorch/callbacks/callback.py +4 -5
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +17 -17
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +23 -29
- mlrun/frameworks/pytorch/callbacks_handler.py +38 -38
- mlrun/frameworks/pytorch/mlrun_interface.py +20 -20
- mlrun/frameworks/pytorch/model_handler.py +17 -17
- mlrun/frameworks/pytorch/model_server.py +7 -7
- mlrun/frameworks/sklearn/__init__.py +13 -13
- mlrun/frameworks/sklearn/estimator.py +4 -4
- mlrun/frameworks/sklearn/metrics_library.py +14 -14
- mlrun/frameworks/sklearn/mlrun_interface.py +16 -9
- mlrun/frameworks/sklearn/model_handler.py +2 -2
- mlrun/frameworks/tf_keras/__init__.py +10 -7
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +15 -15
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +19 -23
- mlrun/frameworks/tf_keras/mlrun_interface.py +9 -11
- mlrun/frameworks/tf_keras/model_handler.py +14 -14
- mlrun/frameworks/tf_keras/model_server.py +6 -6
- mlrun/frameworks/xgboost/__init__.py +13 -13
- mlrun/frameworks/xgboost/model_handler.py +6 -6
- mlrun/k8s_utils.py +61 -17
- mlrun/launcher/__init__.py +1 -1
- mlrun/launcher/base.py +16 -15
- mlrun/launcher/client.py +13 -11
- mlrun/launcher/factory.py +1 -1
- mlrun/launcher/local.py +23 -13
- mlrun/launcher/remote.py +17 -10
- mlrun/lists.py +7 -6
- mlrun/model.py +478 -103
- mlrun/model_monitoring/__init__.py +1 -1
- mlrun/model_monitoring/api.py +163 -371
- mlrun/{runtimes/mpijob/v1alpha1.py → model_monitoring/applications/__init__.py} +9 -15
- mlrun/model_monitoring/applications/_application_steps.py +188 -0
- mlrun/model_monitoring/applications/base.py +108 -0
- mlrun/model_monitoring/applications/context.py +341 -0
- mlrun/model_monitoring/{evidently_application.py → applications/evidently_base.py} +27 -22
- mlrun/model_monitoring/applications/histogram_data_drift.py +354 -0
- mlrun/model_monitoring/applications/results.py +99 -0
- mlrun/model_monitoring/controller.py +131 -278
- mlrun/model_monitoring/db/__init__.py +18 -0
- mlrun/model_monitoring/db/stores/__init__.py +136 -0
- mlrun/model_monitoring/db/stores/base/__init__.py +15 -0
- mlrun/model_monitoring/db/stores/base/store.py +213 -0
- mlrun/model_monitoring/db/stores/sqldb/__init__.py +13 -0
- mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +71 -0
- mlrun/model_monitoring/db/stores/sqldb/models/base.py +190 -0
- mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +103 -0
- mlrun/model_monitoring/{stores/models/mysql.py → db/stores/sqldb/models/sqlite.py} +19 -13
- mlrun/model_monitoring/db/stores/sqldb/sql_store.py +659 -0
- mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +13 -0
- mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +726 -0
- mlrun/model_monitoring/db/tsdb/__init__.py +105 -0
- mlrun/model_monitoring/db/tsdb/base.py +448 -0
- mlrun/model_monitoring/db/tsdb/helpers.py +30 -0
- mlrun/model_monitoring/db/tsdb/tdengine/__init__.py +15 -0
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +279 -0
- mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +42 -0
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +507 -0
- mlrun/model_monitoring/db/tsdb/v3io/__init__.py +15 -0
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +158 -0
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +849 -0
- mlrun/model_monitoring/features_drift_table.py +134 -106
- mlrun/model_monitoring/helpers.py +199 -55
- mlrun/model_monitoring/metrics/__init__.py +13 -0
- mlrun/model_monitoring/metrics/histogram_distance.py +127 -0
- mlrun/model_monitoring/model_endpoint.py +3 -2
- mlrun/model_monitoring/stream_processing.py +131 -398
- mlrun/model_monitoring/tracking_policy.py +9 -2
- mlrun/model_monitoring/writer.py +161 -125
- mlrun/package/__init__.py +6 -6
- mlrun/package/context_handler.py +5 -5
- mlrun/package/packager.py +7 -7
- mlrun/package/packagers/default_packager.py +8 -8
- mlrun/package/packagers/numpy_packagers.py +15 -15
- mlrun/package/packagers/pandas_packagers.py +5 -5
- mlrun/package/packagers/python_standard_library_packagers.py +10 -10
- mlrun/package/packagers_manager.py +19 -23
- mlrun/package/utils/_formatter.py +6 -6
- mlrun/package/utils/_pickler.py +2 -2
- mlrun/package/utils/_supported_format.py +4 -4
- mlrun/package/utils/log_hint_utils.py +2 -2
- mlrun/package/utils/type_hint_utils.py +4 -9
- mlrun/platforms/__init__.py +11 -10
- mlrun/platforms/iguazio.py +24 -203
- mlrun/projects/operations.py +52 -25
- mlrun/projects/pipelines.py +191 -197
- mlrun/projects/project.py +1227 -400
- mlrun/render.py +16 -19
- mlrun/run.py +209 -184
- mlrun/runtimes/__init__.py +83 -15
- mlrun/runtimes/base.py +51 -35
- mlrun/runtimes/daskjob.py +17 -10
- mlrun/runtimes/databricks_job/databricks_cancel_task.py +1 -1
- mlrun/runtimes/databricks_job/databricks_runtime.py +8 -7
- mlrun/runtimes/databricks_job/databricks_wrapper.py +1 -1
- mlrun/runtimes/funcdoc.py +1 -29
- mlrun/runtimes/function_reference.py +1 -1
- mlrun/runtimes/kubejob.py +34 -128
- mlrun/runtimes/local.py +40 -11
- mlrun/runtimes/mpijob/__init__.py +0 -20
- mlrun/runtimes/mpijob/abstract.py +9 -10
- mlrun/runtimes/mpijob/v1.py +1 -1
- mlrun/{model_monitoring/stores/models/sqlite.py → runtimes/nuclio/__init__.py} +7 -9
- mlrun/runtimes/nuclio/api_gateway.py +769 -0
- mlrun/runtimes/nuclio/application/__init__.py +15 -0
- mlrun/runtimes/nuclio/application/application.py +758 -0
- mlrun/runtimes/nuclio/application/reverse_proxy.go +95 -0
- mlrun/runtimes/{function.py → nuclio/function.py} +200 -83
- mlrun/runtimes/{nuclio.py → nuclio/nuclio.py} +6 -6
- mlrun/runtimes/{serving.py → nuclio/serving.py} +65 -68
- mlrun/runtimes/pod.py +281 -101
- mlrun/runtimes/remotesparkjob.py +12 -9
- mlrun/runtimes/sparkjob/spark3job.py +67 -51
- mlrun/runtimes/utils.py +41 -75
- mlrun/secrets.py +9 -5
- mlrun/serving/__init__.py +8 -1
- mlrun/serving/remote.py +2 -7
- mlrun/serving/routers.py +85 -69
- mlrun/serving/server.py +69 -44
- mlrun/serving/states.py +209 -36
- mlrun/serving/utils.py +22 -14
- mlrun/serving/v1_serving.py +6 -7
- mlrun/serving/v2_serving.py +129 -54
- mlrun/track/tracker.py +2 -1
- mlrun/track/tracker_manager.py +3 -3
- mlrun/track/trackers/mlflow_tracker.py +6 -2
- mlrun/utils/async_http.py +6 -8
- mlrun/utils/azure_vault.py +1 -1
- mlrun/utils/clones.py +1 -2
- mlrun/utils/condition_evaluator.py +3 -3
- mlrun/utils/db.py +21 -3
- mlrun/utils/helpers.py +405 -225
- mlrun/utils/http.py +3 -6
- mlrun/utils/logger.py +112 -16
- mlrun/utils/notifications/notification/__init__.py +17 -13
- mlrun/utils/notifications/notification/base.py +50 -2
- mlrun/utils/notifications/notification/console.py +2 -0
- mlrun/utils/notifications/notification/git.py +24 -1
- mlrun/utils/notifications/notification/ipython.py +3 -1
- mlrun/utils/notifications/notification/slack.py +96 -21
- mlrun/utils/notifications/notification/webhook.py +59 -2
- mlrun/utils/notifications/notification_pusher.py +149 -30
- mlrun/utils/regex.py +9 -0
- mlrun/utils/retryer.py +208 -0
- mlrun/utils/singleton.py +1 -1
- mlrun/utils/v3io_clients.py +4 -6
- mlrun/utils/version/version.json +2 -2
- mlrun/utils/version/version.py +2 -6
- mlrun-1.7.0.dist-info/METADATA +378 -0
- mlrun-1.7.0.dist-info/RECORD +351 -0
- {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/WHEEL +1 -1
- mlrun/feature_store/retrieval/conversion.py +0 -273
- mlrun/kfpops.py +0 -868
- mlrun/model_monitoring/application.py +0 -310
- mlrun/model_monitoring/batch.py +0 -1095
- mlrun/model_monitoring/prometheus.py +0 -219
- mlrun/model_monitoring/stores/__init__.py +0 -111
- mlrun/model_monitoring/stores/kv_model_endpoint_store.py +0 -576
- mlrun/model_monitoring/stores/model_endpoint_store.py +0 -147
- mlrun/model_monitoring/stores/models/__init__.py +0 -27
- mlrun/model_monitoring/stores/models/base.py +0 -84
- mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -384
- mlrun/platforms/other.py +0 -306
- mlrun-1.6.4rc7.dist-info/METADATA +0 -272
- mlrun-1.6.4rc7.dist-info/RECORD +0 -314
- {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/LICENSE +0 -0
- {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/entry_points.txt +0 -0
- {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any
|
|
15
|
+
from typing import Any
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import transformers
|
|
@@ -65,7 +65,7 @@ class HuggingFaceModelServer(V2ModelServer):
|
|
|
65
65
|
framework of the `model`, or to PyTorch if no model is provided
|
|
66
66
|
:param class_args: -
|
|
67
67
|
"""
|
|
68
|
-
super(
|
|
68
|
+
super().__init__(
|
|
69
69
|
context=context,
|
|
70
70
|
name=name,
|
|
71
71
|
model_path=model_path,
|
|
@@ -104,7 +104,7 @@ class HuggingFaceModelServer(V2ModelServer):
|
|
|
104
104
|
framework=self.framework,
|
|
105
105
|
)
|
|
106
106
|
|
|
107
|
-
def predict(self, request:
|
|
107
|
+
def predict(self, request: dict[str, Any]) -> list:
|
|
108
108
|
"""
|
|
109
109
|
Generate model predictions from sample.
|
|
110
110
|
:param request: The request to the model. The input to the model will be read from the "inputs" key.
|
|
@@ -135,7 +135,7 @@ class HuggingFaceModelServer(V2ModelServer):
|
|
|
135
135
|
|
|
136
136
|
return result
|
|
137
137
|
|
|
138
|
-
def explain(self, request:
|
|
138
|
+
def explain(self, request: dict) -> str:
|
|
139
139
|
"""
|
|
140
140
|
Return a string explaining what model is being served in this serving function and the function name.
|
|
141
141
|
:param request: A given request.
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
|
|
16
|
-
from typing import Any,
|
|
16
|
+
from typing import Any, Union
|
|
17
17
|
|
|
18
18
|
import lightgbm as lgb
|
|
19
19
|
|
|
@@ -37,20 +37,20 @@ LGBMArtifactsLibrary = MLArtifactsLibrary
|
|
|
37
37
|
def _apply_mlrun_on_module(
|
|
38
38
|
model_name: str = "model",
|
|
39
39
|
tag: str = "",
|
|
40
|
-
modules_map: Union[
|
|
41
|
-
custom_objects_map: Union[
|
|
40
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
41
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
42
42
|
custom_objects_directory: str = None,
|
|
43
43
|
context: mlrun.MLClientCtx = None,
|
|
44
44
|
model_format: str = LGBMModelHandler.ModelFormats.PKL,
|
|
45
45
|
sample_set: Union[LGBMTypes.DatasetType, mlrun.DataItem, str] = None,
|
|
46
|
-
y_columns: Union[
|
|
46
|
+
y_columns: Union[list[str], list[int]] = None,
|
|
47
47
|
feature_vector: str = None,
|
|
48
|
-
feature_weights:
|
|
49
|
-
labels:
|
|
50
|
-
parameters:
|
|
51
|
-
extra_data:
|
|
48
|
+
feature_weights: list[float] = None,
|
|
49
|
+
labels: dict[str, Union[str, int, float]] = None,
|
|
50
|
+
parameters: dict[str, Union[str, int, float]] = None,
|
|
51
|
+
extra_data: dict[str, LGBMTypes.ExtraDataType] = None,
|
|
52
52
|
auto_log: bool = True,
|
|
53
|
-
mlrun_logging_callback_kwargs:
|
|
53
|
+
mlrun_logging_callback_kwargs: dict[str, Any] = None,
|
|
54
54
|
):
|
|
55
55
|
# Apply MLRun's interface on the LightGBM module:
|
|
56
56
|
LGBMMLRunInterface.add_interface(obj=lgb)
|
|
@@ -85,26 +85,26 @@ def _apply_mlrun_on_model(
|
|
|
85
85
|
model_name: str = "model",
|
|
86
86
|
tag: str = "",
|
|
87
87
|
model_path: str = None,
|
|
88
|
-
modules_map: Union[
|
|
89
|
-
custom_objects_map: Union[
|
|
88
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
89
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
90
90
|
custom_objects_directory: str = None,
|
|
91
91
|
context: mlrun.MLClientCtx = None,
|
|
92
92
|
model_format: str = LGBMModelHandler.ModelFormats.PKL,
|
|
93
|
-
artifacts: Union[
|
|
93
|
+
artifacts: Union[list[MLPlan], list[str], dict[str, dict]] = None,
|
|
94
94
|
metrics: Union[
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
95
|
+
list[Metric],
|
|
96
|
+
list[LGBMTypes.MetricEntryType],
|
|
97
|
+
dict[str, LGBMTypes.MetricEntryType],
|
|
98
98
|
] = None,
|
|
99
99
|
x_test: LGBMTypes.DatasetType = None,
|
|
100
100
|
y_test: LGBMTypes.DatasetType = None,
|
|
101
101
|
sample_set: Union[LGBMTypes.DatasetType, mlrun.DataItem, str] = None,
|
|
102
|
-
y_columns: Union[
|
|
102
|
+
y_columns: Union[list[str], list[int]] = None,
|
|
103
103
|
feature_vector: str = None,
|
|
104
|
-
feature_weights:
|
|
105
|
-
labels:
|
|
106
|
-
parameters:
|
|
107
|
-
extra_data:
|
|
104
|
+
feature_weights: list[float] = None,
|
|
105
|
+
labels: dict[str, Union[str, int, float]] = None,
|
|
106
|
+
parameters: dict[str, Union[str, int, float]] = None,
|
|
107
|
+
extra_data: dict[str, LGBMTypes.ExtraDataType] = None,
|
|
108
108
|
auto_log: bool = True,
|
|
109
109
|
**kwargs,
|
|
110
110
|
):
|
|
@@ -183,28 +183,28 @@ def apply_mlrun(
|
|
|
183
183
|
model_name: str = "model",
|
|
184
184
|
tag: str = "",
|
|
185
185
|
model_path: str = None,
|
|
186
|
-
modules_map: Union[
|
|
187
|
-
custom_objects_map: Union[
|
|
186
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
187
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
188
188
|
custom_objects_directory: str = None,
|
|
189
189
|
context: mlrun.MLClientCtx = None,
|
|
190
190
|
model_format: str = LGBMModelHandler.ModelFormats.PKL,
|
|
191
|
-
artifacts: Union[
|
|
191
|
+
artifacts: Union[list[MLPlan], list[str], dict[str, dict]] = None,
|
|
192
192
|
metrics: Union[
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
193
|
+
list[Metric],
|
|
194
|
+
list[LGBMTypes.MetricEntryType],
|
|
195
|
+
dict[str, LGBMTypes.MetricEntryType],
|
|
196
196
|
] = None,
|
|
197
197
|
x_test: LGBMTypes.DatasetType = None,
|
|
198
198
|
y_test: LGBMTypes.DatasetType = None,
|
|
199
199
|
sample_set: Union[LGBMTypes.DatasetType, mlrun.DataItem, str] = None,
|
|
200
|
-
y_columns: Union[
|
|
200
|
+
y_columns: Union[list[str], list[int]] = None,
|
|
201
201
|
feature_vector: str = None,
|
|
202
|
-
feature_weights:
|
|
203
|
-
labels:
|
|
204
|
-
parameters:
|
|
205
|
-
extra_data:
|
|
202
|
+
feature_weights: list[float] = None,
|
|
203
|
+
labels: dict[str, Union[str, int, float]] = None,
|
|
204
|
+
parameters: dict[str, Union[str, int, float]] = None,
|
|
205
|
+
extra_data: dict[str, LGBMTypes.ExtraDataType] = None,
|
|
206
206
|
auto_log: bool = True,
|
|
207
|
-
mlrun_logging_callback_kwargs:
|
|
207
|
+
mlrun_logging_callback_kwargs: dict[str, Any] = None,
|
|
208
208
|
**kwargs,
|
|
209
209
|
) -> Union[LGBMModelHandler, None]:
|
|
210
210
|
"""
|
|
@@ -241,7 +241,7 @@ def apply_mlrun(
|
|
|
241
241
|
|
|
242
242
|
{
|
|
243
243
|
"/.../custom_model.py": "MyModel",
|
|
244
|
-
"/.../custom_objects.py": ["object1", "object2"]
|
|
244
|
+
"/.../custom_objects.py": ["object1", "object2"],
|
|
245
245
|
}
|
|
246
246
|
|
|
247
247
|
All the paths will be accessed from the given 'custom_objects_directory', meaning
|
|
@@ -63,11 +63,9 @@ class Callback(ABC):
|
|
|
63
63
|
def on_train_end(self):
|
|
64
64
|
print("{self.name}: Done training!")
|
|
65
65
|
|
|
66
|
+
|
|
66
67
|
apply_mlrun()
|
|
67
|
-
lgb.train(
|
|
68
|
-
...,
|
|
69
|
-
callbacks=[ExampleCallback(name="Example")]
|
|
70
|
-
)
|
|
68
|
+
lgb.train(..., callbacks=[ExampleCallback(name="Example")])
|
|
71
69
|
"""
|
|
72
70
|
|
|
73
71
|
def __init__(self, order: int = 10, before_iteration: bool = False):
|
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import List
|
|
16
15
|
|
|
17
16
|
from ..._ml_common.loggers import Logger
|
|
18
17
|
from ..utils import LGBMTypes
|
|
@@ -26,8 +25,8 @@ class LoggingCallback(Callback):
|
|
|
26
25
|
|
|
27
26
|
def __init__(
|
|
28
27
|
self,
|
|
29
|
-
dynamic_hyperparameters:
|
|
30
|
-
static_hyperparameters:
|
|
28
|
+
dynamic_hyperparameters: list[str] = None,
|
|
29
|
+
static_hyperparameters: list[str] = None,
|
|
31
30
|
):
|
|
32
31
|
"""
|
|
33
32
|
Initialize the logging callback with the given configuration. All the metrics data will be collected but the
|
|
@@ -41,7 +40,7 @@ class LoggingCallback(Callback):
|
|
|
41
40
|
The parameter expects a list of all the hyperparameters names to track our of
|
|
42
41
|
the `params` dictionary.
|
|
43
42
|
"""
|
|
44
|
-
super(
|
|
43
|
+
super().__init__()
|
|
45
44
|
self._logger = Logger()
|
|
46
45
|
self._dynamic_hyperparameters_keys = (
|
|
47
46
|
dynamic_hyperparameters if dynamic_hyperparameters is not None else {}
|
|
@@ -76,7 +75,7 @@ class LoggingCallback(Callback):
|
|
|
76
75
|
self._log_hyperparameters(parameters=env.params)
|
|
77
76
|
|
|
78
77
|
def _log_results(
|
|
79
|
-
self, evaluation_result_list:
|
|
78
|
+
self, evaluation_result_list: list[LGBMTypes.EvaluationResultType]
|
|
80
79
|
):
|
|
81
80
|
"""
|
|
82
81
|
Log the callback environment results data into the logger.
|
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import List
|
|
16
15
|
|
|
17
16
|
import mlrun
|
|
18
17
|
|
|
@@ -34,8 +33,8 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
34
33
|
def __init__(
|
|
35
34
|
self,
|
|
36
35
|
context: mlrun.MLClientCtx,
|
|
37
|
-
dynamic_hyperparameters:
|
|
38
|
-
static_hyperparameters:
|
|
36
|
+
dynamic_hyperparameters: list[str] = None,
|
|
37
|
+
static_hyperparameters: list[str] = None,
|
|
39
38
|
logging_frequency: int = 100,
|
|
40
39
|
):
|
|
41
40
|
"""
|
|
@@ -55,7 +54,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
55
54
|
them and the results to MLRun). Two low frequency may slow the training time.
|
|
56
55
|
Default: 100.
|
|
57
56
|
"""
|
|
58
|
-
super(
|
|
57
|
+
super().__init__(
|
|
59
58
|
dynamic_hyperparameters=dynamic_hyperparameters,
|
|
60
59
|
static_hyperparameters=static_hyperparameters,
|
|
61
60
|
)
|
|
@@ -75,7 +74,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
75
74
|
information check the `Callback` doc string.
|
|
76
75
|
"""
|
|
77
76
|
# Log the results and parameters:
|
|
78
|
-
super(
|
|
77
|
+
super().__call__(env=env)
|
|
79
78
|
|
|
80
79
|
# Produce the artifacts (post iteration stage):
|
|
81
80
|
if env.iteration % self._logging_frequency == 0:
|
|
@@ -43,6 +43,4 @@ class LGBMBoosterMLRunInterface(MLRunInterface, ABC):
|
|
|
43
43
|
:param restoration: Restoration information tuple as returned from 'remove_interface' in order to add the
|
|
44
44
|
interface in a certain state.
|
|
45
45
|
"""
|
|
46
|
-
super(
|
|
47
|
-
obj=obj, restoration=restoration
|
|
48
|
-
)
|
|
46
|
+
super().add_interface(obj=obj, restoration=restoration)
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
#
|
|
15
15
|
from abc import ABC
|
|
16
16
|
from types import ModuleType
|
|
17
|
-
from typing import Callable,
|
|
17
|
+
from typing import Callable, Union
|
|
18
18
|
|
|
19
19
|
import lightgbm as lgb
|
|
20
20
|
|
|
@@ -88,7 +88,7 @@ class LGBMMLRunInterface(MLRunInterface, ABC):
|
|
|
88
88
|
globals().update({"lightgbm": lgb, "lgb": lgb})
|
|
89
89
|
|
|
90
90
|
# Add the interface to the provided lightgbm module:
|
|
91
|
-
super(
|
|
91
|
+
super().add_interface(obj=obj, restoration=restoration)
|
|
92
92
|
|
|
93
93
|
@staticmethod
|
|
94
94
|
def mlrun_train(*args, **kwargs):
|
|
@@ -223,7 +223,7 @@ class LGBMMLRunInterface(MLRunInterface, ABC):
|
|
|
223
223
|
pass
|
|
224
224
|
|
|
225
225
|
@staticmethod
|
|
226
|
-
def _parse_callbacks(callbacks:
|
|
226
|
+
def _parse_callbacks(callbacks: list[Callable]):
|
|
227
227
|
"""
|
|
228
228
|
Parse the callbacks passed to the training API functions of LightGBM for adding logging and enabling the MLRun
|
|
229
229
|
callbacks API.
|
|
@@ -259,9 +259,9 @@ class LGBMMLRunInterface(MLRunInterface, ABC):
|
|
|
259
259
|
@staticmethod
|
|
260
260
|
def _post_train(
|
|
261
261
|
booster: lgb.Booster,
|
|
262
|
-
train_set:
|
|
263
|
-
validation_sets:
|
|
264
|
-
|
|
262
|
+
train_set: tuple[MLTypes.DatasetType, Union[MLTypes.DatasetType, None]],
|
|
263
|
+
validation_sets: list[
|
|
264
|
+
tuple[tuple[MLTypes.DatasetType, Union[MLTypes.DatasetType, None]], str]
|
|
265
265
|
],
|
|
266
266
|
):
|
|
267
267
|
"""
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
#
|
|
15
15
|
import os
|
|
16
16
|
import pickle
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import Union
|
|
18
18
|
|
|
19
19
|
import cloudpickle
|
|
20
20
|
import lightgbm as lgb
|
|
@@ -56,8 +56,8 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
56
56
|
model_name: str = None,
|
|
57
57
|
model_path: str = None,
|
|
58
58
|
model: LGBMTypes.ModelType = None,
|
|
59
|
-
modules_map: Union[
|
|
60
|
-
custom_objects_map: Union[
|
|
59
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
60
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
61
61
|
custom_objects_directory: str = None,
|
|
62
62
|
context: mlrun.MLClientCtx = None,
|
|
63
63
|
model_format: str = ModelFormats.PKL,
|
|
@@ -103,7 +103,7 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
103
103
|
|
|
104
104
|
{
|
|
105
105
|
"/.../custom_model.py": "MyModel",
|
|
106
|
-
"/.../custom_objects.py": ["object1", "object2"]
|
|
106
|
+
"/.../custom_objects.py": ["object1", "object2"],
|
|
107
107
|
}
|
|
108
108
|
|
|
109
109
|
All the paths will be accessed from the given 'custom_objects_directory',
|
|
@@ -139,7 +139,7 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
139
139
|
self._model_format = model_format
|
|
140
140
|
|
|
141
141
|
# Set up the base handler class:
|
|
142
|
-
super(
|
|
142
|
+
super().__init__(
|
|
143
143
|
model=model,
|
|
144
144
|
model_path=model_path,
|
|
145
145
|
model_name=model_name,
|
|
@@ -152,8 +152,8 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
152
152
|
|
|
153
153
|
def set_labels(
|
|
154
154
|
self,
|
|
155
|
-
to_add:
|
|
156
|
-
to_remove:
|
|
155
|
+
to_add: dict[str, Union[str, int, float]] = None,
|
|
156
|
+
to_remove: list[str] = None,
|
|
157
157
|
):
|
|
158
158
|
"""
|
|
159
159
|
Update the labels dictionary of this model artifact. There are required labels that cannot be edited or removed.
|
|
@@ -162,7 +162,7 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
162
162
|
:param to_remove: A list of labels keys to remove.
|
|
163
163
|
"""
|
|
164
164
|
# Update the user's labels:
|
|
165
|
-
super(
|
|
165
|
+
super().set_labels(to_add=to_add, to_remove=to_remove)
|
|
166
166
|
|
|
167
167
|
# Set the required labels:
|
|
168
168
|
self._labels[self._LabelKeys.MODEL_FORMAT] = self._model_format
|
|
@@ -193,7 +193,7 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
193
193
|
|
|
194
194
|
:return The saved model additional artifacts (if needed) dictionary if context is available and None otherwise.
|
|
195
195
|
"""
|
|
196
|
-
super(
|
|
196
|
+
super().save(output_path=output_path)
|
|
197
197
|
|
|
198
198
|
if isinstance(self._model, lgb.LGBMModel):
|
|
199
199
|
return self._save_lgbmmodel()
|
|
@@ -204,7 +204,7 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
204
204
|
Load the specified model in this handler. Additional parameters for the class initializer can be passed via the
|
|
205
205
|
kwargs dictionary.
|
|
206
206
|
"""
|
|
207
|
-
super(
|
|
207
|
+
super().load()
|
|
208
208
|
|
|
209
209
|
# ModelFormats.PKL - Load from a pkl file:
|
|
210
210
|
if self._model_format == LGBMModelHandler.ModelFormats.PKL:
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
|
|
@@ -37,8 +37,8 @@ class LGBMModelServer(V2ModelServer):
|
|
|
37
37
|
model_path: LGBMTypes.PathType = None,
|
|
38
38
|
model_name: str = None,
|
|
39
39
|
model_format: str = LGBMModelHandler.ModelFormats.PKL,
|
|
40
|
-
modules_map: Union[
|
|
41
|
-
custom_objects_map: Union[
|
|
40
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
41
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
42
42
|
custom_objects_directory: str = None,
|
|
43
43
|
to_list: bool = True,
|
|
44
44
|
protocol: str = None,
|
|
@@ -100,7 +100,7 @@ class LGBMModelServer(V2ModelServer):
|
|
|
100
100
|
:param protocol: -
|
|
101
101
|
:param class_args: -
|
|
102
102
|
"""
|
|
103
|
-
super(
|
|
103
|
+
super().__init__(
|
|
104
104
|
context=context,
|
|
105
105
|
name=name,
|
|
106
106
|
model_path=model_path,
|
|
@@ -139,7 +139,7 @@ class LGBMModelServer(V2ModelServer):
|
|
|
139
139
|
self._model_handler.load()
|
|
140
140
|
self.model = self._model_handler.model
|
|
141
141
|
|
|
142
|
-
def predict(self, request:
|
|
142
|
+
def predict(self, request: dict[str, Any]) -> Union[np.ndarray, list]:
|
|
143
143
|
"""
|
|
144
144
|
Infer the inputs through the model using MLRun's PyTorch interface and return its output. The inferred data will
|
|
145
145
|
be read from the "inputs" key of the request.
|
|
@@ -158,7 +158,7 @@ class LGBMModelServer(V2ModelServer):
|
|
|
158
158
|
# Return as list if required:
|
|
159
159
|
return predictions if not self.to_list else predictions.tolist()
|
|
160
160
|
|
|
161
|
-
def explain(self, request:
|
|
161
|
+
def explain(self, request: dict[str, Any]) -> str:
|
|
162
162
|
"""
|
|
163
163
|
Return a string explaining what model is being served in this serving function and the function name.
|
|
164
164
|
|
mlrun/frameworks/lgbm/utils.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import Union
|
|
16
16
|
|
|
17
17
|
import lightgbm as lgb
|
|
18
18
|
import numpy as np
|
|
@@ -36,13 +36,13 @@ class LGBMTypes(MLTypes):
|
|
|
36
36
|
|
|
37
37
|
# An evaluation result as packaged by the training in LightGBM:
|
|
38
38
|
EvaluationResultType = Union[
|
|
39
|
-
|
|
40
|
-
|
|
39
|
+
tuple[str, str, float, bool], # As packaged in `lightgbm.train`
|
|
40
|
+
tuple[str, str, float, bool, float], # As packaged in `lightgbm.cv`
|
|
41
41
|
]
|
|
42
42
|
|
|
43
43
|
# Detailed type for the named tuple `CallbackEnv` passed during LightGBM's training for the callbacks:
|
|
44
|
-
CallbackEnvType =
|
|
45
|
-
lgb.Booster, dict, int, int, int,
|
|
44
|
+
CallbackEnvType = tuple[
|
|
45
|
+
lgb.Booster, dict, int, int, int, list[EvaluationResultType]
|
|
46
46
|
]
|
|
47
47
|
|
|
48
48
|
|
mlrun/frameworks/onnx/dataset.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
import math
|
|
16
|
-
from typing import Callable,
|
|
16
|
+
from typing import Callable, Union
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
|
|
@@ -25,11 +25,11 @@ class ONNXDataset:
|
|
|
25
25
|
|
|
26
26
|
def __init__(
|
|
27
27
|
self,
|
|
28
|
-
x: Union[np.ndarray,
|
|
29
|
-
y: Union[np.ndarray,
|
|
28
|
+
x: Union[np.ndarray, list[np.ndarray]],
|
|
29
|
+
y: Union[np.ndarray, list[np.ndarray]] = None,
|
|
30
30
|
batch_size: int = 1,
|
|
31
|
-
x_transforms:
|
|
32
|
-
y_transforms:
|
|
31
|
+
x_transforms: list[Callable[[np.ndarray], np.ndarray]] = None,
|
|
32
|
+
y_transforms: list[Callable[[np.ndarray], np.ndarray]] = None,
|
|
33
33
|
is_batched_transforms: bool = False,
|
|
34
34
|
):
|
|
35
35
|
"""
|
|
@@ -71,7 +71,7 @@ class ONNXDataset:
|
|
|
71
71
|
self._index = 0
|
|
72
72
|
return self
|
|
73
73
|
|
|
74
|
-
def __next__(self) -> Union[np.ndarray,
|
|
74
|
+
def __next__(self) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray]]:
|
|
75
75
|
"""
|
|
76
76
|
Get the next item in line (by the inner index) since calling '__iter__'. If ground truth was provided (y),
|
|
77
77
|
a tuple of (x, y) will be returned. Otherwise x.
|
|
@@ -92,7 +92,7 @@ class ONNXDataset:
|
|
|
92
92
|
|
|
93
93
|
def __getitem__(
|
|
94
94
|
self, index: int
|
|
95
|
-
) -> Union[np.ndarray,
|
|
95
|
+
) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray]]:
|
|
96
96
|
"""
|
|
97
97
|
Get the item at the given index. If ground truth was provided, a tuple of (x, y) will be returned. Otherwise x.
|
|
98
98
|
|
|
@@ -155,7 +155,7 @@ class ONNXDataset:
|
|
|
155
155
|
def _call_transforms(
|
|
156
156
|
self,
|
|
157
157
|
items: np.ndarray,
|
|
158
|
-
transforms:
|
|
158
|
+
transforms: list[Callable[[np.ndarray], np.ndarray]],
|
|
159
159
|
is_batched: bool,
|
|
160
160
|
):
|
|
161
161
|
"""
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Callable
|
|
15
|
+
from typing import Callable
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import onnx
|
|
@@ -35,7 +35,7 @@ class ONNXMLRunInterface:
|
|
|
35
35
|
def __init__(
|
|
36
36
|
self,
|
|
37
37
|
model: onnx.ModelProto,
|
|
38
|
-
execution_providers:
|
|
38
|
+
execution_providers: list[str] = None,
|
|
39
39
|
context: mlrun.MLClientCtx = None,
|
|
40
40
|
):
|
|
41
41
|
# Set the context:
|
|
@@ -74,7 +74,7 @@ class ONNXMLRunInterface:
|
|
|
74
74
|
def evaluate(
|
|
75
75
|
self,
|
|
76
76
|
dataset: ONNXDataset,
|
|
77
|
-
metrics:
|
|
77
|
+
metrics: list[Callable[[np.ndarray, np.ndarray], float]],
|
|
78
78
|
):
|
|
79
79
|
pass
|
|
80
80
|
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
import os
|
|
16
|
-
from typing import
|
|
16
|
+
from typing import Union
|
|
17
17
|
|
|
18
18
|
import onnx
|
|
19
19
|
import onnxoptimizer
|
|
@@ -60,7 +60,7 @@ class ONNXModelHandler(ModelHandler):
|
|
|
60
60
|
:raise MLRunInvalidArgumentError: There was no model or model directory supplied.
|
|
61
61
|
"""
|
|
62
62
|
# Setup the base handler class:
|
|
63
|
-
super(
|
|
63
|
+
super().__init__(
|
|
64
64
|
model=model,
|
|
65
65
|
model_path=model_path,
|
|
66
66
|
model_name=model_name,
|
|
@@ -71,7 +71,7 @@ class ONNXModelHandler(ModelHandler):
|
|
|
71
71
|
# TODO: output_path won't work well with logging artifacts. Need to look into changing the logic of 'log_artifact'.
|
|
72
72
|
def save(
|
|
73
73
|
self, output_path: str = None, **kwargs
|
|
74
|
-
) -> Union[
|
|
74
|
+
) -> Union[dict[str, Artifact], None]:
|
|
75
75
|
"""
|
|
76
76
|
Save the handled model at the given output path. If a MLRun context is available, the saved model files will be
|
|
77
77
|
logged and returned as artifacts.
|
|
@@ -81,7 +81,7 @@ class ONNXModelHandler(ModelHandler):
|
|
|
81
81
|
|
|
82
82
|
:return The saved model additional artifacts (if needed) dictionary if context is available and None otherwise.
|
|
83
83
|
"""
|
|
84
|
-
super(
|
|
84
|
+
super().save(output_path=output_path)
|
|
85
85
|
|
|
86
86
|
# Set the output path:
|
|
87
87
|
if output_path is None:
|
|
@@ -97,7 +97,7 @@ class ONNXModelHandler(ModelHandler):
|
|
|
97
97
|
"""
|
|
98
98
|
Load the specified model in this handler.
|
|
99
99
|
"""
|
|
100
|
-
super(
|
|
100
|
+
super().load()
|
|
101
101
|
|
|
102
102
|
# Check that the model is well-formed:
|
|
103
103
|
# TODO: Currently not working well with HuggingFace models so we skip it
|
|
@@ -106,7 +106,7 @@ class ONNXModelHandler(ModelHandler):
|
|
|
106
106
|
# Load the ONNX model:
|
|
107
107
|
self._model = onnx.load(self._model_file)
|
|
108
108
|
|
|
109
|
-
def optimize(self, optimizations:
|
|
109
|
+
def optimize(self, optimizations: list[str] = None, fixed_point: bool = False):
|
|
110
110
|
"""
|
|
111
111
|
Use ONNX optimizer to optimize the ONNX model. The optimizations supported can be seen by calling
|
|
112
112
|
'onnxoptimizer.get_available_passes()'
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import onnx
|
|
@@ -37,7 +37,7 @@ class ONNXModelServer(V2ModelServer):
|
|
|
37
37
|
model: onnx.ModelProto = None,
|
|
38
38
|
model_path: str = None,
|
|
39
39
|
model_name: str = None,
|
|
40
|
-
execution_providers:
|
|
40
|
+
execution_providers: list[Union[str, tuple[str, dict[str, Any]]]] = None,
|
|
41
41
|
protocol: str = None,
|
|
42
42
|
**class_args,
|
|
43
43
|
):
|
|
@@ -76,7 +76,7 @@ class ONNXModelServer(V2ModelServer):
|
|
|
76
76
|
:param protocol: -
|
|
77
77
|
:param class_args: -
|
|
78
78
|
"""
|
|
79
|
-
super(
|
|
79
|
+
super().__init__(
|
|
80
80
|
context=context,
|
|
81
81
|
name=name,
|
|
82
82
|
model_path=model_path,
|
|
@@ -98,8 +98,8 @@ class ONNXModelServer(V2ModelServer):
|
|
|
98
98
|
# Prepare inference parameters:
|
|
99
99
|
self._model_handler: ONNXModelHandler = None
|
|
100
100
|
self._inference_session: onnxruntime.InferenceSession = None
|
|
101
|
-
self._input_layers:
|
|
102
|
-
self._output_layers:
|
|
101
|
+
self._input_layers: list[str] = None
|
|
102
|
+
self._output_layers: list[str] = None
|
|
103
103
|
|
|
104
104
|
def load(self):
|
|
105
105
|
"""
|
|
@@ -134,7 +134,7 @@ class ONNXModelServer(V2ModelServer):
|
|
|
134
134
|
output_layer.name for output_layer in self._inference_session.get_outputs()
|
|
135
135
|
]
|
|
136
136
|
|
|
137
|
-
def predict(self, request:
|
|
137
|
+
def predict(self, request: dict[str, Any]) -> np.ndarray:
|
|
138
138
|
"""
|
|
139
139
|
Infer the inputs through the model using ONNXRunTime and return its output. The inferred data will be
|
|
140
140
|
read from the "inputs" key of the request.
|
|
@@ -155,7 +155,7 @@ class ONNXModelServer(V2ModelServer):
|
|
|
155
155
|
},
|
|
156
156
|
)
|
|
157
157
|
|
|
158
|
-
def explain(self, request:
|
|
158
|
+
def explain(self, request: dict[str, Any]) -> str:
|
|
159
159
|
"""
|
|
160
160
|
Return a string explaining what model is being serve in this serving function and the function name.
|
|
161
161
|
|