ob-metaflow-extensions 1.1.151__py2.py3-none-any.whl → 1.4.33__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. metaflow_extensions/outerbounds/__init__.py +1 -1
  2. metaflow_extensions/outerbounds/plugins/__init__.py +17 -3
  3. metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
  4. metaflow_extensions/outerbounds/plugins/apps/app_deploy_decorator.py +146 -0
  5. metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +10 -0
  6. metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
  7. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
  8. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
  9. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
  10. metaflow_extensions/outerbounds/plugins/apps/core/app_cli.py +1200 -0
  11. metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +146 -0
  12. metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
  13. metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
  14. metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
  15. metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
  16. metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
  17. metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
  18. metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +12 -0
  19. metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +161 -0
  20. metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +868 -0
  21. metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +288 -0
  22. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +139 -0
  23. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +398 -0
  24. metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1088 -0
  25. metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
  26. metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
  27. metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +303 -0
  28. metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
  29. metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
  30. metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
  31. metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
  32. metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
  33. metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
  34. metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
  35. metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +78 -0
  36. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +9 -77
  37. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
  38. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +7 -78
  39. metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
  40. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +17 -3
  41. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +1 -0
  42. metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +18 -44
  43. metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
  44. metaflow_extensions/outerbounds/plugins/nim/card.py +1 -6
  45. metaflow_extensions/outerbounds/plugins/nim/{__init__.py → nim_decorator.py} +13 -49
  46. metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +294 -233
  47. metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
  48. metaflow_extensions/outerbounds/plugins/nvcf/constants.py +2 -2
  49. metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +32 -8
  50. metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +1 -1
  51. metaflow_extensions/outerbounds/plugins/ollama/__init__.py +171 -16
  52. metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
  53. metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
  54. metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1710 -114
  55. metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
  56. metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
  57. metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
  58. metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
  59. metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
  60. metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
  61. metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
  62. metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
  63. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
  64. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
  65. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
  66. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +6 -3
  67. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +13 -7
  68. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +8 -2
  69. metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
  70. metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
  71. metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
  72. metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
  73. metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
  74. metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
  75. metaflow_extensions/outerbounds/remote_config.py +27 -3
  76. metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +86 -2
  77. metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
  78. metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
  79. metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
  80. metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
  81. metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
  82. {ob_metaflow_extensions-1.1.151.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/METADATA +2 -2
  83. ob_metaflow_extensions-1.4.33.dist-info/RECORD +134 -0
  84. metaflow_extensions/outerbounds/plugins/nim/utilities.py +0 -5
  85. ob_metaflow_extensions-1.1.151.dist-info/RECORD +0 -74
  86. {ob_metaflow_extensions-1.1.151.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/WHEEL +0 -0
  87. {ob_metaflow_extensions-1.1.151.dist-info → ob_metaflow_extensions-1.4.33.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,958 @@
1
+ from datetime import datetime
2
+ import json
3
+ import os
4
+ import pathlib
5
+ import requests
6
+ import sys
7
+ import time
8
+ from functools import partial
9
+ import shlex
10
+ from typing import Optional, List, Dict, Any, Tuple, Union, Callable
11
+ from .utils import TODOException, safe_requests_wrapper, MaximumRetriesExceeded
12
+ from .app_config import AppConfig, CAPSULE_DEBUG, AuthType
13
+ from . import experimental
14
+ from ._state_machine import (
15
+ _capsule_worker_semantic_status,
16
+ _capsule_worker_status_diff,
17
+ CapsuleWorkerSemanticStatus,
18
+ WorkerStatus,
19
+ CapsuleStatus,
20
+ DEPLOYMENT_READY_CONDITIONS,
21
+ LogLine,
22
+ )
23
+
24
+
25
+ def _format_url_string(url):
26
+ if url is None:
27
+ return None
28
+
29
+ if url.startswith("http://") or url.startswith("https://"):
30
+ return url
31
+
32
+ return f"https://{url}"
33
+
34
+
35
+ class CapsuleStateMachine:
36
+ """
37
+ - Every capsule create call will return a `identifier` and a `version` of the object.
38
+ - Each update call will return a new version.
39
+ - The status.currentlyServedVersion will be the version that is currently serving traffic.
40
+ - The status.updateInProgress will be True if an upgrade is in progress.
41
+
42
+ CapsuleState Transition:
43
+ - Every capsule create call will return a `identifier` and a `version` of the object.
44
+ - Happy Path:
45
+ - First time Create :
46
+ - wait for status.updateInProgress to be set to False
47
+ - (interleaved) Poll the worker endpoints to check their status
48
+ - showcase how many workers are coming up if things are on the cli side.
49
+ - If the user has set some flag like `--dont-wait-to-fully-finish` then we check the `status.currentlyServedVersion` to see if even one replica is ready to
50
+ serve traffic.
51
+ - once the status.updateInProgress is set to False, it means that the replicas are ready
52
+ - Upgrade:
53
+ - wait for status.updateInProgress to be set to False
54
+ - (interleaved) Poll the worker endpoints to check their status and signal the user the number replicas coming up
55
+ - If the user has set some flag like `--dont-wait-to-fully-finish` then we check the `status.currentlyServedVersion` to see if even one replica is ready to
56
+ serve traffic.
57
+ - Unhappy Path:
58
+ - First time Create :
59
+ - wait for status.updateInProgress to be set to False,
60
+ - (interleaved) Poll the workers to check their status.
61
+ - If the worker pertaining the current deployment instance version is crashlooping then crash the deployment process with the error messages and logs.
62
+ - Upgrade:
63
+ - wait for status.updateInProgress to be set to False,
64
+ - (interleaved) Poll the workers to check their status.
65
+ - If the worker pertaining the current deployment instance version is crashlooping then crash the deployment process with the error messages and logs.
66
+
67
+ """
68
+
69
+ def __init__(self, capsule_id: str, current_deployment_instance_version: str):
70
+ self._capsule_id = capsule_id
71
+ self._status_trail: List[Dict[str, Any]] = []
72
+ self._current_deployment_instance_version = current_deployment_instance_version
73
+
74
+ def get_status_trail(self):
75
+ return self._status_trail
76
+
77
+ def add_status(self, status: CapsuleStatus):
78
+ self._status_trail.append({"timestamp": time.time(), "status": status})
79
+
80
+ @property
81
+ def current_status(self):
82
+ return self._status_trail[-1].get("status")
83
+
84
+ @property
85
+ def out_of_cluster_url(self):
86
+ access_info = self.current_status.get("accessInfo", {}) or {}
87
+ return _format_url_string(access_info.get("outOfClusterURL", None))
88
+
89
+ @property
90
+ def in_cluster_url(self):
91
+ access_info = self.current_status.get("accessInfo", {}) or {}
92
+ return _format_url_string(access_info.get("inClusterURL", None))
93
+
94
+ @property
95
+ def update_in_progress(self):
96
+ return self.current_status.get("updateInProgress", False)
97
+
98
+ @property
99
+ def currently_served_version(self):
100
+ return self.current_status.get("currentlyServedVersion", None)
101
+
102
+ @property
103
+ def ready_to_serve_traffic(self):
104
+ if self.current_status.get("readyToServeTraffic", False):
105
+ return any(
106
+ i is not None for i in [self.out_of_cluster_url, self.in_cluster_url]
107
+ )
108
+ return False
109
+
110
+ @property
111
+ def available_replicas(self):
112
+ return self.current_status.get("availableReplicas", 0)
113
+
114
+ def report_current_status(self, logger):
115
+ pass
116
+
117
+ def save_debug_info(self, state_dir: str):
118
+ debug_path = os.path.join(
119
+ state_dir, f"debug_capsule_sm_{self._capsule_id}.json"
120
+ )
121
+ with open(debug_path, "w") as f:
122
+ json.dump(self._status_trail, f, indent=4)
123
+
124
+
125
+ class CapsuleWorkersStateMachine:
126
+ def __init__(
127
+ self,
128
+ capsule_id: str,
129
+ end_state_capsule_version: str,
130
+ deployment_mode: str = DEPLOYMENT_READY_CONDITIONS.ATLEAST_ONE_RUNNING,
131
+ minimum_replicas: int = 1,
132
+ ):
133
+ self._capsule_id = capsule_id
134
+ self._end_state_capsule_version = end_state_capsule_version
135
+ self._deployment_mode = deployment_mode
136
+ self._minimum_replicas = minimum_replicas
137
+ self._status_trail: List[Dict[str, Union[float, List[WorkerStatus]]]] = []
138
+
139
+ def get_status_trail(self):
140
+ return self._status_trail
141
+
142
+ def add_status(self, worker_list_response: List[WorkerStatus]):
143
+ """
144
+ worker_list_response: List[Dict[str, Any]]
145
+ [
146
+ {
147
+ "workerId": "c-4pqikm-659dd9ccdc-5hcwz",
148
+ "phase": "Running",
149
+ "activity": 0,
150
+ "activityDataAvailable": false,
151
+ "version": "0xhgaewiqb"
152
+ },
153
+ {
154
+ "workerId": "c-4pqikm-b8559688b-xk2jh",
155
+ "phase": "Pending",
156
+ "activity": 0,
157
+ "activityDataAvailable": false,
158
+ "version": "421h48qh95"
159
+ }
160
+ ]
161
+ """
162
+ self._status_trail.append(
163
+ {"timestamp": time.time(), "status": worker_list_response}
164
+ )
165
+
166
+ def save_debug_info(self, state_dir: str):
167
+ debug_path = os.path.join(
168
+ state_dir, f"debug_capsule_workers_{self._capsule_id}_trail.json"
169
+ )
170
+ with open(debug_path, "w") as f:
171
+ json.dump(self._status_trail, f, indent=4)
172
+
173
+ status_path = os.path.join(
174
+ state_dir, f"debug_capsule_workers_{self._capsule_id}_status.json"
175
+ )
176
+ with open(status_path, "w") as f:
177
+ json.dump(self.current_version_deployment_status(), f, indent=4)
178
+
179
+ def report_current_status(self, logger):
180
+ if len(self._status_trail) == 0:
181
+ return
182
+ older_status = None
183
+ if len(self._status_trail) >= 2:
184
+ older_status = _capsule_worker_semantic_status(
185
+ self._status_trail[-2].get("status"),
186
+ self._end_state_capsule_version,
187
+ self._minimum_replicas,
188
+ )
189
+ current_status = self.current_version_deployment_status()
190
+ changes = _capsule_worker_status_diff(current_status, older_status)
191
+ if len(changes) > 0:
192
+ logger(*changes)
193
+
194
+ @property
195
+ def current_status(self) -> List[WorkerStatus]:
196
+ return self._status_trail[-1].get("status") # type: ignore
197
+
198
+ def current_version_deployment_status(self) -> CapsuleWorkerSemanticStatus:
199
+ return _capsule_worker_semantic_status(
200
+ self.current_status, self._end_state_capsule_version, self._minimum_replicas
201
+ )
202
+
203
+ @property
204
+ def is_crashlooping(self) -> bool:
205
+ status = self.current_version_deployment_status()
206
+ return status["status"]["at_least_one_crashlooping"]
207
+
208
+
209
+ class CapsuleInput:
210
+ @classmethod
211
+ def construct_exec_command(cls, commands: List[str]):
212
+ commands = ["set -eEuo pipefail"] + commands
213
+ command_string = "\n".join(commands)
214
+ # First construct a base64 encoded string of the quoted command
215
+ # One of the reasons we don't directly pass the command string to the backend with a `\n` join
216
+ # is because the backend controller doesn't play nice when the command can be a multi-line string.
217
+ # So we encode it to a base64 string and then decode it back to a command string at runtime to provide to
218
+ # `bash -c`. The ideal thing to have done is to run "bash -c {shlex.quote(command_string)}" and call it a day
219
+ # but the backend controller yields the following error:
220
+ # `error parsing template: error converting YAML to JSON: yaml: line 111: mapping values are not allowed in this context`
221
+ # So we go to great length to ensure the command is provided in base64 to avoid any issues with the backend controller.
222
+ import base64
223
+
224
+ encoded_command = base64.b64encode(command_string.encode()).decode()
225
+ decode_cmd = f"echo {encoded_command} | base64 -d > ./_ob_app_run.sh"
226
+ return (
227
+ f"bash -c '{decode_cmd} && cat ./_ob_app_run.sh && bash ./_ob_app_run.sh'"
228
+ )
229
+
230
+ @classmethod
231
+ def _marshal_environment_variables(cls, app_config: AppConfig):
232
+ envs = app_config.get_state("environment", {}).copy()
233
+ _return = []
234
+ for k, v in envs.items():
235
+ _v = v
236
+ if isinstance(v, dict):
237
+ _v = json.dumps(v)
238
+ elif isinstance(v, list):
239
+ _v = json.dumps(v)
240
+ else:
241
+ _v = str(v)
242
+ _return.append(
243
+ {
244
+ "name": k,
245
+ "value": _v,
246
+ }
247
+ )
248
+ return _return
249
+
250
+ @classmethod
251
+ def from_app_config(cls, app_config: AppConfig):
252
+ ## Replica settings
253
+ replicas = app_config.get_state("replicas", {})
254
+ fixed, _min, _max = (
255
+ replicas.get("fixed"),
256
+ replicas.get("min"),
257
+ replicas.get("max"),
258
+ )
259
+ rpm = replicas.get("scaling_policy", {}).get("rpm", None)
260
+ autoscaling_config = {}
261
+ if rpm:
262
+ autoscaling_config = {
263
+ "requestRateBasedAutoscalingConfig": {"targetRequestsPerMinute": rpm}
264
+ }
265
+ if fixed is not None:
266
+ _min, _max = fixed, fixed
267
+ gpu_resource = app_config.get_state("resources").get("gpu")
268
+ resources = {}
269
+ shared_memory = app_config.get_state("resources").get("shared_memory")
270
+ if gpu_resource:
271
+ resources["gpu"] = gpu_resource
272
+ if shared_memory:
273
+ resources["sharedMemory"] = shared_memory
274
+
275
+ _scheduling_config = {}
276
+ if app_config.get_state("compute_pools", None):
277
+ _scheduling_config["schedulingConfig"] = {
278
+ "computePools": [
279
+ {"name": x} for x in app_config.get_state("compute_pools")
280
+ ]
281
+ }
282
+ _description = app_config.get_state("description")
283
+ _app_type = app_config.get_state("app_type")
284
+ _final_info = {}
285
+ if _description:
286
+ _final_info["description"] = _description
287
+ if _app_type:
288
+ _final_info["endpointType"] = _app_type
289
+ return {
290
+ "perimeter": app_config.get_state("perimeter"),
291
+ **_final_info,
292
+ "codePackagePath": app_config.get_state("code_package_url"),
293
+ "image": app_config.get_state("image"),
294
+ "resourceIntegrations": [
295
+ {"name": x} for x in app_config.get_state("secrets", [])
296
+ ],
297
+ "resourceConfig": {
298
+ "cpu": str(app_config.get_state("resources").get("cpu")),
299
+ "memory": str(app_config.get_state("resources").get("memory")),
300
+ "ephemeralStorage": str(app_config.get_state("resources").get("disk")),
301
+ **resources,
302
+ },
303
+ "autoscalingConfig": {
304
+ "minReplicas": _min,
305
+ "maxReplicas": _max,
306
+ **autoscaling_config,
307
+ },
308
+ **_scheduling_config,
309
+ "containerStartupConfig": {
310
+ "entrypoint": cls.construct_exec_command(
311
+ app_config.get_state("commands")
312
+ )
313
+ },
314
+ "environmentVariables": cls._marshal_environment_variables(app_config),
315
+ # "assets": [{"name": "startup-script.sh"}],
316
+ "authConfig": {
317
+ "authType": app_config.get_state("auth").get("type"),
318
+ "publicToDeployment": app_config.get_state("auth").get("public"),
319
+ },
320
+ "tags": [
321
+ dict(key=k, value=v)
322
+ for tag in app_config.get_state("tags", [])
323
+ for k, v in tag.items()
324
+ ],
325
+ "port": app_config.get_state("port"),
326
+ "displayName": app_config.get_state("name"),
327
+ "forceUpdate": app_config.get_state("force_upgrade", False),
328
+ }
329
+
330
+
331
+ class CapsuleApiException(Exception):
332
+ def __init__(
333
+ self,
334
+ url: str,
335
+ method: str,
336
+ status_code: int,
337
+ text: str,
338
+ message: Optional[str] = None,
339
+ ):
340
+ self.url = url
341
+ self.method = method
342
+ self.status_code = status_code
343
+ self.text = text
344
+ self.message = message
345
+
346
+ def __str__(self):
347
+ return (
348
+ f"CapsuleApiException: {self.url} [{self.method}]: Status Code: {self.status_code} \n\n {self.text}"
349
+ + (f"\n\n {self.message}" if self.message else "")
350
+ )
351
+
352
+
353
+ class CapsuleDeploymentException(Exception):
354
+ def __init__(
355
+ self,
356
+ capsule_id: str,
357
+ message: str,
358
+ ):
359
+ self.capsule_id = capsule_id
360
+ self.message = message
361
+
362
+ def __str__(self):
363
+ return f"CapsuleDeploymentException: [{self.capsule_id}] :: {self.message}"
364
+
365
+
366
+ class CapsuleApi:
367
+ def __init__(self, base_url: str, perimeter: str, logger_fn=None):
368
+ self._base_url = self._create_base_url(base_url, perimeter)
369
+ from metaflow.metaflow_config import SERVICE_HEADERS
370
+
371
+ self._logger_fn = logger_fn
372
+ self._request_headers = {
373
+ **{"Content-Type": "application/json", "Connection": "keep-alive"},
374
+ **(SERVICE_HEADERS or {}),
375
+ }
376
+
377
+ @staticmethod
378
+ def _create_base_url(base_url: str, perimeter: str):
379
+ return os.path.join(
380
+ base_url,
381
+ "v1",
382
+ "perimeters",
383
+ perimeter,
384
+ "capsules",
385
+ )
386
+
387
+ def _wrapped_api_caller(self, method_func, *args, **kwargs):
388
+ try:
389
+ response = safe_requests_wrapper(
390
+ method_func,
391
+ *args,
392
+ headers=self._request_headers,
393
+ logger_fn=self._logger_fn,
394
+ **kwargs,
395
+ )
396
+ except MaximumRetriesExceeded as e:
397
+ raise CapsuleApiException(
398
+ e.url,
399
+ e.method,
400
+ e.status_code,
401
+ e.text,
402
+ message=f"Maximum retries exceeded for {e.url} [{e.method}]",
403
+ )
404
+ if response.status_code >= 400:
405
+ raise CapsuleApiException(
406
+ args[0],
407
+ method_func.__name__,
408
+ response.status_code,
409
+ response.text,
410
+ )
411
+ return response
412
+
413
+ def create(self, capsule_input: dict):
414
+ _data = json.dumps(capsule_input)
415
+ response = self._wrapped_api_caller(
416
+ requests.post,
417
+ self._base_url,
418
+ data=_data,
419
+ )
420
+ try:
421
+ return response.json()
422
+ except json.JSONDecodeError as e:
423
+ raise CapsuleApiException(
424
+ self._base_url,
425
+ "post",
426
+ response.status_code,
427
+ response.text,
428
+ message="Capsule JSON decode failed",
429
+ )
430
+
431
+ def get(self, capsule_id: str) -> Dict[str, Any]:
432
+ _url = os.path.join(self._base_url, capsule_id)
433
+ response = self._wrapped_api_caller(
434
+ requests.get,
435
+ _url,
436
+ retryable_status_codes=[409, 404], # todo : verify me
437
+ conn_error_retries=3,
438
+ )
439
+ try:
440
+ return response.json()
441
+ except json.JSONDecodeError as e:
442
+ raise CapsuleApiException(
443
+ _url,
444
+ "get",
445
+ response.status_code,
446
+ response.text,
447
+ message="Capsule JSON decode failed",
448
+ )
449
+
450
+ # TODO: refactor me since name *currently(9/8/25)* is unique across capsules.
451
+ def get_by_name(self, name: str, most_recent_only: bool = True):
452
+ _url = os.path.join(self._base_url, f"?displayName={name}")
453
+ response = self._wrapped_api_caller(
454
+ requests.get,
455
+ _url,
456
+ retryable_status_codes=[409], # todo : verify me
457
+ conn_error_retries=3,
458
+ )
459
+ try:
460
+ if most_recent_only:
461
+ result = response.json()
462
+ candidates = result["capsules"]
463
+ if not candidates:
464
+ return None
465
+ return sorted(
466
+ candidates, key=lambda x: x["metadata"]["createdAt"], reverse=True
467
+ )[0]
468
+ else:
469
+ return response.json()
470
+ except json.JSONDecodeError as e:
471
+ raise CapsuleApiException(
472
+ _url,
473
+ "get",
474
+ response.status_code,
475
+ response.text,
476
+ message="Capsule JSON decode failed",
477
+ )
478
+
479
+ def list(self):
480
+ response = self._wrapped_api_caller(
481
+ requests.get,
482
+ self._base_url,
483
+ retryable_status_codes=[409], # todo : verify me
484
+ conn_error_retries=3,
485
+ )
486
+ try:
487
+ response_json = response.json()
488
+ except json.JSONDecodeError as e:
489
+ raise CapsuleApiException(
490
+ self._base_url,
491
+ "get",
492
+ response.status_code,
493
+ response.text,
494
+ message="Capsule JSON decode failed",
495
+ )
496
+ if "capsules" not in response_json:
497
+ raise CapsuleApiException(
498
+ self._base_url,
499
+ "get",
500
+ response.status_code,
501
+ response.text,
502
+ message="Capsule JSON decode failed",
503
+ )
504
+ return response_json.get("capsules", []) or []
505
+
506
+ def delete(self, capsule_id: str):
507
+ _url = os.path.join(self._base_url, capsule_id)
508
+ response = self._wrapped_api_caller(
509
+ requests.delete,
510
+ _url,
511
+ retryable_status_codes=[409], # todo : verify me
512
+ )
513
+ if response.status_code >= 400:
514
+ raise CapsuleApiException(
515
+ _url,
516
+ "delete",
517
+ response.status_code,
518
+ response.text,
519
+ )
520
+
521
+ if response.status_code == 200:
522
+ return True
523
+ return False
524
+
525
+ def get_workers(self, capsule_id: str) -> List[Dict[str, Any]]:
526
+ _url = os.path.join(self._base_url, capsule_id, "workers")
527
+ response = self._wrapped_api_caller(
528
+ requests.get,
529
+ _url,
530
+ retryable_status_codes=[409, 404], # todo : verify me
531
+ # Adding 404s because sometimes we might even end up getting 404s if
532
+ # the backend cache is not updated yet. So on consistent 404s we should
533
+ # just crash out.
534
+ conn_error_retries=3,
535
+ )
536
+ try:
537
+ return response.json().get("workers", []) or []
538
+ except json.JSONDecodeError as e:
539
+ raise CapsuleApiException(
540
+ _url,
541
+ "get",
542
+ response.status_code,
543
+ response.text,
544
+ message="Capsule JSON decode failed",
545
+ )
546
+
547
+ def logs(
548
+ self, capsule_id: str, worker_id: str, previous: bool = False
549
+ ) -> List[LogLine]:
550
+ _url = os.path.join(self._base_url, capsule_id, "workers", worker_id, "logs")
551
+ options = None
552
+ if previous:
553
+ options = {"previous": True}
554
+ response = self._wrapped_api_caller(
555
+ requests.get,
556
+ _url,
557
+ retryable_status_codes=[409], # todo : verify me
558
+ params=options,
559
+ )
560
+ try:
561
+ return response.json().get("logs", []) or []
562
+ except json.JSONDecodeError as e:
563
+ raise CapsuleApiException(
564
+ _url,
565
+ "get",
566
+ response.status_code,
567
+ response.text,
568
+ message="Capsule JSON decode failed",
569
+ )
570
+
571
+ def patch(self, capsule_id: str, patch_input: dict):
572
+ capsule_response = self.get(capsule_id)
573
+ if "spec" not in capsule_response or len(capsule_response.get("spec", {})) == 0:
574
+ raise CapsuleApiException(
575
+ self._base_url,
576
+ "patch",
577
+ 403,
578
+ "Capsule response of incorrect format",
579
+ )
580
+
581
+ spec = capsule_response.get("spec")
582
+ spec.update(patch_input)
583
+ return self.create(spec)
584
+
585
+
586
+ def list_and_filter_capsules(
587
+ capsule_api: CapsuleApi, project, branch, name, tags, auth_type, capsule_id
588
+ ):
589
+ capsules = capsule_api.list()
590
+
591
+ def _tags_match(tags, key, value):
592
+ for t in tags:
593
+ if t["key"] == key and t["value"] == value:
594
+ return True
595
+ return False
596
+
597
+ def _all_tags_match(tags, tags_to_match):
598
+ return all([_tags_match(tags, t["key"], t["value"]) for t in tags_to_match])
599
+
600
+ def _filter_capsules(capsules, project, branch, name, tags, auth_type, capsule_id):
601
+ _filtered_capsules = []
602
+ for capsule in capsules:
603
+ set_tags = capsule.get("spec", {}).get("tags", [])
604
+ display_name = capsule.get("spec", {}).get("displayName", None)
605
+ set_id = capsule.get("id", None)
606
+ set_auth_type = (
607
+ capsule.get("spec", {}).get("authConfig", {}).get("authType", None)
608
+ )
609
+
610
+ if auth_type and set_auth_type != auth_type:
611
+ continue
612
+ if project and not _tags_match(set_tags, "project", project):
613
+ continue
614
+ if branch and not _tags_match(set_tags, "branch", branch):
615
+ continue
616
+ if name and display_name != name:
617
+ continue
618
+ if tags and not _all_tags_match(set_tags, tags):
619
+ continue
620
+ if capsule_id and set_id != capsule_id:
621
+ continue
622
+
623
+ _filtered_capsules.append(capsule)
624
+ return _filtered_capsules
625
+
626
+ return _filter_capsules(
627
+ capsules, project, branch, name, tags, auth_type, capsule_id
628
+ )
629
+
630
+
631
+ from collections import namedtuple
632
+
633
+ CapsuleInfo = namedtuple("CapsuleInfo", ["info", "workers"])
634
+
635
+
636
+ class CapsuleDeployer:
637
+
638
+ status: CapsuleStateMachine
639
+
640
+ identifier = None
641
+
642
+ # TODO: Current default timeout is very large of 5 minutes. Ideally we should have finished the deployed in less than 1 minutes.
643
+ def __init__(
644
+ self,
645
+ app_config: AppConfig,
646
+ base_url: str,
647
+ create_timeout: int = 60 * 5,
648
+ debug_dir: Optional[str] = None,
649
+ success_terminal_state_condition: str = DEPLOYMENT_READY_CONDITIONS.ATLEAST_ONE_RUNNING,
650
+ readiness_wait_time: int = 20,
651
+ logger_fn=None,
652
+ ):
653
+ self._app_config = app_config
654
+ self._capsule_api = CapsuleApi(
655
+ base_url,
656
+ app_config.get_state("perimeter"),
657
+ logger_fn=logger_fn or partial(print, file=sys.stderr),
658
+ )
659
+ self._create_timeout = create_timeout
660
+ self._logger_fn = logger_fn
661
+ self._debug_dir = debug_dir
662
+ self._capsule_deploy_response = None
663
+ self._success_terminal_state_condition = success_terminal_state_condition
664
+ self._readiness_wait_time = readiness_wait_time
665
+
666
+ @property
667
+ def url(self):
668
+ return _format_url_string(
669
+ ({} or self._capsule_deploy_response).get("outOfClusterUrl", None)
670
+ )
671
+
672
+ @property
673
+ def capsule_api(self):
674
+ return self._capsule_api
675
+
676
+ @property
677
+ def capsule_type(self):
678
+ auth_type = self._app_config.get_state("auth", {}).get("type", AuthType.default)
679
+ if auth_type == AuthType.BROWSER:
680
+ return "App"
681
+ elif auth_type == AuthType.API or auth_type == AuthType.BROWSER_AND_API:
682
+ return "Endpoint"
683
+ else:
684
+ raise TODOException(f"Unknown auth type: {auth_type}")
685
+
686
+ @property
687
+ def name(self):
688
+ return self._app_config.get_state("name")
689
+
690
+ def create_input(self):
691
+ return experimental.capsule_input_overrides(
692
+ self._app_config, CapsuleInput.from_app_config(self._app_config)
693
+ )
694
+
695
+ @property
696
+ def current_deployment_instance_version(self):
697
+ """
698
+ The backend `create` call returns a version of the object that will be
699
+ """
700
+ if self._capsule_deploy_response is None:
701
+ return None
702
+ return self._capsule_deploy_response.get("version", None)
703
+
704
+ def create(self):
705
+ capsule_response = self._capsule_api.create(self.create_input())
706
+ self.identifier = capsule_response.get("id")
707
+ self._capsule_deploy_response = capsule_response
708
+ return self.identifier
709
+
710
+ def get(self):
711
+ return self._capsule_api.get(self.identifier)
712
+
713
+ def get_workers(self):
714
+ return self._capsule_api.get_workers(self.identifier)
715
+
716
+ def _backend_version_mismatch_check(
717
+ self, capsule_response: dict, current_deployment_instance_version: str
718
+ ):
719
+ """
720
+ - `capsule_response.version` contains the version of the object present in the database
721
+ - `current_deployment_instance_version` contains the version of the object that was deployed by this instance of the deployer.
722
+ In the situation that the versions of the objects become a mismatch then it means that current deployment process is not giving the user the
723
+ output that they desire.
724
+ """
725
+ if capsule_response.get("version", None) != current_deployment_instance_version:
726
+ raise CapsuleDeploymentException(
727
+ self.identifier, # type: ignore
728
+ f"A capsule upgrade was triggered outside current deployment instance. Current deployment version was discarded. Current deployment version: {current_deployment_instance_version} and new version: {capsule_response.get('version', None)}",
729
+ )
730
+
731
+ def _update_capsule_and_worker_sm(
732
+ self,
733
+ capsule_sm: "CapsuleStateMachine",
734
+ workers_sm: "CapsuleWorkersStateMachine",
735
+ logger: Callable[[str], None],
736
+ ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
737
+ capsule_response = self.get()
738
+ capsule_sm.add_status(capsule_response.get("status", {})) # type: ignore
739
+
740
+ # We need to check if someone has not upgraded the capsule under the hood and
741
+ # the current deployment instance is invalid.
742
+ self._backend_version_mismatch_check(
743
+ capsule_response, self.current_deployment_instance_version # type: ignore
744
+ )
745
+ workers_response = self.get_workers()
746
+ capsule_sm.report_current_status(logger)
747
+ workers_sm.add_status(workers_response)
748
+ workers_sm.report_current_status(logger)
749
+ return capsule_response, workers_response
750
+
751
+ def _publish_capsule_debug_info(
752
+ self,
753
+ capsule_sm: "CapsuleStateMachine",
754
+ workers_sm: "CapsuleWorkersStateMachine",
755
+ capsule_response: Dict[str, Any],
756
+ ):
757
+ if CAPSULE_DEBUG and self._debug_dir:
758
+ capsule_sm.save_debug_info(self._debug_dir)
759
+ workers_sm.save_debug_info(self._debug_dir)
760
+ debug_path = os.path.join(
761
+ self._debug_dir, f"debug_capsule_{self.identifier}.json"
762
+ )
763
+ with open(debug_path, "w") as f:
764
+ f.write(json.dumps(capsule_response, indent=4))
765
+
766
+ def _monitor_worker_readiness(
767
+ self,
768
+ workers_sm: "CapsuleWorkersStateMachine",
769
+ capsule_sm: "CapsuleStateMachine",
770
+ ):
771
+ """returns True if the worker is crashlooping, False otherwise"""
772
+ logger = self._logger_fn or partial(print, file=sys.stderr)
773
+ for i in range(self._readiness_wait_time):
774
+ time.sleep(1)
775
+ self._update_capsule_and_worker_sm(capsule_sm, workers_sm, logger)
776
+ if workers_sm.is_crashlooping:
777
+ return True
778
+ return False
779
+
780
+ def _extract_logs_from_crashlooping_worker(
781
+ self, workers_sm: "CapsuleWorkersStateMachine"
782
+ ):
783
+ def _extract_worker_id_of_crashlooping_worker(
784
+ workers_status: List[WorkerStatus],
785
+ ):
786
+ for worker in workers_status:
787
+ if worker["phase"] == "CrashLoopBackOff" or worker["phase"] == "Failed":
788
+ return worker["workerId"]
789
+ return None
790
+
791
+ worker_id = _extract_worker_id_of_crashlooping_worker(workers_sm.current_status)
792
+ if worker_id is None:
793
+ return None, None
794
+ logs = self.capsule_api.logs(self.identifier, worker_id, previous=True)
795
+ return logs, worker_id
796
+
797
+ def _get_min_replicas(self):
798
+ replicas = self._app_config.get_state("replicas", {})
799
+ fixed, _min, _ = replicas.get("fixed"), replicas.get("min"), replicas.get("max")
800
+ if fixed is not None:
801
+ return fixed
802
+ return _min
803
+
804
+ def wait_for_terminal_state(
805
+ self,
806
+ ):
807
+ """ """
808
+ logger = self._logger_fn or partial(print, file=sys.stderr)
809
+ state_machine = CapsuleStateMachine(
810
+ self.identifier, self.current_deployment_instance_version
811
+ )
812
+ # min_replicas will always be present
813
+ min_replicas = self._get_min_replicas()
814
+ workers_state_machine = CapsuleWorkersStateMachine(
815
+ self.identifier,
816
+ self.current_deployment_instance_version,
817
+ deployment_mode=self._success_terminal_state_condition,
818
+ minimum_replicas=min_replicas,
819
+ )
820
+ self.status = state_machine
821
+
822
+ # This loop will check all the conditions that help verify the terminal state.
823
+ # How it works is by extracting the statuses of the capsule and workers and
824
+ # then adding them as a part of a state-machine that helps track transitions and
825
+ # helps derive terminal states.
826
+ # We will first keep checking for terminal conditions or outright failure conditions
827
+ # If we reach a teminal condition like described in `DEPLOYMENT_READY_CONDITIONS`, then
828
+ # we will further check for readiness conditions.
829
+ for i in range(self._create_timeout):
830
+ time.sleep(1)
831
+ capsule_response, _ = self._update_capsule_and_worker_sm(
832
+ state_machine, workers_state_machine, logger
833
+ )
834
+ # Deployment readiness checks will determine what is the terminal state
835
+ # of the workerstate machine. If we detect a terminal state in the workers,
836
+ # then even if the capsule upgrade is still in progress we will end up crashing
837
+ # the deployment.
838
+ (
839
+ capsule_ready,
840
+ further_check_worker_readiness,
841
+ ) = DEPLOYMENT_READY_CONDITIONS.check_readiness_condition(
842
+ state_machine.current_status,
843
+ workers_state_machine.current_version_deployment_status(),
844
+ self._success_terminal_state_condition,
845
+ )
846
+
847
+ failure_condition_satisfied = (
848
+ DEPLOYMENT_READY_CONDITIONS.check_failure_condition(
849
+ state_machine.current_status,
850
+ workers_state_machine.current_version_deployment_status(),
851
+ )
852
+ )
853
+ if capsule_ready or failure_condition_satisfied:
854
+ logger(
855
+ "💊 %s deployment status: %s "
856
+ % (
857
+ self.capsule_type.title(),
858
+ (
859
+ "in progress"
860
+ if state_machine.update_in_progress
861
+ else "completed"
862
+ ),
863
+ )
864
+ )
865
+ _further_readiness_check_failed = False
866
+ if further_check_worker_readiness:
867
+ # HACK : monitor the workers for N seconds to make sure they are healthy
868
+ # this is a hack. Ideally we should implement a healthcheck as a first class citizen
869
+ # but it will take some time to do that so in the meanwhile a timeout set on the cli
870
+ # side will be really helpful.
871
+ logger(
872
+ "💊 Running last minute readiness check for %s..."
873
+ % self.identifier
874
+ )
875
+ _further_readiness_check_failed = self._monitor_worker_readiness(
876
+ workers_state_machine,
877
+ state_machine,
878
+ )
879
+
880
+ if CAPSULE_DEBUG:
881
+ logger(
882
+ f"[debug] 💊 {self.capsule_type} {self.identifier}: further_check_worker_readiness {_further_readiness_check_failed} | failure_condition_satisfied {failure_condition_satisfied}"
883
+ )
884
+
885
+ # We should still check for failure state and crash if we detect something in the readiness check
886
+ if failure_condition_satisfied or _further_readiness_check_failed:
887
+ # hit the logs endpoint for the worker and get the logs
888
+ # Print those logs out on the terminal
889
+ # raise an exception that should be caught gracefully by the cli
890
+ logs, worker_id = self._extract_logs_from_crashlooping_worker(
891
+ workers_state_machine
892
+ )
893
+ if logs is not None:
894
+ # todo: It would be really odd if the logs are not present and we discover something is crashlooping.
895
+ # Handle that condition later
896
+ logger(
897
+ *(
898
+ [
899
+ f"💥 Worker ID ({worker_id}) is crashlooping. Please check the following logs for more information: "
900
+ ]
901
+ + ["\t" + l["message"] for l in logs]
902
+ )
903
+ )
904
+ raise CapsuleDeploymentException(
905
+ self.identifier,
906
+ f"Worker ID ({worker_id}) is crashlooping. Please check the logs for more information.",
907
+ )
908
+
909
+ if state_machine.ready_to_serve_traffic:
910
+ logger(
911
+ "💊 %s %s is ready to serve traffic on the URL: %s"
912
+ % (
913
+ self.capsule_type,
914
+ self.identifier,
915
+ state_machine.out_of_cluster_url,
916
+ ),
917
+ )
918
+
919
+ break
920
+
921
+ self._publish_capsule_debug_info(
922
+ state_machine, workers_state_machine, capsule_response
923
+ )
924
+
925
+ if CAPSULE_DEBUG and i % 3 == 0: # Every 3 seconds report the status
926
+ logger(
927
+ f"[debug] 💊 {self.capsule_type} {self.identifier} deployment status: {state_machine.current_status} | worker states: {workers_state_machine.current_status} | capsule_ready : {capsule_ready} | further_check_worker_readiness {further_check_worker_readiness}"
928
+ )
929
+
930
+ self._publish_capsule_debug_info(
931
+ state_machine, workers_state_machine, capsule_response
932
+ )
933
+
934
+ # We will only check ready_to_serve_traffic under the following conditions:
935
+ # If the readiness condition is not Async and min_replicas in this deployment
936
+ # instance is < 0
937
+ _is_async_readiness = (
938
+ self._success_terminal_state_condition == DEPLOYMENT_READY_CONDITIONS.ASYNC
939
+ )
940
+ if (
941
+ min_replicas > 0
942
+ and not _is_async_readiness
943
+ and not self.status.ready_to_serve_traffic
944
+ ):
945
+ raise CapsuleDeploymentException(
946
+ self.identifier,
947
+ f"Capsule {self.identifier} failed to be ready to serve traffic",
948
+ )
949
+
950
+ return dict(
951
+ id=self.identifier,
952
+ auth_type=self.capsule_type,
953
+ public_url=self.url,
954
+ available_replicas=self.status.available_replicas,
955
+ name=self.name,
956
+ deployed_version=self.current_deployment_instance_version,
957
+ deployed_at=datetime.now().isoformat(),
958
+ )