pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +49 -103
  6. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +49 -103
  15. scripts/pytme_runner.py +46 -69
  16. tests/preprocessing/test_compose.py +31 -30
  17. tests/preprocessing/test_frequency_filters.py +17 -32
  18. tests/preprocessing/test_preprocessor.py +0 -19
  19. tests/preprocessing/test_utils.py +13 -1
  20. tests/test_analyzer.py +2 -10
  21. tests/test_backends.py +47 -18
  22. tests/test_density.py +72 -13
  23. tests/test_extensions.py +1 -0
  24. tests/test_matching_cli.py +23 -9
  25. tests/test_matching_exhaustive.py +5 -5
  26. tests/test_matching_utils.py +3 -3
  27. tests/test_orientations.py +12 -0
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +91 -68
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +103 -98
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +44 -57
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +17 -3
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post2.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
scripts/pytme_runner.py CHANGED
@@ -8,7 +8,7 @@ import subprocess
8
8
  from abc import ABC, abstractmethod
9
9
 
10
10
  from pathlib import Path
11
- from dataclasses import dataclass
11
+ from dataclasses import dataclass, field, fields
12
12
  from typing import Dict, List, Optional, Any
13
13
 
14
14
  from tme.backends import backend as be
@@ -149,7 +149,7 @@ class AnalysisDatasetDiscovery(DatasetDiscovery):
149
149
 
150
150
  #: Glob pattern for TM pickle files, e.g., "/data/results/*.pickle"
151
151
  input_patterns: List[str]
152
- #: List of glob patterns for background files, e.g., ["/data/bg1/*.pickle", "/data/bg2/*.pickle"]
152
+ #: List of glob patterns for background files, e.g., ["bg1/*.pickle", "bg2/*."]
153
153
  background_patterns: List[str] = None
154
154
  #: Target masks, e.g., "/data/masks/*.mrc"
155
155
  mask_patterns: Optional[str] = None
@@ -230,46 +230,47 @@ class TMParameters:
230
230
  axis_sampling: Optional[float] = None
231
231
  axis_symmetry: int = 1
232
232
  cone_axis: int = 2
233
- invert_cone: bool = False
234
- no_use_optimized_set: bool = False
233
+ invert_cone: bool = field(default=False, metadata={"flag": True})
234
+ no_use_optimized_set: bool = field(default=False, metadata={"flag": True})
235
235
 
236
236
  # Microscope parameters
237
237
  acceleration_voltage: float = 300.0 # kV
238
238
  spherical_aberration: float = 2.7e7 # Å
239
239
  amplitude_contrast: float = 0.07
240
240
  defocus: Optional[float] = None # Å
241
- phase_shift: float = 0.0 # Dg
241
+ phase_shift: float = 0.0 # degrees
242
242
 
243
243
  # Processing options
244
244
  lowpass: Optional[float] = None # Å
245
245
  highpass: Optional[float] = None # Å
246
246
  pass_format: str = "sampling_rate" # "sampling_rate", "voxel", "frequency"
247
- no_pass_smooth: bool = True
247
+ no_pass_smooth: bool = field(default=True, metadata={"flag": False})
248
248
  interpolation_order: int = 3
249
249
  score_threshold: float = 0.0
250
250
  score: str = "FLCSphericalMask"
251
+ background_correction: Optional[str] = None
251
252
 
252
253
  # Weighting and correction
253
254
  tilt_weighting: Optional[str] = None # "angle", "relion", "grigorieff"
254
255
  wedge_axes: str = "2,0"
255
- whiten_spectrum: bool = False
256
- scramble_phases: bool = False
257
- invert_target_contrast: bool = False
256
+ whiten_spectrum: bool = field(default=False, metadata={"flag": True})
257
+ scramble_phases: bool = field(default=False, metadata={"flag": True})
258
+ invert_target_contrast: bool = field(default=False, metadata={"flag": True})
258
259
 
