torchx-nightly 2024.2.12__py3-none-any.whl → 2025.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (102) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/apps/serve/serve.py +2 -0
  3. torchx/apps/utils/booth_main.py +2 -0
  4. torchx/apps/utils/copy_main.py +2 -0
  5. torchx/apps/utils/process_monitor.py +2 -0
  6. torchx/cli/__init__.py +2 -0
  7. torchx/cli/argparse_util.py +38 -3
  8. torchx/cli/cmd_base.py +2 -0
  9. torchx/cli/cmd_cancel.py +2 -0
  10. torchx/cli/cmd_configure.py +2 -0
  11. torchx/cli/cmd_describe.py +2 -0
  12. torchx/cli/cmd_list.py +2 -0
  13. torchx/cli/cmd_log.py +6 -24
  14. torchx/cli/cmd_run.py +30 -12
  15. torchx/cli/cmd_runopts.py +2 -0
  16. torchx/cli/cmd_status.py +2 -0
  17. torchx/cli/cmd_tracker.py +2 -0
  18. torchx/cli/colors.py +2 -0
  19. torchx/cli/main.py +2 -0
  20. torchx/components/__init__.py +2 -0
  21. torchx/components/component_test_base.py +2 -0
  22. torchx/components/dist.py +2 -0
  23. torchx/components/integration_tests/component_provider.py +2 -0
  24. torchx/components/integration_tests/integ_tests.py +2 -0
  25. torchx/components/serve.py +2 -0
  26. torchx/components/structured_arg.py +2 -0
  27. torchx/components/utils.py +2 -0
  28. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  29. torchx/examples/apps/lightning/data.py +5 -3
  30. torchx/examples/apps/lightning/model.py +2 -0
  31. torchx/examples/apps/lightning/profiler.py +7 -4
  32. torchx/examples/apps/lightning/train.py +2 -0
  33. torchx/examples/pipelines/kfp/advanced_pipeline.py +2 -0
  34. torchx/examples/pipelines/kfp/dist_pipeline.py +3 -1
  35. torchx/examples/pipelines/kfp/intro_pipeline.py +3 -1
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/pipelines/kfp/__init__.py +2 -0
  39. torchx/pipelines/kfp/adapter.py +7 -4
  40. torchx/pipelines/kfp/version.py +2 -0
  41. torchx/runner/__init__.py +2 -0
  42. torchx/runner/api.py +78 -20
  43. torchx/runner/config.py +34 -3
  44. torchx/runner/events/__init__.py +37 -3
  45. torchx/runner/events/api.py +13 -2
  46. torchx/runner/events/handlers.py +2 -0
  47. torchx/runtime/tracking/__init__.py +2 -0
  48. torchx/runtime/tracking/api.py +2 -0
  49. torchx/schedulers/__init__.py +10 -5
  50. torchx/schedulers/api.py +3 -1
  51. torchx/schedulers/aws_batch_scheduler.py +4 -0
  52. torchx/schedulers/aws_sagemaker_scheduler.py +596 -0
  53. torchx/schedulers/devices.py +17 -4
  54. torchx/schedulers/docker_scheduler.py +38 -8
  55. torchx/schedulers/gcp_batch_scheduler.py +8 -9
  56. torchx/schedulers/ids.py +2 -0
  57. torchx/schedulers/kubernetes_mcad_scheduler.py +3 -1
  58. torchx/schedulers/kubernetes_scheduler.py +31 -5
  59. torchx/schedulers/local_scheduler.py +45 -6
  60. torchx/schedulers/lsf_scheduler.py +3 -1
  61. torchx/schedulers/ray/ray_driver.py +7 -7
  62. torchx/schedulers/ray_scheduler.py +1 -1
  63. torchx/schedulers/slurm_scheduler.py +3 -1
  64. torchx/schedulers/streams.py +2 -0
  65. torchx/specs/__init__.py +49 -8
  66. torchx/specs/api.py +87 -5
  67. torchx/specs/builders.py +61 -19
  68. torchx/specs/file_linter.py +8 -2
  69. torchx/specs/finder.py +2 -0
  70. torchx/specs/named_resources_aws.py +109 -2
  71. torchx/specs/named_resources_generic.py +2 -0
  72. torchx/specs/test/components/__init__.py +2 -0
  73. torchx/specs/test/components/a/__init__.py +2 -0
  74. torchx/specs/test/components/a/b/__init__.py +2 -0
  75. torchx/specs/test/components/a/b/c.py +2 -0
  76. torchx/specs/test/components/c/__init__.py +2 -0
  77. torchx/specs/test/components/c/d.py +2 -0
  78. torchx/tracker/__init__.py +2 -0
  79. torchx/tracker/api.py +4 -4
  80. torchx/tracker/backend/fsspec.py +2 -0
  81. torchx/util/cuda.py +2 -0
  82. torchx/util/datetime.py +2 -0
  83. torchx/util/entrypoints.py +6 -2
  84. torchx/util/io.py +2 -0
  85. torchx/util/log_tee_helpers.py +210 -0
  86. torchx/util/modules.py +2 -0
  87. torchx/util/session.py +42 -0
  88. torchx/util/shlex.py +2 -0
  89. torchx/util/strings.py +2 -0
  90. torchx/util/types.py +20 -2
  91. torchx/version.py +3 -1
  92. torchx/workspace/__init__.py +2 -0
  93. torchx/workspace/api.py +34 -1
  94. torchx/workspace/dir_workspace.py +2 -0
  95. torchx/workspace/docker_workspace.py +25 -2
  96. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.15.dist-info}/METADATA +55 -48
  97. torchx_nightly-2025.1.15.dist-info/RECORD +123 -0
  98. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.15.dist-info}/WHEEL +1 -1
  99. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.15.dist-info}/entry_points.txt +0 -1
  100. torchx_nightly-2024.2.12.dist-info/RECORD +0 -119
  101. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.15.dist-info}/LICENSE +0 -0
  102. {torchx_nightly-2024.2.12.dist-info → torchx_nightly-2025.1.15.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,596 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the BSD-style license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import getpass
10
+ import os
11
+ import re
12
+ import threading
13
+ from collections import OrderedDict as OrdDict
14
+ from dataclasses import asdict, dataclass
15
+ from datetime import datetime
16
+ from typing import (
17
+ Any,
18
+ Callable,
19
+ cast,
20
+ Dict,
21
+ Iterable,
22
+ List,
23
+ Mapping,
24
+ Optional,
25
+ OrderedDict,
26
+ Tuple,
27
+ TYPE_CHECKING,
28
+ TypeVar,
29
+ )
30
+
31
+ import boto3
32
+ import yaml
33
+
34
+ from sagemaker.pytorch import PyTorch
35
+ from torchx.components.structured_arg import StructuredNameArgument
36
+ from torchx.schedulers.api import (
37
+ AppDryRunInfo,
38
+ DescribeAppResponse,
39
+ ListAppResponse,
40
+ Scheduler,
41
+ Stream,
42
+ )
43
+ from torchx.schedulers.ids import make_unique
44
+ from torchx.specs.api import AppDef, AppState, CfgVal, runopts
45
+ from torchx.workspace.docker_workspace import DockerWorkspaceMixin
46
+ from typing_extensions import TypedDict
47
+
48
+
49
+ if TYPE_CHECKING:
50
+ from docker import DockerClient # pragma: no cover
51
+
52
+ JOB_STATE: Dict[str, AppState] = {
53
+ "InProgress": AppState.RUNNING,
54
+ "Completed": AppState.SUCCEEDED,
55
+ "Failed": AppState.FAILED,
56
+ "Stopping": AppState.CANCELLED,
57
+ "Stopped": AppState.CANCELLED,
58
+ }
59
+
60
+
61
+ class AWSSageMakerOpts(TypedDict, total=False):
62
+ """
63
+ Opts where we can get from .torchxconfig or user command args
64
+ """
65
+
66
+ role: str
67
+ instance_count: int
68
+ instance_type: str
69
+ keep_alive_period_in_seconds: Optional[int]
70
+ volume_size: Optional[int]
71
+ volume_kms_key: Optional[str]
72
+ max_run: Optional[int]
73
+ input_mode: Optional[str]
74
+ output_path: Optional[str]
75
+ output_kms_key: Optional[str]
76
+ base_job_name: Optional[str]
77
+ tags: Optional[Dict[str, str]]
78
+ subnets: Optional[List[str]]
79
+ security_group_ids: Optional[List[str]]
80
+ model_uri: Optional[str]
81
+ model_channel_name: Optional[str]
82
+ metric_definitions: Optional[Dict[str, str]]
83
+ encrypt_inter_container_traffic: Optional[bool]
84
+ use_spot_instances: Optional[bool]
85
+ max_wait: Optional[int]
86
+ checkpoint_s3_uri: Optional[str]
87
+ checkpoint_local_path: Optional[str]
88
+ debugger_hook_config: Optional[bool]
89
+ enable_sagemaker_metrics: Optional[bool]
90
+ enable_network_isolation: Optional[bool]
91
+ disable_profiler: Optional[bool]
92
+ environment: Optional[Dict[str, str]]
93
+ max_retry_attempts: Optional[int]
94
+ source_dir: Optional[str]
95
+ git_config: Optional[Dict[str, str]]
96
+ hyperparameters: Optional[Dict[str, str]]
97
+ container_log_level: Optional[int]
98
+ code_location: Optional[str]
99
+ dependencies: Optional[List[str]]
100
+ training_repository_access_mode: Optional[str]
101
+ training_repository_credentials_provider_arn: Optional[str]
102
+ disable_output_compression: Optional[bool]
103
+ enable_infra_check: Optional[bool]
104
+
105
+
106
+ @dataclass
107
+ class AWSSageMakerJob:
108
+ """
109
+ Jobs defined the key values that is requried to schedule a job. This will be the value
110
+ of `request` in the AppDryRunInfo object.
111
+
112
+ - job_name: defines the job name shown in SageMaker
113
+ - job_def: defines the job description that will be used to schedule the job on SageMaker
114
+ - images_to_push: used by torchx to push to image_repo
115
+ """
116
+
117
+ job_name: str
118
+ job_def: Dict[str, Any]
119
+ images_to_push: Dict[str, Tuple[str, str]]
120
+
121
+ def __str__(self) -> str:
122
+ return yaml.dump(asdict(self))
123
+
124
+ def __repr__(self) -> str:
125
+ return str(self)
126
+
127
+
128
+ T = TypeVar("T")
129
+
130
+
131
+ def _thread_local_cache(f: Callable[[], T]) -> Callable[[], T]:
132
+ # decorator function for keeping object in cache
133
+ local: threading.local = threading.local()
134
+ key: str = "value"
135
+
136
+ def wrapper() -> T:
137
+ if key in local.__dict__:
138
+ return local.__dict__[key]
139
+ v = f()
140
+ local.__dict__[key] = v
141
+ return v
142
+
143
+ return wrapper
144
+
145
+
146
+ @_thread_local_cache
147
+ def _local_session() -> boto3.session.Session:
148
+ return boto3.session.Session()
149
+
150
+
151
+ def _merge_ordered(
152
+ src: Optional[Dict[str, str]], extra: Dict[str, str]
153
+ ) -> OrderedDict[str, str]:
154
+ merged = OrdDict(src or {})
155
+ merged.update(extra)
156
+ return merged
157
+
158
+
159
+ class AWSSageMakerScheduler(DockerWorkspaceMixin, Scheduler[AWSSageMakerOpts]): # type: ignore[misc]
160
+ """
161
+ AWSSageMakerScheduler is a TorchX scheduling interface to AWS SageMaker.
162
+
163
+ .. code-block:: bash
164
+
165
+ $ torchx run -s aws_sagemaker utils.echo --image alpine:latest --msg hello
166
+ aws_batch://torchx_user/1234
167
+ $ torchx status aws_batch://torchx_user/1234
168
+ ...
169
+
170
+ Authentication is loaded from the environment using the ``boto3`` credential
171
+ handling.
172
+
173
+ **Config Options**
174
+
175
+ .. runopts::
176
+ class: torchx.schedulers.aws_sagemaker_scheduler.create_scheduler
177
+
178
+ **Compatibility**
179
+
180
+ .. compatibility::
181
+ type: scheduler
182
+ features:
183
+ cancel: true
184
+ logs: false
185
+ distributed: true
186
+ describe: |
187
+ Partial support. SageMakerScheduler will return job and replica
188
+ status but does not provide the complete original AppSpec.
189
+ workspaces: true
190
+ mounts: false
191
+ elasticity: false
192
+ """
193
+
194
+ def __init__(
195
+ self,
196
+ session_name: str,
197
+ client: Optional[Any] = None, # pyre-ignore[2]
198
+ docker_client: Optional["DockerClient"] = None,
199
+ ) -> None:
200
+ super().__init__("aws_sagemaker", session_name, docker_client=docker_client)
201
+ # pyre-fixme[4]: Attribute annotation cannot be `Any`.
202
+ self.__client = client
203
+
204
+ @property
205
+ # pyre-fixme[3]: Return annotation cannot be `Any`.
206
+ def _client(self) -> Any:
207
+ if self.__client:
208
+ return self.__client
209
+ return _local_session().client("sagemaker")
210
+
211
+ def schedule(self, dryrun_info: AppDryRunInfo[AWSSageMakerJob]) -> str:
212
+ cfg = dryrun_info._cfg
213
+ assert cfg is not None, f"{dryrun_info} missing cfg"
214
+
215
+ images_to_push = dryrun_info.request.images_to_push
216
+ self.push_images(images_to_push)
217
+
218
+ req = dryrun_info.request
219
+ pt_estimator = PyTorch(**req.job_def)
220
+ pt_estimator.fit(wait=False, job_name=req.job_name)
221
+
222
+ return req.job_name
223
+
224
+ def _submit_dryrun(
225
+ self, app: AppDef, cfg: AWSSageMakerOpts
226
+ ) -> AppDryRunInfo[AWSSageMakerJob]:
227
+ role = app.roles[0]
228
+ entrypoint, hyperparameters = self._parse_args(role.args)
229
+
230
+ # map any local images to the remote image
231
+ images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg))
232
+ structured_name_kwargs = {}
233
+ if entrypoint.startswith("-m"):
234
+ structured_name_kwargs["m"] = entrypoint.replace("-m", "").strip()
235
+ else:
236
+ structured_name_kwargs["script"] = entrypoint
237
+ structured_name = StructuredNameArgument.parse_from(
238
+ app.name, **structured_name_kwargs
239
+ )
240
+ job_name = make_unique(structured_name.run_name)
241
+
242
+ role.env["TORCHX_JOB_ID"] = job_name
243
+
244
+ # see https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#sagemaker.estimator.EstimatorBase
245
+ job_def = {
246
+ "entry_point": entrypoint,
247
+ "image_uri": role.image,
248
+ "distribution": {"torch_distributed": {"enabled": True}},
249
+ }
250
+
251
+ cfg["environment"] = _merge_ordered(cfg.get("environment"), role.env)
252
+ # hyperparameters are used for both script/module entrypoint args and the values from .torchxconfig
253
+ # order matters, adding script args last to handle wildcard parameters
254
+ cfg["hyperparameters"] = _merge_ordered(
255
+ cfg.get("hyperparameters"), hyperparameters
256
+ )
257
+ # tags are used for AppDef metadata and the values from .torchxconfig
258
+ cfg["tags"] = [ # pyre-ignore[54]
259
+ *(cfg.get("tags") or []),
260
+ *({"Key": k, "Value": v} for k, v in app.metadata.items()),
261
+ ]
262
+ # following the principle of least astonishment defaulting source_dir to current working directory
263
+ cfg["source_dir"] = cfg.get("source_dir") or os.getcwd()
264
+
265
+ for key in cfg:
266
+ if key in job_def:
267
+ raise ValueError(
268
+ f"{key} is controlled by aws_sagemaker_scheduler and is set to {job_def[key]}"
269
+ )
270
+ value = cfg.get(key) # type: ignore
271
+ if value is not None:
272
+ job_def[key] = value # type: ignore
273
+
274
+ req = AWSSageMakerJob(
275
+ job_name=job_name,
276
+ job_def=job_def,
277
+ images_to_push=images_to_push,
278
+ )
279
+ return AppDryRunInfo(req, repr)
280
+
281
+ def _parse_args(self, args: List[str]) -> Tuple[str, Dict[str, str]]:
282
+ if len(args) < 1:
283
+ raise ValueError("Not enough args to resolve entrypoint")
284
+ offset = 1
285
+ if args[0] == "-m":
286
+ if len(args) < 2:
287
+ raise ValueError("Missing module name")
288
+ offset += 1
289
+ entrypoint = " ".join(args[:offset])
290
+ hyperparameters = OrdDict() # the order matters, e.g. for wildcard params
291
+ while offset < len(args):
292
+ arg = args[offset]
293
+ sp_pos = arg.find("=")
294
+ if sp_pos < 0:
295
+ if offset + 1 >= len(args):
296
+ raise ValueError(
297
+ "SageMaker currently only supports named arguments"
298
+ )
299
+ key = arg
300
+ offset += 1
301
+ value = args[offset]
302
+ else:
303
+ key = arg[:sp_pos]
304
+ value = arg[sp_pos + 1 :]
305
+ if not key.startswith("--"):
306
+ raise ValueError("SageMaker only supports arguments that start with --")
307
+ offset += 1
308
+ hyperparameters[key[2:]] = value
309
+ return entrypoint, hyperparameters
310
+
311
+ def _run_opts(self) -> runopts:
312
+ opts = runopts()
313
+ opts.add(
314
+ "role",
315
+ type_=str,
316
+ help="an AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource.",
317
+ required=True,
318
+ )
319
+ opts.add(
320
+ "instance_count",
321
+ type_=int,
322
+ default=1,
323
+ help="number of Amazon EC2 instances to use for training. Required if instance_groups is not set.",
324
+ )
325
+ opts.add(
326
+ "instance_type",
327
+ type_=str,
328
+ help="type of EC2 instance to use for training, for example, 'ml.c4.xlarge'",
329
+ required=True,
330
+ )
331
+ opts.add(
332
+ "user",
333
+ type_=str,
334
+ default=getpass.getuser(),
335
+ help="the username to tag the job with. `getpass.getuser()` if not specified.",
336
+ )
337
+ opts.add(
338
+ "keep_alive_period_in_seconds",
339
+ type_=int,
340
+ default=None,
341
+ help="the duration of time in seconds to retain configured resources in a warm pool for subsequent training jobs.",
342
+ )
343
+ opts.add(
344
+ "volume_size",
345
+ type_=int,
346
+ default=None,
347
+ help="size in GB of the storage volume to use for storing input and output data during training (default: 30).",
348
+ )
349
+ opts.add(
350
+ "volume_kms_key",
351
+ type_=str,
352
+ default=None,
353
+ help="KMS key ID for encrypting EBS volume attached to the training instance.",
354
+ )
355
+ opts.add(
356
+ "max_run",
357
+ type_=int,
358
+ default=None,
359
+ help="timeout in seconds for training (default: 24 * 60 * 60).",
360
+ )
361
+ opts.add(
362
+ "input_mode",
363
+ type_=str,
364
+ default=None,
365
+ help="the input mode that the algorithm supports (default: ‘File’).",
366
+ )
367
+ opts.add(
368
+ "output_path",
369
+ type_=str,
370
+ default=None,
371
+ help="S3 location for saving the training result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the fit() method execution.",
372
+ )
373
+ opts.add(
374
+ "output_kms_key",
375
+ type_=str,
376
+ default=None,
377
+ help="KMS key ID for encrypting the training output (default: Your IAM role’s KMS key for Amazon S3).",
378
+ )
379
+ opts.add(
380
+ "base_job_name",
381
+ type_=str,
382
+ default=None,
383
+ help="prefix for training job name when the fit() method launches. If not specified, the estimator generates a default job name based on the training image name and current timestamp.",
384
+ )
385
+ opts.add(
386
+ "tags",
387
+ type_=List[Dict[str, str]],
388
+ default=None,
389
+ help="list of tags for labeling a training job.",
390
+ )
391
+ opts.add(
392
+ "subnets",
393
+ type_=List[str],
394
+ default=None,
395
+ help="list of subnet ids. If not specified training job will be created without VPC config.",
396
+ )
397
+ opts.add(
398
+ "security_group_ids",
399
+ type_=List[str],
400
+ default=None,
401
+ help="list of security group ids. If not specified training job will be created without VPC config.",
402
+ )
403
+ opts.add(
404
+ "model_uri",
405
+ type_=str,
406
+ default=None,
407
+ help="URI where a pre-trained model is stored, either locally or in S3.",
408
+ )
409
+ opts.add(
410
+ "model_channel_name",
411
+ type_=str,
412
+ default=None,
413
+ help="name of the channel where ‘model_uri’ will be downloaded (default: ‘model’).",
414
+ )
415
+ opts.add(
416
+ "metric_definitions",
417
+ type_=List[Dict[str, str]],
418
+ default=None,
419
+ help="list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: ‘Name’ for the name of the metric, and ‘Regex’ for the regular expression used to extract the metric from the logs.",
420
+ )
421
+ opts.add(
422
+ "encrypt_inter_container_traffic",
423
+ type_=bool,
424
+ default=None,
425
+ help="specifies whether traffic between training containers is encrypted for the training job (default: False).",
426
+ )
427
+ opts.add(
428
+ "use_spot_instances",
429
+ type_=bool,
430
+ default=None,
431
+ help="specifies whether to use SageMaker Managed Spot instances for training. If enabled then the max_wait arg should also be set.",
432
+ )
433
+ opts.add(
434
+ "max_wait",
435
+ type_=int,
436
+ default=None,
437
+ help="timeout in seconds waiting for spot training job.",
438
+ )
439
+ opts.add(
440
+ "checkpoint_s3_uri",
441
+ type_=str,
442
+ default=None,
443
+ help="S3 URI in which to persist checkpoints that the algorithm persists (if any) during training.",
444
+ )
445
+ opts.add(
446
+ "checkpoint_local_path",
447
+ type_=str,
448
+ default=None,
449
+ help="local path that the algorithm writes its checkpoints to.",
450
+ )
451
+ opts.add(
452
+ "debugger_hook_config",
453
+ type_=bool,
454
+ default=None,
455
+ help="configuration for how debugging information is emitted with SageMaker Debugger. If not specified, a default one is created using the estimator’s output_path, unless the region does not support SageMaker Debugger. To disable SageMaker Debugger, set this parameter to False.",
456
+ )
457
+ opts.add(
458
+ "enable_sagemaker_metrics",
459
+ type_=bool,
460
+ default=None,
461
+ help="enable SageMaker Metrics Time Series.",
462
+ )
463
+ opts.add(
464
+ "enable_network_isolation",
465
+ type_=bool,
466
+ default=None,
467
+ help="specifies whether container will run in network isolation mode (default: False).",
468
+ )
469
+ opts.add(
470
+ "disable_profiler",
471
+ type_=bool,
472
+ default=None,
473
+ help="specifies whether Debugger monitoring and profiling will be disabled (default: False).",
474
+ )
475
+ opts.add(
476
+ "environment",
477
+ type_=Dict[str, str],
478
+ default=None,
479
+ help="environment variables to be set for use during training job",
480
+ )
481
+ opts.add(
482
+ "max_retry_attempts",
483
+ type_=int,
484
+ default=None,
485
+ help="number of times to move a job to the STARTING status. You can specify between 1 and 30 attempts.",
486
+ )
487
+ opts.add(
488
+ "source_dir",
489
+ type_=str,
490
+ default=None,
491
+ help="absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: current working directory)",
492
+ )
493
+ opts.add(
494
+ "git_config",
495
+ type_=Dict[str, str],
496
+ default=None,
497
+ help="git configurations used for cloning files, including repo, branch, commit, 2FA_enabled, username, password, and token.",
498
+ )
499
+ opts.add(
500
+ "hyperparameters",
501
+ type_=Dict[str, str],
502
+ default=None,
503
+ help="dictionary containing the hyperparameters to initialize this estimator with.",
504
+ )
505
+ opts.add(
506
+ "container_log_level",
507
+ type_=int,
508
+ default=None,
509
+ help="log level to use within the container (default: logging.INFO).",
510
+ )
511
+ opts.add(
512
+ "code_location",
513
+ type_=str,
514
+ default=None,
515
+ help="S3 prefix URI where custom code is uploaded.",
516
+ )
517
+ opts.add(
518
+ "dependencies",
519
+ type_=List[str],
520
+ default=None,
521
+ help="list of absolute or relative paths to directories with any additional libraries that should be exported to the container.",
522
+ )
523
+ opts.add(
524
+ "training_repository_access_mode",
525
+ type_=str,
526
+ default=None,
527
+ help="specifies how SageMaker accesses the Docker image that contains the training algorithm.",
528
+ )
529
+ opts.add(
530
+ "training_repository_credentials_provider_arn",
531
+ type_=str,
532
+ default=None,
533
+ help="Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your training image is hosted.",
534
+ )
535
+ opts.add(
536
+ "disable_output_compression",
537
+ type_=bool,
538
+ default=None,
539
+ help="when set to true, Model is uploaded to Amazon S3 without compression after training finishes.",
540
+ )
541
+ opts.add(
542
+ "enable_infra_check",
543
+ type_=bool,
544
+ default=None,
545
+ help="specifies whether it is running Sagemaker built-in infra check jobs.",
546
+ )
547
+ return opts
548
+
549
+ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
550
+ job = self._get_job(app_id)
551
+ if job is None:
552
+ return None
553
+
554
+ return DescribeAppResponse(
555
+ app_id=app_id,
556
+ state=JOB_STATE[job["TrainingJobStatus"]],
557
+ ui_url=self._job_ui_url(job["TrainingJobArn"]),
558
+ )
559
+
560
+ def list(self) -> List[ListAppResponse]:
561
+ raise NotImplementedError()
562
+
563
+ def _cancel_existing(self, app_id: str) -> None:
564
+ self._client.stop_training_job(TrainingJobName=app_id)
565
+
566
+ def log_iter(
567
+ self,
568
+ app_id: str,
569
+ role_name: str,
570
+ k: int = 0,
571
+ regex: Optional[str] = None,
572
+ since: Optional[datetime] = None,
573
+ until: Optional[datetime] = None,
574
+ should_tail: bool = False,
575
+ streams: Optional[Stream] = None,
576
+ ) -> Iterable[str]:
577
+ raise NotImplementedError()
578
+
579
+ def _get_job(self, app_id: str) -> Optional[Dict[str, Any]]:
580
+ job = self._client.describe_training_job(TrainingJobName=app_id)
581
+ return job
582
+
583
+ def _job_ui_url(self, job_arn: str) -> Optional[str]:
584
+ match = re.match(
585
+ "arn:aws:sagemaker:(?P<region>[a-z-0-9]+):[0-9]+:training-job/(?P<job_id>[a-z-0-9]+)",
586
+ job_arn,
587
+ )
588
+ if match is None:
589
+ return None
590
+ region = match.group("region")
591
+ job_id = match.group("job_id")
592
+ return f"https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#jobs/{job_id}"
593
+
594
+
595
+ def create_scheduler(session_name: str, **kwargs: object) -> AWSSageMakerScheduler:
596
+ return AWSSageMakerScheduler(session_name=session_name)
@@ -4,26 +4,39 @@
4
4
  #
