omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +29 -44
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +214 -150
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +168 -118
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
- omnigenome-0.3.0a0.dist-info/RECORD +0 -85
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
|
@@ -56,7 +56,7 @@ def covert_input_to_tensor(data):
|
|
|
56
56
|
class OmniGenomeDict(dict):
|
|
57
57
|
"""
|
|
58
58
|
A dictionary subclass that allows moving all tensor values to a specified device.
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
This class extends the standard Python dictionary to provide a convenient
|
|
61
61
|
method for moving all tensor values to a specific device (CPU/GPU).
|
|
62
62
|
"""
|
|
@@ -87,14 +87,14 @@ class OmniGenomeDict(dict):
|
|
|
87
87
|
class OmniDataset(torch.utils.data.Dataset):
|
|
88
88
|
"""
|
|
89
89
|
Abstract base class for all datasets in OmniGenome.
|
|
90
|
-
|
|
90
|
+
|
|
91
91
|
This class provides a unified interface for genomic datasets in the OmniGenome
|
|
92
92
|
framework. It handles data loading, preprocessing, tokenization, and provides
|
|
93
93
|
a PyTorch-compatible dataset interface.
|
|
94
|
-
|
|
94
|
+
|
|
95
95
|
The class supports various data formats and can handle different types of
|
|
96
96
|
genomic tasks including classification, regression, and token-level tasks.
|
|
97
|
-
|
|
97
|
+
|
|
98
98
|
Attributes:
|
|
99
99
|
tokenizer: The tokenizer to use for processing sequences.
|
|
100
100
|
max_length (int): The maximum sequence length for tokenization.
|
|
@@ -118,17 +118,17 @@ class OmniDataset(torch.utils.data.Dataset):
|
|
|
118
118
|
**kwargs: Additional keyword arguments.
|
|
119
119
|
- label2id (dict): A mapping from labels to integer IDs.
|
|
120
120
|
- shuffle (bool): Whether to shuffle the data. Defaults to True.
|
|
121
|
-
- structure_in (bool): Whether to include secondary structure
|
|
121
|
+
- structure_in (bool): Whether to include secondary structure
|
|
122
122
|
information. Defaults to False.
|
|
123
|
-
- drop_long_seq (bool): Whether to drop sequences longer than
|
|
123
|
+
- drop_long_seq (bool): Whether to drop sequences longer than
|
|
124
124
|
max_length. Defaults to False.
|
|
125
125
|
|
|
126
126
|
Example:
|
|
127
127
|
>>> # Initialize with a single data file
|
|
128
128
|
>>> dataset = OmniDataset("data.json", tokenizer, max_length=512)
|
|
129
|
-
|
|
129
|
+
|
|
130
130
|
>>> # Initialize with label mapping
|
|
131
|
-
>>> dataset = OmniDataset("data.json", tokenizer,
|
|
131
|
+
>>> dataset = OmniDataset("data.json", tokenizer,
|
|
132
132
|
... label2id={"A": 0, "B": 1})
|
|
133
133
|
"""
|
|
134
134
|
super(OmniDataset, self).__init__()
|
|
@@ -158,9 +158,7 @@ class OmniDataset(torch.utils.data.Dataset):
|
|
|
158
158
|
)
|
|
159
159
|
self.max_length = self.tokenizer.max_length
|
|
160
160
|
else:
|
|
161
|
-
fprint(
|
|
162
|
-
f"No max_length detected, using default max_length=512."
|
|
163
|
-
)
|
|
161
|
+
fprint(f"No max_length detected, using default max_length=512.")
|
|
164
162
|
self.max_length = 512
|
|
165
163
|
|
|
166
164
|
self.tokenizer.max_length = self.max_length
|
|
@@ -417,23 +415,44 @@ class OmniDataset(torch.utils.data.Dataset):
|
|
|
417
415
|
lines = f.readlines()
|
|
418
416
|
for line in lines:
|
|
419
417
|
examples.append({"text": line.strip()})
|
|
420
|
-
elif data_source.endswith(
|
|
418
|
+
elif data_source.endswith(
|
|
419
|
+
(".fasta", ".fa", ".fna", ".ffn", ".faa", ".frn")
|
|
420
|
+
):
|
|
421
421
|
try:
|
|
422
422
|
from Bio import SeqIO
|
|
423
423
|
except ImportError:
|
|
424
|
-
raise ImportError(
|
|
424
|
+
raise ImportError(
|
|
425
|
+
"Biopython is required for FASTA parsing. Please install with 'pip install biopython'."
|
|
426
|
+
)
|
|
425
427
|
for record in SeqIO.parse(data_source, "fasta"):
|
|
426
|
-
examples.append(
|
|
427
|
-
|
|
428
|
+
examples.append(
|
|
429
|
+
{
|
|
430
|
+
"id": record.id,
|
|
431
|
+
"sequence": str(record.seq),
|
|
432
|
+
"description": record.description,
|
|
433
|
+
}
|
|
434
|
+
)
|
|
435
|
+
elif data_source.endswith((".fastq", ".fq")):
|
|
428
436
|
try:
|
|
429
437
|
from Bio import SeqIO
|
|
430
438
|
except ImportError:
|
|
431
|
-
raise ImportError(
|
|
439
|
+
raise ImportError(
|
|
440
|
+
"Biopython is required for FASTQ parsing. Please install with 'pip install biopython'."
|
|
441
|
+
)
|
|
432
442
|
for record in SeqIO.parse(data_source, "fastq"):
|
|
433
|
-
examples.append(
|
|
434
|
-
|
|
443
|
+
examples.append(
|
|
444
|
+
{
|
|
445
|
+
"id": record.id,
|
|
446
|
+
"sequence": str(record.seq),
|
|
447
|
+
"quality": record.letter_annotations.get(
|
|
448
|
+
"phred_quality", []
|
|
449
|
+
),
|
|
450
|
+
}
|
|
451
|
+
)
|
|
452
|
+
elif data_source.endswith(".bed"):
|
|
435
453
|
import pandas as pd
|
|
436
|
-
|
|
454
|
+
|
|
455
|
+
df = pd.read_csv(data_source, sep="\t", comment="#")
|
|
437
456
|
# Assign column names for standard BED fields
|
|
438
457
|
for _, row in df.iterrows():
|
|
439
458
|
examples.append(row.to_dict())
|
|
@@ -15,17 +15,17 @@ from ..misc.utils import env_meta_info
|
|
|
15
15
|
class OmniMetric:
|
|
16
16
|
"""
|
|
17
17
|
Abstract base class for all metrics in OmniGenome, based on scikit-learn.
|
|
18
|
-
|
|
18
|
+
|
|
19
19
|
This class provides a unified interface for evaluation metrics in the OmniGenome
|
|
20
20
|
framework. It integrates with scikit-learn's metric functions and provides
|
|
21
21
|
additional functionality for handling genomic data evaluation.
|
|
22
|
-
|
|
22
|
+
|
|
23
23
|
The class automatically exposes all scikit-learn metrics as attributes,
|
|
24
24
|
making them easily accessible for evaluation tasks.
|
|
25
|
-
|
|
25
|
+
|
|
26
26
|
Attributes:
|
|
27
27
|
metric_func (callable): A callable metric function from `sklearn.metrics`.
|
|
28
|
-
ignore_y (any): A value in the ground truth labels to be ignored during
|
|
28
|
+
ignore_y (any): A value in the ground truth labels to be ignored during
|
|
29
29
|
metric computation.
|
|
30
30
|
metadata (dict): Metadata about the metric including version info.
|
|
31
31
|
"""
|
|
@@ -35,10 +35,10 @@ class OmniMetric:
|
|
|
35
35
|
Initializes the metric.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
|
-
metric_func (callable, optional): A callable metric function from
|
|
38
|
+
metric_func (callable, optional): A callable metric function from
|
|
39
39
|
`sklearn.metrics`. If None, subclasses
|
|
40
40
|
should implement their own compute method.
|
|
41
|
-
ignore_y (any, optional): A value in the ground truth labels to be
|
|
41
|
+
ignore_y (any, optional): A value in the ground truth labels to be
|
|
42
42
|
ignored during metric computation.
|
|
43
43
|
*args: Additional positional arguments.
|
|
44
44
|
**kwargs: Additional keyword arguments.
|
|
@@ -46,7 +46,7 @@ class OmniMetric:
|
|
|
46
46
|
Example:
|
|
47
47
|
>>> # Initialize with a specific metric function
|
|
48
48
|
>>> metric = OmniMetric(metrics.accuracy_score)
|
|
49
|
-
|
|
49
|
+
|
|
50
50
|
>>> # Initialize with ignore value
|
|
51
51
|
>>> metric = OmniMetric(ignore_y=-100)
|
|
52
52
|
"""
|
|
@@ -47,14 +47,14 @@ def count_parameters(model):
|
|
|
47
47
|
class OmniModel(torch.nn.Module):
|
|
48
48
|
"""
|
|
49
49
|
Abstract base class for all models in OmniGenome.
|
|
50
|
-
|
|
50
|
+
|
|
51
51
|
This class provides a unified interface for all genomic models in the OmniGenome
|
|
52
52
|
framework. It handles model initialization, forward passes, loss computation,
|
|
53
53
|
prediction, inference, and model persistence.
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
The class is designed to work with various types of genomic data and tasks,
|
|
56
56
|
including sequence classification, token classification, regression, and more.
|
|
57
|
-
|
|
57
|
+
|
|
58
58
|
Attributes:
|
|
59
59
|
model (torch.nn.Module): The underlying PyTorch model.
|
|
60
60
|
config: The model configuration.
|
|
@@ -76,16 +76,16 @@ class OmniModel(torch.nn.Module):
|
|
|
76
76
|
- From a configuration object
|
|
77
77
|
|
|
78
78
|
Args:
|
|
79
|
-
config_or_model: A model configuration, a pre-trained model path (str),
|
|
79
|
+
config_or_model: A model configuration, a pre-trained model path (str),
|
|
80
80
|
or a `torch.nn.Module` instance.
|
|
81
81
|
tokenizer: The tokenizer associated with the model.
|
|
82
82
|
*args: Additional positional arguments.
|
|
83
83
|
**kwargs: Additional keyword arguments.
|
|
84
84
|
- label2id (dict): Mapping from class labels to IDs.
|
|
85
85
|
- num_labels (int): The number of labels.
|
|
86
|
-
- trust_remote_code (bool): Whether to trust remote code when loading
|
|
86
|
+
- trust_remote_code (bool): Whether to trust remote code when loading
|
|
87
87
|
from Hugging Face Hub. Defaults to True.
|
|
88
|
-
- ignore_mismatched_sizes (bool): Whether to ignore size mismatches
|
|
88
|
+
- ignore_mismatched_sizes (bool): Whether to ignore size mismatches
|
|
89
89
|
when loading pre-trained weights. Defaults to False.
|
|
90
90
|
- dropout (float): Dropout rate. Defaults to 0.0.
|
|
91
91
|
|
|
@@ -97,7 +97,7 @@ class OmniModel(torch.nn.Module):
|
|
|
97
97
|
Example:
|
|
98
98
|
>>> # Initialize from a pre-trained model
|
|
99
99
|
>>> model = OmniModelForSequenceClassification("model_path", tokenizer)
|
|
100
|
-
|
|
100
|
+
|
|
101
101
|
>>> # Initialize from a configuration
|
|
102
102
|
>>> config = AutoConfig.from_pretrained("model_path")
|
|
103
103
|
>>> model = OmniModelForSequenceClassification(config, tokenizer)
|
|
@@ -202,7 +202,9 @@ class OmniModel(torch.nn.Module):
|
|
|
202
202
|
)
|
|
203
203
|
self.config.num_labels = len(self.config.id2label)
|
|
204
204
|
|
|
205
|
-
assert
|
|
205
|
+
assert (
|
|
206
|
+
len(self.config.label2id) == num_labels
|
|
207
|
+
), f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary."
|
|
206
208
|
|
|
207
209
|
# The metadata of the model
|
|
208
210
|
self.metadata = env_meta_info()
|
|
@@ -240,7 +242,7 @@ class OmniModel(torch.nn.Module):
|
|
|
240
242
|
model architectures by mapping input parameters appropriately.
|
|
241
243
|
|
|
242
244
|
Args:
|
|
243
|
-
**inputs: The inputs to the model, compatible with the base model's
|
|
245
|
+
**inputs: The inputs to the model, compatible with the base model's
|
|
244
246
|
forward method. Typically includes 'input_ids', 'attention_mask',
|
|
245
247
|
and other model-specific parameters.
|
|
246
248
|
|
|
@@ -386,7 +388,7 @@ class OmniModel(torch.nn.Module):
|
|
|
386
388
|
predictions for further processing.
|
|
387
389
|
|
|
388
390
|
Args:
|
|
389
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
391
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
390
392
|
tokenized inputs (dict/tuple).
|
|
391
393
|
**kwargs: Additional arguments for tokenization and inference.
|
|
392
394
|
|
|
@@ -398,7 +400,7 @@ class OmniModel(torch.nn.Module):
|
|
|
398
400
|
Example:
|
|
399
401
|
>>> # Predict on a single sequence
|
|
400
402
|
>>> outputs = model.predict("ATCGATCG")
|
|
401
|
-
|
|
403
|
+
|
|
402
404
|
>>> # Predict on multiple sequences
|
|
403
405
|
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
|
|
404
406
|
"""
|
|
@@ -416,7 +418,7 @@ class OmniModel(torch.nn.Module):
|
|
|
416
418
|
to class labels or probabilities.
|
|
417
419
|
|
|
418
420
|
Args:
|
|
419
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
421
|
+
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
420
422
|
tokenized inputs (dict/tuple).
|
|
421
423
|
**kwargs: Additional arguments for tokenization and inference.
|
|
422
424
|
|
|
@@ -429,7 +431,7 @@ class OmniModel(torch.nn.Module):
|
|
|
429
431
|
>>> # Inference on a single sequence
|
|
430
432
|
>>> results = model.inference("ATCGATCG")
|
|
431
433
|
>>> print(results['predictions']) # Class labels
|
|
432
|
-
|
|
434
|
+
|
|
433
435
|
>>> # Inference on multiple sequences
|
|
434
436
|
>>> results = model.inference(["ATCGATCG", "GCTAGCTA"])
|
|
435
437
|
"""
|
|
@@ -686,4 +688,3 @@ class OmniModel(torch.nn.Module):
|
|
|
686
688
|
info += f"Model Config: {self.config}\n"
|
|
687
689
|
fprint(info)
|
|
688
690
|
return info
|
|
689
|
-
|
|
@@ -16,15 +16,15 @@ from ..misc.utils import env_meta_info, load_module_from_path
|
|
|
16
16
|
class OmniTokenizer:
|
|
17
17
|
"""
|
|
18
18
|
A wrapper class for tokenizers to provide a consistent interface within OmniGenome.
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
This class provides a unified interface for tokenizers in the OmniGenome framework.
|
|
21
21
|
It wraps underlying tokenizers (typically from Hugging Face) and provides
|
|
22
22
|
additional functionality for genomic sequence processing.
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
The class handles various tokenization strategies and provides compatibility
|
|
25
25
|
with different model architectures. It also supports custom tokenizer wrappers
|
|
26
26
|
for specialized genomic tasks.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
Attributes:
|
|
29
29
|
base_tokenizer: The underlying tokenizer instance (e.g., from Hugging Face).
|
|
30
30
|
max_length (int): The default maximum sequence length.
|
|
@@ -52,7 +52,7 @@ class OmniTokenizer:
|
|
|
52
52
|
>>> from transformers import AutoTokenizer
|
|
53
53
|
>>> base_tokenizer = AutoTokenizer.from_pretrained("model_name")
|
|
54
54
|
>>> tokenizer = OmniTokenizer(base_tokenizer, max_length=512)
|
|
55
|
-
|
|
55
|
+
|
|
56
56
|
>>> # Initialize with sequence conversion
|
|
57
57
|
>>> tokenizer = OmniTokenizer(base_tokenizer, u2t=True)
|
|
58
58
|
"""
|
|
@@ -87,9 +87,9 @@ class OmniTokenizer:
|
|
|
87
87
|
Example:
|
|
88
88
|
>>> # Load from a pre-trained model
|
|
89
89
|
>>> tokenizer = OmniTokenizer.from_pretrained("model_name")
|
|
90
|
-
|
|
90
|
+
|
|
91
91
|
>>> # Load with custom parameters
|
|
92
|
-
>>> tokenizer = OmniTokenizer.from_pretrained("model_name",
|
|
92
|
+
>>> tokenizer = OmniTokenizer.from_pretrained("model_name",
|
|
93
93
|
... trust_remote_code=True)
|
|
94
94
|
"""
|
|
95
95
|
wrapper_path = f"{model_name_or_path.rstrip('/')}/omnigenome_wrapper.py"
|
|
@@ -104,7 +104,9 @@ class OmniTokenizer:
|
|
|
104
104
|
warnings.warn(
|
|
105
105
|
f"No tokenizer wrapper found in {wrapper_path} -> Exception: {e}"
|
|
106
106
|
)
|
|
107
|
-
kwargs.pop(
|
|
107
|
+
kwargs.pop(
|
|
108
|
+
"num_labels", None
|
|
109
|
+
) # Remove num_labels if it exists, as it may not be applicable
|
|
108
110
|
|
|
109
111
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
|
|
110
112
|
|
|
@@ -27,11 +27,11 @@ from ... import __name__, __version__
|
|
|
27
27
|
class OmniDatasetForTokenClassification(OmniDataset):
|
|
28
28
|
"""
|
|
29
29
|
Dataset class specifically designed for token classification tasks in genomics.
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
This class extends `OmniDataset` to provide functionalities for preparing input sequences
|
|
32
32
|
and their corresponding token-level labels. It's designed for tasks where each token
|
|
33
33
|
in a sequence needs to be classified independently.
|
|
34
|
-
|
|
34
|
+
|
|
35
35
|
Attributes:
|
|
36
36
|
metadata: Dictionary containing dataset metadata including library information
|
|
37
37
|
label2id: Mapping from label strings to integer IDs
|
|
@@ -68,7 +68,7 @@ class OmniDatasetForTokenClassification(OmniDataset):
|
|
|
68
68
|
def prepare_input(self, instance, **kwargs):
|
|
69
69
|
"""
|
|
70
70
|
Prepare a single data instance for token classification.
|
|
71
|
-
|
|
71
|
+
|
|
72
72
|
This method handles both string sequences and dictionary instances
|
|
73
73
|
containing sequence and label information. It tokenizes the input
|
|
74
74
|
sequence and prepares token-level labels for classification.
|
|
@@ -138,11 +138,11 @@ class OmniDatasetForTokenClassification(OmniDataset):
|
|
|
138
138
|
class OmniDatasetForSequenceClassification(OmniDataset):
|
|
139
139
|
"""
|
|
140
140
|
Dataset class for sequence classification tasks in genomics.
|
|
141
|
-
|
|
141
|
+
|
|
142
142
|
This class extends `OmniDataset` to prepare input sequences and their corresponding
|
|
143
143
|
sequence-level labels. It's designed for tasks where the entire sequence needs
|
|
144
144
|
to be classified into one of several categories.
|
|
145
|
-
|
|
145
|
+
|
|
146
146
|
Attributes:
|
|
147
147
|
metadata: Dictionary containing dataset metadata including library information
|
|
148
148
|
label2id: Mapping from label strings to integer IDs
|
|
@@ -179,7 +179,7 @@ class OmniDatasetForSequenceClassification(OmniDataset):
|
|
|
179
179
|
def prepare_input(self, instance, **kwargs):
|
|
180
180
|
"""
|
|
181
181
|
Prepare a single data instance for sequence classification.
|
|
182
|
-
|
|
182
|
+
|
|
183
183
|
This method handles both string sequences and dictionary instances
|
|
184
184
|
containing sequence and label information. It tokenizes the input
|
|
185
185
|
sequence and prepares sequence-level labels for classification.
|
|
@@ -238,11 +238,11 @@ class OmniDatasetForSequenceClassification(OmniDataset):
|
|
|
238
238
|
class OmniDatasetForTokenRegression(OmniDataset):
|
|
239
239
|
"""
|
|
240
240
|
Dataset class for token regression tasks in genomics.
|
|
241
|
-
|
|
241
|
+
|
|
242
242
|
This class extends `OmniDataset` to prepare input sequences and their corresponding
|
|
243
243
|
token-level regression targets. It's designed for tasks where each token in a
|
|
244
244
|
sequence needs to be assigned a continuous value.
|
|
245
|
-
|
|
245
|
+
|
|
246
246
|
Attributes:
|
|
247
247
|
metadata: Dictionary containing dataset metadata including library information
|
|
248
248
|
"""
|
|
@@ -278,7 +278,7 @@ class OmniDatasetForTokenRegression(OmniDataset):
|
|
|
278
278
|
def prepare_input(self, instance, **kwargs):
|
|
279
279
|
"""
|
|
280
280
|
Prepare a single data instance for token regression.
|
|
281
|
-
|
|
281
|
+
|
|
282
282
|
This method handles both string sequences and dictionary instances
|
|
283
283
|
containing sequence and regression target information. It tokenizes
|
|
284
284
|
the input sequence and prepares token-level regression targets.
|
|
@@ -330,7 +330,9 @@ class OmniDatasetForTokenRegression(OmniDataset):
|
|
|
330
330
|
# Handle token-level regression labels
|
|
331
331
|
if isinstance(labels, (list, tuple)):
|
|
332
332
|
# Ensure labels match sequence length
|
|
333
|
-
labels = list(labels)[
|
|
333
|
+
labels = list(labels)[
|
|
334
|
+
: self.max_length - 2
|
|
335
|
+
] # Account for special tokens
|
|
334
336
|
labels = [-100] + labels + [-100] # Add padding for special tokens
|
|
335
337
|
else:
|
|
336
338
|
# Single value for the entire sequence
|
|
@@ -343,11 +345,11 @@ class OmniDatasetForTokenRegression(OmniDataset):
|
|
|
343
345
|
class OmniDatasetForSequenceRegression(OmniDataset):
|
|
344
346
|
"""
|
|
345
347
|
Dataset class for sequence regression tasks in genomics.
|
|
346
|
-
|
|
348
|
+
|
|
347
349
|
This class extends `OmniDataset` to prepare input sequences and their corresponding
|
|
348
350
|
sequence-level regression targets. It's designed for tasks where the entire
|
|
349
351
|
sequence needs to be assigned a continuous value.
|
|
350
|
-
|
|
352
|
+
|
|
351
353
|
Attributes:
|
|
352
354
|
metadata: Dictionary containing dataset metadata including library information
|
|
353
355
|
"""
|
|
@@ -383,7 +385,7 @@ class OmniDatasetForSequenceRegression(OmniDataset):
|
|
|
383
385
|
def prepare_input(self, instance, **kwargs):
|
|
384
386
|
"""
|
|
385
387
|
Prepare a single data instance for sequence regression.
|
|
386
|
-
|
|
388
|
+
|
|
387
389
|
This method handles both string sequences and dictionary instances
|
|
388
390
|
containing sequence and regression target information. It tokenizes
|
|
389
391
|
the input sequence and prepares sequence-level regression targets.
|
|
@@ -432,4 +434,4 @@ class OmniDatasetForSequenceRegression(OmniDataset):
|
|
|
432
434
|
labels = float(labels)
|
|
433
435
|
|
|
434
436
|
tokenized_inputs["labels"] = torch.tensor(labels, dtype=torch.float32)
|
|
435
|
-
return tokenized_inputs
|
|
437
|
+
return tokenized_inputs
|
omnigenome/src/lora/__init__.py
CHANGED