pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/matching_data.py CHANGED
@@ -1,18 +1,19 @@
1
- """ Data class for holding template matching data.
1
+ """ Class representation of template matching data.
2
2
 
3
3
  Copyright (c) 2023 European Molecular Biology Laboratory
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
7
  import warnings
8
- from typing import Tuple, List
8
+ from typing import Tuple, List, Optional
9
9
 
10
10
  import numpy as np
11
11
  from numpy.typing import NDArray
12
12
 
13
13
  from . import Density
14
14
  from .types import ArrayLike
15
- from .backends import backend
15
+ from .preprocessing import Compose
16
+ from .backends import backend as be
16
17
  from .matching_utils import compute_full_convolution_index
17
18
 
18
19
 
@@ -22,34 +23,53 @@ class MatchingData:
22
23
 
23
24
  Parameters
24
25
  ----------
25
- target : np.ndarray or Density
26
- Target data array for template matching.
27
- template : np.ndarray or Density
28
- Template data array for template matching.
26
+ target : np.ndarray or :py:class:`tme.density.Density`
27
+ Target data.
28
+ template : np.ndarray or :py:class:`tme.density.Density`
29
+ Template data.
30
+ target_mask : np.ndarray or :py:class:`tme.density.Density`, optional
31
+ Target mask data.
32
+ template_mask : np.ndarray or :py:class:`tme.density.Density`, optional
33
+ Template mask data.
34
+ invert_target : bool, optional
35
+ Whether to invert the target before template matching..
36
+ rotations: np.ndarray, optional
37
+ Template rotations to sample. Can be a single (d, d) or a stack (n, d, d)
38
+ of rotation matrices where d is the dimension of the template.
39
+
40
+ Examples
41
+ --------
42
+ The following achieves the minimal definition of a :py:class:`MatchingData` instance.
43
+
44
+ >>> import numpy as np
45
+ >>> from tme.matching_data import MatchingData
46
+ >>> target = np.random.rand(50,40,60)
47
+ >>> template = target[15:25, 10:20, 30:40]
48
+ >>> matching_data = MatchingData(target=target, template=template)
29
49
 
30
50
  """
31
51
 
32
- def __init__(self, target: NDArray, template: NDArray):
33
- self._default_dtype = np.float32
34
- self._complex_dtype = np.complex64
35
-
36
- self._target = target
37
- self._target_mask = None
38
- self._template_mask = None
39
- self._translation_offset = np.zeros(len(target.shape), dtype=int)
52
+ def __init__(
53
+ self,
54
+ target: NDArray,
55
+ template: NDArray,
56
+ template_mask: NDArray = None,
57
+ target_mask: NDArray = None,
58
+ invert_target: bool = False,
59
+ rotations: NDArray = None,
60
+ ):
61
+ self.target = target
62
+ self.target_mask = target_mask
40
63
 
41
64
  self.template = template
65
+ if template_mask is not None:
66
+ self.template_mask = template_mask
42
67
 
43
- self._target_pad = np.zeros(len(target.shape), dtype=int)
44
- self._template_pad = np.zeros(len(template.shape), dtype=int)
45
-
46
- self.template_filter = {}
47
- self.target_filter = {}
48
-
49
- self._invert_target = False
50
- self._rotations = None
68
+ self.rotations = rotations
69
+ self._translation_offset = np.zeros(len(target.shape), dtype=int)
70
+ self._invert_target = invert_target
51
71
 
52
- self._set_batch_dimension()
72
+ self._set_matching_dimension()
53
73
 
54
74
  @staticmethod
55
75
  def _shape_to_slice(shape: Tuple[int]):
@@ -78,8 +98,7 @@ class MatchingData:
78
98
  NDArray
79
99
  Loaded array.
80
100
  """
81
-
82
- if type(arr) == np.memmap:
101
+ if isinstance(arr, np.memmap):
83
102
  return np.memmap(arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype)
84
103
  return arr
85
104
 
@@ -114,7 +133,7 @@ class MatchingData:
114
133
  NDArray
115
134
  Subset of the input array with padding applied.
