pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
@@ -4,26 +4,79 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
-
7
+ import os
8
8
  from typing import Tuple, Dict, List
9
+ from contextlib import contextmanager
9
10
  from multiprocessing import shared_memory
10
11
  from multiprocessing.managers import SharedMemoryManager
11
- from contextlib import contextmanager
12
12
 
13
13
  import numpy as np
14
14
  from psutil import virtual_memory
15
- from numpy.typing import NDArray
15
+ from scipy.ndimage import maximum_filter, affine_transform
16
16
  from pyfftw import zeros_aligned, simd_alignment, FFTW, next_fast_len
17
17
  from pyfftw.builders import rfftn as rfftn_builder, irfftn as irfftn_builder
18
- from scipy.ndimage import maximum_filter, affine_transform
19
18
 
20
- from .matching_backend import MatchingBackend
21
- from ..matching_utils import rigid_transform
19
+ from ..types import NDArray, BackendArray, shm_type
20
+ from .matching_backend import MatchingBackend, _create_metafunction
21
+
22
+ os.environ["MKL_NUM_THREADS"] = "1"
23
+ os.environ["OMP_NUM_THREADS"] = "1"
24
+ os.environ["PYFFTW_NUM_THREADS"] = "1"
25
+ os.environ["OPENBLAS_NUM_THREADS"] = "1"
26
+
27
+
28
+ def create_ufuncs(obj):
29
+ ufuncs = [
30
+ "add",
31
+ "subtract",
32
+ "multiply",
33
+ "divide",
34
+ "mod",
35
+ "sum",
36
+ "where",
37
+ "einsum",
38
+ "mean",
39
+ "einsum",
40
+ "std",
41
+ "max",
42
+ "min",
43
+ "maximum",
44
+ "minimum",
45
+ "sqrt",
46
+ "square",
47
+ "abs",
48
+ "power",
49
+ "full",
50
+ "clip",
51
+ "arange",
52
+ "stack",
53
+ "concatenate",
54
+ "repeat",
55
+ "indices",
56
+ "unique",
57
+ "argsort",
58
+ "tril_indices",
59
+ "reshape",
60
+ "identity",
61
+ "dot",
62
+ ]
63
+ for ufunc in ufuncs:
64
+ setattr(obj, ufunc, _create_metafunction(ufunc))
65
+ return obj
66
+
67
+
68
+ @create_ufuncs
69
+ class _NumpyWrapper:
70
+ """
71
+ MatchingBackend prohibits using create_ufuncs on NumpyFFTWBackend directly.
72
+ """
73
+
74
+ pass
22
75
 
23
76
 
24
- class NumpyFFTWBackend(MatchingBackend):
77
+ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
25
78
  """
26
- A numpy and pyfftw based backend for template matching.
79
+ A numpy and pyfftw-based matching backend.
27
80
  """
28
81
 
