pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +50 -103
  6. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
  8. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +50 -103
  15. scripts/pytme_runner.py +46 -69
  16. scripts/refine_matches.py +5 -7
  17. tests/preprocessing/test_compose.py +31 -30
  18. tests/preprocessing/test_frequency_filters.py +17 -32
  19. tests/preprocessing/test_preprocessor.py +0 -19
  20. tests/preprocessing/test_utils.py +13 -1
  21. tests/test_analyzer.py +2 -10
  22. tests/test_backends.py +47 -18
  23. tests/test_density.py +72 -13
  24. tests/test_extensions.py +1 -0
  25. tests/test_matching_cli.py +23 -9
  26. tests/test_matching_exhaustive.py +5 -5
  27. tests/test_matching_utils.py +3 -3
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +124 -71
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +110 -105
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +102 -58
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +28 -8
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
  65. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
  66. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
  67. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
  68. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
scripts/pytme_runner.py CHANGED
@@ -8,7 +8,7 @@ import subprocess
8
8
  from abc import ABC, abstractmethod
9
9
 
10
10
  from pathlib import Path
11
- from dataclasses import dataclass
11
+ from dataclasses import dataclass, field, fields
12
12
  from typing import Dict, List, Optional, Any
13
13
 
14
14
  from tme.backends import backend as be
@@ -149,7 +149,7 @@ class AnalysisDatasetDiscovery(DatasetDiscovery):
149
149
 
150
150
  #: Glob pattern for TM pickle files, e.g., "/data/results/*.pickle"
151
151
  input_patterns: List[str]
152
- #: List of glob patterns for background files, e.g., ["/data/bg1/*.pickle", "/data/bg2/*.pickle"]
152
+ #: List of glob patterns for background files, e.g., ["bg1/*.pickle", "bg2/*."]
153
153
  background_patterns: List[str] = None
154
154
  #: Target masks, e.g., "/data/masks/*.mrc"
155
155
  mask_patterns: Optional[str] = None
@@ -230,46 +230,47 @@ class TMParameters:
230
230
  axis_sampling: Optional[float] = None
231
231
  axis_symmetry: int = 1
232
232
  cone_axis: int = 2
233
- invert_cone: bool = False
234
- no_use_optimized_set: bool = False
233
+ invert_cone: bool = field(default=False, metadata={"flag": True})
234
+ no_use_optimized_set: bool = field(default=False, metadata={"flag": True})
235
235
 
236
236
  # Microscope parameters
237
237
  acceleration_voltage: float = 300.0 # kV
238
238
  spherical_aberration: float = 2.7e7 # Å
239
239
  amplitude_contrast: float = 0.07
240
240
  defocus: Optional[float] = None # Å
241
- phase_shift: float = 0.0 # Dg
241
+ phase_shift: float = 0.0 # degrees
242
242
 
243
243
  # Processing options
244
244
  lowpass: Optional[float] = None # Å
245
245
  highpass: Optional[float] = None # Å
246
246
  pass_format: str = "sampling_rate" # "sampling_rate", "voxel", "frequency"
247
- no_pass_smooth: bool = True
247
+ no_pass_smooth: bool = field(default=True, metadata={"flag": False})
248
248
  interpolation_order: int = 3
249
249
  score_threshold: float = 0.0
250
250
  score: str = "FLCSphericalMask"
251
+ background_correction: Optional[str] = None
251
252
 
252
253
  # Weighting and correction
253
254
  tilt_weighting: Optional[str] = None # "angle", "relion", "grigorieff"
254
255
  wedge_axes: str = "2,0"
255
- whiten_spectrum: bool = False
256
- scramble_phases: bool = False
257
- invert_target_contrast: bool = False
256
+ whiten_spectrum: bool = field(default=False, metadata={"flag": True})
257
+ scramble_phases: bool = field(default=False, metadata={"flag": True})
258
+ invert_target_contrast: bool = field(default=False, metadata={"flag": True})
258
259
 
259
260
  # CTF parameters
260
261
  ctf_file: Optional[Path] = None
