mlrun 1.7.2rc3__py3-none-any.whl → 1.8.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +18 -18
- mlrun/__main__.py +3 -3
- mlrun/alerts/alert.py +19 -12
- mlrun/artifacts/__init__.py +0 -2
- mlrun/artifacts/base.py +34 -11
- mlrun/artifacts/dataset.py +16 -16
- mlrun/artifacts/manager.py +13 -13
- mlrun/artifacts/model.py +66 -53
- mlrun/common/constants.py +6 -0
- mlrun/common/formatters/__init__.py +1 -0
- mlrun/common/formatters/feature_set.py +1 -0
- mlrun/common/formatters/function.py +1 -0
- mlrun/common/formatters/model_endpoint.py +30 -0
- mlrun/common/formatters/pipeline.py +1 -2
- mlrun/common/formatters/project.py +9 -0
- mlrun/common/model_monitoring/__init__.py +0 -3
- mlrun/common/model_monitoring/helpers.py +1 -1
- mlrun/common/runtimes/constants.py +1 -2
- mlrun/common/schemas/__init__.py +7 -2
- mlrun/common/schemas/alert.py +31 -18
- mlrun/common/schemas/api_gateway.py +3 -3
- mlrun/common/schemas/artifact.py +7 -13
- mlrun/common/schemas/auth.py +6 -4
- mlrun/common/schemas/background_task.py +7 -7
- mlrun/common/schemas/client_spec.py +2 -2
- mlrun/common/schemas/clusterization_spec.py +2 -2
- mlrun/common/schemas/common.py +53 -3
- mlrun/common/schemas/datastore_profile.py +1 -1
- mlrun/common/schemas/feature_store.py +9 -9
- mlrun/common/schemas/frontend_spec.py +4 -4
- mlrun/common/schemas/function.py +10 -10
- mlrun/common/schemas/hub.py +1 -1
- mlrun/common/schemas/k8s.py +3 -3
- mlrun/common/schemas/memory_reports.py +3 -3
- mlrun/common/schemas/model_monitoring/__init__.py +8 -1
- mlrun/common/schemas/model_monitoring/constants.py +62 -12
- mlrun/common/schemas/model_monitoring/grafana.py +1 -1
- mlrun/common/schemas/model_monitoring/model_endpoint_v2.py +149 -0
- mlrun/common/schemas/model_monitoring/model_endpoints.py +22 -6
- mlrun/common/schemas/notification.py +18 -3
- mlrun/common/schemas/object.py +1 -1
- mlrun/common/schemas/pagination.py +4 -4
- mlrun/common/schemas/partition.py +137 -0
- mlrun/common/schemas/pipeline.py +2 -2
- mlrun/common/schemas/project.py +22 -17
- mlrun/common/schemas/runs.py +2 -2
- mlrun/common/schemas/runtime_resource.py +5 -5
- mlrun/common/schemas/schedule.py +1 -1
- mlrun/common/schemas/secret.py +1 -1
- mlrun/common/schemas/tag.py +3 -3
- mlrun/common/schemas/workflow.py +5 -5
- mlrun/config.py +65 -15
- mlrun/data_types/__init__.py +0 -2
- mlrun/data_types/data_types.py +0 -1
- mlrun/data_types/infer.py +3 -1
- mlrun/data_types/spark.py +4 -4
- mlrun/data_types/to_pandas.py +2 -11
- mlrun/datastore/__init__.py +0 -2
- mlrun/datastore/alibaba_oss.py +4 -1
- mlrun/datastore/azure_blob.py +4 -1
- mlrun/datastore/base.py +12 -4
- mlrun/datastore/datastore.py +9 -3
- mlrun/datastore/datastore_profile.py +20 -20
- mlrun/datastore/dbfs_store.py +4 -1
- mlrun/datastore/filestore.py +4 -1
- mlrun/datastore/google_cloud_storage.py +4 -1
- mlrun/datastore/hdfs.py +4 -1
- mlrun/datastore/inmem.py +4 -1
- mlrun/datastore/redis.py +4 -1
- mlrun/datastore/s3.py +4 -1
- mlrun/datastore/sources.py +51 -49
- mlrun/datastore/store_resources.py +0 -2
- mlrun/datastore/targets.py +22 -23
- mlrun/datastore/utils.py +2 -2
- mlrun/datastore/v3io.py +4 -1
- mlrun/datastore/wasbfs/fs.py +13 -12
- mlrun/db/base.py +170 -64
- mlrun/db/factory.py +3 -0
- mlrun/db/httpdb.py +986 -238
- mlrun/db/nopdb.py +155 -57
- mlrun/errors.py +2 -2
- mlrun/execution.py +55 -29
- mlrun/feature_store/__init__.py +0 -2
- mlrun/feature_store/api.py +40 -40
- mlrun/feature_store/common.py +9 -9
- mlrun/feature_store/feature_set.py +20 -18
- mlrun/feature_store/feature_vector.py +27 -24
- mlrun/feature_store/retrieval/base.py +14 -9
- mlrun/feature_store/retrieval/job.py +2 -1
- mlrun/feature_store/steps.py +2 -2
- mlrun/features.py +30 -13
- mlrun/frameworks/__init__.py +1 -2
- mlrun/frameworks/_common/__init__.py +1 -2
- mlrun/frameworks/_common/artifacts_library.py +2 -2
- mlrun/frameworks/_common/mlrun_interface.py +10 -6
- mlrun/frameworks/_common/model_handler.py +29 -27
- mlrun/frameworks/_common/producer.py +3 -1
- mlrun/frameworks/_dl_common/__init__.py +1 -2
- mlrun/frameworks/_dl_common/loggers/__init__.py +1 -2
- mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +4 -4
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +3 -3
- mlrun/frameworks/_ml_common/__init__.py +1 -2
- mlrun/frameworks/_ml_common/loggers/__init__.py +1 -2
- mlrun/frameworks/_ml_common/model_handler.py +21 -21
- mlrun/frameworks/_ml_common/plans/__init__.py +1 -2
- mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +3 -1
- mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
- mlrun/frameworks/auto_mlrun/__init__.py +1 -2
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +22 -15
- mlrun/frameworks/huggingface/__init__.py +1 -2
- mlrun/frameworks/huggingface/model_server.py +9 -9
- mlrun/frameworks/lgbm/__init__.py +47 -44
- mlrun/frameworks/lgbm/callbacks/__init__.py +1 -2
- mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -2
- mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -2
- mlrun/frameworks/lgbm/mlrun_interfaces/__init__.py +1 -2
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +5 -5
- mlrun/frameworks/lgbm/model_handler.py +15 -11
- mlrun/frameworks/lgbm/model_server.py +11 -7
- mlrun/frameworks/lgbm/utils.py +2 -2
- mlrun/frameworks/onnx/__init__.py +1 -2
- mlrun/frameworks/onnx/dataset.py +3 -3
- mlrun/frameworks/onnx/mlrun_interface.py +2 -2
- mlrun/frameworks/onnx/model_handler.py +7 -5
- mlrun/frameworks/onnx/model_server.py +8 -6
- mlrun/frameworks/parallel_coordinates.py +11 -11
- mlrun/frameworks/pytorch/__init__.py +22 -23
- mlrun/frameworks/pytorch/callbacks/__init__.py +1 -2
- mlrun/frameworks/pytorch/callbacks/callback.py +2 -1
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +15 -8
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +19 -12
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +22 -15
- mlrun/frameworks/pytorch/callbacks_handler.py +36 -30
- mlrun/frameworks/pytorch/mlrun_interface.py +17 -17
- mlrun/frameworks/pytorch/model_handler.py +21 -17
- mlrun/frameworks/pytorch/model_server.py +13 -9
- mlrun/frameworks/sklearn/__init__.py +19 -18
- mlrun/frameworks/sklearn/estimator.py +2 -2
- mlrun/frameworks/sklearn/metric.py +3 -3
- mlrun/frameworks/sklearn/metrics_library.py +8 -6
- mlrun/frameworks/sklearn/mlrun_interface.py +3 -2
- mlrun/frameworks/sklearn/model_handler.py +4 -3
- mlrun/frameworks/tf_keras/__init__.py +11 -12
- mlrun/frameworks/tf_keras/callbacks/__init__.py +1 -2
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +17 -14
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +15 -12
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +21 -18
- mlrun/frameworks/tf_keras/model_handler.py +17 -13
- mlrun/frameworks/tf_keras/model_server.py +12 -8
- mlrun/frameworks/xgboost/__init__.py +19 -18
- mlrun/frameworks/xgboost/model_handler.py +13 -9
- mlrun/launcher/base.py +3 -4
- mlrun/launcher/local.py +1 -1
- mlrun/launcher/remote.py +1 -1
- mlrun/lists.py +4 -3
- mlrun/model.py +110 -46
- mlrun/model_monitoring/__init__.py +1 -2
- mlrun/model_monitoring/api.py +6 -6
- mlrun/model_monitoring/applications/_application_steps.py +13 -15
- mlrun/model_monitoring/applications/histogram_data_drift.py +41 -15
- mlrun/model_monitoring/applications/results.py +55 -3
- mlrun/model_monitoring/controller.py +185 -223
- mlrun/model_monitoring/db/_schedules.py +156 -0
- mlrun/model_monitoring/db/_stats.py +189 -0
- mlrun/model_monitoring/db/stores/__init__.py +1 -1
- mlrun/model_monitoring/db/stores/base/store.py +6 -65
- mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +0 -25
- mlrun/model_monitoring/db/stores/sqldb/models/base.py +0 -97
- mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +2 -58
- mlrun/model_monitoring/db/stores/sqldb/models/sqlite.py +0 -15
- mlrun/model_monitoring/db/stores/sqldb/sql_store.py +6 -257
- mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +9 -271
- mlrun/model_monitoring/db/tsdb/base.py +76 -24
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +61 -6
- mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +33 -0
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +253 -28
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +1 -0
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +35 -17
- mlrun/model_monitoring/helpers.py +91 -1
- mlrun/model_monitoring/model_endpoint.py +4 -2
- mlrun/model_monitoring/stream_processing.py +16 -13
- mlrun/model_monitoring/tracking_policy.py +10 -3
- mlrun/model_monitoring/writer.py +47 -26
- mlrun/package/__init__.py +3 -6
- mlrun/package/context_handler.py +1 -1
- mlrun/package/packager.py +12 -9
- mlrun/package/packagers/__init__.py +0 -2
- mlrun/package/packagers/default_packager.py +14 -11
- mlrun/package/packagers/numpy_packagers.py +16 -7
- mlrun/package/packagers/pandas_packagers.py +18 -18
- mlrun/package/packagers/python_standard_library_packagers.py +25 -11
- mlrun/package/packagers_manager.py +31 -14
- mlrun/package/utils/__init__.py +0 -3
- mlrun/package/utils/_pickler.py +6 -6
- mlrun/platforms/__init__.py +3 -16
- mlrun/platforms/iguazio.py +4 -1
- mlrun/projects/operations.py +27 -27
- mlrun/projects/pipelines.py +34 -35
- mlrun/projects/project.py +535 -182
- mlrun/run.py +13 -10
- mlrun/runtimes/__init__.py +1 -3
- mlrun/runtimes/base.py +15 -11
- mlrun/runtimes/daskjob.py +9 -9
- mlrun/runtimes/generators.py +2 -1
- mlrun/runtimes/kubejob.py +4 -5
- mlrun/runtimes/mounts.py +572 -0
- mlrun/runtimes/mpijob/__init__.py +0 -2
- mlrun/runtimes/mpijob/abstract.py +7 -6
- mlrun/runtimes/nuclio/api_gateway.py +7 -7
- mlrun/runtimes/nuclio/application/application.py +11 -11
- mlrun/runtimes/nuclio/function.py +13 -13
- mlrun/runtimes/nuclio/serving.py +9 -9
- mlrun/runtimes/pod.py +154 -45
- mlrun/runtimes/remotesparkjob.py +3 -2
- mlrun/runtimes/sparkjob/__init__.py +0 -2
- mlrun/runtimes/sparkjob/spark3job.py +21 -11
- mlrun/runtimes/utils.py +6 -5
- mlrun/serving/merger.py +6 -4
- mlrun/serving/remote.py +18 -17
- mlrun/serving/routers.py +27 -27
- mlrun/serving/server.py +1 -1
- mlrun/serving/states.py +76 -71
- mlrun/serving/utils.py +13 -2
- mlrun/serving/v1_serving.py +3 -2
- mlrun/serving/v2_serving.py +4 -4
- mlrun/track/__init__.py +1 -1
- mlrun/track/tracker.py +2 -2
- mlrun/track/trackers/mlflow_tracker.py +6 -5
- mlrun/utils/async_http.py +1 -1
- mlrun/utils/helpers.py +70 -16
- mlrun/utils/logger.py +106 -4
- mlrun/utils/notifications/notification/__init__.py +22 -19
- mlrun/utils/notifications/notification/base.py +33 -14
- mlrun/utils/notifications/notification/console.py +6 -6
- mlrun/utils/notifications/notification/git.py +11 -11
- mlrun/utils/notifications/notification/ipython.py +10 -9
- mlrun/utils/notifications/notification/mail.py +149 -0
- mlrun/utils/notifications/notification/slack.py +6 -6
- mlrun/utils/notifications/notification/webhook.py +18 -22
- mlrun/utils/notifications/notification_pusher.py +43 -31
- mlrun/utils/regex.py +3 -1
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0rc2.dist-info}/METADATA +18 -14
- mlrun-1.8.0rc2.dist-info/RECORD +358 -0
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0rc2.dist-info}/WHEEL +1 -1
- mlrun-1.7.2rc3.dist-info/RECORD +0 -351
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0rc2.dist-info}/LICENSE +0 -0
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0rc2.dist-info}/entry_points.txt +0 -0
- {mlrun-1.7.2rc3.dist-info → mlrun-1.8.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -11,10 +11,9 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
|
|
14
|
+
|
|
16
15
|
import warnings
|
|
17
|
-
from typing import Union
|
|
16
|
+
from typing import Optional, Union
|
|
18
17
|
|
|
19
18
|
import mlrun
|
|
20
19
|
from mlrun.frameworks.sklearn.metric import Metric
|
|
@@ -36,26 +35,28 @@ def apply_mlrun(
|
|
|
36
35
|
model: SKLearnTypes.ModelType = None,
|
|
37
36
|
model_name: str = "model",
|
|
38
37
|
tag: str = "",
|
|
39
|
-
model_path: str = None,
|
|
40
|
-
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
41
|
-
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
42
|
-
custom_objects_directory: str = None,
|
|
38
|
+
model_path: Optional[str] = None,
|
|
39
|
+
modules_map: Optional[Union[dict[str, Union[None, str, list[str]]], str]] = None,
|
|
40
|
+
custom_objects_map: Optional[Union[dict[str, Union[str, list[str]]], str]] = None,
|
|
41
|
+
custom_objects_directory: Optional[str] = None,
|
|
43
42
|
context: mlrun.MLClientCtx = None,
|
|
44
|
-
artifacts: Union[list[MLPlan], list[str], dict[str, dict]] = None,
|
|
45
|
-
metrics:
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
43
|
+
artifacts: Optional[Union[list[MLPlan], list[str], dict[str, dict]]] = None,
|
|
44
|
+
metrics: Optional[
|
|
45
|
+
Union[
|
|
46
|
+
list[Metric],
|
|
47
|
+
list[SKLearnTypes.MetricEntryType],
|
|
48
|
+
dict[str, SKLearnTypes.MetricEntryType],
|
|
49
|
+
]
|
|
49
50
|
] = None,
|
|
50
51
|
x_test: SKLearnTypes.DatasetType = None,
|
|
51
52
|
y_test: SKLearnTypes.DatasetType = None,
|
|
52
53
|
sample_set: Union[SKLearnTypes.DatasetType, mlrun.DataItem, str] = None,
|
|
53
|
-
y_columns: Union[list[str], list[int]] = None,
|
|
54
|
-
feature_vector: str = None,
|
|
55
|
-
feature_weights: list[float] = None,
|
|
56
|
-
labels: dict[str, Union[str, int, float]] = None,
|
|
57
|
-
parameters: dict[str, Union[str, int, float]] = None,
|
|
58
|
-
extra_data: dict[str, SKLearnTypes.ExtraDataType] = None,
|
|
54
|
+
y_columns: Optional[Union[list[str], list[int]]] = None,
|
|
55
|
+
feature_vector: Optional[str] = None,
|
|
56
|
+
feature_weights: Optional[list[float]] = None,
|
|
57
|
+
labels: Optional[dict[str, Union[str, int, float]]] = None,
|
|
58
|
+
parameters: Optional[dict[str, Union[str, int, float]]] = None,
|
|
59
|
+
extra_data: Optional[dict[str, SKLearnTypes.ExtraDataType]] = None,
|
|
59
60
|
auto_log: bool = True,
|
|
60
61
|
**kwargs,
|
|
61
62
|
) -> SKLearnModelHandler:
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Union
|
|
15
|
+
from typing import Optional, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import pandas as pd
|
|
@@ -32,7 +32,7 @@ class Estimator:
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
34
|
context: mlrun.MLClientCtx = None,
|
|
35
|
-
metrics: list[Metric] = None,
|
|
35
|
+
metrics: Optional[list[Metric]] = None,
|
|
36
36
|
):
|
|
37
37
|
"""
|
|
38
38
|
Initialize an estimator with the given metrics. The estimator will log the calculated results using the given
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
import importlib
|
|
16
16
|
import json
|
|
17
17
|
import sys
|
|
18
|
-
from typing import Callable, Union
|
|
18
|
+
from typing import Callable, Optional, Union
|
|
19
19
|
|
|
20
20
|
import mlrun.errors
|
|
21
21
|
|
|
@@ -31,8 +31,8 @@ class Metric:
|
|
|
31
31
|
def __init__(
|
|
32
32
|
self,
|
|
33
33
|
metric: Union[Callable, str],
|
|
34
|
-
name: str = None,
|
|
35
|
-
additional_arguments: dict = None,
|
|
34
|
+
name: Optional[str] = None,
|
|
35
|
+
additional_arguments: Optional[dict] = None,
|
|
36
36
|
need_probabilities: bool = False,
|
|
37
37
|
):
|
|
38
38
|
"""
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
from abc import ABC
|
|
16
|
-
from typing import Union
|
|
16
|
+
from typing import Optional, Union
|
|
17
17
|
|
|
18
18
|
import sklearn
|
|
19
19
|
from sklearn.preprocessing import LabelBinarizer
|
|
@@ -39,10 +39,12 @@ class MetricsLibrary(ABC):
|
|
|
39
39
|
@classmethod
|
|
40
40
|
def get_metrics(
|
|
41
41
|
cls,
|
|
42
|
-
metrics:
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
42
|
+
metrics: Optional[
|
|
43
|
+
Union[
|
|
44
|
+
list[Metric],
|
|
45
|
+
list[SKLearnTypes.MetricEntryType],
|
|
46
|
+
dict[str, SKLearnTypes.MetricEntryType],
|
|
47
|
+
]
|
|
46
48
|
] = None,
|
|
47
49
|
context: mlrun.MLClientCtx = None,
|
|
48
50
|
include_default: bool = True,
|
|
@@ -262,7 +264,7 @@ class MetricsLibrary(ABC):
|
|
|
262
264
|
def _to_metric_class(
|
|
263
265
|
cls,
|
|
264
266
|
metric_entry: SKLearnTypes.MetricEntryType,
|
|
265
|
-
metric_name: str = None,
|
|
267
|
+
metric_name: Optional[str] = None,
|
|
266
268
|
) -> Metric:
|
|
267
269
|
"""
|
|
268
270
|
Create a Metric instance from a user given metric entry.
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
from abc import ABC
|
|
16
|
+
from typing import Optional
|
|
16
17
|
|
|
17
18
|
import mlrun
|
|
18
19
|
|
|
@@ -161,8 +162,8 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
|
|
|
161
162
|
def configure_logging(
|
|
162
163
|
self,
|
|
163
164
|
context: mlrun.MLClientCtx = None,
|
|
164
|
-
plans: list[MLPlan] = None,
|
|
165
|
-
metrics: list[Metric] = None,
|
|
165
|
+
plans: Optional[list[MLPlan]] = None,
|
|
166
|
+
metrics: Optional[list[Metric]] = None,
|
|
166
167
|
x_test: SKLearnTypes.DatasetType = None,
|
|
167
168
|
y_test: SKLearnTypes.DatasetType = None,
|
|
168
169
|
model_handler: MLModelHandler = None,
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
#
|
|
15
15
|
import os
|
|
16
16
|
import pickle
|
|
17
|
+
from typing import Optional
|
|
17
18
|
|
|
18
19
|
import cloudpickle
|
|
19
20
|
|
|
@@ -49,7 +50,7 @@ class SKLearnModelHandler(MLModelHandler):
|
|
|
49
50
|
)
|
|
50
51
|
|
|
51
52
|
@without_mlrun_interface(interface=SKLearnMLRunInterface)
|
|
52
|
-
def save(self, output_path: str = None, **kwargs):
|
|
53
|
+
def save(self, output_path: Optional[str] = None, **kwargs):
|
|
53
54
|
"""
|
|
54
55
|
Save the handled model at the given output path. If a MLRun context is available, the saved model files will be
|
|
55
56
|
logged and returned as artifacts.
|
|
@@ -81,10 +82,10 @@ class SKLearnModelHandler(MLModelHandler):
|
|
|
81
82
|
|
|
82
83
|
def to_onnx(
|
|
83
84
|
self,
|
|
84
|
-
model_name: str = None,
|
|
85
|
+
model_name: Optional[str] = None,
|
|
85
86
|
optimize: bool = True,
|
|
86
87
|
input_sample: SKLearnTypes.DatasetType = None,
|
|
87
|
-
log: bool = None,
|
|
88
|
+
log: Optional[bool] = None,
|
|
88
89
|
):
|
|
89
90
|
"""
|
|
90
91
|
Convert the model in this handler to an ONNX model. The inputs names are optional, they do not change the
|
|
@@ -11,9 +11,8 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
from typing import Any, Union
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional, Union
|
|
17
16
|
|
|
18
17
|
from tensorflow import keras
|
|
19
18
|
|
|
@@ -29,20 +28,20 @@ from .utils import TFKerasTypes, TFKerasUtils
|
|
|
29
28
|
|
|
30
29
|
def apply_mlrun(
|
|
31
30
|
model: keras.Model = None,
|
|
32
|
-
model_name: str = None,
|
|
31
|
+
model_name: Optional[str] = None,
|
|
33
32
|
tag: str = "",
|
|
34
|
-
model_path: str = None,
|
|
33
|
+
model_path: Optional[str] = None,
|
|
35
34
|
model_format: str = TFKerasModelHandler.ModelFormats.SAVED_MODEL,
|
|
36
35
|
save_traces: bool = False,
|
|
37
|
-
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
38
|
-
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
39
|
-
custom_objects_directory: str = None,
|
|
36
|
+
modules_map: Optional[Union[dict[str, Union[None, str, list[str]]], str]] = None,
|
|
37
|
+
custom_objects_map: Optional[Union[dict[str, Union[str, list[str]]], str]] = None,
|
|
38
|
+
custom_objects_directory: Optional[str] = None,
|
|
40
39
|
context: mlrun.MLClientCtx = None,
|
|
41
40
|
auto_log: bool = True,
|
|
42
|
-
tensorboard_directory: str = None,
|
|
43
|
-
mlrun_callback_kwargs: dict[str, Any] = None,
|
|
44
|
-
tensorboard_callback_kwargs: dict[str, Any] = None,
|
|
45
|
-
use_horovod: bool = None,
|
|
41
|
+
tensorboard_directory: Optional[str] = None,
|
|
42
|
+
mlrun_callback_kwargs: Optional[dict[str, Any]] = None,
|
|
43
|
+
tensorboard_callback_kwargs: Optional[dict[str, Any]] = None,
|
|
44
|
+
use_horovod: Optional[bool] = None,
|
|
46
45
|
**kwargs,
|
|
47
46
|
) -> TFKerasModelHandler:
|
|
48
47
|
"""
|
|
@@ -11,8 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
|
|
14
|
+
|
|
16
15
|
from .logging_callback import LoggingCallback
|
|
17
16
|
from .mlrun_logging_callback import MLRunLoggingCallback
|
|
18
17
|
from .tensorboard_logging_callback import TensorboardLoggingCallback
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Callable, Union
|
|
15
|
+
from typing import Callable, Optional, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import tensorflow as tf
|
|
@@ -36,11 +36,14 @@ class LoggingCallback(Callback):
|
|
|
36
36
|
def __init__(
|
|
37
37
|
self,
|
|
38
38
|
context: mlrun.MLClientCtx = None,
|
|
39
|
-
dynamic_hyperparameters:
|
|
40
|
-
|
|
39
|
+
dynamic_hyperparameters: Optional[
|
|
40
|
+
dict[
|
|
41
|
+
str,
|
|
42
|
+
Union[list[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]],
|
|
43
|
+
]
|
|
41
44
|
] = None,
|
|
42
|
-
static_hyperparameters:
|
|
43
|
-
str, Union[TFKerasTypes.TrackableType, list[Union[str, int]]]
|
|
45
|
+
static_hyperparameters: Optional[
|
|
46
|
+
dict[str, Union[TFKerasTypes.TrackableType, list[Union[str, int]]]]
|
|
44
47
|
] = None,
|
|
45
48
|
auto_log: bool = False,
|
|
46
49
|
):
|
|
@@ -175,7 +178,7 @@ class LoggingCallback(Callback):
|
|
|
175
178
|
"""
|
|
176
179
|
return self._logger.validation_iterations
|
|
177
180
|
|
|
178
|
-
def on_train_begin(self, logs: dict = None):
|
|
181
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
179
182
|
"""
|
|
180
183
|
Called once at the beginning of training process (one time call).
|
|
181
184
|
|
|
@@ -185,7 +188,7 @@ class LoggingCallback(Callback):
|
|
|
185
188
|
self._is_training = True
|
|
186
189
|
self._setup_run()
|
|
187
190
|
|
|
188
|
-
def on_test_begin(self, logs: dict = None):
|
|
191
|
+
def on_test_begin(self, logs: Optional[dict] = None):
|
|
189
192
|
"""
|
|
190
193
|
Called at the beginning of evaluation or validation. Will be called on each epoch according to the validation
|
|
191
194
|
per epoch configuration.
|
|
@@ -202,7 +205,7 @@ class LoggingCallback(Callback):
|
|
|
202
205
|
if not self._is_training:
|
|
203
206
|
self._setup_run()
|
|
204
207
|
|
|
205
|
-
def on_test_end(self, logs: dict = None):
|
|
208
|
+
def on_test_end(self, logs: Optional[dict] = None):
|
|
206
209
|
"""
|
|
207
210
|
Called at the end of evaluation or validation. Will be called on each epoch according to the validation
|
|
208
211
|
per epoch configuration. The recent evaluation / validation results will be summarized and logged.
|
|
@@ -220,7 +223,7 @@ class LoggingCallback(Callback):
|
|
|
220
223
|
result=float(sum(epoch_values[-1]) / len(epoch_values[-1])),
|
|
221
224
|
)
|
|
222
225
|
|
|
223
|
-
def on_epoch_begin(self, epoch: int, logs: dict = None):
|
|
226
|
+
def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
|
|
224
227
|
"""
|
|
225
228
|
Called at the start of an epoch, logging it and appending a new epoch to the logger's dictionaries.
|
|
226
229
|
|
|
@@ -236,7 +239,7 @@ class LoggingCallback(Callback):
|
|
|
236
239
|
for metric in sum_dictionary:
|
|
237
240
|
sum_dictionary[metric] = 0
|
|
238
241
|
|
|
239
|
-
def on_epoch_end(self, epoch: int, logs: dict = None):
|
|
242
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
240
243
|
"""
|
|
241
244
|
Called at the end of an epoch, logging the training summaries and the current dynamic hyperparameters values.
|
|
242
245
|
|
|
@@ -262,7 +265,7 @@ class LoggingCallback(Callback):
|
|
|
262
265
|
value=self._get_hyperparameter(key_chain=key_chain),
|
|
263
266
|
)
|
|
264
267
|
|
|
265
|
-
def on_train_batch_begin(self, batch: int, logs: dict = None):
|
|
268
|
+
def on_train_batch_begin(self, batch: int, logs: Optional[dict] = None):
|
|
266
269
|
"""
|
|
267
270
|
Called at the beginning of a training batch in `fit` methods. The logger will check if this batch is needed to
|
|
268
271
|
be logged according to the configuration. Note that if the `steps_per_execution` argument to `compile` in
|
|
@@ -274,7 +277,7 @@ class LoggingCallback(Callback):
|
|
|
274
277
|
"""
|
|
275
278
|
self._logger.log_training_iteration()
|
|
276
279
|
|
|
277
|
-
def on_train_batch_end(self, batch: int, logs: dict = None):
|
|
280
|
+
def on_train_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
278
281
|
"""
|
|
279
282
|
Called at the end of a training batch in `fit` methods. The batch metrics results will be logged. Note that if
|
|
280
283
|
the `steps_per_execution` argument to `compile` in `tf.keras.Model` is set to `N`, this method will only be
|
|
@@ -289,7 +292,7 @@ class LoggingCallback(Callback):
|
|
|
289
292
|
logs=logs,
|
|
290
293
|
)
|
|
291
294
|
|
|
292
|
-
def on_test_batch_begin(self, batch: int, logs: dict = None):
|
|
295
|
+
def on_test_batch_begin(self, batch: int, logs: Optional[dict] = None):
|
|
293
296
|
"""
|
|
294
297
|
Called at the beginning of a batch in `evaluate` methods. Also called at the beginning of a validation batch in
|
|
295
298
|
the `fit` methods, if validation data is provided. The logger will check if this batch is needed to be logged
|
|
@@ -302,7 +305,7 @@ class LoggingCallback(Callback):
|
|
|
302
305
|
"""
|
|
303
306
|
self._logger.log_validation_iteration()
|
|
304
307
|
|
|
305
|
-
def on_test_batch_end(self, batch: int, logs: dict = None):
|
|
308
|
+
def on_test_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
306
309
|
"""
|
|
307
310
|
Called at the end of a batch in `evaluate` methods. Also called at the end of a validation batch in the `fit`
|
|
308
311
|
methods, if validation data is provided. The batch metrics results will be logged. Note that if the
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Callable, Union
|
|
15
|
+
from typing import Callable, Optional, Union
|
|
16
16
|
|
|
17
17
|
import mlrun
|
|
18
18
|
from mlrun.artifacts import Artifact
|
|
@@ -50,16 +50,19 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
50
50
|
context: mlrun.MLClientCtx,
|
|
51
51
|
model_handler: TFKerasModelHandler,
|
|
52
52
|
log_model_tag: str = "",
|
|
53
|
-
log_model_labels: dict[str, TFKerasTypes.TrackableType] = None,
|
|
54
|
-
log_model_parameters: dict[str, TFKerasTypes.TrackableType] = None,
|
|
55
|
-
log_model_extra_data:
|
|
56
|
-
str, Union[TFKerasTypes.TrackableType, Artifact]
|
|
53
|
+
log_model_labels: Optional[dict[str, TFKerasTypes.TrackableType]] = None,
|
|
54
|
+
log_model_parameters: Optional[dict[str, TFKerasTypes.TrackableType]] = None,
|
|
55
|
+
log_model_extra_data: Optional[
|
|
56
|
+
dict[str, Union[TFKerasTypes.TrackableType, Artifact]]
|
|
57
57
|
] = None,
|
|
58
|
-
dynamic_hyperparameters:
|
|
59
|
-
|
|
58
|
+
dynamic_hyperparameters: Optional[
|
|
59
|
+
dict[
|
|
60
|
+
str,
|
|
61
|
+
Union[list[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]],
|
|
62
|
+
]
|
|
60
63
|
] = None,
|
|
61
|
-
static_hyperparameters:
|
|
62
|
-
str, Union[TFKerasTypes, list[Union[str, int]]]
|
|
64
|
+
static_hyperparameters: Optional[
|
|
65
|
+
dict[str, Union[TFKerasTypes, list[Union[str, int]]]]
|
|
63
66
|
] = None,
|
|
64
67
|
auto_log: bool = False,
|
|
65
68
|
):
|
|
@@ -116,7 +119,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
116
119
|
# Store the model handler:
|
|
117
120
|
self._model_handler = model_handler
|
|
118
121
|
|
|
119
|
-
def on_train_end(self, logs: dict = None):
|
|
122
|
+
def on_train_end(self, logs: Optional[dict] = None):
|
|
120
123
|
"""
|
|
121
124
|
Called at the end of training, logging the model and the summaries of this run.
|
|
122
125
|
|
|
@@ -125,7 +128,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
125
128
|
"""
|
|
126
129
|
self._end_run()
|
|
127
130
|
|
|
128
|
-
def on_test_end(self, logs: dict = None):
|
|
131
|
+
def on_test_end(self, logs: Optional[dict] = None):
|
|
129
132
|
"""
|
|
130
133
|
Called at the end of evaluation or validation. Will be called on each epoch according to the validation
|
|
131
134
|
per epoch configuration. The recent evaluation / validation results will be summarized and logged. If the logger
|
|
@@ -141,7 +144,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
141
144
|
self._logger.log_epoch_to_context(epoch=1)
|
|
142
145
|
self._end_run()
|
|
143
146
|
|
|
144
|
-
def on_epoch_end(self, epoch: int, logs: dict = None):
|
|
147
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
145
148
|
"""
|
|
146
149
|
Called at the end of an epoch, logging the dynamic hyperparameters and results of this epoch via the stored
|
|
147
150
|
context.
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
from datetime import datetime
|
|
16
|
-
from typing import Callable, Union
|
|
16
|
+
from typing import Callable, Optional, Union
|
|
17
17
|
|
|
18
18
|
import tensorflow as tf
|
|
19
19
|
from packaging import version
|
|
@@ -40,8 +40,8 @@ class _TFKerasTensorboardLogger(TensorboardLogger):
|
|
|
40
40
|
self,
|
|
41
41
|
statistics_functions: list[Callable[[Union[Variable]], Union[float, Variable]]],
|
|
42
42
|
context: mlrun.MLClientCtx = None,
|
|
43
|
-
tensorboard_directory: str = None,
|
|
44
|
-
run_name: str = None,
|
|
43
|
+
tensorboard_directory: Optional[str] = None,
|
|
44
|
+
run_name: Optional[str] = None,
|
|
45
45
|
update_frequency: Union[int, str] = "epoch",
|
|
46
46
|
):
|
|
47
47
|
"""
|
|
@@ -253,17 +253,20 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
253
253
|
def __init__(
|
|
254
254
|
self,
|
|
255
255
|
context: mlrun.MLClientCtx = None,
|
|
256
|
-
tensorboard_directory: str = None,
|
|
257
|
-
run_name: str = None,
|
|
256
|
+
tensorboard_directory: Optional[str] = None,
|
|
257
|
+
run_name: Optional[str] = None,
|
|
258
258
|
weights: Union[bool, list[str]] = False,
|
|
259
|
-
statistics_functions:
|
|
260
|
-
Callable[[Union[Variable, Tensor]], Union[float, Tensor]]
|
|
259
|
+
statistics_functions: Optional[
|
|
260
|
+
list[Callable[[Union[Variable, Tensor]], Union[float, Tensor]]]
|
|
261
261
|
] = None,
|
|
262
|
-
dynamic_hyperparameters:
|
|
263
|
-
|
|
262
|
+
dynamic_hyperparameters: Optional[
|
|
263
|
+
dict[
|
|
264
|
+
str,
|
|
265
|
+
Union[list[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]],
|
|
266
|
+
]
|
|
264
267
|
] = None,
|
|
265
|
-
static_hyperparameters:
|
|
266
|
-
str, Union[TFKerasTypes.TrackableType, list[Union[str, int]]]
|
|
268
|
+
static_hyperparameters: Optional[
|
|
269
|
+
dict[str, Union[TFKerasTypes.TrackableType, list[Union[str, int]]]]
|
|
267
270
|
] = None,
|
|
268
271
|
update_frequency: Union[int, str] = "epoch",
|
|
269
272
|
auto_log: bool = False,
|
|
@@ -370,7 +373,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
370
373
|
"""
|
|
371
374
|
return self._logger.weight_statistics
|
|
372
375
|
|
|
373
|
-
def on_train_begin(self, logs: dict = None):
|
|
376
|
+
def on_train_begin(self, logs: Optional[dict] = None):
|
|
374
377
|
"""
|
|
375
378
|
Called once at the beginning of training process (one time call). Will log the pre-training (epoch 0)
|
|
376
379
|
hyperparameters and weights.
|
|
@@ -401,7 +404,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
401
404
|
# Make sure all values were written to the directory logs:
|
|
402
405
|
self._logger.flush()
|
|
403
406
|
|
|
404
|
-
def on_train_end(self, logs: dict = None):
|
|
407
|
+
def on_train_end(self, logs: Optional[dict] = None):
|
|
405
408
|
"""
|
|
406
409
|
Called at the end of training, wrapping up the tensorboard logging session.
|
|
407
410
|
|
|
@@ -416,7 +419,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
416
419
|
# Close the logger:
|
|
417
420
|
self._logger.close()
|
|
418
421
|
|
|
419
|
-
def on_test_begin(self, logs: dict = None):
|
|
422
|
+
def on_test_begin(self, logs: Optional[dict] = None):
|
|
420
423
|
"""
|
|
421
424
|
Called at the beginning of evaluation or validation. Will be called on each epoch according to the validation
|
|
422
425
|
per epoch configuration. In case it is an evaluation, the epoch 0 will be logged.
|
|
@@ -445,7 +448,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
445
448
|
# Make sure all values were written to the directory logs:
|
|
446
449
|
self._logger.flush()
|
|
447
450
|
|
|
448
|
-
def on_test_end(self, logs: dict = None):
|
|
451
|
+
def on_test_end(self, logs: Optional[dict] = None):
|
|
449
452
|
"""
|
|
450
453
|
Called at the end of evaluation or validation. Will be called on each epoch according to the validation
|
|
451
454
|
per epoch configuration. The recent evaluation / validation results will be summarized and logged.
|
|
@@ -466,7 +469,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
466
469
|
# Close the logger:
|
|
467
470
|
self._logger.close()
|
|
468
471
|
|
|
469
|
-
def on_epoch_end(self, epoch: int, logs: dict = None):
|
|
472
|
+
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
470
473
|
"""
|
|
471
474
|
Called at the end of an epoch, logging the current dynamic hyperparameters values, summaries and weights to
|
|
472
475
|
tensorboard.
|
|
@@ -504,7 +507,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
504
507
|
# Make sure all values were written to the directory logs:
|
|
505
508
|
self._logger.flush()
|
|
506
509
|
|
|
507
|
-
def on_train_batch_end(self, batch: int, logs: dict = None):
|
|
510
|
+
def on_train_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
508
511
|
"""
|
|
509
512
|
Called at the end of a training batch in `fit` methods. The batch metrics results will be logged. If it is the
|
|
510
513
|
first batch to end, the model architecture and hyperparameters will be logged as well. Note that if the
|
|
@@ -526,7 +529,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
526
529
|
self._logged_hyperparameters = True
|
|
527
530
|
self._logger.write_dynamic_hyperparameters()
|
|
528
531
|
|
|
529
|
-
def on_test_batch_end(self, batch: int, logs: dict = None):
|
|
532
|
+
def on_test_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
530
533
|
"""
|
|
531
534
|
Called at the end of a batch in `evaluate` methods. Also called at the end of a validation batch in the `fit`
|
|
532
535
|
methods, if validation data is provided. The batch metrics results will be logged. In case it is an evaluation
|
|
@@ -15,7 +15,7 @@
|
|
|
15
15
|
import os
|
|
16
16
|
import shutil
|
|
17
17
|
import zipfile
|
|
18
|
-
from typing import Union
|
|
18
|
+
from typing import Optional, Union
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
21
21
|
import tensorflow as tf
|
|
@@ -63,13 +63,17 @@ class TFKerasModelHandler(DLModelHandler):
|
|
|
63
63
|
def __init__(
|
|
64
64
|
self,
|
|
65
65
|
model: keras.Model = None,
|
|
66
|
-
model_path: str = None,
|
|
67
|
-
model_name: str = None,
|
|
66
|
+
model_path: Optional[str] = None,
|
|
67
|
+
model_name: Optional[str] = None,
|
|
68
68
|
model_format: str = ModelFormats.SAVED_MODEL,
|
|
69
69
|
context: mlrun.MLClientCtx = None,
|
|
70
|
-
modules_map:
|
|
71
|
-
|
|
72
|
-
|
|
70
|
+
modules_map: Optional[
|
|
71
|
+
Union[dict[str, Union[None, str, list[str]]], str]
|
|
72
|
+
] = None,
|
|
73
|
+
custom_objects_map: Optional[
|
|
74
|
+
Union[dict[str, Union[str, list[str]]], str]
|
|
75
|
+
] = None,
|
|
76
|
+
custom_objects_directory: Optional[str] = None,
|
|
73
77
|
save_traces: bool = False,
|
|
74
78
|
**kwargs,
|
|
75
79
|
):
|
|
@@ -190,8 +194,8 @@ class TFKerasModelHandler(DLModelHandler):
|
|
|
190
194
|
|
|
191
195
|
def set_labels(
|
|
192
196
|
self,
|
|
193
|
-
to_add: dict[str, Union[str, int, float]] = None,
|
|
194
|
-
to_remove: list[str] = None,
|
|
197
|
+
to_add: Optional[dict[str, Union[str, int, float]]] = None,
|
|
198
|
+
to_remove: Optional[list[str]] = None,
|
|
195
199
|
):
|
|
196
200
|
"""
|
|
197
201
|
Update the labels dictionary of this model artifact. There are required labels that cannot be edited or removed.
|
|
@@ -210,7 +214,7 @@ class TFKerasModelHandler(DLModelHandler):
|
|
|
210
214
|
# TODO: output_path won't work well with logging artifacts. Need to look into changing the logic of 'log_artifact'.
|
|
211
215
|
@without_mlrun_interface(interface=TFKerasMLRunInterface)
|
|
212
216
|
def save(
|
|
213
|
-
self, output_path: str = None, **kwargs
|
|
217
|
+
self, output_path: Optional[str] = None, **kwargs
|
|
214
218
|
) -> Union[dict[str, Artifact], None]:
|
|
215
219
|
"""
|
|
216
220
|
Save the handled model at the given output path. If a MLRun context is available, the saved model files will be
|
|
@@ -274,7 +278,7 @@ class TFKerasModelHandler(DLModelHandler):
|
|
|
274
278
|
|
|
275
279
|
return artifacts if self._context is not None else None
|
|
276
280
|
|
|
277
|
-
def load(self, checkpoint: str = None, **kwargs):
|
|
281
|
+
def load(self, checkpoint: Optional[str] = None, **kwargs):
|
|
278
282
|
"""
|
|
279
283
|
Load the specified model in this handler. If a checkpoint is required to be loaded, it can be given here
|
|
280
284
|
according to the provided model path in the initialization of this handler. Additional parameters for the class
|
|
@@ -318,13 +322,13 @@ class TFKerasModelHandler(DLModelHandler):
|
|
|
318
322
|
|
|
319
323
|
def to_onnx(
|
|
320
324
|
self,
|
|
321
|
-
model_name: str = None,
|
|
325
|
+
model_name: Optional[str] = None,
|
|
322
326
|
optimize: bool = True,
|
|
323
327
|
input_signature: Union[
|
|
324
328
|
list[tf.TensorSpec], list[np.ndarray], tf.TensorSpec, np.ndarray
|
|
325
329
|
] = None,
|
|
326
|
-
output_path: str = None,
|
|
327
|
-
log: bool = None,
|
|
330
|
+
output_path: Optional[str] = None,
|
|
331
|
+
log: Optional[bool] = None,
|
|
328
332
|
):
|
|
329
333
|
"""
|
|
330
334
|
Convert the model in this handler to an ONNX model.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Any, Union
|
|
15
|
+
from typing import Any, Optional, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
from tensorflow import keras
|
|
@@ -32,16 +32,20 @@ class TFKerasModelServer(V2ModelServer):
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
34
|
context: mlrun.MLClientCtx = None,
|
|
35
|
-
name: str = None,
|
|
35
|
+
name: Optional[str] = None,
|
|
36
36
|
model: keras.Model = None,
|
|
37
|
-
model_path: str = None,
|
|
38
|
-
model_name: str = None,
|
|
39
|
-
modules_map:
|
|
40
|
-
|
|
41
|
-
|
|
37
|
+
model_path: Optional[str] = None,
|
|
38
|
+
model_name: Optional[str] = None,
|
|
39
|
+
modules_map: Optional[
|
|
40
|
+
Union[dict[str, Union[None, str, list[str]]], str]
|
|
41
|
+
] = None,
|
|
42
|
+
custom_objects_map: Optional[
|
|
43
|
+
Union[dict[str, Union[str, list[str]]], str]
|
|
44
|
+
] = None,
|
|
45
|
+
custom_objects_directory: Optional[str] = None,
|
|
42
46
|
model_format: str = TFKerasModelHandler.ModelFormats.SAVED_MODEL,
|
|
43
47
|
to_list: bool = False,
|
|
44
|
-
protocol: str = None,
|
|
48
|
+
protocol: Optional[str] = None,
|
|
45
49
|
**class_args,
|
|
46
50
|
):
|
|
47
51
|
"""
|