pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/match_template.py +97 -148
  2. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/postprocess.py +20 -29
  3. pytme-0.2.4.data/scripts/preprocess.py +148 -0
  4. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +15 -23
  5. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/METADATA +11 -10
  6. pytme-0.2.4.dist-info/RECORD +119 -0
  7. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
  8. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
  9. pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
  10. scripts/match_template.py +97 -148
  11. scripts/postprocess.py +20 -29
  12. scripts/preprocess.py +116 -61
  13. scripts/preprocessor_gui.py +15 -23
  14. tests/__init__.py +0 -0
  15. tests/data/.DS_Store +0 -0
  16. tests/data/Blurring/.DS_Store +0 -0
  17. tests/data/Blurring/blob_width18.npy +0 -0
  18. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  19. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  20. tests/data/Blurring/hamming_width6.npy +0 -0
  21. tests/data/Blurring/kaiserb_width18.npy +0 -0
  22. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  23. tests/data/Blurring/mean_size5.npy +0 -0
  24. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  25. tests/data/Blurring/rank_rank3.npy +0 -0
  26. tests/data/Maps/.DS_Store +0 -0
  27. tests/data/Maps/emd_8621.mrc.gz +0 -0
  28. tests/data/README.md +2 -0
  29. tests/data/Raw/.DS_Store +0 -0
  30. tests/data/Raw/em_map.map +0 -0
  31. tests/data/Structures/.DS_Store +0 -0
  32. tests/data/Structures/1pdj.cif +3339 -0
  33. tests/data/Structures/1pdj.pdb +1429 -0
  34. tests/data/Structures/5khe.cif +3685 -0
  35. tests/data/Structures/5khe.ent +2210 -0
  36. tests/data/Structures/5khe.pdb +2210 -0
  37. tests/data/Structures/5uz4.cif +70548 -0
  38. tests/preprocessing/__init__.py +0 -0
  39. tests/preprocessing/test_compose.py +76 -0
  40. tests/preprocessing/test_frequency_filters.py +178 -0
  41. tests/preprocessing/test_preprocessor.py +136 -0
  42. tests/preprocessing/test_utils.py +79 -0
  43. tests/test_analyzer.py +310 -0
  44. tests/test_backends.py +375 -0
  45. tests/test_density.py +508 -0
  46. tests/test_extensions.py +130 -0
  47. tests/test_matching_cli.py +283 -0
  48. tests/test_matching_data.py +162 -0
  49. tests/test_matching_exhaustive.py +162 -0
  50. tests/test_matching_memory.py +30 -0
  51. tests/test_matching_optimization.py +276 -0
  52. tests/test_matching_utils.py +326 -0
  53. tests/test_orientations.py +173 -0
  54. tests/test_packaging.py +95 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_structure.py +243 -0
  57. tme/__init__.py +0 -1
  58. tme/__version__.py +1 -1
  59. tme/analyzer.py +9 -6
  60. tme/backends/__init__.py +1 -1
  61. tme/backends/_jax_utils.py +10 -8
  62. tme/backends/cupy_backend.py +2 -7
  63. tme/backends/jax_backend.py +35 -20
  64. tme/backends/npfftw_backend.py +3 -2
  65. tme/backends/pytorch_backend.py +10 -7
  66. tme/data/scattering_factors.pickle +0 -0
  67. tme/density.py +26 -12
  68. tme/extensions.cpython-311-darwin.so +0 -0
  69. tme/external/bindings.cpp +332 -0
  70. tme/matching_data.py +33 -24
  71. tme/matching_exhaustive.py +39 -20
  72. tme/matching_scores.py +5 -2
  73. tme/matching_utils.py +8 -2
  74. tme/orientations.py +26 -9
  75. tme/preprocessing/_utils.py +14 -14
  76. tme/preprocessing/composable_filter.py +5 -4
  77. tme/preprocessing/compose.py +4 -4
  78. tme/preprocessing/frequency_filters.py +32 -35
  79. tme/preprocessing/tilt_series.py +210 -148
  80. tme/preprocessor.py +24 -246
  81. tme/structure.py +14 -14
  82. pytme-0.2.2.dist-info/RECORD +0 -74
  83. tme/matching_memory.py +0 -383
  84. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
  85. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
  86. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,173 @@
