torchx-nightly 2024.2.11__py3-none-any.whl → 2025.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (102) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/apps/serve/serve.py +2 -0
  3. torchx/apps/utils/booth_main.py +2 -0
  4. torchx/apps/utils/copy_main.py +2 -0
  5. torchx/apps/utils/process_monitor.py +2 -0
  6. torchx/cli/__init__.py +2 -0
  7. torchx/cli/argparse_util.py +38 -3
  8. torchx/cli/cmd_base.py +2 -0
  9. torchx/cli/cmd_cancel.py +2 -0
  10. torchx/cli/cmd_configure.py +2 -0
  11. torchx/cli/cmd_describe.py +2 -0
  12. torchx/cli/cmd_list.py +2 -0
  13. torchx/cli/cmd_log.py +6 -24
  14. torchx/cli/cmd_run.py +30 -12
  15. torchx/cli/cmd_runopts.py +2 -0
  16. torchx/cli/cmd_status.py +2 -0
  17. torchx/cli/cmd_tracker.py +2 -0
  18. torchx/cli/colors.py +2 -0
  19. torchx/cli/main.py +2 -0
  20. torchx/components/__init__.py +2 -0
  21. torchx/components/component_test_base.py +2 -0
  22. torchx/components/dist.py +2 -0
  23. torchx/components/integration_tests/component_provider.py +2 -0
  24. torchx/components/integration_tests/integ_tests.py +2 -0
  25. torchx/components/serve.py +2 -0
  26. torchx/components/structured_arg.py +2 -0
  27. torchx/components/utils.py +2 -0
  28. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  29. torchx/examples/apps/lightning/data.py +5 -3
  30. torchx/examples/apps/lightning/model.py +2 -0
  31. torchx/examples/apps/lightning/profiler.py +7 -4
  32. torchx/examples/apps/lightning/train.py +2 -0
  33. torchx/examples/pipelines/kfp/advanced_pipeline.py +2 -0
  34. torchx/examples/pipelines/kfp/dist_pipeline.py +3 -1
  35. torchx/examples/pipelines/kfp/intro_pipeline.py +3 -1
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/pipelines/kfp/__init__.py +2 -0
  39. torchx/pipelines/kfp/adapter.py +7 -4
  40. torchx/pipelines/kfp/version.py +2 -0
  41. torchx/runner/__init__.py +2 -0
  42. torchx/runner/api.py +78 -20
  43. torchx/runner/config.py +34 -3
  44. torchx/runner/events/__init__.py +37 -3
  45. torchx/runner/events/api.py +13 -2
  46. torchx/runner/events/handlers.py +2 -0
  47. torchx/runtime/tracking/__init__.py +2 -0
  48. torchx/runtime/tracking/api.py +2 -0
  49. torchx/schedulers/__init__.py +10 -5
  50. torchx/schedulers/api.py +3 -1
  51. torchx/schedulers/aws_batch_scheduler.py +4 -0
  52. torchx/schedulers/aws_sagemaker_scheduler.py +596 -0
  53. torchx/schedulers/devices.py +17 -4
  54. torchx/schedulers/docker_scheduler.py +38 -8
  55. torchx/schedulers/gcp_batch_scheduler.py +8 -9
  56. torchx/schedulers/ids.py +2 -0
  57. torchx/schedulers/kubernetes_mcad_scheduler.py +3 -1
  58. torchx/schedulers/kubernetes_scheduler.py +31 -5
  59. torchx/schedulers/local_scheduler.py +45 -6
  60. torchx/schedulers/lsf_scheduler.py +3 -1
  61. torchx/schedulers/ray/ray_driver.py +7 -7
  62. torchx/schedulers/ray_scheduler.py +1 -1
  63. torchx/schedulers/slurm_scheduler.py +3 -1
  64. torchx/schedulers/streams.py +2 -0
  65. torchx/specs/__init__.py +49 -8
  66. torchx/specs/api.py +87 -5
  67. torchx/specs/builders.py +61 -19
  68. torchx/specs/file_linter.py +8 -2
  69. torchx/specs/finder.py +2 -0
  70. torchx/specs/named_resources_aws.py +109 -2
  71. torchx/specs/named_resources_generic.py +2 -0
  72. torchx/specs/test/components/__init__.py +2 -0
  73. torchx/specs/test/components/a/__init__.py +2 -0
  74. torchx/specs/test/components/a/b/__init__.py +2 -0
  75. torchx/specs/test/components/a/b/c.py +2 -0
  76. torchx/specs/test/components/c/__init__.py +2 -0
  77. torchx/specs/test/components/c/d.py +2 -0
  78. torchx/tracker/__init__.py +2 -0
  79. torchx/tracker/api.py +4 -4
  80. torchx/tracker/backend/fsspec.py +2 -0
  81. torchx/util/cuda.py +2 -0
  82. torchx/util/datetime.py +2 -0
  83. torchx/util/entrypoints.py +6 -2
  84. torchx/util/io.py +2 -0
  85. torchx/util/log_tee_helpers.py +210 -0
  86. torchx/util/modules.py +2 -0
  87. torchx/util/session.py +42 -0
  88. torchx/util/shlex.py +2 -0
  89. torchx/util/strings.py +2 -0
  90. torchx/util/types.py +20 -2
  91. torchx/version.py +3 -1
  92. torchx/workspace/__init__.py +2 -0
  93. torchx/workspace/api.py +34 -1
  94. torchx/workspace/dir_workspace.py +2 -0
  95. torchx/workspace/docker_workspace.py +25 -2
  96. {torchx_nightly-2024.2.11.dist-info → torchx_nightly-2025.1.14.dist-info}/METADATA +55 -48
  97. torchx_nightly-2025.1.14.dist-info/RECORD +123 -0
  98. {torchx_nightly-2024.2.11.dist-info → torchx_nightly-2025.1.14.dist-info}/WHEEL +1 -1
  99. {torchx_nightly-2024.2.11.dist-info → torchx_nightly-2025.1.14.dist-info}/entry_points.txt +0 -1
  100. torchx_nightly-2024.2.11.dist-info/RECORD +0 -119
  101. {torchx_nightly-2024.2.11.dist-info → torchx_nightly-2025.1.14.dist-info}/LICENSE +0 -0
  102. {torchx_nightly-2024.2.11.dist-info → torchx_nightly-2025.1.14.dist-info}/top_level.txt +0 -0
