mlrun 1.7.0rc13__py3-none-any.whl → 1.7.0rc14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

@@ -110,6 +110,7 @@ class ProjectSummary(pydantic.BaseModel):
110
110
  files_count: int
111
111
  feature_sets_count: int
112
112
  models_count: int
113
+ runs_completed_recent_count: int
113
114
  runs_failed_recent_count: int
114
115
  runs_running_count: int
115
116
  schedules_count: int
mlrun/config.py CHANGED
@@ -188,6 +188,7 @@ default_config = {
188
188
  "background_tasks": {
189
189
  # enabled / disabled
190
190
  "timeout_mode": "enabled",
191
+ "function_deletion_batch_size": 10,
191
192
  # timeout in seconds to wait for background task to be updated / finished by the worker responsible for the task
192
193
  "default_timeouts": {
193
194
  "operations": {
@@ -196,6 +197,7 @@ default_config = {
196
197
  "run_abortion": "600",
197
198
  "abort_grace_period": "10",
198
199
  "delete_project": "900",
200
+ "delete_function": "900",
199
201
  },
200
202
  "runtimes": {"dask": "600"},
201
203
  },
@@ -692,6 +694,10 @@ default_config = {
692
694
  # supported modes: "enabled", "disabled".
693
695
  "mode": "disabled"
694
696
  },
697
+ "auth_with_client_id": {
698
+ "enabled": False,
699
+ "request_timeout": 5,
700
+ },
695
701
  }
696
702
 
697
703
  _is_running_as_api = None
@@ -1396,7 +1402,11 @@ def read_env(env=None, prefix=env_prefix):
1396
1402
  log_formatter = mlrun.utils.create_formatter_instance(
1397
1403
  mlrun.utils.FormatterKinds(log_formatter_name)
1398
1404
  )
1399
- mlrun.utils.logger.get_handler("default").setFormatter(log_formatter)
1405
+ current_handler = mlrun.utils.logger.get_handler("default")
1406
+ current_formatter_name = current_handler.formatter.__class__.__name__
1407
+ desired_formatter_name = log_formatter.__class__.__name__
1408
+ if current_formatter_name != desired_formatter_name:
1409
+ current_handler.setFormatter(log_formatter)
1400
1410
 
1401
1411
  # The default function pod resource values are of type str; however, when reading from environment variable numbers,
1402
1412
  # it converts them to type int if contains only number, so we want to convert them to str.
mlrun/datastore/hdfs.py CHANGED
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import os
15
+ from urllib.parse import urlparse
15
16
 
16
17
  import fsspec
17
18
 
@@ -49,3 +50,7 @@ class HdfsStore(DataStore):
49
50
  @property
50
51
  def spark_url(self):
51
52
  return f"hdfs://{self.host}:{self.port}"
53
+
54
+ def rm(self, url, recursive=False, maxdepth=None):
55
+ path = urlparse(url).path
56
+ self.filesystem.rm(path=path, recursive=recursive, maxdepth=maxdepth)
@@ -1390,6 +1390,39 @@ class RedisNoSqlTarget(NoSqlBaseTarget):
1390
1390
  support_spark = True
1391
1391
  writer_step_name = "RedisNoSqlTarget"
1392
1392
 
1393
+ @property
1394
+ def _target_path_object(self):
1395
+ url = self.path or mlrun.mlconf.redis.url
1396
+ if self._resource and url:
1397
+ parsed_url = urlparse(url)
1398
+ if not parsed_url.path or parsed_url.path == "/":
1399
+ kind_prefix = (
1400
+ "sets"
1401
+ if self._resource.kind
1402
+ == mlrun.common.schemas.ObjectKind.feature_set
1403
+ else "vectors"
1404
+ )
1405
+ kind = self.kind
1406
+ name = self._resource.metadata.name
1407
+ project = (
1408
+ self._resource.metadata.project or mlrun.mlconf.default_project
1409
+ )
1410
+ data_prefix = get_default_prefix_for_target(kind).format(
1411
+ ds_profile_name=parsed_url.netloc,
1412
+ authority=parsed_url.netloc,
1413
+ project=project,
1414
+ kind=kind,
1415
+ name=name,
1416
+ )
1417
+ if url.startswith("rediss://"):
1418
+ data_prefix = data_prefix.replace("redis://", "rediss://", 1)
1419
+ if not self.run_id:
1420
+ version = self._resource.metadata.tag or "latest"
1421
+ name = f"{name}-{version}"
1422
+ url = f"{data_prefix}/{kind_prefix}/{name}"
1423
+ return TargetPathObject(url, self.run_id, False)
1424
+ return super()._target_path_object
1425
+
1393
1426
  # Fetch server url from the RedisNoSqlTarget::__init__() 'path' parameter.
1394
1427
  # If not set fetch it from 'mlrun.mlconf.redis.url' (MLRUN_REDIS__URL environment variable).
1395
1428
  # Then look for username and password at REDIS_xxx secrets
@@ -1516,6 +1549,25 @@ class StreamTarget(BaseStoreTarget):
1516
1549
 
1517
1550
 
1518
1551
  class KafkaTarget(BaseStoreTarget):
1552
+ """
1553
+ Kafka target storage driver, used to write data into kafka topics.
1554
+ example::
1555
+ # define target
1556
+ kafka_target = KafkaTarget(
1557
+ name="kafka", path="my_topic", brokers="localhost:9092"
1558
+ )
1559
+ # ingest
1560
+ stocks_set.ingest(stocks, [kafka_target])
1561
+ :param name: target name
1562
+ :param path: topic name e.g. "my_topic"
1563
+ :param after_step: optional, after what step in the graph to add the target
1564
+ :param columns: optional, which columns from data to write
1565
+ :param bootstrap_servers: Deprecated. Use the brokers parameter instead
1566
+ :param producer_options: additional configurations for kafka producer
1567
+ :param brokers: kafka broker as represented by a host:port pair, or a list of kafka brokers, e.g.
1568
+ "localhost:9092", or ["kafka-broker-1:9092", "kafka-broker-2:9092"]
1569
+ """
1570
+
1519
1571
  kind = TargetTypes.kafka
1520
1572
  is_table = False
1521
1573
  is_online = False
mlrun/datastore/v3io.py CHANGED
@@ -29,7 +29,7 @@ from .base import (
29
29
  )
30
30
 
31
31
  V3IO_LOCAL_ROOT = "v3io"
32
- V3IO_DEFAULT_UPLOAD_CHUNK_SIZE = 1024 * 1024 * 100
32
+ V3IO_DEFAULT_UPLOAD_CHUNK_SIZE = 1024 * 1024 * 10
33
33
 
34
34
 
35
35
  class V3ioStore(DataStore):
mlrun/db/auth_utils.py ADDED
@@ -0,0 +1,152 @@
1
+ # Copyright 2024 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from datetime import datetime, timedelta
17
+
18
+ import requests
19
+
20
+ import mlrun.errors
21
+ from mlrun.utils import logger
22
+
23
+
24
+ class TokenProvider(ABC):
25
+ @abstractmethod
26
+ def get_token(self):
27
+ pass
28
+
29
+ @abstractmethod
30
+ def is_iguazio_session(self):
31
+ pass
32
+
33
+
34
+ class StaticTokenProvider(TokenProvider):
35
+ def __init__(self, token: str):
36
+ self.token = token
37
+
38
+ def get_token(self):
39
+ return self.token
40
+
41
+ def is_iguazio_session(self):
42
+ return mlrun.platforms.iguazio.is_iguazio_session(self.token)
43
+
44
+
45
+ class OAuthClientIDTokenProvider(TokenProvider):
46
+ def __init__(
47
+ self, token_endpoint: str, client_id: str, client_secret: str, timeout=5
48
+ ):
49
+ if not token_endpoint or not client_id or not client_secret:
50
+ raise mlrun.errors.MLRunValueError(
51
+ "Invalid client_id configuration for authentication. Must provide token endpoint, client-id and secret"
52
+ )
53
+ self.token_endpoint = token_endpoint
54
+ self.client_id = client_id
55
+ self.client_secret = client_secret
56
+ self.timeout = timeout
57
+
58
+ # Since we're only issuing POST requests, which are actually a disguised GET, then it's ok to allow retries
59
+ # on them.
60
+ self._session = mlrun.utils.HTTPSessionWithRetry(
61
+ retry_on_post=True,
62
+ verbose=True,
63
+ )
64
+
65
+ self._cleanup()
66
+ self._refresh_token_if_needed()
67
+
68
+ def get_token(self):
69
+ self._refresh_token_if_needed()
70
+ return self.token
71
+
72
+ def is_iguazio_session(self):
73
+ return False
74
+
75
+ def _cleanup(self):
76
+ self.token = self.token_expiry_time = self.token_refresh_time = None
77
+
78
+ def _refresh_token_if_needed(self):
79
+ now = datetime.now()
80
+ if self.token:
81
+ if self.token_refresh_time and now <= self.token_refresh_time:
82
+ return self.token
83
+
84
+ # We only cleanup if token was really expired - even if we fail in refreshing the token, we can still
85
+ # use the existing one given that it's not expired.
86
+ if now >= self.token_expiry_time:
87
+ self._cleanup()
88
+
89
+ self._issue_token_request()
90
+ return self.token
91
+
92
+ def _issue_token_request(self, raise_on_error=False):
93
+ try:
94
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
95
+ request_body = {
96
+ "grant_type": "client_credentials",
97
+ "client_id": self.client_id,
98
+ "client_secret": self.client_secret,
99
+ }
100
+ response = self._session.request(
101
+ "POST",
102
+ self.token_endpoint,
103
+ timeout=self.timeout,
104
+ headers=headers,
105
+ data=request_body,
106
+ )
107
+ except requests.RequestException as exc:
108
+ error = f"Retrieving token failed: {mlrun.errors.err_to_str(exc)}"
109
+ if raise_on_error:
110
+ raise mlrun.errors.MLRunRuntimeError(error) from exc
111
+ else:
112
+ logger.warning(error)
113
+ return
114
+
115
+ if not response.ok:
116
+ error = "No error available"
117
+ if response.content:
118
+ try:
119
+ data = response.json()
120
+ error = data.get("error")
121
+ except Exception:
122
+ pass
123
+ logger.warning(
124
+ "Retrieving token failed", status=response.status_code, error=error
125
+ )
126
+ if raise_on_error:
127
+ mlrun.errors.raise_for_status(response)
128
+ return
129
+
130
+ self._parse_response(response.json())
131
+
132
+ def _parse_response(self, data: dict):
133
+ # Response is described in https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.3
134
+ # According to spec, there isn't a refresh token - just the access token and its expiry time (in seconds).
135
+ self.token = data.get("access_token")
136
+ expires_in = data.get("expires_in")
137
+ if not self.token or not expires_in:
138
+ token_str = "****" if self.token else "missing"
139
+ logger.warning(
140
+ "Failed to parse token response", token=token_str, expires_in=expires_in
141
+ )
142
+ return
143
+
144
+ now = datetime.now()
145
+ self.token_expiry_time = now + timedelta(seconds=expires_in)
146
+ self.token_refresh_time = now + timedelta(seconds=expires_in / 2)
147
+ logger.info(
148
+ "Successfully retrieved client-id token",
149
+ expires_in=expires_in,
150
+ expiry=str(self.token_expiry_time),
151
+ refresh=str(self.token_refresh_time),
152
+ )
mlrun/db/base.py CHANGED
@@ -802,7 +802,7 @@ class RunDBInterface(ABC):
802
802
  project: str,
803
803
  base_period: int = 10,
804
804
  image: str = "mlrun/mlrun",
805
- ):
805
+ ) -> None:
806
806
  pass
