together 1.4.1__py3-none-any.whl → 1.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- together/cli/api/finetune.py +67 -5
- together/client.py +9 -1
- together/constants.py +6 -0
- together/legacy/finetune.py +1 -1
- together/resources/finetune.py +173 -15
- together/types/__init__.py +6 -0
- together/types/chat_completions.py +6 -0
- together/types/endpoints.py +3 -3
- together/types/finetune.py +45 -0
- together/utils/__init__.py +4 -0
- together/utils/api_helpers.py +40 -0
- together/utils/files.py +139 -66
- together/utils/tools.py +53 -2
- {together-1.4.1.dist-info → together-1.4.5.dist-info}/METADATA +93 -23
- {together-1.4.1.dist-info → together-1.4.5.dist-info}/RECORD +18 -18
- {together-1.4.1.dist-info → together-1.4.5.dist-info}/WHEEL +1 -1
- {together-1.4.1.dist-info → together-1.4.5.dist-info}/LICENSE +0 -0
- {together-1.4.1.dist-info → together-1.4.5.dist-info}/entry_points.txt +0 -0
together/cli/api/finetune.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from datetime import datetime
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
5
|
from textwrap import wrap
|
|
6
6
|
from typing import Any, Literal
|
|
7
|
+
import re
|
|
7
8
|
|
|
8
9
|
import click
|
|
9
10
|
from click.core import ParameterSource # type: ignore[attr-defined]
|
|
@@ -17,8 +18,13 @@ from together.utils import (
|
|
|
17
18
|
log_warn,
|
|
18
19
|
log_warn_once,
|
|
19
20
|
parse_timestamp,
|
|
21
|
+
format_timestamp,
|
|
22
|
+
)
|
|
23
|
+
from together.types.finetune import (
|
|
24
|
+
DownloadCheckpointType,
|
|
25
|
+
FinetuneTrainingLimits,
|
|
26
|
+
FinetuneEventType,
|
|
20
27
|
)
|
|
21
|
-
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
|
|
22
28
|
|
|
23
29
|
|
|
24
30
|
_CONFIRMATION_MESSAGE = (
|
|
@@ -104,6 +110,18 @@ def fine_tuning(ctx: click.Context) -> None:
|
|
|
104
110
|
default="all-linear",
|
|
105
111
|
help="Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'",
|
|
106
112
|
)
|
|
113
|
+
@click.option(
|
|
114
|
+
"--training-method",
|
|
115
|
+
type=click.Choice(["sft", "dpo"]),
|
|
116
|
+
default="sft",
|
|
117
|
+
help="Training method to use. Options: sft (supervised fine-tuning), dpo (Direct Preference Optimization)",
|
|
118
|
+
)
|
|
119
|
+
@click.option(
|
|
120
|
+
"--dpo-beta",
|
|
121
|
+
type=float,
|
|
122
|
+
default=0.1,
|
|
123
|
+
help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
|
|
124
|
+
)
|
|
107
125
|
@click.option(
|
|
108
126
|
"--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
|
|
109
127
|
)
|
|
@@ -126,6 +144,14 @@ def fine_tuning(ctx: click.Context) -> None:
|
|
|
126
144
|
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
|
|
127
145
|
"`auto` will automatically determine whether to mask the inputs based on the data format.",
|
|
128
146
|
)
|
|
147
|
+
@click.option(
|
|
148
|
+
"--from-checkpoint",
|
|
149
|
+
type=str,
|
|
150
|
+
default=None,
|
|
151
|
+
help="The checkpoint identifier to continue training from a previous fine-tuning job. "
|
|
152
|
+
"The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. "
|
|
153
|
+
"The step value is optional, without it the final checkpoint will be used.",
|
|
154
|
+
)
|
|
129
155
|
def create(
|
|
130
156
|
ctx: click.Context,
|
|
131
157
|
training_file: str,
|
|
@@ -152,6 +178,9 @@ def create(
|
|
|
152
178
|
wandb_name: str,
|
|
153
179
|
confirm: bool,
|
|
154
180
|
train_on_inputs: bool | Literal["auto"],
|
|
181
|
+
training_method: str,
|
|
182
|
+
dpo_beta: float,
|
|
183
|
+
from_checkpoint: str,
|
|
155
184
|
) -> None:
|
|
156
185
|
"""Start fine-tuning"""
|
|
157
186
|
client: Together = ctx.obj
|
|
@@ -180,6 +209,9 @@ def create(
|
|
|
180
209
|
wandb_project_name=wandb_project_name,
|
|
181
210
|
wandb_name=wandb_name,
|
|
182
211
|
train_on_inputs=train_on_inputs,
|
|
212
|
+
training_method=training_method,
|
|
213
|
+
dpo_beta=dpo_beta,
|
|
214
|
+
from_checkpoint=from_checkpoint,
|
|
183
215
|
)
|
|
184
216
|
|
|
185
217
|
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
|
|
@@ -261,7 +293,9 @@ def list(ctx: click.Context) -> None:
|
|
|
261
293
|
|
|
262
294
|
response.data = response.data or []
|
|
263
295
|
|
|
264
|
-
|
|
296
|
+
# Use a default datetime for None values to make sure the key function always returns a comparable value
|
|
297
|
+
epoch_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
|
298
|
+
response.data.sort(key=lambda x: parse_timestamp(x.created_at or "") or epoch_start)
|
|
265
299
|
|
|
266
300
|
display_list = []
|
|
267
301
|
for i in response.data:
|
|
@@ -344,6 +378,34 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
|
|
|
344
378
|
click.echo(table)
|
|
345
379
|
|
|
346
380
|
|
|
381
|
+
@fine_tuning.command()
|
|
382
|
+
@click.pass_context
|
|
383
|
+
@click.argument("fine_tune_id", type=str, required=True)
|
|
384
|
+
def list_checkpoints(ctx: click.Context, fine_tune_id: str) -> None:
|
|
385
|
+
"""List available checkpoints for a fine-tuning job"""
|
|
386
|
+
client: Together = ctx.obj
|
|
387
|
+
|
|
388
|
+
checkpoints = client.fine_tuning.list_checkpoints(fine_tune_id)
|
|
389
|
+
|
|
390
|
+
display_list = []
|
|
391
|
+
for checkpoint in checkpoints:
|
|
392
|
+
display_list.append(
|
|
393
|
+
{
|
|
394
|
+
"Type": checkpoint.type,
|
|
395
|
+
"Timestamp": format_timestamp(checkpoint.timestamp),
|
|
396
|
+
"Name": checkpoint.name,
|
|
397
|
+
}
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
if display_list:
|
|
401
|
+
click.echo(f"Job {fine_tune_id} contains the following checkpoints:")
|
|
402
|
+
table = tabulate(display_list, headers="keys", tablefmt="grid")
|
|
403
|
+
click.echo(table)
|
|
404
|
+
click.echo("\nTo download a checkpoint, use `together fine-tuning download`")
|
|
405
|
+
else:
|
|
406
|
+
click.echo(f"No checkpoints found for job {fine_tune_id}")
|
|
407
|
+
|
|
408
|
+
|
|
347
409
|
@fine_tuning.command()
|
|
348
410
|
@click.pass_context
|
|
349
411
|
@click.argument("fine_tune_id", type=str, required=True)
|
|
@@ -358,7 +420,7 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
|
|
|
358
420
|
"--checkpoint-step",
|
|
359
421
|
type=int,
|
|
360
422
|
required=False,
|
|
361
|
-
default
|
|
423
|
+
default=None,
|
|
362
424
|
help="Download fine-tuning checkpoint. Defaults to latest.",
|
|
363
425
|
)
|
|
364
426
|
@click.option(
|
|
@@ -372,7 +434,7 @@ def download(
|
|
|
372
434
|
ctx: click.Context,
|
|
373
435
|
fine_tune_id: str,
|
|
374
436
|
output_dir: str,
|
|
375
|
-
checkpoint_step: int,
|
|
437
|
+
checkpoint_step: int | None,
|
|
376
438
|
checkpoint_type: DownloadCheckpointType,
|
|
377
439
|
) -> None:
|
|
378
440
|
"""Download fine-tuning checkpoint"""
|
together/client.py
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
|
|
4
|
+
import sys
|
|
5
|
+
from typing import Dict, TYPE_CHECKING
|
|
5
6
|
|
|
6
7
|
from together import resources
|
|
7
8
|
from together.constants import BASE_URL, MAX_RETRIES, TIMEOUT_SECS
|
|
8
9
|
from together.error import AuthenticationError
|
|
9
10
|
from together.types import TogetherClient
|
|
10
11
|
from together.utils import enforce_trailing_slash
|
|
12
|
+
from together.utils.api_helpers import get_google_colab_secret
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class Together:
|
|
@@ -44,6 +46,9 @@ class Together:
|
|
|
44
46
|
if not api_key:
|
|
45
47
|
api_key = os.environ.get("TOGETHER_API_KEY")
|
|
46
48
|
|
|
49
|
+
if not api_key and "google.colab" in sys.modules:
|
|
50
|
+
api_key = get_google_colab_secret("TOGETHER_API_KEY")
|
|
51
|
+
|
|
47
52
|
if not api_key:
|
|
48
53
|
raise AuthenticationError(
|
|
49
54
|
"The api_key client option must be set either by passing api_key to the client or by setting the "
|
|
@@ -117,6 +122,9 @@ class AsyncTogether:
|
|
|
117
122
|
if not api_key:
|
|
118
123
|
api_key = os.environ.get("TOGETHER_API_KEY")
|
|
119
124
|
|
|
125
|
+
if not api_key and "google.colab" in sys.modules:
|
|
126
|
+
api_key = get_google_colab_secret("TOGETHER_API_KEY")
|
|
127
|
+
|
|
120
128
|
if not api_key:
|
|
121
129
|
raise AuthenticationError(
|
|
122
130
|
"The api_key client option must be set either by passing api_key to the client or by setting the "
|
together/constants.py
CHANGED
|
@@ -39,12 +39,18 @@ class DatasetFormat(enum.Enum):
|
|
|
39
39
|
GENERAL = "general"
|
|
40
40
|
CONVERSATION = "conversation"
|
|
41
41
|
INSTRUCTION = "instruction"
|
|
42
|
+
PREFERENCE_OPENAI = "preference_openai"
|
|
42
43
|
|
|
43
44
|
|
|
44
45
|
JSONL_REQUIRED_COLUMNS_MAP = {
|
|
45
46
|
DatasetFormat.GENERAL: ["text"],
|
|
46
47
|
DatasetFormat.CONVERSATION: ["messages"],
|
|
47
48
|
DatasetFormat.INSTRUCTION: ["prompt", "completion"],
|
|
49
|
+
DatasetFormat.PREFERENCE_OPENAI: [
|
|
50
|
+
"input",
|
|
51
|
+
"preferred_output",
|
|
52
|
+
"non_preferred_output",
|
|
53
|
+
],
|
|
48
54
|
}
|
|
49
55
|
REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
|
|
50
56
|
POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
|
together/legacy/finetune.py
CHANGED
together/resources/finetune.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import re
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Literal
|
|
5
|
+
from typing import Literal, List
|
|
5
6
|
|
|
6
7
|
from rich import print as rprint
|
|
7
8
|
|
|
@@ -22,9 +23,28 @@ from together.types import (
|
|
|
22
23
|
TrainingType,
|
|
23
24
|
FinetuneLRScheduler,
|
|
24
25
|
FinetuneLinearLRSchedulerArgs,
|
|
26
|
+
TrainingMethodDPO,
|
|
27
|
+
TrainingMethodSFT,
|
|
28
|
+
FinetuneCheckpoint,
|
|
25
29
|
)
|
|
26
|
-
from together.types.finetune import
|
|
27
|
-
|
|
30
|
+
from together.types.finetune import (
|
|
31
|
+
DownloadCheckpointType,
|
|
32
|
+
FinetuneEventType,
|
|
33
|
+
FinetuneEvent,
|
|
34
|
+
)
|
|
35
|
+
from together.utils import (
|
|
36
|
+
log_warn_once,
|
|
37
|
+
normalize_key,
|
|
38
|
+
get_event_step,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
AVAILABLE_TRAINING_METHODS = {
|
|
45
|
+
TrainingMethodSFT().method,
|
|
46
|
+
TrainingMethodDPO().method,
|
|
47
|
+
}
|
|
28
48
|
|
|
29
49
|
|
|
30
50
|
def createFinetuneRequest(
|
|
@@ -52,7 +72,11 @@ def createFinetuneRequest(
|
|
|
52
72
|
wandb_project_name: str | None = None,
|
|
53
73
|
wandb_name: str | None = None,
|
|
54
74
|
train_on_inputs: bool | Literal["auto"] = "auto",
|
|
75
|
+
training_method: str = "sft",
|
|
76
|
+
dpo_beta: float | None = None,
|
|
77
|
+
from_checkpoint: str | None = None,
|
|
55
78
|
) -> FinetuneRequest:
|
|
79
|
+
|
|
56
80
|
if batch_size == "max":
|
|
57
81
|
log_warn_once(
|
|
58
82
|
"Starting from together>=1.3.0, "
|
|
@@ -100,11 +124,20 @@ def createFinetuneRequest(
|
|
|
100
124
|
if weight_decay is not None and (weight_decay < 0):
|
|
101
125
|
raise ValueError("Weight decay should be non-negative")
|
|
102
126
|
|
|
127
|
+
if training_method not in AVAILABLE_TRAINING_METHODS:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
|
|
130
|
+
)
|
|
131
|
+
|
|
103
132
|
lrScheduler = FinetuneLRScheduler(
|
|
104
133
|
lr_scheduler_type="linear",
|
|
105
134
|
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
|
|
106
135
|
)
|
|
107
136
|
|
|
137
|
+
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
|
|
138
|
+
if training_method == "dpo":
|
|
139
|
+
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
|
|
140
|
+
|
|
108
141
|
finetune_request = FinetuneRequest(
|
|
109
142
|
model=model,
|
|
110
143
|
training_file=training_file,
|
|
@@ -125,11 +158,77 @@ def createFinetuneRequest(
|
|
|
125
158
|
wandb_project_name=wandb_project_name,
|
|
126
159
|
wandb_name=wandb_name,
|
|
127
160
|
train_on_inputs=train_on_inputs,
|
|
161
|
+
training_method=training_method_cls,
|
|
162
|
+
from_checkpoint=from_checkpoint,
|
|
128
163
|
)
|
|
129
164
|
|
|
130
165
|
return finetune_request
|
|
131
166
|
|
|
132
167
|
|
|
168
|
+
def _process_checkpoints_from_events(
|
|
169
|
+
events: List[FinetuneEvent], id: str
|
|
170
|
+
) -> List[FinetuneCheckpoint]:
|
|
171
|
+
"""
|
|
172
|
+
Helper function to process events and create checkpoint list.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
events (List[FinetuneEvent]): List of fine-tune events to process
|
|
176
|
+
id (str): Fine-tune job ID
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
List[FinetuneCheckpoint]: List of available checkpoints
|
|
180
|
+
"""
|
|
181
|
+
checkpoints: List[FinetuneCheckpoint] = []
|
|
182
|
+
|
|
183
|
+
for event in events:
|
|
184
|
+
event_type = event.type
|
|
185
|
+
|
|
186
|
+
if event_type == FinetuneEventType.CHECKPOINT_SAVE:
|
|
187
|
+
step = get_event_step(event)
|
|
188
|
+
checkpoint_name = f"{id}:{step}" if step is not None else id
|
|
189
|
+
|
|
190
|
+
checkpoints.append(
|
|
191
|
+
FinetuneCheckpoint(
|
|
192
|
+
type=(
|
|
193
|
+
f"Intermediate (step {step})"
|
|
194
|
+
if step is not None
|
|
195
|
+
else "Intermediate"
|
|
196
|
+
),
|
|
197
|
+
timestamp=event.created_at,
|
|
198
|
+
name=checkpoint_name,
|
|
199
|
+
)
|
|
200
|
+
)
|
|
201
|
+
elif event_type == FinetuneEventType.JOB_COMPLETE:
|
|
202
|
+
if hasattr(event, "model_path"):
|
|
203
|
+
checkpoints.append(
|
|
204
|
+
FinetuneCheckpoint(
|
|
205
|
+
type=(
|
|
206
|
+
"Final Merged"
|
|
207
|
+
if hasattr(event, "adapter_path")
|
|
208
|
+
else "Final"
|
|
209
|
+
),
|
|
210
|
+
timestamp=event.created_at,
|
|
211
|
+
name=id,
|
|
212
|
+
)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
if hasattr(event, "adapter_path"):
|
|
216
|
+
checkpoints.append(
|
|
217
|
+
FinetuneCheckpoint(
|
|
218
|
+
type=(
|
|
219
|
+
"Final Adapter" if hasattr(event, "model_path") else "Final"
|
|
220
|
+
),
|
|
221
|
+
timestamp=event.created_at,
|
|
222
|
+
name=id,
|
|
223
|
+
)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Sort by timestamp (newest first)
|
|
227
|
+
checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
|
|
228
|
+
|
|
229
|
+
return checkpoints
|
|
230
|
+
|
|
231
|
+
|
|
133
232
|
class FineTuning:
|
|
134
233
|
def __init__(self, client: TogetherClient) -> None:
|
|
135
234
|
self._client = client
|
|
@@ -162,6 +261,9 @@ class FineTuning:
|
|
|
162
261
|
verbose: bool = False,
|
|
163
262
|
model_limits: FinetuneTrainingLimits | None = None,
|
|
164
263
|
train_on_inputs: bool | Literal["auto"] = "auto",
|
|
264
|
+
training_method: str = "sft",
|
|
265
|
+
dpo_beta: float | None = None,
|
|
266
|
+
from_checkpoint: str | None = None,
|
|
165
267
|
) -> FinetuneResponse:
|
|
166
268
|
"""
|
|
167
269
|
Method to initiate a fine-tuning job
|
|
@@ -207,6 +309,12 @@ class FineTuning:
|
|
|
207
309
|
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
|
|
208
310
|
(Instruction format), inputs will be masked.
|
|
209
311
|
Defaults to "auto".
|
|
312
|
+
training_method (str, optional): Training method. Defaults to "sft".
|
|
313
|
+
Supported methods: "sft", "dpo".
|
|
314
|
+
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
|
|
315
|
+
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
|
|
316
|
+
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
|
|
317
|
+
The step value is optional, without it the final checkpoint will be used.
|
|
210
318
|
|
|
211
319
|
Returns:
|
|
212
320
|
FinetuneResponse: Object containing information about fine-tuning job.
|
|
@@ -218,7 +326,6 @@ class FineTuning:
|
|
|
218
326
|
|
|
219
327
|
if model_limits is None:
|
|
220
328
|
model_limits = self.get_model_limits(model=model)
|
|
221
|
-
|
|
222
329
|
finetune_request = createFinetuneRequest(
|
|
223
330
|
model_limits=model_limits,
|
|
224
331
|
training_file=training_file,
|
|
@@ -244,6 +351,9 @@ class FineTuning:
|
|
|
244
351
|
wandb_project_name=wandb_project_name,
|
|
245
352
|
wandb_name=wandb_name,
|
|
246
353
|
train_on_inputs=train_on_inputs,
|
|
354
|
+
training_method=training_method,
|
|
355
|
+
dpo_beta=dpo_beta,
|
|
356
|
+
from_checkpoint=from_checkpoint,
|
|
247
357
|
)
|
|
248
358
|
|
|
249
359
|
if verbose:
|
|
@@ -261,7 +371,6 @@ class FineTuning:
|
|
|
261
371
|
),
|
|
262
372
|
stream=False,
|
|
263
373
|
)
|
|
264
|
-
|
|
265
374
|
assert isinstance(response, TogetherResponse)
|
|
266
375
|
|
|
267
376
|
return FinetuneResponse(**response.data)
|
|
@@ -366,17 +475,29 @@ class FineTuning:
|
|
|
366
475
|
),
|
|
367
476
|
stream=False,
|
|
368
477
|
)
|
|
369
|
-
|
|
370
478
|
assert isinstance(response, TogetherResponse)
|
|
371
479
|
|
|
372
480
|
return FinetuneListEvents(**response.data)
|
|
373
481
|
|
|
482
|
+
def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
|
|
483
|
+
"""
|
|
484
|
+
List available checkpoints for a fine-tuning job
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
id (str): Unique identifier of the fine-tune job to list checkpoints for
|
|
488
|
+
|
|
489
|
+
Returns:
|
|
490
|
+
List[FinetuneCheckpoint]: List of available checkpoints
|
|
491
|
+
"""
|
|
492
|
+
events = self.list_events(id).data or []
|
|
493
|
+
return _process_checkpoints_from_events(events, id)
|
|
494
|
+
|
|
374
495
|
def download(
|
|
375
496
|
self,
|
|
376
497
|
id: str,
|
|
377
498
|
*,
|
|
378
499
|
output: Path | str | None = None,
|
|
379
|
-
checkpoint_step: int =
|
|
500
|
+
checkpoint_step: int | None = None,
|
|
380
501
|
checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
|
|
381
502
|
) -> FinetuneDownloadResult:
|
|
382
503
|
"""
|
|
@@ -397,9 +518,19 @@ class FineTuning:
|
|
|
397
518
|
FinetuneDownloadResult: Object containing downloaded model metadata
|
|
398
519
|
"""
|
|
399
520
|
|
|
521
|
+
if re.match(_FT_JOB_WITH_STEP_REGEX, id) is not None:
|
|
522
|
+
if checkpoint_step is None:
|
|
523
|
+
checkpoint_step = int(id.split(":")[1])
|
|
524
|
+
id = id.split(":")[0]
|
|
525
|
+
else:
|
|
526
|
+
raise ValueError(
|
|
527
|
+
"Fine-tuning job ID {id} contains a colon to specify the step to download, but `checkpoint_step` "
|
|
528
|
+
"was also set. Remove one of the step specifiers to proceed."
|
|
529
|
+
)
|
|
530
|
+
|
|
400
531
|
url = f"finetune/download?ft_id={id}"
|
|
401
532
|
|
|
402
|
-
if checkpoint_step
|
|
533
|
+
if checkpoint_step is not None:
|
|
403
534
|
url += f"&checkpoint_step={checkpoint_step}"
|
|
404
535
|
|
|
405
536
|
ft_job = self.retrieve(id)
|
|
@@ -503,6 +634,9 @@ class AsyncFineTuning:
|
|
|
503
634
|
verbose: bool = False,
|
|
504
635
|
model_limits: FinetuneTrainingLimits | None = None,
|
|
505
636
|
train_on_inputs: bool | Literal["auto"] = "auto",
|
|
637
|
+
training_method: str = "sft",
|
|
638
|
+
dpo_beta: float | None = None,
|
|
639
|
+
from_checkpoint: str | None = None,
|
|
506
640
|
) -> FinetuneResponse:
|
|
507
641
|
"""
|
|
508
642
|
Async method to initiate a fine-tuning job
|
|
@@ -548,6 +682,12 @@ class AsyncFineTuning:
|
|
|
548
682
|
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
|
|
549
683
|
(Instruction format), inputs will be masked.
|
|
550
684
|
Defaults to "auto".
|
|
685
|
+
training_method (str, optional): Training method. Defaults to "sft".
|
|
686
|
+
Supported methods: "sft", "dpo".
|
|
687
|
+
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
|
|
688
|
+
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
|
|
689
|
+
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
|
|
690
|
+
The step value is optional, without it the final checkpoint will be used.
|
|
551
691
|
|
|
552
692
|
Returns:
|
|
553
693
|
FinetuneResponse: Object containing information about fine-tuning job.
|
|
@@ -585,6 +725,9 @@ class AsyncFineTuning:
|
|
|
585
725
|
wandb_project_name=wandb_project_name,
|
|
586
726
|
wandb_name=wandb_name,
|
|
587
727
|
train_on_inputs=train_on_inputs,
|
|
728
|
+
training_method=training_method,
|
|
729
|
+
dpo_beta=dpo_beta,
|
|
730
|
+
from_checkpoint=from_checkpoint,
|
|
588
731
|
)
|
|
589
732
|
|
|
590
733
|
if verbose:
|
|
@@ -687,30 +830,45 @@ class AsyncFineTuning:
|
|
|
687
830
|
|
|
688
831
|
async def list_events(self, id: str) -> FinetuneListEvents:
|
|
689
832
|
"""
|
|
690
|
-
|
|
833
|
+
List fine-tuning events
|
|
691
834
|
|
|
692
835
|
Args:
|
|
693
|
-
id (str):
|
|
836
|
+
id (str): Unique identifier of the fine-tune job to list events for
|
|
694
837
|
|
|
695
838
|
Returns:
|
|
696
|
-
FinetuneListEvents: Object containing list of fine-tune events
|
|
839
|
+
FinetuneListEvents: Object containing list of fine-tune job events
|
|
697
840
|
"""
|
|
698
841
|
|
|
699
842
|
requestor = api_requestor.APIRequestor(
|
|
700
843
|
client=self._client,
|
|
701
844
|
)
|
|
702
845
|
|
|
703
|
-
|
|
846
|
+
events_response, _, _ = await requestor.arequest(
|
|
704
847
|
options=TogetherRequest(
|
|
705
848
|
method="GET",
|
|
706
|
-
url=f"fine-tunes/{id}/events",
|
|
849
|
+
url=f"fine-tunes/{normalize_key(id)}/events",
|
|
707
850
|
),
|
|
708
851
|
stream=False,
|
|
709
852
|
)
|
|
710
853
|
|
|
711
|
-
|
|
854
|
+
# FIXME: API returns "data" field with no object type (should be "list")
|
|
855
|
+
events_list = FinetuneListEvents(object="list", **events_response.data)
|
|
712
856
|
|
|
713
|
-
return
|
|
857
|
+
return events_list
|
|
858
|
+
|
|
859
|
+
async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
|
|
860
|
+
"""
|
|
861
|
+
List available checkpoints for a fine-tuning job
|
|
862
|
+
|
|
863
|
+
Args:
|
|
864
|
+
id (str): Unique identifier of the fine-tune job to list checkpoints for
|
|
865
|
+
|
|
866
|
+
Returns:
|
|
867
|
+
List[FinetuneCheckpoint]: Object containing list of available checkpoints
|
|
868
|
+
"""
|
|
869
|
+
events_list = await self.list_events(id)
|
|
870
|
+
events = events_list.data or []
|
|
871
|
+
return _process_checkpoints_from_events(events, id)
|
|
714
872
|
|
|
715
873
|
async def download(
|
|
716
874
|
self, id: str, *, output: str | None = None, checkpoint_step: int = -1
|
together/types/__init__.py
CHANGED
|
@@ -31,6 +31,9 @@ from together.types.files import (
|
|
|
31
31
|
FileType,
|
|
32
32
|
)
|
|
33
33
|
from together.types.finetune import (
|
|
34
|
+
TrainingMethodDPO,
|
|
35
|
+
TrainingMethodSFT,
|
|
36
|
+
FinetuneCheckpoint,
|
|
34
37
|
FinetuneDownloadResult,
|
|
35
38
|
FinetuneLinearLRSchedulerArgs,
|
|
36
39
|
FinetuneList,
|
|
@@ -59,6 +62,7 @@ __all__ = [
|
|
|
59
62
|
"ChatCompletionResponse",
|
|
60
63
|
"EmbeddingRequest",
|
|
61
64
|
"EmbeddingResponse",
|
|
65
|
+
"FinetuneCheckpoint",
|
|
62
66
|
"FinetuneRequest",
|
|
63
67
|
"FinetuneResponse",
|
|
64
68
|
"FinetuneList",
|
|
@@ -79,6 +83,8 @@ __all__ = [
|
|
|
79
83
|
"TrainingType",
|
|
80
84
|
"FullTrainingType",
|
|
81
85
|
"LoRATrainingType",
|
|
86
|
+
"TrainingMethodDPO",
|
|
87
|
+
"TrainingMethodSFT",
|
|
82
88
|
"RerankRequest",
|
|
83
89
|
"RerankResponse",
|
|
84
90
|
"FinetuneTrainingLimits",
|
|
@@ -44,16 +44,22 @@ class ToolCalls(BaseModel):
|
|
|
44
44
|
class ChatCompletionMessageContentType(str, Enum):
|
|
45
45
|
TEXT = "text"
|
|
46
46
|
IMAGE_URL = "image_url"
|
|
47
|
+
VIDEO_URL = "video_url"
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
class ChatCompletionMessageContentImageURL(BaseModel):
|
|
50
51
|
url: str
|
|
51
52
|
|
|
52
53
|
|
|
54
|
+
class ChatCompletionMessageContentVideoURL(BaseModel):
|
|
55
|
+
url: str
|
|
56
|
+
|
|
57
|
+
|
|
53
58
|
class ChatCompletionMessageContent(BaseModel):
|
|
54
59
|
type: ChatCompletionMessageContentType
|
|
55
60
|
text: str | None = None
|
|
56
61
|
image_url: ChatCompletionMessageContentImageURL | None = None
|
|
62
|
+
video_url: ChatCompletionMessageContentVideoURL | None = None
|
|
57
63
|
|
|
58
64
|
|
|
59
65
|
class ChatCompletionMessage(BaseModel):
|
together/types/endpoints.py
CHANGED
|
@@ -86,9 +86,9 @@ class BaseEndpoint(TogetherJSONModel):
|
|
|
86
86
|
model: str = Field(description="The model deployed on this endpoint")
|
|
87
87
|
type: str = Field(description="The type of endpoint")
|
|
88
88
|
owner: str = Field(description="The owner of this endpoint")
|
|
89
|
-
state: Literal[
|
|
90
|
-
|
|
91
|
-
)
|
|
89
|
+
state: Literal[
|
|
90
|
+
"PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "FAILED", "ERROR"
|
|
91
|
+
] = Field(description="Current state of the endpoint")
|
|
92
92
|
created_at: datetime = Field(description="Timestamp when the endpoint was created")
|
|
93
93
|
|
|
94
94
|
|
together/types/finetune.py
CHANGED
|
@@ -135,6 +135,31 @@ class LoRATrainingType(TrainingType):
|
|
|
135
135
|
type: str = "Lora"
|
|
136
136
|
|
|
137
137
|
|
|
138
|
+
class TrainingMethod(BaseModel):
|
|
139
|
+
"""
|
|
140
|
+
Training method type
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
method: str
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class TrainingMethodSFT(TrainingMethod):
|
|
147
|
+
"""
|
|
148
|
+
Training method type for SFT training
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
method: Literal["sft"] = "sft"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class TrainingMethodDPO(TrainingMethod):
|
|
155
|
+
"""
|
|
156
|
+
Training method type for DPO training
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
method: Literal["dpo"] = "dpo"
|
|
160
|
+
dpo_beta: float | None = None
|
|
161
|
+
|
|
162
|
+
|
|
138
163
|
class FinetuneRequest(BaseModel):
|
|
139
164
|
"""
|
|
140
165
|
Fine-tune request type
|
|
@@ -178,6 +203,12 @@ class FinetuneRequest(BaseModel):
|
|
|
178
203
|
training_type: FullTrainingType | LoRATrainingType | None = None
|
|
179
204
|
# train on inputs
|
|
180
205
|
train_on_inputs: StrictBool | Literal["auto"] = "auto"
|
|
206
|
+
# training method
|
|
207
|
+
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
|
|
208
|
+
default_factory=TrainingMethodSFT
|
|
209
|
+
)
|
|
210
|
+
# from step
|
|
211
|
+
from_checkpoint: str | None = None
|
|
181
212
|
|
|
182
213
|
|
|
183
214
|
class FinetuneResponse(BaseModel):
|
|
@@ -256,6 +287,7 @@ class FinetuneResponse(BaseModel):
|
|
|
256
287
|
training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
|
|
257
288
|
training_file_size: int | None = Field(None, alias="TrainingFileSize")
|
|
258
289
|
train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
|
|
290
|
+
from_checkpoint: str | None = None
|
|
259
291
|
|
|
260
292
|
@field_validator("training_type")
|
|
261
293
|
@classmethod
|
|
@@ -320,3 +352,16 @@ class FinetuneLRScheduler(BaseModel):
|
|
|
320
352
|
|
|
321
353
|
class FinetuneLinearLRSchedulerArgs(BaseModel):
|
|
322
354
|
min_lr_ratio: float | None = 0.0
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class FinetuneCheckpoint(BaseModel):
|
|
358
|
+
"""
|
|
359
|
+
Fine-tuning checkpoint information
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
# checkpoint type (e.g. "Intermediate", "Final", "Final Merged", "Final Adapter")
|
|
363
|
+
type: str
|
|
364
|
+
# timestamp when the checkpoint was created
|
|
365
|
+
timestamp: str
|
|
366
|
+
# checkpoint name/identifier
|
|
367
|
+
name: str
|
together/utils/__init__.py
CHANGED
|
@@ -8,6 +8,8 @@ from together.utils.tools import (
|
|
|
8
8
|
finetune_price_to_dollars,
|
|
9
9
|
normalize_key,
|
|
10
10
|
parse_timestamp,
|
|
11
|
+
format_timestamp,
|
|
12
|
+
get_event_step,
|
|
11
13
|
)
|
|
12
14
|
|
|
13
15
|
|
|
@@ -23,6 +25,8 @@ __all__ = [
|
|
|
23
25
|
"enforce_trailing_slash",
|
|
24
26
|
"normalize_key",
|
|
25
27
|
"parse_timestamp",
|
|
28
|
+
"format_timestamp",
|
|
29
|
+
"get_event_step",
|
|
26
30
|
"finetune_price_to_dollars",
|
|
27
31
|
"convert_bytes",
|
|
28
32
|
"convert_unix_timestamp",
|