omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +29 -44
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +214 -150
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +168 -118
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
- omnigenome-0.3.0a0.dist-info/RECORD +0 -85
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
|
@@ -18,22 +18,23 @@ import torch
|
|
|
18
18
|
from torch import nn
|
|
19
19
|
from omnigenome.src.misc.utils import fprint
|
|
20
20
|
|
|
21
|
+
|
|
21
22
|
def find_linear_target_modules(model, keyword_filter=None, use_full_path=True):
|
|
22
23
|
"""
|
|
23
24
|
Find linear modules in a model that can be targeted for LoRA adaptation.
|
|
24
|
-
|
|
25
|
+
|
|
25
26
|
This function searches through a model's modules to identify linear layers
|
|
26
27
|
that can be adapted using LoRA. It supports filtering by keyword patterns
|
|
27
28
|
to target specific types of layers.
|
|
28
|
-
|
|
29
|
+
|
|
29
30
|
Args:
|
|
30
31
|
model: The model to search for linear modules
|
|
31
32
|
keyword_filter (str, list, tuple, optional): Keywords to filter modules by name
|
|
32
33
|
use_full_path (bool): Whether to return full module paths or just names (default: True)
|
|
33
|
-
|
|
34
|
+
|
|
34
35
|
Returns:
|
|
35
36
|
list: Sorted list of linear module names that can be targeted for LoRA
|
|
36
|
-
|
|
37
|
+
|
|
37
38
|
Raises:
|
|
38
39
|
TypeError: If keyword_filter is not None, str, or a list/tuple of str
|
|
39
40
|
"""
|
|
@@ -46,31 +47,32 @@ def find_linear_target_modules(model, keyword_filter=None, use_full_path=True):
|
|
|
46
47
|
elif not isinstance(keyword_filter, (list, tuple)):
|
|
47
48
|
raise TypeError("keyword_filter must be None, str, or a list/tuple of str")
|
|
48
49
|
|
|
49
|
-
pattern =
|
|
50
|
+
pattern = "|".join(map(re.escape, keyword_filter))
|
|
50
51
|
|
|
51
52
|
linear_modules = set()
|
|
52
53
|
for name, module in model.named_modules():
|
|
53
54
|
if isinstance(module, nn.Linear):
|
|
54
55
|
if keyword_filter is None or re.search(pattern, name, re.IGNORECASE):
|
|
55
|
-
linear_modules.add(name if use_full_path else name.split(
|
|
56
|
+
linear_modules.add(name if use_full_path else name.split(".")[-1])
|
|
56
57
|
|
|
57
58
|
return sorted(linear_modules)
|
|
58
59
|
|
|
60
|
+
|
|
59
61
|
def auto_lora_model(model, **kwargs):
|
|
60
62
|
"""
|
|
61
63
|
Automatically create a LoRA-adapted model.
|
|
62
|
-
|
|
64
|
+
|
|
63
65
|
This function automatically identifies suitable target modules and creates
|
|
64
66
|
a LoRA-adapted version of the input model. It handles configuration
|
|
65
67
|
setup and parameter freezing for efficient fine-tuning.
|
|
66
|
-
|
|
68
|
+
|
|
67
69
|
Args:
|
|
68
70
|
model: The base model to adapt with LoRA
|
|
69
71
|
**kwargs: Additional LoRA configuration parameters
|
|
70
|
-
|
|
72
|
+
|
|
71
73
|
Returns:
|
|
72
74
|
The LoRA-adapted model
|
|
73
|
-
|
|
75
|
+
|
|
74
76
|
Raises:
|
|
75
77
|
AssertionError: If no target modules are found for LoRA injection
|
|
76
78
|
"""
|
|
@@ -79,8 +81,8 @@ def auto_lora_model(model, **kwargs):
|
|
|
79
81
|
|
|
80
82
|
# A bad case for the EVO-1 model, which has a custom config class
|
|
81
83
|
######################
|
|
82
|
-
if hasattr(model,
|
|
83
|
-
delattr(model.config,
|
|
84
|
+
if hasattr(model, "config") and not isinstance(model.config, PretrainedConfig):
|
|
85
|
+
delattr(model.config, "Loader")
|
|
84
86
|
model.config = PretrainedConfig.from_dict(dict(model.config))
|
|
85
87
|
#######################
|
|
86
88
|
|
|
@@ -92,7 +94,9 @@ def auto_lora_model(model, **kwargs):
|
|
|
92
94
|
lora_dropout = kwargs.pop("lora_dropout", 0.1)
|
|
93
95
|
|
|
94
96
|
if target_modules is None:
|
|
95
|
-
target_modules = find_linear_target_modules(
|
|
97
|
+
target_modules = find_linear_target_modules(
|
|
98
|
+
model, keyword_filter=kwargs.get("keyword_filter", None)
|
|
99
|
+
)
|
|
96
100
|
assert target_modules is not None, "No target modules found for LoRA injection."
|
|
97
101
|
config = LoraConfig(
|
|
98
102
|
target_modules=target_modules,
|
|
@@ -115,29 +119,30 @@ def auto_lora_model(model, **kwargs):
|
|
|
115
119
|
)
|
|
116
120
|
return lora_model
|
|
117
121
|
|
|
122
|
+
|
|
118
123
|
class OmniLoraModel(nn.Module):
|
|
119
124
|
"""
|
|
120
125
|
LoRA-adapted model for OmniGenome.
|
|
121
|
-
|
|
126
|
+
|
|
122
127
|
This class provides a wrapper around LoRA-adapted models, enabling
|
|
123
128
|
efficient fine-tuning of large genomic language models while maintaining
|
|
124
129
|
compatibility with the OmniGenome framework.
|
|
125
|
-
|
|
130
|
+
|
|
126
131
|
Attributes:
|
|
127
132
|
lora_model: The underlying LoRA-adapted model
|
|
128
133
|
config: Model configuration
|
|
129
134
|
device: Device the model is running on
|
|
130
135
|
dtype: Data type of the model parameters
|
|
131
136
|
"""
|
|
132
|
-
|
|
137
|
+
|
|
133
138
|
def __init__(self, model, **kwargs):
|
|
134
139
|
"""
|
|
135
140
|
Initialize the LoRA-adapted model.
|
|
136
|
-
|
|
141
|
+
|
|
137
142
|
Args:
|
|
138
143
|
model: The base model to adapt with LoRA
|
|
139
144
|
**kwargs: LoRA configuration parameters
|
|
140
|
-
|
|
145
|
+
|
|
141
146
|
Raises:
|
|
142
147
|
ValueError: If no target modules are specified for LoRA injection
|
|
143
148
|
"""
|
|
@@ -147,7 +152,8 @@ class OmniLoraModel(nn.Module):
|
|
|
147
152
|
raise ValueError(
|
|
148
153
|
"No target modules found for LoRA injection. To perform LoRA adaptation fine-tuning, "
|
|
149
154
|
"please specify the target modules using the 'target_modules' argument. "
|
|
150
|
-
"The target modules depend on the model architecture, such as 'query', 'value', etc. "
|
|
155
|
+
"The target modules depend on the model architecture, such as 'query', 'value', etc. "
|
|
156
|
+
)
|
|
151
157
|
|
|
152
158
|
self.lora_model = auto_lora_model(model, **kwargs)
|
|
153
159
|
|
|
@@ -159,23 +165,23 @@ class OmniLoraModel(nn.Module):
|
|
|
159
165
|
)
|
|
160
166
|
|
|
161
167
|
self.config = model.config
|
|
162
|
-
self.to(
|
|
168
|
+
self.to("cpu") # Move the model to CPU initially
|
|
163
169
|
fprint(
|
|
164
170
|
"LoRA model initialized with the following configuration:\n",
|
|
165
|
-
self.lora_model
|
|
171
|
+
self.lora_model,
|
|
166
172
|
)
|
|
167
173
|
|
|
168
174
|
def to(self, *args, **kwargs):
|
|
169
175
|
"""
|
|
170
176
|
Move the model to a specific device and data type.
|
|
171
|
-
|
|
177
|
+
|
|
172
178
|
This method overrides the default to() method to ensure the LoRA model
|
|
173
179
|
and its components are properly moved to the target device and dtype.
|
|
174
|
-
|
|
180
|
+
|
|
175
181
|
Args:
|
|
176
182
|
*args: Device specification (e.g., 'cuda', 'cpu')
|
|
177
183
|
**kwargs: Additional arguments including dtype
|
|
178
|
-
|
|
184
|
+
|
|
179
185
|
Returns:
|
|
180
186
|
self: The model instance
|
|
181
187
|
"""
|
|
@@ -188,20 +194,20 @@ class OmniLoraModel(nn.Module):
|
|
|
188
194
|
break
|
|
189
195
|
for module in self.lora_model.modules():
|
|
190
196
|
module.device = self.device
|
|
191
|
-
if hasattr(module,
|
|
197
|
+
if hasattr(module, "dtype"):
|
|
192
198
|
module.dtype = self.dtype
|
|
193
199
|
except Exception as e:
|
|
194
|
-
pass
|
|
200
|
+
pass # Ignore errors if parameters are not available
|
|
195
201
|
return self
|
|
196
202
|
|
|
197
203
|
def forward(self, *args, **kwargs):
|
|
198
204
|
"""
|
|
199
205
|
Forward pass through the LoRA model.
|
|
200
|
-
|
|
206
|
+
|
|
201
207
|
Args:
|
|
202
208
|
*args: Positional arguments for the forward pass
|
|
203
209
|
**kwargs: Keyword arguments for the forward pass
|
|
204
|
-
|
|
210
|
+
|
|
205
211
|
Returns:
|
|
206
212
|
The output from the LoRA model
|
|
207
213
|
"""
|
|
@@ -210,11 +216,11 @@ class OmniLoraModel(nn.Module):
|
|
|
210
216
|
def predict(self, *args, **kwargs):
|
|
211
217
|
"""
|
|
212
218
|
Generate predictions using the LoRA model.
|
|
213
|
-
|
|
219
|
+
|
|
214
220
|
Args:
|
|
215
221
|
*args: Positional arguments for prediction
|
|
216
222
|
**kwargs: Keyword arguments for prediction
|
|
217
|
-
|
|
223
|
+
|
|
218
224
|
Returns:
|
|
219
225
|
Model predictions
|
|
220
226
|
"""
|
|
@@ -223,11 +229,11 @@ class OmniLoraModel(nn.Module):
|
|
|
223
229
|
def save(self, *args, **kwargs):
|
|
224
230
|
"""
|
|
225
231
|
Save the LoRA model.
|
|
226
|
-
|
|
232
|
+
|
|
227
233
|
Args:
|
|
228
234
|
*args: Positional arguments for saving
|
|
229
235
|
**kwargs: Keyword arguments for saving
|
|
230
|
-
|
|
236
|
+
|
|
231
237
|
Returns:
|
|
232
238
|
Result of the save operation
|
|
233
239
|
"""
|
|
@@ -236,7 +242,7 @@ class OmniLoraModel(nn.Module):
|
|
|
236
242
|
def model_info(self):
|
|
237
243
|
"""
|
|
238
244
|
Get information about the LoRA model.
|
|
239
|
-
|
|
245
|
+
|
|
240
246
|
Returns:
|
|
241
247
|
Model information from the base model
|
|
242
248
|
"""
|
|
@@ -245,10 +251,10 @@ class OmniLoraModel(nn.Module):
|
|
|
245
251
|
def set_loss_fn(self, fn):
|
|
246
252
|
"""
|
|
247
253
|
Set the loss function for the LoRA model.
|
|
248
|
-
|
|
254
|
+
|
|
249
255
|
Args:
|
|
250
256
|
fn: Loss function to set
|
|
251
|
-
|
|
257
|
+
|
|
252
258
|
Returns:
|
|
253
259
|
Result of setting the loss function
|
|
254
260
|
"""
|
|
@@ -257,10 +263,10 @@ class OmniLoraModel(nn.Module):
|
|
|
257
263
|
def last_hidden_state_forward(self, **kwargs):
|
|
258
264
|
"""
|
|
259
265
|
Forward pass to get the last hidden state.
|
|
260
|
-
|
|
266
|
+
|
|
261
267
|
Args:
|
|
262
268
|
**kwargs: Keyword arguments for the forward pass
|
|
263
|
-
|
|
269
|
+
|
|
264
270
|
Returns:
|
|
265
271
|
Last hidden state from the base model
|
|
266
272
|
"""
|
|
@@ -269,7 +275,7 @@ class OmniLoraModel(nn.Module):
|
|
|
269
275
|
def tokenizer(self):
|
|
270
276
|
"""
|
|
271
277
|
Get the tokenizer from the base model.
|
|
272
|
-
|
|
278
|
+
|
|
273
279
|
Returns:
|
|
274
280
|
The tokenizer from the base model
|
|
275
281
|
"""
|
|
@@ -278,7 +284,7 @@ class OmniLoraModel(nn.Module):
|
|
|
278
284
|
def config(self):
|
|
279
285
|
"""
|
|
280
286
|
Get the configuration from the base model.
|
|
281
|
-
|
|
287
|
+
|
|
282
288
|
Returns:
|
|
283
289
|
The configuration from the base model
|
|
284
290
|
"""
|
|
@@ -287,8 +293,8 @@ class OmniLoraModel(nn.Module):
|
|
|
287
293
|
def model(self):
|
|
288
294
|
"""
|
|
289
295
|
Get the base model.
|
|
290
|
-
|
|
296
|
+
|
|
291
297
|
Returns:
|
|
292
298
|
The base model
|
|
293
299
|
"""
|
|
294
|
-
return self.lora_model.base_model.model
|
|
300
|
+
return self.lora_model.base_model.model
|
|
@@ -19,17 +19,17 @@ from ..abc.abstract_metric import OmniMetric
|
|
|
19
19
|
class ClassificationMetric(OmniMetric):
|
|
20
20
|
"""
|
|
21
21
|
Classification metric class for evaluating classification models.
|
|
22
|
-
|
|
22
|
+
|
|
23
23
|
This class provides a comprehensive interface for classification metrics
|
|
24
24
|
in the OmniGenome framework. It integrates with scikit-learn's classification
|
|
25
25
|
metrics and provides additional functionality for handling genomic classification
|
|
26
26
|
tasks.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
The class automatically exposes all scikit-learn classification metrics as
|
|
29
29
|
callable attributes, making them easily accessible for evaluation. It also
|
|
30
30
|
handles special cases like Hugging Face's EvalPrediction objects and
|
|
31
31
|
provides proper handling of ignored labels.
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
Attributes:
|
|
34
34
|
metric_func (callable): A callable metric function from sklearn.metrics.
|
|
35
35
|
ignore_y (any): A value in the ground truth labels to be ignored during
|
|
@@ -42,10 +42,10 @@ class ClassificationMetric(OmniMetric):
|
|
|
42
42
|
Initializes the classification metric.
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
|
-
metric_func (callable, optional): A callable metric function from
|
|
45
|
+
metric_func (callable, optional): A callable metric function from
|
|
46
46
|
sklearn.metrics. If None, subclasses
|
|
47
47
|
should implement their own compute method.
|
|
48
|
-
ignore_y (any, optional): A value in the ground truth labels to be
|
|
48
|
+
ignore_y (any, optional): A value in the ground truth labels to be
|
|
49
49
|
ignored during metric computation. Defaults to -100.
|
|
50
50
|
*args: Additional positional arguments.
|
|
51
51
|
**kwargs: Additional keyword arguments.
|
|
@@ -53,7 +53,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
53
53
|
Example:
|
|
54
54
|
>>> # Initialize with a specific metric function
|
|
55
55
|
>>> metric = ClassificationMetric(metrics.accuracy_score)
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
>>> # Initialize with ignore value
|
|
58
58
|
>>> metric = ClassificationMetric(ignore_y=-100)
|
|
59
59
|
"""
|
|
@@ -64,7 +64,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
64
64
|
def __getattribute__(self, name):
|
|
65
65
|
"""
|
|
66
66
|
Custom attribute getter that provides dynamic access to scikit-learn metrics.
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
This method provides transparent access to all scikit-learn classification
|
|
69
69
|
metrics. When a metric function is accessed, it returns a callable wrapper
|
|
70
70
|
that handles the metric computation with proper preprocessing.
|
|
@@ -91,7 +91,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
91
91
|
def wrapper(y_true=None, y_pred=None, *args, **kwargs):
|
|
92
92
|
"""
|
|
93
93
|
Compute the metric, based on the true and predicted values.
|
|
94
|
-
|
|
94
|
+
|
|
95
95
|
This wrapper function handles various input formats including
|
|
96
96
|
Hugging Face's EvalPrediction objects and provides proper
|
|
97
97
|
preprocessing for metric computation.
|
|
@@ -99,7 +99,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
99
99
|
Args:
|
|
100
100
|
y_true: The true values (ground truth labels).
|
|
101
101
|
y_pred: The predicted values (model predictions).
|
|
102
|
-
ignore_y: The value to ignore in the predictions and true
|
|
102
|
+
ignore_y: The value to ignore in the predictions and true
|
|
103
103
|
values in corresponding positions.
|
|
104
104
|
*args: Additional positional arguments for the metric function.
|
|
105
105
|
**kwargs: Additional keyword arguments for the metric function.
|
|
@@ -111,7 +111,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
111
111
|
>>> # Standard usage
|
|
112
112
|
>>> result = accuracy_fn(y_true, y_pred)
|
|
113
113
|
>>> print(result) # {'accuracy_score': 0.85}
|
|
114
|
-
|
|
114
|
+
|
|
115
115
|
>>> # With Hugging Face EvalPrediction
|
|
116
116
|
>>> result = accuracy_fn(eval_prediction)
|
|
117
117
|
>>> print(result) # {'accuracy_score': 0.85}
|
|
@@ -152,7 +152,7 @@ class ClassificationMetric(OmniMetric):
|
|
|
152
152
|
def compute(self, y_true, y_pred, *args, **kwargs):
|
|
153
153
|
"""
|
|
154
154
|
Compute the metric, based on the true and predicted values.
|
|
155
|
-
|
|
155
|
+
|
|
156
156
|
This method computes the classification metric using the provided
|
|
157
157
|
metric function. It handles preprocessing and applies any additional
|
|
158
158
|
keyword arguments.
|
omnigenome/src/metric/metric.py
CHANGED
|
@@ -20,20 +20,20 @@ from ..abc.abstract_metric import OmniMetric
|
|
|
20
20
|
def mcrmse(y_true, y_pred):
|
|
21
21
|
"""
|
|
22
22
|
Compute Mean Column Root Mean Square Error (MCRMSE).
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
MCRMSE is a multi-target regression metric that computes the RMSE for each target
|
|
25
25
|
column and then takes the mean across all targets.
|
|
26
|
-
|
|
26
|
+
|
|
27
27
|
Args:
|
|
28
28
|
y_true (np.ndarray): Ground truth values with shape (n_samples, n_targets)
|
|
29
29
|
y_pred (np.ndarray): Predicted values with shape (n_samples, n_targets)
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Returns:
|
|
32
32
|
float: Mean Column Root Mean Square Error
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
Raises:
|
|
35
35
|
ValueError: If y_true and y_pred have different shapes
|
|
36
|
-
|
|
36
|
+
|
|
37
37
|
Example:
|
|
38
38
|
>>> y_true = np.array([[1, 2], [3, 4], [5, 6]])
|
|
39
39
|
>>> y_pred = np.array([[1.1, 2.1], [2.9, 4.1], [5.2, 5.8]])
|
|
@@ -57,18 +57,18 @@ class Metric(OmniMetric):
|
|
|
57
57
|
"""
|
|
58
58
|
A flexible metric class that provides access to all scikit-learn metrics
|
|
59
59
|
and custom metrics for evaluation.
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
This class dynamically wraps scikit-learn metrics and provides a unified
|
|
62
62
|
interface for computing various evaluation metrics. It handles different
|
|
63
63
|
input formats including HuggingFace trainer outputs and supports
|
|
64
64
|
custom metric functions.
|
|
65
|
-
|
|
65
|
+
|
|
66
66
|
Attributes:
|
|
67
67
|
metric_func: Custom metric function if provided
|
|
68
68
|
ignore_y: Value to ignore in predictions and true values
|
|
69
69
|
kwargs: Additional keyword arguments for metric computation
|
|
70
70
|
metrics: Dictionary of available metrics including custom ones
|
|
71
|
-
|
|
71
|
+
|
|
72
72
|
Example:
|
|
73
73
|
>>> from omnigenome.src.metric import Metric
|
|
74
74
|
>>> metric = Metric(ignore_y=-100)
|
|
@@ -82,7 +82,7 @@ class Metric(OmniMetric):
|
|
|
82
82
|
def __init__(self, metric_func=None, ignore_y=-100, *args, **kwargs):
|
|
83
83
|
"""
|
|
84
84
|
Initialize the Metric class.
|
|
85
|
-
|
|
85
|
+
|
|
86
86
|
Args:
|
|
87
87
|
metric_func (callable, optional): Custom metric function to use
|
|
88
88
|
ignore_y (int, optional): Value to ignore in predictions and true values. Defaults to -100
|
|
@@ -98,14 +98,14 @@ class Metric(OmniMetric):
|
|
|
98
98
|
def __getattribute__(self, name):
|
|
99
99
|
"""
|
|
100
100
|
Dynamically create metric computation methods.
|
|
101
|
-
|
|
101
|
+
|
|
102
102
|
This method intercepts attribute access and creates wrapper functions
|
|
103
103
|
for scikit-learn metrics, handling different input formats and
|
|
104
104
|
preprocessing the data appropriately.
|
|
105
|
-
|
|
105
|
+
|
|
106
106
|
Args:
|
|
107
107
|
name (str): Name of the metric to access
|
|
108
|
-
|
|
108
|
+
|
|
109
109
|
Returns:
|
|
110
110
|
callable: Wrapper function for the requested metric
|
|
111
111
|
"""
|
|
@@ -119,20 +119,20 @@ class Metric(OmniMetric):
|
|
|
119
119
|
def wrapper(y_true=None, y_score=None, *args, **kwargs):
|
|
120
120
|
"""
|
|
121
121
|
Compute the metric, based on the true and predicted values.
|
|
122
|
-
|
|
122
|
+
|
|
123
123
|
This wrapper handles different input formats including HuggingFace
|
|
124
124
|
trainer outputs and performs necessary preprocessing.
|
|
125
|
-
|
|
125
|
+
|
|
126
126
|
Args:
|
|
127
127
|
y_true: The true values or HuggingFace EvalPrediction object
|
|
128
128
|
y_score: The predicted values
|
|
129
129
|
ignore_y: The value to ignore in the predictions and true values in corresponding positions
|
|
130
130
|
*args: Additional positional arguments for the metric
|
|
131
131
|
**kwargs: Additional keyword arguments for the metric
|
|
132
|
-
|
|
132
|
+
|
|
133
133
|
Returns:
|
|
134
134
|
dict: Dictionary containing the metric name and computed value
|
|
135
|
-
|
|
135
|
+
|
|
136
136
|
Raises:
|
|
137
137
|
ValueError: If neither y_true nor y_score is provided
|
|
138
138
|
"""
|
|
@@ -176,16 +176,16 @@ class Metric(OmniMetric):
|
|
|
176
176
|
def compute(self, y_true, y_score, *args, **kwargs):
|
|
177
177
|
"""
|
|
178
178
|
Compute the metric, based on the true and predicted values.
|
|
179
|
-
|
|
179
|
+
|
|
180
180
|
Args:
|
|
181
181
|
y_true: The true values
|
|
182
182
|
y_score: The predicted values
|
|
183
183
|
*args: Additional positional arguments for the metric
|
|
184
184
|
**kwargs: Additional keyword arguments for the metric
|
|
185
|
-
|
|
185
|
+
|
|
186
186
|
Returns:
|
|
187
187
|
The computed metric value
|
|
188
|
-
|
|
188
|
+
|
|
189
189
|
Raises:
|
|
190
190
|
NotImplementedError: If no metric function is provided and compute is not implemented
|
|
191
191
|
"""
|
|
@@ -20,16 +20,16 @@ from ..abc.abstract_metric import OmniMetric
|
|
|
20
20
|
class RankingMetric(OmniMetric):
|
|
21
21
|
"""
|
|
22
22
|
A specialized metric class for ranking tasks and evaluation.
|
|
23
|
-
|
|
23
|
+
|
|
24
24
|
This class provides access to ranking-specific metrics from scikit-learn
|
|
25
25
|
and handles different input formats including HuggingFace trainer outputs.
|
|
26
26
|
It dynamically wraps scikit-learn metrics and provides a unified interface
|
|
27
27
|
for computing various ranking evaluation metrics.
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
Attributes:
|
|
30
30
|
metric_func: Custom metric function if provided
|
|
31
31
|
ignore_y: Value to ignore in predictions and true values
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
Example:
|
|
34
34
|
>>> from omnigenome.src.metric import RankingMetric
|
|
35
35
|
>>> metric = RankingMetric(ignore_y=-100)
|
|
@@ -43,7 +43,7 @@ class RankingMetric(OmniMetric):
|
|
|
43
43
|
def __init__(self, *args, **kwargs):
|
|
44
44
|
"""
|
|
45
45
|
Initialize the RankingMetric class.
|
|
46
|
-
|
|
46
|
+
|
|
47
47
|
Args:
|
|
48
48
|
*args: Additional positional arguments passed to parent class
|
|
49
49
|
**kwargs: Additional keyword arguments passed to parent class
|
|
@@ -53,17 +53,17 @@ class RankingMetric(OmniMetric):
|
|
|
53
53
|
def __getattr__(self, name):
|
|
54
54
|
"""
|
|
55
55
|
Dynamically create ranking metric computation methods.
|
|
56
|
-
|
|
56
|
+
|
|
57
57
|
This method intercepts attribute access and creates wrapper functions
|
|
58
58
|
for scikit-learn ranking metrics, handling different input formats and
|
|
59
59
|
preprocessing the data appropriately.
|
|
60
|
-
|
|
60
|
+
|
|
61
61
|
Args:
|
|
62
62
|
name (str): Name of the ranking metric to access
|
|
63
|
-
|
|
63
|
+
|
|
64
64
|
Returns:
|
|
65
65
|
callable: Wrapper function for the requested ranking metric
|
|
66
|
-
|
|
66
|
+
|
|
67
67
|
Raises:
|
|
68
68
|
AttributeError: If the requested metric is not found
|
|
69
69
|
"""
|
|
@@ -74,17 +74,17 @@ class RankingMetric(OmniMetric):
|
|
|
74
74
|
def wrapper(y_true=None, y_score=None, *args, **kwargs):
|
|
75
75
|
"""
|
|
76
76
|
Compute the ranking metric, based on the true and predicted values.
|
|
77
|
-
|
|
77
|
+
|
|
78
78
|
This wrapper handles different input formats including HuggingFace
|
|
79
79
|
trainer outputs and performs necessary preprocessing for ranking tasks.
|
|
80
|
-
|
|
80
|
+
|
|
81
81
|
Args:
|
|
82
82
|
y_true: The true values or HuggingFace EvalPrediction object
|
|
83
83
|
y_score: The predicted values (scores for ranking)
|
|
84
84
|
ignore_y: The value to ignore in the predictions and true values in corresponding positions
|
|
85
85
|
*args: Additional positional arguments for the metric
|
|
86
86
|
**kwargs: Additional keyword arguments for the metric
|
|
87
|
-
|
|
87
|
+
|
|
88
88
|
Returns:
|
|
89
89
|
dict: Dictionary containing the metric name and computed value
|
|
90
90
|
"""
|
|
@@ -121,19 +121,19 @@ class RankingMetric(OmniMetric):
|
|
|
121
121
|
def compute(self, y_true, y_score, *args, **kwargs):
|
|
122
122
|
"""
|
|
123
123
|
Compute the ranking metric, based on the true and predicted values.
|
|
124
|
-
|
|
124
|
+
|
|
125
125
|
This method should be implemented by subclasses to provide specific
|
|
126
126
|
ranking metric computation logic.
|
|
127
|
-
|
|
127
|
+
|
|
128
128
|
Args:
|
|
129
129
|
y_true: The true values
|
|
130
130
|
y_score: The predicted values (scores for ranking)
|
|
131
131
|
*args: Additional positional arguments for the metric
|
|
132
132
|
**kwargs: Additional keyword arguments for the metric
|
|
133
|
-
|
|
133
|
+
|
|
134
134
|
Returns:
|
|
135
135
|
The computed ranking metric value
|
|
136
|
-
|
|
136
|
+
|
|
137
137
|
Raises:
|
|
138
138
|
NotImplementedError: If compute method is not implemented in the child class
|
|
139
139
|
"""
|