pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +49 -103
  6. {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
  8. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +49 -103
  15. scripts/pytme_runner.py +46 -69
  16. tests/preprocessing/test_compose.py +31 -30
  17. tests/preprocessing/test_frequency_filters.py +17 -32
  18. tests/preprocessing/test_preprocessor.py +0 -19
  19. tests/preprocessing/test_utils.py +13 -1
  20. tests/test_analyzer.py +2 -10
  21. tests/test_backends.py +47 -18
  22. tests/test_density.py +72 -13
  23. tests/test_extensions.py +1 -0
  24. tests/test_matching_cli.py +23 -9
  25. tests/test_matching_exhaustive.py +5 -5
  26. tests/test_matching_utils.py +3 -3
  27. tests/test_orientations.py +12 -0
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +91 -68
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +103 -98
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +44 -57
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +17 -3
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
  65. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
  66. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
  67. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
  68. {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
tests/test_density.py CHANGED
@@ -101,9 +101,10 @@ class TestDensity:
101
101
 
102
102
  def test_from_file_baseline(self):
103
103
  self.test_to_file(gzip=False)
104
- density = Density.from_file(str(BASEPATH.joinpath("Maps/emd_8621.mrc.gz")))
105
- assert np.allclose(density.origin, (4.35, 2.90, -1.45), rtol=0.1)
106
- assert np.allclose(density.sampling_rate, (1.45), rtol=0.3)
104
+ density = Density.from_file(str(BASEPATH.joinpath("Raw/em_map.map")))
105
+ assert density.shape == (19, 14, 20)
106
+ assert np.allclose(density.origin, (-52.8, -10.56, -52.8), rtol=0.1)
107
+ assert np.allclose(density.sampling_rate, (5.28), rtol=0.1)
107
108
 
108
109
  @pytest.mark.parametrize("extension", ("mrc", "em", "tiff", "h5"))
109
110
  @pytest.mark.parametrize("gzip", (True, False))
@@ -315,7 +316,7 @@ class TestDensity:
315
316
  "cutoff", [DEFAULT_DATA.min() - 1, 0, DEFAULT_DATA.max() - 0.1]
316
317
  )
317
318
  def test_centered(self, cutoff):
318
- centered_density, translation = self.density.centered(cutoff=cutoff)
319
+ centered_density = self.density.centered(cutoff=cutoff)
319
320
  com = centered_density.center_of_mass(centered_density.data, 0)
320
321
 
321
322
  difference = np.abs(
@@ -334,7 +335,7 @@ class TestDensity:
334
335
  box = temp.minimum_enclosing_box(cutoff=0, use_geometric_center=True)
335
336
  temp.adjust_box(box)
336
337
  else:
337
- temp, translation = temp.centered()
338
+ temp = temp.centered()
338
339
 
339
340
  swaps = set(permutations([0, 1, 2]))
340
341
  temp_matrix = np.eye(temp.data.ndim).astype(np.float32)
@@ -436,7 +437,7 @@ class TestDensity:
436
437
  arr = np.linalg.norm(np.indices(shape) - position, axis=0)
437
438
  arr = (arr <= radius).astype(np.float32)
438
439
 
439
- center_of_mass = Density.center_of_mass(arr)
440
+ center_of_mass = Density(arr).center_of_mass()
440
441
  assert np.allclose(center, center_of_mass)
441
442
 
442
443
  @pytest.mark.parametrize(
@@ -451,7 +452,7 @@ class TestDensity:
451
452
  target[5:10, 15:22, 10:13] = 1
452
453
 
453
454
  target = Density(target, sampling_rate=(1, 1, 1), origin=(0, 0, 0))
454
- target, translation = target.centered(cutoff=0)
455
+ target = target.centered(cutoff=0)
455
456
 
456
457
  template = target.copy()
457
458
 
@@ -474,21 +475,79 @@ class TestDensity:
474
475
  assert np.allclose(np.linalg.inv(rotation), initial_rotation, atol=0.2)
475
476
 
476
477
  def test_match_structure_to_density(self):
477
- density = Density.from_file("tests/data/Maps/emd_8621.mrc.gz")
478
- density = density.resample(density.sampling_rate * 4)
479
- structure = Structure.from_file(
480
- "tests/data/Structures/5uz4.cif", filter_by_residues=None
478
+ mask = create_mask(
479
+ mask_type="ellipse",
480
+ center=(20, 20, 20),
481
+ radius=(10, 5, 10),
482
+ shape=(50, 50, 50),
481
483
  )
482
484
 
485
+ coordinates = np.array(np.where(mask > 0), dtype=np.float32).T
486
+ structure = Structure(
487
+ record_type=[
488
+ "ATOM",
489
+ ]
490
+ * coordinates.shape[0],
491
+ atom_serial_number=list(range(coordinates.shape[0])),
492
+ atom_name=[
493
+ "C",
494
+ ]
495
+ * coordinates.shape[0],
496
+ atom_coordinate=coordinates,
497
+ alternate_location_indicator=[
498
+ ".",
499
+ ]
500
+ * coordinates.shape[0],
501
+ residue_name=[
502
+ "GLY",
503
+ ]
504
+ * coordinates.shape[0],
505
+ chain_identifier=[
506
+ "A",
507
+ ]
508
+ * coordinates.shape[0],
509
+ residue_sequence_number=[
510
+ 0,
511
+ ]
512
+ * coordinates.shape[0],
513
+ code_for_residue_insertion=[
514
+ "?",
515
+ ]
516
+ * coordinates.shape[0],
517
+ occupancy=[
518
+ 0,
519
+ ]
520
+ * coordinates.shape[0],
521
+ temperature_factor=[
522
+ 0,
523
+ ]
524
+ * coordinates.shape[0],
525
+ segment_identifier=[
526
+ "1",
527
+ ]
528
+ * coordinates.shape[0],
529
+ element_symbol=[
530
+ "C",
531
+ ]
532
+ * coordinates.shape[0],
533
+ charge=[
534
+ "?",
535
+ ]
536
+ * coordinates.shape[0],
537
+ metadata={},
538
+ )
539
+ density = Density(mask, sampling_rate=1.0)
540
+ density = density.resample(density.sampling_rate * 2)
541
+
483
542
  initial_translation = np.array([-1, 0, 5])
484
543
  initial_rotation = euler_to_rotationmatrix((-10, 2, 5))
485
- structure.rigid_transform(
544
+ structure_mod = structure.rigid_transform(
486
545
  translation=initial_translation, rotation_matrix=initial_rotation
487
546
  )
488
547
  np.random.seed(12)
489
548
  ret = Density.match_structure_to_density(
490
549
  target=density,
491
- template=structure,
550
+ template=structure_mod,
492
551
  cutoff_target=0,
493
552
  scoring_method="CrossCorrelation",
494
553
  maxiter=10,
tests/test_extensions.py CHANGED
@@ -53,6 +53,7 @@ class TestExtensions:
53
53
  @pytest.mark.parametrize("min_distance", [0, 5, 10])
54
54
  def test_find_candidate_indices(self, dimension, dtype, min_distance):
55
55
  coordinates = COORDINATES[dimension].astype(dtype)
56
+ print(coordinates.shape)
56
57
 
57
58
  min_distance = np.array([min_distance]).astype(dtype)[0]
58
59
 
@@ -10,7 +10,7 @@ from tme import Density, Orientations
10
10
  from tme.backends import backend as be
11
11
 
12
12
  np.random.seed(42)
13
- available_backends = (x for x in be.available_backends() if x != "mlx")
13
+ available_backends = tuple(x for x in be.available_backends() if x != "mlx")
14
14
 
15
15
 
16
16
  def argdict_to_command(input_args, executable: str):
@@ -29,7 +29,7 @@ def argdict_to_command(input_args, executable: str):
29
29
  return " ".join(ret)
30
30
 
31
31
 
32
- class TestMatchTemplate:
32
+ class TestSetup:
33
33
  @classmethod
34
34
  def setup_class(cls):
35
35
  target = np.random.rand(20, 20, 20)
@@ -98,6 +98,7 @@ class TestMatchTemplate:
98
98
  use_target_mask: bool = False,
99
99
  backend: str = "numpyfftw",
100
100
  test_rejection_sampling: bool = False,
101
+ background_correction: str = None,
101
102
  ):
102
103
  output_path = tempfile.NamedTemporaryFile(delete=False, suffix="pickle").name
103
104
 
@@ -120,6 +121,9 @@ class TestMatchTemplate:
120
121
  if test_rejection_sampling:
121
122
  argdict["--orientations"] = self.orientations_path
122
123
 
124
+ if background_correction is not None:
125
+ argdict["--background-correction"] = background_correction
126
+
123
127
  if test_filter:
124
128
  argdict["--lowpass"] = 30
125
129
  argdict["--defocus"] = 3000
@@ -136,11 +140,15 @@ class TestMatchTemplate:
136
140
  assert ret.returncode == 0
137
141
  return output_path
138
142
 
143
+
144
+ class TestMatchTemplate(TestSetup):
145
+
139
146
  @pytest.mark.parametrize("backend", available_backends)
140
147
  @pytest.mark.parametrize("call_peaks", (False, True))
141
148
  @pytest.mark.parametrize("use_template_mask", (False, True))
142
149
  @pytest.mark.parametrize("test_filter", (False, True))
143
150
  @pytest.mark.parametrize("test_rejection_sampling", (False, True))
151
+ @pytest.mark.parametrize("background_correction", (None, "phase-scrambling"))
144
152
  def test_match_template(
145
153
  self,
146
154
  backend: bool,
@@ -148,8 +156,14 @@ class TestMatchTemplate:
148
156
  use_template_mask: bool,
149
157
  test_filter: bool,
150
158
  test_rejection_sampling: bool,
159
+ background_correction: str,
151
160
  ):
152
- if backend == "jax" and (call_peaks or test_rejection_sampling):
161
+ # Jax does not support peak calling yet
162
+ if backend == "jax" and call_peaks:
163
+ return None
164
+
165
+ # These use different analyzers raising an error in the interface
166
+ if call_peaks and test_rejection_sampling:
153
167
  return None
154
168
 
155
169
  self.run_matching(
@@ -163,10 +177,11 @@ class TestMatchTemplate:
163
177
  template_mask_path=self.template_mask_path,
164
178
  target_mask_path=self.target_mask_path,
165
179
  test_rejection_sampling=test_rejection_sampling,
180
+ background_correction=background_correction,
166
181
  )
167
182
 
168
183
 
169
- class TestPostprocessing(TestMatchTemplate):
184
+ class TestPostprocessing(TestSetup):
170
185
  @classmethod
171
186
  def setup_class(cls):
172
187
  super().setup_class()
@@ -256,7 +271,6 @@ class TestPostprocessing(TestMatchTemplate):
256
271
  }
257
272
  cmd = argdict_to_command(argdict, executable="postprocess.py")
258
273
  ret = subprocess.run(cmd, capture_output=True, shell=True)
259
- print(ret)
260
274
 
261
275
  match output_format:
262
276
  case "orientations":
@@ -289,14 +303,14 @@ class TestPostprocessing(TestMatchTemplate):
289
303
  assert ret.returncode == 0
290
304
 
291
305
 
292
- class TestEstimateMemoryUsage(TestMatchTemplate):
306
+ class TestEstimateMemoryUsage(TestSetup):
293
307
  @classmethod
294
308
  def setup_class(cls):
295
309
  super().setup_class()
296
310
 
297
311
  @pytest.mark.parametrize("ncores", (1, 4, 8))
298
312
  @pytest.mark.parametrize("pad_edges", (False, True))
299
- def test_estimation(self, ncores, pad_edges):
313
+ def test_estimation_cli(self, ncores, pad_edges):
300
314
 
301
315
  argdict = {
302
316
  "-m": self.target_path,
@@ -311,7 +325,7 @@ class TestEstimateMemoryUsage(TestMatchTemplate):
311
325
  assert ret.returncode == 0
312
326
 
313
327
 
314
- class TestPreprocess(TestMatchTemplate):
328
+ class TestPreprocess(TestSetup):
315
329
  @classmethod
316
330
  def setup_class(cls):
317
331
  super().setup_class()
@@ -319,7 +333,7 @@ class TestPreprocess(TestMatchTemplate):
319
333
  @pytest.mark.parametrize("backend", available_backends)
320
334
  @pytest.mark.parametrize("align_axis", (False, True))
321
335
  @pytest.mark.parametrize("invert_contrast", (False, True))
322
- def test_estimation(self, backend, align_axis, invert_contrast):
336
+ def test_preprocess_cli(self, backend, align_axis, invert_contrast):
323
337
 
324
338
  argdict = {
325
339
  "-m": self.target_path,
@@ -1,5 +1,5 @@
1
- import numpy as np
2
1
  import pytest
2
+ import numpy as np
3
3
 
4
4
  from scipy.ndimage import laplace
5
5
 
@@ -7,7 +7,7 @@ from tme.matching_data import MatchingData
7
7
  from tme.memory import MATCHING_MEMORY_REGISTRY
8
8
  from tme.analyzer import MaxScoreOverRotations, PeakCallerSort
9
9
  from tme.matching_exhaustive import (
10
- scan_subsets,
10
+ match_exhaustive,
11
11
  MATCHING_EXHAUSTIVE_REGISTER,
12
12
  register_matching_exhaustive,
13
13
  )
@@ -35,11 +35,11 @@ class TestMatchExhaustive:
35
35
  self.coordinates_weights = None
36
36
  self.rotations = None
37
37
 
38
- @pytest.mark.parametrize("evaluate_peak", (True,))
38
+ @pytest.mark.parametrize("evaluate_peak", (True, False))
39
39
  @pytest.mark.parametrize("score", tuple(MATCHING_EXHAUSTIVE_REGISTER.keys()))
40
40
  @pytest.mark.parametrize("job_schedule", ((2, 1),))
41
41
  @pytest.mark.parametrize("pad_edge", (False, True))
42
- def test_scan_subset(
42
+ def test_match_exhaustive(
43
43
  self,
44
44
  score: str,
45
45
  job_schedule: int,
@@ -64,7 +64,7 @@ class TestMatchExhaustive:
64
64
  if evaluate_peak:
65
65
  callback_class = PeakCallerSort
66
66
 
67
- ret = scan_subsets(
67
+ ret = match_exhaustive(
68
68
  matching_data=matching_data,
69
69
  matching_setup=setup,
70
70
  matching_score=process,
@@ -15,7 +15,7 @@ from tme.matching_utils import (
15
15
  apply_convolution_mode,
16
16
  write_pickle,
17
17
  load_pickle,
18
- _normalize_template_overflow_safe,
18
+ _standardize_safe,
19
19
  )
20
20
 
21
21
  BASEPATH = files("tests.data")
@@ -127,12 +127,12 @@ class TestMatchingUtils:
127
127
  loaded_data = load_pickle(filename)
128
128
  assert np.array_equal(loaded_data, data)
129
129
 
130
- def test_normalize_template_overflow_safe(self):
130
+ def test_standardize_safe(self):
131
131
  template = be.random.random((10, 10)).astype(be.float32)
132
132
  mask = be.ones_like(template)
133
133
  n_observations = 100.0
134
134
 
135
- result = _normalize_template_overflow_safe(template, mask, n_observations)
135
+ result = _standardize_safe(template, mask, n_observations)
136
136
  assert result.shape == template.shape
137
137
  assert result.dtype == template.dtype
138
138
  assert np.allclose(result.mean(), 0, atol=0.1)
@@ -95,6 +95,18 @@ class TestDensity:
95
95
  self.orientations.rotations, orientations_new.rotations, atol=1e-3
96
96
  )
97
97
 
98
+ @pytest.mark.parametrize("input_format", ("text", "star", "tbl"))
99
+ @pytest.mark.parametrize("output_format", ("text", "star", "tbl"))
100
+ def test_file_format_io(self, input_format: str, output_format: str):
101
+ _, output_file = mkstemp(suffix=f".{input_format}")
102
+ _, output_file2 = mkstemp(suffix=f".{output_format}")
103
+
104
+ self.orientations.to_file(output_file)
105
+ orientations_new = Orientations.from_file(output_file)
106
+ orientations_new.to_file(output_file2)
107
+
108
+ assert True
109
+
98
110
  @pytest.mark.parametrize("drop_oob", (True, False))
99
111
  @pytest.mark.parametrize("shape", (10, 40, 80))
100
112
  @pytest.mark.parametrize("odd", (True, False))
tests/test_rotations.py CHANGED
@@ -35,24 +35,19 @@ class TestRotations:
35
35
  )
36
36
 
37
37
  @pytest.mark.parametrize(
38
- "initial_vector, target_vector, convention",
38
+ "initial_vector, target_vector",
39
39
  [
40
- ([1, 0, 0], [0, 1, 0], None),
41
- ([0, 1, 0], [0, 0, 1], "zyx"),
42
- ([1, 1, 1], [1, 0, 0], "xyz"),
40
+ ([1, 0, 0], [0, 1, 0]),
41
+ ([0, 1, 0], [0, 0, 1]),
42
+ ([1, 1, 1], [1, 0, 0]),
43
43
  ],
44
44
  )
45
- def test_align_vectors(self, initial_vector, target_vector, convention):
46
- result = align_vectors(initial_vector, target_vector, convention)
45
+ def test_align_vectors(self, initial_vector, target_vector):
46
+ result = align_vectors(initial_vector, target_vector)
47
47
 
48
48
  assert isinstance(result, np.ndarray)
49
- if convention is None:
50
- assert result.shape == (3, 3)
51
- assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
52
- else:
53
- assert len(result) == 3
54
- result = Rotation.from_euler(convention, result, degrees=True).as_matrix()
55
- assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
49
+ assert result.shape == (3, 3)
50
+ assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
56
51
 
57
52
  rotated = np.dot(Rotation.from_matrix(result).as_matrix(), initial_vector)
58
53
  assert np.allclose(
@@ -62,11 +57,11 @@ class TestRotations:
62
57
  )
63
58
 
64
59
  @pytest.mark.parametrize(
65
- "cone_angle, cone_sampling, axis_angle, axis_sampling, vector, n_symmetry, convention",
60
+ "cone_angle, cone_sampling, axis_angle, axis_sampling, vector, n_symmetry",
66
61
  [
67
- (30, 5, 360, None, (1, 0, 0), 1, None),
68
- (45, 10, 180, 15, (0, 1, 0), 2, "zyx"),
69
- (60, 15, 90, 30, (0, 0, 1), 4, "xyz"),
62
+ (30, 5, 360, None, (1, 0, 0), 1),
63
+ (45, 10, 180, 15, (0, 1, 0), 2),
64
+ (60, 15, 90, 30, (0, 0, 1), 4),
70
65
  ],
71
66
  )
72
67
  def test_get_cone_rotations(
@@ -77,7 +72,6 @@ class TestRotations:
77
72
  axis_sampling,
78
73
  vector,
79
74
  n_symmetry,
80
- convention,
81
75
  ):
82
76
  result = get_cone_rotations(
83
77
  cone_angle=cone_angle,
@@ -86,14 +80,10 @@ class TestRotations:
86
80
  axis_sampling=axis_sampling,
87
81
  reference=vector,
88
82
  n_symmetry=n_symmetry,
89
- seq=convention,
90
83
  )
91
84
 
92
85
  assert isinstance(result, np.ndarray)
93
- if convention is None:
94
- assert result.shape[1:] == (3, 3)
95
- else:
96
- assert result.shape[1] == 3
86
+ assert result.shape[1:] == (3, 3)
97
87
 
98
88
  def test_euler_conversion(self):
99
89
  rotation_matrix_initial = np.array(
tests/test_structure.py CHANGED
@@ -6,7 +6,6 @@ import pytest
6
6
  import numpy as np
7
7
 
8
8
  from tme import Structure
9
- from tme.matching_utils import minimum_enclosing_box
10
9
  from tme.rotations import euler_to_rotationmatrix
11
10
 
12
11
 
@@ -196,11 +195,6 @@ class TestStructure:
196
195
  assert center_of_mass.shape[0] == self.structure.atom_coordinate.shape[1]
197
196
  assert np.allclose(center_of_mass, [-0.89391639, 29.94908928, -2.64736741])
198
197
 
199
- def test_centered(self):
200
- ret, translation = self.structure.centered()
201
- box = minimum_enclosing_box(coordinates=self.structure.atom_coordinate.T)
202
- assert np.allclose(ret.center_of_mass(), np.divide(box, 2), atol=1)
203
-
204
198
  def test__get_atom_weights_error(self):
205
199
  with pytest.raises(NotImplementedError):
206
200
  self.structure._get_atom_weights(
@@ -242,6 +236,6 @@ class TestStructure:
242
236
  assert final_rmsd <= 0.1
243
237
 
244
238
  aligned, final_rmsd = Structure.align_structures(
245
- self.structure, structure_transform, sampling_rate=1
239
+ self.structure, structure_transform
246
240
  )
247
241
  assert final_rmsd <= 1
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.1"
1
+ __version__ = "0.3.2"
@@ -192,6 +192,15 @@ class MaxScoreOverRotations(AbstractAnalyzer):
192
192
  )
193
193
  return scores, rotations, rotation_mapping, ssum
194
194
 
195
+ def correct_background(self, state, mean=0, inv_std=1, **kwargs):
196
+ scores, rotations, rotation_mapping, ssum = state
197
+
198
+ scores = be.subtract(scores, mean, out=scores)
199
+ scores = be.multiply(scores, inv_std, out=scores)
200
+
201
+ scores = be.maximum(scores, self._score_threshold, out=scores)
202
+ return scores, rotations, rotation_mapping, ssum
203
+
195
204
  @staticmethod
196
205
  def _invert_rmap(rotation_mapping: dict) -> dict:
197
206
  """
@@ -201,7 +210,12 @@ class MaxScoreOverRotations(AbstractAnalyzer):
201
210
  new_map, ndim = {}, None
202
211
  for k, v in rotation_mapping.items():
203
212
  nbytes = be.datatype_bytes(be._float_dtype)
204
- dtype = np.float32 if nbytes == 4 else np.float16
213
+ if nbytes == 8:
214
+ dtype = np.float64
215
+ elif nbytes == 4:
216
+ dtype = np.float32
217
+ else:
218
+ np.float16
205
219
  rmat = np.frombuffer(k, dtype=dtype)
206
220
  if ndim is None:
207
221
  ndim = int(np.sqrt(rmat.size))
@@ -451,7 +465,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
451
465
  Maximum accepted rotational deviation in degrees.
452
466
  positions : BackendArray
453
467
  Array of shape (n, d) with n seed point translations.
454
- positions : BackendArray
468
+ rotations : BackendArray
455
469
  Array of shape (n, d, d) with n seed point rotation matrices.
456
470
  reference : BackendArray
457
471
  Reference orientation of the template, wlog defaults to (0,0,1).
@@ -489,6 +503,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
489
503
  be.reshape(be.to_backend_array(reference), (-1,)), be._float_dtype
490
504
  )
491
505
  positions = be.astype(be.to_backend_array(positions), be._int_dtype)
506
+ rotations = be.astype(be.to_backend_array(rotations), be._float_dtype)
492
507
 
493
508
  ndim = positions.shape[1]
494
509
  rotate_mask = len(set(acceptance_radius)) != 1
@@ -515,7 +530,13 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
515
530
  )
516
531
 
517
532
  self._positions = positions[valid_positions]
518
- rotations = be.to_backend_array(rotations)[valid_positions]
533
+ rotations = rotations[valid_positions]
534
+
535
+ # Convert to pull matrix to remain consistent with rotation convention
536
+ rotations = be.concatenate(
537
+ [rotations[i].T[None] for i in range(rotations.shape[0])]
538
+ )
539
+
519
540
  ex = be.astype(be.to_backend_array((1, 0, 0)), be._float_dtype)
520
541
  ey = be.astype(be.to_backend_array((0, 1, 0)), be._float_dtype)
521
542
  ez = be.astype(be.to_backend_array((0, 0, 1)), be._float_dtype)
@@ -524,6 +545,15 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
524
545
  self._normals_y = (rotations @ ey[..., None])[..., 0]
525
546
  self._normals_z = (rotations @ ez[..., None])[..., 0]
526
547
 
548
+ # All scores will be rejected in this case. We should think about a
549
+ # unified interface for checking analyzer validity to skip such runs
550
+ if self._positions.shape[0] == 0:
551
+
552
+ def _get_score_mask(*args, **kwargs):
553
+ return 0
554
+
555
+ self._get_score_mask = _get_score_mask
556
+
527
557
  # Periodic wrapping could be avoided by padding the target
528
558
  shape = be.to_backend_array(self._shape)
529
559
  starts = be.subtract(self._positions, extend)
@@ -539,9 +569,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
539
569
  self._mask_shape = tuple(1 if i != 0 else -1 for i in range(1 + ndim))
540
570
 
541
571
  if rotate_mask:
542
- self._score_mask = be.zeros(
543
- (rotations.shape[0], *self._score_mask.shape), dtype=be._float_dtype
544
- )
572
+ self._score_mask = []
545
573
  for i in range(rotations.shape[0]):
546
574
  mask = create_mask(
547
575
  mask_type="ellipse",
@@ -550,9 +578,10 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
550
578
  center=tuple(extend for _ in range(ndim)),
551
579
  orientation=be.to_numpy_array(rotations[i]),
552
580
  )
553
- self._score_mask[i] = be.astype(
554
- be.to_backend_array(mask), be._float_dtype
581
+ self._score_mask.append(
582
+ be.astype(be.to_backend_array(mask), be._float_dtype)[None]
555
583
  )
584
+ self._score_mask = be.concatenate(self._score_mask)
556
585
 
557
586
  def __call__(
558
587
  self,
@@ -573,7 +602,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
573
602
  """
574
603
  Determine whether the angle between projection of reference w.r.t to
575
604
  a given rotation matrix and a set of rotations fall within the set
576
- cone_angle cutoff.
605
+ cone angle cutoff.
577
606
 
578
607
  Parameters
579
608
  ----------
@@ -585,7 +614,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
585
614
  BackerndArray
586
615
  Boolean mask of shape (n, )
587
616
  """
588
- template_rot = rotation_matrix @ self._reference
617
+ template_rot = rotation_matrix.T @ self._reference
589
618
 
590
619
  x = be.sum(be.multiply(self._normals_x, template_rot), axis=1)
591
620
  y = be.sum(be.multiply(self._normals_y, template_rot), axis=1)
@@ -596,10 +625,9 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
596
625
  def _get_score_mask(self, mask: BackendArray, scores: BackendArray, **kwargs):
597
626
  score_mask = be.zeros(scores.shape, scores.dtype)
598
627
 
599
- if be.sum(mask) == 0:
600
- return score_mask
628
+ # The indexing could be improved to avoid expanding the mask to
629
+ # the number of seed points
601
630
  mask = be.reshape(mask, self._mask_shape)
602
-
603
631
  score_mask = be.addat(score_mask, self._index_grid, self._score_mask * mask)
604
632
  return score_mask > 0
605
633
 
@@ -663,13 +691,16 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
663
691
  rotation_index = len(rotation_mapping)
664
692
  if self._inversion_mapping:
665
693
  rotation_mapping[rotation_index] = rotation_matrix
694
+ elif self._jax_mode:
695
+ rotation_index = kwargs.get("rotation_index", 0)
666
696
  else:
667
697
  rotation = be.tobytes(rotation_matrix)
668
698
  rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
669
- max_score = be.max(scores, axis=self._aggregate_axis)
670
699
 
671
- update = prev_scores[rotation_index]
672
- update = be.maximum(max_score, update, out=update)
700
+ scores = be.max(scores, axis=self._aggregate_axis)
701
+ scores = be.maximum(scores, prev_scores[rotation_index])
702
+ prev_scores = be.at(prev_scores, rotation_index, scores)
703
+
673
704
  return prev_scores, rotations, rotation_mapping
674
705
 
675
706
  @classmethod
tme/analyzer/base.py CHANGED
@@ -73,6 +73,40 @@ class AbstractAnalyzer(ABC):
73
73
  Updated analyzer state incorporating the new data.
74
74
  """
75
75
 
76
+ @abstractmethod
77
+ def correct_background(self, state, mean=0, inv_std=1, **kwargs):
78
+ """
79
+ Applies flat-fielding correction to scores f as
80
+
81
+ .. math::
82
+
83
+ f' = (f - \\text{mean}) \\cdot \\text{inv_std},
84
+
85
+ transforming raw correlations to SNR-like scores.
86
+
87
+ Parameters
88
+ ----------
89
+ state : tuple
90
+ Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
91
+ or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
92
+ mean : BackendArray, optional
93
+ Background mean (or equivalent), defaults to 0.
94
+ inv_std : BackendArray, optional
95
+ Reciprocal background standard deviation (or equivalent), defaults to 1.
96
+
97
+ Notes
98
+ -----
99
+ This method should be called after all rotations have been processed
100
+ but before calling :py:meth:`result`. The correction helps distinguish genuine
101
+ template matches from systematic background artifacts that may arise from
102
+ template edges, interpolation artifacts, or structured noise in the target.
103
+
104
+ Returns
105
+ -------
106
+ tuple
107
+ Updated analyzer state incorporating the new data.
108
+ """
109
+
76
110
  @abstractmethod
77
111
  def result(self, state: Tuple, **kwargs) -> Tuple:
78
112
  """