ob-metaflow 2.15.7.2__py2.py3-none-any.whl → 2.15.11.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ob-metaflow might be problematic. Click here for more details.

Files changed (39) hide show
  1. metaflow/cli.py +8 -0
  2. metaflow/cli_components/run_cmds.py +2 -2
  3. metaflow/cmd/main_cli.py +1 -1
  4. metaflow/includefile.py +2 -2
  5. metaflow/metadata_provider/metadata.py +35 -0
  6. metaflow/metaflow_config.py +6 -0
  7. metaflow/metaflow_environment.py +6 -1
  8. metaflow/metaflow_git.py +115 -0
  9. metaflow/metaflow_version.py +2 -2
  10. metaflow/plugins/__init__.py +1 -0
  11. metaflow/plugins/argo/argo_workflows.py +13 -2
  12. metaflow/plugins/argo/argo_workflows_cli.py +1 -0
  13. metaflow/plugins/aws/aws_client.py +4 -3
  14. metaflow/plugins/datastores/gs_storage.py +3 -1
  15. metaflow/plugins/datatools/s3/s3.py +54 -45
  16. metaflow/plugins/datatools/s3/s3op.py +149 -62
  17. metaflow/plugins/kubernetes/kubernetes.py +4 -0
  18. metaflow/plugins/kubernetes/kubernetes_cli.py +8 -0
  19. metaflow/plugins/kubernetes/kubernetes_decorator.py +10 -0
  20. metaflow/plugins/kubernetes/kubernetes_job.py +8 -0
  21. metaflow/plugins/kubernetes/kubernetes_jobsets.py +7 -0
  22. metaflow/plugins/pypi/conda_decorator.py +2 -1
  23. metaflow/plugins/pypi/conda_environment.py +1 -0
  24. metaflow/plugins/uv/__init__.py +0 -0
  25. metaflow/plugins/uv/bootstrap.py +100 -0
  26. metaflow/plugins/uv/uv_environment.py +70 -0
  27. metaflow/runner/deployer.py +8 -2
  28. metaflow/runner/deployer_impl.py +6 -2
  29. metaflow/runner/metaflow_runner.py +7 -2
  30. metaflow/version.py +1 -1
  31. {ob_metaflow-2.15.7.2.data → ob_metaflow-2.15.11.1.data}/data/share/metaflow/devtools/Makefile +2 -0
  32. {ob_metaflow-2.15.7.2.dist-info → ob_metaflow-2.15.11.1.dist-info}/METADATA +2 -2
  33. {ob_metaflow-2.15.7.2.dist-info → ob_metaflow-2.15.11.1.dist-info}/RECORD +39 -35
  34. {ob_metaflow-2.15.7.2.dist-info → ob_metaflow-2.15.11.1.dist-info}/WHEEL +1 -1
  35. {ob_metaflow-2.15.7.2.data → ob_metaflow-2.15.11.1.data}/data/share/metaflow/devtools/Tiltfile +0 -0
  36. {ob_metaflow-2.15.7.2.data → ob_metaflow-2.15.11.1.data}/data/share/metaflow/devtools/pick_services.sh +0 -0
  37. {ob_metaflow-2.15.7.2.dist-info → ob_metaflow-2.15.11.1.dist-info}/entry_points.txt +0 -0
  38. {ob_metaflow-2.15.7.2.dist-info → ob_metaflow-2.15.11.1.dist-info}/licenses/LICENSE +0 -0
  39. {ob_metaflow-2.15.7.2.dist-info → ob_metaflow-2.15.11.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import print_function
2
2
 
3
+ import errno
3
4
  import json
4
5
  import time
5
6
  import math
@@ -15,7 +16,10 @@ from tempfile import NamedTemporaryFile
15
16
  from multiprocessing import Process, Queue
16
17
  from itertools import starmap, chain, islice
17
18
 
19
+ from boto3.exceptions import RetriesExceededError, S3UploadFailedError
18
20
  from boto3.s3.transfer import TransferConfig
21
+ from botocore.config import Config
22
+ from botocore.exceptions import ClientError, SSLError
19
23
 
20
24
  try:
21
25
  # python2
@@ -46,13 +50,21 @@ from metaflow.plugins.datatools.s3.s3util import (
46
50
  import metaflow.tracing as tracing
47
51
  from metaflow.metaflow_config import (
48
52
  S3_WORKER_COUNT,
53
+ S3_CLIENT_RETRY_CONFIG,
49
54
  )
50
55
 
51
56
  DOWNLOAD_FILE_THRESHOLD = 2 * TransferConfig().multipart_threshold
52
57
  DOWNLOAD_MAX_CHUNK = 2 * 1024 * 1024 * 1024 - 1
53
58
 
59
+ DEFAULT_S3_CLIENT_PARAMS = {"config": Config(retries=S3_CLIENT_RETRY_CONFIG)}
54
60
  RANGE_MATCH = re.compile(r"bytes (?P<start>[0-9]+)-(?P<end>[0-9]+)/(?P<total>[0-9]+)")
55
61
 
62
+ # from botocore ClientError MSG_TEMPLATE:
63
+ # https://github.com/boto/botocore/blob/68ca78f3097906c9231840a49931ef4382c41eea/botocore/exceptions.py#L521
64
+ BOTOCORE_MSG_TEMPLATE_MATCH = re.compile(
65
+ r"An error occurred \((\w+)\) when calling the (\w+) operation.*: (.+)"
66
+ )
67
+
56
68
  S3Config = namedtuple("S3Config", "role session_vars client_params")
57
69
 
58
70
 
@@ -97,6 +109,7 @@ ERROR_VERIFY_FAILED = 9
97
109
  ERROR_LOCAL_FILE_NOT_FOUND = 10
98
110
  ERROR_INVALID_RANGE = 11
99
111
  ERROR_TRANSIENT = 12
112
+ ERROR_OUT_OF_DISK_SPACE = 13
100
113
 
101
114
 
102
115
  def format_result_line(idx, prefix, url="", local=""):
@@ -147,6 +160,7 @@ def normalize_client_error(err):
147
160
  "LimitExceededException",
148
161
  "RequestThrottled",
149
162
  "EC2ThrottledException",
163
+ "InternalError",
150
164
  ):
151
165
  return 503
152
166
  return error_code
@@ -221,54 +235,68 @@ def worker(result_file_name, queue, mode, s3config):
221
235
  elif mode == "download":
222
236
  tmp = NamedTemporaryFile(dir=".", mode="wb", delete=False)
223
237
  try:
224
- if url.range:
225
- resp = s3.get_object(
226
- Bucket=url.bucket, Key=url.path, Range=url.range
227
- )
228
- range_result = resp["ContentRange"]
229
- range_result_match = RANGE_MATCH.match(range_result)
230
- if range_result_match is None:
231
- raise RuntimeError(
232
- "Wrong format for ContentRange: %s"
233
- % str(range_result)
238
+ try:
239
+ if url.range:
240
+ resp = s3.get_object(
241
+ Bucket=url.bucket, Key=url.path, Range=url.range
234
242
  )
235
- range_result = {
236
- x: int(range_result_match.group(x))
237
- for x in ["total", "start", "end"]
238
- }
239
- else:
240
- resp = s3.get_object(Bucket=url.bucket, Key=url.path)
241
- range_result = None
242
- sz = resp["ContentLength"]
243
- if range_result is None:
244
- range_result = {"total": sz, "start": 0, "end": sz - 1}
245
- if not url.range and sz > DOWNLOAD_FILE_THRESHOLD:
246
- # In this case, it is more efficient to use download_file as it
247
- # will download multiple parts in parallel (it does it after
248
- # multipart_threshold)
249
- s3.download_file(url.bucket, url.path, tmp.name)
250
- else:
251
- read_in_chunks(tmp, resp["Body"], sz, DOWNLOAD_MAX_CHUNK)
252
- tmp.close()
253
- os.rename(tmp.name, url.local)
254
- except client_error as err:
255
- tmp.close()
256
- os.unlink(tmp.name)
257
- error_code = normalize_client_error(err)
258
- if error_code == 404:
259
- result_file.write("%d %d\n" % (idx, -ERROR_URL_NOT_FOUND))
243
+ range_result = resp["ContentRange"]
244
+ range_result_match = RANGE_MATCH.match(range_result)
245
+ if range_result_match is None:
246
+ raise RuntimeError(
247
+ "Wrong format for ContentRange: %s"
248
+ % str(range_result)
249
+ )
250
+ range_result = {
251
+ x: int(range_result_match.group(x))
252
+ for x in ["total", "start", "end"]
253
+ }
254
+ else:
255
+ resp = s3.get_object(Bucket=url.bucket, Key=url.path)
256
+ range_result = None
257
+ sz = resp["ContentLength"]
258
+ if range_result is None:
259
+ range_result = {"total": sz, "start": 0, "end": sz - 1}
260
+ if not url.range and sz > DOWNLOAD_FILE_THRESHOLD:
261
+ # In this case, it is more efficient to use download_file as it
262
+ # will download multiple parts in parallel (it does it after
263
+ # multipart_threshold)
264
+ s3.download_file(url.bucket, url.path, tmp.name)
265
+ else:
266
+ read_in_chunks(
267
+ tmp, resp["Body"], sz, DOWNLOAD_MAX_CHUNK
268
+ )
269
+ tmp.close()
270
+ os.rename(tmp.name, url.local)
271
+ except client_error as err:
272
+ tmp.close()
273
+ os.unlink(tmp.name)
274
+ handle_client_error(err, idx, result_file)
260
275
  continue
261
- elif error_code == 403:
262
- result_file.write(
263
- "%d %d\n" % (idx, -ERROR_URL_ACCESS_DENIED)
264
- )
276
+ except RetriesExceededError as e:
277
+ tmp.close()
278
+ os.unlink(tmp.name)
279
+ err = convert_to_client_error(e)
280
+ handle_client_error(err, idx, result_file)
265
281
  continue
266
- elif error_code == 503:
267
- result_file.write("%d %d\n" % (idx, -ERROR_TRANSIENT))
282
+ except OSError as e:
283
+ tmp.close()
284
+ os.unlink(tmp.name)
285
+ if e.errno == errno.ENOSPC:
286
+ result_file.write(
287
+ "%d %d\n" % (idx, -ERROR_OUT_OF_DISK_SPACE)
288
+ )
289
+ else:
290
+ result_file.write("%d %d\n" % (idx, -ERROR_TRANSIENT))
291
+ result_file.flush()
268
292
  continue
269
- else:
270
- raise
271
- # TODO specific error message for out of disk space
293
+ except (SSLError, Exception) as e:
294
+ tmp.close()
295
+ os.unlink(tmp.name)
296
+ # assume anything else is transient
297
+ result_file.write("%d %d\n" % (idx, -ERROR_TRANSIENT))
298
+ result_file.flush()
299
+ continue
272
300
  # If we need the metadata, get it and write it out
273
301
  if pre_op_info:
274
302
  with open("%s_meta" % url.local, mode="w") as f:
@@ -316,28 +344,67 @@ def worker(result_file_name, queue, mode, s3config):
316
344
  if url.encryption is not None:
317
345
  extra["ServerSideEncryption"] = url.encryption
318
346
  try:
319
- s3.upload_file(
320
- url.local, url.bucket, url.path, ExtraArgs=extra
321
- )
322
- # We indicate that the file was uploaded
323
- result_file.write("%d %d\n" % (idx, 0))
324
- except client_error as err:
325
- error_code = normalize_client_error(err)
326
- if error_code == 403:
327
- result_file.write(
328
- "%d %d\n" % (idx, -ERROR_URL_ACCESS_DENIED)
347
+ try:
348
+ s3.upload_file(
349
+ url.local, url.bucket, url.path, ExtraArgs=extra
329
350
  )
351
+ # We indicate that the file was uploaded
352
+ result_file.write("%d %d\n" % (idx, 0))
353
+ except client_error as err:
354
+ # Shouldn't get here, but just in case.
355
+ # Internally, botocore catches ClientError and returns a S3UploadFailedError.
356
+ # See https://github.com/boto/boto3/blob/develop/boto3/s3/transfer.py#L377
357
+ handle_client_error(err, idx, result_file)
330
358
  continue
331
- elif error_code == 503:
332
- result_file.write("%d %d\n" % (idx, -ERROR_TRANSIENT))
359
+ except S3UploadFailedError as e:
360
+ err = convert_to_client_error(e)
361
+ handle_client_error(err, idx, result_file)
333
362
  continue
334
- else:
335
- raise
363
+ except (SSLError, Exception) as e:
364
+ # assume anything else is transient
365
+ result_file.write("%d %d\n" % (idx, -ERROR_TRANSIENT))
366
+ result_file.flush()
367
+ continue
336
368
  except:
337
369
  traceback.print_exc()
370
+ result_file.flush()
338
371
  sys.exit(ERROR_WORKER_EXCEPTION)
339
372
 
340
373
 
374
+ def convert_to_client_error(e):
375
+ match = BOTOCORE_MSG_TEMPLATE_MATCH.search(str(e))
376
+ if not match:
377
+ raise e
378
+ error_code = match.group(1)
379
+ operation_name = match.group(2)
380
+ error_message = match.group(3)
381
+ response = {
382
+ "Error": {
383
+ "Code": error_code,
384
+ "Message": error_message,
385
+ }
386
+ }
387
+ return ClientError(response, operation_name)
388
+
389
+
390
+ def handle_client_error(err, idx, result_file):
391
+ error_code = normalize_client_error(err)
392
+ if error_code == 404:
393
+ result_file.write("%d %d\n" % (idx, -ERROR_URL_NOT_FOUND))
394
+ result_file.flush()
395
+ elif error_code == 403:
396
+ result_file.write("%d %d\n" % (idx, -ERROR_URL_ACCESS_DENIED))
397
+ result_file.flush()
398
+ elif error_code == 503:
399
+ result_file.write("%d %d\n" % (idx, -ERROR_TRANSIENT))
400
+ result_file.flush()
401
+ else:
402
+ # optimistically assume it is a transient error
403
+ result_file.write("%d %d\n" % (idx, -ERROR_TRANSIENT))
404
+ result_file.flush()
405
+ # TODO specific error message for out of disk space
406
+
407
+
341
408
  def start_workers(mode, urls, num_workers, inject_failure, s3config):
342
409
  # We start the minimum of len(urls) or num_workers to avoid starting
343
410
  # workers that will definitely do nothing
@@ -381,6 +448,22 @@ def start_workers(mode, urls, num_workers, inject_failure, s3config):
381
448
  if proc.exitcode is not None:
382
449
  if proc.exitcode != 0:
383
450
  msg = "Worker process failed (exit code %d)" % proc.exitcode
451
+
452
+ # IMPORTANT: if this process has put items on a queue, then it will not terminate
453
+ # until all buffered items have been flushed to the pipe, causing a deadlock.
454
+ # `cancel_join_thread()` allows it to exit without flushing the queue.
455
+ # Without this line, the parent process would hang indefinitely when a subprocess
456
+ # did not exit cleanly in the case of unhandled exceptions.
457
+ #
458
+ # The error situation is:
459
+ # 1. this process puts stuff in queue
460
+ # 2. subprocess dies so doesn't consume its end-of-queue marker (the None)
461
+ # 3. other subprocesses consume all useful bits AND their end-of-queue marker
462
+ # 4. one marker is left and not consumed
463
+ # 5. this process cannot shut down until the queue is empty.
464
+ # 6. it will never be empty because all subprocesses (workers) have died.
465
+ queue.cancel_join_thread()
466
+
384
467
  exit(msg, proc.exitcode)
385
468
  # Read the output file if all went well
386
469
  with open(out_path, "r") as out_file:
@@ -573,6 +656,8 @@ def exit(exit_code, url):
573
656
  msg = "Local file not found: %s" % url
574
657
  elif exit_code == ERROR_TRANSIENT:
575
658
  msg = "Transient error for url: %s" % url
659
+ elif exit_code == ERROR_OUT_OF_DISK_SPACE:
660
+ msg = "Out of disk space when downloading URL: %s" % url
576
661
  else:
577
662
  msg = "Unknown error"
578
663
  print("s3op failed:\n%s" % msg, file=sys.stderr)
@@ -745,7 +830,7 @@ def lst(
745
830
  s3config = S3Config(
746
831
  s3role,
747
832
  json.loads(s3sessionvars) if s3sessionvars else None,
748
- json.loads(s3clientparams) if s3clientparams else None,
833
+ json.loads(s3clientparams) if s3clientparams else DEFAULT_S3_CLIENT_PARAMS,
749
834
  )
750
835
 
751
836
  urllist = []
@@ -878,7 +963,7 @@ def put(
878
963
  s3config = S3Config(
879
964
  s3role,
880
965
  json.loads(s3sessionvars) if s3sessionvars else None,
881
- json.loads(s3clientparams) if s3clientparams else None,
966
+ json.loads(s3clientparams) if s3clientparams else DEFAULT_S3_CLIENT_PARAMS,
882
967
  )
883
968
 
884
969
  urls = list(starmap(_make_url, _files()))
@@ -1025,7 +1110,7 @@ def get(
1025
1110
  s3config = S3Config(
1026
1111
  s3role,
1027
1112
  json.loads(s3sessionvars) if s3sessionvars else None,
1028
- json.loads(s3clientparams) if s3clientparams else None,
1113
+ json.loads(s3clientparams) if s3clientparams else DEFAULT_S3_CLIENT_PARAMS,
1029
1114
  )
1030
1115
 
1031
1116
  # Construct a list of URL (prefix) objects
@@ -1103,6 +1188,8 @@ def get(
1103
1188
  )
1104
1189
  if verify:
1105
1190
  verify_info.append((url, sz))
1191
+ elif sz == -ERROR_OUT_OF_DISK_SPACE:
1192
+ exit(ERROR_OUT_OF_DISK_SPACE, url)
1106
1193
  elif sz == -ERROR_URL_ACCESS_DENIED:
1107
1194
  denied_url = url
1108
1195
  break
@@ -1172,7 +1259,7 @@ def info(
1172
1259
  s3config = S3Config(
1173
1260
  s3role,
1174
1261
  json.loads(s3sessionvars) if s3sessionvars else None,
1175
- json.loads(s3clientparams) if s3clientparams else None,
1262
+ json.loads(s3clientparams) if s3clientparams else DEFAULT_S3_CLIENT_PARAMS,
1176
1263
  )
1177
1264
 
1178
1265
  # Construct a list of URL (prefix) objects
@@ -194,6 +194,7 @@ class Kubernetes(object):
194
194
  port=None,
195
195
  num_parallel=None,
196
196
  qos=None,
197
+ security_context=None,
197
198
  ):
198
199
  name = "js-%s" % str(uuid4())[:6]
199
200
  jobset = (
@@ -227,6 +228,7 @@ class Kubernetes(object):
227
228
  port=port,
228
229
  num_parallel=num_parallel,
229
230
  qos=qos,
231
+ security_context=security_context,
230
232
  )
231
233
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
232
234
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -504,6 +506,7 @@ class Kubernetes(object):
504
506
  name_pattern=None,
505
507
  qos=None,
506
508
  annotations=None,
509
+ security_context=None,
507
510
  ):
508
511
  if env is None:
509
512
  env = {}
@@ -546,6 +549,7 @@ class Kubernetes(object):
546
549
  shared_memory=shared_memory,
547
550
  port=port,
548
551
  qos=qos,
552
+ security_context=security_context,
549
553
  )
550
554
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
551
555
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -145,6 +145,12 @@ def kubernetes():
145
145
  type=JSONTypeClass(),
146
146
  multiple=False,
147
147
  )
148
+ @click.option(
149
+ "--security-context",
150
+ default=None,
151
+ type=JSONTypeClass(),
152
+ multiple=False,
153
+ )
148
154
  @click.pass_context
149
155
  def step(
150
156
  ctx,
@@ -176,6 +182,7 @@ def step(
176
182
  qos=None,
177
183
  labels=None,
178
184
  annotations=None,
185
+ security_context=None,
179
186
  **kwargs
180
187
  ):
181
188
  def echo(msg, stream="stderr", job_id=None, **kwargs):
@@ -319,6 +326,7 @@ def step(
319
326
  qos=qos,
320
327
  labels=labels,
321
328
  annotations=annotations,
329
+ security_context=security_context,
322
330
  )
323
331
  except Exception:
324
332
  traceback.print_exc(chain=False)
@@ -124,6 +124,14 @@ class KubernetesDecorator(StepDecorator):
124
124
  Only applicable when @parallel is used.
125
125
  qos: str, default: Burstable
126
126
  Quality of Service class to assign to the pod. Supported values are: Guaranteed, Burstable, BestEffort
127
+
128
+ security_context: Dict[str, Any], optional, default None
129
+ Container security context. Applies to the task container. Allows the following keys:
130
+ - privileged: bool, optional, default None
131
+ - allow_privilege_escalation: bool, optional, default None
132
+ - run_as_user: int, optional, default None
133
+ - run_as_group: int, optional, default None
134
+ - run_as_non_root: bool, optional, default None
127
135
  """
128
136
 
129
137
  name = "kubernetes"
@@ -154,6 +162,7 @@ class KubernetesDecorator(StepDecorator):
154
162
  "executable": None,
155
163
  "hostname_resolution_timeout": 10 * 60,
156
164
  "qos": KUBERNETES_QOS,
165
+ "security_context": None,
157
166
  }
158
167
  package_url = None
159
168
  package_sha = None
@@ -489,6 +498,7 @@ class KubernetesDecorator(StepDecorator):
489
498
  "persistent_volume_claims",
490
499
  "labels",
491
500
  "annotations",
501
+ "security_context",
492
502
  ]:
493
503
  cli_args.command_options[k] = json.dumps(v)
494
504
  else:
@@ -94,6 +94,13 @@ class KubernetesJob(object):
94
94
  ],
95
95
  }
96
96
 
97
+ security_context = self._kwargs.get("security_context", {})
98
+ _security_context = {}
99
+ if security_context is not None and len(security_context) > 0:
100
+ _security_context = {
101
+ "security_context": client.V1SecurityContext(**security_context)
102
+ }
103
+
97
104
  return client.V1JobSpec(
98
105
  # Retries are handled by Metaflow when it is responsible for
99
106
  # executing the flow. The responsibility is moved to Kubernetes
@@ -224,6 +231,7 @@ class KubernetesJob(object):
224
231
  if self._kwargs["persistent_volume_claims"] is not None
225
232
  else []
226
233
  ),
234
+ **_security_context,
227
235
  )
228
236
  ],
229
237
  node_selector=self._kwargs.get("node_selector"),
@@ -562,6 +562,12 @@ class JobSetSpec(object):
562
562
  self._kwargs["memory"],
563
563
  self._kwargs["disk"],
564
564
  )
565
+ security_context = self._kwargs.get("security_context", {})
566
+ _security_context = {}
567
+ if security_context is not None and len(security_context) > 0:
568
+ _security_context = {
569
+ "security_context": client.V1SecurityContext(**security_context)
570
+ }
565
571
  return dict(
566
572
  name=self.name,
567
573
  template=client.api_client.ApiClient().sanitize_for_serialization(
@@ -708,6 +714,7 @@ class JobSetSpec(object):
708
714
  is not None
709
715
  else []
710
716
  ),
717
+ **_security_context,
711
718
  )
712
719
  ],
713
720
  node_selector=self._kwargs.get("node_selector"),
@@ -227,7 +227,8 @@ class CondaStepDecorator(StepDecorator):
227
227
  self.interpreter = (
228
228
  self.environment.interpreter(self.step)
229
229
  if not any(
230
- decorator.name in ["batch", "kubernetes", "nvidia", "snowpark", "slurm"]
230
+ decorator.name
231
+ in ["batch", "kubernetes", "nvidia", "snowpark", "slurm", "nvct"]
231
232
  for decorator in next(
232
233
  step for step in self.flow if step.name == self.step
233
234
  ).decorators
@@ -326,6 +326,7 @@ class CondaEnvironment(MetaflowEnvironment):
326
326
  "nvidia",
327
327
  "snowpark",
328
328
  "slurm",
329
+ "nvct",
329
330
  ]:
330
331
  target_platform = getattr(decorator, "target_platform", "linux-64")
331
332
  break
File without changes
@@ -0,0 +1,100 @@
1
+ import os
2
+ import subprocess
3
+ import sys
4
+ import time
5
+
6
+ from metaflow.util import which
7
+ from metaflow.metaflow_config import get_pinned_conda_libs
8
+ from urllib.request import Request, urlopen
9
+ from urllib.error import URLError
10
+
11
+ # TODO: support version/platform/architecture selection.
12
+ UV_URL = "https://github.com/astral-sh/uv/releases/download/0.6.11/uv-x86_64-unknown-linux-gnu.tar.gz"
13
+
14
+ if __name__ == "__main__":
15
+
16
+ def run_cmd(cmd, stdin_str=None):
17
+ result = subprocess.run(
18
+ cmd,
19
+ shell=True,
20
+ input=stdin_str,
21
+ stdout=subprocess.PIPE,
22
+ stderr=subprocess.PIPE,
23
+ text=True,
24
+ )
25
+ if result.returncode != 0:
26
+ print(f"Bootstrap failed while executing: {cmd}")
27
+ print("Stdout:", result.stdout)
28
+ print("Stderr:", result.stderr)
29
+ sys.exit(1)
30
+
31
+ def install_uv():
32
+ import tarfile
33
+
34
+ uv_install_path = os.path.join(os.getcwd(), "uv_install")
35
+ if which("uv"):
36
+ return
37
+
38
+ print("Installing uv...")
39
+
40
+ # Prepare directory once
41
+ os.makedirs(uv_install_path, exist_ok=True)
42
+
43
+ # Download and decompress in one go
44
+ headers = {
45
+ "Accept-Encoding": "gzip, deflate, br",
46
+ "Connection": "keep-alive",
47
+ "User-Agent": "python-urllib",
48
+ }
49
+
50
+ def _tar_filter(member: tarfile.TarInfo, path):
51
+ if os.path.basename(member.name) != "uv":
52
+ return None # skip
53
+ member.path = os.path.basename(member.path)
54
+ return member
55
+
56
+ max_retries = 3
57
+ for attempt in range(max_retries):
58
+ try:
59
+ req = Request(UV_URL, headers=headers)
60
+ with urlopen(req) as response:
61
+ with tarfile.open(fileobj=response, mode="r:gz") as tar:
62
+ tar.extractall(uv_install_path, filter=_tar_filter)
63
+ break
64
+ except (URLError, IOError) as e:
65
+ if attempt == max_retries - 1:
66
+ raise Exception(
67
+ f"Failed to download UV after {max_retries} attempts: {e}"
68
+ )
69
+ time.sleep(2**attempt)
70
+
71
+ # Update PATH only once at the end
72
+ os.environ["PATH"] += os.pathsep + uv_install_path
73
+
74
+ def get_dependencies(datastore_type):
75
+ # return required dependencies for Metaflow that must be added to the UV environment.
76
+ pinned = get_pinned_conda_libs(None, datastore_type)
77
+
78
+ # return only dependency names instead of pinned versions
79
+ return pinned.keys()
80
+
81
+ def sync_uv_project(datastore_type):
82
+ print("Syncing uv project...")
83
+ dependencies = " ".join(get_dependencies(datastore_type))
84
+ cmd = f"""set -e;
85
+ uv sync --frozen --no-install-package metaflow;
86
+ uv pip install {dependencies} --strict
87
+ """
88
+ run_cmd(cmd)
89
+
90
+ if len(sys.argv) != 2:
91
+ print("Usage: bootstrap.py <datastore_type>")
92
+ sys.exit(1)
93
+
94
+ try:
95
+ datastore_type = sys.argv[1]
96
+ install_uv()
97
+ sync_uv_project(datastore_type)
98
+ except Exception as e:
99
+ print(f"Error: {str(e)}", file=sys.stderr)
100
+ sys.exit(1)
@@ -0,0 +1,70 @@
1
+ import os
2
+
3
+ from metaflow.exception import MetaflowException
4
+ from metaflow.metaflow_environment import MetaflowEnvironment
5
+
6
+
7
+ class UVException(MetaflowException):
8
+ headline = "uv error"
9
+
10
+
11
+ class UVEnvironment(MetaflowEnvironment):
12
+ TYPE = "uv"
13
+
14
+ def __init__(self, flow):
15
+ self.flow = flow
16
+
17
+ def validate_environment(self, logger, datastore_type):
18
+ self.datastore_type = datastore_type
19
+ self.logger = logger
20
+
21
+ def init_environment(self, echo, only_steps=None):
22
+ self.logger("Bootstrapping uv...")
23
+
24
+ def executable(self, step_name, default=None):
25
+ return "uv run python"
26
+
27
+ def add_to_package(self):
28
+ # NOTE: We treat uv.lock and pyproject.toml as regular project assets and ship these along user code as part of the code package
29
+ # These are the minimal required files to reproduce the UV environment on the remote platform.
30
+ def _find(filename):
31
+ current_dir = os.getcwd()
32
+ while True:
33
+ file_path = os.path.join(current_dir, filename)
34
+ if os.path.isfile(file_path):
35
+ return file_path
36
+ parent_dir = os.path.dirname(current_dir)
37
+ if parent_dir == current_dir: # Reached root
38
+ raise UVException(
39
+ f"Could not find {filename} in current directory or any parent directory"
40
+ )
41
+ current_dir = parent_dir
42
+
43
+ pyproject_path = _find("pyproject.toml")
44
+ uv_lock_path = _find("uv.lock")
45
+ files = [
46
+ (uv_lock_path, "uv.lock"),
47
+ (pyproject_path, "pyproject.toml"),
48
+ ]
49
+ return files
50
+
51
+ def pylint_config(self):
52
+ config = super().pylint_config()
53
+ # Disable (import-error) in pylint
54
+ config.append("--disable=F0401")
55
+ return config
56
+
57
+ def bootstrap_commands(self, step_name, datastore_type):
58
+ return [
59
+ "echo 'Bootstrapping uv project...'",
60
+ "flush_mflogs",
61
+ # We have to prevent the tracing module from loading, as the bootstrapping process
62
+ # uses the internal S3 client which would fail to import tracing due to the required
63
+ # dependencies being bundled into the conda environment, which is yet to be
64
+ # initialized at this point.
65
+ 'DISABLE_TRACING=True python -m metaflow.plugins.uv.bootstrap "%s"'
66
+ % datastore_type,
67
+ "echo 'uv project bootstrapped.'",
68
+ "flush_mflogs",
69
+ "export PATH=$PATH:$(pwd)/uv_install",
70
+ ]