pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,583 @@
1
+ """ Backend using numpy and pyFFTW for template matching.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ import os
9
+ from psutil import virtual_memory
10
+ from contextlib import contextmanager
11
+ from typing import Tuple, Dict, List, Type
12
+
13
+ import scipy
14
+ import numpy as np
15
+ from scipy.ndimage import maximum_filter, affine_transform
16
+ from pyfftw.builders import rfftn as rfftn_builder, irfftn as irfftn_builder
17
+ from pyfftw import zeros_aligned, simd_alignment, FFTW, next_fast_len, interfaces
18
+
19
+ from ..types import NDArray, BackendArray, shm_type
20
+ from .matching_backend import MatchingBackend, _create_metafunction
21
+
22
+ os.environ["MKL_NUM_THREADS"] = "1"
23
+ os.environ["OMP_NUM_THREADS"] = "1"
24
+ os.environ["PYFFTW_NUM_THREADS"] = "1"
25
+ os.environ["OPENBLAS_NUM_THREADS"] = "1"
26
+
27
+
28
+ def create_ufuncs(obj):
29
+ ufuncs = [
30
+ "add",
31
+ "subtract",
32
+ "multiply",
33
+ "divide",
34
+ "mod",
35
+ "sum",
36
+ "where",
37
+ "einsum",
38
+ "mean",
39
+ "einsum",
40
+ "std",
41
+ "max",
42
+ "min",
43
+ "maximum",
44
+ "minimum",
45
+ "sqrt",
46
+ "square",
47
+ "abs",
48
+ "power",
49
+ "full",
50
+ "clip",
51
+ "arange",
52
+ "stack",
53
+ "concatenate",
54
+ "repeat",
55
+ "indices",
56
+ "unique",
57
+ "argsort",
58
+ "tril_indices",
59
+ "reshape",
60
+ "identity",
61
+ "dot",
62
+ "copy",
63
+ ]
64
+ for ufunc in ufuncs:
65
+ setattr(obj, ufunc, _create_metafunction(ufunc))
66
+ return obj
67
+
68
+
69
+ @create_ufuncs
70
+ class _NumpyWrapper:
71
+ """
72
+ MatchingBackend prohibits using create_ufuncs on NumpyFFTWBackend directly.
73
+ """
74
+
75
+ pass
76
+
77
+
78
+ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
79
+ """
80
+ A numpy and pyfftw-based matching backend.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ array_backend=np,
86
+ float_dtype=np.float32,
87
+ complex_dtype=np.complex64,
88
+ int_dtype=np.int32,
89
+ overflow_safe_dtype=np.float32,
90
+ **kwargs,
91
+ ):
92
+ super().__init__(
93
+ array_backend=array_backend,
94
+ float_dtype=float_dtype,
95
+ complex_dtype=complex_dtype,
96
+ int_dtype=int_dtype,
97
+ overflow_safe_dtype=overflow_safe_dtype,
98
+ )
99
+ self.affine_transform = affine_transform
100
+
101
+ self.cholesky = self._linalg_cholesky
102
+ self.solve_triangular = self._solve_triangular
103
+ self.linalg.solve_triangular = scipy.linalg.solve_triangular
104
+
105
+ def _linalg_cholesky(self, arr, lower=False, *args, **kwargs):
106
+ # Upper argument is not supported until numpy 2.0
107
+ ret = self._array_backend.linalg.cholesky(arr, *args, **kwargs)
108
+ if not lower:
109
+ axes = list(range(ret.ndim))
110
+ axes[-2:] = (ret.ndim - 1, ret.ndim - 2)
111
+ ret = self._array_backend.transpose(ret, axes)
112
+ return ret
113
+
114
+ def _solve_triangular(self, a, b, lower=True, *args, **kwargs):
115
+ mask = self._array_backend.tril if lower else self._array_backend.triu
116
+ return self._array_backend.linalg.solve(mask(a), b, *args, **kwargs)
117
+
118
+ def to_backend_array(self, arr: NDArray) -> NDArray:
119
+ if isinstance(arr, self._array_backend.ndarray):
120
+ return arr
121
+ return self._array_backend.asarray(arr)
122
+
123
+ def to_numpy_array(self, arr: NDArray) -> NDArray:
124
+ return np.array(arr)
125
+
126
+ def to_cpu_array(self, arr: NDArray) -> NDArray:
127
+ return arr
128
+
129
+ def get_fundamental_dtype(self, arr: NDArray) -> Type:
130
+ dt = arr.dtype
131
+ if self._array_backend.issubdtype(dt, self._array_backend.integer):
132
+ return int
133
+ elif self._array_backend.issubdtype(dt, self._array_backend.floating):
134
+ return float
135
+ elif self._array_backend.issubdtype(dt, self._array_backend.complexfloating):
136
+ return complex
137
+ return float
138
+
139
+ def free_cache(self):
140
+ pass
141
+
142
+ def transpose(self, arr: NDArray, *args, **kwargs) -> NDArray:
143
+ return self._array_backend.transpose(arr, *args, **kwargs)
144
+
145
+ def tobytes(self, arr: NDArray) -> str:
146
+ return arr.tobytes()
147
+
148
+ def size(self, arr: NDArray) -> int:
149
+ return arr.size
150
+
151
+ def fill(self, arr: NDArray, value: float) -> NDArray:
152
+ arr.fill(value)
153
+ return arr
154
+
155
+ def eps(self, dtype: type) -> NDArray:
156
+ return self._array_backend.finfo(dtype).eps
157
+
158
+ def datatype_bytes(self, dtype: type) -> NDArray:
159
+ temp = self._array_backend.zeros(1, dtype=dtype)
160
+ return temp.nbytes
161
+
162
+ @staticmethod
163
+ def astype(arr, dtype: Type) -> NDArray:
164
+ return arr.astype(dtype)
165
+
166
+ @staticmethod
167
+ def at(arr, idx, value) -> NDArray:
168
+ arr[idx] = value
169
+ return arr
170
+
171
+ def addat(self, arr, indices, *args, **kwargs) -> NDArray:
172
+ self._array_backend.add.at(arr, indices, *args, **kwargs)
173
+ return arr
174
+
175
+ def topk_indices(self, arr: NDArray, k: int):
176
+ temp = arr.reshape(-1)
177
+ indices = self._array_backend.argpartition(temp, -k)[-k:][:k]
178
+ sorted_indices = indices[self._array_backend.argsort(temp[indices])][::-1]
179
+ sorted_indices = self.unravel_index(indices=sorted_indices, shape=arr.shape)
180
+ return sorted_indices
181
+
182
+ def indices(self, *args, **kwargs) -> NDArray:
183
+ return self._array_backend.indices(*args, **kwargs)
184
+
185
+ def roll(
186
+ self, a: NDArray, shift: Tuple[int], axis: Tuple[int], **kwargs
187
+ ) -> NDArray:
188
+ return self._array_backend.roll(
189
+ a,
190
+ shift=shift,
191
+ axis=axis,
192
+ **kwargs,
193
+ )
194
+
195
+ def unravel_index(self, indices: NDArray, shape: Tuple[int]) -> NDArray:
196
+ return self._array_backend.unravel_index(indices=indices, shape=shape)
197
+
198
+ def max_filter_coordinates(self, score_space: NDArray, min_distance: Tuple[int]):
199
+ score_box = tuple(min_distance for _ in range(score_space.ndim))
200
+ max_filter = maximum_filter(score_space, size=score_box, mode="constant")
201
+ max_filter = max_filter == score_space
202
+
203
+ peaks = np.array(np.nonzero(max_filter)).T
204
+ return peaks
205
+
206
+ @staticmethod
207
+ def zeros(shape: Tuple[int], dtype: type = None) -> NDArray:
208
+ arr = zeros_aligned(shape, dtype=dtype, n=simd_alignment)
209
+ return arr
210
+
211
+ def from_sharedarr(self, args) -> NDArray:
212
+ if len(args) == 1:
213
+ return args[0]
214
+ shm, shape, dtype = args
215
+ return self.ndarray(shape, dtype, shm.buf)
216
+
217
+ def to_sharedarr(
218
+ self, arr: NDArray, shared_memory_handler: type = None
219
+ ) -> shm_type:
220
+ if shared_memory_handler is None:
221
+ return (arr,)
222
+
223
+ shm = shared_memory_handler.SharedMemory(size=arr.nbytes)
224
+ np_array = self.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
225
+ np_array[:] = arr[:].copy()
226
+ return shm, arr.shape, arr.dtype
227
+
228
+ def topleft_pad(self, arr: NDArray, shape: Tuple[int], padval: int = 0) -> NDArray:
229
+ b = self.zeros(shape, arr.dtype)
230
+ self.add(b, padval, out=b)
231
+ aind = [slice(None, None)] * arr.ndim
232
+ bind = [slice(None, None)] * arr.ndim
233
+ for i in range(arr.ndim):
234
+ if arr.shape[i] > shape[i]:
235
+ aind[i] = slice(0, shape[i])
236
+ elif arr.shape[i] < shape[i]:
237
+ bind[i] = slice(0, arr.shape[i])
238
+ b[tuple(bind)] = arr[tuple(aind)]
239
+ return b
240
+
241
+ def build_fft(
242
+ self,
243
+ fwd_shape: Tuple[int],
244
+ inv_shape: Tuple[int],
245
+ real_dtype: type,
246
+ cmpl_dtype: type,
247
+ fftargs: Dict = {},
248
+ inv_output_shape: Tuple[int] = None,
249
+ temp_fwd: NDArray = None,
250
+ temp_inv: NDArray = None,
251
+ fwd_axes: Tuple[int] = None,
252
+ inv_axes: Tuple[int] = None,
253
+ ) -> Tuple[FFTW, FFTW]:
254
+ if temp_fwd is None:
255
+ temp_fwd = (
256
+ self.zeros(fwd_shape, real_dtype) if temp_fwd is None else temp_fwd
257
+ )
258
+ if temp_inv is None:
259
+ temp_inv = (
260
+ self.zeros(inv_shape, cmpl_dtype) if temp_inv is None else temp_inv
261
+ )
262
+
263
+ default_values = {
264
+ "planner_effort": "FFTW_MEASURE",
265
+ "auto_align_input": False,
266
+ "auto_contiguous": False,
267
+ "avoid_copy": True,
268
+ "overwrite_input": True,
269
+ "threads": 1,
270
+ }
271
+ for key in default_values:
272
+ if key in fftargs:
273
+ continue
274
+ fftargs[key] = default_values[key]
275
+
276
+ rfft_shape = self._format_fft_shape(temp_fwd.shape, fwd_axes)
277
+ _rfftn = rfftn_builder(temp_fwd, s=rfft_shape, axes=fwd_axes, **fftargs)
278
+ overwrite_input = fftargs.pop("overwrite_input", None)
279
+
280
+ irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
281
+ irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
282
+ _irfftn = irfftn_builder(temp_inv, s=irfft_shape, axes=inv_axes, **fftargs)
283
+
284
+ def _rfftn_wrapper(arr, out, *args, **kwargs):
285
+ return _rfftn(arr, out)
286
+
287
+ def _irfftn_wrapper(arr, out, *args, **kwargs):
288
+ return _irfftn(arr, out)
289
+
290
+ fftargs["overwrite_input"] = overwrite_input
291
+ return _rfftn_wrapper, _irfftn_wrapper
292
+
293
+ @staticmethod
294
+ def _format_fft_shape(shape: Tuple[int], axes: Tuple[int] = None):
295
+ if axes is None:
296
+ return shape
297
+ axes = tuple(sorted(range(len(shape))[i] for i in axes))
298
+ return tuple(shape[i] for i in axes)
299
+
300
+ def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
301
+ return interfaces.numpy_fft.rfftn(arr, **kwargs)
302
+
303
+ def irfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
304
+ return interfaces.numpy_fft.irfftn(arr, **kwargs)
305
+
306
+ def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
307
+ new_shape = self.to_backend_array(newshape)
308
+ current_shape = self.to_backend_array(arr.shape)
309
+ starts = self.subtract(current_shape, new_shape)
310
+ starts = self.astype(self.divide(starts, 2), self._int_dtype)
311
+ stops = self.astype(self.add(starts, new_shape), self._int_dtype)
312
+ box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
313
+ return arr[box]
314
+
315
+ def compute_convolution_shapes(
316
+ self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
317
+ ) -> Tuple[List[int], List[int], List[int]]:
318
+ convolution_shape = [int(x + y - 1) for x, y in zip(arr1_shape, arr2_shape)]
319
+ fast_shape = [next_fast_len(x) for x in convolution_shape]
320
+ fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
321
+
322
+ return convolution_shape, fast_shape, fast_ft_shape
323
+
324
+ def _rigid_transform_matrix(
325
+ self,
326
+ rotation_matrix: NDArray,
327
+ translation: NDArray = None,
328
+ center: NDArray = None,
329
+ ) -> NDArray:
330
+ ndim = rotation_matrix.shape[0]
331
+ matrix = self.identity(ndim + 1, dtype=self._float_dtype)
332
+
333
+ if translation is not None:
334
+ translation_matrix = self.identity(ndim + 1, dtype=self._float_dtype)
335
+ translation_matrix[:ndim, ndim] = -translation
336
+ self.dot(matrix, translation_matrix, out=matrix)
337
+
338
+ if center is not None:
339
+ center_matrix = self.identity(ndim + 1, dtype=self._float_dtype)
340
+ center_matrix[:ndim, ndim] = center
341
+ self.dot(matrix, center_matrix, out=matrix)
342
+
343
+ if rotation_matrix is not None:
344
+ rmat = self.identity(ndim + 1, dtype=self._float_dtype)
345
+ rmat[:ndim, :ndim] = self._array_backend.linalg.inv(rotation_matrix)
346
+ self.dot(matrix, rmat, out=matrix)
347
+
348
+ if center is not None:
349
+ center_matrix[:ndim, ndim] = -center_matrix[:ndim, ndim]
350
+ self.dot(matrix, center_matrix, out=matrix)
351
+
352
+ matrix /= matrix[ndim, ndim]
353
+ return matrix
354
+
355
+ def _rigid_transform(
356
+ self,
357
+ data: NDArray,
358
+ matrix: NDArray,
359
+ output: NDArray,
360
+ prefilter: bool,
361
+ order: int,
362
+ cache: bool = False,
363
+ batched=False,
364
+ ) -> None:
365
+ if batched:
366
+ for i in range(data.shape[0]):
367
+ self._rigid_transform(
368
+ data=data[i],
369
+ matrix=matrix,
370
+ output=output[i],
371
+ prefilter=prefilter,
372
+ order=order,
373
+ cache=cache,
374
+ batched=False,
375
+ )
376
+ return None
377
+
378
+ out_slice = tuple(slice(0, stop) for stop in data.shape)
379
+ self.affine_transform(
380
+ input=data,
381
+ matrix=matrix,
382
+ mode="constant",
383
+ output=output[out_slice],
384
+ order=order,
385
+ prefilter=prefilter,
386
+ )
387
+
388
+ def rigid_transform(
389
+ self,
390
+ arr: NDArray,
391
+ rotation_matrix: NDArray,
392
+ arr_mask: NDArray = None,
393
+ translation: NDArray = None,
394
+ use_geometric_center: bool = False,
395
+ out: NDArray = None,
396
+ out_mask: NDArray = None,
397
+ order: int = 3,
398
+ cache: bool = False,
399
+ ) -> Tuple[NDArray, NDArray]:
400
+ out = self.zeros_like(arr) if out is None else out
401
+ batched = arr.ndim != rotation_matrix.shape[0]
402
+
403
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
404
+ if not use_geometric_center:
405
+ center = self.center_of_mass(arr, cutoff=0)
406
+
407
+ offset = int(arr.ndim - rotation_matrix.shape[0])
408
+ center = center[offset:]
409
+ translation = self.zeros(center.size) if translation is None else translation
410
+ matrix = self._rigid_transform_matrix(
411
+ rotation_matrix=rotation_matrix,
412
+ translation=translation,
413
+ center=center,
414
+ )
415
+
416
+ subset = tuple(slice(None) for _ in range(arr.ndim))
417
+ if offset > 1:
418
+ subset = tuple(
419
+ 0 if i < (offset - 1) else slice(None) for i in range(arr.ndim)
420
+ )
421
+
422
+ self._rigid_transform(
423
+ data=arr[subset],
424
+ matrix=matrix,
425
+ output=out[subset],
426
+ order=order,
427
+ prefilter=True,
428
+ cache=cache,
429
+ batched=batched,
430
+ )
431
+
432
+ # Applying the prefilter leads to artifacts in the mask.
433
+ if arr_mask is not None:
434
+ out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
435
+ self._rigid_transform(
436
+ data=arr_mask[subset],
437
+ matrix=matrix,
438
+ output=out_mask[subset],
439
+ order=order,
440
+ prefilter=False,
441
+ cache=cache,
442
+ batched=batched,
443
+ )
444
+
445
+ return out, out_mask
446
+
447
+ def center_of_mass(self, arr: BackendArray, cutoff: float = None) -> BackendArray:
448
+ """
449
+ Computes the center of mass of a numpy ndarray instance using all available
450
+ elements. For template matching it typically makes sense to only input
451
+ positive densities.
452
+
453
+ Parameters
454
+ ----------
455
+ arr : BackendArray
456
+ Array to compute the center of mass of.
457
+ cutoff : float, optional
458
+ Densities less than or equal to cutoff are nullified for center
459
+ of mass computation. By default considers all values.
460
+
461
+ Returns
462
+ -------
463
+ BackendArray
464
+ Center of mass with shape (arr.ndim).
465
+ """
466
+ cutoff = self.min(arr) - 1 if cutoff is None else cutoff
467
+
468
+ arr = self.where(arr > cutoff, arr, 0)
469
+ denominator = self.sum(arr)
470
+
471
+ grids = []
472
+ for i, x in enumerate(arr.shape):
473
+ baseline_dims = tuple(1 if i != t else x for t in range(len(arr.shape)))
474
+ grids.append(
475
+ self.reshape(self.arange(x, dtype=self._float_dtype), baseline_dims)
476
+ )
477
+
478
+ center_of_mass = [self.sum((arr * grid) / denominator) for grid in grids]
479
+
480
+ return self.to_backend_array(center_of_mass)
481
+
482
+ def get_available_memory(self) -> int:
483
+ return virtual_memory().available
484
+
485
+ @contextmanager
486
+ def set_device(self, device_index: int):
487
+ yield None
488
+
489
+ def device_count(self) -> int:
490
+ return 1
491
+
492
+ @staticmethod
493
+ def reverse(arr: NDArray, axis: Tuple[int] = None) -> NDArray:
494
+ if axis is None:
495
+ axis = tuple(range(arr.ndim))
496
+ keep, rev = slice(None, None), slice(None, None, -1)
497
+ return arr[tuple(rev if i in axis else keep for i in range(arr.ndim))]
498
+
499
+ def max_score_over_rotations(
500
+ self,
501
+ scores: BackendArray,
502
+ max_scores: BackendArray,
503
+ rotations: BackendArray,
504
+ rotation_index: int,
505
+ ) -> None:
506
+ """
507
+ Update elements in ``max_scores`` and ``rotations`` where scores is larger than
508
+ max_scores with score and rotation_index, respectivelty.
509
+
510
+ .. warning:: ``max_scores`` and ``rotations`` are modified in-place.
511
+
512
+ Parameters
513
+ ----------
514
+ scores : BackendArray
515
+ The score space to compare against max_scores.
516
+ max_scores : BackendArray
517
+ Maximum score observed for each element in an array.
518
+ rotations : BackendArray
519
+ Rotation used to achieve a given max_score.
520
+ rotation_index : int
521
+ The index representing the current rotation.
522
+
523
+ Returns
524
+ -------
525
+ Tuple[BackendArray, BackendArray]
526
+ Updated ``max_scores`` and ``rotations``.
527
+ """
528
+ indices = scores > max_scores
529
+ max_scores[indices] = scores[indices]
530
+ rotations[indices] = rotation_index
531
+ return max_scores, rotations
532
+
533
+ def norm_scores(
534
+ self,
535
+ arr: BackendArray,
536
+ exp_sq: BackendArray,
537
+ sq_exp: BackendArray,
538
+ n_obs: int,
539
+ eps: float,
540
+ out: BackendArray,
541
+ ) -> BackendArray:
542
+ """
543
+ Normalizes ``arr`` by the standard deviation ensuring numerical stability.
544
+
545
+ Parameters
546
+ ----------
547
+ arr : BackendArray
548
+ The input array to be normalized.
549
+ exp_sq : BackendArray
550
+ Non-normalized expectation square.
551
+ sq_exp : BackendArray
552
+ Non-normalized expectation.
553
+ n_obs : int
554
+ Number of observations for normalization.
555
+ eps : float
556
+ Numbers below this threshold will be ignored in division.
557
+ out : BackendArray
558
+ Output array to write the result to.
559
+
560
+ Returns
561
+ -------
562
+ BackendArray
563
+ The normalized array with the same shape as `arr`.
564
+
565
+ See Also
566
+ --------
567
+ :py:meth:`tme.matching_exhaustive.flc_scoring`
568
+ """
569
+ # Squared expected value (E(X)^2)
570
+ sq_exp = self.divide(sq_exp, n_obs, out=sq_exp)
571
+ sq_exp = self.square(sq_exp, out=sq_exp)
572
+ # Expected squared value (E(X^2))
573
+ exp_sq = self.divide(exp_sq, n_obs, out=exp_sq)
574
+ # Variance
575
+ sq_exp = self.subtract(exp_sq, sq_exp, out=sq_exp)
576
+ sq_exp = self.maximum(sq_exp, 0.0, out=sq_exp)
577
+ sq_exp = self.sqrt(sq_exp, out=sq_exp)
578
+
579
+ # Assume that low stdev regions also have low scores
580
+ # See :py:meth:`tme.matching_exhaustive.flcSphericalMask_setup` for correct norm
581
+ sq_exp[sq_exp < eps] = 1
582
+ sq_exp = self.multiply(sq_exp, n_obs, out=sq_exp)
583
+ return self.divide(arr, sq_exp, out=out)