259
260
  # CTF parameters
260
261
  ctf_file: Optional[Path] = None
261
- no_flip_phase: bool = True
262
- correct_defocus_gradient: bool = False
262
+ no_flip_phase: bool = field(default=True, metadata={"flag": False})
263
+ correct_defocus_gradient: bool = field(default=False, metadata={"flag": True})
263
264
 
264
265
  # Performance options
265
- centering: bool = False
266
- pad_edges: bool = False
267
- pad_filter: bool = False
268
- use_mixed_precision: bool = False
269
- use_memmap: bool = False
266
+ centering: bool = field(default=False, metadata={"flag": True})
267
+ pad_edges: bool = field(default=False, metadata={"flag": True})
268
+ pad_filter: bool = field(default=False, metadata={"flag": True})
269
+ use_mixed_precision: bool = field(default=False, metadata={"flag": True})
270
+ use_memmap: bool = field(default=False, metadata={"flag": True})
270
271
 
271
272
  # Analysis options
272
- peak_calling: bool = False
273
+ peak_calling: bool = field(default=False, metadata={"flag": True})
273
274
  num_peaks: int = 1000
274
275
 
275
276
  # Backend selection
@@ -279,7 +280,7 @@ class TMParameters:
279
280
  # Reconstruction
280
281
  reconstruction_filter: str = "ramp"
281
282
  reconstruction_interpolation_order: int = 1
282
- no_filter_target: bool = False
283
+ no_filter_target: bool = field(default=False, metadata={"flag": True})
283
284
 
284
285
  def __post_init__(self):
285
286
  """Validate parameters and convert units."""
@@ -337,6 +338,9 @@ class TMParameters:
337
338
  args["ctf-file"] = str(files.metadata)
338
339
  args["tilt-angles"] = str(files.metadata)
339
340
 
341
+ if self.background_correction:
342
+ args["background-correction"] = self.background_correction
343
+
340
344
  # Optional parameters
341
345
  if self.lowpass:
342
346
  args["lowpass"] = self.lowpass
@@ -378,43 +382,26 @@ class TMParameters:
378
382
  return {k: v for k, v in args.items() if v is not None}
379
383
 
380
384
  def get_flags(self) -> List[str]:
381
- """Get boolean flags for pyTME command."""
382
385
  flags = []
383
- if self.whiten_spectrum:
384
- flags.append("whiten-spectrum")
385
- if self.scramble_phases:
386
- flags.append("scramble-phases")
387
- if self.invert_target_contrast:
388
- flags.append("invert-target-contrast")
389
- if self.centering:
390
- flags.append("centering")
391
- if self.pad_edges:
392
- flags.append("pad-edges")
393
- if self.pad_filter:
394
- flags.append("pad-filter")
395
- if not self.no_pass_smooth:
396
- flags.append("no-pass-smooth")
397
- if self.use_mixed_precision:
398
- flags.append("use-mixed-precision")
399
- if self.use_memmap:
400
- flags.append("use-memmap")
401
- if self.peak_calling:
402
- flags.append("peak-calling")
403
- if not self.no_flip_phase:
404
- flags.append("no-flip-phase")
405
- if self.correct_defocus_gradient:
406
- flags.append("correct-defocus-gradient")
407
- if self.invert_cone:
408
- flags.append("invert-cone")
409
- if self.no_use_optimized_set:
410
- flags.append("no-use-optimized-set")
411
- if self.no_filter_target:
412
- flags.append("no-filter-target")
386
+
387
+ for field_info in fields(self):
388
+ flag_meta = field_info.metadata.get("flag")
389
+ if flag_meta is None:
390
+ continue
391
+
392
+ value = getattr(self, field_info.name)
393
+ if not isinstance(value, bool):
394
+ continue
395
+
396
+ flag_name = field_info.name.replace("_", "-")
397
+ if (flag_meta is True and value) or (flag_meta is False and not value):
398
+ flags.append(flag_name)
399
+
413
400
  return flags
