omnigenome 0.3.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +281 -0
- omnigenome/auto/__init__.py +3 -0
- omnigenome/auto/auto_bench/__init__.py +12 -0
- omnigenome/auto/auto_bench/auto_bench.py +484 -0
- omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
- omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
- omnigenome/auto/auto_bench/config_check.py +34 -0
- omnigenome/auto/auto_train/__init__.py +13 -0
- omnigenome/auto/auto_train/auto_train.py +430 -0
- omnigenome/auto/auto_train/auto_train_cli.py +222 -0
- omnigenome/auto/bench_hub/__init__.py +12 -0
- omnigenome/auto/bench_hub/bench_hub.py +25 -0
- omnigenome/cli/__init__.py +13 -0
- omnigenome/cli/commands/__init__.py +13 -0
- omnigenome/cli/commands/base.py +83 -0
- omnigenome/cli/commands/bench/__init__.py +13 -0
- omnigenome/cli/commands/bench/bench_cli.py +202 -0
- omnigenome/cli/commands/rna/__init__.py +13 -0
- omnigenome/cli/commands/rna/rna_design.py +178 -0
- omnigenome/cli/omnigenome_cli.py +128 -0
- omnigenome/src/__init__.py +12 -0
- omnigenome/src/abc/__init__.py +12 -0
- omnigenome/src/abc/abstract_dataset.py +622 -0
- omnigenome/src/abc/abstract_metric.py +114 -0
- omnigenome/src/abc/abstract_model.py +689 -0
- omnigenome/src/abc/abstract_tokenizer.py +267 -0
- omnigenome/src/dataset/__init__.py +16 -0
- omnigenome/src/dataset/omni_dataset.py +435 -0
- omnigenome/src/lora/__init__.py +13 -0
- omnigenome/src/lora/lora_model.py +294 -0
- omnigenome/src/metric/__init__.py +15 -0
- omnigenome/src/metric/classification_metric.py +184 -0
- omnigenome/src/metric/metric.py +199 -0
- omnigenome/src/metric/ranking_metric.py +142 -0
- omnigenome/src/metric/regression_metric.py +191 -0
- omnigenome/src/misc/__init__.py +3 -0
- omnigenome/src/misc/utils.py +439 -0
- omnigenome/src/model/__init__.py +19 -0
- omnigenome/src/model/augmentation/__init__.py +12 -0
- omnigenome/src/model/augmentation/model.py +219 -0
- omnigenome/src/model/classification/__init__.py +12 -0
- omnigenome/src/model/classification/model.py +642 -0
- omnigenome/src/model/embedding/__init__.py +12 -0
- omnigenome/src/model/embedding/model.py +263 -0
- omnigenome/src/model/mlm/__init__.py +12 -0
- omnigenome/src/model/mlm/model.py +177 -0
- omnigenome/src/model/module_utils.py +232 -0
- omnigenome/src/model/regression/__init__.py +12 -0
- omnigenome/src/model/regression/model.py +786 -0
- omnigenome/src/model/regression/resnet.py +483 -0
- omnigenome/src/model/rna_design/__init__.py +12 -0
- omnigenome/src/model/rna_design/model.py +426 -0
- omnigenome/src/model/seq2seq/__init__.py +12 -0
- omnigenome/src/model/seq2seq/model.py +44 -0
- omnigenome/src/tokenizer/__init__.py +16 -0
- omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
- omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
- omnigenome/src/trainer/__init__.py +14 -0
- omnigenome/src/trainer/accelerate_trainer.py +739 -0
- omnigenome/src/trainer/hf_trainer.py +75 -0
- omnigenome/src/trainer/trainer.py +579 -0
- omnigenome/utility/__init__.py +3 -0
- omnigenome/utility/dataset_hub/__init__.py +13 -0
- omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
- omnigenome/utility/ensemble.py +324 -0
- omnigenome/utility/hub_utils.py +517 -0
- omnigenome/utility/model_hub/__init__.py +12 -0
- omnigenome/utility/model_hub/model_hub.py +231 -0
- omnigenome/utility/pipeline_hub/__init__.py +12 -0
- omnigenome/utility/pipeline_hub/pipeline.py +483 -0
- omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
- omnigenome-0.3.0a0.dist-info/METADATA +224 -0
- omnigenome-0.3.0a0.dist-info/RECORD +85 -0
- omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
- omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
- omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
- omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
- tests/__init__.py +9 -0
- tests/conftest.py +160 -0
- tests/test_dataset_patterns.py +291 -0
- tests/test_examples_syntax.py +83 -0
- tests/test_model_loading.py +183 -0
- tests/test_rna_functions.py +255 -0
- tests/test_training_patterns.py +302 -0
|
@@ -0,0 +1,517 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: hub_utils.py
|
|
3
|
+
# time: 16:54 13/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
from typing import Union, Dict, Any
|
|
13
|
+
|
|
14
|
+
import findfile
|
|
15
|
+
import requests
|
|
16
|
+
import tqdm
|
|
17
|
+
from packaging.version import Version
|
|
18
|
+
from termcolor import colored
|
|
19
|
+
|
|
20
|
+
from omnigenome import __version__ as current_version
|
|
21
|
+
from omnigenome.src.misc.utils import fprint, default_omnigenome_repo
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def unzip_checkpoint(checkpoint_path):
|
|
25
|
+
"""
|
|
26
|
+
Unzips a checkpoint file.
|
|
27
|
+
|
|
28
|
+
This function extracts a zipped checkpoint file to a directory,
|
|
29
|
+
making it ready for use by the model loading functions.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
checkpoint_path (str): The path to the checkpoint file.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
str: The path to the extracted checkpoint directory.
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> extracted_path = unzip_checkpoint("model.zip")
|
|
39
|
+
>>> print(extracted_path) # "model"
|
|
40
|
+
"""
|
|
41
|
+
import zipfile
|
|
42
|
+
|
|
43
|
+
with zipfile.ZipFile(checkpoint_path, "r") as zip_ref:
|
|
44
|
+
zip_ref.extractall(checkpoint_path.strip(".zip"))
|
|
45
|
+
|
|
46
|
+
return checkpoint_path.strip(".zip")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def query_models_info(
|
|
50
|
+
keyword: Union[list, str], repo: str = None, local_only: bool = False, **kwargs
|
|
51
|
+
) -> Dict[str, Any]:
|
|
52
|
+
"""
|
|
53
|
+
Queries information about available models from the hub.
|
|
54
|
+
|
|
55
|
+
This function retrieves model information from the OmniGenome hub,
|
|
56
|
+
either from a remote repository or from a local cache. It supports
|
|
57
|
+
filtering by keywords to find specific models.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
keyword (Union[list, str]): A keyword or list of keywords to filter models.
|
|
61
|
+
repo (str, optional): The repository URL to query. If None, uses the default hub.
|
|
62
|
+
local_only (bool): Whether to use only local cache. Defaults to False.
|
|
63
|
+
**kwargs: Additional keyword arguments.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Dict[str, Any]: A dictionary containing model information filtered by the keyword.
|
|
67
|
+
|
|
68
|
+
Example:
|
|
69
|
+
>>> # Query all models
|
|
70
|
+
>>> models = query_models_info("")
|
|
71
|
+
>>> print(len(models)) # Number of available models
|
|
72
|
+
|
|
73
|
+
>>> # Query specific models
|
|
74
|
+
>>> models = query_models_info("DNA")
|
|
75
|
+
>>> print(models.keys()) # Models containing "DNA"
|
|
76
|
+
"""
|
|
77
|
+
if local_only:
|
|
78
|
+
with open("./models_info.json", "r", encoding="utf8") as f:
|
|
79
|
+
models_info = json.load(f)
|
|
80
|
+
else:
|
|
81
|
+
repo = repo if repo else "https://huggingface.co/spaces/anonymous8/gfm_hub/"
|
|
82
|
+
try:
|
|
83
|
+
response = requests.get(repo + "models_info.json")
|
|
84
|
+
models_info = response.json()
|
|
85
|
+
with open("./models_info.json", "w", encoding="utf8") as f:
|
|
86
|
+
json.dump(models_info, f)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
fprint(
|
|
89
|
+
"Fail to download models info from huggingface space, the error is: {}".format(
|
|
90
|
+
e
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
with open("./models_info.json", "r", encoding="utf8") as f:
|
|
94
|
+
models_info = json.load(f)
|
|
95
|
+
|
|
96
|
+
if isinstance(keyword, str):
|
|
97
|
+
filtered_models_info = {}
|
|
98
|
+
for key in models_info:
|
|
99
|
+
if keyword in key:
|
|
100
|
+
filtered_models_info[key] = models_info[key]
|
|
101
|
+
return filtered_models_info
|
|
102
|
+
else:
|
|
103
|
+
return models_info
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def query_pipelines_info(
|
|
107
|
+
keyword: Union[list, str], repo: str = None, local_only: bool = False, **kwargs
|
|
108
|
+
) -> Dict[str, Any]:
|
|
109
|
+
"""
|
|
110
|
+
Queries information about available pipelines from the hub.
|
|
111
|
+
|
|
112
|
+
This function retrieves pipeline information from the OmniGenome hub,
|
|
113
|
+
either from a remote repository or from a local cache. It supports
|
|
114
|
+
filtering by keywords to find specific pipelines.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
keyword (Union[list, str]): A keyword or list of keywords to filter pipelines.
|
|
118
|
+
repo (str, optional): The repository URL to query. If None, uses the default hub.
|
|
119
|
+
local_only (bool): Whether to use only local cache. Defaults to False.
|
|
120
|
+
**kwargs: Additional keyword arguments.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Dict[str, Any]: A dictionary containing pipeline information filtered by the keyword.
|
|
124
|
+
|
|
125
|
+
Example:
|
|
126
|
+
>>> # Query all pipelines
|
|
127
|
+
>>> pipelines = query_pipelines_info("")
|
|
128
|
+
>>> print(len(pipelines)) # Number of available pipelines
|
|
129
|
+
|
|
130
|
+
>>> # Query specific pipelines
|
|
131
|
+
>>> pipelines = query_pipelines_info("classification")
|
|
132
|
+
>>> print(pipelines.keys()) # Pipelines containing "classification"
|
|
133
|
+
"""
|
|
134
|
+
if local_only:
|
|
135
|
+
with open("./pipelines_info.json", "r", encoding="utf8") as f:
|
|
136
|
+
pipelines_info = json.load(f)
|
|
137
|
+
else:
|
|
138
|
+
repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
|
|
139
|
+
try:
|
|
140
|
+
response = requests.get(repo + "pipelines_info.json")
|
|
141
|
+
pipelines_info = response.json()
|
|
142
|
+
with open("./pipelines_info.json", "w", encoding="utf8") as f:
|
|
143
|
+
json.dump(pipelines_info, f)
|
|
144
|
+
except Exception as e:
|
|
145
|
+
fprint(
|
|
146
|
+
"Fail to download pipelines info from huggingface space, the error is: {}".format(
|
|
147
|
+
e
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
with open("./pipelines_info.json", "r", encoding="utf8") as f:
|
|
151
|
+
pipelines_info = json.load(f)
|
|
152
|
+
|
|
153
|
+
if isinstance(keyword, str):
|
|
154
|
+
filtered_pipelines_info = {}
|
|
155
|
+
for key in pipelines_info:
|
|
156
|
+
if keyword in key:
|
|
157
|
+
filtered_pipelines_info[key] = pipelines_info[key]
|
|
158
|
+
return filtered_pipelines_info
|
|
159
|
+
else:
|
|
160
|
+
return pipelines_info
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def query_benchmarks_info(
|
|
164
|
+
keyword: Union[list, str], repo: str = None, local_only: bool = False, **kwargs
|
|
165
|
+
) -> Dict[str, Any]:
|
|
166
|
+
"""
|
|
167
|
+
Queries information about available benchmarks from the hub.
|
|
168
|
+
|
|
169
|
+
This function retrieves benchmark information from the OmniGenome hub,
|
|
170
|
+
either from a remote repository or from a local cache. It supports
|
|
171
|
+
filtering by keywords to find specific benchmarks.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
keyword (Union[list, str]): A keyword or list of keywords to filter benchmarks.
|
|
175
|
+
repo (str, optional): The repository URL to query. If None, uses the default hub.
|
|
176
|
+
local_only (bool): Whether to use only local cache. Defaults to False.
|
|
177
|
+
**kwargs: Additional keyword arguments.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Dict[str, Any]: A dictionary containing benchmark information filtered by the keyword.
|
|
181
|
+
|
|
182
|
+
Example:
|
|
183
|
+
>>> # Query all benchmarks
|
|
184
|
+
>>> benchmarks = query_benchmarks_info("")
|
|
185
|
+
>>> print(len(benchmarks)) # Number of available benchmarks
|
|
186
|
+
|
|
187
|
+
>>> # Query specific benchmarks
|
|
188
|
+
>>> benchmarks = query_benchmarks_info("RGB")
|
|
189
|
+
>>> print(benchmarks.keys()) # Benchmarks containing "RGB"
|
|
190
|
+
"""
|
|
191
|
+
if local_only:
|
|
192
|
+
with open("./benchmarks_info.json", "r", encoding="utf8") as f:
|
|
193
|
+
benchmarks_info = json.load(f)
|
|
194
|
+
else:
|
|
195
|
+
repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
|
|
196
|
+
try:
|
|
197
|
+
response = requests.get(repo + "benchmarks_info.json")
|
|
198
|
+
benchmarks_info = response.json()
|
|
199
|
+
with open("./benchmarks_info.json", "w", encoding="utf8") as f:
|
|
200
|
+
json.dump(benchmarks_info, f)
|
|
201
|
+
except Exception as e:
|
|
202
|
+
fprint(
|
|
203
|
+
"Fail to download datasets info from huggingface space, the error is: {}".format(
|
|
204
|
+
e
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
with open("./benchmarks_info.json", "r", encoding="utf8") as f:
|
|
208
|
+
benchmarks_info = json.load(f)
|
|
209
|
+
|
|
210
|
+
if isinstance(keyword, str):
|
|
211
|
+
filtered_benchmarks_info = {}
|
|
212
|
+
for key in benchmarks_info:
|
|
213
|
+
if keyword in key:
|
|
214
|
+
filtered_benchmarks_info[key] = benchmarks_info[key]
|
|
215
|
+
return filtered_benchmarks_info
|
|
216
|
+
else:
|
|
217
|
+
return benchmarks_info
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def download_model(
|
|
221
|
+
model_name_or_path: str, local_only: bool = False, repo: str = None, cache_dir=None
|
|
222
|
+
) -> str:
|
|
223
|
+
"""
|
|
224
|
+
Downloads a model from a given URL.
|
|
225
|
+
|
|
226
|
+
This function downloads a model from the OmniGenome hub and caches it
|
|
227
|
+
locally for future use. It supports both remote and local-only modes.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
model_name_or_path (str): The name or path of the model to download.
|
|
231
|
+
local_only (bool): A flag indicating whether to download the model from
|
|
232
|
+
the local cache. Defaults to False.
|
|
233
|
+
repo (str, optional): The URL of the repository to download the model from.
|
|
234
|
+
cache_dir (str, optional): The directory to cache the downloaded model.
|
|
235
|
+
If None, uses "__OMNIGENOME_DATA__/models/".
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
str: A string representing the path to the downloaded model.
|
|
239
|
+
|
|
240
|
+
Raises:
|
|
241
|
+
ConnectionError: If the model download fails.
|
|
242
|
+
ValueError: If the model is not found in the repository.
|
|
243
|
+
|
|
244
|
+
Example:
|
|
245
|
+
>>> # Download a model
|
|
246
|
+
>>> model_path = download_model("DNABERT-2")
|
|
247
|
+
>>> print(model_path) # Path to the downloaded model
|
|
248
|
+
|
|
249
|
+
>>> # Download with custom cache directory
|
|
250
|
+
>>> model_path = download_model("DNABERT-2", cache_dir="./models")
|
|
251
|
+
"""
|
|
252
|
+
cache_dir = (cache_dir if cache_dir else "__OMNIGENOME_DATA__") + "/models/"
|
|
253
|
+
if not os.path.exists(cache_dir):
|
|
254
|
+
os.makedirs(cache_dir)
|
|
255
|
+
ckpt_config = findfile.find_files(cache_dir, ["config.json"])
|
|
256
|
+
if ckpt_config:
|
|
257
|
+
return os.path.dirname(ckpt_config[0])
|
|
258
|
+
|
|
259
|
+
if local_only:
|
|
260
|
+
with open("./models_info.json", "r", encoding="utf8") as f:
|
|
261
|
+
models_info = json.load(f)
|
|
262
|
+
else:
|
|
263
|
+
repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
|
|
264
|
+
try:
|
|
265
|
+
response = requests.get(repo + "models_info.json")
|
|
266
|
+
models_info = response.json()
|
|
267
|
+
with open("./models_info.json", "w", encoding="utf8") as f:
|
|
268
|
+
json.dump(models_info, f)
|
|
269
|
+
except Exception as e:
|
|
270
|
+
fprint(
|
|
271
|
+
"Fail to download models info from huggingface space, the error is: {}".format(
|
|
272
|
+
e
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
with open("./models_info.json", "r", encoding="utf8") as f:
|
|
276
|
+
models_info = json.load(f)
|
|
277
|
+
|
|
278
|
+
if model_name_or_path in models_info:
|
|
279
|
+
model_info = models_info[model_name_or_path]
|
|
280
|
+
try:
|
|
281
|
+
model_url = f'{repo}/models/{model_info["filename"]}'
|
|
282
|
+
response = requests.get(model_url, stream=True)
|
|
283
|
+
cache_path = os.path.join(cache_dir, f"{model_info['filename']}")
|
|
284
|
+
with open(cache_path, "wb") as f:
|
|
285
|
+
for chunk in tqdm.tqdm(
|
|
286
|
+
response.iter_content(chunk_size=1024 * 1024),
|
|
287
|
+
unit="MB",
|
|
288
|
+
total=int(response.headers["content-length"]) // 1024 // 1024,
|
|
289
|
+
desc="Downloading model",
|
|
290
|
+
):
|
|
291
|
+
f.write(chunk)
|
|
292
|
+
except Exception as e:
|
|
293
|
+
raise ConnectionError("Fail to download model: {}".format(e))
|
|
294
|
+
|
|
295
|
+
return unzip_checkpoint(cache_path)
|
|
296
|
+
|
|
297
|
+
else:
|
|
298
|
+
raise ValueError("Model not found in the repository.")
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def download_pipeline(
|
|
302
|
+
pipeline_name_or_path: str,
|
|
303
|
+
local_only: bool = False,
|
|
304
|
+
repo: str = None,
|
|
305
|
+
cache_dir=None,
|
|
306
|
+
) -> str:
|
|
307
|
+
"""
|
|
308
|
+
Downloads a pipeline from a given URL.
|
|
309
|
+
|
|
310
|
+
This function downloads a pipeline from the OmniGenome hub and caches it
|
|
311
|
+
locally for future use. It supports both remote and local-only modes.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
pipeline_name_or_path (str): The name or path of the pipeline to download.
|
|
315
|
+
local_only (bool): A flag indicating whether to download the pipeline from
|
|
316
|
+
the local cache. Defaults to False.
|
|
317
|
+
repo (str, optional): The URL of the repository to download the pipeline from.
|
|
318
|
+
cache_dir (str, optional): The directory to cache the downloaded pipeline.
|
|
319
|
+
If None, uses "__OMNIGENOME_DATA__/pipelines/".
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
str: A string representing the path to the downloaded pipeline.
|
|
323
|
+
|
|
324
|
+
Raises:
|
|
325
|
+
ConnectionError: If the pipeline download fails.
|
|
326
|
+
ValueError: If the pipeline is not found in the repository.
|
|
327
|
+
|
|
328
|
+
Example:
|
|
329
|
+
>>> # Download a pipeline
|
|
330
|
+
>>> pipeline_path = download_pipeline("classification_pipeline")
|
|
331
|
+
>>> print(pipeline_path) # Path to the downloaded pipeline
|
|
332
|
+
"""
|
|
333
|
+
cache_dir = (cache_dir if cache_dir else "__OMNIGENOME_DATA__") + "/pipelines/"
|
|
334
|
+
if not os.path.exists(cache_dir):
|
|
335
|
+
os.makedirs(cache_dir)
|
|
336
|
+
ckpt_config = findfile.find_files(cache_dir, ["config.json"])
|
|
337
|
+
if ckpt_config:
|
|
338
|
+
return os.path.dirname(ckpt_config[0])
|
|
339
|
+
|
|
340
|
+
if local_only:
|
|
341
|
+
with open("./pipelines_info.json", "r", encoding="utf8") as f:
|
|
342
|
+
pipelines_info = json.load(f)
|
|
343
|
+
else:
|
|
344
|
+
repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
|
|
345
|
+
try:
|
|
346
|
+
response = requests.get(repo + "pipelines_info.json")
|
|
347
|
+
pipelines_info = response.json()
|
|
348
|
+
with open("./pipelines_info.json", "w", encoding="utf8") as f:
|
|
349
|
+
json.dump(pipelines_info, f)
|
|
350
|
+
except Exception as e:
|
|
351
|
+
fprint(
|
|
352
|
+
"Fail to download pipelines info from huggingface space, the error is: {}".format(
|
|
353
|
+
e
|
|
354
|
+
)
|
|
355
|
+
)
|
|
356
|
+
with open("./pipelines_info.json", "r", encoding="utf8") as f:
|
|
357
|
+
pipelines_info = json.load(f)
|
|
358
|
+
|
|
359
|
+
if pipeline_name_or_path in pipelines_info:
|
|
360
|
+
pipeline_info = pipelines_info[pipeline_name_or_path]
|
|
361
|
+
try:
|
|
362
|
+
pipeline_url = f'{repo}/pipelines/{pipeline_info["filename"]}'
|
|
363
|
+
response = requests.get(pipeline_url, stream=True)
|
|
364
|
+
cache_path = os.path.join(cache_dir, f"{pipeline_info['filename']}")
|
|
365
|
+
with open(cache_path, "wb") as f:
|
|
366
|
+
for chunk in tqdm.tqdm(
|
|
367
|
+
response.iter_content(chunk_size=1024 * 1024),
|
|
368
|
+
unit="MB",
|
|
369
|
+
total=int(response.headers["content-length"]) // 1024 // 1024,
|
|
370
|
+
desc="Downloading pipeline",
|
|
371
|
+
):
|
|
372
|
+
f.write(chunk)
|
|
373
|
+
except Exception as e:
|
|
374
|
+
raise ConnectionError("Fail to download pipeline: {}".format(e))
|
|
375
|
+
|
|
376
|
+
return unzip_checkpoint(cache_path)
|
|
377
|
+
|
|
378
|
+
else:
|
|
379
|
+
raise ValueError("Pipeline not found in the repository.")
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def download_benchmark(
|
|
383
|
+
benchmark_name_or_path: str,
|
|
384
|
+
local_only: bool = False,
|
|
385
|
+
repo: str = None,
|
|
386
|
+
cache_dir=None,
|
|
387
|
+
) -> str:
|
|
388
|
+
"""
|
|
389
|
+
Downloads a benchmark from a given URL.
|
|
390
|
+
|
|
391
|
+
This function downloads a benchmark from the OmniGenome hub and caches it
|
|
392
|
+
locally for future use. It supports both remote and local-only modes.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
benchmark_name_or_path (str): The name or path of the benchmark to download.
|
|
396
|
+
local_only (bool): A flag indicating whether to download the benchmark from
|
|
397
|
+
the local cache. Defaults to False.
|
|
398
|
+
repo (str, optional): The URL of the repository to download the benchmark from.
|
|
399
|
+
cache_dir (str, optional): The directory to cache the downloaded benchmark.
|
|
400
|
+
If None, uses "__OMNIGENOME_DATA__/benchmarks/".
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
str: A string representing the path to the downloaded benchmark.
|
|
404
|
+
|
|
405
|
+
Raises:
|
|
406
|
+
ConnectionError: If the benchmark download fails.
|
|
407
|
+
ValueError: If the benchmark is not found in the repository.
|
|
408
|
+
|
|
409
|
+
Example:
|
|
410
|
+
>>> # Download a benchmark
|
|
411
|
+
>>> benchmark_path = download_benchmark("RGB")
|
|
412
|
+
>>> print(benchmark_path) # Path to the downloaded benchmark
|
|
413
|
+
|
|
414
|
+
>>> # Download with custom cache directory
|
|
415
|
+
>>> benchmark_path = download_benchmark("RGB", cache_dir="./benchmarks")
|
|
416
|
+
"""
|
|
417
|
+
cache_dir = (cache_dir if cache_dir else "__OMNIGENOME_DATA__") + "/benchmarks/"
|
|
418
|
+
if not os.path.exists(cache_dir):
|
|
419
|
+
os.makedirs(cache_dir)
|
|
420
|
+
bench_config = findfile.find_file(
|
|
421
|
+
cache_dir, [benchmark_name_or_path, "metadata.py"]
|
|
422
|
+
)
|
|
423
|
+
if bench_config:
|
|
424
|
+
return os.path.dirname(bench_config)
|
|
425
|
+
|
|
426
|
+
if local_only:
|
|
427
|
+
with open("./benchmarks_info.json", "r", encoding="utf8") as f:
|
|
428
|
+
benchmarks_info = json.load(f)
|
|
429
|
+
else:
|
|
430
|
+
repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
|
|
431
|
+
try:
|
|
432
|
+
response = requests.get(repo + "benchmarks_info.json")
|
|
433
|
+
benchmarks_info = response.json()
|
|
434
|
+
with open("./benchmarks_info.json", "w", encoding="utf8") as f:
|
|
435
|
+
json.dump(benchmarks_info, f)
|
|
436
|
+
except Exception as e:
|
|
437
|
+
fprint(
|
|
438
|
+
"Fail to download datasets info from huggingface space, the error is: {}".format(
|
|
439
|
+
e
|
|
440
|
+
)
|
|
441
|
+
)
|
|
442
|
+
with open("./benchmarks_info.json", "r", encoding="utf8") as f:
|
|
443
|
+
benchmarks_info = json.load(f)
|
|
444
|
+
|
|
445
|
+
if benchmark_name_or_path in benchmarks_info:
|
|
446
|
+
benchmarks_info_item = benchmarks_info[benchmark_name_or_path]
|
|
447
|
+
try:
|
|
448
|
+
benchmark_url = f'{repo}/benchmarks/{benchmarks_info_item["filename"]}'
|
|
449
|
+
response = requests.get(benchmark_url, stream=True)
|
|
450
|
+
cache_path = os.path.join(cache_dir, f"{benchmarks_info_item['filename']}")
|
|
451
|
+
with open(cache_path, "wb") as f:
|
|
452
|
+
for chunk in tqdm.tqdm(
|
|
453
|
+
response.iter_content(chunk_size=1024 * 1024),
|
|
454
|
+
unit="MB",
|
|
455
|
+
total=int(response.headers["content-length"]) // 1024 // 1024,
|
|
456
|
+
desc="Downloading benchmark",
|
|
457
|
+
):
|
|
458
|
+
f.write(chunk)
|
|
459
|
+
except Exception as e:
|
|
460
|
+
raise ConnectionError("Fail to download benchmark: {}".format(e))
|
|
461
|
+
|
|
462
|
+
return unzip_checkpoint(cache_path)
|
|
463
|
+
|
|
464
|
+
else:
|
|
465
|
+
raise ValueError("Benchmark not found in the repository.")
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def check_version(repo: str = None) -> None:
|
|
469
|
+
"""
|
|
470
|
+
Checks the version compatibility between local and remote OmniGenome.
|
|
471
|
+
|
|
472
|
+
This function compares the local OmniGenome version with the version
|
|
473
|
+
available in the remote repository to ensure compatibility.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
repo (str, optional): The repository URL to check. If None, uses the default hub.
|
|
477
|
+
|
|
478
|
+
Example:
|
|
479
|
+
>>> check_version() # Check version compatibility
|
|
480
|
+
"""
|
|
481
|
+
repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
|
|
482
|
+
try:
|
|
483
|
+
response = requests.get(repo + "version.json")
|
|
484
|
+
version_info = response.json()
|
|
485
|
+
remote_version = version_info["version"]
|
|
486
|
+
if Version(current_version) < Version(remote_version):
|
|
487
|
+
fprint(
|
|
488
|
+
colored(
|
|
489
|
+
f"Warning: Your local OmniGenome version ({current_version}) "
|
|
490
|
+
f"is older than the remote version ({remote_version}). "
|
|
491
|
+
f"Please consider updating.",
|
|
492
|
+
"yellow",
|
|
493
|
+
)
|
|
494
|
+
)
|
|
495
|
+
elif Version(current_version) > Version(remote_version):
|
|
496
|
+
fprint(
|
|
497
|
+
colored(
|
|
498
|
+
f"Warning: Your local OmniGenome version ({current_version}) "
|
|
499
|
+
f"is newer than the remote version ({remote_version}). "
|
|
500
|
+
f"This might cause compatibility issues.",
|
|
501
|
+
"yellow",
|
|
502
|
+
)
|
|
503
|
+
)
|
|
504
|
+
else:
|
|
505
|
+
fprint(
|
|
506
|
+
colored(
|
|
507
|
+
f"OmniGenome version ({current_version}) is up to date.",
|
|
508
|
+
"green",
|
|
509
|
+
)
|
|
510
|
+
)
|
|
511
|
+
except Exception as e:
|
|
512
|
+
fprint(
|
|
513
|
+
colored(
|
|
514
|
+
f"Failed to check version: {e}",
|
|
515
|
+
"red",
|
|
516
|
+
)
|
|
517
|
+
)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# file: __init__.py
|
|
3
|
+
# time: 18:27 11/04/2024
|
|
4
|
+
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
|
|
5
|
+
# github: https://github.com/yangheng95
|
|
6
|
+
# huggingface: https://huggingface.co/yangheng
|
|
7
|
+
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
|
|
8
|
+
# Copyright (C) 2019-2024. All Rights Reserved.
|
|
9
|
+
"""
|
|
10
|
+
This package contains modules for the model hub.
|
|
11
|
+
"""
|
|
12
|
+
|