metaflow 2.11.1__py2.py3-none-any.whl → 2.11.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. metaflow/flowspec.py +7 -3
  2. metaflow/metaflow_config.py +11 -1
  3. metaflow/parameters.py +6 -0
  4. metaflow/plugins/argo/argo_workflows.py +101 -23
  5. metaflow/plugins/aws/batch/batch.py +2 -0
  6. metaflow/plugins/aws/batch/batch_client.py +10 -2
  7. metaflow/plugins/aws/step_functions/dynamo_db_client.py +28 -6
  8. metaflow/plugins/aws/step_functions/production_token.py +1 -1
  9. metaflow/plugins/aws/step_functions/step_functions.py +219 -4
  10. metaflow/plugins/aws/step_functions/step_functions_cli.py +104 -6
  11. metaflow/plugins/aws/step_functions/step_functions_client.py +8 -3
  12. metaflow/plugins/aws/step_functions/step_functions_decorator.py +1 -1
  13. metaflow/plugins/cards/card_cli.py +2 -2
  14. metaflow/plugins/kubernetes/kubernetes.py +2 -0
  15. metaflow/plugins/kubernetes/kubernetes_cli.py +3 -0
  16. metaflow/plugins/kubernetes/kubernetes_client.py +10 -2
  17. metaflow/plugins/kubernetes/kubernetes_decorator.py +17 -0
  18. metaflow/plugins/kubernetes/kubernetes_job.py +27 -0
  19. metaflow/plugins/pypi/bootstrap.py +1 -1
  20. metaflow/plugins/pypi/conda_decorator.py +21 -1
  21. metaflow/plugins/pypi/conda_environment.py +21 -4
  22. metaflow/version.py +1 -1
  23. {metaflow-2.11.1.dist-info → metaflow-2.11.3.dist-info}/METADATA +2 -2
  24. {metaflow-2.11.1.dist-info → metaflow-2.11.3.dist-info}/RECORD +28 -28
  25. {metaflow-2.11.1.dist-info → metaflow-2.11.3.dist-info}/LICENSE +0 -0
  26. {metaflow-2.11.1.dist-info → metaflow-2.11.3.dist-info}/WHEEL +0 -0
  27. {metaflow-2.11.1.dist-info → metaflow-2.11.3.dist-info}/entry_points.txt +0 -0
  28. {metaflow-2.11.1.dist-info → metaflow-2.11.3.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ from metaflow.metaflow_config import (
15
15
  SFN_DYNAMO_DB_TABLE,
16
16
  SFN_EXECUTION_LOG_GROUP_ARN,
17
17
  SFN_IAM_ROLE,
18
+ SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH,
18
19
  )
19
20
  from metaflow.parameters import deploy_time_eval
20
21
  from metaflow.util import dict_to_cli_options, to_pascalcase
@@ -52,6 +53,7 @@ class StepFunctions(object):
52
53
  max_workers=None,
53
54
  workflow_timeout=None,
54
55
  is_project=False,
56
+ use_distributed_map=False,
55
57
  ):
56
58
  self.name = name
57
59
  self.graph = graph
@@ -70,6 +72,9 @@ class StepFunctions(object):
70
72
  self.max_workers = max_workers
71
73
  self.workflow_timeout = workflow_timeout
72
74
 
75
+ # https://aws.amazon.com/blogs/aws/step-functions-distributed-map-a-serverless-solution-for-large-scale-parallel-data-processing/
76
+ self.use_distributed_map = use_distributed_map
77
+
73
78
  self._client = StepFunctionsClient()
74
79
  self._workflow = self._compile()
75
80
  self._cron = self._cron()
@@ -166,6 +171,13 @@ class StepFunctions(object):
166
171
 
167
172
  return schedule_deleted, sfn_deleted
168
173
 
174
+ @classmethod
175
+ def terminate(cls, flow_name, name):
176
+ client = StepFunctionsClient()
177
+ execution_arn, _, _, _ = cls.get_execution(flow_name, name)
178
+ response = client.terminate_execution(execution_arn)
179
+ return response
180
+
169
181
  @classmethod
