mlrun 1.10.0rc11__py3-none-any.whl → 1.10.0rc12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (54) hide show
  1. mlrun/__init__.py +2 -1
  2. mlrun/__main__.py +7 -1
  3. mlrun/artifacts/base.py +9 -3
  4. mlrun/artifacts/dataset.py +2 -1
  5. mlrun/artifacts/llm_prompt.py +1 -1
  6. mlrun/artifacts/model.py +2 -2
  7. mlrun/common/constants.py +1 -0
  8. mlrun/common/runtimes/constants.py +10 -1
  9. mlrun/config.py +19 -2
  10. mlrun/datastore/__init__.py +3 -1
  11. mlrun/datastore/alibaba_oss.py +1 -1
  12. mlrun/datastore/azure_blob.py +1 -1
  13. mlrun/datastore/base.py +6 -31
  14. mlrun/datastore/datastore.py +109 -33
  15. mlrun/datastore/datastore_profile.py +31 -0
  16. mlrun/datastore/dbfs_store.py +1 -1
  17. mlrun/datastore/google_cloud_storage.py +2 -2
  18. mlrun/datastore/model_provider/__init__.py +13 -0
  19. mlrun/datastore/model_provider/model_provider.py +82 -0
  20. mlrun/datastore/model_provider/openai_provider.py +120 -0
  21. mlrun/datastore/remote_client.py +54 -0
  22. mlrun/datastore/s3.py +1 -1
  23. mlrun/datastore/storeytargets.py +1 -1
  24. mlrun/datastore/utils.py +22 -0
  25. mlrun/datastore/v3io.py +1 -1
  26. mlrun/db/base.py +1 -1
  27. mlrun/db/httpdb.py +9 -4
  28. mlrun/db/nopdb.py +1 -1
  29. mlrun/execution.py +23 -7
  30. mlrun/launcher/base.py +23 -13
  31. mlrun/launcher/local.py +3 -1
  32. mlrun/launcher/remote.py +4 -2
  33. mlrun/model.py +65 -0
  34. mlrun/package/packagers_manager.py +2 -0
  35. mlrun/projects/operations.py +8 -1
  36. mlrun/projects/project.py +23 -5
  37. mlrun/run.py +17 -0
  38. mlrun/runtimes/__init__.py +6 -0
  39. mlrun/runtimes/base.py +24 -6
  40. mlrun/runtimes/daskjob.py +1 -0
  41. mlrun/runtimes/databricks_job/databricks_runtime.py +1 -0
  42. mlrun/runtimes/local.py +1 -6
  43. mlrun/serving/server.py +0 -2
  44. mlrun/serving/states.py +30 -5
  45. mlrun/serving/system_steps.py +22 -28
  46. mlrun/utils/helpers.py +13 -2
  47. mlrun/utils/notifications/notification_pusher.py +15 -0
  48. mlrun/utils/version/version.json +2 -2
  49. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/METADATA +2 -2
  50. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/RECORD +54 -50
  51. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/WHEEL +0 -0
  52. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/entry_points.txt +0 -0
  53. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/licenses/LICENSE +0 -0
  54. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,82 @@
