together 1.3.3__tar.gz → 1.3.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. {together-1.3.3 → together-1.3.5}/PKG-INFO +2 -2
  2. {together-1.3.3 → together-1.3.5}/pyproject.toml +3 -3
  3. {together-1.3.3 → together-1.3.5}/src/together/cli/api/finetune.py +45 -18
  4. {together-1.3.3 → together-1.3.5}/src/together/cli/api/utils.py +21 -0
  5. {together-1.3.3 → together-1.3.5}/src/together/constants.py +19 -0
  6. {together-1.3.3 → together-1.3.5}/src/together/resources/finetune.py +64 -4
  7. {together-1.3.3 → together-1.3.5}/src/together/types/__init__.py +4 -0
  8. {together-1.3.3 → together-1.3.5}/src/together/types/finetune.py +24 -1
  9. together-1.3.5/src/together/utils/files.py +324 -0
  10. together-1.3.3/src/together/utils/files.py +0 -204
  11. {together-1.3.3 → together-1.3.5}/LICENSE +0 -0
  12. {together-1.3.3 → together-1.3.5}/README.md +0 -0
  13. {together-1.3.3 → together-1.3.5}/src/together/__init__.py +0 -0
  14. {together-1.3.3 → together-1.3.5}/src/together/abstract/__init__.py +0 -0
  15. {together-1.3.3 → together-1.3.5}/src/together/abstract/api_requestor.py +0 -0
  16. {together-1.3.3 → together-1.3.5}/src/together/cli/__init__.py +0 -0
  17. {together-1.3.3 → together-1.3.5}/src/together/cli/api/__init__.py +0 -0
  18. {together-1.3.3 → together-1.3.5}/src/together/cli/api/chat.py +0 -0
  19. {together-1.3.3 → together-1.3.5}/src/together/cli/api/completions.py +0 -0
  20. {together-1.3.3 → together-1.3.5}/src/together/cli/api/files.py +0 -0
  21. {together-1.3.3 → together-1.3.5}/src/together/cli/api/images.py +0 -0
  22. {together-1.3.3 → together-1.3.5}/src/together/cli/api/models.py +0 -0
  23. {together-1.3.3 → together-1.3.5}/src/together/cli/cli.py +0 -0
  24. {together-1.3.3 → together-1.3.5}/src/together/client.py +0 -0
  25. {together-1.3.3 → together-1.3.5}/src/together/error.py +0 -0
  26. {together-1.3.3 → together-1.3.5}/src/together/filemanager.py +0 -0
  27. {together-1.3.3 → together-1.3.5}/src/together/legacy/__init__.py +0 -0
  28. {together-1.3.3 → together-1.3.5}/src/together/legacy/base.py +0 -0
  29. {together-1.3.3 → together-1.3.5}/src/together/legacy/complete.py +0 -0
  30. {together-1.3.3 → together-1.3.5}/src/together/legacy/embeddings.py +0 -0
  31. {together-1.3.3 → together-1.3.5}/src/together/legacy/files.py +0 -0
  32. {together-1.3.3 → together-1.3.5}/src/together/legacy/finetune.py +0 -0
  33. {together-1.3.3 → together-1.3.5}/src/together/legacy/images.py +0 -0
  34. {together-1.3.3 → together-1.3.5}/src/together/legacy/models.py +0 -0
  35. {together-1.3.3 → together-1.3.5}/src/together/resources/__init__.py +0 -0
  36. {together-1.3.3 → together-1.3.5}/src/together/resources/chat/__init__.py +0 -0
  37. {together-1.3.3 → together-1.3.5}/src/together/resources/chat/completions.py +0 -0
  38. {together-1.3.3 → together-1.3.5}/src/together/resources/completions.py +0 -0
  39. {together-1.3.3 → together-1.3.5}/src/together/resources/embeddings.py +0 -0
  40. {together-1.3.3 → together-1.3.5}/src/together/resources/files.py +0 -0
  41. {together-1.3.3 → together-1.3.5}/src/together/resources/images.py +0 -0
  42. {together-1.3.3 → together-1.3.5}/src/together/resources/models.py +0 -0
  43. {together-1.3.3 → together-1.3.5}/src/together/resources/rerank.py +0 -0
  44. {together-1.3.3 → together-1.3.5}/src/together/together_response.py +0 -0
  45. {together-1.3.3 → together-1.3.5}/src/together/types/abstract.py +0 -0
  46. {together-1.3.3 → together-1.3.5}/src/together/types/chat_completions.py +0 -0
  47. {together-1.3.3 → together-1.3.5}/src/together/types/common.py +0 -0
  48. {together-1.3.3 → together-1.3.5}/src/together/types/completions.py +0 -0
  49. {together-1.3.3 → together-1.3.5}/src/together/types/embeddings.py +0 -0
  50. {together-1.3.3 → together-1.3.5}/src/together/types/error.py +0 -0
  51. {together-1.3.3 → together-1.3.5}/src/together/types/files.py +0 -0
  52. {together-1.3.3 → together-1.3.5}/src/together/types/images.py +0 -0
  53. {together-1.3.3 → together-1.3.5}/src/together/types/models.py +0 -0
  54. {together-1.3.3 → together-1.3.5}/src/together/types/rerank.py +0 -0
  55. {together-1.3.3 → together-1.3.5}/src/together/utils/__init__.py +0 -0
  56. {together-1.3.3 → together-1.3.5}/src/together/utils/_log.py +0 -0
  57. {together-1.3.3 → together-1.3.5}/src/together/utils/api_helpers.py +0 -0
  58. {together-1.3.3 → together-1.3.5}/src/together/utils/tools.py +0 -0
  59. {together-1.3.3 → together-1.3.5}/src/together/version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: together
3
- Version: 1.3.3
3
+ Version: 1.3.5
4
4
  Summary: Python client for Together's Cloud Platform!
5
5
  Home-page: https://github.com/togethercomputer/together-python
