mlrun 1.6.4rc2__py3-none-any.whl → 1.7.0rc20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (291) hide show
  1. mlrun/__init__.py +11 -1
  2. mlrun/__main__.py +26 -112
  3. mlrun/alerts/__init__.py +15 -0
  4. mlrun/alerts/alert.py +144 -0
  5. mlrun/api/schemas/__init__.py +5 -4
  6. mlrun/artifacts/__init__.py +8 -3
  7. mlrun/artifacts/base.py +46 -257
  8. mlrun/artifacts/dataset.py +11 -192
  9. mlrun/artifacts/manager.py +47 -48
  10. mlrun/artifacts/model.py +31 -159
  11. mlrun/artifacts/plots.py +23 -380
  12. mlrun/common/constants.py +69 -0
  13. mlrun/common/db/sql_session.py +2 -3
  14. mlrun/common/formatters/__init__.py +19 -0
  15. mlrun/common/formatters/artifact.py +21 -0
  16. mlrun/common/formatters/base.py +78 -0
  17. mlrun/common/formatters/function.py +41 -0
  18. mlrun/common/formatters/pipeline.py +53 -0
  19. mlrun/common/formatters/project.py +51 -0
  20. mlrun/common/helpers.py +1 -2
  21. mlrun/common/model_monitoring/helpers.py +9 -5
  22. mlrun/{runtimes → common/runtimes}/constants.py +37 -9
  23. mlrun/common/schemas/__init__.py +24 -4
  24. mlrun/common/schemas/alert.py +203 -0
  25. mlrun/common/schemas/api_gateway.py +148 -0
  26. mlrun/common/schemas/artifact.py +18 -8
  27. mlrun/common/schemas/auth.py +11 -5
  28. mlrun/common/schemas/background_task.py +1 -1
  29. mlrun/common/schemas/client_spec.py +4 -1
  30. mlrun/common/schemas/feature_store.py +16 -16
  31. mlrun/common/schemas/frontend_spec.py +8 -7
  32. mlrun/common/schemas/function.py +5 -1
  33. mlrun/common/schemas/hub.py +11 -18
  34. mlrun/common/schemas/memory_reports.py +2 -2
  35. mlrun/common/schemas/model_monitoring/__init__.py +18 -3
  36. mlrun/common/schemas/model_monitoring/constants.py +83 -26
  37. mlrun/common/schemas/model_monitoring/grafana.py +13 -9
  38. mlrun/common/schemas/model_monitoring/model_endpoints.py +99 -16
  39. mlrun/common/schemas/notification.py +4 -4
  40. mlrun/common/schemas/object.py +2 -2
  41. mlrun/{runtimes/mpijob/v1alpha1.py → common/schemas/pagination.py} +10 -13
  42. mlrun/common/schemas/pipeline.py +1 -10
  43. mlrun/common/schemas/project.py +24 -23
  44. mlrun/common/schemas/runtime_resource.py +8 -12
  45. mlrun/common/schemas/schedule.py +3 -3
  46. mlrun/common/schemas/tag.py +1 -2
  47. mlrun/common/schemas/workflow.py +2 -2
  48. mlrun/common/types.py +7 -1
  49. mlrun/config.py +54 -17
  50. mlrun/data_types/to_pandas.py +10 -12
  51. mlrun/datastore/__init__.py +5 -8
  52. mlrun/datastore/alibaba_oss.py +130 -0
  53. mlrun/datastore/azure_blob.py +17 -5
  54. mlrun/datastore/base.py +62 -39
  55. mlrun/datastore/datastore.py +28 -9
  56. mlrun/datastore/datastore_profile.py +146 -20
  57. mlrun/datastore/filestore.py +0 -1
  58. mlrun/datastore/google_cloud_storage.py +6 -2
  59. mlrun/datastore/hdfs.py +56 -0
  60. mlrun/datastore/inmem.py +2 -2
  61. mlrun/datastore/redis.py +6 -2
  62. mlrun/datastore/s3.py +9 -0
  63. mlrun/datastore/snowflake_utils.py +43 -0
  64. mlrun/datastore/sources.py +201 -96
  65. mlrun/datastore/spark_utils.py +1 -2
  66. mlrun/datastore/store_resources.py +7 -7
  67. mlrun/datastore/targets.py +358 -104
  68. mlrun/datastore/utils.py +72 -58
  69. mlrun/datastore/v3io.py +5 -1
  70. mlrun/db/base.py +185 -35
  71. mlrun/db/factory.py +1 -1
  72. mlrun/db/httpdb.py +614 -179
  73. mlrun/db/nopdb.py +210 -26
  74. mlrun/errors.py +12 -1
  75. mlrun/execution.py +41 -24
  76. mlrun/feature_store/__init__.py +0 -2
  77. mlrun/feature_store/api.py +40 -72
  78. mlrun/feature_store/common.py +1 -1
  79. mlrun/feature_store/feature_set.py +76 -55
  80. mlrun/feature_store/feature_vector.py +28 -30
  81. mlrun/feature_store/ingestion.py +7 -6
  82. mlrun/feature_store/retrieval/base.py +16 -11
  83. mlrun/feature_store/retrieval/conversion.py +11 -13
  84. mlrun/feature_store/retrieval/dask_merger.py +2 -0
  85. mlrun/feature_store/retrieval/job.py +9 -3
  86. mlrun/feature_store/retrieval/local_merger.py +2 -0
  87. mlrun/feature_store/retrieval/spark_merger.py +34 -24
  88. mlrun/feature_store/steps.py +37 -34
  89. mlrun/features.py +9 -20
  90. mlrun/frameworks/_common/artifacts_library.py +9 -9
  91. mlrun/frameworks/_common/mlrun_interface.py +5 -5
  92. mlrun/frameworks/_common/model_handler.py +48 -48
  93. mlrun/frameworks/_common/plan.py +2 -3
  94. mlrun/frameworks/_common/producer.py +3 -4
  95. mlrun/frameworks/_common/utils.py +5 -5
  96. mlrun/frameworks/_dl_common/loggers/logger.py +6 -7
  97. mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +9 -9
  98. mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +23 -47
  99. mlrun/frameworks/_ml_common/artifacts_library.py +1 -2
  100. mlrun/frameworks/_ml_common/loggers/logger.py +3 -4
  101. mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +4 -5
  102. mlrun/frameworks/_ml_common/model_handler.py +24 -24
  103. mlrun/frameworks/_ml_common/pkl_model_server.py +2 -2
  104. mlrun/frameworks/_ml_common/plan.py +1 -1
  105. mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +2 -3
  106. mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +2 -3
  107. mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
  108. mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +3 -3
  109. mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
  110. mlrun/frameworks/_ml_common/utils.py +4 -4
  111. mlrun/frameworks/auto_mlrun/auto_mlrun.py +9 -9
  112. mlrun/frameworks/huggingface/model_server.py +4 -4
  113. mlrun/frameworks/lgbm/__init__.py +33 -33
  114. mlrun/frameworks/lgbm/callbacks/callback.py +2 -4
  115. mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -5
  116. mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -5
  117. mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +1 -3
  118. mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +6 -6
  119. mlrun/frameworks/lgbm/model_handler.py +10 -10
  120. mlrun/frameworks/lgbm/model_server.py +6 -6
  121. mlrun/frameworks/lgbm/utils.py +5 -5
  122. mlrun/frameworks/onnx/dataset.py +8 -8
  123. mlrun/frameworks/onnx/mlrun_interface.py +3 -3
  124. mlrun/frameworks/onnx/model_handler.py +6 -6
  125. mlrun/frameworks/onnx/model_server.py +7 -7
  126. mlrun/frameworks/parallel_coordinates.py +4 -3
  127. mlrun/frameworks/pytorch/__init__.py +18 -18
  128. mlrun/frameworks/pytorch/callbacks/callback.py +4 -5
  129. mlrun/frameworks/pytorch/callbacks/logging_callback.py +17 -17
  130. mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +11 -11
  131. mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +23 -29
  132. mlrun/frameworks/pytorch/callbacks_handler.py +38 -38
  133. mlrun/frameworks/pytorch/mlrun_interface.py +20 -20
  134. mlrun/frameworks/pytorch/model_handler.py +17 -17
  135. mlrun/frameworks/pytorch/model_server.py +7 -7
  136. mlrun/frameworks/sklearn/__init__.py +13 -13
  137. mlrun/frameworks/sklearn/estimator.py +4 -4
  138. mlrun/frameworks/sklearn/metrics_library.py +14 -14
  139. mlrun/frameworks/sklearn/mlrun_interface.py +3 -6
  140. mlrun/frameworks/sklearn/model_handler.py +2 -2
  141. mlrun/frameworks/tf_keras/__init__.py +10 -7
  142. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +15 -15
  143. mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +11 -11
  144. mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +19 -23
  145. mlrun/frameworks/tf_keras/mlrun_interface.py +9 -11
  146. mlrun/frameworks/tf_keras/model_handler.py +14 -14
  147. mlrun/frameworks/tf_keras/model_server.py +6 -6
  148. mlrun/frameworks/xgboost/__init__.py +13 -13
  149. mlrun/frameworks/xgboost/model_handler.py +6 -6
  150. mlrun/k8s_utils.py +14 -16
  151. mlrun/launcher/__init__.py +1 -1
  152. mlrun/launcher/base.py +16 -15
  153. mlrun/launcher/client.py +8 -6
  154. mlrun/launcher/factory.py +1 -1
  155. mlrun/launcher/local.py +17 -11
  156. mlrun/launcher/remote.py +16 -10
  157. mlrun/lists.py +7 -6
  158. mlrun/model.py +238 -73
  159. mlrun/model_monitoring/__init__.py +1 -1
  160. mlrun/model_monitoring/api.py +138 -315
  161. mlrun/model_monitoring/application.py +5 -296
  162. mlrun/model_monitoring/applications/__init__.py +24 -0
  163. mlrun/model_monitoring/applications/_application_steps.py +157 -0
  164. mlrun/model_monitoring/applications/base.py +282 -0
  165. mlrun/model_monitoring/applications/context.py +214 -0
  166. mlrun/model_monitoring/applications/evidently_base.py +211 -0
  167. mlrun/model_monitoring/applications/histogram_data_drift.py +349 -0
  168. mlrun/model_monitoring/applications/results.py +99 -0
  169. mlrun/model_monitoring/controller.py +104 -84
  170. mlrun/model_monitoring/controller_handler.py +13 -5
  171. mlrun/model_monitoring/db/__init__.py +18 -0
  172. mlrun/model_monitoring/{stores → db/stores}/__init__.py +43 -36
  173. mlrun/model_monitoring/db/stores/base/__init__.py +15 -0
  174. mlrun/model_monitoring/{stores/model_endpoint_store.py → db/stores/base/store.py} +64 -40
  175. mlrun/model_monitoring/db/stores/sqldb/__init__.py +13 -0
  176. mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +71 -0
  177. mlrun/model_monitoring/{stores → db/stores/sqldb}/models/base.py +109 -5
  178. mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +88 -0
  179. mlrun/model_monitoring/{stores/models/mysql.py → db/stores/sqldb/models/sqlite.py} +19 -13
  180. mlrun/model_monitoring/db/stores/sqldb/sql_store.py +684 -0
  181. mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +13 -0
  182. mlrun/model_monitoring/{stores/kv_model_endpoint_store.py → db/stores/v3io_kv/kv_store.py} +310 -165
  183. mlrun/model_monitoring/db/tsdb/__init__.py +100 -0
  184. mlrun/model_monitoring/db/tsdb/base.py +329 -0
  185. mlrun/model_monitoring/db/tsdb/helpers.py +30 -0
  186. mlrun/model_monitoring/db/tsdb/tdengine/__init__.py +15 -0
  187. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +240 -0
  188. mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +45 -0
  189. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +397 -0
  190. mlrun/model_monitoring/db/tsdb/v3io/__init__.py +15 -0
  191. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +117 -0
  192. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +630 -0
  193. mlrun/model_monitoring/evidently_application.py +6 -118
  194. mlrun/model_monitoring/features_drift_table.py +134 -106
  195. mlrun/model_monitoring/helpers.py +127 -28
  196. mlrun/model_monitoring/metrics/__init__.py +13 -0
  197. mlrun/model_monitoring/metrics/histogram_distance.py +127 -0
  198. mlrun/model_monitoring/model_endpoint.py +3 -2
  199. mlrun/model_monitoring/prometheus.py +1 -4
  200. mlrun/model_monitoring/stream_processing.py +62 -231
  201. mlrun/model_monitoring/tracking_policy.py +9 -2
  202. mlrun/model_monitoring/writer.py +152 -124
  203. mlrun/package/__init__.py +6 -6
  204. mlrun/package/context_handler.py +5 -5
  205. mlrun/package/packager.py +7 -7
  206. mlrun/package/packagers/default_packager.py +6 -6
  207. mlrun/package/packagers/numpy_packagers.py +15 -15
  208. mlrun/package/packagers/pandas_packagers.py +5 -5
  209. mlrun/package/packagers/python_standard_library_packagers.py +10 -10
  210. mlrun/package/packagers_manager.py +19 -23
  211. mlrun/package/utils/_formatter.py +6 -6
  212. mlrun/package/utils/_pickler.py +2 -2
  213. mlrun/package/utils/_supported_format.py +4 -4
  214. mlrun/package/utils/log_hint_utils.py +2 -2
  215. mlrun/package/utils/type_hint_utils.py +4 -9
  216. mlrun/platforms/__init__.py +11 -10
  217. mlrun/platforms/iguazio.py +24 -203
  218. mlrun/projects/operations.py +35 -21
  219. mlrun/projects/pipelines.py +68 -99
  220. mlrun/projects/project.py +830 -266
  221. mlrun/render.py +3 -11
  222. mlrun/run.py +162 -166
  223. mlrun/runtimes/__init__.py +62 -7
  224. mlrun/runtimes/base.py +39 -32
  225. mlrun/runtimes/daskjob.py +8 -8
  226. mlrun/runtimes/databricks_job/databricks_cancel_task.py +1 -1
  227. mlrun/runtimes/databricks_job/databricks_runtime.py +7 -7
  228. mlrun/runtimes/databricks_job/databricks_wrapper.py +1 -1
  229. mlrun/runtimes/funcdoc.py +0 -28
  230. mlrun/runtimes/function_reference.py +1 -1
  231. mlrun/runtimes/kubejob.py +28 -122
  232. mlrun/runtimes/local.py +6 -3
  233. mlrun/runtimes/mpijob/__init__.py +0 -20
  234. mlrun/runtimes/mpijob/abstract.py +9 -10
  235. mlrun/runtimes/mpijob/v1.py +1 -1
  236. mlrun/{model_monitoring/stores/models/sqlite.py → runtimes/nuclio/__init__.py} +7 -9
  237. mlrun/runtimes/nuclio/api_gateway.py +709 -0
  238. mlrun/runtimes/nuclio/application/__init__.py +15 -0
  239. mlrun/runtimes/nuclio/application/application.py +523 -0
  240. mlrun/runtimes/nuclio/application/reverse_proxy.go +95 -0
  241. mlrun/runtimes/{function.py → nuclio/function.py} +112 -73
  242. mlrun/runtimes/{nuclio.py → nuclio/nuclio.py} +6 -6
  243. mlrun/runtimes/{serving.py → nuclio/serving.py} +45 -51
  244. mlrun/runtimes/pod.py +286 -88
  245. mlrun/runtimes/remotesparkjob.py +2 -2
  246. mlrun/runtimes/sparkjob/spark3job.py +51 -34
  247. mlrun/runtimes/utils.py +7 -75
  248. mlrun/secrets.py +9 -5
  249. mlrun/serving/remote.py +2 -7
  250. mlrun/serving/routers.py +13 -10
  251. mlrun/serving/server.py +22 -26
  252. mlrun/serving/states.py +99 -25
  253. mlrun/serving/utils.py +3 -3
  254. mlrun/serving/v1_serving.py +6 -7
  255. mlrun/serving/v2_serving.py +59 -20
  256. mlrun/track/tracker.py +2 -1
  257. mlrun/track/tracker_manager.py +3 -3
  258. mlrun/track/trackers/mlflow_tracker.py +1 -2
  259. mlrun/utils/async_http.py +5 -7
  260. mlrun/utils/azure_vault.py +1 -1
  261. mlrun/utils/clones.py +1 -2
  262. mlrun/utils/condition_evaluator.py +3 -3
  263. mlrun/utils/db.py +3 -3
  264. mlrun/utils/helpers.py +183 -197
  265. mlrun/utils/http.py +2 -5
  266. mlrun/utils/logger.py +76 -14
  267. mlrun/utils/notifications/notification/__init__.py +17 -12
  268. mlrun/utils/notifications/notification/base.py +14 -2
  269. mlrun/utils/notifications/notification/console.py +2 -0
  270. mlrun/utils/notifications/notification/git.py +3 -1
  271. mlrun/utils/notifications/notification/ipython.py +3 -1
  272. mlrun/utils/notifications/notification/slack.py +101 -21
  273. mlrun/utils/notifications/notification/webhook.py +11 -1
  274. mlrun/utils/notifications/notification_pusher.py +155 -30
  275. mlrun/utils/retryer.py +208 -0
  276. mlrun/utils/singleton.py +1 -1
  277. mlrun/utils/v3io_clients.py +2 -4
  278. mlrun/utils/version/version.json +2 -2
  279. mlrun/utils/version/version.py +2 -6
  280. {mlrun-1.6.4rc2.dist-info → mlrun-1.7.0rc20.dist-info}/METADATA +31 -19
  281. mlrun-1.7.0rc20.dist-info/RECORD +353 -0
  282. mlrun/kfpops.py +0 -868
  283. mlrun/model_monitoring/batch.py +0 -1095
  284. mlrun/model_monitoring/stores/models/__init__.py +0 -27
  285. mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -384
  286. mlrun/platforms/other.py +0 -306
  287. mlrun-1.6.4rc2.dist-info/RECORD +0 -314
  288. {mlrun-1.6.4rc2.dist-info → mlrun-1.7.0rc20.dist-info}/LICENSE +0 -0
  289. {mlrun-1.6.4rc2.dist-info → mlrun-1.7.0rc20.dist-info}/WHEEL +0 -0
  290. {mlrun-1.6.4rc2.dist-info → mlrun-1.7.0rc20.dist-info}/entry_points.txt +0 -0
  291. {mlrun-1.6.4rc2.dist-info → mlrun-1.7.0rc20.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  from datetime import datetime
16
- from typing import Callable, Dict, List, Tuple, Union
16
+ from typing import Callable, Union
17
17
 
18
18
  import torch
19
19
  from torch import Tensor
@@ -63,7 +63,7 @@ class _PyTorchTensorboardLogger(TensorboardLogger):
63
63
 
64
64
  def __init__(
65
65
  self,
66
- statistics_functions: List[
66
+ statistics_functions: list[
67
67
  Callable[[Union[Parameter]], Union[float, Parameter]]
68
68
  ],
69
69
  context: mlrun.MLClientCtx = None,
@@ -94,7 +94,7 @@ class _PyTorchTensorboardLogger(TensorboardLogger):
94
94
  update. Notice that writing to tensorboard too frequently may cause the training
95
95
  to be slower. Default: 'epoch'.
96
96
  """
97
- super(_PyTorchTensorboardLogger, self).__init__(
97
+ super().__init__(
98
98
  statistics_functions=statistics_functions,
99
99
  context=context,
100
100
  tensorboard_directory=tensorboard_directory,
@@ -249,19 +249,19 @@ class TensorboardLoggingCallback(LoggingCallback):
249
249
  context: mlrun.MLClientCtx = None,
250
250
  tensorboard_directory: str = None,
251
251
  run_name: str = None,
252
- weights: Union[bool, List[str]] = False,
253
- statistics_functions: List[
252
+ weights: Union[bool, list[str]] = False,
253
+ statistics_functions: list[
254
254
  Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]
255
255
  ] = None,
256
- dynamic_hyperparameters: Dict[
256
+ dynamic_hyperparameters: dict[
257
257
  str,
258
- Tuple[
258
+ tuple[
259
259
  str,
260
- Union[List[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
260
+ Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
261
261
  ],
262
262
  ] = None,
263
- static_hyperparameters: Dict[
264
- str, Union[PyTorchTypes.TrackableType, Tuple[str, List[Union[str, int]]]]
263
+ static_hyperparameters: dict[
264
+ str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
265
265
  ] = None,
266
266
  update_frequency: Union[int, str] = "epoch",
267
267
  auto_log: bool = False,
@@ -322,7 +322,7 @@ class TensorboardLoggingCallback(LoggingCallback):
322
322
  :raise MLRunInvalidArgumentError: In case both 'context' and 'tensorboard_directory' parameters were not given
323
323
  or the 'update_frequency' was incorrect.
324
324
  """
325
- super(TensorboardLoggingCallback, self).__init__(
325
+ super().__init__(
326
326
  dynamic_hyperparameters=dynamic_hyperparameters,
327
327
  static_hyperparameters=static_hyperparameters,
328
328
  auto_log=auto_log,
@@ -345,7 +345,7 @@ class TensorboardLoggingCallback(LoggingCallback):
345
345
  # Save the configurations:
346
346
  self._tracked_weights = weights
347
347
 
348
- def get_weights(self) -> Dict[str, Parameter]:
348
+ def get_weights(self) -> dict[str, Parameter]:
349
349
  """
350
350
  Get the weights tensors tracked. The weights will be stored in a dictionary where each key is the weight's name
351
351
  and the value is the weight's parameter (tensor).
@@ -354,7 +354,7 @@ class TensorboardLoggingCallback(LoggingCallback):
354
354
  """
355
355
  return self._logger.weights
356
356
 
357
- def get_weights_statistics(self) -> Dict[str, Dict[str, List[float]]]:
357
+ def get_weights_statistics(self) -> dict[str, dict[str, list[float]]]:
358
358
  """
359
359
  Get the weights mean results logged. The results will be stored in a dictionary where each key is the weight's
360
360
  name and the value is a list of mean values per epoch.
@@ -365,7 +365,7 @@ class TensorboardLoggingCallback(LoggingCallback):
365
365
 
366
366
  @staticmethod
367
367
  def get_default_weight_statistics_list() -> (
368
- List[Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]]
368
+ list[Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]]
369
369
  ):
370
370
  """
371
371
  Get the default list of statistics functions being applied on the tracked weights each epoch.
@@ -381,7 +381,7 @@ class TensorboardLoggingCallback(LoggingCallback):
381
381
  validation_set: DataLoader = None,
382
382
  loss_function: Module = None,
383
383
  optimizer: Optimizer = None,
384
- metric_functions: List[PyTorchTypes.MetricFunctionType] = None,
384
+ metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
385
385
  scheduler=None,
386
386
  ):
387
387
  """
@@ -396,7 +396,7 @@ class TensorboardLoggingCallback(LoggingCallback):
396
396
  :param metric_functions: The metric functions to be stored in this callback.
397
397
  :param scheduler: The scheduler to be stored in this callback.
398
398
  """
399
- super(TensorboardLoggingCallback, self).on_setup(
399
+ super().on_setup(
400
400
  model=model,
401
401
  training_set=training_set,
402
402
  validation_set=validation_set,
@@ -439,7 +439,7 @@ class TensorboardLoggingCallback(LoggingCallback):
439
439
  for logging. Epoch 0 (pre-run state) will be logged here.
440
440
  """
441
441
  # Setup all the results and hyperparameters dictionaries:
442
- super(TensorboardLoggingCallback, self).on_run_begin()
442
+ super().on_run_begin()
443
443
 
444
444
  # Log the initial summary of the run:
445
445
  self._logger.write_initial_summary_text()
@@ -470,10 +470,10 @@ class TensorboardLoggingCallback(LoggingCallback):
470
470
  # Write the final summary of the run:
471
471
  self._logger.write_final_summary_text()
472
472
 
473
- super(TensorboardLoggingCallback, self).on_run_end()
473
+ super().on_run_end()
474
474
 
475
475
  def on_validation_end(
476
- self, loss_value: PyTorchTypes.MetricValueType, metric_values: List[float]
476
+ self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
477
477
  ):
478
478
  """
479
479
  Before the validation (in a training case it will be per epoch) ends, this method will be called to log the
@@ -482,9 +482,7 @@ class TensorboardLoggingCallback(LoggingCallback):
482
482
  :param loss_value: The loss summary of this validation.
483
483
  :param metric_values: The metrics summaries of this validation.
484
484
  """
485
- super(TensorboardLoggingCallback, self).on_validation_end(
486
- loss_value=loss_value, metric_values=metric_values
487
- )
485
+ super().on_validation_end(loss_value=loss_value, metric_values=metric_values)
488
486
 
489
487
  # Check if this run was part of an evaluation:
490
488
  if not self._is_training:
@@ -503,7 +501,7 @@ class TensorboardLoggingCallback(LoggingCallback):
503
501
 
504
502
  :param epoch: The epoch that has just ended.
505
503
  """
506
- super(TensorboardLoggingCallback, self).on_epoch_end(epoch=epoch)
504
+ super().on_epoch_end(epoch=epoch)
507
505
 
508
506
  # Log the weights statistics:
509
507
  self._logger.log_weights_statistics()
@@ -540,9 +538,7 @@ class TensorboardLoggingCallback(LoggingCallback):
540
538
  :param y_true: The true value part of the current batch.
541
539
  :param y_pred: The prediction (output) of the model for this batch's input ('x').
542
540
  """
543
- super(TensorboardLoggingCallback, self).on_train_batch_end(
544
- batch=batch, x=x, y_true=y_true, y_pred=y_pred
545
- )
541
+ super().on_train_batch_end(batch=batch, x=x, y_true=y_true, y_pred=y_pred)
546
542
 
547
543
  # Write the batch loss and metrics results to their graphs:
548
544
  self._logger.write_training_results()
@@ -559,9 +555,7 @@ class TensorboardLoggingCallback(LoggingCallback):
559
555
  :param y_true: The true value part of the current batch.
560
556
  :param y_pred: The prediction (output) of the model for this batch's input ('x').
561
557
  """
562
- super(TensorboardLoggingCallback, self).on_validation_batch_end(
563
- batch=batch, x=x, y_true=y_true, y_pred=y_pred
564
- )
558
+ super().on_validation_batch_end(batch=batch, x=x, y_true=y_true, y_pred=y_pred)
565
559
 
566
560
  # Write the batch loss and metrics results to their graphs:
567
561
  self._logger.write_validation_results()
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  #
15
- from typing import Dict, List, Tuple, Union
15
+ from typing import Union
16
16
 
17
17
  from torch import Tensor
18
18
  from torch.nn import Module
@@ -66,7 +66,7 @@ class CallbacksHandler:
66
66
  A class for handling multiple callbacks during a run.
67
67
  """
68
68
 
69
- def __init__(self, callbacks: List[Union[Callback, Tuple[str, Callback]]]):
69
+ def __init__(self, callbacks: list[Union[Callback, tuple[str, Callback]]]):
70
70
  """
71
71
  Initialize the callbacks handler with the given callbacks he will handle. The callbacks can be passed as their
72
72
  initialized instances or as a tuple where [0] is a name that will be attached to him and [1] will be the
@@ -99,7 +99,7 @@ class CallbacksHandler:
99
99
  self._callbacks[callback.__class__.__name__] = callback
100
100
 
101
101
  @property
102
- def callbacks(self) -> Dict[str, Callback]:
102
+ def callbacks(self) -> dict[str, Callback]:
103
103
  """
104
104
  Get the callbacks dictionary handled by this handler.
105
105
 
@@ -114,9 +114,9 @@ class CallbacksHandler:
114
114
  validation_set: DataLoader,
115
115
  loss_function: Module,
116
116
  optimizer: Optimizer,
117
- metric_functions: List[PyTorchTypes.MetricFunctionType],
117
+ metric_functions: list[PyTorchTypes.MetricFunctionType],
118
118
  scheduler,
119
- callbacks: List[str] = None,
119
+ callbacks: list[str] = None,
120
120
  ) -> bool:
121
121
  """
122
122
  Call the 'on_setup' method of every callback in the callbacks list. If the list is 'None' (not given), all
@@ -145,7 +145,7 @@ class CallbacksHandler:
145
145
  scheduler=scheduler,
146
146
  )
147
147
 
148
- def on_run_begin(self, callbacks: List[str] = None) -> bool:
148
+ def on_run_begin(self, callbacks: list[str] = None) -> bool:
149
149
  """
150
150
  Call the 'on_run_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
151
151
  callbacks will be called.
@@ -159,7 +159,7 @@ class CallbacksHandler:
159
159
  callbacks=self._parse_names(names=callbacks),
160
160
  )
161
161
 
162
- def on_run_end(self, callbacks: List[str] = None) -> bool:
162
+ def on_run_end(self, callbacks: list[str] = None) -> bool:
163
163
  """
164
164
  Call the 'on_run_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
165
165
  callbacks will be called.
@@ -173,7 +173,7 @@ class CallbacksHandler:
173
173
  callbacks=self._parse_names(names=callbacks),
174
174
  )
175
175
 
176
- def on_epoch_begin(self, epoch: int, callbacks: List[str] = None) -> bool:
176
+ def on_epoch_begin(self, epoch: int, callbacks: list[str] = None) -> bool:
177
177
  """
178
178
  Call the 'on_epoch_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
179
179
  callbacks will be called.
@@ -189,7 +189,7 @@ class CallbacksHandler:
189
189
  epoch=epoch,
190
190
  )
191
191
 
192
- def on_epoch_end(self, epoch: int, callbacks: List[str] = None) -> bool:
192
+ def on_epoch_end(self, epoch: int, callbacks: list[str] = None) -> bool:
193
193
  """
194
194
  Call the 'on_epoch_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
195
195
  callbacks will be called.
@@ -205,7 +205,7 @@ class CallbacksHandler:
205
205
  epoch=epoch,
206
206
  )
207
207
 
208
- def on_train_begin(self, callbacks: List[str] = None) -> bool:
208
+ def on_train_begin(self, callbacks: list[str] = None) -> bool:
209
209
  """
210
210
  Call the 'on_train_begin' method of every callback in the callbacks list. If the list is 'None' (not given), all
211
211
  callbacks will be called.
@@ -219,7 +219,7 @@ class CallbacksHandler:
219
219
  callbacks=self._parse_names(names=callbacks),
220
220
  )
221
221
 
222
- def on_train_end(self, callbacks: List[str] = None) -> bool:
222
+ def on_train_end(self, callbacks: list[str] = None) -> bool:
223
223
  """
224
224
  Call the 'on_train_end' method of every callback in the callbacks list. If the list is 'None' (not given), all
225
225
  callbacks will be called.
@@ -233,7 +233,7 @@ class CallbacksHandler:
233
233
  callbacks=self._parse_names(names=callbacks),
234
234
  )
235
235
 
236
- def on_validation_begin(self, callbacks: List[str] = None) -> bool:
236
+ def on_validation_begin(self, callbacks: list[str] = None) -> bool:
237
237
  """
238
238
  Call the 'on_validation_begin' method of every callback in the callbacks list. If the list is 'None'
239
239
  (not given), all callbacks will be called.
@@ -250,8 +250,8 @@ class CallbacksHandler:
250
250
  def on_validation_end(
251
251
  self,
252
252
  loss_value: PyTorchTypes.MetricValueType,
253
- metric_values: List[float],
254
- callbacks: List[str] = None,
253
+ metric_values: list[float],
254
+ callbacks: list[str] = None,
255
255
  ) -> bool:
256
256
  """
257
257
  Call the 'on_validation_end' method of every callback in the callbacks list. If the list is 'None' (not given),
@@ -271,7 +271,7 @@ class CallbacksHandler:
271
271
  )
272
272
 
273
273
  def on_train_batch_begin(
274
- self, batch: int, x, y_true: Tensor, callbacks: List[str] = None
274
+ self, batch: int, x, y_true: Tensor, callbacks: list[str] = None
275
275
  ) -> bool:
276
276
  """
277
277
  Call the 'on_train_batch_begin' method of every callback in the callbacks list. If the list is 'None'
@@ -298,7 +298,7 @@ class CallbacksHandler:
298
298
  x,
299
299
  y_pred: Tensor,
300
300
  y_true: Tensor,
301
- callbacks: List[str] = None,
301
+ callbacks: list[str] = None,
302
302
  ) -> bool:
303
303
  """
304
304
  Call the 'on_train_batch_end' method of every callback in the callbacks list. If the list is 'None' (not given),
@@ -322,7 +322,7 @@ class CallbacksHandler:
322
322
  )
