pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -4,41 +4,96 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
-
7
+ import os
8
8
  from typing import Tuple, Dict, List
9
+ from contextlib import contextmanager
9
10
  from multiprocessing import shared_memory
10
11
  from multiprocessing.managers import SharedMemoryManager
11
- from contextlib import contextmanager
12
12
 
13
13
  import numpy as np
14
14
  from psutil import virtual_memory
15
- from numpy.typing import NDArray
15
+ from scipy.ndimage import maximum_filter, affine_transform
16
16
  from pyfftw import zeros_aligned, simd_alignment, FFTW, next_fast_len
17
17
  from pyfftw.builders import rfftn as rfftn_builder, irfftn as irfftn_builder
18
- from scipy.ndimage import maximum_filter, affine_transform
19
18
 
20
- from .matching_backend import MatchingBackend
21
- from ..matching_utils import rigid_transform
19
+ from ..types import NDArray, BackendArray, shm_type
20
+ from .matching_backend import MatchingBackend, _create_metafunction
21
+
22
+ os.environ["MKL_NUM_THREADS"] = "1"
23
+ os.environ["OMP_NUM_THREADS"] = "1"
24
+ os.environ["PYFFTW_NUM_THREADS"] = "1"
25
+ os.environ["OPENBLAS_NUM_THREADS"] = "1"
26
+
27
+
28
+ def create_ufuncs(obj):
29
+ ufuncs = [
30
+ "add",
31
+ "subtract",
32
+ "multiply",
33
+ "divide",
34
+ "mod",
35
+ "sum",
36
+ "where",
37
+ "einsum",
38
+ "mean",
39
+ "einsum",
40
+ "std",
41
+ "max",
42
+ "min",
43
+ "maximum",
44
+ "minimum",
45
+ "sqrt",
46
+ "square",
47
+ "abs",
48
+ "power",
49
+ "full",
50
+ "clip",
51
+ "arange",
52
+ "stack",
53
+ "concatenate",
54
+ "repeat",
55
+ "indices",
56
+ "unique",
57
+ "argsort",
58
+ "tril_indices",
59
+ "reshape",
60
+ "identity",
61
+ "dot",
62
+ ]
63
+ for ufunc in ufuncs:
64
+ setattr(obj, ufunc, _create_metafunction(ufunc))
65
+ return obj
66
+
67
+
68
+ @create_ufuncs
69
+ class _NumpyWrapper:
70
+ """
71
+ MatchingBackend prohibits using create_ufuncs on NumpyFFTWBackend directly.
72
+ """
22
73
 
74
+ pass
23
75
 
24
- class NumpyFFTWBackend(MatchingBackend):
76
+
77
+ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
25
78
  """
26
- A numpy and pyfftw based backend for template matching.
79
+ A numpy and pyfftw-based matching backend.
27
80
  """
28
81
 
29
82
  def __init__(
30
83
  self,
31
84
  array_backend=np,
32
- default_dtype=np.float32,
85
+ float_dtype=np.float32,
33
86
  complex_dtype=np.complex64,
34
- default_dtype_int=np.int32,
87
+ int_dtype=np.int32,
88
+ overflow_safe_dtype=np.float32,
35
89
  **kwargs,
36
90
  ):
37
91
  super().__init__(
38
92
  array_backend=array_backend,
39
- default_dtype=default_dtype,
93
+ float_dtype=float_dtype,
40
94
  complex_dtype=complex_dtype,
41
- default_dtype_int=default_dtype_int,
95
+ int_dtype=int_dtype,
96
+ overflow_safe_dtype=overflow_safe_dtype,
42
97
  )
43
98
  self.affine_transform = affine_transform
44
99
 
@@ -48,155 +103,48 @@ class NumpyFFTWBackend(MatchingBackend):
48
103
  return self._array_backend.asarray(arr)
49
104
 
50
105
  def to_numpy_array(self, arr: NDArray) -> NDArray:
51
- return arr
106
+ return np.array(arr)
52
107
 
53
108
  def to_cpu_array(self, arr: NDArray) -> NDArray:
54
109
  return arr
55
110
 
111
+ def get_fundamental_dtype(self, arr):
112
+ dt = arr.dtype
113
+ if self._array_backend.issubdtype(dt, self._array_backend.integer):
114
+ return int
115
+ elif self._array_backend.issubdtype(dt, self._array_backend.floating):
116
+ return float
117
+ elif self._array_backend.issubdtype(dt, self._array_backend.complexfloating):
118
+ return complex
119
+ return float
120
+
56
121
  def free_cache(self):
57
122
  pass
58
123
 
