omnigenome 0.3.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (85) hide show
  1. omnigenome/__init__.py +281 -0
  2. omnigenome/auto/__init__.py +3 -0
  3. omnigenome/auto/auto_bench/__init__.py +12 -0
  4. omnigenome/auto/auto_bench/auto_bench.py +484 -0
  5. omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
  6. omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
  7. omnigenome/auto/auto_bench/config_check.py +34 -0
  8. omnigenome/auto/auto_train/__init__.py +13 -0
  9. omnigenome/auto/auto_train/auto_train.py +430 -0
  10. omnigenome/auto/auto_train/auto_train_cli.py +222 -0
  11. omnigenome/auto/bench_hub/__init__.py +12 -0
  12. omnigenome/auto/bench_hub/bench_hub.py +25 -0
  13. omnigenome/cli/__init__.py +13 -0
  14. omnigenome/cli/commands/__init__.py +13 -0
  15. omnigenome/cli/commands/base.py +83 -0
  16. omnigenome/cli/commands/bench/__init__.py +13 -0
  17. omnigenome/cli/commands/bench/bench_cli.py +202 -0
  18. omnigenome/cli/commands/rna/__init__.py +13 -0
  19. omnigenome/cli/commands/rna/rna_design.py +178 -0
  20. omnigenome/cli/omnigenome_cli.py +128 -0
  21. omnigenome/src/__init__.py +12 -0
  22. omnigenome/src/abc/__init__.py +12 -0
  23. omnigenome/src/abc/abstract_dataset.py +622 -0
  24. omnigenome/src/abc/abstract_metric.py +114 -0
  25. omnigenome/src/abc/abstract_model.py +689 -0
  26. omnigenome/src/abc/abstract_tokenizer.py +267 -0
  27. omnigenome/src/dataset/__init__.py +16 -0
  28. omnigenome/src/dataset/omni_dataset.py +435 -0
  29. omnigenome/src/lora/__init__.py +13 -0
  30. omnigenome/src/lora/lora_model.py +294 -0
  31. omnigenome/src/metric/__init__.py +15 -0
  32. omnigenome/src/metric/classification_metric.py +184 -0
  33. omnigenome/src/metric/metric.py +199 -0
  34. omnigenome/src/metric/ranking_metric.py +142 -0
  35. omnigenome/src/metric/regression_metric.py +191 -0
  36. omnigenome/src/misc/__init__.py +3 -0
  37. omnigenome/src/misc/utils.py +439 -0
  38. omnigenome/src/model/__init__.py +19 -0
  39. omnigenome/src/model/augmentation/__init__.py +12 -0
  40. omnigenome/src/model/augmentation/model.py +219 -0
  41. omnigenome/src/model/classification/__init__.py +12 -0
  42. omnigenome/src/model/classification/model.py +642 -0
  43. omnigenome/src/model/embedding/__init__.py +12 -0
  44. omnigenome/src/model/embedding/model.py +263 -0
  45. omnigenome/src/model/mlm/__init__.py +12 -0
  46. omnigenome/src/model/mlm/model.py +177 -0
  47. omnigenome/src/model/module_utils.py +232 -0
  48. omnigenome/src/model/regression/__init__.py +12 -0
  49. omnigenome/src/model/regression/model.py +786 -0
  50. omnigenome/src/model/regression/resnet.py +483 -0
  51. omnigenome/src/model/rna_design/__init__.py +12 -0
  52. omnigenome/src/model/rna_design/model.py +426 -0
  53. omnigenome/src/model/seq2seq/__init__.py +12 -0
  54. omnigenome/src/model/seq2seq/model.py +44 -0
  55. omnigenome/src/tokenizer/__init__.py +16 -0
  56. omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
  57. omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
  58. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
  59. omnigenome/src/trainer/__init__.py +14 -0
  60. omnigenome/src/trainer/accelerate_trainer.py +739 -0
  61. omnigenome/src/trainer/hf_trainer.py +75 -0
  62. omnigenome/src/trainer/trainer.py +579 -0
  63. omnigenome/utility/__init__.py +3 -0
  64. omnigenome/utility/dataset_hub/__init__.py +13 -0
  65. omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
  66. omnigenome/utility/ensemble.py +324 -0
  67. omnigenome/utility/hub_utils.py +517 -0
  68. omnigenome/utility/model_hub/__init__.py +12 -0
  69. omnigenome/utility/model_hub/model_hub.py +231 -0
  70. omnigenome/utility/pipeline_hub/__init__.py +12 -0
  71. omnigenome/utility/pipeline_hub/pipeline.py +483 -0
  72. omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
  73. omnigenome-0.3.0a0.dist-info/METADATA +224 -0
  74. omnigenome-0.3.0a0.dist-info/RECORD +85 -0
  75. omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
  76. omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
  77. omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
  78. omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
  79. tests/__init__.py +9 -0
  80. tests/conftest.py +160 -0
  81. tests/test_dataset_patterns.py +291 -0
  82. tests/test_examples_syntax.py +83 -0
  83. tests/test_model_loading.py +183 -0
  84. tests/test_rna_functions.py +255 -0
  85. tests/test_training_patterns.py +302 -0
