omnigenome 0.3.0a1__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +16 -8
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +40 -36
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +65 -58
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +2 -2
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -0
|
@@ -17,17 +17,17 @@ warnings.filterwarnings("once")
|
|
|
17
17
|
def is_bpe_tokenization(tokens, threshold=0.1):
|
|
18
18
|
"""
|
|
19
19
|
Check if the tokenization is BPE-based by analyzing token characteristics.
|
|
20
|
-
|
|
20
|
+
|
|
21
21
|
This function examines the tokens to determine if they follow BPE tokenization
|
|
22
22
|
patterns by analyzing token length distributions and special token patterns.
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
Args:
|
|
25
25
|
tokens (list): List of tokens to analyze
|
|
26
26
|
threshold (float, optional): Threshold for determining BPE tokenization. Defaults to 0.1
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
Returns:
|
|
29
29
|
bool: True if tokens appear to be BPE-based, False otherwise
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Example:
|
|
32
32
|
>>> tokens = ["▁hello", "▁world", "▁how", "▁are", "▁you"]
|
|
33
33
|
>>> is_bpe = is_bpe_tokenization(tokens)
|
|
@@ -52,15 +52,15 @@ def is_bpe_tokenization(tokens, threshold=0.1):
|
|
|
52
52
|
class OmniBPETokenizer(OmniTokenizer):
|
|
53
53
|
"""
|
|
54
54
|
A Byte Pair Encoding (BPE) tokenizer for genomic sequences.
|
|
55
|
-
|
|
55
|
+
|
|
56
56
|
This tokenizer uses BPE tokenization for genomic sequences and provides
|
|
57
57
|
validation to ensure the base tokenizer is BPE-based. It supports sequence
|
|
58
58
|
preprocessing and handles various input formats.
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
Attributes:
|
|
61
61
|
base_tokenizer: The underlying BPE tokenizer
|
|
62
62
|
metadata: Dictionary containing tokenizer metadata
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
Example:
|
|
65
65
|
>>> from omnigenome.src.tokenizer import OmniBPETokenizer
|
|
66
66
|
>>> from transformers import AutoTokenizer
|
|
@@ -75,7 +75,7 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
75
75
|
def __init__(self, base_tokenizer=None, **kwargs):
|
|
76
76
|
"""
|
|
77
77
|
Initialize the OmniBPETokenizer.
|
|
78
|
-
|
|
78
|
+
|
|
79
79
|
Args:
|
|
80
80
|
base_tokenizer: The base BPE tokenizer
|
|
81
81
|
**kwargs: Additional keyword arguments passed to parent class
|
|
@@ -86,21 +86,21 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
86
86
|
def __call__(self, sequence, **kwargs):
|
|
87
87
|
"""
|
|
88
88
|
Tokenize a sequence using BPE tokenization.
|
|
89
|
-
|
|
89
|
+
|
|
90
90
|
This method processes the input sequence using BPE tokenization,
|
|
91
91
|
handles sequence preprocessing (U/T conversion, whitespace addition),
|
|
92
92
|
and validates that the tokenization is BPE-based.
|
|
93
|
-
|
|
93
|
+
|
|
94
94
|
Args:
|
|
95
95
|
sequence (str): Input sequence to tokenize
|
|
96
96
|
**kwargs: Additional keyword arguments including max_length
|
|
97
|
-
|
|
97
|
+
|
|
98
98
|
Returns:
|
|
99
99
|
dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask'
|
|
100
|
-
|
|
100
|
+
|
|
101
101
|
Raises:
|
|
102
102
|
ValueError: If the tokenizer is not BPE-based
|
|
103
|
-
|
|
103
|
+
|
|
104
104
|
Example:
|
|
105
105
|
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
106
106
|
>>> tokenized = tokenizer(sequence)
|
|
@@ -136,14 +136,14 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
136
136
|
def from_pretrained(model_name_or_path, **kwargs):
|
|
137
137
|
"""
|
|
138
138
|
Create a BPE tokenizer from a pre-trained model.
|
|
139
|
-
|
|
139
|
+
|
|
140
140
|
Args:
|
|
141
141
|
model_name_or_path (str): Name or path of the pre-trained model
|
|
142
142
|
**kwargs: Additional keyword arguments
|
|
143
|
-
|
|
143
|
+
|
|
144
144
|
Returns:
|
|
145
145
|
OmniBPETokenizer: Initialized BPE tokenizer
|
|
146
|
-
|
|
146
|
+
|
|
147
147
|
Example:
|
|
148
148
|
>>> tokenizer = OmniBPETokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
|
149
149
|
>>> print(type(tokenizer))
|
|
@@ -159,14 +159,14 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
159
159
|
def tokenize(self, sequence, **kwargs):
|
|
160
160
|
"""
|
|
161
161
|
Tokenize a sequence using the base BPE tokenizer.
|
|
162
|
-
|
|
162
|
+
|
|
163
163
|
Args:
|
|
164
164
|
sequence (str): Input sequence to tokenize
|
|
165
165
|
**kwargs: Additional keyword arguments
|
|
166
|
-
|
|
166
|
+
|
|
167
167
|
Returns:
|
|
168
168
|
list: List of tokens
|
|
169
|
-
|
|
169
|
+
|
|
170
170
|
Example:
|
|
171
171
|
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
172
172
|
>>> tokens = tokenizer.tokenize(sequence)
|
|
@@ -178,17 +178,17 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
178
178
|
def encode(self, sequence, **kwargs):
|
|
179
179
|
"""
|
|
180
180
|
Encode a sequence using the base BPE tokenizer.
|
|
181
|
-
|
|
181
|
+
|
|
182
182
|
Args:
|
|
183
183
|
sequence (str): Input sequence to encode
|
|
184
184
|
**kwargs: Additional keyword arguments
|
|
185
|
-
|
|
185
|
+
|
|
186
186
|
Returns:
|
|
187
187
|
list: List of token IDs
|
|
188
|
-
|
|
188
|
+
|
|
189
189
|
Raises:
|
|
190
190
|
AssertionError: If the base tokenizer is not BPE-based
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
Example:
|
|
193
193
|
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
194
194
|
>>> token_ids = tokenizer.encode(sequence)
|
|
@@ -203,17 +203,17 @@ class OmniBPETokenizer(OmniTokenizer):
|
|
|
203
203
|
def decode(self, sequence, **kwargs):
|
|
204
204
|
"""
|
|
205
205
|
Decode a sequence using the base BPE tokenizer.
|
|
206
|
-
|
|
206
|
+
|
|
207
207
|
Args:
|
|
208
208
|
sequence: Input sequence to decode (can be token IDs or tokens)
|
|
209
209
|
**kwargs: Additional keyword arguments
|
|
210
|
-
|
|
210
|
+
|
|
211
211
|
Returns:
|
|
212
212
|
str: Decoded sequence
|
|
213
|
-
|
|
213
|
+
|
|
214
214
|
Raises:
|
|
215
215
|
AssertionError: If the base tokenizer is not BPE-based
|
|
216
|
-
|
|
216
|
+
|
|
217
217
|
Example:
|
|
218
218
|
>>> token_ids = [1, 2, 3, 4, 5]
|
|
219
219
|
>>> sequence = tokenizer.decode(token_ids)
|
|
@@ -16,18 +16,18 @@ warnings.filterwarnings("once")
|
|
|
16
16
|
class OmniKmersTokenizer(OmniTokenizer):
|
|
17
17
|
"""
|
|
18
18
|
A k-mer based tokenizer for genomic sequences.
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
This tokenizer breaks genomic sequences into overlapping k-mers and uses
|
|
21
21
|
a base tokenizer to convert them into token IDs. It supports various
|
|
22
22
|
k-mer sizes and overlap configurations for different genomic applications.
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
Attributes:
|
|
25
25
|
base_tokenizer: The underlying tokenizer for converting k-mers to IDs
|
|
26
26
|
k: Size of k-mers
|
|
27
27
|
overlap: Number of overlapping positions between consecutive k-mers
|
|
28
28
|
max_length: Maximum sequence length for tokenization
|
|
29
29
|
metadata: Dictionary containing tokenizer metadata
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Example:
|
|
32
32
|
>>> from omnigenome.src.tokenizer import OmniKmersTokenizer
|
|
33
33
|
>>> from transformers import AutoTokenizer
|
|
@@ -42,7 +42,7 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
42
42
|
def __init__(self, base_tokenizer=None, k=3, overlap=0, max_length=512, **kwargs):
|
|
43
43
|
"""
|
|
44
44
|
Initialize the OmniKmersTokenizer.
|
|
45
|
-
|
|
45
|
+
|
|
46
46
|
Args:
|
|
47
47
|
base_tokenizer: The base tokenizer for converting k-mers to token IDs
|
|
48
48
|
k (int, optional): Size of k-mers. Defaults to 3
|
|
@@ -59,18 +59,18 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
59
59
|
def __call__(self, sequence, **kwargs):
|
|
60
60
|
"""
|
|
61
61
|
Tokenize a sequence or list of sequences into tokenized inputs.
|
|
62
|
-
|
|
62
|
+
|
|
63
63
|
This method processes the input sequence(s) by first converting them to k-mers,
|
|
64
64
|
then using the base tokenizer to convert k-mers to token IDs. It handles
|
|
65
65
|
sequence preprocessing (U/T conversion) and adds special tokens.
|
|
66
|
-
|
|
66
|
+
|
|
67
67
|
Args:
|
|
68
68
|
sequence (str or list): Input sequence(s) to tokenize
|
|
69
69
|
**kwargs: Additional keyword arguments including max_length
|
|
70
|
-
|
|
70
|
+
|
|
71
71
|
Returns:
|
|
72
72
|
dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask'
|
|
73
|
-
|
|
73
|
+
|
|
74
74
|
Example:
|
|
75
75
|
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
76
76
|
>>> tokenized = tokenizer(sequence)
|
|
@@ -126,14 +126,14 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
126
126
|
def from_pretrained(model_name_or_path, **kwargs):
|
|
127
127
|
"""
|
|
128
128
|
Create a k-mers tokenizer from a pre-trained model.
|
|
129
|
-
|
|
129
|
+
|
|
130
130
|
Args:
|
|
131
131
|
model_name_or_path (str): Name or path of the pre-trained model
|
|
132
132
|
**kwargs: Additional keyword arguments
|
|
133
|
-
|
|
133
|
+
|
|
134
134
|
Returns:
|
|
135
135
|
OmniKmersTokenizer: Initialized k-mers tokenizer
|
|
136
|
-
|
|
136
|
+
|
|
137
137
|
Example:
|
|
138
138
|
>>> tokenizer = OmniKmersTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
|
139
139
|
>>> print(type(tokenizer))
|
|
@@ -149,17 +149,17 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
149
149
|
def tokenize(self, sequence, **kwargs):
|
|
150
150
|
"""
|
|
151
151
|
Convert sequence(s) into k-mers.
|
|
152
|
-
|
|
152
|
+
|
|
153
153
|
This method breaks the input sequence(s) into overlapping k-mers based on
|
|
154
154
|
the configured k-mer size and overlap parameters.
|
|
155
|
-
|
|
155
|
+
|
|
156
156
|
Args:
|
|
157
157
|
sequence (str or list): Input sequence(s) to convert to k-mers
|
|
158
158
|
**kwargs: Additional keyword arguments
|
|
159
|
-
|
|
159
|
+
|
|
160
160
|
Returns:
|
|
161
161
|
list: List of k-mer lists for each input sequence
|
|
162
|
-
|
|
162
|
+
|
|
163
163
|
Example:
|
|
164
164
|
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
165
165
|
>>> k_mers = tokenizer.tokenize(sequence)
|
|
@@ -184,11 +184,11 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
184
184
|
def encode(self, input_ids, **kwargs):
|
|
185
185
|
"""
|
|
186
186
|
Encode input IDs using the base tokenizer.
|
|
187
|
-
|
|
187
|
+
|
|
188
188
|
Args:
|
|
189
189
|
input_ids: Input IDs to encode
|
|
190
190
|
**kwargs: Additional keyword arguments
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
Returns:
|
|
193
193
|
Encoded input IDs
|
|
194
194
|
"""
|
|
@@ -197,11 +197,11 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
197
197
|
def decode(self, input_ids, **kwargs):
|
|
198
198
|
"""
|
|
199
199
|
Decode input IDs using the base tokenizer.
|
|
200
|
-
|
|
200
|
+
|
|
201
201
|
Args:
|
|
202
202
|
input_ids: Input IDs to decode
|
|
203
203
|
**kwargs: Additional keyword arguments
|
|
204
|
-
|
|
204
|
+
|
|
205
205
|
Returns:
|
|
206
206
|
Decoded sequence
|
|
207
207
|
"""
|
|
@@ -210,13 +210,13 @@ class OmniKmersTokenizer(OmniTokenizer):
|
|
|
210
210
|
def encode_plus(self, sequence, **kwargs):
|
|
211
211
|
"""
|
|
212
212
|
Encode a sequence with additional information.
|
|
213
|
-
|
|
213
|
+
|
|
214
214
|
This method is not yet implemented for k-mers tokenizer.
|
|
215
|
-
|
|
215
|
+
|
|
216
216
|
Args:
|
|
217
217
|
sequence: Input sequence
|
|
218
218
|
**kwargs: Additional keyword arguments
|
|
219
|
-
|
|
219
|
+
|
|
220
220
|
Raises:
|
|
221
221
|
NotImplementedError: This method is not implemented yet
|
|
222
222
|
"""
|
|
@@ -19,16 +19,16 @@ warnings.filterwarnings("once")
|
|
|
19
19
|
class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
20
20
|
"""
|
|
21
21
|
Tokenizer for single nucleotide tokenization in genomics.
|
|
22
|
-
|
|
22
|
+
|
|
23
23
|
This tokenizer converts genomic sequences into individual nucleotide tokens,
|
|
24
24
|
where each nucleotide (A, T, C, G, U) becomes a separate token. It's designed
|
|
25
25
|
for genomic sequence processing where fine-grained nucleotide-level analysis
|
|
26
26
|
is required.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
The tokenizer supports various preprocessing options including U/T conversion
|
|
29
29
|
and whitespace addition between nucleotides. It also handles special tokens
|
|
30
30
|
like BOS (beginning of sequence) and EOS (end of sequence) tokens.
|
|
31
|
-
|
|
31
|
+
|
|
32
32
|
Attributes:
|
|
33
33
|
u2t (bool): Whether to convert 'U' to 'T'.
|
|
34
34
|
t2u (bool): Whether to convert 'T' to 'U'.
|
|
@@ -54,7 +54,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
54
54
|
def __call__(self, sequence, **kwargs):
|
|
55
55
|
"""
|
|
56
56
|
Tokenizes sequences using single nucleotide tokenization.
|
|
57
|
-
|
|
57
|
+
|
|
58
58
|
This method converts genomic sequences into tokenized inputs suitable
|
|
59
59
|
for model training and inference. It handles sequence preprocessing,
|
|
60
60
|
tokenization, and padding/truncation.
|
|
@@ -76,7 +76,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
76
76
|
>>> # Tokenize a single sequence
|
|
77
77
|
>>> inputs = tokenizer("ATCGATCG")
|
|
78
78
|
>>> print(inputs['input_ids'].shape) # torch.Size([1, seq_len])
|
|
79
|
-
|
|
79
|
+
|
|
80
80
|
>>> # Tokenize multiple sequences
|
|
81
81
|
>>> inputs = tokenizer(["ATCGATCG", "GCTAGCTA"])
|
|
82
82
|
>>> print(inputs['input_ids'].shape) # torch.Size([2, seq_len])
|
|
@@ -134,7 +134,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
134
134
|
def from_pretrained(model_name_or_path, **kwargs):
|
|
135
135
|
"""
|
|
136
136
|
Loads a single nucleotide tokenizer from a pre-trained model.
|
|
137
|
-
|
|
137
|
+
|
|
138
138
|
This method creates a single nucleotide tokenizer wrapper around
|
|
139
139
|
a Hugging Face tokenizer loaded from a pre-trained model.
|
|
140
140
|
|
|
@@ -156,7 +156,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
156
156
|
def tokenize(self, sequence, **kwargs):
|
|
157
157
|
"""
|
|
158
158
|
Converts a sequence into a list of individual nucleotide tokens.
|
|
159
|
-
|
|
159
|
+
|
|
160
160
|
This method tokenizes genomic sequences by treating each nucleotide
|
|
161
161
|
as a separate token. It handles both single sequences and lists of sequences.
|
|
162
162
|
|
|
@@ -172,7 +172,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
172
172
|
>>> # Tokenize a single sequence
|
|
173
173
|
>>> tokens = tokenizer.tokenize("ATCGATCG")
|
|
174
174
|
>>> print(tokens) # [['A', 'T', 'C', 'G', 'A', 'T', 'C', 'G']]
|
|
175
|
-
|
|
175
|
+
|
|
176
176
|
>>> # Tokenize multiple sequences
|
|
177
177
|
>>> tokens = tokenizer.tokenize(["ATCGATCG", "GCTAGCTA"])
|
|
178
178
|
>>> print(tokens) # [['A', 'T', 'C', 'G', ...], ['G', 'C', 'T', 'A', ...]]
|
|
@@ -191,7 +191,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
191
191
|
def encode(self, sequence, **kwargs):
|
|
192
192
|
"""
|
|
193
193
|
Converts a sequence into a list of token IDs.
|
|
194
|
-
|
|
194
|
+
|
|
195
195
|
This method encodes genomic sequences into token IDs using the
|
|
196
196
|
underlying base tokenizer.
|
|
197
197
|
|
|
@@ -211,7 +211,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
211
211
|
def decode(self, sequence, **kwargs):
|
|
212
212
|
"""
|
|
213
213
|
Converts a list of token IDs back into a sequence.
|
|
214
|
-
|
|
214
|
+
|
|
215
215
|
This method decodes token IDs back into genomic sequences using
|
|
216
216
|
the underlying base tokenizer.
|
|
217
217
|
|
|
@@ -231,7 +231,7 @@ class OmniSingleNucleotideTokenizer(OmniTokenizer):
|
|
|
231
231
|
def encode_plus(self, sequence, **kwargs):
|
|
232
232
|
"""
|
|
233
233
|
Encodes a sequence with additional information.
|
|
234
|
-
|
|
234
|
+
|
|
235
235
|
This method provides enhanced encoding with additional information
|
|
236
236
|
like attention masks and token type IDs.
|
|
237
237
|
|
|
@@ -21,15 +21,15 @@ from ..misc.utils import env_meta_info, fprint, seed_everything
|
|
|
21
21
|
def _infer_optimization_direction(metrics, prev_metrics):
|
|
22
22
|
"""
|
|
23
23
|
Infer the optimization direction based on metric values.
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
This function analyzes the trend of metric values to determine whether
|
|
26
26
|
larger values are better (e.g., accuracy) or smaller values are better
|
|
27
27
|
(e.g., loss).
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
Args:
|
|
30
30
|
metrics (dict): Current metric values
|
|
31
31
|
prev_metrics (list): Previous metric values
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
Returns:
|
|
34
34
|
str: Either 'larger_is_better' or 'smaller_is_better'
|
|
35
35
|
"""
|
|
@@ -91,11 +91,11 @@ def _infer_optimization_direction(metrics, prev_metrics):
|
|
|
91
91
|
class AccelerateTrainer:
|
|
92
92
|
"""
|
|
93
93
|
A distributed training trainer using HuggingFace Accelerate.
|
|
94
|
-
|
|
94
|
+
|
|
95
95
|
This trainer provides distributed training capabilities with automatic mixed precision,
|
|
96
96
|
gradient accumulation, and early stopping. It supports both single and multi-GPU
|
|
97
97
|
training with seamless integration with HuggingFace Accelerate.
|
|
98
|
-
|
|
98
|
+
|
|
99
99
|
Attributes:
|
|
100
100
|
model: The model to train
|
|
101
101
|
train_loader: DataLoader for training data
|
|
@@ -110,7 +110,7 @@ class AccelerateTrainer:
|
|
|
110
110
|
accelerator: HuggingFace Accelerate instance
|
|
111
111
|
metrics: Dictionary to store training metrics
|
|
112
112
|
predictions: Dictionary to store predictions
|
|
113
|
-
|
|
113
|
+
|
|
114
114
|
Example:
|
|
115
115
|
>>> from omnigenome.src.trainer import AccelerateTrainer
|
|
116
116
|
>>> trainer = AccelerateTrainer(
|
|
@@ -143,7 +143,7 @@ class AccelerateTrainer:
|
|
|
143
143
|
):
|
|
144
144
|
"""
|
|
145
145
|
Initialize the AccelerateTrainer.
|
|
146
|
-
|
|
146
|
+
|
|
147
147
|
Args:
|
|
148
148
|
model: The model to train
|
|
149
149
|
train_dataset (torch.utils.data.Dataset, optional): Training dataset
|
|
@@ -293,14 +293,14 @@ class AccelerateTrainer:
|
|
|
293
293
|
def evaluate(self):
|
|
294
294
|
"""
|
|
295
295
|
Evaluate the model on the validation dataset.
|
|
296
|
-
|
|
296
|
+
|
|
297
297
|
This method runs the model in evaluation mode and computes metrics
|
|
298
298
|
on the validation dataset. It handles distributed evaluation and
|
|
299
299
|
gathers results from all processes.
|
|
300
|
-
|
|
300
|
+
|
|
301
301
|
Returns:
|
|
302
302
|
dict: Dictionary containing evaluation metrics
|
|
303
|
-
|
|
303
|
+
|
|
304
304
|
Example:
|
|
305
305
|
>>> metrics = trainer.evaluate()
|
|
306
306
|
>>> print(f"Validation accuracy: {metrics['accuracy']:.4f}")
|
|
@@ -364,14 +364,14 @@ class AccelerateTrainer:
|
|
|
364
364
|
def test(self):
|
|
365
365
|
"""
|
|
366
366
|
Test the model on the test dataset.
|
|
367
|
-
|
|
367
|
+
|
|
368
368
|
This method runs the model in evaluation mode and computes metrics
|
|
369
369
|
on the test dataset. It handles distributed testing and gathers
|
|
370
370
|
results from all processes.
|
|
371
|
-
|
|
371
|
+
|
|
372
372
|
Returns:
|
|
373
373
|
dict: Dictionary containing test metrics
|
|
374
|
-
|
|
374
|
+
|
|
375
375
|
Example:
|
|
376
376
|
>>> metrics = trainer.test()
|
|
377
377
|
>>> print(f"Test accuracy: {metrics['accuracy']:.4f}")
|
|
@@ -431,18 +431,18 @@ class AccelerateTrainer:
|
|
|
431
431
|
def train(self, path_to_save=None, **kwargs):
|
|
432
432
|
"""
|
|
433
433
|
Train the model using distributed training.
|
|
434
|
-
|
|
434
|
+
|
|
435
435
|
This method performs the complete training loop with validation,
|
|
436
436
|
early stopping, and model checkpointing. It handles distributed
|
|
437
437
|
training across multiple GPUs and processes.
|
|
438
|
-
|
|
438
|
+
|
|
439
439
|
Args:
|
|
440
440
|
path_to_save (str, optional): Path to save the trained model
|
|
441
441
|
**kwargs: Additional keyword arguments for model saving
|
|
442
|
-
|
|
442
|
+
|
|
443
443
|
Returns:
|
|
444
444
|
dict: Dictionary containing training metrics
|
|
445
|
-
|
|
445
|
+
|
|
446
446
|
Example:
|
|
447
447
|
>>> metrics = trainer.train(path_to_save="./checkpoints/model")
|
|
448
448
|
>>> print(f"Best validation accuracy: {metrics['best_valid']['accuracy']:.4f}")
|
|
@@ -489,12 +489,20 @@ class AccelerateTrainer:
|
|
|
489
489
|
if "loss" not in outputs:
|
|
490
490
|
# Generally, the model should return a loss in the outputs via OmniGenBench
|
|
491
491
|
# For the Lora models, the loss is computed separately
|
|
492
|
-
if hasattr(self.model, "loss_function") and callable(
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
492
|
+
if hasattr(self.model, "loss_function") and callable(
|
|
493
|
+
self.model.loss_function
|
|
494
|
+
):
|
|
495
|
+
loss = self.model.loss_function(
|
|
496
|
+
outputs["logits"], outputs["labels"]
|
|
497
|
+
)
|
|
498
|
+
elif (
|
|
499
|
+
hasattr(self.model, "model")
|
|
500
|
+
and hasattr(self.model.model, "loss_function")
|
|
501
|
+
and callable(self.model.model.loss_function)
|
|
502
|
+
):
|
|
503
|
+
loss = self.model.model.loss_function(
|
|
504
|
+
outputs["logits"], outputs["labels"]
|
|
505
|
+
)
|
|
498
506
|
else:
|
|
499
507
|
raise ValueError(
|
|
500
508
|
"The model does not have a loss function defined. "
|
|
@@ -585,11 +593,11 @@ class AccelerateTrainer:
|
|
|
585
593
|
def _is_metric_better(self, metrics, stage="valid"):
|
|
586
594
|
"""
|
|
587
595
|
Check if the current metrics are better than the best metrics so far.
|
|
588
|
-
|
|
596
|
+
|
|
589
597
|
Args:
|
|
590
598
|
metrics (dict): Current metrics
|
|
591
599
|
stage (str): Stage of evaluation ('valid' or 'test')
|
|
592
|
-
|
|
600
|
+
|
|
593
601
|
Returns:
|
|
594
602
|
bool: True if current metrics are better, False otherwise
|
|
595
603
|
"""
|
|
@@ -643,10 +651,10 @@ class AccelerateTrainer:
|
|
|
643
651
|
def predict(self, data_loader):
|
|
644
652
|
"""
|
|
645
653
|
Make predictions using the trained model.
|
|
646
|
-
|
|
654
|
+
|
|
647
655
|
Args:
|
|
648
656
|
data_loader: DataLoader containing data to predict on
|
|
649
|
-
|
|
657
|
+
|
|
650
658
|
Returns:
|
|
651
659
|
dict: Dictionary containing predictions
|
|
652
660
|
"""
|
|
@@ -655,10 +663,10 @@ class AccelerateTrainer:
|
|
|
655
663
|
def get_model(self, **kwargs):
|
|
656
664
|
"""
|
|
657
665
|
Get the trained model.
|
|
658
|
-
|
|
666
|
+
|
|
659
667
|
Args:
|
|
660
668
|
**kwargs: Additional keyword arguments
|
|
661
|
-
|
|
669
|
+
|
|
662
670
|
Returns:
|
|
663
671
|
The trained model
|
|
664
672
|
"""
|
|
@@ -667,10 +675,10 @@ class AccelerateTrainer:
|
|
|
667
675
|
def compute_metrics(self):
|
|
668
676
|
"""
|
|
669
677
|
Compute metrics for evaluation.
|
|
670
|
-
|
|
678
|
+
|
|
671
679
|
This method should be implemented by subclasses to provide specific
|
|
672
680
|
metric computation logic.
|
|
673
|
-
|
|
681
|
+
|
|
674
682
|
Raises:
|
|
675
683
|
NotImplementedError: If compute_metrics method is not implemented
|
|
676
684
|
"""
|
|
@@ -682,7 +690,7 @@ class AccelerateTrainer:
|
|
|
682
690
|
def save_model(self, path, overwrite=False, **kwargs):
|
|
683
691
|
"""
|
|
684
692
|
Save the trained model.
|
|
685
|
-
|
|
693
|
+
|
|
686
694
|
Args:
|
|
687
695
|
path (str): Path to save the model
|
|
688
696
|
overwrite (bool, optional): Whether to overwrite existing files. Defaults to False
|
|
@@ -24,19 +24,19 @@ from ... import __version__ as omnigenome_version
|
|
|
24
24
|
class HFTrainer(Trainer):
|
|
25
25
|
"""
|
|
26
26
|
HuggingFace trainer wrapper for OmniGenome models.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
This class extends the HuggingFace Trainer to include OmniGenome-specific
|
|
29
29
|
metadata and functionality while maintaining full compatibility with the
|
|
30
30
|
HuggingFace training ecosystem.
|
|
31
|
-
|
|
31
|
+
|
|
32
32
|
Attributes:
|
|
33
33
|
metadata: Dictionary containing OmniGenome library information
|
|
34
34
|
"""
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
def __init__(self, *args, **kwargs):
|
|
37
37
|
"""
|
|
38
38
|
Initialize the HuggingFace trainer wrapper.
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
Args:
|
|
41
41
|
*args: Positional arguments passed to the parent Trainer
|
|
42
42
|
**kwargs: Keyword arguments passed to the parent Trainer
|
|
@@ -51,19 +51,19 @@ class HFTrainer(Trainer):
|
|
|
51
51
|
class HFTrainingArguments(TrainingArguments):
|
|
52
52
|
"""
|
|
53
53
|
HuggingFace training arguments wrapper for OmniGenome models.
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
This class extends the HuggingFace TrainingArguments to include
|
|
56
56
|
OmniGenome-specific metadata while maintaining full compatibility
|
|
57
57
|
with the HuggingFace training ecosystem.
|
|
58
|
-
|
|
58
|
+
|
|
59
59
|
Attributes:
|
|
60
60
|
metadata: Dictionary containing OmniGenome library information
|
|
61
61
|
"""
|
|
62
|
-
|
|
62
|
+
|
|
63
63
|
def __init__(self, *args, **kwargs):
|
|
64
64
|
"""
|
|
65
65
|
Initialize the HuggingFace training arguments wrapper.
|
|
66
|
-
|
|
66
|
+
|
|
67
67
|
Args:
|
|
68
68
|
*args: Positional arguments passed to the parent TrainingArguments
|
|
69
69
|
**kwargs: Keyword arguments passed to the parent TrainingArguments
|