414
401
 
415
402
 
416
403
  @dataclass
417
- class AnalysisParameters:
404
+ class AnalysisParameters(TMParameters):
418
405
  """Parameters for template matching analysis and peak calling."""
419
406
 
420
407
  # Peak calling
@@ -424,13 +411,12 @@ class AnalysisParameters:
424
411
  max_score: Optional[float] = None
425
412
  min_distance: int = 5
426
413
  min_boundary_distance: int = 0
427
- mask_edges: bool = False
414
+ mask_edges: bool = field(default=False, metadata={"flag": True})
428
415
  n_false_positives: Optional[int] = None
429
416
 
430
417
  # Output format
431
418
  output_format: str = "relion4"
432
419
  output_directory: Optional[str] = None
433
- angles_clockwise: bool = False
434
420
 
435
421
  # Advanced options
436
422
  extraction_box_size: Optional[int] = None
@@ -468,15 +454,6 @@ class AnalysisParameters:
468
454
 
469
455
  return {k: v for k, v in args.items() if v is not None}
470
456
 
471
- def get_flags(self) -> List[str]:
472
- """Get boolean flags for analyze_template_matching command."""
473
- flags = []
474
- if self.mask_edges:
475
- flags.append("mask-edges")
476
- if self.angles_clockwise:
477
- flags.append("angles-clockwise")
478
- return flags
479
-
480
457
 
481
458
  @dataclass
482
459
  class ComputeResources:
@@ -592,12 +569,10 @@ class ExecutionBackend(ABC):
592
569
  @abstractmethod
593
570
  def submit_job(self, task) -> str:
594
571
  """Submit a single job and return job ID or status."""
595
- pass
596
572
 
597
573
  @abstractmethod
598
574
  def submit_jobs(self, tasks: List) -> List[str]:
599
575
  """Submit multiple jobs and return list of job IDs."""
600
- pass
601
576
 
602
577
 
603
578
  class SlurmBackend(ExecutionBackend):
@@ -841,6 +816,13 @@ def parse_args():
841
816
  default="FLCSphericalMask",
842
817
  help="Template matching scoring function. Use FLC if mask is not spherical.",
843
818
  )
819
+ tm_group.add_argument(
820
+ "--background-correction",
821
+ choices=["phase-scrambling"],
822
+ required=False,
823
+ help="Transform cross-correlation into SNR-like values using a given method: "
824
+ "'phase-scrambling' uses a phase-scrambled template as background",
825
+ )
844
826
  tm_group.add_argument(
845
827
  "--score-threshold", type=float, default=0.0, help="Minimum score threshold"
846
828
  )
@@ -1006,11 +988,6 @@ def parse_args():
1006
988
  default="relion4",
1007
989
  help="Output format for analysis results",
1008
990
  )
1009
- output_group.add_argument(
1010
- "--angles-clockwise",
1011
- action="store_true",
1012
- help="Report Euler angles in clockwise format expected by RELION",
1013
- )
1014
991
 
1015
992
  advanced_group = analysis_parser.add_argument_group("Advanced Options")
1016
993
  advanced_group.add_argument(
@@ -1077,6 +1054,7 @@ def run_matching(args, resources):
1077
1054
  backend=args.backend,
1078
1055
  whiten_spectrum=args.whiten_spectrum,
1079
1056
  scramble_phases=args.scramble_phases,
1057
+ background_correction=args.background_correction,
1080
1058
  )
1081
1059
  print_params = params.to_command_args(files[0], "")
1082
1060
  _ = print_params.pop("target")
@@ -1132,7 +1110,6 @@ def run_analysis(args, resources):
1132
1110
  mask_edges=args.mask_edges,
1133
1111
  n_false_positives=args.n_false_positives,
1134
1112
  output_format=args.output_format,
1135
- angles_clockwise=args.angles_clockwise,
1136
1113
  extraction_box_size=args.extraction_box_size,
