together 1.4.0__py3-none-any.whl → 1.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- together/abstract/api_requestor.py +7 -9
- together/cli/api/endpoints.py +415 -0
- together/cli/api/finetune.py +67 -5
- together/cli/cli.py +2 -0
- together/client.py +1 -0
- together/constants.py +6 -0
- together/error.py +3 -0
- together/legacy/finetune.py +1 -1
- together/resources/__init__.py +4 -1
- together/resources/endpoints.py +488 -0
- together/resources/finetune.py +173 -15
- together/types/__init__.py +25 -20
- together/types/chat_completions.py +6 -0
- together/types/endpoints.py +123 -0
- together/types/finetune.py +45 -0
- together/utils/__init__.py +4 -0
- together/utils/files.py +139 -66
- together/utils/tools.py +53 -2
- {together-1.4.0.dist-info → together-1.4.4.dist-info}/METADATA +93 -23
- {together-1.4.0.dist-info → together-1.4.4.dist-info}/RECORD +23 -20
- {together-1.4.0.dist-info → together-1.4.4.dist-info}/WHEEL +1 -1
- {together-1.4.0.dist-info → together-1.4.4.dist-info}/LICENSE +0 -0
- {together-1.4.0.dist-info → together-1.4.4.dist-info}/entry_points.txt +0 -0
together/resources/finetune.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import re
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Literal
|
|
5
|
+
from typing import Literal, List
|
|
5
6
|
|
|
6
7
|
from rich import print as rprint
|
|
7
8
|
|
|
@@ -22,9 +23,28 @@ from together.types import (
|
|
|
22
23
|
TrainingType,
|
|
23
24
|
FinetuneLRScheduler,
|
|
24
25
|
FinetuneLinearLRSchedulerArgs,
|
|
26
|
+
TrainingMethodDPO,
|
|
27
|
+
TrainingMethodSFT,
|
|
28
|
+
FinetuneCheckpoint,
|
|
25
29
|
)
|
|
26
|
-
from together.types.finetune import
|
|
27
|
-
|
|
30
|
+
from together.types.finetune import (
|
|
31
|
+
DownloadCheckpointType,
|
|
32
|
+
FinetuneEventType,
|
|
33
|
+
FinetuneEvent,
|
|
34
|
+
)
|
|
35
|
+
from together.utils import (
|
|
36
|
+
log_warn_once,
|
|
37
|
+
normalize_key,
|
|
38
|
+
get_event_step,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
AVAILABLE_TRAINING_METHODS = {
|
|
45
|
+
TrainingMethodSFT().method,
|
|
46
|
+
TrainingMethodDPO().method,
|
|
47
|
+
}
|
|
28
48
|
|
|
29
49
|
|
|
30
50
|
def createFinetuneRequest(
|
|
@@ -52,7 +72,11 @@ def createFinetuneRequest(
|
|
|
52
72
|
wandb_project_name: str | None = None,
|
|
53
73
|
wandb_name: str | None = None,
|
|
54
74
|
train_on_inputs: bool | Literal["auto"] = "auto",
|
|
75
|
+
training_method: str = "sft",
|
|
76
|
+
dpo_beta: float | None = None,
|
|
77
|
+
from_checkpoint: str | None = None,
|
|
55
78
|
) -> FinetuneRequest:
|
|
79
|
+
|
|
56
80
|
if batch_size == "max":
|
|
57
81
|
log_warn_once(
|
|
58
82
|
"Starting from together>=1.3.0, "
|
|
@@ -100,11 +124,20 @@ def createFinetuneRequest(
|
|
|
100
124
|
if weight_decay is not None and (weight_decay < 0):
|
|
101
125
|
raise ValueError("Weight decay should be non-negative")
|
|
102
126
|
|
|
127
|
+
if training_method not in AVAILABLE_TRAINING_METHODS:
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
|
|
130
|
+
)
|
|
131
|
+
|
|
103
132
|
lrScheduler = FinetuneLRScheduler(
|
|
104
133
|
lr_scheduler_type="linear",
|
|
105
134
|
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
|
|
106
135
|
)
|
|
107
136
|
|
|
137
|
+
training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
|
|
138
|
+
if training_method == "dpo":
|
|
139
|
+
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
|
|
140
|
+
|
|
108
141
|
finetune_request = FinetuneRequest(
|
|
109
142
|
model=model,
|
|
110
143
|
training_file=training_file,
|
|
@@ -125,11 +158,77 @@ def createFinetuneRequest(
|
|
|
125
158
|
wandb_project_name=wandb_project_name,
|
|
126
159
|
wandb_name=wandb_name,
|
|
127
160
|
train_on_inputs=train_on_inputs,
|
|
161
|
+
training_method=training_method_cls,
|
|
162
|
+
from_checkpoint=from_checkpoint,
|
|
128
163
|
)
|
|
129
164
|
|
|
130
165
|
return finetune_request
|
|
131
166
|
|
|
132
167
|
|
|
168
|
+
def _process_checkpoints_from_events(
|
|
169
|
+
events: List[FinetuneEvent], id: str
|
|
170
|
+
) -> List[FinetuneCheckpoint]:
|
|
171
|
+
"""
|
|
172
|
+
Helper function to process events and create checkpoint list.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
events (List[FinetuneEvent]): List of fine-tune events to process
|
|
176
|
+
id (str): Fine-tune job ID
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
List[FinetuneCheckpoint]: List of available checkpoints
|
|
180
|
+
"""
|
|
181
|
+
checkpoints: List[FinetuneCheckpoint] = []
|
|
182
|
+
|
|
183
|
+
for event in events:
|
|
184
|
+
event_type = event.type
|
|
185
|
+
|
|
186
|
+
if event_type == FinetuneEventType.CHECKPOINT_SAVE:
|
|
187
|
+
step = get_event_step(event)
|
|
188
|
+
checkpoint_name = f"{id}:{step}" if step is not None else id
|
|
189
|
+
|
|
190
|
+
checkpoints.append(
|
|
191
|
+
FinetuneCheckpoint(
|
|
192
|
+
type=(
|
|
193
|
+
f"Intermediate (step {step})"
|
|
194
|
+
if step is not None
|
|
195
|
+
else "Intermediate"
|
|
196
|
+
),
|
|
197
|
+
timestamp=event.created_at,
|
|
198
|
+
name=checkpoint_name,
|
|
199
|
+
)
|
|
200
|
+
)
|
|
201
|
+
elif event_type == FinetuneEventType.JOB_COMPLETE:
|
|
202
|
+
if hasattr(event, "model_path"):
|
|
203
|
+
checkpoints.append(
|
|
204
|
+
FinetuneCheckpoint(
|
|
205
|
+
type=(
|
|
206
|
+
"Final Merged"
|
|
207
|
+
if hasattr(event, "adapter_path")
|
|
208
|
+
else "Final"
|
|
209
|
+
),
|
|
210
|
+
timestamp=event.created_at,
|
|
211
|
+
name=id,
|
|
212
|
+
)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
if hasattr(event, "adapter_path"):
|
|
216
|
+
checkpoints.append(
|
|
217
|
+
FinetuneCheckpoint(
|
|
218
|
+
type=(
|
|
219
|
+
"Final Adapter" if hasattr(event, "model_path") else "Final"
|
|
220
|
+
),
|
|
221
|
+
timestamp=event.created_at,
|
|
222
|
+
name=id,
|
|
223
|
+
)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Sort by timestamp (newest first)
|
|
227
|
+
checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
|
|
228
|
+
|
|
229
|
+
return checkpoints
|
|
230
|
+
|
|
231
|
+
|
|
133
232
|
class FineTuning:
|
|
134
233
|
def __init__(self, client: TogetherClient) -> None:
|
|
135
234
|
self._client = client
|
|
@@ -162,6 +261,9 @@ class FineTuning:
|
|
|
162
261
|
verbose: bool = False,
|
|
163
262
|
model_limits: FinetuneTrainingLimits | None = None,
|
|
164
263
|
train_on_inputs: bool | Literal["auto"] = "auto",
|
|
264
|
+
training_method: str = "sft",
|
|
265
|
+
dpo_beta: float | None = None,
|
|
266
|
+
from_checkpoint: str | None = None,
|
|
165
267
|
) -> FinetuneResponse:
|
|
166
268
|
"""
|
|
167
269
|
Method to initiate a fine-tuning job
|
|
@@ -207,6 +309,12 @@ class FineTuning:
|
|
|
207
309
|
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
|
|
208
310
|
(Instruction format), inputs will be masked.
|
|
209
311
|
Defaults to "auto".
|
|
312
|
+
training_method (str, optional): Training method. Defaults to "sft".
|
|
313
|
+
Supported methods: "sft", "dpo".
|
|
314
|
+
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
|
|
315
|
+
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
|
|
316
|
+
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
|
|
317
|
+
The step value is optional, without it the final checkpoint will be used.
|
|
210
318
|
|
|
211
319
|
Returns:
|
|
212
320
|
FinetuneResponse: Object containing information about fine-tuning job.
|
|
@@ -218,7 +326,6 @@ class FineTuning:
|
|
|
218
326
|
|
|
219
327
|
if model_limits is None:
|
|
220
328
|
model_limits = self.get_model_limits(model=model)
|
|
221
|
-
|
|
222
329
|
finetune_request = createFinetuneRequest(
|
|
223
330
|
model_limits=model_limits,
|
|
224
331
|
training_file=training_file,
|
|
@@ -244,6 +351,9 @@ class FineTuning:
|
|
|
244
351
|
wandb_project_name=wandb_project_name,
|
|
245
352
|
wandb_name=wandb_name,
|
|
246
353
|
train_on_inputs=train_on_inputs,
|
|
354
|
+
training_method=training_method,
|
|
355
|
+
dpo_beta=dpo_beta,
|
|
356
|
+
from_checkpoint=from_checkpoint,
|
|
247
357
|
)
|
|
248
358
|
|
|
249
359
|
if verbose:
|
|
@@ -261,7 +371,6 @@ class FineTuning:
|
|
|
261
371
|
),
|
|
262
372
|
stream=False,
|
|
263
373
|
)
|
|
264
|
-
|
|
265
374
|
assert isinstance(response, TogetherResponse)
|
|
266
375
|
|
|
267
376
|
return FinetuneResponse(**response.data)
|
|
@@ -366,17 +475,29 @@ class FineTuning:
|
|
|
366
475
|
),
|
|
367
476
|
stream=False,
|
|
368
477
|
)
|
|
369
|
-
|
|
370
478
|
assert isinstance(response, TogetherResponse)
|
|
371
479
|
|
|
372
480
|
return FinetuneListEvents(**response.data)
|
|
373
481
|
|
|
482
|
+
def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
|
|
483
|
+
"""
|
|
484
|
+
List available checkpoints for a fine-tuning job
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
id (str): Unique identifier of the fine-tune job to list checkpoints for
|
|
488
|
+
|
|
489
|
+
Returns:
|
|
490
|
+
List[FinetuneCheckpoint]: List of available checkpoints
|
|
491
|
+
"""
|
|
492
|
+
events = self.list_events(id).data or []
|
|
493
|
+
return _process_checkpoints_from_events(events, id)
|
|
494
|
+
|
|
374
495
|
def download(
|
|
375
496
|
self,
|
|
376
497
|
id: str,
|
|
377
498
|
*,
|
|
378
499
|
output: Path | str | None = None,
|
|
379
|
-
checkpoint_step: int =
|
|
500
|
+
checkpoint_step: int | None = None,
|
|
380
501
|
checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
|
|
381
502
|
) -> FinetuneDownloadResult:
|
|
382
503
|
"""
|
|
@@ -397,9 +518,19 @@ class FineTuning:
|
|
|
397
518
|
FinetuneDownloadResult: Object containing downloaded model metadata
|
|
398
519
|
"""
|
|
399
520
|
|
|
521
|
+
if re.match(_FT_JOB_WITH_STEP_REGEX, id) is not None:
|
|
522
|
+
if checkpoint_step is None:
|
|
523
|
+
checkpoint_step = int(id.split(":")[1])
|
|
524
|
+
id = id.split(":")[0]
|
|
525
|
+
else:
|
|
526
|
+
raise ValueError(
|
|
527
|
+
"Fine-tuning job ID {id} contains a colon to specify the step to download, but `checkpoint_step` "
|
|
528
|
+
"was also set. Remove one of the step specifiers to proceed."
|
|
529
|
+
)
|
|
530
|
+
|
|
400
531
|
url = f"finetune/download?ft_id={id}"
|
|
401
532
|
|
|
402
|
-
if checkpoint_step
|
|
533
|
+
if checkpoint_step is not None:
|
|
403
534
|
url += f"&checkpoint_step={checkpoint_step}"
|
|
404
535
|
|
|
405
536
|
ft_job = self.retrieve(id)
|
|
@@ -503,6 +634,9 @@ class AsyncFineTuning:
|
|
|
503
634
|
verbose: bool = False,
|
|
504
635
|
model_limits: FinetuneTrainingLimits | None = None,
|
|
505
636
|
train_on_inputs: bool | Literal["auto"] = "auto",
|
|
637
|
+
training_method: str = "sft",
|
|
638
|
+
dpo_beta: float | None = None,
|
|
639
|
+
from_checkpoint: str | None = None,
|
|
506
640
|
) -> FinetuneResponse:
|
|
507
641
|
"""
|
|
508
642
|
Async method to initiate a fine-tuning job
|
|
@@ -548,6 +682,12 @@ class AsyncFineTuning:
|
|
|
548
682
|
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
|
|
549
683
|
(Instruction format), inputs will be masked.
|
|
550
684
|
Defaults to "auto".
|
|
685
|
+
training_method (str, optional): Training method. Defaults to "sft".
|
|
686
|
+
Supported methods: "sft", "dpo".
|
|
687
|
+
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
|
|
688
|
+
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
|
|
689
|
+
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
|
|
690
|
+
The step value is optional, without it the final checkpoint will be used.
|
|
551
691
|
|
|
552
692
|
Returns:
|
|
553
693
|
FinetuneResponse: Object containing information about fine-tuning job.
|
|
@@ -585,6 +725,9 @@ class AsyncFineTuning:
|
|
|
585
725
|
wandb_project_name=wandb_project_name,
|
|
586
726
|
wandb_name=wandb_name,
|
|
587
727
|
train_on_inputs=train_on_inputs,
|
|
728
|
+
training_method=training_method,
|
|
729
|
+
dpo_beta=dpo_beta,
|
|
730
|
+
from_checkpoint=from_checkpoint,
|
|
588
731
|
)
|
|
589
732
|
|
|
590
733
|
if verbose:
|
|
@@ -687,30 +830,45 @@ class AsyncFineTuning:
|
|
|
687
830
|
|
|
688
831
|
async def list_events(self, id: str) -> FinetuneListEvents:
|
|
689
832
|
"""
|
|
690
|
-
|
|
833
|
+
List fine-tuning events
|
|
691
834
|
|
|
692
835
|
Args:
|
|
693
|
-
id (str):
|
|
836
|
+
id (str): Unique identifier of the fine-tune job to list events for
|
|
694
837
|
|
|
695
838
|
Returns:
|
|
696
|
-
FinetuneListEvents: Object containing list of fine-tune events
|
|
839
|
+
FinetuneListEvents: Object containing list of fine-tune job events
|
|
697
840
|
"""
|
|
698
841
|
|
|
699
842
|
requestor = api_requestor.APIRequestor(
|
|
700
843
|
client=self._client,
|
|
701
844
|
)
|
|
702
845
|
|
|
703
|
-
|
|
846
|
+
events_response, _, _ = await requestor.arequest(
|
|
704
847
|
options=TogetherRequest(
|
|
705
848
|
method="GET",
|
|
706
|
-
url=f"fine-tunes/{id}/events",
|
|
849
|
+
url=f"fine-tunes/{normalize_key(id)}/events",
|
|
707
850
|
),
|
|
708
851
|
stream=False,
|
|
709
852
|
)
|
|
710
853
|
|
|
711
|
-
|
|
854
|
+
# FIXME: API returns "data" field with no object type (should be "list")
|
|
855
|
+
events_list = FinetuneListEvents(object="list", **events_response.data)
|
|
712
856
|
|
|
713
|
-
return
|
|
857
|
+
return events_list
|
|
858
|
+
|
|
859
|
+
async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
|
|
860
|
+
"""
|
|
861
|
+
List available checkpoints for a fine-tuning job
|
|
862
|
+
|
|
863
|
+
Args:
|
|
864
|
+
id (str): Unique identifier of the fine-tune job to list checkpoints for
|
|
865
|
+
|
|
866
|
+
Returns:
|
|
867
|
+
List[FinetuneCheckpoint]: Object containing list of available checkpoints
|
|
868
|
+
"""
|
|
869
|
+
events_list = await self.list_events(id)
|
|
870
|
+
events = events_list.data or []
|
|
871
|
+
return _process_checkpoints_from_events(events, id)
|
|
714
872
|
|
|
715
873
|
async def download(
|
|
716
874
|
self, id: str, *, output: str | None = None, checkpoint_step: int = -1
|
together/types/__init__.py
CHANGED
|
@@ -1,4 +1,13 @@
|
|
|
1
1
|
from together.types.abstract import TogetherClient
|
|
2
|
+
from together.types.audio_speech import (
|
|
3
|
+
AudioLanguage,
|
|
4
|
+
AudioResponseEncoding,
|
|
5
|
+
AudioResponseFormat,
|
|
6
|
+
AudioSpeechRequest,
|
|
7
|
+
AudioSpeechStreamChunk,
|
|
8
|
+
AudioSpeechStreamEvent,
|
|
9
|
+
AudioSpeechStreamResponse,
|
|
10
|
+
)
|
|
2
11
|
from together.types.chat_completions import (
|
|
3
12
|
ChatCompletionChunk,
|
|
4
13
|
ChatCompletionRequest,
|
|
@@ -11,6 +20,7 @@ from together.types.completions import (
|
|
|
11
20
|
CompletionResponse,
|
|
12
21
|
)
|
|
13
22
|
from together.types.embeddings import EmbeddingRequest, EmbeddingResponse
|
|
23
|
+
from together.types.endpoints import Autoscaling, DedicatedEndpoint, ListEndpoint
|
|
14
24
|
from together.types.files import (
|
|
15
25
|
FileDeleteResponse,
|
|
16
26
|
FileList,
|
|
@@ -21,36 +31,25 @@ from together.types.files import (
|
|
|
21
31
|
FileType,
|
|
22
32
|
)
|
|
23
33
|
from together.types.finetune import (
|
|
34
|
+
TrainingMethodDPO,
|
|
35
|
+
TrainingMethodSFT,
|
|
36
|
+
FinetuneCheckpoint,
|
|
24
37
|
FinetuneDownloadResult,
|
|
38
|
+
FinetuneLinearLRSchedulerArgs,
|
|
25
39
|
FinetuneList,
|
|
26
40
|
FinetuneListEvents,
|
|
41
|
+
FinetuneLRScheduler,
|
|
27
42
|
FinetuneRequest,
|
|
28
43
|
FinetuneResponse,
|
|
44
|
+
FinetuneTrainingLimits,
|
|
29
45
|
FullTrainingType,
|
|
30
46
|
LoRATrainingType,
|
|
31
47
|
TrainingType,
|
|
32
|
-
FinetuneTrainingLimits,
|
|
33
|
-
FinetuneLRScheduler,
|
|
34
|
-
FinetuneLinearLRSchedulerArgs,
|
|
35
|
-
)
|
|
36
|
-
from together.types.images import (
|
|
37
|
-
ImageRequest,
|
|
38
|
-
ImageResponse,
|
|
39
48
|
)
|
|
49
|
+
from together.types.images import ImageRequest, ImageResponse
|
|
40
50
|
from together.types.models import ModelObject
|
|
41
|
-
from together.types.rerank import
|
|
42
|
-
|
|
43
|
-
RerankResponse,
|
|
44
|
-
)
|
|
45
|
-
from together.types.audio_speech import (
|
|
46
|
-
AudioSpeechRequest,
|
|
47
|
-
AudioResponseFormat,
|
|
48
|
-
AudioLanguage,
|
|
49
|
-
AudioResponseEncoding,
|
|
50
|
-
AudioSpeechStreamChunk,
|
|
51
|
-
AudioSpeechStreamEvent,
|
|
52
|
-
AudioSpeechStreamResponse,
|
|
53
|
-
)
|
|
51
|
+
from together.types.rerank import RerankRequest, RerankResponse
|
|
52
|
+
|
|
54
53
|
|
|
55
54
|
__all__ = [
|
|
56
55
|
"TogetherClient",
|
|
@@ -63,6 +62,7 @@ __all__ = [
|
|
|
63
62
|
"ChatCompletionResponse",
|
|
64
63
|
"EmbeddingRequest",
|
|
65
64
|
"EmbeddingResponse",
|
|
65
|
+
"FinetuneCheckpoint",
|
|
66
66
|
"FinetuneRequest",
|
|
67
67
|
"FinetuneResponse",
|
|
68
68
|
"FinetuneList",
|
|
@@ -83,6 +83,8 @@ __all__ = [
|
|
|
83
83
|
"TrainingType",
|
|
84
84
|
"FullTrainingType",
|
|
85
85
|
"LoRATrainingType",
|
|
86
|
+
"TrainingMethodDPO",
|
|
87
|
+
"TrainingMethodSFT",
|
|
86
88
|
"RerankRequest",
|
|
87
89
|
"RerankResponse",
|
|
88
90
|
"FinetuneTrainingLimits",
|
|
@@ -93,4 +95,7 @@ __all__ = [
|
|
|
93
95
|
"AudioSpeechStreamChunk",
|
|
94
96
|
"AudioSpeechStreamEvent",
|
|
95
97
|
"AudioSpeechStreamResponse",
|
|
98
|
+
"DedicatedEndpoint",
|
|
99
|
+
"ListEndpoint",
|
|
100
|
+
"Autoscaling",
|
|
96
101
|
]
|
|
@@ -44,16 +44,22 @@ class ToolCalls(BaseModel):
|
|
|
44
44
|
class ChatCompletionMessageContentType(str, Enum):
|
|
45
45
|
TEXT = "text"
|
|
46
46
|
IMAGE_URL = "image_url"
|
|
47
|
+
VIDEO_URL = "video_url"
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
class ChatCompletionMessageContentImageURL(BaseModel):
|
|
50
51
|
url: str
|
|
51
52
|
|
|
52
53
|
|
|
54
|
+
class ChatCompletionMessageContentVideoURL(BaseModel):
|
|
55
|
+
url: str
|
|
56
|
+
|
|
57
|
+
|
|
53
58
|
class ChatCompletionMessageContent(BaseModel):
|
|
54
59
|
type: ChatCompletionMessageContentType
|
|
55
60
|
text: str | None = None
|
|
56
61
|
image_url: ChatCompletionMessageContentImageURL | None = None
|
|
62
|
+
video_url: ChatCompletionMessageContentVideoURL | None = None
|
|
57
63
|
|
|
58
64
|
|
|
59
65
|
class ChatCompletionMessage(BaseModel):
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any, Dict, Literal, Optional, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TogetherJSONModel(BaseModel):
|
|
10
|
+
"""Base model with JSON serialization support."""
|
|
11
|
+
|
|
12
|
+
def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
|
|
13
|
+
exclude_none = kwargs.pop("exclude_none", True)
|
|
14
|
+
data = super().model_dump(exclude_none=exclude_none, **kwargs)
|
|
15
|
+
|
|
16
|
+
# Convert datetime objects to ISO format strings
|
|
17
|
+
for key, value in data.items():
|
|
18
|
+
if isinstance(value, datetime):
|
|
19
|
+
data[key] = value.isoformat()
|
|
20
|
+
|
|
21
|
+
return data
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Autoscaling(TogetherJSONModel):
|
|
25
|
+
"""Configuration for automatic scaling of replicas based on demand."""
|
|
26
|
+
|
|
27
|
+
min_replicas: int = Field(
|
|
28
|
+
description="The minimum number of replicas to maintain, even when there is no load"
|
|
29
|
+
)
|
|
30
|
+
max_replicas: int = Field(
|
|
31
|
+
description="The maximum number of replicas to scale up to under load"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class EndpointPricing(TogetherJSONModel):
|
|
36
|
+
"""Pricing details for using an endpoint."""
|
|
37
|
+
|
|
38
|
+
cents_per_minute: float = Field(
|
|
39
|
+
description="Cost per minute of endpoint uptime in cents"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class HardwareSpec(TogetherJSONModel):
|
|
44
|
+
"""Detailed specifications of a hardware configuration."""
|
|
45
|
+
|
|
46
|
+
gpu_type: str = Field(description="The type/model of GPU")
|
|
47
|
+
gpu_link: str = Field(description="The GPU interconnect technology")
|
|
48
|
+
gpu_memory: Union[float, int] = Field(description="Amount of GPU memory in GB")
|
|
49
|
+
gpu_count: int = Field(description="Number of GPUs in this configuration")
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class HardwareAvailability(TogetherJSONModel):
|
|
53
|
+
"""Indicates the current availability status of a hardware configuration."""
|
|
54
|
+
|
|
55
|
+
status: Literal["available", "unavailable", "insufficient"] = Field(
|
|
56
|
+
description="The availability status of the hardware configuration"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class HardwareWithStatus(TogetherJSONModel):
|
|
61
|
+
"""Hardware configuration details with optional availability status."""
|
|
62
|
+
|
|
63
|
+
object: Literal["hardware"] = Field(description="The type of object")
|
|
64
|
+
id: str = Field(description="Unique identifier for the hardware configuration")
|
|
65
|
+
pricing: EndpointPricing = Field(
|
|
66
|
+
description="Pricing details for this hardware configuration"
|
|
67
|
+
)
|
|
68
|
+
specs: HardwareSpec = Field(description="Detailed specifications of this hardware")
|
|
69
|
+
availability: Optional[HardwareAvailability] = Field(
|
|
70
|
+
default=None,
|
|
71
|
+
description="Current availability status of this hardware configuration",
|
|
72
|
+
)
|
|
73
|
+
updated_at: datetime = Field(
|
|
74
|
+
description="Timestamp of when the hardware status was last updated"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class BaseEndpoint(TogetherJSONModel):
|
|
79
|
+
"""Base class for endpoint models with common fields."""
|
|
80
|
+
|
|
81
|
+
object: Literal["endpoint"] = Field(description="The type of object")
|
|
82
|
+
id: Optional[str] = Field(
|
|
83
|
+
default=None, description="Unique identifier for the endpoint"
|
|
84
|
+
)
|
|
85
|
+
name: str = Field(description="System name for the endpoint")
|
|
86
|
+
model: str = Field(description="The model deployed on this endpoint")
|
|
87
|
+
type: str = Field(description="The type of endpoint")
|
|
88
|
+
owner: str = Field(description="The owner of this endpoint")
|
|
89
|
+
state: Literal[
|
|
90
|
+
"PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "FAILED", "ERROR"
|
|
91
|
+
] = Field(description="Current state of the endpoint")
|
|
92
|
+
created_at: datetime = Field(description="Timestamp when the endpoint was created")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ListEndpoint(BaseEndpoint):
|
|
96
|
+
"""Details about an endpoint when listed via the list endpoint."""
|
|
97
|
+
|
|
98
|
+
type: Literal["dedicated", "serverless"] = Field(description="The type of endpoint")
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class DedicatedEndpoint(BaseEndpoint):
|
|
102
|
+
"""Details about a dedicated endpoint deployment."""
|
|
103
|
+
|
|
104
|
+
id: str = Field(description="Unique identifier for the endpoint")
|
|
105
|
+
type: Literal["dedicated"] = Field(description="The type of endpoint")
|
|
106
|
+
display_name: str = Field(description="Human-readable name for the endpoint")
|
|
107
|
+
hardware: str = Field(
|
|
108
|
+
description="The hardware configuration used for this endpoint"
|
|
109
|
+
)
|
|
110
|
+
autoscaling: Autoscaling = Field(
|
|
111
|
+
description="Configuration for automatic scaling of the endpoint"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
__all__ = [
|
|
116
|
+
"DedicatedEndpoint",
|
|
117
|
+
"ListEndpoint",
|
|
118
|
+
"Autoscaling",
|
|
119
|
+
"EndpointPricing",
|
|
120
|
+
"HardwareSpec",
|
|
121
|
+
"HardwareAvailability",
|
|
122
|
+
"HardwareWithStatus",
|
|
123
|
+
]
|
together/types/finetune.py
CHANGED
|
@@ -135,6 +135,31 @@ class LoRATrainingType(TrainingType):
|
|
|
135
135
|
type: str = "Lora"
|
|
136
136
|
|
|
137
137
|
|
|
138
|
+
class TrainingMethod(BaseModel):
|
|
139
|
+
"""
|
|
140
|
+
Training method type
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
method: str
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class TrainingMethodSFT(TrainingMethod):
|
|
147
|
+
"""
|
|
148
|
+
Training method type for SFT training
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
method: Literal["sft"] = "sft"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class TrainingMethodDPO(TrainingMethod):
|
|
155
|
+
"""
|
|
156
|
+
Training method type for DPO training
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
method: Literal["dpo"] = "dpo"
|
|
160
|
+
dpo_beta: float | None = None
|
|
161
|
+
|
|
162
|
+
|
|
138
163
|
class FinetuneRequest(BaseModel):
|
|
139
164
|
"""
|
|
140
165
|
Fine-tune request type
|
|
@@ -178,6 +203,12 @@ class FinetuneRequest(BaseModel):
|
|
|
178
203
|
training_type: FullTrainingType | LoRATrainingType | None = None
|
|
179
204
|
# train on inputs
|
|
180
205
|
train_on_inputs: StrictBool | Literal["auto"] = "auto"
|
|
206
|
+
# training method
|
|
207
|
+
training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
|
|
208
|
+
default_factory=TrainingMethodSFT
|
|
209
|
+
)
|
|
210
|
+
# from step
|
|
211
|
+
from_checkpoint: str | None = None
|
|
181
212
|
|
|
182
213
|
|
|
183
214
|
class FinetuneResponse(BaseModel):
|
|
@@ -256,6 +287,7 @@ class FinetuneResponse(BaseModel):
|
|
|
256
287
|
training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
|
|
257
288
|
training_file_size: int | None = Field(None, alias="TrainingFileSize")
|
|
258
289
|
train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
|
|
290
|
+
from_checkpoint: str | None = None
|
|
259
291
|
|
|
260
292
|
@field_validator("training_type")
|
|
261
293
|
@classmethod
|
|
@@ -320,3 +352,16 @@ class FinetuneLRScheduler(BaseModel):
|
|
|
320
352
|
|
|
321
353
|
class FinetuneLinearLRSchedulerArgs(BaseModel):
|
|
322
354
|
min_lr_ratio: float | None = 0.0
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class FinetuneCheckpoint(BaseModel):
|
|
358
|
+
"""
|
|
359
|
+
Fine-tuning checkpoint information
|
|
360
|
+
"""
|
|
361
|
+
|
|
362
|
+
# checkpoint type (e.g. "Intermediate", "Final", "Final Merged", "Final Adapter")
|
|
363
|
+
type: str
|
|
364
|
+
# timestamp when the checkpoint was created
|
|
365
|
+
timestamp: str
|
|
366
|
+
# checkpoint name/identifier
|
|
367
|
+
name: str
|
together/utils/__init__.py
CHANGED
|
@@ -8,6 +8,8 @@ from together.utils.tools import (
|
|
|
8
8
|
finetune_price_to_dollars,
|
|
9
9
|
normalize_key,
|
|
10
10
|
parse_timestamp,
|
|
11
|
+
format_timestamp,
|
|
12
|
+
get_event_step,
|
|
11
13
|
)
|
|
12
14
|
|
|
13
15
|
|
|
@@ -23,6 +25,8 @@ __all__ = [
|
|
|
23
25
|
"enforce_trailing_slash",
|
|
24
26
|
"normalize_key",
|
|
25
27
|
"parse_timestamp",
|
|
28
|
+
"format_timestamp",
|
|
29
|
+
"get_event_step",
|
|
26
30
|
"finetune_price_to_dollars",
|
|
27
31
|
"convert_bytes",
|
|
28
32
|
"convert_unix_timestamp",
|