29
82
  def __init__(
@@ -50,7 +103,7 @@ class NumpyFFTWBackend(MatchingBackend):
50
103
  return self._array_backend.asarray(arr)
51
104
 
52
105
  def to_numpy_array(self, arr: NDArray) -> NDArray:
53
- return arr
106
+ return np.array(arr)
54
107
 
55
108
  def to_cpu_array(self, arr: NDArray) -> NDArray:
56
109
  return arr
@@ -68,147 +121,30 @@ class NumpyFFTWBackend(MatchingBackend):
68
121
  def free_cache(self):
69
122
  pass
70
123
 
71
- def add(self, x1, x2, *args, **kwargs) -> NDArray:
72
- x1 = self.to_backend_array(x1)
73
- x2 = self.to_backend_array(x2)
74
- return self._array_backend.add(x1, x2, *args, **kwargs)
75
-
76
- def subtract(self, x1, x2, *args, **kwargs) -> NDArray:
77
- x1 = self.to_backend_array(x1)
78
- x2 = self.to_backend_array(x2)
79
- return self._array_backend.subtract(x1, x2, *args, **kwargs)
80
-
81
- def multiply(self, x1, x2, *args, **kwargs) -> NDArray:
82
- x1 = self.to_backend_array(x1)
83
- x2 = self.to_backend_array(x2)
84
- return self._array_backend.multiply(x1, x2, *args, **kwargs)
85
-
86
- def divide(self, x1, x2, *args, **kwargs) -> NDArray:
87
- x1 = self.to_backend_array(x1)
88
- x2 = self.to_backend_array(x2)
89
- return self._array_backend.divide(x1, x2, *args, **kwargs)
90
-
91
- def mod(self, x1, x2, *args, **kwargs):
92
- x1 = self.to_backend_array(x1)
93
- x2 = self.to_backend_array(x2)
94
- return self._array_backend.mod(x1, x2, *args, **kwargs)
95
-
96
- def sum(self, *args, **kwargs) -> NDArray:
97
- return self._array_backend.sum(*args, **kwargs)
98
-
99
- def einsum(self, *args, **kwargs) -> NDArray:
100
- return self._array_backend.einsum(*args, **kwargs)
101
-
102
- def mean(self, *args, **kwargs) -> NDArray:
103
- return self._array_backend.mean(*args, **kwargs)
104
-
105
- def std(self, *args, **kwargs) -> NDArray:
106
- return self._array_backend.std(*args, **kwargs)
107
-
108
- def max(self, *args, **kwargs) -> NDArray:
109
- return self._array_backend.max(*args, **kwargs)
110
-
111
- def min(self, *args, **kwargs) -> NDArray:
112
- return self._array_backend.min(*args, **kwargs)
113
-
114
- def maximum(self, x1, x2, *args, **kwargs) -> NDArray:
115
- x1 = self.to_backend_array(x1)
116
- x2 = self.to_backend_array(x2)
117
- return self._array_backend.maximum(x1, x2, *args, **kwargs)
118
-
119
- def minimum(self, x1, x2, *args, **kwargs) -> NDArray:
120
- x1 = self.to_backend_array(x1)
121
- x2 = self.to_backend_array(x2)
122
- return self._array_backend.minimum(x1, x2, *args, **kwargs)
123
-
124
- def sqrt(self, *args, **kwargs) -> NDArray:
125
- return self._array_backend.sqrt(*args, **kwargs)
126
-
127
- def square(self, *args, **kwargs) -> NDArray:
128
- return self._array_backend.square(*args, **kwargs)
129
-
130
- def abs(self, *args, **kwargs) -> NDArray:
131
- return self._array_backend.abs(*args, **kwargs)
132
-
133
124
  def transpose(self, arr):
134
125
  return arr.T
135
126
 
136
- def power(self, *args, **kwargs):
137
- return self._array_backend.power(*args, **kwargs)
138
-
139
127
  def tobytes(self, arr):
140
128
  return arr.tobytes()
141
129
 
142
130
  def size(self, arr):
143
131
  return arr.size
144
132
 
145
- def fill(self, arr: NDArray, value: float) -> None:
133
+ def fill(self, arr: NDArray, value: float) -> NDArray:
146
134
  arr.fill(value)
147
-
148
- def zeros(self, shape, dtype=np.float64) -> NDArray:
149
- return self._array_backend.zeros(shape=shape, dtype=dtype)
150
-
151
- def full(self, shape, fill_value, dtype=None, **kwargs) -> NDArray:
152
- return self._array_backend.full(
153
- shape, dtype=dtype, fill_value=fill_value, **kwargs
154
- )
135
+ return arr
155
136
 
156
137
  def eps(self, dtype: type) -> NDArray:
157
- """
158
- Returns the eps defined as diffeerence between 1.0 and the next
159
- representable floating point value larger than 1.0.
160
-
161
- Parameters
162
- ----------
163
- dtype : type
164
- Data type for which eps should be returned.
165
-
166
- Returns
167
- -------
168
- Scalar
169
- The eps for the given data type
170
- """
171
138
  return self._array_backend.finfo(dtype).eps
172
139
 
173
140
  def datatype_bytes(self, dtype: type) -> NDArray:
174
- """
175
- Return the number of bytes occupied by a given datatype.
176
-
177
- Parameters
178
- ----------
179
- dtype : type
180
- Datatype for which the number of bytes is to be determined.
181
-
182
- Returns
183
- -------
184
- int
185
- Number of bytes occupied by the datatype.
186
- """
187
141
  temp = self._array_backend.zeros(1, dtype=dtype)
188
142
  return temp.nbytes
189
143
 
190
- def clip(self, *args, **kwargs) -> NDArray:
191
- return self._array_backend.clip(*args, **kwargs)
192
-
193
- def flip(self, a, axis, **kwargs):
194
- return self._array_backend.flip(a, axis, **kwargs)
195
-
196
144
  @staticmethod
197
145
  def astype(arr, dtype):
198
146
  return arr.astype(dtype)
199
147
 
200
- def arange(self, *args, **kwargs):
201
- return self._array_backend.arange(*args, **kwargs)
202
-
203
- def stack(self, *args, **kwargs):
204
- return self._array_backend.stack(*args, **kwargs)
205
-
206
- def concatenate(self, *args, **kwargs):
207
- return self._array_backend.concatenate(*args, **kwargs)
208
-
209
- def repeat(self, *args, **kwargs):
210
- return self._array_backend.repeat(*args, **kwargs)
211
-
212
148
  def topk_indices(self, arr: NDArray, k: int):
213
149
  temp = arr.reshape(-1)
214
150
  indices = self._array_backend.argpartition(temp, -k)[-k:][:k]
@@ -227,18 +163,9 @@ class NumpyFFTWBackend(MatchingBackend):
227
163
  **kwargs,
228
164
  )
