mlrun 1.10.0rc40__py3-none-any.whl → 1.11.0rc16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +3 -2
- mlrun/__main__.py +0 -4
- mlrun/artifacts/dataset.py +2 -2
- mlrun/artifacts/plots.py +1 -1
- mlrun/{model_monitoring/db/tsdb/tdengine → auth}/__init__.py +2 -3
- mlrun/auth/nuclio.py +89 -0
- mlrun/auth/providers.py +429 -0
- mlrun/auth/utils.py +415 -0
- mlrun/common/constants.py +7 -0
- mlrun/common/model_monitoring/helpers.py +41 -4
- mlrun/common/runtimes/constants.py +28 -0
- mlrun/common/schemas/__init__.py +13 -3
- mlrun/common/schemas/alert.py +2 -2
- mlrun/common/schemas/api_gateway.py +3 -0
- mlrun/common/schemas/auth.py +10 -10
- mlrun/common/schemas/client_spec.py +4 -0
- mlrun/common/schemas/constants.py +25 -0
- mlrun/common/schemas/frontend_spec.py +1 -8
- mlrun/common/schemas/function.py +24 -0
- mlrun/common/schemas/hub.py +3 -2
- mlrun/common/schemas/model_monitoring/__init__.py +1 -1
- mlrun/common/schemas/model_monitoring/constants.py +2 -2
- mlrun/common/schemas/secret.py +17 -2
- mlrun/common/secrets.py +95 -1
- mlrun/common/types.py +10 -10
- mlrun/config.py +53 -15
- mlrun/data_types/infer.py +2 -2
- mlrun/datastore/__init__.py +2 -3
- mlrun/datastore/base.py +274 -10
- mlrun/datastore/datastore.py +1 -1
- mlrun/datastore/datastore_profile.py +49 -17
- mlrun/datastore/model_provider/huggingface_provider.py +6 -2
- mlrun/datastore/model_provider/model_provider.py +2 -2
- mlrun/datastore/model_provider/openai_provider.py +2 -2
- mlrun/datastore/s3.py +15 -16
- mlrun/datastore/sources.py +1 -1
- mlrun/datastore/store_resources.py +4 -4
- mlrun/datastore/storeytargets.py +16 -10
- mlrun/datastore/targets.py +1 -1
- mlrun/datastore/utils.py +16 -3
- mlrun/datastore/v3io.py +1 -1
- mlrun/db/base.py +36 -12
- mlrun/db/httpdb.py +316 -101
- mlrun/db/nopdb.py +29 -11
- mlrun/errors.py +4 -2
- mlrun/execution.py +11 -12
- mlrun/feature_store/api.py +1 -1
- mlrun/feature_store/common.py +1 -1
- mlrun/feature_store/feature_vector_utils.py +1 -1
- mlrun/feature_store/steps.py +8 -6
- mlrun/frameworks/_common/utils.py +3 -3
- mlrun/frameworks/_dl_common/loggers/logger.py +1 -1
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +2 -1
- mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +1 -1
- mlrun/frameworks/_ml_common/utils.py +2 -1
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +4 -3
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +2 -1
- mlrun/frameworks/onnx/dataset.py +2 -1
- mlrun/frameworks/onnx/mlrun_interface.py +2 -1
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +5 -4
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +2 -1
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +2 -1
- mlrun/frameworks/pytorch/utils.py +2 -1
- mlrun/frameworks/sklearn/metric.py +2 -1
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +5 -4
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +2 -1
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +2 -1
- mlrun/hub/__init__.py +37 -0
- mlrun/hub/base.py +142 -0
- mlrun/hub/module.py +67 -76
- mlrun/hub/step.py +113 -0
- mlrun/launcher/base.py +2 -1
- mlrun/launcher/local.py +2 -1
- mlrun/model.py +12 -2
- mlrun/model_monitoring/__init__.py +0 -1
- mlrun/model_monitoring/api.py +2 -2
- mlrun/model_monitoring/applications/base.py +20 -6
- mlrun/model_monitoring/applications/context.py +1 -0
- mlrun/model_monitoring/controller.py +7 -17
- mlrun/model_monitoring/db/_schedules.py +2 -16
- mlrun/model_monitoring/db/_stats.py +2 -13
- mlrun/model_monitoring/db/tsdb/__init__.py +9 -7
- mlrun/model_monitoring/db/tsdb/base.py +2 -4
- mlrun/model_monitoring/db/tsdb/preaggregate.py +234 -0
- mlrun/model_monitoring/db/tsdb/stream_graph_steps.py +63 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_metrics_queries.py +414 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_predictions_queries.py +376 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_results_queries.py +590 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connection.py +434 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connector.py +541 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_operations.py +808 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_schema.py +502 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream.py +163 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream_graph_steps.py +60 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_dataframe_processor.py +141 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_query_builder.py +585 -0
- mlrun/model_monitoring/db/tsdb/timescaledb/writer_graph_steps.py +73 -0
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +4 -6
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +147 -79
- mlrun/model_monitoring/features_drift_table.py +2 -1
- mlrun/model_monitoring/helpers.py +2 -1
- mlrun/model_monitoring/stream_processing.py +18 -16
- mlrun/model_monitoring/writer.py +4 -3
- mlrun/package/__init__.py +2 -1
- mlrun/platforms/__init__.py +0 -44
- mlrun/platforms/iguazio.py +1 -1
- mlrun/projects/operations.py +11 -10
- mlrun/projects/project.py +81 -82
- mlrun/run.py +4 -7
- mlrun/runtimes/__init__.py +2 -204
- mlrun/runtimes/base.py +89 -21
- mlrun/runtimes/constants.py +225 -0
- mlrun/runtimes/daskjob.py +4 -2
- mlrun/runtimes/databricks_job/databricks_runtime.py +2 -1
- mlrun/runtimes/mounts.py +5 -0
- mlrun/runtimes/nuclio/__init__.py +12 -8
- mlrun/runtimes/nuclio/api_gateway.py +36 -6
- mlrun/runtimes/nuclio/application/application.py +200 -32
- mlrun/runtimes/nuclio/function.py +154 -49
- mlrun/runtimes/nuclio/serving.py +55 -42
- mlrun/runtimes/pod.py +59 -10
- mlrun/secrets.py +46 -2
- mlrun/serving/__init__.py +2 -0
- mlrun/serving/remote.py +5 -5
- mlrun/serving/routers.py +3 -3
- mlrun/serving/server.py +46 -43
- mlrun/serving/serving_wrapper.py +6 -2
- mlrun/serving/states.py +554 -207
- mlrun/serving/steps.py +1 -1
- mlrun/serving/system_steps.py +42 -33
- mlrun/track/trackers/mlflow_tracker.py +29 -31
- mlrun/utils/helpers.py +89 -16
- mlrun/utils/http.py +9 -2
- mlrun/utils/notifications/notification/git.py +1 -1
- mlrun/utils/notifications/notification/mail.py +39 -16
- mlrun/utils/notifications/notification_pusher.py +2 -2
- mlrun/utils/version/version.json +2 -2
- mlrun/utils/version/version.py +3 -4
- {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/METADATA +39 -49
- {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/RECORD +144 -130
- mlrun/db/auth_utils.py +0 -152
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +0 -343
- mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +0 -75
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connection.py +0 -281
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +0 -1368
- mlrun/model_monitoring/db/tsdb/tdengine/writer_graph_steps.py +0 -51
- {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/WHEEL +0 -0
- {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/entry_points.txt +0 -0
- {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/licenses/LICENSE +0 -0
- {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/top_level.txt +0 -0
mlrun/__init__.py
CHANGED
|
@@ -24,6 +24,7 @@ __all__ = [
|
|
|
24
24
|
"v3io_cred",
|
|
25
25
|
"auto_mount",
|
|
26
26
|
"VolumeMount",
|
|
27
|
+
"sync_secret_tokens",
|
|
27
28
|
]
|
|
28
29
|
|
|
29
30
|
from os import environ, path
|
|
@@ -37,7 +38,7 @@ from .datastore import DataItem, ModelProvider, store_manager
|
|
|
37
38
|
from .db import get_run_db
|
|
38
39
|
from .errors import MLRunInvalidArgumentError, MLRunNotFoundError
|
|
39
40
|
from .execution import MLClientCtx
|
|
40
|
-
from .hub import get_hub_module, import_module
|
|
41
|
+
from .hub import get_hub_item, get_hub_module, get_hub_step, import_module
|
|
41
42
|
from .model import RunObject, RunTemplate, new_task
|
|
42
43
|
from .package import ArtifactType, DefaultPackager, Packager, handler
|
|
43
44
|
from .projects import (
|
|
@@ -68,7 +69,7 @@ from .run import (
|
|
|
68
69
|
wait_for_pipeline_completion,
|
|
69
70
|
)
|
|
70
71
|
from .runtimes import mounts, new_model_server
|
|
71
|
-
from .secrets import get_secret_or_env
|
|
72
|
+
from .secrets import get_secret_or_env, sync_secret_tokens
|
|
72
73
|
from .utils.version import Version
|
|
73
74
|
|
|
74
75
|
__version__ = Version().get()["version"]
|
mlrun/__main__.py
CHANGED
|
@@ -203,7 +203,6 @@ def main():
|
|
|
203
203
|
@click.option(
|
|
204
204
|
"--allow-cross-project",
|
|
205
205
|
is_flag=True,
|
|
206
|
-
default=True, # TODO: remove this default in 1.11
|
|
207
206
|
help="Override the loaded project name. This flag ensures awareness of loading an existing project yaml "
|
|
208
207
|
"as a baseline for a new project with a different name",
|
|
209
208
|
)
|
|
@@ -513,7 +512,6 @@ def run(
|
|
|
513
512
|
@click.option(
|
|
514
513
|
"--allow-cross-project",
|
|
515
514
|
is_flag=True,
|
|
516
|
-
default=True, # TODO: remove this default in 1.11
|
|
517
515
|
help="Override the loaded project name. This flag ensures awareness of loading an existing project yaml "
|
|
518
516
|
"as a baseline for a new project with a different name",
|
|
519
517
|
)
|
|
@@ -672,7 +670,6 @@ def build(
|
|
|
672
670
|
@click.option(
|
|
673
671
|
"--allow-cross-project",
|
|
674
672
|
is_flag=True,
|
|
675
|
-
default=True, # TODO: remove this default in 1.11
|
|
676
673
|
help="Override the loaded project name. This flag ensures awareness of loading an existing project yaml "
|
|
677
674
|
"as a baseline for a new project with a different name",
|
|
678
675
|
)
|
|
@@ -1008,7 +1005,6 @@ def logs(uid, project, offset, db):
|
|
|
1008
1005
|
@click.option(
|
|
1009
1006
|
"--allow-cross-project",
|
|
1010
1007
|
is_flag=True,
|
|
1011
|
-
default=True, # TODO: remove this default in 1.11
|
|
1012
1008
|
help="Override the loaded project name. This flag ensures awareness of loading an existing project yaml "
|
|
1013
1009
|
"as a baseline for a new project with a different name",
|
|
1014
1010
|
)
|
mlrun/artifacts/dataset.py
CHANGED
|
@@ -366,9 +366,9 @@ def get_df_stats(df):
|
|
|
366
366
|
for col, values in df.describe(include="all").items():
|
|
367
367
|
stats_dict = {}
|
|
368
368
|
for stat, val in values.dropna().items():
|
|
369
|
-
if isinstance(val,
|
|
369
|
+
if isinstance(val, float | np.floating | np.float64):
|
|
370
370
|
stats_dict[stat] = float(val)
|
|
371
|
-
elif isinstance(val,
|
|
371
|
+
elif isinstance(val, int | np.integer | np.int64):
|
|
372
372
|
stats_dict[stat] = int(val)
|
|
373
373
|
else:
|
|
374
374
|
stats_dict[stat] = str(val)
|
mlrun/artifacts/plots.py
CHANGED
|
@@ -42,7 +42,7 @@ class PlotArtifact(Artifact):
|
|
|
42
42
|
import matplotlib
|
|
43
43
|
|
|
44
44
|
if not self.spec.get_body() or not isinstance(
|
|
45
|
-
self.spec.get_body(),
|
|
45
|
+
self.spec.get_body(), bytes | matplotlib.figure.Figure
|
|
46
46
|
):
|
|
47
47
|
raise ValueError(
|
|
48
48
|
"matplotlib fig or png bytes must be provided as artifact body"
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright
|
|
1
|
+
# Copyright 2025 Iguazio
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -11,5 +11,4 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
|
|
15
|
-
from .tdengine_connector import TDEngineConnector
|
|
14
|
+
from .providers import IGTokenProvider, OAuthClientIDTokenProvider, StaticTokenProvider
|
mlrun/auth/nuclio.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Copyright 2023 Iguazio
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
|
|
17
|
+
import requests.auth
|
|
18
|
+
from nuclio.auth import AuthInfo as _NuclioAuthInfo
|
|
19
|
+
from nuclio.auth import AuthKinds as NuclioAuthKinds
|
|
20
|
+
|
|
21
|
+
import mlrun.auth.providers
|
|
22
|
+
import mlrun.common.schemas.auth
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class NuclioAuthInfo(_NuclioAuthInfo):
|
|
26
|
+
def __init__(self, token=None, **kwargs):
|
|
27
|
+
super().__init__(**kwargs)
|
|
28
|
+
self._token = token
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def from_auth_info(cls, auth_info: "mlrun.common.schemas.auth.AuthInfo"):
|
|
32
|
+
if not auth_info:
|
|
33
|
+
return None
|
|
34
|
+
if mlrun.mlconf.is_iguazio_v4_mode():
|
|
35
|
+
return cls.from_request_headers(auth_info.request_headers)
|
|
36
|
+
if auth_info.session != "":
|
|
37
|
+
return NuclioAuthInfo(
|
|
38
|
+
password=auth_info.session, mode=NuclioAuthKinds.iguazio
|
|
39
|
+
)
|
|
40
|
+
return None
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_request_headers(cls, headers: dict[str, str]):
|
|
44
|
+
if not headers:
|
|
45
|
+
return cls()
|
|
46
|
+
for key, value in headers.items():
|
|
47
|
+
if key.lower() == "authorization":
|
|
48
|
+
if value.lower().startswith("bearer "):
|
|
49
|
+
return cls(
|
|
50
|
+
token=value[len("bearer ") :],
|
|
51
|
+
mode=NuclioAuthKinds.iguazio,
|
|
52
|
+
)
|
|
53
|
+
if value.lower().startswith("basic "):
|
|
54
|
+
token = value[len("basic ") :]
|
|
55
|
+
decoded_token = base64.b64decode(token).decode("utf-8")
|
|
56
|
+
username, password = decoded_token.split(":", 1)
|
|
57
|
+
return cls(
|
|
58
|
+
username=username,
|
|
59
|
+
password=password,
|
|
60
|
+
mode=NuclioAuthKinds.iguazio,
|
|
61
|
+
)
|
|
62
|
+
return cls()
|
|
63
|
+
|
|
64
|
+
@classmethod
|
|
65
|
+
def from_envvar(cls):
|
|
66
|
+
if mlrun.mlconf.is_iguazio_v4_mode():
|
|
67
|
+
token_provider = mlrun.auth.providers.IGTokenProvider(
|
|
68
|
+
token_endpoint=mlrun.mlconf.auth_token_endpoint,
|
|
69
|
+
)
|
|
70
|
+
return cls(
|
|
71
|
+
token=token_provider.get_token(),
|
|
72
|
+
mode=NuclioAuthKinds.iguazio,
|
|
73
|
+
)
|
|
74
|
+
return super().from_envvar()
|
|
75
|
+
|
|
76
|
+
def to_requests_auth(self) -> "requests.auth":
|
|
77
|
+
if self._token:
|
|
78
|
+
# in iguazio v4 mode we use bearer token auth
|
|
79
|
+
return _RequestAuthBearerToken(self._token)
|
|
80
|
+
return super().to_requests_auth()
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class _RequestAuthBearerToken(requests.auth.AuthBase):
|
|
84
|
+
def __init__(self, token: str):
|
|
85
|
+
self._token = token
|
|
86
|
+
|
|
87
|
+
def __call__(self, r):
|
|
88
|
+
r.headers["Authorization"] = f"Bearer {self._token}"
|
|
89
|
+
return r
|
mlrun/auth/providers.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
# Copyright 2025 Iguazio
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import typing
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from datetime import datetime, timedelta
|
|
18
|
+
|
|
19
|
+
import jwt
|
|
20
|
+
import requests
|
|
21
|
+
|
|
22
|
+
import mlrun.auth.utils
|
|
23
|
+
import mlrun.errors
|
|
24
|
+
import mlrun.secrets
|
|
25
|
+
import mlrun.utils.helpers
|
|
26
|
+
from mlrun.config import config as mlconf
|
|
27
|
+
from mlrun.utils import logger
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TokenProvider(ABC):
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def get_token(self):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def is_iguazio_session(self):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class StaticTokenProvider(TokenProvider):
|
|
41
|
+
def __init__(self, token: str):
|
|
42
|
+
self.token = token
|
|
43
|
+
|
|
44
|
+
def get_token(self):
|
|
45
|
+
return self.token
|
|
46
|
+
|
|
47
|
+
def is_iguazio_session(self):
|
|
48
|
+
return mlrun.platforms.iguazio.is_iguazio_session(self.token)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DynamicTokenProvider(TokenProvider):
|
|
52
|
+
"""
|
|
53
|
+
A token provider that dynamically fetches and refreshes tokens from a token endpoint.
|
|
54
|
+
|
|
55
|
+
This class handles token retrieval and automatic refresh when the token is expired or about to expire.
|
|
56
|
+
It uses a session with retry capabilities for robust communication with the token endpoint.
|
|
57
|
+
|
|
58
|
+
:param token_endpoint: The URL of the token endpoint.
|
|
59
|
+
:param timeout: The timeout for token requests, in seconds.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, token_endpoint: str, timeout=5, max_retries=0):
|
|
63
|
+
if not token_endpoint:
|
|
64
|
+
raise mlrun.errors.MLRunValueError(
|
|
65
|
+
"No token endpoint provided, cannot initialize token provider"
|
|
66
|
+
)
|
|
67
|
+
self._token = None
|
|
68
|
+
self._token_endpoint = token_endpoint
|
|
69
|
+
self._timeout = timeout
|
|
70
|
+
self._max_retries = max_retries
|
|
71
|
+
|
|
72
|
+
# Since we're only issuing POST requests, which are actually a disguised GET, then it's ok to allow retries
|
|
73
|
+
# on them.
|
|
74
|
+
self._session = mlrun.utils.HTTPSessionWithRetry(
|
|
75
|
+
retry_on_post=True,
|
|
76
|
+
verbose=True,
|
|
77
|
+
)
|
|
78
|
+
self._cleanup()
|
|
79
|
+
self._refresh_token_if_needed()
|
|
80
|
+
|
|
81
|
+
def get_token(self):
|
|
82
|
+
"""
|
|
83
|
+
Retrieve the current access token, refreshing it if necessary.
|
|
84
|
+
|
|
85
|
+
:return: The current access token.
|
|
86
|
+
"""
|
|
87
|
+
self._refresh_token_if_needed()
|
|
88
|
+
return self._token
|
|
89
|
+
|
|
90
|
+
def is_iguazio_session(self):
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
def fetch_token(self):
|
|
94
|
+
mlrun.utils.helpers.run_with_retry(
|
|
95
|
+
retry_count=self._max_retries,
|
|
96
|
+
func=self._fetch_token,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def _fetch_token(self):
|
|
100
|
+
"""
|
|
101
|
+
Fetch a new access token from the token endpoint.
|
|
102
|
+
|
|
103
|
+
This method builds the token request, sends it to the token endpoint, and parses the response.
|
|
104
|
+
If the request fails, it either raises an error or logs a warning based on the `raise_on_error` parameter.
|
|
105
|
+
"""
|
|
106
|
+
request_body, headers, body_type = self._build_token_request(
|
|
107
|
+
raise_on_error=True
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
request_kwargs = {
|
|
112
|
+
"method": "POST",
|
|
113
|
+
"url": self._token_endpoint,
|
|
114
|
+
"timeout": self._timeout,
|
|
115
|
+
"headers": headers,
|
|
116
|
+
"verify": mlconf.httpdb.http.verify,
|
|
117
|
+
}
|
|
118
|
+
if body_type == "json":
|
|
119
|
+
request_kwargs["json"] = request_body
|
|
120
|
+
else:
|
|
121
|
+
request_kwargs["data"] = request_body
|
|
122
|
+
|
|
123
|
+
response = self._session.request(**request_kwargs)
|
|
124
|
+
except requests.RequestException as exc:
|
|
125
|
+
error = f"Retrieving token failed: {mlrun.errors.err_to_str(exc)}"
|
|
126
|
+
raise mlrun.errors.MLRunRuntimeError(error) from exc
|
|
127
|
+
|
|
128
|
+
if not response.ok:
|
|
129
|
+
error = "No error available"
|
|
130
|
+
if response.content:
|
|
131
|
+
try:
|
|
132
|
+
data = response.json()
|
|
133
|
+
error = data.get("error")
|
|
134
|
+
except Exception:
|
|
135
|
+
pass
|
|
136
|
+
logger.warning(
|
|
137
|
+
"Retrieving token failed", status=response.status_code, error=error
|
|
138
|
+
)
|
|
139
|
+
mlrun.errors.raise_for_status(response)
|
|
140
|
+
|
|
141
|
+
self._parse_response(response.json())
|
|
142
|
+
|
|
143
|
+
def _refresh_token_if_needed(self):
|
|
144
|
+
"""
|
|
145
|
+
Refresh the access token if it is expired or about to expire.
|
|
146
|
+
|
|
147
|
+
:return: The refreshed access token.
|
|
148
|
+
"""
|
|
149
|
+
raise_on_error = True
|
|
150
|
+
|
|
151
|
+
# Check if there is an existing access token and if it is within the refresh threshold
|
|
152
|
+
if self._token and self._is_token_within_refresh_threshold(
|
|
153
|
+
cleanup_if_expired=True
|
|
154
|
+
):
|
|
155
|
+
return self._token
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
self.fetch_token()
|
|
159
|
+
except Exception as exc:
|
|
160
|
+
raise_on_error = False
|
|
161
|
+
# Token fetch failed and there is no existing token - cannot proceed
|
|
162
|
+
if not self._token:
|
|
163
|
+
raise mlrun.errors.MLRunRuntimeError(
|
|
164
|
+
"Failed to fetch a valid access token. Authentication procedure stopped."
|
|
165
|
+
) from exc
|
|
166
|
+
|
|
167
|
+
finally:
|
|
168
|
+
self._post_fetch_hook(raise_on_error)
|
|
169
|
+
|
|
170
|
+
return self._token
|
|
171
|
+
|
|
172
|
+
@abstractmethod
|
|
173
|
+
def _post_fetch_hook(self, raise_on_error=True):
|
|
174
|
+
"""
|
|
175
|
+
A hook that is called after fetching a new token.
|
|
176
|
+
Can be used to perform additional actions, such as logging or updating state.
|
|
177
|
+
"""
|
|
178
|
+
pass
|
|
179
|
+
|
|
180
|
+
@abstractmethod
|
|
181
|
+
def _is_token_within_refresh_threshold(self, cleanup_if_expired=True) -> bool:
|
|
182
|
+
"""
|
|
183
|
+
Check if the current access token is valid.
|
|
184
|
+
|
|
185
|
+
:param cleanup_if_expired: Whether to clean up the token if it is expired.
|
|
186
|
+
:return: True if the token is valid, False otherwise.
|
|
187
|
+
"""
|
|
188
|
+
pass
|
|
189
|
+
|
|
190
|
+
@abstractmethod
|
|
191
|
+
def _cleanup(self):
|
|
192
|
+
"""
|
|
193
|
+
Clean up the token and related metadata.
|
|
194
|
+
"""
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
@abstractmethod
|
|
198
|
+
def _build_token_request(self, raise_on_error=False):
|
|
199
|
+
"""
|
|
200
|
+
Build the request body and headers for the token request.
|
|
201
|
+
|
|
202
|
+
:param raise_on_error: Whether to raise an error if the request cannot be built.
|
|
203
|
+
:return: A tuple containing the request body and headers.
|
|
204
|
+
"""
|
|
205
|
+
pass
|
|
206
|
+
|
|
207
|
+
@abstractmethod
|
|
208
|
+
def _parse_response(self, data: dict):
|
|
209
|
+
"""
|
|
210
|
+
Parse the response from the token endpoint.
|
|
211
|
+
|
|
212
|
+
:param data: The JSON response data from the token endpoint.
|
|
213
|
+
"""
|
|
214
|
+
pass
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class OAuthClientIDTokenProvider(DynamicTokenProvider):
|
|
218
|
+
def __init__(
|
|
219
|
+
self, token_endpoint: str, client_id: str, client_secret: str, timeout=5
|
|
220
|
+
):
|
|
221
|
+
if not token_endpoint or not client_id or not client_secret:
|
|
222
|
+
raise mlrun.errors.MLRunValueError(
|
|
223
|
+
"Invalid client_id configuration for authentication. Must provide token endpoint, client-id and secret"
|
|
224
|
+
)
|
|
225
|
+
# should be set before calling the parent constructor
|
|
226
|
+
self._client_id = client_id
|
|
227
|
+
self._client_secret = client_secret
|
|
228
|
+
super().__init__(token_endpoint=token_endpoint, timeout=timeout)
|
|
229
|
+
|
|
230
|
+
def _cleanup(self):
|
|
231
|
+
self._token = self.token_expiry_time = self.token_refresh_time = None
|
|
232
|
+
|
|
233
|
+
def _is_token_within_refresh_threshold(self, cleanup_if_expired=True) -> bool:
|
|
234
|
+
"""
|
|
235
|
+
Check if the current access token is valid.
|
|
236
|
+
|
|
237
|
+
:param cleanup_if_expired: Whether to clean up the token if it is expired.
|
|
238
|
+
:return: True if the token is valid, False otherwise.
|
|
239
|
+
"""
|
|
240
|
+
if not self._token or not self.token_expiry_time:
|
|
241
|
+
return False
|
|
242
|
+
|
|
243
|
+
now = datetime.now()
|
|
244
|
+
|
|
245
|
+
if now <= self.token_refresh_time:
|
|
246
|
+
return True
|
|
247
|
+
|
|
248
|
+
if now < self.token_expiry_time:
|
|
249
|
+
# past refresh time but not expired yet → not valid
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
# expired
|
|
253
|
+
if cleanup_if_expired:
|
|
254
|
+
# We only cleanup if token was really expired - even if we fail in refreshing the token, we can still
|
|
255
|
+
# use the existing one given that it's not expired.
|
|
256
|
+
self._cleanup()
|
|
257
|
+
return False
|
|
258
|
+
|
|
259
|
+
def _build_token_request(self, raise_on_error=False):
|
|
260
|
+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
|
261
|
+
request_body = {
|
|
262
|
+
"grant_type": "client_credentials",
|
|
263
|
+
"client_id": self._client_id,
|
|
264
|
+
"client_secret": self._client_secret,
|
|
265
|
+
}
|
|
266
|
+
return request_body, headers, "data"
|
|
267
|
+
|
|
268
|
+
def _parse_response(self, data: dict):
|
|
269
|
+
# Response is described in https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.3
|
|
270
|
+
# According to spec, there isn't a refresh token - just the access token and its expiry time (in seconds).
|
|
271
|
+
self._token = data.get("access_token")
|
|
272
|
+
expires_in = data.get("expires_in")
|
|
273
|
+
if not self._token or not expires_in:
|
|
274
|
+
token_str = "****" if self._token else "missing"
|
|
275
|
+
logger.warning(
|
|
276
|
+
"Failed to parse token response", token=token_str, expires_in=expires_in
|
|
277
|
+
)
|
|
278
|
+
return
|
|
279
|
+
|
|
280
|
+
now = datetime.now()
|
|
281
|
+
self.token_expiry_time = now + timedelta(seconds=expires_in)
|
|
282
|
+
self.token_refresh_time = now + timedelta(seconds=expires_in / 2)
|
|
283
|
+
logger.info(
|
|
284
|
+
"Successfully retrieved client-id token",
|
|
285
|
+
expires_in=expires_in,
|
|
286
|
+
expiry=str(self.token_expiry_time),
|
|
287
|
+
refresh=str(self.token_refresh_time),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def _post_fetch_hook(self, raise_on_error=True):
|
|
291
|
+
"""
|
|
292
|
+
A hook that is called after fetching a new token.
|
|
293
|
+
Can be used to perform additional actions, such as logging or updating state.
|
|
294
|
+
"""
|
|
295
|
+
pass
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class IGTokenProvider(DynamicTokenProvider):
|
|
299
|
+
"""
|
|
300
|
+
A token provider for Iguazio that uses a refresh token to fetch access tokens.
|
|
301
|
+
|
|
302
|
+
This class implements the Iguazio-specific token refresh flow to retrieve access tokens
|
|
303
|
+
from a token endpoint.
|
|
304
|
+
|
|
305
|
+
:param token_endpoint: The URL of the token endpoint.
|
|
306
|
+
:param timeout: The timeout for token requests, in seconds.
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(self, token_endpoint: str, timeout=5):
|
|
310
|
+
super().__init__(token_endpoint=token_endpoint, timeout=timeout, max_retries=2)
|
|
311
|
+
|
|
312
|
+
@property
|
|
313
|
+
def authenticated_user_id(self) -> typing.Optional[str]:
|
|
314
|
+
return mlrun.auth.utils.resolve_jwt_subject(self._token, raise_on_error=True)
|
|
315
|
+
|
|
316
|
+
def _cleanup(self):
|
|
317
|
+
self._token = None
|
|
318
|
+
self._token_total_lifetime = 0
|
|
319
|
+
self._token_expiry_time = None
|
|
320
|
+
|
|
321
|
+
def _is_token_within_refresh_threshold(self, cleanup_if_expired=True) -> bool:
|
|
322
|
+
"""
|
|
323
|
+
Check if the current access token is valid and has sufficient lifetime remaining.
|
|
324
|
+
|
|
325
|
+
:param cleanup_if_expired: Whether to clean up the token if it is expired.
|
|
326
|
+
:return: True if the token is valid, False otherwise.
|
|
327
|
+
"""
|
|
328
|
+
if (
|
|
329
|
+
not self._token
|
|
330
|
+
or self._token_total_lifetime <= 0
|
|
331
|
+
or not self._token_expiry_time
|
|
332
|
+
):
|
|
333
|
+
return False
|
|
334
|
+
|
|
335
|
+
now = datetime.now()
|
|
336
|
+
remaining_lifetime = (self._token_expiry_time - now).total_seconds()
|
|
337
|
+
if remaining_lifetime <= 0 and cleanup_if_expired:
|
|
338
|
+
self._cleanup()
|
|
339
|
+
return False
|
|
340
|
+
|
|
341
|
+
return (
|
|
342
|
+
self._token_total_lifetime - remaining_lifetime
|
|
343
|
+
< self._token_total_lifetime
|
|
344
|
+
* mlconf.auth_with_oauth_token.refresh_threshold
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
def _build_token_request(self, raise_on_error=False):
|
|
348
|
+
"""
|
|
349
|
+
Build the request body and headers for the token request.
|
|
350
|
+
|
|
351
|
+
:param raise_on_error: Whether to raise an error if the request cannot be built.
|
|
352
|
+
:return: A tuple containing the request body and headers.
|
|
353
|
+
"""
|
|
354
|
+
offline_token = mlrun.auth.utils.load_offline_token(
|
|
355
|
+
raise_on_error=raise_on_error
|
|
356
|
+
)
|
|
357
|
+
if not offline_token:
|
|
358
|
+
# Error already handled in `_load_offline_token`
|
|
359
|
+
return None, None
|
|
360
|
+
|
|
361
|
+
headers = {"Content-Type": "application/json"}
|
|
362
|
+
request_body = {"refreshToken": offline_token}
|
|
363
|
+
return request_body, headers, "json"
|
|
364
|
+
|
|
365
|
+
def _parse_response(self, response_data):
|
|
366
|
+
"""
|
|
367
|
+
Parse the response from the token endpoint.
|
|
368
|
+
|
|
369
|
+
:param response_data: The JSON response data from the token endpoint.
|
|
370
|
+
:param raise_on_error: Whether to raise an error if the response cannot be parsed.
|
|
371
|
+
"""
|
|
372
|
+
spec = response_data.get("spec", {})
|
|
373
|
+
access_token = spec.get("accessToken")
|
|
374
|
+
|
|
375
|
+
if not access_token:
|
|
376
|
+
raise mlrun.errors.MLRunRuntimeError(
|
|
377
|
+
"Access token is missing in the response from the token endpoint"
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
self._token = access_token
|
|
381
|
+
|
|
382
|
+
self._token_total_lifetime, self._token_expiry_time = (
|
|
383
|
+
self._get_token_lifetime_and_expiry(access_token)
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
def _post_fetch_hook(self, raise_on_error=True):
|
|
387
|
+
# if we reach this point and the token is non-empty but invalid,
|
|
388
|
+
# it means the refresh threshold has been reached and the token will expire soon.
|
|
389
|
+
if self._token and not self._is_token_within_refresh_threshold(
|
|
390
|
+
cleanup_if_expired=True
|
|
391
|
+
):
|
|
392
|
+
logger.warning(
|
|
393
|
+
"Failed to fetch a new token. Using the existing token, which remains valid but is close to expiring."
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
# Perform a secondary validation that token fetch succeeded.
|
|
397
|
+
# We enter this block if token fetch failed and did not raise an error
|
|
398
|
+
if not self._token and raise_on_error:
|
|
399
|
+
raise mlrun.errors.MLRunRuntimeError(
|
|
400
|
+
"Failed to fetch a valid access token. Authentication procedure stopped."
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
@staticmethod
|
|
404
|
+
def _get_token_lifetime_and_expiry(
|
|
405
|
+
token: str,
|
|
406
|
+
) -> tuple[int, typing.Optional[datetime]]:
|
|
407
|
+
"""
|
|
408
|
+
Calculate the total lifetime and expiration time of the token.
|
|
409
|
+
|
|
410
|
+
:param token: The access token to decode.
|
|
411
|
+
:return: A tuple containing the total lifetime of the token in seconds and its expiration time as a datetime.
|
|
412
|
+
"""
|
|
413
|
+
if not token:
|
|
414
|
+
return 0, None
|
|
415
|
+
try:
|
|
416
|
+
# already been verified earlier during the refresh access token call
|
|
417
|
+
decoded_token = jwt.decode(token, options={"verify_signature": False})
|
|
418
|
+
exp_timestamp = decoded_token.get("exp")
|
|
419
|
+
iat_timestamp = decoded_token.get("iat")
|
|
420
|
+
if exp_timestamp and iat_timestamp:
|
|
421
|
+
return exp_timestamp - iat_timestamp, datetime.fromtimestamp(
|
|
422
|
+
exp_timestamp
|
|
423
|
+
)
|
|
424
|
+
except jwt.PyJWTError as exc:
|
|
425
|
+
logger.warning(
|
|
426
|
+
"Failed to decode access token",
|
|
427
|
+
error=str(exc),
|
|
428
|
+
)
|
|
429
|
+
return 0, None
|