together 1.3.2__tar.gz → 1.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {together-1.3.2 → together-1.3.4}/PKG-INFO +2 -2
  2. {together-1.3.2 → together-1.3.4}/pyproject.toml +3 -3
  3. {together-1.3.2 → together-1.3.4}/src/together/cli/api/finetune.py +21 -18
  4. together-1.3.4/src/together/cli/api/utils.py +51 -0
  5. {together-1.3.2 → together-1.3.4}/src/together/constants.py +19 -0
  6. {together-1.3.2 → together-1.3.4}/src/together/resources/finetune.py +19 -1
  7. {together-1.3.2 → together-1.3.4}/src/together/types/finetune.py +3 -1
  8. together-1.3.4/src/together/utils/files.py +324 -0
  9. together-1.3.2/src/together/cli/api/utils.py +0 -23
  10. together-1.3.2/src/together/utils/files.py +0 -204
  11. {together-1.3.2 → together-1.3.4}/LICENSE +0 -0
  12. {together-1.3.2 → together-1.3.4}/README.md +0 -0
  13. {together-1.3.2 → together-1.3.4}/src/together/__init__.py +0 -0
  14. {together-1.3.2 → together-1.3.4}/src/together/abstract/__init__.py +0 -0
  15. {together-1.3.2 → together-1.3.4}/src/together/abstract/api_requestor.py +0 -0
  16. {together-1.3.2 → together-1.3.4}/src/together/cli/__init__.py +0 -0
  17. {together-1.3.2 → together-1.3.4}/src/together/cli/api/__init__.py +0 -0
  18. {together-1.3.2 → together-1.3.4}/src/together/cli/api/chat.py +0 -0
  19. {together-1.3.2 → together-1.3.4}/src/together/cli/api/completions.py +0 -0
  20. {together-1.3.2 → together-1.3.4}/src/together/cli/api/files.py +0 -0
  21. {together-1.3.2 → together-1.3.4}/src/together/cli/api/images.py +0 -0
  22. {together-1.3.2 → together-1.3.4}/src/together/cli/api/models.py +0 -0
  23. {together-1.3.2 → together-1.3.4}/src/together/cli/cli.py +0 -0
  24. {together-1.3.2 → together-1.3.4}/src/together/client.py +0 -0
  25. {together-1.3.2 → together-1.3.4}/src/together/error.py +0 -0
  26. {together-1.3.2 → together-1.3.4}/src/together/filemanager.py +0 -0
  27. {together-1.3.2 → together-1.3.4}/src/together/legacy/__init__.py +0 -0
  28. {together-1.3.2 → together-1.3.4}/src/together/legacy/base.py +0 -0
  29. {together-1.3.2 → together-1.3.4}/src/together/legacy/complete.py +0 -0
  30. {together-1.3.2 → together-1.3.4}/src/together/legacy/embeddings.py +0 -0
  31. {together-1.3.2 → together-1.3.4}/src/together/legacy/files.py +0 -0
  32. {together-1.3.2 → together-1.3.4}/src/together/legacy/finetune.py +0 -0
  33. {together-1.3.2 → together-1.3.4}/src/together/legacy/images.py +0 -0
  34. {together-1.3.2 → together-1.3.4}/src/together/legacy/models.py +0 -0
  35. {together-1.3.2 → together-1.3.4}/src/together/resources/__init__.py +0 -0
  36. {together-1.3.2 → together-1.3.4}/src/together/resources/chat/__init__.py +0 -0
  37. {together-1.3.2 → together-1.3.4}/src/together/resources/chat/completions.py +0 -0
  38. {together-1.3.2 → together-1.3.4}/src/together/resources/completions.py +0 -0
  39. {together-1.3.2 → together-1.3.4}/src/together/resources/embeddings.py +0 -0
  40. {together-1.3.2 → together-1.3.4}/src/together/resources/files.py +0 -0
  41. {together-1.3.2 → together-1.3.4}/src/together/resources/images.py +0 -0
  42. {together-1.3.2 → together-1.3.4}/src/together/resources/models.py +0 -0
  43. {together-1.3.2 → together-1.3.4}/src/together/resources/rerank.py +0 -0
  44. {together-1.3.2 → together-1.3.4}/src/together/together_response.py +0 -0
  45. {together-1.3.2 → together-1.3.4}/src/together/types/__init__.py +0 -0
  46. {together-1.3.2 → together-1.3.4}/src/together/types/abstract.py +0 -0
  47. {together-1.3.2 → together-1.3.4}/src/together/types/chat_completions.py +0 -0
  48. {together-1.3.2 → together-1.3.4}/src/together/types/common.py +0 -0
  49. {together-1.3.2 → together-1.3.4}/src/together/types/completions.py +0 -0
  50. {together-1.3.2 → together-1.3.4}/src/together/types/embeddings.py +0 -0
  51. {together-1.3.2 → together-1.3.4}/src/together/types/error.py +0 -0
  52. {together-1.3.2 → together-1.3.4}/src/together/types/files.py +0 -0
  53. {together-1.3.2 → together-1.3.4}/src/together/types/images.py +0 -0
  54. {together-1.3.2 → together-1.3.4}/src/together/types/models.py +0 -0
  55. {together-1.3.2 → together-1.3.4}/src/together/types/rerank.py +0 -0
  56. {together-1.3.2 → together-1.3.4}/src/together/utils/__init__.py +0 -0
  57. {together-1.3.2 → together-1.3.4}/src/together/utils/_log.py +0 -0
  58. {together-1.3.2 → together-1.3.4}/src/together/utils/api_helpers.py +0 -0
  59. {together-1.3.2 → together-1.3.4}/src/together/utils/tools.py +0 -0
  60. {together-1.3.2 → together-1.3.4}/src/together/version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: together