170
182
  def trigger(cls, name, parameters):
171
183
  try:
@@ -234,6 +246,50 @@ class StepFunctions(object):
234
246
  )
235
247
  return None
236
248
 
249
+ @classmethod
250
+ def get_execution(cls, state_machine_name, name):
251
+ client = StepFunctionsClient()
252
+ try:
253
+ state_machine = client.get(state_machine_name)
254
+ except Exception as e:
255
+ raise StepFunctionsException(repr(e))
256
+ if state_machine is None:
257
+ raise StepFunctionsException(
258
+ "The state machine *%s* doesn't exist on AWS Step Functions."
259
+ % state_machine_name
260
+ )
261
+ try:
262
+ state_machine_arn = state_machine.get("stateMachineArn")
263
+ environment_vars = (
264
+ json.loads(state_machine.get("definition"))
265
+ .get("States")
266
+ .get("start")
267
+ .get("Parameters")
268
+ .get("ContainerOverrides")
269
+ .get("Environment")
270
+ )
271
+ parameters = {
272
+ item.get("Name"): item.get("Value") for item in environment_vars
273
+ }
274
+ executions = client.list_executions(state_machine_arn, states=["RUNNING"])
275
+ for execution in executions:
276
+ if execution.get("name") == name:
277
+ try:
278
+ return (
279
+ execution.get("executionArn"),
280
+ parameters.get("METAFLOW_OWNER"),
281
+ parameters.get("METAFLOW_PRODUCTION_TOKEN"),
282
+ parameters.get("SFN_STATE_MACHINE"),
283
+ )
284
+ except KeyError:
285
+ raise StepFunctionsException(
286
+ "A non-metaflow workflow *%s* already exists in AWS Step Functions."
287
+ % name
288
+ )
289
+ return None
290
+ except Exception as e:
291
+ raise StepFunctionsException(repr(e))
292
+
237
293
  def _compile(self):
238
294
  if self.flow._flow_decorators.get("trigger") or self.flow._flow_decorators.get(
239
295
  "trigger_on_finish"
@@ -314,17 +370,80 @@ class StepFunctions(object):
314
370
  .parameter("SplitParentTaskId.$", "$.JobId")
315
371
  .parameter("Parameters.$", "$.Parameters")
316
372
  .parameter("Index.$", "$$.Map.Item.Value")
317
- .next(node.matching_join)
373
+ .next(
374
+ "%s_*GetManifest" % iterator_name
375
+ if self.use_distributed_map
376
+ else node.matching_join
377
+ )
318
378
  .iterator(
319
379
  _visit(
320
380
  self.graph[node.out_funcs[0]],
321
- Workflow(node.out_funcs[0]).start_at(node.out_funcs[0]),
381
+ Workflow(node.out_funcs[0])
382
+ .start_at(node.out_funcs[0])
383
+ .mode(
384
+ "DISTRIBUTED" if self.use_distributed_map else "INLINE"
385
+ ),
322
386
  node.matching_join,
323
387
  )
324
388
  )
325
389
  .max_concurrency(self.max_workers)
326
- .output_path("$.[0]")
390
+ # AWS Step Functions has a short coming for DistributedMap at the
391
+ # moment that does not allow us to subset the output of for-each
392
+ # to just a single element. We have to rely on a rather terrible
393
+ # hack and resort to using ResultWriter to write the state to
394
+ # Amazon S3 and process it in another task. But, well what can we
395
+ # do...
396
+ .result_writer(
397
+ *(
398
+ (
399
+ (
400
+ SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH[len("s3://") :]
401
+ if SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH.startswith(
402
+ "s3://"
403
+ )
404
+ else SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH
405
+ ).split("/", 1)
406
+ + [""]
407
+ )[:2]
408
+ if self.use_distributed_map
409
+ else (None, None)
410
+ )
411
+ )
412
+ .output_path("$" if self.use_distributed_map else "$.[0]")
327
413
  )
