pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +49 -103
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
- pytme-0.3.2.dev0.dist-info/RECORD +136 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +49 -103
- scripts/pytme_runner.py +46 -69
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_orientations.py +12 -0
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +91 -68
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +103 -98
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +44 -57
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +17 -3
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- pytme-0.3.1.post2.dist-info/RECORD +0 -133
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
scripts/pytme_runner.py
CHANGED
@@ -8,7 +8,7 @@ import subprocess
|
|
8
8
|
from abc import ABC, abstractmethod
|
9
9
|
|
10
10
|
from pathlib import Path
|
11
|
-
from dataclasses import dataclass
|
11
|
+
from dataclasses import dataclass, field, fields
|
12
12
|
from typing import Dict, List, Optional, Any
|
13
13
|
|
14
14
|
from tme.backends import backend as be
|
@@ -149,7 +149,7 @@ class AnalysisDatasetDiscovery(DatasetDiscovery):
|
|
149
149
|
|
150
150
|
#: Glob pattern for TM pickle files, e.g., "/data/results/*.pickle"
|
151
151
|
input_patterns: List[str]
|
152
|
-
#: List of glob patterns for background files, e.g., ["
|
152
|
+
#: List of glob patterns for background files, e.g., ["bg1/*.pickle", "bg2/*."]
|
153
153
|
background_patterns: List[str] = None
|
154
154
|
#: Target masks, e.g., "/data/masks/*.mrc"
|
155
155
|
mask_patterns: Optional[str] = None
|
@@ -230,46 +230,47 @@ class TMParameters:
|
|
230
230
|
axis_sampling: Optional[float] = None
|
231
231
|
axis_symmetry: int = 1
|
232
232
|
cone_axis: int = 2
|
233
|
-
invert_cone: bool = False
|
234
|
-
no_use_optimized_set: bool = False
|
233
|
+
invert_cone: bool = field(default=False, metadata={"flag": True})
|
234
|
+
no_use_optimized_set: bool = field(default=False, metadata={"flag": True})
|
235
235
|
|
236
236
|
# Microscope parameters
|
237
237
|
acceleration_voltage: float = 300.0 # kV
|
238
238
|
spherical_aberration: float = 2.7e7 # Å
|
239
239
|
amplitude_contrast: float = 0.07
|
240
240
|
defocus: Optional[float] = None # Å
|
241
|
-
phase_shift: float = 0.0 #
|
241
|
+
phase_shift: float = 0.0 # degrees
|
242
242
|
|
243
243
|
# Processing options
|
244
244
|
lowpass: Optional[float] = None # Å
|
245
245
|
highpass: Optional[float] = None # Å
|
246
246
|
pass_format: str = "sampling_rate" # "sampling_rate", "voxel", "frequency"
|
247
|
-
no_pass_smooth: bool = True
|
247
|
+
no_pass_smooth: bool = field(default=True, metadata={"flag": False})
|
248
248
|
interpolation_order: int = 3
|
249
249
|
score_threshold: float = 0.0
|
250
250
|
score: str = "FLCSphericalMask"
|
251
|
+
background_correction: Optional[str] = None
|
251
252
|
|
252
253
|
# Weighting and correction
|
253
254
|
tilt_weighting: Optional[str] = None # "angle", "relion", "grigorieff"
|
254
255
|
wedge_axes: str = "2,0"
|
255
|
-
whiten_spectrum: bool = False
|
256
|
-
scramble_phases: bool = False
|
257
|
-
invert_target_contrast: bool = False
|
256
|
+
whiten_spectrum: bool = field(default=False, metadata={"flag": True})
|
257
|
+
scramble_phases: bool = field(default=False, metadata={"flag": True})
|
258
|
+
invert_target_contrast: bool = field(default=False, metadata={"flag": True})
|
258
259
|
|
259
260
|
# CTF parameters
|
260
261
|
ctf_file: Optional[Path] = None
|
261
|
-
no_flip_phase: bool = True
|
262
|
-
correct_defocus_gradient: bool = False
|
262
|
+
no_flip_phase: bool = field(default=True, metadata={"flag": False})
|
263
|
+
correct_defocus_gradient: bool = field(default=False, metadata={"flag": True})
|
263
264
|
|
264
265
|
# Performance options
|
265
|
-
centering: bool = False
|
266
|
-
pad_edges: bool = False
|
267
|
-
pad_filter: bool = False
|
268
|
-
use_mixed_precision: bool = False
|
269
|
-
use_memmap: bool = False
|
266
|
+
centering: bool = field(default=False, metadata={"flag": True})
|
267
|
+
pad_edges: bool = field(default=False, metadata={"flag": True})
|
268
|
+
pad_filter: bool = field(default=False, metadata={"flag": True})
|
269
|
+
use_mixed_precision: bool = field(default=False, metadata={"flag": True})
|
270
|
+
use_memmap: bool = field(default=False, metadata={"flag": True})
|
270
271
|
|
271
272
|
# Analysis options
|
272
|
-
peak_calling: bool = False
|
273
|
+
peak_calling: bool = field(default=False, metadata={"flag": True})
|
273
274
|
num_peaks: int = 1000
|
274
275
|
|
275
276
|
# Backend selection
|
@@ -279,7 +280,7 @@ class TMParameters:
|
|
279
280
|
# Reconstruction
|
280
281
|
reconstruction_filter: str = "ramp"
|
281
282
|
reconstruction_interpolation_order: int = 1
|
282
|
-
no_filter_target: bool = False
|
283
|
+
no_filter_target: bool = field(default=False, metadata={"flag": True})
|
283
284
|
|
284
285
|
def __post_init__(self):
|
285
286
|
"""Validate parameters and convert units."""
|
@@ -337,6 +338,9 @@ class TMParameters:
|
|
337
338
|
args["ctf-file"] = str(files.metadata)
|
338
339
|
args["tilt-angles"] = str(files.metadata)
|
339
340
|
|
341
|
+
if self.background_correction:
|
342
|
+
args["background-correction"] = self.background_correction
|
343
|
+
|
340
344
|
# Optional parameters
|
341
345
|
if self.lowpass:
|
342
346
|
args["lowpass"] = self.lowpass
|
@@ -378,43 +382,26 @@ class TMParameters:
|
|
378
382
|
return {k: v for k, v in args.items() if v is not None}
|
379
383
|
|
380
384
|
def get_flags(self) -> List[str]:
|
381
|
-
"""Get boolean flags for pyTME command."""
|
382
385
|
flags = []
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
if self.use_mixed_precision:
|
398
|
-
flags.append("use-mixed-precision")
|
399
|
-
if self.use_memmap:
|
400
|
-
flags.append("use-memmap")
|
401
|
-
if self.peak_calling:
|
402
|
-
flags.append("peak-calling")
|
403
|
-
if not self.no_flip_phase:
|
404
|
-
flags.append("no-flip-phase")
|
405
|
-
if self.correct_defocus_gradient:
|
406
|
-
flags.append("correct-defocus-gradient")
|
407
|
-
if self.invert_cone:
|
408
|
-
flags.append("invert-cone")
|
409
|
-
if self.no_use_optimized_set:
|
410
|
-
flags.append("no-use-optimized-set")
|
411
|
-
if self.no_filter_target:
|
412
|
-
flags.append("no-filter-target")
|
386
|
+
|
387
|
+
for field_info in fields(self):
|
388
|
+
flag_meta = field_info.metadata.get("flag")
|
389
|
+
if flag_meta is None:
|
390
|
+
continue
|
391
|
+
|
392
|
+
value = getattr(self, field_info.name)
|
393
|
+
if not isinstance(value, bool):
|
394
|
+
continue
|
395
|
+
|
396
|
+
flag_name = field_info.name.replace("_", "-")
|
397
|
+
if (flag_meta is True and value) or (flag_meta is False and not value):
|
398
|
+
flags.append(flag_name)
|
399
|
+
|
413
400
|
return flags
|
414
401
|
|
415
402
|
|
416
403
|
@dataclass
|
417
|
-
class AnalysisParameters:
|
404
|
+
class AnalysisParameters(TMParameters):
|
418
405
|
"""Parameters for template matching analysis and peak calling."""
|
419
406
|
|
420
407
|
# Peak calling
|
@@ -424,13 +411,12 @@ class AnalysisParameters:
|
|
424
411
|
max_score: Optional[float] = None
|
425
412
|
min_distance: int = 5
|
426
413
|
min_boundary_distance: int = 0
|
427
|
-
mask_edges: bool = False
|
414
|
+
mask_edges: bool = field(default=False, metadata={"flag": True})
|
428
415
|
n_false_positives: Optional[int] = None
|
429
416
|
|
430
417
|
# Output format
|
431
418
|
output_format: str = "relion4"
|
432
419
|
output_directory: Optional[str] = None
|
433
|
-
angles_clockwise: bool = False
|
434
420
|
|
435
421
|
# Advanced options
|
436
422
|
extraction_box_size: Optional[int] = None
|
@@ -468,15 +454,6 @@ class AnalysisParameters:
|
|
468
454
|
|
469
455
|
return {k: v for k, v in args.items() if v is not None}
|
470
456
|
|
471
|
-
def get_flags(self) -> List[str]:
|
472
|
-
"""Get boolean flags for analyze_template_matching command."""
|
473
|
-
flags = []
|
474
|
-
if self.mask_edges:
|
475
|
-
flags.append("mask-edges")
|
476
|
-
if self.angles_clockwise:
|
477
|
-
flags.append("angles-clockwise")
|
478
|
-
return flags
|
479
|
-
|
480
457
|
|
481
458
|
@dataclass
|
482
459
|
class ComputeResources:
|
@@ -592,12 +569,10 @@ class ExecutionBackend(ABC):
|
|
592
569
|
@abstractmethod
|
593
570
|
def submit_job(self, task) -> str:
|
594
571
|
"""Submit a single job and return job ID or status."""
|
595
|
-
pass
|
596
572
|
|
597
573
|
@abstractmethod
|
598
574
|
def submit_jobs(self, tasks: List) -> List[str]:
|
599
575
|
"""Submit multiple jobs and return list of job IDs."""
|
600
|
-
pass
|
601
576
|
|
602
577
|
|
603
578
|
class SlurmBackend(ExecutionBackend):
|
@@ -841,6 +816,13 @@ def parse_args():
|
|
841
816
|
default="FLCSphericalMask",
|
842
817
|
help="Template matching scoring function. Use FLC if mask is not spherical.",
|
843
818
|
)
|
819
|
+
tm_group.add_argument(
|
820
|
+
"--background-correction",
|
821
|
+
choices=["phase-scrambling"],
|
822
|
+
required=False,
|
823
|
+
help="Transform cross-correlation into SNR-like values using a given method: "
|
824
|
+
"'phase-scrambling' uses a phase-scrambled template as background",
|
825
|
+
)
|
844
826
|
tm_group.add_argument(
|
845
827
|
"--score-threshold", type=float, default=0.0, help="Minimum score threshold"
|
846
828
|
)
|
@@ -1006,11 +988,6 @@ def parse_args():
|
|
1006
988
|
default="relion4",
|
1007
989
|
help="Output format for analysis results",
|
1008
990
|
)
|
1009
|
-
output_group.add_argument(
|
1010
|
-
"--angles-clockwise",
|
1011
|
-
action="store_true",
|
1012
|
-
help="Report Euler angles in clockwise format expected by RELION",
|
1013
|
-
)
|
1014
991
|
|
1015
992
|
advanced_group = analysis_parser.add_argument_group("Advanced Options")
|
1016
993
|
advanced_group.add_argument(
|
@@ -1077,6 +1054,7 @@ def run_matching(args, resources):
|
|
1077
1054
|
backend=args.backend,
|
1078
1055
|
whiten_spectrum=args.whiten_spectrum,
|
1079
1056
|
scramble_phases=args.scramble_phases,
|
1057
|
+
background_correction=args.background_correction,
|
1080
1058
|
)
|
1081
1059
|
print_params = params.to_command_args(files[0], "")
|
1082
1060
|
_ = print_params.pop("target")
|
@@ -1132,7 +1110,6 @@ def run_analysis(args, resources):
|
|
1132
1110
|
mask_edges=args.mask_edges,
|
1133
1111
|
n_false_positives=args.n_false_positives,
|
1134
1112
|
output_format=args.output_format,
|
1135
|
-
angles_clockwise=args.angles_clockwise,
|
1136
1113
|
extraction_box_size=args.extraction_box_size,
|
1137
1114
|
)
|
1138
1115
|
print_params = params.to_command_args(files[0], Path(""))
|
@@ -1,28 +1,38 @@
|
|
1
1
|
import pytest
|
2
2
|
|
3
|
-
from tme.filters import Compose
|
3
|
+
from tme.filters import Compose, ComposableFilter
|
4
4
|
from tme.backends import backend as be
|
5
5
|
|
6
6
|
|
7
|
-
|
8
|
-
|
7
|
+
class MockFilter(ComposableFilter):
|
8
|
+
def _evaluate(self, *args, **kwargs):
|
9
|
+
return {"data": be.ones((10, 10)), "shape": (10, 10)}
|
10
|
+
|
11
|
+
|
12
|
+
class MockFilterNoMult(ComposableFilter):
|
13
|
+
def _evaluate(self, *args, **kwargs):
|
14
|
+
return {
|
15
|
+
"data": be.ones((10, 10)) * 3,
|
16
|
+
"shape": (10, 10),
|
17
|
+
"is_multiplicative_filter": False,
|
18
|
+
}
|
9
19
|
|
10
20
|
|
11
|
-
|
12
|
-
|
21
|
+
mock_transform = MockFilter()
|
22
|
+
mock_transform_nomult = MockFilterNoMult()
|
13
23
|
|
14
24
|
|
15
|
-
def
|
16
|
-
return {"
|
25
|
+
def mock_transform_error(**kwargs):
|
26
|
+
return {"data": be.ones((10, 10)), "is_multiplicative_filter": True}
|
17
27
|
|
18
28
|
|
19
29
|
class TestCompose:
|
20
30
|
@pytest.fixture
|
21
31
|
def compose_instance(self):
|
22
|
-
return Compose((
|
32
|
+
return Compose((mock_transform, mock_transform, mock_transform))
|
23
33
|
|
24
34
|
def test_init(self):
|
25
|
-
transforms = (
|
35
|
+
transforms = (mock_transform, mock_transform)
|
26
36
|
compose = Compose(transforms)
|
27
37
|
assert compose.transforms == transforms
|
28
38
|
|
@@ -32,45 +42,36 @@ class TestCompose:
|
|
32
42
|
assert result == {}
|
33
43
|
|
34
44
|
def test_call_single_transform(self):
|
35
|
-
compose = Compose((
|
36
|
-
result = compose()
|
45
|
+
compose = Compose((mock_transform,))
|
46
|
+
result = compose(return_real_fourier=False)
|
37
47
|
assert "data" in result
|
38
|
-
assert result.get("is_multiplicative_filter", False)
|
39
48
|
assert be.allclose(result["data"], be.ones((10, 10)))
|
40
49
|
|
41
50
|
def test_call_multiple_transforms(self, compose_instance):
|
42
|
-
result = compose_instance()
|
51
|
+
result = compose_instance(return_real_fourier=False)
|
43
52
|
assert "data" in result
|
44
|
-
assert "
|
45
|
-
assert be.allclose(result["data"], be.ones((10, 10)) * 2)
|
53
|
+
assert be.allclose(result["data"], be.ones((10, 10)))
|
46
54
|
|
47
55
|
def test_multiplicative_filter_composition(self):
|
48
|
-
compose = Compose((
|
49
|
-
result = compose()
|
56
|
+
compose = Compose((mock_transform, mock_transform))
|
57
|
+
result = compose(return_real_fourier=False)
|
50
58
|
assert "data" in result
|
51
|
-
assert be.allclose(result["data"], be.ones((10, 10))
|
59
|
+
assert be.allclose(result["data"], be.ones((10, 10)))
|
52
60
|
|
53
61
|
@pytest.mark.parametrize(
|
54
62
|
"kwargs", [{}, {"extra_param": "test"}, {"data": be.zeros((5, 5))}]
|
55
63
|
)
|
56
64
|
def test_call_with_kwargs(self, compose_instance, kwargs):
|
57
|
-
result = compose_instance(**kwargs)
|
65
|
+
result = compose_instance(**kwargs, return_real_fourier=False)
|
58
66
|
assert "data" in result
|
59
67
|
assert "extra_info" not in result
|
60
68
|
|
61
69
|
def test_non_multiplicative_filter(self):
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
compose = Compose((mock_transform1, non_mult_transform))
|
66
|
-
result = compose()
|
70
|
+
compose = Compose((mock_transform, mock_transform_nomult))
|
71
|
+
result = compose(return_real_fourier=False)
|
67
72
|
assert "data" in result
|
68
73
|
assert be.allclose(result["data"], be.ones((10, 10)) * 3)
|
69
74
|
|
70
75
|
def test_error_handling(self):
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
compose = Compose((mock_transform1, error_transform))
|
75
|
-
with pytest.raises(ValueError, match="Test error"):
|
76
|
-
compose()
|
76
|
+
with pytest.raises(ValueError):
|
77
|
+
Compose((mock_transform, mock_transform_error))()
|
@@ -71,52 +71,43 @@ class TestBandPassFilter:
|
|
71
71
|
shape_is_real_fourier: bool,
|
72
72
|
):
|
73
73
|
band_pass_filter.use_gaussian = use_gaussian
|
74
|
-
band_pass_filter.return_real_fourier = return_real_fourier
|
75
74
|
band_pass_filter.shape_is_real_fourier = shape_is_real_fourier
|
76
75
|
|
77
76
|
result = band_pass_filter(shape=(10, 10), lowpass=0.2, highpass=0.8)
|
78
77
|
|
79
78
|
assert isinstance(result, dict)
|
80
79
|
assert "data" in result
|
81
|
-
assert "is_multiplicative_filter" in result
|
82
80
|
assert isinstance(result["data"], type(be.ones((1,))))
|
83
|
-
assert result["is_multiplicative_filter"] is True
|
84
81
|
|
85
82
|
def test_default_values(self, band_pass_filter: BandPassReconstructed):
|
86
83
|
assert band_pass_filter.lowpass is None
|
87
84
|
assert band_pass_filter.highpass is None
|
88
85
|
assert band_pass_filter.sampling_rate == 1
|
89
86
|
assert band_pass_filter.use_gaussian is True
|
90
|
-
assert band_pass_filter.return_real_fourier is False
|
91
87
|
|
92
88
|
@pytest.mark.parametrize("shape", ((10, 10), (20, 20, 20), (30, 30)))
|
93
89
|
def test_return_real_fourier(self, shape: Tuple[int]):
|
94
|
-
bpf = BandPassReconstructed(
|
90
|
+
bpf = BandPassReconstructed()
|
95
91
|
result = bpf(shape=shape, lowpass=0.2, highpass=0.8)
|
96
|
-
|
97
|
-
assert result["data"].shape == expected_shape
|
92
|
+
assert result["data"].shape == shape
|
98
93
|
|
99
94
|
|
100
95
|
class TestLinearWhiteningFilter:
|
101
96
|
@pytest.mark.parametrize(
|
102
|
-
"shape, n_bins
|
97
|
+
"shape, n_bins",
|
103
98
|
[
|
104
|
-
((10, 10), None
|
105
|
-
((20, 20, 20), 15
|
106
|
-
((30, 30, 30), 20
|
107
|
-
((40, 40, 40, 40), 25
|
99
|
+
((10, 10), None),
|
100
|
+
((20, 20, 20), 15),
|
101
|
+
((30, 30, 30), 20),
|
102
|
+
((40, 40, 40, 40), 25),
|
108
103
|
],
|
109
104
|
)
|
110
|
-
def test_compute_spectrum(
|
111
|
-
self, shape: Tuple[int], n_bins: int, batch_dimension: int
|
112
|
-
):
|
105
|
+
def test_compute_spectrum(self, shape: Tuple[int], n_bins: int):
|
113
106
|
data_rfft = be.fft.rfftn(be.random.random(shape))
|
114
107
|
bins, radial_averages = LinearWhiteningFilter._compute_spectrum(
|
115
|
-
data_rfft, n_bins
|
116
|
-
)
|
117
|
-
data_shape = tuple(
|
118
|
-
int(x) for i, x in enumerate(data_rfft.shape) if i != batch_dimension
|
108
|
+
data_rfft, n_bins
|
119
109
|
)
|
110
|
+
data_shape = tuple(int(x) for i, x in enumerate(data_rfft.shape))
|
120
111
|
|
121
112
|
assert isinstance(bins, np.ndarray)
|
122
113
|
assert isinstance(radial_averages, np.ndarray)
|
@@ -140,9 +131,8 @@ class TestLinearWhiteningFilter:
|
|
140
131
|
@pytest.mark.parametrize(
|
141
132
|
"shape, n_bins, batch_dimension, order",
|
142
133
|
[
|
143
|
-
((10, 10), None,
|
144
|
-
((20, 20, 20), 15, 0,
|
145
|
-
((30, 30, 30), 20, 1, None),
|
134
|
+
((10, 10), None, (), 1),
|
135
|
+
((20, 20, 20), 15, 0, 1),
|
146
136
|
],
|
147
137
|
)
|
148
138
|
def test_call_method(
|
@@ -154,21 +144,17 @@ class TestLinearWhiteningFilter:
|
|
154
144
|
):
|
155
145
|
data = be.random.random(shape)
|
156
146
|
result = LinearWhiteningFilter()(
|
157
|
-
shape=shape,
|
158
|
-
|
147
|
+
shape=tuple(x for i, x in enumerate(shape) if i != batch_dimension),
|
148
|
+
data_rfft=np.fft.rfftn(data),
|
159
149
|
n_bins=n_bins,
|
160
|
-
|
150
|
+
axes=batch_dimension,
|
161
151
|
order=order,
|
162
152
|
)
|
163
153
|
|
164
154
|
assert isinstance(result, dict)
|
165
155
|
assert result.get("data", False) is not False
|
166
|
-
assert result.get("is_multiplicative_filter", False)
|
167
156
|
assert isinstance(result["data"], type(be.ones((1,))))
|
168
|
-
|
169
|
-
int(x) for i, x in enumerate(data.shape) if i != batch_dimension
|
170
|
-
)
|
171
|
-
assert result["data"].shape == tuple(compute_fourier_shape(data_shape, False))
|
157
|
+
assert result["data"].shape == shape
|
172
158
|
|
173
159
|
def test_call_method_with_data_rfft(self):
|
174
160
|
shape = (30, 30, 30)
|
@@ -179,14 +165,13 @@ class TestLinearWhiteningFilter:
|
|
179
165
|
|
180
166
|
assert isinstance(result, dict)
|
181
167
|
assert result.get("data", False) is not False
|
182
|
-
assert result.get("is_multiplicative_filter", False)
|
183
168
|
assert isinstance(result["data"], type(be.ones((1,))))
|
184
169
|
assert result["data"].shape == data_rfft.shape
|
185
170
|
|
186
171
|
@pytest.mark.parametrize("shape", [(10, 10), (20, 20, 20), (30, 30, 30)])
|
187
172
|
def test_filter_mask_range(self, shape: Tuple[int]):
|
188
173
|
data = be.random.random(shape)
|
189
|
-
result = LinearWhiteningFilter()(shape=shape,
|
174
|
+
result = LinearWhiteningFilter()(shape=shape, data_rfft=np.fft.rfftn(data))
|
190
175
|
|
191
176
|
filter_mask = result["data"]
|
192
177
|
assert np.all(filter_mask >= 0) and np.all(filter_mask <= 1)
|
@@ -1,5 +1,4 @@
|
|
1
1
|
import pytest
|
2
|
-
import numpy as np
|
3
2
|
|
4
3
|
from tme import Density, Structure, Preprocessor
|
5
4
|
|
@@ -49,15 +48,6 @@ class TestPreprocessor:
|
|
49
48
|
high_sigma=high_sigma,
|
50
49
|
)
|
51
50
|
|
52
|
-
@pytest.mark.parametrize("smallest_size,largest_size", [(1, 10), (2, 20)])
|
53
|
-
def test_bandpass_filter(self, smallest_size, largest_size):
|
54
|
-
_ = self.preprocessor.bandpass_filter(
|
55
|
-
template=self.structure_density.data,
|
56
|
-
lowpass=smallest_size,
|
57
|
-
highpass=largest_size,
|
58
|
-
sampling_rate=1,
|
59
|
-
)
|
60
|
-
|
61
51
|
@pytest.mark.parametrize("lbd,sigma_range", [(1, (2, 4)), (20, (1, 6))])
|
62
52
|
def test_local_gaussian_alignment_filter(self, lbd, sigma_range):
|
63
53
|
_ = self.preprocessor.local_gaussian_alignment_filter(
|
@@ -125,12 +115,3 @@ class TestPreprocessor:
|
|
125
115
|
template=self.structure_density.data,
|
126
116
|
rank=rank,
|
127
117
|
)
|
128
|
-
|
129
|
-
@pytest.mark.parametrize("infinite_plane", [False, True])
|
130
|
-
def test_continuous_wedge_mask(self, infinite_plane):
|
131
|
-
_ = self.preprocessor.continuous_wedge_mask(
|
132
|
-
start_tilt=50,
|
133
|
-
stop_tilt=-40,
|
134
|
-
shape=(50, 50, 50),
|
135
|
-
infinite_plane=infinite_plane,
|
136
|
-
)
|
@@ -74,10 +74,22 @@ class TestPreprocessUtils:
|
|
74
74
|
def test_fftfreqn(self, n, sampling_rate):
|
75
75
|
assert np.allclose(
|
76
76
|
fftfreqn(
|
77
|
-
shape=(n,),
|
77
|
+
shape=(n,),
|
78
|
+
sampling_rate=sampling_rate,
|
79
|
+
compute_euclidean_norm=True,
|
80
|
+
fftshift=True,
|
78
81
|
),
|
79
82
|
np.abs(np.fft.ifftshift(np.fft.fftfreq(n=n, d=sampling_rate))),
|
80
83
|
)
|
84
|
+
assert np.allclose(
|
85
|
+
fftfreqn(
|
86
|
+
shape=(n,),
|
87
|
+
sampling_rate=sampling_rate,
|
88
|
+
compute_euclidean_norm=True,
|
89
|
+
fftshift=False,
|
90
|
+
),
|
91
|
+
np.abs(np.fft.fftfreq(n=n, d=sampling_rate)),
|
92
|
+
)
|
81
93
|
|
82
94
|
@pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
|
83
95
|
def test_crop_real_fourier(self, shape):
|
tests/test_analyzer.py
CHANGED
@@ -109,13 +109,11 @@ class TestRecursiveMasking:
|
|
109
109
|
|
110
110
|
@pytest.mark.parametrize("num_peaks", (1, 100))
|
111
111
|
@pytest.mark.parametrize("compute_rotation", (True, False))
|
112
|
-
|
113
|
-
def test__call__(self, num_peaks, compute_rotation, minimum_score):
|
112
|
+
def test__call__(self, num_peaks, compute_rotation):
|
114
113
|
peak_caller = PeakCallerRecursiveMasking(
|
115
114
|
shape=self.data.shape,
|
116
115
|
num_peaks=num_peaks,
|
117
116
|
min_distance=self.min_distance,
|
118
|
-
min_score=minimum_score,
|
119
117
|
)
|
120
118
|
rotation_space, rotation_mapping = None, None
|
121
119
|
if compute_rotation:
|
@@ -131,13 +129,7 @@ class TestRecursiveMasking:
|
|
131
129
|
rotation_space=rotation_space,
|
132
130
|
rotation_mapping=rotation_mapping,
|
133
131
|
)
|
134
|
-
|
135
|
-
if minimum_score is None:
|
136
|
-
assert len(state[0] <= num_peaks)
|
137
|
-
else:
|
138
|
-
peaks = state[0].astype(int)
|
139
|
-
assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
|
140
|
-
|
132
|
+
assert len(state[0] <= num_peaks)
|
141
133
|
|
142
134
|
class TestMaxScoreOverRotations:
|
143
135
|
def setup_method(self):
|
tests/test_backends.py
CHANGED
@@ -331,16 +331,6 @@ class TestBackends:
|
|
331
331
|
real_arr = backend.irfftn(complex_arr)
|
332
332
|
assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
|
333
333
|
|
334
|
-
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
335
|
-
def test_extract_center(self, backend):
|
336
|
-
new_shape = np.divide(self.x1.shape, 2).astype(int)
|
337
|
-
base = self.backend.extract_center(arr=self.x1, newshape=new_shape)
|
338
|
-
other = backend.extract_center(
|
339
|
-
arr=backend.to_backend_array(self.x1), newshape=new_shape
|
340
|
-
)
|
341
|
-
|
342
|
-
assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
|
343
|
-
|
344
334
|
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
345
335
|
def test_compute_convolution_shapes(self, backend):
|
346
336
|
base = self.backend.compute_convolution_shapes(self.x1.shape, self.x2.shape)
|
@@ -359,32 +349,69 @@ class TestBackends:
|
|
359
349
|
elif dim == 3:
|
360
350
|
arr[20:25, 21:26, 26:31] = 1
|
361
351
|
|
362
|
-
|
363
|
-
|
352
|
+
from tme.rotations import get_rotation_matrices
|
353
|
+
|
354
|
+
np.random.seed(42)
|
355
|
+
rotation_matrix = get_rotation_matrices(
|
356
|
+
dim=dim, angular_sampling=10, use_optimized_set=False
|
357
|
+
)[-1]
|
364
358
|
|
365
359
|
out = np.zeros_like(arr)
|
360
|
+
out.setflags(write=True)
|
366
361
|
|
367
362
|
arr_mask, out_mask = None, None
|
368
363
|
if create_mask:
|
369
364
|
arr_mask = np.multiply(np.random.rand(*arr.shape) > 0.5, 1.0)
|
370
365
|
out_mask = np.zeros_like(arr_mask)
|
366
|
+
out_mask.setflags(write=True)
|
367
|
+
|
368
|
+
out, _ = NumpyFFTWBackend().rigid_transform(
|
369
|
+
arr=arr,
|
370
|
+
arr_mask=arr_mask,
|
371
|
+
rotation_matrix=rotation_matrix,
|
372
|
+
out=out,
|
373
|
+
out_mask=out_mask,
|
374
|
+
order=1,
|
375
|
+
use_geometric_center=True,
|
376
|
+
)
|
377
|
+
|
378
|
+
arr = backend.to_backend_array(arr.copy())
|
379
|
+
out_be = backend.to_backend_array(out.copy())
|
380
|
+
if create_mask:
|
371
381
|
arr_mask = backend.to_backend_array(arr_mask)
|
372
382
|
out_mask = backend.to_backend_array(out_mask)
|
373
383
|
|
374
|
-
arr = backend.to_backend_array(arr)
|
375
|
-
out = backend.to_backend_array(arr)
|
376
|
-
|
377
384
|
rotation_matrix = backend.to_backend_array(rotation_matrix)
|
378
385
|
|
379
|
-
backend.rigid_transform(
|
386
|
+
out_be, _ = backend.rigid_transform(
|
380
387
|
arr=arr,
|
381
388
|
arr_mask=arr_mask,
|
382
389
|
rotation_matrix=rotation_matrix,
|
383
|
-
out=
|
390
|
+
out=out_be,
|
384
391
|
out_mask=out_mask,
|
392
|
+
order=1,
|
393
|
+
use_geometric_center=True,
|
385
394
|
)
|
395
|
+
out_be = backend.to_numpy_array(out_be)
|
396
|
+
assert np.allclose(out, out_be, atol=0.3)
|
386
397
|
|
387
|
-
|
398
|
+
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
399
|
+
@pytest.mark.parametrize("create_mask", (False, True))
|
400
|
+
def test_rigid_transform_identity(self, backend, create_mask):
|
401
|
+
dim = 3
|
402
|
+
shape = tuple(50 for _ in range(dim))
|
403
|
+
|
404
|
+
arr = np.zeros(shape)
|
405
|
+
arr[20:25, 21:26, 26:31] = 1
|
406
|
+
|
407
|
+
rotation_matrix = backend.to_backend_array(np.eye(dim))
|
408
|
+
out, _ = backend.rigid_transform(
|
409
|
+
arr=backend.to_backend_array(arr),
|
410
|
+
rotation_matrix=backend.to_backend_array(rotation_matrix),
|
411
|
+
order=1,
|
412
|
+
use_geometric_center=True,
|
413
|
+
)
|
414
|
+
assert np.allclose(out, arr, atol=0.01)
|
388
415
|
|
389
416
|
@pytest.mark.parametrize("dim", (2, 3))
|
390
417
|
@pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
|
@@ -418,6 +445,7 @@ class TestBackends:
|
|
418
445
|
out=out,
|
419
446
|
out_mask=out_mask,
|
420
447
|
batched=True,
|
448
|
+
order=1,
|
421
449
|
)
|
422
450
|
|
423
451
|
arr_b = backend.to_backend_array(arr_b)
|
@@ -430,6 +458,7 @@ class TestBackends:
|
|
430
458
|
rotation_matrix=rotation_matrix,
|
431
459
|
out=out_b[i],
|
432
460
|
out_mask=out_mask if out_mask is None else out_mask[i],
|
461
|
+
order=1,
|
433
462
|
)
|
434
463
|
|
435
464
|
assert np.allclose(arr, arr_b)
|