261
- no_flip_phase: bool = True
262
- correct_defocus_gradient: bool = False
262
+ no_flip_phase: bool = field(default=True, metadata={"flag": False})
263
+ correct_defocus_gradient: bool = field(default=False, metadata={"flag": True})
263
264
 
264
265
  # Performance options
265
- centering: bool = False
266
- pad_edges: bool = False
267
- pad_filter: bool = False
268
- use_mixed_precision: bool = False
269
- use_memmap: bool = False
266
+ centering: bool = field(default=False, metadata={"flag": True})
267
+ pad_edges: bool = field(default=False, metadata={"flag": True})
268
+ pad_filter: bool = field(default=False, metadata={"flag": True})
269
+ use_mixed_precision: bool = field(default=False, metadata={"flag": True})
270
+ use_memmap: bool = field(default=False, metadata={"flag": True})
270
271
 
271
272
  # Analysis options
272
- peak_calling: bool = False
273
+ peak_calling: bool = field(default=False, metadata={"flag": True})
273
274
  num_peaks: int = 1000
274
275
 
275
276
  # Backend selection
@@ -279,7 +280,7 @@ class TMParameters:
279
280
  # Reconstruction
280
281
  reconstruction_filter: str = "ramp"
281
282
  reconstruction_interpolation_order: int = 1
282
- no_filter_target: bool = False
283
+ no_filter_target: bool = field(default=False, metadata={"flag": True})
283
284
 
284
285
  def __post_init__(self):
285
286
  """Validate parameters and convert units."""
@@ -337,6 +338,9 @@ class TMParameters:
337
338
  args["ctf-file"] = str(files.metadata)
338
339
  args["tilt-angles"] = str(files.metadata)
339
340
 
341
+ if self.background_correction:
342
+ args["background-correction"] = self.background_correction
343
+
340
344
  # Optional parameters
341
345
  if self.lowpass:
342
346
  args["lowpass"] = self.lowpass
@@ -378,43 +382,26 @@ class TMParameters:
378
382
  return {k: v for k, v in args.items() if v is not None}
379
383
 
380
384
  def get_flags(self) -> List[str]:
381
- """Get boolean flags for pyTME command."""
382
385
  flags = []
383
- if self.whiten_spectrum:
384
- flags.append("whiten-spectrum")
385
- if self.scramble_phases:
386
- flags.append("scramble-phases")
387
- if self.invert_target_contrast:
388
- flags.append("invert-target-contrast")
389
- if self.centering:
390
- flags.append("centering")
391
- if self.pad_edges:
392
- flags.append("pad-edges")
393
- if self.pad_filter:
394
- flags.append("pad-filter")
395
- if not self.no_pass_smooth:
396
- flags.append("no-pass-smooth")
397
- if self.use_mixed_precision:
398
- flags.append("use-mixed-precision")
399
- if self.use_memmap:
400
- flags.append("use-memmap")
401
- if self.peak_calling:
402
- flags.append("peak-calling")
403
- if not self.no_flip_phase:
404
- flags.append("no-flip-phase")
405
- if self.correct_defocus_gradient:
406
- flags.append("correct-defocus-gradient")
407
- if self.invert_cone:
408
- flags.append("invert-cone")
409
- if self.no_use_optimized_set:
410
- flags.append("no-use-optimized-set")
411
- if self.no_filter_target:
412
- flags.append("no-filter-target")
386
+
387
+ for field_info in fields(self):
388
+ flag_meta = field_info.metadata.get("flag")
389
+ if flag_meta is None:
390
+ continue
391
+
392
+ value = getattr(self, field_info.name)
393
+ if not isinstance(value, bool):
394
+ continue
395
+
396
+ flag_name = field_info.name.replace("_", "-")
397
+ if (flag_meta is True and value) or (flag_meta is False and not value):
398
+ flags.append(flag_name)
399
+
413
400
  return flags
414
401
 
415
402
 
416
403
  @dataclass
417
- class AnalysisParameters:
404
+ class AnalysisParameters(TMParameters):
418
405
  """Parameters for template matching analysis and peak calling."""
419
406
 
420
407
  # Peak calling
@@ -424,13 +411,12 @@ class AnalysisParameters:
424
411
  max_score: Optional[float] = None
