pytme 0.1.6__cp311-cp311-macosx_14_0_arm64.whl → 0.1.8__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tme/matching_data.py CHANGED
@@ -4,13 +4,14 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
-
7
+ import warnings
8
8
  from typing import Tuple, List
9
9
 
10
10
  import numpy as np
11
11
  from numpy.typing import NDArray
12
12
 
13
13
  from . import Density
14
+ from .types import ArrayLike
14
15
  from .backends import backend
15
16
  from .matching_utils import compute_full_convolution_index
16
17
 
@@ -80,7 +81,11 @@ class MatchingData:
80
81
  return arr
81
82
 
82
83
  def subset_array(
83
- self, arr: NDArray, arr_slice: Tuple[slice], padding: NDArray
84
+ self,
85
+ arr: NDArray,
86
+ arr_slice: Tuple[slice],
87
+ padding: NDArray,
88
+ invert: bool = False,
84
89
  ) -> NDArray:
85
90
  """
86
91
  Extract a subset of the input array according to the given slice and
@@ -95,19 +100,22 @@ class MatchingData:
95
100
  padding : NDArray
96
101
  Padding values for each dimension. If the padding exceeds the array
97
102
  dimensions, the extra regions are filled with the mean of the array
98
- values, otherwise, the
99
- values in ``arr`` are used.
103
+ values, otherwise, the values in ``arr`` are used.
104
+ invert : bool, optional
105
+ Whether the returned array should be inverted and normalized to the interval
106
+ [0, 1]. If available, uses the metadata information of the Density object,
107
+ otherwise computes min and max on the extracted subset.
100
108
 
101
109
  Returns
102
110
  -------
103
111
  NDArray
104
112
  Subset of the input array with padding applied.
105
113
  """
106
- padding = np.maximum(padding, 0)
114
+ padding = backend.to_numpy_array(padding)
115
+ padding = np.maximum(padding, 0).astype(int)
107
116
 
108
117
  slice_start = np.array([x.start for x in arr_slice], dtype=int)
109
118
  slice_stop = np.array([x.stop for x in arr_slice], dtype=int)
110
- slice_shape = np.subtract(slice_stop, slice_start)
111
119
 
112
120
  padding = np.add(padding, np.mod(padding, 2))
113
121
  left_pad = right_pad = np.divide(padding, 2).astype(int)
@@ -117,20 +125,18 @@ class MatchingData:
117
125
  np.subtract(arr.shape, slice_stop), right_pad
118
126
  ).astype(int)
119
127
 
120
- ret_shape = np.add(slice_shape, padding)
121
128
  arr_start = np.subtract(slice_start, data_voxels_left)
122
129
  arr_stop = np.add(slice_stop, data_voxels_right)
123
130
  arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
124
131
  arr_mesh = self._slice_to_mesh(arr_slice, arr.shape)
125
132
 
126
- subset_start = np.subtract(left_pad, data_voxels_left)
127
- subset_stop = np.add(subset_start, np.subtract(arr_stop, arr_start))
128
- subset_slice = tuple(slice(*prod) for prod in zip(subset_start, subset_stop))
129
- subset_mesh = self._slice_to_mesh(subset_slice, ret_shape)
130
-
133
+ arr_min, arr_max = None, None
131
134
  if type(arr) == Density:
132
135
  if type(arr.data) == np.memmap:
133
- arr = Density.from_file(arr.data.filename, subset=arr_slice).data
136
+ dens = Density.from_file(arr.data.filename, subset=arr_slice)
137
+ arr = dens.data
138
+ arr_min = dens.metadata.get("min", None)
139
+ arr_max = dens.metadata.get("max", None)
134
140
  else:
135
141
  arr = np.asarray(arr.data[*arr_mesh])
136
142
  else:
@@ -139,10 +145,38 @@ class MatchingData:
139
145
  arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype
140
146
  )
141
147
  arr = np.asarray(arr[*arr_mesh])