116
135
  """
117
- padding = backend.to_numpy_array(padding)
136
+ padding = be.to_numpy_array(padding)
118
137
  padding = np.maximum(padding, 0).astype(int)
119
138
 
120
139
  slice_start = np.array([x.start for x in arr_slice], dtype=int)
@@ -133,38 +152,18 @@ class MatchingData:
133
152
  arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
134
153
  arr_mesh = self._slice_to_mesh(arr_slice, arr.shape)
135
154
 
136
- arr_min, arr_max = None, None
137
- if type(arr) == Density:
138
- if type(arr.data) == np.memmap:
139
- dens = Density.from_file(arr.data.filename, subset=arr_slice)
140
- arr = dens.data
141
- arr_min = dens.metadata.get("min", None)
142
- arr_max = dens.metadata.get("max", None)
155
+ if isinstance(arr, Density):
156
+ if isinstance(arr.data, np.memmap):
157
+ arr = Density.from_file(arr.data.filename, subset=arr_slice).data
143
158
  else:
144
159
  arr = np.asarray(arr.data[*arr_mesh])
145
160
  else:
146
- if type(arr) == np.memmap:
161
+ if isinstance(arr, np.memmap):
147
162
  arr = np.memmap(
148
163
  arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype
149
164
  )
150
165
  arr = np.asarray(arr[*arr_mesh])
151
166
 
152
- def _warn_on_mismatch(
153
- expectation: float, computation: float, name: str
154
- ) -> float:
155
- if expectation is None:
156
- expectation = computation
157
- expectation, computation = float(expectation), float(computation)
158
-
159
- if abs(computation) > abs(expectation):
160
- warnings.warn(
161
- f"Computed {name} value is more extreme than value in file header"
162
- f" (|{computation}| > |{expectation}|). This may lead to issues"
163
- " with padding and contrast inversion."
164
- )
165
-
166
- return expectation
167
-
168
167
  padding = tuple(
169
168
  (left, right)
170
169
  for left, right in zip(
@@ -172,17 +171,11 @@ class MatchingData:
172
171
  np.subtract(right_pad, data_voxels_right),
173
172
  )
174
173
  )
175
- ret = np.pad(arr, padding, mode="reflect")
174
+ arr = np.pad(arr, padding, mode="reflect")
176
175
 
177
176
  if invert:
178
- arr_min = _warn_on_mismatch(arr_min, arr.min(), "min")
179
- arr_max = _warn_on_mismatch(arr_max, arr.max(), "max")
180
-
181
- # Avoid in-place operation in case ret is not floating point
182
- ret = (
183
- -np.divide(np.subtract(ret, arr_min), np.subtract(arr_max, arr_min)) + 1
184
- )
185
- return ret
177
+ arr = -arr
178
+ return arr
186
179
 
187
180
  def subset_by_slice(
188
181
  self,
@@ -193,135 +186,137 @@ class MatchingData:
193
186
  invert_target: bool = False,
194
187
  ) -> "MatchingData":
195
188
  """
196
- Slice the instance arrays based on the provided slices.
189
+ Subset class instance based on slices.
197
190
 
198
191
  Parameters
199
192
  ----------
200
193
  target_slice : tuple of slice, optional
201
- Slices for the target. If not provided, the full shape is used.
194
+ Target subset to use, all by default.
202
195
  template_slice : tuple of slice, optional
203
- Slices for the template. If not provided, the full shape is used.
196
+ Template subset to use, all by default.
204
197
  target_pad : NDArray, optional
205
- Padding for target. Defaults to zeros. If padding exceeds target,
206
- pad with mean.
198
+ Target padding, zero by default.
207
199
  template_pad : NDArray, optional
208
- Padding for template. Defaults to zeros. If padding exceeds template,
209
- pad with mean.
200
+ Template padding, zero by default.
210
201
 
211
202
  Returns
212
203
  -------