3
- Version: 1.3.2
3
+ Version: 1.3.4
4
4
  Summary: Python client for Together's Cloud Platform!
5
5
  Home-page: https://github.com/togethercomputer/together-python
6
6
  License: Apache-2.0
@@ -29,7 +29,7 @@ Requires-Dist: requests (>=2.31.0,<3.0.0)
29
29
  Requires-Dist: rich (>=13.8.1,<14.0.0)
30
30
  Requires-Dist: tabulate (>=0.9.0,<0.10.0)
31
31
  Requires-Dist: tqdm (>=4.66.2,<5.0.0)
32
- Requires-Dist: typer (>=0.9,<0.13)
32
+ Requires-Dist: typer (>=0.9,<0.14)
33
33
  Project-URL: Bug Tracker, https://github.com/togethercomputer/together-python/issues
34
34
  Project-URL: Repository, https://github.com/togethercomputer/together-python
35
35
  Description-Content-Type: text/markdown
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
12
12
 
13
13
  [tool.poetry]
14
14
  name = "together"
15
- version = "1.3.2"
15
+ version = "1.3.4"
16
16
  authors = [
17
17
  "Together AI <support@together.ai>"
18
18
  ]
@@ -29,7 +29,7 @@ homepage = "https://github.com/togethercomputer/together-python"
29
29
 
30
30
  [tool.poetry.dependencies]
31
31
  python = "^3.8"
32
- typer = ">=0.9,<0.13"
32
+ typer = ">=0.9,<0.14"
33
33
  requests = "^2.31.0"
34
34
  rich = "^13.8.1"
35
35
  tqdm = "^4.66.2"
@@ -51,7 +51,7 @@ optional = true
51
51
 
52
52
  [tool.poetry.group.quality.dependencies]
53
53
  black = ">=23.1,<25.0"
54
- ruff = ">=0.3.2,<0.7.0"
54
+ ruff = ">=0.3.2,<0.8.0"
55
55
  types-tqdm = "^4.65.0.0"
56
56
  types-tabulate = "^0.9.0.3"
57
57
  pre-commit = "3.5.0"
@@ -11,8 +11,13 @@ from rich import print as rprint
11
11
  from tabulate import tabulate
12
12
 
13
13
  from together import Together
14
- from together.cli.api.utils import INT_WITH_MAX
15
- from together.utils import finetune_price_to_dollars, log_warn, parse_timestamp
14
+ from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX
15
+ from together.utils import (
16
+ finetune_price_to_dollars,
17
+ log_warn,
18
+ log_warn_once,
19
+ parse_timestamp,
20
+ )
16
21
  from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
17
22
 
18
23
 
@@ -93,6 +98,13 @@ def fine_tuning(ctx: click.Context) -> None:
93
98
  default=False,
94
99
  help="Whether to skip the launch confirmation message",
95
100
  )
