mlrun 1.6.4rc7__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +11 -1
- mlrun/__main__.py +40 -122
- mlrun/alerts/__init__.py +15 -0
- mlrun/alerts/alert.py +248 -0
- mlrun/api/schemas/__init__.py +5 -4
- mlrun/artifacts/__init__.py +8 -3
- mlrun/artifacts/base.py +47 -257
- mlrun/artifacts/dataset.py +11 -192
- mlrun/artifacts/manager.py +79 -47
- mlrun/artifacts/model.py +31 -159
- mlrun/artifacts/plots.py +23 -380
- mlrun/common/constants.py +74 -1
- mlrun/common/db/sql_session.py +5 -5
- mlrun/common/formatters/__init__.py +21 -0
- mlrun/common/formatters/artifact.py +45 -0
- mlrun/common/formatters/base.py +113 -0
- mlrun/common/formatters/feature_set.py +33 -0
- mlrun/common/formatters/function.py +46 -0
- mlrun/common/formatters/pipeline.py +53 -0
- mlrun/common/formatters/project.py +51 -0
- mlrun/common/formatters/run.py +29 -0
- mlrun/common/helpers.py +12 -3
- mlrun/common/model_monitoring/helpers.py +9 -5
- mlrun/{runtimes → common/runtimes}/constants.py +37 -9
- mlrun/common/schemas/__init__.py +31 -5
- mlrun/common/schemas/alert.py +202 -0
- mlrun/common/schemas/api_gateway.py +196 -0
- mlrun/common/schemas/artifact.py +25 -4
- mlrun/common/schemas/auth.py +16 -5
- mlrun/common/schemas/background_task.py +1 -1
- mlrun/common/schemas/client_spec.py +4 -2
- mlrun/common/schemas/common.py +7 -4
- mlrun/common/schemas/constants.py +3 -0
- mlrun/common/schemas/feature_store.py +74 -44
- mlrun/common/schemas/frontend_spec.py +15 -7
- mlrun/common/schemas/function.py +12 -1
- mlrun/common/schemas/hub.py +11 -18
- mlrun/common/schemas/memory_reports.py +2 -2
- mlrun/common/schemas/model_monitoring/__init__.py +20 -4
- mlrun/common/schemas/model_monitoring/constants.py +123 -42
- mlrun/common/schemas/model_monitoring/grafana.py +13 -9
- mlrun/common/schemas/model_monitoring/model_endpoints.py +101 -54
- mlrun/common/schemas/notification.py +71 -14
- mlrun/common/schemas/object.py +2 -2
- mlrun/{model_monitoring/controller_handler.py → common/schemas/pagination.py} +9 -12
- mlrun/common/schemas/pipeline.py +8 -1
- mlrun/common/schemas/project.py +69 -18
- mlrun/common/schemas/runs.py +7 -1
- mlrun/common/schemas/runtime_resource.py +8 -12
- mlrun/common/schemas/schedule.py +4 -4
- mlrun/common/schemas/tag.py +1 -2
- mlrun/common/schemas/workflow.py +12 -4
- mlrun/common/types.py +14 -1
- mlrun/config.py +154 -69
- mlrun/data_types/data_types.py +6 -1
- mlrun/data_types/spark.py +2 -2
- mlrun/data_types/to_pandas.py +67 -37
- mlrun/datastore/__init__.py +6 -8
- mlrun/datastore/alibaba_oss.py +131 -0
- mlrun/datastore/azure_blob.py +143 -42
- mlrun/datastore/base.py +102 -58
- mlrun/datastore/datastore.py +34 -13
- mlrun/datastore/datastore_profile.py +146 -20
- mlrun/datastore/dbfs_store.py +3 -7
- mlrun/datastore/filestore.py +1 -4
- mlrun/datastore/google_cloud_storage.py +97 -33
- mlrun/datastore/hdfs.py +56 -0
- mlrun/datastore/inmem.py +6 -3
- mlrun/datastore/redis.py +7 -2
- mlrun/datastore/s3.py +34 -12
- mlrun/datastore/snowflake_utils.py +45 -0
- mlrun/datastore/sources.py +303 -111
- mlrun/datastore/spark_utils.py +31 -2
- mlrun/datastore/store_resources.py +9 -7
- mlrun/datastore/storeytargets.py +151 -0
- mlrun/datastore/targets.py +453 -176
- mlrun/datastore/utils.py +72 -58
- mlrun/datastore/v3io.py +6 -1
- mlrun/db/base.py +274 -41
- mlrun/db/factory.py +1 -1
- mlrun/db/httpdb.py +893 -225
- mlrun/db/nopdb.py +291 -33
- mlrun/errors.py +36 -6
- mlrun/execution.py +115 -42
- mlrun/feature_store/__init__.py +0 -2
- mlrun/feature_store/api.py +65 -73
- mlrun/feature_store/common.py +7 -12
- mlrun/feature_store/feature_set.py +76 -55
- mlrun/feature_store/feature_vector.py +39 -31
- mlrun/feature_store/ingestion.py +7 -6
- mlrun/feature_store/retrieval/base.py +16 -11
- mlrun/feature_store/retrieval/dask_merger.py +2 -0
- mlrun/feature_store/retrieval/job.py +13 -4
- mlrun/feature_store/retrieval/local_merger.py +2 -0
- mlrun/feature_store/retrieval/spark_merger.py +24 -32
- mlrun/feature_store/steps.py +45 -34
- mlrun/features.py +11 -21
- mlrun/frameworks/_common/artifacts_library.py +9 -9
- mlrun/frameworks/_common/mlrun_interface.py +5 -5
- mlrun/frameworks/_common/model_handler.py +48 -48
- mlrun/frameworks/_common/plan.py +5 -6
- mlrun/frameworks/_common/producer.py +3 -4
- mlrun/frameworks/_common/utils.py +5 -5
- mlrun/frameworks/_dl_common/loggers/logger.py +6 -7
- mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +9 -9
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +23 -47
- mlrun/frameworks/_ml_common/artifacts_library.py +1 -2
- mlrun/frameworks/_ml_common/loggers/logger.py +3 -4
- mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +4 -5
- mlrun/frameworks/_ml_common/model_handler.py +24 -24
- mlrun/frameworks/_ml_common/pkl_model_server.py +2 -2
- mlrun/frameworks/_ml_common/plan.py +2 -2
- mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
- mlrun/frameworks/_ml_common/utils.py +4 -4
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +9 -9
- mlrun/frameworks/huggingface/model_server.py +4 -4
- mlrun/frameworks/lgbm/__init__.py +33 -33
- mlrun/frameworks/lgbm/callbacks/callback.py +2 -4
- mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -5
- mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -5
- mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +1 -3
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +6 -6
- mlrun/frameworks/lgbm/model_handler.py +10 -10
- mlrun/frameworks/lgbm/model_server.py +6 -6
- mlrun/frameworks/lgbm/utils.py +5 -5
- mlrun/frameworks/onnx/dataset.py +8 -8
- mlrun/frameworks/onnx/mlrun_interface.py +3 -3
- mlrun/frameworks/onnx/model_handler.py +6 -6
- mlrun/frameworks/onnx/model_server.py +7 -7
- mlrun/frameworks/parallel_coordinates.py +6 -6
- mlrun/frameworks/pytorch/__init__.py +18 -18
- mlrun/frameworks/pytorch/callbacks/callback.py +4 -5
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +17 -17
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +23 -29
- mlrun/frameworks/pytorch/callbacks_handler.py +38 -38
- mlrun/frameworks/pytorch/mlrun_interface.py +20 -20
- mlrun/frameworks/pytorch/model_handler.py +17 -17
- mlrun/frameworks/pytorch/model_server.py +7 -7
- mlrun/frameworks/sklearn/__init__.py +13 -13
- mlrun/frameworks/sklearn/estimator.py +4 -4
- mlrun/frameworks/sklearn/metrics_library.py +14 -14
- mlrun/frameworks/sklearn/mlrun_interface.py +16 -9
- mlrun/frameworks/sklearn/model_handler.py +2 -2
- mlrun/frameworks/tf_keras/__init__.py +10 -7
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +15 -15
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +19 -23
- mlrun/frameworks/tf_keras/mlrun_interface.py +9 -11
- mlrun/frameworks/tf_keras/model_handler.py +14 -14
- mlrun/frameworks/tf_keras/model_server.py +6 -6
- mlrun/frameworks/xgboost/__init__.py +13 -13
- mlrun/frameworks/xgboost/model_handler.py +6 -6
- mlrun/k8s_utils.py +61 -17
- mlrun/launcher/__init__.py +1 -1
- mlrun/launcher/base.py +16 -15
- mlrun/launcher/client.py +13 -11
- mlrun/launcher/factory.py +1 -1
- mlrun/launcher/local.py +23 -13
- mlrun/launcher/remote.py +17 -10
- mlrun/lists.py +7 -6
- mlrun/model.py +478 -103
- mlrun/model_monitoring/__init__.py +1 -1
- mlrun/model_monitoring/api.py +163 -371
- mlrun/{runtimes/mpijob/v1alpha1.py → model_monitoring/applications/__init__.py} +9 -15
- mlrun/model_monitoring/applications/_application_steps.py +188 -0
- mlrun/model_monitoring/applications/base.py +108 -0
- mlrun/model_monitoring/applications/context.py +341 -0
- mlrun/model_monitoring/{evidently_application.py → applications/evidently_base.py} +27 -22
- mlrun/model_monitoring/applications/histogram_data_drift.py +354 -0
- mlrun/model_monitoring/applications/results.py +99 -0
- mlrun/model_monitoring/controller.py +131 -278
- mlrun/model_monitoring/db/__init__.py +18 -0
- mlrun/model_monitoring/db/stores/__init__.py +136 -0
- mlrun/model_monitoring/db/stores/base/__init__.py +15 -0
- mlrun/model_monitoring/db/stores/base/store.py +213 -0
- mlrun/model_monitoring/db/stores/sqldb/__init__.py +13 -0
- mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +71 -0
- mlrun/model_monitoring/db/stores/sqldb/models/base.py +190 -0
- mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +103 -0
- mlrun/model_monitoring/{stores/models/mysql.py → db/stores/sqldb/models/sqlite.py} +19 -13
- mlrun/model_monitoring/db/stores/sqldb/sql_store.py +659 -0
- mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +13 -0
- mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +726 -0
- mlrun/model_monitoring/db/tsdb/__init__.py +105 -0
- mlrun/model_monitoring/db/tsdb/base.py +448 -0
- mlrun/model_monitoring/db/tsdb/helpers.py +30 -0
- mlrun/model_monitoring/db/tsdb/tdengine/__init__.py +15 -0
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +279 -0
- mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +42 -0
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +507 -0
- mlrun/model_monitoring/db/tsdb/v3io/__init__.py +15 -0
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +158 -0
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +849 -0
- mlrun/model_monitoring/features_drift_table.py +134 -106
- mlrun/model_monitoring/helpers.py +199 -55
- mlrun/model_monitoring/metrics/__init__.py +13 -0
- mlrun/model_monitoring/metrics/histogram_distance.py +127 -0
- mlrun/model_monitoring/model_endpoint.py +3 -2
- mlrun/model_monitoring/stream_processing.py +131 -398
- mlrun/model_monitoring/tracking_policy.py +9 -2
- mlrun/model_monitoring/writer.py +161 -125
- mlrun/package/__init__.py +6 -6
- mlrun/package/context_handler.py +5 -5
- mlrun/package/packager.py +7 -7
- mlrun/package/packagers/default_packager.py +8 -8
- mlrun/package/packagers/numpy_packagers.py +15 -15
- mlrun/package/packagers/pandas_packagers.py +5 -5
- mlrun/package/packagers/python_standard_library_packagers.py +10 -10
- mlrun/package/packagers_manager.py +19 -23
- mlrun/package/utils/_formatter.py +6 -6
- mlrun/package/utils/_pickler.py +2 -2
- mlrun/package/utils/_supported_format.py +4 -4
- mlrun/package/utils/log_hint_utils.py +2 -2
- mlrun/package/utils/type_hint_utils.py +4 -9
- mlrun/platforms/__init__.py +11 -10
- mlrun/platforms/iguazio.py +24 -203
- mlrun/projects/operations.py +52 -25
- mlrun/projects/pipelines.py +191 -197
- mlrun/projects/project.py +1227 -400
- mlrun/render.py +16 -19
- mlrun/run.py +209 -184
- mlrun/runtimes/__init__.py +83 -15
- mlrun/runtimes/base.py +51 -35
- mlrun/runtimes/daskjob.py +17 -10
- mlrun/runtimes/databricks_job/databricks_cancel_task.py +1 -1
- mlrun/runtimes/databricks_job/databricks_runtime.py +8 -7
- mlrun/runtimes/databricks_job/databricks_wrapper.py +1 -1
- mlrun/runtimes/funcdoc.py +1 -29
- mlrun/runtimes/function_reference.py +1 -1
- mlrun/runtimes/kubejob.py +34 -128
- mlrun/runtimes/local.py +40 -11
- mlrun/runtimes/mpijob/__init__.py +0 -20
- mlrun/runtimes/mpijob/abstract.py +9 -10
- mlrun/runtimes/mpijob/v1.py +1 -1
- mlrun/{model_monitoring/stores/models/sqlite.py → runtimes/nuclio/__init__.py} +7 -9
- mlrun/runtimes/nuclio/api_gateway.py +769 -0
- mlrun/runtimes/nuclio/application/__init__.py +15 -0
- mlrun/runtimes/nuclio/application/application.py +758 -0
- mlrun/runtimes/nuclio/application/reverse_proxy.go +95 -0
- mlrun/runtimes/{function.py → nuclio/function.py} +200 -83
- mlrun/runtimes/{nuclio.py → nuclio/nuclio.py} +6 -6
- mlrun/runtimes/{serving.py → nuclio/serving.py} +65 -68
- mlrun/runtimes/pod.py +281 -101
- mlrun/runtimes/remotesparkjob.py +12 -9
- mlrun/runtimes/sparkjob/spark3job.py +67 -51
- mlrun/runtimes/utils.py +41 -75
- mlrun/secrets.py +9 -5
- mlrun/serving/__init__.py +8 -1
- mlrun/serving/remote.py +2 -7
- mlrun/serving/routers.py +85 -69
- mlrun/serving/server.py +69 -44
- mlrun/serving/states.py +209 -36
- mlrun/serving/utils.py +22 -14
- mlrun/serving/v1_serving.py +6 -7
- mlrun/serving/v2_serving.py +129 -54
- mlrun/track/tracker.py +2 -1
- mlrun/track/tracker_manager.py +3 -3
- mlrun/track/trackers/mlflow_tracker.py +6 -2
- mlrun/utils/async_http.py +6 -8
- mlrun/utils/azure_vault.py +1 -1
- mlrun/utils/clones.py +1 -2
- mlrun/utils/condition_evaluator.py +3 -3
- mlrun/utils/db.py +21 -3
- mlrun/utils/helpers.py +405 -225
- mlrun/utils/http.py +3 -6
- mlrun/utils/logger.py +112 -16
- mlrun/utils/notifications/notification/__init__.py +17 -13
- mlrun/utils/notifications/notification/base.py +50 -2
- mlrun/utils/notifications/notification/console.py +2 -0
- mlrun/utils/notifications/notification/git.py +24 -1
- mlrun/utils/notifications/notification/ipython.py +3 -1
- mlrun/utils/notifications/notification/slack.py +96 -21
- mlrun/utils/notifications/notification/webhook.py +59 -2
- mlrun/utils/notifications/notification_pusher.py +149 -30
- mlrun/utils/regex.py +9 -0
- mlrun/utils/retryer.py +208 -0
- mlrun/utils/singleton.py +1 -1
- mlrun/utils/v3io_clients.py +4 -6
- mlrun/utils/version/version.json +2 -2
- mlrun/utils/version/version.py +2 -6
- mlrun-1.7.0.dist-info/METADATA +378 -0
- mlrun-1.7.0.dist-info/RECORD +351 -0
- {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/WHEEL +1 -1
- mlrun/feature_store/retrieval/conversion.py +0 -273
- mlrun/kfpops.py +0 -868
- mlrun/model_monitoring/application.py +0 -310
- mlrun/model_monitoring/batch.py +0 -1095
- mlrun/model_monitoring/prometheus.py +0 -219
- mlrun/model_monitoring/stores/__init__.py +0 -111
- mlrun/model_monitoring/stores/kv_model_endpoint_store.py +0 -576
- mlrun/model_monitoring/stores/model_endpoint_store.py +0 -147
- mlrun/model_monitoring/stores/models/__init__.py +0 -27
- mlrun/model_monitoring/stores/models/base.py +0 -84
- mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -384
- mlrun/platforms/other.py +0 -306
- mlrun-1.6.4rc7.dist-info/METADATA +0 -272
- mlrun-1.6.4rc7.dist-info/RECORD +0 -314
- {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/LICENSE +0 -0
- {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/entry_points.txt +0 -0
- {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import pandas as pd
|
|
@@ -32,7 +32,7 @@ class Estimator:
|
|
|
32
32
|
def __init__(
|
|
33
33
|
self,
|
|
34
34
|
context: mlrun.MLClientCtx = None,
|
|
35
|
-
metrics:
|
|
35
|
+
metrics: list[Metric] = None,
|
|
36
36
|
):
|
|
37
37
|
"""
|
|
38
38
|
Initialize an estimator with the given metrics. The estimator will log the calculated results using the given
|
|
@@ -62,7 +62,7 @@ class Estimator:
|
|
|
62
62
|
return self._context
|
|
63
63
|
|
|
64
64
|
@property
|
|
65
|
-
def results(self) ->
|
|
65
|
+
def results(self) -> dict[str, float]:
|
|
66
66
|
"""
|
|
67
67
|
Get the logged results.
|
|
68
68
|
|
|
@@ -86,7 +86,7 @@ class Estimator:
|
|
|
86
86
|
"""
|
|
87
87
|
self._context = context
|
|
88
88
|
|
|
89
|
-
def set_metrics(self, metrics:
|
|
89
|
+
def set_metrics(self, metrics: list[Metric]):
|
|
90
90
|
"""
|
|
91
91
|
Update the metrics of this logger to the given list of metrics here.
|
|
92
92
|
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
from abc import ABC
|
|
16
|
-
from typing import
|
|
16
|
+
from typing import Union
|
|
17
17
|
|
|
18
18
|
import sklearn
|
|
19
19
|
from sklearn.preprocessing import LabelBinarizer
|
|
@@ -40,14 +40,14 @@ class MetricsLibrary(ABC):
|
|
|
40
40
|
def get_metrics(
|
|
41
41
|
cls,
|
|
42
42
|
metrics: Union[
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
43
|
+
list[Metric],
|
|
44
|
+
list[SKLearnTypes.MetricEntryType],
|
|
45
|
+
dict[str, SKLearnTypes.MetricEntryType],
|
|
46
46
|
] = None,
|
|
47
47
|
context: mlrun.MLClientCtx = None,
|
|
48
48
|
include_default: bool = True,
|
|
49
49
|
**default_kwargs,
|
|
50
|
-
) ->
|
|
50
|
+
) -> list[Metric]:
|
|
51
51
|
"""
|
|
52
52
|
Get metrics for a run. The metrics will be taken from the provided metrics / configuration via code, from
|
|
53
53
|
provided configuration via MLRun context and if the 'include_default' is True, from the metric library's
|
|
@@ -87,11 +87,11 @@ class MetricsLibrary(ABC):
|
|
|
87
87
|
def _parse(
|
|
88
88
|
cls,
|
|
89
89
|
metrics: Union[
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
90
|
+
list[Metric],
|
|
91
|
+
list[SKLearnTypes.MetricEntryType],
|
|
92
|
+
dict[str, SKLearnTypes.MetricEntryType],
|
|
93
93
|
],
|
|
94
|
-
) ->
|
|
94
|
+
) -> list[Metric]:
|
|
95
95
|
"""
|
|
96
96
|
Parse the given metrics by the possible rules of the framework implementing.
|
|
97
97
|
|
|
@@ -116,8 +116,8 @@ class MetricsLibrary(ABC):
|
|
|
116
116
|
|
|
117
117
|
@classmethod
|
|
118
118
|
def _from_list(
|
|
119
|
-
cls, metrics_list:
|
|
120
|
-
) ->
|
|
119
|
+
cls, metrics_list: list[Union[Metric, SKLearnTypes.MetricEntryType]]
|
|
120
|
+
) -> list[Metric]:
|
|
121
121
|
"""
|
|
122
122
|
Collect the given metrics configurations from a list. The metrics names will be chosen by the following rules:
|
|
123
123
|
|
|
@@ -143,8 +143,8 @@ class MetricsLibrary(ABC):
|
|
|
143
143
|
|
|
144
144
|
@classmethod
|
|
145
145
|
def _from_dict(
|
|
146
|
-
cls, metrics_dictionary:
|
|
147
|
-
) ->
|
|
146
|
+
cls, metrics_dictionary: dict[str, SKLearnTypes.MetricEntryType]
|
|
147
|
+
) -> list[Metric]:
|
|
148
148
|
"""
|
|
149
149
|
Collect the given metrics configurations from a dictionary.
|
|
150
150
|
|
|
@@ -165,7 +165,7 @@ class MetricsLibrary(ABC):
|
|
|
165
165
|
@classmethod
|
|
166
166
|
def _default(
|
|
167
167
|
cls, model: SKLearnTypes.ModelType, y: SKLearnTypes.DatasetType = None
|
|
168
|
-
) ->
|
|
168
|
+
) -> list[Metric]:
|
|
169
169
|
"""
|
|
170
170
|
Get the default metrics list according to the algorithm functionality.
|
|
171
171
|
|
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
from abc import ABC
|
|
16
|
-
from typing import List
|
|
17
16
|
|
|
18
17
|
import mlrun
|
|
19
18
|
|
|
@@ -75,9 +74,7 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
|
|
|
75
74
|
cls._REPLACED_METHODS.remove("predict_proba")
|
|
76
75
|
|
|
77
76
|
# Add the interface to the model:
|
|
78
|
-
super(
|
|
79
|
-
obj=obj, restoration=restoration
|
|
80
|
-
)
|
|
77
|
+
super().add_interface(obj=obj, restoration=restoration)
|
|
81
78
|
|
|
82
79
|
# Restore the '_REPLACED_METHODS' list for next models:
|
|
83
80
|
if "predict_proba" not in cls._REPLACED_METHODS:
|
|
@@ -100,7 +97,7 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
|
|
|
100
97
|
|
|
101
98
|
def wrapper(
|
|
102
99
|
self: SKLearnTypes.ModelType,
|
|
103
|
-
X: SKLearnTypes.DatasetType,
|
|
100
|
+
X: SKLearnTypes.DatasetType, # noqa: N803 - should be lowercase "x", kept for BC
|
|
104
101
|
y: SKLearnTypes.DatasetType = None,
|
|
105
102
|
*args,
|
|
106
103
|
**kwargs,
|
|
@@ -127,7 +124,12 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
|
|
|
127
124
|
|
|
128
125
|
return wrapper
|
|
129
126
|
|
|
130
|
-
def mlrun_predict(
|
|
127
|
+
def mlrun_predict(
|
|
128
|
+
self,
|
|
129
|
+
X: SKLearnTypes.DatasetType, # noqa: N803 - should be lowercase "x", kept for BC
|
|
130
|
+
*args,
|
|
131
|
+
**kwargs,
|
|
132
|
+
):
|
|
131
133
|
"""
|
|
132
134
|
MLRun's wrapper for the common ML API predict method.
|
|
133
135
|
"""
|
|
@@ -139,7 +141,12 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
|
|
|
139
141
|
|
|
140
142
|
return y_pred
|
|
141
143
|
|
|
142
|
-
def mlrun_predict_proba(
|
|
144
|
+
def mlrun_predict_proba(
|
|
145
|
+
self,
|
|
146
|
+
X: SKLearnTypes.DatasetType, # noqa: N803 - should be lowercase "x", kept for BC
|
|
147
|
+
*args,
|
|
148
|
+
**kwargs,
|
|
149
|
+
):
|
|
143
150
|
"""
|
|
144
151
|
MLRun's wrapper for the common ML API predict_proba method.
|
|
145
152
|
"""
|
|
@@ -154,8 +161,8 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
|
|
|
154
161
|
def configure_logging(
|
|
155
162
|
self,
|
|
156
163
|
context: mlrun.MLClientCtx = None,
|
|
157
|
-
plans:
|
|
158
|
-
metrics:
|
|
164
|
+
plans: list[MLPlan] = None,
|
|
165
|
+
metrics: list[Metric] = None,
|
|
159
166
|
x_test: SKLearnTypes.DatasetType = None,
|
|
160
167
|
y_test: SKLearnTypes.DatasetType = None,
|
|
161
168
|
model_handler: MLModelHandler = None,
|
|
@@ -59,7 +59,7 @@ class SKLearnModelHandler(MLModelHandler):
|
|
|
59
59
|
|
|
60
60
|
:return The saved model additional artifacts (if needed) dictionary if context is available and None otherwise.
|
|
61
61
|
"""
|
|
62
|
-
super(
|
|
62
|
+
super().save(output_path=output_path)
|
|
63
63
|
|
|
64
64
|
# Save the model pkl file:
|
|
65
65
|
self._model_file = f"{self._model_name}.pkl"
|
|
@@ -73,7 +73,7 @@ class SKLearnModelHandler(MLModelHandler):
|
|
|
73
73
|
Load the specified model in this handler. Additional parameters for the class initializer can be passed via the
|
|
74
74
|
kwargs dictionary.
|
|
75
75
|
"""
|
|
76
|
-
super(
|
|
76
|
+
super().load()
|
|
77
77
|
|
|
78
78
|
# Load from a pkl file:
|
|
79
79
|
with open(self._model_file, "rb") as pickle_file:
|
|
@@ -13,11 +13,12 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
|
|
16
|
-
from typing import Any,
|
|
16
|
+
from typing import Any, Union
|
|
17
17
|
|
|
18
18
|
from tensorflow import keras
|
|
19
19
|
|
|
20
20
|
import mlrun
|
|
21
|
+
import mlrun.common.constants as mlrun_constants
|
|
21
22
|
|
|
22
23
|
from .callbacks import MLRunLoggingCallback, TensorboardLoggingCallback
|
|
23
24
|
from .mlrun_interface import TFKerasMLRunInterface
|
|
@@ -33,14 +34,14 @@ def apply_mlrun(
|
|
|
33
34
|
model_path: str = None,
|
|
34
35
|
model_format: str = TFKerasModelHandler.ModelFormats.SAVED_MODEL,
|
|
35
36
|
save_traces: bool = False,
|
|
36
|
-
modules_map: Union[
|
|
37
|
-
custom_objects_map: Union[
|
|
37
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
38
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
38
39
|
custom_objects_directory: str = None,
|
|
39
40
|
context: mlrun.MLClientCtx = None,
|
|
40
41
|
auto_log: bool = True,
|
|
41
42
|
tensorboard_directory: str = None,
|
|
42
|
-
mlrun_callback_kwargs:
|
|
43
|
-
tensorboard_callback_kwargs:
|
|
43
|
+
mlrun_callback_kwargs: dict[str, Any] = None,
|
|
44
|
+
tensorboard_callback_kwargs: dict[str, Any] = None,
|
|
44
45
|
use_horovod: bool = None,
|
|
45
46
|
**kwargs,
|
|
46
47
|
) -> TFKerasModelHandler:
|
|
@@ -85,7 +86,7 @@ def apply_mlrun(
|
|
|
85
86
|
|
|
86
87
|
{
|
|
87
88
|
"/.../custom_optimizer.py": "optimizer",
|
|
88
|
-
"/.../custom_layers.py": ["layer1", "layer2"]
|
|
89
|
+
"/.../custom_layers.py": ["layer1", "layer2"],
|
|
89
90
|
}
|
|
90
91
|
|
|
91
92
|
All the paths will be accessed from the given 'custom_objects_directory',
|
|
@@ -126,7 +127,9 @@ def apply_mlrun(
|
|
|
126
127
|
# # Use horovod:
|
|
127
128
|
if use_horovod is None:
|
|
128
129
|
use_horovod = (
|
|
129
|
-
context.labels.get(
|
|
130
|
+
context.labels.get(mlrun_constants.MLRunInternalLabels.kind, "") == "mpijob"
|
|
131
|
+
if context is not None
|
|
132
|
+
else False
|
|
130
133
|
)
|
|
131
134
|
|
|
132
135
|
# Create a model handler:
|
|
@@ -12,12 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Callable,
|
|
15
|
+
from typing import Callable, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import tensorflow as tf
|
|
19
19
|
from tensorflow import Tensor, Variable
|
|
20
|
-
from tensorflow.keras.callbacks import Callback
|
|
20
|
+
from tensorflow.python.keras.callbacks import Callback
|
|
21
21
|
|
|
22
22
|
import mlrun
|
|
23
23
|
|
|
@@ -36,11 +36,11 @@ class LoggingCallback(Callback):
|
|
|
36
36
|
def __init__(
|
|
37
37
|
self,
|
|
38
38
|
context: mlrun.MLClientCtx = None,
|
|
39
|
-
dynamic_hyperparameters:
|
|
40
|
-
str, Union[
|
|
39
|
+
dynamic_hyperparameters: dict[
|
|
40
|
+
str, Union[list[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]]
|
|
41
41
|
] = None,
|
|
42
|
-
static_hyperparameters:
|
|
43
|
-
str, Union[TFKerasTypes.TrackableType,
|
|
42
|
+
static_hyperparameters: dict[
|
|
43
|
+
str, Union[TFKerasTypes.TrackableType, list[Union[str, int]]]
|
|
44
44
|
] = None,
|
|
45
45
|
auto_log: bool = False,
|
|
46
46
|
):
|
|
@@ -70,7 +70,7 @@ class LoggingCallback(Callback):
|
|
|
70
70
|
:param auto_log: Whether or not to enable auto logging, trying to track common static and dynamic
|
|
71
71
|
hyperparameters.
|
|
72
72
|
"""
|
|
73
|
-
super(
|
|
73
|
+
super().__init__()
|
|
74
74
|
self._supports_tf_logs = True
|
|
75
75
|
|
|
76
76
|
# Store the configurations:
|
|
@@ -93,7 +93,7 @@ class LoggingCallback(Callback):
|
|
|
93
93
|
self._is_training = None # type: bool
|
|
94
94
|
self._auto_log = auto_log
|
|
95
95
|
|
|
96
|
-
def get_training_results(self) ->
|
|
96
|
+
def get_training_results(self) -> dict[str, list[list[float]]]:
|
|
97
97
|
"""
|
|
98
98
|
Get the training results logged. The results will be stored in a dictionary where each key is the metric name
|
|
99
99
|
and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
|
|
@@ -103,7 +103,7 @@ class LoggingCallback(Callback):
|
|
|
103
103
|
"""
|
|
104
104
|
return self._logger.training_results
|
|
105
105
|
|
|
106
|
-
def get_validation_results(self) ->
|
|
106
|
+
def get_validation_results(self) -> dict[str, list[list[float]]]:
|
|
107
107
|
"""
|
|
108
108
|
Get the validation results logged. The results will be stored in a dictionary where each key is the metric name
|
|
109
109
|
and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
|
|
@@ -113,7 +113,7 @@ class LoggingCallback(Callback):
|
|
|
113
113
|
"""
|
|
114
114
|
return self._logger.validation_results
|
|
115
115
|
|
|
116
|
-
def get_training_summaries(self) ->
|
|
116
|
+
def get_training_summaries(self) -> dict[str, list[float]]:
|
|
117
117
|
"""
|
|
118
118
|
Get the training summaries of the metrics results. The summaries will be stored in a dictionary where each key
|
|
119
119
|
is the metric names and the value is a list of all the summary values per epoch.
|
|
@@ -122,7 +122,7 @@ class LoggingCallback(Callback):
|
|
|
122
122
|
"""
|
|
123
123
|
return self._logger.training_summaries
|
|
124
124
|
|
|
125
|
-
def get_validation_summaries(self) ->
|
|
125
|
+
def get_validation_summaries(self) -> dict[str, list[float]]:
|
|
126
126
|
"""
|
|
127
127
|
Get the validation summaries of the metrics results. The summaries will be stored in a dictionary where each key
|
|
128
128
|
is the metric names and the value is a list of all the summary values per epoch.
|
|
@@ -131,7 +131,7 @@ class LoggingCallback(Callback):
|
|
|
131
131
|
"""
|
|
132
132
|
return self._logger.validation_summaries
|
|
133
133
|
|
|
134
|
-
def get_static_hyperparameters(self) ->
|
|
134
|
+
def get_static_hyperparameters(self) -> dict[str, TFKerasTypes.TrackableType]:
|
|
135
135
|
"""
|
|
136
136
|
Get the static hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
|
|
137
137
|
hyperparameter name and the value is his logged value.
|
|
@@ -142,7 +142,7 @@ class LoggingCallback(Callback):
|
|
|
142
142
|
|
|
143
143
|
def get_dynamic_hyperparameters(
|
|
144
144
|
self,
|
|
145
|
-
) ->
|
|
145
|
+
) -> dict[str, list[TFKerasTypes.TrackableType]]:
|
|
146
146
|
"""
|
|
147
147
|
Get the dynamic hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
|
|
148
148
|
hyperparameter name and the value is a list of his logged values per epoch.
|
|
@@ -329,7 +329,7 @@ class LoggingCallback(Callback):
|
|
|
329
329
|
|
|
330
330
|
# Static hyperparameters:
|
|
331
331
|
for name, value in self._static_hyperparameters_keys.items():
|
|
332
|
-
if isinstance(value,
|
|
332
|
+
if isinstance(value, list):
|
|
333
333
|
# Its a parameter that needed to be extracted via key chain.
|
|
334
334
|
self._logger.log_static_hyperparameter(
|
|
335
335
|
parameter_name=name,
|
|
@@ -398,7 +398,7 @@ class LoggingCallback(Callback):
|
|
|
398
398
|
def _get_hyperparameter(
|
|
399
399
|
self,
|
|
400
400
|
key_chain: Union[
|
|
401
|
-
Callable[[], TFKerasTypes.TrackableType],
|
|
401
|
+
Callable[[], TFKerasTypes.TrackableType], list[Union[str, int]]
|
|
402
402
|
],
|
|
403
403
|
) -> TFKerasTypes.TrackableType:
|
|
404
404
|
"""
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Callable,
|
|
15
|
+
from typing import Callable, Union
|
|
16
16
|
|
|
17
17
|
import mlrun
|
|
18
18
|
from mlrun.artifacts import Artifact
|
|
@@ -50,16 +50,16 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
50
50
|
context: mlrun.MLClientCtx,
|
|
51
51
|
model_handler: TFKerasModelHandler,
|
|
52
52
|
log_model_tag: str = "",
|
|
53
|
-
log_model_labels:
|
|
54
|
-
log_model_parameters:
|
|
55
|
-
log_model_extra_data:
|
|
53
|
+
log_model_labels: dict[str, TFKerasTypes.TrackableType] = None,
|
|
54
|
+
log_model_parameters: dict[str, TFKerasTypes.TrackableType] = None,
|
|
55
|
+
log_model_extra_data: dict[
|
|
56
56
|
str, Union[TFKerasTypes.TrackableType, Artifact]
|
|
57
57
|
] = None,
|
|
58
|
-
dynamic_hyperparameters:
|
|
59
|
-
str, Union[
|
|
58
|
+
dynamic_hyperparameters: dict[
|
|
59
|
+
str, Union[list[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]]
|
|
60
60
|
] = None,
|
|
61
|
-
static_hyperparameters:
|
|
62
|
-
str, Union[TFKerasTypes,
|
|
61
|
+
static_hyperparameters: dict[
|
|
62
|
+
str, Union[TFKerasTypes, list[Union[str, int]]]
|
|
63
63
|
] = None,
|
|
64
64
|
auto_log: bool = False,
|
|
65
65
|
):
|
|
@@ -97,7 +97,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
97
97
|
trying to track common static and dynamic hyperparameters such as learning
|
|
98
98
|
rate.
|
|
99
99
|
"""
|
|
100
|
-
super(
|
|
100
|
+
super().__init__(
|
|
101
101
|
dynamic_hyperparameters=dynamic_hyperparameters,
|
|
102
102
|
static_hyperparameters=static_hyperparameters,
|
|
103
103
|
auto_log=auto_log,
|
|
@@ -134,7 +134,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
134
134
|
:param logs: Currently no data is passed to this argument for this method but that may change in the
|
|
135
135
|
future.
|
|
136
136
|
"""
|
|
137
|
-
super(
|
|
137
|
+
super().on_test_end(logs=logs)
|
|
138
138
|
|
|
139
139
|
# Check if its part of evaluation. If so, end the run:
|
|
140
140
|
if self._logger.mode == LoggingMode.EVALUATION:
|
|
@@ -151,7 +151,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
151
151
|
performed. Validation result keys are prefixed with `val_`. For training epoch, the values of the
|
|
152
152
|
`Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`.
|
|
153
153
|
"""
|
|
154
|
-
super(
|
|
154
|
+
super().on_epoch_end(epoch=epoch)
|
|
155
155
|
|
|
156
156
|
# Log the current epoch's results:
|
|
157
157
|
self._logger.log_epoch_to_context(epoch=epoch)
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
from datetime import datetime
|
|
16
|
-
from typing import Callable,
|
|
16
|
+
from typing import Callable, Union
|
|
17
17
|
|
|
18
18
|
import tensorflow as tf
|
|
19
19
|
from packaging import version
|
|
@@ -38,7 +38,7 @@ class _TFKerasTensorboardLogger(TensorboardLogger):
|
|
|
38
38
|
|
|
39
39
|
def __init__(
|
|
40
40
|
self,
|
|
41
|
-
statistics_functions:
|
|
41
|
+
statistics_functions: list[Callable[[Union[Variable]], Union[float, Variable]]],
|
|
42
42
|
context: mlrun.MLClientCtx = None,
|
|
43
43
|
tensorboard_directory: str = None,
|
|
44
44
|
run_name: str = None,
|
|
@@ -67,7 +67,7 @@ class _TFKerasTensorboardLogger(TensorboardLogger):
|
|
|
67
67
|
update. Notice that writing to tensorboard too frequently may cause the training
|
|
68
68
|
to be slower. Default: 'epoch'.
|
|
69
69
|
"""
|
|
70
|
-
super(
|
|
70
|
+
super().__init__(
|
|
71
71
|
statistics_functions=statistics_functions,
|
|
72
72
|
context=context,
|
|
73
73
|
tensorboard_directory=tensorboard_directory,
|
|
@@ -255,15 +255,15 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
255
255
|
context: mlrun.MLClientCtx = None,
|
|
256
256
|
tensorboard_directory: str = None,
|
|
257
257
|
run_name: str = None,
|
|
258
|
-
weights: Union[bool,
|
|
259
|
-
statistics_functions:
|
|
258
|
+
weights: Union[bool, list[str]] = False,
|
|
259
|
+
statistics_functions: list[
|
|
260
260
|
Callable[[Union[Variable, Tensor]], Union[float, Tensor]]
|
|
261
261
|
] = None,
|
|
262
|
-
dynamic_hyperparameters:
|
|
263
|
-
str, Union[
|
|
262
|
+
dynamic_hyperparameters: dict[
|
|
263
|
+
str, Union[list[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]]
|
|
264
264
|
] = None,
|
|
265
|
-
static_hyperparameters:
|
|
266
|
-
str, Union[TFKerasTypes.TrackableType,
|
|
265
|
+
static_hyperparameters: dict[
|
|
266
|
+
str, Union[TFKerasTypes.TrackableType, list[Union[str, int]]]
|
|
267
267
|
] = None,
|
|
268
268
|
update_frequency: Union[int, str] = "epoch",
|
|
269
269
|
auto_log: bool = False,
|
|
@@ -325,7 +325,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
325
325
|
:raise MLRunInvalidArgumentError: In case both 'context' and 'tensorboard_directory' parameters were not given
|
|
326
326
|
or the 'update_frequency' was incorrect.
|
|
327
327
|
"""
|
|
328
|
-
super(
|
|
328
|
+
super().__init__(
|
|
329
329
|
dynamic_hyperparameters=dynamic_hyperparameters,
|
|
330
330
|
static_hyperparameters=static_hyperparameters,
|
|
331
331
|
auto_log=auto_log,
|
|
@@ -352,7 +352,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
352
352
|
self._logged_model = False
|
|
353
353
|
self._logged_hyperparameters = False
|
|
354
354
|
|
|
355
|
-
def get_weights(self) ->
|
|
355
|
+
def get_weights(self) -> dict[str, Variable]:
|
|
356
356
|
"""
|
|
357
357
|
Get the weights tensors tracked. The weights will be stored in a dictionary where each key is the weight's name
|
|
358
358
|
and the value is the weight's parameter (tensor).
|
|
@@ -361,7 +361,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
361
361
|
"""
|
|
362
362
|
return self._logger.weights
|
|
363
363
|
|
|
364
|
-
def get_weights_statistics(self) ->
|
|
364
|
+
def get_weights_statistics(self) -> dict[str, dict[str, list[float]]]:
|
|
365
365
|
"""
|
|
366
366
|
Get the weights mean results logged. The results will be stored in a dictionary where each key is the weight's
|
|
367
367
|
name and the value is a list of mean values per epoch.
|
|
@@ -408,7 +408,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
408
408
|
:param logs: Currently the output of the last call to `on_epoch_end()` is passed to this argument for this
|
|
409
409
|
method but that may change in the future.
|
|
410
410
|
"""
|
|
411
|
-
super(
|
|
411
|
+
super().on_train_end()
|
|
412
412
|
|
|
413
413
|
# Write the final run summary:
|
|
414
414
|
self._logger.write_final_summary_text()
|
|
@@ -453,7 +453,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
453
453
|
:param logs: Currently no data is passed to this argument for this method but that may change in the
|
|
454
454
|
future.
|
|
455
455
|
"""
|
|
456
|
-
super(
|
|
456
|
+
super().on_test_end(logs=logs)
|
|
457
457
|
|
|
458
458
|
# Check if needed to end the run (in case of evaluation and not training):
|
|
459
459
|
if not self._is_training:
|
|
@@ -477,7 +477,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
477
477
|
`Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`.
|
|
478
478
|
"""
|
|
479
479
|
# Update the dynamic hyperparameters
|
|
480
|
-
super(
|
|
480
|
+
super().on_epoch_end(epoch=epoch)
|
|
481
481
|
|
|
482
482
|
# Log the weights statistics:
|
|
483
483
|
self._logger.log_weights_statistics()
|
|
@@ -515,9 +515,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
515
515
|
:param logs: Aggregated metric results up until this batch.
|
|
516
516
|
"""
|
|
517
517
|
# Log the batch's results:
|
|
518
|
-
super(
|
|
519
|
-
batch=batch, logs=logs
|
|
520
|
-
)
|
|
518
|
+
super().on_train_batch_end(batch=batch, logs=logs)
|
|
521
519
|
|
|
522
520
|
# Write the batch loss and metrics results to their graphs:
|
|
523
521
|
self._logger.write_training_results()
|
|
@@ -540,9 +538,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
540
538
|
:param logs: Aggregated metric results up until this batch.
|
|
541
539
|
"""
|
|
542
540
|
# Log the batch's results:
|
|
543
|
-
super(
|
|
544
|
-
batch=batch, logs=logs
|
|
545
|
-
)
|
|
541
|
+
super().on_test_batch_end(batch=batch, logs=logs)
|
|
546
542
|
|
|
547
543
|
# Write the batch loss and metrics results to their graphs:
|
|
548
544
|
self._logger.write_validation_results()
|
|
@@ -555,7 +551,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
555
551
|
|
|
556
552
|
@staticmethod
|
|
557
553
|
def get_default_weight_statistics_list() -> (
|
|
558
|
-
|
|
554
|
+
list[Callable[[Union[Variable, Tensor]], Union[float, Tensor]]]
|
|
559
555
|
):
|
|
560
556
|
"""
|
|
561
557
|
Get the default list of statistics functions being applied on the tracked weights each epoch.
|
|
@@ -569,7 +565,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
569
565
|
After the trainer / evaluator run begins, this method will be called to setup the results, hyperparameters
|
|
570
566
|
and weights dictionaries for logging.
|
|
571
567
|
"""
|
|
572
|
-
super(
|
|
568
|
+
super()._setup_run()
|
|
573
569
|
|
|
574
570
|
# Check if needed to track weights:
|
|
575
571
|
if self._tracked_weights is False:
|
|
@@ -15,11 +15,12 @@
|
|
|
15
15
|
import importlib
|
|
16
16
|
import os
|
|
17
17
|
from abc import ABC
|
|
18
|
-
from typing import
|
|
18
|
+
from typing import Union
|
|
19
19
|
|
|
20
20
|
import tensorflow as tf
|
|
21
21
|
from tensorflow import keras
|
|
22
|
-
from tensorflow.keras.
|
|
22
|
+
from tensorflow.keras.optimizers import Optimizer
|
|
23
|
+
from tensorflow.python.keras.callbacks import (
|
|
23
24
|
BaseLogger,
|
|
24
25
|
Callback,
|
|
25
26
|
CSVLogger,
|
|
@@ -27,7 +28,6 @@ from tensorflow.keras.callbacks import (
|
|
|
27
28
|
ProgbarLogger,
|
|
28
29
|
TensorBoard,
|
|
29
30
|
)
|
|
30
|
-
from tensorflow.keras.optimizers import Optimizer
|
|
31
31
|
|
|
32
32
|
import mlrun
|
|
33
33
|
|
|
@@ -88,9 +88,7 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
|
|
|
88
88
|
:param restoration: Restoration information tuple as returned from 'remove_interface' in order to
|
|
89
89
|
add the interface in a certain state.
|
|
90
90
|
"""
|
|
91
|
-
super(
|
|
92
|
-
obj=obj, restoration=restoration
|
|
93
|
-
)
|
|
91
|
+
super().add_interface(obj=obj, restoration=restoration)
|
|
94
92
|
|
|
95
93
|
def mlrun_compile(self, *args, **kwargs):
|
|
96
94
|
"""
|
|
@@ -237,7 +235,7 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
|
|
|
237
235
|
"""
|
|
238
236
|
self._RANK_0_ONLY_CALLBACKS.add(callback_name)
|
|
239
237
|
|
|
240
|
-
def _pre_compile(self, optimizer: Optimizer) ->
|
|
238
|
+
def _pre_compile(self, optimizer: Optimizer) -> tuple[Optimizer, Union[bool, None]]:
|
|
241
239
|
"""
|
|
242
240
|
Method to call before calling 'compile' to setup the run and inputs for using horovod.
|
|
243
241
|
|
|
@@ -295,11 +293,11 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
|
|
|
295
293
|
|
|
296
294
|
def _pre_fit(
|
|
297
295
|
self,
|
|
298
|
-
callbacks:
|
|
296
|
+
callbacks: list[Callback],
|
|
299
297
|
verbose: int,
|
|
300
298
|
steps_per_epoch: Union[int, None],
|
|
301
299
|
validation_steps: Union[int, None],
|
|
302
|
-
) ->
|
|
300
|
+
) -> tuple[list[Callback], int, Union[int, None], Union[int, None]]:
|
|
303
301
|
"""
|
|
304
302
|
Method to call before calling 'fit' to setup the run and inputs for using horovod.
|
|
305
303
|
|
|
@@ -366,9 +364,9 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
|
|
|
366
364
|
|
|
367
365
|
def _pre_evaluate(
|
|
368
366
|
self,
|
|
369
|
-
callbacks:
|
|
367
|
+
callbacks: list[Callback],
|
|
370
368
|
steps: Union[int, None],
|
|
371
|
-
) ->
|
|
369
|
+
) -> tuple[list[Callback], Union[int, None]]:
|
|
372
370
|
"""
|
|
373
371
|
Method to call before calling 'evaluate' to setup the run and inputs for using horovod.
|
|
374
372
|
|