mlrun 1.7.2rc3__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +26 -22
- mlrun/__main__.py +15 -16
- mlrun/alerts/alert.py +150 -15
- mlrun/api/schemas/__init__.py +1 -9
- mlrun/artifacts/__init__.py +2 -3
- mlrun/artifacts/base.py +62 -19
- mlrun/artifacts/dataset.py +17 -17
- mlrun/artifacts/document.py +454 -0
- mlrun/artifacts/manager.py +28 -18
- mlrun/artifacts/model.py +91 -59
- mlrun/artifacts/plots.py +2 -2
- mlrun/common/constants.py +8 -0
- mlrun/common/formatters/__init__.py +1 -0
- mlrun/common/formatters/artifact.py +1 -1
- mlrun/common/formatters/feature_set.py +2 -0
- mlrun/common/formatters/function.py +1 -0
- mlrun/{model_monitoring/db/stores/v3io_kv/__init__.py → common/formatters/model_endpoint.py} +17 -0
- mlrun/common/formatters/pipeline.py +1 -2
- mlrun/common/formatters/project.py +9 -0
- mlrun/common/model_monitoring/__init__.py +0 -5
- mlrun/common/model_monitoring/helpers.py +12 -62
- mlrun/common/runtimes/constants.py +25 -4
- mlrun/common/schemas/__init__.py +9 -5
- mlrun/common/schemas/alert.py +114 -19
- mlrun/common/schemas/api_gateway.py +3 -3
- mlrun/common/schemas/artifact.py +22 -9
- mlrun/common/schemas/auth.py +8 -4
- mlrun/common/schemas/background_task.py +7 -7
- mlrun/common/schemas/client_spec.py +4 -4
- mlrun/common/schemas/clusterization_spec.py +2 -2
- mlrun/common/schemas/common.py +53 -3
- mlrun/common/schemas/constants.py +15 -0
- mlrun/common/schemas/datastore_profile.py +1 -1
- mlrun/common/schemas/feature_store.py +9 -9
- mlrun/common/schemas/frontend_spec.py +4 -4
- mlrun/common/schemas/function.py +10 -10
- mlrun/common/schemas/hub.py +1 -1
- mlrun/common/schemas/k8s.py +3 -3
- mlrun/common/schemas/memory_reports.py +3 -3
- mlrun/common/schemas/model_monitoring/__init__.py +4 -8
- mlrun/common/schemas/model_monitoring/constants.py +127 -46
- mlrun/common/schemas/model_monitoring/grafana.py +18 -12
- mlrun/common/schemas/model_monitoring/model_endpoints.py +154 -160
- mlrun/common/schemas/notification.py +24 -3
- mlrun/common/schemas/object.py +1 -1
- mlrun/common/schemas/pagination.py +4 -4
- mlrun/common/schemas/partition.py +142 -0
- mlrun/common/schemas/pipeline.py +3 -3
- mlrun/common/schemas/project.py +26 -18
- mlrun/common/schemas/runs.py +3 -3
- mlrun/common/schemas/runtime_resource.py +5 -5
- mlrun/common/schemas/schedule.py +1 -1
- mlrun/common/schemas/secret.py +1 -1
- mlrun/{model_monitoring/db/stores/sqldb/__init__.py → common/schemas/serving.py} +10 -1
- mlrun/common/schemas/tag.py +3 -3
- mlrun/common/schemas/workflow.py +6 -5
- mlrun/common/types.py +1 -0
- mlrun/config.py +157 -89
- mlrun/data_types/__init__.py +5 -3
- mlrun/data_types/infer.py +13 -3
- mlrun/data_types/spark.py +2 -1
- mlrun/datastore/__init__.py +59 -18
- mlrun/datastore/alibaba_oss.py +4 -1
- mlrun/datastore/azure_blob.py +4 -1
- mlrun/datastore/base.py +19 -24
- mlrun/datastore/datastore.py +10 -4
- mlrun/datastore/datastore_profile.py +178 -45
- mlrun/datastore/dbfs_store.py +4 -1
- mlrun/datastore/filestore.py +4 -1
- mlrun/datastore/google_cloud_storage.py +4 -1
- mlrun/datastore/hdfs.py +4 -1
- mlrun/datastore/inmem.py +4 -1
- mlrun/datastore/redis.py +4 -1
- mlrun/datastore/s3.py +14 -3
- mlrun/datastore/sources.py +89 -92
- mlrun/datastore/store_resources.py +7 -4
- mlrun/datastore/storeytargets.py +51 -16
- mlrun/datastore/targets.py +38 -31
- mlrun/datastore/utils.py +87 -4
- mlrun/datastore/v3io.py +4 -1
- mlrun/datastore/vectorstore.py +291 -0
- mlrun/datastore/wasbfs/fs.py +13 -12
- mlrun/db/base.py +286 -100
- mlrun/db/httpdb.py +1562 -490
- mlrun/db/nopdb.py +250 -83
- mlrun/errors.py +6 -2
- mlrun/execution.py +194 -50
- mlrun/feature_store/__init__.py +2 -10
- mlrun/feature_store/api.py +20 -458
- mlrun/feature_store/common.py +9 -9
- mlrun/feature_store/feature_set.py +20 -18
- mlrun/feature_store/feature_vector.py +105 -479
- mlrun/feature_store/feature_vector_utils.py +466 -0
- mlrun/feature_store/retrieval/base.py +15 -11
- mlrun/feature_store/retrieval/job.py +2 -1
- mlrun/feature_store/retrieval/storey_merger.py +1 -1
- mlrun/feature_store/steps.py +3 -3
- mlrun/features.py +30 -13
- mlrun/frameworks/__init__.py +1 -2
- mlrun/frameworks/_common/__init__.py +1 -2
- mlrun/frameworks/_common/artifacts_library.py +2 -2
- mlrun/frameworks/_common/mlrun_interface.py +10 -6
- mlrun/frameworks/_common/model_handler.py +31 -31
- mlrun/frameworks/_common/producer.py +3 -1
- mlrun/frameworks/_dl_common/__init__.py +1 -2
- mlrun/frameworks/_dl_common/loggers/__init__.py +1 -2
- mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +4 -4
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +3 -3
- mlrun/frameworks/_ml_common/__init__.py +1 -2
- mlrun/frameworks/_ml_common/loggers/__init__.py +1 -2
- mlrun/frameworks/_ml_common/model_handler.py +21 -21
- mlrun/frameworks/_ml_common/plans/__init__.py +1 -2
- mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +3 -1
- mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
- mlrun/frameworks/auto_mlrun/__init__.py +1 -2
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +22 -15
- mlrun/frameworks/huggingface/__init__.py +1 -2
- mlrun/frameworks/huggingface/model_server.py +9 -9
- mlrun/frameworks/lgbm/__init__.py +47 -44
- mlrun/frameworks/lgbm/callbacks/__init__.py +1 -2
- mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -2
- mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -2
- mlrun/frameworks/lgbm/mlrun_interfaces/__init__.py +1 -2
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +5 -5
- mlrun/frameworks/lgbm/model_handler.py +15 -11
- mlrun/frameworks/lgbm/model_server.py +11 -7
- mlrun/frameworks/lgbm/utils.py +2 -2
- mlrun/frameworks/onnx/__init__.py +1 -2
- mlrun/frameworks/onnx/dataset.py +3 -3
- mlrun/frameworks/onnx/mlrun_interface.py +2 -2
- mlrun/frameworks/onnx/model_handler.py +7 -5
- mlrun/frameworks/onnx/model_server.py +8 -6
- mlrun/frameworks/parallel_coordinates.py +11 -11
- mlrun/frameworks/pytorch/__init__.py +22 -23
- mlrun/frameworks/pytorch/callbacks/__init__.py +1 -2
- mlrun/frameworks/pytorch/callbacks/callback.py +2 -1
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +15 -8
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +19 -12
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +22 -15
- mlrun/frameworks/pytorch/callbacks_handler.py +36 -30
- mlrun/frameworks/pytorch/mlrun_interface.py +17 -17
- mlrun/frameworks/pytorch/model_handler.py +21 -17
- mlrun/frameworks/pytorch/model_server.py +13 -9
- mlrun/frameworks/sklearn/__init__.py +19 -18
- mlrun/frameworks/sklearn/estimator.py +2 -2
- mlrun/frameworks/sklearn/metric.py +3 -3
- mlrun/frameworks/sklearn/metrics_library.py +8 -6
- mlrun/frameworks/sklearn/mlrun_interface.py +3 -2
- mlrun/frameworks/sklearn/model_handler.py +4 -3
- mlrun/frameworks/tf_keras/__init__.py +11 -12
- mlrun/frameworks/tf_keras/callbacks/__init__.py +1 -2
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +17 -14
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +15 -12
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +21 -18
- mlrun/frameworks/tf_keras/model_handler.py +17 -13
- mlrun/frameworks/tf_keras/model_server.py +12 -8
- mlrun/frameworks/xgboost/__init__.py +19 -18
- mlrun/frameworks/xgboost/model_handler.py +13 -9
- mlrun/k8s_utils.py +2 -5
- mlrun/launcher/base.py +3 -4
- mlrun/launcher/client.py +2 -2
- mlrun/launcher/local.py +6 -2
- mlrun/launcher/remote.py +1 -1
- mlrun/lists.py +8 -4
- mlrun/model.py +132 -46
- mlrun/model_monitoring/__init__.py +3 -5
- mlrun/model_monitoring/api.py +113 -98
- mlrun/model_monitoring/applications/__init__.py +0 -5
- mlrun/model_monitoring/applications/_application_steps.py +81 -50
- mlrun/model_monitoring/applications/base.py +467 -14
- mlrun/model_monitoring/applications/context.py +212 -134
- mlrun/model_monitoring/{db/stores/base → applications/evidently}/__init__.py +6 -2
- mlrun/model_monitoring/applications/evidently/base.py +146 -0
- mlrun/model_monitoring/applications/histogram_data_drift.py +89 -56
- mlrun/model_monitoring/applications/results.py +67 -15
- mlrun/model_monitoring/controller.py +701 -315
- mlrun/model_monitoring/db/__init__.py +0 -2
- mlrun/model_monitoring/db/_schedules.py +242 -0
- mlrun/model_monitoring/db/_stats.py +189 -0
- mlrun/model_monitoring/db/tsdb/__init__.py +33 -22
- mlrun/model_monitoring/db/tsdb/base.py +243 -49
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +76 -36
- mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +33 -0
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connection.py +213 -0
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +534 -88
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +1 -0
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +436 -106
- mlrun/model_monitoring/helpers.py +356 -114
- mlrun/model_monitoring/stream_processing.py +190 -345
- mlrun/model_monitoring/tracking_policy.py +11 -4
- mlrun/model_monitoring/writer.py +49 -90
- mlrun/package/__init__.py +3 -6
- mlrun/package/context_handler.py +2 -2
- mlrun/package/packager.py +12 -9
- mlrun/package/packagers/__init__.py +0 -2
- mlrun/package/packagers/default_packager.py +14 -11
- mlrun/package/packagers/numpy_packagers.py +16 -7
- mlrun/package/packagers/pandas_packagers.py +18 -18
- mlrun/package/packagers/python_standard_library_packagers.py +25 -11
- mlrun/package/packagers_manager.py +35 -32
- mlrun/package/utils/__init__.py +0 -3
- mlrun/package/utils/_pickler.py +6 -6
- mlrun/platforms/__init__.py +47 -16
- mlrun/platforms/iguazio.py +4 -1
- mlrun/projects/operations.py +30 -30
- mlrun/projects/pipelines.py +116 -47
- mlrun/projects/project.py +1292 -329
- mlrun/render.py +5 -9
- mlrun/run.py +57 -14
- mlrun/runtimes/__init__.py +1 -3
- mlrun/runtimes/base.py +30 -22
- mlrun/runtimes/daskjob.py +9 -9
- mlrun/runtimes/databricks_job/databricks_runtime.py +6 -5
- mlrun/runtimes/function_reference.py +5 -2
- mlrun/runtimes/generators.py +3 -2
- mlrun/runtimes/kubejob.py +6 -7
- mlrun/runtimes/mounts.py +574 -0
- mlrun/runtimes/mpijob/__init__.py +0 -2
- mlrun/runtimes/mpijob/abstract.py +7 -6
- mlrun/runtimes/nuclio/api_gateway.py +7 -7
- mlrun/runtimes/nuclio/application/application.py +11 -13
- mlrun/runtimes/nuclio/application/reverse_proxy.go +66 -64
- mlrun/runtimes/nuclio/function.py +127 -70
- mlrun/runtimes/nuclio/serving.py +105 -37
- mlrun/runtimes/pod.py +159 -54
- mlrun/runtimes/remotesparkjob.py +3 -2
- mlrun/runtimes/sparkjob/__init__.py +0 -2
- mlrun/runtimes/sparkjob/spark3job.py +22 -12
- mlrun/runtimes/utils.py +7 -6
- mlrun/secrets.py +2 -2
- mlrun/serving/__init__.py +8 -0
- mlrun/serving/merger.py +7 -5
- mlrun/serving/remote.py +35 -22
- mlrun/serving/routers.py +186 -240
- mlrun/serving/server.py +41 -10
- mlrun/serving/states.py +432 -118
- mlrun/serving/utils.py +13 -2
- mlrun/serving/v1_serving.py +3 -2
- mlrun/serving/v2_serving.py +161 -203
- mlrun/track/__init__.py +1 -1
- mlrun/track/tracker.py +2 -2
- mlrun/track/trackers/mlflow_tracker.py +6 -5
- mlrun/utils/async_http.py +35 -22
- mlrun/utils/clones.py +7 -4
- mlrun/utils/helpers.py +511 -58
- mlrun/utils/logger.py +119 -13
- mlrun/utils/notifications/notification/__init__.py +22 -19
- mlrun/utils/notifications/notification/base.py +39 -15
- mlrun/utils/notifications/notification/console.py +6 -6
- mlrun/utils/notifications/notification/git.py +11 -11
- mlrun/utils/notifications/notification/ipython.py +10 -9
- mlrun/utils/notifications/notification/mail.py +176 -0
- mlrun/utils/notifications/notification/slack.py +16 -8
- mlrun/utils/notifications/notification/webhook.py +24 -8
- mlrun/utils/notifications/notification_pusher.py +191 -200
- mlrun/utils/regex.py +12 -2
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0.dist-info}/METADATA +81 -54
- mlrun-1.8.0.dist-info/RECORD +351 -0
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0.dist-info}/WHEEL +1 -1
- mlrun/model_monitoring/applications/evidently_base.py +0 -137
- mlrun/model_monitoring/db/stores/__init__.py +0 -136
- mlrun/model_monitoring/db/stores/base/store.py +0 -213
- mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +0 -71
- mlrun/model_monitoring/db/stores/sqldb/models/base.py +0 -190
- mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +0 -103
- mlrun/model_monitoring/db/stores/sqldb/models/sqlite.py +0 -40
- mlrun/model_monitoring/db/stores/sqldb/sql_store.py +0 -659
- mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +0 -726
- mlrun/model_monitoring/model_endpoint.py +0 -118
- mlrun-1.7.2rc3.dist-info/RECORD +0 -351
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0.dist-info}/entry_points.txt +0 -0
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0.dist-info/licenses}/LICENSE +0 -0
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0.dist-info}/top_level.txt +0 -0
|
@@ -11,9 +11,8 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
from typing import Any, Union
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional, Union
|
|
17
16
|
|
|
18
17
|
import lightgbm as lgb
|
|
19
18
|
|
|
@@ -37,20 +36,20 @@ LGBMArtifactsLibrary = MLArtifactsLibrary
|
|
|
37
36
|
def _apply_mlrun_on_module(
|
|
38
37
|
model_name: str = "model",
|
|
39
38
|
tag: str = "",
|
|
40
|
-
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
41
|
-
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
42
|
-
custom_objects_directory: str = None,
|
|
39
|
+
modules_map: Optional[Union[dict[str, Union[None, str, list[str]]], str]] = None,
|
|
40
|
+
custom_objects_map: Optional[Union[dict[str, Union[str, list[str]]], str]] = None,
|
|
41
|
+
custom_objects_directory: Optional[str] = None,
|
|
43
42
|
context: mlrun.MLClientCtx = None,
|
|
44
43
|
model_format: str = LGBMModelHandler.ModelFormats.PKL,
|
|
45
44
|
sample_set: Union[LGBMTypes.DatasetType, mlrun.DataItem, str] = None,
|
|
46
|
-
y_columns: Union[list[str], list[int]] = None,
|
|
47
|
-
feature_vector: str = None,
|
|
48
|
-
feature_weights: list[float] = None,
|
|
49
|
-
labels: dict[str, Union[str, int, float]] = None,
|
|
50
|
-
parameters: dict[str, Union[str, int, float]] = None,
|
|
51
|
-
extra_data: dict[str, LGBMTypes.ExtraDataType] = None,
|
|
45
|
+
y_columns: Optional[Union[list[str], list[int]]] = None,
|
|
46
|
+
feature_vector: Optional[str] = None,
|
|
47
|
+
feature_weights: Optional[list[float]] = None,
|
|
48
|
+
labels: Optional[dict[str, Union[str, int, float]]] = None,
|
|
49
|
+
parameters: Optional[dict[str, Union[str, int, float]]] = None,
|
|
50
|
+
extra_data: Optional[dict[str, LGBMTypes.ExtraDataType]] = None,
|
|
52
51
|
auto_log: bool = True,
|
|
53
|
-
mlrun_logging_callback_kwargs: dict[str, Any] = None,
|
|
52
|
+
mlrun_logging_callback_kwargs: Optional[dict[str, Any]] = None,
|
|
54
53
|
):
|
|
55
54
|
# Apply MLRun's interface on the LightGBM module:
|
|
56
55
|
LGBMMLRunInterface.add_interface(obj=lgb)
|
|
@@ -84,27 +83,29 @@ def _apply_mlrun_on_model(
|
|
|
84
83
|
model: LGBMTypes.ModelType = None,
|
|
85
84
|
model_name: str = "model",
|
|
86
85
|
tag: str = "",
|
|
87
|
-
model_path: str = None,
|
|
88
|
-
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
89
|
-
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
90
|
-
custom_objects_directory: str = None,
|
|
86
|
+
model_path: Optional[str] = None,
|
|
87
|
+
modules_map: Optional[Union[dict[str, Union[None, str, list[str]]], str]] = None,
|
|
88
|
+
custom_objects_map: Optional[Union[dict[str, Union[str, list[str]]], str]] = None,
|
|
89
|
+
custom_objects_directory: Optional[str] = None,
|
|
91
90
|
context: mlrun.MLClientCtx = None,
|
|
92
91
|
model_format: str = LGBMModelHandler.ModelFormats.PKL,
|
|
93
|
-
artifacts: Union[list[MLPlan], list[str], dict[str, dict]] = None,
|
|
94
|
-
metrics:
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
92
|
+
artifacts: Optional[Union[list[MLPlan], list[str], dict[str, dict]]] = None,
|
|
93
|
+
metrics: Optional[
|
|
94
|
+
Union[
|
|
95
|
+
list[Metric],
|
|
96
|
+
list[LGBMTypes.MetricEntryType],
|
|
97
|
+
dict[str, LGBMTypes.MetricEntryType],
|
|
98
|
+
]
|
|
98
99
|
] = None,
|
|
99
100
|
x_test: LGBMTypes.DatasetType = None,
|
|
100
101
|
y_test: LGBMTypes.DatasetType = None,
|
|
101
102
|
sample_set: Union[LGBMTypes.DatasetType, mlrun.DataItem, str] = None,
|
|
102
|
-
y_columns: Union[list[str], list[int]] = None,
|
|
103
|
-
feature_vector: str = None,
|
|
104
|
-
feature_weights: list[float] = None,
|
|
105
|
-
labels: dict[str, Union[str, int, float]] = None,
|
|
106
|
-
parameters: dict[str, Union[str, int, float]] = None,
|
|
107
|
-
extra_data: dict[str, LGBMTypes.ExtraDataType] = None,
|
|
103
|
+
y_columns: Optional[Union[list[str], list[int]]] = None,
|
|
104
|
+
feature_vector: Optional[str] = None,
|
|
105
|
+
feature_weights: Optional[list[float]] = None,
|
|
106
|
+
labels: Optional[dict[str, Union[str, int, float]]] = None,
|
|
107
|
+
parameters: Optional[dict[str, Union[str, int, float]]] = None,
|
|
108
|
+
extra_data: Optional[dict[str, LGBMTypes.ExtraDataType]] = None,
|
|
108
109
|
auto_log: bool = True,
|
|
109
110
|
**kwargs,
|
|
110
111
|
):
|
|
@@ -182,29 +183,31 @@ def apply_mlrun(
|
|
|
182
183
|
model: LGBMTypes.ModelType = None,
|
|
183
184
|
model_name: str = "model",
|
|
184
185
|
tag: str = "",
|
|
185
|
-
model_path: str = None,
|
|
186
|
-
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
187
|
-
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
188
|
-
custom_objects_directory: str = None,
|
|
186
|
+
model_path: Optional[str] = None,
|
|
187
|
+
modules_map: Optional[Union[dict[str, Union[None, str, list[str]]], str]] = None,
|
|
188
|
+
custom_objects_map: Optional[Union[dict[str, Union[str, list[str]]], str]] = None,
|
|
189
|
+
custom_objects_directory: Optional[str] = None,
|
|
189
190
|
context: mlrun.MLClientCtx = None,
|
|
190
191
|
model_format: str = LGBMModelHandler.ModelFormats.PKL,
|
|
191
|
-
artifacts: Union[list[MLPlan], list[str], dict[str, dict]] = None,
|
|
192
|
-
metrics:
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
192
|
+
artifacts: Optional[Union[list[MLPlan], list[str], dict[str, dict]]] = None,
|
|
193
|
+
metrics: Optional[
|
|
194
|
+
Union[
|
|
195
|
+
list[Metric],
|
|
196
|
+
list[LGBMTypes.MetricEntryType],
|
|
197
|
+
dict[str, LGBMTypes.MetricEntryType],
|
|
198
|
+
]
|
|
196
199
|
] = None,
|
|
197
200
|
x_test: LGBMTypes.DatasetType = None,
|
|
198
201
|
y_test: LGBMTypes.DatasetType = None,
|
|
199
202
|
sample_set: Union[LGBMTypes.DatasetType, mlrun.DataItem, str] = None,
|
|
200
|
-
y_columns: Union[list[str], list[int]] = None,
|
|
201
|
-
feature_vector: str = None,
|
|
202
|
-
feature_weights: list[float] = None,
|
|
203
|
-
labels: dict[str, Union[str, int, float]] = None,
|
|
204
|
-
parameters: dict[str, Union[str, int, float]] = None,
|
|
205
|
-
extra_data: dict[str, LGBMTypes.ExtraDataType] = None,
|
|
203
|
+
y_columns: Optional[Union[list[str], list[int]]] = None,
|
|
204
|
+
feature_vector: Optional[str] = None,
|
|
205
|
+
feature_weights: Optional[list[float]] = None,
|
|
206
|
+
labels: Optional[dict[str, Union[str, int, float]]] = None,
|
|
207
|
+
parameters: Optional[dict[str, Union[str, int, float]]] = None,
|
|
208
|
+
extra_data: Optional[dict[str, LGBMTypes.ExtraDataType]] = None,
|
|
206
209
|
auto_log: bool = True,
|
|
207
|
-
mlrun_logging_callback_kwargs: dict[str, Any] = None,
|
|
210
|
+
mlrun_logging_callback_kwargs: Optional[dict[str, Any]] = None,
|
|
208
211
|
**kwargs,
|
|
209
212
|
) -> Union[LGBMModelHandler, None]:
|
|
210
213
|
"""
|
|
@@ -11,8 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
|
|
14
|
+
|
|
16
15
|
from .callback import Callback
|
|
17
16
|
from .logging_callback import LoggingCallback
|
|
18
17
|
from .mlrun_logging_callback import MLRunLoggingCallback
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
16
18
|
from ..._ml_common.loggers import Logger
|
|
17
19
|
from ..utils import LGBMTypes
|
|
18
20
|
from .callback import Callback, CallbackEnv
|
|
@@ -25,8 +27,8 @@ class LoggingCallback(Callback):
|
|
|
25
27
|
|
|
26
28
|
def __init__(
|
|
27
29
|
self,
|
|
28
|
-
dynamic_hyperparameters: list[str] = None,
|
|
29
|
-
static_hyperparameters: list[str] = None,
|
|
30
|
+
dynamic_hyperparameters: Optional[list[str]] = None,
|
|
31
|
+
static_hyperparameters: Optional[list[str]] = None,
|
|
30
32
|
):
|
|
31
33
|
"""
|
|
32
34
|
Initialize the logging callback with the given configuration. All the metrics data will be collected but the
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
16
18
|
import mlrun
|
|
17
19
|
|
|
18
20
|
from ..._ml_common.loggers import MLRunLogger
|
|
@@ -33,8 +35,8 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
33
35
|
def __init__(
|
|
34
36
|
self,
|
|
35
37
|
context: mlrun.MLClientCtx,
|
|
36
|
-
dynamic_hyperparameters: list[str] = None,
|
|
37
|
-
static_hyperparameters: list[str] = None,
|
|
38
|
+
dynamic_hyperparameters: Optional[list[str]] = None,
|
|
39
|
+
static_hyperparameters: Optional[list[str]] = None,
|
|
38
40
|
logging_frequency: int = 100,
|
|
39
41
|
):
|
|
40
42
|
"""
|
|
@@ -11,8 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
|
|
14
|
+
|
|
16
15
|
from .booster_mlrun_interface import LGBMBoosterMLRunInterface
|
|
17
16
|
from .mlrun_interface import LGBMMLRunInterface
|
|
18
17
|
from .model_mlrun_interface import LGBMModelMLRunInterface
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
#
|
|
15
15
|
from abc import ABC
|
|
16
16
|
from types import ModuleType
|
|
17
|
-
from typing import Callable, Union
|
|
17
|
+
from typing import Callable, Optional, Union
|
|
18
18
|
|
|
19
19
|
import lightgbm as lgb
|
|
20
20
|
|
|
@@ -67,7 +67,7 @@ class LGBMMLRunInterface(MLRunInterface, ABC):
|
|
|
67
67
|
@classmethod
|
|
68
68
|
def add_interface(
|
|
69
69
|
cls,
|
|
70
|
-
obj: ModuleType = None,
|
|
70
|
+
obj: Optional[ModuleType] = None,
|
|
71
71
|
restoration: LGBMTypes.MLRunInterfaceRestorationType = None,
|
|
72
72
|
):
|
|
73
73
|
"""
|
|
@@ -167,10 +167,10 @@ class LGBMMLRunInterface(MLRunInterface, ABC):
|
|
|
167
167
|
def configure_logging(
|
|
168
168
|
context: mlrun.MLClientCtx = None,
|
|
169
169
|
log_model: bool = True,
|
|
170
|
-
model_handler_kwargs: dict = None,
|
|
171
|
-
log_model_kwargs: dict = None,
|
|
170
|
+
model_handler_kwargs: Optional[dict] = None,
|
|
171
|
+
log_model_kwargs: Optional[dict] = None,
|
|
172
172
|
log_training: bool = True,
|
|
173
|
-
mlrun_logging_callback_kwargs: dict = None,
|
|
173
|
+
mlrun_logging_callback_kwargs: Optional[dict] = None,
|
|
174
174
|
):
|
|
175
175
|
"""
|
|
176
176
|
Configure the logging of the training API in LightGBM to log the training and model into MLRun. Each `train`
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
#
|
|
15
15
|
import os
|
|
16
16
|
import pickle
|
|
17
|
-
from typing import Union
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
import cloudpickle
|
|
20
20
|
import lightgbm as lgb
|
|
@@ -53,12 +53,16 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
53
53
|
|
|
54
54
|
def __init__(
|
|
55
55
|
self,
|
|
56
|
-
model_name: str = None,
|
|
57
|
-
model_path: str = None,
|
|
56
|
+
model_name: Optional[str] = None,
|
|
57
|
+
model_path: Optional[str] = None,
|
|
58
58
|
model: LGBMTypes.ModelType = None,
|
|
59
|
-
modules_map:
|
|
60
|
-
|
|
61
|
-
|
|
59
|
+
modules_map: Optional[
|
|
60
|
+
Union[dict[str, Union[None, str, list[str]]], str]
|
|
61
|
+
] = None,
|
|
62
|
+
custom_objects_map: Optional[
|
|
63
|
+
Union[dict[str, Union[str, list[str]]], str]
|
|
64
|
+
] = None,
|
|
65
|
+
custom_objects_directory: Optional[str] = None,
|
|
62
66
|
context: mlrun.MLClientCtx = None,
|
|
63
67
|
model_format: str = ModelFormats.PKL,
|
|
64
68
|
**kwargs,
|
|
@@ -152,8 +156,8 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
152
156
|
|
|
153
157
|
def set_labels(
|
|
154
158
|
self,
|
|
155
|
-
to_add: dict[str, Union[str, int, float]] = None,
|
|
156
|
-
to_remove: list[str] = None,
|
|
159
|
+
to_add: Optional[dict[str, Union[str, int, float]]] = None,
|
|
160
|
+
to_remove: Optional[list[str]] = None,
|
|
157
161
|
):
|
|
158
162
|
"""
|
|
159
163
|
Update the labels dictionary of this model artifact. There are required labels that cannot be edited or removed.
|
|
@@ -183,7 +187,7 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
183
187
|
f"'model_path': '{self._model_path}'"
|
|
184
188
|
)
|
|
185
189
|
|
|
186
|
-
def save(self, output_path: str = None, **kwargs):
|
|
190
|
+
def save(self, output_path: Optional[str] = None, **kwargs):
|
|
187
191
|
"""
|
|
188
192
|
Save the handled model at the given output path. If a MLRun context is available, the saved model files will be
|
|
189
193
|
logged and returned as artifacts.
|
|
@@ -217,10 +221,10 @@ class LGBMModelHandler(MLModelHandler):
|
|
|
217
221
|
|
|
218
222
|
def to_onnx(
|
|
219
223
|
self,
|
|
220
|
-
model_name: str = None,
|
|
224
|
+
model_name: Optional[str] = None,
|
|
221
225
|
optimize: bool = True,
|
|
222
226
|
input_sample: LGBMTypes.DatasetType = None,
|
|
223
|
-
log: bool = None,
|
|
227
|
+
log: Optional[bool] = None,
|
|
224
228
|
):
|
|
225
229
|
"""
|
|
226
230
|
Convert the model in this handler to an ONNX model. The inputs names are optional, they do not change the
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Any, Union
|
|
15
|
+
from typing import Any, Optional, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
|
|
@@ -32,16 +32,20 @@ class LGBMModelServer(V2ModelServer):
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
34
|
context: mlrun.MLClientCtx = None,
|
|
35
|
-
name: str = None,
|
|
35
|
+
name: Optional[str] = None,
|
|
36
36
|
model: LGBMTypes.ModelType = None,
|
|
37
37
|
model_path: LGBMTypes.PathType = None,
|
|
38
|
-
model_name: str = None,
|
|
38
|
+
model_name: Optional[str] = None,
|
|
39
39
|
model_format: str = LGBMModelHandler.ModelFormats.PKL,
|
|
40
|
-
modules_map:
|
|
41
|
-
|
|
42
|
-
|
|
40
|
+
modules_map: Optional[
|
|
41
|
+
Union[dict[str, Union[None, str, list[str]]], str]
|
|
42
|
+
] = None,
|
|
43
|
+
custom_objects_map: Optional[
|
|
44
|
+
Union[dict[str, Union[str, list[str]]], str]
|
|
45
|
+
] = None,
|
|
46
|
+
custom_objects_directory: Optional[str] = None,
|
|
43
47
|
to_list: bool = True,
|
|
44
|
-
protocol: str = None,
|
|
48
|
+
protocol: Optional[str] = None,
|
|
45
49
|
**class_args,
|
|
46
50
|
):
|
|
47
51
|
"""
|
mlrun/frameworks/lgbm/utils.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Union
|
|
15
|
+
from typing import Optional, Union
|
|
16
16
|
|
|
17
17
|
import lightgbm as lgb
|
|
18
18
|
import numpy as np
|
|
@@ -109,7 +109,7 @@ class LGBMUtils(MLUtils):
|
|
|
109
109
|
def get_algorithm_functionality(
|
|
110
110
|
model: MLTypes.ModelType = None,
|
|
111
111
|
y: MLTypes.DatasetType = None,
|
|
112
|
-
objective: str = None,
|
|
112
|
+
objective: Optional[str] = None,
|
|
113
113
|
) -> AlgorithmFunctionality:
|
|
114
114
|
"""
|
|
115
115
|
Get the algorithm functionality of the LightGBM model. If SciKit-Learn API is used, pass the LGBBMModel and a y
|
|
@@ -11,8 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
|
|
14
|
+
|
|
16
15
|
from .dataset import ONNXDataset
|
|
17
16
|
from .model_handler import ONNXModelHandler
|
|
18
17
|
from .model_server import ONNXModelServer
|
mlrun/frameworks/onnx/dataset.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
import math
|
|
16
|
-
from typing import Callable, Union
|
|
16
|
+
from typing import Callable, Optional, Union
|
|
17
17
|
|
|
18
18
|
import numpy as np
|
|
19
19
|
|
|
@@ -28,8 +28,8 @@ class ONNXDataset:
|
|
|
28
28
|
x: Union[np.ndarray, list[np.ndarray]],
|
|
29
29
|
y: Union[np.ndarray, list[np.ndarray]] = None,
|
|
30
30
|
batch_size: int = 1,
|
|
31
|
-
x_transforms: list[Callable[[np.ndarray], np.ndarray]] = None,
|
|
32
|
-
y_transforms: list[Callable[[np.ndarray], np.ndarray]] = None,
|
|
31
|
+
x_transforms: Optional[list[Callable[[np.ndarray], np.ndarray]]] = None,
|
|
32
|
+
y_transforms: Optional[list[Callable[[np.ndarray], np.ndarray]]] = None,
|
|
33
33
|
is_batched_transforms: bool = False,
|
|
34
34
|
):
|
|
35
35
|
"""
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Callable
|
|
15
|
+
from typing import Callable, Optional
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import onnx
|
|
@@ -35,7 +35,7 @@ class ONNXMLRunInterface:
|
|
|
35
35
|
def __init__(
|
|
36
36
|
self,
|
|
37
37
|
model: onnx.ModelProto,
|
|
38
|
-
execution_providers: list[str] = None,
|
|
38
|
+
execution_providers: Optional[list[str]] = None,
|
|
39
39
|
context: mlrun.MLClientCtx = None,
|
|
40
40
|
):
|
|
41
41
|
# Set the context:
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
import os
|
|
16
|
-
from typing import Union
|
|
16
|
+
from typing import Optional, Union
|
|
17
17
|
|
|
18
18
|
import onnx
|
|
19
19
|
import onnxoptimizer
|
|
@@ -35,8 +35,8 @@ class ONNXModelHandler(ModelHandler):
|
|
|
35
35
|
def __init__(
|
|
36
36
|
self,
|
|
37
37
|
model: onnx.ModelProto = None,
|
|
38
|
-
model_path: str = None,
|
|
39
|
-
model_name: str = None,
|
|
38
|
+
model_path: Optional[str] = None,
|
|
39
|
+
model_name: Optional[str] = None,
|
|
40
40
|
context: mlrun.MLClientCtx = None,
|
|
41
41
|
**kwargs,
|
|
42
42
|
):
|
|
@@ -70,7 +70,7 @@ class ONNXModelHandler(ModelHandler):
|
|
|
70
70
|
|
|
71
71
|
# TODO: output_path won't work well with logging artifacts. Need to look into changing the logic of 'log_artifact'.
|
|
72
72
|
def save(
|
|
73
|
-
self, output_path: str = None, **kwargs
|
|
73
|
+
self, output_path: Optional[str] = None, **kwargs
|
|
74
74
|
) -> Union[dict[str, Artifact], None]:
|
|
75
75
|
"""
|
|
76
76
|
Save the handled model at the given output path. If a MLRun context is available, the saved model files will be
|
|
@@ -106,7 +106,9 @@ class ONNXModelHandler(ModelHandler):
|
|
|
106
106
|
# Load the ONNX model:
|
|
107
107
|
self._model = onnx.load(self._model_file)
|
|
108
108
|
|
|
109
|
-
def optimize(
|
|
109
|
+
def optimize(
|
|
110
|
+
self, optimizations: Optional[list[str]] = None, fixed_point: bool = False
|
|
111
|
+
):
|
|
110
112
|
"""
|
|
111
113
|
Use ONNX optimizer to optimize the ONNX model. The optimizations supported can be seen by calling
|
|
112
114
|
'onnxoptimizer.get_available_passes()'
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Any, Union
|
|
15
|
+
from typing import Any, Optional, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import onnx
|
|
@@ -33,12 +33,14 @@ class ONNXModelServer(V2ModelServer):
|
|
|
33
33
|
def __init__(
|
|
34
34
|
self,
|
|
35
35
|
context: mlrun.MLClientCtx = None,
|
|
36
|
-
name: str = None,
|
|
36
|
+
name: Optional[str] = None,
|
|
37
37
|
model: onnx.ModelProto = None,
|
|
38
|
-
model_path: str = None,
|
|
39
|
-
model_name: str = None,
|
|
40
|
-
execution_providers:
|
|
41
|
-
|
|
38
|
+
model_path: Optional[str] = None,
|
|
39
|
+
model_name: Optional[str] = None,
|
|
40
|
+
execution_providers: Optional[
|
|
41
|
+
list[Union[str, tuple[str, dict[str, Any]]]]
|
|
42
|
+
] = None,
|
|
43
|
+
protocol: Optional[str] = None,
|
|
42
44
|
**class_args,
|
|
43
45
|
):
|
|
44
46
|
"""
|
|
@@ -11,10 +11,10 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
import datetime
|
|
14
|
+
|
|
16
15
|
import os
|
|
17
|
-
from
|
|
16
|
+
from datetime import datetime
|
|
17
|
+
from typing import Optional, Union
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
import pandas as pd
|
|
@@ -48,7 +48,7 @@ def _gen_dropdown_buttons(output_cols) -> list:
|
|
|
48
48
|
|
|
49
49
|
|
|
50
50
|
def _gen_dimensions(
|
|
51
|
-
df: pd.DataFrame, col: str, prefix: str = None, is_index=False
|
|
51
|
+
df: pd.DataFrame, col: str, prefix: Optional[str] = None, is_index=False
|
|
52
52
|
) -> dict:
|
|
53
53
|
"""
|
|
54
54
|
Computes the plotting dimensions of each parameter/output col according to its type.
|
|
@@ -107,8 +107,8 @@ def gen_pcp_plot(
|
|
|
107
107
|
source_df: pd.DataFrame,
|
|
108
108
|
index_col: str,
|
|
109
109
|
hide_identical: bool = True,
|
|
110
|
-
exclude: list = None,
|
|
111
|
-
colorscale: str = None,
|
|
110
|
+
exclude: Optional[list] = None,
|
|
111
|
+
colorscale: Optional[str] = None,
|
|
112
112
|
):
|
|
113
113
|
"""
|
|
114
114
|
Creates a list composed of the data to be plotted as a Parallel Coordinate, this includes
|
|
@@ -240,11 +240,11 @@ def _runs_list_to_df(runs_list, extend_iterations=False):
|
|
|
240
240
|
def compare_run_objects(
|
|
241
241
|
runs_list: Union[mlrun.model.RunObject, list[mlrun.model.RunObject]],
|
|
242
242
|
hide_identical: bool = True,
|
|
243
|
-
exclude: list = None,
|
|
244
|
-
show: bool = None,
|
|
243
|
+
exclude: Optional[list] = None,
|
|
244
|
+
show: Optional[bool] = None,
|
|
245
245
|
extend_iterations=True,
|
|
246
246
|
filename=None,
|
|
247
|
-
colorscale: str = None,
|
|
247
|
+
colorscale: Optional[str] = None,
|
|
248
248
|
):
|
|
249
249
|
"""return/show parallel coordinates plot + table to compare between a list of runs or run iterations
|
|
250
250
|
|
|
@@ -292,9 +292,9 @@ def compare_db_runs(
|
|
|
292
292
|
run_name=None,
|
|
293
293
|
labels=None,
|
|
294
294
|
iter=False,
|
|
295
|
-
start_time_from: datetime = None,
|
|
295
|
+
start_time_from: Optional[datetime] = None,
|
|
296
296
|
hide_identical: bool = True,
|
|
297
|
-
exclude: list = None,
|
|
297
|
+
exclude: Optional[list] = None,
|
|
298
298
|
show=None,
|
|
299
299
|
colorscale: str = "Blues",
|
|
300
300
|
filename=None,
|
|
@@ -11,9 +11,8 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
from typing import Any, Union
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional, Union
|
|
17
16
|
|
|
18
17
|
from torch.nn import Module
|
|
19
18
|
from torch.optim import Optimizer
|
|
@@ -35,23 +34,23 @@ def train(
|
|
|
35
34
|
loss_function: Module,
|
|
36
35
|
optimizer: Optimizer,
|
|
37
36
|
validation_set: DataLoader = None,
|
|
38
|
-
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
37
|
+
metric_functions: Optional[list[PyTorchTypes.MetricFunctionType]] = None,
|
|
39
38
|
scheduler=None,
|
|
40
39
|
scheduler_step_frequency: Union[int, float, str] = "epoch",
|
|
41
40
|
epochs: int = 1,
|
|
42
|
-
training_iterations: int = None,
|
|
43
|
-
validation_iterations: int = None,
|
|
44
|
-
callbacks_list: list[Callback] = None,
|
|
41
|
+
training_iterations: Optional[int] = None,
|
|
42
|
+
validation_iterations: Optional[int] = None,
|
|
43
|
+
callbacks_list: Optional[list[Callback]] = None,
|
|
45
44
|
use_cuda: bool = True,
|
|
46
|
-
use_horovod: bool = None,
|
|
45
|
+
use_horovod: Optional[bool] = None,
|
|
47
46
|
auto_log: bool = True,
|
|
48
|
-
model_name: str = None,
|
|
49
|
-
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
50
|
-
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
51
|
-
custom_objects_directory: str = None,
|
|
52
|
-
tensorboard_directory: str = None,
|
|
53
|
-
mlrun_callback_kwargs: dict[str, Any] = None,
|
|
54
|
-
tensorboard_callback_kwargs: dict[str, Any] = None,
|
|
47
|
+
model_name: Optional[str] = None,
|
|
48
|
+
modules_map: Optional[Union[dict[str, Union[None, str, list[str]]], str]] = None,
|
|
49
|
+
custom_objects_map: Optional[Union[dict[str, Union[str, list[str]]], str]] = None,
|
|
50
|
+
custom_objects_directory: Optional[str] = None,
|
|
51
|
+
tensorboard_directory: Optional[str] = None,
|
|
52
|
+
mlrun_callback_kwargs: Optional[dict[str, Any]] = None,
|
|
53
|
+
tensorboard_callback_kwargs: Optional[dict[str, Any]] = None,
|
|
55
54
|
context: mlrun.MLClientCtx = None,
|
|
56
55
|
) -> PyTorchModelHandler:
|
|
57
56
|
"""
|
|
@@ -205,17 +204,17 @@ def evaluate(
|
|
|
205
204
|
dataset: DataLoader,
|
|
206
205
|
model: Module = None,
|
|
207
206
|
loss_function: Module = None,
|
|
208
|
-
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
209
|
-
iterations: int = None,
|
|
210
|
-
callbacks_list: list[Callback] = None,
|
|
207
|
+
metric_functions: Optional[list[PyTorchTypes.MetricFunctionType]] = None,
|
|
208
|
+
iterations: Optional[int] = None,
|
|
209
|
+
callbacks_list: Optional[list[Callback]] = None,
|
|
211
210
|
use_cuda: bool = True,
|
|
212
211
|
use_horovod: bool = False,
|
|
213
212
|
auto_log: bool = True,
|
|
214
|
-
model_name: str = None,
|
|
215
|
-
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
216
|
-
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
217
|
-
custom_objects_directory: str = None,
|
|
218
|
-
mlrun_callback_kwargs: dict[str, Any] = None,
|
|
213
|
+
model_name: Optional[str] = None,
|
|
214
|
+
modules_map: Optional[Union[dict[str, Union[None, str, list[str]]], str]] = None,
|
|
215
|
+
custom_objects_map: Optional[Union[dict[str, Union[str, list[str]]], str]] = None,
|
|
216
|
+
custom_objects_directory: Optional[str] = None,
|
|
217
|
+
mlrun_callback_kwargs: Optional[dict[str, Any]] = None,
|
|
219
218
|
context: mlrun.MLClientCtx = None,
|
|
220
219
|
) -> tuple[PyTorchModelHandler, list[PyTorchTypes.MetricValueType]]:
|
|
221
220
|
"""
|
|
@@ -11,8 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
|
|
14
|
+
|
|
16
15
|
from .callback import Callback
|
|
17
16
|
from .logging_callback import HyperparametersKeys, LoggingCallback
|
|
18
17
|
from .mlrun_logging_callback import MLRunLoggingCallback
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
from abc import ABC, abstractmethod
|
|
16
|
+
from typing import Optional
|
|
16
17
|
|
|
17
18
|
from torch import Tensor
|
|
18
19
|
from torch.nn import Module
|
|
@@ -67,7 +68,7 @@ class Callback(ABC):
|
|
|
67
68
|
validation_set: DataLoader = None,
|
|
68
69
|
loss_function: Module = None,
|
|
69
70
|
optimizer: Optimizer = None,
|
|
70
|
-
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
71
|
+
metric_functions: Optional[list[PyTorchTypes.MetricFunctionType]] = None,
|
|
71
72
|
scheduler=None,
|
|
72
73
|
):
|
|
73
74
|
"""
|