metaflow 2.12.10__py2.py3-none-any.whl → 2.12.12__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. metaflow/client/core.py +6 -6
  2. metaflow/client/filecache.py +16 -3
  3. metaflow/cmd/develop/stub_generator.py +62 -47
  4. metaflow/datastore/content_addressed_store.py +1 -1
  5. metaflow/datastore/task_datastore.py +1 -1
  6. metaflow/decorators.py +2 -4
  7. metaflow/extension_support/__init__.py +3 -3
  8. metaflow/extension_support/plugins.py +3 -3
  9. metaflow/metaflow_config.py +35 -18
  10. metaflow/parameters.py +3 -3
  11. metaflow/plugins/airflow/airflow.py +6 -6
  12. metaflow/plugins/airflow/airflow_utils.py +5 -3
  13. metaflow/plugins/argo/argo_workflows.py +407 -193
  14. metaflow/plugins/argo/argo_workflows_cli.py +17 -4
  15. metaflow/plugins/argo/argo_workflows_decorator.py +6 -13
  16. metaflow/plugins/argo/capture_error.py +70 -0
  17. metaflow/plugins/aws/step_functions/step_functions.py +3 -3
  18. metaflow/plugins/cards/card_modules/basic.py +5 -3
  19. metaflow/plugins/cards/card_modules/convert_to_native_type.py +2 -2
  20. metaflow/plugins/cards/card_modules/renderer_tools.py +1 -0
  21. metaflow/plugins/cards/card_modules/test_cards.py +0 -2
  22. metaflow/plugins/datatools/s3/s3op.py +5 -3
  23. metaflow/plugins/kubernetes/kubernetes.py +1 -0
  24. metaflow/plugins/kubernetes/kubernetes_job.py +10 -8
  25. metaflow/plugins/kubernetes/kubernetes_jobsets.py +15 -14
  26. metaflow/plugins/logs_cli.py +1 -0
  27. metaflow/plugins/pypi/conda_environment.py +1 -3
  28. metaflow/plugins/pypi/pip.py +3 -3
  29. metaflow/plugins/tag_cli.py +3 -3
  30. metaflow/procpoll.py +1 -1
  31. metaflow/runtime.py +1 -0
  32. metaflow/util.py +6 -6
  33. metaflow/version.py +1 -1
  34. {metaflow-2.12.10.dist-info → metaflow-2.12.12.dist-info}/METADATA +2 -2
  35. {metaflow-2.12.10.dist-info → metaflow-2.12.12.dist-info}/RECORD +39 -38
  36. {metaflow-2.12.10.dist-info → metaflow-2.12.12.dist-info}/WHEEL +1 -1
  37. {metaflow-2.12.10.dist-info → metaflow-2.12.12.dist-info}/LICENSE +0 -0
  38. {metaflow-2.12.10.dist-info → metaflow-2.12.12.dist-info}/entry_points.txt +0 -0
  39. {metaflow-2.12.10.dist-info → metaflow-2.12.12.dist-info}/top_level.txt +0 -0
@@ -5,11 +5,14 @@ import re
5
5
  import sys
6
6
  from hashlib import sha1
7
7
 
8
- from metaflow import Run, JSONType, current, decorators, parameters
9
- from metaflow.client.core import get_metadata
10
- from metaflow.exception import MetaflowNotFound
8
+ from metaflow import JSONType, Run, current, decorators, parameters
11
9
  from metaflow._vendor import click
12
- from metaflow.exception import MetaflowException, MetaflowInternalError
10
+ from metaflow.client.core import get_metadata
11
+ from metaflow.exception import (
12
+ MetaflowException,
13
+ MetaflowInternalError,
14
+ MetaflowNotFound,
15
+ )
13
16
  from metaflow.metaflow_config import (
14
17
  ARGO_WORKFLOWS_UI_URL,
15
18
  KUBERNETES_NAMESPACE,
@@ -181,6 +184,12 @@ def argo_workflows(obj, name=None):
181
184
  help="Write the workflow name to the file specified. Used internally for Metaflow's Deployer API.",
182
185
  hidden=True,
183
186
  )
