omnigenome 0.3.0a0__py3-none-any.whl → 0.3.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- omnigenome/__init__.py +29 -44
- omnigenome/auto/auto_bench/__init__.py +0 -1
- omnigenome/auto/auto_bench/auto_bench.py +24 -14
- omnigenome/auto/auto_train/__init__.py +0 -1
- omnigenome/auto/auto_train/auto_train.py +11 -12
- omnigenome/auto/bench_hub/__init__.py +0 -1
- omnigenome/auto/bench_hub/bench_hub.py +1 -1
- omnigenome/cli/__init__.py +0 -1
- omnigenome/cli/commands/__init__.py +0 -1
- omnigenome/cli/commands/base.py +10 -10
- omnigenome/cli/commands/bench/__init__.py +0 -1
- omnigenome/cli/commands/bench/bench_cli.py +10 -10
- omnigenome/cli/commands/rna/__init__.py +0 -1
- omnigenome/cli/commands/rna/rna_design.py +10 -11
- omnigenome/src/__init__.py +0 -1
- omnigenome/src/abc/__init__.py +0 -1
- omnigenome/src/abc/abstract_dataset.py +38 -19
- omnigenome/src/abc/abstract_metric.py +7 -7
- omnigenome/src/abc/abstract_model.py +15 -14
- omnigenome/src/abc/abstract_tokenizer.py +9 -7
- omnigenome/src/dataset/omni_dataset.py +16 -14
- omnigenome/src/lora/__init__.py +0 -1
- omnigenome/src/lora/lora_model.py +47 -41
- omnigenome/src/metric/classification_metric.py +11 -11
- omnigenome/src/metric/metric.py +19 -19
- omnigenome/src/metric/ranking_metric.py +15 -15
- omnigenome/src/metric/regression_metric.py +18 -18
- omnigenome/src/misc/utils.py +214 -150
- omnigenome/src/model/augmentation/__init__.py +0 -1
- omnigenome/src/model/augmentation/model.py +17 -17
- omnigenome/src/model/classification/__init__.py +0 -1
- omnigenome/src/model/classification/model.py +28 -32
- omnigenome/src/model/embedding/__init__.py +0 -1
- omnigenome/src/model/embedding/model.py +35 -35
- omnigenome/src/model/mlm/__init__.py +0 -1
- omnigenome/src/model/mlm/model.py +13 -13
- omnigenome/src/model/module_utils.py +17 -17
- omnigenome/src/model/regression/__init__.py +0 -1
- omnigenome/src/model/regression/model.py +72 -77
- omnigenome/src/model/regression/resnet.py +32 -32
- omnigenome/src/model/rna_design/__init__.py +0 -1
- omnigenome/src/model/rna_design/model.py +168 -118
- omnigenome/src/model/seq2seq/__init__.py +0 -1
- omnigenome/src/model/seq2seq/model.py +4 -4
- omnigenome/src/tokenizer/bpe_tokenizer.py +27 -27
- omnigenome/src/tokenizer/kmers_tokenizer.py +22 -22
- omnigenome/src/tokenizer/single_nucleotide_tokenizer.py +11 -11
- omnigenome/src/trainer/accelerate_trainer.py +40 -32
- omnigenome/src/trainer/hf_trainer.py +8 -8
- omnigenome/src/trainer/trainer.py +37 -25
- omnigenome/utility/dataset_hub/__init__.py +0 -1
- omnigenome/utility/dataset_hub/dataset_hub.py +13 -13
- omnigenome/utility/ensemble.py +26 -26
- omnigenome/utility/hub_utils.py +8 -8
- omnigenome/utility/model_hub/__init__.py +0 -1
- omnigenome/utility/model_hub/model_hub.py +26 -25
- omnigenome/utility/pipeline_hub/__init__.py +0 -1
- omnigenome/utility/pipeline_hub/pipeline.py +49 -49
- omnigenome/utility/pipeline_hub/pipeline_hub.py +17 -17
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/METADATA +3 -3
- omnigenome-0.3.1a0.dist-info/RECORD +78 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/top_level.txt +0 -1
- omnigenome-0.3.0a0.dist-info/RECORD +0 -85
- tests/__init__.py +0 -9
- tests/conftest.py +0 -160
- tests/test_dataset_patterns.py +0 -291
- tests/test_examples_syntax.py +0 -83
- tests/test_model_loading.py +0 -183
- tests/test_rna_functions.py +0 -255
- tests/test_training_patterns.py +0 -302
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/WHEEL +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/entry_points.txt +0 -0
- {omnigenome-0.3.0a0.dist-info → omnigenome-0.3.1a0.dist-info}/licenses/LICENSE +0 -0
tests/test_model_loading.py
DELETED
|
@@ -1,183 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Test model loading functionality based on examples.
|
|
3
|
-
"""
|
|
4
|
-
import pytest
|
|
5
|
-
import tempfile
|
|
6
|
-
import os
|
|
7
|
-
from unittest.mock import patch, MagicMock
|
|
8
|
-
|
|
9
|
-
try:
|
|
10
|
-
import torch
|
|
11
|
-
except ImportError:
|
|
12
|
-
torch = None
|
|
13
|
-
|
|
14
|
-
# Skip heavy model loading tests by default - can be run with --run-slow
|
|
15
|
-
pytestmark = pytest.mark.slow
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class TestModelLoading:
|
|
19
|
-
"""Test model loading similar to examples."""
|
|
20
|
-
|
|
21
|
-
@pytest.fixture
|
|
22
|
-
def mock_model_config(self):
|
|
23
|
-
"""Mock model config to avoid downloading real models."""
|
|
24
|
-
config = MagicMock()
|
|
25
|
-
config.hidden_size = 768
|
|
26
|
-
config.num_labels = 2
|
|
27
|
-
return config
|
|
28
|
-
|
|
29
|
-
@pytest.fixture
|
|
30
|
-
def mock_tokenizer(self):
|
|
31
|
-
"""Mock tokenizer for testing."""
|
|
32
|
-
tokenizer = MagicMock()
|
|
33
|
-
tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
|
34
|
-
tokenizer.decode.return_value = "AUGC"
|
|
35
|
-
tokenizer.convert_ids_to_tokens.return_value = ["A", "U", "G", "C"]
|
|
36
|
-
return tokenizer
|
|
37
|
-
|
|
38
|
-
def test_model_import_structure(self):
|
|
39
|
-
"""Test that model classes can be imported as shown in examples."""
|
|
40
|
-
try:
|
|
41
|
-
from omnigenome import (
|
|
42
|
-
OmniModelForSequenceClassification,
|
|
43
|
-
OmniModelForTokenClassification,
|
|
44
|
-
OmniModelForSequenceRegression,
|
|
45
|
-
OmniModelForTokenRegression,
|
|
46
|
-
)
|
|
47
|
-
# If import succeeds, test passes
|
|
48
|
-
assert True
|
|
49
|
-
except ImportError:
|
|
50
|
-
pytest.skip("omnigenome not available or missing dependencies")
|
|
51
|
-
|
|
52
|
-
def test_embedding_model_import(self):
|
|
53
|
-
"""Test embedding model import as shown in RNA_Embedding_Tutorial.ipynb."""
|
|
54
|
-
try:
|
|
55
|
-
from omnigenome import OmniGenomeModelForEmbedding
|
|
56
|
-
assert True
|
|
57
|
-
except ImportError:
|
|
58
|
-
pytest.skip("omnigenome not available or missing dependencies")
|
|
59
|
-
|
|
60
|
-
def test_pooling_import(self):
|
|
61
|
-
"""Test pooling import as shown in classification.ipynb."""
|
|
62
|
-
try:
|
|
63
|
-
from omnigenome import OmniModel, OmniPooling
|
|
64
|
-
assert True
|
|
65
|
-
except ImportError:
|
|
66
|
-
pytest.skip("omnigenome not available or missing dependencies")
|
|
67
|
-
|
|
68
|
-
def test_base_model_loading_pattern(self, mock_tokenizer):
|
|
69
|
-
"""Test the base model loading pattern from classification.ipynb."""
|
|
70
|
-
try:
|
|
71
|
-
from transformers import AutoTokenizer, AutoModel
|
|
72
|
-
except ImportError:
|
|
73
|
-
pytest.skip("transformers not available")
|
|
74
|
-
|
|
75
|
-
with patch('transformers.AutoTokenizer.from_pretrained') as mock_auto_tokenizer, \
|
|
76
|
-
patch('transformers.AutoModel.from_pretrained') as mock_auto_model:
|
|
77
|
-
|
|
78
|
-
# Mock the returns
|
|
79
|
-
mock_auto_tokenizer.return_value = mock_tokenizer
|
|
80
|
-
mock_auto_model.return_value = MagicMock()
|
|
81
|
-
|
|
82
|
-
# This pattern is from examples/custom_finetuning/classification.ipynb
|
|
83
|
-
model_name = "yangheng/OmniGenome-52M"
|
|
84
|
-
base_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
|
85
|
-
base_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
86
|
-
|
|
87
|
-
# Verify the calls were made correctly
|
|
88
|
-
mock_auto_model.assert_called_once_with(model_name, trust_remote_code=True)
|
|
89
|
-
mock_auto_tokenizer.assert_called_once_with(model_name, trust_remote_code=True)
|
|
90
|
-
|
|
91
|
-
def test_embedding_model_initialization_pattern(self):
|
|
92
|
-
"""Test embedding model initialization pattern from RNA_Embedding_Tutorial.ipynb."""
|
|
93
|
-
if torch is None:
|
|
94
|
-
pytest.skip("torch not available")
|
|
95
|
-
|
|
96
|
-
try:
|
|
97
|
-
from omnigenome import OmniGenomeModelForEmbedding
|
|
98
|
-
except ImportError:
|
|
99
|
-
pytest.skip("omnigenome not available")
|
|
100
|
-
|
|
101
|
-
with patch('omnigenome.OmniGenomeModelForEmbedding') as mock_embedding_model:
|
|
102
|
-
mock_instance = MagicMock()
|
|
103
|
-
mock_instance.to.return_value = mock_instance
|
|
104
|
-
mock_embedding_model.return_value = mock_instance
|
|
105
|
-
|
|
106
|
-
model_name = "yangheng/OmniGenome-52M"
|
|
107
|
-
embedding_model = OmniGenomeModelForEmbedding(model_name, trust_remote_code=True).to(torch.device("cuda:0")).to(torch.float16)
|
|
108
|
-
|
|
109
|
-
# Verify initialization pattern
|
|
110
|
-
mock_embedding_model.assert_called_once_with(model_name, trust_remote_code=True)
|
|
111
|
-
assert mock_instance.to.call_count == 2 # Called twice for device and dtype
|
|
112
|
-
|
|
113
|
-
def test_model_parameter_patterns(self):
|
|
114
|
-
"""Test that common model parameters are recognized."""
|
|
115
|
-
# These are patterns seen across examples
|
|
116
|
-
common_model_names = [
|
|
117
|
-
"yangheng/OmniGenome-52M",
|
|
118
|
-
"yangheng/OmniGenome-186M",
|
|
119
|
-
"anonymous8/OmniGenome-186M",
|
|
120
|
-
"anonymous8/OmniGenome-52M"
|
|
121
|
-
]
|
|
122
|
-
|
|
123
|
-
for model_name in common_model_names:
|
|
124
|
-
# Just verify the string patterns are valid
|
|
125
|
-
assert isinstance(model_name, str)
|
|
126
|
-
assert "/" in model_name
|
|
127
|
-
assert "OmniGenome" in model_name
|
|
128
|
-
|
|
129
|
-
def test_classification_model_initialization_pattern(self, mock_tokenizer):
|
|
130
|
-
"""Test classification model init pattern from examples."""
|
|
131
|
-
try:
|
|
132
|
-
from omnigenome import OmniModelForSequenceClassification
|
|
133
|
-
except ImportError:
|
|
134
|
-
pytest.skip("omnigenome not available")
|
|
135
|
-
|
|
136
|
-
with patch('omnigenome.OmniModelForSequenceClassification') as mock_model_class:
|
|
137
|
-
mock_model_class.return_value = MagicMock()
|
|
138
|
-
|
|
139
|
-
# Pattern from classification.ipynb
|
|
140
|
-
model_name = "test_model"
|
|
141
|
-
tokenizer = mock_tokenizer
|
|
142
|
-
|
|
143
|
-
model = OmniModelForSequenceClassification(
|
|
144
|
-
config_or_model=model_name,
|
|
145
|
-
tokenizer=tokenizer,
|
|
146
|
-
num_labels=3,
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
mock_model_class.assert_called_once_with(
|
|
150
|
-
config_or_model=model_name,
|
|
151
|
-
tokenizer=tokenizer,
|
|
152
|
-
num_labels=3,
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
def test_rna_sequence_patterns(self):
|
|
156
|
-
"""Test RNA sequence patterns used in examples."""
|
|
157
|
-
# Patterns from RNA_Embedding_Tutorial.ipynb
|
|
158
|
-
rna_sequences = [
|
|
159
|
-
"AUGGCUACG",
|
|
160
|
-
"CGGAUACGGC",
|
|
161
|
-
"UGGCCAAGUC",
|
|
162
|
-
"AUGCUGCUAUGCUA"
|
|
163
|
-
]
|
|
164
|
-
|
|
165
|
-
for seq in rna_sequences:
|
|
166
|
-
# Basic validation of RNA sequence format
|
|
167
|
-
assert isinstance(seq, str)
|
|
168
|
-
assert len(seq) > 0
|
|
169
|
-
assert all(base in "AUCG" for base in seq)
|
|
170
|
-
|
|
171
|
-
def test_device_patterns(self):
|
|
172
|
-
"""Test device usage patterns from examples."""
|
|
173
|
-
if torch is None:
|
|
174
|
-
pytest.skip("torch not available")
|
|
175
|
-
|
|
176
|
-
# Pattern from examples: torch.device("cuda:0")
|
|
177
|
-
device = torch.device("cuda:0")
|
|
178
|
-
assert str(device) == "cuda:0"
|
|
179
|
-
|
|
180
|
-
# Alternative pattern
|
|
181
|
-
if torch.cuda.is_available():
|
|
182
|
-
device = torch.device("cuda")
|
|
183
|
-
assert "cuda" in str(device)
|
tests/test_rna_functions.py
DELETED
|
@@ -1,255 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Test RNA-specific functionality based on examples.
|
|
3
|
-
"""
|
|
4
|
-
import pytest
|
|
5
|
-
import tempfile
|
|
6
|
-
import os
|
|
7
|
-
from unittest.mock import patch, MagicMock
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class TestRNAFunctions:
|
|
11
|
-
"""Test RNA functionality based on examples."""
|
|
12
|
-
|
|
13
|
-
def test_rna_sequence_validity_checker(self):
|
|
14
|
-
"""Test ss_validity_loss function from Secondary_Structure_Prediction.py."""
|
|
15
|
-
# Recreate the function from the example
|
|
16
|
-
def ss_validity_loss(rna_strct: str) -> float:
|
|
17
|
-
left = right = 0
|
|
18
|
-
dots = rna_strct.count('.')
|
|
19
|
-
for c in rna_strct:
|
|
20
|
-
if c == '(':
|
|
21
|
-
left += 1
|
|
22
|
-
elif c == ')':
|
|
23
|
-
if left:
|
|
24
|
-
left -= 1
|
|
25
|
-
else:
|
|
26
|
-
right += 1
|
|
27
|
-
elif c != '.':
|
|
28
|
-
raise ValueError(f"Invalid char {c}")
|
|
29
|
-
return (left + right) / (len(rna_strct) - dots + 1e-8)
|
|
30
|
-
|
|
31
|
-
# Test valid structures
|
|
32
|
-
assert ss_validity_loss("(())") == 0.0
|
|
33
|
-
assert ss_validity_loss("((..))") == 0.0
|
|
34
|
-
assert ss_validity_loss("....") == 0.0
|
|
35
|
-
|
|
36
|
-
# Test invalid structures
|
|
37
|
-
assert ss_validity_loss("(((") > 0.0 # Unmatched left
|
|
38
|
-
assert ss_validity_loss(")))") > 0.0 # Unmatched right
|
|
39
|
-
assert ss_validity_loss("())(") > 0.0 # Mixed unmatched
|
|
40
|
-
|
|
41
|
-
# Test error case
|
|
42
|
-
with pytest.raises(ValueError, match="Invalid char"):
|
|
43
|
-
ss_validity_loss("((X))")
|
|
44
|
-
|
|
45
|
-
def test_find_invalid_positions(self):
|
|
46
|
-
"""Test find_invalid_positions function from Secondary_Structure_Prediction.py."""
|
|
47
|
-
# Recreate the function from the example
|
|
48
|
-
def find_invalid_positions(struct: str) -> list:
|
|
49
|
-
stack, invalid = [], []
|
|
50
|
-
for i, c in enumerate(struct):
|
|
51
|
-
if c == '(':
|
|
52
|
-
stack.append(i)
|
|
53
|
-
elif c == ')':
|
|
54
|
-
if stack:
|
|
55
|
-
stack.pop()
|
|
56
|
-
else:
|
|
57
|
-
invalid.append(i)
|
|
58
|
-
invalid.extend(stack)
|
|
59
|
-
return invalid
|
|
60
|
-
|
|
61
|
-
# Test valid structures
|
|
62
|
-
assert find_invalid_positions("(())") == []
|
|
63
|
-
assert find_invalid_positions("((..))") == []
|
|
64
|
-
assert find_invalid_positions("....") == []
|
|
65
|
-
|
|
66
|
-
# Test invalid structures
|
|
67
|
-
assert find_invalid_positions("(((") == [0, 1, 2] # All unmatched left
|
|
68
|
-
assert find_invalid_positions(")))") == [0, 1, 2] # All unmatched right
|
|
69
|
-
assert find_invalid_positions("())(") == [2, 3] # One unmatched right, one left
|
|
70
|
-
|
|
71
|
-
def test_rna_structure_formats(self):
|
|
72
|
-
"""Test RNA structure format validation."""
|
|
73
|
-
valid_structures = [
|
|
74
|
-
"(())",
|
|
75
|
-
"((()))",
|
|
76
|
-
".((.))",
|
|
77
|
-
"....",
|
|
78
|
-
"",
|
|
79
|
-
"((..))",
|
|
80
|
-
]
|
|
81
|
-
|
|
82
|
-
invalid_structures = [
|
|
83
|
-
"((X))", # Invalid character
|
|
84
|
-
"(()", # Unmatched
|
|
85
|
-
"())", # Unmatched
|
|
86
|
-
")(", # Wrong order
|
|
87
|
-
]
|
|
88
|
-
|
|
89
|
-
def is_valid_structure_format(struct: str) -> bool:
|
|
90
|
-
"""Check if structure contains only valid characters."""
|
|
91
|
-
return all(c in "()." for c in struct)
|
|
92
|
-
|
|
93
|
-
for struct in valid_structures:
|
|
94
|
-
assert is_valid_structure_format(struct), f"Should be valid: {struct}"
|
|
95
|
-
|
|
96
|
-
for struct in invalid_structures:
|
|
97
|
-
if any(c not in "()." for c in struct):
|
|
98
|
-
assert not is_valid_structure_format(struct), f"Should be invalid: {struct}"
|
|
99
|
-
|
|
100
|
-
def test_sequence_replacement_patterns(self):
|
|
101
|
-
"""Test U/T replacement patterns from examples."""
|
|
102
|
-
# Pattern from web_rna_design.py
|
|
103
|
-
def rna_to_dna_pattern(sequence):
|
|
104
|
-
return sequence.replace("U", "T")
|
|
105
|
-
|
|
106
|
-
def dna_to_rna_pattern(sequence):
|
|
107
|
-
return sequence.replace("T", "U")
|
|
108
|
-
|
|
109
|
-
# Test RNA to DNA
|
|
110
|
-
assert rna_to_dna_pattern("AUCG") == "ATCG"
|
|
111
|
-
assert rna_to_dna_pattern("UUUU") == "TTTT"
|
|
112
|
-
assert rna_to_dna_pattern("ACGU") == "ACGT"
|
|
113
|
-
|
|
114
|
-
# Test DNA to RNA
|
|
115
|
-
assert dna_to_rna_pattern("ATCG") == "AUCG"
|
|
116
|
-
assert dna_to_rna_pattern("TTTT") == "UUUU"
|
|
117
|
-
assert dna_to_rna_pattern("ACGT") == "ACGU"
|
|
118
|
-
|
|
119
|
-
def test_random_base_generation_patterns(self):
|
|
120
|
-
"""Test random base generation patterns from RNA design examples."""
|
|
121
|
-
import random
|
|
122
|
-
|
|
123
|
-
def generate_random_rna_base():
|
|
124
|
-
"""Pattern from easy_rna_design_emoo.py."""
|
|
125
|
-
return random.choice(["A", "U", "G", "C"])
|
|
126
|
-
|
|
127
|
-
def generate_random_dna_base():
|
|
128
|
-
"""Pattern from easy_rna_design_emoo.py."""
|
|
129
|
-
return random.choice(["A", "T", "G", "C"])
|
|
130
|
-
|
|
131
|
-
# Test multiple generations to ensure valid bases
|
|
132
|
-
for _ in range(10):
|
|
133
|
-
rna_base = generate_random_rna_base()
|
|
134
|
-
assert rna_base in ["A", "U", "G", "C"]
|
|
135
|
-
|
|
136
|
-
dna_base = generate_random_dna_base()
|
|
137
|
-
assert dna_base in ["A", "T", "G", "C"]
|
|
138
|
-
|
|
139
|
-
def test_sequence_mutation_pattern(self):
|
|
140
|
-
"""Test sequence mutation pattern from mlm_mutate function."""
|
|
141
|
-
try:
|
|
142
|
-
import numpy as np
|
|
143
|
-
except ImportError:
|
|
144
|
-
pytest.skip("numpy not available")
|
|
145
|
-
|
|
146
|
-
def mutate_sequence_pattern(sequence, mutation_rate=0.1):
|
|
147
|
-
"""Simplified version of mutation pattern from examples."""
|
|
148
|
-
sequence_array = np.array(list(sequence), dtype=np.str_)
|
|
149
|
-
probability_matrix = np.full(sequence_array.shape, mutation_rate)
|
|
150
|
-
masked_indices = np.random.rand(*sequence_array.shape) < probability_matrix
|
|
151
|
-
sequence_array[masked_indices] = "$" # Mask token
|
|
152
|
-
return "".join(sequence_array.tolist())
|
|
153
|
-
|
|
154
|
-
# Test mutation with 0% rate
|
|
155
|
-
original = "AUCG"
|
|
156
|
-
mutated_zero = mutate_sequence_pattern(original, 0.0)
|
|
157
|
-
assert mutated_zero == original
|
|
158
|
-
|
|
159
|
-
# Test mutation with 100% rate
|
|
160
|
-
mutated_full = mutate_sequence_pattern(original, 1.0)
|
|
161
|
-
assert mutated_full == "$$$$"
|
|
162
|
-
|
|
163
|
-
# Test with moderate rate - should have some masks
|
|
164
|
-
np.random.seed(42) # For reproducible test
|
|
165
|
-
mutated_partial = mutate_sequence_pattern("AUCGAUCGAUCG", 0.5)
|
|
166
|
-
assert "$" in mutated_partial
|
|
167
|
-
|
|
168
|
-
@patch('tempfile.mkdtemp')
|
|
169
|
-
def test_temp_directory_pattern(self, mock_mkdtemp):
|
|
170
|
-
"""Test temp directory usage pattern from Secondary_Structure_Prediction.py."""
|
|
171
|
-
from pathlib import Path
|
|
172
|
-
|
|
173
|
-
mock_mkdtemp.return_value = "/tmp/test_dir"
|
|
174
|
-
|
|
175
|
-
# Pattern from Secondary_Structure_Prediction.py
|
|
176
|
-
TEMP_DIR = Path(tempfile.mkdtemp())
|
|
177
|
-
|
|
178
|
-
mock_mkdtemp.assert_called_once()
|
|
179
|
-
assert isinstance(TEMP_DIR, Path)
|
|
180
|
-
|
|
181
|
-
def test_rna_embedding_sequence_validation(self):
|
|
182
|
-
"""Test RNA sequence validation for embedding examples."""
|
|
183
|
-
# RNA sequences from RNA_Embedding_Tutorial.ipynb
|
|
184
|
-
rna_sequences = [
|
|
185
|
-
"AUGGCUACG",
|
|
186
|
-
"CGGAUACGGC",
|
|
187
|
-
"UGGCCAAGUC",
|
|
188
|
-
"AUGCUGCUAUGCUA"
|
|
189
|
-
]
|
|
190
|
-
|
|
191
|
-
def validate_rna_sequence(seq):
|
|
192
|
-
"""Validate RNA sequence format."""
|
|
193
|
-
return all(base in "AUCG" for base in seq) and len(seq) > 0
|
|
194
|
-
|
|
195
|
-
for seq in rna_sequences:
|
|
196
|
-
assert validate_rna_sequence(seq), f"Invalid RNA sequence: {seq}"
|
|
197
|
-
|
|
198
|
-
def test_structure_prediction_mock_pattern(self):
|
|
199
|
-
"""Test structure prediction pattern without ViennaRNA dependency."""
|
|
200
|
-
def mock_predict_structure_single(sequence):
|
|
201
|
-
"""Mock version of predict_structure_single from examples."""
|
|
202
|
-
# Return a mock structure and energy
|
|
203
|
-
return "." * len(sequence), -10.0
|
|
204
|
-
|
|
205
|
-
# Test the pattern
|
|
206
|
-
seq = "AUCG"
|
|
207
|
-
struct, energy = mock_predict_structure_single(seq)
|
|
208
|
-
|
|
209
|
-
assert len(struct) == len(seq)
|
|
210
|
-
assert isinstance(energy, float)
|
|
211
|
-
assert struct == "...."
|
|
212
|
-
|
|
213
|
-
def test_base64_encoding_pattern(self):
|
|
214
|
-
"""Test base64 encoding pattern from SVG generation."""
|
|
215
|
-
import base64
|
|
216
|
-
|
|
217
|
-
def create_mock_svg_datauri(content="test"):
|
|
218
|
-
"""Mock version of SVG data URI creation."""
|
|
219
|
-
svg_content = f'<svg>{content}</svg>'
|
|
220
|
-
b64 = base64.b64encode(svg_content.encode()).decode('utf-8')
|
|
221
|
-
return f"data:image/svg+xml;base64,{b64}"
|
|
222
|
-
|
|
223
|
-
uri = create_mock_svg_datauri("test")
|
|
224
|
-
assert uri.startswith("data:image/svg+xml;base64,")
|
|
225
|
-
|
|
226
|
-
# Decode and verify
|
|
227
|
-
_, b64_part = uri.split(",", 1)
|
|
228
|
-
decoded = base64.b64decode(b64_part).decode('utf-8')
|
|
229
|
-
assert decoded == "<svg>test</svg>"
|
|
230
|
-
|
|
231
|
-
def test_longest_bp_span_function(self):
|
|
232
|
-
"""Test longest_bp_span function from easy_rna_design_emoo.py."""
|
|
233
|
-
def longest_bp_span(structure):
|
|
234
|
-
"""Function from easy_rna_design_emoo.py."""
|
|
235
|
-
stack = []
|
|
236
|
-
max_span = 0
|
|
237
|
-
|
|
238
|
-
for i, char in enumerate(structure):
|
|
239
|
-
if char == '(':
|
|
240
|
-
stack.append(i)
|
|
241
|
-
elif char == ')':
|
|
242
|
-
if stack:
|
|
243
|
-
left_index = stack.pop()
|
|
244
|
-
current_span = i - left_index
|
|
245
|
-
max_span = max(max_span, current_span)
|
|
246
|
-
|
|
247
|
-
return max_span
|
|
248
|
-
|
|
249
|
-
# Test cases
|
|
250
|
-
assert longest_bp_span("(())") == 3 # Outer pair spans 3 positions
|
|
251
|
-
assert longest_bp_span("((()))") == 5 # Outer pair spans 5 positions
|
|
252
|
-
assert longest_bp_span("()()") == 1 # Each pair spans 1 position
|
|
253
|
-
assert longest_bp_span("....") == 0 # No pairs
|
|
254
|
-
assert longest_bp_span("") == 0 # Empty structure
|
|
255
|
-
assert longest_bp_span("((.))") == 4 # Outer pair spans 4 positions
|