1137
1114
  )
1138
1115
  print_params = params.to_command_args(files[0], Path(""))
@@ -1,28 +1,38 @@
1
1
  import pytest
2
2
 
3
- from tme.filters import Compose
3
+ from tme.filters import Compose, ComposableFilter
4
4
  from tme.backends import backend as be
5
5
 
6
6
 
7
- def mock_transform1(**kwargs):
8
- return {"data": be.ones((10, 10)), "is_multiplicative_filter": True}
7
+ class MockFilter(ComposableFilter):
8
+ def _evaluate(self, *args, **kwargs):
9
+ return {"data": be.ones((10, 10)), "shape": (10, 10)}
10
+
11
+
12
+ class MockFilterNoMult(ComposableFilter):
13
+ def _evaluate(self, *args, **kwargs):
14
+ return {
15
+ "data": be.ones((10, 10)) * 3,
16
+ "shape": (10, 10),
17
+ "is_multiplicative_filter": False,
18
+ }
9
19
 
10
20
 
11
- def mock_transform2(**kwargs):
12
- return {"data": be.ones((10, 10)) * 2, "is_multiplicative_filter": True}
21
+ mock_transform = MockFilter()
22
+ mock_transform_nomult = MockFilterNoMult()
13
23
 
14
24
 
15
- def mock_transform3(**kwargs):
16
- return {"extra_info": "test"}
25
+ def mock_transform_error(**kwargs):
26
+ return {"data": be.ones((10, 10)), "is_multiplicative_filter": True}
17
27
 
18
28
 
19
29
  class TestCompose:
20
30
  @pytest.fixture
21
31
  def compose_instance(self):
22
- return Compose((mock_transform1, mock_transform2, mock_transform3))
32
+ return Compose((mock_transform, mock_transform, mock_transform))
23
33
 
24
34
  def test_init(self):
25
- transforms = (mock_transform1, mock_transform2)
35
+ transforms = (mock_transform, mock_transform)
26
36
  compose = Compose(transforms)
27
37
  assert compose.transforms == transforms
28
38
 
@@ -32,45 +42,36 @@ class TestCompose:
32
42
  assert result == {}
33
43
 
34
44
  def test_call_single_transform(self):
35
- compose = Compose((mock_transform1,))
36
- result = compose()
45
+ compose = Compose((mock_transform,))
46
+ result = compose(return_real_fourier=False)
37
47
  assert "data" in result
38
- assert result.get("is_multiplicative_filter", False)
39
48
  assert be.allclose(result["data"], be.ones((10, 10)))
40
49
 
41
50
  def test_call_multiple_transforms(self, compose_instance):
42
- result = compose_instance()
51
+ result = compose_instance(return_real_fourier=False)
43
52
  assert "data" in result
44
- assert "extra_info" not in result
45
- assert be.allclose(result["data"], be.ones((10, 10)) * 2)
53
+ assert be.allclose(result["data"], be.ones((10, 10)))
46
54
 
47
55
  def test_multiplicative_filter_composition(self):
48
- compose = Compose((mock_transform1, mock_transform2))
49
- result = compose()
56
+ compose = Compose((mock_transform, mock_transform))
57
+ result = compose(return_real_fourier=False)
50
58
  assert "data" in result
51
- assert be.allclose(result["data"], be.ones((10, 10)) * 2)
59
+ assert be.allclose(result["data"], be.ones((10, 10)))
52
60
 
53
61
  @pytest.mark.parametrize(
54
62
  "kwargs", [{}, {"extra_param": "test"}, {"data": be.zeros((5, 5))}]
55
63
  )
56
64
  def test_call_with_kwargs(self, compose_instance, kwargs):
57
- result = compose_instance(**kwargs)
65
+ result = compose_instance(**kwargs, return_real_fourier=False)
58
66
  assert "data" in result
59
67
  assert "extra_info" not in result
60
68
 