425
412
  min_distance: int = 5
426
413
  min_boundary_distance: int = 0
427
- mask_edges: bool = False
414
+ mask_edges: bool = field(default=False, metadata={"flag": True})
428
415
  n_false_positives: Optional[int] = None
429
416
 
430
417
  # Output format
431
418
  output_format: str = "relion4"
432
419
  output_directory: Optional[str] = None
433
- angles_clockwise: bool = False
434
420
 
435
421
  # Advanced options
436
422
  extraction_box_size: Optional[int] = None
@@ -468,15 +454,6 @@ class AnalysisParameters:
468
454
 
469
455
  return {k: v for k, v in args.items() if v is not None}
470
456
 
471
- def get_flags(self) -> List[str]:
472
- """Get boolean flags for analyze_template_matching command."""
473
- flags = []
474
- if self.mask_edges:
475
- flags.append("mask-edges")
476
- if self.angles_clockwise:
477
- flags.append("angles-clockwise")
478
- return flags
479
-
480
457
 
481
458
  @dataclass
482
459
  class ComputeResources:
@@ -592,12 +569,10 @@ class ExecutionBackend(ABC):
592
569
  @abstractmethod
593
570
  def submit_job(self, task) -> str:
594
571
  """Submit a single job and return job ID or status."""
595
- pass
596
572
 
597
573
  @abstractmethod
598
574
  def submit_jobs(self, tasks: List) -> List[str]:
599
575
  """Submit multiple jobs and return list of job IDs."""
600
- pass
601
576
 
602
577
 
603
578
  class SlurmBackend(ExecutionBackend):
@@ -841,6 +816,13 @@ def parse_args():
841
816
  default="FLCSphericalMask",
842
817
  help="Template matching scoring function. Use FLC if mask is not spherical.",
843
818
  )
819
+ tm_group.add_argument(
820
+ "--background-correction",
821
+ choices=["phase-scrambling"],
822
+ required=False,
823
+ help="Transform cross-correlation into SNR-like values using a given method: "
824
+ "'phase-scrambling' uses a phase-scrambled template as background",
825
+ )
844
826
  tm_group.add_argument(
845
827
  "--score-threshold", type=float, default=0.0, help="Minimum score threshold"
846
828
  )
@@ -1006,11 +988,6 @@ def parse_args():
1006
988
  default="relion4",
1007
989
  help="Output format for analysis results",
1008
990
  )
1009
- output_group.add_argument(
1010
- "--angles-clockwise",
1011
- action="store_true",
1012
- help="Report Euler angles in clockwise format expected by RELION",
1013
- )
1014
991
 
1015
992
  advanced_group = analysis_parser.add_argument_group("Advanced Options")
1016
993
  advanced_group.add_argument(
@@ -1077,6 +1054,7 @@ def run_matching(args, resources):
1077
1054
  backend=args.backend,
1078
1055
  whiten_spectrum=args.whiten_spectrum,
1079
1056
  scramble_phases=args.scramble_phases,
1057
+ background_correction=args.background_correction,
1080
1058
  )
1081
1059
  print_params = params.to_command_args(files[0], "")
1082
1060
  _ = print_params.pop("target")
@@ -1132,7 +1110,6 @@ def run_analysis(args, resources):
1132
1110
  mask_edges=args.mask_edges,
1133
1111
  n_false_positives=args.n_false_positives,
1134
1112
  output_format=args.output_format,
1135
- angles_clockwise=args.angles_clockwise,
1136
1113
  extraction_box_size=args.extraction_box_size,
1137
1114
  )
1138
1115
  print_params = params.to_command_args(files[0], Path(""))
scripts/refine_matches.py CHANGED
@@ -10,11 +10,9 @@ import subprocess
10
10
  from sys import exit
11
11
  from os import unlink
12
12
  from time import time
13
- from os.path import join
14
13
  from typing import Tuple, List, Dict
15
14
 
16
15
  import numpy as np
17
- from scipy import optimize
18
16
  from sklearn.metrics import roc_auc_score
19
17
 
20
18
  from tme import Orientations, Density
