omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
@@ -32,11 +32,11 @@ def load_benchmark_datasets(
32
32
  ):
33
33
  """
34
34
  Load benchmark datasets from the OmniGenome hub.
35
-
35
+
36
36
  This function automatically downloads benchmark datasets if they don't exist locally,
37
37
  loads their configurations, and initializes train/validation/test datasets with
38
38
  the specified tokenizer.
39
-
39
+
40
40
  Args:
41
41
  benchmark (str): Name or path of the benchmark to load. If the benchmark
42
42
  doesn't exist locally, it will be downloaded from the hub.
@@ -46,17 +46,17 @@ def load_benchmark_datasets(
46
46
  be loaded from the benchmark configuration.
47
47
  **kwargs: Additional keyword arguments to override benchmark configuration.
48
48
  These will be passed to the dataset classes and tokenizer initialization.
49
-
49
+
50
50
  Returns:
51
51
  dict: Dictionary containing datasets for each benchmark task, with keys
52
52
  being benchmark names and values being dictionaries with 'train',
53
53
  'valid', and 'test' datasets.
54
-
54
+
55
55
  Raises:
56
56
  FileNotFoundError: If the benchmark cannot be found or downloaded.
57
57
  ValueError: If the benchmark configuration is invalid.
58
58
  ImportError: If required dependencies are not available.
59
-
59
+
60
60
  Example:
61
61
  >>> from omnigenome import OmniSingleNucleotideTokenizer
62
62
  >>> tokenizer = OmniSingleNucleotideTokenizer.from_pretrained("model_name")
@@ -64,7 +64,7 @@ def load_benchmark_datasets(
64
64
  >>> print(f"Loaded {len(datasets)} benchmark tasks")
65
65
  >>> for task_name, task_datasets in datasets.items():
66
66
  ... print(f"{task_name}: {len(task_datasets['train'])} train samples")
67
-
67
+
68
68
  Note:
69
69
  - The function automatically handles U/T conversion and other preprocessing
70
70
  based on the benchmark configuration.
@@ -80,7 +80,7 @@ def load_benchmark_datasets(
80
80
  "does not exist. Search online for available benchmarks.",
81
81
  )
82
82
  benchmark = download_benchmark(benchmark)
83
-
83
+
84
84
  # Import benchmark list
85
85
  bench_metadata = load_module_from_path(
86
86
  f"bench_metadata", f"{benchmark}/metadata.py"
@@ -107,9 +107,7 @@ def load_benchmark_datasets(
107
107
 
108
108
  for key, value in _kwargs.items():
109
109
  if key in bench_config:
110
- fprint(
111
- "Override", key, "with", value, "according to the input kwargs"
112
- )
110
+ fprint("Override", key, "with", value, "according to the input kwargs")
113
111
  bench_config.update({key: value})
114
112
 
115
113
  else:
@@ -170,9 +168,11 @@ def load_benchmark_datasets(
170
168
  "valid": valid_set,
171
169
  }
172
170
 
173
- fprint(f"Loaded dataset for {bench} with {len(train_set)} train samples, "
174
- f"{len(test_set)} test samples and {len(valid_set)} valid samples.")
171
+ fprint(
172
+ f"Loaded dataset for {bench} with {len(train_set)} train samples, "
173
+ f"{len(test_set)} test samples and {len(valid_set)} valid samples."
174
+ )
175
175
 
176
176
  datasets[bench] = dataset
177
177
 
178
- return datasets
178
+ return datasets
@@ -14,11 +14,11 @@ import numpy as np
14
14
  class VoteEnsemblePredictor:
15
15
  """
16
16
  An ensemble predictor that combines predictions from multiple models using voting.
17
-
17
+
18
18
  This class implements ensemble methods for combining predictions from multiple
19
19
  models or checkpoints. It supports both weighted and unweighted voting, and
20
20
  provides various aggregation methods for different data types (numeric and string).
21
-
21
+
22
22
  Attributes:
23
23
  checkpoints: List of checkpoint names
24
24
  predictors: Dictionary of initialized predictors
@@ -27,7 +27,7 @@ class VoteEnsemblePredictor:
27
27
  str_agg: Function for aggregating string predictions
28
28
  numeric_agg_methods: Dictionary of available numeric aggregation methods
29
29
  str_agg_methods: Dictionary of available string aggregation methods
30
-
30
+
31
31
  Example:
32
32
  >>> from omnigenome.utility import VoteEnsemblePredictor
33
33
  >>> predictors = {
@@ -51,14 +51,14 @@ class VoteEnsemblePredictor:
51
51
  ):
