trainy-policy-nightly 0.1.0.dev20251025104156__tar.gz → 0.1.0.dev20251226104750__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of trainy-policy-nightly might be problematic. Click here for more details.

Files changed (14) hide show
  1. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/PKG-INFO +1 -1
  2. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/pyproject.toml +1 -1
  3. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/__init__.py +1 -1
  4. trainy_policy_nightly-0.1.0.dev20251226104750/trainy/policy/gke.py +477 -0
  5. trainy_policy_nightly-0.1.0.dev20251025104156/trainy/policy/gke.py +0 -224
  6. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/README.md +0 -0
  7. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/config.py +0 -0
  8. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/configs/eks.yaml +0 -0
  9. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/configs/gke.yaml +0 -0
  10. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/configs/nebius.yaml +0 -0
  11. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/logging.py +0 -0
  12. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/policy/__init__.py +0 -0
  13. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/policy/eks.py +0 -0
  14. {trainy_policy_nightly-0.1.0.dev20251025104156 → trainy_policy_nightly-0.1.0.dev20251226104750}/trainy/policy/nebius.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: trainy-policy-nightly
3
- Version: 0.1.0.dev20251025104156
3
+ Version: 0.1.0.dev20251226104750
4
4
  Summary: Trainy Skypilot Policy
5
5
  Author: Andrew Aikawa
6
6
  Author-email: asai@berkeley.edu
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "trainy-policy-nightly"
3
- version = "0.1.0.dev20251025104156"
3
+ version = "0.1.0.dev20251226104750"
4
4
  description = "Trainy Skypilot Policy"
