omnigenome 0.3.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of omnigenome might be problematic. Click here for more details.

Files changed (85) hide show
  1. omnigenome/__init__.py +281 -0
  2. omnigenome/auto/__init__.py +3 -0
  3. omnigenome/auto/auto_bench/__init__.py +12 -0
  4. omnigenome/auto/auto_bench/auto_bench.py +484 -0
  5. omnigenome/auto/auto_bench/auto_bench_cli.py +230 -0
  6. omnigenome/auto/auto_bench/auto_bench_config.py +216 -0
  7. omnigenome/auto/auto_bench/config_check.py +34 -0
  8. omnigenome/auto/auto_train/__init__.py +13 -0
  9. omnigenome/auto/auto_train/auto_train.py +430 -0
  10. omnigenome/auto/auto_train/auto_train_cli.py +222 -0
  11. omnigenome/auto/bench_hub/__init__.py +12 -0
  12. omnigenome/auto/bench_hub/bench_hub.py +25 -0
  13. omnigenome/cli/__init__.py +13 -0
  14. omnigenome/cli/commands/__init__.py +13 -0
  15. omnigenome/cli/commands/base.py +83 -0
  16. omnigenome/cli/commands/bench/__init__.py +13 -0
  17. omnigenome/cli/commands/bench/bench_cli.py +202 -0
  18. omnigenome/cli/commands/rna/__init__.py +13 -0
  19. omnigenome/cli/commands/rna/rna_design.py +178 -0
  20. omnigenome/cli/omnigenome_cli.py +128 -0
  21. omnigenome/src/__init__.py +12 -0
  22. omnigenome/src/abc/__init__.py +12 -0
  23. omnigenome/src/abc/abstract_dataset.py +622 -0
  24. omnigenome/src/abc/abstract_metric.py +114 -0
  25. omnigenome/src/abc/abstract_model.py +689 -0
  26. omnigenome/src/abc/abstract_tokenizer.py +267 -0
  27. omnigenome/src/dataset/__init__.py +16 -0
  28. omnigenome/src/dataset/omni_dataset.py +435 -0
  29. omnigenome/src/lora/__init__.py +13 -0
  30. omnigenome/src/lora/lora_model.py +294 -0
  31. omnigenome/src/metric/__init__.py +15 -0
  32. omnigenome/src/metric/classification_metric.py +184 -0
  33. omnigenome/src/metric/metric.py +199 -0
  34. omnigenome/src/metric/ranking_metric.py +142 -0
  35. omnigenome/src/metric/regression_metric.py +191 -0
  36. omnigenome/src/misc/__init__.py +3 -0
  37. omnigenome/src/misc/utils.py +439 -0
  38. omnigenome/src/model/__init__.py +19 -0
  39. omnigenome/src/model/augmentation/__init__.py +12 -0
  40. omnigenome/src/model/augmentation/model.py +219 -0
  41. omnigenome/src/model/classification/__init__.py +12 -0
  42. omnigenome/src/model/classification/model.py +642 -0
  43. omnigenome/src/model/embedding/__init__.py +12 -0
  44. omnigenome/src/model/embedding/model.py +263 -0
  45. omnigenome/src/model/mlm/__init__.py +12 -0
  46. omnigenome/src/model/mlm/model.py +177 -0
  47. omnigenome/src/model/module_utils.py +232 -0
  48. omnigenome/src/model/regression/__init__.py +12 -0
  49. omnigenome/src/model/regression/model.py +786 -0
  50. omnigenome/src/model/regression/resnet.py +483 -0
  51. omnigenome/src/model/rna_design/__init__.py +12 -0
  52. omnigenome/src/model/rna_design/model.py +426 -0
  53. omnigenome/src/model/seq2seq/__init__.py +12 -0
  54. omnigenome/src/model/seq2seq/model.py +44 -0
  55. omnigenome/src/tokenizer/__init__.py +16 -0
  56. omnigenome/src/tokenizer/bpe_tokenizer.py +226 -0
  57. omnigenome/src/tokenizer/kmers_tokenizer.py +247 -0
  58. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +249 -0
  59. omnigenome/src/trainer/__init__.py +14 -0
  60. omnigenome/src/trainer/accelerate_trainer.py +739 -0
  61. omnigenome/src/trainer/hf_trainer.py +75 -0
  62. omnigenome/src/trainer/trainer.py +579 -0
  63. omnigenome/utility/__init__.py +3 -0
  64. omnigenome/utility/dataset_hub/__init__.py +13 -0
  65. omnigenome/utility/dataset_hub/dataset_hub.py +178 -0
  66. omnigenome/utility/ensemble.py +324 -0
  67. omnigenome/utility/hub_utils.py +517 -0
  68. omnigenome/utility/model_hub/__init__.py +12 -0
  69. omnigenome/utility/model_hub/model_hub.py +231 -0
  70. omnigenome/utility/pipeline_hub/__init__.py +12 -0
  71. omnigenome/utility/pipeline_hub/pipeline.py +483 -0
  72. omnigenome/utility/pipeline_hub/pipeline_hub.py +129 -0
  73. omnigenome-0.3.0a0.dist-info/METADATA +224 -0
  74. omnigenome-0.3.0a0.dist-info/RECORD +85 -0
  75. omnigenome-0.3.0a0.dist-info/WHEEL +5 -0
  76. omnigenome-0.3.0a0.dist-info/entry_points.txt +3 -0
  77. omnigenome-0.3.0a0.dist-info/licenses/LICENSE +201 -0
  78. omnigenome-0.3.0a0.dist-info/top_level.txt +2 -0
  79. tests/__init__.py +9 -0
  80. tests/conftest.py +160 -0
  81. tests/test_dataset_patterns.py +291 -0
  82. tests/test_examples_syntax.py +83 -0
  83. tests/test_model_loading.py +183 -0
  84. tests/test_rna_functions.py +255 -0
  85. tests/test_training_patterns.py +302 -0