323
323
 
324
324
  def on_validation_batch_begin(
325
- self, batch: int, x, y_true: Tensor, callbacks: List[str] = None
325
+ self, batch: int, x, y_true: Tensor, callbacks: list[str] = None
326
326
  ) -> bool:
327
327
  """
328
328
  Call the 'on_validation_batch_begin' method of every callback in the callbacks list. If the list is 'None'
@@ -349,7 +349,7 @@ class CallbacksHandler:
349
349
  x,
350
350
  y_pred: Tensor,
351
351
  y_true: Tensor,
352
- callbacks: List[str] = None,
352
+ callbacks: list[str] = None,
353
353
  ) -> bool:
354
354
  """
355
355
  Call the 'on_validation_batch_end' method of every callback in the callbacks list. If the list is 'None'
@@ -375,7 +375,7 @@ class CallbacksHandler:
375
375
  def on_inference_begin(
376
376
  self,
377
377
  x,
378
- callbacks: List[str] = None,
378
+ callbacks: list[str] = None,
379
379
  ) -> bool:
380
380
  """
381
381
  Call the 'on_inference_begin' method of every callback in the callbacks list. If the list is 'None' (not given),
@@ -396,7 +396,7 @@ class CallbacksHandler:
396
396
  self,
397
397
  y_pred: Tensor,
398
398
  y_true: Tensor,
399
- callbacks: List[str] = None,
399
+ callbacks: list[str] = None,
400
400
  ) -> bool:
401
401
  """
