viettelcloud-aiplatform 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- viettelcloud/__init__.py +1 -0
- viettelcloud/aiplatform/__init__.py +15 -0
- viettelcloud/aiplatform/common/__init__.py +0 -0
- viettelcloud/aiplatform/common/constants.py +22 -0
- viettelcloud/aiplatform/common/types.py +28 -0
- viettelcloud/aiplatform/common/utils.py +40 -0
- viettelcloud/aiplatform/hub/OWNERS +14 -0
- viettelcloud/aiplatform/hub/__init__.py +25 -0
- viettelcloud/aiplatform/hub/api/__init__.py +13 -0
- viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
- viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
- viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
- viettelcloud/aiplatform/optimizer/__init__.py +45 -0
- viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
- viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
- viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
- viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
- viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
- viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
- viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
- viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
- viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
- viettelcloud/aiplatform/py.typed +0 -0
- viettelcloud/aiplatform/trainer/__init__.py +82 -0
- viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
- viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
- viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
- viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/base.py +94 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
- viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
- viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
- viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
- viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
- viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
- viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
- viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
- viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
- viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
- viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
- viettelcloud/aiplatform/trainer/options/common.py +55 -0
- viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
- viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
- viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
- viettelcloud/aiplatform/trainer/test/common.py +22 -0
- viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
- viettelcloud/aiplatform/trainer/types/types.py +517 -0
- viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
- viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
- viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
- viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
- viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
|
@@ -0,0 +1,517 @@
|
|
|
1
|
+
# Copyright 2024 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import abc
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
from enum import Enum
|
|
19
|
+
from typing import Callable, Optional, Union
|
|
20
|
+
from urllib.parse import urlparse
|
|
21
|
+
|
|
22
|
+
import viettelcloud.aiplatform.common.constants as common_constants
|
|
23
|
+
from viettelcloud.aiplatform.trainer.constants import constants
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Configuration for the Custom Trainer.
|
|
27
|
+
@dataclass
|
|
28
|
+
class CustomTrainer:
|
|
29
|
+
"""Custom Trainer configuration. Configure the self-contained function
|
|
30
|
+
that encapsulates the entire model training process.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
func (`Callable`): The function that encapsulates the entire model training process.
|
|
34
|
+
func_args (`Optional[dict]`): The arguments to pass to the function.
|
|
35
|
+
image (`Optional[str]`): The optional container image to use in TrainJob.
|
|
36
|
+
packages_to_install (`Optional[list[str]]`):
|
|
37
|
+
A list of Python packages to install before running the function.
|
|
38
|
+
pip_index_urls (`list[str]`): The PyPI URLs from which to install
|
|
39
|
+
Python packages. The first URL will be the index-url, and remaining ones
|
|
40
|
+
are extra-index-urls.
|
|
41
|
+
num_nodes (`Optional[int]`): The number of nodes to use for training.
|
|
42
|
+
resources_per_node (`Optional[dict]`): The computing resources to allocate per node.
|
|
43
|
+
```python
|
|
44
|
+
resources_per_node = {"gpu": 4, "cpu": 5, "memory": "10G"}
|
|
45
|
+
```
|
|
46
|
+
If your compute supports fractional GPUs (e.g. multi-instance GPU),
|
|
47
|
+
you can set the resources as follows (request 1 GPU slice of 5Gb) :
|
|
48
|
+
```python
|
|
49
|
+
resources_per_node = {"mig-1g.5gb": 1}
|
|
50
|
+
```
|
|
51
|
+
env (`Optional[dict[str, str]]`): The environment variables to set in the training nodes.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
func: Callable
|
|
55
|
+
func_args: Optional[dict] = None
|
|
56
|
+
image: Optional[str] = None
|
|
57
|
+
packages_to_install: Optional[list[str]] = None
|
|
58
|
+
pip_index_urls: list[str] = field(
|
|
59
|
+
default_factory=lambda: list(constants.DEFAULT_PIP_INDEX_URLS)
|
|
60
|
+
)
|
|
61
|
+
num_nodes: Optional[int] = None
|
|
62
|
+
resources_per_node: Optional[dict] = None
|
|
63
|
+
env: Optional[dict[str, str]] = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# Configuration for the Custom Trainer Container.
|
|
67
|
+
@dataclass
|
|
68
|
+
class CustomTrainerContainer:
|
|
69
|
+
"""Custom Trainer Container configuration. Configure the container image
|
|
70
|
+
that encapsulates the entire model training process.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
image (`str`): The container image that encapsulates the entire model training process.
|
|
74
|
+
num_nodes (`Optional[int]`): The number of nodes to use for training.
|
|
75
|
+
resources_per_node (`Optional[dict]`): The computing resources to allocate per node.
|
|
76
|
+
```python
|
|
77
|
+
resources_per_node = {"gpu": 4, "cpu": 5, "memory": "10G"}
|
|
78
|
+
```
|
|
79
|
+
If your compute supports fractional GPUs (e.g. multi-instance GPU),
|
|
80
|
+
you can set the resources as follows (request 1 GPU slice of 5Gb) :
|
|
81
|
+
```python
|
|
82
|
+
resources_per_node = {"mig-1g.5gb": 1}
|
|
83
|
+
```
|
|
84
|
+
env (`Optional[dict[str, str]]`): The environment variables to set in the training nodes.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
image: str
|
|
88
|
+
num_nodes: Optional[int] = None
|
|
89
|
+
resources_per_node: Optional[dict] = None
|
|
90
|
+
env: Optional[dict[str, str]] = None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# TODO(Electronic-Waste): Add more loss functions.
|
|
94
|
+
# Loss function for the TorchTune LLM Trainer.
|
|
95
|
+
class Loss(Enum):
|
|
96
|
+
"""Loss function for the TorchTune LLM Trainer."""
|
|
97
|
+
|
|
98
|
+
CEWithChunkedOutputLoss = "torchtune.modules.loss.CEWithChunkedOutputLoss"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# Data type for the TorchTune LLM Trainer.
|
|
102
|
+
class DataType(Enum):
|
|
103
|
+
"""Data type for the TorchTune LLM Trainer."""
|
|
104
|
+
|
|
105
|
+
BF16 = "bf16"
|
|
106
|
+
FP32 = "fp32"
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# Data file type for the TorchTune LLM Trainer.
|
|
110
|
+
class DataFormat(Enum):
|
|
111
|
+
"""Data file type for the TorchTune LLM Trainer."""
|
|
112
|
+
|
|
113
|
+
JSON = "json"
|
|
114
|
+
CSV = "csv"
|
|
115
|
+
PARQUET = "parquet"
|
|
116
|
+
ARROW = "arrow"
|
|
117
|
+
TEXT = "text"
|
|
118
|
+
XML = "xml"
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# Configuration for the TorchTune Instruct dataset.
|
|
122
|
+
@dataclass
|
|
123
|
+
class TorchTuneInstructDataset:
|
|
124
|
+
"""
|
|
125
|
+
Configuration for the custom dataset with user instruction prompts and model responses.
|
|
126
|
+
REF: https://pytorch.org/torchtune/main/generated/torchtune.datasets.instruct_dataset.html
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
source (`Optional[DataFormat]`): Data file type.
|
|
130
|
+
split (`Optional[str]`):
|
|
131
|
+
The split of the dataset to use. You can use this argument to load a subset of
|
|
132
|
+
a given split, e.g. split="train[:10%]". Default is `train`.
|
|
133
|
+
train_on_input (`Optional[bool]`):
|
|
134
|
+
Whether the model is trained on the user prompt or not. Default is False.
|
|
135
|
+
new_system_prompt (`Optional[str]`):
|
|
136
|
+
The new system prompt to use. If specified, prepend a system message.
|
|
137
|
+
This can serve as instructions to guide the model response. Default is None.
|
|
138
|
+
column_map (`Optional[Dict[str, str]]`):
|
|
139
|
+
A mapping to change the expected "input" and "output" column names to the actual
|
|
140
|
+
column names in the dataset. Keys should be "input" and "output" and values should
|
|
141
|
+
be the actual column names. Default is None, keeping the default "input" and
|
|
142
|
+
"output" column names.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
source: Optional[DataFormat] = None
|
|
146
|
+
split: Optional[str] = None
|
|
147
|
+
train_on_input: Optional[bool] = None
|
|
148
|
+
new_system_prompt: Optional[str] = None
|
|
149
|
+
column_map: Optional[dict[str, str]] = None
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
@dataclass
|
|
153
|
+
class LoraConfig:
|
|
154
|
+
"""Configuration for the LoRA/QLoRA/DoRA.
|
|
155
|
+
REF: https://meta-pytorch.org/torchtune/main/tutorials/memory_optimizations.html
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
apply_lora_to_mlp (`Optional[bool]`):
|
|
159
|
+
Whether to apply LoRA to the MLP in each transformer layer.
|
|
160
|
+
apply_lora_to_output (`Optional[bool]`):
|
|
161
|
+
Whether to apply LoRA to the model's final output projection.
|
|
162
|
+
lora_attn_modules (`list[str]`):
|
|
163
|
+
A list of strings specifying which layers of the model to apply LoRA,
|
|
164
|
+
default is ["q_proj", "v_proj", "output_proj"]:
|
|
165
|
+
1. "q_proj" applies LoRA to the query projection layer.
|
|
166
|
+
2. "k_proj" applies LoRA to the key projection layer.
|
|
167
|
+
3. "v_proj" applies LoRA to the value projection layer.
|
|
168
|
+
4. "output_proj" applies LoRA to the attention output projection layer.
|
|
169
|
+
lora_rank (`Optional[int]`): The rank of the low rank decomposition.
|
|
170
|
+
lora_alpha (`Optional[int]`):
|
|
171
|
+
The scaling factor that adjusts the magnitude of the low-rank matrices' output.
|
|
172
|
+
lora_dropout (`Optional[float]`):
|
|
173
|
+
The probability of applying Dropout to the low rank updates.
|
|
174
|
+
quantize_base (`Optional[bool]`): Whether to enable model quantization.
|
|
175
|
+
use_dora (`Optional[bool]`): Whether to enable DoRA.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
apply_lora_to_mlp: Optional[bool] = None
|
|
179
|
+
apply_lora_to_output: Optional[bool] = None
|
|
180
|
+
lora_attn_modules: list[str] = field(
|
|
181
|
+
default_factory=lambda: ["q_proj", "v_proj", "output_proj"]
|
|
182
|
+
)
|
|
183
|
+
lora_rank: Optional[int] = None
|
|
184
|
+
lora_alpha: Optional[int] = None
|
|
185
|
+
lora_dropout: Optional[float] = None
|
|
186
|
+
quantize_base: Optional[bool] = None
|
|
187
|
+
use_dora: Optional[bool] = None
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
# Configuration for the TorchTune LLM Trainer.
|
|
191
|
+
@dataclass
|
|
192
|
+
class TorchTuneConfig:
|
|
193
|
+
"""TorchTune LLM Trainer configuration. Configure the parameters in
|
|
194
|
+
the TorchTune LLM Trainer that already includes the fine-tuning logic.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
dtype (`Optional[Dtype]`):
|
|
198
|
+
The underlying data type used to represent the model and optimizer parameters.
|
|
199
|
+
Currently, we only support `bf16` and `fp32`.
|
|
200
|
+
batch_size (`Optional[int]`):
|
|
201
|
+
The number of samples processed before updating model weights.
|
|
202
|
+
epochs (`Optional[int]`):
|
|
203
|
+
The number of samples processed before updating model weights.
|
|
204
|
+
loss (`Optional[Loss]`): The loss algorithm we use to fine-tune the LLM,
|
|
205
|
+
e.g. `torchtune.modules.loss.CEWithChunkedOutputLoss`.
|
|
206
|
+
num_nodes (`Optional[int]`): The number of nodes to use for training.
|
|
207
|
+
peft_config (`Optional[LoraConfig]`):
|
|
208
|
+
Configuration for the PEFT(Parameter-Efficient Fine-Tuning),
|
|
209
|
+
including LoRA/QLoRA/DoRA, etc.
|
|
210
|
+
dataset_preprocess_config (`Optional[TorchTuneInstructDataset]`):
|
|
211
|
+
Configuration for the dataset preprocessing.
|
|
212
|
+
resources_per_node (`Optional[Dict]`): The computing resources to allocate per node.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
dtype: Optional[DataType] = None
|
|
216
|
+
batch_size: Optional[int] = None
|
|
217
|
+
epochs: Optional[int] = None
|
|
218
|
+
loss: Optional[Loss] = None
|
|
219
|
+
num_nodes: Optional[int] = None
|
|
220
|
+
peft_config: Optional[LoraConfig] = None
|
|
221
|
+
dataset_preprocess_config: Optional[TorchTuneInstructDataset] = None
|
|
222
|
+
resources_per_node: Optional[dict] = None
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
# Configuration for the Builtin Trainer.
|
|
226
|
+
@dataclass
|
|
227
|
+
class BuiltinTrainer:
|
|
228
|
+
"""
|
|
229
|
+
Builtin Trainer configuration. Configure the builtin trainer that already includes
|
|
230
|
+
the fine-tuning logic, requiring only parameter adjustments.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
config (`TorchTuneConfig`): The configuration for the builtin trainer.
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
config: TorchTuneConfig
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
# Change it to list: BUILTIN_CONFIGS, once we support more Builtin Trainer configs.
|
|
240
|
+
TORCH_TUNE = BuiltinTrainer.__annotations__["config"].__name__.lower().replace("config", "")
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class TrainerType(Enum):
|
|
244
|
+
CUSTOM_TRAINER = CustomTrainer.__name__
|
|
245
|
+
BUILTIN_TRAINER = BuiltinTrainer.__name__
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
# Representation for the Trainer of the runtime.
|
|
249
|
+
@dataclass
|
|
250
|
+
class RuntimeTrainer:
|
|
251
|
+
trainer_type: TrainerType
|
|
252
|
+
framework: str
|
|
253
|
+
image: str
|
|
254
|
+
num_nodes: int = 1 # The default value is set in the APIs.
|
|
255
|
+
device: str = common_constants.UNKNOWN
|
|
256
|
+
device_count: str = common_constants.UNKNOWN
|
|
257
|
+
__command: tuple[str, ...] = field(init=False, repr=False)
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def command(self) -> tuple[str, ...]:
|
|
261
|
+
return self.__command
|
|
262
|
+
|
|
263
|
+
def set_command(self, command: tuple[str, ...]):
|
|
264
|
+
self.__command = command
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
# Representation for the Training Runtime.
|
|
268
|
+
@dataclass
|
|
269
|
+
class Runtime:
|
|
270
|
+
name: str
|
|
271
|
+
trainer: RuntimeTrainer
|
|
272
|
+
pretrained_model: Optional[str] = None
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# Representation for the TrainJob steps.
|
|
276
|
+
@dataclass
|
|
277
|
+
class Step:
|
|
278
|
+
name: str
|
|
279
|
+
status: Optional[str]
|
|
280
|
+
pod_name: str
|
|
281
|
+
device: str = common_constants.UNKNOWN
|
|
282
|
+
device_count: str = common_constants.UNKNOWN
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# Representation for the TrainJob.
|
|
286
|
+
@dataclass
|
|
287
|
+
class TrainJob:
|
|
288
|
+
name: str
|
|
289
|
+
runtime: Runtime
|
|
290
|
+
steps: list[Step]
|
|
291
|
+
num_nodes: int
|
|
292
|
+
creation_timestamp: datetime
|
|
293
|
+
status: str = common_constants.UNKNOWN
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# Representation for TrainJob events.
|
|
297
|
+
@dataclass
|
|
298
|
+
class Event:
|
|
299
|
+
"""Event object that represents a Kubernetes event related to a TrainJob.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
involved_object_kind (`str`): The kind of object this event is about
|
|
303
|
+
(e.g., 'TrainJob', 'Pod').
|
|
304
|
+
involved_object_name (`str`): The name of the object this event is about.
|
|
305
|
+
message (`str`): Human-readable description of the event.
|
|
306
|
+
reason (`str`): Short, machine understandable string describing why
|
|
307
|
+
this event was generated.
|
|
308
|
+
event_time (`datetime`): The time at which the event was first recorded.
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
involved_object_kind: str
|
|
312
|
+
involved_object_name: str
|
|
313
|
+
message: str
|
|
314
|
+
reason: str
|
|
315
|
+
event_time: datetime
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
@dataclass
|
|
319
|
+
class BaseInitializer(abc.ABC):
|
|
320
|
+
"""Base class for all initializers"""
|
|
321
|
+
|
|
322
|
+
storage_uri: str
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@dataclass
|
|
326
|
+
class HuggingFaceDatasetInitializer(BaseInitializer):
|
|
327
|
+
"""Configuration for downloading datasets from HuggingFace Hub.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
storage_uri (`str`): The HuggingFace Hub model identifier in the format 'hf://username/repo_name'.
|
|
331
|
+
ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
|
|
332
|
+
access_token (`Optional[str]`): HuggingFace Hub access token for private datasets.
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
ignore_patterns: Optional[list[str]] = None
|
|
336
|
+
access_token: Optional[str] = None
|
|
337
|
+
|
|
338
|
+
def __post_init__(self):
|
|
339
|
+
"""Validate HuggingFaceDatasetInitializer parameters."""
|
|
340
|
+
|
|
341
|
+
if not self.storage_uri.startswith("hf://"):
|
|
342
|
+
raise ValueError(f"storage_uri must start with 'hf://', got {self.storage_uri}")
|
|
343
|
+
|
|
344
|
+
if urlparse(self.storage_uri).path == "":
|
|
345
|
+
raise ValueError(
|
|
346
|
+
"storage_uri: must have absolute path with 'hf://<user_name>/<dataset_name>', got "
|
|
347
|
+
f"{self.storage_uri}"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
@dataclass
|
|
352
|
+
class S3DatasetInitializer(BaseInitializer):
|
|
353
|
+
"""Configuration for downloading datasets from S3-compatible storage.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
storage_uri (`str`): The S3 URI for the model in the format 's3://bucket-name/path/to/model'.
|
|
357
|
+
ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
|
|
358
|
+
endpoint (`Optional[str]`): Custom S3 endpoint URL.
|
|
359
|
+
access_key_id (`Optional[str]`): Access key for authentication.
|
|
360
|
+
secret_access_key (`Optional[str]`): Secret key for authentication.
|
|
361
|
+
region (`Optional[str]`): Region used in instantiating the client.
|
|
362
|
+
role_arn (`Optional[str]`): The ARN of the role you want to assume.
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
ignore_patterns: Optional[list[str]] = None
|
|
366
|
+
endpoint: Optional[str] = None
|
|
367
|
+
access_key_id: Optional[str] = None
|
|
368
|
+
secret_access_key: Optional[str] = None
|
|
369
|
+
region: Optional[str] = None
|
|
370
|
+
role_arn: Optional[str] = None
|
|
371
|
+
|
|
372
|
+
def __post_init__(self):
|
|
373
|
+
"""Validate S3DatasetInitializer parameters."""
|
|
374
|
+
|
|
375
|
+
if not self.storage_uri.startswith("s3://"):
|
|
376
|
+
raise ValueError(f"storage_uri must start with 's3://', got {self.storage_uri}")
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@dataclass
|
|
380
|
+
class DataCacheInitializer(BaseInitializer):
|
|
381
|
+
"""Configuration for distributed data caching system for training workloads.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
storage_uri (`str`): The URI for the cached data in the format
|
|
385
|
+
'cache://<SCHEMA_NAME>/<TABLE_NAME>'. This specifies the location
|
|
386
|
+
where the data cache will be stored and accessed.
|
|
387
|
+
metadata_loc (`str`): The metadata file path of an iceberg table.
|
|
388
|
+
num_data_nodes (`int`): The number of data nodes in the distributed cache
|
|
389
|
+
system. Must be greater than 1.
|
|
390
|
+
head_cpu (`Optional[str]`): The CPU resources to allocate for the cache head node.
|
|
391
|
+
head_mem (`Optional[str]`): The memory resources to allocate for the cache head node.
|
|
392
|
+
worker_cpu (`Optional[str]`): The CPU resources to allocate for each cache worker node.
|
|
393
|
+
worker_mem (`Optional[str]`): The memory resources to allocate for each cache worker node.
|
|
394
|
+
iam_role (`Optional[str]`): The IAM role to use for accessing metadata_loc file.
|
|
395
|
+
"""
|
|
396
|
+
|
|
397
|
+
metadata_loc: str
|
|
398
|
+
num_data_nodes: int
|
|
399
|
+
head_cpu: Optional[str] = None
|
|
400
|
+
head_mem: Optional[str] = None
|
|
401
|
+
worker_cpu: Optional[str] = None
|
|
402
|
+
worker_mem: Optional[str] = None
|
|
403
|
+
iam_role: Optional[str] = None
|
|
404
|
+
|
|
405
|
+
def __post_init__(self):
|
|
406
|
+
"""Validate DataCacheInitializer parameters."""
|
|
407
|
+
|
|
408
|
+
if self.num_data_nodes <= 1:
|
|
409
|
+
raise ValueError(f"num_data_nodes must be greater than 1, got {self.num_data_nodes}")
|
|
410
|
+
|
|
411
|
+
# Validate storage_uri format
|
|
412
|
+
if not self.storage_uri.startswith("cache://"):
|
|
413
|
+
raise ValueError(f"storage_uri must start with 'cache://', got {self.storage_uri}")
|
|
414
|
+
|
|
415
|
+
uri_path = self.storage_uri[len("cache://") :]
|
|
416
|
+
parts = uri_path.split("/")
|
|
417
|
+
|
|
418
|
+
if len(parts) != 2:
|
|
419
|
+
raise ValueError(
|
|
420
|
+
f"storage_uri must be in format "
|
|
421
|
+
f"'cache://<SCHEMA_NAME>/<TABLE_NAME>', got {self.storage_uri}"
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
@dataclass
|
|
426
|
+
class HuggingFaceModelInitializer(BaseInitializer):
|
|
427
|
+
"""Configuration for downloading models from HuggingFace Hub.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
storage_uri (`str`): The HuggingFace Hub model identifier in the format 'hf://username/repo_name'.
|
|
431
|
+
ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
|
|
432
|
+
access_token (`Optional[str]`): HuggingFace Hub access token.
|
|
433
|
+
"""
|
|
434
|
+
|
|
435
|
+
ignore_patterns: Optional[list[str]] = field(
|
|
436
|
+
default_factory=lambda: constants.INITIALIZER_DEFAULT_IGNORE_PATTERNS
|
|
437
|
+
)
|
|
438
|
+
access_token: Optional[str] = None
|
|
439
|
+
|
|
440
|
+
def __post_init__(self):
|
|
441
|
+
"""Validate HuggingFaceModelInitializer parameters."""
|
|
442
|
+
|
|
443
|
+
if not self.storage_uri.startswith("hf://"):
|
|
444
|
+
raise ValueError(f"storage_uri must start with 'hf://', got {self.storage_uri}")
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
@dataclass
|
|
448
|
+
class S3ModelInitializer(BaseInitializer):
|
|
449
|
+
"""Configuration for downloading models from S3-compatible storage.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
storage_uri (`str`): The S3 URI for the model in the format 's3://bucket-name/path/to/model'.
|
|
453
|
+
ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
|
|
454
|
+
Defaults to `['*.msgpack', '*.h5', '*.bin', '.pt', '.pth']`.
|
|
455
|
+
endpoint (`Optional[str]`): Custom S3 endpoint URL.
|
|
456
|
+
access_key_id (`Optional[str]`): Access key for authentication.
|
|
457
|
+
secret_access_key (`Optional[str]`): Secret key for authentication.
|
|
458
|
+
region (`Optional[str]`): Region used in instantiating the client.
|
|
459
|
+
role_arn (`Optional[str]`): The ARN of the role you want to assume.
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
ignore_patterns: Optional[list[str]] = field(
|
|
463
|
+
default_factory=lambda: constants.INITIALIZER_DEFAULT_IGNORE_PATTERNS
|
|
464
|
+
)
|
|
465
|
+
endpoint: Optional[str] = None
|
|
466
|
+
access_key_id: Optional[str] = None
|
|
467
|
+
secret_access_key: Optional[str] = None
|
|
468
|
+
region: Optional[str] = None
|
|
469
|
+
role_arn: Optional[str] = None
|
|
470
|
+
|
|
471
|
+
def __post_init__(self):
|
|
472
|
+
"""Validate S3ModelInitializer parameters."""
|
|
473
|
+
|
|
474
|
+
if not self.storage_uri.startswith("s3://"):
|
|
475
|
+
raise ValueError(f"storage_uri must start with 's3://', got {self.storage_uri}")
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@dataclass
|
|
479
|
+
class Initializer:
|
|
480
|
+
"""Initializer defines configurations for dataset and pre-trained model initialization
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
dataset (`Optional[Union[HuggingFaceDatasetInitializer, S3DatasetInitializer, DataCacheInitializer]]`):
|
|
484
|
+
The configuration for one of the supported dataset initializers.
|
|
485
|
+
model (`Optional[Union[HuggingFaceModelInitializer, S3ModelInitializer]]`):
|
|
486
|
+
The configuration for one of the supported model initializers.
|
|
487
|
+
""" # noqa: E501
|
|
488
|
+
|
|
489
|
+
dataset: Optional[
|
|
490
|
+
Union[HuggingFaceDatasetInitializer, S3DatasetInitializer, DataCacheInitializer]
|
|
491
|
+
] = None
|
|
492
|
+
model: Optional[Union[HuggingFaceModelInitializer, S3ModelInitializer]] = None
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
# TODO (andreyvelich): Add train() and optimize() methods to this class.
|
|
496
|
+
@dataclass
|
|
497
|
+
class TrainJobTemplate:
|
|
498
|
+
"""TrainJob template configuration.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
trainer (`CustomTrainer`): Configuration for a CustomTrainer.
|
|
502
|
+
runtime (`Optional[Union[str, Runtime]]`): Optional, reference to one of the existing
|
|
503
|
+
runtimes. It can accept the runtime name or Runtime object from the `get_runtime()` API.
|
|
504
|
+
Defaults to the torch-distributed runtime if not provided.
|
|
505
|
+
initializer (`Optional[Initializer]`): Optional configuration for the dataset and model
|
|
506
|
+
initializers.
|
|
507
|
+
"""
|
|
508
|
+
|
|
509
|
+
trainer: CustomTrainer
|
|
510
|
+
runtime: Optional[Union[str, Runtime]] = None
|
|
511
|
+
initializer: Optional[Initializer] = None
|
|
512
|
+
|
|
513
|
+
def keys(self):
|
|
514
|
+
return ["trainer", "runtime", "initializer"]
|
|
515
|
+
|
|
516
|
+
def __getitem__(self, key):
|
|
517
|
+
return getattr(self, key)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright 2025 The Kubeflow Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import pytest
|
|
16
|
+
|
|
17
|
+
from viettelcloud.aiplatform.trainer.test.common import FAILED, SUCCESS, TestCase
|
|
18
|
+
from viettelcloud.aiplatform.trainer.types import types
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.mark.parametrize(
|
|
22
|
+
"test_case",
|
|
23
|
+
[
|
|
24
|
+
TestCase(
|
|
25
|
+
name="valid datacacheinitializer creation",
|
|
26
|
+
expected_status=SUCCESS,
|
|
27
|
+
config={
|
|
28
|
+
"storage_uri": "cache://test_schema/test_table",
|
|
29
|
+
"num_data_nodes": 3,
|
|
30
|
+
"metadata_loc": "gs://my-bucket/metadata",
|
|
31
|
+
},
|
|
32
|
+
expected_output=None,
|
|
33
|
+
),
|
|
34
|
+
TestCase(
|
|
35
|
+
name="invalid num_data_nodes raises ValueError",
|
|
36
|
+
expected_status=FAILED,
|
|
37
|
+
config={
|
|
38
|
+
"storage_uri": "cache://test_schema/test_table",
|
|
39
|
+
"num_data_nodes": 1,
|
|
40
|
+
"metadata_loc": "gs://my-bucket/metadata",
|
|
41
|
+
},
|
|
42
|
+
expected_error=ValueError,
|
|
43
|
+
),
|
|
44
|
+
TestCase(
|
|
45
|
+
name="zero num_data_nodes raises ValueError",
|
|
46
|
+
expected_status=FAILED,
|
|
47
|
+
config={
|
|
48
|
+
"storage_uri": "cache://test_schema/test_table",
|
|
49
|
+
"num_data_nodes": 0,
|
|
50
|
+
"metadata_loc": "gs://my-bucket/metadata",
|
|
51
|
+
},
|
|
52
|
+
expected_error=ValueError,
|
|
53
|
+
),
|
|
54
|
+
TestCase(
|
|
55
|
+
name="negative num_data_nodes raises ValueError",
|
|
56
|
+
expected_status=FAILED,
|
|
57
|
+
config={
|
|
58
|
+
"storage_uri": "cache://test_schema/test_table",
|
|
59
|
+
"num_data_nodes": -1,
|
|
60
|
+
"metadata_loc": "gs://my-bucket/metadata",
|
|
61
|
+
},
|
|
62
|
+
expected_error=ValueError,
|
|
63
|
+
),
|
|
64
|
+
TestCase(
|
|
65
|
+
name="invalid storage_uri without cache:// prefix raises ValueError",
|
|
66
|
+
expected_status=FAILED,
|
|
67
|
+
config={
|
|
68
|
+
"storage_uri": "invalid://test_schema/test_table",
|
|
69
|
+
"num_data_nodes": 3,
|
|
70
|
+
"metadata_loc": "gs://my-bucket/metadata",
|
|
71
|
+
},
|
|
72
|
+
expected_error=ValueError,
|
|
73
|
+
),
|
|
74
|
+
TestCase(
|
|
75
|
+
name="invalid storage_uri format raises ValueError",
|
|
76
|
+
expected_status=FAILED,
|
|
77
|
+
config={
|
|
78
|
+
"storage_uri": "cache://test_schema",
|
|
79
|
+
"num_data_nodes": 3,
|
|
80
|
+
"metadata_loc": "gs://my-bucket/metadata",
|
|
81
|
+
},
|
|
82
|
+
expected_error=ValueError,
|
|
83
|
+
),
|
|
84
|
+
TestCase(
|
|
85
|
+
name="invalid storage_uri with too many parts raises ValueError",
|
|
86
|
+
expected_status=FAILED,
|
|
87
|
+
config={
|
|
88
|
+
"storage_uri": "cache://test_schema/test_table/extra",
|
|
89
|
+
"num_data_nodes": 3,
|
|
90
|
+
"metadata_loc": "gs://my-bucket/metadata",
|
|
91
|
+
},
|
|
92
|
+
expected_error=ValueError,
|
|
93
|
+
),
|
|
94
|
+
],
|
|
95
|
+
)
|
|
96
|
+
def test_data_cache_initializer(test_case: TestCase):
|
|
97
|
+
"""Test DataCacheInitializer creation and validation."""
|
|
98
|
+
print("Executing test:", test_case.name)
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
initializer = types.DataCacheInitializer(
|
|
102
|
+
storage_uri=test_case.config["storage_uri"],
|
|
103
|
+
num_data_nodes=test_case.config["num_data_nodes"],
|
|
104
|
+
metadata_loc=test_case.config["metadata_loc"],
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
assert test_case.expected_status == SUCCESS
|
|
108
|
+
# Only check the fields that were passed in config, not auto-generated ones
|
|
109
|
+
for key in test_case.config:
|
|
110
|
+
assert getattr(initializer, key) == test_case.config[key]
|
|
111
|
+
|
|
112
|
+
except Exception as e:
|
|
113
|
+
assert test_case.expected_status == FAILED
|
|
114
|
+
assert type(e) is test_case.expected_error
|
|
115
|
+
print("test execution complete")
|