mlrun 1.6.4rc8__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +11 -1
- mlrun/__main__.py +40 -122
- mlrun/alerts/__init__.py +15 -0
- mlrun/alerts/alert.py +248 -0
- mlrun/api/schemas/__init__.py +5 -4
- mlrun/artifacts/__init__.py +8 -3
- mlrun/artifacts/base.py +47 -257
- mlrun/artifacts/dataset.py +11 -192
- mlrun/artifacts/manager.py +79 -47
- mlrun/artifacts/model.py +31 -159
- mlrun/artifacts/plots.py +23 -380
- mlrun/common/constants.py +74 -1
- mlrun/common/db/sql_session.py +5 -5
- mlrun/common/formatters/__init__.py +21 -0
- mlrun/common/formatters/artifact.py +45 -0
- mlrun/common/formatters/base.py +113 -0
- mlrun/common/formatters/feature_set.py +33 -0
- mlrun/common/formatters/function.py +46 -0
- mlrun/common/formatters/pipeline.py +53 -0
- mlrun/common/formatters/project.py +51 -0
- mlrun/common/formatters/run.py +29 -0
- mlrun/common/helpers.py +12 -3
- mlrun/common/model_monitoring/helpers.py +9 -5
- mlrun/{runtimes → common/runtimes}/constants.py +37 -9
- mlrun/common/schemas/__init__.py +31 -5
- mlrun/common/schemas/alert.py +202 -0
- mlrun/common/schemas/api_gateway.py +196 -0
- mlrun/common/schemas/artifact.py +25 -4
- mlrun/common/schemas/auth.py +16 -5
- mlrun/common/schemas/background_task.py +1 -1
- mlrun/common/schemas/client_spec.py +4 -2
- mlrun/common/schemas/common.py +7 -4
- mlrun/common/schemas/constants.py +3 -0
- mlrun/common/schemas/feature_store.py +74 -44
- mlrun/common/schemas/frontend_spec.py +15 -7
- mlrun/common/schemas/function.py +12 -1
- mlrun/common/schemas/hub.py +11 -18
- mlrun/common/schemas/memory_reports.py +2 -2
- mlrun/common/schemas/model_monitoring/__init__.py +20 -4
- mlrun/common/schemas/model_monitoring/constants.py +123 -42
- mlrun/common/schemas/model_monitoring/grafana.py +13 -9
- mlrun/common/schemas/model_monitoring/model_endpoints.py +101 -54
- mlrun/common/schemas/notification.py +71 -14
- mlrun/common/schemas/object.py +2 -2
- mlrun/{model_monitoring/controller_handler.py → common/schemas/pagination.py} +9 -12
- mlrun/common/schemas/pipeline.py +8 -1
- mlrun/common/schemas/project.py +69 -18
- mlrun/common/schemas/runs.py +7 -1
- mlrun/common/schemas/runtime_resource.py +8 -12
- mlrun/common/schemas/schedule.py +4 -4
- mlrun/common/schemas/tag.py +1 -2
- mlrun/common/schemas/workflow.py +12 -4
- mlrun/common/types.py +14 -1
- mlrun/config.py +154 -69
- mlrun/data_types/data_types.py +6 -1
- mlrun/data_types/spark.py +2 -2
- mlrun/data_types/to_pandas.py +67 -37
- mlrun/datastore/__init__.py +6 -8
- mlrun/datastore/alibaba_oss.py +131 -0
- mlrun/datastore/azure_blob.py +143 -42
- mlrun/datastore/base.py +102 -58
- mlrun/datastore/datastore.py +34 -13
- mlrun/datastore/datastore_profile.py +146 -20
- mlrun/datastore/dbfs_store.py +3 -7
- mlrun/datastore/filestore.py +1 -4
- mlrun/datastore/google_cloud_storage.py +97 -33
- mlrun/datastore/hdfs.py +56 -0
- mlrun/datastore/inmem.py +6 -3
- mlrun/datastore/redis.py +7 -2
- mlrun/datastore/s3.py +34 -12
- mlrun/datastore/snowflake_utils.py +45 -0
- mlrun/datastore/sources.py +303 -111
- mlrun/datastore/spark_utils.py +31 -2
- mlrun/datastore/store_resources.py +9 -7
- mlrun/datastore/storeytargets.py +151 -0
- mlrun/datastore/targets.py +453 -176
- mlrun/datastore/utils.py +72 -58
- mlrun/datastore/v3io.py +6 -1
- mlrun/db/base.py +274 -41
- mlrun/db/factory.py +1 -1
- mlrun/db/httpdb.py +893 -225
- mlrun/db/nopdb.py +291 -33
- mlrun/errors.py +36 -6
- mlrun/execution.py +115 -42
- mlrun/feature_store/__init__.py +0 -2
- mlrun/feature_store/api.py +65 -73
- mlrun/feature_store/common.py +7 -12
- mlrun/feature_store/feature_set.py +76 -55
- mlrun/feature_store/feature_vector.py +39 -31
- mlrun/feature_store/ingestion.py +7 -6
- mlrun/feature_store/retrieval/base.py +16 -11
- mlrun/feature_store/retrieval/dask_merger.py +2 -0
- mlrun/feature_store/retrieval/job.py +13 -4
- mlrun/feature_store/retrieval/local_merger.py +2 -0
- mlrun/feature_store/retrieval/spark_merger.py +24 -32
- mlrun/feature_store/steps.py +45 -34
- mlrun/features.py +11 -21
- mlrun/frameworks/_common/artifacts_library.py +9 -9
- mlrun/frameworks/_common/mlrun_interface.py +5 -5
- mlrun/frameworks/_common/model_handler.py +48 -48
- mlrun/frameworks/_common/plan.py +5 -6
- mlrun/frameworks/_common/producer.py +3 -4
- mlrun/frameworks/_common/utils.py +5 -5
- mlrun/frameworks/_dl_common/loggers/logger.py +6 -7
- mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +9 -9
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +23 -47
- mlrun/frameworks/_ml_common/artifacts_library.py +1 -2
- mlrun/frameworks/_ml_common/loggers/logger.py +3 -4
- mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +4 -5
- mlrun/frameworks/_ml_common/model_handler.py +24 -24
- mlrun/frameworks/_ml_common/pkl_model_server.py +2 -2
- mlrun/frameworks/_ml_common/plan.py +2 -2
- mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
- mlrun/frameworks/_ml_common/utils.py +4 -4
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +9 -9
- mlrun/frameworks/huggingface/model_server.py +4 -4
- mlrun/frameworks/lgbm/__init__.py +33 -33
- mlrun/frameworks/lgbm/callbacks/callback.py +2 -4
- mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -5
- mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -5
- mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +1 -3
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +6 -6
- mlrun/frameworks/lgbm/model_handler.py +10 -10
- mlrun/frameworks/lgbm/model_server.py +6 -6
- mlrun/frameworks/lgbm/utils.py +5 -5
- mlrun/frameworks/onnx/dataset.py +8 -8
- mlrun/frameworks/onnx/mlrun_interface.py +3 -3
- mlrun/frameworks/onnx/model_handler.py +6 -6
- mlrun/frameworks/onnx/model_server.py +7 -7
- mlrun/frameworks/parallel_coordinates.py +6 -6
- mlrun/frameworks/pytorch/__init__.py +18 -18
- mlrun/frameworks/pytorch/callbacks/callback.py +4 -5
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +17 -17
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +23 -29
- mlrun/frameworks/pytorch/callbacks_handler.py +38 -38
- mlrun/frameworks/pytorch/mlrun_interface.py +20 -20
- mlrun/frameworks/pytorch/model_handler.py +17 -17
- mlrun/frameworks/pytorch/model_server.py +7 -7
- mlrun/frameworks/sklearn/__init__.py +13 -13
- mlrun/frameworks/sklearn/estimator.py +4 -4
- mlrun/frameworks/sklearn/metrics_library.py +14 -14
- mlrun/frameworks/sklearn/mlrun_interface.py +16 -9
- mlrun/frameworks/sklearn/model_handler.py +2 -2
- mlrun/frameworks/tf_keras/__init__.py +10 -7
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +15 -15
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +19 -23
- mlrun/frameworks/tf_keras/mlrun_interface.py +9 -11
- mlrun/frameworks/tf_keras/model_handler.py +14 -14
- mlrun/frameworks/tf_keras/model_server.py +6 -6
- mlrun/frameworks/xgboost/__init__.py +13 -13
- mlrun/frameworks/xgboost/model_handler.py +6 -6
- mlrun/k8s_utils.py +61 -17
- mlrun/launcher/__init__.py +1 -1
- mlrun/launcher/base.py +16 -15
- mlrun/launcher/client.py +13 -11
- mlrun/launcher/factory.py +1 -1
- mlrun/launcher/local.py +23 -13
- mlrun/launcher/remote.py +17 -10
- mlrun/lists.py +7 -6
- mlrun/model.py +478 -103
- mlrun/model_monitoring/__init__.py +1 -1
- mlrun/model_monitoring/api.py +163 -371
- mlrun/{runtimes/mpijob/v1alpha1.py → model_monitoring/applications/__init__.py} +9 -15
- mlrun/model_monitoring/applications/_application_steps.py +188 -0
- mlrun/model_monitoring/applications/base.py +108 -0
- mlrun/model_monitoring/applications/context.py +341 -0
- mlrun/model_monitoring/{evidently_application.py → applications/evidently_base.py} +27 -22
- mlrun/model_monitoring/applications/histogram_data_drift.py +354 -0
- mlrun/model_monitoring/applications/results.py +99 -0
- mlrun/model_monitoring/controller.py +131 -278
- mlrun/model_monitoring/db/__init__.py +18 -0
- mlrun/model_monitoring/db/stores/__init__.py +136 -0
- mlrun/model_monitoring/db/stores/base/__init__.py +15 -0
- mlrun/model_monitoring/db/stores/base/store.py +213 -0
- mlrun/model_monitoring/db/stores/sqldb/__init__.py +13 -0
- mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +71 -0
- mlrun/model_monitoring/db/stores/sqldb/models/base.py +190 -0
- mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +103 -0
- mlrun/model_monitoring/{stores/models/mysql.py → db/stores/sqldb/models/sqlite.py} +19 -13
- mlrun/model_monitoring/db/stores/sqldb/sql_store.py +659 -0
- mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +13 -0
- mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +726 -0
- mlrun/model_monitoring/db/tsdb/__init__.py +105 -0
- mlrun/model_monitoring/db/tsdb/base.py +448 -0
- mlrun/model_monitoring/db/tsdb/helpers.py +30 -0
- mlrun/model_monitoring/db/tsdb/tdengine/__init__.py +15 -0
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +279 -0
- mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +42 -0
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +507 -0
- mlrun/model_monitoring/db/tsdb/v3io/__init__.py +15 -0
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +158 -0
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +849 -0
- mlrun/model_monitoring/features_drift_table.py +134 -106
- mlrun/model_monitoring/helpers.py +199 -55
- mlrun/model_monitoring/metrics/__init__.py +13 -0
- mlrun/model_monitoring/metrics/histogram_distance.py +127 -0
- mlrun/model_monitoring/model_endpoint.py +3 -2
- mlrun/model_monitoring/stream_processing.py +134 -398
- mlrun/model_monitoring/tracking_policy.py +9 -2
- mlrun/model_monitoring/writer.py +161 -125
- mlrun/package/__init__.py +6 -6
- mlrun/package/context_handler.py +5 -5
- mlrun/package/packager.py +7 -7
- mlrun/package/packagers/default_packager.py +8 -8
- mlrun/package/packagers/numpy_packagers.py +15 -15
- mlrun/package/packagers/pandas_packagers.py +5 -5
- mlrun/package/packagers/python_standard_library_packagers.py +10 -10
- mlrun/package/packagers_manager.py +19 -23
- mlrun/package/utils/_formatter.py +6 -6
- mlrun/package/utils/_pickler.py +2 -2
- mlrun/package/utils/_supported_format.py +4 -4
- mlrun/package/utils/log_hint_utils.py +2 -2
- mlrun/package/utils/type_hint_utils.py +4 -9
- mlrun/platforms/__init__.py +11 -10
- mlrun/platforms/iguazio.py +24 -203
- mlrun/projects/operations.py +52 -25
- mlrun/projects/pipelines.py +191 -197
- mlrun/projects/project.py +1227 -400
- mlrun/render.py +16 -19
- mlrun/run.py +209 -184
- mlrun/runtimes/__init__.py +83 -15
- mlrun/runtimes/base.py +51 -35
- mlrun/runtimes/daskjob.py +17 -10
- mlrun/runtimes/databricks_job/databricks_cancel_task.py +1 -1
- mlrun/runtimes/databricks_job/databricks_runtime.py +8 -7
- mlrun/runtimes/databricks_job/databricks_wrapper.py +1 -1
- mlrun/runtimes/funcdoc.py +1 -29
- mlrun/runtimes/function_reference.py +1 -1
- mlrun/runtimes/kubejob.py +34 -128
- mlrun/runtimes/local.py +40 -11
- mlrun/runtimes/mpijob/__init__.py +0 -20
- mlrun/runtimes/mpijob/abstract.py +9 -10
- mlrun/runtimes/mpijob/v1.py +1 -1
- mlrun/{model_monitoring/stores/models/sqlite.py → runtimes/nuclio/__init__.py} +7 -9
- mlrun/runtimes/nuclio/api_gateway.py +769 -0
- mlrun/runtimes/nuclio/application/__init__.py +15 -0
- mlrun/runtimes/nuclio/application/application.py +758 -0
- mlrun/runtimes/nuclio/application/reverse_proxy.go +95 -0
- mlrun/runtimes/{function.py → nuclio/function.py} +200 -83
- mlrun/runtimes/{nuclio.py → nuclio/nuclio.py} +6 -6
- mlrun/runtimes/{serving.py → nuclio/serving.py} +65 -68
- mlrun/runtimes/pod.py +281 -101
- mlrun/runtimes/remotesparkjob.py +12 -9
- mlrun/runtimes/sparkjob/spark3job.py +67 -51
- mlrun/runtimes/utils.py +41 -75
- mlrun/secrets.py +9 -5
- mlrun/serving/__init__.py +8 -1
- mlrun/serving/remote.py +2 -7
- mlrun/serving/routers.py +85 -69
- mlrun/serving/server.py +69 -44
- mlrun/serving/states.py +209 -36
- mlrun/serving/utils.py +22 -14
- mlrun/serving/v1_serving.py +6 -7
- mlrun/serving/v2_serving.py +133 -54
- mlrun/track/tracker.py +2 -1
- mlrun/track/tracker_manager.py +3 -3
- mlrun/track/trackers/mlflow_tracker.py +6 -2
- mlrun/utils/async_http.py +6 -8
- mlrun/utils/azure_vault.py +1 -1
- mlrun/utils/clones.py +1 -2
- mlrun/utils/condition_evaluator.py +3 -3
- mlrun/utils/db.py +21 -3
- mlrun/utils/helpers.py +405 -225
- mlrun/utils/http.py +3 -6
- mlrun/utils/logger.py +112 -16
- mlrun/utils/notifications/notification/__init__.py +17 -13
- mlrun/utils/notifications/notification/base.py +50 -2
- mlrun/utils/notifications/notification/console.py +2 -0
- mlrun/utils/notifications/notification/git.py +24 -1
- mlrun/utils/notifications/notification/ipython.py +3 -1
- mlrun/utils/notifications/notification/slack.py +96 -21
- mlrun/utils/notifications/notification/webhook.py +59 -2
- mlrun/utils/notifications/notification_pusher.py +149 -30
- mlrun/utils/regex.py +9 -0
- mlrun/utils/retryer.py +208 -0
- mlrun/utils/singleton.py +1 -1
- mlrun/utils/v3io_clients.py +4 -6
- mlrun/utils/version/version.json +2 -2
- mlrun/utils/version/version.py +2 -6
- mlrun-1.7.0.dist-info/METADATA +378 -0
- mlrun-1.7.0.dist-info/RECORD +351 -0
- {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/WHEEL +1 -1
- mlrun/feature_store/retrieval/conversion.py +0 -273
- mlrun/kfpops.py +0 -868
- mlrun/model_monitoring/application.py +0 -310
- mlrun/model_monitoring/batch.py +0 -1095
- mlrun/model_monitoring/prometheus.py +0 -219
- mlrun/model_monitoring/stores/__init__.py +0 -111
- mlrun/model_monitoring/stores/kv_model_endpoint_store.py +0 -576
- mlrun/model_monitoring/stores/model_endpoint_store.py +0 -147
- mlrun/model_monitoring/stores/models/__init__.py +0 -27
- mlrun/model_monitoring/stores/models/base.py +0 -84
- mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -384
- mlrun/platforms/other.py +0 -306
- mlrun-1.6.4rc8.dist-info/METADATA +0 -272
- mlrun-1.6.4rc8.dist-info/RECORD +0 -314
- {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/LICENSE +0 -0
- {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/entry_points.txt +0 -0
- {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/top_level.txt +0 -0
|
@@ -14,12 +14,11 @@
|
|
|
14
14
|
#
|
|
15
15
|
import datetime
|
|
16
16
|
import os
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import Union
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
import pandas as pd
|
|
21
|
-
from IPython.
|
|
22
|
-
from IPython.display import display
|
|
21
|
+
from IPython.display import HTML, display
|
|
23
22
|
from pandas.api.types import is_numeric_dtype, is_string_dtype
|
|
24
23
|
|
|
25
24
|
import mlrun
|
|
@@ -216,7 +215,7 @@ def _show_and_export_html(html: str, show=None, filename=None, runs_list=None):
|
|
|
216
215
|
fp.write("</body></html>")
|
|
217
216
|
else:
|
|
218
217
|
fp.write(html)
|
|
219
|
-
if show or (show is None and mlrun.utils.
|
|
218
|
+
if show or (show is None and mlrun.utils.is_jupyter):
|
|
220
219
|
display(HTML(html))
|
|
221
220
|
if runs_list and len(runs_list) <= max_table_rows:
|
|
222
221
|
display(HTML(html_table))
|
|
@@ -239,7 +238,7 @@ def _runs_list_to_df(runs_list, extend_iterations=False):
|
|
|
239
238
|
|
|
240
239
|
@filter_warnings("ignore", FutureWarning)
|
|
241
240
|
def compare_run_objects(
|
|
242
|
-
runs_list: Union[mlrun.model.RunObject,
|
|
241
|
+
runs_list: Union[mlrun.model.RunObject, list[mlrun.model.RunObject]],
|
|
243
242
|
hide_identical: bool = True,
|
|
244
243
|
exclude: list = None,
|
|
245
244
|
show: bool = None,
|
|
@@ -295,7 +294,7 @@ def compare_db_runs(
|
|
|
295
294
|
iter=False,
|
|
296
295
|
start_time_from: datetime = None,
|
|
297
296
|
hide_identical: bool = True,
|
|
298
|
-
exclude: list =
|
|
297
|
+
exclude: list = None,
|
|
299
298
|
show=None,
|
|
300
299
|
colorscale: str = "Blues",
|
|
301
300
|
filename=None,
|
|
@@ -332,6 +331,7 @@ def compare_db_runs(
|
|
|
332
331
|
**query_args,
|
|
333
332
|
)
|
|
334
333
|
|
|
334
|
+
exclude = exclude or []
|
|
335
335
|
runs_df = _runs_list_to_df(runs_list)
|
|
336
336
|
plot_as_html = gen_pcp_plot(
|
|
337
337
|
runs_df,
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
|
|
16
|
-
from typing import Any,
|
|
16
|
+
from typing import Any, Union
|
|
17
17
|
|
|
18
18
|
from torch.nn import Module
|
|
19
19
|
from torch.optim import Optimizer
|
|
@@ -35,23 +35,23 @@ def train(
|
|
|
35
35
|
loss_function: Module,
|
|
36
36
|
optimizer: Optimizer,
|
|
37
37
|
validation_set: DataLoader = None,
|
|
38
|
-
metric_functions:
|
|
38
|
+
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
39
39
|
scheduler=None,
|
|
40
40
|
scheduler_step_frequency: Union[int, float, str] = "epoch",
|
|
41
41
|
epochs: int = 1,
|
|
42
42
|
training_iterations: int = None,
|
|
43
43
|
validation_iterations: int = None,
|
|
44
|
-
callbacks_list:
|
|
44
|
+
callbacks_list: list[Callback] = None,
|
|
45
45
|
use_cuda: bool = True,
|
|
46
46
|
use_horovod: bool = None,
|
|
47
47
|
auto_log: bool = True,
|
|
48
48
|
model_name: str = None,
|
|
49
|
-
modules_map: Union[
|
|
50
|
-
custom_objects_map: Union[
|
|
49
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
50
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
51
51
|
custom_objects_directory: str = None,
|
|
52
52
|
tensorboard_directory: str = None,
|
|
53
|
-
mlrun_callback_kwargs:
|
|
54
|
-
tensorboard_callback_kwargs:
|
|
53
|
+
mlrun_callback_kwargs: dict[str, Any] = None,
|
|
54
|
+
tensorboard_callback_kwargs: dict[str, Any] = None,
|
|
55
55
|
context: mlrun.MLClientCtx = None,
|
|
56
56
|
) -> PyTorchModelHandler:
|
|
57
57
|
"""
|
|
@@ -112,7 +112,7 @@ def train(
|
|
|
112
112
|
|
|
113
113
|
{
|
|
114
114
|
"/.../custom_optimizer.py": "optimizer",
|
|
115
|
-
"/.../custom_layers.py": ["layer1", "layer2"]
|
|
115
|
+
"/.../custom_layers.py": ["layer1", "layer2"],
|
|
116
116
|
}
|
|
117
117
|
|
|
118
118
|
All the paths will be accessed from the given 'custom_objects_directory',
|
|
@@ -205,19 +205,19 @@ def evaluate(
|
|
|
205
205
|
dataset: DataLoader,
|
|
206
206
|
model: Module = None,
|
|
207
207
|
loss_function: Module = None,
|
|
208
|
-
metric_functions:
|
|
208
|
+
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
209
209
|
iterations: int = None,
|
|
210
|
-
callbacks_list:
|
|
210
|
+
callbacks_list: list[Callback] = None,
|
|
211
211
|
use_cuda: bool = True,
|
|
212
212
|
use_horovod: bool = False,
|
|
213
213
|
auto_log: bool = True,
|
|
214
214
|
model_name: str = None,
|
|
215
|
-
modules_map: Union[
|
|
216
|
-
custom_objects_map: Union[
|
|
215
|
+
modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
|
|
216
|
+
custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
|
|
217
217
|
custom_objects_directory: str = None,
|
|
218
|
-
mlrun_callback_kwargs:
|
|
218
|
+
mlrun_callback_kwargs: dict[str, Any] = None,
|
|
219
219
|
context: mlrun.MLClientCtx = None,
|
|
220
|
-
) ->
|
|
220
|
+
) -> tuple[PyTorchModelHandler, list[PyTorchTypes.MetricValueType]]:
|
|
221
221
|
"""
|
|
222
222
|
Use MLRun's PyTorch interface to evaluate the model with the given parameters. For more information and further
|
|
223
223
|
options regarding the auto logging, see 'PyTorchMLRunInterface' documentation. Notice for auto-logging: In order to
|
|
@@ -264,7 +264,7 @@ def evaluate(
|
|
|
264
264
|
|
|
265
265
|
{
|
|
266
266
|
"/.../custom_optimizer.py": "optimizer",
|
|
267
|
-
"/.../custom_layers.py": ["layer1", "layer2"]
|
|
267
|
+
"/.../custom_layers.py": ["layer1", "layer2"],
|
|
268
268
|
}
|
|
269
269
|
|
|
270
270
|
All the paths will be accessed from the given 'custom_objects_directory', meaning
|
|
@@ -343,9 +343,9 @@ def evaluate(
|
|
|
343
343
|
def _parse_callbacks_kwargs(
|
|
344
344
|
handler: PyTorchModelHandler,
|
|
345
345
|
tensorboard_directory: Union[str, None],
|
|
346
|
-
mlrun_callback_kwargs: Union[
|
|
347
|
-
tensorboard_callback_kwargs: Union[
|
|
348
|
-
) ->
|
|
346
|
+
mlrun_callback_kwargs: Union[dict[str, Any], None],
|
|
347
|
+
tensorboard_callback_kwargs: Union[dict[str, Any], None],
|
|
348
|
+
) -> tuple[dict, dict]:
|
|
349
349
|
"""
|
|
350
350
|
Parse the given parameters into the MLRun and Tensorboard callbacks kwargs.
|
|
351
351
|
|
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
from abc import ABC, abstractmethod
|
|
16
|
-
from typing import List
|
|
17
16
|
|
|
18
17
|
from torch import Tensor
|
|
19
18
|
from torch.nn import Module
|
|
@@ -68,7 +67,7 @@ class Callback(ABC):
|
|
|
68
67
|
validation_set: DataLoader = None,
|
|
69
68
|
loss_function: Module = None,
|
|
70
69
|
optimizer: Optimizer = None,
|
|
71
|
-
metric_functions:
|
|
70
|
+
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
72
71
|
scheduler=None,
|
|
73
72
|
):
|
|
74
73
|
"""
|
|
@@ -141,7 +140,7 @@ class Callback(ABC):
|
|
|
141
140
|
pass
|
|
142
141
|
|
|
143
142
|
def on_validation_end(
|
|
144
|
-
self, loss_value: PyTorchTypes.MetricValueType, metric_values:
|
|
143
|
+
self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
|
|
145
144
|
) -> bool:
|
|
146
145
|
"""
|
|
147
146
|
Before the validation (in a training case it will be per epoch) ends, this method will be called.
|
|
@@ -258,7 +257,7 @@ class Callback(ABC):
|
|
|
258
257
|
"""
|
|
259
258
|
pass
|
|
260
259
|
|
|
261
|
-
def on_train_metrics_end(self, metric_values:
|
|
260
|
+
def on_train_metrics_end(self, metric_values: list[PyTorchTypes.MetricValueType]):
|
|
262
261
|
"""
|
|
263
262
|
After the training calculation of the metrics, this method will be called.
|
|
264
263
|
|
|
@@ -273,7 +272,7 @@ class Callback(ABC):
|
|
|
273
272
|
pass
|
|
274
273
|
|
|
275
274
|
def on_validation_metrics_end(
|
|
276
|
-
self, metric_values:
|
|
275
|
+
self, metric_values: list[PyTorchTypes.MetricValueType]
|
|
277
276
|
):
|
|
278
277
|
"""
|
|
279
278
|
After the validating calculation of the metrics, this method will be called.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Callable,
|
|
15
|
+
from typing import Callable, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
from torch import Tensor
|
|
@@ -59,15 +59,15 @@ class LoggingCallback(Callback):
|
|
|
59
59
|
def __init__(
|
|
60
60
|
self,
|
|
61
61
|
context: mlrun.MLClientCtx = None,
|
|
62
|
-
dynamic_hyperparameters:
|
|
62
|
+
dynamic_hyperparameters: dict[
|
|
63
63
|
str,
|
|
64
|
-
|
|
64
|
+
tuple[
|
|
65
65
|
str,
|
|
66
|
-
Union[
|
|
66
|
+
Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
|
|
67
67
|
],
|
|
68
68
|
] = None,
|
|
69
|
-
static_hyperparameters:
|
|
70
|
-
str, Union[PyTorchTypes.TrackableType,
|
|
69
|
+
static_hyperparameters: dict[
|
|
70
|
+
str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
|
|
71
71
|
] = None,
|
|
72
72
|
auto_log: bool = False,
|
|
73
73
|
):
|
|
@@ -100,7 +100,7 @@ class LoggingCallback(Callback):
|
|
|
100
100
|
:param auto_log: Whether or not to enable auto logging, trying to track common static and dynamic
|
|
101
101
|
hyperparameters.
|
|
102
102
|
"""
|
|
103
|
-
super(
|
|
103
|
+
super().__init__()
|
|
104
104
|
|
|
105
105
|
# Store the configurations:
|
|
106
106
|
self._dynamic_hyperparameters_keys = (
|
|
@@ -117,7 +117,7 @@ class LoggingCallback(Callback):
|
|
|
117
117
|
self._is_training = None # type: bool
|
|
118
118
|
self._auto_log = auto_log
|
|
119
119
|
|
|
120
|
-
def get_training_results(self) ->
|
|
120
|
+
def get_training_results(self) -> dict[str, list[list[float]]]:
|
|
121
121
|
"""
|
|
122
122
|
Get the training results logged. The results will be stored in a dictionary where each key is the metric name
|
|
123
123
|
and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
|
|
@@ -127,7 +127,7 @@ class LoggingCallback(Callback):
|
|
|
127
127
|
"""
|
|
128
128
|
return self._logger.training_results
|
|
129
129
|
|
|
130
|
-
def get_validation_results(self) ->
|
|
130
|
+
def get_validation_results(self) -> dict[str, list[list[float]]]:
|
|
131
131
|
"""
|
|
132
132
|
Get the validation results logged. The results will be stored in a dictionary where each key is the metric name
|
|
133
133
|
and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
|
|
@@ -137,7 +137,7 @@ class LoggingCallback(Callback):
|
|
|
137
137
|
"""
|
|
138
138
|
return self._logger.validation_results
|
|
139
139
|
|
|
140
|
-
def get_static_hyperparameters(self) ->
|
|
140
|
+
def get_static_hyperparameters(self) -> dict[str, PyTorchTypes.TrackableType]:
|
|
141
141
|
"""
|
|
142
142
|
Get the static hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
|
|
143
143
|
hyperparameter name and the value is his logged value.
|
|
@@ -148,7 +148,7 @@ class LoggingCallback(Callback):
|
|
|
148
148
|
|
|
149
149
|
def get_dynamic_hyperparameters(
|
|
150
150
|
self,
|
|
151
|
-
) ->
|
|
151
|
+
) -> dict[str, list[PyTorchTypes.TrackableType]]:
|
|
152
152
|
"""
|
|
153
153
|
Get the dynamic hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
|
|
154
154
|
hyperparameter name and the value is a list of his logged values per epoch.
|
|
@@ -157,7 +157,7 @@ class LoggingCallback(Callback):
|
|
|
157
157
|
"""
|
|
158
158
|
return self._logger.dynamic_hyperparameters
|
|
159
159
|
|
|
160
|
-
def get_summaries(self) ->
|
|
160
|
+
def get_summaries(self) -> dict[str, list[float]]:
|
|
161
161
|
"""
|
|
162
162
|
Get the validation summaries of the metrics results. The summaries will be stored in a dictionary where each key
|
|
163
163
|
is the metric names and the value is a list of all the summary values per epoch.
|
|
@@ -210,7 +210,7 @@ class LoggingCallback(Callback):
|
|
|
210
210
|
self._add_auto_hyperparameters()
|
|
211
211
|
# # Static hyperparameters:
|
|
212
212
|
for name, value in self._static_hyperparameters_keys.items():
|
|
213
|
-
if isinstance(value,
|
|
213
|
+
if isinstance(value, tuple):
|
|
214
214
|
# Its a parameter that needed to be extracted via key chain.
|
|
215
215
|
self._logger.log_static_hyperparameter(
|
|
216
216
|
parameter_name=name,
|
|
@@ -294,7 +294,7 @@ class LoggingCallback(Callback):
|
|
|
294
294
|
self._logger.set_mode(mode=LoggingMode.EVALUATION)
|
|
295
295
|
|
|
296
296
|
def on_validation_end(
|
|
297
|
-
self, loss_value: PyTorchTypes.MetricValueType, metric_values:
|
|
297
|
+
self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
|
|
298
298
|
):
|
|
299
299
|
"""
|
|
300
300
|
Before the validation (in a training case it will be per epoch) ends, this method will be called to log the
|
|
@@ -372,7 +372,7 @@ class LoggingCallback(Callback):
|
|
|
372
372
|
result=float(loss_value),
|
|
373
373
|
)
|
|
374
374
|
|
|
375
|
-
def on_train_metrics_end(self, metric_values:
|
|
375
|
+
def on_train_metrics_end(self, metric_values: list[PyTorchTypes.MetricValueType]):
|
|
376
376
|
"""
|
|
377
377
|
After the training calculation of the metrics, this method will be called to log the metrics values.
|
|
378
378
|
|
|
@@ -389,7 +389,7 @@ class LoggingCallback(Callback):
|
|
|
389
389
|
)
|
|
390
390
|
|
|
391
391
|
def on_validation_metrics_end(
|
|
392
|
-
self, metric_values:
|
|
392
|
+
self, metric_values: list[PyTorchTypes.MetricValueType]
|
|
393
393
|
):
|
|
394
394
|
"""
|
|
395
395
|
After the validating calculation of the metrics, this method will be called to log the metrics values.
|
|
@@ -456,7 +456,7 @@ class LoggingCallback(Callback):
|
|
|
456
456
|
self,
|
|
457
457
|
source: str,
|
|
458
458
|
key_chain: Union[
|
|
459
|
-
|
|
459
|
+
list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]
|
|
460
460
|
],
|
|
461
461
|
) -> PyTorchTypes.TrackableType:
|
|
462
462
|
"""
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Callable,
|
|
15
|
+
from typing import Callable, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from torch import Tensor
|
|
@@ -53,20 +53,20 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
53
53
|
context: mlrun.MLClientCtx,
|
|
54
54
|
model_handler: PyTorchModelHandler,
|
|
55
55
|
log_model_tag: str = "",
|
|
56
|
-
log_model_labels:
|
|
57
|
-
log_model_parameters:
|
|
58
|
-
log_model_extra_data:
|
|
56
|
+
log_model_labels: dict[str, PyTorchTypes.TrackableType] = None,
|
|
57
|
+
log_model_parameters: dict[str, PyTorchTypes.TrackableType] = None,
|
|
58
|
+
log_model_extra_data: dict[
|
|
59
59
|
str, Union[PyTorchTypes.TrackableType, Artifact]
|
|
60
60
|
] = None,
|
|
61
|
-
dynamic_hyperparameters:
|
|
61
|
+
dynamic_hyperparameters: dict[
|
|
62
62
|
str,
|
|
63
|
-
|
|
63
|
+
tuple[
|
|
64
64
|
str,
|
|
65
|
-
Union[
|
|
65
|
+
Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
|
|
66
66
|
],
|
|
67
67
|
] = None,
|
|
68
|
-
static_hyperparameters:
|
|
69
|
-
str, Union[PyTorchTypes.TrackableType,
|
|
68
|
+
static_hyperparameters: dict[
|
|
69
|
+
str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
|
|
70
70
|
] = None,
|
|
71
71
|
auto_log: bool = False,
|
|
72
72
|
):
|
|
@@ -107,7 +107,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
107
107
|
:param auto_log: Whether or not to enable auto logging for logging the context parameters and
|
|
108
108
|
trying to track common static and dynamic hyperparameters.
|
|
109
109
|
"""
|
|
110
|
-
super(
|
|
110
|
+
super().__init__(
|
|
111
111
|
dynamic_hyperparameters=dynamic_hyperparameters,
|
|
112
112
|
static_hyperparameters=static_hyperparameters,
|
|
113
113
|
auto_log=auto_log,
|
|
@@ -160,7 +160,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
160
160
|
|
|
161
161
|
:param epoch: The epoch that has just ended.
|
|
162
162
|
"""
|
|
163
|
-
super(
|
|
163
|
+
super().on_epoch_end(epoch=epoch)
|
|
164
164
|
|
|
165
165
|
# Create child context to hold the current epoch's results:
|
|
166
166
|
self._logger.log_epoch_to_context(epoch=epoch)
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
from datetime import datetime
|
|
16
|
-
from typing import Callable,
|
|
16
|
+
from typing import Callable, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from torch import Tensor
|
|
@@ -63,7 +63,7 @@ class _PyTorchTensorboardLogger(TensorboardLogger):
|
|
|
63
63
|
|
|
64
64
|
def __init__(
|
|
65
65
|
self,
|
|
66
|
-
statistics_functions:
|
|
66
|
+
statistics_functions: list[
|
|
67
67
|
Callable[[Union[Parameter]], Union[float, Parameter]]
|
|
68
68
|
],
|
|
69
69
|
context: mlrun.MLClientCtx = None,
|
|
@@ -94,7 +94,7 @@ class _PyTorchTensorboardLogger(TensorboardLogger):
|
|
|
94
94
|
update. Notice that writing to tensorboard too frequently may cause the training
|
|
95
95
|
to be slower. Default: 'epoch'.
|
|
96
96
|
"""
|
|
97
|
-
super(
|
|
97
|
+
super().__init__(
|
|
98
98
|
statistics_functions=statistics_functions,
|
|
99
99
|
context=context,
|
|
100
100
|
tensorboard_directory=tensorboard_directory,
|
|
@@ -249,19 +249,19 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
249
249
|
context: mlrun.MLClientCtx = None,
|
|
250
250
|
tensorboard_directory: str = None,
|
|
251
251
|
run_name: str = None,
|
|
252
|
-
weights: Union[bool,
|
|
253
|
-
statistics_functions:
|
|
252
|
+
weights: Union[bool, list[str]] = False,
|
|
253
|
+
statistics_functions: list[
|
|
254
254
|
Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]
|
|
255
255
|
] = None,
|
|
256
|
-
dynamic_hyperparameters:
|
|
256
|
+
dynamic_hyperparameters: dict[
|
|
257
257
|
str,
|
|
258
|
-
|
|
258
|
+
tuple[
|
|
259
259
|
str,
|
|
260
|
-
Union[
|
|
260
|
+
Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
|
|
261
261
|
],
|
|
262
262
|
] = None,
|
|
263
|
-
static_hyperparameters:
|
|
264
|
-
str, Union[PyTorchTypes.TrackableType,
|
|
263
|
+
static_hyperparameters: dict[
|
|
264
|
+
str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
|
|
265
265
|
] = None,
|
|
266
266
|
update_frequency: Union[int, str] = "epoch",
|
|
267
267
|
auto_log: bool = False,
|
|
@@ -322,7 +322,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
322
322
|
:raise MLRunInvalidArgumentError: In case both 'context' and 'tensorboard_directory' parameters were not given
|
|
323
323
|
or the 'update_frequency' was incorrect.
|
|
324
324
|
"""
|
|
325
|
-
super(
|
|
325
|
+
super().__init__(
|
|
326
326
|
dynamic_hyperparameters=dynamic_hyperparameters,
|
|
327
327
|
static_hyperparameters=static_hyperparameters,
|
|
328
328
|
auto_log=auto_log,
|
|
@@ -345,7 +345,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
345
345
|
# Save the configurations:
|
|
346
346
|
self._tracked_weights = weights
|
|
347
347
|
|
|
348
|
-
def get_weights(self) ->
|
|
348
|
+
def get_weights(self) -> dict[str, Parameter]:
|
|
349
349
|
"""
|
|
350
350
|
Get the weights tensors tracked. The weights will be stored in a dictionary where each key is the weight's name
|
|
351
351
|
and the value is the weight's parameter (tensor).
|
|
@@ -354,7 +354,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
354
354
|
"""
|
|
355
355
|
return self._logger.weights
|
|
356
356
|
|
|
357
|
-
def get_weights_statistics(self) ->
|
|
357
|
+
def get_weights_statistics(self) -> dict[str, dict[str, list[float]]]:
|
|
358
358
|
"""
|
|
359
359
|
Get the weights mean results logged. The results will be stored in a dictionary where each key is the weight's
|
|
360
360
|
name and the value is a list of mean values per epoch.
|
|
@@ -365,7 +365,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
365
365
|
|
|
366
366
|
@staticmethod
|
|
367
367
|
def get_default_weight_statistics_list() -> (
|
|
368
|
-
|
|
368
|
+
list[Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]]
|
|
369
369
|
):
|
|
370
370
|
"""
|
|
371
371
|
Get the default list of statistics functions being applied on the tracked weights each epoch.
|
|
@@ -381,7 +381,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
381
381
|
validation_set: DataLoader = None,
|
|
382
382
|
loss_function: Module = None,
|
|
383
383
|
optimizer: Optimizer = None,
|
|
384
|
-
metric_functions:
|
|
384
|
+
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
385
385
|
scheduler=None,
|
|
386
386
|
):
|
|
387
387
|
"""
|
|
@@ -396,7 +396,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
396
396
|
:param metric_functions: The metric functions to be stored in this callback.
|
|
397
397
|
:param scheduler: The scheduler to be stored in this callback.
|
|
398
398
|
"""
|
|
399
|
-
super(
|
|
399
|
+
super().on_setup(
|
|
400
400
|
model=model,
|
|
401
401
|
training_set=training_set,
|
|
402
402
|
validation_set=validation_set,
|
|
@@ -439,7 +439,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
439
439
|
for logging. Epoch 0 (pre-run state) will be logged here.
|
|
440
440
|
"""
|
|
441
441
|
# Setup all the results and hyperparameters dictionaries:
|
|
442
|
-
super(
|
|
442
|
+
super().on_run_begin()
|
|
443
443
|
|
|
444
444
|
# Log the initial summary of the run:
|
|
445
445
|
self._logger.write_initial_summary_text()
|
|
@@ -470,10 +470,10 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
470
470
|
# Write the final summary of the run:
|
|
471
471
|
self._logger.write_final_summary_text()
|
|
472
472
|
|
|
473
|
-
super(
|
|
473
|
+
super().on_run_end()
|
|
474
474
|
|
|
475
475
|
def on_validation_end(
|
|
476
|
-
self, loss_value: PyTorchTypes.MetricValueType, metric_values:
|
|
476
|
+
self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
|
|
477
477
|
):
|
|
478
478
|
"""
|
|
479
479
|
Before the validation (in a training case it will be per epoch) ends, this method will be called to log the
|
|
@@ -482,9 +482,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
482
482
|
:param loss_value: The loss summary of this validation.
|
|
483
483
|
:param metric_values: The metrics summaries of this validation.
|
|
484
484
|
"""
|
|
485
|
-
super(
|
|
486
|
-
loss_value=loss_value, metric_values=metric_values
|
|
487
|
-
)
|
|
485
|
+
super().on_validation_end(loss_value=loss_value, metric_values=metric_values)
|
|
488
486
|
|
|
489
487
|
# Check if this run was part of an evaluation:
|
|
490
488
|
if not self._is_training:
|
|
@@ -503,7 +501,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
503
501
|
|
|
504
502
|
:param epoch: The epoch that has just ended.
|
|
505
503
|
"""
|
|
506
|
-
super(
|
|
504
|
+
super().on_epoch_end(epoch=epoch)
|
|
507
505
|
|
|
508
506
|
# Log the weights statistics:
|
|
509
507
|
self._logger.log_weights_statistics()
|
|
@@ -540,9 +538,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
540
538
|
:param y_true: The true value part of the current batch.
|
|
541
539
|
:param y_pred: The prediction (output) of the model for this batch's input ('x').
|
|
542
540
|
"""
|
|
543
|
-
super(
|
|
544
|
-
batch=batch, x=x, y_true=y_true, y_pred=y_pred
|
|
545
|
-
)
|
|
541
|
+
super().on_train_batch_end(batch=batch, x=x, y_true=y_true, y_pred=y_pred)
|
|
546
542
|
|
|
547
543
|
# Write the batch loss and metrics results to their graphs:
|
|
548
544
|
self._logger.write_training_results()
|
|
@@ -559,9 +555,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
559
555
|
:param y_true: The true value part of the current batch.
|
|
560
556
|
:param y_pred: The prediction (output) of the model for this batch's input ('x').
|
|
561
557
|
"""
|
|
562
|
-
super(
|
|
563
|
-
batch=batch, x=x, y_true=y_true, y_pred=y_pred
|
|
564
|
-
)
|
|
558
|
+
super().on_validation_batch_end(batch=batch, x=x, y_true=y_true, y_pred=y_pred)
|
|
565
559
|
|
|
566
560
|
# Write the batch loss and metrics results to their graphs:
|
|
567
561
|
self._logger.write_validation_results()
|