229
165
 
230
- def unique(self, *args, **kwargs):
231
- return self._array_backend.unique(*args, **kwargs)
232
-
233
- def argsort(self, *args, **kwargs):
234
- return self._array_backend.argsort(*args, **kwargs)
235
-
236
166
  def unravel_index(self, indices, shape):
237
167
  return self._array_backend.unravel_index(indices=indices, shape=shape)
238
168
 
239
- def tril_indices(self, *args, **kwargs):
240
- return self._array_backend.tril_indices(*args, **kwargs)
241
-
242
169
  def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
243
170
  score_box = tuple(min_distance for _ in range(score_space.ndim))
244
171
  max_filter = maximum_filter(score_space, size=score_box, mode="constant")
@@ -248,93 +175,27 @@ class NumpyFFTWBackend(MatchingBackend):
248
175
  return peaks
249
176
 
250
177
  @staticmethod
251
- def preallocate_array(shape: Tuple[int], dtype: type) -> NDArray:
252
- """
253
- Returns a byte-aligned array of zeros with specified shape and dtype.
254
-
255
- Parameters
256
- ----------
257
- shape : Tuple[int]
258
- Desired shape for the array.
259
- dtype : type
260
- Desired data type for the array.
261
-
262
- Returns
263
- -------
264
- NDArray
265
- Byte-aligned array of zeros with specified shape and dtype.
266
- """
178
+ def zeros(shape: Tuple[int], dtype: type = None) -> NDArray:
267
179
  arr = zeros_aligned(shape, dtype=dtype, n=simd_alignment)
268
180
  return arr
269
181
 
270
- def sharedarr_to_arr(
271
- self, shm: shared_memory.SharedMemory, shape: Tuple[int], dtype: str
272
- ) -> NDArray:
273
- """
274
- Returns an array of given shape and dtype from shared memory location.
275
-
276
- Parameters
277
- ----------
278
- shape : tuple
279
- Tuple of integers specifying the shape of the array.
280
- dtype : str
281
- String specifying the dtype of the array.
282
- shm : shared_memory.SharedMemory
283
- Shared memory object where the array is stored.
284
-
285
- Returns
286
- -------
287
- NDArray
288
- Array of the specified shape and dtype from the shared memory location.
289
- """
182
+ def from_sharedarr(self, args) -> NDArray:
183
+ shm, shape, dtype = args
290
184
  return self.ndarray(shape, dtype, shm.buf)
291
185
 
292
- def arr_to_sharedarr(
186
+ def to_sharedarr(
293
187
  self, arr: NDArray, shared_memory_handler: type = None
294
- ) -> shared_memory.SharedMemory:
295
- """
296
- Converts a numpy array to an object shared in memory.
297
-
298
- Parameters
299
- ----------
300
- arr : NDArray
301
- Numpy array to convert.
302
- shared_memory_handler : type, optional
303
- The type of shared memory handler. Default is None.
304
-
305
- Returns
306
- -------
307
- shared_memory.SharedMemory
308
- The shared memory object containing the numpy array.
309
- """
310
- if type(shared_memory_handler) == SharedMemoryManager:
188
+ ) -> shm_type:
189
+ if isinstance(shared_memory_handler, SharedMemoryManager):
311
190
  shm = shared_memory_handler.SharedMemory(size=arr.nbytes)
312
191
  else:
313
192
  shm = shared_memory.SharedMemory(create=True, size=arr.nbytes)
314
193
  np_array = self.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
315
194
  np_array[:] = arr[:].copy()
316
- return shm
195
+ return shm, arr.shape, arr.dtype
317
196
 
318
197
  def topleft_pad(self, arr: NDArray, shape: Tuple[int], padval: int = 0) -> NDArray:
319
- """
320
- Returns an array that has been padded to a specified shape with a padding
321
- value at the top-left corner.
322
-
323
- Parameters
324
- ----------
325
- arr : NDArray
326
- Input array to be padded.
327
- shape : Tuple[int]
328
- Desired shape for the output array.
329
- padval : int, optional
330
- Value to use for padding, default is 0.
331
-
332
- Returns
333
- -------
334
- NDArray
335
- Array that has been padded to the specified shape.
336
- """
337
- b = self.preallocate_array(shape, arr.dtype)
198
+ b = self.zeros(shape, arr.dtype)
338
199
  self.add(b, padval, out=b)
