torchx-nightly 2024.1.6__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +4 -3
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +75 -6
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +9 -7
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +418 -30
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +181 -4
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2024.1.6.dist-info/METADATA +0 -176
  108. torchx_nightly-2024.1.6.dist-info/RECORD +0 -118
  109. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,598 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import getpass
10
+ import os
11
+ import re
12
+ import threading
13
+ from collections import OrderedDict as OrdDict
14
+ from dataclasses import asdict, dataclass
15
+ from datetime import datetime
16
+ from typing import (
17
+ Any,
18
+ Callable,
19
+ cast,
20
+ Dict,
21
+ Iterable,
22
+ List,
23
+ Mapping,
24
+ Optional,
25
+ OrderedDict,
26
+ Tuple,
27
+ TYPE_CHECKING,
28
+ TypedDict,
29
+ TypeVar,
30
+ )
31
+
32
+ import boto3
33
+ import yaml
34
+
35
+ from sagemaker.pytorch import PyTorch
36
+ from torchx.components.structured_arg import StructuredNameArgument
37
+ from torchx.schedulers.api import (
38
+ DescribeAppResponse,
39
+ ListAppResponse,
40
+ Scheduler,
41
+ Stream,
42
+ )
43
+ from torchx.schedulers.ids import make_unique
44
+ from torchx.specs.api import AppDef, AppDryRunInfo, AppState, CfgVal, runopts
45
+ from torchx.workspace.docker_workspace import DockerWorkspaceMixin
46
+
47
+
48
+ if TYPE_CHECKING:
49
+ from docker import DockerClient # pragma: no cover
50
+
51
+ JOB_STATE: Dict[str, AppState] = {
52
+ "InProgress": AppState.RUNNING,
53
+ "Completed": AppState.SUCCEEDED,
54
+ "Failed": AppState.FAILED,
55
+ "Stopping": AppState.CANCELLED,
56
+ "Stopped": AppState.CANCELLED,
57
+ }
58
+
59
+
60
+ class AWSSageMakerOpts(TypedDict, total=False):
61
+ """
62
+ Opts where we can get from .torchxconfig or user command args
63
+ """
64
+
65
+ role: str
66
+ instance_count: int
67
+ instance_type: str
68
+ keep_alive_period_in_seconds: Optional[int]
69
+ volume_size: Optional[int]
70
+ volume_kms_key: Optional[str]
71
+ max_run: Optional[int]
72
+ input_mode: Optional[str]
73
+ output_path: Optional[str]
74
+ output_kms_key: Optional[str]
75
+ base_job_name: Optional[str]
76
+ tags: Optional[Dict[str, str]]
77
+ subnets: Optional[List[str]]
78
+ security_group_ids: Optional[List[str]]
79
+ model_uri: Optional[str]
80
+ model_channel_name: Optional[str]
81
+ metric_definitions: Optional[Dict[str, str]]
82
+ encrypt_inter_container_traffic: Optional[bool]
83
+ use_spot_instances: Optional[bool]
84
+ max_wait: Optional[int]
85
+ checkpoint_s3_uri: Optional[str]
86
+ checkpoint_local_path: Optional[str]
87
+ debugger_hook_config: Optional[bool]
88
+ enable_sagemaker_metrics: Optional[bool]
89
+ enable_network_isolation: Optional[bool]
90
+ disable_profiler: Optional[bool]
91
+ environment: Optional[Dict[str, str]]
92
+ max_retry_attempts: Optional[int]
93
+ source_dir: Optional[str]
94
+ git_config: Optional[Dict[str, str]]
95
+ hyperparameters: Optional[Dict[str, str]]
96
+ container_log_level: Optional[int]
97
+ code_location: Optional[str]
98
+ dependencies: Optional[List[str]]
99
+ training_repository_access_mode: Optional[str]
100
+ training_repository_credentials_provider_arn: Optional[str]
101
+ disable_output_compression: Optional[bool]
102
+ enable_infra_check: Optional[bool]
103
+
104
+
105
+ @dataclass
106
+ class AWSSageMakerJob:
107
+ """
108
+ Jobs defined the key values that is requried to schedule a job. This will be the value
109
+ of `request` in the AppDryRunInfo object.
110
+
111
+ - job_name: defines the job name shown in SageMaker
112
+ - job_def: defines the job description that will be used to schedule the job on SageMaker
113
+ - images_to_push: used by torchx to push to image_repo
114
+ """
115
+
116
+ job_name: str
117
+ job_def: Dict[str, Any]
118
+ images_to_push: Dict[str, Tuple[str, str]]
119
+
120
+ def __str__(self) -> str:
121
+ return yaml.dump(asdict(self))
122
+
123
+ def __repr__(self) -> str:
124
+ return str(self)
125
+
126
+
127
+ T = TypeVar("T")
128
+
129
+
130
+ def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]:
131
+ # decorator function for keeping object in cache
132
+ local: threading.local = threading.local()
133
+ key: str = "value"
134
+
135
+ def wrapper() -> T:
136
+ if key in local.__dict__:
137
+ return local.__dict__[key]
138
+ v = f()
139
+ local.__dict__[key] = v
140
+ return v
141
+
142
+ return wrapper
143
+
144
+
145
+ @_thread_local_cache
146
+ def _local_session() -> boto3.session.Session:
147
+ return boto3.session.Session()
148
+
149
+
150
+ def _merge_ordered(
151
+ src: Optional[Dict[str, str]], extra: Dict[str, str]
152
+ ) -> OrderedDict[str, str]:
153
+ merged = OrdDict(src or {})
154
+ merged.update(extra)
155
+ return merged
156
+
157
+
158
+ class AWSSageMakerScheduler(
159
+ DockerWorkspaceMixin,
160
+ Scheduler[AWSSageMakerOpts],
161
+ ):
162
+ """
163
+ AWSSageMakerScheduler is a TorchX scheduling interface to AWS SageMaker.
164
+
165
+ .. code-block:: bash
166
+
167
+ $ torchx run -s aws_sagemaker utils.echo --image alpine:latest --msg hello
168
+ aws_batch://torchx_user/1234
169
+ $ torchx status aws_batch://torchx_user/1234
170
+ ...
171
+
172
+ Authentication is loaded from the environment using the ``boto3`` credential
173
+ handling.
174
+
175
+ **Config Options**
176
+
177
+ .. runopts::
178
+ class: torchx.schedulers.aws_sagemaker_scheduler.create_scheduler
179
+
180
+ **Compatibility**
181
+
182
+ .. compatibility::
183
+ type: scheduler
184
+ features:
185
+ cancel: true
186
+ logs: false
187
+ distributed: true
188
+ describe: |
189
+ Partial support. SageMakerScheduler will return job and replica
190
+ status but does not provide the complete original AppSpec.
191
+ workspaces: true
192
+ mounts: false
193
+ elasticity: false
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ session_name: str,
199
+ client: Optional[Any] = None, # pyre-ignore[2]
200
+ docker_client: Optional["DockerClient"] = None,
201
+ ) -> None:
202
+ super().__init__("aws_sagemaker", session_name, docker_client=docker_client)
203
+ # pyre-fixme[4]: Attribute annotation cannot be `Any`.
204
+ self.__client = client
205
+
206
+ @property
207
+ # pyre-fixme[3]: Return annotation cannot be `Any`.
208
+ def _client(self) -> Any:
209
+ if self.__client:
210
+ return self.__client
211
+ return _local_session().client("sagemaker")
212
+
213
+ def schedule(self, dryrun_info: AppDryRunInfo[AWSSageMakerJob]) -> str:
214
+ cfg = dryrun_info._cfg
215
+ assert cfg is not None, f"{dryrun_info} missing cfg"
216
+
217
+ images_to_push = dryrun_info.request.images_to_push
218
+ self.push_images(images_to_push)
219
+
220
+ req = dryrun_info.request
221
+ pt_estimator = PyTorch(**req.job_def)
222
+ pt_estimator.fit(wait=False, job_name=req.job_name)
223
+
224
+ return req.job_name
225
+
226
+ def _submit_dryrun(
227
+ self, app: AppDef, cfg: AWSSageMakerOpts
228
+ ) -> AppDryRunInfo[AWSSageMakerJob]:
229
+ role = app.roles[0]
230
+ entrypoint, hyperparameters = self._parse_args(role.args)
231
+
232
+ # map any local images to the remote image
233
+ images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg))
234
+ structured_name_kwargs = {}
235
+ if entrypoint.startswith("-m"):
236
+ structured_name_kwargs["m"] = entrypoint.replace("-m", "").strip()
237
+ else:
238
+ structured_name_kwargs["script"] = entrypoint
239
+ structured_name = StructuredNameArgument.parse_from(
240
+ app.name, **structured_name_kwargs
241
+ )
242
+ job_name = make_unique(structured_name.run_name)
243
+
244
+ role.env["TORCHX_JOB_ID"] = job_name
245
+
246
+ # see https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.EstimatorBase
247
+ job_def = {
248
+ "entry_point": entrypoint,
249
+ "image_uri": role.image,
250
+ "distribution": {"torch_distributed": {"enabled": True}},
251
+ }
252
+
253
+ cfg["environment"] = _merge_ordered(cfg.get("environment"), role.env)
254
+ # hyperparameters are used for both script/module entrypoint args and the values from .torchxconfig
255
+ # order matters, adding script args last to handle wildcard parameters
256
+ cfg["hyperparameters"] = _merge_ordered(
257
+ cfg.get("hyperparameters"), hyperparameters
258
+ )
259
+ # tags are used for AppDef metadata and the values from .torchxconfig
260
+ cfg["tags"] = [ # pyre-ignore[54]
261
+ *(cfg.get("tags") or []),
262
+ *({"Key": k, "Value": v} for k, v in app.metadata.items()),
263
+ ]
264
+ # following the principle of least astonishment defaulting source_dir to current working directory
265
+ cfg["source_dir"] = cfg.get("source_dir") or os.getcwd()
266
+
267
+ for key in cfg:
268
+ if key in job_def:
269
+ raise ValueError(
270
+ f"{key} is controlled by aws_sagemaker_scheduler and is set to {job_def[key]}"
271
+ )
272
+ value = cfg.get(key) # type: ignore
273
+ if value is not None:
274
+ job_def[key] = value # type: ignore
275
+
276
+ req = AWSSageMakerJob(
277
+ job_name=job_name,
278
+ job_def=job_def,
279
+ images_to_push=images_to_push,
280
+ )
281
+ return AppDryRunInfo(req, repr)
282
+
283
+ def _parse_args(self, args: List[str]) -> Tuple[str, Dict[str, str]]:
284
+ if len(args) < 1:
285
+ raise ValueError("Not enough args to resolve entrypoint")
286
+ offset = 1
287
+ if args[0] == "-m":
288
+ if len(args) < 2:
289
+ raise ValueError("Missing module name")
290
+ offset += 1
291
+ entrypoint = " ".join(args[:offset])
292
+ hyperparameters = OrdDict() # the order matters, e.g. for wildcard params
293
+ while offset < len(args):
294
+ arg = args[offset]
295
+ sp_pos = arg.find("=")
296
+ if sp_pos < 0:
297
+ if offset + 1 >= len(args):
298
+ raise ValueError(
299
+ "SageMaker currently only supports named arguments"
300
+ )
301
+ key = arg
302
+ offset += 1
303
+ value = args[offset]
304
+ else:
305
+ key = arg[:sp_pos]
306
+ value = arg[sp_pos + 1 :]
307
+ if not key.startswith("--"):
308
+ raise ValueError("SageMaker only supports arguments that start with --")
309
+ offset += 1
310
+ hyperparameters[key[2:]] = value
311
+ return entrypoint, hyperparameters
312
+
313
+ def _run_opts(self) -> runopts:
314
+ opts = runopts()
315
+ opts.add(
316
+ "role",
317
+ type_=str,
318
+ help="an AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource.",
319
+ required=True,
320
+ )
321
+ opts.add(
322
+ "instance_count",
323
+ type_=int,
324
+ default=1,
325
+ help="number of Amazon EC2 instances to use for training. Required if instance_groups is not set.",
326
+ )
327
+ opts.add(
328
+ "instance_type",
329
+ type_=str,
330
+ help="type of EC2 instance to use for training, for example, 'ml.c4.xlarge'",
331
+ required=True,
332
+ )
333
+ opts.add(
334
+ "user",
335
+ type_=str,
336
+ default=getpass.getuser(),
337
+ help="the username to tag the job with. `getpass.getuser()` if not specified.",
338
+ )
339
+ opts.add(
340
+ "keep_alive_period_in_seconds",
341
+ type_=int,
342
+ default=None,
343
+ help="the duration of time in seconds to retain configured resources in a warm pool for subsequent training jobs.",
344
+ )
345
+ opts.add(
346
+ "volume_size",
347
+ type_=int,
348
+ default=None,
349
+ help="size in GB of the storage volume to use for storing input and output data during training (default: 30).",
350
+ )
351
+ opts.add(
352
+ "volume_kms_key",
353
+ type_=str,
354
+ default=None,
355
+ help="KMS key ID for encrypting EBS volume attached to the training instance.",
356
+ )
357
+ opts.add(
358
+ "max_run",
359
+ type_=int,
360
+ default=None,
361
+ help="timeout in seconds for training (default: 24 * 60 * 60).",
362
+ )
363
+ opts.add(
364
+ "input_mode",
365
+ type_=str,
366
+ default=None,
367
+ help="the input mode that the algorithm supports (default: ‘File’).",
368
+ )
369
+ opts.add(
370
+ "output_path",
371
+ type_=str,
372
+ default=None,
373
+ help="S3 location for saving the training result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the fit() method execution.",
374
+ )
375
+ opts.add(
376
+ "output_kms_key",
377
+ type_=str,
378
+ default=None,
379
+ help="KMS key ID for encrypting the training output (default: Your IAM role’s KMS key for Amazon S3).",
380
+ )
381
+ opts.add(
382
+ "base_job_name",
383
+ type_=str,
384
+ default=None,
385
+ help="prefix for training job name when the fit() method launches. If not specified, the estimator generates a default job name based on the training image name and current timestamp.",
386
+ )
387
+ opts.add(
388
+ "tags",
389
+ type_=List[Dict[str, str]],
390
+ default=None,
391
+ help="list of tags for labeling a training job.",
392
+ )
393
+ opts.add(
394
+ "subnets",
395
+ type_=List[str],
396
+ default=None,
397
+ help="list of subnet ids. If not specified training job will be created without VPC config.",
398
+ )
399
+ opts.add(
400
+ "security_group_ids",
401
+ type_=List[str],
402
+ default=None,
403
+ help="list of security group ids. If not specified training job will be created without VPC config.",
404
+ )
405
+ opts.add(
406
+ "model_uri",
407
+ type_=str,
408
+ default=None,
409
+ help="URI where a pre-trained model is stored, either locally or in S3.",
410
+ )
411
+ opts.add(
412
+ "model_channel_name",
413
+ type_=str,
414
+ default=None,
415
+ help="name of the channel where ‘model_uri’ will be downloaded (default: ‘model’).",
416
+ )
417
+ opts.add(
418
+ "metric_definitions",
419
+ type_=List[Dict[str, str]],
420
+ default=None,
421
+ help="list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: ‘Name’ for the name of the metric, and ‘Regex’ for the regular expression used to extract the metric from the logs.",
422
+ )
423
+ opts.add(
424
+ "encrypt_inter_container_traffic",
425
+ type_=bool,
426
+ default=None,
427
+ help="specifies whether traffic between training containers is encrypted for the training job (default: False).",
428
+ )
429
+ opts.add(
430
+ "use_spot_instances",
431
+ type_=bool,
432
+ default=None,
433
+ help="specifies whether to use SageMaker Managed Spot instances for training. If enabled then the max_wait arg should also be set.",
434
+ )
435
+ opts.add(
436
+ "max_wait",
437
+ type_=int,
438
+ default=None,
439
+ help="timeout in seconds waiting for spot training job.",
440
+ )
441
+ opts.add(
442
+ "checkpoint_s3_uri",
443
+ type_=str,
444
+ default=None,
445
+ help="S3 URI in which to persist checkpoints that the algorithm persists (if any) during training.",
446
+ )
447
+ opts.add(
448
+ "checkpoint_local_path",
449
+ type_=str,
450
+ default=None,
451
+ help="local path that the algorithm writes its checkpoints to.",
452
+ )
453
+ opts.add(
454
+ "debugger_hook_config",
455
+ type_=bool,
456
+ default=None,
457
+ help="configuration for how debugging information is emitted with SageMaker Debugger. If not specified, a default one is created using the estimator’s output_path, unless the region does not support SageMaker Debugger. To disable SageMaker Debugger, set this parameter to False.",
458
+ )
459
+ opts.add(
460
+ "enable_sagemaker_metrics",
461
+ type_=bool,
462
+ default=None,
463
+ help="enable SageMaker Metrics Time Series.",
464
+ )
465
+ opts.add(
466
+ "enable_network_isolation",
467
+ type_=bool,
468
+ default=None,
469
+ help="specifies whether container will run in network isolation mode (default: False).",
470
+ )
471
+ opts.add(
472
+ "disable_profiler",
473
+ type_=bool,
474
+ default=None,
475
+ help="specifies whether Debugger monitoring and profiling will be disabled (default: False).",
476
+ )
477
+ opts.add(
478
+ "environment",
479
+ type_=Dict[str, str],
480
+ default=None,
481
+ help="environment variables to be set for use during training job",
482
+ )
483
+ opts.add(
484
+ "max_retry_attempts",
485
+ type_=int,
486
+ default=None,
487
+ help="number of times to move a job to the STARTING status. You can specify between 1 and 30 attempts.",
488
+ )
489
+ opts.add(
490
+ "source_dir",
491
+ type_=str,
492
+ default=None,
493
+ help="absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: current working directory)",
494
+ )
495
+ opts.add(
496
+ "git_config",
497
+ type_=Dict[str, str],
498
+ default=None,
499
+ help="git configurations used for cloning files, including repo, branch, commit, 2FA_enabled, username, password, and token.",
500
+ )
501
+ opts.add(
502
+ "hyperparameters",
503
+ type_=Dict[str, str],
504
+ default=None,
505
+ help="dictionary containing the hyperparameters to initialize this estimator with.",
506
+ )
507
+ opts.add(
508
+ "container_log_level",
509
+ type_=int,
510
+ default=None,
511
+ help="log level to use within the container (default: logging.INFO).",
512
+ )
513
+ opts.add(
514
+ "code_location",
515
+ type_=str,
516
+ default=None,
517
+ help="S3 prefix URI where custom code is uploaded.",
518
+ )
519
+ opts.add(
520
+ "dependencies",
521
+ type_=List[str],
522
+ default=None,
523
+ help="list of absolute or relative paths to directories with any additional libraries that should be exported to the container.",
524
+ )
525
+ opts.add(
526
+ "training_repository_access_mode",
527
+ type_=str,
528
+ default=None,
529
+ help="specifies how SageMaker accesses the Docker image that contains the training algorithm.",
530
+ )
531
+ opts.add(
532
+ "training_repository_credentials_provider_arn",
533
+ type_=str,
534
+ default=None,
535
+ help="Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your training image is hosted.",
536
+ )
537
+ opts.add(
538
+ "disable_output_compression",
539
+ type_=bool,
540
+ default=None,
541
+ help="when set to true, Model is uploaded to Amazon S3 without compression after training finishes.",
542
+ )
543
+ opts.add(
544
+ "enable_infra_check",
545
+ type_=bool,
546
+ default=None,
547
+ help="specifies whether it is running Sagemaker built-in infra check jobs.",
548
+ )
549
+ return opts
550
+
551
+ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
552
+ job = self._get_job(app_id)
553
+ if job is None:
554
+ return None
555
+
556
+ return DescribeAppResponse(
557
+ app_id=app_id,
558
+ state=JOB_STATE[job["TrainingJobStatus"]],
559
+ ui_url=self._job_ui_url(job["TrainingJobArn"]),
560
+ )
561
+
562
+ def list(self) -> List[ListAppResponse]:
563
+ raise NotImplementedError()
564
+
565
+ def _cancel_existing(self, app_id: str) -> None:
566
+ self._client.stop_training_job(TrainingJobName=app_id)
567
+
568
+ def log_iter(
569
+ self,
570
+ app_id: str,
571
+ role_name: str,
572
+ k: int = 0,
573
+ regex: Optional[str] = None,
574
+ since: Optional[datetime] = None,
575
+ until: Optional[datetime] = None,
576
+ should_tail: bool = False,
577
+ streams: Optional[Stream] = None,
578
+ ) -> Iterable[str]:
579
+ raise NotImplementedError()
580
+
581
+ def _get_job(self, app_id: str) -> Optional[Dict[str, Any]]:
582
+ job = self._client.describe_training_job(TrainingJobName=app_id)
583
+ return job
584
+
585
+ def _job_ui_url(self, job_arn: str) -> Optional[str]:
586
+ match = re.match(
587
+ "arn:aws:sagemaker:(?P<region>[a-z-0-9]+):[0-9]+:training-job/(?P<job_id>[a-z-0-9]+)",
588
+ job_arn,
589
+ )
590
+ if match is None:
591
+ return None
592
+ region = match.group("region")
593
+ job_id = match.group("job_id")
594
+ return f"https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#jobs/{job_id}"
595
+
596
+
597
+ def create_scheduler(session_name: str, **kwargs: object) -> AWSSageMakerScheduler:
598
+ return AWSSageMakerScheduler(session_name=session_name)
@@ -4,26 +4,39 @@
4
4
  #
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
7
9
  import warnings
10
+ from functools import partial
8
11
  from typing import Callable, Dict, List, Mapping
9
12
 
10
13
  from torchx.specs.api import DeviceMount
14
+ from torchx.specs.named_resources_aws import EFA_DEVICE, NEURON_DEVICE
11
15
 
12
16
 
13
- def efa_to_devicemounts(num_devices: int) -> List[DeviceMount]:
17
+ def to_devicemounts(num_devices: int, device_type: str) -> List[DeviceMount]:
14
18
  device_mounts = []
15
19
  for device_index in range(0, num_devices):
16
20
  device_mounts.append(
17
21
  DeviceMount(
18
- src_path="/dev/infiniband/uverbs" + str(device_index),
19
- dst_path="/dev/infiniband/uverbs" + str(device_index),
22
+ src_path=device_type + str(device_index),
23
+ dst_path=device_type + str(device_index),
20
24
  )
21
25
  )
22
26
  return device_mounts
23
27
 
24
28
 
29
+ neuron_to_devicemounts: Callable[[int], List[DeviceMount]] = partial(
30
+ to_devicemounts, device_type="/dev/neuron"
31
+ )
32
+ efa_to_devicemounts: Callable[[int], List[DeviceMount]] = partial(
33
+ to_devicemounts, device_type="/dev/infiniband/uverbs"
34
+ )
35
+
36
+
25
37
  DEVICES: Mapping[str, Callable[[int], List[DeviceMount]]] = {
26
- "vpc.amazonaws.com/efa": efa_to_devicemounts,
38
+ EFA_DEVICE: efa_to_devicemounts,
39
+ NEURON_DEVICE: neuron_to_devicemounts,
27
40
  }
28
41
 
29
42