402
402
  Call the 'on_inference_end' method of every callback in the callbacks list. If the list is 'None' (not given),
@@ -415,7 +415,7 @@ class CallbacksHandler:
415
415
  y_true=y_true,
416
416
  )
417
417
 
418
- def on_train_loss_begin(self, callbacks: List[str] = None) -> bool:
418
+ def on_train_loss_begin(self, callbacks: list[str] = None) -> bool:
419
419
  """
420
420
  Call the 'on_train_loss_begin' method of every callback in the callbacks list. If the list is 'None'
421
421
  (not given), all callbacks will be called.
@@ -430,7 +430,7 @@ class CallbacksHandler:
430
430
  )
431
431
 
432
432
  def on_train_loss_end(
433
- self, loss_value: PyTorchTypes.MetricValueType, callbacks: List[str] = None
433
+ self, loss_value: PyTorchTypes.MetricValueType, callbacks: list[str] = None
434
434
  ) -> bool:
435
435
  """
436
436
  Call the 'on_train_loss_end' method of every callback in the callbacks list. If the list is 'None' (not given),
@@ -447,7 +447,7 @@ class CallbacksHandler:
447
447
  loss_value=loss_value,
448
448
  )
449
449
 
450
- def on_validation_loss_begin(self, callbacks: List[str] = None) -> bool:
450
+ def on_validation_loss_begin(self, callbacks: list[str] = None) -> bool:
451
451
  """
452
452
  Call the 'on_validation_loss_begin' method of every callback in the callbacks list. If the list is 'None'
453
453
  (not given), all callbacks will be called.
@@ -462,7 +462,7 @@ class CallbacksHandler:
462
462
  )
463
463
 
464
464
  def on_validation_loss_end(
465
- self, loss_value: PyTorchTypes.MetricValueType, callbacks: List[str] = None
465
+ self, loss_value: PyTorchTypes.MetricValueType, callbacks: list[str] = None
466
466
  ) -> bool:
467
467
  """