339
200
  aind = [slice(None, None)] * arr.ndim
340
201
  bind = [slice(None, None)] * arr.ndim
@@ -357,40 +218,10 @@ class NumpyFFTWBackend(MatchingBackend):
357
218
  temp_real: NDArray = None,
358
219
  temp_fft: NDArray = None,
359
220
  ) -> Tuple[FFTW, FFTW]:
360
- """
361
- Build pyFFTW builder functions.
362
-
363
- Parameters
364
- ----------
365
- fast_shape : tuple
366
- Tuple of integers corresponding to fast convolution shape
367
- (see `compute_convolution_shapes`).
368
- fast_ft_shape : tuple
369
- Tuple of integers corresponding to the shape of the fourier
370
- transform array (see `compute_convolution_shapes`).
371
- real_dtype : dtype
372
- Numpy dtype of the inverse fourier transform.
373
- complex_dtype : dtype
374
- Numpy dtype of the fourier transform.
375
- inverse_fast_shape : tuple, optional
376
- Output shape of the inverse Fourier transform. By default fast_shape.
377
- fftargs : dict, optional
378
- Dictionary passed to pyFFTW builders.
379
- temp_real : NDArray, optional
380
- Temporary real numpy array, by default None.
381
- temp_fft : NDArray, optional
382
- Temporary fft numpy array, by default None.
383
-
384
- Returns
385
- -------
386
- tuple
387
- Tuple containing callable pyFFTW objects for forward and inverse
388
- fourier transform.
389
- """
390
221
  if temp_real is None:
391
- temp_real = self.preallocate_array(fast_shape, real_dtype)
222
+ temp_real = self.zeros(fast_shape, real_dtype)
392
223
  if temp_fft is None:
393
- temp_fft = self.preallocate_array(fast_ft_shape, complex_dtype)
224
+ temp_fft = self.zeros(fast_ft_shape, complex_dtype)
394
225
  if inverse_fast_shape is None:
395
226
  inverse_fast_shape = fast_shape
396
227
 
@@ -419,30 +250,11 @@ class NumpyFFTWBackend(MatchingBackend):
419
250
  return rfftn, irfftn
420
251
 
421
252
  def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
422
- """
423
- Extract the centered portion of an array based on a new shape.
424
-
425
- Parameters
426
- ----------
427
- arr : NDArray
428
- Input array.
429
- newshape : tuple
430
- Desired shape for the central portion.
431
-
432
- Returns
433
- -------
434
- NDArray
435
- Central portion of the array with shape `newshape`.
436
-
437
- References
438
- ----------
439
- .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py
440
- """
441
253
  new_shape = self.to_backend_array(newshape)
442
254
  current_shape = self.to_backend_array(arr.shape)
443
255
  starts = self.subtract(current_shape, new_shape)
444
256
  starts = self.astype(self.divide(starts, 2), self._int_dtype)
445
- stops = self.astype(self.add(starts, newshape), self._int_dtype)
257
+ stops = self.astype(self.add(starts, new_shape), self._int_dtype)
446
258
  box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
447
259
  return arr[box]
448
260
 
@@ -466,252 +278,111 @@ class NumpyFFTWBackend(MatchingBackend):
466
278
  fourier transform, shape of the forward fourier transform
467
279
  (see :py:meth:`build_fft`).
