pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/matching_data.py CHANGED
@@ -1,18 +1,19 @@
1
- """ Data class for holding template matching data.
1
+ """ Class representation of template matching data.
2
2
 
3
3
  Copyright (c) 2023 European Molecular Biology Laboratory
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
7
  import warnings
8
- from typing import Tuple, List
8
+ from typing import Tuple, List, Optional
9
9
 
10
10
  import numpy as np
11
11
  from numpy.typing import NDArray
12
12
 
13
13
  from . import Density
14
14
  from .types import ArrayLike
15
- from .backends import backend
15
+ from .preprocessing import Compose
16
+ from .backends import backend as be
16
17
  from .matching_utils import compute_full_convolution_index
17
18
 
18
19
 
@@ -31,9 +32,9 @@ class MatchingData:
31
32
  template_mask : np.ndarray or :py:class:`tme.density.Density`, optional
32
33
  Template mask data.
33
34
  invert_target : bool, optional
34
- Whether to invert and rescale the target before template matching..
35
+ Whether to invert the target before template matching..
35
36
  rotations: np.ndarray, optional
36
- Template rotations to sample. Can be a single (d x d) or a stack (n x d x d)
37
+ Template rotations to sample. Can be a single (d, d) or a stack (n, d, d)
37
38
  of rotation matrices where d is the dimension of the template.
38
39
 
39
40
  Examples
@@ -57,26 +58,18 @@ class MatchingData:
57
58
  invert_target: bool = False,
58
59
  rotations: NDArray = None,
59
60
  ):
60
- self._target = target
61
- self._target_mask = target_mask
62
- self._template_mask = template_mask
63
- self._translation_offset = np.zeros(len(target.shape), dtype=int)
61
+ self.target = target
62
+ self.target_mask = target_mask
64
63
 
65
64
  self.template = template
65
+ if template_mask is not None:
66
+ self.template_mask = template_mask
66
67
 
67
- self._target_pad = np.zeros(len(target.shape), dtype=int)
68
- self._template_pad = np.zeros(len(template.shape), dtype=int)
69
-
70
- self.template_filter = {}
71
- self.target_filter = {}
72
-
68
+ self.rotations = rotations
69
+ self._translation_offset = np.zeros(len(target.shape), dtype=int)
73
70
  self._invert_target = invert_target
74
71
 
75
- self._rotations = rotations
76
- if rotations is not None:
77
- self.rotations = rotations
78
-
79
- self._set_batch_dimension()
72
+ self._set_matching_dimension()
80
73
 
81
74
  @staticmethod
82
75
  def _shape_to_slice(shape: Tuple[int]):
@@ -105,8 +98,7 @@ class MatchingData:
105
98
  NDArray
106
99
  Loaded array.
107
100
  """
108
-
109
- if type(arr) == np.memmap:
101
+ if isinstance(arr, np.memmap):
110
102
  return np.memmap(arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype)
111
103
  return arr
112
104
 
@@ -141,7 +133,7 @@ class MatchingData:
141
133
  NDArray
142
134
  Subset of the input array with padding applied.
143
135
  """
144
- padding = backend.to_numpy_array(padding)
136
+ padding = be.to_numpy_array(padding)
145
137
  padding = np.maximum(padding, 0).astype(int)
146
138
 
147
139
  slice_start = np.array([x.start for x in arr_slice], dtype=int)
@@ -160,38 +152,18 @@ class MatchingData:
160
152
  arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
161
153
  arr_mesh = self._slice_to_mesh(arr_slice, arr.shape)
162
154
 
163
- arr_min, arr_max = None, None
164
- if type(arr) == Density:
165
- if type(arr.data) == np.memmap:
166
- dens = Density.from_file(arr.data.filename, subset=arr_slice)
167
- arr = dens.data
168
- arr_min = dens.metadata.get("min", None)
169
- arr_max = dens.metadata.get("max", None)
155
+ if isinstance(arr, Density):
156
+ if isinstance(arr.data, np.memmap):
157
+ arr = Density.from_file(arr.data.filename, subset=arr_slice).data
170
158
  else:
171
159
  arr = np.asarray(arr.data[*arr_mesh])
172
160
  else:
173
- if type(arr) == np.memmap:
161
+ if isinstance(arr, np.memmap):
174
162
  arr = np.memmap(
175
163
  arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype
176
164
  )