1
+ # Copyright 2025 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from collections.abc import Awaitable
15
+ from typing import Callable, Optional, TypeVar
16
+
17
+ import mlrun.errors
18
+ from mlrun.datastore.remote_client import (
19
+ BaseRemoteClient,
20
+ )
21
+
22
+ T = TypeVar("T")
23
+
24
+
25
+ class ModelProvider(BaseRemoteClient):
26
+ support_async = False
27
+
28
+ def __init__(
29
+ self,
30
+ parent,
31
+ kind,
32
+ name,
33
+ endpoint="",
34
+ secrets: Optional[dict] = None,
35
+ default_invoke_kwargs: Optional[dict] = None,
36
+ ):
37
+ super().__init__(
38
+ parent=parent, name=name, kind=kind, endpoint=endpoint, secrets=secrets
39
+ )
40
+ self.default_invoke_kwargs = default_invoke_kwargs or {}
41
+ self._client = None
42
+ self._default_operation = None
43
+ self._async_client = None
44
+ self._default_async_operation = None
45
+
46
+ def load_client(self) -> None:
47
+ raise NotImplementedError("load_client method is not implemented")
48
+
49
+ def invoke(self, prompt: Optional[str] = None, **invoke_kwargs) -> str:
50
+ raise NotImplementedError("invoke method is not implemented")
51
+
52
+ def customized_invoke(
53
+ self, operation: Optional[Callable[..., T]] = None, **invoke_kwargs
54
+ ) -> Optional[T]:
55
+ raise NotImplementedError("customized_invoke method is not implemented")
56
+
57
+ @property
58
+ def client(self):
59
+ return self._client
60
+
61
+ @property
62
+ def model(self):
63
+ return None
64
+
65
+ def get_invoke_kwargs(self, invoke_kwargs):
66
+ kwargs = self.default_invoke_kwargs.copy()
67
+ kwargs.update(invoke_kwargs)
68
+ return kwargs
69
+
70
+ @property
71
+ def async_client(self):
72
+ if not self.support_async:
73
+ raise mlrun.errors.MLRunInvalidArgumentError(
74
+ f"{self.__class__.__name__} does not support async operations"
75
+ )
76
+ return self._async_client
77
+
78
+ async def async_customized_invoke(self, **kwargs):
79
+ raise NotImplementedError("async_customized_invoke is not implemented")
80
+
81
+ async def async_invoke(self, prompt: str, **invoke_kwargs) -> Awaitable[str]:
82
+ raise NotImplementedError("async_invoke is not implemented")
@@ -0,0 +1,120 @@
1
+ # Copyright 2025 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Optional, TypeVar
16
+
17
+ import mlrun
18
+ from mlrun.datastore.model_provider.model_provider import ModelProvider
19
+
20
+ T = TypeVar("T")
21
+
22
+
23
+ class OpenAIProvider(ModelProvider):
24
+ def __init__(
25
+ self,
26
+ parent,
27
+ schema,
28
+ name,
29
+ endpoint="",
30
+ secrets: Optional[dict] = None,
31
+ default_invoke_kwargs: Optional[dict] = None,
32
+ ):
33
+ endpoint = endpoint or mlrun.mlconf.model_providers.openai_default_model
34
+ if schema != "openai":
35
+ raise mlrun.errors.MLRunInvalidArgumentError(
36
+ "OpenAIProvider supports only 'openai' as the provider kind."
37
+ )
38
+ super().__init__(
39
+ parent=parent,
40
+ kind=schema,
41
+ name=name,
42
+ endpoint=endpoint,
43
+ secrets=secrets,
44
+ default_invoke_kwargs=default_invoke_kwargs,
45
+ )
46
+ self.options = self.get_client_options()
47
+ self.load_client()
48
+
49
+ @classmethod
50
+ def parse_endpoint_and_path(cls, endpoint, subpath) -> (str, str):
51
+ if endpoint and subpath:
52
+ endpoint = endpoint + subpath
53
+ # in openai there is no usage of subpath variable. if the model contains "/", it is part of the model name.
54
+ subpath = ""
55
+ return endpoint, subpath
56
+
57
+ @property
58
+ def model(self):
59
+ return self.endpoint
60
+
61
+ def load_client(self) -> None:
62
+ try:
63
+ from openai import OpenAI # noqa
64
+
65
+ self._client = OpenAI(**self.options)
66
+ self._default_operation = self.client.chat.completions.create
67
+ except ImportError as exc:
68
+ raise ImportError("openai package is not installed") from exc
69
+
70
+ def get_client_options(self):
71
+ res = dict(
72
+ api_key=self._get_secret_or_env("OPENAI_API_KEY"),
73
+ organization=self._get_secret_or_env("OPENAI_ORG_ID"),
74
+ project=self._get_secret_or_env("OPENAI_PROJECT_ID"),
75
+ base_url=self._get_secret_or_env("OPENAI_BASE_URL"),
76
+ timeout=self._get_secret_or_env("OPENAI_TIMEOUT"),
77
+ max_retries=self._get_secret_or_env("OPENAI_MAX_RETRIES"),
78
+ )
79
+ return self._sanitize_options(res)
80
+
81
+ def customized_invoke(
82
+ self, operation: Optional[Callable[..., T]] = None, **invoke_kwargs
83
+ ) -> Optional[T]:
84
+ invoke_kwargs = self.get_invoke_kwargs(invoke_kwargs)
85
+ if operation:
86
+ return operation(**invoke_kwargs, model=self.model)
87
+ else:
88
+ return self._default_operation(**invoke_kwargs, model=self.model)
89
+
90
+ def _get_messages_parameter(
91
+ self, prompt: Optional[str] = None, **invoke_kwargs
92
+ ) -> (str, dict):
93
+ invoke_kwargs = self.get_invoke_kwargs(invoke_kwargs)
94
+ messages = invoke_kwargs.get("messages")
95
+ if messages:
96
+ if prompt:
97
+ raise mlrun.errors.MLRunInvalidArgumentError(
98
+ "can not provide 'messages' and 'prompt' to invoke"
99
+ )
100
+ elif prompt:
101
+ messages = [
102
+ {
103
+ "role": "user",
104
+ "content": prompt,
105
+ },
106
+ ]
107
+ else:
108
+ raise mlrun.errors.MLRunInvalidArgumentError(
109
+ "must provide 'messages' or 'prompt' to invoke"
110
+ )
111
+ return messages, invoke_kwargs
112
+
113
+ def invoke(self, prompt: Optional[str] = None, **invoke_kwargs) -> str:
114
+ messages, invoke_kwargs = self._get_messages_parameter(
115
+ prompt=prompt, **invoke_kwargs
116
+ )
117
+ response = self._default_operation(
118
+ model=self.endpoint, messages=messages, **invoke_kwargs
119
+ )
120
+ return response.choices[0].message.content
@@ -0,0 +1,54 @@
1
+ # Copyright 2025 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ import mlrun
18
+
19
+
20
+ class BaseRemoteClient:
21
+ def __init__(self, parent, kind, name, endpoint="", secrets: Optional[dict] = None):
22
+ self._parent = parent
23
+ self.kind = kind
24
+ self.name = name
25
+ self.endpoint = endpoint
26
+ self._secrets = secrets or {}
27
+ self.secret_pfx = ""
28
+
29
+ def _get_secret_or_env(self, key, default=None):
30
+ # Project-secrets are mounted as env variables whose name can be retrieved from SecretsStore
31
+ return mlrun.get_secret_or_env(
32
+ key, secret_provider=self._get_secret, default=default
33
+ )
34
+
35
+ def _get_parent_secret(self, key):
36
+ return self._parent.secret(self.secret_pfx + key)
37
+
38
+ def _get_secret(self, key: str, default=None):
39
+ return self._secrets.get(key, default) or self._get_parent_secret(key)
40
+
41
+ @property
42
+ def url(self):
43
+ return f"{self.kind}://{self.endpoint}"
44
+
45
+ @staticmethod
46
+ def _sanitize_options(options):
47
+ if not options:
48
+ return {}
49
+ options = {k: v for k, v in options.items() if v is not None and v != ""}
50
+ return options
51
+
52
+ @classmethod
53
+ def parse_endpoint_and_path(cls, endpoint, subpath) -> (str, str):
54
+ return endpoint, subpath
mlrun/datastore/s3.py CHANGED
@@ -186,7 +186,7 @@ class S3Store(DataStore):
186
186
  if profile:
187
187
  storage_options["profile"] = profile
188
188
 
189
- return self._sanitize_storage_options(storage_options)
189
+ return self._sanitize_options(storage_options)
190
190
 
191
191
  @property
192
192
  def spark_url(self):
@@ -46,7 +46,7 @@ def get_url_and_storage_options(path, external_storage_options=None):
46
46
  storage_options = merge(external_storage_options, storage_options)
47
47
  else:
48
48
  storage_options = storage_options or external_storage_options
49
- return url, DataStore._sanitize_storage_options(storage_options)
49
+ return url, DataStore._sanitize_options(storage_options)
50
50
 
51
51
 
52
52
  class TDEngineStoreyTarget(storey.TDEngineTarget):
mlrun/datastore/utils.py CHANGED
@@ -311,3 +311,25 @@ class KafkaParameters:
311
311
  valid_keys.update(ref_dict.keys())
312
312
  # Return a new dictionary with only valid keys
313
313
  return {k: v for k, v in input_dict.items() if k in valid_keys}
314
+
315
+
316
+ def parse_url(url):
317
+ if url and url.startswith("v3io://") and not url.startswith("v3io:///"):
318
+ url = url.replace("v3io://", "v3io:///", 1)
319
+ parsed_url = urlparse(url)
320
+ schema = parsed_url.scheme.lower()
321
+ endpoint = parsed_url.hostname
322
+ if endpoint:
323
+ # HACK - urlparse returns the hostname after in lower case - we want the original case:
324
+ # the hostname is a substring of the netloc, in which it's the original case, so we find the indexes of the
325
+ # hostname in the netloc and take it from there
326
+ lower_hostname = parsed_url.hostname
327
+ netloc = str(parsed_url.netloc)
328
+ lower_netloc = netloc.lower()
329
+ hostname_index_in_netloc = lower_netloc.index(str(lower_hostname))
330
+ endpoint = netloc[
331
+ hostname_index_in_netloc : hostname_index_in_netloc + len(lower_hostname)
332
+ ]
333
+ if parsed_url.port:
334
+ endpoint += f":{parsed_url.port}"
335
+ return schema, endpoint, parsed_url
mlrun/datastore/v3io.py CHANGED
@@ -97,7 +97,7 @@ class V3ioStore(DataStore):
97
97
  v3io_access_key=self._get_secret_or_env("V3IO_ACCESS_KEY"),