@@ -66,7 +64,6 @@ def parse_args():
66
64
  matching_group.add_argument(
67
65
  "-i",
68
66
  "--template",
69
- dest="template",
70
67
  type=str,
71
68
  required=True,
72
69
  help="Path to a template in PDB/MMCIF or other supported formats (see target).",
@@ -102,7 +99,7 @@ def parse_args():
102
99
  )
103
100
  matching_group.add_argument(
104
101
  "-s",
105
- dest="score",
102
+ "--score",
106
103
  type=str,
107
104
  default="batchFLCSphericalMask",
108
105
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
@@ -197,6 +194,7 @@ def create_matching_argdict(args) -> Dict:
197
194
  "-n": args.cores,
198
195
  "--ctf-file": args.ctf_file,
199
196
  "--invert-target-contrast": args.invert_target_contrast,
197
+ "--backend" : args.backend,
200
198
  }
201
199
  return arg_dict
202
200
 
@@ -252,7 +250,7 @@ class DeepMatcher:
252
250
  if args.lowpass_range:
253
251
  self.filter_parameters["--lowpass"] = 0
254
252
  if args.highpass_range:
255
- self.filter_parameters["--highpass"] = 200
253
+ self.filter_parameters["--highpass"] = 0
256
254
 
257
255
  self.postprocess_args = create_postprocessing_argdict(args)
258
256
  self.log_file = f"{args.output_prefix}_optimization_log.txt"
@@ -309,14 +307,14 @@ class DeepMatcher:
309
307
 
310
308
  match_template = argdict_to_command(
311
309
  self.match_template_args,
312
- executable="match_template.py",
310
+ executable="match_template",
313
311
  )
314
312
  run_command(match_template)
315
313
 
316
314
  # Assume we get a new peak for each input in the same order
317
315
  postprocess = argdict_to_command(
318
316
  self.postprocess_args,
319
- executable="postprocess.py",
317
+ executable="postprocess",
320
318
  )
321
319
  run_command(postprocess)
322
320
 
@@ -1,28 +1,38 @@
1
1
  import pytest
2
2
 
3
- from tme.filters import Compose
3
+ from tme.filters import Compose, ComposableFilter
4
4
  from tme.backends import backend as be
5
5
 
6
6
 
7
- def mock_transform1(**kwargs):
8
- return {"data": be.ones((10, 10)), "is_multiplicative_filter": True}
7
+ class MockFilter(ComposableFilter):
8
+ def _evaluate(self, *args, **kwargs):
9
+ return {"data": be.ones((10, 10)), "shape": (10, 10)}
10
+
11
+
12
+ class MockFilterNoMult(ComposableFilter):
13
+ def _evaluate(self, *args, **kwargs):
14
+ return {
15
+ "data": be.ones((10, 10)) * 3,
16
+ "shape": (10, 10),
17
+ "is_multiplicative_filter": False,
18
+ }
9
19
 
10
20
 
11
- def mock_transform2(**kwargs):
12
- return {"data": be.ones((10, 10)) * 2, "is_multiplicative_filter": True}
21
+ mock_transform = MockFilter()
22
+ mock_transform_nomult = MockFilterNoMult()
13
23
 
14
24
 
15
- def mock_transform3(**kwargs):
16
- return {"extra_info": "test"}
25
+ def mock_transform_error(**kwargs):
26
+ return {"data": be.ones((10, 10)), "is_multiplicative_filter": True}
17
27
 
18
28
 
19
29
  class TestCompose:
20
30
  @pytest.fixture
21
31
  def compose_instance(self):
22
- return Compose((mock_transform1, mock_transform2, mock_transform3))
32
+ return Compose((mock_transform, mock_transform, mock_transform))
23
33
 
24
34
  def test_init(self):
25
- transforms = (mock_transform1, mock_transform2)
35
+ transforms = (mock_transform, mock_transform)
26
36
  compose = Compose(transforms)
27
37
  assert compose.transforms == transforms
28
38
 
@@ -32,45 +42,36 @@ class TestCompose:
32
42
  assert result == {}
33
43
 
34
44
  def test_call_single_transform(self):