468
468
  Call the 'on_validation_loss_end' method of every callback in the callbacks list. If the list is 'None'
@@ -479,7 +479,7 @@ class CallbacksHandler:
479
479
  loss_value=loss_value,
480
480
  )
481
481
 
482
- def on_train_metrics_begin(self, callbacks: List[str] = None) -> bool:
482
+ def on_train_metrics_begin(self, callbacks: list[str] = None) -> bool:
483
483
  """
484
484
  Call the 'on_train_metrics_begin' method of every callback in the callbacks list. If the list is 'None'
485
485
  (not given), all callbacks will be called.
@@ -495,8 +495,8 @@ class CallbacksHandler:
495
495
 
496
496
  def on_train_metrics_end(
497
497
  self,
498
- metric_values: List[PyTorchTypes.MetricValueType],
499
- callbacks: List[str] = None,
498
+ metric_values: list[PyTorchTypes.MetricValueType],
499
+ callbacks: list[str] = None,
500
500
  ) -> bool:
501
501
  """
502
502
  Call the 'on_train_metrics_end' method of every callback in the callbacks list. If the list is 'None'
@@ -513,7 +513,7 @@ class CallbacksHandler:
513
513
  metric_values=metric_values,
514
514
  )
515
515
 
516
- def on_validation_metrics_begin(self, callbacks: List[str] = None) -> bool:
516
+ def on_validation_metrics_begin(self, callbacks: list[str] = None) -> bool:
517
517
  """
518
518
  Call the 'on_validation_metrics_begin' method of every callback in the callbacks list. If the list is 'None'
519
519
  (not given), all callbacks will be called.
@@ -529,8 +529,8 @@ class CallbacksHandler:
529
529
 
530
530
  def on_validation_metrics_end(
531
531
  self,
532
- metric_values: List[PyTorchTypes.MetricValueType],
533
- callbacks: List[str] = None,
532
+ metric_values: list[PyTorchTypes.MetricValueType],
533
+ callbacks: list[str] = None,
534
534
  ) -> bool:
535
535
  """