213
- MatchingData
214
- Newly allocated sliced class instance.
204
+ :py:class:`MatchingData`
205
+ Newly allocated subset of class instance.
206
+
207
+ Examples
208
+ --------
209
+ >>> import numpy as np
210
+ >>> from tme.matching_data import MatchingData
211
+ >>> target = np.random.rand(50,40,60)
212
+ >>> template = target[15:25, 10:20, 30:40]
213
+ >>> matching_data = MatchingData(target=target, template=template)
214
+ >>> subset = matching_data.subset_by_slice(
215
+ >>> target_slice=(slice(0, 10), slice(10,20), slice(15,35))
216
+ >>> )
215
217
  """
216
- target_shape = self._target.shape
217
- template_shape = self._template.shape
218
-
219
218
  if target_slice is None:
220
- target_slice = self._shape_to_slice(target_shape)
219
+ target_slice = self._shape_to_slice(self._target.shape)
221
220
  if template_slice is None:
222
- template_slice = self._shape_to_slice(template_shape)
221
+ template_slice = self._shape_to_slice(self._template.shape)
223
222
 
224
223
  if target_pad is None:
225
224
  target_pad = np.zeros(len(self._target.shape), dtype=int)
226
225
  if template_pad is None:
227
226
  template_pad = np.zeros(len(self._template.shape), dtype=int)
228
227
 
229
- indices = None
230
- if len(self._target.shape) == len(self._template.shape):
231
- indices = compute_full_convolution_index(
232
- outer_shape=self._target.shape,
233
- inner_shape=self._template.shape,
234
- outer_split=target_slice,
235
- inner_split=template_slice,
236
- )
237
-
228
+ target_mask, template_mask = None, None
238
229
  target_subset = self.subset_array(
239
- arr=self._target,
240
- arr_slice=target_slice,
241
- padding=target_pad,
242
- invert=self._invert_target,
230
+ self._target, target_slice, target_pad, invert=self._invert_target
243
231
  )
244
-
245
232
  template_subset = self.subset_array(
246
- arr=self._template,
247
- arr_slice=template_slice,
248
- padding=template_pad,
233
+ arr=self._template, arr_slice=template_slice, padding=template_pad
249
234
  )
250
- ret = self.__class__(target=target_subset, template=template_subset)
251
-
252
- target_offset = np.zeros(len(self._output_target_shape), dtype=int)
253
- target_offset[(target_offset.size - len(target_slice)) :] = [
254
- x.start for x in target_slice
255
- ]
256
- template_offset = np.zeros(len(self._output_target_shape), dtype=int)
257
- template_offset[(template_offset.size - len(template_slice)) :] = [
258
- x.start for x in template_slice
259
- ]
260
- ret._translation_offset = np.add(target_offset, template_offset)
261
-
262
- ret.template_filter = self.template_filter
263
- ret.target_filter = self.target_filter
264
- ret._rotations, ret.indices = self.rotations, indices
265
- ret._target_pad, ret._template_pad = target_pad, template_pad
266
- ret._invert_target = self._invert_target
267
-
268
235
  if self._target_mask is not None:
269
- ret.target_mask = self.subset_array(
236
+ target_mask = self.subset_array(
270
237
  arr=self._target_mask, arr_slice=target_slice, padding=target_pad
271
238
  )
272
239
  if self._template_mask is not None:
273
- ret.template_mask = self.subset_array(
274
- arr=self._template_mask,
275
- arr_slice=template_slice,
276
- padding=template_pad,
240
+ template_mask = self.subset_array(
241
+ arr=self._template_mask, arr_slice=template_slice, padding=template_pad
277
242
  )
278
243
 
279
- target_dims, template_dims = None, None
280
- if hasattr(self, "_target_dims"):
281
- target_dims = self._target_dims
244
+ ret = self.__class__(
245
+ target=target_subset,
246
+ template=template_subset,
247
+ template_mask=template_mask,
248
+ target_mask=target_mask,
249
+ rotations=self.rotations,
250
+ invert_target=self._invert_target,
251
+ )
252
+
253
+ # Deal with splitting offsets
254
+ target_offset = np.zeros(len(self._output_target_shape), dtype=int)
255
+ offset = target_offset.size - len(target_slice)
256
+ target_offset[offset:] = [x.start for x in target_slice]
257
+ template_offset = np.zeros(len(self._output_target_shape), dtype=int)
258
+ offset = template_offset.size - len(template_slice)
259
+ template_offset[offset:] = [x.start for x in template_slice]
260
+ ret._translation_offset = target_offset
261
+ if len(self._target.shape) == len(self._template.shape):
262
+ ret.indices = compute_full_convolution_index(
263
+ outer_shape=self._target.shape,
264
+ inner_shape=self._template.shape,
265
+ outer_split=target_slice,
266
+ inner_split=template_slice,
267
+ )
282
268
 
283
- if hasattr(self, "_template_dims"):
284
- template_dims = self._template_dims
269
+ ret._is_padded = be.sum(be.to_backend_array(target_pad)) > 0
270
+ ret.target_filter = self.target_filter
271
+ ret.template_filter = self.template_filter
285
272
 
286
- ret._set_batch_dimension(target_dims=target_dims, template_dims=template_dims)
273
+ ret._set_matching_dimension(
274
+ target_dims=getattr(self, "_target_dims", None),
275
+ template_dims=getattr(self, "_template_dims", None),
276
+ )
287
277
 
288
278
  return ret
289
279
 
290
280
  def to_backend(self) -> None:
291
281
  """
