pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/match_template.py +97 -148
  2. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/postprocess.py +20 -29
  3. pytme-0.2.4.data/scripts/preprocess.py +148 -0
  4. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +15 -23
  5. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/METADATA +11 -10
  6. pytme-0.2.4.dist-info/RECORD +119 -0
  7. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
  8. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
  9. pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
  10. scripts/match_template.py +97 -148
  11. scripts/postprocess.py +20 -29
  12. scripts/preprocess.py +116 -61
  13. scripts/preprocessor_gui.py +15 -23
  14. tests/__init__.py +0 -0
  15. tests/data/.DS_Store +0 -0
  16. tests/data/Blurring/.DS_Store +0 -0
  17. tests/data/Blurring/blob_width18.npy +0 -0
  18. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  19. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  20. tests/data/Blurring/hamming_width6.npy +0 -0
  21. tests/data/Blurring/kaiserb_width18.npy +0 -0
  22. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  23. tests/data/Blurring/mean_size5.npy +0 -0
  24. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  25. tests/data/Blurring/rank_rank3.npy +0 -0
  26. tests/data/Maps/.DS_Store +0 -0
  27. tests/data/Maps/emd_8621.mrc.gz +0 -0
  28. tests/data/README.md +2 -0
  29. tests/data/Raw/.DS_Store +0 -0
  30. tests/data/Raw/em_map.map +0 -0
  31. tests/data/Structures/.DS_Store +0 -0
  32. tests/data/Structures/1pdj.cif +3339 -0
  33. tests/data/Structures/1pdj.pdb +1429 -0
  34. tests/data/Structures/5khe.cif +3685 -0
  35. tests/data/Structures/5khe.ent +2210 -0
  36. tests/data/Structures/5khe.pdb +2210 -0
  37. tests/data/Structures/5uz4.cif +70548 -0
  38. tests/preprocessing/__init__.py +0 -0
  39. tests/preprocessing/test_compose.py +76 -0
  40. tests/preprocessing/test_frequency_filters.py +178 -0
  41. tests/preprocessing/test_preprocessor.py +136 -0
  42. tests/preprocessing/test_utils.py +79 -0
  43. tests/test_analyzer.py +310 -0
  44. tests/test_backends.py +375 -0
  45. tests/test_density.py +508 -0
  46. tests/test_extensions.py +130 -0
  47. tests/test_matching_cli.py +283 -0
  48. tests/test_matching_data.py +162 -0
  49. tests/test_matching_exhaustive.py +162 -0
  50. tests/test_matching_memory.py +30 -0
  51. tests/test_matching_optimization.py +276 -0
  52. tests/test_matching_utils.py +326 -0
  53. tests/test_orientations.py +173 -0
  54. tests/test_packaging.py +95 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_structure.py +243 -0
  57. tme/__init__.py +0 -1
  58. tme/__version__.py +1 -1
  59. tme/analyzer.py +9 -6
  60. tme/backends/__init__.py +1 -1
  61. tme/backends/_jax_utils.py +10 -8
  62. tme/backends/cupy_backend.py +2 -7
  63. tme/backends/jax_backend.py +35 -20
  64. tme/backends/npfftw_backend.py +3 -2
  65. tme/backends/pytorch_backend.py +10 -7
  66. tme/data/scattering_factors.pickle +0 -0
  67. tme/density.py +26 -12
  68. tme/extensions.cpython-311-darwin.so +0 -0
  69. tme/external/bindings.cpp +332 -0
  70. tme/matching_data.py +33 -24
  71. tme/matching_exhaustive.py +39 -20
  72. tme/matching_scores.py +5 -2
  73. tme/matching_utils.py +8 -2
  74. tme/orientations.py +26 -9
  75. tme/preprocessing/_utils.py +14 -14
  76. tme/preprocessing/composable_filter.py +5 -4
  77. tme/preprocessing/compose.py +4 -4
  78. tme/preprocessing/frequency_filters.py +32 -35
  79. tme/preprocessing/tilt_series.py +210 -148
  80. tme/preprocessor.py +24 -246
  81. tme/structure.py +14 -14
  82. pytme-0.2.2.dist-info/RECORD +0 -74
  83. tme/matching_memory.py +0 -383
  84. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
  85. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
  86. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
