omnigenome 0.3.0a0__py3-none-any.whl → 0.3.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of omnigenome might be problematic. Click here for more details.
- omnigenome/__init__.py +14 -37
- omnigenome/src/misc/utils.py +199 -139
- omnigenome/src/model/rna_design/model.py +139 -96
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/METADATA +3 -3
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/RECORD +9 -16
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/top_level.txt +0 -1
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.0a1.dist-info}/licenses/LICENSE +0 -0
tests/test_dataset_patterns.py
DELETED
|
@@ -1,291 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Test dataset loading and processing patterns based on examples.
|
|
3
|
-
"""
|
|
4
|
-
import pytest
|
|
5
|
-
import json
|
|
6
|
-
import tempfile
|
|
7
|
-
import os
|
|
8
|
-
from unittest.mock import patch, MagicMock, mock_open
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class TestDatasetPatterns:
|
|
12
|
-
"""Test dataset patterns from examples."""
|
|
13
|
-
|
|
14
|
-
def test_dataset_imports(self):
|
|
15
|
-
"""Test dataset class imports as shown in examples."""
|
|
16
|
-
try:
|
|
17
|
-
from omnigenome import (
|
|
18
|
-
OmniGenomeDatasetForSequenceClassification,
|
|
19
|
-
OmniGenomeDatasetForSequenceRegression,
|
|
20
|
-
OmniGenomeDatasetForTokenClassification,
|
|
21
|
-
OmniGenomeDatasetForTokenRegression,
|
|
22
|
-
)
|
|
23
|
-
assert True
|
|
24
|
-
except ImportError:
|
|
25
|
-
pytest.skip("omnigenome not available or missing dependencies")
|
|
26
|
-
|
|
27
|
-
def test_json_dataset_format(self):
|
|
28
|
-
"""Test JSON dataset format used in examples."""
|
|
29
|
-
# Sample data format from toy_datasets
|
|
30
|
-
sample_data = [
|
|
31
|
-
{"seq": "AUCG", "label": "(...)"},
|
|
32
|
-
{"seq": "AUGC", "label": "(..)"},
|
|
33
|
-
{"seq": "CGAU", "label": "().."},
|
|
34
|
-
]
|
|
35
|
-
|
|
36
|
-
# Verify format
|
|
37
|
-
for item in sample_data:
|
|
38
|
-
assert "seq" in item
|
|
39
|
-
assert "label" in item
|
|
40
|
-
assert isinstance(item["seq"], str)
|
|
41
|
-
assert len(item["seq"]) > 0
|
|
42
|
-
|
|
43
|
-
@patch("builtins.open", new_callable=mock_open)
|
|
44
|
-
@patch("json.loads")
|
|
45
|
-
def test_dataset_loading_pattern(self, mock_json_loads, mock_file):
|
|
46
|
-
"""Test dataset loading pattern from examples."""
|
|
47
|
-
# Mock data similar to examples
|
|
48
|
-
mock_data = [
|
|
49
|
-
{"seq": "AUCG", "label": "(..)"},
|
|
50
|
-
{"seq": "AUGC", "label": "()"},
|
|
51
|
-
]
|
|
52
|
-
|
|
53
|
-
mock_json_loads.return_value = mock_data[0]
|
|
54
|
-
mock_file.return_value.__iter__ = lambda self: iter([
|
|
55
|
-
'{"seq": "AUCG", "label": "(..)"}\n',
|
|
56
|
-
'{"seq": "AUGC", "label": "()"}\n'
|
|
57
|
-
])
|
|
58
|
-
|
|
59
|
-
# Pattern from examples for loading test data
|
|
60
|
-
def load_test_data(file_path):
|
|
61
|
-
"""Pattern from Secondary_Structure_Prediction.py."""
|
|
62
|
-
data = []
|
|
63
|
-
with open(file_path) as f:
|
|
64
|
-
for line in f:
|
|
65
|
-
data.append(json.loads(line))
|
|
66
|
-
return data
|
|
67
|
-
|
|
68
|
-
# Test the pattern
|
|
69
|
-
result = load_test_data("test_file.json")
|
|
70
|
-
assert len(result) == 2
|
|
71
|
-
|
|
72
|
-
def test_config_file_structure(self):
|
|
73
|
-
"""Test config.py structure from toy_datasets."""
|
|
74
|
-
# Common config patterns from examples
|
|
75
|
-
config_patterns = {
|
|
76
|
-
"max_length": [128, 256, 512, 1024],
|
|
77
|
-
"num_labels": [2, 3, 4, 5],
|
|
78
|
-
"task_type": ["classification", "regression", "token_classification"],
|
|
79
|
-
}
|
|
80
|
-
|
|
81
|
-
for key, valid_values in config_patterns.items():
|
|
82
|
-
assert isinstance(key, str)
|
|
83
|
-
assert isinstance(valid_values, list)
|
|
84
|
-
assert len(valid_values) > 0
|
|
85
|
-
|
|
86
|
-
def test_sample_data_extraction_pattern(self):
|
|
87
|
-
"""Test sample data extraction pattern from examples."""
|
|
88
|
-
import random
|
|
89
|
-
try:
|
|
90
|
-
import numpy as np
|
|
91
|
-
except ImportError:
|
|
92
|
-
pytest.skip("numpy not available")
|
|
93
|
-
|
|
94
|
-
def sample_rna_sequence_pattern():
|
|
95
|
-
"""Pattern from Secondary_Structure_Prediction.py."""
|
|
96
|
-
try:
|
|
97
|
-
# Mock data similar to toy_datasets/Archive2/test.json
|
|
98
|
-
mock_examples = [
|
|
99
|
-
{"seq": "AUCG", "label": "(..)"},
|
|
100
|
-
{"seq": "AUGC", "label": "().."},
|
|
101
|
-
{"seq": "CGAU", "label": "(())"},
|
|
102
|
-
]
|
|
103
|
-
ex = mock_examples[np.random.randint(len(mock_examples))]
|
|
104
|
-
return ex['seq'], ex.get('label', '')
|
|
105
|
-
except Exception as e:
|
|
106
|
-
return f"Error loading sample: {e}", ""
|
|
107
|
-
|
|
108
|
-
# Test the pattern
|
|
109
|
-
seq, label = sample_rna_sequence_pattern()
|
|
110
|
-
assert isinstance(seq, str)
|
|
111
|
-
assert isinstance(label, str)
|
|
112
|
-
|
|
113
|
-
def test_data_validation_patterns(self):
|
|
114
|
-
"""Test data validation patterns from examples."""
|
|
115
|
-
def validate_sequence_label_pair(seq, label):
|
|
116
|
-
"""Validate sequence-label pair format."""
|
|
117
|
-
if not isinstance(seq, str) or not isinstance(label, str):
|
|
118
|
-
return False
|
|
119
|
-
if len(seq) == 0:
|
|
120
|
-
return False
|
|
121
|
-
# RNA sequence validation
|
|
122
|
-
if not all(base in "AUCG" for base in seq):
|
|
123
|
-
return False
|
|
124
|
-
# Structure validation (if applicable)
|
|
125
|
-
if label and not all(c in "()." for c in label):
|
|
126
|
-
return False
|
|
127
|
-
return True
|
|
128
|
-
|
|
129
|
-
# Test valid pairs
|
|
130
|
-
valid_pairs = [
|
|
131
|
-
("AUCG", "(..)"),
|
|
132
|
-
("AUG", "..."),
|
|
133
|
-
("AU", "()"),
|
|
134
|
-
("A", "."),
|
|
135
|
-
]
|
|
136
|
-
|
|
137
|
-
for seq, label in valid_pairs:
|
|
138
|
-
assert validate_sequence_label_pair(seq, label)
|
|
139
|
-
|
|
140
|
-
# Test invalid pairs
|
|
141
|
-
invalid_pairs = [
|
|
142
|
-
("", ""), # Empty sequence
|
|
143
|
-
("AUXG", "(..)"), # Invalid base X
|
|
144
|
-
("AUCG", "(.)X"), # Invalid structure char
|
|
145
|
-
(123, "(..)"), # Non-string sequence
|
|
146
|
-
("AUCG", 123), # Non-string label
|
|
147
|
-
]
|
|
148
|
-
|
|
149
|
-
for seq, label in invalid_pairs:
|
|
150
|
-
assert not validate_sequence_label_pair(seq, label)
|
|
151
|
-
|
|
152
|
-
def test_train_test_split_patterns(self):
|
|
153
|
-
"""Test train/test split patterns from examples."""
|
|
154
|
-
# Mock dataset similar to toy_datasets structure
|
|
155
|
-
mock_data = [
|
|
156
|
-
{"seq": "AUCG", "label": "(..)"},
|
|
157
|
-
{"seq": "AUGC", "label": "().."},
|
|
158
|
-
{"seq": "CGAU", "label": "(())"},
|
|
159
|
-
{"seq": "GAUC", "label": "...."},
|
|
160
|
-
]
|
|
161
|
-
|
|
162
|
-
def split_data_pattern(data, train_ratio=0.8):
|
|
163
|
-
"""Simple train/test split pattern."""
|
|
164
|
-
import random
|
|
165
|
-
random.shuffle(data)
|
|
166
|
-
split_idx = int(len(data) * train_ratio)
|
|
167
|
-
return data[:split_idx], data[split_idx:]
|
|
168
|
-
|
|
169
|
-
train_data, test_data = split_data_pattern(mock_data.copy())
|
|
170
|
-
|
|
171
|
-
# Verify split
|
|
172
|
-
assert len(train_data) + len(test_data) == len(mock_data)
|
|
173
|
-
assert len(train_data) >= len(test_data) # With 80/20 split
|
|
174
|
-
|
|
175
|
-
def test_dataset_file_patterns(self):
|
|
176
|
-
"""Test dataset file naming patterns from examples."""
|
|
177
|
-
expected_files = ["train.json", "test.json", "valid.json", "config.py"]
|
|
178
|
-
|
|
179
|
-
for filename in expected_files:
|
|
180
|
-
# Verify naming patterns
|
|
181
|
-
if filename.endswith(".json"):
|
|
182
|
-
assert filename in ["train.json", "test.json", "valid.json"]
|
|
183
|
-
elif filename.endswith(".py"):
|
|
184
|
-
assert filename == "config.py"
|
|
185
|
-
|
|
186
|
-
def test_dataset_initialization_pattern(self):
|
|
187
|
-
"""Test dataset initialization pattern from examples."""
|
|
188
|
-
try:
|
|
189
|
-
from omnigenome import OmniGenomeDatasetForSequenceClassification
|
|
190
|
-
except ImportError:
|
|
191
|
-
pytest.skip("omnigenome not available")
|
|
192
|
-
|
|
193
|
-
with patch("omnigenome.OmniGenomeDatasetForSequenceClassification") as mock_dataset:
|
|
194
|
-
mock_dataset.return_value = MagicMock()
|
|
195
|
-
|
|
196
|
-
# Create a single mock tokenizer instance to use in both call and assertion
|
|
197
|
-
mock_tokenizer_instance = MagicMock()
|
|
198
|
-
|
|
199
|
-
# Pattern from examples
|
|
200
|
-
dataset = OmniGenomeDatasetForSequenceClassification(
|
|
201
|
-
train_file="path/to/train.json",
|
|
202
|
-
test_file="path/to/test.json",
|
|
203
|
-
tokenizer=mock_tokenizer_instance,
|
|
204
|
-
max_length=512
|
|
205
|
-
)
|
|
206
|
-
|
|
207
|
-
# Verify the call was made with the expected arguments
|
|
208
|
-
mock_dataset.assert_called_once()
|
|
209
|
-
call_args = mock_dataset.call_args
|
|
210
|
-
assert call_args[1]["train_file"] == "path/to/train.json"
|
|
211
|
-
assert call_args[1]["test_file"] == "path/to/test.json"
|
|
212
|
-
assert call_args[1]["max_length"] == 512
|
|
213
|
-
|
|
214
|
-
def test_benchmark_dataset_structure(self):
|
|
215
|
-
"""Test benchmark dataset structure from examples."""
|
|
216
|
-
# RGB benchmark structure from examples
|
|
217
|
-
rgb_tasks = [
|
|
218
|
-
"RNA-mRNA",
|
|
219
|
-
"RNA-SNMD",
|
|
220
|
-
"RNA-SNMR",
|
|
221
|
-
"RNA-SSP-Archive2",
|
|
222
|
-
"RNA-SSP-bpRNA",
|
|
223
|
-
"RNA-SSP-rnastralign"
|
|
224
|
-
]
|
|
225
|
-
|
|
226
|
-
for task in rgb_tasks:
|
|
227
|
-
assert isinstance(task, str)
|
|
228
|
-
assert "RNA" in task
|
|
229
|
-
assert len(task) > 3
|
|
230
|
-
|
|
231
|
-
def test_eterna_dataset_pattern(self):
|
|
232
|
-
"""Test Eterna dataset pattern from RNA design examples."""
|
|
233
|
-
# Pattern from eterna100_vienna2.txt usage
|
|
234
|
-
def load_eterna_pattern():
|
|
235
|
-
"""Mock Eterna dataset loading pattern."""
|
|
236
|
-
# This would normally read from eterna100_vienna2.txt
|
|
237
|
-
mock_eterna_data = [
|
|
238
|
-
"(((...)))",
|
|
239
|
-
"(((())))",
|
|
240
|
-
"........",
|
|
241
|
-
"((..))"
|
|
242
|
-
]
|
|
243
|
-
return mock_eterna_data
|
|
244
|
-
|
|
245
|
-
eterna_structures = load_eterna_pattern()
|
|
246
|
-
|
|
247
|
-
for structure in eterna_structures:
|
|
248
|
-
assert isinstance(structure, str)
|
|
249
|
-
assert all(c in "()." for c in structure)
|
|
250
|
-
|
|
251
|
-
def test_solved_sequences_format(self):
|
|
252
|
-
"""Test solved sequences format from RNA design examples."""
|
|
253
|
-
# Format from solved_sequences.json in RNA design
|
|
254
|
-
solved_format = {
|
|
255
|
-
"puzzle_1": {
|
|
256
|
-
"sequence": "AUCG",
|
|
257
|
-
"structure": "(..)",
|
|
258
|
-
"energy": -5.2
|
|
259
|
-
},
|
|
260
|
-
"puzzle_2": {
|
|
261
|
-
"sequence": "AUGC",
|
|
262
|
-
"structure": "().",
|
|
263
|
-
"energy": -3.1
|
|
264
|
-
}
|
|
265
|
-
}
|
|
266
|
-
|
|
267
|
-
for puzzle_id, data in solved_format.items():
|
|
268
|
-
assert isinstance(puzzle_id, str)
|
|
269
|
-
assert "sequence" in data
|
|
270
|
-
assert "structure" in data
|
|
271
|
-
assert "energy" in data
|
|
272
|
-
assert isinstance(data["energy"], (int, float))
|
|
273
|
-
|
|
274
|
-
def test_data_loading_error_handling(self):
|
|
275
|
-
"""Test error handling patterns from examples."""
|
|
276
|
-
def safe_load_pattern(file_path):
|
|
277
|
-
"""Safe loading pattern from examples."""
|
|
278
|
-
try:
|
|
279
|
-
# Mock successful load
|
|
280
|
-
return [{"seq": "AUCG", "label": "(..)"}]
|
|
281
|
-
except FileNotFoundError:
|
|
282
|
-
return []
|
|
283
|
-
except json.JSONDecodeError:
|
|
284
|
-
return []
|
|
285
|
-
except Exception as e:
|
|
286
|
-
print(f"Unexpected error: {e}")
|
|
287
|
-
return []
|
|
288
|
-
|
|
289
|
-
# Test error handling
|
|
290
|
-
result = safe_load_pattern("nonexistent.json")
|
|
291
|
-
assert isinstance(result, list)
|
tests/test_examples_syntax.py
DELETED
|
@@ -1,83 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import glob
|
|
3
|
-
import ast
|
|
4
|
-
import py_compile
|
|
5
|
-
|
|
6
|
-
import nbformat
|
|
7
|
-
import pytest
|
|
8
|
-
|
|
9
|
-
# Root directory of the repository (two levels up from this test file)
|
|
10
|
-
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
|
|
11
|
-
EXAMPLES_DIR = os.path.join(ROOT_DIR, "examples")
|
|
12
|
-
|
|
13
|
-
# -----------------------------------------------------------------------------
|
|
14
|
-
# Helper collectors
|
|
15
|
-
# -----------------------------------------------------------------------------
|
|
16
|
-
|
|
17
|
-
def _collect_example_py_files():
|
|
18
|
-
"""Return list of all *.py files under examples/ recursively."""
|
|
19
|
-
pattern = os.path.join(EXAMPLES_DIR, "**", "*.py")
|
|
20
|
-
return [path for path in glob.glob(pattern, recursive=True) if os.path.isfile(path)]
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def _collect_example_notebooks():
|
|
24
|
-
"""Return list of all *.ipynb files under examples/ recursively."""
|
|
25
|
-
pattern = os.path.join(EXAMPLES_DIR, "**", "*.ipynb")
|
|
26
|
-
return [path for path in glob.glob(pattern, recursive=True) if os.path.isfile(path)]
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
# -----------------------------------------------------------------------------
|
|
30
|
-
# Tests for Python scripts
|
|
31
|
-
# -----------------------------------------------------------------------------
|
|
32
|
-
|
|
33
|
-
@pytest.mark.parametrize("py_path", _collect_example_py_files())
|
|
34
|
-
def test_example_python_files_compile(py_path):
|
|
35
|
-
"""Ensure each example Python script has valid syntax.
|
|
36
|
-
|
|
37
|
-
This uses ``py_compile`` so the file is parsed by CPython without execution
|
|
38
|
-
of the module-level code, avoiding heavy runtime dependencies.
|
|
39
|
-
"""
|
|
40
|
-
# doraise=True raises a ``py_compile.PyCompileError`` on failure which
|
|
41
|
-
# pytest will treat as a test failure.
|
|
42
|
-
py_compile.compile(py_path, doraise=True)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
# -----------------------------------------------------------------------------
|
|
46
|
-
# Tests for Jupyter notebooks
|
|
47
|
-
# -----------------------------------------------------------------------------
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def _clean_code(source: str) -> str:
|
|
51
|
-
"""Remove Jupyter magics / shell escapes so source can be parsed by ``ast``.
|
|
52
|
-
|
|
53
|
-
Lines starting with ``%`` or ``!`` are stripped because they are not valid
|
|
54
|
-
Python syntax outside a notebook environment.
|
|
55
|
-
"""
|
|
56
|
-
cleaned_lines = []
|
|
57
|
-
for line in source.splitlines():
|
|
58
|
-
stripped = line.lstrip()
|
|
59
|
-
if stripped.startswith("%") or stripped.startswith("!"):
|
|
60
|
-
# Skip IPython magic or shell command
|
|
61
|
-
continue
|
|
62
|
-
cleaned_lines.append(line)
|
|
63
|
-
return "\n".join(cleaned_lines)
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
@pytest.mark.parametrize("nb_path", _collect_example_notebooks())
|
|
67
|
-
def test_example_notebook_cells_parse(nb_path):
|
|
68
|
-
"""Validate that each code cell in the example notebooks can be parsed.
|
|
69
|
-
|
|
70
|
-
Instead of executing potentially heavy code, we parse the cleaned source of
|
|
71
|
-
each code cell with the ``ast`` module to ensure syntactic correctness.
|
|
72
|
-
"""
|
|
73
|
-
nb = nbformat.read(nb_path, as_version=4)
|
|
74
|
-
for cell in nb.cells:
|
|
75
|
-
if cell.cell_type != "code":
|
|
76
|
-
continue
|
|
77
|
-
cleaned = _clean_code(cell.source)
|
|
78
|
-
if cleaned.strip() == "":
|
|
79
|
-
# Skip empty cells after cleaning
|
|
80
|
-
continue
|
|
81
|
-
# ``ast.parse`` raises ``SyntaxError`` on invalid Python code which will
|
|
82
|
-
# fail the test if encountered.
|
|
83
|
-
ast.parse(cleaned, filename=nb_path)
|
tests/test_model_loading.py
DELETED
|
@@ -1,183 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Test model loading functionality based on examples.
|
|
3
|
-
"""
|
|
4
|
-
import pytest
|
|
5
|
-
import tempfile
|
|
6
|
-
import os
|
|
7
|
-
from unittest.mock import patch, MagicMock
|
|
8
|
-
|
|
9
|
-
try:
|
|
10
|
-
import torch
|
|
11
|
-
except ImportError:
|
|
12
|
-
torch = None
|
|
13
|
-
|
|
14
|
-
# Skip heavy model loading tests by default - can be run with --run-slow
|
|
15
|
-
pytestmark = pytest.mark.slow
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class TestModelLoading:
|
|
19
|
-
"""Test model loading similar to examples."""
|
|
20
|
-
|
|
21
|
-
@pytest.fixture
|
|
22
|
-
def mock_model_config(self):
|
|
23
|
-
"""Mock model config to avoid downloading real models."""
|
|
24
|
-
config = MagicMock()
|
|
25
|
-
config.hidden_size = 768
|
|
26
|
-
config.num_labels = 2
|
|
27
|
-
return config
|
|
28
|
-
|
|
29
|
-
@pytest.fixture
|
|
30
|
-
def mock_tokenizer(self):
|
|
31
|
-
"""Mock tokenizer for testing."""
|
|
32
|
-
tokenizer = MagicMock()
|
|
33
|
-
tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
|
34
|
-
tokenizer.decode.return_value = "AUGC"
|
|
35
|
-
tokenizer.convert_ids_to_tokens.return_value = ["A", "U", "G", "C"]
|
|
36
|
-
return tokenizer
|
|
37
|
-
|
|
38
|
-
def test_model_import_structure(self):
|
|
39
|
-
"""Test that model classes can be imported as shown in examples."""
|
|
40
|
-
try:
|
|
41
|
-
from omnigenome import (
|
|
42
|
-
OmniModelForSequenceClassification,
|
|
43
|
-
OmniModelForTokenClassification,
|
|
44
|
-
OmniModelForSequenceRegression,
|
|
45
|
-
OmniModelForTokenRegression,
|
|
46
|
-
)
|
|
47
|
-
# If import succeeds, test passes
|
|
48
|
-
assert True
|
|
49
|
-
except ImportError:
|
|
50
|
-
pytest.skip("omnigenome not available or missing dependencies")
|
|
51
|
-
|
|
52
|
-
def test_embedding_model_import(self):
|
|
53
|
-
"""Test embedding model import as shown in RNA_Embedding_Tutorial.ipynb."""
|
|
54
|
-
try:
|
|
55
|
-
from omnigenome import OmniGenomeModelForEmbedding
|
|
56
|
-
assert True
|
|
57
|
-
except ImportError:
|
|
58
|
-
pytest.skip("omnigenome not available or missing dependencies")
|
|
59
|
-
|
|
60
|
-
def test_pooling_import(self):
|
|
61
|
-
"""Test pooling import as shown in classification.ipynb."""
|
|
62
|
-
try:
|
|
63
|
-
from omnigenome import OmniModel, OmniPooling
|
|
64
|
-
assert True
|
|
65
|
-
except ImportError:
|
|
66
|
-
pytest.skip("omnigenome not available or missing dependencies")
|
|
67
|
-
|
|
68
|
-
def test_base_model_loading_pattern(self, mock_tokenizer):
|
|
69
|
-
"""Test the base model loading pattern from classification.ipynb."""
|
|
70
|
-
try:
|
|
71
|
-
from transformers import AutoTokenizer, AutoModel
|
|
72
|
-
except ImportError:
|
|
73
|
-
pytest.skip("transformers not available")
|
|
74
|
-
|
|
75
|
-
with patch('transformers.AutoTokenizer.from_pretrained') as mock_auto_tokenizer, \
|
|
76
|
-
patch('transformers.AutoModel.from_pretrained') as mock_auto_model:
|
|
77
|
-
|
|
78
|
-
# Mock the returns
|
|
79
|
-
mock_auto_tokenizer.return_value = mock_tokenizer
|
|
80
|
-
mock_auto_model.return_value = MagicMock()
|
|
81
|
-
|
|
82
|
-
# This pattern is from examples/custom_finetuning/classification.ipynb
|
|
83
|
-
model_name = "yangheng/OmniGenome-52M"
|
|
84
|
-
base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
|
85
|
-
base_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
86
|
-
|
|
87
|
-
# Verify the calls were made correctly
|
|
88
|
-
mock_auto_model.assert_called_once_with(model_name, trust_remote_code=True)
|
|
89
|
-
mock_auto_tokenizer.assert_called_once_with(model_name, trust_remote_code=True)
|
|
90
|
-
|
|
91
|
-
def test_embedding_model_initialization_pattern(self):
|
|
92
|
-
"""Test embedding model initialization pattern from RNA_Embedding_Tutorial.ipynb."""
|
|
93
|
-
if torch is None:
|
|
94
|
-
pytest.skip("torch not available")
|
|
95
|
-
|
|
96
|
-
try:
|
|
97
|
-
from omnigenome import OmniGenomeModelForEmbedding
|
|
98
|
-
except ImportError:
|
|
99
|
-
pytest.skip("omnigenome not available")
|
|
100
|
-
|
|
101
|
-
with patch('omnigenome.OmniGenomeModelForEmbedding') as mock_embedding_model:
|
|
102
|
-
mock_instance = MagicMock()
|
|
103
|
-
mock_instance.to.return_value = mock_instance
|
|
104
|
-
mock_embedding_model.return_value = mock_instance
|
|
105
|
-
|
|
106
|
-
model_name = "yangheng/OmniGenome-52M"
|
|
107
|
-
embedding_model = OmniGenomeModelForEmbedding(model_name, trust_remote_code=True).to(torch.device("cuda:0")).to(torch.float16)
|
|
108
|
-
|
|
109
|
-
# Verify initialization pattern
|
|
110
|
-
mock_embedding_model.assert_called_once_with(model_name, trust_remote_code=True)
|
|
111
|
-
assert mock_instance.to.call_count == 2 # Called twice for device and dtype
|
|
112
|
-
|
|
113
|
-
def test_model_parameter_patterns(self):
|
|
114
|
-
"""Test that common model parameters are recognized."""
|
|
115
|
-
# These are patterns seen across examples
|
|
116
|
-
common_model_names = [
|
|
117
|
-
"yangheng/OmniGenome-52M",
|
|
118
|
-
"yangheng/OmniGenome-186M",
|
|
119
|
-
"anonymous8/OmniGenome-186M",
|
|
120
|
-
"anonymous8/OmniGenome-52M"
|
|
121
|
-
]
|
|
122
|
-
|
|
123
|
-
for model_name in common_model_names:
|
|
124
|
-
# Just verify the string patterns are valid
|
|
125
|
-
assert isinstance(model_name, str)
|
|
126
|
-
assert "/" in model_name
|
|
127
|
-
assert "OmniGenome" in model_name
|
|
128
|
-
|
|
129
|
-
def test_classification_model_initialization_pattern(self, mock_tokenizer):
|
|
130
|
-
"""Test classification model init pattern from examples."""
|
|
131
|
-
try:
|
|
132
|
-
from omnigenome import OmniModelForSequenceClassification
|
|
133
|
-
except ImportError:
|
|
134
|
-
pytest.skip("omnigenome not available")
|
|
135
|
-
|
|
136
|
-
with patch('omnigenome.OmniModelForSequenceClassification') as mock_model_class:
|
|
137
|
-
mock_model_class.return_value = MagicMock()
|
|
138
|
-
|
|
139
|
-
# Pattern from classification.ipynb
|
|
140
|
-
model_name = "test_model"
|
|
141
|
-
tokenizer = mock_tokenizer
|
|
142
|
-
|
|
143
|
-
model = OmniModelForSequenceClassification(
|
|
144
|
-
config_or_model=model_name,
|
|
145
|
-
tokenizer=tokenizer,
|
|
146
|
-
num_labels=3,
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
mock_model_class.assert_called_once_with(
|
|
150
|
-
config_or_model=model_name,
|
|
151
|
-
tokenizer=tokenizer,
|
|
152
|
-
num_labels=3,
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
def test_rna_sequence_patterns(self):
|
|
156
|
-
"""Test RNA sequence patterns used in examples."""
|
|
157
|
-
# Patterns from RNA_Embedding_Tutorial.ipynb
|
|
158
|
-
rna_sequences = [
|
|
159
|
-
"AUGGCUACG",
|
|
160
|
-
"CGGAUACGGC",
|
|
161
|
-
"UGGCCAAGUC",
|
|
162
|
-
"AUGCUGCUAUGCUA"
|
|
163
|
-
]
|
|
164
|
-
|
|
165
|
-
for seq in rna_sequences:
|
|
166
|
-
# Basic validation of RNA sequence format
|
|
167
|
-
assert isinstance(seq, str)
|
|
168
|
-
assert len(seq) > 0
|
|
169
|
-
assert all(base in "AUCG" for base in seq)
|
|
170
|
-
|
|
171
|
-
def test_device_patterns(self):
|
|
172
|
-
"""Test device usage patterns from examples."""
|
|
173
|
-
if torch is None:
|
|
174
|
-
pytest.skip("torch not available")
|
|
175
|
-
|
|
176
|
-
# Pattern from examples: torch.device("cuda:0")
|
|
177
|
-
device = torch.device("cuda:0")
|
|
178
|
-
assert str(device) == "cuda:0"
|
|
179
|
-
|
|
180
|
-
# Alternative pattern
|
|
181
|
-
if torch.cuda.is_available():
|
|
182
|
-
device = torch.device("cuda")
|
|
183
|
-
assert "cuda" in str(device)
|