torchx-nightly 2025.9.21__py3-none-any.whl → 2025.9.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

torchx/cli/cmd_list.py CHANGED
@@ -33,8 +33,7 @@ class CmdList(SubCommand):
33
33
  type=str,
34
34
  default=get_default_scheduler_name(),
35
35
  choices=list(scheduler_names),
36
- help=f"Name of the scheduler to use. One of: [{','.join(scheduler_names)}]."
37
- " For listing app handles for ray scheduler, RAY_ADDRESS env variable should be set.",
36
+ help=f"Name of the scheduler to use. One of: [{','.join(scheduler_names)}].",
38
37
  )
39
38
 
40
39
  def run(self, args: argparse.Namespace) -> None:
@@ -109,7 +109,7 @@ class CopyComponentProvider(ComponentProvider):
109
109
  self._dst_path = "<None>"
110
110
 
111
111
  def setUp(self) -> None:
112
- if self._scheduler in ["local_cwd", "ray"]:
112
+ if self._scheduler in ["local_cwd"]:
113
113
  fname = "torchx_copy_test.txt"
114
114
  self._src_path: str = os.path.join(tempfile.gettempdir(), fname)
115
115
  self._dst_path: str = os.path.join(tempfile.gettempdir(), f"{fname}.copy")
@@ -126,7 +126,7 @@ class CopyComponentProvider(ComponentProvider):
126
126
  def tearDown(self) -> None:
127
127
  if os.path.exists(self._dst_path):
128
128
  os.remove(self._dst_path)
129
- if self._scheduler in ["local_cwd", "ray"] and os.path.exists(self._dst_path):
129
+ if self._scheduler in ["local_cwd"] and os.path.exists(self._dst_path):
130
130
  os.remove(self._dst_path)
131
131
 
132
132
  def get_app_def(self) -> AppDef:
@@ -21,7 +21,6 @@ DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = {
21
21
  "kubernetes_mcad": "torchx.schedulers.kubernetes_mcad_scheduler",
22
22
  "aws_batch": "torchx.schedulers.aws_batch_scheduler",
23
23
  "aws_sagemaker": "torchx.schedulers.aws_sagemaker_scheduler",
24
- "gcp_batch": "torchx.schedulers.gcp_batch_scheduler",
25
24
  "lsf": "torchx.schedulers.lsf_scheduler",
26
25
  }
27
26
 
torchx/util/strings.py CHANGED
@@ -13,7 +13,7 @@ def normalize_str(data: str) -> str:
13
13
  """
14
14
  Invokes ``lower`` on thes string and removes all
15
15
  characters that do not satisfy ``[a-z0-9\\-]`` pattern.
16
- This method is mostly used to make sure kubernetes and gcp_batch scheduler gets
16
+ This method is mostly used to make sure kubernetes scheduler gets
17
17
  the job name that does not violate its restrictions.
18
18
  """
19
19
  if data.startswith("-"):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: torchx-nightly
3
- Version: 2025.9.21
3
+ Version: 2025.9.23
4
4
  Summary: TorchX SDK and Components
5
5
  Home-page: https://github.com/pytorch/torchx
6
6
  Author: TorchX Devs
@@ -22,7 +22,6 @@ Requires-Dist: pyyaml
22
22
  Requires-Dist: docker
23
23
  Requires-Dist: filelock
24
24
  Requires-Dist: fsspec>=2023.10.0
25
- Requires-Dist: urllib3<1.27,>=1.21.1
26
25
  Requires-Dist: tabulate
27
26
  Provides-Extra: aws_batch
28
27
  Requires-Dist: boto3; extra == "aws-batch"
@@ -36,9 +35,6 @@ Requires-Dist: kubernetes==25.3.0; extra == "dev"
36
35
  Requires-Dist: flake8==3.9.0; extra == "dev"
37
36
  Requires-Dist: fsspec==2024.3.1; extra == "dev"
38
37
  Requires-Dist: s3fs==2024.3.1; extra == "dev"
39
- Requires-Dist: google-cloud-batch==0.17.14; extra == "dev"
40
- Requires-Dist: google-cloud-logging==3.10.0; extra == "dev"
41
- Requires-Dist: google-cloud-runtimeconfig==0.34.0; extra == "dev"
42
38
  Requires-Dist: hydra-core; extra == "dev"