tme/matching_data.py CHANGED
@@ -171,6 +171,7 @@ class MatchingData:
171
171
  np.subtract(right_pad, data_voxels_right),
172
172
  )
173
173
  )
174
+ # The reflections are later cropped from the scores
174
175
  arr = np.pad(arr, padding, mode="reflect")
175
176
 
176
177
  if invert:
@@ -449,7 +450,7 @@ class MatchingData:
449
450
  template_shape: NDArray,
450
451
  batch_mask: NDArray = None,
451
452
  pad_fourier: bool = False,
452
- ) -> Tuple[Tuple, Tuple, Tuple]:
453
+ ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
453
454
  """
454
455
  Determines an efficient shape for Fourier transforms considering zero-padding.
455
456
  """
@@ -467,31 +468,39 @@ class MatchingData:
467
468
 
468
469
  pad_shape = np.maximum(target_shape, template_shape)
469
470
  ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
470
- convolution_shape, fast_shape, fast_ft_shape = ret
471
+ conv_shape, fast_shape, fast_ft_shape = ret
472
+
473
+ template_mod = np.mod(template_shape, 2)
471
474
  if not pad_fourier:
472
475
  fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
473
- fourier_shift -= np.mod(template_shape, 2)
474
- shape_diff = np.subtract(fast_shape, convolution_shape)
475
- shape_diff = np.divide(shape_diff, 2).astype(int)
476
- shape_diff = np.multiply(shape_diff, 1 - batch_mask)
477
- np.add(fourier_shift, shape_diff, out=fourier_shift)
478
-
479
- fourier_shift = fourier_shift.astype(int)
480
-
481
- shape_diff = np.subtract(target_shape, template_shape)
482
- shape_diff = np.multiply(shape_diff, 1 - batch_mask)
483
- if np.sum(shape_diff < 0) and not pad_fourier:
484
- warnings.warn(
485
- "Target is larger than template and Fourier padding is turned off. "
486
- "This may lead to inaccurate results. Prefer swapping template and target, "
487
- "enable padding or turn off template centering."
488
- )
489
- fourier_shift = np.subtract(fourier_shift, np.divide(shape_diff, 2))
490
- fourier_shift = fourier_shift.astype(int)
476
+ fourier_shift = np.subtract(fourier_shift, template_mod)
477
+
478
+ shape_diff = np.multiply(
479
+ np.subtract(target_shape, template_shape), 1 - batch_mask
480
+ )
481
+ if np.sum(shape_diff < 0):
482
+ shape_shift = np.divide(shape_diff, 2)
483
+ offset = np.mod(shape_diff, 2)
484
+ if pad_fourier:
485
+ offset = -np.subtract(
486
+ offset,
487
+ np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
488
+ )
489
+ else:
490
+ warnings.warn(
491
+ "Template is larger than target and padding is turned off. Consider "
492
+ "swapping them or activate padding. Correcting the shift for now."
493
+ )
494
+
495
+ shape_shift = np.add(shape_shift, offset)
496
+ fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
491
497
 
492
- return tuple(fast_shape), tuple(fast_ft_shape), tuple(fourier_shift)
498
+ fourier_shift = tuple(fourier_shift.astype(int))
499
+ return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
493
500
 