468
280
  """
469
- convolution_shape = [
470
- int(x) + int(y) - 1 for x, y in zip(arr1_shape, arr2_shape)
471
- ]
281
+ convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
472
282
  fast_shape = [next_fast_len(x) for x in convolution_shape]
473
283
  fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
474
284
 
475
285
  return convolution_shape, fast_shape, fast_ft_shape
476
286
 
477
- def rotate_array(
287
+ def _rigid_transform_matrix(
478
288
  self,
479
- arr: NDArray,
480
- rotation_matrix: NDArray,
481
- arr_mask: NDArray = None,
289
+ rotation_matrix: NDArray = None,
482
290
  translation: NDArray = None,
483
- use_geometric_center: bool = False,
484
- out: NDArray = None,
485
- out_mask: NDArray = None,
486
- order: int = 3,
487
- ) -> None:
488
- """
489
- Rotates coordinates of arr according to rotation_matrix.
490
-
491
- If no output array is provided, this method will compute an array with
492
- sufficient space to hold all elements. If both `arr` and `arr_mask`
493
- are provided, `arr_mask` will be centered according to arr.
494
-
495
- Parameters
496
- ----------
497
- arr : NDArray
498
- The input array to be rotated.
499
- arr_mask : NDArray, optional
500
- The mask of `arr` that will be equivalently rotated.
501
- rotation_matrix : NDArray
502
- The rotation matrix to apply [d x d].
503
- translation : NDArray
504
- The translation to apply [d].
505
- use_geometric_center : bool, optional
506
- Whether the rotation should be centered around the geometric
507
- or mass center. Default is mass center.
508
- out : NDArray, optional
509
- The output array to write the rotation of `arr` to.
510
- out_mask : NDArray, optional
511
- The output array to write the rotation of `arr_mask` to.
512
- order : int, optional
513
- Spline interpolation order. Has to be in the range 0-5. Non-zero
514
- elements will be converted into a point-cloud and rotated according
515
- to ``rotation_matrix`` if order is None.
516
- """
291
+ center: NDArray = None,
292
+ ) -> NDArray:
293
+ ndim = rotation_matrix.shape[0]
294
+ matrix = self.identity(ndim + 1, dtype=self._float_dtype)
517
295
 
518
- if order is None:
519
- mask_coordinates = None
520
- if arr_mask is not None:
521
- mask_coordinates = np.array(np.where(arr_mask > 0))
522
- return self.rotate_array_coordinates(
523
- arr=arr,
524
- arr_mask=arr_mask,
525
- coordinates=np.array(np.where(arr > 0)),
526
- mask_coordinates=mask_coordinates,
527
- out=out,
528
- out_mask=out_mask,
529
- rotation_matrix=rotation_matrix,
530
- translation=translation,
531
- use_geometric_center=use_geometric_center,
532
- )
296
+ if translation is not None:
297
+ translation_matrix = self.identity(ndim + 1, dtype=self._float_dtype)
298
+ translation_matrix[:ndim, ndim] = -translation
299
+ self.dot(matrix, translation_matrix, out=matrix)
533
300
 
534
- rotate_mask = arr_mask is not None
535
- return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
536
- translation = np.zeros(arr.ndim) if translation is None else translation
301
+ if center is not None:
302
+ center_matrix = self.identity(ndim + 1, dtype=self._float_dtype)
303
+ center_matrix[:ndim, ndim] = center
304
+ self.dot(matrix, center_matrix, out=matrix)
537
305
 
538
- center = np.divide(arr.shape, 2)
539
- if not use_geometric_center:
540
- center = self.center_of_mass(arr, cutoff=0)
306
+ if rotation_matrix is not None:
307
+ rmat = self.identity(ndim + 1, dtype=self._float_dtype)
308
+ rmat[:ndim, :ndim] = self._array_backend.linalg.inv(rotation_matrix)
309
+ self.dot(matrix, rmat, out=matrix)
541
310
 
542
- rotation_matrix_inverted = np.linalg.inv(rotation_matrix)
543
- transformed_center = rotation_matrix_inverted @ center.reshape(-1, 1)
544
- transformed_center = transformed_center.reshape(-1)
545
- base_offset = np.subtract(center, transformed_center)
546
- offset = np.subtract(base_offset, translation)
311
+ if center is not None:
312
+ center_matrix[:ndim, ndim] = -center_matrix[:ndim, ndim]
313
+ self.dot(matrix, center_matrix, out=matrix)
547
314
 
548
- out = np.zeros_like(arr) if out is None else out
549
- out_slice = tuple(slice(0, stop) for stop in arr.shape)
315
+ matrix /= matrix[ndim, ndim]
316
+ return matrix
550
317
 
551
- # Applying the prefilter can cause artifacts in the mask
552
- affine_transform(
553
- input=arr,
554
- matrix=rotation_matrix_inverted,
555
- offset=offset,
318
+ def _rigid_transform(
319
+ self,
320
+ data: NDArray,
321
+ matrix: NDArray,
322
+ output: NDArray,
323
+ prefilter: bool,
324
+ order: int,
325
+ cache: bool = False,
326
+ ) -> None:
327
+ out_slice = tuple(slice(0, stop) for stop in data.shape)
328
+ self.affine_transform(
329
+ input=data,
330
+ matrix=matrix,
556
331
  mode="constant",
557
- output=out[out_slice],
332
+ output=output[out_slice],
558
333
  order=order,
559
- prefilter=True,
334
+ prefilter=prefilter,
560
335
  )
561
336
 
562
- if rotate_mask:
563
- out_mask = np.zeros_like(arr_mask) if out_mask is None else out_mask
564
- out_mask_slice = tuple(slice(0, stop) for stop in arr_mask.shape)
565
- affine_transform(
566
- input=arr_mask,
567
- matrix=rotation_matrix_inverted,
568
- offset=offset,
569
- mode="constant",
570
- output=out_mask[out_mask_slice],
571
- order=order,
572
- prefilter=False,
573
- )
574
-
575
- match return_type:
576
- case 0:
577
- return None
578
- case 1:
579
- return out
580
- case 2:
581
- return out_mask
582
- case 3:
583
- return out, out_mask
584
-
585
- @staticmethod
586
- def rotate_array_coordinates(
337
+ def rigid_transform(
338
+ self,
587
339
  arr: NDArray,
588
- coordinates: NDArray,
589
340
  rotation_matrix: NDArray,
341
+ arr_mask: NDArray = None,
590
342
  translation: NDArray = None,
343
+ use_geometric_center: bool = False,
591
344
  out: NDArray = None,
592
- use_geometric_center: bool = True,
593
- arr_mask: NDArray = None,
594
- mask_coordinates: NDArray = None,
595
345
  out_mask: NDArray = None,
596
- ) -> None:
597
- """
598
- Rotates coordinates of arr according to rotation_matrix.
599
-
600
- If no output array is provided, this method will compute an array with
601
- sufficient space to hold all elements. If both `arr` and `arr_mask`
602
- are provided, `arr_mask` will be centered according to arr.
603
-
604
- No centering will be performed if the rotation matrix is the identity matrix.
605
-
606
- Parameters
607
- ----------
608
- arr : NDArray
609
- The input array to be rotated.
610
- coordinates : NDArray
611
- The pointcloud [d x N] containing elements of `arr` that should be rotated.
612
- See :py:meth:`Density.to_pointcloud` on how to obtain the coordinates.
613
- rotation_matrix : NDArray
614
- The rotation matrix to apply [d x d].
615
- rotation_matrix : NDArray
616
- The translation to apply [d].
617
- out : NDArray, optional
618
- The output array to write the rotation of `arr` to.
619
- use_geometric_center : bool, optional
620
- Whether the rotation should be centered around the geometric
621
- or mass center.
622
- arr_mask : NDArray, optional
623
- The mask of `arr` that will be equivalently rotated.
624
- mask_coordinates : NDArray, optional
625
- Equivalent to `coordinates`, but containing elements of `arr_mask`
626
- that should be rotated.
627
- out_mask : NDArray, optional
628
- The output array to write the rotation of `arr_mask` to.
629
- """
630
- rotate_mask = arr_mask is not None and mask_coordinates is not None
631
- return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
632
-
633
- coordinates_rotated = np.empty(coordinates.shape, dtype=rotation_matrix.dtype)
634
- mask_rotated = (
635
- np.empty(mask_coordinates.shape, dtype=rotation_matrix.dtype)
636
- if rotate_mask
637
- else None
638
- )
346
+ order: int = 3,
347
+ cache: bool = False,
348
+ ) -> Tuple[NDArray, NDArray]:
349
+ translation = self.zeros(arr.ndim) if translation is None else translation
639
350
 
640
- center = np.array(arr.shape) // 2 if use_geometric_center else None
641
- if translation is None:
642
- translation = np.zeros(coordinates_rotated.shape[0])
351
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
352
+ if not use_geometric_center:
353
+ center = self.center_of_mass(arr, cutoff=0)
643
354
 
644
- rigid_transform(
645
- coordinates=coordinates,
646
- coordinates_mask=mask_coordinates,
647
- out=coordinates_rotated,
648
- out_mask=mask_rotated,
355
+ matrix = self._rigid_transform_matrix(
649
356
  rotation_matrix=rotation_matrix,
650
357
  translation=translation,
651
- use_geometric_center=use_geometric_center,
652
358
  center=center,
653
359
  )
360
+ out = self.zeros_like(arr) if out is None else out
654
361
 
655
- coordinates_rotated = coordinates_rotated.astype(int)
656
- offset = coordinates_rotated.min(axis=1)
657
- np.multiply(offset, offset < 0, out=offset)
658
- coordinates_rotated -= offset[:, None]
659
-
660
- out_offset = np.zeros(
661
- coordinates_rotated.shape[0], dtype=coordinates_rotated.dtype
362
+ self._rigid_transform(
363
+ data=arr,
364
+ matrix=matrix,
365
+ output=out,
366
+ order=order,
367
+ prefilter=True,
368
+ cache=cache,
662
369
  )
663
- if out is None:
664
- out_offset = coordinates_rotated.min(axis=1)
665
- coordinates_rotated -= out_offset[:, None]
666
- out = np.zeros(coordinates_rotated.max(axis=1) + 1, dtype=arr.dtype)
667
-
668
- if rotate_mask:
669
- mask_rotated = mask_rotated.astype(int)
670
- if out_mask is None:
671
- # mask_rotated -= out_offset[:, None]
672
- out_mask = np.zeros(mask_rotated.max(axis=1) + 1, dtype=arr.dtype)
673
-
674
- in_box = np.logical_and(
675
- mask_rotated < np.array(out_mask.shape)[:, None],
676
- mask_rotated >= 0,
677
- ).min(axis=0)
678
- out_of_box = np.invert(in_box).sum()
679
- if out_of_box != 0:
680
- print(
681
- f"{out_of_box} elements out of bounds. Perhaps increase"
682
- " *arr_mask* size."
683
- )
684
-
685
- mask_coordinates = tuple(mask_coordinates[:, in_box])
686
- mask_rotated = tuple(mask_rotated[:, in_box])
687
- np.add.at(out_mask, mask_rotated, arr_mask[mask_coordinates])
688
-
689
- # Negative coordinates would be (mis)interpreted as reverse index
690
- in_box = np.logical_and(
691
- coordinates_rotated < np.array(out.shape)[:, None], coordinates_rotated >= 0
692
- ).min(axis=0)
693
- out_of_box = np.invert(in_box).sum()
694
- if out_of_box != 0:
695
- print(f"{out_of_box} elements out of bounds. Perhaps increase *out* size.")
696
-
697
- coordinates = coordinates[:, in_box]
698
- coordinates_rotated = coordinates_rotated[:, in_box]
699
-
700
- coordinates = tuple(coordinates)
701
- coordinates_rotated = tuple(coordinates_rotated)
702
- np.add.at(out, coordinates_rotated, arr[coordinates])
703
-
704
- match return_type:
705
- case 0:
706
- return None
707
- case 1:
708
- return out
709
- case 2:
710
- return out_mask
711
- case 3:
712
- return out, out_mask
713
-
714
- def center_of_mass(self, arr: NDArray, cutoff: float = None) -> NDArray:
370
+
371
+ # Applying the prefilter leads to artifacts in the mask.
372
+ if arr_mask is not None:
373
+ out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
374
+ self._rigid_transform(
375
+ data=arr_mask,
376
+ matrix=matrix,
377
+ output=out_mask,
378
+ order=order,
379
+ prefilter=False,
380
+ cache=cache,
381
+ )
382
+
383
+ return out, out_mask
384
+
385
+ def center_of_mass(self, arr: BackendArray, cutoff: float = None) -> BackendArray:
715
386
  """