1
+ from tempfile import mkstemp
2
+
3
+ import pytest
4
+ import numpy as np
5
+
6
+ from tme import Orientations
7
+
8
+
9
+ class TestDensity:
10
+ def setup_method(self):
11
+ self.translations = np.random.rand(100, 3).astype(np.float32)
12
+ self.rotations = np.random.rand(100, 3).astype(np.float32)
13
+ self.scores = np.random.rand(100).astype(np.float32)
14
+ self.details = np.random.rand(100).astype(np.float32)
15
+
16
+ self.orientations = Orientations(
17
+ translations=self.translations,
18
+ rotations=self.rotations,
19
+ scores=self.scores,
20
+ details=self.details,
21
+ )
22
+
23
+ def teardown_method(self):
24
+ self.translations = None
25
+ self.rotations = None
26
+ self.scores = None
27
+ self.details = None
28
+ self.orientations = None
29
+
30
+ def test_initialization(self):
31
+ orientations = Orientations(
32
+ translations=self.translations,
33
+ rotations=self.rotations,
34
+ scores=self.scores,
35
+ details=self.details,
36
+ )
37
+
38
+ assert np.array_equal(self.translations, orientations.translations)
39
+ assert np.array_equal(self.rotations, orientations.rotations)
40
+ assert np.array_equal(self.scores, orientations.scores)
41
+ assert np.array_equal(self.details, orientations.details)
42
+
43
+ def test_initialization_type(self):
44
+ orientations = Orientations(
45
+ translations=self.translations.astype(int),
46
+ rotations=self.rotations.astype(int),
47
+ scores=self.scores.astype(int),
48
+ details=self.details.astype(int),
49
+ )
50
+ assert np.issubdtype(orientations.translations.dtype, np.floating)
51
+ assert np.issubdtype(orientations.rotations.dtype, np.floating)
52
+ assert np.issubdtype(orientations.scores.dtype, np.floating)
53
+ assert np.issubdtype(orientations.details.dtype, np.floating)
54
+
55
+ def test_initialization_error(self):
56
+ with pytest.raises(ValueError):
57
+ _ = Orientations(
58
+ translations=self.translations,
59
+ rotations=np.random.rand(self.translations.shape[0] + 1),
60
+ scores=self.scores,
61
+ details=self.details,
62
+ )
63
+
64
+ with pytest.raises(ValueError):
65
+ _ = Orientations(
66
+ translations=np.random.rand(self.translations.shape[0]),
67
+ rotations=np.random.rand(self.translations.shape[0] + 1),
68
+ scores=self.scores,
69
+ details=self.details,
70
+ )
71
+ _ = Orientations(
72
+ translations=self.translations,
73
+ rotations=np.random.rand(self.translations.shape[0]),
74
+ scores=self.scores,
75
+ details=self.details,
76
+ )
77
+
78
+ assert True
79
+
80
+ @pytest.mark.parametrize("file_format", ("text", "relion", "tbl"))
81
+ def test_to_file(self, file_format: str):
82
+ _, output_file = mkstemp(suffix=f".{file_format}")
83
+ self.orientations.to_file(output_file)
84
+ assert True
85
+
86
+ @pytest.mark.parametrize("file_format", ("text", "star", "tbl"))
87
+ def test_from_file(self, file_format: str):
88
+ _, output_file = mkstemp(suffix=f".{file_format}")
89
+ self.orientations.to_file(output_file)
90
+ orientations_new = Orientations.from_file(output_file)
91
+
92
+ assert np.array_equal(
93
+ self.orientations.translations, orientations_new.translations
94
+ )
95
+
96
+ @pytest.mark.parametrize("input_format", ("text", "star", "tbl"))
97
+ @pytest.mark.parametrize("output_format", ("text", "star", "tbl"))
98
+ def test_file_format_io(self, input_format: str, output_format: str):
99
+ _, output_file = mkstemp(suffix=f".{input_format}")
100
+ _, output_file2 = mkstemp(suffix=f".{output_format}")
101
+
102
+ self.orientations.to_file(output_file)
103
+ orientations_new = Orientations.from_file(output_file)
104
+ orientations_new.to_file(output_file2)
105
+
106
+ assert True
107
+
108
+ @pytest.mark.parametrize("drop_oob", (True, False))
109
+ @pytest.mark.parametrize("shape", (10, 40, 80))
110
+ @pytest.mark.parametrize("odd", (True, False))
111
+ def test_extraction(self, shape: int, drop_oob: bool, odd: bool):
112
+ if odd:
113
+ shape = shape + (1 - shape % 2)
114
+
115
+ data = np.random.rand(50, 50, 50)
116
+ translations = np.array([[25, 25, 25], [15, 25, 35], [35, 25, 15], [0, 15, 49]])
117
+ orientations = Orientations(
118
+ translations=translations,
119
+ rotations=np.random.rand(*translations.shape),
120
+ scores=np.random.rand(translations.shape[0]),
121
+ details=np.random.rand(translations.shape[0]),
122
+ )
123
+ extraction_shape = np.repeat(np.array(shape), data.ndim)
124
+ orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
125
+ target_shape=data.shape,
126
+ extraction_shape=extraction_shape,
127
+ drop_out_of_box=drop_oob,
128
+ return_orientations=True,
129
+ )
130
+ assert orientations.translations.shape[0] == len(cand_slices)
131
+ assert len(cand_slices) == len(obs_slices)
132
+
133
+ cand_slices2, obs_slices2 = orientations.get_extraction_slices(
134
+ target_shape=data.shape,
135
+ extraction_shape=extraction_shape,
136
+ drop_out_of_box=drop_oob,
137
+ return_orientations=False,
138
+ )
139
+ assert cand_slices == cand_slices2
140
+ assert obs_slices == obs_slices2
141
+
142
+ # Check whether extraction slices are pasted in center
143
+ out = np.zeros(extraction_shape, dtype=data.dtype)
144
+ center = np.divide(extraction_shape, 2) + np.mod(extraction_shape, 2)
145
+ center = center.astype(int)
146
+ for index, (cand_slice, obs_slice) in enumerate(zip(cand_slices, obs_slices)):
147
+ out[cand_slice] = data[obs_slice]
148
+ assert np.allclose(
149
+ out[tuple(center)],
150
+ data[tuple(orientations.translations[index].astype(int))],
151
+ )
152
+
153
+ @pytest.mark.parametrize(
154
+ "order", (("x", "y", "z"), ("z", "y", "x"), ("y", "x", "z"))
155
+ )
156
+ def test_txt_sort(self, order: str):
157
+ _, output_file = mkstemp(suffix=".tsv")
158
+ translations = ((50, 30, 20), (10, 5, 30))
159
+
160
+ with open(output_file, mode="w", encoding="utf-8") as ofile:
161
+ ofile.write("\t".join([str(x) for x in order]) + "\n")
162
+ for translation in translations:
163
+ ofile.write("\t".join([str(x) for x in translation]) + "\n")
164
+
165
+ translations = np.array(translations).astype(np.float32)
166
+ orientations = Orientations.from_file(output_file)
167
+
168
+ out_order = zip(order, range(len(order)))
169
+ out_order = tuple(
170
+ x[1] for x in sorted(out_order, key=lambda x: x[0], reverse=True)
171
+ )
172
+
173
+ assert np.array_equal(translations[..., out_order], orientations.translations)
@@ -0,0 +1,95 @@
1
+ import pytest
2
+ import numpy as np
3
+
4
+ from tme.package import GaussianKernel, KernelFitting
5
+ from tme.analyzer import PeakCallerFast
6
+
7
+ class TestKernelFitting:
8
+ def setup_method(self):
9
+ self.number_of_peaks = 100
10
+ self.min_distance = 5
11
+ self.data = np.random.rand(100, 100, 100)
12
+ self.rotation_matrix = np.eye(3)
13
+
14
+ self.kernel_box = np.random.choice(np.arange(5, 15), size=self.data.ndim)
15
+ self.template = np.random.rand(*self.kernel_box)
16
+
17
+ @pytest.mark.parametrize(
18
+ "kernel",
19
+ [
20
+ (GaussianKernel),
21
+ ],
22
+ )
23
+ def test_initialization(self, kernel):
24
+ kernel_params = kernel.estimate_parameters(self.template)
25
+ _ = KernelFitting(
26
+ number_of_peaks=self.number_of_peaks,
27
+ min_distance=self.min_distance,
28
+ kernel_box=self.kernel_box,
29
+ kernel_class=kernel,
30
+ kernel_params=kernel_params,
31
+ peak_caller=PeakCallerFast,
32
+ )
33
+
34
+ @pytest.mark.parametrize("kernel",[GaussianKernel])
35
+ def test__call__(self, kernel):
36
+ kernel_params = kernel.estimate_parameters(self.template)
37
+ score_analyzer = KernelFitting(
38
+ number_of_peaks=self.number_of_peaks,
39
+ min_distance=self.min_distance,
40
+ kernel_box=self.kernel_box,
41
+ kernel_class=kernel,
42
+ kernel_params=kernel_params,
43
+ peak_caller=PeakCallerFast,
44
+ )
45
+ score_analyzer(self.data, rotation_matrix=self.rotation_matrix)
46
+
47
+
48
+ class TestKernel:
49
+ def setup_method(self):
50
+ self.number_of_peaks = 100
51
+ self.min_distance = 5
52
+ self.data = np.random.rand(100, 100, 100)
53
+ self.rotation_matrix = np.eye(3)
54
+
55
+ def test_initialization(self):
56
+ _ = GaussianKernel()
57
+
58
+ @pytest.mark.parametrize(
59
+ "kernel",
60
+ [
61
+ GaussianKernel,
62
+ ],
63
+ )
64
+ def test_fit(self, kernel):
65
+ params, success, final_error = kernel.fit(self.data)
66
+ assert isinstance(params, tuple)
67
+ assert len(params) == 3
68
+
69
+ @pytest.mark.parametrize(
70
+ "kernel",
71
+ [
72
+ GaussianKernel,
73
+ ],
74
+ )
75
+ def test_fit_height(self, kernel):
76
+ params, success, final_error = kernel.fit(self.data)
77
+ height, mean, cov = params
78
+ params, succes, final_error = kernel.fit_height(
79
+ data=self.data, mean=mean, cov=cov
80
+ )
81
+ assert np.allclose(params[1], mean)
82
+ assert np.allclose(params[2], cov)
83
+
84
+ @pytest.mark.parametrize(
85
+ "kernel",
86
+ [
87
+ GaussianKernel,
88
+ ],
89
+ )
90
+ def test_score(self, kernel):
91
+ params, *_ = kernel.fit(self.data)
92
+ params2, *_ = kernel.fit(self.data)
93
+
94
+ score = kernel.score(params, params2)
95
+ assert score == 1
tests/test_parser.py ADDED
@@ -0,0 +1,33 @@
1
+ import pytest
2
+ from importlib_resources import files
3
+
4
+ from tme.parser import Parser, PDBParser
5
+
6
+
7
+ class TestParser:
8
+ def setup_method(self):
9
+ self.pdb_file = str(files("tests.data").joinpath("Structures/5khe.pdb"))
10
+
11
+ def teardown_method(self):
12
+ self.pdb_file = None
13
+
14
+ def test_initialize_parser_error(self):
15
+ with pytest.raises(TypeError):
16
+ _ = Parser(self.pdb_file)
17
+
18
+ def test_parser_get(self):
19
+ parser = PDBParser(self.pdb_file)
20
+ assert parser.get("NOTPRESENT", None) is None
21
+ assert parser.get("record_type", None) is not None
22
+
23
+ def test_parser_keys(self):
24
+ parser = PDBParser(self.pdb_file)
25
+ assert parser.keys() == parser._data.keys()
26
+
27
+ def test_parser_values(self):
28
+ parser = PDBParser(self.pdb_file)
29
+ assert str(parser.values()) == str(parser._data.values())
30
+
31
+ def test_parser_items(self):
32
+ parser = PDBParser(self.pdb_file)
33
+ assert parser.items() == parser._data.items()
@@ -0,0 +1,243 @@
1
+ from os import remove
2
+ from tempfile import mkstemp
3
+ from importlib_resources import files
4
+
5
+ import pytest
6
+ import numpy as np
7
+
8
+ from tme import Structure
9
+ from tme.matching_utils import euler_to_rotationmatrix, minimum_enclosing_box
10
+
11
+
12
+ STRUCTURE_ATTRIBUTES = [
13
+ "record_type",
14
+ "atom_serial_number",
15
+ "atom_name",
16
+ "atom_coordinate",
17
+ "alternate_location_indicator",
18
+ "residue_name",
19
+ "chain_identifier",
20
+ "residue_sequence_number",
21
+ "code_for_residue_insertion",
22
+ "occupancy",
23
+ "temperature_factor",
24
+ "segment_identifier",
25
+ "element_symbol",
26
+ "charge",
27
+ "metadata",
28
+ ]
29
+
30
+
31
+ class TestStructure:
32
+ def setup_method(self):
33
+ self.structure = Structure.from_file(
34
+ str(files("tests.data").joinpath("Structures/5khe.cif"))
35
+ )
36
+ _, self.path = mkstemp()
37
+
38
+ def teardown_method(self):
39
+ del self.structure
40
+ remove(self.path)
41
+
42
+ def compare_structures(self, structure1, structure2, exclude_attributes=[]):
43
+ for attribute in STRUCTURE_ATTRIBUTES:
44
+ if attribute in exclude_attributes:
45
+ continue
46
+ value = getattr(structure1, attribute)
47
+ value_comparison = getattr(structure2, attribute)
48
+ if isinstance(value, np.ndarray):
49
+ assert np.all(value_comparison == value)
50
+ else:
51
+ assert value == value_comparison
52
+
53
+ def test_initialization(self):
54
+ structure = Structure(
55
+ record_type=self.structure.record_type,
56
+ atom_serial_number=self.structure.atom_serial_number,
57
+ atom_name=self.structure.atom_name,
58
+ atom_coordinate=self.structure.atom_coordinate,
59
+ alternate_location_indicator=self.structure.alternate_location_indicator,
60
+ residue_name=self.structure.residue_name,
61
+ chain_identifier=self.structure.chain_identifier,
62
+ residue_sequence_number=self.structure.residue_sequence_number,
63
+ code_for_residue_insertion=self.structure.code_for_residue_insertion,
64
+ occupancy=self.structure.occupancy,
65
+ temperature_factor=self.structure.temperature_factor,
66
+ segment_identifier=self.structure.segment_identifier,
67
+ element_symbol=self.structure.element_symbol,
68
+ charge=self.structure.charge,
69
+ metadata=self.structure.metadata,
70
+ )
71
+
72
+ for attribute in STRUCTURE_ATTRIBUTES:
73
+ value = getattr(self.structure, attribute)
74
+ value_comparison = getattr(structure, attribute)
75
+ if isinstance(value, np.ndarray):
76
+ assert np.all(value_comparison == value)
77
+ else:
78
+ assert value == value_comparison
79
+
80
+ @pytest.mark.parametrize(
81
+ "modified_attribute",
82
+ [
83
+ ("record_type"),
84
+ ("atom_serial_number"),
85
+ ("atom_name"),
86
+ ("atom_coordinate"),
87
+ ("alternate_location_indicator"),
88
+ ("residue_name"),
89
+ ("chain_identifier"),
90
+ ("residue_sequence_number"),
91
+ ("code_for_residue_insertion"),
92
+ ("occupancy"),
93
+ ("temperature_factor"),
94
+ ("segment_identifier"),
95
+ ("element_symbol"),
96
+ ],
97
+ )
98
+ def test_initialization_errors(self, modified_attribute):
99
+ kwargs = {
100
+ attribute: getattr(self.structure, attribute)
101
+ for attribute in STRUCTURE_ATTRIBUTES
102
+ if attribute != modified_attribute
103
+ }
104
+ kwargs[modified_attribute] = getattr(self.structure, modified_attribute)[:1]
105
+
106
+ with pytest.raises(ValueError):
107
+ Structure(**kwargs)
108
+
109
+ def test__getitem__(self):
110
+ ret_single_index = self.structure[1]
111
+ ret = self.structure[[1]]
112
+ self.compare_structures(ret_single_index, ret)
113
+
114
+ ret = self.structure[self.structure.record_type == "ATOM"]
115
+ assert np.all(ret.record_type == "ATOM")
116
+
117
+ ret = self.structure[self.structure.element_symbol == "C"]
118
+ assert np.all(ret.element_symbol == "C")
119
+
120
+ def test__repr__(self):
121
+ unique_chains = "-".join(
122
+ [
123
+ ",".join([str(x) for x in entity])
124
+ for entity in self.structure.metadata["unique_chains"]
125
+ ]
126
+ )
127
+
128
+ min_atom = np.min(self.structure.atom_serial_number)
129
+ max_atom = np.max(self.structure.atom_serial_number)
130
+ n_atom = self.structure.atom_serial_number.size
131
+
132
+ min_residue = np.min(self.structure.residue_sequence_number)
133
+ max_residue = np.max(self.structure.residue_sequence_number)
134
+ n_residue = np.unique(self.structure.residue_sequence_number).size
135
+
136
+ repr_str = (
137
+ f"Structure object at {id(self.structure)}\n"
138
+ f"Unique Chains: {unique_chains}, "
139
+ f"Atom Range: {min_atom}-{max_atom} [N = {n_atom}], "
140
+ f"Residue Range: {min_residue}-{max_residue} [N = {n_residue}]"
141
+ )
142
+ assert repr_str == self.structure.__repr__()
143
+
144
+ @pytest.mark.parametrize(
145
+ "path",
146
+ [
147
+ str(files("tests.data").joinpath("Structures/5khe.cif")),
148
+ str(files("tests.data").joinpath("Structures/5khe.pdb")),
149
+ ],
150
+ )
151
+ def test_fromfile(self, path):
152
+ _ = Structure.from_file(path)
153
+
154
+ def test_fromfile_error(self):
155
+ with pytest.raises(NotImplementedError):
156
+ _ = Structure.from_file("madeup.extension")
157
+
158
+ @pytest.mark.parametrize("file_format", [("cif"), ("pdb")])
159
+ def test_to_file(self, file_format):
160
+ _, path = mkstemp()
161
+ path = f"{path}.{file_format}"
162
+ self.structure.to_file(path)
163
+ read = self.structure.from_file(path)
164
+ comparison = self.structure.copy()
165
+
166
+ self.compare_structures(comparison, read, exclude_attributes=["metadata"])
167
+
168
+ def test_to_file_error(self):
169
+ _, path = mkstemp()
170
+ path = f"{path}.RAISERROR"
171
+ with pytest.raises(NotImplementedError):
172
+ self.structure.to_file(path)
173
+
174
+ def test_subset_by_chain(self):
175
+ chain = "A"
176
+ ret = self.structure.subset_by_chain(chain=chain)
177
+ assert np.all(ret.chain_identifier == chain)
178
+
179
+ def test_subset_by_chain_range(self):
180
+ chain, start, stop = "A", 0, 20
181
+ ret = self.structure.subset_by_range(chain=chain, start=start, stop=stop)
182
+ assert np.all(ret.chain_identifier == chain)
183
+ assert np.all(
184
+ np.logical_and(
185
+ ret.residue_sequence_number >= start,
186
+ ret.residue_sequence_number <= stop,
187
+ )
188
+ )
189
+
190
+ def test_center_of_mass(self):
191
+ center_of_mass = self.structure.center_of_mass()
192
+ assert center_of_mass.shape[0] == self.structure.atom_coordinate.shape[1]
193
+ assert np.allclose(center_of_mass, [-0.89391639, 29.94908928, -2.64736741])
194
+
195
+ def test_centered(self):
196
+ ret, translation = self.structure.centered()
197
+ box = minimum_enclosing_box(coordinates=self.structure.atom_coordinate.T)
198
+ assert np.allclose(ret.center_of_mass(), np.divide(box, 2), atol=1)
199
+
200
+ def test__get_atom_weights_error(self):
201
+ with pytest.raises(NotImplementedError):
202
+ self.structure._get_atom_weights(
203
+ self.structure.atom_name, weight_type="RAISEERROR"
204
+ )
205
+
206
+ def test_compare_structures(self):
207
+ rmsd = Structure.compare_structures(self.structure, self.structure)
208
+ assert rmsd == 0
209
+
210
+ rmsd = Structure.compare_structures(
211
+ self.structure, self.structure, weighted=True
212
+ )
213
+ assert rmsd == 0
214
+
215
+ translation = (3, 0, 0)
216
+ structure_transform = self.structure.rigid_transform(
217
+ translation=translation,
218
+ rotation_matrix=np.eye(self.structure.atom_coordinate.shape[1]),
219
+ )
220
+ rmsd = Structure.compare_structures(self.structure, structure_transform)
221
+ assert np.allclose(rmsd, np.linalg.norm(translation))
222
+
223
+ def test_comopare_structures_error(self):
224
+ ret = self.structure[[1, 2, 3, 4, 5]]
225
+ with pytest.raises(ValueError):
226
+ Structure.compare_structures(self.structure, ret)
227
+
228
+ def test_align_structures(self):
229
+ rotation_matrix = euler_to_rotationmatrix((20, -10, 45))
230
+ translation = (10, 0, -15)
231
+
232
+ structure_transform = self.structure.rigid_transform(
233
+ rotation_matrix=rotation_matrix, translation=translation
234
+ )
235
+ aligned, final_rmsd = Structure.align_structures(
236
+ self.structure, structure_transform
237
+ )
238
+ assert final_rmsd <= 0.1
239
+
240
+ aligned, final_rmsd = Structure.align_structures(
241
+ self.structure, structure_transform, sampling_rate=1
242
+ )
243
+ assert final_rmsd <= 1
tme/__init__.py CHANGED
@@ -1,4 +1,3 @@
1
- from . import extensions
2
1
  from .__version__ import __version__