177
165
  arr = np.asarray(arr[*arr_mesh])
178
166
 
179
- def _warn_on_mismatch(
180
- expectation: float, computation: float, name: str
181
- ) -> float:
182
- if expectation is None:
183
- expectation = computation
184
- expectation, computation = float(expectation), float(computation)
185
-
186
- if abs(computation) > abs(expectation):
187
- warnings.warn(
188
- f"Computed {name} value is more extreme than value in file header"
189
- f" (|{computation}| > |{expectation}|). This may lead to issues"
190
- " with padding and contrast inversion."
191
- )
192
-
193
- return expectation
194
-
195
167
  padding = tuple(
196
168
  (left, right)
197
169
  for left, right in zip(
@@ -199,17 +171,11 @@ class MatchingData:
199
171
  np.subtract(right_pad, data_voxels_right),
200
172
  )
201
173
  )
202
- ret = np.pad(arr, padding, mode="reflect")
174
+ arr = np.pad(arr, padding, mode="reflect")
203
175
 
204
176
  if invert:
205
- arr_min = _warn_on_mismatch(arr_min, arr.min(), "min")
206
- arr_max = _warn_on_mismatch(arr_max, arr.max(), "max")
207
-
208
- # Avoid in-place operation in case ret is not floating point
209
- ret = (
210
- -np.divide(np.subtract(ret, arr_min), np.subtract(arr_max, arr_min)) + 1
211
- )
212
- return ret
177
+ arr = -arr
178
+ return arr
213
179
 
214
180
  def subset_by_slice(
215
181
  self,
@@ -220,97 +186,94 @@ class MatchingData:
220
186
  invert_target: bool = False,
221
187
  ) -> "MatchingData":
222
188
  """
223
- Slice the instance arrays based on the provided slices.
189
+ Subset class instance based on slices.
224
190
 
225
191
  Parameters
226
192
  ----------
227
193
  target_slice : tuple of slice, optional
228
- Slices for the target. If not provided, the full shape is used.
194
+ Target subset to use, all by default.
229
195
  template_slice : tuple of slice, optional
230
- Slices for the template. If not provided, the full shape is used.
196
+ Template subset to use, all by default.
231
197
  target_pad : NDArray, optional
232
- Padding for target. Defaults to zeros. If padding exceeds target,
233
- pad with mean.
198
+ Target padding, zero by default.
234
199
  template_pad : NDArray, optional
235
- Padding for template. Defaults to zeros. If padding exceeds template,
236
- pad with mean.
200
+ Template padding, zero by default.
237
201
 
238
202
  Returns
239
203
  -------
240
- MatchingData
241
- Newly allocated sliced class instance.
204
+ :py:class:`MatchingData`
205
+ Newly allocated subset of class instance.
206
+
207
+ Examples
208
+ --------
209
+ >>> import numpy as np
210
+ >>> from tme.matching_data import MatchingData
211
+ >>> target = np.random.rand(50,40,60)
212
+ >>> template = target[15:25, 10:20, 30:40]
213
+ >>> matching_data = MatchingData(target=target, template=template)
214
+ >>> subset = matching_data.subset_by_slice(
215
+ >>> target_slice=(slice(0, 10), slice(10,20), slice(15,35))
216
+ >>> )
242
217
  """
243
- target_shape = self._target.shape
244
- template_shape = self._template.shape
245
-
246
218
  if target_slice is None:
247
- target_slice = self._shape_to_slice(target_shape)
219
+ target_slice = self._shape_to_slice(self._target.shape)
248
220
  if template_slice is None:
249
- template_slice = self._shape_to_slice(template_shape)
221
+ template_slice = self._shape_to_slice(self._template.shape)
250
222
 
251
223
  if target_pad is None:
252
224
  target_pad = np.zeros(len(self._target.shape), dtype=int)
253
225
  if template_pad is None:
254
226
  template_pad = np.zeros(len(self._template.shape), dtype=int)
255
227
 
256
- indices = None
257
- if len(self._target.shape) == len(self._template.shape):
258
- indices = compute_full_convolution_index(
259
- outer_shape=self._target.shape,
260
- inner_shape=self._template.shape,
261
- outer_split=target_slice,
262
- inner_split=template_slice,
263
- )
264
-
228
+ target_mask, template_mask = None, None
265
229
  target_subset = self.subset_array(
266
- arr=self._target,
267
- arr_slice=target_slice,
268
- padding=target_pad,
269
- invert=self._invert_target,
230
+ self._target, target_slice, target_pad, invert=self._invert_target
270
231
  )
