viettelcloud-aiplatform 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. viettelcloud/__init__.py +1 -0
  2. viettelcloud/aiplatform/__init__.py +15 -0
  3. viettelcloud/aiplatform/common/__init__.py +0 -0
  4. viettelcloud/aiplatform/common/constants.py +22 -0
  5. viettelcloud/aiplatform/common/types.py +28 -0
  6. viettelcloud/aiplatform/common/utils.py +40 -0
  7. viettelcloud/aiplatform/hub/OWNERS +14 -0
  8. viettelcloud/aiplatform/hub/__init__.py +25 -0
  9. viettelcloud/aiplatform/hub/api/__init__.py +13 -0
  10. viettelcloud/aiplatform/hub/api/_proxy_client.py +355 -0
  11. viettelcloud/aiplatform/hub/api/model_registry_client.py +561 -0
  12. viettelcloud/aiplatform/hub/api/model_registry_client_test.py +462 -0
  13. viettelcloud/aiplatform/optimizer/__init__.py +45 -0
  14. viettelcloud/aiplatform/optimizer/api/__init__.py +0 -0
  15. viettelcloud/aiplatform/optimizer/api/optimizer_client.py +248 -0
  16. viettelcloud/aiplatform/optimizer/backends/__init__.py +13 -0
  17. viettelcloud/aiplatform/optimizer/backends/base.py +77 -0
  18. viettelcloud/aiplatform/optimizer/backends/kubernetes/__init__.py +13 -0
  19. viettelcloud/aiplatform/optimizer/backends/kubernetes/backend.py +563 -0
  20. viettelcloud/aiplatform/optimizer/backends/kubernetes/utils.py +112 -0
  21. viettelcloud/aiplatform/optimizer/constants/__init__.py +13 -0
  22. viettelcloud/aiplatform/optimizer/constants/constants.py +59 -0
  23. viettelcloud/aiplatform/optimizer/types/__init__.py +13 -0
  24. viettelcloud/aiplatform/optimizer/types/algorithm_types.py +87 -0
  25. viettelcloud/aiplatform/optimizer/types/optimization_types.py +135 -0
  26. viettelcloud/aiplatform/optimizer/types/search_types.py +95 -0
  27. viettelcloud/aiplatform/py.typed +0 -0
  28. viettelcloud/aiplatform/trainer/__init__.py +82 -0
  29. viettelcloud/aiplatform/trainer/api/__init__.py +3 -0
  30. viettelcloud/aiplatform/trainer/api/trainer_client.py +277 -0
  31. viettelcloud/aiplatform/trainer/api/trainer_client_test.py +72 -0
  32. viettelcloud/aiplatform/trainer/backends/__init__.py +0 -0
  33. viettelcloud/aiplatform/trainer/backends/base.py +94 -0
  34. viettelcloud/aiplatform/trainer/backends/container/adapters/base.py +195 -0
  35. viettelcloud/aiplatform/trainer/backends/container/adapters/docker.py +231 -0
  36. viettelcloud/aiplatform/trainer/backends/container/adapters/podman.py +258 -0
  37. viettelcloud/aiplatform/trainer/backends/container/backend.py +668 -0
  38. viettelcloud/aiplatform/trainer/backends/container/backend_test.py +867 -0
  39. viettelcloud/aiplatform/trainer/backends/container/runtime_loader.py +631 -0
  40. viettelcloud/aiplatform/trainer/backends/container/runtime_loader_test.py +637 -0
  41. viettelcloud/aiplatform/trainer/backends/container/types.py +67 -0
  42. viettelcloud/aiplatform/trainer/backends/container/utils.py +213 -0
  43. viettelcloud/aiplatform/trainer/backends/kubernetes/__init__.py +0 -0
  44. viettelcloud/aiplatform/trainer/backends/kubernetes/backend.py +710 -0
  45. viettelcloud/aiplatform/trainer/backends/kubernetes/backend_test.py +1344 -0
  46. viettelcloud/aiplatform/trainer/backends/kubernetes/constants.py +15 -0
  47. viettelcloud/aiplatform/trainer/backends/kubernetes/utils.py +636 -0
  48. viettelcloud/aiplatform/trainer/backends/kubernetes/utils_test.py +582 -0
  49. viettelcloud/aiplatform/trainer/backends/localprocess/__init__.py +0 -0
  50. viettelcloud/aiplatform/trainer/backends/localprocess/backend.py +306 -0
  51. viettelcloud/aiplatform/trainer/backends/localprocess/backend_test.py +501 -0
  52. viettelcloud/aiplatform/trainer/backends/localprocess/constants.py +90 -0
  53. viettelcloud/aiplatform/trainer/backends/localprocess/job.py +184 -0
  54. viettelcloud/aiplatform/trainer/backends/localprocess/types.py +52 -0
  55. viettelcloud/aiplatform/trainer/backends/localprocess/utils.py +302 -0
  56. viettelcloud/aiplatform/trainer/constants/__init__.py +0 -0
  57. viettelcloud/aiplatform/trainer/constants/constants.py +179 -0
  58. viettelcloud/aiplatform/trainer/options/__init__.py +52 -0
  59. viettelcloud/aiplatform/trainer/options/common.py +55 -0
  60. viettelcloud/aiplatform/trainer/options/kubernetes.py +502 -0
  61. viettelcloud/aiplatform/trainer/options/kubernetes_test.py +259 -0
  62. viettelcloud/aiplatform/trainer/options/localprocess.py +20 -0
  63. viettelcloud/aiplatform/trainer/test/common.py +22 -0
  64. viettelcloud/aiplatform/trainer/types/__init__.py +0 -0
  65. viettelcloud/aiplatform/trainer/types/types.py +517 -0
  66. viettelcloud/aiplatform/trainer/types/types_test.py +115 -0
  67. viettelcloud_aiplatform-0.3.0.dist-info/METADATA +226 -0
  68. viettelcloud_aiplatform-0.3.0.dist-info/RECORD +71 -0
  69. viettelcloud_aiplatform-0.3.0.dist-info/WHEEL +4 -0
  70. viettelcloud_aiplatform-0.3.0.dist-info/licenses/LICENSE +201 -0
  71. viettelcloud_aiplatform-0.3.0.dist-info/licenses/NOTICE +36 -0
