together 1.4.0__py3-none-any.whl → 1.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import re
3
4
  from pathlib import Path
4
- from typing import Literal
5
+ from typing import Literal, List
5
6
 
6
7
  from rich import print as rprint
7
8
 
@@ -22,9 +23,28 @@ from together.types import (
22
23
  TrainingType,
23
24
  FinetuneLRScheduler,
24
25
  FinetuneLinearLRSchedulerArgs,
26
+ TrainingMethodDPO,
27
+ TrainingMethodSFT,
28
+ FinetuneCheckpoint,
25
29
  )
26
- from together.types.finetune import DownloadCheckpointType
27
- from together.utils import log_warn_once, normalize_key
30
+ from together.types.finetune import (
31
+ DownloadCheckpointType,
32
+ FinetuneEventType,
33
+ FinetuneEvent,
34
+ )
35
+ from together.utils import (
36
+ log_warn_once,
37
+ normalize_key,
38
+ get_event_step,
39
+ )
40
+
41
+ _FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"
42
+
43
+
44
+ AVAILABLE_TRAINING_METHODS = {
45
+ TrainingMethodSFT().method,
46
+ TrainingMethodDPO().method,
47
+ }
28
48
 
29
49
 
30
50
  def createFinetuneRequest(
@@ -52,7 +72,11 @@ def createFinetuneRequest(
52
72
  wandb_project_name: str | None = None,
53
73
  wandb_name: str | None = None,
54
74
  train_on_inputs: bool | Literal["auto"] = "auto",
75
+ training_method: str = "sft",
76
+ dpo_beta: float | None = None,
77
+ from_checkpoint: str | None = None,
55
78
  ) -> FinetuneRequest:
79
+
56
80
  if batch_size == "max":
57
81
  log_warn_once(
58
82
  "Starting from together>=1.3.0, "
@@ -100,11 +124,20 @@ def createFinetuneRequest(
100
124
  if weight_decay is not None and (weight_decay < 0):
101
125
  raise ValueError("Weight decay should be non-negative")
102
126
 
127
+ if training_method not in AVAILABLE_TRAINING_METHODS:
128
+ raise ValueError(
129
+ f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
130
+ )
131
+
103
132
  lrScheduler = FinetuneLRScheduler(
104
133
  lr_scheduler_type="linear",
105
134
  lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
106
135
  )
107
136
 
137
+ training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT()
138
+ if training_method == "dpo":
139
+ training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
140
+
108
141
  finetune_request = FinetuneRequest(
109
142
  model=model,
110
143
  training_file=training_file,
@@ -125,11 +158,77 @@ def createFinetuneRequest(
125
158
  wandb_project_name=wandb_project_name,
126
159
  wandb_name=wandb_name,
127
160
  train_on_inputs=train_on_inputs,
161
+ training_method=training_method_cls,
162
+ from_checkpoint=from_checkpoint,
128
163
  )
129
164
 
130
165
  return finetune_request
131
166
 
132
167
 
168
+ def _process_checkpoints_from_events(
169
+ events: List[FinetuneEvent], id: str
170
+ ) -> List[FinetuneCheckpoint]:
171
+ """
172
+ Helper function to process events and create checkpoint list.
173
+
174
+ Args:
175
+ events (List[FinetuneEvent]): List of fine-tune events to process
176
+ id (str): Fine-tune job ID
177
+
178
+ Returns:
179
+ List[FinetuneCheckpoint]: List of available checkpoints
180
+ """
181
+ checkpoints: List[FinetuneCheckpoint] = []
182
+
183
+ for event in events:
184
+ event_type = event.type
185
+
186
+ if event_type == FinetuneEventType.CHECKPOINT_SAVE:
187
+ step = get_event_step(event)
188
+ checkpoint_name = f"{id}:{step}" if step is not None else id
189
+
190
+ checkpoints.append(
191
+ FinetuneCheckpoint(
192
+ type=(
193
+ f"Intermediate (step {step})"
194
+ if step is not None
195
+ else "Intermediate"
196
+ ),
197
+ timestamp=event.created_at,
198
+ name=checkpoint_name,
199
+ )
200
+ )
201
+ elif event_type == FinetuneEventType.JOB_COMPLETE:
202
+ if hasattr(event, "model_path"):
203
+ checkpoints.append(
204
+ FinetuneCheckpoint(
205
+ type=(
206
+ "Final Merged"
207
+ if hasattr(event, "adapter_path")
208
+ else "Final"
209
+ ),
210
+ timestamp=event.created_at,
211
+ name=id,
212
+ )
213
+ )
214
+
215
+ if hasattr(event, "adapter_path"):
216
+ checkpoints.append(
217
+ FinetuneCheckpoint(
218
+ type=(
219
+ "Final Adapter" if hasattr(event, "model_path") else "Final"
220
+ ),
221
+ timestamp=event.created_at,
222
+ name=id,
223
+ )
224
+ )
225
+
226
+ # Sort by timestamp (newest first)
227
+ checkpoints.sort(key=lambda x: x.timestamp, reverse=True)
228
+
229
+ return checkpoints
230
+
231
+
133
232
  class FineTuning:
134
233
  def __init__(self, client: TogetherClient) -> None:
135
234
  self._client = client
@@ -162,6 +261,9 @@ class FineTuning:
162
261
  verbose: bool = False,
163
262
  model_limits: FinetuneTrainingLimits | None = None,
164
263
  train_on_inputs: bool | Literal["auto"] = "auto",
264
+ training_method: str = "sft",
265
+ dpo_beta: float | None = None,
266
+ from_checkpoint: str | None = None,
165
267
  ) -> FinetuneResponse:
166
268
  """
167
269
  Method to initiate a fine-tuning job
@@ -207,6 +309,12 @@ class FineTuning:
207
309
  For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
208
310
  (Instruction format), inputs will be masked.
209
311
  Defaults to "auto".
312
+ training_method (str, optional): Training method. Defaults to "sft".
313
+ Supported methods: "sft", "dpo".
314
+ dpo_beta (float, optional): DPO beta parameter. Defaults to None.
315
+ from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
316
+ The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
317
+ The step value is optional, without it the final checkpoint will be used.
210
318
 
211
319
  Returns:
212
320
  FinetuneResponse: Object containing information about fine-tuning job.
@@ -218,7 +326,6 @@ class FineTuning:
218
326
 
219
327
  if model_limits is None:
220
328
  model_limits = self.get_model_limits(model=model)
221
-
222
329
  finetune_request = createFinetuneRequest(
223
330
  model_limits=model_limits,
224
331
  training_file=training_file,
@@ -244,6 +351,9 @@ class FineTuning:
244
351
  wandb_project_name=wandb_project_name,
245
352
  wandb_name=wandb_name,
246
353
  train_on_inputs=train_on_inputs,
354
+ training_method=training_method,
355
+ dpo_beta=dpo_beta,
356
+ from_checkpoint=from_checkpoint,
247
357
  )
248
358
 
249
359
  if verbose:
@@ -261,7 +371,6 @@ class FineTuning:
261
371
  ),
262
372
  stream=False,
263
373
  )
264
-
265
374
  assert isinstance(response, TogetherResponse)
266
375
 
267
376
  return FinetuneResponse(**response.data)
@@ -366,17 +475,29 @@ class FineTuning:
366
475
  ),
367
476
  stream=False,
368
477
  )
369
-
370
478
  assert isinstance(response, TogetherResponse)
371
479
 
372
480
  return FinetuneListEvents(**response.data)
373
481
 
482
+ def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
483
+ """
484
+ List available checkpoints for a fine-tuning job
485
+
486
+ Args:
487
+ id (str): Unique identifier of the fine-tune job to list checkpoints for
488
+
489
+ Returns:
490
+ List[FinetuneCheckpoint]: List of available checkpoints
491
+ """
492
+ events = self.list_events(id).data or []
493
+ return _process_checkpoints_from_events(events, id)
494
+
374
495
  def download(
375
496
  self,
376
497
  id: str,
377
498
  *,
378
499
  output: Path | str | None = None,
379
- checkpoint_step: int = -1,
500
+ checkpoint_step: int | None = None,
380
501
  checkpoint_type: DownloadCheckpointType = DownloadCheckpointType.DEFAULT,
381
502
  ) -> FinetuneDownloadResult:
382
503
  """
@@ -397,9 +518,19 @@ class FineTuning:
397
518
  FinetuneDownloadResult: Object containing downloaded model metadata
398
519
  """
399
520
 
521
+ if re.match(_FT_JOB_WITH_STEP_REGEX, id) is not None:
522
+ if checkpoint_step is None:
523
+ checkpoint_step = int(id.split(":")[1])
524
+ id = id.split(":")[0]
525
+ else:
526
+ raise ValueError(
527
+ "Fine-tuning job ID {id} contains a colon to specify the step to download, but `checkpoint_step` "
528
+ "was also set. Remove one of the step specifiers to proceed."
529
+ )
530
+
400
531
  url = f"finetune/download?ft_id={id}"
401
532
 
402
- if checkpoint_step > 0:
533
+ if checkpoint_step is not None:
403
534
  url += f"&checkpoint_step={checkpoint_step}"
404
535
 
405
536
  ft_job = self.retrieve(id)
@@ -503,6 +634,9 @@ class AsyncFineTuning:
503
634
  verbose: bool = False,
504
635
  model_limits: FinetuneTrainingLimits | None = None,
505
636
  train_on_inputs: bool | Literal["auto"] = "auto",
637
+ training_method: str = "sft",
638
+ dpo_beta: float | None = None,
639
+ from_checkpoint: str | None = None,
506
640
  ) -> FinetuneResponse:
507
641
  """
508
642
  Async method to initiate a fine-tuning job
@@ -548,6 +682,12 @@ class AsyncFineTuning:
548
682
  For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
549
683
  (Instruction format), inputs will be masked.
550
684
  Defaults to "auto".
685
+ training_method (str, optional): Training method. Defaults to "sft".
686
+ Supported methods: "sft", "dpo".
687
+ dpo_beta (float, optional): DPO beta parameter. Defaults to None.
688
+ from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
689
+ The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
690
+ The step value is optional, without it the final checkpoint will be used.
551
691
 
552
692
  Returns:
553
693
  FinetuneResponse: Object containing information about fine-tuning job.
@@ -585,6 +725,9 @@ class AsyncFineTuning:
585
725
  wandb_project_name=wandb_project_name,
586
726
  wandb_name=wandb_name,
587
727
  train_on_inputs=train_on_inputs,
728
+ training_method=training_method,
729
+ dpo_beta=dpo_beta,
730
+ from_checkpoint=from_checkpoint,
588
731
  )
589
732
 
590
733
  if verbose:
@@ -687,30 +830,45 @@ class AsyncFineTuning:
687
830
 
688
831
  async def list_events(self, id: str) -> FinetuneListEvents:
689
832
  """
690
- Async method to lists events of a fine-tune job
833
+ List fine-tuning events
691
834
 
692
835
  Args:
693
- id (str): Fine-tune ID to list events for. A string that starts with `ft-`.
836
+ id (str): Unique identifier of the fine-tune job to list events for
694
837
 
695
838
  Returns:
696
- FinetuneListEvents: Object containing list of fine-tune events
839
+ FinetuneListEvents: Object containing list of fine-tune job events
697
840
  """
698
841
 
699
842
  requestor = api_requestor.APIRequestor(
700
843
  client=self._client,
701
844
  )
702
845
 
703
- response, _, _ = await requestor.arequest(
846
+ events_response, _, _ = await requestor.arequest(
704
847
  options=TogetherRequest(
705
848
  method="GET",
706
- url=f"fine-tunes/{id}/events",
849
+ url=f"fine-tunes/{normalize_key(id)}/events",
707
850
  ),
708
851
  stream=False,
709
852
  )
710
853
 
711
- assert isinstance(response, TogetherResponse)
854
+ # FIXME: API returns "data" field with no object type (should be "list")
855
+ events_list = FinetuneListEvents(object="list", **events_response.data)
712
856
 
713
- return FinetuneListEvents(**response.data)
857
+ return events_list
858
+
859
+ async def list_checkpoints(self, id: str) -> List[FinetuneCheckpoint]:
860
+ """
861
+ List available checkpoints for a fine-tuning job
862
+
863
+ Args:
864
+ id (str): Unique identifier of the fine-tune job to list checkpoints for
865
+
866
+ Returns:
867
+ List[FinetuneCheckpoint]: Object containing list of available checkpoints
868
+ """
869
+ events_list = await self.list_events(id)
870
+ events = events_list.data or []
871
+ return _process_checkpoints_from_events(events, id)
714
872
 
715
873
  async def download(
716
874
  self, id: str, *, output: str | None = None, checkpoint_step: int = -1
@@ -1,4 +1,13 @@
1
1
  from together.types.abstract import TogetherClient
2
+ from together.types.audio_speech import (
3
+ AudioLanguage,
4
+ AudioResponseEncoding,
5
+ AudioResponseFormat,
6
+ AudioSpeechRequest,
7
+ AudioSpeechStreamChunk,
8
+ AudioSpeechStreamEvent,
9
+ AudioSpeechStreamResponse,
10
+ )
2
11
  from together.types.chat_completions import (
3
12
  ChatCompletionChunk,
4
13
  ChatCompletionRequest,
@@ -11,6 +20,7 @@ from together.types.completions import (
11
20
  CompletionResponse,
12
21
  )
13
22
  from together.types.embeddings import EmbeddingRequest, EmbeddingResponse
23
+ from together.types.endpoints import Autoscaling, DedicatedEndpoint, ListEndpoint
14
24
  from together.types.files import (
15
25
  FileDeleteResponse,
16
26
  FileList,
@@ -21,36 +31,25 @@ from together.types.files import (
21
31
  FileType,
22
32
  )
23
33
  from together.types.finetune import (
34
+ TrainingMethodDPO,
35
+ TrainingMethodSFT,
36
+ FinetuneCheckpoint,
24
37
  FinetuneDownloadResult,
38
+ FinetuneLinearLRSchedulerArgs,
25
39
  FinetuneList,
26
40
  FinetuneListEvents,
41
+ FinetuneLRScheduler,
27
42
  FinetuneRequest,
28
43
  FinetuneResponse,
44
+ FinetuneTrainingLimits,
29
45
  FullTrainingType,
30
46
  LoRATrainingType,
31
47
  TrainingType,
32
- FinetuneTrainingLimits,
33
- FinetuneLRScheduler,
34
- FinetuneLinearLRSchedulerArgs,
35
- )
36
- from together.types.images import (
37
- ImageRequest,
38
- ImageResponse,
39
48
  )
49
+ from together.types.images import ImageRequest, ImageResponse
40
50
  from together.types.models import ModelObject
41
- from together.types.rerank import (
42
- RerankRequest,
43
- RerankResponse,
44
- )
45
- from together.types.audio_speech import (
46
- AudioSpeechRequest,
47
- AudioResponseFormat,
48
- AudioLanguage,
49
- AudioResponseEncoding,
50
- AudioSpeechStreamChunk,
51
- AudioSpeechStreamEvent,
52
- AudioSpeechStreamResponse,
53
- )
51
+ from together.types.rerank import RerankRequest, RerankResponse
52
+
54
53
 
55
54
  __all__ = [
56
55
  "TogetherClient",
@@ -63,6 +62,7 @@ __all__ = [
63
62
  "ChatCompletionResponse",
64
63
  "EmbeddingRequest",
65
64
  "EmbeddingResponse",
65
+ "FinetuneCheckpoint",
66
66
  "FinetuneRequest",
67
67
  "FinetuneResponse",
68
68
  "FinetuneList",
@@ -83,6 +83,8 @@ __all__ = [
83
83
  "TrainingType",
84
84
  "FullTrainingType",
85
85
  "LoRATrainingType",
86
+ "TrainingMethodDPO",
87
+ "TrainingMethodSFT",
86
88
  "RerankRequest",
87
89
  "RerankResponse",
88
90
  "FinetuneTrainingLimits",
@@ -93,4 +95,7 @@ __all__ = [
93
95
  "AudioSpeechStreamChunk",
94
96
  "AudioSpeechStreamEvent",
95
97
  "AudioSpeechStreamResponse",
98
+ "DedicatedEndpoint",
99
+ "ListEndpoint",
100
+ "Autoscaling",
96
101
  ]
@@ -44,16 +44,22 @@ class ToolCalls(BaseModel):
44
44
  class ChatCompletionMessageContentType(str, Enum):
45
45
  TEXT = "text"
46
46
  IMAGE_URL = "image_url"
47
+ VIDEO_URL = "video_url"
47
48
 
48
49
 
49
50
  class ChatCompletionMessageContentImageURL(BaseModel):
50
51
  url: str
51
52
 
52
53
 
54
+ class ChatCompletionMessageContentVideoURL(BaseModel):
55
+ url: str
56
+
57
+
53
58
  class ChatCompletionMessageContent(BaseModel):
54
59
  type: ChatCompletionMessageContentType
55
60
  text: str | None = None
56
61
  image_url: ChatCompletionMessageContentImageURL | None = None
62
+ video_url: ChatCompletionMessageContentVideoURL | None = None
57
63
 
58
64
 
59
65
  class ChatCompletionMessage(BaseModel):
@@ -0,0 +1,123 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import Any, Dict, Literal, Optional, Union
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class TogetherJSONModel(BaseModel):
10
+ """Base model with JSON serialization support."""
11
+
12
+ def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
13
+ exclude_none = kwargs.pop("exclude_none", True)
14
+ data = super().model_dump(exclude_none=exclude_none, **kwargs)
15
+
16
+ # Convert datetime objects to ISO format strings
17
+ for key, value in data.items():
18
+ if isinstance(value, datetime):
19
+ data[key] = value.isoformat()
20
+
21
+ return data
22
+
23
+
24
+ class Autoscaling(TogetherJSONModel):
25
+ """Configuration for automatic scaling of replicas based on demand."""
26
+
27
+ min_replicas: int = Field(
28
+ description="The minimum number of replicas to maintain, even when there is no load"
29
+ )
30
+ max_replicas: int = Field(
31
+ description="The maximum number of replicas to scale up to under load"
32
+ )
33
+
34
+
35
+ class EndpointPricing(TogetherJSONModel):
36
+ """Pricing details for using an endpoint."""
37
+
38
+ cents_per_minute: float = Field(
39
+ description="Cost per minute of endpoint uptime in cents"
40
+ )
41
+
42
+
43
+ class HardwareSpec(TogetherJSONModel):
44
+ """Detailed specifications of a hardware configuration."""
45
+
46
+ gpu_type: str = Field(description="The type/model of GPU")
47
+ gpu_link: str = Field(description="The GPU interconnect technology")
48
+ gpu_memory: Union[float, int] = Field(description="Amount of GPU memory in GB")
49
+ gpu_count: int = Field(description="Number of GPUs in this configuration")
50
+
51
+
52
+ class HardwareAvailability(TogetherJSONModel):
53
+ """Indicates the current availability status of a hardware configuration."""
54
+
55
+ status: Literal["available", "unavailable", "insufficient"] = Field(
56
+ description="The availability status of the hardware configuration"
57
+ )
58
+
59
+
60
+ class HardwareWithStatus(TogetherJSONModel):
61
+ """Hardware configuration details with optional availability status."""
62
+
63
+ object: Literal["hardware"] = Field(description="The type of object")
64
+ id: str = Field(description="Unique identifier for the hardware configuration")
65
+ pricing: EndpointPricing = Field(
66
+ description="Pricing details for this hardware configuration"
67
+ )
68
+ specs: HardwareSpec = Field(description="Detailed specifications of this hardware")
69
+ availability: Optional[HardwareAvailability] = Field(
70
+ default=None,
71
+ description="Current availability status of this hardware configuration",
72
+ )
73
+ updated_at: datetime = Field(
74
+ description="Timestamp of when the hardware status was last updated"
75
+ )
76
+
77
+
78
+ class BaseEndpoint(TogetherJSONModel):
79
+ """Base class for endpoint models with common fields."""
80
+
81
+ object: Literal["endpoint"] = Field(description="The type of object")
82
+ id: Optional[str] = Field(
83
+ default=None, description="Unique identifier for the endpoint"
84
+ )
85
+ name: str = Field(description="System name for the endpoint")
86
+ model: str = Field(description="The model deployed on this endpoint")
87
+ type: str = Field(description="The type of endpoint")
88
+ owner: str = Field(description="The owner of this endpoint")
89
+ state: Literal[
90
+ "PENDING", "STARTING", "STARTED", "STOPPING", "STOPPED", "FAILED", "ERROR"
91
+ ] = Field(description="Current state of the endpoint")
92
+ created_at: datetime = Field(description="Timestamp when the endpoint was created")
93
+
94
+
95
+ class ListEndpoint(BaseEndpoint):
96
+ """Details about an endpoint when listed via the list endpoint."""
97
+
98
+ type: Literal["dedicated", "serverless"] = Field(description="The type of endpoint")
99
+
100
+
101
+ class DedicatedEndpoint(BaseEndpoint):
102
+ """Details about a dedicated endpoint deployment."""
103
+
104
+ id: str = Field(description="Unique identifier for the endpoint")
105
+ type: Literal["dedicated"] = Field(description="The type of endpoint")
106
+ display_name: str = Field(description="Human-readable name for the endpoint")
107
+ hardware: str = Field(
108
+ description="The hardware configuration used for this endpoint"
109
+ )
110
+ autoscaling: Autoscaling = Field(
111
+ description="Configuration for automatic scaling of the endpoint"
112
+ )
113
+
114
+
115
+ __all__ = [
116
+ "DedicatedEndpoint",
117
+ "ListEndpoint",
118
+ "Autoscaling",
119
+ "EndpointPricing",
120
+ "HardwareSpec",
121
+ "HardwareAvailability",
122
+ "HardwareWithStatus",
123
+ ]
@@ -135,6 +135,31 @@ class LoRATrainingType(TrainingType):
135
135
  type: str = "Lora"
136
136
 
137
137
 
138
+ class TrainingMethod(BaseModel):
139
+ """
140
+ Training method type
141
+ """
142
+
143
+ method: str
144
+
145
+
146
+ class TrainingMethodSFT(TrainingMethod):
147
+ """
148
+ Training method type for SFT training
149
+ """
150
+
151
+ method: Literal["sft"] = "sft"
152
+
153
+
154
+ class TrainingMethodDPO(TrainingMethod):
155
+ """
156
+ Training method type for DPO training
157
+ """
158
+
159
+ method: Literal["dpo"] = "dpo"
160
+ dpo_beta: float | None = None
161
+
162
+
138
163
  class FinetuneRequest(BaseModel):
139
164
  """
140
165
  Fine-tune request type
@@ -178,6 +203,12 @@ class FinetuneRequest(BaseModel):
178
203
  training_type: FullTrainingType | LoRATrainingType | None = None
179
204
  # train on inputs
180
205
  train_on_inputs: StrictBool | Literal["auto"] = "auto"
206
+ # training method
207
+ training_method: TrainingMethodSFT | TrainingMethodDPO = Field(
208
+ default_factory=TrainingMethodSFT
209
+ )
210
+ # from step
211
+ from_checkpoint: str | None = None
181
212
 
182
213
 
183
214
  class FinetuneResponse(BaseModel):
@@ -256,6 +287,7 @@ class FinetuneResponse(BaseModel):
256
287
  training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
257
288
  training_file_size: int | None = Field(None, alias="TrainingFileSize")
258
289
  train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
290
+ from_checkpoint: str | None = None
259
291
 
260
292
  @field_validator("training_type")
261
293
  @classmethod
@@ -320,3 +352,16 @@ class FinetuneLRScheduler(BaseModel):
320
352
 
321
353
  class FinetuneLinearLRSchedulerArgs(BaseModel):
322
354
  min_lr_ratio: float | None = 0.0
355
+
356
+
357
+ class FinetuneCheckpoint(BaseModel):
358
+ """
359
+ Fine-tuning checkpoint information
360
+ """
361
+
362
+ # checkpoint type (e.g. "Intermediate", "Final", "Final Merged", "Final Adapter")
363
+ type: str
364
+ # timestamp when the checkpoint was created
365
+ timestamp: str
366
+ # checkpoint name/identifier
367
+ name: str
@@ -8,6 +8,8 @@ from together.utils.tools import (
8
8
  finetune_price_to_dollars,
9
9
  normalize_key,
10
10
  parse_timestamp,
11
+ format_timestamp,
12
+ get_event_step,
11
13
  )
12
14
 
13
15
 
@@ -23,6 +25,8 @@ __all__ = [
23
25
  "enforce_trailing_slash",
24
26
  "normalize_key",
25
27
  "parse_timestamp",
28
+ "format_timestamp",
29
+ "get_event_step",
26
30
  "finetune_price_to_dollars",
27
31
  "convert_bytes",
28
32
  "convert_unix_timestamp",