187
+ @click.option(
188
+ "--enable-error-msg-capture/--no-enable-error-msg-capture",
189
+ default=False,
190
+ show_default=True,
191
+ help="Capture stack trace of first failed task in exit hook.",
192
+ )
184
193
  @click.pass_obj
185
194
  def create(
186
195
  obj,
@@ -200,6 +209,7 @@ def create(
200
209
  notify_pager_duty_integration_key=None,
201
210
  enable_heartbeat_daemon=True,
202
211
  deployer_attribute_file=None,
212
+ enable_error_msg_capture=False,
203
213
  ):
204
214
  validate_tags(tags)
205
215
 
@@ -248,6 +258,7 @@ def create(
248
258
  notify_slack_webhook_url,
249
259
  notify_pager_duty_integration_key,
250
260
  enable_heartbeat_daemon,
261
+ enable_error_msg_capture,
251
262
  )
252
263
 
253
264
  if only_json:
@@ -421,6 +432,7 @@ def make_flow(
421
432
  notify_slack_webhook_url,
422
433
  notify_pager_duty_integration_key,
423
434
  enable_heartbeat_daemon,
435
+ enable_error_msg_capture,
424
436
  ):
425
437
  # TODO: Make this check less specific to Amazon S3 as we introduce
426
438
  # support for more cloud object stores.
@@ -484,6 +496,7 @@ def make_flow(
484
496
  notify_slack_webhook_url=notify_slack_webhook_url,
485
497
  notify_pager_duty_integration_key=notify_pager_duty_integration_key,
486
498
  enable_heartbeat_daemon=enable_heartbeat_daemon,
499
+ enable_error_msg_capture=enable_error_msg_capture,
487
500
  )
488
501
 
489
502
 
@@ -54,7 +54,7 @@ class ArgoWorkflowsInternalDecorator(StepDecorator):
54
54
  "_", 1
55
55
  )[
56
56
  0
57
- ] # infer type from env var key
57
+ ], # infer type from env var key
58
58
  # Add more event metadata here in the future
59
59
  }
60
60
  )
@@ -108,18 +108,12 @@ class ArgoWorkflowsInternalDecorator(StepDecorator):
108
108
  # we run pods with a security context. We work around this constraint by
109
109
  # mounting an emptyDir volume.
110
110
  if graph[step_name].type == "foreach":
111
- # A DAGNode is considered a `parallel_step` if it is annotated by the @parallel decorator.
112
- # A DAGNode is considered a `parallel_foreach` if it contains a `num_parallel` kwarg provided to the
113
- # `next` method of that DAGNode.
114
- # At this moment in the code we care if a node is marked as a `parallel_foreach` so that we can pass down the
115
- # value of `num_parallel` to the subsequent steps.
116
- # For @parallel, the implmentation uses 1 jobset object. That one jobset
117
- # object internally creates 'num_parallel' jobs. So, we set foreach_num_splits
118
- # to 1 here for @parallel. The parallelism of jobset is handled in
119
- # kubernetes_job.py.
120
111
  if graph[step_name].parallel_foreach:
112
+ # If a node is marked as a `parallel_foreach`, pass down the value of
113
+ # `num_parallel` to the subsequent steps.
121
114
  with open("/mnt/out/num_parallel", "w") as f:
122
115
  json.dump(flow._parallel_ubf_iter.num_parallel, f)
116
+ # Set splits to 1 since parallelism is handled by JobSet.
123
117
  flow._foreach_num_splits = 1
124
118
  with open("/mnt/out/task_id_entropy", "w") as file:
125
119
  import uuid
@@ -131,10 +125,9 @@ class ArgoWorkflowsInternalDecorator(StepDecorator):
131
125
  with open("/mnt/out/split_cardinality", "w") as file:
132
126
  json.dump(flow._foreach_num_splits, file)
133
127
 
134
- # for steps that have a `@parallel` decorator set to them, we will be relying on Jobsets
128
+ # For steps that have a `@parallel` decorator set to them, we will be relying on Jobsets
135
129
  # to run the task. In this case, we cannot set anything in the