6
6
  License: Apache-2.0
@@ -29,7 +29,7 @@ Requires-Dist: requests (>=2.31.0,<3.0.0)
29
29
  Requires-Dist: rich (>=13.8.1,<14.0.0)
30
30
  Requires-Dist: tabulate (>=0.9.0,<0.10.0)
31
31
  Requires-Dist: tqdm (>=4.66.2,<5.0.0)
32
- Requires-Dist: typer (>=0.9,<0.13)
32
+ Requires-Dist: typer (>=0.9,<0.14)
33
33
  Project-URL: Bug Tracker, https://github.com/togethercomputer/together-python/issues
34
34
  Project-URL: Repository, https://github.com/togethercomputer/together-python
35
35
  Description-Content-Type: text/markdown
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
12
12
 
13
13
  [tool.poetry]
14
14
  name = "together"
15
- version = "1.3.3"
15
+ version = "1.3.5"
16
16
  authors = [
17
17
  "Together AI <support@together.ai>"
18
18
  ]
@@ -29,7 +29,7 @@ homepage = "https://github.com/togethercomputer/together-python"
29
29
 
30
30
  [tool.poetry.dependencies]
31
31
  python = "^3.8"
32
- typer = ">=0.9,<0.13"
32
+ typer = ">=0.9,<0.14"
33
33
  requests = "^2.31.0"
34
34
  rich = "^13.8.1"
35
35
  tqdm = "^4.66.2"
@@ -51,7 +51,7 @@ optional = true
51
51
 
52
52
  [tool.poetry.group.quality.dependencies]
53
53
  black = ">=23.1,<25.0"
54
- ruff = ">=0.3.2,<0.7.0"
54
+ ruff = ">=0.3.2,<0.8.0"
55
55
  types-tqdm = "^4.65.0.0"
56
56
  types-tabulate = "^0.9.0.3"
57
57
  pre-commit = "3.5.0"
@@ -11,8 +11,13 @@ from rich import print as rprint
11
11
  from tabulate import tabulate
12
12
 
13
13
  from together import Together
14
- from together.cli.api.utils import INT_WITH_MAX
15
- from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp
14
+ from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
15
+ from together.utils import (
16
+ finetune_price_to_dollars,
17
+ log_warn,
18
+ log_warn_once,
19
+ parse_timestamp,
20
+ )
16
21
  from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
17
22
 
18
23
 
@@ -60,12 +65,30 @@ def fine_tuning(ctx: click.Context) -> None:
60
65
  )
61
66
  @click.option("--batch-size", type=INT_WITH_MAX, default="max", help="Train batch size")
62
67
  @click.option("--learning-rate", type=float, default=1e-5, help="Learning rate")
68
+ @click.option(
69
+ "--min-lr-ratio",
70
+ type=float,
71
+ default=0.0,
72
+ help="The ratio of the final learning rate to the peak learning rate",
73
+ )
63
74
  @click.option(
64
75
  "--warmup-ratio",
65
76
  type=float,
66
77
  default=0.0,
67
78
  help="Warmup ratio for learning rate scheduler.",
68
79
  )
80
+ @click.option(
81
+ "--max-grad-norm",
82
+ type=float,
83
+ default=1.0,
84
+ help="Max gradient norm to be used for gradient clipping. Set to 0 to disable.",
85
+ )
86
+ @click.option(
87
+ "--weight-decay",
88
+ type=float,
89
+ default=0.0,
90
+ help="Weight decay",
91
+ )
69
92
  @click.option(
70
93
  "--lora/--no-lora",
71
94
  type=bool,
@@ -93,6 +116,13 @@ def fine_tuning(ctx: click.Context) -> None:
93
116
  default=False,
94
117
  help="Whether to skip the launch confirmation message",
95
118
  )