59
- def add(self, x1, x2, *args, **kwargs) -> NDArray:
60
- x1 = self.to_backend_array(x1)
61
- x2 = self.to_backend_array(x2)
62
- return self._array_backend.add(x1, x2, *args, **kwargs)
63
-
64
- def subtract(self, x1, x2, *args, **kwargs) -> NDArray:
65
- x1 = self.to_backend_array(x1)
66
- x2 = self.to_backend_array(x2)
67
- return self._array_backend.subtract(x1, x2, *args, **kwargs)
68
-
69
- def multiply(self, x1, x2, *args, **kwargs) -> NDArray:
70
- x1 = self.to_backend_array(x1)
71
- x2 = self.to_backend_array(x2)
72
- return self._array_backend.multiply(x1, x2, *args, **kwargs)
73
-
74
- def divide(self, x1, x2, *args, **kwargs) -> NDArray:
75
- x1 = self.to_backend_array(x1)
76
- x2 = self.to_backend_array(x2)
77
- return self._array_backend.divide(x1, x2, *args, **kwargs)
78
-
79
- def mod(self, x1, x2, *args, **kwargs):
80
- x1 = self.to_backend_array(x1)
81
- x2 = self.to_backend_array(x2)
82
- return self._array_backend.mod(x1, x2, *args, **kwargs)
83
-
84
- def sum(self, *args, **kwargs) -> NDArray:
85
- return self._array_backend.sum(*args, **kwargs)
86
-
87
- def einsum(self, *args, **kwargs) -> NDArray:
88
- return self._array_backend.einsum(*args, **kwargs)
89
-
90
- def mean(self, *args, **kwargs) -> NDArray:
91
- return self._array_backend.mean(*args, **kwargs)
92
-
93
- def std(self, *args, **kwargs) -> NDArray:
94
- return self._array_backend.std(*args, **kwargs)
95
-
96
- def max(self, *args, **kwargs) -> NDArray:
97
- return self._array_backend.max(*args, **kwargs)
98
-
99
- def min(self, *args, **kwargs) -> NDArray:
100
- return self._array_backend.min(*args, **kwargs)
101
-
102
- def maximum(self, x1, x2, *args, **kwargs) -> NDArray:
103
- x1 = self.to_backend_array(x1)
104
- x2 = self.to_backend_array(x2)
105
- return self._array_backend.maximum(x1, x2, *args, **kwargs)
106
-
107
- def minimum(self, x1, x2, *args, **kwargs) -> NDArray:
108
- x1 = self.to_backend_array(x1)
109
- x2 = self.to_backend_array(x2)
110
- return self._array_backend.minimum(x1, x2, *args, **kwargs)
111
-
112
- def sqrt(self, *args, **kwargs) -> NDArray:
113
- return self._array_backend.sqrt(*args, **kwargs)
114
-
115
- def square(self, *args, **kwargs) -> NDArray:
116
- return self._array_backend.square(*args, **kwargs)
117
-
118
- def abs(self, *args, **kwargs) -> NDArray:
119
- return self._array_backend.abs(*args, **kwargs)
120
-
121
124
  def transpose(self, arr):
122
125
  return arr.T
123
126
 
124
- def power(self, *args, **kwargs):
125
- return self._array_backend.power(*args, **kwargs)
126
-
127
127
  def tobytes(self, arr):
128
128
  return arr.tobytes()
129
129
 
130
130
  def size(self, arr):
131
131
  return arr.size
132
132
 
133
- def fill(self, arr: NDArray, value: float) -> None:
133
+ def fill(self, arr: NDArray, value: float) -> NDArray:
134
134
  arr.fill(value)
135
-
136
- def zeros(self, shape, dtype=np.float64) -> NDArray:
137
- return self._array_backend.zeros(shape=shape, dtype=dtype)
138
-
139
- def full(self, shape, fill_value, dtype=None, **kwargs) -> NDArray:
140
- return self._array_backend.full(
141
- shape, dtype=dtype, fill_value=fill_value, **kwargs
142
- )
135
+ return arr
143
136
 
144
137
  def eps(self, dtype: type) -> NDArray:
145
- """
146
- Returns the eps defined as diffeerence between 1.0 and the next
147
- representable floating point value larger than 1.0.
148
-
149
- Parameters
150
- ----------
151
- dtype : type
152
- Data type for which eps should be returned.
153
-
154
- Returns
155
- -------
156
- Scalar
157
- The eps for the given data type
158
- """
159
138
  return self._array_backend.finfo(dtype).eps
160
139
 
161
140
  def datatype_bytes(self, dtype: type) -> NDArray:
162
- """
163
- Return the number of bytes occupied by a given datatype.
164
-
165
- Parameters
166
- ----------
167
- dtype : type
168
- Datatype for which the number of bytes is to be determined.
169
-
170
- Returns
171
- -------
172
- int
173
- Number of bytes occupied by the datatype.
174
- """
175
141
  temp = self._array_backend.zeros(1, dtype=dtype)
