pytme 0.1.5__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.1.5.data/scripts/estimate_ram_usage.py +81 -0
  2. pytme-0.1.5.data/scripts/match_template.py +744 -0
  3. pytme-0.1.5.data/scripts/postprocess.py +279 -0
  4. pytme-0.1.5.data/scripts/preprocess.py +93 -0
  5. pytme-0.1.5.data/scripts/preprocessor_gui.py +729 -0
  6. pytme-0.1.5.dist-info/LICENSE +153 -0
  7. pytme-0.1.5.dist-info/METADATA +69 -0
  8. pytme-0.1.5.dist-info/RECORD +63 -0
  9. pytme-0.1.5.dist-info/WHEEL +5 -0
  10. pytme-0.1.5.dist-info/entry_points.txt +6 -0
  11. pytme-0.1.5.dist-info/top_level.txt +2 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +81 -0
  14. scripts/match_template.py +744 -0
  15. scripts/match_template_devel.py +788 -0
  16. scripts/postprocess.py +279 -0
  17. scripts/preprocess.py +93 -0
  18. scripts/preprocessor_gui.py +729 -0
  19. tme/__init__.py +6 -0
  20. tme/__version__.py +1 -0
  21. tme/analyzer.py +1144 -0
  22. tme/backends/__init__.py +134 -0
  23. tme/backends/cupy_backend.py +309 -0
  24. tme/backends/matching_backend.py +1154 -0
  25. tme/backends/npfftw_backend.py +763 -0
  26. tme/backends/pytorch_backend.py +526 -0
  27. tme/data/__init__.py +0 -0
  28. tme/data/c48n309.npy +0 -0
  29. tme/data/c48n527.npy +0 -0
  30. tme/data/c48n9.npy +0 -0
  31. tme/data/c48u1.npy +0 -0
  32. tme/data/c48u1153.npy +0 -0
  33. tme/data/c48u1201.npy +0 -0
  34. tme/data/c48u1641.npy +0 -0
  35. tme/data/c48u181.npy +0 -0
  36. tme/data/c48u2219.npy +0 -0
  37. tme/data/c48u27.npy +0 -0
  38. tme/data/c48u2947.npy +0 -0
  39. tme/data/c48u3733.npy +0 -0
  40. tme/data/c48u4749.npy +0 -0
  41. tme/data/c48u5879.npy +0 -0
  42. tme/data/c48u7111.npy +0 -0
  43. tme/data/c48u815.npy +0 -0
  44. tme/data/c48u83.npy +0 -0
  45. tme/data/c48u8649.npy +0 -0
  46. tme/data/c600v.npy +0 -0
  47. tme/data/c600vc.npy +0 -0
  48. tme/data/metadata.yaml +80 -0
  49. tme/data/quat_to_numpy.py +42 -0
  50. tme/data/scattering_factors.pickle +0 -0
  51. tme/density.py +2314 -0
  52. tme/extensions.cpython-311-darwin.so +0 -0
  53. tme/helpers.py +881 -0
  54. tme/matching_data.py +377 -0
  55. tme/matching_exhaustive.py +1553 -0
  56. tme/matching_memory.py +382 -0
  57. tme/matching_optimization.py +1123 -0
  58. tme/matching_utils.py +1180 -0
  59. tme/parser.py +429 -0
  60. tme/preprocessor.py +1291 -0
  61. tme/scoring.py +866 -0
  62. tme/structure.py +1428 -0
  63. tme/types.py +10 -0