61
69
  def test_non_multiplicative_filter(self):
62
- def non_mult_transform(**kwargs):
63
- return {"data": be.ones((10, 10)) * 3, "is_multiplicative_filter": False}
64
-
65
- compose = Compose((mock_transform1, non_mult_transform))
66
- result = compose()
70
+ compose = Compose((mock_transform, mock_transform_nomult))
71
+ result = compose(return_real_fourier=False)
67
72
  assert "data" in result
68
73
  assert be.allclose(result["data"], be.ones((10, 10)) * 3)
69
74
 
70
75
  def test_error_handling(self):
71
- def error_transform(**kwargs):
72
- raise ValueError("Test error")
73
-
74
- compose = Compose((mock_transform1, error_transform))
75
- with pytest.raises(ValueError, match="Test error"):
76
- compose()
76
+ with pytest.raises(ValueError):
77
+ Compose((mock_transform, mock_transform_error))()
@@ -71,52 +71,43 @@ class TestBandPassFilter:
71
71
  shape_is_real_fourier: bool,
72
72
  ):
73
73
  band_pass_filter.use_gaussian = use_gaussian
74
- band_pass_filter.return_real_fourier = return_real_fourier
75
74
  band_pass_filter.shape_is_real_fourier = shape_is_real_fourier
76
75
 
77
76
  result = band_pass_filter(shape=(10, 10), lowpass=0.2, highpass=0.8)
78
77
 
79
78
  assert isinstance(result, dict)
80
79
  assert "data" in result
81
- assert "is_multiplicative_filter" in result
82
80
  assert isinstance(result["data"], type(be.ones((1,))))
83
- assert result["is_multiplicative_filter"] is True
84
81
 
85
82
  def test_default_values(self, band_pass_filter: BandPassReconstructed):
86
83
  assert band_pass_filter.lowpass is None
87
84
  assert band_pass_filter.highpass is None
88
85
  assert band_pass_filter.sampling_rate == 1
89
86
  assert band_pass_filter.use_gaussian is True
90
- assert band_pass_filter.return_real_fourier is False
91
87
 
92
88
  @pytest.mark.parametrize("shape", ((10, 10), (20, 20, 20), (30, 30)))
93
89
  def test_return_real_fourier(self, shape: Tuple[int]):
94
- bpf = BandPassReconstructed(return_real_fourier=True)
90
+ bpf = BandPassReconstructed()
95
91
  result = bpf(shape=shape, lowpass=0.2, highpass=0.8)
96
- expected_shape = tuple(compute_fourier_shape(shape, False))
97
- assert result["data"].shape == expected_shape
92
+ assert result["data"].shape == shape
98
93
 
99
94
 
100
95
  class TestLinearWhiteningFilter:
101
96
  @pytest.mark.parametrize(
102
- "shape, n_bins, batch_dimension",
97
+ "shape, n_bins",
103
98
  [
104
- ((10, 10), None, None),
105
- ((20, 20, 20), 15, 0),
106
- ((30, 30, 30), 20, 1),
107
- ((40, 40, 40, 40), 25, 2),
99
+ ((10, 10), None),
100
+ ((20, 20, 20), 15),
101
+ ((30, 30, 30), 20),
102
+ ((40, 40, 40, 40), 25),
108
103
  ],
109
104
  )
110
- def test_compute_spectrum(
111
- self, shape: Tuple[int], n_bins: int, batch_dimension: int
112
- ):
105
+ def test_compute_spectrum(self, shape: Tuple[int], n_bins: int):
113
106
  data_rfft = be.fft.rfftn(be.random.random(shape))
114
107
  bins, radial_averages = LinearWhiteningFilter._compute_spectrum(
115
- data_rfft, n_bins, batch_dimension
116
- )
117
- data_shape = tuple(
118
- int(x) for i, x in enumerate(data_rfft.shape) if i != batch_dimension
108
+ data_rfft, n_bins
119
109
  )