536
536
  Call the 'on_validation_metrics_end' method of every callback in the callbacks list. If the list is 'None'
@@ -547,7 +547,7 @@ class CallbacksHandler:
547
547
  metric_values=metric_values,
548
548
  )
549
549
 
550
- def on_backward_begin(self, callbacks: List[str] = None) -> bool:
550
+ def on_backward_begin(self, callbacks: list[str] = None) -> bool:
551
551
  """
552
552
  Call the 'on_backward_begin' method of every callback in the callbacks list. If the list is 'None' (not given),
553
553
  all callbacks will be called.
@@ -561,7 +561,7 @@ class CallbacksHandler:
561
561
  callbacks=self._parse_names(names=callbacks),
562
562
  )
563
563
 
564
- def on_backward_end(self, callbacks: List[str] = None) -> bool:
564
+ def on_backward_end(self, callbacks: list[str] = None) -> bool:
565
565
  """
566
566
  Call the 'on_backward_end' method of every callback in the callbacks list. If the list is 'None' (not given),
567
567
  all callbacks will be called.
@@ -575,7 +575,7 @@ class CallbacksHandler:
575
575
  callbacks=self._parse_names(names=callbacks),
576
576
  )
577
577
 
578
- def on_optimizer_step_begin(self, callbacks: List[str] = None) -> bool:
578
+ def on_optimizer_step_begin(self, callbacks: list[str] = None) -> bool:
579
579
  """
580
580
  Call the 'on_optimizer_step_begin' method of every callback in the callbacks list. If the list is 'None'
581
581
  (not given), all callbacks will be called.
@@ -589,7 +589,7 @@ class CallbacksHandler:
589
589
  callbacks=self._parse_names(names=callbacks),
590
590
  )
591
591
 
592
- def on_optimizer_step_end(self, callbacks: List[str] = None) -> bool:
592
+ def on_optimizer_step_end(self, callbacks: list[str] = None) -> bool:
593
593
  """
