together 2.0.0a10__py3-none-any.whl → 2.0.0a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- together/_base_client.py +8 -2
- together/_version.py +1 -1
- together/lib/cli/api/fine_tuning.py +20 -3
- together/lib/cli/api/utils.py +87 -6
- together/lib/constants.py +9 -0
- together/lib/resources/files.py +65 -6
- together/lib/resources/fine_tuning.py +15 -1
- together/lib/types/fine_tuning.py +36 -0
- together/lib/utils/files.py +187 -29
- together/resources/audio/transcriptions.py +6 -4
- together/resources/audio/translations.py +6 -4
- together/resources/fine_tuning.py +25 -17
- together/types/audio/transcription_create_params.py +5 -2
- together/types/audio/translation_create_params.py +5 -2
- together/types/fine_tuning_cancel_response.py +14 -0
- together/types/fine_tuning_list_response.py +14 -0
- together/types/finetune_response.py +28 -2
- {together-2.0.0a10.dist-info → together-2.0.0a12.dist-info}/METADATA +3 -3
- {together-2.0.0a10.dist-info → together-2.0.0a12.dist-info}/RECORD +22 -22
- {together-2.0.0a10.dist-info → together-2.0.0a12.dist-info}/licenses/LICENSE +1 -1
- {together-2.0.0a10.dist-info → together-2.0.0a12.dist-info}/WHEEL +0 -0
- {together-2.0.0a10.dist-info → together-2.0.0a12.dist-info}/entry_points.txt +0 -0
together/_base_client.py
CHANGED
|
@@ -1247,9 +1247,12 @@ class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
|
|
|
1247
1247
|
*,
|
|
1248
1248
|
cast_to: Type[ResponseT],
|
|
1249
1249
|
body: Body | None = None,
|
|
1250
|
+
files: RequestFiles | None = None,
|
|
1250
1251
|
options: RequestOptions = {},
|
|
1251
1252
|
) -> ResponseT:
|
|
1252
|
-
opts = FinalRequestOptions.construct(
|
|
1253
|
+
opts = FinalRequestOptions.construct(
|
|
1254
|
+
method="patch", url=path, json_data=body, files=to_httpx_files(files), **options
|
|
1255
|
+
)
|
|
1253
1256
|
return self.request(cast_to, opts)
|
|
1254
1257
|
|
|
1255
1258
|
def put(
|
|
@@ -1767,9 +1770,12 @@ class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
|
|
|
1767
1770
|
*,
|
|
1768
1771
|
cast_to: Type[ResponseT],
|
|
1769
1772
|
body: Body | None = None,
|
|
1773
|
+
files: RequestFiles | None = None,
|
|
1770
1774
|
options: RequestOptions = {},
|
|
1771
1775
|
) -> ResponseT:
|
|
1772
|
-
opts = FinalRequestOptions.construct(
|
|
1776
|
+
opts = FinalRequestOptions.construct(
|
|
1777
|
+
method="patch", url=path, json_data=body, files=await async_to_httpx_files(files), **options
|
|
1778
|
+
)
|
|
1773
1779
|
return await self.request(cast_to, opts)
|
|
1774
1780
|
|
|
1775
1781
|
async def put(
|
together/_version.py
CHANGED
|
@@ -10,6 +10,7 @@ from textwrap import wrap
|
|
|
10
10
|
import click
|
|
11
11
|
from rich import print as rprint
|
|
12
12
|
from tabulate import tabulate
|
|
13
|
+
from rich.json import JSON
|
|
13
14
|
from click.core import ParameterSource # type: ignore[attr-defined]
|
|
14
15
|
|
|
15
16
|
from together import Together
|
|
@@ -17,7 +18,7 @@ from together.types import fine_tuning_estimate_price_params as pe_params
|
|
|
17
18
|
from together._types import NOT_GIVEN, NotGiven
|
|
18
19
|
from together.lib.utils import log_warn
|
|
19
20
|
from together.lib.utils.tools import format_timestamp, finetune_price_to_dollars
|
|
20
|
-
from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO
|
|
21
|
+
from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO, generate_progress_bar
|
|
21
22
|
from together.lib.resources.files import DownloadManager
|
|
22
23
|
from together.lib.utils.serializer import datetime_serializer
|
|
23
24
|
from together.types.finetune_response import TrainingTypeFullTrainingType, TrainingTypeLoRaTrainingType
|
|
@@ -175,6 +176,12 @@ def fine_tuning(ctx: click.Context) -> None:
|
|
|
175
176
|
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
|
|
176
177
|
"`auto` will automatically determine whether to mask the inputs based on the data format.",
|
|
177
178
|
)
|
|
179
|
+
@click.option(
|
|
180
|
+
"--train-vision",
|
|
181
|
+
type=bool,
|
|
182
|
+
default=False,
|
|
183
|
+
help="Whether to train the vision encoder. Only supported for multimodal models.",
|
|
184
|
+
)
|
|
178
185
|
@click.option(
|
|
179
186
|
"--from-checkpoint",
|
|
180
187
|
type=str,
|
|
@@ -230,6 +237,7 @@ def create(
|
|
|
230
237
|
lora_dropout: float | None,
|
|
231
238
|
lora_alpha: float | None,
|
|
232
239
|
lora_trainable_modules: str | None,
|
|
240
|
+
train_vision: bool,
|
|
233
241
|
suffix: str | None,
|
|
234
242
|
wandb_api_key: str | None,
|
|
235
243
|
wandb_base_url: str | None,
|
|
@@ -271,6 +279,7 @@ def create(
|
|
|
271
279
|
lora_dropout=lora_dropout,
|
|
272
280
|
lora_alpha=lora_alpha,
|
|
273
281
|
lora_trainable_modules=lora_trainable_modules,
|
|
282
|
+
train_vision=train_vision,
|
|
274
283
|
suffix=suffix,
|
|
275
284
|
wandb_api_key=wandb_api_key,
|
|
276
285
|
wandb_base_url=wandb_base_url,
|
|
@@ -361,7 +370,11 @@ def create(
|
|
|
361
370
|
rpo_alpha=rpo_alpha or 0,
|
|
362
371
|
simpo_gamma=simpo_gamma or 0,
|
|
363
372
|
)
|
|
364
|
-
|
|
373
|
+
|
|
374
|
+
if model_limits.supports_vision:
|
|
375
|
+
# Don't show price estimation for multimodal models yet
|
|
376
|
+
confirm = True
|
|
377
|
+
|
|
365
378
|
finetune_price_estimation_result = client.fine_tuning.estimate_price(
|
|
366
379
|
training_file=training_file,
|
|
367
380
|
validation_file=validation_file,
|
|
@@ -425,6 +438,7 @@ def list(ctx: click.Context) -> None:
|
|
|
425
438
|
"Price": f"""${
|
|
426
439
|
finetune_price_to_dollars(float(str(i.total_price)))
|
|
427
440
|
}""", # convert to string for mypy typing
|
|
441
|
+
"Progress": generate_progress_bar(i, datetime.now().astimezone(), use_rich=False),
|
|
428
442
|
}
|
|
429
443
|
)
|
|
430
444
|
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)
|
|
@@ -444,7 +458,10 @@ def retrieve(ctx: click.Context, fine_tune_id: str) -> None:
|
|
|
444
458
|
# remove events from response for cleaner output
|
|
445
459
|
response.events = None
|
|
446
460
|
|
|
447
|
-
|
|
461
|
+
rprint(JSON.from_data(response.model_json_schema()))
|
|
462
|
+
progress_text = generate_progress_bar(response, datetime.now().astimezone(), use_rich=True)
|
|
463
|
+
prefix = f"Status: [bold]{response.status}[/bold],"
|
|
464
|
+
rprint(f"{prefix} {progress_text}")
|
|
448
465
|
|
|
449
466
|
|
|
450
467
|
@fine_tuning.command()
|
together/lib/cli/api/utils.py
CHANGED
|
@@ -1,18 +1,25 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
3
|
+
import re
|
|
4
|
+
import math
|
|
5
|
+
from typing import List, Union, Literal
|
|
4
6
|
from gettext import gettext as _
|
|
5
|
-
from
|
|
7
|
+
from datetime import datetime
|
|
6
8
|
|
|
7
9
|
import click
|
|
8
10
|
|
|
11
|
+
from together.lib.types.fine_tuning import COMPLETED_STATUSES, FinetuneResponse
|
|
12
|
+
from together.types.finetune_response import FinetuneResponse as _FinetuneResponse
|
|
13
|
+
from together.types.fine_tuning_list_response import Data
|
|
14
|
+
|
|
15
|
+
_PROGRESS_BAR_WIDTH = 40
|
|
16
|
+
|
|
9
17
|
|
|
10
18
|
class AutoIntParamType(click.ParamType):
|
|
11
19
|
name = "integer_or_max"
|
|
12
20
|
_number_class = int
|
|
13
21
|
|
|
14
|
-
|
|
15
|
-
def convert(
|
|
22
|
+
def convert( # pyright: ignore[reportImplicitOverride]
|
|
16
23
|
self, value: str, param: click.Parameter | None, ctx: click.Context | None
|
|
17
24
|
) -> int | Literal["max"] | None:
|
|
18
25
|
if value == "max":
|
|
@@ -30,8 +37,7 @@ class AutoIntParamType(click.ParamType):
|
|
|
30
37
|
class BooleanWithAutoParamType(click.ParamType):
|
|
31
38
|
name = "boolean_or_auto"
|
|
32
39
|
|
|
33
|
-
|
|
34
|
-
def convert(
|
|
40
|
+
def convert( # pyright: ignore[reportImplicitOverride]
|
|
35
41
|
self, value: str, param: click.Parameter | None, ctx: click.Context | None
|
|
36
42
|
) -> bool | Literal["auto"] | None:
|
|
37
43
|
if value == "auto":
|
|
@@ -48,3 +54,78 @@ class BooleanWithAutoParamType(click.ParamType):
|
|
|
48
54
|
|
|
49
55
|
INT_WITH_MAX = AutoIntParamType()
|
|
50
56
|
BOOL_WITH_AUTO = BooleanWithAutoParamType()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _human_readable_time(timedelta: float) -> str:
|
|
60
|
+
"""Convert a timedelta to a compact human-readble string
|
|
61
|
+
Examples:
|
|
62
|
+
00:00:10 -> 10s
|
|
63
|
+
01:23:45 -> 1h 23min 45s
|
|
64
|
+
1 Month 23 days 04:56:07 -> 1month 23d 4h 56min 7s
|
|
65
|
+
Args:
|
|
66
|
+
timedelta (float): The timedelta in seconds to convert.
|
|
67
|
+
Returns:
|
|
68
|
+
A string representing the timedelta in a human-readable format.
|
|
69
|
+
"""
|
|
70
|
+
units = [
|
|
71
|
+
(30 * 24 * 60 * 60, "month"), # 30 days
|
|
72
|
+
(24 * 60 * 60, "d"),
|
|
73
|
+
(60 * 60, "h"),
|
|
74
|
+
(60, "min"),
|
|
75
|
+
(1, "s"),
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
total_seconds = int(timedelta)
|
|
79
|
+
parts: List[str] = []
|
|
80
|
+
|
|
81
|
+
for unit_seconds, unit_name in units:
|
|
82
|
+
if total_seconds >= unit_seconds:
|
|
83
|
+
value = total_seconds // unit_seconds
|
|
84
|
+
total_seconds %= unit_seconds
|
|
85
|
+
parts.append(f"{value}{unit_name}")
|
|
86
|
+
|
|
87
|
+
return " ".join(parts) if parts else "0s"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def generate_progress_bar(
|
|
91
|
+
finetune_job: Union[Data, FinetuneResponse, _FinetuneResponse], current_time: datetime, use_rich: bool = False
|
|
92
|
+
) -> str:
|
|
93
|
+
"""Generate a progress bar for a finetune job.
|
|
94
|
+
Args:
|
|
95
|
+
finetune_job: The finetune job to generate a progress bar for.
|
|
96
|
+
current_time: The current time.
|
|
97
|
+
use_rich: Whether to use rich formatting.
|
|
98
|
+
Returns:
|
|
99
|
+
A string representing the progress bar.
|
|
100
|
+
"""
|
|
101
|
+
progress = "Progress: [bold red]unavailable[/bold red]"
|
|
102
|
+
if finetune_job.status in COMPLETED_STATUSES:
|
|
103
|
+
progress = "Progress: [bold green]completed[/bold green]"
|
|
104
|
+
elif finetune_job.updated_at is not None:
|
|
105
|
+
update_at = finetune_job.updated_at.astimezone()
|
|
106
|
+
|
|
107
|
+
if finetune_job.progress is not None:
|
|
108
|
+
if current_time < update_at:
|
|
109
|
+
return progress
|
|
110
|
+
|
|
111
|
+
if not finetune_job.progress.estimate_available:
|
|
112
|
+
return progress
|
|
113
|
+
|
|
114
|
+
if finetune_job.progress.seconds_remaining <= 0:
|
|
115
|
+
return progress
|
|
116
|
+
|
|
117
|
+
elapsed_time = (current_time - update_at).total_seconds()
|
|
118
|
+
ratio_filled = min(elapsed_time / finetune_job.progress.seconds_remaining, 1.0)
|
|
119
|
+
percentage = ratio_filled * 100
|
|
120
|
+
filled = math.ceil(ratio_filled * _PROGRESS_BAR_WIDTH)
|
|
121
|
+
bar = "█" * filled + "░" * (_PROGRESS_BAR_WIDTH - filled)
|
|
122
|
+
time_left = "N/A"
|
|
123
|
+
if finetune_job.progress.seconds_remaining > elapsed_time:
|
|
124
|
+
time_left = _human_readable_time(finetune_job.progress.seconds_remaining - elapsed_time)
|
|
125
|
+
time_text = f"{time_left} left"
|
|
126
|
+
progress = f"Progress: {bar} [bold]{percentage:>3.0f}%[/bold] [yellow]{time_text}[/yellow]"
|
|
127
|
+
|
|
128
|
+
if use_rich:
|
|
129
|
+
return progress
|
|
130
|
+
|
|
131
|
+
return re.sub(r"\[/?[^\]]+\]", "", progress)
|
together/lib/constants.py
CHANGED
|
@@ -14,6 +14,9 @@ import enum
|
|
|
14
14
|
# Download defaults
|
|
15
15
|
DOWNLOAD_BLOCK_SIZE = 10 * 1024 * 1024 # 10 MB
|
|
16
16
|
DISABLE_TQDM = False
|
|
17
|
+
MAX_DOWNLOAD_RETRIES = 5 # Maximum retries for download failures
|
|
18
|
+
DOWNLOAD_INITIAL_RETRY_DELAY = 1.0 # Initial retry delay in seconds
|
|
19
|
+
DOWNLOAD_MAX_RETRY_DELAY = 30.0 # Maximum retry delay in seconds
|
|
17
20
|
|
|
18
21
|
# Upload defaults
|
|
19
22
|
MAX_CONCURRENT_PARTS = 4 # Maximum concurrent parts for multipart upload
|
|
@@ -34,6 +37,12 @@ NUM_BYTES_IN_GB = 2**30
|
|
|
34
37
|
# maximum number of GB sized files we support finetuning for
|
|
35
38
|
MAX_FILE_SIZE_GB = 50.1
|
|
36
39
|
|
|
40
|
+
# Multimodal limits
|
|
41
|
+
MAX_IMAGES_PER_EXAMPLE = 10
|
|
42
|
+
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB
|
|
43
|
+
# Max length = Header length + base64 factor (4/3) * image bytes
|
|
44
|
+
MAX_BASE64_IMAGE_LENGTH = len("data:image/jpeg;base64,") + 4 * MAX_IMAGE_BYTES // 3
|
|
45
|
+
|
|
37
46
|
# expected columns for Parquet files
|
|
38
47
|
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]
|
|
39
48
|
|
together/lib/resources/files.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import os
|
|
4
4
|
import math
|
|
5
5
|
import stat
|
|
6
|
+
import time
|
|
6
7
|
import uuid
|
|
7
8
|
import shutil
|
|
8
9
|
import asyncio
|
|
@@ -29,12 +30,15 @@ from ..constants import (
|
|
|
29
30
|
MAX_MULTIPART_PARTS,
|
|
30
31
|
TARGET_PART_SIZE_MB,
|
|
31
32
|
MAX_CONCURRENT_PARTS,
|
|
33
|
+
MAX_DOWNLOAD_RETRIES,
|
|
32
34
|
MULTIPART_THRESHOLD_GB,
|
|
35
|
+
DOWNLOAD_MAX_RETRY_DELAY,
|
|
33
36
|
MULTIPART_UPLOAD_TIMEOUT,
|
|
37
|
+
DOWNLOAD_INITIAL_RETRY_DELAY,
|
|
34
38
|
)
|
|
35
39
|
from ..._resource import SyncAPIResource, AsyncAPIResource
|
|
36
40
|
from ..types.error import DownloadError, FileTypeError
|
|
37
|
-
from ..._exceptions import APIStatusError, AuthenticationError
|
|
41
|
+
from ..._exceptions import APIStatusError, APIConnectionError, AuthenticationError
|
|
38
42
|
|
|
39
43
|
log: logging.Logger = logging.getLogger(__name__)
|
|
40
44
|
|
|
@@ -198,6 +202,11 @@ class DownloadManager(SyncAPIResource):
|
|
|
198
202
|
|
|
199
203
|
assert file_size != 0, "Unable to retrieve remote file."
|
|
200
204
|
|
|
205
|
+
# Download with retry logic
|
|
206
|
+
bytes_downloaded = 0
|
|
207
|
+
retry_count = 0
|
|
208
|
+
retry_delay = DOWNLOAD_INITIAL_RETRY_DELAY
|
|
209
|
+
|
|
201
210
|
with tqdm(
|
|
202
211
|
total=file_size,
|
|
203
212
|
unit="B",
|
|
@@ -205,14 +214,64 @@ class DownloadManager(SyncAPIResource):
|
|
|
205
214
|
desc=f"Downloading file {file_path.name}",
|
|
206
215
|
disable=bool(DISABLE_TQDM),
|
|
207
216
|
) as pbar:
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
217
|
+
while bytes_downloaded < file_size:
|
|
218
|
+
try:
|
|
219
|
+
# If this is a retry, close the previous response and create a new one with Range header
|
|
220
|
+
if bytes_downloaded > 0:
|
|
221
|
+
response.close()
|
|
222
|
+
|
|
223
|
+
log.info(f"Resuming download from byte {bytes_downloaded}")
|
|
224
|
+
response = self._client.get(
|
|
225
|
+
path=url,
|
|
226
|
+
cast_to=httpx.Response,
|
|
227
|
+
stream=True,
|
|
228
|
+
options=RequestOptions(
|
|
229
|
+
headers={"Range": f"bytes={bytes_downloaded}-"},
|
|
230
|
+
),
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Download chunks
|
|
234
|
+
for chunk in response.iter_bytes(DOWNLOAD_BLOCK_SIZE):
|
|
235
|
+
temp_file.write(chunk) # type: ignore
|
|
236
|
+
bytes_downloaded += len(chunk)
|
|
237
|
+
pbar.update(len(chunk))
|
|
238
|
+
|
|
239
|
+
# Successfully completed download
|
|
240
|
+
break
|
|
241
|
+
|
|
242
|
+
except (httpx.RequestError, httpx.StreamError, APIConnectionError) as e:
|
|
243
|
+
if retry_count >= MAX_DOWNLOAD_RETRIES:
|
|
244
|
+
log.error(f"Download failed after {retry_count} retries")
|
|
245
|
+
raise DownloadError(
|
|
246
|
+
f"Download failed after {retry_count} retries. Last error: {str(e)}"
|
|
247
|
+
) from e
|
|
248
|
+
|
|
249
|
+
retry_count += 1
|
|
250
|
+
log.warning(
|
|
251
|
+
f"Download interrupted at {bytes_downloaded}/{file_size} bytes. "
|
|
252
|
+
f"Retry {retry_count}/{MAX_DOWNLOAD_RETRIES} in {retry_delay}s..."
|
|
253
|
+
)
|
|
254
|
+
time.sleep(retry_delay)
|
|
255
|
+
|
|
256
|
+
# Exponential backoff with max delay cap
|
|
257
|
+
retry_delay = min(retry_delay * 2, DOWNLOAD_MAX_RETRY_DELAY)
|
|
258
|
+
|
|
259
|
+
except APIStatusError as e:
|
|
260
|
+
# For API errors, don't retry
|
|
261
|
+
log.error(f"API error during download: {e}")
|
|
262
|
+
raise APIStatusError(
|
|
263
|
+
"Error downloading file",
|
|
264
|
+
response=e.response,
|
|
265
|
+
body=e.response,
|
|
266
|
+
) from e
|
|
267
|
+
|
|
268
|
+
# Close the response
|
|
269
|
+
response.close()
|
|
211
270
|
|
|
212
271
|
# Raise exception if remote file size does not match downloaded file size
|
|
213
272
|
if os.stat(temp_file.name).st_size != file_size:
|
|
214
|
-
DownloadError(
|
|
215
|
-
f"Downloaded file size `{
|
|
273
|
+
raise DownloadError(
|
|
274
|
+
f"Downloaded file size `{bytes_downloaded}` bytes does not match remote file size `{file_size}` bytes."
|
|
216
275
|
)
|
|
217
276
|
|
|
218
277
|
# Moves temp file to output file path
|
|
@@ -22,6 +22,7 @@ from together.lib.types.fine_tuning import (
|
|
|
22
22
|
CosineLRSchedulerArgs,
|
|
23
23
|
LinearLRSchedulerArgs,
|
|
24
24
|
FinetuneTrainingLimits,
|
|
25
|
+
FinetuneMultimodalParams,
|
|
25
26
|
)
|
|
26
27
|
|
|
27
28
|
AVAILABLE_TRAINING_METHODS = {
|
|
@@ -51,6 +52,7 @@ def create_finetune_request(
|
|
|
51
52
|
lora_dropout: float | None = 0,
|
|
52
53
|
lora_alpha: float | None = None,
|
|
53
54
|
lora_trainable_modules: str | None = "all-linear",
|
|
55
|
+
train_vision: bool = False,
|
|
54
56
|
suffix: str | None = None,
|
|
55
57
|
wandb_api_key: str | None = None,
|
|
56
58
|
wandb_base_url: str | None = None,
|
|
@@ -207,6 +209,13 @@ def create_finetune_request(
|
|
|
207
209
|
simpo_gamma=simpo_gamma,
|
|
208
210
|
)
|
|
209
211
|
|
|
212
|
+
if model_limits.supports_vision:
|
|
213
|
+
multimodal_params = FinetuneMultimodalParams(train_vision=train_vision)
|
|
214
|
+
elif not model_limits.supports_vision and train_vision:
|
|
215
|
+
raise ValueError(f"Vision encoder training is not supported for the non-multimodal model `{model}`")
|
|
216
|
+
else:
|
|
217
|
+
multimodal_params = None
|
|
218
|
+
|
|
210
219
|
finetune_request = FinetuneRequest(
|
|
211
220
|
model=model,
|
|
212
221
|
training_file=training_file,
|
|
@@ -227,6 +236,7 @@ def create_finetune_request(
|
|
|
227
236
|
wandb_project_name=wandb_project_name,
|
|
228
237
|
wandb_name=wandb_name,
|
|
229
238
|
training_method=training_method_cls, # pyright: ignore[reportPossiblyUnboundVariable]
|
|
239
|
+
multimodal_params=multimodal_params,
|
|
230
240
|
from_checkpoint=from_checkpoint,
|
|
231
241
|
from_hf_model=from_hf_model,
|
|
232
242
|
hf_model_revision=hf_model_revision,
|
|
@@ -238,7 +248,10 @@ def create_finetune_request(
|
|
|
238
248
|
|
|
239
249
|
return finetune_request, training_type_pe, training_method_pe
|
|
240
250
|
|
|
241
|
-
|
|
251
|
+
|
|
252
|
+
def create_price_estimation_params(
|
|
253
|
+
finetune_request: FinetuneRequest,
|
|
254
|
+
) -> tuple[pe_params.TrainingType, pe_params.TrainingMethod]:
|
|
242
255
|
training_type_cls: pe_params.TrainingType
|
|
243
256
|
if isinstance(finetune_request.training_type, FullTrainingType):
|
|
244
257
|
training_type_cls = pe_params.TrainingTypeFullTrainingType(
|
|
@@ -275,6 +288,7 @@ def create_price_estimation_params(finetune_request: FinetuneRequest) -> tuple[p
|
|
|
275
288
|
|
|
276
289
|
return training_type_cls, training_method_cls
|
|
277
290
|
|
|
291
|
+
|
|
278
292
|
def get_model_limits(client: Together, model: str) -> FinetuneTrainingLimits:
|
|
279
293
|
"""
|
|
280
294
|
Requests training limits for a specific model
|
|
@@ -25,6 +25,14 @@ class FinetuneJobStatus(str, Enum):
|
|
|
25
25
|
STATUS_COMPLETED = "completed"
|
|
26
26
|
|
|
27
27
|
|
|
28
|
+
COMPLETED_STATUSES = [
|
|
29
|
+
FinetuneJobStatus.STATUS_ERROR,
|
|
30
|
+
FinetuneJobStatus.STATUS_USER_ERROR,
|
|
31
|
+
FinetuneJobStatus.STATUS_COMPLETED,
|
|
32
|
+
FinetuneJobStatus.STATUS_CANCELLED,
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
28
36
|
class FinetuneEventType(str, Enum):
|
|
29
37
|
"""
|
|
30
38
|
Fine-tune job event types
|
|
@@ -181,6 +189,7 @@ class TrainingMethodUnknown(BaseModel):
|
|
|
181
189
|
|
|
182
190
|
method: str
|
|
183
191
|
|
|
192
|
+
|
|
184
193
|
TrainingMethod: TypeAlias = Union[
|
|
185
194
|
TrainingMethodSFT,
|
|
186
195
|
TrainingMethodDPO,
|
|
@@ -194,6 +203,7 @@ class FinetuneTrainingLimits(BaseModel):
|
|
|
194
203
|
min_learning_rate: float
|
|
195
204
|
full_training: Optional[FinetuneFullTrainingLimits] = None
|
|
196
205
|
lora_training: Optional[FinetuneLoraTrainingLimits] = None
|
|
206
|
+
supports_vision: bool = False
|
|
197
207
|
|
|
198
208
|
|
|
199
209
|
class LinearLRSchedulerArgs(BaseModel):
|
|
@@ -241,6 +251,7 @@ class EmptyLRScheduler(BaseModel):
|
|
|
241
251
|
lr_scheduler_type: Literal[""]
|
|
242
252
|
lr_scheduler_args: None = None
|
|
243
253
|
|
|
254
|
+
|
|
244
255
|
class UnknownLRScheduler(BaseModel):
|
|
245
256
|
"""
|
|
246
257
|
Unknown learning rate scheduler
|
|
@@ -260,6 +271,23 @@ FinetuneLRScheduler: TypeAlias = Union[
|
|
|
260
271
|
]
|
|
261
272
|
|
|
262
273
|
|
|
274
|
+
class FinetuneMultimodalParams(BaseModel):
|
|
275
|
+
"""
|
|
276
|
+
Multimodal parameters
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
train_vision: bool = False
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class FinetuneProgress(BaseModel):
|
|
283
|
+
"""
|
|
284
|
+
Fine-tune job progress
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
estimate_available: bool = False
|
|
288
|
+
seconds_remaining: float = 0
|
|
289
|
+
|
|
290
|
+
|
|
263
291
|
class FinetuneResponse(BaseModel):
|
|
264
292
|
"""
|
|
265
293
|
Fine-tune API response type
|
|
@@ -286,6 +314,9 @@ class FinetuneResponse(BaseModel):
|
|
|
286
314
|
from_checkpoint: Optional[str] = None
|
|
287
315
|
"""Checkpoint used to continue training"""
|
|
288
316
|
|
|
317
|
+
multimodal_params: Optional[FinetuneMultimodalParams] = None
|
|
318
|
+
"""Multimodal parameters"""
|
|
319
|
+
|
|
289
320
|
from_hf_model: Optional[str] = None
|
|
290
321
|
"""Hugging Face Hub repo to start training from"""
|
|
291
322
|
|
|
@@ -393,6 +424,8 @@ class FinetuneResponse(BaseModel):
|
|
|
393
424
|
training_file_size: Optional[int] = Field(None, alias="TrainingFileSize")
|
|
394
425
|
train_on_inputs: Union[StrictBool, Literal["auto"], None] = "auto"
|
|
395
426
|
|
|
427
|
+
progress: Union[FinetuneProgress, None] = None
|
|
428
|
+
|
|
396
429
|
@classmethod
|
|
397
430
|
def validate_training_type(cls, v: TrainingType) -> TrainingType:
|
|
398
431
|
if v.type == "Full" or v.type == "":
|
|
@@ -448,6 +481,9 @@ class FinetuneRequest(BaseModel):
|
|
|
448
481
|
training_method: TrainingMethod = Field(default_factory=TrainingMethodSFT)
|
|
449
482
|
# from step
|
|
450
483
|
from_checkpoint: Union[str, None] = None
|
|
484
|
+
# multimodal parameters
|
|
485
|
+
multimodal_params: Union[FinetuneMultimodalParams, None] = None
|
|
486
|
+
# hugging face related fields
|
|
451
487
|
from_hf_model: Union[str, None] = None
|
|
452
488
|
hf_model_revision: Union[str, None] = None
|
|
453
489
|
# hf related fields
|