5
5
  packages = [
6
6
  {include = "trainy"}
@@ -4,7 +4,7 @@ import os
4
4
  import subprocess
5
5
 
6
6
  # Replaced with the current commit when building the wheels.
7
- _POLICY_COMMIT_SHA = "14ff34c33b11d38c00a264197f0b885f67a65c45"
7
+ _POLICY_COMMIT_SHA = "a7ce34b2b1ed91770b4485278b256f3dbc7d609c"
8
8
 
9
9
 
10
10
  def _get_git_commit():
@@ -0,0 +1,477 @@
1
+ """GKE policy
2
+ Users must use the following `~/.sky/config.yaml`
3
+
4
+ admin_policy: trainy.policy.GKEPolicy
5
+
6
+ kubernetes:
7
+ autoscaler: gke
8
+ provision_timeout: 600 # should be at least 10 min
9
+ remote_identity: SERVICE_ACCOUNT
10
+
11
+ """
12
+
13
+ from copy import deepcopy
14
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
15
+
16
+ import sky
17
+ import uuid
18
+
19
+ from trainy.config import load_config
20
+ from trainy.logging import get_logger
21
+
22
+ logger = get_logger(__file__)
23
+
24
+ DEFAULT_QUEUE = "user-queue"
25
+ PRIORITY_CLASSES = ["low-priority", "high-priority"]
26
+ ALLOWED_GPUS = ["H100", "H100-MEGA-80GB", "A100-80GB", "A100"]
27
+
28
+ _POD_SPEC_PATH: Tuple[str, ...] = ("kubernetes", "pod_config", "spec")
29
+
30
+
31
+ def _get_dict_path(
32
+ data: Dict[str, Any], path: Sequence[str]
33
+ ) -> Optional[Dict[str, Any]]:
34
+ """Returns the nested dict at path, creating dicts along the way."""
35
+ cursor: Dict[str, Any] = data
36
+ for key in path:
37
+ if key not in cursor or not isinstance(cursor[key], dict):
38
+ cursor[key] = {}
39
+ cursor = cursor[key]
40
+ return cursor
41
+
42
+
43
+ def _merge_named_items(
44
+ base_list: Optional[List[Dict[str, Any]]],
45
+ override_list: Optional[List[Dict[str, Any]]],
46
+ ) -> List[Dict[str, Any]]:
47
+ """Merges two lists of dicts keyed by `name` while preserving order."""
48
+ merged: List[Dict[str, Any]] = []
49
+ index_by_name: Dict[str, int] = {}
50
+
51
+ def _append(item: Dict[str, Any]) -> None:
52
+ index_by_name[item.get("name", f"__idx_{len(merged)}")] = len(merged)
53
+ merged.append(item)
54
+
55
+ for entry in base_list or []:
56
+ if isinstance(entry, dict):
57
+ _append(deepcopy(entry))
58
+
59
+ for entry in override_list or []:
60
+ if not isinstance(entry, dict):
61
+ continue
62
+ name = entry.get("name")
63
+ replacement = deepcopy(entry)
64
+ if name is not None and name in index_by_name:
65
+ merged[index_by_name[name]] = replacement
66
+ else:
67
+ _append(replacement)
68
+
69
+ return merged
70
+
71
+
72
+ def _merge_container_sections(
73
+ base_containers: Optional[List[Dict[str, Any]]],
74
+ override_containers: Optional[List[Dict[str, Any]]],
75
+ ) -> List[Dict[str, Any]]:
76
+ """Merges env/volumeMounts between lists of containers."""
77
+ base_containers = base_containers or []
78
+ override_containers = override_containers or []
79
+ merged_containers: List[Dict[str, Any]] = []
80
+ max_len = max(len(base_containers), len(override_containers))
81
+
82
+ for idx in range(max_len):
83
+ base_container = base_containers[idx] if idx < len(base_containers) else None
84
+ override_container = (
85
+ deepcopy(override_containers[idx]) if idx < len(override_containers) else None
86
+ )
87
+
88
+ if override_container is None:
89
+ if base_container is not None:
90
+ merged_containers.append(deepcopy(base_container))
91
+ continue
92
+
93
+ merged_env = _merge_named_items(
94
+ base_container.get("env") if base_container else None,
95
+ override_container.get("env"),
96
+ )
97
+ if merged_env:
98
+ override_container["env"] = merged_env
99
+ elif "env" in override_container:
100
+ del override_container["env"]
101
+
102
+ merged_volume_mounts = _merge_named_items(
103
+ base_container.get("volumeMounts") if base_container else None,
104
+ override_container.get("volumeMounts"),
105
+ )
106
+ if merged_volume_mounts:
107
+ override_container["volumeMounts"] = merged_volume_mounts
108
+ elif "volumeMounts" in override_container:
109
+ del override_container["volumeMounts"]
110
+
111
+ merged_containers.append(override_container)
112
+
113
+ return merged_containers
114
+
115
+
116
+ def _merge_pod_spec_sections(
117
+ config: sky.skypilot_config.Config, override_config: Dict[str, Any]
118
+ ) -> Dict[str, Any]:
119
+ """Returns override_config with volumes/env/volumeMounts merged from config."""
120
+ merged_override = deepcopy(override_config)
121
+ spec = _get_dict_path(merged_override, _POD_SPEC_PATH)
122
+
123
+ base_volumes = config.get_nested(_POD_SPEC_PATH + ("volumes",), None)
124
+ override_volumes = spec.get("volumes")
125
+ merged_volumes = _merge_named_items(base_volumes, override_volumes)
126
+ if merged_volumes:
127
+ spec["volumes"] = merged_volumes
128
+ elif "volumes" in spec:
129
+ del spec["volumes"]
130
+
131
+ merged_containers = _merge_container_sections(
132
+ config.get_nested(_POD_SPEC_PATH + ("containers",), None),
133
+ spec.get("containers"),
134
+ )
135
+ if merged_containers:
136
+ spec["containers"] = merged_containers
137
+ elif "containers" in spec:
138
+ del spec["containers"]
139
+
140
+ merged_init_containers = _merge_container_sections(
141
+ config.get_nested(_POD_SPEC_PATH + ("initContainers",), None),
142
+ spec.get("initContainers"),
143
+ )
144
+ if merged_init_containers:
145
+ spec["initContainers"] = merged_init_containers
146
+ elif "initContainers" in spec:
147
+ del spec["initContainers"]
148
+
149
+ return merged_override
150
+
151
+ _POD_SPEC_PATH: Tuple[str, ...] = ("kubernetes", "pod_config", "spec")
152
+
153
+
154
+ def _get_dict_path(
155
+ data: Dict[str, Any], path: Sequence[str]
156
+ ) -> Optional[Dict[str, Any]]:
157
+ """Returns the nested dict at path, creating dicts along the way."""
158
+ cursor: Dict[str, Any] = data
159
+ for key in path:
160
+ if key not in cursor or not isinstance(cursor[key], dict):
161
+ cursor[key] = {}
162
+ cursor = cursor[key]
163
+ return cursor
164
+
165
+
166
+ def _merge_named_items(
167
+ base_list: Optional[List[Dict[str, Any]]],
168
+ override_list: Optional[List[Dict[str, Any]]],
169
+ ) -> List[Dict[str, Any]]:
170
+ """Merges two lists of dicts keyed by `name` while preserving order."""
171
+ merged: List[Dict[str, Any]] = []
172
+ index_by_name: Dict[str, int] = {}
173
+
174
+ def _append(item: Dict[str, Any]) -> None:
175
+ index_by_name[item.get("name", f"__idx_{len(merged)}")] = len(merged)
176
+ merged.append(item)
177
+
178
+ for entry in base_list or []:
179
+ if isinstance(entry, dict):
180
+ _append(deepcopy(entry))
181
+
182
+ for entry in override_list or []:
183
+ if not isinstance(entry, dict):
184
+ continue
185
+ name = entry.get("name")
186
+ replacement = deepcopy(entry)
187
+ if name is not None and name in index_by_name:
188
+ merged[index_by_name[name]] = replacement
189
+ else:
190
+ _append(replacement)
191
+
192
+ return merged
193
+
194
+
195
+ def _merge_container_sections(
196
+ base_containers: Optional[List[Dict[str, Any]]],
197
+ override_containers: Optional[List[Dict[str, Any]]],
198
+ ) -> List[Dict[str, Any]]:
199
+ """Merges env/volumeMounts between lists of containers."""
200
+ base_containers = base_containers or []
201
+ override_containers = override_containers or []
202
+ merged_containers: List[Dict[str, Any]] = []
203
+ max_len = max(len(base_containers), len(override_containers))
204
+
205
+ for idx in range(max_len):
206
+ base_container = base_containers[idx] if idx < len(base_containers) else None
207
+ override_container = (
208
+ deepcopy(override_containers[idx]) if idx < len(override_containers) else None
209
+ )
210
+
211
+ if override_container is None:
212
+ if base_container is not None:
213
+ merged_containers.append(deepcopy(base_container))
214
+ continue
215
+
216
+ merged_env = _merge_named_items(
217
+ base_container.get("env") if base_container else None,
218
+ override_container.get("env"),
219
+ )
220
+ if merged_env:
221
+ override_container["env"] = merged_env
222
+ elif "env" in override_container:
223
+ del override_container["env"]
224
+
225
+ merged_volume_mounts = _merge_named_items(
226
+ base_container.get("volumeMounts") if base_container else None,
227
+ override_container.get("volumeMounts"),
228
+ )
229
+ if merged_volume_mounts:
230
+ override_container["volumeMounts"] = merged_volume_mounts
231
+ elif "volumeMounts" in override_container:
232
+ del override_container["volumeMounts"]
233
+
234
+ merged_containers.append(override_container)
235
+
236
+ return merged_containers
237
+
238
+
239
+ def _merge_pod_spec_sections(
240
+ config: sky.skypilot_config.Config, override_config: Dict[str, Any]
241
+ ) -> Dict[str, Any]:
242
+ """Returns override_config with volumes/env/volumeMounts merged from config."""
243
+ merged_override = deepcopy(override_config)
244
+ spec = _get_dict_path(merged_override, _POD_SPEC_PATH)
245
+
246
+ base_volumes = config.get_nested(_POD_SPEC_PATH + ("volumes",), None)
247
+ override_volumes = spec.get("volumes")
248
+ merged_volumes = _merge_named_items(base_volumes, override_volumes)
249
+ if merged_volumes:
250
+ spec["volumes"] = merged_volumes
251
+ elif "volumes" in spec:
252
+ del spec["volumes"]
253
+
254
+ merged_containers = _merge_container_sections(
255
+ config.get_nested(_POD_SPEC_PATH + ("containers",), None),
256
+ spec.get("containers"),
257
+ )
258
+ if merged_containers:
259
+ spec["containers"] = merged_containers
260
+ elif "containers" in spec:
261
+ del spec["containers"]
262
+
263
+ merged_init_containers = _merge_container_sections(
264
+ config.get_nested(_POD_SPEC_PATH + ("initContainers",), None),
265
+ spec.get("initContainers"),
266
+ )
267
+ if merged_init_containers:
268
+ spec["initContainers"] = merged_init_containers
269
+ elif "initContainers" in spec:
270
+ del spec["initContainers"]
271
+
272
+ return merged_override
273
+
274
+
275
+ def set_tcpxo_config(user_request: sky.UserRequest) -> sky.MutatedUserRequest:
276
+ """Sets pod specs for running TCPXO"""
277
+ task = user_request.task
278
+ config = user_request.skypilot_config
279
+ for resource in task.resources:
280
+ for accelerator, count in resource.accelerators.items():
281
+ if accelerator == "H100-MEGA-80GB":
282
+ k8s_override_config = load_config("gke.yaml")
283
+ config = user_request.skypilot_config
284
+ merged_override = _merge_pod_spec_sections(
285
+ config, k8s_override_config
286
+ )
287
+ new_config = sky.skypilot_config._recursive_update(
288
+ config, merged_override
289
+ )
290
+
291
+ return sky.MutatedUserRequest(
292
+ task=user_request.task, skypilot_config=new_config
293
+ )
294
+ return sky.MutatedUserRequest(task=task, skypilot_config=config)
295
+
296
+
297
+ def validate_set_kueue_dws_labels_annotations(
298
+ user_request: sky.MutatedUserRequest,
299
+ ) -> sky.MutatedUserRequest:
300
+ """Checks for valid Kueue user-queue and priority queue values
301
+ populates the required labels and annotations for Kueue
302
+ """
303
+ task = user_request.task
304
+ config = user_request.skypilot_config
305
+ new_resources = []
306
+ for resource in task.resources:
307
+ if resource.cloud is None or str(resource.cloud) != "Kubernetes":
308
+ raise ValueError("Only `kubernetes` is permitted as a cloud on Trainy")
309
+ if not resource.accelerators:
310
+ raise ValueError(
311
+ "You must request a GPU instance. Set `accelerators: "
312
+ "H100-MEGA-80GB:8` under resources for example"
313
+ )
314
+ for accelerator, count in resource.accelerators.items():
315
+ if accelerator not in ALLOWED_GPUS:
316
+ raise ValueError(
317
+ f"{resource.accelerators} requested,"
318
+ f"only `{ALLOWED_GPUS}` allowed"
319
+ )
320
+ nodepool = (
321
+ "kubernetes",
322
+ "pod_config",
323
+ "spec",
324
+ "nodeSelector",
325
+ "cloud.google.com/gke-nodepool",
326
+ )
327
+ if config.get_nested(nodepool, None) is None:
328
+ logger.info(
329
+ "`cloud.google.com/gke-nodepool` not set, "
330
+ f"setting to default nodepool `{accelerator.lower()}-pool`"
331
+ )
332
+ config.set_nested(nodepool, f"{accelerator.lower()}-pool")
333
+
334
+ labels = resource.labels
335
+ if labels is None:
336
+ labels = {}
337
+ queue_name: str = labels.get("kueue.x-k8s.io/queue-name", DEFAULT_QUEUE)
338
+ priority: str = labels.get("kueue.x-k8s.io/priority-class", "low-priority")
339
+ run_duration: str = labels.get("max-run-duration-seconds", None)
340
+ # if queue_name != DEFAULT_QUEUE:
341
+ # raise ValueError(
342
+ # f"{queue_name} queue was selected, "
343
+ # f"only {DEFAULT_QUEUE} queue is permitted for hosted Trainy clusters"
344
+ # )
345
+ if priority not in PRIORITY_CLASSES:
346
+ raise ValueError(
347
+ f"priority `{priority}` was selected, "
348
+ f"only {PRIORITY_CLASSES} are available"
349
+ )
350
+ if task.name is None:
351
+ raise ValueError("no sky.Task name defined. You must set a task name")
352
+ if len(task.name) > 59:
353
+ raise ValueError(f"sky.Task name is {len(task.name)} long. Expected 58 characters or less.")
354
+ labels.update(
355
+ {
356
+ "kueue.x-k8s.io/queue-name": queue_name,
357
+ "kueue.x-k8s.io/priority-class": priority,
358
+ "kueue.x-k8s.io/pod-group-name": f"{task.name}-"
359
+ f"{uuid.uuid4().hex[:4]}",
360
+ }
361
+ )
362
+ if resource.labels is not None:
363
+ resource.labels.update(labels)
364
+ new_resources.append(resource)
365
+ task.set_resources(type(task.resources)(new_resources))
366
+
367
+ # pod annotations
368
+ config.set_nested(
369
+ (
370
+ "kubernetes",
371
+ "pod_config",
372
+ "metadata",
373
+ "annotations",
374
+ "kueue.x-k8s.io/pod-group-total-count",
375
+ ),
376
+ str(task.num_nodes),
377
+ )
378
+ config.set_nested(
379
+ (
380
+ "kubernetes",
381
+ "pod_config",
382
+ "metadata",
383
+ "annotations",
384
+ "kueue.x-k8s.io/retriable-in-group",
385
+ ),
386
+ "false",
387
+ )
388
+
389
+ maxRunDurationSeconds = (
390
+ "kubernetes",
391
+ "pod_config",
392
+ "metadata",
393
+ "annotations",
394
+ "provreq.kueue.x-k8s.io/maxRunDurationSeconds",
395
+ )
396
+ if config.get_nested(maxRunDurationSeconds, None) is None:
397
+ if run_duration is None:
398
+ raise ValueError("You must specify a label for `max-run-duration-seconds`")
399
+ # maximum runtime on gke dws is 7 days
400
+ config.set_nested(maxRunDurationSeconds, str(run_duration))
401
+
402
+ run_duration = config.get_nested(maxRunDurationSeconds, None)
403
+ assert run_duration is not None
404
+ if not (0 < int(run_duration) <= 3600 * 24 * 7):
405
+ raise ValueError(
406
+ f"largest allowed run duration is 7 days "
407
+ f" = {3600 * 24 * 7} seconds {int(run_duration)} requested "
408
+ "from either `max-run-duration-seconds` or "
409
+ "`provreq.kueue.x-k8s.io/maxRunDurationSeconds`"
410
+ )
411
+
412
+ safe_to_evict = (
413
+ "kubernetes",
414
+ "pod_config",
415
+ "metadata",
416
+ "annotations",
417
+ "cluster-autoscaler.kubernetes.io/safe-to-evict",
418
+ )
419
+ if config.get_nested(safe_to_evict, None) is None:
420
+ config.set_nested(safe_to_evict, "false")
421
+
422
+ return sky.MutatedUserRequest(task=task, skypilot_config=config)
423
+
424
+
425
+ class GKEPolicy(sky.AdminPolicy):
426
+ """GKE specific configurations."""
427
+
428
+ @classmethod
429
+ def validate_and_mutate(
430
+ cls, user_request: sky.UserRequest
431
+ ) -> sky.MutatedUserRequest:
432
+ """Updates the kubernetes context to use
433
+ and kueue labels and sets GKE autoscaler
434
+ """
435
+ if not user_request.task.is_controller_task():
436
+ new_request: sky.MutatedUserRequest = set_tcpxo_config(user_request)
437
+ new_request = validate_set_kueue_dws_labels_annotations(user_request)
438
+ return sky.MutatedUserRequest(
439
+ task=new_request.task, skypilot_config=new_request.skypilot_config
440
+ )
441
+ return sky.MutatedUserRequest(
442
+ task=user_request.task, skypilot_config=user_request.skypilot_config
443
+ )
444
+
445
+
446
+ def configure_and_get_allowed_contexts():
447
+ """Mock implementation of getting allowed kubernetes contexts."""
448
+ from sky.provision.kubernetes import utils
449
+
450
+ contexts = utils.get_all_kube_config_context_names()
451
+ return contexts[:2]
452
+
453
+
454
+ class TailscaleGKEPolicy(GKEPolicy):
455
+ @classmethod
456
+ def validate_and_mutate(
457
+ cls, user_request: sky.UserRequest
458
+ ) -> sky.MutatedUserRequest:
459
+ """Updates the kubernetes context to use
460
+ and kueue labels and sets GKE autoscaler
461
+ """
462
+
463
+ super().validate_and_mutate(user_request=user_request)
464
+
465
+ # Append any new kubernetes clusters in local kubeconfig. An example
466
+ # implementation of this method can be:
467
+ # 1. Query tailscale for k8s clusters.
468
+ # 2. Append the new credentials to the local kubeconfig.
469
+ # 3. Set the allow contexts for the cluster.
470
+
471
+ # Get the allowed contexts for the user. Similarly, it can retrieve
472
+ # the latest allowed contexts from an organization's internal API.
473
+ # allowed_contexts = configure_and_get_allowed_contexts()
474
+
475
+ # # Update the kubernetes allowed contexts in skypilot config.
476
+ # config = user_request.skypilot_config
477
+ # config.set_nested(("kubernetes", "allowed_contexts"), allowed_contexts)
@@ -1,224 +0,0 @@
1
- """GKE policy
2
- Users must use the following `~/.sky/config.yaml`
3
-
4
- admin_policy: trainy.policy.GKEPolicy
5
-
6
- kubernetes:
7
- autoscaler: gke
8
- provision_timeout: 600 # should be at least 10 min
9
- remote_identity: SERVICE_ACCOUNT
10
-
11
- """
12
-
13
- import sky
14
- import uuid
15
-
16
- from trainy.config import load_config
17
- from trainy.logging import get_logger
18
-
19
- logger = get_logger(__file__)
20
-
21
- DEFAULT_QUEUE = "user-queue"
22
- PRIORITY_CLASSES = ["low-priority", "high-priority"]
23
- ALLOWED_GPUS = ["H100", "H100-MEGA-80GB", "A100-80GB", "A100"]
24
-
25
-
26
- def set_tcpxo_config(user_request: sky.UserRequest) -> sky.MutatedUserRequest:
27
- """Sets pod specs for running TCPXO"""
28
- task = user_request.task
29
- config = user_request.skypilot_config
30
- for resource in task.resources:
31
- for accelerator, count in resource.accelerators.items():
32
- if accelerator == "H100-MEGA-80GB":
33
- k8s_override_config = load_config("gke.yaml")
34
- config = user_request.skypilot_config
35
- new_config = sky.skypilot_config._recursive_update(
36
- config, k8s_override_config
37
- )
38
- return sky.MutatedUserRequest(
39
- task=user_request.task, skypilot_config=new_config
40
- )
41
- return sky.MutatedUserRequest(task=task, skypilot_config=config)
42
-
43
-
44
- def validate_set_kueue_dws_labels_annotations(
45
- user_request: sky.MutatedUserRequest,
46
- ) -> sky.MutatedUserRequest:
47
- """Checks for valid Kueue user-queue and priority queue values
48
- populates the required labels and annotations for Kueue
49
- """
50
- task = user_request.task
51
- config = user_request.skypilot_config
52
- new_resources = []
53
- for resource in task.resources:
54
- if resource.cloud is None or str(resource.cloud) != "Kubernetes":
55
- raise ValueError("Only `kubernetes` is permitted as a cloud on Trainy")
56
- if not resource.accelerators:
57
- raise ValueError(
58
- "You must request a GPU instance. Set `accelerators: "
59
- "H100-MEGA-80GB:8` under resources for example"
60
- )
61
- for accelerator, count in resource.accelerators.items():
62
- if accelerator not in ALLOWED_GPUS:
63
- raise ValueError(
64
- f"{resource.accelerators} requested,"
65
- f"only `{ALLOWED_GPUS}` allowed"
66
- )
67
- nodepool = (
68
- "kubernetes",
69
- "pod_config",
70
- "spec",
71
- "nodeSelector",
72
- "cloud.google.com/gke-nodepool",
73
- )
74
- if config.get_nested(nodepool, None) is None:
75
- logger.info(
76
- "`cloud.google.com/gke-nodepool` not set, "
77
- f"setting to default nodepool `{accelerator.lower()}-pool`"
78
- )
79
- config.set_nested(nodepool, f"{accelerator.lower()}-pool")
80
-
81
- labels = resource.labels
82
- if labels is None:
83
- labels = {}
84
- queue_name: str = labels.get("kueue.x-k8s.io/queue-name", DEFAULT_QUEUE)
85
- priority: str = labels.get("kueue.x-k8s.io/priority-class", "low-priority")
86
- run_duration: str = labels.get("max-run-duration-seconds", None)
87
- # if queue_name != DEFAULT_QUEUE:
88
- # raise ValueError(
89
- # f"{queue_name} queue was selected, "
90
- # f"only {DEFAULT_QUEUE} queue is permitted for hosted Trainy clusters"
91
- # )
92
- if priority not in PRIORITY_CLASSES:
93
- raise ValueError(
94
- f"priority `{priority}` was selected, "
95
- f"only {PRIORITY_CLASSES} are available"
96
- )
97
- if task.name is None:
98
- raise ValueError("no sky.Task name defined. You must set a task name")
99
- if len(task.name) > 59:
100
- raise ValueError(f"sky.Task name is {len(task.name)} long. Expected 58 characters or less.")
101
- labels.update(
102
- {
103
- "kueue.x-k8s.io/queue-name": queue_name,
104
- "kueue.x-k8s.io/priority-class": priority,
105
- "kueue.x-k8s.io/pod-group-name": f"{task.name}-"
106
- f"{uuid.uuid4().hex[:4]}",
107
- }
108
- )
109
- if resource.labels is not None:
110
- resource.labels.update(labels)
111
- new_resources.append(resource)
112
- task.set_resources(type(task.resources)(new_resources))
113
-
114
- # pod annotations
115
- config.set_nested(
116
- (
117
- "kubernetes",
118
- "pod_config",
119
- "metadata",
120
- "annotations",
121
- "kueue.x-k8s.io/pod-group-total-count",
122
- ),
123
- str(task.num_nodes),
124
- )
125
- config.set_nested(
126
- (
127
- "kubernetes",
128
- "pod_config",
129
- "metadata",
130
- "annotations",
131
- "kueue.x-k8s.io/retriable-in-group",
132
- ),
133
- "false",
134
- )
135
-
136
- maxRunDurationSeconds = (
137
- "kubernetes",
138
- "pod_config",
139
- "metadata",
140
- "annotations",
141
- "provreq.kueue.x-k8s.io/maxRunDurationSeconds",
142
- )
143
- if config.get_nested(maxRunDurationSeconds, None) is None:
144
- if run_duration is None:
145
- raise ValueError("You must specify a label for `max-run-duration-seconds`")
146
- # maximum runtime on gke dws is 7 days
147
- config.set_nested(maxRunDurationSeconds, str(run_duration))
148
-
149
- run_duration = config.get_nested(maxRunDurationSeconds, None)
150
- assert run_duration is not None
151
- if not (0 < int(run_duration) <= 3600 * 24 * 7):
152
- raise ValueError(
153
- f"largest allowed run duration is 7 days "
154
- f" = {3600 * 24 * 7} seconds {int(run_duration)} requested "
155
- "from either `max-run-duration-seconds` or "
156
- "`provreq.kueue.x-k8s.io/maxRunDurationSeconds`"
157
- )
158
-
159
- safe_to_evict = (
160
- "kubernetes",
161
- "pod_config",
162
- "metadata",
163
- "annotations",
164
- "cluster-autoscaler.kubernetes.io/safe-to-evict",
165
- )
166
- if config.get_nested(safe_to_evict, None) is None:
167
- config.set_nested(safe_to_evict, "false")
168
-
169
- return sky.MutatedUserRequest(task=task, skypilot_config=config)
170
-
171
-
172
- class GKEPolicy(sky.AdminPolicy):
173
- """GKE specific configurations."""
174
-
175
- @classmethod
176
- def validate_and_mutate(
177
- cls, user_request: sky.UserRequest
178
- ) -> sky.MutatedUserRequest:
179
- """Updates the kubernetes context to use
180
- and kueue labels and sets GKE autoscaler
181
- """
182
- if not user_request.task.is_controller_task():
183
- new_request: sky.MutatedUserRequest = set_tcpxo_config(user_request)
184
- new_request = validate_set_kueue_dws_labels_annotations(user_request)
185
- return sky.MutatedUserRequest(
186
- task=new_request.task, skypilot_config=new_request.skypilot_config
187
- )
188
- return sky.MutatedUserRequest(
189
- task=user_request.task, skypilot_config=user_request.skypilot_config
190
- )
191
-
192
-
193
- def configure_and_get_allowed_contexts():
194
- """Mock implementation of getting allowed kubernetes contexts."""
195
- from sky.provision.kubernetes import utils
196
-
197
- contexts = utils.get_all_kube_config_context_names()
198
- return contexts[:2]
199
-
200
-
201
- class TailscaleGKEPolicy(GKEPolicy):
202
- @classmethod
203
- def validate_and_mutate(
204
- cls, user_request: sky.UserRequest
205
- ) -> sky.MutatedUserRequest:
206
- """Updates the kubernetes context to use
207
- and kueue labels and sets GKE autoscaler
208
- """
209
-
210
- super().validate_and_mutate(user_request=user_request)
211
-
212
- # Append any new kubernetes clusters in local kubeconfig. An example
213
- # implementation of this method can be:
214
- # 1. Query tailscale for k8s clusters.
215
- # 2. Append the new credentials to the local kubeconfig.
216
- # 3. Set the allow contexts for the cluster.
217
-
218
- # Get the allowed contexts for the user. Similarly, it can retrieve
219
- # the latest allowed contexts from an organization's internal API.
220
- # allowed_contexts = configure_and_get_allowed_contexts()
221
-
222
- # # Update the kubernetes allowed contexts in skypilot config.
223
- # config = user_request.skypilot_config
224
- # config.set_nested(("kubernetes", "allowed_contexts"), allowed_contexts)