494
- def fourier_padding(self, pad_fourier: bool = False) -> Tuple[Tuple, Tuple, Tuple]:
501
+ def fourier_padding(
502
+ self, pad_fourier: bool = False
503
+ ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
495
504
  """
496
505
  Computes efficient shape four Fourier transforms and potential associated shifts.
497
506
 
@@ -503,8 +512,8 @@ class MatchingData:
503
512
 
504
513
  Returns
505
514
  -------
506
- Tuple[tuple of int, tuple of int, tuple of int]
507
- Tuple with real and complex Fourier transform shape, and corresponding shift.
515
+ Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
516
+ Tuple with convolution, forward FT, inverse FT shape and corresponding shift.
508
517
  """
509
518
  return self._fourier_padding(
510
519
  target_shape=be.to_numpy_array(self._output_target_shape),
@@ -73,35 +73,48 @@ def _setup_template_filter_apply_target_filter(
73
73
  if not filter_template and not filter_target:
74
74
  return template_filter
75
75
 
76
- target_temp = be.topleft_pad(matching_data.target, fast_shape)
77
- target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
78
-
79
76
  inv_mask = be.subtract(1, be.to_backend_array(matching_data._batch_mask))
80
77
  filter_shape = be.multiply(be.to_backend_array(fast_ft_shape), inv_mask)
81
78
  filter_shape = tuple(int(x) if x != 0 else 1 for x in filter_shape)
82
-
83
79
  fast_shape = be.multiply(be.to_backend_array(fast_shape), inv_mask)
84
80
  fast_shape = tuple(int(x) for x in fast_shape if x != 0)
85
81
 
82
+ fastt_shape, fastt_ft_shape = fast_shape, filter_shape
83
+ if filter_template and not pad_template_filter:
84
+ # FFT shape acrobatics for faster filter application
85
+ # _, fastt_shape, _, _ = matching_data._fourier_padding(
86
+ # target_shape=be.to_numpy_array(matching_data._template.shape),
87
+ # template_shape=be.to_numpy_array(
88
+ # [1 for _ in matching_data._template.shape]
89
+ # ),
90
+ # pad_fourier=False,
91
+ # )
92
+ fastt_shape = matching_data._template.shape
93
+ matching_data.template = be.reverse(
94
+ be.topleft_pad(matching_data.template, fastt_shape)
95
+ )
96
+ matching_data.template_mask = be.reverse(
97
+ be.topleft_pad(matching_data.template_mask, fastt_shape)
98
+ )
99
+ matching_data._set_matching_dimension(
100
+ target_dims=matching_data._target_dims,
101
+ template_dims=matching_data._template_dims,
102
+ )
103
+ fastt_ft_shape = [int(x) for x in matching_data._output_template_shape]
104
+ fastt_ft_shape[-1] = fastt_ft_shape[-1] // 2 + 1
105
+
106
+ target_temp = be.topleft_pad(matching_data.target, fast_shape)
107
+ target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
86
108
  target_temp_ft = rfftn(target_temp, target_temp_ft)
87
109
  if filter_template:
88
- # TODO: Pad to fast shapes and adapt _setup_template_filtering accordingly
89
- template_fast_shape, template_filter_shape = fast_shape, filter_shape
90
- if not pad_template_filter:
91
- template_fast_shape = tuple(int(x) for x in matching_data._template.shape)
92
- template_filter_shape = [
93
- int(x) for x in matching_data._output_template_shape
94
- ]
95
- template_filter_shape[-1] = template_filter_shape[-1] // 2 + 1
96
-
97
110
  template_filter = matching_data.template_filter(
98
- shape=template_fast_shape,
111
+ shape=fastt_shape,
99
112
  return_real_fourier=True,
100
113
  shape_is_real_fourier=False,
101
114
  data_rfft=target_temp_ft,
102
115
  batch_dimension=matching_data._target_dims,
103
116
  )["data"]
104
- template_filter = be.reshape(template_filter, template_filter_shape)
117
+ template_filter = be.reshape(template_filter, fastt_ft_shape)
105
118
 
106
119
  if filter_target:
107
120
  target_filter = matching_data.target_filter(
@@ -212,9 +225,13 @@ def scan(
212
225
 
213
226
  """
214
227
  matching_data.to_backend()
215
- fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
216
- pad_fourier=pad_fourier
217
- )
228
+ (
229
+ conv_shape,
230
+ fast_shape,
231
+ fast_ft_shape,
232
+ fourier_shift,
233
+ ) = matching_data.fourier_padding(pad_fourier=pad_fourier)
234
+ template_shape = matching_data.template.shape
218
235
 
219
236
  rfftn, irfftn = be.build_fft(
220
237
  fast_shape=fast_shape,
@@ -256,7 +273,8 @@ def scan(
256
273
  "fourier_shift": fourier_shift,
257
274
  "convolution_mode": convmode,
258
275
  "targetshape": matching_data.target.shape,
259
- "templateshape": matching_data.template.shape,
276
+ "templateshape": template_shape,
277
+ "convolution_shape": conv_shape,
260
278
  "fast_shape": fast_shape,
261
279
  "indices": getattr(matching_data, "indices", None),
262
280
  "shared_memory_handler": shared_memory_handler,
@@ -382,7 +400,8 @@ def scan_subsets(
382
400
  The template matching procedure is determined by ``matching_setup`` and
383
401
  ``matching_score``, which are unique to each score. In the following,
384
402
  we will be using the `FLCSphericalMask` score, which is composed of
385
- :py:meth:`flcSphericalMask_setup` and :py:meth:`corr_scoring`
403
+ :py:meth:`tme.matching_scores.flcSphericalMask_setup` and
404
+ :py:meth:`tme.matching_scores.corr_scoring`
386
405
 
387
406
  >>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
388
407
  >>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
tme/matching_scores.py CHANGED
@@ -86,7 +86,7 @@ def _setup_template_filtering(
86
86
  forward_ft_shape = template_shape
87
87
  inverse_ft_shape = template_filter.shape
88
88
 
89
- if rfftn is not None and irfftn is not None:
89
+ if (rfftn is not None and irfftn is not None) or shape_mismatch:
90
90
  rfftn, irfftn = be.build_fft(
91
91
  fast_shape=forward_ft_shape,
92
92
  fast_ft_shape=inverse_ft_shape,
@@ -109,7 +109,10 @@ def _setup_template_filtering(
109
109
 
110
110
  def _apply_filter_shape_mismatch(template, ft_temp, template_filter):
111
111
  _template[:] = template[real_subset]
112
- return _apply_template_filter(_template, _ft_temp, template_filter)
112
+ template[real_subset] = _apply_template_filter(
113
+ _template, _ft_temp, template_filter
114
+ )
115
+ return template
113
116
 
114
117
  return _apply_filter_shape_mismatch
115
118
 
tme/matching_utils.py CHANGED
@@ -467,6 +467,7 @@ def apply_convolution_mode(
467
467
  convolution_mode: str,
468
468
  s1: Tuple[int],
469
469
  s2: Tuple[int],
470
+ convolution_shape: Tuple[int] = None,
470
471
  mask_output: bool = False,
471
472
  ) -> BackendArray:
472
473
  """
@@ -490,6 +491,8 @@ def apply_convolution_mode(
490
491
  Tuple of integers corresponding to shape of convolution array 1.
491
492
  s2 : tuple of ints
492
493
  Tuple of integers corresponding to shape of convolution array 2.
494
+ convolution_shape : tuple of ints, optional
495
+ Size of the actually computed convolution. s1 + s2 - 1 by default.
493
496
  mask_output : bool, optional
494
497
  Whether to mask values outside of convolution_mode rather than
495
498
  removing them. Defaults to False.
@@ -500,7 +503,9 @@ def apply_convolution_mode(
500
503
  The array after applying the convolution mode.
501
504
  """
502
505
  # Remove padding to next fast Fourier length
503
- arr = arr[tuple(slice(s1[i] + s2[i] - 1) for i in range(len(s1)))]
506
+ if convolution_shape is None:
507
+ convolution_shape = [s1[i] + s2[i] - 1 for i in range(len(s1))]
508
+ arr = arr[tuple(slice(0, x) for x in convolution_shape)]
504
509
 
505
510
  if convolution_mode not in ("full", "same", "valid"):
506
511
  raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
@@ -640,6 +645,7 @@ def get_rotation_matrices(
640
645
  dets = np.linalg.det(ret)
641
646
  neg_dets = dets < 0
642
647
  ret[neg_dets, :, -1] *= -1
648
+ ret[0] = np.eye(dim, dtype = ret.dtype)
643
649
  return ret
644
650
 
645
651
 
@@ -1220,7 +1226,7 @@ def scramble_phases(
1220
1226
  arr: NDArray,
1221
1227
  noise_proportion: float = 0.5,
1222
1228
  seed: int = 42,
1223
- normalize_power: bool = True,
1229
+ normalize_power: bool = False,
1224
1230
  ) -> NDArray:
1225
1231
  """
1226
1232
  Perform random phase scrambling of ``arr``.
tme/orientations.py CHANGED
@@ -6,6 +6,7 @@
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
  import re
9
+ import warnings
9
10
  from collections import deque
10
11
  from dataclasses import dataclass
11
12
  from string import ascii_lowercase
@@ -301,7 +302,7 @@ class Orientations:
301
302
  def _to_relion_star(
302
303
  self,
303
304
  filename: str,
304
- name_prefix: str = None,
305
+ name: str = None,
305
306
  ctf_image: str = None,
306
307
  sampling_rate: float = 1.0,
307
308
  subtomogram_size: int = 0,
@@ -313,8 +314,9 @@ class Orientations:
313
314
  ----------
314
315
  filename : str
315
316
  The name of the file to save the orientations.
316
- name_prefix : str, optional
317
- A prefix to add to the image names in the STAR file.
317
+ name : str or list of str, optional
318
+ Path to image file the orientation is in reference to. If name is a list
319
+ its assumed to correspond to _rlnImageName, otherwise _rlnMicrographName.
318
320
  ctf_image : str, optional
319
321
  Path to CTF or wedge mask RELION.
320
322
  sampling_rate : float, optional
@@ -352,6 +354,21 @@ class Orientations:
352
354
  optics_header = "\n".join(optics_header)
353
355
  optics_data = "\t".join(optics_data)
354
356
 
357
+ if name is None:
358
+ name = ""
359
+ warnings.warn(
360
+ "Consider specifying the name argument. A single string will be "
361
+ "interpreted as path to the original micrograph, a list of strings "
362
+ "as path to individual subsets."
363
+ )
364
+
365
+ name_reference = "_rlnImageName"
366
+ if isinstance(name, str):
367
+ name = [
368
+ name,
369
+ ] * self.translations.shape[0]
370
+ name_reference = "_rlnMicrographName"
371
+
355
372
  header = [
356
373
  "data_particles",
357
374
  "",
@@ -359,7 +376,7 @@ class Orientations:
359
376
  "_rlnCoordinateX",
360
377
  "_rlnCoordinateY",
361
378
  "_rlnCoordinateZ",
362
- "_rlnImageName",
379
+ name_reference,
363
380
  "_rlnAngleRot",
364
381
  "_rlnAngleTilt",
365
382
  "_rlnAnglePsi",
@@ -371,8 +388,6 @@ class Orientations:
371
388
  ctf_image = "" if ctf_image is None else f"\t{ctf_image}"
372
389
 
373
390
  header = "\n".join(header)
374
- name_prefix = "" if name_prefix is None else name_prefix
375
-
376
391
  with open(filename, mode="w", encoding="utf-8") as ofile:
377
392
  _ = ofile.write(f"{optics_header}\n")
378
393
  _ = ofile.write(f"{optics_data}\n")
@@ -387,9 +402,8 @@ class Orientations:
387
402
 
388
403
  translation_string = "\t".join([str(x) for x in translation][::-1])
389
404
  angle_string = "\t".join([str(x) for x in rotation])
390
- name = f"{name_prefix}_{index}.mrc"
391
405
  _ = ofile.write(
392
- f"{translation_string}\t{name}\t{angle_string}\t1{ctf_image}\n"
406
+ f"{translation_string}\t{name[index]}\t{angle_string}\t1{ctf_image}\n"
393
407
  )
394
408
 
395
409
  return None
@@ -584,7 +598,10 @@ class Orientations:
584
598
  cls, filename: str, delimiter: str = None
585
599
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
586
600
  ret = cls._parse_star(filename=filename, delimiter=delimiter)
587
- ret = ret["data_particles"]
601
+
602
+ ret = ret.get("data_particles", None)
603
+ if ret is None:
604
+ raise ValueError(f"No data_particles section found in {filename}.")
588
605
 
589
606
  translation = np.vstack(
590
607
  (ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
@@ -19,8 +19,8 @@ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool =
19
19
  """
20
20
  Given an opening_axis, computes the shape of the remaining dimensions.
21
21
 
22
- Parameters:
23
- -----------
22
+ Parameters
23
+ ----------
24
24
  shape : Tuple[int]
25
25
  The shape of the input array.
26
26
  opening_axis : int
@@ -28,8 +28,8 @@ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool =
28
28
  reduce_dim : bool, optional (default=False)
29
29
  Whether to reduce the dimensionality after tilting.
30
30
 
31
- Returns:
32
- --------
31
+ Returns
32
+ -------
33
33
  Tuple[int]
34
34
  The shape of the array after tilting.
35
35
  """
@@ -44,13 +44,13 @@ def centered_grid(shape: Tuple[int]) -> NDArray:
44
44
  """
45
45
  Generate an integer valued grid centered around size // 2
46
46
 
47
- Parameters:
48
- -----------
47
+ Parameters
48
+ ----------
49
49
  shape : Tuple[int]
50
50
  The shape of the grid.
51
51
 
52
- Returns:
53
- --------
52
+ Returns
53
+ -------
54
54
  NDArray
55
55
  The centered grid.
56
56
  """
@@ -70,8 +70,8 @@ def frequency_grid_at_angle(
70
70
  """
71
71
  Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
72
72
 
73
- Parameters:
74
- -----------
73
+ Parameters
74
+ ----------
75
75
  shape : Tuple[int]
76
76
  The shape of the grid.
77
77
  angle : float
@@ -128,8 +128,8 @@ def fftfreqn(
128
128
  """
129
129
  Generate the n-dimensional discrete Fourier transform sample frequencies.
130
130
 
131
- Parameters:
132
- -----------
131
+ Parameters
132
+ ----------
133
133
  shape : Tuple[int]
134
134
  The shape of the data.
135
135
  sampling_rate : float or Tuple[float]
@@ -180,8 +180,8 @@ def crop_real_fourier(data: BackendArray) -> BackendArray:
180
180
  """
181
181
  Crop the real part of a Fourier transform.
182
182
 
183
- Parameters:
184
- -----------
183
+ Parameters
184
+ ----------
185
185
  data : BackendArray
186
186
  The Fourier transformed data.
187
187
 
@@ -17,15 +17,16 @@ class ComposableFilter(ABC):
17
17
  @abstractmethod
18
18
  def __call__(self, *args, **kwargs) -> Dict:
19
19
  """
20
- Parameters:
21
- -----------
20
+
21
+ Parameters
22
+ ----------
22
23
  *args : tuple
23
24
  Variable length argument list.
24
25
  **kwargs : dict
25
26
  Arbitrary keyword arguments.
26
27
 
27
- Returns:
28
- --------
28
+ Returns
29
+ -------
29
30
  Dict
30
31
  A dictionary representing the result of the filtering operation.
31
32
  """
@@ -17,13 +17,13 @@ class Compose:
17
17
  This class allows composing multiple transformations together. Each transformation
18
18
  is expected to be a callable that accepts keyword arguments and returns metadata.
19
19
 
20
- Parameters:
21
- -----------
20
+ Parameters
21
+ ----------
22
22
  transforms : Tuple[object]
23
23
  A tuple containing transformation objects.
24
24
 
25
- Returns:
26
- --------
25
+ Returns
26
+ -------
27
27
  Dict
28
28
  Metadata resulting from the composed transformations.
29
29
 
@@ -18,12 +18,10 @@ from ._utils import fftfreqn, crop_real_fourier, shift_fourier, compute_fourier_
18
18
 
19
19
  class BandPassFilter:
20
20
  """
21
- This class provides methods to generate bandpass filters in Fourier space,
22
- either by directly specifying the frequency cutoffs (discrete_bandpass) or
23
- by using Gaussian functions (gaussian_bandpass).
21
+ Generate bandpass filters in Fourier space.
24
22
 
25
- Parameters:
26
- -----------
23
+ Parameters
24
+ ----------
27
25
  lowpass : float, optional
28
26
  The lowpass cutoff, defaults to None.
29
27
  highpass : float, optional
@@ -67,8 +65,8 @@ class BandPassFilter:
67
65
  """
68
66
  Generate a bandpass filter using discrete frequency cutoffs.
69
67
 
70
- Parameters:
71
- -----------
68
+ Parameters
69
+ ----------
72
70
  shape : tuple of int
73
71
  The shape of the bandpass filter.
74
72
  lowpass : float
@@ -84,8 +82,8 @@ class BandPassFilter:
84
82
  **kwargs : dict
85
83
  Additional keyword arguments.
86
84
 
87
- Returns:
88
- --------
85
+ Returns
86
+ -------
89
87
  BackendArray
90
88
  The bandpass filter in Fourier space.
91
89
  """
@@ -98,17 +96,18 @@ class BandPassFilter:
98
96
  shape_is_real_fourier=shape_is_real_fourier,
99
97
  compute_euclidean_norm=True,
100
98
  )
101
-
102
- lowpass = 0 if lowpass is None else lowpass
103
- highpass = 1e10 if highpass is None else highpass
99
+ grid = be.to_backend_array(grid)
100
+ sampling_rate = be.to_backend_array(sampling_rate)
104
101
 
105
102
  highcut = grid.max()
106
- if lowpass > 0:
107
- highcut = np.max(2 * sampling_rate / lowpass)
108
- lowcut = np.max(2 * sampling_rate / highpass)
103
+ if lowpass is not None:
104
+ highcut = be.max(2 * sampling_rate / lowpass)
109
105
 
110
- bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
106
+ lowcut = 0
107
+ if highpass is not None:
108
+ lowcut = be.max(2 * sampling_rate / highpass)
111
109
 
110
+ bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
112
111
  bandpass_filter = shift_fourier(
113
112
  data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
114
113
  )
@@ -129,10 +128,10 @@ class BandPassFilter:
129
128
  **kwargs,
130
129
  ) -> BackendArray:
131
130
  """
132
- Generate a bandpass filter using Gaussian functions.
131
+ Generate a bandpass filter using Gaussians.
133
132
 
134
- Parameters:
135
- -----------
133
+ Parameters
134
+ ----------
136
135
  shape : tuple of int
137
136
  The shape of the bandpass filter.
138
137
  lowpass : float
@@ -148,8 +147,8 @@ class BandPassFilter:
148
147
  **kwargs : dict
149
148
  Additional keyword arguments.
150
149
 
151
- Returns:
152
- --------
150
+ Returns
151
+ -------
153
152
  BackendArray
154
153
  The bandpass filter in Fourier space.
155
154
  """
@@ -216,15 +215,13 @@ class BandPassFilter:
216
215
 
217
216
  class LinearWhiteningFilter:
218
217
  """
219
- This class provides methods to compute the spectrum of the input data and
220
- apply linear whitening to the Fourier coefficients.
218
+ Compute Fourier power spectrums and perform whitening.
221
219
 
222
- Parameters:
223
- -----------
220
+ Parameters
221
+ ----------
224
222
  **kwargs : Dict, optional
225
223
  Additional keyword arguments.
226
224
 
227
-
228
225
  References
229
226
  ----------
230
227
  .. [1] de Teresa-Trueba, I.; Goetz, S. K.; Mattausch, A.; Stojanovska, F.; Zimmerli, C. E.;
@@ -243,10 +240,10 @@ class LinearWhiteningFilter:
243
240
  data_rfft: BackendArray, n_bins: int = None, batch_dimension: int = None
244
241
  ) -> Tuple[BackendArray, BackendArray]:
245
242
  """
246
- Compute the spectrum of the input data.
243
+ Compute the power spectrum of the input data.
247
244
 
248
- Parameters:
249
- -----------
245
+ Parameters
246
+ ----------
250
247
  data_rfft : BackendArray
251
248
  The Fourier transform of the input data.
252
249
  n_bins : int, optional
@@ -254,8 +251,8 @@ class LinearWhiteningFilter:
254
251
  batch_dimension : int, optional
255
252
  Batch dimension to average over.
256
253
 
257
- Returns:
258
- --------
254
+ Returns
255
+ -------
259
256
  bins : BackendArray
260
257
  Array containing the bin indices for the spectrum.
261
258
  radial_averages : BackendArray
@@ -330,8 +327,8 @@ class LinearWhiteningFilter:
330
327
  """
331
328
  Apply linear whitening to the data and return the result.
332
329
 
333
- Parameters:
334
- -----------
330
+ Parameters
331
+ ----------
335
332
  data : BackendArray, optional
336
333
  The input data, defaults to None.
337
334
  data_rfft : BackendArray, optional
@@ -345,8 +342,8 @@ class LinearWhiteningFilter:
345
342
  **kwargs : Dict
346
343
  Additional keyword arguments.
347
344
 
348
- Returns:
349
- --------
345
+ Returns
346
+ -------
350
347
  Dict
351
348
  Filter data and associated parameters.
352
349
  """