seekrai 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. seekrai/__init__.py +64 -0
  2. seekrai/abstract/__init__.py +1 -0
  3. seekrai/abstract/api_requestor.py +710 -0
  4. seekrai/cli/__init__.py +0 -0
  5. seekrai/cli/api/__init__.py +0 -0
  6. seekrai/cli/api/chat.py +245 -0
  7. seekrai/cli/api/completions.py +107 -0
  8. seekrai/cli/api/files.py +125 -0
  9. seekrai/cli/api/finetune.py +175 -0
  10. seekrai/cli/api/images.py +82 -0
  11. seekrai/cli/api/models.py +42 -0
  12. seekrai/cli/cli.py +77 -0
  13. seekrai/client.py +154 -0
  14. seekrai/constants.py +32 -0
  15. seekrai/error.py +188 -0
  16. seekrai/filemanager.py +393 -0
  17. seekrai/legacy/__init__.py +0 -0
  18. seekrai/legacy/base.py +27 -0
  19. seekrai/legacy/complete.py +91 -0
  20. seekrai/legacy/embeddings.py +25 -0
  21. seekrai/legacy/files.py +140 -0
  22. seekrai/legacy/finetune.py +173 -0
  23. seekrai/legacy/images.py +25 -0
  24. seekrai/legacy/models.py +44 -0
  25. seekrai/resources/__init__.py +25 -0
  26. seekrai/resources/chat/__init__.py +24 -0
  27. seekrai/resources/chat/completions.py +241 -0
  28. seekrai/resources/completions.py +205 -0
  29. seekrai/resources/embeddings.py +100 -0
  30. seekrai/resources/files.py +173 -0
  31. seekrai/resources/finetune.py +425 -0
  32. seekrai/resources/images.py +156 -0
  33. seekrai/resources/models.py +75 -0
  34. seekrai/seekrflow_response.py +50 -0
  35. seekrai/types/__init__.py +67 -0
  36. seekrai/types/abstract.py +26 -0
  37. seekrai/types/chat_completions.py +151 -0
  38. seekrai/types/common.py +64 -0
  39. seekrai/types/completions.py +86 -0
  40. seekrai/types/embeddings.py +35 -0
  41. seekrai/types/error.py +16 -0
  42. seekrai/types/files.py +88 -0
  43. seekrai/types/finetune.py +218 -0
  44. seekrai/types/images.py +42 -0
  45. seekrai/types/models.py +43 -0
  46. seekrai/utils/__init__.py +28 -0
  47. seekrai/utils/_log.py +61 -0
  48. seekrai/utils/api_helpers.py +84 -0
  49. seekrai/utils/files.py +204 -0
  50. seekrai/utils/tools.py +75 -0
  51. seekrai/version.py +6 -0
  52. seekrai-0.0.1.dist-info/LICENSE +201 -0
  53. seekrai-0.0.1.dist-info/METADATA +401 -0
  54. seekrai-0.0.1.dist-info/RECORD +56 -0
  55. seekrai-0.0.1.dist-info/WHEEL +4 -0
  56. seekrai-0.0.1.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,218 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import List, Literal
