torchx-nightly 2024.1.6__py3-none-any.whl → 2025.12.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchx-nightly might be problematic. Click here for more details.

Files changed (110) hide show
  1. torchx/__init__.py +2 -0
  2. torchx/{schedulers/ray/__init__.py → _version.py} +3 -1
  3. torchx/apps/serve/serve.py +2 -0
  4. torchx/apps/utils/booth_main.py +2 -0
  5. torchx/apps/utils/copy_main.py +2 -0
  6. torchx/apps/utils/process_monitor.py +2 -0
  7. torchx/cli/__init__.py +2 -0
  8. torchx/cli/argparse_util.py +38 -3
  9. torchx/cli/cmd_base.py +2 -0
  10. torchx/cli/cmd_cancel.py +2 -0
  11. torchx/cli/cmd_configure.py +2 -0
  12. torchx/cli/cmd_delete.py +30 -0
  13. torchx/cli/cmd_describe.py +2 -0
  14. torchx/cli/cmd_list.py +8 -4
  15. torchx/cli/cmd_log.py +6 -24
  16. torchx/cli/cmd_run.py +269 -45
  17. torchx/cli/cmd_runopts.py +2 -0
  18. torchx/cli/cmd_status.py +12 -1
  19. torchx/cli/cmd_tracker.py +3 -1
  20. torchx/cli/colors.py +2 -0
  21. torchx/cli/main.py +4 -0
  22. torchx/components/__init__.py +3 -8
  23. torchx/components/component_test_base.py +2 -0
  24. torchx/components/dist.py +18 -7
  25. torchx/components/integration_tests/component_provider.py +4 -2
  26. torchx/components/integration_tests/integ_tests.py +2 -0
  27. torchx/components/serve.py +2 -0
  28. torchx/components/structured_arg.py +4 -3
  29. torchx/components/utils.py +15 -4
  30. torchx/distributed/__init__.py +2 -4
  31. torchx/examples/apps/datapreproc/datapreproc.py +2 -0
  32. torchx/examples/apps/lightning/data.py +5 -3
  33. torchx/examples/apps/lightning/model.py +7 -6
  34. torchx/examples/apps/lightning/profiler.py +7 -4
  35. torchx/examples/apps/lightning/train.py +11 -2
  36. torchx/examples/torchx_out_of_sync_training.py +11 -0
  37. torchx/notebook.py +2 -0
  38. torchx/runner/__init__.py +2 -0
  39. torchx/runner/api.py +167 -60
  40. torchx/runner/config.py +43 -10
  41. torchx/runner/events/__init__.py +57 -13
  42. torchx/runner/events/api.py +14 -3
  43. torchx/runner/events/handlers.py +2 -0
  44. torchx/runtime/tracking/__init__.py +2 -0
  45. torchx/runtime/tracking/api.py +2 -0
  46. torchx/schedulers/__init__.py +16 -15
  47. torchx/schedulers/api.py +70 -14
  48. torchx/schedulers/aws_batch_scheduler.py +75 -6
  49. torchx/schedulers/aws_sagemaker_scheduler.py +598 -0
  50. torchx/schedulers/devices.py +17 -4
  51. torchx/schedulers/docker_scheduler.py +43 -11
  52. torchx/schedulers/ids.py +29 -23
  53. torchx/schedulers/kubernetes_mcad_scheduler.py +9 -7
  54. torchx/schedulers/kubernetes_scheduler.py +383 -38
  55. torchx/schedulers/local_scheduler.py +100 -27
  56. torchx/schedulers/lsf_scheduler.py +5 -4
  57. torchx/schedulers/slurm_scheduler.py +336 -20
  58. torchx/schedulers/streams.py +2 -0
  59. torchx/specs/__init__.py +89 -12
  60. torchx/specs/api.py +418 -30
  61. torchx/specs/builders.py +176 -38
  62. torchx/specs/file_linter.py +143 -57
  63. torchx/specs/finder.py +68 -28
  64. torchx/specs/named_resources_aws.py +181 -4
  65. torchx/specs/named_resources_generic.py +2 -0
  66. torchx/specs/overlays.py +106 -0
  67. torchx/specs/test/components/__init__.py +2 -0
  68. torchx/specs/test/components/a/__init__.py +2 -0
  69. torchx/specs/test/components/a/b/__init__.py +2 -0
  70. torchx/specs/test/components/a/b/c.py +2 -0
  71. torchx/specs/test/components/c/__init__.py +2 -0
  72. torchx/specs/test/components/c/d.py +2 -0
  73. torchx/tracker/__init__.py +12 -6
  74. torchx/tracker/api.py +15 -18
  75. torchx/tracker/backend/fsspec.py +2 -0
  76. torchx/util/cuda.py +2 -0
  77. torchx/util/datetime.py +2 -0
  78. torchx/util/entrypoints.py +39 -15
  79. torchx/util/io.py +2 -0
  80. torchx/util/log_tee_helpers.py +210 -0
  81. torchx/util/modules.py +65 -0
  82. torchx/util/session.py +42 -0
  83. torchx/util/shlex.py +2 -0
  84. torchx/util/strings.py +3 -1
  85. torchx/util/types.py +90 -29
  86. torchx/version.py +4 -2
  87. torchx/workspace/__init__.py +2 -0
  88. torchx/workspace/api.py +136 -6
  89. torchx/workspace/dir_workspace.py +2 -0
  90. torchx/workspace/docker_workspace.py +30 -2
  91. torchx_nightly-2025.12.24.dist-info/METADATA +167 -0
  92. torchx_nightly-2025.12.24.dist-info/RECORD +113 -0
  93. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/WHEEL +1 -1
  94. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/entry_points.txt +0 -1
  95. torchx/examples/pipelines/__init__.py +0 -0
  96. torchx/examples/pipelines/kfp/__init__.py +0 -0
  97. torchx/examples/pipelines/kfp/advanced_pipeline.py +0 -287
  98. torchx/examples/pipelines/kfp/dist_pipeline.py +0 -69
  99. torchx/examples/pipelines/kfp/intro_pipeline.py +0 -81
  100. torchx/pipelines/kfp/__init__.py +0 -28
  101. torchx/pipelines/kfp/adapter.py +0 -271
  102. torchx/pipelines/kfp/version.py +0 -17
  103. torchx/schedulers/gcp_batch_scheduler.py +0 -487
  104. torchx/schedulers/ray/ray_common.py +0 -22
  105. torchx/schedulers/ray/ray_driver.py +0 -307
  106. torchx/schedulers/ray_scheduler.py +0 -453
  107. torchx_nightly-2024.1.6.dist-info/METADATA +0 -176
  108. torchx_nightly-2024.1.6.dist-info/RECORD +0 -118
  109. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info/licenses}/LICENSE +0 -0
  110. {torchx_nightly-2024.1.6.dist-info → torchx_nightly-2025.12.24.dist-info}/top_level.txt +0 -0
