omnigenome 0.3.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +281 -0
- omnigenome/auto/__init__.py +3 -0
- omnigenome/auto/auto_bench/__init__.py +12 -0
- omnigenome/auto/auto_bench/auto_bench.py +484 -0
- omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
- omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
- omnigenome/auto/auto_bench/config_check.py +34 -0
- omnigenome/auto/auto_train/__init__.py +13 -0
- omnigenome/auto/auto_train/auto_train.py +430 -0
- omnigenome/auto/auto_train/auto_train_cli.py +222 -0
- omnigenome/auto/bench_hub/__init__.py +12 -0
- omnigenome/auto/bench_hub/bench_hub.py +25 -0
- omnigenome/cli/__init__.py +13 -0
- omnigenome/cli/commands/__init__.py +13 -0
- omnigenome/cli/commands/base.py +83 -0
- omnigenome/cli/commands/bench/__init__.py +13 -0
- omnigenome/cli/commands/bench/bench_cli.py +202 -0
- omnigenome/cli/commands/rna/__init__.py +13 -0
- omnigenome/cli/commands/rna/rna_design.py +178 -0
- omnigenome/cli/omnigenome_cli.py +128 -0
- omnigenome/src/__init__.py +12 -0
- omnigenome/src/abc/__init__.py +12 -0
- omnigenome/src/abc/abstract_dataset.py +622 -0
- omnigenome/src/abc/abstract_metric.py +114 -0
- omnigenome/src/abc/abstract_model.py +689 -0
- omnigenome/src/abc/abstract_tokenizer.py +267 -0
- omnigenome/src/dataset/__init__.py +16 -0
- omnigenome/src/dataset/omni_dataset.py +435 -0
- omnigenome/src/lora/__init__.py +13 -0
- omnigenome/src/lora/lora_model.py +294 -0
- omnigenome/src/metric/__init__.py +15 -0
- omnigenome/src/metric/classification_metric.py +184 -0
- omnigenome/src/metric/metric.py +199 -0
- omnigenome/src/metric/ranking_metric.py +142 -0
- omnigenome/src/metric/regression_metric.py +191 -0
- omnigenome/src/misc/__init__.py +3 -0
- omnigenome/src/misc/utils.py +439 -0
- omnigenome/src/model/__init__.py +19 -0
- omnigenome/src/model/augmentation/__init__.py +12 -0
- omnigenome/src/model/augmentation/model.py +219 -0
- omnigenome/src/model/classification/__init__.py +12 -0
- omnigenome/src/model/classification/model.py +642 -0
- omnigenome/src/model/embedding/__init__.py +12 -0
- omnigenome/src/model/embedding/model.py +263 -0
- omnigenome/src/model/mlm/__init__.py +12 -0
- omnigenome/src/model/mlm/model.py +177 -0
- omnigenome/src/model/module_utils.py +232 -0
- omnigenome/src/model/regression/__init__.py +12 -0
- omnigenome/src/model/regression/model.py +786 -0
- omnigenome/src/model/regression/resnet.py +483 -0
- omnigenome/src/model/rna_design/__init__.py +12 -0
- omnigenome/src/model/rna_design/model.py +426 -0
- omnigenome/src/model/seq2seq/__init__.py +12 -0
- omnigenome/src/model/seq2seq/model.py +44 -0
- omnigenome/src/tokenizer/__init__.py +16 -0
- omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
- omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
- omnigenome/src/trainer/__init__.py +14 -0
- omnigenome/src/trainer/accelerate_trainer.py +739 -0
- omnigenome/src/trainer/hf_trainer.py +75 -0
- omnigenome/src/trainer/trainer.py +579 -0
- omnigenome/utility/__init__.py +3 -0
- omnigenome/utility/dataset_hub/__init__.py +13 -0
- omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
- omnigenome/utility/ensemble.py +324 -0
- omnigenome/utility/hub_utils.py +517 -0
- omnigenome/utility/model_hub/__init__.py +12 -0
- omnigenome/utility/model_hub/model_hub.py +231 -0
- omnigenome/utility/pipeline_hub/__init__.py +12 -0
- omnigenome/utility/pipeline_hub/pipeline.py +483 -0
- omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
- omnigenome-0.3.0a0.dist-info/METADATA +224 -0
- omnigenome-0.3.0a0.dist-info/RECORD +85 -0
- omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
- omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
- omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
- tests/__init__.py +9 -0
- tests/conftest.py +160 -0
- tests/test_dataset_patterns.py +291 -0
- tests/test_examples_syntax.py +83 -0
- tests/test_model_loading.py +183 -0
- tests/test_rna_functions.py +255 -0
- tests/test_training_patterns.py +302 -0
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: abstract_metric.py
|
|
3
|
+
# time: 12:58 09/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
import numpy as np
|
|
10
|
+
import sklearn.metrics as metrics
|
|
11
|
+
|
|
12
|
+
from ..misc.utils import env_meta_info
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OmniMetric:
|
|
16
|
+
"""
|
|
17
|
+
Abstract base class for all metrics in OmniGenome, based on scikit-learn.
|
|
18
|
+
|
|
19
|
+
This class provides a unified interface for evaluation metrics in the OmniGenome
|
|
20
|
+
framework. It integrates with scikit-learn's metric functions and provides
|
|
21
|
+
additional functionality for handling genomic data evaluation.
|
|
22
|
+
|
|
23
|
+
The class automatically exposes all scikit-learn metrics as attributes,
|
|
24
|
+
making them easily accessible for evaluation tasks.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
metric_func (callable): A callable metric function from `sklearn.metrics`.
|
|
28
|
+
ignore_y (any): A value in the ground truth labels to be ignored during
|
|
29
|
+
metric computation.
|
|
30
|
+
metadata (dict): Metadata about the metric including version info.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, metric_func=None, ignore_y=None, *args, **kwargs):
|
|
34
|
+
"""
|
|
35
|
+
Initializes the metric.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
metric_func (callable, optional): A callable metric function from
|
|
39
|
+
`sklearn.metrics`. If None, subclasses
|
|
40
|
+
should implement their own compute method.
|
|
41
|
+
ignore_y (any, optional): A value in the ground truth labels to be
|
|
42
|
+
ignored during metric computation.
|
|
43
|
+
*args: Additional positional arguments.
|
|
44
|
+
**kwargs: Additional keyword arguments.
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> # Initialize with a specific metric function
|
|
48
|
+
>>> metric = OmniMetric(metrics.accuracy_score)
|
|
49
|
+
|
|
50
|
+
>>> # Initialize with ignore value
|
|
51
|
+
>>> metric = OmniMetric(ignore_y=-100)
|
|
52
|
+
"""
|
|
53
|
+
self.metric_func = metric_func
|
|
54
|
+
self.ignore_y = ignore_y
|
|
55
|
+
|
|
56
|
+
# Expose all scikit-learn metrics as attributes
|
|
57
|
+
for metric in metrics.__dict__.keys():
|
|
58
|
+
setattr(self, metric, metrics.__dict__[metric])
|
|
59
|
+
|
|
60
|
+
self.metadata = env_meta_info()
|
|
61
|
+
|
|
62
|
+
def compute(self, y_true, y_pred) -> dict:
|
|
63
|
+
"""
|
|
64
|
+
Computes the metric. This method must be implemented by subclasses.
|
|
65
|
+
|
|
66
|
+
This method should be implemented by concrete metric classes to define
|
|
67
|
+
how the metric is calculated for their specific evaluation task.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
y_true: Ground truth labels.
|
|
71
|
+
y_pred: Predicted labels.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
dict: A dictionary with the metric name as key and its value.
|
|
75
|
+
|
|
76
|
+
Raises:
|
|
77
|
+
NotImplementedError: If the method is not implemented by the subclass.
|
|
78
|
+
|
|
79
|
+
Example:
|
|
80
|
+
>>> # In a classification metric
|
|
81
|
+
>>> result = metric.compute(y_true, y_pred)
|
|
82
|
+
>>> print(result) # {'accuracy': 0.85}
|
|
83
|
+
"""
|
|
84
|
+
raise NotImplementedError(
|
|
85
|
+
"Method compute() is not implemented in the child class. "
|
|
86
|
+
"This function returns a dict containing the metric name and value."
|
|
87
|
+
"e.g. {'accuracy': 0.9}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def flatten(y_true, y_pred):
|
|
92
|
+
"""
|
|
93
|
+
Flattens the ground truth and prediction arrays.
|
|
94
|
+
|
|
95
|
+
This utility method ensures that the input arrays are properly flattened
|
|
96
|
+
for metric computation. It handles various input formats and converts
|
|
97
|
+
them to 1D numpy arrays.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
y_true: Ground truth labels in any format that can be converted to numpy array.
|
|
101
|
+
y_pred: Predicted labels in any format that can be converted to numpy array.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
tuple: A tuple of flattened `y_true` and `y_pred` as numpy arrays.
|
|
105
|
+
|
|
106
|
+
Example:
|
|
107
|
+
>>> y_true = [[1, 2], [3, 4]]
|
|
108
|
+
>>> y_pred = [[1, 2], [3, 4]]
|
|
109
|
+
>>> flat_true, flat_pred = OmniMetric.flatten(y_true, y_pred)
|
|
110
|
+
>>> print(flat_true.shape) # (4,)
|
|
111
|
+
"""
|
|
112
|
+
y_true = np.array(y_true).flatten()
|
|
113
|
+
y_pred = np.array(y_pred).flatten()
|
|
114
|
+
return y_true, y_pred
|