mlrun 1.6.0rc35__py3-none-any.whl → 1.7.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__main__.py +3 -3
- mlrun/api/schemas/__init__.py +1 -1
- mlrun/artifacts/base.py +11 -6
- mlrun/artifacts/dataset.py +2 -2
- mlrun/artifacts/model.py +30 -24
- mlrun/artifacts/plots.py +2 -2
- mlrun/common/db/sql_session.py +5 -3
- mlrun/common/helpers.py +1 -2
- mlrun/common/schemas/artifact.py +3 -3
- mlrun/common/schemas/auth.py +3 -3
- mlrun/common/schemas/background_task.py +1 -1
- mlrun/common/schemas/client_spec.py +1 -1
- mlrun/common/schemas/feature_store.py +16 -16
- mlrun/common/schemas/frontend_spec.py +7 -7
- mlrun/common/schemas/function.py +1 -1
- mlrun/common/schemas/hub.py +4 -9
- mlrun/common/schemas/memory_reports.py +2 -2
- mlrun/common/schemas/model_monitoring/grafana.py +4 -4
- mlrun/common/schemas/model_monitoring/model_endpoints.py +14 -15
- mlrun/common/schemas/notification.py +4 -4
- mlrun/common/schemas/object.py +2 -2
- mlrun/common/schemas/pipeline.py +1 -1
- mlrun/common/schemas/project.py +3 -3
- mlrun/common/schemas/runtime_resource.py +8 -12
- mlrun/common/schemas/schedule.py +3 -3
- mlrun/common/schemas/tag.py +1 -2
- mlrun/common/schemas/workflow.py +2 -2
- mlrun/config.py +8 -4
- mlrun/data_types/to_pandas.py +1 -3
- mlrun/datastore/base.py +0 -28
- mlrun/datastore/datastore_profile.py +9 -9
- mlrun/datastore/filestore.py +0 -1
- mlrun/datastore/google_cloud_storage.py +1 -1
- mlrun/datastore/sources.py +7 -11
- mlrun/datastore/spark_utils.py +1 -2
- mlrun/datastore/targets.py +31 -31
- mlrun/datastore/utils.py +4 -6
- mlrun/datastore/v3io.py +70 -46
- mlrun/db/base.py +22 -23
- mlrun/db/httpdb.py +34 -34
- mlrun/db/nopdb.py +19 -19
- mlrun/errors.py +1 -1
- mlrun/execution.py +4 -4
- mlrun/feature_store/api.py +20 -21
- mlrun/feature_store/common.py +1 -1
- mlrun/feature_store/feature_set.py +28 -32
- mlrun/feature_store/feature_vector.py +24 -27
- mlrun/feature_store/retrieval/base.py +7 -7
- mlrun/feature_store/retrieval/conversion.py +2 -4
- mlrun/feature_store/steps.py +7 -15
- mlrun/features.py +5 -7
- mlrun/frameworks/_common/artifacts_library.py +9 -9
- mlrun/frameworks/_common/mlrun_interface.py +5 -5
- mlrun/frameworks/_common/model_handler.py +48 -48
- mlrun/frameworks/_common/plan.py +2 -3
- mlrun/frameworks/_common/producer.py +3 -4
- mlrun/frameworks/_common/utils.py +5 -5
- mlrun/frameworks/_dl_common/loggers/logger.py +6 -7
- mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +9 -9
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +16 -35
- mlrun/frameworks/_ml_common/artifacts_library.py +1 -2
- mlrun/frameworks/_ml_common/loggers/logger.py +3 -4
- mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +4 -5
- mlrun/frameworks/_ml_common/model_handler.py +24 -24
- mlrun/frameworks/_ml_common/pkl_model_server.py +2 -2
- mlrun/frameworks/_ml_common/plan.py +1 -1
- mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
- mlrun/frameworks/_ml_common/utils.py +4 -4
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +7 -7
- mlrun/frameworks/huggingface/model_server.py +4 -4
- mlrun/frameworks/lgbm/__init__.py +32 -32
- mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -5
- mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -5
- mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +1 -3
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +6 -6
- mlrun/frameworks/lgbm/model_handler.py +9 -9
- mlrun/frameworks/lgbm/model_server.py +6 -6
- mlrun/frameworks/lgbm/utils.py +5 -5
- mlrun/frameworks/onnx/dataset.py +8 -8
- mlrun/frameworks/onnx/mlrun_interface.py +3 -3
- mlrun/frameworks/onnx/model_handler.py +6 -6
- mlrun/frameworks/onnx/model_server.py +7 -7
- mlrun/frameworks/parallel_coordinates.py +2 -2
- mlrun/frameworks/pytorch/__init__.py +16 -16
- mlrun/frameworks/pytorch/callbacks/callback.py +4 -5
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +17 -17
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +23 -29
- mlrun/frameworks/pytorch/callbacks_handler.py +38 -38
- mlrun/frameworks/pytorch/mlrun_interface.py +20 -20
- mlrun/frameworks/pytorch/model_handler.py +17 -17
- mlrun/frameworks/pytorch/model_server.py +7 -7
- mlrun/frameworks/sklearn/__init__.py +12 -12
- mlrun/frameworks/sklearn/estimator.py +4 -4
- mlrun/frameworks/sklearn/metrics_library.py +14 -14
- mlrun/frameworks/sklearn/mlrun_interface.py +3 -6
- mlrun/frameworks/sklearn/model_handler.py +2 -2
- mlrun/frameworks/tf_keras/__init__.py +5 -5
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +14 -14
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +19 -23
- mlrun/frameworks/tf_keras/mlrun_interface.py +7 -9
- mlrun/frameworks/tf_keras/model_handler.py +14 -14
- mlrun/frameworks/tf_keras/model_server.py +6 -6
- mlrun/frameworks/xgboost/__init__.py +12 -12
- mlrun/frameworks/xgboost/model_handler.py +6 -6
- mlrun/k8s_utils.py +4 -5
- mlrun/kfpops.py +2 -2
- mlrun/launcher/base.py +10 -10
- mlrun/launcher/local.py +8 -8
- mlrun/launcher/remote.py +7 -7
- mlrun/lists.py +3 -4
- mlrun/model.py +205 -55
- mlrun/model_monitoring/api.py +21 -24
- mlrun/model_monitoring/application.py +4 -4
- mlrun/model_monitoring/batch.py +17 -17
- mlrun/model_monitoring/controller.py +2 -1
- mlrun/model_monitoring/features_drift_table.py +44 -31
- mlrun/model_monitoring/prometheus.py +1 -4
- mlrun/model_monitoring/stores/kv_model_endpoint_store.py +11 -13
- mlrun/model_monitoring/stores/model_endpoint_store.py +9 -11
- mlrun/model_monitoring/stores/models/__init__.py +2 -2
- mlrun/model_monitoring/stores/sql_model_endpoint_store.py +11 -13
- mlrun/model_monitoring/stream_processing.py +16 -34
- mlrun/model_monitoring/tracking_policy.py +2 -1
- mlrun/package/__init__.py +6 -6
- mlrun/package/context_handler.py +5 -5
- mlrun/package/packager.py +7 -7
- mlrun/package/packagers/default_packager.py +6 -6
- mlrun/package/packagers/numpy_packagers.py +15 -15
- mlrun/package/packagers/pandas_packagers.py +5 -5
- mlrun/package/packagers/python_standard_library_packagers.py +10 -10
- mlrun/package/packagers_manager.py +18 -23
- mlrun/package/utils/_formatter.py +4 -4
- mlrun/package/utils/_pickler.py +2 -2
- mlrun/package/utils/_supported_format.py +4 -4
- mlrun/package/utils/log_hint_utils.py +2 -2
- mlrun/package/utils/type_hint_utils.py +4 -9
- mlrun/platforms/other.py +1 -2
- mlrun/projects/operations.py +5 -5
- mlrun/projects/pipelines.py +9 -9
- mlrun/projects/project.py +58 -46
- mlrun/render.py +1 -1
- mlrun/run.py +9 -9
- mlrun/runtimes/__init__.py +7 -4
- mlrun/runtimes/base.py +20 -23
- mlrun/runtimes/constants.py +5 -5
- mlrun/runtimes/daskjob.py +8 -8
- mlrun/runtimes/databricks_job/databricks_cancel_task.py +1 -1
- mlrun/runtimes/databricks_job/databricks_runtime.py +7 -7
- mlrun/runtimes/function_reference.py +1 -1
- mlrun/runtimes/local.py +1 -1
- mlrun/runtimes/mpijob/abstract.py +1 -2
- mlrun/runtimes/nuclio/__init__.py +20 -0
- mlrun/runtimes/{function.py → nuclio/function.py} +15 -16
- mlrun/runtimes/{nuclio.py → nuclio/nuclio.py} +6 -6
- mlrun/runtimes/{serving.py → nuclio/serving.py} +13 -12
- mlrun/runtimes/pod.py +95 -48
- mlrun/runtimes/remotesparkjob.py +1 -1
- mlrun/runtimes/sparkjob/spark3job.py +50 -33
- mlrun/runtimes/utils.py +1 -2
- mlrun/secrets.py +3 -3
- mlrun/serving/remote.py +0 -4
- mlrun/serving/routers.py +6 -6
- mlrun/serving/server.py +4 -4
- mlrun/serving/states.py +29 -0
- mlrun/serving/utils.py +3 -3
- mlrun/serving/v1_serving.py +6 -7
- mlrun/serving/v2_serving.py +50 -8
- mlrun/track/tracker_manager.py +3 -3
- mlrun/track/trackers/mlflow_tracker.py +1 -2
- mlrun/utils/async_http.py +5 -7
- mlrun/utils/azure_vault.py +1 -1
- mlrun/utils/clones.py +1 -2
- mlrun/utils/condition_evaluator.py +3 -3
- mlrun/utils/db.py +3 -3
- mlrun/utils/helpers.py +37 -119
- mlrun/utils/http.py +1 -4
- mlrun/utils/logger.py +49 -14
- mlrun/utils/notifications/notification/__init__.py +3 -3
- mlrun/utils/notifications/notification/base.py +2 -2
- mlrun/utils/notifications/notification/ipython.py +1 -1
- mlrun/utils/notifications/notification_pusher.py +8 -14
- mlrun/utils/retryer.py +207 -0
- mlrun/utils/singleton.py +1 -1
- mlrun/utils/v3io_clients.py +2 -3
- mlrun/utils/version/version.json +2 -2
- mlrun/utils/version/version.py +2 -6
- {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/METADATA +9 -9
- mlrun-1.7.0rc2.dist-info/RECORD +315 -0
- mlrun-1.6.0rc35.dist-info/RECORD +0 -313
- {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/LICENSE +0 -0
- {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/WHEEL +0 -0
- {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/entry_points.txt +0 -0
- {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Callable,
|
|
15
|
+
from typing import Callable, Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
from torch import Tensor
|
|
@@ -59,15 +59,15 @@ class LoggingCallback(Callback):
|
|
|
59
59
|
def __init__(
|
|
60
60
|
self,
|
|
61
61
|
context: mlrun.MLClientCtx = None,
|
|
62
|
-
dynamic_hyperparameters:
|
|
62
|
+
dynamic_hyperparameters: dict[
|
|
63
63
|
str,
|
|
64
|
-
|
|
64
|
+
tuple[
|
|
65
65
|
str,
|
|
66
|
-
Union[
|
|
66
|
+
Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
|
|
67
67
|
],
|
|
68
68
|
] = None,
|
|
69
|
-
static_hyperparameters:
|
|
70
|
-
str, Union[PyTorchTypes.TrackableType,
|
|
69
|
+
static_hyperparameters: dict[
|
|
70
|
+
str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
|
|
71
71
|
] = None,
|
|
72
72
|
auto_log: bool = False,
|
|
73
73
|
):
|
|
@@ -100,7 +100,7 @@ class LoggingCallback(Callback):
|
|
|
100
100
|
:param auto_log: Whether or not to enable auto logging, trying to track common static and dynamic
|
|
101
101
|
hyperparameters.
|
|
102
102
|
"""
|
|
103
|
-
super(
|
|
103
|
+
super().__init__()
|
|
104
104
|
|
|
105
105
|
# Store the configurations:
|
|
106
106
|
self._dynamic_hyperparameters_keys = (
|
|
@@ -117,7 +117,7 @@ class LoggingCallback(Callback):
|
|
|
117
117
|
self._is_training = None # type: bool
|
|
118
118
|
self._auto_log = auto_log
|
|
119
119
|
|
|
120
|
-
def get_training_results(self) ->
|
|
120
|
+
def get_training_results(self) -> dict[str, list[list[float]]]:
|
|
121
121
|
"""
|
|
122
122
|
Get the training results logged. The results will be stored in a dictionary where each key is the metric name
|
|
123
123
|
and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
|
|
@@ -127,7 +127,7 @@ class LoggingCallback(Callback):
|
|
|
127
127
|
"""
|
|
128
128
|
return self._logger.training_results
|
|
129
129
|
|
|
130
|
-
def get_validation_results(self) ->
|
|
130
|
+
def get_validation_results(self) -> dict[str, list[list[float]]]:
|
|
131
131
|
"""
|
|
132
132
|
Get the validation results logged. The results will be stored in a dictionary where each key is the metric name
|
|
133
133
|
and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
|
|
@@ -137,7 +137,7 @@ class LoggingCallback(Callback):
|
|
|
137
137
|
"""
|
|
138
138
|
return self._logger.validation_results
|
|
139
139
|
|
|
140
|
-
def get_static_hyperparameters(self) ->
|
|
140
|
+
def get_static_hyperparameters(self) -> dict[str, PyTorchTypes.TrackableType]:
|
|
141
141
|
"""
|
|
142
142
|
Get the static hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
|
|
143
143
|
hyperparameter name and the value is his logged value.
|
|
@@ -148,7 +148,7 @@ class LoggingCallback(Callback):
|
|
|
148
148
|
|
|
149
149
|
def get_dynamic_hyperparameters(
|
|
150
150
|
self,
|
|
151
|
-
) ->
|
|
151
|
+
) -> dict[str, list[PyTorchTypes.TrackableType]]:
|
|
152
152
|
"""
|
|
153
153
|
Get the dynamic hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
|
|
154
154
|
hyperparameter name and the value is a list of his logged values per epoch.
|
|
@@ -157,7 +157,7 @@ class LoggingCallback(Callback):
|
|
|
157
157
|
"""
|
|
158
158
|
return self._logger.dynamic_hyperparameters
|
|
159
159
|
|
|
160
|
-
def get_summaries(self) ->
|
|
160
|
+
def get_summaries(self) -> dict[str, list[float]]:
|
|
161
161
|
"""
|
|
162
162
|
Get the validation summaries of the metrics results. The summaries will be stored in a dictionary where each key
|
|
163
163
|
is the metric names and the value is a list of all the summary values per epoch.
|
|
@@ -210,7 +210,7 @@ class LoggingCallback(Callback):
|
|
|
210
210
|
self._add_auto_hyperparameters()
|
|
211
211
|
# # Static hyperparameters:
|
|
212
212
|
for name, value in self._static_hyperparameters_keys.items():
|
|
213
|
-
if isinstance(value,
|
|
213
|
+
if isinstance(value, tuple):
|
|
214
214
|
# Its a parameter that needed to be extracted via key chain.
|
|
215
215
|
self._logger.log_static_hyperparameter(
|
|
216
216
|
parameter_name=name,
|
|
@@ -294,7 +294,7 @@ class LoggingCallback(Callback):
|
|
|
294
294
|
self._logger.set_mode(mode=LoggingMode.EVALUATION)
|
|
295
295
|
|
|
296
296
|
def on_validation_end(
|
|
297
|
-
self, loss_value: PyTorchTypes.MetricValueType, metric_values:
|
|
297
|
+
self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
|
|
298
298
|
):
|
|
299
299
|
"""
|
|
300
300
|
Before the validation (in a training case it will be per epoch) ends, this method will be called to log the
|
|
@@ -372,7 +372,7 @@ class LoggingCallback(Callback):
|
|
|
372
372
|
result=float(loss_value),
|
|
373
373
|
)
|
|
374
374
|
|
|
375
|
-
def on_train_metrics_end(self, metric_values:
|
|
375
|
+
def on_train_metrics_end(self, metric_values: list[PyTorchTypes.MetricValueType]):
|
|
376
376
|
"""
|
|
377
377
|
After the training calculation of the metrics, this method will be called to log the metrics values.
|
|
378
378
|
|
|
@@ -389,7 +389,7 @@ class LoggingCallback(Callback):
|
|
|
389
389
|
)
|
|
390
390
|
|
|
391
391
|
def on_validation_metrics_end(
|
|
392
|
-
self, metric_values:
|
|
392
|
+
self, metric_values: list[PyTorchTypes.MetricValueType]
|
|
393
393
|
):
|
|
394
394
|
"""
|
|
395
395
|
After the validating calculation of the metrics, this method will be called to log the metrics values.
|
|
@@ -456,7 +456,7 @@ class LoggingCallback(Callback):
|
|
|
456
456
|
self,
|
|
457
457
|
source: str,
|
|
458
458
|
key_chain: Union[
|
|
459
|
-
|
|
459
|
+
list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]
|
|
460
460
|
],
|
|
461
461
|
) -> PyTorchTypes.TrackableType:
|
|
462
462
|
"""
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import Callable,
|
|
15
|
+
from typing import Callable, Union
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from torch import Tensor
|
|
@@ -53,20 +53,20 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
53
53
|
context: mlrun.MLClientCtx,
|
|
54
54
|
model_handler: PyTorchModelHandler,
|
|
55
55
|
log_model_tag: str = "",
|
|
56
|
-
log_model_labels:
|
|
57
|
-
log_model_parameters:
|
|
58
|
-
log_model_extra_data:
|
|
56
|
+
log_model_labels: dict[str, PyTorchTypes.TrackableType] = None,
|
|
57
|
+
log_model_parameters: dict[str, PyTorchTypes.TrackableType] = None,
|
|
58
|
+
log_model_extra_data: dict[
|
|
59
59
|
str, Union[PyTorchTypes.TrackableType, Artifact]
|
|
60
60
|
] = None,
|
|
61
|
-
dynamic_hyperparameters:
|
|
61
|
+
dynamic_hyperparameters: dict[
|
|
62
62
|
str,
|
|
63
|
-
|
|
63
|
+
tuple[
|
|
64
64
|
str,
|
|
65
|
-
Union[
|
|
65
|
+
Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
|
|
66
66
|
],
|
|
67
67
|
] = None,
|
|
68
|
-
static_hyperparameters:
|
|
69
|
-
str, Union[PyTorchTypes.TrackableType,
|
|
68
|
+
static_hyperparameters: dict[
|
|
69
|
+
str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
|
|
70
70
|
] = None,
|
|
71
71
|
auto_log: bool = False,
|
|
72
72
|
):
|
|
@@ -107,7 +107,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
107
107
|
:param auto_log: Whether or not to enable auto logging for logging the context parameters and
|
|
108
108
|
trying to track common static and dynamic hyperparameters.
|
|
109
109
|
"""
|
|
110
|
-
super(
|
|
110
|
+
super().__init__(
|
|
111
111
|
dynamic_hyperparameters=dynamic_hyperparameters,
|
|
112
112
|
static_hyperparameters=static_hyperparameters,
|
|
113
113
|
auto_log=auto_log,
|
|
@@ -160,7 +160,7 @@ class MLRunLoggingCallback(LoggingCallback):
|
|
|
160
160
|
|
|
161
161
|
:param epoch: The epoch that has just ended.
|
|
162
162
|
"""
|
|
163
|
-
super(
|
|
163
|
+
super().on_epoch_end(epoch=epoch)
|
|
164
164
|
|
|
165
165
|
# Create child context to hold the current epoch's results:
|
|
166
166
|
self._logger.log_epoch_to_context(epoch=epoch)
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
15
|
from datetime import datetime
|
|
16
|
-
from typing import Callable,
|
|
16
|
+
from typing import Callable, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from torch import Tensor
|
|
@@ -63,7 +63,7 @@ class _PyTorchTensorboardLogger(TensorboardLogger):
|
|
|
63
63
|
|
|
64
64
|
def __init__(
|
|
65
65
|
self,
|
|
66
|
-
statistics_functions:
|
|
66
|
+
statistics_functions: list[
|
|
67
67
|
Callable[[Union[Parameter]], Union[float, Parameter]]
|
|
68
68
|
],
|
|
69
69
|
context: mlrun.MLClientCtx = None,
|
|
@@ -94,7 +94,7 @@ class _PyTorchTensorboardLogger(TensorboardLogger):
|
|
|
94
94
|
update. Notice that writing to tensorboard too frequently may cause the training
|
|
95
95
|
to be slower. Default: 'epoch'.
|
|
96
96
|
"""
|
|
97
|
-
super(
|
|
97
|
+
super().__init__(
|
|
98
98
|
statistics_functions=statistics_functions,
|
|
99
99
|
context=context,
|
|
100
100
|
tensorboard_directory=tensorboard_directory,
|
|
@@ -249,19 +249,19 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
249
249
|
context: mlrun.MLClientCtx = None,
|
|
250
250
|
tensorboard_directory: str = None,
|
|
251
251
|
run_name: str = None,
|
|
252
|
-
weights: Union[bool,
|
|
253
|
-
statistics_functions:
|
|
252
|
+
weights: Union[bool, list[str]] = False,
|
|
253
|
+
statistics_functions: list[
|
|
254
254
|
Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]
|
|
255
255
|
] = None,
|
|
256
|
-
dynamic_hyperparameters:
|
|
256
|
+
dynamic_hyperparameters: dict[
|
|
257
257
|
str,
|
|
258
|
-
|
|
258
|
+
tuple[
|
|
259
259
|
str,
|
|
260
|
-
Union[
|
|
260
|
+
Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
|
|
261
261
|
],
|
|
262
262
|
] = None,
|
|
263
|
-
static_hyperparameters:
|
|
264
|
-
str, Union[PyTorchTypes.TrackableType,
|
|
263
|
+
static_hyperparameters: dict[
|
|
264
|
+
str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
|
|
265
265
|
] = None,
|
|
266
266
|
update_frequency: Union[int, str] = "epoch",
|
|
267
267
|
auto_log: bool = False,
|
|
@@ -322,7 +322,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
322
322
|
:raise MLRunInvalidArgumentError: In case both 'context' and 'tensorboard_directory' parameters were not given
|
|
323
323
|
or the 'update_frequency' was incorrect.
|
|
324
324
|
"""
|
|
325
|
-
super(
|
|
325
|
+
super().__init__(
|
|
326
326
|
dynamic_hyperparameters=dynamic_hyperparameters,
|
|
327
327
|
static_hyperparameters=static_hyperparameters,
|
|
328
328
|
auto_log=auto_log,
|
|
@@ -345,7 +345,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
345
345
|
# Save the configurations:
|
|
346
346
|
self._tracked_weights = weights
|
|
347
347
|
|
|
348
|
-
def get_weights(self) ->
|
|
348
|
+
def get_weights(self) -> dict[str, Parameter]:
|
|
349
349
|
"""
|
|
350
350
|
Get the weights tensors tracked. The weights will be stored in a dictionary where each key is the weight's name
|
|
351
351
|
and the value is the weight's parameter (tensor).
|
|
@@ -354,7 +354,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
354
354
|
"""
|
|
355
355
|
return self._logger.weights
|
|
356
356
|
|
|
357
|
-
def get_weights_statistics(self) ->
|
|
357
|
+
def get_weights_statistics(self) -> dict[str, dict[str, list[float]]]:
|
|
358
358
|
"""
|
|
359
359
|
Get the weights mean results logged. The results will be stored in a dictionary where each key is the weight's
|
|
360
360
|
name and the value is a list of mean values per epoch.
|
|
@@ -365,7 +365,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
365
365
|
|
|
366
366
|
@staticmethod
|
|
367
367
|
def get_default_weight_statistics_list() -> (
|
|
368
|
-
|
|
368
|
+
list[Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]]
|
|
369
369
|
):
|
|
370
370
|
"""
|
|
371
371
|
Get the default list of statistics functions being applied on the tracked weights each epoch.
|
|
@@ -381,7 +381,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
381
381
|
validation_set: DataLoader = None,
|
|
382
382
|
loss_function: Module = None,
|
|
383
383
|
optimizer: Optimizer = None,
|
|
384
|
-
metric_functions:
|
|
384
|
+
metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
|
|
385
385
|
scheduler=None,
|
|
386
386
|
):
|
|
387
387
|
"""
|
|
@@ -396,7 +396,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
396
396
|
:param metric_functions: The metric functions to be stored in this callback.
|
|
397
397
|
:param scheduler: The scheduler to be stored in this callback.
|
|
398
398
|
"""
|
|
399
|
-
super(
|
|
399
|
+
super().on_setup(
|
|
400
400
|
model=model,
|
|
401
401
|
training_set=training_set,
|
|
402
402
|
validation_set=validation_set,
|
|
@@ -439,7 +439,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
439
439
|
for logging. Epoch 0 (pre-run state) will be logged here.
|
|
440
440
|
"""
|
|
441
441
|
# Setup all the results and hyperparameters dictionaries:
|
|
442
|
-
super(
|
|
442
|
+
super().on_run_begin()
|
|
443
443
|
|
|
444
444
|
# Log the initial summary of the run:
|
|
445
445
|
self._logger.write_initial_summary_text()
|
|
@@ -470,10 +470,10 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
470
470
|
# Write the final summary of the run:
|
|
471
471
|
self._logger.write_final_summary_text()
|
|
472
472
|
|
|
473
|
-
super(
|
|
473
|
+
super().on_run_end()
|
|
474
474
|
|
|
475
475
|
def on_validation_end(
|
|
476
|
-
self, loss_value: PyTorchTypes.MetricValueType, metric_values:
|
|
476
|
+
self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
|
|
477
477
|
):
|
|
478
478
|
"""
|
|
479
479
|
Before the validation (in a training case it will be per epoch) ends, this method will be called to log the
|
|
@@ -482,9 +482,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
482
482
|
:param loss_value: The loss summary of this validation.
|
|
483
483
|
:param metric_values: The metrics summaries of this validation.
|
|
484
484
|
"""
|
|
485
|
-
super(
|
|
486
|
-
loss_value=loss_value, metric_values=metric_values
|
|
487
|
-
)
|
|
485
|
+
super().on_validation_end(loss_value=loss_value, metric_values=metric_values)
|
|
488
486
|
|
|
489
487
|
# Check if this run was part of an evaluation:
|
|
490
488
|
if not self._is_training:
|
|
@@ -503,7 +501,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
503
501
|
|
|
504
502
|
:param epoch: The epoch that has just ended.
|
|
505
503
|
"""
|
|
506
|
-
super(
|
|
504
|
+
super().on_epoch_end(epoch=epoch)
|
|
507
505
|
|
|
508
506
|
# Log the weights statistics:
|
|
509
507
|
self._logger.log_weights_statistics()
|
|
@@ -540,9 +538,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
540
538
|
:param y_true: The true value part of the current batch.
|
|
541
539
|
:param y_pred: The prediction (output) of the model for this batch's input ('x').
|
|
542
540
|
"""
|
|
543
|
-
super(
|
|
544
|
-
batch=batch, x=x, y_true=y_true, y_pred=y_pred
|
|
545
|
-
)
|
|
541
|
+
super().on_train_batch_end(batch=batch, x=x, y_true=y_true, y_pred=y_pred)
|
|
546
542
|
|
|
547
543
|
# Write the batch loss and metrics results to their graphs:
|
|
548
544
|
self._logger.write_training_results()
|
|
@@ -559,9 +555,7 @@ class TensorboardLoggingCallback(LoggingCallback):
|
|
|
559
555
|
:param y_true: The true value part of the current batch.
|
|
560
556
|
:param y_pred: The prediction (output) of the model for this batch's input ('x').
|
|
561
557
|
"""
|
|
562
|
-
super(
|
|
563
|
-
batch=batch, x=x, y_true=y_true, y_pred=y_pred
|
|
564
|
-
)
|
|
558
|
+
super().on_validation_batch_end(batch=batch, x=x, y_true=y_true, y_pred=y_pred)
|
|
565
559
|
|
|
566
560
|
# Write the batch loss and metrics results to their graphs:
|
|
567
561
|
self._logger.write_validation_results()
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import Union
|
|
16
16
|
|
|
17
17
|
from torch import Tensor
|
|
18
18
|
from torch.nn import Module
|
|
@@ -66,7 +66,7 @@ class CallbacksHandler:
|
|
|
66
66
|
A class for handling multiple callbacks during a run.
|
|
67
67
|
"""
|
|
68
68
|
|
|
69
|
-
def __init__(self, callbacks:
|
|
69
|
+
def __init__(self, callbacks: list[Union[Callback, tuple[str, Callback]]]):
|
|
70
70
|
"""
|
|
71
71
|
Initialize the callbacks handler with the given callbacks he will handle. The callbacks can be passed as their
|
|
72
72
|
initialized instances or as a tuple where [0] is a name that will be attached to him and [1] will be the
|
|
@@ -99,7 +99,7 @@ class CallbacksHandler:
|
|
|
99
99
|
self._callbacks[callback.__class__.__name__] = callback
|
|
100
100
|
|
|
101
101
|
@property
|
|
102
|
-
def callbacks(self) ->
|
|
102
|
+
def callbacks(self) -> dict[str, Callback]:
|
|
103
103
|
"""
|
|
104
104
|
Get the callbacks dictionary handled by this handler.
|
|
105
105
|
|
|
@@ -114,9 +114,9 @@ class CallbacksHandler:
|
|
|
114
114
|
validation_set: DataLoader,
|
|
115
115
|
loss_function: Module,
|
|
116
116
|
optimizer: Optimizer,
|
|
117
|
-
metric_functions:
|
|
117
|
+
metric_functions: list[PyTorchTypes.MetricFunctionType],
|
|
118
118
|
scheduler,
|
|
119
|
-
callbacks:
|
|
119
|
+
callbacks: list[str] = None,
|
|
120
120
|
) -> bool:
|
|
121
121
|
"""
|
|
122
122
|
Call the 'on_setup' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
@@ -145,7 +145,7 @@ class CallbacksHandler:
|
|
|
145
145
|
scheduler=scheduler,
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
-
def on_run_begin(self, callbacks:
|
|
148
|
+
def on_run_begin(self, callbacks: list[str] = None) -> bool:
|
|
149
149
|
"""
|
|
150
150
|
Call the 'on_run_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
151
151
|
callbacks will be called.
|
|
@@ -159,7 +159,7 @@ class CallbacksHandler:
|
|
|
159
159
|
callbacks=self._parse_names(names=callbacks),
|
|
160
160
|
)
|
|
161
161
|
|
|
162
|
-
def on_run_end(self, callbacks:
|
|
162
|
+
def on_run_end(self, callbacks: list[str] = None) -> bool:
|
|
163
163
|
"""
|
|
164
164
|
Call the 'on_run_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
165
165
|
callbacks will be called.
|
|
@@ -173,7 +173,7 @@ class CallbacksHandler:
|
|
|
173
173
|
callbacks=self._parse_names(names=callbacks),
|
|
174
174
|
)
|
|
175
175
|
|
|
176
|
-
def on_epoch_begin(self, epoch: int, callbacks:
|
|
176
|
+
def on_epoch_begin(self, epoch: int, callbacks: list[str] = None) -> bool:
|
|
177
177
|
"""
|
|
178
178
|
Call the 'on_epoch_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
179
179
|
callbacks will be called.
|
|
@@ -189,7 +189,7 @@ class CallbacksHandler:
|
|
|
189
189
|
epoch=epoch,
|
|
190
190
|
)
|
|
191
191
|
|
|
192
|
-
def on_epoch_end(self, epoch: int, callbacks:
|
|
192
|
+
def on_epoch_end(self, epoch: int, callbacks: list[str] = None) -> bool:
|
|
193
193
|
"""
|
|
194
194
|
Call the 'on_epoch_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
195
195
|
callbacks will be called.
|
|
@@ -205,7 +205,7 @@ class CallbacksHandler:
|
|
|
205
205
|
epoch=epoch,
|
|
206
206
|
)
|
|
207
207
|
|
|
208
|
-
def on_train_begin(self, callbacks:
|
|
208
|
+
def on_train_begin(self, callbacks: list[str] = None) -> bool:
|
|
209
209
|
"""
|
|
210
210
|
Call the 'on_train_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
211
211
|
callbacks will be called.
|
|
@@ -219,7 +219,7 @@ class CallbacksHandler:
|
|
|
219
219
|
callbacks=self._parse_names(names=callbacks),
|
|
220
220
|
)
|
|
221
221
|
|
|
222
|
-
def on_train_end(self, callbacks:
|
|
222
|
+
def on_train_end(self, callbacks: list[str] = None) -> bool:
|
|
223
223
|
"""
|
|
224
224
|
Call the 'on_train_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
|
|
225
225
|
callbacks will be called.
|
|
@@ -233,7 +233,7 @@ class CallbacksHandler:
|
|
|
233
233
|
callbacks=self._parse_names(names=callbacks),
|
|
234
234
|
)
|
|
235
235
|
|
|
236
|
-
def on_validation_begin(self, callbacks:
|
|
236
|
+
def on_validation_begin(self, callbacks: list[str] = None) -> bool:
|
|
237
237
|
"""
|
|
238
238
|
Call the 'on_validation_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
239
239
|
(not given), all callbacks will be called.
|
|
@@ -250,8 +250,8 @@ class CallbacksHandler:
|
|
|
250
250
|
def on_validation_end(
|
|
251
251
|
self,
|
|
252
252
|
loss_value: PyTorchTypes.MetricValueType,
|
|
253
|
-
metric_values:
|
|
254
|
-
callbacks:
|
|
253
|
+
metric_values: list[float],
|
|
254
|
+
callbacks: list[str] = None,
|
|
255
255
|
) -> bool:
|
|
256
256
|
"""
|
|
257
257
|
Call the 'on_validation_end' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
@@ -271,7 +271,7 @@ class CallbacksHandler:
|
|
|
271
271
|
)
|
|
272
272
|
|
|
273
273
|
def on_train_batch_begin(
|
|
274
|
-
self, batch: int, x, y_true: Tensor, callbacks:
|
|
274
|
+
self, batch: int, x, y_true: Tensor, callbacks: list[str] = None
|
|
275
275
|
) -> bool:
|
|
276
276
|
"""
|
|
277
277
|
Call the 'on_train_batch_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -298,7 +298,7 @@ class CallbacksHandler:
|
|
|
298
298
|
x,
|
|
299
299
|
y_pred: Tensor,
|
|
300
300
|
y_true: Tensor,
|
|
301
|
-
callbacks:
|
|
301
|
+
callbacks: list[str] = None,
|
|
302
302
|
) -> bool:
|
|
303
303
|
"""
|
|
304
304
|
Call the 'on_train_batch_end' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
@@ -322,7 +322,7 @@ class CallbacksHandler:
|
|
|
322
322
|
)
|
|
323
323
|
|
|
324
324
|
def on_validation_batch_begin(
|
|
325
|
-
self, batch: int, x, y_true: Tensor, callbacks:
|
|
325
|
+
self, batch: int, x, y_true: Tensor, callbacks: list[str] = None
|
|
326
326
|
) -> bool:
|
|
327
327
|
"""
|
|
328
328
|
Call the 'on_validation_batch_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -349,7 +349,7 @@ class CallbacksHandler:
|
|
|
349
349
|
x,
|
|
350
350
|
y_pred: Tensor,
|
|
351
351
|
y_true: Tensor,
|
|
352
|
-
callbacks:
|
|
352
|
+
callbacks: list[str] = None,
|
|
353
353
|
) -> bool:
|
|
354
354
|
"""
|
|
355
355
|
Call the 'on_validation_batch_end' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -375,7 +375,7 @@ class CallbacksHandler:
|
|
|
375
375
|
def on_inference_begin(
|
|
376
376
|
self,
|
|
377
377
|
x,
|
|
378
|
-
callbacks:
|
|
378
|
+
callbacks: list[str] = None,
|
|
379
379
|
) -> bool:
|
|
380
380
|
"""
|
|
381
381
|
Call the 'on_inference_begin' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
@@ -396,7 +396,7 @@ class CallbacksHandler:
|
|
|
396
396
|
self,
|
|
397
397
|
y_pred: Tensor,
|
|
398
398
|
y_true: Tensor,
|
|
399
|
-
callbacks:
|
|
399
|
+
callbacks: list[str] = None,
|
|
400
400
|
) -> bool:
|
|
401
401
|
"""
|
|
402
402
|
Call the 'on_inference_end' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
@@ -415,7 +415,7 @@ class CallbacksHandler:
|
|
|
415
415
|
y_true=y_true,
|
|
416
416
|
)
|
|
417
417
|
|
|
418
|
-
def on_train_loss_begin(self, callbacks:
|
|
418
|
+
def on_train_loss_begin(self, callbacks: list[str] = None) -> bool:
|
|
419
419
|
"""
|
|
420
420
|
Call the 'on_train_loss_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
421
421
|
(not given), all callbacks will be called.
|
|
@@ -430,7 +430,7 @@ class CallbacksHandler:
|
|
|
430
430
|
)
|
|
431
431
|
|
|
432
432
|
def on_train_loss_end(
|
|
433
|
-
self, loss_value: PyTorchTypes.MetricValueType, callbacks:
|
|
433
|
+
self, loss_value: PyTorchTypes.MetricValueType, callbacks: list[str] = None
|
|
434
434
|
) -> bool:
|
|
435
435
|
"""
|
|
436
436
|
Call the 'on_train_loss_end' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
@@ -447,7 +447,7 @@ class CallbacksHandler:
|
|
|
447
447
|
loss_value=loss_value,
|
|
448
448
|
)
|
|
449
449
|
|
|
450
|
-
def on_validation_loss_begin(self, callbacks:
|
|
450
|
+
def on_validation_loss_begin(self, callbacks: list[str] = None) -> bool:
|
|
451
451
|
"""
|
|
452
452
|
Call the 'on_validation_loss_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
453
453
|
(not given), all callbacks will be called.
|
|
@@ -462,7 +462,7 @@ class CallbacksHandler:
|
|
|
462
462
|
)
|
|
463
463
|
|
|
464
464
|
def on_validation_loss_end(
|
|
465
|
-
self, loss_value: PyTorchTypes.MetricValueType, callbacks:
|
|
465
|
+
self, loss_value: PyTorchTypes.MetricValueType, callbacks: list[str] = None
|
|
466
466
|
) -> bool:
|
|
467
467
|
"""
|
|
468
468
|
Call the 'on_validation_loss_end' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -479,7 +479,7 @@ class CallbacksHandler:
|
|
|
479
479
|
loss_value=loss_value,
|
|
480
480
|
)
|
|
481
481
|
|
|
482
|
-
def on_train_metrics_begin(self, callbacks:
|
|
482
|
+
def on_train_metrics_begin(self, callbacks: list[str] = None) -> bool:
|
|
483
483
|
"""
|
|
484
484
|
Call the 'on_train_metrics_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
485
485
|
(not given), all callbacks will be called.
|
|
@@ -495,8 +495,8 @@ class CallbacksHandler:
|
|
|
495
495
|
|
|
496
496
|
def on_train_metrics_end(
|
|
497
497
|
self,
|
|
498
|
-
metric_values:
|
|
499
|
-
callbacks:
|
|
498
|
+
metric_values: list[PyTorchTypes.MetricValueType],
|
|
499
|
+
callbacks: list[str] = None,
|
|
500
500
|
) -> bool:
|
|
501
501
|
"""
|
|
502
502
|
Call the 'on_train_metrics_end' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -513,7 +513,7 @@ class CallbacksHandler:
|
|
|
513
513
|
metric_values=metric_values,
|
|
514
514
|
)
|
|
515
515
|
|
|
516
|
-
def on_validation_metrics_begin(self, callbacks:
|
|
516
|
+
def on_validation_metrics_begin(self, callbacks: list[str] = None) -> bool:
|
|
517
517
|
"""
|
|
518
518
|
Call the 'on_validation_metrics_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
519
519
|
(not given), all callbacks will be called.
|
|
@@ -529,8 +529,8 @@ class CallbacksHandler:
|
|
|
529
529
|
|
|
530
530
|
def on_validation_metrics_end(
|
|
531
531
|
self,
|
|
532
|
-
metric_values:
|
|
533
|
-
callbacks:
|
|
532
|
+
metric_values: list[PyTorchTypes.MetricValueType],
|
|
533
|
+
callbacks: list[str] = None,
|
|
534
534
|
) -> bool:
|
|
535
535
|
"""
|
|
536
536
|
Call the 'on_validation_metrics_end' method of every callback in the callbacks list. If the list is 'None'
|
|
@@ -547,7 +547,7 @@ class CallbacksHandler:
|
|
|
547
547
|
metric_values=metric_values,
|
|
548
548
|
)
|
|
549
549
|
|
|
550
|
-
def on_backward_begin(self, callbacks:
|
|
550
|
+
def on_backward_begin(self, callbacks: list[str] = None) -> bool:
|
|
551
551
|
"""
|
|
552
552
|
Call the 'on_backward_begin' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
553
553
|
all callbacks will be called.
|
|
@@ -561,7 +561,7 @@ class CallbacksHandler:
|
|
|
561
561
|
callbacks=self._parse_names(names=callbacks),
|
|
562
562
|
)
|
|
563
563
|
|
|
564
|
-
def on_backward_end(self, callbacks:
|
|
564
|
+
def on_backward_end(self, callbacks: list[str] = None) -> bool:
|
|
565
565
|
"""
|
|
566
566
|
Call the 'on_backward_end' method of every callback in the callbacks list. If the list is 'None' (not given),
|
|
567
567
|
all callbacks will be called.
|
|
@@ -575,7 +575,7 @@ class CallbacksHandler:
|
|
|
575
575
|
callbacks=self._parse_names(names=callbacks),
|
|
576
576
|
)
|
|
577
577
|
|
|
578
|
-
def on_optimizer_step_begin(self, callbacks:
|
|
578
|
+
def on_optimizer_step_begin(self, callbacks: list[str] = None) -> bool:
|
|
579
579
|
"""
|
|
580
580
|
Call the 'on_optimizer_step_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
581
581
|
(not given), all callbacks will be called.
|
|
@@ -589,7 +589,7 @@ class CallbacksHandler:
|
|
|
589
589
|
callbacks=self._parse_names(names=callbacks),
|
|
590
590
|
)
|
|
591
591
|
|
|
592
|
-
def on_optimizer_step_end(self, callbacks:
|
|
592
|
+
def on_optimizer_step_end(self, callbacks: list[str] = None) -> bool:
|
|
593
593
|
"""
|
|
594
594
|
Call the 'on_optimizer_step_end' method of every callback in the callbacks list. If the list is 'None'
|
|
595
595
|
(not given), all callbacks will be called.
|
|
@@ -603,7 +603,7 @@ class CallbacksHandler:
|
|
|
603
603
|
callbacks=self._parse_names(names=callbacks),
|
|
604
604
|
)
|
|
605
605
|
|
|
606
|
-
def on_scheduler_step_begin(self, callbacks:
|
|
606
|
+
def on_scheduler_step_begin(self, callbacks: list[str] = None) -> bool:
|
|
607
607
|
"""
|
|
608
608
|
Call the 'on_scheduler_step_begin' method of every callback in the callbacks list. If the list is 'None'
|
|
609
609
|
(not given), all callbacks will be called.
|
|
@@ -617,7 +617,7 @@ class CallbacksHandler:
|
|
|
617
617
|
callbacks=self._parse_names(names=callbacks),
|
|
618
618
|
)
|
|
619
619
|
|
|
620
|
-
def on_scheduler_step_end(self, callbacks:
|
|
620
|
+
def on_scheduler_step_end(self, callbacks: list[str] = None) -> bool:
|
|
621
621
|
"""
|
|
622
622
|
Call the 'on_scheduler_step_end' method of every callback in the callbacks list. If the list is 'None'
|
|
623
623
|
(not given), all callbacks will be called.
|
|
@@ -631,7 +631,7 @@ class CallbacksHandler:
|
|
|
631
631
|
callbacks=self._parse_names(names=callbacks),
|
|
632
632
|
)
|
|
633
633
|
|
|
634
|
-
def _parse_names(self, names: Union[
|
|
634
|
+
def _parse_names(self, names: Union[list[str], None]) -> list[str]:
|
|
635
635
|
"""
|
|
636
636
|
Parse the given callbacks names. If they are not 'None' then the names will be returned as they are, otherwise
|
|
637
637
|
all of the callbacks handled by this handler will be returned (the default behavior of when there were no names
|
|
@@ -646,7 +646,7 @@ class CallbacksHandler:
|
|
|
646
646
|
return list(self._callbacks.keys())
|
|
647
647
|
|
|
648
648
|
def _run_callbacks(
|
|
649
|
-
self, method_name: str, callbacks:
|
|
649
|
+
self, method_name: str, callbacks: list[str], *args, **kwargs
|
|
650
650
|
) -> bool:
|
|
651
651
|
"""
|
|
652
652
|
Run the given method from the 'CallbackInterface' on all the specified callbacks with the given arguments.
|