110
+ data_shape = tuple(int(x) for i, x in enumerate(data_rfft.shape))
120
111
 
121
112
  assert isinstance(bins, np.ndarray)
122
113
  assert isinstance(radial_averages, np.ndarray)
@@ -140,9 +131,8 @@ class TestLinearWhiteningFilter:
140
131
  @pytest.mark.parametrize(
141
132
  "shape, n_bins, batch_dimension, order",
142
133
  [
143
- ((10, 10), None, None, 1),
144
- ((20, 20, 20), 15, 0, 2),
145
- ((30, 30, 30), 20, 1, None),
134
+ ((10, 10), None, (), 1),
135
+ ((20, 20, 20), 15, 0, 1),
146
136
  ],
147
137
  )
148
138
  def test_call_method(
@@ -154,21 +144,17 @@ class TestLinearWhiteningFilter:
154
144
  ):
155
145
  data = be.random.random(shape)
156
146
  result = LinearWhiteningFilter()(
157
- shape=shape,
158
- data=data,
147
+ shape=tuple(x for i, x in enumerate(shape) if i != batch_dimension),
148
+ data_rfft=np.fft.rfftn(data),
159
149
  n_bins=n_bins,
160
- batch_dimension=batch_dimension,
150
+ axes=batch_dimension,
161
151
  order=order,
162
152
  )
163
153
 
164
154
  assert isinstance(result, dict)
165
155
  assert result.get("data", False) is not False
166
- assert result.get("is_multiplicative_filter", False)
167
156
  assert isinstance(result["data"], type(be.ones((1,))))
168
- data_shape = tuple(
169
- int(x) for i, x in enumerate(data.shape) if i != batch_dimension
170
- )
171
- assert result["data"].shape == tuple(compute_fourier_shape(data_shape, False))
157
+ assert result["data"].shape == shape
172
158
 
173
159
  def test_call_method_with_data_rfft(self):
174
160
  shape = (30, 30, 30)
@@ -179,14 +165,13 @@ class TestLinearWhiteningFilter:
179
165
 
180
166
  assert isinstance(result, dict)
181
167
  assert result.get("data", False) is not False
182
- assert result.get("is_multiplicative_filter", False)
183
168
  assert isinstance(result["data"], type(be.ones((1,))))
184
169
  assert result["data"].shape == data_rfft.shape
185
170
 
186
171
  @pytest.mark.parametrize("shape", [(10, 10), (20, 20, 20), (30, 30, 30)])
187
172
  def test_filter_mask_range(self, shape: Tuple[int]):
188
173
  data = be.random.random(shape)
189
- result = LinearWhiteningFilter()(shape=shape, data=data)
174
+ result = LinearWhiteningFilter()(shape=shape, data_rfft=np.fft.rfftn(data))
190
175
 
191
176
  filter_mask = result["data"]
192
177
  assert np.all(filter_mask >= 0) and np.all(filter_mask <= 1)
@@ -1,5 +1,4 @@
1
1
  import pytest
2
- import numpy as np
3
2
 
4
3
  from tme import Density, Structure, Preprocessor
5
4
 
@@ -49,15 +48,6 @@ class TestPreprocessor:
49
48
  high_sigma=high_sigma,
50
49
  )
51
50
 
52
- @pytest.mark.parametrize("smallest_size,largest_size", [(1, 10), (2, 20)])
53
- def test_bandpass_filter(self, smallest_size, largest_size):
54
- _ = self.preprocessor.bandpass_filter(
55
- template=self.structure_density.data,
56
- lowpass=smallest_size,
57
- highpass=largest_size,
58
- sampling_rate=1,
59
- )
60
-
61
51
  @pytest.mark.parametrize("lbd,sigma_range", [(1, (2, 4)), (20, (1, 6))])
62
52
  def test_local_gaussian_alignment_filter(self, lbd, sigma_range):
