pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
tests/test_rotations.py
ADDED
@@ -0,0 +1,153 @@
|
|
1
|
+
from importlib_resources import files
|
2
|
+
from itertools import combinations, chain, product
|
3
|
+
|
4
|
+
import pytest
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from tme import Density
|
8
|
+
from scipy.spatial.transform import Rotation
|
9
|
+
from scipy.signal import correlate
|
10
|
+
|
11
|
+
from tme.rotations import (
|
12
|
+
euler_from_rotationmatrix,
|
13
|
+
euler_to_rotationmatrix,
|
14
|
+
get_cone_rotations,
|
15
|
+
align_vectors,
|
16
|
+
get_rotation_matrices,
|
17
|
+
)
|
18
|
+
from tme.matching_utils import (
|
19
|
+
elliptical_mask,
|
20
|
+
split_shape,
|
21
|
+
compute_full_convolution_index,
|
22
|
+
)
|
23
|
+
|
24
|
+
BASEPATH = files("tests.data")
|
25
|
+
|
26
|
+
|
27
|
+
class TestRotations:
|
28
|
+
def setup_method(self):
|
29
|
+
self.density = Density.from_file(str(BASEPATH.joinpath("Raw/em_map.map")))
|
30
|
+
self.structure_density = Density.from_structure(
|
31
|
+
filename_or_structure=str(BASEPATH.joinpath("Structures/5khe.cif")),
|
32
|
+
origin=self.density.origin,
|
33
|
+
shape=self.density.shape,
|
34
|
+
sampling_rate=self.density.sampling_rate,
|
35
|
+
)
|
36
|
+
|
37
|
+
@pytest.mark.parametrize(
|
38
|
+
"initial_vector, target_vector, convention",
|
39
|
+
[
|
40
|
+
([1, 0, 0], [0, 1, 0], None),
|
41
|
+
([0, 1, 0], [0, 0, 1], "zyx"),
|
42
|
+
([1, 1, 1], [1, 0, 0], "xyz"),
|
43
|
+
],
|
44
|
+
)
|
45
|
+
def test_align_vectors(self, initial_vector, target_vector, convention):
|
46
|
+
result = align_vectors(initial_vector, target_vector, convention)
|
47
|
+
|
48
|
+
assert isinstance(result, np.ndarray)
|
49
|
+
if convention is None:
|
50
|
+
assert result.shape == (3, 3)
|
51
|
+
assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
|
52
|
+
else:
|
53
|
+
assert len(result) == 3
|
54
|
+
result = Rotation.from_euler(convention, result, degrees=True).as_matrix()
|
55
|
+
assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
|
56
|
+
|
57
|
+
rotated = np.dot(Rotation.from_matrix(result).as_matrix(), initial_vector)
|
58
|
+
assert np.allclose(
|
59
|
+
rotated / np.linalg.norm(rotated),
|
60
|
+
target_vector / np.linalg.norm(target_vector),
|
61
|
+
atol=1e-6,
|
62
|
+
)
|
63
|
+
|
64
|
+
@pytest.mark.parametrize(
|
65
|
+
"cone_angle, cone_sampling, axis_angle, axis_sampling, vector, n_symmetry, convention",
|
66
|
+
[
|
67
|
+
(30, 5, 360, None, (1, 0, 0), 1, None),
|
68
|
+
(45, 10, 180, 15, (0, 1, 0), 2, "zyx"),
|
69
|
+
(60, 15, 90, 30, (0, 0, 1), 4, "xyz"),
|
70
|
+
],
|
71
|
+
)
|
72
|
+
def test_get_cone_rotations(
|
73
|
+
self,
|
74
|
+
cone_angle,
|
75
|
+
cone_sampling,
|
76
|
+
axis_angle,
|
77
|
+
axis_sampling,
|
78
|
+
vector,
|
79
|
+
n_symmetry,
|
80
|
+
convention,
|
81
|
+
):
|
82
|
+
result = get_cone_rotations(
|
83
|
+
cone_angle=cone_angle,
|
84
|
+
cone_sampling=cone_sampling,
|
85
|
+
axis_angle=axis_angle,
|
86
|
+
axis_sampling=axis_sampling,
|
87
|
+
reference=vector,
|
88
|
+
n_symmetry=n_symmetry,
|
89
|
+
seq=convention,
|
90
|
+
)
|
91
|
+
|
92
|
+
assert isinstance(result, np.ndarray)
|
93
|
+
if convention is None:
|
94
|
+
assert result.shape[1:] == (3, 3)
|
95
|
+
else:
|
96
|
+
assert result.shape[1] == 3
|
97
|
+
|
98
|
+
def test_euler_conversion(self):
|
99
|
+
rotation_matrix_initial = np.array(
|
100
|
+
[
|
101
|
+
[0.35355339, 0.61237244, -0.70710678],
|
102
|
+
[-0.8660254, 0.5, -0.0],
|
103
|
+
[0.35355339, 0.61237244, 0.70710678],
|
104
|
+
]
|
105
|
+
)
|
106
|
+
euler_angles = euler_from_rotationmatrix(rotation_matrix_initial)
|
107
|
+
rotation_matrix_converted = euler_to_rotationmatrix(euler_angles)
|
108
|
+
assert np.allclose(
|
109
|
+
rotation_matrix_initial, rotation_matrix_converted, atol=1e-6
|
110
|
+
)
|
111
|
+
|
112
|
+
@pytest.mark.parametrize("dim", range(1, 3, 5))
|
113
|
+
@pytest.mark.parametrize("angular_sampling", [10, 15, 20])
|
114
|
+
def test_get_rotation_matrices(self, dim, angular_sampling):
|
115
|
+
rotation_matrices = get_rotation_matrices(
|
116
|
+
angular_sampling=angular_sampling, dim=dim
|
117
|
+
)
|
118
|
+
assert np.allclose(rotation_matrices[0] @ rotation_matrices[0].T, np.eye(dim))
|
119
|
+
|
120
|
+
def test_split_correlation(self):
|
121
|
+
arr1 = elliptical_mask(shape=(50, 51), center=(20, 30), radius=5)
|
122
|
+
|
123
|
+
arr2 = elliptical_mask(shape=(41, 36), center=(25, 20), radius=5)
|
124
|
+
s = range(arr1.ndim)
|
125
|
+
outer_split = chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
|
126
|
+
|
127
|
+
s = range(arr2.ndim)
|
128
|
+
inner_split = chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
|
129
|
+
|
130
|
+
outer_splits = [dict(zip(i, [2] * len(i))) for i in list(outer_split)]
|
131
|
+
inner_splits = [dict(zip(i, [2] * len(i))) for i in list(inner_split)]
|
132
|
+
|
133
|
+
for outer_split, inner_split in product(outer_splits, inner_splits):
|
134
|
+
splits1 = split_shape(
|
135
|
+
shape=arr1.shape, splits=outer_split, equal_shape=False
|
136
|
+
)
|
137
|
+
splits2 = split_shape(
|
138
|
+
shape=arr2.shape, splits=inner_split, equal_shape=False
|
139
|
+
)
|
140
|
+
|
141
|
+
full = correlate(arr1, arr2, method="direct", mode="full")
|
142
|
+
temp = np.zeros_like(full)
|
143
|
+
|
144
|
+
for arr1_split, arr2_split in product(splits1, splits2):
|
145
|
+
correlation = correlate(
|
146
|
+
arr1[arr1_split], arr2[arr2_split], method="direct", mode="full"
|
147
|
+
)
|
148
|
+
score_slice = compute_full_convolution_index(
|
149
|
+
arr1.shape, arr2.shape, arr1_split, arr2_split
|
150
|
+
)
|
151
|
+
temp[score_slice] += correlation
|
152
|
+
|
153
|
+
assert np.allclose(temp, full)
|
tests/test_structure.py
ADDED
@@ -0,0 +1,247 @@
|
|
1
|
+
from os import remove
|
2
|
+
from tempfile import mkstemp
|
3
|
+
from importlib_resources import files
|
4
|
+
|
5
|
+
import pytest
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from tme import Structure
|
9
|
+
from tme.matching_utils import minimum_enclosing_box
|
10
|
+
from tme.rotations import euler_to_rotationmatrix
|
11
|
+
|
12
|
+
|
13
|
+
STRUCTURE_ATTRIBUTES = [
|
14
|
+
"record_type",
|
15
|
+
"atom_serial_number",
|
16
|
+
"atom_name",
|
17
|
+
"atom_coordinate",
|
18
|
+
"alternate_location_indicator",
|
19
|
+
"residue_name",
|
20
|
+
"chain_identifier",
|
21
|
+
"residue_sequence_number",
|
22
|
+
"code_for_residue_insertion",
|
23
|
+
"occupancy",
|
24
|
+
"temperature_factor",
|
25
|
+
"segment_identifier",
|
26
|
+
"element_symbol",
|
27
|
+
"charge",
|
28
|
+
"metadata",
|
29
|
+
]
|
30
|
+
|
31
|
+
|
32
|
+
class TestStructure:
|
33
|
+
def setup_method(self):
|
34
|
+
self.structure = Structure.from_file(
|
35
|
+
str(files("tests.data").joinpath("Structures/5khe.cif"))
|
36
|
+
)
|
37
|
+
_, self.path = mkstemp()
|
38
|
+
|
39
|
+
def teardown_method(self):
|
40
|
+
del self.structure
|
41
|
+
remove(self.path)
|
42
|
+
|
43
|
+
def compare_structures(self, structure1, structure2, exclude_attributes=[]):
|
44
|
+
for attribute in STRUCTURE_ATTRIBUTES:
|
45
|
+
if attribute in exclude_attributes:
|
46
|
+
continue
|
47
|
+
value = getattr(structure1, attribute)
|
48
|
+
value_comparison = getattr(structure2, attribute)
|
49
|
+
if isinstance(value, np.ndarray):
|
50
|
+
assert np.all(value_comparison == value)
|
51
|
+
else:
|
52
|
+
assert value == value_comparison
|
53
|
+
|
54
|
+
def test_initialization(self):
|
55
|
+
structure = Structure(
|
56
|
+
record_type=self.structure.record_type,
|
57
|
+
atom_serial_number=self.structure.atom_serial_number,
|
58
|
+
atom_name=self.structure.atom_name,
|
59
|
+
atom_coordinate=self.structure.atom_coordinate,
|
60
|
+
alternate_location_indicator=self.structure.alternate_location_indicator,
|
61
|
+
residue_name=self.structure.residue_name,
|
62
|
+
chain_identifier=self.structure.chain_identifier,
|
63
|
+
residue_sequence_number=self.structure.residue_sequence_number,
|
64
|
+
code_for_residue_insertion=self.structure.code_for_residue_insertion,
|
65
|
+
occupancy=self.structure.occupancy,
|
66
|
+
temperature_factor=self.structure.temperature_factor,
|
67
|
+
segment_identifier=self.structure.segment_identifier,
|
68
|
+
element_symbol=self.structure.element_symbol,
|
69
|
+
charge=self.structure.charge,
|
70
|
+
metadata=self.structure.metadata,
|
71
|
+
)
|
72
|
+
|
73
|
+
for attribute in STRUCTURE_ATTRIBUTES:
|
74
|
+
value = getattr(self.structure, attribute)
|
75
|
+
value_comparison = getattr(structure, attribute)
|
76
|
+
if isinstance(value, np.ndarray):
|
77
|
+
assert np.all(value_comparison == value)
|
78
|
+
else:
|
79
|
+
assert value == value_comparison
|
80
|
+
|
81
|
+
@pytest.mark.parametrize(
|
82
|
+
"modified_attribute",
|
83
|
+
[
|
84
|
+
("record_type"),
|
85
|
+
("atom_serial_number"),
|
86
|
+
("atom_name"),
|
87
|
+
("atom_coordinate"),
|
88
|
+
("alternate_location_indicator"),
|
89
|
+
("residue_name"),
|
90
|
+
("chain_identifier"),
|
91
|
+
("residue_sequence_number"),
|
92
|
+
("code_for_residue_insertion"),
|
93
|
+
("occupancy"),
|
94
|
+
("temperature_factor"),
|
95
|
+
("segment_identifier"),
|
96
|
+
("element_symbol"),
|
97
|
+
],
|
98
|
+
)
|
99
|
+
def test_initialization_errors(self, modified_attribute):
|
100
|
+
kwargs = {
|
101
|
+
attribute: getattr(self.structure, attribute)
|
102
|
+
for attribute in STRUCTURE_ATTRIBUTES
|
103
|
+
if attribute != modified_attribute
|
104
|
+
}
|
105
|
+
kwargs[modified_attribute] = getattr(self.structure, modified_attribute)[:1]
|
106
|
+
|
107
|
+
with pytest.raises(ValueError):
|
108
|
+
Structure(**kwargs)
|
109
|
+
|
110
|
+
def test__getitem__(self):
|
111
|
+
ret_single_index = self.structure[1]
|
112
|
+
ret = self.structure[[1]]
|
113
|
+
self.compare_structures(ret_single_index, ret)
|
114
|
+
|
115
|
+
ret = self.structure[self.structure.record_type == "ATOM"]
|
116
|
+
assert np.all(ret.record_type == "ATOM")
|
117
|
+
|
118
|
+
ret = self.structure[self.structure.element_symbol == "C"]
|
119
|
+
assert np.all(ret.element_symbol == "C")
|
120
|
+
|
121
|
+
def test__repr__(self):
|
122
|
+
unique_chains = "-".join(
|
123
|
+
[
|
124
|
+
",".join([str(x) for x in entity])
|
125
|
+
for entity in self.structure.metadata["unique_chains"]
|
126
|
+
]
|
127
|
+
)
|
128
|
+
|
129
|
+
min_atom = np.min(self.structure.atom_serial_number)
|
130
|
+
max_atom = np.max(self.structure.atom_serial_number)
|
131
|
+
n_atom = self.structure.atom_serial_number.size
|
132
|
+
|
133
|
+
min_residue = np.min(self.structure.residue_sequence_number)
|
134
|
+
max_residue = np.max(self.structure.residue_sequence_number)
|
135
|
+
n_residue = np.unique(self.structure.residue_sequence_number).size
|
136
|
+
|
137
|
+
repr_str = (
|
138
|
+
f"Structure object at {id(self.structure)}\n"
|
139
|
+
f"Unique Chains: {unique_chains}, "
|
140
|
+
f"Atom Range: {min_atom}-{max_atom} [N = {n_atom}], "
|
141
|
+
f"Residue Range: {min_residue}-{max_residue} [N = {n_residue}]"
|
142
|
+
)
|
143
|
+
assert repr_str == self.structure.__repr__()
|
144
|
+
|
145
|
+
@pytest.mark.parametrize(
|
146
|
+
"path",
|
147
|
+
[
|
148
|
+
str(files("tests.data").joinpath("Structures/5khe.cif")),
|
149
|
+
str(files("tests.data").joinpath("Structures/5khe.pdb")),
|
150
|
+
],
|
151
|
+
)
|
152
|
+
def test_fromfile(self, path):
|
153
|
+
_ = Structure.from_file(path)
|
154
|
+
|
155
|
+
def test_fromfile_error(self):
|
156
|
+
with pytest.raises(NotImplementedError):
|
157
|
+
_ = Structure.from_file("madeup.extension")
|
158
|
+
|
159
|
+
@pytest.mark.parametrize("file_format", ["cif", "pdb", "gro"])
|
160
|
+
def test_to_file(self, file_format):
|
161
|
+
_, path = mkstemp()
|
162
|
+
path = f"{path}.{file_format}"
|
163
|
+
self.structure.to_file(path)
|
164
|
+
read = self.structure.from_file(path)
|
165
|
+
comparison = self.structure.copy()
|
166
|
+
|
167
|
+
if file_format != "gro":
|
168
|
+
self.compare_structures(comparison, read, exclude_attributes=["metadata"])
|
169
|
+
else:
|
170
|
+
assert np.allclose(comparison.atom_coordinate, read.atom_coordinate)
|
171
|
+
|
172
|
+
def test_to_file_error(self):
|
173
|
+
_, path = mkstemp()
|
174
|
+
path = f"{path}.RAISERROR"
|
175
|
+
with pytest.raises(NotImplementedError):
|
176
|
+
self.structure.to_file(path)
|
177
|
+
|
178
|
+
def test_subset_by_chain(self):
|
179
|
+
chain = "A"
|
180
|
+
ret = self.structure.subset_by_chain(chain=chain)
|
181
|
+
assert np.all(ret.chain_identifier == chain)
|
182
|
+
|
183
|
+
def test_subset_by_chain_range(self):
|
184
|
+
chain, start, stop = "A", 0, 20
|
185
|
+
ret = self.structure.subset_by_range(chain=chain, start=start, stop=stop)
|
186
|
+
assert np.all(ret.chain_identifier == chain)
|
187
|
+
assert np.all(
|
188
|
+
np.logical_and(
|
189
|
+
ret.residue_sequence_number >= start,
|
190
|
+
ret.residue_sequence_number <= stop,
|
191
|
+
)
|
192
|
+
)
|
193
|
+
|
194
|
+
def test_center_of_mass(self):
|
195
|
+
center_of_mass = self.structure.center_of_mass()
|
196
|
+
assert center_of_mass.shape[0] == self.structure.atom_coordinate.shape[1]
|
197
|
+
assert np.allclose(center_of_mass, [-0.89391639, 29.94908928, -2.64736741])
|
198
|
+
|
199
|
+
def test_centered(self):
|
200
|
+
ret, translation = self.structure.centered()
|
201
|
+
box = minimum_enclosing_box(coordinates=self.structure.atom_coordinate.T)
|
202
|
+
assert np.allclose(ret.center_of_mass(), np.divide(box, 2), atol=1)
|
203
|
+
|
204
|
+
def test__get_atom_weights_error(self):
|
205
|
+
with pytest.raises(NotImplementedError):
|
206
|
+
self.structure._get_atom_weights(
|
207
|
+
self.structure.atom_name, weight_type="RAISEERROR"
|
208
|
+
)
|
209
|
+
|
210
|
+
def test_compare_structures(self):
|
211
|
+
rmsd = Structure.compare_structures(self.structure, self.structure)
|
212
|
+
assert rmsd == 0
|
213
|
+
|
214
|
+
rmsd = Structure.compare_structures(
|
215
|
+
self.structure, self.structure, weighted=True
|
216
|
+
)
|
217
|
+
assert rmsd == 0
|
218
|
+
|
219
|
+
translation = (3, 0, 0)
|
220
|
+
structure_transform = self.structure.rigid_transform(
|
221
|
+
translation=translation,
|
222
|
+
rotation_matrix=np.eye(self.structure.atom_coordinate.shape[1]),
|
223
|
+
)
|
224
|
+
rmsd = Structure.compare_structures(self.structure, structure_transform)
|
225
|
+
assert np.allclose(rmsd, np.linalg.norm(translation))
|
226
|
+
|
227
|
+
def test_comopare_structures_error(self):
|
228
|
+
ret = self.structure[[1, 2, 3, 4, 5]]
|
229
|
+
with pytest.raises(ValueError):
|
230
|
+
Structure.compare_structures(self.structure, ret)
|
231
|
+
|
232
|
+
def test_align_structures(self):
|
233
|
+
rotation_matrix = euler_to_rotationmatrix((20, -10, 45))
|
234
|
+
translation = (10, 0, -15)
|
235
|
+
|
236
|
+
structure_transform = self.structure.rigid_transform(
|
237
|
+
rotation_matrix=rotation_matrix, translation=translation
|
238
|
+
)
|
239
|
+
aligned, final_rmsd = Structure.align_structures(
|
240
|
+
self.structure, structure_transform
|
241
|
+
)
|
242
|
+
assert final_rmsd <= 0.1
|
243
|
+
|
244
|
+
aligned, final_rmsd = Structure.align_structures(
|
245
|
+
self.structure, structure_transform, sampling_rate=1
|
246
|
+
)
|
247
|
+
assert final_rmsd <= 1
|
tme/__init__.py
ADDED
tme/__version__.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = "0.2.9"
|
tme/analyzer/__init__.py
ADDED
tme/analyzer/_utils.py
ADDED
@@ -0,0 +1,186 @@
|
|
1
|
+
""" Analyzer utility functions.
|
2
|
+
|
3
|
+
Copyright (c) 2023-2025 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Tuple
|
9
|
+
|
10
|
+
from ..types import BackendArray
|
11
|
+
from ..backends import backend as be
|
12
|
+
|
13
|
+
__all__ = ["cart_to_score", "score_to_cart"]
|
14
|
+
|
15
|
+
|
16
|
+
def _convmode_to_shape(
|
17
|
+
convolution_mode: str,
|
18
|
+
targetshape: BackendArray,
|
19
|
+
templateshape: BackendArray,
|
20
|
+
convolution_shape: BackendArray,
|
21
|
+
) -> BackendArray:
|
22
|
+
"""
|
23
|
+
Calculate convolution shape based on convolution mode.
|
24
|
+
|
25
|
+
Parameters
|
26
|
+
----------
|
27
|
+
convolution_mode : str
|
28
|
+
Mode of convolution. Supported values are:
|
29
|
+
- 'same': Output shape will match target shape
|
30
|
+
- 'valid': Output shape will be target shape minus template shape plus
|
31
|
+
template shape modulo 2
|
32
|
+
- Other: Output shape will be equal to convolution_shape
|
33
|
+
targetshape : BackendArray
|
34
|
+
Shape of the target array.
|
35
|
+
templateshape : BackendArray
|
36
|
+
Shape of the template array.
|
37
|
+
convolution_shape : BackendArray
|
38
|
+
Shape of the convolution output.
|
39
|
+
|
40
|
+
Returns
|
41
|
+
-------
|
42
|
+
BackendArray
|
43
|
+
Convolution shape.
|
44
|
+
"""
|
45
|
+
output_shape = convolution_shape
|
46
|
+
if convolution_mode == "same":
|
47
|
+
output_shape = targetshape
|
48
|
+
elif convolution_mode == "valid":
|
49
|
+
output_shape = be.add(
|
50
|
+
be.subtract(targetshape, templateshape),
|
51
|
+
be.mod(templateshape, 2),
|
52
|
+
)
|
53
|
+
return be.to_backend_array(output_shape)
|
54
|
+
|
55
|
+
|
56
|
+
def cart_to_score(
|
57
|
+
positions: BackendArray,
|
58
|
+
fast_shape: Tuple[int],
|
59
|
+
targetshape: Tuple[int],
|
60
|
+
templateshape: Tuple[int],
|
61
|
+
convolution_shape: Tuple[int] = None,
|
62
|
+
fourier_shift: Tuple[int] = None,
|
63
|
+
convolution_mode: str = None,
|
64
|
+
**kwargs,
|
65
|
+
) -> Tuple[BackendArray]:
|
66
|
+
"""
|
67
|
+
Maps peak positions from cartesian to padded score space coordinates.
|
68
|
+
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
positions : BackendArray
|
72
|
+
Positions in cartesian coordinates.
|
73
|
+
fast_shape : tuple of int
|
74
|
+
Shape of the score space padded to efficient Fourier shape.
|
75
|
+
targetshape : tuple of int
|
76
|
+
Shape of the target array.
|
77
|
+
templateshape : tuple of int
|
78
|
+
Shape of the template array.
|
79
|
+
convolution_shape : tuple of int, optional
|
80
|
+
Non-padded convolution_shape of template and target.
|
81
|
+
fourier_shift : tuple of int, optional
|
82
|
+
Translation offset of coordinates.
|
83
|
+
convolution_mode : str, optional
|
84
|
+
Mode of convolution ('same', 'valid', or 'full')
|
85
|
+
|
86
|
+
Returns
|
87
|
+
-------
|
88
|
+
Tuple of BackendArray
|
89
|
+
Adjusted positions. and boolean array indicating whether corresponding
|
90
|
+
positions are valid positions w.r.t. to supplied bounds.
|
91
|
+
"""
|
92
|
+
positions = be.to_backend_array(positions)
|
93
|
+
fast_shape = be.to_backend_array(fast_shape)
|
94
|
+
targetshape = be.to_backend_array(targetshape)
|
95
|
+
templateshape = be.to_backend_array(templateshape)
|
96
|
+
convolution_shape = be.to_backend_array(convolution_shape)
|
97
|
+
|
98
|
+
# Compute removed padding
|
99
|
+
output_shape = _convmode_to_shape(
|
100
|
+
convolution_mode=convolution_mode,
|
101
|
+
targetshape=targetshape,
|
102
|
+
templateshape=templateshape,
|
103
|
+
convolution_shape=convolution_shape,
|
104
|
+
)
|
105
|
+
valid_positions = be.multiply(positions >= 0, positions < output_shape)
|
106
|
+
valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
|
107
|
+
|
108
|
+
starts = be.astype(
|
109
|
+
be.divide(be.subtract(convolution_shape, output_shape), 2),
|
110
|
+
be._int_dtype,
|
111
|
+
)
|
112
|
+
|
113
|
+
positions = be.add(positions, starts)
|
114
|
+
if fourier_shift is not None:
|
115
|
+
fourier_shift = be.to_backend_array(fourier_shift)
|
116
|
+
positions = be.subtract(positions, fourier_shift)
|
117
|
+
positions = be.mod(positions, fast_shape)
|
118
|
+
|
119
|
+
return positions, valid_positions
|
120
|
+
|
121
|
+
|
122
|
+
def score_to_cart(
|
123
|
+
positions,
|
124
|
+
fast_shape: Tuple[int] = None,
|
125
|
+
targetshape: Tuple[int] = None,
|
126
|
+
templateshape: Tuple[int] = None,
|
127
|
+
convolution_shape: Tuple[int] = None,
|
128
|
+
fourier_shift: Tuple[int] = None,
|
129
|
+
convolution_mode: str = None,
|
130
|
+
**kwargs,
|
131
|
+
) -> Tuple[BackendArray]:
|
132
|
+
"""
|
133
|
+
Maps peak positions from padded score to cartesian coordinates.
|
134
|
+
|
135
|
+
Parameters
|
136
|
+
----------
|
137
|
+
positions : BackendArray
|
138
|
+
Positions in padded Fourier space system.
|
139
|
+
fast_shape : tuple of int
|
140
|
+
Shape of the score space padded to efficient Fourier shape.
|
141
|
+
targetshape : tuple of int
|
142
|
+
Shape of the target array.
|
143
|
+
templateshape : tuple of int
|
144
|
+
Shape of the template array.
|
145
|
+
convolution_shape : tuple of int, optional
|
146
|
+
Non-padded convolution_shape of template and target.
|
147
|
+
fourier_shift : tuple of int, optional
|
148
|
+
Translation offset of coordinates.
|
149
|
+
convolution_mode : str, optional
|
150
|
+
Mode of convolution ('same', 'valid', or 'full')
|
151
|
+
|
152
|
+
Returns
|
153
|
+
-------
|
154
|
+
Tuple of BackendArray
|
155
|
+
Adjusted positions. and boolean array indicating whether corresponding
|
156
|
+
positions are valid positions w.r.t. to supplied bounds.
|
157
|
+
"""
|
158
|
+
positions = be.to_backend_array(positions)
|
159
|
+
convolution_shape = be.to_backend_array(convolution_shape)
|
160
|
+
fast_shape = be.to_backend_array(fast_shape)
|
161
|
+
targetshape = be.to_backend_array(targetshape)
|
162
|
+
templateshape = be.to_backend_array(templateshape)
|
163
|
+
|
164
|
+
# Wrap peaks around score space
|
165
|
+
if fourier_shift is not None:
|
166
|
+
fourier_shift = be.to_backend_array(fourier_shift)
|
167
|
+
positions = be.add(positions, fourier_shift)
|
168
|
+
positions = be.mod(positions, fast_shape)
|
169
|
+
|
170
|
+
output_shape = _convmode_to_shape(
|
171
|
+
convolution_mode=convolution_mode,
|
172
|
+
targetshape=targetshape,
|
173
|
+
templateshape=templateshape,
|
174
|
+
convolution_shape=convolution_shape,
|
175
|
+
)
|
176
|
+
starts = be.astype(
|
177
|
+
be.divide(be.subtract(convolution_shape, output_shape), 2),
|
178
|
+
be._int_dtype,
|
179
|
+
)
|
180
|
+
stops = be.add(starts, output_shape)
|
181
|
+
|
182
|
+
valid_positions = be.multiply(positions >= starts, positions < stops)
|
183
|
+
valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
|
184
|
+
positions = be.subtract(positions, starts)
|
185
|
+
|
186
|
+
return positions, valid_positions
|