414
+ if self.use_distributed_map:
415
+ workflow.add_state(
416
+ State("%s_*GetManifest" % iterator_name)
417
+ .resource("arn:aws:states:::aws-sdk:s3:getObject")
418
+ .parameter("Bucket.$", "$.ResultWriterDetails.Bucket")
419
+ .parameter("Key.$", "$.ResultWriterDetails.Key")
420
+ .next("%s_*Map" % iterator_name)
421
+ .result_selector("Body.$", "States.StringToJson($.Body)")
422
+ )
423
+ workflow.add_state(
424
+ Map("%s_*Map" % iterator_name)
425
+ .iterator(
426
+ Workflow("%s_*PassWorkflow" % iterator_name)
427
+ .mode("DISTRIBUTED")
428
+ .start_at("%s_*Pass" % iterator_name)
429
+ .add_state(
430
+ Pass("%s_*Pass" % iterator_name)
431
+ .end()
432
+ .parameter("Output.$", "States.StringToJson($.Output)")
433
+ .output_path("$.Output")
434
+ )
435
+ )
436
+ .next(node.matching_join)
437
+ .max_concurrency(1000)
438
+ .item_reader(
439
+ JSONItemReader()
440
+ .resource("arn:aws:states:::s3:getObject")
441
+ .parameter("Bucket.$", "$.Body.DestinationBucket")
442
+ .parameter("Key.$", "$.Body.ResultFiles.SUCCEEDED.[0].Key")
443
+ )
444
+ .output_path("$.[0]")
445
+ )
446
+
328
447
  # Continue the traversal from the matching_join.
329
448
  _visit(self.graph[node.matching_join], workflow, exit_node)
330
449
  # We shouldn't ideally ever get here.
@@ -393,7 +512,6 @@ class StepFunctions(object):
393
512
  "metaflow.owner": self.username,
394
513
  "metaflow.flow_name": self.flow.name,
395
514
  "metaflow.step_name": node.name,
396
- "metaflow.run_id.$": "$$.Execution.Name",
397
515
  # Unfortunately we can't set the task id here since AWS Step
398
516
  # Functions lacks any notion of run-scoped task identifiers. We
399
517
  # instead co-opt the AWS Batch job id as the task id. This also
@@ -405,6 +523,10 @@ class StepFunctions(object):
405
523
  # `$$.State.RetryCount` resolves to an int dynamically and
406
524
  # AWS Batch job specification only accepts strings. We handle
407
525
  # retries/catch within AWS Batch to get around this limitation.
526
+ # And, we also cannot set the run id here since the run id maps to
527
+ # the execution name of the AWS Step Functions State Machine, which
528
+ # is different when executing inside a distributed map. We set it once
529
+ # in the start step and move it along to be consumed by all the children.
408
530
  "metaflow.version": self.environment.get_environment_info()[
409
531
  "metaflow_version"
410
532
  ],
@@ -441,6 +563,12 @@ class StepFunctions(object):
441
563
  env["METAFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL
442
564
 
443
565
  if node.name == "start":
566
+ # metaflow.run_id maps to AWS Step Functions State Machine Execution in all
567
+ # cases except for when within a for-each construct that relies on
568
+ # Distributed Map. To work around this issue, we pass the run id from the
569
+ # start step to all subsequent tasks.
570
+ attrs["metaflow.run_id.$"] = "$$.Execution.Name"
571
+
444
572
  # Initialize parameters for the flow in the `start` step.
445
573
  parameters = self._process_parameters()
446
574
  if parameters:
@@ -499,6 +627,8 @@ class StepFunctions(object):
499
627
  env["METAFLOW_SPLIT_PARENT_TASK_ID"] = (
500
628
  "$.Parameters.split_parent_task_id_%s" % node.split_parents[-1]
501
629
  )
630
+ # Inherit the run id from the parent and pass it along to children.
631
+ attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
502
632
  else:
503
633
  # Set appropriate environment variables for runtime replacement.
504
634
  if len(node.in_funcs) == 1:
@@ -507,6 +637,8 @@ class StepFunctions(object):
507
637
  % node.in_funcs[0]
508
638
  )
