pytme 0.1.1__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.1.1 → pytme-0.1.3}/PKG-INFO +2 -2
- {pytme-0.1.1 → pytme-0.1.3}/pyproject.toml +2 -2
- {pytme-0.1.1 → pytme-0.1.3}/pytme.egg-info/SOURCES.txt +1 -0
- {pytme-0.1.1 → pytme-0.1.3}/scripts/match_template.py +10 -9
- {pytme-0.1.1 → pytme-0.1.3}/scripts/postprocess.py +6 -3
- {pytme-0.1.1 → pytme-0.1.3}/scripts/preprocessor_gui.py +93 -14
- pytme-0.1.3/tme/__version__.py +1 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/analyzer.py +2 -2
- {pytme-0.1.1 → pytme-0.1.3}/tme/backends/pytorch_backend.py +7 -4
- {pytme-0.1.1 → pytme-0.1.3}/tme/density.py +22 -13
- {pytme-0.1.1 → pytme-0.1.3}/tme/matching_data.py +2 -4
- {pytme-0.1.1 → pytme-0.1.3}/tme/matching_exhaustive.py +0 -2
- {pytme-0.1.1 → pytme-0.1.3}/tme/matching_memory.py +1 -1
- {pytme-0.1.1 → pytme-0.1.3}/tme/matching_optimization.py +5 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/matching_utils.py +7 -1
- {pytme-0.1.1 → pytme-0.1.3}/tme/preprocessor.py +62 -8
- pytme-0.1.3/tme/scoring.py +679 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/structure.py +19 -20
- pytme-0.1.1/tme/__version__.py +0 -1
- {pytme-0.1.1 → pytme-0.1.3}/LICENSE +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/MANIFEST.in +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/README.md +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/scripts/preprocess.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/setup.cfg +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/setup.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/src/extensions.cpp +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/__init__.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/backends/__init__.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/backends/cupy_backend.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/backends/matching_backend.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/backends/npfftw_backend.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/__init__.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48n309.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48n527.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48n9.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u1.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u1153.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u1201.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u1641.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u181.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u2219.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u27.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u2947.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u3733.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u4749.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u5879.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u7111.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u815.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u83.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u8649.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c600v.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/c600vc.npy +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/metadata.yaml +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/data/quat_to_numpy.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/helpers.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/parser.py +0 -0
- {pytme-0.1.1 → pytme-0.1.3}/tme/types.py +0 -0
@@ -1,11 +1,11 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: pytme
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.3
|
4
4
|
Summary: Python Template Matching Engine
|
5
5
|
Author: Valentin Maurer
|
6
6
|
Author-email: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
License: Proprietary
|
8
|
-
Project-URL: Homepage, https://
|
8
|
+
Project-URL: Homepage, https://github.com/KosinskiLab/pyTME
|
9
9
|
Classifier: Programming Language :: Python :: 3
|
10
10
|
Classifier: Operating System :: OS Independent
|
11
11
|
Requires-Python: >=3.11
|
@@ -7,7 +7,7 @@ name="pytme"
|
|
7
7
|
authors = [
|
8
8
|
{ name = "Valentin Maurer", email = "valentin.maurer@embl-hamburg.de" },
|
9
9
|
]
|
10
|
-
version="0.1.
|
10
|
+
version="0.1.3"
|
11
11
|
description="Python Template Matching Engine"
|
12
12
|
readme="README.md"
|
13
13
|
requires-python = ">=3.11"
|
@@ -38,7 +38,7 @@ preprocess = "tme.scripts:preprocess"
|
|
38
38
|
postprocess = "tme.scripts:postprocess"
|
39
39
|
|
40
40
|
[project.urls]
|
41
|
-
"Homepage" = "https://
|
41
|
+
"Homepage" = "https://github.com/KosinskiLab/pyTME"
|
42
42
|
|
43
43
|
[tool.setuptools]
|
44
44
|
include-package-data = true
|
@@ -43,7 +43,7 @@ def print_block(name: str, data: dict, label_width=20) -> None:
|
|
43
43
|
"""Prints a formatted block of information."""
|
44
44
|
print(f"\n> {name}")
|
45
45
|
for key, value in data.items():
|
46
|
-
formatted_value = str(value)
|
46
|
+
formatted_value = str(value)
|
47
47
|
print(f" - {key + ':':<{label_width}} {formatted_value}")
|
48
48
|
|
49
49
|
|
@@ -426,7 +426,7 @@ def main():
|
|
426
426
|
|
427
427
|
if not np.allclose(target.sampling_rate, template.sampling_rate):
|
428
428
|
print(
|
429
|
-
f"Resampling template to {target.sampling_rate}."
|
429
|
+
f"Resampling template to {target.sampling_rate}. "
|
430
430
|
"Consider providing a template with the same sampling rate as the target."
|
431
431
|
)
|
432
432
|
template = template.resample(target.sampling_rate, order=3)
|
@@ -506,7 +506,7 @@ def main():
|
|
506
506
|
-tilt_start, tilt_stop + args.tilt_step, args.tilt_step
|
507
507
|
)
|
508
508
|
angles = np.zeros((template.data.ndim, tilt_angles.size))
|
509
|
-
angles[
|
509
|
+
angles[2, :] = tilt_angles
|
510
510
|
template_filter["wedge_mask"] = {
|
511
511
|
"tilt_angles": angles,
|
512
512
|
"sigma": args.wedge_smooth,
|
@@ -516,6 +516,7 @@ def main():
|
|
516
516
|
"start_tilt": tilt_start,
|
517
517
|
"stop_tilt": tilt_stop,
|
518
518
|
"tilt_axis": 1,
|
519
|
+
"infinite_plane": True,
|
519
520
|
"sigma": args.wedge_smooth,
|
520
521
|
}
|
521
522
|
|
@@ -597,6 +598,10 @@ def main():
|
|
597
598
|
if not args.pad_fourier:
|
598
599
|
template_box = np.ones(len(template_box), dtype=int)
|
599
600
|
|
601
|
+
callback_class = MaxScoreOverRotations
|
602
|
+
if args.peak_calling:
|
603
|
+
callback_class = PeakCallerMaximumFilter
|
604
|
+
|
600
605
|
splits, schedule = compute_parallelization_schedule(
|
601
606
|
shape1=target.shape,
|
602
607
|
shape2=template_box,
|
@@ -605,7 +610,7 @@ def main():
|
|
605
610
|
max_ram=args.ram,
|
606
611
|
split_only_outer=args.use_gpu,
|
607
612
|
matching_method=args.score,
|
608
|
-
analyzer_method=
|
613
|
+
analyzer_method=callback_class.__name__,
|
609
614
|
backend=backend._backend_name,
|
610
615
|
float_nbytes=backend.datatype_bytes(backend._default_dtype),
|
611
616
|
complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
|
@@ -627,16 +632,12 @@ def main():
|
|
627
632
|
}
|
628
633
|
|
629
634
|
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
|
630
|
-
callback_class = MaxScoreOverRotations
|
631
|
-
if args.peak_calling:
|
632
|
-
callback_class = PeakCallerMaximumFilter
|
633
|
-
|
634
635
|
matching_data = MatchingData(target=target, template=template.data)
|
635
636
|
matching_data.rotations = get_rotation_matrices(
|
636
637
|
angular_sampling=args.angular_sampling, dim=target.data.ndim
|
637
638
|
)
|
638
|
-
matching_data.template_filter = template_filter
|
639
639
|
|
640
|
+
matching_data.template_filter = template_filter
|
640
641
|
if target_mask is not None:
|
641
642
|
matching_data.target_mask = target_mask
|
642
643
|
if template_mask is not None:
|
@@ -100,9 +100,12 @@ def main():
|
|
100
100
|
orientations.append((translation, angles, score, detail))
|
101
101
|
else:
|
102
102
|
candidates = data
|
103
|
-
|
104
|
-
|
105
|
-
|
103
|
+
translation, rotation, score, detail, *_ = data
|
104
|
+
for i in range(translation.shape[0]):
|
105
|
+
angles = euler_from_rotationmatrix(rotation[i])
|
106
|
+
orientations.append(
|
107
|
+
(np.array(translation[i]), angles, score[i], detail[i])
|
108
|
+
)
|
106
109
|
else:
|
107
110
|
with open(args.orientations, mode="r", encoding="utf-8") as infile:
|
108
111
|
data = [x.strip().split("\t") for x in infile.read().split("\n")]
|
@@ -108,6 +108,31 @@ def mean(
|
|
108
108
|
return preprocessor.mean_filter(template=template, width=width)
|
109
109
|
|
110
110
|
|
111
|
+
def wedge(
|
112
|
+
template: NDArray,
|
113
|
+
tilt_start: float,
|
114
|
+
tilt_stop: float,
|
115
|
+
gaussian_sigma: float,
|
116
|
+
tilt_axis: int = 1,
|
117
|
+
infinite_plane: bool = True,
|
118
|
+
extrude_plane: bool = True,
|
119
|
+
):
|
120
|
+
template_ft = np.fft.rfftn(template)
|
121
|
+
wedge_mask = preprocessor.continuous_wedge_mask(
|
122
|
+
start_tilt=tilt_start,
|
123
|
+
stop_tilt=tilt_stop,
|
124
|
+
tilt_axis=tilt_axis,
|
125
|
+
shape=template.shape,
|
126
|
+
sigma=gaussian_sigma,
|
127
|
+
omit_negative_frequencies=True,
|
128
|
+
infinite_plane=infinite_plane,
|
129
|
+
extrude_plane=extrude_plane,
|
130
|
+
)
|
131
|
+
np.multiply(template_ft, wedge_mask, out=template_ft)
|
132
|
+
template = np.real(np.fft.irfftn(template_ft))
|
133
|
+
return template
|
134
|
+
|
135
|
+
|
111
136
|
def widgets_from_function(function: Callable, exclude_params: List = ["self"]):
|
112
137
|
"""
|
113
138
|
Creates list of magicui widgets by inspecting function typing ann
|
@@ -166,7 +191,8 @@ WRAPPED_FUNCTIONS = {
|
|
166
191
|
"ntree_filter": ntree,
|
167
192
|
"local_gaussian_filter": local_gaussian_filter,
|
168
193
|
"difference_of_gaussian_filter": difference_of_gaussian_filter,
|
169
|
-
"mean_filter"
|
194
|
+
"mean_filter": mean,
|
195
|
+
"continuous_wedge_mask": wedge,
|
170
196
|
}
|
171
197
|
|
172
198
|
EXCLUDED_FUNCTIONS = [
|
@@ -178,7 +204,7 @@ EXCLUDED_FUNCTIONS = [
|
|
178
204
|
"interpolate_box",
|
179
205
|
"molmap",
|
180
206
|
"local_gaussian_alignment_filter",
|
181
|
-
"continuous_wedge_mask",
|
207
|
+
# "continuous_wedge_mask",
|
182
208
|
"wedge_mask",
|
183
209
|
"bandpass_mask",
|
184
210
|
]
|
@@ -372,19 +398,22 @@ def wedge_mask(
|
|
372
398
|
tilt_stop: float,
|
373
399
|
gaussian_sigma: float,
|
374
400
|
tilt_axis: int = 1,
|
401
|
+
omit_negative_frequencies: bool = True,
|
402
|
+
extrude_plane: bool = True,
|
403
|
+
infinite_plane: bool = True,
|
375
404
|
):
|
376
|
-
template_ft = np.fft.fftn(template)
|
377
405
|
wedge_mask = preprocessor.continuous_wedge_mask(
|
378
406
|
start_tilt=tilt_start,
|
379
407
|
stop_tilt=tilt_stop,
|
380
408
|
tilt_axis=tilt_axis,
|
381
|
-
shape=
|
409
|
+
shape=template.shape,
|
382
410
|
sigma=gaussian_sigma,
|
411
|
+
omit_negative_frequencies=omit_negative_frequencies,
|
412
|
+
extrude_plane=extrude_plane,
|
413
|
+
infinite_plane=infinite_plane,
|
383
414
|
)
|
384
|
-
|
385
|
-
|
386
|
-
template = np.real(np.fft.ifftn(template_ft))
|
387
|
-
return template
|
415
|
+
wedge_mask = np.fft.fftshift(wedge_mask)
|
416
|
+
return wedge_mask
|
388
417
|
|
389
418
|
|
390
419
|
class MaskWidget(widgets.Container):
|
@@ -410,34 +439,73 @@ class MaskWidget(widgets.Container):
|
|
410
439
|
)
|
411
440
|
self.method_dropdown.changed.connect(self._on_method_changed)
|
412
441
|
|
413
|
-
self.adapt_button = widgets.PushButton(
|
414
|
-
text="Adapt to current layer", enabled=False
|
415
|
-
)
|
442
|
+
self.adapt_button = widgets.PushButton(text="Adapt to layer", enabled=False)
|
416
443
|
self.adapt_button.changed.connect(self._update_initial_values)
|
417
444
|
|
418
445
|
self.viewer.layers.selection.events.active.connect(
|
419
446
|
self._update_action_button_state
|
420
447
|
)
|
421
448
|
|
449
|
+
self.align_button = widgets.PushButton(text="Align to axis", enabled=False)
|
450
|
+
self.align_button.changed.connect(self._align_with_axis)
|
451
|
+
self.density_field = widgets.Label()
|
452
|
+
# self.density_field.value = f"Positive Density in Mask: {0:.2f}%"
|
453
|
+
|
422
454
|
self.append(self.method_dropdown)
|
423
455
|
self.append(self.adapt_button)
|
456
|
+
self.append(self.align_button)
|
424
457
|
self.append(self.action_button)
|
458
|
+
self.append(self.density_field)
|
425
459
|
|
426
460
|
# Create GUI for initially selected filtering method
|
427
461
|
self._on_method_changed(None)
|
428
462
|
|
429
463
|
def _update_action_button_state(self, event):
|
464
|
+
self.align_button.enabled = bool(self.viewer.layers.selection.active)
|
430
465
|
self.action_button.enabled = bool(self.viewer.layers.selection.active)
|
431
466
|
self.adapt_button.enabled = bool(self.viewer.layers.selection.active)
|
432
467
|
|
468
|
+
def _align_with_axis(self):
|
469
|
+
active_layer = self.viewer.layers.selection.active
|
470
|
+
|
471
|
+
if active_layer.metadata.get("is_aligned", False):
|
472
|
+
return
|
473
|
+
|
474
|
+
coords = np.array(np.where(active_layer.data > 0)).T
|
475
|
+
centered_coords = coords - np.mean(coords, axis=0)
|
476
|
+
cov_matrix = np.cov(centered_coords, rowvar=False)
|
477
|
+
|
478
|
+
eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
|
479
|
+
principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
|
480
|
+
|
481
|
+
rotation_axis = np.cross(principal_eigenvector, [1, 0, 0])
|
482
|
+
rotation_angle = np.arccos(np.dot(principal_eigenvector, [1, 0, 0]))
|
483
|
+
k = rotation_axis / np.linalg.norm(rotation_axis)
|
484
|
+
K = np.array([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]])
|
485
|
+
rotation_matrix = np.eye(3)
|
486
|
+
rotation_matrix += np.sin(rotation_angle) * K
|
487
|
+
rotation_matrix += (1 - np.cos(rotation_angle)) * np.dot(K, K)
|
488
|
+
|
489
|
+
rotated_data = Density.rotate_array(
|
490
|
+
arr=active_layer.data,
|
491
|
+
rotation_matrix=rotation_matrix,
|
492
|
+
use_geometric_center=False,
|
493
|
+
)
|
494
|
+
eps = np.finfo(rotated_data.dtype).eps
|
495
|
+
rotated_data[rotated_data < eps] = 0
|
496
|
+
|
497
|
+
active_layer.metadata["is_aligned"] = True
|
498
|
+
active_layer.data = rotated_data
|
499
|
+
|
433
500
|
def _update_initial_values(self, event=None):
|
434
501
|
active_layer = self.viewer.layers.selection.active
|
435
|
-
center_of_mass = Density.center_of_mass(np.abs(active_layer.data))
|
502
|
+
center_of_mass = Density.center_of_mass(np.abs(active_layer.data), 0)
|
436
503
|
coordinates = np.array(np.where(active_layer.data > 0))
|
437
504
|
coordinates_min = coordinates.min(axis=1)
|
438
505
|
coordinates_max = coordinates.max(axis=1)
|
439
506
|
coordinates_heights = coordinates_max - coordinates_min
|
440
507
|
coordinate_radius = np.divide(coordinates_heights, 2)
|
508
|
+
center_of_mass = coordinate_radius + coordinates_min
|
441
509
|
|
442
510
|
defaults = dict(zip(["center_x", "center_y", "center_z"], center_of_mass))
|
443
511
|
defaults.update(
|
@@ -465,7 +533,7 @@ class MaskWidget(widgets.Container):
|
|
465
533
|
widgets = widgets_from_function(function)
|
466
534
|
for widget in widgets:
|
467
535
|
self.action_widgets.append(widget)
|
468
|
-
self.insert(
|
536
|
+
self.insert(1, widget)
|
469
537
|
|
470
538
|
def _action(self):
|
471
539
|
function = self.methods.get(self.method_dropdown.value)
|
@@ -479,7 +547,8 @@ class MaskWidget(widgets.Container):
|
|
479
547
|
selected_layer = self.viewer.layers[new_layer_name]
|
480
548
|
|
481
549
|
processed_data = processed_data.astype(np.float32)
|
482
|
-
|
550
|
+
metadata = selected_layer.metadata
|
551
|
+
mask = metadata.get("mask", False)
|
483
552
|
if mask == self.method_dropdown.value:
|
484
553
|
selected_layer.data = processed_data
|
485
554
|
else:
|
@@ -490,8 +559,18 @@ class MaskWidget(widgets.Container):
|
|
490
559
|
metadata = selected_layer.metadata.copy()
|
491
560
|
metadata["filter_parameters"] = {self.method_dropdown.value: kwargs.copy()}
|
492
561
|
metadata["mask"] = self.method_dropdown.value
|
562
|
+
metadata["origin_layer"] = selected_layer.name
|
493
563
|
new_layer.metadata = metadata
|
494
564
|
|
565
|
+
origin_layer = metadata["origin_layer"]
|
566
|
+
if origin_layer in self.viewer.layers:
|
567
|
+
origin_layer = self.viewer.layers[origin_layer]
|
568
|
+
if np.allclose(origin_layer.data.shape, processed_data.shape):
|
569
|
+
in_mask = np.sum(np.fmax(origin_layer.data * processed_data, 0))
|
570
|
+
in_mask /= np.sum(np.fmax(origin_layer.data, 0))
|
571
|
+
in_mask *= 100
|
572
|
+
self.density_field.value = f"Positive Density in Mask: {in_mask:.2f}%"
|
573
|
+
|
495
574
|
|
496
575
|
class ExportWidget(widgets.Container):
|
497
576
|
def __init__(self, viewer):
|
@@ -0,0 +1 @@
|
|
1
|
+
__version__ = "0.1.3"
|
@@ -121,7 +121,7 @@ class PeakCaller(ABC):
|
|
121
121
|
fourier_shift = kwargs.get(
|
122
122
|
"fourier_shift", backend.zeros(peak_positions.shape[1], dtype=int)
|
123
123
|
)
|
124
|
-
if
|
124
|
+
if backend.sum(fourier_shift != 0) != 0:
|
125
125
|
peak_positions = backend.mod(
|
126
126
|
backend.add(peak_positions, fourier_shift), score_space.shape
|
127
127
|
)
|
@@ -197,6 +197,7 @@ class PeakCaller(ABC):
|
|
197
197
|
if len(candidate) == 0:
|
198
198
|
continue
|
199
199
|
peak_positions, rotations, peak_scores, peak_details = candidate
|
200
|
+
kwargs["translation_offset"] = backend.zeros(peak_positions.shape[1])
|
200
201
|
base._update(
|
201
202
|
peak_positions=backend.to_backend_array(peak_positions),
|
202
203
|
peak_details=backend.to_backend_array(peak_details),
|
@@ -237,7 +238,6 @@ class PeakCaller(ABC):
|
|
237
238
|
translation_offset = backend.astype(translation_offset, peak_positions.dtype)
|
238
239
|
|
239
240
|
backend.add(peak_positions, translation_offset, out=peak_positions)
|
240
|
-
|
241
241
|
if not len(self.peak_list):
|
242
242
|
self.peak_list = [peak_positions, rotations, peak_scores, peak_details]
|
243
243
|
return None
|
@@ -491,8 +491,11 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
491
491
|
Operates as a context manager, yielding None and providing
|
492
492
|
the set GPU context for enclosed operations.
|
493
493
|
"""
|
494
|
-
|
495
|
-
|
494
|
+
if self.device == "cuda":
|
495
|
+
with self._array_backend.cuda.device(device_index):
|
496
|
+
yield
|
497
|
+
else:
|
498
|
+
yield None
|
496
499
|
|
497
500
|
def device_count(self) -> int:
|
498
501
|
"""
|
@@ -505,7 +508,7 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
505
508
|
"""
|
506
509
|
return self._array_backend.cuda.device_count()
|
507
510
|
|
508
|
-
def reverse(arr: TorchTensor) -> TorchTensor:
|
511
|
+
def reverse(self, arr: TorchTensor) -> TorchTensor:
|
509
512
|
"""
|
510
513
|
Reverse the order of elements in a tensor along all its axes.
|
511
514
|
|
@@ -519,4 +522,4 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
519
522
|
TorchTensor
|
520
523
|
Reversed tensor.
|
521
524
|
"""
|
522
|
-
return self._array_backend.flip(arr, [i for i in range(arr.
|
525
|
+
return self._array_backend.flip(arr, [i for i in range(arr.ndim)])
|
@@ -46,10 +46,10 @@ class Density:
|
|
46
46
|
----------
|
47
47
|
data : NDArray
|
48
48
|
Electron density data.
|
49
|
-
origin : NDArray
|
50
|
-
Origin of the coordinate system.
|
51
|
-
sampling_rate : NDArray
|
52
|
-
Sampling rate along data axis.
|
49
|
+
origin : NDArray, optional
|
50
|
+
Origin of the coordinate system. Defaults to zero.
|
51
|
+
sampling_rate : NDArray, optional
|
52
|
+
Sampling rate along data axis. Defaults to one.
|
53
53
|
metadata : dict, optional
|
54
54
|
Dictionary with metadata information, empty by default.
|
55
55
|
|
@@ -62,16 +62,18 @@ class Density:
|
|
62
62
|
--------
|
63
63
|
>>> import numpy as np
|
64
64
|
>>> data = np.random.rand(50,50,50)
|
65
|
-
>>> Density(data = data, origin = (0, 0, 0), sampling_rate = (
|
65
|
+
>>> Density(data = data, origin = (0, 0, 0), sampling_rate = (1, 1, 1))
|
66
66
|
"""
|
67
67
|
|
68
68
|
def __init__(
|
69
69
|
self,
|
70
70
|
data: NDArray,
|
71
|
-
origin: NDArray,
|
72
|
-
sampling_rate: NDArray,
|
71
|
+
origin: NDArray = None,
|
72
|
+
sampling_rate: NDArray = None,
|
73
73
|
metadata: Dict = {},
|
74
74
|
):
|
75
|
+
origin = 0 if origin is None else origin
|
76
|
+
sampling_rate = 1 if sampling_rate is None else sampling_rate
|
75
77
|
sampling_rate, origin = np.asarray(sampling_rate), np.asarray(origin)
|
76
78
|
sampling_rate = np.repeat(sampling_rate, data.ndim // sampling_rate.size)
|
77
79
|
|
@@ -127,7 +129,7 @@ class Density:
|
|
127
129
|
|
128
130
|
Examples
|
129
131
|
--------
|
130
|
-
>>> density = Density.from_file("/path/to/
|
132
|
+
>>> density = Density.from_file("/path/to/file")
|
131
133
|
|
132
134
|
Notes
|
133
135
|
-----
|
@@ -548,8 +550,9 @@ class Density:
|
|
548
550
|
Which weight should be given to individual atoms. For valid values
|
549
551
|
see :py:meth:`Structure.from_file`.
|
550
552
|
chain : str, optional
|
551
|
-
|
552
|
-
|
553
|
+
The chain identifier. If multiple chains should be selected they need
|
554
|
+
to be a comma separated string, e.g. 'A,B,CE'. If chain None,
|
555
|
+
all chains are returned. Default is None.
|
553
556
|
filter_by_elements : set, optional
|
554
557
|
Set of atomic elements to keep. Default is all atoms.
|
555
558
|
filter_by_residues : set, optional
|
@@ -876,7 +879,6 @@ class Density:
|
|
876
879
|
box_start = np.array([b.start for b in box])
|
877
880
|
box_stop = np.array([b.stop for b in box])
|
878
881
|
left_pad = -np.minimum(box_start, np.zeros(len(box), dtype=int))
|
879
|
-
has_extension = box_start < 0
|
880
882
|
|
881
883
|
right_pad = box_stop - box_start * (box_start > 0)
|
882
884
|
right_pad -= np.array(self.shape, dtype=int)
|
@@ -905,8 +907,6 @@ class Density:
|
|
905
907
|
self.data = self.data[crop_box].copy()
|
906
908
|
|
907
909
|
# In case the box is larger than the current map
|
908
|
-
before_shape = self.data.shape
|
909
|
-
after_shape = [b.stop - b.start for b in box]
|
910
910
|
self.data = self._pad_slice(box)
|
911
911
|
|
912
912
|
# Adjust the origin
|
@@ -934,10 +934,19 @@ class Density:
|
|
934
934
|
tuple
|
935
935
|
A tuple containing slice objects representing the box.
|
936
936
|
|
937
|
+
Raises
|
938
|
+
------
|
939
|
+
ValueError
|
940
|
+
If the cutoff is larger than or equal to the maximum density value.
|
941
|
+
|
937
942
|
See Also
|
938
943
|
--------
|
939
944
|
:py:meth:`Density.adjust_box`
|
940
945
|
"""
|
946
|
+
if cutoff >= self.data.max():
|
947
|
+
raise ValueError(
|
948
|
+
f"Cutoff exceeds data range ({cutoff} >= {self.data.max()})."
|
949
|
+
)
|
941
950
|
starts, stops = [], []
|
942
951
|
for axis in range(self.data.ndim):
|
943
952
|
projected_max = np.max(
|
@@ -116,7 +116,6 @@ class MatchingData:
|
|
116
116
|
).astype(int)
|
117
117
|
|
118
118
|
ret_shape = np.add(slice_shape, padding)
|
119
|
-
|
120
119
|
arr_start = np.subtract(slice_start, data_voxels_left)
|
121
120
|
arr_stop = np.add(slice_stop, data_voxels_right)
|
122
121
|
arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
|
@@ -138,7 +137,6 @@ class MatchingData:
|
|
138
137
|
arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype
|
139
138
|
)
|
140
139
|
arr = np.asarray(arr[*arr_mesh])
|
141
|
-
|
142
140
|
ret = np.full(
|
143
141
|
shape=np.add(slice_shape, padding), fill_value=arr.mean(), dtype=arr.dtype
|
144
142
|
)
|
@@ -188,8 +186,8 @@ class MatchingData:
|
|
188
186
|
template_pad = np.zeros(len(self.target.shape), dtype=int)
|
189
187
|
|
190
188
|
indices = compute_full_convolution_index(
|
191
|
-
outer_shape=self.
|
192
|
-
inner_shape=self.
|
189
|
+
outer_shape=self._target.shape,
|
190
|
+
inner_shape=self._template.shape,
|
193
191
|
outer_split=target_slice,
|
194
192
|
inner_split=template_slice,
|
195
193
|
)
|
@@ -1136,7 +1136,6 @@ def mcc_scoring(
|
|
1136
1136
|
mask_overlap, axis=axes, keepdims=True
|
1137
1137
|
)
|
1138
1138
|
temp[mask_overlap < number_px_threshold] = 0.0
|
1139
|
-
|
1140
1139
|
convolution_mode = kwargs.get("convolution_mode", "full")
|
1141
1140
|
score = apply_convolution_mode(
|
1142
1141
|
temp, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
|
@@ -1165,7 +1164,6 @@ def device_memory_handler(func: Callable):
|
|
1165
1164
|
return_value = func(shared_memory_handler=smh, *args, **kwargs)
|
1166
1165
|
except Exception as e:
|
1167
1166
|
print(e)
|
1168
|
-
return None
|
1169
1167
|
last_type, last_value, last_traceback = sys.exc_info()
|
1170
1168
|
finally:
|
1171
1169
|
handle_traceback(last_type, last_value, last_traceback)
|
@@ -327,7 +327,7 @@ def estimate_ram_usage(
|
|
327
327
|
"""
|
328
328
|
if matching_method not in MATCHING_MEMORY_REGISTRY:
|
329
329
|
raise ValueError(
|
330
|
-
f"Supported
|
330
|
+
f"Supported options are {','.join(MATCHING_MEMORY_REGISTRY.keys())}"
|
331
331
|
)
|
332
332
|
|
333
333
|
convolution_shape, fast_shape, ft_shape = _compute_convolution_shapes(
|
@@ -246,6 +246,7 @@ class CrossCorrelation(MatchCoordinatesToDensity):
|
|
246
246
|
.. math::
|
247
247
|
|
248
248
|
\\text{score} = \\text{target_weights} \\cdot \\text{template_weights}
|
249
|
+
|
249
250
|
"""
|
250
251
|
|
251
252
|
def __init__(self, **kwargs):
|
@@ -298,6 +299,7 @@ class LaplaceCrossCorrelation(CrossCorrelation):
|
|
298
299
|
|
299
300
|
\\text{score} = \\nabla^{2} \\text{target_weights} \\cdot
|
300
301
|
\\nabla^{2} \\text{template_weights}
|
302
|
+
|
301
303
|
"""
|
302
304
|
|
303
305
|
def __init__(self, **kwargs):
|
@@ -617,6 +619,7 @@ class Chamfer(MatchCoordinatesToCoordinates):
|
|
617
619
|
-------
|
618
620
|
float
|
619
621
|
The negative of the Chamfer distance score.
|
622
|
+
|
620
623
|
"""
|
621
624
|
dist, _ = self.target_tree.query(self.template_coordinates_rotated.T)
|
622
625
|
score = np.mean(dist)
|
@@ -638,6 +641,7 @@ class MutualInformation(MatchCoordinatesToDensity):
|
|
638
641
|
.. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
|
639
642
|
fitting", Journal of Structural Biology, vol. 174, no. 2,
|
640
643
|
pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
|
644
|
+
|
641
645
|
"""
|
642
646
|
|
643
647
|
def __init__(self, **kwargs):
|
@@ -776,6 +780,7 @@ class NormalVectorScore(MatchCoordinatesToCoordinates):
|
|
776
780
|
.. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
|
777
781
|
fitting", Journal of Structural Biology, vol. 174, no. 2,
|
778
782
|
pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
|
783
|
+
|
779
784
|
"""
|
780
785
|
|
781
786
|
def __init__(self, **kwargs):
|
@@ -762,6 +762,8 @@ def euler_to_rotationmatrix(angles: Tuple[float]) -> NDArray:
|
|
762
762
|
NDArray
|
763
763
|
The generated rotation matrix.
|
764
764
|
"""
|
765
|
+
if len(angles) == 1:
|
766
|
+
angles = (angles, 0, 0)
|
765
767
|
rotation_matrix = (
|
766
768
|
Rotation.from_euler("zyx", angles, degrees=True).as_matrix().astype(np.float32)
|
767
769
|
)
|
@@ -775,13 +777,17 @@ def euler_from_rotationmatrix(rotation_matrix: NDArray) -> Tuple:
|
|
775
777
|
Parameters
|
776
778
|
----------
|
777
779
|
rotation_matrix : NDArray
|
778
|
-
A 3 x 3 rotation matrix in z y x form.
|
780
|
+
A 2 x 2 or 3 x 3 rotation matrix in z y x form.
|
779
781
|
|
780
782
|
Returns
|
781
783
|
-------
|
782
784
|
Tuple
|
783
785
|
The generate euler angles in degrees
|
784
786
|
"""
|
787
|
+
if rotation_matrix.shape[0] == 2:
|
788
|
+
temp_matrix = np.eye(3)
|
789
|
+
temp_matrix[:2, :2] = rotation_matrix
|
790
|
+
rotation_matrix = temp_matrix
|
785
791
|
euler_angles = (
|
786
792
|
Rotation.from_matrix(rotation_matrix)
|
787
793
|
.as_euler("zyx", degrees=True)
|