176
142
  return temp.nbytes
177
143
 
178
- def clip(self, *args, **kwargs) -> NDArray:
179
- return self._array_backend.clip(*args, **kwargs)
180
-
181
- def flip(self, a, axis, **kwargs):
182
- return self._array_backend.flip(a, axis, **kwargs)
183
-
184
144
  @staticmethod
185
145
  def astype(arr, dtype):
186
146
  return arr.astype(dtype)
187
147
 
188
- def arange(self, *args, **kwargs):
189
- return self._array_backend.arange(*args, **kwargs)
190
-
191
- def stack(self, *args, **kwargs):
192
- return self._array_backend.stack(*args, **kwargs)
193
-
194
- def concatenate(self, *args, **kwargs):
195
- return self._array_backend.concatenate(*args, **kwargs)
196
-
197
- def repeat(self, *args, **kwargs):
198
- return self._array_backend.repeat(*args, **kwargs)
199
-
200
148
  def topk_indices(self, arr: NDArray, k: int):
201
149
  temp = arr.reshape(-1)
202
150
  indices = self._array_backend.argpartition(temp, -k)[-k:][:k]
@@ -215,18 +163,9 @@ class NumpyFFTWBackend(MatchingBackend):
215
163
  **kwargs,
216
164
  )
217
165
 
218
- def unique(self, *args, **kwargs):
219
- return self._array_backend.unique(*args, **kwargs)
220
-
221
- def argsort(self, *args, **kwargs):
222
- return self._array_backend.argsort(*args, **kwargs)
223
-
224
166
  def unravel_index(self, indices, shape):
225
167
  return self._array_backend.unravel_index(indices=indices, shape=shape)
226
168
 
227
- def tril_indices(self, *args, **kwargs):
228
- return self._array_backend.tril_indices(*args, **kwargs)
229
-
230
169
  def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
231
170
  score_box = tuple(min_distance for _ in range(score_space.ndim))
232
171
  max_filter = maximum_filter(score_space, size=score_box, mode="constant")
@@ -236,93 +175,27 @@ class NumpyFFTWBackend(MatchingBackend):
236
175
  return peaks
237
176
 
238
177
  @staticmethod
239
- def preallocate_array(shape: Tuple[int], dtype: type) -> NDArray:
240
- """
241
- Returns a byte-aligned array of zeros with specified shape and dtype.
242
-
243
- Parameters
244
- ----------
245
- shape : Tuple[int]
246
- Desired shape for the array.
247
- dtype : type
248
- Desired data type for the array.
249
-
250
- Returns
251
- -------
252
- NDArray
253
- Byte-aligned array of zeros with specified shape and dtype.
254
- """
178
+ def zeros(shape: Tuple[int], dtype: type = None) -> NDArray:
255
179
  arr = zeros_aligned(shape, dtype=dtype, n=simd_alignment)
256
180
  return arr
257
181
 
258
- def sharedarr_to_arr(
259
- self, shm: shared_memory.SharedMemory, shape: Tuple[int], dtype: str
260
- ) -> NDArray:
261
- """
262
- Returns an array of given shape and dtype from shared memory location.
263
-
264
- Parameters
265
- ----------
266
- shape : tuple
267
- Tuple of integers specifying the shape of the array.
268
- dtype : str
269
- String specifying the dtype of the array.
270
- shm : shared_memory.SharedMemory
271
- Shared memory object where the array is stored.
272
-
273
- Returns
274
- -------
275
- NDArray
276
- Array of the specified shape and dtype from the shared memory location.
277
- """
182
+ def from_sharedarr(self, args) -> NDArray:
183
+ shm, shape, dtype = args
278
184
  return self.ndarray(shape, dtype, shm.buf)
279
185
 
280
- def arr_to_sharedarr(
186
+ def to_sharedarr(
281
187
  self, arr: NDArray, shared_memory_handler: type = None
282
- ) -> shared_memory.SharedMemory:
283
- """
284
- Converts a numpy array to an object shared in memory.
285
-
286
- Parameters
287
- ----------
288
- arr : NDArray
289
- Numpy array to convert.
290
- shared_memory_handler : type, optional
291
- The type of shared memory handler. Default is None.
292
-
293
- Returns
294
- -------
295
- shared_memory.SharedMemory
296
- The shared memory object containing the numpy array.
297
- """
188
+ ) -> shm_type:
298
189
  if type(shared_memory_handler) == SharedMemoryManager:
299
190
  shm = shared_memory_handler.SharedMemory(size=arr.nbytes)
300
191
  else:
301
192
  shm = shared_memory.SharedMemory(create=True, size=arr.nbytes)
302
193
  np_array = self.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
303
194
  np_array[:] = arr[:].copy()
