together 1.4.1__py3-none-any.whl → 1.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import json
4
- from datetime import datetime
4
+ from datetime import datetime, timezone
5
5
  from textwrap import wrap
6
6
  from typing import Any, Literal
7
+ import re
7
8
 
8
9
  import click
9
10
  from click.core import ParameterSource # type: ignore[attr-defined]
@@ -17,8 +18,13 @@ from together.utils import (
17
18
  log_warn,
18
19
  log_warn_once,
19
20
  parse_timestamp,
21
+ format_timestamp,
22
+ )
23
+ from together.types.finetune import (
24
+ DownloadCheckpointType,
25
+ FinetuneTrainingLimits,
26
+ FinetuneEventType,
20
27
  )
21
- from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
22
28
 
23
29
 
24
30
  _CONFIRMATION_MESSAGE = (
@@ -104,6 +110,18 @@ def fine_tuning(ctx: click.Context) -> None:
104
110
  default="all-linear",
105
111
  help="Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'",
106
112
  )
113
+ @click.option(
114
+ "--training-method",
115
+ type=click.Choice(["sft", "dpo"]),
116
+ default="sft",
117
+ help="Training method to use. Options: sft (supervised fine-tuning), dpo (Direct Preference Optimization)",
118
+ )
119
+ @click.option(
120
+ "--dpo-beta",
121
+ type=float,
122
+ default=0.1,
123
+ help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
124
+ )
107
125
  @click.option(
108
126
  "--suffix", type=str, default=None, help="Suffix for the fine-tuned model name"
109
127
  )
@@ -126,6 +144,14 @@ def fine_tuning(ctx: click.Context) -> None:
126
144
  help="Whether to mask the user messages in conversational data or prompts in instruction data. "
127
145
  "`auto` will automatically determine whether to mask the inputs based on the data format.",
128
146
  )
147
+ @click.option(
148
+ "--from-checkpoint",
149
+ type=str,
150
+ default=None,
151
+ help="The checkpoint identifier to continue training from a previous fine-tuning job. "
152
+ "The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. "
153
+ "The step value is optional, without it the final checkpoint will be used.",
154
+ )
129
155
  def create(
130
156
  ctx: click.Context,
131
157
  training_file: str,
@@ -152,6 +178,9 @@ def create(
152
178
  wandb_name: str,
153
179
  confirm: bool,
154
180
  train_on_inputs: bool | Literal["auto"],
181
+ training_method: str,
182
+ dpo_beta: float,
183
+ from_checkpoint: str,
155
184
  ) -> None:
156
185
  """Start fine-tuning"""
157
186
  client: Together = ctx.obj
@@ -180,6 +209,9 @@ def create(
180
209
  wandb_project_name=wandb_project_name,
181
210
  wandb_name=wandb_name,
182
211
  train_on_inputs=train_on_inputs,
212
+ training_method=training_method,
213
+ dpo_beta=dpo_beta,
214
+ from_checkpoint=from_checkpoint,
183
215
  )
184
216
 
185
217
  model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
@@ -261,7 +293,9 @@ def list(ctx: click.Context) -> None:
261
293
 
262
294
  response.data = response.data or []
263
295
 
264
- response.data.sort(key=lambda x: parse_timestamp(x.created_at or ""))
296
+ # Use a default datetime for None values to make sure the key function always returns a comparable value
297
+ epoch_start = datetime.fromtimestamp(0, tz=timezone.utc)
298
+ response.data.sort(key=lambda x: parse_timestamp(x.created_at or "") or epoch_start)
265
299
 
266
300
  display_list = []
267
301
  for i in response.data:
@@ -344,6 +378,34 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
344
378
  click.echo(table)
345
379
 
346
380
 
381
+ @fine_tuning.command()
382
+ @click.pass_context
383
+ @click.argument("fine_tune_id", type=str, required=True)
384
+ def list_checkpoints(ctx: click.Context, fine_tune_id: str) -> None:
385
+ """List available checkpoints for a fine-tuning job"""
386
+ client: Together = ctx.obj
387
+
388
+ checkpoints = client.fine_tuning.list_checkpoints(fine_tune_id)
389
+
390
+ display_list = []
391
+ for checkpoint in checkpoints:
392
+ display_list.append(
393
+ {
394
+ "Type": checkpoint.type,
395
+ "Timestamp": format_timestamp(checkpoint.timestamp),
396
+ "Name": checkpoint.name,
397
+ }
398
+ )
399
+
400
+ if display_list:
401
+ click.echo(f"Job {fine_tune_id} contains the following checkpoints:")
402
+ table = tabulate(display_list, headers="keys", tablefmt="grid")
403
+ click.echo(table)
404
+ click.echo("\nTo download a checkpoint, use `together fine-tuning download`")
405
+ else:
406
+ click.echo(f"No checkpoints found for job {fine_tune_id}")
407
+
408
+
347
409
  @fine_tuning.command()
348
410
  @click.pass_context
349
411
  @click.argument("fine_tune_id", type=str, required=True)
@@ -358,7 +420,7 @@ def list_events(ctx: click.Context, fine_tune_id: str) -> None:
358
420
  "--checkpoint-step",
359
421
  type=int,
360
422
  required=False,
361
- default=-1,
423
+ default=None,
362
424
  help="Download fine-tuning checkpoint. Defaults to latest.",
363
425
  )
