omnigenome 0.3.0a1__py3-none-any.whl → 0.3.3a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (79) hide show
  1. omnigenome/__init__.py +252 -258
  2. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/METADATA +10 -10
  3. omnigenome-0.3.3a0.dist-info/RECORD +7 -0
  4. omnigenome/auto/__init__.py +0 -3
  5. omnigenome/auto/auto_bench/__init__.py +0 -12
  6. omnigenome/auto/auto_bench/auto_bench.py +0 -484
  7. omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
  8. omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
  9. omnigenome/auto/auto_bench/config_check.py +0 -34
  10. omnigenome/auto/auto_train/__init__.py +0 -13
  11. omnigenome/auto/auto_train/auto_train.py +0 -430
  12. omnigenome/auto/auto_train/auto_train_cli.py +0 -222
  13. omnigenome/auto/bench_hub/__init__.py +0 -12
  14. omnigenome/auto/bench_hub/bench_hub.py +0 -25
  15. omnigenome/cli/__init__.py +0 -13
  16. omnigenome/cli/commands/__init__.py +0 -13
  17. omnigenome/cli/commands/base.py +0 -83
  18. omnigenome/cli/commands/bench/__init__.py +0 -13
  19. omnigenome/cli/commands/bench/bench_cli.py +0 -202
  20. omnigenome/cli/commands/rna/__init__.py +0 -13
  21. omnigenome/cli/commands/rna/rna_design.py +0 -178
  22. omnigenome/cli/omnigenome_cli.py +0 -128
  23. omnigenome/src/__init__.py +0 -12
  24. omnigenome/src/abc/__init__.py +0 -12
  25. omnigenome/src/abc/abstract_dataset.py +0 -622
  26. omnigenome/src/abc/abstract_metric.py +0 -114
  27. omnigenome/src/abc/abstract_model.py +0 -689
  28. omnigenome/src/abc/abstract_tokenizer.py +0 -267
  29. omnigenome/src/dataset/__init__.py +0 -16
  30. omnigenome/src/dataset/omni_dataset.py +0 -435
  31. omnigenome/src/lora/__init__.py +0 -13
  32. omnigenome/src/lora/lora_model.py +0 -294
  33. omnigenome/src/metric/__init__.py +0 -15
  34. omnigenome/src/metric/classification_metric.py +0 -184
  35. omnigenome/src/metric/metric.py +0 -199
  36. omnigenome/src/metric/ranking_metric.py +0 -142
  37. omnigenome/src/metric/regression_metric.py +0 -191
  38. omnigenome/src/misc/__init__.py +0 -3
  39. omnigenome/src/misc/utils.py +0 -499
  40. omnigenome/src/model/__init__.py +0 -19
  41. omnigenome/src/model/augmentation/__init__.py +0 -12
  42. omnigenome/src/model/augmentation/model.py +0 -219
  43. omnigenome/src/model/classification/__init__.py +0 -12
  44. omnigenome/src/model/classification/model.py +0 -642
  45. omnigenome/src/model/embedding/__init__.py +0 -12
  46. omnigenome/src/model/embedding/model.py +0 -263
  47. omnigenome/src/model/mlm/__init__.py +0 -12
  48. omnigenome/src/model/mlm/model.py +0 -177
  49. omnigenome/src/model/module_utils.py +0 -232
  50. omnigenome/src/model/regression/__init__.py +0 -12
  51. omnigenome/src/model/regression/model.py +0 -786
  52. omnigenome/src/model/regression/resnet.py +0 -483
  53. omnigenome/src/model/rna_design/__init__.py +0 -12
  54. omnigenome/src/model/rna_design/model.py +0 -469
  55. omnigenome/src/model/seq2seq/__init__.py +0 -12
  56. omnigenome/src/model/seq2seq/model.py +0 -44
  57. omnigenome/src/tokenizer/__init__.py +0 -16
  58. omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
  59. omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
  60. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
  61. omnigenome/src/trainer/__init__.py +0 -14
  62. omnigenome/src/trainer/accelerate_trainer.py +0 -739
  63. omnigenome/src/trainer/hf_trainer.py +0 -75
  64. omnigenome/src/trainer/trainer.py +0 -579
  65. omnigenome/utility/__init__.py +0 -3
  66. omnigenome/utility/dataset_hub/__init__.py +0 -13
  67. omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
  68. omnigenome/utility/ensemble.py +0 -324
  69. omnigenome/utility/hub_utils.py +0 -517
  70. omnigenome/utility/model_hub/__init__.py +0 -12
  71. omnigenome/utility/model_hub/model_hub.py +0 -231
  72. omnigenome/utility/pipeline_hub/__init__.py +0 -12
  73. omnigenome/utility/pipeline_hub/pipeline.py +0 -483
  74. omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
  75. omnigenome-0.3.0a1.dist-info/RECORD +0 -78
  76. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/WHEEL +0 -0
  77. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/entry_points.txt +0 -0
  78. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/licenses/LICENSE +0 -0
  79. {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.3a0.dist-info}/top_level.txt +0 -0
