omnigenome 0.3.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (85) hide show
  1. omnigenome/__init__.py +281 -0
  2. omnigenome/auto/__init__.py +3 -0
  3. omnigenome/auto/auto_bench/__init__.py +12 -0
  4. omnigenome/auto/auto_bench/auto_bench.py +484 -0
  5. omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
  6. omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
  7. omnigenome/auto/auto_bench/config_check.py +34 -0
  8. omnigenome/auto/auto_train/__init__.py +13 -0
  9. omnigenome/auto/auto_train/auto_train.py +430 -0
  10. omnigenome/auto/auto_train/auto_train_cli.py +222 -0
  11. omnigenome/auto/bench_hub/__init__.py +12 -0
  12. omnigenome/auto/bench_hub/bench_hub.py +25 -0
  13. omnigenome/cli/__init__.py +13 -0
  14. omnigenome/cli/commands/__init__.py +13 -0
  15. omnigenome/cli/commands/base.py +83 -0
  16. omnigenome/cli/commands/bench/__init__.py +13 -0
  17. omnigenome/cli/commands/bench/bench_cli.py +202 -0
  18. omnigenome/cli/commands/rna/__init__.py +13 -0
  19. omnigenome/cli/commands/rna/rna_design.py +178 -0
  20. omnigenome/cli/omnigenome_cli.py +128 -0
  21. omnigenome/src/__init__.py +12 -0
  22. omnigenome/src/abc/__init__.py +12 -0
  23. omnigenome/src/abc/abstract_dataset.py +622 -0
  24. omnigenome/src/abc/abstract_metric.py +114 -0
  25. omnigenome/src/abc/abstract_model.py +689 -0
  26. omnigenome/src/abc/abstract_tokenizer.py +267 -0
  27. omnigenome/src/dataset/__init__.py +16 -0
  28. omnigenome/src/dataset/omni_dataset.py +435 -0
  29. omnigenome/src/lora/__init__.py +13 -0
  30. omnigenome/src/lora/lora_model.py +294 -0
  31. omnigenome/src/metric/__init__.py +15 -0
  32. omnigenome/src/metric/classification_metric.py +184 -0
  33. omnigenome/src/metric/metric.py +199 -0
  34. omnigenome/src/metric/ranking_metric.py +142 -0
  35. omnigenome/src/metric/regression_metric.py +191 -0
  36. omnigenome/src/misc/__init__.py +3 -0
  37. omnigenome/src/misc/utils.py +439 -0
  38. omnigenome/src/model/__init__.py +19 -0
  39. omnigenome/src/model/augmentation/__init__.py +12 -0
  40. omnigenome/src/model/augmentation/model.py +219 -0
  41. omnigenome/src/model/classification/__init__.py +12 -0
  42. omnigenome/src/model/classification/model.py +642 -0
  43. omnigenome/src/model/embedding/__init__.py +12 -0
  44. omnigenome/src/model/embedding/model.py +263 -0
  45. omnigenome/src/model/mlm/__init__.py +12 -0
  46. omnigenome/src/model/mlm/model.py +177 -0
  47. omnigenome/src/model/module_utils.py +232 -0
  48. omnigenome/src/model/regression/__init__.py +12 -0
  49. omnigenome/src/model/regression/model.py +786 -0
  50. omnigenome/src/model/regression/resnet.py +483 -0
  51. omnigenome/src/model/rna_design/__init__.py +12 -0
  52. omnigenome/src/model/rna_design/model.py +426 -0
  53. omnigenome/src/model/seq2seq/__init__.py +12 -0
  54. omnigenome/src/model/seq2seq/model.py +44 -0
  55. omnigenome/src/tokenizer/__init__.py +16 -0
  56. omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
  57. omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
  58. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
  59. omnigenome/src/trainer/__init__.py +14 -0
  60. omnigenome/src/trainer/accelerate_trainer.py +739 -0
  61. omnigenome/src/trainer/hf_trainer.py +75 -0
  62. omnigenome/src/trainer/trainer.py +579 -0
  63. omnigenome/utility/__init__.py +3 -0
  64. omnigenome/utility/dataset_hub/__init__.py +13 -0
  65. omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
  66. omnigenome/utility/ensemble.py +324 -0
  67. omnigenome/utility/hub_utils.py +517 -0
  68. omnigenome/utility/model_hub/__init__.py +12 -0
  69. omnigenome/utility/model_hub/model_hub.py +231 -0
  70. omnigenome/utility/pipeline_hub/__init__.py +12 -0
  71. omnigenome/utility/pipeline_hub/pipeline.py +483 -0
  72. omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
  73. omnigenome-0.3.0a0.dist-info/METADATA +224 -0
  74. omnigenome-0.3.0a0.dist-info/RECORD +85 -0
  75. omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
  76. omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
  77. omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
  78. omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
  79. tests/__init__.py +9 -0
  80. tests/conftest.py +160 -0
  81. tests/test_dataset_patterns.py +291 -0
  82. tests/test_examples_syntax.py +83 -0
  83. tests/test_model_loading.py +183 -0
  84. tests/test_rna_functions.py +255 -0
  85. tests/test_training_patterns.py +302 -0
