torchx-nightly 2024.1.6__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +4 -3
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +75 -6
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +9 -7
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +418 -30
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +181 -4
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2024.1.6.dist-info/METADATA +0 -176
  108. torchx_nightly-2024.1.6.dist-info/RECORD +0 -118
  109. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
torchx/specs/api.py CHANGED
@@ -1,16 +1,26 @@
1
- #!/usr/bin/env python3
2
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
3
2
  # All rights reserved.
4
3
  #
5
4
  # This source code is licensed under the BSD-style license found in the
6
5
  # LICENSE file in the root directory of this source tree.
7
6
 
7
+ # pyre-strict
8
+
9
+ import asyncio
8
10
  import copy
11
+ import inspect
9
12
  import json
13
+ import logging as logger
14
+ import os
15
+ import pathlib
10
16
  import re
17
+ import shutil
18
+ import typing
19
+ import warnings
11
20
  from dataclasses import asdict, dataclass, field
12
21
  from datetime import datetime
13
- from enum import Enum
22
+ from enum import Enum, IntEnum
23
+ from json import JSONDecodeError
14
24
  from string import Template
15
25
  from typing import (
16
26
  Any,
@@ -56,6 +66,35 @@ _RPC_ERROR_MESSAGE_RE: Pattern[str] = re.compile(
56
66
  # (most recent call first):
57
67
  _EMBEDDED_ERROR_MESSAGE_RE: Pattern[str] = re.compile(r"(?P<msg>.+)\nException.*")
58
68
 
69
+ YELLOW_BOLD = "\033[1;33m"
70
+ RESET = "\033[0m"
71
+
72
+
73
+ def TORCHX_HOME(*subdir_paths: str) -> pathlib.Path:
74
+ """
75
+ Path to the "dot-directory" for torchx.
76
+ Defaults to `~/.torchx` and is overridable via the `TORCHX_HOME` environment variable.
77
+
78
+ Usage:
79
+
80
+ .. doc-test::
81
+
82
+ from pathlib import Path
83
+ from torchx.specs import TORCHX_HOME
84
+
85
+ assert TORCHX_HOME() == Path.home() / ".torchx"
86
+ assert TORCHX_HOME("conda-pack-out") == Path.home() / ".torchx" / "conda-pack-out"
87
+ ```
88
+ """
89
+
90
+ default_dir = str(pathlib.Path.home() / ".torchx")
91
+ torchx_home = pathlib.Path(os.getenv("TORCHX_HOME", default_dir))
92
+
93
+ torchx_home = torchx_home / os.path.sep.join(subdir_paths)
94
+ torchx_home.mkdir(parents=True, exist_ok=True)
95
+
96
+ return torchx_home
97
+
59
98
 
60
99
  # ========================================
61
100
  # ==== Distributed AppDef API =======
@@ -73,6 +112,8 @@ class Resource:
73
112
  memMB: MB of ram
74
113
  capabilities: additional hardware specs (interpreted by scheduler)
75
114
  devices: a list of named devices with their quantities
115
+ tags: metadata tags for the resource (not interpreted by schedulers)
116
+ used to add non-functional information about resources (e.g. whether it is an alias of another resource)
76
117
 
77
118
  Note: you should prefer to use named_resources instead of specifying the raw
78
119
  resource requirement directly.
@@ -83,6 +124,7 @@ class Resource:
83
124
  memMB: int
84
125
  capabilities: Dict[str, Any] = field(default_factory=dict)
85
126
  devices: Dict[str, int] = field(default_factory=dict)
127
+ tags: Dict[str, object] = field(default_factory=dict)
86
128
 
87
129
  @staticmethod
88
130
  def copy(original: "Resource", **capabilities: Any) -> "Resource":
@@ -91,6 +133,7 @@ class Resource:
91
133
  are present in the original resource and as parameter, the one from parameter
92
134
  will be used.
93
135
  """
136
+
94
137
  res_capabilities = dict(original.capabilities)
95
138
  res_capabilities.update(capabilities)
96
139
  return Resource(
@@ -183,16 +226,48 @@ class macros:
183
226
  apply applies the values to a copy the specified role and returns it.
184
227
  """
185
228
 
229
+ # Overrides might contain future values which can't be serialized so taken out for the copy
230
+ overrides = role.overrides
231
+ if len(overrides) > 0:
232
+ logger.warning(
233
+ "Role overrides are not supported for macros. Overrides will not be copied"
234
+ )
235
+ role.overrides = {}
186
236
  role = copy.deepcopy(role)
237
+ role.overrides = overrides
238
+
187
239
  role.args = [self.substitute(arg) for arg in role.args]
188
240
  role.env = {key: self.substitute(arg) for key, arg in role.env.items()}
241
+ role.metadata = self._apply_nested(role.metadata)
242
+
189
243
  return role
190
244
 
245
+ def _apply_nested(self, d: typing.Dict[str, Any]) -> typing.Dict[str, Any]:
246
+ stack = [d]
247
+ while stack:
248
+ current_dict = stack.pop()
249
+ for k, v in current_dict.items():
250
+ if isinstance(v, dict):
251
+ stack.append(v)
252
+ elif isinstance(v, str):
253
+ current_dict[k] = self.substitute(v)
254
+ elif isinstance(v, list):
255
+ for i in range(len(v)):
256
+ if isinstance(v[i], dict):
257
+ stack.append(v[i])
258
+ elif isinstance(v[i], str):
259
+ v[i] = self.substitute(v[i])
260
+ return d
261
+
262
+ # Overrides the asdict method to generate a dictionary of macro values to be substituted.
263
+ def to_dict(self) -> Dict[str, Any]:
264
+ return asdict(self)
265
+
191
266
  def substitute(self, arg: str) -> str:
192
267
  """
193
268
  substitute applies the values to the template arg.
194
269
  """
195
- return Template(arg).safe_substitute(**asdict(self))
270
+ return Template(arg).safe_substitute(**self.to_dict())
196
271
 
197
272
 
198
273
  class RetryPolicy(str, Enum):
@@ -216,11 +291,13 @@ class RetryPolicy(str, Enum):
216
291
  application to deal with failed replica departures and
217
292
  replacement replica admittance.
218
293
  2. APPLICATION: Restarts the entire application.
219
-
294
+ 3. ROLE: Restarts the role when any error occurs in that role. This does not
295
+ restart the whole job.
220
296
  """
221
297
 
222
298
  REPLICA = "REPLICA"
223
299
  APPLICATION = "APPLICATION"
300
+ ROLE = "ROLE"
224
301
 
225
302
 
226
303
  class MountType(str, Enum):
@@ -277,6 +354,121 @@ class DeviceMount:
277
354
  permissions: str = "rwm"
278
355
 
279
356
 
357
+ @dataclass
358
+ class Workspace:
359
+ """
360
+ Specifies a local "workspace" (a set of directories). Workspaces are ad-hoc built
361
+ into an (usually ephemeral) image. This effectively mirrors the local code changes
362
+ at job submission time.
363
+
364
+ For example:
365
+
366
+ 1. ``projects={"~/github/torch": "torch"}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/torch/**``
367
+ 2. ``projects={"~/github/torch": ""}`` copies ``~/github/torch/**`` into ``$REMOTE_WORKSPACE_ROOT/**``
368
+
369
+ The exact location of ``$REMOTE_WORKSPACE_ROOT`` is implementation dependent and varies between
370
+ different implementations of :py:class:`~torchx.workspace.api.WorkspaceMixin`.
371
+ Check the scheduler documentation for details on which workspace it supports.
372
+
373
+ Note: ``projects`` maps the location of the local project to a sub-directory in the remote workspace root directory.
374
+ Typically the local project location is a directory path (e.g. ``/home/foo/github/torch``).
375
+
376
+
377
+ Attributes:
378
+ projects: mapping of local project to the sub-dir in the remote workspace dir.
379
+ """
380
+
381
+ projects: dict[str, str]
382
+
383
+ def __bool__(self) -> bool:
384
+ """False if no projects mapping. Lets us use workspace object in an if-statement"""
385
+ return bool(self.projects)
386
+
387
+ def __eq__(self, other: object) -> bool:
388
+ if not isinstance(other, Workspace):
389
+ return False
390
+ return self.projects == other.projects
391
+
392
+ def __hash__(self) -> int:
393
+ # makes it possible to use Workspace as the key in the workspace build cache
394
+ # see WorkspaceMixin.caching_build_workspace_and_update_role
395
+ return hash(frozenset(self.projects.items()))
396
+
397
+ def is_unmapped_single_project(self) -> bool:
398
+ """
399
+ Returns ``True`` if this workspace only has 1 project
400
+ and its target mapping is an empty string.
401
+ """
402
+ return len(self.projects) == 1 and not next(iter(self.projects.values()))
403
+
404
+ def merge_into(self, outdir: str | pathlib.Path) -> None:
405
+ """
406
+ Copies each project dir of this workspace into the specified ``outdir``.
407
+ Each project dir is copied into ``{outdir}/{target}`` where ``target`` is
408
+ the target mapping of the project dir.
409
+
410
+ For example:
411
+
412
+ .. code-block:: python
413
+ from os.path import expanduser
414
+
415
+ workspace = Workspace(
416
+ projects={
417
+ expanduser("~/workspace/torch"): "torch",
418
+ expanduser("~/workspace/my_project": "")
419
+ }
420
+ )
421
+ workspace.merge_into(expanduser("~/tmp"))
422
+
423
+ Copies:
424
+
425
+ * ``~/workspace/torch/**`` into ``~/tmp/torch/**``
426
+ * ``~/workspace/my_project/**`` into ``~/tmp/**``
427
+
428
+ """
429
+
430
+ for src, dst in self.projects.items():
431
+ dst_path = pathlib.Path(outdir) / dst
432
+ if pathlib.Path(src).is_file():
433
+ shutil.copy2(src, dst_path)
434
+ else: # src is dir
435
+ shutil.copytree(src, dst_path, dirs_exist_ok=True)
436
+
437
+ @staticmethod
438
+ def from_str(workspace: str | None) -> "Workspace":
439
+ import yaml
440
+
441
+ if not workspace:
442
+ return Workspace({})
443
+
444
+ projects = yaml.safe_load(workspace)
445
+ if isinstance(projects, str): # single project workspace
446
+ projects = {projects: ""}
447
+ else: # multi-project workspace
448
+ # Replace None mappings with "" (empty string)
449
+ projects = {k: ("" if v is None else v) for k, v in projects.items()}
450
+
451
+ return Workspace(projects)
452
+
453
+ def __str__(self) -> str:
454
+ """
455
+ Returns a string representation of the Workspace by concatenating
456
+ the project mappings using ';' as a delimiter and ':' between key and value.
457
+ If the single-project workspace with no target mapping, then simply
458
+ returns the src (local project dir)
459
+
460
+ NOTE: meant to be used for logging purposes not serde.
461
+ Therefore not symmetric with :py:func:`Workspace.from_str`.
462
+
463
+ """
464
+ if self.is_unmapped_single_project():
465
+ return next(iter(self.projects))
466
+ else:
467
+ return ";".join(
468
+ k if not v else f"{k}:{v}" for k, v in self.projects.items()
469
+ )
470
+
471
+
280
472
  @dataclass
281
473
  class Role:
282
474
  """
@@ -329,12 +521,15 @@ class Role:
329
521
  metadata: Free form information that is associated with the role, for example
330
522
  scheduler specific data. The key should follow the pattern: ``$scheduler.$key``
331
523
  mounts: a list of mounts on the machine
524
+ workspace: local project directories to be mirrored on the remote job.
525
+ NOTE: The workspace argument provided to the :py:class:`~torchx.runner.api.Runner` APIs
526
+ only takes effect on ``appdef.role[0]`` and overrides this attribute.
527
+
332
528
  """
333
529
 
334
530
  name: str
335
531
  image: str
336
532
  min_replicas: Optional[int] = None
337
- base_image: Optional[str] = None # DEPRECATED DO NOT SET, WILL BE REMOVED SOON
338
533
  entrypoint: str = MISSING
339
534
  args: List[str] = field(default_factory=list)
340
535
  env: Dict[str, str] = field(default_factory=dict)
@@ -344,9 +539,28 @@ class Role:
344
539
  resource: Resource = field(default_factory=_null_resource)
345
540
  port_map: Dict[str, int] = field(default_factory=dict)
346
541
  metadata: Dict[str, Any] = field(default_factory=dict)
347
- mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field(
348
- default_factory=list
349
- )
542
+ mounts: List[BindMount | VolumeMount | DeviceMount] = field(default_factory=list)
543
+ workspace: Workspace | None = None
544
+
545
+ # DEPRECATED DO NOT SET, WILL BE REMOVED SOON
546
+ overrides: Dict[str, Any] = field(default_factory=dict)
547
+
548
+ # pyre-ignore
549
+ def __getattribute__(self, attrname: str) -> Any:
550
+ if attrname == "overrides":
551
+ return super().__getattribute__(attrname)
552
+ try:
553
+ ov = super().__getattribute__("overrides")
554
+ except AttributeError:
555
+ ov = {}
556
+ if attrname in ov:
557
+ if inspect.isawaitable(ov[attrname]):
558
+ result = asyncio.get_event_loop().run_until_complete(ov[attrname])
559
+ else:
560
+ result = ov[attrname]()
561
+ setattr(self, attrname, result)
562
+ ov[attrname] = lambda: result
563
+ return super().__getattribute__(attrname)
350
564
 
351
565
  def pre_proc(
352
566
  self,
@@ -482,6 +696,15 @@ class RoleStatus:
482
696
  role: str
483
697
  replicas: List[ReplicaStatus]
484
698
 
699
+ def to_json(self) -> Dict[str, Any]:
700
+ """
701
+ Convert the RoleStatus to a json object.
702
+ """
703
+ return {
704
+ "role": self.role,
705
+ "replicas": [asdict(replica) for replica in self.replicas],
706
+ }
707
+
485
708
 
486
709
  @dataclass
487
710
  class AppStatus:
@@ -552,7 +775,10 @@ class AppStatus:
552
775
 
553
776
  def _format_replica_status(self, replica_status: ReplicaStatus) -> str:
554
777
  if replica_status.structured_error_msg != NONE:
555
- error_data = json.loads(replica_status.structured_error_msg)
778
+ try:
779
+ error_data = json.loads(replica_status.structured_error_msg)
780
+ except JSONDecodeError:
781
+ return replica_status.structured_error_msg
556
782
  error_message = self._format_error_message(
557
783
  msg=error_data["message"]["message"], header=" error_msg: "
558
784
  )
@@ -598,6 +824,21 @@ class AppStatus:
598
824
  replica_data += self._format_replica_status(replica)
599
825
  return f"{replica_data}"
600
826
 
827
+ def to_json(self, filter_roles: Optional[List[str]] = None) -> Dict[str, Any]:
828
+ """
829
+ Convert the AppStatus to a json object, including RoleStatus.
830
+ """
831
+ roles = self._get_role_statuses(self.roles, filter_roles)
832
+
833
+ return {
834
+ "state": str(self.state),
835
+ "num_restarts": self.num_restarts,
836
+ "roles": [role_status.to_json() for role_status in roles],
837
+ "msg": self.msg,
838
+ "structured_error_msg": self.structured_error_msg,
839
+ "url": self.ui_url,
840
+ }
841
+
601
842
  def format(
602
843
  self,
603
844
  filter_roles: Optional[List[str]] = None,
@@ -613,6 +854,7 @@ class AppStatus:
613
854
  """
614
855
  roles_data = ""
615
856
  roles = self._get_role_statuses(self.roles, filter_roles)
857
+
616
858
  for role_status in roles:
617
859
  roles_data += self._format_role_status(role_status)
618
860
  return Template(_APP_STATUS_FORMAT_TEMPLATE).substitute(
@@ -637,11 +879,11 @@ class AppStatusError(Exception):
637
879
  self.status = status
638
880
 
639
881
 
640
- # valid run cfg values; only support primitives (str, int, float, bool, List[str])
882
+ # valid run cfg values; only support primitives (str, int, float, bool, List[str], Dict[str, str])
641
883
  # TODO(wilsonhong): python 3.9+ supports list[T] in typing, which can be used directly
642
884
  # in isinstance(). Should replace with that.
643
885
  # see: https://docs.python.org/3/library/stdtypes.html#generic-alias-type
644
- CfgVal = Union[str, int, float, bool, List[str], None]
886
+ CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], None]
645
887
 
646
888
 
647
889
  T = TypeVar("T")
@@ -700,6 +942,62 @@ class runopt:
700
942
  opt_type: Type[CfgVal]
701
943
  is_required: bool
702
944
  help: str
945
+ aliases: list[str] | None = None
946
+ deprecated_aliases: list[str] | None = None
947
+
948
+ @property
949
+ def is_type_list_of_str(self) -> bool:
950
+ """
951
+ Checks if the option type is a list of strings.
952
+
953
+ Returns:
954
+ bool: True if the option type is either List[str] or list[str], False otherwise.
955
+ """
956
+ return self.opt_type in (List[str], list[str])
957
+
958
+ @property
959
+ def is_type_dict_of_str(self) -> bool:
960
+ """
961
+ Checks if the option type is a dict of string keys to string values.
962
+
963
+ Returns:
964
+ bool: True if the option type is either Dict[str, str] or dict[str, str], False otherwise.
965
+ """
966
+ return self.opt_type in (Dict[str, str], dict[str, str])
967
+
968
+ def cast_to_type(self, value: str) -> CfgVal:
969
+ """Casts the given `value` (in its string representation) to the type of this run option.
970
+ Below are the cast rules for each option type and value literal:
971
+
972
+ 1. opt_type=str, value="foo" -> "foo"
973
+ 1. opt_type=bool, value="True"/"False" -> True/False
974
+ 1. opt_type=int, value="1" -> 1
975
+ 1. opt_type=float, value="1.1" -> 1.1
976
+ 1. opt_type=list[str]/List[str], value="a,b,c" or value="a;b;c" -> ["a", "b", "c"]
977
+ 1. opt_type=dict[str,str]/Dict[str,str],
978
+ value="key1:val1,key2:val2" or value="key1:val1;key2:val2" -> {"key1": "val1", "key2": "val2"}
979
+
980
+ NOTE: dict parsing uses ":" as the kv separator (rather than the standard "=") because "=" is used
981
+ at the top-level cfg to parse runopts (notice the plural) from the CLI. Originally torchx only supported
982
+ primitives and list[str] as CfgVal but dict[str,str] was added in https://github.com/meta-pytorch/torchx/pull/855
983
+ """
984
+
985
+ if self.opt_type is None:
986
+ raise ValueError("runopt's opt_type cannot be `None`")
987
+ elif self.opt_type == bool:
988
+ return value.lower() == "true"
989
+ elif self.opt_type in (List[str], list[str]):
990
+ # lists may be ; or , delimited
991
+ # also deal with trailing "," by removing empty strings
992
+ return [v for v in value.replace(";", ",").split(",") if v]
993
+ elif self.opt_type in (Dict[str, str], dict[str, str]):
994
+ return {
995
+ s.split(":", 1)[0]: s.split(":", 1)[1]
996
+ for s in value.replace(";", ",").split(",")
997
+ }
998
+ else:
999
+ assert self.opt_type in (str, int, float)
1000
+ return self.opt_type(value)
703
1001
 
704
1002
 
705
1003
  class runopts:
@@ -737,6 +1035,7 @@ class runopts:
737
1035
 
738
1036
  def __init__(self) -> None:
739
1037
  self._opts: Dict[str, runopt] = {}
1038
+ self._alias_to_key: dict[str, str] = {}
740
1039
 
741
1040
  def __iter__(self) -> Iterator[Tuple[str, runopt]]:
742
1041
  return self._opts.items().__iter__()
@@ -755,14 +1054,25 @@ class runopts:
755
1054
  except TypeError:
756
1055
  if isinstance(obj, list):
757
1056
  return all(isinstance(e, str) for e in obj)
1057
+ elif isinstance(obj, dict):
1058
+ return all(
1059
+ isinstance(k, str) and isinstance(v, str) for k, v in obj.items()
1060
+ )
758
1061
  else:
759
1062
  return False
760
1063
 
761
1064
  def get(self, name: str) -> Optional[runopt]:
762
1065
  """
763
- Returns option if any was registered, or None otherwise
1066
+ Returns option if any was registered, or None otherwise.
1067
+ First searches for the option by ``name``, then falls-back to matching ``name`` with any
1068
+ registered aliases.
1069
+
764
1070
  """
765
- return self._opts.get(name, None)
1071
+ if name in self._opts:
1072
+ return self._opts[name]
1073
+ if name in self._alias_to_key:
1074
+ return self._opts[self._alias_to_key[name]]
1075
+ return None
766
1076
 
767
1077
  def resolve(self, cfg: Mapping[str, CfgVal]) -> Dict[str, CfgVal]:
768
1078
  """
@@ -777,6 +1087,36 @@ class runopts:
777
1087
 
778
1088
  for cfg_key, runopt in self._opts.items():
779
1089
  val = resolved_cfg.get(cfg_key)
1090
+ resolved_name = None
1091
+ aliases = runopt.aliases or []
1092
+ deprecated_aliases = runopt.deprecated_aliases or []
1093
+ if val is None:
1094
+ for alias in aliases:
1095
+ val = resolved_cfg.get(alias)
1096
+ if alias in cfg or val is not None:
1097
+ resolved_name = alias
1098
+ break
1099
+ for alias in deprecated_aliases:
1100
+ val = resolved_cfg.get(alias)
1101
+ if val is not None:
1102
+ resolved_name = alias
1103
+ use_instead = self._alias_to_key.get(alias)
1104
+ warnings.warn(
1105
+ f"Run option `{alias}` is deprecated, use `{use_instead}` instead",
1106
+ UserWarning,
1107
+ stacklevel=2,
1108
+ )
1109
+ break
1110
+ else:
1111
+ resolved_name = cfg_key
1112
+ for alias in aliases:
1113
+ duplicate_val = resolved_cfg.get(alias)
1114
+ if alias in cfg or duplicate_val is not None:
1115
+ raise InvalidRunConfigException(
1116
+ f"Duplicate opt name. runopt: `{resolved_name}``, is an alias of runopt: `{alias}`",
1117
+ resolved_name,
1118
+ cfg,
1119
+ )
780
1120
 
781
1121
  # check required opt
782
1122
  if runopt.is_required and val is None:
@@ -796,7 +1136,7 @@ class runopts:
796
1136
  )
797
1137
 
798
1138
  # not required and not set, set to default
799
- if val is None:
1139
+ if val is None and resolved_name is None:
800
1140
  resolved_cfg[cfg_key] = runopt.default
801
1141
  return resolved_cfg
802
1142
 
@@ -856,22 +1196,37 @@ class runopts:
856
1196
 
857
1197
  """
858
1198
 
859
- def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal:
860
- if opt_type == bool:
861
- return value.lower() == "true"
862
- elif opt_type == List[str]:
863
- # lists may be ; or , delimited
864
- # also deal with trailing "," by removing empty strings
865
- return [v for v in value.replace(";", ",").split(",") if v]
1199
+ cfg: Dict[str, CfgVal] = {}
1200
+ for key, val in to_dict(cfg_str).items():
1201
+ opt = self.get(key)
1202
+ if opt:
1203
+ cfg[key] = opt.cast_to_type(val)
866
1204
  else:
867
- # pyre-ignore[19]
868
- return opt_type(value)
1205
+ logger.warning(
1206
+ f"{YELLOW_BOLD}Unknown run option passed to scheduler: {key}={val}{RESET}"
1207
+ )
1208
+ return cfg
869
1209
 
1210
+ def cfg_from_json_repr(self, json_repr: str) -> Dict[str, CfgVal]:
1211
+ """
1212
+ Converts the given dict to a valid cfg for this ``runopts`` object.
1213
+ """
870
1214
  cfg: Dict[str, CfgVal] = {}
871
- for key, val in to_dict(cfg_str).items():
872
- runopt_ = self.get(key)
873
- if runopt_:
874
- cfg[key] = _cast_to_type(val, runopt_.opt_type)
1215
+ cfg_dict = json.loads(json_repr)
1216
+ for key, val in cfg_dict.items():
1217
+ opt = self.get(key)
1218
+ if opt:
1219
+ # Optional runopt cfg values default their value to None,
1220
+ # but use `_type` to specify their type when provided.
1221
+ # Make sure not to treat None's as lists/dictionaries
1222
+ if val is None:
1223
+ cfg[key] = val
1224
+ elif opt.is_type_list_of_str:
1225
+ cfg[key] = [str(v) for v in val]
1226
+ elif opt.is_type_dict_of_str:
1227
+ cfg[key] = {str(k): str(v) for k, v in val.items()}
1228
+ else:
1229
+ cfg[key] = val
875
1230
  return cfg
876
1231
 
877
1232
  def add(
@@ -881,12 +1236,16 @@ class runopts:
881
1236
  help: str,
882
1237
  default: CfgVal = None,
883
1238
  required: bool = False,
1239
+ aliases: Optional[list[str]] = None,
1240
+ deprecated_aliases: Optional[list[str]] = None,
884
1241
  ) -> None:
885
1242
  """
886
1243
  Adds the ``config`` option with the given help string and ``default``
887
1244
  value (if any). If the ``default`` is not specified then this option
888
1245
  is a required option.
889
1246
  """
1247
+ aliases = aliases or []
1248
+ deprecated_aliases = deprecated_aliases or []
890
1249
  if required and default is not None:
891
1250
  raise ValueError(
892
1251
  f"Required option: {cfg_key} must not specify default value. Given: {default}"
@@ -898,7 +1257,19 @@ class runopts:
898
1257
  f" Given: {default} ({type(default).__name__})"
899
1258
  )
900
1259
 
901
- self._opts[cfg_key] = runopt(default, type_, required, help)
1260
+ opt = runopt(
1261
+ default,
1262
+ type_,
1263
+ required,
1264
+ help,
1265
+ list(set(aliases)),
1266
+ list(set(deprecated_aliases)),
1267
+ )
1268
+ for alias in aliases:
1269
+ self._alias_to_key[alias] = cfg_key
1270
+ for deprecated_alias in deprecated_aliases:
1271
+ self._alias_to_key[deprecated_alias] = cfg_key
1272
+ self._opts[cfg_key] = opt
902
1273
 
903
1274
  def update(self, other: "runopts") -> None:
904
1275
  self._opts.update(other._opts)
@@ -996,14 +1367,31 @@ class UnknownAppException(Exception):
996
1367
 
997
1368
  def parse_app_handle(app_handle: AppHandle) -> ParsedAppHandle:
998
1369
  """
999
- parses the app handle into ```(scheduler_backend, session_name, and app_id)```
1370
+ Parses the app handle into ```(scheduler_backend, session_name, and app_id)```.
1371
+
1372
+ Example:
1373
+
1374
+ .. doctest::
1375
+
1376
+ assert parse_app_handle("k8s://default/foo_bar") == ("k8s", "default", "foo_bar")
1377
+ assert parse_app_handle("k8s:///foo_bar") == ("k8s", "", "foo_bar")
1378
+
1379
+ Args:
1380
+ app_handle: a URI of the form ``{scheduler}://{session_name}/{app_id}``,
1381
+ where the ``session_name`` is optional. In this case the app handle is
1382
+ of the form ``{scheduler}:///{app_id}`` (notice the triple slashes).
1383
+
1384
+ Returns: A ``Tuple`` of three elements, ``(scheduler, session_name, app_id)``
1385
+ parsed from the app_handle URI str. If the session name is not present then
1386
+ an empty string is returned in its place in the tuple.
1387
+
1000
1388
  """
1001
1389
 
1002
1390
  # parse it manually b/c currently torchx does not
1003
1391
  # define allowed characters nor length for session name and app_id
1004
1392
  import re
1005
1393
 
1006
- pattern = r"(?P<scheduler_backend>.+)://(?P<session_name>.+)/(?P<app_id>.+)"
1394
+ pattern = r"(?P<scheduler_backend>.+)://(?P<session_name>.*)/(?P<app_id>.+)"
1007
1395
  match = re.match(pattern, app_handle)
1008
1396
  if not match:
1009
1397
  raise MalformedAppHandleException(app_handle)