304
- return shm
195
+ return shm, arr.shape, arr.dtype
305
196
 
306
197
  def topleft_pad(self, arr: NDArray, shape: Tuple[int], padval: int = 0) -> NDArray:
307
- """
308
- Returns an array that has been padded to a specified shape with a padding
309
- value at the top-left corner.
310
-
311
- Parameters
312
- ----------
313
- arr : NDArray
314
- Input array to be padded.
315
- shape : Tuple[int]
316
- Desired shape for the output array.
317
- padval : int, optional
318
- Value to use for padding, default is 0.
319
-
320
- Returns
321
- -------
322
- NDArray
323
- Array that has been padded to the specified shape.
324
- """
325
- b = self.preallocate_array(shape, arr.dtype)
198
+ b = self.zeros(shape, arr.dtype)
326
199
  self.add(b, padval, out=b)
327
200
  aind = [slice(None, None)] * arr.ndim
328
201
  bind = [slice(None, None)] * arr.ndim
@@ -345,40 +218,10 @@ class NumpyFFTWBackend(MatchingBackend):
345
218
  temp_real: NDArray = None,
346
219
  temp_fft: NDArray = None,
347
220
  ) -> Tuple[FFTW, FFTW]:
348
- """
349
- Build pyFFTW builder functions.
350
-
351
- Parameters
352
- ----------
353
- fast_shape : tuple
354
- Tuple of integers corresponding to fast convolution shape
355
- (see `compute_convolution_shapes`).
356
- fast_ft_shape : tuple
357
- Tuple of integers corresponding to the shape of the fourier
358
- transform array (see `compute_convolution_shapes`).
359
- real_dtype : dtype
360
- Numpy dtype of the inverse fourier transform.
361
- complex_dtype : dtype
362
- Numpy dtype of the fourier transform.
363
- inverse_fast_shape : tuple, optional
364
- Output shape of the inverse Fourier transform. By default fast_shape.
365
- fftargs : dict, optional
366
- Dictionary passed to pyFFTW builders.
367
- temp_real : NDArray, optional
368
- Temporary real numpy array, by default None.
369
- temp_fft : NDArray, optional
370
- Temporary fft numpy array, by default None.
371
-
372
- Returns
373
- -------
374
- tuple
375
- Tuple containing callable pyFFTW objects for forward and inverse
376
- fourier transform.
377
- """
378
221
  if temp_real is None:
379
- temp_real = self.preallocate_array(fast_shape, real_dtype)
222
+ temp_real = self.zeros(fast_shape, real_dtype)
380
223
  if temp_fft is None:
381
- temp_fft = self.preallocate_array(fast_ft_shape, complex_dtype)
224
+ temp_fft = self.zeros(fast_ft_shape, complex_dtype)
382
225
  if inverse_fast_shape is None:
383
226
  inverse_fast_shape = fast_shape
384
227
 
@@ -407,30 +250,11 @@ class NumpyFFTWBackend(MatchingBackend):
407
250
  return rfftn, irfftn
408
251
 
409
252
  def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
410
- """
411
- Extract the centered portion of an array based on a new shape.
412
-
413
- Parameters
414
- ----------
415
- arr : NDArray
416
- Input array.
417
- newshape : tuple
418
- Desired shape for the central portion.
419
-
420
- Returns
421
- -------
422
- NDArray
423
- Central portion of the array with shape `newshape`.
424
-
425
- References
426
- ----------
427
- .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py
428
- """
429
253
  new_shape = self.to_backend_array(newshape)
430
254
  current_shape = self.to_backend_array(arr.shape)
431
255
  starts = self.subtract(current_shape, new_shape)
432
- starts = self.astype(self.divide(starts, 2), self._default_dtype_int)
433
- stops = self.astype(self.add(starts, newshape), self._default_dtype_int)
256
+ starts = self.astype(self.divide(starts, 2), self._int_dtype)
257
+ stops = self.astype(self.add(starts, new_shape), self._int_dtype)
434
258
  box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
435
259
  return arr[box]
436
260
 
@@ -454,254 +278,110 @@ class NumpyFFTWBackend(MatchingBackend):
454
278
  fourier transform, shape of the forward fourier transform
455
279
  (see :py:meth:`build_fft`).