@@ -0,0 +1,294 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: lora_model.py
3
+ # time: 12:36 11/06/2025
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # homepage: https://yangheng95.github.io
6
+ # github: https://github.com/yangheng95
7
+ # huggingface: https://huggingface.co/yangheng
8
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
9
+ # Copyright (C) 2019-2025. All Rights Reserved.
10
+ """
11
+ Low-Rank Adaptation (LoRA) models for OmniGenome.
12
+
13
+ This module provides LoRA implementation for efficient fine-tuning of large
14
+ genomic language models. LoRA reduces the number of trainable parameters
15
+ by adding low-rank adaptation layers to existing model weights.
16
+ """
17
+ import torch
18
+ from torch import nn
19
+ from omnigenome.src.misc.utils import fprint
20
+
21
+ def find_linear_target_modules(model, keyword_filter=None, use_full_path=True):
22
+ """
23
+ Find linear modules in a model that can be targeted for LoRA adaptation.
24
+
25
+ This function searches through a model's modules to identify linear layers
26
+ that can be adapted using LoRA. It supports filtering by keyword patterns
27
+ to target specific types of layers.
28
+
29
+ Args:
30
+ model: The model to search for linear modules
31
+ keyword_filter (str, list, tuple, optional): Keywords to filter modules by name
32
+ use_full_path (bool): Whether to return full module paths or just names (default: True)
33
+
34
+ Returns:
35
+ list: Sorted list of linear module names that can be targeted for LoRA
36
+
37
+ Raises:
38
+ TypeError: If keyword_filter is not None, str, or a list/tuple of str
39
+ """
40
+ import re
41
+ from torch import nn
42
+
43
+ if keyword_filter is not None:
44
+ if isinstance(keyword_filter, str):
45
+ keyword_filter = [keyword_filter]
46
+ elif not isinstance(keyword_filter, (list, tuple)):
47
+ raise TypeError("keyword_filter must be None, str, or a list/tuple of str")
48
+
49
+ pattern = '|'.join(map(re.escape, keyword_filter))
50
+
51
+ linear_modules = set()
52
+ for name, module in model.named_modules():
53
+ if isinstance(module, nn.Linear):
54
+ if keyword_filter is None or re.search(pattern, name, re.IGNORECASE):
55
+ linear_modules.add(name if use_full_path else name.split('.')[-1])
56
+
57
+ return sorted(linear_modules)
58
+
59
+ def auto_lora_model(model, **kwargs):
60
+ """
61
+ Automatically create a LoRA-adapted model.
62
+
63
+ This function automatically identifies suitable target modules and creates
64
+ a LoRA-adapted version of the input model. It handles configuration
65
+ setup and parameter freezing for efficient fine-tuning.
66
+
67
+ Args:
68
+ model: The base model to adapt with LoRA
69
+ **kwargs: Additional LoRA configuration parameters
70
+
71
+ Returns:
72
+ The LoRA-adapted model
73
+
74
+ Raises:
75
+ AssertionError: If no target modules are found for LoRA injection
76
+ """
77
+ from peft import LoraConfig, get_peft_model
78
+ from transformers import PretrainedConfig
79
+
80
+ # A bad case for the EVO-1 model, which has a custom config class
81
+ ######################
82
+ if hasattr(model, 'config') and not isinstance(model.config, PretrainedConfig):
83
+ delattr(model.config, 'Loader')
84
+ model.config = PretrainedConfig.from_dict(dict(model.config))
85
+ #######################
86
+
87
+ target_modules = kwargs.pop("target_modules", None)
88
+ use_rslora = kwargs.pop("use_rslora", True)
89
+ bias = kwargs.pop("bias", "none")
90
+ r = kwargs.pop("r", 32)
91
+ lora_alpha = kwargs.pop("lora_alpha", 256)
92
+ lora_dropout = kwargs.pop("lora_dropout", 0.1)
93
+
94
+ if target_modules is None:
95
+ target_modules = find_linear_target_modules(model, keyword_filter=kwargs.get("keyword_filter", None))
96
+ assert target_modules is not None, "No target modules found for LoRA injection."
97
+ config = LoraConfig(
98
+ target_modules=target_modules,
99
+ r=r,
100
+ lora_alpha=lora_alpha,
101
+ lora_dropout=lora_dropout,
102
+ bias=bias,
103
+ use_rslora=use_rslora,
104
+ **kwargs,
105
+ )
106
+
107
+ for param in model.parameters():
108
+ param.requires_grad = False
109
+
110
+ lora_model = get_peft_model(model, config)
111
+ trainable_params, all_param = lora_model.get_nb_trainable_parameters()
112
+ fprint(
113
+ f"trainable params: {trainable_params:,d} || all params: {all_param:,d}"
114
+ f" || trainable%: {100 * trainable_params / all_param:.4f}"
115
+ )
116
+ return lora_model
117
+
118
+ class OmniLoraModel(nn.Module):
119
+ """
120
+ LoRA-adapted model for OmniGenome.
121
+
122
+ This class provides a wrapper around LoRA-adapted models, enabling
123
+ efficient fine-tuning of large genomic language models while maintaining
124
+ compatibility with the OmniGenome framework.
125
+
126
+ Attributes:
127
+ lora_model: The underlying LoRA-adapted model
128
+ config: Model configuration
129
+ device: Device the model is running on
130
+ dtype: Data type of the model parameters
131
+ """
132
+
133
+ def __init__(self, model, **kwargs):
134
+ """
135
+ Initialize the LoRA-adapted model.
136
+
137
+ Args:
138
+ model: The base model to adapt with LoRA
139
+ **kwargs: LoRA configuration parameters
140
+
141
+ Raises:
142
+ ValueError: If no target modules are specified for LoRA injection
143
+ """
144
+ super(OmniLoraModel, self).__init__()
145
+ target_modules = kwargs.get("target_modules", None)
146
+ if target_modules is None:
147
+ raise ValueError(
148
+ "No target modules found for LoRA injection. To perform LoRA adaptation fine-tuning, "
149
+ "please specify the target modules using the 'target_modules' argument. "
150
+ "The target modules depend on the model architecture, such as 'query', 'value', etc. ")
151
+
152
+ self.lora_model = auto_lora_model(model, **kwargs)
153
+
154
+ fprint(
155
+ "To reduce GPU memory occupation, "
156
+ "you should avoid include non-trainable parameters into optimizers, "
157
+ "e.g., optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), ...), "
158
+ "AVOIDING: optimizer = torch.optim.AdamW(model.parameters(), ...)"
159
+ )
160
+
161
+ self.config = model.config
162
+ self.to('cpu') # Move the model to CPU initially
163
+ fprint(
164
+ "LoRA model initialized with the following configuration:\n",
165
+ self.lora_model
166
+ )
167
+
168
+ def to(self, *args, **kwargs):
169
+ """
170
+ Move the model to a specific device and data type.
171
+
172
+ This method overrides the default to() method to ensure the LoRA model
173
+ and its components are properly moved to the target device and dtype.
174
+
175
+ Args:
176
+ *args: Device specification (e.g., 'cuda', 'cpu')
177
+ **kwargs: Additional arguments including dtype
178
+
179
+ Returns:
180
+ self: The model instance
181
+ """
182
+ self.lora_model.to(*args, **kwargs)
183
+ try:
184
+ # For evo-1 and similar models, we need to set the device and dtype
185
+ for param in self.parameters():
186
+ self.device = param.device
187
+ self.dtype = param.dtype
188
+ break
189
+ for module in self.lora_model.modules():
190
+ module.device = self.device
191
+ if hasattr(module, 'dtype'):
192
+ module.dtype = self.dtype
193
+ except Exception as e:
194
+ pass # Ignore errors if parameters are not available
195
+ return self
196
+
197
+ def forward(self, *args, **kwargs):
198
+ """
199
+ Forward pass through the LoRA model.
200
+
201
+ Args:
202
+ *args: Positional arguments for the forward pass
203
+ **kwargs: Keyword arguments for the forward pass
204
+
205
+ Returns:
206
+ The output from the LoRA model
207
+ """
208
+ return self.lora_model(*args, **kwargs)
209
+
210
+ def predict(self, *args, **kwargs):
211
+ """
212
+ Generate predictions using the LoRA model.
213
+
214
+ Args:
215
+ *args: Positional arguments for prediction
216
+ **kwargs: Keyword arguments for prediction
217
+
218
+ Returns:
219
+ Model predictions
220
+ """
221
+ return self.lora_model.base_model.predict(*args, **kwargs)
222
+
223
+ def save(self, *args, **kwargs):
224
+ """
225
+ Save the LoRA model.
226
+
227
+ Args:
228
+ *args: Positional arguments for saving
229
+ **kwargs: Keyword arguments for saving
230
+
231
+ Returns:
232
+ Result of the save operation
233
+ """
234
+ return self.lora_model.base_model.save(*args, **kwargs)
235
+
236
+ def model_info(self):
237
+ """
238
+ Get information about the LoRA model.
239
+
240
+ Returns:
241
+ Model information from the base model
242
+ """
243
+ return self.lora_model.base_model.model_info()
244
+
245
+ def set_loss_fn(self, fn):
246
+ """
247
+ Set the loss function for the LoRA model.
248
+
249
+ Args:
250
+ fn: Loss function to set
251
+
252
+ Returns:
253
+ Result of setting the loss function
254
+ """
255
+ return self.lora_model.base_model.set_loss_fn(fn)
256
+
257
+ def last_hidden_state_forward(self, **kwargs):
258
+ """
259
+ Forward pass to get the last hidden state.
260
+
261
+ Args:
262
+ **kwargs: Keyword arguments for the forward pass
263
+
264
+ Returns:
265
+ Last hidden state from the base model
266
+ """
267
+ return self.lora_model.base_model.last_hidden_state_forward(**kwargs)
268
+
269
+ def tokenizer(self):
270
+ """
271
+ Get the tokenizer from the base model.
272
+
273
+ Returns:
274
+ The tokenizer from the base model
275
+ """
276
+ return self.lora_model.base_model.tokenizer
277
+
278
+ def config(self):
279
+ """
280
+ Get the configuration from the base model.
281
+
282
+ Returns:
283
+ The configuration from the base model
284
+ """
285
+ return self.lora_model.base_model.config
286
+
287
+ def model(self):
288
+ """
289
+ Get the base model.
290
+
291
+ Returns:
292
+ The base model
293
+ """
294
+ return self.lora_model.base_model.model
@@ -0,0 +1,15 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: __init__.py
3
+ # time: 12:53 09/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+ """
10
+ This package contains modules for evaluation metrics.
11
+ """
12
+
13
+ from .classification_metric import ClassificationMetric
14
+ from .ranking_metric import RankingMetric
15
+ from .regression_metric import RegressionMetric
@@ -0,0 +1,184 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: classification_metric.py
3
+ # time: 12:57 09/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+
10
+ import types
11
+ import warnings
12
+
13
+ import numpy as np
14
+ import sklearn.metrics as metrics
15
+
16
+ from ..abc.abstract_metric import OmniMetric
17
+
18
+
19
+ class ClassificationMetric(OmniMetric):
20
+ """
21
+ Classification metric class for evaluating classification models.
22
+
23
+ This class provides a comprehensive interface for classification metrics
24
+ in the OmniGenome framework. It integrates with scikit-learn's classification
25
+ metrics and provides additional functionality for handling genomic classification
26
+ tasks.
27
+
28
+ The class automatically exposes all scikit-learn classification metrics as
29
+ callable attributes, making them easily accessible for evaluation. It also
30
+ handles special cases like Hugging Face's EvalPrediction objects and
31
+ provides proper handling of ignored labels.
32
+
33
+ Attributes:
34
+ metric_func (callable): A callable metric function from sklearn.metrics.
35
+ ignore_y (any): A value in the ground truth labels to be ignored during
36
+ metric computation. Defaults to -100.
37
+ kwargs (dict): Additional keyword arguments for metric computation.
38
+ """
39
+
40
+ def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
41
+ """
42
+ Initializes the classification metric.
43
+
44
+ Args:
45
+ metric_func (callable, optional): A callable metric function from
46
+ sklearn.metrics. If None, subclasses
47
+ should implement their own compute method.
48
+ ignore_y (any, optional): A value in the ground truth labels to be
49
+ ignored during metric computation. Defaults to -100.
50
+ *args: Additional positional arguments.
51
+ **kwargs: Additional keyword arguments.
52
+
53
+ Example:
54
+ >>> # Initialize with a specific metric function
55
+ >>> metric = ClassificationMetric(metrics.accuracy_score)
56
+
57
+ >>> # Initialize with ignore value
58
+ >>> metric = ClassificationMetric(ignore_y=-100)
59
+ """
60
+ super().__init__(metric_func, ignore_y, *args, **kwargs)
61
+ self.kwargs = kwargs
62
+
63
+ # def __getattr__(self, name):
64
+ def __getattribute__(self, name):
65
+ """
66
+ Custom attribute getter that provides dynamic access to scikit-learn metrics.
67
+
68
+ This method provides transparent access to all scikit-learn classification
69
+ metrics. When a metric function is accessed, it returns a callable wrapper
70
+ that handles the metric computation with proper preprocessing.
71
+
72
+ Args:
73
+ name (str): The attribute name to get.
74
+
75
+ Returns:
76
+ callable: A wrapper function for the requested metric, or the original
77
+ attribute if it's not a metric function.
78
+
79
+ Example:
80
+ >>> metric = ClassificationMetric()
81
+ >>> # Access any scikit-learn metric
82
+ >>> accuracy_fn = metric.accuracy_score
83
+ >>> result = accuracy_fn(y_true, y_pred)
84
+ """
85
+ # Get the metric function
86
+ metric_func = getattr(metrics, name, None)
87
+ if metric_func and isinstance(metric_func, types.FunctionType):
88
+ setattr(self, "compute", metric_func)
89
+ # If the metric function exists, return a wrapper function
90
+
91
+ def wrapper(y_true=None, y_pred=None, *args, **kwargs):
92
+ """
93
+ Compute the metric, based on the true and predicted values.
94
+
95
+ This wrapper function handles various input formats including
96
+ Hugging Face's EvalPrediction objects and provides proper
97
+ preprocessing for metric computation.
98
+
99
+ Args:
100
+ y_true: The true values (ground truth labels).
101
+ y_pred: The predicted values (model predictions).
102
+ ignore_y: The value to ignore in the predictions and true
103
+ values in corresponding positions.
104
+ *args: Additional positional arguments for the metric function.
105
+ **kwargs: Additional keyword arguments for the metric function.
106
+
107
+ Returns:
108
+ dict: A dictionary with the metric name as key and its value.
109
+
110
+ Example:
111
+ >>> # Standard usage
112
+ >>> result = accuracy_fn(y_true, y_pred)
113
+ >>> print(result) # {'accuracy_score': 0.85}
114
+
115
+ >>> # With Hugging Face EvalPrediction
116
+ >>> result = accuracy_fn(eval_prediction)
117
+ >>> print(result) # {'accuracy_score': 0.85}
118
+ """
119
+
120
+ # This is an ugly method to handle the case when the predictions are in the form of a tuple
121
+ # for huggingface trainers
122
+ if y_true.__class__.__name__ == "EvalPrediction":
123
+ eval_prediction = y_true
124
+ if hasattr(eval_prediction, "label_ids"):
125
+ y_true = eval_prediction.label_ids
126
+ if hasattr(eval_prediction, "labels"):
127
+ y_true = eval_prediction.labels
128
+ predictions = eval_prediction.predictions
129
+ for i in range(len(predictions)):
130
+ if predictions[i].shape == y_true.shape and not np.all(
131
+ predictions[i] == y_true
132
+ ):
133
+ y_score = predictions[i]
134
+ break
135
+
136
+ y_true, y_pred = ClassificationMetric.flatten(y_true, y_pred)
137
+ y_true_mask_idx = np.where(y_true != self.ignore_y)
138
+ if self.ignore_y is not None:
139
+ y_true = y_true[y_true_mask_idx]
140
+ try:
141
+ y_pred = y_pred[y_true_mask_idx]
142
+ except Exception as e:
143
+ warnings.warn(str(e))
144
+
145
+ kwargs.update(self.kwargs)
146
+ return {name: self.compute(y_true, y_pred, *args, **kwargs)}
147
+
148
+ return wrapper
149
+ else:
150
+ return super().__getattribute__(name)
151
+
152
+ def compute(self, y_true, y_pred, *args, **kwargs):
153
+ """
154
+ Compute the metric, based on the true and predicted values.
155
+
156
+ This method computes the classification metric using the provided
157
+ metric function. It handles preprocessing and applies any additional
158
+ keyword arguments.
159
+
160
+ Args:
161
+ y_true: The true values (ground truth labels).
162
+ y_pred: The predicted values (model predictions).
163
+ *args: Additional positional arguments for the metric function.
164
+ **kwargs: Additional keyword arguments for the metric function.
165
+
166
+ Returns:
167
+ dict: A dictionary with the metric name as key and its value.
168
+
169
+ Raises:
170
+ NotImplementedError: If no metric function is provided and the method
171
+ is not implemented by the subclass.
172
+
173
+ Example:
174
+ >>> metric = ClassificationMetric(metrics.accuracy_score)
175
+ >>> result = metric.compute(y_true, y_pred)
176
+ >>> print(result) # {'accuracy_score': 0.85}
177
+ """
178
+ if self.metric_func is not None:
179
+ kwargs.update(self.kwargs)
180
+ return self.metric_func(y_true, y_pred, *args, **kwargs)
181
+ else:
182
+ raise NotImplementedError(
183
+ "Method compute() is not implemented in the child class."
184
+ )
@@ -0,0 +1,199 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: regression_metric.py
3
+ # time: 12:57 09/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+
10
+
11
+ import types
12
+ import warnings
13
+
14
+ import numpy as np
15
+ import sklearn.metrics as metrics
16
+
17
+ from ..abc.abstract_metric import OmniMetric
18
+
19
+
20
+ def mcrmse(y_true, y_pred):
21
+ """
22
+ Compute Mean Column Root Mean Square Error (MCRMSE).
23
+
24
+ MCRMSE is a multi-target regression metric that computes the RMSE for each target
25
+ column and then takes the mean across all targets.
26
+
27
+ Args:
28
+ y_true (np.ndarray): Ground truth values with shape (n_samples, n_targets)
29
+ y_pred (np.ndarray): Predicted values with shape (n_samples, n_targets)
30
+
31
+ Returns:
32
+ float: Mean Column Root Mean Square Error
33
+
34
+ Raises:
35
+ ValueError: If y_true and y_pred have different shapes
36
+
37
+ Example:
38
+ >>> y_true = np.array([[1, 2], [3, 4], [5, 6]])
39
+ >>> y_pred = np.array([[1.1, 2.1], [2.9, 4.1], [5.2, 5.8]])
40
+ >>> mcrmse(y_true, y_pred)
41
+ 0.1833...
42
+ """
43
+ if y_true.shape != y_pred.shape:
44
+ raise ValueError("y_true and y_pred must have the same shape")
45
+ mask = y_true != -100
46
+ filtered_y_pred = y_pred[mask]
47
+ filtered_y_true = y_true[mask]
48
+ rmse_per_target = np.sqrt(np.mean((filtered_y_true - filtered_y_pred) ** 2, axis=0))
49
+ mcrmse_value = np.mean(rmse_per_target)
50
+ return mcrmse_value
51
+
52
+
53
+ setattr(metrics, "mcrmse", mcrmse)
54
+
55
+
56
+ class Metric(OmniMetric):
57
+ """
58
+ A flexible metric class that provides access to all scikit-learn metrics
59
+ and custom metrics for evaluation.
60
+
61
+ This class dynamically wraps scikit-learn metrics and provides a unified
62
+ interface for computing various evaluation metrics. It handles different
63
+ input formats including HuggingFace trainer outputs and supports
64
+ custom metric functions.
65
+
66
+ Attributes:
67
+ metric_func: Custom metric function if provided
68
+ ignore_y: Value to ignore in predictions and true values
69
+ kwargs: Additional keyword arguments for metric computation
70
+ metrics: Dictionary of available metrics including custom ones
71
+
72
+ Example:
73
+ >>> from omnigenome.src.metric import Metric
74
+ >>> metric = Metric(ignore_y=-100)
75
+ >>> y_true = [0, 1, 2, 0, 1]
76
+ >>> y_pred = [0, 1, 1, 0, 1]
77
+ >>> result = metric.accuracy(y_true, y_pred)
78
+ >>> print(result)
79
+ {'accuracy': 0.8}
80
+ """
81
+
82
+ def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
83
+ """
84
+ Initialize the Metric class.
85
+
86
+ Args:
87
+ metric_func (callable, optional): Custom metric function to use
88
+ ignore_y (int, optional): Value to ignore in predictions and true values. Defaults to -100
89
+ *args: Additional positional arguments
90
+ **kwargs: Additional keyword arguments for metric computation
91
+ """
92
+ super().__init__(metric_func, ignore_y, *args, **kwargs)
93
+ self.kwargs = kwargs
94
+ self.metrics = {"mcrmse": mcrmse}
95
+ for key, value in metrics.__dict__.items():
96
+ setattr(self, key, value)
97
+
98
+ def __getattribute__(self, name):
99
+ """
100
+ Dynamically create metric computation methods.
101
+
102
+ This method intercepts attribute access and creates wrapper functions
103
+ for scikit-learn metrics, handling different input formats and
104
+ preprocessing the data appropriately.
105
+
106
+ Args:
107
+ name (str): Name of the metric to access
108
+
109
+ Returns:
110
+ callable: Wrapper function for the requested metric
111
+ """
112
+ # Get the metric function
113
+ metric_func = getattr(metrics, name, None)
114
+
115
+ if metric_func and isinstance(metric_func, types.FunctionType):
116
+ setattr(self, "compute", metric_func)
117
+ # If the metric function exists, return a wrapper function
118
+
119
+ def wrapper(y_true=None, y_score=None, *args, **kwargs):
120
+ """
121
+ Compute the metric, based on the true and predicted values.
122
+
123
+ This wrapper handles different input formats including HuggingFace
124
+ trainer outputs and performs necessary preprocessing.
125
+
126
+ Args:
127
+ y_true: The true values or HuggingFace EvalPrediction object
128
+ y_score: The predicted values
129
+ ignore_y: The value to ignore in the predictions and true values in corresponding positions
130
+ *args: Additional positional arguments for the metric
131
+ **kwargs: Additional keyword arguments for the metric
132
+
133
+ Returns:
134
+ dict: Dictionary containing the metric name and computed value
135
+
136
+ Raises:
137
+ ValueError: If neither y_true nor y_score is provided
138
+ """
139
+ # This is an ugly method to handle the case when the predictions are in the form of a tuple
140
+ # for huggingface trainers
141
+ if y_true is not None and y_score is None:
142
+ if hasattr(y_true, "predictions"):
143
+ y_score = y_true.predictions
144
+ if hasattr(y_true, "label_ids"):
145
+ y_true = y_true.label_ids
146
+ if hasattr(y_true, "labels"):
147
+ y_true = y_true.labels
148
+ if len(y_score[0][1]) == np.max(y_true) + 1:
149
+ y_score = y_score[0]
150
+ else:
151
+ y_score = y_score[1]
152
+ y_score = np.argmax(y_score, axis=1)
153
+ elif y_true is not None and y_score is not None:
154
+ pass # y_true and y_score are provided
155
+ else:
156
+ raise ValueError(
157
+ "Please provide the true and predicted values or a dictionary with 'y_true' and 'y_score'."
158
+ )
159
+
160
+ y_true, y_score = Metric.flatten(y_true, y_score)
161
+ y_true_mask_idx = np.where(y_true != self.ignore_y)
162
+ if self.ignore_y is not None:
163
+ y_true = y_true[y_true_mask_idx]
164
+ try:
165
+ y_score = y_score[y_true_mask_idx]
166
+ except Exception as e:
167
+ warnings.warn(str(e))
168
+ kwargs.update(self.kwargs)
169
+
170
+ return {name: self.compute(y_true, y_score, *args, **kwargs)}
171
+
172
+ return wrapper
173
+ else:
174
+ return super().__getattribute__(name)
175
+
176
+ def compute(self, y_true, y_score, *args, **kwargs):
177
+ """
178
+ Compute the metric, based on the true and predicted values.
179
+
180
+ Args:
181
+ y_true: The true values
182
+ y_score: The predicted values
183
+ *args: Additional positional arguments for the metric
184
+ **kwargs: Additional keyword arguments for the metric
185
+
186
+ Returns:
187
+ The computed metric value
188
+
189
+ Raises:
190
+ NotImplementedError: If no metric function is provided and compute is not implemented
191
+ """
192
+ if self.metric_func is not None:
193
+ kwargs.update(self.kwargs)
194
+ return self.metric_func(y_true, y_score, *args, **kwargs)
195
+
196
+ else:
197
+ raise NotImplementedError(
198
+ "Method compute() is not implemented in the child class."
199
+ )