mlrun 1.10.0rc40__py3-none-any.whl → 1.11.0rc16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (150) hide show
  1. mlrun/__init__.py +3 -2
  2. mlrun/__main__.py +0 -4
  3. mlrun/artifacts/dataset.py +2 -2
  4. mlrun/artifacts/plots.py +1 -1
  5. mlrun/{model_monitoring/db/tsdb/tdengine → auth}/__init__.py +2 -3
  6. mlrun/auth/nuclio.py +89 -0
  7. mlrun/auth/providers.py +429 -0
  8. mlrun/auth/utils.py +415 -0
  9. mlrun/common/constants.py +7 -0
  10. mlrun/common/model_monitoring/helpers.py +41 -4
  11. mlrun/common/runtimes/constants.py +28 -0
  12. mlrun/common/schemas/__init__.py +13 -3
  13. mlrun/common/schemas/alert.py +2 -2
  14. mlrun/common/schemas/api_gateway.py +3 -0
  15. mlrun/common/schemas/auth.py +10 -10
  16. mlrun/common/schemas/client_spec.py +4 -0
  17. mlrun/common/schemas/constants.py +25 -0
  18. mlrun/common/schemas/frontend_spec.py +1 -8
  19. mlrun/common/schemas/function.py +24 -0
  20. mlrun/common/schemas/hub.py +3 -2
  21. mlrun/common/schemas/model_monitoring/__init__.py +1 -1
  22. mlrun/common/schemas/model_monitoring/constants.py +2 -2
  23. mlrun/common/schemas/secret.py +17 -2
  24. mlrun/common/secrets.py +95 -1
  25. mlrun/common/types.py +10 -10
  26. mlrun/config.py +53 -15
  27. mlrun/data_types/infer.py +2 -2
  28. mlrun/datastore/__init__.py +2 -3
  29. mlrun/datastore/base.py +274 -10
  30. mlrun/datastore/datastore.py +1 -1
  31. mlrun/datastore/datastore_profile.py +49 -17
  32. mlrun/datastore/model_provider/huggingface_provider.py +6 -2
  33. mlrun/datastore/model_provider/model_provider.py +2 -2
  34. mlrun/datastore/model_provider/openai_provider.py +2 -2
  35. mlrun/datastore/s3.py +15 -16
  36. mlrun/datastore/sources.py +1 -1
  37. mlrun/datastore/store_resources.py +4 -4
  38. mlrun/datastore/storeytargets.py +16 -10
  39. mlrun/datastore/targets.py +1 -1
  40. mlrun/datastore/utils.py +16 -3
  41. mlrun/datastore/v3io.py +1 -1
  42. mlrun/db/base.py +36 -12
  43. mlrun/db/httpdb.py +316 -101
  44. mlrun/db/nopdb.py +29 -11
  45. mlrun/errors.py +4 -2
  46. mlrun/execution.py +11 -12
  47. mlrun/feature_store/api.py +1 -1
  48. mlrun/feature_store/common.py +1 -1
  49. mlrun/feature_store/feature_vector_utils.py +1 -1
  50. mlrun/feature_store/steps.py +8 -6
  51. mlrun/frameworks/_common/utils.py +3 -3
  52. mlrun/frameworks/_dl_common/loggers/logger.py +1 -1
  53. mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +2 -1
  54. mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +1 -1
  55. mlrun/frameworks/_ml_common/utils.py +2 -1
  56. mlrun/frameworks/auto_mlrun/auto_mlrun.py +4 -3
  57. mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +2 -1
  58. mlrun/frameworks/onnx/dataset.py +2 -1
  59. mlrun/frameworks/onnx/mlrun_interface.py +2 -1
  60. mlrun/frameworks/pytorch/callbacks/logging_callback.py +5 -4
  61. mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +2 -1
  62. mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +2 -1
  63. mlrun/frameworks/pytorch/utils.py +2 -1
  64. mlrun/frameworks/sklearn/metric.py +2 -1
  65. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +5 -4
  66. mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +2 -1
  67. mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +2 -1
  68. mlrun/hub/__init__.py +37 -0
  69. mlrun/hub/base.py +142 -0
  70. mlrun/hub/module.py +67 -76
  71. mlrun/hub/step.py +113 -0
  72. mlrun/launcher/base.py +2 -1
  73. mlrun/launcher/local.py +2 -1
  74. mlrun/model.py +12 -2
  75. mlrun/model_monitoring/__init__.py +0 -1
  76. mlrun/model_monitoring/api.py +2 -2
  77. mlrun/model_monitoring/applications/base.py +20 -6
  78. mlrun/model_monitoring/applications/context.py +1 -0
  79. mlrun/model_monitoring/controller.py +7 -17
  80. mlrun/model_monitoring/db/_schedules.py +2 -16
  81. mlrun/model_monitoring/db/_stats.py +2 -13
  82. mlrun/model_monitoring/db/tsdb/__init__.py +9 -7
  83. mlrun/model_monitoring/db/tsdb/base.py +2 -4
  84. mlrun/model_monitoring/db/tsdb/preaggregate.py +234 -0
  85. mlrun/model_monitoring/db/tsdb/stream_graph_steps.py +63 -0
  86. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_metrics_queries.py +414 -0
  87. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_predictions_queries.py +376 -0
  88. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_results_queries.py +590 -0
  89. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connection.py +434 -0
  90. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connector.py +541 -0
  91. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_operations.py +808 -0
  92. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_schema.py +502 -0
  93. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream.py +163 -0
  94. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream_graph_steps.py +60 -0
  95. mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_dataframe_processor.py +141 -0
  96. mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_query_builder.py +585 -0
  97. mlrun/model_monitoring/db/tsdb/timescaledb/writer_graph_steps.py +73 -0
  98. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +4 -6
  99. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +147 -79
  100. mlrun/model_monitoring/features_drift_table.py +2 -1
  101. mlrun/model_monitoring/helpers.py +2 -1
  102. mlrun/model_monitoring/stream_processing.py +18 -16
  103. mlrun/model_monitoring/writer.py +4 -3
  104. mlrun/package/__init__.py +2 -1
  105. mlrun/platforms/__init__.py +0 -44
  106. mlrun/platforms/iguazio.py +1 -1
  107. mlrun/projects/operations.py +11 -10
  108. mlrun/projects/project.py +81 -82
  109. mlrun/run.py +4 -7
  110. mlrun/runtimes/__init__.py +2 -204
  111. mlrun/runtimes/base.py +89 -21
  112. mlrun/runtimes/constants.py +225 -0
  113. mlrun/runtimes/daskjob.py +4 -2
  114. mlrun/runtimes/databricks_job/databricks_runtime.py +2 -1
  115. mlrun/runtimes/mounts.py +5 -0
  116. mlrun/runtimes/nuclio/__init__.py +12 -8
  117. mlrun/runtimes/nuclio/api_gateway.py +36 -6
  118. mlrun/runtimes/nuclio/application/application.py +200 -32
  119. mlrun/runtimes/nuclio/function.py +154 -49
  120. mlrun/runtimes/nuclio/serving.py +55 -42
  121. mlrun/runtimes/pod.py +59 -10
  122. mlrun/secrets.py +46 -2
  123. mlrun/serving/__init__.py +2 -0
  124. mlrun/serving/remote.py +5 -5
  125. mlrun/serving/routers.py +3 -3
  126. mlrun/serving/server.py +46 -43
  127. mlrun/serving/serving_wrapper.py +6 -2
  128. mlrun/serving/states.py +554 -207
  129. mlrun/serving/steps.py +1 -1
  130. mlrun/serving/system_steps.py +42 -33
  131. mlrun/track/trackers/mlflow_tracker.py +29 -31
  132. mlrun/utils/helpers.py +89 -16
  133. mlrun/utils/http.py +9 -2
  134. mlrun/utils/notifications/notification/git.py +1 -1
  135. mlrun/utils/notifications/notification/mail.py +39 -16
  136. mlrun/utils/notifications/notification_pusher.py +2 -2
  137. mlrun/utils/version/version.json +2 -2
  138. mlrun/utils/version/version.py +3 -4
  139. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/METADATA +39 -49
  140. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/RECORD +144 -130
  141. mlrun/db/auth_utils.py +0 -152
  142. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +0 -343
  143. mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +0 -75
  144. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connection.py +0 -281
  145. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +0 -1368
  146. mlrun/model_monitoring/db/tsdb/tdengine/writer_graph_steps.py +0 -51
  147. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/WHEEL +0 -0
  148. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/entry_points.txt +0 -0
  149. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/licenses/LICENSE +0 -0
  150. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/top_level.txt +0 -0