torchx/specs/api.py CHANGED
@@ -5,12 +5,19 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
10
+ import asyncio
8
11
  import copy
12
+ import inspect
9
13
  import json
14
+ import logging as logger
10
15
  import re
16
+ import typing
11
17
  from dataclasses import asdict, dataclass, field
12
18
  from datetime import datetime
13
19
  from enum import Enum
20
+ from json import JSONDecodeError
14
21
  from string import Template
15
22
  from typing import (
16
23
  Any,
@@ -183,11 +190,37 @@ class macros:
183
190
  apply applies the values to a copy the specified role and returns it.
184
191
  """
185
192
 
193
+ # Overrides might contain future values which can't be serialized so taken out for the copy
194
+ overrides = role.overrides
195
+ if len(overrides) > 0:
196
+ logger.warning(
197
+ "Role overrides are not supported for macros. Overrides will not be copied"
198
+ )
199
+ role.overrides = {}
186
200
  role = copy.deepcopy(role)
201
+ role.overrides = overrides
202
+
187
203
  role.args = [self.substitute(arg) for arg in role.args]
188
204
  role.env = {key: self.substitute(arg) for key, arg in role.env.items()}
205
+ role.metadata = self._apply_nested(role.metadata)
206
+
189
207
  return role
190
208
 
209
+ def _apply_nested(self, d: typing.Dict[str, Any]) -> typing.Dict[str, Any]:
210
+ stack = [d]
211
+ while stack:
212
+ current_dict = stack.pop()
213
+ for k, v in current_dict.items():
214
+ if isinstance(v, dict):
215
+ stack.append(v)
216
+ elif isinstance(v, str):
217
+ current_dict[k] = self.substitute(v)
218
+ elif isinstance(v, list):
219
+ for i in range(len(v)):
220
+ if isinstance(v[i], str):
221
+ v[i] = self.substitute(v[i])
222
+ return d
223
+
191
224
  def substitute(self, arg: str) -> str:
192
225
  """
193
226
  substitute applies the values to the template arg.
@@ -216,11 +249,13 @@ class RetryPolicy(str, Enum):
216
249
  application to deal with failed replica departures and
217
250
  replacement replica admittance.
218
251
  2. APPLICATION: Restarts the entire application.
219
-
252
+ 3. ROLE: Restarts the role when any error occurs in that role. This does not
253
+ restart the whole job.
220
254
  """
221
255
 
222
256
  REPLICA = "REPLICA"
223
257
  APPLICATION = "APPLICATION"
258
+ ROLE = "ROLE"
224
259
 
225
260
 
226
261
  class MountType(str, Enum):
@@ -347,6 +382,24 @@ class Role:
347
382
  mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field(
348
383
  default_factory=list
349
384
  )
385
+ overrides: Dict[str, Any] = field(default_factory=dict)
386
+
387
+ # pyre-ignore
388
+ def __getattribute__(self, attrname: str) -> Any:
389
+ if attrname == "overrides":
390
+ return super().__getattribute__(attrname)
391
+ try:
392
+ ov = super().__getattribute__("overrides")
393
+ except AttributeError:
394
+ ov = {}
395
+ if attrname in ov:
396
+ if inspect.isawaitable(ov[attrname]):
397
+ result = asyncio.get_event_loop().run_until_complete(ov[attrname])
398
+ else:
399
+ result = ov[attrname]()
400
+ setattr(self, attrname, result)
401
+ del ov[attrname]
402
+ return super().__getattribute__(attrname)
350
403
 
351
404
  def pre_proc(
352
405
  self,
@@ -552,7 +605,10 @@ class AppStatus:
552
605
 
553
606
  def _format_replica_status(self, replica_status: ReplicaStatus) -> str:
554
607
  if replica_status.structured_error_msg != NONE:
555
- error_data = json.loads(replica_status.structured_error_msg)
608
+ try:
609
+ error_data = json.loads(replica_status.structured_error_msg)
610
+ except JSONDecodeError:
611
+ return replica_status.structured_error_msg
556
612
  error_message = self._format_error_message(
557
613
  msg=error_data["message"]["message"], header=" error_msg: "
558
614
  )
@@ -637,11 +693,11 @@ class AppStatusError(Exception):
637
693
  self.status = status
638
694
 
639
695
 
640
- # valid run cfg values; only support primitives (str, int, float, bool, List[str])
696
+ # valid run cfg values; only support primitives (str, int, float, bool, List[str], Dict[str, str])
641
697
  # TODO(wilsonhong): python 3.9+ supports list[T] in typing, which can be used directly
642
698
  # in isinstance(). Should replace with that.
643
699
  # see: https://docs.python.org/3/library/stdtypes.html#generic-alias-type
644
- CfgVal = Union[str, int, float, bool, List[str], None]
700
+ CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], None]
645
701
 
646
702
 
647
703
  T = TypeVar("T")
@@ -755,6 +811,10 @@ class runopts:
755
811
  except TypeError:
756
812
  if isinstance(obj, list):
757
813
  return all(isinstance(e, str) for e in obj)
814
+ elif isinstance(obj, dict):
815
+ return all(
816
+ isinstance(k, str) and isinstance(v, str) for k, v in obj.items()
817
+ )
758
818
  else:
759
819
  return False
760
820
 
@@ -863,8 +923,13 @@ class runopts:
863
923
  # lists may be ; or , delimited
864
924
  # also deal with trailing "," by removing empty strings
865
925
  return [v for v in value.replace(";", ",").split(",") if v]
926
+ elif opt_type == Dict[str, str]:
927
+ return {
928
+ s.split(":", 1)[0]: s.split(":", 1)[1]
929
+ for s in value.replace(";", ",").split(",")
930
+ }
866
931
  else:
867
- # pyre-ignore[19]
932
+ # pyre-ignore[19, 6] type won't be dict here as we handled it above
868
933
  return opt_type(value)
869
934
 
870
935
  cfg: Dict[str, CfgVal] = {}
@@ -874,6 +939,23 @@ class runopts:
874
939
  cfg[key] = _cast_to_type(val, runopt_.opt_type)
875
940
  return cfg
876
941
 
942
+ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
943
+ """
944
+ Converts the given dict to a valid cfg for this ``runopts`` object.
945
+ """
946
+ cfg: Dict[str, CfgVal] = {}
947
+ cfg_dict = json.loads(json_repr)
948
+ for key, val in cfg_dict.items():
949
+ runopt_ = self.get(key)
950
+ if runopt_:
951
+ if runopt_.opt_type == List[str]:
952
+ cfg[key] = [str(v) for v in val]
953
+ elif runopt_.opt_type == Dict[str, str]:
954
+ cfg[key] = {str(k): str(v) for k, v in val.items()}
955
+ else:
956
+ cfg[key] = val
957
+ return cfg
958
+
877
959
  def add(
878
960
  self,
879
961
  cfg_key: str,
torchx/specs/builders.py CHANGED
@@ -4,25 +4,25 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import argparse
8
10
  import inspect
11
+ import os
12
+ from argparse import Namespace
9
13
  from typing import Any, Callable, Dict, List, Mapping, Optional, Union
10
14
 
11
15
  from torchx.specs.api import BindMount, MountType, VolumeMount
12
16
  from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
13
- from torchx.util.types import (
14
- decode_from_string,
15
- decode_optional,
16
- get_argparse_param_type,
17
- is_bool,
18
- is_primitive,
19
- )
17
+ from torchx.util.types import decode, decode_optional, get_argparse_param_type, is_bool
20
18
 
21
19
  from .api import AppDef, DeviceMount
22
20
 
23
21
 
24
22
  def _create_args_parser(
25
- cmpnt_fn: Callable[..., AppDef], cmpnt_defaults: Optional[Dict[str, str]] = None
23
+ cmpnt_fn: Callable[..., AppDef],
24
+ cmpnt_defaults: Optional[Dict[str, str]] = None,
25
+ config: Optional[Dict[str, Any]] = None,
26
26
  ) -> argparse.ArgumentParser:
27
27
  parameters = inspect.signature(cmpnt_fn).parameters
28
28
  function_desc, args_desc = get_fn_docstring(cmpnt_fn)
@@ -85,15 +85,55 @@ def _create_args_parser(
85
85
  if len(param_name) == 1:
86
86
  arg_names = [f"-{param_name}"] + arg_names
87
87
  if "default" not in args:
88
- args["required"] = True
88
+ if (config and param_name not in config) or not config:
89
+ args["required"] = True
90
+
89
91
  script_parser.add_argument(*arg_names, **args)
90
92
  return script_parser
91
93
 
92
94
 
95
+ def _merge_config_values_with_args(
96
+ parsed_args: argparse.Namespace, config: Dict[str, Any]
97
+ ) -> None:
98
+ for key, val in config.items():
99
+ if key in parsed_args:
100
+ setattr(parsed_args, key, val)
101
+
102
+
103
+ def parse_args(
104
+ cmpnt_fn: Callable[..., AppDef],
105
+ cmpnt_args: List[str],
106
+ cmpnt_defaults: Optional[Dict[str, Any]] = None,
107
+ config: Optional[Dict[str, Any]] = None,
108
+ ) -> Namespace:
109
+ """
110
+ Parse passed arguments, defaults, and config values into a namespace for
111
+ a component function.
112
+
113
+ Args:
114
+ cmpnt_fn: Component function
115
+ cmpnt_args: Function args
116
+ cmpnt_defaults: Additional default values for parameters of ``app_fn``
117
+ (overrides the defaults set on the fn declaration)
118
+ config: Optional dict containing additional configuration for the component from a passed config file
119
+
120
+ Returns:
121
+ A Namespace object with the args, defaults, and config values incorporated.
122
+ """
123
+
124
+ script_parser = _create_args_parser(cmpnt_fn, cmpnt_defaults, config)
125
+ parsed_args = script_parser.parse_args(cmpnt_args)
126
+ if config:
127
+ _merge_config_values_with_args(parsed_args, config)
128
+
129
+ return parsed_args
130
+
131
+
93
132
  def materialize_appdef(
94
133
  cmpnt_fn: Callable[..., AppDef],
95
134
  cmpnt_args: List[str],
96
- cmpnt_defaults: Optional[Dict[str, str]] = None,
135
+ cmpnt_defaults: Optional[Dict[str, Any]] = None,
136
+ config: Optional[Dict[str, Any]] = None,
97
137
  ) -> AppDef:
98
138
  """
99
139
  Creates an application by running user defined ``app_fn``.
@@ -118,26 +158,23 @@ def materialize_appdef(
118
158
  cmpnt_args: Function args
119
159
  cmpnt_defaults: Additional default values for parameters of ``app_fn``
120
160
  (overrides the defaults set on the fn declaration)
161
+ config: Optional dict containing additional configuration for the component from a passed config file
121
162
  Returns:
122
163
  An application spec
123
164
  """
124
165
 
125
- script_parser = _create_args_parser(cmpnt_fn, cmpnt_defaults)
126
- parsed_args = script_parser.parse_args(cmpnt_args)
127
-
128
166
  function_args = []
129
167
  var_arg = []
130
168
  kwargs = {}
131
169
 
170
+ parsed_args = parse_args(cmpnt_fn, cmpnt_args, cmpnt_defaults, config)
171
+
132
172
  parameters = inspect.signature(cmpnt_fn).parameters
133
173
  for param_name, parameter in parameters.items():
134
174
  arg_value = getattr(parsed_args, param_name)
135
175
  parameter_type = parameter.annotation
136
176
  parameter_type = decode_optional(parameter_type)
137
- if is_bool(parameter_type):
138
- arg_value = arg_value and arg_value.lower() == "true"
139
- elif not is_primitive(parameter_type):
140
- arg_value = decode_from_string(arg_value, parameter_type)
177
+ arg_value = decode(arg_value, parameter_type)
141
178
  if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
142
179
  var_arg = arg_value
143
180
  elif parameter.kind == inspect.Parameter.KEYWORD_ONLY:
@@ -149,7 +186,9 @@ def materialize_appdef(
149
186
  if len(var_arg) > 0 and var_arg[0] == "--":
150
187
  var_arg = var_arg[1:]
151
188
 
152
- return cmpnt_fn(*function_args, *var_arg, **kwargs)
189
+ appdef = cmpnt_fn(*function_args, *var_arg, **kwargs)
190
+
191
+ return appdef
153
192
 
154
193
 
155
194
  def make_app_handle(scheduler_backend: str, session_name: str, app_id: str) -> str:
@@ -205,9 +244,12 @@ def parse_mounts(opts: List[str]) -> List[Union[BindMount, VolumeMount, DeviceMo
205
244
  for opts in mount_opts:
206
245
  typ = opts.get("type")
207
246
  if typ == MountType.BIND:
247
+ src_path = opts["src"]
248
+ if src_path.startswith("~"):
249
+ src_path = os.path.expanduser(src_path)
208
250
  mounts.append(
209
251
  BindMount(
210
- src_path=opts["src"],
252
+ src_path=src_path,
211
253
  dst_path=opts["dst"],
212
254
  read_only="readonly" in opts,
213
255
  )
@@ -5,6 +5,8 @@
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
7
 
8
+ # pyre-strict
9
+
8
10
  import abc
9
11
  import argparse
10
12
  import ast
@@ -29,7 +31,11 @@ def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str,
29
31
  return args_decs
30
32
 
31
33
 
32
- class TorchXArgumentHelpFormatter(argparse.HelpFormatter):
34
+ class TorchXArgumentHelpFormatter(
35
+ argparse.RawDescriptionHelpFormatter,
36
+ argparse.ArgumentDefaultsHelpFormatter,
37
+ argparse.MetavarTypeHelpFormatter,
38
+ ):
33
39
  """Help message formatter which adds default values and required to argument help.
34
40
 
35
41
  If the argument is required, the class appends `(required)` at the end of the help message.
@@ -79,7 +85,7 @@ to your component (see: https://pytorch.org/torchx/latest/component_best_practic
79
85
  args_description[param.arg_name] = param.description
80
86
  short_func_description = docstring.short_description or default_fn_desc
81
87
  if docstring.long_description:
82
- short_func_description += " ..."
88
+ short_func_description += "\n" + docstring.long_description
83
89
  return (short_func_description or default_fn_desc, args_description)
84
90
 
85
91
 
torchx/specs/finder.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import abc
8
10
  import importlib
9
11
  import inspect
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  r"""
8
10
  `torchx.specs.named_resources_aws` contains resource definitions that represent corresponding AWS instance types
9
11
  taken from https://aws.amazon.com/ec2/instance-types/. The resources are exposed
@@ -35,6 +37,7 @@ from typing import Callable, Mapping
35
37
  from torchx.specs.api import Resource
36
38
 
37
39
  EFA_DEVICE = "vpc.amazonaws.com/efa"
40
+ NEURON_DEVICE = "aws.amazon.com/neurondevice"
38
41
 
39
42
  # ecs and ec2 have memtax and currently AWS Batch uses hard memory limits
40
43
  # so we have to account for mem tax when registering these resources for AWS
@@ -107,6 +110,16 @@ def aws_p4de_24xlarge() -> Resource:
107
110
  )
108
111
 
109
112
 
113
+ def aws_p5_48xlarge() -> Resource:
114
+ return Resource(
115
+ cpu=192,
116
+ gpu=8,
117
+ memMB=2048 * GiB,
118
+ capabilities={K8S_ITYPE: "p5.48xlarge"},
119
+ devices={EFA_DEVICE: 32},
120
+ )
121
+
122
+
110
123
  def aws_t3_medium() -> Resource:
111
124
  return Resource(cpu=2, gpu=0, memMB=4 * GiB, capabilities={K8S_ITYPE: "t3.medium"})
112
125
 
@@ -117,6 +130,12 @@ def aws_m5_2xlarge() -> Resource:
117
130
  )
118
131
 
119
132
 
133
+ def aws_c5_18xlarge() -> Resource:
134
+ return Resource(
135
+ cpu=72, gpu=0, memMB=144 * GiB, capabilities={K8S_ITYPE: "c5.18xlarge"}
136
+ )
137
+
138
+
120
139
  def aws_g4dn_xlarge() -> Resource:
121
140
  return Resource(
122
141
  cpu=4, gpu=1, memMB=16 * GiB, capabilities={K8S_ITYPE: "g4dn.xlarge"}
@@ -241,9 +260,87 @@ def aws_g5_48xlarge() -> Resource:
241
260
  )
242
261
 
243
262
 
263
+ def aws_g6e_xlarge() -> Resource:
264
+ return Resource(
265
+ cpu=4,
266
+ gpu=1,
267
+ memMB=32 * GiB,
268
+ capabilities={K8S_ITYPE: "g6e.xlarge"},
269
+ )
270
+
271
+
272
+ def aws_g6e_2xlarge() -> Resource:
273
+ return Resource(
274
+ cpu=8,
275
+ gpu=1,
276
+ memMB=64 * GiB,
277
+ capabilities={K8S_ITYPE: "g6e.2xlarge"},
278
+ )
279
+
280
+
281
+ def aws_g6e_4xlarge() -> Resource:
282
+ return Resource(
283
+ cpu=16,
284
+ gpu=1,
285
+ memMB=128 * GiB,
286
+ capabilities={K8S_ITYPE: "g6e.4xlarge"},
287
+ )
288
+
289
+
290
+ def aws_g6e_8xlarge() -> Resource:
291
+ return Resource(
292
+ cpu=32,
293
+ gpu=1,
294
+ memMB=256 * GiB,
295
+ capabilities={K8S_ITYPE: "g6e.8xlarge"},
296
+ )
297
+
298
+
299
+ def aws_g6e_16xlarge() -> Resource:
300
+ return Resource(
301
+ cpu=64,
302
+ gpu=1,
303
+ memMB=512 * GiB,
304
+ capabilities={K8S_ITYPE: "g6e.16xlarge"},
305
+ )
306
+
307
+
308
+ def aws_g6e_12xlarge() -> Resource:
309
+ return Resource(
310
+ cpu=48,
311
+ gpu=4,
312
+ memMB=384 * GiB,
313
+ capabilities={K8S_ITYPE: "g6e.12xlarge"},
314
+ )
315
+
316
+
317
+ def aws_g6e_24xlarge() -> Resource:
318
+ return Resource(
319
+ cpu=96,
320
+ gpu=4,
321
+ memMB=768 * GiB,
322
+ capabilities={K8S_ITYPE: "g6e.24xlarge"},
323
+ devices={EFA_DEVICE: 2},
324
+ )
325
+
326
+
327
+ def aws_g6e_48xlarge() -> Resource:
328
+ return Resource(
329
+ cpu=192,
330
+ gpu=8,
331
+ memMB=1536 * GiB,
332
+ capabilities={K8S_ITYPE: "g6e.48xlarge"},
333
+ devices={EFA_DEVICE: 4},
334
+ )
335
+
336
+
244
337
  def aws_trn1_2xlarge() -> Resource:
245
338
  return Resource(
246
- cpu=8, gpu=0, memMB=32 * GiB, capabilities={K8S_ITYPE: "trn1.2xlarge"}
339
+ cpu=8,
340
+ gpu=0,
341
+ memMB=32 * GiB,
342
+ capabilities={K8S_ITYPE: "trn1.2xlarge"},
343
+ devices={NEURON_DEVICE: 1},
247
344
  )
248
345
 
249
346
 
@@ -253,19 +350,21 @@ def aws_trn1_32xlarge() -> Resource:
253
350
  gpu=0,
254
351
  memMB=512 * GiB,
255
352
  capabilities={K8S_ITYPE: "trn1.32xlarge"},
256
- devices={EFA_DEVICE: 8},
353
+ devices={EFA_DEVICE: 8, NEURON_DEVICE: 16},
257
354
  )
258
355
 
259
356
 
260
357
  NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
261
358
  "aws_t3.medium": aws_t3_medium,
262
359
  "aws_m5.2xlarge": aws_m5_2xlarge,
360
+ "aws_c5.18xlarge": aws_c5_18xlarge,
263
361
  "aws_p3.2xlarge": aws_p3_2xlarge,
264
362
  "aws_p3.8xlarge": aws_p3_8xlarge,
265
363
  "aws_p3.16xlarge": aws_p3_16xlarge,
266
364
  "aws_p3dn.24xlarge": aws_p3dn_24xlarge,
267
365
  "aws_p4d.24xlarge": aws_p4d_24xlarge,
268
366
  "aws_p4de.24xlarge": aws_p4de_24xlarge,
367
+ "aws_p5.48xlarge": aws_p5_48xlarge,
269
368
  "aws_g4dn.xlarge": aws_g4dn_xlarge,
270
369
  "aws_g4dn.2xlarge": aws_g4dn_2xlarge,
271
370
  "aws_g4dn.4xlarge": aws_g4dn_4xlarge,
@@ -281,6 +380,14 @@ NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = {
281
380
  "aws_g5.12xlarge": aws_g5_12xlarge,
282
381
  "aws_g5.24xlarge": aws_g5_24xlarge,
283
382
  "aws_g5.48xlarge": aws_g5_48xlarge,
383
+ "aws_g6e.xlarge": aws_g6e_xlarge,
384
+ "aws_g6e.2xlarge": aws_g6e_2xlarge,
385
+ "aws_g6e.4xlarge": aws_g6e_4xlarge,
386
+ "aws_g6e.8xlarge": aws_g6e_8xlarge,
387
+ "aws_g6e.16xlarge": aws_g6e_16xlarge,
388
+ "aws_g6e.12xlarge": aws_g6e_12xlarge,
389
+ "aws_g6e.24xlarge": aws_g6e_24xlarge,
390
+ "aws_g6e.48xlarge": aws_g6e_48xlarge,
284
391
  "aws_trn1.2xlarge": aws_trn1_2xlarge,
285
392
  "aws_trn1.32xlarge": aws_trn1_32xlarge,
286
393
  }
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  Defines generic named resources that are not specific to any cloud provider's
9
11
  instance types. These generic named resources are meant to be used as
@@ -3,3 +3,5 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -3,6 +3,8 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
6
8
  import torchx
7
9
  from torchx import specs
8
10
 
@@ -3,3 +3,5 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -3,6 +3,8 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
6
8
  import torchx
7
9
  from torchx import specs
8
10
 
@@ -4,3 +4,5 @@
4
4
  #
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
@@ -3,6 +3,8 @@
3
3
  #
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
6
8
  import torchx
7
9
  from torchx import specs
8
10
 
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  """
8
10
  .. note:: PROTOTYPE, USE AT YOUR OWN RISK, APIs SUBJECT TO CHANGE
9
11
 
torchx/tracker/api.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  from __future__ import annotations
8
10
 
9
11
  import logging
@@ -67,8 +69,7 @@ class AppRunTrackableSource:
67
69
  artifact_name: Optional[str]
68
70
 
69
71
 
70
- class Lineage:
71
- ...
72
+ class Lineage: ...
72
73
 
73
74
 
74
75
  class TrackerBase(ABC):
@@ -332,5 +333,4 @@ class AppRun:
332
333
 
333
334
  return model_run_sources
334
335
 
335
- def children(self) -> Iterable[AppRun]:
336
- ...
336
+ def children(self) -> Iterable[AppRun]: ...
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  from __future__ import annotations
8
10
 
9
11
  import json
torchx/util/cuda.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import torch
8
10
 
9
11
 
torchx/util/datetime.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  from datetime import datetime, timedelta
8
10
 
9
11