ob-metaflow-extensions 1.1.45rc3__py2.py3-none-any.whl → 1.5.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

Files changed (128) hide show
  1. metaflow_extensions/outerbounds/__init__.py +1 -7
  2. metaflow_extensions/outerbounds/config/__init__.py +35 -0
  3. metaflow_extensions/outerbounds/plugins/__init__.py +186 -57
  4. metaflow_extensions/outerbounds/plugins/apps/__init__.py +0 -0
  5. metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
  6. metaflow_extensions/outerbounds/plugins/apps/app_utils.py +187 -0
  7. metaflow_extensions/outerbounds/plugins/apps/consts.py +3 -0
  8. metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +15 -0
  9. metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
  10. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
  11. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
  12. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
  13. metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +128 -0
  14. metaflow_extensions/outerbounds/plugins/apps/core/app_deploy_decorator.py +330 -0
  15. metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
  16. metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
  17. metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
  18. metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
  19. metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
  20. metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
  21. metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +15 -0
  22. metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +165 -0
  23. metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +966 -0
  24. metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +299 -0
  25. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +233 -0
  26. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +537 -0
  27. metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1125 -0
  28. metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
  29. metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
  30. metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +959 -0
  31. metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
  32. metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
  33. metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
  34. metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
  35. metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
  36. metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py +201 -0
  37. metaflow_extensions/outerbounds/plugins/apps/supervisord_utils.py +243 -0
  38. metaflow_extensions/outerbounds/plugins/auth_server.py +28 -8
  39. metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
  40. metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
  41. metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +118 -0
  42. metaflow_extensions/outerbounds/plugins/card_utilities/__init__.py +0 -0
  43. metaflow_extensions/outerbounds/plugins/card_utilities/async_cards.py +142 -0
  44. metaflow_extensions/outerbounds/plugins/card_utilities/extra_components.py +545 -0
  45. metaflow_extensions/outerbounds/plugins/card_utilities/injector.py +70 -0
  46. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
  47. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
  48. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
  49. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
  50. metaflow_extensions/outerbounds/plugins/fast_bakery/__init__.py +0 -0
  51. metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
  52. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +391 -0
  53. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +188 -0
  54. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_cli.py +54 -0
  55. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery_decorator.py +50 -0
  56. metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +79 -0
  57. metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
  58. metaflow_extensions/outerbounds/plugins/nim/card.py +140 -0
  59. metaflow_extensions/outerbounds/plugins/nim/nim_decorator.py +101 -0
  60. metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +379 -0
  61. metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
  62. metaflow_extensions/outerbounds/plugins/nvcf/__init__.py +0 -0
  63. metaflow_extensions/outerbounds/plugins/nvcf/constants.py +3 -0
  64. metaflow_extensions/outerbounds/plugins/nvcf/exceptions.py +94 -0
  65. metaflow_extensions/outerbounds/plugins/nvcf/heartbeat_store.py +178 -0
  66. metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +417 -0
  67. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_cli.py +280 -0
  68. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +242 -0
  69. metaflow_extensions/outerbounds/plugins/nvcf/utils.py +6 -0
  70. metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
  71. metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
  72. metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
  73. metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
  74. metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
  75. metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
  76. metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
  77. metaflow_extensions/outerbounds/plugins/ollama/__init__.py +225 -0
  78. metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
  79. metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
  80. metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1924 -0
  81. metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
  82. metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
  83. metaflow_extensions/outerbounds/plugins/perimeters.py +19 -5
  84. metaflow_extensions/outerbounds/plugins/profilers/deco_injector.py +70 -0
  85. metaflow_extensions/outerbounds/plugins/profilers/gpu_profile_decorator.py +88 -0
  86. metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
  87. metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
  88. metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
  89. metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
  90. metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
  91. metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
  92. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
  93. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
  94. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
  95. metaflow_extensions/outerbounds/plugins/secrets/__init__.py +0 -0
  96. metaflow_extensions/outerbounds/plugins/secrets/secrets.py +204 -0
  97. metaflow_extensions/outerbounds/plugins/snowflake/__init__.py +3 -0
  98. metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +378 -0
  99. metaflow_extensions/outerbounds/plugins/snowpark/__init__.py +0 -0
  100. metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +309 -0
  101. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +277 -0
  102. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +150 -0
  103. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +273 -0
  104. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_exceptions.py +13 -0
  105. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +241 -0
  106. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_service_spec.py +259 -0
  107. metaflow_extensions/outerbounds/plugins/tensorboard/__init__.py +50 -0
  108. metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
  109. metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
  110. metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
  111. metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
  112. metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
  113. metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
  114. metaflow_extensions/outerbounds/profilers/gpu.py +131 -47
  115. metaflow_extensions/outerbounds/remote_config.py +53 -16
  116. metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +138 -2
  117. metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
  118. metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py +1 -0
  119. metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
  120. metaflow_extensions/outerbounds/toplevel/plugins/snowflake/__init__.py +1 -0
  121. metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
  122. metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
  123. metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
  124. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/METADATA +2 -2
  125. ob_metaflow_extensions-1.5.1.dist-info/RECORD +133 -0
  126. ob_metaflow_extensions-1.1.45rc3.dist-info/RECORD +0 -19
  127. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/WHEEL +0 -0
  128. {ob_metaflow_extensions-1.1.45rc3.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,178 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import subprocess
5
+ from io import BytesIO
6
+ from datetime import datetime, timezone
7
+
8
+ from metaflow.exception import MetaflowException
9
+
10
+
11
+ def kill_process_and_descendants(pid, termination_timeout=1, iterations=20, delay=0.5):
12
+ for i in range(iterations):
13
+ try:
14
+ subprocess.check_call(["pkill", "-TERM", "-P", str(pid)])
15
+ subprocess.check_call(["kill", "-TERM", str(pid)])
16
+
17
+ time.sleep(termination_timeout)
18
+
19
+ subprocess.check_call(["pkill", "-KILL", "-P", str(pid)])
20
+ subprocess.check_call(["kill", "-KILL", str(pid)])
21
+ except subprocess.CalledProcessError:
22
+ pass
23
+
24
+ # Don't delay after the last iteration
25
+ if i < iterations - 1:
26
+ time.sleep(delay)
27
+
28
+
29
+ class HeartbeatStore(object):
30
+ def __init__(
31
+ self,
32
+ main_pid=None,
33
+ storage_backend=None,
34
+ emit_frequency=30,
35
+ missed_heartbeat_timeout=60,
36
+ monitor_frequency=15,
37
+ max_missed_heartbeats=3,
38
+ ) -> None:
39
+ self.main_pid = main_pid
40
+ self.storage_backend = storage_backend
41
+ self.emit_frequency = emit_frequency
42
+ self.monitor_frequency = monitor_frequency
43
+ self.missed_heartbeat_timeout = missed_heartbeat_timeout
44
+ self.max_missed_heartbeats = max_missed_heartbeats
45
+ self.missed_heartbeats = 0
46
+
47
+ def emit_heartbeat(self, heartbeat_prefix: str, folder_name=None):
48
+ heartbeat_key = f"{heartbeat_prefix}/heartbeat"
49
+ if folder_name:
50
+ heartbeat_key = f"{folder_name}/{heartbeat_key}"
51
+
52
+ while True:
53
+ try:
54
+ epoch_string = str(datetime.now(timezone.utc).timestamp()).encode(
55
+ "utf-8"
56
+ )
57
+ self.storage_backend.save_bytes(
58
+ [(heartbeat_key, BytesIO(bytes(epoch_string)))], overwrite=True
59
+ )
60
+ except Exception as e:
61
+ print(f"Error writing heartbeat: {e}")
62
+ sys.exit(1)
63
+
64
+ time.sleep(self.emit_frequency)
65
+
66
+ def emit_tombstone(self, tombstone_prefix: str, folder_name=None):
67
+ tombstone_key = f"{tombstone_prefix}/tombstone"
68
+ if folder_name:
69
+ tombstone_key = f"{folder_name}/{tombstone_key}"
70
+
71
+ tombstone_string = "tombstone".encode("utf-8")
72
+ try:
73
+ self.storage_backend.save_bytes(
74
+ [(tombstone_key, BytesIO(bytes(tombstone_string)))], overwrite=True
75
+ )
76
+ except Exception as e:
77
+ print(f"Error writing tombstone: {e}")
78
+ sys.exit(1)
79
+
80
+ def __handle_tombstone(self, path):
81
+ if path is not None:
82
+ with open(path) as f:
83
+ contents = f.read()
84
+ if "tombstone" in contents:
85
+ print("[Outerbounds] Tombstone detected. Terminating the task..")
86
+ kill_process_and_descendants(self.main_pid)
87
+ sys.exit(1)
88
+
89
+ def __handle_heartbeat(self, path):
90
+ if path is not None:
91
+ with open(path) as f:
92
+ contents = f.read()
93
+ current_timestamp = datetime.now(timezone.utc).timestamp()
94
+ if current_timestamp - float(contents) > self.missed_heartbeat_timeout:
95
+ self.missed_heartbeats += 1
96
+ else:
97
+ self.missed_heartbeats = 0
98
+ else:
99
+ self.missed_heartbeats += 1
100
+
101
+ if self.missed_heartbeats > self.max_missed_heartbeats:
102
+ print(
103
+ f"[Outerbounds] Missed {self.max_missed_heartbeats} consecutive heartbeats. Terminating the task.."
104
+ )
105
+ kill_process_and_descendants(self.main_pid)
106
+ sys.exit(1)
107
+
108
+ def is_main_process_running(self):
109
+ try:
110
+ # Check if the process is running
111
+ os.kill(self.main_pid, 0)
112
+ except ProcessLookupError:
113
+ return False
114
+ return True
115
+
116
+ def monitor(self, heartbeat_prefix: str, tombstone_prefix: str, folder_name=None):
117
+ heartbeat_key = f"{heartbeat_prefix}/heartbeat"
118
+ if folder_name:
119
+ heartbeat_key = f"{folder_name}/{heartbeat_key}"
120
+
121
+ tombstone_key = f"{tombstone_prefix}/tombstone"
122
+ if folder_name:
123
+ tombstone_key = f"{folder_name}/{tombstone_key}"
124
+
125
+ while self.is_main_process_running():
126
+ with self.storage_backend.load_bytes(
127
+ [heartbeat_key, tombstone_key]
128
+ ) as results:
129
+ for key, path, _ in results:
130
+ if key == tombstone_key:
131
+ self.__handle_tombstone(path)
132
+ elif key == heartbeat_key:
133
+ self.__handle_heartbeat(path)
134
+
135
+ time.sleep(self.monitor_frequency)
136
+
137
+
138
+ if __name__ == "__main__":
139
+ from metaflow.plugins import DATASTORES
140
+ from metaflow.metaflow_config import NVIDIA_HEARTBEAT_THRESHOLD
141
+
142
+ if len(sys.argv) != 4:
143
+ print("Usage: heartbeat_store.py <main_pid> <datastore_type> <folder_name>")
144
+ sys.exit(1)
145
+ _, main_pid, datastore_type, folder_name = sys.argv
146
+
147
+ if datastore_type not in ("azure", "gs", "s3"):
148
+ print(f"Datastore unsupported for type: {datastore_type}")
149
+ sys.exit(1)
150
+
151
+ datastores = [d for d in DATASTORES if d.TYPE == datastore_type]
152
+ datastore_sysroot = datastores[0].get_datastore_root_from_config(
153
+ lambda *args, **kwargs: None
154
+ )
155
+ if datastore_sysroot is None:
156
+ raise MetaflowException(
157
+ msg="METAFLOW_DATASTORE_SYSROOT_{datastore_type} must be set!".format(
158
+ datastore_type=datastore_type.upper()
159
+ )
160
+ )
161
+
162
+ storage = datastores[0](datastore_sysroot)
163
+
164
+ heartbeat_prefix = f"{os.getenv('MF_PATHSPEC')}/{os.getenv('MF_ATTEMPT')}"
165
+ flow_name, run_id, _, _ = os.getenv("MF_PATHSPEC").split("/")
166
+ tombstone_prefix = f"{flow_name}/{run_id}"
167
+
168
+ store = HeartbeatStore(
169
+ main_pid=int(main_pid),
170
+ storage_backend=storage,
171
+ max_missed_heartbeats=int(NVIDIA_HEARTBEAT_THRESHOLD),
172
+ )
173
+
174
+ store.monitor(
175
+ heartbeat_prefix=heartbeat_prefix,
176
+ tombstone_prefix=tombstone_prefix,
177
+ folder_name=folder_name,
178
+ )
@@ -0,0 +1,417 @@
1
+ import json
2
+ import os
3
+ import time
4
+ import threading
5
+ from urllib.request import HTTPError, Request, URLError, urlopen
6
+ from functools import wraps
7
+
8
+ from metaflow import util
9
+ from metaflow.mflog import (
10
+ BASH_SAVE_LOGS,
11
+ bash_capture_logs,
12
+ export_mflog_env_vars,
13
+ tail_logs,
14
+ get_log_tailer,
15
+ )
16
+ from .exceptions import NvcfJobFailedException, NvcfPollingConnectionError
17
+
18
+ # Redirect structured logs to $PWD/.logs/
19
+ LOGS_DIR = "$PWD/.logs"
20
+ STDOUT_FILE = "mflog_stdout"
21
+ STDERR_FILE = "mflog_stderr"
22
+ STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE)
23
+ STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE)
24
+
25
+
26
+ def retry_on_status(status_codes=[500], max_retries=3, delay=1):
27
+ def decorator(func):
28
+ @wraps(func)
29
+ def wrapper(instance, *args, **kwargs):
30
+ retries = 0
31
+
32
+ # Determine retry limit upfront
33
+ use_queue_timeout = 504 in status_codes
34
+ if use_queue_timeout:
35
+ poll_seconds = int(instance._poll_seconds)
36
+ retry_limit = (
37
+ instance._queue_timeout + (poll_seconds - 1)
38
+ ) // poll_seconds
39
+ remainder = instance._queue_timeout % poll_seconds
40
+ last_timeout = remainder if remainder != 0 else poll_seconds
41
+ else:
42
+ retry_limit = max_retries
43
+
44
+ while retries < retry_limit:
45
+ try:
46
+ return func(instance, *args, **kwargs)
47
+ except HTTPError as e:
48
+ if e.code not in status_codes or retries >= retry_limit:
49
+ instance._status = JobStatus.FAILED
50
+ if e.code == 504 and retries >= retry_limit:
51
+ raise NvcfPollingConnectionError(
52
+ "Request timed out after all retries"
53
+ )
54
+ raise
55
+
56
+ if e.code == 504 and retries == retry_limit - 1:
57
+ instance._poll_seconds = str(last_timeout)
58
+
59
+ print(
60
+ f"[@nvidia] {'Queue timeout' if e.code == 504 else f'Received {e.code}'}, "
61
+ f"retrying ({retries + 1}/{retry_limit})... with poll seconds as {instance._poll_seconds}"
62
+ )
63
+
64
+ if e.code != 504:
65
+ time.sleep(delay)
66
+
67
+ retries += 1
68
+ except URLError as e:
69
+ instance._status = JobStatus.FAILED
70
+ raise
71
+ # final attempt
72
+ return func(instance, *args, **kwargs)
73
+
74
+ return wrapper
75
+
76
+ return decorator
77
+
78
+
79
+ class Nvcf(object):
80
+ def __init__(
81
+ self, metadata, datastore, environment, function_id, ngc_api_key, queue_timeout
82
+ ):
83
+ self.metadata = metadata
84
+ self.datastore = datastore
85
+ self.environment = environment
86
+ self._function_id = function_id
87
+ self._ngc_api_key = ngc_api_key
88
+ self._queue_timeout = queue_timeout
89
+
90
+ def launch_job(
91
+ self,
92
+ step_name,
93
+ step_cli,
94
+ task_spec,
95
+ code_package_sha,
96
+ code_package_url,
97
+ code_package_ds,
98
+ env={},
99
+ ):
100
+ mflog_expr = export_mflog_env_vars(
101
+ datastore_type=code_package_ds,
102
+ stdout_path=STDOUT_PATH,
103
+ stderr_path=STDERR_PATH,
104
+ **task_spec,
105
+ )
106
+ init_cmds = self.environment.get_package_commands(
107
+ code_package_url, code_package_ds
108
+ )
109
+ init_expr = " && ".join(init_cmds)
110
+ heartbeat_expr = f'python -m metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store "$MAIN_PID" {code_package_ds} nvcf_heartbeats & HEARTBEAT_PID=$!;'
111
+ step_expr = bash_capture_logs(
112
+ " && ".join(
113
+ self.environment.bootstrap_commands(step_name, code_package_ds)
114
+ + [step_cli + " & MAIN_PID=$!; " + heartbeat_expr + " wait $MAIN_PID"]
115
+ )
116
+ )
117
+
118
+ # construct an entry point that
119
+ # 1) initializes the mflog environment (mflog_expr)
120
+ # 2) bootstraps a metaflow environment (init_expr)
121
+ # 3) executes a task (step_expr)
122
+
123
+ cmd_str = "mkdir -p %s && %s && %s && %s; " % (
124
+ LOGS_DIR,
125
+ mflog_expr,
126
+ init_expr,
127
+ step_expr,
128
+ )
129
+ # after the task has finished, we save its exit code (fail/success)
130
+ # and persist the final logs. The whole entrypoint should exit
131
+ # with the exit code (c) of the task.
132
+ #
133
+ # Note that if step_expr OOMs, this tail expression is never executed.
134
+ # We lose the last logs in this scenario.
135
+ cmd_str += (
136
+ "c=$?; kill $HEARTBEAT_PID; wait $HEARTBEAT_PID; %s; exit $c"
137
+ % BASH_SAVE_LOGS
138
+ )
139
+ cmd_str = (
140
+ '${METAFLOW_INIT_SCRIPT:+eval \\"${METAFLOW_INIT_SCRIPT}\\"} && %s'
141
+ % cmd_str
142
+ )
143
+ self.job = Job(
144
+ 'bash -c "%s"' % cmd_str,
145
+ env,
146
+ task_spec,
147
+ self.datastore._storage_impl,
148
+ self._function_id,
149
+ self._ngc_api_key,
150
+ self._queue_timeout,
151
+ )
152
+ self.job.submit()
153
+
154
+ def wait(self, stdout_location, stderr_location, echo=None):
155
+ def wait_for_launch(job):
156
+ status = job._status
157
+ echo(
158
+ "Task status: %s..." % status,
159
+ "stderr",
160
+ _id=job.id,
161
+ )
162
+
163
+ prefix = b"[%s] " % util.to_bytes(self.job.id)
164
+ stdout_tail = get_log_tailer(stdout_location, self.datastore.TYPE)
165
+ stderr_tail = get_log_tailer(stderr_location, self.datastore.TYPE)
166
+
167
+ # 1) Loop until the job has started
168
+ wait_for_launch(self.job)
169
+
170
+ # 2) Tail logs until the job has finished
171
+ tail_logs(
172
+ prefix=prefix,
173
+ stdout_tail=stdout_tail,
174
+ stderr_tail=stderr_tail,
175
+ echo=echo,
176
+ has_log_updates=lambda: self.job.is_running,
177
+ )
178
+
179
+ echo(
180
+ "Task finished with exit code %s." % self.job.result.get("exit_code"),
181
+ "stderr",
182
+ _id=self.job.id,
183
+ )
184
+ if self.job.has_failed:
185
+ raise NvcfJobFailedException(
186
+ "This could be a transient error. Use @retry to retry."
187
+ )
188
+
189
+
190
+ class JobStatus(object):
191
+ CREATED = "CREATED" # Job object created but not submitted
192
+ SUBMITTED = "SUBMITTED" # Job submitted to NVCF
193
+ POLLED = "POLLED" # Job has been successfully polled at least once
194
+ SUCCESSFUL = "SUCCESSFUL" # Job completed successfully
195
+ FAILED = "FAILED" # Job failed
196
+ DISAPPEARED = "DISAPPEARED" # Job disappeared from NVCF but was previously polled (likely successful)
197
+
198
+
199
+ terminal_states = [JobStatus.SUCCESSFUL, JobStatus.FAILED, JobStatus.DISAPPEARED]
200
+
201
+ nvcf_url = "https://api.nvcf.nvidia.com"
202
+ submit_endpoint = f"{nvcf_url}/v2/nvcf/pexec/functions"
203
+ result_endpoint = f"{nvcf_url}/v2/nvcf/pexec/status"
204
+
205
+
206
+ class Job(object):
207
+ def __init__(
208
+ self, command, env, task_spec, backend, function_id, ngc_api_key, queue_timeout
209
+ ):
210
+ self._payload = {
211
+ "command": command,
212
+ "env": {k: v for k, v in env.items() if v is not None},
213
+ }
214
+ self._result = {}
215
+ self._function_id = function_id
216
+ self._ngc_api_key = ngc_api_key
217
+ self._queue_timeout = queue_timeout
218
+ self._poll_seconds = "3600"
219
+
220
+ # Initialize status and tracking variables
221
+ self._status = JobStatus.CREATED
222
+ self._last_poll_time = time.time()
223
+
224
+ # State tracking for long polling
225
+ self._long_polling_active = False
226
+ self._poll_response = None
227
+
228
+ flow_name = task_spec.get("flow_name")
229
+ run_id = task_spec.get("run_id")
230
+ step_name = task_spec.get("step_name")
231
+ task_id = task_spec.get("task_id")
232
+ retry_count = task_spec.get("retry_count")
233
+
234
+ heartbeat_prefix = "/".join(
235
+ (flow_name, str(run_id), step_name, str(task_id), str(retry_count))
236
+ )
237
+
238
+ ## import is done here to avoid the following warning:
239
+ # RuntimeWarning: 'metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store' found in sys.modules
240
+ # after import of package 'metaflow_extensions.outerbounds.plugins.nvcf', but prior to execution of
241
+ # 'metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store'; this may result in unpredictable behaviour
242
+ from metaflow_extensions.outerbounds.plugins.nvcf.heartbeat_store import (
243
+ HeartbeatStore,
244
+ )
245
+
246
+ store = HeartbeatStore(
247
+ main_pid=None,
248
+ storage_backend=backend,
249
+ )
250
+
251
+ self.heartbeat_thread = threading.Thread(
252
+ target=store.emit_heartbeat,
253
+ args=(
254
+ heartbeat_prefix,
255
+ "nvcf_heartbeats",
256
+ ),
257
+ daemon=True,
258
+ )
259
+ self.heartbeat_thread.start()
260
+
261
+ @retry_on_status(status_codes=[504])
262
+ def submit(self):
263
+ try:
264
+ headers = {
265
+ "Authorization": f"Bearer {self._ngc_api_key}",
266
+ "Content-Type": "application/json",
267
+ "nvcf-feature-enable-gateway-timeout": "true",
268
+ "NVCF-POLL-SECONDS": self._poll_seconds,
269
+ }
270
+ request_data = json.dumps(self._payload).encode()
271
+ request = Request(
272
+ f"{submit_endpoint}/{self._function_id}",
273
+ data=request_data,
274
+ headers=headers,
275
+ )
276
+ response = urlopen(request)
277
+ self._invocation_id = response.headers.get("NVCF-REQID")
278
+ if response.getcode() == 200:
279
+ data = json.loads(response.read())
280
+ if data.get("exit_code") == 0:
281
+ self._status = JobStatus.SUCCESSFUL
282
+ else:
283
+ self._status = JobStatus.FAILED
284
+ self._result = data
285
+ elif response.getcode() == 202:
286
+ self._status = JobStatus.SUBMITTED
287
+ # Start long polling immediately after receiving 202
288
+ self._start_long_polling()
289
+ else:
290
+ self._status = JobStatus.FAILED
291
+ except URLError:
292
+ self._status = JobStatus.FAILED
293
+ raise
294
+
295
+ def _start_long_polling(self):
296
+ if not self._long_polling_active:
297
+ self._long_polling_active = True
298
+ polling_thread = threading.Thread(target=self._long_poll_loop, daemon=True)
299
+ polling_thread.start()
300
+
301
+ def _long_poll_loop(self):
302
+ while self._long_polling_active and self._status not in terminal_states:
303
+ try:
304
+ self._poll()
305
+ # No sleep needed - the request itself will block for up to self._poll_seconds
306
+ except Exception as e:
307
+ print(f"[@nvidia] Long polling error: {e}")
308
+ # Brief pause before retry on error
309
+ time.sleep(1)
310
+
311
+ self._long_polling_active = False
312
+
313
+ @property
314
+ def id(self):
315
+ return self._invocation_id
316
+
317
+ @property
318
+ def is_running(self):
319
+ # Job is running if it's in SUBMITTED or POLLED state
320
+ return self._status in [JobStatus.SUBMITTED, JobStatus.POLLED]
321
+
322
+ @property
323
+ def has_failed(self):
324
+ return self._status == JobStatus.FAILED
325
+
326
+ @property
327
+ def result(self):
328
+ return self._result
329
+
330
+ @retry_on_status(status_codes=[500], max_retries=3, delay=5)
331
+ @retry_on_status(status_codes=[504])
332
+ def _poll(self):
333
+ try:
334
+ # Implement rate limiting to prevent more than 1 request per second
335
+ current_time = time.time()
336
+ if (
337
+ hasattr(self, "_last_poll_time")
338
+ and current_time - self._last_poll_time < 1
339
+ ):
340
+ time.sleep(1 - (current_time - self._last_poll_time))
341
+
342
+ headers = {
343
+ "Authorization": f"Bearer {self._ngc_api_key}",
344
+ "Content-Type": "application/json",
345
+ "nvcf-feature-enable-gateway-timeout": "true",
346
+ "NVCF-POLL-SECONDS": self._poll_seconds,
347
+ }
348
+ request = Request(
349
+ f"{result_endpoint}/{self._invocation_id}", headers=headers
350
+ )
351
+
352
+ # Record time before making the request
353
+ self._last_poll_time = time.time()
354
+
355
+ response = urlopen(request)
356
+ body = response.read()
357
+ print(f"[@nvidia] polling status code: {response.getcode()}")
358
+
359
+ if response.getcode() == 200:
360
+ data = json.loads(body)
361
+ if data.get("exit_code") == 0:
362
+ self._status = JobStatus.SUCCESSFUL
363
+ else:
364
+ self._status = JobStatus.FAILED
365
+ self._result = data
366
+ self._long_polling_active = False # Stop polling once job completes
367
+ elif response.getcode() == 202:
368
+ # Job is still running - status remains SUBMITTED or POLLED
369
+ if self._status == JobStatus.SUBMITTED:
370
+ self._status = JobStatus.POLLED
371
+ elif response.getcode() == 302:
372
+ # Handle redirects for large responses or requests in different regions
373
+ redirect_location = response.headers.get("Location")
374
+ if redirect_location:
375
+ redirect_request = Request(redirect_location, headers=headers)
376
+ redirect_response = urlopen(redirect_request)
377
+ if redirect_response.getcode() == 200:
378
+ data = json.loads(redirect_response.read())
379
+ if data.get("exit_code") == 0:
380
+ self._status = JobStatus.SUCCESSFUL
381
+ else:
382
+ self._status = JobStatus.FAILED
383
+ self._result = data
384
+ self._long_polling_active = False
385
+ else:
386
+ print(
387
+ f"[@nvidia] Unexpected response code: {response.getcode()}. Please notify an Outerbounds support engineer if this error persists."
388
+ )
389
+ self._status = JobStatus.FAILED
390
+
391
+ except HTTPError as e:
392
+ if e.code == 404:
393
+ # 404 interpretation depends on job lifecycle
394
+ if self._status in [JobStatus.POLLED, JobStatus.SUBMITTED]:
395
+ # We've submitted or successfully polled this job before,
396
+ # so a 404 likely means it completed and was removed
397
+ self._status = JobStatus.DISAPPEARED
398
+ self._result = {"exit_code": 0}
399
+ print(
400
+ f"[@nvidia] 404 received for job that was previously tracked - assuming job completed"
401
+ )
402
+ else:
403
+ # Job was never successfully tracked
404
+ print(
405
+ f"[@nvidia] 404 received for job that was never successfully tracked - treating as failure"
406
+ )
407
+ self._status = JobStatus.FAILED
408
+ raise NvcfPollingConnectionError(e)
409
+ elif e.code in [500, 504]:
410
+ # Don't set status to FAILED, just re-raise for retry decorator
411
+ raise
412
+ else:
413
+ self._status = JobStatus.FAILED
414
+ raise NvcfPollingConnectionError(e)
415
+ except URLError as e:
416
+ self._status = JobStatus.FAILED
417
+ raise NvcfPollingConnectionError(e)