456
280
  """
457
- convolution_shape = [
458
- int(x) + int(y) - 1 for x, y in zip(arr1_shape, arr2_shape)
459
- ]
281
+ convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
460
282
  fast_shape = [next_fast_len(x) for x in convolution_shape]
461
283
  fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
462
284
 
463
285
  return convolution_shape, fast_shape, fast_ft_shape
464
286
 
465
- def rotate_array(
287
+ def _rigid_transform_matrix(
466
288
  self,
467
- arr: NDArray,
468
- rotation_matrix: NDArray,
469
- arr_mask: NDArray = None,
289
+ rotation_matrix: NDArray = None,
470
290
  translation: NDArray = None,
471
- use_geometric_center: bool = False,
472
- out: NDArray = None,
473
- out_mask: NDArray = None,
474
- order: int = 3,
475
- ) -> None:
476
- """
477
- Rotates coordinates of arr according to rotation_matrix.
478
-
479
- If no output array is provided, this method will compute an array with
480
- sufficient space to hold all elements. If both `arr` and `arr_mask`
481
- are provided, `arr_mask` will be centered according to arr.
482
-
483
- Parameters
484
- ----------
485
- arr : NDArray
486
- The input array to be rotated.
487
- arr_mask : NDArray, optional
488
- The mask of `arr` that will be equivalently rotated.
489
- rotation_matrix : NDArray
490
- The rotation matrix to apply [d x d].
491
- translation : NDArray
492
- The translation to apply [d].
493
- use_geometric_center : bool, optional
494
- Whether the rotation should be centered around the geometric
495
- or mass center. Default is mass center.
496
- out : NDArray, optional
497
- The output array to write the rotation of `arr` to.
498
- out_mask : NDArray, optional
499
- The output array to write the rotation of `arr_mask` to.
500
- order : int, optional
501
- Spline interpolation order. Has to be in the range 0-5. Non-zero
502
- elements will be converted into a point-cloud and rotated according
503
- to ``rotation_matrix`` if order is None.
504
- """
291
+ center: NDArray = None,
292
+ ) -> NDArray:
293
+ ndim = rotation_matrix.shape[0]
294
+ matrix = self.identity(ndim + 1, dtype=self._float_dtype)
505
295
 
506
- if order is None:
507
- mask_coordinates = None
508
- if arr_mask is not None:
509
- mask_coordinates = np.array(np.where(arr_mask > 0))
510
- return self.rotate_array_coordinates(
511
- arr=arr,
512
- arr_mask=arr_mask,
513
- coordinates=np.array(np.where(arr > 0)),
514
- mask_coordinates=mask_coordinates,
515
- out=out,
516
- out_mask=out_mask,
517
- rotation_matrix=rotation_matrix,
518
- translation=translation,
519
- use_geometric_center=use_geometric_center,
520
- )
296
+ if translation is not None:
297
+ translation_matrix = self.identity(ndim + 1, dtype=self._float_dtype)
298
+ translation_matrix[:ndim, ndim] = -translation
299
+ self.dot(matrix, translation_matrix, out=matrix)
521
300
 
522
- rotate_mask = arr_mask is not None
523
- return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
524
- translation = np.zeros(arr.ndim) if translation is None else translation
301
+ if center is not None:
302
+ center_matrix = self.identity(ndim + 1, dtype=self._float_dtype)
303
+ center_matrix[:ndim, ndim] = center
304
+ self.dot(matrix, center_matrix, out=matrix)
525
305
 
526
- center = np.divide(arr.shape, 2)
527
- if not use_geometric_center:
528
- center = self.center_of_mass(arr, cutoff=0)
306
+ if rotation_matrix is not None:
307
+ rmat = self.identity(ndim + 1, dtype=self._float_dtype)
308
+ rmat[:ndim, :ndim] = self._array_backend.linalg.inv(rotation_matrix)
309
+ self.dot(matrix, rmat, out=matrix)
529
310
 
530
- rotation_matrix_inverted = np.linalg.inv(rotation_matrix)
531
- transformed_center = rotation_matrix_inverted @ center.reshape(-1, 1)
532
- transformed_center = transformed_center.reshape(-1)
533
- base_offset = np.subtract(center, transformed_center)
534
- offset = np.subtract(base_offset, translation)
311
+ if center is not None:
312
+ center_matrix[:ndim, ndim] = -center_matrix[:ndim, ndim]
313
+ self.dot(matrix, center_matrix, out=matrix)
535
314
 
536
- out = np.zeros_like(arr) if out is None else out
537
- out_slice = tuple(slice(0, stop) for stop in arr.shape)
315
+ matrix /= matrix[ndim, ndim]
316
+ return matrix
538
317
 
539
- # Applying the prefilter can cause artifacts in the mask
540
- affine_transform(
541
- input=arr,
542
- matrix=rotation_matrix_inverted,
543
- offset=offset,
318
+ def _rigid_transform(
319
+ self,
320
+ data: NDArray,
321
+ matrix: NDArray,
322
+ output: NDArray,
323
+ prefilter: bool,
324
+ order: int,
325
+ cache: bool = False,
326
+ ) -> None:
327
+ out_slice = tuple(slice(0, stop) for stop in data.shape)
328
+ self.affine_transform(
329
+ input=data,
330
+ matrix=matrix,
544
331
  mode="constant",
545
- output=out[out_slice],
332
+ output=output[out_slice],
546
333
  order=order,
547
- prefilter=True,
334
+ prefilter=prefilter,
548
335
  )
549
336
 
550
- if rotate_mask:
551
- out_mask = np.zeros_like(arr_mask) if out_mask is None else out_mask
552
- out_mask_slice = tuple(slice(0, stop) for stop in arr_mask.shape)
553
- affine_transform(
554
- input=arr_mask,
555
- matrix=rotation_matrix_inverted,
556
- offset=offset,
557
- mode="constant",
558
- output=out_mask[out_mask_slice],
559
- order=order,
560
- prefilter=False,
561
- )
562
-
563
- match return_type:
564
- case 0:
565
- return None
566
- case 1:
567
- return out
568
- case 2:
569
- return out_mask
570
- case 3:
571
- return out, out_mask
572
-
573
- @staticmethod
574
- def rotate_array_coordinates(
337
+ def rigid_transform(
338
+ self,
575
339
  arr: NDArray,
576
- coordinates: NDArray,
577
340
  rotation_matrix: NDArray,
341
+ arr_mask: NDArray = None,
578
342
  translation: NDArray = None,
343
+ use_geometric_center: bool = False,
579
344
  out: NDArray = None,
580
- use_geometric_center: bool = True,
581
- arr_mask: NDArray = None,
582
- mask_coordinates: NDArray = None,
583
345
  out_mask: NDArray = None,
584
- ) -> None:
585
- """
586
- Rotates coordinates of arr according to rotation_matrix.
587
-
588
- If no output array is provided, this method will compute an array with
589
- sufficient space to hold all elements. If both `arr` and `arr_mask`
590
- are provided, `arr_mask` will be centered according to arr.
591
-
592
- No centering will be performed if the rotation matrix is the identity matrix.
593
-
594
- Parameters
595
- ----------
596
- arr : NDArray
597
- The input array to be rotated.
598
- coordinates : NDArray
599
- The pointcloud [d x N] containing elements of `arr` that should be rotated.
600
- See :py:meth:`Density.to_pointcloud` on how to obtain the coordinates.
601
- rotation_matrix : NDArray
602
- The rotation matrix to apply [d x d].
603
- rotation_matrix : NDArray
604
- The translation to apply [d].
605
- out : NDArray, optional
606
- The output array to write the rotation of `arr` to.
607
- use_geometric_center : bool, optional
608
- Whether the rotation should be centered around the geometric
609
- or mass center.
610
- arr_mask : NDArray, optional
611
- The mask of `arr` that will be equivalently rotated.
612
- mask_coordinates : NDArray, optional
613
- Equivalent to `coordinates`, but containing elements of `arr_mask`
614
- that should be rotated.
615
- out_mask : NDArray, optional
616
- The output array to write the rotation of `arr_mask` to.
617
- """
618
- rotate_mask = arr_mask is not None and mask_coordinates is not None
619
- return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
620
-
621
- coordinates_rotated = np.empty(coordinates.shape, dtype=rotation_matrix.dtype)
622
- mask_rotated = (
623
- np.empty(mask_coordinates.shape, dtype=rotation_matrix.dtype)
624
- if rotate_mask
625
- else None
626
- )
627
-
628
- center = np.array(arr.shape) // 2 if use_geometric_center else None
629
- if translation is None:
630
- translation = np.zeros(coordinates_rotated.shape[0])
346
+ order: int = 3,
347
+ cache: bool = False,
348
+ ) -> Tuple[NDArray, NDArray]:
349
+ translation = self.zeros(arr.ndim) if translation is None else translation
350
+ center = self.divide(self.to_backend_array(arr.shape), 2)
351
+ if not use_geometric_center:
352
+ center = self.center_of_mass(arr, cutoff=0)
631
353
 
632
- rigid_transform(
633
- coordinates=coordinates,
634
- coordinates_mask=mask_coordinates,
635
- out=coordinates_rotated,
636
- out_mask=mask_rotated,
354
+ matrix = self._rigid_transform_matrix(
637
355
  rotation_matrix=rotation_matrix,
638
356
  translation=translation,
639
- use_geometric_center=use_geometric_center,
640
357
  center=center,
641
358
  )
359
+ out = self.zeros_like(arr) if out is None else out
642
360
 
643
- coordinates_rotated = coordinates_rotated.astype(int)
644
- offset = coordinates_rotated.min(axis=1)
645
- np.multiply(offset, offset < 0, out=offset)
646
- coordinates_rotated -= offset[:, None]
647
-
648
- out_offset = np.zeros(
649
- coordinates_rotated.shape[0], dtype=coordinates_rotated.dtype
361
+ self._rigid_transform(
362
+ data=arr,
363
+ matrix=matrix,
364
+ output=out,
365
+ order=order,
366
+ prefilter=True,
367
+ cache=cache,
650
368
  )
651
- if out is None:
652
- out_offset = coordinates_rotated.min(axis=1)
653
- coordinates_rotated -= out_offset[:, None]
654
- out = np.zeros(coordinates_rotated.max(axis=1) + 1, dtype=arr.dtype)
655
-
656
- if rotate_mask:
657
- mask_rotated = mask_rotated.astype(int)
658
- if out_mask is None:
659
- mask_rotated -= out_offset[:, None]
660
- out_mask = np.zeros(
661
- coordinates_rotated.max(axis=1) + 1, dtype=arr.dtype
662
- )
663
-
664
- in_box = np.logical_and(
665
- mask_rotated < np.array(out_mask.shape)[:, None],
666
- mask_rotated >= 0,
667
- ).min(axis=0)
668
- out_of_box = np.invert(in_box).sum()
669
- if out_of_box != 0:
670
- print(
671
- f"{out_of_box} elements out of bounds. Perhaps increase"
672
- " *arr_mask* size."
673
- )
674
-
675
- mask_coordinates = tuple(mask_coordinates[:, in_box])
676
- mask_rotated = tuple(mask_rotated[:, in_box])
677
- np.add.at(out_mask, mask_rotated, arr_mask[mask_coordinates])
678
-
679
- # Negative coordinates would be (mis)interpreted as reverse index
680
- in_box = np.logical_and(
681
- coordinates_rotated < np.array(out.shape)[:, None], coordinates_rotated >= 0
682
- ).min(axis=0)
683
- out_of_box = np.invert(in_box).sum()
684
- if out_of_box != 0:
685
- print(f"{out_of_box} elements out of bounds. Perhaps increase *out* size.")
686
-
687
- coordinates = coordinates[:, in_box]
688
- coordinates_rotated = coordinates_rotated[:, in_box]
689
-
690
- coordinates = tuple(coordinates)
691
- coordinates_rotated = tuple(coordinates_rotated)
692
- np.add.at(out, coordinates_rotated, arr[coordinates])
693
-
694
- match return_type:
695
- case 0:
696
- return None
697
- case 1:
698
- return out
699
- case 2:
700
- return out_mask
701
- case 3:
702
- return out, out_mask
703
-
704
- def center_of_mass(self, arr: NDArray, cutoff: float = None) -> NDArray:
369
+
370
+ # Applying the prefilter leads to artifacts in the mask.
371
+ if arr_mask is not None:
372
+ out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
373
+ self._rigid_transform(
374
+ data=arr_mask,
375
+ matrix=matrix,
376
+ output=out_mask,
377
+ order=order,
378
+ prefilter=False,
379
+ cache=cache,
380
+ )
381
+
382
+ return out, out_mask
383
+
384
+ def center_of_mass(self, arr: BackendArray, cutoff: float = None) -> BackendArray:
705
385
  """