5
5
  # This source code is licensed under the BSD-style license found in the
6
6
  # LICENSE file in the root directory of this source tree.
7
+
8
+ # pyre-strict
7
9
  import warnings
10
+ from functools import partial
8
11
  from typing import Callable, Dict, List, Mapping
9
12
 
10
13
  from torchx.specs.api import DeviceMount
14
+ from torchx.specs.named_resources_aws import EFA_DEVICE, NEURON_DEVICE
11
15
 
12
16
 
13
- def efa_to_devicemounts(num_devices: int) -> List[DeviceMount]:
17
+ def to_devicemounts(num_devices: int, device_type: str) -> List[DeviceMount]:
14
18
  device_mounts = []
15
19
  for device_index in range(0, num_devices):
16
20
  device_mounts.append(
17
21
  DeviceMount(
18
- src_path="/dev/infiniband/uverbs" + str(device_index),
19
- dst_path="/dev/infiniband/uverbs" + str(device_index),
22
+ src_path=device_type + str(device_index),
23
+ dst_path=device_type + str(device_index),
20
24
  )
21
25
  )
22
26
  return device_mounts
23
27
 
24
28
 
29
+ neuron_to_devicemounts: Callable[[int], List[DeviceMount]] = partial(
30
+ to_devicemounts, device_type="/dev/neuron"
31
+ )
32
+ efa_to_devicemounts: Callable[[int], List[DeviceMount]] = partial(
33
+ to_devicemounts, device_type="/dev/infiniband/uverbs"
34
+ )
35
+
36
+
25
37
  DEVICES: Mapping[str, Callable[[int], List[DeviceMount]]] = {
26
- "vpc.amazonaws.com/efa": efa_to_devicemounts,
38
+ EFA_DEVICE: efa_to_devicemounts,
39
+ NEURON_DEVICE: neuron_to_devicemounts,
27
40
  }
28
41
 
29
42