pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/match_template.py +97 -148
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/postprocess.py +20 -29
- pytme-0.2.4.data/scripts/preprocess.py +148 -0
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +15 -23
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/METADATA +11 -10
- pytme-0.2.4.dist-info/RECORD +119 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
- pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/match_template.py +97 -148
- scripts/postprocess.py +20 -29
- scripts/preprocess.py +116 -61
- scripts/preprocessor_gui.py +15 -23
- tests/__init__.py +0 -0
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +310 -0
- tests/test_backends.py +375 -0
- tests/test_density.py +508 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +162 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +276 -0
- tests/test_matching_utils.py +326 -0
- tests/test_orientations.py +173 -0
- tests/test_packaging.py +95 -0
- tests/test_parser.py +33 -0
- tests/test_structure.py +243 -0
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +9 -6
- tme/backends/__init__.py +1 -1
- tme/backends/_jax_utils.py +10 -8
- tme/backends/cupy_backend.py +2 -7
- tme/backends/jax_backend.py +35 -20
- tme/backends/npfftw_backend.py +3 -2
- tme/backends/pytorch_backend.py +10 -7
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +26 -12
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/matching_data.py +33 -24
- tme/matching_exhaustive.py +39 -20
- tme/matching_scores.py +5 -2
- tme/matching_utils.py +8 -2
- tme/orientations.py +26 -9
- tme/preprocessing/_utils.py +14 -14
- tme/preprocessing/composable_filter.py +5 -4
- tme/preprocessing/compose.py +4 -4
- tme/preprocessing/frequency_filters.py +32 -35
- tme/preprocessing/tilt_series.py +210 -148
- tme/preprocessor.py +24 -246
- tme/structure.py +14 -14
- pytme-0.2.2.dist-info/RECORD +0 -74
- tme/matching_memory.py +0 -383
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
tme/matching_data.py
CHANGED
@@ -171,6 +171,7 @@ class MatchingData:
|
|
171
171
|
np.subtract(right_pad, data_voxels_right),
|
172
172
|
)
|
173
173
|
)
|
174
|
+
# The reflections are later cropped from the scores
|
174
175
|
arr = np.pad(arr, padding, mode="reflect")
|
175
176
|
|
176
177
|
if invert:
|
@@ -449,7 +450,7 @@ class MatchingData:
|
|
449
450
|
template_shape: NDArray,
|
450
451
|
batch_mask: NDArray = None,
|
451
452
|
pad_fourier: bool = False,
|
452
|
-
) -> Tuple[Tuple, Tuple, Tuple]:
|
453
|
+
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
453
454
|
"""
|
454
455
|
Determines an efficient shape for Fourier transforms considering zero-padding.
|
455
456
|
"""
|
@@ -467,31 +468,39 @@ class MatchingData:
|
|
467
468
|
|
468
469
|
pad_shape = np.maximum(target_shape, template_shape)
|
469
470
|
ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
|
470
|
-
|
471
|
+
conv_shape, fast_shape, fast_ft_shape = ret
|
472
|
+
|
473
|
+
template_mod = np.mod(template_shape, 2)
|
471
474
|
if not pad_fourier:
|
472
475
|
fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
|
473
|
-
fourier_shift
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
476
|
+
fourier_shift = np.subtract(fourier_shift, template_mod)
|
477
|
+
|
478
|
+
shape_diff = np.multiply(
|
479
|
+
np.subtract(target_shape, template_shape), 1 - batch_mask
|
480
|
+
)
|
481
|
+
if np.sum(shape_diff < 0):
|
482
|
+
shape_shift = np.divide(shape_diff, 2)
|
483
|
+
offset = np.mod(shape_diff, 2)
|
484
|
+
if pad_fourier:
|
485
|
+
offset = -np.subtract(
|
486
|
+
offset,
|
487
|
+
np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
|
488
|
+
)
|
489
|
+
else:
|
490
|
+
warnings.warn(
|
491
|
+
"Template is larger than target and padding is turned off. Consider "
|
492
|
+
"swapping them or activate padding. Correcting the shift for now."
|
493
|
+
)
|
494
|
+
|
495
|
+
shape_shift = np.add(shape_shift, offset)
|
496
|
+
fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
|
491
497
|
|
492
|
-
|
498
|
+
fourier_shift = tuple(fourier_shift.astype(int))
|
499
|
+
return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
|
493
500
|
|
494
|
-
def fourier_padding(
|
501
|
+
def fourier_padding(
|
502
|
+
self, pad_fourier: bool = False
|
503
|
+
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
495
504
|
"""
|
496
505
|
Computes efficient shape four Fourier transforms and potential associated shifts.
|
497
506
|
|
@@ -503,8 +512,8 @@ class MatchingData:
|
|
503
512
|
|
504
513
|
Returns
|
505
514
|
-------
|
506
|
-
Tuple[tuple of int, tuple of int, tuple of int]
|
507
|
-
Tuple with
|
515
|
+
Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
|
516
|
+
Tuple with convolution, forward FT, inverse FT shape and corresponding shift.
|
508
517
|
"""
|
509
518
|
return self._fourier_padding(
|
510
519
|
target_shape=be.to_numpy_array(self._output_target_shape),
|
tme/matching_exhaustive.py
CHANGED
@@ -73,35 +73,48 @@ def _setup_template_filter_apply_target_filter(
|
|
73
73
|
if not filter_template and not filter_target:
|
74
74
|
return template_filter
|
75
75
|
|
76
|
-
target_temp = be.topleft_pad(matching_data.target, fast_shape)
|
77
|
-
target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
|
78
|
-
|
79
76
|
inv_mask = be.subtract(1, be.to_backend_array(matching_data._batch_mask))
|
80
77
|
filter_shape = be.multiply(be.to_backend_array(fast_ft_shape), inv_mask)
|
81
78
|
filter_shape = tuple(int(x) if x != 0 else 1 for x in filter_shape)
|
82
|
-
|
83
79
|
fast_shape = be.multiply(be.to_backend_array(fast_shape), inv_mask)
|
84
80
|
fast_shape = tuple(int(x) for x in fast_shape if x != 0)
|
85
81
|
|
82
|
+
fastt_shape, fastt_ft_shape = fast_shape, filter_shape
|
83
|
+
if filter_template and not pad_template_filter:
|
84
|
+
# FFT shape acrobatics for faster filter application
|
85
|
+
# _, fastt_shape, _, _ = matching_data._fourier_padding(
|
86
|
+
# target_shape=be.to_numpy_array(matching_data._template.shape),
|
87
|
+
# template_shape=be.to_numpy_array(
|
88
|
+
# [1 for _ in matching_data._template.shape]
|
89
|
+
# ),
|
90
|
+
# pad_fourier=False,
|
91
|
+
# )
|
92
|
+
fastt_shape = matching_data._template.shape
|
93
|
+
matching_data.template = be.reverse(
|
94
|
+
be.topleft_pad(matching_data.template, fastt_shape)
|
95
|
+
)
|
96
|
+
matching_data.template_mask = be.reverse(
|
97
|
+
be.topleft_pad(matching_data.template_mask, fastt_shape)
|
98
|
+
)
|
99
|
+
matching_data._set_matching_dimension(
|
100
|
+
target_dims=matching_data._target_dims,
|
101
|
+
template_dims=matching_data._template_dims,
|
102
|
+
)
|
103
|
+
fastt_ft_shape = [int(x) for x in matching_data._output_template_shape]
|
104
|
+
fastt_ft_shape[-1] = fastt_ft_shape[-1] // 2 + 1
|
105
|
+
|
106
|
+
target_temp = be.topleft_pad(matching_data.target, fast_shape)
|
107
|
+
target_temp_ft = be.zeros(fast_ft_shape, be._complex_dtype)
|
86
108
|
target_temp_ft = rfftn(target_temp, target_temp_ft)
|
87
109
|
if filter_template:
|
88
|
-
# TODO: Pad to fast shapes and adapt _setup_template_filtering accordingly
|
89
|
-
template_fast_shape, template_filter_shape = fast_shape, filter_shape
|
90
|
-
if not pad_template_filter:
|
91
|
-
template_fast_shape = tuple(int(x) for x in matching_data._template.shape)
|
92
|
-
template_filter_shape = [
|
93
|
-
int(x) for x in matching_data._output_template_shape
|
94
|
-
]
|
95
|
-
template_filter_shape[-1] = template_filter_shape[-1] // 2 + 1
|
96
|
-
|
97
110
|
template_filter = matching_data.template_filter(
|
98
|
-
shape=
|
111
|
+
shape=fastt_shape,
|
99
112
|
return_real_fourier=True,
|
100
113
|
shape_is_real_fourier=False,
|
101
114
|
data_rfft=target_temp_ft,
|
102
115
|
batch_dimension=matching_data._target_dims,
|
103
116
|
)["data"]
|
104
|
-
template_filter = be.reshape(template_filter,
|
117
|
+
template_filter = be.reshape(template_filter, fastt_ft_shape)
|
105
118
|
|
106
119
|
if filter_target:
|
107
120
|
target_filter = matching_data.target_filter(
|
@@ -212,9 +225,13 @@ def scan(
|
|
212
225
|
|
213
226
|
"""
|
214
227
|
matching_data.to_backend()
|
215
|
-
|
216
|
-
|
217
|
-
|
228
|
+
(
|
229
|
+
conv_shape,
|
230
|
+
fast_shape,
|
231
|
+
fast_ft_shape,
|
232
|
+
fourier_shift,
|
233
|
+
) = matching_data.fourier_padding(pad_fourier=pad_fourier)
|
234
|
+
template_shape = matching_data.template.shape
|
218
235
|
|
219
236
|
rfftn, irfftn = be.build_fft(
|
220
237
|
fast_shape=fast_shape,
|
@@ -256,7 +273,8 @@ def scan(
|
|
256
273
|
"fourier_shift": fourier_shift,
|
257
274
|
"convolution_mode": convmode,
|
258
275
|
"targetshape": matching_data.target.shape,
|
259
|
-
"templateshape":
|
276
|
+
"templateshape": template_shape,
|
277
|
+
"convolution_shape": conv_shape,
|
260
278
|
"fast_shape": fast_shape,
|
261
279
|
"indices": getattr(matching_data, "indices", None),
|
262
280
|
"shared_memory_handler": shared_memory_handler,
|
@@ -382,7 +400,8 @@ def scan_subsets(
|
|
382
400
|
The template matching procedure is determined by ``matching_setup`` and
|
383
401
|
``matching_score``, which are unique to each score. In the following,
|
384
402
|
we will be using the `FLCSphericalMask` score, which is composed of
|
385
|
-
:py:meth:`flcSphericalMask_setup` and
|
403
|
+
:py:meth:`tme.matching_scores.flcSphericalMask_setup` and
|
404
|
+
:py:meth:`tme.matching_scores.corr_scoring`
|
386
405
|
|
387
406
|
>>> from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
|
388
407
|
>>> funcs = MATCHING_EXHAUSTIVE_REGISTER.get("FLCSphericalMask")
|
tme/matching_scores.py
CHANGED
@@ -86,7 +86,7 @@ def _setup_template_filtering(
|
|
86
86
|
forward_ft_shape = template_shape
|
87
87
|
inverse_ft_shape = template_filter.shape
|
88
88
|
|
89
|
-
if rfftn is not None and irfftn is not None:
|
89
|
+
if (rfftn is not None and irfftn is not None) or shape_mismatch:
|
90
90
|
rfftn, irfftn = be.build_fft(
|
91
91
|
fast_shape=forward_ft_shape,
|
92
92
|
fast_ft_shape=inverse_ft_shape,
|
@@ -109,7 +109,10 @@ def _setup_template_filtering(
|
|
109
109
|
|
110
110
|
def _apply_filter_shape_mismatch(template, ft_temp, template_filter):
|
111
111
|
_template[:] = template[real_subset]
|
112
|
-
|
112
|
+
template[real_subset] = _apply_template_filter(
|
113
|
+
_template, _ft_temp, template_filter
|
114
|
+
)
|
115
|
+
return template
|
113
116
|
|
114
117
|
return _apply_filter_shape_mismatch
|
115
118
|
|
tme/matching_utils.py
CHANGED
@@ -467,6 +467,7 @@ def apply_convolution_mode(
|
|
467
467
|
convolution_mode: str,
|
468
468
|
s1: Tuple[int],
|
469
469
|
s2: Tuple[int],
|
470
|
+
convolution_shape: Tuple[int] = None,
|
470
471
|
mask_output: bool = False,
|
471
472
|
) -> BackendArray:
|
472
473
|
"""
|
@@ -490,6 +491,8 @@ def apply_convolution_mode(
|
|
490
491
|
Tuple of integers corresponding to shape of convolution array 1.
|
491
492
|
s2 : tuple of ints
|
492
493
|
Tuple of integers corresponding to shape of convolution array 2.
|
494
|
+
convolution_shape : tuple of ints, optional
|
495
|
+
Size of the actually computed convolution. s1 + s2 - 1 by default.
|
493
496
|
mask_output : bool, optional
|
494
497
|
Whether to mask values outside of convolution_mode rather than
|
495
498
|
removing them. Defaults to False.
|
@@ -500,7 +503,9 @@ def apply_convolution_mode(
|
|
500
503
|
The array after applying the convolution mode.
|
501
504
|
"""
|
502
505
|
# Remove padding to next fast Fourier length
|
503
|
-
|
506
|
+
if convolution_shape is None:
|
507
|
+
convolution_shape = [s1[i] + s2[i] - 1 for i in range(len(s1))]
|
508
|
+
arr = arr[tuple(slice(0, x) for x in convolution_shape)]
|
504
509
|
|
505
510
|
if convolution_mode not in ("full", "same", "valid"):
|
506
511
|
raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
|
@@ -640,6 +645,7 @@ def get_rotation_matrices(
|
|
640
645
|
dets = np.linalg.det(ret)
|
641
646
|
neg_dets = dets < 0
|
642
647
|
ret[neg_dets, :, -1] *= -1
|
648
|
+
ret[0] = np.eye(dim, dtype = ret.dtype)
|
643
649
|
return ret
|
644
650
|
|
645
651
|
|
@@ -1220,7 +1226,7 @@ def scramble_phases(
|
|
1220
1226
|
arr: NDArray,
|
1221
1227
|
noise_proportion: float = 0.5,
|
1222
1228
|
seed: int = 42,
|
1223
|
-
normalize_power: bool =
|
1229
|
+
normalize_power: bool = False,
|
1224
1230
|
) -> NDArray:
|
1225
1231
|
"""
|
1226
1232
|
Perform random phase scrambling of ``arr``.
|
tme/orientations.py
CHANGED
@@ -6,6 +6,7 @@
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
import re
|
9
|
+
import warnings
|
9
10
|
from collections import deque
|
10
11
|
from dataclasses import dataclass
|
11
12
|
from string import ascii_lowercase
|
@@ -301,7 +302,7 @@ class Orientations:
|
|
301
302
|
def _to_relion_star(
|
302
303
|
self,
|
303
304
|
filename: str,
|
304
|
-
|
305
|
+
name: str = None,
|
305
306
|
ctf_image: str = None,
|
306
307
|
sampling_rate: float = 1.0,
|
307
308
|
subtomogram_size: int = 0,
|
@@ -313,8 +314,9 @@ class Orientations:
|
|
313
314
|
----------
|
314
315
|
filename : str
|
315
316
|
The name of the file to save the orientations.
|
316
|
-
|
317
|
-
|
317
|
+
name : str or list of str, optional
|
318
|
+
Path to image file the orientation is in reference to. If name is a list
|
319
|
+
its assumed to correspond to _rlnImageName, otherwise _rlnMicrographName.
|
318
320
|
ctf_image : str, optional
|
319
321
|
Path to CTF or wedge mask RELION.
|
320
322
|
sampling_rate : float, optional
|
@@ -352,6 +354,21 @@ class Orientations:
|
|
352
354
|
optics_header = "\n".join(optics_header)
|
353
355
|
optics_data = "\t".join(optics_data)
|
354
356
|
|
357
|
+
if name is None:
|
358
|
+
name = ""
|
359
|
+
warnings.warn(
|
360
|
+
"Consider specifying the name argument. A single string will be "
|
361
|
+
"interpreted as path to the original micrograph, a list of strings "
|
362
|
+
"as path to individual subsets."
|
363
|
+
)
|
364
|
+
|
365
|
+
name_reference = "_rlnImageName"
|
366
|
+
if isinstance(name, str):
|
367
|
+
name = [
|
368
|
+
name,
|
369
|
+
] * self.translations.shape[0]
|
370
|
+
name_reference = "_rlnMicrographName"
|
371
|
+
|
355
372
|
header = [
|
356
373
|
"data_particles",
|
357
374
|
"",
|
@@ -359,7 +376,7 @@ class Orientations:
|
|
359
376
|
"_rlnCoordinateX",
|
360
377
|
"_rlnCoordinateY",
|
361
378
|
"_rlnCoordinateZ",
|
362
|
-
|
379
|
+
name_reference,
|
363
380
|
"_rlnAngleRot",
|
364
381
|
"_rlnAngleTilt",
|
365
382
|
"_rlnAnglePsi",
|
@@ -371,8 +388,6 @@ class Orientations:
|
|
371
388
|
ctf_image = "" if ctf_image is None else f"\t{ctf_image}"
|
372
389
|
|
373
390
|
header = "\n".join(header)
|
374
|
-
name_prefix = "" if name_prefix is None else name_prefix
|
375
|
-
|
376
391
|
with open(filename, mode="w", encoding="utf-8") as ofile:
|
377
392
|
_ = ofile.write(f"{optics_header}\n")
|
378
393
|
_ = ofile.write(f"{optics_data}\n")
|
@@ -387,9 +402,8 @@ class Orientations:
|
|
387
402
|
|
388
403
|
translation_string = "\t".join([str(x) for x in translation][::-1])
|
389
404
|
angle_string = "\t".join([str(x) for x in rotation])
|
390
|
-
name = f"{name_prefix}_{index}.mrc"
|
391
405
|
_ = ofile.write(
|
392
|
-
f"{translation_string}\t{name}\t{angle_string}\t1{ctf_image}\n"
|
406
|
+
f"{translation_string}\t{name[index]}\t{angle_string}\t1{ctf_image}\n"
|
393
407
|
)
|
394
408
|
|
395
409
|
return None
|
@@ -584,7 +598,10 @@ class Orientations:
|
|
584
598
|
cls, filename: str, delimiter: str = None
|
585
599
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
586
600
|
ret = cls._parse_star(filename=filename, delimiter=delimiter)
|
587
|
-
|
601
|
+
|
602
|
+
ret = ret.get("data_particles", None)
|
603
|
+
if ret is None:
|
604
|
+
raise ValueError(f"No data_particles section found in {filename}.")
|
588
605
|
|
589
606
|
translation = np.vstack(
|
590
607
|
(ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
|
tme/preprocessing/_utils.py
CHANGED
@@ -19,8 +19,8 @@ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool =
|
|
19
19
|
"""
|
20
20
|
Given an opening_axis, computes the shape of the remaining dimensions.
|
21
21
|
|
22
|
-
Parameters
|
23
|
-
|
22
|
+
Parameters
|
23
|
+
----------
|
24
24
|
shape : Tuple[int]
|
25
25
|
The shape of the input array.
|
26
26
|
opening_axis : int
|
@@ -28,8 +28,8 @@ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool =
|
|
28
28
|
reduce_dim : bool, optional (default=False)
|
29
29
|
Whether to reduce the dimensionality after tilting.
|
30
30
|
|
31
|
-
Returns
|
32
|
-
|
31
|
+
Returns
|
32
|
+
-------
|
33
33
|
Tuple[int]
|
34
34
|
The shape of the array after tilting.
|
35
35
|
"""
|
@@ -44,13 +44,13 @@ def centered_grid(shape: Tuple[int]) -> NDArray:
|
|
44
44
|
"""
|
45
45
|
Generate an integer valued grid centered around size // 2
|
46
46
|
|
47
|
-
Parameters
|
48
|
-
|
47
|
+
Parameters
|
48
|
+
----------
|
49
49
|
shape : Tuple[int]
|
50
50
|
The shape of the grid.
|
51
51
|
|
52
|
-
Returns
|
53
|
-
|
52
|
+
Returns
|
53
|
+
-------
|
54
54
|
NDArray
|
55
55
|
The centered grid.
|
56
56
|
"""
|
@@ -70,8 +70,8 @@ def frequency_grid_at_angle(
|
|
70
70
|
"""
|
71
71
|
Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
|
72
72
|
|
73
|
-
Parameters
|
74
|
-
|
73
|
+
Parameters
|
74
|
+
----------
|
75
75
|
shape : Tuple[int]
|
76
76
|
The shape of the grid.
|
77
77
|
angle : float
|
@@ -128,8 +128,8 @@ def fftfreqn(
|
|
128
128
|
"""
|
129
129
|
Generate the n-dimensional discrete Fourier transform sample frequencies.
|
130
130
|
|
131
|
-
Parameters
|
132
|
-
|
131
|
+
Parameters
|
132
|
+
----------
|
133
133
|
shape : Tuple[int]
|
134
134
|
The shape of the data.
|
135
135
|
sampling_rate : float or Tuple[float]
|
@@ -180,8 +180,8 @@ def crop_real_fourier(data: BackendArray) -> BackendArray:
|
|
180
180
|
"""
|
181
181
|
Crop the real part of a Fourier transform.
|
182
182
|
|
183
|
-
Parameters
|
184
|
-
|
183
|
+
Parameters
|
184
|
+
----------
|
185
185
|
data : BackendArray
|
186
186
|
The Fourier transformed data.
|
187
187
|
|
@@ -17,15 +17,16 @@ class ComposableFilter(ABC):
|
|
17
17
|
@abstractmethod
|
18
18
|
def __call__(self, *args, **kwargs) -> Dict:
|
19
19
|
"""
|
20
|
-
|
21
|
-
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
22
23
|
*args : tuple
|
23
24
|
Variable length argument list.
|
24
25
|
**kwargs : dict
|
25
26
|
Arbitrary keyword arguments.
|
26
27
|
|
27
|
-
Returns
|
28
|
-
|
28
|
+
Returns
|
29
|
+
-------
|
29
30
|
Dict
|
30
31
|
A dictionary representing the result of the filtering operation.
|
31
32
|
"""
|
tme/preprocessing/compose.py
CHANGED
@@ -17,13 +17,13 @@ class Compose:
|
|
17
17
|
This class allows composing multiple transformations together. Each transformation
|
18
18
|
is expected to be a callable that accepts keyword arguments and returns metadata.
|
19
19
|
|
20
|
-
Parameters
|
21
|
-
|
20
|
+
Parameters
|
21
|
+
----------
|
22
22
|
transforms : Tuple[object]
|
23
23
|
A tuple containing transformation objects.
|
24
24
|
|
25
|
-
Returns
|
26
|
-
|
25
|
+
Returns
|
26
|
+
-------
|
27
27
|
Dict
|
28
28
|
Metadata resulting from the composed transformations.
|
29
29
|
|
@@ -18,12 +18,10 @@ from ._utils import fftfreqn, crop_real_fourier, shift_fourier, compute_fourier_
|
|
18
18
|
|
19
19
|
class BandPassFilter:
|
20
20
|
"""
|
21
|
-
|
22
|
-
either by directly specifying the frequency cutoffs (discrete_bandpass) or
|
23
|
-
by using Gaussian functions (gaussian_bandpass).
|
21
|
+
Generate bandpass filters in Fourier space.
|
24
22
|
|
25
|
-
Parameters
|
26
|
-
|
23
|
+
Parameters
|
24
|
+
----------
|
27
25
|
lowpass : float, optional
|
28
26
|
The lowpass cutoff, defaults to None.
|
29
27
|
highpass : float, optional
|
@@ -67,8 +65,8 @@ class BandPassFilter:
|
|
67
65
|
"""
|
68
66
|
Generate a bandpass filter using discrete frequency cutoffs.
|
69
67
|
|
70
|
-
Parameters
|
71
|
-
|
68
|
+
Parameters
|
69
|
+
----------
|
72
70
|
shape : tuple of int
|
73
71
|
The shape of the bandpass filter.
|
74
72
|
lowpass : float
|
@@ -84,8 +82,8 @@ class BandPassFilter:
|
|
84
82
|
**kwargs : dict
|
85
83
|
Additional keyword arguments.
|
86
84
|
|
87
|
-
Returns
|
88
|
-
|
85
|
+
Returns
|
86
|
+
-------
|
89
87
|
BackendArray
|
90
88
|
The bandpass filter in Fourier space.
|
91
89
|
"""
|
@@ -98,17 +96,18 @@ class BandPassFilter:
|
|
98
96
|
shape_is_real_fourier=shape_is_real_fourier,
|
99
97
|
compute_euclidean_norm=True,
|
100
98
|
)
|
101
|
-
|
102
|
-
|
103
|
-
highpass = 1e10 if highpass is None else highpass
|
99
|
+
grid = be.to_backend_array(grid)
|
100
|
+
sampling_rate = be.to_backend_array(sampling_rate)
|
104
101
|
|
105
102
|
highcut = grid.max()
|
106
|
-
if lowpass
|
107
|
-
highcut =
|
108
|
-
lowcut = np.max(2 * sampling_rate / highpass)
|
103
|
+
if lowpass is not None:
|
104
|
+
highcut = be.max(2 * sampling_rate / lowpass)
|
109
105
|
|
110
|
-
|
106
|
+
lowcut = 0
|
107
|
+
if highpass is not None:
|
108
|
+
lowcut = be.max(2 * sampling_rate / highpass)
|
111
109
|
|
110
|
+
bandpass_filter = ((grid <= highcut) & (grid >= lowcut)) * 1.0
|
112
111
|
bandpass_filter = shift_fourier(
|
113
112
|
data=bandpass_filter, shape_is_real_fourier=shape_is_real_fourier
|
114
113
|
)
|
@@ -129,10 +128,10 @@ class BandPassFilter:
|
|
129
128
|
**kwargs,
|
130
129
|
) -> BackendArray:
|
131
130
|
"""
|
132
|
-
Generate a bandpass filter using
|
131
|
+
Generate a bandpass filter using Gaussians.
|
133
132
|
|
134
|
-
Parameters
|
135
|
-
|
133
|
+
Parameters
|
134
|
+
----------
|
136
135
|
shape : tuple of int
|
137
136
|
The shape of the bandpass filter.
|
138
137
|
lowpass : float
|
@@ -148,8 +147,8 @@ class BandPassFilter:
|
|
148
147
|
**kwargs : dict
|
149
148
|
Additional keyword arguments.
|
150
149
|
|
151
|
-
Returns
|
152
|
-
|
150
|
+
Returns
|
151
|
+
-------
|
153
152
|
BackendArray
|
154
153
|
The bandpass filter in Fourier space.
|
155
154
|
"""
|
@@ -216,15 +215,13 @@ class BandPassFilter:
|
|
216
215
|
|
217
216
|
class LinearWhiteningFilter:
|
218
217
|
"""
|
219
|
-
|
220
|
-
apply linear whitening to the Fourier coefficients.
|
218
|
+
Compute Fourier power spectrums and perform whitening.
|
221
219
|
|
222
|
-
Parameters
|
223
|
-
|
220
|
+
Parameters
|
221
|
+
----------
|
224
222
|
**kwargs : Dict, optional
|
225
223
|
Additional keyword arguments.
|
226
224
|
|
227
|
-
|
228
225
|
References
|
229
226
|
----------
|
230
227
|
.. [1] de Teresa-Trueba, I.; Goetz, S. K.; Mattausch, A.; Stojanovska, F.; Zimmerli, C. E.;
|
@@ -243,10 +240,10 @@ class LinearWhiteningFilter:
|
|
243
240
|
data_rfft: BackendArray, n_bins: int = None, batch_dimension: int = None
|
244
241
|
) -> Tuple[BackendArray, BackendArray]:
|
245
242
|
"""
|
246
|
-
Compute the spectrum of the input data.
|
243
|
+
Compute the power spectrum of the input data.
|
247
244
|
|
248
|
-
Parameters
|
249
|
-
|
245
|
+
Parameters
|
246
|
+
----------
|
250
247
|
data_rfft : BackendArray
|
251
248
|
The Fourier transform of the input data.
|
252
249
|
n_bins : int, optional
|
@@ -254,8 +251,8 @@ class LinearWhiteningFilter:
|
|
254
251
|
batch_dimension : int, optional
|
255
252
|
Batch dimension to average over.
|
256
253
|
|
257
|
-
Returns
|
258
|
-
|
254
|
+
Returns
|
255
|
+
-------
|
259
256
|
bins : BackendArray
|
260
257
|
Array containing the bin indices for the spectrum.
|
261
258
|
radial_averages : BackendArray
|
@@ -330,8 +327,8 @@ class LinearWhiteningFilter:
|
|
330
327
|
"""
|
331
328
|
Apply linear whitening to the data and return the result.
|
332
329
|
|
333
|
-
Parameters
|
334
|
-
|
330
|
+
Parameters
|
331
|
+
----------
|
335
332
|
data : BackendArray, optional
|
336
333
|
The input data, defaults to None.
|
337
334
|
data_rfft : BackendArray, optional
|
@@ -345,8 +342,8 @@ class LinearWhiteningFilter:
|
|
345
342
|
**kwargs : Dict
|
346
343
|
Additional keyword arguments.
|
347
344
|
|
348
|
-
Returns
|
349
|
-
|
345
|
+
Returns
|
346
|
+
-------
|
350
347
|
Dict
|
351
348
|
Filter data and associated parameters.
|
352
349
|
"""
|