136
- # `/mnt/out` directory, since such form of output mounts are not available to jobset execution as
137
- # argo just treats it like A K8s resource that it throws in the cluster.
130
+ # `/mnt/out` directory, since such form of output mounts are not available to Jobset executions.
138
131
  if not graph[step_name].parallel_step:
139
132
  # Unfortunately, we can't always use pod names as task-ids since the pod names
140
133
  # are not static across retries. We write the task-id to a file that is read
@@ -0,0 +1,70 @@
1
+ import json
2
+ import os
3
+ from datetime import datetime
4
+
5
+ ###
6
+ # Algorithm to determine 1st error:
7
+ # ignore the failures where message = ""
8
+ # group the failures via templateName
9
+ # sort each group by finishedAt
10
+ # find the group for which the last finishedAt is earliest
11
+ # if the earliest message is "No more retries left" then
12
+ # get the n-1th message from that group
13
+ # else
14
+ # return the last message.
15
+ ###
16
+
17
+
18
+ def parse_workflow_failures():
19
+ failures = json.loads(
20
+ json.loads(os.getenv("METAFLOW_ARGO_WORKFLOW_FAILURES", "[]"), strict=False),
21
+ strict=False,
22
+ )
23
+ return [wf for wf in failures if wf.get("message")]
24
+
25
+
26
+ def group_failures_by_template(failures):
27
+ groups = {}
28
+ for failure in failures:
29
+ groups.setdefault(failure["templateName"], []).append(failure)
30
+ return groups
31
+
32
+
33
+ def sort_by_finished_at(items):
34
+ return sorted(
35
+ items, key=lambda x: datetime.strptime(x["finishedAt"], "%Y-%m-%dT%H:%M:%SZ")
36
+ )
37
+
38
+
39
+ def find_earliest_last_finished_group(groups):
40
+ return min(
41
+ groups,
42
+ key=lambda k: datetime.strptime(
43
+ groups[k][-1]["finishedAt"], "%Y-%m-%dT%H:%M:%SZ"
44
+ ),
45
+ )
46
+
47
+
48
+ def determine_first_error():
49
+ failures = parse_workflow_failures()
50
+ if not failures:
51
+ return None
52
+
53
+ grouped_failures = group_failures_by_template(failures)
54
+ for group in grouped_failures.values():
55
+ group.sort(
56
+ key=lambda x: datetime.strptime(x["finishedAt"], "%Y-%m-%dT%H:%M:%SZ")
57
+ )
58
+
59
+ earliest_group = grouped_failures[
60
+ find_earliest_last_finished_group(grouped_failures)
61
+ ]
62
+
63
+ if earliest_group[-1]["message"] == "No more retries left":
64
+ return earliest_group[-2]
65
+ return earliest_group[-1]
66
+
67
+
68
+ if __name__ == "__main__":
69
+ first_err = determine_first_error()
70
+ print(json.dumps(first_err, indent=2))
@@ -664,9 +664,9 @@ class StepFunctions(object):
664
664
  # input to those descendent tasks. We set and propagate the
665
665
  # task ids pointing to split_parents through every state.
666
666
  if any(self.graph[n].type == "foreach" for n in node.in_funcs):
667
- attrs[
668
- "split_parent_task_id_%s.$" % node.split_parents[-1]
669
- ] = "$.SplitParentTaskId"
667
+ attrs["split_parent_task_id_%s.$" % node.split_parents[-1]] = (
668
+ "$.SplitParentTaskId"
669
+ )
670
670
  for parent in node.split_parents[:-1]:
671
671
  if self.graph[parent].type == "foreach":
