pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/match_template.py +97 -148
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/postprocess.py +20 -29
- pytme-0.2.4.data/scripts/preprocess.py +148 -0
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +15 -23
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/METADATA +11 -10
- pytme-0.2.4.dist-info/RECORD +119 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
- pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/match_template.py +97 -148
- scripts/postprocess.py +20 -29
- scripts/preprocess.py +116 -61
- scripts/preprocessor_gui.py +15 -23
- tests/__init__.py +0 -0
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +310 -0
- tests/test_backends.py +375 -0
- tests/test_density.py +508 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +162 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +276 -0
- tests/test_matching_utils.py +326 -0
- tests/test_orientations.py +173 -0
- tests/test_packaging.py +95 -0
- tests/test_parser.py +33 -0
- tests/test_structure.py +243 -0
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +9 -6
- tme/backends/__init__.py +1 -1
- tme/backends/_jax_utils.py +10 -8
- tme/backends/cupy_backend.py +2 -7
- tme/backends/jax_backend.py +35 -20
- tme/backends/npfftw_backend.py +3 -2
- tme/backends/pytorch_backend.py +10 -7
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +26 -12
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/matching_data.py +33 -24
- tme/matching_exhaustive.py +39 -20
- tme/matching_scores.py +5 -2
- tme/matching_utils.py +8 -2
- tme/orientations.py +26 -9
- tme/preprocessing/_utils.py +14 -14
- tme/preprocessing/composable_filter.py +5 -4
- tme/preprocessing/compose.py +4 -4
- tme/preprocessing/frequency_filters.py +32 -35
- tme/preprocessing/tilt_series.py +210 -148
- tme/preprocessor.py +24 -246
- tme/structure.py +14 -14
- pytme-0.2.2.dist-info/RECORD +0 -74
- tme/matching_memory.py +0 -383
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,173 @@
|
|
1
|
+
from tempfile import mkstemp
|
2
|
+
|
3
|
+
import pytest
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
from tme import Orientations
|
7
|
+
|
8
|
+
|
9
|
+
class TestDensity:
|
10
|
+
def setup_method(self):
|
11
|
+
self.translations = np.random.rand(100, 3).astype(np.float32)
|
12
|
+
self.rotations = np.random.rand(100, 3).astype(np.float32)
|
13
|
+
self.scores = np.random.rand(100).astype(np.float32)
|
14
|
+
self.details = np.random.rand(100).astype(np.float32)
|
15
|
+
|
16
|
+
self.orientations = Orientations(
|
17
|
+
translations=self.translations,
|
18
|
+
rotations=self.rotations,
|
19
|
+
scores=self.scores,
|
20
|
+
details=self.details,
|
21
|
+
)
|
22
|
+
|
23
|
+
def teardown_method(self):
|
24
|
+
self.translations = None
|
25
|
+
self.rotations = None
|
26
|
+
self.scores = None
|
27
|
+
self.details = None
|
28
|
+
self.orientations = None
|
29
|
+
|
30
|
+
def test_initialization(self):
|
31
|
+
orientations = Orientations(
|
32
|
+
translations=self.translations,
|
33
|
+
rotations=self.rotations,
|
34
|
+
scores=self.scores,
|
35
|
+
details=self.details,
|
36
|
+
)
|
37
|
+
|
38
|
+
assert np.array_equal(self.translations, orientations.translations)
|
39
|
+
assert np.array_equal(self.rotations, orientations.rotations)
|
40
|
+
assert np.array_equal(self.scores, orientations.scores)
|
41
|
+
assert np.array_equal(self.details, orientations.details)
|
42
|
+
|
43
|
+
def test_initialization_type(self):
|
44
|
+
orientations = Orientations(
|
45
|
+
translations=self.translations.astype(int),
|
46
|
+
rotations=self.rotations.astype(int),
|
47
|
+
scores=self.scores.astype(int),
|
48
|
+
details=self.details.astype(int),
|
49
|
+
)
|
50
|
+
assert np.issubdtype(orientations.translations.dtype, np.floating)
|
51
|
+
assert np.issubdtype(orientations.rotations.dtype, np.floating)
|
52
|
+
assert np.issubdtype(orientations.scores.dtype, np.floating)
|
53
|
+
assert np.issubdtype(orientations.details.dtype, np.floating)
|
54
|
+
|
55
|
+
def test_initialization_error(self):
|
56
|
+
with pytest.raises(ValueError):
|
57
|
+
_ = Orientations(
|
58
|
+
translations=self.translations,
|
59
|
+
rotations=np.random.rand(self.translations.shape[0] + 1),
|
60
|
+
scores=self.scores,
|
61
|
+
details=self.details,
|
62
|
+
)
|
63
|
+
|
64
|
+
with pytest.raises(ValueError):
|
65
|
+
_ = Orientations(
|
66
|
+
translations=np.random.rand(self.translations.shape[0]),
|
67
|
+
rotations=np.random.rand(self.translations.shape[0] + 1),
|
68
|
+
scores=self.scores,
|
69
|
+
details=self.details,
|
70
|
+
)
|
71
|
+
_ = Orientations(
|
72
|
+
translations=self.translations,
|
73
|
+
rotations=np.random.rand(self.translations.shape[0]),
|
74
|
+
scores=self.scores,
|
75
|
+
details=self.details,
|
76
|
+
)
|
77
|
+
|
78
|
+
assert True
|
79
|
+
|
80
|
+
@pytest.mark.parametrize("file_format", ("text", "relion", "tbl"))
|
81
|
+
def test_to_file(self, file_format: str):
|
82
|
+
_, output_file = mkstemp(suffix=f".{file_format}")
|
83
|
+
self.orientations.to_file(output_file)
|
84
|
+
assert True
|
85
|
+
|
86
|
+
@pytest.mark.parametrize("file_format", ("text", "star", "tbl"))
|
87
|
+
def test_from_file(self, file_format: str):
|
88
|
+
_, output_file = mkstemp(suffix=f".{file_format}")
|
89
|
+
self.orientations.to_file(output_file)
|
90
|
+
orientations_new = Orientations.from_file(output_file)
|
91
|
+
|
92
|
+
assert np.array_equal(
|
93
|
+
self.orientations.translations, orientations_new.translations
|
94
|
+
)
|
95
|
+
|
96
|
+
@pytest.mark.parametrize("input_format", ("text", "star", "tbl"))
|
97
|
+
@pytest.mark.parametrize("output_format", ("text", "star", "tbl"))
|
98
|
+
def test_file_format_io(self, input_format: str, output_format: str):
|
99
|
+
_, output_file = mkstemp(suffix=f".{input_format}")
|
100
|
+
_, output_file2 = mkstemp(suffix=f".{output_format}")
|
101
|
+
|
102
|
+
self.orientations.to_file(output_file)
|
103
|
+
orientations_new = Orientations.from_file(output_file)
|
104
|
+
orientations_new.to_file(output_file2)
|
105
|
+
|
106
|
+
assert True
|
107
|
+
|
108
|
+
@pytest.mark.parametrize("drop_oob", (True, False))
|
109
|
+
@pytest.mark.parametrize("shape", (10, 40, 80))
|
110
|
+
@pytest.mark.parametrize("odd", (True, False))
|
111
|
+
def test_extraction(self, shape: int, drop_oob: bool, odd: bool):
|
112
|
+
if odd:
|
113
|
+
shape = shape + (1 - shape % 2)
|
114
|
+
|
115
|
+
data = np.random.rand(50, 50, 50)
|
116
|
+
translations = np.array([[25, 25, 25], [15, 25, 35], [35, 25, 15], [0, 15, 49]])
|
117
|
+
orientations = Orientations(
|
118
|
+
translations=translations,
|
119
|
+
rotations=np.random.rand(*translations.shape),
|
120
|
+
scores=np.random.rand(translations.shape[0]),
|
121
|
+
details=np.random.rand(translations.shape[0]),
|
122
|
+
)
|
123
|
+
extraction_shape = np.repeat(np.array(shape), data.ndim)
|
124
|
+
orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
|
125
|
+
target_shape=data.shape,
|
126
|
+
extraction_shape=extraction_shape,
|
127
|
+
drop_out_of_box=drop_oob,
|
128
|
+
return_orientations=True,
|
129
|
+
)
|
130
|
+
assert orientations.translations.shape[0] == len(cand_slices)
|
131
|
+
assert len(cand_slices) == len(obs_slices)
|
132
|
+
|
133
|
+
cand_slices2, obs_slices2 = orientations.get_extraction_slices(
|
134
|
+
target_shape=data.shape,
|
135
|
+
extraction_shape=extraction_shape,
|
136
|
+
drop_out_of_box=drop_oob,
|
137
|
+
return_orientations=False,
|
138
|
+
)
|
139
|
+
assert cand_slices == cand_slices2
|
140
|
+
assert obs_slices == obs_slices2
|
141
|
+
|
142
|
+
# Check whether extraction slices are pasted in center
|
143
|
+
out = np.zeros(extraction_shape, dtype=data.dtype)
|
144
|
+
center = np.divide(extraction_shape, 2) + np.mod(extraction_shape, 2)
|
145
|
+
center = center.astype(int)
|
146
|
+
for index, (cand_slice, obs_slice) in enumerate(zip(cand_slices, obs_slices)):
|
147
|
+
out[cand_slice] = data[obs_slice]
|
148
|
+
assert np.allclose(
|
149
|
+
out[tuple(center)],
|
150
|
+
data[tuple(orientations.translations[index].astype(int))],
|
151
|
+
)
|
152
|
+
|
153
|
+
@pytest.mark.parametrize(
|
154
|
+
"order", (("x", "y", "z"), ("z", "y", "x"), ("y", "x", "z"))
|
155
|
+
)
|
156
|
+
def test_txt_sort(self, order: str):
|
157
|
+
_, output_file = mkstemp(suffix=".tsv")
|
158
|
+
translations = ((50, 30, 20), (10, 5, 30))
|
159
|
+
|
160
|
+
with open(output_file, mode="w", encoding="utf-8") as ofile:
|
161
|
+
ofile.write("\t".join([str(x) for x in order]) + "\n")
|
162
|
+
for translation in translations:
|
163
|
+
ofile.write("\t".join([str(x) for x in translation]) + "\n")
|
164
|
+
|
165
|
+
translations = np.array(translations).astype(np.float32)
|
166
|
+
orientations = Orientations.from_file(output_file)
|
167
|
+
|
168
|
+
out_order = zip(order, range(len(order)))
|
169
|
+
out_order = tuple(
|
170
|
+
x[1] for x in sorted(out_order, key=lambda x: x[0], reverse=True)
|
171
|
+
)
|
172
|
+
|
173
|
+
assert np.array_equal(translations[..., out_order], orientations.translations)
|
tests/test_packaging.py
ADDED
@@ -0,0 +1,95 @@
|
|
1
|
+
import pytest
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
from tme.package import GaussianKernel, KernelFitting
|
5
|
+
from tme.analyzer import PeakCallerFast
|
6
|
+
|
7
|
+
class TestKernelFitting:
|
8
|
+
def setup_method(self):
|
9
|
+
self.number_of_peaks = 100
|
10
|
+
self.min_distance = 5
|
11
|
+
self.data = np.random.rand(100, 100, 100)
|
12
|
+
self.rotation_matrix = np.eye(3)
|
13
|
+
|
14
|
+
self.kernel_box = np.random.choice(np.arange(5, 15), size=self.data.ndim)
|
15
|
+
self.template = np.random.rand(*self.kernel_box)
|
16
|
+
|
17
|
+
@pytest.mark.parametrize(
|
18
|
+
"kernel",
|
19
|
+
[
|
20
|
+
(GaussianKernel),
|
21
|
+
],
|
22
|
+
)
|
23
|
+
def test_initialization(self, kernel):
|
24
|
+
kernel_params = kernel.estimate_parameters(self.template)
|
25
|
+
_ = KernelFitting(
|
26
|
+
number_of_peaks=self.number_of_peaks,
|
27
|
+
min_distance=self.min_distance,
|
28
|
+
kernel_box=self.kernel_box,
|
29
|
+
kernel_class=kernel,
|
30
|
+
kernel_params=kernel_params,
|
31
|
+
peak_caller=PeakCallerFast,
|
32
|
+
)
|
33
|
+
|
34
|
+
@pytest.mark.parametrize("kernel",[GaussianKernel])
|
35
|
+
def test__call__(self, kernel):
|
36
|
+
kernel_params = kernel.estimate_parameters(self.template)
|
37
|
+
score_analyzer = KernelFitting(
|
38
|
+
number_of_peaks=self.number_of_peaks,
|
39
|
+
min_distance=self.min_distance,
|
40
|
+
kernel_box=self.kernel_box,
|
41
|
+
kernel_class=kernel,
|
42
|
+
kernel_params=kernel_params,
|
43
|
+
peak_caller=PeakCallerFast,
|
44
|
+
)
|
45
|
+
score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
|
46
|
+
|
47
|
+
|
48
|
+
class TestKernel:
|
49
|
+
def setup_method(self):
|
50
|
+
self.number_of_peaks = 100
|
51
|
+
self.min_distance = 5
|
52
|
+
self.data = np.random.rand(100, 100, 100)
|
53
|
+
self.rotation_matrix = np.eye(3)
|
54
|
+
|
55
|
+
def test_initialization(self):
|
56
|
+
_ = GaussianKernel()
|
57
|
+
|
58
|
+
@pytest.mark.parametrize(
|
59
|
+
"kernel",
|
60
|
+
[
|
61
|
+
GaussianKernel,
|
62
|
+
],
|
63
|
+
)
|
64
|
+
def test_fit(self, kernel):
|
65
|
+
params, success, final_error = kernel.fit(self.data)
|
66
|
+
assert isinstance(params, tuple)
|
67
|
+
assert len(params) == 3
|
68
|
+
|
69
|
+
@pytest.mark.parametrize(
|
70
|
+
"kernel",
|
71
|
+
[
|
72
|
+
GaussianKernel,
|
73
|
+
],
|
74
|
+
)
|
75
|
+
def test_fit_height(self, kernel):
|
76
|
+
params, success, final_error = kernel.fit(self.data)
|
77
|
+
height, mean, cov = params
|
78
|
+
params, succes, final_error = kernel.fit_height(
|
79
|
+
data=self.data, mean=mean, cov=cov
|
80
|
+
)
|
81
|
+
assert np.allclose(params[1], mean)
|
82
|
+
assert np.allclose(params[2], cov)
|
83
|
+
|
84
|
+
@pytest.mark.parametrize(
|
85
|
+
"kernel",
|
86
|
+
[
|
87
|
+
GaussianKernel,
|
88
|
+
],
|
89
|
+
)
|
90
|
+
def test_score(self, kernel):
|
91
|
+
params, *_ = kernel.fit(self.data)
|
92
|
+
params2, *_ = kernel.fit(self.data)
|
93
|
+
|
94
|
+
score = kernel.score(params, params2)
|
95
|
+
assert score == 1
|
tests/test_parser.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
import pytest
|
2
|
+
from importlib_resources import files
|
3
|
+
|
4
|
+
from tme.parser import Parser, PDBParser
|
5
|
+
|
6
|
+
|
7
|
+
class TestParser:
|
8
|
+
def setup_method(self):
|
9
|
+
self.pdb_file = str(files("tests.data").joinpath("Structures/5khe.pdb"))
|
10
|
+
|
11
|
+
def teardown_method(self):
|
12
|
+
self.pdb_file = None
|
13
|
+
|
14
|
+
def test_initialize_parser_error(self):
|
15
|
+
with pytest.raises(TypeError):
|
16
|
+
_ = Parser(self.pdb_file)
|
17
|
+
|
18
|
+
def test_parser_get(self):
|
19
|
+
parser = PDBParser(self.pdb_file)
|
20
|
+
assert parser.get("NOTPRESENT", None) is None
|
21
|
+
assert parser.get("record_type", None) is not None
|
22
|
+
|
23
|
+
def test_parser_keys(self):
|
24
|
+
parser = PDBParser(self.pdb_file)
|
25
|
+
assert parser.keys() == parser._data.keys()
|
26
|
+
|
27
|
+
def test_parser_values(self):
|
28
|
+
parser = PDBParser(self.pdb_file)
|
29
|
+
assert str(parser.values()) == str(parser._data.values())
|
30
|
+
|
31
|
+
def test_parser_items(self):
|
32
|
+
parser = PDBParser(self.pdb_file)
|
33
|
+
assert parser.items() == parser._data.items()
|
tests/test_structure.py
ADDED
@@ -0,0 +1,243 @@
|
|
1
|
+
from os import remove
|
2
|
+
from tempfile import mkstemp
|
3
|
+
from importlib_resources import files
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from tme import Structure
|
9
|
+
from tme.matching_utils import euler_to_rotationmatrix, minimum_enclosing_box
|
10
|
+
|
11
|
+
|
12
|
+
STRUCTURE_ATTRIBUTES = [
|
13
|
+
"record_type",
|
14
|
+
"atom_serial_number",
|
15
|
+
"atom_name",
|
16
|
+
"atom_coordinate",
|
17
|
+
"alternate_location_indicator",
|
18
|
+
"residue_name",
|
19
|
+
"chain_identifier",
|
20
|
+
"residue_sequence_number",
|
21
|
+
"code_for_residue_insertion",
|
22
|
+
"occupancy",
|
23
|
+
"temperature_factor",
|
24
|
+
"segment_identifier",
|
25
|
+
"element_symbol",
|
26
|
+
"charge",
|
27
|
+
"metadata",
|
28
|
+
]
|
29
|
+
|
30
|
+
|
31
|
+
class TestStructure:
|
32
|
+
def setup_method(self):
|
33
|
+
self.structure = Structure.from_file(
|
34
|
+
str(files("tests.data").joinpath("Structures/5khe.cif"))
|
35
|
+
)
|
36
|
+
_, self.path = mkstemp()
|
37
|
+
|
38
|
+
def teardown_method(self):
|
39
|
+
del self.structure
|
40
|
+
remove(self.path)
|
41
|
+
|
42
|
+
def compare_structures(self, structure1, structure2, exclude_attributes=[]):
|
43
|
+
for attribute in STRUCTURE_ATTRIBUTES:
|
44
|
+
if attribute in exclude_attributes:
|
45
|
+
continue
|
46
|
+
value = getattr(structure1, attribute)
|
47
|
+
value_comparison = getattr(structure2, attribute)
|
48
|
+
if isinstance(value, np.ndarray):
|
49
|
+
assert np.all(value_comparison == value)
|
50
|
+
else:
|
51
|
+
assert value == value_comparison
|
52
|
+
|
53
|
+
def test_initialization(self):
|
54
|
+
structure = Structure(
|
55
|
+
record_type=self.structure.record_type,
|
56
|
+
atom_serial_number=self.structure.atom_serial_number,
|
57
|
+
atom_name=self.structure.atom_name,
|
58
|
+
atom_coordinate=self.structure.atom_coordinate,
|
59
|
+
alternate_location_indicator=self.structure.alternate_location_indicator,
|
60
|
+
residue_name=self.structure.residue_name,
|
61
|
+
chain_identifier=self.structure.chain_identifier,
|
62
|
+
residue_sequence_number=self.structure.residue_sequence_number,
|
63
|
+
code_for_residue_insertion=self.structure.code_for_residue_insertion,
|
64
|
+
occupancy=self.structure.occupancy,
|
65
|
+
temperature_factor=self.structure.temperature_factor,
|
66
|
+
segment_identifier=self.structure.segment_identifier,
|
67
|
+
element_symbol=self.structure.element_symbol,
|
68
|
+
charge=self.structure.charge,
|
69
|
+
metadata=self.structure.metadata,
|
70
|
+
)
|
71
|
+
|
72
|
+
for attribute in STRUCTURE_ATTRIBUTES:
|
73
|
+
value = getattr(self.structure, attribute)
|
74
|
+
value_comparison = getattr(structure, attribute)
|
75
|
+
if isinstance(value, np.ndarray):
|
76
|
+
assert np.all(value_comparison == value)
|
77
|
+
else:
|
78
|
+
assert value == value_comparison
|
79
|
+
|
80
|
+
@pytest.mark.parametrize(
|
81
|
+
"modified_attribute",
|
82
|
+
[
|
83
|
+
("record_type"),
|
84
|
+
("atom_serial_number"),
|
85
|
+
("atom_name"),
|
86
|
+
("atom_coordinate"),
|
87
|
+
("alternate_location_indicator"),
|
88
|
+
("residue_name"),
|
89
|
+
("chain_identifier"),
|
90
|
+
("residue_sequence_number"),
|
91
|
+
("code_for_residue_insertion"),
|
92
|
+
("occupancy"),
|
93
|
+
("temperature_factor"),
|
94
|
+
("segment_identifier"),
|
95
|
+
("element_symbol"),
|
96
|
+
],
|
97
|
+
)
|
98
|
+
def test_initialization_errors(self, modified_attribute):
|
99
|
+
kwargs = {
|
100
|
+
attribute: getattr(self.structure, attribute)
|
101
|
+
for attribute in STRUCTURE_ATTRIBUTES
|
102
|
+
if attribute != modified_attribute
|
103
|
+
}
|
104
|
+
kwargs[modified_attribute] = getattr(self.structure, modified_attribute)[:1]
|
105
|
+
|
106
|
+
with pytest.raises(ValueError):
|
107
|
+
Structure(**kwargs)
|
108
|
+
|
109
|
+
def test__getitem__(self):
|
110
|
+
ret_single_index = self.structure[1]
|
111
|
+
ret = self.structure[[1]]
|
112
|
+
self.compare_structures(ret_single_index, ret)
|
113
|
+
|
114
|
+
ret = self.structure[self.structure.record_type == "ATOM"]
|
115
|
+
assert np.all(ret.record_type == "ATOM")
|
116
|
+
|
117
|
+
ret = self.structure[self.structure.element_symbol == "C"]
|
118
|
+
assert np.all(ret.element_symbol == "C")
|
119
|
+
|
120
|
+
def test__repr__(self):
|
121
|
+
unique_chains = "-".join(
|
122
|
+
[
|
123
|
+
",".join([str(x) for x in entity])
|
124
|
+
for entity in self.structure.metadata["unique_chains"]
|
125
|
+
]
|
126
|
+
)
|
127
|
+
|
128
|
+
min_atom = np.min(self.structure.atom_serial_number)
|
129
|
+
max_atom = np.max(self.structure.atom_serial_number)
|
130
|
+
n_atom = self.structure.atom_serial_number.size
|
131
|
+
|
132
|
+
min_residue = np.min(self.structure.residue_sequence_number)
|
133
|
+
max_residue = np.max(self.structure.residue_sequence_number)
|
134
|
+
n_residue = np.unique(self.structure.residue_sequence_number).size
|
135
|
+
|
136
|
+
repr_str = (
|
137
|
+
f"Structure object at {id(self.structure)}\n"
|
138
|
+
f"Unique Chains: {unique_chains}, "
|
139
|
+
f"Atom Range: {min_atom}-{max_atom} [N = {n_atom}], "
|
140
|
+
f"Residue Range: {min_residue}-{max_residue} [N = {n_residue}]"
|
141
|
+
)
|
142
|
+
assert repr_str == self.structure.__repr__()
|
143
|
+
|
144
|
+
@pytest.mark.parametrize(
|
145
|
+
"path",
|
146
|
+
[
|
147
|
+
str(files("tests.data").joinpath("Structures/5khe.cif")),
|
148
|
+
str(files("tests.data").joinpath("Structures/5khe.pdb")),
|
149
|
+
],
|
150
|
+
)
|
151
|
+
def test_fromfile(self, path):
|
152
|
+
_ = Structure.from_file(path)
|
153
|
+
|
154
|
+
def test_fromfile_error(self):
|
155
|
+
with pytest.raises(NotImplementedError):
|
156
|
+
_ = Structure.from_file("madeup.extension")
|
157
|
+
|
158
|
+
@pytest.mark.parametrize("file_format", [("cif"), ("pdb")])
|
159
|
+
def test_to_file(self, file_format):
|
160
|
+
_, path = mkstemp()
|
161
|
+
path = f"{path}.{file_format}"
|
162
|
+
self.structure.to_file(path)
|
163
|
+
read = self.structure.from_file(path)
|
164
|
+
comparison = self.structure.copy()
|
165
|
+
|
166
|
+
self.compare_structures(comparison, read, exclude_attributes=["metadata"])
|
167
|
+
|
168
|
+
def test_to_file_error(self):
|
169
|
+
_, path = mkstemp()
|
170
|
+
path = f"{path}.RAISERROR"
|
171
|
+
with pytest.raises(NotImplementedError):
|
172
|
+
self.structure.to_file(path)
|
173
|
+
|
174
|
+
def test_subset_by_chain(self):
|
175
|
+
chain = "A"
|
176
|
+
ret = self.structure.subset_by_chain(chain=chain)
|
177
|
+
assert np.all(ret.chain_identifier == chain)
|
178
|
+
|
179
|
+
def test_subset_by_chain_range(self):
|
180
|
+
chain, start, stop = "A", 0, 20
|
181
|
+
ret = self.structure.subset_by_range(chain=chain, start=start, stop=stop)
|
182
|
+
assert np.all(ret.chain_identifier == chain)
|
183
|
+
assert np.all(
|
184
|
+
np.logical_and(
|
185
|
+
ret.residue_sequence_number >= start,
|
186
|
+
ret.residue_sequence_number <= stop,
|
187
|
+
)
|
188
|
+
)
|
189
|
+
|
190
|
+
def test_center_of_mass(self):
|
191
|
+
center_of_mass = self.structure.center_of_mass()
|
192
|
+
assert center_of_mass.shape[0] == self.structure.atom_coordinate.shape[1]
|
193
|
+
assert np.allclose(center_of_mass, [-0.89391639, 29.94908928, -2.64736741])
|
194
|
+
|
195
|
+
def test_centered(self):
|
196
|
+
ret, translation = self.structure.centered()
|
197
|
+
box = minimum_enclosing_box(coordinates=self.structure.atom_coordinate.T)
|
198
|
+
assert np.allclose(ret.center_of_mass(), np.divide(box, 2), atol=1)
|
199
|
+
|
200
|
+
def test__get_atom_weights_error(self):
|
201
|
+
with pytest.raises(NotImplementedError):
|
202
|
+
self.structure._get_atom_weights(
|
203
|
+
self.structure.atom_name, weight_type="RAISEERROR"
|
204
|
+
)
|
205
|
+
|
206
|
+
def test_compare_structures(self):
|
207
|
+
rmsd = Structure.compare_structures(self.structure, self.structure)
|
208
|
+
assert rmsd == 0
|
209
|
+
|
210
|
+
rmsd = Structure.compare_structures(
|
211
|
+
self.structure, self.structure, weighted=True
|
212
|
+
)
|
213
|
+
assert rmsd == 0
|
214
|
+
|
215
|
+
translation = (3, 0, 0)
|
216
|
+
structure_transform = self.structure.rigid_transform(
|
217
|
+
translation=translation,
|
218
|
+
rotation_matrix=np.eye(self.structure.atom_coordinate.shape[1]),
|
219
|
+
)
|
220
|
+
rmsd = Structure.compare_structures(self.structure, structure_transform)
|
221
|
+
assert np.allclose(rmsd, np.linalg.norm(translation))
|
222
|
+
|
223
|
+
def test_comopare_structures_error(self):
|
224
|
+
ret = self.structure[[1, 2, 3, 4, 5]]
|
225
|
+
with pytest.raises(ValueError):
|
226
|
+
Structure.compare_structures(self.structure, ret)
|
227
|
+
|
228
|
+
def test_align_structures(self):
|
229
|
+
rotation_matrix = euler_to_rotationmatrix((20, -10, 45))
|
230
|
+
translation = (10, 0, -15)
|
231
|
+
|
232
|
+
structure_transform = self.structure.rigid_transform(
|
233
|
+
rotation_matrix=rotation_matrix, translation=translation
|
234
|
+
)
|
235
|
+
aligned, final_rmsd = Structure.align_structures(
|
236
|
+
self.structure, structure_transform
|
237
|
+
)
|
238
|
+
assert final_rmsd <= 0.1
|
239
|
+
|
240
|
+
aligned, final_rmsd = Structure.align_structures(
|
241
|
+
self.structure, structure_transform, sampling_rate=1
|
242
|
+
)
|
243
|
+
assert final_rmsd <= 1
|
tme/__init__.py
CHANGED
tme/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.2.
|
1
|
+
__version__ = "0.2.4"
|
tme/analyzer.py
CHANGED
@@ -459,14 +459,14 @@ class PeakCaller(ABC):
|
|
459
459
|
|
460
460
|
final_order = top_scores[
|
461
461
|
filter_points_indices(
|
462
|
-
coordinates=peak_positions[top_scores],
|
462
|
+
coordinates=peak_positions[top_scores, :],
|
463
463
|
min_distance=self.min_distance,
|
464
464
|
batch_dims=self.batch_dims,
|
465
465
|
)
|
466
466
|
]
|
467
467
|
|
468
|
-
self.peak_list[0] = peak_positions[final_order,]
|
469
|
-
self.peak_list[1] = rotations[final_order,]
|
468
|
+
self.peak_list[0] = peak_positions[final_order, :]
|
469
|
+
self.peak_list[1] = rotations[final_order, :]
|
470
470
|
self.peak_list[2] = peak_scores[final_order]
|
471
471
|
self.peak_list[3] = peak_details[final_order]
|
472
472
|
|
@@ -475,6 +475,7 @@ class PeakCaller(ABC):
|
|
475
475
|
fast_shape: Tuple[int],
|
476
476
|
targetshape: Tuple[int],
|
477
477
|
templateshape: Tuple[int],
|
478
|
+
convolution_shape: Tuple[int] = None,
|
478
479
|
fourier_shift: Tuple[int] = None,
|
479
480
|
convolution_mode: str = None,
|
480
481
|
shared_memory_handler=None,
|
@@ -488,6 +489,7 @@ class PeakCaller(ABC):
|
|
488
489
|
return self
|
489
490
|
|
490
491
|
# Wrap peaks around score space
|
492
|
+
convolution_shape = be.to_backend_array(convolution_shape)
|
491
493
|
fast_shape = be.to_backend_array(fast_shape)
|
492
494
|
if fourier_shift is not None:
|
493
495
|
fourier_shift = be.to_backend_array(fourier_shift)
|
@@ -501,10 +503,9 @@ class PeakCaller(ABC):
|
|
501
503
|
)
|
502
504
|
|
503
505
|
# Remove padding to fast Fourier (and potential full convolution) shape
|
506
|
+
output_shape = convolution_shape
|
504
507
|
targetshape = be.to_backend_array(targetshape)
|
505
508
|
templateshape = be.to_backend_array(templateshape)
|
506
|
-
fast_shape = be.minimum(be.add(targetshape, templateshape) - 1, fast_shape)
|
507
|
-
output_shape = fast_shape
|
508
509
|
if convolution_mode == "same":
|
509
510
|
output_shape = targetshape
|
510
511
|
elif convolution_mode == "valid":
|
@@ -515,7 +516,7 @@ class PeakCaller(ABC):
|
|
515
516
|
|
516
517
|
output_shape = be.to_backend_array(output_shape)
|
517
518
|
starts = be.astype(
|
518
|
-
be.divide(be.subtract(
|
519
|
+
be.divide(be.subtract(convolution_shape, output_shape), 2),
|
519
520
|
be._int_dtype,
|
520
521
|
)
|
521
522
|
stops = be.add(starts, output_shape)
|
@@ -1019,6 +1020,7 @@ class MaxScoreOverRotations:
|
|
1019
1020
|
self,
|
1020
1021
|
targetshape: Tuple[int],
|
1021
1022
|
templateshape: Tuple[int],
|
1023
|
+
convolution_shape: Tuple[int],
|
1022
1024
|
fourier_shift: Tuple[int] = None,
|
1023
1025
|
convolution_mode: str = None,
|
1024
1026
|
shared_memory_handler=None,
|
@@ -1039,6 +1041,7 @@ class MaxScoreOverRotations:
|
|
1039
1041
|
"s1": targetshape,
|
1040
1042
|
"s2": templateshape,
|
1041
1043
|
"convolution_mode": convolution_mode,
|
1044
|
+
"convolution_shape": convolution_shape,
|
1042
1045
|
}
|
1043
1046
|
if convolution_mode is not None:
|
1044
1047
|
scores = apply_convolution_mode(scores, **convargs)
|
tme/backends/__init__.py
CHANGED
tme/backends/_jax_utils.py
CHANGED
@@ -19,9 +19,9 @@ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
|
19
19
|
"""
|
20
20
|
Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
|
21
21
|
"""
|
22
|
-
template_ft = jnp.fft.rfftn(template)
|
22
|
+
template_ft = jnp.fft.rfftn(template, s=template.shape)
|
23
23
|
template_ft = template_ft.at[:].multiply(ft_target)
|
24
|
-
correlation = jnp.fft.irfftn(template_ft)
|
24
|
+
correlation = jnp.fft.irfftn(template_ft, s=template.shape)
|
25
25
|
return correlation
|
26
26
|
|
27
27
|
|
@@ -77,14 +77,15 @@ def _reciprocal_target_std(
|
|
77
77
|
--------
|
78
78
|
:py:meth:`tme.matching_exhaustive.flc_scoring`.
|
79
79
|
"""
|
80
|
-
|
80
|
+
ft_shape = template_mask.shape
|
81
|
+
ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
|
81
82
|
|
82
83
|
# E(X^2)- E(X)^2
|
83
|
-
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask)
|
84
|
+
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=ft_shape)
|
84
85
|
exp_sq = exp_sq.at[:].divide(n_observations)
|
85
86
|
|
86
87
|
ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
|
87
|
-
sq_exp = jnp.fft.irfftn(ft_template_mask)
|
88
|
+
sq_exp = jnp.fft.irfftn(ft_template_mask, s=ft_shape)
|
88
89
|
sq_exp = sq_exp.at[:].divide(n_observations)
|
89
90
|
sq_exp = sq_exp.at[:].power(2)
|
90
91
|
|
@@ -99,7 +100,7 @@ def _reciprocal_target_std(
|
|
99
100
|
|
100
101
|
|
101
102
|
def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
102
|
-
arr_ft = jnp.fft.rfftn(arr)
|
103
|
+
arr_ft = jnp.fft.rfftn(arr, s=arr.shape)
|
103
104
|
arr_ft = arr_ft.at[:].multiply(arr_filter)
|
104
105
|
return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
|
105
106
|
|
@@ -107,6 +108,7 @@ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> Backen
|
|
107
108
|
def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
108
109
|
return arr
|
109
110
|
|
111
|
+
|
110
112
|
@partial(
|
111
113
|
pmap,
|
112
114
|
in_axes=(0,) + (None,) * 6,
|
@@ -127,8 +129,8 @@ def scan(
|
|
127
129
|
if hasattr(target_filter, "shape"):
|
128
130
|
target = _apply_fourier_filter(target, target_filter)
|
129
131
|
|
130
|
-
ft_target = jnp.fft.rfftn(target)
|
131
|
-
ft_target2 = jnp.fft.rfftn(jnp.square(target))
|
132
|
+
ft_target = jnp.fft.rfftn(target, s=fast_shape)
|
133
|
+
ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
|
132
134
|
inv_denominator, target, scoring_func = None, None, _flc_scoring
|
133
135
|
if not rotate_mask:
|
134
136
|
n_observations = jnp.sum(template_mask)
|