pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/matching_data.py CHANGED
@@ -1,18 +1,19 @@
1
- """ Data class for holding template matching data.
1
+ """ Class representation of template matching data.
2
2
 
3
3
  Copyright (c) 2023 European Molecular Biology Laboratory
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
7
  import warnings
8
- from typing import Tuple, List
8
+ from typing import Tuple, List, Optional
9
9
 
10
10
  import numpy as np
11
11
  from numpy.typing import NDArray
12
12
 
13
13
  from . import Density
14
14
  from .types import ArrayLike
15
- from .backends import backend
15
+ from .preprocessing import Compose
16
+ from .backends import backend as be
16
17
  from .matching_utils import compute_full_convolution_index
17
18
 
18
19
 
@@ -31,9 +32,9 @@ class MatchingData:
31
32
  template_mask : np.ndarray or :py:class:`tme.density.Density`, optional
32
33
  Template mask data.
33
34
  invert_target : bool, optional
34
- Whether to invert and rescale the target before template matching..
35
+ Whether to invert the target before template matching..
35
36
  rotations: np.ndarray, optional
36
- Template rotations to sample. Can be a single (d x d) or a stack (n x d x d)
37
+ Template rotations to sample. Can be a single (d, d) or a stack (n, d, d)
37
38
  of rotation matrices where d is the dimension of the template.
38
39
 
39
40
  Examples
@@ -57,26 +58,18 @@ class MatchingData:
57
58
  invert_target: bool = False,
58
59
  rotations: NDArray = None,
59
60
  ):
60
- self._target = target
61
- self._target_mask = target_mask
62
- self._template_mask = template_mask
63
- self._translation_offset = np.zeros(len(target.shape), dtype=int)
61
+ self.target = target
62
+ self.target_mask = target_mask
64
63
 
65
64
  self.template = template
65
+ if template_mask is not None:
66
+ self.template_mask = template_mask
66
67
 
67
- self._target_pad = np.zeros(len(target.shape), dtype=int)
68
- self._template_pad = np.zeros(len(template.shape), dtype=int)
69
-
70
- self.template_filter = {}
71
- self.target_filter = {}
72
-
68
+ self.rotations = rotations
69
+ self._translation_offset = np.zeros(len(target.shape), dtype=int)
73
70
  self._invert_target = invert_target
74
71
 
75
- self._rotations = rotations
76
- if rotations is not None:
77
- self.rotations = rotations
78
-
79
- self._set_batch_dimension()
72
+ self._set_matching_dimension()
80
73
 
81
74
  @staticmethod
82
75
  def _shape_to_slice(shape: Tuple[int]):
@@ -105,8 +98,7 @@ class MatchingData:
105
98
  NDArray
106
99
  Loaded array.
107
100
  """
108
-
109
- if type(arr) == np.memmap:
101
+ if isinstance(arr, np.memmap):
110
102
  return np.memmap(arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype)
111
103
  return arr
112
104
 
@@ -141,7 +133,7 @@ class MatchingData:
141
133
  NDArray
142
134
  Subset of the input array with padding applied.
143
135
  """
144
- padding = backend.to_numpy_array(padding)
136
+ padding = be.to_numpy_array(padding)
145
137
  padding = np.maximum(padding, 0).astype(int)
146
138
 
147
139
  slice_start = np.array([x.start for x in arr_slice], dtype=int)
@@ -160,38 +152,18 @@ class MatchingData:
160
152
  arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
161
153
  arr_mesh = self._slice_to_mesh(arr_slice, arr.shape)
162
154
 
163
- arr_min, arr_max = None, None
164
- if type(arr) == Density:
165
- if type(arr.data) == np.memmap:
166
- dens = Density.from_file(arr.data.filename, subset=arr_slice)
167
- arr = dens.data
168
- arr_min = dens.metadata.get("min", None)
169
- arr_max = dens.metadata.get("max", None)
155
+ if isinstance(arr, Density):
156
+ if isinstance(arr.data, np.memmap):
157
+ arr = Density.from_file(arr.data.filename, subset=arr_slice).data
170
158
  else:
171
159
  arr = np.asarray(arr.data[*arr_mesh])
172
160
  else:
173
- if type(arr) == np.memmap:
161
+ if isinstance(arr, np.memmap):
174
162
  arr = np.memmap(
175
163
  arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype
176
164
  )
