pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +50 -103
  6. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +50 -103
  15. scripts/pytme_runner.py +46 -69
  16. scripts/refine_matches.py +5 -7
  17. tests/preprocessing/test_compose.py +31 -30
  18. tests/preprocessing/test_frequency_filters.py +17 -32
  19. tests/preprocessing/test_preprocessor.py +0 -19
  20. tests/preprocessing/test_utils.py +13 -1
  21. tests/test_analyzer.py +2 -10
  22. tests/test_backends.py +47 -18
  23. tests/test_density.py +72 -13
  24. tests/test_extensions.py +1 -0
  25. tests/test_matching_cli.py +23 -9
  26. tests/test_matching_exhaustive.py +5 -5
  27. tests/test_matching_utils.py +3 -3
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +124 -71
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +110 -105
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +102 -58
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +28 -8
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post1.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tme/matching_scores.py CHANGED
@@ -10,14 +10,14 @@ import warnings
10
10
  from typing import Callable, Tuple, Dict
11
11
 
12
12
  import numpy as np
13
- from scipy.ndimage import laplace
14
13
 
15
14
  from .backends import backend as be
16
15
  from .types import CallbackClass, BackendArray, shm_type
17
16
  from .matching_utils import (
18
17
  conditional_execute,
19
18
  identity,
20
- normalize_template,
19
+ standardize,
20
+ to_padded,
21
21
  )
22
22
 
23
23
 
@@ -46,7 +46,7 @@ def cc_setup(
46
46
  )
47
47
  axes = matching_data._batch_axis(matching_data._batch_mask)
48
48
 
