codon-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codon/__init__.py +5 -0
- codon/base.py +167 -0
- codon/exp/__init__.py +0 -0
- codon/exp/moe.py +307 -0
- codon/model/__init__.py +0 -0
- codon/model/motif/__init__.py +1 -0
- codon/model/motif/motif_a1.py +121 -0
- codon/model/patch_disc.py +151 -0
- codon/model/tcn.py +124 -0
- codon/ops/__init__.py +3 -0
- codon/ops/attention.py +107 -0
- codon/ops/bio.py +0 -0
- codon/utils/__init__.py +0 -0
- codon/utils/dataset/__init__.py +3 -0
- codon/utils/dataset/base.py +46 -0
- codon/utils/dataset/corpus.py +478 -0
- codon/utils/dataset/dataviewer.py +196 -0
- codon/utils/dataset/flatdata.py +455 -0
- codon/utils/mask.py +266 -0
- codon/utils/safecode.py +24 -0
- codon/utils/seed.py +75 -0
- codon/utils/theta.py +55 -0
- codon/utils/token.py +276 -0
- codon_model-0.0.1.dist-info/METADATA +17 -0
- codon_model-0.0.1.dist-info/RECORD +28 -0
- codon_model-0.0.1.dist-info/WHEEL +5 -0
- codon_model-0.0.1.dist-info/licenses/LICENSE +201 -0
- codon_model-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import csv
|
|
4
|
+
|
|
5
|
+
from enum import Enum, auto
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Union
|
|
9
|
+
|
|
10
|
+
import pandas as pd
|
|
11
|
+
|
|
12
|
+
from .base import CodonDataset
|
|
13
|
+
from .flatdata import FlatDataset
|
|
14
|
+
|
|
15
|
+
from codon.utils.safecode import safecode
|
|
16
|
+
|
|
17
|
+
class FileType(Enum):
|
|
18
|
+
PARQUET = auto()
|
|
19
|
+
JSONL = auto()
|
|
20
|
+
CSV = auto()
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class CorpusData:
|
|
24
|
+
'''
|
|
25
|
+
Represents a single corpus data entry with content, token count, and UUID.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
content (str): The text content.
|
|
29
|
+
num_token (int): The number of tokens (characters) in the content.
|
|
30
|
+
uuid (str): Unique identifier for this entry.
|
|
31
|
+
'''
|
|
32
|
+
content: str
|
|
33
|
+
num_token: int
|
|
34
|
+
uuid: str
|
|
35
|
+
|
|
36
|
+
class CorpusDataset(CodonDataset):
|
|
37
|
+
'''
|
|
38
|
+
A dataset for managing linguistic corpora with token counting.
|
|
39
|
+
|
|
40
|
+
This class maintains a folder of corpus files (PARQUET, JSONL, CSV) and
|
|
41
|
+
tracks metadata in a JSON configuration file. Token count is calculated
|
|
42
|
+
based on character count. Supports Key-Value access with both string keys
|
|
43
|
+
(filename:row_number) and integer keys (global row index).
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
folder_path (str): Path to the folder containing corpus files.
|
|
47
|
+
config_path (str): Path to the configuration JSON file.
|
|
48
|
+
_config (dict): Configuration dictionary loaded from JSON.
|
|
49
|
+
_total_token (int): Total token count across all files.
|
|
50
|
+
_file_index (list): List of file metadata for index mapping.
|
|
51
|
+
_cumulative_rows (list): Cumulative row counts for index mapping.
|
|
52
|
+
'''
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
folder_path: str,
|
|
57
|
+
file_type: FileType = None,
|
|
58
|
+
file_limit: int = 2 * 1024 * 1024 * 1024
|
|
59
|
+
) -> None:
|
|
60
|
+
'''
|
|
61
|
+
Initializes the CorpusDataset.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
folder_path (str): Path to the folder containing corpus files.
|
|
65
|
+
file_type (FileType): The file type for storing data. If not provided,
|
|
66
|
+
will attempt to read from config.json. Required for new datasets.
|
|
67
|
+
file_limit (int): Maximum file size in bytes. Defaults to 2GB.
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
ValueError: If file_type is not provided and config.json does not exist.
|
|
71
|
+
'''
|
|
72
|
+
self.folder_path = folder_path
|
|
73
|
+
self.config_path = os.path.join(folder_path, 'config.json')
|
|
74
|
+
self.file_limit = file_limit
|
|
75
|
+
self._config: dict = {}
|
|
76
|
+
self._total_token: int = 0
|
|
77
|
+
self._file_index: list = []
|
|
78
|
+
self._cumulative_rows: list = [0]
|
|
79
|
+
self._current_file_size: int = 0
|
|
80
|
+
self._current_file_idx: int = 0
|
|
81
|
+
self.file_type: FileType = file_type
|
|
82
|
+
self._flat_dataset_cache: dict = {}
|
|
83
|
+
self._parquet_buffer: list = []
|
|
84
|
+
self._file_name_to_idx: dict = {}
|
|
85
|
+
|
|
86
|
+
# Create folder if it doesn't exist
|
|
87
|
+
os.makedirs(folder_path, exist_ok=True)
|
|
88
|
+
|
|
89
|
+
# Load or initialize config
|
|
90
|
+
self._load_config()
|
|
91
|
+
|
|
92
|
+
# Determine file_type
|
|
93
|
+
if self.file_type is None:
|
|
94
|
+
if 'file_type' in self._config:
|
|
95
|
+
self.file_type = FileType[self._config['file_type']]
|
|
96
|
+
else:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
'file_type must be specified when creating a new CorpusDataset'
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Store file_type in config
|
|
102
|
+
self._config['file_type'] = self.file_type.name
|
|
103
|
+
self._save_config()
|
|
104
|
+
|
|
105
|
+
def _load_config(self) -> None:
|
|
106
|
+
'''
|
|
107
|
+
Loads configuration from JSON file or creates a new one.
|
|
108
|
+
'''
|
|
109
|
+
if os.path.exists(self.config_path):
|
|
110
|
+
with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
111
|
+
self._config = json.load(f)
|
|
112
|
+
else:
|
|
113
|
+
self._config = {
|
|
114
|
+
'version': '1.0',
|
|
115
|
+
'total_token': 0,
|
|
116
|
+
'files': []
|
|
117
|
+
}
|
|
118
|
+
self._save_config()
|
|
119
|
+
|
|
120
|
+
# Rebuild file index and cumulative rows
|
|
121
|
+
self._rebuild_index()
|
|
122
|
+
|
|
123
|
+
def _save_config(self) -> None:
|
|
124
|
+
'''
|
|
125
|
+
Saves configuration to JSON file.
|
|
126
|
+
'''
|
|
127
|
+
self._config['total_token'] = self._total_token
|
|
128
|
+
with open(self.config_path, 'w', encoding='utf-8') as f:
|
|
129
|
+
json.dump(self._config, f, indent=2, ensure_ascii=False)
|
|
130
|
+
|
|
131
|
+
def _rebuild_index(self) -> None:
|
|
132
|
+
'''
|
|
133
|
+
Rebuilds the file index and cumulative row counts.
|
|
134
|
+
'''
|
|
135
|
+
self._file_index = []
|
|
136
|
+
self._cumulative_rows = [0]
|
|
137
|
+
self._total_token = 0
|
|
138
|
+
self._file_name_to_idx = {}
|
|
139
|
+
|
|
140
|
+
for file_info in self._config.get('files', []):
|
|
141
|
+
self._file_index.append(file_info)
|
|
142
|
+
self._file_name_to_idx[file_info['filename']] = len(self._file_index) - 1
|
|
143
|
+
self._total_token += file_info.get('num_token', 0)
|
|
144
|
+
self._cumulative_rows.append(
|
|
145
|
+
self._cumulative_rows[-1] + file_info.get('num_rows', 0)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def _detect_file_type(self, file_path: str) -> FileType:
|
|
149
|
+
'''
|
|
150
|
+
Detects file type based on extension.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
file_path (str): Path to the file.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
FileType: The detected file type.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
ValueError: If file type is not supported.
|
|
160
|
+
'''
|
|
161
|
+
if file_path.endswith('.parquet'):
|
|
162
|
+
return FileType.PARQUET
|
|
163
|
+
elif file_path.endswith('.jsonl'):
|
|
164
|
+
return FileType.JSONL
|
|
165
|
+
elif file_path.endswith('.csv'):
|
|
166
|
+
return FileType.CSV
|
|
167
|
+
else:
|
|
168
|
+
raise ValueError(f'Unsupported file type: {file_path}')
|
|
169
|
+
|
|
170
|
+
def _count_tokens(self, content: str) -> int:
|
|
171
|
+
'''
|
|
172
|
+
Counts tokens in content based on character count.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
content (str): The text content.
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
int: The number of tokens (characters).
|
|
179
|
+
'''
|
|
180
|
+
return len(content)
|
|
181
|
+
|
|
182
|
+
def _load_file_data(self, file_path: str, file_type: FileType) -> list:
|
|
183
|
+
'''
|
|
184
|
+
Loads data from a file.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
file_path (str): Path to the file.
|
|
188
|
+
file_type (FileType): Type of the file.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
list: List of dictionaries representing rows.
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
ValueError: If file type is not supported.
|
|
195
|
+
'''
|
|
196
|
+
if file_type == FileType.PARQUET:
|
|
197
|
+
df = pd.read_parquet(file_path)
|
|
198
|
+
return df.to_dict('records')
|
|
199
|
+
elif file_type == FileType.JSONL:
|
|
200
|
+
data = []
|
|
201
|
+
with open(file_path, 'r', encoding='utf-8') as f:
|
|
202
|
+
for line in f:
|
|
203
|
+
data.append(json.loads(line))
|
|
204
|
+
return data
|
|
205
|
+
elif file_type == FileType.CSV:
|
|
206
|
+
df = pd.read_csv(file_path)
|
|
207
|
+
return df.to_dict('records')
|
|
208
|
+
else:
|
|
209
|
+
raise ValueError(f'Unsupported file type: {file_type}')
|
|
210
|
+
|
|
211
|
+
def add(self, data: str) -> None:
|
|
212
|
+
'''
|
|
213
|
+
Adds a string data entry to the dataset. If adding the data would exceed
|
|
214
|
+
the file size limit, automatically creates a new file.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
data (str): The text content to add.
|
|
218
|
+
'''
|
|
219
|
+
# Calculate token count (character count)
|
|
220
|
+
num_token = self._count_tokens(data)
|
|
221
|
+
data_size = len(data.encode('utf-8'))
|
|
222
|
+
|
|
223
|
+
# Check if adding this data would exceed limit
|
|
224
|
+
if self._current_file_size + data_size > self.file_limit and len(self._config['files']) > 0:
|
|
225
|
+
# Flush parquet buffer before starting new file
|
|
226
|
+
if self.file_type == FileType.PARQUET and self._parquet_buffer:
|
|
227
|
+
current_filename = f'corpus_{self._current_file_idx:03d}.{self._get_file_extension()}'
|
|
228
|
+
current_file_path = os.path.join(self.folder_path, current_filename)
|
|
229
|
+
self._flush_parquet_buffer(current_file_path)
|
|
230
|
+
# Start a new file
|
|
231
|
+
self._current_file_idx += 1
|
|
232
|
+
self._current_file_size = 0
|
|
233
|
+
|
|
234
|
+
# Get or create current file
|
|
235
|
+
current_filename = f'corpus_{self._current_file_idx:03d}.{self._get_file_extension()}'
|
|
236
|
+
current_file_path = os.path.join(self.folder_path, current_filename)
|
|
237
|
+
|
|
238
|
+
# Generate UUID for this entry
|
|
239
|
+
entry_uuid = safecode(8)
|
|
240
|
+
|
|
241
|
+
# Prepare row data
|
|
242
|
+
row_data = {'content': data, 'uuid': entry_uuid}
|
|
243
|
+
|
|
244
|
+
# Append to file
|
|
245
|
+
if self.file_type == FileType.JSONL:
|
|
246
|
+
with open(current_file_path, 'a', encoding='utf-8') as f:
|
|
247
|
+
f.write(json.dumps(row_data, ensure_ascii=False) + '\n')
|
|
248
|
+
elif self.file_type == FileType.CSV:
|
|
249
|
+
file_exists = os.path.exists(current_file_path)
|
|
250
|
+
with open(current_file_path, 'a', encoding='utf-8', newline='') as f:
|
|
251
|
+
writer = csv.DictWriter(f, fieldnames=['content', 'uuid'])
|
|
252
|
+
if not file_exists:
|
|
253
|
+
writer.writeheader()
|
|
254
|
+
writer.writerow(row_data)
|
|
255
|
+
elif self.file_type == FileType.PARQUET:
|
|
256
|
+
# For parquet, accumulate in memory and write periodically
|
|
257
|
+
self._parquet_buffer.append(row_data)
|
|
258
|
+
# Write to parquet every 1000 rows or when buffer is large
|
|
259
|
+
if len(self._parquet_buffer) >= 1000:
|
|
260
|
+
self._flush_parquet_buffer(current_file_path)
|
|
261
|
+
|
|
262
|
+
# Update file info
|
|
263
|
+
self._update_file_info(current_filename, num_token, data_size)
|
|
264
|
+
|
|
265
|
+
def _get_file_extension(self) -> str:
|
|
266
|
+
'''
|
|
267
|
+
Gets the file extension based on file_type.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
str: The file extension without the dot.
|
|
271
|
+
|
|
272
|
+
Raises:
|
|
273
|
+
ValueError: If file type is not supported.
|
|
274
|
+
'''
|
|
275
|
+
if self.file_type == FileType.PARQUET:
|
|
276
|
+
return 'parquet'
|
|
277
|
+
elif self.file_type == FileType.JSONL:
|
|
278
|
+
return 'jsonl'
|
|
279
|
+
elif self.file_type == FileType.CSV:
|
|
280
|
+
return 'csv'
|
|
281
|
+
else:
|
|
282
|
+
raise ValueError(f'Unsupported file type: {self.file_type}')
|
|
283
|
+
|
|
284
|
+
def _flush_parquet_buffer(self, file_path: str) -> None:
|
|
285
|
+
'''
|
|
286
|
+
Flushes accumulated parquet data to file.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
file_path (str): Path to the parquet file.
|
|
290
|
+
'''
|
|
291
|
+
if not self._parquet_buffer:
|
|
292
|
+
return
|
|
293
|
+
|
|
294
|
+
df = pd.DataFrame(self._parquet_buffer)
|
|
295
|
+
if os.path.exists(file_path):
|
|
296
|
+
existing_df = pd.read_parquet(file_path)
|
|
297
|
+
df = pd.concat([existing_df, df], ignore_index=True)
|
|
298
|
+
df.to_parquet(file_path, index=False)
|
|
299
|
+
self._parquet_buffer = []
|
|
300
|
+
|
|
301
|
+
def _update_file_info(self, filename: str, num_token: int, data_size: int) -> None:
|
|
302
|
+
'''
|
|
303
|
+
Updates file information in config.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
filename (str): The filename.
|
|
307
|
+
num_token (int): Number of tokens added.
|
|
308
|
+
data_size (int): Size of data in bytes.
|
|
309
|
+
'''
|
|
310
|
+
# Find or create file info
|
|
311
|
+
file_info = None
|
|
312
|
+
for info in self._config['files']:
|
|
313
|
+
if info['filename'] == filename:
|
|
314
|
+
file_info = info
|
|
315
|
+
break
|
|
316
|
+
|
|
317
|
+
if file_info is None:
|
|
318
|
+
file_info = {
|
|
319
|
+
'filename': filename,
|
|
320
|
+
'file_type': self.file_type.name,
|
|
321
|
+
'num_rows': 0,
|
|
322
|
+
'num_token': 0,
|
|
323
|
+
'created_at': datetime.now().isoformat()
|
|
324
|
+
}
|
|
325
|
+
self._config['files'].append(file_info)
|
|
326
|
+
|
|
327
|
+
file_info['num_rows'] += 1
|
|
328
|
+
file_info['num_token'] += num_token
|
|
329
|
+
self._total_token += num_token
|
|
330
|
+
self._current_file_size += data_size
|
|
331
|
+
|
|
332
|
+
# Save config and rebuild index
|
|
333
|
+
self._save_config()
|
|
334
|
+
self._rebuild_index()
|
|
335
|
+
|
|
336
|
+
def add_from_file(self, file_path: str, fields: list[str]) -> None:
|
|
337
|
+
'''
|
|
338
|
+
Adds data from a file by reading specified fields and concatenating them.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
file_path (str): Path to the source file.
|
|
342
|
+
fields (list[str]): List of field names to read and concatenate.
|
|
343
|
+
If a field doesn't exist, it is skipped.
|
|
344
|
+
|
|
345
|
+
Raises:
|
|
346
|
+
FileNotFoundError: If the file does not exist.
|
|
347
|
+
'''
|
|
348
|
+
if not os.path.exists(file_path):
|
|
349
|
+
raise FileNotFoundError(f'File not found: {file_path}')
|
|
350
|
+
|
|
351
|
+
file_type = self._detect_file_type(file_path)
|
|
352
|
+
data = self._load_file_data(file_path, file_type)
|
|
353
|
+
|
|
354
|
+
# Process each row
|
|
355
|
+
for row in data:
|
|
356
|
+
# Concatenate specified fields
|
|
357
|
+
content_parts = []
|
|
358
|
+
for field in fields:
|
|
359
|
+
if field in row:
|
|
360
|
+
content_parts.append(str(row[field]))
|
|
361
|
+
|
|
362
|
+
if content_parts:
|
|
363
|
+
concatenated_content = ''.join(content_parts)
|
|
364
|
+
self.add(concatenated_content)
|
|
365
|
+
|
|
366
|
+
def _get_or_create_flat_dataset(self, file_path: str) -> FlatDataset:
|
|
367
|
+
'''
|
|
368
|
+
Gets or creates a FlatDataset instance for lazy loading.
|
|
369
|
+
|
|
370
|
+
This method caches FlatDataset instances to avoid recreating them
|
|
371
|
+
for repeated access to the same file.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
file_path (str): Path to the corpus file.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
FlatDataset: A FlatDataset instance for lazy loading the file.
|
|
378
|
+
'''
|
|
379
|
+
if file_path not in self._flat_dataset_cache:
|
|
380
|
+
self._flat_dataset_cache[file_path] = FlatDataset(
|
|
381
|
+
file_path,
|
|
382
|
+
in_memory=False,
|
|
383
|
+
shuffle=False
|
|
384
|
+
)
|
|
385
|
+
return self._flat_dataset_cache[file_path]
|
|
386
|
+
|
|
387
|
+
def get(self, key: Union[int, str]) -> CorpusData:
|
|
388
|
+
'''
|
|
389
|
+
Retrieves a corpus data entry by key using lazy loading.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
key (Union[int, str]): The key to retrieve. Can be:
|
|
393
|
+
- int: Global row index (0, 1, 2, ...)
|
|
394
|
+
- str: "filename:row_number" format (e.g., "corpus_001.parquet:5")
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
CorpusData: The corpus data at the specified key.
|
|
398
|
+
|
|
399
|
+
Raises:
|
|
400
|
+
IndexError: If index is out of range.
|
|
401
|
+
KeyError: If string key format is invalid.
|
|
402
|
+
'''
|
|
403
|
+
# Handle string key format: "filename:row_number"
|
|
404
|
+
if isinstance(key, str):
|
|
405
|
+
if ':' not in key:
|
|
406
|
+
raise KeyError(f'Invalid key format: {key}. Expected "filename:row_number"')
|
|
407
|
+
filename, row_str = key.rsplit(':', 1)
|
|
408
|
+
try:
|
|
409
|
+
row_in_file = int(row_str)
|
|
410
|
+
except ValueError:
|
|
411
|
+
raise KeyError(f'Invalid row number in key: {key}')
|
|
412
|
+
|
|
413
|
+
# Find file by filename using optimized lookup
|
|
414
|
+
if filename not in self._file_name_to_idx:
|
|
415
|
+
raise KeyError(f'File not found: {filename}')
|
|
416
|
+
|
|
417
|
+
file_idx = self._file_name_to_idx[filename]
|
|
418
|
+
file_info = self._file_index[file_idx]
|
|
419
|
+
|
|
420
|
+
if row_in_file < 0 or row_in_file >= file_info['num_rows']:
|
|
421
|
+
raise IndexError(f'Row {row_in_file} out of range for {filename}')
|
|
422
|
+
else:
|
|
423
|
+
# Handle integer key: global row index
|
|
424
|
+
if key < 0 or key >= self._cumulative_rows[-1]:
|
|
425
|
+
raise IndexError(f'Index {key} out of range')
|
|
426
|
+
|
|
427
|
+
# Find which file contains this index
|
|
428
|
+
file_idx = 0
|
|
429
|
+
for i, cumulative_row in enumerate(self._cumulative_rows[1:], 1):
|
|
430
|
+
if key < cumulative_row:
|
|
431
|
+
file_idx = i - 1
|
|
432
|
+
break
|
|
433
|
+
|
|
434
|
+
row_in_file = key - self._cumulative_rows[file_idx]
|
|
435
|
+
file_info = self._file_index[file_idx]
|
|
436
|
+
|
|
437
|
+
# Use FlatDataset for lazy loading
|
|
438
|
+
file_path = os.path.join(self.folder_path, file_info['filename'])
|
|
439
|
+
flat_dataset = self._get_or_create_flat_dataset(file_path)
|
|
440
|
+
row = flat_dataset.get_value(row_in_file)
|
|
441
|
+
|
|
442
|
+
# Find content column (try common names)
|
|
443
|
+
content_column = None
|
|
444
|
+
for col_name in ['content', 'text', 'data', 'corpus']:
|
|
445
|
+
if col_name in row:
|
|
446
|
+
content_column = col_name
|
|
447
|
+
break
|
|
448
|
+
|
|
449
|
+
if content_column is None:
|
|
450
|
+
# Use first column if no standard name found
|
|
451
|
+
content_column = list(row.keys())[0]
|
|
452
|
+
|
|
453
|
+
content = str(row[content_column])
|
|
454
|
+
num_token = self._count_tokens(content)
|
|
455
|
+
uuid = str(row.get('uuid', ''))
|
|
456
|
+
|
|
457
|
+
return CorpusData(content=content, num_token=num_token, uuid=uuid)
|
|
458
|
+
|
|
459
|
+
def __len__(self) -> int:
|
|
460
|
+
'''
|
|
461
|
+
Returns the total number of entries in the dataset.
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
int: The total number of entries.
|
|
465
|
+
'''
|
|
466
|
+
return self._cumulative_rows[-1] if self._cumulative_rows else 0
|
|
467
|
+
|
|
468
|
+
def __getitem__(self, key: Union[int, str]) -> CorpusData:
|
|
469
|
+
'''
|
|
470
|
+
Retrieves an entry by key using bracket notation.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
key (Union[int, str]): The key to retrieve (int index or str "filename:row").
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
CorpusData: The corpus data at the specified key.
|
|
477
|
+
'''
|
|
478
|
+
return self.get(key)
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
'''
|
|
2
|
+
Data viewer module for previewing dataset fields.
|
|
3
|
+
|
|
4
|
+
This module provides utilities to inspect the structure and schema
|
|
5
|
+
of various dataset file formats including JSONL, Parquet, and CSV.
|
|
6
|
+
'''
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import pyarrow.parquet as pq
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DataViewer:
|
|
16
|
+
'''
|
|
17
|
+
A utility class for previewing and inspecting dataset files.
|
|
18
|
+
|
|
19
|
+
Supports JSONL, Parquet, and CSV file formats. Provides methods
|
|
20
|
+
to view fields, schema, and sample data.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
file_path (str): Path to the dataset file.
|
|
24
|
+
file_type (str): Type of the file ('jsonl', 'parquet', 'csv').
|
|
25
|
+
'''
|
|
26
|
+
|
|
27
|
+
def __init__(self, file_path: str) -> None:
|
|
28
|
+
'''
|
|
29
|
+
Initialize the DataViewer with a file path.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
file_path (str): Path to the dataset file.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
FileNotFoundError: If the file does not exist.
|
|
36
|
+
ValueError: If the file type is not supported.
|
|
37
|
+
'''
|
|
38
|
+
if not os.path.exists(file_path):
|
|
39
|
+
raise FileNotFoundError(f'File not found: {file_path}')
|
|
40
|
+
|
|
41
|
+
self.file_path = file_path
|
|
42
|
+
self.file_type = self._detect_file_type(file_path)
|
|
43
|
+
self._df: pd.DataFrame | None = None
|
|
44
|
+
|
|
45
|
+
def _detect_file_type(self, file_path: str) -> str:
|
|
46
|
+
'''
|
|
47
|
+
Detect file type based on extension.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
file_path (str): Path to the file.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
str: File type ('jsonl', 'parquet', or 'csv').
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If file extension is not supported.
|
|
57
|
+
'''
|
|
58
|
+
ext = os.path.splitext(file_path)[1].lower()
|
|
59
|
+
type_map = {
|
|
60
|
+
'.jsonl': 'jsonl',
|
|
61
|
+
'.parquet': 'parquet',
|
|
62
|
+
'.csv': 'csv',
|
|
63
|
+
}
|
|
64
|
+
if ext not in type_map:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f'Unsupported file type: {ext}. '
|
|
67
|
+
f'Supported types: {list(type_map.keys())}'
|
|
68
|
+
)
|
|
69
|
+
return type_map[ext]
|
|
70
|
+
|
|
71
|
+
def _load_data(self, nrows: int | None = None) -> pd.DataFrame:
|
|
72
|
+
'''
|
|
73
|
+
Load data from file into a pandas DataFrame.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
nrows (int | None): Number of rows to load. If None, loads all rows.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
pd.DataFrame: Loaded data.
|
|
80
|
+
'''
|
|
81
|
+
if self._df is not None and nrows is None:
|
|
82
|
+
return self._df
|
|
83
|
+
|
|
84
|
+
if self.file_type == 'jsonl':
|
|
85
|
+
df = pd.read_json(self.file_path, lines=True, nrows=nrows)
|
|
86
|
+
elif self.file_type == 'parquet':
|
|
87
|
+
parquet_file = pq.ParquetFile(self.file_path)
|
|
88
|
+
if nrows is not None:
|
|
89
|
+
df = parquet_file.read_row_group(0).to_pandas()[:nrows]
|
|
90
|
+
else:
|
|
91
|
+
df = parquet_file.read().to_pandas()
|
|
92
|
+
elif self.file_type == 'csv':
|
|
93
|
+
df = pd.read_csv(self.file_path, nrows=nrows)
|
|
94
|
+
else:
|
|
95
|
+
raise ValueError(f'Unknown file type: {self.file_type}')
|
|
96
|
+
|
|
97
|
+
if nrows is None:
|
|
98
|
+
self._df = df
|
|
99
|
+
return df
|
|
100
|
+
|
|
101
|
+
def get_fields(self) -> list[str]:
|
|
102
|
+
'''
|
|
103
|
+
Get list of field names (columns) in the dataset.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
list[str]: List of field names.
|
|
107
|
+
'''
|
|
108
|
+
df = self._load_data(nrows=1)
|
|
109
|
+
return list(df.columns)
|
|
110
|
+
|
|
111
|
+
def get_schema(self) -> dict[str, str]:
|
|
112
|
+
'''
|
|
113
|
+
Get schema information with field names and their data types.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
dict[str, str]: Dictionary mapping field names to data types.
|
|
117
|
+
'''
|
|
118
|
+
df = self._load_data(nrows=1)
|
|
119
|
+
schema = {}
|
|
120
|
+
for col in df.columns:
|
|
121
|
+
dtype = str(df[col].dtype)
|
|
122
|
+
schema[col] = dtype
|
|
123
|
+
return schema
|
|
124
|
+
|
|
125
|
+
def preview(self, nrows: int = 5) -> pd.DataFrame:
|
|
126
|
+
'''
|
|
127
|
+
Preview the first N rows of the dataset.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
nrows (int): Number of rows to preview. Default is 5.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
pd.DataFrame: First N rows of the dataset.
|
|
134
|
+
'''
|
|
135
|
+
return self._load_data(nrows=nrows).head(nrows)
|
|
136
|
+
|
|
137
|
+
def get_stats(self) -> dict[str, Any]:
|
|
138
|
+
'''
|
|
139
|
+
Get basic statistics about the dataset.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
dict[str, Any]: Dictionary containing:
|
|
143
|
+
- num_rows: Total number of rows
|
|
144
|
+
- num_columns: Total number of columns
|
|
145
|
+
- file_size: File size in bytes
|
|
146
|
+
- memory_usage: Memory usage in bytes (if data loaded)
|
|
147
|
+
'''
|
|
148
|
+
df = self._load_data()
|
|
149
|
+
file_size = os.path.getsize(self.file_path)
|
|
150
|
+
memory_usage = df.memory_usage(deep=True).sum()
|
|
151
|
+
|
|
152
|
+
return {
|
|
153
|
+
'num_rows': len(df),
|
|
154
|
+
'num_columns': len(df.columns),
|
|
155
|
+
'file_size': file_size,
|
|
156
|
+
'memory_usage': memory_usage,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
def __repr__(self) -> str:
|
|
160
|
+
'''
|
|
161
|
+
Return string representation of the DataViewer.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
str: String representation showing file path and type.
|
|
165
|
+
'''
|
|
166
|
+
return f'DataViewer(file={self.file_path}, type={self.file_type})'
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def preview_fields(file_path: str, nrows: int = 5) -> None:
|
|
170
|
+
'''
|
|
171
|
+
Convenience function to preview fields and sample data from a file.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
file_path (str): Path to the dataset file.
|
|
175
|
+
nrows (int): Number of rows to preview. Default is 5.
|
|
176
|
+
'''
|
|
177
|
+
viewer = DataViewer(file_path)
|
|
178
|
+
|
|
179
|
+
print(f'File: {viewer.file_path}')
|
|
180
|
+
print(f'Type: {viewer.file_type}')
|
|
181
|
+
print()
|
|
182
|
+
|
|
183
|
+
print('Schema:')
|
|
184
|
+
schema = viewer.get_schema()
|
|
185
|
+
for field, dtype in schema.items():
|
|
186
|
+
print(f' {field}: {dtype}')
|
|
187
|
+
print()
|
|
188
|
+
|
|
189
|
+
print(f'Preview (first {nrows} rows):')
|
|
190
|
+
print(viewer.preview(nrows))
|
|
191
|
+
print()
|
|
192
|
+
|
|
193
|
+
print('Statistics:')
|
|
194
|
+
stats = viewer.get_stats()
|
|
195
|
+
for key, value in stats.items():
|
|
196
|
+
print(f' {key}: {value}')
|