119
+ @click.option(
120
+ "--train-on-inputs",
121
+ type=BOOL_WITH_AUTO,
122
+ default="auto",
123
+ help="Whether to mask the user messages in conversational data or prompts in instruction data. "
124
+ "`auto` will automatically determine whether to mask the inputs based on the data format.",
125
+ )
96
126
  def create(
97
127
  ctx: click.Context,
98
128
  training_file: str,
@@ -103,7 +133,10 @@ def create(
103
133
  n_checkpoints: int,
104
134
  batch_size: int | Literal["max"],
105
135
  learning_rate: float,
136
+ min_lr_ratio: float,
106
137
  warmup_ratio: float,
138
+ max_grad_norm: float,
139
+ weight_decay: float,
107
140
  lora: bool,
108
141
  lora_r: int,
109
142
  lora_dropout: float,
@@ -112,6 +145,7 @@ def create(
112
145
  suffix: str,
113
146
  wandb_api_key: str,
114
147
  confirm: bool,
148
+ train_on_inputs: bool | Literal["auto"],
115
149
  ) -> None:
116
150
  """Start fine-tuning"""
117
151
  client: Together = ctx.obj
@@ -125,7 +159,10 @@ def create(
125
159
  n_checkpoints=n_checkpoints,
126
160
  batch_size=batch_size,
127
161
  learning_rate=learning_rate,
162
+ min_lr_ratio=min_lr_ratio,
128
163
  warmup_ratio=warmup_ratio,
164
+ max_grad_norm=max_grad_norm,
165
+ weight_decay=weight_decay,
129
166
  lora=lora,
130
167
  lora_r=lora_r,
131
168
  lora_dropout=lora_dropout,
@@ -133,6 +170,7 @@ def create(
133
170
  lora_trainable_modules=lora_trainable_modules,
134
171
  suffix=suffix,
135
172
  wandb_api_key=wandb_api_key,
173
+ train_on_inputs=train_on_inputs,
136
174
  )
137
175
 
138
176
  model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
@@ -150,6 +188,10 @@ def create(
150
188
  "batch_size": model_limits.lora_training.max_batch_size,
151
189
  "learning_rate": 1e-3,
152
190
  }
191
+ log_warn_once(
192
+ f"The default LoRA rank for {model} has been changed to {default_values['lora_r']} as the max available.\n"
193
+ f"Also, the default learning rate for LoRA fine-tuning has been changed to {default_values['learning_rate']}."
194
+ )
153
195
  for arg in default_values:
154
196
  arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined]
155
197
  if arg_source == ParameterSource.DEFAULT:
@@ -186,22 +228,7 @@ def create(
186
228
 
187
229
  if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
188
230
  response = client.fine_tuning.create(
189
- training_file=training_file,
190
- model=model,
191
- n_epochs=n_epochs,
192
- validation_file=validation_file,
193
- n_evals=n_evals,
194
- n_checkpoints=n_checkpoints,
195
- batch_size=batch_size,
196
- learning_rate=learning_rate,
197
- warmup_ratio=warmup_ratio,
198
- lora=lora,
199
- lora_r=lora_r,
200
- lora_dropout=lora_dropout,
201
- lora_alpha=lora_alpha,
202
- lora_trainable_modules=lora_trainable_modules,
203
- suffix=suffix,
204
- wandb_api_key=wandb_api_key,
231
+ **training_args,
205
232
  verbose=True,
206
233
  )
207
234
 
@@ -27,4 +27,25 @@ class AutoIntParamType(click.ParamType):
27
27
  )
28
28
 
29
29
 
30
+ class BooleanWithAutoParamType(click.ParamType):
31
+ name = "boolean_or_auto"
32
+
33
+ def convert(
34
+ self, value: str, param: click.Parameter | None, ctx: click.Context | None
35
+ ) -> bool | Literal["auto"] | None:
36
+ if value == "auto":
37
+ return "auto"
38
+ try:
39
+ return bool(value)
40
+ except ValueError:
41
+ self.fail(
42
+ _("{value!r} is not a valid {type}.").format(
43
+ value=value, type=self.name
44
+ ),
45
+ param,
46
+ ctx,
47
+ )
48
+
49
+
30
50
  INT_WITH_MAX = AutoIntParamType()
51
+ BOOL_WITH_AUTO = BooleanWithAutoParamType()
@@ -1,3 +1,5 @@
1
+ import enum
2
+
1
3
  # Session constants
2
4
  TIMEOUT_SECS = 600
3
5
  MAX_SESSION_LIFETIME_SECS = 180
@@ -29,3 +31,20 @@ MAX_FILE_SIZE_GB = 4.9
29
31
 
30
32
  # expected columns for Parquet files
31
33
  PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]
34
+
35
+
36
+ class DatasetFormat(enum.Enum):
37
+ """Dataset format enum."""
38
+
39
+ GENERAL = "general"
40
+ CONVERSATION = "conversation"
41
+ INSTRUCTION = "instruction"
42
+
43
+
44
+ JSONL_REQUIRED_COLUMNS_MAP = {
45
+ DatasetFormat.GENERAL: ["text"],
46
+ DatasetFormat.CONVERSATION: ["messages"],
47
+ DatasetFormat.INSTRUCTION: ["prompt", "completion"],
48
+ }
49
+ REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
50
+ POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
@@ -20,6 +20,8 @@ from together.types import (
20
20
  TogetherClient,
21
21
  TogetherRequest,
22
22
  TrainingType,
23
+ FinetuneLRScheduler,
24
+ FinetuneLinearLRSchedulerArgs,
23
25
  )
24
26
  from together.types.finetune import DownloadCheckpointType
25
27
  from together.utils import log_warn_once, normalize_key
@@ -35,7 +37,10 @@ def createFinetuneRequest(
35
37
  n_checkpoints: int | None = 1,
36
38
  batch_size: int | Literal["max"] = "max",
37
39
  learning_rate: float | None = 0.00001,
38
- warmup_ratio: float | None = 0.0,
40
+ min_lr_ratio: float = 0.0,
41
+ warmup_ratio: float = 0.0,
42
+ max_grad_norm: float = 1.0,
43
+ weight_decay: float = 0.0,
39
44
  lora: bool = False,
40
45
  lora_r: int | None = None,
41
46
  lora_dropout: float | None = 0,
@@ -43,6 +48,7 @@ def createFinetuneRequest(
43
48
  lora_trainable_modules: str | None = "all-linear",
44
49
  suffix: str | None = None,
45
50
  wandb_api_key: str | None = None,
51
+ train_on_inputs: bool | Literal["auto"] = "auto",
46
52
  ) -> FinetuneRequest:
47
53
  if batch_size == "max":
48
54
  log_warn_once(
@@ -82,6 +88,20 @@ def createFinetuneRequest(
82
88
  if warmup_ratio > 1 or warmup_ratio < 0:
83
89
  raise ValueError("Warmup ratio should be between 0 and 1")
84
90
 
91
+ if min_lr_ratio is not None and (min_lr_ratio > 1 or min_lr_ratio < 0):
92
+ raise ValueError("Min learning rate ratio should be between 0 and 1")
93
+
94
+ if max_grad_norm < 0:
95
+ raise ValueError("Max gradient norm should be non-negative")
96
+
97
+ if weight_decay is not None and (weight_decay < 0):
98
+ raise ValueError("Weight decay should be non-negative")
99
+
100
+ lrScheduler = FinetuneLRScheduler(
101
+ lr_scheduler_type="linear",
102
+ lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
103
+ )
104
+
85
105
  finetune_request = FinetuneRequest(
86
106
  model=model,
87
107
  training_file=training_file,
@@ -91,10 +111,14 @@ def createFinetuneRequest(
91
111
  n_checkpoints=n_checkpoints,
92
112
  batch_size=batch_size,
93
113
  learning_rate=learning_rate,
114
+ lr_scheduler=lrScheduler,
94
115
  warmup_ratio=warmup_ratio,
116
+ max_grad_norm=max_grad_norm,
117
+ weight_decay=weight_decay,
95
118
  training_type=training_type,
96
119
  suffix=suffix,
97
120
  wandb_key=wandb_api_key,
121
+ train_on_inputs=train_on_inputs,
98
122
  )
99
123
 
100
124
  return finetune_request
@@ -115,7 +139,10 @@ class FineTuning:
115
139
  n_checkpoints: int | None = 1,
116
140
  batch_size: int | Literal["max"] = "max",
117
141
  learning_rate: float | None = 0.00001,
118
- warmup_ratio: float | None = 0.0,
142
+ min_lr_ratio: float = 0.0,
143
+ warmup_ratio: float = 0.0,
144
+ max_grad_norm: float = 1.0,
145
+ weight_decay: float = 0.0,
119
146
  lora: bool = False,
120
147
  lora_r: int | None = None,
121
148
  lora_dropout: float | None = 0,
@@ -125,6 +152,7 @@ class FineTuning:
125
152
  wandb_api_key: str | None = None,
126
153
  verbose: bool = False,
127
154
  model_limits: FinetuneTrainingLimits | None = None,
155
+ train_on_inputs: bool | Literal["auto"] = "auto",
128
156
  ) -> FinetuneResponse:
129
157
  """
130
158
  Method to initiate a fine-tuning job
@@ -137,10 +165,14 @@ class FineTuning:
137
165
  n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
138
166
  n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
139
167
  Defaults to 1.
140
- batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
168
+ batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
141
169
  learning_rate (float, optional): Learning rate multiplier to use for training
142
170
  Defaults to 0.00001.
171
+ min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
172
+ the learning rate scheduler. Defaults to 0.0.
143
173
  warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
174
+ max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
175
+ weight_decay (float, optional): Weight decay. Defaults to 0.0.
144
176
  lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
145
177
  lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
146
178
  lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
@@ -154,6 +186,12 @@ class FineTuning:
154
186
  Defaults to False.
155
187
  model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
156
188
  Defaults to None.
189
+ train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
190
+ "auto" will automatically determine whether to mask the inputs based on the data format.
191
+ For datasets with the "text" field (general format), inputs will not be masked.
192
+ For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
193
+ (Instruction format), inputs will be masked.
194
+ Defaults to "auto".
157
195
 
158
196
  Returns:
159
197
  FinetuneResponse: Object containing information about fine-tuning job.
@@ -176,7 +214,10 @@ class FineTuning:
176
214
  n_checkpoints=n_checkpoints,
177
215
  batch_size=batch_size,
178
216
  learning_rate=learning_rate,
217
+ min_lr_ratio=min_lr_ratio,
179
218
  warmup_ratio=warmup_ratio,
219
+ max_grad_norm=max_grad_norm,
220
+ weight_decay=weight_decay,
180
221
  lora=lora,
181
222
  lora_r=lora_r,
182
223
  lora_dropout=lora_dropout,
@@ -184,6 +225,7 @@ class FineTuning:
184
225
  lora_trainable_modules=lora_trainable_modules,
185
226
  suffix=suffix,
186
227
  wandb_api_key=wandb_api_key,
228
+ train_on_inputs=train_on_inputs,
187
229
  )
188
230
 
189
231
  if verbose:
@@ -426,7 +468,10 @@ class AsyncFineTuning:
426
468
  n_checkpoints: int | None = 1,
427
469
  batch_size: int | Literal["max"] = "max",
428
470
  learning_rate: float | None = 0.00001,
429
- warmup_ratio: float | None = 0.0,
471
+ min_lr_ratio: float = 0.0,
472
+ warmup_ratio: float = 0.0,
473
+ max_grad_norm: float = 1.0,
474
+ weight_decay: float = 0.0,
430
475
  lora: bool = False,
431
476
  lora_r: int | None = None,
432
477
  lora_dropout: float | None = 0,
@@ -436,6 +481,7 @@ class AsyncFineTuning:
436
481
  wandb_api_key: str | None = None,
437
482
  verbose: bool = False,
438
483
  model_limits: FinetuneTrainingLimits | None = None,
484
+ train_on_inputs: bool | Literal["auto"] = "auto",
439
485
  ) -> FinetuneResponse:
440
486
  """
441
487
  Async method to initiate a fine-tuning job
@@ -451,7 +497,11 @@ class AsyncFineTuning:
451
497
  batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
452
498
  learning_rate (float, optional): Learning rate multiplier to use for training
453
499
  Defaults to 0.00001.
500
+ min_lr_ratio (float, optional): Min learning rate ratio of the initial learning rate for
501
+ the learning rate scheduler. Defaults to 0.0.
454
502
  warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
503
+ max_grad_norm (float, optional): Max gradient norm. Defaults to 1.0, set to 0 to disable.
504
+ weight_decay (float, optional): Weight decay. Defaults to 0.0.
455
505
  lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
456
506
  lora_r (int, optional): Rank of LoRA adapters. Defaults to 8.
457
507
  lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
@@ -465,6 +515,12 @@ class AsyncFineTuning:
465
515
  Defaults to False.
466
516
  model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
467
517
  Defaults to None.
518
+ train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
519
+ "auto" will automatically determine whether to mask the inputs based on the data format.
520
+ For datasets with the "text" field (general format), inputs will not be masked.
521
+ For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
522
+ (Instruction format), inputs will be masked.
523
+ Defaults to "auto".
468
524
 
469
525
  Returns:
470
526
  FinetuneResponse: Object containing information about fine-tuning job.
@@ -487,7 +543,10 @@ class AsyncFineTuning:
487
543
  n_checkpoints=n_checkpoints,
488
544
  batch_size=batch_size,
489
545
  learning_rate=learning_rate,
546
+ min_lr_ratio=min_lr_ratio,
490
547
  warmup_ratio=warmup_ratio,
548
+ max_grad_norm=max_grad_norm,
549
+ weight_decay=weight_decay,
491
550
  lora=lora,
492
551
  lora_r=lora_r,
493
552
  lora_dropout=lora_dropout,
@@ -495,6 +554,7 @@ class AsyncFineTuning:
495
554
  lora_trainable_modules=lora_trainable_modules,
496
555
  suffix=suffix,
497
556
  wandb_api_key=wandb_api_key,
557
+ train_on_inputs=train_on_inputs,
498
558
  )
499
559
 
500
560
  if verbose:
@@ -30,6 +30,8 @@ from together.types.finetune import (
30
30
  LoRATrainingType,
31
31
  TrainingType,
32
32
  FinetuneTrainingLimits,
33
+ FinetuneLRScheduler,
34
+ FinetuneLinearLRSchedulerArgs,
33
35
  )
34
36
  from together.types.images import (
35
37
  ImageRequest,
@@ -57,6 +59,8 @@ __all__ = [
57
59
  "FinetuneList",
58
60
  "FinetuneListEvents",
59
61
  "FinetuneDownloadResult",
62
+ "FinetuneLRScheduler",
63
+ "FinetuneLinearLRSchedulerArgs",
60
64
  "FileRequest",
61
65
  "FileResponse",
62
66
  "FileList",
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from enum import Enum
4
4
  from typing import List, Literal
5
5
 
6
- from pydantic import Field, validator, field_validator
6
+ from pydantic import StrictBool, Field, validator, field_validator
7
7
 
8
8
  from together.types.abstract import BaseModel
9
9
  from together.types.common import (
@@ -150,8 +150,14 @@ class FinetuneRequest(BaseModel):
150
150
  n_epochs: int
151
151
  # training learning rate
152
152
  learning_rate: float
153
+ # learning rate scheduler type and args
154
+ lr_scheduler: FinetuneLRScheduler | None = None
153
155
  # learning rate warmup ratio
154
156
  warmup_ratio: float
157
+ # max gradient norm
158
+ max_grad_norm: float
159
+ # weight decay
160
+ weight_decay: float
155
161
  # number of checkpoints to save
156
162
  n_checkpoints: int | None = None
157
163
  # number of evaluation loops to run
@@ -163,6 +169,7 @@ class FinetuneRequest(BaseModel):
163
169
  # weights & biases api key
164
170
  wandb_key: str | None = None
165
171
  training_type: FullTrainingType | LoRATrainingType | None = None
172
+ train_on_inputs: StrictBool | Literal["auto"] = "auto"
166
173
 
167
174
 
168
175
  class FinetuneResponse(BaseModel):
@@ -192,8 +199,14 @@ class FinetuneResponse(BaseModel):
192
199
  batch_size: int | None = None
193
200
  # training learning rate
194
201
  learning_rate: float | None = None
202
+ # learning rate scheduler type and args
203
+ lr_scheduler: FinetuneLRScheduler | None = None
195
204
  # learning rate warmup ratio
196
205
  warmup_ratio: float | None = None
206
+ # max gradient norm
207
+ max_grad_norm: float | None = None
208
+ # weight decay
209
+ weight_decay: float | None = None
197
210
  # number of steps between evals
198
211
  eval_steps: int | None = None
199
212
  # training type
@@ -230,6 +243,7 @@ class FinetuneResponse(BaseModel):
230
243
  # training file metadata
231
244
  training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
232
245
  training_file_size: int | None = Field(None, alias="TrainingFileSize")
246
+ train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
233
247
 
234
248
  @field_validator("training_type")
235
249
  @classmethod
@@ -285,3 +299,12 @@ class FinetuneTrainingLimits(BaseModel):
285
299
  min_learning_rate: float
286
300
  full_training: FinetuneFullTrainingLimits | None = None
287
301
  lora_training: FinetuneLoraTrainingLimits | None = None
302
+
303
+
304
+ class FinetuneLRScheduler(BaseModel):
305
+ lr_scheduler_type: str
306
+ lr_scheduler_args: FinetuneLinearLRSchedulerArgs | None = None
307
+
308
+
309
+ class FinetuneLinearLRSchedulerArgs(BaseModel):
310
+ min_lr_ratio: float | None = 0.0
@@ -0,0 +1,324 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from traceback import format_exc
7
+ from typing import Any, Dict
8
+
9
+ from pyarrow import ArrowInvalid, parquet
10
+
11
+ from together.constants import (
12
+ MAX_FILE_SIZE_GB,
13
+ MIN_SAMPLES,
14
+ NUM_BYTES_IN_GB,
15
+ PARQUET_EXPECTED_COLUMNS,
16
+ JSONL_REQUIRED_COLUMNS_MAP,
17
+ REQUIRED_COLUMNS_MESSAGE,
18
+ POSSIBLE_ROLES_CONVERSATION,
19
+ DatasetFormat,
20
+ )
21
+
22
+
23
+ class InvalidFileFormatError(ValueError):
24
+ """Exception raised for invalid file formats during file checks."""
25
+
26
+ def __init__(
27
+ self,
28
+ message: str = "",
29
+ line_number: int | None = None,
30
+ error_source: str | None = None,
31
+ ) -> None:
32
+ super().__init__(message)
33
+ self.message = message
34
+ self.line_number = line_number
35
+ self.error_source = error_source
36
+
37
+
38
+ def check_file(
39
+ file: Path | str,
40
+ ) -> Dict[str, Any]:
41
+ if not isinstance(file, Path):
42
+ file = Path(file)
43
+
44
+ report_dict = {
45
+ "is_check_passed": True,
46
+ "message": "Checks passed",
47
+ "found": None,
48
+ "file_size": None,
49
+ "utf8": None,
50
+ "line_type": None,
51
+ "text_field": None,
52
+ "key_value": None,
53
+ "has_min_samples": None,
54
+ "num_samples": None,
55
+ "load_json": None,
56
+ }
57
+
58
+ if not file.is_file():
59
+ report_dict["found"] = False
60
+ report_dict["is_check_passed"] = False
61
+ return report_dict
62
+ else:
63
+ report_dict["found"] = True
64
+
65
+ file_size = os.stat(file.as_posix()).st_size
66
+
67
+ if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
68
+ report_dict["message"] = (
69
+ f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB ,3)} GB."
70
+ )
71
+ report_dict["is_check_passed"] = False
72
+ elif file_size == 0:
73
+ report_dict["message"] = "File is empty"
74
+ report_dict["file_size"] = 0
75
+ report_dict["is_check_passed"] = False
76
+ return report_dict
77
+ else:
78
+ report_dict["file_size"] = file_size
79
+
80
+ data_report_dict = {}
81
+ if file.suffix == ".jsonl":
82
+ report_dict["filetype"] = "jsonl"
83
+ data_report_dict = _check_jsonl(file)
84
+ elif file.suffix == ".parquet":
85
+ report_dict["filetype"] = "parquet"
86
+ data_report_dict = _check_parquet(file)
87
+ else:
88
+ report_dict["filetype"] = (
89
+ f"Unknown extension of file {file}. "
90
+ "Only files with extensions .jsonl and .parquet are supported."
91
+ )
92
+ report_dict["is_check_passed"] = False
93
+
94
+ report_dict.update(data_report_dict)
95
+
96
+ return report_dict
97
+
98
+
99
+ def _check_jsonl(file: Path) -> Dict[str, Any]:
100
+ report_dict: Dict[str, Any] = {}
101
+ # Check that the file is UTF-8 encoded. If not report where the error occurs.
102
+ try:
103
+ with file.open(encoding="utf-8") as f:
104
+ f.read()
105
+ report_dict["utf8"] = True
106
+ except UnicodeDecodeError as e:
107
+ report_dict["utf8"] = False
108
+ report_dict["message"] = f"File is not UTF-8 encoded. Error raised: {e}."
109
+ report_dict["is_check_passed"] = False
110
+ return report_dict
111
+
112
+ dataset_format = None
113
+ with file.open() as f:
114
+ idx = -1
115
+ try:
116
+ for idx, line in enumerate(f):
117
+ json_line = json.loads(line)
118
+
119
+ if not isinstance(json_line, dict):
120
+ raise InvalidFileFormatError(
121
+ message=(
122
+ f"Error parsing file. Invalid format on line {idx + 1} of the input file. "
123
+ 'Example of valid json: {"text": "my sample string"}. '
124
+ ),
125
+ line_number=idx + 1,
126
+ error_source="line_type",
127
+ )
128
+
129
+ current_format = None
130
+ for possible_format in JSONL_REQUIRED_COLUMNS_MAP:
131
+ if all(
132
+ column in json_line
133
+ for column in JSONL_REQUIRED_COLUMNS_MAP[possible_format]
134
+ ):
135
+ if current_format is None:
136
+ current_format = possible_format
137
+ elif current_format != possible_format:
138
+ raise InvalidFileFormatError(
139
+ message="Found multiple dataset formats in the input file. "
140
+ f"Got {current_format} and {possible_format} on line {idx + 1}.",
141
+ line_number=idx + 1,
142
+ error_source="format",
143
+ )
144
+
145
+ if current_format is None:
146
+ raise InvalidFileFormatError(
147
+ message=(
148
+ f"Error parsing file. Could not detect a format for the line {idx + 1} with the columns:\n"
149
+ f"{json_line.keys()}"
150
+ ),
151
+ line_number=idx + 1,
152
+ error_source="format",
153
+ )
154
+
155
+ if current_format == DatasetFormat.CONVERSATION:
156
+ message_column = JSONL_REQUIRED_COLUMNS_MAP[
157
+ DatasetFormat.CONVERSATION
158
+ ][0]
159
+ if not isinstance(json_line[message_column], list):
160
+ raise InvalidFileFormatError(
161
+ message=f"Invalid format on line {idx + 1} of the input file. "
162
+ f"Expected a list of messages. Found {type(json_line[message_column])}",
163
+ line_number=idx + 1,
164
+ error_source="key_value",
165
+ )
166
+
167
+ for turn_id, turn in enumerate(json_line[message_column]):
168
+ if not isinstance(turn, dict):
169
+ raise InvalidFileFormatError(
170
+ message=f"Invalid format on line {idx + 1} of the input file. "
171
+ f"Expected a dictionary in the {turn_id + 1} turn. Found {type(turn)}",
172
+ line_number=idx + 1,
173
+ error_source="key_value",
174
+ )
175
+
176
+ previous_role = None
177
+ for turn in json_line[message_column]:
178
+ for column in REQUIRED_COLUMNS_MESSAGE:
179
+ if column not in turn:
180
+ raise InvalidFileFormatError(
181
+ message=f"Field `{column}` is missing for a turn `{turn}` on line {idx + 1} "
182
+ "of the the input file.",
183
+ line_number=idx + 1,
184
+ error_source="key_value",
185
+ )
186
+ else:
187
+ if not isinstance(turn[column], str):
188
+ raise InvalidFileFormatError(
189
+ message=f"Invalid format on line {idx + 1} in the column {column} for turn `{turn}` "
190
+ f"of the input file. Expected string. Found {type(turn[column])}",
191
+ line_number=idx + 1,
192
+ error_source="text_field",
193
+ )
194
+ role = turn["role"]
195
+
196
+ if role not in POSSIBLE_ROLES_CONVERSATION:
197
+ raise InvalidFileFormatError(
198
+ message=f"Found invalid role `{role}` in the messages on the line {idx + 1}. "
199
+ f"Possible roles in the conversation are: {POSSIBLE_ROLES_CONVERSATION}",
200
+ line_number=idx + 1,
201
+ error_source="key_value",
202
+ )
203
+
204
+ if previous_role == role:
205
+ raise InvalidFileFormatError(
206
+ message=f"Invalid role turns on line {idx + 1} of the input file. "
207
+ "`user` and `assistant` roles must alternate user/assistant/user/assistant/...",
208
+ line_number=idx + 1,
209
+ error_source="key_value",
210
+ )
211
+
212
+ previous_role = role
213
+
214
+ else:
215
+ for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
216
+ if not isinstance(json_line[column], str):
217
+ raise InvalidFileFormatError(
218
+ message=f'Invalid value type for "{column}" key on line {idx + 1}. '
219
+ f"Expected string. Found {type(json_line[column])}.",
220
+ line_number=idx + 1,
221
+ error_source="key_value",
222
+ )
223
+
224
+ if dataset_format is None:
225
+ dataset_format = current_format
226
+ elif current_format is not None:
227
+ if current_format != dataset_format:
228
+ raise InvalidFileFormatError(
229
+ message="All samples in the dataset must have the same dataset format. "
230
+ f"Got {dataset_format} for the first line and {current_format} "
231
+ f"for the line {idx + 1}.",
232
+ line_number=idx + 1,
233
+ error_source="format",
234
+ )
235
+
236
+ if idx + 1 < MIN_SAMPLES:
237
+ report_dict["has_min_samples"] = False
238
+ report_dict["message"] = (
239
+ f"Processing {file} resulted in only {idx + 1} samples. "
240
+ f"Our minimum is {MIN_SAMPLES} samples. "
241
+ )
242
+ report_dict["is_check_passed"] = False
243
+ else:
244
+ report_dict["num_samples"] = idx + 1
245
+ report_dict["has_min_samples"] = True
246
+ report_dict["is_check_passed"] = True
247
+
248
+ report_dict["load_json"] = True
249
+
250
+ except InvalidFileFormatError as e:
251
+ report_dict["load_json"] = False
252
+ report_dict["is_check_passed"] = False
253
+ report_dict["message"] = e.message
254
+ if e.line_number is not None:
255
+ report_dict["line_number"] = e.line_number
256
+ if e.error_source is not None:
257
+ report_dict[e.error_source] = False
258
+ except ValueError:
259
+ report_dict["load_json"] = False
260
+ if idx < 0:
261
+ report_dict["message"] = (
262
+ "Unable to decode file. "
263
+ "File may be empty or in an unsupported format. "
264
+ )
265
+ else:
266
+ report_dict["message"] = (
267
+ f"Error parsing json payload. Unexpected format on line {idx + 1}."
268
+ )
269
+ report_dict["is_check_passed"] = False
270
+
271
+ if "text_field" not in report_dict:
272
+ report_dict["text_field"] = True
273
+ if "line_type" not in report_dict:
274
+ report_dict["line_type"] = True
275
+ if "key_value" not in report_dict:
276
+ report_dict["key_value"] = True
277
+ return report_dict
278
+
279
+
280
+ def _check_parquet(file: Path) -> Dict[str, Any]:
281
+ report_dict: Dict[str, Any] = {}
282
+
283
+ try:
284
+ table = parquet.read_table(str(file), memory_map=True)
285
+ except ArrowInvalid:
286
+ report_dict["load_parquet"] = (
287
+ f"An exception has occurred when loading the Parquet file {file}. Please check the file for corruption. "
288
+ f"Exception trace:\n{format_exc()}"
289
+ )
290
+ report_dict["is_check_passed"] = False
291
+ return report_dict
292
+
293
+ column_names = table.schema.names
294
+ if "input_ids" not in column_names:
295
+ report_dict["load_parquet"] = (
296
+ f"Parquet file {file} does not contain the `input_ids` column."
297
+ )
298
+ report_dict["is_check_passed"] = False
299
+ return report_dict
300
+
301
+ for column_name in column_names:
302
+ if column_name not in PARQUET_EXPECTED_COLUMNS:
303
+ report_dict["load_parquet"] = (
304
+ f"Parquet file {file} contains an unexpected column {column_name}. "
305
+ f"Only columns {PARQUET_EXPECTED_COLUMNS} are supported."
306
+ )
307
+ report_dict["is_check_passed"] = False
308
+ return report_dict
309
+
310
+ num_samples = len(table)
311
+ if num_samples < MIN_SAMPLES:
312
+ report_dict["has_min_samples"] = False
313
+ report_dict["message"] = (
314
+ f"Processing {file} resulted in only {num_samples} samples. "
315
+ f"Our minimum is {MIN_SAMPLES} samples. "
316
+ )
317
+ report_dict["is_check_passed"] = False
318
+ return report_dict
319
+ else:
320
+ report_dict["num_samples"] = num_samples
321
+
322
+ report_dict["is_check_passed"] = True
323
+
324
+ return report_dict
@@ -1,204 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- from pathlib import Path
6
- from traceback import format_exc
7
- from typing import Any, Dict
8
-
9
- from pyarrow import ArrowInvalid, parquet
10
-
11
- from together.constants import (
12
- MAX_FILE_SIZE_GB,
13
- MIN_SAMPLES,
14
- NUM_BYTES_IN_GB,
15
- PARQUET_EXPECTED_COLUMNS,
16
- )
17
-
18
-
19
- def check_file(
20
- file: Path | str,
21
- ) -> Dict[str, Any]:
22
- if not isinstance(file, Path):
23
- file = Path(file)
24
-
25
- report_dict = {
26
- "is_check_passed": True,
27
- "message": "Checks passed",
28
- "found": None,
29
- "file_size": None,
30
- "utf8": None,
31
- "line_type": None,
32
- "text_field": None,
33
- "key_value": None,
34
- "min_samples": None,
35
- "num_samples": None,
36
- "load_json": None,
37
- }
38
-
39
- if not file.is_file():
40
- report_dict["found"] = False
41
- report_dict["is_check_passed"] = False
42
- return report_dict
43
- else:
44
- report_dict["found"] = True
45
-
46
- file_size = os.stat(file.as_posix()).st_size
47
-
48
- if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
49
- report_dict["message"] = (
50
- f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB ,3)} GB."
51
- )
52
- report_dict["is_check_passed"] = False
53
- elif file_size == 0:
54
- report_dict["message"] = "File is empty"
55
- report_dict["file_size"] = 0
56
- report_dict["is_check_passed"] = False
57
- return report_dict
58
- else:
59
- report_dict["file_size"] = file_size
60
-
61
- if file.suffix == ".jsonl":
62
- report_dict["filetype"] = "jsonl"
63
- data_report_dict = _check_jsonl(file)
64
- elif file.suffix == ".parquet":
65
- report_dict["filetype"] = "parquet"
66
- data_report_dict = _check_parquet(file)
67
- else:
68
- report_dict["filetype"] = (
69
- f"Unknown extension of file {file}. "
70
- "Only files with extensions .jsonl and .parquet are supported."
71
- )
72
- report_dict["is_check_passed"] = False
73
-
74
- report_dict.update(data_report_dict)
75
- return report_dict
76
-
77
-
78
- def _check_jsonl(file: Path) -> Dict[str, Any]:
79
- report_dict: Dict[str, Any] = {}
80
- # Check that the file is UTF-8 encoded. If not report where the error occurs.
81
- try:
82
- with file.open(encoding="utf-8") as f:
83
- f.read()
84
- report_dict["utf8"] = True
85
- except UnicodeDecodeError as e:
86
- report_dict["utf8"] = False
87
- report_dict["message"] = f"File is not UTF-8 encoded. Error raised: {e}."
88
- report_dict["is_check_passed"] = False
89
- return report_dict
90
-
91
- with file.open() as f:
92
- # idx must be instantiated so decode errors (e.g. file is a tar) or empty files are caught
93
- idx = -1
94
- try:
95
- for idx, line in enumerate(f):
96
- json_line = json.loads(line) # each line in jsonlines should be a json
97
-
98
- if not isinstance(json_line, dict):
99
- report_dict["line_type"] = False
100
- report_dict["message"] = (
101
- f"Error parsing file. Invalid format on line {idx + 1} of the input file. "
102
- 'Example of valid json: {"text": "my sample string"}. '
103
- )
104
-
105
- report_dict["is_check_passed"] = False
106
-
107
- if "text" not in json_line.keys():
108
- report_dict["text_field"] = False
109
- report_dict["message"] = (
110
- f"Missing 'text' field was found on line {idx + 1} of the the input file. "
111
- "Expected format: {'text': 'my sample string'}. "
112
- )
113
- report_dict["is_check_passed"] = False
114
- else:
115
- # check to make sure the value of the "text" key is a string
116
- if not isinstance(json_line["text"], str):
117
- report_dict["key_value"] = False
118
- report_dict["message"] = (
119
- f'Invalid value type for "text" key on line {idx + 1}. '
120
- f'Expected string. Found {type(json_line["text"])}.'
121
- )
122
-
123
- report_dict["is_check_passed"] = False
124
-
125
- # make sure this is outside the for idx, line in enumerate(f): for loop
126
- if idx + 1 < MIN_SAMPLES:
127
- report_dict["min_samples"] = False
128
- report_dict["message"] = (
129
- f"Processing {file} resulted in only {idx + 1} samples. "
130
- f"Our minimum is {MIN_SAMPLES} samples. "
131
- )
132
- report_dict["is_check_passed"] = False
133
- else:
134
- report_dict["num_samples"] = idx + 1
135
- report_dict["min_samples"] = True
136
-
137
- report_dict["load_json"] = True
138
-
139
- except ValueError:
140
- report_dict["load_json"] = False
141
- if idx < 0:
142
- report_dict["message"] = (
143
- "Unable to decode file. "
144
- "File may be empty or in an unsupported format. "
145
- )
146
- else:
147
- report_dict["message"] = (
148
- f"Error parsing json payload. Unexpected format on line {idx + 1}."
149
- )
150
- report_dict["is_check_passed"] = False
151
-
152
- if "text_field" not in report_dict:
153
- report_dict["text_field"] = True
154
- if "line_type" not in report_dict:
155
- report_dict["line_type"] = True
156
- if "key_value" not in report_dict:
157
- report_dict["key_value"] = True
158
- return report_dict
159
-
160
-
161
- def _check_parquet(file: Path) -> Dict[str, Any]:
162
- report_dict: Dict[str, Any] = {}
163
-
164
- try:
165
- table = parquet.read_table(str(file), memory_map=True)
166
- except ArrowInvalid:
167
- report_dict["load_parquet"] = (
168
- f"An exception has occurred when loading the Parquet file {file}. Please check the file for corruption. "
169
- f"Exception trace:\n{format_exc()}"
170
- )
171
- report_dict["is_check_passed"] = False
172
- return report_dict
173
-
174
- column_names = table.schema.names
175
- if "input_ids" not in column_names:
176
- report_dict["load_parquet"] = (
177
- f"Parquet file {file} does not contain the `input_ids` column."
178
- )
179
- report_dict["is_check_passed"] = False
180
- return report_dict
181
-
182
- for column_name in column_names:
183
- if column_name not in PARQUET_EXPECTED_COLUMNS:
184
- report_dict["load_parquet"] = (
185
- f"Parquet file {file} contains an unexpected column {column_name}. "
186
- f"Only columns {PARQUET_EXPECTED_COLUMNS} are supported."
187
- )
188
- report_dict["is_check_passed"] = False
189
- return report_dict
190
-
191
- num_samples = len(table)
192
- if num_samples < MIN_SAMPLES:
193
- report_dict["min_samples"] = (
194
- f"Processing {file} resulted in only {num_samples} samples. "
195
- f"Our minimum is {MIN_SAMPLES} samples. "
196
- )
197
- report_dict["is_check_passed"] = False
198
- return report_dict
199
- else:
200
- report_dict["num_samples"] = num_samples
201
-
202
- report_dict["is_check_passed"] = True
203
-
204
- return report_dict
File without changes
File without changes
File without changes