5
+
6
+ from pydantic import Field
7
+
8
+ from seekrai.types.abstract import BaseModel
9
+ from seekrai.types.common import (
10
+ ObjectType,
11
+ )
12
+ from datetime import datetime
13
+
14
+
15
+
16
+ class FinetuneJobStatus(str, Enum):
17
+ """
18
+ Possible fine-tune job status
19
+ """
20
+
21
+ STATUS_PENDING = "pending"
22
+ STATUS_QUEUED = "queued"
23
+ STATUS_RUNNING = "running"
24
+ # STATUS_COMPRESSING = "compressing"
25
+ # STATUS_UPLOADING = "uploading"
26
+ STATUS_CANCEL_REQUESTED = "cancel_requested"
27
+ STATUS_CANCELLED = "cancelled"
28
+ STATUS_FAILED = "failed"
29
+ STATUS_COMPLETED = "completed"
30
+
31
+
32
+ class FinetuneEventLevels(str, Enum):
33
+ """
34
+ Fine-tune job event status levels
35
+ """
36
+
37
+ NULL = ""
38
+ INFO = "Info"
39
+ WARNING = "Warning"
40
+ ERROR = "Error"
41
+ LEGACY_INFO = "info"
42
+ LEGACY_IWARNING = "warning"
43
+ LEGACY_IERROR = "error"
44
+
45
+
46
+ class FinetuneEventType(str, Enum):
47
+ """
48
+ Fine-tune job event types
49
+ """
50
+
51
+ JOB_PENDING = "JOB_PENDING"
52
+ JOB_START = "JOB_START"
53
+ JOB_STOPPED = "JOB_STOPPED"
54
+ MODEL_DOWNLOADING = "MODEL_DOWNLOADING"
55
+ MODEL_DOWNLOAD_COMPLETE = "MODEL_DOWNLOAD_COMPLETE"
56
+ TRAINING_DATA_DOWNLOADING = "TRAINING_DATA_DOWNLOADING"
57
+ TRAINING_DATA_DOWNLOAD_COMPLETE = "TRAINING_DATA_DOWNLOAD_COMPLETE"
58
+ VALIDATION_DATA_DOWNLOADING = "VALIDATION_DATA_DOWNLOADING"
59
+ VALIDATION_DATA_DOWNLOAD_COMPLETE = "VALIDATION_DATA_DOWNLOAD_COMPLETE"
60
+ WANDB_INIT = "WANDB_INIT"
61
+ TRAINING_START = "TRAINING_START"
62
+ CHECKPOINT_SAVE = "CHECKPOINT_SAVE"
63
+ BILLING_LIMIT = "BILLING_LIMIT"
64
+ EPOCH_COMPLETE = "EPOCH_COMPLETE"
65
+ TRAINING_COMPLETE = "TRAINING_COMPLETE"
66
+ MODEL_COMPRESSING = "COMPRESSING_MODEL"
67
+ MODEL_COMPRESSION_COMPLETE = "MODEL_COMPRESSION_COMPLETE"
68
+ MODEL_UPLOADING = "MODEL_UPLOADING"
69
+ MODEL_UPLOAD_COMPLETE = "MODEL_UPLOAD_COMPLETE"
70
+ JOB_COMPLETE = "JOB_COMPLETE"
71
+ JOB_ERROR = "JOB_ERROR"
72
+ CANCEL_REQUESTED = "CANCEL_REQUESTED"
73
+ JOB_RESTARTED = "JOB_RESTARTED"
74
+ REFUND = "REFUND"
75
+ WARNING = "WARNING"
76
+
77
+
78
+ class FinetuneEvent(BaseModel):
79
+ """
80
+ Fine-tune event type
81
+ """
82
+
83
+ # object type
84
+ object: Literal[ObjectType.FinetuneEvent]
85
+ # created at datetime stamp
86
+ created_at: str | None = None
87
+ # event log level
88
+ level: FinetuneEventLevels | None = None
89
+ # event message string
90
+ message: str | None = None
91
+ # event type
92
+ type: FinetuneEventType | None = None
93
+ # optional: model parameter count
94
+ param_count: int | None = None
95
+ # optional: dataset token count
96
+ token_count: int | None = None
97
+ # optional: weights & biases url
98
+ wandb_url: str | None = None
99
+ # event hash
100
+ hash: str | None = None
101
+
102
+
103
+
104
+ class TrainingConfig(BaseModel):
105
+ # training file ID
106
+ training_files: List[str]
107
+ # base model string
108
+ model: str
109
+ # number of epochs to train for
110
+ n_epochs: int
111
+ # training learning rate
112
+ learning_rate: float
113
+ # number of checkpoints to save
114
+ n_checkpoints: int | None = None
115
+ # training batch size
116
+ batch_size: int | None = None
117
+ # up to 40 character suffix for output model name
118
+ experiment_name: str | None = None
119
+ # # weights & biases api key
120
+ # wandb_key: str | None = None
121
+
122
+ class InfrastructureConfig(BaseModel):
123
+ n_cpu: int
124
+ n_gpu: int
125
+
126
+ class FinetuneRequest(BaseModel):
127
+ """
128
+ Fine-tune request type
129
+ """
130
+ training_config: TrainingConfig
131
+ infrastructure_config: InfrastructureConfig
132
+
133
+
134
+
135
+
136
+
137
+ class FinetuneResponse(BaseModel):
138
+ """
139
+ Fine-tune API response type
140
+ """
141
+
142
+ # job ID
143
+ id: str | None = None
144
+ # training file id
145
+ training_files: List[str] | None = None
146
+ # validation file id
147
+ # validation_files: str | None = None TODO
148
+ # base model name
149
+ model: str | None = None
150
+ # number of epochs
151
+ # n_epochs: int | None = None
152
+ # number of checkpoints to save
153
+ # n_checkpoints: int | None = None # TODO
154
+ # training batch size
155
+ # batch_size: int | None = None
156
+ # training learning rate
157
+ # learning_rate: float | None = None
158
+ # number of steps between evals
159
+ # eval_steps: int | None = None TODO
160
+ # is LoRA finetune boolean
161
+ # lora: bool | None = None
162
+ # lora_r: int | None = None
163
+ # lora_alpha: int | None = None
164
+ # lora_dropout: int | None = None
165
+ # created/updated datetime stamps
166
+ created_at: datetime | None = None
167
+ # updated_at: str | None = None
168
+ # job status
169
+ status: FinetuneJobStatus | None = None
170
+
171
+ # list of fine-tune events
172
+ events: List[FinetuneEvent] | None = None
173
+ inference_available: bool = False
174
+ # dataset token count
175
+ # TODO
176
+ # token_count: int | None = None
177
+ # # model parameter count
178
+ # param_count: int | None = None
179
+ # # fine-tune job price
180
+ # total_price: int | None = None
181
+ # # number of epochs completed (incrementing counter)
182
+ # epochs_completed: int | None = None
183
+ # # place in job queue (decrementing counter)
184
+ # queue_depth: int | None = None
185
+ # # weights & biases project name
186
+ # wandb_project_name: str | None = None
187
+ # # weights & biases job url
188
+ # wandb_url: str | None = None
189
+ # # training file metadata
190
+ # training_file_num_lines: int | None = Field(None, alias="TrainingFileNumLines")
191
+ # training_file_size: int | None = Field(None, alias="TrainingFileSize")
192
+
193
+
194
+ class FinetuneList(BaseModel):
195
+ # object type
196
+ object: Literal["list"] | None = None
197
+ # list of fine-tune job objects
198
+ data: List[FinetuneResponse] | None = None
199
+
200
+
201
+ class FinetuneListEvents(BaseModel):
202
+ # object type
203
+ object: Literal["list"] | None = None
204
+ # list of fine-tune events
205
+ data: List[FinetuneEvent] | None = None
206
+
207
+
208
+ class FinetuneDownloadResult(BaseModel):
209
+ # object type
210
+ object: Literal["local"] | None = None
211
+ # fine-tune job id
212
+ id: str | None = None
213
+ # checkpoint step number
214
+ checkpoint_step: int | None = None
215
+ # local path filename
216
+ filename: str | None = None
217
+ # size in bytes
218
+ size: int | None = None
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import List, Literal
4
+
5
+ from seekrai.types.abstract import BaseModel
6
+
7
+
8
+ class ImageRequest(BaseModel):
9
+ # input or list of inputs
10
+ prompt: str
11
+ # model to query
12
+ model: str
13
+ # num generation steps
14
+ steps: int | None = 20
15
+ # seed
16
+ seed: int | None = None
17
+ # number of results to return
18
+ n: int | None = 1
19
+ # pixel height
20
+ height: int | None = 1024
21
+ # pixel width
22
+ width: int | None = 1024
23
+ # negative prompt
24
+ negative_prompt: str | None = None
25
+
26
+
27
+ class ImageChoicesData(BaseModel):
28
+ # response index
29
+ index: int
30
+ # base64 image response
31
+ b64_json: str
32
+
33
+
34
+ class ImageResponse(BaseModel):
35
+ # job id
36
+ id: str | None = None
37
+ # query model
38
+ model: str | None = None
39
+ # object type
40
+ object: Literal["list"] | None = None
41
+ # list of embedding choices
42
+ data: List[ImageChoicesData] | None = None
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import Literal
5
+
6
+ from seekrai.types.abstract import BaseModel
7
+ from seekrai.types.common import ObjectType
8
+
9
+
10
+ class ModelType(str, Enum):
11
+ CHAT = "chat"
12
+ LANGUAGE = "language"
13
+ CODE = "code"
14
+ IMAGE = "image"
15
+ EMBEDDING = "embedding"
16
+ MODERATION = "moderation"
17
+
18
+
19
+ class PricingObject(BaseModel):
20
+ input: float | None = None
21
+ output: float | None = None
22
+ hourly: float | None = None
23
+ base: float | None = None
24
+ finetune: float | None = None
25
+
26
+
27
+ class ModelObject(BaseModel):
28
+ # model id
29
+ id: str
30
+ # object type
31
+ object: Literal[ObjectType.Model]
32
+ created: int | None = None
33
+ # model type
34
+ type: ModelType | None = None
35
+ # pretty name
36
+ display_name: str | None = None
37
+ # model creator organization
38
+ organization: str | None = None
39
+ # link to model resource
40
+ link: str | None = None
41
+ license: str | None = None
42
+ context_length: int | None = None
43
+ pricing: PricingObject
@@ -0,0 +1,28 @@
1
+ from seekrai.utils._log import log_debug, log_info, log_warn, logfmt
2
+ from seekrai.utils.api_helpers import default_api_key, get_headers
3
+ from seekrai.utils.files import check_file
4
+ from seekrai.utils.tools import (
5
+ convert_bytes,
6
+ convert_unix_timestamp,
7
+ enforce_trailing_slash,
8
+ finetune_price_to_dollars,
9
+ normalize_key,
10
+ parse_timestamp,
11
+ )
12
+
13
+
14
+ __all__ = [
15
+ "check_file",
16
+ "get_headers",
17
+ "default_api_key",
18
+ "log_debug",
19
+ "log_info",
20
+ "log_warn",
21
+ "logfmt",
22
+ "enforce_trailing_slash",
23
+ "normalize_key",
24
+ "parse_timestamp",
25
+ "finetune_price_to_dollars",
26
+ "convert_bytes",
27
+ "convert_unix_timestamp",
28
+ ]
seekrai/utils/_log.py ADDED
@@ -0,0 +1,61 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import re
6
+ import sys
7
+ from typing import Any, Dict
8
+
9
+ import seekrai
10
+
11
+
12
+ logger = logging.getLogger("seekrai")
13
+
14
+ SEEKRFLOW_LOG = os.environ.get("SEEKRFLOW_LOG")
15
+
16
+
17
+ def _console_log_level() -> str | None:
18
+ if seekrai.log in ["debug", "info"]:
19
+ return seekrai.log
20
+ elif SEEKRFLOW_LOG in ["debug", "info"]:
21
+ return SEEKRFLOW_LOG
22
+ else:
23
+ return None
24
+
25
+
26
+ def logfmt(props: Dict[str, Any]) -> str:
27
+ def fmt(key: str, val: Any) -> str:
28
+ # Handle case where val is a bytes or bytesarray
29
+ if hasattr(val, "decode"):
30
+ val = val.decode("utf-8")
31
+ # Check if val is already a string to avoid re-encoding into ascii.
32
+ if not isinstance(val, str):
33
+ val = str(val)
34
+ if re.search(r"\s", val):
35
+ val = repr(val)
36
+ # key should already be a string
37
+ if re.search(r"\s", key):
38
+ key = repr(key)
39
+ return "{key}={val}".format(key=key, val=val)
40
+
41
+ return " ".join([fmt(key, val) for key, val in sorted(props.items())])
42
+
43
+
44
+ def log_debug(message: str | Any, **params: Any) -> None:
45
+ msg = logfmt(dict(message=message, **params))
46
+ if _console_log_level() == "debug":
47
+ print(msg, file=sys.stderr)
48
+ logger.debug(msg)
49
+
50
+
51
+ def log_info(message: str | Any, **params: Any) -> None:
52
+ msg = logfmt(dict(message=message, **params))
53
+ if _console_log_level() in ["debug", "info"]:
54
+ print(msg, file=sys.stderr)
55
+ logger.info(msg)
56
+
57
+
58
+ def log_warn(message: str | Any, **params: Any) -> None:
59
+ msg = logfmt(dict(message=message, **params))
60
+ print(msg, file=sys.stderr)
61
+ logger.warn(msg)
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import platform
6
+ from typing import TYPE_CHECKING, Any, Dict
7
+
8
+
9
+ if TYPE_CHECKING:
10
+ from _typeshed import SupportsKeysAndGetItem
11
+
12
+ import seekrai
13
+ from seekrai import error
14
+ from seekrai.utils._log import _console_log_level
15
+
16
+
17
+ def get_headers(
18
+ method: str | None = None,
19
+ api_key: str | None = None,
20
+ extra: "SupportsKeysAndGetItem[str, Any] | None" = None,
21
+ ) -> Dict[str, str]:
22
+ """
23
+ Generates request headers with API key, metadata, and supplied headers
24
+
25
+ Args:
26
+ method (str, optional): HTTP request type (POST, GET, etc.)
27
+ Defaults to None.
28
+ api_key (str, optional): API key to add as an Authorization header.
29
+ Defaults to None.
30
+ extra (SupportsKeysAndGetItem[str, Any], optional): Additional headers to add to request.
31
+ Defaults to None.
32
+
33
+ Returns:
34
+ headers (Dict[str, str]): Compiled headers from data
35
+ """
36
+
37
+ user_agent = "SeekrFlow/v1 PythonBindings/%s" % (seekrai.version,)
38
+
39
+ uname_without_node = " ".join(
40
+ v for k, v in platform.uname()._asdict().items() if k != "node"
41
+ )
42
+ ua = {
43
+ "bindings_version": seekrai.version,
44
+ "httplib": "requests",
45
+ "lang": "python",
46
+ "lang_version": platform.python_version(),
47
+ "platform": platform.platform(),
48
+ "publisher": "seekrai",
49
+ "uname": uname_without_node,
50
+ }
51
+
52
+ headers: Dict[str, Any] = {
53
+ "X-SeekrFlow-Client-User-Agent": json.dumps(ua),
54
+ "Authorization": default_api_key(api_key),
55
+ "User-Agent": user_agent,
56
+ }
57
+
58
+ if _console_log_level():
59
+ headers["SeekrFlow-Debug"] = _console_log_level()
60
+ if extra:
61
+ headers.update(extra)
62
+
63
+ return headers
64
+
65
+
66
+ def default_api_key(api_key: str | None = None) -> str | None:
67
+ """
68
+ API key fallback logic from input argument and environment variable
69
+
70
+ Args:
71
+ api_key (str, optional): Supplied API key. This argument takes priority over env var
72
+
73
+ Returns:
74
+ seekrflow_api_key (str): Returns API key from supplied input or env var
75
+
76
+ Raises:
77
+ seekrai.error.AuthenticationError: if API key not found
78
+ """
79
+ if api_key:
80
+ return api_key
81
+ if os.environ.get("SEEKRFLOW_API_KEY"):
82
+ return os.environ.get("SEEKRFLOW_API_KEY")
83
+
84
+ raise error.AuthenticationError(seekrai.constants.MISSING_API_KEY_MESSAGE)
seekrai/utils/files.py ADDED
@@ -0,0 +1,204 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from traceback import format_exc
7
+ from typing import Any, Dict
8
+
9
+ from pyarrow import ArrowInvalid, parquet
10
+
11
+ from seekrai.constants import (
12
+ MAX_FILE_SIZE_GB,
13
+ MIN_SAMPLES,
14
+ NUM_BYTES_IN_GB,
15
+ PARQUET_EXPECTED_COLUMNS,
16
+ )
17
+
18
+
19
+ def check_file(
20
+ file: Path | str,
21
+ ) -> Dict[str, Any]:
22
+ if not isinstance(file, Path):
23
+ file = Path(file)
24
+
25
+ report_dict = {
26
+ "is_check_passed": True,
27
+ "message": "Checks passed",
28
+ "found": None,
29
+ "file_size": None,
30
+ "utf8": None,
31
+ "line_type": None,
32
+ "text_field": None,
33
+ "key_value": None,
34
+ "min_samples": None,
35
+ "num_samples": None,
36
+ "load_json": None,
37
+ }
38
+
39
+ if not file.is_file():
40
+ report_dict["found"] = False
41
+ report_dict["is_check_passed"] = False
42
+ return report_dict
43
+ else:
44
+ report_dict["found"] = True
45
+
46
+ file_size = os.stat(file.as_posix()).st_size
47
+
48
+ if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
49
+ report_dict["message"] = (
50
+ f"Maximum supported file size is {MAX_FILE_SIZE_GB} GB. Found file with size of {round(file_size / NUM_BYTES_IN_GB ,3)} GB."
51
+ )
52
+ report_dict["is_check_passed"] = False
53
+ elif file_size == 0:
54
+ report_dict["message"] = "File is empty"
55
+ report_dict["file_size"] = 0
56
+ report_dict["is_check_passed"] = False
57
+ return report_dict
58
+ else:
59
+ report_dict["file_size"] = file_size
60
+
61
+ if file.suffix == ".jsonl":
62
+ report_dict["filetype"] = "jsonl"
63
+ data_report_dict = _check_jsonl(file)
64
+ elif file.suffix == ".parquet":
65
+ report_dict["filetype"] = "parquet"
66
+ data_report_dict = _check_parquet(file)
67
+ else:
68
+ report_dict["filetype"] = (
69
+ f"Unknown extension of file {file}. "
70
+ "Only files with extensions .jsonl and .parquet are supported."
71
+ )
72
+ report_dict["is_check_passed"] = False
73
+
74
+ report_dict.update(data_report_dict)
75
+ return report_dict
76
+
77
+
78
+ def _check_jsonl(file: Path) -> Dict[str, Any]:
79
+ report_dict: Dict[str, Any] = {}
80
+ # Check that the file is UTF-8 encoded. If not report where the error occurs.
81
+ try:
82
+ with file.open(encoding="utf-8") as f:
83
+ f.read()
84
+ report_dict["utf8"] = True
85
+ except UnicodeDecodeError as e:
86
+ report_dict["utf8"] = False
87
+ report_dict["message"] = f"File is not UTF-8 encoded. Error raised: {e}."
88
+ report_dict["is_check_passed"] = False
89
+ return report_dict
90
+
91
+ with file.open() as f:
92
+ # idx must be instantiated so decode errors (e.g. file is a tar) or empty files are caught
93
+ idx = -1
94
+ try:
95
+ for idx, line in enumerate(f):
96
+ json_line = json.loads(line) # each line in jsonlines should be a json
97
+
98
+ if not isinstance(json_line, dict):
99
+ report_dict["line_type"] = False
100
+ report_dict["message"] = (
101
+ f"Error parsing file. Invalid format on line {idx + 1} of the input file. "
102
+ 'Example of valid json: {"text": "my sample string"}. '
103
+ )
104
+
105
+ report_dict["is_check_passed"] = False
106
+
107
+ if "text" not in json_line.keys():
108
+ report_dict["text_field"] = False
109
+ report_dict["message"] = (
110
+ f"Missing 'text' field was found on line {idx + 1} of the the input file. "
111
+ "Expected format: {'text': 'my sample string'}. "
112
+ )
113
+ report_dict["is_check_passed"] = False
114
+ else:
115
+ # check to make sure the value of the "text" key is a string
116
+ if not isinstance(json_line["text"], str):
117
+ report_dict["key_value"] = False
118
+ report_dict["message"] = (
119
+ f'Invalid value type for "text" key on line {idx + 1}. '
120
+ f'Expected string. Found {type(json_line["text"])}.'
121
+ )
122
+
123
+ report_dict["is_check_passed"] = False
124
+
125
+ # make sure this is outside the for idx, line in enumerate(f): for loop
126
+ if idx + 1 < MIN_SAMPLES:
127
+ report_dict["min_samples"] = False
128
+ report_dict["message"] = (
129
+ f"Processing {file} resulted in only {idx + 1} samples. "
130
+ f"Our minimum is {MIN_SAMPLES} samples. "
131
+ )
132
+ report_dict["is_check_passed"] = False
133
+ else:
134
+ report_dict["num_samples"] = idx + 1
135
+ report_dict["min_samples"] = True
136
+
137
+ report_dict["load_json"] = True
138
+
139
+ except ValueError:
140
+ report_dict["load_json"] = False
141
+ if idx < 0:
142
+ report_dict["message"] = (
143
+ "Unable to decode file. "
144
+ "File may be empty or in an unsupported format. "
145
+ )
146
+ else:
147
+ report_dict["message"] = (
148
+ f"Error parsing json payload. Unexpected format on line {idx + 1}."
149
+ )
150
+ report_dict["is_check_passed"] = False
151
+
152
+ if "text_field" not in report_dict:
153
+ report_dict["text_field"] = True
154
+ if "line_type" not in report_dict:
155
+ report_dict["line_type"] = True
156
+ if "key_value" not in report_dict:
157
+ report_dict["key_value"] = True
158
+ return report_dict
159
+
160
+
161
+ def _check_parquet(file: Path) -> Dict[str, Any]:
162
+ report_dict: Dict[str, Any] = {}
163
+
164
+ try:
165
+ table = parquet.read_table(str(file), memory_map=True)
166
+ except ArrowInvalid:
167
+ report_dict["load_parquet"] = (
168
+ f"An exception has occurred when loading the Parquet file {file}. Please check the file for corruption. "
169
+ f"Exception trace:\n{format_exc()}"
170
+ )
171
+ report_dict["is_check_passed"] = False
172
+ return report_dict
173
+
174
+ column_names = table.schema.names
175
+ if "input_ids" not in column_names:
176
+ report_dict["load_parquet"] = (
177
+ f"Parquet file {file} does not contain the `input_ids` column."
178
+ )
179
+ report_dict["is_check_passed"] = False
180
+ return report_dict
181
+
182
+ for column_name in column_names:
183
+ if column_name not in PARQUET_EXPECTED_COLUMNS:
184
+ report_dict["load_parquet"] = (
185
+ f"Parquet file {file} contains an unexpected column {column_name}. "
186
+ f"Only columns {PARQUET_EXPECTED_COLUMNS} are supported."
187
+ )
188
+ report_dict["is_check_passed"] = False
189
+ return report_dict
190
+
191
+ num_samples = len(table)
192
+ if num_samples < MIN_SAMPLES:
193
+ report_dict["min_samples"] = (
194
+ f"Processing {file} resulted in only {num_samples} samples. "
195
+ f"Our minimum is {MIN_SAMPLES} samples. "
196
+ )
197
+ report_dict["is_check_passed"] = False
198
+ return report_dict
199
+ else:
200
+ report_dict["num_samples"] = num_samples
201
+
202
+ report_dict["is_check_passed"] = True
203
+
204
+ return report_dict