292
- Transfer the class instance's numpy arrays to the current backend.
282
+ Transfer and convert types of class instance's data arrays to the current backend
293
283
  """
294
- backend_arr = type(backend.zeros((1), dtype=backend._default_dtype))
284
+ backend_arr = type(be.zeros((1), dtype=be._float_dtype))
295
285
  for attr_name, attr_value in vars(self).items():
286
+ converted_array = None
296
287
  if isinstance(attr_value, np.ndarray):
297
- converted_array = backend.to_backend_array(attr_value.copy())
298
- setattr(self, attr_name, converted_array)
288
+ converted_array = be.to_backend_array(attr_value.copy())
299
289
  elif isinstance(attr_value, backend_arr):
300
- converted_array = backend.to_backend_array(attr_value)
301
- setattr(self, attr_name, converted_array)
290
+ converted_array = be.to_backend_array(attr_value)
291
+ else:
292
+ continue
293
+
294
+ current_dtype = be.get_fundamental_dtype(converted_array)
295
+ target_dtype = be._fundamental_dtypes[current_dtype]
302
296
 
303
- self._default_dtype = backend._default_dtype
304
- self._complex_dtype = backend._complex_dtype
297
+ # Optional, but scores are float so we avoid casting and potential issues
298
+ if attr_name in ("_template", "_template_mask", "_target", "_target_mask"):
299
+ target_dtype = be._float_dtype
305
300
 
306
- def _set_batch_dimension(
301
+ if target_dtype != current_dtype:
302
+ converted_array = be.astype(converted_array, target_dtype)
303
+
304
+ setattr(self, attr_name, converted_array)
305
+
306
+ def _set_matching_dimension(
307
307
  self, target_dims: Tuple[int] = None, template_dims: Tuple[int] = None
308
308
  ) -> None:
309
309
  """
310
- Sets the shapes of target and template for template matching considering
311
- their corresponding batch dimensions.
312
-
310
+ Sets matching dimensions for target and template.
313
311
  Parameters
314
312
  ----------
315
- target_dims : Tuple[int], optional
316
- A tuple of integers specifying the batch dimensions of the target. If None,
317
- the target is assumed not to have batch dimensions.
318
- template_dims : Tuple[int], optional
319
- A tuple of integers specifying the batch dimensions of the template. If None,
320
- the template is assumed not to have batch dimensions.
313
+ target_dims : tuple of ints, optional
314
+ Target batch dimensions, None by default.
315
+ template_dims : tuple of ints, optional
316
+ Template batch dimensions, None by default.
321
317
 
322
318
  Notes
323
319
  -----
324
-
325
320
  If the target and template share a batch dimension, the target will
326
321
  take precendence and the template dimension will be shifted to the right.
327
322
  If target and template have the same dimension, but target specifies batch
@@ -349,13 +344,9 @@ class MatchingData:
349
344
 
350
345
  matching_dims = target_measurement_dims + batch_dims
351
346
 
352
- target_shape = backend.full(
353
- shape=(matching_dims,), fill_value=1, dtype=backend._default_dtype_int
354
- )
355
- template_shape = backend.full(
356
- shape=(matching_dims,), fill_value=1, dtype=backend._default_dtype_int
357
- )
358
- batch_mask = backend.full(shape=(matching_dims,), fill_value=1, dtype=bool)
347
+ target_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
348
+ template_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
349
+ batch_mask = np.full(shape=matching_dims, fill_value=1, dtype=int)
359
350
 
360
351
  target_index, template_index = 0, 0
361
352
  for k in range(matching_dims):
@@ -381,9 +372,9 @@ class MatchingData:
381
372
  if template_dim < template_ndim:
382
373
  template_shape[k] = self._template.shape[template_dim]
383
374
 
384
- self._output_target_shape = target_shape
385
- self._output_template_shape = template_shape
386
- self._batch_mask = batch_mask
375
+ self._output_target_shape = tuple(int(x) for x in target_shape)
376
+ self._output_template_shape = tuple(int(x) for x in template_shape)
377
+ self._batch_mask = tuple(int(x) for x in batch_mask)
387
378
 
388
379
  @staticmethod