@@ -0,0 +1,622 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: abstract_dataset.py
3
+ # time: 14:13 06/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+ import random
10
+ import warnings
11
+ from collections import Counter
12
+
13
+ import numpy as np
14
+ import torch
15
+ import tqdm
16
+
17
+ from transformers import BatchEncoding
18
+
19
+ from ..misc.utils import fprint, env_meta_info, RNA2StructureCache
20
+
21
+
22
+ def covert_input_to_tensor(data):
23
+ """
24
+ Recursively converts numerical values in a nested data structure to PyTorch tensors.
25
+
26
+ This function traverses through nested dictionaries and lists, converting
27
+ numerical values to PyTorch tensors while preserving the structure.
28
+
29
+ Args:
30
+ data (list or dict): A list or dictionary containing data samples.
31
+
32
+ Returns:
33
+ list or dict: The data structure with numerical values converted to tensors.
34
+
35
+ Example:
36
+ >>> data = [{'input_ids': [1, 2, 3], 'labels': [0]}]
37
+ >>> tensor_data = covert_input_to_tensor(data)
38
+ >>> print(type(tensor_data[0]['input_ids'])) # <class 'torch.Tensor'>
39
+ """
40
+ for d in data:
41
+ if isinstance(d, dict) or isinstance(d, BatchEncoding):
42
+ for key, value in d.items():
43
+ try:
44
+ if not isinstance(value, torch.Tensor):
45
+ d[key] = torch.tensor(value)
46
+ except Exception as e:
47
+ pass
48
+ elif isinstance(d, list):
49
+ for value in d:
50
+ covert_input_to_tensor(value)
51
+ covert_input_to_tensor(d)
52
+
53
+ return data
54
+
55
+
56
+ class OmniGenomeDict(dict):
57
+ """
58
+ A dictionary subclass that allows moving all tensor values to a specified device.
59
+
60
+ This class extends the standard Python dictionary to provide a convenient
61
+ method for moving all tensor values to a specific device (CPU/GPU).
62
+ """
63
+
64
+ def __init__(self, *args, **kwargs):
65
+ super(OmniGenomeDict, self).__init__(*args, **kwargs)
66
+
67
+ def to(self, device):
68
+ """
69
+ Moves all tensor values in the dictionary to the specified device.
70
+
71
+ Args:
72
+ device (str or torch.device): The target device (e.g., 'cuda:0' or 'cpu').
73
+
74
+ Returns:
75
+ OmniGenomeDict: The dictionary itself, with tensors moved to the new device.
76
+
77
+ Example:
78
+ >>> data = OmniGenomeDict({'input_ids': torch.tensor([1, 2, 3])})
79
+ >>> data.to('cuda:0') # Moves tensors to GPU
80
+ """
81
+ for key, value in self.items():
82
+ if isinstance(value, torch.Tensor):
83
+ self[key] = value.to(device)
84
+ return self
85
+
86
+
87
+ class OmniDataset(torch.utils.data.Dataset):
88
+ """
89
+ Abstract base class for all datasets in OmniGenome.
90
+
91
+ This class provides a unified interface for genomic datasets in the OmniGenome
92
+ framework. It handles data loading, preprocessing, tokenization, and provides
93
+ a PyTorch-compatible dataset interface.
94
+
95
+ The class supports various data formats and can handle different types of
96
+ genomic tasks including classification, regression, and token-level tasks.
97
+
98
+ Attributes:
99
+ tokenizer: The tokenizer to use for processing sequences.
100
+ max_length (int): The maximum sequence length for tokenization.
101
+ label2id (dict): Mapping from labels to integer IDs.
102
+ id2label (dict): Mapping from integer IDs to labels.
103
+ shuffle (bool): Whether to shuffle the data.
104
+ structure_in (bool): Whether to include secondary structure information.
105
+ drop_long_seq (bool): Whether to drop sequences longer than max_length.
106
+ metadata (dict): Metadata about the dataset including version info.
107
+ rna2structure (RNA2StructureCache): Cache for RNA structure predictions.
108
+ """
109
+
110
+ def __init__(self, data_source, tokenizer, max_length=None, **kwargs):
111
+ """
112
+ Initializes the dataset.
113
+
114
+ Args:
115
+ data_source (str or list): Path to the data file or a list of paths.
116
+ tokenizer: The tokenizer to use for processing sequences.
117
+ max_length (int, optional): The maximum sequence length.
118
+ **kwargs: Additional keyword arguments.
119
+ - label2id (dict): A mapping from labels to integer IDs.
120
+ - shuffle (bool): Whether to shuffle the data. Defaults to True.
121
+ - structure_in (bool): Whether to include secondary structure
122
+ information. Defaults to False.
123
+ - drop_long_seq (bool): Whether to drop sequences longer than
124
+ max_length. Defaults to False.
125
+
126
+ Example:
127
+ >>> # Initialize with a single data file
128
+ >>> dataset = OmniDataset("data.json", tokenizer, max_length=512)
129
+
130
+ >>> # Initialize with label mapping
131
+ >>> dataset = OmniDataset("data.json", tokenizer,
132
+ ... label2id={"A": 0, "B": 1})
133
+ """
134
+ super(OmniDataset, self).__init__()
135
+ self.metadata = env_meta_info()
136
+ self.tokenizer = tokenizer
137
+ self.label2id = kwargs.get("label2id", None)
138
+ self.shuffle = kwargs.get("shuffle", True)
139
+ self.structure_in = kwargs.get("structure_in", False)
140
+ self.drop_long_seq = kwargs.get("drop_long_seq", False)
141
+ if self.structure_in and not hasattr(self, "rna2structure"):
142
+ self.rna2structure = RNA2StructureCache()
143
+
144
+ if self.label2id is not None:
145
+ self.id2label = {v: k for k, v in self.label2id.items()}
146
+
147
+ if max_length is not None:
148
+ fprint(
149
+ f"Detected max_length={max_length} in the dataset, using it as the max_length."
150
+ )
151
+ self.max_length = max_length
152
+ elif (
153
+ hasattr(self.tokenizer, "max_length")
154
+ and self.tokenizer.max_length is not None
155
+ ):
156
+ fprint(
157
+ f"Detected max_length={self.tokenizer.max_length} from the tokenizer."
158
+ )
159
+ self.max_length = self.tokenizer.max_length
160
+ else:
161
+ fprint(
162
+ f"No max_length detected, using default max_length=512."
163
+ )
164
+ self.max_length = 512
165
+
166
+ self.tokenizer.max_length = self.max_length
167
+ self.examples = []
168
+ self.data = []
169
+
170
+ if data_source is not None:
171
+ fprint(f"Loading data from {data_source}...")
172
+ self.load_data_source(data_source, **kwargs)
173
+ self._preprocessing()
174
+
175
+ for example in tqdm.tqdm(self.examples):
176
+ if hasattr(self.tokenizer, "max_length"):
177
+ self.tokenizer.max_length = self.max_length
178
+ else:
179
+ self.tokenizer.base_tokenizer.max_length = self.max_length
180
+
181
+ import inspect
182
+
183
+ new_args = {}
184
+ tokenization_args = inspect.getfullargspec(self.tokenizer.encode).args
185
+ for key in kwargs:
186
+ if key in tokenization_args:
187
+ new_args[key] = kwargs[key]
188
+ prepared_input = self.prepare_input(example, **new_args)
189
+
190
+ if (
191
+ self.drop_long_seq
192
+ and len(prepared_input["input_ids"]) > self.max_length
193
+ ):
194
+ fprint(
195
+ f"Dropping sequence {example['sequence']} due to length > {self.max_length}"
196
+ )
197
+ else:
198
+ self.data.append(prepared_input)
199
+
200
+ self._postprocessing()
201
+
202
+ def print_label_distribution(self):
203
+ """
204
+ Print the distribution of labels for 0-dimensional (scalar) labels.
205
+ This is useful for classification tasks where each sample has a single label.
206
+ """
207
+ # Check if we have scalar labels
208
+ if self.data and "labels" in self.data[0]:
209
+ first_label = self.data[0]["labels"]
210
+ if isinstance(first_label.item(), float):
211
+ return
212
+
213
+ if not isinstance(first_label, torch.Tensor) or first_label.ndim == 0:
214
+ # Convert labels to list of integers
215
+ labels = [int(d["labels"]) for d in self.data]
216
+
217
+ # Count frequency of each label
218
+ label_counts = Counter(labels)
219
+ total_samples = len(labels)
220
+
221
+ # Sort by label value
222
+ sorted_counts = sorted(label_counts.items())
223
+
224
+ fprint("\nLabel Distribution:")
225
+ fprint("-" * 40)
226
+ fprint(f"{'Label':<10}\t\t{'Count':<10}\t\t{'Percentage':<10}")
227
+ fprint("-" * 40)
228
+
229
+ for label, count in sorted_counts:
230
+ percentage = (count / total_samples) * 100
231
+ label_name = (
232
+ self.id2label[label]
233
+ if hasattr(self, "id2label")
234
+ else str(label)
235
+ )
236
+ fprint(f"{label_name:<10}\t\t{count:<10}\t\t{percentage:.2f}%")
237
+
238
+ fprint("-" * 40)
239
+ fprint(f"Total samples: {total_samples}")
240
+ else:
241
+ fprint(
242
+ "Warning: This method is only for scalar (0-dimensional) labels."
243
+ )
244
+ else:
245
+ fprint("No labels found in the dataset.")
246
+
247
+ def to(self, device):
248
+ """
249
+ Moves all tensor data in the dataset to the specified device.
250
+
251
+ Args:
252
+ device (str or torch.device): The target device.
253
+
254
+ Returns:
255
+ OmniDataset: The dataset itself.
256
+ """
257
+ for data_item in self.data:
258
+ for key, value in data_item.items():
259
+ if isinstance(value, torch.Tensor):
260
+ data_item[key] = value.to(device)
261
+ return self
262
+
263
+ def _pad_and_truncate(self, pad_value=0):
264
+ """
265
+ Pads and truncates sequences in the dataset to a uniform length.
266
+ The length is determined dynamically based on the longest sequence in the batch,
267
+ up to the `self.max_length` limit, and adjusted to be a multiple of 8.
268
+
269
+ Args:
270
+ pad_value (int, optional): The value to use for padding. Defaults to 0.
271
+
272
+ Returns:
273
+ list: The padded and truncated data.
274
+ """
275
+ if hasattr(self.tokenizer, "pad_token_id"):
276
+ pad_token_id = self.tokenizer.pad_token_id
277
+ else:
278
+ pad_token_id = self.tokenizer.base_tokenizer.pad_token_id
279
+
280
+ # 计算输入和标签的最大长度
281
+ max_input_length = max(
282
+ [
283
+ torch.sum(data_item["input_ids"] != pad_token_id).item()
284
+ for data_item in self.data
285
+ ]
286
+ )
287
+ max_label_length = max(
288
+ [
289
+ (data_item["labels"].shape[0] if data_item["labels"].ndim >= 1 else 0)
290
+ for data_item in self.data
291
+ ]
292
+ )
293
+
294
+ # 确定初始max_length,不超过self.max_length
295
+ original_max_length = max(max_input_length, max_label_length)
296
+ original_max_length = min(original_max_length, self.max_length)
297
+
298
+ # 调整到不超过self.max_length的最大的8的倍数
299
+ remainder = original_max_length % 8
300
+ if remainder != 0:
301
+ adjusted_max_length = original_max_length + (8 - remainder)
302
+ adjusted_max_length = min(adjusted_max_length, self.max_length)
303
+ else:
304
+ adjusted_max_length = original_max_length
305
+ max_length = adjusted_max_length
306
+
307
+ # 处理标签的特殊情况(修复错误的关键部分)
308
+ first_labels = self.data[0]["labels"]
309
+
310
+ label_shape = first_labels.shape
311
+ if len(label_shape) >= 1:
312
+ label_padding_length = max(max_length, self.data[0]["labels"].shape[0])
313
+ label_padding_length = min(label_padding_length, max_length)
314
+ max_length = max(max_length, label_padding_length)
315
+ else:
316
+ label_padding_length = 0
317
+
318
+ fprint(
319
+ f"Max sequence length updated -> Reset max_length={max_length},"
320
+ f" label_padding_length={label_padding_length}"
321
+ )
322
+
323
+ for data_item in self.data:
324
+ for key, value in data_item.items():
325
+ # 确保转换为Tensor
326
+ if not isinstance(value, torch.Tensor):
327
+ value = torch.as_tensor(value)
328
+ dtype = value.dtype
329
+ if "label" in key and (
330
+ value.dtype == torch.int16 or value.dtype == torch.int32
331
+ ):
332
+ data_item[key] = value.long()
333
+ # 确定填充长度
334
+ if "label" in key:
335
+ if value.ndim == 0: # 处理标量标签
336
+ padding_length = 0
337
+ else:
338
+ padding_length = label_padding_length - value.size(0)
339
+ else:
340
+ padding_length = max_length - value.size(0)
341
+
342
+ # 处理填充或截断
343
+ if padding_length > 0:
344
+ # 确定填充值
345
+ if key == "input_ids":
346
+ _pad_value = pad_token_id
347
+ elif key == "attention_mask":
348
+ _pad_value = 0
349
+ elif "ids" in key:
350
+ _pad_value = 0
351
+ elif "label" in key:
352
+ _pad_value = -100
353
+ elif "ids" in key:
354
+ _pad_value = pad_token_id
355
+ else:
356
+ _pad_value = pad_value
357
+
358
+ # 构建填充张量
359
+ if value.ndim == 2:
360
+ pad_shape = (padding_length, value.size(1))
361
+ else:
362
+ pad_shape = (padding_length,)
363
+ pad_tensor = torch.full(pad_shape, _pad_value, dtype=dtype)
364
+ data_item[key] = torch.cat([value, pad_tensor], dim=0)
365
+ elif padding_length < 0:
366
+ data_item[key] = value[:max_length]
367
+
368
+ # 确保数据类型正确
369
+ data_item[key] = data_item[key].to(dtype)
370
+
371
+ return self.data
372
+
373
+ def load_data_source(self, data_source, **kwargs):
374
+ """
375
+ Loads data from a file or list of files.
376
+
377
+ Args:
378
+ data_source (str or list): Path to the data file or a list of paths.
379
+ **kwargs: Additional keyword arguments, e.g., `max_examples`.
380
+
381
+ Returns:
382
+ list: A list of examples.
383
+ """
384
+ examples = []
385
+ max_examples = kwargs.get("max_examples", None)
386
+ if not isinstance(data_source, list):
387
+ data_source = [data_source]
388
+
389
+ for data_source in data_source:
390
+ if data_source.endswith(".csv"):
391
+ import pandas as pd
392
+
393
+ df = pd.read_csv(data_source)
394
+ for i in range(len(df)):
395
+ examples.append(df.iloc[i].to_dict())
396
+ elif data_source.endswith(".json"):
397
+ import json
398
+
399
+ try:
400
+ with open(data_source, "r", encoding="utf8") as f:
401
+ examples = json.load(f)
402
+ except:
403
+ with open(data_source, "r", encoding="utf8") as f:
404
+ lines = f.readlines() # Assume the data is a list of examples
405
+ for i in range(len(lines)):
406
+ lines[i] = json.loads(lines[i])
407
+ for line in lines:
408
+ examples.append(line)
409
+ elif data_source.endswith(".parquet"):
410
+ import pandas as pd
411
+
412
+ df = pd.read_parquet(data_source)
413
+ for i in range(len(df)):
414
+ examples.append(df.iloc[i].to_dict())
415
+ elif data_source.endswith(".txt") or data_source.endswith(".dat"):
416
+ with open(data_source, "r", encoding="utf8") as f:
417
+ lines = f.readlines()
418
+ for line in lines:
419
+ examples.append({"text": line.strip()})
420
+ elif data_source.endswith(('.fasta', '.fa', '.fna', '.ffn', '.faa', '.frn')):
421
+ try:
422
+ from Bio import SeqIO
423
+ except ImportError:
424
+ raise ImportError("Biopython is required for FASTA parsing. Please install with 'pip install biopython'.")
425
+ for record in SeqIO.parse(data_source, "fasta"):
426
+ examples.append({"id": record.id, "sequence": str(record.seq), "description": record.description})
427
+ elif data_source.endswith(('.fastq', '.fq')):
428
+ try:
429
+ from Bio import SeqIO
430
+ except ImportError:
431
+ raise ImportError("Biopython is required for FASTQ parsing. Please install with 'pip install biopython'.")
432
+ for record in SeqIO.parse(data_source, "fastq"):
433
+ examples.append({"id": record.id, "sequence": str(record.seq), "quality": record.letter_annotations.get("phred_quality", [])})
434
+ elif data_source.endswith('.bed'):
435
+ import pandas as pd
436
+ df = pd.read_csv(data_source, sep='\t', comment='#')
437
+ # Assign column names for standard BED fields
438
+ for _, row in df.iterrows():
439
+ examples.append(row.to_dict())
440
+ else:
441
+ raise Exception("Unknown file format.")
442
+
443
+ fprint(f"Loaded {len(examples)} examples from {data_source}")
444
+
445
+ if self.shuffle is True:
446
+ fprint("Detected shuffle=True, shuffling the examples...")
447
+ random.shuffle(examples)
448
+
449
+ if max_examples is not None:
450
+ fprint(f"Detected max_examples={max_examples}, truncating the examples...")
451
+ examples = examples[:max_examples]
452
+
453
+ self.examples = examples
454
+ return examples
455
+
456
+ def prepare_input(self, instance, **kwargs):
457
+ """
458
+ Prepares a single data instance for the model. Must be implemented by subclasses.
459
+
460
+ Args:
461
+ instance (dict): A single data instance (e.g., a dictionary).
462
+ **kwargs: Additional keyword arguments for tokenization.
463
+
464
+ Returns:
465
+ dict: A dictionary of tokenized inputs.
466
+ """
467
+ raise NotImplementedError(
468
+ "The prepare_input() function should be implemented for your dataset."
469
+ )
470
+
471
+ def _preprocessing(self):
472
+ """
473
+ Performs preprocessing on the loaded examples.
474
+ This method standardizes the 'sequence' field and adds secondary structure
475
+ information if `structure_in` is True.
476
+ """
477
+ for idx, ex in enumerate(self.examples):
478
+ if (
479
+ "seq" in self.examples[idx]
480
+ ): # For the RNA or DNA stored in the "seq" field
481
+ self.examples[idx]["sequence"] = self.examples[idx]["seq"]
482
+ del self.examples[idx]["seq"]
483
+ if (
484
+ "text" in self.examples[idx]
485
+ ): # For the RNA or DNA stored in the "text" field
486
+ self.examples[idx]["sequence"] = self.examples[idx]["text"]
487
+ del self.examples[idx]["text"]
488
+
489
+ if "sequence" not in self.examples[idx]:
490
+ warnings.warn("The 'sequence' field is missing in the raw dataset.")
491
+ if "sequence" in self.examples[0]:
492
+ sequences = [ex["sequence"] for ex in self.examples]
493
+ if self.structure_in:
494
+ structures = self.rna2structure.fold(sequences)
495
+ for idx, (sequence, structure) in enumerate(zip(sequences, structures)):
496
+ self.examples[idx][
497
+ "sequence"
498
+ ] = f"{sequence}{self.tokenizer.eos_token}{structure}"
499
+
500
+ def _postprocessing(self):
501
+ """
502
+ Performs postprocessing on the tokenized data.
503
+ This method standardizes the 'labels' field and prints the label distribution
504
+ for classification tasks.
505
+ """
506
+ for idx, ex in enumerate(self.data):
507
+ if "label" in self.data[idx]:
508
+ self.data[idx]["labels"] = self.data[idx]["label"]
509
+ # del self.data[idx]["label"]
510
+ # assert (
511
+ # "labels" in self.data[idx]
512
+ # ), "The 'labels' field is required in the tokenized dataset."
513
+
514
+ if "labels" not in self.data[idx].data or self.data[idx]["labels"] is None:
515
+ self.data[idx]["labels"] = torch.tensor([-100])
516
+
517
+ if self.data[0]["labels"].dim() == 0:
518
+ self.print_label_distribution()
519
+
520
+ def __len__(self):
521
+ """
522
+ Returns the number of samples in the dataset.
523
+
524
+ Returns:
525
+ int: The number of samples in the dataset.
526
+ """
527
+ return len(self.data)
528
+
529
+ def __getitem__(self, idx):
530
+ """
531
+ Returns a single data sample at the given index.
532
+
533
+ Args:
534
+ idx (int): The index of the sample.
535
+
536
+ Returns:
537
+ OmniGenomeDict: An `OmniGenomeDict` containing the data sample.
538
+ """
539
+ # convert the data item to a omnigenome dict
540
+ return OmniGenomeDict(self.data[idx])
541
+
542
+ def sample(self, n=1):
543
+ """
544
+ Returns a random sample of n items from the dataset.
545
+
546
+ Args:
547
+ n (int): The number of samples to return.
548
+
549
+ Returns:
550
+ list: A list of data samples.
551
+ """
552
+ return random.sample(self.data, n)
553
+
554
+ def get_column(self, column_name):
555
+ """
556
+ Returns all values for a specific column in the dataset.
557
+
558
+ Args:
559
+ column_name (str): The name of the column.
560
+
561
+ Returns:
562
+ list: A list of values from the specified column.
563
+ """
564
+ return [data_item[column_name] for data_item in self.data]
565
+
566
+ def get_labels(self):
567
+ """
568
+ Returns the set of unique labels in the dataset.
569
+
570
+ Returns:
571
+ set: The set of unique labels.
572
+ """
573
+ return set(self.get_column("labels"))
574
+
575
+ def get_inputs_length(self):
576
+ """
577
+ Calculates and returns statistics about sequence and label lengths.
578
+
579
+ Returns:
580
+ dict: A dictionary with length statistics (min, max, avg).
581
+ """
582
+ if hasattr(self.tokenizer, "pad_token_id"):
583
+ pad_token_id = self.tokenizer.pad_token_id
584
+ else:
585
+ pad_token_id = self.tokenizer.base_tokenizer.pad_token_id
586
+ length = {}
587
+ all_seq_lengths = [
588
+ torch.sum(data_item["input_ids"] != pad_token_id) for data_item in self.data
589
+ ]
590
+ all_label_lengths = [
591
+ data_item["labels"].shape[0] if data_item["labels"].shape else 1
592
+ for data_item in self.data
593
+ ]
594
+ length["avg_seq_len"] = np.mean(all_seq_lengths)
595
+ length["max_seq_len"] = np.max(all_seq_lengths)
596
+ length["min_seq_len"] = np.min(all_seq_lengths)
597
+ length["avg_label_len"] = np.mean(all_label_lengths)
598
+ length["max_label_len"] = np.max(all_label_lengths)
599
+ length["min_label_len"] = np.min(all_label_lengths)
600
+ return length
601
+
602
+ def _max_labels_length(self):
603
+ """
604
+ Returns the maximum length of labels in the dataset.
605
+
606
+ Returns:
607
+ int: The maximum length of labels.
608
+ """
609
+ if self.data[0]["labels"].dim() > 0:
610
+ return max([len(ex["labels"]) for ex in self.data])
611
+ else:
612
+ return 1
613
+
614
+ def __iter__(self):
615
+ """
616
+ Returns an iterator over the dataset.
617
+
618
+ Returns:
619
+ iterator: An iterator over the dataset.
620
+ """
621
+ for data_item in self.data:
622
+ yield OmniGenomeDict(data_item)