pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +49 -103
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +49 -103
- scripts/pytme_runner.py +46 -69
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_orientations.py +12 -0
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +91 -68
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +103 -98
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +44 -57
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +17 -3
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
tests/test_density.py
CHANGED
@@ -101,9 +101,10 @@ class TestDensity:
|
|
101
101
|
|
102
102
|
def test_from_file_baseline(self):
|
103
103
|
self.test_to_file(gzip=False)
|
104
|
-
density = Density.from_file(str(BASEPATH.joinpath("
|
105
|
-
assert
|
106
|
-
assert np.allclose(density.
|
104
|
+
density = Density.from_file(str(BASEPATH.joinpath("Raw/em_map.map")))
|
105
|
+
assert density.shape == (19, 14, 20)
|
106
|
+
assert np.allclose(density.origin, (-52.8, -10.56, -52.8), rtol=0.1)
|
107
|
+
assert np.allclose(density.sampling_rate, (5.28), rtol=0.1)
|
107
108
|
|
108
109
|
@pytest.mark.parametrize("extension", ("mrc", "em", "tiff", "h5"))
|
109
110
|
@pytest.mark.parametrize("gzip", (True, False))
|
@@ -315,7 +316,7 @@ class TestDensity:
|
|
315
316
|
"cutoff", [DEFAULT_DATA.min() - 1, 0, DEFAULT_DATA.max() - 0.1]
|
316
317
|
)
|
317
318
|
def test_centered(self, cutoff):
|
318
|
-
centered_density
|
319
|
+
centered_density = self.density.centered(cutoff=cutoff)
|
319
320
|
com = centered_density.center_of_mass(centered_density.data, 0)
|
320
321
|
|
321
322
|
difference = np.abs(
|
@@ -334,7 +335,7 @@ class TestDensity:
|
|
334
335
|
box = temp.minimum_enclosing_box(cutoff=0, use_geometric_center=True)
|
335
336
|
temp.adjust_box(box)
|
336
337
|
else:
|
337
|
-
temp
|
338
|
+
temp = temp.centered()
|
338
339
|
|
339
340
|
swaps = set(permutations([0, 1, 2]))
|
340
341
|
temp_matrix = np.eye(temp.data.ndim).astype(np.float32)
|
@@ -436,7 +437,7 @@ class TestDensity:
|
|
436
437
|
arr = np.linalg.norm(np.indices(shape) - position, axis=0)
|
437
438
|
arr = (arr <= radius).astype(np.float32)
|
438
439
|
|
439
|
-
center_of_mass = Density.center_of_mass(
|
440
|
+
center_of_mass = Density(arr).center_of_mass()
|
440
441
|
assert np.allclose(center, center_of_mass)
|
441
442
|
|
442
443
|
@pytest.mark.parametrize(
|
@@ -451,7 +452,7 @@ class TestDensity:
|
|
451
452
|
target[5:10, 15:22, 10:13] = 1
|
452
453
|
|
453
454
|
target = Density(target, sampling_rate=(1, 1, 1), origin=(0, 0, 0))
|
454
|
-
target
|
455
|
+
target = target.centered(cutoff=0)
|
455
456
|
|
456
457
|
template = target.copy()
|
457
458
|
|
@@ -474,21 +475,79 @@ class TestDensity:
|
|
474
475
|
assert np.allclose(np.linalg.inv(rotation), initial_rotation, atol=0.2)
|
475
476
|
|
476
477
|
def test_match_structure_to_density(self):
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
478
|
+
mask = create_mask(
|
479
|
+
mask_type="ellipse",
|
480
|
+
center=(20, 20, 20),
|
481
|
+
radius=(10, 5, 10),
|
482
|
+
shape=(50, 50, 50),
|
481
483
|
)
|
482
484
|
|
485
|
+
coordinates = np.array(np.where(mask > 0), dtype=np.float32).T
|
486
|
+
structure = Structure(
|
487
|
+
record_type=[
|
488
|
+
"ATOM",
|
489
|
+
]
|
490
|
+
* coordinates.shape[0],
|
491
|
+
atom_serial_number=list(range(coordinates.shape[0])),
|
492
|
+
atom_name=[
|
493
|
+
"C",
|
494
|
+
]
|
495
|
+
* coordinates.shape[0],
|
496
|
+
atom_coordinate=coordinates,
|
497
|
+
alternate_location_indicator=[
|
498
|
+
".",
|
499
|
+
]
|
500
|
+
* coordinates.shape[0],
|
501
|
+
residue_name=[
|
502
|
+
"GLY",
|
503
|
+
]
|
504
|
+
* coordinates.shape[0],
|
505
|
+
chain_identifier=[
|
506
|
+
"A",
|
507
|
+
]
|
508
|
+
* coordinates.shape[0],
|
509
|
+
residue_sequence_number=[
|
510
|
+
0,
|
511
|
+
]
|
512
|
+
* coordinates.shape[0],
|
513
|
+
code_for_residue_insertion=[
|
514
|
+
"?",
|
515
|
+
]
|
516
|
+
* coordinates.shape[0],
|
517
|
+
occupancy=[
|
518
|
+
0,
|
519
|
+
]
|
520
|
+
* coordinates.shape[0],
|
521
|
+
temperature_factor=[
|
522
|
+
0,
|
523
|
+
]
|
524
|
+
* coordinates.shape[0],
|
525
|
+
segment_identifier=[
|
526
|
+
"1",
|
527
|
+
]
|
528
|
+
* coordinates.shape[0],
|
529
|
+
element_symbol=[
|
530
|
+
"C",
|
531
|
+
]
|
532
|
+
* coordinates.shape[0],
|
533
|
+
charge=[
|
534
|
+
"?",
|
535
|
+
]
|
536
|
+
* coordinates.shape[0],
|
537
|
+
metadata={},
|
538
|
+
)
|
539
|
+
density = Density(mask, sampling_rate=1.0)
|
540
|
+
density = density.resample(density.sampling_rate * 2)
|
541
|
+
|
483
542
|
initial_translation = np.array([-1, 0, 5])
|
484
543
|
initial_rotation = euler_to_rotationmatrix((-10, 2, 5))
|
485
|
-
structure.rigid_transform(
|
544
|
+
structure_mod = structure.rigid_transform(
|
486
545
|
translation=initial_translation, rotation_matrix=initial_rotation
|
487
546
|
)
|
488
547
|
np.random.seed(12)
|
489
548
|
ret = Density.match_structure_to_density(
|
490
549
|
target=density,
|
491
|
-
template=
|
550
|
+
template=structure_mod,
|
492
551
|
cutoff_target=0,
|
493
552
|
scoring_method="CrossCorrelation",
|
494
553
|
maxiter=10,
|
tests/test_extensions.py
CHANGED
@@ -53,6 +53,7 @@ class TestExtensions:
|
|
53
53
|
@pytest.mark.parametrize("min_distance", [0, 5, 10])
|
54
54
|
def test_find_candidate_indices(self, dimension, dtype, min_distance):
|
55
55
|
coordinates = COORDINATES[dimension].astype(dtype)
|
56
|
+
print(coordinates.shape)
|
56
57
|
|
57
58
|
min_distance = np.array([min_distance]).astype(dtype)[0]
|
58
59
|
|
tests/test_matching_cli.py
CHANGED
@@ -10,7 +10,7 @@ from tme import Density, Orientations
|
|
10
10
|
from tme.backends import backend as be
|
11
11
|
|
12
12
|
np.random.seed(42)
|
13
|
-
available_backends = (x for x in be.available_backends() if x != "mlx")
|
13
|
+
available_backends = tuple(x for x in be.available_backends() if x != "mlx")
|
14
14
|
|
15
15
|
|
16
16
|
def argdict_to_command(input_args, executable: str):
|
@@ -29,7 +29,7 @@ def argdict_to_command(input_args, executable: str):
|
|
29
29
|
return " ".join(ret)
|
30
30
|
|
31
31
|
|
32
|
-
class
|
32
|
+
class TestSetup:
|
33
33
|
@classmethod
|
34
34
|
def setup_class(cls):
|
35
35
|
target = np.random.rand(20, 20, 20)
|
@@ -98,6 +98,7 @@ class TestMatchTemplate:
|
|
98
98
|
use_target_mask: bool = False,
|
99
99
|
backend: str = "numpyfftw",
|
100
100
|
test_rejection_sampling: bool = False,
|
101
|
+
background_correction: str = None,
|
101
102
|
):
|
102
103
|
output_path = tempfile.NamedTemporaryFile(delete=False, suffix="pickle").name
|
103
104
|
|
@@ -120,6 +121,9 @@ class TestMatchTemplate:
|
|
120
121
|
if test_rejection_sampling:
|
121
122
|
argdict["--orientations"] = self.orientations_path
|
122
123
|
|
124
|
+
if background_correction is not None:
|
125
|
+
argdict["--background-correction"] = background_correction
|
126
|
+
|
123
127
|
if test_filter:
|
124
128
|
argdict["--lowpass"] = 30
|
125
129
|
argdict["--defocus"] = 3000
|
@@ -136,11 +140,15 @@ class TestMatchTemplate:
|
|
136
140
|
assert ret.returncode == 0
|
137
141
|
return output_path
|
138
142
|
|
143
|
+
|
144
|
+
class TestMatchTemplate(TestSetup):
|
145
|
+
|
139
146
|
@pytest.mark.parametrize("backend", available_backends)
|
140
147
|
@pytest.mark.parametrize("call_peaks", (False, True))
|
141
148
|
@pytest.mark.parametrize("use_template_mask", (False, True))
|
142
149
|
@pytest.mark.parametrize("test_filter", (False, True))
|
143
150
|
@pytest.mark.parametrize("test_rejection_sampling", (False, True))
|
151
|
+
@pytest.mark.parametrize("background_correction", (None, "phase-scrambling"))
|
144
152
|
def test_match_template(
|
145
153
|
self,
|
146
154
|
backend: bool,
|
@@ -148,8 +156,14 @@ class TestMatchTemplate:
|
|
148
156
|
use_template_mask: bool,
|
149
157
|
test_filter: bool,
|
150
158
|
test_rejection_sampling: bool,
|
159
|
+
background_correction: str,
|
151
160
|
):
|
152
|
-
|
161
|
+
# Jax does not support peak calling yet
|
162
|
+
if backend == "jax" and call_peaks:
|
163
|
+
return None
|
164
|
+
|
165
|
+
# These use different analyzers raising an error in the interface
|
166
|
+
if call_peaks and test_rejection_sampling:
|
153
167
|
return None
|
154
168
|
|
155
169
|
self.run_matching(
|
@@ -163,10 +177,11 @@ class TestMatchTemplate:
|
|
163
177
|
template_mask_path=self.template_mask_path,
|
164
178
|
target_mask_path=self.target_mask_path,
|
165
179
|
test_rejection_sampling=test_rejection_sampling,
|
180
|
+
background_correction=background_correction,
|
166
181
|
)
|
167
182
|
|
168
183
|
|
169
|
-
class TestPostprocessing(
|
184
|
+
class TestPostprocessing(TestSetup):
|
170
185
|
@classmethod
|
171
186
|
def setup_class(cls):
|
172
187
|
super().setup_class()
|
@@ -256,7 +271,6 @@ class TestPostprocessing(TestMatchTemplate):
|
|
256
271
|
}
|
257
272
|
cmd = argdict_to_command(argdict, executable="postprocess.py")
|
258
273
|
ret = subprocess.run(cmd, capture_output=True, shell=True)
|
259
|
-
print(ret)
|
260
274
|
|
261
275
|
match output_format:
|
262
276
|
case "orientations":
|
@@ -289,14 +303,14 @@ class TestPostprocessing(TestMatchTemplate):
|
|
289
303
|
assert ret.returncode == 0
|
290
304
|
|
291
305
|
|
292
|
-
class TestEstimateMemoryUsage(
|
306
|
+
class TestEstimateMemoryUsage(TestSetup):
|
293
307
|
@classmethod
|
294
308
|
def setup_class(cls):
|
295
309
|
super().setup_class()
|
296
310
|
|
297
311
|
@pytest.mark.parametrize("ncores", (1, 4, 8))
|
298
312
|
@pytest.mark.parametrize("pad_edges", (False, True))
|
299
|
-
def
|
313
|
+
def test_estimation_cli(self, ncores, pad_edges):
|
300
314
|
|
301
315
|
argdict = {
|
302
316
|
"-m": self.target_path,
|
@@ -311,7 +325,7 @@ class TestEstimateMemoryUsage(TestMatchTemplate):
|
|
311
325
|
assert ret.returncode == 0
|
312
326
|
|
313
327
|
|
314
|
-
class TestPreprocess(
|
328
|
+
class TestPreprocess(TestSetup):
|
315
329
|
@classmethod
|
316
330
|
def setup_class(cls):
|
317
331
|
super().setup_class()
|
@@ -319,7 +333,7 @@ class TestPreprocess(TestMatchTemplate):
|
|
319
333
|
@pytest.mark.parametrize("backend", available_backends)
|
320
334
|
@pytest.mark.parametrize("align_axis", (False, True))
|
321
335
|
@pytest.mark.parametrize("invert_contrast", (False, True))
|
322
|
-
def
|
336
|
+
def test_preprocess_cli(self, backend, align_axis, invert_contrast):
|
323
337
|
|
324
338
|
argdict = {
|
325
339
|
"-m": self.target_path,
|
@@ -1,5 +1,5 @@
|
|
1
|
-
import numpy as np
|
2
1
|
import pytest
|
2
|
+
import numpy as np
|
3
3
|
|
4
4
|
from scipy.ndimage import laplace
|
5
5
|
|
@@ -7,7 +7,7 @@ from tme.matching_data import MatchingData
|
|
7
7
|
from tme.memory import MATCHING_MEMORY_REGISTRY
|
8
8
|
from tme.analyzer import MaxScoreOverRotations, PeakCallerSort
|
9
9
|
from tme.matching_exhaustive import (
|
10
|
-
|
10
|
+
match_exhaustive,
|
11
11
|
MATCHING_EXHAUSTIVE_REGISTER,
|
12
12
|
register_matching_exhaustive,
|
13
13
|
)
|
@@ -35,11 +35,11 @@ class TestMatchExhaustive:
|
|
35
35
|
self.coordinates_weights = None
|
36
36
|
self.rotations = None
|
37
37
|
|
38
|
-
@pytest.mark.parametrize("evaluate_peak", (True,))
|
38
|
+
@pytest.mark.parametrize("evaluate_peak", (True, False))
|
39
39
|
@pytest.mark.parametrize("score", tuple(MATCHING_EXHAUSTIVE_REGISTER.keys()))
|
40
40
|
@pytest.mark.parametrize("job_schedule", ((2, 1),))
|
41
41
|
@pytest.mark.parametrize("pad_edge", (False, True))
|
42
|
-
def
|
42
|
+
def test_match_exhaustive(
|
43
43
|
self,
|
44
44
|
score: str,
|
45
45
|
job_schedule: int,
|
@@ -64,7 +64,7 @@ class TestMatchExhaustive:
|
|
64
64
|
if evaluate_peak:
|
65
65
|
callback_class = PeakCallerSort
|
66
66
|
|
67
|
-
ret =
|
67
|
+
ret = match_exhaustive(
|
68
68
|
matching_data=matching_data,
|
69
69
|
matching_setup=setup,
|
70
70
|
matching_score=process,
|
tests/test_matching_utils.py
CHANGED
@@ -15,7 +15,7 @@ from tme.matching_utils import (
|
|
15
15
|
apply_convolution_mode,
|
16
16
|
write_pickle,
|
17
17
|
load_pickle,
|
18
|
-
|
18
|
+
_standardize_safe,
|
19
19
|
)
|
20
20
|
|
21
21
|
BASEPATH = files("tests.data")
|
@@ -127,12 +127,12 @@ class TestMatchingUtils:
|
|
127
127
|
loaded_data = load_pickle(filename)
|
128
128
|
assert np.array_equal(loaded_data, data)
|
129
129
|
|
130
|
-
def
|
130
|
+
def test_standardize_safe(self):
|
131
131
|
template = be.random.random((10, 10)).astype(be.float32)
|
132
132
|
mask = be.ones_like(template)
|
133
133
|
n_observations = 100.0
|
134
134
|
|
135
|
-
result =
|
135
|
+
result = _standardize_safe(template, mask, n_observations)
|
136
136
|
assert result.shape == template.shape
|
137
137
|
assert result.dtype == template.dtype
|
138
138
|
assert np.allclose(result.mean(), 0, atol=0.1)
|
tests/test_orientations.py
CHANGED
@@ -95,6 +95,18 @@ class TestDensity:
|
|
95
95
|
self.orientations.rotations, orientations_new.rotations, atol=1e-3
|
96
96
|
)
|
97
97
|
|
98
|
+
@pytest.mark.parametrize("input_format", ("text", "star", "tbl"))
|
99
|
+
@pytest.mark.parametrize("output_format", ("text", "star", "tbl"))
|
100
|
+
def test_file_format_io(self, input_format: str, output_format: str):
|
101
|
+
_, output_file = mkstemp(suffix=f".{input_format}")
|
102
|
+
_, output_file2 = mkstemp(suffix=f".{output_format}")
|
103
|
+
|
104
|
+
self.orientations.to_file(output_file)
|
105
|
+
orientations_new = Orientations.from_file(output_file)
|
106
|
+
orientations_new.to_file(output_file2)
|
107
|
+
|
108
|
+
assert True
|
109
|
+
|
98
110
|
@pytest.mark.parametrize("drop_oob", (True, False))
|
99
111
|
@pytest.mark.parametrize("shape", (10, 40, 80))
|
100
112
|
@pytest.mark.parametrize("odd", (True, False))
|
tests/test_rotations.py
CHANGED
@@ -35,24 +35,19 @@ class TestRotations:
|
|
35
35
|
)
|
36
36
|
|
37
37
|
@pytest.mark.parametrize(
|
38
|
-
"initial_vector, target_vector
|
38
|
+
"initial_vector, target_vector",
|
39
39
|
[
|
40
|
-
([1, 0, 0], [0, 1, 0]
|
41
|
-
([0, 1, 0], [0, 0, 1]
|
42
|
-
([1, 1, 1], [1, 0, 0]
|
40
|
+
([1, 0, 0], [0, 1, 0]),
|
41
|
+
([0, 1, 0], [0, 0, 1]),
|
42
|
+
([1, 1, 1], [1, 0, 0]),
|
43
43
|
],
|
44
44
|
)
|
45
|
-
def test_align_vectors(self, initial_vector, target_vector
|
46
|
-
result = align_vectors(initial_vector, target_vector
|
45
|
+
def test_align_vectors(self, initial_vector, target_vector):
|
46
|
+
result = align_vectors(initial_vector, target_vector)
|
47
47
|
|
48
48
|
assert isinstance(result, np.ndarray)
|
49
|
-
|
50
|
-
|
51
|
-
assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
|
52
|
-
else:
|
53
|
-
assert len(result) == 3
|
54
|
-
result = Rotation.from_euler(convention, result, degrees=True).as_matrix()
|
55
|
-
assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
|
49
|
+
assert result.shape == (3, 3)
|
50
|
+
assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
|
56
51
|
|
57
52
|
rotated = np.dot(Rotation.from_matrix(result).as_matrix(), initial_vector)
|
58
53
|
assert np.allclose(
|
@@ -62,11 +57,11 @@ class TestRotations:
|
|
62
57
|
)
|
63
58
|
|
64
59
|
@pytest.mark.parametrize(
|
65
|
-
"cone_angle, cone_sampling, axis_angle, axis_sampling, vector, n_symmetry
|
60
|
+
"cone_angle, cone_sampling, axis_angle, axis_sampling, vector, n_symmetry",
|
66
61
|
[
|
67
|
-
(30, 5, 360, None, (1, 0, 0), 1
|
68
|
-
(45, 10, 180, 15, (0, 1, 0), 2
|
69
|
-
(60, 15, 90, 30, (0, 0, 1), 4
|
62
|
+
(30, 5, 360, None, (1, 0, 0), 1),
|
63
|
+
(45, 10, 180, 15, (0, 1, 0), 2),
|
64
|
+
(60, 15, 90, 30, (0, 0, 1), 4),
|
70
65
|
],
|
71
66
|
)
|
72
67
|
def test_get_cone_rotations(
|
@@ -77,7 +72,6 @@ class TestRotations:
|
|
77
72
|
axis_sampling,
|
78
73
|
vector,
|
79
74
|
n_symmetry,
|
80
|
-
convention,
|
81
75
|
):
|
82
76
|
result = get_cone_rotations(
|
83
77
|
cone_angle=cone_angle,
|
@@ -86,14 +80,10 @@ class TestRotations:
|
|
86
80
|
axis_sampling=axis_sampling,
|
87
81
|
reference=vector,
|
88
82
|
n_symmetry=n_symmetry,
|
89
|
-
seq=convention,
|
90
83
|
)
|
91
84
|
|
92
85
|
assert isinstance(result, np.ndarray)
|
93
|
-
|
94
|
-
assert result.shape[1:] == (3, 3)
|
95
|
-
else:
|
96
|
-
assert result.shape[1] == 3
|
86
|
+
assert result.shape[1:] == (3, 3)
|
97
87
|
|
98
88
|
def test_euler_conversion(self):
|
99
89
|
rotation_matrix_initial = np.array(
|
tests/test_structure.py
CHANGED
@@ -6,7 +6,6 @@ import pytest
|
|
6
6
|
import numpy as np
|
7
7
|
|
8
8
|
from tme import Structure
|
9
|
-
from tme.matching_utils import minimum_enclosing_box
|
10
9
|
from tme.rotations import euler_to_rotationmatrix
|
11
10
|
|
12
11
|
|
@@ -196,11 +195,6 @@ class TestStructure:
|
|
196
195
|
assert center_of_mass.shape[0] == self.structure.atom_coordinate.shape[1]
|
197
196
|
assert np.allclose(center_of_mass, [-0.89391639, 29.94908928, -2.64736741])
|
198
197
|
|
199
|
-
def test_centered(self):
|
200
|
-
ret, translation = self.structure.centered()
|
201
|
-
box = minimum_enclosing_box(coordinates=self.structure.atom_coordinate.T)
|
202
|
-
assert np.allclose(ret.center_of_mass(), np.divide(box, 2), atol=1)
|
203
|
-
|
204
198
|
def test__get_atom_weights_error(self):
|
205
199
|
with pytest.raises(NotImplementedError):
|
206
200
|
self.structure._get_atom_weights(
|
@@ -242,6 +236,6 @@ class TestStructure:
|
|
242
236
|
assert final_rmsd <= 0.1
|
243
237
|
|
244
238
|
aligned, final_rmsd = Structure.align_structures(
|
245
|
-
self.structure, structure_transform
|
239
|
+
self.structure, structure_transform
|
246
240
|
)
|
247
241
|
assert final_rmsd <= 1
|
tme/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.3.
|
1
|
+
__version__ = "0.3.2"
|
tme/analyzer/aggregation.py
CHANGED
@@ -192,6 +192,15 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
192
192
|
)
|
193
193
|
return scores, rotations, rotation_mapping, ssum
|
194
194
|
|
195
|
+
def correct_background(self, state, mean=0, inv_std=1, **kwargs):
|
196
|
+
scores, rotations, rotation_mapping, ssum = state
|
197
|
+
|
198
|
+
scores = be.subtract(scores, mean, out=scores)
|
199
|
+
scores = be.multiply(scores, inv_std, out=scores)
|
200
|
+
|
201
|
+
scores = be.maximum(scores, self._score_threshold, out=scores)
|
202
|
+
return scores, rotations, rotation_mapping, ssum
|
203
|
+
|
195
204
|
@staticmethod
|
196
205
|
def _invert_rmap(rotation_mapping: dict) -> dict:
|
197
206
|
"""
|
@@ -201,7 +210,12 @@ class MaxScoreOverRotations(AbstractAnalyzer):
|
|
201
210
|
new_map, ndim = {}, None
|
202
211
|
for k, v in rotation_mapping.items():
|
203
212
|
nbytes = be.datatype_bytes(be._float_dtype)
|
204
|
-
|
213
|
+
if nbytes == 8:
|
214
|
+
dtype = np.float64
|
215
|
+
elif nbytes == 4:
|
216
|
+
dtype = np.float32
|
217
|
+
else:
|
218
|
+
np.float16
|
205
219
|
rmat = np.frombuffer(k, dtype=dtype)
|
206
220
|
if ndim is None:
|
207
221
|
ndim = int(np.sqrt(rmat.size))
|
@@ -451,7 +465,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
451
465
|
Maximum accepted rotational deviation in degrees.
|
452
466
|
positions : BackendArray
|
453
467
|
Array of shape (n, d) with n seed point translations.
|
454
|
-
|
468
|
+
rotations : BackendArray
|
455
469
|
Array of shape (n, d, d) with n seed point rotation matrices.
|
456
470
|
reference : BackendArray
|
457
471
|
Reference orientation of the template, wlog defaults to (0,0,1).
|
@@ -489,6 +503,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
489
503
|
be.reshape(be.to_backend_array(reference), (-1,)), be._float_dtype
|
490
504
|
)
|
491
505
|
positions = be.astype(be.to_backend_array(positions), be._int_dtype)
|
506
|
+
rotations = be.astype(be.to_backend_array(rotations), be._float_dtype)
|
492
507
|
|
493
508
|
ndim = positions.shape[1]
|
494
509
|
rotate_mask = len(set(acceptance_radius)) != 1
|
@@ -515,7 +530,13 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
515
530
|
)
|
516
531
|
|
517
532
|
self._positions = positions[valid_positions]
|
518
|
-
rotations =
|
533
|
+
rotations = rotations[valid_positions]
|
534
|
+
|
535
|
+
# Convert to pull matrix to remain consistent with rotation convention
|
536
|
+
rotations = be.concatenate(
|
537
|
+
[rotations[i].T[None] for i in range(rotations.shape[0])]
|
538
|
+
)
|
539
|
+
|
519
540
|
ex = be.astype(be.to_backend_array((1, 0, 0)), be._float_dtype)
|
520
541
|
ey = be.astype(be.to_backend_array((0, 1, 0)), be._float_dtype)
|
521
542
|
ez = be.astype(be.to_backend_array((0, 0, 1)), be._float_dtype)
|
@@ -524,6 +545,15 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
524
545
|
self._normals_y = (rotations @ ey[..., None])[..., 0]
|
525
546
|
self._normals_z = (rotations @ ez[..., None])[..., 0]
|
526
547
|
|
548
|
+
# All scores will be rejected in this case. We should think about a
|
549
|
+
# unified interface for checking analyzer validity to skip such runs
|
550
|
+
if self._positions.shape[0] == 0:
|
551
|
+
|
552
|
+
def _get_score_mask(*args, **kwargs):
|
553
|
+
return 0
|
554
|
+
|
555
|
+
self._get_score_mask = _get_score_mask
|
556
|
+
|
527
557
|
# Periodic wrapping could be avoided by padding the target
|
528
558
|
shape = be.to_backend_array(self._shape)
|
529
559
|
starts = be.subtract(self._positions, extend)
|
@@ -539,9 +569,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
539
569
|
self._mask_shape = tuple(1 if i != 0 else -1 for i in range(1 + ndim))
|
540
570
|
|
541
571
|
if rotate_mask:
|
542
|
-
self._score_mask =
|
543
|
-
(rotations.shape[0], *self._score_mask.shape), dtype=be._float_dtype
|
544
|
-
)
|
572
|
+
self._score_mask = []
|
545
573
|
for i in range(rotations.shape[0]):
|
546
574
|
mask = create_mask(
|
547
575
|
mask_type="ellipse",
|
@@ -550,9 +578,10 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
550
578
|
center=tuple(extend for _ in range(ndim)),
|
551
579
|
orientation=be.to_numpy_array(rotations[i]),
|
552
580
|
)
|
553
|
-
self._score_mask
|
554
|
-
be.to_backend_array(mask), be._float_dtype
|
581
|
+
self._score_mask.append(
|
582
|
+
be.astype(be.to_backend_array(mask), be._float_dtype)[None]
|
555
583
|
)
|
584
|
+
self._score_mask = be.concatenate(self._score_mask)
|
556
585
|
|
557
586
|
def __call__(
|
558
587
|
self,
|
@@ -573,7 +602,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
573
602
|
"""
|
574
603
|
Determine whether the angle between projection of reference w.r.t to
|
575
604
|
a given rotation matrix and a set of rotations fall within the set
|
576
|
-
|
605
|
+
cone angle cutoff.
|
577
606
|
|
578
607
|
Parameters
|
579
608
|
----------
|
@@ -585,7 +614,7 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
585
614
|
BackerndArray
|
586
615
|
Boolean mask of shape (n, )
|
587
616
|
"""
|
588
|
-
template_rot = rotation_matrix @ self._reference
|
617
|
+
template_rot = rotation_matrix.T @ self._reference
|
589
618
|
|
590
619
|
x = be.sum(be.multiply(self._normals_x, template_rot), axis=1)
|
591
620
|
y = be.sum(be.multiply(self._normals_y, template_rot), axis=1)
|
@@ -596,10 +625,9 @@ class MaxScoreOverRotationsConstrained(MaxScoreOverRotations):
|
|
596
625
|
def _get_score_mask(self, mask: BackendArray, scores: BackendArray, **kwargs):
|
597
626
|
score_mask = be.zeros(scores.shape, scores.dtype)
|
598
627
|
|
599
|
-
|
600
|
-
|
628
|
+
# The indexing could be improved to avoid expanding the mask to
|
629
|
+
# the number of seed points
|
601
630
|
mask = be.reshape(mask, self._mask_shape)
|
602
|
-
|
603
631
|
score_mask = be.addat(score_mask, self._index_grid, self._score_mask * mask)
|
604
632
|
return score_mask > 0
|
605
633
|
|
@@ -663,13 +691,16 @@ class MaxScoreOverTranslations(MaxScoreOverRotations):
|
|
663
691
|
rotation_index = len(rotation_mapping)
|
664
692
|
if self._inversion_mapping:
|
665
693
|
rotation_mapping[rotation_index] = rotation_matrix
|
694
|
+
elif self._jax_mode:
|
695
|
+
rotation_index = kwargs.get("rotation_index", 0)
|
666
696
|
else:
|
667
697
|
rotation = be.tobytes(rotation_matrix)
|
668
698
|
rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
|
669
|
-
max_score = be.max(scores, axis=self._aggregate_axis)
|
670
699
|
|
671
|
-
|
672
|
-
|
700
|
+
scores = be.max(scores, axis=self._aggregate_axis)
|
701
|
+
scores = be.maximum(scores, prev_scores[rotation_index])
|
702
|
+
prev_scores = be.at(prev_scores, rotation_index, scores)
|
703
|
+
|
673
704
|
return prev_scores, rotations, rotation_mapping
|
674
705
|
|
675
706
|
@classmethod
|
tme/analyzer/base.py
CHANGED
@@ -73,6 +73,40 @@ class AbstractAnalyzer(ABC):
|
|
73
73
|
Updated analyzer state incorporating the new data.
|
74
74
|
"""
|
75
75
|
|
76
|
+
@abstractmethod
|
77
|
+
def correct_background(self, state, mean=0, inv_std=1, **kwargs):
|
78
|
+
"""
|
79
|
+
Applies flat-fielding correction to scores f as
|
80
|
+
|
81
|
+
.. math::
|
82
|
+
|
83
|
+
f' = (f - \\text{mean}) \\cdot \\text{inv_std},
|
84
|
+
|
85
|
+
transforming raw correlations to SNR-like scores.
|
86
|
+
|
87
|
+
Parameters
|
88
|
+
----------
|
89
|
+
state : tuple
|
90
|
+
Current analyzer state as returned :py:meth:`AbstractAnalyzer.init_state`
|
91
|
+
or previous invocations of :py:meth:`AbstractAnalyzer.__call__`.
|
92
|
+
mean : BackendArray, optional
|
93
|
+
Background mean (or equivalent), defaults to 0.
|
94
|
+
inv_std : BackendArray, optional
|
95
|
+
Reciprocal background standard deviation (or equivalent), defaults to 1.
|
96
|
+
|
97
|
+
Notes
|
98
|
+
-----
|
99
|
+
This method should be called after all rotations have been processed
|
100
|
+
but before calling :py:meth:`result`. The correction helps distinguish genuine
|
101
|
+
template matches from systematic background artifacts that may arise from
|
102
|
+
template edges, interpolation artifacts, or structured noise in the target.
|
103
|
+
|
104
|
+
Returns
|
105
|
+
-------
|
106
|
+
tuple
|
107
|
+
Updated analyzer state incorporating the new data.
|
108
|
+
"""
|
109
|
+
|
76
110
|
@abstractmethod
|
77
111
|
def result(self, state: Tuple, **kwargs) -> Tuple:
|
78
112
|
"""
|