594
594
  Call the 'on_optimizer_step_end' method of every callback in the callbacks list. If the list is 'None'
595
595
  (not given), all callbacks will be called.
@@ -603,7 +603,7 @@ class CallbacksHandler:
603
603
  callbacks=self._parse_names(names=callbacks),
604
604
  )
605
605
 
606
- def on_scheduler_step_begin(self, callbacks: List[str] = None) -> bool:
606
+ def on_scheduler_step_begin(self, callbacks: list[str] = None) -> bool:
607
607
  """
608
608
  Call the 'on_scheduler_step_begin' method of every callback in the callbacks list. If the list is 'None'
609
609
  (not given), all callbacks will be called.
@@ -617,7 +617,7 @@ class CallbacksHandler:
617
617
  callbacks=self._parse_names(names=callbacks),
618
618
  )
619
619
 
620
- def on_scheduler_step_end(self, callbacks: List[str] = None) -> bool:
620
+ def on_scheduler_step_end(self, callbacks: list[str] = None) -> bool:
621
621
  """
622
622
  Call the 'on_scheduler_step_end' method of every callback in the callbacks list. If the list is 'None'
623
623
  (not given), all callbacks will be called.
@@ -631,7 +631,7 @@ class CallbacksHandler:
631
631
  callbacks=self._parse_names(names=callbacks),
