omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
@@ -1,183 +0,0 @@
1
- """
2
- Test model loading functionality based on examples.
3
- """
4
- import pytest
5
- import tempfile
6
- import os
7
- from unittest.mock import patch, MagicMock
8
-
9
- try:
10
- import torch
11
- except ImportError:
12
- torch = None
13
-
14
- # Skip heavy model loading tests by default - can be run with --run-slow
15
- pytestmark = pytest.mark.slow
16
-
17
-
18
- class TestModelLoading:
19
- """Test model loading similar to examples."""
20
-
21
- @pytest.fixture
22
- def mock_model_config(self):
23
- """Mock model config to avoid downloading real models."""
24
- config = MagicMock()
25
- config.hidden_size = 768
26
- config.num_labels = 2
27
- return config
28
-
29
- @pytest.fixture
30
- def mock_tokenizer(self):
31
- """Mock tokenizer for testing."""
32
- tokenizer = MagicMock()
33
- tokenizer.encode.return_value = [1, 2, 3, 4, 5]
34
- tokenizer.decode.return_value = "AUGC"
35
- tokenizer.convert_ids_to_tokens.return_value = ["A", "U", "G", "C"]
36
- return tokenizer
37
-
38
- def test_model_import_structure(self):
39
- """Test that model classes can be imported as shown in examples."""
40
- try:
41
- from omnigenome import (
42
- OmniModelForSequenceClassification,
43
- OmniModelForTokenClassification,
44
- OmniModelForSequenceRegression,
45
- OmniModelForTokenRegression,
46
- )
47
- # If import succeeds, test passes
48
- assert True
49
- except ImportError:
50
- pytest.skip("omnigenome not available or missing dependencies")
51
-
52
- def test_embedding_model_import(self):
53
- """Test embedding model import as shown in RNA_Embedding_Tutorial.ipynb."""
54
- try:
55
- from omnigenome import OmniGenomeModelForEmbedding
56
- assert True
57
- except ImportError:
58
- pytest.skip("omnigenome not available or missing dependencies")
59
-
60
- def test_pooling_import(self):
61
- """Test pooling import as shown in classification.ipynb."""
62
- try:
63
- from omnigenome import OmniModel, OmniPooling
64
- assert True
65
- except ImportError:
66
- pytest.skip("omnigenome not available or missing dependencies")
67
-
68
- def test_base_model_loading_pattern(self, mock_tokenizer):
69
- """Test the base model loading pattern from classification.ipynb."""
70
- try:
71
- from transformers import AutoTokenizer, AutoModel
72
- except ImportError:
73
- pytest.skip("transformers not available")
74
-
75
- with patch('transformers.AutoTokenizer.from_pretrained') as mock_auto_tokenizer, \
76
- patch('transformers.AutoModel.from_pretrained') as mock_auto_model:
77
-
78
- # Mock the returns
79
- mock_auto_tokenizer.return_value = mock_tokenizer
80
- mock_auto_model.return_value = MagicMock()
81
-
82
- # This pattern is from examples/custom_finetuning/classification.ipynb
83
- model_name = "yangheng/OmniGenome-52M"
84
- base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
85
- base_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
86
-
87
- # Verify the calls were made correctly
88
- mock_auto_model.assert_called_once_with(model_name, trust_remote_code=True)
89
- mock_auto_tokenizer.assert_called_once_with(model_name, trust_remote_code=True)
90
-
91
- def test_embedding_model_initialization_pattern(self):
92
- """Test embedding model initialization pattern from RNA_Embedding_Tutorial.ipynb."""
93
- if torch is None:
94
- pytest.skip("torch not available")
95
-
96
- try:
97
- from omnigenome import OmniGenomeModelForEmbedding
98
- except ImportError:
99
- pytest.skip("omnigenome not available")
100
-
101
- with patch('omnigenome.OmniGenomeModelForEmbedding') as mock_embedding_model:
102
- mock_instance = MagicMock()
103
- mock_instance.to.return_value = mock_instance
104
- mock_embedding_model.return_value = mock_instance
105
-
106
- model_name = "yangheng/OmniGenome-52M"
107
- embedding_model = OmniGenomeModelForEmbedding(model_name, trust_remote_code=True).to(torch.device("cuda:0")).to(torch.float16)
108
-
109
- # Verify initialization pattern
110
- mock_embedding_model.assert_called_once_with(model_name, trust_remote_code=True)
111
- assert mock_instance.to.call_count == 2 # Called twice for device and dtype
112
-
113
- def test_model_parameter_patterns(self):
114
- """Test that common model parameters are recognized."""
115
- # These are patterns seen across examples
116
- common_model_names = [
117
- "yangheng/OmniGenome-52M",
118
- "yangheng/OmniGenome-186M",
119
- "anonymous8/OmniGenome-186M",
120
- "anonymous8/OmniGenome-52M"
121
- ]
122
-
123
- for model_name in common_model_names:
124
- # Just verify the string patterns are valid
125
- assert isinstance(model_name, str)
126
- assert "/" in model_name
127
- assert "OmniGenome" in model_name
128
-
129
- def test_classification_model_initialization_pattern(self, mock_tokenizer):
130
- """Test classification model init pattern from examples."""
131
- try:
132
- from omnigenome import OmniModelForSequenceClassification
133
- except ImportError:
134
- pytest.skip("omnigenome not available")
135
-
136
- with patch('omnigenome.OmniModelForSequenceClassification') as mock_model_class:
137
- mock_model_class.return_value = MagicMock()
138
-
139
- # Pattern from classification.ipynb
140
- model_name = "test_model"
141
- tokenizer = mock_tokenizer
142
-
143
- model = OmniModelForSequenceClassification(
144
- config_or_model=model_name,
145
- tokenizer=tokenizer,
146
- num_labels=3,
147
- )
148
-
149
- mock_model_class.assert_called_once_with(
150
- config_or_model=model_name,
151
- tokenizer=tokenizer,
152
- num_labels=3,
153
- )
154
-
155
- def test_rna_sequence_patterns(self):
156
- """Test RNA sequence patterns used in examples."""
157
- # Patterns from RNA_Embedding_Tutorial.ipynb
158
- rna_sequences = [
159
- "AUGGCUACG",
160
- "CGGAUACGGC",
161
- "UGGCCAAGUC",
162
- "AUGCUGCUAUGCUA"
163
- ]
164
-
165
- for seq in rna_sequences:
166
- # Basic validation of RNA sequence format
167
- assert isinstance(seq, str)
168
- assert len(seq) > 0
169
- assert all(base in "AUCG" for base in seq)
170
-
171
- def test_device_patterns(self):
172
- """Test device usage patterns from examples."""
173
- if torch is None:
174
- pytest.skip("torch not available")
175
-
176
- # Pattern from examples: torch.device("cuda:0")
177
- device = torch.device("cuda:0")
178
- assert str(device) == "cuda:0"
179
-
180
- # Alternative pattern
181
- if torch.cuda.is_available():
182
- device = torch.device("cuda")
183
- assert "cuda" in str(device)
@@ -1,255 +0,0 @@
1
- """
2
- Test RNA-specific functionality based on examples.
3
- """
4
- import pytest
5
- import tempfile
6
- import os
7
- from unittest.mock import patch, MagicMock
8
-
9
-
10
- class TestRNAFunctions:
11
- """Test RNA functionality based on examples."""
12
-
13
- def test_rna_sequence_validity_checker(self):
14
- """Test ss_validity_loss function from Secondary_Structure_Prediction.py."""
15
- # Recreate the function from the example
16
- def ss_validity_loss(rna_strct: str) -> float:
17
- left = right = 0
18
- dots = rna_strct.count('.')
19
- for c in rna_strct:
20
- if c == '(':
21
- left += 1
22
- elif c == ')':
23
- if left:
24
- left -= 1
25
- else:
26
- right += 1
27
- elif c != '.':
28
- raise ValueError(f"Invalid char {c}")
29
- return (left + right) / (len(rna_strct) - dots + 1e-8)
30
-
31
- # Test valid structures
32
- assert ss_validity_loss("(())") == 0.0
33
- assert ss_validity_loss("((..))") == 0.0
34
- assert ss_validity_loss("....") == 0.0
35
-
36
- # Test invalid structures
37
- assert ss_validity_loss("(((") > 0.0 # Unmatched left
38
- assert ss_validity_loss(")))") > 0.0 # Unmatched right
39
- assert ss_validity_loss("())(") > 0.0 # Mixed unmatched
40
-
41
- # Test error case
42
- with pytest.raises(ValueError, match="Invalid char"):
43
- ss_validity_loss("((X))")
44
-
45
- def test_find_invalid_positions(self):
46
- """Test find_invalid_positions function from Secondary_Structure_Prediction.py."""
47
- # Recreate the function from the example
48
- def find_invalid_positions(struct: str) -> list:
49
- stack, invalid = [], []
50
- for i, c in enumerate(struct):
51
- if c == '(':
52
- stack.append(i)
53
- elif c == ')':
54
- if stack:
55
- stack.pop()
56
- else:
57
- invalid.append(i)
58
- invalid.extend(stack)
59
- return invalid
60
-
61
- # Test valid structures
62
- assert find_invalid_positions("(())") == []
63
- assert find_invalid_positions("((..))") == []
64
- assert find_invalid_positions("....") == []
65
-
66
- # Test invalid structures
67
- assert find_invalid_positions("(((") == [0, 1, 2] # All unmatched left
68
- assert find_invalid_positions(")))") == [0, 1, 2] # All unmatched right
69
- assert find_invalid_positions("())(") == [2, 3] # One unmatched right, one left
70
-
71
- def test_rna_structure_formats(self):
72
- """Test RNA structure format validation."""
73
- valid_structures = [
74
- "(())",
75
- "((()))",
76
- ".((.))",
77
- "....",
78
- "",
79
- "((..))",
80
- ]
81
-
82
- invalid_structures = [
83
- "((X))", # Invalid character
84
- "(()", # Unmatched
85
- "())", # Unmatched
86
- ")(", # Wrong order
87
- ]
88
-
89
- def is_valid_structure_format(struct: str) -> bool:
90
- """Check if structure contains only valid characters."""
91
- return all(c in "()." for c in struct)
92
-
93
- for struct in valid_structures:
94
- assert is_valid_structure_format(struct), f"Should be valid: {struct}"
95
-
96
- for struct in invalid_structures:
97
- if any(c not in "()." for c in struct):
98
- assert not is_valid_structure_format(struct), f"Should be invalid: {struct}"
99
-
100
- def test_sequence_replacement_patterns(self):
101
- """Test U/T replacement patterns from examples."""
102
- # Pattern from web_rna_design.py
103
- def rna_to_dna_pattern(sequence):
104
- return sequence.replace("U", "T")
105
-
106
- def dna_to_rna_pattern(sequence):
107
- return sequence.replace("T", "U")
108
-
109
- # Test RNA to DNA
110
- assert rna_to_dna_pattern("AUCG") == "ATCG"
111
- assert rna_to_dna_pattern("UUUU") == "TTTT"
112
- assert rna_to_dna_pattern("ACGU") == "ACGT"
113
-
114
- # Test DNA to RNA
115
- assert dna_to_rna_pattern("ATCG") == "AUCG"
116
- assert dna_to_rna_pattern("TTTT") == "UUUU"
117
- assert dna_to_rna_pattern("ACGT") == "ACGU"
118
-
119
- def test_random_base_generation_patterns(self):
120
- """Test random base generation patterns from RNA design examples."""
121
- import random
122
-
123
- def generate_random_rna_base():
124
- """Pattern from easy_rna_design_emoo.py."""
125
- return random.choice(["A", "U", "G", "C"])
126
-
127
- def generate_random_dna_base():
128
- """Pattern from easy_rna_design_emoo.py."""
129
- return random.choice(["A", "T", "G", "C"])
130
-
131
- # Test multiple generations to ensure valid bases
132
- for _ in range(10):
133
- rna_base = generate_random_rna_base()
134
- assert rna_base in ["A", "U", "G", "C"]
135
-
136
- dna_base = generate_random_dna_base()
137
- assert dna_base in ["A", "T", "G", "C"]
138
-
139
- def test_sequence_mutation_pattern(self):
140
- """Test sequence mutation pattern from mlm_mutate function."""
141
- try:
142
- import numpy as np
143
- except ImportError:
144
- pytest.skip("numpy not available")
145
-
146
- def mutate_sequence_pattern(sequence, mutation_rate=0.1):
147
- """Simplified version of mutation pattern from examples."""
148
- sequence_array = np.array(list(sequence), dtype=np.str_)
149
- probability_matrix = np.full(sequence_array.shape, mutation_rate)
150
- masked_indices = np.random.rand(*sequence_array.shape) < probability_matrix
151
- sequence_array[masked_indices] = "$" # Mask token
152
- return "".join(sequence_array.tolist())
153
-
154
- # Test mutation with 0% rate
155
- original = "AUCG"
156
- mutated_zero = mutate_sequence_pattern(original, 0.0)
157
- assert mutated_zero == original
158
-
159
- # Test mutation with 100% rate
160
- mutated_full = mutate_sequence_pattern(original, 1.0)
161
- assert mutated_full == "$$$$"
162
-
163
- # Test with moderate rate - should have some masks
164
- np.random.seed(42) # For reproducible test
165
- mutated_partial = mutate_sequence_pattern("AUCGAUCGAUCG", 0.5)
166
- assert "$" in mutated_partial
167
-
168
- @patch('tempfile.mkdtemp')
169
- def test_temp_directory_pattern(self, mock_mkdtemp):
170
- """Test temp directory usage pattern from Secondary_Structure_Prediction.py."""
171
- from pathlib import Path
172
-
173
- mock_mkdtemp.return_value = "/tmp/test_dir"
174
-
175
- # Pattern from Secondary_Structure_Prediction.py
176
- TEMP_DIR = Path(tempfile.mkdtemp())
177
-
178
- mock_mkdtemp.assert_called_once()
179
- assert isinstance(TEMP_DIR, Path)
180
-
181
- def test_rna_embedding_sequence_validation(self):
182
- """Test RNA sequence validation for embedding examples."""
183
- # RNA sequences from RNA_Embedding_Tutorial.ipynb
184
- rna_sequences = [
185
- "AUGGCUACG",
186
- "CGGAUACGGC",
187
- "UGGCCAAGUC",
188
- "AUGCUGCUAUGCUA"
189
- ]
190
-
191
- def validate_rna_sequence(seq):
192
- """Validate RNA sequence format."""
193
- return all(base in "AUCG" for base in seq) and len(seq) > 0
194
-
195
- for seq in rna_sequences:
196
- assert validate_rna_sequence(seq), f"Invalid RNA sequence: {seq}"
197
-
198
- def test_structure_prediction_mock_pattern(self):
199
- """Test structure prediction pattern without ViennaRNA dependency."""
200
- def mock_predict_structure_single(sequence):
201
- """Mock version of predict_structure_single from examples."""
202
- # Return a mock structure and energy
203
- return "." * len(sequence), -10.0
204
-
205
- # Test the pattern
206
- seq = "AUCG"
207
- struct, energy = mock_predict_structure_single(seq)
208
-
209
- assert len(struct) == len(seq)
210
- assert isinstance(energy, float)
211
- assert struct == "...."
212
-
213
- def test_base64_encoding_pattern(self):
214
- """Test base64 encoding pattern from SVG generation."""
215
- import base64
216
-
217
- def create_mock_svg_datauri(content="test"):
218
- """Mock version of SVG data URI creation."""
219
- svg_content = f'<svg>{content}</svg>'
220
- b64 = base64.b64encode(svg_content.encode()).decode('utf-8')
221
- return f"data:image/svg+xml;base64,{b64}"
222
-
223
- uri = create_mock_svg_datauri("test")
224
- assert uri.startswith("data:image/svg+xml;base64,")
225
-
226
- # Decode and verify
227
- _, b64_part = uri.split(",", 1)
228
- decoded = base64.b64decode(b64_part).decode('utf-8')
229
- assert decoded == "<svg>test</svg>"
230
-
231
- def test_longest_bp_span_function(self):
232
- """Test longest_bp_span function from easy_rna_design_emoo.py."""
233
- def longest_bp_span(structure):
234
- """Function from easy_rna_design_emoo.py."""
235
- stack = []
236
- max_span = 0
237
-
238
- for i, char in enumerate(structure):
239
- if char == '(':
240
- stack.append(i)
241
- elif char == ')':
242
- if stack:
243
- left_index = stack.pop()
244
- current_span = i - left_index
245
- max_span = max(max_span, current_span)
246
-
247
- return max_span
248
-
249
- # Test cases
250
- assert longest_bp_span("(())") == 3 # Outer pair spans 3 positions
251
- assert longest_bp_span("((()))") == 5 # Outer pair spans 5 positions
252
- assert longest_bp_span("()()") == 1 # Each pair spans 1 position
253
- assert longest_bp_span("....") == 0 # No pairs
254
- assert longest_bp_span("") == 0 # Empty structure
255
- assert longest_bp_span("((.))") == 4 # Outer pair spans 4 positions