viettelcloud-aiplatform 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- viettelcloud/__init__.py +1 -0
- viettelcloud/aiplatform/__init__.py +15 -0
- viettelcloud/aiplatform/common/__init__.py +0 -0
- viettelcloud/aiplatform/common/constants.py +22 -0
- viettelcloud/aiplatform/common/types.py +28 -0
- viettelcloud/aiplatform/common/utils.py +40 -0
- viettelcloud/aiplatform/hub/OWNERS +14 -0
- viettelcloud/aiplatform/hub/__init__.py +25 -0
- viettelcloud/aiplatform/hub/api/__init__.py +13 -0
- viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
- viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
- viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
- viettelcloud/aiplatform/optimizer/__init__.py +45 -0
- viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
- viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
- viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
- viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
- viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
- viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
- viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
- viettelcloud/aiplatform/py.typed +0 -0
- viettelcloud/aiplatform/trainer/__init__.py +82 -0
- viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
- viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
- viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
- viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/base.py +94 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
- viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
- viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
- viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
- viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
- viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
- viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
- viettelcloud/aiplatform/trainer/options/common.py +55 -0
- viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
- viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
- viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
- viettelcloud/aiplatform/trainer/test/common.py +22 -0
- viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/types/types.py +517 -0
- viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
- viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
- viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
- viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
|
@@ -0,0 +1,631 @@
|
|
|
1
|
+
# Copyright 2025 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Runtime loader for container backends (Docker, Podman).
|
|
17
|
+
|
|
18
|
+
We support loading training runtime definitions from multiple sources:
|
|
19
|
+
1. GitHub: Fetches latest runtimes from kubeflow/trainer repository (with caching)
|
|
20
|
+
2. Local bundled: Falls back to `kubeflow/trainer/config/training_runtimes/` YAML files
|
|
21
|
+
3. User custom: Additional YAML files in the local directory
|
|
22
|
+
|
|
23
|
+
The loader tries GitHub first (with 24-hour cache), then falls back to bundled files
|
|
24
|
+
if the network is unavailable or GitHub fetch fails.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
from datetime import datetime, timedelta
|
|
30
|
+
import json
|
|
31
|
+
import logging
|
|
32
|
+
from pathlib import Path
|
|
33
|
+
from typing import Any, Optional
|
|
34
|
+
import urllib.error
|
|
35
|
+
import urllib.request
|
|
36
|
+
|
|
37
|
+
import yaml
|
|
38
|
+
|
|
39
|
+
from viettelcloud.aiplatform.trainer.constants import constants
|
|
40
|
+
from viettelcloud.aiplatform.trainer.types import types as base_types
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
TRAINING_RUNTIMES_DIR = Path(__file__).parents[2] / "config" / "training_runtimes"
|
|
45
|
+
CACHE_DIR = Path.home() / ".kubeflow" / "trainer" / "cache"
|
|
46
|
+
CACHE_DURATION = timedelta(hours=24)
|
|
47
|
+
|
|
48
|
+
# GitHub runtimes configuration
|
|
49
|
+
GITHUB_RUNTIMES_BASE_URL = (
|
|
50
|
+
"https://raw.githubusercontent.com/kubeflow/trainer/master/manifests/base/runtimes"
|
|
51
|
+
)
|
|
52
|
+
GITHUB_RUNTIMES_TREE_URL = "https://github.com/kubeflow/trainer/tree/master/manifests/base/runtimes"
|
|
53
|
+
|
|
54
|
+
__all__ = [
|
|
55
|
+
"TRAINING_RUNTIMES_DIR",
|
|
56
|
+
"get_training_runtime_from_sources",
|
|
57
|
+
"list_training_runtimes_from_sources",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _load_runtime_from_yaml(path: Path) -> dict[str, Any]:
|
|
62
|
+
with open(path) as f:
|
|
63
|
+
data: dict[str, Any] = yaml.safe_load(f)
|
|
64
|
+
return data
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _discover_github_runtime_files(
|
|
68
|
+
owner: str = "kubeflow",
|
|
69
|
+
repo: str = "trainer",
|
|
70
|
+
branch: str = "master",
|
|
71
|
+
path: str = "manifests/base/runtimes",
|
|
72
|
+
) -> list[str]:
|
|
73
|
+
"""
|
|
74
|
+
Discover available runtime YAML files from GitHub repository.
|
|
75
|
+
|
|
76
|
+
Fetches the directory listing from GitHub and extracts .yaml filenames,
|
|
77
|
+
excluding kustomization.yaml and other non-runtime files.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
owner: GitHub repository owner (default: "kubeflow")
|
|
81
|
+
repo: GitHub repository name (default: "trainer")
|
|
82
|
+
branch: Git branch name (default: "master")
|
|
83
|
+
path: Path to runtimes directory (default: "manifests/base/runtimes")
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of YAML filenames (e.g., ['torch_distributed.yaml', ...])
|
|
87
|
+
Returns empty list if discovery fails.
|
|
88
|
+
"""
|
|
89
|
+
tree_url = f"https://github.com/{owner}/{repo}/tree/{branch}/{path}"
|
|
90
|
+
try:
|
|
91
|
+
logger.debug(f"Discovering runtimes from GitHub: {tree_url}")
|
|
92
|
+
with urllib.request.urlopen(tree_url, timeout=5) as response:
|
|
93
|
+
html_content = response.read().decode("utf-8")
|
|
94
|
+
|
|
95
|
+
# Parse HTML to find .yaml files
|
|
96
|
+
# Look for .yaml filenames in the HTML content
|
|
97
|
+
import re
|
|
98
|
+
|
|
99
|
+
# Pattern to match .yaml files in the HTML
|
|
100
|
+
# Matches word characters, hyphens, underscores followed by .yaml
|
|
101
|
+
pattern = r"([\w-]+\.yaml)"
|
|
102
|
+
matches = re.findall(pattern, html_content)
|
|
103
|
+
|
|
104
|
+
# Filter out kustomization.yaml, config files, and duplicates
|
|
105
|
+
# Keep only runtime files (typically named *_distributed.yaml or similar)
|
|
106
|
+
runtime_files = []
|
|
107
|
+
seen = set()
|
|
108
|
+
exclude_files = {"kustomization.yaml", "golangci.yaml", "pre-commit-config.yaml"}
|
|
109
|
+
|
|
110
|
+
for match in matches:
|
|
111
|
+
filename = match
|
|
112
|
+
if filename not in seen and filename not in exclude_files:
|
|
113
|
+
runtime_files.append(filename)
|
|
114
|
+
seen.add(filename)
|
|
115
|
+
|
|
116
|
+
logger.debug(f"Discovered {len(runtime_files)} runtime files: {runtime_files}")
|
|
117
|
+
return runtime_files
|
|
118
|
+
|
|
119
|
+
except Exception as e:
|
|
120
|
+
logger.debug(f"Failed to discover GitHub runtime files from {tree_url}: {e}")
|
|
121
|
+
return []
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _fetch_runtime_from_github(
|
|
125
|
+
runtime_file: str,
|
|
126
|
+
owner: str = "kubeflow",
|
|
127
|
+
repo: str = "trainer",
|
|
128
|
+
branch: str = "master",
|
|
129
|
+
path: str = "manifests/base/runtimes",
|
|
130
|
+
) -> Optional[dict[str, Any]]:
|
|
131
|
+
"""
|
|
132
|
+
Fetch a runtime YAML from GitHub.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
runtime_file: YAML filename to fetch
|
|
136
|
+
owner: GitHub repository owner (default: "kubeflow")
|
|
137
|
+
repo: GitHub repository name (default: "trainer")
|
|
138
|
+
branch: Git branch name (default: "master")
|
|
139
|
+
path: Path to runtimes directory (default: "manifests/base/runtimes")
|
|
140
|
+
|
|
141
|
+
Returns None if fetch fails (network error, timeout, etc.)
|
|
142
|
+
"""
|
|
143
|
+
url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}/{runtime_file}"
|
|
144
|
+
try:
|
|
145
|
+
logger.debug(f"Fetching runtime from GitHub: {url}")
|
|
146
|
+
with urllib.request.urlopen(url, timeout=5) as response:
|
|
147
|
+
content = response.read().decode("utf-8")
|
|
148
|
+
data = yaml.safe_load(content)
|
|
149
|
+
logger.debug(f"Successfully fetched {runtime_file} from GitHub")
|
|
150
|
+
return data
|
|
151
|
+
except (urllib.error.URLError, TimeoutError, Exception) as e:
|
|
152
|
+
logger.debug(f"Failed to fetch {runtime_file} from GitHub: {e}")
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _get_cached_runtime_list() -> Optional[list[str]]:
|
|
157
|
+
"""
|
|
158
|
+
Get cached runtime file list if it exists and is not expired.
|
|
159
|
+
|
|
160
|
+
Returns None if cache doesn't exist or is expired.
|
|
161
|
+
"""
|
|
162
|
+
if not CACHE_DIR.exists():
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
cache_file = CACHE_DIR / "runtime_list.json"
|
|
166
|
+
|
|
167
|
+
if not cache_file.exists():
|
|
168
|
+
return None
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
with open(cache_file) as f:
|
|
172
|
+
data = json.load(f)
|
|
173
|
+
|
|
174
|
+
cached_time = datetime.fromisoformat(data["cached_at"])
|
|
175
|
+
if datetime.now() - cached_time > CACHE_DURATION:
|
|
176
|
+
logger.debug("Runtime list cache expired")
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
logger.debug(f"Using cached runtime list: {data['files']}")
|
|
180
|
+
return data["files"]
|
|
181
|
+
except (json.JSONDecodeError, KeyError, ValueError, Exception) as e:
|
|
182
|
+
logger.debug(f"Failed to read runtime list cache: {e}")
|
|
183
|
+
return None
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _cache_runtime_list(files: list[str]) -> None:
|
|
187
|
+
"""Cache the discovered runtime file list."""
|
|
188
|
+
try:
|
|
189
|
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
190
|
+
cache_file = CACHE_DIR / "runtime_list.json"
|
|
191
|
+
|
|
192
|
+
data = {
|
|
193
|
+
"cached_at": datetime.now().isoformat(),
|
|
194
|
+
"files": files,
|
|
195
|
+
}
|
|
196
|
+
with open(cache_file, "w") as f:
|
|
197
|
+
json.dump(data, f)
|
|
198
|
+
|
|
199
|
+
logger.debug(f"Cached runtime list: {files}")
|
|
200
|
+
except Exception as e:
|
|
201
|
+
logger.debug(f"Failed to cache runtime list: {e}")
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _get_github_runtime_files() -> list[str]:
|
|
205
|
+
"""
|
|
206
|
+
Get list of runtime files from GitHub with caching.
|
|
207
|
+
|
|
208
|
+
Priority:
|
|
209
|
+
1. Check cache (if not expired)
|
|
210
|
+
2. Discover from GitHub (and cache if successful)
|
|
211
|
+
3. Return empty list if both fail
|
|
212
|
+
"""
|
|
213
|
+
# Try cache first
|
|
214
|
+
cached = _get_cached_runtime_list()
|
|
215
|
+
if cached is not None:
|
|
216
|
+
return cached
|
|
217
|
+
|
|
218
|
+
# Try GitHub discovery
|
|
219
|
+
files = _discover_github_runtime_files()
|
|
220
|
+
if files:
|
|
221
|
+
_cache_runtime_list(files)
|
|
222
|
+
return files
|
|
223
|
+
|
|
224
|
+
return []
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _get_cached_runtime(runtime_file: str) -> Optional[dict[str, Any]]:
|
|
228
|
+
"""
|
|
229
|
+
Get cached runtime if it exists and is not expired.
|
|
230
|
+
|
|
231
|
+
Returns None if cache doesn't exist or is expired.
|
|
232
|
+
"""
|
|
233
|
+
if not CACHE_DIR.exists():
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
cache_file = CACHE_DIR / runtime_file
|
|
237
|
+
metadata_file = CACHE_DIR / f"{runtime_file}.metadata"
|
|
238
|
+
|
|
239
|
+
if not cache_file.exists() or not metadata_file.exists():
|
|
240
|
+
return None
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
# Check if cache is expired
|
|
244
|
+
with open(metadata_file) as f:
|
|
245
|
+
metadata = json.load(f)
|
|
246
|
+
|
|
247
|
+
cached_time = datetime.fromisoformat(metadata["cached_at"])
|
|
248
|
+
if datetime.now() - cached_time > CACHE_DURATION:
|
|
249
|
+
logger.debug(f"Cache expired for {runtime_file}")
|
|
250
|
+
return None
|
|
251
|
+
|
|
252
|
+
# Load cached runtime
|
|
253
|
+
with open(cache_file) as f:
|
|
254
|
+
data = yaml.safe_load(f)
|
|
255
|
+
|
|
256
|
+
logger.debug(f"Using cached runtime: {runtime_file}")
|
|
257
|
+
return data
|
|
258
|
+
except (json.JSONDecodeError, KeyError, ValueError, Exception) as e:
|
|
259
|
+
logger.debug(f"Failed to read cache for {runtime_file}: {e}")
|
|
260
|
+
return None
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _cache_runtime(runtime_file: str, data: dict[str, Any]) -> None:
|
|
264
|
+
"""Cache a runtime YAML with metadata."""
|
|
265
|
+
try:
|
|
266
|
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
267
|
+
|
|
268
|
+
cache_file = CACHE_DIR / runtime_file
|
|
269
|
+
metadata_file = CACHE_DIR / f"{runtime_file}.metadata"
|
|
270
|
+
|
|
271
|
+
# Write runtime data
|
|
272
|
+
with open(cache_file, "w") as f:
|
|
273
|
+
yaml.safe_dump(data, f)
|
|
274
|
+
|
|
275
|
+
# Write metadata
|
|
276
|
+
metadata = {
|
|
277
|
+
"cached_at": datetime.now().isoformat(),
|
|
278
|
+
"source": "github",
|
|
279
|
+
}
|
|
280
|
+
with open(metadata_file, "w") as f:
|
|
281
|
+
json.dump(metadata, f)
|
|
282
|
+
|
|
283
|
+
logger.debug(f"Cached runtime: {runtime_file}")
|
|
284
|
+
except Exception as e:
|
|
285
|
+
logger.debug(f"Failed to cache {runtime_file}: {e}")
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _load_runtime_from_github_with_cache(runtime_file: str) -> Optional[dict[str, Any]]:
|
|
289
|
+
"""
|
|
290
|
+
Load runtime from GitHub with caching.
|
|
291
|
+
|
|
292
|
+
Priority:
|
|
293
|
+
1. Check cache (if not expired)
|
|
294
|
+
2. Fetch from GitHub (and cache if successful)
|
|
295
|
+
3. Return None if both fail
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
runtime_file: YAML filename to load
|
|
299
|
+
"""
|
|
300
|
+
# Try cache first
|
|
301
|
+
cached = _get_cached_runtime(runtime_file)
|
|
302
|
+
if cached is not None:
|
|
303
|
+
return cached
|
|
304
|
+
|
|
305
|
+
# Try GitHub
|
|
306
|
+
data = _fetch_runtime_from_github(runtime_file)
|
|
307
|
+
if data is not None:
|
|
308
|
+
_cache_runtime(runtime_file, data)
|
|
309
|
+
return data
|
|
310
|
+
|
|
311
|
+
return None
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _create_default_runtimes() -> list[base_types.Runtime]:
|
|
315
|
+
"""
|
|
316
|
+
Create default Runtime objects from DEFAULT_FRAMEWORK_IMAGES constant.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
List of default Runtime objects for each framework.
|
|
320
|
+
"""
|
|
321
|
+
default_runtimes = []
|
|
322
|
+
|
|
323
|
+
for framework, image in constants.DEFAULT_FRAMEWORK_IMAGES.items():
|
|
324
|
+
runtime = base_types.Runtime(
|
|
325
|
+
name=f"{framework}-distributed",
|
|
326
|
+
trainer=base_types.RuntimeTrainer(
|
|
327
|
+
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
|
|
328
|
+
framework=framework,
|
|
329
|
+
num_nodes=1,
|
|
330
|
+
image=image,
|
|
331
|
+
),
|
|
332
|
+
)
|
|
333
|
+
default_runtimes.append(runtime)
|
|
334
|
+
logger.debug(f"Created default runtime: {runtime.name} with image {image}")
|
|
335
|
+
|
|
336
|
+
return default_runtimes
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def _parse_runtime_yaml(data: dict[str, Any], source: str = "unknown") -> base_types.Runtime:
|
|
340
|
+
"""
|
|
341
|
+
Parse a runtime YAML dict into a Runtime object.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
data: The YAML data as a dictionary
|
|
345
|
+
source: Source of the YAML (for error messages)
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Runtime object
|
|
349
|
+
|
|
350
|
+
Raises:
|
|
351
|
+
ValueError: If the YAML is malformed or missing required fields
|
|
352
|
+
"""
|
|
353
|
+
# Require CRD-like schema strictly. Accept both ClusterTrainingRuntime
|
|
354
|
+
# and TrainingRuntime kinds.
|
|
355
|
+
if not (
|
|
356
|
+
data.get("kind") in {"ClusterTrainingRuntime", "TrainingRuntime"} and data.get("metadata")
|
|
357
|
+
):
|
|
358
|
+
raise ValueError(
|
|
359
|
+
f"Runtime YAML from {source} must be a ClusterTrainingRuntime CRD-shaped document"
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
name = data["metadata"].get("name")
|
|
363
|
+
if not name:
|
|
364
|
+
raise ValueError(f"Runtime YAML from {source} missing metadata.name")
|
|
365
|
+
|
|
366
|
+
labels = data["metadata"].get("labels", {})
|
|
367
|
+
framework = labels.get("trainer.kubeflow.org/framework")
|
|
368
|
+
if not framework:
|
|
369
|
+
raise ValueError(
|
|
370
|
+
f"Runtime {name} from {source} must set "
|
|
371
|
+
f"metadata.labels['trainer.kubeflow.org/framework']"
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
spec = data.get("spec", {})
|
|
375
|
+
ml_policy = spec.get("mlPolicy", {})
|
|
376
|
+
num_nodes = int(ml_policy.get("numNodes", 1))
|
|
377
|
+
|
|
378
|
+
# Validate presence of a 'node' replicated job with a container image
|
|
379
|
+
templ = spec.get("template", {}).get("spec", {})
|
|
380
|
+
replicated = templ.get("replicatedJobs", [])
|
|
381
|
+
node_jobs = [j for j in replicated if j.get("name") == "node"]
|
|
382
|
+
if not node_jobs:
|
|
383
|
+
raise ValueError(
|
|
384
|
+
f"Runtime {name} from {source} must define replicatedJobs with a 'node' entry"
|
|
385
|
+
)
|
|
386
|
+
node_spec = node_jobs[0].get("template", {}).get("spec", {}).get("template", {}).get("spec", {})
|
|
387
|
+
containers = node_spec.get("containers", [])
|
|
388
|
+
if not containers:
|
|
389
|
+
raise ValueError(f"Runtime {name} from {source} 'node' must specify at least one container")
|
|
390
|
+
|
|
391
|
+
# Extract the container image from the container named 'node', or fallback to first container
|
|
392
|
+
image = None
|
|
393
|
+
for container in containers:
|
|
394
|
+
if container.get("name") == "node" and container.get("image"):
|
|
395
|
+
image = container.get("image")
|
|
396
|
+
break
|
|
397
|
+
|
|
398
|
+
# Fallback to first container with an image if no 'node' container found
|
|
399
|
+
if not image:
|
|
400
|
+
for container in containers:
|
|
401
|
+
if container.get("image"):
|
|
402
|
+
image = container.get("image")
|
|
403
|
+
break
|
|
404
|
+
|
|
405
|
+
if not image:
|
|
406
|
+
raise ValueError(
|
|
407
|
+
f"Runtime {name} from {source} 'node' must specify an image in at least one container"
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
return base_types.Runtime(
|
|
411
|
+
name=name,
|
|
412
|
+
trainer=base_types.RuntimeTrainer(
|
|
413
|
+
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
|
|
414
|
+
framework=framework,
|
|
415
|
+
num_nodes=num_nodes,
|
|
416
|
+
image=image,
|
|
417
|
+
),
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def _parse_source_url(source: str) -> tuple[str, str]:
|
|
422
|
+
"""
|
|
423
|
+
Parse a source URL to determine its type and path.
|
|
424
|
+
|
|
425
|
+
Args:
|
|
426
|
+
source: Source URL with scheme (github://, https://, file://, or absolute path)
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
Tuple of (source_type, path) where source_type is one of:
|
|
430
|
+
'github', 'http', 'https', 'file'
|
|
431
|
+
"""
|
|
432
|
+
if source.startswith("github://"):
|
|
433
|
+
return ("github", source[9:]) # Remove 'github://'
|
|
434
|
+
elif source.startswith("https://"):
|
|
435
|
+
return ("https", source)
|
|
436
|
+
elif source.startswith("http://"):
|
|
437
|
+
return ("http", source)
|
|
438
|
+
elif source.startswith("file://"):
|
|
439
|
+
return ("file", source[7:]) # Remove 'file://'
|
|
440
|
+
elif source.startswith("/"):
|
|
441
|
+
# Absolute path without file:// prefix
|
|
442
|
+
return ("file", source)
|
|
443
|
+
else:
|
|
444
|
+
raise ValueError(
|
|
445
|
+
f"Unsupported source URL scheme: {source}. "
|
|
446
|
+
f"Supported: github://, https://, http://, file://, or absolute paths"
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def _load_from_github_url(github_path: str) -> list[base_types.Runtime]:
|
|
451
|
+
"""
|
|
452
|
+
Load runtimes from a GitHub URL (github://owner/repo[/path]).
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
github_path: Path after github:// (e.g., "kubeflow/trainer" or "myorg/myrepo")
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
List of Runtime objects loaded from GitHub
|
|
459
|
+
"""
|
|
460
|
+
runtimes = []
|
|
461
|
+
runtime_names_seen = set()
|
|
462
|
+
|
|
463
|
+
# Parse the GitHub path
|
|
464
|
+
# Format: owner/repo[/path/to/runtimes]
|
|
465
|
+
parts = github_path.split("/")
|
|
466
|
+
if len(parts) < 2:
|
|
467
|
+
logger.warning(f"Invalid GitHub path format: {github_path}. Expected owner/repo[/path]")
|
|
468
|
+
return runtimes
|
|
469
|
+
|
|
470
|
+
owner = parts[0]
|
|
471
|
+
repo = parts[1]
|
|
472
|
+
# Custom path if provided (default to manifests/base/runtimes)
|
|
473
|
+
custom_path = "/".join(parts[2:]) if len(parts) > 2 else "manifests/base/runtimes"
|
|
474
|
+
|
|
475
|
+
# Discover runtime files from the specified GitHub repo
|
|
476
|
+
logger.debug(f"Loading runtimes from GitHub: {owner}/{repo}/{custom_path}")
|
|
477
|
+
github_runtime_files = _discover_github_runtime_files(owner=owner, repo=repo, path=custom_path)
|
|
478
|
+
|
|
479
|
+
for runtime_file in github_runtime_files:
|
|
480
|
+
try:
|
|
481
|
+
data = _fetch_runtime_from_github(
|
|
482
|
+
runtime_file, owner=owner, repo=repo, path=custom_path
|
|
483
|
+
)
|
|
484
|
+
if data is not None:
|
|
485
|
+
runtime = _parse_runtime_yaml(data, source=f"github://{github_path}/{runtime_file}")
|
|
486
|
+
if runtime.name not in runtime_names_seen:
|
|
487
|
+
runtimes.append(runtime)
|
|
488
|
+
runtime_names_seen.add(runtime.name)
|
|
489
|
+
logger.debug(f"Loaded runtime from GitHub: {runtime.name}")
|
|
490
|
+
except Exception as e:
|
|
491
|
+
logger.debug(f"Failed to parse GitHub runtime {runtime_file}: {e}")
|
|
492
|
+
|
|
493
|
+
return runtimes
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def _load_from_http_url(url: str) -> list[base_types.Runtime]:
|
|
497
|
+
"""
|
|
498
|
+
Load runtimes from an HTTP(S) URL.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
url: HTTP(S) URL to a runtime YAML file or directory listing
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
List of Runtime objects loaded from HTTP(S)
|
|
505
|
+
"""
|
|
506
|
+
runtimes = []
|
|
507
|
+
|
|
508
|
+
try:
|
|
509
|
+
import urllib.request
|
|
510
|
+
|
|
511
|
+
logger.debug(f"Fetching runtime from HTTP: {url}")
|
|
512
|
+
with urllib.request.urlopen(url, timeout=5) as response:
|
|
513
|
+
content = response.read().decode("utf-8")
|
|
514
|
+
import yaml
|
|
515
|
+
|
|
516
|
+
data = yaml.safe_load(content)
|
|
517
|
+
runtime = _parse_runtime_yaml(data, source=url)
|
|
518
|
+
runtimes.append(runtime)
|
|
519
|
+
logger.debug(f"Loaded runtime from HTTP: {runtime.name}")
|
|
520
|
+
except Exception as e:
|
|
521
|
+
logger.debug(f"Failed to load runtime from HTTP {url}: {e}")
|
|
522
|
+
|
|
523
|
+
return runtimes
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
def _load_from_filesystem(path: str) -> list[base_types.Runtime]:
|
|
527
|
+
"""
|
|
528
|
+
Load runtimes from local filesystem path.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
path: Local filesystem path to a directory or YAML file
|
|
532
|
+
|
|
533
|
+
Returns:
|
|
534
|
+
List of Runtime objects loaded from filesystem
|
|
535
|
+
"""
|
|
536
|
+
from pathlib import Path
|
|
537
|
+
|
|
538
|
+
runtimes = []
|
|
539
|
+
runtime_path = Path(path).expanduser()
|
|
540
|
+
|
|
541
|
+
try:
|
|
542
|
+
if runtime_path.is_dir():
|
|
543
|
+
# Load all YAML files from directory
|
|
544
|
+
for yaml_file in sorted(runtime_path.glob("*.yaml")):
|
|
545
|
+
try:
|
|
546
|
+
data = _load_runtime_from_yaml(yaml_file)
|
|
547
|
+
runtime = _parse_runtime_yaml(data, source=str(yaml_file))
|
|
548
|
+
runtimes.append(runtime)
|
|
549
|
+
logger.debug(f"Loaded runtime from file: {runtime.name}")
|
|
550
|
+
except Exception as e:
|
|
551
|
+
logger.warning(f"Failed to load runtime from {yaml_file}: {e}")
|
|
552
|
+
elif runtime_path.is_file():
|
|
553
|
+
# Load single YAML file
|
|
554
|
+
data = _load_runtime_from_yaml(runtime_path)
|
|
555
|
+
runtime = _parse_runtime_yaml(data, source=str(runtime_path))
|
|
556
|
+
runtimes.append(runtime)
|
|
557
|
+
logger.debug(f"Loaded runtime from file: {runtime.name}")
|
|
558
|
+
else:
|
|
559
|
+
logger.warning(f"Path does not exist: {runtime_path}")
|
|
560
|
+
except Exception as e:
|
|
561
|
+
logger.warning(f"Failed to load runtimes from {path}: {e}")
|
|
562
|
+
|
|
563
|
+
return runtimes
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def list_training_runtimes_from_sources(sources: list[str]) -> list[base_types.Runtime]:
|
|
567
|
+
"""
|
|
568
|
+
List all available training runtimes from configured sources.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
sources: List of source URLs with schemes (github://, https://, http://, file://, or paths)
|
|
572
|
+
|
|
573
|
+
Returns:
|
|
574
|
+
List of Runtime objects (built-in runtimes used as default if not found in sources)
|
|
575
|
+
"""
|
|
576
|
+
runtimes: list[base_types.Runtime] = []
|
|
577
|
+
runtime_names_seen = set()
|
|
578
|
+
|
|
579
|
+
# Load from each configured source in priority order
|
|
580
|
+
for source in sources:
|
|
581
|
+
try:
|
|
582
|
+
source_type, source_path = _parse_source_url(source)
|
|
583
|
+
|
|
584
|
+
if source_type == "github":
|
|
585
|
+
source_runtimes = _load_from_github_url(source_path)
|
|
586
|
+
elif source_type in ("http", "https"):
|
|
587
|
+
source_runtimes = _load_from_http_url(source)
|
|
588
|
+
elif source_type == "file":
|
|
589
|
+
source_runtimes = _load_from_filesystem(source_path)
|
|
590
|
+
else:
|
|
591
|
+
logger.warning(f"Unsupported source type: {source_type}")
|
|
592
|
+
continue
|
|
593
|
+
|
|
594
|
+
# Add runtimes, skipping duplicates
|
|
595
|
+
for runtime in source_runtimes:
|
|
596
|
+
if runtime.name not in runtime_names_seen:
|
|
597
|
+
runtimes.append(runtime)
|
|
598
|
+
runtime_names_seen.add(runtime.name)
|
|
599
|
+
except Exception as e:
|
|
600
|
+
logger.debug(f"Failed to load from source {source}: {e}")
|
|
601
|
+
|
|
602
|
+
# Fallback to default runtimes from constants if not found in sources
|
|
603
|
+
for default_runtime in _create_default_runtimes():
|
|
604
|
+
if default_runtime.name not in runtime_names_seen:
|
|
605
|
+
runtimes.append(default_runtime)
|
|
606
|
+
runtime_names_seen.add(default_runtime.name)
|
|
607
|
+
|
|
608
|
+
return runtimes
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
def get_training_runtime_from_sources(name: str, sources: list[str]) -> base_types.Runtime:
|
|
612
|
+
"""
|
|
613
|
+
Get a specific training runtime by name from configured sources.
|
|
614
|
+
|
|
615
|
+
Args:
|
|
616
|
+
name: The name of the runtime to get
|
|
617
|
+
sources: List of source URLs with schemes
|
|
618
|
+
|
|
619
|
+
Returns:
|
|
620
|
+
Runtime object
|
|
621
|
+
|
|
622
|
+
Raises:
|
|
623
|
+
ValueError: If the runtime is not found
|
|
624
|
+
"""
|
|
625
|
+
for rt in list_training_runtimes_from_sources(sources):
|
|
626
|
+
if rt.name == name:
|
|
627
|
+
return rt
|
|
628
|
+
raise ValueError(
|
|
629
|
+
f"Runtime '{name}' not found. Available runtimes: "
|
|
630
|
+
f"{[rt.name for rt in list_training_runtimes_from_sources(sources)]}"
|
|
631
|
+
)
|