viettelcloud-aiplatform 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- viettelcloud/__init__.py +1 -0
- viettelcloud/aiplatform/__init__.py +15 -0
- viettelcloud/aiplatform/common/__init__.py +0 -0
- viettelcloud/aiplatform/common/constants.py +22 -0
- viettelcloud/aiplatform/common/types.py +28 -0
- viettelcloud/aiplatform/common/utils.py +40 -0
- viettelcloud/aiplatform/hub/OWNERS +14 -0
- viettelcloud/aiplatform/hub/__init__.py +25 -0
- viettelcloud/aiplatform/hub/api/__init__.py +13 -0
- viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
- viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
- viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
- viettelcloud/aiplatform/optimizer/__init__.py +45 -0
- viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
- viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
- viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
- viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
- viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
- viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
- viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
- viettelcloud/aiplatform/py.typed +0 -0
- viettelcloud/aiplatform/trainer/__init__.py +82 -0
- viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
- viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
- viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
- viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/base.py +94 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
- viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
- viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
- viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
- viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
- viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
- viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
- viettelcloud/aiplatform/trainer/options/common.py +55 -0
- viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
- viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
- viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
- viettelcloud/aiplatform/trainer/test/common.py +22 -0
- viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/types/types.py +517 -0
- viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
- viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
- viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
- viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# Copyright 2025 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
import subprocess
|
|
18
|
+
import threading
|
|
19
|
+
from typing import Union
|
|
20
|
+
|
|
21
|
+
from viettelcloud.aiplatform.trainer.constants import constants
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LocalJob(threading.Thread):
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
name,
|
|
30
|
+
command: Union[list, tuple[str], str],
|
|
31
|
+
execution_dir: str = None,
|
|
32
|
+
env: dict[str, str] = None,
|
|
33
|
+
dependencies: list = None,
|
|
34
|
+
):
|
|
35
|
+
"""Creates a LocalJob.
|
|
36
|
+
|
|
37
|
+
Creates a local subprocess with threading to allow users to create background jobs.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
name (str): The name of the job.
|
|
41
|
+
command (str): The command to run.
|
|
42
|
+
execution_dir (str): The execution directory.
|
|
43
|
+
env (Dict[str, str], optional): Environment variables. Defaults to None.
|
|
44
|
+
dependencies (List[str], optional): List of dependencies. Defaults to None.
|
|
45
|
+
"""
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.name = name
|
|
48
|
+
self.command = command
|
|
49
|
+
self._stdout = ""
|
|
50
|
+
self._returncode = None
|
|
51
|
+
self._success = False
|
|
52
|
+
self._status = constants.TRAINJOB_CREATED
|
|
53
|
+
self._lock = threading.Lock()
|
|
54
|
+
self._process = None
|
|
55
|
+
self._output_updated = threading.Event()
|
|
56
|
+
self._cancel_requested = threading.Event()
|
|
57
|
+
self._start_time = None
|
|
58
|
+
self._end_time = None
|
|
59
|
+
self.env = env or {}
|
|
60
|
+
self.dependencies = dependencies or []
|
|
61
|
+
self.execution_dir = execution_dir or os.getcwd()
|
|
62
|
+
|
|
63
|
+
def run(self):
|
|
64
|
+
for dep in self.dependencies:
|
|
65
|
+
dep.join()
|
|
66
|
+
if not dep.success:
|
|
67
|
+
with self._lock:
|
|
68
|
+
self._stdout = f"Dependency {dep.name} failed. Skipping"
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
current_dir = os.getcwd()
|
|
72
|
+
try:
|
|
73
|
+
self._start_time = datetime.now()
|
|
74
|
+
_c = " ".join(self.command)
|
|
75
|
+
logger.debug(f"[{self.name}] Started at {self._start_time} with command: \n {_c}")
|
|
76
|
+
|
|
77
|
+
# change working directory to venv before executing script
|
|
78
|
+
os.chdir(self.execution_dir)
|
|
79
|
+
|
|
80
|
+
self._process = subprocess.Popen(
|
|
81
|
+
self.command,
|
|
82
|
+
stdout=subprocess.PIPE,
|
|
83
|
+
stderr=subprocess.STDOUT,
|
|
84
|
+
text=True,
|
|
85
|
+
encoding="utf-8",
|
|
86
|
+
bufsize=1,
|
|
87
|
+
env=self.env,
|
|
88
|
+
)
|
|
89
|
+
# set job status
|
|
90
|
+
self._status = constants.TRAINJOB_RUNNING
|
|
91
|
+
|
|
92
|
+
while True:
|
|
93
|
+
if self._cancel_requested.is_set():
|
|
94
|
+
self._process.terminate()
|
|
95
|
+
self._stdout += "[JobCancelled]\n"
|
|
96
|
+
self._status = constants.TRAINJOB_FAILED
|
|
97
|
+
self._success = False
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
# Read output line by line (for streaming)
|
|
101
|
+
output_line = self._process.stdout.readline()
|
|
102
|
+
with self._lock:
|
|
103
|
+
if output_line:
|
|
104
|
+
self._stdout += output_line
|
|
105
|
+
self._output_updated.set()
|
|
106
|
+
|
|
107
|
+
if not output_line and self._process.poll() is not None:
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
self._process.stdout.close()
|
|
111
|
+
self._returncode = self._process.wait()
|
|
112
|
+
self._end_time = datetime.now()
|
|
113
|
+
self._success = self._process.returncode == 0
|
|
114
|
+
msg = (
|
|
115
|
+
f"[{self.name}] Completed with code {self._returncode}"
|
|
116
|
+
f" in {self._end_time - self._start_time} seconds."
|
|
117
|
+
)
|
|
118
|
+
# set status based on success or failure
|
|
119
|
+
self._status = (
|
|
120
|
+
constants.TRAINJOB_COMPLETE if self._success else (constants.TRAINJOB_FAILED)
|
|
121
|
+
)
|
|
122
|
+
self._stdout += msg
|
|
123
|
+
logger.debug("Job output: ", self._stdout)
|
|
124
|
+
|
|
125
|
+
except Exception as e:
|
|
126
|
+
with self._lock:
|
|
127
|
+
self._stdout += f"Exception: {e}\n"
|
|
128
|
+
self._success = False
|
|
129
|
+
self._status = constants.TRAINJOB_FAILED
|
|
130
|
+
finally:
|
|
131
|
+
os.chdir(current_dir)
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def stdout(self):
|
|
135
|
+
with self._lock:
|
|
136
|
+
return self._stdout
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def success(self):
|
|
140
|
+
return self._success
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def status(self):
|
|
144
|
+
return self._status
|
|
145
|
+
|
|
146
|
+
def cancel(self):
|
|
147
|
+
self._cancel_requested.set()
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def returncode(self):
|
|
151
|
+
return self._returncode
|
|
152
|
+
|
|
153
|
+
def logs(self, follow=False) -> list[str]:
|
|
154
|
+
if not follow:
|
|
155
|
+
return self._stdout.splitlines()
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
for chunk in self.stream_logs():
|
|
159
|
+
print(chunk, end="", flush=True) # stream to console live
|
|
160
|
+
except StopIteration:
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
return self._stdout.splitlines()
|
|
164
|
+
|
|
165
|
+
def stream_logs(self):
|
|
166
|
+
"""Generator that yields new output lines as they come in."""
|
|
167
|
+
last_index = 0
|
|
168
|
+
while self.is_alive() or last_index < len(self._stdout):
|
|
169
|
+
self._output_updated.wait(timeout=1)
|
|
170
|
+
with self._lock:
|
|
171
|
+
data = self._stdout
|
|
172
|
+
new_data = data[last_index:]
|
|
173
|
+
last_index = len(data)
|
|
174
|
+
self._output_updated.clear()
|
|
175
|
+
if new_data:
|
|
176
|
+
yield new_data
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def creation_time(self):
|
|
180
|
+
return self._start_time
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def completion_time(self):
|
|
184
|
+
return self._end_time
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright 2025 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
import typing
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel
|
|
22
|
+
|
|
23
|
+
from viettelcloud.aiplatform.trainer.backends.localprocess.job import LocalJob
|
|
24
|
+
from viettelcloud.aiplatform.trainer.types import types
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LocalProcessBackendConfig(BaseModel):
|
|
28
|
+
cleanup_venv: bool = True
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class LocalRuntimeTrainer(types.RuntimeTrainer):
|
|
33
|
+
packages: list[str] = field(default_factory=list)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class LocalBackendStep(BaseModel):
|
|
37
|
+
step_name: str
|
|
38
|
+
job: LocalJob
|
|
39
|
+
|
|
40
|
+
class Config:
|
|
41
|
+
arbitrary_types_allowed = True
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LocalBackendJobs(BaseModel):
|
|
45
|
+
steps: Optional[list[LocalBackendStep]] = []
|
|
46
|
+
runtime: Optional[types.Runtime] = None
|
|
47
|
+
name: str
|
|
48
|
+
created: typing.Optional[datetime] = None
|
|
49
|
+
completed: typing.Optional[datetime] = None
|
|
50
|
+
|
|
51
|
+
class Config:
|
|
52
|
+
arbitrary_types_allowed = True
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import re
|
|
5
|
+
import shutil
|
|
6
|
+
from string import Template
|
|
7
|
+
import textwrap
|
|
8
|
+
from typing import Any, Callable, Optional
|
|
9
|
+
|
|
10
|
+
from viettelcloud.aiplatform.trainer.backends.localprocess import constants as local_exec_constants
|
|
11
|
+
from viettelcloud.aiplatform.trainer.backends.localprocess.types import LocalRuntimeTrainer
|
|
12
|
+
from viettelcloud.aiplatform.trainer.constants import constants
|
|
13
|
+
from viettelcloud.aiplatform.trainer.types import types
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _extract_name(requirement: str) -> str:
|
|
17
|
+
"""
|
|
18
|
+
Extract the base distribution name from a requirement string without external deps.
|
|
19
|
+
|
|
20
|
+
Supports common PEP 508 patterns:
|
|
21
|
+
- 'package'
|
|
22
|
+
- 'package[extra1,extra2]'
|
|
23
|
+
- 'package==1.2.3', 'package>=1.0', 'package~=1.4', etc.
|
|
24
|
+
- 'package @ https://...'
|
|
25
|
+
- markers after ';' are irrelevant for name extraction.
|
|
26
|
+
|
|
27
|
+
Returns the *raw* (un-normalized) name as it appears.
|
|
28
|
+
Raises ValueError if a name cannot be parsed.
|
|
29
|
+
"""
|
|
30
|
+
if requirement is None:
|
|
31
|
+
raise ValueError("Requirement string cannot be None")
|
|
32
|
+
s = requirement.strip()
|
|
33
|
+
if not s:
|
|
34
|
+
raise ValueError("Empty requirement string")
|
|
35
|
+
|
|
36
|
+
m = local_exec_constants.PYTHON_PACKAGE_NAME_RE.match(s)
|
|
37
|
+
if not m:
|
|
38
|
+
raise ValueError(f"Could not parse package name from requirement: {requirement!r}")
|
|
39
|
+
return m.group(1)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _canonicalize_name(name: str) -> str:
|
|
43
|
+
"""
|
|
44
|
+
PEP 503-style normalization: case-insensitive, and collapse runs of -, _, . into '-'.
|
|
45
|
+
"""
|
|
46
|
+
return re.sub(r"[-_.]+", "-", name).lower()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_install_packages(
|
|
50
|
+
runtime_packages: list[str],
|
|
51
|
+
trainer_packages: Optional[list[str]] = None,
|
|
52
|
+
) -> list[str]:
|
|
53
|
+
"""
|
|
54
|
+
Merge two requirement lists into a single list of strings.
|
|
55
|
+
|
|
56
|
+
Rules implemented:
|
|
57
|
+
1) If a package appears in trainer_packages, it overwrites the one in runtime_packages.
|
|
58
|
+
We keep the *trainer string verbatim* (specifier, markers, extras, spacing).
|
|
59
|
+
2) Case-insensitive matching of package names (PEP 503-style normalization).
|
|
60
|
+
3) Output is a list of strings.
|
|
61
|
+
4) If trainer_packages contains the same dependency multiple times (case-insensitive),
|
|
62
|
+
raise ValueError.
|
|
63
|
+
5) If runtime_packages contains duplicates, the last one among *runtime* wins there
|
|
64
|
+
(no error), but any trainer entry still overwrites it. Runtime packages shouldn't
|
|
65
|
+
have any duplicates.
|
|
66
|
+
6) Ordering: keep runtime-only packages in their original order (emitting only their
|
|
67
|
+
last occurrence), then append all trainer packages in their original order.
|
|
68
|
+
"""
|
|
69
|
+
if not trainer_packages:
|
|
70
|
+
return runtime_packages
|
|
71
|
+
|
|
72
|
+
# --- Parse + normalize runtime ---
|
|
73
|
+
runtime_parsed: list[tuple[str, str]] = [] # (orig, canonical_name)
|
|
74
|
+
last_runtime_index_by_name: dict[str, int] = {}
|
|
75
|
+
|
|
76
|
+
for i, orig in enumerate(runtime_packages):
|
|
77
|
+
raw_name = _extract_name(orig)
|
|
78
|
+
canon = _canonicalize_name(raw_name)
|
|
79
|
+
runtime_parsed.append((orig, canon))
|
|
80
|
+
last_runtime_index_by_name[canon] = i # last occurrence index wins among runtime
|
|
81
|
+
|
|
82
|
+
# --- Parse + validate trainer (detect duplicates) ---
|
|
83
|
+
trainer_parsed: list[tuple[str, str]] = []
|
|
84
|
+
seen_trainer: set[str] = set()
|
|
85
|
+
for orig in trainer_packages:
|
|
86
|
+
raw_name = _extract_name(orig)
|
|
87
|
+
canon = _canonicalize_name(raw_name)
|
|
88
|
+
if canon in seen_trainer:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"Duplicate dependency in trainer_packages: '{raw_name}' (canonical: '{canon}')"
|
|
91
|
+
)
|
|
92
|
+
seen_trainer.add(canon)
|
|
93
|
+
trainer_parsed.append((orig, canon))
|
|
94
|
+
|
|
95
|
+
trainer_names: set[str] = {canon for _, canon in trainer_parsed}
|
|
96
|
+
|
|
97
|
+
# --- Build merged list respecting order semantics ---
|
|
98
|
+
merged: list[str] = []
|
|
99
|
+
|
|
100
|
+
# 1) Runtime-only packages (only emit the last occurrence for each name)
|
|
101
|
+
emitted_runtime_names: set[str] = set()
|
|
102
|
+
for idx, (orig, canon) in enumerate(runtime_parsed):
|
|
103
|
+
if canon in trainer_names:
|
|
104
|
+
continue # overwritten by trainer
|
|
105
|
+
if last_runtime_index_by_name[canon] == idx and canon not in emitted_runtime_names:
|
|
106
|
+
merged.append(orig)
|
|
107
|
+
emitted_runtime_names.add(canon)
|
|
108
|
+
|
|
109
|
+
# 2) Trainer packages (overwrite and preserve trainer's exact strings, original order)
|
|
110
|
+
for orig, _ in trainer_parsed:
|
|
111
|
+
merged.append(orig)
|
|
112
|
+
|
|
113
|
+
return merged
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_local_runtime_trainer(
|
|
117
|
+
runtime_name: str,
|
|
118
|
+
venv_dir: str,
|
|
119
|
+
framework: str,
|
|
120
|
+
) -> LocalRuntimeTrainer:
|
|
121
|
+
"""
|
|
122
|
+
Get the LocalRuntimeTrainer object.
|
|
123
|
+
"""
|
|
124
|
+
local_runtime = next(
|
|
125
|
+
(rt for rt in local_exec_constants.local_runtimes if rt.name == runtime_name),
|
|
126
|
+
None,
|
|
127
|
+
)
|
|
128
|
+
if not local_runtime:
|
|
129
|
+
raise ValueError(f"Runtime {runtime_name} not found")
|
|
130
|
+
|
|
131
|
+
trainer = LocalRuntimeTrainer(
|
|
132
|
+
trainer_type=types.TrainerType.CUSTOM_TRAINER,
|
|
133
|
+
framework=framework,
|
|
134
|
+
packages=local_runtime.trainer.packages,
|
|
135
|
+
image=local_exec_constants.LOCAL_RUNTIME_IMAGE,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# set command to run from venv
|
|
139
|
+
venv_bin_dir = str(Path(venv_dir) / "bin")
|
|
140
|
+
default_cmd = [str(Path(venv_bin_dir) / local_exec_constants.DEFAULT_COMMAND)]
|
|
141
|
+
# Set the Trainer entrypoint.
|
|
142
|
+
if framework == local_exec_constants.TORCH_FRAMEWORK_TYPE:
|
|
143
|
+
_c = [os.path.join(venv_bin_dir, local_exec_constants.TORCH_COMMAND)]
|
|
144
|
+
trainer.set_command(tuple(_c))
|
|
145
|
+
else:
|
|
146
|
+
trainer.set_command(tuple(default_cmd))
|
|
147
|
+
|
|
148
|
+
return trainer
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def get_dependencies_command(
|
|
152
|
+
runtime_packages: list[str],
|
|
153
|
+
pip_index_urls: list[str],
|
|
154
|
+
trainer_packages: list[str],
|
|
155
|
+
quiet: bool = True,
|
|
156
|
+
) -> str:
|
|
157
|
+
# resolve runtime dependencies and trainer dependencies.
|
|
158
|
+
packages = get_install_packages(
|
|
159
|
+
runtime_packages=runtime_packages,
|
|
160
|
+
trainer_packages=trainer_packages,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
options = [f"--index-url {pip_index_urls[0]}"]
|
|
164
|
+
options.extend(f"--extra-index-url {extra_index_url}" for extra_index_url in pip_index_urls[1:])
|
|
165
|
+
|
|
166
|
+
"""
|
|
167
|
+
PIP_DISABLE_PIP_VERSION_CHECK=1 pip install $QUIET $AS_USER \
|
|
168
|
+
--no-warn-script-location $PIP_INDEX $PACKAGE_STR
|
|
169
|
+
"""
|
|
170
|
+
mapping = {
|
|
171
|
+
"QUIET": "--quiet" if quiet else "",
|
|
172
|
+
"PIP_INDEX": " ".join(options),
|
|
173
|
+
"PACKAGE_STR": '"{}"'.format('" "'.join(packages)), # quote deps
|
|
174
|
+
}
|
|
175
|
+
t = Template(local_exec_constants.DEPENDENCIES_SCRIPT)
|
|
176
|
+
result = t.substitute(**mapping)
|
|
177
|
+
return result
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_command_using_train_func(
|
|
181
|
+
runtime: types.Runtime,
|
|
182
|
+
train_func: Callable,
|
|
183
|
+
train_func_parameters: Optional[dict[str, Any]],
|
|
184
|
+
venv_dir: str,
|
|
185
|
+
train_job_name: str,
|
|
186
|
+
) -> str:
|
|
187
|
+
"""
|
|
188
|
+
Get the Trainer container command from the given training function and parameters.
|
|
189
|
+
"""
|
|
190
|
+
# Check if the runtime has a Trainer.
|
|
191
|
+
if not runtime.trainer:
|
|
192
|
+
raise ValueError(f"Runtime must have a trainer: {runtime}")
|
|
193
|
+
|
|
194
|
+
# Check if training function is callable.
|
|
195
|
+
if not callable(train_func):
|
|
196
|
+
raise ValueError(
|
|
197
|
+
f"Training function must be callable, got function type: {type(train_func)}"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Extract the function implementation.
|
|
201
|
+
func_code = inspect.getsource(train_func)
|
|
202
|
+
|
|
203
|
+
# Extract the file name where the function is defined and move it the venv directory.
|
|
204
|
+
func_file = Path(venv_dir) / local_exec_constants.LOCAL_EXEC_FILENAME.format(train_job_name)
|
|
205
|
+
|
|
206
|
+
# Function might be defined in some indented scope (e.g. in another function).
|
|
207
|
+
# We need to dedent the function code.
|
|
208
|
+
func_code = textwrap.dedent(func_code)
|
|
209
|
+
|
|
210
|
+
# Wrap function code to execute it from the file. For example:
|
|
211
|
+
# TODO (andreyvelich): Find a better way to run users' scripts.
|
|
212
|
+
# def train(parameters):
|
|
213
|
+
# print('Start Training...')
|
|
214
|
+
# train({'lr': 0.01})
|
|
215
|
+
if train_func_parameters is None:
|
|
216
|
+
func_code = f"{func_code}\n{train_func.__name__}()\n"
|
|
217
|
+
else:
|
|
218
|
+
func_code = f"{func_code}\n{train_func.__name__}({train_func_parameters})\n"
|
|
219
|
+
|
|
220
|
+
with open(func_file, "w") as f:
|
|
221
|
+
f.write(func_code)
|
|
222
|
+
f.close()
|
|
223
|
+
|
|
224
|
+
t = Template(local_exec_constants.LOCAL_EXEC_ENTRYPOINT)
|
|
225
|
+
mapping = {
|
|
226
|
+
"PARAMETERS": "", ## Torch Parameters if any
|
|
227
|
+
"PYENV_LOCATION": venv_dir,
|
|
228
|
+
"ENTRYPOINT": " ".join(runtime.trainer.command),
|
|
229
|
+
"FUNC_FILE": func_file,
|
|
230
|
+
}
|
|
231
|
+
entrypoint = t.safe_substitute(**mapping)
|
|
232
|
+
|
|
233
|
+
return entrypoint
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def get_cleanup_venv_script(venv_dir: str, cleanup_venv: bool = True) -> str:
|
|
237
|
+
script = "\n"
|
|
238
|
+
if not cleanup_venv:
|
|
239
|
+
return script
|
|
240
|
+
|
|
241
|
+
t = Template(local_exec_constants.LOCAL_EXEC_JOB_CLEANUP_SCRIPT)
|
|
242
|
+
mapping = {
|
|
243
|
+
"PYENV_LOCATION": venv_dir,
|
|
244
|
+
}
|
|
245
|
+
return t.substitute(**mapping)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def get_local_train_job_script(
|
|
249
|
+
train_job_name: str,
|
|
250
|
+
venv_dir: str,
|
|
251
|
+
trainer: types.CustomTrainer,
|
|
252
|
+
runtime: types.Runtime,
|
|
253
|
+
cleanup_venv: bool = True,
|
|
254
|
+
) -> tuple:
|
|
255
|
+
# use local-exec train job template
|
|
256
|
+
t = Template(local_exec_constants.LOCAL_EXEC_JOB_TEMPLATE)
|
|
257
|
+
# find os python binary to create venv
|
|
258
|
+
python_bin = shutil.which("python")
|
|
259
|
+
if not python_bin:
|
|
260
|
+
python_bin = shutil.which("python3")
|
|
261
|
+
if not python_bin:
|
|
262
|
+
raise ValueError("No python executable found")
|
|
263
|
+
|
|
264
|
+
# workout if dependencies needs to be installed
|
|
265
|
+
if isinstance(runtime.trainer, LocalRuntimeTrainer):
|
|
266
|
+
runtime_trainer: LocalRuntimeTrainer = runtime.trainer
|
|
267
|
+
else:
|
|
268
|
+
raise ValueError("Invalid Runtime Trainer type: {type(runtime.trainer)}")
|
|
269
|
+
dependency_script = "\n"
|
|
270
|
+
if trainer.packages_to_install:
|
|
271
|
+
dependency_script = get_dependencies_command(
|
|
272
|
+
pip_index_urls=(
|
|
273
|
+
trainer.pip_index_urls
|
|
274
|
+
if trainer.pip_index_urls
|
|
275
|
+
else constants.DEFAULT_PIP_INDEX_URLS
|
|
276
|
+
),
|
|
277
|
+
runtime_packages=runtime_trainer.packages,
|
|
278
|
+
trainer_packages=trainer.packages_to_install,
|
|
279
|
+
quiet=False,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
entrypoint = get_command_using_train_func(
|
|
283
|
+
venv_dir=venv_dir,
|
|
284
|
+
runtime=runtime,
|
|
285
|
+
train_func=trainer.func,
|
|
286
|
+
train_func_parameters=trainer.func_args,
|
|
287
|
+
train_job_name=train_job_name,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
cleanup_script = get_cleanup_venv_script(cleanup_venv=cleanup_venv, venv_dir=venv_dir)
|
|
291
|
+
|
|
292
|
+
mapping = {
|
|
293
|
+
"OS_PYTHON_BIN": python_bin,
|
|
294
|
+
"PYENV_LOCATION": venv_dir,
|
|
295
|
+
"DEPENDENCIES_SCRIPT": dependency_script,
|
|
296
|
+
"ENTRYPOINT": entrypoint,
|
|
297
|
+
"CLEANUP_SCRIPT": cleanup_script,
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
command = t.safe_substitute(**mapping)
|
|
301
|
+
|
|
302
|
+
return "bash", "-c", command
|
|
File without changes
|