ob-metaflow-extensions 1.1.142__py2.py3-none-any.whl → 1.4.33__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metaflow_extensions/outerbounds/__init__.py +1 -1
- metaflow_extensions/outerbounds/plugins/__init__.py +26 -5
- metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/app_deploy_decorator.py +146 -0
- metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +10 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_cli.py +1200 -0
- metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +146 -0
- metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
- metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
- metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
- metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +12 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +161 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +868 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +288 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +139 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +398 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1088 -0
- metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
- metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
- metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +303 -0
- metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
- metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
- metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
- metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
- metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
- metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
- metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +78 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
- metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
- metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +17 -3
- metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +1 -0
- metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +18 -44
- metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
- metaflow_extensions/outerbounds/plugins/nim/card.py +1 -6
- metaflow_extensions/outerbounds/plugins/nim/{__init__.py → nim_decorator.py} +13 -49
- metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +294 -233
- metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
- metaflow_extensions/outerbounds/plugins/nvcf/constants.py +2 -2
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +100 -19
- metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +6 -1
- metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
- metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
- metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
- metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
- metaflow_extensions/outerbounds/plugins/ollama/__init__.py +171 -16
- metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
- metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1710 -114
- metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
- metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
- metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
- metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
- metaflow_extensions/outerbounds/plugins/secrets/secrets.py +38 -2
- metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +44 -4
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +6 -3
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +13 -7
- metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +8 -2
- metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
- metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
- metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
- metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
- metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
- metaflow_extensions/outerbounds/remote_config.py +27 -3
- metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +87 -2
- metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
- metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
- metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/METADATA +2 -2
- ob_metaflow_extensions-1.4.33.dist-info/RECORD +134 -0
- metaflow_extensions/outerbounds/plugins/nim/utilities.py +0 -5
- ob_metaflow_extensions-1.1.142.dist-info/RECORD +0 -64
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/WHEEL +0 -0
- {ob_metaflow_extensions-1.1.142.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import time
|
|
4
|
+
import math
|
|
5
|
+
import shlex
|
|
6
|
+
import atexit
|
|
7
|
+
|
|
8
|
+
from metaflow import util
|
|
9
|
+
from metaflow.mflog import (
|
|
10
|
+
BASH_SAVE_LOGS,
|
|
11
|
+
bash_capture_logs,
|
|
12
|
+
export_mflog_env_vars,
|
|
13
|
+
tail_logs,
|
|
14
|
+
get_log_tailer,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from .nvct import NVCTClient, NVCTTask, NVCTRequest
|
|
18
|
+
from .exceptions import (
|
|
19
|
+
NvctKilledException,
|
|
20
|
+
NvctExecutionException,
|
|
21
|
+
NvctTaskFailedException,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Constants for Metaflow logs
|
|
25
|
+
LOGS_DIR = "$PWD/.logs"
|
|
26
|
+
STDOUT_FILE = "mflog_stdout"
|
|
27
|
+
STDERR_FILE = "mflog_stderr"
|
|
28
|
+
STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE)
|
|
29
|
+
STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE)
|
|
30
|
+
NVCT_WRAPPER = "/usr/local/bin/nvct-wrapper.sh"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class NvctRunner:
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
metadata,
|
|
37
|
+
datastore,
|
|
38
|
+
environment,
|
|
39
|
+
gpu_type,
|
|
40
|
+
instance_type,
|
|
41
|
+
backend,
|
|
42
|
+
ngc_api_key,
|
|
43
|
+
):
|
|
44
|
+
self.metadata = metadata
|
|
45
|
+
self.datastore = datastore
|
|
46
|
+
self.environment = environment
|
|
47
|
+
self.gpu_type = gpu_type
|
|
48
|
+
self.instance_type = instance_type
|
|
49
|
+
self.backend = backend
|
|
50
|
+
self._ngc_api_key = ngc_api_key
|
|
51
|
+
self.client = None
|
|
52
|
+
self.task = None
|
|
53
|
+
atexit.register(lambda: self.task.cancel() if hasattr(self, "task") else None)
|
|
54
|
+
|
|
55
|
+
def launch_task(
|
|
56
|
+
self,
|
|
57
|
+
step_name,
|
|
58
|
+
step_cli,
|
|
59
|
+
task_spec,
|
|
60
|
+
code_package_sha,
|
|
61
|
+
code_package_url,
|
|
62
|
+
code_package_ds,
|
|
63
|
+
env={},
|
|
64
|
+
max_runtime="PT7H", # <8H allowed for GFN backend
|
|
65
|
+
max_queued="PT120H", # 5 days
|
|
66
|
+
):
|
|
67
|
+
mflog_expr = export_mflog_env_vars(
|
|
68
|
+
datastore_type=code_package_ds,
|
|
69
|
+
stdout_path=STDOUT_PATH,
|
|
70
|
+
stderr_path=STDERR_PATH,
|
|
71
|
+
**task_spec,
|
|
72
|
+
)
|
|
73
|
+
init_cmds = self.environment.get_package_commands(
|
|
74
|
+
code_package_url, code_package_ds
|
|
75
|
+
)
|
|
76
|
+
init_expr = " && ".join(init_cmds)
|
|
77
|
+
step_expr = bash_capture_logs(
|
|
78
|
+
" && ".join(
|
|
79
|
+
self.environment.bootstrap_commands(step_name, code_package_ds)
|
|
80
|
+
+ [step_cli]
|
|
81
|
+
)
|
|
82
|
+
)
|
|
83
|
+
cmd_str = "mkdir -p %s && %s && %s && %s; c=$?; %s; exit $c" % (
|
|
84
|
+
LOGS_DIR,
|
|
85
|
+
mflog_expr,
|
|
86
|
+
init_expr,
|
|
87
|
+
step_expr,
|
|
88
|
+
BASH_SAVE_LOGS,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Add optional initialization script execution
|
|
92
|
+
cmd_str = (
|
|
93
|
+
'${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"} && %s'
|
|
94
|
+
% cmd_str
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
cmd_str = shlex.split('bash -c "%s"' % cmd_str)[-1]
|
|
98
|
+
|
|
99
|
+
def modify_python_c(match):
|
|
100
|
+
content = match.group(1)
|
|
101
|
+
# Escape double quotes within the python -c command
|
|
102
|
+
content = content.replace('"', r"\"")
|
|
103
|
+
# Replace outermost double quotes with single quotes
|
|
104
|
+
return 'python -c "%s"' % content
|
|
105
|
+
|
|
106
|
+
# Convert python -c single quotes to double quotes
|
|
107
|
+
cmd_str = re.sub(r"python -c '(.*?)'", modify_python_c, cmd_str)
|
|
108
|
+
cmd_str = cmd_str.replace("'", '"')
|
|
109
|
+
# Create the final command with outer single quotes to pass to NVCT wrapper
|
|
110
|
+
nvct_cmd = f"{NVCT_WRAPPER} bash -c '{cmd_str}'"
|
|
111
|
+
|
|
112
|
+
flow_name = task_spec.get("flow_name")
|
|
113
|
+
run_id = task_spec.get("run_id")
|
|
114
|
+
task_id = task_spec.get("task_id")
|
|
115
|
+
retry_count = task_spec.get("retry_count")
|
|
116
|
+
task_name = f"{flow_name}-{run_id}-{step_name}-{task_id}-{retry_count}"
|
|
117
|
+
|
|
118
|
+
if self.backend != "GFN":
|
|
119
|
+
# if maxRuntimeDuration exceeds 8 hours for a Task on the GFN backend,
|
|
120
|
+
# the request will be rejected.
|
|
121
|
+
# (https://docs.nvidia.com/cloud-functions/user-guide/latest/cloud-function/tasks.html#create-task)
|
|
122
|
+
## thus, if it is non GFN backend, we increase it to 3 days
|
|
123
|
+
max_runtime = "PT72H"
|
|
124
|
+
|
|
125
|
+
request = (
|
|
126
|
+
NVCTRequest(task_name)
|
|
127
|
+
.container_image("nvcr.io/zhxkmsaasxhw/nvct-base:2.0-jovyan")
|
|
128
|
+
.container_args(nvct_cmd)
|
|
129
|
+
.gpu(
|
|
130
|
+
gpu=self.gpu_type,
|
|
131
|
+
instance_type=self.instance_type,
|
|
132
|
+
backend=self.backend,
|
|
133
|
+
)
|
|
134
|
+
.max_runtime(max_runtime)
|
|
135
|
+
.max_queued(max_queued)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
for k, v in env.items():
|
|
139
|
+
if v is not None:
|
|
140
|
+
request.env(k, str(v))
|
|
141
|
+
|
|
142
|
+
self.client = NVCTClient(self._ngc_api_key)
|
|
143
|
+
self.task = NVCTTask(self.client, request.to_dict())
|
|
144
|
+
|
|
145
|
+
self.task.submit()
|
|
146
|
+
return self.task.id
|
|
147
|
+
|
|
148
|
+
def wait_for_completion(self, stdout_location, stderr_location, echo=None):
|
|
149
|
+
if not self.task:
|
|
150
|
+
raise NvctExecutionException("No task has been launched")
|
|
151
|
+
|
|
152
|
+
def update_delay(secs_since_start):
|
|
153
|
+
# this sigmoid function reaches
|
|
154
|
+
# - 0.1 after 11 minutes
|
|
155
|
+
# - 0.5 after 15 minutes
|
|
156
|
+
# - 1.0 after 23 minutes
|
|
157
|
+
# in other words, the user will see very frequent updates
|
|
158
|
+
# during the first 10 minutes
|
|
159
|
+
sigmoid = 1.0 / (1.0 + math.exp(-0.01 * secs_since_start + 9.0))
|
|
160
|
+
return 0.5 + sigmoid * 30.0
|
|
161
|
+
|
|
162
|
+
def wait_for_launch(task):
|
|
163
|
+
status = task.status
|
|
164
|
+
echo(
|
|
165
|
+
"Task is starting (%s)..." % status,
|
|
166
|
+
"stderr",
|
|
167
|
+
_id=task.id,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
t = time.time()
|
|
171
|
+
start_time = time.time()
|
|
172
|
+
while task.is_waiting:
|
|
173
|
+
new_status = task.status
|
|
174
|
+
if status != new_status or (time.time() - t) > 30:
|
|
175
|
+
status = new_status
|
|
176
|
+
echo(
|
|
177
|
+
"Task is starting (%s)..." % status,
|
|
178
|
+
"stderr",
|
|
179
|
+
_id=task.id,
|
|
180
|
+
)
|
|
181
|
+
t = time.time()
|
|
182
|
+
time.sleep(update_delay(time.time() - start_time))
|
|
183
|
+
|
|
184
|
+
_make_prefix = lambda: b"[%s] " % util.to_bytes(self.task.id)
|
|
185
|
+
stdout_tail = get_log_tailer(stdout_location, self.datastore.TYPE)
|
|
186
|
+
stderr_tail = get_log_tailer(stderr_location, self.datastore.TYPE)
|
|
187
|
+
|
|
188
|
+
# 1) Loop until the job has started
|
|
189
|
+
wait_for_launch(self.task)
|
|
190
|
+
|
|
191
|
+
echo(
|
|
192
|
+
"Task is starting (%s)..." % self.task.status,
|
|
193
|
+
"stderr",
|
|
194
|
+
_id=self.task.id,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# 2) Tail logs until the job has finished
|
|
198
|
+
tail_logs(
|
|
199
|
+
prefix=_make_prefix(),
|
|
200
|
+
stdout_tail=stdout_tail,
|
|
201
|
+
stderr_tail=stderr_tail,
|
|
202
|
+
echo=echo,
|
|
203
|
+
has_log_updates=lambda: self.task.is_running,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if self.task.has_failed:
|
|
207
|
+
raise NvctTaskFailedException(
|
|
208
|
+
f"Task failed with status: {self.task.status}. This could be a transient error. Use @retry to retry."
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
if self.task.is_running:
|
|
212
|
+
# Kill the job if it is still running by throwing an exception.
|
|
213
|
+
raise NvctKilledException("Task failed!")
|
|
214
|
+
echo(
|
|
215
|
+
f"Task finished with status: {self.task.status}",
|
|
216
|
+
"stderr",
|
|
217
|
+
_id=self.task.id,
|
|
218
|
+
)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import requests
|
|
4
|
+
from urllib.parse import urlparse
|
|
5
|
+
from metaflow.metaflow_config import SERVICE_URL
|
|
6
|
+
from metaflow.metaflow_config_funcs import init_config
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_ngc_api_key():
|
|
10
|
+
conf = init_config()
|
|
11
|
+
if "OBP_AUTH_SERVER" in conf:
|
|
12
|
+
auth_host = conf["OBP_AUTH_SERVER"]
|
|
13
|
+
else:
|
|
14
|
+
auth_host = "auth." + urlparse(SERVICE_URL).hostname.split(".", 1)[1]
|
|
15
|
+
|
|
16
|
+
# NOTE: reusing the same auth_host as the one used in NimMetadata,
|
|
17
|
+
# however, user should not need to use nim container to use @nvct.
|
|
18
|
+
# May want to refactor this to a common endpoint.
|
|
19
|
+
nim_info_url = "https://" + auth_host + "/generate/nim"
|
|
20
|
+
|
|
21
|
+
if "METAFLOW_SERVICE_AUTH_KEY" in conf:
|
|
22
|
+
headers = {"x-api-key": conf["METAFLOW_SERVICE_AUTH_KEY"]}
|
|
23
|
+
res = requests.get(nim_info_url, headers=headers)
|
|
24
|
+
else:
|
|
25
|
+
headers = json.loads(os.environ.get("METAFLOW_SERVICE_HEADERS"))
|
|
26
|
+
res = requests.get(nim_info_url, headers=headers)
|
|
27
|
+
|
|
28
|
+
res.raise_for_status()
|
|
29
|
+
return res.json()["nvcf"]["api_key"]
|
|
@@ -1,8 +1,11 @@
|
|
|
1
1
|
from metaflow.decorators import StepDecorator
|
|
2
2
|
from metaflow import current
|
|
3
3
|
import functools
|
|
4
|
+
import os
|
|
5
|
+
import threading
|
|
4
6
|
|
|
5
|
-
from .ollama import OllamaManager
|
|
7
|
+
from .ollama import OllamaManager, OllamaRequestInterceptor
|
|
8
|
+
from .status_card import OllamaStatusCard
|
|
6
9
|
from ..card_utilities.injector import CardDecoratorInjector
|
|
7
10
|
|
|
8
11
|
__mf_promote_submodules__ = ["plugins.ollama"]
|
|
@@ -13,10 +16,10 @@ class OllamaDecorator(StepDecorator, CardDecoratorInjector):
|
|
|
13
16
|
This decorator is used to run Ollama APIs as Metaflow task sidecars.
|
|
14
17
|
|
|
15
18
|
User code call
|
|
16
|
-
|
|
19
|
+
--------------
|
|
17
20
|
@ollama(
|
|
18
|
-
models=[
|
|
19
|
-
|
|
21
|
+
models=[...],
|
|
22
|
+
...
|
|
20
23
|
)
|
|
21
24
|
|
|
22
25
|
Valid backend options
|
|
@@ -26,45 +29,197 @@ class OllamaDecorator(StepDecorator, CardDecoratorInjector):
|
|
|
26
29
|
- (TODO) 'remote': Spin up separate instance to serve Ollama models.
|
|
27
30
|
|
|
28
31
|
Valid model options
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
- 'llama3.3'
|
|
32
|
-
- any model here https://ollama.com/search
|
|
32
|
+
-------------------
|
|
33
|
+
Any model here https://ollama.com/search, e.g. 'llama3.2', 'llama3.3'
|
|
33
34
|
|
|
34
35
|
Parameters
|
|
35
36
|
----------
|
|
36
|
-
models: list[
|
|
37
|
+
models: list[str]
|
|
37
38
|
List of Ollama containers running models in sidecars.
|
|
38
39
|
backend: str
|
|
39
40
|
Determines where and how to run the Ollama process.
|
|
41
|
+
force_pull: bool
|
|
42
|
+
Whether to run `ollama pull` no matter what, or first check the remote cache in Metaflow datastore for this model key.
|
|
43
|
+
cache_update_policy: str
|
|
44
|
+
Cache update policy: "auto", "force", or "never".
|
|
45
|
+
force_cache_update: bool
|
|
46
|
+
Simple override for "force" cache update policy.
|
|
47
|
+
debug: bool
|
|
48
|
+
Whether to turn on verbose debugging logs.
|
|
49
|
+
circuit_breaker_config: dict
|
|
50
|
+
Configuration for circuit breaker protection. Keys: failure_threshold, recovery_timeout, reset_timeout.
|
|
51
|
+
timeout_config: dict
|
|
52
|
+
Configuration for various operation timeouts. Keys: pull, stop, health_check, install, server_startup.
|
|
40
53
|
"""
|
|
41
54
|
|
|
42
55
|
name = "ollama"
|
|
43
|
-
defaults = {
|
|
56
|
+
defaults = {
|
|
57
|
+
"models": [],
|
|
58
|
+
"backend": "local",
|
|
59
|
+
"force_pull": False,
|
|
60
|
+
"cache_update_policy": "auto", # "auto", "force", "never"
|
|
61
|
+
"force_cache_update": False, # Simple override for "force"
|
|
62
|
+
"debug": False,
|
|
63
|
+
"circuit_breaker_config": {
|
|
64
|
+
"failure_threshold": 3,
|
|
65
|
+
"recovery_timeout": 60,
|
|
66
|
+
"reset_timeout": 30,
|
|
67
|
+
},
|
|
68
|
+
"timeout_config": {
|
|
69
|
+
"pull": 600, # 10 minutes for model pulls
|
|
70
|
+
"stop": 30, # 30 seconds for model stops
|
|
71
|
+
"health_check": 5, # 5 seconds for health checks
|
|
72
|
+
"install": 60, # 1 minute for Ollama installation
|
|
73
|
+
"server_startup": 300, # 5 minutes for server startup
|
|
74
|
+
},
|
|
75
|
+
"card_refresh_interval": 10, # seconds - how often to update the status card
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def step_init(
|
|
79
|
+
self, flow, graph, step_name, decorators, environment, flow_datastore, logger
|
|
80
|
+
):
|
|
81
|
+
super().step_init(
|
|
82
|
+
flow, graph, step_name, decorators, environment, flow_datastore, logger
|
|
83
|
+
)
|
|
84
|
+
self.flow_datastore_backend = flow_datastore._storage_impl
|
|
85
|
+
|
|
86
|
+
# Attach the ollama status card
|
|
87
|
+
self.attach_card_decorator(
|
|
88
|
+
flow,
|
|
89
|
+
step_name,
|
|
90
|
+
"ollama_status",
|
|
91
|
+
"blank",
|
|
92
|
+
refresh_interval=self.attributes["card_refresh_interval"],
|
|
93
|
+
)
|
|
44
94
|
|
|
45
95
|
def task_decorate(
|
|
46
96
|
self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
|
|
47
97
|
):
|
|
48
98
|
@functools.wraps(step_func)
|
|
49
99
|
def ollama_wrapper():
|
|
100
|
+
self.ollama_manager = None
|
|
101
|
+
self.request_interceptor = None
|
|
102
|
+
self.status_card = None
|
|
103
|
+
self.card_monitor_thread = None
|
|
104
|
+
|
|
50
105
|
try:
|
|
106
|
+
# Initialize status card and monitoring
|
|
107
|
+
self.status_card = OllamaStatusCard(
|
|
108
|
+
refresh_interval=self.attributes["card_refresh_interval"]
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Start card monitoring in background
|
|
112
|
+
def monitor_card():
|
|
113
|
+
try:
|
|
114
|
+
self.status_card.on_startup(current.card["ollama_status"])
|
|
115
|
+
|
|
116
|
+
while not getattr(
|
|
117
|
+
self.card_monitor_thread, "_stop_event", False
|
|
118
|
+
):
|
|
119
|
+
try:
|
|
120
|
+
# Trigger card update with current data
|
|
121
|
+
self.status_card.on_update(
|
|
122
|
+
current.card["ollama_status"], None
|
|
123
|
+
)
|
|
124
|
+
import time
|
|
125
|
+
|
|
126
|
+
time.sleep(self.attributes["card_refresh_interval"])
|
|
127
|
+
except Exception as e:
|
|
128
|
+
if self.attributes["debug"]:
|
|
129
|
+
print(f"[@ollama] Card monitoring error: {e}")
|
|
130
|
+
break
|
|
131
|
+
except Exception as e:
|
|
132
|
+
if self.attributes["debug"]:
|
|
133
|
+
print(f"[@ollama] Card monitor thread error: {e}")
|
|
134
|
+
self.status_card.on_error(current.card["ollama_status"], str(e))
|
|
135
|
+
|
|
136
|
+
self.card_monitor_thread = threading.Thread(
|
|
137
|
+
target=monitor_card, daemon=True
|
|
138
|
+
)
|
|
139
|
+
self.card_monitor_thread._stop_event = False
|
|
140
|
+
self.card_monitor_thread.start()
|
|
141
|
+
|
|
142
|
+
# Initialize OllamaManager with status card
|
|
51
143
|
self.ollama_manager = OllamaManager(
|
|
52
144
|
models=self.attributes["models"],
|
|
53
145
|
backend=self.attributes["backend"],
|
|
146
|
+
flow_datastore_backend=self.flow_datastore_backend,
|
|
147
|
+
force_pull=self.attributes["force_pull"],
|
|
148
|
+
cache_update_policy=self.attributes["cache_update_policy"],
|
|
149
|
+
force_cache_update=self.attributes["force_cache_update"],
|
|
54
150
|
debug=self.attributes["debug"],
|
|
151
|
+
circuit_breaker_config=self.attributes["circuit_breaker_config"],
|
|
152
|
+
timeout_config=self.attributes["timeout_config"],
|
|
153
|
+
status_card=self.status_card,
|
|
55
154
|
)
|
|
155
|
+
|
|
156
|
+
# Install request protection by monkey-patching ollama package
|
|
157
|
+
self.request_interceptor = OllamaRequestInterceptor(
|
|
158
|
+
self.ollama_manager.circuit_breaker, self.attributes["debug"]
|
|
159
|
+
)
|
|
160
|
+
self.request_interceptor.install_protection()
|
|
161
|
+
|
|
162
|
+
if self.attributes["debug"]:
|
|
163
|
+
print(
|
|
164
|
+
"[@ollama] OllamaManager initialized and request protection installed"
|
|
165
|
+
)
|
|
166
|
+
|
|
56
167
|
except Exception as e:
|
|
168
|
+
if self.status_card:
|
|
169
|
+
self.status_card.add_event(
|
|
170
|
+
"error", f"Initialization failed: {str(e)}"
|
|
171
|
+
)
|
|
172
|
+
try:
|
|
173
|
+
self.status_card.on_error(current.card["ollama_status"], str(e))
|
|
174
|
+
except:
|
|
175
|
+
pass
|
|
57
176
|
print(f"[@ollama] Error initializing OllamaManager: {e}")
|
|
58
177
|
raise
|
|
178
|
+
|
|
59
179
|
try:
|
|
180
|
+
if self.status_card:
|
|
181
|
+
self.status_card.add_event("info", "Starting user step function")
|
|
60
182
|
step_func()
|
|
183
|
+
if self.status_card:
|
|
184
|
+
self.status_card.add_event(
|
|
185
|
+
"success", "User step function completed successfully"
|
|
186
|
+
)
|
|
61
187
|
finally:
|
|
62
|
-
|
|
188
|
+
# Remove request protection first (before terminating models)
|
|
189
|
+
if self.request_interceptor:
|
|
190
|
+
self.request_interceptor.remove_protection()
|
|
191
|
+
if self.attributes["debug"]:
|
|
192
|
+
print("[@ollama] Request protection removed")
|
|
193
|
+
|
|
194
|
+
# Then cleanup ollama manager (while card monitoring is still active)
|
|
195
|
+
if self.ollama_manager:
|
|
63
196
|
self.ollama_manager.terminate_models()
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
197
|
+
|
|
198
|
+
# Give the card a moment to render the final shutdown events
|
|
199
|
+
if self.card_monitor_thread and self.status_card:
|
|
200
|
+
import time
|
|
201
|
+
|
|
202
|
+
# Trigger one final card update to capture all shutdown events
|
|
203
|
+
try:
|
|
204
|
+
self.status_card.on_update(current.card["ollama_status"], None)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
if self.attributes["debug"]:
|
|
207
|
+
print(f"[@ollama] Final card update error: {e}")
|
|
208
|
+
time.sleep(2) # Allow final events to be rendered
|
|
209
|
+
|
|
210
|
+
# Now stop card monitoring
|
|
211
|
+
if self.card_monitor_thread:
|
|
212
|
+
self.card_monitor_thread._stop_event = True
|
|
213
|
+
|
|
214
|
+
if self.ollama_manager and self.attributes["debug"]:
|
|
215
|
+
print(
|
|
216
|
+
f"[@ollama] process statuses: {self.ollama_manager.processes}"
|
|
217
|
+
)
|
|
218
|
+
print(
|
|
219
|
+
f"[@ollama] process runtime stats: {self.ollama_manager.stats}"
|
|
220
|
+
)
|
|
221
|
+
print(
|
|
222
|
+
f"[@ollama] Circuit Breaker status: {self.ollama_manager.circuit_breaker.get_status()}"
|
|
223
|
+
)
|
|
69
224
|
|
|
70
225
|
return ollama_wrapper
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
OLLAMA_SUFFIX = "mf.ollama"
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from metaflow.exception import MetaflowException
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class UnspecifiedRemoteStorageRootException(MetaflowException):
|
|
5
|
+
headline = "Storage root not specified."
|
|
6
|
+
|
|
7
|
+
def __init__(self, message):
|
|
8
|
+
super(UnspecifiedRemoteStorageRootException, self).__init__(message)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class EmptyOllamaManifestCacheException(MetaflowException):
|
|
12
|
+
headline = "Model not found."
|
|
13
|
+
|
|
14
|
+
def __init__(self, message):
|
|
15
|
+
super(EmptyOllamaManifestCacheException, self).__init__(message)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EmptyOllamaBlobCacheException(MetaflowException):
|
|
19
|
+
headline = "Blob not found."
|
|
20
|
+
|
|
21
|
+
def __init__(self, message):
|
|
22
|
+
super(EmptyOllamaBlobCacheException, self).__init__(message)
|