omnigenome 0.3.0a1__py3-none-any.whl → 1.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +26 -258
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/METADATA +9 -10
- omnigenome-1.0.0b0.dist-info/RECORD +6 -0
- omnigenome/auto/__init__.py +0 -3
- omnigenome/auto/auto_bench/__init__.py +0 -12
- omnigenome/auto/auto_bench/auto_bench.py +0 -484
- omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
- omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
- omnigenome/auto/auto_bench/config_check.py +0 -34
- omnigenome/auto/auto_train/__init__.py +0 -13
- omnigenome/auto/auto_train/auto_train.py +0 -430
- omnigenome/auto/auto_train/auto_train_cli.py +0 -222
- omnigenome/auto/bench_hub/__init__.py +0 -12
- omnigenome/auto/bench_hub/bench_hub.py +0 -25
- omnigenome/cli/__init__.py +0 -13
- omnigenome/cli/commands/__init__.py +0 -13
- omnigenome/cli/commands/base.py +0 -83
- omnigenome/cli/commands/bench/__init__.py +0 -13
- omnigenome/cli/commands/bench/bench_cli.py +0 -202
- omnigenome/cli/commands/rna/__init__.py +0 -13
- omnigenome/cli/commands/rna/rna_design.py +0 -178
- omnigenome/cli/omnigenome_cli.py +0 -128
- omnigenome/src/__init__.py +0 -12
- omnigenome/src/abc/__init__.py +0 -12
- omnigenome/src/abc/abstract_dataset.py +0 -622
- omnigenome/src/abc/abstract_metric.py +0 -114
- omnigenome/src/abc/abstract_model.py +0 -689
- omnigenome/src/abc/abstract_tokenizer.py +0 -267
- omnigenome/src/dataset/__init__.py +0 -16
- omnigenome/src/dataset/omni_dataset.py +0 -435
- omnigenome/src/lora/__init__.py +0 -13
- omnigenome/src/lora/lora_model.py +0 -294
- omnigenome/src/metric/__init__.py +0 -15
- omnigenome/src/metric/classification_metric.py +0 -184
- omnigenome/src/metric/metric.py +0 -199
- omnigenome/src/metric/ranking_metric.py +0 -142
- omnigenome/src/metric/regression_metric.py +0 -191
- omnigenome/src/misc/__init__.py +0 -3
- omnigenome/src/misc/utils.py +0 -499
- omnigenome/src/model/__init__.py +0 -19
- omnigenome/src/model/augmentation/__init__.py +0 -12
- omnigenome/src/model/augmentation/model.py +0 -219
- omnigenome/src/model/classification/__init__.py +0 -12
- omnigenome/src/model/classification/model.py +0 -642
- omnigenome/src/model/embedding/__init__.py +0 -12
- omnigenome/src/model/embedding/model.py +0 -263
- omnigenome/src/model/mlm/__init__.py +0 -12
- omnigenome/src/model/mlm/model.py +0 -177
- omnigenome/src/model/module_utils.py +0 -232
- omnigenome/src/model/regression/__init__.py +0 -12
- omnigenome/src/model/regression/model.py +0 -786
- omnigenome/src/model/regression/resnet.py +0 -483
- omnigenome/src/model/rna_design/__init__.py +0 -12
- omnigenome/src/model/rna_design/model.py +0 -469
- omnigenome/src/model/seq2seq/__init__.py +0 -12
- omnigenome/src/model/seq2seq/model.py +0 -44
- omnigenome/src/tokenizer/__init__.py +0 -16
- omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
- omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
- omnigenome/src/trainer/__init__.py +0 -14
- omnigenome/src/trainer/accelerate_trainer.py +0 -739
- omnigenome/src/trainer/hf_trainer.py +0 -75
- omnigenome/src/trainer/trainer.py +0 -579
- omnigenome/utility/__init__.py +0 -3
- omnigenome/utility/dataset_hub/__init__.py +0 -13
- omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
- omnigenome/utility/ensemble.py +0 -324
- omnigenome/utility/hub_utils.py +0 -517
- omnigenome/utility/model_hub/__init__.py +0 -12
- omnigenome/utility/model_hub/model_hub.py +0 -231
- omnigenome/utility/pipeline_hub/__init__.py +0 -12
- omnigenome/utility/pipeline_hub/pipeline.py +0 -483
- omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- omnigenome-0.3.0a1.dist-info/entry_points.txt +0 -3
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-1.0.0b0.dist-info}/top_level.txt +0 -0
|
@@ -1,689 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: omnigenome_model.py
|
|
3
|
-
# time: 18:36 06/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
import json
|
|
10
|
-
import os
|
|
11
|
-
import shutil
|
|
12
|
-
import warnings
|
|
13
|
-
import inspect
|
|
14
|
-
from importlib import import_module
|
|
15
|
-
|
|
16
|
-
import dill
|
|
17
|
-
import findfile
|
|
18
|
-
import torch
|
|
19
|
-
from transformers import AutoModel, AutoConfig, AutoTokenizer, BatchEncoding
|
|
20
|
-
|
|
21
|
-
from ..misc.utils import fprint, env_meta_info
|
|
22
|
-
|
|
23
|
-
warnings.filterwarnings("once")
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def count_parameters(model):
|
|
27
|
-
"""
|
|
28
|
-
Counts the number of trainable parameters in a model.
|
|
29
|
-
|
|
30
|
-
This function iterates through all parameters of a PyTorch model and counts
|
|
31
|
-
only those that require gradients (i.e., trainable parameters).
|
|
32
|
-
|
|
33
|
-
Args:
|
|
34
|
-
model (torch.nn.Module): A PyTorch model.
|
|
35
|
-
|
|
36
|
-
Returns:
|
|
37
|
-
int: The total number of trainable parameters.
|
|
38
|
-
|
|
39
|
-
Example:
|
|
40
|
-
>>> model = OmniModelForSequenceClassification(config, tokenizer)
|
|
41
|
-
>>> num_params = count_parameters(model)
|
|
42
|
-
>>> print(f"Model has {num_params} trainable parameters")
|
|
43
|
-
"""
|
|
44
|
-
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
class OmniModel(torch.nn.Module):
|
|
48
|
-
"""
|
|
49
|
-
Abstract base class for all models in OmniGenome.
|
|
50
|
-
|
|
51
|
-
This class provides a unified interface for all genomic models in the OmniGenome
|
|
52
|
-
framework. It handles model initialization, forward passes, loss computation,
|
|
53
|
-
prediction, inference, and model persistence.
|
|
54
|
-
|
|
55
|
-
The class is designed to work with various types of genomic data and tasks,
|
|
56
|
-
including sequence classification, token classification, regression, and more.
|
|
57
|
-
|
|
58
|
-
Attributes:
|
|
59
|
-
model (torch.nn.Module): The underlying PyTorch model.
|
|
60
|
-
config: The model configuration.
|
|
61
|
-
tokenizer: The tokenizer associated with the model.
|
|
62
|
-
metadata (dict): Metadata about the model including version info.
|
|
63
|
-
loss_fn: The loss function for training.
|
|
64
|
-
dropout (torch.nn.Dropout): Dropout layer for regularization.
|
|
65
|
-
activation (torch.nn.Tanh): Activation function.
|
|
66
|
-
pad_token_id (int): ID of the padding token.
|
|
67
|
-
"""
|
|
68
|
-
|
|
69
|
-
def __init__(self, config_or_model, tokenizer, *args, **kwargs):
|
|
70
|
-
"""
|
|
71
|
-
Initializes the model.
|
|
72
|
-
|
|
73
|
-
This method handles different types of model initialization:
|
|
74
|
-
- From a pre-trained model path (string)
|
|
75
|
-
- From a PyTorch model instance
|
|
76
|
-
- From a configuration object
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
config_or_model: A model configuration, a pre-trained model path (str),
|
|
80
|
-
or a `torch.nn.Module` instance.
|
|
81
|
-
tokenizer: The tokenizer associated with the model.
|
|
82
|
-
*args: Additional positional arguments.
|
|
83
|
-
**kwargs: Additional keyword arguments.
|
|
84
|
-
- label2id (dict): Mapping from class labels to IDs.
|
|
85
|
-
- num_labels (int): The number of labels.
|
|
86
|
-
- trust_remote_code (bool): Whether to trust remote code when loading
|
|
87
|
-
from Hugging Face Hub. Defaults to True.
|
|
88
|
-
- ignore_mismatched_sizes (bool): Whether to ignore size mismatches
|
|
89
|
-
when loading pre-trained weights. Defaults to False.
|
|
90
|
-
- dropout (float): Dropout rate. Defaults to 0.0.
|
|
91
|
-
|
|
92
|
-
Raises:
|
|
93
|
-
ValueError: If config_or_model is not a valid type or if required
|
|
94
|
-
configuration is missing.
|
|
95
|
-
RuntimeError: If the hidden size cannot be determined from the config.
|
|
96
|
-
|
|
97
|
-
Example:
|
|
98
|
-
>>> # Initialize from a pre-trained model
|
|
99
|
-
>>> model = OmniModelForSequenceClassification("model_path", tokenizer)
|
|
100
|
-
|
|
101
|
-
>>> # Initialize from a configuration
|
|
102
|
-
>>> config = AutoConfig.from_pretrained("model_path")
|
|
103
|
-
>>> model = OmniModelForSequenceClassification(config, tokenizer)
|
|
104
|
-
"""
|
|
105
|
-
self.loss_fn = None
|
|
106
|
-
|
|
107
|
-
label2id = kwargs.pop("label2id", None)
|
|
108
|
-
trust_remote_code = kwargs.pop("trust_remote_code", True)
|
|
109
|
-
num_labels = kwargs.pop("num_labels", None)
|
|
110
|
-
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
|
111
|
-
|
|
112
|
-
if label2id is not None and num_labels is None:
|
|
113
|
-
num_labels = len(label2id)
|
|
114
|
-
elif num_labels is not None and label2id is None:
|
|
115
|
-
label2id = {str(i): i for i in range(num_labels)}
|
|
116
|
-
|
|
117
|
-
# do not change the order of the following lines
|
|
118
|
-
super().__init__(*args, **kwargs)
|
|
119
|
-
|
|
120
|
-
if isinstance(config_or_model, str):
|
|
121
|
-
config = AutoConfig.from_pretrained(
|
|
122
|
-
config_or_model,
|
|
123
|
-
num_labels=num_labels,
|
|
124
|
-
label2id=label2id,
|
|
125
|
-
trust_remote_code=trust_remote_code,
|
|
126
|
-
)
|
|
127
|
-
# Load the model from either `architectures` or `auto_map`
|
|
128
|
-
if hasattr(config, "auto_map") and config.auto_map:
|
|
129
|
-
architectures = list(set(config.auto_map.keys()) - set(["AutoConfig"]))
|
|
130
|
-
if architectures:
|
|
131
|
-
model_cls_name = (
|
|
132
|
-
"AutoModel"
|
|
133
|
-
if "AutoModel" in architectures
|
|
134
|
-
else architectures[-1]
|
|
135
|
-
)
|
|
136
|
-
model_cls = getattr(import_module(f"transformers"), model_cls_name)
|
|
137
|
-
|
|
138
|
-
model = model_cls.from_pretrained(
|
|
139
|
-
config_or_model,
|
|
140
|
-
config=config,
|
|
141
|
-
trust_remote_code=trust_remote_code,
|
|
142
|
-
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
143
|
-
).base_model
|
|
144
|
-
else:
|
|
145
|
-
raise ValueError(
|
|
146
|
-
f"The model cannot be instantiated from {config_or_model}. "
|
|
147
|
-
f"Please check the model configuration contains the architectures or auto_map."
|
|
148
|
-
)
|
|
149
|
-
elif hasattr(config, "architectures") and config.architectures:
|
|
150
|
-
model_cls_name = (
|
|
151
|
-
AutoModel
|
|
152
|
-
if "AutoModel" in config.architectures
|
|
153
|
-
else config.architectures[-1]
|
|
154
|
-
)
|
|
155
|
-
model_cls = getattr(import_module(f"transformers"), model_cls_name)
|
|
156
|
-
model = model_cls.from_pretrained(
|
|
157
|
-
config_or_model,
|
|
158
|
-
config=config,
|
|
159
|
-
trust_remote_code=trust_remote_code,
|
|
160
|
-
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
161
|
-
).base_model
|
|
162
|
-
else:
|
|
163
|
-
raise ValueError(
|
|
164
|
-
"Neither `architectures` nor `auto_map` is defined in the config."
|
|
165
|
-
)
|
|
166
|
-
self.model = model
|
|
167
|
-
self.model.config = config
|
|
168
|
-
del model_cls
|
|
169
|
-
elif isinstance(config_or_model, torch.nn.Module):
|
|
170
|
-
self.model = config_or_model
|
|
171
|
-
self.model.config.num_labels = (
|
|
172
|
-
num_labels if len(label2id) == num_labels else len(label2id)
|
|
173
|
-
)
|
|
174
|
-
self.model.config.label2id = label2id
|
|
175
|
-
elif isinstance(config_or_model, AutoConfig):
|
|
176
|
-
config = config_or_model
|
|
177
|
-
config.num_labels = (
|
|
178
|
-
num_labels if len(label2id) == num_labels else len(label2id)
|
|
179
|
-
)
|
|
180
|
-
config.label2id = label2id
|
|
181
|
-
self.model = AutoModel.from_config(config)
|
|
182
|
-
self.model.config = config
|
|
183
|
-
else:
|
|
184
|
-
raise ValueError(
|
|
185
|
-
"The config_or_model should be either a string, a torch.nn.Module or a AutoConfig object."
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
# Update the config
|
|
189
|
-
self.config = self.model.config
|
|
190
|
-
if isinstance(label2id, dict):
|
|
191
|
-
self.config.label2id = label2id
|
|
192
|
-
self.config.id2label = {v: k for k, v in label2id.items()}
|
|
193
|
-
if (
|
|
194
|
-
not hasattr(self.config, "num_labels")
|
|
195
|
-
or len(self.config.id2label) != self.config.num_labels
|
|
196
|
-
):
|
|
197
|
-
fprint(
|
|
198
|
-
"Warning: The number of labels in the config is not equal to the number of labels in the label2id dictionary. "
|
|
199
|
-
)
|
|
200
|
-
fprint(
|
|
201
|
-
"Please check the label2id dictionary and the num_labels parameter in the config."
|
|
202
|
-
)
|
|
203
|
-
self.config.num_labels = len(self.config.id2label)
|
|
204
|
-
|
|
205
|
-
assert len(self.config.label2id) == num_labels, f"Expected {num_labels} labels, but got {len(self.config.label2id)} in label2id dictionary."
|
|
206
|
-
|
|
207
|
-
# The metadata of the model
|
|
208
|
-
self.metadata = env_meta_info()
|
|
209
|
-
self.metadata["model_cls"] = self.__class__.__name__
|
|
210
|
-
|
|
211
|
-
# The config of the model
|
|
212
|
-
if hasattr(self.config, "n_embd") and self.config.n_embd:
|
|
213
|
-
self.config.hidden_size = self.config.n_embd
|
|
214
|
-
elif hasattr(self.config, "d_model") and self.config.d_model:
|
|
215
|
-
self.config.hidden_size = self.config.d_model
|
|
216
|
-
elif hasattr(self.config, "hidden_size") and self.config.hidden_size:
|
|
217
|
-
self.config.hidden_size = self.config.hidden_size
|
|
218
|
-
else:
|
|
219
|
-
raise RuntimeError(
|
|
220
|
-
"The hidden size of the model is not found in the config."
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
# The tokenizer of the model
|
|
224
|
-
self.tokenizer = tokenizer
|
|
225
|
-
self.metadata["tokenizer_cls"] = self.tokenizer.__class__.__name__
|
|
226
|
-
if hasattr(self.tokenizer, "base_tokenizer"):
|
|
227
|
-
self.pad_token_id = self.tokenizer.base_tokenizer.pad_token_id
|
|
228
|
-
else:
|
|
229
|
-
self.pad_token_id = self.tokenizer.pad_token_id
|
|
230
|
-
|
|
231
|
-
self.dropout = torch.nn.Dropout(kwargs.get("dropout", 0.0))
|
|
232
|
-
self.activation = torch.nn.Tanh()
|
|
233
|
-
|
|
234
|
-
def last_hidden_state_forward(self, **inputs):
|
|
235
|
-
"""
|
|
236
|
-
Performs a forward pass to get the last hidden state from the base model.
|
|
237
|
-
|
|
238
|
-
This method handles the forward pass through the underlying model and
|
|
239
|
-
returns the last hidden state. It also handles compatibility with different
|
|
240
|
-
model architectures by mapping input parameters appropriately.
|
|
241
|
-
|
|
242
|
-
Args:
|
|
243
|
-
**inputs: The inputs to the model, compatible with the base model's
|
|
244
|
-
forward method. Typically includes 'input_ids', 'attention_mask',
|
|
245
|
-
and other model-specific parameters.
|
|
246
|
-
|
|
247
|
-
Returns:
|
|
248
|
-
torch.Tensor: The last hidden state tensor.
|
|
249
|
-
|
|
250
|
-
Example:
|
|
251
|
-
>>> inputs = {
|
|
252
|
-
... 'input_ids': torch.tensor([[1, 2, 3, 4]]),
|
|
253
|
-
... 'attention_mask': torch.tensor([[1, 1, 1, 1]])
|
|
254
|
-
... }
|
|
255
|
-
>>> hidden_states = model.last_hidden_state_forward(**inputs)
|
|
256
|
-
"""
|
|
257
|
-
model = self.model
|
|
258
|
-
input_mapping = {}
|
|
259
|
-
inputs["output_hidden_states"] = True
|
|
260
|
-
|
|
261
|
-
if "strippedhyena" in model.__class__.__name__.lower():
|
|
262
|
-
inputs["x"] = inputs["input_ids"] # For compatibility with Evo models
|
|
263
|
-
if isinstance(inputs, BatchEncoding) or isinstance(inputs, dict):
|
|
264
|
-
# Determine the input parameter names of the model's forward method
|
|
265
|
-
forward_params = inspect.signature(model.forward).parameters
|
|
266
|
-
# Map the inputs to the forward method parameters
|
|
267
|
-
for param in forward_params:
|
|
268
|
-
if param in inputs:
|
|
269
|
-
input_mapping[param] = inputs[param]
|
|
270
|
-
# 对于未在模型签名中声明的关键参数,可以给出警告或日志
|
|
271
|
-
ignored_keys = set(inputs.keys()) - set(input_mapping.keys())
|
|
272
|
-
if ignored_keys:
|
|
273
|
-
warnings.warn(f"Warning: Ignored keys in inputs: {ignored_keys}")
|
|
274
|
-
|
|
275
|
-
inputs = input_mapping
|
|
276
|
-
elif isinstance(inputs, tuple):
|
|
277
|
-
input_ids = inputs[0]
|
|
278
|
-
attention_mask = inputs[1] if len(inputs) > 1 else None
|
|
279
|
-
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
280
|
-
elif isinstance(inputs, torch.Tensor):
|
|
281
|
-
shape = inputs.shape
|
|
282
|
-
try:
|
|
283
|
-
if len(shape) == 3:
|
|
284
|
-
if shape[1] == 2:
|
|
285
|
-
input_ids = inputs[:, 0]
|
|
286
|
-
attention_mask = inputs[:, 1]
|
|
287
|
-
else:
|
|
288
|
-
input_ids = inputs[0]
|
|
289
|
-
attention_mask = inputs[1] if len(inputs) > 1 else None
|
|
290
|
-
elif len(shape) == 2:
|
|
291
|
-
input_ids = inputs
|
|
292
|
-
attention_mask = None
|
|
293
|
-
else:
|
|
294
|
-
raise ValueError(
|
|
295
|
-
f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}."
|
|
296
|
-
)
|
|
297
|
-
except:
|
|
298
|
-
raise ValueError(
|
|
299
|
-
f"Failed to get the input_ids and attention_mask from the inputs, got shape {shape}."
|
|
300
|
-
)
|
|
301
|
-
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
302
|
-
else:
|
|
303
|
-
raise ValueError(
|
|
304
|
-
f"The inputs should be a tuple, BatchEncoding or a dictionary-like object, got {type(inputs)}."
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
# 执行模型
|
|
308
|
-
outputs = model(**inputs)
|
|
309
|
-
|
|
310
|
-
if not hasattr(outputs, "last_hidden_state"):
|
|
311
|
-
warnings.warn(
|
|
312
|
-
f"last_hidden_state not found in the outputs from the {model.__class__.__name__} model."
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
if hasattr(outputs, "last_hidden_state"):
|
|
316
|
-
last_hidden_state = outputs.last_hidden_state
|
|
317
|
-
elif isinstance(outputs, dict) and "last_hidden_state" in outputs:
|
|
318
|
-
last_hidden_state = outputs["last_hidden_state"]
|
|
319
|
-
elif hasattr(outputs, "hidden_states"):
|
|
320
|
-
last_hidden_state = outputs.hidden_states[-1]
|
|
321
|
-
elif isinstance(outputs, (list, tuple, torch.Tensor)):
|
|
322
|
-
if len(outputs) <= 2:
|
|
323
|
-
# For Evo models that return a tuple of (last_hidden_state, logits)
|
|
324
|
-
last_hidden_state = outputs[0]
|
|
325
|
-
elif len(outputs) >= 3:
|
|
326
|
-
last_hidden_state = outputs[-1]
|
|
327
|
-
else:
|
|
328
|
-
raise ValueError(
|
|
329
|
-
f"Cannot find the last hidden state in the outputs from the {model.__class__.__name__} model, "
|
|
330
|
-
f"please check the model architecture."
|
|
331
|
-
)
|
|
332
|
-
|
|
333
|
-
return last_hidden_state
|
|
334
|
-
|
|
335
|
-
def loss_function(self, logits, labels):
|
|
336
|
-
"""
|
|
337
|
-
Calculates the loss. Must be implemented by subclasses.
|
|
338
|
-
|
|
339
|
-
This method should be implemented by concrete model classes to define
|
|
340
|
-
how the loss is calculated for their specific task (classification,
|
|
341
|
-
regression, etc.).
|
|
342
|
-
|
|
343
|
-
Args:
|
|
344
|
-
logits (torch.Tensor): The model's output logits.
|
|
345
|
-
labels (torch.Tensor): The ground truth labels.
|
|
346
|
-
|
|
347
|
-
Returns:
|
|
348
|
-
torch.Tensor: The calculated loss.
|
|
349
|
-
|
|
350
|
-
Raises:
|
|
351
|
-
NotImplementedError: If the method is not implemented by the subclass.
|
|
352
|
-
|
|
353
|
-
Example:
|
|
354
|
-
>>> # In a classification model
|
|
355
|
-
>>> loss = model.loss_function(logits, labels)
|
|
356
|
-
"""
|
|
357
|
-
raise NotImplementedError(
|
|
358
|
-
"The loss_function() function should be implemented for your model."
|
|
359
|
-
)
|
|
360
|
-
|
|
361
|
-
def set_loss_fn(self, loss_function):
|
|
362
|
-
"""
|
|
363
|
-
Sets a custom loss function for the model.
|
|
364
|
-
|
|
365
|
-
This method allows setting a custom loss function that will be used
|
|
366
|
-
during training. The loss function should be compatible with the
|
|
367
|
-
model's output format.
|
|
368
|
-
|
|
369
|
-
Args:
|
|
370
|
-
loss_function (callable): A callable loss function that takes
|
|
371
|
-
logits and labels as arguments.
|
|
372
|
-
|
|
373
|
-
Example:
|
|
374
|
-
>>> import torch.nn as nn
|
|
375
|
-
>>> model.set_loss_fn(nn.CrossEntropyLoss())
|
|
376
|
-
"""
|
|
377
|
-
self.loss_fn = loss_function
|
|
378
|
-
|
|
379
|
-
def predict(self, sequence_or_inputs, **kwargs):
|
|
380
|
-
"""
|
|
381
|
-
Performs prediction on raw inputs. Returns raw model outputs.
|
|
382
|
-
|
|
383
|
-
This method takes raw sequences or tokenized inputs and returns
|
|
384
|
-
the raw model outputs (logits, hidden states, etc.) without
|
|
385
|
-
post-processing. It's useful for getting the model's direct
|
|
386
|
-
predictions for further processing.
|
|
387
|
-
|
|
388
|
-
Args:
|
|
389
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
390
|
-
tokenized inputs (dict/tuple).
|
|
391
|
-
**kwargs: Additional arguments for tokenization and inference.
|
|
392
|
-
|
|
393
|
-
Returns:
|
|
394
|
-
dict: A dictionary containing the raw model outputs, typically
|
|
395
|
-
including 'logits', 'last_hidden_state', and other
|
|
396
|
-
model-specific outputs.
|
|
397
|
-
|
|
398
|
-
Example:
|
|
399
|
-
>>> # Predict on a single sequence
|
|
400
|
-
>>> outputs = model.predict("ATCGATCG")
|
|
401
|
-
|
|
402
|
-
>>> # Predict on multiple sequences
|
|
403
|
-
>>> outputs = model.predict(["ATCGATCG", "GCTAGCTA"])
|
|
404
|
-
"""
|
|
405
|
-
# Please implement the predict() function for your model
|
|
406
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
407
|
-
return raw_outputs
|
|
408
|
-
|
|
409
|
-
def inference(self, sequence_or_inputs, **kwargs):
|
|
410
|
-
"""
|
|
411
|
-
Performs inference on raw inputs. Returns processed, human-readable predictions.
|
|
412
|
-
|
|
413
|
-
This method takes raw sequences or tokenized inputs and returns
|
|
414
|
-
processed predictions that are ready for human consumption. It
|
|
415
|
-
typically includes post-processing steps like converting logits
|
|
416
|
-
to class labels or probabilities.
|
|
417
|
-
|
|
418
|
-
Args:
|
|
419
|
-
sequence_or_inputs: A sequence (str), list of sequences, or
|
|
420
|
-
tokenized inputs (dict/tuple).
|
|
421
|
-
**kwargs: Additional arguments for tokenization and inference.
|
|
422
|
-
|
|
423
|
-
Returns:
|
|
424
|
-
dict: A dictionary containing the processed predictions, typically
|
|
425
|
-
including 'predictions', 'confidence', and other
|
|
426
|
-
human-readable outputs.
|
|
427
|
-
|
|
428
|
-
Example:
|
|
429
|
-
>>> # Inference on a single sequence
|
|
430
|
-
>>> results = model.inference("ATCGATCG")
|
|
431
|
-
>>> print(results['predictions']) # Class labels
|
|
432
|
-
|
|
433
|
-
>>> # Inference on multiple sequences
|
|
434
|
-
>>> results = model.inference(["ATCGATCG", "GCTAGCTA"])
|
|
435
|
-
"""
|
|
436
|
-
# Please implement the predict() function for your model
|
|
437
|
-
raw_outputs = self._forward_from_raw_input(sequence_or_inputs, **kwargs)
|
|
438
|
-
return raw_outputs
|
|
439
|
-
|
|
440
|
-
def __call__(self, **inputs):
|
|
441
|
-
"""
|
|
442
|
-
The main forward pass of the model, suitable for training loops.
|
|
443
|
-
|
|
444
|
-
This method is the primary interface for model forward passes during
|
|
445
|
-
training. It handles both tokenized inputs and raw sequences,
|
|
446
|
-
calculates loss if labels are provided, and returns a comprehensive
|
|
447
|
-
output dictionary.
|
|
448
|
-
|
|
449
|
-
Args:
|
|
450
|
-
**inputs: A dictionary of tokenized inputs, potentially including
|
|
451
|
-
labels. Can also handle raw sequences that will be
|
|
452
|
-
tokenized automatically.
|
|
453
|
-
|
|
454
|
-
Returns:
|
|
455
|
-
dict: A dictionary containing logits, last_hidden_state, labels,
|
|
456
|
-
and loss (if labels were provided).
|
|
457
|
-
|
|
458
|
-
Example:
|
|
459
|
-
>>> # Training forward pass
|
|
460
|
-
>>> outputs = model(
|
|
461
|
-
... input_ids=torch.tensor([[1, 2, 3, 4]]),
|
|
462
|
-
... attention_mask=torch.tensor([[1, 1, 1, 1]]),
|
|
463
|
-
... labels=torch.tensor([0])
|
|
464
|
-
... )
|
|
465
|
-
>>> loss = outputs['loss']
|
|
466
|
-
"""
|
|
467
|
-
# For transformer trainer integration, we need to pop the "inputs" to be a tokenized inputs object.
|
|
468
|
-
# For native trainer, the inputs are already tokenized inputs object
|
|
469
|
-
labels = inputs.pop("labels", None)
|
|
470
|
-
inputs = inputs.pop("inputs", inputs)
|
|
471
|
-
inputs["labels"] = labels
|
|
472
|
-
if isinstance(inputs, dict):
|
|
473
|
-
|
|
474
|
-
labels = inputs.get("labels", None)
|
|
475
|
-
label = inputs.get("label", None)
|
|
476
|
-
labels = labels if labels is not None else label
|
|
477
|
-
# if labels is None:
|
|
478
|
-
# warnings.warn(
|
|
479
|
-
# "No labels are provided in the inputs, the model will not calculate the loss."
|
|
480
|
-
# )
|
|
481
|
-
elif isinstance(inputs, tuple):
|
|
482
|
-
labels = inputs[1]
|
|
483
|
-
inputs = inputs[0]
|
|
484
|
-
elif labels is not None:
|
|
485
|
-
labels = labels
|
|
486
|
-
outputs = self.forward(**inputs)
|
|
487
|
-
|
|
488
|
-
if labels is not None:
|
|
489
|
-
outputs["loss"] = self._calculate_loss(outputs, labels)
|
|
490
|
-
else:
|
|
491
|
-
outputs["loss"] = None
|
|
492
|
-
return outputs
|
|
493
|
-
|
|
494
|
-
def _calculate_loss(self, outputs, labels):
|
|
495
|
-
"""
|
|
496
|
-
Internal method to calculate loss if not already present in outputs.
|
|
497
|
-
|
|
498
|
-
:param outputs: The dictionary of model outputs.
|
|
499
|
-
:param labels: The ground truth labels.
|
|
500
|
-
:return: The calculated loss.
|
|
501
|
-
"""
|
|
502
|
-
loss = outputs.get("loss", None)
|
|
503
|
-
if loss is not None:
|
|
504
|
-
return outputs
|
|
505
|
-
|
|
506
|
-
logits = outputs["logits"]
|
|
507
|
-
if logits is not None or labels is not None:
|
|
508
|
-
loss = self.loss_function(logits, labels)
|
|
509
|
-
return loss
|
|
510
|
-
else:
|
|
511
|
-
raise RuntimeError(
|
|
512
|
-
"The output of the forward() function should be a dictionary-like objective"
|
|
513
|
-
" and have either 'loss', or 'logits' and 'labels' attribute."
|
|
514
|
-
)
|
|
515
|
-
|
|
516
|
-
def save(self, path, overwrite=False, dtype=torch.float16, **kwargs):
|
|
517
|
-
"""
|
|
518
|
-
Saves the model, tokenizer, and metadata to a directory.
|
|
519
|
-
|
|
520
|
-
:param path: The directory to save the model to.
|
|
521
|
-
:param overwrite: Whether to overwrite the directory if it exists.
|
|
522
|
-
:param dtype: The data type to save the model weights in.
|
|
523
|
-
:param kwargs: Additional arguments.
|
|
524
|
-
"""
|
|
525
|
-
self.eval()
|
|
526
|
-
|
|
527
|
-
if os.path.exists(path) and not overwrite:
|
|
528
|
-
raise FileExistsError(
|
|
529
|
-
f"The path {path} already exists, please set overwrite=True to overwrite it."
|
|
530
|
-
)
|
|
531
|
-
|
|
532
|
-
if not os.path.exists(path):
|
|
533
|
-
os.makedirs(path)
|
|
534
|
-
|
|
535
|
-
for file in findfile.find_files(
|
|
536
|
-
self.config.name_or_path,
|
|
537
|
-
or_key=["bin", "json", "txt", "py"],
|
|
538
|
-
exclude_key=["pytorch_model.bin", "model.safetensors"],
|
|
539
|
-
):
|
|
540
|
-
shutil.copyfile(file, f"{path}/{os.path.basename(file)}")
|
|
541
|
-
|
|
542
|
-
_device = self.model.device
|
|
543
|
-
_dtype = self.model.dtype
|
|
544
|
-
self.model.to(dtype).to("cpu")
|
|
545
|
-
self.tokenizer.save_pretrained(path)
|
|
546
|
-
|
|
547
|
-
# Save metadata including information about the loss function
|
|
548
|
-
metadata = self.metadata.copy()
|
|
549
|
-
if self.loss_fn is not None:
|
|
550
|
-
metadata["loss_fn_class"] = self.loss_fn.__class__.__name__
|
|
551
|
-
metadata["loss_fn_module"] = self.loss_fn.__class__.__module__
|
|
552
|
-
|
|
553
|
-
with open(f"{path}/metadata.json", "w", encoding="utf8") as f:
|
|
554
|
-
json.dump(metadata, f)
|
|
555
|
-
with open(f"{path}/tokenizer.bin", "wb") as f:
|
|
556
|
-
dill.dump(self.tokenizer, f)
|
|
557
|
-
self.model.save_pretrained(
|
|
558
|
-
f"{path}", safe_serialization=False
|
|
559
|
-
) # do not remove this line, used to save customized model scripts
|
|
560
|
-
|
|
561
|
-
# Save complete state dict including all components
|
|
562
|
-
with open(f"{path}/pytorch_model.bin", "wb") as f:
|
|
563
|
-
torch.save(self.state_dict(), f)
|
|
564
|
-
|
|
565
|
-
self.model.to(_dtype).to(_device)
|
|
566
|
-
fprint(f"The model is saved to {path}.")
|
|
567
|
-
|
|
568
|
-
def load(self, path, **kwargs):
|
|
569
|
-
"""
|
|
570
|
-
Loads the model, tokenizer, and metadata from a directory.
|
|
571
|
-
|
|
572
|
-
:param path: The directory to load the model from.
|
|
573
|
-
:param kwargs: Additional arguments.
|
|
574
|
-
:return: The loaded model instance.
|
|
575
|
-
"""
|
|
576
|
-
with open(f"{path}/metadata.json", "r", encoding="utf8") as f:
|
|
577
|
-
metadata = json.load(f)
|
|
578
|
-
|
|
579
|
-
if metadata["model_cls"] != self.__class__.__name__: # Check the model class
|
|
580
|
-
raise ValueError(
|
|
581
|
-
f"The model class in the loaded model is {metadata['model_cls']}, "
|
|
582
|
-
f"but the current model class is {self.__class__.__name__}."
|
|
583
|
-
)
|
|
584
|
-
config = AutoConfig.from_pretrained(path, trust_remote_code=True, **kwargs)
|
|
585
|
-
|
|
586
|
-
for key, value in config.__dict__.items():
|
|
587
|
-
if key not in self.config.__dict__ or self.config.__dict__[key] != value:
|
|
588
|
-
fprint(
|
|
589
|
-
f"Warning: The value of the key {key} in the loaded model is {value}, "
|
|
590
|
-
f"but the current value is {self.config.__dict__.get(key, None)}."
|
|
591
|
-
)
|
|
592
|
-
|
|
593
|
-
# Attempt to restore any saved loss function
|
|
594
|
-
if "loss_fn_class" in metadata and "loss_fn_module" in metadata:
|
|
595
|
-
try:
|
|
596
|
-
loss_module = import_module(metadata["loss_fn_module"])
|
|
597
|
-
loss_class = getattr(loss_module, metadata["loss_fn_class"])
|
|
598
|
-
# Initialize loss function if possible (parameters will be loaded with state dict)
|
|
599
|
-
self.loss_fn = loss_class()
|
|
600
|
-
fprint(
|
|
601
|
-
f"Restored loss function: {metadata['loss_fn_class']} from {metadata['loss_fn_module']}"
|
|
602
|
-
)
|
|
603
|
-
except (ImportError, AttributeError) as e:
|
|
604
|
-
warnings.warn(f"Could not restore loss function: {e}")
|
|
605
|
-
|
|
606
|
-
with open(f"{path}/pytorch_model.bin", "rb") as f:
|
|
607
|
-
loaded_state_dict = torch.load(f, map_location=kwargs.get("device", "cpu"))
|
|
608
|
-
|
|
609
|
-
# Check if keys match between current and loaded state dict
|
|
610
|
-
current_keys = set(self.state_dict().keys())
|
|
611
|
-
loaded_keys = set(loaded_state_dict.keys())
|
|
612
|
-
missing_keys = current_keys - loaded_keys
|
|
613
|
-
unexpected_keys = loaded_keys - current_keys
|
|
614
|
-
|
|
615
|
-
if missing_keys:
|
|
616
|
-
warnings.warn(f"Missing keys in loaded weights: {missing_keys}")
|
|
617
|
-
if unexpected_keys:
|
|
618
|
-
warnings.warn(f"Unexpected keys in loaded weights: {unexpected_keys}")
|
|
619
|
-
|
|
620
|
-
self.load_state_dict(loaded_state_dict, strict=False)
|
|
621
|
-
# Load the tokenizer
|
|
622
|
-
if os.path.exists(f"{path}/tokenizer.bin"):
|
|
623
|
-
with open(f"{path}/tokenizer.bin", "rb") as f:
|
|
624
|
-
self.tokenizer = dill.load(f)
|
|
625
|
-
|
|
626
|
-
return self
|
|
627
|
-
|
|
628
|
-
def _forward_from_raw_input(self, sequence_or_inputs, **kwargs):
|
|
629
|
-
"""
|
|
630
|
-
Tokenizes raw input and performs a forward pass in no_grad mode.
|
|
631
|
-
|
|
632
|
-
:param sequence_or_inputs: A sequence, list of sequences, or tokenized inputs.
|
|
633
|
-
:param kwargs: Additional arguments for tokenization.
|
|
634
|
-
:return: A dictionary containing the raw model outputs and the tokenized inputs.
|
|
635
|
-
"""
|
|
636
|
-
if not isinstance(sequence_or_inputs, BatchEncoding) and not isinstance(
|
|
637
|
-
sequence_or_inputs, dict
|
|
638
|
-
):
|
|
639
|
-
inputs = self.tokenizer(
|
|
640
|
-
sequence_or_inputs,
|
|
641
|
-
padding=kwargs.pop("padding", True),
|
|
642
|
-
max_length=kwargs.pop("max_length", 1024),
|
|
643
|
-
truncation=kwargs.pop("truncation", True),
|
|
644
|
-
return_tensors=kwargs.pop("return_tensors", "pt"),
|
|
645
|
-
**kwargs,
|
|
646
|
-
)
|
|
647
|
-
else:
|
|
648
|
-
inputs = sequence_or_inputs
|
|
649
|
-
inputs = inputs.to(self.model.device)
|
|
650
|
-
with torch.no_grad():
|
|
651
|
-
raw_outputs = self(**inputs)
|
|
652
|
-
raw_outputs["inputs"] = inputs
|
|
653
|
-
return raw_outputs
|
|
654
|
-
|
|
655
|
-
@staticmethod
|
|
656
|
-
def from_pretrained(model_name_or_path, tokenizer, *args, **kwargs):
|
|
657
|
-
"""
|
|
658
|
-
Loads a pre-trained model and tokenizer.
|
|
659
|
-
|
|
660
|
-
:param model_name_or_path: The name or path of the pre-trained model.
|
|
661
|
-
:param tokenizer: The tokenizer to use.
|
|
662
|
-
:param args: Additional positional arguments.
|
|
663
|
-
:param kwargs: Additional keyword arguments.
|
|
664
|
-
:return: An instance of `OmniModel`.
|
|
665
|
-
"""
|
|
666
|
-
config = kwargs.pop("config", None)
|
|
667
|
-
if config is None:
|
|
668
|
-
config = AutoConfig.from_pretrained(model_name_or_path, **kwargs)
|
|
669
|
-
base_model = AutoModel.from_pretrained(model_name_or_path, **kwargs)
|
|
670
|
-
if tokenizer is None:
|
|
671
|
-
tokenizer = AutoTokenizer.from_pretrained(base_model, **kwargs)
|
|
672
|
-
return OmniModel(config, base_model, tokenizer, *args, **kwargs)
|
|
673
|
-
|
|
674
|
-
def model_info(self):
|
|
675
|
-
"""
|
|
676
|
-
Prints and returns detailed information about the model.
|
|
677
|
-
|
|
678
|
-
:return: A string containing the model information.
|
|
679
|
-
"""
|
|
680
|
-
info = f"Model Name: {self.__class__.__name__}\n"
|
|
681
|
-
info += f"Model Metadata: {self.metadata}\n"
|
|
682
|
-
info += f"Base Model Name: {self.config.name_or_path}\n"
|
|
683
|
-
info += f"Model Type: {self.config.model_type}\n"
|
|
684
|
-
info += f"Model Architecture: {self.config.architectures}\n"
|
|
685
|
-
info += f"Model Parameters: {count_parameters(self.model) / 1e6} M\n"
|
|
686
|
-
info += f"Model Config: {self.config}\n"
|
|
687
|
-
fprint(info)
|
|
688
|
-
return info
|
|
689
|
-
|