@@ -0,0 +1,517 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: hub_utils.py
3
+ # time: 16:54 13/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+
10
+ import json
11
+ import os
12
+ from typing import Union, Dict, Any
13
+
14
+ import findfile
15
+ import requests
16
+ import tqdm
17
+ from packaging.version import Version
18
+ from termcolor import colored
19
+
20
+ from omnigenome import __version__ as current_version
21
+ from omnigenome.src.misc.utils import fprint, default_omnigenome_repo
22
+
23
+
24
+ def unzip_checkpoint(checkpoint_path):
25
+ """
26
+ Unzips a checkpoint file.
27
+
28
+ This function extracts a zipped checkpoint file to a directory,
29
+ making it ready for use by the model loading functions.
30
+
31
+ Args:
32
+ checkpoint_path (str): The path to the checkpoint file.
33
+
34
+ Returns:
35
+ str: The path to the extracted checkpoint directory.
36
+
37
+ Example:
38
+ >>> extracted_path = unzip_checkpoint("model.zip")
39
+ >>> print(extracted_path) # "model"
40
+ """
41
+ import zipfile
42
+
43
+ with zipfile.ZipFile(checkpoint_path, "r") as zip_ref:
44
+ zip_ref.extractall(checkpoint_path.strip(".zip"))
45
+
46
+ return checkpoint_path.strip(".zip")
47
+
48
+
49
+ def query_models_info(
50
+ keyword: Union[list, str], repo: str = None, local_only: bool = False, **kwargs
51
+ ) -> Dict[str, Any]:
52
+ """
53
+ Queries information about available models from the hub.
54
+
55
+ This function retrieves model information from the OmniGenome hub,
56
+ either from a remote repository or from a local cache. It supports
57
+ filtering by keywords to find specific models.
58
+
59
+ Args:
60
+ keyword (Union[list, str]): A keyword or list of keywords to filter models.
61
+ repo (str, optional): The repository URL to query. If None, uses the default hub.
62
+ local_only (bool): Whether to use only local cache. Defaults to False.
63
+ **kwargs: Additional keyword arguments.
64
+
65
+ Returns:
66
+ Dict[str, Any]: A dictionary containing model information filtered by the keyword.
67
+
68
+ Example:
69
+ >>> # Query all models
70
+ >>> models = query_models_info("")
71
+ >>> print(len(models)) # Number of available models
72
+
73
+ >>> # Query specific models
74
+ >>> models = query_models_info("DNA")
75
+ >>> print(models.keys()) # Models containing "DNA"
76
+ """
77
+ if local_only:
78
+ with open("./models_info.json", "r", encoding="utf8") as f:
79
+ models_info = json.load(f)
80
+ else:
81
+ repo = repo if repo else "https://huggingface.co/spaces/anonymous8/gfm_hub/"
82
+ try:
83
+ response = requests.get(repo + "models_info.json")
84
+ models_info = response.json()
85
+ with open("./models_info.json", "w", encoding="utf8") as f:
86
+ json.dump(models_info, f)
87
+ except Exception as e:
88
+ fprint(
89
+ "Fail to download models info from huggingface space, the error is: {}".format(
90
+ e
91
+ )
92
+ )
93
+ with open("./models_info.json", "r", encoding="utf8") as f:
94
+ models_info = json.load(f)
95
+
96
+ if isinstance(keyword, str):
97
+ filtered_models_info = {}
98
+ for key in models_info:
99
+ if keyword in key:
100
+ filtered_models_info[key] = models_info[key]
101
+ return filtered_models_info
102
+ else:
103
+ return models_info
104
+
105
+
106
+ def query_pipelines_info(
107
+ keyword: Union[list, str], repo: str = None, local_only: bool = False, **kwargs
108
+ ) -> Dict[str, Any]:
109
+ """
110
+ Queries information about available pipelines from the hub.
111
+
112
+ This function retrieves pipeline information from the OmniGenome hub,
113
+ either from a remote repository or from a local cache. It supports
114
+ filtering by keywords to find specific pipelines.
115
+
116
+ Args:
117
+ keyword (Union[list, str]): A keyword or list of keywords to filter pipelines.
118
+ repo (str, optional): The repository URL to query. If None, uses the default hub.
119
+ local_only (bool): Whether to use only local cache. Defaults to False.
120
+ **kwargs: Additional keyword arguments.
121
+
122
+ Returns:
123
+ Dict[str, Any]: A dictionary containing pipeline information filtered by the keyword.
124
+
125
+ Example:
126
+ >>> # Query all pipelines
127
+ >>> pipelines = query_pipelines_info("")
128
+ >>> print(len(pipelines)) # Number of available pipelines
129
+
130
+ >>> # Query specific pipelines
131
+ >>> pipelines = query_pipelines_info("classification")
132
+ >>> print(pipelines.keys()) # Pipelines containing "classification"
133
+ """
134
+ if local_only:
135
+ with open("./pipelines_info.json", "r", encoding="utf8") as f:
136
+ pipelines_info = json.load(f)
137
+ else:
138
+ repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
139
+ try:
140
+ response = requests.get(repo + "pipelines_info.json")
141
+ pipelines_info = response.json()
142
+ with open("./pipelines_info.json", "w", encoding="utf8") as f:
143
+ json.dump(pipelines_info, f)
144
+ except Exception as e:
145
+ fprint(
146
+ "Fail to download pipelines info from huggingface space, the error is: {}".format(
147
+ e
148
+ )
149
+ )
150
+ with open("./pipelines_info.json", "r", encoding="utf8") as f:
151
+ pipelines_info = json.load(f)
152
+
153
+ if isinstance(keyword, str):
154
+ filtered_pipelines_info = {}
155
+ for key in pipelines_info:
156
+ if keyword in key:
157
+ filtered_pipelines_info[key] = pipelines_info[key]
158
+ return filtered_pipelines_info
159
+ else:
160
+ return pipelines_info
161
+
162
+
163
+ def query_benchmarks_info(
164
+ keyword: Union[list, str], repo: str = None, local_only: bool = False, **kwargs
165
+ ) -> Dict[str, Any]:
166
+ """
167
+ Queries information about available benchmarks from the hub.
168
+
169
+ This function retrieves benchmark information from the OmniGenome hub,
170
+ either from a remote repository or from a local cache. It supports
171
+ filtering by keywords to find specific benchmarks.
172
+
173
+ Args:
174
+ keyword (Union[list, str]): A keyword or list of keywords to filter benchmarks.
175
+ repo (str, optional): The repository URL to query. If None, uses the default hub.
176
+ local_only (bool): Whether to use only local cache. Defaults to False.
177
+ **kwargs: Additional keyword arguments.
178
+
179
+ Returns:
180
+ Dict[str, Any]: A dictionary containing benchmark information filtered by the keyword.
181
+
182
+ Example:
183
+ >>> # Query all benchmarks
184
+ >>> benchmarks = query_benchmarks_info("")
185
+ >>> print(len(benchmarks)) # Number of available benchmarks
186
+
187
+ >>> # Query specific benchmarks
188
+ >>> benchmarks = query_benchmarks_info("RGB")
189
+ >>> print(benchmarks.keys()) # Benchmarks containing "RGB"
190
+ """
191
+ if local_only:
192
+ with open("./benchmarks_info.json", "r", encoding="utf8") as f:
193
+ benchmarks_info = json.load(f)
194
+ else:
195
+ repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
196
+ try:
197
+ response = requests.get(repo + "benchmarks_info.json")
198
+ benchmarks_info = response.json()
199
+ with open("./benchmarks_info.json", "w", encoding="utf8") as f:
200
+ json.dump(benchmarks_info, f)
201
+ except Exception as e:
202
+ fprint(
203
+ "Fail to download datasets info from huggingface space, the error is: {}".format(
204
+ e
205
+ )
206
+ )
207
+ with open("./benchmarks_info.json", "r", encoding="utf8") as f:
208
+ benchmarks_info = json.load(f)
209
+
210
+ if isinstance(keyword, str):
211
+ filtered_benchmarks_info = {}
212
+ for key in benchmarks_info:
213
+ if keyword in key:
214
+ filtered_benchmarks_info[key] = benchmarks_info[key]
215
+ return filtered_benchmarks_info
216
+ else:
217
+ return benchmarks_info
218
+
219
+
220
+ def download_model(
221
+ model_name_or_path: str, local_only: bool = False, repo: str = None, cache_dir=None
222
+ ) -> str:
223
+ """
224
+ Downloads a model from a given URL.
225
+
226
+ This function downloads a model from the OmniGenome hub and caches it
227
+ locally for future use. It supports both remote and local-only modes.
228
+
229
+ Args:
230
+ model_name_or_path (str): The name or path of the model to download.
231
+ local_only (bool): A flag indicating whether to download the model from
232
+ the local cache. Defaults to False.
233
+ repo (str, optional): The URL of the repository to download the model from.
234
+ cache_dir (str, optional): The directory to cache the downloaded model.
235
+ If None, uses "__OMNIGENOME_DATA__/models/".
236
+
237
+ Returns:
238
+ str: A string representing the path to the downloaded model.
239
+
240
+ Raises:
241
+ ConnectionError: If the model download fails.
242
+ ValueError: If the model is not found in the repository.
243
+
244
+ Example:
245
+ >>> # Download a model
246
+ >>> model_path = download_model("DNABERT-2")
247
+ >>> print(model_path) # Path to the downloaded model
248
+
249
+ >>> # Download with custom cache directory
250
+ >>> model_path = download_model("DNABERT-2", cache_dir="./models")
251
+ """
252
+ cache_dir = (cache_dir if cache_dir else "__OMNIGENOME_DATA__") + "/models/"
253
+ if not os.path.exists(cache_dir):
254
+ os.makedirs(cache_dir)
255
+ ckpt_config = findfile.find_files(cache_dir, ["config.json"])
256
+ if ckpt_config:
257
+ return os.path.dirname(ckpt_config[0])
258
+
259
+ if local_only:
260
+ with open("./models_info.json", "r", encoding="utf8") as f:
261
+ models_info = json.load(f)
262
+ else:
263
+ repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
264
+ try:
265
+ response = requests.get(repo + "models_info.json")
266
+ models_info = response.json()
267
+ with open("./models_info.json", "w", encoding="utf8") as f:
268
+ json.dump(models_info, f)
269
+ except Exception as e:
270
+ fprint(
271
+ "Fail to download models info from huggingface space, the error is: {}".format(
272
+ e
273
+ )
274
+ )
275
+ with open("./models_info.json", "r", encoding="utf8") as f:
276
+ models_info = json.load(f)
277
+
278
+ if model_name_or_path in models_info:
279
+ model_info = models_info[model_name_or_path]
280
+ try:
281
+ model_url = f'{repo}/models/{model_info["filename"]}'
282
+ response = requests.get(model_url, stream=True)
283
+ cache_path = os.path.join(cache_dir, f"{model_info['filename']}")
284
+ with open(cache_path, "wb") as f:
285
+ for chunk in tqdm.tqdm(
286
+ response.iter_content(chunk_size=1024 * 1024),
287
+ unit="MB",
288
+ total=int(response.headers["content-length"]) // 1024 // 1024,
289
+ desc="Downloading model",
290
+ ):
291
+ f.write(chunk)
292
+ except Exception as e:
293
+ raise ConnectionError("Fail to download model: {}".format(e))
294
+
295
+ return unzip_checkpoint(cache_path)
296
+
297
+ else:
298
+ raise ValueError("Model not found in the repository.")
299
+
300
+
301
+ def download_pipeline(
302
+ pipeline_name_or_path: str,
303
+ local_only: bool = False,
304
+ repo: str = None,
305
+ cache_dir=None,
306
+ ) -> str:
307
+ """
308
+ Downloads a pipeline from a given URL.
309
+
310
+ This function downloads a pipeline from the OmniGenome hub and caches it
311
+ locally for future use. It supports both remote and local-only modes.
312
+
313
+ Args:
314
+ pipeline_name_or_path (str): The name or path of the pipeline to download.
315
+ local_only (bool): A flag indicating whether to download the pipeline from
316
+ the local cache. Defaults to False.
317
+ repo (str, optional): The URL of the repository to download the pipeline from.
318
+ cache_dir (str, optional): The directory to cache the downloaded pipeline.
319
+ If None, uses "__OMNIGENOME_DATA__/pipelines/".
320
+
321
+ Returns:
322
+ str: A string representing the path to the downloaded pipeline.
323
+
324
+ Raises:
325
+ ConnectionError: If the pipeline download fails.
326
+ ValueError: If the pipeline is not found in the repository.
327
+
328
+ Example:
329
+ >>> # Download a pipeline
330
+ >>> pipeline_path = download_pipeline("classification_pipeline")
331
+ >>> print(pipeline_path) # Path to the downloaded pipeline
332
+ """
333
+ cache_dir = (cache_dir if cache_dir else "__OMNIGENOME_DATA__") + "/pipelines/"
334
+ if not os.path.exists(cache_dir):
335
+ os.makedirs(cache_dir)
336
+ ckpt_config = findfile.find_files(cache_dir, ["config.json"])
337
+ if ckpt_config:
338
+ return os.path.dirname(ckpt_config[0])
339
+
340
+ if local_only:
341
+ with open("./pipelines_info.json", "r", encoding="utf8") as f:
342
+ pipelines_info = json.load(f)
343
+ else:
344
+ repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
345
+ try:
346
+ response = requests.get(repo + "pipelines_info.json")
347
+ pipelines_info = response.json()
348
+ with open("./pipelines_info.json", "w", encoding="utf8") as f:
349
+ json.dump(pipelines_info, f)
350
+ except Exception as e:
351
+ fprint(
352
+ "Fail to download pipelines info from huggingface space, the error is: {}".format(
353
+ e
354
+ )
355
+ )
356
+ with open("./pipelines_info.json", "r", encoding="utf8") as f:
357
+ pipelines_info = json.load(f)
358
+
359
+ if pipeline_name_or_path in pipelines_info:
360
+ pipeline_info = pipelines_info[pipeline_name_or_path]
361
+ try:
362
+ pipeline_url = f'{repo}/pipelines/{pipeline_info["filename"]}'
363
+ response = requests.get(pipeline_url, stream=True)
364
+ cache_path = os.path.join(cache_dir, f"{pipeline_info['filename']}")
365
+ with open(cache_path, "wb") as f:
366
+ for chunk in tqdm.tqdm(
367
+ response.iter_content(chunk_size=1024 * 1024),
368
+ unit="MB",
369
+ total=int(response.headers["content-length"]) // 1024 // 1024,
370
+ desc="Downloading pipeline",
371
+ ):
372
+ f.write(chunk)
373
+ except Exception as e:
374
+ raise ConnectionError("Fail to download pipeline: {}".format(e))
375
+
376
+ return unzip_checkpoint(cache_path)
377
+
378
+ else:
379
+ raise ValueError("Pipeline not found in the repository.")
380
+
381
+
382
+ def download_benchmark(
383
+ benchmark_name_or_path: str,
384
+ local_only: bool = False,
385
+ repo: str = None,
386
+ cache_dir=None,
387
+ ) -> str:
388
+ """
389
+ Downloads a benchmark from a given URL.
390
+
391
+ This function downloads a benchmark from the OmniGenome hub and caches it
392
+ locally for future use. It supports both remote and local-only modes.
393
+
394
+ Args:
395
+ benchmark_name_or_path (str): The name or path of the benchmark to download.
396
+ local_only (bool): A flag indicating whether to download the benchmark from
397
+ the local cache. Defaults to False.
398
+ repo (str, optional): The URL of the repository to download the benchmark from.
399
+ cache_dir (str, optional): The directory to cache the downloaded benchmark.
400
+ If None, uses "__OMNIGENOME_DATA__/benchmarks/".
401
+
402
+ Returns:
403
+ str: A string representing the path to the downloaded benchmark.
404
+
405
+ Raises:
406
+ ConnectionError: If the benchmark download fails.
407
+ ValueError: If the benchmark is not found in the repository.
408
+
409
+ Example:
410
+ >>> # Download a benchmark
411
+ >>> benchmark_path = download_benchmark("RGB")
412
+ >>> print(benchmark_path) # Path to the downloaded benchmark
413
+
414
+ >>> # Download with custom cache directory
415
+ >>> benchmark_path = download_benchmark("RGB", cache_dir="./benchmarks")
416
+ """
417
+ cache_dir = (cache_dir if cache_dir else "__OMNIGENOME_DATA__") + "/benchmarks/"
418
+ if not os.path.exists(cache_dir):
419
+ os.makedirs(cache_dir)
420
+ bench_config = findfile.find_file(
421
+ cache_dir, [benchmark_name_or_path, "metadata.py"]
422
+ )
423
+ if bench_config:
424
+ return os.path.dirname(bench_config)
425
+
426
+ if local_only:
427
+ with open("./benchmarks_info.json", "r", encoding="utf8") as f:
428
+ benchmarks_info = json.load(f)
429
+ else:
430
+ repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
431
+ try:
432
+ response = requests.get(repo + "benchmarks_info.json")
433
+ benchmarks_info = response.json()
434
+ with open("./benchmarks_info.json", "w", encoding="utf8") as f:
435
+ json.dump(benchmarks_info, f)
436
+ except Exception as e:
437
+ fprint(
438
+ "Fail to download datasets info from huggingface space, the error is: {}".format(
439
+ e
440
+ )
441
+ )
442
+ with open("./benchmarks_info.json", "r", encoding="utf8") as f:
443
+ benchmarks_info = json.load(f)
444
+
445
+ if benchmark_name_or_path in benchmarks_info:
446
+ benchmarks_info_item = benchmarks_info[benchmark_name_or_path]
447
+ try:
448
+ benchmark_url = f'{repo}/benchmarks/{benchmarks_info_item["filename"]}'
449
+ response = requests.get(benchmark_url, stream=True)
450
+ cache_path = os.path.join(cache_dir, f"{benchmarks_info_item['filename']}")
451
+ with open(cache_path, "wb") as f:
452
+ for chunk in tqdm.tqdm(
453
+ response.iter_content(chunk_size=1024 * 1024),
454
+ unit="MB",
455
+ total=int(response.headers["content-length"]) // 1024 // 1024,
456
+ desc="Downloading benchmark",
457
+ ):
458
+ f.write(chunk)
459
+ except Exception as e:
460
+ raise ConnectionError("Fail to download benchmark: {}".format(e))
461
+
462
+ return unzip_checkpoint(cache_path)
463
+
464
+ else:
465
+ raise ValueError("Benchmark not found in the repository.")
466
+
467
+
468
+ def check_version(repo: str = None) -> None:
469
+ """
470
+ Checks the version compatibility between local and remote OmniGenome.
471
+
472
+ This function compares the local OmniGenome version with the version
473
+ available in the remote repository to ensure compatibility.
474
+
475
+ Args:
476
+ repo (str, optional): The repository URL to check. If None, uses the default hub.
477
+
478
+ Example:
479
+ >>> check_version() # Check version compatibility
480
+ """
481
+ repo = (repo if repo else default_omnigenome_repo) + "resolve/main/"
482
+ try:
483
+ response = requests.get(repo + "version.json")
484
+ version_info = response.json()
485
+ remote_version = version_info["version"]
486
+ if Version(current_version) < Version(remote_version):
487
+ fprint(
488
+ colored(
489
+ f"Warning: Your local OmniGenome version ({current_version}) "
490
+ f"is older than the remote version ({remote_version}). "
491
+ f"Please consider updating.",
492
+ "yellow",
493
+ )
494
+ )
495
+ elif Version(current_version) > Version(remote_version):
496
+ fprint(
497
+ colored(
498
+ f"Warning: Your local OmniGenome version ({current_version}) "
499
+ f"is newer than the remote version ({remote_version}). "
500
+ f"This might cause compatibility issues.",
501
+ "yellow",
502
+ )
503
+ )
504
+ else:
505
+ fprint(
506
+ colored(
507
+ f"OmniGenome version ({current_version}) is up to date.",
508
+ "green",
509
+ )
510
+ )
511
+ except Exception as e:
512
+ fprint(
513
+ colored(
514
+ f"Failed to check version: {e}",
515
+ "red",
516
+ )
517
+ )
@@ -0,0 +1,12 @@
1
+ # -*- coding: utf-8 -*-
2
+ # file: __init__.py
3
+ # time: 18:27 11/04/2024
4
+ # author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
5
+ # github: https://github.com/yangheng95
6
+ # huggingface: https://huggingface.co/yangheng
7
+ # google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
8
+ # Copyright (C) 2019-2024. All Rights Reserved.
9
+ """
10
+ This package contains modules for the model hub.
11
+ """
12
+