807
807
 
808
808
  @abstractmethod
mlrun/db/httpdb.py CHANGED
@@ -38,6 +38,7 @@ import mlrun.platforms
38
38
  import mlrun.projects
39
39
  import mlrun.runtimes.nuclio.api_gateway
40
40
  import mlrun.utils
41
+ from mlrun.db.auth_utils import OAuthClientIDTokenProvider, StaticTokenProvider
41
42
  from mlrun.errors import MLRunInvalidArgumentError, err_to_str
42
43
 
43
44
  from ..artifacts import Artifact
@@ -138,17 +139,28 @@ class HTTPRunDB(RunDBInterface):
138
139
  endpoint += f":{parsed_url.port}"
139
140
  base_url = f"{parsed_url.scheme}://{endpoint}{parsed_url.path}"
140
141
 
142
+ self.base_url = base_url
141
143
  username = parsed_url.username or config.httpdb.user
142
144
  password = parsed_url.password or config.httpdb.password
145
+ self.token_provider = None
143
146
 
144
- username, password, token = mlrun.platforms.add_or_refresh_credentials(
145
- parsed_url.hostname, username, password, config.httpdb.token
146
- )
147
+ if config.auth_with_client_id.enabled:
148
+ self.token_provider = OAuthClientIDTokenProvider(
149
+ token_endpoint=mlrun.get_secret_or_env("MLRUN_AUTH_TOKEN_ENDPOINT"),
150
+ client_id=mlrun.get_secret_or_env("MLRUN_AUTH_CLIENT_ID"),
151
+ client_secret=mlrun.get_secret_or_env("MLRUN_AUTH_CLIENT_SECRET"),
152
+ timeout=config.auth_with_client_id.request_timeout,
153
+ )
154
+ else:
155
+ username, password, token = mlrun.platforms.add_or_refresh_credentials(
156
+ parsed_url.hostname, username, password, config.httpdb.token
157
+ )
158
+
159
+ if token:
160
+ self.token_provider = StaticTokenProvider(token)
147
161
 