@@ -1,487 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright (c) Meta Platforms, Inc. and affiliates.
3
- # All rights reserved.
4
- #
5
- # This source code is licensed under the BSD-style license found in the
6
- # LICENSE file in the root directory of this source tree.
7
-
8
- """
9
-
10
- This contains the TorchX GCP Batch scheduler which can be used to run TorchX
11
- components directly on GCP Batch.
12
-
13
- This scheduler is in prototype stage and may change without notice.
14
-
15
- Prerequisites
16
- ==============
17
-
18
- You need to have a GCP project configured to use Batch by enabling and setting it up.
19
- See https://cloud.google.com/batch/docs/get-started#prerequisites
20
-
21
- """
22
-
23
- from dataclasses import dataclass
24
- from datetime import datetime
25
- from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING
26
-
27
- import torchx
28
- import yaml
29
-
30
- from torchx.schedulers.api import (
31
- AppDryRunInfo,
32
- DescribeAppResponse,
33
- ListAppResponse,
34
- Scheduler,
35
- Stream,
36
- )
37
- from torchx.schedulers.ids import make_unique
38
- from torchx.specs.api import AppDef, AppState, macros, Resource, Role, runopts
39
- from torchx.util.strings import normalize_str
40
- from typing_extensions import TypedDict
41
-
42
-
43
- if TYPE_CHECKING:
44
- from google.cloud import batch_v1
45
-
46
-
47
- JOB_STATE: Dict[str, AppState] = {
48
- "STATE_UNSPECIFIED": AppState.UNKNOWN,
49
- "QUEUED": AppState.SUBMITTED,
50
- "SCHEDULED": AppState.PENDING,
51
- "RUNNING": AppState.RUNNING,
52
- "SUCCEEDED": AppState.SUCCEEDED,
53
- "FAILED": AppState.FAILED,
54
- "DELETION_IN_PROGRESS": AppState.UNKNOWN,
55
- }
56
-
57
- GPU_COUNT_TO_TYPE: Dict[int, str] = {
58
- 1: "a2-highgpu-1g",
59
- 2: "a2-highgpu-2g",
60
- 4: "a2-highgpu-4g",
61
- 8: "a2-highgpu-8g",
62
- 16: "a2-highgpu-16g",
63
- }
64
-
65
- GPU_TYPE_TO_COUNT: Dict[str, int] = {v: k for k, v in GPU_COUNT_TO_TYPE.items()}
66
-
67
- LABEL_VERSION: str = "torchx_version"
68
- LABEL_APP_NAME: str = "torchx_app_name"
69
-
70
- DEFAULT_LOC: str = "us-central1"
71
-
72
- # TODO Remove LOCATIONS list once Batch supports all locations
73
- # or when there is an API to query locations supported by Batch
74
- LOCATIONS: List[str] = [
75
- DEFAULT_LOC,
76
- "us-west1",
77
- "us-east1",
78
- "asia-southeast1",
79
- "europe-north1",
80
- "europe-west6",
81
- ]
82
-
83
- BATCH_LOGGER_NAME = "batch_task_logs"
84
-
85
-
86
- @dataclass
87
- class GCPBatchJob:
88
- name: str
89
- project: str
90
- location: str
91
- job_def: "batch_v1.Job"
92
-
93
- def __str__(self) -> str:
94
- return yaml.dump(self.job_def)
95
-
96
- def __repr__(self) -> str:
97
- return str(self)
98
-
99
-
100
- class GCPBatchOpts(TypedDict, total=False):
101
- project: Optional[str]
102
- location: Optional[str]
103
-
104
-
105
- class GCPBatchScheduler(Scheduler[GCPBatchOpts]):
106
- """
107
- GCPBatchScheduler is a TorchX scheduling interface to GCP Batch.
108
-
109
- .. code-block:: bash
110
-
111
- $ pip install torchx[gcp_batch]
112
- $ torchx run --scheduler gcp_batch utils.echo --msg hello
113
- # This launches a job with app handle like gcp_batch://torchx/project:location:app_id1234 and prints it
114
- $ torchx status gcp_batch://torchx/project:location:app_id1234
115
- ...
116
-
117
- Authentication is loaded from the environment using the gcloud credential handling.
118
-
119
- **Config Options**
120
-
121
- .. runopts::
122
- class: torchx.schedulers.gcp_batch_scheduler.create_scheduler
123
-
124
- **Compatibility**
125
-
126
- .. compatibility::
127
- type: scheduler
128
- features:
129
- cancel: true
130
- logs: true
131
- describe: true
132
- distributed: true
133
- workspaces: false
134
- mounts: false
135
- elasticity: false
136
-
137
- """
138
-
139
- def __init__(
140
- self,
141
- session_name: str,
142
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
143
- client: Optional[Any] = None,
144
- ) -> None:
145
- # NOTE: make sure any new init options are supported in create_scheduler(...)
146
- Scheduler.__init__(self, "gcp_batch", session_name)
147
- # pyre-fixme[4]: Attribute annotation cannot be `Any`.
148
- self.__client = client
149
-
150
- @property
151
- # pyre-fixme[3]: Return annotation cannot be `Any`.
152
- def _client(self) -> Any:
153
- from google.api_core import gapic_v1
154
- from google.cloud import batch_v1
155
-
156
- c = self.__client
157
- if c is None:
158
- client_info = gapic_v1.client_info.ClientInfo(
159
- user_agent=f"TorchX/{torchx.__version__}"
160
- )
161
- c = self.__client = batch_v1.BatchServiceClient(client_info=client_info)
162
- return c
163
-
164
- def schedule(self, dryrun_info: AppDryRunInfo[GCPBatchJob]) -> str:
165
- from google.cloud import batch_v1
166
-
167
- req = dryrun_info.request
168
- assert req is not None, f"{dryrun_info} missing request"
169
-
170
- request = batch_v1.CreateJobRequest(
171
- parent=f"projects/{req.project}/locations/{req.location}",
172
- job=req.job_def,
173
- job_id=req.name,
174
- )
175
-
176
- response = self._client.create_job(request=request)
177
- return f"{req.project}:{req.location}:{req.name}"
178
-
179
- def _app_to_job(self, app: AppDef) -> "batch_v1.Job":
180
- from google.cloud import batch_v1
181
-
182
- name = normalize_str(make_unique(app.name))
183
-
184
- taskGroups = []
185
- allocationPolicy = None
186
-
187
- # 1. Convert role to task
188
- # TODO implement retry_policy, mount conversion
189
- # NOTE: Supports only one role for now as GCP Batch supports only one TaskGroup
190
- # which is ok to start with as most components have only one role
191
- for role_idx, role in enumerate(app.roles):
192
- values = macros.Values(
193
- img_root="",
194
- app_id=name,
195
- replica_id=str(0),
196
- rank0_env=("BATCH_MAIN_NODE_HOSTNAME"),
197
- )
198
- role_dict = values.apply(role)
199
- role_dict.env["TORCHX_ROLE_IDX"] = str(role_idx)
200
- role_dict.env["TORCHX_ROLE_NAME"] = str(role.name)
201
-
202
- resource = role_dict.resource
203
- res = batch_v1.ComputeResource()
204
- cpu = resource.cpu
205
- if cpu <= 0:
206
- cpu = 1
207
- MILLI = 1000
208
- res.cpu_milli = cpu * MILLI
209
- memMB = resource.memMB
210
- if memMB < 0:
211
- raise ValueError(
212
- f"memMB should to be set to a positive value, got {memMB}"
213
- )
214
- res.memory_mib = memMB
215
-
216
- # TODO support named resources
217
- # Using v100 as default GPU type as a100 does not allow changing count for now
218
- # TODO See if there is a better default GPU type
219
- if resource.gpu > 0:
220
- if resource.gpu not in GPU_COUNT_TO_TYPE:
221
- raise ValueError(
222
- f"gpu should to be set to one of these values: {GPU_COUNT_TO_TYPE.keys()}"
223
- )
224
- machineType = GPU_COUNT_TO_TYPE[resource.gpu]
225
- allocationPolicy = batch_v1.AllocationPolicy(
226
- instances=[
227
- batch_v1.AllocationPolicy.InstancePolicyOrTemplate(
228
- install_gpu_drivers=True,
229
- policy=batch_v1.AllocationPolicy.InstancePolicy(
230
- machine_type=machineType,
231
- ),
232
- )
233
- ],
234
- )
235
- print(f"Using GPUs of type: {machineType}")
236
-
237
- # Configure host firewall rules to accept ingress communication
238
- config_network_runnable = batch_v1.Runnable(
239
- script=batch_v1.Runnable.Script(
240
- text="/sbin/iptables -A INPUT -j ACCEPT"
241
- )
242
- )
243
-
244
- runnable = batch_v1.Runnable(
245
- container=batch_v1.Runnable.Container(
246
- image_uri=role_dict.image,
247
- commands=[role_dict.entrypoint] + role_dict.args,
248
- entrypoint="",
249
- # Configure docker to use the host network stack to communicate with containers/other hosts in the same network
250
- options="--net host",
251
- )
252
- )
253
-
254
- ts = batch_v1.TaskSpec(
255
- runnables=[config_network_runnable, runnable],
256
- environment=batch_v1.Environment(variables=role_dict.env),
257
- max_retry_count=role_dict.max_retries,
258
- compute_resource=res,
259
- )
260
-
261
- task_env = [
262
- batch_v1.Environment(variables={"TORCHX_REPLICA_IDX": str(i)})
263
- for i in range(role_dict.num_replicas)
264
- ]
265
-
266
- tg = batch_v1.TaskGroup(
267
- task_spec=ts,
268
- task_count=role_dict.num_replicas,
269
- task_count_per_node=1,
270
- task_environments=task_env,
271
- require_hosts_file=True,
272
- )
273
- taskGroups.append(tg)
274
-
275
- # 2. Convert AppDef to Job
276
- job = batch_v1.Job(
277
- name=name,
278
- task_groups=taskGroups,
279
- allocation_policy=allocationPolicy,
280
- logs_policy=batch_v1.LogsPolicy(
281
- destination=batch_v1.LogsPolicy.Destination.CLOUD_LOGGING,
282
- ),
283
- # NOTE: GCP Batch does not allow label names with "."
284
- labels={
285
- LABEL_VERSION: torchx.__version__.replace(".", "-"),
286
- LABEL_APP_NAME: name,
287
- },
288
- )
289
- return job
290
-
291
- def _get_project(self) -> str:
292
- from google.cloud import runtimeconfig
293
-
294
- return runtimeconfig.Client().project
295
-
296
- def _submit_dryrun(
297
- self, app: AppDef, cfg: GCPBatchOpts
298
- ) -> AppDryRunInfo[GCPBatchJob]:
299
- proj = cfg.get("project")
300
- if proj is None:
301
- proj = self._get_project()
302
- assert proj is not None and isinstance(proj, str), "project must be a str"
303
-
304
- loc = cfg.get("location")
305
- assert loc is not None and isinstance(loc, str), "location must be a str"
306
-
307
- job = self._app_to_job(app)
308
-
309
- # Convert JobDef + BatchOpts to GCPBatchJob
310
- req = GCPBatchJob(
311
- name=str(job.name),
312
- project=proj,
313
- location=loc,
314
- job_def=job,
315
- )
316
-
317
- return AppDryRunInfo(req, repr)
318
-
319
- def run_opts(self) -> runopts:
320
- opts = runopts()
321
- opts.add(
322
- "project",
323
- type_=str,
324
- help="Name of the GCP project. Defaults to the configured GCP project in the environment",
325
- )
326
- opts.add(
327
- "location",
328
- type_=str,
329
- default=DEFAULT_LOC,
330
- help=f"Name of the location to schedule the job in. Defaults to {DEFAULT_LOC}",
331
- )
332
- return opts
333
-
334
- def _app_id_to_job_full_name(self, app_id: str) -> str:
335
- """
336
- app_id format: f"{project}:{location}:{name}"
337
- job_full_name format: f"projects/{project}/locations/{location}/jobs/{name}"
338
- where 'name' was created uniquely for the job from the app name
339
- """
340
- app_id_splits = app_id.split(":")
341
- if len(app_id_splits) != 3:
342
- raise ValueError(f"app_id not in expected format: {app_id}")
343
- return f"projects/{app_id_splits[0]}/locations/{app_id_splits[1]}/jobs/{app_id_splits[2]}"
344
-
345
- def _get_job(self, app_id: str) -> "batch_v1.Job":
346
- from google.cloud import batch_v1
347
-
348
- job_name = self._app_id_to_job_full_name(app_id)
349
- request = batch_v1.GetJobRequest(
350
- name=job_name,
351
- )
352
- return self._client.get_job(request=request)
353
-
354
- def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
355
- job = self._get_job(app_id)
356
- if job is None:
357
- print(f"app not found: {app_id}")
358
- return None
359
-
360
- gpu = 0
361
- if len(job.allocation_policy.instances) != 0:
362
- gpu_type = job.allocation_policy.instances[0].policy.machine_type
363
- gpu = GPU_TYPE_TO_COUNT[gpu_type]
364
-
365
- roles = {}
366
- for tg in job.task_groups:
367
- env = tg.task_spec.environment.variables
368
- role = env["TORCHX_ROLE_NAME"]
369
- container = tg.task_spec.runnables[1].container
370
- roles[role] = Role(
371
- name=role,
372
- num_replicas=tg.task_count,
373
- image=container.image_uri,
374
- entrypoint=container.commands[0],
375
- args=list(container.commands[1:]),
376
- resource=Resource(
377
- cpu=int(tg.task_spec.compute_resource.cpu_milli / 1000),
378
- memMB=tg.task_spec.compute_resource.memory_mib,
379
- gpu=gpu,
380
- ),
381
- env=dict(env),
382
- max_retries=tg.task_spec.max_retry_count,
383
- )
384
-
385
- # Map job -> DescribeAppResponse
386
- # TODO map role/replica status
387
- desc = DescribeAppResponse(
388
- app_id=app_id,
389
- state=JOB_STATE[job.status.state.name],
390
- roles=list(roles.values()),
391
- )
392
- return desc
393
-
394
- def log_iter(
395
- self,
396
- app_id: str,
397
- role_name: str = "",
398
- k: int = 0,
399
- regex: Optional[str] = None,
400
- since: Optional[datetime] = None,
401
- until: Optional[datetime] = None,
402
- should_tail: bool = False,
403
- streams: Optional[Stream] = None,
404
- ) -> Iterable[str]:
405
- if streams not in (None, Stream.COMBINED):
406
- raise ValueError("GCPBatchScheduler only supports COMBINED log stream")
407
-
408
- job = self._get_job(app_id)
409
- if not job:
410
- raise ValueError(f"app not found: {app_id}")
411
-
412
- job_uid = job.uid
413
- filters = [f"labels.job_uid={job_uid}"]
414
- filters.append(f"resource.labels.task_id:task/{job_uid}-group0-{k}")
415
-
416
- if since is not None:
417
- filters.append(f'timestamp>="{str(since.isoformat())}"')
418
- else:
419
- # gcloud logger.list by default only returns logs in the last 24 hours
420
- # Since many ML jobs can run longer add timestamp filter to get all logs
421
- filters.append(f'timestamp>="{str(datetime.fromtimestamp(0).isoformat())}"')
422
-
423
- if until is not None:
424
- filters.append(f'timestamp<="{str(until.isoformat())}"')
425
- if regex is not None:
426
- filters.append(f'textPayload =~ "{regex}"')
427
- filter = " AND ".join(filters)
428
- return self._batch_log_iter(filter)
429
-
430
- def _batch_log_iter(self, filter: str) -> Iterable[str]:
431
- from google.cloud import logging
432
-
433
- logger = logging.Client().logger(BATCH_LOGGER_NAME)
434
- for entry in logger.list_entries(filter_=filter):
435
- yield entry.payload
436
-
437
- def _job_full_name_to_app_id(self, job_full_name: str) -> str:
438
- """
439
- job_full_name format: f"projects/{project}/locations/{location}/jobs/{name}"
440
- app_id format: f"{project}:{location}:{name}"
441
- where 'name' was created uniquely for the job from the app name
442
- """
443
- job_name_splits = job_full_name.split("/")
444
- if len(job_name_splits) != 6:
445
- raise ValueError(f"job full name not in expected format: {job_full_name}")
446
- return f"{job_name_splits[1]}:{job_name_splits[3]}:{job_name_splits[5]}"
447
-
448
- def list(self) -> List[ListAppResponse]:
449
- all_jobs = []
450
- proj = self._get_project()
451
- for loc in LOCATIONS:
452
- jobs = self._client.list_jobs(parent=f"projects/{proj}/locations/{loc}")
453
- all_jobs += jobs
454
- all_jobs.sort(key=lambda job: job.create_time.timestamp(), reverse=True)
455
- return [
456
- ListAppResponse(
457
- app_id=self._job_full_name_to_app_id(job.name),
458
- state=JOB_STATE[job.status.state.name],
459
- )
460
- for job in all_jobs
461
- ]
462
-
463
- def _validate(self, app: AppDef, scheduler: str) -> None:
464
- # Skip validation step
465
- pass
466
-
467
- def _cancel_existing(self, app_id: str) -> None:
468
- from google.cloud import batch_v1
469
-
470
- job_name = self._app_id_to_job_full_name(app_id)
471
- request = batch_v1.DeleteJobRequest(
472
- name=job_name,
473
- reason="Killed via TorchX",
474
- )
475
- self._client.delete_job(request=request)
476
-
477
-
478
- def create_scheduler(
479
- session_name: str,
480
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
481
- client: Optional[Any] = None,
482
- **kwargs: object,
483
- ) -> GCPBatchScheduler:
484
- return GCPBatchScheduler(
485
- session_name=session_name,
486
- client=client,
487
- )
@@ -1,22 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from dataclasses import dataclass, field
8
- from typing import Dict, List, Optional
9
-
10
- TORCHX_RANK0_HOST: str = "TORCHX_RANK0_HOST"
11
-
12
-
13
- @dataclass
14
- class RayActor:
15
- """Describes an actor (a.k.a. worker/replica in TorchX terms)."""
16
-
17
- name: str
18
- command: List[str]
19
- env: Dict[str, str] = field(default_factory=dict)
20
- num_cpus: int = 1
21
- num_gpus: int = 0
22
- min_replicas: Optional[int] = None