omnigenome 0.3.1a0__py3-none-any.whl → 1.0.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +26 -266
- {omnigenome-0.3.1a0.dist-info → omnigenome-1.0.0b0.dist-info}/METADATA +8 -9
- omnigenome-1.0.0b0.dist-info/RECORD +6 -0
- omnigenome/auto/__init__.py +0 -3
- omnigenome/auto/auto_bench/__init__.py +0 -11
- omnigenome/auto/auto_bench/auto_bench.py +0 -494
- omnigenome/auto/auto_bench/auto_bench_cli.py +0 -230
- omnigenome/auto/auto_bench/auto_bench_config.py +0 -216
- omnigenome/auto/auto_bench/config_check.py +0 -34
- omnigenome/auto/auto_train/__init__.py +0 -12
- omnigenome/auto/auto_train/auto_train.py +0 -429
- omnigenome/auto/auto_train/auto_train_cli.py +0 -222
- omnigenome/auto/bench_hub/__init__.py +0 -11
- omnigenome/auto/bench_hub/bench_hub.py +0 -25
- omnigenome/cli/__init__.py +0 -12
- omnigenome/cli/commands/__init__.py +0 -12
- omnigenome/cli/commands/base.py +0 -83
- omnigenome/cli/commands/bench/__init__.py +0 -12
- omnigenome/cli/commands/bench/bench_cli.py +0 -202
- omnigenome/cli/commands/rna/__init__.py +0 -12
- omnigenome/cli/commands/rna/rna_design.py +0 -177
- omnigenome/cli/omnigenome_cli.py +0 -128
- omnigenome/src/__init__.py +0 -11
- omnigenome/src/abc/__init__.py +0 -11
- omnigenome/src/abc/abstract_dataset.py +0 -641
- omnigenome/src/abc/abstract_metric.py +0 -114
- omnigenome/src/abc/abstract_model.py +0 -690
- omnigenome/src/abc/abstract_tokenizer.py +0 -269
- omnigenome/src/dataset/__init__.py +0 -16
- omnigenome/src/dataset/omni_dataset.py +0 -437
- omnigenome/src/lora/__init__.py +0 -12
- omnigenome/src/lora/lora_model.py +0 -300
- omnigenome/src/metric/__init__.py +0 -15
- omnigenome/src/metric/classification_metric.py +0 -184
- omnigenome/src/metric/metric.py +0 -199
- omnigenome/src/metric/ranking_metric.py +0 -142
- omnigenome/src/metric/regression_metric.py +0 -191
- omnigenome/src/misc/__init__.py +0 -3
- omnigenome/src/misc/utils.py +0 -503
- omnigenome/src/model/__init__.py +0 -19
- omnigenome/src/model/augmentation/__init__.py +0 -11
- omnigenome/src/model/augmentation/model.py +0 -219
- omnigenome/src/model/classification/__init__.py +0 -11
- omnigenome/src/model/classification/model.py +0 -638
- omnigenome/src/model/embedding/__init__.py +0 -11
- omnigenome/src/model/embedding/model.py +0 -263
- omnigenome/src/model/mlm/__init__.py +0 -11
- omnigenome/src/model/mlm/model.py +0 -177
- omnigenome/src/model/module_utils.py +0 -232
- omnigenome/src/model/regression/__init__.py +0 -11
- omnigenome/src/model/regression/model.py +0 -781
- omnigenome/src/model/regression/resnet.py +0 -483
- omnigenome/src/model/rna_design/__init__.py +0 -11
- omnigenome/src/model/rna_design/model.py +0 -476
- omnigenome/src/model/seq2seq/__init__.py +0 -11
- omnigenome/src/model/seq2seq/model.py +0 -44
- omnigenome/src/tokenizer/__init__.py +0 -16
- omnigenome/src/tokenizer/bpe_tokenizer.py +0 -226
- omnigenome/src/tokenizer/kmers_tokenizer.py +0 -247
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +0 -249
- omnigenome/src/trainer/__init__.py +0 -14
- omnigenome/src/trainer/accelerate_trainer.py +0 -747
- omnigenome/src/trainer/hf_trainer.py +0 -75
- omnigenome/src/trainer/trainer.py +0 -591
- omnigenome/utility/__init__.py +0 -3
- omnigenome/utility/dataset_hub/__init__.py +0 -12
- omnigenome/utility/dataset_hub/dataset_hub.py +0 -178
- omnigenome/utility/ensemble.py +0 -324
- omnigenome/utility/hub_utils.py +0 -517
- omnigenome/utility/model_hub/__init__.py +0 -11
- omnigenome/utility/model_hub/model_hub.py +0 -232
- omnigenome/utility/pipeline_hub/__init__.py +0 -11
- omnigenome/utility/pipeline_hub/pipeline.py +0 -483
- omnigenome/utility/pipeline_hub/pipeline_hub.py +0 -129
- omnigenome-0.3.1a0.dist-info/RECORD +0 -78
- omnigenome-0.3.1a0.dist-info/entry_points.txt +0 -3
- {omnigenome-0.3.1a0.dist-info → omnigenome-1.0.0b0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.1a0.dist-info → omnigenome-1.0.0b0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.1a0.dist-info → omnigenome-1.0.0b0.dist-info}/top_level.txt +0 -0
|
@@ -1,178 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# File: dataset_hub.py
|
|
3
|
-
# Time: 02:22 20/06/2025
|
|
4
|
-
# Author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# Website: https://yangheng95.github.io
|
|
6
|
-
# GitHub: https://github.com/yangheng95
|
|
7
|
-
# HuggingFace: https://huggingface.co/yangheng
|
|
8
|
-
# Google Scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
9
|
-
# Copyright (C) 2019-2025. All rights reserved.
|
|
10
|
-
"""
|
|
11
|
-
Dataset Hub Module
|
|
12
|
-
|
|
13
|
-
This module provides utilities for loading benchmark datasets from the OmniGenome hub.
|
|
14
|
-
It handles automatic downloading, configuration loading, and dataset initialization
|
|
15
|
-
for various genomic benchmarks.
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
import os
|
|
19
|
-
import warnings
|
|
20
|
-
|
|
21
|
-
import findfile
|
|
22
|
-
from typing_extensions import Union
|
|
23
|
-
|
|
24
|
-
from ... import OmniTokenizer, download_benchmark
|
|
25
|
-
from ...src.misc.utils import load_module_from_path, fprint
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def load_benchmark_datasets(
|
|
29
|
-
benchmark: str,
|
|
30
|
-
tokenizer: Union["OmniTokenizer", str] = None,
|
|
31
|
-
**kwargs: dict,
|
|
32
|
-
):
|
|
33
|
-
"""
|
|
34
|
-
Load benchmark datasets from the OmniGenome hub.
|
|
35
|
-
|
|
36
|
-
This function automatically downloads benchmark datasets if they don't exist locally,
|
|
37
|
-
loads their configurations, and initializes train/validation/test datasets with
|
|
38
|
-
the specified tokenizer.
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
benchmark (str): Name or path of the benchmark to load. If the benchmark
|
|
42
|
-
doesn't exist locally, it will be downloaded from the hub.
|
|
43
|
-
tokenizer (Union[OmniTokenizer, str], optional): Tokenizer to use for
|
|
44
|
-
dataset preprocessing. Can be an OmniTokenizer instance or a string
|
|
45
|
-
identifier for a pre-trained tokenizer. If None, the tokenizer will
|
|
46
|
-
be loaded from the benchmark configuration.
|
|
47
|
-
**kwargs: Additional keyword arguments to override benchmark configuration.
|
|
48
|
-
These will be passed to the dataset classes and tokenizer initialization.
|
|
49
|
-
|
|
50
|
-
Returns:
|
|
51
|
-
dict: Dictionary containing datasets for each benchmark task, with keys
|
|
52
|
-
being benchmark names and values being dictionaries with 'train',
|
|
53
|
-
'valid', and 'test' datasets.
|
|
54
|
-
|
|
55
|
-
Raises:
|
|
56
|
-
FileNotFoundError: If the benchmark cannot be found or downloaded.
|
|
57
|
-
ValueError: If the benchmark configuration is invalid.
|
|
58
|
-
ImportError: If required dependencies are not available.
|
|
59
|
-
|
|
60
|
-
Example:
|
|
61
|
-
>>> from omnigenome import OmniSingleNucleotideTokenizer
|
|
62
|
-
>>> tokenizer = OmniSingleNucleotideTokenizer.from_pretrained("model_name")
|
|
63
|
-
>>> datasets = load_benchmark_datasets("RGB", tokenizer, max_length=512)
|
|
64
|
-
>>> print(f"Loaded {len(datasets)} benchmark tasks")
|
|
65
|
-
>>> for task_name, task_datasets in datasets.items():
|
|
66
|
-
... print(f"{task_name}: {len(task_datasets['train'])} train samples")
|
|
67
|
-
|
|
68
|
-
Note:
|
|
69
|
-
- The function automatically handles U/T conversion and other preprocessing
|
|
70
|
-
based on the benchmark configuration.
|
|
71
|
-
- If a tokenizer string is provided, it will be loaded with the benchmark's
|
|
72
|
-
trust_remote_code setting.
|
|
73
|
-
- The function supports multiple seeds for robust evaluation.
|
|
74
|
-
- Long sequences can be dropped or truncated based on configuration.
|
|
75
|
-
"""
|
|
76
|
-
if not os.path.exists(benchmark):
|
|
77
|
-
fprint(
|
|
78
|
-
"Benchmark:",
|
|
79
|
-
benchmark,
|
|
80
|
-
"does not exist. Search online for available benchmarks.",
|
|
81
|
-
)
|
|
82
|
-
benchmark = download_benchmark(benchmark)
|
|
83
|
-
|
|
84
|
-
# Import benchmark list
|
|
85
|
-
bench_metadata = load_module_from_path(
|
|
86
|
-
f"bench_metadata", f"{benchmark}/metadata.py"
|
|
87
|
-
)
|
|
88
|
-
datasets = {}
|
|
89
|
-
for _, bench in enumerate(bench_metadata.bench_list):
|
|
90
|
-
|
|
91
|
-
bench_config_path = findfile.find_file(
|
|
92
|
-
benchmark, f"{benchmark}.{bench}.config".split(".")
|
|
93
|
-
)
|
|
94
|
-
config = load_module_from_path("config", bench_config_path)
|
|
95
|
-
bench_config = config.bench_config
|
|
96
|
-
fprint(f"Loaded config for {bench} from {bench_config_path}")
|
|
97
|
-
fprint(bench_config)
|
|
98
|
-
_kwargs = kwargs.copy()
|
|
99
|
-
|
|
100
|
-
# Init Tokenizer and Model
|
|
101
|
-
if isinstance(tokenizer, str):
|
|
102
|
-
tokenizer = OmniTokenizer.from_pretrained(
|
|
103
|
-
tokenizer,
|
|
104
|
-
trust_remote_code=bench_config.get("trust_remote_code", True),
|
|
105
|
-
**bench_config,
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
for key, value in _kwargs.items():
|
|
109
|
-
if key in bench_config:
|
|
110
|
-
fprint("Override", key, "with", value, "according to the input kwargs")
|
|
111
|
-
bench_config.update({key: value})
|
|
112
|
-
|
|
113
|
-
else:
|
|
114
|
-
warnings.warn(
|
|
115
|
-
f"kwarg: {key} not found in bench_config while setting {key} = {value}"
|
|
116
|
-
)
|
|
117
|
-
bench_config.update({key: value})
|
|
118
|
-
|
|
119
|
-
for key, value in bench_config.items():
|
|
120
|
-
if key in bench_config and key in _kwargs:
|
|
121
|
-
_kwargs.pop(key)
|
|
122
|
-
|
|
123
|
-
if not isinstance(bench_config["seeds"], list):
|
|
124
|
-
bench_config["seeds"] = [bench_config["seeds"]]
|
|
125
|
-
|
|
126
|
-
# Init Trainer
|
|
127
|
-
dataset_cls = bench_config["dataset_cls"]
|
|
128
|
-
|
|
129
|
-
max_length = bench_config["max_length"]
|
|
130
|
-
|
|
131
|
-
train_set = dataset_cls(
|
|
132
|
-
data_source=bench_config["train_file"],
|
|
133
|
-
tokenizer=tokenizer,
|
|
134
|
-
label2id=bench_config["label2id"],
|
|
135
|
-
max_length=max_length,
|
|
136
|
-
structure_in=bench_config.get("structure_in", False),
|
|
137
|
-
max_examples=bench_config.get("max_examples", None),
|
|
138
|
-
shuffle=bench_config.get("shuffle", True),
|
|
139
|
-
drop_long_seq=bench_config.get("drop_long_seq", False),
|
|
140
|
-
**_kwargs,
|
|
141
|
-
)
|
|
142
|
-
test_set = dataset_cls(
|
|
143
|
-
data_source=bench_config["test_file"],
|
|
144
|
-
tokenizer=tokenizer,
|
|
145
|
-
label2id=bench_config["label2id"],
|
|
146
|
-
max_length=max_length,
|
|
147
|
-
structure_in=bench_config.get("structure_in", False),
|
|
148
|
-
max_examples=bench_config.get("max_examples", None),
|
|
149
|
-
shuffle=False,
|
|
150
|
-
drop_long_seq=bench_config.get("drop_long_seq", False),
|
|
151
|
-
**_kwargs,
|
|
152
|
-
)
|
|
153
|
-
valid_set = dataset_cls(
|
|
154
|
-
data_source=bench_config["valid_file"],
|
|
155
|
-
tokenizer=tokenizer,
|
|
156
|
-
label2id=bench_config["label2id"],
|
|
157
|
-
max_length=max_length,
|
|
158
|
-
structure_in=bench_config.get("structure_in", False),
|
|
159
|
-
max_examples=bench_config.get("max_examples", None),
|
|
160
|
-
shuffle=False,
|
|
161
|
-
drop_long_seq=bench_config.get("drop_long_seq", False),
|
|
162
|
-
**_kwargs,
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
dataset = {
|
|
166
|
-
"train": train_set,
|
|
167
|
-
"test": test_set,
|
|
168
|
-
"valid": valid_set,
|
|
169
|
-
}
|
|
170
|
-
|
|
171
|
-
fprint(
|
|
172
|
-
f"Loaded dataset for {bench} with {len(train_set)} train samples, "
|
|
173
|
-
f"{len(test_set)} test samples and {len(valid_set)} valid samples."
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
datasets[bench] = dataset
|
|
177
|
-
|
|
178
|
-
return datasets
|
omnigenome/utility/ensemble.py
DELETED
|
@@ -1,324 +0,0 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
# file: ensemble.py
|
|
3
|
-
# time: 21:39 24/04/2024
|
|
4
|
-
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
-
# github: https://github.com/yangheng95
|
|
6
|
-
# huggingface: https://huggingface.co/yangheng
|
|
7
|
-
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
-
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
-
from typing import List
|
|
10
|
-
|
|
11
|
-
import numpy as np
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class VoteEnsemblePredictor:
|
|
15
|
-
"""
|
|
16
|
-
An ensemble predictor that combines predictions from multiple models using voting.
|
|
17
|
-
|
|
18
|
-
This class implements ensemble methods for combining predictions from multiple
|
|
19
|
-
models or checkpoints. It supports both weighted and unweighted voting, and
|
|
20
|
-
provides various aggregation methods for different data types (numeric and string).
|
|
21
|
-
|
|
22
|
-
Attributes:
|
|
23
|
-
checkpoints: List of checkpoint names
|
|
24
|
-
predictors: Dictionary of initialized predictors
|
|
25
|
-
weights: List of weights for each predictor
|
|
26
|
-
numeric_agg_func: Function for aggregating numeric predictions
|
|
27
|
-
str_agg: Function for aggregating string predictions
|
|
28
|
-
numeric_agg_methods: Dictionary of available numeric aggregation methods
|
|
29
|
-
str_agg_methods: Dictionary of available string aggregation methods
|
|
30
|
-
|
|
31
|
-
Example:
|
|
32
|
-
>>> from omnigenome.utility import VoteEnsemblePredictor
|
|
33
|
-
>>> predictors = {
|
|
34
|
-
... "model1": predictor1,
|
|
35
|
-
... "model2": predictor2,
|
|
36
|
-
... "model3": predictor3
|
|
37
|
-
... }
|
|
38
|
-
>>> weights = {"model1": 1.0, "model2": 0.8, "model3": 0.6}
|
|
39
|
-
>>> ensemble = VoteEnsemblePredictor(predictors, weights, numeric_agg="average")
|
|
40
|
-
>>> result = ensemble.predict("ACGUAGGUAUCGUAGA")
|
|
41
|
-
>>> print(result)
|
|
42
|
-
{'prediction': 0.85}
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
def __init__(
|
|
46
|
-
self,
|
|
47
|
-
predictors: [List, dict],
|
|
48
|
-
weights: [List, dict] = None,
|
|
49
|
-
numeric_agg="average",
|
|
50
|
-
str_agg="max_vote",
|
|
51
|
-
):
|
|
52
|
-
"""
|
|
53
|
-
Initialize the VoteEnsemblePredictor.
|
|
54
|
-
|
|
55
|
-
Args:
|
|
56
|
-
predictors (List or dict): A list of checkpoints, or a dictionary of initialized predictors
|
|
57
|
-
weights (List or dict, optional): A list of weights for each predictor, or a dictionary of weights for each predictor
|
|
58
|
-
numeric_agg (str, optional): The aggregation method for numeric data. Options are 'average', 'mean', 'max', 'min',
|
|
59
|
-
'median', 'mode', and 'sum'. Defaults to 'average'
|
|
60
|
-
str_agg (str, optional): The aggregation method for string data. Options are 'max_vote', 'min_vote', 'vote', and 'mode'. Defaults to 'max_vote'
|
|
61
|
-
|
|
62
|
-
Raises:
|
|
63
|
-
AssertionError: If predictors and weights have different lengths or types
|
|
64
|
-
AssertionError: If predictors list is empty
|
|
65
|
-
AssertionError: If unsupported aggregation methods are provided
|
|
66
|
-
"""
|
|
67
|
-
if weights is not None:
|
|
68
|
-
assert len(predictors) == len(
|
|
69
|
-
weights
|
|
70
|
-
), "Checkpoints and weights should have the same length"
|
|
71
|
-
assert type(predictors) == type(
|
|
72
|
-
weights
|
|
73
|
-
), "Checkpoints and weights should have the same type"
|
|
74
|
-
|
|
75
|
-
assert len(predictors) > 0, "Checkpoints should not be empty"
|
|
76
|
-
|
|
77
|
-
self.numeric_agg_methods = {
|
|
78
|
-
"average": np.mean,
|
|
79
|
-
"mean": np.mean,
|
|
80
|
-
"max": np.max,
|
|
81
|
-
"min": np.min,
|
|
82
|
-
"median": np.median,
|
|
83
|
-
"mode": lambda x: max(set(x), key=x.count),
|
|
84
|
-
"sum": np.sum,
|
|
85
|
-
}
|
|
86
|
-
self.str_agg_methods = {
|
|
87
|
-
"max_vote": lambda x: max(set(x), key=x.count),
|
|
88
|
-
"min_vote": lambda x: min(set(x), key=x.count),
|
|
89
|
-
"vote": lambda x: max(set(x), key=x.count),
|
|
90
|
-
"mode": lambda x: max(set(x), key=x.count),
|
|
91
|
-
}
|
|
92
|
-
assert (
|
|
93
|
-
numeric_agg in self.numeric_agg_methods
|
|
94
|
-
), "numeric_agg should be either: " + str(self.numeric_agg_methods.keys())
|
|
95
|
-
assert (
|
|
96
|
-
str_agg in self.str_agg_methods
|
|
97
|
-
), "str_agg should be either max or vote" + str(self.str_agg_methods.keys())
|
|
98
|
-
|
|
99
|
-
self.numeric_agg_func = numeric_agg
|
|
100
|
-
self.str_agg = self.str_agg_methods[str_agg]
|
|
101
|
-
|
|
102
|
-
if isinstance(predictors, dict):
|
|
103
|
-
self.checkpoints = list(predictors.keys())
|
|
104
|
-
self.predictors = predictors
|
|
105
|
-
self.weights = (
|
|
106
|
-
list(weights.values()) if weights else [1] * len(self.checkpoints)
|
|
107
|
-
)
|
|
108
|
-
else:
|
|
109
|
-
raise NotImplementedError(
|
|
110
|
-
"Only support dict type for checkpoints and weights"
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
def numeric_agg(self, result: list):
|
|
114
|
-
"""
|
|
115
|
-
Aggregate a list of numeric values.
|
|
116
|
-
|
|
117
|
-
Args:
|
|
118
|
-
result (list): A list of numeric values to aggregate
|
|
119
|
-
|
|
120
|
-
Returns:
|
|
121
|
-
The aggregated value using the specified numeric aggregation method
|
|
122
|
-
|
|
123
|
-
Example:
|
|
124
|
-
>>> ensemble = VoteEnsemblePredictor(predictors, numeric_agg="average")
|
|
125
|
-
>>> result = ensemble.numeric_agg([0.8, 0.9, 0.7])
|
|
126
|
-
>>> print(result)
|
|
127
|
-
0.8
|
|
128
|
-
"""
|
|
129
|
-
res = np.stack([np.array(x) for x in result])
|
|
130
|
-
return self.numeric_agg_methods[self.numeric_agg_func](res, axis=0)
|
|
131
|
-
|
|
132
|
-
def __ensemble(self, result: dict):
|
|
133
|
-
"""
|
|
134
|
-
Aggregate prediction results by calling the appropriate aggregation method.
|
|
135
|
-
|
|
136
|
-
This method determines the type of result and calls the appropriate
|
|
137
|
-
aggregation method (numeric or string).
|
|
138
|
-
|
|
139
|
-
Args:
|
|
140
|
-
result (dict): A dictionary containing the prediction results
|
|
141
|
-
|
|
142
|
-
Returns:
|
|
143
|
-
The aggregated prediction result
|
|
144
|
-
"""
|
|
145
|
-
if isinstance(result, dict):
|
|
146
|
-
return self.__dict_aggregate(result)
|
|
147
|
-
elif isinstance(result, list):
|
|
148
|
-
return self.__list_aggregate(result)
|
|
149
|
-
else:
|
|
150
|
-
return result
|
|
151
|
-
|
|
152
|
-
def __dict_aggregate(self, result: dict):
|
|
153
|
-
"""
|
|
154
|
-
Recursively aggregate a dictionary of prediction results.
|
|
155
|
-
|
|
156
|
-
This method recursively processes nested dictionaries and applies
|
|
157
|
-
appropriate aggregation methods to each level.
|
|
158
|
-
|
|
159
|
-
Args:
|
|
160
|
-
result (dict): A dictionary containing the prediction results
|
|
161
|
-
|
|
162
|
-
Returns:
|
|
163
|
-
dict: The aggregated prediction result
|
|
164
|
-
"""
|
|
165
|
-
ensemble_result = {}
|
|
166
|
-
for k, v in result.items():
|
|
167
|
-
if isinstance(result[k], list):
|
|
168
|
-
ensemble_result[k] = self.__list_aggregate(result[k])
|
|
169
|
-
elif isinstance(result[k], dict):
|
|
170
|
-
ensemble_result[k] = self.__dict_aggregate(result[k])
|
|
171
|
-
else:
|
|
172
|
-
ensemble_result[k] = result[k]
|
|
173
|
-
return ensemble_result
|
|
174
|
-
|
|
175
|
-
def __list_aggregate(self, result: list):
|
|
176
|
-
"""
|
|
177
|
-
Aggregate a list of prediction results.
|
|
178
|
-
|
|
179
|
-
This method handles different types of list elements and applies
|
|
180
|
-
appropriate aggregation methods based on the data type.
|
|
181
|
-
|
|
182
|
-
Args:
|
|
183
|
-
result (list): A list of prediction results to aggregate
|
|
184
|
-
|
|
185
|
-
Returns:
|
|
186
|
-
The aggregated result
|
|
187
|
-
|
|
188
|
-
Raises:
|
|
189
|
-
AssertionError: If all elements in the list are not of the same type
|
|
190
|
-
"""
|
|
191
|
-
if not isinstance(result, list):
|
|
192
|
-
result = [result]
|
|
193
|
-
|
|
194
|
-
assert all(
|
|
195
|
-
isinstance(x, (type(result[0]))) for x in result
|
|
196
|
-
), "all type of result should be the same"
|
|
197
|
-
|
|
198
|
-
if isinstance(result[0], list):
|
|
199
|
-
for i, k in enumerate(result):
|
|
200
|
-
result[i] = self.__list_aggregate(k)
|
|
201
|
-
# start to aggregate
|
|
202
|
-
try:
|
|
203
|
-
new_result = self.numeric_agg(result)
|
|
204
|
-
except Exception as e:
|
|
205
|
-
try:
|
|
206
|
-
new_result = self.str_agg(result)
|
|
207
|
-
except Exception as e:
|
|
208
|
-
new_result = result
|
|
209
|
-
return [new_result]
|
|
210
|
-
|
|
211
|
-
elif isinstance(result[0], dict):
|
|
212
|
-
for k in result:
|
|
213
|
-
result[k] = self.__dict_aggregate(result[k])
|
|
214
|
-
return result
|
|
215
|
-
|
|
216
|
-
# start to aggregate
|
|
217
|
-
try:
|
|
218
|
-
new_result = self.numeric_agg(result)
|
|
219
|
-
except Exception as e:
|
|
220
|
-
try:
|
|
221
|
-
new_result = self.str_agg(result)
|
|
222
|
-
except Exception as e:
|
|
223
|
-
new_result = result
|
|
224
|
-
|
|
225
|
-
return new_result
|
|
226
|
-
|
|
227
|
-
def predict(self, text, ignore_error=False, print_result=False):
|
|
228
|
-
"""
|
|
229
|
-
Predicts on a single text and returns the ensemble result.
|
|
230
|
-
|
|
231
|
-
This method combines predictions from all predictors in the ensemble
|
|
232
|
-
using the specified weights and aggregation methods.
|
|
233
|
-
|
|
234
|
-
Args:
|
|
235
|
-
text (str): The text to perform prediction on
|
|
236
|
-
ignore_error (bool, optional): Whether to ignore any errors that occur during prediction. Defaults to False
|
|
237
|
-
print_result (bool, optional): Whether to print the prediction result. Defaults to False
|
|
238
|
-
|
|
239
|
-
Returns:
|
|
240
|
-
dict: The ensemble prediction result
|
|
241
|
-
|
|
242
|
-
Example:
|
|
243
|
-
>>> result = ensemble.predict("ACGUAGGUAUCGUAGA", ignore_error=True)
|
|
244
|
-
>>> print(result)
|
|
245
|
-
{'prediction': 0.85, 'confidence': 0.92}
|
|
246
|
-
"""
|
|
247
|
-
# Initialize an empty dictionary to store the prediction result
|
|
248
|
-
result = {}
|
|
249
|
-
# Loop through each checkpoint and predictor in the ensemble
|
|
250
|
-
for ckpt, predictor in self.predictors.items():
|
|
251
|
-
# Perform prediction on the text using the predictor
|
|
252
|
-
raw_result = predictor.inference(
|
|
253
|
-
text, ignore_error=ignore_error, print_result=print_result
|
|
254
|
-
)
|
|
255
|
-
# For each key-value pair in the raw result dictionary
|
|
256
|
-
for key, value in raw_result.items():
|
|
257
|
-
# If the key is not already in the result dictionary
|
|
258
|
-
if key not in result:
|
|
259
|
-
# Initialize an empty list for the key
|
|
260
|
-
result[key] = []
|
|
261
|
-
# Append the value to the list the number of times specified by the corresponding weight
|
|
262
|
-
for _ in range(self.weights[self.checkpoints.index(ckpt)]):
|
|
263
|
-
result[key].append(value)
|
|
264
|
-
# Return the ensemble result by aggregating the values in the result dictionary
|
|
265
|
-
return self.__ensemble(result)
|
|
266
|
-
|
|
267
|
-
def batch_predict(self, texts, ignore_error=False, print_result=False):
|
|
268
|
-
"""
|
|
269
|
-
Predicts on a batch of texts using the ensemble of predictors.
|
|
270
|
-
|
|
271
|
-
This method processes multiple texts efficiently by combining predictions
|
|
272
|
-
from all predictors in the ensemble for each text in the batch.
|
|
273
|
-
|
|
274
|
-
Args:
|
|
275
|
-
texts (list): A list of strings to predict on
|
|
276
|
-
ignore_error (bool, optional): Boolean indicating whether to ignore errors or raise exceptions when prediction fails. Defaults to False
|
|
277
|
-
print_result (bool, optional): Boolean indicating whether to print the raw results for each predictor. Defaults to False
|
|
278
|
-
|
|
279
|
-
Returns:
|
|
280
|
-
list: A list of dictionaries, each dictionary containing the aggregated results of the corresponding text in the input list
|
|
281
|
-
|
|
282
|
-
Example:
|
|
283
|
-
>>> texts = ["ACGUAGGUAUCGUAGA", "GGCTAGCTA", "TATCGCTA"]
|
|
284
|
-
>>> results = ensemble.batch_predict(texts, ignore_error=True)
|
|
285
|
-
>>> print(len(results))
|
|
286
|
-
3
|
|
287
|
-
"""
|
|
288
|
-
batch_raw_results = []
|
|
289
|
-
for ckpt, predictor in self.predictors.items():
|
|
290
|
-
if hasattr(predictor, "inference"):
|
|
291
|
-
raw_results = predictor.inference(
|
|
292
|
-
texts,
|
|
293
|
-
ignore_error=ignore_error,
|
|
294
|
-
print_result=print_result,
|
|
295
|
-
merge_results=False,
|
|
296
|
-
)
|
|
297
|
-
else:
|
|
298
|
-
raw_results = predictor.inference(
|
|
299
|
-
texts, ignore_error=ignore_error, print_result=print_result
|
|
300
|
-
)
|
|
301
|
-
batch_raw_results.append(raw_results)
|
|
302
|
-
|
|
303
|
-
batch_results = []
|
|
304
|
-
for raw_result in batch_raw_results:
|
|
305
|
-
for i, result in enumerate(raw_result):
|
|
306
|
-
if i >= len(batch_results):
|
|
307
|
-
batch_results.append({})
|
|
308
|
-
for key, value in result.items():
|
|
309
|
-
if key not in batch_results[i]:
|
|
310
|
-
batch_results[i][key] = []
|
|
311
|
-
for _ in range(self.weights[self.checkpoints.index(ckpt)]):
|
|
312
|
-
batch_results[i][key].append(value)
|
|
313
|
-
|
|
314
|
-
ensemble_results = []
|
|
315
|
-
for result in batch_results:
|
|
316
|
-
ensemble_results.append(self.__ensemble(result))
|
|
317
|
-
return ensemble_results
|
|
318
|
-
|
|
319
|
-
# def batch_predict(self, texts, ignore_error=False, print_result=False):
|
|
320
|
-
# batch_results = []
|
|
321
|
-
# for text in tqdm.tqdm(texts, desc='Batch predict: '):
|
|
322
|
-
# result = self.predict(text, ignore_error=ignore_error, print_result=print_result)
|
|
323
|
-
# batch_results.append(result)
|
|
324
|
-
# return batch_results
|