pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,153 @@
1
+ from importlib_resources import files
2
+ from itertools import combinations, chain, product
3
+
4
+ import pytest
5
+ import numpy as np
6
+
7
+ from tme import Density
8
+ from scipy.spatial.transform import Rotation
9
+ from scipy.signal import correlate
10
+
11
+ from tme.rotations import (
12
+ euler_from_rotationmatrix,
13
+ euler_to_rotationmatrix,
14
+ get_cone_rotations,
15
+ align_vectors,
16
+ get_rotation_matrices,
17
+ )
18
+ from tme.matching_utils import (
19
+ elliptical_mask,
20
+ split_shape,
21
+ compute_full_convolution_index,
22
+ )
23
+
24
+ BASEPATH = files("tests.data")
25
+
26
+
27
+ class TestRotations:
28
+ def setup_method(self):
29
+ self.density = Density.from_file(str(BASEPATH.joinpath("Raw/em_map.map")))
30
+ self.structure_density = Density.from_structure(
31
+ filename_or_structure=str(BASEPATH.joinpath("Structures/5khe.cif")),
32
+ origin=self.density.origin,
33
+ shape=self.density.shape,
34
+ sampling_rate=self.density.sampling_rate,
35
+ )
36
+
37
+ @pytest.mark.parametrize(
38
+ "initial_vector, target_vector, convention",
39
+ [
40
+ ([1, 0, 0], [0, 1, 0], None),
41
+ ([0, 1, 0], [0, 0, 1], "zyx"),
42
+ ([1, 1, 1], [1, 0, 0], "xyz"),
43
+ ],
44
+ )
45
+ def test_align_vectors(self, initial_vector, target_vector, convention):
46
+ result = align_vectors(initial_vector, target_vector, convention)
47
+
48
+ assert isinstance(result, np.ndarray)
49
+ if convention is None:
50
+ assert result.shape == (3, 3)
51
+ assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
52
+ else:
53
+ assert len(result) == 3
54
+ result = Rotation.from_euler(convention, result, degrees=True).as_matrix()
55
+ assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
56
+
57
+ rotated = np.dot(Rotation.from_matrix(result).as_matrix(), initial_vector)
58
+ assert np.allclose(
59
+ rotated / np.linalg.norm(rotated),
60
+ target_vector / np.linalg.norm(target_vector),
61
+ atol=1e-6,
62
+ )
63
+
64
+ @pytest.mark.parametrize(
65
+ "cone_angle, cone_sampling, axis_angle, axis_sampling, vector, n_symmetry, convention",
66
+ [
67
+ (30, 5, 360, None, (1, 0, 0), 1, None),
68
+ (45, 10, 180, 15, (0, 1, 0), 2, "zyx"),
69
+ (60, 15, 90, 30, (0, 0, 1), 4, "xyz"),
70
+ ],
71
+ )
72
+ def test_get_cone_rotations(
73
+ self,
74
+ cone_angle,
75
+ cone_sampling,
76
+ axis_angle,
77
+ axis_sampling,
78
+ vector,
79
+ n_symmetry,
80
+ convention,
81
+ ):
82
+ result = get_cone_rotations(
83
+ cone_angle=cone_angle,
84
+ cone_sampling=cone_sampling,
85
+ axis_angle=axis_angle,
86
+ axis_sampling=axis_sampling,
87
+ reference=vector,
88
+ n_symmetry=n_symmetry,
89
+ seq=convention,
90
+ )
91
+
92
+ assert isinstance(result, np.ndarray)
93
+ if convention is None:
94
+ assert result.shape[1:] == (3, 3)
95
+ else:
96
+ assert result.shape[1] == 3
97
+
98
+ def test_euler_conversion(self):
99
+ rotation_matrix_initial = np.array(
100
+ [
101
+ [0.35355339, 0.61237244, -0.70710678],
102
+ [-0.8660254, 0.5, -0.0],
103
+ [0.35355339, 0.61237244, 0.70710678],
104
+ ]
105
+ )
106
+ euler_angles = euler_from_rotationmatrix(rotation_matrix_initial)
107
+ rotation_matrix_converted = euler_to_rotationmatrix(euler_angles)
108
+ assert np.allclose(
109
+ rotation_matrix_initial, rotation_matrix_converted, atol=1e-6
110
+ )
111
+
112
+ @pytest.mark.parametrize("dim", range(1, 3, 5))
113
+ @pytest.mark.parametrize("angular_sampling", [10, 15, 20])
114
+ def test_get_rotation_matrices(self, dim, angular_sampling):
115
+ rotation_matrices = get_rotation_matrices(
116
+ angular_sampling=angular_sampling, dim=dim
117
+ )
118
+ assert np.allclose(rotation_matrices[0] @ rotation_matrices[0].T, np.eye(dim))
119
+
120
+ def test_split_correlation(self):
121
+ arr1 = elliptical_mask(shape=(50, 51), center=(20, 30), radius=5)
122
+
123
+ arr2 = elliptical_mask(shape=(41, 36), center=(25, 20), radius=5)
124
+ s = range(arr1.ndim)
125
+ outer_split = chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
126
+
127
+ s = range(arr2.ndim)
128
+ inner_split = chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
129
+
130
+ outer_splits = [dict(zip(i, [2] * len(i))) for i in list(outer_split)]
131
+ inner_splits = [dict(zip(i, [2] * len(i))) for i in list(inner_split)]
132
+
133
+ for outer_split, inner_split in product(outer_splits, inner_splits):
134
+ splits1 = split_shape(
135
+ shape=arr1.shape, splits=outer_split, equal_shape=False
136
+ )
137
+ splits2 = split_shape(
138
+ shape=arr2.shape, splits=inner_split, equal_shape=False
139
+ )
140
+
141
+ full = correlate(arr1, arr2, method="direct", mode="full")
142
+ temp = np.zeros_like(full)
143
+
144
+ for arr1_split, arr2_split in product(splits1, splits2):
145
+ correlation = correlate(
146
+ arr1[arr1_split], arr2[arr2_split], method="direct", mode="full"
147
+ )
148
+ score_slice = compute_full_convolution_index(
149
+ arr1.shape, arr2.shape, arr1_split, arr2_split
150
+ )
151
+ temp[score_slice] += correlation
152
+
153
+ assert np.allclose(temp, full)
@@ -0,0 +1,247 @@
1
+ from os import remove
2
+ from tempfile import mkstemp
3
+ from importlib_resources import files
4
+
5
+ import pytest
6
+ import numpy as np
7
+
8
+ from tme import Structure
9
+ from tme.matching_utils import minimum_enclosing_box
10
+ from tme.rotations import euler_to_rotationmatrix
11
+
12
+
13
+ STRUCTURE_ATTRIBUTES = [
14
+ "record_type",
15
+ "atom_serial_number",
16
+ "atom_name",
17
+ "atom_coordinate",
18
+ "alternate_location_indicator",
19
+ "residue_name",
20
+ "chain_identifier",
21
+ "residue_sequence_number",
22
+ "code_for_residue_insertion",
23
+ "occupancy",
24
+ "temperature_factor",
25
+ "segment_identifier",
26
+ "element_symbol",
27
+ "charge",
28
+ "metadata",
29
+ ]
30
+
31
+
32
+ class TestStructure:
33
+ def setup_method(self):
34
+ self.structure = Structure.from_file(
35
+ str(files("tests.data").joinpath("Structures/5khe.cif"))
36
+ )
37
+ _, self.path = mkstemp()
38
+
39
+ def teardown_method(self):
40
+ del self.structure
41
+ remove(self.path)
42
+
43
+ def compare_structures(self, structure1, structure2, exclude_attributes=[]):
44
+ for attribute in STRUCTURE_ATTRIBUTES:
45
+ if attribute in exclude_attributes:
46
+ continue
47
+ value = getattr(structure1, attribute)
48
+ value_comparison = getattr(structure2, attribute)
49
+ if isinstance(value, np.ndarray):
50
+ assert np.all(value_comparison == value)
51
+ else:
52
+ assert value == value_comparison
53
+
54
+ def test_initialization(self):
55
+ structure = Structure(
56
+ record_type=self.structure.record_type,
57
+ atom_serial_number=self.structure.atom_serial_number,
58
+ atom_name=self.structure.atom_name,
59
+ atom_coordinate=self.structure.atom_coordinate,
60
+ alternate_location_indicator=self.structure.alternate_location_indicator,
61
+ residue_name=self.structure.residue_name,
62
+ chain_identifier=self.structure.chain_identifier,
63
+ residue_sequence_number=self.structure.residue_sequence_number,
64
+ code_for_residue_insertion=self.structure.code_for_residue_insertion,
65
+ occupancy=self.structure.occupancy,
66
+ temperature_factor=self.structure.temperature_factor,
67
+ segment_identifier=self.structure.segment_identifier,
68
+ element_symbol=self.structure.element_symbol,
69
+ charge=self.structure.charge,
70
+ metadata=self.structure.metadata,
71
+ )
72
+
73
+ for attribute in STRUCTURE_ATTRIBUTES:
74
+ value = getattr(self.structure, attribute)
75
+ value_comparison = getattr(structure, attribute)
76
+ if isinstance(value, np.ndarray):
77
+ assert np.all(value_comparison == value)
78
+ else:
79
+ assert value == value_comparison
80
+
81
+ @pytest.mark.parametrize(
82
+ "modified_attribute",
83
+ [
84
+ ("record_type"),
85
+ ("atom_serial_number"),
86
+ ("atom_name"),
87
+ ("atom_coordinate"),
88
+ ("alternate_location_indicator"),
89
+ ("residue_name"),
90
+ ("chain_identifier"),
91
+ ("residue_sequence_number"),
92
+ ("code_for_residue_insertion"),
93
+ ("occupancy"),
94
+ ("temperature_factor"),
95
+ ("segment_identifier"),
96
+ ("element_symbol"),
97
+ ],
98
+ )
99
+ def test_initialization_errors(self, modified_attribute):
100
+ kwargs = {
101
+ attribute: getattr(self.structure, attribute)
102
+ for attribute in STRUCTURE_ATTRIBUTES
103
+ if attribute != modified_attribute
104
+ }
105
+ kwargs[modified_attribute] = getattr(self.structure, modified_attribute)[:1]
106
+
107
+ with pytest.raises(ValueError):
108
+ Structure(**kwargs)
109
+
110
+ def test__getitem__(self):
111
+ ret_single_index = self.structure[1]
112
+ ret = self.structure[[1]]
113
+ self.compare_structures(ret_single_index, ret)
114
+
115
+ ret = self.structure[self.structure.record_type == "ATOM"]
116
+ assert np.all(ret.record_type == "ATOM")
117
+
118
+ ret = self.structure[self.structure.element_symbol == "C"]
119
+ assert np.all(ret.element_symbol == "C")
120
+
121
+ def test__repr__(self):
122
+ unique_chains = "-".join(
123
+ [
124
+ ",".join([str(x) for x in entity])
125
+ for entity in self.structure.metadata["unique_chains"]
126
+ ]
127
+ )
128
+
129
+ min_atom = np.min(self.structure.atom_serial_number)
130
+ max_atom = np.max(self.structure.atom_serial_number)
131
+ n_atom = self.structure.atom_serial_number.size
132
+
133
+ min_residue = np.min(self.structure.residue_sequence_number)
134
+ max_residue = np.max(self.structure.residue_sequence_number)
135
+ n_residue = np.unique(self.structure.residue_sequence_number).size
136
+
137
+ repr_str = (
138
+ f"Structure object at {id(self.structure)}\n"
139
+ f"Unique Chains: {unique_chains}, "
140
+ f"Atom Range: {min_atom}-{max_atom} [N = {n_atom}], "
141
+ f"Residue Range: {min_residue}-{max_residue} [N = {n_residue}]"
142
+ )
143
+ assert repr_str == self.structure.__repr__()
144
+
145
+ @pytest.mark.parametrize(
146
+ "path",
147
+ [
148
+ str(files("tests.data").joinpath("Structures/5khe.cif")),
149
+ str(files("tests.data").joinpath("Structures/5khe.pdb")),
150
+ ],
151
+ )
152
+ def test_fromfile(self, path):
153
+ _ = Structure.from_file(path)
154
+
155
+ def test_fromfile_error(self):
156
+ with pytest.raises(NotImplementedError):
157
+ _ = Structure.from_file("madeup.extension")
158
+
159
+ @pytest.mark.parametrize("file_format", ["cif", "pdb", "gro"])
160
+ def test_to_file(self, file_format):
161
+ _, path = mkstemp()
162
+ path = f"{path}.{file_format}"
163
+ self.structure.to_file(path)
164
+ read = self.structure.from_file(path)
165
+ comparison = self.structure.copy()
166
+
167
+ if file_format != "gro":
168
+ self.compare_structures(comparison, read, exclude_attributes=["metadata"])
169
+ else:
170
+ assert np.allclose(comparison.atom_coordinate, read.atom_coordinate)
171
+
172
+ def test_to_file_error(self):
173
+ _, path = mkstemp()
174
+ path = f"{path}.RAISERROR"
175
+ with pytest.raises(NotImplementedError):
176
+ self.structure.to_file(path)
177
+
178
+ def test_subset_by_chain(self):
179
+ chain = "A"
180
+ ret = self.structure.subset_by_chain(chain=chain)
181
+ assert np.all(ret.chain_identifier == chain)
182
+
183
+ def test_subset_by_chain_range(self):
184
+ chain, start, stop = "A", 0, 20
185
+ ret = self.structure.subset_by_range(chain=chain, start=start, stop=stop)
186
+ assert np.all(ret.chain_identifier == chain)
187
+ assert np.all(
188
+ np.logical_and(
189
+ ret.residue_sequence_number >= start,
190
+ ret.residue_sequence_number <= stop,
191
+ )
192
+ )
193
+
194
+ def test_center_of_mass(self):
195
+ center_of_mass = self.structure.center_of_mass()
196
+ assert center_of_mass.shape[0] == self.structure.atom_coordinate.shape[1]
197
+ assert np.allclose(center_of_mass, [-0.89391639, 29.94908928, -2.64736741])
198
+
199
+ def test_centered(self):
200
+ ret, translation = self.structure.centered()
201
+ box = minimum_enclosing_box(coordinates=self.structure.atom_coordinate.T)
202
+ assert np.allclose(ret.center_of_mass(), np.divide(box, 2), atol=1)
203
+
204
+ def test__get_atom_weights_error(self):
205
+ with pytest.raises(NotImplementedError):
206
+ self.structure._get_atom_weights(
207
+ self.structure.atom_name, weight_type="RAISEERROR"
208
+ )
209
+
210
+ def test_compare_structures(self):
211
+ rmsd = Structure.compare_structures(self.structure, self.structure)
212
+ assert rmsd == 0
213
+
214
+ rmsd = Structure.compare_structures(
215
+ self.structure, self.structure, weighted=True
216
+ )
217
+ assert rmsd == 0
218
+
219
+ translation = (3, 0, 0)
220
+ structure_transform = self.structure.rigid_transform(
221
+ translation=translation,
222
+ rotation_matrix=np.eye(self.structure.atom_coordinate.shape[1]),
223
+ )
224
+ rmsd = Structure.compare_structures(self.structure, structure_transform)
225
+ assert np.allclose(rmsd, np.linalg.norm(translation))
226
+
227
+ def test_comopare_structures_error(self):
228
+ ret = self.structure[[1, 2, 3, 4, 5]]
229
+ with pytest.raises(ValueError):
230
+ Structure.compare_structures(self.structure, ret)
231
+
232
+ def test_align_structures(self):
233
+ rotation_matrix = euler_to_rotationmatrix((20, -10, 45))
234
+ translation = (10, 0, -15)
235
+
236
+ structure_transform = self.structure.rigid_transform(
237
+ rotation_matrix=rotation_matrix, translation=translation
238
+ )
239
+ aligned, final_rmsd = Structure.align_structures(
240
+ self.structure, structure_transform
241
+ )
242
+ assert final_rmsd <= 0.1
243
+
244
+ aligned, final_rmsd = Structure.align_structures(
245
+ self.structure, structure_transform, sampling_rate=1
246
+ )
247
+ assert final_rmsd <= 1
tme/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ from .__version__ import __version__
2
+ from .density import Density
3
+ from .preprocessor import Preprocessor
4
+ from .structure import Structure
5
+ from .orientations import Orientations
6
+ from .matching_data import MatchingData
tme/__version__.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.2.9"
@@ -0,0 +1,2 @@
1
+ from .peaks import *
2
+ from .aggregation import *
tme/analyzer/_utils.py ADDED
@@ -0,0 +1,186 @@
1
+ """ Analyzer utility functions.
2
+
3
+ Copyright (c) 2023-2025 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple
9
+
10
+ from ..types import BackendArray
11
+ from ..backends import backend as be
12
+
13
+ __all__ = ["cart_to_score", "score_to_cart"]
14
+
15
+
16
+ def _convmode_to_shape(
17
+ convolution_mode: str,
18
+ targetshape: BackendArray,
19
+ templateshape: BackendArray,
20
+ convolution_shape: BackendArray,
21
+ ) -> BackendArray:
22
+ """
23
+ Calculate convolution shape based on convolution mode.
24
+
25
+ Parameters
26
+ ----------
27
+ convolution_mode : str
28
+ Mode of convolution. Supported values are:
29
+ - 'same': Output shape will match target shape
30
+ - 'valid': Output shape will be target shape minus template shape plus
31
+ template shape modulo 2
32
+ - Other: Output shape will be equal to convolution_shape
33
+ targetshape : BackendArray
34
+ Shape of the target array.
35
+ templateshape : BackendArray
36
+ Shape of the template array.
37
+ convolution_shape : BackendArray
38
+ Shape of the convolution output.
39
+
40
+ Returns
41
+ -------
42
+ BackendArray
43
+ Convolution shape.
44
+ """
45
+ output_shape = convolution_shape
46
+ if convolution_mode == "same":
47
+ output_shape = targetshape
48
+ elif convolution_mode == "valid":
49
+ output_shape = be.add(
50
+ be.subtract(targetshape, templateshape),
51
+ be.mod(templateshape, 2),
52
+ )
53
+ return be.to_backend_array(output_shape)
54
+
55
+
56
+ def cart_to_score(
57
+ positions: BackendArray,
58
+ fast_shape: Tuple[int],
59
+ targetshape: Tuple[int],
60
+ templateshape: Tuple[int],
61
+ convolution_shape: Tuple[int] = None,
62
+ fourier_shift: Tuple[int] = None,
63
+ convolution_mode: str = None,
64
+ **kwargs,
65
+ ) -> Tuple[BackendArray]:
66
+ """
67
+ Maps peak positions from cartesian to padded score space coordinates.
68
+
69
+ Parameters
70
+ ----------
71
+ positions : BackendArray
72
+ Positions in cartesian coordinates.
73
+ fast_shape : tuple of int
74
+ Shape of the score space padded to efficient Fourier shape.
75
+ targetshape : tuple of int
76
+ Shape of the target array.
77
+ templateshape : tuple of int
78
+ Shape of the template array.
79
+ convolution_shape : tuple of int, optional
80
+ Non-padded convolution_shape of template and target.
81
+ fourier_shift : tuple of int, optional
82
+ Translation offset of coordinates.
83
+ convolution_mode : str, optional
84
+ Mode of convolution ('same', 'valid', or 'full')
85
+
86
+ Returns
87
+ -------
88
+ Tuple of BackendArray
89
+ Adjusted positions. and boolean array indicating whether corresponding
90
+ positions are valid positions w.r.t. to supplied bounds.
91
+ """
92
+ positions = be.to_backend_array(positions)
93
+ fast_shape = be.to_backend_array(fast_shape)
94
+ targetshape = be.to_backend_array(targetshape)
95
+ templateshape = be.to_backend_array(templateshape)
96
+ convolution_shape = be.to_backend_array(convolution_shape)
97
+
98
+ # Compute removed padding
99
+ output_shape = _convmode_to_shape(
100
+ convolution_mode=convolution_mode,
101
+ targetshape=targetshape,
102
+ templateshape=templateshape,
103
+ convolution_shape=convolution_shape,
104
+ )
105
+ valid_positions = be.multiply(positions >= 0, positions < output_shape)
106
+ valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
107
+
108
+ starts = be.astype(
109
+ be.divide(be.subtract(convolution_shape, output_shape), 2),
110
+ be._int_dtype,
111
+ )
112
+
113
+ positions = be.add(positions, starts)
114
+ if fourier_shift is not None:
115
+ fourier_shift = be.to_backend_array(fourier_shift)
116
+ positions = be.subtract(positions, fourier_shift)
117
+ positions = be.mod(positions, fast_shape)
118
+
119
+ return positions, valid_positions
120
+
121
+
122
+ def score_to_cart(
123
+ positions,
124
+ fast_shape: Tuple[int] = None,
125
+ targetshape: Tuple[int] = None,
126
+ templateshape: Tuple[int] = None,
127
+ convolution_shape: Tuple[int] = None,
128
+ fourier_shift: Tuple[int] = None,
129
+ convolution_mode: str = None,
130
+ **kwargs,
131
+ ) -> Tuple[BackendArray]:
132
+ """
133
+ Maps peak positions from padded score to cartesian coordinates.
134
+
135
+ Parameters
136
+ ----------
137
+ positions : BackendArray
138
+ Positions in padded Fourier space system.
139
+ fast_shape : tuple of int
140
+ Shape of the score space padded to efficient Fourier shape.
141
+ targetshape : tuple of int
142
+ Shape of the target array.
143
+ templateshape : tuple of int
144
+ Shape of the template array.
145
+ convolution_shape : tuple of int, optional
146
+ Non-padded convolution_shape of template and target.
147
+ fourier_shift : tuple of int, optional
148
+ Translation offset of coordinates.
149
+ convolution_mode : str, optional
150
+ Mode of convolution ('same', 'valid', or 'full')
151
+
152
+ Returns
153
+ -------
154
+ Tuple of BackendArray
155
+ Adjusted positions. and boolean array indicating whether corresponding
156
+ positions are valid positions w.r.t. to supplied bounds.
157
+ """
158
+ positions = be.to_backend_array(positions)
159
+ convolution_shape = be.to_backend_array(convolution_shape)
160
+ fast_shape = be.to_backend_array(fast_shape)
161
+ targetshape = be.to_backend_array(targetshape)
162
+ templateshape = be.to_backend_array(templateshape)
163
+
164
+ # Wrap peaks around score space
165
+ if fourier_shift is not None:
166
+ fourier_shift = be.to_backend_array(fourier_shift)
167
+ positions = be.add(positions, fourier_shift)
168
+ positions = be.mod(positions, fast_shape)
169
+
170
+ output_shape = _convmode_to_shape(
171
+ convolution_mode=convolution_mode,
172
+ targetshape=targetshape,
173
+ templateshape=templateshape,
174
+ convolution_shape=convolution_shape,
175
+ )
176
+ starts = be.astype(
177
+ be.divide(be.subtract(convolution_shape, output_shape), 2),
178
+ be._int_dtype,
179
+ )
180
+ stops = be.add(starts, output_shape)
181
+
182
+ valid_positions = be.multiply(positions >= starts, positions < stops)
183
+ valid_positions = be.sum(valid_positions, axis=1) == positions.shape[1]
184
+ positions = be.subtract(positions, starts)
185
+
186
+ return positions, valid_positions