672
672
  attrs["split_parent_task_id_%s.$" % parent] = (
@@ -26,9 +26,11 @@ def transform_flow_graph(step_info):
26
26
  graph_dict[stepname] = {
27
27
  "type": node_to_type(step_info[stepname]["type"]),
28
28
  "box_next": step_info[stepname]["type"] not in ("linear", "join"),
29
- "box_ends": None
30
- if "matching_join" not in step_info[stepname]
31
- else step_info[stepname]["matching_join"],
29
+ "box_ends": (
30
+ None
31
+ if "matching_join" not in step_info[stepname]
32
+ else step_info[stepname]["matching_join"]
33
+ ),
32
34
  "next": step_info[stepname]["next"],
33
35
  "doc": step_info[stepname]["doc"],
34
36
  }
@@ -314,8 +314,8 @@ class TaskToDict:
314
314
  # If there is any form of TypeError or ValueError we set the column value to "Unsupported Type"
315
315
  # We also set columns which are have null values to "null" strings
316
316
  time_format = "%Y-%m-%dT%H:%M:%S%Z"
317
- truncate_long_objects = (
318
- lambda x: x.astype("string").str.slice(0, 30) + "..."
317
+ truncate_long_objects = lambda x: (
318
+ x.astype("string").str.slice(0, 30) + "..."
319
319
  if len(x) > 0 and x.astype("string").str.len().max() > 30
320
320
  else x.astype("string")
321
321
  )
@@ -40,6 +40,7 @@ def render_safely(func):
40
40
  This is a decorator that can be added to any `MetaflowCardComponent.render`
41
41
  The goal is to render subcomponents safely and ensure that they are JSON serializable.
42
42
  """
43
+
43
44
  # expects a renderer func
44
45
  def ret_func(self, *args, **kwargs):
45
46
  return _render_component_safely(self, func, True, *args, **kwargs)
@@ -138,7 +138,6 @@ class TestJSONComponent(MetaflowCardComponent):
138
138
 
139
139
 
140
140
  class TestRefreshCard(MetaflowCard):
141
-
142
141
  """
143
142
  This card takes no components and helps test the `current.card.refresh(data)` interface.
144
143
  """
@@ -178,7 +177,6 @@ def _component_values_to_hash(components):
178
177
 
179
178
 
180
179
  class TestRefreshComponentCard(MetaflowCard):
181
-
182
180
  """
183
181
  This card takes components and helps test the `current.card.components["A"].update()`
184
182
  interface
@@ -1119,9 +1119,11 @@ def get(
1119
1119
  str(url.idx),
1120
1120
  url_quote(url.prefix).decode(encoding="utf-8"),
1121
1121
  url_quote(url.url).decode(encoding="utf-8"),
1122
- url_quote(url.range).decode(encoding="utf-8")
1123
- if url.range
1124
- else "<norange>",
1122
+ (
1123
+ url_quote(url.range).decode(encoding="utf-8")
1124
+ if url.range
1125
+ else "<norange>"
1126
+ ),
1125
1127
  ]
1126
1128
  )
1127
1129
  + "\n"
@@ -299,6 +299,7 @@ class Kubernetes(object):
299
299
 
300
300
  jobset.environment_variables_from_selectors(
301
301
  {
302
+ "METAFLOW_KUBERNETES_NAMESPACE": "metadata.namespace",
302
303
  "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
303
304
  "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
304
305
  "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
@@ -99,13 +99,15 @@ class KubernetesJob(object):
99
99
  client.V1Container(
100
100
  command=self._kwargs["command"],
101
101
  termination_message_policy="FallbackToLogsOnError",
102
- ports=[]
103
- if self._kwargs["port"] is None
104
- else [
105
- client.V1ContainerPort(
106
- container_port=int(self._kwargs["port"])
107
- )
108
- ],
102
+ ports=(
103
+ []
104
+ if self._kwargs["port"] is None
105
+ else [
106
+ client.V1ContainerPort(
107
+ container_port=int(self._kwargs["port"])
108
+ )
109
+ ]
110
+ ),
109
111
  env=[
110
112
  client.V1EnvVar(name=k, value=str(v))
111
113
  for k, v in self._kwargs.get(
@@ -125,6 +127,7 @@ class KubernetesJob(object):
125
127
  ),
126
128
  )
127
129
  for k, v in {
130
+ "METAFLOW_KUBERNETES_NAMESPACE": "metadata.namespace",
128
131
  "METAFLOW_KUBERNETES_POD_NAMESPACE": "metadata.namespace",
129
132
  "METAFLOW_KUBERNETES_POD_NAME": "metadata.name",
130
133
  "METAFLOW_KUBERNETES_POD_ID": "metadata.uid",
@@ -257,7 +260,6 @@ class KubernetesJob(object):
257
260
  if self._kwargs["persistent_volume_claims"] is not None
258
261
  else []
259
262
  ),
260
- # TODO (savin): Set termination_message_policy
261
263
  ),
262
264
  ),
263
265
  )
@@ -52,8 +52,6 @@ def k8s_retry(deadline_seconds=60, max_backoff=32):
52
52
  return decorator
53
53
 
54
54
 
55
- CONTROL_JOB_NAME = "control"
56
-
57
55
  JobsetStatus = namedtuple(
58
56
  "JobsetStatus",
59
57
  [
@@ -587,13 +585,17 @@ class JobSetSpec(object):
587
585
  client.V1Container(
588
586
  command=self._kwargs["command"],
589
587
  termination_message_policy="FallbackToLogsOnError",
590
- ports=[]
591
- if self._kwargs["port"] is None
592
- else [
593
- client.V1ContainerPort(
594
- container_port=int(self._kwargs["port"])
595
- )
596
- ],
588
+ ports=(
589
+ []
590
+ if self._kwargs["port"] is None
591
+ else [
592
+ client.V1ContainerPort(
593
+ container_port=int(
594
+ self._kwargs["port"]
595
+ )
596
+ )
597
+ ]
598
+ ),
597
599
  env=[
598
600
  client.V1EnvVar(name=k, value=str(v))
599
601
  for k, v in self._kwargs.get(
@@ -757,7 +759,6 @@ class JobSetSpec(object):
757
759
  is not None
758
760
  else []
759
761
  ),
760
- # TODO (savin): Set termination_message_policy
761
762
  ),
762
763
  ),
763
764
  ),
@@ -791,14 +792,14 @@ class KubernetesJobSet(object):
791
792
 
792
793
  self._jobset_control_addr = _make_domain_name(
793
794
  name,
794
- CONTROL_JOB_NAME,
795
+ "control",
795
796
  0,
796
797
  0,
797
798
  namespace,
798
799
  )
799
800
 
800
801
  self._control_spec = JobSetSpec(
801
- client.get(), name=CONTROL_JOB_NAME, namespace=namespace, **kwargs
802
+ client.get(), name="control", namespace=namespace, **kwargs
802
803
  )
803
804
  self._worker_spec = JobSetSpec(
804
805
  client.get(), name="worker", namespace=namespace, **kwargs
@@ -919,14 +920,14 @@ class KubernetesArgoJobSet(object):
919
920
 
920
921
  self._jobset_control_addr = _make_domain_name(
921
922
  name,
922
- CONTROL_JOB_NAME,
923
+ "control",
923
924
  0,
924
925
  0,
925
926
  namespace,
926
927
  )
927
928
 
928
929
  self._control_spec = JobSetSpec(
929
- kubernetes_sdk, name=CONTROL_JOB_NAME, namespace=namespace, **kwargs
930
+ kubernetes_sdk, name="control", namespace=namespace, **kwargs
930
931
  )
931
932
  self._worker_spec = JobSetSpec(
932
933
  kubernetes_sdk, name="worker", namespace=namespace, **kwargs
@@ -7,6 +7,7 @@ from ..datastore import TaskDataStoreSet, TaskDataStore
7
7
 
8
8
  from ..mflog import mflog, LOG_SOURCES
9
9
 
10
+
10
11
  # main motivation from https://github.com/pallets/click/issues/430
11
12
  # in order to support a default command being called for a Click group.
12
13
  #
@@ -298,9 +298,7 @@ class CondaEnvironment(MetaflowEnvironment):
298
298
  lambda f: lambda obj: (
299
299
  {k: f(f)(v) for k, v in sorted(obj.items())}
300
300
  if isinstance(obj, dict)
301
- else sorted([f(f)(e) for e in obj])
302
- if isinstance(obj, list)
303
- else obj
301
+ else sorted([f(f)(e) for e in obj]) if isinstance(obj, list) else obj
304
302
  )
305
303
  )
306
304
 
@@ -121,9 +121,9 @@ class Pip(object):
121
121
  res["url"] = "{vcs}+{url}@{commit_id}{subdir_str}".format(
122
122
  **vcs_info,
123
123
  **res,
124
- subdir_str="#subdirectory=%s" % subdirectory
125
- if subdirectory
126
- else ""
124
+ subdir_str=(
125
+ "#subdirectory=%s" % subdirectory if subdirectory else ""
126
+ )
127
127
  )
128
128
  # used to deduplicate the storage location in case wheel does not
129
129
  # build with enough unique identifiers.
@@ -507,9 +507,9 @@ def tag_list(
507
507
 
508
508
  if not group_by_run and not group_by_tag:
509
509
  # We list all the runs that match to print them out if needed.
510
- system_tags_by_some_grouping[
511
- ",".join(pathspecs)
512
- ] = system_tags_by_some_grouping.get("_", set())
510
+ system_tags_by_some_grouping[",".join(pathspecs)] = (
511
+ system_tags_by_some_grouping.get("_", set())
512
+ )
513
513
  all_tags_by_some_grouping[",".join(pathspecs)] = all_tags_by_some_grouping.get(
514
514
  "_", set()
515
515
  )
metaflow/procpoll.py CHANGED
@@ -31,7 +31,7 @@ class LinuxProcPoll(ProcPoll):
31
31
  self._poll.unregister(fd)
32
32
 
33
33
  def poll(self, timeout):
34
- for (fd, event) in self._poll.poll(timeout):
34
+ for fd, event in self._poll.poll(timeout):
35
35
  yield ProcPollEvent(
36
36
  fd=fd,
37
37
  can_read=bool(event & select.POLLIN),
metaflow/runtime.py CHANGED
@@ -4,6 +4,7 @@ Local backend
4
4
  Execute the flow with a native runtime
5
5
  using local / remote processes
6
6
  """
7
+
7
8
  from __future__ import print_function
8
9
  import os
9
10
  import sys
metaflow/util.py CHANGED
@@ -382,9 +382,9 @@ def to_camelcase(obj):
382
382
  if isinstance(obj, dict):
383
383
  res = obj.__class__()
384
384
  for k in obj:
385
- res[
386
- re.sub(r"(?!^)_([a-zA-Z])", lambda x: x.group(1).upper(), k)
387
- ] = to_camelcase(obj[k])
385
+ res[re.sub(r"(?!^)_([a-zA-Z])", lambda x: x.group(1).upper(), k)] = (
386
+ to_camelcase(obj[k])
387
+ )
388
388
  elif isinstance(obj, (list, set, tuple)):
389
389
  res = obj.__class__(to_camelcase(v) for v in obj)
390
390
  else:
@@ -401,9 +401,9 @@ def to_pascalcase(obj):
401
401
  if isinstance(obj, dict):
402
402
  res = obj.__class__()
403
403
  for k in obj:
404
- res[
405
- re.sub("([a-zA-Z])", lambda x: x.groups()[0].upper(), k, 1)
406
- ] = to_pascalcase(obj[k])
404
+ res[re.sub("([a-zA-Z])", lambda x: x.groups()[0].upper(), k, 1)] = (
405
+ to_pascalcase(obj[k])
406
+ )
407
407
  elif isinstance(obj, (list, set, tuple)):
408
408
  res = obj.__class__(to_pascalcase(v) for v in obj)
409
409
  else:
metaflow/version.py CHANGED
@@ -1 +1 @@
1
- metaflow_version = "2.12.10"
1
+ metaflow_version = "2.12.12"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: metaflow
3
- Version: 2.12.10
3
+ Version: 2.12.12
4
4
  Summary: Metaflow: More Data Science, Less Engineering
5
5
  Author: Metaflow Developers
6
6
  Author-email: help@metaflow.org
@@ -26,7 +26,7 @@ License-File: LICENSE
26
26
  Requires-Dist: requests
27
27
  Requires-Dist: boto3
28
28
  Provides-Extra: stubs
29
- Requires-Dist: metaflow-stubs ==2.12.10 ; extra == 'stubs'
29
+ Requires-Dist: metaflow-stubs==2.12.12; extra == "stubs"
30
30
 
31
31
  ![Metaflow_Logo_Horizontal_FullColor_Ribbon_Dark_RGB](https://user-images.githubusercontent.com/763451/89453116-96a57e00-d713-11ea-9fa6-82b29d4d6eff.png)
32
32