pathways-cli 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ This document details the codebase design, architecture, key lessons, and integr
9
9
  The project is structured around a standard PEP 621 package layout:
10
10
 
11
11
  ```
12
- /Users/stoelinga/workspace/pathways-cli/
12
+ pathways-cli/
13
13
  ├── pyproject.toml # Package configurations, CLI scripts, and Pytest options
14
14
  ├── .gitignore # Excludes local environments, caches, and secrets
15
15
  ├── README.md # User documentation and example verification steps
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pathways-cli
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Summary: Pathways CLI to easily bring up pathways clusters.
5
5
  Author-email: Sam Stoelinga <sammiestoel@gmail.com>
6
6
  Requires-Python: >=3.12
@@ -26,26 +26,28 @@ Description-Content-Type: text/markdown
26
26
 
27
27
  ## Installation
28
28
 
29
- This project utilizes [uv](https://github.com/astral-sh/uv) for fast, modern Python package and dependency management.
30
-
31
- To sync the environment and install `pwy`:
29
+ Install `pathways-cli` from PyPI using your preferred package manager:
32
30
 
33
31
  ```bash
34
- uv sync
32
+ # Using pip
33
+ pip install pathways-cli
34
+
35
+ # Or using uv (recommended for fast tool management)
36
+ uv tool install pathways-cli
35
37
  ```
36
38
 
37
39
  ---
38
40
 
39
41
  ## Usage
40
42
 
41
- You can invoke `pwy` commands directly using `uv run`:
43
+ Once installed, you can invoke the `pwy` CLI directly:
42
44
 
43
45
  ### 1. Provision / Preview a Cluster (`pwy up`)
44
46
 
45
47
  Starts a Pathways JobSet or dry-runs the configuration.
46
48
 
47
49
  ```bash
48
- uv run pwy up \
50
+ pwy up \
49
51
  --tpu-type v6e-16 \
50
52
  --gcs-scratch-location gs://my-bucket/pathways-staging \
51
53
  --num-slices 1 \
@@ -71,7 +73,7 @@ uv run pwy up \
71
73
  Deletes the running Pathways JobSet.
72
74
 
73
75
  ```bash
74
- uv run pwy down --name pathways-interactive --namespace default
76
+ pwy down --name pathways-interactive --namespace default
75
77
  ```
76
78
 
77
79
  ---
@@ -101,15 +103,13 @@ Once the interactive cluster is running, you can verify execution by `exec`ing i
101
103
 
102
104
  ## TPU Type Mappings
103
105
 
104
- `pwy` handles all resource-limit math and topologies automatically according to the following matrix:
106
+ `pwy` handles all resource-limit math and topologies automatically. It supports a wide range of TPU generations, including:
105
107
 
106
- | TPU Type | GKE Topology | VMs Per Slice | RM Instance Type |
107
- | :--- | :--- | :--- | :--- |
108
- | `v6e-4` | `2x2` | 1 | `tpuv6e:2x2` |
109
- | `v6e-8` | `2x4` | 2 | `tpuv6e:2x4` |
110
- | `v6e-16` | `4x4` | 4 | `tpuv6e:4x4` |
111
- | `v6e-32` | `4x8` | 8 | `tpuv6e:4x8` |
112
- | `v6e-64` | `8x8` | 16 | `tpuv6e:8x8` |
108
+ - **TPU v6e**: `v6e-4` up to `v6e-256` (including `v6e-8-1` with 8 chips per VM)
109
+ - **TPU v5p**: `v5p-8` up to `v5p-17920`
110
+ - **TPU v5e (v5LitePod)**: `v5litepod-8` up to `v5litepod-256`
111
+ - **TPU v4**: `v4-8` up to `v4-4096`
112
+ - **TPU 7x**: `7x-8` up to `7x-8192`
113
113
 
114
114
  ---
115
115
 
@@ -16,26 +16,28 @@
16
16
 
17
17
  ## Installation
18
18
 