177
165
  arr = np.asarray(arr[*arr_mesh])
178
166
 
179
- def _warn_on_mismatch(
180
- expectation: float, computation: float, name: str
181
- ) -> float:
182
- if expectation is None:
183
- expectation = computation
184
- expectation, computation = float(expectation), float(computation)
185
-
186
- if abs(computation) > abs(expectation):
187
- warnings.warn(
188
- f"Computed {name} value is more extreme than value in file header"
189
- f" (|{computation}| > |{expectation}|). This may lead to issues"
190
- " with padding and contrast inversion."
191
- )
192
-
193
- return expectation
194
-
195
167
  padding = tuple(
196
168
  (left, right)
197
169
  for left, right in zip(
@@ -199,17 +171,12 @@ class MatchingData:
199
171
  np.subtract(right_pad, data_voxels_right),
200
172
  )
201
173
  )
202
- ret = np.pad(arr, padding, mode="reflect")
174
+ # The reflections are later cropped from the scores
175
+ arr = np.pad(arr, padding, mode="reflect")
203
176
 
204
177
  if invert:
205
- arr_min = _warn_on_mismatch(arr_min, arr.min(), "min")
206
- arr_max = _warn_on_mismatch(arr_max, arr.max(), "max")
207
-
208
- # Avoid in-place operation in case ret is not floating point
209
- ret = (
210
- -np.divide(np.subtract(ret, arr_min), np.subtract(arr_max, arr_min)) + 1
211
- )
212
- return ret
178
+ arr = -arr
179
+ return arr
213
180
 
214
181
  def subset_by_slice(
215
182
  self,
@@ -220,97 +187,94 @@ class MatchingData:
220
187
  invert_target: bool = False,
221
188
  ) -> "MatchingData":
222
189
  """
223
- Slice the instance arrays based on the provided slices.
190
+ Subset class instance based on slices.
224
191
 
225
192
  Parameters
226
193
  ----------
227
194
  target_slice : tuple of slice, optional
228
- Slices for the target. If not provided, the full shape is used.
195
+ Target subset to use, all by default.
229
196
  template_slice : tuple of slice, optional
230
- Slices for the template. If not provided, the full shape is used.
197
+ Template subset to use, all by default.
231
198
  target_pad : NDArray, optional
232
- Padding for target. Defaults to zeros. If padding exceeds target,
233
- pad with mean.
199
+ Target padding, zero by default.
234
200
  template_pad : NDArray, optional
235
- Padding for template. Defaults to zeros. If padding exceeds template,
236
- pad with mean.
201
+ Template padding, zero by default.
237
202
 
238
203
  Returns
239
204
  -------
240
- MatchingData
241
- Newly allocated sliced class instance.
205
+ :py:class:`MatchingData`
206
+ Newly allocated subset of class instance.
207
+
208
+ Examples
209
+ --------
210
+ >>> import numpy as np
211
+ >>> from tme.matching_data import MatchingData
212
+ >>> target = np.random.rand(50,40,60)
213
+ >>> template = target[15:25, 10:20, 30:40]
214
+ >>> matching_data = MatchingData(target=target, template=template)
215
+ >>> subset = matching_data.subset_by_slice(
216
+ >>> target_slice=(slice(0, 10), slice(10,20), slice(15,35))
217
+ >>> )
242
218
  """
243
- target_shape = self._target.shape
244
- template_shape = self._template.shape
245
-
246
219
  if target_slice is None:
247
- target_slice = self._shape_to_slice(target_shape)
220
+ target_slice = self._shape_to_slice(self._target.shape)
248
221
  if template_slice is None:
249
- template_slice = self._shape_to_slice(template_shape)
222
+ template_slice = self._shape_to_slice(self._template.shape)
250
223
 
251
224
  if target_pad is None:
252
225
  target_pad = np.zeros(len(self._target.shape), dtype=int)
253
226
  if template_pad is None:
254
227
  template_pad = np.zeros(len(self._template.shape), dtype=int)
255
228
 
256
- indices = None
257
- if len(self._target.shape) == len(self._template.shape):
258
- indices = compute_full_convolution_index(
259
- outer_shape=self._target.shape,
260
- inner_shape=self._template.shape,
261
- outer_split=target_slice,
262
- inner_split=template_slice,
263
- )
264
-
229
+ target_mask, template_mask = None, None
265
230
  target_subset = self.subset_array(
266
- arr=self._target,
267
- arr_slice=target_slice,
268
- padding=target_pad,
269
- invert=self._invert_target,
231
+ self._target, target_slice, target_pad, invert=self._invert_target
270
232
  )