389
380
  def _compute_batch_dimension(
@@ -394,22 +385,22 @@ class MatchingData:
394
385
 
395
386
  Parameters
396
387
  ----------
397
- batch_dims : Tuple[int]
388
+ batch_dims : tuple of ints
398
389
  A tuple of integers representing the batch dimensions.
399
390
  ndim : int
400
391
  The number of dimensions of the array.
401
392
 
402
393
  Returns
403
394
  -------
404
- Tuple[ArrayLike, Tuple]
405
- A tuple containing the mask (as an ArrayLike) and the validated batch dimensions.
395
+ Tuple[ArrayLike, tuple of ints]
396
+ Mask and the corresponding batch dimensions.
406
397
 
407
398
  Raises
408
399
  ------
409
400
  ValueError
410
401
  If any dimension in batch_dims is not less than ndim.
411
402
  """
412
- mask = backend.zeros(ndim, dtype=bool)
403
+ mask = np.zeros(ndim, dtype=int)
413
404
  if batch_dims is None:
414
405
  return mask, ()
415
406
 
@@ -424,228 +415,292 @@ class MatchingData:
424
415
 
425
416
  return mask, batch_dims
426
417
 
427
- def target_padding(self, pad_target: bool = False) -> ArrayLike:
418
+ def target_padding(self, pad_target: bool = False) -> Tuple[int]:
428
419
  """
429
420
  Computes padding for the target based on the template's shape.
430
421
 
431
422
  Parameters
432
423
  ----------
433
424
  pad_target : bool, default False
434
- If True, computes the padding required for the target. If False,
435
- an array of zeros is returned.
425
+ Whether to pad the target, default returns an array of zeros.
436
426
 
437
427
  Returns
438
428
  -------
439
- ArrayLike
440
- An array indicating the padding for each dimension of the target.
429
+ tuple of ints
430
+ Padding along each dimension of the target.
441
431
  """
442
- target_padding = backend.zeros(
443
- len(self._output_target_shape), dtype=backend._default_dtype_int
444
- )
445
-
432
+ target_padding = np.zeros(len(self._output_target_shape), dtype=int)
446
433
  if pad_target:
447
- backend.subtract(
434
+ target_padding = np.subtract(
448
435
  self._output_template_shape,
449
- backend.mod(self._output_template_shape, 2),
450
- out=target_padding,
436
+ np.mod(self._output_template_shape, 2),
451
437
  )
452
438
  if hasattr(self, "_is_target_batch"):
453
- target_padding[self._is_target_batch] = 0
439
+ target_padding = np.multiply(
440
+ target_padding,
441
+ np.subtract(1, self._is_target_batch),
442
+ )
454
443
 
455
- return target_padding
444
+ return tuple(int(x) for x in target_padding)
456
445
 
457
- def fourier_padding(
458
- self, pad_fourier: bool = False
459
- ) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
446
+ @staticmethod
447
+ def _fourier_padding(
448
+ target_shape: NDArray,
449
+ template_shape: NDArray,
450
+ batch_mask: NDArray = None,
451
+ pad_fourier: bool = False,
452
+ ) -> Tuple[Tuple, Tuple, Tuple]:
460
453
  """
461
- Computes an efficient shape for the forward Fourier transform, the
462
- corresponding shape of the real-valued FFT, and the associated
463
- translation shift.
464
-
465
- Parameters
466
- ----------
467
- pad_fourier : bool, default False
468
- If true, returns the shape of the full-convolution defined as sum of target
469
- shape and template shape minus one. By default, returns unpadded transform.
470
-
471
- Returns
472
- -------
473
- Tuple[ArrayLike, ArrayLike, ArrayLike]
474
- A tuple containing the calculated fast shape, fast Fourier transform shape,
475
- and the Fourier shift values, respectively.
454
+ Determines an efficient shape for Fourier transforms considering zero-padding.
476
455
  """
477
- template_shape = self._template.shape
478
- if hasattr(self, "_output_template_shape"):
479
- template_shape = self._output_template_shape
480
- template_shape = backend.to_backend_array(template_shape)
456
+ fourier_pad = template_shape
457
+ fourier_shift = np.zeros_like(template_shape)
481
458
 
482
- target_shape = self._target.shape
483
- if hasattr(self, "_output_target_shape"):
484
- target_shape = self._output_target_shape
485
- target_shape = backend.to_backend_array(target_shape)
486
-
487
- fourier_pad = backend.to_backend_array(template_shape)
488
- fourier_shift = backend.zeros(len(fourier_pad))
459
+ if batch_mask is None:
460
+ batch_mask = np.zeros_like(template_shape)
461
+ batch_mask = np.asarray(batch_mask)
489
462
 
490
463
  if not pad_fourier:
491
- fourier_pad = backend.full(
492
- shape=(len(fourier_pad),),
493
- fill_value=1,
494
- dtype=backend._default_dtype_int,
495
- )
464
+ fourier_pad = np.ones(len(fourier_pad), dtype=int)
465
+ fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
466
+ fourier_pad = np.add(fourier_pad, batch_mask)
496
467
 
497
- fourier_pad = backend.to_backend_array(fourier_pad)
498
- if hasattr(self, "_batch_mask"):
499
- batch_mask = backend.to_backend_array(self._batch_mask)
500
- fourier_pad[batch_mask] = 1
501
-
502
- pad_shape = backend.maximum(target_shape, template_shape)
503
- ret = backend.compute_convolution_shapes(pad_shape, fourier_pad)
468
+ pad_shape = np.maximum(target_shape, template_shape)
469
+ ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
504
470
  convolution_shape, fast_shape, fast_ft_shape = ret
505
471
  if not pad_fourier:
506
- fourier_shift = 1 - backend.astype(backend.divide(template_shape, 2), int)
507
- fourier_shift -= backend.mod(template_shape, 2)
508
- shape_diff = backend.subtract(fast_shape, convolution_shape)
509
- shape_diff = backend.astype(backend.divide(shape_diff, 2), int)
472
+ fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
473
+ fourier_shift -= np.mod(template_shape, 2)
474
+ shape_diff = np.subtract(fast_shape, convolution_shape)
475
+ shape_diff = np.divide(shape_diff, 2).astype(int)
476
+ shape_diff = np.multiply(shape_diff, 1 - batch_mask)
477
+ np.add(fourier_shift, shape_diff, out=fourier_shift)
478
+
479
+ fourier_shift = fourier_shift.astype(int)
480
+
481
+ shape_diff = np.subtract(target_shape, template_shape)
482
+ shape_diff = np.multiply(shape_diff, 1 - batch_mask)
483
+ if np.sum(shape_diff < 0) and not pad_fourier:
484
+ warnings.warn(
485
+ "Target is larger than template and Fourier padding is turned off. "
486
+ "This may lead to inaccurate results. Prefer swapping template and target, "
487
+ "enable padding or turn off template centering."
488
+ )
489
+ fourier_shift = np.subtract(fourier_shift, np.divide(shape_diff, 2))
490
+ fourier_shift = fourier_shift.astype(int)
510
491
 
511
- if hasattr(self, "_batch_mask"):
512
- batch_mask = backend.to_backend_array(self._batch_mask)
513
- shape_diff[batch_mask] = 0
492
+ return tuple(fast_shape), tuple(fast_ft_shape), tuple(fourier_shift)
514
493
 
515
- backend.add(fourier_shift, shape_diff, out=fourier_shift)
494
+ def fourier_padding(self, pad_fourier: bool = False) -> Tuple[Tuple, Tuple, Tuple]:
495
+ """
496
+ Computes efficient shape four Fourier transforms and potential associated shifts.
497
+
498
+ Parameters
499
+ ----------
500
+ pad_fourier : bool, default False
501
+ If true, returns the shape of the full-convolution defined as sum of target
502
+ shape and template shape minus one, False by default.
516
503
 
517
- return fast_shape, fast_ft_shape, fourier_shift
504
+ Returns
505
+ -------
506
+ Tuple[tuple of int, tuple of int, tuple of int]
507
+ Tuple with real and complex Fourier transform shape, and corresponding shift.
508
+ """
509
+ return self._fourier_padding(
510
+ target_shape=be.to_numpy_array(self._output_target_shape),
511
+ template_shape=be.to_numpy_array(self._output_template_shape),
512
+ batch_mask=be.to_numpy_array(self._batch_mask),
513
+ pad_fourier=pad_fourier,
514
+ )
518
515
 
519
516
  @property
520
517
  def rotations(self):
521
- """Return stored rotation matrices.."""
518
+ """Return stored rotation matrices."""
522
519
  return self._rotations
523
520
 
524
521
  @rotations.setter
525
522
  def rotations(self, rotations: NDArray):
526
523
  """
527
- Set and reshape the rotation matrices for template matching.
524
+ Set :py:attr:`MatchingData.rotations`.
528
525
 
529
526
  Parameters
530
527
  ----------
531
528
  rotations : NDArray
532
- Rotations in shape (k x k), or (n x k x k).
529
+ Rotations matrices with shape (d, d) or (n, d, d).
533
530
  """
534
- if rotations.__class__ != np.ndarray:
535
- raise ValueError("Rotation set has to be of type numpy ndarray.")
536
- if rotations.ndim == 2:
531
+ if rotations is None:
532
+ print("No rotations provided, assuming identity for now.")
533
+ rotations = np.eye(len(self._target.shape))
534
+
535
+ if rotations.ndim not in (2, 3):
536
+ raise ValueError("Rotations have to be a rank 2 or 3 array.")
537
+ elif rotations.ndim == 2:
537
538
  print("Reshaping rotations array to rank 3.")
538
539
  rotations = rotations.reshape(1, *rotations.shape)
539
- elif rotations.ndim == 3:
540
- pass
541
- else:
542
- raise ValueError("Rotations have to be a rank 2 or 3 array.")
543
- self._rotations = rotations.astype(self._default_dtype)
540
+ self._rotations = rotations.astype(np.float32)
541
+
542
+ @staticmethod
543
+ def _get_data(attribute, output_shape: Tuple[int], reverse: bool = False):
544
+ if isinstance(attribute, Density):
545
+ attribute = attribute.data
546
+
547
+ if attribute is not None:
548
+ if reverse:
549
+ attribute = be.reverse(attribute)
550
+ attribute = attribute.reshape(tuple(int(x) for x in output_shape))
551
+
552
+ return attribute
544
553
 
545
554
  @property
546
555
  def target(self):
547
- """Returns the target NDArray."""
548
- if isinstance(self._target, Density):
549
- target = self._target.data
550
- else:
551
- target = self._target
552
- out_shape = backend.to_numpy_array(self._output_target_shape)
553
- return target.reshape(tuple(int(x) for x in out_shape))
556
+ """
557
+ Return the target.
558
+
559
+ Returns
560
+ -------
561
+ NDArray
562
+ Output data.
563
+ """
564
+ return self._get_data(self._target, self._output_target_shape, False)
565
+
566
+ @property
567
+ def target_mask(self):
568
+ """
569
+ Return the target mask.
570
+
571
+ Returns
572
+ -------
573
+ NDArray
574
+ Output data.
575
+ """
576
+ target_mask = getattr(self, "_target_mask", None)
577
+ return self._get_data(target_mask, self._output_target_shape, False)
554
578
 
555
579
  @property
556
580
  def template(self):
557
- """Returns the reversed template NDArray."""
558
- template = self._template
559
- if isinstance(self._template, Density):
560
- template = self._template.data
561
- template = backend.reverse(template)
562
- out_shape = backend.to_numpy_array(self._output_template_shape)
563
- return template.reshape(tuple(int(x) for x in out_shape))
581
+ """
582
+ Return the reversed template.
583
+
584
+ Returns
585
+ -------
586
+ NDArray
587
+ Output data.
588
+ """
589
+ return self._get_data(self._template, self._output_template_shape, True)
590
+
591
+ @property
592
+ def template_mask(self):
593
+ """
594
+ Return the reversed template mask.
595
+
596
+ Returns
597
+ -------
598
+ NDArray
599
+ Output data.
600
+ """
601
+ template_mask = getattr(self, "_template_mask", None)
602
+ return self._get_data(template_mask, self._output_template_shape, True)
603
+
604
+ @target.setter
605
+ def target(self, arr: NDArray):
606
+ """
607
+ Set :py:attr:`MatchingData.target`.
608
+
609
+ Parameters
610
+ ----------
611
+ arr : NDArray
612
+ Array to set as the target.
613
+ """
614
+ self._target = arr
564
615
 
565
616
  @template.setter
566
- def template(self, template: NDArray):
617
+ def template(self, arr: NDArray):
567
618
  """
568
- Set the template array. If not already defined, also initializes
569
- :py:attr:`MatchingData.template_mask` to an uninformative mask filled with
570
- ones.
619
+ Set :py:attr:`MatchingData.template` and initializes
620
+ :py:attr:`MatchingData.template_mask` to an to an uninformative
621
+ mask filled with ones if not already defined.
571
622
 
572
623
  Parameters
573
624
  ----------
574
- template : NDArray
625
+ arr : NDArray
575
626
  Array to set as the template.
576
627
  """
577
- self._templateshape = template.shape[::-1]
578
- if self._template_mask is None:
579
- self._template_mask = backend.full(
580
- shape=template.shape, dtype=float, fill_value=1
628
+ self._template = arr
629
+ if getattr(self, "_template_mask", None) is None:
630
+ self._template_mask = be.full(
631
+ shape=arr.shape, dtype=be._float_dtype, fill_value=1
581
632
  )
582
633
 
583
- if type(template) == Density:
584
- template = template.data
585
- self._template = template.astype(self._default_dtype, copy=False)
586
-
587
- @property
588
- def target_mask(self):
589
- """Returns the target mask NDArray."""
590
- target_mask = self._target_mask
591
- if isinstance(self._target_mask, Density):
592
- target_mask = self._target_mask.data
634
+ @staticmethod
635
+ def _set_mask(mask, shape: Tuple[int]):
636
+ if mask is not None:
637
+ if mask.shape != shape:
638
+ raise ValueError(
639
+ "Mask and respective data have to have the same shape."
640
+ )
641
+ return mask
593
642
 
594
- if target_mask is not None:
595
- out_shape = backend.to_numpy_array(self._output_target_shape)
596
- target_mask = target_mask.reshape(tuple(int(x) for x in out_shape))
643
+ @target_mask.setter
644
+ def target_mask(self, arr: NDArray):
645
+ """
646
+ Set :py:attr:`MatchingData.target_mask`.
597
647
 
598
- return target_mask
648
+ Parameters
649
+ ----------
650
+ arr : NDArray
651
+ Array to set as the target_mask.
652
+ """
653
+ self._target_mask = self._set_mask(mask=arr, shape=self._target.shape)
599
654
 
600
- @target_mask.setter
601
- def target_mask(self, mask: NDArray):
602
- """Sets the target mask."""
603
- if not np.all(self.target.shape == mask.shape):
604
- raise ValueError("Target and its mask have to have the same shape.")
655
+ @template_mask.setter
656
+ def template_mask(self, arr: NDArray):
657
+ """
658
+ Set :py:attr:`MatchingData.template_mask`.
605
659
 
606
- if type(mask) == Density:
607
- mask.data = mask.data.astype(self._default_dtype, copy=False)
608
- self._target_mask = mask
609
- self._targetmaskshape = self._target_mask.shape[::-1]
610
- return None
660
+ Parameters
661
+ ----------
662
+ arr : NDArray
663
+ Array to set as the template_mask.
664
+ """
665
+ self._template_mask = self._set_mask(mask=arr, shape=self._template.shape)
611
666
 
612
- self._target_mask = mask.astype(self._default_dtype, copy=False)
613
- self._targetmaskshape = self._target_mask.shape
667
+ @staticmethod
668
+ def _set_filter(composable_filter) -> Optional[Compose]:
669
+ if isinstance(composable_filter, Compose):
670
+ return composable_filter
671
+ return None
614
672
 
615
673
  @property
616
- def template_mask(self):
674
+ def template_filter(self) -> Optional[Compose]:
617
675
  """
618
- Set the template mask array after reversing it.
676
+ Returns the composable template filter.
619
677
 
620
- Parameters
621
- ----------
622
- template : NDArray
623
- Array to set as the template.
678
+ Returns
679
+ -------
680
+ :py:class:`tme.preprocessing.Compose` | None
681
+ Composable template filter or None.
624
682
  """
625
- mask = self._template_mask
626
- if isinstance(self._template_mask, Density):
627
- mask = self._template_mask.data
683
+ return getattr(self, "_template_filter", None)
628
684
 
629
- if mask is not None:
630
- mask = backend.reverse(mask)
631
- out_shape = backend.to_numpy_array(self._output_template_shape)
632
- mask = mask.reshape(tuple(int(x) for x in out_shape))
633
- return mask
685
+ @property
686
+ def target_filter(self) -> Optional[Compose]:
687
+ """
688
+ Returns the composable target filter.
634
689
 
635
- @template_mask.setter
636
- def template_mask(self, mask: NDArray):
637
- """Returns the reversed template mask NDArray."""
638
- if not np.all(self._templateshape[::-1] == mask.shape):
639
- raise ValueError("Target and its mask have to have the same shape.")
690
+ Returns
691
+ -------
692
+ :py:class:`tme.preprocessing.Compose` | None
693
+ Composable filter or None.
694
+ """
695
+ return getattr(self, "_target_filter", None)
640
696
 
641
- if type(mask) == Density:
642
- mask.data = mask.data.astype(self._default_dtype, copy=False)
643
- self._template_mask = mask
644
- self._templatemaskshape = self._template_mask.shape[::-1]
645
- return None
697
+ @template_filter.setter
698
+ def template_filter(self, composable_filter: Compose):
699
+ self._template_filter = self._set_filter(composable_filter)
646
700
 
647
- self._template_mask = mask.astype(self._default_dtype, copy=False)
648
- self._templatemaskshape = self._template_mask.shape[::-1]
701
+ @target_filter.setter
702
+ def target_filter(self, composable_filter: Compose):
703
+ self._target_filter = self._set_filter(composable_filter)
649
704
 
650
705
  def _split_rotations_on_jobs(self, n_jobs: int) -> List[NDArray]:
651
706
  """
@@ -670,3 +725,11 @@ class MatchingData:
670
725
  end_rot = None
671
726
  rot_list.append(self.rotations[init_rot:end_rot])
672
727
  return rot_list
728
+
729
+ def _free_data(self):
730
+ """
731
+ Free (dereference) data arrays owned by the class instance.
732
+ """
733
+ attrs = ("_target", "_template", "_template_mask", "_target_mask")
734
+ for attr in attrs:
735
+ setattr(self, attr, None)