mlrun 1.6.0rc35__py3-none-any.whl → 1.7.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (199) hide show
  1. mlrun/__main__.py +3 -3
  2. mlrun/api/schemas/__init__.py +1 -1
  3. mlrun/artifacts/base.py +11 -6
  4. mlrun/artifacts/dataset.py +2 -2
  5. mlrun/artifacts/model.py +30 -24
  6. mlrun/artifacts/plots.py +2 -2
  7. mlrun/common/db/sql_session.py +5 -3
  8. mlrun/common/helpers.py +1 -2
  9. mlrun/common/schemas/artifact.py +3 -3
  10. mlrun/common/schemas/auth.py +3 -3
  11. mlrun/common/schemas/background_task.py +1 -1
  12. mlrun/common/schemas/client_spec.py +1 -1
  13. mlrun/common/schemas/feature_store.py +16 -16
  14. mlrun/common/schemas/frontend_spec.py +7 -7
  15. mlrun/common/schemas/function.py +1 -1
  16. mlrun/common/schemas/hub.py +4 -9
  17. mlrun/common/schemas/memory_reports.py +2 -2
  18. mlrun/common/schemas/model_monitoring/grafana.py +4 -4
  19. mlrun/common/schemas/model_monitoring/model_endpoints.py +14 -15
  20. mlrun/common/schemas/notification.py +4 -4
  21. mlrun/common/schemas/object.py +2 -2
  22. mlrun/common/schemas/pipeline.py +1 -1
  23. mlrun/common/schemas/project.py +3 -3
  24. mlrun/common/schemas/runtime_resource.py +8 -12
  25. mlrun/common/schemas/schedule.py +3 -3
  26. mlrun/common/schemas/tag.py +1 -2
  27. mlrun/common/schemas/workflow.py +2 -2
  28. mlrun/config.py +8 -4
  29. mlrun/data_types/to_pandas.py +1 -3
  30. mlrun/datastore/base.py +0 -28
  31. mlrun/datastore/datastore_profile.py +9 -9
  32. mlrun/datastore/filestore.py +0 -1
  33. mlrun/datastore/google_cloud_storage.py +1 -1
  34. mlrun/datastore/sources.py +7 -11
  35. mlrun/datastore/spark_utils.py +1 -2
  36. mlrun/datastore/targets.py +31 -31
  37. mlrun/datastore/utils.py +4 -6
  38. mlrun/datastore/v3io.py +70 -46
  39. mlrun/db/base.py +22 -23
  40. mlrun/db/httpdb.py +34 -34
  41. mlrun/db/nopdb.py +19 -19
  42. mlrun/errors.py +1 -1
  43. mlrun/execution.py +4 -4
  44. mlrun/feature_store/api.py +20 -21
  45. mlrun/feature_store/common.py +1 -1
  46. mlrun/feature_store/feature_set.py +28 -32
  47. mlrun/feature_store/feature_vector.py +24 -27
  48. mlrun/feature_store/retrieval/base.py +7 -7
  49. mlrun/feature_store/retrieval/conversion.py +2 -4
  50. mlrun/feature_store/steps.py +7 -15
  51. mlrun/features.py +5 -7
  52. mlrun/frameworks/_common/artifacts_library.py +9 -9
  53. mlrun/frameworks/_common/mlrun_interface.py +5 -5
  54. mlrun/frameworks/_common/model_handler.py +48 -48
  55. mlrun/frameworks/_common/plan.py +2 -3
  56. mlrun/frameworks/_common/producer.py +3 -4
  57. mlrun/frameworks/_common/utils.py +5 -5
  58. mlrun/frameworks/_dl_common/loggers/logger.py +6 -7
  59. mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +9 -9
  60. mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +16 -35
  61. mlrun/frameworks/_ml_common/artifacts_library.py +1 -2
  62. mlrun/frameworks/_ml_common/loggers/logger.py +3 -4
  63. mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +4 -5
  64. mlrun/frameworks/_ml_common/model_handler.py +24 -24
  65. mlrun/frameworks/_ml_common/pkl_model_server.py +2 -2
  66. mlrun/frameworks/_ml_common/plan.py +1 -1
  67. mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +2 -3
  68. mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +2 -3
  69. mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
  70. mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +3 -3
  71. mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
  72. mlrun/frameworks/_ml_common/utils.py +4 -4
  73. mlrun/frameworks/auto_mlrun/auto_mlrun.py +7 -7
  74. mlrun/frameworks/huggingface/model_server.py +4 -4
  75. mlrun/frameworks/lgbm/__init__.py +32 -32
  76. mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -5
  77. mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -5
  78. mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +1 -3
  79. mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +6 -6
  80. mlrun/frameworks/lgbm/model_handler.py +9 -9
  81. mlrun/frameworks/lgbm/model_server.py +6 -6
  82. mlrun/frameworks/lgbm/utils.py +5 -5
  83. mlrun/frameworks/onnx/dataset.py +8 -8
  84. mlrun/frameworks/onnx/mlrun_interface.py +3 -3
  85. mlrun/frameworks/onnx/model_handler.py +6 -6
  86. mlrun/frameworks/onnx/model_server.py +7 -7
  87. mlrun/frameworks/parallel_coordinates.py +2 -2
  88. mlrun/frameworks/pytorch/__init__.py +16 -16
  89. mlrun/frameworks/pytorch/callbacks/callback.py +4 -5
  90. mlrun/frameworks/pytorch/callbacks/logging_callback.py +17 -17
  91. mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +11 -11
  92. mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +23 -29
  93. mlrun/frameworks/pytorch/callbacks_handler.py +38 -38
  94. mlrun/frameworks/pytorch/mlrun_interface.py +20 -20
  95. mlrun/frameworks/pytorch/model_handler.py +17 -17
  96. mlrun/frameworks/pytorch/model_server.py +7 -7
  97. mlrun/frameworks/sklearn/__init__.py +12 -12
  98. mlrun/frameworks/sklearn/estimator.py +4 -4
  99. mlrun/frameworks/sklearn/metrics_library.py +14 -14
  100. mlrun/frameworks/sklearn/mlrun_interface.py +3 -6
  101. mlrun/frameworks/sklearn/model_handler.py +2 -2
  102. mlrun/frameworks/tf_keras/__init__.py +5 -5
  103. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +14 -14
  104. mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +11 -11
  105. mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +19 -23
  106. mlrun/frameworks/tf_keras/mlrun_interface.py +7 -9
  107. mlrun/frameworks/tf_keras/model_handler.py +14 -14
  108. mlrun/frameworks/tf_keras/model_server.py +6 -6
  109. mlrun/frameworks/xgboost/__init__.py +12 -12
  110. mlrun/frameworks/xgboost/model_handler.py +6 -6
  111. mlrun/k8s_utils.py +4 -5
  112. mlrun/kfpops.py +2 -2
  113. mlrun/launcher/base.py +10 -10
  114. mlrun/launcher/local.py +8 -8
  115. mlrun/launcher/remote.py +7 -7
  116. mlrun/lists.py +3 -4
  117. mlrun/model.py +205 -55
  118. mlrun/model_monitoring/api.py +21 -24
  119. mlrun/model_monitoring/application.py +4 -4
  120. mlrun/model_monitoring/batch.py +17 -17
  121. mlrun/model_monitoring/controller.py +2 -1
  122. mlrun/model_monitoring/features_drift_table.py +44 -31
  123. mlrun/model_monitoring/prometheus.py +1 -4
  124. mlrun/model_monitoring/stores/kv_model_endpoint_store.py +11 -13
  125. mlrun/model_monitoring/stores/model_endpoint_store.py +9 -11
  126. mlrun/model_monitoring/stores/models/__init__.py +2 -2
  127. mlrun/model_monitoring/stores/sql_model_endpoint_store.py +11 -13
  128. mlrun/model_monitoring/stream_processing.py +16 -34
  129. mlrun/model_monitoring/tracking_policy.py +2 -1
  130. mlrun/package/__init__.py +6 -6
  131. mlrun/package/context_handler.py +5 -5
  132. mlrun/package/packager.py +7 -7
  133. mlrun/package/packagers/default_packager.py +6 -6
  134. mlrun/package/packagers/numpy_packagers.py +15 -15
  135. mlrun/package/packagers/pandas_packagers.py +5 -5
  136. mlrun/package/packagers/python_standard_library_packagers.py +10 -10
  137. mlrun/package/packagers_manager.py +18 -23
  138. mlrun/package/utils/_formatter.py +4 -4
  139. mlrun/package/utils/_pickler.py +2 -2
  140. mlrun/package/utils/_supported_format.py +4 -4
  141. mlrun/package/utils/log_hint_utils.py +2 -2
  142. mlrun/package/utils/type_hint_utils.py +4 -9
  143. mlrun/platforms/other.py +1 -2
  144. mlrun/projects/operations.py +5 -5
  145. mlrun/projects/pipelines.py +9 -9
  146. mlrun/projects/project.py +58 -46
  147. mlrun/render.py +1 -1
  148. mlrun/run.py +9 -9
  149. mlrun/runtimes/__init__.py +7 -4
  150. mlrun/runtimes/base.py +20 -23
  151. mlrun/runtimes/constants.py +5 -5
  152. mlrun/runtimes/daskjob.py +8 -8
  153. mlrun/runtimes/databricks_job/databricks_cancel_task.py +1 -1
  154. mlrun/runtimes/databricks_job/databricks_runtime.py +7 -7
  155. mlrun/runtimes/function_reference.py +1 -1
  156. mlrun/runtimes/local.py +1 -1
  157. mlrun/runtimes/mpijob/abstract.py +1 -2
  158. mlrun/runtimes/nuclio/__init__.py +20 -0
  159. mlrun/runtimes/{function.py → nuclio/function.py} +15 -16
  160. mlrun/runtimes/{nuclio.py → nuclio/nuclio.py} +6 -6
  161. mlrun/runtimes/{serving.py → nuclio/serving.py} +13 -12
  162. mlrun/runtimes/pod.py +95 -48
  163. mlrun/runtimes/remotesparkjob.py +1 -1
  164. mlrun/runtimes/sparkjob/spark3job.py +50 -33
  165. mlrun/runtimes/utils.py +1 -2
  166. mlrun/secrets.py +3 -3
  167. mlrun/serving/remote.py +0 -4
  168. mlrun/serving/routers.py +6 -6
  169. mlrun/serving/server.py +4 -4
  170. mlrun/serving/states.py +29 -0
  171. mlrun/serving/utils.py +3 -3
  172. mlrun/serving/v1_serving.py +6 -7
  173. mlrun/serving/v2_serving.py +50 -8
  174. mlrun/track/tracker_manager.py +3 -3
  175. mlrun/track/trackers/mlflow_tracker.py +1 -2
  176. mlrun/utils/async_http.py +5 -7
  177. mlrun/utils/azure_vault.py +1 -1
  178. mlrun/utils/clones.py +1 -2
  179. mlrun/utils/condition_evaluator.py +3 -3
  180. mlrun/utils/db.py +3 -3
  181. mlrun/utils/helpers.py +37 -119
  182. mlrun/utils/http.py +1 -4
  183. mlrun/utils/logger.py +49 -14
  184. mlrun/utils/notifications/notification/__init__.py +3 -3
  185. mlrun/utils/notifications/notification/base.py +2 -2
  186. mlrun/utils/notifications/notification/ipython.py +1 -1
  187. mlrun/utils/notifications/notification_pusher.py +8 -14
  188. mlrun/utils/retryer.py +207 -0
  189. mlrun/utils/singleton.py +1 -1
  190. mlrun/utils/v3io_clients.py +2 -3
  191. mlrun/utils/version/version.json +2 -2
  192. mlrun/utils/version/version.py +2 -6
  193. {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/METADATA +9 -9
  194. mlrun-1.7.0rc2.dist-info/RECORD +315 -0
  195. mlrun-1.6.0rc35.dist-info/RECORD +0 -313
  196. {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/LICENSE +0 -0
  197. {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/WHEEL +0 -0
  198. {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/entry_points.txt +0 -0
  199. {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  #
15
- from typing import Callable, Dict, List, Tuple, Union
15
+ from typing import Callable, Union
16
16
 
17
17
  import numpy as np
18
18
  from torch import Tensor
@@ -59,15 +59,15 @@ class LoggingCallback(Callback):
59
59
  def __init__(
60
60
  self,
61
61
  context: mlrun.MLClientCtx = None,
62
- dynamic_hyperparameters: Dict[
62
+ dynamic_hyperparameters: dict[
63
63
  str,
64
- Tuple[
64
+ tuple[
65
65
  str,
66
- Union[List[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
66
+ Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
67
67
  ],
68
68
  ] = None,
69
- static_hyperparameters: Dict[
70
- str, Union[PyTorchTypes.TrackableType, Tuple[str, List[Union[str, int]]]]
69
+ static_hyperparameters: dict[
70
+ str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
71
71
  ] = None,
72
72
  auto_log: bool = False,
73
73
  ):
@@ -100,7 +100,7 @@ class LoggingCallback(Callback):
100
100
  :param auto_log: Whether or not to enable auto logging, trying to track common static and dynamic
101
101
  hyperparameters.
102
102
  """
103
- super(LoggingCallback, self).__init__()
103
+ super().__init__()
104
104
 
105
105
  # Store the configurations:
106
106
  self._dynamic_hyperparameters_keys = (
@@ -117,7 +117,7 @@ class LoggingCallback(Callback):
117
117
  self._is_training = None # type: bool
118
118
  self._auto_log = auto_log
119
119
 
120
- def get_training_results(self) -> Dict[str, List[List[float]]]:
120
+ def get_training_results(self) -> dict[str, list[list[float]]]:
121
121
  """
122
122
  Get the training results logged. The results will be stored in a dictionary where each key is the metric name
123
123
  and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
@@ -127,7 +127,7 @@ class LoggingCallback(Callback):
127
127
  """
128
128
  return self._logger.training_results
129
129
 
130
- def get_validation_results(self) -> Dict[str, List[List[float]]]:
130
+ def get_validation_results(self) -> dict[str, list[list[float]]]:
131
131
  """
132
132
  Get the validation results logged. The results will be stored in a dictionary where each key is the metric name
133
133
  and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
@@ -137,7 +137,7 @@ class LoggingCallback(Callback):
137
137
  """
138
138
  return self._logger.validation_results
139
139
 
140
- def get_static_hyperparameters(self) -> Dict[str, PyTorchTypes.TrackableType]:
140
+ def get_static_hyperparameters(self) -> dict[str, PyTorchTypes.TrackableType]:
141
141
  """
142
142
  Get the static hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
143
143
  hyperparameter name and the value is his logged value.
@@ -148,7 +148,7 @@ class LoggingCallback(Callback):
148
148
 
149
149
  def get_dynamic_hyperparameters(
150
150
  self,
151
- ) -> Dict[str, List[PyTorchTypes.TrackableType]]:
151
+ ) -> dict[str, list[PyTorchTypes.TrackableType]]:
152
152
  """
153
153
  Get the dynamic hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
154
154
  hyperparameter name and the value is a list of his logged values per epoch.
@@ -157,7 +157,7 @@ class LoggingCallback(Callback):
157
157
  """
158
158
  return self._logger.dynamic_hyperparameters
159
159
 
160
- def get_summaries(self) -> Dict[str, List[float]]:
160
+ def get_summaries(self) -> dict[str, list[float]]:
161
161
  """
162
162
  Get the validation summaries of the metrics results. The summaries will be stored in a dictionary where each key
163
163
  is the metric names and the value is a list of all the summary values per epoch.
@@ -210,7 +210,7 @@ class LoggingCallback(Callback):
210
210
  self._add_auto_hyperparameters()
211
211
  # # Static hyperparameters:
212
212
  for name, value in self._static_hyperparameters_keys.items():
213
- if isinstance(value, Tuple):
213
+ if isinstance(value, tuple):
214
214
  # Its a parameter that needed to be extracted via key chain.
215
215
  self._logger.log_static_hyperparameter(
216
216
  parameter_name=name,
@@ -294,7 +294,7 @@ class LoggingCallback(Callback):
294
294
  self._logger.set_mode(mode=LoggingMode.EVALUATION)
295
295
 
296
296
  def on_validation_end(
297
- self, loss_value: PyTorchTypes.MetricValueType, metric_values: List[float]
297
+ self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
298
298
  ):
299
299
  """
300
300
  Before the validation (in a training case it will be per epoch) ends, this method will be called to log the
@@ -372,7 +372,7 @@ class LoggingCallback(Callback):
372
372
  result=float(loss_value),
373
373
  )
374
374
 
375
- def on_train_metrics_end(self, metric_values: List[PyTorchTypes.MetricValueType]):
375
+ def on_train_metrics_end(self, metric_values: list[PyTorchTypes.MetricValueType]):
376
376
  """
377
377
  After the training calculation of the metrics, this method will be called to log the metrics values.
378
378
 
@@ -389,7 +389,7 @@ class LoggingCallback(Callback):
389
389
  )
390
390
 
391
391
  def on_validation_metrics_end(
392
- self, metric_values: List[PyTorchTypes.MetricValueType]
392
+ self, metric_values: list[PyTorchTypes.MetricValueType]
393
393
  ):
394
394
  """
395
395
  After the validating calculation of the metrics, this method will be called to log the metrics values.
@@ -456,7 +456,7 @@ class LoggingCallback(Callback):
456
456
  self,
457
457
  source: str,
458
458
  key_chain: Union[
459
- List[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]
459
+ list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]
460
460
  ],
461
461
  ) -> PyTorchTypes.TrackableType:
462
462
  """
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  #
15
- from typing import Callable, Dict, List, Tuple, Union
15
+ from typing import Callable, Union
16
16
 
17
17
  import torch
18
18
  from torch import Tensor
@@ -53,20 +53,20 @@ class MLRunLoggingCallback(LoggingCallback):
53
53
  context: mlrun.MLClientCtx,
54
54
  model_handler: PyTorchModelHandler,
55
55
  log_model_tag: str = "",
56
- log_model_labels: Dict[str, PyTorchTypes.TrackableType] = None,
57
- log_model_parameters: Dict[str, PyTorchTypes.TrackableType] = None,
58
- log_model_extra_data: Dict[
56
+ log_model_labels: dict[str, PyTorchTypes.TrackableType] = None,
57
+ log_model_parameters: dict[str, PyTorchTypes.TrackableType] = None,
58
+ log_model_extra_data: dict[
59
59
  str, Union[PyTorchTypes.TrackableType, Artifact]
60
60
  ] = None,
61
- dynamic_hyperparameters: Dict[
61
+ dynamic_hyperparameters: dict[
62
62
  str,
63
- Tuple[
63
+ tuple[
64
64
  str,
65
- Union[List[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
65
+ Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
66
66
  ],
67
67
  ] = None,
68
- static_hyperparameters: Dict[
69
- str, Union[PyTorchTypes.TrackableType, Tuple[str, List[Union[str, int]]]]
68
+ static_hyperparameters: dict[
69
+ str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
70
70
  ] = None,
71
71
  auto_log: bool = False,
72
72
  ):
@@ -107,7 +107,7 @@ class MLRunLoggingCallback(LoggingCallback):
107
107
  :param auto_log: Whether or not to enable auto logging for logging the context parameters and
108
108
  trying to track common static and dynamic hyperparameters.
109
109
  """
110
- super(MLRunLoggingCallback, self).__init__(
110
+ super().__init__(
111
111
  dynamic_hyperparameters=dynamic_hyperparameters,
112
112
  static_hyperparameters=static_hyperparameters,
113
113
  auto_log=auto_log,
@@ -160,7 +160,7 @@ class MLRunLoggingCallback(LoggingCallback):
160
160
 
161
161
  :param epoch: The epoch that has just ended.
162
162
  """
163
- super(MLRunLoggingCallback, self).on_epoch_end(epoch=epoch)
163
+ super().on_epoch_end(epoch=epoch)
164
164
 
165
165
  # Create child context to hold the current epoch's results:
166
166
  self._logger.log_epoch_to_context(epoch=epoch)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  from datetime import datetime
16
- from typing import Callable, Dict, List, Tuple, Union
16
+ from typing import Callable, Union
17
17
 
18
18
  import torch
19
19
  from torch import Tensor
@@ -63,7 +63,7 @@ class _PyTorchTensorboardLogger(TensorboardLogger):
63
63
 
64
64
  def __init__(
65
65
  self,
66
- statistics_functions: List[
66
+ statistics_functions: list[
67
67
  Callable[[Union[Parameter]], Union[float, Parameter]]
68
68
  ],
69
69
  context: mlrun.MLClientCtx = None,
@@ -94,7 +94,7 @@ class _PyTorchTensorboardLogger(TensorboardLogger):
94
94
  update. Notice that writing to tensorboard too frequently may cause the training
95
95
  to be slower. Default: 'epoch'.
96
96
  """
97
- super(_PyTorchTensorboardLogger, self).__init__(
97
+ super().__init__(
98
98
  statistics_functions=statistics_functions,
99
99
  context=context,
100
100
  tensorboard_directory=tensorboard_directory,
@@ -249,19 +249,19 @@ class TensorboardLoggingCallback(LoggingCallback):
249
249
  context: mlrun.MLClientCtx = None,
250
250
  tensorboard_directory: str = None,
251
251
  run_name: str = None,
252
- weights: Union[bool, List[str]] = False,
253
- statistics_functions: List[
252
+ weights: Union[bool, list[str]] = False,
253
+ statistics_functions: list[
254
254
  Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]
255
255
  ] = None,
256
- dynamic_hyperparameters: Dict[
256
+ dynamic_hyperparameters: dict[
257
257
  str,
258
- Tuple[
258
+ tuple[
259
259
  str,
260
- Union[List[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
260
+ Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
261
261
  ],
262
262
  ] = None,
263
- static_hyperparameters: Dict[
264
- str, Union[PyTorchTypes.TrackableType, Tuple[str, List[Union[str, int]]]]
263
+ static_hyperparameters: dict[
264
+ str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
265
265
  ] = None,
266
266
  update_frequency: Union[int, str] = "epoch",
267
267
  auto_log: bool = False,
@@ -322,7 +322,7 @@ class TensorboardLoggingCallback(LoggingCallback):
322
322
  :raise MLRunInvalidArgumentError: In case both 'context' and 'tensorboard_directory' parameters were not given
323
323
  or the 'update_frequency' was incorrect.
324
324
  """
325
- super(TensorboardLoggingCallback, self).__init__(
325
+ super().__init__(
326
326
  dynamic_hyperparameters=dynamic_hyperparameters,
327
327
  static_hyperparameters=static_hyperparameters,
328
328
  auto_log=auto_log,
@@ -345,7 +345,7 @@ class TensorboardLoggingCallback(LoggingCallback):
345
345
  # Save the configurations:
346
346
  self._tracked_weights = weights
347
347
 
348
- def get_weights(self) -> Dict[str, Parameter]:
348
+ def get_weights(self) -> dict[str, Parameter]:
349
349
  """
350
350
  Get the weights tensors tracked. The weights will be stored in a dictionary where each key is the weight's name
351
351
  and the value is the weight's parameter (tensor).
@@ -354,7 +354,7 @@ class TensorboardLoggingCallback(LoggingCallback):
354
354
  """
355
355
  return self._logger.weights
356
356
 
357
- def get_weights_statistics(self) -> Dict[str, Dict[str, List[float]]]:
357
+ def get_weights_statistics(self) -> dict[str, dict[str, list[float]]]:
358
358
  """
359
359
  Get the weights mean results logged. The results will be stored in a dictionary where each key is the weight's
360
360
  name and the value is a list of mean values per epoch.
@@ -365,7 +365,7 @@ class TensorboardLoggingCallback(LoggingCallback):
365
365
 
366
366
  @staticmethod
367
367
  def get_default_weight_statistics_list() -> (
368
- List[Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]]
368
+ list[Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]]
369
369
  ):
370
370
  """
371
371
  Get the default list of statistics functions being applied on the tracked weights each epoch.
@@ -381,7 +381,7 @@ class TensorboardLoggingCallback(LoggingCallback):
381
381
  validation_set: DataLoader = None,
382
382
  loss_function: Module = None,
383
383
  optimizer: Optimizer = None,
384
- metric_functions: List[PyTorchTypes.MetricFunctionType] = None,
384
+ metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
385
385
  scheduler=None,
386
386
  ):
387
387
  """
@@ -396,7 +396,7 @@ class TensorboardLoggingCallback(LoggingCallback):
396
396
  :param metric_functions: The metric functions to be stored in this callback.
397
397
  :param scheduler: The scheduler to be stored in this callback.
398
398
  """
399
- super(TensorboardLoggingCallback, self).on_setup(
399
+ super().on_setup(
400
400
  model=model,
401
401
  training_set=training_set,
402
402
  validation_set=validation_set,
@@ -439,7 +439,7 @@ class TensorboardLoggingCallback(LoggingCallback):
439
439
  for logging. Epoch 0 (pre-run state) will be logged here.
440
440
  """
441
441
  # Setup all the results and hyperparameters dictionaries:
442
- super(TensorboardLoggingCallback, self).on_run_begin()
442
+ super().on_run_begin()
443
443
 
444
444
  # Log the initial summary of the run:
445
445
  self._logger.write_initial_summary_text()
@@ -470,10 +470,10 @@ class TensorboardLoggingCallback(LoggingCallback):
470
470
  # Write the final summary of the run:
471
471
  self._logger.write_final_summary_text()
472
472
 
473
- super(TensorboardLoggingCallback, self).on_run_end()
473
+ super().on_run_end()
474
474
 
475
475
  def on_validation_end(
476
- self, loss_value: PyTorchTypes.MetricValueType, metric_values: List[float]
476
+ self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
477
477
  ):
478
478
  """
479
479
  Before the validation (in a training case it will be per epoch) ends, this method will be called to log the
@@ -482,9 +482,7 @@ class TensorboardLoggingCallback(LoggingCallback):
482
482
  :param loss_value: The loss summary of this validation.
483
483
  :param metric_values: The metrics summaries of this validation.
484
484
  """
485
- super(TensorboardLoggingCallback, self).on_validation_end(
486
- loss_value=loss_value, metric_values=metric_values
487
- )
485
+ super().on_validation_end(loss_value=loss_value, metric_values=metric_values)
488
486
 
489
487
  # Check if this run was part of an evaluation:
490
488
  if not self._is_training:
@@ -503,7 +501,7 @@ class TensorboardLoggingCallback(LoggingCallback):
503
501
 
504
502
  :param epoch: The epoch that has just ended.
505
503
  """
506
- super(TensorboardLoggingCallback, self).on_epoch_end(epoch=epoch)
504
+ super().on_epoch_end(epoch=epoch)
507
505
 
508
506
  # Log the weights statistics:
509
507
  self._logger.log_weights_statistics()
@@ -540,9 +538,7 @@ class TensorboardLoggingCallback(LoggingCallback):
540
538
  :param y_true: The true value part of the current batch.
541
539
  :param y_pred: The prediction (output) of the model for this batch's input ('x').
542
540
  """
543
- super(TensorboardLoggingCallback, self).on_train_batch_end(
544
- batch=batch, x=x, y_true=y_true, y_pred=y_pred
545
- )
541
+ super().on_train_batch_end(batch=batch, x=x, y_true=y_true, y_pred=y_pred)
546
542
 
547
543
  # Write the batch loss and metrics results to their graphs:
548
544
  self._logger.write_training_results()
@@ -559,9 +555,7 @@ class TensorboardLoggingCallback(LoggingCallback):
559
555
  :param y_true: The true value part of the current batch.
560
556
  :param y_pred: The prediction (output) of the model for this batch's input ('x').
561
557
  """
562
- super(TensorboardLoggingCallback, self).on_validation_batch_end(
563
- batch=batch, x=x, y_true=y_true, y_pred=y_pred
564
- )
558
+ super().on_validation_batch_end(batch=batch, x=x, y_true=y_true, y_pred=y_pred)
565
559
 
566
560
  # Write the batch loss and metrics results to their graphs:
567
561
  self._logger.write_validation_results()
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  #
15
- from typing import Dict, List, Tuple, Union
15
+ from typing import Union
16
16
 
17
17
  from torch import Tensor
18
18
  from torch.nn import Module
@@ -66,7 +66,7 @@ class CallbacksHandler:
66
66
  A class for handling multiple callbacks during a run.
67
67
  """
68
68
 
69
- def __init__(self, callbacks: List[Union[Callback, Tuple[str, Callback]]]):
69
+ def __init__(self, callbacks: list[Union[Callback, tuple[str, Callback]]]):
70
70
  """
71
71
  Initialize the callbacks handler with the given callbacks he will handle. The callbacks can be passed as their
72
72
  initialized instances or as a tuple where [0] is a name that will be attached to him and [1] will be the
@@ -99,7 +99,7 @@ class CallbacksHandler:
99
99
  self._callbacks[callback.__class__.__name__] = callback
100
100
 
101
101
  @property
102
- def callbacks(self) -> Dict[str, Callback]:
102
+ def callbacks(self) -> dict[str, Callback]:
103
103
  """
104
104
  Get the callbacks dictionary handled by this handler.
105
105
 
@@ -114,9 +114,9 @@ class CallbacksHandler:
114
114
  validation_set: DataLoader,
115
115
  loss_function: Module,
116
116
  optimizer: Optimizer,
117
- metric_functions: List[PyTorchTypes.MetricFunctionType],
117
+ metric_functions: list[PyTorchTypes.MetricFunctionType],
118
118
  scheduler,
119
- callbacks: List[str] = None,
119
+ callbacks: list[str] = None,
120
120
  ) -> bool:
121
121
  """
122
122
  Call the 'on_setup' method of every callback in the callbacks list. If the list is 'None' (not given), all
@@ -145,7 +145,7 @@ class CallbacksHandler:
145
145
  scheduler=scheduler,
146
146
  )
147
147
 
148
- def on_run_begin(self, callbacks: List[str] = None) -> bool:
148
+ def on_run_begin(self, callbacks: list[str] = None) -> bool:
149
149
  """
150
150
  Call the 'on_run_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
151
151
  callbacks will be called.
@@ -159,7 +159,7 @@ class CallbacksHandler:
159
159
  callbacks=self._parse_names(names=callbacks),
160
160
  )
161
161
 
162
- def on_run_end(self, callbacks: List[str] = None) -> bool:
162
+ def on_run_end(self, callbacks: list[str] = None) -> bool:
163
163
  """
164
164
  Call the 'on_run_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
165
165
  callbacks will be called.
@@ -173,7 +173,7 @@ class CallbacksHandler:
173
173
  callbacks=self._parse_names(names=callbacks),
174
174
  )
175
175
 
176
- def on_epoch_begin(self, epoch: int, callbacks: List[str] = None) -> bool:
176
+ def on_epoch_begin(self, epoch: int, callbacks: list[str] = None) -> bool:
177
177
  """
178
178
  Call the 'on_epoch_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
179
179
  callbacks will be called.
@@ -189,7 +189,7 @@ class CallbacksHandler:
189
189
  epoch=epoch,
190
190
  )
191
191
 
192
- def on_epoch_end(self, epoch: int, callbacks: List[str] = None) -> bool:
192
+ def on_epoch_end(self, epoch: int, callbacks: list[str] = None) -> bool:
193
193
  """
194
194
  Call the 'on_epoch_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
195
195
  callbacks will be called.
@@ -205,7 +205,7 @@ class CallbacksHandler:
205
205
  epoch=epoch,
206
206
  )
207
207
 
208
- def on_train_begin(self, callbacks: List[str] = None) -> bool:
208
+ def on_train_begin(self, callbacks: list[str] = None) -> bool:
209
209
  """
210
210
  Call the 'on_train_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
211
211
  callbacks will be called.
@@ -219,7 +219,7 @@ class CallbacksHandler:
219
219
  callbacks=self._parse_names(names=callbacks),
220
220
  )
221
221
 
222
- def on_train_end(self, callbacks: List[str] = None) -> bool:
222
+ def on_train_end(self, callbacks: list[str] = None) -> bool:
223
223
  """
224
224
  Call the 'on_train_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
225
225
  callbacks will be called.
@@ -233,7 +233,7 @@ class CallbacksHandler:
233
233
  callbacks=self._parse_names(names=callbacks),
234
234
  )
235
235
 
236
- def on_validation_begin(self, callbacks: List[str] = None) -> bool:
236
+ def on_validation_begin(self, callbacks: list[str] = None) -> bool:
237
237
  """
238
238
  Call the 'on_validation_begin' method of every callback in the callbacks list. If the list is 'None'
239
239
  (not given), all callbacks will be called.
@@ -250,8 +250,8 @@ class CallbacksHandler:
250
250
  def on_validation_end(
251
251
  self,
252
252
  loss_value: PyTorchTypes.MetricValueType,
253
- metric_values: List[float],
254
- callbacks: List[str] = None,
253
+ metric_values: list[float],
254
+ callbacks: list[str] = None,
255
255
  ) -> bool:
256
256
  """
257
257
  Call the 'on_validation_end' method of every callback in the callbacks list. If the list is 'None' (not given),
@@ -271,7 +271,7 @@ class CallbacksHandler:
271
271
  )
272
272
 
273
273
  def on_train_batch_begin(
274
- self, batch: int, x, y_true: Tensor, callbacks: List[str] = None
274
+ self, batch: int, x, y_true: Tensor, callbacks: list[str] = None
275
275
  ) -> bool:
276
276
  """
277
277
  Call the 'on_train_batch_begin' method of every callback in the callbacks list. If the list is 'None'
@@ -298,7 +298,7 @@ class CallbacksHandler:
298
298
  x,
299
299
  y_pred: Tensor,
300
300
  y_true: Tensor,
301
- callbacks: List[str] = None,
301
+ callbacks: list[str] = None,
302
302
  ) -> bool:
303
303
  """
304
304
  Call the 'on_train_batch_end' method of every callback in the callbacks list. If the list is 'None' (not given),
@@ -322,7 +322,7 @@ class CallbacksHandler:
322
322
  )
323
323
 
324
324
  def on_validation_batch_begin(
325
- self, batch: int, x, y_true: Tensor, callbacks: List[str] = None
325
+ self, batch: int, x, y_true: Tensor, callbacks: list[str] = None
326
326
  ) -> bool:
327
327
  """
328
328
  Call the 'on_validation_batch_begin' method of every callback in the callbacks list. If the list is 'None'
@@ -349,7 +349,7 @@ class CallbacksHandler:
349
349
  x,
350
350
  y_pred: Tensor,
351
351
  y_true: Tensor,
352
- callbacks: List[str] = None,
352
+ callbacks: list[str] = None,
353
353
  ) -> bool:
354
354
  """
355
355
  Call the 'on_validation_batch_end' method of every callback in the callbacks list. If the list is 'None'
@@ -375,7 +375,7 @@ class CallbacksHandler:
375
375
  def on_inference_begin(
376
376
  self,
377
377
  x,
378
- callbacks: List[str] = None,
378
+ callbacks: list[str] = None,
379
379
  ) -> bool:
380
380
  """
381
381
  Call the 'on_inference_begin' method of every callback in the callbacks list. If the list is 'None' (not given),
@@ -396,7 +396,7 @@ class CallbacksHandler:
396
396
  self,
397
397
  y_pred: Tensor,
398
398
  y_true: Tensor,
399
- callbacks: List[str] = None,
399
+ callbacks: list[str] = None,
400
400
  ) -> bool:
401
401
  """
402
402
  Call the 'on_inference_end' method of every callback in the callbacks list. If the list is 'None' (not given),
@@ -415,7 +415,7 @@ class CallbacksHandler:
415
415
  y_true=y_true,
416
416
  )
417
417
 
418
- def on_train_loss_begin(self, callbacks: List[str] = None) -> bool:
418
+ def on_train_loss_begin(self, callbacks: list[str] = None) -> bool:
419
419
  """
420
420
  Call the 'on_train_loss_begin' method of every callback in the callbacks list. If the list is 'None'
421
421
  (not given), all callbacks will be called.
@@ -430,7 +430,7 @@ class CallbacksHandler:
430
430
  )
431
431
 
432
432
  def on_train_loss_end(
433
- self, loss_value: PyTorchTypes.MetricValueType, callbacks: List[str] = None
433
+ self, loss_value: PyTorchTypes.MetricValueType, callbacks: list[str] = None
434
434
  ) -> bool:
435
435
  """
436
436
  Call the 'on_train_loss_end' method of every callback in the callbacks list. If the list is 'None' (not given),
@@ -447,7 +447,7 @@ class CallbacksHandler:
447
447
  loss_value=loss_value,
448
448
  )
449
449
 
450
- def on_validation_loss_begin(self, callbacks: List[str] = None) -> bool:
450
+ def on_validation_loss_begin(self, callbacks: list[str] = None) -> bool:
451
451
  """
452
452
  Call the 'on_validation_loss_begin' method of every callback in the callbacks list. If the list is 'None'
453
453
  (not given), all callbacks will be called.
@@ -462,7 +462,7 @@ class CallbacksHandler:
462
462
  )
463
463
 
464
464
  def on_validation_loss_end(
465
- self, loss_value: PyTorchTypes.MetricValueType, callbacks: List[str] = None
465
+ self, loss_value: PyTorchTypes.MetricValueType, callbacks: list[str] = None
466
466
  ) -> bool:
467
467
  """
468
468
  Call the 'on_validation_loss_end' method of every callback in the callbacks list. If the list is 'None'
@@ -479,7 +479,7 @@ class CallbacksHandler:
479
479
  loss_value=loss_value,
480
480
  )
481
481
 
482
- def on_train_metrics_begin(self, callbacks: List[str] = None) -> bool:
482
+ def on_train_metrics_begin(self, callbacks: list[str] = None) -> bool:
483
483
  """
484
484
  Call the 'on_train_metrics_begin' method of every callback in the callbacks list. If the list is 'None'
485
485
  (not given), all callbacks will be called.
@@ -495,8 +495,8 @@ class CallbacksHandler:
495
495
 
496
496
  def on_train_metrics_end(
497
497
  self,
498
- metric_values: List[PyTorchTypes.MetricValueType],
499
- callbacks: List[str] = None,
498
+ metric_values: list[PyTorchTypes.MetricValueType],
499
+ callbacks: list[str] = None,
500
500
  ) -> bool:
501
501
  """
502
502
  Call the 'on_train_metrics_end' method of every callback in the callbacks list. If the list is 'None'
@@ -513,7 +513,7 @@ class CallbacksHandler:
513
513
  metric_values=metric_values,
514
514
  )
515
515
 
516
- def on_validation_metrics_begin(self, callbacks: List[str] = None) -> bool:
516
+ def on_validation_metrics_begin(self, callbacks: list[str] = None) -> bool:
517
517
  """
518
518
  Call the 'on_validation_metrics_begin' method of every callback in the callbacks list. If the list is 'None'
519
519
  (not given), all callbacks will be called.
@@ -529,8 +529,8 @@ class CallbacksHandler:
529
529
 
530
530
  def on_validation_metrics_end(
531
531
  self,
532
- metric_values: List[PyTorchTypes.MetricValueType],
533
- callbacks: List[str] = None,
532
+ metric_values: list[PyTorchTypes.MetricValueType],
533
+ callbacks: list[str] = None,
534
534
  ) -> bool:
535
535
  """
536
536
  Call the 'on_validation_metrics_end' method of every callback in the callbacks list. If the list is 'None'
@@ -547,7 +547,7 @@ class CallbacksHandler:
547
547
  metric_values=metric_values,
548
548
  )
549
549
 
550
- def on_backward_begin(self, callbacks: List[str] = None) -> bool:
550
+ def on_backward_begin(self, callbacks: list[str] = None) -> bool:
551
551
  """
552
552
  Call the 'on_backward_begin' method of every callback in the callbacks list. If the list is 'None' (not given),
553
553
  all callbacks will be called.
@@ -561,7 +561,7 @@ class CallbacksHandler:
561
561
  callbacks=self._parse_names(names=callbacks),
562
562
  )
563
563
 
564
- def on_backward_end(self, callbacks: List[str] = None) -> bool:
564
+ def on_backward_end(self, callbacks: list[str] = None) -> bool:
565
565
  """
566
566
  Call the 'on_backward_end' method of every callback in the callbacks list. If the list is 'None' (not given),
567
567
  all callbacks will be called.
@@ -575,7 +575,7 @@ class CallbacksHandler:
575
575
  callbacks=self._parse_names(names=callbacks),
576
576
  )
577
577
 
578
- def on_optimizer_step_begin(self, callbacks: List[str] = None) -> bool:
578
+ def on_optimizer_step_begin(self, callbacks: list[str] = None) -> bool:
579
579
  """
580
580
  Call the 'on_optimizer_step_begin' method of every callback in the callbacks list. If the list is 'None'
581
581
  (not given), all callbacks will be called.
@@ -589,7 +589,7 @@ class CallbacksHandler:
589
589
  callbacks=self._parse_names(names=callbacks),
590
590
  )
591
591
 
592
- def on_optimizer_step_end(self, callbacks: List[str] = None) -> bool:
592
+ def on_optimizer_step_end(self, callbacks: list[str] = None) -> bool:
593
593
  """
594
594
  Call the 'on_optimizer_step_end' method of every callback in the callbacks list. If the list is 'None'
595
595
  (not given), all callbacks will be called.
@@ -603,7 +603,7 @@ class CallbacksHandler:
603
603
  callbacks=self._parse_names(names=callbacks),
604
604
  )
605
605
 
606
- def on_scheduler_step_begin(self, callbacks: List[str] = None) -> bool:
606
+ def on_scheduler_step_begin(self, callbacks: list[str] = None) -> bool:
607
607
  """
608
608
  Call the 'on_scheduler_step_begin' method of every callback in the callbacks list. If the list is 'None'
609
609
  (not given), all callbacks will be called.
@@ -617,7 +617,7 @@ class CallbacksHandler:
617
617
  callbacks=self._parse_names(names=callbacks),
618
618
  )
619
619
 
620
- def on_scheduler_step_end(self, callbacks: List[str] = None) -> bool:
620
+ def on_scheduler_step_end(self, callbacks: list[str] = None) -> bool:
621
621
  """
622
622
  Call the 'on_scheduler_step_end' method of every callback in the callbacks list. If the list is 'None'
623
623
  (not given), all callbacks will be called.
@@ -631,7 +631,7 @@ class CallbacksHandler:
631
631
  callbacks=self._parse_names(names=callbacks),
632
632
  )
633
633
 
634
- def _parse_names(self, names: Union[List[str], None]) -> List[str]:
634
+ def _parse_names(self, names: Union[list[str], None]) -> list[str]:
635
635
  """
636
636
  Parse the given callbacks names. If they are not 'None' then the names will be returned as they are, otherwise
637
637
  all of the callbacks handled by this handler will be returned (the default behavior of when there were no names
@@ -646,7 +646,7 @@ class CallbacksHandler:
646
646
  return list(self._callbacks.keys())
647
647
 
648
648
  def _run_callbacks(
649
- self, method_name: str, callbacks: List[str], *args, **kwargs
649
+ self, method_name: str, callbacks: list[str], *args, **kwargs
650
650
  ) -> bool:
651
651
  """
652
652
  Run the given method from the 'CallbackInterface' on all the specified callbacks with the given arguments.