together 1.5.34__py3-none-any.whl → 2.0.0a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- together/__init__.py +101 -114
- together/_base_client.py +1995 -0
- together/_client.py +1033 -0
- together/_compat.py +219 -0
- together/_constants.py +14 -0
- together/_exceptions.py +108 -0
- together/_files.py +123 -0
- together/_models.py +857 -0
- together/_qs.py +150 -0
- together/_resource.py +43 -0
- together/_response.py +830 -0
- together/_streaming.py +370 -0
- together/_types.py +260 -0
- together/_utils/__init__.py +64 -0
- together/_utils/_compat.py +45 -0
- together/_utils/_datetime_parse.py +136 -0
- together/_utils/_logs.py +25 -0
- together/_utils/_proxy.py +65 -0
- together/_utils/_reflection.py +42 -0
- together/_utils/_resources_proxy.py +24 -0
- together/_utils/_streams.py +12 -0
- together/_utils/_sync.py +58 -0
- together/_utils/_transform.py +457 -0
- together/_utils/_typing.py +156 -0
- together/_utils/_utils.py +421 -0
- together/_version.py +4 -0
- together/lib/.keep +4 -0
- together/lib/__init__.py +23 -0
- together/{cli → lib/cli}/api/endpoints.py +65 -81
- together/{cli/api/evaluation.py → lib/cli/api/evals.py} +152 -43
- together/{cli → lib/cli}/api/files.py +20 -17
- together/{cli/api/finetune.py → lib/cli/api/fine_tuning.py} +116 -172
- together/{cli → lib/cli}/api/models.py +34 -27
- together/lib/cli/api/utils.py +50 -0
- together/{cli → lib/cli}/cli.py +16 -26
- together/{constants.py → lib/constants.py} +11 -24
- together/lib/resources/__init__.py +11 -0
- together/lib/resources/files.py +999 -0
- together/lib/resources/fine_tuning.py +280 -0
- together/lib/resources/models.py +35 -0
- together/lib/types/__init__.py +13 -0
- together/lib/types/error.py +9 -0
- together/lib/types/fine_tuning.py +397 -0
- together/{utils → lib/utils}/__init__.py +6 -14
- together/{utils → lib/utils}/_log.py +11 -16
- together/{utils → lib/utils}/files.py +90 -288
- together/lib/utils/serializer.py +10 -0
- together/{utils → lib/utils}/tools.py +19 -55
- together/resources/__init__.py +225 -39
- together/resources/audio/__init__.py +72 -48
- together/resources/audio/audio.py +198 -0
- together/resources/audio/speech.py +574 -128
- together/resources/audio/transcriptions.py +247 -261
- together/resources/audio/translations.py +221 -241
- together/resources/audio/voices.py +111 -41
- together/resources/batches.py +417 -0
- together/resources/chat/__init__.py +30 -21
- together/resources/chat/chat.py +102 -0
- together/resources/chat/completions.py +1063 -263
- together/resources/code_interpreter/__init__.py +33 -0
- together/resources/code_interpreter/code_interpreter.py +258 -0
- together/resources/code_interpreter/sessions.py +135 -0
- together/resources/completions.py +884 -225
- together/resources/embeddings.py +172 -68
- together/resources/endpoints.py +589 -477
- together/resources/evals.py +452 -0
- together/resources/files.py +397 -129
- together/resources/fine_tuning.py +1033 -0
- together/resources/hardware.py +181 -0
- together/resources/images.py +258 -104
- together/resources/jobs.py +214 -0
- together/resources/models.py +223 -193
- together/resources/rerank.py +190 -92
- together/resources/videos.py +286 -214
- together/types/__init__.py +66 -167
- together/types/audio/__init__.py +10 -0
- together/types/audio/speech_create_params.py +75 -0
- together/types/audio/transcription_create_params.py +54 -0
- together/types/audio/transcription_create_response.py +111 -0
- together/types/audio/translation_create_params.py +40 -0
- together/types/audio/translation_create_response.py +70 -0
- together/types/audio/voice_list_response.py +23 -0
- together/types/audio_speech_stream_chunk.py +16 -0
- together/types/autoscaling.py +13 -0
- together/types/autoscaling_param.py +15 -0
- together/types/batch_create_params.py +24 -0
- together/types/batch_create_response.py +14 -0
- together/types/batch_job.py +45 -0
- together/types/batch_list_response.py +10 -0
- together/types/chat/__init__.py +18 -0
- together/types/chat/chat_completion.py +60 -0
- together/types/chat/chat_completion_chunk.py +61 -0
- together/types/chat/chat_completion_structured_message_image_url_param.py +18 -0
- together/types/chat/chat_completion_structured_message_text_param.py +13 -0
- together/types/chat/chat_completion_structured_message_video_url_param.py +18 -0
- together/types/chat/chat_completion_usage.py +13 -0
- together/types/chat/chat_completion_warning.py +9 -0
- together/types/chat/completion_create_params.py +329 -0
- together/types/code_interpreter/__init__.py +5 -0
- together/types/code_interpreter/session_list_response.py +31 -0
- together/types/code_interpreter_execute_params.py +45 -0
- together/types/completion.py +42 -0
- together/types/completion_chunk.py +66 -0
- together/types/completion_create_params.py +138 -0
- together/types/dedicated_endpoint.py +44 -0
- together/types/embedding.py +24 -0
- together/types/embedding_create_params.py +31 -0
- together/types/endpoint_create_params.py +43 -0
- together/types/endpoint_list_avzones_response.py +11 -0
- together/types/endpoint_list_params.py +18 -0
- together/types/endpoint_list_response.py +41 -0
- together/types/endpoint_update_params.py +27 -0
- together/types/eval_create_params.py +263 -0
- together/types/eval_create_response.py +16 -0
- together/types/eval_list_params.py +21 -0
- together/types/eval_list_response.py +10 -0
- together/types/eval_status_response.py +100 -0
- together/types/evaluation_job.py +139 -0
- together/types/execute_response.py +108 -0
- together/types/file_delete_response.py +13 -0
- together/types/file_list.py +12 -0
- together/types/file_purpose.py +9 -0
- together/types/file_response.py +31 -0
- together/types/file_type.py +7 -0
- together/types/fine_tuning_cancel_response.py +194 -0
- together/types/fine_tuning_content_params.py +24 -0
- together/types/fine_tuning_delete_params.py +11 -0
- together/types/fine_tuning_delete_response.py +12 -0
- together/types/fine_tuning_list_checkpoints_response.py +21 -0
- together/types/fine_tuning_list_events_response.py +12 -0
- together/types/fine_tuning_list_response.py +199 -0
- together/types/finetune_event.py +41 -0
- together/types/finetune_event_type.py +33 -0
- together/types/finetune_response.py +177 -0
- together/types/hardware_list_params.py +16 -0
- together/types/hardware_list_response.py +58 -0
- together/types/image_data_b64.py +15 -0
- together/types/image_data_url.py +15 -0
- together/types/image_file.py +23 -0
- together/types/image_generate_params.py +85 -0
- together/types/job_list_response.py +47 -0
- together/types/job_retrieve_response.py +43 -0
- together/types/log_probs.py +18 -0
- together/types/model_list_response.py +10 -0
- together/types/model_object.py +42 -0
- together/types/model_upload_params.py +36 -0
- together/types/model_upload_response.py +23 -0
- together/types/rerank_create_params.py +36 -0
- together/types/rerank_create_response.py +36 -0
- together/types/tool_choice.py +23 -0
- together/types/tool_choice_param.py +23 -0
- together/types/tools_param.py +23 -0
- together/types/training_method_dpo.py +22 -0
- together/types/training_method_sft.py +18 -0
- together/types/video_create_params.py +86 -0
- together/types/video_create_response.py +10 -0
- together/types/video_job.py +57 -0
- together-2.0.0a6.dist-info/METADATA +729 -0
- together-2.0.0a6.dist-info/RECORD +165 -0
- {together-1.5.34.dist-info → together-2.0.0a6.dist-info}/WHEEL +1 -1
- together-2.0.0a6.dist-info/entry_points.txt +2 -0
- {together-1.5.34.dist-info → together-2.0.0a6.dist-info}/licenses/LICENSE +1 -1
- together/abstract/api_requestor.py +0 -770
- together/cli/api/chat.py +0 -298
- together/cli/api/completions.py +0 -119
- together/cli/api/images.py +0 -93
- together/cli/api/utils.py +0 -139
- together/client.py +0 -186
- together/error.py +0 -194
- together/filemanager.py +0 -635
- together/legacy/__init__.py +0 -0
- together/legacy/base.py +0 -27
- together/legacy/complete.py +0 -93
- together/legacy/embeddings.py +0 -27
- together/legacy/files.py +0 -146
- together/legacy/finetune.py +0 -177
- together/legacy/images.py +0 -27
- together/legacy/models.py +0 -44
- together/resources/batch.py +0 -165
- together/resources/code_interpreter.py +0 -82
- together/resources/evaluation.py +0 -808
- together/resources/finetune.py +0 -1388
- together/together_response.py +0 -50
- together/types/abstract.py +0 -26
- together/types/audio_speech.py +0 -311
- together/types/batch.py +0 -54
- together/types/chat_completions.py +0 -210
- together/types/code_interpreter.py +0 -57
- together/types/common.py +0 -67
- together/types/completions.py +0 -107
- together/types/embeddings.py +0 -35
- together/types/endpoints.py +0 -123
- together/types/error.py +0 -16
- together/types/evaluation.py +0 -93
- together/types/files.py +0 -93
- together/types/finetune.py +0 -464
- together/types/images.py +0 -42
- together/types/models.py +0 -96
- together/types/rerank.py +0 -43
- together/types/videos.py +0 -69
- together/utils/api_helpers.py +0 -124
- together/version.py +0 -6
- together-1.5.34.dist-info/METADATA +0 -583
- together-1.5.34.dist-info/RECORD +0 -77
- together-1.5.34.dist-info/entry_points.txt +0 -3
- /together/{abstract → lib/cli}/__init__.py +0 -0
- /together/{cli → lib/cli/api}/__init__.py +0 -0
- /together/{cli/api/__init__.py → py.typed} +0 -0
together/resources/finetune.py
DELETED
|
@@ -1,1388 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import re
|
|
4
|
-
from pathlib import Path
|
|
5
|
-
from typing import Dict, List, Literal
|
|
6
|
-
|
|
7
|
-
from rich import print as rprint
|
|
8
|
-
|
|
9
|
-
from together.abstract import api_requestor
|
|
10
|
-
from together.filemanager import DownloadManager
|
|
11
|
-
from together.together_response import TogetherResponse
|
|
12
|
-
from together.types import (
|
|
13
|
-
CosineLRScheduler,
|
|
14
|
-
CosineLRSchedulerArgs,
|
|
15
|
-
FinetuneCheckpoint,
|
|
16
|
-
FinetuneDeleteResponse,
|
|
17
|
-
FinetuneDownloadResult,
|
|
18
|
-
FinetuneList,
|
|
19
|
-
FinetuneListEvents,
|
|
20
|
-
FinetuneLRScheduler,
|
|
21
|
-
FinetuneMultimodalParams,
|
|
22
|
-
FinetunePriceEstimationRequest,
|
|
23
|
-
FinetunePriceEstimationResponse,
|
|
24
|
-
FinetuneRequest,
|
|
25
|
-
FinetuneResponse,
|
|
26
|
-
FinetuneTrainingLimits,
|
|
27
|
-
FullTrainingType,
|
|
28
|
-
LinearLRScheduler,
|
|
29
|
-
LinearLRSchedulerArgs,
|
|
30
|
-
LoRATrainingType,
|
|
31
|
-
TogetherClient,
|
|
32
|
-
TogetherRequest,
|
|
33
|
-
TrainingMethodDPO,
|
|
34
|
-
TrainingMethodSFT,
|
|
35
|
-
TrainingType,
|
|
36
|
-
)
|
|
37
|
-
from together.types.finetune import DownloadCheckpointType, TrainingMethod
|
|
38
|
-
from together.utils import log_warn_once, normalize_key
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
AVAILABLE_TRAINING_METHODS = {
|
|
45
|
-
TrainingMethodSFT().method,
|
|
46
|
-
TrainingMethodDPO().method,
|
|
47
|
-
}
|
|
48
|
-
_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
|
|
49
|
-
"The estimated price of the fine-tuning job is {} which is significantly "
|
|
50
|
-
"greater than your current credit limit and balance combined. "
|
|
51
|
-
"It will likely get cancelled due to insufficient funds. "
|
|
52
|
-
"Proceed at your own risk."
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
def create_finetune_request(
|
|
57
|
-
model_limits: FinetuneTrainingLimits,
|
|
58
|
-
training_file: str,
|
|
59
|
-
model: str | None = None,
|
|
60
|
-
n_epochs: int = 1,
|
|
61
|
-
validation_file: str | None = "",
|
|
62
|
-
n_evals: int | None = 0,
|
|
63
|
-
n_checkpoints: int | None = 1,
|
|
64
|
-
batch_size: int | Literal["max"] = "max",
|
|
65
|
-
learning_rate: float | None = 0.00001,
|
|
66
|
-
lr_scheduler_type: Literal["linear", "cosine"] = "cosine",
|
|
67
|
-
min_lr_ratio: float = 0.0,
|
|
68
|
-
scheduler_num_cycles: float = 0.5,
|
|
69
|
-
warmup_ratio: float | None = None,
|
|
70
|
-
max_grad_norm: float = 1.0,
|
|
71
|
-
weight_decay: float = 0.0,
|
|
72
|
-
lora: bool = False,
|
|
73
|
-
lora_r: int | None = None,
|
|
74
|
-
lora_dropout: float | None = 0,
|
|
75
|
-
lora_alpha: float | None = None,
|
|
76
|
-
lora_trainable_modules: str | None = "all-linear",
|
|
77
|
-
train_vision: bool = False,
|
|
78
|
-
suffix: str | None = None,
|
|
79
|
-
wandb_api_key: str | None = None,
|
|
80
|
-
wandb_base_url: str | None = None,
|
|
81
|
-
wandb_project_name: str | None = None,
|
|
82
|
-
wandb_name: str | None = None,
|
|
83
|
-
train_on_inputs: bool | Literal["auto"] | None = None,
|
|
84
|
-
training_method: str = "sft",
|
|
85
|
-
dpo_beta: float | None = None,
|
|
86
|
-
dpo_normalize_logratios_by_length: bool = False,
|
|
87
|
-
rpo_alpha: float | None = None,
|
|
88
|
-
simpo_gamma: float | None = None,
|
|
89
|
-
from_checkpoint: str | None = None,
|
|
90
|
-
from_hf_model: str | None = None,
|
|
91
|
-
hf_model_revision: str | None = None,
|
|
92
|
-
hf_api_token: str | None = None,
|
|
93
|
-
hf_output_repo_name: str | None = None,
|
|
94
|
-
) -> FinetuneRequest:
|
|
95
|
-
if model is not None and from_checkpoint is not None:
|
|
96
|
-
raise ValueError(
|
|
97
|
-
"You must specify either a model or a checkpoint to start a job from, not both"
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
if model is None and from_checkpoint is None:
|
|
101
|
-
raise ValueError("You must specify either a model or a checkpoint")
|
|
102
|
-
|
|
103
|
-
if from_checkpoint is not None and from_hf_model is not None:
|
|
104
|
-
raise ValueError(
|
|
105
|
-
"You must specify either a Hugging Face Hub model or a previous checkpoint from "
|
|
106
|
-
"Together to start a job from, not both"
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
if from_hf_model is not None and model is None:
|
|
110
|
-
raise ValueError(
|
|
111
|
-
"You must specify the base model to fine-tune a model from the Hugging Face Hub"
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
model_or_checkpoint = model or from_checkpoint
|
|
115
|
-
|
|
116
|
-
if warmup_ratio is None:
|
|
117
|
-
warmup_ratio = 0.0
|
|
118
|
-
|
|
119
|
-
training_type: TrainingType = FullTrainingType()
|
|
120
|
-
if lora:
|
|
121
|
-
if model_limits.lora_training is None:
|
|
122
|
-
raise ValueError(
|
|
123
|
-
f"LoRA adapters are not supported for the selected model ({model_or_checkpoint})."
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
if lora_dropout is not None:
|
|
127
|
-
if not 0 <= lora_dropout < 1.0:
|
|
128
|
-
raise ValueError("LoRA dropout must be in [0, 1) range.")
|
|
129
|
-
|
|
130
|
-
lora_r = lora_r if lora_r is not None else model_limits.lora_training.max_rank
|
|
131
|
-
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
|
|
132
|
-
training_type = LoRATrainingType(
|
|
133
|
-
lora_r=lora_r,
|
|
134
|
-
lora_alpha=lora_alpha,
|
|
135
|
-
lora_dropout=lora_dropout,
|
|
136
|
-
lora_trainable_modules=lora_trainable_modules,
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
max_batch_size = model_limits.lora_training.max_batch_size
|
|
140
|
-
min_batch_size = model_limits.lora_training.min_batch_size
|
|
141
|
-
max_batch_size_dpo = model_limits.lora_training.max_batch_size_dpo
|
|
142
|
-
else:
|
|
143
|
-
if model_limits.full_training is None:
|
|
144
|
-
raise ValueError(
|
|
145
|
-
f"Full training is not supported for the selected model ({model_or_checkpoint})."
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
max_batch_size = model_limits.full_training.max_batch_size
|
|
149
|
-
min_batch_size = model_limits.full_training.min_batch_size
|
|
150
|
-
max_batch_size_dpo = model_limits.full_training.max_batch_size_dpo
|
|
151
|
-
|
|
152
|
-
if batch_size != "max":
|
|
153
|
-
if training_method == "sft":
|
|
154
|
-
if batch_size > max_batch_size:
|
|
155
|
-
raise ValueError(
|
|
156
|
-
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
|
|
157
|
-
)
|
|
158
|
-
elif training_method == "dpo":
|
|
159
|
-
if batch_size > max_batch_size_dpo:
|
|
160
|
-
raise ValueError(
|
|
161
|
-
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
if batch_size < min_batch_size:
|
|
165
|
-
raise ValueError(
|
|
166
|
-
f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}."
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
if warmup_ratio > 1 or warmup_ratio < 0:
|
|
170
|
-
raise ValueError(f"Warmup ratio should be between 0 and 1 (got {warmup_ratio})")
|
|
171
|
-
|
|
172
|
-
if min_lr_ratio is not None and (min_lr_ratio > 1 or min_lr_ratio < 0):
|
|
173
|
-
raise ValueError(
|
|
174
|
-
f"Min learning rate ratio should be between 0 and 1 (got {min_lr_ratio})"
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
if max_grad_norm < 0:
|
|
178
|
-
raise ValueError(
|
|
179
|
-
f"Max gradient norm should be non-negative (got {max_grad_norm})"
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
if weight_decay is not None and (weight_decay < 0):
|
|
183
|
-
raise ValueError(f"Weight decay should be non-negative (got {weight_decay})")
|
|
184
|
-
|
|
185
|
-
if training_method not in AVAILABLE_TRAINING_METHODS:
|
|
186
|
-
raise ValueError(
|
|
187
|
-
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
if train_on_inputs is not None and training_method != "sft":
|
|
191
|
-
raise ValueError("train_on_inputs is only supported for SFT training")
|
|
192
|
-
|
|
193
|
-
if train_on_inputs is None and training_method == "sft":
|
|
194
|
-
log_warn_once(
|
|
195
|
-
"train_on_inputs is not set for SFT training, it will be set to 'auto'"
|
|
196
|
-
)
|
|
197
|
-
train_on_inputs = "auto"
|
|
198
|
-
|
|
199
|
-
if dpo_beta is not None and training_method != "dpo":
|
|
200
|
-
raise ValueError("dpo_beta is only supported for DPO training")
|
|
201
|
-
if dpo_normalize_logratios_by_length and training_method != "dpo":
|
|
202
|
-
raise ValueError(
|
|
203
|
-
"dpo_normalize_logratios_by_length=True is only supported for DPO training"
|
|
204
|
-
)
|
|
205
|
-
if rpo_alpha is not None:
|
|
206
|
-
if training_method != "dpo":
|
|
207
|
-
raise ValueError("rpo_alpha is only supported for DPO training")
|
|
208
|
-
if not rpo_alpha >= 0.0:
|
|
209
|
-
raise ValueError(f"rpo_alpha should be non-negative (got {rpo_alpha})")
|
|
210
|
-
|
|
211
|
-
if simpo_gamma is not None:
|
|
212
|
-
if training_method != "dpo":
|
|
213
|
-
raise ValueError("simpo_gamma is only supported for DPO training")
|
|
214
|
-
if not simpo_gamma >= 0.0:
|
|
215
|
-
raise ValueError(f"simpo_gamma should be non-negative (got {simpo_gamma})")
|
|
216
|
-
|
|
217
|
-
lr_scheduler: FinetuneLRScheduler
|
|
218
|
-
if lr_scheduler_type == "cosine":
|
|
219
|
-
if scheduler_num_cycles <= 0.0:
|
|
220
|
-
raise ValueError(
|
|
221
|
-
f"Number of cycles should be greater than 0 (got {scheduler_num_cycles})"
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
lr_scheduler = CosineLRScheduler(
|
|
225
|
-
lr_scheduler_args=CosineLRSchedulerArgs(
|
|
226
|
-
min_lr_ratio=min_lr_ratio, num_cycles=scheduler_num_cycles
|
|
227
|
-
),
|
|
228
|
-
)
|
|
229
|
-
else:
|
|
230
|
-
lr_scheduler = LinearLRScheduler(
|
|
231
|
-
lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
training_method_cls: TrainingMethodSFT | TrainingMethodDPO
|
|
235
|
-
if training_method == "sft":
|
|
236
|
-
training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs)
|
|
237
|
-
elif training_method == "dpo":
|
|
238
|
-
if simpo_gamma is not None and simpo_gamma > 0:
|
|
239
|
-
dpo_reference_free = True
|
|
240
|
-
dpo_normalize_logratios_by_length = True
|
|
241
|
-
rprint(
|
|
242
|
-
f"Parameter simpo_gamma was set to {simpo_gamma}. "
|
|
243
|
-
"SimPO training detected. Reference logits will not be used "
|
|
244
|
-
"and length normalization of log-probabilities will be enabled."
|
|
245
|
-
)
|
|
246
|
-
else:
|
|
247
|
-
dpo_reference_free = False
|
|
248
|
-
|
|
249
|
-
training_method_cls = TrainingMethodDPO(
|
|
250
|
-
dpo_beta=dpo_beta,
|
|
251
|
-
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
|
|
252
|
-
dpo_reference_free=dpo_reference_free,
|
|
253
|
-
rpo_alpha=rpo_alpha,
|
|
254
|
-
simpo_gamma=simpo_gamma,
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
if model_limits.supports_vision:
|
|
258
|
-
multimodal_params = FinetuneMultimodalParams(train_vision=train_vision)
|
|
259
|
-
elif not model_limits.supports_vision and train_vision:
|
|
260
|
-
raise ValueError(
|
|
261
|
-
f"Vision encoder training is not supported for the non-multimodal model `{model}`"
|
|
262
|
-
)
|
|
263
|
-
else:
|
|
264
|
-
multimodal_params = None
|
|
265
|
-
|
|
266
|
-
finetune_request = FinetuneRequest(
|
|
267
|
-
model=model,
|
|
268
|
-
training_file=training_file,
|
|
269
|
-
validation_file=validation_file,
|
|
270
|
-
n_epochs=n_epochs,
|
|
271
|
-
n_evals=n_evals,
|
|
272
|
-
n_checkpoints=n_checkpoints,
|
|
273
|
-
batch_size=batch_size,
|
|
274
|
-
learning_rate=learning_rate,
|
|
275
|
-
lr_scheduler=lr_scheduler,
|
|
276
|
-
warmup_ratio=warmup_ratio,
|
|
277
|
-
max_grad_norm=max_grad_norm,
|
|
278
|
-
weight_decay=weight_decay,
|
|
279
|
-
training_type=training_type,
|
|
280
|
-
suffix=suffix,
|
|
281
|
-
wandb_key=wandb_api_key,
|
|
282
|
-
wandb_base_url=wandb_base_url,
|
|
283
|
-
wandb_project_name=wandb_project_name,
|
|
284
|
-
wandb_name=wandb_name,
|
|
285
|
-
training_method=training_method_cls,
|
|
286
|
-
multimodal_params=multimodal_params,
|
|
287
|
-
from_checkpoint=from_checkpoint,
|
|
288
|
-
from_hf_model=from_hf_model,
|
|
289
|
-
hf_model_revision=hf_model_revision,
|
|
290
|
-
hf_api_token=hf_api_token,
|
|
291
|
-
hf_output_repo_name=hf_output_repo_name,
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
return finetune_request
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
def _parse_raw_checkpoints(
|
|
298
|
-
checkpoints: List[Dict[str, str]], id: str
|
|
299
|
-
) -> List[FinetuneCheckpoint]:
|
|
300
|
-
"""
|
|
301
|
-
Helper function to process raw checkpoints and create checkpoint list.
|
|
302
|
-
|
|
303
|
-
Args:
|
|
304
|
-
checkpoints (List[Dict[str, str]]): List of raw checkpoints metadata
|
|
305
|
-
id (str): Fine-tune job ID
|
|
306
|
-
|
|
307
|
-
Returns:
|
|
308
|
-
List[FinetuneCheckpoint]: List of available checkpoints
|
|
309
|
-
"""
|
|
310
|
-
|
|
311
|
-
parsed_checkpoints = []
|
|
312
|
-
for checkpoint in checkpoints:
|
|
313
|
-
step = checkpoint["step"]
|
|
314
|
-
checkpoint_type = checkpoint["checkpoint_type"]
|
|
315
|
-
checkpoint_name = (
|
|
316
|
-
f"{id}:{step}" if "intermediate" in checkpoint_type.lower() else id
|
|
317
|
-
)
|
|
318
|
-
|
|
319
|
-
parsed_checkpoints.append(
|
|
320
|
-
FinetuneCheckpoint(
|
|
321
|
-
type=checkpoint_type,
|
|
322
|
-
timestamp=checkpoint["created_at"],
|
|
323
|
-
name=checkpoint_name,
|
|
324
|
-
)
|
|
325
|
-
)
|
|
326
|
-
|
|
327
|
-
parsed_checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
|
|
328
|
-
return parsed_checkpoints
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
class FineTuning:
|
|
332
|
-
def __init__(self, client: TogetherClient) -> None:
|
|
333
|
-
self._client = client
|
|
334
|
-
|
|
335
|
-
def create(
|
|
336
|
-
self,
|
|
337
|
-
*,
|
|
338
|
-
training_file: str,
|
|
339
|
-
model: str | None = None,
|
|
340
|
-
n_epochs: int = 1,
|
|
341
|
-
validation_file: str | None = "",
|
|
342
|
-
n_evals: int | None = 0,
|
|
343
|
-
n_checkpoints: int | None = 1,
|
|
344
|
-
batch_size: int | Literal["max"] = "max",
|
|
345
|
-
learning_rate: float | None = 0.00001,
|
|
346
|
-
lr_scheduler_type: Literal["linear", "cosine"] = "cosine",
|
|
347
|
-
min_lr_ratio: float = 0.0,
|
|
348
|
-
scheduler_num_cycles: float = 0.5,
|
|
349
|
-
warmup_ratio: float = 0.0,
|
|
350
|
-
max_grad_norm: float = 1.0,
|
|
351
|
-
weight_decay: float = 0.0,
|
|
352
|
-
lora: bool = True,
|
|
353
|
-
lora_r: int | None = None,
|
|
354
|
-
lora_dropout: float | None = 0,
|
|
355
|
-
lora_alpha: float | None = None,
|
|
356
|
-
lora_trainable_modules: str | None = "all-linear",
|
|
357
|
-
train_vision: bool = False,
|
|
358
|
-
suffix: str | None = None,
|
|
359
|
-
wandb_api_key: str | None = None,
|
|
360
|
-
wandb_base_url: str | None = None,
|
|
361
|
-
wandb_project_name: str | None = None,
|
|
362
|
-
wandb_name: str | None = None,
|
|
363
|
-
verbose: bool = False,
|
|
364
|
-
model_limits: FinetuneTrainingLimits | None = None,
|
|
365
|
-
train_on_inputs: bool | Literal["auto"] | None = None,
|
|
366
|
-
training_method: str = "sft",
|
|
367
|
-
dpo_beta: float | None = None,
|
|
368
|
-
dpo_normalize_logratios_by_length: bool = False,
|
|
369
|
-
rpo_alpha: float | None = None,
|
|
370
|
-
simpo_gamma: float | None = None,
|
|
371
|
-
from_checkpoint: str | None = None,
|
|
372
|
-
from_hf_model: str | None = None,
|
|
373
|
-
hf_model_revision: str | None = None,
|
|
374
|
-
hf_api_token: str | None = None,
|
|
375
|
-
hf_output_repo_name: str | None = None,
|
|
376
|
-
) -> FinetuneResponse:
|
|
377
|
-
"""
|
|
378
|
-
Method to initiate a fine-tuning job
|
|
379
|
-
|
|
380
|
-
Args:
|
|
381
|
-
training_file (str): File-ID of a file uploaded to the Together API
|
|
382
|
-
model (str, optional): Name of the base model to run fine-tune job on
|
|
383
|
-
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
|
|
384
|
-
validation file (str, optional): File ID of a file uploaded to the Together API for validation.
|
|
385
|
-
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
|
|
386
|
-
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
|
|
387
|
-
Defaults to 1.
|
|
388
|
-
batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
|
|
389
|
-
learning_rate (float, optional): Learning rate multiplier to use for training
|
|
390
|
-
Defaults to 0.00001.
|
|
391
|
-
lr_scheduler_type (Literal["linear", "cosine"]): Learning rate scheduler type. Defaults to "cosine".
|
|
392
|
-
min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
|
|
393
|
-
the learning rate scheduler. Defaults to 0.0.
|
|
394
|
-
scheduler_num_cycles (float, optional): Number or fraction of cycles for the cosine learning rate scheduler. Defaults to 0.5.
|
|
395
|
-
warmup_ratio (float, optional): Warmup ratio for the learning rate scheduler.
|
|
396
|
-
max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
|
|
397
|
-
weight_decay (float, optional): Weight decay. Defaults to 0.0.
|
|
398
|
-
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
|
|
399
|
-
lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
|
|
400
|
-
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
|
|
401
|
-
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
|
|
402
|
-
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
|
|
403
|
-
train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False.
|
|
404
|
-
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
|
|
405
|
-
Defaults to None.
|
|
406
|
-
wandb_api_key (str, optional): API key for Weights & Biases integration.
|
|
407
|
-
Defaults to None.
|
|
408
|
-
wandb_base_url (str, optional): Base URL for Weights & Biases integration.
|
|
409
|
-
Defaults to None.
|
|
410
|
-
wandb_project_name (str, optional): Project name for Weights & Biases integration.
|
|
411
|
-
Defaults to None.
|
|
412
|
-
wandb_name (str, optional): Run name for Weights & Biases integration.
|
|
413
|
-
Defaults to None.
|
|
414
|
-
verbose (bool, optional): whether to print the job parameters before submitting a request.
|
|
415
|
-
Defaults to False.
|
|
416
|
-
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
|
|
417
|
-
Defaults to None.
|
|
418
|
-
train_on_inputs (bool or "auto", optional): Whether to mask the user messages in conversational data or prompts in instruction data.
|
|
419
|
-
"auto" will automatically determine whether to mask the inputs based on the data format.
|
|
420
|
-
For datasets with the "text" field (general format), inputs will not be masked.
|
|
421
|
-
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
|
|
422
|
-
(Instruction format), inputs will be masked.
|
|
423
|
-
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
|
|
424
|
-
training_method (str, optional): Training method. Defaults to "sft".
|
|
425
|
-
Supported methods: "sft", "dpo".
|
|
426
|
-
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
|
|
427
|
-
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
|
|
428
|
-
rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
|
|
429
|
-
simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
|
|
430
|
-
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
|
|
431
|
-
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
|
|
432
|
-
The step value is optional, without it the final checkpoint will be used.
|
|
433
|
-
from_hf_model (str, optional): The Hugging Face Hub repo to start training from.
|
|
434
|
-
Should be as close as possible to the base model (specified by the `model` argument) in terms of architecture and size.
|
|
435
|
-
hf_model_revision (str, optional): The revision of the Hugging Face Hub model to continue training from. Defaults to None.
|
|
436
|
-
Example: hf_model_revision=None (defaults to the latest revision in `main`) or
|
|
437
|
-
hf_model_revision="607a30d783dfa663caf39e06633721c8d4cfcd7e" (specific commit).
|
|
438
|
-
hf_api_token (str, optional): API key for the Hugging Face Hub. Defaults to None.
|
|
439
|
-
hf_output_repo_name (str, optional): HF repo to upload the fine-tuned model to. Defaults to None.
|
|
440
|
-
|
|
441
|
-
Returns:
|
|
442
|
-
FinetuneResponse: Object containing information about fine-tuning job.
|
|
443
|
-
"""
|
|
444
|
-
|
|
445
|
-
requestor = api_requestor.APIRequestor(
|
|
446
|
-
client=self._client,
|
|
447
|
-
)
|
|
448
|
-
|
|
449
|
-
if model_limits is None:
|
|
450
|
-
# mypy doesn't understand that model or from_checkpoint is not None
|
|
451
|
-
if model is not None:
|
|
452
|
-
model_name = model
|
|
453
|
-
elif from_checkpoint is not None:
|
|
454
|
-
model_name = from_checkpoint.split(":")[0]
|
|
455
|
-
else:
|
|
456
|
-
# this branch is unreachable, but mypy doesn't know that
|
|
457
|
-
pass
|
|
458
|
-
model_limits = self.get_model_limits(model=model_name)
|
|
459
|
-
|
|
460
|
-
finetune_request = create_finetune_request(
|
|
461
|
-
model_limits=model_limits,
|
|
462
|
-
training_file=training_file,
|
|
463
|
-
model=model,
|
|
464
|
-
n_epochs=n_epochs,
|
|
465
|
-
validation_file=validation_file,
|
|
466
|
-
n_evals=n_evals,
|
|
467
|
-
n_checkpoints=n_checkpoints,
|
|
468
|
-
batch_size=batch_size,
|
|
469
|
-
learning_rate=learning_rate,
|
|
470
|
-
lr_scheduler_type=lr_scheduler_type,
|
|
471
|
-
min_lr_ratio=min_lr_ratio,
|
|
472
|
-
scheduler_num_cycles=scheduler_num_cycles,
|
|
473
|
-
warmup_ratio=warmup_ratio,
|
|
474
|
-
max_grad_norm=max_grad_norm,
|
|
475
|
-
weight_decay=weight_decay,
|
|
476
|
-
lora=lora,
|
|
477
|
-
lora_r=lora_r,
|
|
478
|
-
lora_dropout=lora_dropout,
|
|
479
|
-
lora_alpha=lora_alpha,
|
|
480
|
-
lora_trainable_modules=lora_trainable_modules,
|
|
481
|
-
train_vision=train_vision,
|
|
482
|
-
suffix=suffix,
|
|
483
|
-
wandb_api_key=wandb_api_key,
|
|
484
|
-
wandb_base_url=wandb_base_url,
|
|
485
|
-
wandb_project_name=wandb_project_name,
|
|
486
|
-
wandb_name=wandb_name,
|
|
487
|
-
train_on_inputs=train_on_inputs,
|
|
488
|
-
training_method=training_method,
|
|
489
|
-
dpo_beta=dpo_beta,
|
|
490
|
-
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
|
|
491
|
-
rpo_alpha=rpo_alpha,
|
|
492
|
-
simpo_gamma=simpo_gamma,
|
|
493
|
-
from_checkpoint=from_checkpoint,
|
|
494
|
-
from_hf_model=from_hf_model,
|
|
495
|
-
hf_model_revision=hf_model_revision,
|
|
496
|
-
hf_api_token=hf_api_token,
|
|
497
|
-
hf_output_repo_name=hf_output_repo_name,
|
|
498
|
-
)
|
|
499
|
-
if from_checkpoint is None and from_hf_model is None:
|
|
500
|
-
price_estimation_result = self.estimate_price(
|
|
501
|
-
training_file=training_file,
|
|
502
|
-
validation_file=validation_file,
|
|
503
|
-
model=model_name,
|
|
504
|
-
n_epochs=finetune_request.n_epochs,
|
|
505
|
-
n_evals=finetune_request.n_evals,
|
|
506
|
-
training_type="lora" if lora else "full",
|
|
507
|
-
training_method=training_method,
|
|
508
|
-
)
|
|
509
|
-
price_limit_passed = price_estimation_result.allowed_to_proceed
|
|
510
|
-
else:
|
|
511
|
-
# unsupported case
|
|
512
|
-
price_limit_passed = True
|
|
513
|
-
|
|
514
|
-
if verbose:
|
|
515
|
-
rprint(
|
|
516
|
-
"Submitting a fine-tuning job with the following parameters:",
|
|
517
|
-
finetune_request,
|
|
518
|
-
)
|
|
519
|
-
if not price_limit_passed:
|
|
520
|
-
rprint(
|
|
521
|
-
"[red]"
|
|
522
|
-
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
|
|
523
|
-
price_estimation_result.estimated_total_price
|
|
524
|
-
)
|
|
525
|
-
+ "[/red]",
|
|
526
|
-
)
|
|
527
|
-
parameter_payload = finetune_request.model_dump(exclude_none=True)
|
|
528
|
-
|
|
529
|
-
response, _, _ = requestor.request(
|
|
530
|
-
options=TogetherRequest(
|
|
531
|
-
method="POST",
|
|
532
|
-
url="fine-tunes",
|
|
533
|
-
params=parameter_payload,
|
|
534
|
-
),
|
|
535
|
-
stream=False,
|
|
536
|
-
)
|
|
537
|
-
assert isinstance(response, TogetherResponse)
|
|
538
|
-
|
|
539
|
-
return FinetuneResponse(**response.data)
|
|
540
|
-
|
|
541
|
-
def estimate_price(
|
|
542
|
-
self,
|
|
543
|
-
*,
|
|
544
|
-
training_file: str,
|
|
545
|
-
model: str,
|
|
546
|
-
validation_file: str | None = None,
|
|
547
|
-
n_epochs: int | None = 1,
|
|
548
|
-
n_evals: int | None = 0,
|
|
549
|
-
training_type: str = "lora",
|
|
550
|
-
training_method: str = "sft",
|
|
551
|
-
) -> FinetunePriceEstimationResponse:
|
|
552
|
-
"""
|
|
553
|
-
Estimates the price of a fine-tuning job
|
|
554
|
-
|
|
555
|
-
Args:
|
|
556
|
-
training_file (str): File-ID of a file uploaded to the Together API
|
|
557
|
-
model (str): Name of the base model to run fine-tune job on
|
|
558
|
-
validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
|
|
559
|
-
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
|
|
560
|
-
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
|
|
561
|
-
training_type (str, optional): Training type. Defaults to "lora".
|
|
562
|
-
training_method (str, optional): Training method. Defaults to "sft".
|
|
563
|
-
|
|
564
|
-
Returns:
|
|
565
|
-
FinetunePriceEstimationResponse: Object containing the price estimation result.
|
|
566
|
-
"""
|
|
567
|
-
training_type_cls: TrainingType
|
|
568
|
-
training_method_cls: TrainingMethod
|
|
569
|
-
|
|
570
|
-
if training_method == "sft":
|
|
571
|
-
training_method_cls = TrainingMethodSFT(method="sft")
|
|
572
|
-
elif training_method == "dpo":
|
|
573
|
-
training_method_cls = TrainingMethodDPO(method="dpo")
|
|
574
|
-
else:
|
|
575
|
-
raise ValueError(f"Unknown training method: {training_method}")
|
|
576
|
-
|
|
577
|
-
if training_type.lower() == "lora":
|
|
578
|
-
# parameters of lora are unused in price estimation
|
|
579
|
-
# but we need to set them to valid values
|
|
580
|
-
training_type_cls = LoRATrainingType(
|
|
581
|
-
type="Lora",
|
|
582
|
-
lora_r=16,
|
|
583
|
-
lora_alpha=16,
|
|
584
|
-
lora_dropout=0.0,
|
|
585
|
-
lora_trainable_modules="all-linear",
|
|
586
|
-
)
|
|
587
|
-
elif training_type.lower() == "full":
|
|
588
|
-
training_type_cls = FullTrainingType(type="Full")
|
|
589
|
-
else:
|
|
590
|
-
raise ValueError(f"Unknown training type: {training_type}")
|
|
591
|
-
|
|
592
|
-
request = FinetunePriceEstimationRequest(
|
|
593
|
-
training_file=training_file,
|
|
594
|
-
validation_file=validation_file,
|
|
595
|
-
model=model,
|
|
596
|
-
n_epochs=n_epochs,
|
|
597
|
-
n_evals=n_evals,
|
|
598
|
-
training_type=training_type_cls,
|
|
599
|
-
training_method=training_method_cls,
|
|
600
|
-
)
|
|
601
|
-
parameter_payload = request.model_dump(exclude_none=True)
|
|
602
|
-
requestor = api_requestor.APIRequestor(
|
|
603
|
-
client=self._client,
|
|
604
|
-
)
|
|
605
|
-
|
|
606
|
-
response, _, _ = requestor.request(
|
|
607
|
-
options=TogetherRequest(
|
|
608
|
-
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
|
|
609
|
-
),
|
|
610
|
-
stream=False,
|
|
611
|
-
)
|
|
612
|
-
assert isinstance(response, TogetherResponse)
|
|
613
|
-
|
|
614
|
-
return FinetunePriceEstimationResponse(**response.data)
|
|
615
|
-
|
|
616
|
-
def list(self) -> FinetuneList:
|
|
617
|
-
"""
|
|
618
|
-
Lists fine-tune job history
|
|
619
|
-
|
|
620
|
-
Returns:
|
|
621
|
-
FinetuneList: Object containing a list of fine-tune jobs
|
|
622
|
-
"""
|
|
623
|
-
|
|
624
|
-
requestor = api_requestor.APIRequestor(
|
|
625
|
-
client=self._client,
|
|
626
|
-
)
|
|
627
|
-
|
|
628
|
-
response, _, _ = requestor.request(
|
|
629
|
-
options=TogetherRequest(
|
|
630
|
-
method="GET",
|
|
631
|
-
url="fine-tunes",
|
|
632
|
-
),
|
|
633
|
-
stream=False,
|
|
634
|
-
)
|
|
635
|
-
|
|
636
|
-
assert isinstance(response, TogetherResponse)
|
|
637
|
-
|
|
638
|
-
return FinetuneList(**response.data)
|
|
639
|
-
|
|
640
|
-
def retrieve(self, id: str) -> FinetuneResponse:
|
|
641
|
-
"""
|
|
642
|
-
Retrieves fine-tune job details
|
|
643
|
-
|
|
644
|
-
Args:
|
|
645
|
-
id (str): Fine-tune ID to retrieve. A string that starts with `ft-`.
|
|
646
|
-
|
|
647
|
-
Returns:
|
|
648
|
-
FinetuneResponse: Object containing information about fine-tuning job.
|
|
649
|
-
"""
|
|
650
|
-
|
|
651
|
-
requestor = api_requestor.APIRequestor(
|
|
652
|
-
client=self._client,
|
|
653
|
-
)
|
|
654
|
-
|
|
655
|
-
response, _, _ = requestor.request(
|
|
656
|
-
options=TogetherRequest(
|
|
657
|
-
method="GET",
|
|
658
|
-
url=f"fine-tunes/{id}",
|
|
659
|
-
),
|
|
660
|
-
stream=False,
|
|
661
|
-
)
|
|
662
|
-
|
|
663
|
-
assert isinstance(response, TogetherResponse)
|
|
664
|
-
|
|
665
|
-
return FinetuneResponse(**response.data)
|
|
666
|
-
|
|
667
|
-
def cancel(self, id: str) -> FinetuneResponse:
|
|
668
|
-
"""
|
|
669
|
-
Method to cancel a running fine-tuning job
|
|
670
|
-
|
|
671
|
-
Args:
|
|
672
|
-
id (str): Fine-tune ID to cancel. A string that starts with `ft-`.
|
|
673
|
-
|
|
674
|
-
Returns:
|
|
675
|
-
FinetuneResponse: Object containing information about cancelled fine-tuning job.
|
|
676
|
-
"""
|
|
677
|
-
|
|
678
|
-
requestor = api_requestor.APIRequestor(
|
|
679
|
-
client=self._client,
|
|
680
|
-
)
|
|
681
|
-
|
|
682
|
-
response, _, _ = requestor.request(
|
|
683
|
-
options=TogetherRequest(
|
|
684
|
-
method="POST",
|
|
685
|
-
url=f"fine-tunes/{id}/cancel",
|
|
686
|
-
),
|
|
687
|
-
stream=False,
|
|
688
|
-
)
|
|
689
|
-
|
|
690
|
-
assert isinstance(response, TogetherResponse)
|
|
691
|
-
|
|
692
|
-
return FinetuneResponse(**response.data)
|
|
693
|
-
|
|
694
|
-
def delete(self, id: str, force: bool = False) -> FinetuneDeleteResponse:
|
|
695
|
-
"""
|
|
696
|
-
Method to delete a fine-tuning job
|
|
697
|
-
|
|
698
|
-
Args:
|
|
699
|
-
id (str): Fine-tune ID to delete. A string that starts with `ft-`.
|
|
700
|
-
force (bool, optional): Force deletion. Defaults to False.
|
|
701
|
-
|
|
702
|
-
Returns:
|
|
703
|
-
FinetuneDeleteResponse: Object containing deletion confirmation message.
|
|
704
|
-
"""
|
|
705
|
-
|
|
706
|
-
requestor = api_requestor.APIRequestor(
|
|
707
|
-
client=self._client,
|
|
708
|
-
)
|
|
709
|
-
|
|
710
|
-
params = {"force": str(force).lower()}
|
|
711
|
-
|
|
712
|
-
response, _, _ = requestor.request(
|
|
713
|
-
options=TogetherRequest(
|
|
714
|
-
method="DELETE",
|
|
715
|
-
url=f"fine-tunes/{id}",
|
|
716
|
-
params=params,
|
|
717
|
-
),
|
|
718
|
-
stream=False,
|
|
719
|
-
)
|
|
720
|
-
|
|
721
|
-
assert isinstance(response, TogetherResponse)
|
|
722
|
-
|
|
723
|
-
return FinetuneDeleteResponse(**response.data)
|
|
724
|
-
|
|
725
|
-
def list_events(self, id: str) -> FinetuneListEvents:
|
|
726
|
-
"""
|
|
727
|
-
Lists events of a fine-tune job
|
|
728
|
-
|
|
729
|
-
Args:
|
|
730
|
-
id (str): Fine-tune ID to list events for. A string that starts with `ft-`.
|
|
731
|
-
|
|
732
|
-
Returns:
|
|
733
|
-
FinetuneListEvents: Object containing list of fine-tune events
|
|
734
|
-
"""
|
|
735
|
-
|
|
736
|
-
requestor = api_requestor.APIRequestor(
|
|
737
|
-
client=self._client,
|
|
738
|
-
)
|
|
739
|
-
|
|
740
|
-
response, _, _ = requestor.request(
|
|
741
|
-
options=TogetherRequest(
|
|
742
|
-
method="GET",
|
|
743
|
-
url=f"fine-tunes/{id}/events",
|
|
744
|
-
),
|
|
745
|
-
stream=False,
|
|
746
|
-
)
|
|
747
|
-
assert isinstance(response, TogetherResponse)
|
|
748
|
-
|
|
749
|
-
return FinetuneListEvents(**response.data)
|
|
750
|
-
|
|
751
|
-
def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
|
|
752
|
-
"""
|
|
753
|
-
List available checkpoints for a fine-tuning job
|
|
754
|
-
|
|
755
|
-
Args:
|
|
756
|
-
id (str): Unique identifier of the fine-tune job to list checkpoints for
|
|
757
|
-
|
|
758
|
-
Returns:
|
|
759
|
-
List[FinetuneCheckpoint]: List of available checkpoints
|
|
760
|
-
"""
|
|
761
|
-
requestor = api_requestor.APIRequestor(
|
|
762
|
-
client=self._client,
|
|
763
|
-
)
|
|
764
|
-
|
|
765
|
-
response, _, _ = requestor.request(
|
|
766
|
-
options=TogetherRequest(
|
|
767
|
-
method="GET",
|
|
768
|
-
url=f"fine-tunes/{id}/checkpoints",
|
|
769
|
-
),
|
|
770
|
-
stream=False,
|
|
771
|
-
)
|
|
772
|
-
assert isinstance(response, TogetherResponse)
|
|
773
|
-
|
|
774
|
-
raw_checkpoints = response.data["data"]
|
|
775
|
-
return _parse_raw_checkpoints(raw_checkpoints, id)
|
|
776
|
-
|
|
777
|
-
def download(
|
|
778
|
-
self,
|
|
779
|
-
id: str,
|
|
780
|
-
*,
|
|
781
|
-
output: Path | str | None = None,
|
|
782
|
-
checkpoint_step: int | None = None,
|
|
783
|
-
checkpoint_type: DownloadCheckpointType | str = DownloadCheckpointType.DEFAULT,
|
|
784
|
-
) -> FinetuneDownloadResult:
|
|
785
|
-
"""
|
|
786
|
-
Downloads compressed fine-tuned model or checkpoint to local disk.
|
|
787
|
-
|
|
788
|
-
Defaults file location to `$PWD/{model_name}.{extension}`
|
|
789
|
-
|
|
790
|
-
Args:
|
|
791
|
-
id (str): Fine-tune ID to download. A string that starts with `ft-`.
|
|
792
|
-
output (pathlib.Path | str, optional): Specifies output file name for downloaded model.
|
|
793
|
-
Defaults to None.
|
|
794
|
-
checkpoint_step (int, optional): Specifies step number for checkpoint to download.
|
|
795
|
-
Defaults to -1 (download the final model)
|
|
796
|
-
checkpoint_type (CheckpointType | str, optional): Specifies which checkpoint to download.
|
|
797
|
-
Defaults to CheckpointType.DEFAULT.
|
|
798
|
-
|
|
799
|
-
Returns:
|
|
800
|
-
FinetuneDownloadResult: Object containing downloaded model metadata
|
|
801
|
-
"""
|
|
802
|
-
|
|
803
|
-
if re.match(_FT_JOB_WITH_STEP_REGEX, id) is not None:
|
|
804
|
-
if checkpoint_step is None:
|
|
805
|
-
checkpoint_step = int(id.split(":")[1])
|
|
806
|
-
id = id.split(":")[0]
|
|
807
|
-
else:
|
|
808
|
-
raise ValueError(
|
|
809
|
-
"Fine-tuning job ID {id} contains a colon to specify the step to download, but `checkpoint_step` "
|
|
810
|
-
"was also set. Remove one of the step specifiers to proceed."
|
|
811
|
-
)
|
|
812
|
-
|
|
813
|
-
url = f"finetune/download?ft_id={id}"
|
|
814
|
-
|
|
815
|
-
if checkpoint_step is not None:
|
|
816
|
-
url += f"&checkpoint_step={checkpoint_step}"
|
|
817
|
-
|
|
818
|
-
ft_job = self.retrieve(id)
|
|
819
|
-
|
|
820
|
-
# convert str to DownloadCheckpointType
|
|
821
|
-
if isinstance(checkpoint_type, str):
|
|
822
|
-
try:
|
|
823
|
-
checkpoint_type = DownloadCheckpointType(checkpoint_type.lower())
|
|
824
|
-
except ValueError:
|
|
825
|
-
enum_strs = ", ".join(e.value for e in DownloadCheckpointType)
|
|
826
|
-
raise ValueError(
|
|
827
|
-
f"Invalid checkpoint type: {checkpoint_type}. Choose one of {{{enum_strs}}}."
|
|
828
|
-
)
|
|
829
|
-
|
|
830
|
-
if isinstance(ft_job.training_type, FullTrainingType):
|
|
831
|
-
if checkpoint_type != DownloadCheckpointType.DEFAULT:
|
|
832
|
-
raise ValueError(
|
|
833
|
-
"Only DEFAULT checkpoint type is allowed for FullTrainingType"
|
|
834
|
-
)
|
|
835
|
-
url += "&checkpoint=model_output_path"
|
|
836
|
-
elif isinstance(ft_job.training_type, LoRATrainingType):
|
|
837
|
-
if checkpoint_type == DownloadCheckpointType.DEFAULT:
|
|
838
|
-
checkpoint_type = DownloadCheckpointType.MERGED
|
|
839
|
-
|
|
840
|
-
if checkpoint_type in {
|
|
841
|
-
DownloadCheckpointType.MERGED,
|
|
842
|
-
DownloadCheckpointType.ADAPTER,
|
|
843
|
-
}:
|
|
844
|
-
url += f"&checkpoint={checkpoint_type.value}"
|
|
845
|
-
else:
|
|
846
|
-
raise ValueError(
|
|
847
|
-
f"Invalid checkpoint type for LoRATrainingType: {checkpoint_type}"
|
|
848
|
-
)
|
|
849
|
-
|
|
850
|
-
remote_name = ft_job.output_name
|
|
851
|
-
|
|
852
|
-
download_manager = DownloadManager(self._client)
|
|
853
|
-
|
|
854
|
-
if isinstance(output, str):
|
|
855
|
-
output = Path(output)
|
|
856
|
-
|
|
857
|
-
downloaded_filename, file_size = download_manager.download(
|
|
858
|
-
url, output, normalize_key(remote_name or id), fetch_metadata=True
|
|
859
|
-
)
|
|
860
|
-
|
|
861
|
-
return FinetuneDownloadResult(
|
|
862
|
-
object="local",
|
|
863
|
-
id=id,
|
|
864
|
-
checkpoint_step=checkpoint_step,
|
|
865
|
-
filename=downloaded_filename,
|
|
866
|
-
size=file_size,
|
|
867
|
-
)
|
|
868
|
-
|
|
869
|
-
def get_model_limits(self, *, model: str) -> FinetuneTrainingLimits:
|
|
870
|
-
"""
|
|
871
|
-
Requests training limits for a specific model
|
|
872
|
-
|
|
873
|
-
Args:
|
|
874
|
-
model_name (str): Name of the model to get limits for
|
|
875
|
-
|
|
876
|
-
Returns:
|
|
877
|
-
FinetuneTrainingLimits: Object containing training limits for the model
|
|
878
|
-
"""
|
|
879
|
-
|
|
880
|
-
requestor = api_requestor.APIRequestor(
|
|
881
|
-
client=self._client,
|
|
882
|
-
)
|
|
883
|
-
|
|
884
|
-
model_limits_response, _, _ = requestor.request(
|
|
885
|
-
options=TogetherRequest(
|
|
886
|
-
method="GET",
|
|
887
|
-
url="fine-tunes/models/limits",
|
|
888
|
-
params={"model_name": model},
|
|
889
|
-
),
|
|
890
|
-
stream=False,
|
|
891
|
-
)
|
|
892
|
-
|
|
893
|
-
model_limits = FinetuneTrainingLimits(**model_limits_response.data)
|
|
894
|
-
|
|
895
|
-
return model_limits
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
class AsyncFineTuning:
|
|
899
|
-
def __init__(self, client: TogetherClient) -> None:
|
|
900
|
-
self._client = client
|
|
901
|
-
|
|
902
|
-
async def create(
|
|
903
|
-
self,
|
|
904
|
-
*,
|
|
905
|
-
training_file: str,
|
|
906
|
-
model: str | None = None,
|
|
907
|
-
n_epochs: int = 1,
|
|
908
|
-
validation_file: str | None = "",
|
|
909
|
-
n_evals: int | None = 0,
|
|
910
|
-
n_checkpoints: int | None = 1,
|
|
911
|
-
batch_size: int | Literal["max"] = "max",
|
|
912
|
-
learning_rate: float | None = 0.00001,
|
|
913
|
-
lr_scheduler_type: Literal["linear", "cosine"] = "cosine",
|
|
914
|
-
min_lr_ratio: float = 0.0,
|
|
915
|
-
scheduler_num_cycles: float = 0.5,
|
|
916
|
-
warmup_ratio: float = 0.0,
|
|
917
|
-
max_grad_norm: float = 1.0,
|
|
918
|
-
weight_decay: float = 0.0,
|
|
919
|
-
lora: bool = True,
|
|
920
|
-
lora_r: int | None = None,
|
|
921
|
-
lora_dropout: float | None = 0,
|
|
922
|
-
lora_alpha: float | None = None,
|
|
923
|
-
lora_trainable_modules: str | None = "all-linear",
|
|
924
|
-
train_vision: bool = False,
|
|
925
|
-
suffix: str | None = None,
|
|
926
|
-
wandb_api_key: str | None = None,
|
|
927
|
-
wandb_base_url: str | None = None,
|
|
928
|
-
wandb_project_name: str | None = None,
|
|
929
|
-
wandb_name: str | None = None,
|
|
930
|
-
verbose: bool = False,
|
|
931
|
-
model_limits: FinetuneTrainingLimits | None = None,
|
|
932
|
-
train_on_inputs: bool | Literal["auto"] | None = None,
|
|
933
|
-
training_method: str = "sft",
|
|
934
|
-
dpo_beta: float | None = None,
|
|
935
|
-
dpo_normalize_logratios_by_length: bool = False,
|
|
936
|
-
rpo_alpha: float | None = None,
|
|
937
|
-
simpo_gamma: float | None = None,
|
|
938
|
-
from_checkpoint: str | None = None,
|
|
939
|
-
from_hf_model: str | None = None,
|
|
940
|
-
hf_model_revision: str | None = None,
|
|
941
|
-
hf_api_token: str | None = None,
|
|
942
|
-
hf_output_repo_name: str | None = None,
|
|
943
|
-
) -> FinetuneResponse:
|
|
944
|
-
"""
|
|
945
|
-
Async method to initiate a fine-tuning job
|
|
946
|
-
|
|
947
|
-
Args:
|
|
948
|
-
training_file (str): File-ID of a file uploaded to the Together API
|
|
949
|
-
model (str, optional): Name of the base model to run fine-tune job on
|
|
950
|
-
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
|
|
951
|
-
validation file (str, optional): File ID of a file uploaded to the Together API for validation.
|
|
952
|
-
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
|
|
953
|
-
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
|
|
954
|
-
Defaults to 1.
|
|
955
|
-
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
|
|
956
|
-
learning_rate (float, optional): Learning rate multiplier to use for training
|
|
957
|
-
Defaults to 0.00001.
|
|
958
|
-
lr_scheduler_type (Literal["linear", "cosine"]): Learning rate scheduler type. Defaults to "cosine".
|
|
959
|
-
min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
|
|
960
|
-
the learning rate scheduler. Defaults to 0.0.
|
|
961
|
-
scheduler_num_cycles (float, optional): Number or fraction of cycles for the cosine learning rate scheduler. Defaults to 0.5.
|
|
962
|
-
warmup_ratio (float, optional): Warmup ratio for the learning rate scheduler.
|
|
963
|
-
max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
|
|
964
|
-
weight_decay (float, optional): Weight decay. Defaults to 0.0.
|
|
965
|
-
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
|
|
966
|
-
lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
|
|
967
|
-
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
|
|
968
|
-
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
|
|
969
|
-
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
|
|
970
|
-
train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False.
|
|
971
|
-
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
|
|
972
|
-
Defaults to None.
|
|
973
|
-
wandb_api_key (str, optional): API key for Weights & Biases integration.
|
|
974
|
-
Defaults to None.
|
|
975
|
-
wandb_base_url (str, optional): Base URL for Weights & Biases integration.
|
|
976
|
-
Defaults to None.
|
|
977
|
-
wandb_project_name (str, optional): Project name for Weights & Biases integration.
|
|
978
|
-
Defaults to None.
|
|
979
|
-
wandb_name (str, optional): Run name for Weights & Biases integration.
|
|
980
|
-
Defaults to None.
|
|
981
|
-
verbose (bool, optional): whether to print the job parameters before submitting a request.
|
|
982
|
-
Defaults to False.
|
|
983
|
-
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
|
|
984
|
-
Defaults to None.
|
|
985
|
-
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
|
|
986
|
-
"auto" will automatically determine whether to mask the inputs based on the data format.
|
|
987
|
-
For datasets with the "text" field (general format), inputs will not be masked.
|
|
988
|
-
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
|
|
989
|
-
(Instruction format), inputs will be masked.
|
|
990
|
-
Defaults to None, or "auto" if training_method is "sft" (set in create_finetune_request).
|
|
991
|
-
training_method (str, optional): Training method. Defaults to "sft".
|
|
992
|
-
Supported methods: "sft", "dpo".
|
|
993
|
-
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
|
|
994
|
-
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
|
|
995
|
-
rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
|
|
996
|
-
simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
|
|
997
|
-
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
|
|
998
|
-
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
|
|
999
|
-
The step value is optional, without it the final checkpoint will be used.
|
|
1000
|
-
from_hf_model (str, optional): The Hugging Face Hub repo to start training from.
|
|
1001
|
-
Should be as close as possible to the base model (specified by the `model` argument) in terms of architecture and size.
|
|
1002
|
-
hf_model_revision (str, optional): The revision of the Hugging Face Hub model to continue training from. Defaults to None.
|
|
1003
|
-
Example: hf_model_revision=None (defaults to the latest revision in `main`) or
|
|
1004
|
-
hf_model_revision="607a30d783dfa663caf39e06633721c8d4cfcd7e" (specific commit).
|
|
1005
|
-
hf_api_token (str, optional): API key for the Huggging Face Hub. Defaults to None.
|
|
1006
|
-
hf_output_repo_name (str, optional): HF repo to upload the fine-tuned model to. Defaults to None.
|
|
1007
|
-
|
|
1008
|
-
Returns:
|
|
1009
|
-
FinetuneResponse: Object containing information about fine-tuning job.
|
|
1010
|
-
"""
|
|
1011
|
-
|
|
1012
|
-
requestor = api_requestor.APIRequestor(
|
|
1013
|
-
client=self._client,
|
|
1014
|
-
)
|
|
1015
|
-
|
|
1016
|
-
if model_limits is None:
|
|
1017
|
-
# mypy doesn't understand that model or from_checkpoint is not None
|
|
1018
|
-
if model is not None:
|
|
1019
|
-
model_name = model
|
|
1020
|
-
elif from_checkpoint is not None:
|
|
1021
|
-
model_name = from_checkpoint.split(":")[0]
|
|
1022
|
-
else:
|
|
1023
|
-
# this branch is unreachable, but mypy doesn't know that
|
|
1024
|
-
pass
|
|
1025
|
-
model_limits = await self.get_model_limits(model=model_name)
|
|
1026
|
-
|
|
1027
|
-
finetune_request = create_finetune_request(
|
|
1028
|
-
model_limits=model_limits,
|
|
1029
|
-
training_file=training_file,
|
|
1030
|
-
model=model,
|
|
1031
|
-
n_epochs=n_epochs,
|
|
1032
|
-
validation_file=validation_file,
|
|
1033
|
-
n_evals=n_evals,
|
|
1034
|
-
n_checkpoints=n_checkpoints,
|
|
1035
|
-
batch_size=batch_size,
|
|
1036
|
-
learning_rate=learning_rate,
|
|
1037
|
-
lr_scheduler_type=lr_scheduler_type,
|
|
1038
|
-
min_lr_ratio=min_lr_ratio,
|
|
1039
|
-
scheduler_num_cycles=scheduler_num_cycles,
|
|
1040
|
-
warmup_ratio=warmup_ratio,
|
|
1041
|
-
max_grad_norm=max_grad_norm,
|
|
1042
|
-
weight_decay=weight_decay,
|
|
1043
|
-
lora=lora,
|
|
1044
|
-
lora_r=lora_r,
|
|
1045
|
-
lora_dropout=lora_dropout,
|
|
1046
|
-
lora_alpha=lora_alpha,
|
|
1047
|
-
lora_trainable_modules=lora_trainable_modules,
|
|
1048
|
-
train_vision=train_vision,
|
|
1049
|
-
suffix=suffix,
|
|
1050
|
-
wandb_api_key=wandb_api_key,
|
|
1051
|
-
wandb_base_url=wandb_base_url,
|
|
1052
|
-
wandb_project_name=wandb_project_name,
|
|
1053
|
-
wandb_name=wandb_name,
|
|
1054
|
-
train_on_inputs=train_on_inputs,
|
|
1055
|
-
training_method=training_method,
|
|
1056
|
-
dpo_beta=dpo_beta,
|
|
1057
|
-
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
|
|
1058
|
-
rpo_alpha=rpo_alpha,
|
|
1059
|
-
simpo_gamma=simpo_gamma,
|
|
1060
|
-
from_checkpoint=from_checkpoint,
|
|
1061
|
-
from_hf_model=from_hf_model,
|
|
1062
|
-
hf_model_revision=hf_model_revision,
|
|
1063
|
-
hf_api_token=hf_api_token,
|
|
1064
|
-
hf_output_repo_name=hf_output_repo_name,
|
|
1065
|
-
)
|
|
1066
|
-
|
|
1067
|
-
if (
|
|
1068
|
-
from_checkpoint is None
|
|
1069
|
-
and from_hf_model is None
|
|
1070
|
-
and not model_limits.supports_vision
|
|
1071
|
-
):
|
|
1072
|
-
price_estimation_result = await self.estimate_price(
|
|
1073
|
-
training_file=training_file,
|
|
1074
|
-
validation_file=validation_file,
|
|
1075
|
-
model=model_name,
|
|
1076
|
-
n_epochs=finetune_request.n_epochs,
|
|
1077
|
-
n_evals=finetune_request.n_evals,
|
|
1078
|
-
training_type="lora" if lora else "full",
|
|
1079
|
-
training_method=training_method,
|
|
1080
|
-
)
|
|
1081
|
-
price_limit_passed = price_estimation_result.allowed_to_proceed
|
|
1082
|
-
else:
|
|
1083
|
-
# unsupported case
|
|
1084
|
-
price_limit_passed = True
|
|
1085
|
-
|
|
1086
|
-
if verbose:
|
|
1087
|
-
rprint(
|
|
1088
|
-
"Submitting a fine-tuning job with the following parameters:",
|
|
1089
|
-
finetune_request,
|
|
1090
|
-
)
|
|
1091
|
-
if not price_limit_passed:
|
|
1092
|
-
rprint(
|
|
1093
|
-
"[red]"
|
|
1094
|
-
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
|
|
1095
|
-
price_estimation_result.estimated_total_price
|
|
1096
|
-
)
|
|
1097
|
-
+ "[/red]",
|
|
1098
|
-
)
|
|
1099
|
-
parameter_payload = finetune_request.model_dump(exclude_none=True)
|
|
1100
|
-
|
|
1101
|
-
response, _, _ = await requestor.arequest(
|
|
1102
|
-
options=TogetherRequest(
|
|
1103
|
-
method="POST",
|
|
1104
|
-
url="fine-tunes",
|
|
1105
|
-
params=parameter_payload,
|
|
1106
|
-
),
|
|
1107
|
-
stream=False,
|
|
1108
|
-
)
|
|
1109
|
-
|
|
1110
|
-
assert isinstance(response, TogetherResponse)
|
|
1111
|
-
|
|
1112
|
-
return FinetuneResponse(**response.data)
|
|
1113
|
-
|
|
1114
|
-
async def estimate_price(
|
|
1115
|
-
self,
|
|
1116
|
-
*,
|
|
1117
|
-
training_file: str,
|
|
1118
|
-
model: str,
|
|
1119
|
-
validation_file: str | None = None,
|
|
1120
|
-
n_epochs: int | None = 1,
|
|
1121
|
-
n_evals: int | None = 0,
|
|
1122
|
-
training_type: str = "lora",
|
|
1123
|
-
training_method: str = "sft",
|
|
1124
|
-
) -> FinetunePriceEstimationResponse:
|
|
1125
|
-
"""
|
|
1126
|
-
Estimates the price of a fine-tuning job
|
|
1127
|
-
|
|
1128
|
-
Args:
|
|
1129
|
-
training_file (str): File-ID of a file uploaded to the Together API
|
|
1130
|
-
model (str): Name of the base model to run fine-tune job on
|
|
1131
|
-
validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
|
|
1132
|
-
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
|
|
1133
|
-
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
|
|
1134
|
-
training_type (str, optional): Training type. Defaults to "lora".
|
|
1135
|
-
training_method (str, optional): Training method. Defaults to "sft".
|
|
1136
|
-
|
|
1137
|
-
Returns:
|
|
1138
|
-
FinetunePriceEstimationResponse: Object containing the price estimation result.
|
|
1139
|
-
"""
|
|
1140
|
-
training_type_cls: TrainingType
|
|
1141
|
-
training_method_cls: TrainingMethod
|
|
1142
|
-
|
|
1143
|
-
if training_method == "sft":
|
|
1144
|
-
training_method_cls = TrainingMethodSFT(method="sft")
|
|
1145
|
-
elif training_method == "dpo":
|
|
1146
|
-
training_method_cls = TrainingMethodDPO(method="dpo")
|
|
1147
|
-
else:
|
|
1148
|
-
raise ValueError(f"Unknown training method: {training_method}")
|
|
1149
|
-
|
|
1150
|
-
if training_type.lower() == "lora":
|
|
1151
|
-
# parameters of lora are unused in price estimation
|
|
1152
|
-
# but we need to set them to valid values
|
|
1153
|
-
training_type_cls = LoRATrainingType(
|
|
1154
|
-
type="Lora",
|
|
1155
|
-
lora_r=16,
|
|
1156
|
-
lora_alpha=16,
|
|
1157
|
-
lora_dropout=0.0,
|
|
1158
|
-
lora_trainable_modules="all-linear",
|
|
1159
|
-
)
|
|
1160
|
-
elif training_type.lower() == "full":
|
|
1161
|
-
training_type_cls = FullTrainingType(type="Full")
|
|
1162
|
-
else:
|
|
1163
|
-
raise ValueError(f"Unknown training type: {training_type}")
|
|
1164
|
-
|
|
1165
|
-
request = FinetunePriceEstimationRequest(
|
|
1166
|
-
training_file=training_file,
|
|
1167
|
-
validation_file=validation_file,
|
|
1168
|
-
model=model,
|
|
1169
|
-
n_epochs=n_epochs,
|
|
1170
|
-
n_evals=n_evals,
|
|
1171
|
-
training_type=training_type_cls,
|
|
1172
|
-
training_method=training_method_cls,
|
|
1173
|
-
)
|
|
1174
|
-
parameter_payload = request.model_dump(exclude_none=True)
|
|
1175
|
-
requestor = api_requestor.APIRequestor(
|
|
1176
|
-
client=self._client,
|
|
1177
|
-
)
|
|
1178
|
-
|
|
1179
|
-
response, _, _ = await requestor.arequest(
|
|
1180
|
-
options=TogetherRequest(
|
|
1181
|
-
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
|
|
1182
|
-
),
|
|
1183
|
-
stream=False,
|
|
1184
|
-
)
|
|
1185
|
-
assert isinstance(response, TogetherResponse)
|
|
1186
|
-
|
|
1187
|
-
return FinetunePriceEstimationResponse(**response.data)
|
|
1188
|
-
|
|
1189
|
-
async def list(self) -> FinetuneList:
|
|
1190
|
-
"""
|
|
1191
|
-
Async method to list fine-tune job history
|
|
1192
|
-
|
|
1193
|
-
Returns:
|
|
1194
|
-
FinetuneList: Object containing a list of fine-tune jobs
|
|
1195
|
-
"""
|
|
1196
|
-
|
|
1197
|
-
requestor = api_requestor.APIRequestor(
|
|
1198
|
-
client=self._client,
|
|
1199
|
-
)
|
|
1200
|
-
|
|
1201
|
-
response, _, _ = await requestor.arequest(
|
|
1202
|
-
options=TogetherRequest(
|
|
1203
|
-
method="GET",
|
|
1204
|
-
url="fine-tunes",
|
|
1205
|
-
),
|
|
1206
|
-
stream=False,
|
|
1207
|
-
)
|
|
1208
|
-
|
|
1209
|
-
assert isinstance(response, TogetherResponse)
|
|
1210
|
-
|
|
1211
|
-
return FinetuneList(**response.data)
|
|
1212
|
-
|
|
1213
|
-
async def retrieve(self, id: str) -> FinetuneResponse:
|
|
1214
|
-
"""
|
|
1215
|
-
Async method to retrieve fine-tune job details
|
|
1216
|
-
|
|
1217
|
-
Args:
|
|
1218
|
-
id (str): Fine-tune ID to retrieve. A string that starts with `ft-`.
|
|
1219
|
-
|
|
1220
|
-
Returns:
|
|
1221
|
-
FinetuneResponse: Object containing information about fine-tuning job.
|
|
1222
|
-
"""
|
|
1223
|
-
|
|
1224
|
-
requestor = api_requestor.APIRequestor(
|
|
1225
|
-
client=self._client,
|
|
1226
|
-
)
|
|
1227
|
-
|
|
1228
|
-
response, _, _ = await requestor.arequest(
|
|
1229
|
-
options=TogetherRequest(
|
|
1230
|
-
method="GET",
|
|
1231
|
-
url=f"fine-tunes/{id}",
|
|
1232
|
-
),
|
|
1233
|
-
stream=False,
|
|
1234
|
-
)
|
|
1235
|
-
|
|
1236
|
-
assert isinstance(response, TogetherResponse)
|
|
1237
|
-
|
|
1238
|
-
return FinetuneResponse(**response.data)
|
|
1239
|
-
|
|
1240
|
-
async def cancel(self, id: str) -> FinetuneResponse:
|
|
1241
|
-
"""
|
|
1242
|
-
Async method to cancel a running fine-tuning job
|
|
1243
|
-
|
|
1244
|
-
Args:
|
|
1245
|
-
id (str): Fine-tune ID to cancel. A string that starts with `ft-`.
|
|
1246
|
-
|
|
1247
|
-
Returns:
|
|
1248
|
-
FinetuneResponse: Object containing information about cancelled fine-tuning job.
|
|
1249
|
-
"""
|
|
1250
|
-
|
|
1251
|
-
requestor = api_requestor.APIRequestor(
|
|
1252
|
-
client=self._client,
|
|
1253
|
-
)
|
|
1254
|
-
|
|
1255
|
-
response, _, _ = await requestor.arequest(
|
|
1256
|
-
options=TogetherRequest(
|
|
1257
|
-
method="POST",
|
|
1258
|
-
url=f"fine-tunes/{id}/cancel",
|
|
1259
|
-
),
|
|
1260
|
-
stream=False,
|
|
1261
|
-
)
|
|
1262
|
-
|
|
1263
|
-
assert isinstance(response, TogetherResponse)
|
|
1264
|
-
|
|
1265
|
-
return FinetuneResponse(**response.data)
|
|
1266
|
-
|
|
1267
|
-
async def delete(self, id: str, force: bool = False) -> FinetuneDeleteResponse:
|
|
1268
|
-
"""
|
|
1269
|
-
Async method to delete a fine-tuning job
|
|
1270
|
-
|
|
1271
|
-
Args:
|
|
1272
|
-
id (str): Fine-tune ID to delete. A string that starts with `ft-`.
|
|
1273
|
-
force (bool, optional): Force deletion. Defaults to False.
|
|
1274
|
-
|
|
1275
|
-
Returns:
|
|
1276
|
-
FinetuneDeleteResponse: Object containing deletion confirmation message.
|
|
1277
|
-
"""
|
|
1278
|
-
|
|
1279
|
-
requestor = api_requestor.APIRequestor(
|
|
1280
|
-
client=self._client,
|
|
1281
|
-
)
|
|
1282
|
-
|
|
1283
|
-
params = {"force": str(force).lower()}
|
|
1284
|
-
|
|
1285
|
-
response, _, _ = await requestor.arequest(
|
|
1286
|
-
options=TogetherRequest(
|
|
1287
|
-
method="DELETE",
|
|
1288
|
-
url=f"fine-tunes/{id}",
|
|
1289
|
-
params=params,
|
|
1290
|
-
),
|
|
1291
|
-
stream=False,
|
|
1292
|
-
)
|
|
1293
|
-
|
|
1294
|
-
assert isinstance(response, TogetherResponse)
|
|
1295
|
-
|
|
1296
|
-
return FinetuneDeleteResponse(**response.data)
|
|
1297
|
-
|
|
1298
|
-
async def list_events(self, id: str) -> FinetuneListEvents:
|
|
1299
|
-
"""
|
|
1300
|
-
List fine-tuning events
|
|
1301
|
-
|
|
1302
|
-
Args:
|
|
1303
|
-
id (str): Unique identifier of the fine-tune job to list events for
|
|
1304
|
-
|
|
1305
|
-
Returns:
|
|
1306
|
-
FinetuneListEvents: Object containing list of fine-tune job events
|
|
1307
|
-
"""
|
|
1308
|
-
|
|
1309
|
-
requestor = api_requestor.APIRequestor(
|
|
1310
|
-
client=self._client,
|
|
1311
|
-
)
|
|
1312
|
-
|
|
1313
|
-
events_response, _, _ = await requestor.arequest(
|
|
1314
|
-
options=TogetherRequest(
|
|
1315
|
-
method="GET",
|
|
1316
|
-
url=f"fine-tunes/{normalize_key(id)}/events",
|
|
1317
|
-
),
|
|
1318
|
-
stream=False,
|
|
1319
|
-
)
|
|
1320
|
-
assert isinstance(events_response, TogetherResponse)
|
|
1321
|
-
|
|
1322
|
-
return FinetuneListEvents(**events_response.data)
|
|
1323
|
-
|
|
1324
|
-
async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
|
|
1325
|
-
"""
|
|
1326
|
-
List available checkpoints for a fine-tuning job
|
|
1327
|
-
|
|
1328
|
-
Args:
|
|
1329
|
-
id (str): Unique identifier of the fine-tune job to list checkpoints for
|
|
1330
|
-
|
|
1331
|
-
Returns:
|
|
1332
|
-
List[FinetuneCheckpoint]: List of available checkpoints
|
|
1333
|
-
"""
|
|
1334
|
-
requestor = api_requestor.APIRequestor(
|
|
1335
|
-
client=self._client,
|
|
1336
|
-
)
|
|
1337
|
-
|
|
1338
|
-
response, _, _ = await requestor.arequest(
|
|
1339
|
-
options=TogetherRequest(
|
|
1340
|
-
method="GET",
|
|
1341
|
-
url=f"fine-tunes/{id}/checkpoints",
|
|
1342
|
-
),
|
|
1343
|
-
stream=False,
|
|
1344
|
-
)
|
|
1345
|
-
assert isinstance(response, TogetherResponse)
|
|
1346
|
-
|
|
1347
|
-
raw_checkpoints = response.data["data"]
|
|
1348
|
-
return _parse_raw_checkpoints(raw_checkpoints, id)
|
|
1349
|
-
|
|
1350
|
-
async def download(
|
|
1351
|
-
self, id: str, *, output: str | None = None, checkpoint_step: int = -1
|
|
1352
|
-
) -> str:
|
|
1353
|
-
"""
|
|
1354
|
-
TODO: Implement async download method
|
|
1355
|
-
"""
|
|
1356
|
-
|
|
1357
|
-
raise NotImplementedError(
|
|
1358
|
-
"AsyncFineTuning.download not implemented. "
|
|
1359
|
-
"Please use FineTuning.download function instead."
|
|
1360
|
-
)
|
|
1361
|
-
|
|
1362
|
-
async def get_model_limits(self, *, model: str) -> FinetuneTrainingLimits:
|
|
1363
|
-
"""
|
|
1364
|
-
Requests training limits for a specific model
|
|
1365
|
-
|
|
1366
|
-
Args:
|
|
1367
|
-
model_name (str): Name of the model to get limits for
|
|
1368
|
-
|
|
1369
|
-
Returns:
|
|
1370
|
-
FinetuneTrainingLimits: Object containing training limits for the model
|
|
1371
|
-
"""
|
|
1372
|
-
|
|
1373
|
-
requestor = api_requestor.APIRequestor(
|
|
1374
|
-
client=self._client,
|
|
1375
|
-
)
|
|
1376
|
-
|
|
1377
|
-
model_limits_response, _, _ = await requestor.arequest(
|
|
1378
|
-
options=TogetherRequest(
|
|
1379
|
-
method="GET",
|
|
1380
|
-
url="fine-tunes/models/limits",
|
|
1381
|
-
params={"model": model},
|
|
1382
|
-
),
|
|
1383
|
-
stream=False,
|
|
1384
|
-
)
|
|
1385
|
-
|
|
1386
|
-
model_limits = FinetuneTrainingLimits(**model_limits_response.data)
|
|
1387
|
-
|
|
1388
|
-
return model_limits
|