together 1.3.3__tar.gz → 1.3.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {together-1.3.3 → together-1.3.5}/PKG-INFO +2 -2
- {together-1.3.3 → together-1.3.5}/pyproject.toml +3 -3
- {together-1.3.3 → together-1.3.5}/src/together/cli/api/finetune.py +45 -18
- {together-1.3.3 → together-1.3.5}/src/together/cli/api/utils.py +21 -0
- {together-1.3.3 → together-1.3.5}/src/together/constants.py +19 -0
- {together-1.3.3 → together-1.3.5}/src/together/resources/finetune.py +64 -4
- {together-1.3.3 → together-1.3.5}/src/together/types/__init__.py +4 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/finetune.py +24 -1
- together-1.3.5/src/together/utils/files.py +324 -0
- together-1.3.3/src/together/utils/files.py +0 -204
- {together-1.3.3 → together-1.3.5}/LICENSE +0 -0
- {together-1.3.3 → together-1.3.5}/README.md +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/__init__.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/abstract/__init__.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/abstract/api_requestor.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/cli/__init__.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/cli/api/__init__.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/cli/api/chat.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/cli/api/completions.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/cli/api/files.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/cli/api/images.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/cli/api/models.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/cli/cli.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/client.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/error.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/filemanager.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/legacy/__init__.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/legacy/base.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/legacy/complete.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/legacy/embeddings.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/legacy/files.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/legacy/finetune.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/legacy/images.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/legacy/models.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/resources/__init__.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/resources/chat/__init__.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/resources/chat/completions.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/resources/completions.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/resources/embeddings.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/resources/files.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/resources/images.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/resources/models.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/resources/rerank.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/together_response.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/abstract.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/chat_completions.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/common.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/completions.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/embeddings.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/error.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/files.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/images.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/models.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/types/rerank.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/utils/__init__.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/utils/_log.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/utils/api_helpers.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/utils/tools.py +0 -0
- {together-1.3.3 → together-1.3.5}/src/together/version.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: together
|
|
3
|
-
Version: 1.3.
|
|
3
|
+
Version: 1.3.5
|
|
4
4
|
Summary: Python client for Together's Cloud Platform!
|
|
5
5
|
Home-page: https://github.com/togethercomputer/together-python
|
|
6
6
|
License: Apache-2.0
|
|
@@ -29,7 +29,7 @@ Requires-Dist: requests (>=2.31.0,<3.0.0)
|
|
|
29
29
|
Requires-Dist: rich (>=13.8.1,<14.0.0)
|
|
30
30
|
Requires-Dist: tabulate (>=0.9.0,<0.10.0)
|
|
31
31
|
Requires-Dist: tqdm (>=4.66.2,<5.0.0)
|
|
32
|
-
Requires-Dist: typer (>=0.9,<0.
|
|
32
|
+
Requires-Dist: typer (>=0.9,<0.14)
|
|
33
33
|
Project-URL: Bug Tracker, https://github.com/togethercomputer/together-python/issues
|
|
34
34
|
Project-URL: Repository, https://github.com/togethercomputer/together-python
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
|
|
|
12
12
|
|
|
13
13
|
[tool.poetry]
|
|
14
14
|
name = "together"
|
|
15
|
-
version = "1.3.
|
|
15
|
+
version = "1.3.5"
|
|
16
16
|
authors = [
|
|
17
17
|
"Together AI <support@together.ai>"
|
|
18
18
|
]
|
|
@@ -29,7 +29,7 @@ homepage = "https://github.com/togethercomputer/together-python"
|
|
|
29
29
|
|
|
30
30
|
[tool.poetry.dependencies]
|
|
31
31
|
python = "^3.8"
|
|
32
|
-
typer = ">=0.9,<0.
|
|
32
|
+
typer = ">=0.9,<0.14"
|
|
33
33
|
requests = "^2.31.0"
|
|
34
34
|
rich = "^13.8.1"
|
|
35
35
|
tqdm = "^4.66.2"
|
|
@@ -51,7 +51,7 @@ optional = true
|
|
|
51
51
|
|
|
52
52
|
[tool.poetry.group.quality.dependencies]
|
|
53
53
|
black = ">=23.1,<25.0"
|
|
54
|
-
ruff = ">=0.3.2,<0.
|
|
54
|
+
ruff = ">=0.3.2,<0.8.0"
|
|
55
55
|
types-tqdm = "^4.65.0.0"
|
|
56
56
|
types-tabulate = "^0.9.0.3"
|
|
57
57
|
pre-commit = "3.5.0"
|
|
@@ -11,8 +11,13 @@ from rich import print as rprint
|
|
|
11
11
|
from tabulate import tabulate
|
|
12
12
|
|
|
13
13
|
from together import Together
|
|
14
|
-
from together.cli.api.utils import INT_WITH_MAX
|
|
15
|
-
from together.utils import
|
|
14
|
+
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
|
|
15
|
+
from together.utils import (
|
|
16
|
+
finetune_price_to_dollars,
|
|
17
|
+
log_warn,
|
|
18
|
+
log_warn_once,
|
|
19
|
+
parse_timestamp,
|
|
20
|
+
)
|
|
16
21
|
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
|
|
17
22
|
|
|
18
23
|
|
|
@@ -60,12 +65,30 @@ def fine_tuning(ctx: click.Context) -> None:
|
|
|
60
65
|
)
|
|
61
66
|
@click.option("--batch-size", type=INT_WITH_MAX, default="max", help="Train batch size")
|
|
62
67
|
@click.option("--learning-rate", type=float, default=1e-5, help="Learning rate")
|
|
68
|
+
@click.option(
|
|
69
|
+
"--min-lr-ratio",
|
|
70
|
+
type=float,
|
|
71
|
+
default=0.0,
|
|
72
|
+
help="The ratio of the final learning rate to the peak learning rate",
|
|
73
|
+
)
|
|
63
74
|
@click.option(
|
|
64
75
|
"--warmup-ratio",
|
|
65
76
|
type=float,
|
|
66
77
|
default=0.0,
|
|
67
78
|
help="Warmup ratio for learning rate scheduler.",
|
|
68
79
|
)
|
|
80
|
+
@click.option(
|
|
81
|
+
"--max-grad-norm",
|
|
82
|
+
type=float,
|
|
83
|
+
default=1.0,
|
|
84
|
+
help="Max gradient norm to be used for gradient clipping. Set to 0 to disable.",
|
|
85
|
+
)
|
|
86
|
+
@click.option(
|
|
87
|
+
"--weight-decay",
|
|
88
|
+
type=float,
|
|
89
|
+
default=0.0,
|
|
90
|
+
help="Weight decay",
|
|
91
|
+
)
|
|
69
92
|
@click.option(
|
|
70
93
|
"--lora/--no-lora",
|
|
71
94
|
type=bool,
|
|
@@ -93,6 +116,13 @@ def fine_tuning(ctx: click.Context) -> None:
|
|
|
93
116
|
default=False,
|
|
94
117
|
help="Whether to skip the launch confirmation message",
|
|
95
118
|
)
|
|
119
|
+
@click.option(
|
|
120
|
+
"--train-on-inputs",
|
|
121
|
+
type=BOOL_WITH_AUTO,
|
|
122
|
+
default="auto",
|
|
123
|
+
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
|
|
124
|
+
"`auto` will automatically determine whether to mask the inputs based on the data format.",
|
|
125
|
+
)
|
|
96
126
|
def create(
|
|
97
127
|
ctx: click.Context,
|
|
98
128
|
training_file: str,
|
|
@@ -103,7 +133,10 @@ def create(
|
|
|
103
133
|
n_checkpoints: int,
|
|
104
134
|
batch_size: int | Literal["max"],
|
|
105
135
|
learning_rate: float,
|
|
136
|
+
min_lr_ratio: float,
|
|
106
137
|
warmup_ratio: float,
|
|
138
|
+
max_grad_norm: float,
|
|
139
|
+
weight_decay: float,
|
|
107
140
|
lora: bool,
|
|
108
141
|
lora_r: int,
|
|
109
142
|
lora_dropout: float,
|
|
@@ -112,6 +145,7 @@ def create(
|
|
|
112
145
|
suffix: str,
|
|
113
146
|
wandb_api_key: str,
|
|
114
147
|
confirm: bool,
|
|
148
|
+
train_on_inputs: bool | Literal["auto"],
|
|
115
149
|
) -> None:
|
|
116
150
|
"""Start fine-tuning"""
|
|
117
151
|
client: Together = ctx.obj
|
|
@@ -125,7 +159,10 @@ def create(
|
|
|
125
159
|
n_checkpoints=n_checkpoints,
|
|
126
160
|
batch_size=batch_size,
|
|
127
161
|
learning_rate=learning_rate,
|
|
162
|
+
min_lr_ratio=min_lr_ratio,
|
|
128
163
|
warmup_ratio=warmup_ratio,
|
|
164
|
+
max_grad_norm=max_grad_norm,
|
|
165
|
+
weight_decay=weight_decay,
|
|
129
166
|
lora=lora,
|
|
130
167
|
lora_r=lora_r,
|
|
131
168
|
lora_dropout=lora_dropout,
|
|
@@ -133,6 +170,7 @@ def create(
|
|
|
133
170
|
lora_trainable_modules=lora_trainable_modules,
|
|
134
171
|
suffix=suffix,
|
|
135
172
|
wandb_api_key=wandb_api_key,
|
|
173
|
+
train_on_inputs=train_on_inputs,
|
|
136
174
|
)
|
|
137
175
|
|
|
138
176
|
model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
|
|
@@ -150,6 +188,10 @@ def create(
|
|
|
150
188
|
"batch_size": model_limits.lora_training.max_batch_size,
|
|
151
189
|
"learning_rate": 1e-3,
|
|
152
190
|
}
|
|
191
|
+
log_warn_once(
|
|
192
|
+
f"The default LoRA rank for {model} has been changed to {default_values['lora_r']} as the max available.\n"
|
|
193
|
+
f"Also, the default learning rate for LoRA fine-tuning has been changed to {default_values['learning_rate']}."
|
|
194
|
+
)
|
|
153
195
|
for arg in default_values:
|
|
154
196
|
arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined]
|
|
155
197
|
if arg_source == ParameterSource.DEFAULT:
|
|
@@ -186,22 +228,7 @@ def create(
|
|
|
186
228
|
|
|
187
229
|
if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
|
|
188
230
|
response = client.fine_tuning.create(
|
|
189
|
-
|
|
190
|
-
model=model,
|
|
191
|
-
n_epochs=n_epochs,
|
|
192
|
-
validation_file=validation_file,
|
|
193
|
-
n_evals=n_evals,
|
|
194
|
-
n_checkpoints=n_checkpoints,
|
|
195
|
-
batch_size=batch_size,
|
|
196
|
-
learning_rate=learning_rate,
|
|
197
|
-
warmup_ratio=warmup_ratio,
|
|
198
|
-
lora=lora,
|
|
199
|
-
lora_r=lora_r,
|
|
200
|
-
lora_dropout=lora_dropout,
|
|
201
|
-
lora_alpha=lora_alpha,
|
|
202
|
-
lora_trainable_modules=lora_trainable_modules,
|
|
203
|
-
suffix=suffix,
|
|
204
|
-
wandb_api_key=wandb_api_key,
|
|
231
|
+
**training_args,
|
|
205
232
|
verbose=True,
|
|
206
233
|
)
|
|
207
234
|
|
|
@@ -27,4 +27,25 @@ class AutoIntParamType(click.ParamType):
|
|
|
27
27
|
)
|
|
28
28
|
|
|
29
29
|
|
|
30
|
+
class BooleanWithAutoParamType(click.ParamType):
|
|
31
|
+
name = "boolean_or_auto"
|
|
32
|
+
|
|
33
|
+
def convert(
|
|
34
|
+
self, value: str, param: click.Parameter | None, ctx: click.Context | None
|
|
35
|
+
) -> bool | Literal["auto"] | None:
|
|
36
|
+
if value == "auto":
|
|
37
|
+
return "auto"
|
|
38
|
+
try:
|
|
39
|
+
return bool(value)
|
|
40
|
+
except ValueError:
|
|
41
|
+
self.fail(
|
|
42
|
+
_("{value!r} is not a valid {type}.").format(
|
|
43
|
+
value=value, type=self.name
|
|
44
|
+
),
|
|
45
|
+
param,
|
|
46
|
+
ctx,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
30
50
|
INT_WITH_MAX = AutoIntParamType()
|
|
51
|
+
BOOL_WITH_AUTO = BooleanWithAutoParamType()
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
|
|
1
3
|
# Session constants
|
|
2
4
|
TIMEOUT_SECS = 600
|
|
3
5
|
MAX_SESSION_LIFETIME_SECS = 180
|
|
@@ -29,3 +31,20 @@ MAX_FILE_SIZE_GB = 4.9
|
|
|
29
31
|
|
|
30
32
|
# expected columns for Parquet files
|
|
31
33
|
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DatasetFormat(enum.Enum):
|
|
37
|
+
"""Dataset format enum."""
|
|
38
|
+
|
|
39
|
+
GENERAL = "general"
|
|
40
|
+
CONVERSATION = "conversation"
|
|
41
|
+
INSTRUCTION = "instruction"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
JSONL_REQUIRED_COLUMNS_MAP = {
|
|
45
|
+
DatasetFormat.GENERAL: ["text"],
|
|
46
|
+
DatasetFormat.CONVERSATION: ["messages"],
|
|
47
|
+
DatasetFormat.INSTRUCTION: ["prompt", "completion"],
|
|
48
|
+
}
|
|
49
|
+
REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
|
|
50
|
+
POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
|
|
@@ -20,6 +20,8 @@ from together.types import (
|
|
|
20
20
|
TogetherClient,
|
|
21
21
|
TogetherRequest,
|
|
22
22
|
TrainingType,
|
|
23
|
+
FinetuneLRScheduler,
|
|
24
|
+
FinetuneLinearLRSchedulerArgs,
|
|
23
25
|
)
|
|
24
26
|
from together.types.finetune import DownloadCheckpointType
|
|
25
27
|
from together.utils import log_warn_once, normalize_key
|
|
@@ -35,7 +37,10 @@ def createFinetuneRequest(
|
|
|
35
37
|
n_checkpoints: int | None = 1,
|
|
36
38
|
batch_size: int | Literal["max"] = "max",
|
|
37
39
|
learning_rate: float | None = 0.00001,
|
|
38
|
-
|
|
40
|
+
min_lr_ratio: float = 0.0,
|
|
41
|
+
warmup_ratio: float = 0.0,
|
|
42
|
+
max_grad_norm: float = 1.0,
|
|
43
|
+
weight_decay: float = 0.0,
|
|
39
44
|
lora: bool = False,
|
|
40
45
|
lora_r: int | None = None,
|
|
41
46
|
lora_dropout: float | None = 0,
|
|
@@ -43,6 +48,7 @@ def createFinetuneRequest(
|
|
|
43
48
|
lora_trainable_modules: str | None = "all-linear",
|
|
44
49
|
suffix: str | None = None,
|
|
45
50
|
wandb_api_key: str | None = None,
|
|
51
|
+
train_on_inputs: bool | Literal["auto"] = "auto",
|
|
46
52
|
) -> FinetuneRequest:
|
|
47
53
|
if batch_size == "max":
|
|
48
54
|
log_warn_once(
|
|
@@ -82,6 +88,20 @@ def createFinetuneRequest(
|
|
|
82
88
|
if warmup_ratio > 1 or warmup_ratio < 0:
|
|
83
89
|
raise ValueError("Warmup ratio should be between 0 and 1")
|
|
84
90
|
|
|
91
|
+
if min_lr_ratio is not None and (min_lr_ratio > 1 or min_lr_ratio < 0):
|
|
92
|
+
raise ValueError("Min learning rate ratio should be between 0 and 1")
|
|
93
|
+
|
|
94
|
+
if max_grad_norm < 0:
|
|
95
|
+
raise ValueError("Max gradient norm should be non-negative")
|
|
96
|
+
|
|
97
|
+
if weight_decay is not None and (weight_decay < 0):
|
|
98
|
+
raise ValueError("Weight decay should be non-negative")
|
|
99
|
+
|
|
100
|
+
lrScheduler = FinetuneLRScheduler(
|
|
101
|
+
lr_scheduler_type="linear",
|
|
102
|
+
lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
|
|
103
|
+
)
|
|
104
|
+
|
|
85
105
|
finetune_request = FinetuneRequest(
|
|
86
106
|
model=model,
|
|
87
107
|
training_file=training_file,
|
|
@@ -91,10 +111,14 @@ def createFinetuneRequest(
|
|
|
91
111
|
n_checkpoints=n_checkpoints,
|
|
92
112
|
batch_size=batch_size,
|
|
93
113
|
learning_rate=learning_rate,
|
|
114
|
+
lr_scheduler=lrScheduler,
|
|
94
115
|
warmup_ratio=warmup_ratio,
|
|
116
|
+
max_grad_norm=max_grad_norm,
|
|
117
|
+
weight_decay=weight_decay,
|
|
95
118
|
training_type=training_type,
|
|
96
119
|
suffix=suffix,
|
|
97
120
|
wandb_key=wandb_api_key,
|
|
121
|
+
train_on_inputs=train_on_inputs,
|
|
98
122
|
)
|
|
99
123
|
|
|
100
124
|
return finetune_request
|
|
@@ -115,7 +139,10 @@ class FineTuning:
|
|
|
115
139
|
n_checkpoints: int | None = 1,
|
|
116
140
|
batch_size: int | Literal["max"] = "max",
|
|
117
141
|
learning_rate: float | None = 0.00001,
|
|
118
|
-
|
|
142
|
+
min_lr_ratio: float = 0.0,
|
|
143
|
+
warmup_ratio: float = 0.0,
|
|
144
|
+
max_grad_norm: float = 1.0,
|
|
145
|
+
weight_decay: float = 0.0,
|
|
119
146
|
lora: bool = False,
|
|
120
147
|
lora_r: int | None = None,
|
|
121
148
|
lora_dropout: float | None = 0,
|
|
@@ -125,6 +152,7 @@ class FineTuning:
|
|
|
125
152
|
wandb_api_key: str | None = None,
|
|
126
153
|
verbose: bool = False,
|
|
127
154
|
model_limits: FinetuneTrainingLimits | None = None,
|
|
155
|
+
train_on_inputs: bool | Literal["auto"] = "auto",
|
|
128
156
|
) -> FinetuneResponse:
|
|
129
157
|
"""
|
|
130
158
|
Method to initiate a fine-tuning job
|
|
@@ -137,10 +165,14 @@ class FineTuning:
|
|
|
137
165
|
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
|
|
138
166
|
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
|
|
139
167
|
Defaults to 1.
|
|
140
|
-
batch_size (int
|
|
168
|
+
batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
|
|
141
169
|
learning_rate (float, optional): Learning rate multiplier to use for training
|
|
142
170
|
Defaults to 0.00001.
|
|
171
|
+
min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
|
|
172
|
+
the learning rate scheduler. Defaults to 0.0.
|
|
143
173
|
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
|
|
174
|
+
max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
|
|
175
|
+
weight_decay (float, optional): Weight decay. Defaults to 0.0.
|
|
144
176
|
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
|
|
145
177
|
lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
|
|
146
178
|
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
|
|
@@ -154,6 +186,12 @@ class FineTuning:
|
|
|
154
186
|
Defaults to False.
|
|
155
187
|
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
|
|
156
188
|
Defaults to None.
|
|
189
|
+
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
|
|
190
|
+
"auto" will automatically determine whether to mask the inputs based on the data format.
|
|
191
|
+
For datasets with the "text" field (general format), inputs will not be masked.
|
|
192
|
+
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
|
|
193
|
+
(Instruction format), inputs will be masked.
|
|
194
|
+
Defaults to "auto".
|
|
157
195
|
|
|
158
196
|
Returns:
|
|
159
197
|
FinetuneResponse: Object containing information about fine-tuning job.
|
|
@@ -176,7 +214,10 @@ class FineTuning:
|
|
|
176
214
|
n_checkpoints=n_checkpoints,
|
|
177
215
|
batch_size=batch_size,
|
|
178
216
|
learning_rate=learning_rate,
|
|
217
|
+
min_lr_ratio=min_lr_ratio,
|
|
179
218
|
warmup_ratio=warmup_ratio,
|
|
219
|
+
max_grad_norm=max_grad_norm,
|
|
220
|
+
weight_decay=weight_decay,
|
|
180
221
|
lora=lora,
|
|
181
222
|
lora_r=lora_r,
|
|
182
223
|
lora_dropout=lora_dropout,
|
|
@@ -184,6 +225,7 @@ class FineTuning:
|
|
|
184
225
|
lora_trainable_modules=lora_trainable_modules,
|
|
185
226
|
suffix=suffix,
|
|
186
227
|
wandb_api_key=wandb_api_key,
|
|
228
|
+
train_on_inputs=train_on_inputs,
|
|
187
229
|
)
|
|
188
230
|
|
|
189
231
|
if verbose:
|
|
@@ -426,7 +468,10 @@ class AsyncFineTuning:
|
|
|
426
468
|
n_checkpoints: int | None = 1,
|
|
427
469
|
batch_size: int | Literal["max"] = "max",
|
|
428
470
|
learning_rate: float | None = 0.00001,
|
|
429
|
-
|
|
471
|
+
min_lr_ratio: float = 0.0,
|
|
472
|
+
warmup_ratio: float = 0.0,
|
|
473
|
+
max_grad_norm: float = 1.0,
|
|
474
|
+
weight_decay: float = 0.0,
|
|
430
475
|
lora: bool = False,
|
|
431
476
|
lora_r: int | None = None,
|
|
432
477
|
lora_dropout: float | None = 0,
|
|
@@ -436,6 +481,7 @@ class AsyncFineTuning:
|
|
|
436
481
|
wandb_api_key: str | None = None,
|
|
437
482
|
verbose: bool = False,
|
|
438
483
|
model_limits: FinetuneTrainingLimits | None = None,
|
|
484
|
+
train_on_inputs: bool | Literal["auto"] = "auto",
|
|
439
485
|
) -> FinetuneResponse:
|
|
440
486
|
"""
|
|
441
487
|
Async method to initiate a fine-tuning job
|
|
@@ -451,7 +497,11 @@ class AsyncFineTuning:
|
|
|
451
497
|
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
|
|
452
498
|
learning_rate (float, optional): Learning rate multiplier to use for training
|
|
453
499
|
Defaults to 0.00001.
|
|
500
|
+
min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
|
|
501
|
+
the learning rate scheduler. Defaults to 0.0.
|
|
454
502
|
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
|
|
503
|
+
max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
|
|
504
|
+
weight_decay (float, optional): Weight decay. Defaults to 0.0.
|
|
455
505
|
lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
|
|
456
506
|
lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
|
|
457
507
|
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
|
|
@@ -465,6 +515,12 @@ class AsyncFineTuning:
|
|
|
465
515
|
Defaults to False.
|
|
466
516
|
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
|
|
467
517
|
Defaults to None.
|
|
518
|
+
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
|
|
519
|
+
"auto" will automatically determine whether to mask the inputs based on the data format.
|
|
520
|
+
For datasets with the "text" field (general format), inputs will not be masked.
|
|
521
|
+
For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
|
|
522
|
+
(Instruction format), inputs will be masked.
|
|
523
|
+
Defaults to "auto".
|
|
468
524
|
|
|
469
525
|
Returns:
|
|
470
526
|
FinetuneResponse: Object containing information about fine-tuning job.
|
|
@@ -487,7 +543,10 @@ class AsyncFineTuning:
|
|
|
487
543
|
n_checkpoints=n_checkpoints,
|
|
488
544
|
batch_size=batch_size,
|
|
489
545
|
learning_rate=learning_rate,
|
|
546
|
+
min_lr_ratio=min_lr_ratio,
|
|
490
547
|
warmup_ratio=warmup_ratio,
|
|
548
|
+
max_grad_norm=max_grad_norm,
|
|
549
|
+
weight_decay=weight_decay,
|
|
491
550
|
lora=lora,
|
|
492
551
|
lora_r=lora_r,
|
|
493
552
|
lora_dropout=lora_dropout,
|
|
@@ -495,6 +554,7 @@ class AsyncFineTuning:
|
|
|
495
554
|
lora_trainable_modules=lora_trainable_modules,
|
|
496
555
|
suffix=suffix,
|
|
497
556
|
wandb_api_key=wandb_api_key,
|
|
557
|
+
train_on_inputs=train_on_inputs,
|
|
498
558
|
)
|
|
499
559
|
|
|
500
560
|
if verbose:
|
|
@@ -30,6 +30,8 @@ from together.types.finetune import (
|
|
|
30
30
|
LoRATrainingType,
|
|
31
31
|
TrainingType,
|
|
32
32
|
FinetuneTrainingLimits,
|
|
33
|
+
FinetuneLRScheduler,
|
|
34
|
+
FinetuneLinearLRSchedulerArgs,
|
|
33
35
|
)
|
|
34
36
|
from together.types.images import (
|
|
35
37
|
ImageRequest,
|
|
@@ -57,6 +59,8 @@ __all__ = [
|
|
|
57
59
|
"FinetuneList",
|
|
58
60
|
"FinetuneListEvents",
|
|
59
61
|
"FinetuneDownloadResult",
|
|
62
|
+
"FinetuneLRScheduler",
|
|
63
|
+
"FinetuneLinearLRSchedulerArgs",
|
|
60
64
|
"FileRequest",
|
|
61
65
|
"FileResponse",
|
|
62
66
|
"FileList",
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from enum import Enum
|
|
4
4
|
from typing import List, Literal
|
|
5
5
|
|
|
6
|
-
from pydantic import Field, validator, field_validator
|
|
6
|
+
from pydantic import StrictBool, Field, validator, field_validator
|
|
7
7
|
|
|
8
8
|
from together.types.abstract import BaseModel
|
|
9
9
|
from together.types.common import (
|
|
@@ -150,8 +150,14 @@ class FinetuneRequest(BaseModel):
|
|
|
150
150
|
n_epochs: int
|
|
151
151
|
# training learning rate
|
|
152
152
|
learning_rate: float
|
|
153
|
+
# learning rate scheduler type and args
|
|
154
|
+
lr_scheduler: FinetuneLRScheduler | None = None
|
|
153
155
|
# learning rate warmup ratio
|
|
154
156
|
warmup_ratio: float
|
|
157
|
+
# max gradient norm
|
|
158
|
+
max_grad_norm: float
|
|
159
|
+
# weight decay
|
|
160
|
+
weight_decay: float
|
|
155
161
|
# number of checkpoints to save
|
|
156
162
|
n_checkpoints: int | None = None
|
|
157
163
|
# number of evaluation loops to run
|
|
@@ -163,6 +169,7 @@ class FinetuneRequest(BaseModel):
|
|
|
163
169
|
# weights & biases api key
|
|
164
170
|
wandb_key: str | None = None
|
|
165
171
|
training_type: FullTrainingType | LoRATrainingType | None = None
|
|
172
|
+
train_on_inputs: StrictBool | Literal["auto"] = "auto"
|
|
166
173
|
|
|
167
174
|
|
|
168
175
|
class FinetuneResponse(BaseModel):
|
|
@@ -192,8 +199,14 @@ class FinetuneResponse(BaseModel):
|
|
|
192
199
|
batch_size: int | None = None
|
|
193
200
|
# training learning rate
|
|
194
201
|
learning_rate: float | None = None
|
|
202
|
+
# learning rate scheduler type and args
|
|
203
|
+
lr_scheduler: FinetuneLRScheduler | None = None
|
|
195
204
|
# learning rate warmup ratio
|
|
196
205
|
warmup_ratio: float | None = None
|
|
206
|
+
# max gradient norm
|
|
207
|
+
max_grad_norm: float | None = None
|
|
208
|
+
# weight decay
|
|
209
|
+
weight_decay: float | None = None
|
|
197
210
|
# number of steps between evals
|
|
198
211
|
eval_steps: int | None = None
|
|
199
212
|
# training type
|
|
@@ -230,6 +243,7 @@ class FinetuneResponse(BaseModel):
|
|
|
230
243
|
# training file metadata
|
|
231
244
|
training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
|
|
232
245
|
training_file_size: int | None = Field(None, alias="TrainingFileSize")
|
|
246
|
+
train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
|
|
233
247
|
|
|
234
248
|
@field_validator("training_type")
|
|
235
249
|
@classmethod
|
|
@@ -285,3 +299,12 @@ class FinetuneTrainingLimits(BaseModel):
|
|
|
285
299
|
min_learning_rate: float
|
|
286
300
|
full_training: FinetuneFullTrainingLimits | None = None
|
|
287
301
|
lora_training: FinetuneLoraTrainingLimits | None = None
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class FinetuneLRScheduler(BaseModel):
|
|
305
|
+
lr_scheduler_type: str
|
|
306
|
+
lr_scheduler_args: FinetuneLinearLRSchedulerArgs | None = None
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
class FinetuneLinearLRSchedulerArgs(BaseModel):
|
|
310
|
+
min_lr_ratio: float | None = 0.0
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from traceback import format_exc
|
|
7
|
+
from typing import Any, Dict
|
|
8
|
+
|
|
9
|
+
from pyarrow import ArrowInvalid, parquet
|
|
10
|
+
|
|
11
|
+
from together.constants import (
|
|
12
|
+
MAX_FILE_SIZE_GB,
|
|
13
|
+
MIN_SAMPLES,
|
|
14
|
+
NUM_BYTES_IN_GB,
|
|
15
|
+
PARQUET_EXPECTED_COLUMNS,
|
|
16
|
+
JSONL_REQUIRED_COLUMNS_MAP,
|
|
17
|
+
REQUIRED_COLUMNS_MESSAGE,
|
|
18
|
+
POSSIBLE_ROLES_CONVERSATION,
|
|
19
|
+
DatasetFormat,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class InvalidFileFormatError(ValueError):
|
|
24
|
+
"""Exception raised for invalid file formats during file checks."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
message: str = "",
|
|
29
|
+
line_number: int | None = None,
|
|
30
|
+
error_source: str | None = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__(message)
|
|
33
|
+
self.message = message
|
|
34
|
+
self.line_number = line_number
|
|
35
|
+
self.error_source = error_source
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def check_file(
|
|
39
|
+
file: Path | str,
|
|
40
|
+
) -> Dict[str, Any]:
|
|
41
|
+
if not isinstance(file, Path):
|
|
42
|
+
file = Path(file)
|
|
43
|
+
|
|
44
|
+
report_dict = {
|
|
45
|
+
"is_check_passed": True,
|
|
46
|
+
"message": "Checks passed",
|
|
47
|
+
"found": None,
|
|
48
|
+
"file_size": None,
|
|
49
|
+
"utf8": None,
|
|
50
|
+
"line_type": None,
|
|
51
|
+
"text_field": None,
|
|
52
|
+
"key_value": None,
|
|
53
|
+
"has_min_samples": None,
|
|
54
|
+
"num_samples": None,
|
|
55
|
+
"load_json": None,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
if not file.is_file():
|
|
59
|
+
report_dict["found"] = False
|
|
60
|
+
report_dict["is_check_passed"] = False
|
|
61
|
+
return report_dict
|
|
62
|
+
else:
|
|
63
|
+
report_dict["found"] = True
|
|
64
|
+
|
|
65
|
+
file_size = os.stat(file.as_posix()).st_size
|
|
66
|
+
|
|
67
|
+
if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
|
|
68
|
+
report_dict["message"] = (
|
|
69
|
+
f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB ,3)} GB."
|
|
70
|
+
)
|
|
71
|
+
report_dict["is_check_passed"] = False
|
|
72
|
+
elif file_size == 0:
|
|
73
|
+
report_dict["message"] = "File is empty"
|
|
74
|
+
report_dict["file_size"] = 0
|
|
75
|
+
report_dict["is_check_passed"] = False
|
|
76
|
+
return report_dict
|
|
77
|
+
else:
|
|
78
|
+
report_dict["file_size"] = file_size
|
|
79
|
+
|
|
80
|
+
data_report_dict = {}
|
|
81
|
+
if file.suffix == ".jsonl":
|
|
82
|
+
report_dict["filetype"] = "jsonl"
|
|
83
|
+
data_report_dict = _check_jsonl(file)
|
|
84
|
+
elif file.suffix == ".parquet":
|
|
85
|
+
report_dict["filetype"] = "parquet"
|
|
86
|
+
data_report_dict = _check_parquet(file)
|
|
87
|
+
else:
|
|
88
|
+
report_dict["filetype"] = (
|
|
89
|
+
f"Unknown extension of file {file}. "
|
|
90
|
+
"Only files with extensions .jsonl and .parquet are supported."
|
|
91
|
+
)
|
|
92
|
+
report_dict["is_check_passed"] = False
|
|
93
|
+
|
|
94
|
+
report_dict.update(data_report_dict)
|
|
95
|
+
|
|
96
|
+
return report_dict
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _check_jsonl(file: Path) -> Dict[str, Any]:
|
|
100
|
+
report_dict: Dict[str, Any] = {}
|
|
101
|
+
# Check that the file is UTF-8 encoded. If not report where the error occurs.
|
|
102
|
+
try:
|
|
103
|
+
with file.open(encoding="utf-8") as f:
|
|
104
|
+
f.read()
|
|
105
|
+
report_dict["utf8"] = True
|
|
106
|
+
except UnicodeDecodeError as e:
|
|
107
|
+
report_dict["utf8"] = False
|
|
108
|
+
report_dict["message"] = f"File is not UTF-8 encoded. Error raised: {e}."
|
|
109
|
+
report_dict["is_check_passed"] = False
|
|
110
|
+
return report_dict
|
|
111
|
+
|
|
112
|
+
dataset_format = None
|
|
113
|
+
with file.open() as f:
|
|
114
|
+
idx = -1
|
|
115
|
+
try:
|
|
116
|
+
for idx, line in enumerate(f):
|
|
117
|
+
json_line = json.loads(line)
|
|
118
|
+
|
|
119
|
+
if not isinstance(json_line, dict):
|
|
120
|
+
raise InvalidFileFormatError(
|
|
121
|
+
message=(
|
|
122
|
+
f"Error parsing file. Invalid format on line {idx + 1} of the input file. "
|
|
123
|
+
'Example of valid json: {"text": "my sample string"}. '
|
|
124
|
+
),
|
|
125
|
+
line_number=idx + 1,
|
|
126
|
+
error_source="line_type",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
current_format = None
|
|
130
|
+
for possible_format in JSONL_REQUIRED_COLUMNS_MAP:
|
|
131
|
+
if all(
|
|
132
|
+
column in json_line
|
|
133
|
+
for column in JSONL_REQUIRED_COLUMNS_MAP[possible_format]
|
|
134
|
+
):
|
|
135
|
+
if current_format is None:
|
|
136
|
+
current_format = possible_format
|
|
137
|
+
elif current_format != possible_format:
|
|
138
|
+
raise InvalidFileFormatError(
|
|
139
|
+
message="Found multiple dataset formats in the input file. "
|
|
140
|
+
f"Got {current_format} and {possible_format} on line {idx + 1}.",
|
|
141
|
+
line_number=idx + 1,
|
|
142
|
+
error_source="format",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if current_format is None:
|
|
146
|
+
raise InvalidFileFormatError(
|
|
147
|
+
message=(
|
|
148
|
+
f"Error parsing file. Could not detect a format for the line {idx + 1} with the columns:\n"
|
|
149
|
+
f"{json_line.keys()}"
|
|
150
|
+
),
|
|
151
|
+
line_number=idx + 1,
|
|
152
|
+
error_source="format",
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
if current_format == DatasetFormat.CONVERSATION:
|
|
156
|
+
message_column = JSONL_REQUIRED_COLUMNS_MAP[
|
|
157
|
+
DatasetFormat.CONVERSATION
|
|
158
|
+
][0]
|
|
159
|
+
if not isinstance(json_line[message_column], list):
|
|
160
|
+
raise InvalidFileFormatError(
|
|
161
|
+
message=f"Invalid format on line {idx + 1} of the input file. "
|
|
162
|
+
f"Expected a list of messages. Found {type(json_line[message_column])}",
|
|
163
|
+
line_number=idx + 1,
|
|
164
|
+
error_source="key_value",
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
for turn_id, turn in enumerate(json_line[message_column]):
|
|
168
|
+
if not isinstance(turn, dict):
|
|
169
|
+
raise InvalidFileFormatError(
|
|
170
|
+
message=f"Invalid format on line {idx + 1} of the input file. "
|
|
171
|
+
f"Expected a dictionary in the {turn_id + 1} turn. Found {type(turn)}",
|
|
172
|
+
line_number=idx + 1,
|
|
173
|
+
error_source="key_value",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
previous_role = None
|
|
177
|
+
for turn in json_line[message_column]:
|
|
178
|
+
for column in REQUIRED_COLUMNS_MESSAGE:
|
|
179
|
+
if column not in turn:
|
|
180
|
+
raise InvalidFileFormatError(
|
|
181
|
+
message=f"Field `{column}` is missing for a turn `{turn}` on line {idx + 1} "
|
|
182
|
+
"of the the input file.",
|
|
183
|
+
line_number=idx + 1,
|
|
184
|
+
error_source="key_value",
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
if not isinstance(turn[column], str):
|
|
188
|
+
raise InvalidFileFormatError(
|
|
189
|
+
message=f"Invalid format on line {idx + 1} in the column {column} for turn `{turn}` "
|
|
190
|
+
f"of the input file. Expected string. Found {type(turn[column])}",
|
|
191
|
+
line_number=idx + 1,
|
|
192
|
+
error_source="text_field",
|
|
193
|
+
)
|
|
194
|
+
role = turn["role"]
|
|
195
|
+
|
|
196
|
+
if role not in POSSIBLE_ROLES_CONVERSATION:
|
|
197
|
+
raise InvalidFileFormatError(
|
|
198
|
+
message=f"Found invalid role `{role}` in the messages on the line {idx + 1}. "
|
|
199
|
+
f"Possible roles in the conversation are: {POSSIBLE_ROLES_CONVERSATION}",
|
|
200
|
+
line_number=idx + 1,
|
|
201
|
+
error_source="key_value",
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
if previous_role == role:
|
|
205
|
+
raise InvalidFileFormatError(
|
|
206
|
+
message=f"Invalid role turns on line {idx + 1} of the input file. "
|
|
207
|
+
"`user` and `assistant` roles must alternate user/assistant/user/assistant/...",
|
|
208
|
+
line_number=idx + 1,
|
|
209
|
+
error_source="key_value",
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
previous_role = role
|
|
213
|
+
|
|
214
|
+
else:
|
|
215
|
+
for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
|
|
216
|
+
if not isinstance(json_line[column], str):
|
|
217
|
+
raise InvalidFileFormatError(
|
|
218
|
+
message=f'Invalid value type for "{column}" key on line {idx + 1}. '
|
|
219
|
+
f"Expected string. Found {type(json_line[column])}.",
|
|
220
|
+
line_number=idx + 1,
|
|
221
|
+
error_source="key_value",
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
if dataset_format is None:
|
|
225
|
+
dataset_format = current_format
|
|
226
|
+
elif current_format is not None:
|
|
227
|
+
if current_format != dataset_format:
|
|
228
|
+
raise InvalidFileFormatError(
|
|
229
|
+
message="All samples in the dataset must have the same dataset format. "
|
|
230
|
+
f"Got {dataset_format} for the first line and {current_format} "
|
|
231
|
+
f"for the line {idx + 1}.",
|
|
232
|
+
line_number=idx + 1,
|
|
233
|
+
error_source="format",
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
if idx + 1 < MIN_SAMPLES:
|
|
237
|
+
report_dict["has_min_samples"] = False
|
|
238
|
+
report_dict["message"] = (
|
|
239
|
+
f"Processing {file} resulted in only {idx + 1} samples. "
|
|
240
|
+
f"Our minimum is {MIN_SAMPLES} samples. "
|
|
241
|
+
)
|
|
242
|
+
report_dict["is_check_passed"] = False
|
|
243
|
+
else:
|
|
244
|
+
report_dict["num_samples"] = idx + 1
|
|
245
|
+
report_dict["has_min_samples"] = True
|
|
246
|
+
report_dict["is_check_passed"] = True
|
|
247
|
+
|
|
248
|
+
report_dict["load_json"] = True
|
|
249
|
+
|
|
250
|
+
except InvalidFileFormatError as e:
|
|
251
|
+
report_dict["load_json"] = False
|
|
252
|
+
report_dict["is_check_passed"] = False
|
|
253
|
+
report_dict["message"] = e.message
|
|
254
|
+
if e.line_number is not None:
|
|
255
|
+
report_dict["line_number"] = e.line_number
|
|
256
|
+
if e.error_source is not None:
|
|
257
|
+
report_dict[e.error_source] = False
|
|
258
|
+
except ValueError:
|
|
259
|
+
report_dict["load_json"] = False
|
|
260
|
+
if idx < 0:
|
|
261
|
+
report_dict["message"] = (
|
|
262
|
+
"Unable to decode file. "
|
|
263
|
+
"File may be empty or in an unsupported format. "
|
|
264
|
+
)
|
|
265
|
+
else:
|
|
266
|
+
report_dict["message"] = (
|
|
267
|
+
f"Error parsing json payload. Unexpected format on line {idx + 1}."
|
|
268
|
+
)
|
|
269
|
+
report_dict["is_check_passed"] = False
|
|
270
|
+
|
|
271
|
+
if "text_field" not in report_dict:
|
|
272
|
+
report_dict["text_field"] = True
|
|
273
|
+
if "line_type" not in report_dict:
|
|
274
|
+
report_dict["line_type"] = True
|
|
275
|
+
if "key_value" not in report_dict:
|
|
276
|
+
report_dict["key_value"] = True
|
|
277
|
+
return report_dict
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _check_parquet(file: Path) -> Dict[str, Any]:
|
|
281
|
+
report_dict: Dict[str, Any] = {}
|
|
282
|
+
|
|
283
|
+
try:
|
|
284
|
+
table = parquet.read_table(str(file), memory_map=True)
|
|
285
|
+
except ArrowInvalid:
|
|
286
|
+
report_dict["load_parquet"] = (
|
|
287
|
+
f"An exception has occurred when loading the Parquet file {file}. Please check the file for corruption. "
|
|
288
|
+
f"Exception trace:\n{format_exc()}"
|
|
289
|
+
)
|
|
290
|
+
report_dict["is_check_passed"] = False
|
|
291
|
+
return report_dict
|
|
292
|
+
|
|
293
|
+
column_names = table.schema.names
|
|
294
|
+
if "input_ids" not in column_names:
|
|
295
|
+
report_dict["load_parquet"] = (
|
|
296
|
+
f"Parquet file {file} does not contain the `input_ids` column."
|
|
297
|
+
)
|
|
298
|
+
report_dict["is_check_passed"] = False
|
|
299
|
+
return report_dict
|
|
300
|
+
|
|
301
|
+
for column_name in column_names:
|
|
302
|
+
if column_name not in PARQUET_EXPECTED_COLUMNS:
|
|
303
|
+
report_dict["load_parquet"] = (
|
|
304
|
+
f"Parquet file {file} contains an unexpected column {column_name}. "
|
|
305
|
+
f"Only columns {PARQUET_EXPECTED_COLUMNS} are supported."
|
|
306
|
+
)
|
|
307
|
+
report_dict["is_check_passed"] = False
|
|
308
|
+
return report_dict
|
|
309
|
+
|
|
310
|
+
num_samples = len(table)
|
|
311
|
+
if num_samples < MIN_SAMPLES:
|
|
312
|
+
report_dict["has_min_samples"] = False
|
|
313
|
+
report_dict["message"] = (
|
|
314
|
+
f"Processing {file} resulted in only {num_samples} samples. "
|
|
315
|
+
f"Our minimum is {MIN_SAMPLES} samples. "
|
|
316
|
+
)
|
|
317
|
+
report_dict["is_check_passed"] = False
|
|
318
|
+
return report_dict
|
|
319
|
+
else:
|
|
320
|
+
report_dict["num_samples"] = num_samples
|
|
321
|
+
|
|
322
|
+
report_dict["is_check_passed"] = True
|
|
323
|
+
|
|
324
|
+
return report_dict
|
|
@@ -1,204 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
import json
|
|
4
|
-
import os
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from traceback import format_exc
|
|
7
|
-
from typing import Any, Dict
|
|
8
|
-
|
|
9
|
-
from pyarrow import ArrowInvalid, parquet
|
|
10
|
-
|
|
11
|
-
from together.constants import (
|
|
12
|
-
MAX_FILE_SIZE_GB,
|
|
13
|
-
MIN_SAMPLES,
|
|
14
|
-
NUM_BYTES_IN_GB,
|
|
15
|
-
PARQUET_EXPECTED_COLUMNS,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def check_file(
|
|
20
|
-
file: Path | str,
|
|
21
|
-
) -> Dict[str, Any]:
|
|
22
|
-
if not isinstance(file, Path):
|
|
23
|
-
file = Path(file)
|
|
24
|
-
|
|
25
|
-
report_dict = {
|
|
26
|
-
"is_check_passed": True,
|
|
27
|
-
"message": "Checks passed",
|
|
28
|
-
"found": None,
|
|
29
|
-
"file_size": None,
|
|
30
|
-
"utf8": None,
|
|
31
|
-
"line_type": None,
|
|
32
|
-
"text_field": None,
|
|
33
|
-
"key_value": None,
|
|
34
|
-
"min_samples": None,
|
|
35
|
-
"num_samples": None,
|
|
36
|
-
"load_json": None,
|
|
37
|
-
}
|
|
38
|
-
|
|
39
|
-
if not file.is_file():
|
|
40
|
-
report_dict["found"] = False
|
|
41
|
-
report_dict["is_check_passed"] = False
|
|
42
|
-
return report_dict
|
|
43
|
-
else:
|
|
44
|
-
report_dict["found"] = True
|
|
45
|
-
|
|
46
|
-
file_size = os.stat(file.as_posix()).st_size
|
|
47
|
-
|
|
48
|
-
if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
|
|
49
|
-
report_dict["message"] = (
|
|
50
|
-
f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB ,3)} GB."
|
|
51
|
-
)
|
|
52
|
-
report_dict["is_check_passed"] = False
|
|
53
|
-
elif file_size == 0:
|
|
54
|
-
report_dict["message"] = "File is empty"
|
|
55
|
-
report_dict["file_size"] = 0
|
|
56
|
-
report_dict["is_check_passed"] = False
|
|
57
|
-
return report_dict
|
|
58
|
-
else:
|
|
59
|
-
report_dict["file_size"] = file_size
|
|
60
|
-
|
|
61
|
-
if file.suffix == ".jsonl":
|
|
62
|
-
report_dict["filetype"] = "jsonl"
|
|
63
|
-
data_report_dict = _check_jsonl(file)
|
|
64
|
-
elif file.suffix == ".parquet":
|
|
65
|
-
report_dict["filetype"] = "parquet"
|
|
66
|
-
data_report_dict = _check_parquet(file)
|
|
67
|
-
else:
|
|
68
|
-
report_dict["filetype"] = (
|
|
69
|
-
f"Unknown extension of file {file}. "
|
|
70
|
-
"Only files with extensions .jsonl and .parquet are supported."
|
|
71
|
-
)
|
|
72
|
-
report_dict["is_check_passed"] = False
|
|
73
|
-
|
|
74
|
-
report_dict.update(data_report_dict)
|
|
75
|
-
return report_dict
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def _check_jsonl(file: Path) -> Dict[str, Any]:
|
|
79
|
-
report_dict: Dict[str, Any] = {}
|
|
80
|
-
# Check that the file is UTF-8 encoded. If not report where the error occurs.
|
|
81
|
-
try:
|
|
82
|
-
with file.open(encoding="utf-8") as f:
|
|
83
|
-
f.read()
|
|
84
|
-
report_dict["utf8"] = True
|
|
85
|
-
except UnicodeDecodeError as e:
|
|
86
|
-
report_dict["utf8"] = False
|
|
87
|
-
report_dict["message"] = f"File is not UTF-8 encoded. Error raised: {e}."
|
|
88
|
-
report_dict["is_check_passed"] = False
|
|
89
|
-
return report_dict
|
|
90
|
-
|
|
91
|
-
with file.open() as f:
|
|
92
|
-
# idx must be instantiated so decode errors (e.g. file is a tar) or empty files are caught
|
|
93
|
-
idx = -1
|
|
94
|
-
try:
|
|
95
|
-
for idx, line in enumerate(f):
|
|
96
|
-
json_line = json.loads(line) # each line in jsonlines should be a json
|
|
97
|
-
|
|
98
|
-
if not isinstance(json_line, dict):
|
|
99
|
-
report_dict["line_type"] = False
|
|
100
|
-
report_dict["message"] = (
|
|
101
|
-
f"Error parsing file. Invalid format on line {idx + 1} of the input file. "
|
|
102
|
-
'Example of valid json: {"text": "my sample string"}. '
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
report_dict["is_check_passed"] = False
|
|
106
|
-
|
|
107
|
-
if "text" not in json_line.keys():
|
|
108
|
-
report_dict["text_field"] = False
|
|
109
|
-
report_dict["message"] = (
|
|
110
|
-
f"Missing 'text' field was found on line {idx + 1} of the the input file. "
|
|
111
|
-
"Expected format: {'text': 'my sample string'}. "
|
|
112
|
-
)
|
|
113
|
-
report_dict["is_check_passed"] = False
|
|
114
|
-
else:
|
|
115
|
-
# check to make sure the value of the "text" key is a string
|
|
116
|
-
if not isinstance(json_line["text"], str):
|
|
117
|
-
report_dict["key_value"] = False
|
|
118
|
-
report_dict["message"] = (
|
|
119
|
-
f'Invalid value type for "text" key on line {idx + 1}. '
|
|
120
|
-
f'Expected string. Found {type(json_line["text"])}.'
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
report_dict["is_check_passed"] = False
|
|
124
|
-
|
|
125
|
-
# make sure this is outside the for idx, line in enumerate(f): for loop
|
|
126
|
-
if idx + 1 < MIN_SAMPLES:
|
|
127
|
-
report_dict["min_samples"] = False
|
|
128
|
-
report_dict["message"] = (
|
|
129
|
-
f"Processing {file} resulted in only {idx + 1} samples. "
|
|
130
|
-
f"Our minimum is {MIN_SAMPLES} samples. "
|
|
131
|
-
)
|
|
132
|
-
report_dict["is_check_passed"] = False
|
|
133
|
-
else:
|
|
134
|
-
report_dict["num_samples"] = idx + 1
|
|
135
|
-
report_dict["min_samples"] = True
|
|
136
|
-
|
|
137
|
-
report_dict["load_json"] = True
|
|
138
|
-
|
|
139
|
-
except ValueError:
|
|
140
|
-
report_dict["load_json"] = False
|
|
141
|
-
if idx < 0:
|
|
142
|
-
report_dict["message"] = (
|
|
143
|
-
"Unable to decode file. "
|
|
144
|
-
"File may be empty or in an unsupported format. "
|
|
145
|
-
)
|
|
146
|
-
else:
|
|
147
|
-
report_dict["message"] = (
|
|
148
|
-
f"Error parsing json payload. Unexpected format on line {idx + 1}."
|
|
149
|
-
)
|
|
150
|
-
report_dict["is_check_passed"] = False
|
|
151
|
-
|
|
152
|
-
if "text_field" not in report_dict:
|
|
153
|
-
report_dict["text_field"] = True
|
|
154
|
-
if "line_type" not in report_dict:
|
|
155
|
-
report_dict["line_type"] = True
|
|
156
|
-
if "key_value" not in report_dict:
|
|
157
|
-
report_dict["key_value"] = True
|
|
158
|
-
return report_dict
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def _check_parquet(file: Path) -> Dict[str, Any]:
|
|
162
|
-
report_dict: Dict[str, Any] = {}
|
|
163
|
-
|
|
164
|
-
try:
|
|
165
|
-
table = parquet.read_table(str(file), memory_map=True)
|
|
166
|
-
except ArrowInvalid:
|
|
167
|
-
report_dict["load_parquet"] = (
|
|
168
|
-
f"An exception has occurred when loading the Parquet file {file}. Please check the file for corruption. "
|
|
169
|
-
f"Exception trace:\n{format_exc()}"
|
|
170
|
-
)
|
|
171
|
-
report_dict["is_check_passed"] = False
|
|
172
|
-
return report_dict
|
|
173
|
-
|
|
174
|
-
column_names = table.schema.names
|
|
175
|
-
if "input_ids" not in column_names:
|
|
176
|
-
report_dict["load_parquet"] = (
|
|
177
|
-
f"Parquet file {file} does not contain the `input_ids` column."
|
|
178
|
-
)
|
|
179
|
-
report_dict["is_check_passed"] = False
|
|
180
|
-
return report_dict
|
|
181
|
-
|
|
182
|
-
for column_name in column_names:
|
|
183
|
-
if column_name not in PARQUET_EXPECTED_COLUMNS:
|
|
184
|
-
report_dict["load_parquet"] = (
|
|
185
|
-
f"Parquet file {file} contains an unexpected column {column_name}. "
|
|
186
|
-
f"Only columns {PARQUET_EXPECTED_COLUMNS} are supported."
|
|
187
|
-
)
|
|
188
|
-
report_dict["is_check_passed"] = False
|
|
189
|
-
return report_dict
|
|
190
|
-
|
|
191
|
-
num_samples = len(table)
|
|
192
|
-
if num_samples < MIN_SAMPLES:
|
|
193
|
-
report_dict["min_samples"] = (
|
|
194
|
-
f"Processing {file} resulted in only {num_samples} samples. "
|
|
195
|
-
f"Our minimum is {MIN_SAMPLES} samples. "
|
|
196
|
-
)
|
|
197
|
-
report_dict["is_check_passed"] = False
|
|
198
|
-
return report_dict
|
|
199
|
-
else:
|
|
200
|
-
report_dict["num_samples"] = num_samples
|
|
201
|
-
|
|
202
|
-
report_dict["is_check_passed"] = True
|
|
203
|
-
|
|
204
|
-
return report_dict
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|