omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
@@ -1,302 +0,0 @@
1
- """
2
- Test training patterns and configurations based on examples.
3
- """
4
- import pytest
5
- from unittest.mock import patch, MagicMock
6
-
7
-
8
- class TestTrainingPatterns:
9
- """Test training patterns from examples."""
10
-
11
- def test_trainer_imports(self):
12
- """Test trainer imports as shown in quick_start.md."""
13
- try:
14
- from omnigenome import Trainer
15
- assert True
16
- except ImportError:
17
- pytest.skip("omnigenome not available or missing dependencies")
18
-
19
- def test_autobench_imports(self):
20
- """Test AutoBench imports from examples."""
21
- try:
22
- from omnigenome import AutoBench
23
- assert True
24
- except ImportError:
25
- pytest.skip("omnigenome not available or missing dependencies")
26
-
27
- def test_autocuda_import_pattern(self):
28
- """Test autocuda import pattern from examples."""
29
- try:
30
- import autocuda
31
- # Pattern from examples
32
- device = autocuda.auto_cuda()
33
- # Just verify the function exists and returns something
34
- assert device is not None
35
- except ImportError:
36
- # Skip if autocuda not available
37
- pytest.skip("autocuda not available")
38
-
39
- @patch('omnigenome.AutoBench')
40
- def test_autobench_initialization_pattern(self, mock_autobench):
41
- """Test AutoBench initialization pattern from quick_start.md."""
42
- mock_instance = MagicMock()
43
- mock_autobench.return_value = mock_instance
44
-
45
- from omnigenome import AutoBench
46
-
47
- # Pattern from quick_start.md
48
- auto_bench = AutoBench(
49
- benchmark="RGB",
50
- model_name_or_path="yangheng/OmniGenome-186M",
51
- device="cuda:0",
52
- overwrite=True
53
- )
54
-
55
- mock_autobench.assert_called_once_with(
56
- benchmark="RGB",
57
- model_name_or_path="yangheng/OmniGenome-186M",
58
- device="cuda:0",
59
- overwrite=True
60
- )
61
-
62
- def test_benchmark_names(self):
63
- """Test benchmark names from examples."""
64
- # Benchmarks from quick_start.md
65
- benchmarks = ["RGB", "GB", "PGB", "GUE", "BEACON"]
66
-
67
- for benchmark in benchmarks:
68
- assert isinstance(benchmark, str)
69
- assert len(benchmark) > 0
70
- assert benchmark.isupper()
71
-
72
- def test_trainer_names(self):
73
- """Test trainer types from examples."""
74
- # Trainers from autobench examples
75
- trainers = ["accelerate", "huggingface"]
76
-
77
- for trainer in trainers:
78
- assert isinstance(trainer, str)
79
- assert trainer in ["accelerate", "huggingface"]
80
-
81
- @patch('omnigenome.Trainer')
82
- def test_trainer_initialization_pattern(self, mock_trainer):
83
- """Test Trainer initialization pattern from quick_start.md."""
84
- mock_trainer.return_value = MagicMock()
85
-
86
- from omnigenome import Trainer
87
-
88
- # Mock training arguments
89
- mock_args = MagicMock()
90
- mock_args.output_dir = "./results"
91
- mock_args.num_train_epochs = 3
92
- mock_args.per_device_train_batch_size = 8
93
- mock_args.learning_rate = 2e-5
94
-
95
- # Pattern from quick_start.md
96
- trainer = Trainer(
97
- model=MagicMock(),
98
- train_dataset=MagicMock(),
99
- eval_dataset=MagicMock(),
100
- args=mock_args
101
- )
102
-
103
- mock_trainer.assert_called_once()
104
-
105
- def test_training_arguments_pattern(self):
106
- """Test training arguments patterns from examples."""
107
- # Common training parameters from examples
108
- training_configs = {
109
- "output_dir": "./results",
110
- "num_train_epochs": 3,
111
- "per_device_train_batch_size": 8,
112
- "learning_rate": 2e-5,
113
- "epochs": 3,
114
- "batch_size": 8,
115
- "seeds": [42, 43, 44]
116
- }
117
-
118
- # Verify types and ranges
119
- assert isinstance(training_configs["output_dir"], str)
120
- assert isinstance(training_configs["num_train_epochs"], int)
121
- assert training_configs["num_train_epochs"] > 0
122
- assert isinstance(training_configs["per_device_train_batch_size"], int)
123
- assert training_configs["per_device_train_batch_size"] > 0
124
- assert isinstance(training_configs["learning_rate"], float)
125
- assert training_configs["learning_rate"] > 0
126
- assert isinstance(training_configs["seeds"], list)
127
- assert all(isinstance(seed, int) for seed in training_configs["seeds"])
128
-
129
- def test_genetic_algorithm_parameters(self):
130
- """Test genetic algorithm parameters from RNA design examples."""
131
- # Parameters from easy_rna_design_emoo.py
132
- ga_params = {
133
- "mutation_ratio": 0.1,
134
- "num_population": 100,
135
- "num_generation": 50,
136
- "model": "anonymous8/OmniGenome-186M"
137
- }
138
-
139
- # Verify parameter types and ranges
140
- assert isinstance(ga_params["mutation_ratio"], float)
141
- assert 0.0 <= ga_params["mutation_ratio"] <= 1.0
142
- assert isinstance(ga_params["num_population"], int)
143
- assert ga_params["num_population"] > 0
144
- assert isinstance(ga_params["num_generation"], int)
145
- assert ga_params["num_generation"] > 0
146
- assert isinstance(ga_params["model"], str)
147
-
148
- def test_web_rna_design_parameters(self):
149
- """Test web RNA design parameters from web_rna_design.py."""
150
- # Parameters from web_rna_design.py
151
- web_params = {
152
- "mutation_ratio": 0.5,
153
- "num_population": 500,
154
- "num_generation": 10,
155
- "puzzle_id": 0
156
- }
157
-
158
- # Verify parameter types
159
- assert isinstance(web_params["mutation_ratio"], float)
160
- assert 0.0 <= web_params["mutation_ratio"] <= 1.0
161
- assert isinstance(web_params["num_population"], int)
162
- assert web_params["num_population"] > 0
163
- assert isinstance(web_params["num_generation"], int)
164
- assert web_params["num_generation"] > 0
165
- assert isinstance(web_params["puzzle_id"], int)
166
- assert web_params["puzzle_id"] >= 0
167
-
168
- def test_model_optimization_patterns(self):
169
- """Test model optimization patterns from examples."""
170
- # Patterns from examples for model optimization
171
- optimization_configs = {
172
- "torch_dtype": "float16",
173
- "device_map": "auto",
174
- "trust_remote_code": True,
175
- "gradient_checkpointing": True,
176
- "fp16": True
177
- }
178
-
179
- for key, value in optimization_configs.items():
180
- assert isinstance(key, str)
181
- # Value types vary, just ensure they exist
182
- assert value is not None
183
-
184
- @patch('torch.cuda.empty_cache')
185
- def test_memory_management_pattern(self, mock_empty_cache):
186
- """Test memory management patterns from web_rna_design.py."""
187
- try:
188
- import torch
189
- except ImportError:
190
- pytest.skip("torch not available")
191
-
192
- # Pattern from web_rna_design.py
193
- def cleanup_model_pattern():
194
- """Memory cleanup pattern from examples."""
195
- # del model, tokenizer # Would normally delete objects
196
- torch.cuda.empty_cache()
197
-
198
- cleanup_model_pattern()
199
- mock_empty_cache.assert_called_once()
200
-
201
- def test_random_seed_patterns(self):
202
- """Test random seed patterns from examples."""
203
- import random
204
-
205
- # Pattern from examples
206
- def set_random_seed_pattern():
207
- """Random seed pattern from easy_rna_design_emoo.py."""
208
- return random.randint(0, 99999999)
209
-
210
- # Test seed generation
211
- seed1 = set_random_seed_pattern()
212
- seed2 = set_random_seed_pattern()
213
-
214
- assert isinstance(seed1, int)
215
- assert isinstance(seed2, int)
216
- assert 0 <= seed1 <= 99999999
217
- assert 0 <= seed2 <= 99999999
218
-
219
- def test_evaluation_metrics_patterns(self):
220
- """Test evaluation metrics patterns from examples."""
221
- # Common metrics mentioned in examples
222
- metrics = [
223
- "accuracy",
224
- "f1_score",
225
- "precision",
226
- "recall",
227
- "mse",
228
- "mae",
229
- "r2_score"
230
- ]
231
-
232
- for metric in metrics:
233
- assert isinstance(metric, str)
234
- assert len(metric) > 0
235
-
236
- def test_device_selection_patterns(self):
237
- """Test device selection patterns from examples."""
238
- # Patterns from examples
239
- device_patterns = [
240
- "cuda:0",
241
- "cuda",
242
- "cpu",
243
- "auto"
244
- ]
245
-
246
- for device in device_patterns:
247
- assert isinstance(device, str)
248
- assert len(device) > 0
249
-
250
- def test_batch_size_patterns(self):
251
- """Test batch size patterns from examples."""
252
- # Common batch sizes from examples
253
- batch_sizes = [4, 8, 16, 32, 64]
254
-
255
- for batch_size in batch_sizes:
256
- assert isinstance(batch_size, int)
257
- assert batch_size > 0
258
- assert batch_size <= 1024 # Reasonable upper limit
259
-
260
- def test_learning_rate_patterns(self):
261
- """Test learning rate patterns from examples."""
262
- # Common learning rates from examples
263
- learning_rates = [1e-5, 2e-5, 5e-5, 1e-4, 2e-4]
264
-
265
- for lr in learning_rates:
266
- assert isinstance(lr, float)
267
- assert lr > 0
268
- assert lr < 1.0 # Should be small
269
-
270
- def test_epoch_patterns(self):
271
- """Test epoch patterns from examples."""
272
- # Common epoch counts from examples
273
- epoch_counts = [1, 3, 5, 10, 20]
274
-
275
- for epochs in epoch_counts:
276
- assert isinstance(epochs, int)
277
- assert epochs > 0
278
- assert epochs <= 100 # Reasonable upper limit
279
-
280
- def test_output_directory_patterns(self):
281
- """Test output directory patterns from examples."""
282
- # Common output directory patterns
283
- output_dirs = [
284
- "./results",
285
- "./output",
286
- "./checkpoints",
287
- "./models"
288
- ]
289
-
290
- for output_dir in output_dirs:
291
- assert isinstance(output_dir, str)
292
- assert output_dir.startswith("./") or output_dir.startswith("/")
293
-
294
- def test_model_saving_patterns(self):
295
- """Test model saving patterns from examples."""
296
- # File extensions for saved models
297
- model_extensions = [".pt", ".pth", ".bin", ".safetensors"]
298
-
299
- for ext in model_extensions:
300
- assert isinstance(ext, str)
301
- assert ext.startswith(".")
302
- assert len(ext) > 1