pytme 0.1.1__cp311-cp311-macosx_13_0_arm64.whl → 0.1.3__cp311-cp311-macosx_13_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tme/scoring.py ADDED
@@ -0,0 +1,679 @@
1
+ from copy import deepcopy
2
+ from typing import Tuple, Callable, Dict
3
+ from joblib import Parallel, delayed
4
+
5
+ import numpy as np
6
+
7
+ from . import Preprocessor, MatchingData
8
+ from .backends import backend
9
+ from .types import NDArray, CallbackClass
10
+ from .matching_memory import CCMemoryUsage
11
+ from .analyzer import MaxScoreOverRotations
12
+ from .matching_exhaustive import register_matching_exhaustive, device_memory_handler
13
+ from .matching_utils import apply_convolution_mode, conditional_execute
14
+
15
+ def corr2_setup(
16
+ rfftn: Callable,
17
+ irfftn: Callable,
18
+ template: NDArray,
19
+ template_mask: NDArray,
20
+ target: NDArray,
21
+ fast_shape: Tuple[int],
22
+ fast_ft_shape: Tuple[int],
23
+ real_dtype: type,
24
+ complex_dtype: type,
25
+ shared_memory_handler: Callable,
26
+ callback_class: Callable,
27
+ callback_class_args: Dict,
28
+ **kwargs,
29
+ ) -> Dict:
30
+ """
31
+ Setup to compute a normalized cross-correlation score of a target f a template g
32
+ and a mask m:
33
+
34
+ .. math::
35
+
36
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
37
+ {N_m * \\sqrt{
38
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
39
+ }
40
+
41
+ Where:
42
+
43
+ .. math::
44
+
45
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
46
+
47
+ and Nm is the number of voxels within the template mask m.
48
+
49
+ References
50
+ ----------
51
+ .. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
52
+ Microsc. Microanal. 26, 2516 (2020)
53
+ .. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
54
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
55
+
56
+ See Also
57
+ --------
58
+ :py:meth:`flc_scoring`
59
+ """
60
+ target_pad = backend.topleft_pad(target, fast_shape)
61
+
62
+ # Target and squared target window sums
63
+ ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
64
+ ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
65
+ rfftn(target_pad, ft_target)
66
+ rfftn(backend.square(target_pad), ft_target2)
67
+
68
+ # Convert arrays used in subsequent fitting to SharedMemory objects
69
+ ft_target = backend.arr_to_sharedarr(
70
+ arr=ft_target, shared_memory_handler=shared_memory_handler
71
+ )
72
+ ft_target2 = backend.arr_to_sharedarr(
73
+ arr=ft_target2, shared_memory_handler=shared_memory_handler
74
+ )
75
+
76
+ template_buffer = backend.arr_to_sharedarr(
77
+ arr=template, shared_memory_handler=shared_memory_handler
78
+ )
79
+ template_mask_buffer = backend.arr_to_sharedarr(
80
+ arr=template_mask, shared_memory_handler=shared_memory_handler
81
+ )
82
+
83
+ template_tuple = (template_buffer, template.shape, real_dtype)
84
+ template_mask_tuple = (template_mask_buffer, template_mask.shape, real_dtype)
85
+
86
+ target_ft_tuple = (ft_target, fast_ft_shape, complex_dtype)
87
+ target_ft2_tuple = (ft_target2, fast_ft_shape, complex_dtype)
88
+
89
+ ret = {
90
+ "template": template_tuple,
91
+ "template_mask": template_mask_tuple,
92
+ "ft_target": target_ft_tuple,
93
+ "ft_target2": target_ft2_tuple,
94
+ "targetshape": target.shape,
95
+ "templateshape": template.shape,
96
+ "fast_shape": fast_shape,
97
+ "fast_ft_shape": fast_ft_shape,
98
+ "real_dtype": real_dtype,
99
+ "complex_dtype": complex_dtype,
100
+ "callback_class": callback_class,
101
+ "callback_class_args": callback_class_args,
102
+ }
103
+
104
+ return ret
105
+
106
+
107
+ def corr2_scoring(
108
+ template: Tuple[type, Tuple[int], type],
109
+ template_mask: Tuple[type, Tuple[int], type],
110
+ ft_target: Tuple[type, Tuple[int], type],
111
+ ft_target2: Tuple[type, Tuple[int], type],
112
+ template_filter: Tuple[type, Tuple[int], type],
113
+ targetshape: Tuple[int],
114
+ templateshape: Tuple[int],
115
+ fast_shape: Tuple[int],
116
+ fast_ft_shape: Tuple[int],
117
+ rotations: NDArray,
118
+ real_dtype: type,
119
+ complex_dtype: type,
120
+ callback_class: CallbackClass,
121
+ callback_class_args: Dict,
122
+ interpolation_order: int,
123
+ **kwargs,
124
+ ) -> CallbackClass:
125
+ template_buffer, template_shape, template_dtype = template
126
+ template_mask_buffer, *_ = template_mask
127
+ filter_buffer, filter_shape, filter_dtype = template_filter
128
+
129
+ ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
130
+ ft_target2_buffer, *_ = ft_target2
131
+
132
+ if callback_class is not None and isinstance(callback_class, type):
133
+ callback = callback_class(**callback_class_args)
134
+ elif not isinstance(callback_class, type):
135
+ callback = callback_class
136
+
137
+ # Retrieve objects from shared memory
138
+ template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
139
+ template_mask = backend.sharedarr_to_arr(
140
+ template_shape, template_dtype, template_mask_buffer
141
+ )
142
+ ft_target = backend.sharedarr_to_arr(
143
+ ft_target_shape, ft_target_dtype, ft_target_buffer
144
+ )
145
+ ft_target2 = backend.sharedarr_to_arr(
146
+ ft_target_shape, ft_target_dtype, ft_target2_buffer
147
+ )
148
+ template_filter = backend.sharedarr_to_arr(
149
+ filter_shape, filter_dtype, filter_buffer
150
+ )
151
+
152
+ arr = backend.preallocate_array(fast_shape, real_dtype)
153
+ temp = backend.preallocate_array(fast_shape, real_dtype)
154
+ temp2 = backend.preallocate_array(fast_shape, real_dtype)
155
+
156
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
157
+ ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
158
+
159
+ rfftn, irfftn = backend.build_fft(
160
+ fast_shape=fast_shape,
161
+ fast_ft_shape=fast_ft_shape,
162
+ real_dtype=real_dtype,
163
+ complex_dtype=complex_dtype,
164
+ fftargs=kwargs.get("fftargs", {}),
165
+ temp_real=arr,
166
+ temp_fft=ft_temp,
167
+ )
168
+
169
+ templateshape = list(templateshape)
170
+ templateshape[-1] = 1
171
+
172
+ subset = [slice(0, x) for x in templateshape]
173
+ subset.pop(-1)
174
+ subset = tuple(subset)
175
+ temp_shape = list(fast_shape)
176
+ temp_shape[-1] = template.shape[-1]
177
+ rotation_out = backend.preallocate_array(temp_shape, real_dtype)
178
+
179
+ from time import time
180
+
181
+ for index in range(rotations.shape[0]):
182
+ start = time()
183
+ rotation = rotations[index]
184
+ backend.fill(arr, 0)
185
+ backend.fill(temp, 0)
186
+ backend.fill(rotation_out, 0)
187
+
188
+ backend.rotate_array(
189
+ arr=template,
190
+ rotation_matrix=rotation,
191
+ out=rotation_out,
192
+ use_geometric_center=False,
193
+ order=1,
194
+ )
195
+ projection = backend.sum(rotation_out, axis=-1)
196
+ arr[..., 0] = projection
197
+
198
+ projection_mask = backend.full(templateshape, dtype=real_dtype, fill_value=1)
199
+ backend.fill(temp, 0)
200
+ temp = backend.topleft_pad(projection_mask, temp.shape)
201
+
202
+ template_mean = backend.mean(projection[subset])
203
+ template_volume = backend.prod(projection[subset].shape)
204
+ template_ssd = backend.sum(
205
+ backend.square(backend.subtract(projection[subset], template_mean))
206
+ )
207
+
208
+ rfftn(temp, ft_temp)
209
+ backend.multiply(ft_target, ft_temp, out=ft_denom)
210
+ irfftn(ft_denom, temp)
211
+
212
+ numerator = backend.multiply(temp, template_mean)
213
+
214
+ backend.square(temp, out=temp)
215
+ backend.divide(temp, template_volume, out=temp)
216
+ backend.multiply(ft_target2, ft_temp, out=ft_denom)
217
+ irfftn(ft_denom, temp2)
218
+
219
+ backend.subtract(temp2, temp, out=temp)
220
+ backend.multiply(temp, template_ssd, out=temp)
221
+ backend.maximum(temp, 0.0, out=temp)
222
+ backend.sqrt(temp, out=temp)
223
+
224
+ denominator_mask = temp > backend.eps(temp.dtype)
225
+ inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
226
+ inv_denominator[denominator_mask] = 1 / temp[denominator_mask]
227
+
228
+ rfftn(arr, ft_temp)
229
+ backend.multiply(ft_target, ft_temp, out=ft_temp)
230
+ irfftn(ft_temp, arr)
231
+
232
+ backend.subtract(arr, numerator, out=arr)
233
+ backend.multiply(arr, inv_denominator, out=arr)
234
+
235
+ convolution_mode = kwargs.get("convolution_mode", "full")
236
+ score = apply_convolution_mode(
237
+ arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
238
+ )
239
+ print(time() - start)
240
+ if callback_class is not None:
241
+ callback(
242
+ score,
243
+ rotation_matrix=rotation,
244
+ rotation_index=index,
245
+ **callback_class_args,
246
+ )
247
+
248
+ return callback
249
+
250
+
251
+ def corr3_setup(
252
+ rfftn: Callable,
253
+ irfftn: Callable,
254
+ template: NDArray,
255
+ template_mask: NDArray,
256
+ target: NDArray,
257
+ fast_shape: Tuple[int],
258
+ fast_ft_shape: Tuple[int],
259
+ real_dtype: type,
260
+ complex_dtype: type,
261
+ shared_memory_handler: Callable,
262
+ callback_class: Callable,
263
+ callback_class_args: Dict,
264
+ **kwargs,
265
+ ) -> Dict:
266
+ target_pad = backend.topleft_pad(target, fast_shape)
267
+
268
+ # The exact composition of the denominator is debatable
269
+ # scikit-image match_template multiplies the running sum of the target
270
+ # with a scaling factor derived from the template. This is probably appropriate
271
+ # in pattern matching situations where the template exists in the target
272
+ template_mask = backend.preallocate_array(
273
+ (*template_mask.shape[:-1], 1), real_dtype
274
+ )
275
+ template_mask[:] = 1
276
+ window_template = backend.topleft_pad(template_mask, fast_shape)
277
+ ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
278
+ rfftn(window_template, ft_window_template)
279
+ window_template = None
280
+
281
+ # Target and squared target window sums
282
+ ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
283
+ ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
284
+ denominator = backend.preallocate_array(fast_shape, real_dtype)
285
+ target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
286
+ rfftn(target_pad, ft_target)
287
+
288
+ rfftn(backend.square(target_pad), ft_target2)
289
+ backend.multiply(ft_target2, ft_window_template, out=ft_target2)
290
+ irfftn(ft_target2, denominator)
291
+
292
+ backend.multiply(ft_target, ft_window_template, out=ft_window_template)
293
+ irfftn(ft_window_template, target_window_sum)
294
+
295
+ target_pad, ft_target2, ft_window_template = None, None, None
296
+
297
+ projection = template.sum(axis=-1)
298
+ # Normalizing constants
299
+ template_mean = backend.mean(projection)
300
+ template_volume = np.prod(projection.shape)
301
+ template_ssd = backend.sum(
302
+ backend.square(backend.subtract(projection, template_mean))
303
+ )
304
+
305
+ # Final numerator is score - numerator2
306
+ numerator2 = backend.multiply(target_window_sum, template_mean)
307
+
308
+ # Compute denominator
309
+ backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
310
+ backend.divide(target_window_sum, template_volume, out=target_window_sum)
311
+
312
+ backend.subtract(denominator, target_window_sum, out=denominator)
313
+ backend.multiply(denominator, template_ssd, out=denominator)
314
+ backend.maximum(denominator, 0, out=denominator)
315
+ backend.sqrt(denominator, out=denominator)
316
+ target_window_sum = None
317
+
318
+ # Invert denominator to compute final score as product
319
+ denominator_mask = denominator > backend.eps(denominator.dtype)
320
+ inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
321
+ inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
322
+
323
+ # Convert arrays used in subsequent fitting to SharedMemory objects
324
+ template_buffer = backend.arr_to_sharedarr(
325
+ arr=template, shared_memory_handler=shared_memory_handler
326
+ )
327
+ target_ft_buffer = backend.arr_to_sharedarr(
328
+ arr=ft_target, shared_memory_handler=shared_memory_handler
329
+ )
330
+ inv_denominator_buffer = backend.arr_to_sharedarr(
331
+ arr=inv_denominator, shared_memory_handler=shared_memory_handler
332
+ )
333
+ numerator2_buffer = backend.arr_to_sharedarr(
334
+ arr=numerator2, shared_memory_handler=shared_memory_handler
335
+ )
336
+
337
+ template_tuple = (template_buffer, deepcopy(template.shape), real_dtype)
338
+ target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
339
+
340
+ inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
341
+ numerator2_tuple = (numerator2_buffer, fast_shape, real_dtype)
342
+
343
+ ft_target, inv_denominator, numerator2 = None, None, None
344
+
345
+ ret = {
346
+ "template": template_tuple,
347
+ "ft_target": target_ft_tuple,
348
+ "inv_denominator": inv_denominator_tuple,
349
+ "numerator2": numerator2_tuple,
350
+ "targetshape": deepcopy(target.shape),
351
+ "templateshape": deepcopy(template.shape),
352
+ "fast_shape": fast_shape,
353
+ "fast_ft_shape": fast_ft_shape,
354
+ "real_dtype": real_dtype,
355
+ "complex_dtype": complex_dtype,
356
+ "callback_class": callback_class,
357
+ "callback_class_args": callback_class_args,
358
+ "template_mean": kwargs.get("template_mean", template_mean),
359
+ }
360
+
361
+ return ret
362
+
363
+
364
+ def corr3_scoring(
365
+ template: Tuple[type, Tuple[int], type],
366
+ ft_target: Tuple[type, Tuple[int], type],
367
+ inv_denominator: Tuple[type, Tuple[int], type],
368
+ numerator2: Tuple[type, Tuple[int], type],
369
+ template_filter: Tuple[type, Tuple[int], type],
370
+ targetshape: Tuple[int],
371
+ templateshape: Tuple[int],
372
+ fast_shape: Tuple[int],
373
+ fast_ft_shape: Tuple[int],
374
+ rotations: NDArray,
375
+ real_dtype: type,
376
+ complex_dtype: type,
377
+ callback_class: CallbackClass,
378
+ callback_class_args: Dict,
379
+ interpolation_order: int,
380
+ convolution_mode: str = "full",
381
+ **kwargs,
382
+ ) -> CallbackClass:
383
+ template_buffer, template_shape, template_dtype = template
384
+ ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
385
+ inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
386
+ numerator2_buffer, numerator2_shape, _ = numerator2
387
+ filter_buffer, filter_shape, filter_dtype = template_filter
388
+
389
+ if callback_class is not None and isinstance(callback_class, type):
390
+ callback = callback_class(**callback_class_args)
391
+ elif not isinstance(callback_class, type):
392
+ callback = callback_class
393
+
394
+ # Retrieve objects from shared memory
395
+ template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
396
+ ft_target = backend.sharedarr_to_arr(
397
+ ft_target_shape, ft_target_dtype, ft_target_buffer
398
+ )
399
+ inv_denominator = backend.sharedarr_to_arr(
400
+ inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
401
+ )
402
+ numerator2 = backend.sharedarr_to_arr(
403
+ numerator2_shape, template_dtype, numerator2_buffer
404
+ )
405
+ template_filter = backend.sharedarr_to_arr(
406
+ filter_shape, filter_dtype, filter_buffer
407
+ )
408
+
409
+ arr = backend.preallocate_array(fast_shape, real_dtype)
410
+ ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
411
+
412
+ rfftn, irfftn = backend.build_fft(
413
+ fast_shape=fast_shape,
414
+ fast_ft_shape=fast_ft_shape,
415
+ real_dtype=real_dtype,
416
+ complex_dtype=complex_dtype,
417
+ fftargs=kwargs.get("fftargs", {}),
418
+ temp_real=arr,
419
+ temp_fft=ft_temp,
420
+ )
421
+
422
+ norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
423
+ norm_denominator = (backend.sum(inv_denominator) != 1) & (
424
+ backend.size(inv_denominator) != 1
425
+ )
426
+ filter_template = backend.size(template_filter) != 0
427
+
428
+ norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
429
+ norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
430
+ template_filter_func = conditional_execute(backend.multiply, filter_template)
431
+
432
+ rotation_out = backend.preallocate_array(
433
+ (*fast_shape[:-1], template.shape[-1]), real_dtype
434
+ )
435
+ templateshape = list(templateshape)
436
+ templateshape[-1] = 1
437
+ from time import time
438
+
439
+ for index in range(rotations.shape[0]):
440
+ start = time()
441
+ rotation = rotations[index]
442
+ backend.fill(arr, 0)
443
+ backend.rotate_array(
444
+ arr=template,
445
+ rotation_matrix=rotation,
446
+ out=rotation_out,
447
+ use_geometric_center=False,
448
+ order=interpolation_order,
449
+ )
450
+ projection = backend.sum(rotation_out, axis=-1)
451
+ arr[..., 0] = projection
452
+
453
+ rfftn(arr, ft_temp)
454
+ template_filter_func(ft_temp, template_filter, out=ft_temp)
455
+
456
+ backend.multiply(ft_target, ft_temp, out=ft_temp)
457
+ irfftn(ft_temp, arr)
458
+
459
+ norm_func_numerator(arr, numerator2, out=arr)
460
+ norm_func_denominator(arr, inv_denominator, out=arr)
461
+
462
+ score = apply_convolution_mode(
463
+ arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
464
+ )
465
+ print(time() - start)
466
+ if callback_class is not None:
467
+ callback(
468
+ score,
469
+ rotation_matrix=rotation,
470
+ rotation_index=index,
471
+ **callback_class_args,
472
+ )
473
+
474
+ return callback
475
+
476
+ @device_memory_handler
477
+ def scan(
478
+ matching_data: MatchingData,
479
+ matching_setup: Callable,
480
+ matching_score: Callable,
481
+ n_jobs: int = 4,
482
+ callback_class: CallbackClass = None,
483
+ callback_class_args: Dict = {},
484
+ fftargs: Dict = {},
485
+ pad_fourier: bool = True,
486
+ interpolation_order: int = 3,
487
+ jobs_per_callback_class: int = 8,
488
+ **kwargs,
489
+ ) -> Tuple:
490
+ """
491
+ Perform template matching between target and template and sample
492
+ different rotations of template.
493
+
494
+ Parameters
495
+ ----------
496
+ matching_data : MatchingData
497
+ Template matching data.
498
+ matching_setup : Callable
499
+ Function pointer to setup function.
500
+ matching_score : Callable
501
+ Function pointer to scoring function.
502
+ n_jobs : int, optional
503
+ Number of parallel jobs. Default is 4.
504
+ callback_class : type, optional
505
+ Analyzer class pointer to operate on computed scores.
506
+ callback_class_args : dict, optional
507
+ Arguments passed to the callback_class. Default is an empty dictionary.
508
+ fftargs : dict, optional
509
+ Arguments for the FFT operations. Default is an empty dictionary.
510
+ pad_fourier: bool, optional
511
+ Whether to pad target and template to the full convolution shape.
512
+ interpolation_order : int, optional
513
+ Order of spline interpolation for rotations.
514
+ jobs_per_callback_class : int, optional
515
+ How many jobs should be processed by a single callback_class instance,
516
+ if ones is provided.
517
+ **kwargs : various
518
+ Additional arguments.
519
+
520
+ Returns
521
+ -------
522
+ Tuple
523
+ The merged results from callback_class if provided otherwise None.
524
+ """
525
+ matching_data.to_backend()
526
+ fourier_pad = matching_data._templateshape
527
+ fourier_pad = list(matching_data._templateshape)
528
+ fourier_pad[-1] = 1
529
+ print("make sure to remove this")
530
+ fourier_shift = backend.zeros(len(fourier_pad))
531
+ if not pad_fourier:
532
+ fourier_pad = backend.full(shape=fourier_shift.shape, fill_value=1, dtype=int)
533
+ fourier_shift = 1 - backend.astype(
534
+ backend.divide(matching_data._templateshape, 2), int
535
+ )
536
+ callback_class_args["fourier_shift"] = fourier_shift
537
+
538
+ _, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
539
+ matching_data._target.shape, fourier_pad
540
+ )
541
+ rfftn, irfftn = backend.build_fft(
542
+ fast_shape=fast_shape,
543
+ fast_ft_shape=fast_ft_shape,
544
+ real_dtype=matching_data._default_dtype,
545
+ complex_dtype=matching_data._complex_dtype,
546
+ fftargs=fftargs,
547
+ )
548
+ setup = matching_setup(
549
+ rfftn=rfftn,
550
+ irfftn=irfftn,
551
+ template=matching_data.template,
552
+ template_mask=matching_data.template_mask,
553
+ target=matching_data.target,
554
+ target_mask=matching_data.target_mask,
555
+ fast_shape=fast_shape,
556
+ fast_ft_shape=fast_ft_shape,
557
+ real_dtype=matching_data._default_dtype,
558
+ complex_dtype=matching_data._complex_dtype,
559
+ callback_class=callback_class,
560
+ callback_class_args=callback_class_args,
561
+ **kwargs,
562
+ )
563
+ rfftn, irfftn = None, None
564
+
565
+ template_filter, preprocessor = None, Preprocessor()
566
+ for method, parameters in matching_data.template_filter.items():
567
+ parameters["shape"] = fast_shape
568
+ parameters["omit_negative_frequencies"] = True
569
+ out = preprocessor.apply_method(method=method, parameters=parameters)
570
+ if template_filter is None:
571
+ template_filter = out
572
+ np.multiply(template_filter, out, out=template_filter)
573
+
574
+ if template_filter is None:
575
+ template_filter = backend.full(
576
+ shape=(1,), fill_value=1, dtype=backend._default_dtype
577
+ )
578
+ else:
579
+ template_filter = backend.to_backend_array(template_filter)
580
+
581
+ template_filter = backend.astype(template_filter, backend._default_dtype)
582
+ template_filter_buffer = backend.arr_to_sharedarr(
583
+ arr=template_filter,
584
+ shared_memory_handler=kwargs.get("shared_memory_handler", None),
585
+ )
586
+ setup["template_filter"] = (
587
+ template_filter_buffer,
588
+ template_filter.shape,
589
+ template_filter.dtype,
590
+ )
591
+
592
+ callback_class_args["translation_offset"] = backend.astype(
593
+ matching_data._translation_offset, int
594
+ )
595
+ callback_class_args["thread_safe"] = n_jobs > 1
596
+ callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
597
+
598
+ n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
599
+ callback_class = setup.pop("callback_class", callback_class)
600
+ callback_class_args = setup.pop("callback_class_args", callback_class_args)
601
+ callback_classes = [callback_class for _ in range(n_callback_classes)]
602
+ if callback_class == MaxScoreOverRotations:
603
+ score_space_shape = backend.subtract(
604
+ matching_data.target.shape,
605
+ matching_data._target_pad,
606
+ )
607
+ callback_classes = [
608
+ class_name(
609
+ score_space_shape=score_space_shape,
610
+ score_space_dtype=matching_data._default_dtype,
611
+ shared_memory_handler=kwargs.get("shared_memory_handler", None),
612
+ rotation_space_dtype=backend._default_dtype_int,
613
+ **callback_class_args,
614
+ )
615
+ for class_name in callback_classes
616
+ ]
617
+
618
+ matching_data._target, matching_data._template = None, None
619
+ matching_data._target_mask, matching_data._template_mask = None, None
620
+
621
+ setup["fftargs"] = fftargs.copy()
622
+ convolution_mode = "same"
623
+ if backend.sum(matching_data._target_pad) > 0:
624
+ convolution_mode = "valid"
625
+ setup["convolution_mode"] = convolution_mode
626
+ setup["interpolation_order"] = interpolation_order
627
+ rotation_list = matching_data._split_rotations_on_jobs(n_jobs)
628
+
629
+ backend.free_cache()
630
+
631
+ def _run_scoring(backend_name, backend_args, rotations, **kwargs):
632
+ from tme.backends import backend
633
+
634
+ backend.change_backend(backend_name, **backend_args)
635
+ return matching_score(rotations=rotations, **kwargs)
636
+
637
+ callbacks = Parallel(n_jobs=n_jobs)(
638
+ delayed(_run_scoring)(
639
+ backend_name=backend._backend_name,
640
+ backend_args=backend._backend_args,
641
+ rotations=rotation,
642
+ callback_class=callback_classes[index % n_callback_classes],
643
+ callback_class_args=callback_class_args,
644
+ **setup,
645
+ )
646
+ for index, rotation in enumerate(rotation_list)
647
+ )
648
+
649
+ callbacks = [
650
+ tuple(callback)
651
+ for callback in callbacks[0:n_callback_classes]
652
+ if callback is not None
653
+ ]
654
+ backend.free_cache()
655
+
656
+ merged_callback = None
657
+ if callback_class is not None:
658
+ merged_callback = callback_class.merge(
659
+ callbacks,
660
+ **callback_class_args,
661
+ score_indices=matching_data.indices,
662
+ inner_merge=True,
663
+ )
664
+
665
+ return merged_callback
666
+
667
+
668
+ register_matching_exhaustive(
669
+ matching = "CC2",
670
+ matching_setup = corr2_setup,
671
+ matching_scoring = corr2_scoring,
672
+ memory_class = CCMemoryUsage)
673
+
674
+ register_matching_exhaustive(
675
+ matching = "CC3",
676
+ matching_setup = corr3_setup,
677
+ matching_scoring = corr3_scoring,
678
+ memory_class = CCMemoryUsage
679
+ )