omnigenome 0.3.0a1__py3-none-any.whl → 0.3.3a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +252 -258
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/METADATA +10 -10
- omnigenome-0.3.3a0.dist-info/RECORD +7 -0
- omnigenome/auto/__init__.py +0 -3
- omnigenome/auto/auto_bench/__init__.py +0 -12
- omnigenome/auto/auto_bench/auto_bench.py +0 -484
- omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
- omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
- omnigenome/auto/auto_bench/config_check.py +0 -34
- omnigenome/auto/auto_train/__init__.py +0 -13
- omnigenome/auto/auto_train/auto_train.py +0 -430
- omnigenome/auto/auto_train/auto_train_cli.py +0 -222
- omnigenome/auto/bench_hub/__init__.py +0 -12
- omnigenome/auto/bench_hub/bench_hub.py +0 -25
- omnigenome/cli/__init__.py +0 -13
- omnigenome/cli/commands/__init__.py +0 -13
- omnigenome/cli/commands/base.py +0 -83
- omnigenome/cli/commands/bench/__init__.py +0 -13
- omnigenome/cli/commands/bench/bench_cli.py +0 -202
- omnigenome/cli/commands/rna/__init__.py +0 -13
- omnigenome/cli/commands/rna/rna_design.py +0 -178
- omnigenome/cli/omnigenome_cli.py +0 -128
- omnigenome/src/__init__.py +0 -12
- omnigenome/src/abc/__init__.py +0 -12
- omnigenome/src/abc/abstract_dataset.py +0 -622
- omnigenome/src/abc/abstract_metric.py +0 -114
- omnigenome/src/abc/abstract_model.py +0 -689
- omnigenome/src/abc/abstract_tokenizer.py +0 -267
- omnigenome/src/dataset/__init__.py +0 -16
- omnigenome/src/dataset/omni_dataset.py +0 -435
- omnigenome/src/lora/__init__.py +0 -13
- omnigenome/src/lora/lora_model.py +0 -294
- omnigenome/src/metric/__init__.py +0 -15
- omnigenome/src/metric/classification_metric.py +0 -184
- omnigenome/src/metric/metric.py +0 -199
- omnigenome/src/metric/ranking_metric.py +0 -142
- omnigenome/src/metric/regression_metric.py +0 -191
- omnigenome/src/misc/__init__.py +0 -3
- omnigenome/src/misc/utils.py +0 -499
- omnigenome/src/model/__init__.py +0 -19
- omnigenome/src/model/augmentation/__init__.py +0 -12
- omnigenome/src/model/augmentation/model.py +0 -219
- omnigenome/src/model/classification/__init__.py +0 -12
- omnigenome/src/model/classification/model.py +0 -642
- omnigenome/src/model/embedding/__init__.py +0 -12
- omnigenome/src/model/embedding/model.py +0 -263
- omnigenome/src/model/mlm/__init__.py +0 -12
- omnigenome/src/model/mlm/model.py +0 -177
- omnigenome/src/model/module_utils.py +0 -232
- omnigenome/src/model/regression/__init__.py +0 -12
- omnigenome/src/model/regression/model.py +0 -786
- omnigenome/src/model/regression/resnet.py +0 -483
- omnigenome/src/model/rna_design/__init__.py +0 -12
- omnigenome/src/model/rna_design/model.py +0 -469
- omnigenome/src/model/seq2seq/__init__.py +0 -12
- omnigenome/src/model/seq2seq/model.py +0 -44
- omnigenome/src/tokenizer/__init__.py +0 -16
- omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
- omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
- omnigenome/src/trainer/__init__.py +0 -14
- omnigenome/src/trainer/accelerate_trainer.py +0 -739
- omnigenome/src/trainer/hf_trainer.py +0 -75
- omnigenome/src/trainer/trainer.py +0 -579
- omnigenome/utility/__init__.py +0 -3
- omnigenome/utility/dataset_hub/__init__.py +0 -13
- omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
- omnigenome/utility/ensemble.py +0 -324
- omnigenome/utility/hub_utils.py +0 -517
- omnigenome/utility/model_hub/__init__.py +0 -12
- omnigenome/utility/model_hub/model_hub.py +0 -231
- omnigenome/utility/pipeline_hub/__init__.py +0 -12
- omnigenome/utility/pipeline_hub/pipeline.py +0 -483
- omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/top_level.txt +0 -0
|
@@ -1,263 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: model.py
|
|
3
|
-
# time: 18:37 22/09/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
|
|
10
|
-
import torch
|
|
11
|
-
from transformers import AutoTokenizer, AutoModel
|
|
12
|
-
|
|
13
|
-
from omnigenome.src.misc.utils import fprint
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class OmniModelForEmbedding(torch.nn.Module):
|
|
17
|
-
"""
|
|
18
|
-
A wrapper class for generating embeddings from pre-trained models.
|
|
19
|
-
|
|
20
|
-
This class provides a unified interface for loading pre-trained models and
|
|
21
|
-
generating embeddings from genomic sequences. It supports various aggregation
|
|
22
|
-
methods and batch processing for efficient embedding generation.
|
|
23
|
-
|
|
24
|
-
Attributes:
|
|
25
|
-
tokenizer: The tokenizer for processing input sequences
|
|
26
|
-
model: The pre-trained model for generating embeddings
|
|
27
|
-
_device: The device (CPU/GPU) where the model is loaded
|
|
28
|
-
|
|
29
|
-
Example:
|
|
30
|
-
>>> from omnigenome import OmniModelForEmbedding
|
|
31
|
-
>>> model = OmniModelForEmbedding("anonymous8/OmniGenome-186M")
|
|
32
|
-
>>> sequences = ["ATCGGCTA", "GGCTAGCTA"]
|
|
33
|
-
>>> embeddings = model.batch_encode(sequences)
|
|
34
|
-
>>> print(f"Embeddings shape: {embeddings.shape}")
|
|
35
|
-
torch.Size([2, 768])
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
def __init__(self, model_name_or_path, *args, **kwargs):
|
|
39
|
-
"""
|
|
40
|
-
Initialize the embedding model.
|
|
41
|
-
|
|
42
|
-
Args:
|
|
43
|
-
model_name_or_path (str): Name or path of the pre-trained model to load
|
|
44
|
-
*args: Additional positional arguments passed to AutoModel.from_pretrained
|
|
45
|
-
**kwargs: Additional keyword arguments passed to AutoModel.from_pretrained
|
|
46
|
-
"""
|
|
47
|
-
super().__init__()
|
|
48
|
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
49
|
-
self.model = AutoModel.from_pretrained(model_name_or_path, *args, **kwargs)
|
|
50
|
-
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
51
|
-
self.model.to(self._device)
|
|
52
|
-
self.model.eval() # Set model to evaluation mode
|
|
53
|
-
|
|
54
|
-
def batch_encode(self, sequences, batch_size=8, max_length=512, agg='head'):
|
|
55
|
-
"""
|
|
56
|
-
Encode a list of sequences to their corresponding embeddings.
|
|
57
|
-
|
|
58
|
-
This method processes sequences in batches for memory efficiency and
|
|
59
|
-
supports different aggregation methods for the final embeddings.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
sequences (list of str): List of input sequences to encode
|
|
63
|
-
batch_size (int, optional): Batch size for processing. Defaults to 8
|
|
64
|
-
max_length (int, optional): Maximum sequence length for encoding. Defaults to 512
|
|
65
|
-
agg (str, optional): Aggregation method for embeddings. Options are 'head', 'mean', 'tail'. Defaults to 'head'
|
|
66
|
-
|
|
67
|
-
Returns:
|
|
68
|
-
torch.Tensor: Embeddings for the input sequences with shape (n_sequences, embedding_dim)
|
|
69
|
-
|
|
70
|
-
Raises:
|
|
71
|
-
ValueError: If unsupported aggregation method is provided
|
|
72
|
-
|
|
73
|
-
Example:
|
|
74
|
-
>>> sequences = ["ATCGGCTA", "GGCTAGCTA", "TATCGCTA"]
|
|
75
|
-
>>> embeddings = model.batch_encode(sequences, batch_size=2, agg='mean')
|
|
76
|
-
>>> print(f"Embeddings shape: {embeddings.shape}")
|
|
77
|
-
torch.Size([3, 768])
|
|
78
|
-
"""
|
|
79
|
-
embeddings = []
|
|
80
|
-
|
|
81
|
-
for i in range(0, len(sequences), batch_size):
|
|
82
|
-
batch_sequences = sequences[i: i + batch_size]
|
|
83
|
-
inputs = self.tokenizer(
|
|
84
|
-
batch_sequences,
|
|
85
|
-
return_tensors="pt",
|
|
86
|
-
padding=True,
|
|
87
|
-
truncation=True,
|
|
88
|
-
max_length=max_length,
|
|
89
|
-
)
|
|
90
|
-
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
|
91
|
-
|
|
92
|
-
with torch.no_grad():
|
|
93
|
-
outputs = self.model(**inputs)
|
|
94
|
-
|
|
95
|
-
batch_embeddings = outputs.last_hidden_state.cpu()
|
|
96
|
-
|
|
97
|
-
if agg == 'head':
|
|
98
|
-
emb = batch_embeddings[:, 0, :]
|
|
99
|
-
elif agg == 'mean':
|
|
100
|
-
attention_mask = inputs["attention_mask"].cpu()
|
|
101
|
-
masked_embeddings = batch_embeddings * attention_mask.unsqueeze(-1)
|
|
102
|
-
lengths = attention_mask.sum(dim=1).unsqueeze(1)
|
|
103
|
-
emb = masked_embeddings.sum(dim=1) / lengths
|
|
104
|
-
elif agg == 'tail':
|
|
105
|
-
attention_mask = inputs["attention_mask"]
|
|
106
|
-
lengths = attention_mask.sum(dim=1) - 1
|
|
107
|
-
emb = torch.stack([
|
|
108
|
-
batch_embeddings[i, l.item()] for i, l in enumerate(lengths)
|
|
109
|
-
])
|
|
110
|
-
else:
|
|
111
|
-
raise ValueError(f"Unsupported aggregation method: {agg}")
|
|
112
|
-
|
|
113
|
-
embeddings.append(emb)
|
|
114
|
-
|
|
115
|
-
embeddings = torch.cat(embeddings, dim=0)
|
|
116
|
-
fprint(f"Generated embeddings for {len(sequences)} sequences.")
|
|
117
|
-
return embeddings
|
|
118
|
-
|
|
119
|
-
def encode(self, sequence, max_length=512, agg='head', keep_dim=False):
|
|
120
|
-
"""
|
|
121
|
-
Encode a single sequence to its corresponding embedding.
|
|
122
|
-
|
|
123
|
-
Args:
|
|
124
|
-
sequence (str): Input sequence to encode
|
|
125
|
-
max_length (int, optional): Maximum sequence length for encoding. Defaults to 512
|
|
126
|
-
agg (str, optional): Aggregation method. Options are 'head', 'mean', 'tail'. Defaults to 'head'
|
|
127
|
-
keep_dim (bool, optional): Whether to retain the batch dimension. Defaults to False
|
|
128
|
-
|
|
129
|
-
Returns:
|
|
130
|
-
torch.Tensor: Embedding for the input sequence
|
|
131
|
-
|
|
132
|
-
Raises:
|
|
133
|
-
ValueError: If unsupported aggregation method is provided
|
|
134
|
-
|
|
135
|
-
Example:
|
|
136
|
-
>>> sequence = "ATCGGCTA"
|
|
137
|
-
>>> embedding = model.encode(sequence, agg='mean')
|
|
138
|
-
>>> print(f"Embedding shape: {embedding.shape}")
|
|
139
|
-
torch.Size([768])
|
|
140
|
-
"""
|
|
141
|
-
inputs = self.tokenizer(
|
|
142
|
-
sequence,
|
|
143
|
-
return_tensors="pt",
|
|
144
|
-
padding=True,
|
|
145
|
-
truncation=True,
|
|
146
|
-
max_length=max_length,
|
|
147
|
-
)
|
|
148
|
-
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
|
149
|
-
|
|
150
|
-
with torch.no_grad():
|
|
151
|
-
outputs = self.model(**inputs)
|
|
152
|
-
|
|
153
|
-
last_hidden = outputs.last_hidden_state.cpu()
|
|
154
|
-
|
|
155
|
-
if agg == 'head':
|
|
156
|
-
emb = last_hidden[0, 0]
|
|
157
|
-
elif agg == 'mean':
|
|
158
|
-
attention_mask = inputs["attention_mask"].cpu()
|
|
159
|
-
masked_embeddings = last_hidden * attention_mask.unsqueeze(-1)
|
|
160
|
-
lengths = attention_mask.sum(dim=1).unsqueeze(1)
|
|
161
|
-
emb = masked_embeddings.sum(dim=1) / lengths
|
|
162
|
-
emb = emb.squeeze(0)
|
|
163
|
-
elif agg == 'tail':
|
|
164
|
-
attention_mask = inputs["attention_mask"]
|
|
165
|
-
lengths = attention_mask.sum(dim=1) - 1
|
|
166
|
-
emb = last_hidden[0, lengths[0].item()]
|
|
167
|
-
else:
|
|
168
|
-
raise ValueError(f"Unsupported aggregation method: {agg}")
|
|
169
|
-
|
|
170
|
-
return emb.unsqueeze(0) if keep_dim else emb
|
|
171
|
-
|
|
172
|
-
def save_embeddings(self, embeddings, output_path):
|
|
173
|
-
"""
|
|
174
|
-
Save the generated embeddings to a file.
|
|
175
|
-
|
|
176
|
-
Args:
|
|
177
|
-
embeddings (torch.Tensor): The embeddings to save
|
|
178
|
-
output_path (str): Path to save the embeddings
|
|
179
|
-
|
|
180
|
-
Example:
|
|
181
|
-
>>> embeddings = model.batch_encode(sequences)
|
|
182
|
-
>>> model.save_embeddings(embeddings, "embeddings.pt")
|
|
183
|
-
>>> print("Embeddings saved successfully")
|
|
184
|
-
"""
|
|
185
|
-
torch.save(embeddings, output_path)
|
|
186
|
-
fprint(f"Embeddings saved to {output_path}")
|
|
187
|
-
|
|
188
|
-
def load_embeddings(self, embedding_path):
|
|
189
|
-
"""
|
|
190
|
-
Load embeddings from a file.
|
|
191
|
-
|
|
192
|
-
Args:
|
|
193
|
-
embedding_path (str): Path to the saved embeddings
|
|
194
|
-
|
|
195
|
-
Returns:
|
|
196
|
-
torch.Tensor: The loaded embeddings
|
|
197
|
-
|
|
198
|
-
Example:
|
|
199
|
-
>>> embeddings = model.load_embeddings("embeddings.pt")
|
|
200
|
-
>>> print(f"Loaded embeddings shape: {embeddings.shape}")
|
|
201
|
-
torch.Size([100, 768])
|
|
202
|
-
"""
|
|
203
|
-
embeddings = torch.load(embedding_path)
|
|
204
|
-
fprint(f"Loaded embeddings from {embedding_path}")
|
|
205
|
-
return embeddings
|
|
206
|
-
|
|
207
|
-
def compute_similarity(self, embedding1, embedding2, dim=0):
|
|
208
|
-
"""
|
|
209
|
-
Compute cosine similarity between two embeddings.
|
|
210
|
-
|
|
211
|
-
Args:
|
|
212
|
-
embedding1 (torch.Tensor): The first embedding
|
|
213
|
-
embedding2 (torch.Tensor): The second embedding
|
|
214
|
-
dim (int, optional): Dimension along which to compute cosine similarity. Defaults to 0
|
|
215
|
-
|
|
216
|
-
Returns:
|
|
217
|
-
float: Cosine similarity score between -1 and 1
|
|
218
|
-
|
|
219
|
-
Example:
|
|
220
|
-
>>> emb1 = model.encode("ATCGGCTA")
|
|
221
|
-
>>> emb2 = model.encode("GGCTAGCTA")
|
|
222
|
-
>>> similarity = model.compute_similarity(emb1, emb2)
|
|
223
|
-
>>> print(f"Cosine similarity: {similarity:.4f}")
|
|
224
|
-
0.8234
|
|
225
|
-
"""
|
|
226
|
-
similarity = torch.nn.functional.cosine_similarity(
|
|
227
|
-
embedding1, embedding2, dim=dim
|
|
228
|
-
)
|
|
229
|
-
return similarity
|
|
230
|
-
|
|
231
|
-
@property
|
|
232
|
-
def device(self):
|
|
233
|
-
"""
|
|
234
|
-
Get the current device ('cuda' or 'cpu').
|
|
235
|
-
|
|
236
|
-
Returns:
|
|
237
|
-
torch.device: The device where the model is loaded
|
|
238
|
-
"""
|
|
239
|
-
return self._device
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
# Example usage
|
|
243
|
-
if __name__ == "__main__":
|
|
244
|
-
model_name = "anonymous8/OmniGenome-186M"
|
|
245
|
-
embedding_model = OmniModelForEmbedding(model_name)
|
|
246
|
-
|
|
247
|
-
# Encode multiple sequences
|
|
248
|
-
sequences = ["ATCGGCTA", "GGCTAGCTA"]
|
|
249
|
-
embedding = embedding_model.encode(sequences[0])
|
|
250
|
-
fprint(f"Single embedding shape: {embedding.shape}")
|
|
251
|
-
|
|
252
|
-
embeddings = embedding_model.batch_encode(sequences)
|
|
253
|
-
fprint(f"Embeddings for sequences: {embeddings}")
|
|
254
|
-
|
|
255
|
-
# Save and load embeddings
|
|
256
|
-
embedding_model.save_embeddings(embeddings, "embeddings.pt")
|
|
257
|
-
loaded_embeddings = embedding_model.load_embeddings("embeddings.pt")
|
|
258
|
-
|
|
259
|
-
# Compute similarity between two embeddings
|
|
260
|
-
similarity = embedding_model.compute_similarity(
|
|
261
|
-
loaded_embeddings[0], loaded_embeddings[1]
|
|
262
|
-
)
|
|
263
|
-
fprint(f"Cosine similarity: {similarity}")
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: __init__.py
|
|
3
|
-
# time: 13:30 10/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
"""
|
|
10
|
-
This package contains modules for Masked Language Models (MLM).
|
|
11
|
-
"""
|
|
12
|
-
|
|
@@ -1,177 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: model.py
|
|
3
|
-
# time: 13:30 10/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
"""
|
|
10
|
-
Masked Language Model (MLM) for genomic sequences.
|
|
11
|
-
|
|
12
|
-
This module provides a masked language model implementation specifically designed
|
|
13
|
-
for genomic sequences. It supports masked language modeling tasks where tokens
|
|
14
|
-
are randomly masked and the model learns to predict the original tokens.
|
|
15
|
-
"""
|
|
16
|
-
import numpy as np
|
|
17
|
-
import torch
|
|
18
|
-
from transformers import BatchEncoding
|
|
19
|
-
|
|
20
|
-
from ...abc.abstract_model import OmniModel
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class OmniModelForMLM(OmniModel):
|
|
24
|
-
"""
|
|
25
|
-
Masked Language Model for genomic sequences.
|
|
26
|
-
|
|
27
|
-
This model implements masked language modeling for genomic sequences, where
|
|
28
|
-
tokens are randomly masked and the model learns to predict the original tokens.
|
|
29
|
-
It's useful for pre-training genomic language models and understanding sequence
|
|
30
|
-
patterns and dependencies.
|
|
31
|
-
|
|
32
|
-
Attributes:
|
|
33
|
-
loss_fn: Cross-entropy loss function for masked language modeling
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
37
|
-
"""
|
|
38
|
-
Initialize the MLM model.
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
config_or_model: Model configuration or pre-trained model
|
|
42
|
-
tokenizer: Tokenizer for processing input sequences
|
|
43
|
-
*args: Additional positional arguments
|
|
44
|
-
**kwargs: Additional keyword arguments
|
|
45
|
-
|
|
46
|
-
Raises:
|
|
47
|
-
ValueError: If the model doesn't support masked language modeling
|
|
48
|
-
"""
|
|
49
|
-
super().__init__(config_or_model, tokenizer, *args, **kwargs)
|
|
50
|
-
self.metadata["model_name"] = self.__class__.__name__
|
|
51
|
-
if "MaskedLM" not in self.model.__class__.__name__:
|
|
52
|
-
raise ValueError(
|
|
53
|
-
"The model does not have a language model head, which is required for MLM."
|
|
54
|
-
"Please use a model that supports masked language modeling."
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
self.loss_fn = torch.nn.CrossEntropyLoss()
|
|
58
|
-
|
|
59
|
-
def forward(self, **inputs):
|
|
60
|
-
"""
|
|
61
|
-
Forward pass for masked language modeling.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
**inputs: Input tensors including input_ids, attention_mask, and labels
|
|
65
|
-
|
|
66
|
-
Returns:
|
|
67
|
-
dict: Dictionary containing loss, logits, and last_hidden_state
|
|
68
|
-
"""
|
|
69
|
-
inputs = inputs.pop("inputs")
|
|
70
|
-
outputs = self.model(**inputs, output_hidden_states=True)
|
|
71
|
-
last_hidden_state = (
|
|
72
|
-
outputs["last_hidden_state"]
|
|
73
|
-
if "last_hidden_state" in outputs
|
|
74
|
-
else outputs["hidden_states"][-1]
|
|
75
|
-
)
|
|
76
|
-
logits = outputs["logits"] if "logits" in outputs else None
|
|
77
|
-
loss = outputs["loss"] if "loss" in outputs else None
|
|
78
|
-
outputs = {
|
|
79
|
-
"loss": loss,
|
|
80
|
-
"logits": logits,
|
|
81
|
-
"last_hidden_state": last_hidden_state,
|
|
82
|
-
}
|
|
83
|
-
return outputs
|
|
84
|
-
|
|
85
|
-
def predict(self, sequence_or_inputs, **kwargs):
|
|
86
|
-
"""
|
|
87
|
-
Generate predictions for masked language modeling.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
91
|
-
**kwargs: Additional keyword arguments
|
|
92
|
-
|
|
93
|
-
Returns:
|
|
94
|
-
dict: Dictionary containing predictions, logits, and last_hidden_state
|
|
95
|
-
"""
|
|
96
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
97
|
-
|
|
98
|
-
logits = raw_outputs["logits"]
|
|
99
|
-
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
100
|
-
|
|
101
|
-
predictions = []
|
|
102
|
-
for i in range(logits.shape[0]):
|
|
103
|
-
predictions.append(logits[i].argmax(dim=-1).cpu())
|
|
104
|
-
|
|
105
|
-
if not isinstance(sequence_or_inputs, list):
|
|
106
|
-
outputs = {
|
|
107
|
-
"predictions": predictions[0],
|
|
108
|
-
"logits": logits[0],
|
|
109
|
-
"last_hidden_state": last_hidden_state[0],
|
|
110
|
-
}
|
|
111
|
-
else:
|
|
112
|
-
outputs = {
|
|
113
|
-
"predictions": (
|
|
114
|
-
torch.stack(predictions)
|
|
115
|
-
if predictions[0].shape
|
|
116
|
-
else torch.tensor(predictions).to(self.model.device)
|
|
117
|
-
),
|
|
118
|
-
"logits": logits,
|
|
119
|
-
"last_hidden_state": last_hidden_state,
|
|
120
|
-
}
|
|
121
|
-
|
|
122
|
-
return outputs
|
|
123
|
-
|
|
124
|
-
def inference(self, sequence_or_inputs, **kwargs):
|
|
125
|
-
"""
|
|
126
|
-
Perform inference for masked language modeling, decoding predictions to sequences.
|
|
127
|
-
|
|
128
|
-
Args:
|
|
129
|
-
sequence_or_inputs: Input sequences or pre-processed inputs
|
|
130
|
-
**kwargs: Additional keyword arguments
|
|
131
|
-
|
|
132
|
-
Returns:
|
|
133
|
-
dict: Dictionary containing decoded predictions, logits, and last_hidden_state
|
|
134
|
-
"""
|
|
135
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
136
|
-
|
|
137
|
-
inputs = raw_outputs["inputs"]
|
|
138
|
-
logits = raw_outputs["logits"]
|
|
139
|
-
last_hidden_state = raw_outputs["last_hidden_state"]
|
|
140
|
-
|
|
141
|
-
predictions = []
|
|
142
|
-
for i in range(logits.shape[0]):
|
|
143
|
-
i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][
|
|
144
|
-
1:-1
|
|
145
|
-
]
|
|
146
|
-
prediction = self.tokenizer.decode(i_logit.argmax(dim=-1)).replace(" ", "")
|
|
147
|
-
predictions.append(list(prediction))
|
|
148
|
-
|
|
149
|
-
if not isinstance(sequence_or_inputs, list):
|
|
150
|
-
outputs = {
|
|
151
|
-
"predictions": predictions[0],
|
|
152
|
-
"logits": logits[0],
|
|
153
|
-
"last_hidden_state": last_hidden_state[0],
|
|
154
|
-
}
|
|
155
|
-
else:
|
|
156
|
-
outputs = {
|
|
157
|
-
"predictions": predictions,
|
|
158
|
-
"logits": logits,
|
|
159
|
-
"last_hidden_state": last_hidden_state,
|
|
160
|
-
}
|
|
161
|
-
|
|
162
|
-
return outputs
|
|
163
|
-
|
|
164
|
-
def loss_function(self, logits, labels):
|
|
165
|
-
"""
|
|
166
|
-
Compute the loss for masked language modeling.
|
|
167
|
-
|
|
168
|
-
Args:
|
|
169
|
-
logits (torch.Tensor): Model predictions [batch_size, seq_len, vocab_size]
|
|
170
|
-
labels (torch.Tensor): Ground truth labels [batch_size, seq_len]
|
|
171
|
-
|
|
172
|
-
Returns:
|
|
173
|
-
torch.Tensor: Computed cross-entropy loss value
|
|
174
|
-
"""
|
|
175
|
-
loss_fn = torch.nn.CrossEntropyLoss()
|
|
176
|
-
loss = loss_fn(logits.view(-1, self.tokenizer.vocab_size), labels.view(-1))
|
|
177
|
-
return loss
|
|
@@ -1,232 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: module_utils.py
|
|
3
|
-
# time: 22:53 18/07/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
"""
|
|
10
|
-
Module utilities for OmniGenome models.
|
|
11
|
-
|
|
12
|
-
This module provides utility classes and functions for handling model inputs,
|
|
13
|
-
pooling operations, and attention mechanisms used across different OmniGenome model types.
|
|
14
|
-
"""
|
|
15
|
-
import torch
|
|
16
|
-
import torch.nn as nn
|
|
17
|
-
|
|
18
|
-
from transformers.models.bert.modeling_bert import BertPooler
|
|
19
|
-
from transformers.tokenization_utils_base import BatchEncoding
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class OmniPooling(torch.nn.Module):
|
|
23
|
-
"""
|
|
24
|
-
A flexible pooling layer for OmniGenome models that handles different input formats.
|
|
25
|
-
|
|
26
|
-
This class provides a unified interface for pooling operations across different
|
|
27
|
-
model architectures, supporting both causal language models and encoder-based models.
|
|
28
|
-
It can handle various input formats including tuples, dictionaries, BatchEncoding
|
|
29
|
-
objects, and tensors.
|
|
30
|
-
|
|
31
|
-
Attributes:
|
|
32
|
-
config: Model configuration object containing architecture and tokenizer settings
|
|
33
|
-
pooler: BertPooler instance for non-causal models, None for causal models
|
|
34
|
-
"""
|
|
35
|
-
|
|
36
|
-
def __init__(self, config, *args, **kwargs):
|
|
37
|
-
"""
|
|
38
|
-
Initialize the OmniPooling layer.
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
config: Model configuration object containing architecture information
|
|
42
|
-
*args: Additional positional arguments
|
|
43
|
-
**kwargs: Additional keyword arguments
|
|
44
|
-
"""
|
|
45
|
-
super().__init__(*args, **kwargs)
|
|
46
|
-
self.config = config
|
|
47
|
-
self.pooler = BertPooler(self.config) if not self._is_causal_lm() else None
|
|
48
|
-
|
|
49
|
-
def forward(self, inputs, last_hidden_state):
|
|
50
|
-
"""
|
|
51
|
-
Perform pooling operation on the last hidden state.
|
|
52
|
-
|
|
53
|
-
This method handles different input formats and applies appropriate pooling:
|
|
54
|
-
- For causal language models: Uses the last non-padded token
|
|
55
|
-
- For encoder models: Uses the BertPooler
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
inputs: Input data in various formats (tuple, dict, BatchEncoding, or tensor)
|
|
59
|
-
last_hidden_state (torch.Tensor): Hidden states from the model [batch_size, seq_len, hidden_size]
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
torch.Tensor: Pooled representation [batch_size, hidden_size]
|
|
63
|
-
|
|
64
|
-
Raises:
|
|
65
|
-
ValueError: If input format is not supported or cannot be parsed
|
|
66
|
-
"""
|
|
67
|
-
if isinstance(inputs, tuple):
|
|
68
|
-
input_ids = inputs[0]
|
|
69
|
-
attention_mask = inputs[1] if len(inputs) > 1 else None
|
|
70
|
-
elif isinstance(inputs, BatchEncoding) or isinstance(inputs, dict):
|
|
71
|
-
input_ids = inputs["input_ids"]
|
|
72
|
-
attention_mask = (
|
|
73
|
-
inputs["attention_mask"] if "attention_mask" in inputs else None
|
|
74
|
-
)
|
|
75
|
-
elif isinstance(inputs, torch.Tensor):
|
|
76
|
-
shape = inputs.shape
|
|
77
|
-
try:
|
|
78
|
-
if len(shape) == 3:
|
|
79
|
-
# compatible with hf_trainer in AutoBenchmark
|
|
80
|
-
if shape[1] == 2:
|
|
81
|
-
input_ids = inputs[:, 0]
|
|
82
|
-
attention_mask = inputs[:, 1]
|
|
83
|
-
else:
|
|
84
|
-
input_ids = inputs[0]
|
|
85
|
-
attention_mask = inputs[1] if len(inputs) > 1 else None
|
|
86
|
-
elif len(shape) == 2:
|
|
87
|
-
input_ids = inputs
|
|
88
|
-
attention_mask = None
|
|
89
|
-
except:
|
|
90
|
-
raise ValueError(
|
|
91
|
-
f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}."
|
|
92
|
-
)
|
|
93
|
-
else:
|
|
94
|
-
raise ValueError(
|
|
95
|
-
f"The inputs should be a tuple, BatchEncoding or a dictionary-like object, got {type(inputs)}."
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
if not self.pooler:
|
|
99
|
-
pad_token_id = getattr(self.config, "pad_token_id", -100)
|
|
100
|
-
sequence_lengths = input_ids.ne(pad_token_id).sum(dim=1) - 1
|
|
101
|
-
last_hidden_state = last_hidden_state[
|
|
102
|
-
torch.arange(input_ids.size(0), device=last_hidden_state.device),
|
|
103
|
-
sequence_lengths,
|
|
104
|
-
]
|
|
105
|
-
else:
|
|
106
|
-
last_hidden_state = self.pooler(last_hidden_state)
|
|
107
|
-
|
|
108
|
-
return last_hidden_state
|
|
109
|
-
|
|
110
|
-
def _is_causal_lm(self):
|
|
111
|
-
"""
|
|
112
|
-
Check if the model is a causal language model.
|
|
113
|
-
|
|
114
|
-
Determines if the model architecture is causal based on the configuration.
|
|
115
|
-
|
|
116
|
-
Returns:
|
|
117
|
-
bool: True if the model is a causal language model, False otherwise
|
|
118
|
-
"""
|
|
119
|
-
if (
|
|
120
|
-
hasattr(self.config, "architectures")
|
|
121
|
-
and "CausalLM" in str(self.config.architectures)
|
|
122
|
-
) or (
|
|
123
|
-
hasattr(self.config, "auto_map") and "CausalLM" in str(self.config.auto_map)
|
|
124
|
-
):
|
|
125
|
-
return True
|
|
126
|
-
else:
|
|
127
|
-
return False
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
# class InteractingAttention(nn.Module):
|
|
131
|
-
# def __init__(self, embed_size, num_heads=12):
|
|
132
|
-
# super(InteractingAttention, self).__init__()
|
|
133
|
-
# self.num_heads = num_heads
|
|
134
|
-
# self.embed_size = embed_size
|
|
135
|
-
#
|
|
136
|
-
# assert embed_size % num_heads == 0, "Embedding size should be divisible by number of heads"
|
|
137
|
-
#
|
|
138
|
-
# self.head_dim = embed_size // num_heads
|
|
139
|
-
#
|
|
140
|
-
# self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
|
|
141
|
-
# self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
|
|
142
|
-
# self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
|
|
143
|
-
# self.layer_norm = nn.LayerNorm(num_heads * self.head_dim, eps=1e-6)
|
|
144
|
-
#
|
|
145
|
-
# self.fc_out = nn.Linear(num_heads * self.head_dim, embed_size)
|
|
146
|
-
#
|
|
147
|
-
# # def forward(self, query, keys, values):
|
|
148
|
-
# def forward(self, query, keys, values):
|
|
149
|
-
#
|
|
150
|
-
# N = query.shape[0]
|
|
151
|
-
# value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
|
|
152
|
-
#
|
|
153
|
-
# # Split embedding into self.num_heads pieces
|
|
154
|
-
# values = values.reshape(N, value_len, self.num_heads, self.head_dim)
|
|
155
|
-
# keys = keys.reshape(N, key_len, self.num_heads, self.head_dim)
|
|
156
|
-
# queries = query.reshape(N, query_len, self.num_heads, self.head_dim)
|
|
157
|
-
#
|
|
158
|
-
# values = self.values(values) # (N, value_len, heads, head_dim)
|
|
159
|
-
# keys = self.keys(keys) # (N, key_len, heads, head_dim)
|
|
160
|
-
# queries = self.queries(queries) # (N, query_len, heads, head_dim)
|
|
161
|
-
#
|
|
162
|
-
# energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
|
|
163
|
-
#
|
|
164
|
-
# attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3) # (N, heads, query_len, key_len)
|
|
165
|
-
#
|
|
166
|
-
# out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
|
|
167
|
-
# N, query_len, self.num_heads * self.head_dim
|
|
168
|
-
# )
|
|
169
|
-
# out = self.layer_norm(out + query)
|
|
170
|
-
# out = self.fc_out(out)
|
|
171
|
-
# out = self.layer_norm(out + query)
|
|
172
|
-
# return out
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
class InteractingAttention(nn.Module):
|
|
176
|
-
"""
|
|
177
|
-
An interacting attention mechanism for sequence modeling.
|
|
178
|
-
|
|
179
|
-
This class implements a multi-head attention mechanism with residual connections
|
|
180
|
-
and layer normalization. It's designed for processing sequences where different
|
|
181
|
-
parts of the sequence need to interact with each other.
|
|
182
|
-
|
|
183
|
-
Attributes:
|
|
184
|
-
attention: Multi-head attention layer
|
|
185
|
-
layer_norm: Layer normalization for residual connections
|
|
186
|
-
fc_out: Output projection layer
|
|
187
|
-
"""
|
|
188
|
-
|
|
189
|
-
def __init__(self, embed_size, num_heads=24):
|
|
190
|
-
"""
|
|
191
|
-
Initialize the InteractingAttention module.
|
|
192
|
-
|
|
193
|
-
Args:
|
|
194
|
-
embed_size (int): Size of the embedding dimension
|
|
195
|
-
num_heads (int): Number of attention heads (default: 24)
|
|
196
|
-
|
|
197
|
-
Raises:
|
|
198
|
-
AssertionError: If embed_size is not divisible by num_heads
|
|
199
|
-
"""
|
|
200
|
-
super(InteractingAttention, self).__init__()
|
|
201
|
-
assert (
|
|
202
|
-
embed_size % num_heads == 0
|
|
203
|
-
), "Embedding size should be divisible by number of heads"
|
|
204
|
-
|
|
205
|
-
self.attention = nn.MultiheadAttention(
|
|
206
|
-
embed_dim=embed_size, num_heads=num_heads, batch_first=True
|
|
207
|
-
)
|
|
208
|
-
|
|
209
|
-
self.layer_norm = nn.LayerNorm(embed_size, eps=1e-6)
|
|
210
|
-
|
|
211
|
-
self.fc_out = nn.Linear(embed_size, embed_size)
|
|
212
|
-
|
|
213
|
-
def forward(self, query, keys, values):
|
|
214
|
-
"""
|
|
215
|
-
Forward pass through the interacting attention mechanism.
|
|
216
|
-
|
|
217
|
-
Args:
|
|
218
|
-
query (torch.Tensor): Query tensor [batch_size, query_len, embed_size]
|
|
219
|
-
keys (torch.Tensor): Key tensor [batch_size, key_len, embed_size]
|
|
220
|
-
values (torch.Tensor): Value tensor [batch_size, value_len, embed_size]
|
|
221
|
-
|
|
222
|
-
Returns:
|
|
223
|
-
torch.Tensor: Output tensor with same shape as query
|
|
224
|
-
"""
|
|
225
|
-
att_output, _ = self.attention(query, keys, values)
|
|
226
|
-
|
|
227
|
-
query = self.layer_norm(att_output + query)
|
|
228
|
-
|
|
229
|
-
output = self.fc_out(query)
|
|
230
|
-
output = self.layer_norm(output + query)
|
|
231
|
-
|
|
232
|
-
return output
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: __init__.py
|
|
3
|
-
# time: 21:11 08/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
"""
|
|
10
|
-
This package contains modules for regression models.
|
|
11
|
-
"""
|
|
12
|
-
|