pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/memory.py ADDED
@@ -0,0 +1,377 @@
1
+ """ Compute memory consumption of template matching components.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Tuple
10
+
11
+ import numpy as np
12
+ from pyfftw import next_fast_len
13
+
14
+
15
+ class MatchingMemoryUsage(ABC):
16
+ """
17
+ Class specification for estimating the memory requirements of template matching.
18
+
19
+ Parameters
20
+ ----------
21
+ fast_shape : tuple of int
22
+ Shape of the real array.
23
+ ft_shape : tuple of int
24
+ Shape of the complex array.
25
+ float_nbytes : int
26
+ Number of bytes of the used float, e.g. 4 for float32.
27
+ complex_nbytes : int
28
+ Number of bytes of the used complex, e.g. 8 for complex64.
29
+ integer_nbytes : int
30
+ Number of bytes of the used integer, e.g. 4 for int32.
31
+
32
+ Attributes
33
+ ----------
34
+ real_array_size : int
35
+ Number of elements in real array.
36
+ complex_array_size : int
37
+ Number of elements in complex array.
38
+ float_nbytes : int
39
+ Number of bytes of the used float, e.g. 4 for float32.
40
+ complex_nbytes : int
41
+ Number of bytes of the used complex, e.g. 8 for complex64.
42
+ integer_nbytes : int
43
+ Number of bytes of the used integer, e.g. 4 for int32.
44
+
45
+ Methods
46
+ -------
47
+ base_usage():
48
+ Returns the base memory usage in bytes.
49
+ per_fork():
50
+ Returns the memory usage in bytes per fork.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ fast_shape: Tuple[int],
56
+ ft_shape: Tuple[int],
57
+ float_nbytes: int,
58
+ complex_nbytes: int,
59
+ integer_nbytes: int,
60
+ ):
61
+ self.real_array_size = np.prod(fast_shape)
62
+ self.complex_array_size = np.prod(ft_shape)
63
+ self.float_nbytes = float_nbytes
64
+ self.complex_nbytes = complex_nbytes
65
+ self.integer_nbytes = integer_nbytes
66
+
67
+ @abstractmethod
68
+ def base_usage(self) -> int:
69
+ """Return the base memory usage in bytes."""
70
+
71
+ @abstractmethod
72
+ def per_fork(self) -> int:
73
+ """Return the memory usage per fork in bytes."""
74
+
75
+
76
+ class CCMemoryUsage(MatchingMemoryUsage):
77
+ """
78
+ Memory usage estimation for CC scoring.
79
+
80
+ See Also
81
+ --------
82
+ :py:meth:`tme.matching_exhaustive.cc_setup`.
83
+ """
84
+
85
+ def base_usage(self) -> int:
86
+ float_arrays = self.real_array_size * self.float_nbytes
87
+ complex_arrays = self.complex_array_size * self.complex_nbytes
88
+ return float_arrays + complex_arrays
89
+
90
+ def per_fork(self) -> int:
91
+ float_arrays = self.real_array_size * self.float_nbytes
92
+ complex_arrays = self.complex_array_size * self.complex_nbytes
93
+ return float_arrays + complex_arrays
94
+
95
+
96
+ class LCCMemoryUsage(CCMemoryUsage):
97
+ """
98
+ Memory usage estimation for LCC scoring.
99
+ See Also
100
+ --------
101
+ :py:meth:`tme.matching_exhaustive.lcc_setup`.
102
+ """
103
+
104
+
105
+ class CORRMemoryUsage(MatchingMemoryUsage):
106
+ """
107
+ Memory usage estimation for CORR scoring.
108
+
109
+ See Also
110
+ --------
111
+ :py:meth:`tme.matching_exhaustive.corr_setup`.
112
+ """
113
+
114
+ def base_usage(self) -> int:
115
+ float_arrays = self.real_array_size * self.float_nbytes * 4
116
+ complex_arrays = self.complex_array_size * self.complex_nbytes
117
+ return float_arrays + complex_arrays
118
+
119
+ def per_fork(self) -> int:
120
+ float_arrays = self.real_array_size * self.float_nbytes
121
+ complex_arrays = self.complex_array_size * self.complex_nbytes
122
+ return float_arrays + complex_arrays
123
+
124
+
125
+ class CAMMemoryUsage(CORRMemoryUsage):
126
+ """
127
+ Memory usage estimation for CAM scoring.
128
+
129
+ See Also
130
+ --------
131
+ :py:meth:`tme.matching_exhaustive.cam_setup`.
132
+ """
133
+
134
+
135
+ class FLCSphericalMaskMemoryUsage(CORRMemoryUsage):
136
+ """
137
+ Memory usage estimation for FLCMSphericalMask scoring.
138
+
139
+ See Also
140
+ --------
141
+ :py:meth:`tme.matching_exhaustive.flcSphericalMask_setup`.
142
+ """
143
+
144
+
145
+ class FLCMemoryUsage(MatchingMemoryUsage):
146
+ """
147
+ Memory usage estimation for FLC scoring.
148
+
149
+ See Also
150
+ --------
151
+ :py:meth:`tme.matching_exhaustive.flc_setup`.
152
+ """
153
+
154
+ def base_usage(self) -> int:
155
+ float_arrays = self.real_array_size * self.float_nbytes * 2
156
+ complex_arrays = self.complex_array_size * self.complex_nbytes * 2
157
+ return float_arrays + complex_arrays
158
+
159
+ def per_fork(self) -> int:
160
+ float_arrays = self.real_array_size * self.float_nbytes * 3
161
+ complex_arrays = self.complex_array_size * self.complex_nbytes * 2
162
+ return float_arrays + complex_arrays
163
+
164
+
165
+ class MCCMemoryUsage(MatchingMemoryUsage):
166
+ """
167
+ Memory usage estimation for MCC scoring.
168
+
169
+ See Also
170
+ --------
171
+ :py:meth:`tme.matching_exhaustive.mcc_setup`.
172
+ """
173
+
174
+ def base_usage(self) -> int:
175
+ float_arrays = self.real_array_size * self.float_nbytes * 2
176
+ complex_arrays = self.complex_array_size * self.complex_nbytes * 3
177
+ return float_arrays + complex_arrays
178
+
179
+ def per_fork(self) -> int:
180
+ float_arrays = self.real_array_size * self.float_nbytes * 6
181
+ complex_arrays = self.complex_array_size * self.complex_nbytes
182
+ return float_arrays + complex_arrays
183
+
184
+
185
+ class MaxScoreOverRotationsMemoryUsage(MatchingMemoryUsage):
186
+ """
187
+ Memory usage estimation MaxScoreOverRotations Analyzer.
188
+
189
+ See Also
190
+ --------
191
+ :py:class:`tme.analyzer.MaxScoreOverRotations`.
192
+ """
193
+
194
+ def base_usage(self) -> int:
195
+ float_arrays = self.real_array_size * self.float_nbytes * 2
196
+ return float_arrays
197
+
198
+ def per_fork(self) -> int:
199
+ return 0
200
+
201
+
202
+ class PeakCallerMaximumFilterMemoryUsage(MatchingMemoryUsage):
203
+ """
204
+ Memory usage estimation MaxScoreOverRotations Analyzer.
205
+
206
+ See Also
207
+ --------
208
+ :py:class:`tme.analyzer.PeakCallerMaximumFilter`.
209
+ """
210
+
211
+ def base_usage(self) -> int:
212
+ float_arrays = self.real_array_size * self.float_nbytes
213
+ return float_arrays
214
+
215
+ def per_fork(self) -> int:
216
+ float_arrays = self.real_array_size * self.float_nbytes
217
+ return float_arrays
218
+
219
+
220
+ class CupyBackendMemoryUsage(MatchingMemoryUsage):
221
+ """
222
+ Memory usage estimation for CupyBackend.
223
+
224
+ See Also
225
+ --------
226
+ :py:class:`tme.backends.CupyBackend`.
227
+ """
228
+
229
+ def base_usage(self) -> int:
230
+ # FFT plans, overhead from assigning FFT result, rotation interpolation
231
+ complex_arrays = self.real_array_size * self.complex_nbytes * 3
232
+ float_arrays = self.complex_array_size * self.float_nbytes * 2
233
+ return float_arrays + complex_arrays
234
+
235
+ def per_fork(self) -> int:
236
+ return 0
237
+
238
+
239
+ def _compute_convolution_shapes(
240
+ arr1_shape: Tuple[int], arr2_shape: Tuple[int]
241
+ ) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]:
242
+ """
243
+ Computes regular, optimized and fourier convolution shape.
244
+
245
+ Parameters
246
+ ----------
247
+ arr1_shape : tuple
248
+ Tuple of integers corresponding to array1 shape.
249
+ arr2_shape : tuple
250
+ Tuple of integers corresponding to array2 shape.
251
+
252
+ Returns
253
+ -------
254
+ tuple
255
+ Tuple with regular convolution shape, convolution shape optimized for faster
256
+ fourier transform, shape of the forward fourier transform
257
+ (see :py:meth:`build_fft`).
258
+ """
259
+ convolution_shape = np.add(arr1_shape, arr2_shape) - 1
260
+ fast_shape = [next_fast_len(x) for x in convolution_shape]
261
+ fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
262
+
263
+ return convolution_shape, fast_shape, fast_ft_shape
264
+
265
+
266
+ MATCHING_MEMORY_REGISTRY = {
267
+ "CC": CCMemoryUsage,
268
+ "LCC": LCCMemoryUsage,
269
+ "CORR": CORRMemoryUsage,
270
+ "CAM": CAMMemoryUsage,
271
+ "MCC": MCCMemoryUsage,
272
+ "FLCSphericalMask": FLCSphericalMaskMemoryUsage,
273
+ "FLC": FLCMemoryUsage,
274
+ "MaxScoreOverRotations": MaxScoreOverRotationsMemoryUsage,
275
+ "PeakCallerMaximumFilter": PeakCallerMaximumFilterMemoryUsage,
276
+ "cupy": CupyBackendMemoryUsage,
277
+ "pytorch": CupyBackendMemoryUsage,
278
+ }
279
+
280
+
281
+ def estimate_ram_usage(
282
+ shape1: Tuple[int],
283
+ shape2: Tuple[int],
284
+ matching_method: str,
285
+ ncores: int,
286
+ analyzer_method: str = None,
287
+ backend: str = None,
288
+ float_nbytes: int = 4,
289
+ complex_nbytes: int = 8,
290
+ integer_nbytes: int = 4,
291
+ ) -> int:
292
+ """
293
+ Estimate the RAM usage for a given convolution operation based on input shapes,
294
+ matching_method, and number of cores.
295
+
296
+ Parameters
297
+ ----------
298
+ shape1 : tuple
299
+ The shape of the input target.
300
+ shape2 : tuple
301
+ The shape of the input template.
302
+ matching_method : str
303
+ The method used for the operation.
304
+ is_gpu : bool, optional
305
+ Whether the computation is performed on GPU. This factors in FFT
306
+ plan caching.
307
+ analyzer_method : str, optional
308
+ The method used for score analysis.
309
+ backend : str, optional
310
+ Backend used for computation.
311
+ ncores : int
312
+ The number of CPU cores used for the operation.
313
+ float_nbytes : int
314
+ Number of bytes of the used float, e.g. 4 for float32.
315
+ complex_nbytes : int
316
+ Number of bytes of the used complex, e.g. 8 for complex64.
317
+ integer_nbytes : int
318
+ Number of bytes of the used integer, e.g. 4 for int32.
319
+
320
+ Returns
321
+ -------
322
+ int
323
+ The estimated RAM usage for the operation in bytes.
324
+
325
+ Notes
326
+ -----
327
+ Residual memory from other objects that may remain allocated during
328
+ template matching, e.g. the full sized target when using splitting,
329
+ are not considered by this function.
330
+
331
+ Raises
332
+ ------
333
+ ValueError
334
+ If an unsupported matching_methode is provided.
335
+ """
336
+ if matching_method not in MATCHING_MEMORY_REGISTRY:
337
+ raise ValueError(
338
+ f"Supported options are {','.join(MATCHING_MEMORY_REGISTRY.keys())}"
339
+ )
340
+
341
+ convolution_shape, fast_shape, ft_shape = _compute_convolution_shapes(
342
+ shape1, shape2
343
+ )
344
+
345
+ memory_instance = MATCHING_MEMORY_REGISTRY[matching_method](
346
+ fast_shape=fast_shape,
347
+ ft_shape=ft_shape,
348
+ float_nbytes=float_nbytes,
349
+ complex_nbytes=complex_nbytes,
350
+ integer_nbytes=integer_nbytes,
351
+ )
352
+
353
+ nbytes = memory_instance.base_usage() + memory_instance.per_fork() * ncores
354
+
355
+ analyzer_instance = MATCHING_MEMORY_REGISTRY.get(analyzer_method, None)
356
+ if analyzer_instance is not None:
357
+ analyzer_instance = analyzer_instance(
358
+ fast_shape=fast_shape,
359
+ ft_shape=ft_shape,
360
+ float_nbytes=float_nbytes,
361
+ complex_nbytes=complex_nbytes,
362
+ integer_nbytes=integer_nbytes,
363
+ )
364
+ nbytes += analyzer_instance.base_usage() + analyzer_instance.per_fork() * ncores
365
+
366
+ backend_instance = MATCHING_MEMORY_REGISTRY.get(backend, None)
367
+ if backend_instance is not None:
368
+ backend_instance = backend_instance(
369
+ fast_shape=fast_shape,
370
+ ft_shape=ft_shape,
371
+ float_nbytes=float_nbytes,
372
+ complex_nbytes=complex_nbytes,
373
+ integer_nbytes=integer_nbytes,
374
+ )
375
+ nbytes += backend_instance.base_usage() + backend_instance.per_fork() * ncores
376
+
377
+ return nbytes
tme/orientations.py CHANGED
@@ -62,16 +62,16 @@ class Orientations:
62
62
  Array with additional orientation details (n, ).