509
639
  env["METAFLOW_PARENT_TASK_ID"] = "$.JobId"
640
+ # Inherit the run id from the parent and pass it along to children.
641
+ attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
510
642
  else:
511
643
  # Generate the input paths in a quasi-compressed format.
512
644
  # See util.decompress_list for why this is written the way
@@ -516,6 +648,8 @@ class StepFunctions(object):
516
648
  "${METAFLOW_PARENT_%s_TASK_ID}" % (idx, idx)
517
649
  for idx, _ in enumerate(node.in_funcs)
518
650
  )
651
+ # Inherit the run id from the parent and pass it along to children.
652
+ attrs["metaflow.run_id.$"] = "$.[0].Parameters.['metaflow.run_id']"
519
653
  for idx, _ in enumerate(node.in_funcs):
520
654
  env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx
521
655
  env["METAFLOW_PARENT_%s_STEP" % idx] = (
@@ -842,6 +976,12 @@ class Workflow(object):
842
976
  tree = lambda: defaultdict(tree)
843
977
  self.payload = tree()
844
978
 
979
+ def mode(self, mode):
980
+ self.payload["ProcessorConfig"] = {"Mode": mode}
981
+ if mode == "DISTRIBUTED":
982
+ self.payload["ProcessorConfig"]["ExecutionType"] = "STANDARD"
983
+ return self
984
+
845
985
  def start_at(self, start_at):
846
986
  self.payload["StartAt"] = start_at
847
987
  return self
@@ -889,10 +1029,18 @@ class State(object):
889
1029
  self.payload["ResultPath"] = result_path
890
1030
  return self
891
1031
 
1032
+ def result_selector(self, name, value):
1033
+ self.payload["ResultSelector"][name] = value
1034
+ return self
1035
+
892
1036
  def _partition(self):
893
1037
  # This is needed to support AWS Gov Cloud and AWS CN regions
894
1038
  return SFN_IAM_ROLE.split(":")[1]
895
1039
 
1040
+ def retry_strategy(self, retry_strategy):
1041
+ self.payload["Retry"] = [retry_strategy]
1042
+ return self
1043
+
896
1044
  def batch(self, job):
897
1045
  self.resource(
898
1046
  "arn:%s:states:::batch:submitJob.sync" % self._partition()
@@ -912,6 +1060,19 @@ class State(object):
912
1060
  # tags may not be present in all scenarios
913
1061
  if "tags" in job.payload:
914
1062
  self.parameter("Tags", job.payload["tags"])
1063
+ # set retry strategy for AWS Batch job submission to account for the
1064
+ # measily 50 jobs / second queue admission limit which people can
1065
+ # run into very quickly.
1066
+ self.retry_strategy(
1067
+ {
1068
+ "ErrorEquals": ["Batch.AWSBatchException"],
1069
+ "BackoffRate": 2,
1070
+ "IntervalSeconds": 2,
1071
+ "MaxDelaySeconds": 60,
1072
+ "MaxAttempts": 10,
1073
+ "JitterStrategy": "FULL",
1074
+ }
1075
+ )
915
1076
  return self
916
1077
 
917
1078
  def dynamo_db(self, table_name, primary_key, values):
@@ -925,6 +1086,26 @@ class State(object):
925
1086
  return self
926
1087
 
927
1088
 
1089
+ class Pass(object):
1090
+ def __init__(self, name):
1091
+ self.name = name
1092
+ tree = lambda: defaultdict(tree)
1093
+ self.payload = tree()
1094
+ self.payload["Type"] = "Pass"
1095
+
1096
+ def end(self):
1097
+ self.payload["End"] = True
1098
+ return self
1099
+
1100
+ def parameter(self, name, value):
1101
+ self.payload["Parameters"][name] = value
1102
+ return self
1103
+
1104
+ def output_path(self, output_path):
1105
+ self.payload["OutputPath"] = output_path
1106
+ return self
1107
+
1108
+
928
1109
  class Parallel(object):
929
1110
  def __init__(self, name):
930
1111
  self.name = name
@@ -986,3 +1167,37 @@ class Map(object):
986
1167
  def result_path(self, result_path):
987
1168
  self.payload["ResultPath"] = result_path
988
1169
  return self
1170
+
1171
+ def item_reader(self, item_reader):
1172
+ self.payload["ItemReader"] = item_reader.payload
1173
+ return self
1174
+
1175
+ def result_writer(self, bucket, prefix):
1176
+ if bucket is not None and prefix is not None:
1177
+ self.payload["ResultWriter"] = {
1178
+ "Resource": "arn:aws:states:::s3:putObject",
1179
+ "Parameters": {
1180
+ "Bucket": bucket,
1181
+ "Prefix": prefix,
1182
+ },
1183
+ }
1184
+ return self
1185
+
1186
+
1187
+ class JSONItemReader(object):
1188
+ def __init__(self):
1189
+ tree = lambda: defaultdict(tree)
1190
+ self.payload = tree()
1191
+ self.payload["ReaderConfig"] = {"InputType": "JSON", "MaxItems": 1}
1192
+
1193
+ def resource(self, resource):
1194
+ self.payload["Resource"] = resource
1195
+ return self
1196
+
1197
+ def parameter(self, name, value):
1198
+ self.payload["Parameters"][name] = value
1199
+ return self
1200
+
1201
+ def output_path(self, output_path):
1202
+ self.payload["OutputPath"] = output_path
1203
+ return self
@@ -1,23 +1,23 @@
1
1
  import base64
2
- from metaflow._vendor import click
3
- from hashlib import sha1
4
2
  import json
5
3
  import re
4
+ from hashlib import sha1
6
5
 
7
- from metaflow import current, decorators, parameters, JSONType
6
+ from metaflow import JSONType, current, decorators, parameters
7
+ from metaflow._vendor import click
8
+ from metaflow.exception import MetaflowException, MetaflowInternalError
8
9
  from metaflow.metaflow_config import (
9
10
  SERVICE_VERSION_CHECK,
10
11
  SFN_STATE_MACHINE_PREFIX,
11
12
  UI_URL,
12
13
  )
13
- from metaflow.exception import MetaflowException, MetaflowInternalError
14
14
  from metaflow.package import MetaflowPackage
15
15
  from metaflow.plugins.aws.batch.batch_decorator import BatchDecorator
16
16
  from metaflow.tagging_util import validate_tags
17
17
  from metaflow.util import get_username, to_bytes, to_unicode, version_parse
18
18
 
19
+ from .production_token import load_token, new_token, store_token
19
20
  from .step_functions import StepFunctions
20
- from .production_token import load_token, store_token, new_token
21
21
 
22
22
  VALID_NAME = re.compile(r"[^a-zA-Z0-9_\-\.]")
23
23
 
@@ -26,6 +26,10 @@ class IncorrectProductionToken(MetaflowException):
26
26
  headline = "Incorrect production token"
27
27
 
28
28
 
29
+ class RunIdMismatch(MetaflowException):
30
+ headline = "Run ID mismatch"
31
+
32
+
29
33
  class IncorrectMetadataServiceVersion(MetaflowException):
30
34
  headline = "Incorrect version for metaflow service"
31
35
 
@@ -120,6 +124,12 @@ def step_functions(obj, name=None):
120
124
  help="Log AWS Step Functions execution history to AWS CloudWatch "
121
125
  "Logs log group.",
122
126
  )
127
+ @click.option(
128
+ "--use-distributed-map/--no-use-distributed-map",
129
+ is_flag=True,
130
+ help="Use AWS Step Functions Distributed Map instead of Inline Map for "
131
+ "defining foreach tasks in Amazon State Language.",
132
+ )
123
133
  @click.pass_obj
124
134
  def create(
125
135
  obj,
@@ -132,6 +142,7 @@ def create(
132
142
  max_workers=None,
133
143
  workflow_timeout=None,
134
144
  log_execution_history=False,
145
+ use_distributed_map=False,
135
146
  ):
136
147
  validate_tags(tags)
137
148
 
@@ -161,6 +172,7 @@ def create(
161
172
  max_workers,
162
173
  workflow_timeout,
163
174
  obj.is_project,
175
+ use_distributed_map,
164
176
  )
165
177
 
166
178
  if only_json:
@@ -269,7 +281,15 @@ def resolve_state_machine_name(obj, name):
269
281
 
270
282
 
271
283
  def make_flow(
272
- obj, token, name, tags, namespace, max_workers, workflow_timeout, is_project
284
+ obj,
285
+ token,
286
+ name,
287
+ tags,
288
+ namespace,
289
+ max_workers,
290
+ workflow_timeout,
291
+ is_project,
292
+ use_distributed_map,
273
293
  ):
274
294
  if obj.flow_datastore.TYPE != "s3":
275
295
  raise MetaflowException("AWS Step Functions requires --datastore=s3.")
@@ -305,6 +325,7 @@ def make_flow(
305
325
  username=get_username(),
306
326
  workflow_timeout=workflow_timeout,
307
327
  is_project=is_project,
328
+ use_distributed_map=use_distributed_map,
308
329
  )
309
330
 
310
331
 
@@ -614,6 +635,83 @@ def delete(obj, authorize=None):
614
635
  )
615
636
 
616
637
 
638
+ @step_functions.command(help="Terminate flow execution on Step Functions.")
639
+ @click.option(
640
+ "--authorize",
641
+ default=None,
642
+ type=str,
643
+ help="Authorize the termination with a production token",
644
+ )
645
+ @click.argument("run-id", required=True, type=str)
646
+ @click.pass_obj
647
+ def terminate(obj, run_id, authorize=None):
648
+ def _token_instructions(flow_name, prev_user):
649
+ obj.echo(
650
+ "There is an existing version of *%s* on AWS Step Functions which was "
651
+ "deployed by the user *%s*." % (flow_name, prev_user)
652
+ )
653
+ obj.echo(
654
+ "To terminate this flow, you need to use the same production token that they used."
655
+ )
656
+ obj.echo(
657
+ "Please reach out to them to get the token. Once you have it, call "
658
+ "this command:"
659
+ )
660
+ obj.echo(" step-functions terminate --authorize MY_TOKEN RUN_ID", fg="green")
661
+ obj.echo(
662
+ 'See "Organizing Results" at docs.metaflow.org for more information '
663
+ "about production tokens."
664
+ )
665
+
666
+ validate_run_id(
667
+ obj.state_machine_name, obj.token_prefix, authorize, run_id, _token_instructions
668
+ )
669
+
670
+ # Trim prefix from run_id
671
+ name = run_id[4:]
672
+ obj.echo(
673
+ "Terminating run *{run_id}* for {flow_name} ...".format(
674
+ run_id=run_id, flow_name=obj.flow.name
675
+ ),
676
+ bold=True,
677
+ )
678
+
679
+ terminated = StepFunctions.terminate(obj.state_machine_name, name)
680
+ if terminated:
681
+ obj.echo("\nRun terminated at %s." % terminated.get("stopDate"))
682
+
683
+
684
+ def validate_run_id(
685
+ state_machine_name, token_prefix, authorize, run_id, instructions_fn=None
686
+ ):
687
+ if not run_id.startswith("sfn-"):
688
+ raise RunIdMismatch(
689
+ "Run IDs for flows executed through AWS Step Functions begin with 'sfn-'"
690
+ )
691
+
692
+ name = run_id[4:]
693
+ execution = StepFunctions.get_execution(state_machine_name, name)
694
+ if execution is None:
695
+ raise MetaflowException(
696
+ "Could not find the execution *%s* (in RUNNING state) for the state machine *%s* on AWS Step Functions"
697
+ % (name, state_machine_name)
698
+ )
699
+
700
+ _, owner, token, _ = execution
701
+
702
+ if authorize is None:
703
+ authorize = load_token(token_prefix)
704
+ elif authorize.startswith("production:"):
705
+ authorize = authorize[11:]
706
+
707
+ if owner != get_username() and authorize != token:
708
+ if instructions_fn:
709
+ instructions_fn(flow_name=name, prev_user=owner)
710
+ raise IncorrectProductionToken("Try again with the correct production token.")
711
+
712
+ return True
713
+
714
+
617
715
  def validate_token(name, token_prefix, authorize, instruction_fn=None):
618
716
  """
619
717
  Validate that the production token matches that of the deployed flow.
@@ -81,9 +81,14 @@ class StepFunctionsClient(object):
81
81
  for execution in page["executions"]
82
82
  )
83
83
 
84
- def terminate_execution(self, state_machine_arn, execution_arn):
85
- # TODO
86
- pass
84
+ def terminate_execution(self, execution_arn):
85
+ try:
86
+ response = self._client.stop_execution(executionArn=execution_arn)
87
+ return response
88
+ except self._client.exceptions.ExecutionDoesNotExist:
89
+ raise ValueError("The execution ARN %s does not exist." % execution_arn)
90
+ except Exception as e:
91
+ raise e
87
92
 
88
93
  def _default_logging_configuration(self, log_execution_history):
89
94
  if log_execution_history:
@@ -1,5 +1,5 @@
1
- import os
2
1
  import json
2
+ import os
3
3
  import time
4
4
 
5
5
  from metaflow.decorators import StepDecorator
@@ -17,6 +17,7 @@ import random
17
17
  from contextlib import contextmanager
18
18
  from functools import wraps
19
19
  from metaflow.exception import MetaflowNamespaceMismatch
20
+
20
21
  from .card_datastore import CardDatastore, NUM_SHORT_HASH_CHARS
21
22
  from .exception import (
22
23
  CardClassFoundException,
@@ -736,8 +737,7 @@ def create(
736
737
 
737
738
  if error_stack_trace is not None and mode != "refresh":
738
739
  rendered_content = error_card().render(task, stack_trace=error_stack_trace)
739
-
740
- if (
740
+ elif (
741
741
  rendered_info.is_implemented
742
742
  and rendered_info.timed_out
743
743
  and mode != "refresh"
@@ -174,6 +174,7 @@ class Kubernetes(object):
174
174
  persistent_volume_claims=None,
175
175
  tolerations=None,
176
176
  labels=None,
177
+ shared_memory=None,
177
178
  ):
178
179
  if env is None:
179
180
  env = {}
@@ -213,6 +214,7 @@ class Kubernetes(object):
213
214
  tmpfs_size=tmpfs_size,
214
215
  tmpfs_path=tmpfs_path,
215
216
  persistent_volume_claims=persistent_volume_claims,
217
+ shared_memory=shared_memory,
216
218
  )
217
219
  .environment_variable("METAFLOW_CODE_SHA", code_package_sha)
218
220
  .environment_variable("METAFLOW_CODE_URL", code_package_url)
@@ -107,6 +107,7 @@ def kubernetes():
107
107
  type=JSONTypeClass(),
108
108
  multiple=False,
109
109
  )
110
+ @click.option("--shared-memory", default=None, help="Size of shared memory in MiB")
110
111
  @click.pass_context
111
112
  def step(
112
113
  ctx,
@@ -132,6 +133,7 @@ def step(
132
133
  run_time_limit=None,
133
134
  persistent_volume_claims=None,
134
135
  tolerations=None,
136
+ shared_memory=None,
135
137
  **kwargs
136
138
  ):
137
139
  def echo(msg, stream="stderr", job_id=None, **kwargs):
@@ -245,6 +247,7 @@ def step(
245
247
  env=env,
246
248
  persistent_volume_claims=persistent_volume_claims,
247
249
  tolerations=tolerations,
250
+ shared_memory=shared_memory,
248
251
  )
249
252
  except Exception as e:
250
253
  traceback.print_exc(chain=False)
@@ -6,6 +6,7 @@ from metaflow.exception import MetaflowException
6
6
 
7
7
  from .kubernetes_job import KubernetesJob
8
8
 
9
+
9
10
  CLIENT_REFRESH_INTERVAL_SECONDS = 300
10
11
 
11
12
 
@@ -32,11 +33,18 @@ class KubernetesClient(object):
32
33
  def _refresh_client(self):
33
34
  from kubernetes import client, config
34
35
 
35
- if os.getenv("KUBERNETES_SERVICE_HOST"):
36
+ if os.getenv("KUBECONFIG"):
37
+ # There are cases where we're running inside a pod, but can't use
38
+ # the kubernetes client for that pod's cluster: for example when
39
+ # running in Bitbucket Cloud or other CI system.
40
+ # In this scenario, the user can set a KUBECONFIG environment variable
41
+ # to load the kubeconfig, regardless of whether we're in a pod or not.
42
+ config.load_kube_config()
43
+ elif os.getenv("KUBERNETES_SERVICE_HOST"):
36
44
  # We are inside a pod, authenticate via ServiceAccount assigned to us
37
45
  config.load_incluster_config()
38
46
  else:
39
- # Use kubeconfig, likely $HOME/.kube/config
47
+ # Default to using kubeconfig, likely $HOME/.kube/config
40
48
  # TODO (savin):
41
49
  # 1. Support generating kubeconfig on the fly using boto3
42
50
  # 2. Support auth via OIDC - https://docs.aws.amazon.com/eks/latest/userguide/authenticate-oidc-identity-provider.html
@@ -20,6 +20,7 @@ from metaflow.metaflow_config import (
20
20
  KUBERNETES_PERSISTENT_VOLUME_CLAIMS,
21
21
  KUBERNETES_TOLERATIONS,
22
22
  KUBERNETES_SERVICE_ACCOUNT,
23
+ KUBERNETES_SHARED_MEMORY,
23
24
  )
24
25
  from metaflow.plugins.resources_decorator import ResourcesDecorator
25
26
  from metaflow.plugins.timeout_decorator import get_run_time_limit_for_task
@@ -87,6 +88,8 @@ class KubernetesDecorator(StepDecorator):
87
88
  persistent_volume_claims : Dict[str, str], optional, default None
88
89
  A map (dictionary) of persistent volumes to be mounted to the pod for this step. The map is from persistent
89
90
  volumes to the path to which the volume is to be mounted, e.g., `{'pvc-name': '/path/to/mount/on'}`.
91
+ shared_memory: int, optional
92
+ Shared memory size (in MiB) required for this step
90
93
  """
91
94
 
92
95
  name = "kubernetes"
@@ -109,6 +112,7 @@ class KubernetesDecorator(StepDecorator):
109
112
  "tmpfs_size": None,
110
113
  "tmpfs_path": "/metaflow_temp",
111
114
  "persistent_volume_claims": None, # e.g., {"pvc-name": "/mnt/vol", "another-pvc": "/mnt/vol2"}
115
+ "shared_memory": None,
112
116
  }