98
98
  v3io_api=mlrun.mlconf.v3io_api,
99
99
  )
100
- return self._sanitize_storage_options(res)
100
+ return self._sanitize_options(res)
101
101
 
102
102
  def _upload(
103
103
  self,
mlrun/db/base.py CHANGED
@@ -44,7 +44,7 @@ class RunDBInterface(ABC):
44
44
  pass
45
45
 
46
46
  @abstractmethod
47
- def get_log(self, uid, project="", offset=0, size=0):
47
+ def get_log(self, uid, project="", offset=0, size=0, attempt=None):
48
48
  pass
49
49
 
50
50
  @abstractmethod
mlrun/db/httpdb.py CHANGED
@@ -608,7 +608,7 @@ class HTTPRunDB(RunDBInterface):
608
608
  error = f"store log {project}/{uid}"
609
609
  self.api_call("POST", path, error, params, body)
610
610
 
611
- def get_log(self, uid, project="", offset=0, size=None):
611
+ def get_log(self, uid, project="", offset=0, size=None, attempt=None):
612
612
  """Retrieve 1 MB data of log.
613
613
 
614
614
  :param uid: Log unique ID
@@ -616,6 +616,8 @@ class HTTPRunDB(RunDBInterface):
616
616
  :param offset: Retrieve partial log, get up to ``size`` bytes starting at offset ``offset``
617
617
  from beginning of log (must be >= 0)
618
618
  :param size: If set to ``-1`` will retrieve and print all data to end of the log by chunks of 1MB each.
619
+ :param attempt: For retriable runs, the attempt number to retrieve the log for.
620
+ 1 is the initial attempt.
619
621
  :returns: The following objects:
620
622
 
621
623
  - state - The state of the runtime object which generates this log, if it exists. In case no known state
@@ -636,6 +638,8 @@ class HTTPRunDB(RunDBInterface):
636
638
  return state, offset
637
639
 
638
640
  params = {"offset": offset, "size": size}
641
+ if attempt:
642
+ params["attempt"] = attempt
639
643
  path = self._path_of("logs", project, uid)
640
644
  error = f"get log {project}/{uid}"
641
645
  resp = self.api_call("GET", path, error, params=params)
@@ -658,7 +662,7 @@ class HTTPRunDB(RunDBInterface):
658
662
  resp = self.api_call("GET", path, error)
659
663
  return resp.json()["size"]
660
664
 
661
- def watch_log(self, uid, project="", watch=True, offset=0):
665
+ def watch_log(self, uid, project="", watch=True, offset=0, attempt=None):
662
666
  """Retrieve logs of a running process by chunks of 1MB, and watch the progress of the execution until it
663
667
  completes. This method will print out the logs and continue to periodically poll for, and print,
664
668
  new logs as long as the state of the runtime which generates this log is either ``pending`` or ``running``.
@@ -668,10 +672,11 @@ class HTTPRunDB(RunDBInterface):
668
672
  :param watch: If set to ``True`` will continue tracking the log as described above. Otherwise this function
669
673
  is practically equivalent to the :py:func:`~get_log` function.
670
674
  :param offset: Minimal offset in the log to watch.
675
+ :param attempt: For retriable runs, the attempt number to retrieve the log for. 1 is the initial attempt.
671
676
  :returns: The final state of the log being watched and the final offset.
672
677
  """
673
678
 
674
- state, text = self.get_log(uid, project, offset=offset)
679
+ state, text = self.get_log(uid, project, offset=offset, attempt=attempt)
675
680
  if text:
676
681
  print(text.decode(errors=mlrun.mlconf.httpdb.logs.decode.errors))
677
682
  nil_resp = 0
@@ -687,7 +692,7 @@ class HTTPRunDB(RunDBInterface):
687
692
  mlrun.mlconf.httpdb.logs.pull_logs_backoff_no_logs_default_interval
688
693
  )
689
694
  )
690
- state, text = self.get_log(uid, project, offset=offset)
695
+ state, text = self.get_log(uid, project, offset=offset, attempt=attempt)
691
696
  if text:
692
697
  nil_resp = 0
693
698
  print(
mlrun/db/nopdb.py CHANGED
@@ -63,7 +63,7 @@ class NopDB(RunDBInterface):
63
63
  def store_log(self, uid, project="", body=None, append=False):
64
64
  pass
65
65
 
66
- def get_log(self, uid, project="", offset=0, size=0):
66
+ def get_log(self, uid, project="", offset=0, size=0, attempt=None):
67
67
  pass
68
68
 
69
69
  def store_run(self, struct, uid, project="", iter=0):
mlrun/execution.py CHANGED
@@ -26,6 +26,7 @@ from dateutil import parser
26
26
  import mlrun
27
27
  import mlrun.common.constants as mlrun_constants
28
28
  import mlrun.common.formatters
29
+ import mlrun.common.runtimes.constants
29
30
  from mlrun.artifacts import (
30
31
  Artifact,
31
32
  DatasetArtifact,
@@ -91,6 +92,8 @@ class MLClientCtx:
91
92
  self._autocommit = autocommit
92
93
  self._notifications = []
93
94
  self._state_thresholds = {}
95
+ self._retry_spec = {}
96
+ self._retry_count = None
94
97
 
95
98
  self._labels = {}
96
99
  self._annotations = {}
@@ -432,6 +435,7 @@ class MLClientCtx:
432
435
  self._tolerations = spec.get("tolerations", self._tolerations)
433
436
  self._affinity = spec.get("affinity", self._affinity)
434
437
  self._reset_on_run = spec.get("reset_on_run", self._reset_on_run)
438
+ self._retry_spec = spec.get("retry", self._retry_spec)
435
439
 
436
440
  self._init_dbs(rundb)
437
441
 
@@ -450,10 +454,11 @@ class MLClientCtx:
450
454
  if start:
451
455
  start = parser.parse(start) if isinstance(start, str) else start
452
456
  self._start_time = start
453
- self._state = "running"
457
+ self._state = mlrun.common.runtimes.constants.RunStates.running
454
458
 
455
459
  status = attrs.get("status")
456
- if include_status and status:
460
+ retry_configured = self._retry_spec and self._retry_spec.get("count")
461
+ if (include_status or retry_configured) and status:
457
462
  self._results = status.get("results", self._results)
458
463
  for artifact in status.get("artifacts", []):
459
464
  artifact_obj = dict_to_artifact(artifact)
@@ -462,7 +467,10 @@ class MLClientCtx:
462
467
  )
463
468
  for key, uri in status.get("artifact_uris", {}).items():
464
469
  self._artifacts_manager.artifact_uris[key] = uri
465
- self._state = status.get("state", self._state)
470
+ self._retry_count = status.get("retry_count", self._retry_count)
471
+ # if run is a retry, the state needs to move to running
472
+ if include_status:
473
+ self._state = status.get("state", self._state)
466
474
 
467
475
  # No need to store the run for every worker
468
476
  if store_run and self.is_logging_worker():
@@ -1107,13 +1115,13 @@ class MLClientCtx:
1107
1115
  :param completed: Mark run as completed
1108
1116
  """
1109
1117
  # Changing state to completed is allowed only when the execution is in running state
1110
- if self._state != "running":
1118
+ if self._state != mlrun.common.runtimes.constants.RunStates.running:
1111
1119
  completed = False
1112
1120
 
1113
1121
  if message:
1114
1122
  self._annotations["message"] = message
1115
1123
  if completed:
1116
- self._state = "completed"
1124
+ self._state = mlrun.common.runtimes.constants.RunStates.completed
1117
1125
 
1118
1126
  if self._parent:
1119
1127
  self._parent.update_child_iterations()
@@ -1147,9 +1155,15 @@ class MLClientCtx:
1147
1155
  updates = {"status.last_update": now_date().isoformat()}
1148
1156
 
1149
1157
  if error is not None:
1150
- self._state = "error"
1158
+ state = mlrun.common.runtimes.constants.RunStates.error
1159
+ max_retries = self._retry_spec.get("count", 0)
1160
+ self._retry_count = self._retry_count or 0
1161
+ if max_retries and self._retry_count < max_retries:
1162
+ state = mlrun.common.runtimes.constants.RunStates.pending_retry
1163
+
1164
+ self._state = state
1151
1165
  self._error = str(error)
1152
- updates["status.state"] = "error"
1166
+ updates["status.state"] = state
1153
1167
  updates["status.error"] = error
1154
1168
  elif (
1155
1169
  execution_state
@@ -1241,11 +1255,13 @@ class MLClientCtx:
1241
1255
  "node_selector": self._node_selector,
1242
1256
  "tolerations": self._tolerations,
1243
1257
  "affinity": self._affinity,
1258
+ "retry": self._retry_spec,
1244
1259
  },
1245
1260
  "status": {
1246
1261
  "results": self._results,
1247
1262
  "start_time": to_date_str(self._start_time),
1248
1263
  "last_update": to_date_str(self._last_update),
1264
+ "retry_count": self._retry_count,
1249
1265
  },
1250
1266
  }
