viettelcloud-aiplatform 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. viettelcloud/__init__.py +1 -0
  2. viettelcloud/aiplatform/__init__.py +15 -0
  3. viettelcloud/aiplatform/common/__init__.py +0 -0
  4. viettelcloud/aiplatform/common/constants.py +22 -0
  5. viettelcloud/aiplatform/common/types.py +28 -0
  6. viettelcloud/aiplatform/common/utils.py +40 -0
  7. viettelcloud/aiplatform/hub/OWNERS +14 -0
  8. viettelcloud/aiplatform/hub/__init__.py +25 -0
  9. viettelcloud/aiplatform/hub/api/__init__.py +13 -0
  10. viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
  11. viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
  12. viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
  13. viettelcloud/aiplatform/optimizer/__init__.py +45 -0
  14. viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
  15. viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
  16. viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
  17. viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
  18. viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
  19. viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
  20. viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
  21. viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
  22. viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
  23. viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
  24. viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
  25. viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
  26. viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
  27. viettelcloud/aiplatform/py.typed +0 -0
  28. viettelcloud/aiplatform/trainer/__init__.py +82 -0
  29. viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
  30. viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
  31. viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
  32. viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
  33. viettelcloud/aiplatform/trainer/backends/base.py +94 -0
  34. viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
  35. viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
  36. viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
  37. viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
  38. viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
  39. viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
  40. viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
  41. viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
  42. viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
  43. viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
  44. viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
  45. viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
  46. viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
  47. viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
  48. viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
  49. viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
  50. viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
  51. viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
  52. viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
  53. viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
  54. viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
  55. viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
  56. viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
  57. viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
  58. viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
  59. viettelcloud/aiplatform/trainer/options/common.py +55 -0
  60. viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
  61. viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
  62. viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
  63. viettelcloud/aiplatform/trainer/test/common.py +22 -0
  64. viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
  65. viettelcloud/aiplatform/trainer/types/types.py +517 -0
  66. viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
  67. viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
  68. viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
  69. viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
  70. viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
  71. viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