43
39
  Requires-Dist: ipython; extra == "dev"
44
40
  Requires-Dist: mlflow-skinny; extra == "dev"
@@ -61,14 +57,6 @@ Requires-Dist: ts==0.5.1; extra == "dev"
61
57
  Requires-Dist: wheel; extra == "dev"
62
58
  Requires-Dist: lintrunner; extra == "dev"
63
59
  Requires-Dist: lintrunner-adapters; extra == "dev"
64
- Requires-Dist: grpcio==1.62.1; extra == "dev"
65
- Requires-Dist: grpcio-status==1.48.1; extra == "dev"
66
- Requires-Dist: googleapis-common-protos==1.63.0; extra == "dev"
67
- Requires-Dist: google-api-core==2.18.0; extra == "dev"
68
- Provides-Extra: gcp_batch
69
- Requires-Dist: google-cloud-batch>=0.5.0; extra == "gcp-batch"
70
- Requires-Dist: google-cloud-logging>=3.0.0; extra == "gcp-batch"
71
- Requires-Dist: google-cloud-runtimeconfig>=0.33.2; extra == "gcp-batch"
72
60
  Provides-Extra: kubernetes
73
61
  Requires-Dist: kubernetes>=11; extra == "kubernetes"
74
62
 
@@ -93,7 +81,6 @@ TorchX currently supports:
93
81
  * AWS Batch
94
82
  * Docker
95
83
  * Local
96
- * GCP Batch (prototype)
97
84
 