63
63
  """
64
64
 
65
- #: Return a numpy array with translations of each orientation (n x d).
65
+ #: Array with translations of each orientation (n, d).
66
66
  translations: np.ndarray
67
67
 
68
- #: Return a numpy array with euler angles of each orientation in zxy format (n x d).
68
+ #: Array with zyx euler angles of each orientation (n, d).
69
69
  rotations: np.ndarray
70
70
 
71
- #: Return a numpy array with the score of each orientation (n, ).
71
+ #: Array with scores of each orientation (n, ).
72
72
  scores: np.ndarray
73
73
 
74
- #: Return a numpy array with additional orientation details (n, ).
74
+ #: Array with additional details of each orientation(n, ).
75
75
  details: np.ndarray
76
76
 
77
77
  def __post_init__(self):
@@ -130,9 +130,21 @@ class Orientations:
130
130
  "scores",
131
131
  "details",
132
132
  )
133
- kwargs = {attr: getattr(self, attr)[indices] for attr in attributes}
133
+ kwargs = {attr: getattr(self, attr)[indices].copy() for attr in attributes}
134
134
  return self.__class__(**kwargs)
135
135
 
136
+ def copy(self) -> "Orientations":
137
+ """
138
+ Create a copy of the current class instance.
139
+
140
+ Returns
141
+ -------
142
+ :py:class:`Orientations`
143
+ Copy of the class instance.
144
+ """
145
+ indices = np.arange(self.scores.size)
146
+ return self[indices]
147
+
136
148
  def to_file(self, filename: str, file_format: type = None, **kwargs) -> None:
137
149
  """
