omnigenome 0.3.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +281 -0
- omnigenome/auto/__init__.py +3 -0
- omnigenome/auto/auto_bench/__init__.py +12 -0
- omnigenome/auto/auto_bench/auto_bench.py +484 -0
- omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
- omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
- omnigenome/auto/auto_bench/config_check.py +34 -0
- omnigenome/auto/auto_train/__init__.py +13 -0
- omnigenome/auto/auto_train/auto_train.py +430 -0
- omnigenome/auto/auto_train/auto_train_cli.py +222 -0
- omnigenome/auto/bench_hub/__init__.py +12 -0
- omnigenome/auto/bench_hub/bench_hub.py +25 -0
- omnigenome/cli/__init__.py +13 -0
- omnigenome/cli/commands/__init__.py +13 -0
- omnigenome/cli/commands/base.py +83 -0
- omnigenome/cli/commands/bench/__init__.py +13 -0
- omnigenome/cli/commands/bench/bench_cli.py +202 -0
- omnigenome/cli/commands/rna/__init__.py +13 -0
- omnigenome/cli/commands/rna/rna_design.py +178 -0
- omnigenome/cli/omnigenome_cli.py +128 -0
- omnigenome/src/__init__.py +12 -0
- omnigenome/src/abc/__init__.py +12 -0
- omnigenome/src/abc/abstract_dataset.py +622 -0
- omnigenome/src/abc/abstract_metric.py +114 -0
- omnigenome/src/abc/abstract_model.py +689 -0
- omnigenome/src/abc/abstract_tokenizer.py +267 -0
- omnigenome/src/dataset/__init__.py +16 -0
- omnigenome/src/dataset/omni_dataset.py +435 -0
- omnigenome/src/lora/__init__.py +13 -0
- omnigenome/src/lora/lora_model.py +294 -0
- omnigenome/src/metric/__init__.py +15 -0
- omnigenome/src/metric/classification_metric.py +184 -0
- omnigenome/src/metric/metric.py +199 -0
- omnigenome/src/metric/ranking_metric.py +142 -0
- omnigenome/src/metric/regression_metric.py +191 -0
- omnigenome/src/misc/__init__.py +3 -0
- omnigenome/src/misc/utils.py +439 -0
- omnigenome/src/model/__init__.py +19 -0
- omnigenome/src/model/augmentation/__init__.py +12 -0
- omnigenome/src/model/augmentation/model.py +219 -0
- omnigenome/src/model/classification/__init__.py +12 -0
- omnigenome/src/model/classification/model.py +642 -0
- omnigenome/src/model/embedding/__init__.py +12 -0
- omnigenome/src/model/embedding/model.py +263 -0
- omnigenome/src/model/mlm/__init__.py +12 -0
- omnigenome/src/model/mlm/model.py +177 -0
- omnigenome/src/model/module_utils.py +232 -0
- omnigenome/src/model/regression/__init__.py +12 -0
- omnigenome/src/model/regression/model.py +786 -0
- omnigenome/src/model/regression/resnet.py +483 -0
- omnigenome/src/model/rna_design/__init__.py +12 -0
- omnigenome/src/model/rna_design/model.py +426 -0
- omnigenome/src/model/seq2seq/__init__.py +12 -0
- omnigenome/src/model/seq2seq/model.py +44 -0
- omnigenome/src/tokenizer/__init__.py +16 -0
- omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
- omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
- omnigenome/src/trainer/__init__.py +14 -0
- omnigenome/src/trainer/accelerate_trainer.py +739 -0
- omnigenome/src/trainer/hf_trainer.py +75 -0
- omnigenome/src/trainer/trainer.py +579 -0
- omnigenome/utility/__init__.py +3 -0
- omnigenome/utility/dataset_hub/__init__.py +13 -0
- omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
- omnigenome/utility/ensemble.py +324 -0
- omnigenome/utility/hub_utils.py +517 -0
- omnigenome/utility/model_hub/__init__.py +12 -0
- omnigenome/utility/model_hub/model_hub.py +231 -0
- omnigenome/utility/pipeline_hub/__init__.py +12 -0
- omnigenome/utility/pipeline_hub/pipeline.py +483 -0
- omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
- omnigenome-0.3.0a0.dist-info/METADATA +224 -0
- omnigenome-0.3.0a0.dist-info/RECORD +85 -0
- omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
- omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
- omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
- tests/__init__.py +9 -0
- tests/conftest.py +160 -0
- tests/test_dataset_patterns.py +291 -0
- tests/test_examples_syntax.py +83 -0
- tests/test_model_loading.py +183 -0
- tests/test_rna_functions.py +255 -0
- tests/test_training_patterns.py +302 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: model_hub.py
|
|
3
|
+
# time: 18:13 12/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
import json
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
import autocuda
|
|
13
|
+
import dill
|
|
14
|
+
import torch
|
|
15
|
+
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
|
16
|
+
|
|
17
|
+
from omnigenome.utility.hub_utils import query_models_info, download_model
|
|
18
|
+
from ...src.misc.utils import env_meta_info, fprint
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ModelHub:
|
|
22
|
+
"""
|
|
23
|
+
A hub for loading and managing pre-trained genomic models.
|
|
24
|
+
|
|
25
|
+
This class provides a unified interface for loading pre-trained models
|
|
26
|
+
from the OmniGenome hub or local paths. It handles model downloading,
|
|
27
|
+
tokenizer loading, and device placement automatically.
|
|
28
|
+
|
|
29
|
+
The ModelHub supports various model types and can automatically
|
|
30
|
+
download models from the hub if they're not available locally.
|
|
31
|
+
|
|
32
|
+
Attributes:
|
|
33
|
+
metadata (dict): Environment metadata information
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
>>> from omnigenome import ModelHub
|
|
37
|
+
>>> hub = ModelHub()
|
|
38
|
+
|
|
39
|
+
>>> # Load a model from the hub
|
|
40
|
+
>>> model, tokenizer = ModelHub.load_model_and_tokenizer("model_name")
|
|
41
|
+
|
|
42
|
+
>>> # Check available models
|
|
43
|
+
>>> models = hub.available_models()
|
|
44
|
+
>>> print(list(models.keys()))
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, *args, **kwargs):
|
|
48
|
+
"""
|
|
49
|
+
Initialize the ModelHub instance.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
*args: Additional positional arguments
|
|
53
|
+
**kwargs: Additional keyword arguments
|
|
54
|
+
"""
|
|
55
|
+
super(ModelHub, self).__init__(*args, **kwargs)
|
|
56
|
+
|
|
57
|
+
self.metadata = env_meta_info()
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def load_model_and_tokenizer(
|
|
61
|
+
model_name_or_path,
|
|
62
|
+
local_only=False,
|
|
63
|
+
device=None,
|
|
64
|
+
dtype=torch.float16,
|
|
65
|
+
**kwargs,
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Load a model and its tokenizer from the hub or local path.
|
|
69
|
+
|
|
70
|
+
This method loads both the model and tokenizer, places them on the
|
|
71
|
+
specified device, and returns them as a tuple. It handles automatic
|
|
72
|
+
device selection if none is specified.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
model_name_or_path (str): Name or path of the model to load
|
|
76
|
+
local_only (bool, optional): Whether to use only local cache. Defaults to False
|
|
77
|
+
device (str, optional): Device to load the model on. If None, uses auto-detection
|
|
78
|
+
dtype (torch.dtype, optional): Data type for the model. Defaults to torch.float16
|
|
79
|
+
**kwargs: Additional keyword arguments passed to the model loading functions
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
tuple: A tuple containing (model, tokenizer)
|
|
83
|
+
|
|
84
|
+
Example:
|
|
85
|
+
>>> model, tokenizer = ModelHub.load_model_and_tokenizer("yangheng/OmniGenome-186M")
|
|
86
|
+
>>> print(f"Model loaded on device: {next(model.parameters()).device}")
|
|
87
|
+
"""
|
|
88
|
+
model = ModelHub.load(model_name_or_path, local_only=local_only, **kwargs)
|
|
89
|
+
fprint(f"The model and tokenizer has been loaded from {model_name_or_path}.")
|
|
90
|
+
model.to(dtype)
|
|
91
|
+
if device is None:
|
|
92
|
+
device = autocuda.auto_cuda()
|
|
93
|
+
fprint(
|
|
94
|
+
f"No device is specified, the model will be loaded to the default device: {device}"
|
|
95
|
+
)
|
|
96
|
+
model.to(device)
|
|
97
|
+
else:
|
|
98
|
+
model.to(device)
|
|
99
|
+
return model, model.tokenizer
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def load(
|
|
103
|
+
model_name_or_path,
|
|
104
|
+
local_only=False,
|
|
105
|
+
device=None,
|
|
106
|
+
dtype=torch.float16,
|
|
107
|
+
**kwargs,
|
|
108
|
+
):
|
|
109
|
+
"""
|
|
110
|
+
Load a model from the hub or local path.
|
|
111
|
+
|
|
112
|
+
This method handles model loading from various sources including
|
|
113
|
+
local paths and the OmniGenome hub. It automatically downloads
|
|
114
|
+
models if they're not available locally.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
model_name_or_path (str): Name or path of the model to load
|
|
118
|
+
local_only (bool, optional): Whether to use only local cache. Defaults to False
|
|
119
|
+
device (str, optional): Device to load the model on. If None, uses auto-detection
|
|
120
|
+
dtype (torch.dtype, optional): Data type for the model. Defaults to torch.float16
|
|
121
|
+
**kwargs: Additional keyword arguments passed to the model loading functions
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
torch.nn.Module: The loaded model
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
ValueError: If model_name_or_path is not a string
|
|
128
|
+
|
|
129
|
+
Example:
|
|
130
|
+
>>> model = ModelHub.load("yangheng/OmniGenome-186M")
|
|
131
|
+
>>> print(f"Model type: {type(model)}")
|
|
132
|
+
"""
|
|
133
|
+
if isinstance(model_name_or_path, str) and os.path.exists(model_name_or_path):
|
|
134
|
+
path = model_name_or_path
|
|
135
|
+
elif isinstance(model_name_or_path, str) and not os.path.exists(
|
|
136
|
+
model_name_or_path
|
|
137
|
+
):
|
|
138
|
+
path = download_model(model_name_or_path, local_only=local_only, **kwargs)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError("model_name_or_path must be a string.")
|
|
141
|
+
|
|
142
|
+
import importlib
|
|
143
|
+
|
|
144
|
+
config = AutoConfig.from_pretrained(path, trust_remote_code=True, **kwargs)
|
|
145
|
+
|
|
146
|
+
with open(f"{path}/metadata.json", "r", encoding="utf8") as f:
|
|
147
|
+
metadata = json.load(f)
|
|
148
|
+
|
|
149
|
+
if "Omni" in metadata["tokenizer_cls"]:
|
|
150
|
+
lib = importlib.import_module(metadata["library_name"].lower())
|
|
151
|
+
tokenizer_cls = getattr(lib, metadata["tokenizer_cls"])
|
|
152
|
+
tokenizer = tokenizer_cls.from_pretrained(path, **kwargs)
|
|
153
|
+
else:
|
|
154
|
+
from multimolecule import RnaTokenizer
|
|
155
|
+
tokenizer = RnaTokenizer.from_pretrained(path, **kwargs)
|
|
156
|
+
|
|
157
|
+
config.metadata = metadata
|
|
158
|
+
|
|
159
|
+
base_model = AutoModel.from_config(config, trust_remote_code=True, **kwargs)
|
|
160
|
+
model_lib = importlib.import_module(metadata["library_name"].lower()).model
|
|
161
|
+
model_cls = getattr(model_lib, metadata["model_cls"])
|
|
162
|
+
model = model_cls(
|
|
163
|
+
base_model,
|
|
164
|
+
tokenizer,
|
|
165
|
+
label2id=config.label2id,
|
|
166
|
+
num_labels=config.num_labels,
|
|
167
|
+
**kwargs,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
with open(f"{path}/pytorch_model.bin", "rb") as f:
|
|
171
|
+
model.load_state_dict(
|
|
172
|
+
torch.load(f, map_location=kwargs.get("device", "cpu")), strict=False
|
|
173
|
+
)
|
|
174
|
+
model.to(dtype)
|
|
175
|
+
if device is None:
|
|
176
|
+
device = autocuda.auto_cuda()
|
|
177
|
+
fprint(
|
|
178
|
+
f"No device is specified, the model will be loaded to the default device: {device}"
|
|
179
|
+
)
|
|
180
|
+
model.to(device)
|
|
181
|
+
else:
|
|
182
|
+
model.to(device)
|
|
183
|
+
return model
|
|
184
|
+
|
|
185
|
+
def available_models(
|
|
186
|
+
self, model_name_or_path=None, local_only=False, repo="", **kwargs
|
|
187
|
+
):
|
|
188
|
+
"""
|
|
189
|
+
Get information about available models in the hub.
|
|
190
|
+
|
|
191
|
+
This method queries the OmniGenome hub to retrieve information about
|
|
192
|
+
available models. It can filter models by name and supports both
|
|
193
|
+
local and remote queries.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
model_name_or_path (str, optional): Filter models by name. Defaults to None
|
|
197
|
+
local_only (bool, optional): Whether to use only local cache. Defaults to False
|
|
198
|
+
repo (str, optional): Repository URL to query. Defaults to ""
|
|
199
|
+
**kwargs: Additional keyword arguments
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
dict: Dictionary containing information about available models
|
|
203
|
+
|
|
204
|
+
Example:
|
|
205
|
+
>>> hub = ModelHub()
|
|
206
|
+
>>> models = hub.available_models()
|
|
207
|
+
>>> print(f"Available models: {len(models)}")
|
|
208
|
+
|
|
209
|
+
>>> # Filter models by name
|
|
210
|
+
>>> dna_models = hub.available_models("DNA")
|
|
211
|
+
>>> print(f"DNA models: {list(dna_models.keys())}")
|
|
212
|
+
"""
|
|
213
|
+
models_info = query_models_info(
|
|
214
|
+
model_name_or_path, local_only=local_only, repo=repo, **kwargs
|
|
215
|
+
)
|
|
216
|
+
return models_info
|
|
217
|
+
|
|
218
|
+
def push(self, model, **kwargs):
|
|
219
|
+
"""
|
|
220
|
+
Push a model to the hub.
|
|
221
|
+
|
|
222
|
+
This method is not yet implemented and will raise a NotImplementedError.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
model: The model to push to the hub
|
|
226
|
+
**kwargs: Additional keyword arguments
|
|
227
|
+
|
|
228
|
+
Raises:
|
|
229
|
+
NotImplementedError: This method has not been implemented yet
|
|
230
|
+
"""
|
|
231
|
+
raise NotImplementedError("This method has not implemented yet.")
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: __init__.py
|
|
3
|
+
# time: 14:09 06/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
This package contains modules for the pipeline hub.
|
|
11
|
+
"""
|
|
12
|
+
|