pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/estimate_memory_usage.py +1 -5
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/match_template.py +163 -201
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +48 -39
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +10 -23
- {pytme-0.3b0.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +3 -4
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +14 -14
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/RECORD +54 -50
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +1 -0
- pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
- scripts/estimate_memory_usage.py +1 -5
- scripts/eval.py +93 -0
- scripts/match_template.py +163 -201
- scripts/match_template_filters.py +1200 -0
- scripts/postprocess.py +48 -39
- scripts/preprocess.py +10 -23
- scripts/preprocessor_gui.py +3 -4
- scripts/pytme_runner.py +769 -0
- scripts/refine_matches.py +0 -1
- tests/preprocessing/test_frequency_filters.py +19 -10
- tests/test_analyzer.py +122 -122
- tests/test_backends.py +1 -0
- tests/test_matching_cli.py +30 -30
- tests/test_matching_data.py +5 -5
- tests/test_matching_utils.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/__init__.py +1 -1
- tme/analyzer/_utils.py +1 -4
- tme/analyzer/aggregation.py +15 -6
- tme/analyzer/base.py +25 -36
- tme/analyzer/peaks.py +39 -113
- tme/analyzer/proxy.py +1 -0
- tme/backends/_jax_utils.py +16 -15
- tme/backends/cupy_backend.py +9 -13
- tme/backends/jax_backend.py +19 -16
- tme/backends/npfftw_backend.py +27 -25
- tme/backends/pytorch_backend.py +4 -0
- tme/density.py +5 -4
- tme/filters/__init__.py +2 -2
- tme/filters/_utils.py +32 -7
- tme/filters/bandpass.py +225 -186
- tme/filters/ctf.py +117 -67
- tme/filters/reconstruction.py +38 -9
- tme/filters/wedge.py +88 -105
- tme/filters/whitening.py +1 -6
- tme/matching_data.py +24 -36
- tme/matching_exhaustive.py +14 -11
- tme/matching_scores.py +21 -12
- tme/matching_utils.py +13 -6
- tme/orientations.py +13 -3
- tme/parser.py +109 -29
- tme/preprocessor.py +2 -2
- pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
tme/filters/wedge.py
CHANGED
@@ -7,8 +7,8 @@ Copyright (c) 2024 European Molecular Biology Laboratory
|
|
7
7
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
8
8
|
"""
|
9
9
|
|
10
|
-
import warnings
|
11
10
|
from typing import Tuple, Dict
|
11
|
+
from dataclasses import dataclass
|
12
12
|
|
13
13
|
import numpy as np
|
14
14
|
|
@@ -16,8 +16,8 @@ from ..types import NDArray
|
|
16
16
|
from ..backends import backend as be
|
17
17
|
from .compose import ComposableFilter
|
18
18
|
from ..matching_utils import centered
|
19
|
-
from ..parser import XMLParser, StarParser
|
20
19
|
from ..rotations import euler_to_rotationmatrix
|
20
|
+
from ..parser import XMLParser, StarParser, MDOCParser
|
21
21
|
from ._utils import (
|
22
22
|
centered_grid,
|
23
23
|
frequency_grid_at_angle,
|
@@ -31,56 +31,28 @@ from ._utils import (
|
|
31
31
|
__all__ = ["Wedge", "WedgeReconstructed"]
|
32
32
|
|
33
33
|
|
34
|
+
@dataclass
|
34
35
|
class Wedge(ComposableFilter):
|
35
36
|
"""
|
36
37
|
Generate wedge mask for tomographic data.
|
37
|
-
|
38
|
-
Parameters
|
39
|
-
----------
|
40
|
-
shape : tuple of int
|
41
|
-
The shape of the reconstruction volume.
|
42
|
-
tilt_axis : int
|
43
|
-
Axis the plane is tilted over, defaults to 0 (x).
|
44
|
-
opening_axis : int
|
45
|
-
The projection axis, defaults to 2 (z).
|
46
|
-
angles : tuple of float
|
47
|
-
The tilt angles.
|
48
|
-
weights : tuple of float
|
49
|
-
The weights corresponding to each tilt angle.
|
50
|
-
weight_type : str, optional
|
51
|
-
The type of weighting to apply, defaults to None.
|
52
|
-
frequency_cutoff : float, optional
|
53
|
-
Frequency cutoff for created mask. Nyquist 0.5 by default.
|
54
|
-
|
55
|
-
Returns
|
56
|
-
-------
|
57
|
-
dict
|
58
|
-
data: BackendArray
|
59
|
-
The filter mask.
|
60
|
-
shape: tuple of ints
|
61
|
-
The requested filter shape
|
62
|
-
return_real_fourier: bool
|
63
|
-
Whether data is compliant with rfftn.
|
64
|
-
is_multiplicative_filter: bool
|
65
|
-
Whether the filter is multiplicative in Fourier space.
|
66
38
|
"""
|
67
39
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
40
|
+
#: The shape of the reconstruction volume.
|
41
|
+
shape: Tuple[int] = None
|
42
|
+
#: The tilt angles.
|
43
|
+
angles: Tuple[float] = None
|
44
|
+
#: The weights corresponding to each tilt angle.
|
45
|
+
weights: Tuple[float] = None
|
46
|
+
#: Axis the plane is tilted over, defaults to 0 (x).
|
47
|
+
tilt_axis: int = 0
|
48
|
+
#: The projection axis, defaults to 2 (z).
|
49
|
+
opening_axis: int = 2
|
50
|
+
#: The type of weighting to apply, defaults to None.
|
51
|
+
weight_type: str = None
|
52
|
+
#: Frequency cutoff for created mask. Nyquist 0.5 by default.
|
53
|
+
frequency_cutoff: float = 0.5
|
54
|
+
#: The sampling rate, defaults to 1 Ångstrom / voxel.
|
55
|
+
sampling_rate: Tuple[float] = 1
|
84
56
|
|
85
57
|
@classmethod
|
86
58
|
def from_file(cls, filename: str) -> "Wedge":
|
@@ -93,6 +65,8 @@ class Wedge(ComposableFilter):
|
|
93
65
|
+-------+---------------------------------------------------------+
|
94
66
|
| .xml | WARP/M XML file |
|
95
67
|
+-------+---------------------------------------------------------+
|
68
|
+
| .mdoc | SerialEM file |
|
69
|
+
+-------+---------------------------------------------------------+
|
96
70
|
| .* | Tab-separated file with optional column names |
|
97
71
|
+-------+---------------------------------------------------------+
|
98
72
|
|
@@ -111,6 +85,8 @@ class Wedge(ComposableFilter):
|
|
111
85
|
func = _from_xml
|
112
86
|
elif filename.lower().endswith("star"):
|
113
87
|
func = _from_star
|
88
|
+
elif filename.lower().endswith("mdoc"):
|
89
|
+
func = _from_mdoc
|
114
90
|
|
115
91
|
data = func(filename)
|
116
92
|
angles, weights = data.get("angles", None), data.get("weights", None)
|
@@ -132,6 +108,9 @@ class Wedge(ComposableFilter):
|
|
132
108
|
)
|
133
109
|
|
134
110
|
def __call__(self, **kwargs: Dict) -> NDArray:
|
111
|
+
"""
|
112
|
+
Returns a Wedge stack of chosen parameters with DC component in the center.
|
113
|
+
"""
|
135
114
|
func_args = vars(self).copy()
|
136
115
|
func_args.update(kwargs)
|
137
116
|
|
@@ -170,6 +149,7 @@ class Wedge(ComposableFilter):
|
|
170
149
|
return {
|
171
150
|
"data": ret,
|
172
151
|
"shape": func_args["shape"],
|
152
|
+
"return_real_fourier": func_args.get("return_real_fourier", False),
|
173
153
|
"is_multiplicative_filter": True,
|
174
154
|
}
|
175
155
|
|
@@ -196,7 +176,12 @@ class Wedge(ComposableFilter):
|
|
196
176
|
return wedges
|
197
177
|
|
198
178
|
def weight_relion(
|
199
|
-
self,
|
179
|
+
self,
|
180
|
+
shape: Tuple[int],
|
181
|
+
opening_axis: int,
|
182
|
+
tilt_axis: int,
|
183
|
+
sampling_rate: float = 1.0,
|
184
|
+
**kwargs,
|
200
185
|
) -> NDArray:
|
201
186
|
"""
|
202
187
|
Generate weighted wedges based on the RELION 1.4 formalism, weighting each
|
@@ -211,7 +196,6 @@ class Wedge(ComposableFilter):
|
|
211
196
|
tilt_shape = compute_tilt_shape(
|
212
197
|
shape=shape, opening_axis=opening_axis, reduce_dim=True
|
213
198
|
)
|
214
|
-
|
215
199
|
wedges = np.zeros((len(self.angles), *tilt_shape))
|
216
200
|
for index, angle in enumerate(self.angles):
|
217
201
|
frequency_grid = frequency_grid_at_angle(
|
@@ -219,7 +203,7 @@ class Wedge(ComposableFilter):
|
|
219
203
|
opening_axis=opening_axis,
|
220
204
|
tilt_axis=tilt_axis,
|
221
205
|
angle=angle,
|
222
|
-
sampling_rate=
|
206
|
+
sampling_rate=sampling_rate,
|
223
207
|
)
|
224
208
|
sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
|
225
209
|
sigma = -2 * np.pi**2 * sigma**2
|
@@ -239,6 +223,7 @@ class Wedge(ComposableFilter):
|
|
239
223
|
amplitude: float = 0.245,
|
240
224
|
power: float = -1.665,
|
241
225
|
offset: float = 2.81,
|
226
|
+
sampling_rate: float = 1.0,
|
242
227
|
**kwargs,
|
243
228
|
) -> NDArray:
|
244
229
|
"""
|
@@ -264,7 +249,7 @@ class Wedge(ComposableFilter):
|
|
264
249
|
opening_axis=opening_axis,
|
265
250
|
tilt_axis=tilt_axis,
|
266
251
|
angle=angle,
|
267
|
-
sampling_rate=
|
252
|
+
sampling_rate=sampling_rate,
|
268
253
|
)
|
269
254
|
|
270
255
|
with np.errstate(divide="ignore"):
|
@@ -283,55 +268,35 @@ class Wedge(ComposableFilter):
|
|
283
268
|
return wedges
|
284
269
|
|
285
270
|
|
271
|
+
@dataclass
|
286
272
|
class WedgeReconstructed:
|
287
273
|
"""
|
288
274
|
Initialize :py:class:`WedgeReconstructed`.
|
289
|
-
|
290
|
-
Parameters
|
291
|
-
----------
|
292
|
-
angles :tuple of float, optional
|
293
|
-
The tilt angles, defaults to None.
|
294
|
-
tilt_axis : int
|
295
|
-
Axis the plane is tilted over, defaults to 0 (x).
|
296
|
-
opening_axis : int
|
297
|
-
The projection axis, defaults to 2 (z).
|
298
|
-
weights : tuple of float, optional
|
299
|
-
Weights to assign to individual wedge components.
|
300
|
-
weight_wedge : bool, optional
|
301
|
-
Whether individual wedge components should be weighted. If True and weights
|
302
|
-
is None, uses the cosine of the angle otherwise weights.
|
303
|
-
create_continuous_wedge: bool, optional
|
304
|
-
Whether to create a continous wedge or a per-component wedge. Weights are only
|
305
|
-
considered for non-continuous wedges.
|
306
|
-
frequency_cutoff : float, optional
|
307
|
-
Filter window applied during reconstruction.
|
308
|
-
**kwargs : Dict
|
309
|
-
Additional keyword arguments.
|
310
275
|
"""
|
311
276
|
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
self.create_continuous_wedge
|
331
|
-
|
277
|
+
#: The tilt angles, defaults to None.
|
278
|
+
angles: Tuple[float] = None
|
279
|
+
#: Weights to assign to individual wedge components. Not considered for continuous wedge
|
280
|
+
weights: Tuple[float] = None
|
281
|
+
#: Whether individual wedge components should be weighted.
|
282
|
+
weight_wedge: bool = False
|
283
|
+
#: Whether to create a continous wedge or a per-component wedge.
|
284
|
+
create_continuous_wedge: bool = False
|
285
|
+
#: Frequency cutoff of filter
|
286
|
+
frequency_cutoff: float = 0.5
|
287
|
+
#: Axis the plane is tilted over, defaults to 0 (x).
|
288
|
+
tilt_axis: int = 0
|
289
|
+
#: The projection axis, defaults to 2 (z).
|
290
|
+
opening_axis: int = 2
|
291
|
+
#: Filter window applied during reconstruction.
|
292
|
+
reconstruction_filter: str = None
|
293
|
+
|
294
|
+
def __post_init__(self):
|
295
|
+
if self.create_continuous_wedge:
|
296
|
+
self.angles = (min(self.angles), max(self.angles))
|
332
297
|
|
333
298
|
def __call__(
|
334
|
-
self, shape: Tuple[int], return_real_fourier: bool = False, **kwargs
|
299
|
+
self, shape: Tuple[int], return_real_fourier: bool = False, **kwargs
|
335
300
|
) -> Dict:
|
336
301
|
"""
|
337
302
|
Generate the reconstructed wedge.
|
@@ -341,10 +306,8 @@ class WedgeReconstructed:
|
|
341
306
|
shape : tuple of int
|
342
307
|
The shape of the reconstruction volume.
|
343
308
|
return_real_fourier : tuple of int
|
344
|
-
Return a shape compliant with
|
345
|
-
|
346
|
-
to False.
|
347
|
-
**kwargs : Dict
|
309
|
+
Return a shape compliant with rfftn. Defaults to False.
|
310
|
+
**kwargs : dict
|
348
311
|
Additional keyword arguments.
|
349
312
|
|
350
313
|
Returns
|
@@ -373,7 +336,6 @@ class WedgeReconstructed:
|
|
373
336
|
)
|
374
337
|
|
375
338
|
ret = func(shape=shape, **func_args)
|
376
|
-
|
377
339
|
frequency_cutoff = func_args.get("frequency_cutoff", None)
|
378
340
|
if frequency_cutoff is not None:
|
379
341
|
frequency_mask = fftfreqn(
|
@@ -565,7 +527,31 @@ def _from_xml(filename: str, **kwargs) -> Dict:
|
|
565
527
|
|
566
528
|
def _from_star(filename: str, **kwargs) -> Dict:
|
567
529
|
"""
|
568
|
-
Read tilt data from a
|
530
|
+
Read tilt data from a STAR file.
|
531
|
+
|
532
|
+
Parameters
|
533
|
+
----------
|
534
|
+
filename : str
|
535
|
+
The path to the text file.
|
536
|
+
|
537
|
+
Returns
|
538
|
+
-------
|
539
|
+
Dict
|
540
|
+
A dictionary with one key for each column.
|
541
|
+
"""
|
542
|
+
data = StarParser(filename, delimiter=None)
|
543
|
+
if "data_stopgap_wedgelist" in data:
|
544
|
+
angles = data["data_stopgap_wedgelist"]["_tilt_angle"]
|
545
|
+
weights = data["data_stopgap_wedgelist"]["_exposure"]
|
546
|
+
else:
|
547
|
+
angles = data["data_"]["_wrpAxisAngle"]
|
548
|
+
weights = data["data_"]["_wrpDose"]
|
549
|
+
return {"angles": angles, "weights": weights}
|
550
|
+
|
551
|
+
|
552
|
+
def _from_mdoc(filename: str, **kwargs) -> Dict:
|
553
|
+
"""
|
554
|
+
Read tilt data from a SerialEM MDOC file.
|
569
555
|
|
570
556
|
Parameters
|
571
557
|
----------
|
@@ -577,8 +563,9 @@ def _from_star(filename: str, **kwargs) -> Dict:
|
|
577
563
|
Dict
|
578
564
|
A dictionary with one key for each column.
|
579
565
|
"""
|
580
|
-
data =
|
581
|
-
|
566
|
+
data = MDOCParser(filename)
|
567
|
+
cumulative_exposure = np.multiply(np.add(1, data["ZValue"]), data["ExposureDose"])
|
568
|
+
return {"angles": data["TiltAngle"], "weights": cumulative_exposure}
|
582
569
|
|
583
570
|
|
584
571
|
def _from_text(filename: str, **kwargs) -> Dict:
|
@@ -604,10 +591,6 @@ def _from_text(filename: str, **kwargs) -> Dict:
|
|
604
591
|
if "angles" in data[0]:
|
605
592
|
headers = data.pop(0)
|
606
593
|
else:
|
607
|
-
warnings.warn(
|
608
|
-
f"Did not find a column named 'angles' in {filename}. Assuming "
|
609
|
-
"first column specifies angles."
|
610
|
-
)
|
611
594
|
if len(data[0]) != 1:
|
612
595
|
raise ValueError(
|
613
596
|
"Found more than one column without column names. Please add "
|
tme/filters/whitening.py
CHANGED
@@ -24,11 +24,6 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
24
24
|
"""
|
25
25
|
Compute Fourier power spectrums and perform whitening.
|
26
26
|
|
27
|
-
Parameters
|
28
|
-
----------
|
29
|
-
**kwargs : Dict, optional
|
30
|
-
Additional keyword arguments.
|
31
|
-
|
32
27
|
References
|
33
28
|
----------
|
34
29
|
.. [1] de Teresa-Trueba, I.; Goetz, S. K.; Mattausch, A.; Stojanovska, F.; Zimmerli, C. E.;
|
@@ -39,7 +34,7 @@ class LinearWhiteningFilter(ComposableFilter):
|
|
39
34
|
13375 (2023)
|
40
35
|
"""
|
41
36
|
|
42
|
-
def __init__(self, **kwargs):
|
37
|
+
def __init__(self, *args, **kwargs):
|
43
38
|
pass
|
44
39
|
|
45
40
|
@staticmethod
|
tme/matching_data.py
CHANGED
@@ -175,7 +175,7 @@ class MatchingData:
|
|
175
175
|
target_pad: NDArray = None,
|
176
176
|
template_pad: NDArray = None,
|
177
177
|
invert_target: bool = False,
|
178
|
-
) -> "MatchingData":
|
178
|
+
) -> Tuple["MatchingData", Tuple]:
|
179
179
|
"""
|
180
180
|
Subset class instance based on slices.
|
181
181
|
|
@@ -194,6 +194,8 @@ class MatchingData:
|
|
194
194
|
-------
|
195
195
|
:py:class:`MatchingData`
|
196
196
|
Newly allocated subset of class instance.
|
197
|
+
Tuple
|
198
|
+
Translation offset to merge analyzers.
|
197
199
|
|
198
200
|
Examples
|
199
201
|
--------
|
@@ -251,8 +253,9 @@ class MatchingData:
|
|
251
253
|
target_offset[mask] = [x.start for x in target_slice]
|
252
254
|
mask = np.subtract(1, self._target_batch).astype(bool)
|
253
255
|
template_offset = np.zeros(len(self._output_template_shape), dtype=int)
|
254
|
-
template_offset[mask] = [x.start for x in template_slice]
|
255
|
-
|
256
|
+
template_offset[mask] = [x.start for x, b in zip(template_slice, mask) if b]
|
257
|
+
|
258
|
+
translation_offset = tuple(x for x in target_offset)
|
256
259
|
|
257
260
|
ret.target_filter = self.target_filter
|
258
261
|
ret.template_filter = self.template_filter
|
@@ -262,7 +265,7 @@ class MatchingData:
|
|
262
265
|
template_dim=getattr(self, "_template_dim", None),
|
263
266
|
)
|
264
267
|
|
265
|
-
return ret
|
268
|
+
return ret, translation_offset
|
266
269
|
|
267
270
|
def to_backend(self):
|
268
271
|
"""
|
@@ -323,11 +326,6 @@ class MatchingData:
|
|
323
326
|
|
324
327
|
target_ndim -= len(target_dims)
|
325
328
|
template_ndim -= len(template_dims)
|
326
|
-
|
327
|
-
if target_ndim != template_ndim:
|
328
|
-
raise ValueError(
|
329
|
-
f"Dimension mismatch: Target ({target_ndim}) Template ({template_ndim})."
|
330
|
-
)
|
331
329
|
self._set_matching_dimension(
|
332
330
|
target_dims=target_dims, template_dims=template_dims
|
333
331
|
)
|
@@ -492,29 +490,26 @@ class MatchingData:
|
|
492
490
|
def _fourier_padding(
|
493
491
|
target_shape: Tuple[int],
|
494
492
|
template_shape: Tuple[int],
|
495
|
-
|
493
|
+
pad_target: bool = False,
|
496
494
|
batch_mask: Tuple[int] = None,
|
497
495
|
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
498
|
-
fourier_pad = template_shape
|
499
|
-
fourier_shift = np.zeros_like(template_shape)
|
500
|
-
|
501
496
|
if batch_mask is None:
|
502
497
|
batch_mask = np.zeros_like(template_shape)
|
503
498
|
batch_mask = np.asarray(batch_mask)
|
504
499
|
|
505
|
-
|
506
|
-
fourier_pad = np.ones(len(fourier_pad), dtype=int)
|
500
|
+
fourier_pad = np.ones(len(template_shape), dtype=int)
|
507
501
|
fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
|
508
502
|
fourier_pad = np.add(fourier_pad, batch_mask)
|
509
503
|
|
504
|
+
# Avoid padding batch dimensions
|
510
505
|
pad_shape = np.maximum(target_shape, template_shape)
|
506
|
+
pad_shape = np.maximum(target_shape, np.multiply(1 - batch_mask, pad_shape))
|
511
507
|
ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
|
512
508
|
conv_shape, fast_shape, fast_ft_shape = ret
|
513
509
|
|
514
510
|
template_mod = np.mod(template_shape, 2)
|
515
|
-
|
516
|
-
|
517
|
-
fourier_shift = np.subtract(fourier_shift, template_mod)
|
511
|
+
fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
|
512
|
+
fourier_shift = np.subtract(fourier_shift, template_mod)
|
518
513
|
|
519
514
|
shape_diff = np.multiply(
|
520
515
|
np.subtract(target_shape, template_shape), 1 - batch_mask
|
@@ -523,34 +518,27 @@ class MatchingData:
|
|
523
518
|
if np.sum(shape_mask):
|
524
519
|
shape_shift = np.divide(shape_diff, 2)
|
525
520
|
offset = np.mod(shape_diff, 2)
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
)
|
531
|
-
else:
|
532
|
-
warnings.warn(
|
533
|
-
"Template is larger than target and padding is turned off. Consider "
|
534
|
-
"swapping them or activate padding. Correcting the shift for now."
|
535
|
-
)
|
521
|
+
warnings.warn(
|
522
|
+
"Template is larger than target and padding is turned off. Consider "
|
523
|
+
"swapping them or activate padding. Correcting the shift for now."
|
524
|
+
)
|
536
525
|
shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
|
537
526
|
fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
|
538
527
|
|
539
|
-
|
528
|
+
if pad_target:
|
529
|
+
fourier_shift = np.subtract(fourier_shift, np.subtract(1, template_mod))
|
540
530
|
|
531
|
+
fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
|
541
532
|
return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
|
542
533
|
|
543
|
-
def fourier_padding(
|
544
|
-
self, pad_fourier: bool = False
|
545
|
-
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
|
534
|
+
def fourier_padding(self, pad_target: bool = False) -> Tuple:
|
546
535
|
"""
|
547
536
|
Computes efficient shape four Fourier transforms and potential associated shifts.
|
548
537
|
|
549
538
|
Parameters
|
550
539
|
----------
|
551
|
-
|
552
|
-
|
553
|
-
shape and template shape minus one, False by default.
|
540
|
+
pad_target : bool, optional
|
541
|
+
Whether the target has been padded to the full convolution shape.
|
554
542
|
|
555
543
|
Returns
|
556
544
|
-------
|
@@ -565,7 +553,7 @@ class MatchingData:
|
|
565
553
|
target_shape=be.to_numpy_array(self._output_target_shape),
|
566
554
|
template_shape=be.to_numpy_array(self._output_template_shape),
|
567
555
|
batch_mask=be.to_numpy_array(self._batch_mask),
|
568
|
-
|
556
|
+
pad_target=pad_target,
|
569
557
|
)
|
570
558
|
|
571
559
|
def computation_schedule(
|
tme/matching_exhaustive.py
CHANGED
@@ -149,7 +149,7 @@ def scan(
|
|
149
149
|
n_jobs: int = 4,
|
150
150
|
callback_class: CallbackClass = None,
|
151
151
|
callback_class_args: Dict = {},
|
152
|
-
|
152
|
+
pad_target: bool = True,
|
153
153
|
pad_template_filter: bool = True,
|
154
154
|
interpolation_order: int = 3,
|
155
155
|
jobs_per_callback_class: int = 8,
|
@@ -176,8 +176,8 @@ def scan(
|
|
176
176
|
Analyzer class pointer to operate on computed scores.
|
177
177
|
callback_class_args : dict, optional
|
178
178
|
Arguments passed to the callback_class. Default is an empty dictionary.
|
179
|
-
|
180
|
-
Whether to pad target
|
179
|
+
pad_target: bool, optional
|
180
|
+
Whether to pad target to the full convolution shape.
|
181
181
|
pad_template_filter: bool, optional
|
182
182
|
Whether to pad potential template filters to the full convolution shape.
|
183
183
|
interpolation_order : int, optional
|
@@ -210,17 +210,17 @@ def scan(
|
|
210
210
|
>>> )
|
211
211
|
|
212
212
|
"""
|
213
|
-
matching_data = matching_data.subset_by_slice(
|
213
|
+
matching_data, translation_offset = matching_data.subset_by_slice(
|
214
214
|
target_slice=target_slice,
|
215
215
|
template_slice=template_slice,
|
216
|
-
target_pad=matching_data.target_padding(pad_target=
|
216
|
+
target_pad=matching_data.target_padding(pad_target=pad_target),
|
217
217
|
)
|
218
218
|
|
219
219
|
matching_data.to_backend()
|
220
220
|
template_shape = matching_data._batch_shape(
|
221
221
|
matching_data.template.shape, matching_data._template_batch
|
222
222
|
)
|
223
|
-
conv, fwd, inv, shift = matching_data.fourier_padding(
|
223
|
+
conv, fwd, inv, shift = matching_data.fourier_padding(pad_target=pad_target)
|
224
224
|
|
225
225
|
template_filter = _setup_template_filter_apply_target_filter(
|
226
226
|
matching_data=matching_data,
|
@@ -231,14 +231,14 @@ def scan(
|
|
231
231
|
|
232
232
|
default_callback_args = {
|
233
233
|
"shape": fwd,
|
234
|
-
"offset":
|
234
|
+
"offset": translation_offset,
|
235
235
|
"fourier_shift": shift,
|
236
236
|
"fast_shape": fwd,
|
237
237
|
"targetshape": matching_data._output_shape,
|
238
238
|
"templateshape": template_shape,
|
239
239
|
"convolution_shape": conv,
|
240
240
|
"thread_safe": n_jobs > 1,
|
241
|
-
"convolution_mode": "valid" if
|
241
|
+
"convolution_mode": "valid" if pad_target else "same",
|
242
242
|
"shm_handler": shm_handler,
|
243
243
|
"only_unique_rotations": True,
|
244
244
|
"aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
|
@@ -425,14 +425,17 @@ def scan_subsets(
|
|
425
425
|
splits = tuple(product(target_splits, template_splits))
|
426
426
|
|
427
427
|
outer_jobs, inner_jobs = job_schedule
|
428
|
-
if
|
428
|
+
if be._backend_name == "jax":
|
429
|
+
func = be.scan
|
430
|
+
|
429
431
|
corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
|
430
|
-
results =
|
432
|
+
results = func(
|
431
433
|
matching_data=matching_data,
|
432
434
|
splits=splits,
|
433
435
|
n_jobs=outer_jobs,
|
434
436
|
rotate_mask=matching_score != corr_scoring,
|
435
437
|
callback_class=callback_class,
|
438
|
+
callback_class_args=callback_class_args,
|
436
439
|
)
|
437
440
|
else:
|
438
441
|
results = Parallel(n_jobs=outer_jobs, verbose=verbose)(
|
@@ -447,7 +450,7 @@ def scan_subsets(
|
|
447
450
|
callback_class=callback_class,
|
448
451
|
callback_class_args=callback_class_args,
|
449
452
|
interpolation_order=interpolation_order,
|
450
|
-
|
453
|
+
pad_target=pad_target_edges,
|
451
454
|
gpu_index=index % outer_jobs,
|
452
455
|
pad_template_filter=pad_template_filter,
|
453
456
|
target_slice=target_split,
|
tme/matching_scores.py
CHANGED
@@ -593,20 +593,27 @@ def corr_scoring(
|
|
593
593
|
**_fftargs,
|
594
594
|
)
|
595
595
|
|
596
|
+
center = be.divide(be.to_backend_array(template.shape) - 1, 2)
|
596
597
|
unpadded_slice = tuple(slice(0, stop) for stop in template.shape)
|
598
|
+
|
599
|
+
template_rot = be.zeros(template.shape, be._float_dtype)
|
597
600
|
for index in range(rotations.shape[0]):
|
601
|
+
# d+1, d+1 rigid transform matrix from d,d rotation matrix
|
598
602
|
rotation = rotations[index]
|
599
|
-
|
600
|
-
|
603
|
+
matrix = be._rigid_transform_matrix(rotation_matrix=rotation, center=center)
|
604
|
+
template_rot, _ = be.rigid_transform(
|
601
605
|
arr=template,
|
602
|
-
rotation_matrix=
|
603
|
-
out=
|
604
|
-
use_geometric_center=True,
|
606
|
+
rotation_matrix=matrix,
|
607
|
+
out=template_rot,
|
605
608
|
order=interpolation_order,
|
606
|
-
cache=
|
609
|
+
cache=True,
|
607
610
|
)
|
608
|
-
|
609
|
-
|
611
|
+
|
612
|
+
template_rot = template_filter_func(template_rot, ft_temp, template_filter)
|
613
|
+
norm_template(template_rot, template_mask, mask_sum)
|
614
|
+
|
615
|
+
arr = be.fill(arr, 0)
|
616
|
+
arr[unpadded_slice] = template_rot
|
610
617
|
|
611
618
|
ft_temp = rfftn(arr, ft_temp)
|
612
619
|
ft_temp = be.multiply(ft_target, ft_temp, out=ft_temp)
|
@@ -729,7 +736,7 @@ def flc_scoring(
|
|
729
736
|
out_mask=temp,
|
730
737
|
use_geometric_center=True,
|
731
738
|
order=interpolation_order,
|
732
|
-
cache=
|
739
|
+
cache=True,
|
733
740
|
)
|
734
741
|
|
735
742
|
n_obs = be.sum(temp)
|
@@ -875,7 +882,7 @@ def mcc_scoring(
|
|
875
882
|
out_mask=temp,
|
876
883
|
use_geometric_center=True,
|
877
884
|
order=interpolation_order,
|
878
|
-
cache=
|
885
|
+
cache=True,
|
879
886
|
)
|
880
887
|
|
881
888
|
template_filter_func(template_rot, temp_ft, template_filter)
|
@@ -1035,7 +1042,8 @@ def flc_scoring2(
|
|
1035
1042
|
out_mask=tmp_sqz,
|
1036
1043
|
use_geometric_center=True,
|
1037
1044
|
order=interpolation_order,
|
1038
|
-
cache=
|
1045
|
+
cache=True,
|
1046
|
+
batched=True,
|
1039
1047
|
)
|
1040
1048
|
n_obs = be.sum(tmp_sqz, axis=data_axes, keepdims=True)
|
1041
1049
|
arr_norm = template_filter_func(arr_sqz, ft_temp, template_filter)
|
@@ -1155,7 +1163,8 @@ def corr_scoring2(
|
|
1155
1163
|
out=arr_sqz,
|
1156
1164
|
use_geometric_center=True,
|
1157
1165
|
order=interpolation_order,
|
1158
|
-
cache=
|
1166
|
+
cache=True,
|
1167
|
+
batched=True,
|
1159
1168
|
)
|
1160
1169
|
arr_norm = template_filter_func(arr_sqz, ft_sqz, template_filter)
|
1161
1170
|
norm_template(arr_norm[unpadded_slice], template_mask, mask_sum, axis=data_axes)
|