pytme 0.1.1__cp311-cp311-macosx_13_0_arm64.whl → 0.1.3__cp311-cp311-macosx_13_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.1.1.data → pytme-0.1.3.data}/scripts/match_template.py +10 -9
- {pytme-0.1.1.data → pytme-0.1.3.data}/scripts/postprocess.py +6 -3
- {pytme-0.1.1.data → pytme-0.1.3.data}/scripts/preprocessor_gui.py +93 -14
- {pytme-0.1.1.dist-info → pytme-0.1.3.dist-info}/METADATA +2 -2
- {pytme-0.1.1.dist-info → pytme-0.1.3.dist-info}/RECORD +23 -22
- tme/__version__.py +1 -1
- tme/analyzer.py +2 -2
- tme/backends/pytorch_backend.py +7 -4
- tme/density.py +22 -13
- tme/matching_data.py +2 -4
- tme/matching_exhaustive.py +0 -2
- tme/matching_memory.py +1 -1
- tme/matching_optimization.py +5 -0
- tme/matching_utils.py +7 -1
- tme/preprocessor.py +62 -8
- tme/scoring.py +679 -0
- tme/structure.py +19 -20
- {pytme-0.1.1.data → pytme-0.1.3.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.1.1.data → pytme-0.1.3.data}/scripts/preprocess.py +0 -0
- {pytme-0.1.1.dist-info → pytme-0.1.3.dist-info}/LICENSE +0 -0
- {pytme-0.1.1.dist-info → pytme-0.1.3.dist-info}/WHEEL +0 -0
- {pytme-0.1.1.dist-info → pytme-0.1.3.dist-info}/entry_points.txt +0 -0
- {pytme-0.1.1.dist-info → pytme-0.1.3.dist-info}/top_level.txt +0 -0
tme/scoring.py
ADDED
@@ -0,0 +1,679 @@
|
|
1
|
+
from copy import deepcopy
|
2
|
+
from typing import Tuple, Callable, Dict
|
3
|
+
from joblib import Parallel, delayed
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from . import Preprocessor, MatchingData
|
8
|
+
from .backends import backend
|
9
|
+
from .types import NDArray, CallbackClass
|
10
|
+
from .matching_memory import CCMemoryUsage
|
11
|
+
from .analyzer import MaxScoreOverRotations
|
12
|
+
from .matching_exhaustive import register_matching_exhaustive, device_memory_handler
|
13
|
+
from .matching_utils import apply_convolution_mode, conditional_execute
|
14
|
+
|
15
|
+
def corr2_setup(
|
16
|
+
rfftn: Callable,
|
17
|
+
irfftn: Callable,
|
18
|
+
template: NDArray,
|
19
|
+
template_mask: NDArray,
|
20
|
+
target: NDArray,
|
21
|
+
fast_shape: Tuple[int],
|
22
|
+
fast_ft_shape: Tuple[int],
|
23
|
+
real_dtype: type,
|
24
|
+
complex_dtype: type,
|
25
|
+
shared_memory_handler: Callable,
|
26
|
+
callback_class: Callable,
|
27
|
+
callback_class_args: Dict,
|
28
|
+
**kwargs,
|
29
|
+
) -> Dict:
|
30
|
+
"""
|
31
|
+
Setup to compute a normalized cross-correlation score of a target f a template g
|
32
|
+
and a mask m:
|
33
|
+
|
34
|
+
.. math::
|
35
|
+
|
36
|
+
\\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
|
37
|
+
{N_m * \\sqrt{
|
38
|
+
\\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
|
39
|
+
}
|
40
|
+
|
41
|
+
Where:
|
42
|
+
|
43
|
+
.. math::
|
44
|
+
|
45
|
+
CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
|
46
|
+
|
47
|
+
and Nm is the number of voxels within the template mask m.
|
48
|
+
|
49
|
+
References
|
50
|
+
----------
|
51
|
+
.. [1] W. Wan, S. Khavnekar, J. Wagner, P. Erdmann, and W. Baumeister
|
52
|
+
Microsc. Microanal. 26, 2516 (2020)
|
53
|
+
.. [2] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
|
54
|
+
and F. Förster, J. Struct. Biol. 178, 177 (2012).
|
55
|
+
|
56
|
+
See Also
|
57
|
+
--------
|
58
|
+
:py:meth:`flc_scoring`
|
59
|
+
"""
|
60
|
+
target_pad = backend.topleft_pad(target, fast_shape)
|
61
|
+
|
62
|
+
# Target and squared target window sums
|
63
|
+
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
64
|
+
ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
65
|
+
rfftn(target_pad, ft_target)
|
66
|
+
rfftn(backend.square(target_pad), ft_target2)
|
67
|
+
|
68
|
+
# Convert arrays used in subsequent fitting to SharedMemory objects
|
69
|
+
ft_target = backend.arr_to_sharedarr(
|
70
|
+
arr=ft_target, shared_memory_handler=shared_memory_handler
|
71
|
+
)
|
72
|
+
ft_target2 = backend.arr_to_sharedarr(
|
73
|
+
arr=ft_target2, shared_memory_handler=shared_memory_handler
|
74
|
+
)
|
75
|
+
|
76
|
+
template_buffer = backend.arr_to_sharedarr(
|
77
|
+
arr=template, shared_memory_handler=shared_memory_handler
|
78
|
+
)
|
79
|
+
template_mask_buffer = backend.arr_to_sharedarr(
|
80
|
+
arr=template_mask, shared_memory_handler=shared_memory_handler
|
81
|
+
)
|
82
|
+
|
83
|
+
template_tuple = (template_buffer, template.shape, real_dtype)
|
84
|
+
template_mask_tuple = (template_mask_buffer, template_mask.shape, real_dtype)
|
85
|
+
|
86
|
+
target_ft_tuple = (ft_target, fast_ft_shape, complex_dtype)
|
87
|
+
target_ft2_tuple = (ft_target2, fast_ft_shape, complex_dtype)
|
88
|
+
|
89
|
+
ret = {
|
90
|
+
"template": template_tuple,
|
91
|
+
"template_mask": template_mask_tuple,
|
92
|
+
"ft_target": target_ft_tuple,
|
93
|
+
"ft_target2": target_ft2_tuple,
|
94
|
+
"targetshape": target.shape,
|
95
|
+
"templateshape": template.shape,
|
96
|
+
"fast_shape": fast_shape,
|
97
|
+
"fast_ft_shape": fast_ft_shape,
|
98
|
+
"real_dtype": real_dtype,
|
99
|
+
"complex_dtype": complex_dtype,
|
100
|
+
"callback_class": callback_class,
|
101
|
+
"callback_class_args": callback_class_args,
|
102
|
+
}
|
103
|
+
|
104
|
+
return ret
|
105
|
+
|
106
|
+
|
107
|
+
def corr2_scoring(
|
108
|
+
template: Tuple[type, Tuple[int], type],
|
109
|
+
template_mask: Tuple[type, Tuple[int], type],
|
110
|
+
ft_target: Tuple[type, Tuple[int], type],
|
111
|
+
ft_target2: Tuple[type, Tuple[int], type],
|
112
|
+
template_filter: Tuple[type, Tuple[int], type],
|
113
|
+
targetshape: Tuple[int],
|
114
|
+
templateshape: Tuple[int],
|
115
|
+
fast_shape: Tuple[int],
|
116
|
+
fast_ft_shape: Tuple[int],
|
117
|
+
rotations: NDArray,
|
118
|
+
real_dtype: type,
|
119
|
+
complex_dtype: type,
|
120
|
+
callback_class: CallbackClass,
|
121
|
+
callback_class_args: Dict,
|
122
|
+
interpolation_order: int,
|
123
|
+
**kwargs,
|
124
|
+
) -> CallbackClass:
|
125
|
+
template_buffer, template_shape, template_dtype = template
|
126
|
+
template_mask_buffer, *_ = template_mask
|
127
|
+
filter_buffer, filter_shape, filter_dtype = template_filter
|
128
|
+
|
129
|
+
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
130
|
+
ft_target2_buffer, *_ = ft_target2
|
131
|
+
|
132
|
+
if callback_class is not None and isinstance(callback_class, type):
|
133
|
+
callback = callback_class(**callback_class_args)
|
134
|
+
elif not isinstance(callback_class, type):
|
135
|
+
callback = callback_class
|
136
|
+
|
137
|
+
# Retrieve objects from shared memory
|
138
|
+
template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
|
139
|
+
template_mask = backend.sharedarr_to_arr(
|
140
|
+
template_shape, template_dtype, template_mask_buffer
|
141
|
+
)
|
142
|
+
ft_target = backend.sharedarr_to_arr(
|
143
|
+
ft_target_shape, ft_target_dtype, ft_target_buffer
|
144
|
+
)
|
145
|
+
ft_target2 = backend.sharedarr_to_arr(
|
146
|
+
ft_target_shape, ft_target_dtype, ft_target2_buffer
|
147
|
+
)
|
148
|
+
template_filter = backend.sharedarr_to_arr(
|
149
|
+
filter_shape, filter_dtype, filter_buffer
|
150
|
+
)
|
151
|
+
|
152
|
+
arr = backend.preallocate_array(fast_shape, real_dtype)
|
153
|
+
temp = backend.preallocate_array(fast_shape, real_dtype)
|
154
|
+
temp2 = backend.preallocate_array(fast_shape, real_dtype)
|
155
|
+
|
156
|
+
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
157
|
+
ft_denom = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
158
|
+
|
159
|
+
rfftn, irfftn = backend.build_fft(
|
160
|
+
fast_shape=fast_shape,
|
161
|
+
fast_ft_shape=fast_ft_shape,
|
162
|
+
real_dtype=real_dtype,
|
163
|
+
complex_dtype=complex_dtype,
|
164
|
+
fftargs=kwargs.get("fftargs", {}),
|
165
|
+
temp_real=arr,
|
166
|
+
temp_fft=ft_temp,
|
167
|
+
)
|
168
|
+
|
169
|
+
templateshape = list(templateshape)
|
170
|
+
templateshape[-1] = 1
|
171
|
+
|
172
|
+
subset = [slice(0, x) for x in templateshape]
|
173
|
+
subset.pop(-1)
|
174
|
+
subset = tuple(subset)
|
175
|
+
temp_shape = list(fast_shape)
|
176
|
+
temp_shape[-1] = template.shape[-1]
|
177
|
+
rotation_out = backend.preallocate_array(temp_shape, real_dtype)
|
178
|
+
|
179
|
+
from time import time
|
180
|
+
|
181
|
+
for index in range(rotations.shape[0]):
|
182
|
+
start = time()
|
183
|
+
rotation = rotations[index]
|
184
|
+
backend.fill(arr, 0)
|
185
|
+
backend.fill(temp, 0)
|
186
|
+
backend.fill(rotation_out, 0)
|
187
|
+
|
188
|
+
backend.rotate_array(
|
189
|
+
arr=template,
|
190
|
+
rotation_matrix=rotation,
|
191
|
+
out=rotation_out,
|
192
|
+
use_geometric_center=False,
|
193
|
+
order=1,
|
194
|
+
)
|
195
|
+
projection = backend.sum(rotation_out, axis=-1)
|
196
|
+
arr[..., 0] = projection
|
197
|
+
|
198
|
+
projection_mask = backend.full(templateshape, dtype=real_dtype, fill_value=1)
|
199
|
+
backend.fill(temp, 0)
|
200
|
+
temp = backend.topleft_pad(projection_mask, temp.shape)
|
201
|
+
|
202
|
+
template_mean = backend.mean(projection[subset])
|
203
|
+
template_volume = backend.prod(projection[subset].shape)
|
204
|
+
template_ssd = backend.sum(
|
205
|
+
backend.square(backend.subtract(projection[subset], template_mean))
|
206
|
+
)
|
207
|
+
|
208
|
+
rfftn(temp, ft_temp)
|
209
|
+
backend.multiply(ft_target, ft_temp, out=ft_denom)
|
210
|
+
irfftn(ft_denom, temp)
|
211
|
+
|
212
|
+
numerator = backend.multiply(temp, template_mean)
|
213
|
+
|
214
|
+
backend.square(temp, out=temp)
|
215
|
+
backend.divide(temp, template_volume, out=temp)
|
216
|
+
backend.multiply(ft_target2, ft_temp, out=ft_denom)
|
217
|
+
irfftn(ft_denom, temp2)
|
218
|
+
|
219
|
+
backend.subtract(temp2, temp, out=temp)
|
220
|
+
backend.multiply(temp, template_ssd, out=temp)
|
221
|
+
backend.maximum(temp, 0.0, out=temp)
|
222
|
+
backend.sqrt(temp, out=temp)
|
223
|
+
|
224
|
+
denominator_mask = temp > backend.eps(temp.dtype)
|
225
|
+
inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
|
226
|
+
inv_denominator[denominator_mask] = 1 / temp[denominator_mask]
|
227
|
+
|
228
|
+
rfftn(arr, ft_temp)
|
229
|
+
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
230
|
+
irfftn(ft_temp, arr)
|
231
|
+
|
232
|
+
backend.subtract(arr, numerator, out=arr)
|
233
|
+
backend.multiply(arr, inv_denominator, out=arr)
|
234
|
+
|
235
|
+
convolution_mode = kwargs.get("convolution_mode", "full")
|
236
|
+
score = apply_convolution_mode(
|
237
|
+
arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
|
238
|
+
)
|
239
|
+
print(time() - start)
|
240
|
+
if callback_class is not None:
|
241
|
+
callback(
|
242
|
+
score,
|
243
|
+
rotation_matrix=rotation,
|
244
|
+
rotation_index=index,
|
245
|
+
**callback_class_args,
|
246
|
+
)
|
247
|
+
|
248
|
+
return callback
|
249
|
+
|
250
|
+
|
251
|
+
def corr3_setup(
|
252
|
+
rfftn: Callable,
|
253
|
+
irfftn: Callable,
|
254
|
+
template: NDArray,
|
255
|
+
template_mask: NDArray,
|
256
|
+
target: NDArray,
|
257
|
+
fast_shape: Tuple[int],
|
258
|
+
fast_ft_shape: Tuple[int],
|
259
|
+
real_dtype: type,
|
260
|
+
complex_dtype: type,
|
261
|
+
shared_memory_handler: Callable,
|
262
|
+
callback_class: Callable,
|
263
|
+
callback_class_args: Dict,
|
264
|
+
**kwargs,
|
265
|
+
) -> Dict:
|
266
|
+
target_pad = backend.topleft_pad(target, fast_shape)
|
267
|
+
|
268
|
+
# The exact composition of the denominator is debatable
|
269
|
+
# scikit-image match_template multiplies the running sum of the target
|
270
|
+
# with a scaling factor derived from the template. This is probably appropriate
|
271
|
+
# in pattern matching situations where the template exists in the target
|
272
|
+
template_mask = backend.preallocate_array(
|
273
|
+
(*template_mask.shape[:-1], 1), real_dtype
|
274
|
+
)
|
275
|
+
template_mask[:] = 1
|
276
|
+
window_template = backend.topleft_pad(template_mask, fast_shape)
|
277
|
+
ft_window_template = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
278
|
+
rfftn(window_template, ft_window_template)
|
279
|
+
window_template = None
|
280
|
+
|
281
|
+
# Target and squared target window sums
|
282
|
+
ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
283
|
+
ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
284
|
+
denominator = backend.preallocate_array(fast_shape, real_dtype)
|
285
|
+
target_window_sum = backend.preallocate_array(fast_shape, real_dtype)
|
286
|
+
rfftn(target_pad, ft_target)
|
287
|
+
|
288
|
+
rfftn(backend.square(target_pad), ft_target2)
|
289
|
+
backend.multiply(ft_target2, ft_window_template, out=ft_target2)
|
290
|
+
irfftn(ft_target2, denominator)
|
291
|
+
|
292
|
+
backend.multiply(ft_target, ft_window_template, out=ft_window_template)
|
293
|
+
irfftn(ft_window_template, target_window_sum)
|
294
|
+
|
295
|
+
target_pad, ft_target2, ft_window_template = None, None, None
|
296
|
+
|
297
|
+
projection = template.sum(axis=-1)
|
298
|
+
# Normalizing constants
|
299
|
+
template_mean = backend.mean(projection)
|
300
|
+
template_volume = np.prod(projection.shape)
|
301
|
+
template_ssd = backend.sum(
|
302
|
+
backend.square(backend.subtract(projection, template_mean))
|
303
|
+
)
|
304
|
+
|
305
|
+
# Final numerator is score - numerator2
|
306
|
+
numerator2 = backend.multiply(target_window_sum, template_mean)
|
307
|
+
|
308
|
+
# Compute denominator
|
309
|
+
backend.multiply(target_window_sum, target_window_sum, out=target_window_sum)
|
310
|
+
backend.divide(target_window_sum, template_volume, out=target_window_sum)
|
311
|
+
|
312
|
+
backend.subtract(denominator, target_window_sum, out=denominator)
|
313
|
+
backend.multiply(denominator, template_ssd, out=denominator)
|
314
|
+
backend.maximum(denominator, 0, out=denominator)
|
315
|
+
backend.sqrt(denominator, out=denominator)
|
316
|
+
target_window_sum = None
|
317
|
+
|
318
|
+
# Invert denominator to compute final score as product
|
319
|
+
denominator_mask = denominator > backend.eps(denominator.dtype)
|
320
|
+
inv_denominator = backend.preallocate_array(fast_shape, real_dtype)
|
321
|
+
inv_denominator[denominator_mask] = 1 / denominator[denominator_mask]
|
322
|
+
|
323
|
+
# Convert arrays used in subsequent fitting to SharedMemory objects
|
324
|
+
template_buffer = backend.arr_to_sharedarr(
|
325
|
+
arr=template, shared_memory_handler=shared_memory_handler
|
326
|
+
)
|
327
|
+
target_ft_buffer = backend.arr_to_sharedarr(
|
328
|
+
arr=ft_target, shared_memory_handler=shared_memory_handler
|
329
|
+
)
|
330
|
+
inv_denominator_buffer = backend.arr_to_sharedarr(
|
331
|
+
arr=inv_denominator, shared_memory_handler=shared_memory_handler
|
332
|
+
)
|
333
|
+
numerator2_buffer = backend.arr_to_sharedarr(
|
334
|
+
arr=numerator2, shared_memory_handler=shared_memory_handler
|
335
|
+
)
|
336
|
+
|
337
|
+
template_tuple = (template_buffer, deepcopy(template.shape), real_dtype)
|
338
|
+
target_ft_tuple = (target_ft_buffer, fast_ft_shape, complex_dtype)
|
339
|
+
|
340
|
+
inv_denominator_tuple = (inv_denominator_buffer, fast_shape, real_dtype)
|
341
|
+
numerator2_tuple = (numerator2_buffer, fast_shape, real_dtype)
|
342
|
+
|
343
|
+
ft_target, inv_denominator, numerator2 = None, None, None
|
344
|
+
|
345
|
+
ret = {
|
346
|
+
"template": template_tuple,
|
347
|
+
"ft_target": target_ft_tuple,
|
348
|
+
"inv_denominator": inv_denominator_tuple,
|
349
|
+
"numerator2": numerator2_tuple,
|
350
|
+
"targetshape": deepcopy(target.shape),
|
351
|
+
"templateshape": deepcopy(template.shape),
|
352
|
+
"fast_shape": fast_shape,
|
353
|
+
"fast_ft_shape": fast_ft_shape,
|
354
|
+
"real_dtype": real_dtype,
|
355
|
+
"complex_dtype": complex_dtype,
|
356
|
+
"callback_class": callback_class,
|
357
|
+
"callback_class_args": callback_class_args,
|
358
|
+
"template_mean": kwargs.get("template_mean", template_mean),
|
359
|
+
}
|
360
|
+
|
361
|
+
return ret
|
362
|
+
|
363
|
+
|
364
|
+
def corr3_scoring(
|
365
|
+
template: Tuple[type, Tuple[int], type],
|
366
|
+
ft_target: Tuple[type, Tuple[int], type],
|
367
|
+
inv_denominator: Tuple[type, Tuple[int], type],
|
368
|
+
numerator2: Tuple[type, Tuple[int], type],
|
369
|
+
template_filter: Tuple[type, Tuple[int], type],
|
370
|
+
targetshape: Tuple[int],
|
371
|
+
templateshape: Tuple[int],
|
372
|
+
fast_shape: Tuple[int],
|
373
|
+
fast_ft_shape: Tuple[int],
|
374
|
+
rotations: NDArray,
|
375
|
+
real_dtype: type,
|
376
|
+
complex_dtype: type,
|
377
|
+
callback_class: CallbackClass,
|
378
|
+
callback_class_args: Dict,
|
379
|
+
interpolation_order: int,
|
380
|
+
convolution_mode: str = "full",
|
381
|
+
**kwargs,
|
382
|
+
) -> CallbackClass:
|
383
|
+
template_buffer, template_shape, template_dtype = template
|
384
|
+
ft_target_buffer, ft_target_shape, ft_target_dtype = ft_target
|
385
|
+
inv_denominator_buffer, inv_denominator_pointer_shape, _ = inv_denominator
|
386
|
+
numerator2_buffer, numerator2_shape, _ = numerator2
|
387
|
+
filter_buffer, filter_shape, filter_dtype = template_filter
|
388
|
+
|
389
|
+
if callback_class is not None and isinstance(callback_class, type):
|
390
|
+
callback = callback_class(**callback_class_args)
|
391
|
+
elif not isinstance(callback_class, type):
|
392
|
+
callback = callback_class
|
393
|
+
|
394
|
+
# Retrieve objects from shared memory
|
395
|
+
template = backend.sharedarr_to_arr(template_shape, template_dtype, template_buffer)
|
396
|
+
ft_target = backend.sharedarr_to_arr(
|
397
|
+
ft_target_shape, ft_target_dtype, ft_target_buffer
|
398
|
+
)
|
399
|
+
inv_denominator = backend.sharedarr_to_arr(
|
400
|
+
inv_denominator_pointer_shape, template_dtype, inv_denominator_buffer
|
401
|
+
)
|
402
|
+
numerator2 = backend.sharedarr_to_arr(
|
403
|
+
numerator2_shape, template_dtype, numerator2_buffer
|
404
|
+
)
|
405
|
+
template_filter = backend.sharedarr_to_arr(
|
406
|
+
filter_shape, filter_dtype, filter_buffer
|
407
|
+
)
|
408
|
+
|
409
|
+
arr = backend.preallocate_array(fast_shape, real_dtype)
|
410
|
+
ft_temp = backend.preallocate_array(fast_ft_shape, complex_dtype)
|
411
|
+
|
412
|
+
rfftn, irfftn = backend.build_fft(
|
413
|
+
fast_shape=fast_shape,
|
414
|
+
fast_ft_shape=fast_ft_shape,
|
415
|
+
real_dtype=real_dtype,
|
416
|
+
complex_dtype=complex_dtype,
|
417
|
+
fftargs=kwargs.get("fftargs", {}),
|
418
|
+
temp_real=arr,
|
419
|
+
temp_fft=ft_temp,
|
420
|
+
)
|
421
|
+
|
422
|
+
norm_numerator = (backend.sum(numerator2) != 0) & (backend.size(numerator2) != 1)
|
423
|
+
norm_denominator = (backend.sum(inv_denominator) != 1) & (
|
424
|
+
backend.size(inv_denominator) != 1
|
425
|
+
)
|
426
|
+
filter_template = backend.size(template_filter) != 0
|
427
|
+
|
428
|
+
norm_func_numerator = conditional_execute(backend.subtract, norm_numerator)
|
429
|
+
norm_func_denominator = conditional_execute(backend.multiply, norm_denominator)
|
430
|
+
template_filter_func = conditional_execute(backend.multiply, filter_template)
|
431
|
+
|
432
|
+
rotation_out = backend.preallocate_array(
|
433
|
+
(*fast_shape[:-1], template.shape[-1]), real_dtype
|
434
|
+
)
|
435
|
+
templateshape = list(templateshape)
|
436
|
+
templateshape[-1] = 1
|
437
|
+
from time import time
|
438
|
+
|
439
|
+
for index in range(rotations.shape[0]):
|
440
|
+
start = time()
|
441
|
+
rotation = rotations[index]
|
442
|
+
backend.fill(arr, 0)
|
443
|
+
backend.rotate_array(
|
444
|
+
arr=template,
|
445
|
+
rotation_matrix=rotation,
|
446
|
+
out=rotation_out,
|
447
|
+
use_geometric_center=False,
|
448
|
+
order=interpolation_order,
|
449
|
+
)
|
450
|
+
projection = backend.sum(rotation_out, axis=-1)
|
451
|
+
arr[..., 0] = projection
|
452
|
+
|
453
|
+
rfftn(arr, ft_temp)
|
454
|
+
template_filter_func(ft_temp, template_filter, out=ft_temp)
|
455
|
+
|
456
|
+
backend.multiply(ft_target, ft_temp, out=ft_temp)
|
457
|
+
irfftn(ft_temp, arr)
|
458
|
+
|
459
|
+
norm_func_numerator(arr, numerator2, out=arr)
|
460
|
+
norm_func_denominator(arr, inv_denominator, out=arr)
|
461
|
+
|
462
|
+
score = apply_convolution_mode(
|
463
|
+
arr, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
|
464
|
+
)
|
465
|
+
print(time() - start)
|
466
|
+
if callback_class is not None:
|
467
|
+
callback(
|
468
|
+
score,
|
469
|
+
rotation_matrix=rotation,
|
470
|
+
rotation_index=index,
|
471
|
+
**callback_class_args,
|
472
|
+
)
|
473
|
+
|
474
|
+
return callback
|
475
|
+
|
476
|
+
@device_memory_handler
|
477
|
+
def scan(
|
478
|
+
matching_data: MatchingData,
|
479
|
+
matching_setup: Callable,
|
480
|
+
matching_score: Callable,
|
481
|
+
n_jobs: int = 4,
|
482
|
+
callback_class: CallbackClass = None,
|
483
|
+
callback_class_args: Dict = {},
|
484
|
+
fftargs: Dict = {},
|
485
|
+
pad_fourier: bool = True,
|
486
|
+
interpolation_order: int = 3,
|
487
|
+
jobs_per_callback_class: int = 8,
|
488
|
+
**kwargs,
|
489
|
+
) -> Tuple:
|
490
|
+
"""
|
491
|
+
Perform template matching between target and template and sample
|
492
|
+
different rotations of template.
|
493
|
+
|
494
|
+
Parameters
|
495
|
+
----------
|
496
|
+
matching_data : MatchingData
|
497
|
+
Template matching data.
|
498
|
+
matching_setup : Callable
|
499
|
+
Function pointer to setup function.
|
500
|
+
matching_score : Callable
|
501
|
+
Function pointer to scoring function.
|
502
|
+
n_jobs : int, optional
|
503
|
+
Number of parallel jobs. Default is 4.
|
504
|
+
callback_class : type, optional
|
505
|
+
Analyzer class pointer to operate on computed scores.
|
506
|
+
callback_class_args : dict, optional
|
507
|
+
Arguments passed to the callback_class. Default is an empty dictionary.
|
508
|
+
fftargs : dict, optional
|
509
|
+
Arguments for the FFT operations. Default is an empty dictionary.
|
510
|
+
pad_fourier: bool, optional
|
511
|
+
Whether to pad target and template to the full convolution shape.
|
512
|
+
interpolation_order : int, optional
|
513
|
+
Order of spline interpolation for rotations.
|
514
|
+
jobs_per_callback_class : int, optional
|
515
|
+
How many jobs should be processed by a single callback_class instance,
|
516
|
+
if ones is provided.
|
517
|
+
**kwargs : various
|
518
|
+
Additional arguments.
|
519
|
+
|
520
|
+
Returns
|
521
|
+
-------
|
522
|
+
Tuple
|
523
|
+
The merged results from callback_class if provided otherwise None.
|
524
|
+
"""
|
525
|
+
matching_data.to_backend()
|
526
|
+
fourier_pad = matching_data._templateshape
|
527
|
+
fourier_pad = list(matching_data._templateshape)
|
528
|
+
fourier_pad[-1] = 1
|
529
|
+
print("make sure to remove this")
|
530
|
+
fourier_shift = backend.zeros(len(fourier_pad))
|
531
|
+
if not pad_fourier:
|
532
|
+
fourier_pad = backend.full(shape=fourier_shift.shape, fill_value=1, dtype=int)
|
533
|
+
fourier_shift = 1 - backend.astype(
|
534
|
+
backend.divide(matching_data._templateshape, 2), int
|
535
|
+
)
|
536
|
+
callback_class_args["fourier_shift"] = fourier_shift
|
537
|
+
|
538
|
+
_, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
|
539
|
+
matching_data._target.shape, fourier_pad
|
540
|
+
)
|
541
|
+
rfftn, irfftn = backend.build_fft(
|
542
|
+
fast_shape=fast_shape,
|
543
|
+
fast_ft_shape=fast_ft_shape,
|
544
|
+
real_dtype=matching_data._default_dtype,
|
545
|
+
complex_dtype=matching_data._complex_dtype,
|
546
|
+
fftargs=fftargs,
|
547
|
+
)
|
548
|
+
setup = matching_setup(
|
549
|
+
rfftn=rfftn,
|
550
|
+
irfftn=irfftn,
|
551
|
+
template=matching_data.template,
|
552
|
+
template_mask=matching_data.template_mask,
|
553
|
+
target=matching_data.target,
|
554
|
+
target_mask=matching_data.target_mask,
|
555
|
+
fast_shape=fast_shape,
|
556
|
+
fast_ft_shape=fast_ft_shape,
|
557
|
+
real_dtype=matching_data._default_dtype,
|
558
|
+
complex_dtype=matching_data._complex_dtype,
|
559
|
+
callback_class=callback_class,
|
560
|
+
callback_class_args=callback_class_args,
|
561
|
+
**kwargs,
|
562
|
+
)
|
563
|
+
rfftn, irfftn = None, None
|
564
|
+
|
565
|
+
template_filter, preprocessor = None, Preprocessor()
|
566
|
+
for method, parameters in matching_data.template_filter.items():
|
567
|
+
parameters["shape"] = fast_shape
|
568
|
+
parameters["omit_negative_frequencies"] = True
|
569
|
+
out = preprocessor.apply_method(method=method, parameters=parameters)
|
570
|
+
if template_filter is None:
|
571
|
+
template_filter = out
|
572
|
+
np.multiply(template_filter, out, out=template_filter)
|
573
|
+
|
574
|
+
if template_filter is None:
|
575
|
+
template_filter = backend.full(
|
576
|
+
shape=(1,), fill_value=1, dtype=backend._default_dtype
|
577
|
+
)
|
578
|
+
else:
|
579
|
+
template_filter = backend.to_backend_array(template_filter)
|
580
|
+
|
581
|
+
template_filter = backend.astype(template_filter, backend._default_dtype)
|
582
|
+
template_filter_buffer = backend.arr_to_sharedarr(
|
583
|
+
arr=template_filter,
|
584
|
+
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
585
|
+
)
|
586
|
+
setup["template_filter"] = (
|
587
|
+
template_filter_buffer,
|
588
|
+
template_filter.shape,
|
589
|
+
template_filter.dtype,
|
590
|
+
)
|
591
|
+
|
592
|
+
callback_class_args["translation_offset"] = backend.astype(
|
593
|
+
matching_data._translation_offset, int
|
594
|
+
)
|
595
|
+
callback_class_args["thread_safe"] = n_jobs > 1
|
596
|
+
callback_class_args["gpu_index"] = kwargs.get("gpu_index", -1)
|
597
|
+
|
598
|
+
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
599
|
+
callback_class = setup.pop("callback_class", callback_class)
|
600
|
+
callback_class_args = setup.pop("callback_class_args", callback_class_args)
|
601
|
+
callback_classes = [callback_class for _ in range(n_callback_classes)]
|
602
|
+
if callback_class == MaxScoreOverRotations:
|
603
|
+
score_space_shape = backend.subtract(
|
604
|
+
matching_data.target.shape,
|
605
|
+
matching_data._target_pad,
|
606
|
+
)
|
607
|
+
callback_classes = [
|
608
|
+
class_name(
|
609
|
+
score_space_shape=score_space_shape,
|
610
|
+
score_space_dtype=matching_data._default_dtype,
|
611
|
+
shared_memory_handler=kwargs.get("shared_memory_handler", None),
|
612
|
+
rotation_space_dtype=backend._default_dtype_int,
|
613
|
+
**callback_class_args,
|
614
|
+
)
|
615
|
+
for class_name in callback_classes
|
616
|
+
]
|
617
|
+
|
618
|
+
matching_data._target, matching_data._template = None, None
|
619
|
+
matching_data._target_mask, matching_data._template_mask = None, None
|
620
|
+
|
621
|
+
setup["fftargs"] = fftargs.copy()
|
622
|
+
convolution_mode = "same"
|
623
|
+
if backend.sum(matching_data._target_pad) > 0:
|
624
|
+
convolution_mode = "valid"
|
625
|
+
setup["convolution_mode"] = convolution_mode
|
626
|
+
setup["interpolation_order"] = interpolation_order
|
627
|
+
rotation_list = matching_data._split_rotations_on_jobs(n_jobs)
|
628
|
+
|
629
|
+
backend.free_cache()
|
630
|
+
|
631
|
+
def _run_scoring(backend_name, backend_args, rotations, **kwargs):
|
632
|
+
from tme.backends import backend
|
633
|
+
|
634
|
+
backend.change_backend(backend_name, **backend_args)
|
635
|
+
return matching_score(rotations=rotations, **kwargs)
|
636
|
+
|
637
|
+
callbacks = Parallel(n_jobs=n_jobs)(
|
638
|
+
delayed(_run_scoring)(
|
639
|
+
backend_name=backend._backend_name,
|
640
|
+
backend_args=backend._backend_args,
|
641
|
+
rotations=rotation,
|
642
|
+
callback_class=callback_classes[index % n_callback_classes],
|
643
|
+
callback_class_args=callback_class_args,
|
644
|
+
**setup,
|
645
|
+
)
|
646
|
+
for index, rotation in enumerate(rotation_list)
|
647
|
+
)
|
648
|
+
|
649
|
+
callbacks = [
|
650
|
+
tuple(callback)
|
651
|
+
for callback in callbacks[0:n_callback_classes]
|
652
|
+
if callback is not None
|
653
|
+
]
|
654
|
+
backend.free_cache()
|
655
|
+
|
656
|
+
merged_callback = None
|
657
|
+
if callback_class is not None:
|
658
|
+
merged_callback = callback_class.merge(
|
659
|
+
callbacks,
|
660
|
+
**callback_class_args,
|
661
|
+
score_indices=matching_data.indices,
|
662
|
+
inner_merge=True,
|
663
|
+
)
|
664
|
+
|
665
|
+
return merged_callback
|
666
|
+
|
667
|
+
|
668
|
+
register_matching_exhaustive(
|
669
|
+
matching = "CC2",
|
670
|
+
matching_setup = corr2_setup,
|
671
|
+
matching_scoring = corr2_scoring,
|
672
|
+
memory_class = CCMemoryUsage)
|
673
|
+
|
674
|
+
register_matching_exhaustive(
|
675
|
+
matching = "CC3",
|
676
|
+
matching_setup = corr3_setup,
|
677
|
+
matching_scoring = corr3_scoring,
|
678
|
+
memory_class = CCMemoryUsage
|
679
|
+
)
|