716
387
  Computes the center of mass of a numpy ndarray instance using all available
717
388
  elements. For template matching it typically makes sense to only input
@@ -719,7 +390,7 @@ class NumpyFFTWBackend(MatchingBackend):
719
390
 
720
391
  Parameters
721
392
  ----------
722
- arr : NDArray
393
+ arr : BackendArray
723
394
  Array to compute the center of mass of.
724
395
  cutoff : float, optional
725
396
  Densities less than or equal to cutoff are nullified for center
@@ -727,23 +398,24 @@ class NumpyFFTWBackend(MatchingBackend):
727
398
 
728
399
  Returns
729
400
  -------
730
- NDArray
401
+ BackendArray
731
402
  Center of mass with shape (arr.ndim).
732
403
  """
733
- cutoff = arr.min() - 1 if cutoff is None else cutoff
734
- arr = self._array_backend.where(arr > cutoff, arr, 0)
404
+ cutoff = self.min(arr) - 1 if cutoff is None else cutoff
405
+
406
+ arr = self.where(arr > cutoff, arr, 0)
735
407
  denominator = self.sum(arr)
736
- grids = self._array_backend.ogrid[tuple(slice(0, i) for i in arr.shape)]
737
- grids = [grid.astype(self._float_dtype) for grid in grids]
738
-
739
- center_of_mass = self.array(
740
- [
741
- self.sum(self.multiply(arr, grids[dim]) / denominator)
742
- for dim in range(arr.ndim)
743
- ]
744
- )
745
408
 
746
- return center_of_mass
409
+ grids = []
410
+ for i, x in enumerate(arr.shape):
411
+ baseline_dims = tuple(1 if i != t else x for t in range(len(arr.shape)))
412
+ grids.append(
413
+ self.reshape(self.arange(x, dtype=self._float_dtype), baseline_dims)
414
+ )
415
+
416
+ center_of_mass = [self.sum((arr * grid) / denominator) for grid in grids]
417
+
418
+ return self.to_backend_array(center_of_mass)
747
419
 
748
420
  def get_available_memory(self) -> int:
749
421
  return virtual_memory().available
@@ -753,6 +425,7 @@ class NumpyFFTWBackend(MatchingBackend):
753
425
  yield None
754
426
 
755
427
  def device_count(self) -> int:
428
+ """Returns the number of available GPU devices."""
756
429
  return 1
757
430
 
758
431
  @staticmethod
@@ -774,26 +447,86 @@ class NumpyFFTWBackend(MatchingBackend):
774
447
 
775
448
  def max_score_over_rotations(
776
449
  self,
777
- score_space: NDArray,
778
- internal_scores: NDArray,
779
- internal_rotations: NDArray,
450
+ scores: BackendArray,
451
+ max_scores: BackendArray,
452
+ rotations: BackendArray,
780
453
  rotation_index: int,
781
454
  ) -> None:
782
455
  """