19
- This project utilizes [uv](https://github.com/astral-sh/uv) for fast, modern Python package and dependency management.
20
-
21
- To sync the environment and install `pwy`:
19
+ Install `pathways-cli` from PyPI using your preferred package manager:
22
20
 
23
21
  ```bash
24
- uv sync
22
+ # Using pip
23
+ pip install pathways-cli
24
+
25
+ # Or using uv (recommended for fast tool management)
26
+ uv tool install pathways-cli
25
27
  ```
26
28
 
27
29
  ---
28
30
 
29
31
  ## Usage
30
32
 
31
- You can invoke `pwy` commands directly using `uv run`:
33
+ Once installed, you can invoke the `pwy` CLI directly:
32
34
 
33
35
  ### 1. Provision / Preview a Cluster (`pwy up`)
34
36
 
35
37
  Starts a Pathways JobSet or dry-runs the configuration.
36
38
 
37
39
  ```bash
38
- uv run pwy up \
40
+ pwy up \
39
41
  --tpu-type v6e-16 \
40
42
  --gcs-scratch-location gs://my-bucket/pathways-staging \
41
43
  --num-slices 1 \
@@ -61,7 +63,7 @@ uv run pwy up \
61
63
  Deletes the running Pathways JobSet.
62
64
 
63
65
  ```bash
64
- uv run pwy down --name pathways-interactive --namespace default
66
+ pwy down --name pathways-interactive --namespace default
65
67
  ```
66
68
 
67
69
  ---
@@ -91,15 +93,13 @@ Once the interactive cluster is running, you can verify execution by `exec`ing i
91
93
 
92
94
  ## TPU Type Mappings
93
95
 
94
- `pwy` handles all resource-limit math and topologies automatically according to the following matrix:
96
+ `pwy` handles all resource-limit math and topologies automatically. It supports a wide range of TPU generations, including:
95
97
 
96
- | TPU Type | GKE Topology | VMs Per Slice | RM Instance Type |
97
- | :--- | :--- | :--- | :--- |
98
- | `v6e-4` | `2x2` | 1 | `tpuv6e:2x2` |
99
- | `v6e-8` | `2x4` | 2 | `tpuv6e:2x4` |
100
- | `v6e-16` | `4x4` | 4 | `tpuv6e:4x4` |
101
- | `v6e-32` | `4x8` | 8 | `tpuv6e:4x8` |
102
- | `v6e-64` | `8x8` | 16 | `tpuv6e:8x8` |
98
+ - **TPU v6e**: `v6e-4` up to `v6e-256` (including `v6e-8-1` with 8 chips per VM)
99
+ - **TPU v5p**: `v5p-8` up to `v5p-17920`
100
+ - **TPU v5e (v5LitePod)**: `v5litepod-8` up to `v5litepod-256`
101
+ - **TPU v4**: `v4-8` up to `v4-4096`
102
+ - **TPU 7x**: `7x-8` up to `7x-8192`
103
103
 
104
104
  ---
105
105
 
@@ -300,7 +300,7 @@ spec:
300
300
  A new standalone directory structure will be created under the workspace (mocking a new repository context):
301
301
 
302
302
  ```
303
- /Users/stoelinga/workspace/pathways-cli/
303
+ pathways-cli/
304
304
  ├── pyproject.toml
305
305
  ├── README.md
306
306
  ├── pwy/
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "pathways-cli"
3
- version = "0.1.0"
3
+ version = "0.1.1"
4
4
  description = "Pathways CLI to easily bring up pathways clusters."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -0,0 +1,270 @@
1
+ from pwy.templates import YAML_TEMPLATE
2
+
3
+ TPU_MAPPINGS = {
4
+ # v6e
5
+ "v6e-4": {"topology": "2x2", "vms_per_slice": 1, "gke_accelerator": "tpu-v6e-slice", "rm_type": "tpuv6e:2x2", "chips_per_vm": 4},
6
+ "v6e-8": {"topology": "2x4", "vms_per_slice": 2, "gke_accelerator": "tpu-v6e-slice", "rm_type": "tpuv6e:2x4", "chips_per_vm": 4},
7
+ "v6e-8-1": {"topology": "2x4", "vms_per_slice": 1, "gke_accelerator": "tpu-v6e-slice", "rm_type": "tpuv6e:2x4", "chips_per_vm": 8},
8
+ "v6e-16": {"topology": "4x4", "vms_per_slice": 4, "gke_accelerator": "tpu-v6e-slice", "rm_type": "tpuv6e:4x4", "chips_per_vm": 4},
9
+ "v6e-32": {"topology": "4x8", "vms_per_slice": 8, "gke_accelerator": "tpu-v6e-slice", "rm_type": "tpuv6e:4x8", "chips_per_vm": 4},
10
+ "v6e-64": {"topology": "8x8", "vms_per_slice": 16, "gke_accelerator": "tpu-v6e-slice", "rm_type": "tpuv6e:8x8", "chips_per_vm": 4},
11
+ "v6e-128": {"topology": "8x16", "vms_per_slice": 32, "gke_accelerator": "tpu-v6e-slice", "rm_type": "tpuv6e:8x16", "chips_per_vm": 4},
12
+ "v6e-256": {"topology": "16x16", "vms_per_slice": 64, "gke_accelerator": "tpu-v6e-slice", "rm_type": "tpuv6e:16x16", "chips_per_vm": 4},
13
+
14
+ # v5litepod
15
+ "v5litepod-8": {"topology": "2x4", "vms_per_slice": 2, "gke_accelerator": "tpu-v5-lite-podslice", "rm_type": "tpuv5litepod:2x4", "chips_per_vm": 4},
16
+ "v5litepod-16": {"topology": "4x4", "vms_per_slice": 4, "gke_accelerator": "tpu-v5-lite-podslice", "rm_type": "tpuv5litepod:4x4", "chips_per_vm": 4},
17
+ "v5litepod-32": {"topology": "4x8", "vms_per_slice": 8, "gke_accelerator": "tpu-v5-lite-podslice", "rm_type": "tpuv5litepod:4x8", "chips_per_vm": 4},
18
+ "v5litepod-64": {"topology": "8x8", "vms_per_slice": 16, "gke_accelerator": "tpu-v5-lite-podslice", "rm_type": "tpuv5litepod:8x8", "chips_per_vm": 4},
19
+ "v5litepod-128": {"topology": "8x16", "vms_per_slice": 32, "gke_accelerator": "tpu-v5-lite-podslice", "rm_type": "tpuv5litepod:8x16", "chips_per_vm": 4},
20
+ "v5litepod-256": {"topology": "16x16", "vms_per_slice": 64, "gke_accelerator": "tpu-v5-lite-podslice", "rm_type": "tpuv5litepod:16x16", "chips_per_vm": 4},
21
+
22
+ # 7x
23
+ "7x-8": {"topology": "2x2x1", "vms_per_slice": 1, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:2x2x1", "chips_per_vm": 4},
24
+ "7x-16": {"topology": "2x2x2", "vms_per_slice": 2, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:2x2x2", "chips_per_vm": 4},
25
+ "7x-32": {"topology": "2x2x4", "vms_per_slice": 4, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:2x2x4", "chips_per_vm": 4},
26
+ "7x-64": {"topology": "2x4x4", "vms_per_slice": 8, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:2x4x4", "chips_per_vm": 4},
27
+ "7x-128": {"topology": "4x4x4", "vms_per_slice": 16, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:4x4x4", "chips_per_vm": 4},
28
+ "7x-256": {"topology": "4x4x8", "vms_per_slice": 32, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:4x4x8", "chips_per_vm": 4},
29
+ "7x-512": {"topology": "4x8x8", "vms_per_slice": 64, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:4x8x8", "chips_per_vm": 4},
30
+ "7x-1024": {"topology": "8x8x8", "vms_per_slice": 128, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:8x8x8", "chips_per_vm": 4},
31
+ "7x-2048": {"topology": "8x8x16", "vms_per_slice": 256, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:8x8x16", "chips_per_vm": 4},
32
+ "7x-4096": {"topology": "8x16x16", "vms_per_slice": 512, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:8x16x16", "chips_per_vm": 4},
33
+ "7x-8192": {"topology": "16x16x16", "vms_per_slice": 1024, "gke_accelerator": "tpu7x", "rm_type": "tpu7x:16x16x16", "chips_per_vm": 4},
34
+
35
+ # v4
36
+ "v4-8": {"topology": "2x2x1", "vms_per_slice": 1, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:2x2x1", "chips_per_vm": 4},
37
+ "v4-16": {"topology": "2x2x2", "vms_per_slice": 2, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:2x2x2", "chips_per_vm": 4},
38
+ "v4-32": {"topology": "2x2x4", "vms_per_slice": 4, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:2x2x4", "chips_per_vm": 4},
39
+ "v4-64": {"topology": "2x4x4", "vms_per_slice": 8, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:2x4x4", "chips_per_vm": 4},
40
+ "v4-128": {"topology": "4x4x4", "vms_per_slice": 16, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:4x4x4", "chips_per_vm": 4},
41
+ "v4-256": {"topology": "4x4x8", "vms_per_slice": 32, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:4x4x8", "chips_per_vm": 4},
42
+ "v4-512": {"topology": "4x8x8", "vms_per_slice": 64, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:4x8x8", "chips_per_vm": 4},
43
+ "v4-1024": {"topology": "8x8x8", "vms_per_slice": 128, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:8x8x8", "chips_per_vm": 4},
44
+ "v4-1536": {"topology": "8x8x12", "vms_per_slice": 192, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:8x8x12", "chips_per_vm": 4},
45
+ "v4-2048": {"topology": "8x8x16", "vms_per_slice": 256, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:8x8x16", "chips_per_vm": 4},
46
+ "v4-4096": {"topology": "8x16x16", "vms_per_slice": 512, "gke_accelerator": "tpu-v4-podslice", "rm_type": "tpuv4:8x16x16", "chips_per_vm": 4},
47
+ }
48
+
49
+ # Dynamically populate v5p topologies
50
+ _V5P_DATA = [
51
+ ("v5p-8", "2x2x1", 1),
52
+ ("v5p-16", "2x2x2", 2),
53
+ ("v5p-32", "2x2x4", 4),
54
+ ("v5p-64", "2x4x4", 8),
55
+ ("v5p-128", "4x4x4", 16),
56
+ ("v5p-256", "4x4x8", 32),
57
+ ("v5p-384", "4x4x12", 48),
58
+ ("v5p-512", "4x8x8", 64),
59
+ ("v5p-640", "4x4x20", 80),
60
+ ("v5p-768", "4x8x12", 96),
61
+ ("v5p-896", "4x4x28", 112),
62
+ ("v5p-1024", "8x8x8", 128),
63
+ ("v5p-1152", "4x12x12", 144),
64
+ ("v5p-1280", "4x8x20", 160),
65
+ ("v5p-1408", "4x4x44", 176),
66
+ ("v5p-1536", "8x8x12", 192),
67
+ ("v5p-1664", "4x4x52", 208),
68
+ ("v5p-1792", "4x8x28", 224),
69
+ ("v5p-1920", "4x12x20", 240),
70
+ ("v5p-2048", "8x8x16", 256),
71
+ ("v5p-2176", "4x4x68", 272),
72
+ ("v5p-2304", "8x12x12", 288),
73
+ ("v5p-2432", "4x4x76", 304),
74
+ ("v5p-2560", "8x8x20", 320),
75
+ ("v5p-2688", "4x12x28", 336),
76
+ ("v5p-2816", "4x8x44", 352),
77
+ ("v5p-2944", "4x4x92", 368),
78
+ ("v5p-3072", "8x12x16", 384),
79
+ ("v5p-3200", "4x20x20", 400),
80
+ ("v5p-3328", "4x8x52", 416),
81
+ ("v5p-3456", "12x12x12", 432),
82
+ ("v5p-3584", "8x8x28", 448),
83
+ ("v5p-3712", "4x4x116", 464),
84
+ ("v5p-3840", "8x12x20", 480),
85
+ ("v5p-3968", "4x4x124", 496),
86
+ ("v5p-4096", "8x16x16", 512),
87
+ ("v5p-4224", "4x12x44", 528),
88
+ ("v5p-4352", "4x8x68", 544),
89
+ ("v5p-4480", "4x20x28", 560),
90
+ ("v5p-4608", "12x12x16", 576),
91
+ ("v5p-4736", "4x4x148", 592),
92
+ ("v5p-4864", "4x8x76", 608),
93
+ ("v5p-4992", "4x12x52", 624),
94
+ ("v5p-5120", "8x16x20", 640),
95
+ ("v5p-5248", "4x4x164", 656),
96
+ ("v5p-5376", "8x12x28", 672),
97
+ ("v5p-5504", "4x4x172", 688),
98
+ ("v5p-5632", "8x8x44", 704),
99
+ ("v5p-5760", "12x12x20", 720),
100
+ ("v5p-5888", "4x8x92", 736),
101
+ ("v5p-6016", "4x4x188", 752),
102
+ ("v5p-6144", "12x16x16", 768),
103
+ ("v5p-6272", "4x28x28", 784),
104
+ ("v5p-6400", "8x20x20", 800),
105
+ ("v5p-6528", "4x12x68", 816),
106
+ ("v5p-6656", "8x8x52", 832),
107
+ ("v5p-6784", "4x4x212", 848),
108
+ ("v5p-6912", "12x12x24", 864),
109
+ ("v5p-7040", "4x20x44", 880),
110
+ ("v5p-7168", "8x16x28", 896),
111
+ ("v5p-7296", "4x12x76", 912),
112
+ ("v5p-7424", "4x8x116", 928),
113
+ ("v5p-7552", "4x4x236", 944),
114
+ ("v5p-7680", "12x16x20", 960),
115
+ ("v5p-7808", "4x4x244", 976),
116
+ ("v5p-7936", "4x8x124", 992),
117
+ ("v5p-8064", "12x12x28", 1008),
118
+ ("v5p-8192", "16x16x16", 1024),
119
+ ("v5p-8320", "4x20x52", 1040),
120
+ ("v5p-8448", "8x12x44", 1056),
121
+ ("v5p-8704", "8x8x68", 1088),
122
+ ("v5p-8832", "4x12x92", 1104),
123
+ ("v5p-8960", "8x20x28", 1120),
124
+ ("v5p-9216", "12x16x24", 1152),
125
+ ("v5p-9472", "4x8x148", 1184),
126
+ ("v5p-9600", "12x20x20", 1200),
127
+ ("v5p-9728", "8x8x76", 1216),
128
+ ("v5p-9856", "4x28x44", 1232),
129
+ ("v5p-9984", "8x12x52", 1248),
130
+ ("v5p-10240", "16x16x20", 1280),
131
+ ("v5p-10368", "12x12x36", 1296),
132
+ ("v5p-10496", "4x8x164", 1312),
133
+ ("v5p-10752", "12x16x28", 1344),
134
+ ("v5p-10880", "4x20x68", 1360),
135
+ ("v5p-11008", "4x8x172", 1376),
136
+ ("v5p-11136", "4x12x116", 1392),
137
+ ("v5p-11264", "8x16x44", 1408),
138
+ ("v5p-11520", "12x20x24", 1440),
139
+ ("v5p-11648", "4x28x52", 1456),
140
+ ("v5p-11776", "8x8x92", 1472),
141
+ ("v5p-11904", "4x12x124", 1488),
142
+ ("v5p-12032", "4x8x188", 1504),
143
+ ("v5p-12160", "4x20x76", 1520),
144
+ ("v5p-12288", "16x16x24", 1536),
145
+ ("v5p-13824", "12x24x24", 1728),
146
+ ("v5p-17920", "16x20x28", 2240),
147
+ ]
148
+ for key, topo, vms in _V5P_DATA:
149
+ TPU_MAPPINGS[key] = {
150
+ "topology": topo,
151
+ "vms_per_slice": vms,
152
+ "gke_accelerator": "tpu-v5p-slice",
153
+ "rm_type": f"tpuv5p:{topo}",
154
+ "chips_per_vm": 4,
155
+ }
156
+
157
+ def get_colocated_python_image(client_image: str) -> str:
158
+ if "/" in client_image and ":" in client_image:
159
+ try:
160
+ path, tag = client_image.rsplit(":", 1)
161
+ repo, _ = path.rsplit("/", 1)
162
+ return f"{repo}/colocated-python:{tag}"
163
+ except Exception:
164
+ pass
165
+ return "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/colocated-python:jax-0.10.0"
166
+
167
+ def generate_yaml(
168
+ name: str,
169
+ namespace: str,
170
+ tpu_type: str,
171
+ gcs_scratch_location: str,
172
+ num_slices: int = 1,
173
+ jax_client_image: str = "python:3.12-slim",
174
+ command: str = None,
175
+ enable_spot: bool = False,
176
+ colocated_python: bool = False,
177
+ ) -> str:
178
+ if tpu_type not in TPU_MAPPINGS:
179
+ raise ValueError(
180
+ f"Unsupported TPU type: {tpu_type}. Supported types: {list(TPU_MAPPINGS.keys())}"
181
+ )
182
+
183
+ mapping = TPU_MAPPINGS[tpu_type]
184
+ gke_topology = mapping["topology"]
185
+ vms_per_slice = mapping["vms_per_slice"]
186
+ rm_instance_type = mapping["rm_type"]
187
+
188
+ # Format client execution command
189
+ if not command:
190
+ client_command = "sleep infinity"
191
+ else:
192
+ client_command = command
193
+
194
+ # Format Spot VM Node Selector and Tolerations
195
+ if enable_spot:
196
+ spot_toleration_head = (
197
+ ' - key: "cloud.google.com/gke-spot"\n'
198
+ ' operator: "Equal"\n'
199
+ ' value: "true"\n'
200
+ ' effect: "NoSchedule"'
201
+ )
202
+ spot_node_selector_worker = ' cloud.google.com/gke-spot: "true"'
203
+ spot_toleration_worker = (
204
+ ' - key: "cloud.google.com/gke-spot"\n'
205
+ ' operator: "Equal"\n'
206
+ ' value: "true"\n'
207
+ ' effect: "NoSchedule"'
208
+ )
209
+ else:
210
+ spot_toleration_head = ""
211
+ spot_node_selector_worker = ""
212
+ spot_toleration_worker = ""
213
+
214
+ # Format colocated python options
215
+ if colocated_python:
216
+ proxy_sidecar_arg = "\n - --sidecar_name=external"
217
+ tpu_premapped_buffer_size = 34359738368 # 32 GiB
218
+ colocated_img = get_colocated_python_image(jax_client_image)
219
+ worker_init_containers = (
220
+ " initContainers:\n"
221
+ " - name: colocated-python\n"
222
+ f" image: {colocated_img}\n"
223
+ " imagePullPolicy: Always\n"
224
+ " restartPolicy: Always\n"
225
+ " ports:\n"
226
+ " - containerPort: 50051\n"
227
+ " protocol: TCP\n"
228
+ " env:\n"
229
+ " - name: CLOUD_PATHWAYS_SIDECAR_SHM_DIRECTORY\n"
230
+ " value: /tmp/ifrt_proxy\n"
231
+ " - name: GRPC_SERVER_ADDRESS\n"
232
+ " value: 0.0.0.0:50051\n"
233
+ " volumeMounts:\n"
234
+ " - name: shared-memory\n"
235
+ " mountPath: /tmp/ifrt_proxy"
236
+ )
237
+ else:
238
+ proxy_sidecar_arg = ""
239
+ tpu_premapped_buffer_size = 274877906944 # 256 GiB
240
+ worker_init_containers = ""
241
+
242
+ # Interpolate variables in the template
243
+ yaml_content = YAML_TEMPLATE.format(
244
+ NAME=name,
245
+ NAMESPACE=namespace,
246
+ CLIENT_IMAGE=jax_client_image,
247
+ CLIENT_EXECUTION_COMMAND=client_command,
248
+ TPU_TYPE=tpu_type,
249
+ NUM_SLICES=num_slices,
250
+ RM_INSTANCE_TYPE=rm_instance_type,
251
+ GCS_SCRATCH_LOCATION=gcs_scratch_location,
252
+ GKE_TOPOLOGY=gke_topology,
253
+ VMS_PER_SLICE=vms_per_slice,
254
+ GKE_ACCELERATOR=mapping["gke_accelerator"],
255
+ CHIPS_PER_VM=mapping["chips_per_vm"],
256
+ SPOT_TOLERATION_HEAD=spot_toleration_head,
257
+ SPOT_NODE_SELECTOR_WORKER=spot_node_selector_worker,
258
+ SPOT_TOLERATION_WORKER=spot_toleration_worker,
259
+ PROXY_SIDECAR_ARG=proxy_sidecar_arg,
260
+ TPU_PREMAPPED_BUFFER_SIZE=tpu_premapped_buffer_size,
261
+ WORKER_INIT_CONTAINERS=worker_init_containers,
262
+ )
263
+
264
+ # Clean up empty lines caused by optional block placeholders
265
+ # (specifically ensuring there are no lines with only whitespace or empty lines where placeholders were)
266
+ lines = []
267
+ for line in yaml_content.splitlines():
268
+ if line.strip() or line == "":
269
+ lines.append(line)
270
+ return "\n".join(lines) + "\n"
@@ -139,7 +139,7 @@ spec:
139
139
  - metadata
140
140
  - metadata.google.internal
141
141
  nodeSelector:
142
- cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
142
+ cloud.google.com/gke-tpu-accelerator: {GKE_ACCELERATOR}
143
143
  cloud.google.com/gke-tpu-topology: {GKE_TOPOLOGY}
144
144
  {SPOT_NODE_SELECTOR_WORKER}
145
145
  tolerations:
@@ -163,7 +163,7 @@ spec:
163
163
  privileged: true
164
164
  resources:
165
165
  limits:
166
- google.com/tpu: 4
166
+ google.com/tpu: {CHIPS_PER_VM}
167
167
  env:
168
168
  - name: TPU_TYPE
169
169
  value: {TPU_TYPE}
@@ -108,7 +108,33 @@ def test_generate_yaml_invalid_tpu_type():
108
108
  generate_yaml(
109
109
  name="test-run",
110
110
  namespace="default",
111
- tpu_type="v5p-8", # not in the mappings
111
+ tpu_type="invalid-tpu", # not in the mappings
112
112
  gcs_scratch_location="gs://my-bucket/staging",
113
113
  )
114
- assert "Unsupported TPU type: v5p-8" in str(excinfo.value)
114
+ assert "Unsupported TPU type: invalid-tpu" in str(excinfo.value)
115
+
116
+ def test_generate_yaml_v5p_8():
117
+ yaml_content = generate_yaml(
118
+ name="test-run",
119
+ namespace="default",
120
+ tpu_type="v5p-8",
121
+ gcs_scratch_location="gs://my-bucket/staging",
122
+ )
123
+ assert "cloud.google.com/gke-tpu-accelerator: tpu-v5p-slice" in yaml_content
124
+ assert "cloud.google.com/gke-tpu-topology: 2x2x1" in yaml_content
125
+ assert "--instance_type=tpuv5p:2x2x1" in yaml_content
126
+ assert "google.com/tpu: 4" in yaml_content
127
+
128
+ def test_generate_yaml_v6e_8_1_eight_chips():
129
+ yaml_content = generate_yaml(
130
+ name="test-run",
131
+ namespace="default",
132
+ tpu_type="v6e-8-1",
133
+ gcs_scratch_location="gs://my-bucket/staging",
134
+ )
135
+ assert "cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice" in yaml_content
136
+ assert "cloud.google.com/gke-tpu-topology: 2x4" in yaml_content
137
+ assert "--instance_type=tpuv6e:2x4" in yaml_content
138
+ assert "google.com/tpu: 8" in yaml_content
139
+
140
+
@@ -1,122 +0,0 @@
1
- from pwy.templates import YAML_TEMPLATE
2
-
3
- TPU_MAPPINGS = {
4
- "v6e-4": {"topology": "2x2", "vms_per_slice": 1, "rm_type": "tpuv6e:2x2"},
5
- "v6e-8": {"topology": "2x4", "vms_per_slice": 2, "rm_type": "tpuv6e:2x4"},
6
- "v6e-16": {"topology": "4x4", "vms_per_slice": 4, "rm_type": "tpuv6e:4x4"},
7
- "v6e-32": {"topology": "4x8", "vms_per_slice": 8, "rm_type": "tpuv6e:4x8"},
8
- "v6e-64": {"topology": "8x8", "vms_per_slice": 16, "rm_type": "tpuv6e:8x8"},
9
- }
10
-
11
- def get_colocated_python_image(client_image: str) -> str:
12
- if "/" in client_image and ":" in client_image:
13
- try:
14
- path, tag = client_image.rsplit(":", 1)
15
- repo, _ = path.rsplit("/", 1)
16
- return f"{repo}/colocated-python:{tag}"
17
- except Exception:
18
- pass
19
- return "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/colocated-python:jax-0.10.0"
20
-
21
- def generate_yaml(
22
- name: str,
23
- namespace: str,
24
- tpu_type: str,
25
- gcs_scratch_location: str,
26
- num_slices: int = 1,
27
- jax_client_image: str = "python:3.12-slim",
28
- command: str = None,
29
- enable_spot: bool = False,
30
- colocated_python: bool = False,
31
- ) -> str:
32
- if tpu_type not in TPU_MAPPINGS:
33
- raise ValueError(
34
- f"Unsupported TPU type: {tpu_type}. Supported types: {list(TPU_MAPPINGS.keys())}"
35
- )
36
-
37
- mapping = TPU_MAPPINGS[tpu_type]
38
- gke_topology = mapping["topology"]
39
- vms_per_slice = mapping["vms_per_slice"]
40
- rm_instance_type = mapping["rm_type"]
41
-
42
- # Format client execution command
43
- if not command:
44
- client_command = "sleep infinity"
45
- else:
46
- client_command = command
47
-
48
- # Format Spot VM Node Selector and Tolerations
49
- if enable_spot:
50
- spot_toleration_head = (
51
- ' - key: "cloud.google.com/gke-spot"\n'
52
- ' operator: "Equal"\n'
53
- ' value: "true"\n'
54
- ' effect: "NoSchedule"'
55
- )
56
- spot_node_selector_worker = ' cloud.google.com/gke-spot: "true"'
57
- spot_toleration_worker = (
58
- ' - key: "cloud.google.com/gke-spot"\n'
59
- ' operator: "Equal"\n'
60
- ' value: "true"\n'
61
- ' effect: "NoSchedule"'
62
- )
63
- else:
64
- spot_toleration_head = ""
65
- spot_node_selector_worker = ""
66
- spot_toleration_worker = ""
67
-
68
- # Format colocated python options
69
- if colocated_python:
70
- proxy_sidecar_arg = "\n - --sidecar_name=external"
71
- tpu_premapped_buffer_size = 34359738368 # 32 GiB
72
- colocated_img = get_colocated_python_image(jax_client_image)
73
- worker_init_containers = (
74
- " initContainers:\n"
75
- " - name: colocated-python\n"
76
- f" image: {colocated_img}\n"
77
- " imagePullPolicy: Always\n"
78
- " restartPolicy: Always\n"
79
- " ports:\n"
80
- " - containerPort: 50051\n"
81
- " protocol: TCP\n"
82
- " env:\n"
83
- " - name: CLOUD_PATHWAYS_SIDECAR_SHM_DIRECTORY\n"
84
- " value: /tmp/ifrt_proxy\n"
85
- " - name: GRPC_SERVER_ADDRESS\n"
86
- " value: 0.0.0.0:50051\n"
87
- " volumeMounts:\n"
88
- " - name: shared-memory\n"
89
- " mountPath: /tmp/ifrt_proxy"
90
- )
91
- else:
92
- proxy_sidecar_arg = ""
93
- tpu_premapped_buffer_size = 274877906944 # 256 GiB
94
- worker_init_containers = ""
95
-
96
- # Interpolate variables in the template
97
- yaml_content = YAML_TEMPLATE.format(
98
- NAME=name,
99
- NAMESPACE=namespace,
100
- CLIENT_IMAGE=jax_client_image,
101
- CLIENT_EXECUTION_COMMAND=client_command,
102
- TPU_TYPE=tpu_type,
103
- NUM_SLICES=num_slices,
104
- RM_INSTANCE_TYPE=rm_instance_type,
105
- GCS_SCRATCH_LOCATION=gcs_scratch_location,
106
- GKE_TOPOLOGY=gke_topology,
107
- VMS_PER_SLICE=vms_per_slice,
108
- SPOT_TOLERATION_HEAD=spot_toleration_head,
109
- SPOT_NODE_SELECTOR_WORKER=spot_node_selector_worker,
110
- SPOT_TOLERATION_WORKER=spot_toleration_worker,
111
- PROXY_SIDECAR_ARG=proxy_sidecar_arg,
112
- TPU_PREMAPPED_BUFFER_SIZE=tpu_premapped_buffer_size,
113
- WORKER_INIT_CONTAINERS=worker_init_containers,
114
- )
115
-
116
- # Clean up empty lines caused by optional block placeholders
117
- # (specifically ensuring there are no lines with only whitespace or empty lines where placeholders were)
118
- lines = []
119
- for line in yaml_content.splitlines():
120
- if line.strip() or line == "":
121
- lines.append(line)
122
- return "\n".join(lines) + "\n"
File without changes
File without changes