omnigenome 0.3.0a1__py3-none-any.whl → 1.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +26 -258
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/METADATA +9 -10
- omnigenome-1.0.0b0.dist-info/RECORD +6 -0
- omnigenome/auto/__init__.py +0 -3
- omnigenome/auto/auto_bench/__init__.py +0 -12
- omnigenome/auto/auto_bench/auto_bench.py +0 -484
- omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
- omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
- omnigenome/auto/auto_bench/config_check.py +0 -34
- omnigenome/auto/auto_train/__init__.py +0 -13
- omnigenome/auto/auto_train/auto_train.py +0 -430
- omnigenome/auto/auto_train/auto_train_cli.py +0 -222
- omnigenome/auto/bench_hub/__init__.py +0 -12
- omnigenome/auto/bench_hub/bench_hub.py +0 -25
- omnigenome/cli/__init__.py +0 -13
- omnigenome/cli/commands/__init__.py +0 -13
- omnigenome/cli/commands/base.py +0 -83
- omnigenome/cli/commands/bench/__init__.py +0 -13
- omnigenome/cli/commands/bench/bench_cli.py +0 -202
- omnigenome/cli/commands/rna/__init__.py +0 -13
- omnigenome/cli/commands/rna/rna_design.py +0 -178
- omnigenome/cli/omnigenome_cli.py +0 -128
- omnigenome/src/__init__.py +0 -12
- omnigenome/src/abc/__init__.py +0 -12
- omnigenome/src/abc/abstract_dataset.py +0 -622
- omnigenome/src/abc/abstract_metric.py +0 -114
- omnigenome/src/abc/abstract_model.py +0 -689
- omnigenome/src/abc/abstract_tokenizer.py +0 -267
- omnigenome/src/dataset/__init__.py +0 -16
- omnigenome/src/dataset/omni_dataset.py +0 -435
- omnigenome/src/lora/__init__.py +0 -13
- omnigenome/src/lora/lora_model.py +0 -294
- omnigenome/src/metric/__init__.py +0 -15
- omnigenome/src/metric/classification_metric.py +0 -184
- omnigenome/src/metric/metric.py +0 -199
- omnigenome/src/metric/ranking_metric.py +0 -142
- omnigenome/src/metric/regression_metric.py +0 -191
- omnigenome/src/misc/__init__.py +0 -3
- omnigenome/src/misc/utils.py +0 -499
- omnigenome/src/model/__init__.py +0 -19
- omnigenome/src/model/augmentation/__init__.py +0 -12
- omnigenome/src/model/augmentation/model.py +0 -219
- omnigenome/src/model/classification/__init__.py +0 -12
- omnigenome/src/model/classification/model.py +0 -642
- omnigenome/src/model/embedding/__init__.py +0 -12
- omnigenome/src/model/embedding/model.py +0 -263
- omnigenome/src/model/mlm/__init__.py +0 -12
- omnigenome/src/model/mlm/model.py +0 -177
- omnigenome/src/model/module_utils.py +0 -232
- omnigenome/src/model/regression/__init__.py +0 -12
- omnigenome/src/model/regression/model.py +0 -786
- omnigenome/src/model/regression/resnet.py +0 -483
- omnigenome/src/model/rna_design/__init__.py +0 -12
- omnigenome/src/model/rna_design/model.py +0 -469
- omnigenome/src/model/seq2seq/__init__.py +0 -12
- omnigenome/src/model/seq2seq/model.py +0 -44
- omnigenome/src/tokenizer/__init__.py +0 -16
- omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
- omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
- omnigenome/src/trainer/__init__.py +0 -14
- omnigenome/src/trainer/accelerate_trainer.py +0 -739
- omnigenome/src/trainer/hf_trainer.py +0 -75
- omnigenome/src/trainer/trainer.py +0 -579
- omnigenome/utility/__init__.py +0 -3
- omnigenome/utility/dataset_hub/__init__.py +0 -13
- omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
- omnigenome/utility/ensemble.py +0 -324
- omnigenome/utility/hub_utils.py +0 -517
- omnigenome/utility/model_hub/__init__.py +0 -12
- omnigenome/utility/model_hub/model_hub.py +0 -231
- omnigenome/utility/pipeline_hub/__init__.py +0 -12
- omnigenome/utility/pipeline_hub/pipeline.py +0 -483
- omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- omnigenome-0.3.0a1.dist-info/entry_points.txt +0 -3
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/top_level.txt +0 -0
|
@@ -1,226 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: bpe_tokenizer.py
|
|
3
|
-
# time: 18:32 08/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
import numpy as np
|
|
10
|
-
import warnings
|
|
11
|
-
|
|
12
|
-
from ..abc.abstract_tokenizer import OmniTokenizer
|
|
13
|
-
|
|
14
|
-
warnings.filterwarnings("once")
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def is_bpe_tokenization(tokens, threshold=0.1):
|
|
18
|
-
"""
|
|
19
|
-
Check if the tokenization is BPE-based by analyzing token characteristics.
|
|
20
|
-
|
|
21
|
-
This function examines the tokens to determine if they follow BPE tokenization
|
|
22
|
-
patterns by analyzing token length distributions and special token patterns.
|
|
23
|
-
|
|
24
|
-
Args:
|
|
25
|
-
tokens (list): List of tokens to analyze
|
|
26
|
-
threshold (float, optional): Threshold for determining BPE tokenization. Defaults to 0.1
|
|
27
|
-
|
|
28
|
-
Returns:
|
|
29
|
-
bool: True if tokens appear to be BPE-based, False otherwise
|
|
30
|
-
|
|
31
|
-
Example:
|
|
32
|
-
>>> tokens = ["▁hello", "▁world", "▁how", "▁are", "▁you"]
|
|
33
|
-
>>> is_bpe = is_bpe_tokenization(tokens)
|
|
34
|
-
>>> print(is_bpe)
|
|
35
|
-
True
|
|
36
|
-
"""
|
|
37
|
-
if not tokens:
|
|
38
|
-
return False
|
|
39
|
-
|
|
40
|
-
# bpe_endings_count = sum(
|
|
41
|
-
# 1
|
|
42
|
-
# for token in tokens
|
|
43
|
-
# if token.startswith("##") or token.startswith("@@") or token.startswith("▁")
|
|
44
|
-
# )
|
|
45
|
-
# bpe_ratio = bpe_endings_count / len(tokens)
|
|
46
|
-
|
|
47
|
-
rmse = np.mean([len(token) ** 2 for token in tokens]) ** 0.5
|
|
48
|
-
|
|
49
|
-
return rmse >= threshold
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class OmniBPETokenizer(OmniTokenizer):
|
|
53
|
-
"""
|
|
54
|
-
A Byte Pair Encoding (BPE) tokenizer for genomic sequences.
|
|
55
|
-
|
|
56
|
-
This tokenizer uses BPE tokenization for genomic sequences and provides
|
|
57
|
-
validation to ensure the base tokenizer is BPE-based. It supports sequence
|
|
58
|
-
preprocessing and handles various input formats.
|
|
59
|
-
|
|
60
|
-
Attributes:
|
|
61
|
-
base_tokenizer: The underlying BPE tokenizer
|
|
62
|
-
metadata: Dictionary containing tokenizer metadata
|
|
63
|
-
|
|
64
|
-
Example:
|
|
65
|
-
>>> from omnigenome.src.tokenizer import OmniBPETokenizer
|
|
66
|
-
>>> from transformers import AutoTokenizer
|
|
67
|
-
>>> base_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
|
68
|
-
>>> tokenizer = OmniBPETokenizer(base_tokenizer)
|
|
69
|
-
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
70
|
-
>>> tokens = tokenizer.tokenize(sequence)
|
|
71
|
-
>>> print(tokens[:5])
|
|
72
|
-
['▁A', 'C', 'G', 'U', 'A']
|
|
73
|
-
"""
|
|
74
|
-
|
|
75
|
-
def __init__(self, base_tokenizer=None, **kwargs):
|
|
76
|
-
"""
|
|
77
|
-
Initialize the OmniBPETokenizer.
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
base_tokenizer: The base BPE tokenizer
|
|
81
|
-
**kwargs: Additional keyword arguments passed to parent class
|
|
82
|
-
"""
|
|
83
|
-
super(OmniBPETokenizer, self).__init__(base_tokenizer, **kwargs)
|
|
84
|
-
self.metadata["tokenizer_name"] = self.__class__.__name__
|
|
85
|
-
|
|
86
|
-
def __call__(self, sequence, **kwargs):
|
|
87
|
-
"""
|
|
88
|
-
Tokenize a sequence using BPE tokenization.
|
|
89
|
-
|
|
90
|
-
This method processes the input sequence using BPE tokenization,
|
|
91
|
-
handles sequence preprocessing (U/T conversion, whitespace addition),
|
|
92
|
-
and validates that the tokenization is BPE-based.
|
|
93
|
-
|
|
94
|
-
Args:
|
|
95
|
-
sequence (str): Input sequence to tokenize
|
|
96
|
-
**kwargs: Additional keyword arguments including max_length
|
|
97
|
-
|
|
98
|
-
Returns:
|
|
99
|
-
dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask'
|
|
100
|
-
|
|
101
|
-
Raises:
|
|
102
|
-
ValueError: If the tokenizer is not BPE-based
|
|
103
|
-
|
|
104
|
-
Example:
|
|
105
|
-
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
106
|
-
>>> tokenized = tokenizer(sequence)
|
|
107
|
-
>>> print(tokenized['input_ids'].shape)
|
|
108
|
-
torch.Size([1, 17])
|
|
109
|
-
"""
|
|
110
|
-
if self.u2t:
|
|
111
|
-
sequence = sequence.replace("U", "T")
|
|
112
|
-
if self.add_whitespace:
|
|
113
|
-
sequence = " ".join(list(sequence))
|
|
114
|
-
|
|
115
|
-
sequence_tokens = self.tokenize(sequence)[
|
|
116
|
-
: min(self.max_length, kwargs.get("max_length", 512)) - 2
|
|
117
|
-
]
|
|
118
|
-
|
|
119
|
-
if not is_bpe_tokenization(sequence_tokens):
|
|
120
|
-
raise ValueError("The tokenizer seems not to be a BPE tokenizer.")
|
|
121
|
-
tokenized_inputs = dict()
|
|
122
|
-
tokenized_inputs["input_ids"] = self.base_tokenizer.convert_tokens_to_ids(
|
|
123
|
-
sequence_tokens
|
|
124
|
-
)
|
|
125
|
-
tokenized_inputs["attention_mask"] = [1] * len(tokenized_inputs["input_ids"])
|
|
126
|
-
|
|
127
|
-
tokenized_inputs = self.base_tokenizer.pad(
|
|
128
|
-
tokenized_inputs,
|
|
129
|
-
padding="max_length",
|
|
130
|
-
max_length=len(sequence_tokens),
|
|
131
|
-
return_tensors="pt",
|
|
132
|
-
)
|
|
133
|
-
return tokenized_inputs
|
|
134
|
-
|
|
135
|
-
@staticmethod
|
|
136
|
-
def from_pretrained(model_name_or_path, **kwargs):
|
|
137
|
-
"""
|
|
138
|
-
Create a BPE tokenizer from a pre-trained model.
|
|
139
|
-
|
|
140
|
-
Args:
|
|
141
|
-
model_name_or_path (str): Name or path of the pre-trained model
|
|
142
|
-
**kwargs: Additional keyword arguments
|
|
143
|
-
|
|
144
|
-
Returns:
|
|
145
|
-
OmniBPETokenizer: Initialized BPE tokenizer
|
|
146
|
-
|
|
147
|
-
Example:
|
|
148
|
-
>>> tokenizer = OmniBPETokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
|
149
|
-
>>> print(type(tokenizer))
|
|
150
|
-
<class 'omnigenome.src.tokenizer.bpe_tokenizer.OmniBPETokenizer'>
|
|
151
|
-
"""
|
|
152
|
-
from transformers import AutoTokenizer
|
|
153
|
-
|
|
154
|
-
self = OmniBPETokenizer(
|
|
155
|
-
AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
|
|
156
|
-
)
|
|
157
|
-
return self
|
|
158
|
-
|
|
159
|
-
def tokenize(self, sequence, **kwargs):
|
|
160
|
-
"""
|
|
161
|
-
Tokenize a sequence using the base BPE tokenizer.
|
|
162
|
-
|
|
163
|
-
Args:
|
|
164
|
-
sequence (str): Input sequence to tokenize
|
|
165
|
-
**kwargs: Additional keyword arguments
|
|
166
|
-
|
|
167
|
-
Returns:
|
|
168
|
-
list: List of tokens
|
|
169
|
-
|
|
170
|
-
Example:
|
|
171
|
-
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
172
|
-
>>> tokens = tokenizer.tokenize(sequence)
|
|
173
|
-
>>> print(tokens[:5])
|
|
174
|
-
['▁A', 'C', 'G', 'U', 'A']
|
|
175
|
-
"""
|
|
176
|
-
return self.base_tokenizer.tokenize(sequence)
|
|
177
|
-
|
|
178
|
-
def encode(self, sequence, **kwargs):
|
|
179
|
-
"""
|
|
180
|
-
Encode a sequence using the base BPE tokenizer.
|
|
181
|
-
|
|
182
|
-
Args:
|
|
183
|
-
sequence (str): Input sequence to encode
|
|
184
|
-
**kwargs: Additional keyword arguments
|
|
185
|
-
|
|
186
|
-
Returns:
|
|
187
|
-
list: List of token IDs
|
|
188
|
-
|
|
189
|
-
Raises:
|
|
190
|
-
AssertionError: If the base tokenizer is not BPE-based
|
|
191
|
-
|
|
192
|
-
Example:
|
|
193
|
-
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
194
|
-
>>> token_ids = tokenizer.encode(sequence)
|
|
195
|
-
>>> print(len(token_ids))
|
|
196
|
-
17
|
|
197
|
-
"""
|
|
198
|
-
assert hasattr(
|
|
199
|
-
self.base_tokenizer, "bpe"
|
|
200
|
-
), "The base tokenizer must be a BPE tokenizer."
|
|
201
|
-
return self.base_tokenizer.encode(sequence, **kwargs)
|
|
202
|
-
|
|
203
|
-
def decode(self, sequence, **kwargs):
|
|
204
|
-
"""
|
|
205
|
-
Decode a sequence using the base BPE tokenizer.
|
|
206
|
-
|
|
207
|
-
Args:
|
|
208
|
-
sequence: Input sequence to decode (can be token IDs or tokens)
|
|
209
|
-
**kwargs: Additional keyword arguments
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
str: Decoded sequence
|
|
213
|
-
|
|
214
|
-
Raises:
|
|
215
|
-
AssertionError: If the base tokenizer is not BPE-based
|
|
216
|
-
|
|
217
|
-
Example:
|
|
218
|
-
>>> token_ids = [1, 2, 3, 4, 5]
|
|
219
|
-
>>> sequence = tokenizer.decode(token_ids)
|
|
220
|
-
>>> print(sequence)
|
|
221
|
-
"ACGUAGGUAUCGUAGA"
|
|
222
|
-
"""
|
|
223
|
-
assert hasattr(
|
|
224
|
-
self.base_tokenizer, "bpe"
|
|
225
|
-
), "The base tokenizer must be a BPE tokenizer."
|
|
226
|
-
return self.base_tokenizer.decode(sequence, **kwargs)
|
|
@@ -1,247 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: kmers_tokenizer.py
|
|
3
|
-
# time: 18:31 08/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
import warnings
|
|
10
|
-
|
|
11
|
-
from ..abc.abstract_tokenizer import OmniTokenizer
|
|
12
|
-
|
|
13
|
-
warnings.filterwarnings("once")
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class OmniKmersTokenizer(OmniTokenizer):
|
|
17
|
-
"""
|
|
18
|
-
A k-mer based tokenizer for genomic sequences.
|
|
19
|
-
|
|
20
|
-
This tokenizer breaks genomic sequences into overlapping k-mers and uses
|
|
21
|
-
a base tokenizer to convert them into token IDs. It supports various
|
|
22
|
-
k-mer sizes and overlap configurations for different genomic applications.
|
|
23
|
-
|
|
24
|
-
Attributes:
|
|
25
|
-
base_tokenizer: The underlying tokenizer for converting k-mers to IDs
|
|
26
|
-
k: Size of k-mers
|
|
27
|
-
overlap: Number of overlapping positions between consecutive k-mers
|
|
28
|
-
max_length: Maximum sequence length for tokenization
|
|
29
|
-
metadata: Dictionary containing tokenizer metadata
|
|
30
|
-
|
|
31
|
-
Example:
|
|
32
|
-
>>> from omnigenome.src.tokenizer import OmniKmersTokenizer
|
|
33
|
-
>>> from transformers import AutoTokenizer
|
|
34
|
-
>>> base_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
|
35
|
-
>>> tokenizer = OmniKmersTokenizer(base_tokenizer, k=4, overlap=2)
|
|
36
|
-
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
37
|
-
>>> tokens = tokenizer.tokenize(sequence)
|
|
38
|
-
>>> print(tokens)
|
|
39
|
-
[['ACGU', 'GUAG', 'UAGG', 'AGGU', 'GGUA', 'GUAU', 'UAUC', 'AUCG', 'UCGU', 'CGUA', 'GUAG', 'UAGA']]
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
def __init__(self, base_tokenizer=None, k=3, overlap=0, max_length=512, **kwargs):
|
|
43
|
-
"""
|
|
44
|
-
Initialize the OmniKmersTokenizer.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
base_tokenizer: The base tokenizer for converting k-mers to token IDs
|
|
48
|
-
k (int, optional): Size of k-mers. Defaults to 3
|
|
49
|
-
overlap (int, optional): Number of overlapping positions between consecutive k-mers. Defaults to 0
|
|
50
|
-
max_length (int, optional): Maximum sequence length for tokenization. Defaults to 512
|
|
51
|
-
**kwargs: Additional keyword arguments passed to parent class
|
|
52
|
-
"""
|
|
53
|
-
super(OmniKmersTokenizer, self).__init__(base_tokenizer, **kwargs)
|
|
54
|
-
self.k = k
|
|
55
|
-
self.overlap = overlap
|
|
56
|
-
self.max_length = max_length
|
|
57
|
-
self.metadata["tokenizer_name"] = self.__class__.__name__
|
|
58
|
-
|
|
59
|
-
def __call__(self, sequence, **kwargs):
|
|
60
|
-
"""
|
|
61
|
-
Tokenize a sequence or list of sequences into tokenized inputs.
|
|
62
|
-
|
|
63
|
-
This method processes the input sequence(s) by first converting them to k-mers,
|
|
64
|
-
then using the base tokenizer to convert k-mers to token IDs. It handles
|
|
65
|
-
sequence preprocessing (U/T conversion) and adds special tokens.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
sequence (str or list): Input sequence(s) to tokenize
|
|
69
|
-
**kwargs: Additional keyword arguments including max_length
|
|
70
|
-
|
|
71
|
-
Returns:
|
|
72
|
-
dict: Dictionary containing tokenized inputs with keys 'input_ids' and 'attention_mask'
|
|
73
|
-
|
|
74
|
-
Example:
|
|
75
|
-
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
76
|
-
>>> tokenized = tokenizer(sequence)
|
|
77
|
-
>>> print(tokenized['input_ids'].shape)
|
|
78
|
-
torch.Size([1, 14])
|
|
79
|
-
"""
|
|
80
|
-
if self.u2t:
|
|
81
|
-
sequence = "".join([seq.replace("U", "T").upper() for seq in sequence])
|
|
82
|
-
if self.t2u:
|
|
83
|
-
sequence = "".join([seq.replace("T", "U").upper() for seq in sequence])
|
|
84
|
-
|
|
85
|
-
sequence_tokens = self.tokenize(sequence)[
|
|
86
|
-
: kwargs.get("max_length", self.max_length) - 2
|
|
87
|
-
]
|
|
88
|
-
tokenized_inputs = {
|
|
89
|
-
"input_ids": [],
|
|
90
|
-
"attention_mask": [],
|
|
91
|
-
}
|
|
92
|
-
bos_id = (
|
|
93
|
-
self.base_tokenizer.bos_token_id
|
|
94
|
-
if self.base_tokenizer.bos_token_id is not None
|
|
95
|
-
else self.base_tokenizer.cls_token_id
|
|
96
|
-
)
|
|
97
|
-
eos_id = (
|
|
98
|
-
self.base_tokenizer.eos_token_id
|
|
99
|
-
if self.base_tokenizer.eos_token_id is not None
|
|
100
|
-
else self.base_tokenizer.sep_token_id
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
for tokens in sequence_tokens:
|
|
104
|
-
tokenized_inputs["input_ids"].append(
|
|
105
|
-
[bos_id] + self.base_tokenizer.convert_tokens_to_ids(tokens) + [eos_id]
|
|
106
|
-
)
|
|
107
|
-
tokenized_inputs["attention_mask"].append(
|
|
108
|
-
[1] * len(tokenized_inputs["input_ids"][-1])
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
for i, ids in enumerate(tokenized_inputs["input_ids"]):
|
|
112
|
-
if ids.count(self.base_tokenizer.unk_token_id) / len(ids) > 0.1:
|
|
113
|
-
warnings.warn(
|
|
114
|
-
f"Unknown tokens are more than 10% in the {i}th sequence, please check the tokenization process."
|
|
115
|
-
)
|
|
116
|
-
tokenized_inputs = self.base_tokenizer.pad(
|
|
117
|
-
tokenized_inputs,
|
|
118
|
-
padding="max_length",
|
|
119
|
-
max_length=len(sequence_tokens),
|
|
120
|
-
return_attention_mask=True,
|
|
121
|
-
return_tensors="pt",
|
|
122
|
-
)
|
|
123
|
-
return tokenized_inputs
|
|
124
|
-
|
|
125
|
-
@staticmethod
|
|
126
|
-
def from_pretrained(model_name_or_path, **kwargs):
|
|
127
|
-
"""
|
|
128
|
-
Create a k-mers tokenizer from a pre-trained model.
|
|
129
|
-
|
|
130
|
-
Args:
|
|
131
|
-
model_name_or_path (str): Name or path of the pre-trained model
|
|
132
|
-
**kwargs: Additional keyword arguments
|
|
133
|
-
|
|
134
|
-
Returns:
|
|
135
|
-
OmniKmersTokenizer: Initialized k-mers tokenizer
|
|
136
|
-
|
|
137
|
-
Example:
|
|
138
|
-
>>> tokenizer = OmniKmersTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
|
|
139
|
-
>>> print(type(tokenizer))
|
|
140
|
-
<class 'omnigenome.src.tokenizer.kmers_tokenizer.OmniKmersTokenizer'>
|
|
141
|
-
"""
|
|
142
|
-
from transformers import AutoTokenizer
|
|
143
|
-
|
|
144
|
-
self = OmniKmersTokenizer(
|
|
145
|
-
AutoTokenizer.from_pretrained(model_name_or_path, **kwargs)
|
|
146
|
-
)
|
|
147
|
-
return self
|
|
148
|
-
|
|
149
|
-
def tokenize(self, sequence, **kwargs):
|
|
150
|
-
"""
|
|
151
|
-
Convert sequence(s) into k-mers.
|
|
152
|
-
|
|
153
|
-
This method breaks the input sequence(s) into overlapping k-mers based on
|
|
154
|
-
the configured k-mer size and overlap parameters.
|
|
155
|
-
|
|
156
|
-
Args:
|
|
157
|
-
sequence (str or list): Input sequence(s) to convert to k-mers
|
|
158
|
-
**kwargs: Additional keyword arguments
|
|
159
|
-
|
|
160
|
-
Returns:
|
|
161
|
-
list: List of k-mer lists for each input sequence
|
|
162
|
-
|
|
163
|
-
Example:
|
|
164
|
-
>>> sequence = "ACGUAGGUAUCGUAGA"
|
|
165
|
-
>>> k_mers = tokenizer.tokenize(sequence)
|
|
166
|
-
>>> print(k_mers[0][:3])
|
|
167
|
-
['ACGU', 'GUAG', 'UAGG']
|
|
168
|
-
"""
|
|
169
|
-
if isinstance(sequence, str):
|
|
170
|
-
sequences = [sequence]
|
|
171
|
-
else:
|
|
172
|
-
sequences = sequence
|
|
173
|
-
|
|
174
|
-
sequence_tokens = []
|
|
175
|
-
for i in range(len(sequences)):
|
|
176
|
-
tokens = []
|
|
177
|
-
for j in range(0, len(sequences[i]), self.k - self.overlap):
|
|
178
|
-
tokens.append(sequences[i][j : j + self.k])
|
|
179
|
-
|
|
180
|
-
sequence_tokens.append(tokens)
|
|
181
|
-
|
|
182
|
-
return sequence_tokens
|
|
183
|
-
|
|
184
|
-
def encode(self, input_ids, **kwargs):
|
|
185
|
-
"""
|
|
186
|
-
Encode input IDs using the base tokenizer.
|
|
187
|
-
|
|
188
|
-
Args:
|
|
189
|
-
input_ids: Input IDs to encode
|
|
190
|
-
**kwargs: Additional keyword arguments
|
|
191
|
-
|
|
192
|
-
Returns:
|
|
193
|
-
Encoded input IDs
|
|
194
|
-
"""
|
|
195
|
-
return self.base_tokenizer.encode(input_ids, **kwargs)
|
|
196
|
-
|
|
197
|
-
def decode(self, input_ids, **kwargs):
|
|
198
|
-
"""
|
|
199
|
-
Decode input IDs using the base tokenizer.
|
|
200
|
-
|
|
201
|
-
Args:
|
|
202
|
-
input_ids: Input IDs to decode
|
|
203
|
-
**kwargs: Additional keyword arguments
|
|
204
|
-
|
|
205
|
-
Returns:
|
|
206
|
-
Decoded sequence
|
|
207
|
-
"""
|
|
208
|
-
return self.base_tokenizer.decode(input_ids, **kwargs)
|
|
209
|
-
|
|
210
|
-
def encode_plus(self, sequence, **kwargs):
|
|
211
|
-
"""
|
|
212
|
-
Encode a sequence with additional information.
|
|
213
|
-
|
|
214
|
-
This method is not yet implemented for k-mers tokenizer.
|
|
215
|
-
|
|
216
|
-
Args:
|
|
217
|
-
sequence: Input sequence
|
|
218
|
-
**kwargs: Additional keyword arguments
|
|
219
|
-
|
|
220
|
-
Raises:
|
|
221
|
-
NotImplementedError: This method is not implemented yet
|
|
222
|
-
"""
|
|
223
|
-
raise NotImplementedError("The encode_plus() function is not implemented yet.")
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
if __name__ == "__main__":
|
|
227
|
-
from transformers import AutoTokenizer
|
|
228
|
-
|
|
229
|
-
# RNA = "ACGUAGGUAUCGUAGA"
|
|
230
|
-
# # base_tokenizer_name = 'bert-base-cased'
|
|
231
|
-
# base_tokenizer_name = "facebook/esm2_t12_35M_UR50D"
|
|
232
|
-
# base_tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name)
|
|
233
|
-
# tokenizer = KmersTokenizer(base_tokenizer)
|
|
234
|
-
# tokens = tokenizer.tokenize(RNA)
|
|
235
|
-
# fprint(tokens)
|
|
236
|
-
# tokenized_inputs = tokenizer(RNA)
|
|
237
|
-
# fprint(tokenized_inputs)
|
|
238
|
-
|
|
239
|
-
RNA = "ACGUAGGUAUCGUAGA"
|
|
240
|
-
# base_tokenizer_name = 'bert-base-cased'
|
|
241
|
-
base_tokenizer_name = "facebook/esm2_t12_35M_UR50D"
|
|
242
|
-
base_tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name)
|
|
243
|
-
tokenizer = OmniKmersTokenizer(base_tokenizer, k=4, overlap=2, max_length=512)
|
|
244
|
-
tokens = tokenizer.tokenize(RNA)
|
|
245
|
-
print(tokens)
|
|
246
|
-
tokenized_inputs = tokenizer(RNA)
|
|
247
|
-
print(tokenized_inputs)
|