98
85
  Need a scheduler not listed? [Let us know!](https://github.com/pytorch/torchx/issues?q=is%3Aopen+is%3Aissue+label%3Ascheduler-request)
99
86
 
@@ -14,7 +14,7 @@ torchx/cli/cmd_base.py,sha256=SdqMtqi04CEqnzcgcS35DbDbsBeMxSgEhfynfpIkMGk,790
14
14
  torchx/cli/cmd_cancel.py,sha256=NKfOCu_44Lch9vliGSQ0Uv6BVqpUqj7Tob652TI-ua4,835
15
15
  torchx/cli/cmd_configure.py,sha256=1kTv0qbsbV44So74plAySwWu56pQrqjhfW_kbfdC3Rw,1722
16
16
  torchx/cli/cmd_describe.py,sha256=E5disbHoKTsqYKp2s3DaFW9GDLCCOgdOc3pQoHKoyCs,1283
17
- torchx/cli/cmd_list.py,sha256=4Y1ZOq-kqJbztoBt56hAW_InJEaJuDAjpKWgMhBw4II,1507
17
+ torchx/cli/cmd_list.py,sha256=alkS9aIaDI8lX3W8uj8Vtr3IU3G2VeCuokKSd3zOFug,1409
18
18
  torchx/cli/cmd_log.py,sha256=v-EZYUDOcG95rEgTnrsmPJMUyxM9Mk8YFAJtUxtgViE,5475
19
19
  torchx/cli/cmd_run.py,sha256=TshvEMTxMRj5O0KhetzHepZUaAFq8R5nFgY8GC_Gl6g,18576
20
20
  torchx/cli/cmd_runopts.py,sha256=NWZiP8XpQjfTDJgays2c6MgL_8wxFoeDge6NstaZdKk,1302
@@ -32,7 +32,7 @@ torchx/components/structured_arg.py,sha256=8jMcd0rtUmzCKEQKJ_JYzxSkMMK9q0fYjkwAs
32
32
  torchx/components/train.py,sha256=vtrQXRcD7bIcbb3lSeyD9BBlIe1mv1WNW6rnLK9R0Mw,1259
33
33
  torchx/components/utils.py,sha256=QRBxBm1OnNhOhpPs0lKdbJ8_mNhWYMklY6cl1gPIw9A,9363
34
34
  torchx/components/integration_tests/__init__.py,sha256=Md3cCHD7Ano9kV15PqGbicgUO-RMdh4aVy1yKiDt_xE,208
35
- torchx/components/integration_tests/component_provider.py,sha256=cFNGqmclcZTJlOW_YGf5XEuGeWloTmcJEAh02Aob_PQ,3995
35
+ torchx/components/integration_tests/component_provider.py,sha256=g-4ig1vtd5Vzgug0VAKRAFUt6KAV3TgQrBCrwRSJ7ZY,3981
36
36
  torchx/components/integration_tests/integ_tests.py,sha256=O8jd8Jq5O0mns7xzIFsHexBDHkIIAIfELQkWCzNPzRw,5165
37
37
  torchx/distributed/__init__.py,sha256=lobebigfujmRTe_SJw07_a9iohBxDhq2iiPsV1YcKjw,10247
38
38
  torchx/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -56,13 +56,12 @@ torchx/runner/events/handlers.py,sha256=ThHCIJW21BfBgB7b6ftyjASJmD1KdizpjuTtsyqn
56
56
  torchx/runtime/__init__.py,sha256=Wxje2BryzeQneFu5r6P9JJiEKG-_C9W1CcZ_JNrKT6g,593
57
57
  torchx/runtime/tracking/__init__.py,sha256=dYnAPnrXYREfPXkpHhdOFkcYIODWEbA13PdD-wLQYBo,3055
58
58
  torchx/runtime/tracking/api.py,sha256=SmUQyUKZqG3KlAhT7CJOGqRz1O274E4m63wQeOVq3CU,5472
59
- torchx/schedulers/__init__.py,sha256=hliMsZHZNOKue0uTHUWxvO0V7xsKApBxN4Wb_9L0Mz4,2253
59
+ torchx/schedulers/__init__.py,sha256=_Wx6-X3FNh8RJR82UGgUwKg7V_VQYsAkrveDoSSk2xU,2195
60
60
  torchx/schedulers/api.py,sha256=lfxNhrEO6eYYqVuQzzj9sTXrZShuZkyYxJ1jPE-Lvpo,14561
61
61
  torchx/schedulers/aws_batch_scheduler.py,sha256=hFxYzSZEK2SVS5sEyQC5YvNI0JJUJUQsWORlYpj_h3M,28105
62
62
  torchx/schedulers/aws_sagemaker_scheduler.py,sha256=flN8GumKE2Dz4X_foAt6Jnvt-ZVojWs6pcyrHwB0hz0,20921
63
63
  torchx/schedulers/devices.py,sha256=RjVcu22ZRl_9OKtOtmA1A3vNXgu2qD6A9ST0L0Hsg4I,1734
64
64
  torchx/schedulers/docker_scheduler.py,sha256=xuK00-dB6o8TV1YaZox7O5P09LHB2KeQ6t4eiNtqMYQ,16781
65
- torchx/schedulers/gcp_batch_scheduler.py,sha256=JQuaEJVL_7NSa9AeUc_0Qo74XZNJk_kp6XwgunvlUKI,16281
66
65
  torchx/schedulers/ids.py,sha256=3E-_vwVYC-8Tv8kjuY9-W7TbOe_-Laqd8a65uIN3hQY,1798
67
66
  torchx/schedulers/kubernetes_mcad_scheduler.py,sha256=1tuzq3OutCMdSPqg_dNmCHt_wyuSFKG0-ywLc3qITJo,42949
68
67
  torchx/schedulers/kubernetes_scheduler.py,sha256=0_loGJ7WnxEr9dhgFt3Gw-7nVLirMDVN-MAFTCq7erE,28217
@@ -97,15 +96,15 @@ torchx/util/log_tee_helpers.py,sha256=wPyozmh9BOt_2d3Gxa0iNogwnjzwFitIIMBJOJ1arI
97
96
  torchx/util/modules.py,sha256=o4y_d07gTpJ4nIVBcoUVJ0JtXIHEsEC5kbgBM6NGpgA,2135
98
97
  torchx/util/session.py,sha256=r6M_nyzXgcbk1GgYGZ324F_ehRGCqjjdVk4YgKxMj8M,1214
99
98
  torchx/util/shlex.py,sha256=eXEKu8KC3zIcd8tEy9_s8Ds5oma8BORr-0VGWNpG2dk,463
100
- torchx/util/strings.py,sha256=GkLWCmYS89Uv6bWc5hH0XwvHy7oQmprv2U7axC4A2e8,678
99
+ torchx/util/strings.py,sha256=7Ef1loz2IYMrzeJ6Lewywi5cBIc3X3g7lSPbT1Tn_z4,664
101
100
  torchx/util/types.py,sha256=E9dxAWQnsJkIDuHtg-poeOJ4etucSI_xP_Z5kNJX8uI,9229
102
101
  torchx/workspace/__init__.py,sha256=cZsKVvUWwDYcGhe6SCXQGBQfbk_yTnKEImOkI6xmu30,809
103
102
  torchx/workspace/api.py,sha256=Ct_75VU94fsH9Rf1WRe-wJGpVgl5O05S_Dq_t2ArJWA,11348
104
103
  torchx/workspace/dir_workspace.py,sha256=npNW_IjUZm_yS5r-8hrRkH46ndDd9a_eApT64m1S1T4,2268
105
104
  torchx/workspace/docker_workspace.py,sha256=PFu2KQNVC-0p2aKJ-W_BKA9ZOmXdCY2ABEkCExp3udQ,10269
106
- torchx_nightly-2025.9.21.dist-info/LICENSE,sha256=WVHfXhFC0Ia8LTKt_nJVYobdqTJVg_4J3Crrfm2A8KQ,1721
107
- torchx_nightly-2025.9.21.dist-info/METADATA,sha256=0KvbX8m2uQZVgOBL_JiKB8nVyXDWBCVFRAypM61NWQU,5693
108
- torchx_nightly-2025.9.21.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
109
- torchx_nightly-2025.9.21.dist-info/entry_points.txt,sha256=T328AMXeKI3JZnnxfkEew2ZcMN1oQDtkXjMz7lkV-P4,169
110
- torchx_nightly-2025.9.21.dist-info/top_level.txt,sha256=pxew3bc2gsiViS0zADs0jb6kC5v8o_Yy_85fhHj_J1A,7
111
- torchx_nightly-2025.9.21.dist-info/RECORD,,
105
+ torchx_nightly-2025.9.23.dist-info/LICENSE,sha256=WVHfXhFC0Ia8LTKt_nJVYobdqTJVg_4J3Crrfm2A8KQ,1721
106
+ torchx_nightly-2025.9.23.dist-info/METADATA,sha256=Dsh27u65MAyHzYM5eRSPfYEQZmnh9qfsvs1_0vWkhCo,5003
107
+ torchx_nightly-2025.9.23.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
108
+ torchx_nightly-2025.9.23.dist-info/entry_points.txt,sha256=T328AMXeKI3JZnnxfkEew2ZcMN1oQDtkXjMz7lkV-P4,169
109
+ torchx_nightly-2025.9.23.dist-info/top_level.txt,sha256=pxew3bc2gsiViS0zADs0jb6kC5v8o_Yy_85fhHj_J1A,7
110
+ torchx_nightly-2025.9.23.dist-info/RECORD,,
@@ -1,497 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright (c) Meta Platforms, Inc. and affiliates.
3
- # All rights reserved.
4
- #
5
- # This source code is licensed under the BSD-style license found in the
6
- # LICENSE file in the root directory of this source tree.
7
-
8
- # pyre-strict
9
-
10
- """
11
-
12
- This contains the TorchX GCP Batch scheduler which can be used to run TorchX
13
- components directly on GCP Batch.
14
-
15
- This scheduler is in prototype stage and may change without notice.
16
-
17
- Prerequisites
18
- ==============
19
-
20
- You need to have a GCP project configured to use Batch by enabling and setting it up.
21
- See https://cloud.google.com/batch/docs/get-started#prerequisites
22
-
23
- """
24
-
25
- from dataclasses import dataclass
26
- from datetime import datetime
27
- from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, TypedDict
28
-
29
- import torchx
30
- import yaml
31
-
32
- from torchx.schedulers.api import (
33
- DescribeAppResponse,
34
- ListAppResponse,
35
- Scheduler,
36
- Stream,
37
- )
38
- from torchx.schedulers.ids import make_unique
39
- from torchx.specs.api import (
40
- AppDef,
41
- AppDryRunInfo,
42
- AppState,
43
- macros,
44
- Resource,
45
- Role,
46
- runopts,
47
- )
48
- from torchx.util.strings import normalize_str
49
-
50
-
51
- if TYPE_CHECKING:
52
- from google.cloud import batch_v1
53
-
54
-
55
- JOB_STATE: Dict[str, AppState] = {
56
- "STATE_UNSPECIFIED": AppState.UNKNOWN,
57
- "QUEUED": AppState.SUBMITTED,
58
- "SCHEDULED": AppState.PENDING,
59
- "RUNNING": AppState.RUNNING,
60
- "SUCCEEDED": AppState.SUCCEEDED,
61
- "FAILED": AppState.FAILED,
62
- "DELETION_IN_PROGRESS": AppState.UNKNOWN,
63
- }
64
-
65
- GPU_COUNT_TO_TYPE: Dict[int, str] = {
66
- 1: "a2-highgpu-1g",
67
- 2: "a2-highgpu-2g",
68
- 4: "a2-highgpu-4g",
69
- 8: "a2-highgpu-8g",
70
- 16: "a2-highgpu-16g",
71
- }
72
-
73
- GPU_TYPE_TO_COUNT: Dict[str, int] = {v: k for k, v in GPU_COUNT_TO_TYPE.items()}
74
-
75
- LABEL_VERSION: str = "torchx_version"
76
- LABEL_APP_NAME: str = "torchx_app_name"
77
-
78
- DEFAULT_LOC: str = "us-central1"
79
-
80
- # TODO Remove LOCATIONS list once Batch supports all locations
81
- # or when there is an API to query locations supported by Batch
82
- LOCATIONS: List[str] = [
83
- DEFAULT_LOC,
84
- "us-west1",
85
- "us-east1",
86
- "asia-southeast1",
87
- "europe-north1",
88
- "europe-west6",
89
- ]
90
-
91
- BATCH_LOGGER_NAME = "batch_task_logs"
92
-
93
-
94
- @dataclass
95
- class GCPBatchJob:
96
- name: str
97
- project: str
98
- location: str
99
- job_def: "batch_v1.Job"
100
-
101
- def __str__(self) -> str:
102
- return yaml.dump(self.job_def)
103
-
104
- def __repr__(self) -> str:
105
- return str(self)
106
-
107
-
108
- class GCPBatchOpts(TypedDict, total=False):
109
- project: Optional[str]
110
- location: Optional[str]
111
-
112
-
113
- class GCPBatchScheduler(Scheduler[GCPBatchOpts, AppDef, AppDryRunInfo[GCPBatchJob]]):
114
- """
115
- GCPBatchScheduler is a TorchX scheduling interface to GCP Batch.
116
-
117
- .. code-block:: bash
118
-
119
- $ pip install torchx[gcp_batch]
120
- $ torchx run --scheduler gcp_batch utils.echo --msg hello
121
- # This launches a job with app handle like gcp_batch://torchx/project:location:app_id1234 and prints it
122
- $ torchx status gcp_batch://torchx/project:location:app_id1234
123
- ...
124
-
125
- Authentication is loaded from the environment using the gcloud credential handling.
126
-
127
- **Config Options**
128
-
129
- .. runopts::
130
- class: torchx.schedulers.gcp_batch_scheduler.create_scheduler
131
-
132
- **Compatibility**
133
-
134
- .. compatibility::
135
- type: scheduler
136
- features:
137
- cancel: true
138
- logs: true
139
- describe: true
140
- distributed: true
141
- workspaces: false
142
- mounts: false
143
- elasticity: false
144
-
145
- """
146
-
147
- def __init__(
148
- self,
149
- session_name: str,
150
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
151
- client: Optional[Any] = None,
152
- ) -> None:
153
- # NOTE: make sure any new init options are supported in create_scheduler(...)
154
- Scheduler.__init__(self, "gcp_batch", session_name)
155
- # pyre-fixme[4]: Attribute annotation cannot be `Any`.
156
- self.__client = client
157
-
158
- @property
159
- # pyre-fixme[3]: Return annotation cannot be `Any`.
160
- def _client(self) -> Any:
161
- from google.api_core import gapic_v1
162
- from google.cloud import batch_v1
163
-
164
- c = self.__client
165
- if c is None:
166
- client_info = gapic_v1.client_info.ClientInfo(
167
- user_agent=f"TorchX/{torchx.__version__}"
168
- )
169
- c = self.__client = batch_v1.BatchServiceClient(client_info=client_info)
170
- return c
171
-
172
- def schedule(self, dryrun_info: AppDryRunInfo[GCPBatchJob]) -> str:
173
- from google.cloud import batch_v1
174
-
175
- req = dryrun_info.request
176
- assert req is not None, f"{dryrun_info} missing request"
177
-
178
- request = batch_v1.CreateJobRequest(
179
- parent=f"projects/{req.project}/locations/{req.location}",
180
- job=req.job_def,
181
- job_id=req.name,
182
- )
183
-
184
- response = self._client.create_job(request=request)
185
- return f"{req.project}:{req.location}:{req.name}"
186
-
187
- def _app_to_job(self, app: AppDef) -> "batch_v1.Job":
188
- from google.cloud import batch_v1
189
-
190
- name = normalize_str(make_unique(app.name))
191
-
192
- taskGroups = []
193
- allocationPolicy = None
194
-
195
- # 1. Convert role to task
196
- # TODO implement retry_policy, mount conversion
197
- # NOTE: Supports only one role for now as GCP Batch supports only one TaskGroup
198
- # which is ok to start with as most components have only one role
199
- for role_idx, role in enumerate(app.roles):
200
- values = macros.Values(
201
- img_root="",
202
- app_id=name,
203
- replica_id=str(0),
204
- rank0_env=("BATCH_MAIN_NODE_HOSTNAME"),
205
- )
206
- role_dict = values.apply(role)
207
- role_dict.env["TORCHX_ROLE_IDX"] = str(role_idx)
208
- role_dict.env["TORCHX_ROLE_NAME"] = str(role.name)
209
-
210
- resource = role_dict.resource
211
- res = batch_v1.ComputeResource()
212
- cpu = resource.cpu
213
- if cpu <= 0:
214
- cpu = 1
215
- MILLI = 1000
216
- res.cpu_milli = cpu * MILLI
217
- memMB = resource.memMB
218
- if memMB < 0:
219
- raise ValueError(
220
- f"memMB should to be set to a positive value, got {memMB}"
221
- )
222
- res.memory_mib = memMB
223
-
224
- # TODO support named resources
225
- # Using v100 as default GPU type as a100 does not allow changing count for now
226
- # TODO See if there is a better default GPU type
227
- if resource.gpu > 0:
228
- if resource.gpu not in GPU_COUNT_TO_TYPE:
229
- raise ValueError(
230
- f"gpu should to be set to one of these values: {GPU_COUNT_TO_TYPE.keys()}"
231
- )
232
- machineType = GPU_COUNT_TO_TYPE[resource.gpu]
233
- allocationPolicy = batch_v1.AllocationPolicy(
234
- instances=[
235
- batch_v1.AllocationPolicy.InstancePolicyOrTemplate(
236
- install_gpu_drivers=True,
237
- policy=batch_v1.AllocationPolicy.InstancePolicy(
238
- machine_type=machineType,
239
- ),
240
- )
241
- ],
242
- )
243
- print(f"Using GPUs of type: {machineType}")
244
-
245
- # Configure host firewall rules to accept ingress communication
246
- config_network_runnable = batch_v1.Runnable(
247
- script=batch_v1.Runnable.Script(
248
- text="/sbin/iptables -A INPUT -j ACCEPT"
249
- )
250
- )
251
-
252
- runnable = batch_v1.Runnable(
253
- container=batch_v1.Runnable.Container(
254
- image_uri=role_dict.image,
255
- commands=[role_dict.entrypoint] + role_dict.args,
256
- entrypoint="",
257
- # Configure docker to use the host network stack to communicate with containers/other hosts in the same network
258
- options="--net host",
259
- )
260
- )
261
-
262
- ts = batch_v1.TaskSpec(
263
- runnables=[config_network_runnable, runnable],
264
- environment=batch_v1.Environment(variables=role_dict.env),
265
- max_retry_count=role_dict.max_retries,
266
- compute_resource=res,
267
- )
268
-
269
- task_env = [
270
- batch_v1.Environment(variables={"TORCHX_REPLICA_IDX": str(i)})
271
- for i in range(role_dict.num_replicas)
272
- ]
273
-
274
- tg = batch_v1.TaskGroup(
275
- task_spec=ts,
276
- task_count=role_dict.num_replicas,
277
- task_count_per_node=1,
278
- task_environments=task_env,
279
- require_hosts_file=True,
280
- )
281
- taskGroups.append(tg)
282
-
283
- # 2. Convert AppDef to Job
284
- job = batch_v1.Job(
285
- name=name,
286
- task_groups=taskGroups,
287
- allocation_policy=allocationPolicy,
288
- logs_policy=batch_v1.LogsPolicy(
289
- destination=batch_v1.LogsPolicy.Destination.CLOUD_LOGGING,
290
- ),
291
- # NOTE: GCP Batch does not allow label names with "."
292
- labels={
293
- LABEL_VERSION: torchx.__version__.replace(".", "-"),
294
- LABEL_APP_NAME: name,
295
- },
296
- )
297
- return job
298
-
299
- def _get_project(self) -> str:
300
- from google.cloud import runtimeconfig
301
-
302
- return runtimeconfig.Client().project
303
-
304
- def _submit_dryrun(
305
- self, app: AppDef, cfg: GCPBatchOpts
306
- ) -> AppDryRunInfo[GCPBatchJob]:
307
- proj = cfg.get("project")
308
- if proj is None:
309
- proj = self._get_project()
310
- assert proj is not None and isinstance(proj, str), "project must be a str"
311
-
312
- loc = cfg.get("location")
313
- assert loc is not None and isinstance(loc, str), "location must be a str"
314
-
315
- job = self._app_to_job(app)
316
-
317
- # Convert JobDef + BatchOpts to GCPBatchJob
318
- req = GCPBatchJob(
319
- name=str(job.name),
320
- project=proj,
321
- location=loc,
322
- job_def=job,
323
- )
324
-
325
- return AppDryRunInfo(req, repr)
326
-
327
- def run_opts(self) -> runopts:
328
- opts = runopts()
329
- opts.add(
330
- "project",
331
- type_=str,
332
- help="Name of the GCP project. Defaults to the configured GCP project in the environment",
333
- )
334
- opts.add(
335
- "location",
336
- type_=str,
337
- default=DEFAULT_LOC,
338
- help=f"Name of the location to schedule the job in. Defaults to {DEFAULT_LOC}",
339
- )
340
- return opts
341
-
342
- def _app_id_to_job_full_name(self, app_id: str) -> str:
343
- """
344
- app_id format: f"{project}:{location}:{name}"
345
- job_full_name format: f"projects/{project}/locations/{location}/jobs/{name}"
346
- where 'name' was created uniquely for the job from the app name
347
- """
348
- app_id_splits = app_id.split(":")
349
- if len(app_id_splits) != 3:
350
- raise ValueError(f"app_id not in expected format: {app_id}")
351
- return f"projects/{app_id_splits[0]}/locations/{app_id_splits[1]}/jobs/{app_id_splits[2]}"
352
-
353
- def _get_job(self, app_id: str) -> "batch_v1.Job":
354
- from google.cloud import batch_v1
355
-
356
- job_name = self._app_id_to_job_full_name(app_id)
357
- request = batch_v1.GetJobRequest(
358
- name=job_name,
359
- )
360
- return self._client.get_job(request=request)
361
-
362
- def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
363
- job = self._get_job(app_id)
364
- if job is None:
365
- print(f"app not found: {app_id}")
366
- return None
367
-
368
- gpu = 0
369
- if len(job.allocation_policy.instances) != 0:
370
- gpu_type = job.allocation_policy.instances[0].policy.machine_type
371
- gpu = GPU_TYPE_TO_COUNT[gpu_type]
372
-
373
- roles = {}
374
- for tg in job.task_groups:
375
- env = tg.task_spec.environment.variables
376
- role = env["TORCHX_ROLE_NAME"]
377
- container = tg.task_spec.runnables[1].container
378
- roles[role] = Role(
379
- name=role,
380
- num_replicas=tg.task_count,
381
- image=container.image_uri,
382
- entrypoint=container.commands[0],
383
- args=list(container.commands[1:]),
384
- resource=Resource(
385
- cpu=int(tg.task_spec.compute_resource.cpu_milli / 1000),
386
- memMB=tg.task_spec.compute_resource.memory_mib,
387
- gpu=gpu,
388
- ),
389
- env=dict(env),
390
- max_retries=tg.task_spec.max_retry_count,
391
- )
392
-
393
- # Map job -> DescribeAppResponse
394
- # TODO map role/replica status
395
- desc = DescribeAppResponse(
396
- app_id=app_id,
397
- state=JOB_STATE[job.status.state.name],
398
- roles=list(roles.values()),
399
- )
400
- return desc
401
-
402
- def log_iter(
403
- self,
404
- app_id: str,
405
- role_name: str = "",
406
- k: int = 0,
407
- regex: Optional[str] = None,
408
- since: Optional[datetime] = None,
409
- until: Optional[datetime] = None,
410
- should_tail: bool = False,
411
- streams: Optional[Stream] = None,
412
- ) -> Iterable[str]:
413
- if streams not in (None, Stream.COMBINED):
414
- raise ValueError("GCPBatchScheduler only supports COMBINED log stream")
415
-
416
- job = self._get_job(app_id)
417
- if not job:
418
- raise ValueError(f"app not found: {app_id}")
419
-
420
- job_uid = job.uid
421
- filters = [
422
- f"labels.job_uid={job_uid}",
423
- f"labels.task_id:{job_uid}-group0-{k}",
424
- ]
425
-
426
- if since is not None:
427
- filters.append(f'timestamp>="{str(since.isoformat())}"')
428
- else:
429
- # gcloud logger.list by default only returns logs in the last 24 hours
430
- # Since many ML jobs can run longer add timestamp filter to get all logs
431
- filters.append(f'timestamp>="{str(datetime.fromtimestamp(0).isoformat())}"')
432
-
433
- if until is not None:
434
- filters.append(f'timestamp<="{str(until.isoformat())}"')
435
- if regex is not None:
436
- filters.append(f'textPayload =~ "{regex}"')
437
- filter = " AND ".join(filters)
438
- return self._batch_log_iter(filter)
439
-
440
- def _batch_log_iter(self, filter: str) -> Iterable[str]:
441
- from google.cloud import logging
442
-
443
- logger = logging.Client().logger(BATCH_LOGGER_NAME)
444
- for entry in logger.list_entries(filter_=filter):
445
- yield entry.payload + "\n"
446
-
447
- def _job_full_name_to_app_id(self, job_full_name: str) -> str:
448
- """
449
- job_full_name format: f"projects/{project}/locations/{location}/jobs/{name}"
450
- app_id format: f"{project}:{location}:{name}"
451
- where 'name' was created uniquely for the job from the app name
452
- """
453
- job_name_splits = job_full_name.split("/")
454
- if len(job_name_splits) != 6:
455
- raise ValueError(f"job full name not in expected format: {job_full_name}")
456
- return f"{job_name_splits[1]}:{job_name_splits[3]}:{job_name_splits[5]}"
457
-
458
- def list(self) -> List[ListAppResponse]:
459
- all_jobs = []
460
- proj = self._get_project()
461
- for loc in LOCATIONS:
462
- jobs = self._client.list_jobs(parent=f"projects/{proj}/locations/{loc}")
463
- all_jobs += jobs
464
- all_jobs.sort(key=lambda job: job.create_time.timestamp(), reverse=True)
465
- return [
466
- ListAppResponse(
467
- app_id=self._job_full_name_to_app_id(job.name),
468
- state=JOB_STATE[job.status.state.name],
469
- )
470
- for job in all_jobs
471
- ]
472
-
473
- def _validate(self, app: AppDef, scheduler: str, cfg: GCPBatchOpts) -> None:
474
- # Skip validation step
475
- pass
476
-
477
- def _cancel_existing(self, app_id: str) -> None:
478
- from google.cloud import batch_v1
479
-
480
- job_name = self._app_id_to_job_full_name(app_id)
481
- request = batch_v1.DeleteJobRequest(
482
- name=job_name,
483
- reason="Killed via TorchX",
484
- )
485
- self._client.delete_job(request=request)
486
-
487
-
488
- def create_scheduler(
489
- session_name: str,
490
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
491
- client: Optional[Any] = None,
492
- **kwargs: object,
493
- ) -> GCPBatchScheduler:
494
- return GCPBatchScheduler(
495
- session_name=session_name,
496
- client=client,
497
- )