3
2
  from .density import Density
4
3
  from .preprocessor import Preprocessor
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.2.2"
1
+ __version__ = "0.2.4"
tme/analyzer.py CHANGED
@@ -459,14 +459,14 @@ class PeakCaller(ABC):
459
459
 
460
460
  final_order = top_scores[
461
461
  filter_points_indices(
462
- coordinates=peak_positions[top_scores],
462
+ coordinates=peak_positions[top_scores, :],
463
463
  min_distance=self.min_distance,
464
464
  batch_dims=self.batch_dims,
465
465
  )
466
466
  ]
467
467
 
468
- self.peak_list[0] = peak_positions[final_order,]
469
- self.peak_list[1] = rotations[final_order,]
468
+ self.peak_list[0] = peak_positions[final_order, :]
469
+ self.peak_list[1] = rotations[final_order, :]
470
470
  self.peak_list[2] = peak_scores[final_order]
471
471
  self.peak_list[3] = peak_details[final_order]
472
472
 
@@ -475,6 +475,7 @@ class PeakCaller(ABC):
475
475
  fast_shape: Tuple[int],
476
476
  targetshape: Tuple[int],
477
477
  templateshape: Tuple[int],
478
+ convolution_shape: Tuple[int] = None,
478
479
  fourier_shift: Tuple[int] = None,