632
632
  )
633
633
 
634
- def _parse_names(self, names: Union[List[str], None]) -> List[str]:
634
+ def _parse_names(self, names: Union[list[str], None]) -> list[str]:
635
635
  """
636
636
  Parse the given callbacks names. If they are not 'None' then the names will be returned as they are, otherwise
637
637
  all of the callbacks handled by this handler will be returned (the default behavior of when there were no names
@@ -646,7 +646,7 @@ class CallbacksHandler:
646
646
  return list(self._callbacks.keys())
647
647
 
648
648
  def _run_callbacks(
649
- self, method_name: str, callbacks: List[str], *args, **kwargs
649
+ self, method_name: str, callbacks: list[str], *args, **kwargs
650
650
  ) -> bool:
651
651
  """
652
652
  Run the given method from the 'CallbackInterface' on all the specified callbacks with the given arguments.
@@ -14,7 +14,7 @@
14
14
  #
15
15
  import importlib
16
16
  import sys
17
- from typing import Any, Dict, List, Tuple, Union
17
+ from typing import Any, Union
18
18
 
19
19
  import torch
20
20
  import torch.multiprocessing as mp
@@ -109,13 +109,13 @@ class PyTorchMLRunInterface:
109
109
  loss_function: Module,
110
110
  optimizer: Optimizer,
111
111
  validation_set: DataLoader = None,
112
- metric_functions: List[PyTorchTypes.MetricFunctionType] = None,
112
+ metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
113
113
  scheduler=None,
114
114
  scheduler_step_frequency: Union[int, float, str] = "epoch",
115
115
  epochs: int = 1,
116
116
  training_iterations: int = None,
117
117
  validation_iterations: int = None,
