omnigenome 0.3.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +281 -0
- omnigenome/auto/__init__.py +3 -0
- omnigenome/auto/auto_bench/__init__.py +12 -0
- omnigenome/auto/auto_bench/auto_bench.py +484 -0
- omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
- omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
- omnigenome/auto/auto_bench/config_check.py +34 -0
- omnigenome/auto/auto_train/__init__.py +13 -0
- omnigenome/auto/auto_train/auto_train.py +430 -0
- omnigenome/auto/auto_train/auto_train_cli.py +222 -0
- omnigenome/auto/bench_hub/__init__.py +12 -0
- omnigenome/auto/bench_hub/bench_hub.py +25 -0
- omnigenome/cli/__init__.py +13 -0
- omnigenome/cli/commands/__init__.py +13 -0
- omnigenome/cli/commands/base.py +83 -0
- omnigenome/cli/commands/bench/__init__.py +13 -0
- omnigenome/cli/commands/bench/bench_cli.py +202 -0
- omnigenome/cli/commands/rna/__init__.py +13 -0
- omnigenome/cli/commands/rna/rna_design.py +178 -0
- omnigenome/cli/omnigenome_cli.py +128 -0
- omnigenome/src/__init__.py +12 -0
- omnigenome/src/abc/__init__.py +12 -0
- omnigenome/src/abc/abstract_dataset.py +622 -0
- omnigenome/src/abc/abstract_metric.py +114 -0
- omnigenome/src/abc/abstract_model.py +689 -0
- omnigenome/src/abc/abstract_tokenizer.py +267 -0
- omnigenome/src/dataset/__init__.py +16 -0
- omnigenome/src/dataset/omni_dataset.py +435 -0
- omnigenome/src/lora/__init__.py +13 -0
- omnigenome/src/lora/lora_model.py +294 -0
- omnigenome/src/metric/__init__.py +15 -0
- omnigenome/src/metric/classification_metric.py +184 -0
- omnigenome/src/metric/metric.py +199 -0
- omnigenome/src/metric/ranking_metric.py +142 -0
- omnigenome/src/metric/regression_metric.py +191 -0
- omnigenome/src/misc/__init__.py +3 -0
- omnigenome/src/misc/utils.py +439 -0
- omnigenome/src/model/__init__.py +19 -0
- omnigenome/src/model/augmentation/__init__.py +12 -0
- omnigenome/src/model/augmentation/model.py +219 -0
- omnigenome/src/model/classification/__init__.py +12 -0
- omnigenome/src/model/classification/model.py +642 -0
- omnigenome/src/model/embedding/__init__.py +12 -0
- omnigenome/src/model/embedding/model.py +263 -0
- omnigenome/src/model/mlm/__init__.py +12 -0
- omnigenome/src/model/mlm/model.py +177 -0
- omnigenome/src/model/module_utils.py +232 -0
- omnigenome/src/model/regression/__init__.py +12 -0
- omnigenome/src/model/regression/model.py +786 -0
- omnigenome/src/model/regression/resnet.py +483 -0
- omnigenome/src/model/rna_design/__init__.py +12 -0
- omnigenome/src/model/rna_design/model.py +426 -0
- omnigenome/src/model/seq2seq/__init__.py +12 -0
- omnigenome/src/model/seq2seq/model.py +44 -0
- omnigenome/src/tokenizer/__init__.py +16 -0
- omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
- omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
- omnigenome/src/trainer/__init__.py +14 -0
- omnigenome/src/trainer/accelerate_trainer.py +739 -0
- omnigenome/src/trainer/hf_trainer.py +75 -0
- omnigenome/src/trainer/trainer.py +579 -0
- omnigenome/utility/__init__.py +3 -0
- omnigenome/utility/dataset_hub/__init__.py +13 -0
- omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
- omnigenome/utility/ensemble.py +324 -0
- omnigenome/utility/hub_utils.py +517 -0
- omnigenome/utility/model_hub/__init__.py +12 -0
- omnigenome/utility/model_hub/model_hub.py +231 -0
- omnigenome/utility/pipeline_hub/__init__.py +12 -0
- omnigenome/utility/pipeline_hub/pipeline.py +483 -0
- omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
- omnigenome-0.3.0a0.dist-info/METADATA +224 -0
- omnigenome-0.3.0a0.dist-info/RECORD +85 -0
- omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
- omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
- omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
- tests/__init__.py +9 -0
- tests/conftest.py +160 -0
- tests/test_dataset_patterns.py +291 -0
- tests/test_examples_syntax.py +83 -0
- tests/test_model_loading.py +183 -0
- tests/test_rna_functions.py +255 -0
- tests/test_training_patterns.py +302 -0
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: lora_model.py
|
|
3
|
+
# time: 12:36 11/06/2025
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# homepage: https://yangheng95.github.io
|
|
6
|
+
# github: https://github.com/yangheng95
|
|
7
|
+
# huggingface: https://huggingface.co/yangheng
|
|
8
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
9
|
+
# Copyright (C) 2019-2025. All Rights Reserved.
|
|
10
|
+
"""
|
|
11
|
+
Low-Rank Adaptation (LoRA) models for OmniGenome.
|
|
12
|
+
|
|
13
|
+
This module provides LoRA implementation for efficient fine-tuning of large
|
|
14
|
+
genomic language models. LoRA reduces the number of trainable parameters
|
|
15
|
+
by adding low-rank adaptation layers to existing model weights.
|
|
16
|
+
"""
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
from omnigenome.src.misc.utils import fprint
|
|
20
|
+
|
|
21
|
+
def find_linear_target_modules(model, keyword_filter=None, use_full_path=True):
|
|
22
|
+
"""
|
|
23
|
+
Find linear modules in a model that can be targeted for LoRA adaptation.
|
|
24
|
+
|
|
25
|
+
This function searches through a model's modules to identify linear layers
|
|
26
|
+
that can be adapted using LoRA. It supports filtering by keyword patterns
|
|
27
|
+
to target specific types of layers.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
model: The model to search for linear modules
|
|
31
|
+
keyword_filter (str, list, tuple, optional): Keywords to filter modules by name
|
|
32
|
+
use_full_path (bool): Whether to return full module paths or just names (default: True)
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
list: Sorted list of linear module names that can be targeted for LoRA
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
TypeError: If keyword_filter is not None, str, or a list/tuple of str
|
|
39
|
+
"""
|
|
40
|
+
import re
|
|
41
|
+
from torch import nn
|
|
42
|
+
|
|
43
|
+
if keyword_filter is not None:
|
|
44
|
+
if isinstance(keyword_filter, str):
|
|
45
|
+
keyword_filter = [keyword_filter]
|
|
46
|
+
elif not isinstance(keyword_filter, (list, tuple)):
|
|
47
|
+
raise TypeError("keyword_filter must be None, str, or a list/tuple of str")
|
|
48
|
+
|
|
49
|
+
pattern = '|'.join(map(re.escape, keyword_filter))
|
|
50
|
+
|
|
51
|
+
linear_modules = set()
|
|
52
|
+
for name, module in model.named_modules():
|
|
53
|
+
if isinstance(module, nn.Linear):
|
|
54
|
+
if keyword_filter is None or re.search(pattern, name, re.IGNORECASE):
|
|
55
|
+
linear_modules.add(name if use_full_path else name.split('.')[-1])
|
|
56
|
+
|
|
57
|
+
return sorted(linear_modules)
|
|
58
|
+
|
|
59
|
+
def auto_lora_model(model, **kwargs):
|
|
60
|
+
"""
|
|
61
|
+
Automatically create a LoRA-adapted model.
|
|
62
|
+
|
|
63
|
+
This function automatically identifies suitable target modules and creates
|
|
64
|
+
a LoRA-adapted version of the input model. It handles configuration
|
|
65
|
+
setup and parameter freezing for efficient fine-tuning.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
model: The base model to adapt with LoRA
|
|
69
|
+
**kwargs: Additional LoRA configuration parameters
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The LoRA-adapted model
|
|
73
|
+
|
|
74
|
+
Raises:
|
|
75
|
+
AssertionError: If no target modules are found for LoRA injection
|
|
76
|
+
"""
|
|
77
|
+
from peft import LoraConfig, get_peft_model
|
|
78
|
+
from transformers import PretrainedConfig
|
|
79
|
+
|
|
80
|
+
# A bad case for the EVO-1 model, which has a custom config class
|
|
81
|
+
######################
|
|
82
|
+
if hasattr(model, 'config') and not isinstance(model.config, PretrainedConfig):
|
|
83
|
+
delattr(model.config, 'Loader')
|
|
84
|
+
model.config = PretrainedConfig.from_dict(dict(model.config))
|
|
85
|
+
#######################
|
|
86
|
+
|
|
87
|
+
target_modules = kwargs.pop("target_modules", None)
|
|
88
|
+
use_rslora = kwargs.pop("use_rslora", True)
|
|
89
|
+
bias = kwargs.pop("bias", "none")
|
|
90
|
+
r = kwargs.pop("r", 32)
|
|
91
|
+
lora_alpha = kwargs.pop("lora_alpha", 256)
|
|
92
|
+
lora_dropout = kwargs.pop("lora_dropout", 0.1)
|
|
93
|
+
|
|
94
|
+
if target_modules is None:
|
|
95
|
+
target_modules = find_linear_target_modules(model, keyword_filter=kwargs.get("keyword_filter", None))
|
|
96
|
+
assert target_modules is not None, "No target modules found for LoRA injection."
|
|
97
|
+
config = LoraConfig(
|
|
98
|
+
target_modules=target_modules,
|
|
99
|
+
r=r,
|
|
100
|
+
lora_alpha=lora_alpha,
|
|
101
|
+
lora_dropout=lora_dropout,
|
|
102
|
+
bias=bias,
|
|
103
|
+
use_rslora=use_rslora,
|
|
104
|
+
**kwargs,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
for param in model.parameters():
|
|
108
|
+
param.requires_grad = False
|
|
109
|
+
|
|
110
|
+
lora_model = get_peft_model(model, config)
|
|
111
|
+
trainable_params, all_param = lora_model.get_nb_trainable_parameters()
|
|
112
|
+
fprint(
|
|
113
|
+
f"trainable params: {trainable_params:,d} || all params: {all_param:,d}"
|
|
114
|
+
f" || trainable%: {100 * trainable_params / all_param:.4f}"
|
|
115
|
+
)
|
|
116
|
+
return lora_model
|
|
117
|
+
|
|
118
|
+
class OmniLoraModel(nn.Module):
|
|
119
|
+
"""
|
|
120
|
+
LoRA-adapted model for OmniGenome.
|
|
121
|
+
|
|
122
|
+
This class provides a wrapper around LoRA-adapted models, enabling
|
|
123
|
+
efficient fine-tuning of large genomic language models while maintaining
|
|
124
|
+
compatibility with the OmniGenome framework.
|
|
125
|
+
|
|
126
|
+
Attributes:
|
|
127
|
+
lora_model: The underlying LoRA-adapted model
|
|
128
|
+
config: Model configuration
|
|
129
|
+
device: Device the model is running on
|
|
130
|
+
dtype: Data type of the model parameters
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(self, model, **kwargs):
|
|
134
|
+
"""
|
|
135
|
+
Initialize the LoRA-adapted model.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
model: The base model to adapt with LoRA
|
|
139
|
+
**kwargs: LoRA configuration parameters
|
|
140
|
+
|
|
141
|
+
Raises:
|
|
142
|
+
ValueError: If no target modules are specified for LoRA injection
|
|
143
|
+
"""
|
|
144
|
+
super(OmniLoraModel, self).__init__()
|
|
145
|
+
target_modules = kwargs.get("target_modules", None)
|
|
146
|
+
if target_modules is None:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
"No target modules found for LoRA injection. To perform LoRA adaptation fine-tuning, "
|
|
149
|
+
"please specify the target modules using the 'target_modules' argument. "
|
|
150
|
+
"The target modules depend on the model architecture, such as 'query', 'value', etc. ")
|
|
151
|
+
|
|
152
|
+
self.lora_model = auto_lora_model(model, **kwargs)
|
|
153
|
+
|
|
154
|
+
fprint(
|
|
155
|
+
"To reduce GPU memory occupation, "
|
|
156
|
+
"you should avoid include non-trainable parameters into optimizers, "
|
|
157
|
+
"e.g., optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), ...), "
|
|
158
|
+
"AVOIDING: optimizer = torch.optim.AdamW(model.parameters(), ...)"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
self.config = model.config
|
|
162
|
+
self.to('cpu') # Move the model to CPU initially
|
|
163
|
+
fprint(
|
|
164
|
+
"LoRA model initialized with the following configuration:\n",
|
|
165
|
+
self.lora_model
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def to(self, *args, **kwargs):
|
|
169
|
+
"""
|
|
170
|
+
Move the model to a specific device and data type.
|
|
171
|
+
|
|
172
|
+
This method overrides the default to() method to ensure the LoRA model
|
|
173
|
+
and its components are properly moved to the target device and dtype.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
*args: Device specification (e.g., 'cuda', 'cpu')
|
|
177
|
+
**kwargs: Additional arguments including dtype
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
self: The model instance
|
|
181
|
+
"""
|
|
182
|
+
self.lora_model.to(*args, **kwargs)
|
|
183
|
+
try:
|
|
184
|
+
# For evo-1 and similar models, we need to set the device and dtype
|
|
185
|
+
for param in self.parameters():
|
|
186
|
+
self.device = param.device
|
|
187
|
+
self.dtype = param.dtype
|
|
188
|
+
break
|
|
189
|
+
for module in self.lora_model.modules():
|
|
190
|
+
module.device = self.device
|
|
191
|
+
if hasattr(module, 'dtype'):
|
|
192
|
+
module.dtype = self.dtype
|
|
193
|
+
except Exception as e:
|
|
194
|
+
pass # Ignore errors if parameters are not available
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
def forward(self, *args, **kwargs):
|
|
198
|
+
"""
|
|
199
|
+
Forward pass through the LoRA model.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
*args: Positional arguments for the forward pass
|
|
203
|
+
**kwargs: Keyword arguments for the forward pass
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The output from the LoRA model
|
|
207
|
+
"""
|
|
208
|
+
return self.lora_model(*args, **kwargs)
|
|
209
|
+
|
|
210
|
+
def predict(self, *args, **kwargs):
|
|
211
|
+
"""
|
|
212
|
+
Generate predictions using the LoRA model.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
*args: Positional arguments for prediction
|
|
216
|
+
**kwargs: Keyword arguments for prediction
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Model predictions
|
|
220
|
+
"""
|
|
221
|
+
return self.lora_model.base_model.predict(*args, **kwargs)
|
|
222
|
+
|
|
223
|
+
def save(self, *args, **kwargs):
|
|
224
|
+
"""
|
|
225
|
+
Save the LoRA model.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
*args: Positional arguments for saving
|
|
229
|
+
**kwargs: Keyword arguments for saving
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Result of the save operation
|
|
233
|
+
"""
|
|
234
|
+
return self.lora_model.base_model.save(*args, **kwargs)
|
|
235
|
+
|
|
236
|
+
def model_info(self):
|
|
237
|
+
"""
|
|
238
|
+
Get information about the LoRA model.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Model information from the base model
|
|
242
|
+
"""
|
|
243
|
+
return self.lora_model.base_model.model_info()
|
|
244
|
+
|
|
245
|
+
def set_loss_fn(self, fn):
|
|
246
|
+
"""
|
|
247
|
+
Set the loss function for the LoRA model.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
fn: Loss function to set
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Result of setting the loss function
|
|
254
|
+
"""
|
|
255
|
+
return self.lora_model.base_model.set_loss_fn(fn)
|
|
256
|
+
|
|
257
|
+
def last_hidden_state_forward(self, **kwargs):
|
|
258
|
+
"""
|
|
259
|
+
Forward pass to get the last hidden state.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
**kwargs: Keyword arguments for the forward pass
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Last hidden state from the base model
|
|
266
|
+
"""
|
|
267
|
+
return self.lora_model.base_model.last_hidden_state_forward(**kwargs)
|
|
268
|
+
|
|
269
|
+
def tokenizer(self):
|
|
270
|
+
"""
|
|
271
|
+
Get the tokenizer from the base model.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
The tokenizer from the base model
|
|
275
|
+
"""
|
|
276
|
+
return self.lora_model.base_model.tokenizer
|
|
277
|
+
|
|
278
|
+
def config(self):
|
|
279
|
+
"""
|
|
280
|
+
Get the configuration from the base model.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
The configuration from the base model
|
|
284
|
+
"""
|
|
285
|
+
return self.lora_model.base_model.config
|
|
286
|
+
|
|
287
|
+
def model(self):
|
|
288
|
+
"""
|
|
289
|
+
Get the base model.
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
The base model
|
|
293
|
+
"""
|
|
294
|
+
return self.lora_model.base_model.model
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: __init__.py
|
|
3
|
+
# time: 12:53 09/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
This package contains modules for evaluation metrics.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from .classification_metric import ClassificationMetric
|
|
14
|
+
from .ranking_metric import RankingMetric
|
|
15
|
+
from .regression_metric import RegressionMetric
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: classification_metric.py
|
|
3
|
+
# time: 12:57 09/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
|
|
10
|
+
import types
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import sklearn.metrics as metrics
|
|
15
|
+
|
|
16
|
+
from ..abc.abstract_metric import OmniMetric
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ClassificationMetric(OmniMetric):
|
|
20
|
+
"""
|
|
21
|
+
Classification metric class for evaluating classification models.
|
|
22
|
+
|
|
23
|
+
This class provides a comprehensive interface for classification metrics
|
|
24
|
+
in the OmniGenome framework. It integrates with scikit-learn's classification
|
|
25
|
+
metrics and provides additional functionality for handling genomic classification
|
|
26
|
+
tasks.
|
|
27
|
+
|
|
28
|
+
The class automatically exposes all scikit-learn classification metrics as
|
|
29
|
+
callable attributes, making them easily accessible for evaluation. It also
|
|
30
|
+
handles special cases like Hugging Face's EvalPrediction objects and
|
|
31
|
+
provides proper handling of ignored labels.
|
|
32
|
+
|
|
33
|
+
Attributes:
|
|
34
|
+
metric_func (callable): A callable metric function from sklearn.metrics.
|
|
35
|
+
ignore_y (any): A value in the ground truth labels to be ignored during
|
|
36
|
+
metric computation. Defaults to -100.
|
|
37
|
+
kwargs (dict): Additional keyword arguments for metric computation.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
|
|
41
|
+
"""
|
|
42
|
+
Initializes the classification metric.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
metric_func (callable, optional): A callable metric function from
|
|
46
|
+
sklearn.metrics. If None, subclasses
|
|
47
|
+
should implement their own compute method.
|
|
48
|
+
ignore_y (any, optional): A value in the ground truth labels to be
|
|
49
|
+
ignored during metric computation. Defaults to -100.
|
|
50
|
+
*args: Additional positional arguments.
|
|
51
|
+
**kwargs: Additional keyword arguments.
|
|
52
|
+
|
|
53
|
+
Example:
|
|
54
|
+
>>> # Initialize with a specific metric function
|
|
55
|
+
>>> metric = ClassificationMetric(metrics.accuracy_score)
|
|
56
|
+
|
|
57
|
+
>>> # Initialize with ignore value
|
|
58
|
+
>>> metric = ClassificationMetric(ignore_y=-100)
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(metric_func, ignore_y, *args, **kwargs)
|
|
61
|
+
self.kwargs = kwargs
|
|
62
|
+
|
|
63
|
+
# def __getattr__(self, name):
|
|
64
|
+
def __getattribute__(self, name):
|
|
65
|
+
"""
|
|
66
|
+
Custom attribute getter that provides dynamic access to scikit-learn metrics.
|
|
67
|
+
|
|
68
|
+
This method provides transparent access to all scikit-learn classification
|
|
69
|
+
metrics. When a metric function is accessed, it returns a callable wrapper
|
|
70
|
+
that handles the metric computation with proper preprocessing.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
name (str): The attribute name to get.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
callable: A wrapper function for the requested metric, or the original
|
|
77
|
+
attribute if it's not a metric function.
|
|
78
|
+
|
|
79
|
+
Example:
|
|
80
|
+
>>> metric = ClassificationMetric()
|
|
81
|
+
>>> # Access any scikit-learn metric
|
|
82
|
+
>>> accuracy_fn = metric.accuracy_score
|
|
83
|
+
>>> result = accuracy_fn(y_true, y_pred)
|
|
84
|
+
"""
|
|
85
|
+
# Get the metric function
|
|
86
|
+
metric_func = getattr(metrics, name, None)
|
|
87
|
+
if metric_func and isinstance(metric_func, types.FunctionType):
|
|
88
|
+
setattr(self, "compute", metric_func)
|
|
89
|
+
# If the metric function exists, return a wrapper function
|
|
90
|
+
|
|
91
|
+
def wrapper(y_true=None, y_pred=None, *args, **kwargs):
|
|
92
|
+
"""
|
|
93
|
+
Compute the metric, based on the true and predicted values.
|
|
94
|
+
|
|
95
|
+
This wrapper function handles various input formats including
|
|
96
|
+
Hugging Face's EvalPrediction objects and provides proper
|
|
97
|
+
preprocessing for metric computation.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
y_true: The true values (ground truth labels).
|
|
101
|
+
y_pred: The predicted values (model predictions).
|
|
102
|
+
ignore_y: The value to ignore in the predictions and true
|
|
103
|
+
values in corresponding positions.
|
|
104
|
+
*args: Additional positional arguments for the metric function.
|
|
105
|
+
**kwargs: Additional keyword arguments for the metric function.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
dict: A dictionary with the metric name as key and its value.
|
|
109
|
+
|
|
110
|
+
Example:
|
|
111
|
+
>>> # Standard usage
|
|
112
|
+
>>> result = accuracy_fn(y_true, y_pred)
|
|
113
|
+
>>> print(result) # {'accuracy_score': 0.85}
|
|
114
|
+
|
|
115
|
+
>>> # With Hugging Face EvalPrediction
|
|
116
|
+
>>> result = accuracy_fn(eval_prediction)
|
|
117
|
+
>>> print(result) # {'accuracy_score': 0.85}
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
# This is an ugly method to handle the case when the predictions are in the form of a tuple
|
|
121
|
+
# for huggingface trainers
|
|
122
|
+
if y_true.__class__.__name__ == "EvalPrediction":
|
|
123
|
+
eval_prediction = y_true
|
|
124
|
+
if hasattr(eval_prediction, "label_ids"):
|
|
125
|
+
y_true = eval_prediction.label_ids
|
|
126
|
+
if hasattr(eval_prediction, "labels"):
|
|
127
|
+
y_true = eval_prediction.labels
|
|
128
|
+
predictions = eval_prediction.predictions
|
|
129
|
+
for i in range(len(predictions)):
|
|
130
|
+
if predictions[i].shape == y_true.shape and not np.all(
|
|
131
|
+
predictions[i] == y_true
|
|
132
|
+
):
|
|
133
|
+
y_score = predictions[i]
|
|
134
|
+
break
|
|
135
|
+
|
|
136
|
+
y_true, y_pred = ClassificationMetric.flatten(y_true, y_pred)
|
|
137
|
+
y_true_mask_idx = np.where(y_true != self.ignore_y)
|
|
138
|
+
if self.ignore_y is not None:
|
|
139
|
+
y_true = y_true[y_true_mask_idx]
|
|
140
|
+
try:
|
|
141
|
+
y_pred = y_pred[y_true_mask_idx]
|
|
142
|
+
except Exception as e:
|
|
143
|
+
warnings.warn(str(e))
|
|
144
|
+
|
|
145
|
+
kwargs.update(self.kwargs)
|
|
146
|
+
return {name: self.compute(y_true, y_pred, *args, **kwargs)}
|
|
147
|
+
|
|
148
|
+
return wrapper
|
|
149
|
+
else:
|
|
150
|
+
return super().__getattribute__(name)
|
|
151
|
+
|
|
152
|
+
def compute(self, y_true, y_pred, *args, **kwargs):
|
|
153
|
+
"""
|
|
154
|
+
Compute the metric, based on the true and predicted values.
|
|
155
|
+
|
|
156
|
+
This method computes the classification metric using the provided
|
|
157
|
+
metric function. It handles preprocessing and applies any additional
|
|
158
|
+
keyword arguments.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
y_true: The true values (ground truth labels).
|
|
162
|
+
y_pred: The predicted values (model predictions).
|
|
163
|
+
*args: Additional positional arguments for the metric function.
|
|
164
|
+
**kwargs: Additional keyword arguments for the metric function.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
dict: A dictionary with the metric name as key and its value.
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
NotImplementedError: If no metric function is provided and the method
|
|
171
|
+
is not implemented by the subclass.
|
|
172
|
+
|
|
173
|
+
Example:
|
|
174
|
+
>>> metric = ClassificationMetric(metrics.accuracy_score)
|
|
175
|
+
>>> result = metric.compute(y_true, y_pred)
|
|
176
|
+
>>> print(result) # {'accuracy_score': 0.85}
|
|
177
|
+
"""
|
|
178
|
+
if self.metric_func is not None:
|
|
179
|
+
kwargs.update(self.kwargs)
|
|
180
|
+
return self.metric_func(y_true, y_pred, *args, **kwargs)
|
|
181
|
+
else:
|
|
182
|
+
raise NotImplementedError(
|
|
183
|
+
"Method compute() is not implemented in the child class."
|
|
184
|
+
)
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: regression_metric.py
|
|
3
|
+
# time: 12:57 09/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
import types
|
|
12
|
+
import warnings
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import sklearn.metrics as metrics
|
|
16
|
+
|
|
17
|
+
from ..abc.abstract_metric import OmniMetric
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def mcrmse(y_true, y_pred):
|
|
21
|
+
"""
|
|
22
|
+
Compute Mean Column Root Mean Square Error (MCRMSE).
|
|
23
|
+
|
|
24
|
+
MCRMSE is a multi-target regression metric that computes the RMSE for each target
|
|
25
|
+
column and then takes the mean across all targets.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
y_true (np.ndarray): Ground truth values with shape (n_samples, n_targets)
|
|
29
|
+
y_pred (np.ndarray): Predicted values with shape (n_samples, n_targets)
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
float: Mean Column Root Mean Square Error
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
ValueError: If y_true and y_pred have different shapes
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> y_true = np.array([[1, 2], [3, 4], [5, 6]])
|
|
39
|
+
>>> y_pred = np.array([[1.1, 2.1], [2.9, 4.1], [5.2, 5.8]])
|
|
40
|
+
>>> mcrmse(y_true, y_pred)
|
|
41
|
+
0.1833...
|
|
42
|
+
"""
|
|
43
|
+
if y_true.shape != y_pred.shape:
|
|
44
|
+
raise ValueError("y_true and y_pred must have the same shape")
|
|
45
|
+
mask = y_true != -100
|
|
46
|
+
filtered_y_pred = y_pred[mask]
|
|
47
|
+
filtered_y_true = y_true[mask]
|
|
48
|
+
rmse_per_target = np.sqrt(np.mean((filtered_y_true - filtered_y_pred) ** 2, axis=0))
|
|
49
|
+
mcrmse_value = np.mean(rmse_per_target)
|
|
50
|
+
return mcrmse_value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
setattr(metrics, "mcrmse", mcrmse)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Metric(OmniMetric):
|
|
57
|
+
"""
|
|
58
|
+
A flexible metric class that provides access to all scikit-learn metrics
|
|
59
|
+
and custom metrics for evaluation.
|
|
60
|
+
|
|
61
|
+
This class dynamically wraps scikit-learn metrics and provides a unified
|
|
62
|
+
interface for computing various evaluation metrics. It handles different
|
|
63
|
+
input formats including HuggingFace trainer outputs and supports
|
|
64
|
+
custom metric functions.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
metric_func: Custom metric function if provided
|
|
68
|
+
ignore_y: Value to ignore in predictions and true values
|
|
69
|
+
kwargs: Additional keyword arguments for metric computation
|
|
70
|
+
metrics: Dictionary of available metrics including custom ones
|
|
71
|
+
|
|
72
|
+
Example:
|
|
73
|
+
>>> from omnigenome.src.metric import Metric
|
|
74
|
+
>>> metric = Metric(ignore_y=-100)
|
|
75
|
+
>>> y_true = [0, 1, 2, 0, 1]
|
|
76
|
+
>>> y_pred = [0, 1, 1, 0, 1]
|
|
77
|
+
>>> result = metric.accuracy(y_true, y_pred)
|
|
78
|
+
>>> print(result)
|
|
79
|
+
{'accuracy': 0.8}
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
|
|
83
|
+
"""
|
|
84
|
+
Initialize the Metric class.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
metric_func (callable, optional): Custom metric function to use
|
|
88
|
+
ignore_y (int, optional): Value to ignore in predictions and true values. Defaults to -100
|
|
89
|
+
*args: Additional positional arguments
|
|
90
|
+
**kwargs: Additional keyword arguments for metric computation
|
|
91
|
+
"""
|
|
92
|
+
super().__init__(metric_func, ignore_y, *args, **kwargs)
|
|
93
|
+
self.kwargs = kwargs
|
|
94
|
+
self.metrics = {"mcrmse": mcrmse}
|
|
95
|
+
for key, value in metrics.__dict__.items():
|
|
96
|
+
setattr(self, key, value)
|
|
97
|
+
|
|
98
|
+
def __getattribute__(self, name):
|
|
99
|
+
"""
|
|
100
|
+
Dynamically create metric computation methods.
|
|
101
|
+
|
|
102
|
+
This method intercepts attribute access and creates wrapper functions
|
|
103
|
+
for scikit-learn metrics, handling different input formats and
|
|
104
|
+
preprocessing the data appropriately.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
name (str): Name of the metric to access
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
callable: Wrapper function for the requested metric
|
|
111
|
+
"""
|
|
112
|
+
# Get the metric function
|
|
113
|
+
metric_func = getattr(metrics, name, None)
|
|
114
|
+
|
|
115
|
+
if metric_func and isinstance(metric_func, types.FunctionType):
|
|
116
|
+
setattr(self, "compute", metric_func)
|
|
117
|
+
# If the metric function exists, return a wrapper function
|
|
118
|
+
|
|
119
|
+
def wrapper(y_true=None, y_score=None, *args, **kwargs):
|
|
120
|
+
"""
|
|
121
|
+
Compute the metric, based on the true and predicted values.
|
|
122
|
+
|
|
123
|
+
This wrapper handles different input formats including HuggingFace
|
|
124
|
+
trainer outputs and performs necessary preprocessing.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
y_true: The true values or HuggingFace EvalPrediction object
|
|
128
|
+
y_score: The predicted values
|
|
129
|
+
ignore_y: The value to ignore in the predictions and true values in corresponding positions
|
|
130
|
+
*args: Additional positional arguments for the metric
|
|
131
|
+
**kwargs: Additional keyword arguments for the metric
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
dict: Dictionary containing the metric name and computed value
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
ValueError: If neither y_true nor y_score is provided
|
|
138
|
+
"""
|
|
139
|
+
# This is an ugly method to handle the case when the predictions are in the form of a tuple
|
|
140
|
+
# for huggingface trainers
|
|
141
|
+
if y_true is not None and y_score is None:
|
|
142
|
+
if hasattr(y_true, "predictions"):
|
|
143
|
+
y_score = y_true.predictions
|
|
144
|
+
if hasattr(y_true, "label_ids"):
|
|
145
|
+
y_true = y_true.label_ids
|
|
146
|
+
if hasattr(y_true, "labels"):
|
|
147
|
+
y_true = y_true.labels
|
|
148
|
+
if len(y_score[0][1]) == np.max(y_true) + 1:
|
|
149
|
+
y_score = y_score[0]
|
|
150
|
+
else:
|
|
151
|
+
y_score = y_score[1]
|
|
152
|
+
y_score = np.argmax(y_score, axis=1)
|
|
153
|
+
elif y_true is not None and y_score is not None:
|
|
154
|
+
pass # y_true and y_score are provided
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"Please provide the true and predicted values or a dictionary with 'y_true' and 'y_score'."
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
y_true, y_score = Metric.flatten(y_true, y_score)
|
|
161
|
+
y_true_mask_idx = np.where(y_true != self.ignore_y)
|
|
162
|
+
if self.ignore_y is not None:
|
|
163
|
+
y_true = y_true[y_true_mask_idx]
|
|
164
|
+
try:
|
|
165
|
+
y_score = y_score[y_true_mask_idx]
|
|
166
|
+
except Exception as e:
|
|
167
|
+
warnings.warn(str(e))
|
|
168
|
+
kwargs.update(self.kwargs)
|
|
169
|
+
|
|
170
|
+
return {name: self.compute(y_true, y_score, *args, **kwargs)}
|
|
171
|
+
|
|
172
|
+
return wrapper
|
|
173
|
+
else:
|
|
174
|
+
return super().__getattribute__(name)
|
|
175
|
+
|
|
176
|
+
def compute(self, y_true, y_score, *args, **kwargs):
|
|
177
|
+
"""
|
|
178
|
+
Compute the metric, based on the true and predicted values.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
y_true: The true values
|
|
182
|
+
y_score: The predicted values
|
|
183
|
+
*args: Additional positional arguments for the metric
|
|
184
|
+
**kwargs: Additional keyword arguments for the metric
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
The computed metric value
|
|
188
|
+
|
|
189
|
+
Raises:
|
|
190
|
+
NotImplementedError: If no metric function is provided and compute is not implemented
|
|
191
|
+
"""
|
|
192
|
+
if self.metric_func is not None:
|
|
193
|
+
kwargs.update(self.kwargs)
|
|
194
|
+
return self.metric_func(y_true, y_score, *args, **kwargs)
|
|
195
|
+
|
|
196
|
+
else:
|
|
197
|
+
raise NotImplementedError(
|
|
198
|
+
"Method compute() is not implemented in the child class."
|
|
199
|
+
)
|