pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,1311 @@
1
+ """ Implements methods for non-exhaustive template matching.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ import warnings
9
+ from typing import Tuple, List, Dict
10
+ from abc import ABC, abstractmethod
11
+
12
+ import numpy as np
13
+ from scipy.ndimage import laplace, map_coordinates, sobel
14
+ from scipy.optimize import (
15
+ minimize,
16
+ basinhopping,
17
+ LinearConstraint,
18
+ differential_evolution,
19
+ )
20
+
21
+ from .backends import backend as be
22
+ from .types import ArrayLike, NDArray
23
+ from .matching_data import MatchingData
24
+ from .rotations import euler_to_rotationmatrix
25
+ from .matching_utils import rigid_transform, normalize_template
26
+
27
+
28
+ def _format_rigid_transform(x: Tuple[float]) -> Tuple[ArrayLike, ArrayLike]:
29
+ """
30
+ Returns a formated rigid transform definition.
31
+
32
+ Parameters
33
+ ----------
34
+ x : tuple of float
35
+ Even-length tuple where the first half represents translations and the
36
+ second half Euler angles in zyz convention for each dimension.
37
+
38
+ Returns
39
+ -------
40
+ Tuple[ArrayLike, ArrayLike]
41
+ Translation of length [d, ] and rotation matrix with dimension [d x d].
42
+ """
43
+ split = len(x) // 2
44
+ translation, angles = x[:split], x[split:]
45
+
46
+ translation = be.to_backend_array(translation)
47
+ rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles))
48
+ rotation_matrix = be.to_backend_array(rotation_matrix)
49
+
50
+ return translation, rotation_matrix
51
+
52
+
53
+ class _MatchDensityToDensity(ABC):
54
+ """
55
+ Parameters
56
+ ----------
57
+ target : array_like
58
+ The target density array.
59
+ template : array_like
60
+ The template density array.
61
+ template_mask : array_like, optional
62
+ Mask array for the template density.
63
+ target_mask : array_like, optional
64
+ Mask array for the target density.
65
+ pad_target_edges : bool, optional
66
+ Whether to pad the edges of the target density array. Default is False.
67
+ pad_fourier : bool, optional
68
+ Whether to pad the Fourier transform of the target and template densities.
69
+ rotate_mask : bool, optional
70
+ Whether to rotate the mask arrays along with the densities. Default is True.
71
+ interpolation_order : int, optional
72
+ The interpolation order for rigid transforms. Default is 1.
73
+ negate_score : bool, optional
74
+ Whether the final score should be multiplied by negative one. Default is True.
75
+ **kwargs : Dict, optional
76
+ Keyword arguments propagated to downstream functions.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ target: ArrayLike,
82
+ template: ArrayLike,
83
+ template_mask: ArrayLike = None,
84
+ target_mask: ArrayLike = None,
85
+ pad_target_edges: bool = False,
86
+ pad_fourier: bool = False,
87
+ rotate_mask: bool = True,
88
+ interpolation_order: int = 1,
89
+ negate_score: bool = True,
90
+ **kwargs: Dict,
91
+ ):
92
+ self.eps = be.eps(target.dtype)
93
+ self.rotate_mask = rotate_mask
94
+ self.interpolation_order = interpolation_order
95
+
96
+ matching_data = MatchingData(target=target, template=template)
97
+ if template_mask is not None:
98
+ matching_data.template_mask = template_mask
99
+ if target_mask is not None:
100
+ matching_data.target_mask = target_mask
101
+
102
+ self.target, self.target_mask = matching_data.target, matching_data.target_mask
103
+
104
+ self.template = matching_data._template
105
+ self.template_rot = be.zeros(template.shape, be._float_dtype)
106
+
107
+ self.template_mask, self.template_mask_rot = 1, 1
108
+ rotate_mask = False if matching_data._template_mask is None else rotate_mask
109
+ if matching_data.template_mask is not None:
110
+ self.template_mask = matching_data._template_mask
111
+ self.template_mask_rot = be.topleft_pad(
112
+ matching_data._template_mask, self.template_mask.shape
113
+ )
114
+
115
+ self.template_slices = tuple(slice(None) for _ in self.template.shape)
116
+ self.target_slices = tuple(slice(0, x) for x in self.template.shape)
117
+
118
+ self.score_sign = -1 if negate_score else 1
119
+
120
+ if hasattr(self, "_post_init"):
121
+ self._post_init(**kwargs)
122
+
123
+ def rotate_array(
124
+ self,
125
+ arr,
126
+ rotation_matrix,
127
+ translation,
128
+ arr_mask=None,
129
+ out=None,
130
+ out_mask=None,
131
+ order: int = 1,
132
+ **kwargs,
133
+ ):
134
+ rotate_mask = arr_mask is not None
135
+ return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
136
+ translation = np.zeros(arr.ndim) if translation is None else translation
137
+
138
+ center = np.floor(np.array(arr.shape) / 2)[:, None]
139
+
140
+ if not hasattr(self, "_previous_center"):
141
+ self._previous_center = arr.shape
142
+
143
+ if not hasattr(self, "grid") or not np.allclose(self._previous_center, center):
144
+ self.grid = np.indices(arr.shape, dtype=np.float32).reshape(arr.ndim, -1)
145
+ np.subtract(self.grid, center, out=self.grid)
146
+ self.grid_out = np.zeros_like(self.grid)
147
+ self._previous_center = center
148
+
149
+ np.matmul(rotation_matrix.T, self.grid, out=self.grid_out)
150
+ translation = np.add(translation[:, None], center)
151
+ np.add(self.grid_out, translation, out=self.grid_out)
152
+
153
+ if out is None:
154
+ out = np.zeros_like(arr)
155
+
156
+ self._interpolate(arr, self.grid_out, order=order, out=out.ravel())
157
+
158
+ if out_mask is None and arr_mask is not None:
159
+ out_mask = np.zeros_like(arr_mask)
160
+
161
+ if arr_mask is not None:
162
+ self._interpolate(arr_mask, self.grid_out, order=order, out=out.ravel())
163
+
164
+ match return_type:
165
+ case 0:
166
+ return None
167
+ case 1:
168
+ return out
169
+ case 2:
170
+ return out_mask
171
+ case 3:
172
+ return out, out_mask
173
+
174
+ @staticmethod
175
+ def _interpolate(data, positions, order: int = 1, out=None):
176
+ return map_coordinates(
177
+ data, positions, order=order, mode="constant", output=out
178
+ )
179
+
180
+ def score_translation(self, x: Tuple[float]) -> float:
181
+ """
182
+ Computes the score after a given translation.
183
+
184
+ Parameters
185
+ ----------
186
+ x : tuple of float
187
+ Tuple representing the translation transformation in each dimension.
188
+
189
+ Returns
190
+ -------
191
+ float
192
+ The score obtained for the translation transformation.
193
+ """
194
+ return self.score((*x, *[0 for _ in range(len(x))]))
195
+
196
+ def score_angles(self, x: Tuple[float]) -> float:
197
+ """
198
+ Computes the score after a given rotation.
199
+
200
+ Parameters
201
+ ----------
202
+ x : tuple of float
203
+ Tuple of Euler angles in zyz convention for each dimension.
204
+
205
+ Returns
206
+ -------
207
+ float
208
+ The score obtained for the rotation transformation.
209
+ """
210
+ return self.score((*[0 for _ in range(len(x))], *x))
211
+
212
+ def score(self, x: Tuple[float]) -> float:
213
+ """
214
+ Compute the matching score for the given transformation parameters.
215
+
216
+ Parameters
217
+ ----------
218
+ x : tuple of float
219
+ Even-length tuple where the first half represents translations and the
220
+ second half Euler angles in zyz convention for each dimension.
221
+
222
+ Returns
223
+ -------
224
+ float
225
+ The matching score obtained for the transformation.
226
+ """
227
+ translation, rotation_matrix = _format_rigid_transform(x)
228
+ self.template_rot.fill(0)
229
+
230
+ voxel_translation = be.astype(translation, be._int_dtype)
231
+ subvoxel_translation = be.subtract(translation, voxel_translation)
232
+
233
+ center = be.astype(be.divide(self.template.shape, 2), be._int_dtype)
234
+ right_pad = be.subtract(self.template.shape, center)
235
+
236
+ translated_center = be.add(voxel_translation, center)
237
+
238
+ target_starts = be.subtract(translated_center, center)
239
+ target_stops = be.add(translated_center, right_pad)
240
+
241
+ template_starts = be.subtract(be.maximum(target_starts, 0), target_starts)
242
+ template_stops = be.subtract(
243
+ target_stops, be.minimum(target_stops, self.target.shape)
244
+ )
245
+ template_stops = be.subtract(self.template.shape, template_stops)
246
+
247
+ target_starts = be.maximum(target_starts, 0)
248
+ target_stops = be.minimum(target_stops, self.target.shape)
249
+
250
+ cand_start, cand_stop = template_starts.astype(int), template_stops.astype(int)
251
+ obs_start, obs_stop = target_starts.astype(int), target_stops.astype(int)
252
+
253
+ self.template_slices = tuple(slice(s, e) for s, e in zip(cand_start, cand_stop))
254
+ self.target_slices = tuple(slice(s, e) for s, e in zip(obs_start, obs_stop))
255
+
256
+ kw_dict = {
257
+ "arr": self.template,
258
+ "rotation_matrix": rotation_matrix,
259
+ "translation": subvoxel_translation,
260
+ "out": self.template_rot,
261
+ "order": self.interpolation_order,
262
+ "use_geometric_center": True,
263
+ }
264
+ if self.rotate_mask:
265
+ self.template_mask_rot.fill(0)
266
+ kw_dict["arr_mask"] = self.template_mask
267
+ kw_dict["out_mask"] = self.template_mask_rot
268
+
269
+ self.rotate_array(**kw_dict)
270
+
271
+ return self()
272
+
273
+ @abstractmethod
274
+ def __call__(self) -> float:
275
+ """Returns the score of the current configuration."""
276
+
277
+
278
+ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
279
+ """
280
+ Parameters
281
+ ----------
282
+ target : NDArray
283
+ A d-dimensional target to match the template coordinate set to.
284
+ template_coordinates : NDArray
285
+ Template coordinate array with shape (d,n).
286
+ template_weights : NDArray
287
+ Template weight array with shape (n,).
288
+ template_mask_coordinates : NDArray, optional
289
+ Template mask coordinates with shape (d,n).
290
+ target_mask : NDArray, optional
291
+ A d-dimensional mask to be applied to the target.
292
+ negate_score : bool, optional
293
+ Whether the final score should be multiplied by negative one. Default is True.
294
+ return_gradient : bool, optional
295
+ Invoking __call_ returns a tuple of score and parameter gradient. Default is False.
296
+ **kwargs : Dict, optional
297
+ Keyword arguments propagated to downstream functions.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ target: NDArray,
303
+ template_coordinates: NDArray,
304
+ template_weights: NDArray,
305
+ template_mask_coordinates: NDArray = None,
306
+ target_mask: NDArray = None,
307
+ negate_score: bool = True,
308
+ return_gradient: bool = False,
309
+ interpolation_order: int = 1,
310
+ **kwargs: Dict,
311
+ ):
312
+ self.target = target.astype(np.float32)
313
+ self.target_mask = None
314
+ if target_mask is not None:
315
+ self.target_mask = target_mask.astype(np.float32)
316
+
317
+ self.eps = be.eps(self.target.dtype)
318
+
319
+ self.target_grad = np.stack(
320
+ [sobel(self.target, axis=i) for i in range(self.target.ndim)]
321
+ )
322
+
323
+ self.n_points = template_coordinates.shape[1]
324
+ self.template = template_coordinates.astype(np.float32)
325
+ self.template_rotated = np.zeros_like(self.template)
326
+ self.template_weights = template_weights.astype(np.float32)
327
+ self.template_center = np.mean(self.template, axis=1)[:, None]
328
+
329
+ self.template_mask, self.template_mask_rotated = None, None
330
+ if template_mask_coordinates is not None:
331
+ self.template_mask = template_mask_coordinates.astype(np.float32)
332
+ self.template_mask_rotated = np.empty_like(self.template_mask)
333
+
334
+ self.denominator = 1
335
+ self.score_sign = -1 if negate_score else 1
336
+ self.interpolation_order = interpolation_order
337
+
338
+ self._target_values = self._interpolate(
339
+ self.target, self.template, order=self.interpolation_order
340
+ )
341
+
342
+ if return_gradient and not hasattr(self, "grad"):
343
+ raise NotImplementedError(f"{type(self)} does not have grad method.")
344
+ self.return_gradient = return_gradient
345
+
346
+ def score(self, x: Tuple[float]):
347
+ """
348
+ Compute the matching score for the given transformation parameters.
349
+
350
+ Parameters
351
+ ----------
352
+ x : tuple of float
353
+ Even-length tuple where the first half represents translations and the
354
+ second half Euler angles in zyz convention for each dimension.
355
+
356
+ Returns
357
+ -------
358
+ float
359
+ The matching score obtained for the transformation.
360
+ """
361
+ translation, rotation_matrix = _format_rigid_transform(x)
362
+
363
+ rigid_transform(
364
+ coordinates=self.template,
365
+ coordinates_mask=self.template_mask,
366
+ rotation_matrix=rotation_matrix,
367
+ translation=translation,
368
+ out=self.template_rotated,
369
+ out_mask=self.template_mask_rotated,
370
+ use_geometric_center=False,
371
+ )
372
+
373
+ self._target_values = self._interpolate(
374
+ self.target, self.template_rotated, order=self.interpolation_order
375
+ )
376
+
377
+ score = self()
378
+ if not self.return_gradient:
379
+ return score
380
+
381
+ return score, self.grad()
382
+
383
+ def _interpolate_gradient(self, positions):
384
+ ret = be.zeros(positions.shape, dtype=positions.dtype)
385
+
386
+ for k in range(self.target_grad.shape[0]):
387
+ ret[k, :] = self._interpolate(
388
+ self.target_grad[k], positions, order=self.interpolation_order
389
+ )
390
+
391
+ return ret
392
+
393
+ @staticmethod
394
+ def _torques(positions, center, gradients):
395
+ positions_center = (positions - center).T
396
+ return be.cross(positions_center, gradients.T).T
397
+
398
+
399
+ class _MatchCoordinatesToCoordinates(_MatchDensityToDensity):
400
+ """
401
+ Parameters
402
+ ----------
403
+ target_coordinates : NDArray
404
+ The coordinates of the target with shape [d x N].
405
+ template_coordinates : NDArray
406
+ The coordinates of the template with shape [d x T].
407
+ target_weights : NDArray
408
+ The weights of the target with shape [N].
409
+ template_weights : NDArray
410
+ The weights of the template with shape [T].
411
+ template_mask_coordinates : NDArray, optional
412
+ The coordinates of the template mask with shape [d x T]. Default is None.
413
+ target_mask_coordinates : NDArray, optional
414
+ The coordinates of the target mask with shape [d X N]. Default is None.
415
+ negate_score : bool, optional
416
+ Whether the final score should be multiplied by negative one. Default is True.
417
+ **kwargs : Dict, optional
418
+ Keyword arguments propagated to downstream functions.
419
+ """
420
+
421
+ def __init__(
422
+ self,
423
+ target_coordinates: NDArray,
424
+ template_coordinates: NDArray,
425
+ target_weights: NDArray,
426
+ template_weights: NDArray,
427
+ template_mask_coordinates: NDArray = None,
428
+ target_mask_coordinates: NDArray = None,
429
+ negate_score: bool = True,
430
+ **kwargs,
431
+ ):
432
+ self.target_weights = target_weights
433
+ self.target_coordinates = target_coordinates
434
+
435
+ self.template_weights = template_weights
436
+ self.template_coordinates = template_coordinates
437
+ self.template_coordinates_rotated = np.empty(
438
+ self.template_coordinates.shape, dtype=np.float32
439
+ )
440
+ self.target_mask_coordinates = target_mask_coordinates
441
+
442
+ self.template_mask_coordinates = None
443
+ self.template_mask_coordinates_rotated = None
444
+ if template_mask_coordinates is not None:
445
+ self.template_mask_coordinates = template_mask_coordinates
446
+ self.template_mask_coordinates_rotated = np.empty(
447
+ self.template_mask_coordinates.shape, dtype=np.float32
448
+ )
449
+ self.score_sign = -1 if negate_score else 1
450
+
451
+ if hasattr(self, "_post_init"):
452
+ self._post_init(**kwargs)
453
+
454
+ def score(self, x: Tuple[float]) -> float:
455
+ """
456
+ Compute the matching score for the given transformation parameters.
457
+
458
+ Parameters
459
+ ----------
460
+ x : tuple of float
461
+ Even-length tuple where the first half represents translations and the
462
+ second half Euler angles in zyz convention for each dimension.
463
+
464
+ Returns
465
+ -------
466
+ float
467
+ The matching score obtained for the transformation.
468
+ """
469
+ translation, rotation_matrix = _format_rigid_transform(x)
470
+
471
+ rigid_transform(
472
+ coordinates=self.template_coordinates,
473
+ coordinates_mask=self.template_mask_coordinates,
474
+ rotation_matrix=rotation_matrix,
475
+ translation=translation,
476
+ out=self.template_coordinates_rotated,
477
+ out_mask=self.template_mask_coordinates_rotated,
478
+ use_geometric_center=False,
479
+ )
480
+
481
+ return self(
482
+ transformed_coordinates=self.template_coordinates_rotated,
483
+ transformed_coordinates_mask=self.template_mask_coordinates_rotated,
484
+ )
485
+
486
+
487
+ class FLC(_MatchDensityToDensity):
488
+ """
489
+ Computes a normalized cross-correlation score of a target f a template g
490
+ and a mask m:
491
+
492
+ .. math::
493
+
494
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
495
+ {N_m * \\sqrt{
496
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
497
+ }
498
+
499
+ Where:
500
+
501
+ .. math::
502
+
503
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
504
+
505
+ and Nm is the number of voxels within the template mask m.
506
+ """
507
+
508
+ __doc__ += _MatchDensityToDensity.__doc__
509
+
510
+ def _post_init(self, **kwargs: Dict):
511
+ if self.target_mask is not None:
512
+ be.multiply(self.target, self.target_mask, out=self.target)
513
+
514
+ self.target_square = be.square(self.target)
515
+
516
+ normalize_template(
517
+ template=self.template,
518
+ mask=self.template_mask,
519
+ n_observations=be.sum(self.template_mask),
520
+ )
521
+
522
+ def __call__(self) -> float:
523
+ """Returns the score of the current configuration."""
524
+ n_obs = be.sum(self.template_mask_rot)
525
+
526
+ normalize_template(
527
+ template=self.template_rot,
528
+ mask=self.template_mask_rot,
529
+ n_observations=n_obs,
530
+ )
531
+ overlap = be.sum(
532
+ be.multiply(
533
+ self.template_rot[self.template_slices], self.target[self.target_slices]
534
+ )
535
+ )
536
+
537
+ mask_rot = self.template_mask_rot[self.template_slices]
538
+ exp_sq = be.sum(self.target_square[self.target_slices] * mask_rot) / n_obs
539
+ sq_exp = be.square(be.sum(self.target[self.target_slices] * mask_rot) / n_obs)
540
+
541
+ denominator = be.maximum(be.subtract(exp_sq, sq_exp), 0.0)
542
+ denominator = be.sqrt(denominator)
543
+ if denominator < self.eps:
544
+ return 0
545
+
546
+ score = be.divide(overlap, denominator * n_obs) * self.score_sign
547
+ return score
548
+
549
+
550
+ class CrossCorrelation(_MatchCoordinatesToDensity):
551
+ """
552
+ Computes the Cross-Correlation score as:
553
+
554
+ .. math::
555
+
556
+ \\text{score} = \\text{target_weights} \\cdot \\text{template_weights}
557
+ """
558
+
559
+ __doc__ += _MatchCoordinatesToDensity.__doc__
560
+
561
+ def __call__(self) -> float:
562
+ """Returns the score of the current configuration."""
563
+ score = be.dot(self._target_values, self.template_weights)
564
+ score /= self.denominator * self.score_sign
565
+ return score
566
+
567
+ def grad(self):
568
+ """
569
+ Calculate the gradient of the cost function w.r.t. translation and rotation.
570
+
571
+ .. math::
572
+
573
+ \\nabla f = -\\frac{1}{N} \\begin{bmatrix}
574
+ \\sum_i w_i \\nabla v(x_i) \\\\
575
+ \\sum_i w_i (r_i \\times \\nabla v(x_i))
576
+ \\end{bmatrix}
577
+
578
+ where :math:`N` is the number of points, :math:`w_i` are weights,
579
+ :math:`x_i` are rotated template positions, and :math:`r_i` are
580
+ positions relative to the template center.
581
+
582
+ Returns
583
+ -------
584
+ np.ndarray
585
+ Negative gradient of the cost function: [dx, dy, dz, dRx, dRy, dRz].
586
+
587
+ """
588
+ grad = self._interpolate_gradient(positions=self.template_rotated)
589
+ torque = self._torques(
590
+ positions=self.template_rotated, gradients=grad, center=self.template_center
591
+ )
592
+
593
+ translation_grad = be.sum(grad * self.template_weights, axis=1)
594
+ torque_grad = be.sum(torque * self.template_weights, axis=1)
595
+
596
+ # <u, dv/dx> / <u, r x dv/dx>
597
+ total_grad = be.concatenate([translation_grad, torque_grad])
598
+ total_grad = be.divide(total_grad, self.n_points, out=total_grad)
599
+ return -total_grad
600
+
601
+
602
+ class LaplaceCrossCorrelation(CrossCorrelation):
603
+ """
604
+ Uses the same formalism as :py:class:`CrossCorrelation` but with Laplace
605
+ filtered weights (:math:`\\nabla^{2}`):
606
+
607
+ .. math::
608
+
609
+ \\text{score} = \\nabla^{2} \\text{target_weights} \\cdot
610
+ \\nabla^{2} \\text{template_weights}
611
+ """
612
+
613
+ __doc__ += _MatchCoordinatesToDensity.__doc__
614
+
615
+ def __init__(self, **kwargs):
616
+ kwargs["target"] = laplace(kwargs["target"])
617
+
618
+ coordinates = kwargs["template_coordinates"]
619
+ origin = coordinates.min(axis=1)
620
+ positions = (coordinates - origin[:, None]).astype(int)
621
+ shape = positions.max(axis=1) + 1
622
+ arr = np.zeros(shape, dtype=np.float32)
623
+ np.add.at(arr, tuple(positions), kwargs["template_weights"])
624
+
625
+ kwargs["template_weights"] = laplace(arr)[tuple(positions)]
626
+ super().__init__(**kwargs)
627
+
628
+
629
+ class NormalizedCrossCorrelation(CrossCorrelation):
630
+ """
631
+ Computes a normalized version of the :py:class:`CrossCorrelation` score based
632
+ on the dot product of `target_weights` and `template_weights`, in order to
633
+ reduce bias to regions of high local energy.
634
+
635
+ .. math::
636
+
637
+ \\text{score} = \\frac{\\text{target_weights} \\cdot \\text{template_weights}}
638
+ {\\text{max(target_norm} \\times \\text{template_norm, eps)}}
639
+
640
+ Where:
641
+
642
+ .. math::
643
+
644
+ \\text{target_norm} = ||\\text{target_weights}||
645
+
646
+ .. math::
647
+
648
+ \\text{template_norm} = ||\\text{template_weights}||
649
+
650
+ Here, :math:`||.||` denotes the L2 (Euclidean) norm.
651
+ """
652
+
653
+ __doc__ += _MatchCoordinatesToDensity.__doc__
654
+
655
+ def __call__(self) -> float:
656
+ denominator = be.multiply(
657
+ np.linalg.norm(self.template_weights), np.linalg.norm(self._target_values)
658
+ )
659
+
660
+ if denominator <= 0:
661
+ return 0.0
662
+
663
+ self.denominator = denominator
664
+ return super().__call__()
665
+
666
+ def grad(self):
667
+ """
668
+ Calculate the normalized gradient of the cost function w.r.t. translation and rotation.
669
+
670
+ .. math::
671
+
672
+ \\nabla f = -\\frac{1}{N|w||v|^3} \\begin{bmatrix}
673
+ (\\sum_i w_i \\nabla v(x_i))|v|^2 - (\\sum_i v(x_i)
674
+ \\nabla v(x_i))(w \\cdot v) \\\\
675
+ (\\sum_i w_i (r_i \\times \\nabla v(x_i)))|v|^2 - (\\sum_i v(x_i)
676
+ (r_i \\times \\nabla v(x_i)))(w \\cdot v)
677
+ \\end{bmatrix}
678
+
679
+ where :math:`N` is the number of points, :math:`w` are weights,
680
+ :math:`v` are target values, :math:`x_i` are rotated template positions,
681
+ and :math:`r_i` are positions relative to the template center.
682
+
683
+ Returns
684
+ -------
685
+ np.ndarray
686
+ Negative normalized gradient: [dx, dy, dz, dRx, dRy, dRz].
687
+
688
+ """
689
+ grad = self._interpolate_gradient(positions=self.template_rotated)
690
+ torque = self._torques(
691
+ positions=self.template_rotated, gradients=grad, center=self.template_center
692
+ )
693
+
694
+ norm = be.multiply(
695
+ be.power(be.sqrt(be.sum(be.square(self._target_values))), 3),
696
+ be.sqrt(be.sum(be.square(self.template_weights))),
697
+ )
698
+
699
+ # (<u,dv/dx> * |v|**2 - <u,v> * <v,dv/dx>)/(|w|*|v|**3)
700
+ translation_grad = be.multiply(
701
+ be.sum(be.multiply(grad, self.template_weights), axis=1),
702
+ be.sum(be.square(self._target_values)),
703
+ )
704
+ translation_grad -= be.multiply(
705
+ be.sum(be.multiply(grad, self._target_values), axis=1),
706
+ be.sum(be.multiply(self._target_values, self.template_weights)),
707
+ )
708
+
709
+ # (<u,r x dv/dx> * |v|**2 - <u,v> * <v,r x dv/dx>)/(|w|*|v|**3)
710
+ torque_grad = be.multiply(
711
+ be.sum(be.multiply(torque, self.template_weights), axis=1),
712
+ be.sum(be.square(self._target_values)),
713
+ )
714
+ torque_grad -= be.multiply(
715
+ be.sum(be.multiply(torque, self._target_values), axis=1),
716
+ be.sum(be.multiply(self._target_values, self.template_weights)),
717
+ )
718
+
719
+ total_grad = be.concatenate([translation_grad, torque_grad])
720
+ if norm > 0:
721
+ total_grad = be.divide(total_grad, norm, out=total_grad)
722
+
723
+ total_grad = be.divide(total_grad, self.n_points, out=total_grad)
724
+ return -total_grad
725
+
726
+
727
+ class NormalizedCrossCorrelationMean(NormalizedCrossCorrelation):
728
+ """
729
+ Computes a similar score than :py:class:`NormalizedCrossCorrelation`, but
730
+ additionally factors in the mean of template and target.
731
+
732
+ .. math::
733
+
734
+ \\text{score} = \\frac{(\\text{target_weights} - \\text{mean(target_weights)})
735
+ \\cdot (\\text{template_weights} -
736
+ \\text{mean(template_weights)})}
737
+ {\\text{max(target_norm} \\times \\text{template_norm, eps)}}
738
+
739
+ Where:
740
+
741
+ .. math::
742
+
743
+ \\text{target_norm} = ||\\text{target_weights} - \\text{mean(target_weights)}||
744
+
745
+ .. math::
746
+
747
+ \\text{template_norm} = ||\\text{template_weights} -
748
+ \\text{mean(template_weights)}||
749
+
750
+ Here, :math:`||.||` denotes the L2 (Euclidean) norm, and :math:`\\text{mean(.)}`
751
+ computes the mean of the respective weights.
752
+ """
753
+
754
+ __doc__ += _MatchCoordinatesToDensity.__doc__
755
+
756
+ def __init__(self, **kwargs):
757
+ kwargs["target"] = np.subtract(kwargs["target"], kwargs["target"].mean())
758
+ kwargs["template_weights"] = np.subtract(
759
+ kwargs["template_weights"], kwargs["template_weights"].mean()
760
+ )
761
+ super().__init__(**kwargs)
762
+
763
+
764
+ class MaskedCrossCorrelation(_MatchCoordinatesToDensity):
765
+ """
766
+ The Masked Cross-Correlation computes the similarity between `target_weights`
767
+ and `template_weights` under respective masks. The score provides a measure of
768
+ similarity even in the presence of missing or masked data.
769
+
770
+ The formula for the Masked Cross-Correlation is:
771
+
772
+ .. math::
773
+ \\text{numerator} = \\text{dot}(\\text{target_weights},
774
+ \\text{template_weights}) -
775
+ \\frac{\\text{sum}(\\text{mask_target}) \\times
776
+ \\text{sum}(\\text{mask_template})}
777
+ {\\text{mask_overlap}}
778
+
779
+ .. math::
780
+ \\text{denominator1} = \\text{sum}(\\text{mask_target}^2) -
781
+ \\frac{\\text{sum}(\\text{mask_target})^2}
782
+ {\\text{mask_overlap}}
783
+
784
+ .. math::
785
+ \\text{denominator2} = \\text{sum}(\\text{mask_template}^2) -
786
+ \\frac{\\text{sum}(\\text{mask_template})^2}
787
+ {\\text{mask_overlap}}
788
+
789
+ .. math::
790
+ \\text{denominator} = \\sqrt{\\text{denominator1} \\times \\text{denominator2}}
791
+
792
+ .. math::
793
+ \\text{score} = \\frac{\\text{numerator}}{\\text{denominator}}
794
+ \\text{ if denominator } \\neq 0
795
+ \\text{ else } 0
796
+
797
+ Where:
798
+
799
+ - mask_target and mask_template are binary masks for the target_weights
800
+ and template_weights respectively.
801
+
802
+ - mask_overlap represents the number of overlapping non-zero elements in
803
+ the masks.
804
+
805
+ References
806
+ ----------
807
+ .. [1] Masked FFT registration, Dirk Padfield, CVPR 2010 conference
808
+ """
809
+
810
+ __doc__ += _MatchCoordinatesToDensity.__doc__
811
+
812
+ def __call__(self) -> float:
813
+ """Returns the score of the current configuration."""
814
+
815
+ in_volume = np.logical_and(
816
+ self.template_rotated < np.array(self.target.shape)[:, None],
817
+ self.template_rotated >= 0,
818
+ ).min(axis=0)
819
+ in_volume_mask = np.logical_and(
820
+ self.template_mask_rotated < np.array(self.target.shape)[:, None],
821
+ self.template_mask_rotated >= 0,
822
+ ).min(axis=0)
823
+
824
+ mask_overlap = np.sum(
825
+ self.target_mask[
826
+ tuple(self.template_mask_rotated[:, in_volume_mask].astype(int))
827
+ ],
828
+ )
829
+ mask_overlap = np.fmax(mask_overlap, np.finfo(float).eps)
830
+
831
+ mask_target = self.target[
832
+ tuple(self.template_mask_rotated[:, in_volume_mask].astype(int))
833
+ ]
834
+ denominator1 = np.subtract(
835
+ np.sum(mask_target**2),
836
+ np.divide(np.square(np.sum(mask_target)), mask_overlap),
837
+ )
838
+ mask_template = np.multiply(
839
+ self.template_weights[in_volume],
840
+ self.target_mask[tuple(self.template_rotated[:, in_volume].astype(int))],
841
+ )
842
+ denominator2 = np.subtract(
843
+ np.sum(mask_template**2),
844
+ np.divide(np.square(np.sum(mask_template)), mask_overlap),
845
+ )
846
+
847
+ denominator1 = np.fmax(denominator1, 0.0)
848
+ denominator2 = np.fmax(denominator2, 0.0)
849
+ denominator = np.sqrt(np.multiply(denominator1, denominator2))
850
+
851
+ numerator = np.dot(
852
+ self.target[tuple(self.template_rotated[:, in_volume].astype(int))],
853
+ self.template_weights[in_volume],
854
+ )
855
+
856
+ numerator -= np.divide(
857
+ np.multiply(np.sum(mask_target), np.sum(mask_template)), mask_overlap
858
+ )
859
+
860
+ if denominator == 0:
861
+ return 0.0
862
+
863
+ score = numerator / denominator
864
+ return float(score * self.score_sign)
865
+
866
+
867
+ class PartialLeastSquareDifference(_MatchCoordinatesToDensity):
868
+ """
869
+ The Partial Least Square Difference (PLSQ) between the target :math:`f` and the
870
+ template :math:`g` is calculated as:
871
+
872
+ .. math::
873
+
874
+ \\text{d(f,g)} = \\sum_{i=1}^{n} \\| f(\\mathbf{p}_i) - g(\\mathbf{q}_i) \\|^2
875
+
876
+ References
877
+ ----------
878
+ .. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
879
+ fitting", Journal of Structural Biology, vol. 174, no. 2,
880
+ pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
881
+ """
882
+
883
+ __doc__ += _MatchCoordinatesToDensity.__doc__
884
+
885
+ def __call__(self) -> float:
886
+ """Returns the score of the current configuration."""
887
+ score = be.sum(
888
+ be.square(be.subtract(self._target_values, self.template_weights))
889
+ )
890
+ return score * self.score_sign
891
+
892
+
893
+ class MutualInformation(_MatchCoordinatesToDensity):
894
+ """
895
+ The Mutual Information (MI) score between the target :math:`f` and the
896
+ template :math:`g` is calculated as:
897
+
898
+ .. math::
899
+
900
+ \\text{d(f,g)} = \\sum_{f,g} p(f,g) \\log \\frac{p(f,g)}{p(f)p(g)}
901
+
902
+ References
903
+ ----------
904
+ .. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
905
+ fitting", Journal of Structural Biology, vol. 174, no. 2,
906
+ pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
907
+
908
+ """
909
+
910
+ __doc__ += _MatchCoordinatesToDensity.__doc__
911
+
912
+ def __call__(self) -> float:
913
+ """Returns the score of the current configuration."""
914
+ p_xy, target, template = np.histogram2d(
915
+ self._target_values, self.template_weights
916
+ )
917
+ p_x, p_y = np.sum(p_xy, axis=1), np.sum(p_xy, axis=0)
918
+
919
+ p_xy /= p_xy.sum()
920
+ p_x /= p_x.sum()
921
+ p_y /= p_y.sum()
922
+
923
+ logprob = np.divide(p_xy, p_x[:, None] * p_y[None, :] + np.finfo(float).eps)
924
+ score = np.nansum(p_xy * logprob)
925
+
926
+ return score * self.score_sign
927
+
928
+
929
+ class Envelope(_MatchCoordinatesToDensity):
930
+ """
931
+ The Envelope score (ENV) between the target :math:`f` and the
932
+ template :math:`g` is calculated as:
933
+
934
+ .. math::
935
+
936
+ \\text{d(f,g)} = \\sum_{\\mathbf{p} \\in P} f'(\\mathbf{p})
937
+ \\cdot g'(\\mathbf{p})
938
+
939
+ References
940
+ ----------
941
+ .. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
942
+ fitting", Journal of Structural Biology, vol. 1174, no. 2,
943
+ pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
944
+ """
945
+
946
+ __doc__ += _MatchCoordinatesToDensity.__doc__
947
+
948
+ def __init__(self, target_threshold: float = None, **kwargs):
949
+ super().__init__(**kwargs)
950
+ if target_threshold is None:
951
+ target_threshold = np.mean(self.target)
952
+ self.target = np.where(self.target > target_threshold, -1, 1)
953
+ self.target_present = np.sum(self.target == -1)
954
+ self.target_absent = np.sum(self.target == 1)
955
+ self.template_weights = np.ones_like(self.template_weights)
956
+
957
+ def __call__(self) -> float:
958
+ """Returns the score of the current configuration."""
959
+ score = self._target_values
960
+ unassigned_density = self.target_present - (score == -1).sum()
961
+
962
+ # Out of volume values will be set to 0
963
+ score = score.sum() - unassigned_density
964
+ score -= 2 * np.sum(np.invert(np.abs(self._target_values) > 0))
965
+ min_score = -self.target_present - 2 * self.target_absent
966
+ score = (score - 2 * min_score) / (2 * self.target_present - min_score)
967
+
968
+ return score * self.score_sign
969
+
970
+
971
+ class Chamfer(_MatchCoordinatesToCoordinates):
972
+ """
973
+ The Chamfer distance between the target :math:`f` and the template :math:`g`
974
+ is calculated as:
975
+
976
+ .. math::
977
+
978
+ \\text{d(f,g)} = \\frac{1}{|X|} \\sum_{\\mathbf{f}_i \\in X}
979
+ \\inf_{\\mathbf{g} \\in Y} ||\\mathbf{f}_i - \\mathbf{g}||_2
980
+
981
+ References
982
+ ----------
983
+ .. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
984
+ fitting", Journal of Structural Biology, vol. 174, no. 2,
985
+ pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
986
+ """
987
+
988
+ __doc__ += _MatchCoordinatesToDensity.__doc__
989
+
990
+ def _post_init(self, **kwargs):
991
+ from scipy.spatial import KDTree
992
+
993
+ self.target_tree = KDTree(self.target_coordinates.T)
994
+
995
+ def __call__(self) -> float:
996
+ """Returns the score of the current configuration."""
997
+ dist, _ = self.target_tree.query(self.template_coordinates_rotated.T)
998
+ score = np.mean(dist)
999
+ return score * self.score_sign
1000
+
1001
+
1002
+ class NormalVectorScore(_MatchCoordinatesToCoordinates):
1003
+ """
1004
+ The Normal Vector Score (NVS) between the target's :math:`f` and the template
1005
+ :math:`g`'s normal vectors is calculated as:
1006
+
1007
+ .. math::
1008
+
1009
+ \\text{d(f,g)} = \\frac{1}{N} \\sum_{i=1}^{N}
1010
+ \\frac{
1011
+ {\\vec{f}_i} \\cdot {\\vec{g}_i}
1012
+ }{
1013
+ ||\\vec{f}_i|| \\, ||\\vec{g}_i||
1014
+ }
1015
+
1016
+ References
1017
+ ----------
1018
+ .. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
1019
+ fitting", Journal of Structural Biology, vol. 174, no. 2,
1020
+ pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
1021
+
1022
+ """
1023
+
1024
+ __doc__ += _MatchCoordinatesToDensity.__doc__
1025
+
1026
+ def __call__(self) -> float:
1027
+ """Returns the score of the current configuration."""
1028
+ numerator = np.multiply(
1029
+ self.template_coordinates_rotated, self.target_coordinates
1030
+ )
1031
+ denominator = np.linalg.norm(self.template_coordinates_rotated)
1032
+ denominator *= np.linalg.norm(self.target_coordinates)
1033
+ score = np.mean(numerator / denominator)
1034
+ return score
1035
+
1036
+
1037
+ MATCHING_OPTIMIZATION_REGISTER = {
1038
+ "CrossCorrelation": CrossCorrelation,
1039
+ "LaplaceCrossCorrelation": LaplaceCrossCorrelation,
1040
+ "NormalizedCrossCorrelationMean": NormalizedCrossCorrelationMean,
1041
+ "NormalizedCrossCorrelation": NormalizedCrossCorrelation,
1042
+ "MaskedCrossCorrelation": MaskedCrossCorrelation,
1043
+ "PartialLeastSquareDifference": PartialLeastSquareDifference,
1044
+ "Envelope": Envelope,
1045
+ "Chamfer": Chamfer,
1046
+ "MutualInformation": MutualInformation,
1047
+ "NormalVectorScore": NormalVectorScore,
1048
+ "FLC": FLC,
1049
+ }
1050
+
1051
+
1052
+ def register_matching_optimization(match_name: str, match_class: type):
1053
+ """
1054
+ Registers a new mtaching method.
1055
+
1056
+ Parameters
1057
+ ----------
1058
+ match_name : str
1059
+ Name of the matching instance.
1060
+ match_class : type
1061
+ Class pointer.
1062
+
1063
+ Raises
1064
+ ------
1065
+ ValueError
1066
+ If any of the required methods is not defined.
1067
+ """
1068
+ methods_to_check = ["__init__", "__call__"]
1069
+
1070
+ for method in methods_to_check:
1071
+ if not hasattr(match_class, method):
1072
+ raise ValueError(
1073
+ f"Method '{method}' is not defined in the provided class or object."
1074
+ )
1075
+ MATCHING_OPTIMIZATION_REGISTER[match_name] = match_class
1076
+
1077
+
1078
+ def create_score_object(score: str, **kwargs) -> object:
1079
+ """
1080
+ Initialize score object with name ``score`` using ``**kwargs``.
1081
+
1082
+ Parameters
1083
+ ----------
1084
+ score: str
1085
+ Name of the score.
1086
+ **kwargs: Dict
1087
+ Keyword arguments passed to the __init__ method of the score object.
1088
+
1089
+ Returns
1090
+ -------
1091
+ object
1092
+ Initialized score object.
1093
+
1094
+ Raises
1095
+ ------
1096
+ ValueError
1097
+ If ``score`` is not a key in MATCHING_OPTIMIZATION_REGISTER.
1098
+
1099
+ See Also
1100
+ --------
1101
+ :py:meth:`register_matching_optimization`
1102
+
1103
+ Examples
1104
+ --------
1105
+ >>> from tme import Density
1106
+ >>> from tme.matching_utils import create_mask, euler_to_rotationmatrix
1107
+ >>> from tme.matching_optimization import CrossCorrelation, optimize_match
1108
+ >>> translation, rotation = (5, -2, 7), (5, -10, 2)
1109
+ >>> target = create_mask(
1110
+ >>> mask_type="ellipse",
1111
+ >>> radius=(5,5,5),
1112
+ >>> shape=(51,51,51),
1113
+ >>> center=(25,25,25),
1114
+ >>> ).astype(float)
1115
+ >>> template = Density(data=target)
1116
+ >>> template = template.rigid_transform(
1117
+ >>> translation=translation,
1118
+ >>> rotation_matrix=euler_to_rotationmatrix(rotation),
1119
+ >>> )
1120
+ >>> template_coordinates = template.to_pointcloud(0)
1121
+ >>> template_weights = template.data[tuple(template_coordinates)]
1122
+ >>> score_object = CrossCorrelation(
1123
+ >>> target=target,
1124
+ >>> template_coordinates=template_coordinates,
1125
+ >>> template_weights=template_weights,
1126
+ >>> negate_score=True # Multiply returned score with -1 for minimization
1127
+ >>> )
1128
+ """
1129
+
1130
+ score_object = MATCHING_OPTIMIZATION_REGISTER.get(score, None)
1131
+
1132
+ if score_object is None:
1133
+ raise ValueError(
1134
+ f"{score} is not defined. Please pick from "
1135
+ f" {', '.join(list(MATCHING_OPTIMIZATION_REGISTER.keys()))}."
1136
+ )
1137
+
1138
+ score_object = score_object(**kwargs)
1139
+ return score_object
1140
+
1141
+
1142
+ def optimize_match(
1143
+ score_object: object,
1144
+ bounds_translation: Tuple[Tuple[float]] = None,
1145
+ bounds_rotation: Tuple[Tuple[float]] = None,
1146
+ optimization_method: str = "basinhopping",
1147
+ maxiter: int = 50,
1148
+ x0: Tuple[float] = None,
1149
+ ) -> Tuple[ArrayLike, ArrayLike, float]:
1150
+ """
1151
+ Find the translation and rotation optimizing the score returned by ``score_object``
1152
+ with respect to provided bounds.
1153
+
1154
+ Parameters
1155
+ ----------
1156
+ score_object: object
1157
+ Class object that defines a score method, which returns a floating point
1158
+ value given a tuple of floating points where the first half describes a
1159
+ translation and the second a rotation. The score will be minimized, i.e.
1160
+ it has to be negated if similarity should be optimized.
1161
+ bounds_translation : tuple of tuple float, optional
1162
+ Bounds on the evaluated translations. Has to be specified per dimension
1163
+ as tuple of (min, max). Default is None.
1164
+ bounds_rotation : tuple of tuple float, optional
1165
+ Bounds on the evaluated zyz Euler angles. Has to be specified per dimension
1166
+ as tuple of (min, max). Default is None.
1167
+ optimization_method : str, optional
1168
+ Optimizer that will be used, basinhopping by default. For further
1169
+ information refer to :doc:`scipy:reference/optimize`.
1170
+
1171
+ +------------------------+-------------------------------------------+
1172
+ | differential_evolution | Highest accuracy but long runtime. |
1173
+ | | Requires bounds on translation. |
1174
+ +------------------------+-------------------------------------------+
1175
+ | basinhopping | Decent accuracy, medium runtime. |
1176
+ +------------------------+-------------------------------------------+
1177
+ | minimize | If initial values are closed to optimum |
1178
+ | | acceptable accuracy and short runtime |
1179
+ +------------------------+-------------------------------------------+
1180
+
1181
+ maxiter : int, optional
1182
+ The maximum number of iterations, 50 by default.
1183
+ x0 : tuple of floats, optional
1184
+ Initial values for the optimizer, zero by default.
1185
+
1186
+ Returns
1187
+ -------
1188
+ Tuple[ArrayLike, ArrayLike, float]
1189
+ Optimal translation, rotation matrix and corresponding score.
1190
+
1191
+ Raises
1192
+ ------
1193
+ ValueError
1194
+ If ``optimization_method`` is not supported.
1195
+
1196
+ Notes
1197
+ -----
1198
+ This function currently only supports three-dimensional optimization and
1199
+ ``score_object`` will be modified during this operation.
1200
+
1201
+ Examples
1202
+ --------
1203
+ Having defined ``score_object``, for instance via :py:meth:`create_score_object`,
1204
+ non-exhaustive template matching can be performed as follows
1205
+
1206
+ >>> translation_fit, rotation_fit, score = optimize_match(score_object)
1207
+
1208
+ `translation_fit` and `rotation_fit` correspond to the inverse of the applied
1209
+ translation and rotation, so the following statements should hold within tolerance
1210
+
1211
+ >>> np.allclose(translation, -translation_fit, atol = 1) # True
1212
+ >>> np.allclose(rotation, np.linalg.inv(rotation_fit), rtol = .1) # True
1213
+
1214
+ Bounds on translation and rotation can be defined as follows
1215
+
1216
+ >>> translation_fit, rotation_fit, score = optimize_match(
1217
+ >>> score_object=score_object,
1218
+ >>> bounds_translation=((-5,5),(-2,2),(0,0)),
1219
+ >>> bounds_rotation=((-10,10), (-5,5), (0,0)),
1220
+ >>> )
1221
+
1222
+ The optimization scheme and the initial parameter estimates can also be adapted
1223
+
1224
+ >>> translation_fit, rotation_fit, score = optimize_match(
1225
+ >>> score_object=score_object,
1226
+ >>> optimization_method="minimize",
1227
+ >>> x0=(0,0,0,5,3,-5),
1228
+ >>> )
1229
+
1230
+ """
1231
+ ndim = 3
1232
+ _optimization_method = {
1233
+ "differential_evolution": differential_evolution,
1234
+ "basinhopping": basinhopping,
1235
+ "minimize": minimize,
1236
+ }
1237
+ if optimization_method not in _optimization_method:
1238
+ raise ValueError(
1239
+ f"{optimization_method} is not supported. "
1240
+ f"Pick from {', '.join(list(_optimization_method.keys()))}"
1241
+ )
1242
+
1243
+ finfo = np.finfo(np.float32)
1244
+
1245
+ # DE always requires bounds
1246
+ if optimization_method == "differential_evolution" and bounds_translation is None:
1247
+ bounds_translation = tuple((finfo.min, finfo.max) for _ in range(ndim))
1248
+
1249
+ if bounds_translation is None and bounds_rotation is not None:
1250
+ bounds_translation = tuple((finfo.min, finfo.max) for _ in range(ndim))
1251
+
1252
+ if bounds_rotation is None and bounds_translation is not None:
1253
+ bounds_rotation = tuple((-180, 180) for _ in range(ndim))
1254
+
1255
+ bounds, linear_constraint = None, ()
1256
+ if bounds_rotation is not None and bounds_translation is not None:
1257
+ uncertainty = (*bounds_translation, *bounds_rotation)
1258
+ bounds = [
1259
+ bound if bound != (0, 0) else (-finfo.resolution, finfo.resolution)
1260
+ for bound in uncertainty
1261
+ ]
1262
+ linear_constraint = LinearConstraint(
1263
+ np.eye(len(bounds)), np.min(bounds, axis=1), np.max(bounds, axis=1)
1264
+ )
1265
+
1266
+ x0 = np.zeros(2 * ndim) if x0 is None else x0
1267
+
1268
+ return_gradient = getattr(score_object, "return_gradient", False)
1269
+ if optimization_method != "minimize" and return_gradient:
1270
+ warnings.warn("Gradient only considered for optimization_method='minimize'.")
1271
+ score_object.return_gradient = False
1272
+
1273
+ initial_score = score_object.score(x=x0)
1274
+ if isinstance(initial_score, (List, Tuple)):
1275
+ initial_score = initial_score[0]
1276
+
1277
+ if optimization_method == "basinhopping":
1278
+ result = basinhopping(
1279
+ x0=x0,
1280
+ func=score_object.score,
1281
+ niter=maxiter,
1282
+ minimizer_kwargs={"method": "COBYLA", "constraints": linear_constraint},
1283
+ )
1284
+ elif optimization_method == "differential_evolution":
1285
+ result = differential_evolution(
1286
+ func=score_object.score,
1287
+ bounds=bounds,
1288
+ constraints=linear_constraint,
1289
+ maxiter=maxiter,
1290
+ )
1291
+ elif optimization_method == "minimize":
1292
+ if hasattr(score_object, "grad") and not return_gradient:
1293
+ warnings.warn(
1294
+ "Consider initializing score object with return_gradient=True."
1295
+ )
1296
+ result = minimize(
1297
+ x0=x0,
1298
+ fun=score_object.score,
1299
+ jac=return_gradient,
1300
+ bounds=bounds,
1301
+ constraints=linear_constraint,
1302
+ options={"maxiter": maxiter},
1303
+ )
1304
+ print(f"Niter: {result.nit}, success : {result.success} ({result.message}).")
1305
+ print(f"Initial score: {initial_score} - Refined score: {result.fun}")
1306
+ if initial_score < result.fun:
1307
+ print("Initial score better than refined score. Returning identity.")
1308
+ result.x = np.zeros_like(result.x)
1309
+ translation, rotation = result.x[:ndim], result.x[ndim:]
1310
+ rotation_matrix = euler_to_rotationmatrix(rotation)
1311
+ return translation, rotation_matrix, result.fun