63
53
  _ = self.preprocessor.local_gaussian_alignment_filter(
@@ -125,12 +115,3 @@ class TestPreprocessor:
125
115
  template=self.structure_density.data,
126
116
  rank=rank,
127
117
  )
128
-
129
- @pytest.mark.parametrize("infinite_plane", [False, True])
130
- def test_continuous_wedge_mask(self, infinite_plane):
131
- _ = self.preprocessor.continuous_wedge_mask(
132
- start_tilt=50,
133
- stop_tilt=-40,
134
- shape=(50, 50, 50),
135
- infinite_plane=infinite_plane,
136
- )
@@ -74,10 +74,22 @@ class TestPreprocessUtils:
74
74
  def test_fftfreqn(self, n, sampling_rate):
75
75
  assert np.allclose(
76
76
  fftfreqn(
77
- shape=(n,), sampling_rate=sampling_rate, compute_euclidean_norm=True
77
+ shape=(n,),
78
+ sampling_rate=sampling_rate,
79
+ compute_euclidean_norm=True,
80
+ fftshift=True,
78
81
  ),
79
82
  np.abs(np.fft.ifftshift(np.fft.fftfreq(n=n, d=sampling_rate))),
80
83
  )
84
+ assert np.allclose(
85
+ fftfreqn(
86
+ shape=(n,),
87
+ sampling_rate=sampling_rate,
88
+ compute_euclidean_norm=True,
89
+ fftshift=False,
90
+ ),
91
+ np.abs(np.fft.fftfreq(n=n, d=sampling_rate)),
92
+ )
81
93
 
82
94
  @pytest.mark.parametrize("shape", ((10,), (10, 15), (10, 15, 30)))
83
95
  def test_crop_real_fourier(self, shape):
tests/test_analyzer.py CHANGED
@@ -109,13 +109,11 @@ class TestRecursiveMasking:
109
109
 
110
110
  @pytest.mark.parametrize("num_peaks", (1, 100))
111
111
  @pytest.mark.parametrize("compute_rotation", (True, False))
112
- @pytest.mark.parametrize("minimum_score", (None, 0.5))
113
- def test__call__(self, num_peaks, compute_rotation, minimum_score):
112
+ def test__call__(self, num_peaks, compute_rotation):
114
113
  peak_caller = PeakCallerRecursiveMasking(
115
114
  shape=self.data.shape,
116
115
  num_peaks=num_peaks,
117
116
  min_distance=self.min_distance,
118
- min_score=minimum_score,
119
117
  )
120
118
  rotation_space, rotation_mapping = None, None
121
119
  if compute_rotation:
@@ -131,13 +129,7 @@ class TestRecursiveMasking:
131
129
  rotation_space=rotation_space,
132
130
  rotation_mapping=rotation_mapping,
133
131
  )
134
-
135
- if minimum_score is None:
136
- assert len(state[0] <= num_peaks)
137
- else:
138
- peaks = state[0].astype(int)
139
- assert np.all(self.data[tuple(peaks.T)] >= minimum_score)
140
-
132
+ assert len(state[0] <= num_peaks)
141
133
 
142
134
  class TestMaxScoreOverRotations:
143
135
  def setup_method(self):
tests/test_backends.py CHANGED
@@ -331,16 +331,6 @@ class TestBackends:
331
331
  real_arr = backend.irfftn(complex_arr)
332
332
  assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
333
333
 
334
- @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
335
- def test_extract_center(self, backend):
336
- new_shape = np.divide(self.x1.shape, 2).astype(int)
337
- base = self.backend.extract_center(arr=self.x1, newshape=new_shape)
338
- other = backend.extract_center(
339
- arr=backend.to_backend_array(self.x1), newshape=new_shape
340
- )
341
-
342
- assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
343
-
344
334
  @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
345
335
  def test_compute_convolution_shapes(self, backend):
346
336
  base = self.backend.compute_convolution_shapes(self.x1.shape, self.x2.shape)