706
386
  Computes the center of mass of a numpy ndarray instance using all available
707
387
  elements. For template matching it typically makes sense to only input
@@ -709,7 +389,7 @@ class NumpyFFTWBackend(MatchingBackend):
709
389
 
710
390
  Parameters
711
391
  ----------
712
- arr : NDArray
392
+ arr : BackendArray
713
393
  Array to compute the center of mass of.
714
394
  cutoff : float, optional
715
395
  Densities less than or equal to cutoff are nullified for center
@@ -717,23 +397,24 @@ class NumpyFFTWBackend(MatchingBackend):
717
397
 
718
398
  Returns
719
399
  -------
720
- NDArray
400
+ BackendArray
721
401
  Center of mass with shape (arr.ndim).
722
402
  """
723
- cutoff = arr.min() - 1 if cutoff is None else cutoff
724
- arr = self._array_backend.where(arr > cutoff, arr, 0)
403
+ cutoff = self.min(arr) - 1 if cutoff is None else cutoff
404
+
405
+ arr = self.where(arr > cutoff, arr, 0)
725
406
  denominator = self.sum(arr)
726
- grids = self._array_backend.ogrid[tuple(slice(0, i) for i in arr.shape)]
727
- grids = [grid.astype(self._default_dtype) for grid in grids]
728
-
729
- center_of_mass = self.array(
730
- [
731
- self.sum(self.multiply(arr, grids[dim])) / denominator
732
- for dim in range(arr.ndim)
733
- ]
734
- )
735
407
 
736
- return center_of_mass
408
+ grids = []
409
+ for i, x in enumerate(arr.shape):
410
+ baseline_dims = tuple(1 if i != t else x for t in range(len(arr.shape)))
411
+ grids.append(
412
+ self.reshape(self.arange(x, dtype=self._float_dtype), baseline_dims)
413
+ )
414
+
415
+ center_of_mass = [self.sum((arr * grid) / denominator) for grid in grids]
416
+
417
+ return self.to_backend_array(center_of_mass)
737
418
 
738
419
  def get_available_memory(self) -> int:
739
420
  return virtual_memory().available
@@ -743,6 +424,7 @@ class NumpyFFTWBackend(MatchingBackend):
743
424
  yield None
744
425
 
745
426
  def device_count(self) -> int:
427
+ """Returns the number of available GPU devices."""
746
428
  return 1
747
429
 
748
430
  @staticmethod
@@ -764,26 +446,86 @@ class NumpyFFTWBackend(MatchingBackend):
764
446
 
765
447
  def max_score_over_rotations(
766
448
  self,
767
- score_space: NDArray,
768
- internal_scores: NDArray,
769
- internal_rotations: NDArray,
449
+ scores: BackendArray,
450
+ max_scores: BackendArray,
451
+ rotations: BackendArray,
770
452
  rotation_index: int,
771
453
  ) -> None:
772
454
  """
