pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +49 -103
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +49 -103
- scripts/pytme_runner.py +46 -69
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_orientations.py +12 -0
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +91 -68
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +103 -98
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +44 -57
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +17 -3
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
tme/matching_scores.py
CHANGED
@@ -10,14 +10,14 @@ import warnings
|
|
10
10
|
from typing import Callable, Tuple, Dict
|
11
11
|
|
12
12
|
import numpy as np
|
13
|
-
from scipy.ndimage import laplace
|
14
13
|
|
15
14
|
from .backends import backend as be
|
16
15
|
from .types import CallbackClass, BackendArray, shm_type
|
17
16
|
from .matching_utils import (
|
18
17
|
conditional_execute,
|
19
18
|
identity,
|
20
|
-
|
19
|
+
standardize,
|
20
|
+
to_padded,
|
21
21
|
)
|
22
22
|
|
23
23
|
|
@@ -46,7 +46,7 @@ def cc_setup(
|
|
46
46
|
)
|
47
47
|
axes = matching_data._batch_axis(matching_data._batch_mask)
|
48
48
|
|
49
|
-
|
49
|
+
return {
|
50
50
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
51
51
|
"ft_target": be.to_sharedarr(be.rfftn(target_pad, axes=axes), shm_handler),
|
52
52
|
"inv_denominator": be.to_sharedarr(
|
@@ -55,8 +55,6 @@ def cc_setup(
|
|
55
55
|
"numerator": be.to_sharedarr(be.zeros(1, be._float_dtype), shm_handler),
|
56
56
|
}
|
57
57
|
|
58
|
-
return ret
|
59
|
-
|
60
58
|
|
61
59
|
def lcc_setup(matching_data, **kwargs) -> Dict:
|
62
60
|
"""
|
@@ -71,26 +69,8 @@ def lcc_setup(matching_data, **kwargs) -> Dict:
|
|
71
69
|
-----
|
72
70
|
To be used with :py:meth:`corr_scoring`.
|
73
71
|
"""
|
74
|
-
target =
|
75
|
-
template =
|
76
|
-
|
77
|
-
subsets = matching_data._batch_iter(
|
78
|
-
target.shape,
|
79
|
-
tuple(1 if i in matching_data._target_dim else 0 for i in range(target.ndim)),
|
80
|
-
)
|
81
|
-
for subset in subsets:
|
82
|
-
target[subset] = laplace(target[subset], mode="wrap")
|
83
|
-
|
84
|
-
subsets = matching_data._batch_iter(
|
85
|
-
template.shape,
|
86
|
-
tuple(1 if i in matching_data._template_dim else 0 for i in range(target.ndim)),
|
87
|
-
)
|
88
|
-
for subset in subsets:
|
89
|
-
template[subset] = laplace(template[subset], mode="wrap")
|
90
|
-
|
91
|
-
matching_data._target = be.to_backend_array(target)
|
92
|
-
matching_data._template = be.to_backend_array(template)
|
93
|
-
|
72
|
+
matching_data.target = matching_data.transform_target("laplace")
|
73
|
+
matching_data.template = matching_data.transform_template("laplace")
|
94
74
|
return cc_setup(matching_data=matching_data, **kwargs)
|
95
75
|
|
96
76
|
|
@@ -184,19 +164,17 @@ def corr_setup(
|
|
184
164
|
denominator = be.divide(1, denominator, out=denominator)
|
185
165
|
denominator = be.multiply(denominator, mask, out=denominator)
|
186
166
|
|
187
|
-
|
167
|
+
return {
|
188
168
|
"template": be.to_sharedarr(template, shm_handler),
|
189
169
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
190
170
|
"inv_denominator": be.to_sharedarr(denominator, shm_handler),
|
191
171
|
"numerator": be.to_sharedarr(numerator, shm_handler),
|
192
172
|
}
|
193
173
|
|
194
|
-
return ret
|
195
|
-
|
196
174
|
|
197
175
|
def cam_setup(matching_data, **kwargs) -> Dict:
|
198
176
|
"""
|
199
|
-
Like :py:meth:`corr_setup` but with standardized ``target
|
177
|
+
Like :py:meth:`corr_setup` but with standardized ``target`` and ``template``
|
200
178
|
|
201
179
|
.. math::
|
202
180
|
|
@@ -206,19 +184,14 @@ def cam_setup(matching_data, **kwargs) -> Dict:
|
|
206
184
|
-----
|
207
185
|
To be used with :py:meth:`corr_scoring`.
|
208
186
|
"""
|
209
|
-
|
210
|
-
|
211
|
-
matching_data
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
target = matching_data.
|
216
|
-
|
217
|
-
matching_data._target = be.divide(
|
218
|
-
be.subtract(target, be.mean(target, axis=axis, keepdims=True)),
|
219
|
-
be.std(target, axis=axis, keepdims=True),
|
220
|
-
)
|
221
|
-
return corr_setup(matching_data=matching_data, **kwargs)
|
187
|
+
matching_data.target = matching_data.transform_target("standardize")
|
188
|
+
matching_data.template = matching_data.transform_template("standardize")
|
189
|
+
return flcSphericalMask_setup(matching_data=matching_data, **kwargs)
|
190
|
+
|
191
|
+
|
192
|
+
def ncc_setup(matching_data, **kwargs) -> Dict:
|
193
|
+
matching_data.target = matching_data.transform_target("standardize")
|
194
|
+
return cc_setup(matching_data=matching_data, **kwargs)
|
222
195
|
|
223
196
|
|
224
197
|
def flc_setup(
|
@@ -242,15 +215,13 @@ def flc_setup(
|
|
242
215
|
target_pad = be.square(target_pad, out=target_pad)
|
243
216
|
ft_target2 = be.rfftn(target_pad, axes=data_axes)
|
244
217
|
|
245
|
-
|
218
|
+
return {
|
246
219
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
247
220
|
"template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
|
248
221
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
249
222
|
"ft_target2": be.to_sharedarr(ft_target2, shm_handler),
|
250
223
|
}
|
251
224
|
|
252
|
-
return ret
|
253
|
-
|
254
225
|
|
255
226
|
def flcSphericalMask_setup(
|
256
227
|
matching_data,
|
@@ -302,7 +273,7 @@ def flcSphericalMask_setup(
|
|
302
273
|
temp = be.irfftn(ft_temp, s=data_shape, axes=data_axes)
|
303
274
|
|
304
275
|
temp2 = be.norm_scores(1, temp2, temp, n_obs, be.eps(be._float_dtype), temp2)
|
305
|
-
|
276
|
+
return {
|
306
277
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
307
278
|
"template_mask": be.to_sharedarr(template_mask, shm_handler),
|
308
279
|
"ft_target": be.to_sharedarr(ft_target, shm_handler),
|
@@ -310,8 +281,6 @@ def flcSphericalMask_setup(
|
|
310
281
|
"numerator": be.to_sharedarr(be.zeros(1, be._float_dtype), shm_handler),
|
311
282
|
}
|
312
283
|
|
313
|
-
return ret
|
314
|
-
|
315
284
|
|
316
285
|
def mcc_setup(
|
317
286
|
matching_data,
|
@@ -331,7 +300,7 @@ def mcc_setup(
|
|
331
300
|
target = be.topleft_pad(target, shape)
|
332
301
|
target_mask = be.topleft_pad(target_mask, shape)
|
333
302
|
|
334
|
-
|
303
|
+
return {
|
335
304
|
"template": be.to_sharedarr(matching_data.template, shm_handler),
|
336
305
|
"template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
|
337
306
|
"ft_target": be.to_sharedarr(be.rfftn(target, axes=ax), shm_handler),
|
@@ -341,8 +310,6 @@ def mcc_setup(
|
|
341
310
|
"ft_target_mask": be.to_sharedarr(be.rfftn(target_mask, axes=ax), shm_handler),
|
342
311
|
}
|
343
312
|
|
344
|
-
return ret
|
345
|
-
|
346
313
|
|
347
314
|
def corr_scoring(
|
348
315
|
template: shm_type,
|
@@ -357,6 +324,7 @@ def corr_scoring(
|
|
357
324
|
interpolation_order: int,
|
358
325
|
template_mask: shm_type = None,
|
359
326
|
score_mask: shm_type = None,
|
327
|
+
template_background: shm_type = None,
|
360
328
|
) -> CallbackClass:
|
361
329
|
"""
|
362
330
|
Calculates a normalized cross-correlation between a target f and a template g.
|
@@ -414,7 +382,7 @@ def corr_scoring(
|
|
414
382
|
template_mask = be.from_sharedarr(template_mask)
|
415
383
|
n_obs = be.sum(template_mask) if template_mask is not None else None
|
416
384
|
|
417
|
-
norm_template = conditional_execute(
|
385
|
+
norm_template = conditional_execute(standardize, n_obs is not None)
|
418
386
|
norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
|
419
387
|
norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1))
|
420
388
|
norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
|
@@ -423,30 +391,41 @@ def corr_scoring(
|
|
423
391
|
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
424
392
|
template_rot = be.zeros(template.shape, be._float_dtype)
|
425
393
|
|
426
|
-
|
427
|
-
arr_shape=template.shape,
|
428
|
-
filter_shape=template_filter.shape,
|
429
|
-
)
|
394
|
+
tmpl_filter_func = _create_filter_func(template.shape, template_filter)
|
430
395
|
|
431
396
|
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
432
397
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
398
|
+
|
399
|
+
background_correction = template_background is not None
|
400
|
+
if background_correction:
|
401
|
+
scores_alt, compute_norm = _setup_background_correction(
|
402
|
+
fast_shape=fast_shape,
|
403
|
+
template_background=template_background,
|
404
|
+
rotation_buffer=template_rot,
|
405
|
+
unpadded_slice=unpadded_slice,
|
406
|
+
interpolation_order=interpolation_order,
|
407
|
+
tmpl_filter_func=tmpl_filter_func,
|
408
|
+
norm_template=norm_template,
|
409
|
+
)
|
410
|
+
|
433
411
|
for index in range(rotations.shape[0]):
|
434
412
|
rotation = rotations[index]
|
435
|
-
matrix = be.
|
413
|
+
matrix = be._build_transform_matrix(
|
414
|
+
rotation_matrix=rotation, center=center, shape=template.shape
|
415
|
+
)
|
436
416
|
_ = be.rigid_transform(
|
437
417
|
arr=template,
|
438
418
|
rotation_matrix=matrix,
|
439
419
|
out=template_rot,
|
440
420
|
order=interpolation_order,
|
441
421
|
cache=True,
|
422
|
+
use_geometric_center=True,
|
442
423
|
)
|
443
424
|
|
444
|
-
template_rot =
|
445
|
-
norm_template(template_rot, template_mask, n_obs)
|
446
|
-
|
447
|
-
arr = be.fill(arr, 0)
|
448
|
-
arr[unpadded_slice] = template_rot
|
425
|
+
template_rot = tmpl_filter_func(template_rot, ft_temp)
|
426
|
+
template_rot = norm_template(template_rot, template_mask, n_obs)
|
449
427
|
|
428
|
+
arr = to_padded(arr, template_rot, unpadded_slice)
|
450
429
|
ft_temp = be.rfftn(arr, s=fast_shape, out=ft_temp)
|
451
430
|
arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
|
452
431
|
|
@@ -455,6 +434,17 @@ def corr_scoring(
|
|
455
434
|
arr = norm_mask(arr, score_mask, out=arr)
|
456
435
|
|
457
436
|
callback(arr, rotation_matrix=rotation)
|
437
|
+
if background_correction:
|
438
|
+
arr = compute_norm(arr, ft_target, ft_temp, matrix, template_mask, n_obs)
|
439
|
+
arr = norm_sub(arr, numerator, out=arr)
|
440
|
+
arr = norm_mul(arr, inv_denominator, out=arr)
|
441
|
+
arr = norm_mask(arr, score_mask, out=arr)
|
442
|
+
scores_alt = be.maximum(arr, scores_alt, out=scores_alt)
|
443
|
+
|
444
|
+
if background_correction:
|
445
|
+
scores_alt = norm_mask(scores_alt, score_mask, out=scores_alt)
|
446
|
+
scores_alt = be.subtract(scores_alt, be.mean(scores_alt), out=scores_alt)
|
447
|
+
callback.correct_background(scores_alt)
|
458
448
|
|
459
449
|
return callback
|
460
450
|
|
@@ -471,6 +461,7 @@ def flc_scoring(
|
|
471
461
|
callback: CallbackClass,
|
472
462
|
interpolation_order: int,
|
473
463
|
score_mask: shm_type = None,
|
464
|
+
template_background: shm_type = None,
|
474
465
|
) -> CallbackClass:
|
475
466
|
"""
|
476
467
|
Computes a normalized cross-correlation between ``target`` (f),
|
@@ -511,8 +502,6 @@ def flc_scoring(
|
|
511
502
|
Rotation matrices to be sampled (n, d, d).
|
512
503
|
callback : CallbackClass
|
513
504
|
A callable for processing the result of each rotation.
|
514
|
-
callback_class_args : Dict
|
515
|
-
Dictionary of arguments to be passed to ``callback``.
|
516
505
|
interpolation_order : int
|
517
506
|
Spline order for template rotations.
|
518
507
|
|
@@ -524,7 +513,6 @@ def flc_scoring(
|
|
524
513
|
----------
|
525
514
|
.. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
|
526
515
|
"""
|
527
|
-
float_dtype, complex_dtype = be._float_dtype, be._complex_dtype
|
528
516
|
template = be.from_sharedarr(template)
|
529
517
|
template_mask = be.from_sharedarr(template_mask)
|
530
518
|
ft_target = be.from_sharedarr(ft_target)
|
@@ -532,42 +520,55 @@ def flc_scoring(
|
|
532
520
|
template_filter = be.from_sharedarr(template_filter)
|
533
521
|
score_mask = be.from_sharedarr(score_mask)
|
534
522
|
|
535
|
-
arr = be.zeros(fast_shape,
|
536
|
-
temp = be.zeros(fast_shape,
|
537
|
-
temp2 = be.zeros(fast_shape,
|
538
|
-
ft_temp = be.zeros(fast_ft_shape,
|
539
|
-
ft_denom = be.zeros(fast_ft_shape,
|
523
|
+
arr = be.zeros(fast_shape, be._float_dtype)
|
524
|
+
temp = be.zeros(fast_shape, be._float_dtype)
|
525
|
+
temp2 = be.zeros(fast_shape, be._float_dtype)
|
526
|
+
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
527
|
+
ft_denom = be.zeros(fast_ft_shape, be._complex_dtype)
|
540
528
|
template_rot = be.zeros(template.shape, be._float_dtype)
|
541
529
|
template_mask_rot = be.zeros(template.shape, be._float_dtype)
|
542
530
|
|
543
|
-
tmpl_filter_func = _create_filter_func(template.shape, template_filter
|
531
|
+
tmpl_filter_func = _create_filter_func(template.shape, template_filter)
|
544
532
|
norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
|
545
533
|
|
546
|
-
eps = be.eps(
|
534
|
+
eps = be.eps(be._float_dtype)
|
547
535
|
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
548
536
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
537
|
+
|
538
|
+
background_correction = template_background is not None
|
539
|
+
if background_correction:
|
540
|
+
scores_alt, compute_norm = _setup_background_correction(
|
541
|
+
fast_shape=fast_shape,
|
542
|
+
template_background=template_background,
|
543
|
+
rotation_buffer=template_rot,
|
544
|
+
unpadded_slice=unpadded_slice,
|
545
|
+
interpolation_order=interpolation_order,
|
546
|
+
tmpl_filter_func=tmpl_filter_func,
|
547
|
+
norm_template=standardize,
|
548
|
+
)
|
549
|
+
|
549
550
|
for index in range(rotations.shape[0]):
|
550
551
|
rotation = rotations[index]
|
551
|
-
matrix = be.
|
552
|
+
matrix = be._build_transform_matrix(
|
553
|
+
rotation_matrix=rotation, center=center, shape=template.shape
|
554
|
+
)
|
552
555
|
_ = be.rigid_transform(
|
553
556
|
arr=template,
|
554
557
|
arr_mask=template_mask,
|
555
558
|
rotation_matrix=matrix,
|
556
559
|
out=template_rot,
|
557
560
|
out_mask=template_mask_rot,
|
558
|
-
use_geometric_center=True,
|
559
561
|
order=interpolation_order,
|
560
562
|
cache=True,
|
563
|
+
use_geometric_center=True,
|
561
564
|
)
|
562
565
|
|
563
566
|
n_obs = be.sum(template_mask_rot)
|
564
|
-
template_rot = tmpl_filter_func(template_rot, ft_temp
|
565
|
-
template_rot =
|
567
|
+
template_rot = tmpl_filter_func(template_rot, ft_temp)
|
568
|
+
template_rot = standardize(template_rot, template_mask_rot, n_obs)
|
566
569
|
|
567
|
-
arr =
|
568
|
-
temp =
|
569
|
-
arr[unpadded_slice] = template_rot
|
570
|
-
temp[unpadded_slice] = template_mask_rot
|
570
|
+
arr = to_padded(arr, template_rot, unpadded_slice)
|
571
|
+
temp = to_padded(temp, template_mask_rot, unpadded_slice)
|
571
572
|
|
572
573
|
ft_temp = be.rfftn(temp, out=ft_temp, s=fast_shape)
|
573
574
|
temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, fast_shape)
|
@@ -576,14 +577,112 @@ def flc_scoring(
|
|
576
577
|
ft_temp = be.rfftn(arr, out=ft_temp, s=fast_shape)
|
577
578
|
arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
|
578
579
|
|
579
|
-
|
580
|
+
inv_sdev = be.norm_scores(1, temp2, temp, n_obs, eps, temp2)
|
581
|
+
arr = be.multiply(arr, inv_sdev, out=arr)
|
580
582
|
arr = norm_mask(arr, score_mask, out=arr)
|
581
583
|
|
582
584
|
callback(arr, rotation_matrix=rotation)
|
585
|
+
if background_correction:
|
586
|
+
arr = compute_norm(arr, ft_target, ft_temp, matrix, template_mask, n_obs)
|
587
|
+
arr = be.multiply(arr, inv_sdev, out=arr)
|
588
|
+
scores_alt = be.maximum(arr, scores_alt, out=scores_alt)
|
589
|
+
|
590
|
+
if background_correction:
|
591
|
+
scores_alt = norm_mask(scores_alt, score_mask, out=scores_alt)
|
592
|
+
scores_alt = be.subtract(scores_alt, be.mean(scores_alt), out=scores_alt)
|
593
|
+
callback.correct_background(scores_alt)
|
583
594
|
|
584
595
|
return callback
|
585
596
|
|
586
597
|
|
598
|
+
def ncc_scoring(
|
599
|
+
template: shm_type,
|
600
|
+
ft_target: shm_type,
|
601
|
+
fast_shape: Tuple[int],
|
602
|
+
fast_ft_shape: Tuple[int],
|
603
|
+
rotations: BackendArray,
|
604
|
+
callback: CallbackClass,
|
605
|
+
interpolation_order: int,
|
606
|
+
template_filter: shm_type = None,
|
607
|
+
score_mask: shm_type = None,
|
608
|
+
template_background: shm_type = None,
|
609
|
+
**kwargs,
|
610
|
+
) -> CallbackClass:
|
611
|
+
template = be.from_sharedarr(template)
|
612
|
+
ft_target = be.from_sharedarr(ft_target)
|
613
|
+
score_mask = be.from_sharedarr(score_mask)
|
614
|
+
template_filter = be.from_sharedarr(template_filter)
|
615
|
+
|
616
|
+
arr = be.zeros(fast_shape, be._float_dtype)
|
617
|
+
ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
|
618
|
+
template_rot = be.zeros(template.shape, be._float_dtype)
|
619
|
+
|
620
|
+
# Welford arrays for global statistics
|
621
|
+
pixel_mean = be.zeros(fast_shape, be._float_dtype)
|
622
|
+
pixel_M2 = be.zeros(fast_shape, be._float_dtype)
|
623
|
+
|
624
|
+
tmpl_filter_func = _create_filter_func(template.shape, template_filter)
|
625
|
+
norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
|
626
|
+
|
627
|
+
size = be.size(template)
|
628
|
+
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
629
|
+
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
630
|
+
n_angles = rotations.shape[0]
|
631
|
+
|
632
|
+
# Scale forward transform by 1/n i.e. norm 'forward'
|
633
|
+
ft_target = be.multiply(ft_target, 1 / be.size(arr))
|
634
|
+
|
635
|
+
background_correction = template_background is not None
|
636
|
+
for index in range(n_angles):
|
637
|
+
arr = be.fill(arr, 0)
|
638
|
+
rotation = rotations[index]
|
639
|
+
matrix = be._build_transform_matrix(
|
640
|
+
rotation_matrix=rotation, center=center, shape=template.shape
|
641
|
+
)
|
642
|
+
|
643
|
+
be.rigid_transform(
|
644
|
+
template,
|
645
|
+
rotation_matrix=matrix,
|
646
|
+
out=template_rot,
|
647
|
+
order=interpolation_order,
|
648
|
+
cache=True,
|
649
|
+
use_geometric_center=True,
|
650
|
+
)
|
651
|
+
template_rot = tmpl_filter_func(template_rot, ft_temp)
|
652
|
+
template_rot = standardize(template_rot, 1, size)
|
653
|
+
|
654
|
+
arr = to_padded(arr, template_rot, unpadded_slice)
|
655
|
+
ft_temp = be.rfftn(arr, s=fast_shape, norm="forward")
|
656
|
+
ft_temp = be.multiply(ft_temp, ft_target, out=ft_temp)
|
657
|
+
|
658
|
+
arr = be.irfftn(ft_temp, s=fast_shape, norm="forward")
|
659
|
+
arr = norm_mask(arr, score_mask, out=arr)
|
660
|
+
callback(arr, rotation_matrix=rotation)
|
661
|
+
|
662
|
+
delta = be.subtract(arr, pixel_mean)
|
663
|
+
pixel_mean = be.add(pixel_mean, be.divide(delta, index + 1), out=pixel_mean)
|
664
|
+
delta2 = be.subtract(arr, pixel_mean)
|
665
|
+
delta = be.multiply(delta, delta2, out=delta)
|
666
|
+
pixel_M2 = be.add(pixel_M2, delta, out=pixel_M2)
|
667
|
+
|
668
|
+
global_mean = be.mean(pixel_mean)
|
669
|
+
pixel_variance = be.divide(pixel_M2, n_angles - 1)
|
670
|
+
global_std = be.sqrt(be.mean(pixel_variance))
|
671
|
+
|
672
|
+
callback.correct_background(global_mean, global_std)
|
673
|
+
if background_correction:
|
674
|
+
# Adapt units for local normalization
|
675
|
+
pixel_mean = be.subtract(pixel_mean, global_mean, out=pixel_mean)
|
676
|
+
pixel_mean = be.divide(pixel_mean, global_std, out=pixel_mean)
|
677
|
+
|
678
|
+
pixel_std = be.sqrt(pixel_variance, out=pixel_variance)
|
679
|
+
pixel_std = be.divide(pixel_std, global_std, out=pixel_variance)
|
680
|
+
|
681
|
+
pixel_std = be.where(pixel_std > 1e-4, 1 / pixel_std, 0.0)
|
682
|
+
callback.correct_background(pixel_mean, pixel_std)
|
683
|
+
return callback
|
684
|
+
|
685
|
+
|
587
686
|
def mcc_scoring(
|
588
687
|
template: shm_type,
|
589
688
|
template_mask: shm_type,
|
@@ -597,7 +696,7 @@ def mcc_scoring(
|
|
597
696
|
callback: CallbackClass,
|
598
697
|
interpolation_order: int,
|
599
698
|
overlap_ratio: float = 0.3,
|
600
|
-
|
699
|
+
**kwargs,
|
601
700
|
) -> CallbackClass:
|
602
701
|
"""
|
603
702
|
Computes a normalized cross-correlation score between ``target`` (f),
|
@@ -676,12 +775,11 @@ def mcc_scoring(
|
|
676
775
|
temp3 = be.zeros(fast_shape, float_dtype)
|
677
776
|
temp_ft = be.zeros(fast_ft_shape, complex_dtype)
|
678
777
|
|
679
|
-
|
778
|
+
tmpl_filter_func = _create_filter_func(
|
680
779
|
arr_shape=template.shape,
|
681
|
-
|
780
|
+
template_filter=template_filter,
|
682
781
|
arr_padded=True,
|
683
782
|
)
|
684
|
-
|
685
783
|
for index in range(rotations.shape[0]):
|
686
784
|
rotation = rotations[index]
|
687
785
|
template_rot = be.fill(template_rot, 0)
|
@@ -697,8 +795,8 @@ def mcc_scoring(
|
|
697
795
|
cache=True,
|
698
796
|
)
|
699
797
|
|
700
|
-
|
701
|
-
|
798
|
+
template_rot = tmpl_filter_func(template_rot, temp_ft)
|
799
|
+
template_rot = standardize(template_rot, temp, be.sum(temp))
|
702
800
|
|
703
801
|
temp_ft = be.rfftn(template_rot, out=temp_ft, s=fast_shape)
|
704
802
|
temp2 = be.irfftn(target_mask_ft * temp_ft, out=temp2, s=fast_shape)
|
@@ -769,6 +867,7 @@ def flc_scoring2(
|
|
769
867
|
callback: CallbackClass,
|
770
868
|
interpolation_order: int,
|
771
869
|
score_mask: shm_type = None,
|
870
|
+
template_background: shm_type = None,
|
772
871
|
) -> CallbackClass:
|
773
872
|
template = be.from_sharedarr(template)
|
774
873
|
template_mask = be.from_sharedarr(template_mask)
|
@@ -795,13 +894,25 @@ def flc_scoring2(
|
|
795
894
|
|
796
895
|
tmp_sqz, arr_sqz, ft_temp = temp[sqz_slice], arr[sqz_slice], ft_denom[sqz_slice]
|
797
896
|
|
798
|
-
|
897
|
+
tmpl_filter_func = _create_filter_func(
|
799
898
|
arr_shape=template.shape,
|
800
|
-
|
899
|
+
template_filter=template_filter,
|
801
900
|
arr_padded=True,
|
802
901
|
)
|
803
902
|
norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
|
804
903
|
|
904
|
+
background_correction = template_background is not None
|
905
|
+
if background_correction:
|
906
|
+
scores_alt, compute_norm = _setup_background_correction(
|
907
|
+
fast_shape=fast_shape,
|
908
|
+
template_background=template_background,
|
909
|
+
rotation_buffer=arr_sqz[tmpl_subset],
|
910
|
+
unpadded_slice=tmpl_subset,
|
911
|
+
interpolation_order=interpolation_order,
|
912
|
+
tmpl_filter_func=tmpl_filter_func,
|
913
|
+
norm_template=standardize,
|
914
|
+
)
|
915
|
+
|
805
916
|
eps = be.eps(be._float_dtype)
|
806
917
|
for index in range(rotations.shape[0]):
|
807
918
|
rotation = rotations[index]
|
@@ -821,8 +932,8 @@ def flc_scoring2(
|
|
821
932
|
)
|
822
933
|
|
823
934
|
n_obs = be.sum(tmp_sqz, axis=axes, keepdims=True)
|
824
|
-
arr_norm =
|
825
|
-
arr_norm =
|
935
|
+
arr_norm = tmpl_filter_func(arr_sqz, ft_temp)
|
936
|
+
arr_norm = standardize(arr_norm, tmp_sqz, n_obs, axis=axes)
|
826
937
|
|
827
938
|
ft_temp = be.rfftn(tmp_sqz, out=ft_temp, axes=axes, s=shape)
|
828
939
|
temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, shape, axes)
|
@@ -831,10 +942,20 @@ def flc_scoring2(
|
|
831
942
|
ft_temp = be.rfftn(arr_norm, out=ft_temp, axes=axes, s=shape)
|
832
943
|
arr = _correlate_fts(ft_target, ft_temp, ft_denom, arr, shape, axes)
|
833
944
|
|
834
|
-
|
945
|
+
inv_sdev = be.norm_scores(1, temp2, temp, n_obs, eps, temp2)
|
946
|
+
arr = be.multiply(arr, inv_sdev, out=arr)
|
835
947
|
arr = norm_mask(arr, score_mask, out=arr)
|
836
948
|
|
837
949
|
callback(arr, rotation_matrix=rotation)
|
950
|
+
if background_correction:
|
951
|
+
arr = compute_norm(arr, ft_target, ft_temp, rotation, template_mask, n_obs)
|
952
|
+
arr = be.multiply(arr, inv_sdev, out=arr)
|
953
|
+
scores_alt = be.maximum(arr, scores_alt, out=scores_alt)
|
954
|
+
|
955
|
+
if background_correction:
|
956
|
+
scores_alt = norm_mask(scores_alt, score_mask, out=scores_alt)
|
957
|
+
scores_alt = be.subtract(scores_alt, be.mean(scores_alt), out=scores_alt)
|
958
|
+
callback.correct_background(scores_alt)
|
838
959
|
|
839
960
|
return callback
|
840
961
|
|
@@ -853,6 +974,7 @@ def corr_scoring2(
|
|
853
974
|
target_filter: shm_type = None,
|
854
975
|
template_mask: shm_type = None,
|
855
976
|
score_mask: shm_type = None,
|
977
|
+
template_background: shm_type = None,
|
856
978
|
) -> CallbackClass:
|
857
979
|
template = be.from_sharedarr(template)
|
858
980
|
ft_target = be.from_sharedarr(ft_target)
|
@@ -886,14 +1008,14 @@ def corr_scoring2(
|
|
886
1008
|
template_mask = be.from_sharedarr(template_mask)
|
887
1009
|
n_obs = be.sum(template_mask, axis=axes, keepdims=True)
|
888
1010
|
|
889
|
-
norm_template = conditional_execute(
|
1011
|
+
norm_template = conditional_execute(standardize, n_obs is not None)
|
890
1012
|
norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
|
891
1013
|
norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1,))
|
892
1014
|
norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
|
893
1015
|
|
894
1016
|
template_filter_func = _create_filter_func(
|
895
1017
|
arr_shape=template.shape,
|
896
|
-
|
1018
|
+
template_filter=template_filter,
|
897
1019
|
arr_padded=True,
|
898
1020
|
)
|
899
1021
|
|
@@ -909,7 +1031,7 @@ def corr_scoring2(
|
|
909
1031
|
cache=False,
|
910
1032
|
batched=batched,
|
911
1033
|
)
|
912
|
-
arr_norm = template_filter_func(arr_sqz, ft_sqz
|
1034
|
+
arr_norm = template_filter_func(arr_sqz, ft_sqz)
|
913
1035
|
norm_template(arr_norm[unpadded_slice], template_mask, n_obs, axis=axes)
|
914
1036
|
|
915
1037
|
ft_sqz = be.rfftn(arr_norm, out=ft_sqz, axes=axes, s=shape)
|
@@ -942,7 +1064,7 @@ def _correlate_fts(ft_tar, ft_tmpl, ft_buffer, real_buffer, fast_shape, axes=Non
|
|
942
1064
|
|
943
1065
|
def _create_filter_func(
|
944
1066
|
arr_shape: Tuple[int],
|
945
|
-
|
1067
|
+
template_filter: BackendArray,
|
946
1068
|
arr_padded: bool = False,
|
947
1069
|
axes=None,
|
948
1070
|
) -> Callable:
|
@@ -960,7 +1082,7 @@ def _create_filter_func(
|
|
960
1082
|
----------
|
961
1083
|
arr_shape : tuple of ints
|
962
1084
|
Shape of the array to be filtered.
|
963
|
-
|
1085
|
+
template_filter : BackendArray
|
964
1086
|
Precomputed filter to apply in the frequency domain.
|
965
1087
|
arr_padded : bool, optional
|
966
1088
|
Whether the input template is padded and will need to be cropped
|
@@ -973,11 +1095,12 @@ def _create_filter_func(
|
|
973
1095
|
Callable
|
974
1096
|
Filter function with parameters template, ft_temp and template_filter.
|
975
1097
|
"""
|
1098
|
+
filter_shape = template_filter.shape
|
976
1099
|
if filter_shape == (1,):
|
977
1100
|
return conditional_execute(identity, execute_operation=True)
|
978
1101
|
|
979
1102
|
# Default case, all shapes are correctly matched
|
980
|
-
def _apply_filter(template, ft_temp
|
1103
|
+
def _apply_filter(template, ft_temp):
|
981
1104
|
ft_temp = be.rfftn(template, out=ft_temp, s=template.shape)
|
982
1105
|
ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
|
983
1106
|
return be.irfftn(ft_temp, out=template, s=template.shape)
|
@@ -990,19 +1113,57 @@ def _create_filter_func(
|
|
990
1113
|
_template = be.zeros(arr_shape, be._float_dtype)
|
991
1114
|
_ft_temp = be.zeros(filter_shape, be._complex_dtype)
|
992
1115
|
|
993
|
-
def _apply_filter_subset(template, ft_temp
|
1116
|
+
def _apply_filter_subset(template, ft_temp):
|
994
1117
|
_template[:] = template[real_subset]
|
995
|
-
template[real_subset] = _apply_filter(_template, _ft_temp
|
1118
|
+
template[real_subset] = _apply_filter(_template, _ft_temp)
|
996
1119
|
return template
|
997
1120
|
|
998
1121
|
return _apply_filter_subset
|
999
1122
|
|
1000
1123
|
|
1124
|
+
def _setup_background_correction(
|
1125
|
+
fast_shape: Tuple[int],
|
1126
|
+
template_background: BackendArray,
|
1127
|
+
rotation_buffer: BackendArray,
|
1128
|
+
unpadded_slice: Tuple[slice],
|
1129
|
+
interpolation_order: int = 3,
|
1130
|
+
tmpl_filter_func: Callable = identity,
|
1131
|
+
norm_template: Callable = identity,
|
1132
|
+
axes=None,
|
1133
|
+
shape=None,
|
1134
|
+
):
|
1135
|
+
scores_noise = be.zeros(fast_shape, be._float_dtype)
|
1136
|
+
template_background = be.from_sharedarr(template_background)
|
1137
|
+
|
1138
|
+
fwd_shape = shape
|
1139
|
+
if shape is not None:
|
1140
|
+
fwd_shape = shape
|
1141
|
+
|
1142
|
+
def compute_norm(arr, ft_target, ft_temp, matrix, template_mask, n_obs):
|
1143
|
+
_ = be.rigid_transform(
|
1144
|
+
arr=template_background,
|
1145
|
+
rotation_matrix=matrix,
|
1146
|
+
out=rotation_buffer,
|
1147
|
+
use_geometric_center=True,
|
1148
|
+
order=interpolation_order,
|
1149
|
+
cache=True,
|
1150
|
+
)
|
1151
|
+
template_rot = tmpl_filter_func(rotation_buffer, ft_temp)
|
1152
|
+
template_rot = norm_template(template_rot, template_mask, n_obs, axis=axes)
|
1153
|
+
|
1154
|
+
arr = to_padded(arr, template_rot, unpadded_slice)
|
1155
|
+
ft_temp = be.rfftn(arr, out=ft_temp, axes=axes, s=fwd_shape)
|
1156
|
+
return _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape, axes)
|
1157
|
+
|
1158
|
+
return scores_noise, compute_norm
|
1159
|
+
|
1160
|
+
|
1001
1161
|
MATCHING_EXHAUSTIVE_REGISTER = {
|
1002
1162
|
"CC": (cc_setup, corr_scoring),
|
1003
1163
|
"LCC": (lcc_setup, corr_scoring),
|
1004
1164
|
"CORR": (corr_setup, corr_scoring),
|
1005
1165
|
"CAM": (cam_setup, corr_scoring),
|
1166
|
+
# "NCC": (ncc_setup, ncc_scoring),
|
1006
1167
|
"FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
|
1007
1168
|
"FLC": (flc_setup, flc_scoring),
|
1008
1169
|
"MCC": (mcc_setup, mcc_scoring),
|