mlrun 1.6.4rc8__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (305) hide show
  1. mlrun/__init__.py +11 -1
  2. mlrun/__main__.py +40 -122
  3. mlrun/alerts/__init__.py +15 -0
  4. mlrun/alerts/alert.py +248 -0
  5. mlrun/api/schemas/__init__.py +5 -4
  6. mlrun/artifacts/__init__.py +8 -3
  7. mlrun/artifacts/base.py +47 -257
  8. mlrun/artifacts/dataset.py +11 -192
  9. mlrun/artifacts/manager.py +79 -47
  10. mlrun/artifacts/model.py +31 -159
  11. mlrun/artifacts/plots.py +23 -380
  12. mlrun/common/constants.py +74 -1
  13. mlrun/common/db/sql_session.py +5 -5
  14. mlrun/common/formatters/__init__.py +21 -0
  15. mlrun/common/formatters/artifact.py +45 -0
  16. mlrun/common/formatters/base.py +113 -0
  17. mlrun/common/formatters/feature_set.py +33 -0
  18. mlrun/common/formatters/function.py +46 -0
  19. mlrun/common/formatters/pipeline.py +53 -0
  20. mlrun/common/formatters/project.py +51 -0
  21. mlrun/common/formatters/run.py +29 -0
  22. mlrun/common/helpers.py +12 -3
  23. mlrun/common/model_monitoring/helpers.py +9 -5
  24. mlrun/{runtimes → common/runtimes}/constants.py +37 -9
  25. mlrun/common/schemas/__init__.py +31 -5
  26. mlrun/common/schemas/alert.py +202 -0
  27. mlrun/common/schemas/api_gateway.py +196 -0
  28. mlrun/common/schemas/artifact.py +25 -4
  29. mlrun/common/schemas/auth.py +16 -5
  30. mlrun/common/schemas/background_task.py +1 -1
  31. mlrun/common/schemas/client_spec.py +4 -2
  32. mlrun/common/schemas/common.py +7 -4
  33. mlrun/common/schemas/constants.py +3 -0
  34. mlrun/common/schemas/feature_store.py +74 -44
  35. mlrun/common/schemas/frontend_spec.py +15 -7
  36. mlrun/common/schemas/function.py +12 -1
  37. mlrun/common/schemas/hub.py +11 -18
  38. mlrun/common/schemas/memory_reports.py +2 -2
  39. mlrun/common/schemas/model_monitoring/__init__.py +20 -4
  40. mlrun/common/schemas/model_monitoring/constants.py +123 -42
  41. mlrun/common/schemas/model_monitoring/grafana.py +13 -9
  42. mlrun/common/schemas/model_monitoring/model_endpoints.py +101 -54
  43. mlrun/common/schemas/notification.py +71 -14
  44. mlrun/common/schemas/object.py +2 -2
  45. mlrun/{model_monitoring/controller_handler.py → common/schemas/pagination.py} +9 -12
  46. mlrun/common/schemas/pipeline.py +8 -1
  47. mlrun/common/schemas/project.py +69 -18
  48. mlrun/common/schemas/runs.py +7 -1
  49. mlrun/common/schemas/runtime_resource.py +8 -12
  50. mlrun/common/schemas/schedule.py +4 -4
  51. mlrun/common/schemas/tag.py +1 -2
  52. mlrun/common/schemas/workflow.py +12 -4
  53. mlrun/common/types.py +14 -1
  54. mlrun/config.py +154 -69
  55. mlrun/data_types/data_types.py +6 -1
  56. mlrun/data_types/spark.py +2 -2
  57. mlrun/data_types/to_pandas.py +67 -37
  58. mlrun/datastore/__init__.py +6 -8
  59. mlrun/datastore/alibaba_oss.py +131 -0
  60. mlrun/datastore/azure_blob.py +143 -42
  61. mlrun/datastore/base.py +102 -58
  62. mlrun/datastore/datastore.py +34 -13
  63. mlrun/datastore/datastore_profile.py +146 -20
  64. mlrun/datastore/dbfs_store.py +3 -7
  65. mlrun/datastore/filestore.py +1 -4
  66. mlrun/datastore/google_cloud_storage.py +97 -33
  67. mlrun/datastore/hdfs.py +56 -0
  68. mlrun/datastore/inmem.py +6 -3
  69. mlrun/datastore/redis.py +7 -2
  70. mlrun/datastore/s3.py +34 -12
  71. mlrun/datastore/snowflake_utils.py +45 -0
  72. mlrun/datastore/sources.py +303 -111
  73. mlrun/datastore/spark_utils.py +31 -2
  74. mlrun/datastore/store_resources.py +9 -7
  75. mlrun/datastore/storeytargets.py +151 -0
  76. mlrun/datastore/targets.py +453 -176
  77. mlrun/datastore/utils.py +72 -58
  78. mlrun/datastore/v3io.py +6 -1
  79. mlrun/db/base.py +274 -41
  80. mlrun/db/factory.py +1 -1
  81. mlrun/db/httpdb.py +893 -225
  82. mlrun/db/nopdb.py +291 -33
  83. mlrun/errors.py +36 -6
  84. mlrun/execution.py +115 -42
  85. mlrun/feature_store/__init__.py +0 -2
  86. mlrun/feature_store/api.py +65 -73
  87. mlrun/feature_store/common.py +7 -12
  88. mlrun/feature_store/feature_set.py +76 -55
  89. mlrun/feature_store/feature_vector.py +39 -31
  90. mlrun/feature_store/ingestion.py +7 -6
  91. mlrun/feature_store/retrieval/base.py +16 -11
  92. mlrun/feature_store/retrieval/dask_merger.py +2 -0
  93. mlrun/feature_store/retrieval/job.py +13 -4
  94. mlrun/feature_store/retrieval/local_merger.py +2 -0
  95. mlrun/feature_store/retrieval/spark_merger.py +24 -32
  96. mlrun/feature_store/steps.py +45 -34
  97. mlrun/features.py +11 -21
  98. mlrun/frameworks/_common/artifacts_library.py +9 -9
  99. mlrun/frameworks/_common/mlrun_interface.py +5 -5
  100. mlrun/frameworks/_common/model_handler.py +48 -48
  101. mlrun/frameworks/_common/plan.py +5 -6
  102. mlrun/frameworks/_common/producer.py +3 -4
  103. mlrun/frameworks/_common/utils.py +5 -5
  104. mlrun/frameworks/_dl_common/loggers/logger.py +6 -7
  105. mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +9 -9
  106. mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +23 -47
  107. mlrun/frameworks/_ml_common/artifacts_library.py +1 -2
  108. mlrun/frameworks/_ml_common/loggers/logger.py +3 -4
  109. mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +4 -5
  110. mlrun/frameworks/_ml_common/model_handler.py +24 -24
  111. mlrun/frameworks/_ml_common/pkl_model_server.py +2 -2
  112. mlrun/frameworks/_ml_common/plan.py +2 -2
  113. mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +2 -3
  114. mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +2 -3
  115. mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
  116. mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +3 -3
  117. mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
  118. mlrun/frameworks/_ml_common/utils.py +4 -4
  119. mlrun/frameworks/auto_mlrun/auto_mlrun.py +9 -9
  120. mlrun/frameworks/huggingface/model_server.py +4 -4
  121. mlrun/frameworks/lgbm/__init__.py +33 -33
  122. mlrun/frameworks/lgbm/callbacks/callback.py +2 -4
  123. mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -5
  124. mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -5
  125. mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +1 -3
  126. mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +6 -6
  127. mlrun/frameworks/lgbm/model_handler.py +10 -10
  128. mlrun/frameworks/lgbm/model_server.py +6 -6
  129. mlrun/frameworks/lgbm/utils.py +5 -5
  130. mlrun/frameworks/onnx/dataset.py +8 -8
  131. mlrun/frameworks/onnx/mlrun_interface.py +3 -3
  132. mlrun/frameworks/onnx/model_handler.py +6 -6
  133. mlrun/frameworks/onnx/model_server.py +7 -7
  134. mlrun/frameworks/parallel_coordinates.py +6 -6
  135. mlrun/frameworks/pytorch/__init__.py +18 -18
  136. mlrun/frameworks/pytorch/callbacks/callback.py +4 -5
  137. mlrun/frameworks/pytorch/callbacks/logging_callback.py +17 -17
  138. mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +11 -11
  139. mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +23 -29
  140. mlrun/frameworks/pytorch/callbacks_handler.py +38 -38
  141. mlrun/frameworks/pytorch/mlrun_interface.py +20 -20
  142. mlrun/frameworks/pytorch/model_handler.py +17 -17
  143. mlrun/frameworks/pytorch/model_server.py +7 -7
  144. mlrun/frameworks/sklearn/__init__.py +13 -13
  145. mlrun/frameworks/sklearn/estimator.py +4 -4
  146. mlrun/frameworks/sklearn/metrics_library.py +14 -14
  147. mlrun/frameworks/sklearn/mlrun_interface.py +16 -9
  148. mlrun/frameworks/sklearn/model_handler.py +2 -2
  149. mlrun/frameworks/tf_keras/__init__.py +10 -7
  150. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +15 -15
  151. mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +11 -11
  152. mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +19 -23
  153. mlrun/frameworks/tf_keras/mlrun_interface.py +9 -11
  154. mlrun/frameworks/tf_keras/model_handler.py +14 -14
  155. mlrun/frameworks/tf_keras/model_server.py +6 -6
  156. mlrun/frameworks/xgboost/__init__.py +13 -13
  157. mlrun/frameworks/xgboost/model_handler.py +6 -6
  158. mlrun/k8s_utils.py +61 -17
  159. mlrun/launcher/__init__.py +1 -1
  160. mlrun/launcher/base.py +16 -15
  161. mlrun/launcher/client.py +13 -11
  162. mlrun/launcher/factory.py +1 -1
  163. mlrun/launcher/local.py +23 -13
  164. mlrun/launcher/remote.py +17 -10
  165. mlrun/lists.py +7 -6
  166. mlrun/model.py +478 -103
  167. mlrun/model_monitoring/__init__.py +1 -1
  168. mlrun/model_monitoring/api.py +163 -371
  169. mlrun/{runtimes/mpijob/v1alpha1.py → model_monitoring/applications/__init__.py} +9 -15
  170. mlrun/model_monitoring/applications/_application_steps.py +188 -0
  171. mlrun/model_monitoring/applications/base.py +108 -0
  172. mlrun/model_monitoring/applications/context.py +341 -0
  173. mlrun/model_monitoring/{evidently_application.py → applications/evidently_base.py} +27 -22
  174. mlrun/model_monitoring/applications/histogram_data_drift.py +354 -0
  175. mlrun/model_monitoring/applications/results.py +99 -0
  176. mlrun/model_monitoring/controller.py +131 -278
  177. mlrun/model_monitoring/db/__init__.py +18 -0
  178. mlrun/model_monitoring/db/stores/__init__.py +136 -0
  179. mlrun/model_monitoring/db/stores/base/__init__.py +15 -0
  180. mlrun/model_monitoring/db/stores/base/store.py +213 -0
  181. mlrun/model_monitoring/db/stores/sqldb/__init__.py +13 -0
  182. mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +71 -0
  183. mlrun/model_monitoring/db/stores/sqldb/models/base.py +190 -0
  184. mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +103 -0
  185. mlrun/model_monitoring/{stores/models/mysql.py → db/stores/sqldb/models/sqlite.py} +19 -13
  186. mlrun/model_monitoring/db/stores/sqldb/sql_store.py +659 -0
  187. mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +13 -0
  188. mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +726 -0
  189. mlrun/model_monitoring/db/tsdb/__init__.py +105 -0
  190. mlrun/model_monitoring/db/tsdb/base.py +448 -0
  191. mlrun/model_monitoring/db/tsdb/helpers.py +30 -0
  192. mlrun/model_monitoring/db/tsdb/tdengine/__init__.py +15 -0
  193. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +279 -0
  194. mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +42 -0
  195. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +507 -0
  196. mlrun/model_monitoring/db/tsdb/v3io/__init__.py +15 -0
  197. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +158 -0
  198. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +849 -0
  199. mlrun/model_monitoring/features_drift_table.py +134 -106
  200. mlrun/model_monitoring/helpers.py +199 -55
  201. mlrun/model_monitoring/metrics/__init__.py +13 -0
  202. mlrun/model_monitoring/metrics/histogram_distance.py +127 -0
  203. mlrun/model_monitoring/model_endpoint.py +3 -2
  204. mlrun/model_monitoring/stream_processing.py +134 -398
  205. mlrun/model_monitoring/tracking_policy.py +9 -2
  206. mlrun/model_monitoring/writer.py +161 -125
  207. mlrun/package/__init__.py +6 -6
  208. mlrun/package/context_handler.py +5 -5
  209. mlrun/package/packager.py +7 -7
  210. mlrun/package/packagers/default_packager.py +8 -8
  211. mlrun/package/packagers/numpy_packagers.py +15 -15
  212. mlrun/package/packagers/pandas_packagers.py +5 -5
  213. mlrun/package/packagers/python_standard_library_packagers.py +10 -10
  214. mlrun/package/packagers_manager.py +19 -23
  215. mlrun/package/utils/_formatter.py +6 -6
  216. mlrun/package/utils/_pickler.py +2 -2
  217. mlrun/package/utils/_supported_format.py +4 -4
  218. mlrun/package/utils/log_hint_utils.py +2 -2
  219. mlrun/package/utils/type_hint_utils.py +4 -9
  220. mlrun/platforms/__init__.py +11 -10
  221. mlrun/platforms/iguazio.py +24 -203
  222. mlrun/projects/operations.py +52 -25
  223. mlrun/projects/pipelines.py +191 -197
  224. mlrun/projects/project.py +1227 -400
  225. mlrun/render.py +16 -19
  226. mlrun/run.py +209 -184
  227. mlrun/runtimes/__init__.py +83 -15
  228. mlrun/runtimes/base.py +51 -35
  229. mlrun/runtimes/daskjob.py +17 -10
  230. mlrun/runtimes/databricks_job/databricks_cancel_task.py +1 -1
  231. mlrun/runtimes/databricks_job/databricks_runtime.py +8 -7
  232. mlrun/runtimes/databricks_job/databricks_wrapper.py +1 -1
  233. mlrun/runtimes/funcdoc.py +1 -29
  234. mlrun/runtimes/function_reference.py +1 -1
  235. mlrun/runtimes/kubejob.py +34 -128
  236. mlrun/runtimes/local.py +40 -11
  237. mlrun/runtimes/mpijob/__init__.py +0 -20
  238. mlrun/runtimes/mpijob/abstract.py +9 -10
  239. mlrun/runtimes/mpijob/v1.py +1 -1
  240. mlrun/{model_monitoring/stores/models/sqlite.py → runtimes/nuclio/__init__.py} +7 -9
  241. mlrun/runtimes/nuclio/api_gateway.py +769 -0
  242. mlrun/runtimes/nuclio/application/__init__.py +15 -0
  243. mlrun/runtimes/nuclio/application/application.py +758 -0
  244. mlrun/runtimes/nuclio/application/reverse_proxy.go +95 -0
  245. mlrun/runtimes/{function.py → nuclio/function.py} +200 -83
  246. mlrun/runtimes/{nuclio.py → nuclio/nuclio.py} +6 -6
  247. mlrun/runtimes/{serving.py → nuclio/serving.py} +65 -68
  248. mlrun/runtimes/pod.py +281 -101
  249. mlrun/runtimes/remotesparkjob.py +12 -9
  250. mlrun/runtimes/sparkjob/spark3job.py +67 -51
  251. mlrun/runtimes/utils.py +41 -75
  252. mlrun/secrets.py +9 -5
  253. mlrun/serving/__init__.py +8 -1
  254. mlrun/serving/remote.py +2 -7
  255. mlrun/serving/routers.py +85 -69
  256. mlrun/serving/server.py +69 -44
  257. mlrun/serving/states.py +209 -36
  258. mlrun/serving/utils.py +22 -14
  259. mlrun/serving/v1_serving.py +6 -7
  260. mlrun/serving/v2_serving.py +133 -54
  261. mlrun/track/tracker.py +2 -1
  262. mlrun/track/tracker_manager.py +3 -3
  263. mlrun/track/trackers/mlflow_tracker.py +6 -2
  264. mlrun/utils/async_http.py +6 -8
  265. mlrun/utils/azure_vault.py +1 -1
  266. mlrun/utils/clones.py +1 -2
  267. mlrun/utils/condition_evaluator.py +3 -3
  268. mlrun/utils/db.py +21 -3
  269. mlrun/utils/helpers.py +405 -225
  270. mlrun/utils/http.py +3 -6
  271. mlrun/utils/logger.py +112 -16
  272. mlrun/utils/notifications/notification/__init__.py +17 -13
  273. mlrun/utils/notifications/notification/base.py +50 -2
  274. mlrun/utils/notifications/notification/console.py +2 -0
  275. mlrun/utils/notifications/notification/git.py +24 -1
  276. mlrun/utils/notifications/notification/ipython.py +3 -1
  277. mlrun/utils/notifications/notification/slack.py +96 -21
  278. mlrun/utils/notifications/notification/webhook.py +59 -2
  279. mlrun/utils/notifications/notification_pusher.py +149 -30
  280. mlrun/utils/regex.py +9 -0
  281. mlrun/utils/retryer.py +208 -0
  282. mlrun/utils/singleton.py +1 -1
  283. mlrun/utils/v3io_clients.py +4 -6
  284. mlrun/utils/version/version.json +2 -2
  285. mlrun/utils/version/version.py +2 -6
  286. mlrun-1.7.0.dist-info/METADATA +378 -0
  287. mlrun-1.7.0.dist-info/RECORD +351 -0
  288. {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/WHEEL +1 -1
  289. mlrun/feature_store/retrieval/conversion.py +0 -273
  290. mlrun/kfpops.py +0 -868
  291. mlrun/model_monitoring/application.py +0 -310
  292. mlrun/model_monitoring/batch.py +0 -1095
  293. mlrun/model_monitoring/prometheus.py +0 -219
  294. mlrun/model_monitoring/stores/__init__.py +0 -111
  295. mlrun/model_monitoring/stores/kv_model_endpoint_store.py +0 -576
  296. mlrun/model_monitoring/stores/model_endpoint_store.py +0 -147
  297. mlrun/model_monitoring/stores/models/__init__.py +0 -27
  298. mlrun/model_monitoring/stores/models/base.py +0 -84
  299. mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -384
  300. mlrun/platforms/other.py +0 -306
  301. mlrun-1.6.4rc8.dist-info/METADATA +0 -272
  302. mlrun-1.6.4rc8.dist-info/RECORD +0 -314
  303. {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/LICENSE +0 -0
  304. {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/entry_points.txt +0 -0
  305. {mlrun-1.6.4rc8.dist-info → mlrun-1.7.0.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,11 @@
14
14
  #
15
15
  import datetime
16
16
  import os
17
- from typing import List, Union
17
+ from typing import Union
18
18
 
19
19
  import numpy as np
20
20
  import pandas as pd
21
- from IPython.core.display import HTML
22
- from IPython.display import display
21
+ from IPython.display import HTML, display
23
22
  from pandas.api.types import is_numeric_dtype, is_string_dtype
24
23
 
25
24
  import mlrun
@@ -216,7 +215,7 @@ def _show_and_export_html(html: str, show=None, filename=None, runs_list=None):
216
215
  fp.write("</body></html>")
217
216
  else:
218
217
  fp.write(html)
219
- if show or (show is None and mlrun.utils.is_ipython):
218
+ if show or (show is None and mlrun.utils.is_jupyter):
220
219
  display(HTML(html))
221
220
  if runs_list and len(runs_list) <= max_table_rows:
222
221
  display(HTML(html_table))
@@ -239,7 +238,7 @@ def _runs_list_to_df(runs_list, extend_iterations=False):
239
238
 
240
239
  @filter_warnings("ignore", FutureWarning)
241
240
  def compare_run_objects(
242
- runs_list: Union[mlrun.model.RunObject, List[mlrun.model.RunObject]],
241
+ runs_list: Union[mlrun.model.RunObject, list[mlrun.model.RunObject]],
243
242
  hide_identical: bool = True,
244
243
  exclude: list = None,
245
244
  show: bool = None,
@@ -295,7 +294,7 @@ def compare_db_runs(
295
294
  iter=False,
296
295
  start_time_from: datetime = None,
297
296
  hide_identical: bool = True,
298
- exclude: list = [],
297
+ exclude: list = None,
299
298
  show=None,
300
299
  colorscale: str = "Blues",
301
300
  filename=None,
@@ -332,6 +331,7 @@ def compare_db_runs(
332
331
  **query_args,
333
332
  )
334
333
 
334
+ exclude = exclude or []
335
335
  runs_df = _runs_list_to_df(runs_list)
336
336
  plot_as_html = gen_pcp_plot(
337
337
  runs_df,
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  # flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx
16
- from typing import Any, Dict, List, Tuple, Union
16
+ from typing import Any, Union
17
17
 
18
18
  from torch.nn import Module
19
19
  from torch.optim import Optimizer
@@ -35,23 +35,23 @@ def train(
35
35
  loss_function: Module,
36
36
  optimizer: Optimizer,
37
37
  validation_set: DataLoader = None,
38
- metric_functions: List[PyTorchTypes.MetricFunctionType] = None,
38
+ metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
39
39
  scheduler=None,
40
40
  scheduler_step_frequency: Union[int, float, str] = "epoch",
41
41
  epochs: int = 1,
42
42
  training_iterations: int = None,
43
43
  validation_iterations: int = None,
44
- callbacks_list: List[Callback] = None,
44
+ callbacks_list: list[Callback] = None,
45
45
  use_cuda: bool = True,
46
46
  use_horovod: bool = None,
47
47
  auto_log: bool = True,
48
48
  model_name: str = None,
49
- modules_map: Union[Dict[str, Union[None, str, List[str]]], str] = None,
50
- custom_objects_map: Union[Dict[str, Union[str, List[str]]], str] = None,
49
+ modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
50
+ custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
51
51
  custom_objects_directory: str = None,
52
52
  tensorboard_directory: str = None,
53
- mlrun_callback_kwargs: Dict[str, Any] = None,
54
- tensorboard_callback_kwargs: Dict[str, Any] = None,
53
+ mlrun_callback_kwargs: dict[str, Any] = None,
54
+ tensorboard_callback_kwargs: dict[str, Any] = None,
55
55
  context: mlrun.MLClientCtx = None,
56
56
  ) -> PyTorchModelHandler:
57
57
  """
@@ -112,7 +112,7 @@ def train(
112
112
 
113
113
  {
114
114
  "/.../custom_optimizer.py": "optimizer",
115
- "/.../custom_layers.py": ["layer1", "layer2"]
115
+ "/.../custom_layers.py": ["layer1", "layer2"],
116
116
  }
117
117
 
118
118
  All the paths will be accessed from the given 'custom_objects_directory',
@@ -205,19 +205,19 @@ def evaluate(
205
205
  dataset: DataLoader,
206
206
  model: Module = None,
207
207
  loss_function: Module = None,
208
- metric_functions: List[PyTorchTypes.MetricFunctionType] = None,
208
+ metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
209
209
  iterations: int = None,
210
- callbacks_list: List[Callback] = None,
210
+ callbacks_list: list[Callback] = None,
211
211
  use_cuda: bool = True,
212
212
  use_horovod: bool = False,
213
213
  auto_log: bool = True,
214
214
  model_name: str = None,
215
- modules_map: Union[Dict[str, Union[None, str, List[str]]], str] = None,
216
- custom_objects_map: Union[Dict[str, Union[str, List[str]]], str] = None,
215
+ modules_map: Union[dict[str, Union[None, str, list[str]]], str] = None,
216
+ custom_objects_map: Union[dict[str, Union[str, list[str]]], str] = None,
217
217
  custom_objects_directory: str = None,
218
- mlrun_callback_kwargs: Dict[str, Any] = None,
218
+ mlrun_callback_kwargs: dict[str, Any] = None,
219
219
  context: mlrun.MLClientCtx = None,
220
- ) -> Tuple[PyTorchModelHandler, List[PyTorchTypes.MetricValueType]]:
220
+ ) -> tuple[PyTorchModelHandler, list[PyTorchTypes.MetricValueType]]:
221
221
  """
222
222
  Use MLRun's PyTorch interface to evaluate the model with the given parameters. For more information and further
223
223
  options regarding the auto logging, see 'PyTorchMLRunInterface' documentation. Notice for auto-logging: In order to
@@ -264,7 +264,7 @@ def evaluate(
264
264
 
265
265
  {
266
266
  "/.../custom_optimizer.py": "optimizer",
267
- "/.../custom_layers.py": ["layer1", "layer2"]
267
+ "/.../custom_layers.py": ["layer1", "layer2"],
268
268
  }
269
269
 
270
270
  All the paths will be accessed from the given 'custom_objects_directory', meaning
@@ -343,9 +343,9 @@ def evaluate(
343
343
  def _parse_callbacks_kwargs(
344
344
  handler: PyTorchModelHandler,
345
345
  tensorboard_directory: Union[str, None],
346
- mlrun_callback_kwargs: Union[Dict[str, Any], None],
347
- tensorboard_callback_kwargs: Union[Dict[str, Any], None],
348
- ) -> Tuple[dict, dict]:
346
+ mlrun_callback_kwargs: Union[dict[str, Any], None],
347
+ tensorboard_callback_kwargs: Union[dict[str, Any], None],
348
+ ) -> tuple[dict, dict]:
349
349
  """
350
350
  Parse the given parameters into the MLRun and Tensorboard callbacks kwargs.
351
351
 
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  from abc import ABC, abstractmethod
16
- from typing import List
17
16
 
18
17
  from torch import Tensor
19
18
  from torch.nn import Module
@@ -68,7 +67,7 @@ class Callback(ABC):
68
67
  validation_set: DataLoader = None,
69
68
  loss_function: Module = None,
70
69
  optimizer: Optimizer = None,
71
- metric_functions: List[PyTorchTypes.MetricFunctionType] = None,
70
+ metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
72
71
  scheduler=None,
73
72
  ):
74
73
  """
@@ -141,7 +140,7 @@ class Callback(ABC):
141
140
  pass
142
141
 
143
142
  def on_validation_end(
144
- self, loss_value: PyTorchTypes.MetricValueType, metric_values: List[float]
143
+ self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
145
144
  ) -> bool:
146
145
  """
147
146
  Before the validation (in a training case it will be per epoch) ends, this method will be called.
@@ -258,7 +257,7 @@ class Callback(ABC):
258
257
  """
259
258
  pass
260
259
 
261
- def on_train_metrics_end(self, metric_values: List[PyTorchTypes.MetricValueType]):
260
+ def on_train_metrics_end(self, metric_values: list[PyTorchTypes.MetricValueType]):
262
261
  """
263
262
  After the training calculation of the metrics, this method will be called.
264
263
 
@@ -273,7 +272,7 @@ class Callback(ABC):
273
272
  pass
274
273
 
275
274
  def on_validation_metrics_end(
276
- self, metric_values: List[PyTorchTypes.MetricValueType]
275
+ self, metric_values: list[PyTorchTypes.MetricValueType]
277
276
  ):
278
277
  """
279
278
  After the validating calculation of the metrics, this method will be called.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  #
15
- from typing import Callable, Dict, List, Tuple, Union
15
+ from typing import Callable, Union
16
16
 
17
17
  import numpy as np
18
18
  from torch import Tensor
@@ -59,15 +59,15 @@ class LoggingCallback(Callback):
59
59
  def __init__(
60
60
  self,
61
61
  context: mlrun.MLClientCtx = None,
62
- dynamic_hyperparameters: Dict[
62
+ dynamic_hyperparameters: dict[
63
63
  str,
64
- Tuple[
64
+ tuple[
65
65
  str,
66
- Union[List[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
66
+ Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
67
67
  ],
68
68
  ] = None,
69
- static_hyperparameters: Dict[
70
- str, Union[PyTorchTypes.TrackableType, Tuple[str, List[Union[str, int]]]]
69
+ static_hyperparameters: dict[
70
+ str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
71
71
  ] = None,
72
72
  auto_log: bool = False,
73
73
  ):
@@ -100,7 +100,7 @@ class LoggingCallback(Callback):
100
100
  :param auto_log: Whether or not to enable auto logging, trying to track common static and dynamic
101
101
  hyperparameters.
102
102
  """
103
- super(LoggingCallback, self).__init__()
103
+ super().__init__()
104
104
 
105
105
  # Store the configurations:
106
106
  self._dynamic_hyperparameters_keys = (
@@ -117,7 +117,7 @@ class LoggingCallback(Callback):
117
117
  self._is_training = None # type: bool
118
118
  self._auto_log = auto_log
119
119
 
120
- def get_training_results(self) -> Dict[str, List[List[float]]]:
120
+ def get_training_results(self) -> dict[str, list[list[float]]]:
121
121
  """
122
122
  Get the training results logged. The results will be stored in a dictionary where each key is the metric name
123
123
  and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
@@ -127,7 +127,7 @@ class LoggingCallback(Callback):
127
127
  """
128
128
  return self._logger.training_results
129
129
 
130
- def get_validation_results(self) -> Dict[str, List[List[float]]]:
130
+ def get_validation_results(self) -> dict[str, list[list[float]]]:
131
131
  """
132
132
  Get the validation results logged. The results will be stored in a dictionary where each key is the metric name
133
133
  and the value is a list of lists of values. The first list is by epoch and the second list is by iteration
@@ -137,7 +137,7 @@ class LoggingCallback(Callback):
137
137
  """
138
138
  return self._logger.validation_results
139
139
 
140
- def get_static_hyperparameters(self) -> Dict[str, PyTorchTypes.TrackableType]:
140
+ def get_static_hyperparameters(self) -> dict[str, PyTorchTypes.TrackableType]:
141
141
  """
142
142
  Get the static hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
143
143
  hyperparameter name and the value is his logged value.
@@ -148,7 +148,7 @@ class LoggingCallback(Callback):
148
148
 
149
149
  def get_dynamic_hyperparameters(
150
150
  self,
151
- ) -> Dict[str, List[PyTorchTypes.TrackableType]]:
151
+ ) -> dict[str, list[PyTorchTypes.TrackableType]]:
152
152
  """
153
153
  Get the dynamic hyperparameters logged. The hyperparameters will be stored in a dictionary where each key is the
154
154
  hyperparameter name and the value is a list of his logged values per epoch.
@@ -157,7 +157,7 @@ class LoggingCallback(Callback):
157
157
  """
158
158
  return self._logger.dynamic_hyperparameters
159
159
 
160
- def get_summaries(self) -> Dict[str, List[float]]:
160
+ def get_summaries(self) -> dict[str, list[float]]:
161
161
  """
162
162
  Get the validation summaries of the metrics results. The summaries will be stored in a dictionary where each key
163
163
  is the metric names and the value is a list of all the summary values per epoch.
@@ -210,7 +210,7 @@ class LoggingCallback(Callback):
210
210
  self._add_auto_hyperparameters()
211
211
  # # Static hyperparameters:
212
212
  for name, value in self._static_hyperparameters_keys.items():
213
- if isinstance(value, Tuple):
213
+ if isinstance(value, tuple):
214
214
  # Its a parameter that needed to be extracted via key chain.
215
215
  self._logger.log_static_hyperparameter(
216
216
  parameter_name=name,
@@ -294,7 +294,7 @@ class LoggingCallback(Callback):
294
294
  self._logger.set_mode(mode=LoggingMode.EVALUATION)
295
295
 
296
296
  def on_validation_end(
297
- self, loss_value: PyTorchTypes.MetricValueType, metric_values: List[float]
297
+ self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
298
298
  ):
299
299
  """
300
300
  Before the validation (in a training case it will be per epoch) ends, this method will be called to log the
@@ -372,7 +372,7 @@ class LoggingCallback(Callback):
372
372
  result=float(loss_value),
373
373
  )
374
374
 
375
- def on_train_metrics_end(self, metric_values: List[PyTorchTypes.MetricValueType]):
375
+ def on_train_metrics_end(self, metric_values: list[PyTorchTypes.MetricValueType]):
376
376
  """
377
377
  After the training calculation of the metrics, this method will be called to log the metrics values.
378
378
 
@@ -389,7 +389,7 @@ class LoggingCallback(Callback):
389
389
  )
390
390
 
391
391
  def on_validation_metrics_end(
392
- self, metric_values: List[PyTorchTypes.MetricValueType]
392
+ self, metric_values: list[PyTorchTypes.MetricValueType]
393
393
  ):
394
394
  """
395
395
  After the validating calculation of the metrics, this method will be called to log the metrics values.
@@ -456,7 +456,7 @@ class LoggingCallback(Callback):
456
456
  self,
457
457
  source: str,
458
458
  key_chain: Union[
459
- List[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]
459
+ list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]
460
460
  ],
461
461
  ) -> PyTorchTypes.TrackableType:
462
462
  """
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  #
15
- from typing import Callable, Dict, List, Tuple, Union
15
+ from typing import Callable, Union
16
16
 
17
17
  import torch
18
18
  from torch import Tensor
@@ -53,20 +53,20 @@ class MLRunLoggingCallback(LoggingCallback):
53
53
  context: mlrun.MLClientCtx,
54
54
  model_handler: PyTorchModelHandler,
55
55
  log_model_tag: str = "",
56
- log_model_labels: Dict[str, PyTorchTypes.TrackableType] = None,
57
- log_model_parameters: Dict[str, PyTorchTypes.TrackableType] = None,
58
- log_model_extra_data: Dict[
56
+ log_model_labels: dict[str, PyTorchTypes.TrackableType] = None,
57
+ log_model_parameters: dict[str, PyTorchTypes.TrackableType] = None,
58
+ log_model_extra_data: dict[
59
59
  str, Union[PyTorchTypes.TrackableType, Artifact]
60
60
  ] = None,
61
- dynamic_hyperparameters: Dict[
61
+ dynamic_hyperparameters: dict[
62
62
  str,
63
- Tuple[
63
+ tuple[
64
64
  str,
65
- Union[List[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
65
+ Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
66
66
  ],
67
67
  ] = None,
68
- static_hyperparameters: Dict[
69
- str, Union[PyTorchTypes.TrackableType, Tuple[str, List[Union[str, int]]]]
68
+ static_hyperparameters: dict[
69
+ str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
70
70
  ] = None,
71
71
  auto_log: bool = False,
72
72
  ):
@@ -107,7 +107,7 @@ class MLRunLoggingCallback(LoggingCallback):
107
107
  :param auto_log: Whether or not to enable auto logging for logging the context parameters and
108
108
  trying to track common static and dynamic hyperparameters.
109
109
  """
110
- super(MLRunLoggingCallback, self).__init__(
110
+ super().__init__(
111
111
  dynamic_hyperparameters=dynamic_hyperparameters,
112
112
  static_hyperparameters=static_hyperparameters,
113
113
  auto_log=auto_log,
@@ -160,7 +160,7 @@ class MLRunLoggingCallback(LoggingCallback):
160
160
 
161
161
  :param epoch: The epoch that has just ended.
162
162
  """
163
- super(MLRunLoggingCallback, self).on_epoch_end(epoch=epoch)
163
+ super().on_epoch_end(epoch=epoch)
164
164
 
165
165
  # Create child context to hold the current epoch's results:
166
166
  self._logger.log_epoch_to_context(epoch=epoch)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  #
15
15
  from datetime import datetime
16
- from typing import Callable, Dict, List, Tuple, Union
16
+ from typing import Callable, Union
17
17
 
18
18
  import torch
19
19
  from torch import Tensor
@@ -63,7 +63,7 @@ class _PyTorchTensorboardLogger(TensorboardLogger):
63
63
 
64
64
  def __init__(
65
65
  self,
66
- statistics_functions: List[
66
+ statistics_functions: list[
67
67
  Callable[[Union[Parameter]], Union[float, Parameter]]
68
68
  ],
69
69
  context: mlrun.MLClientCtx = None,
@@ -94,7 +94,7 @@ class _PyTorchTensorboardLogger(TensorboardLogger):
94
94
  update. Notice that writing to tensorboard too frequently may cause the training
95
95
  to be slower. Default: 'epoch'.
96
96
  """
97
- super(_PyTorchTensorboardLogger, self).__init__(
97
+ super().__init__(
98
98
  statistics_functions=statistics_functions,
99
99
  context=context,
100
100
  tensorboard_directory=tensorboard_directory,
@@ -249,19 +249,19 @@ class TensorboardLoggingCallback(LoggingCallback):
249
249
  context: mlrun.MLClientCtx = None,
250
250
  tensorboard_directory: str = None,
251
251
  run_name: str = None,
252
- weights: Union[bool, List[str]] = False,
253
- statistics_functions: List[
252
+ weights: Union[bool, list[str]] = False,
253
+ statistics_functions: list[
254
254
  Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]
255
255
  ] = None,
256
- dynamic_hyperparameters: Dict[
256
+ dynamic_hyperparameters: dict[
257
257
  str,
258
- Tuple[
258
+ tuple[
259
259
  str,
260
- Union[List[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
260
+ Union[list[Union[str, int]], Callable[[], PyTorchTypes.TrackableType]],
261
261
  ],
262
262
  ] = None,
263
- static_hyperparameters: Dict[
264
- str, Union[PyTorchTypes.TrackableType, Tuple[str, List[Union[str, int]]]]
263
+ static_hyperparameters: dict[
264
+ str, Union[PyTorchTypes.TrackableType, tuple[str, list[Union[str, int]]]]
265
265
  ] = None,
266
266
  update_frequency: Union[int, str] = "epoch",
267
267
  auto_log: bool = False,
@@ -322,7 +322,7 @@ class TensorboardLoggingCallback(LoggingCallback):
322
322
  :raise MLRunInvalidArgumentError: In case both 'context' and 'tensorboard_directory' parameters were not given
323
323
  or the 'update_frequency' was incorrect.
324
324
  """
325
- super(TensorboardLoggingCallback, self).__init__(
325
+ super().__init__(
326
326
  dynamic_hyperparameters=dynamic_hyperparameters,
327
327
  static_hyperparameters=static_hyperparameters,
328
328
  auto_log=auto_log,
@@ -345,7 +345,7 @@ class TensorboardLoggingCallback(LoggingCallback):
345
345
  # Save the configurations:
346
346
  self._tracked_weights = weights
347
347
 
348
- def get_weights(self) -> Dict[str, Parameter]:
348
+ def get_weights(self) -> dict[str, Parameter]:
349
349
  """
350
350
  Get the weights tensors tracked. The weights will be stored in a dictionary where each key is the weight's name
351
351
  and the value is the weight's parameter (tensor).
@@ -354,7 +354,7 @@ class TensorboardLoggingCallback(LoggingCallback):
354
354
  """
355
355
  return self._logger.weights
356
356
 
357
- def get_weights_statistics(self) -> Dict[str, Dict[str, List[float]]]:
357
+ def get_weights_statistics(self) -> dict[str, dict[str, list[float]]]:
358
358
  """
359
359
  Get the weights mean results logged. The results will be stored in a dictionary where each key is the weight's
360
360
  name and the value is a list of mean values per epoch.
@@ -365,7 +365,7 @@ class TensorboardLoggingCallback(LoggingCallback):
365
365
 
366
366
  @staticmethod
367
367
  def get_default_weight_statistics_list() -> (
368
- List[Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]]
368
+ list[Callable[[Union[Parameter, Tensor]], Union[float, Tensor]]]
369
369
  ):
370
370
  """
371
371
  Get the default list of statistics functions being applied on the tracked weights each epoch.
@@ -381,7 +381,7 @@ class TensorboardLoggingCallback(LoggingCallback):
381
381
  validation_set: DataLoader = None,
382
382
  loss_function: Module = None,
383
383
  optimizer: Optimizer = None,
384
- metric_functions: List[PyTorchTypes.MetricFunctionType] = None,
384
+ metric_functions: list[PyTorchTypes.MetricFunctionType] = None,
385
385
  scheduler=None,
386
386
  ):
387
387
  """
@@ -396,7 +396,7 @@ class TensorboardLoggingCallback(LoggingCallback):
396
396
  :param metric_functions: The metric functions to be stored in this callback.
397
397
  :param scheduler: The scheduler to be stored in this callback.
398
398
  """
399
- super(TensorboardLoggingCallback, self).on_setup(
399
+ super().on_setup(
400
400
  model=model,
401
401
  training_set=training_set,
402
402
  validation_set=validation_set,
@@ -439,7 +439,7 @@ class TensorboardLoggingCallback(LoggingCallback):
439
439
  for logging. Epoch 0 (pre-run state) will be logged here.
440
440
  """
441
441
  # Setup all the results and hyperparameters dictionaries:
442
- super(TensorboardLoggingCallback, self).on_run_begin()
442
+ super().on_run_begin()
443
443
 
444
444
  # Log the initial summary of the run:
445
445
  self._logger.write_initial_summary_text()
@@ -470,10 +470,10 @@ class TensorboardLoggingCallback(LoggingCallback):
470
470
  # Write the final summary of the run:
471
471
  self._logger.write_final_summary_text()
472
472
 
473
- super(TensorboardLoggingCallback, self).on_run_end()
473
+ super().on_run_end()
474
474
 
475
475
  def on_validation_end(
476
- self, loss_value: PyTorchTypes.MetricValueType, metric_values: List[float]
476
+ self, loss_value: PyTorchTypes.MetricValueType, metric_values: list[float]
477
477
  ):
478
478
  """
479
479
  Before the validation (in a training case it will be per epoch) ends, this method will be called to log the
@@ -482,9 +482,7 @@ class TensorboardLoggingCallback(LoggingCallback):
482
482
  :param loss_value: The loss summary of this validation.
483
483
  :param metric_values: The metrics summaries of this validation.
484
484
  """
485
- super(TensorboardLoggingCallback, self).on_validation_end(
486
- loss_value=loss_value, metric_values=metric_values
487
- )
485
+ super().on_validation_end(loss_value=loss_value, metric_values=metric_values)
488
486
 
489
487
  # Check if this run was part of an evaluation:
490
488
  if not self._is_training:
@@ -503,7 +501,7 @@ class TensorboardLoggingCallback(LoggingCallback):
503
501
 
504
502
  :param epoch: The epoch that has just ended.
505
503
  """
506
- super(TensorboardLoggingCallback, self).on_epoch_end(epoch=epoch)
504
+ super().on_epoch_end(epoch=epoch)
507
505
 
508
506
  # Log the weights statistics:
509
507
  self._logger.log_weights_statistics()
@@ -540,9 +538,7 @@ class TensorboardLoggingCallback(LoggingCallback):
540
538
  :param y_true: The true value part of the current batch.
541
539
  :param y_pred: The prediction (output) of the model for this batch's input ('x').
542
540
  """
543
- super(TensorboardLoggingCallback, self).on_train_batch_end(
544
- batch=batch, x=x, y_true=y_true, y_pred=y_pred
545
- )
541
+ super().on_train_batch_end(batch=batch, x=x, y_true=y_true, y_pred=y_pred)
546
542
 
547
543
  # Write the batch loss and metrics results to their graphs:
548
544
  self._logger.write_training_results()
@@ -559,9 +555,7 @@ class TensorboardLoggingCallback(LoggingCallback):
559
555
  :param y_true: The true value part of the current batch.
560
556
  :param y_pred: The prediction (output) of the model for this batch's input ('x').
561
557
  """
562
- super(TensorboardLoggingCallback, self).on_validation_batch_end(
563
- batch=batch, x=x, y_true=y_true, y_pred=y_pred
564
- )
558
+ super().on_validation_batch_end(batch=batch, x=x, y_true=y_true, y_pred=y_pred)
565
559
 
566
560
  # Write the batch loss and metrics results to their graphs:
567
561
  self._logger.write_validation_results()