148
- self.base_url = base_url
149
162
  self.user = username
150
163
  self.password = password
151
- self.token = token
152
164
 
153
165
  def __repr__(self):
154
166
  cls = self.__class__.__name__
@@ -218,17 +230,19 @@ class HTTPRunDB(RunDBInterface):
218
230
 
219
231
  if self.user:
220
232
  kw["auth"] = (self.user, self.password)
221
- elif self.token:
222
- # Iguazio auth doesn't support passing token through bearer, so use cookie instead
223
- if mlrun.platforms.iguazio.is_iguazio_session(self.token):
224
- session_cookie = f'j:{{"sid": "{self.token}"}}'
225
- cookies = {
226
- "session": session_cookie,
227
- }
228
- kw["cookies"] = cookies
229
- else:
230
- if "Authorization" not in kw.setdefault("headers", {}):
231
- kw["headers"].update({"Authorization": "Bearer " + self.token})
233
+ elif self.token_provider:
234
+ token = self.token_provider.get_token()
235
+ if token:
236
+ # Iguazio auth doesn't support passing token through bearer, so use cookie instead
237
+ if self.token_provider.is_iguazio_session():
238
+ session_cookie = f'j:{{"sid": "{token}"}}'
239
+ cookies = {
240
+ "session": session_cookie,
241
+ }
242
+ kw["cookies"] = cookies
243
+ else:
244
+ if "Authorization" not in kw.setdefault("headers", {}):
245
+ kw["headers"].update({"Authorization": "Bearer " + token})
232
246
 