35
- compose = Compose((mock_transform1,))
36
- result = compose()
45
+ compose = Compose((mock_transform,))
46
+ result = compose(return_real_fourier=False)
37
47
  assert "data" in result
38
- assert result.get("is_multiplicative_filter", False)
39
48
  assert be.allclose(result["data"], be.ones((10, 10)))
40
49
 
41
50
  def test_call_multiple_transforms(self, compose_instance):
42
- result = compose_instance()
51
+ result = compose_instance(return_real_fourier=False)
43
52
  assert "data" in result
44
- assert "extra_info" not in result
45
- assert be.allclose(result["data"], be.ones((10, 10)) * 2)
53
+ assert be.allclose(result["data"], be.ones((10, 10)))
46
54
 
47
55
  def test_multiplicative_filter_composition(self):
48
- compose = Compose((mock_transform1, mock_transform2))
49
- result = compose()
56
+ compose = Compose((mock_transform, mock_transform))
57
+ result = compose(return_real_fourier=False)
50
58
  assert "data" in result
51
- assert be.allclose(result["data"], be.ones((10, 10)) * 2)
59
+ assert be.allclose(result["data"], be.ones((10, 10)))
52
60
 
53
61
  @pytest.mark.parametrize(
54
62
  "kwargs", [{}, {"extra_param": "test"}, {"data": be.zeros((5, 5))}]
55
63
  )
56
64
  def test_call_with_kwargs(self, compose_instance, kwargs):
57
- result = compose_instance(**kwargs)
65
+ result = compose_instance(**kwargs, return_real_fourier=False)
58
66
  assert "data" in result
59
67
  assert "extra_info" not in result
60
68
 
61
69
  def test_non_multiplicative_filter(self):
62
- def non_mult_transform(**kwargs):
63
- return {"data": be.ones((10, 10)) * 3, "is_multiplicative_filter": False}
64
-
65
- compose = Compose((mock_transform1, non_mult_transform))
66
- result = compose()
70
+ compose = Compose((mock_transform, mock_transform_nomult))
71
+ result = compose(return_real_fourier=False)
67
72
  assert "data" in result
68
73
  assert be.allclose(result["data"], be.ones((10, 10)) * 3)
69
74
 
70
75
  def test_error_handling(self):
71
- def error_transform(**kwargs):
72
- raise ValueError("Test error")
73
-
74
- compose = Compose((mock_transform1, error_transform))
75
- with pytest.raises(ValueError, match="Test error"):
76
- compose()
76
+ with pytest.raises(ValueError):
77
+ Compose((mock_transform, mock_transform_error))()
@@ -71,52 +71,43 @@ class TestBandPassFilter:
71
71
  shape_is_real_fourier: bool,
72
72
  ):
73
73
  band_pass_filter.use_gaussian = use_gaussian
74
- band_pass_filter.return_real_fourier = return_real_fourier
75
74
  band_pass_filter.shape_is_real_fourier = shape_is_real_fourier
76
75
 
77
76
  result = band_pass_filter(shape=(10, 10), lowpass=0.2, highpass=0.8)
78
77
 
79
78
  assert isinstance(result, dict)
80
79
  assert "data" in result
81
- assert "is_multiplicative_filter" in result
82
80
  assert isinstance(result["data"], type(be.ones((1,))))
83
- assert result["is_multiplicative_filter"] is True
84
81
 
85
82
  def test_default_values(self, band_pass_filter: BandPassReconstructed):
86
83
  assert band_pass_filter.lowpass is None
87
84
  assert band_pass_filter.highpass is None
88
85
  assert band_pass_filter.sampling_rate == 1
89
86
  assert band_pass_filter.use_gaussian is True
90
- assert band_pass_filter.return_real_fourier is False
91
87
 
92
88
  @pytest.mark.parametrize("shape", ((10, 10), (20, 20, 20), (30, 30)))
93
89
  def test_return_real_fourier(self, shape: Tuple[int]):
94
- bpf = BandPassReconstructed(return_real_fourier=True)
90
+ bpf = BandPassReconstructed()
95
91
  result = bpf(shape=shape, lowpass=0.2, highpass=0.8)
