ob-metaflow-extensions 1.1.156__py2.py3-none-any.whl → 1.1.158__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

@@ -186,18 +186,61 @@ class NimManager(object):
186
186
  ]
187
187
  self.models = {}
188
188
 
189
- for each_model in models:
190
- if each_model in nvcf_models:
191
- self.models[each_model] = NimChatCompletion(
192
- model=each_model,
189
+ # Convert models to a standard format
190
+ standardized_models = []
191
+ # If models is a single string, convert it to a list with a dict
192
+ if isinstance(models, str):
193
+ standardized_models = [{"name": models}]
194
+ # If models is a list, process each item
195
+ elif isinstance(models, list):
196
+ for model_item in models:
197
+ # If the item is a string, convert it to a dict
198
+ if isinstance(model_item, str):
199
+ standardized_models.append({"name": model_item})
200
+ # If it's already a dict, use it as is
201
+ elif isinstance(model_item, dict):
202
+ standardized_models.append(model_item)
203
+ else:
204
+ raise ValueError(
205
+ f"Model specification must be a string or dictionary, got {type(model_item)}"
206
+ )
207
+ else:
208
+ raise ValueError(
209
+ f"Models must be a string or a list of strings/dictionaries, got {type(models)}"
210
+ )
211
+
212
+ # Process each standardized model
213
+ for each_model_dict in standardized_models:
214
+ model_name = each_model_dict.get("name", "")
215
+ nvcf_id = each_model_dict.get("nvcf_id", "")
216
+ nvcf_version = each_model_dict.get("nvcf_version", "")
217
+
218
+ if model_name and not (nvcf_id and nvcf_version):
219
+ if model_name in nvcf_models:
220
+ self.models[model_name] = NimChatCompletion(
221
+ model=model_name,
222
+ nvcf_id=nvcf_id,
223
+ nvcf_version=nvcf_version,
224
+ nim_metadata=nim_metadata,
225
+ monitor=monitor,
226
+ )
227
+ else:
228
+ raise ValueError(
229
+ f"Model {model_name} not supported by the Outerbounds @nim offering."
230
+ f"\nYou can choose from these options: {nvcf_models}\n\n"
231
+ "Reach out to Outerbounds if there are other models you'd like supported."
232
+ )
233
+ elif nvcf_id and nvcf_version:
234
+ self.models[model_name] = NimChatCompletion(
235
+ model=model_name,
236
+ nvcf_id=nvcf_id,
237
+ nvcf_version=nvcf_version,
193
238
  nim_metadata=nim_metadata,
194
239
  monitor=monitor,
195
240
  )
196
241
  else:
197
242
  raise ValueError(
198
- f"Model {each_model} not supported by the Outerbounds @nim offering."
199
- f"\nYou can choose from these options: {nvcf_models}\n\n"
200
- "Reach out to Outerbounds if there are other models you'd like supported."
243
+ "You must provide either a valid 'name' or a custom 'name' along with both 'nvcf_id' and 'nvcf_version'."
201
244
  )
202
245
 
203
246
 
@@ -205,6 +248,8 @@ class NimChatCompletion(object):
205
248
  def __init__(
206
249
  self,
207
250
  model: str = "meta/llama3-8b-instruct",
251
+ nvcf_id: str = "",
252
+ nvcf_version: str = "",
208
253
  nim_metadata: NimMetadata = None,
209
254
  monitor: bool = False,
210
255
  **kwargs,
@@ -217,18 +262,34 @@ class NimChatCompletion(object):
217
262
  self.model_name = model
218
263
  self.nim_metadata = nim_metadata
219
264
  self.monitor = monitor
220
-
221
265
  all_nvcf_models = self.nim_metadata.get_nvcf_chat_completion_models()
222
- all_nvcf_model_names = [m["name"] for m in all_nvcf_models]
223
266
 
224
- if self.model_name not in all_nvcf_model_names:
225
- raise ValueError(
226
- f"Model {self.model_name} not found in available NVCF models"
227
- )
267
+ if nvcf_id and nvcf_version:
268
+ matching_models = [
269
+ m
270
+ for m in all_nvcf_models
271
+ if m["function-id"] == nvcf_id and m["version-id"] == nvcf_version
272
+ ]
273
+ if matching_models:
274
+ self.model = matching_models[0]
275
+ self.function_id = self.model["function-id"]
276
+ self.version_id = self.model["version-id"]
277
+ self.model_name = self.model["name"]
278
+ else:
279
+ raise ValueError(
280
+ f"Function {self.function_id} with version {self.version_id} not found on NVCF"
281
+ )
282
+ else:
283
+ all_nvcf_model_names = [m["name"] for m in all_nvcf_models]
284
+
285
+ if self.model_name not in all_nvcf_model_names:
286
+ raise ValueError(
287
+ f"Model {self.model_name} not found in available NVCF models"
288
+ )
228
289
 
229
- self.model = all_nvcf_models[all_nvcf_model_names.index(self.model_name)]
230
- self.function_id = self.model["function-id"]
231
- self.version_id = self.model["version-id"]
290
+ self.model = all_nvcf_models[all_nvcf_model_names.index(self.model_name)]
291
+ self.function_id = self.model["function-id"]
292
+ self.version_id = self.model["version-id"]
232
293
 
233
294
  self.first_request = True
234
295
 
@@ -1,6 +1,7 @@
1
1
  from metaflow.decorators import StepDecorator
2
2
  from metaflow import current
3
3
  import functools
4
+ import os
4
5
 
5
6
  from .ollama import OllamaManager
6
7
  from ..card_utilities.injector import CardDecoratorInjector
@@ -13,10 +14,10 @@ class OllamaDecorator(StepDecorator, CardDecoratorInjector):
13
14
  This decorator is used to run Ollama APIs as Metaflow task sidecars.
14
15
 
15
16
  User code call
16
- -----------
17
+ --------------
17
18
  @ollama(
18
- models=['meta/llama3-8b-instruct', 'meta/llama3-70b-instruct'],
19
- backend='local'
19
+ models=[...],
20
+ ...
20
21
  )
21
22
 
22
23
  Valid backend options
@@ -26,21 +27,39 @@ class OllamaDecorator(StepDecorator, CardDecoratorInjector):
26
27
  - (TODO) 'remote': Spin up separate instance to serve Ollama models.
27
28
 
28
29
  Valid model options
29
- ----------------
30
- - 'llama3.2'
31
- - 'llama3.3'
32
- - any model here https://ollama.com/search
30
+ -------------------
31
+ Any model here https://ollama.com/search, e.g. 'llama3.2', 'llama3.3'
33
32
 
34
33
  Parameters
35
34
  ----------
36
- models: list[Ollama]
35
+ models: list[str]
37
36
  List of Ollama containers running models in sidecars.
38
37
  backend: str
39
38
  Determines where and how to run the Ollama process.
39
+ force_pull: bool
40
+ Whether to run `ollama pull` no matter what, or first check the remote cache in Metaflow datastore for this model key.
41
+ skip_push_check: bool
42
+ Whether to skip the check that populates/overwrites remote cache on terminating an ollama model.
43
+ debug: bool
44
+ Whether to turn on verbose debugging logs.
40
45
  """
41
46
 
42
47
  name = "ollama"
43
- defaults = {"models": [], "backend": "local", "debug": False}
48
+ defaults = {
49
+ "models": [],
50
+ "backend": "local",
51
+ "force_pull": False,
52
+ "skip_push_check": False,
53
+ "debug": False,
54
+ }
55
+
56
+ def step_init(
57
+ self, flow, graph, step_name, decorators, environment, flow_datastore, logger
58
+ ):
59
+ super().step_init(
60
+ flow, graph, step_name, decorators, environment, flow_datastore, logger
61
+ )
62
+ self.flow_datastore_backend = flow_datastore._storage_impl
44
63
 
45
64
  def task_decorate(
46
65
  self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
@@ -51,6 +70,9 @@ class OllamaDecorator(StepDecorator, CardDecoratorInjector):
51
70
  self.ollama_manager = OllamaManager(
52
71
  models=self.attributes["models"],
53
72
  backend=self.attributes["backend"],
73
+ flow_datastore_backend=self.flow_datastore_backend,
74
+ force_pull=self.attributes["force_pull"],
75
+ skip_push_check=self.attributes["skip_push_check"],
54
76
  debug=self.attributes["debug"],
55
77
  )
56
78
  except Exception as e:
@@ -59,10 +81,7 @@ class OllamaDecorator(StepDecorator, CardDecoratorInjector):
59
81
  try:
60
82
  step_func()
61
83
  finally:
62
- try:
63
- self.ollama_manager.terminate_models()
64
- except Exception as term_e:
65
- print(f"[@ollama] Error during sidecar termination: {term_e}")
84
+ self.ollama_manager.terminate_models()
66
85
  if self.attributes["debug"]:
67
86
  print(f"[@ollama] process statuses: {self.ollama_manager.processes}")
68
87
  print(f"[@ollama] process runtime stats: {self.ollama_manager.stats}")
@@ -0,0 +1 @@
1
+ OLLAMA_SUFFIX = "mf.ollama"
@@ -0,0 +1,22 @@
1
+ from metaflow.exception import MetaflowException
2
+
3
+
4
+ class UnspecifiedRemoteStorageRootException(MetaflowException):
5
+ headline = "Storage root not specified."
6
+
7
+ def __init__(self, message):
8
+ super(UnspecifiedRemoteStorageRootException, self).__init__(message)
9
+
10
+
11
+ class EmptyOllamaManifestCacheException(MetaflowException):
12
+ headline = "Model not found."
13
+
14
+ def __init__(self, message):
15
+ super(EmptyOllamaManifestCacheException, self).__init__(message)
16
+
17
+
18
+ class EmptyOllamaBlobCacheException(MetaflowException):
19
+ headline = "Blob not found."
20
+
21
+ def __init__(self, message):
22
+ super(EmptyOllamaBlobCacheException, self).__init__(message)
@@ -5,6 +5,15 @@ import socket
5
5
  import sys
6
6
  import os
7
7
  import functools
8
+ import json
9
+ import requests
10
+
11
+ from .constants import OLLAMA_SUFFIX
12
+ from .exceptions import (
13
+ EmptyOllamaManifestCacheException,
14
+ EmptyOllamaBlobCacheException,
15
+ UnspecifiedRemoteStorageRootException,
16
+ )
8
17
 
9
18
 
10
19
  class ProcessStatus:
@@ -14,17 +23,40 @@ class ProcessStatus:
14
23
 
15
24
 
16
25
  class OllamaManager:
17
-
18
26
  """
19
27
  A process manager for Ollama runtimes.
20
- This is run locally, e.g., whether @ollama has a local, remote, or managed backend.
28
+ Implements interface @ollama([models=...], ...) has a local, remote, or managed backend.
21
29
  """
22
30
 
23
- def __init__(self, models, backend="local", debug=False):
31
+ def __init__(
32
+ self,
33
+ models,
34
+ backend="local",
35
+ flow_datastore_backend=None,
36
+ remote_storage_root=None,
37
+ force_pull=False,
38
+ skip_push_check=False,
39
+ debug=False,
40
+ ):
24
41
  self.models = {}
25
42
  self.processes = {}
43
+ self.flow_datastore_backend = flow_datastore_backend
44
+ if self.flow_datastore_backend is not None:
45
+ self.remote_storage_root = self.get_ollama_storage_root(
46
+ self.flow_datastore_backend
47
+ )
48
+ elif remote_storage_root is not None:
49
+ self.remote_storage_root = remote_storage_root
50
+ else:
51
+ raise UnspecifiedRemoteStorageRootException(
52
+ "Can not determine the storage root, as both flow_datastore_backend and remote_storage_root arguments of OllamaManager are None."
53
+ )
54
+ self.force_pull = force_pull
55
+ self.skip_push_check = skip_push_check
26
56
  self.debug = debug
27
57
  self.stats = {}
58
+ self.storage_info = {}
59
+ self.ollama_url = "http://localhost:11434" # Ollama API base URL
28
60
 
29
61
  if backend != "local":
30
62
  raise ValueError(
@@ -41,7 +73,7 @@ class OllamaManager:
41
73
  try:
42
74
  future.result()
43
75
  except Exception as e:
44
- raise RuntimeError(f"Error pulling one or more models: {e}") from e
76
+ raise RuntimeError(f"Error pulling one or more models. {e}") from e
45
77
 
46
78
  # Run models as background processes.
47
79
  for m in models:
@@ -55,22 +87,24 @@ class OllamaManager:
55
87
  self.stats[name] = {"process_runtime": tf - t0}
56
88
 
57
89
  def _install_ollama(self, max_retries=3):
58
-
59
90
  try:
60
91
  result = subprocess.run(["which", "ollama"], capture_output=True, text=True)
61
92
  if result.returncode == 0:
62
- if self.debug:
63
- print("[@ollama] is already installed.")
93
+ print("[@ollama] Ollama is already installed.")
64
94
  return
65
95
  except Exception as e:
66
- print("[@ollama] Did not find Ollama installation: %s" % e)
96
+ if self.debug:
97
+ print(f"[@ollama] Did not find Ollama installation: {e}")
67
98
  if sys.platform == "darwin":
68
99
  raise RuntimeError(
69
- "On macOS, please install Ollama manually from https://ollama.com/download"
100
+ "On macOS, please install Ollama manually from https://ollama.com/download."
70
101
  )
71
102
 
103
+ if self.debug:
104
+ print("[@ollama] Installing Ollama...")
72
105
  env = os.environ.copy()
73
106
  env["CURL_IPRESOLVE"] = "4"
107
+
74
108
  for attempt in range(max_retries):
75
109
  try:
76
110
  install_cmd = ["curl", "-fsSL", "https://ollama.com/install.sh"]
@@ -93,10 +127,11 @@ class OllamaManager:
93
127
  f"Ollama installation script failed: stdout: {sh_proc.stdout}, stderr: {sh_proc.stderr}"
94
128
  )
95
129
  if self.debug:
96
- print("[@ollama] Installed successfully.")
97
- break
130
+ print("[@ollama] Ollama installed successfully.")
131
+ break
98
132
  except Exception as e:
99
- print(f"Installation attempt {attempt+1} failed: {e}")
133
+ if self.debug:
134
+ print(f"[@ollama] Installation attempt {attempt+1} failed: {e}")
100
135
  if attempt < max_retries - 1:
101
136
  time.sleep(5)
102
137
  else:
@@ -117,11 +152,9 @@ class OllamaManager:
117
152
  def _launch_server(self):
118
153
  """
119
154
  Start the Ollama server process and ensure it's running.
120
- This version waits until the server is listening on port 11434.
121
155
  """
122
156
  try:
123
- if self.debug:
124
- print("[@ollama] Starting Ollama server...")
157
+ print("[@ollama] Starting Ollama server...")
125
158
  process = subprocess.Popen(
126
159
  ["ollama", "serve"],
127
160
  stdout=subprocess.PIPE,
@@ -133,23 +166,22 @@ class OllamaManager:
133
166
  "properties": {"type": "api-server", "error_details": None},
134
167
  "status": ProcessStatus.RUNNING,
135
168
  }
169
+
136
170
  if self.debug:
137
- print(
138
- "[@ollama] Started Ollama server process with PID %s" % process.pid
139
- )
171
+ print(f"[@ollama] Started server process with PID {process.pid}.")
140
172
 
141
- # Wait until the server is ready (listening on 127.0.0.1:11434)
173
+ # Wait until the server is ready
142
174
  host, port = "127.0.0.1", 11434
143
175
  retries = 0
144
176
  max_retries = 10
145
177
  while (
146
178
  not self._is_port_open(host, port, timeout=1) and retries < max_retries
147
179
  ):
148
- print(
149
- "[@ollama] Waiting for server to be ready... (%d/%d)"
150
- % (retries + 1, max_retries)
151
- )
152
- time.sleep(1)
180
+ if retries == 0:
181
+ print("[@ollama] Waiting for server to be ready...")
182
+ elif retries % 3 == 0:
183
+ print(f"[@ollama] Still waiting... ({retries + 1}/{max_retries})")
184
+ time.sleep(5)
153
185
  retries += 1
154
186
 
155
187
  if not self._is_port_open(host, port, timeout=1):
@@ -162,7 +194,7 @@ class OllamaManager:
162
194
  self.processes[process.pid]["status"] = ProcessStatus.FAILED
163
195
  raise RuntimeError(f"Ollama server failed to start. {error_details}")
164
196
 
165
- # Check if the process has unexpectedly terminated
197
+ # Check if process terminated unexpectedly
166
198
  returncode = process.poll()
167
199
  if returncode is not None:
168
200
  stdout, stderr = process.communicate()
@@ -181,21 +213,384 @@ class OllamaManager:
181
213
  self.processes[process.pid]["properties"]["error_details"] = str(e)
182
214
  raise RuntimeError(f"Error starting Ollama server: {e}") from e
183
215
 
184
- def _pull_model(self, m):
216
+ def _setup_storage(self, m):
217
+ """
218
+ Configure local and remote storage paths for an Ollama model.
219
+ """
220
+ # Parse model and tag name
221
+ ollama_model_name_components = m.split(":")
222
+ if len(ollama_model_name_components) == 1:
223
+ model_name = ollama_model_name_components[0]
224
+ tag = "latest"
225
+ elif len(ollama_model_name_components) == 2:
226
+ model_name = ollama_model_name_components[0]
227
+ tag = ollama_model_name_components[1]
228
+
229
+ # Find where Ollama actually stores models
230
+ possible_storage_roots = [
231
+ os.environ.get("OLLAMA_MODELS"),
232
+ "/usr/share/ollama/.ollama/models",
233
+ os.path.expanduser("~/.ollama/models"),
234
+ "/root/.ollama/models",
235
+ ]
236
+
237
+ ollama_local_storage_root = None
238
+ for root in possible_storage_roots:
239
+ if root and os.path.exists(root):
240
+ ollama_local_storage_root = root
241
+ break
242
+
243
+ if not ollama_local_storage_root:
244
+ # https://github.com/ollama/ollama/blob/main/docs/faq.md#where-are-models-stored
245
+ if sys.platform.startswith("linux"):
246
+ ollama_local_storage_root = "/usr/share/ollama/.ollama/models"
247
+ elif sys.platform == "darwin":
248
+ ollama_local_storage_root = os.path.expanduser("~/.ollama/models")
249
+
250
+ if self.debug:
251
+ print(
252
+ f"[@ollama {m}] Using Ollama storage root: {ollama_local_storage_root}."
253
+ )
254
+
255
+ blob_local_path = os.path.join(ollama_local_storage_root, "blobs")
256
+ manifest_base_path = os.path.join(
257
+ ollama_local_storage_root,
258
+ "manifests/registry.ollama.ai/library",
259
+ model_name,
260
+ )
261
+
262
+ # Create directories
185
263
  try:
264
+ os.makedirs(blob_local_path, exist_ok=True)
265
+ os.makedirs(manifest_base_path, exist_ok=True)
266
+ except FileExistsError:
267
+ pass
268
+
269
+ # Set up remote paths
270
+ if not self.local_datastore and self.remote_storage_root is not None:
271
+ blob_remote_key = os.path.join(self.remote_storage_root, "blobs")
272
+ manifest_remote_key = os.path.join(
273
+ self.remote_storage_root,
274
+ "manifests/registry.ollama.ai/library",
275
+ model_name,
276
+ tag,
277
+ )
278
+ else:
279
+ blob_remote_key = None
280
+ manifest_remote_key = None
281
+
282
+ self.storage_info[m] = {
283
+ "blob_local_root": blob_local_path,
284
+ "blob_remote_root": blob_remote_key,
285
+ "manifest_local": os.path.join(manifest_base_path, tag),
286
+ "manifest_remote": manifest_remote_key,
287
+ "manifest_content": None,
288
+ "model_name": model_name,
289
+ "tag": tag,
290
+ "storage_root": ollama_local_storage_root,
291
+ }
292
+
293
+ if self.debug:
294
+ print(f"[@ollama {m}] Storage paths configured.")
295
+
296
+ def _fetch_manifest(self, m):
297
+ """
298
+ Load the manifest file and content, either from local storage or remote cache.
299
+ """
300
+ if self.debug:
301
+ print(f"[@ollama {m}] Checking for cached manifest...")
302
+
303
+ def _disk_to_memory():
304
+ with open(self.storage_info[m]["manifest_local"], "r") as f:
305
+ self.storage_info[m]["manifest_content"] = json.load(f)
306
+
307
+ if os.path.exists(self.storage_info[m]["manifest_local"]):
308
+ if self.storage_info[m]["manifest_content"] is None:
309
+ _disk_to_memory()
186
310
  if self.debug:
187
- print("[@ollama] Pulling model: %s" % m)
188
- result = subprocess.run(
189
- ["ollama", "pull", m], capture_output=True, text=True
311
+ print(f"[@ollama {m}] Manifest found locally.")
312
+ elif self.local_datastore:
313
+ if self.debug:
314
+ print(f"[@ollama {m}] No manifest found in local datastore.")
315
+ return None
316
+ else:
317
+ from metaflow import S3
318
+ from metaflow.plugins.datatools.s3.s3 import MetaflowS3NotFound
319
+
320
+ try:
321
+ with S3() as s3:
322
+ s3obj = s3.get(self.storage_info[m]["manifest_remote"])
323
+ if not s3obj.exists:
324
+ raise EmptyOllamaManifestCacheException(
325
+ f"No manifest in remote storage for model {m}"
326
+ )
327
+
328
+ if self.debug:
329
+ print(f"[@ollama {m}] Downloaded manifest from cache.")
330
+ os.rename(s3obj.path, self.storage_info[m]["manifest_local"])
331
+ _disk_to_memory()
332
+
333
+ if self.debug:
334
+ print(
335
+ f"[@ollama {m}] Manifest found in remote cache, downloaded locally."
336
+ )
337
+ except (MetaflowS3NotFound, EmptyOllamaManifestCacheException):
338
+ if self.debug:
339
+ print(
340
+ f"[@ollama {m}] No manifest found locally or in remote cache."
341
+ )
342
+ return None
343
+
344
+ return self.storage_info[m]["manifest_content"]
345
+
346
+ def _fetch_blobs(self, m):
347
+ """
348
+ Fetch missing blobs from remote cache.
349
+ """
350
+ if self.debug:
351
+ print(f"[@ollama {m}] Checking for cached blobs...")
352
+
353
+ manifest = self._fetch_manifest(m)
354
+ if not manifest:
355
+ raise EmptyOllamaBlobCacheException(f"No manifest available for model {m}")
356
+
357
+ blobs_required = [layer["digest"] for layer in manifest["layers"]]
358
+ missing_blob_info = []
359
+
360
+ # Check which blobs are missing locally
361
+ for blob_digest in blobs_required:
362
+ blob_filename = blob_digest.replace(":", "-")
363
+ local_blob_path = os.path.join(
364
+ self.storage_info[m]["blob_local_root"], blob_filename
190
365
  )
191
- if result.returncode != 0:
192
- raise RuntimeError(
193
- f"Failed to pull model {m}: stdout: {result.stdout}, stderr: {result.stderr}"
366
+
367
+ if not os.path.exists(local_blob_path):
368
+ if self.debug:
369
+ print(f"[@ollama {m}] Blob {blob_digest} not found locally.")
370
+
371
+ remote_blob_path = os.path.join(
372
+ self.storage_info[m]["blob_remote_root"], blob_filename
373
+ )
374
+ missing_blob_info.append(
375
+ {
376
+ "digest": blob_digest,
377
+ "filename": blob_filename,
378
+ "remote_path": remote_blob_path,
379
+ "local_path": local_blob_path,
380
+ }
194
381
  )
382
+
383
+ if not missing_blob_info:
384
+ if self.debug:
385
+ print(f"[@ollama {m}] All blobs found locally.")
386
+ return
387
+
388
+ if self.debug:
389
+ print(
390
+ f"[@ollama {m}] Downloading {len(missing_blob_info)} missing blobs from cache..."
391
+ )
392
+
393
+ remote_urls = [blob_info["remote_path"] for blob_info in missing_blob_info]
394
+
395
+ from metaflow import S3
396
+
397
+ try:
398
+ with S3() as s3:
399
+ if len(remote_urls) == 1:
400
+ s3objs = [s3.get(remote_urls[0])]
401
+ else:
402
+ s3objs = s3.get_many(remote_urls)
403
+
404
+ if not isinstance(s3objs, list):
405
+ s3objs = [s3objs]
406
+
407
+ # Move each downloaded blob to correct location
408
+ for i, s3obj in enumerate(s3objs):
409
+ if not s3obj.exists:
410
+ blob_info = missing_blob_info[i]
411
+ raise EmptyOllamaBlobCacheException(
412
+ f"Blob {blob_info['digest']} not found in remote cache for model {m}"
413
+ )
414
+
415
+ blob_info = missing_blob_info[i]
416
+ os.makedirs(os.path.dirname(blob_info["local_path"]), exist_ok=True)
417
+ os.rename(s3obj.path, blob_info["local_path"])
418
+
419
+ if self.debug:
420
+ print(f"[@ollama {m}] Downloaded blob {blob_info['filename']}.")
421
+
422
+ except Exception as e:
423
+ if self.debug:
424
+ print(f"[@ollama {m}] Error during blob fetch: {e}")
425
+ raise EmptyOllamaBlobCacheException(
426
+ f"Failed to fetch blobs for model {m}: {e}"
427
+ )
428
+
429
+ if self.debug:
430
+ print(
431
+ f"[@ollama {m}] Successfully downloaded all missing blobs from cache."
432
+ )
433
+
434
+ def _verify_model_available(self, m):
435
+ """
436
+ Verify model is available using Ollama API
437
+ """
438
+ try:
439
+ response = requests.post(
440
+ f"{self.ollama_url}/api/show", json={"model": m}, timeout=10
441
+ )
442
+
443
+ available = response.status_code == 200
444
+
195
445
  if self.debug:
196
- print("[@ollama] Model %s pulled successfully." % m)
446
+ if available:
447
+ print(f"[@ollama {m}] ✓ Model is available via API.")
448
+ else:
449
+ print(
450
+ f"[@ollama {m}] ✗ Model not available via API (status: {response.status_code})."
451
+ )
452
+
453
+ return available
454
+
197
455
  except Exception as e:
198
- raise RuntimeError(f"Error pulling Ollama model {m}: {e}") from e
456
+ if self.debug:
457
+ print(f"[@ollama {m}] Error verifying model: {e}")
458
+ return False
459
+
460
+ def _register_cached_model_with_ollama(self, m):
461
+ """
462
+ Register a cached model with Ollama using the API.
463
+ """
464
+ try:
465
+ show_response = requests.post(
466
+ f"{self.ollama_url}/api/show", json={"model": m}, timeout=10
467
+ )
468
+
469
+ if show_response.status_code == 200:
470
+ if self.debug:
471
+ print(f"[@ollama {m}] Model already registered with Ollama.")
472
+ return True
473
+
474
+ # Try to create/register the model from existing files
475
+ if self.debug:
476
+ print(f"[@ollama {m}] Registering cached model with Ollama...")
477
+
478
+ create_response = requests.post(
479
+ f"{self.ollama_url}/api/create",
480
+ json={
481
+ "model": m,
482
+ "from": m, # Use same name - should find existing files
483
+ "stream": False,
484
+ },
485
+ timeout=60,
486
+ )
487
+
488
+ if create_response.status_code == 200:
489
+ result = create_response.json()
490
+ if result.get("status") == "success":
491
+ if self.debug:
492
+ print(f"[@ollama {m}] Successfully registered cached model.")
493
+ return True
494
+ else:
495
+ if self.debug:
496
+ print(f"[@ollama {m}] Create response: {result}.")
497
+
498
+ # Fallback: try a pull which should be fast if files exist
499
+ if self.debug:
500
+ print(f"[@ollama {m}] Create failed, trying pull to register...")
501
+
502
+ pull_response = requests.post(
503
+ f"{self.ollama_url}/api/pull",
504
+ json={"model": m, "stream": False},
505
+ timeout=120,
506
+ )
507
+
508
+ if pull_response.status_code == 200:
509
+ result = pull_response.json()
510
+ if result.get("status") == "success":
511
+ if self.debug:
512
+ print(f"[@ollama {m}] Model registered via pull.")
513
+ return True
514
+
515
+ except requests.exceptions.RequestException as e:
516
+ if self.debug:
517
+ print(f"[@ollama {m}] API registration failed: {e}")
518
+ except Exception as e:
519
+ if self.debug:
520
+ print(f"[@ollama {m}] Error during registration: {e}")
521
+
522
+ return False
523
+
524
+ def _pull_model(self, m):
525
+ """
526
+ Pull/setup a model, using cache when possible.
527
+ """
528
+ self._setup_storage(m)
529
+
530
+ # Try to fetch manifest from cache first
531
+ manifest = None
532
+ try:
533
+ manifest = self._fetch_manifest(m)
534
+ except (EmptyOllamaManifestCacheException, Exception) as e:
535
+ if self.debug:
536
+ print(f"[@ollama {m}] No cached manifest found or error fetching: {e}")
537
+ manifest = None
538
+
539
+ # If we don't have a cached manifest or force_pull is True, pull the model
540
+ if self.force_pull or not manifest:
541
+ try:
542
+ print(f"[@ollama {m}] Not using cache. Downloading model {m}...")
543
+ result = subprocess.run(
544
+ ["ollama", "pull", m], capture_output=True, text=True
545
+ )
546
+ if result.returncode != 0:
547
+ raise RuntimeError(
548
+ f"Failed to pull model {m}: stdout: {result.stdout}, stderr: {result.stderr}"
549
+ )
550
+ print(f"[@ollama {m}] Model downloaded successfully.")
551
+ except Exception as e:
552
+ raise RuntimeError(f"Error pulling Ollama model {m}: {e}") from e
553
+ else:
554
+ # We have a cached manifest, try to fetch the blobs
555
+ try:
556
+ self._fetch_blobs(m)
557
+ print(f"[@ollama {m}] Using cached model.")
558
+
559
+ # Register the cached model with Ollama
560
+ if not self._verify_model_available(m):
561
+ if not self._register_cached_model_with_ollama(m):
562
+ raise RuntimeError(
563
+ f"Failed to register cached model {m} with Ollama"
564
+ )
565
+
566
+ # self.skip_push_check = True
567
+
568
+ except (EmptyOllamaBlobCacheException, Exception) as e:
569
+ if self.debug:
570
+ print(f"[@ollama {m}] Cache failed, downloading model...")
571
+ print(f"[@ollama {m}] Error: {e}")
572
+
573
+ # Fallback to pulling the model
574
+ try:
575
+ result = subprocess.run(
576
+ ["ollama", "pull", m], capture_output=True, text=True
577
+ )
578
+ if result.returncode != 0:
579
+ raise RuntimeError(
580
+ f"Failed to pull model {m}: stdout: {result.stdout}, stderr: {result.stderr}"
581
+ )
582
+ print(f"[@ollama {m}] Model downloaded successfully (fallback).")
583
+ except Exception as pull_e:
584
+ raise RuntimeError(
585
+ f"Error pulling Ollama model {m} as fallback: {pull_e}"
586
+ ) from pull_e
587
+
588
+ # Final verification that the model is available
589
+ if not self._verify_model_available(m):
590
+ raise RuntimeError(f"Model {m} is not available to Ollama after setup")
591
+
592
+ if self.debug:
593
+ print(f"[@ollama {m}] Model setup complete and verified.")
199
594
 
200
595
  def _run_model(self, m):
201
596
  """
@@ -204,7 +599,8 @@ class OllamaManager:
204
599
  process = None
205
600
  try:
206
601
  if self.debug:
207
- print("[@ollama] Running model: %s" % m)
602
+ print(f"[@ollama {m}] Starting model process...")
603
+
208
604
  process = subprocess.Popen(
209
605
  ["ollama", "run", m],
210
606
  stdout=subprocess.PIPE,
@@ -216,8 +612,9 @@ class OllamaManager:
216
612
  "properties": {"type": "model", "model": m, "error_details": None},
217
613
  "status": ProcessStatus.RUNNING,
218
614
  }
615
+
219
616
  if self.debug:
220
- print("[@ollama] Stored process %s for model %s." % (process.pid, m))
617
+ print(f"[@ollama {m}] Model process PID: {process.pid}.")
221
618
 
222
619
  try:
223
620
  process.wait(timeout=1)
@@ -231,8 +628,7 @@ class OllamaManager:
231
628
  self.processes[process.pid]["status"] = ProcessStatus.SUCCESSFUL
232
629
  if self.debug:
233
630
  print(
234
- "[@ollama] Process %s for model %s exited successfully."
235
- % (process.pid, m)
631
+ f"[@ollama {m}] Process {process.pid} exited successfully."
236
632
  )
237
633
  else:
238
634
  error_details = f"Return code: {returncode}, Error: {stderr}"
@@ -242,8 +638,7 @@ class OllamaManager:
242
638
  self.processes[process.pid]["status"] = ProcessStatus.FAILED
243
639
  if self.debug:
244
640
  print(
245
- "[@ollama] Process %s for model %s failed: %s"
246
- % (process.pid, m, error_details)
641
+ f"[@ollama {m}] Process {process.pid} failed: {error_details}."
247
642
  )
248
643
  except Exception as e:
249
644
  if process and process.pid in self.processes:
@@ -251,20 +646,25 @@ class OllamaManager:
251
646
  self.processes[process.pid]["properties"]["error_details"] = str(e)
252
647
  raise RuntimeError(f"Error running Ollama model {m}: {e}") from e
253
648
 
254
- def terminate_models(self):
649
+ def terminate_models(self, skip_push_check=None):
255
650
  """
256
- Terminate all processes gracefully.
257
- First, stop model processes using 'ollama stop <model>'.
258
- Then, shut down the API server process.
651
+ Terminate all processes gracefully and update cache.
259
652
  """
653
+ print("[@ollama] Shutting down models...")
654
+
655
+ if skip_push_check is not None:
656
+ assert isinstance(
657
+ skip_push_check, bool
658
+ ), "skip_push_check passed to terminate_models must be a bool if specified."
659
+ self.skip_push_check = skip_push_check
260
660
 
261
661
  for pid, process_info in list(self.processes.items()):
262
662
  if process_info["properties"].get("type") == "model":
263
663
  model_name = process_info["properties"].get("model")
664
+
264
665
  if self.debug:
265
- print(
266
- "[@ollama] Stopping model %s using 'ollama stop'" % model_name
267
- )
666
+ print(f"[@ollama {model_name}] Stopping model process...")
667
+
268
668
  try:
269
669
  result = subprocess.run(
270
670
  ["ollama", "stop", model_name], capture_output=True, text=True
@@ -272,28 +672,27 @@ class OllamaManager:
272
672
  if result.returncode == 0:
273
673
  process_info["status"] = ProcessStatus.SUCCESSFUL
274
674
  if self.debug:
275
- print(
276
- "[@ollama] Model %s stopped successfully." % model_name
277
- )
675
+ print(f"[@ollama {model_name}] Stopped successfully.")
278
676
  else:
279
677
  process_info["status"] = ProcessStatus.FAILED
280
678
  if self.debug:
281
679
  print(
282
- "[@ollama] Model %s failed to stop gracefully. Return code: %s, Error: %s"
283
- % (model_name, result.returncode, result.stderr)
680
+ f"[@ollama {model_name}] Stop failed: {result.stderr}"
284
681
  )
285
682
  except Exception as e:
286
683
  process_info["status"] = ProcessStatus.FAILED
287
- print("[@ollama] Error stopping model %s: %s" % (model_name, e))
684
+ print(f"[@ollama {model_name}] Error stopping: {e}")
288
685
 
289
- # Then, stop the API server
686
+ # Update cache if needed
687
+ if not self.skip_push_check:
688
+ self._update_model_cache(model_name)
689
+
690
+ # Stop the API server
290
691
  for pid, process_info in list(self.processes.items()):
291
692
  if process_info["properties"].get("type") == "api-server":
292
693
  if self.debug:
293
- print(
294
- "[@ollama] Stopping API server process with PID %s using process.terminate()"
295
- % pid
296
- )
694
+ print(f"[@ollama] Stopping API server process PID {pid}.")
695
+
297
696
  process = process_info["p"]
298
697
  try:
299
698
  process.terminate()
@@ -301,28 +700,114 @@ class OllamaManager:
301
700
  process.wait(timeout=5)
302
701
  except subprocess.TimeoutExpired:
303
702
  print(
304
- "[@ollama] API server process %s did not terminate in time; killing it."
305
- % pid
703
+ f"[@ollama] API server PID {pid} did not terminate, killing..."
306
704
  )
307
705
  process.kill()
308
706
  process.wait()
309
- returncode = process.poll()
310
- if returncode is None or returncode != 0:
311
- process_info["status"] = ProcessStatus.FAILED
312
- print(
313
- "[@ollama] API server process %s terminated with error code %s."
314
- % (pid, returncode)
315
- )
316
- else:
317
- process_info["status"] = ProcessStatus.SUCCESSFUL
318
- if self.debug:
319
- print(
320
- "[@ollama] API server process %s terminated successfully."
321
- % pid
322
- )
707
+
708
+ process_info["status"] = ProcessStatus.SUCCESSFUL
709
+ if self.debug:
710
+ print(f"[@ollama] API server terminated successfully.")
323
711
  except Exception as e:
324
712
  process_info["status"] = ProcessStatus.FAILED
713
+ print(f"[@ollama] Warning: Error terminating API server: {e}")
714
+
715
+ print("[@ollama] All models stopped.")
716
+
717
+ # Show performance summary
718
+ if self.debug:
719
+ if hasattr(self, "stats") and self.stats:
720
+ print("[@ollama] Performance summary:")
721
+ for operation, stats in self.stats.items():
722
+ runtime = stats.get("process_runtime", 0)
723
+ if runtime > 1: # Only show operations that took meaningful time
724
+ print(f"[@ollama] {operation}: {runtime:.1f}s")
725
+
726
+ def _update_model_cache(self, model_name):
727
+ """
728
+ Update the remote cache with model files if needed.
729
+ """
730
+ try:
731
+ manifest = self._fetch_manifest(model_name)
732
+ if not manifest:
733
+ if self.debug:
325
734
  print(
326
- "[@ollama] Warning: Error while terminating API server process %s: %s"
327
- % (pid, e)
735
+ f"[@ollama {model_name}] No manifest available for cache update."
328
736
  )
737
+ return
738
+
739
+ from metaflow import S3
740
+
741
+ cache_up_to_date = True
742
+ key_paths = [
743
+ (
744
+ self.storage_info[model_name]["manifest_remote"],
745
+ self.storage_info[model_name]["manifest_local"],
746
+ )
747
+ ]
748
+
749
+ with S3() as s3:
750
+ # Check if blobs need updating
751
+ s3objs = s3.list_paths(
752
+ [self.storage_info[model_name]["blob_remote_root"]]
753
+ )
754
+ for layer in manifest["layers"]:
755
+ expected_blob_sha = layer["digest"]
756
+ if expected_blob_sha not in s3objs:
757
+ cache_up_to_date = False
758
+ break
759
+
760
+ if not cache_up_to_date:
761
+ blob_count = len(manifest.get("layers", []))
762
+ print(
763
+ f"[@ollama {model_name}] Uploading {blob_count} files to cache..."
764
+ )
765
+
766
+ # Add blob paths to upload
767
+ for layer in manifest["layers"]:
768
+ blob_filename = layer["digest"].replace(":", "-")
769
+ key_paths.append(
770
+ (
771
+ os.path.join(
772
+ self.storage_info[model_name]["blob_remote_root"],
773
+ blob_filename,
774
+ ),
775
+ os.path.join(
776
+ self.storage_info[model_name]["blob_local_root"],
777
+ blob_filename,
778
+ ),
779
+ )
780
+ )
781
+
782
+ s3.put_files(key_paths)
783
+ print(f"[@ollama {model_name}] Cache updated.")
784
+ else:
785
+ if self.debug:
786
+ print(f"[@ollama {model_name}] Cache is up to date.")
787
+
788
+ except Exception as e:
789
+ if self.debug:
790
+ print(f"[@ollama {model_name}] Error updating cache: {e}")
791
+
792
+ def get_ollama_storage_root(self, backend):
793
+ """
794
+ Return the path to the root of the datastore.
795
+ """
796
+ if backend.TYPE == "s3":
797
+ from metaflow.metaflow_config import DATASTORE_SYSROOT_S3
798
+
799
+ self.local_datastore = False
800
+ return os.path.join(DATASTORE_SYSROOT_S3, OLLAMA_SUFFIX)
801
+ elif backend.TYPE == "azure":
802
+ from metaflow.metaflow_config import DATASTORE_SYSROOT_AZURE
803
+
804
+ self.local_datastore = False
805
+ return os.path.join(DATASTORE_SYSROOT_AZURE, OLLAMA_SUFFIX)
806
+ elif backend.TYPE == "gs":
807
+ from metaflow.metaflow_config import DATASTORE_SYSROOT_GS
808
+
809
+ self.local_datastore = False
810
+ return os.path.join(DATASTORE_SYSROOT_GS, OLLAMA_SUFFIX)
811
+ else:
812
+ self.local_datastore = True
813
+ return None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ob-metaflow-extensions
3
- Version: 1.1.156
3
+ Version: 1.1.158
4
4
  Summary: Outerbounds Platform Extensions for Metaflow
5
5
  Author: Outerbounds, Inc.
6
6
  License: Commercial
@@ -25,7 +25,7 @@ metaflow_extensions/outerbounds/plugins/kubernetes/__init__.py,sha256=5zG8gShSj8
25
25
  metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py,sha256=fx_XUkgR4r6hF2ilDfT5LubRyVrYMVIv5f6clHkCaEk,5988
26
26
  metaflow_extensions/outerbounds/plugins/nim/card.py,sha256=dXOJvsZed5NyYyxYLPDvtwg9z_X4azL9HTJGYaiNriY,4690
27
27
  metaflow_extensions/outerbounds/plugins/nim/nim_decorator.py,sha256=50YVvC7mcZYlPluM0Wq1UtufhzlQb-RxzZkTOJJ3LkM,3439
28
- metaflow_extensions/outerbounds/plugins/nim/nim_manager.py,sha256=5YkohM-vfoDHPUMWb19sY0HErORoKOKf4jexERJTO80,10912
28
+ metaflow_extensions/outerbounds/plugins/nim/nim_manager.py,sha256=y8U71106KJtrC6nlhsNnzX9Xkv3RnyZ1KEpRFwqZZFk,13686
29
29
  metaflow_extensions/outerbounds/plugins/nim/utils.py,sha256=nU-v1sheBjmITXfHiJx2ucm_Tq_nGb5BcuAm5c235cQ,1164
30
30
  metaflow_extensions/outerbounds/plugins/nvcf/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  metaflow_extensions/outerbounds/plugins/nvcf/constants.py,sha256=aGHdNw_hqBu8i0zWXcatQM6e769wUXox0l8g0f6fNZ8,146
@@ -42,8 +42,10 @@ metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py,sha256=bB9AURhRep9PV_-b
42
42
  metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py,sha256=LaJ_Tk-vNjvrglzSTR-U6pk8f9MtQRKObU9m7vBYtkI,8695
43
43
  metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py,sha256=8IPkdvuTZNIqgAAt75gVNn-ydr-Zz2sKC8UX_6pNEKI,7091
44
44
  metaflow_extensions/outerbounds/plugins/nvct/utils.py,sha256=U4_Fu8H94j_Bbox7mmMhNnlRhlYHqnK28R5w_TMWEFM,1029
45
- metaflow_extensions/outerbounds/plugins/ollama/__init__.py,sha256=HEsI5U4ckQby7K2NsGBOdizhPY3WWqXSnXx_IHL7_No,2307
46
- metaflow_extensions/outerbounds/plugins/ollama/ollama.py,sha256=KlP8_EmnUoi8-PidyU0IDuENYxKjQaHFC33yGsvaeic,13320
45
+ metaflow_extensions/outerbounds/plugins/ollama/__init__.py,sha256=vzh8sQEfwKRdx0fsGFJ-km4mwfi0vm2q1_vsZv-EMcc,3034
46
+ metaflow_extensions/outerbounds/plugins/ollama/constants.py,sha256=hxkTpWEJp1pKHwUcG4EE3-17M6x2CyeMfbeqgUzF9TA,28
47
+ metaflow_extensions/outerbounds/plugins/ollama/exceptions.py,sha256=8Ss296_MGZl1wXAoDNwpH-hsPe6iYLe90Ji1pczNocU,668
48
+ metaflow_extensions/outerbounds/plugins/ollama/ollama.py,sha256=oe-k1ISSMtUF2y3YpfmJhU_3yR7SP31PVilN5NPgKv0,31450
47
49
  metaflow_extensions/outerbounds/plugins/profilers/deco_injector.py,sha256=oI_C3c64XBm7n88FILqHwn-Nnc5DeT_68I67lM9rXaI,2434
48
50
  metaflow_extensions/outerbounds/plugins/profilers/gpu_profile_decorator.py,sha256=gDHQ2sMIp4NuZSzUspbSd8RGdFAoO5mgZAyFcZ2a51Y,2619
49
51
  metaflow_extensions/outerbounds/plugins/secrets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -68,7 +70,7 @@ metaflow_extensions/outerbounds/toplevel/plugins/gcp/__init__.py,sha256=BbZiaH3u
68
70
  metaflow_extensions/outerbounds/toplevel/plugins/kubernetes/__init__.py,sha256=5zG8gShSj8m7rgF4xgWBZFuY3GDP5n1T0ktjRpGJLHA,69
69
71
  metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py,sha256=GRSz2zwqkvlmFS6bcfYD_CX6CMko9DHQokMaH1iBshA,47
70
72
  metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py,sha256=LptpH-ziXHrednMYUjIaosS1SXD3sOtF_9_eRqd8SJw,50
71
- ob_metaflow_extensions-1.1.156.dist-info/METADATA,sha256=G9c19j9g0v8dDQU5sP5Zaaub2fot__EMCJ6iBQBb4Qo,521
72
- ob_metaflow_extensions-1.1.156.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
73
- ob_metaflow_extensions-1.1.156.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
74
- ob_metaflow_extensions-1.1.156.dist-info/RECORD,,
73
+ ob_metaflow_extensions-1.1.158.dist-info/METADATA,sha256=0t_P8-Uhi3I39xyeSGv2BpRQO5Upe1eIjs04e6Stjd8,521
74
+ ob_metaflow_extensions-1.1.158.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
75
+ ob_metaflow_extensions-1.1.158.dist-info/top_level.txt,sha256=NwG0ukwjygtanDETyp_BUdtYtqIA_lOjzFFh1TsnxvI,20
76
+ ob_metaflow_extensions-1.1.158.dist-info/RECORD,,