479
480
  convolution_mode: str = None,
480
481
  shared_memory_handler=None,
@@ -488,6 +489,7 @@ class PeakCaller(ABC):
488
489
  return self
489
490
 
490
491
  # Wrap peaks around score space
492
+ convolution_shape = be.to_backend_array(convolution_shape)
491
493
  fast_shape = be.to_backend_array(fast_shape)
492
494
  if fourier_shift is not None:
493
495
  fourier_shift = be.to_backend_array(fourier_shift)
@@ -501,10 +503,9 @@ class PeakCaller(ABC):
501
503
  )
502
504
 
503
505
  # Remove padding to fast Fourier (and potential full convolution) shape
506
+ output_shape = convolution_shape
504
507
  targetshape = be.to_backend_array(targetshape)
505
508
  templateshape = be.to_backend_array(templateshape)
506
- fast_shape = be.minimum(be.add(targetshape, templateshape) - 1, fast_shape)
507
- output_shape = fast_shape
508
509
  if convolution_mode == "same":
509
510
  output_shape = targetshape
510
511
  elif convolution_mode == "valid":
@@ -515,7 +516,7 @@ class PeakCaller(ABC):
515
516
 
516
517
  output_shape = be.to_backend_array(output_shape)
517
518
  starts = be.astype(
518
- be.divide(be.subtract(fast_shape, output_shape), 2),
519
+ be.divide(be.subtract(convolution_shape, output_shape), 2),
519
520
  be._int_dtype,
520
521
  )
