pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
@@ -0,0 +1,1311 @@
|
|
1
|
+
""" Implements methods for non-exhaustive template matching.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
import warnings
|
9
|
+
from typing import Tuple, List, Dict
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
from scipy.ndimage import laplace, map_coordinates, sobel
|
14
|
+
from scipy.optimize import (
|
15
|
+
minimize,
|
16
|
+
basinhopping,
|
17
|
+
LinearConstraint,
|
18
|
+
differential_evolution,
|
19
|
+
)
|
20
|
+
|
21
|
+
from .backends import backend as be
|
22
|
+
from .types import ArrayLike, NDArray
|
23
|
+
from .matching_data import MatchingData
|
24
|
+
from .rotations import euler_to_rotationmatrix
|
25
|
+
from .matching_utils import rigid_transform, normalize_template
|
26
|
+
|
27
|
+
|
28
|
+
def _format_rigid_transform(x: Tuple[float]) -> Tuple[ArrayLike, ArrayLike]:
|
29
|
+
"""
|
30
|
+
Returns a formated rigid transform definition.
|
31
|
+
|
32
|
+
Parameters
|
33
|
+
----------
|
34
|
+
x : tuple of float
|
35
|
+
Even-length tuple where the first half represents translations and the
|
36
|
+
second half Euler angles in zyz convention for each dimension.
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
Tuple[ArrayLike, ArrayLike]
|
41
|
+
Translation of length [d, ] and rotation matrix with dimension [d x d].
|
42
|
+
"""
|
43
|
+
split = len(x) // 2
|
44
|
+
translation, angles = x[:split], x[split:]
|
45
|
+
|
46
|
+
translation = be.to_backend_array(translation)
|
47
|
+
rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles))
|
48
|
+
rotation_matrix = be.to_backend_array(rotation_matrix)
|
49
|
+
|
50
|
+
return translation, rotation_matrix
|
51
|
+
|
52
|
+
|
53
|
+
class _MatchDensityToDensity(ABC):
|
54
|
+
"""
|
55
|
+
Parameters
|
56
|
+
----------
|
57
|
+
target : array_like
|
58
|
+
The target density array.
|
59
|
+
template : array_like
|
60
|
+
The template density array.
|
61
|
+
template_mask : array_like, optional
|
62
|
+
Mask array for the template density.
|
63
|
+
target_mask : array_like, optional
|
64
|
+
Mask array for the target density.
|
65
|
+
pad_target_edges : bool, optional
|
66
|
+
Whether to pad the edges of the target density array. Default is False.
|
67
|
+
pad_fourier : bool, optional
|
68
|
+
Whether to pad the Fourier transform of the target and template densities.
|
69
|
+
rotate_mask : bool, optional
|
70
|
+
Whether to rotate the mask arrays along with the densities. Default is True.
|
71
|
+
interpolation_order : int, optional
|
72
|
+
The interpolation order for rigid transforms. Default is 1.
|
73
|
+
negate_score : bool, optional
|
74
|
+
Whether the final score should be multiplied by negative one. Default is True.
|
75
|
+
**kwargs : Dict, optional
|
76
|
+
Keyword arguments propagated to downstream functions.
|
77
|
+
"""
|
78
|
+
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
target: ArrayLike,
|
82
|
+
template: ArrayLike,
|
83
|
+
template_mask: ArrayLike = None,
|
84
|
+
target_mask: ArrayLike = None,
|
85
|
+
pad_target_edges: bool = False,
|
86
|
+
pad_fourier: bool = False,
|
87
|
+
rotate_mask: bool = True,
|
88
|
+
interpolation_order: int = 1,
|
89
|
+
negate_score: bool = True,
|
90
|
+
**kwargs: Dict,
|
91
|
+
):
|
92
|
+
self.eps = be.eps(target.dtype)
|
93
|
+
self.rotate_mask = rotate_mask
|
94
|
+
self.interpolation_order = interpolation_order
|
95
|
+
|
96
|
+
matching_data = MatchingData(target=target, template=template)
|
97
|
+
if template_mask is not None:
|
98
|
+
matching_data.template_mask = template_mask
|
99
|
+
if target_mask is not None:
|
100
|
+
matching_data.target_mask = target_mask
|
101
|
+
|
102
|
+
self.target, self.target_mask = matching_data.target, matching_data.target_mask
|
103
|
+
|
104
|
+
self.template = matching_data._template
|
105
|
+
self.template_rot = be.zeros(template.shape, be._float_dtype)
|
106
|
+
|
107
|
+
self.template_mask, self.template_mask_rot = 1, 1
|
108
|
+
rotate_mask = False if matching_data._template_mask is None else rotate_mask
|
109
|
+
if matching_data.template_mask is not None:
|
110
|
+
self.template_mask = matching_data._template_mask
|
111
|
+
self.template_mask_rot = be.topleft_pad(
|
112
|
+
matching_data._template_mask, self.template_mask.shape
|
113
|
+
)
|
114
|
+
|
115
|
+
self.template_slices = tuple(slice(None) for _ in self.template.shape)
|
116
|
+
self.target_slices = tuple(slice(0, x) for x in self.template.shape)
|
117
|
+
|
118
|
+
self.score_sign = -1 if negate_score else 1
|
119
|
+
|
120
|
+
if hasattr(self, "_post_init"):
|
121
|
+
self._post_init(**kwargs)
|
122
|
+
|
123
|
+
def rotate_array(
|
124
|
+
self,
|
125
|
+
arr,
|
126
|
+
rotation_matrix,
|
127
|
+
translation,
|
128
|
+
arr_mask=None,
|
129
|
+
out=None,
|
130
|
+
out_mask=None,
|
131
|
+
order: int = 1,
|
132
|
+
**kwargs,
|
133
|
+
):
|
134
|
+
rotate_mask = arr_mask is not None
|
135
|
+
return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
|
136
|
+
translation = np.zeros(arr.ndim) if translation is None else translation
|
137
|
+
|
138
|
+
center = np.floor(np.array(arr.shape) / 2)[:, None]
|
139
|
+
|
140
|
+
if not hasattr(self, "_previous_center"):
|
141
|
+
self._previous_center = arr.shape
|
142
|
+
|
143
|
+
if not hasattr(self, "grid") or not np.allclose(self._previous_center, center):
|
144
|
+
self.grid = np.indices(arr.shape, dtype=np.float32).reshape(arr.ndim, -1)
|
145
|
+
np.subtract(self.grid, center, out=self.grid)
|
146
|
+
self.grid_out = np.zeros_like(self.grid)
|
147
|
+
self._previous_center = center
|
148
|
+
|
149
|
+
np.matmul(rotation_matrix.T, self.grid, out=self.grid_out)
|
150
|
+
translation = np.add(translation[:, None], center)
|
151
|
+
np.add(self.grid_out, translation, out=self.grid_out)
|
152
|
+
|
153
|
+
if out is None:
|
154
|
+
out = np.zeros_like(arr)
|
155
|
+
|
156
|
+
self._interpolate(arr, self.grid_out, order=order, out=out.ravel())
|
157
|
+
|
158
|
+
if out_mask is None and arr_mask is not None:
|
159
|
+
out_mask = np.zeros_like(arr_mask)
|
160
|
+
|
161
|
+
if arr_mask is not None:
|
162
|
+
self._interpolate(arr_mask, self.grid_out, order=order, out=out.ravel())
|
163
|
+
|
164
|
+
match return_type:
|
165
|
+
case 0:
|
166
|
+
return None
|
167
|
+
case 1:
|
168
|
+
return out
|
169
|
+
case 2:
|
170
|
+
return out_mask
|
171
|
+
case 3:
|
172
|
+
return out, out_mask
|
173
|
+
|
174
|
+
@staticmethod
|
175
|
+
def _interpolate(data, positions, order: int = 1, out=None):
|
176
|
+
return map_coordinates(
|
177
|
+
data, positions, order=order, mode="constant", output=out
|
178
|
+
)
|
179
|
+
|
180
|
+
def score_translation(self, x: Tuple[float]) -> float:
|
181
|
+
"""
|
182
|
+
Computes the score after a given translation.
|
183
|
+
|
184
|
+
Parameters
|
185
|
+
----------
|
186
|
+
x : tuple of float
|
187
|
+
Tuple representing the translation transformation in each dimension.
|
188
|
+
|
189
|
+
Returns
|
190
|
+
-------
|
191
|
+
float
|
192
|
+
The score obtained for the translation transformation.
|
193
|
+
"""
|
194
|
+
return self.score((*x, *[0 for _ in range(len(x))]))
|
195
|
+
|
196
|
+
def score_angles(self, x: Tuple[float]) -> float:
|
197
|
+
"""
|
198
|
+
Computes the score after a given rotation.
|
199
|
+
|
200
|
+
Parameters
|
201
|
+
----------
|
202
|
+
x : tuple of float
|
203
|
+
Tuple of Euler angles in zyz convention for each dimension.
|
204
|
+
|
205
|
+
Returns
|
206
|
+
-------
|
207
|
+
float
|
208
|
+
The score obtained for the rotation transformation.
|
209
|
+
"""
|
210
|
+
return self.score((*[0 for _ in range(len(x))], *x))
|
211
|
+
|
212
|
+
def score(self, x: Tuple[float]) -> float:
|
213
|
+
"""
|
214
|
+
Compute the matching score for the given transformation parameters.
|
215
|
+
|
216
|
+
Parameters
|
217
|
+
----------
|
218
|
+
x : tuple of float
|
219
|
+
Even-length tuple where the first half represents translations and the
|
220
|
+
second half Euler angles in zyz convention for each dimension.
|
221
|
+
|
222
|
+
Returns
|
223
|
+
-------
|
224
|
+
float
|
225
|
+
The matching score obtained for the transformation.
|
226
|
+
"""
|
227
|
+
translation, rotation_matrix = _format_rigid_transform(x)
|
228
|
+
self.template_rot.fill(0)
|
229
|
+
|
230
|
+
voxel_translation = be.astype(translation, be._int_dtype)
|
231
|
+
subvoxel_translation = be.subtract(translation, voxel_translation)
|
232
|
+
|
233
|
+
center = be.astype(be.divide(self.template.shape, 2), be._int_dtype)
|
234
|
+
right_pad = be.subtract(self.template.shape, center)
|
235
|
+
|
236
|
+
translated_center = be.add(voxel_translation, center)
|
237
|
+
|
238
|
+
target_starts = be.subtract(translated_center, center)
|
239
|
+
target_stops = be.add(translated_center, right_pad)
|
240
|
+
|
241
|
+
template_starts = be.subtract(be.maximum(target_starts, 0), target_starts)
|
242
|
+
template_stops = be.subtract(
|
243
|
+
target_stops, be.minimum(target_stops, self.target.shape)
|
244
|
+
)
|
245
|
+
template_stops = be.subtract(self.template.shape, template_stops)
|
246
|
+
|
247
|
+
target_starts = be.maximum(target_starts, 0)
|
248
|
+
target_stops = be.minimum(target_stops, self.target.shape)
|
249
|
+
|
250
|
+
cand_start, cand_stop = template_starts.astype(int), template_stops.astype(int)
|
251
|
+
obs_start, obs_stop = target_starts.astype(int), target_stops.astype(int)
|
252
|
+
|
253
|
+
self.template_slices = tuple(slice(s, e) for s, e in zip(cand_start, cand_stop))
|
254
|
+
self.target_slices = tuple(slice(s, e) for s, e in zip(obs_start, obs_stop))
|
255
|
+
|
256
|
+
kw_dict = {
|
257
|
+
"arr": self.template,
|
258
|
+
"rotation_matrix": rotation_matrix,
|
259
|
+
"translation": subvoxel_translation,
|
260
|
+
"out": self.template_rot,
|
261
|
+
"order": self.interpolation_order,
|
262
|
+
"use_geometric_center": True,
|
263
|
+
}
|
264
|
+
if self.rotate_mask:
|
265
|
+
self.template_mask_rot.fill(0)
|
266
|
+
kw_dict["arr_mask"] = self.template_mask
|
267
|
+
kw_dict["out_mask"] = self.template_mask_rot
|
268
|
+
|
269
|
+
self.rotate_array(**kw_dict)
|
270
|
+
|
271
|
+
return self()
|
272
|
+
|
273
|
+
@abstractmethod
|
274
|
+
def __call__(self) -> float:
|
275
|
+
"""Returns the score of the current configuration."""
|
276
|
+
|
277
|
+
|
278
|
+
class _MatchCoordinatesToDensity(_MatchDensityToDensity):
|
279
|
+
"""
|
280
|
+
Parameters
|
281
|
+
----------
|
282
|
+
target : NDArray
|
283
|
+
A d-dimensional target to match the template coordinate set to.
|
284
|
+
template_coordinates : NDArray
|
285
|
+
Template coordinate array with shape (d,n).
|
286
|
+
template_weights : NDArray
|
287
|
+
Template weight array with shape (n,).
|
288
|
+
template_mask_coordinates : NDArray, optional
|
289
|
+
Template mask coordinates with shape (d,n).
|
290
|
+
target_mask : NDArray, optional
|
291
|
+
A d-dimensional mask to be applied to the target.
|
292
|
+
negate_score : bool, optional
|
293
|
+
Whether the final score should be multiplied by negative one. Default is True.
|
294
|
+
return_gradient : bool, optional
|
295
|
+
Invoking __call_ returns a tuple of score and parameter gradient. Default is False.
|
296
|
+
**kwargs : Dict, optional
|
297
|
+
Keyword arguments propagated to downstream functions.
|
298
|
+
"""
|
299
|
+
|
300
|
+
def __init__(
|
301
|
+
self,
|
302
|
+
target: NDArray,
|
303
|
+
template_coordinates: NDArray,
|
304
|
+
template_weights: NDArray,
|
305
|
+
template_mask_coordinates: NDArray = None,
|
306
|
+
target_mask: NDArray = None,
|
307
|
+
negate_score: bool = True,
|
308
|
+
return_gradient: bool = False,
|
309
|
+
interpolation_order: int = 1,
|
310
|
+
**kwargs: Dict,
|
311
|
+
):
|
312
|
+
self.target = target.astype(np.float32)
|
313
|
+
self.target_mask = None
|
314
|
+
if target_mask is not None:
|
315
|
+
self.target_mask = target_mask.astype(np.float32)
|
316
|
+
|
317
|
+
self.eps = be.eps(self.target.dtype)
|
318
|
+
|
319
|
+
self.target_grad = np.stack(
|
320
|
+
[sobel(self.target, axis=i) for i in range(self.target.ndim)]
|
321
|
+
)
|
322
|
+
|
323
|
+
self.n_points = template_coordinates.shape[1]
|
324
|
+
self.template = template_coordinates.astype(np.float32)
|
325
|
+
self.template_rotated = np.zeros_like(self.template)
|
326
|
+
self.template_weights = template_weights.astype(np.float32)
|
327
|
+
self.template_center = np.mean(self.template, axis=1)[:, None]
|
328
|
+
|
329
|
+
self.template_mask, self.template_mask_rotated = None, None
|
330
|
+
if template_mask_coordinates is not None:
|
331
|
+
self.template_mask = template_mask_coordinates.astype(np.float32)
|
332
|
+
self.template_mask_rotated = np.empty_like(self.template_mask)
|
333
|
+
|
334
|
+
self.denominator = 1
|
335
|
+
self.score_sign = -1 if negate_score else 1
|
336
|
+
self.interpolation_order = interpolation_order
|
337
|
+
|
338
|
+
self._target_values = self._interpolate(
|
339
|
+
self.target, self.template, order=self.interpolation_order
|
340
|
+
)
|
341
|
+
|
342
|
+
if return_gradient and not hasattr(self, "grad"):
|
343
|
+
raise NotImplementedError(f"{type(self)} does not have grad method.")
|
344
|
+
self.return_gradient = return_gradient
|
345
|
+
|
346
|
+
def score(self, x: Tuple[float]):
|
347
|
+
"""
|
348
|
+
Compute the matching score for the given transformation parameters.
|
349
|
+
|
350
|
+
Parameters
|
351
|
+
----------
|
352
|
+
x : tuple of float
|
353
|
+
Even-length tuple where the first half represents translations and the
|
354
|
+
second half Euler angles in zyz convention for each dimension.
|
355
|
+
|
356
|
+
Returns
|
357
|
+
-------
|
358
|
+
float
|
359
|
+
The matching score obtained for the transformation.
|
360
|
+
"""
|
361
|
+
translation, rotation_matrix = _format_rigid_transform(x)
|
362
|
+
|
363
|
+
rigid_transform(
|
364
|
+
coordinates=self.template,
|
365
|
+
coordinates_mask=self.template_mask,
|
366
|
+
rotation_matrix=rotation_matrix,
|
367
|
+
translation=translation,
|
368
|
+
out=self.template_rotated,
|
369
|
+
out_mask=self.template_mask_rotated,
|
370
|
+
use_geometric_center=False,
|
371
|
+
)
|
372
|
+
|
373
|
+
self._target_values = self._interpolate(
|
374
|
+
self.target, self.template_rotated, order=self.interpolation_order
|
375
|
+
)
|
376
|
+
|
377
|
+
score = self()
|
378
|
+
if not self.return_gradient:
|
379
|
+
return score
|
380
|
+
|
381
|
+
return score, self.grad()
|
382
|
+
|
383
|
+
def _interpolate_gradient(self, positions):
|
384
|
+
ret = be.zeros(positions.shape, dtype=positions.dtype)
|
385
|
+
|
386
|
+
for k in range(self.target_grad.shape[0]):
|
387
|
+
ret[k, :] = self._interpolate(
|
388
|
+
self.target_grad[k], positions, order=self.interpolation_order
|
389
|
+
)
|
390
|
+
|
391
|
+
return ret
|
392
|
+
|
393
|
+
@staticmethod
|
394
|
+
def _torques(positions, center, gradients):
|
395
|
+
positions_center = (positions - center).T
|
396
|
+
return be.cross(positions_center, gradients.T).T
|
397
|
+
|
398
|
+
|
399
|
+
class _MatchCoordinatesToCoordinates(_MatchDensityToDensity):
|
400
|
+
"""
|
401
|
+
Parameters
|
402
|
+
----------
|
403
|
+
target_coordinates : NDArray
|
404
|
+
The coordinates of the target with shape [d x N].
|
405
|
+
template_coordinates : NDArray
|
406
|
+
The coordinates of the template with shape [d x T].
|
407
|
+
target_weights : NDArray
|
408
|
+
The weights of the target with shape [N].
|
409
|
+
template_weights : NDArray
|
410
|
+
The weights of the template with shape [T].
|
411
|
+
template_mask_coordinates : NDArray, optional
|
412
|
+
The coordinates of the template mask with shape [d x T]. Default is None.
|
413
|
+
target_mask_coordinates : NDArray, optional
|
414
|
+
The coordinates of the target mask with shape [d X N]. Default is None.
|
415
|
+
negate_score : bool, optional
|
416
|
+
Whether the final score should be multiplied by negative one. Default is True.
|
417
|
+
**kwargs : Dict, optional
|
418
|
+
Keyword arguments propagated to downstream functions.
|
419
|
+
"""
|
420
|
+
|
421
|
+
def __init__(
|
422
|
+
self,
|
423
|
+
target_coordinates: NDArray,
|
424
|
+
template_coordinates: NDArray,
|
425
|
+
target_weights: NDArray,
|
426
|
+
template_weights: NDArray,
|
427
|
+
template_mask_coordinates: NDArray = None,
|
428
|
+
target_mask_coordinates: NDArray = None,
|
429
|
+
negate_score: bool = True,
|
430
|
+
**kwargs,
|
431
|
+
):
|
432
|
+
self.target_weights = target_weights
|
433
|
+
self.target_coordinates = target_coordinates
|
434
|
+
|
435
|
+
self.template_weights = template_weights
|
436
|
+
self.template_coordinates = template_coordinates
|
437
|
+
self.template_coordinates_rotated = np.empty(
|
438
|
+
self.template_coordinates.shape, dtype=np.float32
|
439
|
+
)
|
440
|
+
self.target_mask_coordinates = target_mask_coordinates
|
441
|
+
|
442
|
+
self.template_mask_coordinates = None
|
443
|
+
self.template_mask_coordinates_rotated = None
|
444
|
+
if template_mask_coordinates is not None:
|
445
|
+
self.template_mask_coordinates = template_mask_coordinates
|
446
|
+
self.template_mask_coordinates_rotated = np.empty(
|
447
|
+
self.template_mask_coordinates.shape, dtype=np.float32
|
448
|
+
)
|
449
|
+
self.score_sign = -1 if negate_score else 1
|
450
|
+
|
451
|
+
if hasattr(self, "_post_init"):
|
452
|
+
self._post_init(**kwargs)
|
453
|
+
|
454
|
+
def score(self, x: Tuple[float]) -> float:
|
455
|
+
"""
|
456
|
+
Compute the matching score for the given transformation parameters.
|
457
|
+
|
458
|
+
Parameters
|
459
|
+
----------
|
460
|
+
x : tuple of float
|
461
|
+
Even-length tuple where the first half represents translations and the
|
462
|
+
second half Euler angles in zyz convention for each dimension.
|
463
|
+
|
464
|
+
Returns
|
465
|
+
-------
|
466
|
+
float
|
467
|
+
The matching score obtained for the transformation.
|
468
|
+
"""
|
469
|
+
translation, rotation_matrix = _format_rigid_transform(x)
|
470
|
+
|
471
|
+
rigid_transform(
|
472
|
+
coordinates=self.template_coordinates,
|
473
|
+
coordinates_mask=self.template_mask_coordinates,
|
474
|
+
rotation_matrix=rotation_matrix,
|
475
|
+
translation=translation,
|
476
|
+
out=self.template_coordinates_rotated,
|
477
|
+
out_mask=self.template_mask_coordinates_rotated,
|
478
|
+
use_geometric_center=False,
|
479
|
+
)
|
480
|
+
|
481
|
+
return self(
|
482
|
+
transformed_coordinates=self.template_coordinates_rotated,
|
483
|
+
transformed_coordinates_mask=self.template_mask_coordinates_rotated,
|
484
|
+
)
|
485
|
+
|
486
|
+
|
487
|
+
class FLC(_MatchDensityToDensity):
|
488
|
+
"""
|
489
|
+
Computes a normalized cross-correlation score of a target f a template g
|
490
|
+
and a mask m:
|
491
|
+
|
492
|
+
.. math::
|
493
|
+
|
494
|
+
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
495
|
+
{N_m * \\sqrt{
|
496
|
+
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
497
|
+
}
|
498
|
+
|
499
|
+
Where:
|
500
|
+
|
501
|
+
.. math::
|
502
|
+
|
503
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
504
|
+
|
505
|
+
and Nm is the number of voxels within the template mask m.
|
506
|
+
"""
|
507
|
+
|
508
|
+
__doc__ += _MatchDensityToDensity.__doc__
|
509
|
+
|
510
|
+
def _post_init(self, **kwargs: Dict):
|
511
|
+
if self.target_mask is not None:
|
512
|
+
be.multiply(self.target, self.target_mask, out=self.target)
|
513
|
+
|
514
|
+
self.target_square = be.square(self.target)
|
515
|
+
|
516
|
+
normalize_template(
|
517
|
+
template=self.template,
|
518
|
+
mask=self.template_mask,
|
519
|
+
n_observations=be.sum(self.template_mask),
|
520
|
+
)
|
521
|
+
|
522
|
+
def __call__(self) -> float:
|
523
|
+
"""Returns the score of the current configuration."""
|
524
|
+
n_obs = be.sum(self.template_mask_rot)
|
525
|
+
|
526
|
+
normalize_template(
|
527
|
+
template=self.template_rot,
|
528
|
+
mask=self.template_mask_rot,
|
529
|
+
n_observations=n_obs,
|
530
|
+
)
|
531
|
+
overlap = be.sum(
|
532
|
+
be.multiply(
|
533
|
+
self.template_rot[self.template_slices], self.target[self.target_slices]
|
534
|
+
)
|
535
|
+
)
|
536
|
+
|
537
|
+
mask_rot = self.template_mask_rot[self.template_slices]
|
538
|
+
exp_sq = be.sum(self.target_square[self.target_slices] * mask_rot) / n_obs
|
539
|
+
sq_exp = be.square(be.sum(self.target[self.target_slices] * mask_rot) / n_obs)
|
540
|
+
|
541
|
+
denominator = be.maximum(be.subtract(exp_sq, sq_exp), 0.0)
|
542
|
+
denominator = be.sqrt(denominator)
|
543
|
+
if denominator < self.eps:
|
544
|
+
return 0
|
545
|
+
|
546
|
+
score = be.divide(overlap, denominator * n_obs) * self.score_sign
|
547
|
+
return score
|
548
|
+
|
549
|
+
|
550
|
+
class CrossCorrelation(_MatchCoordinatesToDensity):
|
551
|
+
"""
|
552
|
+
Computes the Cross-Correlation score as:
|
553
|
+
|
554
|
+
.. math::
|
555
|
+
|
556
|
+
\\text{score} = \\text{target_weights} \\cdot \\text{template_weights}
|
557
|
+
"""
|
558
|
+
|
559
|
+
__doc__ += _MatchCoordinatesToDensity.__doc__
|
560
|
+
|
561
|
+
def __call__(self) -> float:
|
562
|
+
"""Returns the score of the current configuration."""
|
563
|
+
score = be.dot(self._target_values, self.template_weights)
|
564
|
+
score /= self.denominator * self.score_sign
|
565
|
+
return score
|
566
|
+
|
567
|
+
def grad(self):
|
568
|
+
"""
|
569
|
+
Calculate the gradient of the cost function w.r.t. translation and rotation.
|
570
|
+
|
571
|
+
.. math::
|
572
|
+
|
573
|
+
\\nabla f = -\\frac{1}{N} \\begin{bmatrix}
|
574
|
+
\\sum_i w_i \\nabla v(x_i) \\\\
|
575
|
+
\\sum_i w_i (r_i \\times \\nabla v(x_i))
|
576
|
+
\\end{bmatrix}
|
577
|
+
|
578
|
+
where :math:`N` is the number of points, :math:`w_i` are weights,
|
579
|
+
:math:`x_i` are rotated template positions, and :math:`r_i` are
|
580
|
+
positions relative to the template center.
|
581
|
+
|
582
|
+
Returns
|
583
|
+
-------
|
584
|
+
np.ndarray
|
585
|
+
Negative gradient of the cost function: [dx, dy, dz, dRx, dRy, dRz].
|
586
|
+
|
587
|
+
"""
|
588
|
+
grad = self._interpolate_gradient(positions=self.template_rotated)
|
589
|
+
torque = self._torques(
|
590
|
+
positions=self.template_rotated, gradients=grad, center=self.template_center
|
591
|
+
)
|
592
|
+
|
593
|
+
translation_grad = be.sum(grad * self.template_weights, axis=1)
|
594
|
+
torque_grad = be.sum(torque * self.template_weights, axis=1)
|
595
|
+
|
596
|
+
# <u, dv/dx> / <u, r x dv/dx>
|
597
|
+
total_grad = be.concatenate([translation_grad, torque_grad])
|
598
|
+
total_grad = be.divide(total_grad, self.n_points, out=total_grad)
|
599
|
+
return -total_grad
|
600
|
+
|
601
|
+
|
602
|
+
class LaplaceCrossCorrelation(CrossCorrelation):
|
603
|
+
"""
|
604
|
+
Uses the same formalism as :py:class:`CrossCorrelation` but with Laplace
|
605
|
+
filtered weights (:math:`\\nabla^{2}`):
|
606
|
+
|
607
|
+
.. math::
|
608
|
+
|
609
|
+
\\text{score} = \\nabla^{2} \\text{target_weights} \\cdot
|
610
|
+
\\nabla^{2} \\text{template_weights}
|
611
|
+
"""
|
612
|
+
|
613
|
+
__doc__ += _MatchCoordinatesToDensity.__doc__
|
614
|
+
|
615
|
+
def __init__(self, **kwargs):
|
616
|
+
kwargs["target"] = laplace(kwargs["target"])
|
617
|
+
|
618
|
+
coordinates = kwargs["template_coordinates"]
|
619
|
+
origin = coordinates.min(axis=1)
|
620
|
+
positions = (coordinates - origin[:, None]).astype(int)
|
621
|
+
shape = positions.max(axis=1) + 1
|
622
|
+
arr = np.zeros(shape, dtype=np.float32)
|
623
|
+
np.add.at(arr, tuple(positions), kwargs["template_weights"])
|
624
|
+
|
625
|
+
kwargs["template_weights"] = laplace(arr)[tuple(positions)]
|
626
|
+
super().__init__(**kwargs)
|
627
|
+
|
628
|
+
|
629
|
+
class NormalizedCrossCorrelation(CrossCorrelation):
|
630
|
+
"""
|
631
|
+
Computes a normalized version of the :py:class:`CrossCorrelation` score based
|
632
|
+
on the dot product of `target_weights` and `template_weights`, in order to
|
633
|
+
reduce bias to regions of high local energy.
|
634
|
+
|
635
|
+
.. math::
|
636
|
+
|
637
|
+
\\text{score} = \\frac{\\text{target_weights} \\cdot \\text{template_weights}}
|
638
|
+
{\\text{max(target_norm} \\times \\text{template_norm, eps)}}
|
639
|
+
|
640
|
+
Where:
|
641
|
+
|
642
|
+
.. math::
|
643
|
+
|
644
|
+
\\text{target_norm} = ||\\text{target_weights}||
|
645
|
+
|
646
|
+
.. math::
|
647
|
+
|
648
|
+
\\text{template_norm} = ||\\text{template_weights}||
|
649
|
+
|
650
|
+
Here, :math:`||.||` denotes the L2 (Euclidean) norm.
|
651
|
+
"""
|
652
|
+
|
653
|
+
__doc__ += _MatchCoordinatesToDensity.__doc__
|
654
|
+
|
655
|
+
def __call__(self) -> float:
|
656
|
+
denominator = be.multiply(
|
657
|
+
np.linalg.norm(self.template_weights), np.linalg.norm(self._target_values)
|
658
|
+
)
|
659
|
+
|
660
|
+
if denominator <= 0:
|
661
|
+
return 0.0
|
662
|
+
|
663
|
+
self.denominator = denominator
|
664
|
+
return super().__call__()
|
665
|
+
|
666
|
+
def grad(self):
|
667
|
+
"""
|
668
|
+
Calculate the normalized gradient of the cost function w.r.t. translation and rotation.
|
669
|
+
|
670
|
+
.. math::
|
671
|
+
|
672
|
+
\\nabla f = -\\frac{1}{N|w||v|^3} \\begin{bmatrix}
|
673
|
+
(\\sum_i w_i \\nabla v(x_i))|v|^2 - (\\sum_i v(x_i)
|
674
|
+
\\nabla v(x_i))(w \\cdot v) \\\\
|
675
|
+
(\\sum_i w_i (r_i \\times \\nabla v(x_i)))|v|^2 - (\\sum_i v(x_i)
|
676
|
+
(r_i \\times \\nabla v(x_i)))(w \\cdot v)
|
677
|
+
\\end{bmatrix}
|
678
|
+
|
679
|
+
where :math:`N` is the number of points, :math:`w` are weights,
|
680
|
+
:math:`v` are target values, :math:`x_i` are rotated template positions,
|
681
|
+
and :math:`r_i` are positions relative to the template center.
|
682
|
+
|
683
|
+
Returns
|
684
|
+
-------
|
685
|
+
np.ndarray
|
686
|
+
Negative normalized gradient: [dx, dy, dz, dRx, dRy, dRz].
|
687
|
+
|
688
|
+
"""
|
689
|
+
grad = self._interpolate_gradient(positions=self.template_rotated)
|
690
|
+
torque = self._torques(
|
691
|
+
positions=self.template_rotated, gradients=grad, center=self.template_center
|
692
|
+
)
|
693
|
+
|
694
|
+
norm = be.multiply(
|
695
|
+
be.power(be.sqrt(be.sum(be.square(self._target_values))), 3),
|
696
|
+
be.sqrt(be.sum(be.square(self.template_weights))),
|
697
|
+
)
|
698
|
+
|
699
|
+
# (<u,dv/dx> * |v|**2 - <u,v> * <v,dv/dx>)/(|w|*|v|**3)
|
700
|
+
translation_grad = be.multiply(
|
701
|
+
be.sum(be.multiply(grad, self.template_weights), axis=1),
|
702
|
+
be.sum(be.square(self._target_values)),
|
703
|
+
)
|
704
|
+
translation_grad -= be.multiply(
|
705
|
+
be.sum(be.multiply(grad, self._target_values), axis=1),
|
706
|
+
be.sum(be.multiply(self._target_values, self.template_weights)),
|
707
|
+
)
|
708
|
+
|
709
|
+
# (<u,r x dv/dx> * |v|**2 - <u,v> * <v,r x dv/dx>)/(|w|*|v|**3)
|
710
|
+
torque_grad = be.multiply(
|
711
|
+
be.sum(be.multiply(torque, self.template_weights), axis=1),
|
712
|
+
be.sum(be.square(self._target_values)),
|
713
|
+
)
|
714
|
+
torque_grad -= be.multiply(
|
715
|
+
be.sum(be.multiply(torque, self._target_values), axis=1),
|
716
|
+
be.sum(be.multiply(self._target_values, self.template_weights)),
|
717
|
+
)
|
718
|
+
|
719
|
+
total_grad = be.concatenate([translation_grad, torque_grad])
|
720
|
+
if norm > 0:
|
721
|
+
total_grad = be.divide(total_grad, norm, out=total_grad)
|
722
|
+
|
723
|
+
total_grad = be.divide(total_grad, self.n_points, out=total_grad)
|
724
|
+
return -total_grad
|
725
|
+
|
726
|
+
|
727
|
+
class NormalizedCrossCorrelationMean(NormalizedCrossCorrelation):
|
728
|
+
"""
|
729
|
+
Computes a similar score than :py:class:`NormalizedCrossCorrelation`, but
|
730
|
+
additionally factors in the mean of template and target.
|
731
|
+
|
732
|
+
.. math::
|
733
|
+
|
734
|
+
\\text{score} = \\frac{(\\text{target_weights} - \\text{mean(target_weights)})
|
735
|
+
\\cdot (\\text{template_weights} -
|
736
|
+
\\text{mean(template_weights)})}
|
737
|
+
{\\text{max(target_norm} \\times \\text{template_norm, eps)}}
|
738
|
+
|
739
|
+
Where:
|
740
|
+
|
741
|
+
.. math::
|
742
|
+
|
743
|
+
\\text{target_norm} = ||\\text{target_weights} - \\text{mean(target_weights)}||
|
744
|
+
|
745
|
+
.. math::
|
746
|
+
|
747
|
+
\\text{template_norm} = ||\\text{template_weights} -
|
748
|
+
\\text{mean(template_weights)}||
|
749
|
+
|
750
|
+
Here, :math:`||.||` denotes the L2 (Euclidean) norm, and :math:`\\text{mean(.)}`
|
751
|
+
computes the mean of the respective weights.
|
752
|
+
"""
|
753
|
+
|
754
|
+
__doc__ += _MatchCoordinatesToDensity.__doc__
|
755
|
+
|
756
|
+
def __init__(self, **kwargs):
|
757
|
+
kwargs["target"] = np.subtract(kwargs["target"], kwargs["target"].mean())
|
758
|
+
kwargs["template_weights"] = np.subtract(
|
759
|
+
kwargs["template_weights"], kwargs["template_weights"].mean()
|
760
|
+
)
|
761
|
+
super().__init__(**kwargs)
|
762
|
+
|
763
|
+
|
764
|
+
class MaskedCrossCorrelation(_MatchCoordinatesToDensity):
|
765
|
+
"""
|
766
|
+
The Masked Cross-Correlation computes the similarity between `target_weights`
|
767
|
+
and `template_weights` under respective masks. The score provides a measure of
|
768
|
+
similarity even in the presence of missing or masked data.
|
769
|
+
|
770
|
+
The formula for the Masked Cross-Correlation is:
|
771
|
+
|
772
|
+
.. math::
|
773
|
+
\\text{numerator} = \\text{dot}(\\text{target_weights},
|
774
|
+
\\text{template_weights}) -
|
775
|
+
\\frac{\\text{sum}(\\text{mask_target}) \\times
|
776
|
+
\\text{sum}(\\text{mask_template})}
|
777
|
+
{\\text{mask_overlap}}
|
778
|
+
|
779
|
+
.. math::
|
780
|
+
\\text{denominator1} = \\text{sum}(\\text{mask_target}^2) -
|
781
|
+
\\frac{\\text{sum}(\\text{mask_target})^2}
|
782
|
+
{\\text{mask_overlap}}
|
783
|
+
|
784
|
+
.. math::
|
785
|
+
\\text{denominator2} = \\text{sum}(\\text{mask_template}^2) -
|
786
|
+
\\frac{\\text{sum}(\\text{mask_template})^2}
|
787
|
+
{\\text{mask_overlap}}
|
788
|
+
|
789
|
+
.. math::
|
790
|
+
\\text{denominator} = \\sqrt{\\text{denominator1} \\times \\text{denominator2}}
|
791
|
+
|
792
|
+
.. math::
|
793
|
+
\\text{score} = \\frac{\\text{numerator}}{\\text{denominator}}
|
794
|
+
\\text{ if denominator } \\neq 0
|
795
|
+
\\text{ else } 0
|
796
|
+
|
797
|
+
Where:
|
798
|
+
|
799
|
+
- mask_target and mask_template are binary masks for the target_weights
|
800
|
+
and template_weights respectively.
|
801
|
+
|
802
|
+
- mask_overlap represents the number of overlapping non-zero elements in
|
803
|
+
the masks.
|
804
|
+
|
805
|
+
References
|
806
|
+
----------
|
807
|
+
.. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
|
808
|
+
"""
|
809
|
+
|
810
|
+
__doc__ += _MatchCoordinatesToDensity.__doc__
|
811
|
+
|
812
|
+
def __call__(self) -> float:
|
813
|
+
"""Returns the score of the current configuration."""
|
814
|
+
|
815
|
+
in_volume = np.logical_and(
|
816
|
+
self.template_rotated < np.array(self.target.shape)[:, None],
|
817
|
+
self.template_rotated >= 0,
|
818
|
+
).min(axis=0)
|
819
|
+
in_volume_mask = np.logical_and(
|
820
|
+
self.template_mask_rotated < np.array(self.target.shape)[:, None],
|
821
|
+
self.template_mask_rotated >= 0,
|
822
|
+
).min(axis=0)
|
823
|
+
|
824
|
+
mask_overlap = np.sum(
|
825
|
+
self.target_mask[
|
826
|
+
tuple(self.template_mask_rotated[:, in_volume_mask].astype(int))
|
827
|
+
],
|
828
|
+
)
|
829
|
+
mask_overlap = np.fmax(mask_overlap, np.finfo(float).eps)
|
830
|
+
|
831
|
+
mask_target = self.target[
|
832
|
+
tuple(self.template_mask_rotated[:, in_volume_mask].astype(int))
|
833
|
+
]
|
834
|
+
denominator1 = np.subtract(
|
835
|
+
np.sum(mask_target**2),
|
836
|
+
np.divide(np.square(np.sum(mask_target)), mask_overlap),
|
837
|
+
)
|
838
|
+
mask_template = np.multiply(
|
839
|
+
self.template_weights[in_volume],
|
840
|
+
self.target_mask[tuple(self.template_rotated[:, in_volume].astype(int))],
|
841
|
+
)
|
842
|
+
denominator2 = np.subtract(
|
843
|
+
np.sum(mask_template**2),
|
844
|
+
np.divide(np.square(np.sum(mask_template)), mask_overlap),
|
845
|
+
)
|
846
|
+
|
847
|
+
denominator1 = np.fmax(denominator1, 0.0)
|
848
|
+
denominator2 = np.fmax(denominator2, 0.0)
|
849
|
+
denominator = np.sqrt(np.multiply(denominator1, denominator2))
|
850
|
+
|
851
|
+
numerator = np.dot(
|
852
|
+
self.target[tuple(self.template_rotated[:, in_volume].astype(int))],
|
853
|
+
self.template_weights[in_volume],
|
854
|
+
)
|
855
|
+
|
856
|
+
numerator -= np.divide(
|
857
|
+
np.multiply(np.sum(mask_target), np.sum(mask_template)), mask_overlap
|
858
|
+
)
|
859
|
+
|
860
|
+
if denominator == 0:
|
861
|
+
return 0.0
|
862
|
+
|
863
|
+
score = numerator / denominator
|
864
|
+
return float(score * self.score_sign)
|
865
|
+
|
866
|
+
|
867
|
+
class PartialLeastSquareDifference(_MatchCoordinatesToDensity):
|
868
|
+
"""
|
869
|
+
The Partial Least Square Difference (PLSQ) between the target :math:`f` and the
|
870
|
+
template :math:`g` is calculated as:
|
871
|
+
|
872
|
+
.. math::
|
873
|
+
|
874
|
+
\\text{d(f,g)} = \\sum_{i=1}^{n} \\| f(\\mathbf{p}_i) - g(\\mathbf{q}_i) \\|^2
|
875
|
+
|
876
|
+
References
|
877
|
+
----------
|
878
|
+
.. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
|
879
|
+
fitting", Journal of Structural Biology, vol. 174, no. 2,
|
880
|
+
pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
|
881
|
+
"""
|
882
|
+
|
883
|
+
__doc__ += _MatchCoordinatesToDensity.__doc__
|
884
|
+
|
885
|
+
def __call__(self) -> float:
|
886
|
+
"""Returns the score of the current configuration."""
|
887
|
+
score = be.sum(
|
888
|
+
be.square(be.subtract(self._target_values, self.template_weights))
|
889
|
+
)
|
890
|
+
return score * self.score_sign
|
891
|
+
|
892
|
+
|
893
|
+
class MutualInformation(_MatchCoordinatesToDensity):
|
894
|
+
"""
|
895
|
+
The Mutual Information (MI) score between the target :math:`f` and the
|
896
|
+
template :math:`g` is calculated as:
|
897
|
+
|
898
|
+
.. math::
|
899
|
+
|
900
|
+
\\text{d(f,g)} = \\sum_{f,g} p(f,g) \\log \\frac{p(f,g)}{p(f)p(g)}
|
901
|
+
|
902
|
+
References
|
903
|
+
----------
|
904
|
+
.. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
|
905
|
+
fitting", Journal of Structural Biology, vol. 174, no. 2,
|
906
|
+
pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
|
907
|
+
|
908
|
+
"""
|
909
|
+
|
910
|
+
__doc__ += _MatchCoordinatesToDensity.__doc__
|
911
|
+
|
912
|
+
def __call__(self) -> float:
|
913
|
+
"""Returns the score of the current configuration."""
|
914
|
+
p_xy, target, template = np.histogram2d(
|
915
|
+
self._target_values, self.template_weights
|
916
|
+
)
|
917
|
+
p_x, p_y = np.sum(p_xy, axis=1), np.sum(p_xy, axis=0)
|
918
|
+
|
919
|
+
p_xy /= p_xy.sum()
|
920
|
+
p_x /= p_x.sum()
|
921
|
+
p_y /= p_y.sum()
|
922
|
+
|
923
|
+
logprob = np.divide(p_xy, p_x[:, None] * p_y[None, :] + np.finfo(float).eps)
|
924
|
+
score = np.nansum(p_xy * logprob)
|
925
|
+
|
926
|
+
return score * self.score_sign
|
927
|
+
|
928
|
+
|
929
|
+
class Envelope(_MatchCoordinatesToDensity):
|
930
|
+
"""
|
931
|
+
The Envelope score (ENV) between the target :math:`f` and the
|
932
|
+
template :math:`g` is calculated as:
|
933
|
+
|
934
|
+
.. math::
|
935
|
+
|
936
|
+
\\text{d(f,g)} = \\sum_{\\mathbf{p} \\in P} f'(\\mathbf{p})
|
937
|
+
\\cdot g'(\\mathbf{p})
|
938
|
+
|
939
|
+
References
|
940
|
+
----------
|
941
|
+
.. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
|
942
|
+
fitting", Journal of Structural Biology, vol. 1174, no. 2,
|
943
|
+
pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
|
944
|
+
"""
|
945
|
+
|
946
|
+
__doc__ += _MatchCoordinatesToDensity.__doc__
|
947
|
+
|
948
|
+
def __init__(self, target_threshold: float = None, **kwargs):
|
949
|
+
super().__init__(**kwargs)
|
950
|
+
if target_threshold is None:
|
951
|
+
target_threshold = np.mean(self.target)
|
952
|
+
self.target = np.where(self.target > target_threshold, -1, 1)
|
953
|
+
self.target_present = np.sum(self.target == -1)
|
954
|
+
self.target_absent = np.sum(self.target == 1)
|
955
|
+
self.template_weights = np.ones_like(self.template_weights)
|
956
|
+
|
957
|
+
def __call__(self) -> float:
|
958
|
+
"""Returns the score of the current configuration."""
|
959
|
+
score = self._target_values
|
960
|
+
unassigned_density = self.target_present - (score == -1).sum()
|
961
|
+
|
962
|
+
# Out of volume values will be set to 0
|
963
|
+
score = score.sum() - unassigned_density
|
964
|
+
score -= 2 * np.sum(np.invert(np.abs(self._target_values) > 0))
|
965
|
+
min_score = -self.target_present - 2 * self.target_absent
|
966
|
+
score = (score - 2 * min_score) / (2 * self.target_present - min_score)
|
967
|
+
|
968
|
+
return score * self.score_sign
|
969
|
+
|
970
|
+
|
971
|
+
class Chamfer(_MatchCoordinatesToCoordinates):
|
972
|
+
"""
|
973
|
+
The Chamfer distance between the target :math:`f` and the template :math:`g`
|
974
|
+
is calculated as:
|
975
|
+
|
976
|
+
.. math::
|
977
|
+
|
978
|
+
\\text{d(f,g)} = \\frac{1}{|X|} \\sum_{\\mathbf{f}_i \\in X}
|
979
|
+
\\inf_{\\mathbf{g} \\in Y} ||\\mathbf{f}_i - \\mathbf{g}||_2
|
980
|
+
|
981
|
+
References
|
982
|
+
----------
|
983
|
+
.. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
|
984
|
+
fitting", Journal of Structural Biology, vol. 174, no. 2,
|
985
|
+
pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
|
986
|
+
"""
|
987
|
+
|
988
|
+
__doc__ += _MatchCoordinatesToDensity.__doc__
|
989
|
+
|
990
|
+
def _post_init(self, **kwargs):
|
991
|
+
from scipy.spatial import KDTree
|
992
|
+
|
993
|
+
self.target_tree = KDTree(self.target_coordinates.T)
|
994
|
+
|
995
|
+
def __call__(self) -> float:
|
996
|
+
"""Returns the score of the current configuration."""
|
997
|
+
dist, _ = self.target_tree.query(self.template_coordinates_rotated.T)
|
998
|
+
score = np.mean(dist)
|
999
|
+
return score * self.score_sign
|
1000
|
+
|
1001
|
+
|
1002
|
+
class NormalVectorScore(_MatchCoordinatesToCoordinates):
|
1003
|
+
"""
|
1004
|
+
The Normal Vector Score (NVS) between the target's :math:`f` and the template
|
1005
|
+
:math:`g`'s normal vectors is calculated as:
|
1006
|
+
|
1007
|
+
.. math::
|
1008
|
+
|
1009
|
+
\\text{d(f,g)} = \\frac{1}{N} \\sum_{i=1}^{N}
|
1010
|
+
\\frac{
|
1011
|
+
{\\vec{f}_i} \\cdot {\\vec{g}_i}
|
1012
|
+
}{
|
1013
|
+
||\\vec{f}_i|| \\, ||\\vec{g}_i||
|
1014
|
+
}
|
1015
|
+
|
1016
|
+
References
|
1017
|
+
----------
|
1018
|
+
.. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
|
1019
|
+
fitting", Journal of Structural Biology, vol. 174, no. 2,
|
1020
|
+
pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
|
1021
|
+
|
1022
|
+
"""
|
1023
|
+
|
1024
|
+
__doc__ += _MatchCoordinatesToDensity.__doc__
|
1025
|
+
|
1026
|
+
def __call__(self) -> float:
|
1027
|
+
"""Returns the score of the current configuration."""
|
1028
|
+
numerator = np.multiply(
|
1029
|
+
self.template_coordinates_rotated, self.target_coordinates
|
1030
|
+
)
|
1031
|
+
denominator = np.linalg.norm(self.template_coordinates_rotated)
|
1032
|
+
denominator *= np.linalg.norm(self.target_coordinates)
|
1033
|
+
score = np.mean(numerator / denominator)
|
1034
|
+
return score
|
1035
|
+
|
1036
|
+
|
1037
|
+
MATCHING_OPTIMIZATION_REGISTER = {
|
1038
|
+
"CrossCorrelation": CrossCorrelation,
|
1039
|
+
"LaplaceCrossCorrelation": LaplaceCrossCorrelation,
|
1040
|
+
"NormalizedCrossCorrelationMean": NormalizedCrossCorrelationMean,
|
1041
|
+
"NormalizedCrossCorrelation": NormalizedCrossCorrelation,
|
1042
|
+
"MaskedCrossCorrelation": MaskedCrossCorrelation,
|
1043
|
+
"PartialLeastSquareDifference": PartialLeastSquareDifference,
|
1044
|
+
"Envelope": Envelope,
|
1045
|
+
"Chamfer": Chamfer,
|
1046
|
+
"MutualInformation": MutualInformation,
|
1047
|
+
"NormalVectorScore": NormalVectorScore,
|
1048
|
+
"FLC": FLC,
|
1049
|
+
}
|
1050
|
+
|
1051
|
+
|
1052
|
+
def register_matching_optimization(match_name: str, match_class: type):
|
1053
|
+
"""
|
1054
|
+
Registers a new mtaching method.
|
1055
|
+
|
1056
|
+
Parameters
|
1057
|
+
----------
|
1058
|
+
match_name : str
|
1059
|
+
Name of the matching instance.
|
1060
|
+
match_class : type
|
1061
|
+
Class pointer.
|
1062
|
+
|
1063
|
+
Raises
|
1064
|
+
------
|
1065
|
+
ValueError
|
1066
|
+
If any of the required methods is not defined.
|
1067
|
+
"""
|
1068
|
+
methods_to_check = ["__init__", "__call__"]
|
1069
|
+
|
1070
|
+
for method in methods_to_check:
|
1071
|
+
if not hasattr(match_class, method):
|
1072
|
+
raise ValueError(
|
1073
|
+
f"Method '{method}' is not defined in the provided class or object."
|
1074
|
+
)
|
1075
|
+
MATCHING_OPTIMIZATION_REGISTER[match_name] = match_class
|
1076
|
+
|
1077
|
+
|
1078
|
+
def create_score_object(score: str, **kwargs) -> object:
|
1079
|
+
"""
|
1080
|
+
Initialize score object with name ``score`` using ``**kwargs``.
|
1081
|
+
|
1082
|
+
Parameters
|
1083
|
+
----------
|
1084
|
+
score: str
|
1085
|
+
Name of the score.
|
1086
|
+
**kwargs: Dict
|
1087
|
+
Keyword arguments passed to the __init__ method of the score object.
|
1088
|
+
|
1089
|
+
Returns
|
1090
|
+
-------
|
1091
|
+
object
|
1092
|
+
Initialized score object.
|
1093
|
+
|
1094
|
+
Raises
|
1095
|
+
------
|
1096
|
+
ValueError
|
1097
|
+
If ``score`` is not a key in MATCHING_OPTIMIZATION_REGISTER.
|
1098
|
+
|
1099
|
+
See Also
|
1100
|
+
--------
|
1101
|
+
:py:meth:`register_matching_optimization`
|
1102
|
+
|
1103
|
+
Examples
|
1104
|
+
--------
|
1105
|
+
>>> from tme import Density
|
1106
|
+
>>> from tme.matching_utils import create_mask, euler_to_rotationmatrix
|
1107
|
+
>>> from tme.matching_optimization import CrossCorrelation, optimize_match
|
1108
|
+
>>> translation, rotation = (5, -2, 7), (5, -10, 2)
|
1109
|
+
>>> target = create_mask(
|
1110
|
+
>>> mask_type="ellipse",
|
1111
|
+
>>> radius=(5,5,5),
|
1112
|
+
>>> shape=(51,51,51),
|
1113
|
+
>>> center=(25,25,25),
|
1114
|
+
>>> ).astype(float)
|
1115
|
+
>>> template = Density(data=target)
|
1116
|
+
>>> template = template.rigid_transform(
|
1117
|
+
>>> translation=translation,
|
1118
|
+
>>> rotation_matrix=euler_to_rotationmatrix(rotation),
|
1119
|
+
>>> )
|
1120
|
+
>>> template_coordinates = template.to_pointcloud(0)
|
1121
|
+
>>> template_weights = template.data[tuple(template_coordinates)]
|
1122
|
+
>>> score_object = CrossCorrelation(
|
1123
|
+
>>> target=target,
|
1124
|
+
>>> template_coordinates=template_coordinates,
|
1125
|
+
>>> template_weights=template_weights,
|
1126
|
+
>>> negate_score=True # Multiply returned score with -1 for minimization
|
1127
|
+
>>> )
|
1128
|
+
"""
|
1129
|
+
|
1130
|
+
score_object = MATCHING_OPTIMIZATION_REGISTER.get(score, None)
|
1131
|
+
|
1132
|
+
if score_object is None:
|
1133
|
+
raise ValueError(
|
1134
|
+
f"{score} is not defined. Please pick from "
|
1135
|
+
f" {', '.join(list(MATCHING_OPTIMIZATION_REGISTER.keys()))}."
|
1136
|
+
)
|
1137
|
+
|
1138
|
+
score_object = score_object(**kwargs)
|
1139
|
+
return score_object
|
1140
|
+
|
1141
|
+
|
1142
|
+
def optimize_match(
|
1143
|
+
score_object: object,
|
1144
|
+
bounds_translation: Tuple[Tuple[float]] = None,
|
1145
|
+
bounds_rotation: Tuple[Tuple[float]] = None,
|
1146
|
+
optimization_method: str = "basinhopping",
|
1147
|
+
maxiter: int = 50,
|
1148
|
+
x0: Tuple[float] = None,
|
1149
|
+
) -> Tuple[ArrayLike, ArrayLike, float]:
|
1150
|
+
"""
|
1151
|
+
Find the translation and rotation optimizing the score returned by ``score_object``
|
1152
|
+
with respect to provided bounds.
|
1153
|
+
|
1154
|
+
Parameters
|
1155
|
+
----------
|
1156
|
+
score_object: object
|
1157
|
+
Class object that defines a score method, which returns a floating point
|
1158
|
+
value given a tuple of floating points where the first half describes a
|
1159
|
+
translation and the second a rotation. The score will be minimized, i.e.
|
1160
|
+
it has to be negated if similarity should be optimized.
|
1161
|
+
bounds_translation : tuple of tuple float, optional
|
1162
|
+
Bounds on the evaluated translations. Has to be specified per dimension
|
1163
|
+
as tuple of (min, max). Default is None.
|
1164
|
+
bounds_rotation : tuple of tuple float, optional
|
1165
|
+
Bounds on the evaluated zyz Euler angles. Has to be specified per dimension
|
1166
|
+
as tuple of (min, max). Default is None.
|
1167
|
+
optimization_method : str, optional
|
1168
|
+
Optimizer that will be used, basinhopping by default. For further
|
1169
|
+
information refer to :doc:`scipy:reference/optimize`.
|
1170
|
+
|
1171
|
+
+------------------------+-------------------------------------------+
|
1172
|
+
| differential_evolution | Highest accuracy but long runtime. |
|
1173
|
+
| | Requires bounds on translation. |
|
1174
|
+
+------------------------+-------------------------------------------+
|
1175
|
+
| basinhopping | Decent accuracy, medium runtime. |
|
1176
|
+
+------------------------+-------------------------------------------+
|
1177
|
+
| minimize | If initial values are closed to optimum |
|
1178
|
+
| | acceptable accuracy and short runtime |
|
1179
|
+
+------------------------+-------------------------------------------+
|
1180
|
+
|
1181
|
+
maxiter : int, optional
|
1182
|
+
The maximum number of iterations, 50 by default.
|
1183
|
+
x0 : tuple of floats, optional
|
1184
|
+
Initial values for the optimizer, zero by default.
|
1185
|
+
|
1186
|
+
Returns
|
1187
|
+
-------
|
1188
|
+
Tuple[ArrayLike, ArrayLike, float]
|
1189
|
+
Optimal translation, rotation matrix and corresponding score.
|
1190
|
+
|
1191
|
+
Raises
|
1192
|
+
------
|
1193
|
+
ValueError
|
1194
|
+
If ``optimization_method`` is not supported.
|
1195
|
+
|
1196
|
+
Notes
|
1197
|
+
-----
|
1198
|
+
This function currently only supports three-dimensional optimization and
|
1199
|
+
``score_object`` will be modified during this operation.
|
1200
|
+
|
1201
|
+
Examples
|
1202
|
+
--------
|
1203
|
+
Having defined ``score_object``, for instance via :py:meth:`create_score_object`,
|
1204
|
+
non-exhaustive template matching can be performed as follows
|
1205
|
+
|
1206
|
+
>>> translation_fit, rotation_fit, score = optimize_match(score_object)
|
1207
|
+
|
1208
|
+
`translation_fit` and `rotation_fit` correspond to the inverse of the applied
|
1209
|
+
translation and rotation, so the following statements should hold within tolerance
|
1210
|
+
|
1211
|
+
>>> np.allclose(translation, -translation_fit, atol = 1) # True
|
1212
|
+
>>> np.allclose(rotation, np.linalg.inv(rotation_fit), rtol = .1) # True
|
1213
|
+
|
1214
|
+
Bounds on translation and rotation can be defined as follows
|
1215
|
+
|
1216
|
+
>>> translation_fit, rotation_fit, score = optimize_match(
|
1217
|
+
>>> score_object=score_object,
|
1218
|
+
>>> bounds_translation=((-5,5),(-2,2),(0,0)),
|
1219
|
+
>>> bounds_rotation=((-10,10), (-5,5), (0,0)),
|
1220
|
+
>>> )
|
1221
|
+
|
1222
|
+
The optimization scheme and the initial parameter estimates can also be adapted
|
1223
|
+
|
1224
|
+
>>> translation_fit, rotation_fit, score = optimize_match(
|
1225
|
+
>>> score_object=score_object,
|
1226
|
+
>>> optimization_method="minimize",
|
1227
|
+
>>> x0=(0,0,0,5,3,-5),
|
1228
|
+
>>> )
|
1229
|
+
|
1230
|
+
"""
|
1231
|
+
ndim = 3
|
1232
|
+
_optimization_method = {
|
1233
|
+
"differential_evolution": differential_evolution,
|
1234
|
+
"basinhopping": basinhopping,
|
1235
|
+
"minimize": minimize,
|
1236
|
+
}
|
1237
|
+
if optimization_method not in _optimization_method:
|
1238
|
+
raise ValueError(
|
1239
|
+
f"{optimization_method} is not supported. "
|
1240
|
+
f"Pick from {', '.join(list(_optimization_method.keys()))}"
|
1241
|
+
)
|
1242
|
+
|
1243
|
+
finfo = np.finfo(np.float32)
|
1244
|
+
|
1245
|
+
# DE always requires bounds
|
1246
|
+
if optimization_method == "differential_evolution" and bounds_translation is None:
|
1247
|
+
bounds_translation = tuple((finfo.min, finfo.max) for _ in range(ndim))
|
1248
|
+
|
1249
|
+
if bounds_translation is None and bounds_rotation is not None:
|
1250
|
+
bounds_translation = tuple((finfo.min, finfo.max) for _ in range(ndim))
|
1251
|
+
|
1252
|
+
if bounds_rotation is None and bounds_translation is not None:
|
1253
|
+
bounds_rotation = tuple((-180, 180) for _ in range(ndim))
|
1254
|
+
|
1255
|
+
bounds, linear_constraint = None, ()
|
1256
|
+
if bounds_rotation is not None and bounds_translation is not None:
|
1257
|
+
uncertainty = (*bounds_translation, *bounds_rotation)
|
1258
|
+
bounds = [
|
1259
|
+
bound if bound != (0, 0) else (-finfo.resolution, finfo.resolution)
|
1260
|
+
for bound in uncertainty
|
1261
|
+
]
|
1262
|
+
linear_constraint = LinearConstraint(
|
1263
|
+
np.eye(len(bounds)), np.min(bounds, axis=1), np.max(bounds, axis=1)
|
1264
|
+
)
|
1265
|
+
|
1266
|
+
x0 = np.zeros(2 * ndim) if x0 is None else x0
|
1267
|
+
|
1268
|
+
return_gradient = getattr(score_object, "return_gradient", False)
|
1269
|
+
if optimization_method != "minimize" and return_gradient:
|
1270
|
+
warnings.warn("Gradient only considered for optimization_method='minimize'.")
|
1271
|
+
score_object.return_gradient = False
|
1272
|
+
|
1273
|
+
initial_score = score_object.score(x=x0)
|
1274
|
+
if isinstance(initial_score, (List, Tuple)):
|
1275
|
+
initial_score = initial_score[0]
|
1276
|
+
|
1277
|
+
if optimization_method == "basinhopping":
|
1278
|
+
result = basinhopping(
|
1279
|
+
x0=x0,
|
1280
|
+
func=score_object.score,
|
1281
|
+
niter=maxiter,
|
1282
|
+
minimizer_kwargs={"method": "COBYLA", "constraints": linear_constraint},
|
1283
|
+
)
|
1284
|
+
elif optimization_method == "differential_evolution":
|
1285
|
+
result = differential_evolution(
|
1286
|
+
func=score_object.score,
|
1287
|
+
bounds=bounds,
|
1288
|
+
constraints=linear_constraint,
|
1289
|
+
maxiter=maxiter,
|
1290
|
+
)
|
1291
|
+
elif optimization_method == "minimize":
|
1292
|
+
if hasattr(score_object, "grad") and not return_gradient:
|
1293
|
+
warnings.warn(
|
1294
|
+
"Consider initializing score object with return_gradient=True."
|
1295
|
+
)
|
1296
|
+
result = minimize(
|
1297
|
+
x0=x0,
|
1298
|
+
fun=score_object.score,
|
1299
|
+
jac=return_gradient,
|
1300
|
+
bounds=bounds,
|
1301
|
+
constraints=linear_constraint,
|
1302
|
+
options={"maxiter": maxiter},
|
1303
|
+
)
|
1304
|
+
print(f"Niter: {result.nit}, success : {result.success} ({result.message}).")
|
1305
|
+
print(f"Initial score: {initial_score} - Refined score: {result.fun}")
|
1306
|
+
if initial_score < result.fun:
|
1307
|
+
print("Initial score better than refined score. Returning identity.")
|
1308
|
+
result.x = np.zeros_like(result.x)
|
1309
|
+
translation, rotation = result.x[:ndim], result.x[ndim:]
|
1310
|
+
rotation_matrix = euler_to_rotationmatrix(rotation)
|
1311
|
+
return translation, rotation_matrix, result.fun
|