mlrun/__init__.py CHANGED
@@ -24,6 +24,7 @@ __all__ = [
24
24
  "v3io_cred",
25
25
  "auto_mount",
26
26
  "VolumeMount",
27
+ "sync_secret_tokens",
27
28
  ]
28
29
 
29
30
  from os import environ, path
@@ -37,7 +38,7 @@ from .datastore import DataItem, ModelProvider, store_manager
37
38
  from .db import get_run_db
38
39
  from .errors import MLRunInvalidArgumentError, MLRunNotFoundError
39
40
  from .execution import MLClientCtx
40
- from .hub import get_hub_module, import_module
41
+ from .hub import get_hub_item, get_hub_module, get_hub_step, import_module
41
42
  from .model import RunObject, RunTemplate, new_task
42
43
  from .package import ArtifactType, DefaultPackager, Packager, handler
43
44
  from .projects import (
@@ -68,7 +69,7 @@ from .run import (
68
69
  wait_for_pipeline_completion,
69
70
  )
70
71
  from .runtimes import mounts, new_model_server
71
- from .secrets import get_secret_or_env
72
+ from .secrets import get_secret_or_env, sync_secret_tokens
72
73
  from .utils.version import Version
73
74
 
74
75
  __version__ = Version().get()["version"]
mlrun/__main__.py CHANGED
@@ -203,7 +203,6 @@ def main():
203
203
  @click.option(
204
204
  "--allow-cross-project",
205
205
  is_flag=True,
206
- default=True, # TODO: remove this default in 1.11
207
206
  help="Override the loaded project name. This flag ensures awareness of loading an existing project yaml "
208
207
  "as a baseline for a new project with a different name",
209
208
  )
@@ -513,7 +512,6 @@ def run(
513
512
  @click.option(
514
513
  "--allow-cross-project",
515
514
  is_flag=True,
516
- default=True, # TODO: remove this default in 1.11
517
515
  help="Override the loaded project name. This flag ensures awareness of loading an existing project yaml "
518
516
  "as a baseline for a new project with a different name",
519
517
  )
@@ -672,7 +670,6 @@ def build(
672
670
  @click.option(
673
671
  "--allow-cross-project",
674
672
  is_flag=True,
675
- default=True, # TODO: remove this default in 1.11
676
673
  help="Override the loaded project name. This flag ensures awareness of loading an existing project yaml "
677
674
  "as a baseline for a new project with a different name",
678
675
  )
@@ -1008,7 +1005,6 @@ def logs(uid, project, offset, db):
1008
1005
  @click.option(
1009
1006
  "--allow-cross-project",
1010
1007
  is_flag=True,
1011
- default=True, # TODO: remove this default in 1.11
1012
1008
  help="Override the loaded project name. This flag ensures awareness of loading an existing project yaml "
1013
1009
  "as a baseline for a new project with a different name",
1014
1010
  )
@@ -366,9 +366,9 @@ def get_df_stats(df):
366
366
  for col, values in df.describe(include="all").items():
367
367
  stats_dict = {}
368
368
  for stat, val in values.dropna().items():
369
- if isinstance(val, (float, np.floating, np.float64)):
369
+ if isinstance(val, float | np.floating | np.float64):
370
370
  stats_dict[stat] = float(val)
371
- elif isinstance(val, (int, np.integer, np.int64)):
371
+ elif isinstance(val, int | np.integer | np.int64):
372
372
  stats_dict[stat] = int(val)
373
373
  else:
374
374
  stats_dict[stat] = str(val)
mlrun/artifacts/plots.py CHANGED
@@ -42,7 +42,7 @@ class PlotArtifact(Artifact):
42
42
  import matplotlib
43
43
 
44
44
  if not self.spec.get_body() or not isinstance(
45
- self.spec.get_body(), (bytes, matplotlib.figure.Figure)
45
+ self.spec.get_body(), bytes | matplotlib.figure.Figure
46
46
  ):
47
47
  raise ValueError(
48
48
  "matplotlib fig or png bytes must be provided as artifact body"
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Iguazio
1
+ # Copyright 2025 Iguazio
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -11,5 +11,4 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- from .tdengine_connector import TDEngineConnector
14
+ from .providers import IGTokenProvider, OAuthClientIDTokenProvider, StaticTokenProvider
mlrun/auth/nuclio.py ADDED
@@ -0,0 +1,89 @@
1
+ # Copyright 2023 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import base64
16
+
17
+ import requests.auth
18
+ from nuclio.auth import AuthInfo as _NuclioAuthInfo
19
+ from nuclio.auth import AuthKinds as NuclioAuthKinds
20
+
21
+ import mlrun.auth.providers
22
+ import mlrun.common.schemas.auth
23
+
24
+
25
+ class NuclioAuthInfo(_NuclioAuthInfo):
26
+ def __init__(self, token=None, **kwargs):
27
+ super().__init__(**kwargs)
28
+ self._token = token
29
+
30
+ @classmethod
31
+ def from_auth_info(cls, auth_info: "mlrun.common.schemas.auth.AuthInfo"):
32
+ if not auth_info:
33
+ return None
34
+ if mlrun.mlconf.is_iguazio_v4_mode():
35
+ return cls.from_request_headers(auth_info.request_headers)
36
+ if auth_info.session != "":
37
+ return NuclioAuthInfo(
38
+ password=auth_info.session, mode=NuclioAuthKinds.iguazio
39
+ )
40
+ return None
41
+
42
+ @classmethod
43
+ def from_request_headers(cls, headers: dict[str, str]):
44
+ if not headers:
45
+ return cls()
46
+ for key, value in headers.items():
47
+ if key.lower() == "authorization":
48
+ if value.lower().startswith("bearer "):
49
+ return cls(
50
+ token=value[len("bearer ") :],
51
+ mode=NuclioAuthKinds.iguazio,
52
+ )
53
+ if value.lower().startswith("basic "):
54
+ token = value[len("basic ") :]
55
+ decoded_token = base64.b64decode(token).decode("utf-8")
56
+ username, password = decoded_token.split(":", 1)
57
+ return cls(
58
+ username=username,
59
+ password=password,
60
+ mode=NuclioAuthKinds.iguazio,
61
+ )
62
+ return cls()
63
+
64
+ @classmethod
65
+ def from_envvar(cls):
66
+ if mlrun.mlconf.is_iguazio_v4_mode():
67
+ token_provider = mlrun.auth.providers.IGTokenProvider(
68
+ token_endpoint=mlrun.mlconf.auth_token_endpoint,
69
+ )
70
+ return cls(
71
+ token=token_provider.get_token(),
72
+ mode=NuclioAuthKinds.iguazio,
73
+ )
74
+ return super().from_envvar()
75
+
76
+ def to_requests_auth(self) -> "requests.auth":
77
+ if self._token:
78
+ # in iguazio v4 mode we use bearer token auth
79
+ return _RequestAuthBearerToken(self._token)
80
+ return super().to_requests_auth()
81
+
82
+
83
+ class _RequestAuthBearerToken(requests.auth.AuthBase):
84
+ def __init__(self, token: str):
85
+ self._token = token
86
+
87
+ def __call__(self, r):
88
+ r.headers["Authorization"] = f"Bearer {self._token}"
89
+ return r
@@ -0,0 +1,429 @@
1
+ # Copyright 2025 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import typing
16
+ from abc import ABC, abstractmethod
17
+ from datetime import datetime, timedelta
18
+
19
+ import jwt
20
+ import requests
21
+
22
+ import mlrun.auth.utils
23
+ import mlrun.errors
24
+ import mlrun.secrets
25
+ import mlrun.utils.helpers
26
+ from mlrun.config import config as mlconf
27
+ from mlrun.utils import logger
28
+
29
+
30
+ class TokenProvider(ABC):
31
+ @abstractmethod
32
+ def get_token(self):
33
+ pass
34
+
35
+ @abstractmethod
36
+ def is_iguazio_session(self):
37
+ pass
38
+
39
+
40
+ class StaticTokenProvider(TokenProvider):
41
+ def __init__(self, token: str):
42
+ self.token = token
43
+
44
+ def get_token(self):
45
+ return self.token
46
+
47
+ def is_iguazio_session(self):
48
+ return mlrun.platforms.iguazio.is_iguazio_session(self.token)
49
+
50
+
51
+ class DynamicTokenProvider(TokenProvider):
52
+ """
53
+ A token provider that dynamically fetches and refreshes tokens from a token endpoint.
54
+
55
+ This class handles token retrieval and automatic refresh when the token is expired or about to expire.
56
+ It uses a session with retry capabilities for robust communication with the token endpoint.
57
+
58
+ :param token_endpoint: The URL of the token endpoint.
59
+ :param timeout: The timeout for token requests, in seconds.
60
+ """
61
+
62
+ def __init__(self, token_endpoint: str, timeout=5, max_retries=0):
63
+ if not token_endpoint:
64
+ raise mlrun.errors.MLRunValueError(
65
+ "No token endpoint provided, cannot initialize token provider"
66
+ )
67
+ self._token = None
68
+ self._token_endpoint = token_endpoint
69
+ self._timeout = timeout
70
+ self._max_retries = max_retries
71
+
72
+ # Since we're only issuing POST requests, which are actually a disguised GET, then it's ok to allow retries
73
+ # on them.
74
+ self._session = mlrun.utils.HTTPSessionWithRetry(
75
+ retry_on_post=True,
76
+ verbose=True,
77
+ )
78
+ self._cleanup()
79
+ self._refresh_token_if_needed()
80
+
81
+ def get_token(self):
82
+ """
83
+ Retrieve the current access token, refreshing it if necessary.
84
+
85
+ :return: The current access token.
86
+ """
87
+ self._refresh_token_if_needed()
88
+ return self._token
89
+
90
+ def is_iguazio_session(self):
91
+ return False
92
+
93
+ def fetch_token(self):
94
+ mlrun.utils.helpers.run_with_retry(
95
+ retry_count=self._max_retries,
96
+ func=self._fetch_token,
97
+ )
98
+
99
+ def _fetch_token(self):
100
+ """
101
+ Fetch a new access token from the token endpoint.
102
+
103
+ This method builds the token request, sends it to the token endpoint, and parses the response.
104
+ If the request fails, it either raises an error or logs a warning based on the `raise_on_error` parameter.
105
+ """
106
+ request_body, headers, body_type = self._build_token_request(
107
+ raise_on_error=True
108
+ )
109
+
110
+ try:
111
+ request_kwargs = {
112
+ "method": "POST",
113
+ "url": self._token_endpoint,
114
+ "timeout": self._timeout,
115
+ "headers": headers,
116
+ "verify": mlconf.httpdb.http.verify,
117
+ }
118
+ if body_type == "json":
119
+ request_kwargs["json"] = request_body
120
+ else:
121
+ request_kwargs["data"] = request_body
122
+
123
+ response = self._session.request(**request_kwargs)
124
+ except requests.RequestException as exc:
125
+ error = f"Retrieving token failed: {mlrun.errors.err_to_str(exc)}"
126
+ raise mlrun.errors.MLRunRuntimeError(error) from exc
127
+
128
+ if not response.ok:
129
+ error = "No error available"
130
+ if response.content:
131
+ try:
132
+ data = response.json()
133
+ error = data.get("error")
134
+ except Exception:
135
+ pass
136
+ logger.warning(
137
+ "Retrieving token failed", status=response.status_code, error=error
138
+ )
139
+ mlrun.errors.raise_for_status(response)
140
+
141
+ self._parse_response(response.json())
142
+
143
+ def _refresh_token_if_needed(self):
144
+ """
145
+ Refresh the access token if it is expired or about to expire.
146
+
147
+ :return: The refreshed access token.
148
+ """
149
+ raise_on_error = True
150
+
151
+ # Check if there is an existing access token and if it is within the refresh threshold
152
+ if self._token and self._is_token_within_refresh_threshold(
153
+ cleanup_if_expired=True
154
+ ):
155
+ return self._token
156
+
157
+ try:
158
+ self.fetch_token()
159
+ except Exception as exc:
160
+ raise_on_error = False
161
+ # Token fetch failed and there is no existing token - cannot proceed
162
+ if not self._token:
163
+ raise mlrun.errors.MLRunRuntimeError(
164
+ "Failed to fetch a valid access token. Authentication procedure stopped."
165
+ ) from exc
166
+
167
+ finally:
168
+ self._post_fetch_hook(raise_on_error)
169
+
170
+ return self._token
171
+
172
+ @abstractmethod
173
+ def _post_fetch_hook(self, raise_on_error=True):
174
+ """
175
+ A hook that is called after fetching a new token.
176
+ Can be used to perform additional actions, such as logging or updating state.
177
+ """
178
+ pass
179
+
180
+ @abstractmethod
181
+ def _is_token_within_refresh_threshold(self, cleanup_if_expired=True) -> bool:
182
+ """
183
+ Check if the current access token is valid.
184
+
185
+ :param cleanup_if_expired: Whether to clean up the token if it is expired.
186
+ :return: True if the token is valid, False otherwise.
187
+ """
188
+ pass
189
+
190
+ @abstractmethod
191
+ def _cleanup(self):
192
+ """
193
+ Clean up the token and related metadata.
194
+ """
195
+ pass
196
+
197
+ @abstractmethod
198
+ def _build_token_request(self, raise_on_error=False):
199
+ """
200
+ Build the request body and headers for the token request.
201
+
202
+ :param raise_on_error: Whether to raise an error if the request cannot be built.
203
+ :return: A tuple containing the request body and headers.
204
+ """
205
+ pass
206
+
207
+ @abstractmethod
208
+ def _parse_response(self, data: dict):
209
+ """
210
+ Parse the response from the token endpoint.
211
+
212
+ :param data: The JSON response data from the token endpoint.
213
+ """
214
+ pass
215
+
216
+
217
+ class OAuthClientIDTokenProvider(DynamicTokenProvider):
218
+ def __init__(
219
+ self, token_endpoint: str, client_id: str, client_secret: str, timeout=5
220
+ ):
221
+ if not token_endpoint or not client_id or not client_secret:
222
+ raise mlrun.errors.MLRunValueError(
223
+ "Invalid client_id configuration for authentication. Must provide token endpoint, client-id and secret"
224
+ )
225
+ # should be set before calling the parent constructor
226
+ self._client_id = client_id
227
+ self._client_secret = client_secret
228
+ super().__init__(token_endpoint=token_endpoint, timeout=timeout)
229
+
230
+ def _cleanup(self):
231
+ self._token = self.token_expiry_time = self.token_refresh_time = None
232
+
233
+ def _is_token_within_refresh_threshold(self, cleanup_if_expired=True) -> bool:
234
+ """
235
+ Check if the current access token is valid.
236
+
237
+ :param cleanup_if_expired: Whether to clean up the token if it is expired.
238
+ :return: True if the token is valid, False otherwise.
239
+ """
240
+ if not self._token or not self.token_expiry_time:
241
+ return False
242
+
243
+ now = datetime.now()
244
+
245
+ if now <= self.token_refresh_time:
246
+ return True
247
+
248
+ if now < self.token_expiry_time:
249
+ # past refresh time but not expired yet → not valid
250
+ return False
251
+
252
+ # expired
253
+ if cleanup_if_expired:
254
+ # We only cleanup if token was really expired - even if we fail in refreshing the token, we can still
255
+ # use the existing one given that it's not expired.
256
+ self._cleanup()
257
+ return False
258
+
259
+ def _build_token_request(self, raise_on_error=False):
260
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
261
+ request_body = {
262
+ "grant_type": "client_credentials",
263
+ "client_id": self._client_id,
264
+ "client_secret": self._client_secret,
265
+ }
266
+ return request_body, headers, "data"
267
+
268
+ def _parse_response(self, data: dict):
269
+ # Response is described in https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.3
270
+ # According to spec, there isn't a refresh token - just the access token and its expiry time (in seconds).
271
+ self._token = data.get("access_token")
272
+ expires_in = data.get("expires_in")
273
+ if not self._token or not expires_in:
274
+ token_str = "****" if self._token else "missing"
275
+ logger.warning(
276
+ "Failed to parse token response", token=token_str, expires_in=expires_in
277
+ )
278
+ return
279
+
280
+ now = datetime.now()
281
+ self.token_expiry_time = now + timedelta(seconds=expires_in)
282
+ self.token_refresh_time = now + timedelta(seconds=expires_in / 2)
283
+ logger.info(
284
+ "Successfully retrieved client-id token",
285
+ expires_in=expires_in,
286
+ expiry=str(self.token_expiry_time),
287
+ refresh=str(self.token_refresh_time),
288
+ )
289
+
290
+ def _post_fetch_hook(self, raise_on_error=True):
291
+ """
292
+ A hook that is called after fetching a new token.
293
+ Can be used to perform additional actions, such as logging or updating state.
294
+ """
295
+ pass
296
+
297
+
298
+ class IGTokenProvider(DynamicTokenProvider):
299
+ """
300
+ A token provider for Iguazio that uses a refresh token to fetch access tokens.
301
+
302
+ This class implements the Iguazio-specific token refresh flow to retrieve access tokens
303
+ from a token endpoint.
304
+
305
+ :param token_endpoint: The URL of the token endpoint.
306
+ :param timeout: The timeout for token requests, in seconds.
307
+ """
308
+
309
+ def __init__(self, token_endpoint: str, timeout=5):
310
+ super().__init__(token_endpoint=token_endpoint, timeout=timeout, max_retries=2)
311
+
312
+ @property
313
+ def authenticated_user_id(self) -> typing.Optional[str]:
314
+ return mlrun.auth.utils.resolve_jwt_subject(self._token, raise_on_error=True)
315
+
316
+ def _cleanup(self):
317
+ self._token = None
318
+ self._token_total_lifetime = 0
319
+ self._token_expiry_time = None
320
+
321
+ def _is_token_within_refresh_threshold(self, cleanup_if_expired=True) -> bool:
322
+ """
323
+ Check if the current access token is valid and has sufficient lifetime remaining.
324
+
325
+ :param cleanup_if_expired: Whether to clean up the token if it is expired.
326
+ :return: True if the token is valid, False otherwise.
327
+ """
328
+ if (
329
+ not self._token
330
+ or self._token_total_lifetime <= 0
331
+ or not self._token_expiry_time
332
+ ):
333
+ return False
334
+
335
+ now = datetime.now()
336
+ remaining_lifetime = (self._token_expiry_time - now).total_seconds()
337
+ if remaining_lifetime <= 0 and cleanup_if_expired:
338
+ self._cleanup()
339
+ return False
340
+
341
+ return (
342
+ self._token_total_lifetime - remaining_lifetime
343
+ < self._token_total_lifetime
344
+ * mlconf.auth_with_oauth_token.refresh_threshold
345
+ )
346
+
347
+ def _build_token_request(self, raise_on_error=False):
348
+ """
349
+ Build the request body and headers for the token request.
350
+
351
+ :param raise_on_error: Whether to raise an error if the request cannot be built.
352
+ :return: A tuple containing the request body and headers.
353
+ """
354
+ offline_token = mlrun.auth.utils.load_offline_token(
355
+ raise_on_error=raise_on_error
356
+ )
357
+ if not offline_token:
358
+ # Error already handled in `_load_offline_token`
359
+ return None, None
360
+
361
+ headers = {"Content-Type": "application/json"}
362
+ request_body = {"refreshToken": offline_token}
363
+ return request_body, headers, "json"
364
+
365
+ def _parse_response(self, response_data):
366
+ """
367
+ Parse the response from the token endpoint.
368
+
369
+ :param response_data: The JSON response data from the token endpoint.
370
+ :param raise_on_error: Whether to raise an error if the response cannot be parsed.
371
+ """
372
+ spec = response_data.get("spec", {})
373
+ access_token = spec.get("accessToken")
374
+
375
+ if not access_token:
376
+ raise mlrun.errors.MLRunRuntimeError(
377
+ "Access token is missing in the response from the token endpoint"
378
+ )
379
+
380
+ self._token = access_token
381
+
382
+ self._token_total_lifetime, self._token_expiry_time = (
383
+ self._get_token_lifetime_and_expiry(access_token)
384
+ )
385
+
386
+ def _post_fetch_hook(self, raise_on_error=True):
387
+ # if we reach this point and the token is non-empty but invalid,
388
+ # it means the refresh threshold has been reached and the token will expire soon.
389
+ if self._token and not self._is_token_within_refresh_threshold(
390
+ cleanup_if_expired=True
391
+ ):
392
+ logger.warning(
393
+ "Failed to fetch a new token. Using the existing token, which remains valid but is close to expiring."
394
+ )
395
+
396
+ # Perform a secondary validation that token fetch succeeded.
397
+ # We enter this block if token fetch failed and did not raise an error
398
+ if not self._token and raise_on_error:
399
+ raise mlrun.errors.MLRunRuntimeError(
400
+ "Failed to fetch a valid access token. Authentication procedure stopped."
401
+ )
402
+
403
+ @staticmethod
404
+ def _get_token_lifetime_and_expiry(
405
+ token: str,
406
+ ) -> tuple[int, typing.Optional[datetime]]:
407
+ """
408
+ Calculate the total lifetime and expiration time of the token.
409
+
410
+ :param token: The access token to decode.
411
+ :return: A tuple containing the total lifetime of the token in seconds and its expiration time as a datetime.
412
+ """
413
+ if not token:
414
+ return 0, None
415
+ try:
416
+ # already been verified earlier during the refresh access token call
417
+ decoded_token = jwt.decode(token, options={"verify_signature": False})
418
+ exp_timestamp = decoded_token.get("exp")
419
+ iat_timestamp = decoded_token.get("iat")
420
+ if exp_timestamp and iat_timestamp:
421
+ return exp_timestamp - iat_timestamp, datetime.fromtimestamp(
422
+ exp_timestamp
423
+ )
424
+ except jwt.PyJWTError as exc:
425
+ logger.warning(
426
+ "Failed to decode access token",
427
+ error=str(exc),
428
+ )
429
+ return 0, None