pytme 0.1.7__cp311-cp311-macosx_14_0_arm64.whl → 0.1.8__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.1.7.data → pytme-0.1.8.data}/scripts/match_template.py +12 -10
- {pytme-0.1.7.data → pytme-0.1.8.data}/scripts/postprocess.py +38 -12
- {pytme-0.1.7.data → pytme-0.1.8.data}/scripts/preprocessor_gui.py +74 -52
- {pytme-0.1.7.dist-info → pytme-0.1.8.dist-info}/METADATA +10 -8
- {pytme-0.1.7.dist-info → pytme-0.1.8.dist-info}/RECORD +21 -21
- scripts/match_template.py +12 -10
- scripts/postprocess.py +38 -12
- scripts/preprocessor_gui.py +74 -52
- tme/__version__.py +1 -1
- tme/backends/npfftw_backend.py +0 -7
- tme/density.py +17 -8
- tme/matching_data.py +55 -24
- tme/matching_exhaustive.py +67 -19
- tme/matching_utils.py +39 -13
- tme/preprocessor.py +113 -5
- {pytme-0.1.7.data → pytme-0.1.8.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.1.7.data → pytme-0.1.8.data}/scripts/preprocess.py +0 -0
- {pytme-0.1.7.dist-info → pytme-0.1.8.dist-info}/LICENSE +0 -0
- {pytme-0.1.7.dist-info → pytme-0.1.8.dist-info}/WHEEL +0 -0
- {pytme-0.1.7.dist-info → pytme-0.1.8.dist-info}/entry_points.txt +0 -0
- {pytme-0.1.7.dist-info → pytme-0.1.8.dist-info}/top_level.txt +0 -0
scripts/postprocess.py
CHANGED
@@ -15,6 +15,7 @@ from dataclasses import dataclass
|
|
15
15
|
|
16
16
|
import numpy as np
|
17
17
|
from scipy.spatial.transform import Rotation
|
18
|
+
from numpy.typing import NDArray
|
18
19
|
|
19
20
|
from tme import Density, Structure
|
20
21
|
from tme.analyzer import (
|
@@ -28,6 +29,7 @@ from tme.matching_utils import (
|
|
28
29
|
load_pickle,
|
29
30
|
euler_to_rotationmatrix,
|
30
31
|
euler_from_rotationmatrix,
|
32
|
+
centered_mask,
|
31
33
|
)
|
32
34
|
|
33
35
|
PEAK_CALLERS = {
|
@@ -75,6 +77,13 @@ def parse_args():
|
|
75
77
|
help="Minimum distance from target boundaries. Ignored when --orientations "
|
76
78
|
"is provided.",
|
77
79
|
)
|
80
|
+
parser.add_argument(
|
81
|
+
"--mask_edges",
|
82
|
+
action="store_true",
|
83
|
+
default=False,
|
84
|
+
help="Whether to mask edges of the input score array according to the template shape."
|
85
|
+
"Uses twice the value of --min_boundary_distance if boht are provided.",
|
86
|
+
)
|
78
87
|
parser.add_argument(
|
79
88
|
"--wedge_mask",
|
80
89
|
type=str,
|
@@ -403,11 +412,25 @@ class Orientations:
|
|
403
412
|
return translation, rotation, score, detail
|
404
413
|
|
405
414
|
|
415
|
+
def load_template(filepath: str, sampling_rate: NDArray) -> "Density":
|
416
|
+
try:
|
417
|
+
template = Density.from_file(filepath)
|
418
|
+
template, _ = template.centered(0)
|
419
|
+
center_of_mass = template.center_of_mass(template.data)
|
420
|
+
except ValueError:
|
421
|
+
template = Structure.from_file(filepath)
|
422
|
+
center_of_mass = template.center_of_mass()[::-1]
|
423
|
+
template = Density.from_structure(template, sampling_rate=sampling_rate)
|
424
|
+
|
425
|
+
return template, center_of_mass
|
426
|
+
|
427
|
+
|
406
428
|
def main():
|
407
429
|
args = parse_args()
|
408
430
|
data = load_pickle(args.input_file)
|
409
431
|
|
410
432
|
meta = data[-1]
|
433
|
+
target_origin, _, sampling_rate, cli_args = meta
|
411
434
|
|
412
435
|
if args.orientations is not None:
|
413
436
|
orientations = Orientations.from_file(
|
@@ -419,6 +442,17 @@ def main():
|
|
419
442
|
# Output is MaxScoreOverRotations
|
420
443
|
if data[0].ndim == data[2].ndim:
|
421
444
|
scores, offset, rotation_array, rotation_mapping, meta = data
|
445
|
+
if args.mask_edges:
|
446
|
+
template, center_of_mass = load_template(
|
447
|
+
cli_args.template, sampling_rate=sampling_rate
|
448
|
+
)
|
449
|
+
if not cli_args.no_centering:
|
450
|
+
template, *_ = template.centered(0)
|
451
|
+
mask_size = template.shape
|
452
|
+
if args.min_boundary_distance > 0:
|
453
|
+
mask_size = 2 * args.min_boundary_distance
|
454
|
+
scores = centered_mask(scores, np.subtract(scores.shape, mask_size) + 1)
|
455
|
+
|
422
456
|
peak_caller = PEAK_CALLERS[args.peak_caller](
|
423
457
|
number_of_peaks=args.number_of_peaks,
|
424
458
|
min_distance=args.min_distance,
|
@@ -458,19 +492,11 @@ def main():
|
|
458
492
|
orientations.to_file(filename=f"{args.output_prefix}.tsv", file_format="text")
|
459
493
|
exit(0)
|
460
494
|
|
461
|
-
target_origin, _, sampling_rate, cli_args = meta
|
462
|
-
|
463
|
-
template_is_density, index = True, 0
|
464
495
|
_, template_extension = splitext(cli_args.template)
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
except ValueError:
|
470
|
-
template_is_density = False
|
471
|
-
template = Structure.from_file(cli_args.template)
|
472
|
-
center_of_mass = template.center_of_mass()[::-1]
|
473
|
-
template = Density.from_structure(template, sampling_rate=sampling_rate)
|
496
|
+
template, center_of_mass = load_template(
|
497
|
+
filepath=cli_args.template, sampling_rate=sampling_rate
|
498
|
+
)
|
499
|
+
template_is_density, index = isinstance(template, Density), 0
|
474
500
|
|
475
501
|
if args.output_format == "relion":
|
476
502
|
new_shape = np.add(template.shape, np.mod(template.shape, 2))
|
scripts/preprocessor_gui.py
CHANGED
@@ -29,17 +29,17 @@ preprocessor = Preprocessor()
|
|
29
29
|
SLIDER_MIN, SLIDER_MAX = 0, 25
|
30
30
|
|
31
31
|
|
32
|
-
def gaussian_filter(template, sigma: float, **kwargs: dict):
|
32
|
+
def gaussian_filter(template: NDArray, sigma: float, **kwargs: dict) -> NDArray:
|
33
33
|
return preprocessor.gaussian_filter(template=template, sigma=sigma, **kwargs)
|
34
34
|
|
35
35
|
|
36
36
|
def bandpass_filter(
|
37
|
-
template,
|
37
|
+
template: NDArray,
|
38
38
|
minimum_frequency: float,
|
39
39
|
maximum_frequency: float,
|
40
40
|
gaussian_sigma: float,
|
41
41
|
**kwargs: dict,
|
42
|
-
):
|
42
|
+
) -> NDArray:
|
43
43
|
return preprocessor.bandpass_filter(
|
44
44
|
template=template,
|
45
45
|
minimum_frequency=minimum_frequency,
|
@@ -51,8 +51,8 @@ def bandpass_filter(
|
|
51
51
|
|
52
52
|
|
53
53
|
def difference_of_gaussian_filter(
|
54
|
-
template, sigmas: Tuple[float, float], **kwargs: dict
|
55
|
-
):
|
54
|
+
template: NDArray, sigmas: Tuple[float, float], **kwargs: dict
|
55
|
+
) -> NDArray:
|
56
56
|
low_sigma, high_sigma = sigmas
|
57
57
|
return preprocessor.difference_of_gaussian_filter(
|
58
58
|
template=template, low_sigma=low_sigma, high_sigma=high_sigma, **kwargs
|
@@ -60,7 +60,7 @@ def difference_of_gaussian_filter(
|
|
60
60
|
|
61
61
|
|
62
62
|
def edge_gaussian_filter(
|
63
|
-
template,
|
63
|
+
template: NDArray,
|
64
64
|
sigma: float,
|
65
65
|
edge_algorithm: Annotated[
|
66
66
|
str,
|
@@ -68,7 +68,7 @@ def edge_gaussian_filter(
|
|
68
68
|
],
|
69
69
|
reverse: bool = False,
|
70
70
|
**kwargs: dict,
|
71
|
-
):
|
71
|
+
) -> NDArray:
|
72
72
|
return preprocessor.edge_gaussian_filter(
|
73
73
|
template=template,
|
74
74
|
sigma=sigma,
|
@@ -78,13 +78,13 @@ def edge_gaussian_filter(
|
|
78
78
|
|
79
79
|
|
80
80
|
def local_gaussian_filter(
|
81
|
-
template,
|
81
|
+
template: NDArray,
|
82
82
|
lbd: float,
|
83
83
|
sigma_range: Tuple[float, float],
|
84
84
|
gaussian_sigma: float,
|
85
85
|
reverse: bool = False,
|
86
86
|
**kwargs: dict,
|
87
|
-
):
|
87
|
+
) -> NDArray:
|
88
88
|
return preprocessor.local_gaussian_filter(
|
89
89
|
template=template,
|
90
90
|
lbd=lbd,
|
@@ -94,22 +94,30 @@ def local_gaussian_filter(
|
|
94
94
|
|
95
95
|
|
96
96
|
def ntree(
|
97
|
-
template,
|
97
|
+
template: NDArray,
|
98
98
|
sigma_range: Tuple[float, float],
|
99
99
|
**kwargs: dict,
|
100
|
-
):
|
100
|
+
) -> NDArray:
|
101
101
|
return preprocessor.ntree_filter(template=template, sigma_range=sigma_range)
|
102
102
|
|
103
103
|
|
104
104
|
def mean(
|
105
|
-
template,
|
105
|
+
template: NDArray,
|
106
106
|
width: int,
|
107
107
|
**kwargs: dict,
|
108
|
-
):
|
108
|
+
) -> NDArray:
|
109
109
|
return preprocessor.mean_filter(template=template, width=width)
|
110
110
|
|
111
111
|
|
112
|
-
def resolution_sphere(
|
112
|
+
def resolution_sphere(
|
113
|
+
template: NDArray,
|
114
|
+
cutoff_angstrom: float,
|
115
|
+
highpass: bool = False,
|
116
|
+
sampling_rate=None,
|
117
|
+
) -> NDArray:
|
118
|
+
if cutoff_angstrom == 0:
|
119
|
+
return template
|
120
|
+
|
113
121
|
cutoff_frequency = np.max(2 * sampling_rate / cutoff_angstrom)
|
114
122
|
|
115
123
|
min_freq, max_freq = 0, cutoff_frequency
|
@@ -123,34 +131,39 @@ def resolution_sphere(template: NDArray, cutoff_angstrom: float, highpass : bool
|
|
123
131
|
omit_negative_frequencies=False,
|
124
132
|
)
|
125
133
|
|
126
|
-
mask = np.fft.ifftshift(mask)
|
127
134
|
template_ft = np.fft.fftn(template)
|
128
|
-
np.multiply(template_ft, mask, out
|
129
|
-
|
135
|
+
np.multiply(template_ft, mask, out=template_ft)
|
130
136
|
return np.fft.ifftn(template_ft).real
|
131
137
|
|
132
138
|
|
133
|
-
def resolution_gaussian(
|
134
|
-
|
139
|
+
def resolution_gaussian(
|
140
|
+
template: NDArray,
|
141
|
+
cutoff_angstrom: float,
|
142
|
+
highpass: bool = False,
|
143
|
+
sampling_rate=None,
|
144
|
+
) -> NDArray:
|
145
|
+
if cutoff_angstrom == 0:
|
146
|
+
return template
|
147
|
+
|
135
148
|
grid = preprocessor.fftfreqn(
|
136
|
-
shape
|
149
|
+
shape=template.shape, sampling_rate=sampling_rate / sampling_rate.max()
|
137
150
|
)
|
138
151
|
|
139
152
|
sigma_fourier = np.divide(
|
140
|
-
np.max(2 * sampling_rate / cutoff_angstrom),
|
141
|
-
np.sqrt(2 * np.log(2))
|
153
|
+
np.max(2 * sampling_rate / cutoff_angstrom), np.sqrt(2 * np.log(2))
|
142
154
|
)
|
143
155
|
|
144
|
-
mask = np.exp(-grid
|
156
|
+
mask = np.exp(-(grid**2) / (2 * sigma_fourier**2))
|
145
157
|
if highpass:
|
146
158
|
mask = 1 - mask
|
147
159
|
|
148
160
|
mask = np.fft.ifftshift(mask)
|
149
|
-
template_ft = np.fft.fftn(template)
|
150
|
-
np.multiply(template_ft, mask, out = template_ft)
|
151
161
|
|
162
|
+
template_ft = np.fft.fftn(template)
|
163
|
+
np.multiply(template_ft, mask, out=template_ft)
|
152
164
|
return np.fft.ifftn(template_ft).real
|
153
165
|
|
166
|
+
|
154
167
|
def wedge(
|
155
168
|
template: NDArray,
|
156
169
|
tilt_start: float,
|
@@ -162,7 +175,7 @@ def wedge(
|
|
162
175
|
omit_negative_frequencies: bool = True,
|
163
176
|
extrude_plane: bool = True,
|
164
177
|
infinite_plane: bool = True,
|
165
|
-
):
|
178
|
+
) -> NDArray:
|
166
179
|
template_ft = np.fft.rfftn(template)
|
167
180
|
|
168
181
|
if tilt_step <= 0:
|
@@ -181,15 +194,14 @@ def wedge(
|
|
181
194
|
template = np.real(np.fft.irfftn(template_ft))
|
182
195
|
return template
|
183
196
|
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
197
|
+
wedge_mask = preprocessor.step_wedge_mask(
|
198
|
+
start_tilt=tilt_start,
|
199
|
+
stop_tilt=tilt_stop,
|
200
|
+
tilt_axis=tilt_axis,
|
201
|
+
tilt_step=tilt_step,
|
202
|
+
opening_axis=opening_axis,
|
190
203
|
shape=template.shape,
|
191
204
|
sigma=gaussian_sigma,
|
192
|
-
opening_axes=opening_axis,
|
193
205
|
omit_negative_frequencies=omit_negative_frequencies,
|
194
206
|
)
|
195
207
|
np.multiply(template_ft, wedge_mask, out=template_ft)
|
@@ -197,7 +209,7 @@ def wedge(
|
|
197
209
|
return template
|
198
210
|
|
199
211
|
|
200
|
-
def compute_power_spectrum(template: NDArray):
|
212
|
+
def compute_power_spectrum(template: NDArray) -> NDArray:
|
201
213
|
return np.fft.fftshift(np.log(np.abs(np.fft.fftn(template))))
|
202
214
|
|
203
215
|
|
@@ -262,8 +274,8 @@ WRAPPED_FUNCTIONS = {
|
|
262
274
|
"mean_filter": mean,
|
263
275
|
"wedge_filter": wedge,
|
264
276
|
"power_spectrum": compute_power_spectrum,
|
265
|
-
"resolution_gaussian"
|
266
|
-
"resolution_sphere"
|
277
|
+
"resolution_gaussian": resolution_gaussian,
|
278
|
+
"resolution_sphere": resolution_sphere,
|
267
279
|
}
|
268
280
|
|
269
281
|
EXCLUDED_FUNCTIONS = [
|
@@ -405,7 +417,7 @@ class FilterWidget(widgets.Container):
|
|
405
417
|
|
406
418
|
def sphere_mask(
|
407
419
|
template: NDArray, center_x: float, center_y: float, center_z: float, radius: float
|
408
|
-
):
|
420
|
+
) -> NDArray:
|
409
421
|
return create_mask(
|
410
422
|
mask_type="ellipse",
|
411
423
|
shape=template.shape,
|
@@ -422,7 +434,7 @@ def ellipsod_mask(
|
|
422
434
|
radius_x: float,
|
423
435
|
radius_y: float,
|
424
436
|
radius_z: float,
|
425
|
-
):
|
437
|
+
) -> NDArray:
|
426
438
|
return create_mask(
|
427
439
|
mask_type="ellipse",
|
428
440
|
shape=template.shape,
|
@@ -439,7 +451,7 @@ def box_mask(
|
|
439
451
|
height_x: int,
|
440
452
|
height_y: int,
|
441
453
|
height_z: int,
|
442
|
-
):
|
454
|
+
) -> NDArray:
|
443
455
|
return create_mask(
|
444
456
|
mask_type="box",
|
445
457
|
shape=template.shape,
|
@@ -457,7 +469,7 @@ def tube_mask(
|
|
457
469
|
inner_radius: float,
|
458
470
|
outer_radius: float,
|
459
471
|
height: int,
|
460
|
-
):
|
472
|
+
) -> NDArray:
|
461
473
|
return create_mask(
|
462
474
|
mask_type="tube",
|
463
475
|
shape=template.shape,
|
@@ -480,7 +492,7 @@ def wedge_mask(
|
|
480
492
|
omit_negative_frequencies: bool = False,
|
481
493
|
extrude_plane: bool = True,
|
482
494
|
infinite_plane: bool = True,
|
483
|
-
):
|
495
|
+
) -> NDArray:
|
484
496
|
if tilt_step <= 0:
|
485
497
|
wedge_mask = preprocessor.continuous_wedge_mask(
|
486
498
|
start_tilt=tilt_start,
|
@@ -496,25 +508,24 @@ def wedge_mask(
|
|
496
508
|
wedge_mask = np.fft.fftshift(wedge_mask)
|
497
509
|
return wedge_mask
|
498
510
|
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
tilt_angles=angles,
|
511
|
+
wedge_mask = preprocessor.step_wedge_mask(
|
512
|
+
start_tilt=tilt_start,
|
513
|
+
stop_tilt=tilt_stop,
|
514
|
+
tilt_axis=tilt_axis,
|
515
|
+
tilt_step=tilt_step,
|
516
|
+
opening_axis=opening_axis,
|
506
517
|
shape=template.shape,
|
507
518
|
sigma=gaussian_sigma,
|
508
|
-
opening_axes=opening_axis,
|
509
519
|
omit_negative_frequencies=omit_negative_frequencies,
|
510
520
|
)
|
521
|
+
|
511
522
|
wedge_mask = np.fft.fftshift(wedge_mask)
|
512
523
|
return wedge_mask
|
513
524
|
|
514
525
|
|
515
526
|
def threshold_mask(
|
516
527
|
template: NDArray, standard_deviation: float = 5.0, invert: bool = False
|
517
|
-
):
|
528
|
+
) -> NDArray:
|
518
529
|
template_mean = template.mean()
|
519
530
|
template_deviation = standard_deviation * template.std()
|
520
531
|
upper = template_mean + template_deviation
|
@@ -526,6 +537,15 @@ def threshold_mask(
|
|
526
537
|
return mask
|
527
538
|
|
528
539
|
|
540
|
+
def lowpass_mask(template: NDArray, sigma: float = 1.0):
|
541
|
+
template = template / template.max()
|
542
|
+
template = (template > np.exp(-2)) * 128.0
|
543
|
+
template = preprocessor.gaussian_filter(template=template, sigma=sigma)
|
544
|
+
mask = template > np.exp(-2)
|
545
|
+
|
546
|
+
return mask
|
547
|
+
|
548
|
+
|
529
549
|
class MaskWidget(widgets.Container):
|
530
550
|
def __init__(self, viewer):
|
531
551
|
super().__init__(layout="vertical")
|
@@ -543,6 +563,7 @@ class MaskWidget(widgets.Container):
|
|
543
563
|
"Box": box_mask,
|
544
564
|
"Wedge": wedge_mask,
|
545
565
|
"Threshold": threshold_mask,
|
566
|
+
"Lowpass": lowpass_mask,
|
546
567
|
}
|
547
568
|
|
548
569
|
self.method_dropdown = widgets.ComboBox(
|
@@ -606,6 +627,7 @@ class MaskWidget(widgets.Container):
|
|
606
627
|
arr=active_layer.data,
|
607
628
|
rotation_matrix=rotation_matrix,
|
608
629
|
use_geometric_center=False,
|
630
|
+
order=1,
|
609
631
|
)
|
610
632
|
eps = np.finfo(rotated_data.dtype).eps
|
611
633
|
rotated_data[rotated_data < eps] = 0
|
@@ -636,10 +658,10 @@ class MaskWidget(widgets.Container):
|
|
636
658
|
dict(zip(["height_x", "height_y", "height_z"], coordinates_heights))
|
637
659
|
)
|
638
660
|
|
639
|
-
defaults["radius"] = np.
|
661
|
+
defaults["radius"] = np.max(coordinate_radius)
|
640
662
|
defaults["inner_radius"] = np.min(coordinate_radius)
|
641
663
|
defaults["outer_radius"] = np.max(coordinate_radius)
|
642
|
-
defaults["height"] =
|
664
|
+
defaults["height"] = np.max(coordinates_heights)
|
643
665
|
|
644
666
|
for widget in self.action_widgets:
|
645
667
|
if widget.name in defaults:
|
tme/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.8"
|
tme/backends/npfftw_backend.py
CHANGED
@@ -613,13 +613,6 @@ class NumpyFFTWBackend(MatchingBackend):
|
|
613
613
|
rotate_mask = arr_mask is not None and mask_coordinates is not None
|
614
614
|
return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
|
615
615
|
|
616
|
-
# Otherwise array might be slightly shifted by centering
|
617
|
-
if np.allclose(
|
618
|
-
rotation_matrix,
|
619
|
-
np.eye(rotation_matrix.shape[0], dtype=rotation_matrix.dtype),
|
620
|
-
):
|
621
|
-
center_rotation = False
|
622
|
-
|
623
616
|
coordinates_rotated = np.empty(coordinates.shape, dtype=rotation_matrix.dtype)
|
624
617
|
mask_rotated = (
|
625
618
|
np.empty(mask_coordinates.shape, dtype=rotation_matrix.dtype)
|
tme/density.py
CHANGED
@@ -199,16 +199,16 @@ class Density:
|
|
199
199
|
func = cls._load_mrc
|
200
200
|
if filename.endswith(".em") or filename.endswith(".em.gz"):
|
201
201
|
func = cls._load_em
|
202
|
-
data, origin, sampling_rate = func(
|
202
|
+
data, origin, sampling_rate, meta = func(
|
203
203
|
filename=filename, subset=subset, use_memmap=use_memmap
|
204
204
|
)
|
205
205
|
except ValueError:
|
206
|
-
data, origin, sampling_rate = cls._load_skio(filename=filename)
|
206
|
+
data, origin, sampling_rate, meta = cls._load_skio(filename=filename)
|
207
207
|
if subset is not None:
|
208
208
|
cls._validate_slices(slices=subset, shape=data.shape)
|
209
209
|
data = data[subset].copy()
|
210
210
|
|
211
|
-
return cls(data=data, origin=origin, sampling_rate=sampling_rate)
|
211
|
+
return cls(data=data, origin=origin, sampling_rate=sampling_rate, metadata=meta)
|
212
212
|
|
213
213
|
@classmethod
|
214
214
|
def _load_mrc(
|
@@ -307,6 +307,13 @@ class Density:
|
|
307
307
|
|
308
308
|
extended_header = mrc.header.nsymbt
|
309
309
|
|
310
|
+
metadata = {
|
311
|
+
"min": float(mrc.header.dmin),
|
312
|
+
"max": float(mrc.header.dmax),
|
313
|
+
"mean": float(mrc.header.dmean),
|
314
|
+
"std": float(mrc.header.rms),
|
315
|
+
}
|
316
|
+
|
310
317
|
if is_gzipped(filename):
|
311
318
|
if use_memmap:
|
312
319
|
warnings.warn(
|
@@ -329,7 +336,7 @@ class Density:
|
|
329
336
|
dtype=data_type,
|
330
337
|
header_size=1024 + extended_header,
|
331
338
|
)
|
332
|
-
return data, origin, sampling_rate
|
339
|
+
return data, origin, sampling_rate, metadata
|
333
340
|
|
334
341
|
if not use_memmap:
|
335
342
|
with mrcfile.open(filename, header_only=False) as mrc:
|
@@ -343,7 +350,7 @@ class Density:
|
|
343
350
|
data = np.transpose(data, crs_index)
|
344
351
|
start = np.take(start, crs_index)
|
345
352
|
|
346
|
-
return data, origin, sampling_rate
|
353
|
+
return data, origin, sampling_rate, metadata
|
347
354
|
|
348
355
|
@classmethod
|
349
356
|
def _load_em(
|
@@ -448,7 +455,7 @@ class Density:
|
|
448
455
|
pixel_size = 1
|
449
456
|
sampling_rate = np.repeat(pixel_size, data.ndim).astype(data.dtype)
|
450
457
|
|
451
|
-
return data, origin, sampling_rate
|
458
|
+
return data, origin, sampling_rate, {}
|
452
459
|
|
453
460
|
@staticmethod
|
454
461
|
def _validate_slices(slices: Tuple[slice], shape: Tuple[int]):
|
@@ -592,7 +599,7 @@ class Density:
|
|
592
599
|
warnings.warn(
|
593
600
|
"origin and sampling_rate are not yet extracted from non CCP4/MRC files."
|
594
601
|
)
|
595
|
-
return data, np.zeros(data.ndim), np.ones(data.ndim)
|
602
|
+
return data, np.zeros(data.ndim), np.ones(data.ndim), {}
|
596
603
|
|
597
604
|
@classmethod
|
598
605
|
def from_structure(
|
@@ -909,7 +916,8 @@ class Density:
|
|
909
916
|
Returns a copy of the current class instance with all elements in
|
910
917
|
:py:attr:`Density.data` set to zero. :py:attr:`Density.origin` and
|
911
918
|
:py:attr:`Density.sampling_rate` will be copied, while
|
912
|
-
:py:attr:`Density.metadata` will be initialized to
|
919
|
+
:py:attr:`Density.metadata` will be initialized to contain min, max,
|
920
|
+
mean and standard deviation of :py:attr:`Density.data`.
|
913
921
|
|
914
922
|
Examples
|
915
923
|
--------
|
@@ -924,6 +932,7 @@ class Density:
|
|
924
932
|
data=np.zeros_like(self.data),
|
925
933
|
origin=deepcopy(self.origin),
|
926
934
|
sampling_rate=deepcopy(self.sampling_rate),
|
935
|
+
metadata={"min": 0, "max": 0, "mean": 0, "std": 0},
|
927
936
|
)
|
928
937
|
|
929
938
|
def copy(self) -> "Density":
|
tme/matching_data.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4
4
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
|
-
|
7
|
+
import warnings
|
8
8
|
from typing import Tuple, List
|
9
9
|
|
10
10
|
import numpy as np
|
@@ -81,7 +81,11 @@ class MatchingData:
|
|
81
81
|
return arr
|
82
82
|
|
83
83
|
def subset_array(
|
84
|
-
self,
|
84
|
+
self,
|
85
|
+
arr: NDArray,
|
86
|
+
arr_slice: Tuple[slice],
|
87
|
+
padding: NDArray,
|
88
|
+
invert: bool = False,
|
85
89
|
) -> NDArray:
|
86
90
|
"""
|
87
91
|
Extract a subset of the input array according to the given slice and
|
@@ -96,8 +100,11 @@ class MatchingData:
|
|
96
100
|
padding : NDArray
|
97
101
|
Padding values for each dimension. If the padding exceeds the array
|
98
102
|
dimensions, the extra regions are filled with the mean of the array
|
99
|
-
values, otherwise, the
|
100
|
-
|
103
|
+
values, otherwise, the values in ``arr`` are used.
|
104
|
+
invert : bool, optional
|
105
|
+
Whether the returned array should be inverted and normalized to the interval
|
106
|
+
[0, 1]. If available, uses the metadata information of the Density object,
|
107
|
+
otherwise computes min and max on the extracted subset.
|
101
108
|
|
102
109
|
Returns
|
103
110
|
-------
|
@@ -109,7 +116,6 @@ class MatchingData:
|
|
109
116
|
|
110
117
|
slice_start = np.array([x.start for x in arr_slice], dtype=int)
|
111
118
|
slice_stop = np.array([x.stop for x in arr_slice], dtype=int)
|
112
|
-
slice_shape = np.subtract(slice_stop, slice_start)
|
113
119
|
|
114
120
|
padding = np.add(padding, np.mod(padding, 2))
|
115
121
|
left_pad = right_pad = np.divide(padding, 2).astype(int)
|
@@ -119,20 +125,18 @@ class MatchingData:
|
|
119
125
|
np.subtract(arr.shape, slice_stop), right_pad
|
120
126
|
).astype(int)
|
121
127
|
|
122
|
-
ret_shape = np.add(slice_shape, padding)
|
123
128
|
arr_start = np.subtract(slice_start, data_voxels_left)
|
124
129
|
arr_stop = np.add(slice_stop, data_voxels_right)
|
125
130
|
arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
|
126
131
|
arr_mesh = self._slice_to_mesh(arr_slice, arr.shape)
|
127
132
|
|
128
|
-
|
129
|
-
subset_stop = np.add(subset_start, np.subtract(arr_stop, arr_start))
|
130
|
-
subset_slice = tuple(slice(*prod) for prod in zip(subset_start, subset_stop))
|
131
|
-
subset_mesh = self._slice_to_mesh(subset_slice, ret_shape)
|
132
|
-
|
133
|
+
arr_min, arr_max = None, None
|
133
134
|
if type(arr) == Density:
|
134
135
|
if type(arr.data) == np.memmap:
|
135
|
-
|
136
|
+
dens = Density.from_file(arr.data.filename, subset=arr_slice)
|
137
|
+
arr = dens.data
|
138
|
+
arr_min = dens.metadata.get("min", None)
|
139
|
+
arr_max = dens.metadata.get("max", None)
|
136
140
|
else:
|
137
141
|
arr = np.asarray(arr.data[*arr_mesh])
|
138
142
|
else:
|
@@ -141,10 +145,38 @@ class MatchingData:
|
|
141
145
|
arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype
|
142
146
|
)
|
143
147
|
arr = np.asarray(arr[*arr_mesh])
|
144
|
-
|
145
|
-
|
148
|
+
|
149
|
+
def _warn_on_mismatch(
|
150
|
+
expectation: float, computation: float, name: str
|
151
|
+
) -> float:
|
152
|
+
expectation, computation = float(expectation), float(computation)
|
153
|
+
if expectation is None:
|
154
|
+
expectation = computation
|
155
|
+
|
156
|
+
if abs(computation) > abs(expectation):
|
157
|
+
warnings.warn(
|
158
|
+
f"Computed {name} value is more extreme than value specified in file"
|
159
|
+
f" (|{computation}| > |{expectation}|). This may lead to issues"
|
160
|
+
" with padding and contrast inversion."
|
161
|
+
)
|
162
|
+
|
163
|
+
return expectation
|
164
|
+
|
165
|
+
padding = tuple(
|
166
|
+
(left, right)
|
167
|
+
for left, right in zip(
|
168
|
+
np.subtract(left_pad, data_voxels_left),
|
169
|
+
np.subtract(right_pad, data_voxels_right),
|
170
|
+
)
|
146
171
|
)
|
147
|
-
ret
|
172
|
+
ret = np.pad(arr, padding, mode="reflect")
|
173
|
+
|
174
|
+
if invert:
|
175
|
+
arr_min = _warn_on_mismatch(arr_min, arr.min(), "min")
|
176
|
+
arr_max = _warn_on_mismatch(arr_max, arr.max(), "max")
|
177
|
+
|
178
|
+
np.subtract(-ret, arr_min, out=ret)
|
179
|
+
np.divide(ret, arr_max - arr_min, out=ret)
|
148
180
|
|
149
181
|
return ret
|
150
182
|
|
@@ -198,12 +230,11 @@ class MatchingData:
|
|
198
230
|
)
|
199
231
|
|
200
232
|
target_subset = self.subset_array(
|
201
|
-
arr=self._target,
|
233
|
+
arr=self._target,
|
234
|
+
arr_slice=target_slice,
|
235
|
+
padding=target_pad,
|
236
|
+
invert=self._invert_target,
|
202
237
|
)
|
203
|
-
if self._invert_target:
|
204
|
-
target_subset *= -1
|
205
|
-
target_min, target_max = target_subset.min(), target_subset.max()
|
206
|
-
target_subset = (target_subset - target_min) / (target_max - target_min)
|
207
238
|
|
208
239
|
template_subset = self.subset_array(
|
209
240
|
arr=self._template,
|
@@ -488,14 +519,14 @@ class MatchingData:
|
|
488
519
|
@property
|
489
520
|
def target(self):
|
490
521
|
"""Returns the target NDArray."""
|
491
|
-
if
|
522
|
+
if isinstance(self._target, Density):
|
492
523
|
return self._target.data
|
493
524
|
return self._target
|
494
525
|
|
495
526
|
@property
|
496
527
|
def template(self):
|
497
528
|
"""Returns the reversed template NDArray."""
|
498
|
-
if
|
529
|
+
if isinstance(self._template, Density):
|
499
530
|
return backend.reverse(self._template.data)
|
500
531
|
return backend.reverse(self._template)
|
501
532
|
|
@@ -524,7 +555,7 @@ class MatchingData:
|
|
524
555
|
@property
|
525
556
|
def target_mask(self):
|
526
557
|
"""Returns the target mask NDArray."""
|
527
|
-
if
|
558
|
+
if isinstance(self._target_mask, Density):
|
528
559
|
return self._target_mask.data
|
529
560
|
return self._target_mask
|
530
561
|
|
@@ -553,7 +584,7 @@ class MatchingData:
|
|
553
584
|
template : NDArray
|
554
585
|
Array to set as the template.
|
555
586
|
"""
|
556
|
-
if
|
587
|
+
if isinstance(self._template_mask, Density):
|
557
588
|
return backend.reverse(self._template_mask.data)
|
558
589
|
return backend.reverse(self._template_mask)
|
559
590
|
|