ob-metaflow-extensions 1.1.45rc3__py2.py3-none-any.whl → 1.5.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

Files changed (128) hide show
  1. metaflow_extensions/outerbounds/__init__.py +1 -7
  2. metaflow_extensions/outerbounds/config/__init__.py +35 -0
  3. metaflow_extensions/outerbounds/plugins/__init__.py +186 -57
  4. metaflow_extensions/outerbounds/plugins/apps/__init__.py +0 -0
  5. metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
  6. metaflow_extensions/outerbounds/plugins/apps/app_utils.py +187 -0
  7. metaflow_extensions/outerbounds/plugins/apps/consts.py +3 -0
  8. metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +15 -0
  9. metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
  10. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
  11. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
  12. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
  13. metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +128 -0
  14. metaflow_extensions/outerbounds/plugins/apps/core/app_deploy_decorator.py +330 -0
  15. metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
  16. metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
  17. metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
  18. metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
  19. metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
  20. metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
  21. metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +15 -0
  22. metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +165 -0
  23. metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +966 -0
  24. metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +299 -0
  25. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +233 -0
  26. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +537 -0
  27. metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1125 -0
  28. metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
  29. metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
  30. metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +959 -0
  31. metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
  32. metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
  33. metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
  34. metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
  35. metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
  36. metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py +201 -0
  37. metaflow_extensions/outerbounds/plugins/apps/supervisord_utils.py +243 -0
  38. metaflow_extensions/outerbounds/plugins/auth_server.py +28 -8
  39. metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
  40. metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
  41. metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +118 -0
  42. metaflow_extensions/outerbounds/plugins/card_utilities/__init__.py +0 -0
  43. metaflow_extensions/outerbounds/plugins/card_utilities/async_cards.py +142 -0
  44. metaflow_extensions/outerbounds/plugins/card_utilities/extra_components.py +545 -0
  45. metaflow_extensions/outerbounds/plugins/card_utilities/injector.py +70 -0
  46. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
  47. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
  48. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
  49. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
  50. metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
  51. metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
  52. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +391 -0
  53. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +188 -0
  54. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +54 -0
  55. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +50 -0
  56. metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +79 -0
  57. metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
  58. metaflow_extensions/outerbounds/plugins/nim/card.py +140 -0
  59. metaflow_extensions/outerbounds/plugins/nim/nim_decorator.py +101 -0
  60. metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +379 -0
  61. metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
  62. metaflow_extensions/outerbounds/plugins/nvcf/__init__.py +0 -0
  63. metaflow_extensions/outerbounds/plugins/nvcf/constants.py +3 -0
  64. metaflow_extensions/outerbounds/plugins/nvcf/exceptions.py +94 -0
  65. metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +178 -0
  66. metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +417 -0
  67. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +280 -0
  68. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +242 -0
  69. metaflow_extensions/outerbounds/plugins/nvcf/utils.py +6 -0
  70. metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
  71. metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
  72. metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
  73. metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
  74. metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
  75. metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
  76. metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
  77. metaflow_extensions/outerbounds/plugins/ollama/__init__.py +225 -0
  78. metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
  79. metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
  80. metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1924 -0
  81. metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
  82. metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
  83. metaflow_extensions/outerbounds/plugins/perimeters.py +19 -5
  84. metaflow_extensions/outerbounds/plugins/profilers/deco_injector.py +70 -0
  85. metaflow_extensions/outerbounds/plugins/profilers/gpu_profile_decorator.py +88 -0
  86. metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
  87. metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
  88. metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
  89. metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
  90. metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
  91. metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
  92. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
  93. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
  94. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
  95. metaflow_extensions/outerbounds/plugins/secrets/__init__.py +0 -0
  96. metaflow_extensions/outerbounds/plugins/secrets/secrets.py +204 -0
  97. metaflow_extensions/outerbounds/plugins/snowflake/__init__.py +3 -0
  98. metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +378 -0
  99. metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
  100. metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +309 -0
  101. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +277 -0
  102. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +150 -0
  103. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +273 -0
  104. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +13 -0
  105. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +241 -0
  106. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +259 -0
  107. metaflow_extensions/outerbounds/plugins/tensorboard/__init__.py +50 -0
  108. metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
  109. metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
  110. metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
  111. metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
  112. metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
  113. metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
  114. metaflow_extensions/outerbounds/profilers/gpu.py +131 -47
  115. metaflow_extensions/outerbounds/remote_config.py +53 -16
  116. metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +138 -2
  117. metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
  118. metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py +1 -0
  119. metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
  120. metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py +1 -0
  121. metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
  122. metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
  123. metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
  124. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/METADATA +2 -2
  125. ob_metaflow_extensions-1.5.1.dist-info/RECORD +133 -0
  126. ob_metaflow_extensions-1.1.45rc3.dist-info/RECORD +0 -19
  127. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/WHEEL +0 -0
  128. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1924 @@