1251
1267
 
mlrun/launcher/base.py CHANGED
@@ -18,6 +18,8 @@ import os
18
18
  import uuid
19
19
  from typing import Any, Callable, Optional, Union
20
20
 
21
+ import mlrun.common.constants
22
+ import mlrun.common.runtimes.constants
21
23
  import mlrun.common.schemas
22
24
  import mlrun.config
23
25
  import mlrun.errors
@@ -72,6 +74,7 @@ class BaseLauncher(abc.ABC):
72
74
  notifications: Optional[list[mlrun.model.Notification]] = None,
73
75
  returns: Optional[list[Union[str, dict[str, str]]]] = None,
74
76
  state_thresholds: Optional[dict[str, int]] = None,
77
+ retry: Optional[Union[mlrun.model.Retry, dict]] = None,
75
78
  ) -> "mlrun.run.RunObject":
76
79
  """run the function from the server/client[local/remote]"""
77
80
  pass
@@ -133,7 +136,7 @@ class BaseLauncher(abc.ABC):
133
136
  """Check if the runtime requires to build the image and updates the spec accordingly"""
134
137
  pass
135
138
 
136
- def _validate_runtime(
139
+ def _validate_run(
137
140
  self,
138
141
  runtime: "mlrun.runtimes.BaseRuntime",
139
142
  run: "mlrun.run.RunObject",
@@ -194,7 +197,7 @@ class BaseLauncher(abc.ABC):
194
197
  )
195
198
 
196
199
  @classmethod
197
- def _validate_run_single_param(cls, param_name, param_value):
200
+ def _validate_run_single_param(cls, param_name: str, param_value: int):
198
201
  # verify that integer parameters don't exceed a int64
199
202
  if isinstance(param_value, int) and abs(param_value) >= 2**63:
200
203
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -203,8 +206,6 @@ class BaseLauncher(abc.ABC):
203
206
 
204
207
  @staticmethod
205
208
  def _create_run_object(task):
206
- valid_task_types = (dict, mlrun.run.RunTemplate, mlrun.run.RunObject)
207
-
208
209
  if not task:
209
210
  # if task passed generate default RunObject
210
211
  return mlrun.run.RunObject.from_dict(task)
@@ -215,18 +216,18 @@ class BaseLauncher(abc.ABC):
215
216
  if isinstance(task, str):
216
217
  task = ast.literal_eval(task)
217
218
 
218
- if not isinstance(task, valid_task_types):
219
- raise mlrun.errors.MLRunInvalidArgumentError(
220
- f"Task is not a valid object, type={type(task)}, expected types={valid_task_types}"
221
- )
222
-
219
+ valid_task_types = (dict, mlrun.run.RunTemplate, mlrun.run.RunObject)
220
+ if isinstance(task, mlrun.run.RunObject):
221
+ # if task is already a RunObject, we can return it as is
222
+ return task
223
223
  if isinstance(task, mlrun.run.RunTemplate):
224
224
  return mlrun.run.RunObject.from_template(task)
225
225
  elif isinstance(task, dict):
226
226
  return mlrun.run.RunObject.from_dict(task)
227
227
 
228
- # task is already a RunObject
229
- return task
228
+ raise mlrun.errors.MLRunInvalidArgumentError(
229
+ f"Task is not a valid object, type={type(task)}, expected types={valid_task_types}"
230
+ )
230
231
 
231
232
  @staticmethod
232
233
  def _enrich_run(
@@ -246,6 +247,7 @@ class BaseLauncher(abc.ABC):
246
247
  workdir=None,
247
248
  notifications: Optional[list[mlrun.model.Notification]] = None,
248
249
  state_thresholds: Optional[dict[str, int]] = None,
250
+ retry: Optional[Union[mlrun.model.Retry, dict]] = None,
249
251
  ):
250
252
  run.spec.handler = (
251
253
  handler or run.spec.handler or runtime.spec.default_handler or ""
@@ -364,6 +366,7 @@ class BaseLauncher(abc.ABC):
364
366
  | state_thresholds
365
367
  )
366
368
  run.spec.state_thresholds = state_thresholds or run.spec.state_thresholds
369
+ run.spec.retry = retry or run.spec.retry
367
370
  return run
368
371
 
369
372
  @staticmethod
@@ -410,7 +413,7 @@ class BaseLauncher(abc.ABC):
410
413
  )