138
150
  Save the current class instance to a file in the specified format.
@@ -146,7 +158,7 @@ class Orientations:
146
158
  the file_format from the typical extension. Supported formats are
147
159
 
148
160
  +---------------+----------------------------------------------------+
149
- | text | pyTME's standard tab-separated orientations file |
161
+ | text | pytme's standard tab-separated orientations file |
150
162
  +---------------+----------------------------------------------------+
151
163
  | relion | Creates a STAR file of orientations |
152
164
  +---------------+----------------------------------------------------+
@@ -207,11 +219,11 @@ class Orientations:
207
219
  with open(filename, mode="w", encoding="utf-8") as ofile:
208
220
  _ = ofile.write(f"{header}\n")
209
221
  for translation, angles, score, detail in self:
210
- translation_string = "\t".join([str(x) for x in translation])
211
- angle_string = "\t".join([str(x) for x in angles])
212
- _ = ofile.write(
213
- f"{translation_string}\t{angle_string}\t{score}\t{detail}\n"
222
+ out_string = (
223
+ "\t".join([str(x) for x in (*translation, *angles, score, detail)])
224
+ + "\n"
214
225
  )
226
+ _ = ofile.write(out_string)
215
227
  return None
216
228
 
217
229
  def _to_dynamo_tbl(
@@ -465,8 +477,10 @@ class Orientations:
465
477
 
466
478
  Notes
467
479
  -----
468
- The text file is expected to have a header and data in columns corresponding to
469
- z, y, x, euler_z, euler_y, euler_x, score, detail.
480
+ The text file is expected to have a header and data in columns. Colums containing
481
+ the name euler are considered to specify rotations. The second last and last
482
+ column correspond to score and detail. Its possible to only specify translations,
483
+ in this case the remaining columns will be filled with trivial values.
470
484
  """
471
485
  with open(filename, mode="r", encoding="utf-8") as infile:
472
486
  data = [x.strip().split("\t") for x in infile.read().split("\n")]
@@ -493,6 +507,32 @@ class Orientations:
493
507
  score = np.array(score)
494
508
  detail = np.array(detail)
495
509
 
510
+ if translation.shape[1] == len(header):
511
+ rotation = np.zeros(translation.shape, dtype=np.float32)
512
+ score = np.zeros(translation.shape[0], dtype=np.float32)
513
+ detail = np.zeros(translation.shape[0], dtype=np.float32) - 1
514
+
515
+ if rotation.size == 0 and translation.shape[0] != 0:
516
+ rotation = np.zeros(translation.shape, dtype=np.float32)
517
+
518
+ header_order = tuple(x for x in header if x in ascii_lowercase)
519
+ header_order = zip(header_order, range(len(header_order)))
520
+ sort_order = tuple(
521
+ x[1] for x in sorted(header_order, key=lambda x: x[0], reverse=True)
522
+ )
523
+ translation = translation[..., sort_order]
524
+
525
+ header_order = tuple(
526
+ x
527
+ for x in header
528
+ if "euler" in x and x.replace("euler_", "") in ascii_lowercase
529
+ )
530
+ header_order = zip(header_order, range(len(header_order)))
531
+ sort_order = tuple(
532
+ x[1] for x in sorted(header_order, key=lambda x: x[0], reverse=True)
533
+ )
534
+ rotation = rotation[..., sort_order]
535
+
496
536
  return translation, rotation, score, detail
497
537
 
498
538
  @staticmethod
tme/parser.py CHANGED
@@ -137,8 +137,7 @@ class Parser(ABC):
137
137
 
138
138
  class PDBParser(Parser):
139
139
  """
140
- A Parser subclass for converting PDB file data into a dictionary representation.
141
- This class is specifically designed to work with PDB file format.
140
+ Convert PDB file data into a dictionary representation [1]_.
142
141
 
143
142
  References
144
143
  ----------
@@ -228,8 +227,8 @@ class PDBParser(Parser):
228
227
 
229
228
  class MMCIFParser(Parser):
230
229
  """
231
- A Parser subclass for converting MMCIF file data into a dictionary representation.
232
- This implementation heavily relies on the atomium library [1]_.
230
+ Convert MMCIF file data into a dictionary representation. This implementation
231
+ heavily relies on the atomium library [1]_.
233
232
 
234
233
  References
235
234
  ----------
@@ -5,12 +5,13 @@
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
7
 
8
- from typing import Tuple
8
+ from typing import Tuple, List
9
9
 
10
10
  import numpy as np
11
- from numpy.typing import NDArray
12
11
 
13
- from ..backends import backend
12
+ from ..backends import backend as be
13
+ from ..backends import NumpyFFTWBackend
14
+ from ..types import BackendArray, NDArray
14
15
  from ..matching_utils import euler_to_rotationmatrix
15
16
 
16
17
 
@@ -93,18 +94,27 @@ def frequency_grid_at_angle(
93
94
  tilt_shape = compute_tilt_shape(
94
95
  shape=shape, opening_axis=opening_axis, reduce_dim=False
95
96
  )
96
- index_grid = centered_grid(shape=tilt_shape)
97
+
98
+ if angle == 0:
99
+ index_grid = fftfreqn(
100
+ tuple(x for x in tilt_shape if x != 1),
101
+ sampling_rate=1,
102
+ compute_euclidean_norm=True,
103
+ )
104
+
97
105
  if angle != 0:
98
106
  angles = np.zeros(len(shape))
99
107
  angles[tilt_axis] = angle
100
108
  rotation_matrix = euler_to_rotationmatrix(np.roll(angles, opening_axis - 1))
109
+
110
+ index_grid = fftfreqn(tilt_shape, sampling_rate=None)
101
111
  index_grid = np.einsum("ij,j...->i...", rotation_matrix, index_grid)
112
+ norm = np.multiply(sampling_rate, shape).astype(int)
102
113
 
103
- norm = np.divide(1, 2 * sampling_rate * np.divide(shape, 2).astype(int))
114
+ index_grid = np.divide(index_grid.T, norm).T
115
+ index_grid = np.squeeze(index_grid)
116
+ index_grid = np.linalg.norm(index_grid, axis=(0))
104
117
 
105
- index_grid = np.multiply(index_grid.T, norm).T
106
- index_grid = np.squeeze(index_grid)
107
- index_grid = np.linalg.norm(index_grid, axis=(0))
108
118
  return index_grid
109
119
 
110
120
 
@@ -113,9 +123,10 @@ def fftfreqn(
113
123
  sampling_rate: Tuple[float],
114
124
  compute_euclidean_norm: bool = False,
115
125
  shape_is_real_fourier: bool = False,
126
+ return_sparse_grid: bool = False,
116
127
  ) -> NDArray:
117
128
  """
118
- Generate the n-dimensional discrete Fourier Transform sample frequencies.
129
+ Generate the n-dimensional discrete Fourier transform sample frequencies.
119
130
 
120
131
  Parameters:
121
132
  -----------
@@ -133,56 +144,74 @@ def fftfreqn(
133
144
  NDArray
134
145
  The sample frequencies.
135
146
  """
136
- center = backend.astype(backend.divide(shape, 2), backend._int_dtype)
137
-
138
- norm = np.ones(len(shape))
147
+ # There is no real need to have these operations on GPU right now
148
+ temp_backend = NumpyFFTWBackend()
149
+ norm = temp_backend.full(len(shape), fill_value=1)
150
+ center = temp_backend.astype(temp_backend.divide(shape, 2), temp_backend._int_dtype)
139
151
  if sampling_rate is not None:
140
- norm = backend.astype(backend.multiply(shape, sampling_rate), int)
152
+ norm = temp_backend.astype(temp_backend.multiply(shape, sampling_rate), int)
141
153
 
142
154
  if shape_is_real_fourier:
143
- center[-1] = 0
144
- norm[-1] = 1
155
+ center[-1], norm[-1] = 0, 1
145
156
  if sampling_rate is not None:
146
157
  norm[-1] = (shape[-1] - 1) * 2 * sampling_rate
147
158
 
148
- indices = backend.transpose(backend.indices(shape))
149
- indices -= center
150
- indices = backend.divide(indices, norm)
151
- indices = backend.transpose(indices)
159
+ grids = []
160
+ for i, x in enumerate(shape):
161
+ baseline_dims = tuple(1 if i != t else x for t in range(len(shape)))
162
+ grid = (temp_backend.arange(x) - center[i]) / norm[i]
163
+ grids.append(temp_backend.reshape(grid, baseline_dims))
152
164
 
153
165
  if compute_euclidean_norm:
154
- indices = backend.square(indices)
155
- indices = backend.sum(indices, axis=0)
156
- backend.sqrt(indices, out=indices)
166
+ grids = sum(temp_backend.square(x) for x in grids)
167
+ grids = temp_backend.sqrt(grids, out=grids)
168
+ return grids
169
+
170
+ if return_sparse_grid:
171
+ return grids
157
172
 
158
- return indices
173
+ grid_flesh = temp_backend.full(shape, fill_value=1)
174
+ grids = temp_backend.stack(tuple(grid * grid_flesh for grid in grids))
159
175
 
176
+ return grids
160
177
 
161
- def crop_real_fourier(data: NDArray) -> NDArray:
178
+
179
+ def crop_real_fourier(data: BackendArray) -> BackendArray:
162
180
  """
163
181
  Crop the real part of a Fourier transform.
164
182
 
165
183
  Parameters:
166
184
  -----------
167
- data : NDArray
185
+ data : BackendArray
168
186
  The Fourier transformed data.
169
187
 
170
188
  Returns:
171
189
  --------
172
- NDArray
190
+ BackendArray
173
191
  The cropped data.
174
192
  """
175
193
  stop = 1 + (data.shape[-1] // 2)
176
194
  return data[..., :stop]
177
195
 
178
196
 
179
- def shift_fourier(data: NDArray, shape_is_real_fourier: bool = False):
180
- shift = backend.add(
181
- backend.astype(backend.divide(data.shape, 2), int),
182
- backend.mod(data.shape, 2),
183
- )
197
+ def compute_fourier_shape(
198
+ shape: Tuple[int], shape_is_real_fourier: bool = False
199
+ ) -> List[int]:
200
+ if shape_is_real_fourier:
201
+ return shape
202
+ shape = [int(x) for x in shape]
203
+ shape[-1] = 1 + shape[-1] // 2
204
+ return shape
205
+
206
+
207
+ def shift_fourier(
208
+ data: BackendArray, shape_is_real_fourier: bool = False
209
+ ) -> BackendArray:
210
+ shape = be.to_backend_array(data.shape)
211
+ shift = be.add(be.divide(shape, 2), be.mod(shape, 2))
212
+ shift = [int(x) for x in shift]
184
213
  if shape_is_real_fourier:
185
214
  shift[-1] = 0
186
215
 
187
- data = backend.roll(data, shift, tuple(i for i in range(len(shift))))
216
+ data = be.roll(data, shift, tuple(i for i in range(len(shift))))
188
217
  return data
@@ -7,7 +7,7 @@
7
7
 
8
8
  from typing import Tuple, Dict
9
9
 
10
- from tme.backends import backend
10
+ from tme.backends import backend as be
11
11
 
12
12
 
13
13
  class Compose:
@@ -42,9 +42,13 @@ class Compose:
42
42
  kwargs.update(meta)
43
43
  ret = transform(**kwargs)
44
44
 
45
+ if "data" not in ret:
46
+ continue
47
+
45
48
  if ret.get("is_multiplicative_filter", False):
46
- backend.multiply(ret["data"], meta["data"], out=ret["data"])
47
- ret["merge"] = None
49
+ prev_data = meta.pop("data")
50
+ ret["data"] = be.multiply(ret["data"], prev_data, out=ret["data"])
51
+ ret["merge"], prev_data = None, None
48
52
 
49
53
  meta = ret
50
54