mlrun 1.6.4rc7__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (305) hide show
  1. mlrun/__init__.py +11 -1
  2. mlrun/__main__.py +40 -122
  3. mlrun/alerts/__init__.py +15 -0
  4. mlrun/alerts/alert.py +248 -0
  5. mlrun/api/schemas/__init__.py +5 -4
  6. mlrun/artifacts/__init__.py +8 -3
  7. mlrun/artifacts/base.py +47 -257
  8. mlrun/artifacts/dataset.py +11 -192
  9. mlrun/artifacts/manager.py +79 -47
  10. mlrun/artifacts/model.py +31 -159
  11. mlrun/artifacts/plots.py +23 -380
  12. mlrun/common/constants.py +74 -1
  13. mlrun/common/db/sql_session.py +5 -5
  14. mlrun/common/formatters/__init__.py +21 -0
  15. mlrun/common/formatters/artifact.py +45 -0
  16. mlrun/common/formatters/base.py +113 -0
  17. mlrun/common/formatters/feature_set.py +33 -0
  18. mlrun/common/formatters/function.py +46 -0
  19. mlrun/common/formatters/pipeline.py +53 -0
  20. mlrun/common/formatters/project.py +51 -0
  21. mlrun/common/formatters/run.py +29 -0
  22. mlrun/common/helpers.py +12 -3
  23. mlrun/common/model_monitoring/helpers.py +9 -5
  24. mlrun/{runtimes → common/runtimes}/constants.py +37 -9
  25. mlrun/common/schemas/__init__.py +31 -5
  26. mlrun/common/schemas/alert.py +202 -0
  27. mlrun/common/schemas/api_gateway.py +196 -0
  28. mlrun/common/schemas/artifact.py +25 -4
  29. mlrun/common/schemas/auth.py +16 -5
  30. mlrun/common/schemas/background_task.py +1 -1
  31. mlrun/common/schemas/client_spec.py +4 -2
  32. mlrun/common/schemas/common.py +7 -4
  33. mlrun/common/schemas/constants.py +3 -0
  34. mlrun/common/schemas/feature_store.py +74 -44
  35. mlrun/common/schemas/frontend_spec.py +15 -7
  36. mlrun/common/schemas/function.py +12 -1
  37. mlrun/common/schemas/hub.py +11 -18
  38. mlrun/common/schemas/memory_reports.py +2 -2
  39. mlrun/common/schemas/model_monitoring/__init__.py +20 -4
  40. mlrun/common/schemas/model_monitoring/constants.py +123 -42
  41. mlrun/common/schemas/model_monitoring/grafana.py +13 -9
  42. mlrun/common/schemas/model_monitoring/model_endpoints.py +101 -54
  43. mlrun/common/schemas/notification.py +71 -14
  44. mlrun/common/schemas/object.py +2 -2
  45. mlrun/{model_monitoring/controller_handler.py → common/schemas/pagination.py} +9 -12
  46. mlrun/common/schemas/pipeline.py +8 -1
  47. mlrun/common/schemas/project.py +69 -18
  48. mlrun/common/schemas/runs.py +7 -1
  49. mlrun/common/schemas/runtime_resource.py +8 -12
  50. mlrun/common/schemas/schedule.py +4 -4
  51. mlrun/common/schemas/tag.py +1 -2
  52. mlrun/common/schemas/workflow.py +12 -4
  53. mlrun/common/types.py +14 -1
  54. mlrun/config.py +154 -69
  55. mlrun/data_types/data_types.py +6 -1
  56. mlrun/data_types/spark.py +2 -2
  57. mlrun/data_types/to_pandas.py +67 -37
  58. mlrun/datastore/__init__.py +6 -8
  59. mlrun/datastore/alibaba_oss.py +131 -0
  60. mlrun/datastore/azure_blob.py +143 -42
  61. mlrun/datastore/base.py +102 -58
  62. mlrun/datastore/datastore.py +34 -13
  63. mlrun/datastore/datastore_profile.py +146 -20
  64. mlrun/datastore/dbfs_store.py +3 -7
  65. mlrun/datastore/filestore.py +1 -4
  66. mlrun/datastore/google_cloud_storage.py +97 -33
  67. mlrun/datastore/hdfs.py +56 -0
  68. mlrun/datastore/inmem.py +6 -3
  69. mlrun/datastore/redis.py +7 -2
  70. mlrun/datastore/s3.py +34 -12
  71. mlrun/datastore/snowflake_utils.py +45 -0
  72. mlrun/datastore/sources.py +303 -111
  73. mlrun/datastore/spark_utils.py +31 -2
  74. mlrun/datastore/store_resources.py +9 -7
  75. mlrun/datastore/storeytargets.py +151 -0
  76. mlrun/datastore/targets.py +453 -176
  77. mlrun/datastore/utils.py +72 -58
  78. mlrun/datastore/v3io.py +6 -1
  79. mlrun/db/base.py +274 -41
  80. mlrun/db/factory.py +1 -1
  81. mlrun/db/httpdb.py +893 -225
  82. mlrun/db/nopdb.py +291 -33
  83. mlrun/errors.py +36 -6
  84. mlrun/execution.py +115 -42
  85. mlrun/feature_store/__init__.py +0 -2
  86. mlrun/feature_store/api.py +65 -73
  87. mlrun/feature_store/common.py +7 -12
  88. mlrun/feature_store/feature_set.py +76 -55
  89. mlrun/feature_store/feature_vector.py +39 -31
  90. mlrun/feature_store/ingestion.py +7 -6
  91. mlrun/feature_store/retrieval/base.py +16 -11
  92. mlrun/feature_store/retrieval/dask_merger.py +2 -0
  93. mlrun/feature_store/retrieval/job.py +13 -4
  94. mlrun/feature_store/retrieval/local_merger.py +2 -0
  95. mlrun/feature_store/retrieval/spark_merger.py +24 -32
  96. mlrun/feature_store/steps.py +45 -34
  97. mlrun/features.py +11 -21
  98. mlrun/frameworks/_common/artifacts_library.py +9 -9
  99. mlrun/frameworks/_common/mlrun_interface.py +5 -5
  100. mlrun/frameworks/_common/model_handler.py +48 -48
  101. mlrun/frameworks/_common/plan.py +5 -6
  102. mlrun/frameworks/_common/producer.py +3 -4
  103. mlrun/frameworks/_common/utils.py +5 -5
  104. mlrun/frameworks/_dl_common/loggers/logger.py +6 -7
  105. mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +9 -9
  106. mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +23 -47
  107. mlrun/frameworks/_ml_common/artifacts_library.py +1 -2
  108. mlrun/frameworks/_ml_common/loggers/logger.py +3 -4
  109. mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +4 -5
  110. mlrun/frameworks/_ml_common/model_handler.py +24 -24
  111. mlrun/frameworks/_ml_common/pkl_model_server.py +2 -2
  112. mlrun/frameworks/_ml_common/plan.py +2 -2
  113. mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +2 -3
  114. mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +2 -3
  115. mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
  116. mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +3 -3
  117. mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
  118. mlrun/frameworks/_ml_common/utils.py +4 -4
  119. mlrun/frameworks/auto_mlrun/auto_mlrun.py +9 -9
  120. mlrun/frameworks/huggingface/model_server.py +4 -4
  121. mlrun/frameworks/lgbm/__init__.py +33 -33
  122. mlrun/frameworks/lgbm/callbacks/callback.py +2 -4
  123. mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -5
  124. mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -5
  125. mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +1 -3
  126. mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +6 -6
  127. mlrun/frameworks/lgbm/model_handler.py +10 -10
  128. mlrun/frameworks/lgbm/model_server.py +6 -6
  129. mlrun/frameworks/lgbm/utils.py +5 -5
  130. mlrun/frameworks/onnx/dataset.py +8 -8
  131. mlrun/frameworks/onnx/mlrun_interface.py +3 -3
  132. mlrun/frameworks/onnx/model_handler.py +6 -6
  133. mlrun/frameworks/onnx/model_server.py +7 -7
  134. mlrun/frameworks/parallel_coordinates.py +6 -6
  135. mlrun/frameworks/pytorch/__init__.py +18 -18
  136. mlrun/frameworks/pytorch/callbacks/callback.py +4 -5
  137. mlrun/frameworks/pytorch/callbacks/logging_callback.py +17 -17
  138. mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +11 -11
  139. mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +23 -29
  140. mlrun/frameworks/pytorch/callbacks_handler.py +38 -38
  141. mlrun/frameworks/pytorch/mlrun_interface.py +20 -20
  142. mlrun/frameworks/pytorch/model_handler.py +17 -17
  143. mlrun/frameworks/pytorch/model_server.py +7 -7
  144. mlrun/frameworks/sklearn/__init__.py +13 -13
  145. mlrun/frameworks/sklearn/estimator.py +4 -4
  146. mlrun/frameworks/sklearn/metrics_library.py +14 -14
  147. mlrun/frameworks/sklearn/mlrun_interface.py +16 -9
  148. mlrun/frameworks/sklearn/model_handler.py +2 -2
  149. mlrun/frameworks/tf_keras/__init__.py +10 -7
  150. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +15 -15
  151. mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +11 -11
  152. mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +19 -23
  153. mlrun/frameworks/tf_keras/mlrun_interface.py +9 -11
  154. mlrun/frameworks/tf_keras/model_handler.py +14 -14
  155. mlrun/frameworks/tf_keras/model_server.py +6 -6
  156. mlrun/frameworks/xgboost/__init__.py +13 -13
  157. mlrun/frameworks/xgboost/model_handler.py +6 -6
  158. mlrun/k8s_utils.py +61 -17
  159. mlrun/launcher/__init__.py +1 -1
  160. mlrun/launcher/base.py +16 -15
  161. mlrun/launcher/client.py +13 -11
  162. mlrun/launcher/factory.py +1 -1
  163. mlrun/launcher/local.py +23 -13
  164. mlrun/launcher/remote.py +17 -10
  165. mlrun/lists.py +7 -6
  166. mlrun/model.py +478 -103
  167. mlrun/model_monitoring/__init__.py +1 -1
  168. mlrun/model_monitoring/api.py +163 -371
  169. mlrun/{runtimes/mpijob/v1alpha1.py → model_monitoring/applications/__init__.py} +9 -15
  170. mlrun/model_monitoring/applications/_application_steps.py +188 -0
  171. mlrun/model_monitoring/applications/base.py +108 -0
  172. mlrun/model_monitoring/applications/context.py +341 -0
  173. mlrun/model_monitoring/{evidently_application.py → applications/evidently_base.py} +27 -22
  174. mlrun/model_monitoring/applications/histogram_data_drift.py +354 -0
  175. mlrun/model_monitoring/applications/results.py +99 -0
  176. mlrun/model_monitoring/controller.py +131 -278
  177. mlrun/model_monitoring/db/__init__.py +18 -0
  178. mlrun/model_monitoring/db/stores/__init__.py +136 -0
  179. mlrun/model_monitoring/db/stores/base/__init__.py +15 -0
  180. mlrun/model_monitoring/db/stores/base/store.py +213 -0
  181. mlrun/model_monitoring/db/stores/sqldb/__init__.py +13 -0
  182. mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +71 -0
  183. mlrun/model_monitoring/db/stores/sqldb/models/base.py +190 -0
  184. mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +103 -0
  185. mlrun/model_monitoring/{stores/models/mysql.py → db/stores/sqldb/models/sqlite.py} +19 -13
  186. mlrun/model_monitoring/db/stores/sqldb/sql_store.py +659 -0
  187. mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +13 -0
  188. mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +726 -0
  189. mlrun/model_monitoring/db/tsdb/__init__.py +105 -0
  190. mlrun/model_monitoring/db/tsdb/base.py +448 -0
  191. mlrun/model_monitoring/db/tsdb/helpers.py +30 -0
  192. mlrun/model_monitoring/db/tsdb/tdengine/__init__.py +15 -0
  193. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +279 -0
  194. mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +42 -0
  195. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +507 -0
  196. mlrun/model_monitoring/db/tsdb/v3io/__init__.py +15 -0
  197. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +158 -0
  198. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +849 -0
  199. mlrun/model_monitoring/features_drift_table.py +134 -106
  200. mlrun/model_monitoring/helpers.py +199 -55
  201. mlrun/model_monitoring/metrics/__init__.py +13 -0
  202. mlrun/model_monitoring/metrics/histogram_distance.py +127 -0
  203. mlrun/model_monitoring/model_endpoint.py +3 -2
  204. mlrun/model_monitoring/stream_processing.py +131 -398
  205. mlrun/model_monitoring/tracking_policy.py +9 -2
  206. mlrun/model_monitoring/writer.py +161 -125
  207. mlrun/package/__init__.py +6 -6
  208. mlrun/package/context_handler.py +5 -5
  209. mlrun/package/packager.py +7 -7
  210. mlrun/package/packagers/default_packager.py +8 -8
  211. mlrun/package/packagers/numpy_packagers.py +15 -15
  212. mlrun/package/packagers/pandas_packagers.py +5 -5
  213. mlrun/package/packagers/python_standard_library_packagers.py +10 -10
  214. mlrun/package/packagers_manager.py +19 -23
  215. mlrun/package/utils/_formatter.py +6 -6
  216. mlrun/package/utils/_pickler.py +2 -2
  217. mlrun/package/utils/_supported_format.py +4 -4
  218. mlrun/package/utils/log_hint_utils.py +2 -2
  219. mlrun/package/utils/type_hint_utils.py +4 -9
  220. mlrun/platforms/__init__.py +11 -10
  221. mlrun/platforms/iguazio.py +24 -203
  222. mlrun/projects/operations.py +52 -25
  223. mlrun/projects/pipelines.py +191 -197
  224. mlrun/projects/project.py +1227 -400
  225. mlrun/render.py +16 -19
  226. mlrun/run.py +209 -184
  227. mlrun/runtimes/__init__.py +83 -15
  228. mlrun/runtimes/base.py +51 -35
  229. mlrun/runtimes/daskjob.py +17 -10
  230. mlrun/runtimes/databricks_job/databricks_cancel_task.py +1 -1
  231. mlrun/runtimes/databricks_job/databricks_runtime.py +8 -7
  232. mlrun/runtimes/databricks_job/databricks_wrapper.py +1 -1
  233. mlrun/runtimes/funcdoc.py +1 -29
  234. mlrun/runtimes/function_reference.py +1 -1
  235. mlrun/runtimes/kubejob.py +34 -128
  236. mlrun/runtimes/local.py +40 -11
  237. mlrun/runtimes/mpijob/__init__.py +0 -20
  238. mlrun/runtimes/mpijob/abstract.py +9 -10
  239. mlrun/runtimes/mpijob/v1.py +1 -1
  240. mlrun/{model_monitoring/stores/models/sqlite.py → runtimes/nuclio/__init__.py} +7 -9
  241. mlrun/runtimes/nuclio/api_gateway.py +769 -0
  242. mlrun/runtimes/nuclio/application/__init__.py +15 -0
  243. mlrun/runtimes/nuclio/application/application.py +758 -0
  244. mlrun/runtimes/nuclio/application/reverse_proxy.go +95 -0
  245. mlrun/runtimes/{function.py → nuclio/function.py} +200 -83
  246. mlrun/runtimes/{nuclio.py → nuclio/nuclio.py} +6 -6
  247. mlrun/runtimes/{serving.py → nuclio/serving.py} +65 -68
  248. mlrun/runtimes/pod.py +281 -101
  249. mlrun/runtimes/remotesparkjob.py +12 -9
  250. mlrun/runtimes/sparkjob/spark3job.py +67 -51
  251. mlrun/runtimes/utils.py +41 -75
  252. mlrun/secrets.py +9 -5
  253. mlrun/serving/__init__.py +8 -1
  254. mlrun/serving/remote.py +2 -7
  255. mlrun/serving/routers.py +85 -69
  256. mlrun/serving/server.py +69 -44
  257. mlrun/serving/states.py +209 -36
  258. mlrun/serving/utils.py +22 -14
  259. mlrun/serving/v1_serving.py +6 -7
  260. mlrun/serving/v2_serving.py +129 -54
  261. mlrun/track/tracker.py +2 -1
  262. mlrun/track/tracker_manager.py +3 -3
  263. mlrun/track/trackers/mlflow_tracker.py +6 -2
  264. mlrun/utils/async_http.py +6 -8
  265. mlrun/utils/azure_vault.py +1 -1
  266. mlrun/utils/clones.py +1 -2
  267. mlrun/utils/condition_evaluator.py +3 -3
  268. mlrun/utils/db.py +21 -3
  269. mlrun/utils/helpers.py +405 -225
  270. mlrun/utils/http.py +3 -6
  271. mlrun/utils/logger.py +112 -16
  272. mlrun/utils/notifications/notification/__init__.py +17 -13
  273. mlrun/utils/notifications/notification/base.py +50 -2
  274. mlrun/utils/notifications/notification/console.py +2 -0
  275. mlrun/utils/notifications/notification/git.py +24 -1
  276. mlrun/utils/notifications/notification/ipython.py +3 -1
  277. mlrun/utils/notifications/notification/slack.py +96 -21
  278. mlrun/utils/notifications/notification/webhook.py +59 -2
  279. mlrun/utils/notifications/notification_pusher.py +149 -30
  280. mlrun/utils/regex.py +9 -0
  281. mlrun/utils/retryer.py +208 -0
  282. mlrun/utils/singleton.py +1 -1
  283. mlrun/utils/v3io_clients.py +4 -6
  284. mlrun/utils/version/version.json +2 -2
  285. mlrun/utils/version/version.py +2 -6
  286. mlrun-1.7.0.dist-info/METADATA +378 -0
  287. mlrun-1.7.0.dist-info/RECORD +351 -0
  288. {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/WHEEL +1 -1
  289. mlrun/feature_store/retrieval/conversion.py +0 -273
  290. mlrun/kfpops.py +0 -868
  291. mlrun/model_monitoring/application.py +0 -310
  292. mlrun/model_monitoring/batch.py +0 -1095
  293. mlrun/model_monitoring/prometheus.py +0 -219
  294. mlrun/model_monitoring/stores/__init__.py +0 -111
  295. mlrun/model_monitoring/stores/kv_model_endpoint_store.py +0 -576
  296. mlrun/model_monitoring/stores/model_endpoint_store.py +0 -147
  297. mlrun/model_monitoring/stores/models/__init__.py +0 -27
  298. mlrun/model_monitoring/stores/models/base.py +0 -84
  299. mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -384
  300. mlrun/platforms/other.py +0 -306
  301. mlrun-1.6.4rc7.dist-info/METADATA +0 -272
  302. mlrun-1.6.4rc7.dist-info/RECORD +0 -314
  303. {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/LICENSE +0 -0
  304. {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/entry_points.txt +0 -0
  305. {mlrun-1.6.4rc7.dist-info → mlrun-1.7.0.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  #
15
- from typing import Dict, List, Union
15
+ from typing import Union
16
16
 
17
17
  import numpy as np
18
18
  import pandas as pd
@@ -32,7 +32,7 @@ class Estimator:
32
32
  def __init__(
33
33
  self,
34
34
  context: mlrun.MLClientCtx = None,
35
- metrics: List[Metric] = None,
35
+ metrics: list[Metric] = None,
36
36
  ):
37
37
  """
38
38
  Initialize an estimator with the given metrics. The estimator will log the calculated results using the given
@@ -62,7 +62,7 @@ class Estimator:
62
62
  return self._context
63
63
 
64
64
  @property
65
- def results(self) -> Dict[str, float]:
65
+ def results(self) -> dict[str, float]:
66
66
  """
67
67
  Get the logged results.
68
68
 
@@ -86,7 +86,7 @@ class Estimator:
86
86
  """
87
87
  self._context = context
88
88
 
89
- def set_metrics(self, metrics: List[Metric]):
89
+ def set_metrics(self, metrics: list[Metric]):
90
90
  """
91
91
  Update the metrics of this logger to the given list of metrics here.
92
92
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  from abc import ABC
16
- from typing import Dict, List, Union
16
+ from typing import Union
17
17
 
18
18
  import sklearn
19
19
  from sklearn.preprocessing import LabelBinarizer
@@ -40,14 +40,14 @@ class MetricsLibrary(ABC):
40
40
  def get_metrics(
41
41
  cls,
42
42
  metrics: Union[
43
- List[Metric],
44
- List[SKLearnTypes.MetricEntryType],
45
- Dict[str, SKLearnTypes.MetricEntryType],
43
+ list[Metric],
44
+ list[SKLearnTypes.MetricEntryType],
45
+ dict[str, SKLearnTypes.MetricEntryType],
46
46
  ] = None,
47
47
  context: mlrun.MLClientCtx = None,
48
48
  include_default: bool = True,
49
49
  **default_kwargs,
50
- ) -> List[Metric]:
50
+ ) -> list[Metric]:
51
51
  """
52
52
  Get metrics for a run. The metrics will be taken from the provided metrics / configuration via code, from
53
53
  provided configuration via MLRun context and if the 'include_default' is True, from the metric library's
@@ -87,11 +87,11 @@ class MetricsLibrary(ABC):
87
87
  def _parse(
88
88
  cls,
89
89
  metrics: Union[
90
- List[Metric],
91
- List[SKLearnTypes.MetricEntryType],
92
- Dict[str, SKLearnTypes.MetricEntryType],
90
+ list[Metric],
91
+ list[SKLearnTypes.MetricEntryType],
92
+ dict[str, SKLearnTypes.MetricEntryType],
93
93
  ],
94
- ) -> List[Metric]:
94
+ ) -> list[Metric]:
95
95
  """
96
96
  Parse the given metrics by the possible rules of the framework implementing.
97
97
 
@@ -116,8 +116,8 @@ class MetricsLibrary(ABC):
116
116
 
117
117
  @classmethod
118
118
  def _from_list(
119
- cls, metrics_list: List[Union[Metric, SKLearnTypes.MetricEntryType]]
120
- ) -> List[Metric]:
119
+ cls, metrics_list: list[Union[Metric, SKLearnTypes.MetricEntryType]]
120
+ ) -> list[Metric]:
121
121
  """
122
122
  Collect the given metrics configurations from a list. The metrics names will be chosen by the following rules:
123
123
 
@@ -143,8 +143,8 @@ class MetricsLibrary(ABC):
143
143
 
144
144
  @classmethod
145
145
  def _from_dict(
146
- cls, metrics_dictionary: Dict[str, SKLearnTypes.MetricEntryType]
147
- ) -> List[Metric]:
146
+ cls, metrics_dictionary: dict[str, SKLearnTypes.MetricEntryType]
147
+ ) -> list[Metric]:
148
148
  """
149
149
  Collect the given metrics configurations from a dictionary.
150
150
 
@@ -165,7 +165,7 @@ class MetricsLibrary(ABC):
165
165
  @classmethod
166
166
  def _default(
167
167
  cls, model: SKLearnTypes.ModelType, y: SKLearnTypes.DatasetType = None
168
- ) -> List[Metric]:
168
+ ) -> list[Metric]:
169
169
  """
170
170
  Get the default metrics list according to the algorithm functionality.
171
171
 
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  from abc import ABC
16
- from typing import List
17
16
 
18
17
  import mlrun
19
18
 
@@ -75,9 +74,7 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
75
74
  cls._REPLACED_METHODS.remove("predict_proba")
76
75
 
77
76
  # Add the interface to the model:
78
- super(SKLearnMLRunInterface, cls).add_interface(
79
- obj=obj, restoration=restoration
80
- )
77
+ super().add_interface(obj=obj, restoration=restoration)
81
78
 
82
79
  # Restore the '_REPLACED_METHODS' list for next models:
83
80
  if "predict_proba" not in cls._REPLACED_METHODS:
@@ -100,7 +97,7 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
100
97
 
101
98
  def wrapper(
102
99
  self: SKLearnTypes.ModelType,
103
- X: SKLearnTypes.DatasetType,
100
+ X: SKLearnTypes.DatasetType, # noqa: N803 - should be lowercase "x", kept for BC
104
101
  y: SKLearnTypes.DatasetType = None,
105
102
  *args,
106
103
  **kwargs,
@@ -127,7 +124,12 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
127
124
 
128
125
  return wrapper
129
126
 
130
- def mlrun_predict(self, X: SKLearnTypes.DatasetType, *args, **kwargs):
127
+ def mlrun_predict(
128
+ self,
129
+ X: SKLearnTypes.DatasetType, # noqa: N803 - should be lowercase "x", kept for BC
130
+ *args,
131
+ **kwargs,
132
+ ):
131
133
  """
132
134
  MLRun's wrapper for the common ML API predict method.
133
135
  """
@@ -139,7 +141,12 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
139
141
 
140
142
  return y_pred
141
143
 
142
- def mlrun_predict_proba(self, X: SKLearnTypes.DatasetType, *args, **kwargs):
144
+ def mlrun_predict_proba(
145
+ self,
146
+ X: SKLearnTypes.DatasetType, # noqa: N803 - should be lowercase "x", kept for BC
147
+ *args,
148
+ **kwargs,
149
+ ):
143
150
  """
144
151
  MLRun's wrapper for the common ML API predict_proba method.
145
152
  """
@@ -154,8 +161,8 @@ class SKLearnMLRunInterface(MLRunInterface, ABC):
154
161
  def configure_logging(
155
162
  self,
156
163
  context: mlrun.MLClientCtx = None,
157
- plans: List[MLPlan] = None,
158
- metrics: List[Metric] = None,
164
+ plans: list[MLPlan] = None,
165
+ metrics: list[Metric] = None,
159
166
  x_test: SKLearnTypes.DatasetType = None,
160
167
  y_test: SKLearnTypes.DatasetType = None,
161
168
  model_handler: MLModelHandler = None,
@@ -59,7 +59,7 @@ class SKLearnModelHandler(MLModelHandler):
59
59
 
60
60
  :return The saved model additional artifacts (if needed) dictionary if context is available and None otherwise.
61
61
  """
62
- super(SKLearnModelHandler, self).save(output_path=output_path)
62
+ super().save(output_path=output_path)
63
63
 
64
64
  # Save the model pkl file:
65
65
  self._model_file = f"{self._model_name}.pkl"
@@ -73,7 +73,7 @@ class SKLearnModelHandler(MLModelHandler):
73
73
  Load the specified model in this handler. Additional parameters for the class initializer can be passed via the
74
74
  kwargs dictionary.
75
75
  """
76
- super(SKLearnModelHandler, self).load()
76
+ super().load()
77
77
 
78
78
  # Load from a pkl file:
79
79
  with open(self._model_file, "rb") as pickle_file:
@@ -13,11 +13,12 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  # flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
16
- from typing import Any, Dict, List, Union
16
+ from typing import Any, Union
17
17
 
18
18
  from tensorflow import keras
19
19
 
20
20
  import mlrun
21
+ import mlrun.common.constants as mlrun_constants
21
22
 
22
23
  from .callbacks import MLRunLoggingCallback, TensorboardLoggingCallback
23
24
  from .mlrun_interface import TFKerasMLRunInterface
@@ -33,14 +34,14 @@ def apply_mlrun(
33
34
  model_path: str = None,
34
35
  model_format: str = TFKerasModelHandler.ModelFormats.SAVED_MODEL,
35
36
  save_traces: bool = False,
36
- modules_map: Union[Dict[str, Union[None, str, List[str]]], str] = None,
37
- custom_objects_map: Union[Dict[str, Union[str, List[str]]], str] = None,
37
+ modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
38
+ custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
38
39
  custom_objects_directory: str = None,
39
40
  context: mlrun.MLClientCtx = None,
40
41
  auto_log: bool = True,
41
42
  tensorboard_directory: str = None,
42
- mlrun_callback_kwargs: Dict[str, Any] = None,
43
- tensorboard_callback_kwargs: Dict[str, Any] = None,
43
+ mlrun_callback_kwargs: dict[str, Any] = None,
44
+ tensorboard_callback_kwargs: dict[str, Any] = None,
44
45
  use_horovod: bool = None,
45
46
  **kwargs,
46
47
  ) -> TFKerasModelHandler:
@@ -85,7 +86,7 @@ def apply_mlrun(
85
86
 
86
87
  {
87
88
  "/.../custom_optimizer.py": "optimizer",
88
- "/.../custom_layers.py": ["layer1", "layer2"]
89
+ "/.../custom_layers.py": ["layer1", "layer2"],
89
90
  }
90
91
 
91
92
  All the paths will be accessed from the given 'custom_objects_directory',
@@ -126,7 +127,9 @@ def apply_mlrun(
126
127
  # # Use horovod:
127
128
  if use_horovod is None:
128
129
  use_horovod = (
129
- context.labels.get("kind", "") == "mpijob" if context is not None else False
130
+ context.labels.get(mlrun_constants.MLRunInternalLabels.kind, "") == "mpijob"
131
+ if context is not None
132
+ else False
130
133
  )
131
134
 
132
135
  # Create a model handler:
@@ -12,12 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  #
15
- from typing import Callable, Dict, List, Union
15
+ from typing import Callable, Union
16
16
 
17
17
  import numpy as np
18
18
  import tensorflow as tf
19
19
  from tensorflow import Tensor, Variable
20
- from tensorflow.keras.callbacks import Callback
20
+ from tensorflow.python.keras.callbacks import Callback
21
21
 
22
22
  import mlrun
23
23
 
@@ -36,11 +36,11 @@ class LoggingCallback(Callback):
36
36
  def __init__(
37
37
  self,
38
38
  context: mlrun.MLClientCtx = None,
39
- dynamic_hyperparameters: Dict[
40
- str, Union[List[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]]
39
+ dynamic_hyperparameters: dict[
40
+ str, Union[list[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]]
41
41
  ] = None,
42
- static_hyperparameters: Dict[
43
- str, Union[TFKerasTypes.TrackableType, List[Union[str, int]]]
42
+ static_hyperparameters: dict[
43
+ str, Union[TFKerasTypes.TrackableType, list[Union[str, int]]]
44
44
  ] = None,
45
45
  auto_log: bool = False,
46
46
  ):
@@ -70,7 +70,7 @@ class LoggingCallback(Callback):
70
70
  :param auto_log: Whether or not to enable auto logging, trying to track common static and dynamic
71
71
  hyperparameters.
72
72
  """
73
- super(LoggingCallback, self).__init__()
73
+ super().__init__()
74
74
  self._supports_tf_logs = True
75
75
 
76
76
  # Store the configurations:
@@ -93,7 +93,7 @@ class LoggingCallback(Callback):
93
93
  self._is_training = None # type: bool
94
94
  self._auto_log = auto_log
95
95
 
96
- def get_training_results(self) -> Dict[str, List[List[float]]]:
96
+ def get_training_results(self) -> dict[str, list[list[float]]]:
97
97
  """
98
98
  Get the training results logged. The results will be stored in a dictionary where each key is the metric name
99
99
  and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
@@ -103,7 +103,7 @@ class LoggingCallback(Callback):
103
103
  """
104
104
  return self._logger.training_results
105
105
 
106
- def get_validation_results(self) -> Dict[str, List[List[float]]]:
106
+ def get_validation_results(self) -> dict[str, list[list[float]]]:
107
107
  """
108
108
  Get the validation results logged. The results will be stored in a dictionary where each key is the metric name
109
109
  and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
@@ -113,7 +113,7 @@ class LoggingCallback(Callback):
113
113
  """
114
114
  return self._logger.validation_results
115
115
 
116
- def get_training_summaries(self) -> Dict[str, List[float]]:
116
+ def get_training_summaries(self) -> dict[str, list[float]]:
117
117
  """
118
118
  Get the training summaries of the metrics results. The summaries will be stored in a dictionary where each key
119
119
  is the metric names and the value is a list of all the summary values per epoch.
@@ -122,7 +122,7 @@ class LoggingCallback(Callback):
122
122
  """
123
123
  return self._logger.training_summaries
124
124
 
125
- def get_validation_summaries(self) -> Dict[str, List[float]]:
125
+ def get_validation_summaries(self) -> dict[str, list[float]]:
126
126
  """
127
127
  Get the validation summaries of the metrics results. The summaries will be stored in a dictionary where each key
128
128
  is the metric names and the value is a list of all the summary values per epoch.
@@ -131,7 +131,7 @@ class LoggingCallback(Callback):
131
131
  """
132
132
  return self._logger.validation_summaries
133
133
 
134
- def get_static_hyperparameters(self) -> Dict[str, TFKerasTypes.TrackableType]:
134
+ def get_static_hyperparameters(self) -> dict[str, TFKerasTypes.TrackableType]:
135
135
  """
136
136
  Get the static hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
137
137
  hyperparameter name and the value is his logged value.
@@ -142,7 +142,7 @@ class LoggingCallback(Callback):
142
142
 
143
143
  def get_dynamic_hyperparameters(
144
144
  self,
145
- ) -> Dict[str, List[TFKerasTypes.TrackableType]]:
145
+ ) -> dict[str, list[TFKerasTypes.TrackableType]]:
146
146
  """
147
147
  Get the dynamic hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
148
148
  hyperparameter name and the value is a list of his logged values per epoch.
@@ -329,7 +329,7 @@ class LoggingCallback(Callback):
329
329
 
330
330
  # Static hyperparameters:
331
331
  for name, value in self._static_hyperparameters_keys.items():
332
- if isinstance(value, List):
332
+ if isinstance(value, list):
333
333
  # Its a parameter that needed to be extracted via key chain.
334
334
  self._logger.log_static_hyperparameter(
335
335
  parameter_name=name,
@@ -398,7 +398,7 @@ class LoggingCallback(Callback):
398
398
  def _get_hyperparameter(
399
399
  self,
400
400
  key_chain: Union[
401
- Callable[[], TFKerasTypes.TrackableType], List[Union[str, int]]
401
+ Callable[[], TFKerasTypes.TrackableType], list[Union[str, int]]
402
402
  ],
403
403
  ) -> TFKerasTypes.TrackableType:
404
404
  """
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  #
15
- from typing import Callable, Dict, List, Union
15
+ from typing import Callable, Union
16
16
 
17
17
  import mlrun
18
18
  from mlrun.artifacts import Artifact
@@ -50,16 +50,16 @@ class MLRunLoggingCallback(LoggingCallback):
50
50
  context: mlrun.MLClientCtx,
51
51
  model_handler: TFKerasModelHandler,
52
52
  log_model_tag: str = "",
53
- log_model_labels: Dict[str, TFKerasTypes.TrackableType] = None,
54
- log_model_parameters: Dict[str, TFKerasTypes.TrackableType] = None,
55
- log_model_extra_data: Dict[
53
+ log_model_labels: dict[str, TFKerasTypes.TrackableType] = None,
54
+ log_model_parameters: dict[str, TFKerasTypes.TrackableType] = None,
55
+ log_model_extra_data: dict[
56
56
  str, Union[TFKerasTypes.TrackableType, Artifact]
57
57
  ] = None,
58
- dynamic_hyperparameters: Dict[
59
- str, Union[List[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]]
58
+ dynamic_hyperparameters: dict[
59
+ str, Union[list[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]]
60
60
  ] = None,
61
- static_hyperparameters: Dict[
62
- str, Union[TFKerasTypes, List[Union[str, int]]]
61
+ static_hyperparameters: dict[
62
+ str, Union[TFKerasTypes, list[Union[str, int]]]
63
63
  ] = None,
64
64
  auto_log: bool = False,
65
65
  ):
@@ -97,7 +97,7 @@ class MLRunLoggingCallback(LoggingCallback):
97
97
  trying to track common static and dynamic hyperparameters such as learning
98
98
  rate.
99
99
  """
100
- super(MLRunLoggingCallback, self).__init__(
100
+ super().__init__(
101
101
  dynamic_hyperparameters=dynamic_hyperparameters,
102
102
  static_hyperparameters=static_hyperparameters,
103
103
  auto_log=auto_log,
@@ -134,7 +134,7 @@ class MLRunLoggingCallback(LoggingCallback):
134
134
  :param logs: Currently no data is passed to this argument for this method but that may change in the
135
135
  future.
136
136
  """
137
- super(MLRunLoggingCallback, self).on_test_end(logs=logs)
137
+ super().on_test_end(logs=logs)
138
138
 
139
139
  # Check if its part of evaluation. If so, end the run:
140
140
  if self._logger.mode == LoggingMode.EVALUATION:
@@ -151,7 +151,7 @@ class MLRunLoggingCallback(LoggingCallback):
151
151
  performed. Validation result keys are prefixed with `val_`. For training epoch, the values of the
152
152
  `Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`.
153
153
  """
154
- super(MLRunLoggingCallback, self).on_epoch_end(epoch=epoch)
154
+ super().on_epoch_end(epoch=epoch)
155
155
 
156
156
  # Log the current epoch's results:
157
157
  self._logger.log_epoch_to_context(epoch=epoch)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  from datetime import datetime
16
- from typing import Callable, Dict, List, Union
16
+ from typing import Callable, Union
17
17
 
18
18
  import tensorflow as tf
19
19
  from packaging import version
@@ -38,7 +38,7 @@ class _TFKerasTensorboardLogger(TensorboardLogger):
38
38
 
39
39
  def __init__(
40
40
  self,
41
- statistics_functions: List[Callable[[Union[Variable]], Union[float, Variable]]],
41
+ statistics_functions: list[Callable[[Union[Variable]], Union[float, Variable]]],
42
42
  context: mlrun.MLClientCtx = None,
43
43
  tensorboard_directory: str = None,
44
44
  run_name: str = None,
@@ -67,7 +67,7 @@ class _TFKerasTensorboardLogger(TensorboardLogger):
67
67
  update. Notice that writing to tensorboard too frequently may cause the training
68
68
  to be slower. Default: 'epoch'.
69
69
  """
70
- super(_TFKerasTensorboardLogger, self).__init__(
70
+ super().__init__(
71
71
  statistics_functions=statistics_functions,
72
72
  context=context,
73
73
  tensorboard_directory=tensorboard_directory,
@@ -255,15 +255,15 @@ class TensorboardLoggingCallback(LoggingCallback):
255
255
  context: mlrun.MLClientCtx = None,
256
256
  tensorboard_directory: str = None,
257
257
  run_name: str = None,
258
- weights: Union[bool, List[str]] = False,
259
- statistics_functions: List[
258
+ weights: Union[bool, list[str]] = False,
259
+ statistics_functions: list[
260
260
  Callable[[Union[Variable, Tensor]], Union[float, Tensor]]
261
261
  ] = None,
262
- dynamic_hyperparameters: Dict[
263
- str, Union[List[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]]
262
+ dynamic_hyperparameters: dict[
263
+ str, Union[list[Union[str, int]], Callable[[], TFKerasTypes.TrackableType]]
264
264
  ] = None,
265
- static_hyperparameters: Dict[
266
- str, Union[TFKerasTypes.TrackableType, List[Union[str, int]]]
265
+ static_hyperparameters: dict[
266
+ str, Union[TFKerasTypes.TrackableType, list[Union[str, int]]]
267
267
  ] = None,
268
268
  update_frequency: Union[int, str] = "epoch",
269
269
  auto_log: bool = False,
@@ -325,7 +325,7 @@ class TensorboardLoggingCallback(LoggingCallback):
325
325
  :raise MLRunInvalidArgumentError: In case both 'context' and 'tensorboard_directory' parameters were not given
326
326
  or the 'update_frequency' was incorrect.
327
327
  """
328
- super(TensorboardLoggingCallback, self).__init__(
328
+ super().__init__(
329
329
  dynamic_hyperparameters=dynamic_hyperparameters,
330
330
  static_hyperparameters=static_hyperparameters,
331
331
  auto_log=auto_log,
@@ -352,7 +352,7 @@ class TensorboardLoggingCallback(LoggingCallback):
352
352
  self._logged_model = False
353
353
  self._logged_hyperparameters = False
354
354
 
355
- def get_weights(self) -> Dict[str, Variable]:
355
+ def get_weights(self) -> dict[str, Variable]:
356
356
  """
357
357
  Get the weights tensors tracked. The weights will be stored in a dictionary where each key is the weight's name
358
358
  and the value is the weight's parameter (tensor).
@@ -361,7 +361,7 @@ class TensorboardLoggingCallback(LoggingCallback):
361
361
  """
362
362
  return self._logger.weights
363
363
 
364
- def get_weights_statistics(self) -> Dict[str, Dict[str, List[float]]]:
364
+ def get_weights_statistics(self) -> dict[str, dict[str, list[float]]]:
365
365
  """
366
366
  Get the weights mean results logged. The results will be stored in a dictionary where each key is the weight's
367
367
  name and the value is a list of mean values per epoch.
@@ -408,7 +408,7 @@ class TensorboardLoggingCallback(LoggingCallback):
408
408
  :param logs: Currently the output of the last call to `on_epoch_end()` is passed to this argument for this
409
409
  method but that may change in the future.
410
410
  """
411
- super(TensorboardLoggingCallback, self).on_train_end()
411
+ super().on_train_end()
412
412
 
413
413
  # Write the final run summary:
414
414
  self._logger.write_final_summary_text()
@@ -453,7 +453,7 @@ class TensorboardLoggingCallback(LoggingCallback):
453
453
  :param logs: Currently no data is passed to this argument for this method but that may change in the
454
454
  future.
455
455
  """
456
- super(TensorboardLoggingCallback, self).on_test_end(logs=logs)
456
+ super().on_test_end(logs=logs)
457
457
 
458
458
  # Check if needed to end the run (in case of evaluation and not training):
459
459
  if not self._is_training:
@@ -477,7 +477,7 @@ class TensorboardLoggingCallback(LoggingCallback):
477
477
  `Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`.
478
478
  """
479
479
  # Update the dynamic hyperparameters
480
- super(TensorboardLoggingCallback, self).on_epoch_end(epoch=epoch)
480
+ super().on_epoch_end(epoch=epoch)
481
481
 
482
482
  # Log the weights statistics:
483
483
  self._logger.log_weights_statistics()
@@ -515,9 +515,7 @@ class TensorboardLoggingCallback(LoggingCallback):
515
515
  :param logs: Aggregated metric results up until this batch.
516
516
  """
517
517
  # Log the batch's results:
518
- super(TensorboardLoggingCallback, self).on_train_batch_end(
519
- batch=batch, logs=logs
520
- )
518
+ super().on_train_batch_end(batch=batch, logs=logs)
521
519
 
522
520
  # Write the batch loss and metrics results to their graphs:
523
521
  self._logger.write_training_results()
@@ -540,9 +538,7 @@ class TensorboardLoggingCallback(LoggingCallback):
540
538
  :param logs: Aggregated metric results up until this batch.
541
539
  """
542
540
  # Log the batch's results:
543
- super(TensorboardLoggingCallback, self).on_test_batch_end(
544
- batch=batch, logs=logs
545
- )
541
+ super().on_test_batch_end(batch=batch, logs=logs)
546
542
 
547
543
  # Write the batch loss and metrics results to their graphs:
548
544
  self._logger.write_validation_results()
@@ -555,7 +551,7 @@ class TensorboardLoggingCallback(LoggingCallback):
555
551
 
556
552
  @staticmethod
557
553
  def get_default_weight_statistics_list() -> (
558
- List[Callable[[Union[Variable, Tensor]], Union[float, Tensor]]]
554
+ list[Callable[[Union[Variable, Tensor]], Union[float, Tensor]]]
559
555
  ):
560
556
  """
561
557
  Get the default list of statistics functions being applied on the tracked weights each epoch.
@@ -569,7 +565,7 @@ class TensorboardLoggingCallback(LoggingCallback):
569
565
  After the trainer / evaluator run begins, this method will be called to setup the results, hyperparameters
570
566
  and weights dictionaries for logging.
571
567
  """
572
- super(TensorboardLoggingCallback, self)._setup_run()
568
+ super()._setup_run()
573
569
 
574
570
  # Check if needed to track weights:
575
571
  if self._tracked_weights is False:
@@ -15,11 +15,12 @@
15
15
  import importlib
16
16
  import os
17
17
  from abc import ABC
18
- from typing import List, Tuple, Union
18
+ from typing import Union
19
19
 
20
20
  import tensorflow as tf
21
21
  from tensorflow import keras
22
- from tensorflow.keras.callbacks import (
22
+ from tensorflow.keras.optimizers import Optimizer
23
+ from tensorflow.python.keras.callbacks import (
23
24
  BaseLogger,
24
25
  Callback,
25
26
  CSVLogger,
@@ -27,7 +28,6 @@ from tensorflow.keras.callbacks import (
27
28
  ProgbarLogger,
28
29
  TensorBoard,
29
30
  )
30
- from tensorflow.keras.optimizers import Optimizer
31
31
 
32
32
  import mlrun
33
33
 
@@ -88,9 +88,7 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
88
88
  :param restoration: Restoration information tuple as returned from 'remove_interface' in order to
89
89
  add the interface in a certain state.
90
90
  """
91
- super(TFKerasMLRunInterface, cls).add_interface(
92
- obj=obj, restoration=restoration
93
- )
91
+ super().add_interface(obj=obj, restoration=restoration)
94
92
 
95
93
  def mlrun_compile(self, *args, **kwargs):
96
94
  """
@@ -237,7 +235,7 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
237
235
  """
238
236
  self._RANK_0_ONLY_CALLBACKS.add(callback_name)
239
237
 
240
- def _pre_compile(self, optimizer: Optimizer) -> Tuple[Optimizer, Union[bool, None]]:
238
+ def _pre_compile(self, optimizer: Optimizer) -> tuple[Optimizer, Union[bool, None]]:
241
239
  """
242
240
  Method to call before calling 'compile' to setup the run and inputs for using horovod.
243
241
 
@@ -295,11 +293,11 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
295
293
 
296
294
  def _pre_fit(
297
295
  self,
298
- callbacks: List[Callback],
296
+ callbacks: list[Callback],
299
297
  verbose: int,
300
298
  steps_per_epoch: Union[int, None],
301
299
  validation_steps: Union[int, None],
302
- ) -> Tuple[List[Callback], int, Union[int, None], Union[int, None]]:
300
+ ) -> tuple[list[Callback], int, Union[int, None], Union[int, None]]:
303
301
  """
304
302
  Method to call before calling 'fit' to setup the run and inputs for using horovod.
305
303
 
@@ -366,9 +364,9 @@ class TFKerasMLRunInterface(MLRunInterface, ABC):
366
364
 
367
365
  def _pre_evaluate(
368
366
  self,
369
- callbacks: List[Callback],
367
+ callbacks: list[Callback],
370
368
  steps: Union[int, None],
371
- ) -> Tuple[List[Callback], Union[int, None]]:
369
+ ) -> tuple[list[Callback], Union[int, None]]:
372
370
  """
373
371
  Method to call before calling 'evaluate' to setup the run and inputs for using horovod.
374
372