773
- Modify internal_scores and internal_rotations inplace with scores and rotation
774
- index respectively, wherever score_sapce is larger than internal scores.
455
+ Update elements in ``max_scores`` and ``rotations`` where scores is larger than
456
+ max_scores with score and rotation_index, respectivelty.
457
+
458
+ .. warning:: ``max_scores`` and ``rotations`` are modified in-place.
775
459
 
776
460
  Parameters
777
461
  ----------
778
- score_space : numpy.ndarray
779
- The score space to compare against internal_scores.
780
- internal_scores : numpy.ndarray
781
- The internal scores to update with maximum scores.
782
- internal_rotations : numpy.ndarray
783
- The internal rotations corresponding to the maximum scores.
462
+ scores : BackendArray
463
+ The score space to compare against max_scores.
464
+ max_scores : BackendArray
465
+ Maximum score observed for each element in an array.
466
+ rotations : BackendArray
467
+ Rotation used to achieve a given max_score.
784
468
  rotation_index : int
785
469
  The index representing the current rotation.
470
+
471
+ Returns
472
+ -------
473
+ Tuple[BackendArray, BackendArray]
474
+ Updated ``max_scores`` and ``rotations``.
475
+ """
476
+ indices = scores > max_scores
477
+ max_scores[indices] = scores[indices]
478
+ rotations[indices] = rotation_index
479
+ return max_scores, rotations
480
+
481
+ def norm_scores(
482
+ self,
483
+ arr: BackendArray,
484
+ exp_sq: BackendArray,
485
+ sq_exp: BackendArray,
486
+ n_obs: int,
487
+ eps: float,
488
+ out: BackendArray,
489
+ ) -> BackendArray:
490
+ """
491
+ Normalizes ``arr`` by the standard deviation ensuring numerical stability.
492
+
493
+ Parameters
494
+ ----------
495
+ arr : BackendArray
496
+ The input array to be normalized.
497
+ exp_sq : BackendArray
498
+ Non-normalized expectation square.
499
+ sq_exp : BackendArray
500
+ Non-normalized expectation.
501
+ n_obs : int
502
+ Number of observations for normalization.
503
+ eps : float
504
+ Numbers below this threshold will be ignored in division.
505
+ out : BackendArray
506
+ Output array to write the result to.
507
+
508
+ Returns
509
+ -------
510
+ BackendArray
511
+ The normalized array with the same shape as `arr`.
512
+
513
+ See Also
514
+ --------
515
+ :py:meth:`tme.matching_exhaustive.flc_scoring`
786
516
  """
787
- indices = score_space > internal_scores
788
- internal_scores[indices] = score_space[indices]
789
- internal_rotations[indices] = rotation_index
517
+ # Squared expected value (E(X)^2)
518
+ sq_exp = self.divide(sq_exp, n_obs, out=sq_exp)
519
+ sq_exp = self.square(sq_exp, out=sq_exp)
520
+ # Expected squared value (E(X^2))
521
+ exp_sq = self.divide(exp_sq, n_obs, out=exp_sq)
522
+ # Variance
523
+ sq_exp = self.subtract(exp_sq, sq_exp, out=sq_exp)
524
+ sq_exp = self.maximum(sq_exp, 0.0, out=sq_exp)
525
+ sq_exp = self.sqrt(sq_exp, out=sq_exp)
526
+
527
+ # Assume that low stdev regions also have low scores
528
+ # See :py:meth:`tme.matching_exhaustive.flcSphericalMask_setup` for correct norm
529
+ sq_exp[sq_exp < eps] = 1
530
+ sq_exp = self.multiply(sq_exp, n_obs, out=sq_exp)
531
+ return self.divide(arr, sq_exp, out=out)