142
- ret = np.full(
143
- shape=np.add(slice_shape, padding), fill_value=arr.mean(), dtype=arr.dtype
148
+
149
+ def _warn_on_mismatch(
150
+ expectation: float, computation: float, name: str
151
+ ) -> float:
152
+ expectation, computation = float(expectation), float(computation)
153
+ if expectation is None:
154
+ expectation = computation
155
+
156
+ if abs(computation) > abs(expectation):
157
+ warnings.warn(
158
+ f"Computed {name} value is more extreme than value specified in file"
159
+ f" (|{computation}| > |{expectation}|). This may lead to issues"
160
+ " with padding and contrast inversion."
161
+ )
162
+
163
+ return expectation
164
+
165
+ padding = tuple(
166
+ (left, right)
167
+ for left, right in zip(
168
+ np.subtract(left_pad, data_voxels_left),
169
+ np.subtract(right_pad, data_voxels_right),
170
+ )
144
171
  )
145
- ret[*subset_mesh] = arr
172
+ ret = np.pad(arr, padding, mode="reflect")
173
+
174
+ if invert:
175
+ arr_min = _warn_on_mismatch(arr_min, arr.min(), "min")
176
+ arr_max = _warn_on_mismatch(arr_max, arr.max(), "max")
177
+
178
+ np.subtract(-ret, arr_min, out=ret)
179
+ np.divide(ret, arr_max - arr_min, out=ret)
146
180
 
147
181
  return ret
148
182
 
@@ -175,7 +209,7 @@ class MatchingData:
175
209
  MatchingData
176
210
  Newly allocated sliced class instance.
177
211
  """
178
- target_shape = self.target.shape
212
+ target_shape = self._target.shape
179
213
  template_shape = self._template.shape
180
214
 
181
215
  if target_slice is None:
@@ -184,9 +218,9 @@ class MatchingData:
184
218
  template_slice = self._shape_to_slice(template_shape)
185
219
 
186
220
  if target_pad is None:
187
- target_pad = np.zeros(len(self.target.shape), dtype=int)
221
+ target_pad = np.zeros(len(self._target.shape), dtype=int)
188
222
  if template_pad is None:
189
- template_pad = np.zeros(len(self.target.shape), dtype=int)
223
+ template_pad = np.zeros(len(self._target.shape), dtype=int)
190
224
 
191
225
  indices = compute_full_convolution_index(
192
226
  outer_shape=self._target.shape,
@@ -196,12 +230,12 @@ class MatchingData:
196
230
  )
197
231
 
198
232
  target_subset = self.subset_array(
199
- arr=self._target, arr_slice=target_slice, padding=target_pad
233
+ arr=self._target,
234
+ arr_slice=target_slice,
235
+ padding=target_pad,
236
+ invert=self._invert_target,
200
237
  )
201
- if self._invert_target:
202
- target_subset *= -1
203
- target_min, target_max = target_subset.min(), target_subset.max()
204
- target_subset = (target_subset - target_min) / (target_max - target_min)
238
+
205
239
  template_subset = self.subset_array(
206
240
  arr=self._template,
207
241
  arr_slice=template_slice,
@@ -230,6 +264,15 @@ class MatchingData:
230
264
  padding=template_pad,
231
265
  )
232
266
 
267
+ target_dims, template_dims = None, None
268
+ if hasattr(self, "_target_dims"):
269
+ target_dims = self._target_dims
270
+
271
+ if hasattr(self, "_template_dims"):
272
+ template_dims = self._template_dims
273
+
274
+ ret._set_batch_dimension(target_dims=target_dims, template_dims=template_dims)
275
+
233
276
  return ret
234
277
 
235
278
  def to_backend(self) -> None:
@@ -244,20 +287,223 @@ class MatchingData:
244
287
  self._default_dtype = backend._default_dtype
245
288
  self._complex_dtype = backend._complex_dtype
246
289
 
290
+ def _set_batch_dimension(
291
+ self, target_dims: Tuple[int] = None, template_dims: Tuple[int] = None
292
+ ) -> None:
293
+ """
294
+ Sets the shapes of target and template for template matching considering
295
+ their corresponding batch dimensions.
296
+
297
+ Parameters
298
+ ----------
299
+ target_dims : Tuple[int], optional
300
+ A tuple of integers specifying the batch dimensions of the target. If None,
301
+ the target is assumed not to have batch dimensions.
302
+ template_dims : Tuple[int], optional
303
+ A tuple of integers specifying the batch dimensions of the template. If None,
304
+ the template is assumed not to have batch dimensions.
305
+
306
+ Notes
307
+ -----
308
+
309
+ If the target and template share a batch dimension, the target will
310
+ take precendence and the template dimension will be shifted to the right.
311
+ If target and template have the same dimension, but target specifies batch
312
+ dimensions, the leftmost template dimensions are assumed to be a collapse
313
+ dimension that operates on a measurement dimension.
314
+ """
315
+ self._target_dims = target_dims
316
+ self._template_dims = template_dims
317
+
318
+ target_ndim = len(self._target.shape)
319
+ self._is_target_batch, target_dims = self._compute_batch_dimension(
320
+ batch_dims=target_dims, ndim=target_ndim
321
+ )
322
+ template_ndim = len(self._template.shape)
323
+ self._is_template_batch, template_dims = self._compute_batch_dimension(
324
+ batch_dims=template_dims, ndim=template_ndim
325
+ )
326
+
327
+ batch_dims = len(target_dims) + len(template_dims)
328
+ target_measurement_dims = target_ndim - len(target_dims)
329
+
330
+ collapse_dims = max(
331
+ template_ndim - len(template_dims) - target_measurement_dims, 0
332
+ )
333
+
334
+ matching_dims = target_measurement_dims + batch_dims
335
+
336
+ target_shape = backend.full(
337
+ shape=(matching_dims,), fill_value=1, dtype=backend._default_dtype_int
338
+ )
339
+ template_shape = backend.full(
340
+ shape=(matching_dims,), fill_value=1, dtype=backend._default_dtype_int
341
+ )
342
+ batch_mask = backend.full(shape=(matching_dims,), fill_value=1, dtype=bool)
343
+
344
+ target_index, template_index = 0, 0
345
+ for k in range(matching_dims):
346
+ target_dim = k - target_index
347
+ template_dim = k - template_index
348
+
349
+ if target_dim in target_dims:
350
+ target_shape[k] = self._target.shape[target_dim]
351
+ if target_index == len(template_dims) and collapse_dims > 0:
352
+ template_shape[k] = self._template.shape[template_dim]
353
+ collapse_dims -= 1
354
+ template_index += 1
355
+ continue
356
+
357
+ if template_dim in template_dims:
358
+ template_shape[k] = self._template.shape[template_dim]
359
+ target_index += 1
360
+ continue
361
+
362
+ batch_mask[k] = 0
363
+ if target_dim < target_ndim:
364
+ target_shape[k] = self._target.shape[target_dim]
365
+ if template_dim < template_ndim:
366
+ template_shape[k] = self._template.shape[template_dim]
367
+
368
+ self._output_target_shape = target_shape
369
+ self._output_template_shape = template_shape
370
+ self._batch_mask = batch_mask
371
+
372
+ @staticmethod
373
+ def _compute_batch_dimension(
374
+ batch_dims: Tuple[int], ndim: int
375
+ ) -> Tuple[ArrayLike, Tuple]:
376
+ """
377
+ Computes a mask for the batch dimensions and the validated batch dimensions.
378
+
379
+ Parameters
380
+ ----------
381
+ batch_dims : Tuple[int]
382
+ A tuple of integers representing the batch dimensions.
383
+ ndim : int
384
+ The number of dimensions of the array.
385
+
386
+ Returns
387
+ -------
388
+ Tuple[ArrayLike, Tuple]
389
+ A tuple containing the mask (as an ArrayLike) and the validated batch dimensions.
390
+
391
+ Raises
392
+ ------
393
+ ValueError
394
+ If any dimension in batch_dims is not less than ndim.
395
+ """
396
+ mask = backend.zeros(ndim, dtype=bool)
397
+ if batch_dims is None:
398
+ return mask, ()
399
+
400
+ if isinstance(batch_dims, int):
401
+ batch_dims = (batch_dims,)
402
+
403
+ for dim in batch_dims:
404
+ if dim < ndim:
405
+ mask[dim] = 1
406
+ continue
407
+ raise ValueError(f"Batch indices needs to be < {ndim}, got {dim}.")
408
+
409
+ return mask, batch_dims
410
+
411
+ def target_padding(self, pad_target: bool = False) -> ArrayLike:
412
+ """
413
+ Computes padding for the target based on the template's shape.
414
+
415
+ Parameters
416
+ ----------
417
+ pad_target : bool, default False
418
+ If True, computes the padding required for the target. If False,
419
+ an array of zeros is returned.
420
+
421
+ Returns
422
+ -------
423
+ ArrayLike
424
+ An array indicating the padding for each dimension of the target.
425
+ """
426
+ target_padding = backend.zeros(
427
+ len(self.target.shape), dtype=backend._default_dtype_int
428
+ )
429
+
430
+ if pad_target:
431
+ backend.subtract(
432
+ self._template.shape,
433
+ backend.mod(self._template.shape, 2),
434
+ out=target_padding,
435
+ )
436
+ if hasattr(self, "_is_target_batch"):
437
+ target_padding[self._is_target_batch] = 0
438
+
439
+ return target_padding
440
+
441
+ def fourier_padding(
442
+ self, pad_fourier: bool = False
443
+ ) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
444
+ """
445
+ Computes an efficient shape for the forward Fourier transform, the
446
+ corresponding shape of the real-valued FFT, and the associated
447
+ translation shift.
448
+
449
+ Parameters
450
+ ----------
451
+ pad_fourier : bool, default False
452
+ If true, returns the shape of the full-convolution defined as sum of target
453
+ shape and template shape minus one. By default, returns unpadded transform.
454
+
455
+ Returns
456
+ -------
457
+ Tuple[ArrayLike, ArrayLike, ArrayLike]
458
+ A tuple containing the calculated fast shape, fast Fourier transform shape,
459
+ and the Fourier shift values, respectively.
460
+ """
461
+ template_shape = self._template.shape
462
+ if hasattr(self, "_output_template_shape"):
463
+ template_shape = self._output_template_shape
464
+ template_shape = backend.to_backend_array(template_shape)
465
+
466
+ target_shape = self._target.shape
467
+ if hasattr(self, "_output_target_shape"):
468
+ target_shape = self._output_target_shape
469
+ target_shape = backend.to_backend_array(target_shape)
470
+
471
+ fourier_pad = backend.to_backend_array(template_shape)
472
+ fourier_shift = backend.zeros(len(fourier_pad))
473
+
474
+ if not pad_fourier:
475
+ fourier_pad = backend.full(shape=len(fourier_pad), fill_value=1, dtype=int)
476
+
477
+ fourier_pad = backend.to_backend_array(fourier_pad)
478
+ if hasattr(self, "_batch_mask"):
479
+ batch_mask = backend.to_backend_array(self._batch_mask)
480
+ fourier_pad[batch_mask] = 1
481
+
482
+ ret = backend.compute_convolution_shapes(target_shape, fourier_pad)
483
+ convolution_shape, fast_shape, fast_ft_shape = ret
484
+ if not pad_fourier:
485
+ fourier_shift = 1 - backend.astype(backend.divide(template_shape, 2), int)
486
+ fourier_shift -= backend.mod(template_shape, 2)
487
+ shape_diff = backend.subtract(fast_shape, convolution_shape)
488
+ shape_diff = backend.astype(backend.divide(shape_diff, 2), int)
489
+ backend.add(fourier_shift, shape_diff, out=fourier_shift)
490
+
491
+ return fast_shape, fast_ft_shape, fourier_shift
492
+
247
493
  @property
248
494
  def rotations(self):
249
- """Return rotation matrices used for fitting."""
495
+ """Return stored rotation matrices.."""
250
496
  return self._rotations
251
497
 
252
498
  @rotations.setter
253
499
  def rotations(self, rotations: NDArray):
254
500
  """
255
- Set and reshape the rotation matrices for fitting.
501
+ Set and reshape the rotation matrices for template matching.
256
502
 
257
503
  Parameters
258
504
  ----------
259
505
  rotations : NDArray
260
- Rotations in shape (3 x 3), (1 x 3 x 3), or (n x k x k).
506
+ Rotations in shape (k x k), or (n x k x k).
261
507
  """
262
508
  if rotations.__class__ != np.ndarray:
263
509
  raise ValueError("Rotation set has to be of type numpy ndarray.")
@@ -273,14 +519,14 @@ class MatchingData:
273
519
  @property
274
520
  def target(self):
275
521
  """Returns the target NDArray."""
276
- if type(self._target) == Density:
522
+ if isinstance(self._target, Density):
277
523
  return self._target.data
278
524
  return self._target
279
525
 
280
526
  @property
281
527
  def template(self):
282
528
  """Returns the reversed template NDArray."""
283
- if type(self._template) == Density:
529
+ if isinstance(self._template, Density):
284
530
  return backend.reverse(self._template.data)
285
531
  return backend.reverse(self._template)
286
532
 
@@ -309,7 +555,7 @@ class MatchingData:
309
555
  @property
310
556
  def target_mask(self):
311
557
  """Returns the target mask NDArray."""
312
- if type(self._target_mask) == Density:
558
+ if isinstance(self._target_mask, Density):
313
559
  return self._target_mask.data
314
560
  return self._target_mask
315
561
 
@@ -324,6 +570,7 @@ class MatchingData:
324
570
  self._target_mask = mask
325
571
  self._targetmaskshape = self._target_mask.shape[::-1]
326
572
  return None
573
+
327
574
  self._target_mask = mask.astype(self._default_dtype, copy=False)
328
575
  self._targetmaskshape = self._target_mask.shape
329
576
 
@@ -337,7 +584,7 @@ class MatchingData:
337
584
  template : NDArray
338
585
  Array to set as the template.
339
586
  """
340
- if type(self._template_mask) == Density:
587
+ if isinstance(self._template_mask, Density):
341
588
  return backend.reverse(self._template_mask.data)
342
589
  return backend.reverse(self._template_mask)
343
590
 
@@ -308,6 +308,49 @@ def cam_setup(**kwargs):
308
308
  return corr_setup(**kwargs)
309
309
 
310
310
 
311
+ def _normalize_under_mask(template: NDArray, mask: NDArray, mask_intensity) -> None:
312
+ """
313
+ Standardizes the values in in template by subtracting the mean and dividing by the
314
+ standard deviation based on the elements in mask. Subsequently, the template is
315
+ multiplied by the mask.
316
+
317
+ Parameters
318
+ ----------
319
+ template : NDArray
320
+ The data array to be normalized. This array is modified in-place.
321
+ mask : NDArray
322
+ A boolean array of the same shape as `template`. True values indicate the positions in `template`
323
+ to consider for normalization.
324
+ mask_intensity : float
325
+ Mask intensity used to compute expectations.
326
+
327
+ References
328
+ ----------
329
+ .. [1] T. Hrabe, Y. Chen, S. Pfeffer, L. Kuhn Cuellar, A.-V. Mangold,
330
+ and F. Förster, J. Struct. Biol. 178, 177 (2012).
331
+ .. [2] M. L. Chaillet, G. van der Schot, I. Gubins, S. Roet,
332
+ R. C. Veltkamp, and F. Förster, Int. J. Mol. Sci. 24,
333
+ 13375 (2023)
334
+
335
+ Returns
336
+ -------
337
+ None
338
+ This function modifies `template` in-place and does not return any value.
339
+ """
340
+ masked_mean = backend.sum(backend.multiply(template, mask))
341
+ masked_mean = backend.divide(masked_mean, mask_intensity)
342
+ masked_std = backend.sum(backend.multiply(backend.square(template), mask))
343
+ masked_std = backend.subtract(
344
+ masked_std / mask_intensity, backend.square(masked_mean)
345
+ )
346
+ masked_std = backend.sqrt(backend.maximum(masked_std, 0))
347
+
348
+ backend.subtract(template, masked_mean, out=template)
349
+ backend.divide(template, masked_std, out=template)
350
+ backend.multiply(template, mask, out=template)
351
+ return None
352
+
353
+
311
354
  def flc_setup(
312
355
  rfftn: Callable,
313
356
  irfftn: Callable,
@@ -359,7 +402,8 @@ def flc_setup(
359
402
  ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
360
403
  ft_target2 = backend.preallocate_array(fast_ft_shape, complex_dtype)
361
404
  rfftn(target_pad, ft_target)
362
- rfftn(backend.square(target_pad), ft_target2)
405
+ backend.square(target_pad, out=target_pad)
406
+ rfftn(target_pad, ft_target2)
363
407
 
364
408
  # Convert arrays used in subsequent fitting to SharedMemory objects
365
409
  ft_target = backend.arr_to_sharedarr(
@@ -369,13 +413,9 @@ def flc_setup(
369
413
  arr=ft_target2, shared_memory_handler=shared_memory_handler
370
414
  )
371
415
 
372
- template_mask = template_mask > 0
373
- template_mean = backend.mean(template[template_mask])
374
- template_std = backend.std(template[template_mask])
375
- template_mask = backend.astype(template_mask, real_dtype)
376
-
377
- backend.divide(template - template_mean, template_std, out=template)
378
- backend.multiply(template, template_mask, out=template)
416
+ _normalize_under_mask(
417
+ template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
418
+ )
379
419
 
380
420
  template_buffer = backend.arr_to_sharedarr(
381
421
  arr=template, shared_memory_handler=shared_memory_handler
@@ -455,7 +495,6 @@ def flcSphericalMask_setup(
455
495
  :py:meth:`corr_scoring`
456
496
  """
457
497
  target_pad = backend.topleft_pad(target, fast_shape)
458
- template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
459
498
 
460
499
  # Target and squared target window sums
461
500
  ft_target = backend.preallocate_array(fast_ft_shape, complex_dtype)
@@ -467,16 +506,17 @@ def flcSphericalMask_setup(
467
506
  numerator2 = backend.preallocate_array(1, real_dtype)
468
507
 
469
508
  eps = backend.eps(real_dtype)
470
- n_observations = backend.sum(template_mask_pad > np.exp(-2))
509
+ n_observations = backend.sum(template_mask)
510
+
511
+ template_mask_pad = backend.topleft_pad(template_mask, fast_shape)
471
512
  rfftn(template_mask_pad, ft_template_mask)
472
513
 
473
- # Variance part denominator
514
+ # Denominator E(X^2) - E(X)^2
474
515
  rfftn(backend.square(target_pad), ft_target)
475
516
  backend.multiply(ft_target, ft_template_mask, out=ft_temp)
476
517
  irfftn(ft_temp, temp2)
477
518
  backend.divide(temp2, n_observations, out=temp2)
478
519
 
479
- # Mean part denominator
480
520
  rfftn(target_pad, ft_target)
481
521
  backend.multiply(ft_target, ft_template_mask, out=ft_temp)
482
522
  irfftn(ft_temp, temp)
@@ -495,14 +535,10 @@ def flcSphericalMask_setup(
495
535
  backend.fill(temp2, 0)
496
536
  temp2[nonzero_indices] = 1 / temp[nonzero_indices]
497
537
 
498
- template_mask = template_mask > np.exp(-2)
499
- template_mean = backend.mean(template[template_mask])
500
- template_std = backend.std(template[template_mask])
501
-
502
- template = backend.divide(backend.subtract(template, template_mean), template_std)
503
- backend.multiply(template, template_mask, out=template)
538
+ _normalize_under_mask(
539
+ template=template, mask=template_mask, mask_intensity=backend.sum(template_mask)
540
+ )
504
541
 
505
- # Convert arrays used in subsequent fitting to SharedMemory objects
506
542
  template_buffer = backend.arr_to_sharedarr(
507
543
  arr=template, shared_memory_handler=shared_memory_handler
508
544
  )
@@ -773,6 +809,7 @@ def corr_scoring(
773
809
  fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
774
810
  fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
775
811
 
812
+ template_sum = backend.sum(template)
776
813
  for index in range(rotations.shape[0]):
777
814
  rotation = rotations[index]
778
815
  backend.fill(arr, 0)
@@ -783,6 +820,9 @@ def corr_scoring(
783
820
  use_geometric_center=False,
784
821
  order=interpolation_order,
785
822
  )
823
+ rotation_norm = template_sum / backend.sum(arr)
824
+ backend.multiply(arr, rotation_norm, out=arr)
825
+
786
826
  rfftn(arr, ft_temp)
787
827
  template_filter_func(ft_temp, template_filter, out=ft_temp)
788
828
 
@@ -905,6 +945,7 @@ def flc_scoring(
905
945
  fourier_shift = callback_class_args.get("fourier_shift", backend.zeros(arr.ndim))
906
946
  fourier_shift_scores = backend.sum(fourier_shift != 0) != 0
907
947
 
948
+ unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
908
949
  for index in range(rotations.shape[0]):
909
950
  rotation = rotations[index]
910
951
  backend.fill(arr, 0)
@@ -918,8 +959,15 @@ def flc_scoring(
918
959
  use_geometric_center=False,
919
960
  order=interpolation_order,
920
961
  )
962
+ # Given the amount of FFTs, might aswell normalize properly
921
963
  n_observations = backend.sum(temp)
922
964
 
965
+ _normalize_under_mask(
966
+ template=arr[unpadded_slice],
967
+ mask=temp[unpadded_slice],
968
+ mask_intensity=n_observations,
969
+ )
970
+
923
971
  rfftn(temp, ft_temp)
924
972
 
925
973
  backend.multiply(ft_target, ft_temp, out=ft_denom)
@@ -1244,7 +1292,7 @@ def scan(
1244
1292
  The merged results from callback_class if provided otherwise None.
1245
1293
  """
1246
1294
  shape_diff = backend.subtract(
1247
- matching_data.target.shape, matching_data._template.shape
1295
+ matching_data._target.shape, matching_data._template.shape
1248
1296
  )
1249
1297
  if backend.sum(shape_diff < 0) and not pad_fourier:
1250
1298
  warnings.warn(
@@ -1254,22 +1302,10 @@ def scan(
1254
1302
  )
1255
1303
 
1256
1304
  matching_data.to_backend()
1257
- fourier_pad = matching_data.template.shape
1258
- fourier_shift = backend.zeros(len(fourier_pad))
1259
- if not pad_fourier:
1260
- fourier_pad = backend.full(shape=fourier_shift.shape, fill_value=1, dtype=int)
1261
1305
 
1262
- convolution_shape, fast_shape, fast_ft_shape = backend.compute_convolution_shapes(
1263
- matching_data._target.shape, fourier_pad
1306
+ fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
1307
+ pad_fourier=pad_fourier
1264
1308
  )
1265
- if not pad_fourier:
1266
- fourier_shift = 1 - backend.astype(
1267
- backend.divide(matching_data._template.shape, 2), int
1268
- )
1269
- fourier_shift -= backend.mod(matching_data._template.shape, 2)
1270
- shape_diff = backend.subtract(fast_shape, convolution_shape)
1271
- shape_diff = backend.astype(backend.divide(shape_diff, 2), int)
1272
- backend.add(fourier_shift, shape_diff, out=fourier_shift)
1273
1309
 
1274
1310
  callback_class_args["fourier_shift"] = fourier_shift
1275
1311
  rfftn, irfftn = backend.build_fft(
@@ -1354,7 +1390,7 @@ def scan(
1354
1390
 
1355
1391
  setup["fftargs"] = fftargs.copy()
1356
1392
  convolution_mode = "same"
1357
- if backend.sum(matching_data._target_pad) > 0:
1393
+ if backend.sum(backend.to_backend_array(matching_data._target_pad)) > 0:
1358
1394
  convolution_mode = "valid"
1359
1395
  setup["convolution_mode"] = convolution_mode
1360
1396
  setup["interpolation_order"] = interpolation_order
@@ -1459,17 +1495,13 @@ def scan_subsets(
1459
1495
  The merged results from callback_class if provided otherwise None.
1460
1496
  """
1461
1497
  target_splits = split_numpy_array_slices(
1462
- matching_data.target.shape, splits=target_splits
1498
+ matching_data._target.shape, splits=target_splits
1463
1499
  )
1464
1500
  template_splits = split_numpy_array_slices(
1465
1501
  matching_data._template.shape, splits=template_splits
1466
1502
  )
1503
+ target_pad = matching_data.target_padding(pad_target=pad_target_edges)
1467
1504
 
1468
- target_pad = np.zeros(len(matching_data.target.shape), dtype=int)
1469
- if pad_target_edges:
1470
- target_pad = np.subtract(
1471
- matching_data._template.shape, np.mod(matching_data._template.shape, 2)
1472
- )
1473
1505
  outer_jobs, inner_jobs = job_schedule
1474
1506
  results = Parallel(n_jobs=outer_jobs)(
1475
1507
  delayed(_run_inner)(