52
52
  """
53
53
  Initialize the VoteEnsemblePredictor.
54
-
54
+
55
55
  Args:
56
56
  predictors (List or dict): A list of checkpoints, or a dictionary of initialized predictors
57
57
  weights (List or dict, optional): A list of weights for each predictor, or a dictionary of weights for each predictor
58
58
  numeric_agg (str, optional): The aggregation method for numeric data. Options are 'average', 'mean', 'max', 'min',
59
59
  'median', 'mode', and 'sum'. Defaults to 'average'
60
60
  str_agg (str, optional): The aggregation method for string data. Options are 'max_vote', 'min_vote', 'vote', and 'mode'. Defaults to 'max_vote'
61
-
61
+
62
62
  Raises:
63
63
  AssertionError: If predictors and weights have different lengths or types
64
64
  AssertionError: If predictors list is empty
@@ -113,13 +113,13 @@ class VoteEnsemblePredictor:
113
113
  def numeric_agg(self, result: list):
114
114
  """
115
115
  Aggregate a list of numeric values.
116
-
116
+
117
117
  Args:
118
118
  result (list): A list of numeric values to aggregate
119
-
119
+
120
120
  Returns:
121
121
  The aggregated value using the specified numeric aggregation method
122
-
122
+
123
123
  Example:
124
124
  >>> ensemble = VoteEnsemblePredictor(predictors, numeric_agg="average")
125
125
  >>> result = ensemble.numeric_agg([0.8, 0.9, 0.7])
@@ -132,13 +132,13 @@ class VoteEnsemblePredictor:
132
132
  def __ensemble(self, result: dict):
133
133
  """
134
134
  Aggregate prediction results by calling the appropriate aggregation method.
135
-
135
+
136
136
  This method determines the type of result and calls the appropriate
137
137
  aggregation method (numeric or string).
138
-
138
+
139
139
  Args:
140
140
  result (dict): A dictionary containing the prediction results
141
-
141
+
142
142
  Returns:
143
143
  The aggregated prediction result
144
144
  """
@@ -152,13 +152,13 @@ class VoteEnsemblePredictor:
152
152
  def __dict_aggregate(self, result: dict):
153
153
  """
154
154
  Recursively aggregate a dictionary of prediction results.
155
-
155
+
156
156
  This method recursively processes nested dictionaries and applies
157
157
  appropriate aggregation methods to each level.
158
-
158
+
159
159
  Args:
160
160
  result (dict): A dictionary containing the prediction results
161
-
161
+
162
162
  Returns:
163
163
  dict: The aggregated prediction result
164
164
  """
@@ -175,16 +175,16 @@ class VoteEnsemblePredictor:
175
175
  def __list_aggregate(self, result: list):
176
176
  """
177
177
  Aggregate a list of prediction results.
178
-
178
+
179
179
  This method handles different types of list elements and applies
180
180
  appropriate aggregation methods based on the data type.
181
-
181
+
182
182
  Args:
183
183
  result (list): A list of prediction results to aggregate
184
-
184
+
185
185
  Returns:
186
186
  The aggregated result
187
-
187
+
188
188
  Raises:
189
189
  AssertionError: If all elements in the list are not of the same type
190
190
  """
@@ -227,18 +227,18 @@ class VoteEnsemblePredictor:
227
227
  def predict(self, text, ignore_error=False, print_result=False):
228
228
  """
229
229
  Predicts on a single text and returns the ensemble result.
230
-
230
+
231
231
  This method combines predictions from all predictors in the ensemble
232
232
  using the specified weights and aggregation methods.
233
-
233
+
234
234
  Args:
235
235
  text (str): The text to perform prediction on
236
236
  ignore_error (bool, optional): Whether to ignore any errors that occur during prediction. Defaults to False
237
237
  print_result (bool, optional): Whether to print the prediction result. Defaults to False
238
-
238
+
239
239
  Returns:
240
240
  dict: The ensemble prediction result
241
-
241
+
242
242
  Example:
243
243
  >>> result = ensemble.predict("ACGUAGGUAUCGUAGA", ignore_error=True)
244
244
  >>> print(result)
@@ -267,18 +267,18 @@ class VoteEnsemblePredictor:
267
267
  def batch_predict(self, texts, ignore_error=False, print_result=False):
268
268
  """
269
269
  Predicts on a batch of texts using the ensemble of predictors.
270
-
270
+
271
271
  This method processes multiple texts efficiently by combining predictions
272
272
  from all predictors in the ensemble for each text in the batch.
273
-
273
+
274
274
  Args:
275
275
  texts (list): A list of strings to predict on
276
276
  ignore_error (bool, optional): Boolean indicating whether to ignore errors or raise exceptions when prediction fails. Defaults to False
277
277
  print_result (bool, optional): Boolean indicating whether to print the raw results for each predictor. Defaults to False
278
-
278
+
279
279
  Returns:
280
280
  list: A list of dictionaries, each dictionary containing the aggregated results of the corresponding text in the input list
281
-
281
+
282
282
  Example:
283
283
  >>> texts = ["ACGUAGGUAUCGUAGA", "GGCTAGCTA", "TATCGCTA"]
284
284
  >>> results = ensemble.batch_predict(texts, ignore_error=True)
@@ -24,7 +24,7 @@ from omnigenome.src.misc.utils import fprint, default_omnigenome_repo
24
24
  def unzip_checkpoint(checkpoint_path):
25
25
  """
26
26
  Unzips a checkpoint file.
27
-
27
+
28
28
  This function extracts a zipped checkpoint file to a directory,
29
29
  making it ready for use by the model loading functions.
30
30
 
@@ -51,7 +51,7 @@ def query_models_info(
51
51
  ) -> Dict[str, Any]:
52
52
  """
53
53
  Queries information about available models from the hub.
54
-
54
+
55
55
  This function retrieves model information from the OmniGenome hub,
56
56
  either from a remote repository or from a local cache. It supports
57
57
  filtering by keywords to find specific models.
@@ -69,7 +69,7 @@ def query_models_info(
69
69
  >>> # Query all models
70
70
  >>> models = query_models_info("")
71
71
  >>> print(len(models)) # Number of available models
72
-
72
+
73
73
  >>> # Query specific models
74
74
  >>> models = query_models_info("DNA")
75
75
  >>> print(models.keys()) # Models containing "DNA"
@@ -108,7 +108,7 @@ def query_pipelines_info(
108
108
  ) -> Dict[str, Any]:
109
109
  """
110
110
  Queries information about available pipelines from the hub.
111
-
111
+
112
112
  This function retrieves pipeline information from the OmniGenome hub,
113
113
  either from a remote repository or from a local cache. It supports
114
114
  filtering by keywords to find specific pipelines.
@@ -126,7 +126,7 @@ def query_pipelines_info(
126
126
  >>> # Query all pipelines
127
127
  >>> pipelines = query_pipelines_info("")
128
128
  >>> print(len(pipelines)) # Number of available pipelines
129
-
129
+
130
130
  >>> # Query specific pipelines
131
131
  >>> pipelines = query_pipelines_info("classification")
132
132
  >>> print(pipelines.keys()) # Pipelines containing "classification"
@@ -165,7 +165,7 @@ def query_benchmarks_info(
165
165
  ) -> Dict[str, Any]:
166
166
  """
167
167
  Queries information about available benchmarks from the hub.
168
-
168
+
169
169
  This function retrieves benchmark information from the OmniGenome hub,
170
170
  either from a remote repository or from a local cache. It supports
171
171
  filtering by keywords to find specific benchmarks.
@@ -183,7 +183,7 @@ def query_benchmarks_info(
183
183
  >>> # Query all benchmarks
184
184
  >>> benchmarks = query_benchmarks_info("")
185
185
  >>> print(len(benchmarks)) # Number of available benchmarks
186
-
186
+
187
187
  >>> # Query specific benchmarks
188
188
  >>> benchmarks = query_benchmarks_info("RGB")
189
189
  >>> print(benchmarks.keys()) # Benchmarks containing "RGB"
@@ -468,7 +468,7 @@ def download_benchmark(
468
468
  def check_version(repo: str = None) -> None:
469
469
  """
470
470
  Checks the version compatibility between local and remote OmniGenome.
471
-
471
+
472
472
  This function compares the local OmniGenome version with the version
473
473
  available in the remote repository to ensure compatibility.
474
474
 
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for the model hub.
11
11
  """
12
-
@@ -21,33 +21,33 @@ from ...src.misc.utils import env_meta_info, fprint
21
21
  class ModelHub:
22
22
  """
23
23
  A hub for loading and managing pre-trained genomic models.
24
-
24
+
25
25
  This class provides a unified interface for loading pre-trained models
26
26
  from the OmniGenome hub or local paths. It handles model downloading,
27
27
  tokenizer loading, and device placement automatically.
28
-
28
+
29
29
  The ModelHub supports various model types and can automatically
30
30
  download models from the hub if they're not available locally.
31
-
31
+
32
32
  Attributes:
33
33
  metadata (dict): Environment metadata information
34
-
34
+
35
35
  Example:
36
36
  >>> from omnigenome import ModelHub
37
37
  >>> hub = ModelHub()
38
-
38
+
39
39
  >>> # Load a model from the hub
40
40
  >>> model, tokenizer = ModelHub.load_model_and_tokenizer("model_name")
41
-
41
+
42
42
  >>> # Check available models
43
43
  >>> models = hub.available_models()
44
44
  >>> print(list(models.keys()))
45
45
  """
46
-
46
+
47
47
  def __init__(self, *args, **kwargs):
48
48
  """
49
49
  Initialize the ModelHub instance.
50
-
50
+
51
51
  Args:
52
52
  *args: Additional positional arguments
53
53
  **kwargs: Additional keyword arguments
@@ -66,21 +66,21 @@ class ModelHub:
66
66
  ):
67
67
  """
68
68
  Load a model and its tokenizer from the hub or local path.
69
-
69
+
70
70
  This method loads both the model and tokenizer, places them on the
71
71
  specified device, and returns them as a tuple. It handles automatic
72
72
  device selection if none is specified.
73
-
73
+
74
74
  Args:
75
75
  model_name_or_path (str): Name or path of the model to load
76
76
  local_only (bool, optional): Whether to use only local cache. Defaults to False
77
77
  device (str, optional): Device to load the model on. If None, uses auto-detection
78
78
  dtype (torch.dtype, optional): Data type for the model. Defaults to torch.float16
79
79
  **kwargs: Additional keyword arguments passed to the model loading functions
80
-
80
+
81
81
  Returns:
82
82
  tuple: A tuple containing (model, tokenizer)
83
-
83
+
84
84
  Example:
85
85
  >>> model, tokenizer = ModelHub.load_model_and_tokenizer("yangheng/OmniGenome-186M")
86
86
  >>> print(f"Model loaded on device: {next(model.parameters()).device}")
@@ -108,24 +108,24 @@ class ModelHub:
108
108
  ):
109
109
  """
110
110
  Load a model from the hub or local path.
111
-
111
+
112
112
  This method handles model loading from various sources including
113
113
  local paths and the OmniGenome hub. It automatically downloads
114
114
  models if they're not available locally.
115
-
115
+
116
116
  Args:
117
117
  model_name_or_path (str): Name or path of the model to load
118
118
  local_only (bool, optional): Whether to use only local cache. Defaults to False
119
119
  device (str, optional): Device to load the model on. If None, uses auto-detection
120
120
  dtype (torch.dtype, optional): Data type for the model. Defaults to torch.float16
121
121
  **kwargs: Additional keyword arguments passed to the model loading functions
122
-
122
+
123
123
  Returns:
124
124
  torch.nn.Module: The loaded model
125
-
125
+
126
126
  Raises:
127
127
  ValueError: If model_name_or_path is not a string
128
-
128
+
129
129
  Example:
130
130
  >>> model = ModelHub.load("yangheng/OmniGenome-186M")
131
131
  >>> print(f"Model type: {type(model)}")
@@ -152,6 +152,7 @@ class ModelHub:
152
152
  tokenizer = tokenizer_cls.from_pretrained(path, **kwargs)
153
153
  else:
154
154
  from multimolecule import RnaTokenizer
155
+
155
156
  tokenizer = RnaTokenizer.from_pretrained(path, **kwargs)
156
157
 
157
158
  config.metadata = metadata
@@ -187,25 +188,25 @@ class ModelHub:
187
188
  ):
188
189
  """
189
190
  Get information about available models in the hub.
190
-
191
+
191
192
  This method queries the OmniGenome hub to retrieve information about
192
193
  available models. It can filter models by name and supports both
193
194
  local and remote queries.
194
-
195
+
195
196
  Args:
196
197
  model_name_or_path (str, optional): Filter models by name. Defaults to None
197
198
  local_only (bool, optional): Whether to use only local cache. Defaults to False
198
199
  repo (str, optional): Repository URL to query. Defaults to ""
199
200
  **kwargs: Additional keyword arguments
200
-
201
+
201
202
  Returns:
202
203
  dict: Dictionary containing information about available models
203
-
204
+
204
205
  Example:
205
206
  >>> hub = ModelHub()
206
207
  >>> models = hub.available_models()
207
208
  >>> print(f"Available models: {len(models)}")
208
-
209
+
209
210
  >>> # Filter models by name
210
211
  >>> dna_models = hub.available_models("DNA")
211
212
  >>> print(f"DNA models: {list(dna_models.keys())}")
@@ -218,13 +219,13 @@ class ModelHub:
218
219
  def push(self, model, **kwargs):
219
220
  """
220
221
  Push a model to the hub.
221
-
222
+
222
223
  This method is not yet implemented and will raise a NotImplementedError.
223
-
224
+
224
225
  Args:
225
226
  model: The model to push to the hub
226
227
  **kwargs: Additional keyword arguments
227
-
228
+
228
229
  Raises:
229
230
  NotImplementedError: This method has not been implemented yet
230
231
  """
@@ -9,4 +9,3 @@
9
9
  """
10
10
  This package contains modules for the pipeline hub.
11
11
  """
12
-