ob-metaflow-extensions 1.1.130__py2.py3-none-any.whl → 1.5.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow-extensions might be problematic. Click here for more details.

Files changed (105) hide show
  1. metaflow_extensions/outerbounds/__init__.py +1 -1
  2. metaflow_extensions/outerbounds/plugins/__init__.py +34 -4
  3. metaflow_extensions/outerbounds/plugins/apps/__init__.py +0 -0
  4. metaflow_extensions/outerbounds/plugins/apps/app_cli.py +0 -0
  5. metaflow_extensions/outerbounds/plugins/apps/app_utils.py +187 -0
  6. metaflow_extensions/outerbounds/plugins/apps/consts.py +3 -0
  7. metaflow_extensions/outerbounds/plugins/apps/core/__init__.py +15 -0
  8. metaflow_extensions/outerbounds/plugins/apps/core/_state_machine.py +506 -0
  9. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/__init__.py +0 -0
  10. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/__init__.py +4 -0
  11. metaflow_extensions/outerbounds/plugins/apps/core/_vendor/spinner/spinners.py +478 -0
  12. metaflow_extensions/outerbounds/plugins/apps/core/app_config.py +128 -0
  13. metaflow_extensions/outerbounds/plugins/apps/core/app_deploy_decorator.py +330 -0
  14. metaflow_extensions/outerbounds/plugins/apps/core/artifacts.py +0 -0
  15. metaflow_extensions/outerbounds/plugins/apps/core/capsule.py +958 -0
  16. metaflow_extensions/outerbounds/plugins/apps/core/click_importer.py +24 -0
  17. metaflow_extensions/outerbounds/plugins/apps/core/code_package/__init__.py +3 -0
  18. metaflow_extensions/outerbounds/plugins/apps/core/code_package/code_packager.py +618 -0
  19. metaflow_extensions/outerbounds/plugins/apps/core/code_package/examples.py +125 -0
  20. metaflow_extensions/outerbounds/plugins/apps/core/config/__init__.py +15 -0
  21. metaflow_extensions/outerbounds/plugins/apps/core/config/cli_generator.py +165 -0
  22. metaflow_extensions/outerbounds/plugins/apps/core/config/config_utils.py +966 -0
  23. metaflow_extensions/outerbounds/plugins/apps/core/config/schema_export.py +299 -0
  24. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_configs.py +233 -0
  25. metaflow_extensions/outerbounds/plugins/apps/core/config/typed_init_generator.py +537 -0
  26. metaflow_extensions/outerbounds/plugins/apps/core/config/unified_config.py +1125 -0
  27. metaflow_extensions/outerbounds/plugins/apps/core/config_schema.yaml +337 -0
  28. metaflow_extensions/outerbounds/plugins/apps/core/dependencies.py +115 -0
  29. metaflow_extensions/outerbounds/plugins/apps/core/deployer.py +959 -0
  30. metaflow_extensions/outerbounds/plugins/apps/core/experimental/__init__.py +89 -0
  31. metaflow_extensions/outerbounds/plugins/apps/core/perimeters.py +87 -0
  32. metaflow_extensions/outerbounds/plugins/apps/core/secrets.py +164 -0
  33. metaflow_extensions/outerbounds/plugins/apps/core/utils.py +233 -0
  34. metaflow_extensions/outerbounds/plugins/apps/core/validations.py +17 -0
  35. metaflow_extensions/outerbounds/plugins/apps/deploy_decorator.py +201 -0
  36. metaflow_extensions/outerbounds/plugins/apps/supervisord_utils.py +243 -0
  37. metaflow_extensions/outerbounds/plugins/aws/__init__.py +4 -0
  38. metaflow_extensions/outerbounds/plugins/aws/assume_role.py +3 -0
  39. metaflow_extensions/outerbounds/plugins/aws/assume_role_decorator.py +118 -0
  40. metaflow_extensions/outerbounds/plugins/card_utilities/injector.py +1 -1
  41. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/__init__.py +2 -0
  42. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/coreweave.py +71 -0
  43. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/external_chckpt.py +85 -0
  44. metaflow_extensions/outerbounds/plugins/checkpoint_datastores/nebius.py +73 -0
  45. metaflow_extensions/outerbounds/plugins/fast_bakery/baker.py +110 -0
  46. metaflow_extensions/outerbounds/plugins/fast_bakery/docker_environment.py +43 -9
  47. metaflow_extensions/outerbounds/plugins/fast_bakery/fast_bakery.py +12 -0
  48. metaflow_extensions/outerbounds/plugins/kubernetes/kubernetes_client.py +18 -44
  49. metaflow_extensions/outerbounds/plugins/kubernetes/pod_killer.py +374 -0
  50. metaflow_extensions/outerbounds/plugins/nim/card.py +2 -16
  51. metaflow_extensions/outerbounds/plugins/nim/{__init__.py → nim_decorator.py} +13 -49
  52. metaflow_extensions/outerbounds/plugins/nim/nim_manager.py +294 -233
  53. metaflow_extensions/outerbounds/plugins/nim/utils.py +36 -0
  54. metaflow_extensions/outerbounds/plugins/nvcf/constants.py +2 -2
  55. metaflow_extensions/outerbounds/plugins/nvcf/nvcf.py +100 -19
  56. metaflow_extensions/outerbounds/plugins/nvcf/nvcf_decorator.py +6 -1
  57. metaflow_extensions/outerbounds/plugins/nvct/__init__.py +0 -0
  58. metaflow_extensions/outerbounds/plugins/nvct/exceptions.py +71 -0
  59. metaflow_extensions/outerbounds/plugins/nvct/nvct.py +131 -0
  60. metaflow_extensions/outerbounds/plugins/nvct/nvct_cli.py +289 -0
  61. metaflow_extensions/outerbounds/plugins/nvct/nvct_decorator.py +286 -0
  62. metaflow_extensions/outerbounds/plugins/nvct/nvct_runner.py +218 -0
  63. metaflow_extensions/outerbounds/plugins/nvct/utils.py +29 -0
  64. metaflow_extensions/outerbounds/plugins/ollama/__init__.py +225 -0
  65. metaflow_extensions/outerbounds/plugins/ollama/constants.py +1 -0
  66. metaflow_extensions/outerbounds/plugins/ollama/exceptions.py +22 -0
  67. metaflow_extensions/outerbounds/plugins/ollama/ollama.py +1924 -0
  68. metaflow_extensions/outerbounds/plugins/ollama/status_card.py +292 -0
  69. metaflow_extensions/outerbounds/plugins/optuna/__init__.py +48 -0
  70. metaflow_extensions/outerbounds/plugins/profilers/simple_card_decorator.py +96 -0
  71. metaflow_extensions/outerbounds/plugins/s3_proxy/__init__.py +7 -0
  72. metaflow_extensions/outerbounds/plugins/s3_proxy/binary_caller.py +132 -0
  73. metaflow_extensions/outerbounds/plugins/s3_proxy/constants.py +11 -0
  74. metaflow_extensions/outerbounds/plugins/s3_proxy/exceptions.py +13 -0
  75. metaflow_extensions/outerbounds/plugins/s3_proxy/proxy_bootstrap.py +59 -0
  76. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_api.py +93 -0
  77. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_decorator.py +250 -0
  78. metaflow_extensions/outerbounds/plugins/s3_proxy/s3_proxy_manager.py +225 -0
  79. metaflow_extensions/outerbounds/plugins/secrets/secrets.py +38 -2
  80. metaflow_extensions/outerbounds/plugins/snowflake/snowflake.py +81 -11
  81. metaflow_extensions/outerbounds/plugins/snowpark/snowpark.py +18 -8
  82. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_cli.py +6 -0
  83. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_client.py +45 -18
  84. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_decorator.py +18 -9
  85. metaflow_extensions/outerbounds/plugins/snowpark/snowpark_job.py +10 -4
  86. metaflow_extensions/outerbounds/plugins/torchtune/__init__.py +163 -0
  87. metaflow_extensions/outerbounds/plugins/vllm/__init__.py +255 -0
  88. metaflow_extensions/outerbounds/plugins/vllm/constants.py +1 -0
  89. metaflow_extensions/outerbounds/plugins/vllm/exceptions.py +1 -0
  90. metaflow_extensions/outerbounds/plugins/vllm/status_card.py +352 -0
  91. metaflow_extensions/outerbounds/plugins/vllm/vllm_manager.py +621 -0
  92. metaflow_extensions/outerbounds/remote_config.py +46 -9
  93. metaflow_extensions/outerbounds/toplevel/global_aliases_for_metaflow_package.py +94 -2
  94. metaflow_extensions/outerbounds/toplevel/ob_internal.py +4 -0
  95. metaflow_extensions/outerbounds/toplevel/plugins/ollama/__init__.py +1 -0
  96. metaflow_extensions/outerbounds/toplevel/plugins/optuna/__init__.py +1 -0
  97. metaflow_extensions/outerbounds/toplevel/plugins/torchtune/__init__.py +1 -0
  98. metaflow_extensions/outerbounds/toplevel/plugins/vllm/__init__.py +1 -0
  99. metaflow_extensions/outerbounds/toplevel/s3_proxy.py +88 -0
  100. {ob_metaflow_extensions-1.1.130.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/METADATA +2 -2
  101. ob_metaflow_extensions-1.5.1.dist-info/RECORD +133 -0
  102. metaflow_extensions/outerbounds/plugins/nim/utilities.py +0 -5
  103. ob_metaflow_extensions-1.1.130.dist-info/RECORD +0 -56
  104. {ob_metaflow_extensions-1.1.130.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/WHEEL +0 -0
  105. {ob_metaflow_extensions-1.1.130.dist-info → ob_metaflow_extensions-1.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,289 @@
1
+ import json
2
+ import os
3
+ import sys
4
+ import time
5
+ import traceback
6
+
7
+ from metaflow import util, Run
8
+ from metaflow._vendor import click
9
+ from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY
10
+ from metaflow.metadata_provider.util import sync_local_metadata_from_datastore
11
+ from metaflow.metaflow_config import (
12
+ CARD_S3ROOT,
13
+ DATASTORE_LOCAL_DIR,
14
+ DATASTORE_SYSROOT_S3,
15
+ DATATOOLS_S3ROOT,
16
+ DEFAULT_METADATA,
17
+ SERVICE_HEADERS,
18
+ SERVICE_URL,
19
+ DEFAULT_SECRETS_BACKEND_TYPE,
20
+ DEFAULT_AWS_CLIENT_PROVIDER,
21
+ AWS_SECRETS_MANAGER_DEFAULT_REGION,
22
+ S3_ENDPOINT_URL,
23
+ AZURE_STORAGE_BLOB_SERVICE_ENDPOINT,
24
+ DATASTORE_SYSROOT_AZURE,
25
+ CARD_AZUREROOT,
26
+ DATASTORE_SYSROOT_GS,
27
+ CARD_GSROOT,
28
+ KUBERNETES_SANDBOX_INIT_SCRIPT,
29
+ OTEL_ENDPOINT,
30
+ )
31
+ from metaflow.mflog import TASK_LOG_SOURCE
32
+ from .nvct_runner import NvctRunner
33
+ from .nvct import NVCTClient
34
+ from .utils import get_ngc_api_key
35
+ from .exceptions import NvctKilledException
36
+
37
+
38
+ @click.group()
39
+ def cli():
40
+ pass
41
+
42
+
43
+ @cli.group(help="Commands related to nvct.")
44
+ def nvct():
45
+ pass
46
+
47
+
48
+ @nvct.command(help="List steps / tasks running as an nvct job.")
49
+ @click.option(
50
+ "--run-id",
51
+ default=None,
52
+ required=True,
53
+ help="List unfinished and running tasks corresponding to the run id.",
54
+ )
55
+ @click.pass_context
56
+ def list(ctx, run_id):
57
+ flow_name = ctx.obj.flow.name
58
+ run_obj = Run(pathspec=f"{flow_name}/{run_id}", _namespace_check=False)
59
+ running_invocations = []
60
+
61
+ for each_step in run_obj:
62
+ for each_task in each_step:
63
+ if not each_task.finished and "nvct-task-id" in each_task.metadata_dict:
64
+ task_pathspec = each_task.pathspec
65
+ attempt = each_task.metadata_dict.get("attempt")
66
+ flow_name, run_id, step_name, task_id = task_pathspec.split("/")
67
+ running_invocations.append(
68
+ f"Flow Name: {flow_name}, Run ID: {run_id}, Step Name: {step_name}, Task ID: {task_id}, Retry Count: {attempt}"
69
+ )
70
+
71
+ if running_invocations:
72
+ for each_invocation in running_invocations:
73
+ ctx.obj.echo(each_invocation)
74
+
75
+
76
+ @nvct.command(help="Cancel steps / tasks running as an nvct job.")
77
+ @click.option(
78
+ "--run-id",
79
+ default=None,
80
+ required=True,
81
+ help="Terminate unfinished tasks corresponding to the run id.",
82
+ )
83
+ @click.pass_context
84
+ def kill(ctx, run_id):
85
+ ngc_api_key = get_ngc_api_key()
86
+ nvct_client = NVCTClient(api_key=ngc_api_key)
87
+
88
+ flow_name = ctx.obj.flow.name
89
+ run_obj = Run(pathspec=f"{flow_name}/{run_id}", _namespace_check=False)
90
+ tasks_cancelled = []
91
+
92
+ for each_step in run_obj:
93
+ for each_task in each_step:
94
+ if not each_task.finished and "nvct-task-id" in each_task.metadata_dict:
95
+ task_pathspec = each_task.pathspec
96
+ attempt = each_task.metadata_dict.get("attempt")
97
+ _, _, step_name, task_id = task_pathspec.split("/")
98
+
99
+ nvct_task_id = each_task.metadata_dict.get("nvct-task-id")
100
+ nvct_client.cancel(nvct_task_id)
101
+
102
+ tasks_cancelled.append(
103
+ f"[{nvct_task_id}] -- Flow Name: {flow_name}, Run ID: {run_id}, Step Name: {step_name}, Task ID: {task_id}, Retry Count: {attempt} is cancelled."
104
+ )
105
+
106
+ if tasks_cancelled:
107
+ for each_cancelled_task in tasks_cancelled:
108
+ ctx.obj.echo(each_cancelled_task)
109
+
110
+
111
+ @nvct.command(
112
+ help="Execute a single task using @nvct. This command calls the "
113
+ "top-level step command inside an nvct job with the given options. "
114
+ "Typically you do not call this command directly; it is used internally by "
115
+ "Metaflow."
116
+ )
117
+ @click.argument("step-name")
118
+ @click.argument("code-package-sha")
119
+ @click.argument("code-package-url")
120
+ @click.option("--gpu-type", help="Type of Nvidia GPU to use.")
121
+ @click.option("--instance-type", help="Instance type to use.")
122
+ @click.option("--backend", help="Backend to use.")
123
+ @click.option("--ngc-api-key", help="NGC API key.")
124
+ @click.option("--run-id", help="Passed to the top-level 'step'.")
125
+ @click.option("--task-id", help="Passed to the top-level 'step'.")
126
+ @click.option("--input-paths", help="Passed to the top-level 'step'.")
127
+ @click.option("--split-index", help="Passed to the top-level 'step'.")
128
+ @click.option("--clone-path", help="Passed to the top-level 'step'.")
129
+ @click.option("--clone-run-id", help="Passed to the top-level 'step'.")
130
+ @click.option(
131
+ "--tag", multiple=True, default=None, help="Passed to the top-level 'step'."
132
+ )
133
+ @click.option("--namespace", default=None, help="Passed to the top-level 'step'.")
134
+ @click.option("--retry-count", default=0, help="Passed to the top-level 'step'.")
135
+ @click.option(
136
+ "--max-user-code-retries", default=0, help="Passed to the top-level 'step'."
137
+ )
138
+ @click.pass_context
139
+ def step(
140
+ ctx,
141
+ step_name,
142
+ code_package_sha,
143
+ code_package_url,
144
+ gpu_type,
145
+ instance_type,
146
+ backend,
147
+ ngc_api_key,
148
+ **kwargs,
149
+ ):
150
+ def echo(msg, stream="stderr", _id=None, **kwargs):
151
+ msg = util.to_unicode(msg)
152
+ if _id:
153
+ msg = "[%s] %s" % (_id, msg)
154
+ ctx.obj.echo_always(msg, err=(stream == sys.stderr), **kwargs)
155
+
156
+ executable = ctx.obj.environment.executable(step_name)
157
+ entrypoint = "%s -u %s" % (executable, os.path.basename(sys.argv[0]))
158
+
159
+ top_args = " ".join(util.dict_to_cli_options(ctx.parent.parent.params))
160
+
161
+ input_paths = kwargs.get("input_paths")
162
+ split_vars = None
163
+ if input_paths:
164
+ max_size = 30 * 1024
165
+ split_vars = {
166
+ "METAFLOW_INPUT_PATHS_%d" % (i // max_size): input_paths[i : i + max_size]
167
+ for i in range(0, len(input_paths), max_size)
168
+ }
169
+ kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys())
170
+
171
+ step_args = " ".join(util.dict_to_cli_options(kwargs))
172
+ step_cli = "{entrypoint} {top_args} step {step} {step_args}".format(
173
+ entrypoint=entrypoint,
174
+ top_args=top_args,
175
+ step=step_name,
176
+ step_args=step_args,
177
+ )
178
+ node = ctx.obj.graph[step_name]
179
+
180
+ # Get retry information
181
+ retry_count = kwargs.get("retry_count", 0)
182
+ retry_deco = [deco for deco in node.decorators if deco.name == "retry"]
183
+ minutes_between_retries = None
184
+ if retry_deco:
185
+ minutes_between_retries = int(
186
+ retry_deco[0].attributes.get("minutes_between_retries", 1)
187
+ )
188
+
189
+ task_spec = {
190
+ "flow_name": ctx.obj.flow.name,
191
+ "step_name": step_name,
192
+ "run_id": kwargs["run_id"],
193
+ "task_id": kwargs["task_id"],
194
+ "retry_count": str(retry_count),
195
+ }
196
+
197
+ env = {
198
+ "METAFLOW_CODE_SHA": code_package_sha,
199
+ "METAFLOW_CODE_URL": code_package_url,
200
+ "METAFLOW_CODE_DS": ctx.obj.flow_datastore.TYPE,
201
+ "METAFLOW_SERVICE_URL": SERVICE_URL,
202
+ "METAFLOW_SERVICE_HEADERS": json.dumps(SERVICE_HEADERS),
203
+ "METAFLOW_DATASTORE_SYSROOT_S3": DATASTORE_SYSROOT_S3,
204
+ "METAFLOW_DATATOOLS_S3ROOT": DATATOOLS_S3ROOT,
205
+ "METAFLOW_DEFAULT_DATASTORE": ctx.obj.flow_datastore.TYPE,
206
+ "METAFLOW_USER": util.get_username(),
207
+ "METAFLOW_DEFAULT_METADATA": DEFAULT_METADATA,
208
+ "METAFLOW_CARD_S3ROOT": CARD_S3ROOT,
209
+ "METAFLOW_RUNTIME_ENVIRONMENT": "nvct",
210
+ "METAFLOW_DEFAULT_SECRETS_BACKEND_TYPE": DEFAULT_SECRETS_BACKEND_TYPE,
211
+ "METAFLOW_DEFAULT_AWS_CLIENT_PROVIDER": DEFAULT_AWS_CLIENT_PROVIDER,
212
+ "METAFLOW_AWS_SECRETS_MANAGER_DEFAULT_REGION": AWS_SECRETS_MANAGER_DEFAULT_REGION,
213
+ "METAFLOW_S3_ENDPOINT_URL": S3_ENDPOINT_URL,
214
+ "METAFLOW_AZURE_STORAGE_BLOB_SERVICE_ENDPOINT": AZURE_STORAGE_BLOB_SERVICE_ENDPOINT,
215
+ "METAFLOW_DATASTORE_SYSROOT_AZURE": DATASTORE_SYSROOT_AZURE,
216
+ "METAFLOW_CARD_AZUREROOT": CARD_AZUREROOT,
217
+ "METAFLOW_DATASTORE_SYSROOT_GS": DATASTORE_SYSROOT_GS,
218
+ "METAFLOW_CARD_GSROOT": CARD_GSROOT,
219
+ "METAFLOW_INIT_SCRIPT": KUBERNETES_SANDBOX_INIT_SCRIPT,
220
+ "METAFLOW_OTEL_ENDPOINT": OTEL_ENDPOINT,
221
+ "NVCT_CONTEXT": "1",
222
+ }
223
+
224
+ env_deco = [deco for deco in node.decorators if deco.name == "environment"]
225
+ if env_deco:
226
+ env.update(env_deco[0].attributes["vars"])
227
+
228
+ # Add the environment variables related to the input-paths argument
229
+ if split_vars:
230
+ env.update(split_vars)
231
+
232
+ if retry_count:
233
+ ctx.obj.echo_always(
234
+ "Sleeping %d minutes before the next retry" % minutes_between_retries
235
+ )
236
+ time.sleep(minutes_between_retries * 60)
237
+
238
+ # this information is needed for log tailing
239
+ ds = ctx.obj.flow_datastore.get_task_datastore(
240
+ mode="w",
241
+ run_id=kwargs["run_id"],
242
+ step_name=step_name,
243
+ task_id=kwargs["task_id"],
244
+ attempt=int(retry_count),
245
+ )
246
+ stdout_location = ds.get_log_location(TASK_LOG_SOURCE, "stdout")
247
+ stderr_location = ds.get_log_location(TASK_LOG_SOURCE, "stderr")
248
+
249
+ def _sync_metadata():
250
+ if ctx.obj.metadata.TYPE == "local":
251
+ sync_local_metadata_from_datastore(
252
+ DATASTORE_LOCAL_DIR,
253
+ ctx.obj.flow_datastore.get_task_datastore(
254
+ kwargs["run_id"], step_name, kwargs["task_id"]
255
+ ),
256
+ )
257
+
258
+ nvct = NvctRunner(
259
+ ctx.obj.metadata,
260
+ ctx.obj.flow_datastore,
261
+ ctx.obj.environment,
262
+ gpu_type,
263
+ instance_type,
264
+ backend,
265
+ ngc_api_key,
266
+ )
267
+ try:
268
+ with ctx.obj.monitor.measure("metaflow.nvct.launch_task"):
269
+ nvct.launch_task(
270
+ step_name,
271
+ step_cli,
272
+ task_spec,
273
+ code_package_sha,
274
+ code_package_url,
275
+ ctx.obj.flow_datastore.TYPE,
276
+ env=env,
277
+ )
278
+ except Exception:
279
+ traceback.print_exc()
280
+ _sync_metadata()
281
+ sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
282
+ try:
283
+ nvct.wait_for_completion(stdout_location, stderr_location, echo=echo)
284
+ except NvctKilledException:
285
+ # don't retry killed tasks
286
+ traceback.print_exc()
287
+ sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
288
+ finally:
289
+ _sync_metadata()
@@ -0,0 +1,286 @@
1
+ import os
2
+ import sys
3
+
4
+ from metaflow.exception import MetaflowException
5
+ from metaflow.decorators import StepDecorator
6
+ from metaflow.plugins.parallel_decorator import ParallelDecorator
7
+ from metaflow.metadata_provider.util import sync_local_metadata_to_datastore
8
+ from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
9
+ from metaflow.sidecar import Sidecar
10
+ from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
11
+ from metaflow.metadata_provider import MetaDatum
12
+
13
+ from .utils import get_ngc_api_key
14
+ from .exceptions import (
15
+ UnsupportedNvctDatastoreException,
16
+ NvctTimeoutTooShortException,
17
+ RequestedGPUTypeUnavailableException,
18
+ UnsupportedNvctConfigurationException,
19
+ )
20
+
21
+
22
+ DEFAULT_GPU_TYPE = "H100"
23
+
24
+ SUPPORTABLE_GPU_TYPES = {
25
+ "L40": [
26
+ {
27
+ "n_gpus": 1,
28
+ "instance_type": "gl40_1.br20_2xlarge",
29
+ "backend": "GFN",
30
+ },
31
+ ],
32
+ "L40S": [
33
+ {
34
+ "n_gpus": 1,
35
+ "instance_type": "gl40s_1.br25_2xlarge",
36
+ "backend": "GFN",
37
+ },
38
+ ],
39
+ "L40G": [
40
+ {
41
+ "n_gpus": 1,
42
+ "instance_type": "gl40g_1.br25_2xlarge",
43
+ "backend": "GFN",
44
+ },
45
+ ],
46
+ "H100": [
47
+ {
48
+ "n_gpus": 1,
49
+ "instance_type": "OCI.GPU.H100_1x",
50
+ "backend": "nvcf-dgxc-k8s-oci-nrt-prd8",
51
+ },
52
+ {
53
+ "n_gpus": 2,
54
+ "instance_type": "OCI.GPU.H100_2x",
55
+ "backend": "nvcf-dgxc-k8s-oci-nrt-prd8",
56
+ },
57
+ {
58
+ "n_gpus": 4,
59
+ "instance_type": "OCI.GPU.H100_4x",
60
+ "backend": "nvcf-dgxc-k8s-oci-nrt-prd8",
61
+ },
62
+ {
63
+ "n_gpus": 8,
64
+ "instance_type": "OCI.GPU.H100_8x",
65
+ "backend": "nvcf-dgxc-k8s-oci-nrt-prd8",
66
+ },
67
+ ],
68
+ "NEBIUS_H100": [
69
+ {
70
+ "n_gpus": 1,
71
+ "instance_type": "ON-PREM.GPU.H100_1x",
72
+ "backend": "default-project-eu-north1",
73
+ },
74
+ {
75
+ "n_gpus": 2,
76
+ "instance_type": "ON-PREM.GPU.H100_2x",
77
+ "backend": "default-project-eu-north1",
78
+ },
79
+ {
80
+ "n_gpus": 4,
81
+ "instance_type": "ON-PREM.GPU.H100_4x",
82
+ "backend": "default-project-eu-north1",
83
+ },
84
+ {
85
+ "n_gpus": 8,
86
+ "instance_type": "ON-PREM.GPU.H100_8x",
87
+ "backend": "default-project-eu-north1",
88
+ },
89
+ ],
90
+ }
91
+
92
+
93
+ class NvctDecorator(StepDecorator):
94
+
95
+ """
96
+ Specifies that this step should execute on DGX cloud.
97
+
98
+ Parameters
99
+ ----------
100
+ gpu : int
101
+ Number of GPUs to use.
102
+ gpu_type : str
103
+ Type of Nvidia GPU to use.
104
+ """
105
+
106
+ name = "nvct"
107
+ defaults = {
108
+ "gpu": 1,
109
+ "gpu_type": None,
110
+ "ngc_api_key": None,
111
+ "instance_type": None,
112
+ "backend": None,
113
+ }
114
+
115
+ package_url = None
116
+ package_sha = None
117
+
118
+ # Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png
119
+ # to understand where these functions are invoked in the lifecycle of a
120
+ # Metaflow flow.
121
+ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger):
122
+ # Executing NVCT functions requires a non-local datastore.
123
+ if flow_datastore.TYPE not in ("s3", "azure", "gs"):
124
+ raise UnsupportedNvctDatastoreException(flow_datastore.TYPE)
125
+
126
+ # Set internal state.
127
+ self.logger = logger
128
+ self.environment = environment
129
+ self.step = step
130
+ self.flow_datastore = flow_datastore
131
+
132
+ if any([deco.name == "kubernetes" for deco in decos]):
133
+ raise MetaflowException(
134
+ "Step *{step}* is marked for execution both on Kubernetes and "
135
+ "Nvct. Please use one or the other.".format(step=step)
136
+ )
137
+ if any([isinstance(deco, ParallelDecorator) for deco in decos]):
138
+ raise MetaflowException(
139
+ "Step *{step}* contains a @parallel decorator "
140
+ "with the @nvct decorator. @parallel decorators are not currently supported with @nvct.".format(
141
+ step=step
142
+ )
143
+ )
144
+
145
+ # Set run time limit for NVCT.
146
+ self.run_time_limit = get_run_time_limit_for_task(decos)
147
+ if self.run_time_limit < 60:
148
+ raise NvctTimeoutTooShortException(step)
149
+
150
+ self.attributes["ngc_api_key"] = get_ngc_api_key()
151
+
152
+ requested_gpu_type = self.attributes["gpu_type"]
153
+ requested_n_gpus = self.attributes["gpu"]
154
+
155
+ if requested_gpu_type is None:
156
+ requested_gpu_type = DEFAULT_GPU_TYPE
157
+ if requested_gpu_type not in SUPPORTABLE_GPU_TYPES:
158
+ raise RequestedGPUTypeUnavailableException(
159
+ requested_gpu_type, list(SUPPORTABLE_GPU_TYPES.keys())
160
+ )
161
+
162
+ valid_config = None
163
+ available_configurations = SUPPORTABLE_GPU_TYPES[requested_gpu_type]
164
+ for each_config in available_configurations:
165
+ if each_config["n_gpus"] == requested_n_gpus:
166
+ valid_config = each_config
167
+ break
168
+
169
+ if valid_config is None:
170
+ raise UnsupportedNvctConfigurationException(
171
+ requested_n_gpus,
172
+ requested_gpu_type,
173
+ available_configurations,
174
+ step,
175
+ )
176
+
177
+ self.attributes["instance_type"] = valid_config["instance_type"]
178
+ self.attributes["gpu_type"] = requested_gpu_type
179
+ if self.attributes["gpu_type"] == "NEBIUS_H100":
180
+ self.attributes["gpu_type"] = "H100"
181
+ self.attributes["backend"] = valid_config["backend"]
182
+
183
+ def runtime_init(self, flow, graph, package, run_id):
184
+ # Set some more internal state.
185
+ self.flow = flow
186
+ self.graph = graph
187
+ self.package = package
188
+ self.run_id = run_id
189
+
190
+ def runtime_task_created(
191
+ self, task_datastore, task_id, split_index, input_paths, is_cloned, ubf_context
192
+ ):
193
+ if not is_cloned:
194
+ self._save_package_once(self.flow_datastore, self.package)
195
+
196
+ def runtime_step_cli(
197
+ self, cli_args, retry_count, max_user_code_retries, ubf_context
198
+ ):
199
+ if retry_count <= max_user_code_retries:
200
+ # after all attempts to run the user code have failed, we don't need
201
+ # to execute on NVCF anymore. We can execute possible fallback
202
+ # code locally.
203
+ cli_args.commands = ["nvct", "step"]
204
+ cli_args.command_args.append(self.package_sha)
205
+ cli_args.command_args.append(self.package_url)
206
+ cli_options = {
207
+ "gpu_type": self.attributes["gpu_type"],
208
+ "instance_type": self.attributes["instance_type"],
209
+ "backend": self.attributes["backend"],
210
+ "ngc_api_key": self.attributes["ngc_api_key"],
211
+ }
212
+ cli_args.command_options.update(cli_options)
213
+ cli_args.entrypoint[0] = sys.executable
214
+
215
+ def task_pre_step(
216
+ self,
217
+ step_name,
218
+ task_datastore,
219
+ metadata,
220
+ run_id,
221
+ task_id,
222
+ flow,
223
+ graph,
224
+ retry_count,
225
+ max_retries,
226
+ ubf_context,
227
+ inputs,
228
+ ):
229
+ self.metadata = metadata
230
+ self.task_datastore = task_datastore
231
+
232
+ # task_pre_step may run locally if fallback is activated for @catch
233
+ # decorator.
234
+
235
+ if "NVCT_CONTEXT" in os.environ:
236
+ meta = {}
237
+
238
+ meta["nvct-task-id"] = os.environ.get("NVCT_TASK_ID")
239
+ meta["nvct-task-name"] = os.environ.get("NVCT_TASK_NAME")
240
+ meta["nvct-ncaid"] = os.environ.get("NVCT_NCA_ID")
241
+ meta["nvct-progress-file-path"] = os.environ.get("NVCT_PROGRESS_FILE_PATH")
242
+ meta["nvct-results-dir"] = os.environ.get("NVCT_RESULTS_DIR")
243
+
244
+ entries = [
245
+ MetaDatum(
246
+ field=k,
247
+ value=v,
248
+ type=k,
249
+ tags=["attempt_id:{0}".format(retry_count)],
250
+ )
251
+ for k, v in meta.items()
252
+ if v is not None
253
+ ]
254
+ # Register book-keeping metadata for debugging.
255
+ metadata.register_metadata(run_id, step_name, task_id, entries)
256
+
257
+ self._save_logs_sidecar = Sidecar("save_logs_periodically")
258
+ self._save_logs_sidecar.start()
259
+
260
+ def task_finished(
261
+ self, step_name, flow, graph, is_task_ok, retry_count, max_retries
262
+ ):
263
+ # task_finished may run locally if fallback is activated for @catch
264
+ # decorator.
265
+ if "NVCT_CONTEXT" in os.environ:
266
+ # If `local` metadata is configured, we would need to copy task
267
+ # execution metadata from the NVCT container to user's
268
+ # local file system after the user code has finished execution.
269
+ # This happens via datastore as a communication bridge.
270
+ if hasattr(self, "metadata") and self.metadata.TYPE == "local":
271
+ sync_local_metadata_to_datastore(
272
+ DATASTORE_LOCAL_DIR, self.task_datastore
273
+ )
274
+
275
+ try:
276
+ self._save_logs_sidecar.terminate()
277
+ except:
278
+ # Best effort kill
279
+ pass
280
+
281
+ @classmethod
282
+ def _save_package_once(cls, flow_datastore, package):
283
+ if cls.package_url is None:
284
+ cls.package_url, cls.package_sha = flow_datastore.save_data(
285
+ [package.blob], len_hint=1
286
+ )[0]