783
- Modify internal_scores and internal_rotations inplace with scores and rotation
784
- index respectively, wherever score_sapce is larger than internal scores.
456
+ Update elements in ``max_scores`` and ``rotations`` where scores is larger than
457
+ max_scores with score and rotation_index, respectivelty.
458
+
459
+ .. warning:: ``max_scores`` and ``rotations`` are modified in-place.
785
460
 
786
461
  Parameters
787
462
  ----------
788
- score_space : numpy.ndarray
789
- The score space to compare against internal_scores.
790
- internal_scores : numpy.ndarray
791
- The internal scores to update with maximum scores.
792
- internal_rotations : numpy.ndarray
793
- The internal rotations corresponding to the maximum scores.
463
+ scores : BackendArray
464
+ The score space to compare against max_scores.
465
+ max_scores : BackendArray
466
+ Maximum score observed for each element in an array.
467
+ rotations : BackendArray
468
+ Rotation used to achieve a given max_score.
794
469
  rotation_index : int
795
470
  The index representing the current rotation.
471
+
472
+ Returns
473
+ -------
474
+ Tuple[BackendArray, BackendArray]
475
+ Updated ``max_scores`` and ``rotations``.
476
+ """
477
+ indices = scores > max_scores
478
+ max_scores[indices] = scores[indices]
479
+ rotations[indices] = rotation_index
480
+ return max_scores, rotations
481
+
482
+ def norm_scores(
483
+ self,
484
+ arr: BackendArray,
485
+ exp_sq: BackendArray,
486
+ sq_exp: BackendArray,
487
+ n_obs: int,
488
+ eps: float,
489
+ out: BackendArray,
490
+ ) -> BackendArray:
491
+ """
492
+ Normalizes ``arr`` by the standard deviation ensuring numerical stability.
493
+
494
+ Parameters
495
+ ----------
496
+ arr : BackendArray
497
+ The input array to be normalized.
498
+ exp_sq : BackendArray
499
+ Non-normalized expectation square.
500
+ sq_exp : BackendArray
501
+ Non-normalized expectation.
502
+ n_obs : int
503
+ Number of observations for normalization.
504
+ eps : float
505
+ Numbers below this threshold will be ignored in division.
506
+ out : BackendArray
507
+ Output array to write the result to.
508
+
509
+ Returns
510
+ -------
511
+ BackendArray
512
+ The normalized array with the same shape as `arr`.
513
+
514
+ See Also
515
+ --------
516
+ :py:meth:`tme.matching_exhaustive.flc_scoring`
796
517
  """
797
- indices = score_space > internal_scores
798
- internal_scores[indices] = score_space[indices]
799
- internal_rotations[indices] = rotation_index
518
+ # Squared expected value (E(X)^2)
519
+ sq_exp = self.divide(sq_exp, n_obs, out=sq_exp)
520
+ sq_exp = self.square(sq_exp, out=sq_exp)
521
+ # Expected squared value (E(X^2))
522
+ exp_sq = self.divide(exp_sq, n_obs, out=exp_sq)
523
+ # Variance
524
+ sq_exp = self.subtract(exp_sq, sq_exp, out=sq_exp)
525
+ sq_exp = self.maximum(sq_exp, 0.0, out=sq_exp)
526
+ sq_exp = self.sqrt(sq_exp, out=sq_exp)
527
+
528
+ # Assume that low stdev regions also have low scores
529
+ # See :py:meth:`tme.matching_exhaustive.flcSphericalMask_setup` for correct norm
530
+ sq_exp[sq_exp < eps] = 1
531
+ sq_exp = self.multiply(sq_exp, n_obs, out=sq_exp)
532
+ return self.divide(arr, sq_exp, out=out)