113
117
  package_url = None
114
118
  package_sha = None
@@ -194,6 +198,8 @@ class KubernetesDecorator(StepDecorator):
194
198
  if not self.attributes["tmpfs_size"]:
195
199
  # default tmpfs behavior - https://man7.org/linux/man-pages/man5/tmpfs.5.html
196
200
  self.attributes["tmpfs_size"] = int(self.attributes["memory"]) // 2
201
+ if not self.attributes["shared_memory"]:
202
+ self.attributes["shared_memory"] = KUBERNETES_SHARED_MEMORY
197
203
 
198
204
  # Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png
199
205
  def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger):
@@ -289,6 +295,17 @@ class KubernetesDecorator(StepDecorator):
289
295
  )
290
296
  )
291
297
 
298
+ if self.attributes["shared_memory"]:
299
+ if not (
300
+ isinstance(self.attributes["shared_memory"], int)
301
+ and int(self.attributes["shared_memory"]) > 0
302
+ ):
303
+ raise KubernetesException(
304
+ "Invalid shared_memory value: *{size}* for step *{step}* (should be an integer greater than 0)".format(
305
+ size=self.attributes["shared_memory"], step=step
306
+ )
307
+ )
308
+
292
309
  def package_init(self, flow, step_name, environment):
293
310
  try:
294
311
  # Kubernetes is a soft dependency.