omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. omnigenome/__init__.py +29 -44
  2. omnigenome/auto/auto_bench/__init__.py +0 -1
  3. omnigenome/auto/auto_bench/auto_bench.py +24 -14
  4. omnigenome/auto/auto_train/__init__.py +0 -1
  5. omnigenome/auto/auto_train/auto_train.py +11 -12
  6. omnigenome/auto/bench_hub/__init__.py +0 -1
  7. omnigenome/auto/bench_hub/bench_hub.py +1 -1
  8. omnigenome/cli/__init__.py +0 -1
  9. omnigenome/cli/commands/__init__.py +0 -1
  10. omnigenome/cli/commands/base.py +10 -10
  11. omnigenome/cli/commands/bench/__init__.py +0 -1
  12. omnigenome/cli/commands/bench/bench_cli.py +10 -10
  13. omnigenome/cli/commands/rna/__init__.py +0 -1
  14. omnigenome/cli/commands/rna/rna_design.py +10 -11
  15. omnigenome/src/__init__.py +0 -1
  16. omnigenome/src/abc/__init__.py +0 -1
  17. omnigenome/src/abc/abstract_dataset.py +38 -19
  18. omnigenome/src/abc/abstract_metric.py +7 -7
  19. omnigenome/src/abc/abstract_model.py +15 -14
  20. omnigenome/src/abc/abstract_tokenizer.py +9 -7
  21. omnigenome/src/dataset/omni_dataset.py +16 -14
  22. omnigenome/src/lora/__init__.py +0 -1
  23. omnigenome/src/lora/lora_model.py +47 -41
  24. omnigenome/src/metric/classification_metric.py +11 -11
  25. omnigenome/src/metric/metric.py +19 -19
  26. omnigenome/src/metric/ranking_metric.py +15 -15
  27. omnigenome/src/metric/regression_metric.py +18 -18
  28. omnigenome/src/misc/utils.py +214 -150
  29. omnigenome/src/model/augmentation/__init__.py +0 -1
  30. omnigenome/src/model/augmentation/model.py +17 -17
  31. omnigenome/src/model/classification/__init__.py +0 -1
  32. omnigenome/src/model/classification/model.py +28 -32
  33. omnigenome/src/model/embedding/__init__.py +0 -1
  34. omnigenome/src/model/embedding/model.py +35 -35
  35. omnigenome/src/model/mlm/__init__.py +0 -1
  36. omnigenome/src/model/mlm/model.py +13 -13
  37. omnigenome/src/model/module_utils.py +17 -17
  38. omnigenome/src/model/regression/__init__.py +0 -1
  39. omnigenome/src/model/regression/model.py +72 -77
  40. omnigenome/src/model/regression/resnet.py +32 -32
  41. omnigenome/src/model/rna_design/__init__.py +0 -1
  42. omnigenome/src/model/rna_design/model.py +168 -118
  43. omnigenome/src/model/seq2seq/__init__.py +0 -1
  44. omnigenome/src/model/seq2seq/model.py +4 -4
  45. omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
  46. omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
  47. omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
  48. omnigenome/src/trainer/accelerate_trainer.py +40 -32
  49. omnigenome/src/trainer/hf_trainer.py +8 -8
  50. omnigenome/src/trainer/trainer.py +37 -25
  51. omnigenome/utility/dataset_hub/__init__.py +0 -1
  52. omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
  53. omnigenome/utility/ensemble.py +26 -26
  54. omnigenome/utility/hub_utils.py +8 -8
  55. omnigenome/utility/model_hub/__init__.py +0 -1
  56. omnigenome/utility/model_hub/model_hub.py +26 -25
  57. omnigenome/utility/pipeline_hub/__init__.py +0 -1
  58. omnigenome/utility/pipeline_hub/pipeline.py +49 -49
  59. omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
  60. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
  61. omnigenome-0.3.1a0.dist-info/RECORD +78 -0
  62. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
  63. omnigenome-0.3.0a0.dist-info/RECORD +0 -85
  64. tests/__init__.py +0 -9
  65. tests/conftest.py +0 -160
  66. tests/test_dataset_patterns.py +0 -291
  67. tests/test_examples_syntax.py +0 -83
  68. tests/test_model_loading.py +0 -183
  69. tests/test_rna_functions.py +0 -255
  70. tests/test_training_patterns.py +0 -302
  71. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
  72. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
  73. {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
tests/conftest.py DELETED
@@ -1,160 +0,0 @@
1
- """
2
- Pytest configuration and shared fixtures for OmniGenBench tests.
3
- """
4
- import pytest
5
- import sys
6
- import os
7
- from pathlib import Path
8
-
9
- # Add the project root to Python path
10
- ROOT_DIR = Path(__file__).parent.parent
11
- sys.path.insert(0, str(ROOT_DIR))
12
-
13
-
14
- def pytest_configure(config):
15
- """Configure pytest with custom markers."""
16
- config.addinivalue_line(
17
- "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
18
- )
19
- config.addinivalue_line(
20
- "markers", "gpu: marks tests that require GPU (deselect with '-m \"not gpu\"')"
21
- )
22
- config.addinivalue_line(
23
- "markers", "integration: marks tests as integration tests"
24
- )
25
-
26
-
27
- def pytest_collection_modifyitems(config, items):
28
- """Auto-mark slow tests and skip GPU tests if CUDA not available."""
29
- try:
30
- import torch
31
- cuda_available = torch.cuda.is_available()
32
- except ImportError:
33
- cuda_available = False
34
-
35
- for item in items:
36
- # Auto-mark slow tests
37
- if "slow" in item.nodeid or "model_loading" in item.nodeid:
38
- item.add_marker(pytest.mark.slow)
39
-
40
- # Skip GPU tests if CUDA not available
41
- if item.get_closest_marker("gpu") and not cuda_available:
42
- item.add_marker(pytest.mark.skip(reason="CUDA not available"))
43
-
44
-
45
- @pytest.fixture
46
- def sample_rna_sequences():
47
- """Sample RNA sequences for testing."""
48
- return [
49
- "AUGGCUACG",
50
- "CGGAUACGGC",
51
- "UGGCCAAGUC",
52
- "AUGCUGCUAUGCUA"
53
- ]
54
-
55
-
56
- @pytest.fixture
57
- def sample_rna_structures():
58
- """Sample RNA secondary structures for testing."""
59
- return [
60
- "(((())))",
61
- "(((...)))",
62
- "........",
63
- "((..))"
64
- ]
65
-
66
-
67
- @pytest.fixture
68
- def sample_dataset_entries():
69
- """Sample dataset entries in the format used by examples."""
70
- return [
71
- {"seq": "AUCG", "label": "(..)"},
72
- {"seq": "AUGC", "label": "().."},
73
- {"seq": "CGAU", "label": "(())"},
74
- {"seq": "GAUC", "label": "...."}
75
- ]
76
-
77
-
78
- @pytest.fixture
79
- def mock_model_config():
80
- """Mock model configuration for testing."""
81
- from unittest.mock import MagicMock
82
- config = MagicMock()
83
- config.hidden_size = 768
84
- config.num_labels = 2
85
- config.vocab_size = 32
86
- config.max_position_embeddings = 512
87
- return config
88
-
89
-
90
- @pytest.fixture
91
- def mock_tokenizer():
92
- """Mock tokenizer for testing."""
93
- from unittest.mock import MagicMock
94
- tokenizer = MagicMock()
95
- tokenizer.encode.return_value = [1, 2, 3, 4, 5]
96
- tokenizer.decode.return_value = "AUGC"
97
- tokenizer.convert_ids_to_tokens.return_value = ["A", "U", "G", "C"]
98
- tokenizer.vocab_size = 32
99
- tokenizer.pad_token_id = 0
100
- tokenizer.eos_token_id = 2
101
- return tokenizer
102
-
103
-
104
- @pytest.fixture
105
- def temp_data_dir(tmp_path):
106
- """Create temporary directory with sample data files."""
107
- data_dir = tmp_path / "data"
108
- data_dir.mkdir()
109
-
110
- # Create sample train.json
111
- train_file = data_dir / "train.json"
112
- train_data = [
113
- '{"seq": "AUCG", "label": "(..)"}',
114
- '{"seq": "AUGC", "label": "().."}',
115
- '{"seq": "CGAU", "label": "(())"}'
116
- ]
117
- train_file.write_text("\n".join(train_data))
118
-
119
- # Create sample test.json
120
- test_file = data_dir / "test.json"
121
- test_data = [
122
- '{"seq": "GAUC", "label": "...."}',
123
- '{"seq": "UCGA", "label": "(.)"}'
124
- ]
125
- test_file.write_text("\n".join(test_data))
126
-
127
- # Create sample config.py
128
- config_file = data_dir / "config.py"
129
- config_content = '''
130
- # Dataset configuration
131
- max_length = 512
132
- num_labels = 4
133
- task_type = "classification"
134
- '''
135
- config_file.write_text(config_content)
136
-
137
- return data_dir
138
-
139
-
140
- @pytest.fixture(scope="session")
141
- def examples_dir():
142
- """Path to examples directory."""
143
- return ROOT_DIR / "examples"
144
-
145
-
146
- @pytest.fixture
147
- def skip_if_no_omnigenome():
148
- """Skip test if omnigenome package is not available."""
149
- try:
150
- import omnigenome
151
- return False
152
- except ImportError:
153
- pytest.skip("omnigenome package not available")
154
-
155
-
156
- # Custom pytest markers
157
- pytestmark = [
158
- pytest.mark.filterwarnings("ignore:.*:DeprecationWarning"),
159
- pytest.mark.filterwarnings("ignore:.*:UserWarning"),
160
- ]
@@ -1,291 +0,0 @@
1
- """
2
- Test dataset loading and processing patterns based on examples.
3
- """
4
- import pytest
5
- import json
6
- import tempfile
7
- import os
8
- from unittest.mock import patch, MagicMock, mock_open
9
-
10
-
11
- class TestDatasetPatterns:
12
- """Test dataset patterns from examples."""
13
-
14
- def test_dataset_imports(self):
15
- """Test dataset class imports as shown in examples."""
16
- try:
17
- from omnigenome import (
18
- OmniGenomeDatasetForSequenceClassification,
19
- OmniGenomeDatasetForSequenceRegression,
20
- OmniGenomeDatasetForTokenClassification,
21
- OmniGenomeDatasetForTokenRegression,
22
- )
23
- assert True
24
- except ImportError:
25
- pytest.skip("omnigenome not available or missing dependencies")
26
-
27
- def test_json_dataset_format(self):
28
- """Test JSON dataset format used in examples."""
29
- # Sample data format from toy_datasets
30
- sample_data = [
31
- {"seq": "AUCG", "label": "(...)"},
32
- {"seq": "AUGC", "label": "(..)"},
33
- {"seq": "CGAU", "label": "().."},
34
- ]
35
-
36
- # Verify format
37
- for item in sample_data:
38
- assert "seq" in item
39
- assert "label" in item
40
- assert isinstance(item["seq"], str)
41
- assert len(item["seq"]) > 0
42
-
43
- @patch("builtins.open", new_callable=mock_open)
44
- @patch("json.loads")
45
- def test_dataset_loading_pattern(self, mock_json_loads, mock_file):
46
- """Test dataset loading pattern from examples."""
47
- # Mock data similar to examples
48
- mock_data = [
49
- {"seq": "AUCG", "label": "(..)"},
50
- {"seq": "AUGC", "label": "()"},
51
- ]
52
-
53
- mock_json_loads.return_value = mock_data[0]
54
- mock_file.return_value.__iter__ = lambda self: iter([
55
- '{"seq": "AUCG", "label": "(..)"}\n',
56
- '{"seq": "AUGC", "label": "()"}\n'
57
- ])
58
-
59
- # Pattern from examples for loading test data
60
- def load_test_data(file_path):
61
- """Pattern from Secondary_Structure_Prediction.py."""
62
- data = []
63
- with open(file_path) as f:
64
- for line in f:
65
- data.append(json.loads(line))
66
- return data
67
-
68
- # Test the pattern
69
- result = load_test_data("test_file.json")
70
- assert len(result) == 2
71
-
72
- def test_config_file_structure(self):
73
- """Test config.py structure from toy_datasets."""
74
- # Common config patterns from examples
75
- config_patterns = {
76
- "max_length": [128, 256, 512, 1024],
77
- "num_labels": [2, 3, 4, 5],
78
- "task_type": ["classification", "regression", "token_classification"],
79
- }
80
-
81
- for key, valid_values in config_patterns.items():
82
- assert isinstance(key, str)
83
- assert isinstance(valid_values, list)
84
- assert len(valid_values) > 0
85
-
86
- def test_sample_data_extraction_pattern(self):
87
- """Test sample data extraction pattern from examples."""
88
- import random
89
- try:
90
- import numpy as np
91
- except ImportError:
92
- pytest.skip("numpy not available")
93
-
94
- def sample_rna_sequence_pattern():
95
- """Pattern from Secondary_Structure_Prediction.py."""
96
- try:
97
- # Mock data similar to toy_datasets/Archive2/test.json
98
- mock_examples = [
99
- {"seq": "AUCG", "label": "(..)"},
100
- {"seq": "AUGC", "label": "().."},
101
- {"seq": "CGAU", "label": "(())"},
102
- ]
103
- ex = mock_examples[np.random.randint(len(mock_examples))]
104
- return ex['seq'], ex.get('label', '')
105
- except Exception as e:
106
- return f"Error loading sample: {e}", ""
107
-
108
- # Test the pattern
109
- seq, label = sample_rna_sequence_pattern()
110
- assert isinstance(seq, str)
111
- assert isinstance(label, str)
112
-
113
- def test_data_validation_patterns(self):
114
- """Test data validation patterns from examples."""
115
- def validate_sequence_label_pair(seq, label):
116
- """Validate sequence-label pair format."""
117
- if not isinstance(seq, str) or not isinstance(label, str):
118
- return False
119
- if len(seq) == 0:
120
- return False
121
- # RNA sequence validation
122
- if not all(base in "AUCG" for base in seq):
123
- return False
124
- # Structure validation (if applicable)
125
- if label and not all(c in "()." for c in label):
126
- return False
127
- return True
128
-
129
- # Test valid pairs
130
- valid_pairs = [
131
- ("AUCG", "(..)"),
132
- ("AUG", "..."),
133
- ("AU", "()"),
134
- ("A", "."),
135
- ]
136
-
137
- for seq, label in valid_pairs:
138
- assert validate_sequence_label_pair(seq, label)
139
-
140
- # Test invalid pairs
141
- invalid_pairs = [
142
- ("", ""), # Empty sequence
143
- ("AUXG", "(..)"), # Invalid base X
144
- ("AUCG", "(.)X"), # Invalid structure char
145
- (123, "(..)"), # Non-string sequence
146
- ("AUCG", 123), # Non-string label
147
- ]
148
-
149
- for seq, label in invalid_pairs:
150
- assert not validate_sequence_label_pair(seq, label)
151
-
152
- def test_train_test_split_patterns(self):
153
- """Test train/test split patterns from examples."""
154
- # Mock dataset similar to toy_datasets structure
155
- mock_data = [
156
- {"seq": "AUCG", "label": "(..)"},
157
- {"seq": "AUGC", "label": "().."},
158
- {"seq": "CGAU", "label": "(())"},
159
- {"seq": "GAUC", "label": "...."},
160
- ]
161
-
162
- def split_data_pattern(data, train_ratio=0.8):
163
- """Simple train/test split pattern."""
164
- import random
165
- random.shuffle(data)
166
- split_idx = int(len(data) * train_ratio)
167
- return data[:split_idx], data[split_idx:]
168
-
169
- train_data, test_data = split_data_pattern(mock_data.copy())
170
-
171
- # Verify split
172
- assert len(train_data) + len(test_data) == len(mock_data)
173
- assert len(train_data) >= len(test_data) # With 80/20 split
174
-
175
- def test_dataset_file_patterns(self):
176
- """Test dataset file naming patterns from examples."""
177
- expected_files = ["train.json", "test.json", "valid.json", "config.py"]
178
-
179
- for filename in expected_files:
180
- # Verify naming patterns
181
- if filename.endswith(".json"):
182
- assert filename in ["train.json", "test.json", "valid.json"]
183
- elif filename.endswith(".py"):
184
- assert filename == "config.py"
185
-
186
- def test_dataset_initialization_pattern(self):
187
- """Test dataset initialization pattern from examples."""
188
- try:
189
- from omnigenome import OmniGenomeDatasetForSequenceClassification
190
- except ImportError:
191
- pytest.skip("omnigenome not available")
192
-
193
- with patch("omnigenome.OmniGenomeDatasetForSequenceClassification") as mock_dataset:
194
- mock_dataset.return_value = MagicMock()
195
-
196
- # Create a single mock tokenizer instance to use in both call and assertion
197
- mock_tokenizer_instance = MagicMock()
198
-
199
- # Pattern from examples
200
- dataset = OmniGenomeDatasetForSequenceClassification(
201
- train_file="path/to/train.json",
202
- test_file="path/to/test.json",
203
- tokenizer=mock_tokenizer_instance,
204
- max_length=512
205
- )
206
-
207
- # Verify the call was made with the expected arguments
208
- mock_dataset.assert_called_once()
209
- call_args = mock_dataset.call_args
210
- assert call_args[1]["train_file"] == "path/to/train.json"
211
- assert call_args[1]["test_file"] == "path/to/test.json"
212
- assert call_args[1]["max_length"] == 512
213
-
214
- def test_benchmark_dataset_structure(self):
215
- """Test benchmark dataset structure from examples."""
216
- # RGB benchmark structure from examples
217
- rgb_tasks = [
218
- "RNA-mRNA",
219
- "RNA-SNMD",
220
- "RNA-SNMR",
221
- "RNA-SSP-Archive2",
222
- "RNA-SSP-bpRNA",
223
- "RNA-SSP-rnastralign"
224
- ]
225
-
226
- for task in rgb_tasks:
227
- assert isinstance(task, str)
228
- assert "RNA" in task
229
- assert len(task) > 3
230
-
231
- def test_eterna_dataset_pattern(self):
232
- """Test Eterna dataset pattern from RNA design examples."""
233
- # Pattern from eterna100_vienna2.txt usage
234
- def load_eterna_pattern():
235
- """Mock Eterna dataset loading pattern."""
236
- # This would normally read from eterna100_vienna2.txt
237
- mock_eterna_data = [
238
- "(((...)))",
239
- "(((())))",
240
- "........",
241
- "((..))"
242
- ]
243
- return mock_eterna_data
244
-
245
- eterna_structures = load_eterna_pattern()
246
-
247
- for structure in eterna_structures:
248
- assert isinstance(structure, str)
249
- assert all(c in "()." for c in structure)
250
-
251
- def test_solved_sequences_format(self):
252
- """Test solved sequences format from RNA design examples."""
253
- # Format from solved_sequences.json in RNA design
254
- solved_format = {
255
- "puzzle_1": {
256
- "sequence": "AUCG",
257
- "structure": "(..)",
258
- "energy": -5.2
259
- },
260
- "puzzle_2": {
261
- "sequence": "AUGC",
262
- "structure": "().",
263
- "energy": -3.1
264
- }
265
- }
266
-
267
- for puzzle_id, data in solved_format.items():
268
- assert isinstance(puzzle_id, str)
269
- assert "sequence" in data
270
- assert "structure" in data
271
- assert "energy" in data
272
- assert isinstance(data["energy"], (int, float))
273
-
274
- def test_data_loading_error_handling(self):
275
- """Test error handling patterns from examples."""
276
- def safe_load_pattern(file_path):
277
- """Safe loading pattern from examples."""
278
- try:
279
- # Mock successful load
280
- return [{"seq": "AUCG", "label": "(..)"}]
281
- except FileNotFoundError:
282
- return []
283
- except json.JSONDecodeError:
284
- return []
285
- except Exception as e:
286
- print(f"Unexpected error: {e}")
287
- return []
288
-
289
- # Test error handling
290
- result = safe_load_pattern("nonexistent.json")
291
- assert isinstance(result, list)
@@ -1,83 +0,0 @@
1
- import os
2
- import glob
3
- import ast
4
- import py_compile
5
-
6
- import nbformat
7
- import pytest
8
-
9
- # Root directory of the repository (two levels up from this test file)
10
- ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
11
- EXAMPLES_DIR = os.path.join(ROOT_DIR, "examples")
12
-
13
- # -----------------------------------------------------------------------------
14
- # Helper collectors
15
- # -----------------------------------------------------------------------------
16
-
17
- def _collect_example_py_files():
18
- """Return list of all *.py files under examples/ recursively."""
19
- pattern = os.path.join(EXAMPLES_DIR, "**", "*.py")
20
- return [path for path in glob.glob(pattern, recursive=True) if os.path.isfile(path)]
21
-
22
-
23
- def _collect_example_notebooks():
24
- """Return list of all *.ipynb files under examples/ recursively."""
25
- pattern = os.path.join(EXAMPLES_DIR, "**", "*.ipynb")
26
- return [path for path in glob.glob(pattern, recursive=True) if os.path.isfile(path)]
27
-
28
-
29
- # -----------------------------------------------------------------------------
30
- # Tests for Python scripts
31
- # -----------------------------------------------------------------------------
32
-
33
- @pytest.mark.parametrize("py_path", _collect_example_py_files())
34
- def test_example_python_files_compile(py_path):
35
- """Ensure each example Python script has valid syntax.
36
-
37
- This uses ``py_compile`` so the file is parsed by CPython without execution
38
- of the module-level code, avoiding heavy runtime dependencies.
39
- """
40
- # doraise=True raises a ``py_compile.PyCompileError`` on failure which
41
- # pytest will treat as a test failure.
42
- py_compile.compile(py_path, doraise=True)
43
-
44
-
45
- # -----------------------------------------------------------------------------
46
- # Tests for Jupyter notebooks
47
- # -----------------------------------------------------------------------------
48
-
49
-
50
- def _clean_code(source: str) -> str:
51
- """Remove Jupyter magics / shell escapes so source can be parsed by ``ast``.
52
-
53
- Lines starting with ``%`` or ``!`` are stripped because they are not valid
54
- Python syntax outside a notebook environment.
55
- """
56
- cleaned_lines = []
57
- for line in source.splitlines():
58
- stripped = line.lstrip()
59
- if stripped.startswith("%") or stripped.startswith("!"):
60
- # Skip IPython magic or shell command
61
- continue
62
- cleaned_lines.append(line)
63
- return "\n".join(cleaned_lines)
64
-
65
-
66
- @pytest.mark.parametrize("nb_path", _collect_example_notebooks())
67
- def test_example_notebook_cells_parse(nb_path):
68
- """Validate that each code cell in the example notebooks can be parsed.
69
-
70
- Instead of executing potentially heavy code, we parse the cleaned source of
71
- each code cell with the ``ast`` module to ensure syntactic correctness.
72
- """
73
- nb = nbformat.read(nb_path, as_version=4)
74
- for cell in nb.cells:
75
- if cell.cell_type != "code":
76
- continue
77
- cleaned = _clean_code(cell.source)
78
- if cleaned.strip() == "":
79
- # Skip empty cells after cleaning
80
- continue
81
- # ``ast.parse`` raises ``SyntaxError`` on invalid Python code which will
82
- # fail the test if encountered.
83
- ast.parse(cleaned, filename=nb_path)