233
247
  if mlrun.common.schemas.HeaderNames.client_version not in kw.setdefault(
234
248
  "headers", {}
@@ -1142,7 +1156,29 @@ class HTTPRunDB(RunDBInterface):
1142
1156
  project = project or config.default_project
1143
1157
  path = f"projects/{project}/functions/{name}"
1144
1158
  error_message = f"Failed deleting function {project}/{name}"
1145
- self.api_call("DELETE", path, error_message)
1159
+ response = self.api_call("DELETE", path, error_message, version="v2")
1160
+ if response.status_code == http.HTTPStatus.ACCEPTED:
1161
+ logger.info(
1162
+ "Function is being deleted", project_name=project, function_name=name
1163
+ )
1164
+ background_task = mlrun.common.schemas.BackgroundTask(**response.json())
1165
+ background_task = self._wait_for_background_task_to_reach_terminal_state(
1166
+ background_task.metadata.name, project=project
1167
+ )
1168
+ if (
1169
+ background_task.status.state
1170
+ == mlrun.common.schemas.BackgroundTaskState.succeeded
1171
+ ):
1172
+ logger.info(
1173
+ "Function deleted", project_name=project, function_name=name
1174
+ )
1175
+ elif (
1176
+ background_task.status.state
1177
+ == mlrun.common.schemas.BackgroundTaskState.failed
1178
+ ):
1179
+ logger.info(
1180
+ "Function deletion failed", project_name=project, function_name=name
1181
+ )
1146
1182
 
1147
1183
  def list_functions(self, name=None, project=None, tag=None, labels=None):
1148
1184
  """Retrieve a list of functions, filtered by specific criteria.
@@ -1488,16 +1524,15 @@ class HTTPRunDB(RunDBInterface):
1488
1524
  """
1489
1525
 
1490
1526
  try:
1527
+ normalized_name = normalize_name(func.metadata.name)
1491
1528
  params = {
1492
- "name": normalize_name(func.metadata.name),
1529
+ "name": normalized_name,
1493
1530
  "project": func.metadata.project,
1494
1531
  "tag": func.metadata.tag,
1495
1532
  "last_log_timestamp": str(last_log_timestamp),
1496
1533
  "verbose": bool2str(verbose),
1497
1534
  }
1498
- _path = (
1499
- f"projects/{func.metadata.project}/nuclio/{func.metadata.name}/deploy"
1500
- )
1535
+ _path = f"projects/{func.metadata.project}/nuclio/{normalized_name}/deploy"
1501
1536
  resp = self.api_call("GET", _path, params=params)
1502
1537
  except OSError as err:
1503
1538
  logger.error(f"error getting deploy status: {err_to_str(err)}")
@@ -3214,7 +3249,7 @@ class HTTPRunDB(RunDBInterface):
3214
3249
  project: str,
3215
3250
  base_period: int = 10,
3216
3251
  image: str = "mlrun/mlrun",
3217
- ):
3252
+ ) -> None:
3218
3253
  """
3219
3254
  Redeploy model monitoring application controller function.
3220
3255
 
@@ -3224,13 +3259,14 @@ class HTTPRunDB(RunDBInterface):
3224
3259
  :param image: The image of the model monitoring controller function.
3225
3260
  By default, the image is mlrun/mlrun.
3226
3261
  """