tme/scoring.py ADDED
@@ -0,0 +1,866 @@
1
+ from copy import deepcopy
2
+ from typing import Tuple, Callable, Dict
3
+ from joblib import Parallel, delayed
4
+
5
+ import numpy as np
6
+
7
+ from . import Preprocessor
8
+ from .matching_data import MatchingData
9
+ from .backends import backend
10
+ from .types import NDArray, CallbackClass
11
+ from .matching_memory import CCMemoryUsage
12
+ from .analyzer import MaxScoreOverRotations
13
+ from .matching_exhaustive import register_matching_exhaustive, device_memory_handler
14
+ from .matching_utils import apply_convolution_mode, conditional_execute
15
+
16
+
17
+
18
+ from scipy.interpolate import RegularGridInterpolator
19
+ class ExtractProjection:
20
+ def __init__(self, data: NDArray, interpolation_method: str = "linear"):
21
+ if not np.all(np.iscomplex(data)):
22
+ data = np.fft.fftshift(np.fft.fftn(data))
23
+
24
+ self.create_point_cloud(data.shape)
25
+
26
+ self._interpolator = RegularGridInterpolator(
27
+ tuple(np.linspace(0, 1, x) for x in data.shape),
28
+ data,
29
+ method=interpolation_method,
30
+ bounds_error=False,
31
+ fill_value=0,
32
+ )
33
+
34
+ def __call__(
35
+ self,
36
+ rotation_matrix: NDArray,
37
+ return_rfft: bool = False,
38
+ center_zero_frequency: bool = False,
39
+ ) -> NDArray:
40
+ self._rotate_points(rotation_matrix=rotation_matrix)
41
+ fourier_slice = self._interpolator(self._point_cloud_transform.T)
42
+ fourier_slice = fourier_slice.reshape(self._data_shape[:-1])
43
+
44
+ if not center_zero_frequency:
45
+ fourier_slice = np.fft.ifftshift(fourier_slice)
46
+
47
+ if return_rfft:
48
+ cutoff = fourier_slice.shape[-1] // 2 + 1
49
+ fourier_slice = fourier_slice[..., :cutoff]
50
+
51
+ return fourier_slice
52
+
53
+ def create_point_cloud(self, shape : NDArray) -> None:
54
+ temp = np.ones(shape[:-1])
55
+ point_cloud = np.vstack(
56
+ [
57
+ np.array(np.where(temp > 0)),
58
+ np.full(temp.size, fill_value=shape[-1] // 2),
59
+ ]
60
+ )
61
+ point_cloud = np.divide(point_cloud, np.array(shape)[..., None])
62
+ self._data_shape = np.array(shape)
63
+ self._ifft_shift = np.where(
64
+ self._data_shape % 2 == 0,
65
+ self._data_shape // 2,
66
+ (self._data_shape - 1) // 2,
67
+ )[..., None]
68
+ self._point_cloud_center = point_cloud.mean(axis=1)[..., None]
69
+ self._point_cloud = np.subtract(point_cloud, self._point_cloud_center)
70
+ self._point_cloud_transform = np.empty(
71
+ self._point_cloud.shape, dtype=np.float32
72
+ )
73
+
74
+ def _rotate_points(self, rotation_matrix: NDArray) -> None:
75
+ np.matmul(rotation_matrix, self._point_cloud, out=self._point_cloud_transform)
76
+ np.add(
77
+ self._point_cloud_transform,
78
+ self._point_cloud_center,
79
+ out=self._point_cloud_transform,
80
+ )
81
+
82
+ def corr2_setup(
83
+ rfftn: Callable,
84
+ irfftn: Callable,
85
+ template: NDArray,
86
+ template_mask: NDArray,
87
+ target: NDArray,
88
+ fast_shape: Tuple[int],
89
+ fast_ft_shape: Tuple[int],
90
+ real_dtype: type,
91
+ complex_dtype: type,
92
+ shared_memory_handler: Callable,
93
+ callback_class: Callable,
94
+ callback_class_args: Dict,
95
+ **kwargs,
96
+ ) -> Dict:
97
+ """
98
+ Setup to compute a normalized cross-correlation score of a target f a template g
99
+ and a mask m:
100
+
101
+ .. math::
102
+
103
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
104
+ {N_m * \\sqrt{
105
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
106
+ }
107
+
108
+ Where:
109
+
110
+ .. math::
111
+
112
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
113
+
114
+ and Nm is the number of voxels within the template mask m.
115
+
116
+ References
117
+ ----------
118
+ .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
119
+ Microsc. Microanal. 26, 2516 (2020)
120
+ .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
121
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
122
+
123
+ See Also
124
+ --------
125
+ :py:meth:`flc_scoring`
126
+ """
127
+ target_pad = backend.topleft_pad(target, fast_shape)
128
+
129
+ # Target and squared target window sums
130
+ ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
131
+ ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
132
+ rfftn(target_pad, ft_target)
133
+ rfftn(backend.square(target_pad), ft_target2)
134
+
135
+ # Convert arrays used in subsequent fitting to SharedMemory objects
136
+ ft_target = backend.arr_to_sharedarr(
137
+ arr=ft_target, shared_memory_handler=shared_memory_handler
138
+ )
139
+ ft_target2 = backend.arr_to_sharedarr(
140
+ arr=ft_target2, shared_memory_handler=shared_memory_handler
141
+ )
142
+
143
+ template_buffer = backend.arr_to_sharedarr(
144
+ arr=template, shared_memory_handler=shared_memory_handler
145
+ )
146
+ template_mask_buffer = backend.arr_to_sharedarr(
147
+ arr=template_mask, shared_memory_handler=shared_memory_handler
148
+ )
149
+
150
+ template_tuple = (template_buffer, template.shape, real_dtype)
151
+ template_mask_tuple = (template_mask_buffer, template_mask.shape, real_dtype)
152
+
153
+ target_ft_tuple = (ft_target, fast_ft_shape, complex_dtype)
154
+ target_ft2_tuple = (ft_target2, fast_ft_shape, complex_dtype)
155
+
156
+ ret = {
157
+ "template": template_tuple,
158
+ "template_mask": template_mask_tuple,
159
+ "ft_target": target_ft_tuple,
160
+ "ft_target2": target_ft2_tuple,
161
+ "targetshape": target.shape,
162
+ "templateshape": template.shape,
163
+ "fast_shape": fast_shape,
164
+ "fast_ft_shape": fast_ft_shape,
165
+ "real_dtype": real_dtype,
166
+ "complex_dtype": complex_dtype,
167
+ "callback_class": callback_class,
168
+ "callback_class_args": callback_class_args,
169
+ }
170
+
171
+ return ret
172
+
173
+
174
+ def corr2_scoring(
175
+ template: Tuple[type, Tuple[int], type],
176
+ template_mask: Tuple[type, Tuple[int], type],
177
+ ft_target: Tuple[type, Tuple[int], type],
178
+ ft_target2: Tuple[type, Tuple[int], type],
179
+ template_filter: Tuple[type, Tuple[int], type],
180
+ targetshape: Tuple[int],
181
+ templateshape: Tuple[int],
182
+ fast_shape: Tuple[int],
183
+ fast_ft_shape: Tuple[int],
184
+ rotations: NDArray,
185
+ real_dtype: type,
186
+ complex_dtype: type,
187
+ callback_class: CallbackClass,
188
+ callback_class_args: Dict,
189
+ interpolation_order: int,
190
+ **kwargs,
191
+ ) -> CallbackClass:
192
+ template_buffer, template_shape, template_dtype = template
193
+ template_mask_buffer, *_ = template_mask
194
+ filter_buffer, filter_shape, filter_dtype = template_filter
195
+
196
+ ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
197
+ ft_target2_buffer, *_ = ft_target2
198
+
199
+ if callback_class is not None and isinstance(callback_class, type):
200
+ callback = callback_class(**callback_class_args)
201
+ elif not isinstance(callback_class, type):
202
+ callback = callback_class
203
+
204
+ # Retrieve objects from shared memory
205
+ template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
206
+ template_mask = backend.sharedarr_to_arr(
207
+ template_shape, template_dtype, template_mask_buffer
208
+ )
209
+ ft_target = backend.sharedarr_to_arr(
210
+ ft_target_shape, ft_target_dtype, ft_target_buffer
211
+ )
212
+ ft_target2 = backend.sharedarr_to_arr(
213
+ ft_target_shape, ft_target_dtype, ft_target2_buffer
214
+ )
215
+ template_filter = backend.sharedarr_to_arr(
216
+ filter_shape, filter_dtype, filter_buffer
217
+ )
218
+
219
+ arr = backend.preallocate_array(fast_shape, real_dtype)
220
+ temp = backend.preallocate_array(fast_shape, real_dtype)
221
+ temp2 = backend.preallocate_array(fast_shape, real_dtype)
222
+
223
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
224
+ ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
225
+
226
+ rfftn, irfftn = backend.build_fft(
227
+ fast_shape=fast_shape,
228
+ fast_ft_shape=fast_ft_shape,
229
+ real_dtype=real_dtype,
230
+ complex_dtype=complex_dtype,
231
+ fftargs=kwargs.get("fftargs", {}),
232
+ temp_real=arr,
233
+ temp_fft=ft_temp,
234
+ )
235
+
236
+ templateshape = list(templateshape)
237
+ templateshape[-1] = 1
238
+
239
+ subset = [slice(0, x) for x in templateshape]
240
+ subset.pop(-1)
241
+ subset = tuple(subset)
242
+ temp_shape = list(fast_shape)
243
+ temp_shape[-1] = template.shape[-1]
244
+ rotation_out = backend.preallocate_array(temp_shape, real_dtype)
245
+
246
+ from time import time
247
+
248
+ for index in range(rotations.shape[0]):
249
+ start = time()
250
+ rotation = rotations[index]
251
+ backend.fill(arr, 0)
252
+ backend.fill(temp, 0)
253
+ backend.fill(rotation_out, 0)
254
+
255
+ backend.rotate_array(
256
+ arr=template,
257
+ rotation_matrix=rotation,
258
+ out=rotation_out,
259
+ use_geometric_center=False,
260
+ order=1,
261
+ )
262
+ projection = backend.sum(rotation_out, axis=-1)
263
+ arr[..., 0] = projection
264
+
265
+ projection_mask = backend.full(templateshape, dtype=real_dtype, fill_value=1)
266
+ backend.fill(temp, 0)
267
+ temp = backend.topleft_pad(projection_mask, temp.shape)
268
+
269
+ template_mean = backend.mean(projection[subset])
270
+ template_volume = backend.prod(projection[subset].shape)
271
+ template_ssd = backend.sum(
272
+ backend.square(backend.subtract(projection[subset], template_mean))
273
+ )
274
+
275
+ rfftn(temp, ft_temp)
276
+ backend.multiply(ft_target, ft_temp, out=ft_denom)
277
+ irfftn(ft_denom, temp)
278
+
279
+ numerator = backend.multiply(temp, template_mean)
280
+
281
+ backend.square(temp, out=temp)
282
+ backend.divide(temp, template_volume, out=temp)
283
+ backend.multiply(ft_target2, ft_temp, out=ft_denom)
284
+ irfftn(ft_denom, temp2)
285
+
286
+ backend.subtract(temp2, temp, out=temp)
287
+ backend.multiply(temp, template_ssd, out=temp)
288
+ backend.maximum(temp, 0.0, out=temp)
289
+ backend.sqrt(temp, out=temp)
290
+
291
+ denominator_mask = temp > backend.eps(temp.dtype)
292
+ inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
293
+ inv_denominator[denominator_mask] = 1 / temp[denominator_mask]
294
+
295
+ rfftn(arr, ft_temp)
296
+ backend.multiply(ft_target, ft_temp, out=ft_temp)
297
+ irfftn(ft_temp, arr)
298
+
299
+ backend.subtract(arr, numerator, out=arr)
300
+ backend.multiply(arr, inv_denominator, out=arr)
301
+
302
+ convolution_mode = kwargs.get("convolution_mode", "full")
303
+ score = apply_convolution_mode(
304
+ arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
305
+ )
306
+ print(time() - start)
307
+ if callback_class is not None:
308
+ callback(
309
+ score,
310
+ rotation_matrix=rotation,
311
+ rotation_index=index,
312
+ **callback_class_args,
313
+ )
314
+
315
+ return callback
316
+
317
+
318
+ def corr3_setup(
319
+ rfftn: Callable,
320
+ irfftn: Callable,
321
+ template: NDArray,
322
+ template_mask: NDArray,
323
+ target: NDArray,
324
+ fast_shape: Tuple[int],
325
+ fast_ft_shape: Tuple[int],
326
+ real_dtype: type,
327
+ complex_dtype: type,
328
+ shared_memory_handler: Callable,
329
+ callback_class: Callable,
330
+ callback_class_args: Dict,
331
+ **kwargs,
332
+ ) -> Dict:
333
+ target_pad = backend.topleft_pad(target, fast_shape)
334
+
335
+ # The exact composition of the denominator is debatable
336
+ # scikit-image match_template multiplies the running sum of the target
337
+ # with a scaling factor derived from the template. This is probably appropriate
338
+ # in pattern matching situations where the template exists in the target
339
+ template_mask = backend.preallocate_array(
340
+ (*template_mask.shape[:-1], 1), real_dtype
341
+ )
342
+ template_mask[:] = 1
343
+ window_template = backend.topleft_pad(template_mask, fast_shape)
344
+ ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
345
+ rfftn(window_template, ft_window_template)
346
+ window_template = None
347
+
348
+ # Target and squared target window sums
349
+ ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
350
+ ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
351
+ denominator = backend.preallocate_array(fast_shape, real_dtype)
352
+ target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
353
+ rfftn(target_pad, ft_target)
354
+
355
+ rfftn(backend.square(target_pad), ft_target2)
356
+ backend.multiply(ft_target2, ft_window_template, out=ft_target2)
357
+ irfftn(ft_target2, denominator)
358
+
359
+ backend.multiply(ft_target, ft_window_template, out=ft_window_template)
360
+ irfftn(ft_window_template, target_window_sum)
361
+
362
+ target_pad, ft_target2, ft_window_template = None, None, None
363
+
364
+ projection = template.sum(axis=-1)
365
+ # Normalizing constants
366
+ template_mean = backend.mean(projection)
367
+ template_volume = np.prod(projection.shape)
368
+ template_ssd = backend.sum(
369
+ backend.square(backend.subtract(projection, template_mean))
370
+ )
371
+
372
+ # Final numerator is score - numerator2
373
+ numerator2 = backend.multiply(target_window_sum, template_mean)
374
+
375
+ # Compute denominator
376
+ backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
377
+ backend.divide(target_window_sum, template_volume, out=target_window_sum)
378
+
379
+ backend.subtract(denominator, target_window_sum, out=denominator)
380
+ backend.multiply(denominator, template_ssd, out=denominator)
381
+ backend.maximum(denominator, 0, out=denominator)
382
+ backend.sqrt(denominator, out=denominator)
383
+ target_window_sum = None
384
+
385
+ # Invert denominator to compute final score as product
386
+ denominator_mask = denominator > backend.eps(denominator.dtype)
387
+ inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
388
+ inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
389
+
390
+ # Convert arrays used in subsequent fitting to SharedMemory objects
391
+ template_buffer = backend.arr_to_sharedarr(
392
+ arr=template, shared_memory_handler=shared_memory_handler
393
+ )
394
+ target_ft_buffer = backend.arr_to_sharedarr(
395
+ arr=ft_target, shared_memory_handler=shared_memory_handler
396
+ )
397
+ inv_denominator_buffer = backend.arr_to_sharedarr(
398
+ arr=inv_denominator, shared_memory_handler=shared_memory_handler
399
+ )
400
+ numerator2_buffer = backend.arr_to_sharedarr(
401
+ arr=numerator2, shared_memory_handler=shared_memory_handler
402
+ )
403
+
404
+ template_tuple = (template_buffer, deepcopy(template.shape), real_dtype)
405
+ target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
406
+
407
+ inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
408
+ numerator2_tuple = (numerator2_buffer, fast_shape, real_dtype)
409
+
410
+ ft_target, inv_denominator, numerator2 = None, None, None
411
+
412
+ ret = {
413
+ "template": template_tuple,
414
+ "ft_target": target_ft_tuple,
415
+ "inv_denominator": inv_denominator_tuple,
416
+ "numerator2": numerator2_tuple,
417
+ "targetshape": deepcopy(target.shape),
418
+ "templateshape": deepcopy(template.shape),
419
+ "fast_shape": fast_shape,
420
+ "fast_ft_shape": fast_ft_shape,
421
+ "real_dtype": real_dtype,
422
+ "complex_dtype": complex_dtype,
423
+ "callback_class": callback_class,
424
+ "callback_class_args": callback_class_args,
425
+ "template_mean": kwargs.get("template_mean", template_mean),
426
+ }
427
+
428
+ return ret
429
+
430
+
431
+ def corr3_scoring(
432
+ template: Tuple[type, Tuple[int], type],
433
+ ft_target: Tuple[type, Tuple[int], type],
434
+ inv_denominator: Tuple[type, Tuple[int], type],
435
+ numerator2: Tuple[type, Tuple[int], type],
436
+ template_filter: Tuple[type, Tuple[int], type],
437
+ targetshape: Tuple[int],
438
+ templateshape: Tuple[int],
439
+ fast_shape: Tuple[int],
440
+ fast_ft_shape: Tuple[int],
441
+ rotations: NDArray,
442
+ real_dtype: type,
443
+ complex_dtype: type,
444
+ callback_class: CallbackClass,
445
+ callback_class_args: Dict,
446
+ interpolation_order: int,
447
+ convolution_mode: str = "full",
448
+ **kwargs,
449
+ ) -> CallbackClass:
450
+ template_buffer, template_shape, template_dtype = template
451
+ ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
452
+ inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
453
+ numerator2_buffer, numerator2_shape, _ = numerator2
454
+ filter_buffer, filter_shape, filter_dtype = template_filter
455
+
456
+ if callback_class is not None and isinstance(callback_class, type):
457
+ callback = callback_class(**callback_class_args)
458
+ elif not isinstance(callback_class, type):
459
+ callback = callback_class
460
+
461
+ # Retrieve objects from shared memory
462
+ template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
463
+ ft_target = backend.sharedarr_to_arr(
464
+ ft_target_shape, ft_target_dtype, ft_target_buffer
465
+ )
466
+ inv_denominator = backend.sharedarr_to_arr(
467
+ inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
468
+ )
469
+ numerator2 = backend.sharedarr_to_arr(
470
+ numerator2_shape, template_dtype, numerator2_buffer
471
+ )
472
+ template_filter = backend.sharedarr_to_arr(
473
+ filter_shape, filter_dtype, filter_buffer
474
+ )
475
+
476
+ arr = backend.preallocate_array(fast_shape, real_dtype)
477
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
478
+
479
+ rfftn, irfftn = backend.build_fft(
480
+ fast_shape=fast_shape,
481
+ fast_ft_shape=fast_ft_shape,
482
+ real_dtype=real_dtype,
483
+ complex_dtype=complex_dtype,
484
+ fftargs=kwargs.get("fftargs", {}),
485
+ temp_real=arr,
486
+ temp_fft=ft_temp,
487
+ )
488
+
489
+ norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
490
+ norm_denominator = (backend.sum(inv_denominator) != 1) & (
491
+ backend.size(inv_denominator) != 1
492
+ )
493
+ filter_template = backend.size(template_filter) != 0
494
+
495
+ norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
496
+ norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
497
+ template_filter_func = conditional_execute(backend.multiply, filter_template)
498
+
499
+ rotation_out = backend.preallocate_array(
500
+ (*fast_shape[:-1], template.shape[-1]), real_dtype
501
+ )
502
+ templateshape = list(templateshape)
503
+ templateshape[-1] = 1
504
+ from time import time
505
+
506
+ for index in range(rotations.shape[0]):
507
+ start = time()
508
+ rotation = rotations[index]
509
+ backend.fill(arr, 0)
510
+ backend.rotate_array(
511
+ arr=template,
512
+ rotation_matrix=rotation,
513
+ out=rotation_out,
514
+ use_geometric_center=False,
515
+ order=interpolation_order,
516
+ )
517
+ projection = backend.sum(rotation_out, axis=-1)
518
+ arr[..., 0] = projection
519
+ print(arr.shape)
520
+
521
+ rfftn(arr, ft_temp)
522
+ template_filter_func(ft_temp, template_filter, out=ft_temp)
523
+
524
+ backend.multiply(ft_target, ft_temp, out=ft_temp)
525
+ irfftn(ft_temp, arr)
526
+
527
+ norm_func_numerator(arr, numerator2, out=arr)
528
+ norm_func_denominator(arr, inv_denominator, out=arr)
529
+
530
+ score = apply_convolution_mode(
531
+ arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
532
+ )
533
+ print(time() - start)
534
+ if callback_class is not None:
535
+ callback(
536
+ score,
537
+ rotation_matrix=rotation,
538
+ rotation_index=index,
539
+ **callback_class_args,
540
+ )
541
+
542
+ return callback
543
+
544
+
545
+
546
+ def corr4_scoring(
547
+ template: Tuple[type, Tuple[int], type],
548
+ ft_target: Tuple[type, Tuple[int], type],
549
+ inv_denominator: Tuple[type, Tuple[int], type],
550
+ numerator2: Tuple[type, Tuple[int], type],
551
+ template_filter: Tuple[type, Tuple[int], type],
552
+ targetshape: Tuple[int],
553
+ templateshape: Tuple[int],
554
+ fast_shape: Tuple[int],
555
+ fast_ft_shape: Tuple[int],
556
+ rotations: NDArray,
557
+ real_dtype: type,
558
+ complex_dtype: type,
559
+ callback_class: CallbackClass,
560
+ callback_class_args: Dict,
561
+ interpolation_order: int,
562
+ convolution_mode: str = "full",
563
+ **kwargs,
564
+ ) -> CallbackClass:
565
+ template_buffer, template_shape, template_dtype = template
566
+ ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
567
+ inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
568
+ numerator2_buffer, numerator2_shape, _ = numerator2
569
+ filter_buffer, filter_shape, filter_dtype = template_filter
570
+
571
+ if callback_class is not None and isinstance(callback_class, type):
572
+ callback = callback_class(**callback_class_args)
573
+ elif not isinstance(callback_class, type):
574
+ callback = callback_class
575
+
576
+ # Retrieve objects from shared memory
577
+ template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
578
+ ft_target = backend.sharedarr_to_arr(
579
+ ft_target_shape, ft_target_dtype, ft_target_buffer
580
+ )
581
+ inv_denominator = backend.sharedarr_to_arr(
582
+ inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
583
+ )
584
+ numerator2 = backend.sharedarr_to_arr(
585
+ numerator2_shape, template_dtype, numerator2_buffer
586
+ )
587
+ template_filter = backend.sharedarr_to_arr(
588
+ filter_shape, filter_dtype, filter_buffer
589
+ )
590
+
591
+ arr = backend.preallocate_array(fast_shape, real_dtype)
592
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
593
+
594
+ rfftn, irfftn = backend.build_fft(
595
+ fast_shape=fast_shape,
596
+ fast_ft_shape=fast_ft_shape,
597
+ real_dtype=real_dtype,
598
+ complex_dtype=complex_dtype,
599
+ fftargs=kwargs.get("fftargs", {}),
600
+ temp_real=arr,
601
+ temp_fft=ft_temp,
602
+ )
603
+
604
+ norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
605
+ norm_denominator = (backend.sum(inv_denominator) != 1) & (
606
+ backend.size(inv_denominator) != 1
607
+ )
608
+ filter_template = backend.size(template_filter) != 0
609
+
610
+ norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
611
+ norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
612
+ template_filter_func = conditional_execute(backend.multiply, filter_template)
613
+
614
+ rotation_out = backend.preallocate_array(
615
+ (*fast_shape[:-1], template.shape[-1]), real_dtype
616
+ )
617
+ templateshape = list(templateshape)
618
+ templateshape[-1] = 1
619
+ from time import time
620
+
621
+ extractor = ExtractProjection(template)
622
+ extractor.create_point_cloud(fast_shape)
623
+ print(fast_shape)
624
+
625
+ for index in range(rotations.shape[0]):
626
+ start = time()
627
+ rotation = rotations[index]
628
+
629
+ ft_temp[..., :] = extractor(rotation)[..., None]
630
+ template_filter_func(ft_temp, template_filter, out=ft_temp)
631
+
632
+ print(ft_temp.shape, ft_target.shape)
633
+ backend.multiply(ft_target, ft_temp, out=ft_temp)
634
+ irfftn(ft_temp, arr)
635
+ print(arr.max())
636
+
637
+ norm_func_numerator(arr, numerator2, out=arr)
638
+ norm_func_denominator(arr, inv_denominator, out=arr)
639
+
640
+ score = apply_convolution_mode(
641
+ arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
642
+ )
643
+ print(time() - start)
644
+ if callback_class is not None:
645
+ callback(
646
+ score,
647
+ rotation_matrix=rotation,
648
+ rotation_index=index,
649
+ **callback_class_args,
650
+ )
651
+
652
+ return callback
653
+
654
+
655
+ @device_memory_handler
656
+ def scan(
657
+ matching_data: MatchingData,
658
+ matching_setup: Callable,
659
+ matching_score: Callable,
660
+ n_jobs: int = 4,
661
+ callback_class: CallbackClass = None,
662
+ callback_class_args: Dict = {},
663
+ fftargs: Dict = {},
664
+ pad_fourier: bool = True,
665
+ interpolation_order: int = 3,
666
+ jobs_per_callback_class: int = 8,
667
+ **kwargs,
668
+ ) -> Tuple:
669
+ """
670
+ Perform template matching between target and template and sample
671
+ different rotations of template.
672
+
673
+ Parameters
674
+ ----------
675
+ matching_data : MatchingData
676
+ Template matching data.
677
+ matching_setup : Callable
678
+ Function pointer to setup function.
679
+ matching_score : Callable
680
+ Function pointer to scoring function.
681
+ n_jobs : int, optional
682
+ Number of parallel jobs. Default is 4.
683
+ callback_class : type, optional
684
+ Analyzer class pointer to operate on computed scores.
685
+ callback_class_args : dict, optional
686
+ Arguments passed to the callback_class. Default is an empty dictionary.
687
+ fftargs : dict, optional
688
+ Arguments for the FFT operations. Default is an empty dictionary.
689
+ pad_fourier: bool, optional
690
+ Whether to pad target and template to the full convolution shape.
691
+ interpolation_order : int, optional
692
+ Order of spline interpolation for rotations.
693
+ jobs_per_callback_class : int, optional
694
+ How many jobs should be processed by a single callback_class instance,
695
+ if ones is provided.
696
+ **kwargs : various
697
+ Additional arguments.
698
+
699
+ Returns
700
+ -------
701
+ Tuple
702
+ The merged results from callback_class if provided otherwise None.
703
+ """
704
+ matching_data.to_backend()
705
+ fourier_pad = matching_data._templateshape
706
+ fourier_pad = list(matching_data._templateshape)
707
+ fourier_pad[-1] = 1
708
+ print("make sure to remove this")
709
+ fourier_shift = backend.zeros(len(fourier_pad))
710
+ if not pad_fourier:
711
+ fourier_pad = backend.full(shape=fourier_shift.shape, fill_value=1, dtype=int)
712
+ fourier_shift = 1 - backend.astype(
713
+ backend.divide(matching_data._templateshape, 2), int
714
+ )
715
+ callback_class_args["fourier_shift"] = fourier_shift
716
+
717
+ _, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
718
+ matching_data._target.shape, fourier_pad
719
+ )
720
+ rfftn, irfftn = backend.build_fft(
721
+ fast_shape=fast_shape,
722
+ fast_ft_shape=fast_ft_shape,
723
+ real_dtype=matching_data._default_dtype,
724
+ complex_dtype=matching_data._complex_dtype,
725
+ fftargs=fftargs,
726
+ )
727
+ setup = matching_setup(
728
+ rfftn=rfftn,
729
+ irfftn=irfftn,
730
+ template=matching_data.template,
731
+ template_mask=matching_data.template_mask,
732
+ target=matching_data.target,
733
+ target_mask=matching_data.target_mask,
734
+ fast_shape=fast_shape,
735
+ fast_ft_shape=fast_ft_shape,
736
+ real_dtype=matching_data._default_dtype,
737
+ complex_dtype=matching_data._complex_dtype,
738
+ callback_class=callback_class,
739
+ callback_class_args=callback_class_args,
740
+ **kwargs,
741
+ )
742
+ rfftn, irfftn = None, None
743
+
744
+ template_filter, preprocessor = None, Preprocessor()
745
+ for method, parameters in matching_data.template_filter.items():
746
+ parameters["shape"] = fast_shape
747
+ parameters["omit_negative_frequencies"] = True
748
+ out = preprocessor.apply_method(method=method, parameters=parameters)
749
+ if template_filter is None:
750
+ template_filter = out
751
+ np.multiply(template_filter, out, out=template_filter)
752
+
753
+ if template_filter is None:
754
+ template_filter = backend.full(
755
+ shape=(1,), fill_value=1, dtype=backend._default_dtype
756
+ )
757
+ else:
758
+ template_filter = backend.to_backend_array(template_filter)
759
+
760
+ template_filter = backend.astype(template_filter, backend._default_dtype)
761
+ template_filter_buffer = backend.arr_to_sharedarr(
762
+ arr=template_filter,
763
+ shared_memory_handler=kwargs.get("shared_memory_handler", None),
764
+ )
765
+ setup["template_filter"] = (
766
+ template_filter_buffer,
767
+ template_filter.shape,
768
+ template_filter.dtype,
769
+ )
770
+
771
+ callback_class_args["translation_offset"] = backend.astype(
772
+ matching_data._translation_offset, int
773
+ )
774
+ callback_class_args["thread_safe"] = n_jobs > 1
775
+ callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
776
+
777
+ n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
778
+ callback_class = setup.pop("callback_class", callback_class)
779
+ callback_class_args = setup.pop("callback_class_args", callback_class_args)
780
+ callback_classes = [callback_class for _ in range(n_callback_classes)]
781
+ if callback_class == MaxScoreOverRotations:
782
+ score_space_shape = backend.subtract(
783
+ matching_data.target.shape,
784
+ matching_data._target_pad,
785
+ )
786
+ callback_classes = [
787
+ class_name(
788
+ score_space_shape=score_space_shape,
789
+ score_space_dtype=matching_data._default_dtype,
790
+ shared_memory_handler=kwargs.get("shared_memory_handler", None),
791
+ rotation_space_dtype=backend._default_dtype_int,
792
+ **callback_class_args,
793
+ )
794
+ for class_name in callback_classes
795
+ ]
796
+
797
+ matching_data._target, matching_data._template = None, None
798
+ matching_data._target_mask, matching_data._template_mask = None, None
799
+
800
+ setup["fftargs"] = fftargs.copy()
801
+ convolution_mode = "same"
802
+ if backend.sum(matching_data._target_pad) > 0:
803
+ convolution_mode = "valid"
804
+ setup["convolution_mode"] = convolution_mode
805
+ setup["interpolation_order"] = interpolation_order
806
+ rotation_list = matching_data._split_rotations_on_jobs(n_jobs)
807
+
808
+ backend.free_cache()
809
+
810
+ def _run_scoring(backend_name, backend_args, rotations, **kwargs):
811
+ from tme.backends import backend
812
+
813
+ backend.change_backend(backend_name, **backend_args)
814
+ return matching_score(rotations=rotations, **kwargs)
815
+
816
+ callbacks = Parallel(n_jobs=n_jobs)(
817
+ delayed(_run_scoring)(
818
+ backend_name=backend._backend_name,
819
+ backend_args=backend._backend_args,
820
+ rotations=rotation,
821
+ callback_class=callback_classes[index % n_callback_classes],
822
+ callback_class_args=callback_class_args,
823
+ **setup,
824
+ )
825
+ for index, rotation in enumerate(rotation_list)
826
+ )
827
+
828
+ callbacks = [
829
+ tuple(callback)
830
+ for callback in callbacks[0:n_callback_classes]
831
+ if callback is not None
832
+ ]
833
+ backend.free_cache()
834
+
835
+ merged_callback = None
836
+ if callback_class is not None:
837
+ merged_callback = callback_class.merge(
838
+ callbacks,
839
+ **callback_class_args,
840
+ score_indices=matching_data.indices,
841
+ inner_merge=True,
842
+ )
843
+
844
+ return merged_callback
845
+
846
+
847
+ register_matching_exhaustive(
848
+ matching = "CC2",
849
+ matching_setup = corr2_setup,
850
+ matching_scoring = corr2_scoring,
851
+ memory_class = CCMemoryUsage)
852
+
853
+ register_matching_exhaustive(
854
+ matching = "CC3",
855
+ matching_setup = corr3_setup,
856
+ matching_scoring = corr3_scoring,
857
+ memory_class = CCMemoryUsage
858
+ )
859
+
860
+
861
+ register_matching_exhaustive(
862
+ matching = "CC4",
863
+ matching_setup = corr3_setup,
864
+ matching_scoring = corr4_scoring,
865
+ memory_class = CCMemoryUsage
866
+ )