101
+ @click.option(
102
+ "--train-on-inputs",
103
+ type=BOOL_WITH_AUTO,
104
+ default="auto",
105
+ help="Whether to mask the user messages in conversational data or prompts in instruction data. "
106
+ "`auto` will automatically determine whether to mask the inputs based on the data format.",
107
+ )
96
108
  def create(
97
109
  ctx: click.Context,
98
110
  training_file: str,
@@ -112,6 +124,7 @@ def create(
112
124
  suffix: str,
113
125
  wandb_api_key: str,
114
126
  confirm: bool,
127
+ train_on_inputs: bool | Literal["auto"],
115
128
  ) -> None:
116
129
  """Start fine-tuning"""
117
130
  client: Together = ctx.obj
@@ -133,6 +146,7 @@ def create(
133
146
  lora_trainable_modules=lora_trainable_modules,
134
147
  suffix=suffix,
135
148
  wandb_api_key=wandb_api_key,
149
+ train_on_inputs=train_on_inputs,
136
150
  )
137
151
 
138
152
  model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits(
@@ -150,6 +164,10 @@ def create(
150
164
  "batch_size": model_limits.lora_training.max_batch_size,
151
165
  "learning_rate": 1e-3,
152
166
  }
167
+ log_warn_once(
168
+ f"The default LoRA rank for {model} has been changed to {default_values['lora_r']} as the max available.\n"
169
+ f"Also, the default learning rate for LoRA fine-tuning has been changed to {default_values['learning_rate']}."
170
+ )
153
171
  for arg in default_values:
154
172
  arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined]
155
173
  if arg_source == ParameterSource.DEFAULT:
@@ -186,22 +204,7 @@ def create(
186
204
 
187
205
  if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
188
206
  response = client.fine_tuning.create(
189
- training_file=training_file,
190
- model=model,
191
- n_epochs=n_epochs,
192
- validation_file=validation_file,
193
- n_evals=n_evals,
194
- n_checkpoints=n_checkpoints,
195
- batch_size=batch_size,
196
- learning_rate=learning_rate,
197
- warmup_ratio=warmup_ratio,
198
- lora=lora,
199
- lora_r=lora_r,
200
- lora_dropout=lora_dropout,
201
- lora_alpha=lora_alpha,
202
- lora_trainable_modules=lora_trainable_modules,
203
- suffix=suffix,
204
- wandb_api_key=wandb_api_key,
207
+ **training_args,
205
208
  verbose=True,
206
209
  )
207
210
 
@@ -0,0 +1,51 @@
1
+ from __future__ import annotations
2
+
3
+ from gettext import gettext as _
4
+ from typing import Literal
5
+
6
+ import click
7
+
8
+
9
+ class AutoIntParamType(click.ParamType):
10
+ name = "integer_or_max"
11
+ _number_class = int
12
+
13
+ def convert(
14
+ self, value: str, param: click.Parameter | None, ctx: click.Context | None
15
+ ) -> int | Literal["max"] | None:
16
+ if value == "max":
17
+ return "max"
18
+ try:
19
+ return int(value)
20
+ except ValueError:
21
+ self.fail(
22
+ _("{value!r} is not a valid {number_type}.").format(
23
+ value=value, number_type=self.name
24
+ ),
25
+ param,
26
+ ctx,
27
+ )
28
+
29
+
30
+ class BooleanWithAutoParamType(click.ParamType):
31
+ name = "boolean_or_auto"
32
+
33
+ def convert(
34
+ self, value: str, param: click.Parameter | None, ctx: click.Context | None
35
+ ) -> bool | Literal["auto"] | None:
36
+ if value == "auto":
37
+ return "auto"
38
+ try:
39
+ return bool(value)
40
+ except ValueError:
41
+ self.fail(
42
+ _("{value!r} is not a valid {type}.").format(
43
+ value=value, type=self.name
44
+ ),
45
+ param,
46
+ ctx,
47
+ )
48
+
49
+
50
+ INT_WITH_MAX = AutoIntParamType()
51
+ BOOL_WITH_AUTO = BooleanWithAutoParamType()
@@ -1,3 +1,5 @@
1
+ import enum
2
+
1
3
  # Session constants
2
4
  TIMEOUT_SECS = 600
3
5
  MAX_SESSION_LIFETIME_SECS = 180
@@ -29,3 +31,20 @@ MAX_FILE_SIZE_GB = 4.9
29
31
 
30
32
  # expected columns for Parquet files
31
33
  PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]
34
+
35
+
36
+ class DatasetFormat(enum.Enum):
37
+ """Dataset format enum."""
38
+
39
+ GENERAL = "general"
40
+ CONVERSATION = "conversation"
41
+ INSTRUCTION = "instruction"
42
+
43
+
44
+ JSONL_REQUIRED_COLUMNS_MAP = {
45
+ DatasetFormat.GENERAL: ["text"],
46
+ DatasetFormat.CONVERSATION: ["messages"],
47
+ DatasetFormat.INSTRUCTION: ["prompt", "completion"],
48
+ }
49
+ REQUIRED_COLUMNS_MESSAGE = ["role", "content"]
50
+ POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"]
@@ -43,6 +43,7 @@ def createFinetuneRequest(
43
43
  lora_trainable_modules: str | None = "all-linear",
44
44
  suffix: str | None = None,
45
45
  wandb_api_key: str | None = None,
46
+ train_on_inputs: bool | Literal["auto"] = "auto",
46
47
  ) -> FinetuneRequest:
47
48
  if batch_size == "max":
48
49
  log_warn_once(
@@ -95,6 +96,7 @@ def createFinetuneRequest(
95
96
  training_type=training_type,
96
97
  suffix=suffix,
97
98
  wandb_key=wandb_api_key,
99
+ train_on_inputs=train_on_inputs,
98
100
  )
99
101
 
100
102
  return finetune_request
@@ -125,6 +127,7 @@ class FineTuning:
125
127
  wandb_api_key: str | None = None,
126
128
  verbose: bool = False,
127
129
  model_limits: FinetuneTrainingLimits | None = None,
130
+ train_on_inputs: bool | Literal["auto"] = "auto",
128
131
  ) -> FinetuneResponse:
129
132
  """
130
133
  Method to initiate a fine-tuning job
@@ -137,7 +140,7 @@ class FineTuning:
137
140
  n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
138
141
  n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
139
142
  Defaults to 1.
140
- batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
143
+ batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
141
144
  learning_rate (float, optional): Learning rate multiplier to use for training
142
145
  Defaults to 0.00001.
143
146
  warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
@@ -154,6 +157,12 @@ class FineTuning:
154
157
  Defaults to False.
155
158
  model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
156
159
  Defaults to None.
160
+ train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
161
+ "auto" will automatically determine whether to mask the inputs based on the data format.
162
+ For datasets with the "text" field (general format), inputs will not be masked.
163
+ For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
164
+ (Instruction format), inputs will be masked.
165
+ Defaults to "auto".
157
166
 
158
167
  Returns:
159
168
  FinetuneResponse: Object containing information about fine-tuning job.
@@ -184,6 +193,7 @@ class FineTuning:
184
193
  lora_trainable_modules=lora_trainable_modules,
185
194
  suffix=suffix,
186
195
  wandb_api_key=wandb_api_key,
196
+ train_on_inputs=train_on_inputs,
187
197
  )
188
198
 
189
199
  if verbose:
@@ -436,6 +446,7 @@ class AsyncFineTuning:
436
446
  wandb_api_key: str | None = None,
437
447
  verbose: bool = False,
438
448
  model_limits: FinetuneTrainingLimits | None = None,
449
+ train_on_inputs: bool | Literal["auto"] = "auto",
439
450
  ) -> FinetuneResponse:
440
451
  """
441
452
  Async method to initiate a fine-tuning job
@@ -465,6 +476,12 @@ class AsyncFineTuning:
465
476
  Defaults to False.
466
477
  model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
467
478
  Defaults to None.
479
+ train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
480
+ "auto" will automatically determine whether to mask the inputs based on the data format.
481
+ For datasets with the "text" field (general format), inputs will not be masked.
482
+ For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields
483
+ (Instruction format), inputs will be masked.
484
+ Defaults to "auto".
468
485
 
469
486
  Returns:
470
487
  FinetuneResponse: Object containing information about fine-tuning job.
@@ -495,6 +512,7 @@ class AsyncFineTuning:
495
512
  lora_trainable_modules=lora_trainable_modules,
496
513
  suffix=suffix,
497
514
  wandb_api_key=wandb_api_key,
515
+ train_on_inputs=train_on_inputs,
498
516
  )