@@ -0,0 +1,184 @@
1
+ # Copyright 2025 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from datetime import datetime
15
+ import logging
16
+ import os
17
+ import subprocess
18
+ import threading
19
+ from typing import Union
20
+
21
+ from viettelcloud.aiplatform.trainer.constants import constants
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class LocalJob(threading.Thread):
27
+ def __init__(
28
+ self,
29
+ name,
30
+ command: Union[list, tuple[str], str],
31
+ execution_dir: str = None,
32
+ env: dict[str, str] = None,
33
+ dependencies: list = None,
34
+ ):
35
+ """Creates a LocalJob.
36
+
37
+ Creates a local subprocess with threading to allow users to create background jobs.
38
+
39
+ Args:
40
+ name (str): The name of the job.
41
+ command (str): The command to run.
42
+ execution_dir (str): The execution directory.
43
+ env (Dict[str, str], optional): Environment variables. Defaults to None.
44
+ dependencies (List[str], optional): List of dependencies. Defaults to None.
45
+ """
46
+ super().__init__()
47
+ self.name = name
48
+ self.command = command
49
+ self._stdout = ""
50
+ self._returncode = None
51
+ self._success = False
52
+ self._status = constants.TRAINJOB_CREATED
53
+ self._lock = threading.Lock()
54
+ self._process = None
55
+ self._output_updated = threading.Event()
56
+ self._cancel_requested = threading.Event()
57
+ self._start_time = None
58
+ self._end_time = None
59
+ self.env = env or {}
60
+ self.dependencies = dependencies or []
61
+ self.execution_dir = execution_dir or os.getcwd()
62
+
63
+ def run(self):
64
+ for dep in self.dependencies:
65
+ dep.join()
66
+ if not dep.success:
67
+ with self._lock:
68
+ self._stdout = f"Dependency {dep.name} failed. Skipping"
69
+ return
70
+
71
+ current_dir = os.getcwd()
72
+ try:
73
+ self._start_time = datetime.now()
74
+ _c = " ".join(self.command)
75
+ logger.debug(f"[{self.name}] Started at {self._start_time} with command: \n {_c}")
76
+
77
+ # change working directory to venv before executing script
78
+ os.chdir(self.execution_dir)
79
+
80
+ self._process = subprocess.Popen(
81
+ self.command,
82
+ stdout=subprocess.PIPE,
83
+ stderr=subprocess.STDOUT,
84
+ text=True,
85
+ encoding="utf-8",
86
+ bufsize=1,
87
+ env=self.env,
88
+ )
89
+ # set job status
90
+ self._status = constants.TRAINJOB_RUNNING
91
+
92
+ while True:
93
+ if self._cancel_requested.is_set():
94
+ self._process.terminate()
95
+ self._stdout += "[JobCancelled]\n"
96
+ self._status = constants.TRAINJOB_FAILED
97
+ self._success = False
98
+ return
99
+
100
+ # Read output line by line (for streaming)
101
+ output_line = self._process.stdout.readline()
102
+ with self._lock:
103
+ if output_line:
104
+ self._stdout += output_line
105
+ self._output_updated.set()
106
+
107
+ if not output_line and self._process.poll() is not None:
108
+ break
109
+
110
+ self._process.stdout.close()
111
+ self._returncode = self._process.wait()
112
+ self._end_time = datetime.now()
113
+ self._success = self._process.returncode == 0
114
+ msg = (
115
+ f"[{self.name}] Completed with code {self._returncode}"
116
+ f" in {self._end_time - self._start_time} seconds."
117
+ )
118
+ # set status based on success or failure
119
+ self._status = (
120
+ constants.TRAINJOB_COMPLETE if self._success else (constants.TRAINJOB_FAILED)
121
+ )
122
+ self._stdout += msg
123
+ logger.debug("Job output: ", self._stdout)
124
+
125
+ except Exception as e:
126
+ with self._lock:
127
+ self._stdout += f"Exception: {e}\n"
128
+ self._success = False
129
+ self._status = constants.TRAINJOB_FAILED
130
+ finally:
131
+ os.chdir(current_dir)
132
+
133
+ @property
134
+ def stdout(self):
135
+ with self._lock:
136
+ return self._stdout
137
+
138
+ @property
139
+ def success(self):
140
+ return self._success
141
+
142
+ @property
143
+ def status(self):
144
+ return self._status
145
+
146
+ def cancel(self):
147
+ self._cancel_requested.set()
148
+
149
+ @property
150
+ def returncode(self):
151
+ return self._returncode
152
+
153
+ def logs(self, follow=False) -> list[str]:
154
+ if not follow:
155
+ return self._stdout.splitlines()
156
+
157
+ try:
158
+ for chunk in self.stream_logs():
159
+ print(chunk, end="", flush=True) # stream to console live
160
+ except StopIteration:
161
+ pass
162
+
163
+ return self._stdout.splitlines()
164
+
165
+ def stream_logs(self):
166
+ """Generator that yields new output lines as they come in."""
167
+ last_index = 0
168
+ while self.is_alive() or last_index < len(self._stdout):
169
+ self._output_updated.wait(timeout=1)
170
+ with self._lock:
171
+ data = self._stdout
172
+ new_data = data[last_index:]
173
+ last_index = len(data)
174
+ self._output_updated.clear()
175
+ if new_data:
176
+ yield new_data
177
+
178
+ @property
179
+ def creation_time(self):
180
+ return self._start_time
181
+
182
+ @property
183
+ def completion_time(self):
184
+ return self._end_time
@@ -0,0 +1,52 @@
1
+ # Copyright 2025 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass, field
17
+ from datetime import datetime
18
+ import typing
19
+ from typing import Optional
20
+
21
+ from pydantic import BaseModel
22
+
23
+ from viettelcloud.aiplatform.trainer.backends.localprocess.job import LocalJob
24
+ from viettelcloud.aiplatform.trainer.types import types
25
+
26
+
27
+ class LocalProcessBackendConfig(BaseModel):
28
+ cleanup_venv: bool = True
29
+
30
+
31
+ @dataclass
32
+ class LocalRuntimeTrainer(types.RuntimeTrainer):
33
+ packages: list[str] = field(default_factory=list)
34
+
35
+
36
+ class LocalBackendStep(BaseModel):
37
+ step_name: str
38
+ job: LocalJob
39
+
40
+ class Config:
41
+ arbitrary_types_allowed = True
42
+
43
+
44
+ class LocalBackendJobs(BaseModel):
45
+ steps: Optional[list[LocalBackendStep]] = []
46
+ runtime: Optional[types.Runtime] = None
47
+ name: str
48
+ created: typing.Optional[datetime] = None
49
+ completed: typing.Optional[datetime] = None
50
+
51
+ class Config:
52
+ arbitrary_types_allowed = True
@@ -0,0 +1,302 @@
1
+ import inspect
2
+ import os
3
+ from pathlib import Path
4
+ import re
5
+ import shutil
6
+ from string import Template
7
+ import textwrap
8
+ from typing import Any, Callable, Optional
9
+
10
+ from viettelcloud.aiplatform.trainer.backends.localprocess import constants as local_exec_constants
11
+ from viettelcloud.aiplatform.trainer.backends.localprocess.types import LocalRuntimeTrainer
12
+ from viettelcloud.aiplatform.trainer.constants import constants
13
+ from viettelcloud.aiplatform.trainer.types import types
14
+
15
+
16
+ def _extract_name(requirement: str) -> str:
17
+ """
18
+ Extract the base distribution name from a requirement string without external deps.
19
+
20
+ Supports common PEP 508 patterns:
21
+ - 'package'
22
+ - 'package[extra1,extra2]'
23
+ - 'package==1.2.3', 'package>=1.0', 'package~=1.4', etc.
24
+ - 'package @ https://...'
25
+ - markers after ';' are irrelevant for name extraction.
26
+
27
+ Returns the *raw* (un-normalized) name as it appears.
28
+ Raises ValueError if a name cannot be parsed.
29
+ """
30
+ if requirement is None:
31
+ raise ValueError("Requirement string cannot be None")
32
+ s = requirement.strip()
33
+ if not s:
34
+ raise ValueError("Empty requirement string")
35
+
36
+ m = local_exec_constants.PYTHON_PACKAGE_NAME_RE.match(s)
37
+ if not m:
38
+ raise ValueError(f"Could not parse package name from requirement: {requirement!r}")
39
+ return m.group(1)
40
+
41
+
42
+ def _canonicalize_name(name: str) -> str:
43
+ """
44
+ PEP 503-style normalization: case-insensitive, and collapse runs of -, _, . into '-'.
45
+ """
46
+ return re.sub(r"[-_.]+", "-", name).lower()
47
+
48
+
49
+ def get_install_packages(
50
+ runtime_packages: list[str],
51
+ trainer_packages: Optional[list[str]] = None,
52
+ ) -> list[str]:
53
+ """
54
+ Merge two requirement lists into a single list of strings.
55
+
56
+ Rules implemented:
57
+ 1) If a package appears in trainer_packages, it overwrites the one in runtime_packages.
58
+ We keep the *trainer string verbatim* (specifier, markers, extras, spacing).
59
+ 2) Case-insensitive matching of package names (PEP 503-style normalization).
60
+ 3) Output is a list of strings.
61
+ 4) If trainer_packages contains the same dependency multiple times (case-insensitive),
62
+ raise ValueError.
63
+ 5) If runtime_packages contains duplicates, the last one among *runtime* wins there
64
+ (no error), but any trainer entry still overwrites it. Runtime packages shouldn't
65
+ have any duplicates.
66
+ 6) Ordering: keep runtime-only packages in their original order (emitting only their
67
+ last occurrence), then append all trainer packages in their original order.
68
+ """
69
+ if not trainer_packages:
70
+ return runtime_packages
71
+
72
+ # --- Parse + normalize runtime ---
73
+ runtime_parsed: list[tuple[str, str]] = [] # (orig, canonical_name)
74
+ last_runtime_index_by_name: dict[str, int] = {}
75
+
76
+ for i, orig in enumerate(runtime_packages):
77
+ raw_name = _extract_name(orig)
78
+ canon = _canonicalize_name(raw_name)
79
+ runtime_parsed.append((orig, canon))
80
+ last_runtime_index_by_name[canon] = i # last occurrence index wins among runtime
81
+
82
+ # --- Parse + validate trainer (detect duplicates) ---
83
+ trainer_parsed: list[tuple[str, str]] = []
84
+ seen_trainer: set[str] = set()
85
+ for orig in trainer_packages:
86
+ raw_name = _extract_name(orig)
87
+ canon = _canonicalize_name(raw_name)
88
+ if canon in seen_trainer:
89
+ raise ValueError(
90
+ f"Duplicate dependency in trainer_packages: '{raw_name}' (canonical: '{canon}')"
91
+ )
92
+ seen_trainer.add(canon)
93
+ trainer_parsed.append((orig, canon))
94
+
95
+ trainer_names: set[str] = {canon for _, canon in trainer_parsed}
96
+
97
+ # --- Build merged list respecting order semantics ---
98
+ merged: list[str] = []
99
+
100
+ # 1) Runtime-only packages (only emit the last occurrence for each name)
101
+ emitted_runtime_names: set[str] = set()
102
+ for idx, (orig, canon) in enumerate(runtime_parsed):
103
+ if canon in trainer_names:
104
+ continue # overwritten by trainer
105
+ if last_runtime_index_by_name[canon] == idx and canon not in emitted_runtime_names:
106
+ merged.append(orig)
107
+ emitted_runtime_names.add(canon)
108
+
109
+ # 2) Trainer packages (overwrite and preserve trainer's exact strings, original order)
110
+ for orig, _ in trainer_parsed:
111
+ merged.append(orig)
112
+
113
+ return merged
114
+
115
+
116
+ def get_local_runtime_trainer(
117
+ runtime_name: str,
118
+ venv_dir: str,
119
+ framework: str,
120
+ ) -> LocalRuntimeTrainer:
121
+ """
122
+ Get the LocalRuntimeTrainer object.
123
+ """
124
+ local_runtime = next(
125
+ (rt for rt in local_exec_constants.local_runtimes if rt.name == runtime_name),
126
+ None,
127
+ )
128
+ if not local_runtime:
129
+ raise ValueError(f"Runtime {runtime_name} not found")
130
+
131
+ trainer = LocalRuntimeTrainer(
132
+ trainer_type=types.TrainerType.CUSTOM_TRAINER,
133
+ framework=framework,
134
+ packages=local_runtime.trainer.packages,
135
+ image=local_exec_constants.LOCAL_RUNTIME_IMAGE,
136
+ )
137
+
138
+ # set command to run from venv
139
+ venv_bin_dir = str(Path(venv_dir) / "bin")
140
+ default_cmd = [str(Path(venv_bin_dir) / local_exec_constants.DEFAULT_COMMAND)]
141
+ # Set the Trainer entrypoint.
142
+ if framework == local_exec_constants.TORCH_FRAMEWORK_TYPE:
143
+ _c = [os.path.join(venv_bin_dir, local_exec_constants.TORCH_COMMAND)]
144
+ trainer.set_command(tuple(_c))
145
+ else:
146
+ trainer.set_command(tuple(default_cmd))
147
+
148
+ return trainer
149
+
150
+
151
+ def get_dependencies_command(
152
+ runtime_packages: list[str],
153
+ pip_index_urls: list[str],
154
+ trainer_packages: list[str],
155
+ quiet: bool = True,
156
+ ) -> str:
157
+ # resolve runtime dependencies and trainer dependencies.
158
+ packages = get_install_packages(
159
+ runtime_packages=runtime_packages,
160
+ trainer_packages=trainer_packages,
161
+ )
162
+
163
+ options = [f"--index-url {pip_index_urls[0]}"]
164
+ options.extend(f"--extra-index-url {extra_index_url}" for extra_index_url in pip_index_urls[1:])
165
+
166
+ """
167
+ PIP_DISABLE_PIP_VERSION_CHECK=1 pip install $QUIET $AS_USER \
168
+ --no-warn-script-location $PIP_INDEX $PACKAGE_STR
169
+ """
170
+ mapping = {
171
+ "QUIET": "--quiet" if quiet else "",
172
+ "PIP_INDEX": " ".join(options),
173
+ "PACKAGE_STR": '"{}"'.format('" "'.join(packages)), # quote deps
174
+ }
175
+ t = Template(local_exec_constants.DEPENDENCIES_SCRIPT)
176
+ result = t.substitute(**mapping)
177
+ return result
178
+
179
+
180
+ def get_command_using_train_func(
181
+ runtime: types.Runtime,
182
+ train_func: Callable,
183
+ train_func_parameters: Optional[dict[str, Any]],
184
+ venv_dir: str,
185
+ train_job_name: str,
186
+ ) -> str:
187
+ """
188
+ Get the Trainer container command from the given training function and parameters.
189
+ """
190
+ # Check if the runtime has a Trainer.
191
+ if not runtime.trainer:
192
+ raise ValueError(f"Runtime must have a trainer: {runtime}")
193
+
194
+ # Check if training function is callable.
195
+ if not callable(train_func):
196
+ raise ValueError(
197
+ f"Training function must be callable, got function type: {type(train_func)}"
198
+ )
199
+
200
+ # Extract the function implementation.
201
+ func_code = inspect.getsource(train_func)
202
+
203
+ # Extract the file name where the function is defined and move it the venv directory.
204
+ func_file = Path(venv_dir) / local_exec_constants.LOCAL_EXEC_FILENAME.format(train_job_name)
205
+
206
+ # Function might be defined in some indented scope (e.g. in another function).
207
+ # We need to dedent the function code.
208
+ func_code = textwrap.dedent(func_code)
209
+
210
+ # Wrap function code to execute it from the file. For example:
211
+ # TODO (andreyvelich): Find a better way to run users' scripts.
212
+ # def train(parameters):
213
+ # print('Start Training...')
214
+ # train({'lr': 0.01})
215
+ if train_func_parameters is None:
216
+ func_code = f"{func_code}\n{train_func.__name__}()\n"
217
+ else:
218
+ func_code = f"{func_code}\n{train_func.__name__}({train_func_parameters})\n"
219
+
220
+ with open(func_file, "w") as f:
221
+ f.write(func_code)
222
+ f.close()
223
+
224
+ t = Template(local_exec_constants.LOCAL_EXEC_ENTRYPOINT)
225
+ mapping = {
226
+ "PARAMETERS": "", ## Torch Parameters if any
227
+ "PYENV_LOCATION": venv_dir,
228
+ "ENTRYPOINT": " ".join(runtime.trainer.command),
229
+ "FUNC_FILE": func_file,
230
+ }
231
+ entrypoint = t.safe_substitute(**mapping)
232
+
233
+ return entrypoint
234
+
235
+
236
+ def get_cleanup_venv_script(venv_dir: str, cleanup_venv: bool = True) -> str:
237
+ script = "\n"
238
+ if not cleanup_venv:
239
+ return script
240
+
241
+ t = Template(local_exec_constants.LOCAL_EXEC_JOB_CLEANUP_SCRIPT)
242
+ mapping = {
243
+ "PYENV_LOCATION": venv_dir,
244
+ }
245
+ return t.substitute(**mapping)
246
+
247
+
248
+ def get_local_train_job_script(
249
+ train_job_name: str,
250
+ venv_dir: str,
251
+ trainer: types.CustomTrainer,
252
+ runtime: types.Runtime,
253
+ cleanup_venv: bool = True,
254
+ ) -> tuple:
255
+ # use local-exec train job template
256
+ t = Template(local_exec_constants.LOCAL_EXEC_JOB_TEMPLATE)
257
+ # find os python binary to create venv
258
+ python_bin = shutil.which("python")
259
+ if not python_bin:
260
+ python_bin = shutil.which("python3")
261
+ if not python_bin:
262
+ raise ValueError("No python executable found")
263
+
264
+ # workout if dependencies needs to be installed
265
+ if isinstance(runtime.trainer, LocalRuntimeTrainer):
266
+ runtime_trainer: LocalRuntimeTrainer = runtime.trainer
267
+ else:
268
+ raise ValueError("Invalid Runtime Trainer type: {type(runtime.trainer)}")
269
+ dependency_script = "\n"
270
+ if trainer.packages_to_install:
271
+ dependency_script = get_dependencies_command(
272
+ pip_index_urls=(
273
+ trainer.pip_index_urls
274
+ if trainer.pip_index_urls
275
+ else constants.DEFAULT_PIP_INDEX_URLS
276
+ ),
277
+ runtime_packages=runtime_trainer.packages,
278
+ trainer_packages=trainer.packages_to_install,
279
+ quiet=False,
280
+ )
281
+
282
+ entrypoint = get_command_using_train_func(
283
+ venv_dir=venv_dir,
284
+ runtime=runtime,
285
+ train_func=trainer.func,
286
+ train_func_parameters=trainer.func_args,
287
+ train_job_name=train_job_name,
288
+ )
289
+
290
+ cleanup_script = get_cleanup_venv_script(cleanup_venv=cleanup_venv, venv_dir=venv_dir)
291
+
292
+ mapping = {
293
+ "OS_PYTHON_BIN": python_bin,
294
+ "PYENV_LOCATION": venv_dir,
295
+ "DEPENDENCIES_SCRIPT": dependency_script,
296
+ "ENTRYPOINT": entrypoint,
297
+ "CLEANUP_SCRIPT": cleanup_script,
298
+ }
299
+
300
+ command = t.safe_substitute(**mapping)
301
+
302
+ return "bash", "-c", command
File without changes