271
-
272
233
  template_subset = self.subset_array(
273
- arr=self._template,
274
- arr_slice=template_slice,
275
- padding=template_pad,
234
+ arr=self._template, arr_slice=template_slice, padding=template_pad
276
235
  )
277
- ret = self.__class__(target=target_subset, template=template_subset)
278
-
279
- target_offset = np.zeros(len(self._output_target_shape), dtype=int)
280
- target_offset[(target_offset.size - len(target_slice)) :] = [
281
- x.start for x in target_slice
282
- ]
283
- template_offset = np.zeros(len(self._output_target_shape), dtype=int)
284
- template_offset[(template_offset.size - len(template_slice)) :] = [
285
- x.start for x in template_slice
286
- ]
287
- ret._translation_offset = target_offset
288
-
289
- ret.template_filter = self.template_filter
290
- ret.target_filter = self.target_filter
291
- ret._rotations, ret.indices = self.rotations, indices
292
- ret._target_pad, ret._template_pad = target_pad, template_pad
293
- ret._invert_target = self._invert_target
294
-
295
236
  if self._target_mask is not None:
296
- ret._target_mask = self.subset_array(
237
+ target_mask = self.subset_array(
297
238
  arr=self._target_mask, arr_slice=target_slice, padding=target_pad
298
239
  )
299
240
  if self._template_mask is not None:
300
- ret.template_mask = self.subset_array(
301
- arr=self._template_mask,
302
- arr_slice=template_slice,
303
- padding=template_pad,
241
+ template_mask = self.subset_array(
242
+ arr=self._template_mask, arr_slice=template_slice, padding=template_pad
304
243
  )
305
244
 
306
- target_dims, template_dims = None, None
307
- if hasattr(self, "_target_dims"):
308
- target_dims = self._target_dims
245
+ ret = self.__class__(
246
+ target=target_subset,
247
+ template=template_subset,
248
+ template_mask=template_mask,
249
+ target_mask=target_mask,
250
+ rotations=self.rotations,
251
+ invert_target=self._invert_target,
252
+ )
253
+
254
+ # Deal with splitting offsets
255
+ target_offset = np.zeros(len(self._output_target_shape), dtype=int)
256
+ offset = target_offset.size - len(target_slice)
257
+ target_offset[offset:] = [x.start for x in target_slice]
258
+ template_offset = np.zeros(len(self._output_target_shape), dtype=int)
259
+ offset = template_offset.size - len(template_slice)
260
+ template_offset[offset:] = [x.start for x in template_slice]
261
+ ret._translation_offset = target_offset
262
+ if len(self._target.shape) == len(self._template.shape):
263
+ ret.indices = compute_full_convolution_index(
264
+ outer_shape=self._target.shape,
265
+ inner_shape=self._template.shape,
266
+ outer_split=target_slice,
267
+ inner_split=template_slice,
268
+ )
309
269
 
310
- if hasattr(self, "_template_dims"):
311
- template_dims = self._template_dims
270
+ ret._is_padded = be.sum(be.to_backend_array(target_pad)) > 0
271
+ ret.target_filter = self.target_filter
272
+ ret.template_filter = self.template_filter
312
273
 
313
- ret._set_batch_dimension(target_dims=target_dims, template_dims=template_dims)
274
+ ret._set_matching_dimension(
275
+ target_dims=getattr(self, "_target_dims", None),
276
+ template_dims=getattr(self, "_template_dims", None),
277
+ )
314
278
 
315
279
  return ret
316
280
 
@@ -318,47 +282,42 @@ class MatchingData:
318
282
  """
319
283
  Transfer and convert types of class instance's data arrays to the current backend
320
284
  """
321
- backend_arr = type(backend.zeros((1), dtype=backend._float_dtype))
285
+ backend_arr = type(be.zeros((1), dtype=be._float_dtype))
322
286
  for attr_name, attr_value in vars(self).items():
323
287
  converted_array = None
324
288
  if isinstance(attr_value, np.ndarray):
325
- converted_array = backend.to_backend_array(attr_value.copy())
289
+ converted_array = be.to_backend_array(attr_value.copy())
326
290
  elif isinstance(attr_value, backend_arr):
327
- converted_array = backend.to_backend_array(attr_value)
291
+ converted_array = be.to_backend_array(attr_value)
328
292
  else:
329
293
  continue
330
294
 
331
- current_dtype = backend.get_fundamental_dtype(converted_array)
332
- target_dtype = backend._fundamental_dtypes[current_dtype]
295
+ current_dtype = be.get_fundamental_dtype(converted_array)
296
+ target_dtype = be._fundamental_dtypes[current_dtype]
333
297
 
334
298
  # Optional, but scores are float so we avoid casting and potential issues
335
299
  if attr_name in ("_template", "_template_mask", "_target", "_target_mask"):
336
- target_dtype = backend._float_dtype
300
+ target_dtype = be._float_dtype
337
301
 
338
302
  if target_dtype != current_dtype:
339
- converted_array = backend.astype(converted_array, target_dtype)
303
+ converted_array = be.astype(converted_array, target_dtype)
340
304
 
341
305
  setattr(self, attr_name, converted_array)
342
306
 
343
- def _set_batch_dimension(
307
+ def _set_matching_dimension(
344
308
  self, target_dims: Tuple[int] = None, template_dims: Tuple[int] = None
345
309
  ) -> None:
346
310
  """
347
- Sets the shapes of target and template for template matching considering
348
- their corresponding batch dimensions.
349
-
311
+ Sets matching dimensions for target and template.
350
312
  Parameters
351
313
  ----------
352
- target_dims : Tuple[int], optional
353
- A tuple of integers specifying the batch dimensions of the target. If None,
354
- the target is assumed not to have batch dimensions.
355
- template_dims : Tuple[int], optional
356
- A tuple of integers specifying the batch dimensions of the template. If None,
357
- the template is assumed not to have batch dimensions.
314
+ target_dims : tuple of ints, optional
315
+ Target batch dimensions, None by default.
316
+ template_dims : tuple of ints, optional
317
+ Template batch dimensions, None by default.
358
318
 
359
319
  Notes
360
320
  -----
361
-
362
321
  If the target and template share a batch dimension, the target will
363
322
  take precendence and the template dimension will be shifted to the right.
364
323
  If target and template have the same dimension, but target specifies batch
@@ -386,15 +345,9 @@ class MatchingData:
386
345
 
387
346
  matching_dims = target_measurement_dims + batch_dims
388
347
 
389
- target_shape = backend.full(
390
- shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
391
- )
392
- template_shape = backend.full(
393
- shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
394
- )
395
- batch_mask = backend.full(
396
- shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
397
- )
348
+ target_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
349
+ template_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
350
+ batch_mask = np.full(shape=matching_dims, fill_value=1, dtype=int)
398
351
 
399
352
  target_index, template_index = 0, 0
400
353
  for k in range(matching_dims):
@@ -420,9 +373,9 @@ class MatchingData:
420
373
  if template_dim < template_ndim:
421
374
  template_shape[k] = self._template.shape[template_dim]
422
375
 
423
- self._output_target_shape = target_shape
424
- self._output_template_shape = template_shape
425
- self._batch_mask = batch_mask
376
+ self._output_target_shape = tuple(int(x) for x in target_shape)
377
+ self._output_template_shape = tuple(int(x) for x in template_shape)
378
+ self._batch_mask = tuple(int(x) for x in batch_mask)
426
379
 
427
380
  @staticmethod
428
381
  def _compute_batch_dimension(
@@ -433,22 +386,22 @@ class MatchingData:
433
386
 
434
387
  Parameters
435
388
  ----------
436
- batch_dims : Tuple[int]
389
+ batch_dims : tuple of ints
437
390
  A tuple of integers representing the batch dimensions.
438
391
  ndim : int
439
392
  The number of dimensions of the array.
440
393
 
441
394
  Returns
442
395
  -------
443
- Tuple[ArrayLike, Tuple]
444
- A tuple containing the mask (as an ArrayLike) and the validated batch dimensions.
396
+ Tuple[ArrayLike, tuple of ints]
397
+ Mask and the corresponding batch dimensions.
445
398
 
446
399
  Raises
447
400
  ------
448
401
  ValueError
449
402
  If any dimension in batch_dims is not less than ndim.
450
403
  """
451
- mask = backend.zeros(ndim, dtype=bool)
404
+ mask = np.zeros(ndim, dtype=int)
452
405
  if batch_dims is None:
453
406
  return mask, ()
454
407
 
@@ -463,215 +416,298 @@ class MatchingData:
463
416
 
464
417
  return mask, batch_dims
465
418
 
466
- def target_padding(self, pad_target: bool = False) -> ArrayLike:
419
+ def target_padding(self, pad_target: bool = False) -> Tuple[int]:
467
420
  """
468
421
  Computes padding for the target based on the template's shape.
469
422
 
470
423
  Parameters
471
424
  ----------
472
425
  pad_target : bool, default False
473
- If True, computes the padding required for the target. If False,
474
- an array of zeros is returned.
426
+ Whether to pad the target, default returns an array of zeros.
475
427
 
476
428
  Returns
477
429
  -------
478
- ArrayLike
479
- An array indicating the padding for each dimension of the target.
430
+ tuple of ints
431
+ Padding along each dimension of the target.
480
432
  """
481
- target_padding = backend.zeros(
482
- len(self._output_target_shape), dtype=backend._int_dtype
483
- )
484
-
433
+ target_padding = np.zeros(len(self._output_target_shape), dtype=int)
485
434
  if pad_target:
486
- backend.subtract(
435
+ target_padding = np.subtract(
487
436
  self._output_template_shape,
488
- backend.mod(self._output_template_shape, 2),
489
- out=target_padding,
437
+ np.mod(self._output_template_shape, 2),
490
438
  )
491
439
  if hasattr(self, "_is_target_batch"):
492
- target_padding[self._is_target_batch] = 0
440
+ target_padding = np.multiply(
441
+ target_padding,
442
+ np.subtract(1, self._is_target_batch),
443
+ )
493
444
 
494
- return target_padding
445
+ return tuple(int(x) for x in target_padding)
495
446
 
496
- def fourier_padding(
497
- self, pad_fourier: bool = False
498
- ) -> Tuple[ArrayLike, ArrayLike, ArrayLike]:
447
+ @staticmethod
448
+ def _fourier_padding(
449
+ target_shape: NDArray,
450
+ template_shape: NDArray,
451
+ batch_mask: NDArray = None,
452
+ pad_fourier: bool = False,
453
+ ) -> Tuple[Tuple, Tuple, Tuple]:
499
454
  """
500
- Computes an efficient shape for the forward Fourier transform, the
501
- corresponding shape of the real-valued FFT, and the associated
502
- translation shift.
503
-
504
- Parameters
505
- ----------
506
- pad_fourier : bool, default False
507
- If true, returns the shape of the full-convolution defined as sum of target
508
- shape and template shape minus one. By default, returns unpadded transform.
509
-
510
- Returns
511
- -------
512
- Tuple[ArrayLike, ArrayLike, ArrayLike]
513
- A tuple containing the calculated fast shape, fast Fourier transform shape,
514
- and the Fourier shift values, respectively.
455
+ Determines an efficient shape for Fourier transforms considering zero-padding.
515
456
  """
516
- template_shape = self._template.shape
517
- if hasattr(self, "_output_template_shape"):
518
- template_shape = self._output_template_shape
519
- template_shape = backend.to_backend_array(template_shape)
457
+ fourier_pad = template_shape
458
+ fourier_shift = np.zeros_like(template_shape)
459
+
460
+ if batch_mask is None:
461
+ batch_mask = np.zeros_like(template_shape)
462
+ batch_mask = np.asarray(batch_mask)
520
463
 
521
- target_shape = self._target.shape
522
- if hasattr(self, "_output_target_shape"):
523
- target_shape = self._output_target_shape
524
- target_shape = backend.to_backend_array(target_shape)
464
+ if not pad_fourier:
465
+ fourier_pad = np.ones(len(fourier_pad), dtype=int)
466
+ fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
467
+ fourier_pad = np.add(fourier_pad, batch_mask)
525
468
 
526
- fourier_pad = backend.to_backend_array(template_shape)
527
- fourier_shift = backend.zeros(len(fourier_pad))
469
+ pad_shape = np.maximum(target_shape, template_shape)
470
+ ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
471
+ conv_shape, fast_shape, fast_ft_shape = ret
528
472
 
473
+ template_mod = np.mod(template_shape, 2)
529
474
  if not pad_fourier:
530
- fourier_pad = backend.full(
531
- shape=(len(fourier_pad),),
532
- fill_value=1,
533
- dtype=backend._int_dtype,
475
+ fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
476
+ fourier_shift = np.subtract(fourier_shift, template_mod)
477
+
478
+ shape_diff = np.multiply(
479
+ np.subtract(target_shape, template_shape), 1 - batch_mask
480
+ )
481
+ if np.sum(shape_diff < 0):
482
+ warnings.warn(
483
+ "Template is larger than target and padding is turned off. Consider "
484
+ "swapping them or activate padding. Correcting the shift for now."
534
485
  )
535
486
 
536
- fourier_pad = backend.to_backend_array(fourier_pad)
537
- if hasattr(self, "_batch_mask"):
538
- batch_mask = backend.to_backend_array(self._batch_mask)
539
- backend.multiply(fourier_pad, 1 - batch_mask, out=fourier_pad)
540
- backend.add(fourier_pad, batch_mask, out=fourier_pad)
487
+ shape_shift = np.divide(shape_diff, 2)
488
+ offset = np.mod(shape_diff, 2)
489
+ if pad_fourier:
490
+ offset = -np.subtract(
491
+ offset,
492
+ np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
493
+ )
541
494
 
542
- pad_shape = backend.maximum(target_shape, template_shape)
543
- ret = backend.compute_convolution_shapes(pad_shape, fourier_pad)
544
- convolution_shape, fast_shape, fast_ft_shape = ret
545
- if not pad_fourier:
546
- fourier_shift = 1 - backend.astype(backend.divide(template_shape, 2), int)
547
- fourier_shift -= backend.mod(template_shape, 2)
548
- shape_diff = backend.subtract(fast_shape, convolution_shape)
549
- shape_diff = backend.astype(backend.divide(shape_diff, 2), int)
495
+ shape_shift = np.add(shape_shift, offset)
496
+ fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
550
497
 
551
- if hasattr(self, "_batch_mask"):
552
- batch_mask = backend.to_backend_array(self._batch_mask)
553
- backend.multiply(shape_diff, 1 - batch_mask, out=shape_diff)
498
+ fourier_shift = tuple(fourier_shift.astype(int))
499
+ return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
554
500
 
555
- backend.add(fourier_shift, shape_diff, out=fourier_shift)
501
+ def fourier_padding(self, pad_fourier: bool = False) -> Tuple[Tuple, Tuple, Tuple]:
502
+ """
503
+ Computes efficient shape four Fourier transforms and potential associated shifts.
556
504
 
557
- fourier_shift = backend.astype(fourier_shift, backend._int_dtype)
505
+ Parameters
506
+ ----------
507
+ pad_fourier : bool, default False
508
+ If true, returns the shape of the full-convolution defined as sum of target
509
+ shape and template shape minus one, False by default.
558
510
 
559
- return fast_shape, fast_ft_shape, fourier_shift
511
+ Returns
512
+ -------
513
+ Tuple[tuple of int, tuple of int, tuple of int]
514
+ Tuple with real and complex Fourier transform shape, and corresponding shift.
515
+ """
516
+ return self._fourier_padding(
517
+ target_shape=be.to_numpy_array(self._output_target_shape),
518
+ template_shape=be.to_numpy_array(self._output_template_shape),
519
+ batch_mask=be.to_numpy_array(self._batch_mask),
520
+ pad_fourier=pad_fourier,
521
+ )
560
522
 
561
523
  @property
562
524
  def rotations(self):
563
- """Return stored rotation matrices.."""
525
+ """Return stored rotation matrices."""
564
526
  return self._rotations
565
527
 
566
528
  @rotations.setter
567
529
  def rotations(self, rotations: NDArray):
568
530
  """
569
- Set and reshape the rotation matrices for template matching.
531
+ Set :py:attr:`MatchingData.rotations`.
570
532
 
571
533
  Parameters
572
534
  ----------
573
535
  rotations : NDArray
574
- Rotations in shape (k x k), or (n x k x k).
536
+ Rotations matrices with shape (d, d) or (n, d, d).
575
537
  """
576
- if rotations.__class__ != np.ndarray:
577
- raise ValueError("Rotation set has to be of type numpy ndarray.")
578
- if rotations.ndim == 2:
538
+ if rotations is None:
539
+ print("No rotations provided, assuming identity for now.")
540
+ rotations = np.eye(len(self._target.shape))
541
+
542
+ if rotations.ndim not in (2, 3):
543
+ raise ValueError("Rotations have to be a rank 2 or 3 array.")
544
+ elif rotations.ndim == 2:
579
545
  print("Reshaping rotations array to rank 3.")
580
546
  rotations = rotations.reshape(1, *rotations.shape)
581
- elif rotations.ndim == 3:
582
- pass
583
- else:
584
- raise ValueError("Rotations have to be a rank 2 or 3 array.")
585
547
  self._rotations = rotations.astype(np.float32)
586
548
 
549
+ @staticmethod
550
+ def _get_data(attribute, output_shape: Tuple[int], reverse: bool = False):
551
+ if isinstance(attribute, Density):
552
+ attribute = attribute.data
553
+
554
+ if attribute is not None:
555
+ if reverse:
556
+ attribute = be.reverse(attribute)
557
+ attribute = attribute.reshape(tuple(int(x) for x in output_shape))
558
+
559
+ return attribute
560
+
587
561
  @property
588
562
  def target(self):
589
- """Returns the target."""
590
- target = self._target
591
- if isinstance(self._target, Density):
592
- target = self._target.data
563
+ """
564
+ Return the target.
565
+
566
+ Returns
567
+ -------
568
+ NDArray
569
+ Output data.
570
+ """
571
+ return self._get_data(self._target, self._output_target_shape, False)
593
572
 
594
- out_shape = tuple(int(x) for x in self._output_target_shape)
595
- return target.reshape(out_shape)
573
+ @property
574
+ def target_mask(self):
575
+ """
576
+ Return the target mask.
577
+
578
+ Returns
579
+ -------
580
+ NDArray
581
+ Output data.
582
+ """
583
+ target_mask = getattr(self, "_target_mask", None)
584
+ return self._get_data(target_mask, self._output_target_shape, False)
596
585
 
597
586
  @property
598
587
  def template(self):
599
- """Returns the reversed template."""
600
- template = self._template
601
- if isinstance(self._template, Density):
602
- template = self._template.data
603
- template = backend.reverse(template)
604
- out_shape = tuple(int(x) for x in self._output_template_shape)
605
- return template.reshape(out_shape)
588
+ """
589
+ Return the reversed template.
590
+
591
+ Returns
592
+ -------
593
+ NDArray
594
+ Output data.
595
+ """
596
+ return self._get_data(self._template, self._output_template_shape, True)
597
+
598
+ @property
599
+ def template_mask(self):
600
+ """
601
+ Return the reversed template mask.
602
+
603
+ Returns
604
+ -------
605
+ NDArray
606
+ Output data.
607
+ """
608
+ template_mask = getattr(self, "_template_mask", None)
609
+ return self._get_data(template_mask, self._output_template_shape, True)
610
+
611
+ @target.setter
612
+ def target(self, arr: NDArray):
613
+ """
614
+ Set :py:attr:`MatchingData.target`.
615
+
616
+ Parameters
617
+ ----------
618
+ arr : NDArray
619
+ Array to set as the target.
620
+ """
621
+ self._target = arr
606
622
 
607
623
  @template.setter
608
- def template(self, template: NDArray):
624
+ def template(self, arr: NDArray):
609
625
  """
610
- Set the template array. If not already defined, also initializes
611
- :py:attr:`MatchingData.template_mask` to an uninformative mask filled with
612
- ones.
626
+ Set :py:attr:`MatchingData.template` and initializes
627
+ :py:attr:`MatchingData.template_mask` to an to an uninformative
628
+ mask filled with ones if not already defined.
613
629
 
614
630
  Parameters
615
631
  ----------
616
- template : NDArray
632
+ arr : NDArray
617
633
  Array to set as the template.
618
634
  """
619
- self._templateshape = template.shape[::-1]
620
- if self._template_mask is None:
621
- self._template_mask = backend.full(
622
- shape=template.shape, dtype=float, fill_value=1
635
+ self._template = arr
636
+ if getattr(self, "_template_mask", None) is None:
637
+ self._template_mask = be.full(
638
+ shape=arr.shape, dtype=be._float_dtype, fill_value=1
623
639
  )
624
640
 
625
- self._template = template
641
+ @staticmethod
642
+ def _set_mask(mask, shape: Tuple[int]):
643
+ if mask is not None:
644
+ if mask.shape != shape:
645
+ raise ValueError(
646
+ "Mask and respective data have to have the same shape."
647
+ )
648
+ return mask
626
649
 
627
- @property
628
- def target_mask(self):
629
- """Returns the target mask NDArray."""
630
- target_mask = self._target_mask
631
- if isinstance(self._target_mask, Density):
632
- target_mask = self._target_mask.data
650
+ @target_mask.setter
651
+ def target_mask(self, arr: NDArray):
652
+ """
653
+ Set :py:attr:`MatchingData.target_mask`.
633
654
 
634
- if target_mask is not None:
635
- out_shape = tuple(int(x) for x in self._output_target_shape)
636
- target_mask = target_mask.reshape(out_shape)
655
+ Parameters
656
+ ----------
657
+ arr : NDArray
658
+ Array to set as the target_mask.
659
+ """
660
+ self._target_mask = self._set_mask(mask=arr, shape=self._target.shape)
637
661
 
638
- return target_mask
662
+ @template_mask.setter
663
+ def template_mask(self, arr: NDArray):
664
+ """
665
+ Set :py:attr:`MatchingData.template_mask`.
639
666
 
640
- @target_mask.setter
641
- def target_mask(self, mask: NDArray):
642
- """Sets the target mask."""
643
- if not np.all(self.target.shape == mask.shape):
644
- raise ValueError("Target and its mask have to have the same shape.")
667
+ Parameters
668
+ ----------
669
+ arr : NDArray
670
+ Array to set as the template_mask.
671
+ """
672
+ self._template_mask = self._set_mask(mask=arr, shape=self._template.shape)
645
673
 
646
- self._target_mask = mask
674
+ @staticmethod
675
+ def _set_filter(composable_filter) -> Optional[Compose]:
676
+ if isinstance(composable_filter, Compose):
677
+ return composable_filter
678
+ return None
647
679
 
648
680
  @property
649
- def template_mask(self):
681
+ def template_filter(self) -> Optional[Compose]:
650
682
  """
651
- Set the template mask array after reversing it.
683
+ Returns the composable template filter.
652
684
 
653
- Parameters
654
- ----------
655
- template : NDArray
656
- Array to set as the template.
685
+ Returns
686
+ -------
687
+ :py:class:`tme.preprocessing.Compose` | None
688
+ Composable template filter or None.
657
689
  """
658
- mask = self._template_mask
659
- if isinstance(self._template_mask, Density):
660
- mask = self._template_mask.data
690
+ return getattr(self, "_template_filter", None)
661
691
 
662
- if mask is not None:
663
- mask = backend.reverse(mask)
664
- out_shape = tuple(int(x) for x in self._output_template_shape)
665
- mask = mask.reshape(out_shape)
666
- return mask
692
+ @property
693
+ def target_filter(self) -> Optional[Compose]:
694
+ """
695
+ Returns the composable target filter.
667
696
 
668
- @template_mask.setter
669
- def template_mask(self, mask: NDArray):
670
- """Returns the reversed template mask NDArray."""
671
- if not np.all(self._templateshape[::-1] == mask.shape):
672
- raise ValueError("Template and its mask have to have the same shape.")
697
+ Returns
698
+ -------
699
+ :py:class:`tme.preprocessing.Compose` | None
700
+ Composable filter or None.
701
+ """
702
+ return getattr(self, "_target_filter", None)
673
703
 
674
- self._template_mask = mask
704
+ @template_filter.setter
705
+ def template_filter(self, composable_filter: Compose):
706
+ self._template_filter = self._set_filter(composable_filter)
707
+
708
+ @target_filter.setter
709
+ def target_filter(self, composable_filter: Compose):
710
+ self._target_filter = self._set_filter(composable_filter)
675
711
 
676
712
  def _split_rotations_on_jobs(self, n_jobs: int) -> List[NDArray]:
677
713
  """
@@ -696,3 +732,11 @@ class MatchingData:
696
732
  end_rot = None
697
733
  rot_list.append(self.rotations[init_rot:end_rot])
698
734
  return rot_list
735
+
736
+ def _free_data(self):
737
+ """
738
+ Free (dereference) data arrays owned by the class instance.
739
+ """
740
+ attrs = ("_target", "_template", "_template_mask", "_target_mask")
741
+ for attr in attrs:
742
+ setattr(self, attr, None)