96
- expected_shape = tuple(compute_fourier_shape(shape, False))
97
- assert result["data"].shape == expected_shape
92
+ assert result["data"].shape == shape
98
93
 
99
94
 
100
95
  class TestLinearWhiteningFilter:
101
96
  @pytest.mark.parametrize(
102
- "shape, n_bins, batch_dimension",
97
+ "shape, n_bins",
103
98
  [
104
- ((10, 10), None, None),
105
- ((20, 20, 20), 15, 0),
106
- ((30, 30, 30), 20, 1),
107
- ((40, 40, 40, 40), 25, 2),
99
+ ((10, 10), None),
100
+ ((20, 20, 20), 15),
101
+ ((30, 30, 30), 20),
102
+ ((40, 40, 40, 40), 25),
108
103
  ],
109
104
  )
110
- def test_compute_spectrum(
111
- self, shape: Tuple[int], n_bins: int, batch_dimension: int
112
- ):
105
+ def test_compute_spectrum(self, shape: Tuple[int], n_bins: int):
113
106
  data_rfft = be.fft.rfftn(be.random.random(shape))
114
107
  bins, radial_averages = LinearWhiteningFilter._compute_spectrum(
115
- data_rfft, n_bins, batch_dimension
116
- )
117
- data_shape = tuple(
118
- int(x) for i, x in enumerate(data_rfft.shape) if i != batch_dimension
108
+ data_rfft, n_bins
119
109
  )
110
+ data_shape = tuple(int(x) for i, x in enumerate(data_rfft.shape))
120
111
 
121
112
  assert isinstance(bins, np.ndarray)
122
113
  assert isinstance(radial_averages, np.ndarray)
@@ -140,9 +131,8 @@ class TestLinearWhiteningFilter:
140
131
  @pytest.mark.parametrize(
141
132
  "shape, n_bins, batch_dimension, order",
142
133
  [
143
- ((10, 10), None, None, 1),
144
- ((20, 20, 20), 15, 0, 2),
145
- ((30, 30, 30), 20, 1, None),
134
+ ((10, 10), None, (), 1),
135
+ ((20, 20, 20), 15, 0, 1),
146
136
  ],
147
137
  )
148
138
  def test_call_method(
@@ -154,21 +144,17 @@ class TestLinearWhiteningFilter:
154
144
  ):
155
145
  data = be.random.random(shape)
156
146
  result = LinearWhiteningFilter()(
157
- shape=shape,
158
- data=data,
147
+ shape=tuple(x for i, x in enumerate(shape) if i != batch_dimension),
148
+ data_rfft=np.fft.rfftn(data),
159
149
  n_bins=n_bins,
160
- batch_dimension=batch_dimension,
150
+ axes=batch_dimension,
161
151
  order=order,
162
152
  )
163
153
 
164
154
  assert isinstance(result, dict)
165
155
  assert result.get("data", False) is not False
166
- assert result.get("is_multiplicative_filter", False)
167
156
  assert isinstance(result["data"], type(be.ones((1,))))
168
- data_shape = tuple(
169
- int(x) for i, x in enumerate(data.shape) if i != batch_dimension
170
- )
171
- assert result["data"].shape == tuple(compute_fourier_shape(data_shape, False))
157
+ assert result["data"].shape == shape
172
158
 
173
159
  def test_call_method_with_data_rfft(self):
174
160
  shape = (30, 30, 30)
@@ -179,14 +165,13 @@ class TestLinearWhiteningFilter:
179
165
 
180
166
  assert isinstance(result, dict)
181
167
  assert result.get("data", False) is not False
182
- assert result.get("is_multiplicative_filter", False)
183
168
  assert isinstance(result["data"], type(be.ones((1,))))
184
169
  assert result["data"].shape == data_rfft.shape
185
170
 
186
171
  @pytest.mark.parametrize("shape", [(10, 10), (20, 20, 20), (30, 30, 30)])
187
172
  def test_filter_mask_range(self, shape: Tuple[int]):
188
173
  data = be.random.random(shape)
189
- result = LinearWhiteningFilter()(shape=shape, data=data)
174
+ result = LinearWhiteningFilter()(shape=shape, data_rfft=np.fft.rfftn(data))
190
175
 