49
- ret = {
49
+ return {
50
50
  "template": be.to_sharedarr(matching_data.template, shm_handler),
51
51
  "ft_target": be.to_sharedarr(be.rfftn(target_pad, axes=axes), shm_handler),
52
52
  "inv_denominator": be.to_sharedarr(
@@ -55,8 +55,6 @@ def cc_setup(
55
55
  "numerator": be.to_sharedarr(be.zeros(1, be._float_dtype), shm_handler),
56
56
  }
57
57
 
58
- return ret
59
-
60
58
 
61
59
  def lcc_setup(matching_data, **kwargs) -> Dict:
62
60
  """
@@ -71,26 +69,8 @@ def lcc_setup(matching_data, **kwargs) -> Dict:
71
69
  -----
72
70
  To be used with :py:meth:`corr_scoring`.
73
71
  """
74
- target = be.to_numpy_array(matching_data._target)
75
- template = be.to_numpy_array(matching_data._template)
76
-
77
- subsets = matching_data._batch_iter(
78
- target.shape,
79
- tuple(1 if i in matching_data._target_dim else 0 for i in range(target.ndim)),
80
- )
81
- for subset in subsets:
82
- target[subset] = laplace(target[subset], mode="wrap")
83
-
84
- subsets = matching_data._batch_iter(
85
- template.shape,
86
- tuple(1 if i in matching_data._template_dim else 0 for i in range(target.ndim)),
87
- )
88
- for subset in subsets:
89
- template[subset] = laplace(template[subset], mode="wrap")
90
-
91
- matching_data._target = be.to_backend_array(target)
92
- matching_data._template = be.to_backend_array(template)
93
-
72
+ matching_data.target = matching_data.transform_target("laplace")
73
+ matching_data.template = matching_data.transform_template("laplace")
94
74
  return cc_setup(matching_data=matching_data, **kwargs)
95
75
 
96
76
 
@@ -184,19 +164,17 @@ def corr_setup(
184
164
  denominator = be.divide(1, denominator, out=denominator)
185
165
  denominator = be.multiply(denominator, mask, out=denominator)
186
166
 
187
- ret = {
167
+ return {
188
168
  "template": be.to_sharedarr(template, shm_handler),
189
169
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
190
170
  "inv_denominator": be.to_sharedarr(denominator, shm_handler),
191
171
  "numerator": be.to_sharedarr(numerator, shm_handler),
192
172
  }
193
173
 
194
- return ret
195
-
196
174
 
197
175
  def cam_setup(matching_data, **kwargs) -> Dict:
198
176
  """
199
- Like :py:meth:`corr_setup` but with standardized ``target``, ``template``
177
+ Like :py:meth:`corr_setup` but with standardized ``target`` and ``template``
200
178
 
201
179
  .. math::
202
180
 
@@ -206,19 +184,14 @@ def cam_setup(matching_data, **kwargs) -> Dict:
206
184
  -----
207
185
  To be used with :py:meth:`corr_scoring`.
208
186
  """
209
- template = matching_data._template
210
- axis = matching_data._batch_axis(matching_data._target_batch)
211
- matching_data._template = be.divide(
212
- be.subtract(template, be.mean(template, axis=axis, keepdims=True)),
213
- be.std(template, axis=axis, keepdims=True),
214
- )
215
- target = matching_data._target
216
- axis = matching_data._batch_axis(matching_data._template_batch)
217
- matching_data._target = be.divide(
218
- be.subtract(target, be.mean(target, axis=axis, keepdims=True)),
219
- be.std(target, axis=axis, keepdims=True),
220
- )
221
- return corr_setup(matching_data=matching_data, **kwargs)
187
+ matching_data.target = matching_data.transform_target("standardize")
188
+ matching_data.template = matching_data.transform_template("standardize")
189
+ return flcSphericalMask_setup(matching_data=matching_data, **kwargs)
190
+
191
+
192
+ def ncc_setup(matching_data, **kwargs) -> Dict:
193
+ matching_data.target = matching_data.transform_target("standardize")
194
+ return cc_setup(matching_data=matching_data, **kwargs)
222
195
 
223
196
 
224
197
  def flc_setup(
@@ -242,15 +215,13 @@ def flc_setup(
242
215
  target_pad = be.square(target_pad, out=target_pad)
243
216
  ft_target2 = be.rfftn(target_pad, axes=data_axes)
244
217
 
245
- ret = {
218
+ return {
246
219
  "template": be.to_sharedarr(matching_data.template, shm_handler),
247
220
  "template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
248
221
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
249
222
  "ft_target2": be.to_sharedarr(ft_target2, shm_handler),
250
223
  }
251
224
 
252
- return ret
253
-
254
225
 
255
226
  def flcSphericalMask_setup(
256
227
  matching_data,
@@ -302,7 +273,7 @@ def flcSphericalMask_setup(
302
273
  temp = be.irfftn(ft_temp, s=data_shape, axes=data_axes)
303
274
 
304
275
  temp2 = be.norm_scores(1, temp2, temp, n_obs, be.eps(be._float_dtype), temp2)
305
- ret = {
276
+ return {
306
277
  "template": be.to_sharedarr(matching_data.template, shm_handler),
307
278
  "template_mask": be.to_sharedarr(template_mask, shm_handler),
308
279
  "ft_target": be.to_sharedarr(ft_target, shm_handler),
@@ -310,8 +281,6 @@ def flcSphericalMask_setup(
310
281
  "numerator": be.to_sharedarr(be.zeros(1, be._float_dtype), shm_handler),
311
282
  }
312
283
 
313
- return ret
314
-
315
284
 
316
285
  def mcc_setup(
317
286
  matching_data,
@@ -331,7 +300,7 @@ def mcc_setup(
331
300
  target = be.topleft_pad(target, shape)
332
301
  target_mask = be.topleft_pad(target_mask, shape)
333
302
 
334
- ret = {
303
+ return {
335
304
  "template": be.to_sharedarr(matching_data.template, shm_handler),
336
305
  "template_mask": be.to_sharedarr(matching_data.template_mask, shm_handler),
337
306
  "ft_target": be.to_sharedarr(be.rfftn(target, axes=ax), shm_handler),
@@ -341,8 +310,6 @@ def mcc_setup(
341
310
  "ft_target_mask": be.to_sharedarr(be.rfftn(target_mask, axes=ax), shm_handler),
342
311
  }
343
312
 
344
- return ret
345
-
346
313
 
347
314
  def corr_scoring(
348
315
  template: shm_type,
@@ -357,6 +324,7 @@ def corr_scoring(
357
324
  interpolation_order: int,
358
325
  template_mask: shm_type = None,
359
326
  score_mask: shm_type = None,
327
+ template_background: shm_type = None,
360
328
  ) -> CallbackClass:
361
329
  """
362
330
  Calculates a normalized cross-correlation between a target f and a template g.
@@ -414,7 +382,7 @@ def corr_scoring(
414
382
  template_mask = be.from_sharedarr(template_mask)
415
383
  n_obs = be.sum(template_mask) if template_mask is not None else None
416
384
 
417
- norm_template = conditional_execute(normalize_template, n_obs is not None)
385
+ norm_template = conditional_execute(standardize, n_obs is not None)
418
386
  norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
419
387
  norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1))
420
388
  norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
@@ -423,30 +391,41 @@ def corr_scoring(
423
391
  ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
424
392
  template_rot = be.zeros(template.shape, be._float_dtype)
425
393
 
426
- template_filter_func = _create_filter_func(
427
- arr_shape=template.shape,
428
- filter_shape=template_filter.shape,
429
- )
394
+ tmpl_filter_func = _create_filter_func(template.shape, template_filter)
430
395
 
431
396
  center = be.divide(be.to_backend_array(template.shape) - 1, 2)
432
397
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
398
+
399
+ background_correction = template_background is not None
400
+ if background_correction:
401
+ scores_alt, compute_norm = _setup_background_correction(
402
+ fast_shape=fast_shape,
403
+ template_background=template_background,
404
+ rotation_buffer=template_rot,
405
+ unpadded_slice=unpadded_slice,
406
+ interpolation_order=interpolation_order,
407
+ tmpl_filter_func=tmpl_filter_func,
408
+ norm_template=norm_template,
409
+ )
410
+
433
411
  for index in range(rotations.shape[0]):
434
412
  rotation = rotations[index]
435
- matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
413
+ matrix = be._build_transform_matrix(
414
+ rotation_matrix=rotation, center=center, shape=template.shape
415
+ )
436
416
  _ = be.rigid_transform(
437
417
  arr=template,
438
418
  rotation_matrix=matrix,
439
419
  out=template_rot,
440
420
  order=interpolation_order,
441
421
  cache=True,
422
+ use_geometric_center=True,
442
423
  )
443
424
 
444
- template_rot = template_filter_func(template_rot, ft_temp, template_filter)
445
- norm_template(template_rot, template_mask, n_obs)
446
-
447
- arr = be.fill(arr, 0)
448
- arr[unpadded_slice] = template_rot
425
+ template_rot = tmpl_filter_func(template_rot, ft_temp)
426
+ template_rot = norm_template(template_rot, template_mask, n_obs)
449
427
 
428
+ arr = to_padded(arr, template_rot, unpadded_slice)
450
429
  ft_temp = be.rfftn(arr, s=fast_shape, out=ft_temp)
451
430
  arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
452
431
 
@@ -455,6 +434,17 @@ def corr_scoring(
455
434
  arr = norm_mask(arr, score_mask, out=arr)
456
435
 
457
436
  callback(arr, rotation_matrix=rotation)
437
+ if background_correction:
438
+ arr = compute_norm(arr, ft_target, ft_temp, matrix, template_mask, n_obs)
439
+ arr = norm_sub(arr, numerator, out=arr)
440
+ arr = norm_mul(arr, inv_denominator, out=arr)
441
+ arr = norm_mask(arr, score_mask, out=arr)
442
+ scores_alt = be.maximum(arr, scores_alt, out=scores_alt)
443
+
444
+ if background_correction:
445
+ scores_alt = norm_mask(scores_alt, score_mask, out=scores_alt)
446
+ scores_alt = be.subtract(scores_alt, be.mean(scores_alt), out=scores_alt)
447
+ callback.correct_background(scores_alt)
458
448
 
459
449
  return callback
460
450
 
@@ -471,6 +461,7 @@ def flc_scoring(
471
461
  callback: CallbackClass,
472
462
  interpolation_order: int,
473
463
  score_mask: shm_type = None,
464
+ template_background: shm_type = None,
474
465
  ) -> CallbackClass:
475
466
  """
476
467
  Computes a normalized cross-correlation between ``target`` (f),
@@ -511,8 +502,6 @@ def flc_scoring(
511
502
  Rotation matrices to be sampled (n, d, d).
512
503
  callback : CallbackClass
513
504
  A callable for processing the result of each rotation.
514
- callback_class_args : Dict
515
- Dictionary of arguments to be passed to ``callback``.
516
505
  interpolation_order : int
517
506
  Spline order for template rotations.
518
507
 
@@ -524,7 +513,6 @@ def flc_scoring(
524
513
  ----------
525
514
  .. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
526
515
  """
527
- float_dtype, complex_dtype = be._float_dtype, be._complex_dtype
528
516
  template = be.from_sharedarr(template)
529
517
  template_mask = be.from_sharedarr(template_mask)
530
518
  ft_target = be.from_sharedarr(ft_target)
@@ -532,42 +520,55 @@ def flc_scoring(
532
520
  template_filter = be.from_sharedarr(template_filter)
533
521
  score_mask = be.from_sharedarr(score_mask)
534
522
 
535
- arr = be.zeros(fast_shape, float_dtype)
536
- temp = be.zeros(fast_shape, float_dtype)
537
- temp2 = be.zeros(fast_shape, float_dtype)
538
- ft_temp = be.zeros(fast_ft_shape, complex_dtype)
539
- ft_denom = be.zeros(fast_ft_shape, complex_dtype)
523
+ arr = be.zeros(fast_shape, be._float_dtype)
524
+ temp = be.zeros(fast_shape, be._float_dtype)
525
+ temp2 = be.zeros(fast_shape, be._float_dtype)
526
+ ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
527
+ ft_denom = be.zeros(fast_ft_shape, be._complex_dtype)
540
528
  template_rot = be.zeros(template.shape, be._float_dtype)
541
529
  template_mask_rot = be.zeros(template.shape, be._float_dtype)
542
530
 
543
- tmpl_filter_func = _create_filter_func(template.shape, template_filter.shape)
531
+ tmpl_filter_func = _create_filter_func(template.shape, template_filter)
544
532
  norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
545
533
 
546
- eps = be.eps(float_dtype)
534
+ eps = be.eps(be._float_dtype)
547
535
  center = be.divide(be.to_backend_array(template.shape) - 1, 2)
548
536
  unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
537
+
538
+ background_correction = template_background is not None
539
+ if background_correction:
540
+ scores_alt, compute_norm = _setup_background_correction(
541
+ fast_shape=fast_shape,
542
+ template_background=template_background,
543
+ rotation_buffer=template_rot,
544
+ unpadded_slice=unpadded_slice,
545
+ interpolation_order=interpolation_order,
546
+ tmpl_filter_func=tmpl_filter_func,
547
+ norm_template=standardize,
548
+ )
549
+
549
550
  for index in range(rotations.shape[0]):
550
551
  rotation = rotations[index]
551
- matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
552
+ matrix = be._build_transform_matrix(
553
+ rotation_matrix=rotation, center=center, shape=template.shape
554
+ )
552
555
  _ = be.rigid_transform(
553
556
  arr=template,
554
557
  arr_mask=template_mask,
555
558
  rotation_matrix=matrix,
556
559
  out=template_rot,
557
560
  out_mask=template_mask_rot,
558
- use_geometric_center=True,
559
561
  order=interpolation_order,
560
562
  cache=True,
563
+ use_geometric_center=True,
561
564
  )
562
565
 
563
566
  n_obs = be.sum(template_mask_rot)
564
- template_rot = tmpl_filter_func(template_rot, ft_temp, template_filter)
565
- template_rot = normalize_template(template_rot, template_mask_rot, n_obs)
567
+ template_rot = tmpl_filter_func(template_rot, ft_temp)
568
+ template_rot = standardize(template_rot, template_mask_rot, n_obs)
566
569
 
567
- arr = be.fill(arr, 0)
568
- temp = be.fill(temp, 0)
569
- arr[unpadded_slice] = template_rot
570
- temp[unpadded_slice] = template_mask_rot
570
+ arr = to_padded(arr, template_rot, unpadded_slice)
571
+ temp = to_padded(temp, template_mask_rot, unpadded_slice)
571
572
 
572
573
  ft_temp = be.rfftn(temp, out=ft_temp, s=fast_shape)
573
574
  temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, fast_shape)
@@ -576,14 +577,112 @@ def flc_scoring(
576
577
  ft_temp = be.rfftn(arr, out=ft_temp, s=fast_shape)
577
578
  arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
578
579
 
579
- arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
580
+ inv_sdev = be.norm_scores(1, temp2, temp, n_obs, eps, temp2)
581
+ arr = be.multiply(arr, inv_sdev, out=arr)
580
582
  arr = norm_mask(arr, score_mask, out=arr)
581
583
 
582
584
  callback(arr, rotation_matrix=rotation)
585
+ if background_correction:
586
+ arr = compute_norm(arr, ft_target, ft_temp, matrix, template_mask, n_obs)
587
+ arr = be.multiply(arr, inv_sdev, out=arr)
588
+ scores_alt = be.maximum(arr, scores_alt, out=scores_alt)
589
+
590
+ if background_correction:
591
+ scores_alt = norm_mask(scores_alt, score_mask, out=scores_alt)
592
+ scores_alt = be.subtract(scores_alt, be.mean(scores_alt), out=scores_alt)
593
+ callback.correct_background(scores_alt)
583
594
 
584
595
  return callback
585
596
 
586
597
 
598
+ def ncc_scoring(
599
+ template: shm_type,
600
+ ft_target: shm_type,
601
+ fast_shape: Tuple[int],
602
+ fast_ft_shape: Tuple[int],
603
+ rotations: BackendArray,
604
+ callback: CallbackClass,
605
+ interpolation_order: int,
606
+ template_filter: shm_type = None,
607
+ score_mask: shm_type = None,
608
+ template_background: shm_type = None,
609
+ **kwargs,
610
+ ) -> CallbackClass:
611
+ template = be.from_sharedarr(template)
612
+ ft_target = be.from_sharedarr(ft_target)
613
+ score_mask = be.from_sharedarr(score_mask)
614
+ template_filter = be.from_sharedarr(template_filter)
615
+
616
+ arr = be.zeros(fast_shape, be._float_dtype)
617
+ ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
618
+ template_rot = be.zeros(template.shape, be._float_dtype)
619
+
620
+ # Welford arrays for global statistics
621
+ pixel_mean = be.zeros(fast_shape, be._float_dtype)
622
+ pixel_M2 = be.zeros(fast_shape, be._float_dtype)
623
+
624
+ tmpl_filter_func = _create_filter_func(template.shape, template_filter)
625
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
626
+
627
+ size = be.size(template)
628
+ center = be.divide(be.to_backend_array(template.shape) - 1, 2)
629
+ unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
630
+ n_angles = rotations.shape[0]
631
+
632
+ # Scale forward transform by 1/n i.e. norm 'forward'
633
+ ft_target = be.multiply(ft_target, 1 / be.size(arr))
634
+
635
+ background_correction = template_background is not None
636
+ for index in range(n_angles):
637
+ arr = be.fill(arr, 0)
638
+ rotation = rotations[index]
639
+ matrix = be._build_transform_matrix(
640
+ rotation_matrix=rotation, center=center, shape=template.shape
641
+ )
642
+
643
+ be.rigid_transform(
644
+ template,
645
+ rotation_matrix=matrix,
646
+ out=template_rot,
647
+ order=interpolation_order,
648
+ cache=True,
649
+ use_geometric_center=True,
650
+ )
651
+ template_rot = tmpl_filter_func(template_rot, ft_temp)
652
+ template_rot = standardize(template_rot, 1, size)
653
+
654
+ arr = to_padded(arr, template_rot, unpadded_slice)
655
+ ft_temp = be.rfftn(arr, s=fast_shape, norm="forward")
656
+ ft_temp = be.multiply(ft_temp, ft_target, out=ft_temp)
657
+
658
+ arr = be.irfftn(ft_temp, s=fast_shape, norm="forward")
659
+ arr = norm_mask(arr, score_mask, out=arr)
660
+ callback(arr, rotation_matrix=rotation)
661
+
662
+ delta = be.subtract(arr, pixel_mean)
663
+ pixel_mean = be.add(pixel_mean, be.divide(delta, index + 1), out=pixel_mean)
664
+ delta2 = be.subtract(arr, pixel_mean)
665
+ delta = be.multiply(delta, delta2, out=delta)
666
+ pixel_M2 = be.add(pixel_M2, delta, out=pixel_M2)
667
+
668
+ global_mean = be.mean(pixel_mean)
669
+ pixel_variance = be.divide(pixel_M2, n_angles - 1)
670
+ global_std = be.sqrt(be.mean(pixel_variance))
671
+
672
+ callback.correct_background(global_mean, global_std)
673
+ if background_correction:
674
+ # Adapt units for local normalization
675
+ pixel_mean = be.subtract(pixel_mean, global_mean, out=pixel_mean)
676
+ pixel_mean = be.divide(pixel_mean, global_std, out=pixel_mean)
677
+
678
+ pixel_std = be.sqrt(pixel_variance, out=pixel_variance)
679
+ pixel_std = be.divide(pixel_std, global_std, out=pixel_variance)
680
+
681
+ pixel_std = be.where(pixel_std > 1e-4, 1 / pixel_std, 0.0)
682
+ callback.correct_background(pixel_mean, pixel_std)
683
+ return callback
684
+
685
+
587
686
  def mcc_scoring(
588
687
  template: shm_type,
589
688
  template_mask: shm_type,
@@ -597,7 +696,7 @@ def mcc_scoring(
597
696
  callback: CallbackClass,
598
697
  interpolation_order: int,
599
698
  overlap_ratio: float = 0.3,
600
- score_mask: shm_type = None,
699
+ **kwargs,
601
700
  ) -> CallbackClass:
602
701
  """
603
702
  Computes a normalized cross-correlation score between ``target`` (f),
@@ -676,12 +775,11 @@ def mcc_scoring(
676
775
  temp3 = be.zeros(fast_shape, float_dtype)
677
776
  temp_ft = be.zeros(fast_ft_shape, complex_dtype)
678
777
 
679
- template_filter_func = _create_filter_func(
778
+ tmpl_filter_func = _create_filter_func(
680
779
  arr_shape=template.shape,
681
- filter_shape=template_filter.shape,
780
+ template_filter=template_filter,
682
781
  arr_padded=True,
683
782
  )
684
-
685
783
  for index in range(rotations.shape[0]):
686
784
  rotation = rotations[index]
687
785
  template_rot = be.fill(template_rot, 0)
@@ -697,8 +795,8 @@ def mcc_scoring(
697
795
  cache=True,
698
796
  )
699
797
 
700
- template_filter_func(template_rot, temp_ft, template_filter)
701
- normalize_template(template_rot, temp, be.sum(temp))
798
+ template_rot = tmpl_filter_func(template_rot, temp_ft)
799
+ template_rot = standardize(template_rot, temp, be.sum(temp))
702
800
 
703
801
  temp_ft = be.rfftn(template_rot, out=temp_ft, s=fast_shape)
704
802
  temp2 = be.irfftn(target_mask_ft * temp_ft, out=temp2, s=fast_shape)
@@ -769,6 +867,7 @@ def flc_scoring2(
769
867
  callback: CallbackClass,
770
868
  interpolation_order: int,
771
869
  score_mask: shm_type = None,
870
+ template_background: shm_type = None,
772
871
  ) -> CallbackClass:
773
872
  template = be.from_sharedarr(template)
774
873
  template_mask = be.from_sharedarr(template_mask)
@@ -795,13 +894,25 @@ def flc_scoring2(
795
894
 
796
895
  tmp_sqz, arr_sqz, ft_temp = temp[sqz_slice], arr[sqz_slice], ft_denom[sqz_slice]
797
896
 
798
- template_filter_func = _create_filter_func(
897
+ tmpl_filter_func = _create_filter_func(
799
898
  arr_shape=template.shape,
800
- filter_shape=template_filter.shape,
899
+ template_filter=template_filter,
801
900
  arr_padded=True,
802
901
  )
803
902
  norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
804
903
 
904
+ background_correction = template_background is not None
905
+ if background_correction:
906
+ scores_alt, compute_norm = _setup_background_correction(
907
+ fast_shape=fast_shape,
908
+ template_background=template_background,
909
+ rotation_buffer=arr_sqz[tmpl_subset],
910
+ unpadded_slice=tmpl_subset,
911
+ interpolation_order=interpolation_order,
912
+ tmpl_filter_func=tmpl_filter_func,
913
+ norm_template=standardize,
914
+ )
915
+
805
916
  eps = be.eps(be._float_dtype)
806
917
  for index in range(rotations.shape[0]):
807
918
  rotation = rotations[index]
@@ -821,8 +932,8 @@ def flc_scoring2(
821
932
  )
822
933
 
823
934
  n_obs = be.sum(tmp_sqz, axis=axes, keepdims=True)
824
- arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
825
- arr_norm = normalize_template(arr_norm, tmp_sqz, n_obs, axis=axes)
935
+ arr_norm = tmpl_filter_func(arr_sqz, ft_temp)
936
+ arr_norm = standardize(arr_norm, tmp_sqz, n_obs, axis=axes)
826
937
 
827
938
  ft_temp = be.rfftn(tmp_sqz, out=ft_temp, axes=axes, s=shape)
828
939
  temp = _correlate_fts(ft_target, ft_temp, ft_denom, temp, shape, axes)
@@ -831,10 +942,20 @@ def flc_scoring2(
831
942
  ft_temp = be.rfftn(arr_norm, out=ft_temp, axes=axes, s=shape)
832
943
  arr = _correlate_fts(ft_target, ft_temp, ft_denom, arr, shape, axes)
833
944
 
834
- arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
945
+ inv_sdev = be.norm_scores(1, temp2, temp, n_obs, eps, temp2)
946
+ arr = be.multiply(arr, inv_sdev, out=arr)
835
947
  arr = norm_mask(arr, score_mask, out=arr)
836
948
 
837
949
  callback(arr, rotation_matrix=rotation)
950
+ if background_correction:
951
+ arr = compute_norm(arr, ft_target, ft_temp, rotation, template_mask, n_obs)
952
+ arr = be.multiply(arr, inv_sdev, out=arr)
953
+ scores_alt = be.maximum(arr, scores_alt, out=scores_alt)
954
+
955
+ if background_correction:
956
+ scores_alt = norm_mask(scores_alt, score_mask, out=scores_alt)
957
+ scores_alt = be.subtract(scores_alt, be.mean(scores_alt), out=scores_alt)
958
+ callback.correct_background(scores_alt)
838
959
 
839
960
  return callback
840
961
 
@@ -853,6 +974,7 @@ def corr_scoring2(
853
974
  target_filter: shm_type = None,
854
975
  template_mask: shm_type = None,
855
976
  score_mask: shm_type = None,
977
+ template_background: shm_type = None,
856
978
  ) -> CallbackClass:
857
979
  template = be.from_sharedarr(template)
858
980
  ft_target = be.from_sharedarr(ft_target)
@@ -886,14 +1008,14 @@ def corr_scoring2(
886
1008
  template_mask = be.from_sharedarr(template_mask)
887
1009
  n_obs = be.sum(template_mask, axis=axes, keepdims=True)
888
1010
 
889
- norm_template = conditional_execute(normalize_template, n_obs is not None)
1011
+ norm_template = conditional_execute(standardize, n_obs is not None)
890
1012
  norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
891
1013
  norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1,))
892
1014
  norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
893
1015
 
894
1016
  template_filter_func = _create_filter_func(
895
1017
  arr_shape=template.shape,
896
- filter_shape=template_filter.shape,
1018
+ template_filter=template_filter,
897
1019
  arr_padded=True,
898
1020
  )
899
1021
 
@@ -909,7 +1031,7 @@ def corr_scoring2(
909
1031
  cache=False,
910
1032
  batched=batched,
911
1033
  )
912
- arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
1034
+ arr_norm = template_filter_func(arr_sqz, ft_sqz)
913
1035
  norm_template(arr_norm[unpadded_slice], template_mask, n_obs, axis=axes)
914
1036
 
915
1037
  ft_sqz = be.rfftn(arr_norm, out=ft_sqz, axes=axes, s=shape)
@@ -942,7 +1064,7 @@ def _correlate_fts(ft_tar, ft_tmpl, ft_buffer, real_buffer, fast_shape, axes=Non
942
1064
 
943
1065
  def _create_filter_func(
944
1066
  arr_shape: Tuple[int],
945
- filter_shape: BackendArray,
1067
+ template_filter: BackendArray,
946
1068
  arr_padded: bool = False,
947
1069
  axes=None,
948
1070
  ) -> Callable:
@@ -960,7 +1082,7 @@ def _create_filter_func(
960
1082
  ----------
961
1083
  arr_shape : tuple of ints
962
1084
  Shape of the array to be filtered.
963
- filter_shape : BackendArray
1085
+ template_filter : BackendArray
964
1086
  Precomputed filter to apply in the frequency domain.
965
1087
  arr_padded : bool, optional
966
1088
  Whether the input template is padded and will need to be cropped
@@ -973,11 +1095,12 @@ def _create_filter_func(
973
1095
  Callable
974
1096
  Filter function with parameters template, ft_temp and template_filter.
975
1097
  """
1098
+ filter_shape = template_filter.shape
976
1099
  if filter_shape == (1,):
977
1100
  return conditional_execute(identity, execute_operation=True)
978
1101
 
979
1102
  # Default case, all shapes are correctly matched
980
- def _apply_filter(template, ft_temp, template_filter):
1103
+ def _apply_filter(template, ft_temp):
981
1104
  ft_temp = be.rfftn(template, out=ft_temp, s=template.shape)
982
1105
  ft_temp = be.multiply(ft_temp, template_filter, out=ft_temp)
983
1106
  return be.irfftn(ft_temp, out=template, s=template.shape)
@@ -990,19 +1113,57 @@ def _create_filter_func(
990
1113
  _template = be.zeros(arr_shape, be._float_dtype)
991
1114
  _ft_temp = be.zeros(filter_shape, be._complex_dtype)
992
1115
 
993
- def _apply_filter_subset(template, ft_temp, template_filter):
1116
+ def _apply_filter_subset(template, ft_temp):
994
1117
  _template[:] = template[real_subset]
995
- template[real_subset] = _apply_filter(_template, _ft_temp, template_filter)
1118
+ template[real_subset] = _apply_filter(_template, _ft_temp)
996
1119
  return template
997
1120
 
998
1121
  return _apply_filter_subset
999
1122
 
1000
1123
 
1124
+ def _setup_background_correction(
1125
+ fast_shape: Tuple[int],
1126
+ template_background: BackendArray,
1127
+ rotation_buffer: BackendArray,
1128
+ unpadded_slice: Tuple[slice],
1129
+ interpolation_order: int = 3,
1130
+ tmpl_filter_func: Callable = identity,
1131
+ norm_template: Callable = identity,
1132
+ axes=None,
1133
+ shape=None,
1134
+ ):
1135
+ scores_noise = be.zeros(fast_shape, be._float_dtype)
1136
+ template_background = be.from_sharedarr(template_background)
1137
+
1138
+ fwd_shape = shape
1139
+ if shape is not None:
1140
+ fwd_shape = shape
1141
+
1142
+ def compute_norm(arr, ft_target, ft_temp, matrix, template_mask, n_obs):
1143
+ _ = be.rigid_transform(
1144
+ arr=template_background,
1145
+ rotation_matrix=matrix,
1146
+ out=rotation_buffer,
1147
+ use_geometric_center=True,
1148
+ order=interpolation_order,
1149
+ cache=True,
1150
+ )
1151
+ template_rot = tmpl_filter_func(rotation_buffer, ft_temp)
1152
+ template_rot = norm_template(template_rot, template_mask, n_obs, axis=axes)
1153
+
1154
+ arr = to_padded(arr, template_rot, unpadded_slice)
1155
+ ft_temp = be.rfftn(arr, out=ft_temp, axes=axes, s=fwd_shape)
1156
+ return _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape, axes)
1157
+
1158
+ return scores_noise, compute_norm
1159
+
1160
+
1001
1161
  MATCHING_EXHAUSTIVE_REGISTER = {
1002
1162
  "CC": (cc_setup, corr_scoring),
1003
1163
  "LCC": (lcc_setup, corr_scoring),
1004
1164
  "CORR": (corr_setup, corr_scoring),
1005
1165
  "CAM": (cam_setup, corr_scoring),
1166
+ # "NCC": (ncc_setup, ncc_scoring),
1006
1167
  "FLCSphericalMask": (flcSphericalMask_setup, corr_scoring),
1007
1168
  "FLC": (flc_setup, flc_scoring),
1008
1169
  "MCC": (mcc_setup, mcc_scoring),