pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/match_template.py +91 -142
- {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/postprocess.py +20 -29
- pytme-0.2.3.data/scripts/preprocess.py +132 -0
- {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +6 -9
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/METADATA +11 -10
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/RECORD +33 -32
- pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/match_template.py +91 -142
- scripts/postprocess.py +20 -29
- scripts/preprocess.py +95 -56
- scripts/preprocessor_gui.py +6 -9
- tme/__version__.py +1 -1
- tme/analyzer.py +9 -6
- tme/backends/__init__.py +1 -1
- tme/backends/_jax_utils.py +10 -8
- tme/backends/cupy_backend.py +2 -7
- tme/backends/jax_backend.py +34 -20
- tme/backends/npfftw_backend.py +3 -2
- tme/backends/pytorch_backend.py +10 -7
- tme/density.py +15 -8
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +24 -17
- tme/matching_exhaustive.py +36 -19
- tme/matching_scores.py +5 -2
- tme/matching_utils.py +7 -2
- tme/orientations.py +26 -9
- tme/preprocessing/composable_filter.py +7 -4
- tme/preprocessing/tilt_series.py +10 -32
- {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/WHEEL +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
scripts/preprocess.py
CHANGED
@@ -1,93 +1,132 @@
|
|
1
1
|
#!python3
|
2
|
-
"""
|
3
|
-
on a provided yaml configuration obtaiend from preprocessor_gui.py.
|
2
|
+
""" Preprocessing routines for template matching.
|
4
3
|
|
5
4
|
Copyright (c) 2023 European Molecular Biology Laboratory
|
6
5
|
|
7
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
8
7
|
"""
|
9
|
-
import
|
8
|
+
import warnings
|
10
9
|
import argparse
|
11
|
-
import
|
12
|
-
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from tme import Density, Structure
|
13
|
+
from tme.backends import backend as be
|
14
|
+
from tme.preprocessing.frequency_filters import BandPassFilter
|
13
15
|
|
14
16
|
|
15
17
|
def parse_args():
|
16
18
|
parser = argparse.ArgumentParser(
|
17
|
-
description=
|
18
|
-
|
19
|
-
Apply preprocessing to an input file based on a provided YAML configuration.
|
20
|
-
|
21
|
-
Expected YAML file format:
|
22
|
-
```yaml
|
23
|
-
<method_name>:
|
24
|
-
<parameter1>: <value1>
|
25
|
-
<parameter2>: <value2>
|
26
|
-
...
|
27
|
-
```
|
28
|
-
"""
|
29
|
-
),
|
30
|
-
formatter_class=argparse.RawDescriptionHelpFormatter,
|
19
|
+
description="Perform template matching preprocessing.",
|
20
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
31
21
|
)
|
32
|
-
|
33
|
-
|
34
|
-
|
22
|
+
|
23
|
+
io_group = parser.add_argument_group("Input / Output")
|
24
|
+
io_group.add_argument(
|
25
|
+
"-m",
|
26
|
+
"--data",
|
27
|
+
dest="data",
|
35
28
|
type=str,
|
36
29
|
required=True,
|
37
|
-
help="Path to
|
30
|
+
help="Path to a file in PDB/MMCIF, CCP4/MRC, EM, H5 or a format supported by "
|
31
|
+
"tme.density.Density.from_file "
|
32
|
+
"https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
|
38
33
|
)
|
39
|
-
|
40
|
-
"-
|
41
|
-
"--
|
34
|
+
io_group.add_argument(
|
35
|
+
"-o",
|
36
|
+
"--output",
|
37
|
+
dest="output",
|
42
38
|
type=str,
|
43
39
|
required=True,
|
44
|
-
help="Path
|
40
|
+
help="Path the output should be written to.",
|
45
41
|
)
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
42
|
+
|
43
|
+
box_group = parser.add_argument_group("Box")
|
44
|
+
box_group.add_argument(
|
45
|
+
"--box_size",
|
46
|
+
dest="box_size",
|
47
|
+
type=int,
|
50
48
|
required=True,
|
51
|
-
help="
|
49
|
+
help="Box size of the output",
|
52
50
|
)
|
53
|
-
|
54
|
-
"--
|
51
|
+
box_group.add_argument(
|
52
|
+
"--sampling_rate",
|
53
|
+
dest="sampling_rate",
|
54
|
+
type=float,
|
55
|
+
required=True,
|
56
|
+
help="Sampling rate of the output file.",
|
55
57
|
)
|
56
58
|
|
59
|
+
modulation_group = parser.add_argument_group("Modulation")
|
60
|
+
modulation_group.add_argument(
|
61
|
+
"--invert_contrast",
|
62
|
+
dest="invert_contrast",
|
63
|
+
action="store_true",
|
64
|
+
required=False,
|
65
|
+
help="Inverts the template contrast.",
|
66
|
+
)
|
67
|
+
modulation_group.add_argument(
|
68
|
+
"--lowpass",
|
69
|
+
dest="lowpass",
|
70
|
+
type=float,
|
71
|
+
required=False,
|
72
|
+
default=None,
|
73
|
+
help="Lowpass filter the template to the given resolution. Nyquist by default. "
|
74
|
+
"A value of 0 disables the filter.",
|
75
|
+
)
|
76
|
+
modulation_group.add_argument(
|
77
|
+
"--no_centering",
|
78
|
+
dest="no_centering",
|
79
|
+
action="store_true",
|
80
|
+
help="Assumes the template is already centered and omits centering.",
|
81
|
+
)
|
57
82
|
args = parser.parse_args()
|
58
|
-
|
59
83
|
return args
|
60
84
|
|
61
85
|
|
62
86
|
def main():
|
63
87
|
args = parse_args()
|
64
|
-
with open(args.yaml_file, "r") as f:
|
65
|
-
preprocess_settings = yaml.safe_load(f)
|
66
88
|
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
89
|
+
try:
|
90
|
+
data = Structure.from_file(args.data)
|
91
|
+
data = Density.from_structure(data, sampling_rate=args.sampling_rate)
|
92
|
+
except NotImplementedError:
|
93
|
+
data = Density.from_file(args.data)
|
94
|
+
|
95
|
+
if not args.no_centering:
|
96
|
+
data, _ = data.centered(0)
|
97
|
+
|
98
|
+
recommended_box = be.compute_convolution_shapes([args.box_size], [1])[1][0]
|
99
|
+
if recommended_box != args.box_size:
|
100
|
+
warnings.warn(
|
101
|
+
f"Consider using --box_size {recommended_box} instead of {args.box_size}."
|
71
102
|
)
|
72
103
|
|
73
|
-
|
74
|
-
|
75
|
-
|
104
|
+
data.pad(
|
105
|
+
np.multiply(args.box_size, np.divide(args.sampling_rate, data.sampling_rate)),
|
106
|
+
center=True,
|
107
|
+
)
|
108
|
+
|
109
|
+
bpf_mask = 1
|
110
|
+
lowpass = 2 * args.sampling_rate if args.lowpass is None else args.lowpass
|
111
|
+
if args.lowpass != 0:
|
112
|
+
bpf_mask = BandPassFilter(
|
113
|
+
lowpass=lowpass,
|
114
|
+
highpass=None,
|
115
|
+
use_gaussian=True,
|
116
|
+
return_real_fourier=True,
|
117
|
+
shape_is_real_fourier=False,
|
118
|
+
)(shape=data.shape)["data"]
|
76
119
|
|
77
|
-
|
78
|
-
|
120
|
+
data_ft = np.fft.rfftn(data.data, s=data.shape)
|
121
|
+
data_ft = np.multiply(data_ft, bpf_mask, out=data_ft)
|
122
|
+
data.data = np.fft.irfftn(data_ft, s=data.shape).real
|
79
123
|
|
80
|
-
|
81
|
-
preprocessor = Preprocessor()
|
82
|
-
method = getattr(preprocessor, method_name, None)
|
83
|
-
if not method:
|
84
|
-
raise ValueError(
|
85
|
-
f"{method} does not exist in dge.preprocessor.Preprocessor class."
|
86
|
-
)
|
124
|
+
data = data.resample(args.sampling_rate, method="spline", order=3)
|
87
125
|
|
88
|
-
|
89
|
-
|
126
|
+
if args.invert_contrast:
|
127
|
+
data.data = data.data * -1
|
90
128
|
|
129
|
+
data.to_file(args.output)
|
91
130
|
|
92
131
|
if __name__ == "__main__":
|
93
|
-
main()
|
132
|
+
main()
|
scripts/preprocessor_gui.py
CHANGED
@@ -132,14 +132,6 @@ def local_gaussian_filter(
|
|
132
132
|
)
|
133
133
|
|
134
134
|
|
135
|
-
def ntree(
|
136
|
-
template: NDArray,
|
137
|
-
sigma_range: Tuple[float, float],
|
138
|
-
**kwargs: dict,
|
139
|
-
) -> NDArray:
|
140
|
-
return preprocessor.ntree_filter(template=template, sigma_range=sigma_range)
|
141
|
-
|
142
|
-
|
143
135
|
def mean(
|
144
136
|
template: NDArray,
|
145
137
|
width: int,
|
@@ -197,6 +189,10 @@ def compute_power_spectrum(template: NDArray) -> NDArray:
|
|
197
189
|
return np.fft.fftshift(np.log(np.abs(np.fft.fftn(template))))
|
198
190
|
|
199
191
|
|
192
|
+
def invert_contrast(template: NDArray) -> NDArray:
|
193
|
+
return template * -1
|
194
|
+
|
195
|
+
|
200
196
|
def widgets_from_function(function: Callable, exclude_params: List = ["self"]):
|
201
197
|
"""
|
202
198
|
Creates list of magicui widgets by inspecting function typing ann
|
@@ -252,13 +248,13 @@ WRAPPED_FUNCTIONS = {
|
|
252
248
|
"gaussian_filter": gaussian_filter,
|
253
249
|
"bandpass_filter": bandpass_filter,
|
254
250
|
"edge_gaussian_filter": edge_gaussian_filter,
|
255
|
-
"ntree_filter": ntree,
|
256
251
|
"local_gaussian_filter": local_gaussian_filter,
|
257
252
|
"difference_of_gaussian_filter": difference_of_gaussian_filter,
|
258
253
|
"mean_filter": mean,
|
259
254
|
"wedge_filter": wedge,
|
260
255
|
"power_spectrum": compute_power_spectrum,
|
261
256
|
"ctf": ctf_filter,
|
257
|
+
"invert_contrast": invert_contrast,
|
262
258
|
}
|
263
259
|
|
264
260
|
EXCLUDED_FUNCTIONS = [
|
@@ -634,6 +630,7 @@ class MaskWidget(widgets.Container):
|
|
634
630
|
|
635
631
|
data = active_layer.data.copy()
|
636
632
|
cutoff = np.quantile(data, self.percentile_range_edit.value / 100)
|
633
|
+
cutoff = max(cutoff, np.finfo(np.float32).resolution)
|
637
634
|
data[data < cutoff] = 0
|
638
635
|
|
639
636
|
center_of_mass = Density.center_of_mass(np.abs(data), 0)
|
tme/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.2.
|
1
|
+
__version__ = "0.2.3"
|
tme/analyzer.py
CHANGED
@@ -459,14 +459,14 @@ class PeakCaller(ABC):
|
|
459
459
|
|
460
460
|
final_order = top_scores[
|
461
461
|
filter_points_indices(
|
462
|
-
coordinates=peak_positions[top_scores],
|
462
|
+
coordinates=peak_positions[top_scores, :],
|
463
463
|
min_distance=self.min_distance,
|
464
464
|
batch_dims=self.batch_dims,
|
465
465
|
)
|
466
466
|
]
|
467
467
|
|
468
|
-
self.peak_list[0] = peak_positions[final_order,]
|
469
|
-
self.peak_list[1] = rotations[final_order,]
|
468
|
+
self.peak_list[0] = peak_positions[final_order, :]
|
469
|
+
self.peak_list[1] = rotations[final_order, :]
|
470
470
|
self.peak_list[2] = peak_scores[final_order]
|
471
471
|
self.peak_list[3] = peak_details[final_order]
|
472
472
|
|
@@ -475,6 +475,7 @@ class PeakCaller(ABC):
|
|
475
475
|
fast_shape: Tuple[int],
|
476
476
|
targetshape: Tuple[int],
|
477
477
|
templateshape: Tuple[int],
|
478
|
+
convolution_shape: Tuple[int] = None,
|
478
479
|
fourier_shift: Tuple[int] = None,
|
479
480
|
convolution_mode: str = None,
|
480
481
|
shared_memory_handler=None,
|
@@ -488,6 +489,7 @@ class PeakCaller(ABC):
|
|
488
489
|
return self
|
489
490
|
|
490
491
|
# Wrap peaks around score space
|
492
|
+
convolution_shape = be.to_backend_array(convolution_shape)
|
491
493
|
fast_shape = be.to_backend_array(fast_shape)
|
492
494
|
if fourier_shift is not None:
|
493
495
|
fourier_shift = be.to_backend_array(fourier_shift)
|
@@ -501,10 +503,9 @@ class PeakCaller(ABC):
|
|
501
503
|
)
|
502
504
|
|
503
505
|
# Remove padding to fast Fourier (and potential full convolution) shape
|
506
|
+
output_shape = convolution_shape
|
504
507
|
targetshape = be.to_backend_array(targetshape)
|
505
508
|
templateshape = be.to_backend_array(templateshape)
|
506
|
-
fast_shape = be.minimum(be.add(targetshape, templateshape) - 1, fast_shape)
|
507
|
-
output_shape = fast_shape
|
508
509
|
if convolution_mode == "same":
|
509
510
|
output_shape = targetshape
|
510
511
|
elif convolution_mode == "valid":
|
@@ -515,7 +516,7 @@ class PeakCaller(ABC):
|
|
515
516
|
|
516
517
|
output_shape = be.to_backend_array(output_shape)
|
517
518
|
starts = be.astype(
|
518
|
-
be.divide(be.subtract(
|
519
|
+
be.divide(be.subtract(convolution_shape, output_shape), 2),
|
519
520
|
be._int_dtype,
|
520
521
|
)
|
521
522
|
stops = be.add(starts, output_shape)
|
@@ -1019,6 +1020,7 @@ class MaxScoreOverRotations:
|
|
1019
1020
|
self,
|
1020
1021
|
targetshape: Tuple[int],
|
1021
1022
|
templateshape: Tuple[int],
|
1023
|
+
convolution_shape: Tuple[int],
|
1022
1024
|
fourier_shift: Tuple[int] = None,
|
1023
1025
|
convolution_mode: str = None,
|
1024
1026
|
shared_memory_handler=None,
|
@@ -1039,6 +1041,7 @@ class MaxScoreOverRotations:
|
|
1039
1041
|
"s1": targetshape,
|
1040
1042
|
"s2": templateshape,
|
1041
1043
|
"convolution_mode": convolution_mode,
|
1044
|
+
"convolution_shape": convolution_shape,
|
1042
1045
|
}
|
1043
1046
|
if convolution_mode is not None:
|
1044
1047
|
scores = apply_convolution_mode(scores, **convargs)
|
tme/backends/__init__.py
CHANGED
tme/backends/_jax_utils.py
CHANGED
@@ -19,9 +19,9 @@ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
|
|
19
19
|
"""
|
20
20
|
Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
|
21
21
|
"""
|
22
|
-
template_ft = jnp.fft.rfftn(template)
|
22
|
+
template_ft = jnp.fft.rfftn(template, s=template.shape)
|
23
23
|
template_ft = template_ft.at[:].multiply(ft_target)
|
24
|
-
correlation = jnp.fft.irfftn(template_ft)
|
24
|
+
correlation = jnp.fft.irfftn(template_ft, s=template.shape)
|
25
25
|
return correlation
|
26
26
|
|
27
27
|
|
@@ -77,14 +77,15 @@ def _reciprocal_target_std(
|
|
77
77
|
--------
|
78
78
|
:py:meth:`tme.matching_exhaustive.flc_scoring`.
|
79
79
|
"""
|
80
|
-
|
80
|
+
ft_shape = template_mask.shape
|
81
|
+
ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
|
81
82
|
|
82
83
|
# E(X^2)- E(X)^2
|
83
|
-
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask)
|
84
|
+
exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=ft_shape)
|
84
85
|
exp_sq = exp_sq.at[:].divide(n_observations)
|
85
86
|
|
86
87
|
ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
|
87
|
-
sq_exp = jnp.fft.irfftn(ft_template_mask)
|
88
|
+
sq_exp = jnp.fft.irfftn(ft_template_mask, s=ft_shape)
|
88
89
|
sq_exp = sq_exp.at[:].divide(n_observations)
|
89
90
|
sq_exp = sq_exp.at[:].power(2)
|
90
91
|
|
@@ -99,7 +100,7 @@ def _reciprocal_target_std(
|
|
99
100
|
|
100
101
|
|
101
102
|
def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
102
|
-
arr_ft = jnp.fft.rfftn(arr)
|
103
|
+
arr_ft = jnp.fft.rfftn(arr, s=arr.shape)
|
103
104
|
arr_ft = arr_ft.at[:].multiply(arr_filter)
|
104
105
|
return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
|
105
106
|
|
@@ -107,6 +108,7 @@ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> Backen
|
|
107
108
|
def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
|
108
109
|
return arr
|
109
110
|
|
111
|
+
|
110
112
|
@partial(
|
111
113
|
pmap,
|
112
114
|
in_axes=(0,) + (None,) * 6,
|
@@ -127,8 +129,8 @@ def scan(
|
|
127
129
|
if hasattr(target_filter, "shape"):
|
128
130
|
target = _apply_fourier_filter(target, target_filter)
|
129
131
|
|
130
|
-
ft_target = jnp.fft.rfftn(target)
|
131
|
-
ft_target2 = jnp.fft.rfftn(jnp.square(target))
|
132
|
+
ft_target = jnp.fft.rfftn(target, s=fast_shape)
|
133
|
+
ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
|
132
134
|
inv_denominator, target, scoring_func = None, None, _flc_scoring
|
133
135
|
if not rotate_mask:
|
134
136
|
n_observations = jnp.sum(template_mask)
|
tme/backends/cupy_backend.py
CHANGED
@@ -149,10 +149,10 @@ class CupyBackend(NumpyFFTWBackend):
|
|
149
149
|
cache.clear()
|
150
150
|
|
151
151
|
def rfftn(arr: CupyArray, out: CupyArray) -> CupyArray:
|
152
|
-
return cufft.rfftn(arr)
|
152
|
+
return cufft.rfftn(arr, s=fast_shape)
|
153
153
|
|
154
154
|
def irfftn(arr: CupyArray, out: CupyArray) -> CupyArray:
|
155
|
-
return cufft.irfftn(arr)
|
155
|
+
return cufft.irfftn(arr, s=fast_shape)
|
156
156
|
|
157
157
|
PLAN_CACHE[current_device] = [fast_shape, fast_ft_shape]
|
158
158
|
|
@@ -167,11 +167,6 @@ class CupyBackend(NumpyFFTWBackend):
|
|
167
167
|
fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
|
168
168
|
fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
|
169
169
|
|
170
|
-
# This almost never happens but avoid cuFFT casting errors
|
171
|
-
is_odd = fast_shape[-1] % 2
|
172
|
-
fast_shape[-1] += is_odd
|
173
|
-
fast_ft_shape[-1] += is_odd
|
174
|
-
|
175
170
|
return convolution_shape, fast_shape, fast_ft_shape
|
176
171
|
|
177
172
|
def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
|
tme/backends/jax_backend.py
CHANGED
@@ -119,19 +119,6 @@ class JaxBackend(NumpyFFTWBackend):
|
|
119
119
|
|
120
120
|
return rfftn, irfftn
|
121
121
|
|
122
|
-
def compute_convolution_shapes(
|
123
|
-
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
124
|
-
) -> Tuple[List[int], List[int], List[int]]:
|
125
|
-
conv_shape, fast_shape, fast_ft_shape = super().compute_convolution_shapes(
|
126
|
-
arr1_shape, arr2_shape
|
127
|
-
)
|
128
|
-
|
129
|
-
is_odd = fast_shape[-1] % 2
|
130
|
-
fast_shape[-1] += is_odd
|
131
|
-
fast_ft_shape[-1] += is_odd
|
132
|
-
|
133
|
-
return conv_shape, fast_shape, fast_ft_shape
|
134
|
-
|
135
122
|
def rigid_transform(
|
136
123
|
self,
|
137
124
|
arr: BackendArray,
|
@@ -144,8 +131,8 @@ class JaxBackend(NumpyFFTWBackend):
|
|
144
131
|
**kwargs,
|
145
132
|
) -> Tuple[BackendArray, BackendArray]:
|
146
133
|
rotate_mask = arr_mask is not None
|
147
|
-
center = self.divide(self.to_backend_array(arr.shape), 2)[:, None]
|
148
134
|
|
135
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)[:, None]
|
149
136
|
indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
|
150
137
|
indices = indices.reshape((arr.ndim, -1))
|
151
138
|
indices = indices.at[:].add(-center)
|
@@ -200,7 +187,7 @@ class JaxBackend(NumpyFFTWBackend):
|
|
200
187
|
target_shape = tuple(
|
201
188
|
(x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
|
202
189
|
)
|
203
|
-
fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
|
190
|
+
conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
|
204
191
|
target_shape=self.to_numpy_array(target_shape),
|
205
192
|
template_shape=self.to_numpy_array(matching_data._template.shape),
|
206
193
|
pad_fourier=False,
|
@@ -210,13 +197,25 @@ class JaxBackend(NumpyFFTWBackend):
|
|
210
197
|
"convolution_mode": convolution_mode,
|
211
198
|
"fourier_shift": shift,
|
212
199
|
"targetshape": target_shape,
|
213
|
-
"templateshape": matching_data.
|
200
|
+
"templateshape": matching_data.template.shape,
|
201
|
+
"convolution_shape": conv_shape,
|
214
202
|
}
|
215
203
|
|
216
204
|
create_target_filter = matching_data.target_filter is not None
|
217
205
|
create_template_filter = matching_data.template_filter is not None
|
218
206
|
create_filter = create_target_filter or create_template_filter
|
219
207
|
|
208
|
+
# Applying the filter leads to more FFTs
|
209
|
+
fastt_shape = matching_data._template.shape
|
210
|
+
if create_template_filter:
|
211
|
+
_, fastt_shape, _, tshift = matching_data._fourier_padding(
|
212
|
+
target_shape=self.to_numpy_array(matching_data._template.shape),
|
213
|
+
template_shape=self.to_numpy_array(
|
214
|
+
[1 for _ in matching_data._template.shape]
|
215
|
+
),
|
216
|
+
pad_fourier=False,
|
217
|
+
)
|
218
|
+
|
220
219
|
ret, template_filter, target_filter = [], 1, 1
|
221
220
|
rotation_mapping = {
|
222
221
|
self.tobytes(matching_data.rotations[i]): i
|
@@ -246,12 +245,12 @@ class JaxBackend(NumpyFFTWBackend):
|
|
246
245
|
|
247
246
|
if create_template_filter:
|
248
247
|
template_filter = matching_data.template_filter(
|
249
|
-
shape=
|
248
|
+
shape=fastt_shape, **filter_args
|
250
249
|
)["data"]
|
251
250
|
template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
|
252
251
|
|
253
252
|
if create_target_filter:
|
254
|
-
target_filter = matching_data.
|
253
|
+
target_filter = matching_data.target_filter(
|
255
254
|
shape=fast_shape, **filter_args
|
256
255
|
)["data"]
|
257
256
|
target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
|
@@ -260,8 +259,8 @@ class JaxBackend(NumpyFFTWBackend):
|
|
260
259
|
base, targets = None, self._array_backend.stack(targets)
|
261
260
|
scores, rotations = scan_inner(
|
262
261
|
targets,
|
263
|
-
matching_data.template,
|
264
|
-
matching_data.template_mask,
|
262
|
+
self.topleft_pad(matching_data.template, fastt_shape),
|
263
|
+
self.topleft_pad(matching_data.template_mask, fastt_shape),
|
265
264
|
matching_data.rotations,
|
266
265
|
template_filter,
|
267
266
|
target_filter,
|
@@ -280,3 +279,18 @@ class JaxBackend(NumpyFFTWBackend):
|
|
280
279
|
ret.append(tuple(temp._postprocess(**analyzer_args)))
|
281
280
|
|
282
281
|
return ret
|
282
|
+
|
283
|
+
def get_available_memory(self) -> int:
|
284
|
+
import jax
|
285
|
+
|
286
|
+
_memory = {"cpu": 0, "gpu": 0}
|
287
|
+
for device in jax.devices():
|
288
|
+
if device.platform == "cpu":
|
289
|
+
_memory["cpu"] = super().get_available_memory()
|
290
|
+
else:
|
291
|
+
mem_stats = device.memory_stats()
|
292
|
+
_memory["gpu"] += mem_stats.get("bytes_limit", 0)
|
293
|
+
|
294
|
+
if _memory["gpu"] > 0:
|
295
|
+
return _memory["gpu"]
|
296
|
+
return _memory["cpu"]
|
tme/backends/npfftw_backend.py
CHANGED
@@ -186,7 +186,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
186
186
|
def to_sharedarr(
|
187
187
|
self, arr: NDArray, shared_memory_handler: type = None
|
188
188
|
) -> shm_type:
|
189
|
-
if
|
189
|
+
if isinstance(shared_memory_handler, SharedMemoryManager):
|
190
190
|
shm = shared_memory_handler.SharedMemory(size=arr.nbytes)
|
191
191
|
else:
|
192
192
|
shm = shared_memory.SharedMemory(create=True, size=arr.nbytes)
|
@@ -347,7 +347,8 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
347
347
|
cache: bool = False,
|
348
348
|
) -> Tuple[NDArray, NDArray]:
|
349
349
|
translation = self.zeros(arr.ndim) if translation is None else translation
|
350
|
-
|
350
|
+
|
351
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
|
351
352
|
if not use_geometric_center:
|
352
353
|
center = self.center_of_mass(arr, cutoff=0)
|
353
354
|
|
tme/backends/pytorch_backend.py
CHANGED
@@ -81,13 +81,13 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
81
81
|
|
82
82
|
def max(self, *args, **kwargs) -> NDArray:
|
83
83
|
ret = self._array_backend.amax(*args, **kwargs)
|
84
|
-
if
|
84
|
+
if isinstance(ret, self._array_backend.Tensor):
|
85
85
|
return ret
|
86
86
|
return ret[0]
|
87
87
|
|
88
88
|
def min(self, *args, **kwargs) -> NDArray:
|
89
89
|
ret = self._array_backend.amin(*args, **kwargs)
|
90
|
-
if
|
90
|
+
if isinstance(ret, self._array_backend.Tensor):
|
91
91
|
return ret
|
92
92
|
return ret[0]
|
93
93
|
|
@@ -154,7 +154,7 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
154
154
|
1, -1
|
155
155
|
)
|
156
156
|
if unraveled_coords.size(0) == 1:
|
157
|
-
return
|
157
|
+
return (unraveled_coords[0, :],)
|
158
158
|
|
159
159
|
else:
|
160
160
|
return tuple(unraveled_coords.T)
|
@@ -206,7 +206,9 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
206
206
|
else:
|
207
207
|
raise NotImplementedError("Operation only implemented for 2 and 3D inputs.")
|
208
208
|
|
209
|
-
pool = func(
|
209
|
+
pool = func(
|
210
|
+
kernel_size=min_distance, padding=min_distance // 2, return_indices=True
|
211
|
+
)
|
210
212
|
_, indices = pool(score_space.reshape(1, 1, *score_space.shape))
|
211
213
|
coordinates = self.unravel_index(indices.reshape(-1), score_space.shape)
|
212
214
|
coordinates = self.transpose(self.stack(coordinates))
|
@@ -217,7 +219,7 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
217
219
|
|
218
220
|
def from_sharedarr(self, args) -> TorchTensor:
|
219
221
|
if self.device == "cuda":
|
220
|
-
return args
|
222
|
+
return args
|
221
223
|
|
222
224
|
shm, shape, dtype = args
|
223
225
|
required_size = int(self._array_backend.prod(self.to_backend_array(shape)))
|
@@ -235,13 +237,12 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
235
237
|
|
236
238
|
nbytes = arr.numel() * arr.element_size()
|
237
239
|
|
238
|
-
if
|
240
|
+
if isinstance(shared_memory_handler, SharedMemoryManager):
|
239
241
|
shm = shared_memory_handler.SharedMemory(size=nbytes)
|
240
242
|
else:
|
241
243
|
shm = shared_memory.SharedMemory(create=True, size=nbytes)
|
242
244
|
|
243
245
|
shm.buf[:nbytes] = arr.numpy().tobytes()
|
244
|
-
|
245
246
|
return shm, arr.shape, arr.dtype
|
246
247
|
|
247
248
|
def transpose(self, arr):
|
@@ -415,6 +416,8 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
415
416
|
yield None
|
416
417
|
|
417
418
|
def device_count(self) -> int:
|
419
|
+
if self.device == "cpu":
|
420
|
+
return 1
|
418
421
|
return self._array_backend.cuda.device_count()
|
419
422
|
|
420
423
|
def reverse(self, arr: TorchTensor) -> TorchTensor:
|
tme/density.py
CHANGED
@@ -116,8 +116,8 @@ class Density:
|
|
116
116
|
response = "Density object at {}\nOrigin: {}, Sampling Rate: {}, Shape: {}"
|
117
117
|
return response.format(
|
118
118
|
hex(id(self)),
|
119
|
-
tuple(
|
120
|
-
tuple(
|
119
|
+
tuple(round(float(x), 3) for x in self.origin),
|
120
|
+
tuple(round(float(x), 3) for x in self.sampling_rate),
|
121
121
|
self.shape,
|
122
122
|
)
|
123
123
|
|
@@ -306,6 +306,10 @@ class Density:
|
|
306
306
|
"std": float(mrc.header.rms),
|
307
307
|
}
|
308
308
|
|
309
|
+
non_standard_crs = not np.all(crs_index == (0, 1, 2))
|
310
|
+
if non_standard_crs:
|
311
|
+
warnings.warn("Non standard MAPC, MAPR, MAPS, adapting data and origin.")
|
312
|
+
|
309
313
|
if is_gzipped(filename):
|
310
314
|
if use_memmap:
|
311
315
|
warnings.warn(
|
@@ -315,6 +319,10 @@ class Density:
|
|
315
319
|
use_memmap = False
|
316
320
|
|
317
321
|
if subset is not None:
|
322
|
+
subset = tuple(
|
323
|
+
subset[i] if i < len(subset) else slice(0, data_shape[i])
|
324
|
+
for i in crs_index
|
325
|
+
)
|
318
326
|
subset_shape = tuple(x.stop - x.start for x in subset)
|
319
327
|
if np.allclose(subset_shape, data_shape):
|
320
328
|
return cls._load_mrc(
|
@@ -328,18 +336,16 @@ class Density:
|
|
328
336
|
dtype=data_type,
|
329
337
|
header_size=1024 + extended_header,
|
330
338
|
)
|
331
|
-
|
332
|
-
|
333
|
-
if not use_memmap:
|
339
|
+
elif subset is None and not use_memmap:
|
334
340
|
with mrcfile.open(filename, header_only=False) as mrc:
|
335
341
|
data = mrc.data.astype(np.float32, copy=False)
|
336
342
|
else:
|
337
343
|
with mrcfile.mrcmemmap.MrcMemmap(filename, header_only=False) as mrc:
|
338
344
|
data = mrc.data
|
339
345
|
|
340
|
-
if
|
346
|
+
if non_standard_crs:
|
341
347
|
data = np.transpose(data, crs_index)
|
342
|
-
|
348
|
+
origin = np.take(origin, crs_index)
|
343
349
|
|
344
350
|
return data, origin, sampling_rate, metadata
|
345
351
|
|
@@ -873,6 +879,7 @@ class Density:
|
|
873
879
|
mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.rint(
|
874
880
|
np.divide(self.origin, self.sampling_rate)
|
875
881
|
)
|
882
|
+
mrc.header.origin = tuple(x for x in self.origin)
|
876
883
|
# mrcfile library expects origin to be in xyz format
|
877
884
|
mrc.header.mapc, mrc.header.mapr, mrc.header.maps = (1, 2, 3)
|
878
885
|
mrc.header["origin"] = tuple(self.origin[::-1])
|
@@ -1594,7 +1601,7 @@ class Density:
|
|
1594
1601
|
rotation_matrix: NDArray,
|
1595
1602
|
translation: NDArray = None,
|
1596
1603
|
order: int = 3,
|
1597
|
-
use_geometric_center: bool =
|
1604
|
+
use_geometric_center: bool = True,
|
1598
1605
|
) -> "Density":
|
1599
1606
|
"""
|
1600
1607
|
Performs a rigid transform of the class instance.
|
Binary file
|