191
176
  filter_mask = result["data"]
192
177
  assert np.all(filter_mask >= 0) and np.all(filter_mask <= 1)
@@ -1,5 +1,4 @@
1
1
  import pytest
2
- import numpy as np
3
2
 
4
3
  from tme import Density, Structure, Preprocessor
5
4
 
@@ -49,15 +48,6 @@ class TestPreprocessor:
49
48
  high_sigma=high_sigma,
50
49
  )
51
50
 
52
- @pytest.mark.parametrize("smallest_size,largest_size", [(1, 10), (2, 20)])
53
- def test_bandpass_filter(self, smallest_size, largest_size):
54
- _ = self.preprocessor.bandpass_filter(
55
- template=self.structure_density.data,
56
- lowpass=smallest_size,
57
- highpass=largest_size,
58
- sampling_rate=1,
59
- )
60
-
61
51
  @pytest.mark.parametrize("lbd,sigma_range", [(1, (2, 4)), (20, (1, 6))])
62
52
  def test_local_gaussian_alignment_filter(self, lbd, sigma_range):
63
53
  _ = self.preprocessor.local_gaussian_alignment_filter(
@@ -125,12 +115,3 @@ class TestPreprocessor:
125
115
  template=self.structure_density.data,
126
116
  rank=rank,
127
117
  )
128
-
129
- @pytest.mark.parametrize("infinite_plane", [False, True])
130
- def test_continuous_wedge_mask(self, infinite_plane):
131
- _ = self.preprocessor.continuous_wedge_mask(
132
- start_tilt=50,
133
- stop_tilt=-40,
134
- shape=(50, 50, 50),
135
- infinite_plane=infinite_plane,
136
- )
@@ -74,10 +74,22 @@ class TestPreprocessUtils:
74
74
  def test_fftfreqn(self, n, sampling_rate):
75
75
  assert np.allclose(
76
76
  fftfreqn(
77
- shape=(n,), sampling_rate=sampling_rate, compute_euclidean_norm=True
77
+ shape=(n,),
78
+ sampling_rate=sampling_rate,
79
+ compute_euclidean_norm=True,
80
+ fftshift=True,
78
81
  ),
79
82
  np.abs(np.fft.ifftshift(np.fft.fftfreq(n=n, d=sampling_rate))),
80
83
  )
84
+ assert np.allclose(
85
+ fftfreqn(
86
+ shape=(n,),
87
+ sampling_rate=sampling_rate,
88
+ compute_euclidean_norm=True,
89
+ fftshift=False,
90
+ ),
91
+ np.abs(np.fft.fftfreq(n=n, d=sampling_rate)),
92
+ )
81
93
 
82
94
  @pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
83
95
  def test_crop_real_fourier(self, shape):
tests/test_analyzer.py CHANGED
@@ -109,13 +109,11 @@ class TestRecursiveMasking:
109
109
 
110
110
  @pytest.mark.parametrize("num_peaks", (1, 100))
111
111
  @pytest.mark.parametrize("compute_rotation", (True, False))
112
- @pytest.mark.parametrize("minimum_score", (None, 0.5))
113
- def test__call__(self, num_peaks, compute_rotation, minimum_score):
112
+ def test__call__(self, num_peaks, compute_rotation):
114
113
  peak_caller = PeakCallerRecursiveMasking(
115
114
  shape=self.data.shape,
116
115
  num_peaks=num_peaks,
117
116
  min_distance=self.min_distance,
118
- min_score=minimum_score,
119
117
  )
120
118
  rotation_space, rotation_mapping = None, None
121
119
  if compute_rotation:
@@ -131,13 +129,7 @@ class TestRecursiveMasking:
131
129
  rotation_space=rotation_space,
132
130
  rotation_mapping=rotation_mapping,
133
131
  )
134
-
135
- if minimum_score is None:
136
- assert len(state[0] <= num_peaks)
137
- else:
138
- peaks = state[0].astype(int)
139
- assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
140
-
132
+ assert len(state[0] <= num_peaks)
141
133
 
142
134
  class TestMaxScoreOverRotations:
143
135
  def setup_method(self):