521
522
  stops = be.add(starts, output_shape)
@@ -1019,6 +1020,7 @@ class MaxScoreOverRotations:
1019
1020
  self,
1020
1021
  targetshape: Tuple[int],
1021
1022
  templateshape: Tuple[int],
1023
+ convolution_shape: Tuple[int],
1022
1024
  fourier_shift: Tuple[int] = None,
1023
1025
  convolution_mode: str = None,
1024
1026
  shared_memory_handler=None,
@@ -1039,6 +1041,7 @@ class MaxScoreOverRotations:
1039
1041
  "s1": targetshape,
1040
1042
  "s2": templateshape,
1041
1043
  "convolution_mode": convolution_mode,
1044
+ "convolution_shape": convolution_shape,
1042
1045
  }
1043
1046
  if convolution_mode is not None:
1044
1047
  scores = apply_convolution_mode(scores, **convargs)
tme/backends/__init__.py CHANGED
@@ -147,7 +147,7 @@ class BackendManager:
147
147
  _dependencies = {
148
148
  "numpyfftw": "numpy",
149
149
  "cupy": "cupy",
150
- "pytorch": "pytorch",
150
+ "pytorch": "torch",
151
151
  "mlx": "mlx",
152
152
  "jax": "jax",
153
153
  }
@@ -19,9 +19,9 @@ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
19
19
  """
20
20
  Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
21
21
  """