@@ -0,0 +1,517 @@
1
+ # Copyright 2024 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ from dataclasses import dataclass, field
17
+ from datetime import datetime
18
+ from enum import Enum
19
+ from typing import Callable, Optional, Union
20
+ from urllib.parse import urlparse
21
+
22
+ import viettelcloud.aiplatform.common.constants as common_constants
23
+ from viettelcloud.aiplatform.trainer.constants import constants
24
+
25
+
26
+ # Configuration for the Custom Trainer.
27
+ @dataclass
28
+ class CustomTrainer:
29
+ """Custom Trainer configuration. Configure the self-contained function
30
+ that encapsulates the entire model training process.
31
+
32
+ Args:
33
+ func (`Callable`): The function that encapsulates the entire model training process.
34
+ func_args (`Optional[dict]`): The arguments to pass to the function.
35
+ image (`Optional[str]`): The optional container image to use in TrainJob.
36
+ packages_to_install (`Optional[list[str]]`):
37
+ A list of Python packages to install before running the function.
38
+ pip_index_urls (`list[str]`): The PyPI URLs from which to install
39
+ Python packages. The first URL will be the index-url, and remaining ones
40
+ are extra-index-urls.
41
+ num_nodes (`Optional[int]`): The number of nodes to use for training.
42
+ resources_per_node (`Optional[dict]`): The computing resources to allocate per node.
43
+ ```python
44
+ resources_per_node = {"gpu": 4, "cpu": 5, "memory": "10G"}
45
+ ```
46
+ If your compute supports fractional GPUs (e.g. multi-instance GPU),
47
+ you can set the resources as follows (request 1 GPU slice of 5Gb) :
48
+ ```python
49
+ resources_per_node = {"mig-1g.5gb": 1}
50
+ ```
51
+ env (`Optional[dict[str, str]]`): The environment variables to set in the training nodes.
52
+ """
53
+
54
+ func: Callable
55
+ func_args: Optional[dict] = None
56
+ image: Optional[str] = None
57
+ packages_to_install: Optional[list[str]] = None
58
+ pip_index_urls: list[str] = field(
59
+ default_factory=lambda: list(constants.DEFAULT_PIP_INDEX_URLS)
60
+ )
61
+ num_nodes: Optional[int] = None
62
+ resources_per_node: Optional[dict] = None
63
+ env: Optional[dict[str, str]] = None
64
+
65
+
66
+ # Configuration for the Custom Trainer Container.
67
+ @dataclass
68
+ class CustomTrainerContainer:
69
+ """Custom Trainer Container configuration. Configure the container image
70
+ that encapsulates the entire model training process.
71
+
72
+ Args:
73
+ image (`str`): The container image that encapsulates the entire model training process.
74
+ num_nodes (`Optional[int]`): The number of nodes to use for training.
75
+ resources_per_node (`Optional[dict]`): The computing resources to allocate per node.
76
+ ```python
77
+ resources_per_node = {"gpu": 4, "cpu": 5, "memory": "10G"}
78
+ ```
79
+ If your compute supports fractional GPUs (e.g. multi-instance GPU),
80
+ you can set the resources as follows (request 1 GPU slice of 5Gb) :
81
+ ```python
82
+ resources_per_node = {"mig-1g.5gb": 1}
83
+ ```
84
+ env (`Optional[dict[str, str]]`): The environment variables to set in the training nodes.
85
+ """
86
+
87
+ image: str
88
+ num_nodes: Optional[int] = None
89
+ resources_per_node: Optional[dict] = None
90
+ env: Optional[dict[str, str]] = None
91
+
92
+
93
+ # TODO(Electronic-Waste): Add more loss functions.
94
+ # Loss function for the TorchTune LLM Trainer.
95
+ class Loss(Enum):
96
+ """Loss function for the TorchTune LLM Trainer."""
97
+
98
+ CEWithChunkedOutputLoss = "torchtune.modules.loss.CEWithChunkedOutputLoss"
99
+
100
+
101
+ # Data type for the TorchTune LLM Trainer.
102
+ class DataType(Enum):
103
+ """Data type for the TorchTune LLM Trainer."""
104
+
105
+ BF16 = "bf16"
106
+ FP32 = "fp32"
107
+
108
+
109
+ # Data file type for the TorchTune LLM Trainer.
110
+ class DataFormat(Enum):
111
+ """Data file type for the TorchTune LLM Trainer."""
112
+
113
+ JSON = "json"
114
+ CSV = "csv"
115
+ PARQUET = "parquet"
116
+ ARROW = "arrow"
117
+ TEXT = "text"
118
+ XML = "xml"
119
+
120
+
121
+ # Configuration for the TorchTune Instruct dataset.
122
+ @dataclass
123
+ class TorchTuneInstructDataset:
124
+ """
125
+ Configuration for the custom dataset with user instruction prompts and model responses.
126
+ REF: https://pytorch.org/torchtune/main/generated/torchtune.datasets.instruct_dataset.html
127
+
128
+ Args:
129
+ source (`Optional[DataFormat]`): Data file type.
130
+ split (`Optional[str]`):
131
+ The split of the dataset to use. You can use this argument to load a subset of
132
+ a given split, e.g. split="train[:10%]". Default is `train`.
133
+ train_on_input (`Optional[bool]`):
134
+ Whether the model is trained on the user prompt or not. Default is False.
135
+ new_system_prompt (`Optional[str]`):
136
+ The new system prompt to use. If specified, prepend a system message.
137
+ This can serve as instructions to guide the model response. Default is None.
138
+ column_map (`Optional[Dict[str, str]]`):
139
+ A mapping to change the expected "input" and "output" column names to the actual
140
+ column names in the dataset. Keys should be "input" and "output" and values should
141
+ be the actual column names. Default is None, keeping the default "input" and
142
+ "output" column names.
143
+ """
144
+
145
+ source: Optional[DataFormat] = None
146
+ split: Optional[str] = None
147
+ train_on_input: Optional[bool] = None
148
+ new_system_prompt: Optional[str] = None
149
+ column_map: Optional[dict[str, str]] = None
150
+
151
+
152
+ @dataclass
153
+ class LoraConfig:
154
+ """Configuration for the LoRA/QLoRA/DoRA.
155
+ REF: https://meta-pytorch.org/torchtune/main/tutorials/memory_optimizations.html
156
+
157
+ Args:
158
+ apply_lora_to_mlp (`Optional[bool]`):
159
+ Whether to apply LoRA to the MLP in each transformer layer.
160
+ apply_lora_to_output (`Optional[bool]`):
161
+ Whether to apply LoRA to the model's final output projection.
162
+ lora_attn_modules (`list[str]`):
163
+ A list of strings specifying which layers of the model to apply LoRA,
164
+ default is ["q_proj", "v_proj", "output_proj"]:
165
+ 1. "q_proj" applies LoRA to the query projection layer.
166
+ 2. "k_proj" applies LoRA to the key projection layer.
167
+ 3. "v_proj" applies LoRA to the value projection layer.
168
+ 4. "output_proj" applies LoRA to the attention output projection layer.
169
+ lora_rank (`Optional[int]`): The rank of the low rank decomposition.
170
+ lora_alpha (`Optional[int]`):
171
+ The scaling factor that adjusts the magnitude of the low-rank matrices' output.
172
+ lora_dropout (`Optional[float]`):
173
+ The probability of applying Dropout to the low rank updates.
174
+ quantize_base (`Optional[bool]`): Whether to enable model quantization.
175
+ use_dora (`Optional[bool]`): Whether to enable DoRA.
176
+ """
177
+
178
+ apply_lora_to_mlp: Optional[bool] = None
179
+ apply_lora_to_output: Optional[bool] = None
180
+ lora_attn_modules: list[str] = field(
181
+ default_factory=lambda: ["q_proj", "v_proj", "output_proj"]
182
+ )
183
+ lora_rank: Optional[int] = None
184
+ lora_alpha: Optional[int] = None
185
+ lora_dropout: Optional[float] = None
186
+ quantize_base: Optional[bool] = None
187
+ use_dora: Optional[bool] = None
188
+
189
+
190
+ # Configuration for the TorchTune LLM Trainer.
191
+ @dataclass
192
+ class TorchTuneConfig:
193
+ """TorchTune LLM Trainer configuration. Configure the parameters in
194
+ the TorchTune LLM Trainer that already includes the fine-tuning logic.
195
+
196
+ Args:
197
+ dtype (`Optional[Dtype]`):
198
+ The underlying data type used to represent the model and optimizer parameters.
199
+ Currently, we only support `bf16` and `fp32`.
200
+ batch_size (`Optional[int]`):
201
+ The number of samples processed before updating model weights.
202
+ epochs (`Optional[int]`):
203
+ The number of samples processed before updating model weights.
204
+ loss (`Optional[Loss]`): The loss algorithm we use to fine-tune the LLM,
205
+ e.g. `torchtune.modules.loss.CEWithChunkedOutputLoss`.
206
+ num_nodes (`Optional[int]`): The number of nodes to use for training.
207
+ peft_config (`Optional[LoraConfig]`):
208
+ Configuration for the PEFT(Parameter-Efficient Fine-Tuning),
209
+ including LoRA/QLoRA/DoRA, etc.
210
+ dataset_preprocess_config (`Optional[TorchTuneInstructDataset]`):
211
+ Configuration for the dataset preprocessing.
212
+ resources_per_node (`Optional[Dict]`): The computing resources to allocate per node.
213
+ """
214
+
215
+ dtype: Optional[DataType] = None
216
+ batch_size: Optional[int] = None
217
+ epochs: Optional[int] = None
218
+ loss: Optional[Loss] = None
219
+ num_nodes: Optional[int] = None
220
+ peft_config: Optional[LoraConfig] = None
221
+ dataset_preprocess_config: Optional[TorchTuneInstructDataset] = None
222
+ resources_per_node: Optional[dict] = None
223
+
224
+
225
+ # Configuration for the Builtin Trainer.
226
+ @dataclass
227
+ class BuiltinTrainer:
228
+ """
229
+ Builtin Trainer configuration. Configure the builtin trainer that already includes
230
+ the fine-tuning logic, requiring only parameter adjustments.
231
+
232
+ Args:
233
+ config (`TorchTuneConfig`): The configuration for the builtin trainer.
234
+ """
235
+
236
+ config: TorchTuneConfig
237
+
238
+
239
+ # Change it to list: BUILTIN_CONFIGS, once we support more Builtin Trainer configs.
240
+ TORCH_TUNE = BuiltinTrainer.__annotations__["config"].__name__.lower().replace("config", "")
241
+
242
+
243
+ class TrainerType(Enum):
244
+ CUSTOM_TRAINER = CustomTrainer.__name__
245
+ BUILTIN_TRAINER = BuiltinTrainer.__name__
246
+
247
+
248
+ # Representation for the Trainer of the runtime.
249
+ @dataclass
250
+ class RuntimeTrainer:
251
+ trainer_type: TrainerType
252
+ framework: str
253
+ image: str
254
+ num_nodes: int = 1 # The default value is set in the APIs.
255
+ device: str = common_constants.UNKNOWN
256
+ device_count: str = common_constants.UNKNOWN
257
+ __command: tuple[str, ...] = field(init=False, repr=False)
258
+
259
+ @property
260
+ def command(self) -> tuple[str, ...]:
261
+ return self.__command
262
+
263
+ def set_command(self, command: tuple[str, ...]):
264
+ self.__command = command
265
+
266
+
267
+ # Representation for the Training Runtime.
268
+ @dataclass
269
+ class Runtime:
270
+ name: str
271
+ trainer: RuntimeTrainer
272
+ pretrained_model: Optional[str] = None
273
+
274
+
275
+ # Representation for the TrainJob steps.
276
+ @dataclass
277
+ class Step:
278
+ name: str
279
+ status: Optional[str]
280
+ pod_name: str
281
+ device: str = common_constants.UNKNOWN
282
+ device_count: str = common_constants.UNKNOWN
283
+
284
+
285
+ # Representation for the TrainJob.
286
+ @dataclass
287
+ class TrainJob:
288
+ name: str
289
+ runtime: Runtime
290
+ steps: list[Step]
291
+ num_nodes: int
292
+ creation_timestamp: datetime
293
+ status: str = common_constants.UNKNOWN
294
+
295
+
296
+ # Representation for TrainJob events.
297
+ @dataclass
298
+ class Event:
299
+ """Event object that represents a Kubernetes event related to a TrainJob.
300
+
301
+ Args:
302
+ involved_object_kind (`str`): The kind of object this event is about
303
+ (e.g., 'TrainJob', 'Pod').
304
+ involved_object_name (`str`): The name of the object this event is about.
305
+ message (`str`): Human-readable description of the event.
306
+ reason (`str`): Short, machine understandable string describing why
307
+ this event was generated.
308
+ event_time (`datetime`): The time at which the event was first recorded.
309
+ """
310
+
311
+ involved_object_kind: str
312
+ involved_object_name: str
313
+ message: str
314
+ reason: str
315
+ event_time: datetime
316
+
317
+
318
+ @dataclass
319
+ class BaseInitializer(abc.ABC):
320
+ """Base class for all initializers"""
321
+
322
+ storage_uri: str
323
+
324
+
325
+ @dataclass
326
+ class HuggingFaceDatasetInitializer(BaseInitializer):
327
+ """Configuration for downloading datasets from HuggingFace Hub.
328
+
329
+ Args:
330
+ storage_uri (`str`): The HuggingFace Hub model identifier in the format 'hf://username/repo_name'.
331
+ ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
332
+ access_token (`Optional[str]`): HuggingFace Hub access token for private datasets.
333
+ """
334
+
335
+ ignore_patterns: Optional[list[str]] = None
336
+ access_token: Optional[str] = None
337
+
338
+ def __post_init__(self):
339
+ """Validate HuggingFaceDatasetInitializer parameters."""
340
+
341
+ if not self.storage_uri.startswith("hf://"):
342
+ raise ValueError(f"storage_uri must start with 'hf://', got {self.storage_uri}")
343
+
344
+ if urlparse(self.storage_uri).path == "":
345
+ raise ValueError(
346
+ "storage_uri: must have absolute path with 'hf://<user_name>/<dataset_name>', got "
347
+ f"{self.storage_uri}"
348
+ )
349
+
350
+
351
+ @dataclass
352
+ class S3DatasetInitializer(BaseInitializer):
353
+ """Configuration for downloading datasets from S3-compatible storage.
354
+
355
+ Args:
356
+ storage_uri (`str`): The S3 URI for the model in the format 's3://bucket-name/path/to/model'.
357
+ ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
358
+ endpoint (`Optional[str]`): Custom S3 endpoint URL.
359
+ access_key_id (`Optional[str]`): Access key for authentication.
360
+ secret_access_key (`Optional[str]`): Secret key for authentication.
361
+ region (`Optional[str]`): Region used in instantiating the client.
362
+ role_arn (`Optional[str]`): The ARN of the role you want to assume.
363
+ """
364
+
365
+ ignore_patterns: Optional[list[str]] = None
366
+ endpoint: Optional[str] = None
367
+ access_key_id: Optional[str] = None
368
+ secret_access_key: Optional[str] = None
369
+ region: Optional[str] = None
370
+ role_arn: Optional[str] = None
371
+
372
+ def __post_init__(self):
373
+ """Validate S3DatasetInitializer parameters."""
374
+
375
+ if not self.storage_uri.startswith("s3://"):
376
+ raise ValueError(f"storage_uri must start with 's3://', got {self.storage_uri}")
377
+
378
+
379
+ @dataclass
380
+ class DataCacheInitializer(BaseInitializer):
381
+ """Configuration for distributed data caching system for training workloads.
382
+
383
+ Args:
384
+ storage_uri (`str`): The URI for the cached data in the format
385
+ 'cache://<SCHEMA_NAME>/<TABLE_NAME>'. This specifies the location
386
+ where the data cache will be stored and accessed.
387
+ metadata_loc (`str`): The metadata file path of an iceberg table.
388
+ num_data_nodes (`int`): The number of data nodes in the distributed cache
389
+ system. Must be greater than 1.
390
+ head_cpu (`Optional[str]`): The CPU resources to allocate for the cache head node.
391
+ head_mem (`Optional[str]`): The memory resources to allocate for the cache head node.
392
+ worker_cpu (`Optional[str]`): The CPU resources to allocate for each cache worker node.
393
+ worker_mem (`Optional[str]`): The memory resources to allocate for each cache worker node.
394
+ iam_role (`Optional[str]`): The IAM role to use for accessing metadata_loc file.
395
+ """
396
+
397
+ metadata_loc: str
398
+ num_data_nodes: int
399
+ head_cpu: Optional[str] = None
400
+ head_mem: Optional[str] = None
401
+ worker_cpu: Optional[str] = None
402
+ worker_mem: Optional[str] = None
403
+ iam_role: Optional[str] = None
404
+
405
+ def __post_init__(self):
406
+ """Validate DataCacheInitializer parameters."""
407
+
408
+ if self.num_data_nodes <= 1:
409
+ raise ValueError(f"num_data_nodes must be greater than 1, got {self.num_data_nodes}")
410
+
411
+ # Validate storage_uri format
412
+ if not self.storage_uri.startswith("cache://"):
413
+ raise ValueError(f"storage_uri must start with 'cache://', got {self.storage_uri}")
414
+
415
+ uri_path = self.storage_uri[len("cache://") :]
416
+ parts = uri_path.split("/")
417
+
418
+ if len(parts) != 2:
419
+ raise ValueError(
420
+ f"storage_uri must be in format "
421
+ f"'cache://<SCHEMA_NAME>/<TABLE_NAME>', got {self.storage_uri}"
422
+ )
423
+
424
+
425
+ @dataclass
426
+ class HuggingFaceModelInitializer(BaseInitializer):
427
+ """Configuration for downloading models from HuggingFace Hub.
428
+
429
+ Args:
430
+ storage_uri (`str`): The HuggingFace Hub model identifier in the format 'hf://username/repo_name'.
431
+ ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
432
+ access_token (`Optional[str]`): HuggingFace Hub access token.
433
+ """
434
+
435
+ ignore_patterns: Optional[list[str]] = field(
436
+ default_factory=lambda: constants.INITIALIZER_DEFAULT_IGNORE_PATTERNS
437
+ )
438
+ access_token: Optional[str] = None
439
+
440
+ def __post_init__(self):
441
+ """Validate HuggingFaceModelInitializer parameters."""
442
+
443
+ if not self.storage_uri.startswith("hf://"):
444
+ raise ValueError(f"storage_uri must start with 'hf://', got {self.storage_uri}")
445
+
446
+
447
+ @dataclass
448
+ class S3ModelInitializer(BaseInitializer):
449
+ """Configuration for downloading models from S3-compatible storage.
450
+
451
+ Args:
452
+ storage_uri (`str`): The S3 URI for the model in the format 's3://bucket-name/path/to/model'.
453
+ ignore_patterns (`Optional[list[str]]`): List of file patterns to ignore during download.
454
+ Defaults to `['*.msgpack', '*.h5', '*.bin', '.pt', '.pth']`.
455
+ endpoint (`Optional[str]`): Custom S3 endpoint URL.
456
+ access_key_id (`Optional[str]`): Access key for authentication.
457
+ secret_access_key (`Optional[str]`): Secret key for authentication.
458
+ region (`Optional[str]`): Region used in instantiating the client.
459
+ role_arn (`Optional[str]`): The ARN of the role you want to assume.
460
+ """
461
+
462
+ ignore_patterns: Optional[list[str]] = field(
463
+ default_factory=lambda: constants.INITIALIZER_DEFAULT_IGNORE_PATTERNS
464
+ )
465
+ endpoint: Optional[str] = None
466
+ access_key_id: Optional[str] = None
467
+ secret_access_key: Optional[str] = None
468
+ region: Optional[str] = None
469
+ role_arn: Optional[str] = None
470
+
471
+ def __post_init__(self):
472
+ """Validate S3ModelInitializer parameters."""
473
+
474
+ if not self.storage_uri.startswith("s3://"):
475
+ raise ValueError(f"storage_uri must start with 's3://', got {self.storage_uri}")
476
+
477
+
478
+ @dataclass
479
+ class Initializer:
480
+ """Initializer defines configurations for dataset and pre-trained model initialization
481
+
482
+ Args:
483
+ dataset (`Optional[Union[HuggingFaceDatasetInitializer, S3DatasetInitializer, DataCacheInitializer]]`):
484
+ The configuration for one of the supported dataset initializers.
485
+ model (`Optional[Union[HuggingFaceModelInitializer, S3ModelInitializer]]`):
486
+ The configuration for one of the supported model initializers.
487
+ """ # noqa: E501
488
+
489
+ dataset: Optional[
490
+ Union[HuggingFaceDatasetInitializer, S3DatasetInitializer, DataCacheInitializer]
491
+ ] = None
492
+ model: Optional[Union[HuggingFaceModelInitializer, S3ModelInitializer]] = None
493
+
494
+
495
+ # TODO (andreyvelich): Add train() and optimize() methods to this class.
496
+ @dataclass
497
+ class TrainJobTemplate:
498
+ """TrainJob template configuration.
499
+
500
+ Args:
501
+ trainer (`CustomTrainer`): Configuration for a CustomTrainer.
502
+ runtime (`Optional[Union[str, Runtime]]`): Optional, reference to one of the existing
503
+ runtimes. It can accept the runtime name or Runtime object from the `get_runtime()` API.
504
+ Defaults to the torch-distributed runtime if not provided.
505
+ initializer (`Optional[Initializer]`): Optional configuration for the dataset and model
506
+ initializers.
507
+ """
508
+
509
+ trainer: CustomTrainer
510
+ runtime: Optional[Union[str, Runtime]] = None
511
+ initializer: Optional[Initializer] = None
512
+
513
+ def keys(self):
514
+ return ["trainer", "runtime", "initializer"]
515
+
516
+ def __getitem__(self, key):
517
+ return getattr(self, key)
@@ -0,0 +1,115 @@
1
+ # Copyright 2025 The Kubeflow Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pytest
16
+
17
+ from viettelcloud.aiplatform.trainer.test.common import FAILED, SUCCESS, TestCase
18
+ from viettelcloud.aiplatform.trainer.types import types
19
+
20
+
21
+ @pytest.mark.parametrize(
22
+ "test_case",
23
+ [
24
+ TestCase(
25
+ name="valid datacacheinitializer creation",
26
+ expected_status=SUCCESS,
27
+ config={
28
+ "storage_uri": "cache://test_schema/test_table",
29
+ "num_data_nodes": 3,
30
+ "metadata_loc": "gs://my-bucket/metadata",
31
+ },
32
+ expected_output=None,
33
+ ),
34
+ TestCase(
35
+ name="invalid num_data_nodes raises ValueError",
36
+ expected_status=FAILED,
37
+ config={
38
+ "storage_uri": "cache://test_schema/test_table",
39
+ "num_data_nodes": 1,
40
+ "metadata_loc": "gs://my-bucket/metadata",
41
+ },
42
+ expected_error=ValueError,
43
+ ),
44
+ TestCase(
45
+ name="zero num_data_nodes raises ValueError",
46
+ expected_status=FAILED,
47
+ config={
48
+ "storage_uri": "cache://test_schema/test_table",
49
+ "num_data_nodes": 0,
50
+ "metadata_loc": "gs://my-bucket/metadata",
51
+ },
52
+ expected_error=ValueError,
53
+ ),
54
+ TestCase(
55
+ name="negative num_data_nodes raises ValueError",
56
+ expected_status=FAILED,
57
+ config={
58
+ "storage_uri": "cache://test_schema/test_table",
59
+ "num_data_nodes": -1,
60
+ "metadata_loc": "gs://my-bucket/metadata",
61
+ },
62
+ expected_error=ValueError,
63
+ ),
64
+ TestCase(
65
+ name="invalid storage_uri without cache:// prefix raises ValueError",
66
+ expected_status=FAILED,
67
+ config={
68
+ "storage_uri": "invalid://test_schema/test_table",
69
+ "num_data_nodes": 3,
70
+ "metadata_loc": "gs://my-bucket/metadata",
71
+ },
72
+ expected_error=ValueError,
73
+ ),
74
+ TestCase(
75
+ name="invalid storage_uri format raises ValueError",
76
+ expected_status=FAILED,
77
+ config={
78
+ "storage_uri": "cache://test_schema",
79
+ "num_data_nodes": 3,
80
+ "metadata_loc": "gs://my-bucket/metadata",
81
+ },
82
+ expected_error=ValueError,
83
+ ),
84
+ TestCase(
85
+ name="invalid storage_uri with too many parts raises ValueError",
86
+ expected_status=FAILED,
87
+ config={
88
+ "storage_uri": "cache://test_schema/test_table/extra",
89
+ "num_data_nodes": 3,
90
+ "metadata_loc": "gs://my-bucket/metadata",
91
+ },
92
+ expected_error=ValueError,
93
+ ),
94
+ ],
95
+ )
96
+ def test_data_cache_initializer(test_case: TestCase):
97
+ """Test DataCacheInitializer creation and validation."""
98
+ print("Executing test:", test_case.name)
99
+
100
+ try:
101
+ initializer = types.DataCacheInitializer(
102
+ storage_uri=test_case.config["storage_uri"],
103
+ num_data_nodes=test_case.config["num_data_nodes"],
104
+ metadata_loc=test_case.config["metadata_loc"],
105
+ )
106
+
107
+ assert test_case.expected_status == SUCCESS
108
+ # Only check the fields that were passed in config, not auto-generated ones
109
+ for key in test_case.config:
110
+ assert getattr(initializer, key) == test_case.config[key]
111
+
112
+ except Exception as e:
113
+ assert test_case.expected_status == FAILED
114
+ assert type(e) is test_case.expected_error
115
+ print("test execution complete")