pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
- pytme-0.2.2.dist-info/METADATA +91 -0
- pytme-0.2.2.dist-info/RECORD +74 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +147 -93
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +67 -26
- scripts/preprocessor_gui.py +175 -85
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +451 -809
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +185 -0
- tme/backends/cupy_backend.py +111 -223
- tme/backends/jax_backend.py +214 -150
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +239 -507
- tme/backends/pytorch_backend.py +21 -145
- tme/density.py +233 -363
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +322 -285
- tme/matching_exhaustive.py +172 -1493
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +884 -0
- tme/matching_utils.py +280 -386
- tme/memory.py +377 -0
- tme/orientations.py +52 -12
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +34 -40
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
scripts/postprocess.py
CHANGED
@@ -8,9 +8,8 @@
|
|
8
8
|
import argparse
|
9
9
|
from sys import exit
|
10
10
|
from os import getcwd
|
11
|
-
from os.path import join, abspath
|
12
11
|
from typing import List, Tuple
|
13
|
-
from os.path import splitext
|
12
|
+
from os.path import join, abspath, splitext
|
14
13
|
|
15
14
|
import numpy as np
|
16
15
|
from numpy.typing import NDArray
|
@@ -26,6 +25,7 @@ from tme.analyzer import (
|
|
26
25
|
)
|
27
26
|
from tme.matching_utils import (
|
28
27
|
load_pickle,
|
28
|
+
centered_mask,
|
29
29
|
euler_to_rotationmatrix,
|
30
30
|
euler_from_rotationmatrix,
|
31
31
|
)
|
@@ -41,9 +41,7 @@ PEAK_CALLERS = {
|
|
41
41
|
|
42
42
|
|
43
43
|
def parse_args():
|
44
|
-
parser = argparse.ArgumentParser(
|
45
|
-
description="Peak Calling for Template Matching Outputs"
|
46
|
-
)
|
44
|
+
parser = argparse.ArgumentParser(description="Analyze Template Matching Outputs")
|
47
45
|
|
48
46
|
input_group = parser.add_argument_group("Input")
|
49
47
|
output_group = parser.add_argument_group("Output")
|
@@ -56,6 +54,13 @@ def parse_args():
|
|
56
54
|
nargs="+",
|
57
55
|
help="Path to the output of match_template.py.",
|
58
56
|
)
|
57
|
+
input_group.add_argument(
|
58
|
+
"--background_file",
|
59
|
+
required=False,
|
60
|
+
nargs="+",
|
61
|
+
help="Path to an output of match_template.py used for normalization. "
|
62
|
+
"For instance from --scramble_phases or a different template.",
|
63
|
+
)
|
59
64
|
input_group.add_argument(
|
60
65
|
"--target_mask",
|
61
66
|
required=False,
|
@@ -87,7 +92,7 @@ def parse_args():
|
|
87
92
|
"average",
|
88
93
|
],
|
89
94
|
default="orientations",
|
90
|
-
help="Available output formats:"
|
95
|
+
help="Available output formats: "
|
91
96
|
"orientations (translation, rotation, and score), "
|
92
97
|
"alignment (aligned template to target based on orientations), "
|
93
98
|
"extraction (extract regions around peaks from targets, i.e. subtomograms), "
|
@@ -206,6 +211,15 @@ def parse_args():
|
|
206
211
|
elif args.number_of_peaks is None:
|
207
212
|
args.number_of_peaks = 1000
|
208
213
|
|
214
|
+
if args.background_file is None:
|
215
|
+
args.background_file = [None]
|
216
|
+
if len(args.background_file) == 1:
|
217
|
+
args.background_file = args.background_file * len(args.input_file)
|
218
|
+
elif len(args.background_file) not in (0, len(args.input_file)):
|
219
|
+
raise ValueError(
|
220
|
+
"--background_file needs to be specified once or for each --input_file."
|
221
|
+
)
|
222
|
+
|
209
223
|
return args
|
210
224
|
|
211
225
|
|
@@ -233,8 +247,8 @@ def load_template(
|
|
233
247
|
return template, center, translation, template_is_density
|
234
248
|
|
235
249
|
|
236
|
-
def merge_outputs(data,
|
237
|
-
if len(
|
250
|
+
def merge_outputs(data, foreground_paths: List[str], background_paths: List[str], args):
|
251
|
+
if len(foreground_paths) == 0:
|
238
252
|
return data, 1
|
239
253
|
|
240
254
|
if data[0].ndim != data[2].ndim:
|
@@ -275,8 +289,11 @@ def merge_outputs(data, filepaths: List[str], args):
|
|
275
289
|
|
276
290
|
entities = np.zeros_like(data[0])
|
277
291
|
data[0] = _norm_scores(data=data, args=args)
|
278
|
-
for index, filepath in enumerate(
|
279
|
-
new_scores = _norm_scores(
|
292
|
+
for index, filepath in enumerate(foreground_paths):
|
293
|
+
new_scores = _norm_scores(
|
294
|
+
data=load_match_template_output(filepath, background_paths[index]),
|
295
|
+
args=args,
|
296
|
+
)
|
280
297
|
indices = new_scores > data[0]
|
281
298
|
entities[indices] = index + 1
|
282
299
|
data[0][indices] = new_scores[indices]
|
@@ -284,9 +301,18 @@ def merge_outputs(data, filepaths: List[str], args):
|
|
284
301
|
return data, entities
|
285
302
|
|
286
303
|
|
304
|
+
def load_match_template_output(foreground_path, background_path):
|
305
|
+
data = load_pickle(foreground_path)
|
306
|
+
if background_path is not None:
|
307
|
+
data_background = load_pickle(background_path)
|
308
|
+
data[0] = (data[0] - data_background[0]) / (1 - data_background[0])
|
309
|
+
np.fmax(data[0], 0, out=data[0])
|
310
|
+
return data
|
311
|
+
|
312
|
+
|
287
313
|
def main():
|
288
314
|
args = parse_args()
|
289
|
-
data =
|
315
|
+
data = load_match_template_output(args.input_file[0], args.background_file[0])
|
290
316
|
|
291
317
|
target_origin, _, sampling_rate, cli_args = data[-1]
|
292
318
|
|
@@ -326,7 +352,12 @@ def main():
|
|
326
352
|
|
327
353
|
entities = None
|
328
354
|
if len(args.input_file) > 1:
|
329
|
-
data, entities = merge_outputs(
|
355
|
+
data, entities = merge_outputs(
|
356
|
+
data=data,
|
357
|
+
foreground_paths=args.input_file,
|
358
|
+
background_paths=args.background_file,
|
359
|
+
args=args,
|
360
|
+
)
|
330
361
|
|
331
362
|
orientations = args.orientations
|
332
363
|
if orientations is None:
|
@@ -339,24 +370,27 @@ def main():
|
|
339
370
|
target_mask = Density.from_file(args.target_mask)
|
340
371
|
scores = scores * target_mask.data
|
341
372
|
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
scores.shape, np.multiply(args.min_boundary_distance, 2)
|
346
|
-
).astype(int)
|
373
|
+
cropped_shape = np.subtract(
|
374
|
+
scores.shape, np.multiply(args.min_boundary_distance, 2)
|
375
|
+
).astype(int)
|
347
376
|
|
348
|
-
|
377
|
+
if args.min_boundary_distance > 0:
|
378
|
+
scores = centered_mask(scores, new_shape=cropped_shape)
|
379
|
+
|
380
|
+
if args.n_false_positives is not None:
|
381
|
+
# Rickgauer et al. 2017
|
382
|
+
cropped_slice = tuple(
|
349
383
|
slice(
|
350
384
|
int(args.min_boundary_distance),
|
351
385
|
int(x - args.min_boundary_distance),
|
352
386
|
)
|
353
387
|
for x in scores.shape
|
354
388
|
)
|
355
|
-
|
356
|
-
n_correlations = np.size(scores[
|
389
|
+
args.n_false_positives = max(args.n_false_positives, 1)
|
390
|
+
n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
|
357
391
|
minimum_score = np.multiply(
|
358
392
|
erfcinv(2 * args.n_false_positives / n_correlations),
|
359
|
-
np.sqrt(2) * np.std(scores[
|
393
|
+
np.sqrt(2) * np.std(scores[cropped_slice]),
|
360
394
|
)
|
361
395
|
print(f"Determined minimum score cutoff: {minimum_score}.")
|
362
396
|
minimum_score = max(minimum_score, 0)
|
@@ -371,6 +405,8 @@ def main():
|
|
371
405
|
"min_distance": args.min_distance,
|
372
406
|
"min_boundary_distance": args.min_boundary_distance,
|
373
407
|
"batch_dims": args.batch_dims,
|
408
|
+
"minimum_score": args.minimum_score,
|
409
|
+
"maximum_score": args.maximum_score,
|
374
410
|
}
|
375
411
|
|
376
412
|
peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
|
@@ -380,7 +416,6 @@ def main():
|
|
380
416
|
mask=template.data,
|
381
417
|
rotation_mapping=rotation_mapping,
|
382
418
|
rotation_array=rotation_array,
|
383
|
-
minimum_score=args.minimum_score,
|
384
419
|
)
|
385
420
|
candidates = peak_caller.merge(
|
386
421
|
candidates=[tuple(peak_caller)], **peak_caller_kwargs
|
@@ -388,10 +423,16 @@ def main():
|
|
388
423
|
if len(candidates) == 0:
|
389
424
|
candidates = [[], [], [], []]
|
390
425
|
print("Found no peaks, consider changing peak calling parameters.")
|
391
|
-
exit(
|
426
|
+
exit(-1)
|
392
427
|
|
393
428
|
for translation, _, score, detail in zip(*candidates):
|
394
|
-
|
429
|
+
rotation_index = rotation_array[tuple(translation)]
|
430
|
+
rotation = rotation_mapping.get(
|
431
|
+
rotation_index, np.zeros(template.data.ndim, int)
|
432
|
+
)
|
433
|
+
if rotation.ndim == 2:
|
434
|
+
rotation = euler_from_rotationmatrix(rotation)
|
435
|
+
rotations.append(rotation)
|
395
436
|
|
396
437
|
else:
|
397
438
|
candidates = data
|
@@ -430,7 +471,7 @@ def main():
|
|
430
471
|
)
|
431
472
|
exit(-1)
|
432
473
|
orientations.translations = peak_caller.oversample_peaks(
|
433
|
-
|
474
|
+
scores=data[0],
|
434
475
|
peak_positions=orientations.translations,
|
435
476
|
oversampling_factor=args.peak_oversampling,
|
436
477
|
)
|
@@ -570,7 +611,7 @@ def main():
|
|
570
611
|
return_orientations=True,
|
571
612
|
)
|
572
613
|
out = np.zeros_like(template.data)
|
573
|
-
out = np.zeros(np.multiply(template.shape, 2).astype(int))
|
614
|
+
# out = np.zeros(np.multiply(template.shape, 2).astype(int))
|
574
615
|
for index in range(len(cand_slices)):
|
575
616
|
from scipy.spatial.transform import Rotation
|
576
617
|
|
scripts/preprocessor_gui.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1
1
|
#!python3
|
2
|
-
"""
|
3
|
-
Exposes tme.preprocessor.Preprocessor and tme.fitter_utils member functions
|
4
|
-
to achieve this aim.
|
2
|
+
""" GUI for identifying adequate template matching filter and masks.
|
5
3
|
|
6
4
|
Copyright (c) 2023 European Molecular Biology Laboratory
|
7
5
|
|
@@ -12,17 +10,20 @@ import argparse
|
|
12
10
|
from typing import Tuple, Callable, List
|
13
11
|
from typing_extensions import Annotated
|
14
12
|
|
13
|
+
import napari
|
15
14
|
import numpy as np
|
16
15
|
import pandas as pd
|
17
|
-
import
|
16
|
+
from scipy.fft import next_fast_len
|
18
17
|
from napari.layers import Image
|
19
18
|
from napari.utils.events import EventedList
|
20
|
-
|
21
19
|
from magicgui import widgets
|
22
20
|
from qtpy.QtWidgets import QFileDialog
|
23
21
|
from numpy.typing import NDArray
|
24
22
|
|
23
|
+
from tme.backends import backend
|
25
24
|
from tme import Preprocessor, Density
|
25
|
+
from tme.preprocessing import BandPassFilter
|
26
|
+
from tme.preprocessing.tilt_series import CTF
|
26
27
|
from tme.matching_utils import create_mask, load_pickle
|
27
28
|
|
28
29
|
preprocessor = Preprocessor()
|
@@ -35,19 +36,57 @@ def gaussian_filter(template: NDArray, sigma: float, **kwargs: dict) -> NDArray:
|
|
35
36
|
|
36
37
|
def bandpass_filter(
|
37
38
|
template: NDArray,
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
39
|
+
lowpass_angstrom: float = 30,
|
40
|
+
highpass_angstrom: float = 140,
|
41
|
+
hard_edges: bool = False,
|
42
|
+
sampling_rate=None,
|
42
43
|
) -> NDArray:
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
44
|
+
bpf = BandPassFilter(
|
45
|
+
lowpass=lowpass_angstrom,
|
46
|
+
highpass=highpass_angstrom,
|
47
|
+
sampling_rate=np.max(sampling_rate),
|
48
|
+
use_gaussian=not hard_edges,
|
49
|
+
shape_is_real_fourier=True,
|
50
|
+
return_real_fourier=True,
|
50
51
|
)
|
52
|
+
template_ft = np.fft.rfftn(template, s=template.shape)
|
53
|
+
|
54
|
+
mask = bpf(shape=template_ft.shape)["data"]
|
55
|
+
np.multiply(template_ft, mask, out=template_ft)
|
56
|
+
return np.fft.irfftn(template_ft, s=template.shape).real
|
57
|
+
|
58
|
+
|
59
|
+
def ctf_filter(
|
60
|
+
template: NDArray,
|
61
|
+
defocus_angstrom: float = 30000,
|
62
|
+
acceleration_voltage: float = 300,
|
63
|
+
spherical_aberration: float = 2.7,
|
64
|
+
amplitude_contrast: float = 0.07,
|
65
|
+
phase_shift: float = 0,
|
66
|
+
defocus_angle: float = 0,
|
67
|
+
sampling_rate=None,
|
68
|
+
flip_phase: bool = False,
|
69
|
+
) -> NDArray:
|
70
|
+
fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
|
71
|
+
template_pad = backend.topleft_pad(template, fast_shape)
|
72
|
+
template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
|
73
|
+
ctf = CTF(
|
74
|
+
angles=[0],
|
75
|
+
shape=fast_shape,
|
76
|
+
defocus_x=[defocus_angstrom],
|
77
|
+
acceleration_voltage=acceleration_voltage * 1e3,
|
78
|
+
spherical_aberration=spherical_aberration * 1e7,
|
79
|
+
amplitude_contrast=amplitude_contrast,
|
80
|
+
phase_shift=[phase_shift],
|
81
|
+
defocus_angle=[defocus_angle],
|
82
|
+
sampling_rate=np.max(sampling_rate),
|
83
|
+
return_real_fourier=True,
|
84
|
+
flip_phase=flip_phase,
|
85
|
+
)
|
86
|
+
np.multiply(template_ft, ctf()["data"], out=template_ft)
|
87
|
+
template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
|
88
|
+
template = backend.topleft_pad(template_pad, template.shape)
|
89
|
+
return template
|
51
90
|
|
52
91
|
|
53
92
|
def difference_of_gaussian_filter(
|
@@ -109,61 +148,6 @@ def mean(
|
|
109
148
|
return preprocessor.mean_filter(template=template, width=width)
|
110
149
|
|
111
150
|
|
112
|
-
def resolution_sphere(
|
113
|
-
template: NDArray,
|
114
|
-
cutoff_angstrom: float,
|
115
|
-
highpass: bool = False,
|
116
|
-
sampling_rate=None,
|
117
|
-
) -> NDArray:
|
118
|
-
if cutoff_angstrom == 0:
|
119
|
-
return template
|
120
|
-
|
121
|
-
cutoff_frequency = np.max(2 * sampling_rate / cutoff_angstrom)
|
122
|
-
|
123
|
-
min_freq, max_freq = 0, cutoff_frequency
|
124
|
-
if highpass:
|
125
|
-
min_freq, max_freq = cutoff_frequency, 1e10
|
126
|
-
|
127
|
-
mask = preprocessor.bandpass_mask(
|
128
|
-
shape=template.shape,
|
129
|
-
minimum_frequency=min_freq,
|
130
|
-
maximum_frequency=max_freq,
|
131
|
-
omit_negative_frequencies=False,
|
132
|
-
)
|
133
|
-
|
134
|
-
template_ft = np.fft.fftn(template)
|
135
|
-
np.multiply(template_ft, mask, out=template_ft)
|
136
|
-
return np.fft.ifftn(template_ft).real
|
137
|
-
|
138
|
-
|
139
|
-
def resolution_gaussian(
|
140
|
-
template: NDArray,
|
141
|
-
cutoff_angstrom: float,
|
142
|
-
highpass: bool = False,
|
143
|
-
sampling_rate=None,
|
144
|
-
) -> NDArray:
|
145
|
-
if cutoff_angstrom == 0:
|
146
|
-
return template
|
147
|
-
|
148
|
-
grid = preprocessor.fftfreqn(
|
149
|
-
shape=template.shape, sampling_rate=sampling_rate / sampling_rate.max()
|
150
|
-
)
|
151
|
-
|
152
|
-
sigma_fourier = np.divide(
|
153
|
-
np.max(2 * sampling_rate / cutoff_angstrom), np.sqrt(2 * np.log(2))
|
154
|
-
)
|
155
|
-
|
156
|
-
mask = np.exp(-(grid**2) / (2 * sigma_fourier**2))
|
157
|
-
if highpass:
|
158
|
-
mask = 1 - mask
|
159
|
-
|
160
|
-
mask = np.fft.ifftshift(mask)
|
161
|
-
|
162
|
-
template_ft = np.fft.fftn(template)
|
163
|
-
np.multiply(template_ft, mask, out=template_ft)
|
164
|
-
return np.fft.ifftn(template_ft).real
|
165
|
-
|
166
|
-
|
167
151
|
def wedge(
|
168
152
|
template: NDArray,
|
169
153
|
tilt_start: float,
|
@@ -274,8 +258,7 @@ WRAPPED_FUNCTIONS = {
|
|
274
258
|
"mean_filter": mean,
|
275
259
|
"wedge_filter": wedge,
|
276
260
|
"power_spectrum": compute_power_spectrum,
|
277
|
-
"
|
278
|
-
"resolution_sphere": resolution_sphere,
|
261
|
+
"ctf": ctf_filter,
|
279
262
|
}
|
280
263
|
|
281
264
|
EXCLUDED_FUNCTIONS = [
|
@@ -421,6 +404,7 @@ def sphere_mask(
|
|
421
404
|
center_y: float,
|
422
405
|
center_z: float,
|
423
406
|
radius: float,
|
407
|
+
sigma_decay: float = 0,
|
424
408
|
**kwargs,
|
425
409
|
) -> NDArray:
|
426
410
|
return create_mask(
|
@@ -428,6 +412,7 @@ def sphere_mask(
|
|
428
412
|
shape=template.shape,
|
429
413
|
center=(center_x, center_y, center_z),
|
430
414
|
radius=radius,
|
415
|
+
sigma_decay=sigma_decay,
|
431
416
|
)
|
432
417
|
|
433
418
|
|
@@ -439,6 +424,7 @@ def ellipsod_mask(
|
|
439
424
|
radius_x: float,
|
440
425
|
radius_y: float,
|
441
426
|
radius_z: float,
|
427
|
+
sigma_decay: float = 0,
|
442
428
|
**kwargs,
|
443
429
|
) -> NDArray:
|
444
430
|
return create_mask(
|
@@ -446,6 +432,7 @@ def ellipsod_mask(
|
|
446
432
|
shape=template.shape,
|
447
433
|
center=(center_x, center_y, center_z),
|
448
434
|
radius=(radius_x, radius_y, radius_z),
|
435
|
+
sigma_decay=sigma_decay,
|
449
436
|
)
|
450
437
|
|
451
438
|
|
@@ -457,6 +444,7 @@ def box_mask(
|
|
457
444
|
height_x: int,
|
458
445
|
height_y: int,
|
459
446
|
height_z: int,
|
447
|
+
sigma_decay: float = 0,
|
460
448
|
**kwargs,
|
461
449
|
) -> NDArray:
|
462
450
|
return create_mask(
|
@@ -464,6 +452,7 @@ def box_mask(
|
|
464
452
|
shape=template.shape,
|
465
453
|
center=(center_x, center_y, center_z),
|
466
454
|
height=(height_x, height_y, height_z),
|
455
|
+
sigma_decay=sigma_decay,
|
467
456
|
)
|
468
457
|
|
469
458
|
|
@@ -476,6 +465,7 @@ def tube_mask(
|
|
476
465
|
inner_radius: float,
|
477
466
|
outer_radius: float,
|
478
467
|
height: int,
|
468
|
+
sigma_decay: float = 0,
|
479
469
|
**kwargs,
|
480
470
|
) -> NDArray:
|
481
471
|
return create_mask(
|
@@ -486,6 +476,7 @@ def tube_mask(
|
|
486
476
|
inner_radius=inner_radius,
|
487
477
|
outer_radius=outer_radius,
|
488
478
|
height=height,
|
479
|
+
sigma_decay=sigma_decay,
|
489
480
|
)
|
490
481
|
|
491
482
|
|
@@ -533,13 +524,23 @@ def wedge_mask(
|
|
533
524
|
|
534
525
|
|
535
526
|
def threshold_mask(
|
536
|
-
template: NDArray,
|
527
|
+
template: NDArray,
|
528
|
+
invert: bool = False,
|
529
|
+
standard_deviation: float = 5.0,
|
530
|
+
sigma: float = 0.0,
|
531
|
+
**kwargs,
|
537
532
|
) -> NDArray:
|
538
533
|
template_mean = template.mean()
|
539
534
|
template_deviation = standard_deviation * template.std()
|
540
535
|
upper = template_mean + template_deviation
|
541
536
|
lower = template_mean - template_deviation
|
542
|
-
mask = np.
|
537
|
+
mask = np.logical_or(template <= lower, template >= upper)
|
538
|
+
|
539
|
+
if sigma != 0:
|
540
|
+
mask_filter = preprocessor.gaussian_filter(template=mask * 1.0, sigma=sigma)
|
541
|
+
mask = np.add(mask, (1 - mask) * mask_filter)
|
542
|
+
mask[mask < np.exp(-np.square(sigma))] = 0
|
543
|
+
|
543
544
|
if invert:
|
544
545
|
np.invert(mask, out=mask)
|
545
546
|
|
@@ -890,6 +891,7 @@ class PointCloudWidget(widgets.Container):
|
|
890
891
|
|
891
892
|
self.viewer = viewer
|
892
893
|
self.dataframes = {}
|
894
|
+
self.selected_category = -1
|
893
895
|
|
894
896
|
self.import_button = widgets.PushButton(
|
895
897
|
name="Import", text="Import Point Cloud"
|
@@ -902,10 +904,98 @@ class PointCloudWidget(widgets.Container):
|
|
902
904
|
self.export_button.clicked.connect(self._export_point_cloud)
|
903
905
|
self.export_button.enabled = False
|
904
906
|
|
907
|
+
self.annotation_container = widgets.Container(name="Label", layout="horizontal")
|
908
|
+
self.positive_button = widgets.PushButton(name="Positive", text="Positive")
|
909
|
+
self.negative_button = widgets.PushButton(name="Negative", text="Negative")
|
910
|
+
self.positive_button.clicked.connect(self._set_positive)
|
911
|
+
self.negative_button.clicked.connect(self._set_negative)
|
912
|
+
self.annotation_container.append(self.positive_button)
|
913
|
+
self.annotation_container.append(self.negative_button)
|
914
|
+
|
915
|
+
self.face_color_select = widgets.ComboBox(
|
916
|
+
name="Color", choices=["Label", "Score"], value=None, nullable=True
|
917
|
+
)
|
918
|
+
self.face_color_select.changed.connect(self._update_face_color_mode)
|
919
|
+
|
905
920
|
self.append(self.import_button)
|
906
921
|
self.append(self.export_button)
|
922
|
+
self.append(self.annotation_container)
|
923
|
+
self.append(self.face_color_select)
|
924
|
+
|
907
925
|
self.viewer.layers.selection.events.changed.connect(self._update_buttons)
|
908
926
|
|
927
|
+
self.viewer.layers.events.inserted.connect(self._initialize_points_layer)
|
928
|
+
|
929
|
+
def _update_face_color_mode(self, event: str = None):
|
930
|
+
for layer in self.viewer.layers:
|
931
|
+
if not isinstance(layer, napari.layers.Points):
|
932
|
+
continue
|
933
|
+
|
934
|
+
layer.face_color = "white"
|
935
|
+
if event == "Label":
|
936
|
+
if len(layer.properties.get("detail", ())) == 0:
|
937
|
+
continue
|
938
|
+
layer.face_color = "detail"
|
939
|
+
layer.face_color_cycle = {
|
940
|
+
-1: "grey",
|
941
|
+
0: "red",
|
942
|
+
1: "green",
|
943
|
+
}
|
944
|
+
layer.face_color_mode = "cycle"
|
945
|
+
elif event == "Score":
|
946
|
+
if len(layer.properties.get("score_scaled", ())) == 0:
|
947
|
+
continue
|
948
|
+
layer.face_color = "score_scaled"
|
949
|
+
layer.face_colormap = "turbo"
|
950
|
+
layer.face_color_mode = "colormap"
|
951
|
+
|
952
|
+
layer.refresh_colors()
|
953
|
+
|
954
|
+
return None
|
955
|
+
|
956
|
+
def _set_positive(self, event):
|
957
|
+
self.selected_category = 1 if self.selected_category != 1 else -1
|
958
|
+
self._update_annotation_buttons()
|
959
|
+
|
960
|
+
def _set_negative(self, event):
|
961
|
+
self.selected_category = 0 if self.selected_category != 0 else -1
|
962
|
+
self._update_annotation_buttons()
|
963
|
+
|
964
|
+
def _update_annotation_buttons(self):
|
965
|
+
selected_style = "background-color: darkgrey"
|
966
|
+
default_style = "background-color: none"
|
967
|
+
|
968
|
+
self.positive_button.native.setStyleSheet(
|
969
|
+
selected_style if self.selected_category == 1 else default_style
|
970
|
+
)
|
971
|
+
self.negative_button.native.setStyleSheet(
|
972
|
+
selected_style if self.selected_category == 0 else default_style
|
973
|
+
)
|
974
|
+
|
975
|
+
def _initialize_points_layer(self, event):
|
976
|
+
layer = event.value
|
977
|
+
if not isinstance(layer, napari.layers.Points):
|
978
|
+
return
|
979
|
+
if len(layer.properties) == 0:
|
980
|
+
layer.properties = {"detail": [-1]}
|
981
|
+
|
982
|
+
if "detail" not in layer.properties:
|
983
|
+
layer["detail"] = [-1]
|
984
|
+
|
985
|
+
layer.mouse_drag_callbacks.append(self._on_click)
|
986
|
+
return None
|
987
|
+
|
988
|
+
def _on_click(self, layer, event):
|
989
|
+
if layer.mode == "add":
|
990
|
+
layer.current_properties["detail"][-1] = self.selected_category
|
991
|
+
elif layer.mode == "select":
|
992
|
+
for index in layer.selected_data:
|
993
|
+
layer.properties["detail"][index] = self.selected_category
|
994
|
+
|
995
|
+
# TODO: Check whether current face color is the desired one already
|
996
|
+
self._update_face_color_mode(self.face_color_select.value)
|
997
|
+
layer.refresh_colors()
|
998
|
+
|
909
999
|
def _update_buttons(self, event):
|
910
1000
|
is_pointcloud = isinstance(
|
911
1001
|
self.viewer.layers.selection.active, napari.layers.Points
|
@@ -951,9 +1041,7 @@ class PointCloudWidget(widgets.Container):
|
|
951
1041
|
|
952
1042
|
if "score" in merged_data.columns:
|
953
1043
|
merged_data["score"] = merged_data["score"].fillna(1)
|
954
|
-
|
955
|
-
merged_data["detail"] = merged_data["detail"].fillna(2)
|
956
|
-
|
1044
|
+
merged_data["detail"] = layer.properties["detail"]
|
957
1045
|
merged_data.to_csv(filename, sep="\t", index=False)
|
958
1046
|
|
959
1047
|
def _get_load_path(self, event):
|
@@ -977,7 +1065,7 @@ class PointCloudWidget(widgets.Container):
|
|
977
1065
|
dataframe["score"] = 1
|
978
1066
|
|
979
1067
|
if "detail" not in dataframe.columns:
|
980
|
-
dataframe["detail"] = -
|
1068
|
+
dataframe["detail"] = -1
|
981
1069
|
|
982
1070
|
point_properties = {
|
983
1071
|
"score": np.array(dataframe["score"].values),
|
@@ -991,8 +1079,6 @@ class PointCloudWidget(widgets.Container):
|
|
991
1079
|
points,
|
992
1080
|
size=10,
|
993
1081
|
properties=point_properties,
|
994
|
-
face_color="score_scaled",
|
995
|
-
face_colormap="turbo",
|
996
1082
|
name=layer_name,
|
997
1083
|
)
|
998
1084
|
self.dataframes[layer_name] = dataframe
|
@@ -1025,9 +1111,14 @@ class MatchingWidget(widgets.Container):
|
|
1025
1111
|
def _load_data(self, filename):
|
1026
1112
|
data = load_pickle(filename)
|
1027
1113
|
|
1028
|
-
|
1114
|
+
metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
|
1115
|
+
_ = self.viewer.add_image(
|
1116
|
+
data=data[2], name="Rotations", colormap="orange", metadata=metadata
|
1117
|
+
)
|
1029
1118
|
|
1030
|
-
_ = self.viewer.add_image(
|
1119
|
+
_ = self.viewer.add_image(
|
1120
|
+
data=data[0], name="Scores", colormap="turbo", metadata=metadata
|
1121
|
+
)
|
1031
1122
|
|
1032
1123
|
|
1033
1124
|
def main():
|
@@ -1045,11 +1136,10 @@ def main():
|
|
1045
1136
|
widget=alignment_widget, name="Alignment", area="right"
|
1046
1137
|
)
|
1047
1138
|
viewer.window.add_dock_widget(widget=mask_widget, name="Mask", area="right")
|
1139
|
+
viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
|
1048
1140
|
viewer.window.add_dock_widget(widget=point_cloud, name="PointCloud", area="left")
|
1049
1141
|
viewer.window.add_dock_widget(widget=matching_widget, name="Matching", area="left")
|
1050
1142
|
|
1051
|
-
viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
|
1052
|
-
|
1053
1143
|
napari.run()
|
1054
1144
|
|
1055
1145
|
|