torchx-nightly 2023.10.21__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +7 -6
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +79 -5
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +10 -8
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +431 -32
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +254 -22
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2023.10.21.dist-info/METADATA +0 -174
  108. torchx_nightly-2023.10.21.dist-info/RECORD +0 -118
  109. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2023.10.21.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
torchx/specs/api.py CHANGED
@@ -1,16 +1,26 @@
1
- #!/usr/bin/env python3
2
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
3
2
  # All rights reserved.
4
3
  #
5
4
  # This source code is licensed under the BSD-style license found in the
6
5
  # LICENSE file in the root directory of this source tree.
7
6
 
7
+ # pyre-strict
8
+
9
+ import asyncio
8
10
  import copy
11
+ import inspect
9
12
  import json
13
+ import logging as logger
14
+ import os
15
+ import pathlib
10
16
  import re
17
+ import shutil
18
+ import typing
19
+ import warnings
11
20
  from dataclasses import asdict, dataclass, field
12
21
  from datetime import datetime
13
- from enum import Enum
22
+ from enum import Enum, IntEnum
23
+ from json import JSONDecodeError
14
24
  from string import Template
15
25
  from typing import (
16
26
  Any,
@@ -20,6 +30,7 @@ from typing import (
20
30
  Iterator,
21
31
  List,
22
32
  Mapping,
33
+ NamedTuple,
23
34
  Optional,
24
35
  Pattern,
25
36
  Tuple,
@@ -55,6 +66,35 @@ _RPC_ERROR_MESSAGE_RE: Pattern[str] = re.compile(
55
66
  # (most recent call first):
56
67
  _EMBEDDED_ERROR_MESSAGE_RE: Pattern[str] = re.compile(r"(?P<msg>.+)\nException.*")
57
68
 
69
+ YELLOW_BOLD = "\033[1;33m"
70
+ RESET = "\033[0m"
71
+
72
+
73
+ def TORCHX_HOME(*subdir_paths: str) -> pathlib.Path:
74
+ """
75
+ Path to the "dot-directory" for torchx.
76
+ Defaults to `~/.torchx` and is overridable via the `TORCHX_HOME` environment variable.
77
+
78
+ Usage:
79
+
80
+ .. doc-test::
81
+
82
+ from pathlib import Path
83
+ from torchx.specs import TORCHX_HOME
84
+
85
+ assert TORCHX_HOME() == Path.home() / ".torchx"
86
+ assert TORCHX_HOME("conda-pack-out") == Path.home() / ".torchx" / "conda-pack-out"
87
+ ```
88
+ """
89
+
90
+ default_dir = str(pathlib.Path.home() / ".torchx")
91
+ torchx_home = pathlib.Path(os.getenv("TORCHX_HOME", default_dir))
92
+
93
+ torchx_home = torchx_home / os.path.sep.join(subdir_paths)
94
+ torchx_home.mkdir(parents=True, exist_ok=True)
95
+
96
+ return torchx_home
97
+
58
98
 
59
99
  # ========================================
60
100
  # ==== Distributed AppDef API =======
@@ -72,6 +112,8 @@ class Resource:
72
112
  memMB: MB of ram
73
113
  capabilities: additional hardware specs (interpreted by scheduler)
74
114
  devices: a list of named devices with their quantities
115
+ tags: metadata tags for the resource (not interpreted by schedulers)
116
+ used to add non-functional information about resources (e.g. whether it is an alias of another resource)
75
117
 
76
118
  Note: you should prefer to use named_resources instead of specifying the raw
77
119
  resource requirement directly.
@@ -82,6 +124,7 @@ class Resource:
82
124
  memMB: int
83
125
  capabilities: Dict[str, Any] = field(default_factory=dict)
84
126
  devices: Dict[str, int] = field(default_factory=dict)
127
+ tags: Dict[str, object] = field(default_factory=dict)
85
128
 
86
129
  @staticmethod
87
130
  def copy(original: "Resource", **capabilities: Any) -> "Resource":
@@ -90,6 +133,7 @@ class Resource:
90
133
  are present in the original resource and as parameter, the one from parameter
91
134
  will be used.
92
135
  """
136
+
93
137
  res_capabilities = dict(original.capabilities)
94
138
  res_capabilities.update(capabilities)
95
139
  return Resource(
@@ -182,16 +226,48 @@ class macros:
182
226
  apply applies the values to a copy the specified role and returns it.
183
227
  """
184
228
 
229
+ # Overrides might contain future values which can't be serialized so taken out for the copy
230
+ overrides = role.overrides
231
+ if len(overrides) > 0:
232
+ logger.warning(
233
+ "Role overrides are not supported for macros. Overrides will not be copied"
234
+ )
235
+ role.overrides = {}
185
236
  role = copy.deepcopy(role)
237
+ role.overrides = overrides
238
+
186
239
  role.args = [self.substitute(arg) for arg in role.args]
187
240
  role.env = {key: self.substitute(arg) for key, arg in role.env.items()}
241
+ role.metadata = self._apply_nested(role.metadata)
242
+
188
243
  return role
189
244
 
245
+ def _apply_nested(self, d: typing.Dict[str, Any]) -> typing.Dict[str, Any]:
246
+ stack = [d]
247
+ while stack:
248
+ current_dict = stack.pop()
249
+ for k, v in current_dict.items():
250
+ if isinstance(v, dict):
251
+ stack.append(v)
252
+ elif isinstance(v, str):
253
+ current_dict[k] = self.substitute(v)
254
+ elif isinstance(v, list):
255
+ for i in range(len(v)):
256
+ if isinstance(v[i], dict):
257
+ stack.append(v[i])
258
+ elif isinstance(v[i], str):
259
+ v[i] = self.substitute(v[i])
260
+ return d
261
+
262
+ # Overrides the asdict method to generate a dictionary of macro values to be substituted.
263
+ def to_dict(self) -> Dict[str, Any]:
264
+ return asdict(self)
265
+
190
266
  def substitute(self, arg: str) -> str:
191
267
  """
192
268
  substitute applies the values to the template arg.
193
269
  """
194
- return Template(arg).safe_substitute(**asdict(self))
270
+ return Template(arg).safe_substitute(**self.to_dict())
195
271
 
196
272
 
197
273
  class RetryPolicy(str, Enum):
@@ -215,11 +291,13 @@ class RetryPolicy(str, Enum):
215
291
  application to deal with failed replica departures and
216
292
  replacement replica admittance.
217
293
  2. APPLICATION: Restarts the entire application.
218
-
294
+ 3. ROLE: Restarts the role when any error occurs in that role. This does not
295
+ restart the whole job.
219
296
  """
220
297
 
221
298
  REPLICA = "REPLICA"
222
299
  APPLICATION = "APPLICATION"
300
+ ROLE = "ROLE"
223
301
 
224
302
 
225
303
  class MountType(str, Enum):
@@ -276,6 +354,121 @@ class DeviceMount:
276
354
  permissions: str = "rwm"
277
355
 
278
356
 
357
+ @dataclass
358
+ class Workspace:
359
+ """
360
+ Specifies a local "workspace" (a set of directories). Workspaces are ad-hoc built
361
+ into an (usually ephemeral) image. This effectively mirrors the local code changes
362
+ at job submission time.
363
+
364
+ For example:
365
+
366
+ 1. ``projects={"~/github/torch": "torch"}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/torch/**``
367
+ 2. ``projects={"~/github/torch": ""}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/**``
368
+
369
+ The exact location of ``$REMOTE_WORKSPACE_ROOT`` is implementation dependent and varies between
370
+ different implementations of :py:class:`~torchx.workspace.api.WorkspaceMixin`.
371
+ Check the scheduler documentation for details on which workspace it supports.
372
+
373
+ Note: ``projects`` maps the location of the local project to a sub-directory in the remote workspace root directory.
374
+ Typically the local project location is a directory path (e.g. ``/home/foo/github/torch``).
375
+
376
+
377
+ Attributes:
378
+ projects: mapping of local project to the sub-dir in the remote workspace dir.
379
+ """
380
+
381
+ projects: dict[str, str]
382
+
383
+ def __bool__(self) -> bool:
384
+ """False if no projects mapping. Lets us use workspace object in an if-statement"""
385
+ return bool(self.projects)
386
+
387
+ def __eq__(self, other: object) -> bool:
388
+ if not isinstance(other, Workspace):
389
+ return False
390
+ return self.projects == other.projects
391
+
392
+ def __hash__(self) -> int:
393
+ # makes it possible to use Workspace as the key in the workspace build cache
394
+ # see WorkspaceMixin.caching_build_workspace_and_update_role
395
+ return hash(frozenset(self.projects.items()))
396
+
397
+ def is_unmapped_single_project(self) -> bool:
398
+ """
399
+ Returns ``True`` if this workspace only has 1 project
400
+ and its target mapping is an empty string.
401
+ """
402
+ return len(self.projects) == 1 and not next(iter(self.projects.values()))
403
+
404
+ def merge_into(self, outdir: str | pathlib.Path) -> None:
405
+ """
406
+ Copies each project dir of this workspace into the specified ``outdir``.
407
+ Each project dir is copied into ``{outdir}/{target}`` where ``target`` is
408
+ the target mapping of the project dir.
409
+
410
+ For example:
411
+
412
+ .. code-block:: python
413
+ from os.path import expanduser
414
+
415
+ workspace = Workspace(
416
+ projects={
417
+ expanduser("~/workspace/torch"): "torch",
418
+ expanduser("~/workspace/my_project": "")
419
+ }
420
+ )
421
+ workspace.merge_into(expanduser("~/tmp"))
422
+
423
+ Copies:
424
+
425
+ * ``~/workspace/torch/**`` into ``~/tmp/torch/**``
426
+ * ``~/workspace/my_project/**`` into ``~/tmp/**``
427
+
428
+ """
429
+
430
+ for src, dst in self.projects.items():
431
+ dst_path = pathlib.Path(outdir) / dst
432
+ if pathlib.Path(src).is_file():
433
+ shutil.copy2(src, dst_path)
434
+ else: # src is dir
435
+ shutil.copytree(src, dst_path, dirs_exist_ok=True)
436
+
437
+ @staticmethod
438
+ def from_str(workspace: str | None) -> "Workspace":
439
+ import yaml
440
+
441
+ if not workspace:
442
+ return Workspace({})
443
+
444
+ projects = yaml.safe_load(workspace)
445
+ if isinstance(projects, str): # single project workspace
446
+ projects = {projects: ""}
447
+ else: # multi-project workspace
448
+ # Replace None mappings with "" (empty string)
449
+ projects = {k: ("" if v is None else v) for k, v in projects.items()}
450
+
451
+ return Workspace(projects)
452
+
453
+ def __str__(self) -> str:
454
+ """
455
+ Returns a string representation of the Workspace by concatenating
456
+ the project mappings using ';' as a delimiter and ':' between key and value.
457
+ If the single-project workspace with no target mapping, then simply
458
+ returns the src (local project dir)
459
+
460
+ NOTE: meant to be used for logging purposes not serde.
461
+ Therefore not symmetric with :py:func:`Workspace.from_str`.
462
+
463
+ """
464
+ if self.is_unmapped_single_project():
465
+ return next(iter(self.projects))
466
+ else:
467
+ return ";".join(
468
+ k if not v else f"{k}:{v}" for k, v in self.projects.items()
469
+ )
470
+
471
+
279
472
  @dataclass
280
473
  class Role:
281
474
  """
@@ -328,12 +521,15 @@ class Role:
328
521
  metadata: Free form information that is associated with the role, for example
329
522
  scheduler specific data. The key should follow the pattern: ``$scheduler.$key``
330
523
  mounts: a list of mounts on the machine
524
+ workspace: local project directories to be mirrored on the remote job.
525
+ NOTE: The workspace argument provided to the :py:class:`~torchx.runner.api.Runner` APIs
526
+ only takes effect on ``appdef.role[0]`` and overrides this attribute.
527
+
331
528
  """
332
529
 
333
530
  name: str
334
531
  image: str
335
532
  min_replicas: Optional[int] = None
336
- base_image: Optional[str] = None # DEPRECATED DO NOT SET, WILL BE REMOVED SOON
337
533
  entrypoint: str = MISSING
338
534
  args: List[str] = field(default_factory=list)
339
535
  env: Dict[str, str] = field(default_factory=dict)
@@ -343,9 +539,28 @@ class Role:
343
539
  resource: Resource = field(default_factory=_null_resource)
344
540
  port_map: Dict[str, int] = field(default_factory=dict)
345
541
  metadata: Dict[str, Any] = field(default_factory=dict)
346
- mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field(
347
- default_factory=list
348
- )
542
+ mounts: List[BindMount | VolumeMount | DeviceMount] = field(default_factory=list)
543
+ workspace: Workspace | None = None
544
+
545
+ # DEPRECATED DO NOT SET, WILL BE REMOVED SOON
546
+ overrides: Dict[str, Any] = field(default_factory=dict)
547
+
548
+ # pyre-ignore
549
+ def __getattribute__(self, attrname: str) -> Any:
550
+ if attrname == "overrides":
551
+ return super().__getattribute__(attrname)
552
+ try:
553
+ ov = super().__getattribute__("overrides")
554
+ except AttributeError:
555
+ ov = {}
556
+ if attrname in ov:
557
+ if inspect.isawaitable(ov[attrname]):
558
+ result = asyncio.get_event_loop().run_until_complete(ov[attrname])
559
+ else:
560
+ result = ov[attrname]()
561
+ setattr(self, attrname, result)
562
+ ov[attrname] = lambda: result
563
+ return super().__getattribute__(attrname)
349
564
 
350
565
  def pre_proc(
351
566
  self,
@@ -481,6 +696,15 @@ class RoleStatus:
481
696
  role: str
482
697
  replicas: List[ReplicaStatus]
483
698
 
699
+ def to_json(self) -> Dict[str, Any]:
700
+ """
701
+ Convert the RoleStatus to a json object.
702
+ """
703
+ return {
704
+ "role": self.role,
705
+ "replicas": [asdict(replica) for replica in self.replicas],
706
+ }
707
+
484
708
 
485
709
  @dataclass
486
710
  class AppStatus:
@@ -551,7 +775,10 @@ class AppStatus:
551
775
 
552
776
  def _format_replica_status(self, replica_status: ReplicaStatus) -> str:
553
777
  if replica_status.structured_error_msg != NONE:
554
- error_data = json.loads(replica_status.structured_error_msg)
778
+ try:
779
+ error_data = json.loads(replica_status.structured_error_msg)
780
+ except JSONDecodeError:
781
+ return replica_status.structured_error_msg
555
782
  error_message = self._format_error_message(
556
783
  msg=error_data["message"]["message"], header=" error_msg: "
557
784
  )
@@ -597,6 +824,21 @@ class AppStatus:
597
824
  replica_data += self._format_replica_status(replica)
598
825
  return f"{replica_data}"
599
826
 
827
+ def to_json(self, filter_roles: Optional[List[str]] = None) -> Dict[str, Any]:
828
+ """
829
+ Convert the AppStatus to a json object, including RoleStatus.
830
+ """
831
+ roles = self._get_role_statuses(self.roles, filter_roles)
832
+
833
+ return {
834
+ "state": str(self.state),
835
+ "num_restarts": self.num_restarts,
836
+ "roles": [role_status.to_json() for role_status in roles],
837
+ "msg": self.msg,
838
+ "structured_error_msg": self.structured_error_msg,
839
+ "url": self.ui_url,
840
+ }
841
+
600
842
  def format(
601
843
  self,
602
844
  filter_roles: Optional[List[str]] = None,
@@ -612,6 +854,7 @@ class AppStatus:
612
854
  """
613
855
  roles_data = ""
614
856
  roles = self._get_role_statuses(self.roles, filter_roles)
857
+
615
858
  for role_status in roles:
616
859
  roles_data += self._format_role_status(role_status)
617
860
  return Template(_APP_STATUS_FORMAT_TEMPLATE).substitute(
@@ -636,11 +879,11 @@ class AppStatusError(Exception):
636
879
  self.status = status
637
880
 
638
881
 
639
- # valid run cfg values; only support primitives (str, int, float, bool, List[str])
882
+ # valid run cfg values; only support primitives (str, int, float, bool, List[str], Dict[str, str])
640
883
  # TODO(wilsonhong): python 3.9+ supports list[T] in typing, which can be used directly
641
884
  # in isinstance(). Should replace with that.
642
885
  # see: https://docs.python.org/3/library/stdtypes.html#generic-alias-type
643
- CfgVal = Union[str, int, float, bool, List[str], None]
886
+ CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], None]
644
887
 
645
888
 
646
889
  T = TypeVar("T")
@@ -699,6 +942,62 @@ class runopt:
699
942
  opt_type: Type[CfgVal]
700
943
  is_required: bool
701
944
  help: str
945
+ aliases: list[str] | None = None
946
+ deprecated_aliases: list[str] | None = None
947
+
948
+ @property
949
+ def is_type_list_of_str(self) -> bool:
950
+ """
951
+ Checks if the option type is a list of strings.
952
+
953
+ Returns:
954
+ bool: True if the option type is either List[str] or list[str], False otherwise.
955
+ """
956
+ return self.opt_type in (List[str], list[str])
957
+
958
+ @property
959
+ def is_type_dict_of_str(self) -> bool:
960
+ """
961
+ Checks if the option type is a dict of string keys to string values.
962
+
963
+ Returns:
964
+ bool: True if the option type is either Dict[str, str] or dict[str, str], False otherwise.
965
+ """
966
+ return self.opt_type in (Dict[str, str], dict[str, str])
967
+
968
+ def cast_to_type(self, value: str) -> CfgVal:
969
+ """Casts the given `value` (in its string representation) to the type of this run option.
970
+ Below are the cast rules for each option type and value literal:
971
+
972
+ 1. opt_type=str, value="foo" -> "foo"
973
+ 1. opt_type=bool, value="True"/"False" -> True/False
974
+ 1. opt_type=int, value="1" -> 1
975
+ 1. opt_type=float, value="1.1" -> 1.1
976
+ 1. opt_type=list[str]/List[str], value="a,b,c" or value="a;b;c" -> ["a", "b", "c"]
977
+ 1. opt_type=dict[str,str]/Dict[str,str],
978
+ value="key1:val1,key2:val2" or value="key1:val1;key2:val2" -> {"key1": "val1", "key2": "val2"}
979
+
980
+ NOTE: dict parsing uses ":" as the kv separator (rather than the standard "=") because "=" is used
981
+ at the top-level cfg to parse runopts (notice the plural) from the CLI. Originally torchx only supported
982
+ primitives and list[str] as CfgVal but dict[str,str] was added in https://github.com/meta-pytorch/torchx/pull/855
983
+ """
984
+
985
+ if self.opt_type is None:
986
+ raise ValueError("runopt's opt_type cannot be `None`")
987
+ elif self.opt_type == bool:
988
+ return value.lower() == "true"
989
+ elif self.opt_type in (List[str], list[str]):
990
+ # lists may be ; or , delimited
991
+ # also deal with trailing "," by removing empty strings
992
+ return [v for v in value.replace(";", ",").split(",") if v]
993
+ elif self.opt_type in (Dict[str, str], dict[str, str]):
994
+ return {
995
+ s.split(":", 1)[0]: s.split(":", 1)[1]
996
+ for s in value.replace(";", ",").split(",")
997
+ }
998
+ else:
999
+ assert self.opt_type in (str, int, float)
1000
+ return self.opt_type(value)
702
1001
 
703
1002
 
704
1003
  class runopts:
@@ -736,6 +1035,7 @@ class runopts:
736
1035
 
737
1036
  def __init__(self) -> None:
738
1037
  self._opts: Dict[str, runopt] = {}
1038
+ self._alias_to_key: dict[str, str] = {}
739
1039
 
740
1040
  def __iter__(self) -> Iterator[Tuple[str, runopt]]:
741
1041
  return self._opts.items().__iter__()
@@ -754,14 +1054,25 @@ class runopts:
754
1054
  except TypeError:
755
1055
  if isinstance(obj, list):
756
1056
  return all(isinstance(e, str) for e in obj)
1057
+ elif isinstance(obj, dict):
1058
+ return all(
1059
+ isinstance(k, str) and isinstance(v, str) for k, v in obj.items()
1060
+ )
757
1061
  else:
758
1062
  return False
759
1063
 
760
1064
  def get(self, name: str) -> Optional[runopt]:
761
1065
  """
762
- Returns option if any was registered, or None otherwise
1066
+ Returns option if any was registered, or None otherwise.
1067
+ First searches for the option by ``name``, then falls-back to matching ``name`` with any
1068
+ registered aliases.
1069
+
763
1070
  """
764
- return self._opts.get(name, None)
1071
+ if name in self._opts:
1072
+ return self._opts[name]
1073
+ if name in self._alias_to_key:
1074
+ return self._opts[self._alias_to_key[name]]
1075
+ return None
765
1076
 
766
1077
  def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
767
1078
  """
@@ -776,6 +1087,36 @@ class runopts:
776
1087
 
777
1088
  for cfg_key, runopt in self._opts.items():
778
1089
  val = resolved_cfg.get(cfg_key)
1090
+ resolved_name = None
1091
+ aliases = runopt.aliases or []
1092
+ deprecated_aliases = runopt.deprecated_aliases or []
1093
+ if val is None:
1094
+ for alias in aliases:
1095
+ val = resolved_cfg.get(alias)
1096
+ if alias in cfg or val is not None:
1097
+ resolved_name = alias
1098
+ break
1099
+ for alias in deprecated_aliases:
1100
+ val = resolved_cfg.get(alias)
1101
+ if val is not None:
1102
+ resolved_name = alias
1103
+ use_instead = self._alias_to_key.get(alias)
1104
+ warnings.warn(
1105
+ f"Run option `{alias}` is deprecated, use `{use_instead}` instead",
1106
+ UserWarning,
1107
+ stacklevel=2,
1108
+ )
1109
+ break
1110
+ else:
1111
+ resolved_name = cfg_key
1112
+ for alias in aliases:
1113
+ duplicate_val = resolved_cfg.get(alias)
1114
+ if alias in cfg or duplicate_val is not None:
1115
+ raise InvalidRunConfigException(
1116
+ f"Duplicate opt name. runopt: `{resolved_name}``, is an alias of runopt: `{alias}`",
1117
+ resolved_name,
1118
+ cfg,
1119
+ )
779
1120
 
780
1121
  # check required opt
781
1122
  if runopt.is_required and val is None:
@@ -795,7 +1136,7 @@ class runopts:
795
1136
  )
796
1137
 
797
1138
  # not required and not set, set to default
798
- if val is None:
1139
+ if val is None and resolved_name is None:
799
1140
  resolved_cfg[cfg_key] = runopt.default
800
1141
  return resolved_cfg
801
1142
 
@@ -855,22 +1196,37 @@ class runopts:
855
1196
 
856
1197
  """
857
1198
 
858
- def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal:
859
- if opt_type == bool:
860
- return value.lower() == "true"
861
- elif opt_type == List[str]:
862
- # lists may be ; or , delimited
863
- # also deal with trailing "," by removing empty strings
864
- return [v for v in value.replace(";", ",").split(",") if v]
1199
+ cfg: Dict[str, CfgVal] = {}
1200
+ for key, val in to_dict(cfg_str).items():
1201
+ opt = self.get(key)
1202
+ if opt:
1203
+ cfg[key] = opt.cast_to_type(val)
865
1204
  else:
866
- # pyre-ignore[19]
867
- return opt_type(value)
1205
+ logger.warning(
1206
+ f"{YELLOW_BOLD}Unknown run option passed to scheduler: {key}={val}{RESET}"
1207
+ )
1208
+ return cfg
868
1209
 
1210
+ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
1211
+ """
1212
+ Converts the given dict to a valid cfg for this ``runopts`` object.
1213
+ """
869
1214
  cfg: Dict[str, CfgVal] = {}
870
- for key, val in to_dict(cfg_str).items():
871
- runopt_ = self.get(key)
872
- if runopt_:
873
- cfg[key] = _cast_to_type(val, runopt_.opt_type)
1215
+ cfg_dict = json.loads(json_repr)
1216
+ for key, val in cfg_dict.items():
1217
+ opt = self.get(key)
1218
+ if opt:
1219
+ # Optional runopt cfg values default their value to None,
1220
+ # but use `_type` to specify their type when provided.
1221
+ # Make sure not to treat None's as lists/dictionaries
1222
+ if val is None:
1223
+ cfg[key] = val
1224
+ elif opt.is_type_list_of_str:
1225
+ cfg[key] = [str(v) for v in val]
1226
+ elif opt.is_type_dict_of_str:
1227
+ cfg[key] = {str(k): str(v) for k, v in val.items()}
1228
+ else:
1229
+ cfg[key] = val
874
1230
  return cfg
875
1231
 
876
1232
  def add(
@@ -880,12 +1236,16 @@ class runopts:
880
1236
  help: str,
881
1237
  default: CfgVal = None,
882
1238
  required: bool = False,
1239
+ aliases: Optional[list[str]] = None,
1240
+ deprecated_aliases: Optional[list[str]] = None,
883
1241
  ) -> None:
884
1242
  """
885
1243
  Adds the ``config`` option with the given help string and ``default``
886
1244
  value (if any). If the ``default`` is not specified then this option
887
1245
  is a required option.
888
1246
  """
1247
+ aliases = aliases or []
1248
+ deprecated_aliases = deprecated_aliases or []
889
1249
  if required and default is not None:
890
1250
  raise ValueError(
891
1251
  f"Required option: {cfg_key} must not specify default value. Given: {default}"
@@ -897,7 +1257,19 @@ class runopts:
897
1257
  f" Given: {default} ({type(default).__name__})"
898
1258
  )
899
1259
 
900
- self._opts[cfg_key] = runopt(default, type_, required, help)
1260
+ opt = runopt(
1261
+ default,
1262
+ type_,
1263
+ required,
1264
+ help,
1265
+ list(set(aliases)),
1266
+ list(set(deprecated_aliases)),
1267
+ )
1268
+ for alias in aliases:
1269
+ self._alias_to_key[alias] = cfg_key
1270
+ for deprecated_alias in deprecated_aliases:
1271
+ self._alias_to_key[deprecated_alias] = cfg_key
1272
+ self._opts[cfg_key] = opt
901
1273
 
902
1274
  def update(self, other: "runopts") -> None:
903
1275
  self._opts.update(other._opts)
@@ -970,6 +1342,16 @@ class UnknownSchedulerException(Exception):
970
1342
  AppHandle = str
971
1343
 
972
1344
 
1345
+ class ParsedAppHandle(NamedTuple):
1346
+ """
1347
+ Individual accessible components of the `AppHandle`
1348
+ """
1349
+
1350
+ scheduler_backend: str
1351
+ session_name: str
1352
+ app_id: str
1353
+
1354
+
973
1355
  class UnknownAppException(Exception):
974
1356
  """
975
1357
  Raised by ``Session`` APIs when either the application does not
@@ -983,18 +1365,35 @@ class UnknownAppException(Exception):
983
1365
  )
984
1366
 
985
1367
 
986
- def parse_app_handle(app_handle: AppHandle) -> Tuple[str, str, str]:
1368
+ def parse_app_handle(app_handle: AppHandle) -> ParsedAppHandle:
987
1369
  """
988
- parses the app handle into ```(scheduler_backend, session_name, and app_id)```
1370
+ Parses the app handle into ```(scheduler_backend, session_name, and app_id)```.
1371
+
1372
+ Example:
1373
+
1374
+ .. doctest::
1375
+
1376
+ assert parse_app_handle("k8s://default/foo_bar") == ("k8s", "default", "foo_bar")
1377
+ assert parse_app_handle("k8s:///foo_bar") == ("k8s", "", "foo_bar")
1378
+
1379
+ Args:
1380
+ app_handle: a URI of the form ``{scheduler}://{session_name}/{app_id}``,
1381
+ where the ``session_name`` is optional. In this case the app handle is
1382
+ of the form ``{scheduler}:///{app_id}`` (notice the triple slashes).
1383
+
1384
+ Returns: A ``Tuple`` of three elements, ``(scheduler, session_name, app_id)``
1385
+ parsed from the app_handle URI str. If the session name is not present then
1386
+ an empty string is returned in its place in the tuple.
1387
+
989
1388
  """
990
1389
 
991
1390
  # parse it manually b/c currently torchx does not
992
1391
  # define allowed characters nor length for session name and app_id
993
1392
  import re
994
1393
 
995
- pattern = r"(?P<scheduler_backend>.+)://(?P<session_name>.+)/(?P<app_id>.+)"
1394
+ pattern = r"(?P<scheduler_backend>.+)://(?P<session_name>.*)/(?P<app_id>.+)"
996
1395
  match = re.match(pattern, app_handle)
997
1396
  if not match:
998
1397
  raise MalformedAppHandleException(app_handle)
999
1398
  gd = match.groupdict()
1000
- return gd["scheduler_backend"], gd["session_name"], gd["app_id"]
1399
+ return ParsedAppHandle(gd["scheduler_backend"], gd["session_name"], gd["app_id"])