mlrun 1.10.0rc11__py3-none-any.whl → 1.10.0rc12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +2 -1
- mlrun/__main__.py +7 -1
- mlrun/artifacts/base.py +9 -3
- mlrun/artifacts/dataset.py +2 -1
- mlrun/artifacts/llm_prompt.py +1 -1
- mlrun/artifacts/model.py +2 -2
- mlrun/common/constants.py +1 -0
- mlrun/common/runtimes/constants.py +10 -1
- mlrun/config.py +19 -2
- mlrun/datastore/__init__.py +3 -1
- mlrun/datastore/alibaba_oss.py +1 -1
- mlrun/datastore/azure_blob.py +1 -1
- mlrun/datastore/base.py +6 -31
- mlrun/datastore/datastore.py +109 -33
- mlrun/datastore/datastore_profile.py +31 -0
- mlrun/datastore/dbfs_store.py +1 -1
- mlrun/datastore/google_cloud_storage.py +2 -2
- mlrun/datastore/model_provider/__init__.py +13 -0
- mlrun/datastore/model_provider/model_provider.py +82 -0
- mlrun/datastore/model_provider/openai_provider.py +120 -0
- mlrun/datastore/remote_client.py +54 -0
- mlrun/datastore/s3.py +1 -1
- mlrun/datastore/storeytargets.py +1 -1
- mlrun/datastore/utils.py +22 -0
- mlrun/datastore/v3io.py +1 -1
- mlrun/db/base.py +1 -1
- mlrun/db/httpdb.py +9 -4
- mlrun/db/nopdb.py +1 -1
- mlrun/execution.py +23 -7
- mlrun/launcher/base.py +23 -13
- mlrun/launcher/local.py +3 -1
- mlrun/launcher/remote.py +4 -2
- mlrun/model.py +65 -0
- mlrun/package/packagers_manager.py +2 -0
- mlrun/projects/operations.py +8 -1
- mlrun/projects/project.py +23 -5
- mlrun/run.py +17 -0
- mlrun/runtimes/__init__.py +6 -0
- mlrun/runtimes/base.py +24 -6
- mlrun/runtimes/daskjob.py +1 -0
- mlrun/runtimes/databricks_job/databricks_runtime.py +1 -0
- mlrun/runtimes/local.py +1 -6
- mlrun/serving/server.py +0 -2
- mlrun/serving/states.py +30 -5
- mlrun/serving/system_steps.py +22 -28
- mlrun/utils/helpers.py +13 -2
- mlrun/utils/notifications/notification_pusher.py +15 -0
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/METADATA +2 -2
- {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/RECORD +54 -50
- {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/WHEEL +0 -0
- {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/entry_points.txt +0 -0
- {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/licenses/LICENSE +0 -0
- {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# Copyright 2025 Iguazio
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from collections.abc import Awaitable
|
|
15
|
+
from typing import Callable, Optional, TypeVar
|
|
16
|
+
|
|
17
|
+
import mlrun.errors
|
|
18
|
+
from mlrun.datastore.remote_client import (
|
|
19
|
+
BaseRemoteClient,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ModelProvider(BaseRemoteClient):
|
|
26
|
+
support_async = False
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
parent,
|
|
31
|
+
kind,
|
|
32
|
+
name,
|
|
33
|
+
endpoint="",
|
|
34
|
+
secrets: Optional[dict] = None,
|
|
35
|
+
default_invoke_kwargs: Optional[dict] = None,
|
|
36
|
+
):
|
|
37
|
+
super().__init__(
|
|
38
|
+
parent=parent, name=name, kind=kind, endpoint=endpoint, secrets=secrets
|
|
39
|
+
)
|
|
40
|
+
self.default_invoke_kwargs = default_invoke_kwargs or {}
|
|
41
|
+
self._client = None
|
|
42
|
+
self._default_operation = None
|
|
43
|
+
self._async_client = None
|
|
44
|
+
self._default_async_operation = None
|
|
45
|
+
|
|
46
|
+
def load_client(self) -> None:
|
|
47
|
+
raise NotImplementedError("load_client method is not implemented")
|
|
48
|
+
|
|
49
|
+
def invoke(self, prompt: Optional[str] = None, **invoke_kwargs) -> str:
|
|
50
|
+
raise NotImplementedError("invoke method is not implemented")
|
|
51
|
+
|
|
52
|
+
def customized_invoke(
|
|
53
|
+
self, operation: Optional[Callable[..., T]] = None, **invoke_kwargs
|
|
54
|
+
) -> Optional[T]:
|
|
55
|
+
raise NotImplementedError("customized_invoke method is not implemented")
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def client(self):
|
|
59
|
+
return self._client
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def model(self):
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
def get_invoke_kwargs(self, invoke_kwargs):
|
|
66
|
+
kwargs = self.default_invoke_kwargs.copy()
|
|
67
|
+
kwargs.update(invoke_kwargs)
|
|
68
|
+
return kwargs
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def async_client(self):
|
|
72
|
+
if not self.support_async:
|
|
73
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
74
|
+
f"{self.__class__.__name__} does not support async operations"
|
|
75
|
+
)
|
|
76
|
+
return self._async_client
|
|
77
|
+
|
|
78
|
+
async def async_customized_invoke(self, **kwargs):
|
|
79
|
+
raise NotImplementedError("async_customized_invoke is not implemented")
|
|
80
|
+
|
|
81
|
+
async def async_invoke(self, prompt: str, **invoke_kwargs) -> Awaitable[str]:
|
|
82
|
+
raise NotImplementedError("async_invoke is not implemented")
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright 2025 Iguazio
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Callable, Optional, TypeVar
|
|
16
|
+
|
|
17
|
+
import mlrun
|
|
18
|
+
from mlrun.datastore.model_provider.model_provider import ModelProvider
|
|
19
|
+
|
|
20
|
+
T = TypeVar("T")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OpenAIProvider(ModelProvider):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
parent,
|
|
27
|
+
schema,
|
|
28
|
+
name,
|
|
29
|
+
endpoint="",
|
|
30
|
+
secrets: Optional[dict] = None,
|
|
31
|
+
default_invoke_kwargs: Optional[dict] = None,
|
|
32
|
+
):
|
|
33
|
+
endpoint = endpoint or mlrun.mlconf.model_providers.openai_default_model
|
|
34
|
+
if schema != "openai":
|
|
35
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
36
|
+
"OpenAIProvider supports only 'openai' as the provider kind."
|
|
37
|
+
)
|
|
38
|
+
super().__init__(
|
|
39
|
+
parent=parent,
|
|
40
|
+
kind=schema,
|
|
41
|
+
name=name,
|
|
42
|
+
endpoint=endpoint,
|
|
43
|
+
secrets=secrets,
|
|
44
|
+
default_invoke_kwargs=default_invoke_kwargs,
|
|
45
|
+
)
|
|
46
|
+
self.options = self.get_client_options()
|
|
47
|
+
self.load_client()
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def parse_endpoint_and_path(cls, endpoint, subpath) -> (str, str):
|
|
51
|
+
if endpoint and subpath:
|
|
52
|
+
endpoint = endpoint + subpath
|
|
53
|
+
# in openai there is no usage of subpath variable. if the model contains "/", it is part of the model name.
|
|
54
|
+
subpath = ""
|
|
55
|
+
return endpoint, subpath
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def model(self):
|
|
59
|
+
return self.endpoint
|
|
60
|
+
|
|
61
|
+
def load_client(self) -> None:
|
|
62
|
+
try:
|
|
63
|
+
from openai import OpenAI # noqa
|
|
64
|
+
|
|
65
|
+
self._client = OpenAI(**self.options)
|
|
66
|
+
self._default_operation = self.client.chat.completions.create
|
|
67
|
+
except ImportError as exc:
|
|
68
|
+
raise ImportError("openai package is not installed") from exc
|
|
69
|
+
|
|
70
|
+
def get_client_options(self):
|
|
71
|
+
res = dict(
|
|
72
|
+
api_key=self._get_secret_or_env("OPENAI_API_KEY"),
|
|
73
|
+
organization=self._get_secret_or_env("OPENAI_ORG_ID"),
|
|
74
|
+
project=self._get_secret_or_env("OPENAI_PROJECT_ID"),
|
|
75
|
+
base_url=self._get_secret_or_env("OPENAI_BASE_URL"),
|
|
76
|
+
timeout=self._get_secret_or_env("OPENAI_TIMEOUT"),
|
|
77
|
+
max_retries=self._get_secret_or_env("OPENAI_MAX_RETRIES"),
|
|
78
|
+
)
|
|
79
|
+
return self._sanitize_options(res)
|
|
80
|
+
|
|
81
|
+
def customized_invoke(
|
|
82
|
+
self, operation: Optional[Callable[..., T]] = None, **invoke_kwargs
|
|
83
|
+
) -> Optional[T]:
|
|
84
|
+
invoke_kwargs = self.get_invoke_kwargs(invoke_kwargs)
|
|
85
|
+
if operation:
|
|
86
|
+
return operation(**invoke_kwargs, model=self.model)
|
|
87
|
+
else:
|
|
88
|
+
return self._default_operation(**invoke_kwargs, model=self.model)
|
|
89
|
+
|
|
90
|
+
def _get_messages_parameter(
|
|
91
|
+
self, prompt: Optional[str] = None, **invoke_kwargs
|
|
92
|
+
) -> (str, dict):
|
|
93
|
+
invoke_kwargs = self.get_invoke_kwargs(invoke_kwargs)
|
|
94
|
+
messages = invoke_kwargs.get("messages")
|
|
95
|
+
if messages:
|
|
96
|
+
if prompt:
|
|
97
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
98
|
+
"can not provide 'messages' and 'prompt' to invoke"
|
|
99
|
+
)
|
|
100
|
+
elif prompt:
|
|
101
|
+
messages = [
|
|
102
|
+
{
|
|
103
|
+
"role": "user",
|
|
104
|
+
"content": prompt,
|
|
105
|
+
},
|
|
106
|
+
]
|
|
107
|
+
else:
|
|
108
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
109
|
+
"must provide 'messages' or 'prompt' to invoke"
|
|
110
|
+
)
|
|
111
|
+
return messages, invoke_kwargs
|
|
112
|
+
|
|
113
|
+
def invoke(self, prompt: Optional[str] = None, **invoke_kwargs) -> str:
|
|
114
|
+
messages, invoke_kwargs = self._get_messages_parameter(
|
|
115
|
+
prompt=prompt, **invoke_kwargs
|
|
116
|
+
)
|
|
117
|
+
response = self._default_operation(
|
|
118
|
+
model=self.endpoint, messages=messages, **invoke_kwargs
|
|
119
|
+
)
|
|
120
|
+
return response.choices[0].message.content
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright 2025 Iguazio
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import mlrun
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseRemoteClient:
|
|
21
|
+
def __init__(self, parent, kind, name, endpoint="", secrets: Optional[dict] = None):
|
|
22
|
+
self._parent = parent
|
|
23
|
+
self.kind = kind
|
|
24
|
+
self.name = name
|
|
25
|
+
self.endpoint = endpoint
|
|
26
|
+
self._secrets = secrets or {}
|
|
27
|
+
self.secret_pfx = ""
|
|
28
|
+
|
|
29
|
+
def _get_secret_or_env(self, key, default=None):
|
|
30
|
+
# Project-secrets are mounted as env variables whose name can be retrieved from SecretsStore
|
|
31
|
+
return mlrun.get_secret_or_env(
|
|
32
|
+
key, secret_provider=self._get_secret, default=default
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def _get_parent_secret(self, key):
|
|
36
|
+
return self._parent.secret(self.secret_pfx + key)
|
|
37
|
+
|
|
38
|
+
def _get_secret(self, key: str, default=None):
|
|
39
|
+
return self._secrets.get(key, default) or self._get_parent_secret(key)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def url(self):
|
|
43
|
+
return f"{self.kind}://{self.endpoint}"
|
|
44
|
+
|
|
45
|
+
@staticmethod
|
|
46
|
+
def _sanitize_options(options):
|
|
47
|
+
if not options:
|
|
48
|
+
return {}
|
|
49
|
+
options = {k: v for k, v in options.items() if v is not None and v != ""}
|
|
50
|
+
return options
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def parse_endpoint_and_path(cls, endpoint, subpath) -> (str, str):
|
|
54
|
+
return endpoint, subpath
|
mlrun/datastore/s3.py
CHANGED
mlrun/datastore/storeytargets.py
CHANGED
|
@@ -46,7 +46,7 @@ def get_url_and_storage_options(path, external_storage_options=None):
|
|
|
46
46
|
storage_options = merge(external_storage_options, storage_options)
|
|
47
47
|
else:
|
|
48
48
|
storage_options = storage_options or external_storage_options
|
|
49
|
-
return url, DataStore.
|
|
49
|
+
return url, DataStore._sanitize_options(storage_options)
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class TDEngineStoreyTarget(storey.TDEngineTarget):
|
mlrun/datastore/utils.py
CHANGED
|
@@ -311,3 +311,25 @@ class KafkaParameters:
|
|
|
311
311
|
valid_keys.update(ref_dict.keys())
|
|
312
312
|
# Return a new dictionary with only valid keys
|
|
313
313
|
return {k: v for k, v in input_dict.items() if k in valid_keys}
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def parse_url(url):
|
|
317
|
+
if url and url.startswith("v3io://") and not url.startswith("v3io:///"):
|
|
318
|
+
url = url.replace("v3io://", "v3io:///", 1)
|
|
319
|
+
parsed_url = urlparse(url)
|
|
320
|
+
schema = parsed_url.scheme.lower()
|
|
321
|
+
endpoint = parsed_url.hostname
|
|
322
|
+
if endpoint:
|
|
323
|
+
# HACK - urlparse returns the hostname after in lower case - we want the original case:
|
|
324
|
+
# the hostname is a substring of the netloc, in which it's the original case, so we find the indexes of the
|
|
325
|
+
# hostname in the netloc and take it from there
|
|
326
|
+
lower_hostname = parsed_url.hostname
|
|
327
|
+
netloc = str(parsed_url.netloc)
|
|
328
|
+
lower_netloc = netloc.lower()
|
|
329
|
+
hostname_index_in_netloc = lower_netloc.index(str(lower_hostname))
|
|
330
|
+
endpoint = netloc[
|
|
331
|
+
hostname_index_in_netloc : hostname_index_in_netloc + len(lower_hostname)
|
|
332
|
+
]
|
|
333
|
+
if parsed_url.port:
|
|
334
|
+
endpoint += f":{parsed_url.port}"
|
|
335
|
+
return schema, endpoint, parsed_url
|
mlrun/datastore/v3io.py
CHANGED
mlrun/db/base.py
CHANGED
mlrun/db/httpdb.py
CHANGED
|
@@ -608,7 +608,7 @@ class HTTPRunDB(RunDBInterface):
|
|
|
608
608
|
error = f"store log {project}/{uid}"
|
|
609
609
|
self.api_call("POST", path, error, params, body)
|
|
610
610
|
|
|
611
|
-
def get_log(self, uid, project="", offset=0, size=None):
|
|
611
|
+
def get_log(self, uid, project="", offset=0, size=None, attempt=None):
|
|
612
612
|
"""Retrieve 1 MB data of log.
|
|
613
613
|
|
|
614
614
|
:param uid: Log unique ID
|
|
@@ -616,6 +616,8 @@ class HTTPRunDB(RunDBInterface):
|
|
|
616
616
|
:param offset: Retrieve partial log, get up to ``size`` bytes starting at offset ``offset``
|
|
617
617
|
from beginning of log (must be >= 0)
|
|
618
618
|
:param size: If set to ``-1`` will retrieve and print all data to end of the log by chunks of 1MB each.
|
|
619
|
+
:param attempt: For retriable runs, the attempt number to retrieve the log for.
|
|
620
|
+
1 is the initial attempt.
|
|
619
621
|
:returns: The following objects:
|
|
620
622
|
|
|
621
623
|
- state - The state of the runtime object which generates this log, if it exists. In case no known state
|
|
@@ -636,6 +638,8 @@ class HTTPRunDB(RunDBInterface):
|
|
|
636
638
|
return state, offset
|
|
637
639
|
|
|
638
640
|
params = {"offset": offset, "size": size}
|
|
641
|
+
if attempt:
|
|
642
|
+
params["attempt"] = attempt
|
|
639
643
|
path = self._path_of("logs", project, uid)
|
|
640
644
|
error = f"get log {project}/{uid}"
|
|
641
645
|
resp = self.api_call("GET", path, error, params=params)
|
|
@@ -658,7 +662,7 @@ class HTTPRunDB(RunDBInterface):
|
|
|
658
662
|
resp = self.api_call("GET", path, error)
|
|
659
663
|
return resp.json()["size"]
|
|
660
664
|
|
|
661
|
-
def watch_log(self, uid, project="", watch=True, offset=0):
|
|
665
|
+
def watch_log(self, uid, project="", watch=True, offset=0, attempt=None):
|
|
662
666
|
"""Retrieve logs of a running process by chunks of 1MB, and watch the progress of the execution until it
|
|
663
667
|
completes. This method will print out the logs and continue to periodically poll for, and print,
|
|
664
668
|
new logs as long as the state of the runtime which generates this log is either ``pending`` or ``running``.
|
|
@@ -668,10 +672,11 @@ class HTTPRunDB(RunDBInterface):
|
|
|
668
672
|
:param watch: If set to ``True`` will continue tracking the log as described above. Otherwise this function
|
|
669
673
|
is practically equivalent to the :py:func:`~get_log` function.
|
|
670
674
|
:param offset: Minimal offset in the log to watch.
|
|
675
|
+
:param attempt: For retriable runs, the attempt number to retrieve the log for. 1 is the initial attempt.
|
|
671
676
|
:returns: The final state of the log being watched and the final offset.
|
|
672
677
|
"""
|
|
673
678
|
|
|
674
|
-
state, text = self.get_log(uid, project, offset=offset)
|
|
679
|
+
state, text = self.get_log(uid, project, offset=offset, attempt=attempt)
|
|
675
680
|
if text:
|
|
676
681
|
print(text.decode(errors=mlrun.mlconf.httpdb.logs.decode.errors))
|
|
677
682
|
nil_resp = 0
|
|
@@ -687,7 +692,7 @@ class HTTPRunDB(RunDBInterface):
|
|
|
687
692
|
mlrun.mlconf.httpdb.logs.pull_logs_backoff_no_logs_default_interval
|
|
688
693
|
)
|
|
689
694
|
)
|
|
690
|
-
state, text = self.get_log(uid, project, offset=offset)
|
|
695
|
+
state, text = self.get_log(uid, project, offset=offset, attempt=attempt)
|
|
691
696
|
if text:
|
|
692
697
|
nil_resp = 0
|
|
693
698
|
print(
|
mlrun/db/nopdb.py
CHANGED
|
@@ -63,7 +63,7 @@ class NopDB(RunDBInterface):
|
|
|
63
63
|
def store_log(self, uid, project="", body=None, append=False):
|
|
64
64
|
pass
|
|
65
65
|
|
|
66
|
-
def get_log(self, uid, project="", offset=0, size=0):
|
|
66
|
+
def get_log(self, uid, project="", offset=0, size=0, attempt=None):
|
|
67
67
|
pass
|
|
68
68
|
|
|
69
69
|
def store_run(self, struct, uid, project="", iter=0):
|
mlrun/execution.py
CHANGED
|
@@ -26,6 +26,7 @@ from dateutil import parser
|
|
|
26
26
|
import mlrun
|
|
27
27
|
import mlrun.common.constants as mlrun_constants
|
|
28
28
|
import mlrun.common.formatters
|
|
29
|
+
import mlrun.common.runtimes.constants
|
|
29
30
|
from mlrun.artifacts import (
|
|
30
31
|
Artifact,
|
|
31
32
|
DatasetArtifact,
|
|
@@ -91,6 +92,8 @@ class MLClientCtx:
|
|
|
91
92
|
self._autocommit = autocommit
|
|
92
93
|
self._notifications = []
|
|
93
94
|
self._state_thresholds = {}
|
|
95
|
+
self._retry_spec = {}
|
|
96
|
+
self._retry_count = None
|
|
94
97
|
|
|
95
98
|
self._labels = {}
|
|
96
99
|
self._annotations = {}
|
|
@@ -432,6 +435,7 @@ class MLClientCtx:
|
|
|
432
435
|
self._tolerations = spec.get("tolerations", self._tolerations)
|
|
433
436
|
self._affinity = spec.get("affinity", self._affinity)
|
|
434
437
|
self._reset_on_run = spec.get("reset_on_run", self._reset_on_run)
|
|
438
|
+
self._retry_spec = spec.get("retry", self._retry_spec)
|
|
435
439
|
|
|
436
440
|
self._init_dbs(rundb)
|
|
437
441
|
|
|
@@ -450,10 +454,11 @@ class MLClientCtx:
|
|
|
450
454
|
if start:
|
|
451
455
|
start = parser.parse(start) if isinstance(start, str) else start
|
|
452
456
|
self._start_time = start
|
|
453
|
-
self._state =
|
|
457
|
+
self._state = mlrun.common.runtimes.constants.RunStates.running
|
|
454
458
|
|
|
455
459
|
status = attrs.get("status")
|
|
456
|
-
|
|
460
|
+
retry_configured = self._retry_spec and self._retry_spec.get("count")
|
|
461
|
+
if (include_status or retry_configured) and status:
|
|
457
462
|
self._results = status.get("results", self._results)
|
|
458
463
|
for artifact in status.get("artifacts", []):
|
|
459
464
|
artifact_obj = dict_to_artifact(artifact)
|
|
@@ -462,7 +467,10 @@ class MLClientCtx:
|
|
|
462
467
|
)
|
|
463
468
|
for key, uri in status.get("artifact_uris", {}).items():
|
|
464
469
|
self._artifacts_manager.artifact_uris[key] = uri
|
|
465
|
-
self.
|
|
470
|
+
self._retry_count = status.get("retry_count", self._retry_count)
|
|
471
|
+
# if run is a retry, the state needs to move to running
|
|
472
|
+
if include_status:
|
|
473
|
+
self._state = status.get("state", self._state)
|
|
466
474
|
|
|
467
475
|
# No need to store the run for every worker
|
|
468
476
|
if store_run and self.is_logging_worker():
|
|
@@ -1107,13 +1115,13 @@ class MLClientCtx:
|
|
|
1107
1115
|
:param completed: Mark run as completed
|
|
1108
1116
|
"""
|
|
1109
1117
|
# Changing state to completed is allowed only when the execution is in running state
|
|
1110
|
-
if self._state !=
|
|
1118
|
+
if self._state != mlrun.common.runtimes.constants.RunStates.running:
|
|
1111
1119
|
completed = False
|
|
1112
1120
|
|
|
1113
1121
|
if message:
|
|
1114
1122
|
self._annotations["message"] = message
|
|
1115
1123
|
if completed:
|
|
1116
|
-
self._state =
|
|
1124
|
+
self._state = mlrun.common.runtimes.constants.RunStates.completed
|
|
1117
1125
|
|
|
1118
1126
|
if self._parent:
|
|
1119
1127
|
self._parent.update_child_iterations()
|
|
@@ -1147,9 +1155,15 @@ class MLClientCtx:
|
|
|
1147
1155
|
updates = {"status.last_update": now_date().isoformat()}
|
|
1148
1156
|
|
|
1149
1157
|
if error is not None:
|
|
1150
|
-
|
|
1158
|
+
state = mlrun.common.runtimes.constants.RunStates.error
|
|
1159
|
+
max_retries = self._retry_spec.get("count", 0)
|
|
1160
|
+
self._retry_count = self._retry_count or 0
|
|
1161
|
+
if max_retries and self._retry_count < max_retries:
|
|
1162
|
+
state = mlrun.common.runtimes.constants.RunStates.pending_retry
|
|
1163
|
+
|
|
1164
|
+
self._state = state
|
|
1151
1165
|
self._error = str(error)
|
|
1152
|
-
updates["status.state"] =
|
|
1166
|
+
updates["status.state"] = state
|
|
1153
1167
|
updates["status.error"] = error
|
|
1154
1168
|
elif (
|
|
1155
1169
|
execution_state
|
|
@@ -1241,11 +1255,13 @@ class MLClientCtx:
|
|
|
1241
1255
|
"node_selector": self._node_selector,
|
|
1242
1256
|
"tolerations": self._tolerations,
|
|
1243
1257
|
"affinity": self._affinity,
|
|
1258
|
+
"retry": self._retry_spec,
|
|
1244
1259
|
},
|
|
1245
1260
|
"status": {
|
|
1246
1261
|
"results": self._results,
|
|
1247
1262
|
"start_time": to_date_str(self._start_time),
|
|
1248
1263
|
"last_update": to_date_str(self._last_update),
|
|
1264
|
+
"retry_count": self._retry_count,
|
|
1249
1265
|
},
|
|
1250
1266
|
}
|
|
1251
1267
|
|
mlrun/launcher/base.py
CHANGED
|
@@ -18,6 +18,8 @@ import os
|
|
|
18
18
|
import uuid
|
|
19
19
|
from typing import Any, Callable, Optional, Union
|
|
20
20
|
|
|
21
|
+
import mlrun.common.constants
|
|
22
|
+
import mlrun.common.runtimes.constants
|
|
21
23
|
import mlrun.common.schemas
|
|
22
24
|
import mlrun.config
|
|
23
25
|
import mlrun.errors
|
|
@@ -72,6 +74,7 @@ class BaseLauncher(abc.ABC):
|
|
|
72
74
|
notifications: Optional[list[mlrun.model.Notification]] = None,
|
|
73
75
|
returns: Optional[list[Union[str, dict[str, str]]]] = None,
|
|
74
76
|
state_thresholds: Optional[dict[str, int]] = None,
|
|
77
|
+
retry: Optional[Union[mlrun.model.Retry, dict]] = None,
|
|
75
78
|
) -> "mlrun.run.RunObject":
|
|
76
79
|
"""run the function from the server/client[local/remote]"""
|
|
77
80
|
pass
|
|
@@ -133,7 +136,7 @@ class BaseLauncher(abc.ABC):
|
|
|
133
136
|
"""Check if the runtime requires to build the image and updates the spec accordingly"""
|
|
134
137
|
pass
|
|
135
138
|
|
|
136
|
-
def
|
|
139
|
+
def _validate_run(
|
|
137
140
|
self,
|
|
138
141
|
runtime: "mlrun.runtimes.BaseRuntime",
|
|
139
142
|
run: "mlrun.run.RunObject",
|
|
@@ -194,7 +197,7 @@ class BaseLauncher(abc.ABC):
|
|
|
194
197
|
)
|
|
195
198
|
|
|
196
199
|
@classmethod
|
|
197
|
-
def _validate_run_single_param(cls, param_name, param_value):
|
|
200
|
+
def _validate_run_single_param(cls, param_name: str, param_value: int):
|
|
198
201
|
# verify that integer parameters don't exceed a int64
|
|
199
202
|
if isinstance(param_value, int) and abs(param_value) >= 2**63:
|
|
200
203
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
@@ -203,8 +206,6 @@ class BaseLauncher(abc.ABC):
|
|
|
203
206
|
|
|
204
207
|
@staticmethod
|
|
205
208
|
def _create_run_object(task):
|
|
206
|
-
valid_task_types = (dict, mlrun.run.RunTemplate, mlrun.run.RunObject)
|
|
207
|
-
|
|
208
209
|
if not task:
|
|
209
210
|
# if task passed generate default RunObject
|
|
210
211
|
return mlrun.run.RunObject.from_dict(task)
|
|
@@ -215,18 +216,18 @@ class BaseLauncher(abc.ABC):
|
|
|
215
216
|
if isinstance(task, str):
|
|
216
217
|
task = ast.literal_eval(task)
|
|
217
218
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
219
|
+
valid_task_types = (dict, mlrun.run.RunTemplate, mlrun.run.RunObject)
|
|
220
|
+
if isinstance(task, mlrun.run.RunObject):
|
|
221
|
+
# if task is already a RunObject, we can return it as is
|
|
222
|
+
return task
|
|
223
223
|
if isinstance(task, mlrun.run.RunTemplate):
|
|
224
224
|
return mlrun.run.RunObject.from_template(task)
|
|
225
225
|
elif isinstance(task, dict):
|
|
226
226
|
return mlrun.run.RunObject.from_dict(task)
|
|
227
227
|
|
|
228
|
-
|
|
229
|
-
|
|
228
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
229
|
+
f"Task is not a valid object, type={type(task)}, expected types={valid_task_types}"
|
|
230
|
+
)
|
|
230
231
|
|
|
231
232
|
@staticmethod
|
|
232
233
|
def _enrich_run(
|
|
@@ -246,6 +247,7 @@ class BaseLauncher(abc.ABC):
|
|
|
246
247
|
workdir=None,
|
|
247
248
|
notifications: Optional[list[mlrun.model.Notification]] = None,
|
|
248
249
|
state_thresholds: Optional[dict[str, int]] = None,
|
|
250
|
+
retry: Optional[Union[mlrun.model.Retry, dict]] = None,
|
|
249
251
|
):
|
|
250
252
|
run.spec.handler = (
|
|
251
253
|
handler or run.spec.handler or runtime.spec.default_handler or ""
|
|
@@ -364,6 +366,7 @@ class BaseLauncher(abc.ABC):
|
|
|
364
366
|
| state_thresholds
|
|
365
367
|
)
|
|
366
368
|
run.spec.state_thresholds = state_thresholds or run.spec.state_thresholds
|
|
369
|
+
run.spec.retry = retry or run.spec.retry
|
|
367
370
|
return run
|
|
368
371
|
|
|
369
372
|
@staticmethod
|
|
@@ -410,7 +413,7 @@ class BaseLauncher(abc.ABC):
|
|
|
410
413
|
)
|
|
411
414
|
if (
|
|
412
415
|
run.status.state
|
|
413
|
-
in mlrun.common.runtimes.constants.RunStates.
|
|
416
|
+
in mlrun.common.runtimes.constants.RunStates.error_states()
|
|
414
417
|
):
|
|
415
418
|
if runtime._is_remote and not runtime.is_child:
|
|
416
419
|
logger.error(
|
|
@@ -418,7 +421,14 @@ class BaseLauncher(abc.ABC):
|
|
|
418
421
|
state=run.status.state,
|
|
419
422
|
status=run.status.to_dict(),
|
|
420
423
|
)
|
|
421
|
-
|
|
424
|
+
|
|
425
|
+
error = run.error
|
|
426
|
+
if (
|
|
427
|
+
run.status.state
|
|
428
|
+
== mlrun.common.runtimes.constants.RunStates.pending_retry
|
|
429
|
+
):
|
|
430
|
+
error = f"Run is pending retry, error: {run.error}"
|
|
431
|
+
raise mlrun.runtimes.utils.RunError(error)
|
|
422
432
|
return run
|
|
423
433
|
|
|
424
434
|
return None
|
mlrun/launcher/local.py
CHANGED
|
@@ -72,6 +72,7 @@ class ClientLocalLauncher(launcher.ClientBaseLauncher):
|
|
|
72
72
|
returns: Optional[list[Union[str, dict[str, str]]]] = None,
|
|
73
73
|
state_thresholds: Optional[dict[str, int]] = None,
|
|
74
74
|
reset_on_run: Optional[bool] = None,
|
|
75
|
+
retry: Optional[Union[mlrun.model.Retry, dict]] = None,
|
|
75
76
|
) -> "mlrun.run.RunObject":
|
|
76
77
|
# do not allow local function to be scheduled
|
|
77
78
|
if schedule is not None:
|
|
@@ -122,8 +123,9 @@ class ClientLocalLauncher(launcher.ClientBaseLauncher):
|
|
|
122
123
|
workdir=workdir,
|
|
123
124
|
notifications=notifications,
|
|
124
125
|
state_thresholds=state_thresholds,
|
|
126
|
+
retry=retry,
|
|
125
127
|
)
|
|
126
|
-
self.
|
|
128
|
+
self._validate_run(runtime, run)
|
|
127
129
|
result = self._execute(
|
|
128
130
|
runtime=runtime,
|
|
129
131
|
run=run,
|
mlrun/launcher/remote.py
CHANGED
|
@@ -61,6 +61,7 @@ class ClientRemoteLauncher(launcher.ClientBaseLauncher):
|
|
|
61
61
|
returns: Optional[list[Union[str, dict[str, str]]]] = None,
|
|
62
62
|
state_thresholds: Optional[dict[str, int]] = None,
|
|
63
63
|
reset_on_run: Optional[bool] = None,
|
|
64
|
+
retry: Optional[Union[mlrun.model.Retry, dict]] = None,
|
|
64
65
|
) -> "mlrun.run.RunObject":
|
|
65
66
|
self.enrich_runtime(runtime, project)
|
|
66
67
|
run = self._create_run_object(task)
|
|
@@ -82,8 +83,9 @@ class ClientRemoteLauncher(launcher.ClientBaseLauncher):
|
|
|
82
83
|
workdir=workdir,
|
|
83
84
|
notifications=notifications,
|
|
84
85
|
state_thresholds=state_thresholds,
|
|
86
|
+
retry=retry,
|
|
85
87
|
)
|
|
86
|
-
self.
|
|
88
|
+
self._validate_run(runtime, run)
|
|
87
89
|
|
|
88
90
|
if not runtime.is_deployed():
|
|
89
91
|
if runtime.spec.build.auto_build or auto_build:
|
|
@@ -190,7 +192,7 @@ class ClientRemoteLauncher(launcher.ClientBaseLauncher):
|
|
|
190
192
|
return self._wrap_run_result(runtime, resp, run, schedule=schedule)
|
|
191
193
|
|
|
192
194
|
@classmethod
|
|
193
|
-
def _validate_run_single_param(cls, param_name, param_value):
|
|
195
|
+
def _validate_run_single_param(cls, param_name: str, param_value: int):
|
|
194
196
|
if isinstance(param_value, pd.DataFrame):
|
|
195
197
|
raise mlrun.errors.MLRunInvalidArgumentTypeError(
|
|
196
198
|
f"Parameter '{param_name}' has an unsupported value of type"
|