499
517
 
500
518
  if verbose:
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from enum import Enum
4
4
  from typing import List, Literal
5
5
 
6
- from pydantic import Field, validator, field_validator
6
+ from pydantic import StrictBool, Field, validator, field_validator
7
7
 
8
8
  from together.types.abstract import BaseModel
9
9
  from together.types.common import (
@@ -163,6 +163,7 @@ class FinetuneRequest(BaseModel):
163
163
  # weights & biases api key
164
164
  wandb_key: str | None = None
165
165
  training_type: FullTrainingType | LoRATrainingType | None = None
166
+ train_on_inputs: StrictBool | Literal["auto"] = "auto"
166
167
 
167
168
 
168
169
  class FinetuneResponse(BaseModel):
@@ -230,6 +231,7 @@ class FinetuneResponse(BaseModel):
230
231
  # training file metadata
231
232
  training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
232
233
  training_file_size: int | None = Field(None, alias="TrainingFileSize")
234
+ train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
233
235
 
234
236
  @field_validator("training_type")
235
237
  @classmethod
@@ -0,0 +1,324 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from traceback import format_exc
7
+ from typing import Any, Dict
8
+
9
+ from pyarrow import ArrowInvalid, parquet
10
+
11
+ from together.constants import (
12
+ MAX_FILE_SIZE_GB,
13
+ MIN_SAMPLES,
14
+ NUM_BYTES_IN_GB,
15
+ PARQUET_EXPECTED_COLUMNS,
16
+ JSONL_REQUIRED_COLUMNS_MAP,
17
+ REQUIRED_COLUMNS_MESSAGE,
18
+ POSSIBLE_ROLES_CONVERSATION,
19
+ DatasetFormat,
20
+ )
21
+
22
+
23
+ class InvalidFileFormatError(ValueError):
24
+ """Exception raised for invalid file formats during file checks."""
25
+
26
+ def __init__(
27
+ self,
28
+ message: str = "",
29
+ line_number: int | None = None,
30
+ error_source: str | None = None,
31
+ ) -> None:
32
+ super().__init__(message)
33
+ self.message = message
34
+ self.line_number = line_number
35
+ self.error_source = error_source
36
+
37
+
38
+ def check_file(
39
+ file: Path | str,
40
+ ) -> Dict[str, Any]:
41
+ if not isinstance(file, Path):
42
+ file = Path(file)
43
+
44
+ report_dict = {
45
+ "is_check_passed": True,
46
+ "message": "Checks passed",
47
+ "found": None,
48
+ "file_size": None,
49
+ "utf8": None,
50
+ "line_type": None,
51
+ "text_field": None,
52
+ "key_value": None,
53
+ "has_min_samples": None,
54
+ "num_samples": None,
55
+ "load_json": None,
56
+ }
57
+
58
+ if not file.is_file():
59
+ report_dict["found"] = False
60
+ report_dict["is_check_passed"] = False
61
+ return report_dict
62
+ else:
63
+ report_dict["found"] = True
64
+
65
+ file_size = os.stat(file.as_posix()).st_size
66
+
67
+ if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
68
+ report_dict["message"] = (
69
+ f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB ,3)} GB."
70
+ )
71
+ report_dict["is_check_passed"] = False
72
+ elif file_size == 0:
73
+ report_dict["message"] = "File is empty"
74
+ report_dict["file_size"] = 0
75
+ report_dict["is_check_passed"] = False
76
+ return report_dict
77
+ else:
78
+ report_dict["file_size"] = file_size
79
+
80
+ data_report_dict = {}
81
+ if file.suffix == ".jsonl":
82
+ report_dict["filetype"] = "jsonl"
83
+ data_report_dict = _check_jsonl(file)
84
+ elif file.suffix == ".parquet":
85
+ report_dict["filetype"] = "parquet"
86
+ data_report_dict = _check_parquet(file)
87
+ else:
88
+ report_dict["filetype"] = (
89
+ f"Unknown extension of file {file}. "
90
+ "Only files with extensions .jsonl and .parquet are supported."
91
+ )
92
+ report_dict["is_check_passed"] = False
93
+
94
+ report_dict.update(data_report_dict)
95
+
96
+ return report_dict
97
+
98
+
99
+ def _check_jsonl(file: Path) -> Dict[str, Any]:
100
+ report_dict: Dict[str, Any] = {}
101
+ # Check that the file is UTF-8 encoded. If not report where the error occurs.
102
+ try:
103
+ with file.open(encoding="utf-8") as f:
104
+ f.read()
105
+ report_dict["utf8"] = True
106
+ except UnicodeDecodeError as e:
107
+ report_dict["utf8"] = False
108
+ report_dict["message"] = f"File is not UTF-8 encoded. Error raised: {e}."
109
+ report_dict["is_check_passed"] = False
110
+ return report_dict
111
+
112
+ dataset_format = None
113
+ with file.open() as f:
114
+ idx = -1
115
+ try:
116
+ for idx, line in enumerate(f):
117
+ json_line = json.loads(line)
118
+
119
+ if not isinstance(json_line, dict):
120
+ raise InvalidFileFormatError(
121
+ message=(
122
+ f"Error parsing file. Invalid format on line {idx + 1} of the input file. "
123
+ 'Example of valid json: {"text": "my sample string"}. '
124
+ ),
125
+ line_number=idx + 1,
126
+ error_source="line_type",
127
+ )
128
+
129
+ current_format = None
130
+ for possible_format in JSONL_REQUIRED_COLUMNS_MAP:
131
+ if all(
132
+ column in json_line
133
+ for column in JSONL_REQUIRED_COLUMNS_MAP[possible_format]
134
+ ):
135
+ if current_format is None:
136
+ current_format = possible_format
137
+ elif current_format != possible_format:
138
+ raise InvalidFileFormatError(
139
+ message="Found multiple dataset formats in the input file. "
140
+ f"Got {current_format} and {possible_format} on line {idx + 1}.",
141
+ line_number=idx + 1,
142
+ error_source="format",
143
+ )
144
+
145
+ if current_format is None:
146
+ raise InvalidFileFormatError(
147
+ message=(
148
+ f"Error parsing file. Could not detect a format for the line {idx + 1} with the columns:\n"
149
+ f"{json_line.keys()}"
150
+ ),
151
+ line_number=idx + 1,
152
+ error_source="format",
153
+ )
154
+
155
+ if current_format == DatasetFormat.CONVERSATION:
156
+ message_column = JSONL_REQUIRED_COLUMNS_MAP[
157
+ DatasetFormat.CONVERSATION
158
+ ][0]
159
+ if not isinstance(json_line[message_column], list):
160
+ raise InvalidFileFormatError(
161
+ message=f"Invalid format on line {idx + 1} of the input file. "
162
+ f"Expected a list of messages. Found {type(json_line[message_column])}",
163
+ line_number=idx + 1,
164
+ error_source="key_value",
165
+ )
166
+
167
+ for turn_id, turn in enumerate(json_line[message_column]):
168
+ if not isinstance(turn, dict):
169
+ raise InvalidFileFormatError(
170
+ message=f"Invalid format on line {idx + 1} of the input file. "
171
+ f"Expected a dictionary in the {turn_id + 1} turn. Found {type(turn)}",
172
+ line_number=idx + 1,
173
+ error_source="key_value",
174
+ )
175
+
176
+ previous_role = None
177
+ for turn in json_line[message_column]:
178
+ for column in REQUIRED_COLUMNS_MESSAGE:
179
+ if column not in turn:
180
+ raise InvalidFileFormatError(
181
+ message=f"Field `{column}` is missing for a turn `{turn}` on line {idx + 1} "
182
+ "of the the input file.",
183
+ line_number=idx + 1,
184
+ error_source="key_value",
185
+ )
186
+ else:
187
+ if not isinstance(turn[column], str):
188
+ raise InvalidFileFormatError(
189
+ message=f"Invalid format on line {idx + 1} in the column {column} for turn `{turn}` "
190
+ f"of the input file. Expected string. Found {type(turn[column])}",
191
+ line_number=idx + 1,
192
+ error_source="text_field",
193
+ )
194
+ role = turn["role"]
195
+
196
+ if role not in POSSIBLE_ROLES_CONVERSATION:
197
+ raise InvalidFileFormatError(
198
+ message=f"Found invalid role `{role}` in the messages on the line {idx + 1}. "
199
+ f"Possible roles in the conversation are: {POSSIBLE_ROLES_CONVERSATION}",
200
+ line_number=idx + 1,
201
+ error_source="key_value",
202
+ )
203
+
204
+ if previous_role == role:
205
+ raise InvalidFileFormatError(
206
+ message=f"Invalid role turns on line {idx + 1} of the input file. "
207
+ "`user` and `assistant` roles must alternate user/assistant/user/assistant/...",
208
+ line_number=idx + 1,
209
+ error_source="key_value",
210
+ )
211
+
212
+ previous_role = role
213
+
214
+ else:
215
+ for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
216
+ if not isinstance(json_line[column], str):
217
+ raise InvalidFileFormatError(
218
+ message=f'Invalid value type for "{column}" key on line {idx + 1}. '
219
+ f"Expected string. Found {type(json_line[column])}.",
220
+ line_number=idx + 1,
221
+ error_source="key_value",
222
+ )
223
+
224
+ if dataset_format is None:
225
+ dataset_format = current_format
226
+ elif current_format is not None:
227
+ if current_format != dataset_format:
228
+ raise InvalidFileFormatError(
229
+ message="All samples in the dataset must have the same dataset format. "
230
+ f"Got {dataset_format} for the first line and {current_format} "
231
+ f"for the line {idx + 1}.",
232
+ line_number=idx + 1,
233
+ error_source="format",
234
+ )
235
+
236
+ if idx + 1 < MIN_SAMPLES:
237
+ report_dict["has_min_samples"] = False
238
+ report_dict["message"] = (
239
+ f"Processing {file} resulted in only {idx + 1} samples. "
240
+ f"Our minimum is {MIN_SAMPLES} samples. "
241
+ )
242
+ report_dict["is_check_passed"] = False
243
+ else:
244
+ report_dict["num_samples"] = idx + 1
245
+ report_dict["has_min_samples"] = True
246
+ report_dict["is_check_passed"] = True
247
+
248
+ report_dict["load_json"] = True
249
+
250
+ except InvalidFileFormatError as e:
251
+ report_dict["load_json"] = False
252
+ report_dict["is_check_passed"] = False
253
+ report_dict["message"] = e.message
254
+ if e.line_number is not None:
255
+ report_dict["line_number"] = e.line_number
256
+ if e.error_source is not None:
257
+ report_dict[e.error_source] = False
258
+ except ValueError:
259
+ report_dict["load_json"] = False
260
+ if idx < 0:
261
+ report_dict["message"] = (
262
+ "Unable to decode file. "
263
+ "File may be empty or in an unsupported format. "
264
+ )
265
+ else:
266
+ report_dict["message"] = (
267
+ f"Error parsing json payload. Unexpected format on line {idx + 1}."
268
+ )
269
+ report_dict["is_check_passed"] = False
270
+
271
+ if "text_field" not in report_dict:
272
+ report_dict["text_field"] = True
273
+ if "line_type" not in report_dict:
274
+ report_dict["line_type"] = True
275
+ if "key_value" not in report_dict:
276
+ report_dict["key_value"] = True
277
+ return report_dict
278
+
279
+
280
+ def _check_parquet(file: Path) -> Dict[str, Any]:
281
+ report_dict: Dict[str, Any] = {}
282
+
283
+ try:
284
+ table = parquet.read_table(str(file), memory_map=True)
285
+ except ArrowInvalid:
286
+ report_dict["load_parquet"] = (
287
+ f"An exception has occurred when loading the Parquet file {file}. Please check the file for corruption. "
288
+ f"Exception trace:\n{format_exc()}"
289
+ )
290
+ report_dict["is_check_passed"] = False
291
+ return report_dict
292
+
293
+ column_names = table.schema.names
294
+ if "input_ids" not in column_names:
295
+ report_dict["load_parquet"] = (
296
+ f"Parquet file {file} does not contain the `input_ids` column."
297
+ )
298
+ report_dict["is_check_passed"] = False
299
+ return report_dict
300
+
301
+ for column_name in column_names:
302
+ if column_name not in PARQUET_EXPECTED_COLUMNS:
303
+ report_dict["load_parquet"] = (
304
+ f"Parquet file {file} contains an unexpected column {column_name}. "
305
+ f"Only columns {PARQUET_EXPECTED_COLUMNS} are supported."
306
+ )
307
+ report_dict["is_check_passed"] = False
308
+ return report_dict
309
+
310
+ num_samples = len(table)
311
+ if num_samples < MIN_SAMPLES:
312
+ report_dict["has_min_samples"] = False
313
+ report_dict["message"] = (
314
+ f"Processing {file} resulted in only {num_samples} samples. "
315
+ f"Our minimum is {MIN_SAMPLES} samples. "
316
+ )
317
+ report_dict["is_check_passed"] = False
318
+ return report_dict
319
+ else:
320
+ report_dict["num_samples"] = num_samples
321
+
322
+ report_dict["is_check_passed"] = True
323
+
324
+ return report_dict
@@ -1,23 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import click
4
-
5
- from typing import Literal
6
-
7
-
8
- class AutoIntParamType(click.ParamType):
9
- name = "integer"
10
-
11
- def convert(
12
- self, value: str, param: click.Parameter | None, ctx: click.Context | None
13
- ) -> int | Literal["max"] | None:
14
- if isinstance(value, int):
15
- return value
16
-
17
- if value == "max":
18
- return "max"
19
-
20
- self.fail("Invalid integer value: {value}")
21
-
22
-
23
- INT_WITH_MAX = AutoIntParamType()
@@ -1,204 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- import os
5
- from pathlib import Path
6
- from traceback import format_exc
7
- from typing import Any, Dict
8
-
9
- from pyarrow import ArrowInvalid, parquet
10
-
11
- from together.constants import (
12
- MAX_FILE_SIZE_GB,
13
- MIN_SAMPLES,
14
- NUM_BYTES_IN_GB,
15
- PARQUET_EXPECTED_COLUMNS,
16
- )
17
-
18
-
19
- def check_file(
20
- file: Path | str,
21
- ) -> Dict[str, Any]:
22
- if not isinstance(file, Path):
23
- file = Path(file)
24
-
25
- report_dict = {
26
- "is_check_passed": True,
27
- "message": "Checks passed",
28
- "found": None,
29
- "file_size": None,
30
- "utf8": None,
31
- "line_type": None,
32
- "text_field": None,
33
- "key_value": None,
34
- "min_samples": None,
35
- "num_samples": None,
36
- "load_json": None,
37
- }
38
-
39
- if not file.is_file():
40
- report_dict["found"] = False
41
- report_dict["is_check_passed"] = False
42
- return report_dict
43
- else:
44
- report_dict["found"] = True
45
-
46
- file_size = os.stat(file.as_posix()).st_size
47
-
48
- if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
49
- report_dict["message"] = (
50
- f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB ,3)} GB."
51
- )
52
- report_dict["is_check_passed"] = False
53
- elif file_size == 0:
54
- report_dict["message"] = "File is empty"
55
- report_dict["file_size"] = 0
56
- report_dict["is_check_passed"] = False
57
- return report_dict
58
- else:
59
- report_dict["file_size"] = file_size
60
-
61
- if file.suffix == ".jsonl":
62
- report_dict["filetype"] = "jsonl"
63
- data_report_dict = _check_jsonl(file)
64
- elif file.suffix == ".parquet":
65
- report_dict["filetype"] = "parquet"
66
- data_report_dict = _check_parquet(file)
67
- else:
68
- report_dict["filetype"] = (
69
- f"Unknown extension of file {file}. "
70
- "Only files with extensions .jsonl and .parquet are supported."
71
- )
72
- report_dict["is_check_passed"] = False
73
-
74
- report_dict.update(data_report_dict)
75
- return report_dict
76
-
77
-
78
- def _check_jsonl(file: Path) -> Dict[str, Any]:
79
- report_dict: Dict[str, Any] = {}
80
- # Check that the file is UTF-8 encoded. If not report where the error occurs.
81
- try:
82
- with file.open(encoding="utf-8") as f:
83
- f.read()
84
- report_dict["utf8"] = True
85
- except UnicodeDecodeError as e:
86
- report_dict["utf8"] = False
87
- report_dict["message"] = f"File is not UTF-8 encoded. Error raised: {e}."
88
- report_dict["is_check_passed"] = False
89
- return report_dict
90
-
91
- with file.open() as f:
92
- # idx must be instantiated so decode errors (e.g. file is a tar) or empty files are caught
93
- idx = -1
94
- try:
95
- for idx, line in enumerate(f):
96
- json_line = json.loads(line) # each line in jsonlines should be a json
97
-
98
- if not isinstance(json_line, dict):
99
- report_dict["line_type"] = False
100
- report_dict["message"] = (
101
- f"Error parsing file. Invalid format on line {idx + 1} of the input file. "
102
- 'Example of valid json: {"text": "my sample string"}. '
103
- )
104
-
105
- report_dict["is_check_passed"] = False
106
-
107
- if "text" not in json_line.keys():
108
- report_dict["text_field"] = False
109
- report_dict["message"] = (
110
- f"Missing 'text' field was found on line {idx + 1} of the the input file. "
111
- "Expected format: {'text': 'my sample string'}. "
112
- )
113
- report_dict["is_check_passed"] = False
114
- else:
115
- # check to make sure the value of the "text" key is a string
116
- if not isinstance(json_line["text"], str):
117
- report_dict["key_value"] = False
118
- report_dict["message"] = (
119
- f'Invalid value type for "text" key on line {idx + 1}. '
120
- f'Expected string. Found {type(json_line["text"])}.'
121
- )
122
-
123
- report_dict["is_check_passed"] = False
124
-
125
- # make sure this is outside the for idx, line in enumerate(f): for loop
126
- if idx + 1 < MIN_SAMPLES:
127
- report_dict["min_samples"] = False
128
- report_dict["message"] = (
129
- f"Processing {file} resulted in only {idx + 1} samples. "
130
- f"Our minimum is {MIN_SAMPLES} samples. "
131
- )
132
- report_dict["is_check_passed"] = False
133
- else:
134
- report_dict["num_samples"] = idx + 1
135
- report_dict["min_samples"] = True
136
-
137
- report_dict["load_json"] = True
138
-
139
- except ValueError:
140
- report_dict["load_json"] = False
141
- if idx < 0:
142
- report_dict["message"] = (
143
- "Unable to decode file. "
144
- "File may be empty or in an unsupported format. "
145
- )
146
- else:
147
- report_dict["message"] = (
148
- f"Error parsing json payload. Unexpected format on line {idx + 1}."
149
- )
150
- report_dict["is_check_passed"] = False
151
-
152
- if "text_field" not in report_dict:
153
- report_dict["text_field"] = True
154
- if "line_type" not in report_dict:
155
- report_dict["line_type"] = True
156
- if "key_value" not in report_dict:
157
- report_dict["key_value"] = True
158
- return report_dict
159
-
160
-
161
- def _check_parquet(file: Path) -> Dict[str, Any]:
162
- report_dict: Dict[str, Any] = {}
163
-
164
- try:
165
- table = parquet.read_table(str(file), memory_map=True)
166
- except ArrowInvalid:
167
- report_dict["load_parquet"] = (
168
- f"An exception has occurred when loading the Parquet file {file}. Please check the file for corruption. "
169
- f"Exception trace:\n{format_exc()}"
170
- )
171
- report_dict["is_check_passed"] = False
172
- return report_dict
173
-
174
- column_names = table.schema.names
175
- if "input_ids" not in column_names:
176
- report_dict["load_parquet"] = (
177
- f"Parquet file {file} does not contain the `input_ids` column."
178
- )
179
- report_dict["is_check_passed"] = False
180
- return report_dict
181
-
182
- for column_name in column_names:
183
- if column_name not in PARQUET_EXPECTED_COLUMNS:
184
- report_dict["load_parquet"] = (
185
- f"Parquet file {file} contains an unexpected column {column_name}. "
186
- f"Only columns {PARQUET_EXPECTED_COLUMNS} are supported."
187
- )
188
- report_dict["is_check_passed"] = False
189
- return report_dict
190
-
191
- num_samples = len(table)
192
- if num_samples < MIN_SAMPLES:
193
- report_dict["min_samples"] = (
194
- f"Processing {file} resulted in only {num_samples} samples. "
195
- f"Our minimum is {MIN_SAMPLES} samples. "
196
- )
197
- report_dict["is_check_passed"] = False
198
- return report_dict
199
- else:
200
- report_dict["num_samples"] = num_samples
201
-
202
- report_dict["is_check_passed"] = True
203
-
204
- return report_dict
File without changes
File without changes
File without changes