codon-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codon/__init__.py +5 -0
- codon/base.py +167 -0
- codon/exp/__init__.py +0 -0
- codon/exp/moe.py +307 -0
- codon/model/__init__.py +0 -0
- codon/model/motif/__init__.py +1 -0
- codon/model/motif/motif_a1.py +121 -0
- codon/model/patch_disc.py +151 -0
- codon/model/tcn.py +124 -0
- codon/ops/__init__.py +3 -0
- codon/ops/attention.py +107 -0
- codon/ops/bio.py +0 -0
- codon/utils/__init__.py +0 -0
- codon/utils/dataset/__init__.py +3 -0
- codon/utils/dataset/base.py +46 -0
- codon/utils/dataset/corpus.py +478 -0
- codon/utils/dataset/dataviewer.py +196 -0
- codon/utils/dataset/flatdata.py +455 -0
- codon/utils/mask.py +266 -0
- codon/utils/safecode.py +24 -0
- codon/utils/seed.py +75 -0
- codon/utils/theta.py +55 -0
- codon/utils/token.py +276 -0
- codon_model-0.0.1.dist-info/METADATA +17 -0
- codon_model-0.0.1.dist-info/RECORD +28 -0
- codon_model-0.0.1.dist-info/WHEEL +5 -0
- codon_model-0.0.1.dist-info/licenses/LICENSE +201 -0
- codon_model-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import random
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import pyarrow.parquet as pq
|
|
5
|
+
|
|
6
|
+
from typing import Any, Dict, Optional, Union, Callable
|
|
7
|
+
|
|
8
|
+
from .base import CodonDataset
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FlatColumnDataset(CodonDataset):
|
|
12
|
+
'''
|
|
13
|
+
A dataset wrapper that provides access to a specific column of a FlatDataset.
|
|
14
|
+
|
|
15
|
+
This class allows treating a single column of a structured dataset (like CSV,
|
|
16
|
+
JSONL, or Parquet) as a standalone dataset.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
dataset (FlatDataset): The underlying source dataset.
|
|
20
|
+
column (str): The name of the column to access.
|
|
21
|
+
'''
|
|
22
|
+
|
|
23
|
+
def __init__(self, source: Union['FlatDataset', str], column: str, **kwargs):
|
|
24
|
+
'''
|
|
25
|
+
Initializes the FlatColumnDataset.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
source (Union[FlatDataset, str]): The source dataset instance or a
|
|
29
|
+
file path to create a new FlatDataset.
|
|
30
|
+
column (str): The name of the column to retrieve.
|
|
31
|
+
**kwargs: Additional arguments passed to FlatDataset constructor if
|
|
32
|
+
source is a path.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
TypeError: If source is not a FlatDataset instance or a string.
|
|
36
|
+
KeyError: If the specified column does not exist in the source dataset.
|
|
37
|
+
'''
|
|
38
|
+
if isinstance(source, str):
|
|
39
|
+
self.dataset = FlatDataset(source, **kwargs)
|
|
40
|
+
elif isinstance(source, FlatDataset):
|
|
41
|
+
self.dataset = source
|
|
42
|
+
else:
|
|
43
|
+
raise TypeError("Source must be a FlatDataset instance or a file path string")
|
|
44
|
+
|
|
45
|
+
self.column = column
|
|
46
|
+
|
|
47
|
+
# Verify column exists
|
|
48
|
+
if self.column not in self.dataset.fields:
|
|
49
|
+
raise KeyError(f"Column '{self.column}' not found in dataset fields: {self.dataset.fields}")
|
|
50
|
+
|
|
51
|
+
def __len__(self) -> int:
|
|
52
|
+
'''
|
|
53
|
+
Returns the length of the dataset.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
int: The number of rows in the dataset.
|
|
57
|
+
'''
|
|
58
|
+
return len(self.dataset)
|
|
59
|
+
|
|
60
|
+
def __getitem__(self, idx: int) -> Any:
|
|
61
|
+
'''
|
|
62
|
+
Retrieves the value of the specified column at the given index.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
idx (int): The index of the row.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Any: The value in the configured column at the given index.
|
|
69
|
+
'''
|
|
70
|
+
return self.dataset.get_value(idx, self.column)
|
|
71
|
+
|
|
72
|
+
def to_flat_dataset(self) -> 'MappedFlatDataset':
|
|
73
|
+
'''
|
|
74
|
+
Converts the column data to a FlatDataset.
|
|
75
|
+
|
|
76
|
+
This method extracts the column values and creates a new MappedFlatDataset
|
|
77
|
+
that treats each dictionary value as a row. Supports chaining for nested
|
|
78
|
+
dictionary structures.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
MappedFlatDataset: A new dataset where each row is the dictionary value
|
|
82
|
+
from the specified column.
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
TypeError: If any value in the column is not a dictionary.
|
|
86
|
+
'''
|
|
87
|
+
def extract_column_as_row(row: Dict[str, Any]) -> Dict[str, Any]:
|
|
88
|
+
'''
|
|
89
|
+
Extracts the column value from a row and validates it is a dict.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
row (Dict[str, Any]): The source row.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Dict[str, Any]: The column value (must be a dictionary).
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
TypeError: If the column value is not a dictionary.
|
|
99
|
+
'''
|
|
100
|
+
col_value = row[self.column]
|
|
101
|
+
if not isinstance(col_value, dict):
|
|
102
|
+
raise TypeError(
|
|
103
|
+
f"Column '{self.column}' value must be a dictionary, "
|
|
104
|
+
f"got {type(col_value).__name__}"
|
|
105
|
+
)
|
|
106
|
+
return col_value
|
|
107
|
+
|
|
108
|
+
return MappedFlatDataset(
|
|
109
|
+
parent_dataset=self.dataset,
|
|
110
|
+
map_fn=extract_column_as_row,
|
|
111
|
+
in_memory=self.dataset.in_memory
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class FlatDataset(CodonDataset):
|
|
116
|
+
'''
|
|
117
|
+
A dataset implementation for flat file formats (JSONL, CSV, Parquet).
|
|
118
|
+
|
|
119
|
+
This class supports both in-memory loading (for smaller datasets) and
|
|
120
|
+
lazy loading (for larger datasets) to efficiently handle data access.
|
|
121
|
+
|
|
122
|
+
Attributes:
|
|
123
|
+
path (str): The file path to the dataset.
|
|
124
|
+
in_memory (bool): Whether to load the entire dataset into memory.
|
|
125
|
+
shuffle (bool): Whether to shuffle the data indices.
|
|
126
|
+
'''
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
path: str,
|
|
131
|
+
in_memory: bool = False,
|
|
132
|
+
shuffle: bool = False
|
|
133
|
+
):
|
|
134
|
+
'''
|
|
135
|
+
Initializes the FlatDataset.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
path (str): The file path to the dataset (supports .jsonl, .csv, .parquet).
|
|
139
|
+
in_memory (bool): If True, loads all data into memory. If False,
|
|
140
|
+
uses lazy loading (offsets for text files, row groups for Parquet).
|
|
141
|
+
Defaults to False.
|
|
142
|
+
shuffle (bool): If True, shuffles the access indices. Defaults to False.
|
|
143
|
+
'''
|
|
144
|
+
self.path = path
|
|
145
|
+
self.in_memory = in_memory
|
|
146
|
+
self.shuffle = shuffle
|
|
147
|
+
self._data = []
|
|
148
|
+
self._offsets = []
|
|
149
|
+
self._indices = []
|
|
150
|
+
self._file_type = self._detect_file_type(path)
|
|
151
|
+
self._length = 0
|
|
152
|
+
self._columns = []
|
|
153
|
+
|
|
154
|
+
# Parquet specific
|
|
155
|
+
self._pq_file = None
|
|
156
|
+
self._pq_meta = None
|
|
157
|
+
|
|
158
|
+
if self.in_memory:
|
|
159
|
+
self._load_all()
|
|
160
|
+
else:
|
|
161
|
+
self._setup_lazy_loading()
|
|
162
|
+
|
|
163
|
+
if self.shuffle:
|
|
164
|
+
self._indices = list(range(self._length))
|
|
165
|
+
random.shuffle(self._indices)
|
|
166
|
+
else:
|
|
167
|
+
self._indices = range(self._length)
|
|
168
|
+
|
|
169
|
+
def _detect_file_type(self, path: str) -> str:
|
|
170
|
+
'''
|
|
171
|
+
Detects the file type based on the file extension.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
path (str): The file path.
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
str: The detected file type ('jsonl', 'csv', or 'parquet').
|
|
178
|
+
|
|
179
|
+
Raises:
|
|
180
|
+
ValueError: If the file extension is not supported.
|
|
181
|
+
'''
|
|
182
|
+
if path.endswith('.jsonl'):
|
|
183
|
+
return 'jsonl'
|
|
184
|
+
elif path.endswith('.csv'):
|
|
185
|
+
return 'csv'
|
|
186
|
+
elif path.endswith('.parquet'):
|
|
187
|
+
return 'parquet'
|
|
188
|
+
else:
|
|
189
|
+
raise ValueError(f"Unsupported file type: {path}")
|
|
190
|
+
|
|
191
|
+
def _load_all(self):
|
|
192
|
+
'''
|
|
193
|
+
Loads the entire dataset into memory.
|
|
194
|
+
'''
|
|
195
|
+
if self._file_type == 'jsonl':
|
|
196
|
+
with open(self.path, 'r', encoding='utf-8') as f:
|
|
197
|
+
self._data = [json.loads(line) for line in f]
|
|
198
|
+
if self._data:
|
|
199
|
+
self._columns = list(self._data[0].keys())
|
|
200
|
+
elif self._file_type == 'csv':
|
|
201
|
+
df = pd.read_csv(self.path)
|
|
202
|
+
self._data = df.to_dict('records')
|
|
203
|
+
self._columns = df.columns.tolist()
|
|
204
|
+
elif self._file_type == 'parquet':
|
|
205
|
+
df = pd.read_parquet(self.path)
|
|
206
|
+
self._data = df.to_dict('records')
|
|
207
|
+
self._columns = df.columns.tolist()
|
|
208
|
+
|
|
209
|
+
self._length = len(self._data)
|
|
210
|
+
|
|
211
|
+
def _setup_lazy_loading(self):
|
|
212
|
+
'''
|
|
213
|
+
Sets up lazy loading by calculating file offsets or metadata.
|
|
214
|
+
'''
|
|
215
|
+
if self._file_type == 'jsonl':
|
|
216
|
+
with open(self.path, 'rb') as f:
|
|
217
|
+
offset = 0
|
|
218
|
+
for line in f:
|
|
219
|
+
self._offsets.append(offset)
|
|
220
|
+
offset += len(line)
|
|
221
|
+
self._length = len(self._offsets)
|
|
222
|
+
# Peek first line for columns
|
|
223
|
+
if self._length > 0:
|
|
224
|
+
with open(self.path, 'r', encoding='utf-8') as f:
|
|
225
|
+
self._columns = list(json.loads(f.readline()).keys())
|
|
226
|
+
|
|
227
|
+
elif self._file_type == 'csv':
|
|
228
|
+
with open(self.path, 'rb') as f:
|
|
229
|
+
# Read header
|
|
230
|
+
header_line = f.readline()
|
|
231
|
+
self._columns = header_line.decode('utf-8').strip().split(',')
|
|
232
|
+
offset = len(header_line)
|
|
233
|
+
while True:
|
|
234
|
+
current_offset = f.tell()
|
|
235
|
+
line = f.readline()
|
|
236
|
+
if not line:
|
|
237
|
+
break
|
|
238
|
+
self._offsets.append(current_offset)
|
|
239
|
+
self._length = len(self._offsets)
|
|
240
|
+
|
|
241
|
+
elif self._file_type == 'parquet':
|
|
242
|
+
self._pq_file = pq.ParquetFile(self.path)
|
|
243
|
+
self._pq_meta = self._pq_file.metadata
|
|
244
|
+
self._length = self._pq_meta.num_rows
|
|
245
|
+
self._columns = self._pq_file.schema.names
|
|
246
|
+
|
|
247
|
+
@property
|
|
248
|
+
def fields(self) -> list[str]:
|
|
249
|
+
'''
|
|
250
|
+
Returns the list of column names in the dataset.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
list[str]: A list of column names.
|
|
254
|
+
'''
|
|
255
|
+
return self._columns
|
|
256
|
+
|
|
257
|
+
def __len__(self) -> int:
|
|
258
|
+
'''
|
|
259
|
+
Returns the number of rows in the dataset.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
int: The total number of rows.
|
|
263
|
+
'''
|
|
264
|
+
return self._length
|
|
265
|
+
|
|
266
|
+
def __getitem__(self, idx: Union[int, str]) -> Union[Dict[str, Any], 'FlatColumnDataset']:
|
|
267
|
+
'''
|
|
268
|
+
Retrieves a row by index or a column wrapper by name.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
idx (Union[int, str]): The row index (int) or column name (str).
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Union[Dict[str, Any], FlatColumnDataset]: If idx is an int, returns
|
|
275
|
+
the row data as a dictionary. If idx is a str, returns a
|
|
276
|
+
FlatColumnDataset wrapper for that column.
|
|
277
|
+
|
|
278
|
+
Raises:
|
|
279
|
+
TypeError: If idx is not an int or str.
|
|
280
|
+
IndexError: If the numeric index is out of range.
|
|
281
|
+
'''
|
|
282
|
+
if isinstance(idx, str):
|
|
283
|
+
return FlatColumnDataset(self, idx)
|
|
284
|
+
|
|
285
|
+
if not isinstance(idx, int):
|
|
286
|
+
raise TypeError(f"Index must be int or str, got {type(idx)}")
|
|
287
|
+
|
|
288
|
+
if idx < 0 or idx >= self._length:
|
|
289
|
+
raise IndexError("Index out of range")
|
|
290
|
+
|
|
291
|
+
real_idx = self._indices[idx]
|
|
292
|
+
return self.get_value(real_idx)
|
|
293
|
+
|
|
294
|
+
def get_value(self, idx: int, column: Optional[str] = None) -> Any:
|
|
295
|
+
'''
|
|
296
|
+
Retrieves the value at the specified index, optionally for a specific column.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
idx (int): The real index of the row (after shuffling logic).
|
|
300
|
+
column (Optional[str]): The specific column name to retrieve.
|
|
301
|
+
If None, returns the entire row.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Any: The value of the column or the full row dictionary.
|
|
305
|
+
|
|
306
|
+
Raises:
|
|
307
|
+
RuntimeError: If an unexpected file type handling path is reached.
|
|
308
|
+
'''
|
|
309
|
+
# Handle in-memory access
|
|
310
|
+
if self.in_memory:
|
|
311
|
+
row = self._data[idx]
|
|
312
|
+
if column:
|
|
313
|
+
return row[column]
|
|
314
|
+
return row
|
|
315
|
+
|
|
316
|
+
# Handle lazy loading
|
|
317
|
+
if self._file_type == 'jsonl':
|
|
318
|
+
with open(self.path, 'r', encoding='utf-8') as f:
|
|
319
|
+
f.seek(self._offsets[idx])
|
|
320
|
+
row = json.loads(f.readline())
|
|
321
|
+
if column:
|
|
322
|
+
return row[column]
|
|
323
|
+
return row
|
|
324
|
+
|
|
325
|
+
elif self._file_type == 'csv':
|
|
326
|
+
with open(self.path, 'r', encoding='utf-8') as f:
|
|
327
|
+
f.seek(self._offsets[idx])
|
|
328
|
+
line = f.readline()
|
|
329
|
+
# Simple CSV parsing
|
|
330
|
+
values = line.strip().split(',')
|
|
331
|
+
row = dict(zip(self._columns, values))
|
|
332
|
+
if column:
|
|
333
|
+
return row[column]
|
|
334
|
+
return row
|
|
335
|
+
|
|
336
|
+
elif self._file_type == 'parquet':
|
|
337
|
+
# Map idx to row group
|
|
338
|
+
row_group_index = 0
|
|
339
|
+
row_in_group = idx
|
|
340
|
+
for i in range(self._pq_file.num_row_groups):
|
|
341
|
+
num_rows = self._pq_meta.row_group(i).num_rows
|
|
342
|
+
if row_in_group < num_rows:
|
|
343
|
+
row_group_index = i
|
|
344
|
+
break
|
|
345
|
+
row_in_group -= num_rows
|
|
346
|
+
|
|
347
|
+
# Optimization: Read specific column only if requested
|
|
348
|
+
cols_to_read = [column] if column else self._columns
|
|
349
|
+
|
|
350
|
+
# Read just the necessary columns from that row group
|
|
351
|
+
table = self._pq_file.read_row_group(row_group_index, columns=cols_to_read)
|
|
352
|
+
|
|
353
|
+
if column:
|
|
354
|
+
# If single column requested, return the value directly
|
|
355
|
+
return table.column(column)[row_in_group].as_py()
|
|
356
|
+
else:
|
|
357
|
+
# Return full row dict
|
|
358
|
+
row_data = {}
|
|
359
|
+
for col_name in self._columns:
|
|
360
|
+
val = table.column(col_name)[row_in_group].as_py()
|
|
361
|
+
row_data[col_name] = val
|
|
362
|
+
return row_data
|
|
363
|
+
|
|
364
|
+
raise RuntimeError("Should not reach here")
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class MappedFlatDataset(FlatDataset):
|
|
368
|
+
'''
|
|
369
|
+
A dataset that applies a mapping function to rows from a parent FlatDataset.
|
|
370
|
+
|
|
371
|
+
This class enables lazy transformation of data without materializing all rows
|
|
372
|
+
into memory. It supports chaining transformations for nested data structures.
|
|
373
|
+
|
|
374
|
+
Attributes:
|
|
375
|
+
parent_dataset (FlatDataset): The source dataset to read from.
|
|
376
|
+
map_fn (Callable): Function that transforms each row.
|
|
377
|
+
in_memory (bool): Whether to cache mapped data in memory.
|
|
378
|
+
'''
|
|
379
|
+
|
|
380
|
+
def __init__(
|
|
381
|
+
self,
|
|
382
|
+
parent_dataset: FlatDataset,
|
|
383
|
+
map_fn: Callable[[Dict[str, Any]], Dict[str, Any]],
|
|
384
|
+
in_memory: bool = False,
|
|
385
|
+
shuffle: bool = False
|
|
386
|
+
):
|
|
387
|
+
'''
|
|
388
|
+
Initializes the MappedFlatDataset.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
parent_dataset (FlatDataset): The source dataset to read from.
|
|
392
|
+
map_fn (Callable): Function that transforms each row dictionary.
|
|
393
|
+
Should accept a dict and return a dict.
|
|
394
|
+
in_memory (bool): If True, caches all mapped rows in memory.
|
|
395
|
+
Defaults to False.
|
|
396
|
+
shuffle (bool): If True, shuffles the access indices. Defaults to False.
|
|
397
|
+
'''
|
|
398
|
+
self.parent_dataset = parent_dataset
|
|
399
|
+
self.map_fn = map_fn
|
|
400
|
+
self.in_memory = in_memory
|
|
401
|
+
self.shuffle = shuffle
|
|
402
|
+
self._data = []
|
|
403
|
+
self._indices = []
|
|
404
|
+
self._length = parent_dataset._length
|
|
405
|
+
self._columns = []
|
|
406
|
+
self._file_type = 'mapped'
|
|
407
|
+
|
|
408
|
+
# Peek first row to determine columns
|
|
409
|
+
if self._length > 0:
|
|
410
|
+
first_row = self.map_fn(parent_dataset.get_value(0))
|
|
411
|
+
self._columns = list(first_row.keys())
|
|
412
|
+
|
|
413
|
+
if self.in_memory:
|
|
414
|
+
self._load_all_mapped()
|
|
415
|
+
|
|
416
|
+
if self.shuffle:
|
|
417
|
+
self._indices = list(range(self._length))
|
|
418
|
+
random.shuffle(self._indices)
|
|
419
|
+
else:
|
|
420
|
+
self._indices = range(self._length)
|
|
421
|
+
|
|
422
|
+
def _load_all_mapped(self):
|
|
423
|
+
'''
|
|
424
|
+
Loads all mapped rows into memory.
|
|
425
|
+
'''
|
|
426
|
+
self._data = []
|
|
427
|
+
for i in range(self._length):
|
|
428
|
+
row = self.parent_dataset.get_value(i)
|
|
429
|
+
mapped_row = self.map_fn(row)
|
|
430
|
+
self._data.append(mapped_row)
|
|
431
|
+
|
|
432
|
+
def get_value(self, idx: int, column: Optional[str] = None) -> Any:
|
|
433
|
+
'''
|
|
434
|
+
Retrieves the value at the specified index, optionally for a specific column.
|
|
435
|
+
|
|
436
|
+
Args:
|
|
437
|
+
idx (int): The real index of the row.
|
|
438
|
+
column (Optional[str]): The specific column name to retrieve.
|
|
439
|
+
If None, returns the entire row.
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
Any: The value of the column or the full row dictionary.
|
|
443
|
+
'''
|
|
444
|
+
if self.in_memory:
|
|
445
|
+
row = self._data[idx]
|
|
446
|
+
if column:
|
|
447
|
+
return row[column]
|
|
448
|
+
return row
|
|
449
|
+
|
|
450
|
+
# Lazy loading: fetch from parent and apply mapping
|
|
451
|
+
parent_row = self.parent_dataset.get_value(idx)
|
|
452
|
+
mapped_row = self.map_fn(parent_row)
|
|
453
|
+
if column:
|
|
454
|
+
return mapped_row[column]
|
|
455
|
+
return mapped_row
|