@@ -1,263 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: model.py
3
- # time: 18:37 22/09/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
-
10
- import torch
11
- from transformers import AutoTokenizer, AutoModel
12
-
13
- from omnigenome.src.misc.utils import fprint
14
-
15
-
16
- class OmniModelForEmbedding(torch.nn.Module):
17
- """
18
- A wrapper class for generating embeddings from pre-trained models.
19
-
20
- This class provides a unified interface for loading pre-trained models and
21
- generating embeddings from genomic sequences. It supports various aggregation
22
- methods and batch processing for efficient embedding generation.
23
-
24
- Attributes:
25
- tokenizer: The tokenizer for processing input sequences
26
- model: The pre-trained model for generating embeddings
27
- _device: The device (CPU/GPU) where the model is loaded
28
-
29
- Example:
30
- >>> from omnigenome import OmniModelForEmbedding
31
- >>> model = OmniModelForEmbedding("anonymous8/OmniGenome-186M")
32
- >>> sequences = ["ATCGGCTA", "GGCTAGCTA"]
33
- >>> embeddings = model.batch_encode(sequences)
34
- >>> print(f"Embeddings shape: {embeddings.shape}")
35
- torch.Size([2, 768])
36
- """
37
-
38
- def __init__(self, model_name_or_path, *args, **kwargs):
39
- """
40
- Initialize the embedding model.
41
-
42
- Args:
43
- model_name_or_path (str): Name or path of the pre-trained model to load
44
- *args: Additional positional arguments passed to AutoModel.from_pretrained
45
- **kwargs: Additional keyword arguments passed to AutoModel.from_pretrained
46
- """
47
- super().__init__()
48
- self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
49
- self.model = AutoModel.from_pretrained(model_name_or_path, *args, **kwargs)
50
- self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
- self.model.to(self._device)
52
- self.model.eval() # Set model to evaluation mode
53
-
54
- def batch_encode(self, sequences, batch_size=8, max_length=512, agg='head'):
55
- """
56
- Encode a list of sequences to their corresponding embeddings.
57
-
58
- This method processes sequences in batches for memory efficiency and
59
- supports different aggregation methods for the final embeddings.
60
-
61
- Args:
62
- sequences (list of str): List of input sequences to encode
63
- batch_size (int, optional): Batch size for processing. Defaults to 8
64
- max_length (int, optional): Maximum sequence length for encoding. Defaults to 512
65
- agg (str, optional): Aggregation method for embeddings. Options are 'head', 'mean', 'tail'. Defaults to 'head'
66
-
67
- Returns:
68
- torch.Tensor: Embeddings for the input sequences with shape (n_sequences, embedding_dim)
69
-
70
- Raises:
71
- ValueError: If unsupported aggregation method is provided
72
-
73
- Example:
74
- >>> sequences = ["ATCGGCTA", "GGCTAGCTA", "TATCGCTA"]
75
- >>> embeddings = model.batch_encode(sequences, batch_size=2, agg='mean')
76
- >>> print(f"Embeddings shape: {embeddings.shape}")
77
- torch.Size([3, 768])
78
- """
79
- embeddings = []
80
-
81
- for i in range(0, len(sequences), batch_size):
82
- batch_sequences = sequences[i: i + batch_size]
83
- inputs = self.tokenizer(
84
- batch_sequences,
85
- return_tensors="pt",
86
- padding=True,
87
- truncation=True,
88
- max_length=max_length,
89
- )
90
- inputs = {key: value.to(self.device) for key, value in inputs.items()}
91
-
92
- with torch.no_grad():
93
- outputs = self.model(**inputs)
94
-
95
- batch_embeddings = outputs.last_hidden_state.cpu()
96
-
97
- if agg == 'head':
98
- emb = batch_embeddings[:, 0, :]
99
- elif agg == 'mean':
100
- attention_mask = inputs["attention_mask"].cpu()
101
- masked_embeddings = batch_embeddings * attention_mask.unsqueeze(-1)
102
- lengths = attention_mask.sum(dim=1).unsqueeze(1)
103
- emb = masked_embeddings.sum(dim=1) / lengths
104
- elif agg == 'tail':
105
- attention_mask = inputs["attention_mask"]
106
- lengths = attention_mask.sum(dim=1) - 1
107
- emb = torch.stack([
108
- batch_embeddings[i, l.item()] for i, l in enumerate(lengths)
109
- ])
110
- else:
111
- raise ValueError(f"Unsupported aggregation method: {agg}")
112
-
113
- embeddings.append(emb)
114
-
115
- embeddings = torch.cat(embeddings, dim=0)
116
- fprint(f"Generated embeddings for {len(sequences)} sequences.")
117
- return embeddings
118
-
119
- def encode(self, sequence, max_length=512, agg='head', keep_dim=False):
120
- """
121
- Encode a single sequence to its corresponding embedding.
122
-
123
- Args:
124
- sequence (str): Input sequence to encode
125
- max_length (int, optional): Maximum sequence length for encoding. Defaults to 512
126
- agg (str, optional): Aggregation method. Options are 'head', 'mean', 'tail'. Defaults to 'head'
127
- keep_dim (bool, optional): Whether to retain the batch dimension. Defaults to False
128
-
129
- Returns:
130
- torch.Tensor: Embedding for the input sequence
131
-
132
- Raises:
133
- ValueError: If unsupported aggregation method is provided
134
-
135
- Example:
136
- >>> sequence = "ATCGGCTA"
137
- >>> embedding = model.encode(sequence, agg='mean')
138
- >>> print(f"Embedding shape: {embedding.shape}")
139
- torch.Size([768])
140
- """
141
- inputs = self.tokenizer(
142
- sequence,
143
- return_tensors="pt",
144
- padding=True,
145
- truncation=True,
146
- max_length=max_length,
147
- )
148
- inputs = {key: value.to(self.device) for key, value in inputs.items()}
149
-
150
- with torch.no_grad():
151
- outputs = self.model(**inputs)
152
-
153
- last_hidden = outputs.last_hidden_state.cpu()
154
-
155
- if agg == 'head':
156
- emb = last_hidden[0, 0]
157
- elif agg == 'mean':
158
- attention_mask = inputs["attention_mask"].cpu()
159
- masked_embeddings = last_hidden * attention_mask.unsqueeze(-1)
160
- lengths = attention_mask.sum(dim=1).unsqueeze(1)
161
- emb = masked_embeddings.sum(dim=1) / lengths
162
- emb = emb.squeeze(0)
163
- elif agg == 'tail':
164
- attention_mask = inputs["attention_mask"]
165
- lengths = attention_mask.sum(dim=1) - 1
166
- emb = last_hidden[0, lengths[0].item()]
167
- else:
168
- raise ValueError(f"Unsupported aggregation method: {agg}")
169
-
170
- return emb.unsqueeze(0) if keep_dim else emb
171
-
172
- def save_embeddings(self, embeddings, output_path):
173
- """
174
- Save the generated embeddings to a file.
175
-
176
- Args:
177
- embeddings (torch.Tensor): The embeddings to save
178
- output_path (str): Path to save the embeddings
179
-
180
- Example:
181
- >>> embeddings = model.batch_encode(sequences)
182
- >>> model.save_embeddings(embeddings, "embeddings.pt")
183
- >>> print("Embeddings saved successfully")
184
- """
185
- torch.save(embeddings, output_path)
186
- fprint(f"Embeddings saved to {output_path}")
187
-
188
- def load_embeddings(self, embedding_path):
189
- """
190
- Load embeddings from a file.
191
-
192
- Args:
193
- embedding_path (str): Path to the saved embeddings
194
-
195
- Returns:
196
- torch.Tensor: The loaded embeddings
197
-
198
- Example:
199
- >>> embeddings = model.load_embeddings("embeddings.pt")
200
- >>> print(f"Loaded embeddings shape: {embeddings.shape}")
201
- torch.Size([100, 768])
202
- """
203
- embeddings = torch.load(embedding_path)
204
- fprint(f"Loaded embeddings from {embedding_path}")
205
- return embeddings
206
-
207
- def compute_similarity(self, embedding1, embedding2, dim=0):
208
- """
209
- Compute cosine similarity between two embeddings.
210
-
211
- Args:
212
- embedding1 (torch.Tensor): The first embedding
213
- embedding2 (torch.Tensor): The second embedding
214
- dim (int, optional): Dimension along which to compute cosine similarity. Defaults to 0
215
-
216
- Returns:
217
- float: Cosine similarity score between -1 and 1
218
-
219
- Example:
220
- >>> emb1 = model.encode("ATCGGCTA")
221
- >>> emb2 = model.encode("GGCTAGCTA")
222
- >>> similarity = model.compute_similarity(emb1, emb2)
223
- >>> print(f"Cosine similarity: {similarity:.4f}")
224
- 0.8234
225
- """
226
- similarity = torch.nn.functional.cosine_similarity(
227
- embedding1, embedding2, dim=dim
228
- )
229
- return similarity
230
-
231
- @property
232
- def device(self):
233
- """
234
- Get the current device ('cuda' or 'cpu').
235
-
236
- Returns:
237
- torch.device: The device where the model is loaded
238
- """
239
- return self._device
240
-
241
-
242
- # Example usage
243
- if __name__ == "__main__":
244
- model_name = "anonymous8/OmniGenome-186M"
245
- embedding_model = OmniModelForEmbedding(model_name)
246
-
247
- # Encode multiple sequences
248
- sequences = ["ATCGGCTA", "GGCTAGCTA"]
249
- embedding = embedding_model.encode(sequences[0])
250
- fprint(f"Single embedding shape: {embedding.shape}")
251
-
252
- embeddings = embedding_model.batch_encode(sequences)
253
- fprint(f"Embeddings for sequences: {embeddings}")
254
-
255
- # Save and load embeddings
256
- embedding_model.save_embeddings(embeddings, "embeddings.pt")
257
- loaded_embeddings = embedding_model.load_embeddings("embeddings.pt")
258
-
259
- # Compute similarity between two embeddings
260
- similarity = embedding_model.compute_similarity(
261
- loaded_embeddings[0], loaded_embeddings[1]
262
- )
263
- fprint(f"Cosine similarity: {similarity}")
@@ -1,12 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: __init__.py
3
- # time: 13:30 10/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- """
10
- This package contains modules for Masked Language Models (MLM).
11
- """
12
-
@@ -1,177 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: model.py
3
- # time: 13:30 10/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- """
10
- Masked Language Model (MLM) for genomic sequences.
11
-
12
- This module provides a masked language model implementation specifically designed
13
- for genomic sequences. It supports masked language modeling tasks where tokens
14
- are randomly masked and the model learns to predict the original tokens.
15
- """
16
- import numpy as np
17
- import torch
18
- from transformers import BatchEncoding
19
-
20
- from ...abc.abstract_model import OmniModel
21
-
22
-
23
- class OmniModelForMLM(OmniModel):
24
- """
25
- Masked Language Model for genomic sequences.
26
-
27
- This model implements masked language modeling for genomic sequences, where
28
- tokens are randomly masked and the model learns to predict the original tokens.
29
- It's useful for pre-training genomic language models and understanding sequence
30
- patterns and dependencies.
31
-
32
- Attributes:
33
- loss_fn: Cross-entropy loss function for masked language modeling
34
- """
35
-
36
- def __init__(self, config_or_model, tokenizer, *args, **kwargs):
37
- """
38
- Initialize the MLM model.
39
-
40
- Args:
41
- config_or_model: Model configuration or pre-trained model
42
- tokenizer: Tokenizer for processing input sequences
43
- *args: Additional positional arguments
44
- **kwargs: Additional keyword arguments
45
-
46
- Raises:
47
- ValueError: If the model doesn't support masked language modeling
48
- """
49
- super().__init__(config_or_model, tokenizer, *args, **kwargs)
50
- self.metadata["model_name"] = self.__class__.__name__
51
- if "MaskedLM" not in self.model.__class__.__name__:
52
- raise ValueError(
53
- "The model does not have a language model head, which is required for MLM."
54
- "Please use a model that supports masked language modeling."
55
- )
56
-
57
- self.loss_fn = torch.nn.CrossEntropyLoss()
58
-
59
- def forward(self, **inputs):
60
- """
61
- Forward pass for masked language modeling.
62
-
63
- Args:
64
- **inputs: Input tensors including input_ids, attention_mask, and labels
65
-
66
- Returns:
67
- dict: Dictionary containing loss, logits, and last_hidden_state
68
- """
69
- inputs = inputs.pop("inputs")
70
- outputs = self.model(**inputs, output_hidden_states=True)
71
- last_hidden_state = (
72
- outputs["last_hidden_state"]
73
- if "last_hidden_state" in outputs
74
- else outputs["hidden_states"][-1]
75
- )
76
- logits = outputs["logits"] if "logits" in outputs else None
77
- loss = outputs["loss"] if "loss" in outputs else None
78
- outputs = {
79
- "loss": loss,
80
- "logits": logits,
81
- "last_hidden_state": last_hidden_state,
82
- }
83
- return outputs
84
-
85
- def predict(self, sequence_or_inputs, **kwargs):
86
- """
87
- Generate predictions for masked language modeling.
88
-
89
- Args:
90
- sequence_or_inputs: Input sequences or pre-processed inputs
91
- **kwargs: Additional keyword arguments
92
-
93
- Returns:
94
- dict: Dictionary containing predictions, logits, and last_hidden_state
95
- """
96
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
97
-
98
- logits = raw_outputs["logits"]
99
- last_hidden_state = raw_outputs["last_hidden_state"]
100
-
101
- predictions = []
102
- for i in range(logits.shape[0]):
103
- predictions.append(logits[i].argmax(dim=-1).cpu())
104
-
105
- if not isinstance(sequence_or_inputs, list):
106
- outputs = {
107
- "predictions": predictions[0],
108
- "logits": logits[0],
109
- "last_hidden_state": last_hidden_state[0],
110
- }
111
- else:
112
- outputs = {
113
- "predictions": (
114
- torch.stack(predictions)
115
- if predictions[0].shape
116
- else torch.tensor(predictions).to(self.model.device)
117
- ),
118
- "logits": logits,
119
- "last_hidden_state": last_hidden_state,
120
- }
121
-
122
- return outputs
123
-
124
- def inference(self, sequence_or_inputs, **kwargs):
125
- """
126
- Perform inference for masked language modeling, decoding predictions to sequences.
127
-
128
- Args:
129
- sequence_or_inputs: Input sequences or pre-processed inputs
130
- **kwargs: Additional keyword arguments
131
-
132
- Returns:
133
- dict: Dictionary containing decoded predictions, logits, and last_hidden_state
134
- """
135
- raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
136
-
137
- inputs = raw_outputs["inputs"]
138
- logits = raw_outputs["logits"]
139
- last_hidden_state = raw_outputs["last_hidden_state"]
140
-
141
- predictions = []
142
- for i in range(logits.shape[0]):
143
- i_logit = logits[i][inputs["input_ids"][i].ne(self.config.pad_token_id)][
144
- 1:-1
145
- ]
146
- prediction = self.tokenizer.decode(i_logit.argmax(dim=-1)).replace(" ", "")
147
- predictions.append(list(prediction))
148
-
149
- if not isinstance(sequence_or_inputs, list):
150
- outputs = {
151
- "predictions": predictions[0],
152
- "logits": logits[0],
153
- "last_hidden_state": last_hidden_state[0],
154
- }
155
- else:
156
- outputs = {
157
- "predictions": predictions,
158
- "logits": logits,
159
- "last_hidden_state": last_hidden_state,
160
- }
161
-
162
- return outputs
163
-
164
- def loss_function(self, logits, labels):
165
- """
166
- Compute the loss for masked language modeling.
167
-
168
- Args:
169
- logits (torch.Tensor): Model predictions [batch_size, seq_len, vocab_size]
170
- labels (torch.Tensor): Ground truth labels [batch_size, seq_len]
171
-
172
- Returns:
173
- torch.Tensor: Computed cross-entropy loss value
174
- """
175
- loss_fn = torch.nn.CrossEntropyLoss()
176
- loss = loss_fn(logits.view(-1, self.tokenizer.vocab_size), labels.view(-1))
177
- return loss
@@ -1,232 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: module_utils.py
3
- # time: 22:53 18/07/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- """
10
- Module utilities for OmniGenome models.
11
-
12
- This module provides utility classes and functions for handling model inputs,
13
- pooling operations, and attention mechanisms used across different OmniGenome model types.
14
- """
15
- import torch
16
- import torch.nn as nn
17
-
18
- from transformers.models.bert.modeling_bert import BertPooler
19
- from transformers.tokenization_utils_base import BatchEncoding
20
-
21
-
22
- class OmniPooling(torch.nn.Module):
23
- """
24
- A flexible pooling layer for OmniGenome models that handles different input formats.
25
-
26
- This class provides a unified interface for pooling operations across different
27
- model architectures, supporting both causal language models and encoder-based models.
28
- It can handle various input formats including tuples, dictionaries, BatchEncoding
29
- objects, and tensors.
30
-
31
- Attributes:
32
- config: Model configuration object containing architecture and tokenizer settings
33
- pooler: BertPooler instance for non-causal models, None for causal models
34
- """
35
-
36
- def __init__(self, config, *args, **kwargs):
37
- """
38
- Initialize the OmniPooling layer.
39
-
40
- Args:
41
- config: Model configuration object containing architecture information
42
- *args: Additional positional arguments
43
- **kwargs: Additional keyword arguments
44
- """
45
- super().__init__(*args, **kwargs)
46
- self.config = config
47
- self.pooler = BertPooler(self.config) if not self._is_causal_lm() else None
48
-
49
- def forward(self, inputs, last_hidden_state):
50
- """
51
- Perform pooling operation on the last hidden state.
52
-
53
- This method handles different input formats and applies appropriate pooling:
54
- - For causal language models: Uses the last non-padded token
55
- - For encoder models: Uses the BertPooler
56
-
57
- Args:
58
- inputs: Input data in various formats (tuple, dict, BatchEncoding, or tensor)
59
- last_hidden_state (torch.Tensor): Hidden states from the model [batch_size, seq_len, hidden_size]
60
-
61
- Returns:
62
- torch.Tensor: Pooled representation [batch_size, hidden_size]
63
-
64
- Raises:
65
- ValueError: If input format is not supported or cannot be parsed
66
- """
67
- if isinstance(inputs, tuple):
68
- input_ids = inputs[0]
69
- attention_mask = inputs[1] if len(inputs) > 1 else None
70
- elif isinstance(inputs, BatchEncoding) or isinstance(inputs, dict):
71
- input_ids = inputs["input_ids"]
72
- attention_mask = (
73
- inputs["attention_mask"] if "attention_mask" in inputs else None
74
- )
75
- elif isinstance(inputs, torch.Tensor):
76
- shape = inputs.shape
77
- try:
78
- if len(shape) == 3:
79
- # compatible with hf_trainer in AutoBenchmark
80
- if shape[1] == 2:
81
- input_ids = inputs[:, 0]
82
- attention_mask = inputs[:, 1]
83
- else:
84
- input_ids = inputs[0]
85
- attention_mask = inputs[1] if len(inputs) > 1 else None
86
- elif len(shape) == 2:
87
- input_ids = inputs
88
- attention_mask = None
89
- except:
90
- raise ValueError(
91
- f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}."
92
- )
93
- else:
94
- raise ValueError(
95
- f"The inputs should be a tuple, BatchEncoding or a dictionary-like object, got {type(inputs)}."
96
- )
97
-
98
- if not self.pooler:
99
- pad_token_id = getattr(self.config, "pad_token_id", -100)
100
- sequence_lengths = input_ids.ne(pad_token_id).sum(dim=1) - 1
101
- last_hidden_state = last_hidden_state[
102
- torch.arange(input_ids.size(0), device=last_hidden_state.device),
103
- sequence_lengths,
104
- ]
105
- else:
106
- last_hidden_state = self.pooler(last_hidden_state)
107
-
108
- return last_hidden_state
109
-
110
- def _is_causal_lm(self):
111
- """
112
- Check if the model is a causal language model.
113
-
114
- Determines if the model architecture is causal based on the configuration.
115
-
116
- Returns:
117
- bool: True if the model is a causal language model, False otherwise
118
- """
119
- if (
120
- hasattr(self.config, "architectures")
121
- and "CausalLM" in str(self.config.architectures)
122
- ) or (
123
- hasattr(self.config, "auto_map") and "CausalLM" in str(self.config.auto_map)
124
- ):
125
- return True
126
- else:
127
- return False
128
-
129
-
130
- # class InteractingAttention(nn.Module):
131
- # def __init__(self, embed_size, num_heads=12):
132
- # super(InteractingAttention, self).__init__()
133
- # self.num_heads = num_heads
134
- # self.embed_size = embed_size
135
- #
136
- # assert embed_size % num_heads == 0, "Embedding size should be divisible by number of heads"
137
- #
138
- # self.head_dim = embed_size // num_heads
139
- #
140
- # self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
141
- # self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
142
- # self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
143
- # self.layer_norm = nn.LayerNorm(num_heads * self.head_dim, eps=1e-6)
144
- #
145
- # self.fc_out = nn.Linear(num_heads * self.head_dim, embed_size)
146
- #
147
- # # def forward(self, query, keys, values):
148
- # def forward(self, query, keys, values):
149
- #
150
- # N = query.shape[0]
151
- # value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
152
- #
153
- # # Split embedding into self.num_heads pieces
154
- # values = values.reshape(N, value_len, self.num_heads, self.head_dim)
155
- # keys = keys.reshape(N, key_len, self.num_heads, self.head_dim)
156
- # queries = query.reshape(N, query_len, self.num_heads, self.head_dim)
157
- #
158
- # values = self.values(values) # (N, value_len, heads, head_dim)
159
- # keys = self.keys(keys) # (N, key_len, heads, head_dim)
160
- # queries = self.queries(queries) # (N, query_len, heads, head_dim)
161
- #
162
- # energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
163
- #
164
- # attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3) # (N, heads, query_len, key_len)
165
- #
166
- # out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
167
- # N, query_len, self.num_heads * self.head_dim
168
- # )
169
- # out = self.layer_norm(out + query)
170
- # out = self.fc_out(out)
171
- # out = self.layer_norm(out + query)
172
- # return out
173
-
174
-
175
- class InteractingAttention(nn.Module):
176
- """
177
- An interacting attention mechanism for sequence modeling.
178
-
179
- This class implements a multi-head attention mechanism with residual connections
180
- and layer normalization. It's designed for processing sequences where different
181
- parts of the sequence need to interact with each other.
182
-
183
- Attributes:
184
- attention: Multi-head attention layer
185
- layer_norm: Layer normalization for residual connections
186
- fc_out: Output projection layer
187
- """
188
-
189
- def __init__(self, embed_size, num_heads=24):
190
- """
191
- Initialize the InteractingAttention module.
192
-
193
- Args:
194
- embed_size (int): Size of the embedding dimension
195
- num_heads (int): Number of attention heads (default: 24)
196
-
197
- Raises:
198
- AssertionError: If embed_size is not divisible by num_heads
199
- """
200
- super(InteractingAttention, self).__init__()
201
- assert (
202
- embed_size % num_heads == 0
203
- ), "Embedding size should be divisible by number of heads"
204
-
205
- self.attention = nn.MultiheadAttention(
206
- embed_dim=embed_size, num_heads=num_heads, batch_first=True
207
- )
208
-
209
- self.layer_norm = nn.LayerNorm(embed_size, eps=1e-6)
210
-
211
- self.fc_out = nn.Linear(embed_size, embed_size)
212
-
213
- def forward(self, query, keys, values):
214
- """
215
- Forward pass through the interacting attention mechanism.
216
-
217
- Args:
218
- query (torch.Tensor): Query tensor [batch_size, query_len, embed_size]
219
- keys (torch.Tensor): Key tensor [batch_size, key_len, embed_size]
220
- values (torch.Tensor): Value tensor [batch_size, value_len, embed_size]
221
-
222
- Returns:
223
- torch.Tensor: Output tensor with same shape as query
224
- """
225
- att_output, _ = self.attention(query, keys, values)
226
-
227
- query = self.layer_norm(att_output + query)
228
-
229
- output = self.fc_out(query)
230
- output = self.layer_norm(output + query)
231
-
232
- return output
@@ -1,12 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # file: __init__.py
3
- # time: 21:11 08/04/2024
4
- # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
- # github: https://github.com/yangheng95
6
- # huggingface: https://huggingface.co/yangheng
7
- # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
- # Copyright (C) 2019-2024. All Rights Reserved.
9
- """
10
- This package contains modules for regression models.
11
- """
12
-