trainy-policy-nightly 0.1.0.dev20251111104744__py3-none-any.whl → 0.1.0.dev20251121104539__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
trainy/__init__.py CHANGED
@@ -4,7 +4,7 @@ import os
4
4
  import subprocess
5
5
 
6
6
  # Replaced with the current commit when building the wheels.
7
- _POLICY_COMMIT_SHA = "14ff34c33b11d38c00a264197f0b885f67a65c45"
7
+ _POLICY_COMMIT_SHA = "a7ce34b2b1ed91770b4485278b256f3dbc7d609c"
8
8
 
9
9
 
10
10
  def _get_git_commit():
trainy/policy/gke.py CHANGED
@@ -10,6 +10,9 @@ kubernetes:
10
10
 
11
11
  """
12
12
 
13
+ from copy import deepcopy
14
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
15
+
13
16
  import sky
14
17
  import uuid
15
18
 
@@ -22,6 +25,252 @@ DEFAULT_QUEUE = "user-queue"
22
25
  PRIORITY_CLASSES = ["low-priority", "high-priority"]
23
26
  ALLOWED_GPUS = ["H100", "H100-MEGA-80GB", "A100-80GB", "A100"]
24
27
 
28
+ _POD_SPEC_PATH: Tuple[str, ...] = ("kubernetes", "pod_config", "spec")
29
+
30
+
31
+ def _get_dict_path(
32
+ data: Dict[str, Any], path: Sequence[str]
33
+ ) -> Optional[Dict[str, Any]]:
34
+ """Returns the nested dict at path, creating dicts along the way."""
35
+ cursor: Dict[str, Any] = data
36
+ for key in path:
37
+ if key not in cursor or not isinstance(cursor[key], dict):
38
+ cursor[key] = {}
39
+ cursor = cursor[key]
40
+ return cursor
41
+
42
+
43
+ def _merge_named_items(
44
+ base_list: Optional[List[Dict[str, Any]]],
45
+ override_list: Optional[List[Dict[str, Any]]],
46
+ ) -> List[Dict[str, Any]]:
47
+ """Merges two lists of dicts keyed by `name` while preserving order."""
48
+ merged: List[Dict[str, Any]] = []
49
+ index_by_name: Dict[str, int] = {}
50
+
51
+ def _append(item: Dict[str, Any]) -> None:
52
+ index_by_name[item.get("name", f"__idx_{len(merged)}")] = len(merged)
53
+ merged.append(item)
54
+
55
+ for entry in base_list or []:
56
+ if isinstance(entry, dict):
57
+ _append(deepcopy(entry))
58
+
59
+ for entry in override_list or []:
60
+ if not isinstance(entry, dict):
61
+ continue
62
+ name = entry.get("name")
63
+ replacement = deepcopy(entry)
64
+ if name is not None and name in index_by_name:
65
+ merged[index_by_name[name]] = replacement
66
+ else:
67
+ _append(replacement)
68
+
69
+ return merged
70
+
71
+
72
+ def _merge_container_sections(
73
+ base_containers: Optional[List[Dict[str, Any]]],
74
+ override_containers: Optional[List[Dict[str, Any]]],
75
+ ) -> List[Dict[str, Any]]:
76
+ """Merges env/volumeMounts between lists of containers."""
77
+ base_containers = base_containers or []
78
+ override_containers = override_containers or []
79
+ merged_containers: List[Dict[str, Any]] = []
80
+ max_len = max(len(base_containers), len(override_containers))
81
+
82
+ for idx in range(max_len):
83
+ base_container = base_containers[idx] if idx < len(base_containers) else None
84
+ override_container = (
85
+ deepcopy(override_containers[idx]) if idx < len(override_containers) else None
86
+ )
87
+
88
+ if override_container is None:
89
+ if base_container is not None:
90
+ merged_containers.append(deepcopy(base_container))
91
+ continue
92
+
93
+ merged_env = _merge_named_items(
94
+ base_container.get("env") if base_container else None,
95
+ override_container.get("env"),
96
+ )
97
+ if merged_env:
98
+ override_container["env"] = merged_env
99
+ elif "env" in override_container:
100
+ del override_container["env"]
101
+
102
+ merged_volume_mounts = _merge_named_items(
103
+ base_container.get("volumeMounts") if base_container else None,
104
+ override_container.get("volumeMounts"),
105
+ )
106
+ if merged_volume_mounts:
107
+ override_container["volumeMounts"] = merged_volume_mounts
108
+ elif "volumeMounts" in override_container:
109
+ del override_container["volumeMounts"]
110
+
111
+ merged_containers.append(override_container)
112
+
113
+ return merged_containers
114
+
115
+
116
+ def _merge_pod_spec_sections(
117
+ config: sky.skypilot_config.Config, override_config: Dict[str, Any]
118
+ ) -> Dict[str, Any]:
119
+ """Returns override_config with volumes/env/volumeMounts merged from config."""
120
+ merged_override = deepcopy(override_config)
121
+ spec = _get_dict_path(merged_override, _POD_SPEC_PATH)
122
+
123
+ base_volumes = config.get_nested(_POD_SPEC_PATH + ("volumes",), None)
124
+ override_volumes = spec.get("volumes")
125
+ merged_volumes = _merge_named_items(base_volumes, override_volumes)
126
+ if merged_volumes:
127
+ spec["volumes"] = merged_volumes
128
+ elif "volumes" in spec:
129
+ del spec["volumes"]
130
+
131
+ merged_containers = _merge_container_sections(
132
+ config.get_nested(_POD_SPEC_PATH + ("containers",), None),
133
+ spec.get("containers"),
134
+ )
135
+ if merged_containers:
136
+ spec["containers"] = merged_containers
137
+ elif "containers" in spec:
138
+ del spec["containers"]
139
+
140
+ merged_init_containers = _merge_container_sections(
141
+ config.get_nested(_POD_SPEC_PATH + ("initContainers",), None),
142
+ spec.get("initContainers"),
143
+ )
144
+ if merged_init_containers:
145
+ spec["initContainers"] = merged_init_containers
146
+ elif "initContainers" in spec:
147
+ del spec["initContainers"]
148
+
149
+ return merged_override
150
+
151
+ _POD_SPEC_PATH: Tuple[str, ...] = ("kubernetes", "pod_config", "spec")
152
+
153
+
154
+ def _get_dict_path(
155
+ data: Dict[str, Any], path: Sequence[str]
156
+ ) -> Optional[Dict[str, Any]]:
157
+ """Returns the nested dict at path, creating dicts along the way."""
158
+ cursor: Dict[str, Any] = data
159
+ for key in path:
160
+ if key not in cursor or not isinstance(cursor[key], dict):
161
+ cursor[key] = {}
162
+ cursor = cursor[key]
163
+ return cursor
164
+
165
+
166
+ def _merge_named_items(
167
+ base_list: Optional[List[Dict[str, Any]]],
168
+ override_list: Optional[List[Dict[str, Any]]],
169
+ ) -> List[Dict[str, Any]]:
170
+ """Merges two lists of dicts keyed by `name` while preserving order."""
171
+ merged: List[Dict[str, Any]] = []
172
+ index_by_name: Dict[str, int] = {}
173
+
174
+ def _append(item: Dict[str, Any]) -> None:
175
+ index_by_name[item.get("name", f"__idx_{len(merged)}")] = len(merged)
176
+ merged.append(item)
177
+
178
+ for entry in base_list or []:
179
+ if isinstance(entry, dict):
180
+ _append(deepcopy(entry))
181
+
182
+ for entry in override_list or []:
183
+ if not isinstance(entry, dict):
184
+ continue
185
+ name = entry.get("name")
186
+ replacement = deepcopy(entry)
187
+ if name is not None and name in index_by_name:
188
+ merged[index_by_name[name]] = replacement
189
+ else:
190
+ _append(replacement)
191
+
192
+ return merged
193
+
194
+
195
+ def _merge_container_sections(
196
+ base_containers: Optional[List[Dict[str, Any]]],
197
+ override_containers: Optional[List[Dict[str, Any]]],
198
+ ) -> List[Dict[str, Any]]:
199
+ """Merges env/volumeMounts between lists of containers."""
200
+ base_containers = base_containers or []
201
+ override_containers = override_containers or []
202
+ merged_containers: List[Dict[str, Any]] = []
203
+ max_len = max(len(base_containers), len(override_containers))
204
+
205
+ for idx in range(max_len):
206
+ base_container = base_containers[idx] if idx < len(base_containers) else None
207
+ override_container = (
208
+ deepcopy(override_containers[idx]) if idx < len(override_containers) else None
209
+ )
210
+
211
+ if override_container is None:
212
+ if base_container is not None:
213
+ merged_containers.append(deepcopy(base_container))
214
+ continue
215
+
216
+ merged_env = _merge_named_items(
217
+ base_container.get("env") if base_container else None,
218
+ override_container.get("env"),
219
+ )
220
+ if merged_env:
221
+ override_container["env"] = merged_env
222
+ elif "env" in override_container:
223
+ del override_container["env"]
224
+
225
+ merged_volume_mounts = _merge_named_items(
226
+ base_container.get("volumeMounts") if base_container else None,
227
+ override_container.get("volumeMounts"),
228
+ )
229
+ if merged_volume_mounts:
230
+ override_container["volumeMounts"] = merged_volume_mounts
231
+ elif "volumeMounts" in override_container:
232
+ del override_container["volumeMounts"]
233
+
234
+ merged_containers.append(override_container)
235
+
236
+ return merged_containers
237
+
238
+
239
+ def _merge_pod_spec_sections(
240
+ config: sky.skypilot_config.Config, override_config: Dict[str, Any]
241
+ ) -> Dict[str, Any]:
242
+ """Returns override_config with volumes/env/volumeMounts merged from config."""
243
+ merged_override = deepcopy(override_config)
244
+ spec = _get_dict_path(merged_override, _POD_SPEC_PATH)
245
+
246
+ base_volumes = config.get_nested(_POD_SPEC_PATH + ("volumes",), None)
247
+ override_volumes = spec.get("volumes")
248
+ merged_volumes = _merge_named_items(base_volumes, override_volumes)
249
+ if merged_volumes:
250
+ spec["volumes"] = merged_volumes
251
+ elif "volumes" in spec:
252
+ del spec["volumes"]
253
+
254
+ merged_containers = _merge_container_sections(
255
+ config.get_nested(_POD_SPEC_PATH + ("containers",), None),
256
+ spec.get("containers"),
257
+ )
258
+ if merged_containers:
259
+ spec["containers"] = merged_containers
260
+ elif "containers" in spec:
261
+ del spec["containers"]
262
+
263
+ merged_init_containers = _merge_container_sections(
264
+ config.get_nested(_POD_SPEC_PATH + ("initContainers",), None),
265
+ spec.get("initContainers"),
266
+ )
267
+ if merged_init_containers:
268
+ spec["initContainers"] = merged_init_containers
269
+ elif "initContainers" in spec:
270
+ del spec["initContainers"]
271
+
272
+ return merged_override
273
+
25
274
 
26
275
  def set_tcpxo_config(user_request: sky.UserRequest) -> sky.MutatedUserRequest:
27
276
  """Sets pod specs for running TCPXO"""
@@ -32,9 +281,13 @@ def set_tcpxo_config(user_request: sky.UserRequest) -> sky.MutatedUserRequest:
32
281
  if accelerator == "H100-MEGA-80GB":
33
282
  k8s_override_config = load_config("gke.yaml")
34
283
  config = user_request.skypilot_config
35
- new_config = sky.skypilot_config._recursive_update(
284
+ merged_override = _merge_pod_spec_sections(
36
285
  config, k8s_override_config
37
286
  )
287
+ new_config = sky.skypilot_config._recursive_update(
288
+ config, merged_override
289
+ )
290
+
38
291
  return sky.MutatedUserRequest(
39
292
  task=user_request.task, skypilot_config=new_config
40
293
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: trainy-policy-nightly
3
- Version: 0.1.0.dev20251111104744
3
+ Version: 0.1.0.dev20251121104539
4
4
  Summary: Trainy Skypilot Policy
5
5
  Author: Andrew Aikawa
6
6
  Author-email: asai@berkeley.edu
@@ -1,4 +1,4 @@
1
- trainy/__init__.py,sha256=JOzvZYiDSOsMOMFV3WUJWII72wYWqtturfqtMNJJTUA,1236
1
+ trainy/__init__.py,sha256=-YRjM9Yuq-e3fxZkRsqp_7BgvmNxsdBgmg7MZlhYJzk,1236
2
2
  trainy/config.py,sha256=MkAVzIvbS61y73uneWzZYS_FikGYZHI__DuPmRJ1lH0,322
3
3
  trainy/configs/eks.yaml,sha256=QlUFr3w7ZcTU0c0jqoGxUlWN1nVl5yE_b5IWd0ToDgs,554
4
4
  trainy/configs/gke.yaml,sha256=aB52IPKZhvb2ZoNn32RMkP299z3PoW2CVE-mHA5j-oE,4972
@@ -6,8 +6,8 @@ trainy/configs/nebius.yaml,sha256=hH0W_EKIYe5EaRC9pkUHWX3P7hbM46tg0OLrc6QatJw,44
6
6
  trainy/logging.py,sha256=CMHdnnZ2b-rNDTfU7bi6jQ3AmxVv0M4sOlghnVJIqko,1186
7
7
  trainy/policy/__init__.py,sha256=DElGvZKzOqFbUs90Pi53ZpNBE3BFsalgpYMos_4tboU,180
8
8
  trainy/policy/eks.py,sha256=FBgRprnSjmp94xQvCV0tnKTgxKBk2SPNIVNFg6PPrZI,2493
9
- trainy/policy/gke.py,sha256=ClkVT6W_cWosiW1dhnJCfun_euE0YSy8B5UA-u-7f9I,8173
9
+ trainy/policy/gke.py,sha256=1OLEt9g9vrt3bZOHuA3v86xzgBFsxh5yHP3InKW4mh0,16967
10
10
  trainy/policy/nebius.py,sha256=79zYL0Z6JsOomB0IF3CimSnkZfyyPvNYaJNQb0iv_Uw,3062
11
- trainy_policy_nightly-0.1.0.dev20251111104744.dist-info/METADATA,sha256=50fw9N5uKazKy_M0bs4Bv40sm_NZy8gQW1W9edjrN7E,1228
12
- trainy_policy_nightly-0.1.0.dev20251111104744.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
13
- trainy_policy_nightly-0.1.0.dev20251111104744.dist-info/RECORD,,
11
+ trainy_policy_nightly-0.1.0.dev20251121104539.dist-info/METADATA,sha256=InPLPXltF-kjnkPFaS28QkKHpT1uwhUiw8Vhzm5DveM,1228
12
+ trainy_policy_nightly-0.1.0.dev20251121104539.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
13
+ trainy_policy_nightly-0.1.0.dev20251121104539.dist-info/RECORD,,