omnigenome 0.3.0a1__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +16 -8
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +40 -36
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +65 -58
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +2 -2
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- omnigenome-0.3.0a1.dist-info/RECORD +0 -78
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
- {omnigenome-0.3.0a1.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -0
|
@@ -29,14 +29,14 @@ from torch.cuda.amp import GradScaler
|
|
|
29
29
|
def _infer_optimization_direction(metrics, prev_metrics):
|
|
30
30
|
"""
|
|
31
31
|
Infer the optimization direction based on metric names and trends.
|
|
32
|
-
|
|
32
|
+
|
|
33
33
|
This function determines whether larger or smaller values are better for
|
|
34
34
|
the given metrics by analyzing metric names and their trends over time.
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
Args:
|
|
37
37
|
metrics (dict): Current metric values
|
|
38
38
|
prev_metrics (list): Previous metric values from multiple epochs
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
Returns:
|
|
41
41
|
str: Either "larger_is_better" or "smaller_is_better"
|
|
42
42
|
"""
|
|
@@ -98,11 +98,11 @@ def _infer_optimization_direction(metrics, prev_metrics):
|
|
|
98
98
|
class Trainer:
|
|
99
99
|
"""
|
|
100
100
|
Comprehensive trainer for OmniGenome models.
|
|
101
|
-
|
|
101
|
+
|
|
102
102
|
This trainer provides a complete training framework with automatic mixed precision,
|
|
103
103
|
early stopping, metric tracking, and model checkpointing. It supports various
|
|
104
104
|
training configurations and can handle different types of genomic sequence tasks.
|
|
105
|
-
|
|
105
|
+
|
|
106
106
|
Attributes:
|
|
107
107
|
model: The model to be trained
|
|
108
108
|
train_loader: DataLoader for training data
|
|
@@ -118,7 +118,7 @@ class Trainer:
|
|
|
118
118
|
metrics: Dictionary to store training metrics
|
|
119
119
|
predictions: Dictionary to store model predictions
|
|
120
120
|
"""
|
|
121
|
-
|
|
121
|
+
|
|
122
122
|
def __init__(
|
|
123
123
|
self,
|
|
124
124
|
model,
|
|
@@ -139,7 +139,7 @@ class Trainer:
|
|
|
139
139
|
):
|
|
140
140
|
"""
|
|
141
141
|
Initialize the trainer.
|
|
142
|
-
|
|
142
|
+
|
|
143
143
|
Args:
|
|
144
144
|
model: The model to be trained
|
|
145
145
|
train_dataset: Training dataset
|
|
@@ -191,7 +191,9 @@ class Trainer:
|
|
|
191
191
|
)
|
|
192
192
|
self.seed = seed
|
|
193
193
|
self.device = device if device else autocuda.auto_cuda()
|
|
194
|
-
self.device =
|
|
194
|
+
self.device = (
|
|
195
|
+
torch.device(self.device) if isinstance(self.device, str) else self.device
|
|
196
|
+
)
|
|
195
197
|
|
|
196
198
|
self.fast_dtype = {
|
|
197
199
|
"float32": torch.float32,
|
|
@@ -218,11 +220,11 @@ class Trainer:
|
|
|
218
220
|
def _is_metric_better(self, metrics, stage="valid"):
|
|
219
221
|
"""
|
|
220
222
|
Check if the current metrics are better than the best metrics so far.
|
|
221
|
-
|
|
223
|
+
|
|
222
224
|
Args:
|
|
223
225
|
metrics (dict): Current metric values
|
|
224
226
|
stage (str): Stage name ("valid" or "test")
|
|
225
|
-
|
|
227
|
+
|
|
226
228
|
Returns:
|
|
227
229
|
bool: True if current metrics are better than best metrics
|
|
228
230
|
"""
|
|
@@ -268,11 +270,11 @@ class Trainer:
|
|
|
268
270
|
def train(self, path_to_save=None, **kwargs):
|
|
269
271
|
"""
|
|
270
272
|
Train the model.
|
|
271
|
-
|
|
273
|
+
|
|
272
274
|
Args:
|
|
273
275
|
path_to_save (str, optional): Path to save the best model
|
|
274
276
|
**kwargs: Additional keyword arguments
|
|
275
|
-
|
|
277
|
+
|
|
276
278
|
Returns:
|
|
277
279
|
dict: Training metrics and results
|
|
278
280
|
"""
|
|
@@ -300,19 +302,29 @@ class Trainer:
|
|
|
300
302
|
self.optimizer.zero_grad()
|
|
301
303
|
|
|
302
304
|
if self.fast_dtype:
|
|
303
|
-
with torch.autocast(
|
|
305
|
+
with torch.autocast(
|
|
306
|
+
device_type=self.device.type, dtype=self.fast_dtype
|
|
307
|
+
):
|
|
304
308
|
outputs = self.model(**batch)
|
|
305
309
|
else:
|
|
306
310
|
outputs = self.model(**batch)
|
|
307
311
|
if "loss" not in outputs:
|
|
308
312
|
# Generally, the model should return a loss in the outputs via OmniGenBench
|
|
309
313
|
# For the Lora models, the loss is computed separately
|
|
310
|
-
if hasattr(self.model, "loss_function") and callable(
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
314
|
+
if hasattr(self.model, "loss_function") and callable(
|
|
315
|
+
self.model.loss_function
|
|
316
|
+
):
|
|
317
|
+
loss = self.model.loss_function(
|
|
318
|
+
outputs["logits"], outputs["labels"]
|
|
319
|
+
)
|
|
320
|
+
elif (
|
|
321
|
+
hasattr(self.model, "model")
|
|
322
|
+
and hasattr(self.model.model, "loss_function")
|
|
323
|
+
and callable(self.model.model.loss_function)
|
|
324
|
+
):
|
|
325
|
+
loss = self.model.model.loss_function(
|
|
326
|
+
outputs["logits"], outputs["labels"]
|
|
327
|
+
)
|
|
316
328
|
else:
|
|
317
329
|
raise ValueError(
|
|
318
330
|
"The model does not have a loss function defined. "
|
|
@@ -480,10 +492,10 @@ class Trainer:
|
|
|
480
492
|
def get_model(self, **kwargs):
|
|
481
493
|
"""
|
|
482
494
|
Get the trained model.
|
|
483
|
-
|
|
495
|
+
|
|
484
496
|
Args:
|
|
485
497
|
**kwargs: Additional keyword arguments
|
|
486
|
-
|
|
498
|
+
|
|
487
499
|
Returns:
|
|
488
500
|
The trained model
|
|
489
501
|
"""
|
|
@@ -492,7 +504,7 @@ class Trainer:
|
|
|
492
504
|
def compute_metrics(self):
|
|
493
505
|
"""
|
|
494
506
|
Get the metric computation functions.
|
|
495
|
-
|
|
507
|
+
|
|
496
508
|
Returns:
|
|
497
509
|
list: List of metric computation functions
|
|
498
510
|
"""
|
|
@@ -501,10 +513,10 @@ class Trainer:
|
|
|
501
513
|
def unwrap_model(self, model=None):
|
|
502
514
|
"""
|
|
503
515
|
Unwrap the model from any distributed training wrappers.
|
|
504
|
-
|
|
516
|
+
|
|
505
517
|
Args:
|
|
506
518
|
model: Model to unwrap (default: None, uses self.model)
|
|
507
|
-
|
|
519
|
+
|
|
508
520
|
Returns:
|
|
509
521
|
The unwrapped model
|
|
510
522
|
"""
|
|
@@ -538,7 +550,7 @@ class Trainer:
|
|
|
538
550
|
"""
|
|
539
551
|
if os.path.exists(self._model_state_dict_path):
|
|
540
552
|
self.unwrap_model().load_state_dict(
|
|
541
|
-
torch.load(self._model_state_dict_path, map_location=
|
|
553
|
+
torch.load(self._model_state_dict_path, map_location="cpu")
|
|
542
554
|
)
|
|
543
555
|
self.unwrap_model().to(self.device)
|
|
544
556
|
|
|
@@ -32,11 +32,11 @@ def load_benchmark_datasets(
|
|
|
32
32
|
):
|
|
33
33
|
"""
|
|
34
34
|
Load benchmark datasets from the OmniGenome hub.
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
This function automatically downloads benchmark datasets if they don't exist locally,
|
|
37
37
|
loads their configurations, and initializes train/validation/test datasets with
|
|
38
38
|
the specified tokenizer.
|
|
39
|
-
|
|
39
|
+
|
|
40
40
|
Args:
|
|
41
41
|
benchmark (str): Name or path of the benchmark to load. If the benchmark
|
|
42
42
|
doesn't exist locally, it will be downloaded from the hub.
|
|
@@ -46,17 +46,17 @@ def load_benchmark_datasets(
|
|
|
46
46
|
be loaded from the benchmark configuration.
|
|
47
47
|
**kwargs: Additional keyword arguments to override benchmark configuration.
|
|
48
48
|
These will be passed to the dataset classes and tokenizer initialization.
|
|
49
|
-
|
|
49
|
+
|
|
50
50
|
Returns:
|
|
51
51
|
dict: Dictionary containing datasets for each benchmark task, with keys
|
|
52
52
|
being benchmark names and values being dictionaries with 'train',
|
|
53
53
|
'valid', and 'test' datasets.
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
Raises:
|
|
56
56
|
FileNotFoundError: If the benchmark cannot be found or downloaded.
|
|
57
57
|
ValueError: If the benchmark configuration is invalid.
|
|
58
58
|
ImportError: If required dependencies are not available.
|
|
59
|
-
|
|
59
|
+
|
|
60
60
|
Example:
|
|
61
61
|
>>> from omnigenome import OmniSingleNucleotideTokenizer
|
|
62
62
|
>>> tokenizer = OmniSingleNucleotideTokenizer.from_pretrained("model_name")
|
|
@@ -64,7 +64,7 @@ def load_benchmark_datasets(
|
|
|
64
64
|
>>> print(f"Loaded {len(datasets)} benchmark tasks")
|
|
65
65
|
>>> for task_name, task_datasets in datasets.items():
|
|
66
66
|
... print(f"{task_name}: {len(task_datasets['train'])} train samples")
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
Note:
|
|
69
69
|
- The function automatically handles U/T conversion and other preprocessing
|
|
70
70
|
based on the benchmark configuration.
|
|
@@ -80,7 +80,7 @@ def load_benchmark_datasets(
|
|
|
80
80
|
"does not exist. Search online for available benchmarks.",
|
|
81
81
|
)
|
|
82
82
|
benchmark = download_benchmark(benchmark)
|
|
83
|
-
|
|
83
|
+
|
|
84
84
|
# Import benchmark list
|
|
85
85
|
bench_metadata = load_module_from_path(
|
|
86
86
|
f"bench_metadata", f"{benchmark}/metadata.py"
|
|
@@ -107,9 +107,7 @@ def load_benchmark_datasets(
|
|
|
107
107
|
|
|
108
108
|
for key, value in _kwargs.items():
|
|
109
109
|
if key in bench_config:
|
|
110
|
-
fprint(
|
|
111
|
-
"Override", key, "with", value, "according to the input kwargs"
|
|
112
|
-
)
|
|
110
|
+
fprint("Override", key, "with", value, "according to the input kwargs")
|
|
113
111
|
bench_config.update({key: value})
|
|
114
112
|
|
|
115
113
|
else:
|
|
@@ -170,9 +168,11 @@ def load_benchmark_datasets(
|
|
|
170
168
|
"valid": valid_set,
|
|
171
169
|
}
|
|
172
170
|
|
|
173
|
-
fprint(
|
|
174
|
-
|
|
171
|
+
fprint(
|
|
172
|
+
f"Loaded dataset for {bench} with {len(train_set)} train samples, "
|
|
173
|
+
f"{len(test_set)} test samples and {len(valid_set)} valid samples."
|
|
174
|
+
)
|
|
175
175
|
|
|
176
176
|
datasets[bench] = dataset
|
|
177
177
|
|
|
178
|
-
return datasets
|
|
178
|
+
return datasets
|
omnigenome/utility/ensemble.py
CHANGED
|
@@ -14,11 +14,11 @@ import numpy as np
|
|
|
14
14
|
class VoteEnsemblePredictor:
|
|
15
15
|
"""
|
|
16
16
|
An ensemble predictor that combines predictions from multiple models using voting.
|
|
17
|
-
|
|
17
|
+
|
|
18
18
|
This class implements ensemble methods for combining predictions from multiple
|
|
19
19
|
models or checkpoints. It supports both weighted and unweighted voting, and
|
|
20
20
|
provides various aggregation methods for different data types (numeric and string).
|
|
21
|
-
|
|
21
|
+
|
|
22
22
|
Attributes:
|
|
23
23
|
checkpoints: List of checkpoint names
|
|
24
24
|
predictors: Dictionary of initialized predictors
|
|
@@ -27,7 +27,7 @@ class VoteEnsemblePredictor:
|
|
|
27
27
|
str_agg: Function for aggregating string predictions
|
|
28
28
|
numeric_agg_methods: Dictionary of available numeric aggregation methods
|
|
29
29
|
str_agg_methods: Dictionary of available string aggregation methods
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
Example:
|
|
32
32
|
>>> from omnigenome.utility import VoteEnsemblePredictor
|
|
33
33
|
>>> predictors = {
|
|
@@ -51,14 +51,14 @@ class VoteEnsemblePredictor:
|
|
|
51
51
|
):
|
|
52
52
|
"""
|
|
53
53
|
Initialize the VoteEnsemblePredictor.
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
Args:
|
|
56
56
|
predictors (List or dict): A list of checkpoints, or a dictionary of initialized predictors
|
|
57
57
|
weights (List or dict, optional): A list of weights for each predictor, or a dictionary of weights for each predictor
|
|
58
58
|
numeric_agg (str, optional): The aggregation method for numeric data. Options are 'average', 'mean', 'max', 'min',
|
|
59
59
|
'median', 'mode', and 'sum'. Defaults to 'average'
|
|
60
60
|
str_agg (str, optional): The aggregation method for string data. Options are 'max_vote', 'min_vote', 'vote', and 'mode'. Defaults to 'max_vote'
|
|
61
|
-
|
|
61
|
+
|
|
62
62
|
Raises:
|
|
63
63
|
AssertionError: If predictors and weights have different lengths or types
|
|
64
64
|
AssertionError: If predictors list is empty
|
|
@@ -113,13 +113,13 @@ class VoteEnsemblePredictor:
|
|
|
113
113
|
def numeric_agg(self, result: list):
|
|
114
114
|
"""
|
|
115
115
|
Aggregate a list of numeric values.
|
|
116
|
-
|
|
116
|
+
|
|
117
117
|
Args:
|
|
118
118
|
result (list): A list of numeric values to aggregate
|
|
119
|
-
|
|
119
|
+
|
|
120
120
|
Returns:
|
|
121
121
|
The aggregated value using the specified numeric aggregation method
|
|
122
|
-
|
|
122
|
+
|
|
123
123
|
Example:
|
|
124
124
|
>>> ensemble = VoteEnsemblePredictor(predictors, numeric_agg="average")
|
|
125
125
|
>>> result = ensemble.numeric_agg([0.8, 0.9, 0.7])
|
|
@@ -132,13 +132,13 @@ class VoteEnsemblePredictor:
|
|
|
132
132
|
def __ensemble(self, result: dict):
|
|
133
133
|
"""
|
|
134
134
|
Aggregate prediction results by calling the appropriate aggregation method.
|
|
135
|
-
|
|
135
|
+
|
|
136
136
|
This method determines the type of result and calls the appropriate
|
|
137
137
|
aggregation method (numeric or string).
|
|
138
|
-
|
|
138
|
+
|
|
139
139
|
Args:
|
|
140
140
|
result (dict): A dictionary containing the prediction results
|
|
141
|
-
|
|
141
|
+
|
|
142
142
|
Returns:
|
|
143
143
|
The aggregated prediction result
|
|
144
144
|
"""
|
|
@@ -152,13 +152,13 @@ class VoteEnsemblePredictor:
|
|
|
152
152
|
def __dict_aggregate(self, result: dict):
|
|
153
153
|
"""
|
|
154
154
|
Recursively aggregate a dictionary of prediction results.
|
|
155
|
-
|
|
155
|
+
|
|
156
156
|
This method recursively processes nested dictionaries and applies
|
|
157
157
|
appropriate aggregation methods to each level.
|
|
158
|
-
|
|
158
|
+
|
|
159
159
|
Args:
|
|
160
160
|
result (dict): A dictionary containing the prediction results
|
|
161
|
-
|
|
161
|
+
|
|
162
162
|
Returns:
|
|
163
163
|
dict: The aggregated prediction result
|
|
164
164
|
"""
|
|
@@ -175,16 +175,16 @@ class VoteEnsemblePredictor:
|
|
|
175
175
|
def __list_aggregate(self, result: list):
|
|
176
176
|
"""
|
|
177
177
|
Aggregate a list of prediction results.
|
|
178
|
-
|
|
178
|
+
|
|
179
179
|
This method handles different types of list elements and applies
|
|
180
180
|
appropriate aggregation methods based on the data type.
|
|
181
|
-
|
|
181
|
+
|
|
182
182
|
Args:
|
|
183
183
|
result (list): A list of prediction results to aggregate
|
|
184
|
-
|
|
184
|
+
|
|
185
185
|
Returns:
|
|
186
186
|
The aggregated result
|
|
187
|
-
|
|
187
|
+
|
|
188
188
|
Raises:
|
|
189
189
|
AssertionError: If all elements in the list are not of the same type
|
|
190
190
|
"""
|
|
@@ -227,18 +227,18 @@ class VoteEnsemblePredictor:
|
|
|
227
227
|
def predict(self, text, ignore_error=False, print_result=False):
|
|
228
228
|
"""
|
|
229
229
|
Predicts on a single text and returns the ensemble result.
|
|
230
|
-
|
|
230
|
+
|
|
231
231
|
This method combines predictions from all predictors in the ensemble
|
|
232
232
|
using the specified weights and aggregation methods.
|
|
233
|
-
|
|
233
|
+
|
|
234
234
|
Args:
|
|
235
235
|
text (str): The text to perform prediction on
|
|
236
236
|
ignore_error (bool, optional): Whether to ignore any errors that occur during prediction. Defaults to False
|
|
237
237
|
print_result (bool, optional): Whether to print the prediction result. Defaults to False
|
|
238
|
-
|
|
238
|
+
|
|
239
239
|
Returns:
|
|
240
240
|
dict: The ensemble prediction result
|
|
241
|
-
|
|
241
|
+
|
|
242
242
|
Example:
|
|
243
243
|
>>> result = ensemble.predict("ACGUAGGUAUCGUAGA", ignore_error=True)
|
|
244
244
|
>>> print(result)
|
|
@@ -267,18 +267,18 @@ class VoteEnsemblePredictor:
|
|
|
267
267
|
def batch_predict(self, texts, ignore_error=False, print_result=False):
|
|
268
268
|
"""
|
|
269
269
|
Predicts on a batch of texts using the ensemble of predictors.
|
|
270
|
-
|
|
270
|
+
|
|
271
271
|
This method processes multiple texts efficiently by combining predictions
|
|
272
272
|
from all predictors in the ensemble for each text in the batch.
|
|
273
|
-
|
|
273
|
+
|
|
274
274
|
Args:
|
|
275
275
|
texts (list): A list of strings to predict on
|
|
276
276
|
ignore_error (bool, optional): Boolean indicating whether to ignore errors or raise exceptions when prediction fails. Defaults to False
|
|
277
277
|
print_result (bool, optional): Boolean indicating whether to print the raw results for each predictor. Defaults to False
|
|
278
|
-
|
|
278
|
+
|
|
279
279
|
Returns:
|
|
280
280
|
list: A list of dictionaries, each dictionary containing the aggregated results of the corresponding text in the input list
|
|
281
|
-
|
|
281
|
+
|
|
282
282
|
Example:
|
|
283
283
|
>>> texts = ["ACGUAGGUAUCGUAGA", "GGCTAGCTA", "TATCGCTA"]
|
|
284
284
|
>>> results = ensemble.batch_predict(texts, ignore_error=True)
|
omnigenome/utility/hub_utils.py
CHANGED
|
@@ -24,7 +24,7 @@ from omnigenome.src.misc.utils import fprint, default_omnigenome_repo
|
|
|
24
24
|
def unzip_checkpoint(checkpoint_path):
|
|
25
25
|
"""
|
|
26
26
|
Unzips a checkpoint file.
|
|
27
|
-
|
|
27
|
+
|
|
28
28
|
This function extracts a zipped checkpoint file to a directory,
|
|
29
29
|
making it ready for use by the model loading functions.
|
|
30
30
|
|
|
@@ -51,7 +51,7 @@ def query_models_info(
|
|
|
51
51
|
) -> Dict[str, Any]:
|
|
52
52
|
"""
|
|
53
53
|
Queries information about available models from the hub.
|
|
54
|
-
|
|
54
|
+
|
|
55
55
|
This function retrieves model information from the OmniGenome hub,
|
|
56
56
|
either from a remote repository or from a local cache. It supports
|
|
57
57
|
filtering by keywords to find specific models.
|
|
@@ -69,7 +69,7 @@ def query_models_info(
|
|
|
69
69
|
>>> # Query all models
|
|
70
70
|
>>> models = query_models_info("")
|
|
71
71
|
>>> print(len(models)) # Number of available models
|
|
72
|
-
|
|
72
|
+
|
|
73
73
|
>>> # Query specific models
|
|
74
74
|
>>> models = query_models_info("DNA")
|
|
75
75
|
>>> print(models.keys()) # Models containing "DNA"
|
|
@@ -108,7 +108,7 @@ def query_pipelines_info(
|
|
|
108
108
|
) -> Dict[str, Any]:
|
|
109
109
|
"""
|
|
110
110
|
Queries information about available pipelines from the hub.
|
|
111
|
-
|
|
111
|
+
|
|
112
112
|
This function retrieves pipeline information from the OmniGenome hub,
|
|
113
113
|
either from a remote repository or from a local cache. It supports
|
|
114
114
|
filtering by keywords to find specific pipelines.
|
|
@@ -126,7 +126,7 @@ def query_pipelines_info(
|
|
|
126
126
|
>>> # Query all pipelines
|
|
127
127
|
>>> pipelines = query_pipelines_info("")
|
|
128
128
|
>>> print(len(pipelines)) # Number of available pipelines
|
|
129
|
-
|
|
129
|
+
|
|
130
130
|
>>> # Query specific pipelines
|
|
131
131
|
>>> pipelines = query_pipelines_info("classification")
|
|
132
132
|
>>> print(pipelines.keys()) # Pipelines containing "classification"
|
|
@@ -165,7 +165,7 @@ def query_benchmarks_info(
|
|
|
165
165
|
) -> Dict[str, Any]:
|
|
166
166
|
"""
|
|
167
167
|
Queries information about available benchmarks from the hub.
|
|
168
|
-
|
|
168
|
+
|
|
169
169
|
This function retrieves benchmark information from the OmniGenome hub,
|
|
170
170
|
either from a remote repository or from a local cache. It supports
|
|
171
171
|
filtering by keywords to find specific benchmarks.
|
|
@@ -183,7 +183,7 @@ def query_benchmarks_info(
|
|
|
183
183
|
>>> # Query all benchmarks
|
|
184
184
|
>>> benchmarks = query_benchmarks_info("")
|
|
185
185
|
>>> print(len(benchmarks)) # Number of available benchmarks
|
|
186
|
-
|
|
186
|
+
|
|
187
187
|
>>> # Query specific benchmarks
|
|
188
188
|
>>> benchmarks = query_benchmarks_info("RGB")
|
|
189
189
|
>>> print(benchmarks.keys()) # Benchmarks containing "RGB"
|
|
@@ -468,7 +468,7 @@ def download_benchmark(
|
|
|
468
468
|
def check_version(repo: str = None) -> None:
|
|
469
469
|
"""
|
|
470
470
|
Checks the version compatibility between local and remote OmniGenome.
|
|
471
|
-
|
|
471
|
+
|
|
472
472
|
This function compares the local OmniGenome version with the version
|
|
473
473
|
available in the remote repository to ensure compatibility.
|
|
474
474
|
|
|
@@ -21,33 +21,33 @@ from ...src.misc.utils import env_meta_info, fprint
|
|
|
21
21
|
class ModelHub:
|
|
22
22
|
"""
|
|
23
23
|
A hub for loading and managing pre-trained genomic models.
|
|
24
|
-
|
|
24
|
+
|
|
25
25
|
This class provides a unified interface for loading pre-trained models
|
|
26
26
|
from the OmniGenome hub or local paths. It handles model downloading,
|
|
27
27
|
tokenizer loading, and device placement automatically.
|
|
28
|
-
|
|
28
|
+
|
|
29
29
|
The ModelHub supports various model types and can automatically
|
|
30
30
|
download models from the hub if they're not available locally.
|
|
31
|
-
|
|
31
|
+
|
|
32
32
|
Attributes:
|
|
33
33
|
metadata (dict): Environment metadata information
|
|
34
|
-
|
|
34
|
+
|
|
35
35
|
Example:
|
|
36
36
|
>>> from omnigenome import ModelHub
|
|
37
37
|
>>> hub = ModelHub()
|
|
38
|
-
|
|
38
|
+
|
|
39
39
|
>>> # Load a model from the hub
|
|
40
40
|
>>> model, tokenizer = ModelHub.load_model_and_tokenizer("model_name")
|
|
41
|
-
|
|
41
|
+
|
|
42
42
|
>>> # Check available models
|
|
43
43
|
>>> models = hub.available_models()
|
|
44
44
|
>>> print(list(models.keys()))
|
|
45
45
|
"""
|
|
46
|
-
|
|
46
|
+
|
|
47
47
|
def __init__(self, *args, **kwargs):
|
|
48
48
|
"""
|
|
49
49
|
Initialize the ModelHub instance.
|
|
50
|
-
|
|
50
|
+
|
|
51
51
|
Args:
|
|
52
52
|
*args: Additional positional arguments
|
|
53
53
|
**kwargs: Additional keyword arguments
|
|
@@ -66,21 +66,21 @@ class ModelHub:
|
|
|
66
66
|
):
|
|
67
67
|
"""
|
|
68
68
|
Load a model and its tokenizer from the hub or local path.
|
|
69
|
-
|
|
69
|
+
|
|
70
70
|
This method loads both the model and tokenizer, places them on the
|
|
71
71
|
specified device, and returns them as a tuple. It handles automatic
|
|
72
72
|
device selection if none is specified.
|
|
73
|
-
|
|
73
|
+
|
|
74
74
|
Args:
|
|
75
75
|
model_name_or_path (str): Name or path of the model to load
|
|
76
76
|
local_only (bool, optional): Whether to use only local cache. Defaults to False
|
|
77
77
|
device (str, optional): Device to load the model on. If None, uses auto-detection
|
|
78
78
|
dtype (torch.dtype, optional): Data type for the model. Defaults to torch.float16
|
|
79
79
|
**kwargs: Additional keyword arguments passed to the model loading functions
|
|
80
|
-
|
|
80
|
+
|
|
81
81
|
Returns:
|
|
82
82
|
tuple: A tuple containing (model, tokenizer)
|
|
83
|
-
|
|
83
|
+
|
|
84
84
|
Example:
|
|
85
85
|
>>> model, tokenizer = ModelHub.load_model_and_tokenizer("yangheng/OmniGenome-186M")
|
|
86
86
|
>>> print(f"Model loaded on device: {next(model.parameters()).device}")
|
|
@@ -108,24 +108,24 @@ class ModelHub:
|
|
|
108
108
|
):
|
|
109
109
|
"""
|
|
110
110
|
Load a model from the hub or local path.
|
|
111
|
-
|
|
111
|
+
|
|
112
112
|
This method handles model loading from various sources including
|
|
113
113
|
local paths and the OmniGenome hub. It automatically downloads
|
|
114
114
|
models if they're not available locally.
|
|
115
|
-
|
|
115
|
+
|
|
116
116
|
Args:
|
|
117
117
|
model_name_or_path (str): Name or path of the model to load
|
|
118
118
|
local_only (bool, optional): Whether to use only local cache. Defaults to False
|
|
119
119
|
device (str, optional): Device to load the model on. If None, uses auto-detection
|
|
120
120
|
dtype (torch.dtype, optional): Data type for the model. Defaults to torch.float16
|
|
121
121
|
**kwargs: Additional keyword arguments passed to the model loading functions
|
|
122
|
-
|
|
122
|
+
|
|
123
123
|
Returns:
|
|
124
124
|
torch.nn.Module: The loaded model
|
|
125
|
-
|
|
125
|
+
|
|
126
126
|
Raises:
|
|
127
127
|
ValueError: If model_name_or_path is not a string
|
|
128
|
-
|
|
128
|
+
|
|
129
129
|
Example:
|
|
130
130
|
>>> model = ModelHub.load("yangheng/OmniGenome-186M")
|
|
131
131
|
>>> print(f"Model type: {type(model)}")
|
|
@@ -152,6 +152,7 @@ class ModelHub:
|
|
|
152
152
|
tokenizer = tokenizer_cls.from_pretrained(path, **kwargs)
|
|
153
153
|
else:
|
|
154
154
|
from multimolecule import RnaTokenizer
|
|
155
|
+
|
|
155
156
|
tokenizer = RnaTokenizer.from_pretrained(path, **kwargs)
|
|
156
157
|
|
|
157
158
|
config.metadata = metadata
|
|
@@ -187,25 +188,25 @@ class ModelHub:
|
|
|
187
188
|
):
|
|
188
189
|
"""
|
|
189
190
|
Get information about available models in the hub.
|
|
190
|
-
|
|
191
|
+
|
|
191
192
|
This method queries the OmniGenome hub to retrieve information about
|
|
192
193
|
available models. It can filter models by name and supports both
|
|
193
194
|
local and remote queries.
|
|
194
|
-
|
|
195
|
+
|
|
195
196
|
Args:
|
|
196
197
|
model_name_or_path (str, optional): Filter models by name. Defaults to None
|
|
197
198
|
local_only (bool, optional): Whether to use only local cache. Defaults to False
|
|
198
199
|
repo (str, optional): Repository URL to query. Defaults to ""
|
|
199
200
|
**kwargs: Additional keyword arguments
|
|
200
|
-
|
|
201
|
+
|
|
201
202
|
Returns:
|
|
202
203
|
dict: Dictionary containing information about available models
|
|
203
|
-
|
|
204
|
+
|
|
204
205
|
Example:
|
|
205
206
|
>>> hub = ModelHub()
|
|
206
207
|
>>> models = hub.available_models()
|
|
207
208
|
>>> print(f"Available models: {len(models)}")
|
|
208
|
-
|
|
209
|
+
|
|
209
210
|
>>> # Filter models by name
|
|
210
211
|
>>> dna_models = hub.available_models("DNA")
|
|
211
212
|
>>> print(f"DNA models: {list(dna_models.keys())}")
|
|
@@ -218,13 +219,13 @@ class ModelHub:
|
|
|
218
219
|
def push(self, model, **kwargs):
|
|
219
220
|
"""
|
|
220
221
|
Push a model to the hub.
|
|
221
|
-
|
|
222
|
+
|
|
222
223
|
This method is not yet implemented and will raise a NotImplementedError.
|
|
223
|
-
|
|
224
|
+
|
|
224
225
|
Args:
|
|
225
226
|
model: The model to push to the hub
|
|
226
227
|
**kwargs: Additional keyword arguments
|
|
227
|
-
|
|
228
|
+
|
|
228
229
|
Raises:
|
|
229
230
|
NotImplementedError: This method has not been implemented yet
|
|
230
231
|
"""
|