364
426
  @click.option(
@@ -372,7 +434,7 @@ def download(
372
434
  ctx: click.Context,
373
435
  fine_tune_id: str,
374
436
  output_dir: str,
375
- checkpoint_step: int,
437
+ checkpoint_step: int | None,
376
438
  checkpoint_type: DownloadCheckpointType,
377
439
  ) -> None:
378
440
  """Download fine-tuning checkpoint"""
together/client.py CHANGED
@@ -1,13 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import os
4
- from typing import Dict
4
+ import sys
5
+ from typing import Dict, TYPE_CHECKING
5
6
 
6
7
  from together import resources
7
8
  from together.constants import BASE_URL, MAX_RETRIES, TIMEOUT_SECS
8
9
  from together.error import AuthenticationError
9
10
  from together.types import TogetherClient
10
11
  from together.utils import enforce_trailing_slash
12
+ from together.utils.api_helpers import get_google_colab_secret
11
13
 
12
14
 
13
15
  class Together:
@@ -44,6 +46,9 @@ class Together:
44
46
  if not api_key:
45
47
  api_key = os.environ.get("TOGETHER_API_KEY")
46
48
 
49
+ if not api_key and "google.colab" in sys.modules:
50
+ api_key = get_google_colab_secret("TOGETHER_API_KEY")
51
+
47
52
  if not api_key:
48
53
  raise AuthenticationError(
49
54
  "The api_key client option must be set either by passing api_key to the client or by setting the "
@@ -117,6 +122,9 @@ class AsyncTogether:
117
122
  if not api_key:
118
123
  api_key = os.environ.get("TOGETHER_API_KEY")
119
124
 
125
+ if not api_key and "google.colab" in sys.modules:
126
+ api_key = get_google_colab_secret("TOGETHER_API_KEY")
127
+
120
128
  if not api_key:
121
129
  raise AuthenticationError(
122
130
  "The api_key client option must be set either by passing api_key to the client or by setting the "
together/constants.py CHANGED
@@ -39,12 +39,18 @@ class DatasetFormat(enum.Enum):
39
39
  GENERAL = "general"
40
40
  CONVERSATION = "conversation"
41
41
  INSTRUCTION = "instruction"
42
+ PREFERENCE_OPENAI = "preference_openai"
42
43
 
43
44
 
44
45
  JSONL_REQUIRED_COLUMNS_MAP = {
45
46
  DatasetFormat.GENERAL: ["text"],
46
47
  DatasetFormat.CONVERSATION: ["messages"],
47
48
  DatasetFormat.INSTRUCTION: ["prompt", "completion"],
49
+ DatasetFormat.PREFERENCE_OPENAI: [
50
+ "input",
51
+ "preferred_output",
52
+ "non_preferred_output",
53
+ ],
48
54
  }
49
55
  REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
50
56
  POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
@@ -161,7 +161,7 @@ class Finetune:
161
161
  cls,
162
162
  fine_tune_id: str,
163
163
  output: str | None = None,
164
- step: int = -1,
164
+ step: int | None = None,
165
165
  ) -> Dict[str, Any]:
166
166
  """Legacy finetuning download function."""
167
167
 
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import re
3
4
  from pathlib import Path
4
- from typing import Literal
5
+ from typing import Literal, List
5
6
 
6
7
  from rich import print as rprint
7
8
 
@@ -22,9 +23,28 @@ from together.types import (
22
23
  TrainingType,
23
24
  FinetuneLRScheduler,
24
25
  FinetuneLinearLRSchedulerArgs,
26
+ TrainingMethodDPO,
27
+ TrainingMethodSFT,
28
+ FinetuneCheckpoint,
25
29
  )
26
- from together.types.finetune import DownloadCheckpointType
27
- from together.utils import log_warn_once, normalize_key
30
+ from together.types.finetune import (
31
+ DownloadCheckpointType,
32
+ FinetuneEventType,
33
+ FinetuneEvent,
34
+ )
35
+ from together.utils import (
36
+ log_warn_once,
37
+ normalize_key,
38
+ get_event_step,
39
+ )
40
+
41
+ _FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
42
+
43
+
44
+ AVAILABLE_TRAINING_METHODS = {
45
+ TrainingMethodSFT().method,
46
+ TrainingMethodDPO().method,
47
+ }
28
48
 
29
49
 
30
50
  def createFinetuneRequest(
@@ -52,7 +72,11 @@ def createFinetuneRequest(
52
72
  wandb_project_name: str | None = None,
53
73
  wandb_name: str | None = None,
54
74
  train_on_inputs: bool | Literal["auto"] = "auto",
75
+ training_method: str = "sft",
76
+ dpo_beta: float | None = None,
77
+ from_checkpoint: str | None = None,
55
78
  ) -> FinetuneRequest:
79
+
56
80
  if batch_size == "max":
57
81
  log_warn_once(
58
82
  "Starting from together>=1.3.0, "
@@ -100,11 +124,20 @@ def createFinetuneRequest(
100
124
  if weight_decay is not None and (weight_decay < 0):
101
125
  raise ValueError("Weight decay should be non-negative")
102
126
 
127
+ if training_method not in AVAILABLE_TRAINING_METHODS:
128
+ raise ValueError(
129
+ f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
130
+ )
131
+
103
132
  lrScheduler = FinetuneLRScheduler(
104
133
  lr_scheduler_type="linear",
105
134
  lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
106
135
  )
107
136
 
137
+ training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
138
+ if training_method == "dpo":
139
+ training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
140
+
108
141
  finetune_request = FinetuneRequest(
109
142
  model=model,
110
143
  training_file=training_file,
@@ -125,11 +158,77 @@ def createFinetuneRequest(
125
158
  wandb_project_name=wandb_project_name,
126
159
  wandb_name=wandb_name,
127
160
  train_on_inputs=train_on_inputs,
161
+ training_method=training_method_cls,
162
+ from_checkpoint=from_checkpoint,
128
163
  )
129
164
 
130
165
  return finetune_request
131
166
 
132
167
 
168
+ def _process_checkpoints_from_events(
169
+ events: List[FinetuneEvent], id: str
170
+ ) -> List[FinetuneCheckpoint]:
171
+ """
172
+ Helper function to process events and create checkpoint list.
173
+
174
+ Args:
175
+ events (List[FinetuneEvent]): List of fine-tune events to process
176
+ id (str): Fine-tune job ID
177
+
178
+ Returns:
179
+ List[FinetuneCheckpoint]: List of available checkpoints
180
+ """
181
+ checkpoints: List[FinetuneCheckpoint] = []
182
+
183
+ for event in events:
184
+ event_type = event.type
185
+
186
+ if event_type == FinetuneEventType.CHECKPOINT_SAVE:
187
+ step = get_event_step(event)
188
+ checkpoint_name = f"{id}:{step}" if step is not None else id
189
+
190
+ checkpoints.append(
191
+ FinetuneCheckpoint(
192
+ type=(
193
+ f"Intermediate (step {step})"
194
+ if step is not None
195
+ else "Intermediate"
196
+ ),
197
+ timestamp=event.created_at,
198
+ name=checkpoint_name,
199
+ )
200
+ )
201
+ elif event_type == FinetuneEventType.JOB_COMPLETE:
202
+ if hasattr(event, "model_path"):
203
+ checkpoints.append(
204
+ FinetuneCheckpoint(
205
+ type=(
206
+ "Final Merged"
207
+ if hasattr(event, "adapter_path")
208
+ else "Final"
209
+ ),
210
+ timestamp=event.created_at,
211
+ name=id,
212
+ )
213
+ )
214
+
215
+ if hasattr(event, "adapter_path"):
216
+ checkpoints.append(
217
+ FinetuneCheckpoint(
218
+ type=(
219
+ "Final Adapter" if hasattr(event, "model_path") else "Final"
220
+ ),
221
+ timestamp=event.created_at,
222
+ name=id,
223
+ )
224
+ )
225
+
226
+ # Sort by timestamp (newest first)
227
+ checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
228
+
229
+ return checkpoints
230
+
231
+
133
232
  class FineTuning:
134
233
  def __init__(self, client: TogetherClient) -> None:
135
234
  self._client = client
@@ -162,6 +261,9 @@ class FineTuning:
162
261
  verbose: bool = False,
163
262
  model_limits: FinetuneTrainingLimits | None = None,
164
263
  train_on_inputs: bool | Literal["auto"] = "auto",
264
+ training_method: str = "sft",
265
+ dpo_beta: float | None = None,
266
+ from_checkpoint: str | None = None,
165
267
  ) -> FinetuneResponse:
166
268
  """
167
269
  Method to initiate a fine-tuning job
@@ -207,6 +309,12 @@ class FineTuning:
207
309
  For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
208
310
  (Instruction format), inputs will be masked.
209
311
  Defaults to "auto".
312
+ training_method (str, optional): Training method. Defaults to "sft".
313
+ Supported methods: "sft", "dpo".
314
+ dpo_beta (float, optional): DPO beta parameter. Defaults to None.
315
+ from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
316
+ The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
317
+ The step value is optional, without it the final checkpoint will be used.
210
318
 
211
319
  Returns:
212
320
  FinetuneResponse: Object containing information about fine-tuning job.
@@ -218,7 +326,6 @@ class FineTuning:
218
326
 
219
327
  if model_limits is None:
220
328
  model_limits = self.get_model_limits(model=model)
221
-
222
329
  finetune_request = createFinetuneRequest(
223
330
  model_limits=model_limits,
224
331
  training_file=training_file,
@@ -244,6 +351,9 @@ class FineTuning:
244
351
  wandb_project_name=wandb_project_name,
245
352
  wandb_name=wandb_name,
246
353
  train_on_inputs=train_on_inputs,
354
+ training_method=training_method,
355
+ dpo_beta=dpo_beta,
356
+ from_checkpoint=from_checkpoint,
247
357
  )
248
358
 
249
359
  if verbose:
@@ -261,7 +371,6 @@ class FineTuning:
261
371
  ),
262
372
  stream=False,
263
373
  )
264
-
265
374
  assert isinstance(response, TogetherResponse)
266
375
 
267
376
  return FinetuneResponse(**response.data)
@@ -366,17 +475,29 @@ class FineTuning:
366
475
  ),
367
476
  stream=False,
368
477
  )
369
-
370
478
  assert isinstance(response, TogetherResponse)
371
479
 
372
480
  return FinetuneListEvents(**response.data)
373
481
 
482
+ def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
483
+ """
484
+ List available checkpoints for a fine-tuning job
485
+
486
+ Args:
487
+ id (str): Unique identifier of the fine-tune job to list checkpoints for
488
+
489
+ Returns:
490
+ List[FinetuneCheckpoint]: List of available checkpoints
491
+ """
492
+ events = self.list_events(id).data or []
493
+ return _process_checkpoints_from_events(events, id)
494
+
374
495
  def download(
375
496
  self,
376
497
  id: str,
377
498
  *,
378
499
  output: Path | str | None = None,
379
- checkpoint_step: int = -1,
500
+ checkpoint_step: int | None = None,
380
501
  checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
381
502
  ) -> FinetuneDownloadResult:
382
503
  """
@@ -397,9 +518,19 @@ class FineTuning:
397
518
  FinetuneDownloadResult: Object containing downloaded model metadata
398
519
  """
399
520
 
521
+ if re.match(_FT_JOB_WITH_STEP_REGEX, id) is not None:
522
+ if checkpoint_step is None:
523
+ checkpoint_step = int(id.split(":")[1])
524
+ id = id.split(":")[0]
525
+ else:
526
+ raise ValueError(
527
+ "Fine-tuning job ID {id} contains a colon to specify the step to download, but `checkpoint_step` "
528
+ "was also set. Remove one of the step specifiers to proceed."
529
+ )
530
+
400
531
  url = f"finetune/download?ft_id={id}"
401
532
 
402
- if checkpoint_step > 0:
533
+ if checkpoint_step is not None:
403
534
  url += f"&checkpoint_step={checkpoint_step}"
404
535
 
405
536
  ft_job = self.retrieve(id)
@@ -503,6 +634,9 @@ class AsyncFineTuning:
503
634
  verbose: bool = False,
504
635
  model_limits: FinetuneTrainingLimits | None = None,
505
636
  train_on_inputs: bool | Literal["auto"] = "auto",
637
+ training_method: str = "sft",
638
+ dpo_beta: float | None = None,
639
+ from_checkpoint: str | None = None,
506
640
  ) -> FinetuneResponse:
507
641
  """
508
642
  Async method to initiate a fine-tuning job
@@ -548,6 +682,12 @@ class AsyncFineTuning:
548
682
  For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
549
683
  (Instruction format), inputs will be masked.
550
684
  Defaults to "auto".
685
+ training_method (str, optional): Training method. Defaults to "sft".
686
+ Supported methods: "sft", "dpo".
687
+ dpo_beta (float, optional): DPO beta parameter. Defaults to None.
688
+ from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
689
+ The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
690
+ The step value is optional, without it the final checkpoint will be used.
551
691
 
552
692
  Returns:
553
693
  FinetuneResponse: Object containing information about fine-tuning job.
@@ -585,6 +725,9 @@ class AsyncFineTuning:
585
725
  wandb_project_name=wandb_project_name,
586
726
  wandb_name=wandb_name,
587
727
  train_on_inputs=train_on_inputs,
728
+ training_method=training_method,
729
+ dpo_beta=dpo_beta,
730
+ from_checkpoint=from_checkpoint,
588
731
  )
589
732
 
590
733
  if verbose:
@@ -687,30 +830,45 @@ class AsyncFineTuning:
687
830
 
688
831
  async def list_events(self, id: str) -> FinetuneListEvents:
689
832
  """
690
- Async method to lists events of a fine-tune job
833
+ List fine-tuning events
691
834
 
692
835
  Args:
693
- id (str): Fine-tune ID to list events for. A string that starts with `ft-`.
836
+ id (str): Unique identifier of the fine-tune job to list events for
694
837
 
695
838
  Returns:
696
- FinetuneListEvents: Object containing list of fine-tune events
839
+ FinetuneListEvents: Object containing list of fine-tune job events
697
840
  """
698
841
 
699
842
  requestor = api_requestor.APIRequestor(
700
843
  client=self._client,
701
844
  )
702
845
 
703
- response, _, _ = await requestor.arequest(
846
+ events_response, _, _ = await requestor.arequest(
704
847
  options=TogetherRequest(
705
848
  method="GET",
706
- url=f"fine-tunes/{id}/events",
849
+ url=f"fine-tunes/{normalize_key(id)}/events",
707
850
  ),
708
851
  stream=False,
709
852
  )
710
853
 
711
- assert isinstance(response, TogetherResponse)
854
+ # FIXME: API returns "data" field with no object type (should be "list")
855
+ events_list = FinetuneListEvents(object="list", **events_response.data)
712
856
 
713
- return FinetuneListEvents(**response.data)
857
+ return events_list
858
+
859
+ async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
860
+ """
861
+ List available checkpoints for a fine-tuning job
862
+
863
+ Args:
864
+ id (str): Unique identifier of the fine-tune job to list checkpoints for
865
+
866
+ Returns:
867
+ List[FinetuneCheckpoint]: Object containing list of available checkpoints
868
+ """
869
+ events_list = await self.list_events(id)
870
+ events = events_list.data or []
871
+ return _process_checkpoints_from_events(events, id)
714
872
 
715
873
  async def download(
716
874
  self, id: str, *, output: str | None = None, checkpoint_step: int = -1
@@ -31,6 +31,9 @@ from together.types.files import (
31
31
  FileType,
32
32
  )
33
33
  from together.types.finetune import (
34
+ TrainingMethodDPO,
35
+ TrainingMethodSFT,
36
+ FinetuneCheckpoint,
34
37
  FinetuneDownloadResult,
35
38
  FinetuneLinearLRSchedulerArgs,
36
39
  FinetuneList,
@@ -59,6 +62,7 @@ __all__ = [
59
62
  "ChatCompletionResponse",
60
63
  "EmbeddingRequest",
61
64
  "EmbeddingResponse",
65
+ "FinetuneCheckpoint",
62
66
  "FinetuneRequest",
63
67
  "FinetuneResponse",
64
68
  "FinetuneList",
@@ -79,6 +83,8 @@ __all__ = [
79
83
  "TrainingType",
80
84
  "FullTrainingType",
81
85
  "LoRATrainingType",
86
+ "TrainingMethodDPO",
87
+ "TrainingMethodSFT",
82
88
  "RerankRequest",
83
89
  "RerankResponse",
84
90
  "FinetuneTrainingLimits",
@@ -44,16 +44,22 @@ class ToolCalls(BaseModel):
44
44
  class ChatCompletionMessageContentType(str, Enum):
45
45
  TEXT = "text"
46
46
  IMAGE_URL = "image_url"
47
+ VIDEO_URL = "video_url"
47
48
 
48
49
 
49
50
  class ChatCompletionMessageContentImageURL(BaseModel):
50
51
  url: str
51
52
 
52
53
 
54
+ class ChatCompletionMessageContentVideoURL(BaseModel):
55
+ url: str
56
+
57
+
53
58
  class ChatCompletionMessageContent(BaseModel):
54
59
  type: ChatCompletionMessageContentType
55
60
  text: str | None = None
56
61
  image_url: ChatCompletionMessageContentImageURL | None = None
62
+ video_url: ChatCompletionMessageContentVideoURL | None = None
57
63
 
58
64
 
59
65
  class ChatCompletionMessage(BaseModel):
@@ -86,9 +86,9 @@ class BaseEndpoint(TogetherJSONModel):
86
86
  model: str = Field(description="The model deployed on this endpoint")
87
87
  type: str = Field(description="The type of endpoint")
88
88
  owner: str = Field(description="The owner of this endpoint")
89
- state: Literal["PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "ERROR"] = (
90
- Field(description="Current state of the endpoint")
91
- )
89
+ state: Literal[
90
+ "PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "FAILED", "ERROR"
91
+ ] = Field(description="Current state of the endpoint")
92
92
  created_at: datetime = Field(description="Timestamp when the endpoint was created")
93
93
 
94
94
 
@@ -135,6 +135,31 @@ class LoRATrainingType(TrainingType):
135
135
  type: str = "Lora"
136
136
 
137
137
 
138
+ class TrainingMethod(BaseModel):
139
+ """
140
+ Training method type
141
+ """
142
+
143
+ method: str
144
+
145
+
146
+ class TrainingMethodSFT(TrainingMethod):
147
+ """
148
+ Training method type for SFT training
149
+ """
150
+
151
+ method: Literal["sft"] = "sft"
152
+
153
+
154
+ class TrainingMethodDPO(TrainingMethod):
155
+ """
156
+ Training method type for DPO training
157
+ """
158
+
159
+ method: Literal["dpo"] = "dpo"
160
+ dpo_beta: float | None = None
161
+
162
+
138
163
  class FinetuneRequest(BaseModel):
139
164
  """
140
165
  Fine-tune request type
@@ -178,6 +203,12 @@ class FinetuneRequest(BaseModel):
178
203
  training_type: FullTrainingType | LoRATrainingType | None = None
179
204
  # train on inputs
180
205
  train_on_inputs: StrictBool | Literal["auto"] = "auto"
206
+ # training method
207
+ training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
208
+ default_factory=TrainingMethodSFT
209
+ )
210
+ # from step
211
+ from_checkpoint: str | None = None
181
212
 
182
213
 
183
214
  class FinetuneResponse(BaseModel):
@@ -256,6 +287,7 @@ class FinetuneResponse(BaseModel):
256
287
  training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
257
288
  training_file_size: int | None = Field(None, alias="TrainingFileSize")
258
289
  train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
290
+ from_checkpoint: str | None = None
259
291
 
260
292
  @field_validator("training_type")
261
293
  @classmethod
@@ -320,3 +352,16 @@ class FinetuneLRScheduler(BaseModel):
320
352
 
321
353
  class FinetuneLinearLRSchedulerArgs(BaseModel):
322
354
  min_lr_ratio: float | None = 0.0
355
+
356
+
357
+ class FinetuneCheckpoint(BaseModel):
358
+ """
359
+ Fine-tuning checkpoint information
360
+ """
361
+
362
+ # checkpoint type (e.g. "Intermediate", "Final", "Final Merged", "Final Adapter")
363
+ type: str
364
+ # timestamp when the checkpoint was created
365
+ timestamp: str
366
+ # checkpoint name/identifier
367
+ name: str
@@ -8,6 +8,8 @@ from together.utils.tools import (
8
8
  finetune_price_to_dollars,
9
9
  normalize_key,
10
10
  parse_timestamp,
11
+ format_timestamp,
12
+ get_event_step,
11
13
  )
12
14
 
13
15
 
@@ -23,6 +25,8 @@ __all__ = [
23
25
  "enforce_trailing_slash",
24
26
  "normalize_key",
25
27
  "parse_timestamp",
28
+ "format_timestamp",
29
+ "get_event_step",
26
30
  "finetune_price_to_dollars",
27
31
  "convert_bytes",
28
32
  "convert_unix_timestamp",