together 2.0.0a10__py3-none-any.whl → 2.0.0a12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- together/_base_client.py +8 -2
- together/_version.py +1 -1
- together/lib/cli/api/fine_tuning.py +20 -3
- together/lib/cli/api/utils.py +87 -6
- together/lib/constants.py +9 -0
- together/lib/resources/files.py +65 -6
- together/lib/resources/fine_tuning.py +15 -1
- together/lib/types/fine_tuning.py +36 -0
- together/lib/utils/files.py +187 -29
- together/resources/audio/transcriptions.py +6 -4
- together/resources/audio/translations.py +6 -4
- together/resources/fine_tuning.py +25 -17
- together/types/audio/transcription_create_params.py +5 -2
- together/types/audio/translation_create_params.py +5 -2
- together/types/fine_tuning_cancel_response.py +14 -0
- together/types/fine_tuning_list_response.py +14 -0
- together/types/finetune_response.py +28 -2
- {together-2.0.0a10.dist-info → together-2.0.0a12.dist-info}/METADATA +3 -3
- {together-2.0.0a10.dist-info → together-2.0.0a12.dist-info}/RECORD +22 -22
- {together-2.0.0a10.dist-info → together-2.0.0a12.dist-info}/licenses/LICENSE +1 -1
- {together-2.0.0a10.dist-info → together-2.0.0a12.dist-info}/WHEEL +0 -0
- {together-2.0.0a10.dist-info → together-2.0.0a12.dist-info}/entry_points.txt +0 -0
together/lib/utils/files.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import os
|
|
4
4
|
import csv
|
|
5
5
|
import json
|
|
6
|
-
from typing import Any, Dict, List, cast
|
|
6
|
+
from typing import Any, Dict, List, Union, cast
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from traceback import format_exc
|
|
9
9
|
|
|
@@ -13,8 +13,11 @@ from together.types import FilePurpose
|
|
|
13
13
|
from together.lib.constants import (
|
|
14
14
|
MIN_SAMPLES,
|
|
15
15
|
DISABLE_TQDM,
|
|
16
|
+
MAX_IMAGE_BYTES,
|
|
16
17
|
NUM_BYTES_IN_GB,
|
|
17
18
|
MAX_FILE_SIZE_GB,
|
|
19
|
+
MAX_IMAGES_PER_EXAMPLE,
|
|
20
|
+
MAX_BASE64_IMAGE_LENGTH,
|
|
18
21
|
PARQUET_EXPECTED_COLUMNS,
|
|
19
22
|
REQUIRED_COLUMNS_MESSAGE,
|
|
20
23
|
JSONL_REQUIRED_COLUMNS_MAP,
|
|
@@ -22,6 +25,15 @@ from together.lib.constants import (
|
|
|
22
25
|
DatasetFormat,
|
|
23
26
|
)
|
|
24
27
|
|
|
28
|
+
# MessageContent is a string or a list of dicts with 'type': 'text' or 'image_url', and 'text' or 'image_url.url'
|
|
29
|
+
# Example: "Hello" or [
|
|
30
|
+
# {"type": "text", "text": "Hello"},
|
|
31
|
+
# {"type": "image_url", "image_url": {
|
|
32
|
+
# "url": "data:image/jpeg;base64,..."
|
|
33
|
+
# }}
|
|
34
|
+
# ]
|
|
35
|
+
MessageContent = Union[str, List[Dict[str, Any]]]
|
|
36
|
+
|
|
25
37
|
|
|
26
38
|
class InvalidFileFormatError(ValueError):
|
|
27
39
|
"""Exception raised for invalid file formats during file checks."""
|
|
@@ -103,7 +115,7 @@ def check_file(
|
|
|
103
115
|
return report_dict
|
|
104
116
|
|
|
105
117
|
|
|
106
|
-
def _check_conversation_type(messages: List[Dict[str, str |
|
|
118
|
+
def _check_conversation_type(messages: List[Dict[str, str | int | MessageContent]], idx: int) -> None:
|
|
107
119
|
"""Check that the conversation has correct type.
|
|
108
120
|
|
|
109
121
|
Args:
|
|
@@ -144,12 +156,6 @@ def _check_conversation_type(messages: List[Dict[str, str | bool]], idx: int) ->
|
|
|
144
156
|
line_number=idx + 1,
|
|
145
157
|
error_source="key_value",
|
|
146
158
|
)
|
|
147
|
-
if not isinstance(message[column], str):
|
|
148
|
-
raise InvalidFileFormatError(
|
|
149
|
-
message=f"Column `{column}` is not a string on line {idx + 1}. Found {type(message[column])}",
|
|
150
|
-
line_number=idx + 1,
|
|
151
|
-
error_source="text_field",
|
|
152
|
-
)
|
|
153
159
|
|
|
154
160
|
|
|
155
161
|
def _check_conversation_roles(require_assistant_role: bool, assistant_role_exists: bool, idx: int) -> None:
|
|
@@ -172,7 +178,7 @@ def _check_conversation_roles(require_assistant_role: bool, assistant_role_exist
|
|
|
172
178
|
)
|
|
173
179
|
|
|
174
180
|
|
|
175
|
-
def _check_message_weight(message: Dict[str, str |
|
|
181
|
+
def _check_message_weight(message: Dict[str, str | int | MessageContent], idx: int) -> int | None:
|
|
176
182
|
"""Check that the message has a weight with the correct type and value.
|
|
177
183
|
|
|
178
184
|
Args:
|
|
@@ -196,9 +202,12 @@ def _check_message_weight(message: Dict[str, str | bool], idx: int) -> None:
|
|
|
196
202
|
line_number=idx + 1,
|
|
197
203
|
error_source="key_value",
|
|
198
204
|
)
|
|
205
|
+
return weight
|
|
206
|
+
|
|
207
|
+
return None
|
|
199
208
|
|
|
200
209
|
|
|
201
|
-
def _check_message_role(message: Dict[str, str |
|
|
210
|
+
def _check_message_role(message: Dict[str, str | int | MessageContent], previous_role: str | None, idx: int) -> str:
|
|
202
211
|
"""Check that the message has correct roles.
|
|
203
212
|
|
|
204
213
|
Args:
|
|
@@ -212,6 +221,14 @@ def _check_message_role(message: Dict[str, str | bool], previous_role: str | boo
|
|
|
212
221
|
Raises:
|
|
213
222
|
InvalidFileFormatError: If the message role is invalid.
|
|
214
223
|
"""
|
|
224
|
+
if not isinstance(message["role"], str):
|
|
225
|
+
raise InvalidFileFormatError(
|
|
226
|
+
message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. "
|
|
227
|
+
f"Role must be a string. Found {type(message['role'])}",
|
|
228
|
+
line_number=idx + 1,
|
|
229
|
+
error_source="key_value",
|
|
230
|
+
)
|
|
231
|
+
|
|
215
232
|
if message["role"] not in POSSIBLE_ROLES_CONVERSATION:
|
|
216
233
|
raise InvalidFileFormatError(
|
|
217
234
|
message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. "
|
|
@@ -229,7 +246,130 @@ def _check_message_role(message: Dict[str, str | bool], previous_role: str | boo
|
|
|
229
246
|
return message["role"]
|
|
230
247
|
|
|
231
248
|
|
|
232
|
-
def
|
|
249
|
+
def _check_message_content(message_content: str | int | MessageContent, role: str, idx: int) -> tuple[bool, int]:
|
|
250
|
+
"""Check that the message content has the correct type.
|
|
251
|
+
Message content can be either a) a string or b) an OpenAI-style multimodal list of content items
|
|
252
|
+
Example:
|
|
253
|
+
a) "Hello", or
|
|
254
|
+
b) [
|
|
255
|
+
{"type": "text", "text": "Hello"},
|
|
256
|
+
{"type": "image_url", "image_url": {
|
|
257
|
+
"url": "data:image/jpeg;base64,..."
|
|
258
|
+
}}
|
|
259
|
+
]
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
message_content: The message content to check.
|
|
263
|
+
role: The role of the message.
|
|
264
|
+
idx: Line number in the file.
|
|
265
|
+
|
|
266
|
+
Returns:
|
|
267
|
+
tuple[bool, int]: A tuple with message is multimodal and the number of images in the message content.
|
|
268
|
+
"""
|
|
269
|
+
# Text-only message content
|
|
270
|
+
if isinstance(message_content, str):
|
|
271
|
+
return False, 0
|
|
272
|
+
|
|
273
|
+
# Multimodal message content
|
|
274
|
+
if isinstance(message_content, list):
|
|
275
|
+
num_images = 0
|
|
276
|
+
for item in message_content:
|
|
277
|
+
if not isinstance(cast(Any, item), dict):
|
|
278
|
+
raise InvalidFileFormatError(
|
|
279
|
+
"The dataset is malformed, the `content` field must be a list of dicts.",
|
|
280
|
+
line_number=idx + 1,
|
|
281
|
+
error_source="key_value",
|
|
282
|
+
)
|
|
283
|
+
if "type" not in item:
|
|
284
|
+
raise InvalidFileFormatError(
|
|
285
|
+
"The dataset is malformed, the `content` field must be a list of dicts with a `type` field.",
|
|
286
|
+
line_number=idx + 1,
|
|
287
|
+
error_source="key_value",
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
if item["type"] == "text":
|
|
291
|
+
if "text" not in item or not isinstance(item["text"], str):
|
|
292
|
+
raise InvalidFileFormatError(
|
|
293
|
+
"The dataset is malformed, the `text` field must be present in the `content` item field and be"
|
|
294
|
+
f" a string. Got '{item.get('text')!r}' instead.",
|
|
295
|
+
line_number=idx + 1,
|
|
296
|
+
error_source="key_value",
|
|
297
|
+
)
|
|
298
|
+
elif item["type"] == "image_url":
|
|
299
|
+
if role != "user":
|
|
300
|
+
raise InvalidFileFormatError(
|
|
301
|
+
"The dataset is malformed, only user messages can contain images.",
|
|
302
|
+
line_number=idx + 1,
|
|
303
|
+
error_source="key_value",
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
if "image_url" not in item or not isinstance(item["image_url"], dict):
|
|
307
|
+
raise InvalidFileFormatError(
|
|
308
|
+
"The dataset is malformed, the `image_url` field must be present in the `content` field and "
|
|
309
|
+
f"be a dictionary. Got {item.get('image_url')!r} instead.",
|
|
310
|
+
line_number=idx + 1,
|
|
311
|
+
error_source="key_value",
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
image_data = cast(Any, item["image_url"]).get("url")
|
|
315
|
+
if not image_data or not isinstance(image_data, str):
|
|
316
|
+
raise InvalidFileFormatError(
|
|
317
|
+
"The dataset is malformed, the `url` field must be present in the `image_url` field and be "
|
|
318
|
+
f"a string. Got {image_data!r} instead.",
|
|
319
|
+
line_number=idx + 1,
|
|
320
|
+
error_source="key_value",
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
if not any(image_data.startswith(f"data:image/{fmt};base64,") for fmt in ["jpeg", "png", "webp"]):
|
|
324
|
+
raise InvalidFileFormatError(
|
|
325
|
+
"The dataset is malformed, the `url` field must be either a JPEG, PNG or WEBP base64-encoded "
|
|
326
|
+
"image in 'data:image/<format>;base64,<base64_encoded_image>' format. "
|
|
327
|
+
f"Got '{image_data[:100]}...' instead.",
|
|
328
|
+
line_number=idx + 1,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
if len(image_data) > MAX_BASE64_IMAGE_LENGTH:
|
|
332
|
+
raise InvalidFileFormatError(
|
|
333
|
+
"The dataset is malformed, the `url` field must contain base64-encoded image "
|
|
334
|
+
f"that is less than {MAX_IMAGE_BYTES // (1024**2)}MB, found ~{len(image_data) * 3 // 4} bytes.",
|
|
335
|
+
line_number=idx + 1,
|
|
336
|
+
error_source="key_value",
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
num_images += 1
|
|
340
|
+
else:
|
|
341
|
+
raise InvalidFileFormatError(
|
|
342
|
+
"The dataset is malformed, the `type` field must be either 'text' or 'image_url'. "
|
|
343
|
+
f"Got {item['type']!r}.",
|
|
344
|
+
line_number=idx + 1,
|
|
345
|
+
error_source="key_value",
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
if num_images > MAX_IMAGES_PER_EXAMPLE:
|
|
349
|
+
raise InvalidFileFormatError(
|
|
350
|
+
f"The dataset is malformed, the `content` field must contain at most "
|
|
351
|
+
f"{MAX_IMAGES_PER_EXAMPLE} images, found {num_images}.",
|
|
352
|
+
line_number=idx + 1,
|
|
353
|
+
error_source="key_value",
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# We still consider text-only messages in such format as multimodal, even if they don't have any images
|
|
357
|
+
# included - so we can process datasets with rather sparse images (i.e. not in each sample) consistently.
|
|
358
|
+
return True, num_images
|
|
359
|
+
|
|
360
|
+
raise InvalidFileFormatError(
|
|
361
|
+
f"Invalid content type on line {idx + 1} of the input file. Expected string or multimodal list of dicts, "
|
|
362
|
+
f"found {type(message_content)}",
|
|
363
|
+
line_number=idx + 1,
|
|
364
|
+
error_source="key_value",
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def validate_messages(
|
|
369
|
+
messages: List[Dict[str, str | int | MessageContent]],
|
|
370
|
+
idx: int,
|
|
371
|
+
require_assistant_role: bool = True,
|
|
372
|
+
) -> None:
|
|
233
373
|
"""Validate the messages column.
|
|
234
374
|
|
|
235
375
|
Args:
|
|
@@ -242,15 +382,43 @@ def validate_messages(messages: List[Dict[str, str | bool]], idx: int, require_a
|
|
|
242
382
|
"""
|
|
243
383
|
_check_conversation_type(messages, idx)
|
|
244
384
|
|
|
245
|
-
has_weights = any("weight" in message for message in messages)
|
|
246
385
|
previous_role = None
|
|
247
386
|
assistant_role_exists = False
|
|
248
387
|
|
|
388
|
+
messages_are_multimodal: bool | None = None
|
|
389
|
+
total_number_of_images = 0
|
|
390
|
+
|
|
249
391
|
for message in messages:
|
|
250
|
-
|
|
251
|
-
_check_message_weight(message, idx)
|
|
392
|
+
message_weight = _check_message_weight(message, idx)
|
|
252
393
|
previous_role = _check_message_role(message, previous_role, idx)
|
|
253
394
|
assistant_role_exists |= previous_role == "assistant"
|
|
395
|
+
is_multimodal, number_of_images = _check_message_content(message["content"], role=previous_role, idx=idx)
|
|
396
|
+
# Multimodal validation
|
|
397
|
+
if number_of_images > 0 and message_weight is not None and message_weight != 0:
|
|
398
|
+
raise InvalidFileFormatError(
|
|
399
|
+
"Messages with images cannot have non-zero weights.",
|
|
400
|
+
line_number=idx + 1,
|
|
401
|
+
error_source="key_value",
|
|
402
|
+
)
|
|
403
|
+
if messages_are_multimodal is None:
|
|
404
|
+
# Detect the format of the messages in the conversation.
|
|
405
|
+
messages_are_multimodal = is_multimodal
|
|
406
|
+
elif messages_are_multimodal != is_multimodal:
|
|
407
|
+
# Due to the format limitation, we cannot mix multimodal and text only messages in the same sample.
|
|
408
|
+
raise InvalidFileFormatError(
|
|
409
|
+
"Messages in the conversation must be either all in multimodal or all in text-only format.",
|
|
410
|
+
line_number=idx + 1,
|
|
411
|
+
error_source="key_value",
|
|
412
|
+
)
|
|
413
|
+
total_number_of_images += number_of_images
|
|
414
|
+
|
|
415
|
+
if total_number_of_images > MAX_IMAGES_PER_EXAMPLE:
|
|
416
|
+
raise InvalidFileFormatError(
|
|
417
|
+
f"The dataset is malformed, the `messages` must contain at most {MAX_IMAGES_PER_EXAMPLE} images. "
|
|
418
|
+
f"Found {total_number_of_images} images.",
|
|
419
|
+
line_number=idx + 1,
|
|
420
|
+
error_source="key_value",
|
|
421
|
+
)
|
|
254
422
|
|
|
255
423
|
_check_conversation_roles(require_assistant_role, assistant_role_exists, idx)
|
|
256
424
|
|
|
@@ -279,7 +447,7 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None:
|
|
|
279
447
|
error_source="key_value",
|
|
280
448
|
)
|
|
281
449
|
|
|
282
|
-
messages: List[Dict[str, str |
|
|
450
|
+
messages: List[Dict[str, str | int | MessageContent]] = cast(Any, example["input"]["messages"])
|
|
283
451
|
validate_messages(messages, idx, require_assistant_role=False)
|
|
284
452
|
|
|
285
453
|
if example["input"]["messages"][-1]["role"] == "assistant":
|
|
@@ -341,12 +509,7 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None:
|
|
|
341
509
|
error_source="key_value",
|
|
342
510
|
)
|
|
343
511
|
|
|
344
|
-
|
|
345
|
-
raise InvalidFileFormatError(
|
|
346
|
-
message=f"The dataset is malformed, the 'content' field in `{key}` must be a string on line {idx + 1}.",
|
|
347
|
-
line_number=idx + 1,
|
|
348
|
-
error_source="key_value",
|
|
349
|
-
)
|
|
512
|
+
_check_message_content(example[key][0]["content"], role="assistant", idx=idx)
|
|
350
513
|
|
|
351
514
|
|
|
352
515
|
def _check_utf8(file: Path) -> Dict[str, Any]:
|
|
@@ -514,7 +677,7 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
|
|
|
514
677
|
elif current_format == DatasetFormat.CONVERSATION:
|
|
515
678
|
message_column = JSONL_REQUIRED_COLUMNS_MAP[DatasetFormat.CONVERSATION][0]
|
|
516
679
|
require_assistant = purpose != "eval"
|
|
517
|
-
message: List[Dict[str, str |
|
|
680
|
+
message: List[Dict[str, str | int | MessageContent]] = cast(Any, json_line[message_column])
|
|
518
681
|
validate_messages(
|
|
519
682
|
message,
|
|
520
683
|
idx,
|
|
@@ -522,13 +685,8 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
|
|
|
522
685
|
)
|
|
523
686
|
else:
|
|
524
687
|
for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
|
|
525
|
-
if
|
|
526
|
-
|
|
527
|
-
message=f'Invalid value type for "{column}" key on line {idx + 1}. '
|
|
528
|
-
f"Expected string. Found {type(cast(Any, json_line[column]))}.",
|
|
529
|
-
line_number=idx + 1,
|
|
530
|
-
error_source="key_value",
|
|
531
|
-
)
|
|
688
|
+
role = "assistant" if column in {"completion"} else "user"
|
|
689
|
+
_check_message_content(cast(Any, json_line[column]), role=role, idx=idx)
|
|
532
690
|
|
|
533
691
|
if dataset_format is None:
|
|
534
692
|
dataset_format = current_format
|
|
@@ -47,7 +47,7 @@ class TranscriptionsResource(SyncAPIResource):
|
|
|
47
47
|
def create(
|
|
48
48
|
self,
|
|
49
49
|
*,
|
|
50
|
-
file: FileTypes,
|
|
50
|
+
file: Union[FileTypes, str],
|
|
51
51
|
diarize: bool | Omit = omit,
|
|
52
52
|
language: str | Omit = omit,
|
|
53
53
|
max_speakers: int | Omit = omit,
|
|
@@ -68,7 +68,8 @@ class TranscriptionsResource(SyncAPIResource):
|
|
|
68
68
|
Transcribes audio into text
|
|
69
69
|
|
|
70
70
|
Args:
|
|
71
|
-
file: Audio file
|
|
71
|
+
file: Audio file upload or public HTTP/HTTPS URL. Supported formats .wav, .mp3, .m4a,
|
|
72
|
+
.webm, .flac.
|
|
72
73
|
|
|
73
74
|
diarize: Whether to enable speaker diarization. When enabled, you will get the speaker id
|
|
74
75
|
for each word in the transcription. In the response, in the words array, you
|
|
@@ -168,7 +169,7 @@ class AsyncTranscriptionsResource(AsyncAPIResource):
|
|
|
168
169
|
async def create(
|
|
169
170
|
self,
|
|
170
171
|
*,
|
|
171
|
-
file: FileTypes,
|
|
172
|
+
file: Union[FileTypes, str],
|
|
172
173
|
diarize: bool | Omit = omit,
|
|
173
174
|
language: str | Omit = omit,
|
|
174
175
|
max_speakers: int | Omit = omit,
|
|
@@ -189,7 +190,8 @@ class AsyncTranscriptionsResource(AsyncAPIResource):
|
|
|
189
190
|
Transcribes audio into text
|
|
190
191
|
|
|
191
192
|
Args:
|
|
192
|
-
file: Audio file
|
|
193
|
+
file: Audio file upload or public HTTP/HTTPS URL. Supported formats .wav, .mp3, .m4a,
|
|
194
|
+
.webm, .flac.
|
|
193
195
|
|
|
194
196
|
diarize: Whether to enable speaker diarization. When enabled, you will get the speaker id
|
|
195
197
|
for each word in the transcription. In the response, in the words array, you
|
|
@@ -47,7 +47,7 @@ class TranslationsResource(SyncAPIResource):
|
|
|
47
47
|
def create(
|
|
48
48
|
self,
|
|
49
49
|
*,
|
|
50
|
-
file: FileTypes,
|
|
50
|
+
file: Union[FileTypes, str],
|
|
51
51
|
language: str | Omit = omit,
|
|
52
52
|
model: Literal["openai/whisper-large-v3"] | Omit = omit,
|
|
53
53
|
prompt: str | Omit = omit,
|
|
@@ -65,7 +65,8 @@ class TranslationsResource(SyncAPIResource):
|
|
|
65
65
|
Translates audio into English
|
|
66
66
|
|
|
67
67
|
Args:
|
|
68
|
-
file: Audio file
|
|
68
|
+
file: Audio file upload or public HTTP/HTTPS URL. Supported formats .wav, .mp3, .m4a,
|
|
69
|
+
.webm, .flac.
|
|
69
70
|
|
|
70
71
|
language: Target output language. Optional ISO 639-1 language code. If omitted, language
|
|
71
72
|
is set to English.
|
|
@@ -145,7 +146,7 @@ class AsyncTranslationsResource(AsyncAPIResource):
|
|
|
145
146
|
async def create(
|
|
146
147
|
self,
|
|
147
148
|
*,
|
|
148
|
-
file: FileTypes,
|
|
149
|
+
file: Union[FileTypes, str],
|
|
149
150
|
language: str | Omit = omit,
|
|
150
151
|
model: Literal["openai/whisper-large-v3"] | Omit = omit,
|
|
151
152
|
prompt: str | Omit = omit,
|
|
@@ -163,7 +164,8 @@ class AsyncTranslationsResource(AsyncAPIResource):
|
|
|
163
164
|
Translates audio into English
|
|
164
165
|
|
|
165
166
|
Args:
|
|
166
|
-
file: Audio file
|
|
167
|
+
file: Audio file upload or public HTTP/HTTPS URL. Supported formats .wav, .mp3, .m4a,
|
|
168
|
+
.webm, .flac.
|
|
167
169
|
|
|
168
170
|
language: Target output language. Optional ISO 639-1 language code. If omitted, language
|
|
169
171
|
is set to English.
|
|
@@ -53,6 +53,7 @@ _WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
|
|
|
53
53
|
"Proceed at your own risk."
|
|
54
54
|
)
|
|
55
55
|
|
|
56
|
+
|
|
56
57
|
class FineTuningResource(SyncAPIResource):
|
|
57
58
|
@cached_property
|
|
58
59
|
def with_raw_response(self) -> FineTuningResourceWithRawResponse:
|
|
@@ -95,6 +96,7 @@ class FineTuningResource(SyncAPIResource):
|
|
|
95
96
|
lora_dropout: float | None = 0,
|
|
96
97
|
lora_alpha: float | None = None,
|
|
97
98
|
lora_trainable_modules: str | None = "all-linear",
|
|
99
|
+
train_vision: bool = False,
|
|
98
100
|
suffix: str | None = None,
|
|
99
101
|
wandb_api_key: str | None = None,
|
|
100
102
|
wandb_base_url: str | None = None,
|
|
@@ -140,6 +142,7 @@ class FineTuningResource(SyncAPIResource):
|
|
|
140
142
|
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
|
|
141
143
|
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
|
|
142
144
|
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
|
|
145
|
+
train_vision (bool, optional): Whether to train the vision encoder (Only for multimodal models). Defaults to False.
|
|
143
146
|
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
|
|
144
147
|
Defaults to None.
|
|
145
148
|
wandb_api_key (str, optional): API key for Weights & Biases integration.
|
|
@@ -214,6 +217,7 @@ class FineTuningResource(SyncAPIResource):
|
|
|
214
217
|
lora_dropout=lora_dropout,
|
|
215
218
|
lora_alpha=lora_alpha,
|
|
216
219
|
lora_trainable_modules=lora_trainable_modules,
|
|
220
|
+
train_vision=train_vision,
|
|
217
221
|
suffix=suffix,
|
|
218
222
|
wandb_api_key=wandb_api_key,
|
|
219
223
|
wandb_base_url=wandb_base_url,
|
|
@@ -232,29 +236,32 @@ class FineTuningResource(SyncAPIResource):
|
|
|
232
236
|
hf_output_repo_name=hf_output_repo_name,
|
|
233
237
|
)
|
|
234
238
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
239
|
+
if not model_limits.supports_vision:
|
|
240
|
+
price_estimation_result = self.estimate_price(
|
|
241
|
+
training_file=training_file,
|
|
242
|
+
from_checkpoint=from_checkpoint or Omit(),
|
|
243
|
+
validation_file=validation_file or Omit(),
|
|
244
|
+
model=model or "",
|
|
245
|
+
n_epochs=finetune_request.n_epochs,
|
|
246
|
+
n_evals=finetune_request.n_evals or 0,
|
|
247
|
+
training_type=training_type_cls,
|
|
248
|
+
training_method=training_method_cls,
|
|
249
|
+
)
|
|
250
|
+
price_limit_passed = price_estimation_result.allowed_to_proceed
|
|
251
|
+
else:
|
|
252
|
+
# unsupported case
|
|
253
|
+
price_limit_passed = True
|
|
247
254
|
|
|
248
255
|
if verbose:
|
|
249
256
|
rprint(
|
|
250
257
|
"Submitting a fine-tuning job with the following parameters:",
|
|
251
258
|
finetune_request,
|
|
252
259
|
)
|
|
253
|
-
if not
|
|
260
|
+
if not price_limit_passed:
|
|
254
261
|
rprint(
|
|
255
262
|
"[red]"
|
|
256
263
|
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
|
|
257
|
-
price_estimation_result.estimated_total_price
|
|
264
|
+
price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable]
|
|
258
265
|
)
|
|
259
266
|
+ "[/red]",
|
|
260
267
|
)
|
|
@@ -627,6 +634,7 @@ class AsyncFineTuningResource(AsyncAPIResource):
|
|
|
627
634
|
lora_dropout: float | None = 0,
|
|
628
635
|
lora_alpha: float | None = None,
|
|
629
636
|
lora_trainable_modules: str | None = "all-linear",
|
|
637
|
+
train_vision: bool = False,
|
|
630
638
|
suffix: str | None = None,
|
|
631
639
|
wandb_api_key: str | None = None,
|
|
632
640
|
wandb_base_url: str | None = None,
|
|
@@ -672,6 +680,7 @@ class AsyncFineTuningResource(AsyncAPIResource):
|
|
|
672
680
|
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
|
|
673
681
|
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
|
|
674
682
|
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
|
|
683
|
+
train_vision (bool, optional): Whether to train the vision encoder (Only for multimodal models). Defaults to False.
|
|
675
684
|
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
|
|
676
685
|
Defaults to None.
|
|
677
686
|
wandb_api_key (str, optional): API key for Weights & Biases integration.
|
|
@@ -746,6 +755,7 @@ class AsyncFineTuningResource(AsyncAPIResource):
|
|
|
746
755
|
lora_dropout=lora_dropout,
|
|
747
756
|
lora_alpha=lora_alpha,
|
|
748
757
|
lora_trainable_modules=lora_trainable_modules,
|
|
758
|
+
train_vision=train_vision,
|
|
749
759
|
suffix=suffix,
|
|
750
760
|
wandb_api_key=wandb_api_key,
|
|
751
761
|
wandb_base_url=wandb_base_url,
|
|
@@ -764,7 +774,6 @@ class AsyncFineTuningResource(AsyncAPIResource):
|
|
|
764
774
|
hf_output_repo_name=hf_output_repo_name,
|
|
765
775
|
)
|
|
766
776
|
|
|
767
|
-
|
|
768
777
|
price_estimation_result = await self.estimate_price(
|
|
769
778
|
training_file=training_file,
|
|
770
779
|
from_checkpoint=from_checkpoint or Omit(),
|
|
@@ -776,7 +785,6 @@ class AsyncFineTuningResource(AsyncAPIResource):
|
|
|
776
785
|
training_method=training_method_cls,
|
|
777
786
|
)
|
|
778
787
|
|
|
779
|
-
|
|
780
788
|
if verbose:
|
|
781
789
|
rprint(
|
|
782
790
|
"Submitting a fine-tuning job with the following parameters:",
|
|
@@ -786,7 +794,7 @@ class AsyncFineTuningResource(AsyncAPIResource):
|
|
|
786
794
|
rprint(
|
|
787
795
|
"[red]"
|
|
788
796
|
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
|
|
789
|
-
price_estimation_result.estimated_total_price
|
|
797
|
+
price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable]
|
|
790
798
|
)
|
|
791
799
|
+ "[/red]",
|
|
792
800
|
)
|
|
@@ -11,8 +11,11 @@ __all__ = ["TranscriptionCreateParams"]
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class TranscriptionCreateParams(TypedDict, total=False):
|
|
14
|
-
file: Required[FileTypes]
|
|
15
|
-
"""Audio file
|
|
14
|
+
file: Required[Union[FileTypes, str]]
|
|
15
|
+
"""Audio file upload or public HTTP/HTTPS URL.
|
|
16
|
+
|
|
17
|
+
Supported formats .wav, .mp3, .m4a, .webm, .flac.
|
|
18
|
+
"""
|
|
16
19
|
|
|
17
20
|
diarize: bool
|
|
18
21
|
"""Whether to enable speaker diarization.
|
|
@@ -11,8 +11,11 @@ __all__ = ["TranslationCreateParams"]
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class TranslationCreateParams(TypedDict, total=False):
|
|
14
|
-
file: Required[FileTypes]
|
|
15
|
-
"""Audio file
|
|
14
|
+
file: Required[Union[FileTypes, str]]
|
|
15
|
+
"""Audio file upload or public HTTP/HTTPS URL.
|
|
16
|
+
|
|
17
|
+
Supported formats .wav, .mp3, .m4a, .webm, .flac.
|
|
18
|
+
"""
|
|
16
19
|
|
|
17
20
|
language: str
|
|
18
21
|
"""Target output language.
|
|
@@ -15,6 +15,7 @@ __all__ = [
|
|
|
15
15
|
"LrSchedulerLrSchedulerArgs",
|
|
16
16
|
"LrSchedulerLrSchedulerArgsLinearLrSchedulerArgs",
|
|
17
17
|
"LrSchedulerLrSchedulerArgsCosineLrSchedulerArgs",
|
|
18
|
+
"Progress",
|
|
18
19
|
"TrainingMethod",
|
|
19
20
|
"TrainingMethodTrainingMethodSft",
|
|
20
21
|
"TrainingMethodTrainingMethodDpo",
|
|
@@ -50,6 +51,16 @@ class LrScheduler(BaseModel):
|
|
|
50
51
|
lr_scheduler_args: Optional[LrSchedulerLrSchedulerArgs] = None
|
|
51
52
|
|
|
52
53
|
|
|
54
|
+
class Progress(BaseModel):
|
|
55
|
+
"""Progress information for the fine-tuning job"""
|
|
56
|
+
|
|
57
|
+
estimate_available: bool
|
|
58
|
+
"""Whether time estimate is available"""
|
|
59
|
+
|
|
60
|
+
seconds_remaining: int
|
|
61
|
+
"""Estimated time remaining in seconds for the fine-tuning job to next state"""
|
|
62
|
+
|
|
63
|
+
|
|
53
64
|
class TrainingMethodTrainingMethodSft(BaseModel):
|
|
54
65
|
method: Literal["sft"]
|
|
55
66
|
|
|
@@ -163,6 +174,9 @@ class FineTuningCancelResponse(BaseModel):
|
|
|
163
174
|
owner_address: Optional[str] = None
|
|
164
175
|
"""Owner address information"""
|
|
165
176
|
|
|
177
|
+
progress: Optional[Progress] = None
|
|
178
|
+
"""Progress information for the fine-tuning job"""
|
|
179
|
+
|
|
166
180
|
suffix: Optional[str] = None
|
|
167
181
|
"""Suffix added to the fine-tuned model name"""
|
|
168
182
|
|
|
@@ -16,6 +16,7 @@ __all__ = [
|
|
|
16
16
|
"DataLrSchedulerLrSchedulerArgs",
|
|
17
17
|
"DataLrSchedulerLrSchedulerArgsLinearLrSchedulerArgs",
|
|
18
18
|
"DataLrSchedulerLrSchedulerArgsCosineLrSchedulerArgs",
|
|
19
|
+
"DataProgress",
|
|
19
20
|
"DataTrainingMethod",
|
|
20
21
|
"DataTrainingMethodTrainingMethodSft",
|
|
21
22
|
"DataTrainingMethodTrainingMethodDpo",
|
|
@@ -51,6 +52,16 @@ class DataLrScheduler(BaseModel):
|
|
|
51
52
|
lr_scheduler_args: Optional[DataLrSchedulerLrSchedulerArgs] = None
|
|
52
53
|
|
|
53
54
|
|
|
55
|
+
class DataProgress(BaseModel):
|
|
56
|
+
"""Progress information for the fine-tuning job"""
|
|
57
|
+
|
|
58
|
+
estimate_available: bool
|
|
59
|
+
"""Whether time estimate is available"""
|
|
60
|
+
|
|
61
|
+
seconds_remaining: int
|
|
62
|
+
"""Estimated time remaining in seconds for the fine-tuning job to next state"""
|
|
63
|
+
|
|
64
|
+
|
|
54
65
|
class DataTrainingMethodTrainingMethodSft(BaseModel):
|
|
55
66
|
method: Literal["sft"]
|
|
56
67
|
|
|
@@ -164,6 +175,9 @@ class Data(BaseModel):
|
|
|
164
175
|
owner_address: Optional[str] = None
|
|
165
176
|
"""Owner address information"""
|
|
166
177
|
|
|
178
|
+
progress: Optional[DataProgress] = None
|
|
179
|
+
"""Progress information for the fine-tuning job"""
|
|
180
|
+
|
|
167
181
|
suffix: Optional[str] = None
|
|
168
182
|
"""Suffix added to the fine-tuned model name"""
|
|
169
183
|
|