1
+ import subprocess
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ import time
4
+ import socket
5
+ import sys
6
+ import os
7
+ import functools
8
+ import json
9
+ import requests
10
+ from enum import Enum
11
+ import threading
12
+ from datetime import datetime
13
+
14
+ from .constants import OLLAMA_SUFFIX
15
+ from .exceptions import (
16
+ EmptyOllamaManifestCacheException,
17
+ EmptyOllamaBlobCacheException,
18
+ UnspecifiedRemoteStorageRootException,
19
+ )
20
+
21
+
22
+ class ProcessStatus:
23
+ RUNNING = "RUNNING"
24
+ FAILED = "FAILED"
25
+ SUCCESSFUL = "SUCCESSFUL"
26
+
27
+
28
+ class CircuitBreakerState(Enum):
29
+ CLOSED = "CLOSED"
30
+ OPEN = "OPEN"
31
+ HALF_OPEN = "HALF_OPEN"
32
+
33
+
34
+ class CircuitBreaker:
35
+ def __init__(
36
+ self,
37
+ failure_threshold,
38
+ recovery_timeout,
39
+ reset_timeout,
40
+ debug=False,
41
+ status_card=None,
42
+ ):
43
+ self.failure_threshold = failure_threshold
44
+ self.recovery_timeout = recovery_timeout
45
+ self.reset_timeout = reset_timeout
46
+ self.state = CircuitBreakerState.CLOSED
47
+ self.failure_count = 0
48
+ self.last_failure_time = None
49
+ self.last_open_time = None
50
+ self.debug = debug
51
+ self.status_card = status_card
52
+ self.lock = threading.Lock()
53
+ self.request_count = 0 # Track total requests for pattern detection
54
+
55
+ if self.debug:
56
+ print(
57
+ f"[@ollama] CircuitBreaker initialized: threshold={failure_threshold}, recovery={recovery_timeout}, reset={reset_timeout}"
58
+ )
59
+
60
+ def _log_state_change(self, new_state):
61
+ if self.debug:
62
+ print(
63
+ f"[@ollama] Circuit Breaker state change: {self.state.value} -> {new_state.value}"
64
+ )
65
+ self.state = new_state
66
+ self._update_status_card()
67
+
68
+ def _update_status_card(self):
69
+ """Update the status card with current circuit breaker state"""
70
+ if self.status_card:
71
+ self.status_card.update_status(
72
+ "circuit_breaker",
73
+ {
74
+ "state": self.state.value,
75
+ "failure_count": self.failure_count,
76
+ "last_failure_time": self.last_failure_time,
77
+ "last_open_time": self.last_open_time,
78
+ },
79
+ )
80
+
81
+ def record_success(self):
82
+ with self.lock:
83
+ self.request_count += 1
84
+ if self.state == CircuitBreakerState.HALF_OPEN:
85
+ self._log_state_change(CircuitBreakerState.CLOSED)
86
+ self.failure_count = 0
87
+ elif self.state == CircuitBreakerState.OPEN:
88
+ # Allow transition to HALF_OPEN on success - server might have recovered
89
+ self._log_state_change(CircuitBreakerState.HALF_OPEN)
90
+ if self.debug:
91
+ print(
92
+ f"[@ollama] Success recorded while circuit OPEN. Transitioning to HALF_OPEN for testing."
93
+ )
94
+ self.failure_count = 0
95
+ self.last_failure_time = None
96
+
97
+ # Log request count milestone for pattern detection
98
+ if self.debug and self.request_count % 100 == 0:
99
+ print(f"[@ollama] Request count: {self.request_count}")
100
+
101
+ self._update_status_card()
102
+
103
+ def record_failure(self):
104
+ with self.lock:
105
+ self.request_count += 1
106
+ self.failure_count += 1
107
+ self.last_failure_time = time.time()
108
+ if (
109
+ self.failure_count >= self.failure_threshold
110
+ and self.state == CircuitBreakerState.CLOSED
111
+ ):
112
+ self._log_state_change(CircuitBreakerState.OPEN)
113
+ self.last_open_time = time.time()
114
+ elif self.state == CircuitBreakerState.HALF_OPEN:
115
+ # If we fail while testing recovery, go back to OPEN
116
+ self._log_state_change(CircuitBreakerState.OPEN)
117
+ self.last_open_time = time.time()
118
+ if self.debug:
119
+ print(
120
+ f"[@ollama] Failure recorded. Count: {self.failure_count}, State: {self.state.value}, Total requests: {self.request_count}"
121
+ )
122
+ self._update_status_card()
123
+
124
+ def should_attempt_reset(self):
125
+ """Check if we should attempt to reset/restart Ollama based on reset_timeout"""
126
+ with self.lock:
127
+ if self.state == CircuitBreakerState.OPEN and self.last_open_time:
128
+ elapsed_time = time.time() - self.last_open_time
129
+ return elapsed_time > self.reset_timeout
130
+ return False
131
+
132
+ def is_request_allowed(self):
133
+ with self.lock:
134
+ if self.state == CircuitBreakerState.OPEN:
135
+ elapsed_time = time.time() - self.last_open_time
136
+ if elapsed_time > self.recovery_timeout:
137
+ self._log_state_change(CircuitBreakerState.HALF_OPEN)
138
+ if self.debug:
139
+ print(
140
+ f"[@ollama] Circuit Breaker transitioning to HALF_OPEN after {elapsed_time:.1f}s."
141
+ )
142
+ return True # Allow a single request to test recovery
143
+ else:
144
+ if self.debug:
145
+ print(
146
+ f"[@ollama] Circuit Breaker is OPEN. Not allowing request. Time until HALF_OPEN: {self.recovery_timeout - elapsed_time:.1f}s"
147
+ )
148
+ return False
149
+ elif self.state == CircuitBreakerState.HALF_OPEN:
150
+ # In HALF_OPEN, be more restrictive - only allow one request at a time
151
+ if self.debug:
152
+ print(
153
+ f"[@ollama] Circuit Breaker is HALF_OPEN. Allowing request to test recovery."
154
+ )
155
+ return True
156
+ else: # CLOSED
157
+ return True
158
+
159
+ def get_status(self):
160
+ with self.lock:
161
+ status = {
162
+ "state": self.state.value,
163
+ "failure_count": self.failure_count,
164
+ "last_failure_time": self.last_failure_time,
165
+ "last_open_time": self.last_open_time,
166
+ }
167
+ if self.debug:
168
+ print(f"[@ollama] Circuit Breaker status: {status}")
169
+ return status
170
+
171
+
172
+ class TimeoutCommand:
173
+ def __init__(self, command, timeout, debug=False, **kwargs):
174
+ self.command = command
175
+ self.timeout = timeout
176
+ self.debug = debug
177
+ self.input_data = kwargs.pop("input", None) # Remove input from kwargs
178
+ self.kwargs = kwargs
179
+
180
+ def run(self):
181
+ if self.debug:
182
+ print(
183
+ f"[@ollama] Executing command with timeout {self.timeout}s: {' '.join(self.command)}"
184
+ )
185
+ try:
186
+ process = subprocess.Popen(
187
+ self.command,
188
+ stdout=subprocess.PIPE,
189
+ stderr=subprocess.PIPE,
190
+ text=True,
191
+ **self.kwargs,
192
+ )
193
+ stdout, stderr = process.communicate(
194
+ input=self.input_data, timeout=self.timeout
195
+ )
196
+ return process.returncode, stdout, stderr
197
+ except subprocess.TimeoutExpired:
198
+ if self.debug:
199
+ print(
200
+ f"[@ollama] Command timed out after {self.timeout}s: {' '.join(self.command)}"
201
+ )
202
+ process.kill()
203
+ stdout, stderr = process.communicate()
204
+ return (
205
+ 124,
206
+ stdout,
207
+ stderr,
208
+ ) # 124 is the standard exit code for `timeout` command
209
+ except Exception as e:
210
+ if self.debug:
211
+ print(
212
+ f"[@ollama] Error executing command {' '.join(self.command)}: {e}"
213
+ )
214
+ return 1, "", str(e)
215
+
216
+
217
+ class OllamaHealthChecker:
218
+ def __init__(self, ollama_url, circuit_breaker, ollama_manager, debug=False):
219
+ self.ollama_url = ollama_url
220
+ self.circuit_breaker = circuit_breaker
221
+ self.ollama_manager = ollama_manager
222
+ self.debug = debug
223
+ self._stop_event = threading.Event()
224
+ self._thread = None
225
+ self._interval = 30 # Check every 30 seconds (less aggressive)
226
+
227
+ def _check_health(self):
228
+ try:
229
+ health_timeout = self.ollama_manager.timeouts.get("health_check", 5)
230
+ if self.debug:
231
+ print(f"[@ollama] Health check: Pinging {self.ollama_url}/api/tags")
232
+ response = requests.get(
233
+ f"{self.ollama_url}/api/tags", timeout=health_timeout
234
+ )
235
+ if response.status_code == 200:
236
+ self.circuit_breaker.record_success()
237
+ self._update_server_health_status("Healthy")
238
+ if self.debug:
239
+ print(
240
+ f"[@ollama] Health check successful. Circuit state: {self.circuit_breaker.state.value}"
241
+ )
242
+ return True
243
+ else:
244
+ if self.debug:
245
+ print(
246
+ f"[@ollama] Health check failed. Status code: {response.status_code}. Circuit state: {self.circuit_breaker.state.value}"
247
+ )
248
+ self.circuit_breaker.record_failure()
249
+ self._update_server_health_status(
250
+ f"Unhealthy (HTTP {response.status_code})"
251
+ )
252
+ return False
253
+ except requests.exceptions.RequestException as e:
254
+ if self.debug:
255
+ print(
256
+ f"[@ollama] Health check exception: {e}. Circuit state: {self.circuit_breaker.state.value}"
257
+ )
258
+ self.circuit_breaker.record_failure()
259
+ self._update_server_health_status(f"Unhealthy ({str(e)[:50]})")
260
+ return False
261
+
262
+ def _update_server_health_status(self, status):
263
+ """Update server health status in the status card"""
264
+ if self.ollama_manager.status_card:
265
+ self.ollama_manager.status_card.update_status(
266
+ "server", {"health_status": status, "last_health_check": datetime.now()}
267
+ )
268
+
269
+ def _run_health_check_loop(self):
270
+ while not self._stop_event.is_set():
271
+ # Always perform health check to monitor server status
272
+ self._check_health()
273
+
274
+ # Check if we should attempt a restart based on reset_timeout
275
+ if self.circuit_breaker.should_attempt_reset():
276
+ try:
277
+ if self.debug:
278
+ print(
279
+ "[@ollama] Circuit breaker reset timeout reached. Attempting restart..."
280
+ )
281
+ restart_success = self.ollama_manager._attempt_ollama_restart()
282
+ if restart_success:
283
+ if self.debug:
284
+ print("[@ollama] Restart successful via health checker")
285
+ else:
286
+ if self.debug:
287
+ print("[@ollama] Restart failed via health checker")
288
+ except Exception as e:
289
+ if self.debug:
290
+ print(
291
+ f"[@ollama] Error during health checker restart attempt: {e}"
292
+ )
293
+
294
+ self._stop_event.wait(self._interval)
295
+
296
+ def start(self):
297
+ if self._thread is None or not self._thread.is_alive():
298
+ self._stop_event.clear()
299
+ self._thread = threading.Thread(
300
+ target=self._run_health_check_loop, daemon=True
301
+ )
302
+ self._thread.start()
303
+ if self.debug:
304
+ print("[@ollama] OllamaHealthChecker started.")
305
+
306
+ def stop(self):
307
+ if self._thread and self._thread.is_alive():
308
+ self._stop_event.set()
309
+ self._thread.join(timeout=self._interval + 1) # Wait for thread to finish
310
+ if self.debug:
311
+ print("[@ollama] OllamaHealthChecker stopped.")
312
+
313
+
314
+ class OllamaRequestInterceptor:
315
+ def __init__(self, circuit_breaker, debug=False):
316
+ self.circuit_breaker = circuit_breaker
317
+ self.debug = debug
318
+ self.original_methods = {}
319
+ self._protection_installed = False
320
+
321
+ def install_protection(self):
322
+ """Install request protection by monkey-patching the ollama package"""
323
+ if self._protection_installed:
324
+ return
325
+
326
+ try:
327
+ import ollama # Import the actual ollama package
328
+
329
+ # Store original methods
330
+ self.original_methods = {
331
+ "chat": getattr(ollama, "chat", None),
332
+ "generate": getattr(ollama, "generate", None),
333
+ "embeddings": getattr(ollama, "embeddings", None),
334
+ }
335
+
336
+ # Replace with protected versions
337
+ if hasattr(ollama, "chat"):
338
+ ollama.chat = self._protected_chat
339
+ if hasattr(ollama, "generate"):
340
+ ollama.generate = self._protected_generate
341
+ if hasattr(ollama, "embeddings"):
342
+ ollama.embeddings = self._protected_embeddings
343
+
344
+ self._protection_installed = True
345
+ if self.debug:
346
+ print(
347
+ "[@ollama] Request protection installed on ollama package methods"
348
+ )
349
+
350
+ except ImportError:
351
+ if self.debug:
352
+ print(
353
+ "[@ollama] Warning: Could not import ollama package for request protection"
354
+ )
355
+ except Exception as e:
356
+ if self.debug:
357
+ print(f"[@ollama] Error installing request protection: {e}")
358
+
359
+ def remove_protection(self):
360
+ """Remove request protection by restoring original methods"""
361
+ if not self._protection_installed:
362
+ return
363
+
364
+ try:
365
+ import ollama
366
+
367
+ # Restore original methods
368
+ for method_name, original_method in self.original_methods.items():
369
+ if original_method is not None and hasattr(ollama, method_name):
370
+ setattr(ollama, method_name, original_method)
371
+
372
+ self._protection_installed = False
373
+ if self.debug:
374
+ print("[@ollama] Request protection removed")
375
+
376
+ except Exception as e:
377
+ if self.debug:
378
+ print(f"[@ollama] Error removing request protection: {e}")
379
+
380
+ def _protected_chat(self, *args, **kwargs):
381
+ if not self.circuit_breaker.is_request_allowed():
382
+ raise RuntimeError(
383
+ f"Ollama server is currently unavailable. Circuit Breaker is {self.circuit_breaker.state.value}. "
384
+ "Please wait or check Ollama server status. "
385
+ f"Current status: {self.circuit_breaker.get_status()}"
386
+ )
387
+ try:
388
+ if self.debug:
389
+ # Debug: log model being used in request
390
+ model_name = kwargs.get("model", "unknown")
391
+ if args and isinstance(args[0], dict) and "model" in args[0]:
392
+ model_name = args[0]["model"]
393
+ print(f"[@ollama] DEBUG: Making chat request with model: {model_name}")
394
+ print(f"[@ollama] DEBUG: Request args: {args}")
395
+ print(f"[@ollama] DEBUG: Request kwargs keys: {list(kwargs.keys())}")
396
+
397
+ result = self.original_methods["chat"](*args, **kwargs)
398
+ self.circuit_breaker.record_success()
399
+ return result
400
+ except Exception as e:
401
+ if self.debug:
402
+ print(f"[@ollama] Protected chat call failed: {e}")
403
+ print(f"[@ollama] DEBUG: Exception type: {type(e)}")
404
+ self.circuit_breaker.record_failure()
405
+ raise
406
+
407
+ def _protected_generate(self, *args, **kwargs):
408
+ if not self.circuit_breaker.is_request_allowed():
409
+ raise RuntimeError(
410
+ f"Ollama server is currently unavailable. Circuit Breaker is {self.circuit_breaker.state.value}. "
411
+ "Please wait or check Ollama server status. "
412
+ f"Current status: {self.circuit_breaker.get_status()}"
413
+ )
414
+ try:
415
+ result = self.original_methods["generate"](*args, **kwargs)
416
+ self.circuit_breaker.record_success()
417
+ return result
418
+ except Exception as e:
419
+ if self.debug:
420
+ print(f"[@ollama] Protected generate call failed: {e}")
421
+ self.circuit_breaker.record_failure()
422
+ raise
423
+
424
+ def _protected_embeddings(self, *args, **kwargs):
425
+ if not self.circuit_breaker.is_request_allowed():
426
+ raise RuntimeError(
427
+ f"Ollama server is currently unavailable. Circuit Breaker is {self.circuit_breaker.state.value}. "
428
+ "Please wait or check Ollama server status. "
429
+ f"Current status: {self.circuit_breaker.get_status()}"
430
+ )
431
+ try:
432
+ result = self.original_methods["embeddings"](*args, **kwargs)
433
+ self.circuit_breaker.record_success()
434
+ return result
435
+ except Exception as e:
436
+ if self.debug:
437
+ print(f"[@ollama] Protected embeddings call failed: {e}")
438
+ self.circuit_breaker.record_failure()
439
+ raise
440
+
441
+
442
+ class OllamaManager:
443
+ """
444
+ A process manager for Ollama runtimes.
445
+ Implements interface @ollama([models=...], ...) has a local, remote, or managed backend.
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ models,
451
+ backend="local",
452
+ flow_datastore_backend=None,
453
+ remote_storage_root=None,
454
+ force_pull=False,
455
+ cache_update_policy="auto",
456
+ force_cache_update=False,
457
+ debug=False,
458
+ circuit_breaker_config=None,
459
+ timeout_config=None,
460
+ status_card=None,
461
+ ):
462
+ self.models = {}
463
+ self.processes = {}
464
+ self.flow_datastore_backend = flow_datastore_backend
465
+ if self.flow_datastore_backend is not None:
466
+ self.remote_storage_root = self.get_ollama_storage_root(
467
+ self.flow_datastore_backend
468
+ )
469
+ elif remote_storage_root is not None:
470
+ self.remote_storage_root = remote_storage_root
471
+ else:
472
+ raise UnspecifiedRemoteStorageRootException(
473
+ "Can not determine the storage root, as both flow_datastore_backend and remote_storage_root arguments of OllamaManager are None."
474
+ )
475
+ self.force_pull = force_pull
476
+
477
+ # New cache logic
478
+ self.cache_update_policy = cache_update_policy
479
+ if force_cache_update: # Simple override
480
+ self.cache_update_policy = "force"
481
+ self.cache_status = {} # Track cache status per model
482
+
483
+ self.debug = debug
484
+ self.stats = {}
485
+ self.storage_info = {}
486
+ self.ollama_url = "http://localhost:11434" # Ollama API base URL
487
+ self.status_card = status_card
488
+ self.initialization_start = time.time()
489
+
490
+ if backend != "local":
491
+ raise ValueError(
492
+ "OllamaManager only supports the 'local' backend at this time."
493
+ )
494
+
495
+ # Validate and set up circuit breaker config
496
+ if circuit_breaker_config is None:
497
+ circuit_breaker_config = {
498
+ "failure_threshold": 3,
499
+ "recovery_timeout": 30, # Reduced from 60s - faster testing
500
+ "reset_timeout": 60, # Reduced from 300s - faster restart
501
+ }
502
+
503
+ # Set up timeout configuration
504
+ if timeout_config is None:
505
+ timeout_config = {
506
+ "pull": 600, # 10 minutes for model pulls
507
+ "stop": 30, # 30 seconds for model stops
508
+ "health_check": 5, # 5 seconds for health checks
509
+ "install": 60, # 1 minute for Ollama installation
510
+ "server_startup": 300, # 5 minutes for server startup
511
+ }
512
+ self.timeouts = timeout_config
513
+
514
+ # Initialize Circuit Breaker and Health Checker
515
+ self.circuit_breaker = CircuitBreaker(
516
+ failure_threshold=circuit_breaker_config.get("failure_threshold", 3),
517
+ recovery_timeout=circuit_breaker_config.get("recovery_timeout", 30),
518
+ reset_timeout=circuit_breaker_config.get("reset_timeout", 60),
519
+ debug=self.debug,
520
+ status_card=self.status_card,
521
+ )
522
+ self.health_checker = OllamaHealthChecker(
523
+ self.ollama_url, self.circuit_breaker, self, self.debug
524
+ )
525
+
526
+ self._log_event("info", "Starting Ollama initialization")
527
+ self._timeit(self._install_ollama, "install_ollama")
528
+ self._timeit(self._launch_server, "launch_server")
529
+ self.health_checker.start()
530
+
531
+ # Collect version information
532
+ self._collect_version_info()
533
+
534
+ # Initialize cache status display
535
+ self._update_cache_status()
536
+
537
+ # Pull models concurrently
538
+ with ThreadPoolExecutor() as executor:
539
+ futures = [executor.submit(self._pull_model, m) for m in models]
540
+ for future in as_completed(futures):
541
+ try:
542
+ future.result()
543
+ except Exception as e:
544
+ raise RuntimeError(f"Error pulling one or more models. {e}") from e
545
+
546
+ # Update final cache status
547
+ self._update_cache_status()
548
+
549
+ # Run models as background processes.
550
+ for m in models:
551
+ f = functools.partial(self._run_model, m)
552
+ self._timeit(f, f"model_{m.lower()}")
553
+
554
+ # Record total initialization time
555
+ total_init_time = time.time() - self.initialization_start
556
+ self._update_performance("total_initialization_time", total_init_time)
557
+ self._log_event(
558
+ "success", f"Ollama initialization completed in {total_init_time:.1f}s"
559
+ )
560
+
561
+ def _collect_version_info(self):
562
+ """Collect version information for Ollama system and Python client"""
563
+ version_info = {}
564
+
565
+ # Get Ollama system version
566
+ try:
567
+ result = subprocess.run(
568
+ ["ollama", "--version"], capture_output=True, text=True, timeout=10
569
+ )
570
+ if result.returncode == 0:
571
+ # Extract version from output - handle different formats
572
+ version_line = result.stdout.strip()
573
+ # Common formats: "ollama version 0.1.0", "0.1.0", "v0.1.0"
574
+ if "version" in version_line.lower():
575
+ # Extract everything after "version"
576
+ parts = version_line.lower().split("version")
577
+ if len(parts) > 1:
578
+ version_info["ollama_system"] = parts[1].strip()
579
+ else:
580
+ version_info["ollama_system"] = version_line
581
+ elif version_line.startswith("v"):
582
+ version_info["ollama_system"] = version_line[
583
+ 1:
584
+ ] # Remove 'v' prefix
585
+ else:
586
+ version_info["ollama_system"] = version_line
587
+ else:
588
+ version_info["ollama_system"] = "Unknown"
589
+ except Exception as e:
590
+ version_info["ollama_system"] = "Error detecting"
591
+ if self.debug:
592
+ print(f"[@ollama] Error getting system version: {e}")
593
+
594
+ # Get Python ollama client version
595
+ try:
596
+ import ollama
597
+
598
+ if hasattr(ollama, "__version__"):
599
+ version_info["ollama_python"] = ollama.__version__
600
+ else:
601
+ # Try alternative methods to get version
602
+ try:
603
+ import pkg_resources
604
+
605
+ version_info["ollama_python"] = pkg_resources.get_distribution(
606
+ "ollama"
607
+ ).version
608
+ except:
609
+ try:
610
+ # Try importlib.metadata (Python 3.8+)
611
+ from importlib import metadata
612
+
613
+ version_info["ollama_python"] = metadata.version("ollama")
614
+ except:
615
+ version_info["ollama_python"] = "Unknown"
616
+ except ImportError:
617
+ version_info["ollama_python"] = "Not installed"
618
+ except Exception as e:
619
+ version_info["ollama_python"] = "Error detecting"
620
+ if self.debug:
621
+ print(f"[@ollama] Error getting Python client version: {e}")
622
+
623
+ # Update status card with version info
624
+ if self.status_card:
625
+ self.status_card.update_status("versions", version_info)
626
+ self._log_event(
627
+ "info",
628
+ f"Versions: System {version_info.get('ollama_system', 'Unknown')}, Python {version_info.get('ollama_python', 'Unknown')}",
629
+ )
630
+
631
+ def _check_cache_exists(self, m):
632
+ """Check if cache exists for the given model"""
633
+ if self.local_datastore:
634
+ # Local datastore - no remote cache
635
+ return False
636
+
637
+ if m not in self.storage_info:
638
+ # Storage not set up yet
639
+ return False
640
+
641
+ try:
642
+ from metaflow import S3
643
+ from metaflow.plugins.datatools.s3.s3 import MetaflowS3NotFound
644
+
645
+ with S3() as s3:
646
+ # Check if manifest exists in remote storage
647
+ manifest_exists = s3.get(self.storage_info[m]["manifest_remote"]).exists
648
+
649
+ if manifest_exists:
650
+ self.cache_status[m] = "exists"
651
+ self._update_cache_status()
652
+ return True
653
+ else:
654
+ self.cache_status[m] = "missing"
655
+ self._update_cache_status()
656
+ return False
657
+
658
+ except Exception as e:
659
+ if self.debug:
660
+ print(f"[@ollama {m}] Error checking cache existence: {e}")
661
+ self.cache_status[m] = "error"
662
+ self._update_cache_status()
663
+ return False
664
+
665
+ def _should_update_cache(self, m):
666
+ """Determine if we should update cache for this model based on policy"""
667
+ if self.cache_update_policy == "never":
668
+ return False
669
+ elif self.cache_update_policy == "force":
670
+ return True
671
+ elif self.cache_update_policy == "auto":
672
+ # Only update if cache doesn't exist
673
+ cache_exists = self._check_cache_exists(m)
674
+ return not cache_exists
675
+ else:
676
+ # Unknown policy, default to auto behavior
677
+ cache_exists = self._check_cache_exists(m)
678
+ return not cache_exists
679
+
680
+ def _log_event(self, event_type, message):
681
+ """Log an event to the status card"""
682
+ if self.status_card:
683
+ self.status_card.add_event(event_type, message)
684
+ if self.debug:
685
+ print(f"[@ollama] {event_type.upper()}: {message}")
686
+
687
+ def _update_server_status(self, status, **kwargs):
688
+ """Update server status in the status card"""
689
+ if self.status_card:
690
+ update_data = {"status": status}
691
+ update_data.update(kwargs)
692
+ self.status_card.update_status("server", update_data)
693
+
694
+ def _update_model_status(self, model_name, **kwargs):
695
+ """Update model status in the status card"""
696
+ if self.status_card:
697
+ current_models = self.status_card.status_data.get("models", {})
698
+ if model_name not in current_models:
699
+ current_models[model_name] = {}
700
+ current_models[model_name].update(kwargs)
701
+ self.status_card.update_status("models", current_models)
702
+
703
+ def _update_performance(self, metric, value):
704
+ """Update performance metrics in the status card"""
705
+ if self.status_card:
706
+ self.status_card.update_status("performance", {metric: value})
707
+
708
+ def _update_circuit_breaker_status(self):
709
+ """Update circuit breaker status in the status card"""
710
+ if self.status_card:
711
+ cb_status = self.circuit_breaker.get_status()
712
+ self.status_card.update_status("circuit_breaker", cb_status)
713
+
714
+ def _update_cache_status(self):
715
+ """Update cache status in the status card"""
716
+ if self.status_card:
717
+ self.status_card.update_status(
718
+ "cache",
719
+ {
720
+ "policy": self.cache_update_policy,
721
+ "model_status": self.cache_status.copy(),
722
+ },
723
+ )
724
+
725
+ def _timeit(self, f, name):
726
+ t0 = time.time()
727
+ f()
728
+ tf = time.time()
729
+ duration = tf - t0
730
+ self.stats[name] = {"process_runtime": duration}
731
+
732
+ # Update performance metrics for status card
733
+ if name == "install_ollama":
734
+ self._update_performance("install_time", duration)
735
+ elif name == "launch_server":
736
+ self._update_performance("server_startup_time", duration)
737
+
738
+ def _is_port_open(self, host, port, timeout=1):
739
+ """Check if a TCP port is open on a given host."""
740
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
741
+ sock.settimeout(timeout)
742
+ try:
743
+ sock.connect((host, port))
744
+ return True
745
+ except socket.error:
746
+ return False
747
+
748
+ def _install_ollama(self, max_retries=3):
749
+ self._log_event("info", "Checking for existing Ollama installation")
750
+ try:
751
+ result = subprocess.run(["which", "ollama"], capture_output=True, text=True)
752
+ if result.returncode == 0:
753
+ self._log_event("success", "Ollama is already installed")
754
+ print("[@ollama] Ollama is already installed.")
755
+ return
756
+ except Exception as e:
757
+ if self.debug:
758
+ print(f"[@ollama] Did not find Ollama installation: {e}")
759
+ if sys.platform == "darwin":
760
+ raise RuntimeError(
761
+ "On macOS, please install Ollama manually from https://ollama.com/download."
762
+ )
763
+
764
+ self._log_event("info", "Installing Ollama...")
765
+ if self.debug:
766
+ print("[@ollama] Installing Ollama...")
767
+ env = os.environ.copy()
768
+ env["CURL_IPRESOLVE"] = "4"
769
+
770
+ for attempt in range(max_retries):
771
+ try:
772
+ install_cmd = ["curl", "-fsSL", "https://ollama.com/install.sh"]
773
+ curl_proc = subprocess.run(
774
+ install_cmd, capture_output=True, text=True, env=env, timeout=120
775
+ )
776
+ if curl_proc.returncode != 0:
777
+ raise RuntimeError(
778
+ f"Failed to download Ollama install script: stdout: {curl_proc.stdout}, stderr: {curl_proc.stderr}"
779
+ )
780
+ sh_proc = subprocess.run(
781
+ ["sh"],
782
+ input=curl_proc.stdout,
783
+ capture_output=True,
784
+ text=True,
785
+ env=env,
786
+ timeout=self.timeouts.get("install", 60),
787
+ )
788
+ if sh_proc.returncode != 0:
789
+ raise RuntimeError(
790
+ f"Ollama installation script failed: stdout: {sh_proc.stdout}, stderr: {sh_proc.stderr}"
791
+ )
792
+ self._log_event("success", "Ollama installation completed successfully")
793
+ if self.debug:
794
+ print("[@ollama] Ollama installed successfully.")
795
+ break
796
+ except Exception as e:
797
+ self._log_event(
798
+ "warning", f"Installation attempt {attempt+1} failed: {str(e)}"
799
+ )
800
+ if self.debug:
801
+ print(f"[@ollama] Installation attempt {attempt+1} failed: {e}")
802
+ if attempt < max_retries - 1:
803
+ time.sleep(5)
804
+ else:
805
+ self._log_event(
806
+ "error",
807
+ f"Ollama installation failed after {max_retries} attempts",
808
+ )
809
+ raise RuntimeError(
810
+ f"Error installing Ollama after {max_retries} attempts: {e}"
811
+ ) from e
812
+
813
+ def _launch_server(self):
814
+ """
815
+ Start the Ollama server process and ensure it's running.
816
+ """
817
+ self._update_server_status("Starting")
818
+ self._log_event("info", "Starting Ollama server...")
819
+
820
+ try:
821
+ print("[@ollama] Starting Ollama server...")
822
+ process = subprocess.Popen(
823
+ ["ollama", "serve"],
824
+ stdout=subprocess.PIPE,
825
+ stderr=subprocess.PIPE,
826
+ text=True,
827
+ )
828
+ self.processes[process.pid] = {
829
+ "p": process,
830
+ "properties": {"type": "api-server", "error_details": None},
831
+ "status": ProcessStatus.RUNNING,
832
+ }
833
+
834
+ if self.debug:
835
+ print(f"[@ollama] Started server process with PID {process.pid}.")
836
+
837
+ # Wait until the server is ready
838
+ host, port = "127.0.0.1", 11434
839
+ retries = 0
840
+ max_retries = 10
841
+ while (
842
+ not self._is_port_open(host, port, timeout=1) and retries < max_retries
843
+ ):
844
+ if retries == 0:
845
+ print("[@ollama] Waiting for server to be ready...")
846
+ elif retries % 3 == 0:
847
+ print(f"[@ollama] Still waiting... ({retries + 1}/{max_retries})")
848
+
849
+ # Check if process terminated unexpectedly during startup
850
+ returncode = process.poll()
851
+ if returncode is not None:
852
+ # Process exited, get error details but don't call communicate() which can hang
853
+ error_details = f"Return code: {returncode}"
854
+ self.processes[process.pid]["properties"][
855
+ "error_details"
856
+ ] = error_details
857
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
858
+ self._update_server_status("Failed", error_details=error_details)
859
+ self._log_event(
860
+ "error", f"Ollama server failed to start: {error_details}"
861
+ )
862
+ raise RuntimeError(
863
+ f"Ollama server failed to start. {error_details}"
864
+ )
865
+
866
+ time.sleep(5)
867
+ retries += 1
868
+
869
+ if not self._is_port_open(host, port, timeout=1):
870
+ error_details = (
871
+ f"Ollama server did not start listening on {host}:{port}"
872
+ )
873
+ self.processes[process.pid]["properties"][
874
+ "error_details"
875
+ ] = error_details
876
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
877
+ self._update_server_status("Failed", error_details=error_details)
878
+ self._log_event("error", f"Server startup timeout: {error_details}")
879
+ raise RuntimeError(f"Ollama server failed to start. {error_details}")
880
+
881
+ # Final check if process terminated unexpectedly
882
+ returncode = process.poll()
883
+ if returncode is not None:
884
+ error_details = f"Return code: {returncode}"
885
+ self.processes[process.pid]["properties"][
886
+ "error_details"
887
+ ] = error_details
888
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
889
+ self._update_server_status("Failed", error_details=error_details)
890
+ self._log_event(
891
+ "error", f"Server process died unexpectedly: {error_details}"
892
+ )
893
+ raise RuntimeError(f"Ollama server failed to start. {error_details}")
894
+
895
+ self._update_server_status("Running", uptime_start=datetime.now())
896
+ self._log_event("success", "Ollama server is ready and listening")
897
+ print("[@ollama] Server is ready.")
898
+
899
+ except Exception as e:
900
+ if "process" in locals() and process.pid in self.processes:
901
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
902
+ self.processes[process.pid]["properties"]["error_details"] = str(e)
903
+ self._update_server_status("Failed", error_details=str(e))
904
+ self._log_event("error", f"Error starting Ollama server: {str(e)}")
905
+ raise RuntimeError(f"Error starting Ollama server: {e}") from e
906
+
907
+ def _setup_storage(self, m):
908
+ """
909
+ Configure local and remote storage paths for an Ollama model.
910
+ """
911
+ # Parse model and tag name
912
+ ollama_model_name_components = m.split(":")
913
+ if len(ollama_model_name_components) == 1:
914
+ model_name = ollama_model_name_components[0]
915
+ tag = "latest"
916
+ elif len(ollama_model_name_components) == 2:
917
+ model_name = ollama_model_name_components[0]
918
+ tag = ollama_model_name_components[1]
919
+
920
+ # Find where Ollama actually stores models
921
+ possible_storage_roots = [
922
+ os.environ.get("OLLAMA_MODELS"),
923
+ "/usr/share/ollama/.ollama/models",
924
+ os.path.expanduser("~/.ollama/models"),
925
+ "/root/.ollama/models",
926
+ ]
927
+
928
+ ollama_local_storage_root = None
929
+ for root in possible_storage_roots:
930
+ if root and os.path.exists(root):
931
+ ollama_local_storage_root = root
932
+ break
933
+
934
+ if not ollama_local_storage_root:
935
+ # https://github.com/ollama/ollama/blob/main/docs/faq.md#where-are-models-stored
936
+ if sys.platform.startswith("linux"):
937
+ ollama_local_storage_root = "/usr/share/ollama/.ollama/models"
938
+ elif sys.platform == "darwin":
939
+ ollama_local_storage_root = os.path.expanduser("~/.ollama/models")
940
+
941
+ if self.debug:
942
+ print(
943
+ f"[@ollama {m}] Using Ollama storage root: {ollama_local_storage_root}."
944
+ )
945
+
946
+ blob_local_path = os.path.join(ollama_local_storage_root, "blobs")
947
+ manifest_base_path = os.path.join(
948
+ ollama_local_storage_root,
949
+ "manifests/registry.ollama.ai/library",
950
+ model_name,
951
+ )
952
+
953
+ # Create directories
954
+ try:
955
+ os.makedirs(blob_local_path, exist_ok=True)
956
+ os.makedirs(manifest_base_path, exist_ok=True)
957
+ except FileExistsError:
958
+ pass
959
+
960
+ # Set up remote paths
961
+ if not self.local_datastore and self.remote_storage_root is not None:
962
+ blob_remote_key = os.path.join(self.remote_storage_root, "blobs")
963
+ manifest_remote_key = os.path.join(
964
+ self.remote_storage_root,
965
+ "manifests/registry.ollama.ai/library",
966
+ model_name,
967
+ tag,
968
+ )
969
+ else:
970
+ blob_remote_key = None
971
+ manifest_remote_key = None
972
+
973
+ self.storage_info[m] = {
974
+ "blob_local_root": blob_local_path,
975
+ "blob_remote_root": blob_remote_key,
976
+ "manifest_local": os.path.join(manifest_base_path, tag),
977
+ "manifest_remote": manifest_remote_key,
978
+ "manifest_content": None,
979
+ "model_name": model_name,
980
+ "tag": tag,
981
+ "storage_root": ollama_local_storage_root,
982
+ }
983
+
984
+ if self.debug:
985
+ print(f"[@ollama {m}] Storage paths configured.")
986
+
987
+ def _fetch_manifest(self, m):
988
+ """
989
+ Load the manifest file and content, either from local storage or remote cache.
990
+ """
991
+ if self.debug:
992
+ print(f"[@ollama {m}] Checking for cached manifest...")
993
+
994
+ def _disk_to_memory():
995
+ with open(self.storage_info[m]["manifest_local"], "r") as f:
996
+ self.storage_info[m]["manifest_content"] = json.load(f)
997
+
998
+ if os.path.exists(self.storage_info[m]["manifest_local"]):
999
+ if self.storage_info[m]["manifest_content"] is None:
1000
+ _disk_to_memory()
1001
+ if self.debug:
1002
+ print(f"[@ollama {m}] Manifest found locally.")
1003
+ elif self.local_datastore:
1004
+ if self.debug:
1005
+ print(f"[@ollama {m}] No manifest found in local datastore.")
1006
+ return None
1007
+ else:
1008
+ from metaflow import S3
1009
+ from metaflow.plugins.datatools.s3.s3 import MetaflowS3NotFound
1010
+
1011
+ try:
1012
+ with S3() as s3:
1013
+ s3obj = s3.get(self.storage_info[m]["manifest_remote"])
1014
+ if not s3obj.exists:
1015
+ raise EmptyOllamaManifestCacheException(
1016
+ f"No manifest in remote storage for model {m}"
1017
+ )
1018
+
1019
+ if self.debug:
1020
+ print(f"[@ollama {m}] Downloaded manifest from cache.")
1021
+ os.rename(s3obj.path, self.storage_info[m]["manifest_local"])
1022
+ _disk_to_memory()
1023
+
1024
+ if self.debug:
1025
+ print(
1026
+ f"[@ollama {m}] Manifest found in remote cache, downloaded locally."
1027
+ )
1028
+ except (MetaflowS3NotFound, EmptyOllamaManifestCacheException):
1029
+ if self.debug:
1030
+ print(
1031
+ f"[@ollama {m}] No manifest found locally or in remote cache."
1032
+ )
1033
+ return None
1034
+
1035
+ return self.storage_info[m]["manifest_content"]
1036
+
1037
+ def _fetch_blobs(self, m):
1038
+ """
1039
+ Fetch missing blobs from remote cache.
1040
+ """
1041
+ if self.debug:
1042
+ print(f"[@ollama {m}] Checking for cached blobs...")
1043
+
1044
+ manifest = self._fetch_manifest(m)
1045
+ if not manifest:
1046
+ raise EmptyOllamaBlobCacheException(f"No manifest available for model {m}")
1047
+
1048
+ blobs_required = [layer["digest"] for layer in manifest["layers"]]
1049
+ missing_blob_info = []
1050
+
1051
+ # Check which blobs are missing locally
1052
+ for blob_digest in blobs_required:
1053
+ blob_filename = blob_digest.replace(":", "-")
1054
+ local_blob_path = os.path.join(
1055
+ self.storage_info[m]["blob_local_root"], blob_filename
1056
+ )
1057
+
1058
+ if not os.path.exists(local_blob_path):
1059
+ if self.debug:
1060
+ print(f"[@ollama {m}] Blob {blob_digest} not found locally.")
1061
+
1062
+ remote_blob_path = os.path.join(
1063
+ self.storage_info[m]["blob_remote_root"], blob_filename
1064
+ )
1065
+ missing_blob_info.append(
1066
+ {
1067
+ "digest": blob_digest,
1068
+ "filename": blob_filename,
1069
+ "remote_path": remote_blob_path,
1070
+ "local_path": local_blob_path,
1071
+ }
1072
+ )
1073
+
1074
+ if not missing_blob_info:
1075
+ if self.debug:
1076
+ print(f"[@ollama {m}] All blobs found locally.")
1077
+ return
1078
+
1079
+ if self.debug:
1080
+ print(
1081
+ f"[@ollama {m}] Downloading {len(missing_blob_info)} missing blobs from cache..."
1082
+ )
1083
+
1084
+ remote_urls = [blob_info["remote_path"] for blob_info in missing_blob_info]
1085
+
1086
+ from metaflow import S3
1087
+
1088
+ try:
1089
+ with S3() as s3:
1090
+ if len(remote_urls) == 1:
1091
+ s3objs = [s3.get(remote_urls[0])]
1092
+ else:
1093
+ s3objs = s3.get_many(remote_urls)
1094
+
1095
+ if not isinstance(s3objs, list):
1096
+ s3objs = [s3objs]
1097
+
1098
+ # Move each downloaded blob to correct location
1099
+ for i, s3obj in enumerate(s3objs):
1100
+ if not s3obj.exists:
1101
+ blob_info = missing_blob_info[i]
1102
+ raise EmptyOllamaBlobCacheException(
1103
+ f"Blob {blob_info['digest']} not found in remote cache for model {m}"
1104
+ )
1105
+
1106
+ blob_info = missing_blob_info[i]
1107
+ os.makedirs(os.path.dirname(blob_info["local_path"]), exist_ok=True)
1108
+ os.rename(s3obj.path, blob_info["local_path"])
1109
+
1110
+ if self.debug:
1111
+ print(f"[@ollama {m}] Downloaded blob {blob_info['filename']}.")
1112
+
1113
+ except Exception as e:
1114
+ if self.debug:
1115
+ print(f"[@ollama {m}] Error during blob fetch: {e}")
1116
+ raise EmptyOllamaBlobCacheException(
1117
+ f"Failed to fetch blobs for model {m}: {e}"
1118
+ )
1119
+
1120
+ if self.debug:
1121
+ print(
1122
+ f"[@ollama {m}] Successfully downloaded all missing blobs from cache."
1123
+ )
1124
+
1125
+ def _verify_model_available(self, m):
1126
+ """
1127
+ Verify model is available using Ollama API
1128
+ """
1129
+ try:
1130
+ if self.debug:
1131
+ print(f"[@ollama] DEBUG: Verifying model availability for: {m}")
1132
+
1133
+ response = requests.post(
1134
+ f"{self.ollama_url}/api/show", json={"model": m}, timeout=10
1135
+ )
1136
+
1137
+ available = response.status_code == 200
1138
+
1139
+ if self.debug:
1140
+ if available:
1141
+ print(f"[@ollama {m}] ✓ Model is available via API.")
1142
+ # Also list all available models for debugging
1143
+ try:
1144
+ tags_response = requests.get(
1145
+ f"{self.ollama_url}/api/tags", timeout=10
1146
+ )
1147
+ if tags_response.status_code == 200:
1148
+ models = tags_response.json().get("models", [])
1149
+ model_names = [
1150
+ model.get("name", "unknown") for model in models
1151
+ ]
1152
+ print(
1153
+ f"[@ollama] DEBUG: All available models: {model_names}"
1154
+ )
1155
+ except Exception as e:
1156
+ print(f"[@ollama] DEBUG: Could not list models: {e}")
1157
+ else:
1158
+ print(
1159
+ f"[@ollama {m}] ✗ Model not available via API (status: {response.status_code})."
1160
+ )
1161
+ try:
1162
+ error_detail = response.text
1163
+ print(f"[@ollama] DEBUG: Error response: {error_detail}")
1164
+ except:
1165
+ pass
1166
+
1167
+ return available
1168
+
1169
+ except Exception as e:
1170
+ if self.debug:
1171
+ print(f"[@ollama {m}] Error verifying model: {e}")
1172
+ return False
1173
+
1174
+ def _register_cached_model_with_ollama(self, m):
1175
+ """
1176
+ Register a cached model with Ollama using the API.
1177
+ """
1178
+ try:
1179
+ show_response = requests.post(
1180
+ f"{self.ollama_url}/api/show", json={"model": m}, timeout=10
1181
+ )
1182
+
1183
+ if show_response.status_code == 200:
1184
+ if self.debug:
1185
+ print(f"[@ollama {m}] Model already registered with Ollama.")
1186
+ return True
1187
+
1188
+ # Try to create/register the model from existing files
1189
+ if self.debug:
1190
+ print(f"[@ollama {m}] Registering cached model with Ollama...")
1191
+
1192
+ create_response = requests.post(
1193
+ f"{self.ollama_url}/api/create",
1194
+ json={
1195
+ "model": m,
1196
+ "from": m, # Use same name - should find existing files
1197
+ "stream": False,
1198
+ },
1199
+ timeout=60,
1200
+ )
1201
+
1202
+ if create_response.status_code == 200:
1203
+ result = create_response.json()
1204
+ if result.get("status") == "success":
1205
+ if self.debug:
1206
+ print(f"[@ollama {m}] Successfully registered cached model.")
1207
+ return True
1208
+ else:
1209
+ if self.debug:
1210
+ print(f"[@ollama {m}] Create response: {result}.")
1211
+
1212
+ # Fallback: try a pull which should be fast if files exist
1213
+ if self.debug:
1214
+ print(f"[@ollama {m}] Create failed, trying pull to register...")
1215
+
1216
+ pull_response = requests.post(
1217
+ f"{self.ollama_url}/api/pull",
1218
+ json={"model": m, "stream": False},
1219
+ timeout=120,
1220
+ )
1221
+
1222
+ if pull_response.status_code == 200:
1223
+ result = pull_response.json()
1224
+ if result.get("status") == "success":
1225
+ if self.debug:
1226
+ print(f"[@ollama {m}] Model registered via pull.")
1227
+ return True
1228
+
1229
+ except requests.exceptions.RequestException as e:
1230
+ if self.debug:
1231
+ print(f"[@ollama {m}] API registration failed: {e}")
1232
+ except Exception as e:
1233
+ if self.debug:
1234
+ print(f"[@ollama {m}] Error during registration: {e}")
1235
+
1236
+ return False
1237
+
1238
+ def _pull_model(self, m):
1239
+ """
1240
+ Pull/setup a model, using cache when possible.
1241
+ """
1242
+ self._update_model_status(m, status="Setting up storage")
1243
+ self._log_event("info", f"Setting up model {m}")
1244
+ pull_start_time = time.time()
1245
+
1246
+ self._setup_storage(m)
1247
+
1248
+ # Check cache existence and inform user about cache strategy
1249
+ cache_exists = self._check_cache_exists(m)
1250
+ will_update_cache = self._should_update_cache(m)
1251
+
1252
+ if cache_exists:
1253
+ if will_update_cache:
1254
+ self._log_event(
1255
+ "info",
1256
+ f"Cache exists for {m}, but will be updated due to {self.cache_update_policy} policy",
1257
+ )
1258
+ print(
1259
+ f"[@ollama {m}] Cache exists but will be updated ({self.cache_update_policy} policy)"
1260
+ )
1261
+ else:
1262
+ self._log_event("info", f"Using existing cache for {m}")
1263
+ print(f"[@ollama {m}] Using existing cache")
1264
+ else:
1265
+ if will_update_cache:
1266
+ self._log_event(
1267
+ "info",
1268
+ f"No cache found for {m}, will populate after successful setup",
1269
+ )
1270
+ print(f"[@ollama {m}] No cache found, will populate cache after setup")
1271
+ else:
1272
+ self._log_event(
1273
+ "info",
1274
+ f"No cache found for {m}, but cache updates disabled ({self.cache_update_policy} policy)",
1275
+ )
1276
+ print(
1277
+ f"[@ollama {m}] No cache found, cache updates disabled ({self.cache_update_policy} policy)"
1278
+ )
1279
+
1280
+ # Try to fetch manifest from cache first
1281
+ manifest = None
1282
+ try:
1283
+ manifest = self._fetch_manifest(m)
1284
+ except (EmptyOllamaManifestCacheException, Exception) as e:
1285
+ if self.debug:
1286
+ print(f"[@ollama {m}] No cached manifest found or error fetching: {e}")
1287
+ manifest = None
1288
+
1289
+ # If we don't have a cached manifest or force_pull is True, pull the model
1290
+ if self.force_pull or not manifest:
1291
+ try:
1292
+ self._update_model_status(m, status="Downloading")
1293
+ self._log_event("info", f"Downloading model {m}...")
1294
+ print(f"[@ollama {m}] Not using cache. Downloading model {m}...")
1295
+ result = subprocess.run(
1296
+ ["ollama", "pull", m],
1297
+ capture_output=True,
1298
+ text=True,
1299
+ timeout=self.timeouts.get("pull", 600),
1300
+ )
1301
+ if result.returncode != 0:
1302
+ self._update_model_status(m, status="Failed")
1303
+ self._log_event(
1304
+ "error", f"Failed to pull model {m}: {result.stderr}"
1305
+ )
1306
+ raise RuntimeError(
1307
+ f"Failed to pull model {m}: stdout: {result.stdout}, stderr: {result.stderr}"
1308
+ )
1309
+ pull_time = time.time() - pull_start_time
1310
+ self._update_model_status(m, status="Downloaded", pull_time=pull_time)
1311
+ self._log_event("success", f"Model {m} downloaded in {pull_time:.1f}s")
1312
+ print(f"[@ollama {m}] Model downloaded successfully.")
1313
+ except Exception as e:
1314
+ self._update_model_status(m, status="Failed")
1315
+ self._log_event("error", f"Error pulling model {m}: {str(e)}")
1316
+ raise RuntimeError(f"Error pulling Ollama model {m}: {e}") from e
1317
+ else:
1318
+ # We have a cached manifest, try to fetch the blobs
1319
+ try:
1320
+ self._update_model_status(m, status="Loading from cache")
1321
+ self._log_event("info", f"Loading model {m} from cache")
1322
+ self._fetch_blobs(m)
1323
+ print(f"[@ollama {m}] Using cached model.")
1324
+
1325
+ # Register the cached model with Ollama
1326
+ if not self._verify_model_available(m):
1327
+ if not self._register_cached_model_with_ollama(m):
1328
+ self._update_model_status(m, status="Failed")
1329
+ self._log_event("error", f"Failed to register cached model {m}")
1330
+ raise RuntimeError(
1331
+ f"Failed to register cached model {m} with Ollama"
1332
+ )
1333
+
1334
+ pull_time = time.time() - pull_start_time
1335
+ self._update_model_status(m, status="Cached", pull_time=pull_time)
1336
+ self._log_event(
1337
+ "success", f"Model {m} loaded from cache in {pull_time:.1f}s"
1338
+ )
1339
+
1340
+ except (EmptyOllamaBlobCacheException, Exception) as e:
1341
+ if self.debug:
1342
+ print(f"[@ollama {m}] Cache failed, downloading model...")
1343
+ print(f"[@ollama {m}] Error: {e}")
1344
+
1345
+ # Fallback to pulling the model
1346
+ try:
1347
+ self._update_model_status(m, status="Downloading (fallback)")
1348
+ self._log_event(
1349
+ "warning", f"Cache failed for {m}, downloading as fallback"
1350
+ )
1351
+ result = subprocess.run(
1352
+ ["ollama", "pull", m],
1353
+ capture_output=True,
1354
+ text=True,
1355
+ timeout=self.timeouts.get("pull", 600),
1356
+ )
1357
+ if result.returncode != 0:
1358
+ self._update_model_status(m, status="Failed")
1359
+ self._log_event("error", f"Fallback pull failed for model {m}")
1360
+ raise RuntimeError(
1361
+ f"Failed to pull model {m}: stdout: {result.stdout}, stderr: {result.stderr}"
1362
+ )
1363
+ pull_time = time.time() - pull_start_time
1364
+ self._update_model_status(
1365
+ m, status="Downloaded (fallback)", pull_time=pull_time
1366
+ )
1367
+ self._log_event(
1368
+ "success",
1369
+ f"Model {m} downloaded via fallback in {pull_time:.1f}s",
1370
+ )
1371
+ print(f"[@ollama {m}] Model downloaded successfully (fallback).")
1372
+ except Exception as pull_e:
1373
+ self._update_model_status(m, status="Failed")
1374
+ self._log_event(
1375
+ "error",
1376
+ f"Fallback download failed for model {m}: {str(pull_e)}",
1377
+ )
1378
+ raise RuntimeError(
1379
+ f"Error pulling Ollama model {m} as fallback: {pull_e}"
1380
+ ) from pull_e
1381
+
1382
+ # Final verification that the model is available
1383
+ if not self._verify_model_available(m):
1384
+ self._update_model_status(m, status="Failed")
1385
+ self._log_event("error", f"Model {m} verification failed")
1386
+ raise RuntimeError(f"Model {m} is not available to Ollama after setup")
1387
+
1388
+ # Collect model metadata (size and blob count)
1389
+ metadata = self._collect_model_metadata(m)
1390
+ self._update_model_status(
1391
+ m,
1392
+ status="Ready",
1393
+ size_formatted=metadata["size_formatted"],
1394
+ blob_count=metadata["blob_count"],
1395
+ )
1396
+ self._log_event("success", f"Model {m} setup complete and verified")
1397
+ if self.debug:
1398
+ print(f"[@ollama {m}] Model setup complete and verified.")
1399
+ if metadata["size_formatted"] != "Unknown":
1400
+ print(
1401
+ f"[@ollama {m}] Model size: {metadata['size_formatted']}, Blobs: {metadata['blob_count']}"
1402
+ )
1403
+
1404
+ def _run_model(self, m):
1405
+ """
1406
+ Start the Ollama model as a subprocess and record its status.
1407
+ """
1408
+ process = None
1409
+ try:
1410
+ self._update_model_status(m, status="Starting process")
1411
+ self._log_event("info", f"Starting model process for {m}")
1412
+ if self.debug:
1413
+ print(f"[@ollama {m}] Starting model process...")
1414
+
1415
+ # For `ollama run`, we want it to stay running, so no timeout on Popen.
1416
+ # The health checker will detect if it becomes unresponsive.
1417
+ process = subprocess.Popen(
1418
+ ["ollama", "run", m],
1419
+ stdout=subprocess.PIPE,
1420
+ stderr=subprocess.PIPE,
1421
+ text=True,
1422
+ )
1423
+ self.processes[process.pid] = {
1424
+ "p": process,
1425
+ "properties": {"type": "model", "model": m, "error_details": None},
1426
+ "status": ProcessStatus.RUNNING,
1427
+ }
1428
+
1429
+ if self.debug:
1430
+ print(f"[@ollama {m}] Model process PID: {process.pid}.")
1431
+
1432
+ # We don't want to wait here indefinitely. Just check if it failed immediately.
1433
+ # The health checker will monitor long-term responsiveness.
1434
+ try:
1435
+ process.wait(timeout=1) # Check if it exited immediately
1436
+ returncode = process.poll()
1437
+ if (
1438
+ returncode is not None and returncode != 0
1439
+ ): # If it exited immediately with an error
1440
+ stdout, stderr = process.communicate()
1441
+ error_details = f"Return code: {returncode}, Error: {stderr}"
1442
+ self.processes[process.pid]["properties"][
1443
+ "error_details"
1444
+ ] = error_details
1445
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
1446
+ self._update_model_status(m, status="Failed")
1447
+ self._log_event(
1448
+ "error",
1449
+ f"Model {m} process failed immediately: {error_details}",
1450
+ )
1451
+ if self.debug:
1452
+ print(
1453
+ f"[@ollama {m}] Process {process.pid} failed immediately: {error_details}."
1454
+ )
1455
+ raise RuntimeError(
1456
+ f"Ollama model {m} failed to start immediately: {error_details}"
1457
+ )
1458
+ elif returncode == 0:
1459
+ # This case should ideally not happen for a long-running model
1460
+ if self.debug:
1461
+ print(
1462
+ f"[@ollama {m}] Process {process.pid} exited immediately with success. This might be unexpected for a model process."
1463
+ )
1464
+ self.processes[process.pid]["status"] = ProcessStatus.SUCCESSFUL
1465
+
1466
+ except subprocess.TimeoutExpired:
1467
+ # This is the expected case: process is running and hasn't exited
1468
+ self._update_model_status(m, status="Running")
1469
+ self._log_event("success", f"Model {m} process started successfully")
1470
+ if self.debug:
1471
+ print(
1472
+ f"[@ollama {m}] Model process {process.pid} is running in background."
1473
+ )
1474
+ pass # Process is still running, which is good
1475
+
1476
+ except Exception as e:
1477
+ if process and process.pid in self.processes:
1478
+ self.processes[process.pid]["status"] = ProcessStatus.FAILED
1479
+ self.processes[process.pid]["properties"]["error_details"] = str(e)
1480
+ self._update_model_status(m, status="Failed")
1481
+ self._log_event("error", f"Error running model {m}: {str(e)}")
1482
+ raise RuntimeError(f"Error running Ollama model {m}: {e}") from e
1483
+
1484
+ def terminate_models(self, skip_push_check=None):
1485
+ """
1486
+ Terminate all processes gracefully and update cache.
1487
+ """
1488
+ shutdown_start_time = time.time()
1489
+ self._log_event("info", "Starting Ollama shutdown sequence")
1490
+ print("[@ollama] Shutting down models...")
1491
+
1492
+ # Stop the health checker first
1493
+ self.health_checker.stop()
1494
+
1495
+ # Handle backward compatibility for skip_push_check parameter
1496
+ if skip_push_check is not None:
1497
+ # Legacy parameter provided
1498
+ if skip_push_check:
1499
+ self.cache_update_policy = "never"
1500
+ self._log_event(
1501
+ "warning",
1502
+ "Using legacy skip_push_check=True, setting cache policy to 'never'",
1503
+ )
1504
+ else:
1505
+ self.cache_update_policy = "force"
1506
+ self._log_event(
1507
+ "warning",
1508
+ "Using legacy skip_push_check=False, setting cache policy to 'force'",
1509
+ )
1510
+
1511
+ # Shutdown models
1512
+ model_shutdown_results = {}
1513
+ for pid, process_info in list(self.processes.items()):
1514
+ if process_info["properties"].get("type") == "model":
1515
+ model_name = process_info["properties"].get("model")
1516
+ model_shutdown_start = time.time()
1517
+ shutdown_cause = "graceful"
1518
+
1519
+ self._update_model_status(model_name, status="Stopping")
1520
+ self._log_event("info", f"Stopping model {model_name}")
1521
+ if self.debug:
1522
+ print(f"[@ollama {model_name}] Stopping model process...")
1523
+
1524
+ try:
1525
+ result = subprocess.run(
1526
+ ["ollama", "stop", model_name],
1527
+ capture_output=True,
1528
+ text=True,
1529
+ timeout=self.timeouts.get("stop", 30),
1530
+ )
1531
+ if result.returncode == 0:
1532
+ process_info["status"] = ProcessStatus.SUCCESSFUL
1533
+ self._update_model_status(model_name, status="Stopped")
1534
+ self._log_event(
1535
+ "success", f"Model {model_name} stopped gracefully"
1536
+ )
1537
+ if self.debug:
1538
+ print(f"[@ollama {model_name}] Stopped successfully.")
1539
+ else:
1540
+ process_info["status"] = ProcessStatus.FAILED
1541
+ shutdown_cause = "force_kill"
1542
+ self._update_model_status(model_name, status="Force stopped")
1543
+ self._log_event(
1544
+ "warning",
1545
+ f"Model {model_name} stop command failed, killing process",
1546
+ )
1547
+ if self.debug:
1548
+ print(
1549
+ f"[@ollama {model_name}] Stop failed: {result.stderr}. Attempting to kill process directly."
1550
+ )
1551
+ # Fallback: if 'ollama stop' fails, try to kill the process directly
1552
+ try:
1553
+ process_info["p"].terminate()
1554
+ process_info["p"].wait(timeout=5)
1555
+ if process_info["p"].poll() is None:
1556
+ process_info["p"].kill()
1557
+ process_info["p"].wait()
1558
+ process_info["status"] = ProcessStatus.SUCCESSFUL
1559
+ self._update_model_status(model_name, status="Killed")
1560
+ self._log_event(
1561
+ "warning", f"Model {model_name} process killed directly"
1562
+ )
1563
+ if self.debug:
1564
+ print(
1565
+ f"[@ollama {model_name}] Process killed directly."
1566
+ )
1567
+ except Exception as kill_e:
1568
+ process_info["status"] = ProcessStatus.FAILED
1569
+ shutdown_cause = "failed"
1570
+ self._update_model_status(
1571
+ model_name, status="Failed to stop"
1572
+ )
1573
+ self._log_event(
1574
+ "error",
1575
+ f"Model {model_name} failed to stop: {str(kill_e)}",
1576
+ )
1577
+ print(
1578
+ f"[@ollama {model_name}] Error killing process directly: {kill_e}"
1579
+ )
1580
+
1581
+ except Exception as e:
1582
+ process_info["status"] = ProcessStatus.FAILED
1583
+ shutdown_cause = "failed"
1584
+ self._update_model_status(model_name, status="Failed to stop")
1585
+ self._log_event(
1586
+ "error", f"Model {model_name} shutdown error: {str(e)}"
1587
+ )
1588
+ print(f"[@ollama {model_name}] Error stopping: {e}")
1589
+
1590
+ # Record model shutdown timing
1591
+ model_shutdown_time = time.time() - model_shutdown_start
1592
+ model_shutdown_results[model_name] = {
1593
+ "shutdown_time": model_shutdown_time,
1594
+ "shutdown_cause": shutdown_cause,
1595
+ }
1596
+
1597
+ # Smart cache update logic
1598
+ should_update = self._should_update_cache(model_name)
1599
+ if should_update:
1600
+ self._log_event(
1601
+ "info",
1602
+ f"Updating cache for {model_name} ({self.cache_update_policy} policy)",
1603
+ )
1604
+ self._update_model_cache(model_name)
1605
+ else:
1606
+ cache_reason = f"policy is '{self.cache_update_policy}'"
1607
+ if (
1608
+ self.cache_update_policy == "auto"
1609
+ and self.cache_status.get(model_name) == "exists"
1610
+ ):
1611
+ cache_reason = "cache already exists"
1612
+ self._log_event(
1613
+ "info",
1614
+ f"Skipping cache update for {model_name} ({cache_reason})",
1615
+ )
1616
+ if self.debug:
1617
+ print(
1618
+ f"[@ollama {model_name}] Skipping cache update: {cache_reason}"
1619
+ )
1620
+
1621
+ # Stop the API server
1622
+ server_shutdown_cause = "graceful"
1623
+ server_shutdown_start = time.time()
1624
+ for pid, process_info in list(self.processes.items()):
1625
+ if process_info["properties"].get("type") == "api-server":
1626
+ self._update_server_status("Stopping")
1627
+ self._log_event("info", "Stopping Ollama API server")
1628
+ if self.debug:
1629
+ print(f"[@ollama] Stopping API server process PID {pid}.")
1630
+
1631
+ process = process_info["p"]
1632
+ try:
1633
+ process.terminate()
1634
+ try:
1635
+ process.wait(timeout=5)
1636
+ except subprocess.TimeoutExpired:
1637
+ server_shutdown_cause = "force_kill"
1638
+ self._log_event(
1639
+ "warning",
1640
+ "API server did not terminate gracefully, killing...",
1641
+ )
1642
+ print(
1643
+ f"[@ollama] API server PID {pid} did not terminate, killing..."
1644
+ )
1645
+ process.kill()
1646
+ process.wait()
1647
+
1648
+ process_info["status"] = ProcessStatus.SUCCESSFUL
1649
+ self._update_server_status("Stopped")
1650
+ self._log_event(
1651
+ "success", f"API server stopped ({server_shutdown_cause})"
1652
+ )
1653
+ if self.debug:
1654
+ print(f"[@ollama] API server terminated successfully.")
1655
+ except Exception as e:
1656
+ process_info["status"] = ProcessStatus.FAILED
1657
+ server_shutdown_cause = "failed"
1658
+ self._update_server_status("Failed to stop")
1659
+ self._log_event("error", f"API server shutdown error: {str(e)}")
1660
+ print(f"[@ollama] Warning: Error terminating API server: {e}")
1661
+
1662
+ # Record total shutdown time and performance metrics
1663
+ total_shutdown_time = time.time() - shutdown_start_time
1664
+ server_shutdown_time = time.time() - server_shutdown_start
1665
+
1666
+ # Update performance metrics
1667
+ self._update_performance("server_shutdown_time", server_shutdown_time)
1668
+ self._update_performance("total_shutdown_time", total_shutdown_time)
1669
+ self._update_performance("shutdown_cause", server_shutdown_cause)
1670
+
1671
+ # Log individual model shutdown times
1672
+ for model_name, results in model_shutdown_results.items():
1673
+ self._update_performance(
1674
+ f"{model_name}_shutdown_time", results["shutdown_time"]
1675
+ )
1676
+ self._update_performance(
1677
+ f"{model_name}_shutdown_cause", results["shutdown_cause"]
1678
+ )
1679
+
1680
+ self._log_event(
1681
+ "success", f"Ollama shutdown completed in {total_shutdown_time:.1f}s"
1682
+ )
1683
+ print("[@ollama] All models stopped.")
1684
+
1685
+ # Show performance summary
1686
+ if self.debug:
1687
+ if hasattr(self, "stats") and self.stats:
1688
+ print("[@ollama] Performance summary:")
1689
+ for operation, stats in self.stats.items():
1690
+ runtime = stats.get("process_runtime", 0)
1691
+ if runtime > 1: # Only show operations that took meaningful time
1692
+ print(f"[@ollama] {operation}: {runtime:.1f}s")
1693
+
1694
+ def _update_model_cache(self, model_name):
1695
+ """
1696
+ Update the remote cache with model files if needed.
1697
+ """
1698
+ try:
1699
+ manifest = self._fetch_manifest(model_name)
1700
+ if not manifest:
1701
+ if self.debug:
1702
+ print(
1703
+ f"[@ollama {model_name}] No manifest available for cache update."
1704
+ )
1705
+ return
1706
+
1707
+ from metaflow import S3
1708
+
1709
+ cache_up_to_date = True
1710
+ key_paths = [
1711
+ (
1712
+ self.storage_info[model_name]["manifest_remote"],
1713
+ self.storage_info[model_name]["manifest_local"],
1714
+ )
1715
+ ]
1716
+
1717
+ with S3() as s3:
1718
+ # Check if blobs need updating
1719
+ s3objs = s3.list_paths(
1720
+ [self.storage_info[model_name]["blob_remote_root"]]
1721
+ )
1722
+ for layer in manifest["layers"]:
1723
+ expected_blob_sha = layer["digest"]
1724
+ if expected_blob_sha not in s3objs:
1725
+ cache_up_to_date = False
1726
+ break
1727
+
1728
+ if not cache_up_to_date:
1729
+ blob_count = len(manifest.get("layers", []))
1730
+ print(
1731
+ f"[@ollama {model_name}] Uploading {blob_count} files to cache..."
1732
+ )
1733
+
1734
+ # Add blob paths to upload
1735
+ for layer in manifest["layers"]:
1736
+ blob_filename = layer["digest"].replace(":", "-")
1737
+ key_paths.append(
1738
+ (
1739
+ os.path.join(
1740
+ self.storage_info[model_name]["blob_remote_root"],
1741
+ blob_filename,
1742
+ ),
1743
+ os.path.join(
1744
+ self.storage_info[model_name]["blob_local_root"],
1745
+ blob_filename,
1746
+ ),
1747
+ )
1748
+ )
1749
+
1750
+ s3.put_files(key_paths)
1751
+ print(f"[@ollama {model_name}] Cache updated.")
1752
+ else:
1753
+ if self.debug:
1754
+ print(f"[@ollama {model_name}] Cache is up to date.")
1755
+
1756
+ except Exception as e:
1757
+ if self.debug:
1758
+ print(f"[@ollama {model_name}] Error updating cache: {e}")
1759
+
1760
+ def get_ollama_storage_root(self, backend):
1761
+ """
1762
+ Return the path to the root of the datastore.
1763
+ """
1764
+ if backend.TYPE == "s3":
1765
+ from metaflow.metaflow_config import DATASTORE_SYSROOT_S3
1766
+
1767
+ self.local_datastore = False
1768
+ return os.path.join(DATASTORE_SYSROOT_S3, OLLAMA_SUFFIX)
1769
+ elif backend.TYPE == "azure":
1770
+ from metaflow.metaflow_config import DATASTORE_SYSROOT_AZURE
1771
+
1772
+ self.local_datastore = False
1773
+ return os.path.join(DATASTORE_SYSROOT_AZURE, OLLAMA_SUFFIX)
1774
+ elif backend.TYPE == "gs":
1775
+ from metaflow.metaflow_config import DATASTORE_SYSROOT_GS
1776
+
1777
+ self.local_datastore = False
1778
+ return os.path.join(DATASTORE_SYSROOT_GS, OLLAMA_SUFFIX)
1779
+ else:
1780
+ self.local_datastore = True
1781
+ return None
1782
+
1783
+ def _attempt_ollama_restart(self):
1784
+ """Attempt to restart Ollama when circuit breaker suggests it"""
1785
+ try:
1786
+ print("[@ollama] Attempting Ollama restart due to circuit breaker...")
1787
+
1788
+ # Stop existing server processes
1789
+ server_stopped = False
1790
+ for pid, process_info in list(self.processes.items()):
1791
+ if process_info["properties"].get("type") == "api-server":
1792
+ process = process_info["p"]
1793
+ try:
1794
+ process.terminate()
1795
+ process.wait(timeout=10)
1796
+ if process.poll() is None:
1797
+ process.kill()
1798
+ process.wait()
1799
+ process_info["status"] = ProcessStatus.SUCCESSFUL
1800
+ server_stopped = True
1801
+ if self.debug:
1802
+ print(
1803
+ f"[@ollama] Stopped server process {pid} during restart"
1804
+ )
1805
+ except Exception as e:
1806
+ if self.debug:
1807
+ print(
1808
+ f"[@ollama] Error stopping server {pid} during restart: {e}"
1809
+ )
1810
+
1811
+ if not server_stopped:
1812
+ if self.debug:
1813
+ print("[@ollama] No server process found to stop during restart")
1814
+
1815
+ # Small delay to ensure cleanup
1816
+ time.sleep(2)
1817
+
1818
+ # Restart server
1819
+ self._launch_server()
1820
+
1821
+ # Verify health with multiple attempts
1822
+ health_attempts = 3
1823
+ for attempt in range(health_attempts):
1824
+ if self._verify_server_health():
1825
+ print("[@ollama] Restart successful")
1826
+ return True
1827
+ else:
1828
+ if attempt < health_attempts - 1:
1829
+ if self.debug:
1830
+ print(
1831
+ f"[@ollama] Health check failed, attempt {attempt + 1}/{health_attempts}"
1832
+ )
1833
+ time.sleep(5)
1834
+
1835
+ print(
1836
+ "[@ollama] Restart failed - server not healthy after multiple attempts"
1837
+ )
1838
+ return False
1839
+
1840
+ except Exception as e:
1841
+ print(f"[@ollama] Restart failed: {e}")
1842
+ return False
1843
+
1844
+ def _verify_server_health(self):
1845
+ """Quick health check for server availability"""
1846
+ try:
1847
+ response = requests.get(
1848
+ f"{self.ollama_url}/api/tags",
1849
+ timeout=self.timeouts.get("health_check", 5),
1850
+ )
1851
+ return response.status_code == 200
1852
+ except Exception:
1853
+ return False
1854
+
1855
+ def _collect_model_metadata(self, m):
1856
+ """
1857
+ Collect model metadata including size and blob count from manifest and API
1858
+ """
1859
+ metadata = {"size_bytes": None, "size_formatted": "Unknown", "blob_count": 0}
1860
+
1861
+ try:
1862
+ # First try to get info from manifest (works for cached models)
1863
+ manifest = self._fetch_manifest(m)
1864
+ if manifest and "layers" in manifest:
1865
+ metadata["blob_count"] = len(manifest["layers"])
1866
+
1867
+ # Calculate total size from manifest layers
1868
+ total_size = 0
1869
+ for layer in manifest["layers"]:
1870
+ if "size" in layer:
1871
+ total_size += layer["size"]
1872
+
1873
+ if total_size > 0:
1874
+ metadata["size_bytes"] = total_size
1875
+ metadata["size_formatted"] = self._format_bytes(total_size)
1876
+
1877
+ # Try to get more detailed info from Ollama API if available
1878
+ try:
1879
+ response = requests.post(
1880
+ f"{self.ollama_url}/api/show", json={"model": m}, timeout=10
1881
+ )
1882
+ if response.status_code == 200:
1883
+ model_info = response.json()
1884
+
1885
+ # Extract size if available in the response
1886
+ if (
1887
+ "details" in model_info
1888
+ and "parameter_size" in model_info["details"]
1889
+ ):
1890
+ # Sometimes the API returns parameter size info
1891
+ param_size = model_info["details"]["parameter_size"]
1892
+ if self.debug:
1893
+ print(f"[@ollama {m}] Parameter size: {param_size}")
1894
+
1895
+ # If we get model_info but didn't have manifest info, try to get layer count
1896
+ if metadata["blob_count"] == 0 and "details" in model_info:
1897
+ details = model_info["details"]
1898
+ if "families" in details or "family" in details:
1899
+ # API response structure varies, estimate blob count
1900
+ metadata["blob_count"] = "API"
1901
+
1902
+ except Exception as api_e:
1903
+ if self.debug:
1904
+ print(f"[@ollama {m}] Could not get API metadata: {api_e}")
1905
+
1906
+ except Exception as e:
1907
+ if self.debug:
1908
+ print(f"[@ollama {m}] Error collecting model metadata: {e}")
1909
+
1910
+ return metadata
1911
+
1912
+ def _format_bytes(self, bytes_count):
1913
+ """Format bytes into human-readable string"""
1914
+ if bytes_count is None:
1915
+ return "Unknown"
1916
+
1917
+ for unit in ["B", "KB", "MB", "GB", "TB"]:
1918
+ if bytes_count < 1024.0:
1919
+ if unit == "B":
1920
+ return f"{int(bytes_count)} {unit}"
1921
+ else:
1922
+ return f"{bytes_count:.1f} {unit}"
1923
+ bytes_count /= 1024.0
1924
+ return f"{bytes_count:.1f} PB"