22
- template_ft = jnp.fft.rfftn(template)
22
+ template_ft = jnp.fft.rfftn(template, s=template.shape)
23
23
  template_ft = template_ft.at[:].multiply(ft_target)
24
- correlation = jnp.fft.irfftn(template_ft)
24
+ correlation = jnp.fft.irfftn(template_ft, s=template.shape)
25
25
  return correlation
26
26
 
27
27
 
@@ -77,14 +77,15 @@ def _reciprocal_target_std(
77
77
  --------
78
78
  :py:meth:`tme.matching_exhaustive.flc_scoring`.
79
79
  """
80
- ft_template_mask = jnp.fft.rfftn(template_mask)
80
+ ft_shape = template_mask.shape
81
+ ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
81
82
 
82
83
  # E(X^2)- E(X)^2
83
- exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask)
84
+ exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=ft_shape)
84
85
  exp_sq = exp_sq.at[:].divide(n_observations)
85
86
 
86
87
  ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
87
- sq_exp = jnp.fft.irfftn(ft_template_mask)
88
+ sq_exp = jnp.fft.irfftn(ft_template_mask, s=ft_shape)
88
89
  sq_exp = sq_exp.at[:].divide(n_observations)
89
90
  sq_exp = sq_exp.at[:].power(2)
90
91
 
@@ -99,7 +100,7 @@ def _reciprocal_target_std(
99
100
 
100
101
 
101
102
  def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
102
- arr_ft = jnp.fft.rfftn(arr)
103
+ arr_ft = jnp.fft.rfftn(arr, s=arr.shape)
103
104
  arr_ft = arr_ft.at[:].multiply(arr_filter)
104
105
  return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
105
106
 
@@ -107,6 +108,7 @@ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> Backen
107
108
  def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
108
109
  return arr
109
110
 
111
+
110
112
  @partial(
111
113
  pmap,
112
114
  in_axes=(0,) + (None,) * 6,
@@ -127,8 +129,8 @@ def scan(
127
129
  if hasattr(target_filter, "shape"):
128
130
  target = _apply_fourier_filter(target, target_filter)
129
131
 
130
- ft_target = jnp.fft.rfftn(target)
131
- ft_target2 = jnp.fft.rfftn(jnp.square(target))
132
+ ft_target = jnp.fft.rfftn(target, s=fast_shape)
133
+ ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
132
134
  inv_denominator, target, scoring_func = None, None, _flc_scoring
133
135
  if not rotate_mask:
134
136
  n_observations = jnp.sum(template_mask)