411
414
  if (
412
415
  run.status.state
413
- in mlrun.common.runtimes.constants.RunStates.error_and_abortion_states()
416
+ in mlrun.common.runtimes.constants.RunStates.error_states()
414
417
  ):
415
418
  if runtime._is_remote and not runtime.is_child:
416
419
  logger.error(
@@ -418,7 +421,14 @@ class BaseLauncher(abc.ABC):
418
421
  state=run.status.state,
419
422
  status=run.status.to_dict(),
420
423
  )
421
- raise mlrun.runtimes.utils.RunError(run.error)
424
+
425
+ error = run.error
426
+ if (
427
+ run.status.state
428
+ == mlrun.common.runtimes.constants.RunStates.pending_retry
429
+ ):
430
+ error = f"Run is pending retry, error: {run.error}"
431
+ raise mlrun.runtimes.utils.RunError(error)
422
432
  return run
423
433
 
424
434
  return None
mlrun/launcher/local.py CHANGED
@@ -72,6 +72,7 @@ class ClientLocalLauncher(launcher.ClientBaseLauncher):
72
72
  returns: Optional[list[Union[str, dict[str, str]]]] = None,
73
73
  state_thresholds: Optional[dict[str, int]] = None,
74
74
  reset_on_run: Optional[bool] = None,