118
- callbacks: List[Callback] = None,
118
+ callbacks: list[Callback] = None,
119
119
  use_cuda: bool = True,
120
120
  use_horovod: bool = None,
121
121
  ):
@@ -221,12 +221,12 @@ class PyTorchMLRunInterface:
221
221
  self,
222
222
  dataset: DataLoader,
223
223
  loss_function: Module = None,
224
- metric_functions: List[PyTorchTypes.MetricFunctionType] = None,
224
+ metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
225
225
  iterations: int = None,
226
- callbacks: List[Callback] = None,
226
+ callbacks: list[Callback] = None,
227
227
  use_cuda: bool = True,
228
228
  use_horovod: bool = None,
229
- ) -> List[PyTorchTypes.MetricValueType]:
229
+ ) -> list[PyTorchTypes.MetricValueType]:
230
230
  """
231
231
  Initiate an evaluation process on this interface configuration.
232
232
 
@@ -303,9 +303,9 @@ class PyTorchMLRunInterface:
303
303
  def add_auto_logging_callbacks(
304
304
  self,
305
305
  add_mlrun_logger: bool = True,
306
- mlrun_callback_kwargs: Dict[str, Any] = None,
306
+ mlrun_callback_kwargs: dict[str, Any] = None,
307
307
  add_tensorboard_logger: bool = True,
308
- tensorboard_callback_kwargs: Dict[str, Any] = None,
308
+ tensorboard_callback_kwargs: dict[str, Any] = None,
309
309
  ):
310
310
  """
311
311
  Get automatic logging callbacks to both MLRun's context and Tensorboard. For further features of logging to both
@@ -347,7 +347,7 @@ class PyTorchMLRunInterface:
347
347
 
348
348
  def predict(
349
349
  self,
350
- inputs: Union[Tensor, List[Tensor]],
350
+ inputs: Union[Tensor, list[Tensor]],
351
351
  use_cuda: bool = True,
352
352
  batch_size: int = -1,
353
353
  ) -> Tensor:
@@ -402,13 +402,13 @@ class PyTorchMLRunInterface:
402
402
  loss_function: Module = None,
403
403
  optimizer: Optimizer = None,
404
404
  validation_set: DataLoader = None,
405
- metric_functions: List[PyTorchTypes.MetricFunctionType] = None,
405
+ metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
406
406
  scheduler=None,
407
407
  scheduler_step_frequency: Union[int, float, str] = "epoch",
408
408
  epochs: int = 1,
409
409
  training_iterations: int = None,
410
410
  validation_iterations: int = None,
411
- callbacks: List[Callback] = None,
411
+ callbacks: list[Callback] = None,
412
412
  use_cuda: bool = True,
413
413
  use_horovod: bool = None,
414
414
  ):
@@ -734,7 +734,7 @@ class PyTorchMLRunInterface:
734
734
 
735
735
  def _validate(
736
736
  self, is_evaluation: bool = False
737
- ) -> Tuple[PyTorchTypes.MetricValueType, List[PyTorchTypes.MetricValueType]]:
737
+ ) -> tuple[PyTorchTypes.MetricValueType, list[PyTorchTypes.MetricValueType]]:
738
738
  """
739
739
  Initiate a single epoch validation.
740
740
 
@@ -817,7 +817,7 @@ class PyTorchMLRunInterface:
817
817
  )
818
818
  return loss_value, metric_values
819
819
 
820
- def _print_results(self, loss_value: Tensor, metric_values: List[float]):
820
+ def _print_results(self, loss_value: Tensor, metric_values: list[float]):
821
821
  """
822
822
  Print the given result between each epoch.
823
823
 
@@ -832,7 +832,7 @@ class PyTorchMLRunInterface:
832
832
  + tabulate(table, headers=["Metrics", "Values"], tablefmt="pretty")
833
833
  )
834
834
 
835
- def _metrics(self, y_pred: Tensor, y_true: Tensor) -> List[float]:
835
+ def _metrics(self, y_pred: Tensor, y_true: Tensor) -> list[float]:
836
836
  """
837
837
  Call all the metrics on the given batch's truth and prediction output.
838
838
 
@@ -860,7 +860,7 @@ class PyTorchMLRunInterface:
860
860
  average_tensor = self._hvd.allreduce(rank_value, name=name)
861
861
  return average_tensor.item()
862
862
 
863
- def _get_learning_rate(self) -> Union[Tuple[str, List[Union[str, int]]], None]:
863
+ def _get_learning_rate(self) -> Union[tuple[str, list[Union[str, int]]], None]:
864
864
  """
865
865
  Try and get the learning rate value form the stored optimizer.
866
866
 
@@ -949,8 +949,8 @@ class PyTorchMLRunInterface:
949
949
 
950
950
  @staticmethod
951
951
  def _tensor_to_cuda(
952
- tensor: Union[Tensor, Dict, List, Tuple],
953
- ) -> Union[Tensor, Dict, List, Tuple]:
952
+ tensor: Union[Tensor, dict, list, tuple],
953
+ ) -> Union[Tensor, dict, list, tuple]:
954
954
  """
955
955
  Send to given tensor to cuda if it is a tensor. If the given object is a dictionary, the dictionary values will
956
956
  be sent to the function again recursively. If the given object is a list or a tuple, all the values in it will
@@ -997,7 +997,7 @@ class PyTorchMLRunInterface:
997
997
  dataset: DataLoader,
998
998
  iterations: int,
999
999
  description: str,
1000
- metrics: List[PyTorchTypes.MetricFunctionType],
1000
+ metrics: list[PyTorchTypes.MetricFunctionType],
1001
1001
  ) -> tqdm:
1002
1002
  """
1003
1003
  Create a progress bar for training and validating / evaluating.
@@ -1028,8 +1028,8 @@ class PyTorchMLRunInterface:
1028
1028
  @staticmethod
1029
1029
  def _update_progress_bar(
1030
1030
  progress_bar: tqdm,
1031
- metrics: List[PyTorchTypes.MetricFunctionType],
1032
- values: List[PyTorchTypes.MetricValueType],
1031
+ metrics: list[PyTorchTypes.MetricFunctionType],
1032
+ values: list[PyTorchTypes.MetricValueType],
1033
1033
  ):
1034
1034
  """
1035
1035
  Update the progress bar metrics results.