trainy-policy-nightly 0.1.0.dev20251025104156__tar.gz → 0.1.0.dev20251226104750__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of trainy-policy-nightly might be problematic. Click here for more details.
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/PKG-INFO +1 -1
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/pyproject.toml +1 -1
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/__init__.py +1 -1
- trainy_policy_nightly-0.1.0.dev20251226104750/trainy/policy/gke.py +477 -0
- trainy_policy_nightly-0.1.0.dev20251025104156/trainy/policy/gke.py +0 -224
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/README.md +0 -0
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/config.py +0 -0
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/configs/eks.yaml +0 -0
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/configs/gke.yaml +0 -0
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/configs/nebius.yaml +0 -0
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/logging.py +0 -0
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/policy/__init__.py +0 -0
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/policy/eks.py +0 -0
- {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/policy/nebius.py +0 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
"""GKE policy
|
|
2
|
+
Users must use the following `~/.sky/config.yaml`
|
|
3
|
+
|
|
4
|
+
admin_policy: trainy.policy.GKEPolicy
|
|
5
|
+
|
|
6
|
+
kubernetes:
|
|
7
|
+
autoscaler: gke
|
|
8
|
+
provision_timeout: 600 # should be at least 10 min
|
|
9
|
+
remote_identity: SERVICE_ACCOUNT
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from copy import deepcopy
|
|
14
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
15
|
+
|
|
16
|
+
import sky
|
|
17
|
+
import uuid
|
|
18
|
+
|
|
19
|
+
from trainy.config import load_config
|
|
20
|
+
from trainy.logging import get_logger
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__file__)
|
|
23
|
+
|
|
24
|
+
DEFAULT_QUEUE = "user-queue"
|
|
25
|
+
PRIORITY_CLASSES = ["low-priority", "high-priority"]
|
|
26
|
+
ALLOWED_GPUS = ["H100", "H100-MEGA-80GB", "A100-80GB", "A100"]
|
|
27
|
+
|
|
28
|
+
_POD_SPEC_PATH: Tuple[str, ...] = ("kubernetes", "pod_config", "spec")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _get_dict_path(
|
|
32
|
+
data: Dict[str, Any], path: Sequence[str]
|
|
33
|
+
) -> Optional[Dict[str, Any]]:
|
|
34
|
+
"""Returns the nested dict at path, creating dicts along the way."""
|
|
35
|
+
cursor: Dict[str, Any] = data
|
|
36
|
+
for key in path:
|
|
37
|
+
if key not in cursor or not isinstance(cursor[key], dict):
|
|
38
|
+
cursor[key] = {}
|
|
39
|
+
cursor = cursor[key]
|
|
40
|
+
return cursor
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _merge_named_items(
|
|
44
|
+
base_list: Optional[List[Dict[str, Any]]],
|
|
45
|
+
override_list: Optional[List[Dict[str, Any]]],
|
|
46
|
+
) -> List[Dict[str, Any]]:
|
|
47
|
+
"""Merges two lists of dicts keyed by `name` while preserving order."""
|
|
48
|
+
merged: List[Dict[str, Any]] = []
|
|
49
|
+
index_by_name: Dict[str, int] = {}
|
|
50
|
+
|
|
51
|
+
def _append(item: Dict[str, Any]) -> None:
|
|
52
|
+
index_by_name[item.get("name", f"__idx_{len(merged)}")] = len(merged)
|
|
53
|
+
merged.append(item)
|
|
54
|
+
|
|
55
|
+
for entry in base_list or []:
|
|
56
|
+
if isinstance(entry, dict):
|
|
57
|
+
_append(deepcopy(entry))
|
|
58
|
+
|
|
59
|
+
for entry in override_list or []:
|
|
60
|
+
if not isinstance(entry, dict):
|
|
61
|
+
continue
|
|
62
|
+
name = entry.get("name")
|
|
63
|
+
replacement = deepcopy(entry)
|
|
64
|
+
if name is not None and name in index_by_name:
|
|
65
|
+
merged[index_by_name[name]] = replacement
|
|
66
|
+
else:
|
|
67
|
+
_append(replacement)
|
|
68
|
+
|
|
69
|
+
return merged
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _merge_container_sections(
|
|
73
|
+
base_containers: Optional[List[Dict[str, Any]]],
|
|
74
|
+
override_containers: Optional[List[Dict[str, Any]]],
|
|
75
|
+
) -> List[Dict[str, Any]]:
|
|
76
|
+
"""Merges env/volumeMounts between lists of containers."""
|
|
77
|
+
base_containers = base_containers or []
|
|
78
|
+
override_containers = override_containers or []
|
|
79
|
+
merged_containers: List[Dict[str, Any]] = []
|
|
80
|
+
max_len = max(len(base_containers), len(override_containers))
|
|
81
|
+
|
|
82
|
+
for idx in range(max_len):
|
|
83
|
+
base_container = base_containers[idx] if idx < len(base_containers) else None
|
|
84
|
+
override_container = (
|
|
85
|
+
deepcopy(override_containers[idx]) if idx < len(override_containers) else None
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if override_container is None:
|
|
89
|
+
if base_container is not None:
|
|
90
|
+
merged_containers.append(deepcopy(base_container))
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
merged_env = _merge_named_items(
|
|
94
|
+
base_container.get("env") if base_container else None,
|
|
95
|
+
override_container.get("env"),
|
|
96
|
+
)
|
|
97
|
+
if merged_env:
|
|
98
|
+
override_container["env"] = merged_env
|
|
99
|
+
elif "env" in override_container:
|
|
100
|
+
del override_container["env"]
|
|
101
|
+
|
|
102
|
+
merged_volume_mounts = _merge_named_items(
|
|
103
|
+
base_container.get("volumeMounts") if base_container else None,
|
|
104
|
+
override_container.get("volumeMounts"),
|
|
105
|
+
)
|
|
106
|
+
if merged_volume_mounts:
|
|
107
|
+
override_container["volumeMounts"] = merged_volume_mounts
|
|
108
|
+
elif "volumeMounts" in override_container:
|
|
109
|
+
del override_container["volumeMounts"]
|
|
110
|
+
|
|
111
|
+
merged_containers.append(override_container)
|
|
112
|
+
|
|
113
|
+
return merged_containers
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _merge_pod_spec_sections(
|
|
117
|
+
config: sky.skypilot_config.Config, override_config: Dict[str, Any]
|
|
118
|
+
) -> Dict[str, Any]:
|
|
119
|
+
"""Returns override_config with volumes/env/volumeMounts merged from config."""
|
|
120
|
+
merged_override = deepcopy(override_config)
|
|
121
|
+
spec = _get_dict_path(merged_override, _POD_SPEC_PATH)
|
|
122
|
+
|
|
123
|
+
base_volumes = config.get_nested(_POD_SPEC_PATH + ("volumes",), None)
|
|
124
|
+
override_volumes = spec.get("volumes")
|
|
125
|
+
merged_volumes = _merge_named_items(base_volumes, override_volumes)
|
|
126
|
+
if merged_volumes:
|
|
127
|
+
spec["volumes"] = merged_volumes
|
|
128
|
+
elif "volumes" in spec:
|
|
129
|
+
del spec["volumes"]
|
|
130
|
+
|
|
131
|
+
merged_containers = _merge_container_sections(
|
|
132
|
+
config.get_nested(_POD_SPEC_PATH + ("containers",), None),
|
|
133
|
+
spec.get("containers"),
|
|
134
|
+
)
|
|
135
|
+
if merged_containers:
|
|
136
|
+
spec["containers"] = merged_containers
|
|
137
|
+
elif "containers" in spec:
|
|
138
|
+
del spec["containers"]
|
|
139
|
+
|
|
140
|
+
merged_init_containers = _merge_container_sections(
|
|
141
|
+
config.get_nested(_POD_SPEC_PATH + ("initContainers",), None),
|
|
142
|
+
spec.get("initContainers"),
|
|
143
|
+
)
|
|
144
|
+
if merged_init_containers:
|
|
145
|
+
spec["initContainers"] = merged_init_containers
|
|
146
|
+
elif "initContainers" in spec:
|
|
147
|
+
del spec["initContainers"]
|
|
148
|
+
|
|
149
|
+
return merged_override
|
|
150
|
+
|
|
151
|
+
_POD_SPEC_PATH: Tuple[str, ...] = ("kubernetes", "pod_config", "spec")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _get_dict_path(
|
|
155
|
+
data: Dict[str, Any], path: Sequence[str]
|
|
156
|
+
) -> Optional[Dict[str, Any]]:
|
|
157
|
+
"""Returns the nested dict at path, creating dicts along the way."""
|
|
158
|
+
cursor: Dict[str, Any] = data
|
|
159
|
+
for key in path:
|
|
160
|
+
if key not in cursor or not isinstance(cursor[key], dict):
|
|
161
|
+
cursor[key] = {}
|
|
162
|
+
cursor = cursor[key]
|
|
163
|
+
return cursor
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _merge_named_items(
|
|
167
|
+
base_list: Optional[List[Dict[str, Any]]],
|
|
168
|
+
override_list: Optional[List[Dict[str, Any]]],
|
|
169
|
+
) -> List[Dict[str, Any]]:
|
|
170
|
+
"""Merges two lists of dicts keyed by `name` while preserving order."""
|
|
171
|
+
merged: List[Dict[str, Any]] = []
|
|
172
|
+
index_by_name: Dict[str, int] = {}
|
|
173
|
+
|
|
174
|
+
def _append(item: Dict[str, Any]) -> None:
|
|
175
|
+
index_by_name[item.get("name", f"__idx_{len(merged)}")] = len(merged)
|
|
176
|
+
merged.append(item)
|
|
177
|
+
|
|
178
|
+
for entry in base_list or []:
|
|
179
|
+
if isinstance(entry, dict):
|
|
180
|
+
_append(deepcopy(entry))
|
|
181
|
+
|
|
182
|
+
for entry in override_list or []:
|
|
183
|
+
if not isinstance(entry, dict):
|
|
184
|
+
continue
|
|
185
|
+
name = entry.get("name")
|
|
186
|
+
replacement = deepcopy(entry)
|
|
187
|
+
if name is not None and name in index_by_name:
|
|
188
|
+
merged[index_by_name[name]] = replacement
|
|
189
|
+
else:
|
|
190
|
+
_append(replacement)
|
|
191
|
+
|
|
192
|
+
return merged
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _merge_container_sections(
|
|
196
|
+
base_containers: Optional[List[Dict[str, Any]]],
|
|
197
|
+
override_containers: Optional[List[Dict[str, Any]]],
|
|
198
|
+
) -> List[Dict[str, Any]]:
|
|
199
|
+
"""Merges env/volumeMounts between lists of containers."""
|
|
200
|
+
base_containers = base_containers or []
|
|
201
|
+
override_containers = override_containers or []
|
|
202
|
+
merged_containers: List[Dict[str, Any]] = []
|
|
203
|
+
max_len = max(len(base_containers), len(override_containers))
|
|
204
|
+
|
|
205
|
+
for idx in range(max_len):
|
|
206
|
+
base_container = base_containers[idx] if idx < len(base_containers) else None
|
|
207
|
+
override_container = (
|
|
208
|
+
deepcopy(override_containers[idx]) if idx < len(override_containers) else None
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if override_container is None:
|
|
212
|
+
if base_container is not None:
|
|
213
|
+
merged_containers.append(deepcopy(base_container))
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
merged_env = _merge_named_items(
|
|
217
|
+
base_container.get("env") if base_container else None,
|
|
218
|
+
override_container.get("env"),
|
|
219
|
+
)
|
|
220
|
+
if merged_env:
|
|
221
|
+
override_container["env"] = merged_env
|
|
222
|
+
elif "env" in override_container:
|
|
223
|
+
del override_container["env"]
|
|
224
|
+
|
|
225
|
+
merged_volume_mounts = _merge_named_items(
|
|
226
|
+
base_container.get("volumeMounts") if base_container else None,
|
|
227
|
+
override_container.get("volumeMounts"),
|
|
228
|
+
)
|
|
229
|
+
if merged_volume_mounts:
|
|
230
|
+
override_container["volumeMounts"] = merged_volume_mounts
|
|
231
|
+
elif "volumeMounts" in override_container:
|
|
232
|
+
del override_container["volumeMounts"]
|
|
233
|
+
|
|
234
|
+
merged_containers.append(override_container)
|
|
235
|
+
|
|
236
|
+
return merged_containers
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def _merge_pod_spec_sections(
|
|
240
|
+
config: sky.skypilot_config.Config, override_config: Dict[str, Any]
|
|
241
|
+
) -> Dict[str, Any]:
|
|
242
|
+
"""Returns override_config with volumes/env/volumeMounts merged from config."""
|
|
243
|
+
merged_override = deepcopy(override_config)
|
|
244
|
+
spec = _get_dict_path(merged_override, _POD_SPEC_PATH)
|
|
245
|
+
|
|
246
|
+
base_volumes = config.get_nested(_POD_SPEC_PATH + ("volumes",), None)
|
|
247
|
+
override_volumes = spec.get("volumes")
|
|
248
|
+
merged_volumes = _merge_named_items(base_volumes, override_volumes)
|
|
249
|
+
if merged_volumes:
|
|
250
|
+
spec["volumes"] = merged_volumes
|
|
251
|
+
elif "volumes" in spec:
|
|
252
|
+
del spec["volumes"]
|
|
253
|
+
|
|
254
|
+
merged_containers = _merge_container_sections(
|
|
255
|
+
config.get_nested(_POD_SPEC_PATH + ("containers",), None),
|
|
256
|
+
spec.get("containers"),
|
|
257
|
+
)
|
|
258
|
+
if merged_containers:
|
|
259
|
+
spec["containers"] = merged_containers
|
|
260
|
+
elif "containers" in spec:
|
|
261
|
+
del spec["containers"]
|
|
262
|
+
|
|
263
|
+
merged_init_containers = _merge_container_sections(
|
|
264
|
+
config.get_nested(_POD_SPEC_PATH + ("initContainers",), None),
|
|
265
|
+
spec.get("initContainers"),
|
|
266
|
+
)
|
|
267
|
+
if merged_init_containers:
|
|
268
|
+
spec["initContainers"] = merged_init_containers
|
|
269
|
+
elif "initContainers" in spec:
|
|
270
|
+
del spec["initContainers"]
|
|
271
|
+
|
|
272
|
+
return merged_override
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def set_tcpxo_config(user_request: sky.UserRequest) -> sky.MutatedUserRequest:
|
|
276
|
+
"""Sets pod specs for running TCPXO"""
|
|
277
|
+
task = user_request.task
|
|
278
|
+
config = user_request.skypilot_config
|
|
279
|
+
for resource in task.resources:
|
|
280
|
+
for accelerator, count in resource.accelerators.items():
|
|
281
|
+
if accelerator == "H100-MEGA-80GB":
|
|
282
|
+
k8s_override_config = load_config("gke.yaml")
|
|
283
|
+
config = user_request.skypilot_config
|
|
284
|
+
merged_override = _merge_pod_spec_sections(
|
|
285
|
+
config, k8s_override_config
|
|
286
|
+
)
|
|
287
|
+
new_config = sky.skypilot_config._recursive_update(
|
|
288
|
+
config, merged_override
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
return sky.MutatedUserRequest(
|
|
292
|
+
task=user_request.task, skypilot_config=new_config
|
|
293
|
+
)
|
|
294
|
+
return sky.MutatedUserRequest(task=task, skypilot_config=config)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def validate_set_kueue_dws_labels_annotations(
|
|
298
|
+
user_request: sky.MutatedUserRequest,
|
|
299
|
+
) -> sky.MutatedUserRequest:
|
|
300
|
+
"""Checks for valid Kueue user-queue and priority queue values
|
|
301
|
+
populates the required labels and annotations for Kueue
|
|
302
|
+
"""
|
|
303
|
+
task = user_request.task
|
|
304
|
+
config = user_request.skypilot_config
|
|
305
|
+
new_resources = []
|
|
306
|
+
for resource in task.resources:
|
|
307
|
+
if resource.cloud is None or str(resource.cloud) != "Kubernetes":
|
|
308
|
+
raise ValueError("Only `kubernetes` is permitted as a cloud on Trainy")
|
|
309
|
+
if not resource.accelerators:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
"You must request a GPU instance. Set `accelerators: "
|
|
312
|
+
"H100-MEGA-80GB:8` under resources for example"
|
|
313
|
+
)
|
|
314
|
+
for accelerator, count in resource.accelerators.items():
|
|
315
|
+
if accelerator not in ALLOWED_GPUS:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"{resource.accelerators} requested,"
|
|
318
|
+
f"only `{ALLOWED_GPUS}` allowed"
|
|
319
|
+
)
|
|
320
|
+
nodepool = (
|
|
321
|
+
"kubernetes",
|
|
322
|
+
"pod_config",
|
|
323
|
+
"spec",
|
|
324
|
+
"nodeSelector",
|
|
325
|
+
"cloud.google.com/gke-nodepool",
|
|
326
|
+
)
|
|
327
|
+
if config.get_nested(nodepool, None) is None:
|
|
328
|
+
logger.info(
|
|
329
|
+
"`cloud.google.com/gke-nodepool` not set, "
|
|
330
|
+
f"setting to default nodepool `{accelerator.lower()}-pool`"
|
|
331
|
+
)
|
|
332
|
+
config.set_nested(nodepool, f"{accelerator.lower()}-pool")
|
|
333
|
+
|
|
334
|
+
labels = resource.labels
|
|
335
|
+
if labels is None:
|
|
336
|
+
labels = {}
|
|
337
|
+
queue_name: str = labels.get("kueue.x-k8s.io/queue-name", DEFAULT_QUEUE)
|
|
338
|
+
priority: str = labels.get("kueue.x-k8s.io/priority-class", "low-priority")
|
|
339
|
+
run_duration: str = labels.get("max-run-duration-seconds", None)
|
|
340
|
+
# if queue_name != DEFAULT_QUEUE:
|
|
341
|
+
# raise ValueError(
|
|
342
|
+
# f"{queue_name} queue was selected, "
|
|
343
|
+
# f"only {DEFAULT_QUEUE} queue is permitted for hosted Trainy clusters"
|
|
344
|
+
# )
|
|
345
|
+
if priority not in PRIORITY_CLASSES:
|
|
346
|
+
raise ValueError(
|
|
347
|
+
f"priority `{priority}` was selected, "
|
|
348
|
+
f"only {PRIORITY_CLASSES} are available"
|
|
349
|
+
)
|
|
350
|
+
if task.name is None:
|
|
351
|
+
raise ValueError("no sky.Task name defined. You must set a task name")
|
|
352
|
+
if len(task.name) > 59:
|
|
353
|
+
raise ValueError(f"sky.Task name is {len(task.name)} long. Expected 58 characters or less.")
|
|
354
|
+
labels.update(
|
|
355
|
+
{
|
|
356
|
+
"kueue.x-k8s.io/queue-name": queue_name,
|
|
357
|
+
"kueue.x-k8s.io/priority-class": priority,
|
|
358
|
+
"kueue.x-k8s.io/pod-group-name": f"{task.name}-"
|
|
359
|
+
f"{uuid.uuid4().hex[:4]}",
|
|
360
|
+
}
|
|
361
|
+
)
|
|
362
|
+
if resource.labels is not None:
|
|
363
|
+
resource.labels.update(labels)
|
|
364
|
+
new_resources.append(resource)
|
|
365
|
+
task.set_resources(type(task.resources)(new_resources))
|
|
366
|
+
|
|
367
|
+
# pod annotations
|
|
368
|
+
config.set_nested(
|
|
369
|
+
(
|
|
370
|
+
"kubernetes",
|
|
371
|
+
"pod_config",
|
|
372
|
+
"metadata",
|
|
373
|
+
"annotations",
|
|
374
|
+
"kueue.x-k8s.io/pod-group-total-count",
|
|
375
|
+
),
|
|
376
|
+
str(task.num_nodes),
|
|
377
|
+
)
|
|
378
|
+
config.set_nested(
|
|
379
|
+
(
|
|
380
|
+
"kubernetes",
|
|
381
|
+
"pod_config",
|
|
382
|
+
"metadata",
|
|
383
|
+
"annotations",
|
|
384
|
+
"kueue.x-k8s.io/retriable-in-group",
|
|
385
|
+
),
|
|
386
|
+
"false",
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
maxRunDurationSeconds = (
|
|
390
|
+
"kubernetes",
|
|
391
|
+
"pod_config",
|
|
392
|
+
"metadata",
|
|
393
|
+
"annotations",
|
|
394
|
+
"provreq.kueue.x-k8s.io/maxRunDurationSeconds",
|
|
395
|
+
)
|
|
396
|
+
if config.get_nested(maxRunDurationSeconds, None) is None:
|
|
397
|
+
if run_duration is None:
|
|
398
|
+
raise ValueError("You must specify a label for `max-run-duration-seconds`")
|
|
399
|
+
# maximum runtime on gke dws is 7 days
|
|
400
|
+
config.set_nested(maxRunDurationSeconds, str(run_duration))
|
|
401
|
+
|
|
402
|
+
run_duration = config.get_nested(maxRunDurationSeconds, None)
|
|
403
|
+
assert run_duration is not None
|
|
404
|
+
if not (0 < int(run_duration) <= 3600 * 24 * 7):
|
|
405
|
+
raise ValueError(
|
|
406
|
+
f"largest allowed run duration is 7 days "
|
|
407
|
+
f" = {3600 * 24 * 7} seconds {int(run_duration)} requested "
|
|
408
|
+
"from either `max-run-duration-seconds` or "
|
|
409
|
+
"`provreq.kueue.x-k8s.io/maxRunDurationSeconds`"
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
safe_to_evict = (
|
|
413
|
+
"kubernetes",
|
|
414
|
+
"pod_config",
|
|
415
|
+
"metadata",
|
|
416
|
+
"annotations",
|
|
417
|
+
"cluster-autoscaler.kubernetes.io/safe-to-evict",
|
|
418
|
+
)
|
|
419
|
+
if config.get_nested(safe_to_evict, None) is None:
|
|
420
|
+
config.set_nested(safe_to_evict, "false")
|
|
421
|
+
|
|
422
|
+
return sky.MutatedUserRequest(task=task, skypilot_config=config)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class GKEPolicy(sky.AdminPolicy):
|
|
426
|
+
"""GKE specific configurations."""
|
|
427
|
+
|
|
428
|
+
@classmethod
|
|
429
|
+
def validate_and_mutate(
|
|
430
|
+
cls, user_request: sky.UserRequest
|
|
431
|
+
) -> sky.MutatedUserRequest:
|
|
432
|
+
"""Updates the kubernetes context to use
|
|
433
|
+
and kueue labels and sets GKE autoscaler
|
|
434
|
+
"""
|
|
435
|
+
if not user_request.task.is_controller_task():
|
|
436
|
+
new_request: sky.MutatedUserRequest = set_tcpxo_config(user_request)
|
|
437
|
+
new_request = validate_set_kueue_dws_labels_annotations(user_request)
|
|
438
|
+
return sky.MutatedUserRequest(
|
|
439
|
+
task=new_request.task, skypilot_config=new_request.skypilot_config
|
|
440
|
+
)
|
|
441
|
+
return sky.MutatedUserRequest(
|
|
442
|
+
task=user_request.task, skypilot_config=user_request.skypilot_config
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def configure_and_get_allowed_contexts():
|
|
447
|
+
"""Mock implementation of getting allowed kubernetes contexts."""
|
|
448
|
+
from sky.provision.kubernetes import utils
|
|
449
|
+
|
|
450
|
+
contexts = utils.get_all_kube_config_context_names()
|
|
451
|
+
return contexts[:2]
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
class TailscaleGKEPolicy(GKEPolicy):
|
|
455
|
+
@classmethod
|
|
456
|
+
def validate_and_mutate(
|
|
457
|
+
cls, user_request: sky.UserRequest
|
|
458
|
+
) -> sky.MutatedUserRequest:
|
|
459
|
+
"""Updates the kubernetes context to use
|
|
460
|
+
and kueue labels and sets GKE autoscaler
|
|
461
|
+
"""
|
|
462
|
+
|
|
463
|
+
super().validate_and_mutate(user_request=user_request)
|
|
464
|
+
|
|
465
|
+
# Append any new kubernetes clusters in local kubeconfig. An example
|
|
466
|
+
# implementation of this method can be:
|
|
467
|
+
# 1. Query tailscale for k8s clusters.
|
|
468
|
+
# 2. Append the new credentials to the local kubeconfig.
|
|
469
|
+
# 3. Set the allow contexts for the cluster.
|
|
470
|
+
|
|
471
|
+
# Get the allowed contexts for the user. Similarly, it can retrieve
|
|
472
|
+
# the latest allowed contexts from an organization's internal API.
|
|
473
|
+
# allowed_contexts = configure_and_get_allowed_contexts()
|
|
474
|
+
|
|
475
|
+
# # Update the kubernetes allowed contexts in skypilot config.
|
|
476
|
+
# config = user_request.skypilot_config
|
|
477
|
+
# config.set_nested(("kubernetes", "allowed_contexts"), allowed_contexts)
|
|
@@ -1,224 +0,0 @@
|
|
|
1
|
-
"""GKE policy
|
|
2
|
-
Users must use the following `~/.sky/config.yaml`
|
|
3
|
-
|
|
4
|
-
admin_policy: trainy.policy.GKEPolicy
|
|
5
|
-
|
|
6
|
-
kubernetes:
|
|
7
|
-
autoscaler: gke
|
|
8
|
-
provision_timeout: 600 # should be at least 10 min
|
|
9
|
-
remote_identity: SERVICE_ACCOUNT
|
|
10
|
-
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
|
-
import sky
|
|
14
|
-
import uuid
|
|
15
|
-
|
|
16
|
-
from trainy.config import load_config
|
|
17
|
-
from trainy.logging import get_logger
|
|
18
|
-
|
|
19
|
-
logger = get_logger(__file__)
|
|
20
|
-
|
|
21
|
-
DEFAULT_QUEUE = "user-queue"
|
|
22
|
-
PRIORITY_CLASSES = ["low-priority", "high-priority"]
|
|
23
|
-
ALLOWED_GPUS = ["H100", "H100-MEGA-80GB", "A100-80GB", "A100"]
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def set_tcpxo_config(user_request: sky.UserRequest) -> sky.MutatedUserRequest:
|
|
27
|
-
"""Sets pod specs for running TCPXO"""
|
|
28
|
-
task = user_request.task
|
|
29
|
-
config = user_request.skypilot_config
|
|
30
|
-
for resource in task.resources:
|
|
31
|
-
for accelerator, count in resource.accelerators.items():
|
|
32
|
-
if accelerator == "H100-MEGA-80GB":
|
|
33
|
-
k8s_override_config = load_config("gke.yaml")
|
|
34
|
-
config = user_request.skypilot_config
|
|
35
|
-
new_config = sky.skypilot_config._recursive_update(
|
|
36
|
-
config, k8s_override_config
|
|
37
|
-
)
|
|
38
|
-
return sky.MutatedUserRequest(
|
|
39
|
-
task=user_request.task, skypilot_config=new_config
|
|
40
|
-
)
|
|
41
|
-
return sky.MutatedUserRequest(task=task, skypilot_config=config)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def validate_set_kueue_dws_labels_annotations(
|
|
45
|
-
user_request: sky.MutatedUserRequest,
|
|
46
|
-
) -> sky.MutatedUserRequest:
|
|
47
|
-
"""Checks for valid Kueue user-queue and priority queue values
|
|
48
|
-
populates the required labels and annotations for Kueue
|
|
49
|
-
"""
|
|
50
|
-
task = user_request.task
|
|
51
|
-
config = user_request.skypilot_config
|
|
52
|
-
new_resources = []
|
|
53
|
-
for resource in task.resources:
|
|
54
|
-
if resource.cloud is None or str(resource.cloud) != "Kubernetes":
|
|
55
|
-
raise ValueError("Only `kubernetes` is permitted as a cloud on Trainy")
|
|
56
|
-
if not resource.accelerators:
|
|
57
|
-
raise ValueError(
|
|
58
|
-
"You must request a GPU instance. Set `accelerators: "
|
|
59
|
-
"H100-MEGA-80GB:8` under resources for example"
|
|
60
|
-
)
|
|
61
|
-
for accelerator, count in resource.accelerators.items():
|
|
62
|
-
if accelerator not in ALLOWED_GPUS:
|
|
63
|
-
raise ValueError(
|
|
64
|
-
f"{resource.accelerators} requested,"
|
|
65
|
-
f"only `{ALLOWED_GPUS}` allowed"
|
|
66
|
-
)
|
|
67
|
-
nodepool = (
|
|
68
|
-
"kubernetes",
|
|
69
|
-
"pod_config",
|
|
70
|
-
"spec",
|
|
71
|
-
"nodeSelector",
|
|
72
|
-
"cloud.google.com/gke-nodepool",
|
|
73
|
-
)
|
|
74
|
-
if config.get_nested(nodepool, None) is None:
|
|
75
|
-
logger.info(
|
|
76
|
-
"`cloud.google.com/gke-nodepool` not set, "
|
|
77
|
-
f"setting to default nodepool `{accelerator.lower()}-pool`"
|
|
78
|
-
)
|
|
79
|
-
config.set_nested(nodepool, f"{accelerator.lower()}-pool")
|
|
80
|
-
|
|
81
|
-
labels = resource.labels
|
|
82
|
-
if labels is None:
|
|
83
|
-
labels = {}
|
|
84
|
-
queue_name: str = labels.get("kueue.x-k8s.io/queue-name", DEFAULT_QUEUE)
|
|
85
|
-
priority: str = labels.get("kueue.x-k8s.io/priority-class", "low-priority")
|
|
86
|
-
run_duration: str = labels.get("max-run-duration-seconds", None)
|
|
87
|
-
# if queue_name != DEFAULT_QUEUE:
|
|
88
|
-
# raise ValueError(
|
|
89
|
-
# f"{queue_name} queue was selected, "
|
|
90
|
-
# f"only {DEFAULT_QUEUE} queue is permitted for hosted Trainy clusters"
|
|
91
|
-
# )
|
|
92
|
-
if priority not in PRIORITY_CLASSES:
|
|
93
|
-
raise ValueError(
|
|
94
|
-
f"priority `{priority}` was selected, "
|
|
95
|
-
f"only {PRIORITY_CLASSES} are available"
|
|
96
|
-
)
|
|
97
|
-
if task.name is None:
|
|
98
|
-
raise ValueError("no sky.Task name defined. You must set a task name")
|
|
99
|
-
if len(task.name) > 59:
|
|
100
|
-
raise ValueError(f"sky.Task name is {len(task.name)} long. Expected 58 characters or less.")
|
|
101
|
-
labels.update(
|
|
102
|
-
{
|
|
103
|
-
"kueue.x-k8s.io/queue-name": queue_name,
|
|
104
|
-
"kueue.x-k8s.io/priority-class": priority,
|
|
105
|
-
"kueue.x-k8s.io/pod-group-name": f"{task.name}-"
|
|
106
|
-
f"{uuid.uuid4().hex[:4]}",
|
|
107
|
-
}
|
|
108
|
-
)
|
|
109
|
-
if resource.labels is not None:
|
|
110
|
-
resource.labels.update(labels)
|
|
111
|
-
new_resources.append(resource)
|
|
112
|
-
task.set_resources(type(task.resources)(new_resources))
|
|
113
|
-
|
|
114
|
-
# pod annotations
|
|
115
|
-
config.set_nested(
|
|
116
|
-
(
|
|
117
|
-
"kubernetes",
|
|
118
|
-
"pod_config",
|
|
119
|
-
"metadata",
|
|
120
|
-
"annotations",
|
|
121
|
-
"kueue.x-k8s.io/pod-group-total-count",
|
|
122
|
-
),
|
|
123
|
-
str(task.num_nodes),
|
|
124
|
-
)
|
|
125
|
-
config.set_nested(
|
|
126
|
-
(
|
|
127
|
-
"kubernetes",
|
|
128
|
-
"pod_config",
|
|
129
|
-
"metadata",
|
|
130
|
-
"annotations",
|
|
131
|
-
"kueue.x-k8s.io/retriable-in-group",
|
|
132
|
-
),
|
|
133
|
-
"false",
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
maxRunDurationSeconds = (
|
|
137
|
-
"kubernetes",
|
|
138
|
-
"pod_config",
|
|
139
|
-
"metadata",
|
|
140
|
-
"annotations",
|
|
141
|
-
"provreq.kueue.x-k8s.io/maxRunDurationSeconds",
|
|
142
|
-
)
|
|
143
|
-
if config.get_nested(maxRunDurationSeconds, None) is None:
|
|
144
|
-
if run_duration is None:
|
|
145
|
-
raise ValueError("You must specify a label for `max-run-duration-seconds`")
|
|
146
|
-
# maximum runtime on gke dws is 7 days
|
|
147
|
-
config.set_nested(maxRunDurationSeconds, str(run_duration))
|
|
148
|
-
|
|
149
|
-
run_duration = config.get_nested(maxRunDurationSeconds, None)
|
|
150
|
-
assert run_duration is not None
|
|
151
|
-
if not (0 < int(run_duration) <= 3600 * 24 * 7):
|
|
152
|
-
raise ValueError(
|
|
153
|
-
f"largest allowed run duration is 7 days "
|
|
154
|
-
f" = {3600 * 24 * 7} seconds {int(run_duration)} requested "
|
|
155
|
-
"from either `max-run-duration-seconds` or "
|
|
156
|
-
"`provreq.kueue.x-k8s.io/maxRunDurationSeconds`"
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
safe_to_evict = (
|
|
160
|
-
"kubernetes",
|
|
161
|
-
"pod_config",
|
|
162
|
-
"metadata",
|
|
163
|
-
"annotations",
|
|
164
|
-
"cluster-autoscaler.kubernetes.io/safe-to-evict",
|
|
165
|
-
)
|
|
166
|
-
if config.get_nested(safe_to_evict, None) is None:
|
|
167
|
-
config.set_nested(safe_to_evict, "false")
|
|
168
|
-
|
|
169
|
-
return sky.MutatedUserRequest(task=task, skypilot_config=config)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
class GKEPolicy(sky.AdminPolicy):
|
|
173
|
-
"""GKE specific configurations."""
|
|
174
|
-
|
|
175
|
-
@classmethod
|
|
176
|
-
def validate_and_mutate(
|
|
177
|
-
cls, user_request: sky.UserRequest
|
|
178
|
-
) -> sky.MutatedUserRequest:
|
|
179
|
-
"""Updates the kubernetes context to use
|
|
180
|
-
and kueue labels and sets GKE autoscaler
|
|
181
|
-
"""
|
|
182
|
-
if not user_request.task.is_controller_task():
|
|
183
|
-
new_request: sky.MutatedUserRequest = set_tcpxo_config(user_request)
|
|
184
|
-
new_request = validate_set_kueue_dws_labels_annotations(user_request)
|
|
185
|
-
return sky.MutatedUserRequest(
|
|
186
|
-
task=new_request.task, skypilot_config=new_request.skypilot_config
|
|
187
|
-
)
|
|
188
|
-
return sky.MutatedUserRequest(
|
|
189
|
-
task=user_request.task, skypilot_config=user_request.skypilot_config
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
def configure_and_get_allowed_contexts():
|
|
194
|
-
"""Mock implementation of getting allowed kubernetes contexts."""
|
|
195
|
-
from sky.provision.kubernetes import utils
|
|
196
|
-
|
|
197
|
-
contexts = utils.get_all_kube_config_context_names()
|
|
198
|
-
return contexts[:2]
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
class TailscaleGKEPolicy(GKEPolicy):
|
|
202
|
-
@classmethod
|
|
203
|
-
def validate_and_mutate(
|
|
204
|
-
cls, user_request: sky.UserRequest
|
|
205
|
-
) -> sky.MutatedUserRequest:
|
|
206
|
-
"""Updates the kubernetes context to use
|
|
207
|
-
and kueue labels and sets GKE autoscaler
|
|
208
|
-
"""
|
|
209
|
-
|
|
210
|
-
super().validate_and_mutate(user_request=user_request)
|
|
211
|
-
|
|
212
|
-
# Append any new kubernetes clusters in local kubeconfig. An example
|
|
213
|
-
# implementation of this method can be:
|
|
214
|
-
# 1. Query tailscale for k8s clusters.
|
|
215
|
-
# 2. Append the new credentials to the local kubeconfig.
|
|
216
|
-
# 3. Set the allow contexts for the cluster.
|
|
217
|
-
|
|
218
|
-
# Get the allowed contexts for the user. Similarly, it can retrieve
|
|
219
|
-
# the latest allowed contexts from an organization's internal API.
|
|
220
|
-
# allowed_contexts = configure_and_get_allowed_contexts()
|
|
221
|
-
|
|
222
|
-
# # Update the kubernetes allowed contexts in skypilot config.
|
|
223
|
-
# config = user_request.skypilot_config
|
|
224
|
-
# config.set_nested(("kubernetes", "allowed_contexts"), allowed_contexts)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|