75
+ retry: Optional[Union[mlrun.model.Retry, dict]] = None,
75
76
  ) -> "mlrun.run.RunObject":
76
77
  # do not allow local function to be scheduled
77
78
  if schedule is not None:
@@ -122,8 +123,9 @@ class ClientLocalLauncher(launcher.ClientBaseLauncher):
122
123
  workdir=workdir,
123
124
  notifications=notifications,
124
125
  state_thresholds=state_thresholds,
126
+ retry=retry,
125
127
  )
126
- self._validate_runtime(runtime, run)
128
+ self._validate_run(runtime, run)
127
129
  result = self._execute(
128
130
  runtime=runtime,
129
131
  run=run,
mlrun/launcher/remote.py CHANGED
@@ -61,6 +61,7 @@ class ClientRemoteLauncher(launcher.ClientBaseLauncher):
61
61
  returns: Optional[list[Union[str, dict[str, str]]]] = None,
62
62
  state_thresholds: Optional[dict[str, int]] = None,
63
63
  reset_on_run: Optional[bool] = None,
64
+ retry: Optional[Union[mlrun.model.Retry, dict]] = None,
64
65
  ) -> "mlrun.run.RunObject":
65
66
  self.enrich_runtime(runtime, project)
66
67
  run = self._create_run_object(task)
@@ -82,8 +83,9 @@ class ClientRemoteLauncher(launcher.ClientBaseLauncher):
82
83
  workdir=workdir,
83
84
  notifications=notifications,
84
85
  state_thresholds=state_thresholds,
86
+ retry=retry,
85
87
  )
86
- self._validate_runtime(runtime, run)
88
+ self._validate_run(runtime, run)
87
89
 
88
90
  if not runtime.is_deployed():
89
91
  if runtime.spec.build.auto_build or auto_build:
@@ -190,7 +192,7 @@ class ClientRemoteLauncher(launcher.ClientBaseLauncher):
190
192
  return self._wrap_run_result(runtime, resp, run, schedule=schedule)
191
193
 
192
194
  @classmethod
193
- def _validate_run_single_param(cls, param_name, param_value):
195
+ def _validate_run_single_param(cls, param_name: str, param_value: int):
194
196
  if isinstance(param_value, pd.DataFrame):
195
197
  raise mlrun.errors.MLRunInvalidArgumentTypeError(
196
198
  f"Parameter '{param_name}' has an unsupported value of type"