@@ -359,32 +349,69 @@ class TestBackends:
359
349
  elif dim == 3:
360
350
  arr[20:25, 21:26, 26:31] = 1
361
351
 
362
- rotation_matrix = np.eye(dim)
363
- rotation_matrix[0, 0] = -1
352
+ from tme.rotations import get_rotation_matrices
353
+
354
+ np.random.seed(42)
355
+ rotation_matrix = get_rotation_matrices(
356
+ dim=dim, angular_sampling=10, use_optimized_set=False
357
+ )[-1]
364
358
 
365
359
  out = np.zeros_like(arr)
360
+ out.setflags(write=True)
366
361
 
367
362
  arr_mask, out_mask = None, None
368
363
  if create_mask:
369
364
  arr_mask = np.multiply(np.random.rand(*arr.shape) > 0.5, 1.0)
370
365
  out_mask = np.zeros_like(arr_mask)
366
+ out_mask.setflags(write=True)
367
+
368
+ out, _ = NumpyFFTWBackend().rigid_transform(
369
+ arr=arr,
370
+ arr_mask=arr_mask,
371
+ rotation_matrix=rotation_matrix,
372
+ out=out,
373
+ out_mask=out_mask,
374
+ order=1,
375
+ use_geometric_center=True,
376
+ )
377
+
378
+ arr = backend.to_backend_array(arr.copy())
379
+ out_be = backend.to_backend_array(out.copy())
380
+ if create_mask:
371
381
  arr_mask = backend.to_backend_array(arr_mask)
372
382
  out_mask = backend.to_backend_array(out_mask)
373
383
 
374
- arr = backend.to_backend_array(arr)
375
- out = backend.to_backend_array(arr)
376
-
377
384
  rotation_matrix = backend.to_backend_array(rotation_matrix)
378
385
 
379
- backend.rigid_transform(
386
+ out_be, _ = backend.rigid_transform(
380
387
  arr=arr,
381
388
  arr_mask=arr_mask,
382
389
  rotation_matrix=rotation_matrix,
383
- out=out,
390
+ out=out_be,
384
391
  out_mask=out_mask,
392
+ order=1,
393
+ use_geometric_center=True,
385
394
  )
395
+ out_be = backend.to_numpy_array(out_be)
396
+ assert np.allclose(out, out_be, atol=0.3)
386
397
 
387
- assert np.round(arr.sum(), 3) == np.round(out.sum(), 3)
398
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
399
+ @pytest.mark.parametrize("create_mask", (False, True))
400
+ def test_rigid_transform_identity(self, backend, create_mask):
401
+ dim = 3
402
+ shape = tuple(50 for _ in range(dim))
403
+
404
+ arr = np.zeros(shape)
405
+ arr[20:25, 21:26, 26:31] = 1
406
+
407
+ rotation_matrix = backend.to_backend_array(np.eye(dim))
408
+ out, _ = backend.rigid_transform(
409
+ arr=backend.to_backend_array(arr),
410
+ rotation_matrix=backend.to_backend_array(rotation_matrix),
411
+ order=1,
412
+ use_geometric_center=True,
413
+ )
414
+ assert np.allclose(out, arr, atol=0.01)
388
415
 
389
416
  @pytest.mark.parametrize("dim", (2, 3))
390
417
  @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
@@ -418,6 +445,7 @@ class TestBackends:
418
445
  out=out,
419
446
  out_mask=out_mask,
420
447
  batched=True,
448
+ order=1,
421
449
  )
422
450
 
423
451
  arr_b = backend.to_backend_array(arr_b)
@@ -430,6 +458,7 @@ class TestBackends:
430
458
  rotation_matrix=rotation_matrix,
431
459
  out=out_b[i],
432
460
  out_mask=out_mask if out_mask is None else out_mask[i],
461
+ order=1,
433
462
  )
434
463
 
435
464
  assert np.allclose(arr, arr_b)