3227
-
3228
- params = {
3229
- "image": image,
3230
- "base_period": base_period,
3231
- }
3232
- path = f"projects/{project}/model-monitoring/model-monitoring-controller"
3233
- self.api_call(method="POST", path=path, params=params)
3262
+ self.api_call(
3263
+ method=mlrun.common.types.HTTPMethod.POST,
3264
+ path=f"projects/{project}/model-monitoring/model-monitoring-controller",
3265
+ params={
3266
+ "base_period": base_period,
3267
+ "image": image,
3268
+ },
3269
+ )
3234
3270
 
3235
3271
  def enable_model_monitoring(
3236
3272
  self,
mlrun/model.py CHANGED
@@ -766,6 +766,11 @@ class RunMetadata(ModelObj):
766
766
  def iteration(self, iteration):
767
767
  self._iteration = iteration
768
768
 
769
+ def is_workflow_runner(self):
770
+ if not self.labels:
771
+ return False
772
+ return self.labels.get("job-type", "") == "workflow-runner"
773
+
769
774
 
770
775
  class HyperParamStrategies:
771
776
  grid = "grid"
@@ -1218,6 +1223,19 @@ class RunStatus(ModelObj):
1218
1223
  self.reason = reason
1219
1224
  self.notifications = notifications or {}
1220
1225
 
1226
+ def is_failed(self) -> Optional[bool]:
1227
+ """
1228
+ This method returns whether a run has failed.
1229
+ Returns none if state has yet to be defined. callee is responsible for handling None.
1230
+ (e.g wait for state to be defined)
1231
+ """
1232
+ if not self.state:
1233
+ return None
1234
+ return self.state.casefold() in [
1235
+ mlrun.run.RunStatuses.failed.casefold(),
1236
+ mlrun.run.RunStatuses.error.casefold(),
1237
+ ]
1238
+
1221
1239
 
1222
1240
  class RunTemplate(ModelObj):
1223
1241
  """Run template"""
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  import abc
15
15
  import builtins
16
+ import http
16
17
  import importlib.util as imputil
17
18
  import os
18
19
  import tempfile
@@ -521,7 +522,7 @@ class _PipelineRunner(abc.ABC):
521
522
  @staticmethod
522
523
  def _get_handler(workflow_handler, workflow_spec, project, secrets):
523
524
  if not (workflow_handler and callable(workflow_handler)):
524
- workflow_file = workflow_spec.get_source_file(project.spec.context)
525
+ workflow_file = workflow_spec.get_source_file(project.spec.get_code_path())
525
526
  workflow_handler = create_pipeline(
526
527
  project,
527
528
  workflow_file,
@@ -553,7 +554,7 @@ class _KFPRunner(_PipelineRunner):
553
554
  @classmethod
554
555
  def save(cls, project, workflow_spec: WorkflowSpec, target, artifact_path=None):
555
556
  pipeline_context.set(project, workflow_spec)
556
- workflow_file = workflow_spec.get_source_file(project.spec.context)
557
+ workflow_file = workflow_spec.get_source_file(project.spec.get_code_path())
557
558
  functions = FunctionsDict(project)
558
559
  pipeline = create_pipeline(
559
560
  project,
@@ -882,17 +883,33 @@ class _RemoteRunner(_PipelineRunner):
882
883
  get_workflow_id_timeout=get_workflow_id_timeout,
883
884
  )
884
885
 
886
+ def _get_workflow_id_or_bail():
887
+ try:
888
+ return run_db.get_workflow_id(
889
+ project=project.name,
890
+ name=workflow_response.name,
891
+ run_id=workflow_response.run_id,
892
+ engine=workflow_spec.engine,
893
+ )
894
+ except mlrun.errors.MLRunHTTPStatusError as get_wf_exc:
895
+ # fail fast on specific errors
896
+ if get_wf_exc.error_status_code in [
897
+ http.HTTPStatus.PRECONDITION_FAILED
898
+ ]:
899
+ raise mlrun.errors.MLRunFatalFailureError(
900
+ original_exception=get_wf_exc
901
+ )
902
+
903
+ # raise for a retry (on other errors)
904
+ raise
905
+
885
906
  # Getting workflow id from run:
886
907
  response = retry_until_successful(
887
908
  1,
888
909
  get_workflow_id_timeout,
889
910
  logger,
890
911
  False,
891
- run_db.get_workflow_id,
892
- project=project.name,
893
- name=workflow_response.name,
894
- run_id=workflow_response.run_id,
895
- engine=workflow_spec.engine,
912
+ _get_workflow_id_or_bail,
896
913
  )
897
914
  workflow_id = response.workflow_id
898
915
  # After fetching the workflow_id the workflow executed successfully