271
-
272
232
  template_subset = self.subset_array(
273
- arr=self._template,
274
- arr_slice=template_slice,
275
- padding=template_pad,
233
+ arr=self._template, arr_slice=template_slice, padding=template_pad
276
234
  )
277
- ret = self.__class__(target=target_subset, template=template_subset)
278
-
279
- target_offset = np.zeros(len(self._output_target_shape), dtype=int)
280
- target_offset[(target_offset.size - len(target_slice)) :] = [
281
- x.start for x in target_slice
282
- ]
283
- template_offset = np.zeros(len(self._output_target_shape), dtype=int)
284
- template_offset[(template_offset.size - len(template_slice)) :] = [
285
- x.start for x in template_slice
286
- ]
287
- ret._translation_offset = target_offset
288
-
289
- ret.template_filter = self.template_filter
290
- ret.target_filter = self.target_filter
291
- ret._rotations, ret.indices = self.rotations, indices
292
- ret._target_pad, ret._template_pad = target_pad, template_pad
293
- ret._invert_target = self._invert_target
294
-
295
235
  if self._target_mask is not None:
296
- ret._target_mask = self.subset_array(
236
+ target_mask = self.subset_array(
297
237
  arr=self._target_mask, arr_slice=target_slice, padding=target_pad
298
238
  )
299
239
  if self._template_mask is not None:
300
- ret.template_mask = self.subset_array(
301
- arr=self._template_mask,
302
- arr_slice=template_slice,
303
- padding=template_pad,
240
+ template_mask = self.subset_array(
241
+ arr=self._template_mask, arr_slice=template_slice, padding=template_pad
304
242
  )
305
243
 
306
- target_dims, template_dims = None, None
307
- if hasattr(self, "_target_dims"):
308
- target_dims = self._target_dims
244
+ ret = self.__class__(
245
+ target=target_subset,
246
+ template=template_subset,
247
+ template_mask=template_mask,
248
+ target_mask=target_mask,
249
+ rotations=self.rotations,
250
+ invert_target=self._invert_target,
251
+ )
309
252
 
310
- if hasattr(self, "_template_dims"):
311
- template_dims = self._template_dims
253
+ # Deal with splitting offsets
254
+ target_offset = np.zeros(len(self._output_target_shape), dtype=int)
255
+ offset = target_offset.size - len(target_slice)
256
+ target_offset[offset:] = [x.start for x in target_slice]
257
+ template_offset = np.zeros(len(self._output_target_shape), dtype=int)
258
+ offset = template_offset.size - len(template_slice)
259
+ template_offset[offset:] = [x.start for x in template_slice]
260
+ ret._translation_offset = target_offset
261
+ if len(self._target.shape) == len(self._template.shape):
262
+ ret.indices = compute_full_convolution_index(
263
+ outer_shape=self._target.shape,
264
+ inner_shape=self._template.shape,
265
+ outer_split=target_slice,
266
+ inner_split=template_slice,
267
+ )
312
268
 
313
- ret._set_batch_dimension(target_dims=target_dims, template_dims=template_dims)
269
+ ret._is_padded = be.sum(be.to_backend_array(target_pad)) > 0
270
+ ret.target_filter = self.target_filter
271
+ ret.template_filter = self.template_filter
272
+
273
+ ret._set_matching_dimension(
274
+ target_dims=getattr(self, "_target_dims", None),
275
+ template_dims=getattr(self, "_template_dims", None),
276
+ )
314
277
 
315
278
  return ret
316
279
 
@@ -318,47 +281,42 @@ class MatchingData:
318
281
  """
319
282
  Transfer and convert types of class instance's data arrays to the current backend
320
283
  """
321
- backend_arr = type(backend.zeros((1), dtype=backend._float_dtype))
284
+ backend_arr = type(be.zeros((1), dtype=be._float_dtype))
322
285
  for attr_name, attr_value in vars(self).items():
323
286
  converted_array = None
324
287
  if isinstance(attr_value, np.ndarray):
325
- converted_array = backend.to_backend_array(attr_value.copy())
288
+ converted_array = be.to_backend_array(attr_value.copy())
326
289
  elif isinstance(attr_value, backend_arr):
327
- converted_array = backend.to_backend_array(attr_value)
290
+ converted_array = be.to_backend_array(attr_value)
328
291
  else:
329
292
  continue
330
293
 
331
- current_dtype = backend.get_fundamental_dtype(converted_array)
332
- target_dtype = backend._fundamental_dtypes[current_dtype]
294
+ current_dtype = be.get_fundamental_dtype(converted_array)
295
+ target_dtype = be._fundamental_dtypes[current_dtype]
333
296
 
334
297
  # Optional, but scores are float so we avoid casting and potential issues
335
298
  if attr_name in ("_template", "_template_mask", "_target", "_target_mask"):
336
- target_dtype = backend._float_dtype
299
+ target_dtype = be._float_dtype
337
300
 
338
301
  if target_dtype != current_dtype:
339
- converted_array = backend.astype(converted_array, target_dtype)
302
+ converted_array = be.astype(converted_array, target_dtype)
340
303
 
341
304
  setattr(self, attr_name, converted_array)
342
305
 
343
- def _set_batch_dimension(
306
+ def _set_matching_dimension(
344
307
  self, target_dims: Tuple[int] = None, template_dims: Tuple[int] = None
345
308
  ) -> None:
346
309
  """
347
- Sets the shapes of target and template for template matching considering
348
- their corresponding batch dimensions.
349
-
310
+ Sets matching dimensions for target and template.
350
311
  Parameters
351
312
  ----------
352
- target_dims : Tuple[int], optional
353
- A tuple of integers specifying the batch dimensions of the target. If None,
354
- the target is assumed not to have batch dimensions.
355
- template_dims : Tuple[int], optional
356
- A tuple of integers specifying the batch dimensions of the template. If None,
357
- the template is assumed not to have batch dimensions.
313
+ target_dims : tuple of ints, optional
314
+ Target batch dimensions, None by default.
315
+ template_dims : tuple of ints, optional
316
+ Template batch dimensions, None by default.
358
317
 
359
318
  Notes
360
319
  -----
361
-
362
320
  If the target and template share a batch dimension, the target will
363
321
  take precendence and the template dimension will be shifted to the right.
364
322
  If target and template have the same dimension, but target specifies batch
@@ -386,15 +344,9 @@ class MatchingData:
386
344
 
387
345
  matching_dims = target_measurement_dims + batch_dims
388
346
 
389
- target_shape = backend.full(
390
- shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
391
- )
392
- template_shape = backend.full(
393
- shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
394
- )
395
- batch_mask = backend.full(
396
- shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
397
- )
347
+ target_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
348
+ template_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
349
+ batch_mask = np.full(shape=matching_dims, fill_value=1, dtype=int)
398
350
 
399
351
  target_index, template_index = 0, 0
400
352
  for k in range(matching_dims):
@@ -420,9 +372,9 @@ class MatchingData:
420
372
  if template_dim < template_ndim:
421
373
  template_shape[k] = self._template.shape[template_dim]
422
374
 
423
- self._output_target_shape = target_shape
424
- self._output_template_shape = template_shape
425
- self._batch_mask = batch_mask
375
+ self._output_target_shape = tuple(int(x) for x in target_shape)
376
+ self._output_template_shape = tuple(int(x) for x in template_shape)
377
+ self._batch_mask = tuple(int(x) for x in batch_mask)
426
378
 
427
379
  @staticmethod
428
380
  def _compute_batch_dimension(
@@ -433,22 +385,22 @@ class MatchingData:
433
385
 
434
386
  Parameters
435
387
  ----------
436
- batch_dims : Tuple[int]
388
+ batch_dims : tuple of ints
437
389
  A tuple of integers representing the batch dimensions.
438
390
  ndim : int
439
391
  The number of dimensions of the array.
440
392
 
441
393
  Returns
442
394
  -------
443
- Tuple[ArrayLike, Tuple]
444
- A tuple containing the mask (as an ArrayLike) and the validated batch dimensions.
395
+ Tuple[ArrayLike, tuple of ints]
396
+ Mask and the corresponding batch dimensions.
445
397
 
446
398
  Raises
447
399
  ------
448
400
  ValueError
449
401
  If any dimension in batch_dims is not less than ndim.
450
402
  """
451
- mask = backend.zeros(ndim, dtype=bool)
403
+ mask = np.zeros(ndim, dtype=int)
452
404
  if batch_dims is None:
453
405
  return mask, ()
454
406
 
@@ -463,215 +415,292 @@ class MatchingData:
463
415
 
464
416
  return mask, batch_dims
465
417
 
466
- def target_padding(self, pad_target: bool = False) -> ArrayLike:
418
+ def target_padding(self, pad_target: bool = False) -> Tuple[int]:
467
419
  """
468
420
  Computes padding for the target based on the template's shape.
469
421
 
470
422
  Parameters
471
423
  ----------
472
424
  pad_target : bool, default False
473
- If True, computes the padding required for the target. If False,
474
- an array of zeros is returned.
425
+ Whether to pad the target, default returns an array of zeros.
475
426
 
476
427
  Returns
477
428
  -------
478
- ArrayLike
479
- An array indicating the padding for each dimension of the target.
429
+ tuple of ints
430
+ Padding along each dimension of the target.
480
431
  """
481
- target_padding = backend.zeros(
482
- len(self._output_target_shape), dtype=backend._int_dtype
483
- )
484
-
432
+ target_padding = np.zeros(len(self._output_target_shape), dtype=int)
485
433
  if pad_target:
486
- backend.subtract(
434
+ target_padding = np.subtract(
487
435
  self._output_template_shape,
488
- backend.mod(self._output_template_shape, 2),
489
- out=target_padding,
436
+ np.mod(self._output_template_shape, 2),
490
437
  )
491
438
  if hasattr(self, "_is_target_batch"):
492
- target_padding[self._is_target_batch] = 0
439
+ target_padding = np.multiply(
440
+ target_padding,
441
+ np.subtract(1, self._is_target_batch),
442
+ )
493
443
 
494
- return target_padding
444
+ return tuple(int(x) for x in target_padding)
495
445
 
496
- def fourier_padding(
497
- self, pad_fourier: bool = False
498
- ) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
446
+ @staticmethod
447
+ def _fourier_padding(
448
+ target_shape: NDArray,
449
+ template_shape: NDArray,
450
+ batch_mask: NDArray = None,
451
+ pad_fourier: bool = False,
452
+ ) -> Tuple[Tuple, Tuple, Tuple]:
499
453
  """
500
- Computes an efficient shape for the forward Fourier transform, the
501
- corresponding shape of the real-valued FFT, and the associated
502
- translation shift.
503
-
504
- Parameters
505
- ----------
506
- pad_fourier : bool, default False
507
- If true, returns the shape of the full-convolution defined as sum of target
508
- shape and template shape minus one. By default, returns unpadded transform.
509
-
510
- Returns
511
- -------
512
- Tuple[ArrayLike, ArrayLike, ArrayLike]
513
- A tuple containing the calculated fast shape, fast Fourier transform shape,
514
- and the Fourier shift values, respectively.
454
+ Determines an efficient shape for Fourier transforms considering zero-padding.
515
455
  """
516
- template_shape = self._template.shape
517
- if hasattr(self, "_output_template_shape"):
518
- template_shape = self._output_template_shape
519
- template_shape = backend.to_backend_array(template_shape)
520
-
521
- target_shape = self._target.shape
522
- if hasattr(self, "_output_target_shape"):
523
- target_shape = self._output_target_shape
524
- target_shape = backend.to_backend_array(target_shape)
456
+ fourier_pad = template_shape
457
+ fourier_shift = np.zeros_like(template_shape)
525
458
 
526
- fourier_pad = backend.to_backend_array(template_shape)
527
- fourier_shift = backend.zeros(len(fourier_pad))
459
+ if batch_mask is None:
460
+ batch_mask = np.zeros_like(template_shape)
461
+ batch_mask = np.asarray(batch_mask)
528
462
 
529
463
  if not pad_fourier:
530
- fourier_pad = backend.full(
531
- shape=(len(fourier_pad),),
532
- fill_value=1,
533
- dtype=backend._int_dtype,
534
- )
464
+ fourier_pad = np.ones(len(fourier_pad), dtype=int)
465
+ fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
466
+ fourier_pad = np.add(fourier_pad, batch_mask)
535
467
 
536
- fourier_pad = backend.to_backend_array(fourier_pad)
537
- if hasattr(self, "_batch_mask"):
538
- batch_mask = backend.to_backend_array(self._batch_mask)
539
- backend.multiply(fourier_pad, 1 - batch_mask, out=fourier_pad)
540
- backend.add(fourier_pad, batch_mask, out=fourier_pad)
541
-
542
- pad_shape = backend.maximum(target_shape, template_shape)
543
- ret = backend.compute_convolution_shapes(pad_shape, fourier_pad)
468
+ pad_shape = np.maximum(target_shape, template_shape)
469
+ ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
544
470
  convolution_shape, fast_shape, fast_ft_shape = ret
545
471
  if not pad_fourier:
546
- fourier_shift = 1 - backend.astype(backend.divide(template_shape, 2), int)
547
- fourier_shift -= backend.mod(template_shape, 2)
548
- shape_diff = backend.subtract(fast_shape, convolution_shape)
549
- shape_diff = backend.astype(backend.divide(shape_diff, 2), int)
472
+ fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
473
+ fourier_shift -= np.mod(template_shape, 2)
474
+ shape_diff = np.subtract(fast_shape, convolution_shape)
475
+ shape_diff = np.divide(shape_diff, 2).astype(int)
476
+ shape_diff = np.multiply(shape_diff, 1 - batch_mask)
477
+ np.add(fourier_shift, shape_diff, out=fourier_shift)
478
+
479
+ fourier_shift = fourier_shift.astype(int)
480
+
481
+ shape_diff = np.subtract(target_shape, template_shape)
482
+ shape_diff = np.multiply(shape_diff, 1 - batch_mask)
483
+ if np.sum(shape_diff < 0) and not pad_fourier:
484
+ warnings.warn(
485
+ "Target is larger than template and Fourier padding is turned off. "
486
+ "This may lead to inaccurate results. Prefer swapping template and target, "
487
+ "enable padding or turn off template centering."
488
+ )
489
+ fourier_shift = np.subtract(fourier_shift, np.divide(shape_diff, 2))
490
+ fourier_shift = fourier_shift.astype(int)
550
491
 
551
- if hasattr(self, "_batch_mask"):
552
- batch_mask = backend.to_backend_array(self._batch_mask)
553
- backend.multiply(shape_diff, 1 - batch_mask, out=shape_diff)
492
+ return tuple(fast_shape), tuple(fast_ft_shape), tuple(fourier_shift)
554
493
 
555
- backend.add(fourier_shift, shape_diff, out=fourier_shift)
494
+ def fourier_padding(self, pad_fourier: bool = False) -> Tuple[Tuple, Tuple, Tuple]:
495
+ """
496
+ Computes efficient shape four Fourier transforms and potential associated shifts.
556
497
 
557
- fourier_shift = backend.astype(fourier_shift, backend._int_dtype)
498
+ Parameters
499
+ ----------
500
+ pad_fourier : bool, default False
501
+ If true, returns the shape of the full-convolution defined as sum of target
502
+ shape and template shape minus one, False by default.
558
503
 
559
- return fast_shape, fast_ft_shape, fourier_shift
504
+ Returns
505
+ -------
506
+ Tuple[tuple of int, tuple of int, tuple of int]
507
+ Tuple with real and complex Fourier transform shape, and corresponding shift.
508
+ """
509
+ return self._fourier_padding(
510
+ target_shape=be.to_numpy_array(self._output_target_shape),
511
+ template_shape=be.to_numpy_array(self._output_template_shape),
512
+ batch_mask=be.to_numpy_array(self._batch_mask),
513
+ pad_fourier=pad_fourier,
514
+ )
560
515
 
561
516
  @property
562
517
  def rotations(self):
563
- """Return stored rotation matrices.."""
518
+ """Return stored rotation matrices."""
564
519
  return self._rotations
565
520
 
566
521
  @rotations.setter
567
522
  def rotations(self, rotations: NDArray):
568
523
  """
569
- Set and reshape the rotation matrices for template matching.
524
+ Set :py:attr:`MatchingData.rotations`.
570
525
 
571
526
  Parameters
572
527
  ----------
573
528
  rotations : NDArray
574
- Rotations in shape (k x k), or (n x k x k).
529
+ Rotations matrices with shape (d, d) or (n, d, d).
575
530
  """
576
- if rotations.__class__ != np.ndarray:
577
- raise ValueError("Rotation set has to be of type numpy ndarray.")
578
- if rotations.ndim == 2:
531
+ if rotations is None:
532
+ print("No rotations provided, assuming identity for now.")
533
+ rotations = np.eye(len(self._target.shape))
534
+
535
+ if rotations.ndim not in (2, 3):
536
+ raise ValueError("Rotations have to be a rank 2 or 3 array.")
537
+ elif rotations.ndim == 2:
579
538
  print("Reshaping rotations array to rank 3.")
580
539
  rotations = rotations.reshape(1, *rotations.shape)
581
- elif rotations.ndim == 3:
582
- pass
583
- else:
584
- raise ValueError("Rotations have to be a rank 2 or 3 array.")
585
540
  self._rotations = rotations.astype(np.float32)
586
541
 
542
+ @staticmethod
543
+ def _get_data(attribute, output_shape: Tuple[int], reverse: bool = False):
544
+ if isinstance(attribute, Density):
545
+ attribute = attribute.data
546
+
547
+ if attribute is not None:
548
+ if reverse:
549
+ attribute = be.reverse(attribute)
550
+ attribute = attribute.reshape(tuple(int(x) for x in output_shape))
551
+
552
+ return attribute
553
+
587
554
  @property
588
555
  def target(self):
589
- """Returns the target."""
590
- target = self._target
591
- if isinstance(self._target, Density):
592
- target = self._target.data
556
+ """
557
+ Return the target.
558
+
559
+ Returns
560
+ -------
561
+ NDArray
562
+ Output data.
563
+ """
564
+ return self._get_data(self._target, self._output_target_shape, False)
565
+
566
+ @property
567
+ def target_mask(self):
568
+ """
569
+ Return the target mask.
593
570
 
594
- out_shape = tuple(int(x) for x in self._output_target_shape)
595
- return target.reshape(out_shape)
571
+ Returns
572
+ -------
573
+ NDArray
574
+ Output data.
575
+ """
576
+ target_mask = getattr(self, "_target_mask", None)
577
+ return self._get_data(target_mask, self._output_target_shape, False)
596
578
 
597
579
  @property
598
580
  def template(self):
599
- """Returns the reversed template."""
600
- template = self._template
601
- if isinstance(self._template, Density):
602
- template = self._template.data
603
- template = backend.reverse(template)
604
- out_shape = tuple(int(x) for x in self._output_template_shape)
605
- return template.reshape(out_shape)
581
+ """
582
+ Return the reversed template.
583
+
584
+ Returns
585
+ -------
586
+ NDArray
587
+ Output data.
588
+ """
589
+ return self._get_data(self._template, self._output_template_shape, True)
590
+
591
+ @property
592
+ def template_mask(self):
593
+ """
594
+ Return the reversed template mask.
595
+
596
+ Returns
597
+ -------
598
+ NDArray
599
+ Output data.
600
+ """
601
+ template_mask = getattr(self, "_template_mask", None)
602
+ return self._get_data(template_mask, self._output_template_shape, True)
603
+
604
+ @target.setter
605
+ def target(self, arr: NDArray):
606
+ """
607
+ Set :py:attr:`MatchingData.target`.
608
+
609
+ Parameters
610
+ ----------
611
+ arr : NDArray
612
+ Array to set as the target.
613
+ """
614
+ self._target = arr
606
615
 
607
616
  @template.setter
608
- def template(self, template: NDArray):
617
+ def template(self, arr: NDArray):
609
618
  """
610
- Set the template array. If not already defined, also initializes
611
- :py:attr:`MatchingData.template_mask` to an uninformative mask filled with
612
- ones.
619
+ Set :py:attr:`MatchingData.template` and initializes
620
+ :py:attr:`MatchingData.template_mask` to an to an uninformative
621
+ mask filled with ones if not already defined.
613
622
 
614
623
  Parameters
615
624
  ----------
616
- template : NDArray
625
+ arr : NDArray
617
626
  Array to set as the template.
618
627
  """
619
- self._templateshape = template.shape[::-1]
620
- if self._template_mask is None:
621
- self._template_mask = backend.full(
622
- shape=template.shape, dtype=float, fill_value=1
628
+ self._template = arr
629
+ if getattr(self, "_template_mask", None) is None:
630
+ self._template_mask = be.full(
631
+ shape=arr.shape, dtype=be._float_dtype, fill_value=1
623
632
  )
624
633
 
625
- self._template = template
634
+ @staticmethod
635
+ def _set_mask(mask, shape: Tuple[int]):
636
+ if mask is not None:
637
+ if mask.shape != shape:
638
+ raise ValueError(
639
+ "Mask and respective data have to have the same shape."
640
+ )
641
+ return mask
626
642
 
627
- @property
628
- def target_mask(self):
629
- """Returns the target mask NDArray."""
630
- target_mask = self._target_mask
631
- if isinstance(self._target_mask, Density):
632
- target_mask = self._target_mask.data
643
+ @target_mask.setter
644
+ def target_mask(self, arr: NDArray):
645
+ """
646
+ Set :py:attr:`MatchingData.target_mask`.
633
647
 
634
- if target_mask is not None:
635
- out_shape = tuple(int(x) for x in self._output_target_shape)
636
- target_mask = target_mask.reshape(out_shape)
648
+ Parameters
649
+ ----------
650
+ arr : NDArray
651
+ Array to set as the target_mask.
652
+ """
653
+ self._target_mask = self._set_mask(mask=arr, shape=self._target.shape)
637
654
 
638
- return target_mask
655
+ @template_mask.setter
656
+ def template_mask(self, arr: NDArray):
657
+ """
658
+ Set :py:attr:`MatchingData.template_mask`.
639
659
 
640
- @target_mask.setter
641
- def target_mask(self, mask: NDArray):
642
- """Sets the target mask."""
643
- if not np.all(self.target.shape == mask.shape):
644
- raise ValueError("Target and its mask have to have the same shape.")
660
+ Parameters
661
+ ----------
662
+ arr : NDArray
663
+ Array to set as the template_mask.
664
+ """
665
+ self._template_mask = self._set_mask(mask=arr, shape=self._template.shape)
645
666
 
646
- self._target_mask = mask
667
+ @staticmethod
668
+ def _set_filter(composable_filter) -> Optional[Compose]:
669
+ if isinstance(composable_filter, Compose):
670
+ return composable_filter
671
+ return None
647
672
 
648
673
  @property
649
- def template_mask(self):
674
+ def template_filter(self) -> Optional[Compose]:
650
675
  """
651
- Set the template mask array after reversing it.
676
+ Returns the composable template filter.
652
677
 
653
- Parameters
654
- ----------
655
- template : NDArray
656
- Array to set as the template.
678
+ Returns
679
+ -------
680
+ :py:class:`tme.preprocessing.Compose` | None
681
+ Composable template filter or None.
657
682
  """
658
- mask = self._template_mask
659
- if isinstance(self._template_mask, Density):
660
- mask = self._template_mask.data
683
+ return getattr(self, "_template_filter", None)
661
684
 
662
- if mask is not None:
663
- mask = backend.reverse(mask)
664
- out_shape = tuple(int(x) for x in self._output_template_shape)
665
- mask = mask.reshape(out_shape)
666
- return mask
685
+ @property
686
+ def target_filter(self) -> Optional[Compose]:
687
+ """
688
+ Returns the composable target filter.
667
689
 
668
- @template_mask.setter
669
- def template_mask(self, mask: NDArray):
670
- """Returns the reversed template mask NDArray."""
671
- if not np.all(self._templateshape[::-1] == mask.shape):
672
- raise ValueError("Template and its mask have to have the same shape.")
690
+ Returns
691
+ -------
692
+ :py:class:`tme.preprocessing.Compose` | None
693
+ Composable filter or None.
694
+ """
695
+ return getattr(self, "_target_filter", None)
673
696
 
674
- self._template_mask = mask
697
+ @template_filter.setter
698
+ def template_filter(self, composable_filter: Compose):
699
+ self._template_filter = self._set_filter(composable_filter)
700
+
701
+ @target_filter.setter
702
+ def target_filter(self, composable_filter: Compose):
703
+ self._target_filter = self._set_filter(composable_filter)
675
704
 
676
705
  def _split_rotations_on_jobs(self, n_jobs: int) -> List[NDArray]:
677
706
  """
@@ -696,3 +725,11 @@ class MatchingData:
696
725
  end_rot = None
697
726
  rot_list.append(self.rotations[init_rot:end_rot])
698
727
  return rot_list
728
+
729
+ def _free_data(self):
730
+ """
731
+ Free (dereference) data arrays owned by the class instance.
732
+ """
733
+ attrs = ("_target", "_template", "_template_mask", "_target_mask")
734
+ for attr in attrs:
735
+ setattr(self, attr, None)