together 1.5.32__tar.gz → 1.5.34__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {together-1.5.32 → together-1.5.34}/PKG-INFO +1 -1
- {together-1.5.32 → together-1.5.34}/pyproject.toml +1 -1
- {together-1.5.32 → together-1.5.34}/src/together/cli/api/finetune.py +27 -11
- together-1.5.34/src/together/cli/api/utils.py +139 -0
- {together-1.5.32 → together-1.5.34}/src/together/constants.py +6 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/finetune.py +26 -4
- {together-1.5.32 → together-1.5.34}/src/together/types/__init__.py +29 -31
- {together-1.5.32 → together-1.5.34}/src/together/types/finetune.py +36 -6
- {together-1.5.32 → together-1.5.34}/src/together/utils/files.py +202 -35
- together-1.5.32/src/together/cli/api/utils.py +0 -51
- {together-1.5.32 → together-1.5.34}/LICENSE +0 -0
- {together-1.5.32 → together-1.5.34}/README.md +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/__init__.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/abstract/__init__.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/abstract/api_requestor.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/cli/__init__.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/cli/api/__init__.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/cli/api/chat.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/cli/api/completions.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/cli/api/endpoints.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/cli/api/evaluation.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/cli/api/files.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/cli/api/images.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/cli/api/models.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/cli/cli.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/client.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/error.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/filemanager.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/legacy/__init__.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/legacy/base.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/legacy/complete.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/legacy/embeddings.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/legacy/files.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/legacy/finetune.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/legacy/images.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/legacy/models.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/__init__.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/audio/__init__.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/audio/speech.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/audio/transcriptions.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/audio/translations.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/audio/voices.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/batch.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/chat/__init__.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/chat/completions.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/code_interpreter.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/completions.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/embeddings.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/endpoints.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/evaluation.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/files.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/images.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/models.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/rerank.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/resources/videos.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/together_response.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/abstract.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/audio_speech.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/batch.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/chat_completions.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/code_interpreter.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/common.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/completions.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/embeddings.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/endpoints.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/error.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/evaluation.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/files.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/images.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/models.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/rerank.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/types/videos.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/utils/__init__.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/utils/_log.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/utils/api_helpers.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/utils/tools.py +0 -0
- {together-1.5.32 → together-1.5.34}/src/together/version.py +0 -0
|
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
|
|
|
12
12
|
|
|
13
13
|
[tool.poetry]
|
|
14
14
|
name = "together"
|
|
15
|
-
version = "1.5.
|
|
15
|
+
version = "1.5.34"
|
|
16
16
|
authors = ["Together AI <support@together.ai>"]
|
|
17
17
|
description = "Python client for Together's Cloud Platform! Note: SDK 2.0 is now available at https://github.com/togethercomputer/together-py"
|
|
18
18
|
readme = "README.md"
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
import re
|
|
5
4
|
from datetime import datetime, timezone
|
|
6
5
|
from textwrap import wrap
|
|
7
6
|
from typing import Any, Literal
|
|
@@ -9,22 +8,16 @@ from typing import Any, Literal
|
|
|
9
8
|
import click
|
|
10
9
|
from click.core import ParameterSource # type: ignore[attr-defined]
|
|
11
10
|
from rich import print as rprint
|
|
11
|
+
from rich.json import JSON
|
|
12
12
|
from tabulate import tabulate
|
|
13
13
|
|
|
14
14
|
from together import Together
|
|
15
|
-
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
|
|
16
|
-
from together.types.finetune import
|
|
17
|
-
DownloadCheckpointType,
|
|
18
|
-
FinetuneEventType,
|
|
19
|
-
FinetuneTrainingLimits,
|
|
20
|
-
FullTrainingType,
|
|
21
|
-
LoRATrainingType,
|
|
22
|
-
)
|
|
15
|
+
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX, generate_progress_bar
|
|
16
|
+
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
|
|
23
17
|
from together.utils import (
|
|
24
18
|
finetune_price_to_dollars,
|
|
25
19
|
format_timestamp,
|
|
26
20
|
log_warn,
|
|
27
|
-
log_warn_once,
|
|
28
21
|
parse_timestamp,
|
|
29
22
|
)
|
|
30
23
|
|
|
@@ -202,6 +195,12 @@ def fine_tuning(ctx: click.Context) -> None:
|
|
|
202
195
|
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
|
|
203
196
|
"`auto` will automatically determine whether to mask the inputs based on the data format.",
|
|
204
197
|
)
|
|
198
|
+
@click.option(
|
|
199
|
+
"--train-vision",
|
|
200
|
+
type=bool,
|
|
201
|
+
default=False,
|
|
202
|
+
help="Whether to train the vision encoder. Only supported for multimodal models.",
|
|
203
|
+
)
|
|
205
204
|
@click.option(
|
|
206
205
|
"--from-checkpoint",
|
|
207
206
|
type=str,
|
|
@@ -257,6 +256,7 @@ def create(
|
|
|
257
256
|
lora_dropout: float,
|
|
258
257
|
lora_alpha: float,
|
|
259
258
|
lora_trainable_modules: str,
|
|
259
|
+
train_vision: bool,
|
|
260
260
|
suffix: str,
|
|
261
261
|
wandb_api_key: str,
|
|
262
262
|
wandb_base_url: str,
|
|
@@ -298,6 +298,7 @@ def create(
|
|
|
298
298
|
lora_dropout=lora_dropout,
|
|
299
299
|
lora_alpha=lora_alpha,
|
|
300
300
|
lora_trainable_modules=lora_trainable_modules,
|
|
301
|
+
train_vision=train_vision,
|
|
301
302
|
suffix=suffix,
|
|
302
303
|
wandb_api_key=wandb_api_key,
|
|
303
304
|
wandb_base_url=wandb_base_url,
|
|
@@ -367,6 +368,10 @@ def create(
|
|
|
367
368
|
"You have specified a number of evaluation loops but no validation file."
|
|
368
369
|
)
|
|
369
370
|
|
|
371
|
+
if model_limits.supports_vision:
|
|
372
|
+
# Don't show price estimation for multimodal models yet
|
|
373
|
+
confirm = True
|
|
374
|
+
|
|
370
375
|
finetune_price_estimation_result = client.fine_tuning.estimate_price(
|
|
371
376
|
training_file=training_file,
|
|
372
377
|
validation_file=validation_file,
|
|
@@ -435,6 +440,9 @@ def list(ctx: click.Context) -> None:
|
|
|
435
440
|
"Price": f"""${
|
|
436
441
|
finetune_price_to_dollars(float(str(i.total_price)))
|
|
437
442
|
}""", # convert to string for mypy typing
|
|
443
|
+
"Progress": generate_progress_bar(
|
|
444
|
+
i, datetime.now().astimezone(), use_rich=False
|
|
445
|
+
),
|
|
438
446
|
}
|
|
439
447
|
)
|
|
440
448
|
table = tabulate(display_list, headers="keys", tablefmt="grid", showindex=True)
|
|
@@ -454,7 +462,15 @@ def retrieve(ctx: click.Context, fine_tune_id: str) -> None:
|
|
|
454
462
|
# remove events from response for cleaner output
|
|
455
463
|
response.events = None
|
|
456
464
|
|
|
457
|
-
|
|
465
|
+
rprint(JSON.from_data(response.model_dump(exclude_none=True)))
|
|
466
|
+
progress_text = generate_progress_bar(
|
|
467
|
+
response, datetime.now().astimezone(), use_rich=True
|
|
468
|
+
)
|
|
469
|
+
status = "Unknown"
|
|
470
|
+
if response.status is not None:
|
|
471
|
+
status = response.status.value
|
|
472
|
+
prefix = f"Status: [bold]{status}[/bold],"
|
|
473
|
+
rprint(f"{prefix} {progress_text}")
|
|
458
474
|
|
|
459
475
|
|
|
460
476
|
@fine_tuning.command()
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import re
|
|
5
|
+
from gettext import gettext as _
|
|
6
|
+
from typing import Literal
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
|
|
9
|
+
import click
|
|
10
|
+
|
|
11
|
+
from together.types.finetune import FinetuneResponse, COMPLETED_STATUSES
|
|
12
|
+
|
|
13
|
+
_PROGRESS_BAR_WIDTH = 40
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AutoIntParamType(click.ParamType):
|
|
17
|
+
name = "integer_or_max"
|
|
18
|
+
_number_class = int
|
|
19
|
+
|
|
20
|
+
def convert(
|
|
21
|
+
self, value: str, param: click.Parameter | None, ctx: click.Context | None
|
|
22
|
+
) -> int | Literal["max"] | None:
|
|
23
|
+
if value == "max":
|
|
24
|
+
return "max"
|
|
25
|
+
try:
|
|
26
|
+
return int(value)
|
|
27
|
+
except ValueError:
|
|
28
|
+
self.fail(
|
|
29
|
+
_("{value!r} is not a valid {number_type}.").format(
|
|
30
|
+
value=value, number_type=self.name
|
|
31
|
+
),
|
|
32
|
+
param,
|
|
33
|
+
ctx,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BooleanWithAutoParamType(click.ParamType):
|
|
38
|
+
name = "boolean_or_auto"
|
|
39
|
+
|
|
40
|
+
def convert(
|
|
41
|
+
self, value: str, param: click.Parameter | None, ctx: click.Context | None
|
|
42
|
+
) -> bool | Literal["auto"] | None:
|
|
43
|
+
if value == "auto":
|
|
44
|
+
return "auto"
|
|
45
|
+
try:
|
|
46
|
+
return bool(value)
|
|
47
|
+
except ValueError:
|
|
48
|
+
self.fail(
|
|
49
|
+
_("{value!r} is not a valid {type}.").format(
|
|
50
|
+
value=value, type=self.name
|
|
51
|
+
),
|
|
52
|
+
param,
|
|
53
|
+
ctx,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
INT_WITH_MAX = AutoIntParamType()
|
|
58
|
+
BOOL_WITH_AUTO = BooleanWithAutoParamType()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _human_readable_time(timedelta: float) -> str:
|
|
62
|
+
"""Convert a timedelta to a compact human-readble string
|
|
63
|
+
Examples:
|
|
64
|
+
00:00:10 -> 10s
|
|
65
|
+
01:23:45 -> 1h 23min 45s
|
|
66
|
+
1 Month 23 days 04:56:07 -> 1month 23d 4h 56min 7s
|
|
67
|
+
Args:
|
|
68
|
+
timedelta (float): The timedelta in seconds to convert.
|
|
69
|
+
Returns:
|
|
70
|
+
A string representing the timedelta in a human-readable format.
|
|
71
|
+
"""
|
|
72
|
+
units = [
|
|
73
|
+
(30 * 24 * 60 * 60, "month"), # 30 days
|
|
74
|
+
(24 * 60 * 60, "d"),
|
|
75
|
+
(60 * 60, "h"),
|
|
76
|
+
(60, "min"),
|
|
77
|
+
(1, "s"),
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
total_seconds = int(timedelta)
|
|
81
|
+
parts = []
|
|
82
|
+
|
|
83
|
+
for unit_seconds, unit_name in units:
|
|
84
|
+
if total_seconds >= unit_seconds:
|
|
85
|
+
value = total_seconds // unit_seconds
|
|
86
|
+
total_seconds %= unit_seconds
|
|
87
|
+
parts.append(f"{value}{unit_name}")
|
|
88
|
+
|
|
89
|
+
return " ".join(parts) if parts else "0s"
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def generate_progress_bar(
|
|
93
|
+
finetune_job: FinetuneResponse, current_time: datetime, use_rich: bool = False
|
|
94
|
+
) -> str:
|
|
95
|
+
"""Generate a progress bar for a finetune job.
|
|
96
|
+
Args:
|
|
97
|
+
finetune_job: The finetune job to generate a progress bar for.
|
|
98
|
+
current_time: The current time.
|
|
99
|
+
use_rich: Whether to use rich formatting.
|
|
100
|
+
Returns:
|
|
101
|
+
A string representing the progress bar.
|
|
102
|
+
"""
|
|
103
|
+
progress = "Progress: [bold red]unavailable[/bold red]"
|
|
104
|
+
if finetune_job.status in COMPLETED_STATUSES:
|
|
105
|
+
progress = "Progress: [bold green]completed[/bold green]"
|
|
106
|
+
elif finetune_job.updated_at is not None:
|
|
107
|
+
# Replace 'Z' with '+00:00' for Python 3.10 compatibility
|
|
108
|
+
updated_at_str = finetune_job.updated_at.replace("Z", "+00:00")
|
|
109
|
+
update_at = datetime.fromisoformat(updated_at_str).astimezone()
|
|
110
|
+
|
|
111
|
+
if finetune_job.progress is not None:
|
|
112
|
+
if current_time < update_at:
|
|
113
|
+
return progress
|
|
114
|
+
|
|
115
|
+
if not finetune_job.progress.estimate_available:
|
|
116
|
+
return progress
|
|
117
|
+
|
|
118
|
+
if finetune_job.progress.seconds_remaining <= 0:
|
|
119
|
+
return progress
|
|
120
|
+
|
|
121
|
+
elapsed_time = (current_time - update_at).total_seconds()
|
|
122
|
+
ratio_filled = min(
|
|
123
|
+
elapsed_time / finetune_job.progress.seconds_remaining, 1.0
|
|
124
|
+
)
|
|
125
|
+
percentage = ratio_filled * 100
|
|
126
|
+
filled = math.ceil(ratio_filled * _PROGRESS_BAR_WIDTH)
|
|
127
|
+
bar = "█" * filled + "░" * (_PROGRESS_BAR_WIDTH - filled)
|
|
128
|
+
time_left = "N/A"
|
|
129
|
+
if finetune_job.progress.seconds_remaining > elapsed_time:
|
|
130
|
+
time_left = _human_readable_time(
|
|
131
|
+
finetune_job.progress.seconds_remaining - elapsed_time
|
|
132
|
+
)
|
|
133
|
+
time_text = f"{time_left} left"
|
|
134
|
+
progress = f"Progress: {bar} [bold]{percentage:>3.0f}%[/bold] [yellow]{time_text}[/yellow]"
|
|
135
|
+
|
|
136
|
+
if use_rich:
|
|
137
|
+
return progress
|
|
138
|
+
|
|
139
|
+
return re.sub(r"\[/?[^\]]+\]", "", progress)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
|
|
3
|
+
|
|
3
4
|
# Session constants
|
|
4
5
|
TIMEOUT_SECS = 600
|
|
5
6
|
MAX_SESSION_LIFETIME_SECS = 180
|
|
@@ -40,6 +41,11 @@ MIN_SAMPLES = 1
|
|
|
40
41
|
# the number of bytes in a gigabyte, used to convert bytes to GB for readable comparison
|
|
41
42
|
NUM_BYTES_IN_GB = 2**30
|
|
42
43
|
|
|
44
|
+
# Multimodal limits
|
|
45
|
+
MAX_IMAGES_PER_EXAMPLE = 10
|
|
46
|
+
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB
|
|
47
|
+
# Max length = Header length + base64 factor (4/3) * image bytes
|
|
48
|
+
MAX_BASE64_IMAGE_LENGTH = len("data:image/jpeg;base64,") + 4 * MAX_IMAGE_BYTES // 3
|
|
43
49
|
|
|
44
50
|
# expected columns for Parquet files
|
|
45
51
|
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import re
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Dict, List, Literal
|
|
6
6
|
|
|
7
7
|
from rich import print as rprint
|
|
8
8
|
|
|
@@ -18,10 +18,11 @@ from together.types import (
|
|
|
18
18
|
FinetuneList,
|
|
19
19
|
FinetuneListEvents,
|
|
20
20
|
FinetuneLRScheduler,
|
|
21
|
-
|
|
22
|
-
FinetuneResponse,
|
|
21
|
+
FinetuneMultimodalParams,
|
|
23
22
|
FinetunePriceEstimationRequest,
|
|
24
23
|
FinetunePriceEstimationResponse,
|
|
24
|
+
FinetuneRequest,
|
|
25
|
+
FinetuneResponse,
|
|
25
26
|
FinetuneTrainingLimits,
|
|
26
27
|
FullTrainingType,
|
|
27
28
|
LinearLRScheduler,
|
|
@@ -73,6 +74,7 @@ def create_finetune_request(
|
|
|
73
74
|
lora_dropout: float | None = 0,
|
|
74
75
|
lora_alpha: float | None = None,
|
|
75
76
|
lora_trainable_modules: str | None = "all-linear",
|
|
77
|
+
train_vision: bool = False,
|
|
76
78
|
suffix: str | None = None,
|
|
77
79
|
wandb_api_key: str | None = None,
|
|
78
80
|
wandb_base_url: str | None = None,
|
|
@@ -252,6 +254,15 @@ def create_finetune_request(
|
|
|
252
254
|
simpo_gamma=simpo_gamma,
|
|
253
255
|
)
|
|
254
256
|
|
|
257
|
+
if model_limits.supports_vision:
|
|
258
|
+
multimodal_params = FinetuneMultimodalParams(train_vision=train_vision)
|
|
259
|
+
elif not model_limits.supports_vision and train_vision:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"Vision encoder training is not supported for the non-multimodal model `{model}`"
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
multimodal_params = None
|
|
265
|
+
|
|
255
266
|
finetune_request = FinetuneRequest(
|
|
256
267
|
model=model,
|
|
257
268
|
training_file=training_file,
|
|
@@ -272,6 +283,7 @@ def create_finetune_request(
|
|
|
272
283
|
wandb_project_name=wandb_project_name,
|
|
273
284
|
wandb_name=wandb_name,
|
|
274
285
|
training_method=training_method_cls,
|
|
286
|
+
multimodal_params=multimodal_params,
|
|
275
287
|
from_checkpoint=from_checkpoint,
|
|
276
288
|
from_hf_model=from_hf_model,
|
|
277
289
|
hf_model_revision=hf_model_revision,
|
|
@@ -342,6 +354,7 @@ class FineTuning:
|
|
|
342
354
|
lora_dropout: float | None = 0,
|
|
343
355
|
lora_alpha: float | None = None,
|
|
344
356
|
lora_trainable_modules: str | None = "all-linear",
|
|
357
|
+
train_vision: bool = False,
|
|
345
358
|
suffix: str | None = None,
|
|
346
359
|
wandb_api_key: str | None = None,
|
|
347
360
|
wandb_base_url: str | None = None,
|
|
@@ -387,6 +400,7 @@ class FineTuning:
|
|
|
387
400
|
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
|
|
388
401
|
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
|
|
389
402
|
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
|
|
403
|
+
train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False.
|
|
390
404
|
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
|
|
391
405
|
Defaults to None.
|
|
392
406
|
wandb_api_key (str, optional): API key for Weights & Biases integration.
|
|
@@ -464,6 +478,7 @@ class FineTuning:
|
|
|
464
478
|
lora_dropout=lora_dropout,
|
|
465
479
|
lora_alpha=lora_alpha,
|
|
466
480
|
lora_trainable_modules=lora_trainable_modules,
|
|
481
|
+
train_vision=train_vision,
|
|
467
482
|
suffix=suffix,
|
|
468
483
|
wandb_api_key=wandb_api_key,
|
|
469
484
|
wandb_base_url=wandb_base_url,
|
|
@@ -906,6 +921,7 @@ class AsyncFineTuning:
|
|
|
906
921
|
lora_dropout: float | None = 0,
|
|
907
922
|
lora_alpha: float | None = None,
|
|
908
923
|
lora_trainable_modules: str | None = "all-linear",
|
|
924
|
+
train_vision: bool = False,
|
|
909
925
|
suffix: str | None = None,
|
|
910
926
|
wandb_api_key: str | None = None,
|
|
911
927
|
wandb_base_url: str | None = None,
|
|
@@ -951,6 +967,7 @@ class AsyncFineTuning:
|
|
|
951
967
|
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
|
|
952
968
|
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
|
|
953
969
|
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
|
|
970
|
+
train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False.
|
|
954
971
|
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
|
|
955
972
|
Defaults to None.
|
|
956
973
|
wandb_api_key (str, optional): API key for Weights & Biases integration.
|
|
@@ -1028,6 +1045,7 @@ class AsyncFineTuning:
|
|
|
1028
1045
|
lora_dropout=lora_dropout,
|
|
1029
1046
|
lora_alpha=lora_alpha,
|
|
1030
1047
|
lora_trainable_modules=lora_trainable_modules,
|
|
1048
|
+
train_vision=train_vision,
|
|
1031
1049
|
suffix=suffix,
|
|
1032
1050
|
wandb_api_key=wandb_api_key,
|
|
1033
1051
|
wandb_base_url=wandb_base_url,
|
|
@@ -1046,7 +1064,11 @@ class AsyncFineTuning:
|
|
|
1046
1064
|
hf_output_repo_name=hf_output_repo_name,
|
|
1047
1065
|
)
|
|
1048
1066
|
|
|
1049
|
-
if
|
|
1067
|
+
if (
|
|
1068
|
+
from_checkpoint is None
|
|
1069
|
+
and from_hf_model is None
|
|
1070
|
+
and not model_limits.supports_vision
|
|
1071
|
+
):
|
|
1050
1072
|
price_estimation_result = await self.estimate_price(
|
|
1051
1073
|
training_file=training_file,
|
|
1052
1074
|
validation_file=validation_file,
|
|
@@ -7,17 +7,18 @@ from together.types.audio_speech import (
|
|
|
7
7
|
AudioSpeechStreamChunk,
|
|
8
8
|
AudioSpeechStreamEvent,
|
|
9
9
|
AudioSpeechStreamResponse,
|
|
10
|
+
AudioTimestampGranularities,
|
|
10
11
|
AudioTranscriptionRequest,
|
|
11
|
-
AudioTranslationRequest,
|
|
12
12
|
AudioTranscriptionResponse,
|
|
13
|
+
AudioTranscriptionResponseFormat,
|
|
13
14
|
AudioTranscriptionVerboseResponse,
|
|
15
|
+
AudioTranslationRequest,
|
|
14
16
|
AudioTranslationResponse,
|
|
15
17
|
AudioTranslationVerboseResponse,
|
|
16
|
-
AudioTranscriptionResponseFormat,
|
|
17
|
-
AudioTimestampGranularities,
|
|
18
18
|
ModelVoices,
|
|
19
19
|
VoiceListResponse,
|
|
20
20
|
)
|
|
21
|
+
from together.types.batch import BatchEndpoint, BatchJob, BatchJobStatus
|
|
21
22
|
from together.types.chat_completions import (
|
|
22
23
|
ChatCompletionChunk,
|
|
23
24
|
ChatCompletionRequest,
|
|
@@ -31,6 +32,19 @@ from together.types.completions import (
|
|
|
31
32
|
)
|
|
32
33
|
from together.types.embeddings import EmbeddingRequest, EmbeddingResponse
|
|
33
34
|
from together.types.endpoints import Autoscaling, DedicatedEndpoint, ListEndpoint
|
|
35
|
+
from together.types.evaluation import (
|
|
36
|
+
ClassifyParameters,
|
|
37
|
+
CompareParameters,
|
|
38
|
+
EvaluationCreateResponse,
|
|
39
|
+
EvaluationJob,
|
|
40
|
+
EvaluationRequest,
|
|
41
|
+
EvaluationStatus,
|
|
42
|
+
EvaluationStatusResponse,
|
|
43
|
+
EvaluationType,
|
|
44
|
+
JudgeModelConfig,
|
|
45
|
+
ModelRequest,
|
|
46
|
+
ScoreParameters,
|
|
47
|
+
)
|
|
34
48
|
from together.types.files import (
|
|
35
49
|
FileDeleteResponse,
|
|
36
50
|
FileList,
|
|
@@ -41,49 +55,32 @@ from together.types.files import (
|
|
|
41
55
|
FileType,
|
|
42
56
|
)
|
|
43
57
|
from together.types.finetune import (
|
|
44
|
-
TrainingMethodDPO,
|
|
45
|
-
TrainingMethodSFT,
|
|
46
|
-
FinetuneCheckpoint,
|
|
47
58
|
CosineLRScheduler,
|
|
48
59
|
CosineLRSchedulerArgs,
|
|
60
|
+
FinetuneCheckpoint,
|
|
61
|
+
FinetuneDeleteResponse,
|
|
49
62
|
FinetuneDownloadResult,
|
|
50
|
-
LinearLRScheduler,
|
|
51
|
-
LinearLRSchedulerArgs,
|
|
52
|
-
FinetuneLRScheduler,
|
|
53
63
|
FinetuneList,
|
|
54
64
|
FinetuneListEvents,
|
|
55
|
-
|
|
56
|
-
|
|
65
|
+
FinetuneLRScheduler,
|
|
66
|
+
FinetuneMultimodalParams,
|
|
57
67
|
FinetunePriceEstimationRequest,
|
|
58
68
|
FinetunePriceEstimationResponse,
|
|
59
|
-
|
|
69
|
+
FinetuneRequest,
|
|
70
|
+
FinetuneResponse,
|
|
60
71
|
FinetuneTrainingLimits,
|
|
61
72
|
FullTrainingType,
|
|
73
|
+
LinearLRScheduler,
|
|
74
|
+
LinearLRSchedulerArgs,
|
|
62
75
|
LoRATrainingType,
|
|
76
|
+
TrainingMethodDPO,
|
|
77
|
+
TrainingMethodSFT,
|
|
63
78
|
TrainingType,
|
|
64
79
|
)
|
|
65
80
|
from together.types.images import ImageRequest, ImageResponse
|
|
66
81
|
from together.types.models import ModelObject, ModelUploadRequest, ModelUploadResponse
|
|
67
82
|
from together.types.rerank import RerankRequest, RerankResponse
|
|
68
|
-
from together.types.
|
|
69
|
-
from together.types.evaluation import (
|
|
70
|
-
EvaluationType,
|
|
71
|
-
EvaluationStatus,
|
|
72
|
-
JudgeModelConfig,
|
|
73
|
-
ModelRequest,
|
|
74
|
-
ClassifyParameters,
|
|
75
|
-
ScoreParameters,
|
|
76
|
-
CompareParameters,
|
|
77
|
-
EvaluationRequest,
|
|
78
|
-
EvaluationCreateResponse,
|
|
79
|
-
EvaluationJob,
|
|
80
|
-
EvaluationStatusResponse,
|
|
81
|
-
)
|
|
82
|
-
from together.types.videos import (
|
|
83
|
-
CreateVideoBody,
|
|
84
|
-
CreateVideoResponse,
|
|
85
|
-
VideoJob,
|
|
86
|
-
)
|
|
83
|
+
from together.types.videos import CreateVideoBody, CreateVideoResponse, VideoJob
|
|
87
84
|
|
|
88
85
|
|
|
89
86
|
__all__ = [
|
|
@@ -131,6 +128,7 @@ __all__ = [
|
|
|
131
128
|
"RerankRequest",
|
|
132
129
|
"RerankResponse",
|
|
133
130
|
"FinetuneTrainingLimits",
|
|
131
|
+
"FinetuneMultimodalParams",
|
|
134
132
|
"AudioSpeechRequest",
|
|
135
133
|
"AudioResponseFormat",
|
|
136
134
|
"AudioLanguage",
|
|
@@ -1,14 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from enum import Enum
|
|
4
|
-
from typing import List, Literal
|
|
4
|
+
from typing import Any, List, Literal
|
|
5
5
|
|
|
6
6
|
from pydantic import Field, StrictBool, field_validator
|
|
7
7
|
|
|
8
8
|
from together.types.abstract import BaseModel
|
|
9
|
-
from together.types.common import
|
|
10
|
-
ObjectType,
|
|
11
|
-
)
|
|
9
|
+
from together.types.common import ObjectType
|
|
12
10
|
|
|
13
11
|
|
|
14
12
|
class FinetuneJobStatus(str, Enum):
|
|
@@ -28,6 +26,14 @@ class FinetuneJobStatus(str, Enum):
|
|
|
28
26
|
STATUS_COMPLETED = "completed"
|
|
29
27
|
|
|
30
28
|
|
|
29
|
+
COMPLETED_STATUSES = [
|
|
30
|
+
FinetuneJobStatus.STATUS_ERROR,
|
|
31
|
+
FinetuneJobStatus.STATUS_USER_ERROR,
|
|
32
|
+
FinetuneJobStatus.STATUS_COMPLETED,
|
|
33
|
+
FinetuneJobStatus.STATUS_CANCELLED,
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
31
37
|
class FinetuneEventLevels(str, Enum):
|
|
32
38
|
"""
|
|
33
39
|
Fine-tune job event status levels
|
|
@@ -167,6 +173,23 @@ class TrainingMethodDPO(TrainingMethod):
|
|
|
167
173
|
simpo_gamma: float | None = None
|
|
168
174
|
|
|
169
175
|
|
|
176
|
+
class FinetuneMultimodalParams(BaseModel):
|
|
177
|
+
"""
|
|
178
|
+
Multimodal parameters
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
train_vision: bool = False
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class FinetuneProgress(BaseModel):
|
|
185
|
+
"""
|
|
186
|
+
Fine-tune job progress
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
estimate_available: bool = False
|
|
190
|
+
seconds_remaining: float = 0
|
|
191
|
+
|
|
192
|
+
|
|
170
193
|
class FinetuneRequest(BaseModel):
|
|
171
194
|
"""
|
|
172
195
|
Fine-tune request type
|
|
@@ -214,6 +237,8 @@ class FinetuneRequest(BaseModel):
|
|
|
214
237
|
)
|
|
215
238
|
# from step
|
|
216
239
|
from_checkpoint: str | None = None
|
|
240
|
+
# multimodal parameters
|
|
241
|
+
multimodal_params: FinetuneMultimodalParams | None = None
|
|
217
242
|
# hf related fields
|
|
218
243
|
hf_api_token: str | None = None
|
|
219
244
|
hf_output_repo_name: str | None = None
|
|
@@ -296,6 +321,10 @@ class FinetuneResponse(BaseModel):
|
|
|
296
321
|
training_file_size: int | None = Field(None, alias="TrainingFileSize")
|
|
297
322
|
train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
|
|
298
323
|
from_checkpoint: str | None = None
|
|
324
|
+
# multimodal parameters
|
|
325
|
+
multimodal_params: FinetuneMultimodalParams | None = None
|
|
326
|
+
|
|
327
|
+
progress: FinetuneProgress | None = None
|
|
299
328
|
|
|
300
329
|
@field_validator("training_type")
|
|
301
330
|
@classmethod
|
|
@@ -318,8 +347,8 @@ class FinetunePriceEstimationRequest(BaseModel):
|
|
|
318
347
|
model: str
|
|
319
348
|
n_epochs: int
|
|
320
349
|
n_evals: int
|
|
321
|
-
training_type:
|
|
322
|
-
training_method:
|
|
350
|
+
training_type: LoRATrainingType | FullTrainingType
|
|
351
|
+
training_method: TrainingMethodSFT | TrainingMethodDPO
|
|
323
352
|
|
|
324
353
|
|
|
325
354
|
class FinetunePriceEstimationResponse(BaseModel):
|
|
@@ -390,6 +419,7 @@ class FinetuneTrainingLimits(BaseModel):
|
|
|
390
419
|
min_learning_rate: float
|
|
391
420
|
full_training: FinetuneFullTrainingLimits | None = None
|
|
392
421
|
lora_training: FinetuneLoraTrainingLimits | None = None
|
|
422
|
+
supports_vision: bool = False
|
|
393
423
|
|
|
394
424
|
|
|
395
425
|
class LinearLRSchedulerArgs(BaseModel):
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import csv
|
|
3
4
|
import json
|
|
4
5
|
import os
|
|
5
|
-
import csv
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from traceback import format_exc
|
|
8
8
|
from typing import Any, Dict, List
|
|
@@ -10,18 +10,30 @@ from typing import Any, Dict, List
|
|
|
10
10
|
from tqdm import tqdm
|
|
11
11
|
|
|
12
12
|
from together.constants import (
|
|
13
|
+
JSONL_REQUIRED_COLUMNS_MAP,
|
|
14
|
+
MAX_BASE64_IMAGE_LENGTH,
|
|
13
15
|
MAX_FILE_SIZE_GB,
|
|
16
|
+
MAX_IMAGES_PER_EXAMPLE,
|
|
14
17
|
MIN_SAMPLES,
|
|
15
18
|
NUM_BYTES_IN_GB,
|
|
16
19
|
PARQUET_EXPECTED_COLUMNS,
|
|
17
|
-
JSONL_REQUIRED_COLUMNS_MAP,
|
|
18
|
-
REQUIRED_COLUMNS_MESSAGE,
|
|
19
20
|
POSSIBLE_ROLES_CONVERSATION,
|
|
21
|
+
REQUIRED_COLUMNS_MESSAGE,
|
|
20
22
|
DatasetFormat,
|
|
21
23
|
)
|
|
22
24
|
from together.types import FilePurpose
|
|
23
25
|
|
|
24
26
|
|
|
27
|
+
# MessageContent is a string or a list of dicts with 'type': 'text' or 'image_url', and 'text' or 'image_url.url'
|
|
28
|
+
# Example: "Hello" or [
|
|
29
|
+
# {"type": "text", "text": "Hello"},
|
|
30
|
+
# {"type": "image_url", "image_url": {
|
|
31
|
+
# "url": "data:image/jpeg;base64,..."
|
|
32
|
+
# }}
|
|
33
|
+
# ]
|
|
34
|
+
MessageContent = str | list[dict[str, Any]]
|
|
35
|
+
|
|
36
|
+
|
|
25
37
|
class InvalidFileFormatError(ValueError):
|
|
26
38
|
"""Exception raised for invalid file formats during file checks."""
|
|
27
39
|
|
|
@@ -70,7 +82,7 @@ def check_file(
|
|
|
70
82
|
|
|
71
83
|
if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
|
|
72
84
|
report_dict["message"] = (
|
|
73
|
-
f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB
|
|
85
|
+
f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB, 3)} GB."
|
|
74
86
|
)
|
|
75
87
|
report_dict["is_check_passed"] = False
|
|
76
88
|
elif file_size == 0:
|
|
@@ -103,7 +115,9 @@ def check_file(
|
|
|
103
115
|
return report_dict
|
|
104
116
|
|
|
105
117
|
|
|
106
|
-
def _check_conversation_type(
|
|
118
|
+
def _check_conversation_type(
|
|
119
|
+
messages: List[Dict[str, str | int | MessageContent]], idx: int
|
|
120
|
+
) -> None:
|
|
107
121
|
"""Check that the conversation has correct type.
|
|
108
122
|
|
|
109
123
|
Args:
|
|
@@ -145,12 +159,6 @@ def _check_conversation_type(messages: List[Dict[str, str | bool]], idx: int) ->
|
|
|
145
159
|
line_number=idx + 1,
|
|
146
160
|
error_source="key_value",
|
|
147
161
|
)
|
|
148
|
-
if not isinstance(message[column], str):
|
|
149
|
-
raise InvalidFileFormatError(
|
|
150
|
-
message=f"Column `{column}` is not a string on line {idx + 1}. Found {type(message[column])}",
|
|
151
|
-
line_number=idx + 1,
|
|
152
|
-
error_source="text_field",
|
|
153
|
-
)
|
|
154
162
|
|
|
155
163
|
|
|
156
164
|
def _check_conversation_roles(
|
|
@@ -175,7 +183,9 @@ def _check_conversation_roles(
|
|
|
175
183
|
)
|
|
176
184
|
|
|
177
185
|
|
|
178
|
-
def _check_message_weight(
|
|
186
|
+
def _check_message_weight(
|
|
187
|
+
message: Dict[str, str | int | MessageContent], idx: int
|
|
188
|
+
) -> int | None:
|
|
179
189
|
"""Check that the message has a weight with the correct type and value.
|
|
180
190
|
|
|
181
191
|
Args:
|
|
@@ -199,11 +209,14 @@ def _check_message_weight(message: Dict[str, str | bool], idx: int) -> None:
|
|
|
199
209
|
line_number=idx + 1,
|
|
200
210
|
error_source="key_value",
|
|
201
211
|
)
|
|
212
|
+
return weight
|
|
213
|
+
|
|
214
|
+
return None
|
|
202
215
|
|
|
203
216
|
|
|
204
217
|
def _check_message_role(
|
|
205
|
-
message: Dict[str, str |
|
|
206
|
-
) -> str
|
|
218
|
+
message: Dict[str, str | int | MessageContent], previous_role: str | None, idx: int
|
|
219
|
+
) -> str:
|
|
207
220
|
"""Check that the message has correct roles.
|
|
208
221
|
|
|
209
222
|
Args:
|
|
@@ -217,6 +230,14 @@ def _check_message_role(
|
|
|
217
230
|
Raises:
|
|
218
231
|
InvalidFileFormatError: If the message role is invalid.
|
|
219
232
|
"""
|
|
233
|
+
if not isinstance(message["role"], str):
|
|
234
|
+
raise InvalidFileFormatError(
|
|
235
|
+
message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. "
|
|
236
|
+
f"Role must be a string. Found {type(message['role'])}",
|
|
237
|
+
line_number=idx + 1,
|
|
238
|
+
error_source="key_value",
|
|
239
|
+
)
|
|
240
|
+
|
|
220
241
|
if message["role"] not in POSSIBLE_ROLES_CONVERSATION:
|
|
221
242
|
raise InvalidFileFormatError(
|
|
222
243
|
message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. "
|
|
@@ -234,8 +255,134 @@ def _check_message_role(
|
|
|
234
255
|
return message["role"]
|
|
235
256
|
|
|
236
257
|
|
|
258
|
+
def _check_message_content(
|
|
259
|
+
message_content: str | int | MessageContent, role: str, idx: int
|
|
260
|
+
) -> tuple[bool, int]:
|
|
261
|
+
"""Check that the message content has the correct type.
|
|
262
|
+
Message content can be either a) a string or b) an OpenAI-style multimodal list of content items
|
|
263
|
+
Example:
|
|
264
|
+
a) "Hello", or
|
|
265
|
+
b) [
|
|
266
|
+
{"type": "text", "text": "Hello"},
|
|
267
|
+
{"type": "image_url", "image_url": {
|
|
268
|
+
"url": "data:image/jpeg;base64,..."
|
|
269
|
+
}}
|
|
270
|
+
]
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
message: The message to check.
|
|
274
|
+
role: The role of the message.
|
|
275
|
+
idx: Line number in the file.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
tuple[bool, int]: A tuple with message is multimodal and the number of images in the message content.
|
|
279
|
+
"""
|
|
280
|
+
# Text-only message content
|
|
281
|
+
if isinstance(message_content, str):
|
|
282
|
+
return False, 0
|
|
283
|
+
|
|
284
|
+
# Multimodal message content
|
|
285
|
+
if isinstance(message_content, list):
|
|
286
|
+
num_images = 0
|
|
287
|
+
for item in message_content:
|
|
288
|
+
if not isinstance(item, dict):
|
|
289
|
+
raise InvalidFileFormatError(
|
|
290
|
+
"The dataset is malformed, the `content` field must be a list of dicts.",
|
|
291
|
+
line_number=idx + 1,
|
|
292
|
+
error_source="key_value",
|
|
293
|
+
)
|
|
294
|
+
if "type" not in item:
|
|
295
|
+
raise InvalidFileFormatError(
|
|
296
|
+
"The dataset is malformed, the `content` field must be a list of dicts with a `type` field.",
|
|
297
|
+
line_number=idx + 1,
|
|
298
|
+
error_source="key_value",
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if item["type"] == "text":
|
|
302
|
+
if "text" not in item or not isinstance(item["text"], str):
|
|
303
|
+
raise InvalidFileFormatError(
|
|
304
|
+
"The dataset is malformed, the `text` field must be present in the `content` item field and be"
|
|
305
|
+
f" a string. Got '{item.get('text')!r}' instead.",
|
|
306
|
+
line_number=idx + 1,
|
|
307
|
+
error_source="key_value",
|
|
308
|
+
)
|
|
309
|
+
elif item["type"] == "image_url":
|
|
310
|
+
if role != "user":
|
|
311
|
+
raise InvalidFileFormatError(
|
|
312
|
+
"The dataset is malformed, only user messages can contain images.",
|
|
313
|
+
line_number=idx + 1,
|
|
314
|
+
error_source="key_value",
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
if "image_url" not in item or not isinstance(item["image_url"], dict):
|
|
318
|
+
raise InvalidFileFormatError(
|
|
319
|
+
"The dataset is malformed, the `image_url` field must be present in the `content` field and "
|
|
320
|
+
f"be a dictionary. Got {item.get('image_url')!r} instead.",
|
|
321
|
+
line_number=idx + 1,
|
|
322
|
+
error_source="key_value",
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
image_data = item["image_url"].get("url")
|
|
326
|
+
if not image_data or not isinstance(image_data, str):
|
|
327
|
+
raise InvalidFileFormatError(
|
|
328
|
+
"The dataset is malformed, the `url` field must be present in the `image_url` field and be "
|
|
329
|
+
f"a string. Got {image_data!r} instead.",
|
|
330
|
+
line_number=idx + 1,
|
|
331
|
+
error_source="key_value",
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
if not any(
|
|
335
|
+
image_data.startswith(f"data:image/{fmt};base64,")
|
|
336
|
+
for fmt in ["jpeg", "png", "webp"]
|
|
337
|
+
):
|
|
338
|
+
raise InvalidFileFormatError(
|
|
339
|
+
"The dataset is malformed, the `url` field must be either a JPEG, PNG or WEBP base64-encoded "
|
|
340
|
+
"image in 'data:image/<format>;base64,<base64_encoded_image>' format. "
|
|
341
|
+
f"Got '{image_data[:100]}...' instead.",
|
|
342
|
+
line_number=idx + 1,
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
if len(image_data) > MAX_BASE64_IMAGE_LENGTH:
|
|
346
|
+
raise InvalidFileFormatError(
|
|
347
|
+
"The dataset is malformed, the `url` field must contain base64-encoded image "
|
|
348
|
+
f"that is less than 10MB, found ~{len(image_data) * 3 // 4} bytes.",
|
|
349
|
+
line_number=idx + 1,
|
|
350
|
+
error_source="key_value",
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
num_images += 1
|
|
354
|
+
else:
|
|
355
|
+
raise InvalidFileFormatError(
|
|
356
|
+
"The dataset is malformed, the `type` field must be either 'text' or 'image_url'. "
|
|
357
|
+
f"Got {item['type']!r}.",
|
|
358
|
+
line_number=idx + 1,
|
|
359
|
+
error_source="key_value",
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
if num_images > MAX_IMAGES_PER_EXAMPLE:
|
|
363
|
+
raise InvalidFileFormatError(
|
|
364
|
+
f"The dataset is malformed, the `content` field must contain at most "
|
|
365
|
+
f"{MAX_IMAGES_PER_EXAMPLE} images, found {num_images}.",
|
|
366
|
+
line_number=idx + 1,
|
|
367
|
+
error_source="key_value",
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# We still consider text-only messages in such format as multimodal, even if they don't have any images
|
|
371
|
+
# included - so we can process datasets with rather sparse images (i.e. not in each sample) consistently.
|
|
372
|
+
return True, num_images
|
|
373
|
+
|
|
374
|
+
raise InvalidFileFormatError(
|
|
375
|
+
f"Invalid content type on line {idx + 1} of the input file. Expected string or multimodal list of dicts, "
|
|
376
|
+
f"found {type(message_content)}",
|
|
377
|
+
line_number=idx + 1,
|
|
378
|
+
error_source="key_value",
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
|
|
237
382
|
def validate_messages(
|
|
238
|
-
messages: List[Dict[str, str |
|
|
383
|
+
messages: List[Dict[str, str | int | MessageContent]],
|
|
384
|
+
idx: int,
|
|
385
|
+
require_assistant_role: bool = True,
|
|
239
386
|
) -> None:
|
|
240
387
|
"""Validate the messages column.
|
|
241
388
|
|
|
@@ -249,15 +396,45 @@ def validate_messages(
|
|
|
249
396
|
"""
|
|
250
397
|
_check_conversation_type(messages, idx)
|
|
251
398
|
|
|
252
|
-
has_weights = any("weight" in message for message in messages)
|
|
253
399
|
previous_role = None
|
|
254
400
|
assistant_role_exists = False
|
|
255
401
|
|
|
402
|
+
messages_are_multimodal: bool | None = None
|
|
403
|
+
total_number_of_images = 0
|
|
404
|
+
|
|
256
405
|
for message in messages:
|
|
257
|
-
|
|
258
|
-
_check_message_weight(message, idx)
|
|
406
|
+
message_weight = _check_message_weight(message, idx)
|
|
259
407
|
previous_role = _check_message_role(message, previous_role, idx)
|
|
260
408
|
assistant_role_exists |= previous_role == "assistant"
|
|
409
|
+
is_multimodal, number_of_images = _check_message_content(
|
|
410
|
+
message["content"], role=previous_role, idx=idx
|
|
411
|
+
)
|
|
412
|
+
# Multimodal validation
|
|
413
|
+
if number_of_images > 0 and message_weight is not None and message_weight != 0:
|
|
414
|
+
raise InvalidFileFormatError(
|
|
415
|
+
"Messages with images cannot have non-zero weights.",
|
|
416
|
+
line_number=idx + 1,
|
|
417
|
+
error_source="key_value",
|
|
418
|
+
)
|
|
419
|
+
if messages_are_multimodal is None:
|
|
420
|
+
# Detect the format of the messages in the conversation.
|
|
421
|
+
messages_are_multimodal = is_multimodal
|
|
422
|
+
elif messages_are_multimodal != is_multimodal:
|
|
423
|
+
# Due to the format limitation, we cannot mix multimodal and text only messages in the same sample.
|
|
424
|
+
raise InvalidFileFormatError(
|
|
425
|
+
"Messages in the conversation must be either all in multimodal or all in text-only format.",
|
|
426
|
+
line_number=idx + 1,
|
|
427
|
+
error_source="key_value",
|
|
428
|
+
)
|
|
429
|
+
total_number_of_images += number_of_images
|
|
430
|
+
|
|
431
|
+
if total_number_of_images > MAX_IMAGES_PER_EXAMPLE:
|
|
432
|
+
raise InvalidFileFormatError(
|
|
433
|
+
f"The dataset is malformed, the `messages` must contain at most {MAX_IMAGES_PER_EXAMPLE} images. "
|
|
434
|
+
f"Found {total_number_of_images} images.",
|
|
435
|
+
line_number=idx + 1,
|
|
436
|
+
error_source="key_value",
|
|
437
|
+
)
|
|
261
438
|
|
|
262
439
|
_check_conversation_roles(require_assistant_role, assistant_role_exists, idx)
|
|
263
440
|
|
|
@@ -347,12 +524,7 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None:
|
|
|
347
524
|
error_source="key_value",
|
|
348
525
|
)
|
|
349
526
|
|
|
350
|
-
|
|
351
|
-
raise InvalidFileFormatError(
|
|
352
|
-
message=f"The dataset is malformed, the 'content' field in `{key}` must be a string on line {idx + 1}.",
|
|
353
|
-
line_number=idx + 1,
|
|
354
|
-
error_source="key_value",
|
|
355
|
-
)
|
|
527
|
+
_check_message_content(example[key][0]["content"], role="assistant", idx=idx)
|
|
356
528
|
|
|
357
529
|
|
|
358
530
|
def _check_utf8(file: Path) -> Dict[str, Any]:
|
|
@@ -454,8 +626,7 @@ def _check_csv(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
|
|
|
454
626
|
report_dict["load_csv"] = False
|
|
455
627
|
if idx < 0:
|
|
456
628
|
report_dict["message"] = (
|
|
457
|
-
"Unable to decode file. "
|
|
458
|
-
"File may be empty or in an unsupported format. "
|
|
629
|
+
"Unable to decode file. File may be empty or in an unsupported format. "
|
|
459
630
|
)
|
|
460
631
|
else:
|
|
461
632
|
report_dict["message"] = (
|
|
@@ -542,13 +713,10 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
|
|
|
542
713
|
)
|
|
543
714
|
else:
|
|
544
715
|
for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
|
|
545
|
-
if
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
line_number=idx + 1,
|
|
550
|
-
error_source="key_value",
|
|
551
|
-
)
|
|
716
|
+
role = "assistant" if column in {"completion"} else "user"
|
|
717
|
+
_check_message_content(
|
|
718
|
+
json_line[column], role=role, idx=idx
|
|
719
|
+
)
|
|
552
720
|
|
|
553
721
|
if dataset_format is None:
|
|
554
722
|
dataset_format = current_format
|
|
@@ -578,8 +746,7 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]:
|
|
|
578
746
|
report_dict["load_json"] = False
|
|
579
747
|
if idx < 0:
|
|
580
748
|
report_dict["message"] = (
|
|
581
|
-
"Unable to decode file. "
|
|
582
|
-
"File may be empty or in an unsupported format. "
|
|
749
|
+
"Unable to decode file. File may be empty or in an unsupported format. "
|
|
583
750
|
)
|
|
584
751
|
else:
|
|
585
752
|
report_dict["message"] = (
|
|
@@ -1,51 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from gettext import gettext as _
|
|
4
|
-
from typing import Literal
|
|
5
|
-
|
|
6
|
-
import click
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class AutoIntParamType(click.ParamType):
|
|
10
|
-
name = "integer_or_max"
|
|
11
|
-
_number_class = int
|
|
12
|
-
|
|
13
|
-
def convert(
|
|
14
|
-
self, value: str, param: click.Parameter | None, ctx: click.Context | None
|
|
15
|
-
) -> int | Literal["max"] | None:
|
|
16
|
-
if value == "max":
|
|
17
|
-
return "max"
|
|
18
|
-
try:
|
|
19
|
-
return int(value)
|
|
20
|
-
except ValueError:
|
|
21
|
-
self.fail(
|
|
22
|
-
_("{value!r} is not a valid {number_type}.").format(
|
|
23
|
-
value=value, number_type=self.name
|
|
24
|
-
),
|
|
25
|
-
param,
|
|
26
|
-
ctx,
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class BooleanWithAutoParamType(click.ParamType):
|
|
31
|
-
name = "boolean_or_auto"
|
|
32
|
-
|
|
33
|
-
def convert(
|
|
34
|
-
self, value: str, param: click.Parameter | None, ctx: click.Context | None
|
|
35
|
-
) -> bool | Literal["auto"] | None:
|
|
36
|
-
if value == "auto":
|
|
37
|
-
return "auto"
|
|
38
|
-
try:
|
|
39
|
-
return bool(value)
|
|
40
|
-
except ValueError:
|
|
41
|
-
self.fail(
|
|
42
|
-
_("{value!r} is not a valid {type}.").format(
|
|
43
|
-
value=value, type=self.name
|
|
44
|
-
),
|
|
45
|
-
param,
|
|
46
|
-
ctx,
|
|
47
|
-
)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
INT_WITH_MAX = AutoIntParamType()
|
|
51
|
-
BOOL_WITH_AUTO = BooleanWithAutoParamType()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|