pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
scripts/postprocess.py CHANGED
@@ -8,9 +8,8 @@
8
8
  import argparse
9
9
  from sys import exit
10
10
  from os import getcwd
11
- from os.path import join, abspath
12
11
  from typing import List, Tuple
13
- from os.path import splitext
12
+ from os.path import join, abspath, splitext
14
13
 
15
14
  import numpy as np
16
15
  from numpy.typing import NDArray
@@ -26,6 +25,7 @@ from tme.analyzer import (
26
25
  )
27
26
  from tme.matching_utils import (
28
27
  load_pickle,
28
+ centered_mask,
29
29
  euler_to_rotationmatrix,
30
30
  euler_from_rotationmatrix,
31
31
  )
@@ -41,9 +41,7 @@ PEAK_CALLERS = {
41
41
 
42
42
 
43
43
  def parse_args():
44
- parser = argparse.ArgumentParser(
45
- description="Peak Calling for Template Matching Outputs"
46
- )
44
+ parser = argparse.ArgumentParser(description="Analyze Template Matching Outputs")
47
45
 
48
46
  input_group = parser.add_argument_group("Input")
49
47
  output_group = parser.add_argument_group("Output")
@@ -56,6 +54,13 @@ def parse_args():
56
54
  nargs="+",
57
55
  help="Path to the output of match_template.py.",
58
56
  )
57
+ input_group.add_argument(
58
+ "--background_file",
59
+ required=False,
60
+ nargs="+",
61
+ help="Path to an output of match_template.py used for normalization. "
62
+ "For instance from --scramble_phases or a different template.",
63
+ )
59
64
  input_group.add_argument(
60
65
  "--target_mask",
61
66
  required=False,
@@ -87,7 +92,7 @@ def parse_args():
87
92
  "average",
88
93
  ],
89
94
  default="orientations",
90
- help="Available output formats:"
95
+ help="Available output formats: "
91
96
  "orientations (translation, rotation, and score), "
92
97
  "alignment (aligned template to target based on orientations), "
93
98
  "extraction (extract regions around peaks from targets, i.e. subtomograms), "
@@ -206,6 +211,15 @@ def parse_args():
206
211
  elif args.number_of_peaks is None:
207
212
  args.number_of_peaks = 1000
208
213
 
214
+ if args.background_file is None:
215
+ args.background_file = [None]
216
+ if len(args.background_file) == 1:
217
+ args.background_file = args.background_file * len(args.input_file)
218
+ elif len(args.background_file) not in (0, len(args.input_file)):
219
+ raise ValueError(
220
+ "--background_file needs to be specified once or for each --input_file."
221
+ )
222
+
209
223
  return args
210
224
 
211
225
 
@@ -233,8 +247,8 @@ def load_template(
233
247
  return template, center, translation, template_is_density
234
248
 
235
249
 
236
- def merge_outputs(data, filepaths: List[str], args):
237
- if len(filepaths) == 0:
250
+ def merge_outputs(data, foreground_paths: List[str], background_paths: List[str], args):
251
+ if len(foreground_paths) == 0:
238
252
  return data, 1
239
253
 
240
254
  if data[0].ndim != data[2].ndim:
@@ -275,8 +289,11 @@ def merge_outputs(data, filepaths: List[str], args):
275
289
 
276
290
  entities = np.zeros_like(data[0])
277
291
  data[0] = _norm_scores(data=data, args=args)
278
- for index, filepath in enumerate(filepaths):
279
- new_scores = _norm_scores(data=load_pickle(filepath), args=args)
292
+ for index, filepath in enumerate(foreground_paths):
293
+ new_scores = _norm_scores(
294
+ data=load_match_template_output(filepath, background_paths[index]),
295
+ args=args,
296
+ )
280
297
  indices = new_scores > data[0]
281
298
  entities[indices] = index + 1
282
299
  data[0][indices] = new_scores[indices]
@@ -284,9 +301,18 @@ def merge_outputs(data, filepaths: List[str], args):
284
301
  return data, entities
285
302
 
286
303
 
304
+ def load_match_template_output(foreground_path, background_path):
305
+ data = load_pickle(foreground_path)
306
+ if background_path is not None:
307
+ data_background = load_pickle(background_path)
308
+ data[0] = (data[0] - data_background[0]) / (1 - data_background[0])
309
+ np.fmax(data[0], 0, out=data[0])
310
+ return data
311
+
312
+
287
313
  def main():
288
314
  args = parse_args()
289
- data = load_pickle(args.input_file[0])
315
+ data = load_match_template_output(args.input_file[0], args.background_file[0])
290
316
 
291
317
  target_origin, _, sampling_rate, cli_args = data[-1]
292
318
 
@@ -326,7 +352,12 @@ def main():
326
352
 
327
353
  entities = None
328
354
  if len(args.input_file) > 1:
329
- data, entities = merge_outputs(data=data, filepaths=args.input_file, args=args)
355
+ data, entities = merge_outputs(
356
+ data=data,
357
+ foreground_paths=args.input_file,
358
+ background_paths=args.background_file,
359
+ args=args,
360
+ )
330
361
 
331
362
  orientations = args.orientations
332
363
  if orientations is None:
@@ -339,24 +370,27 @@ def main():
339
370
  target_mask = Density.from_file(args.target_mask)
340
371
  scores = scores * target_mask.data
341
372
 
342
- if args.n_false_positives is not None:
343
- args.n_false_positives = max(args.n_false_positives, 1)
344
- cropped_shape = np.subtract(
345
- scores.shape, np.multiply(args.min_boundary_distance, 2)
346
- ).astype(int)
373
+ cropped_shape = np.subtract(
374
+ scores.shape, np.multiply(args.min_boundary_distance, 2)
375
+ ).astype(int)
347
376
 
348
- cropped_shape = tuple(
377
+ if args.min_boundary_distance > 0:
378
+ scores = centered_mask(scores, new_shape=cropped_shape)
379
+
380
+ if args.n_false_positives is not None:
381
+ # Rickgauer et al. 2017
382
+ cropped_slice = tuple(
349
383
  slice(
350
384
  int(args.min_boundary_distance),
351
385
  int(x - args.min_boundary_distance),
352
386
  )
353
387
  for x in scores.shape
354
388
  )
355
- # Rickgauer et al. 2017
356
- n_correlations = np.size(scores[cropped_shape]) * len(rotation_mapping)
389
+ args.n_false_positives = max(args.n_false_positives, 1)
390
+ n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
357
391
  minimum_score = np.multiply(
358
392
  erfcinv(2 * args.n_false_positives / n_correlations),
359
- np.sqrt(2) * np.std(scores[cropped_shape]),
393
+ np.sqrt(2) * np.std(scores[cropped_slice]),
360
394
  )
361
395
  print(f"Determined minimum score cutoff: {minimum_score}.")
362
396
  minimum_score = max(minimum_score, 0)
@@ -371,6 +405,8 @@ def main():
371
405
  "min_distance": args.min_distance,
372
406
  "min_boundary_distance": args.min_boundary_distance,
373
407
  "batch_dims": args.batch_dims,
408
+ "minimum_score": args.minimum_score,
409
+ "maximum_score": args.maximum_score,
374
410
  }
375
411
 
376
412
  peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
@@ -380,7 +416,6 @@ def main():
380
416
  mask=template.data,
381
417
  rotation_mapping=rotation_mapping,
382
418
  rotation_array=rotation_array,
383
- minimum_score=args.minimum_score,
384
419
  )
385
420
  candidates = peak_caller.merge(
386
421
  candidates=[tuple(peak_caller)], **peak_caller_kwargs
@@ -388,10 +423,16 @@ def main():
388
423
  if len(candidates) == 0:
389
424
  candidates = [[], [], [], []]
390
425
  print("Found no peaks, consider changing peak calling parameters.")
391
- exit(0)
426
+ exit(-1)
392
427
 
393
428
  for translation, _, score, detail in zip(*candidates):
394
- rotations.append(rotation_mapping[rotation_array[tuple(translation)]])
429
+ rotation_index = rotation_array[tuple(translation)]
430
+ rotation = rotation_mapping.get(
431
+ rotation_index, np.zeros(template.data.ndim, int)
432
+ )
433
+ if rotation.ndim == 2:
434
+ rotation = euler_from_rotationmatrix(rotation)
435
+ rotations.append(rotation)
395
436
 
396
437
  else:
397
438
  candidates = data
@@ -430,7 +471,7 @@ def main():
430
471
  )
431
472
  exit(-1)
432
473
  orientations.translations = peak_caller.oversample_peaks(
433
- score_space=data[0],
474
+ scores=data[0],
434
475
  peak_positions=orientations.translations,
435
476
  oversampling_factor=args.peak_oversampling,
436
477
  )
@@ -570,7 +611,7 @@ def main():
570
611
  return_orientations=True,
571
612
  )
572
613
  out = np.zeros_like(template.data)
573
- out = np.zeros(np.multiply(template.shape, 2).astype(int))
614
+ # out = np.zeros(np.multiply(template.shape, 2).astype(int))
574
615
  for index in range(len(cand_slices)):
575
616
  from scipy.spatial.transform import Rotation
576
617
 
@@ -1,7 +1,5 @@
1
1
  #!python3
2
- """ Simplify picking adequate filtering and masking parameters using a GUI.
3
- Exposes tme.preprocessor.Preprocessor and tme.fitter_utils member functions
4
- to achieve this aim.
2
+ """ GUI for identifying adequate template matching filter and masks.
5
3
 
6
4
  Copyright (c) 2023 European Molecular Biology Laboratory
7
5
 
@@ -12,17 +10,20 @@ import argparse
12
10
  from typing import Tuple, Callable, List
13
11
  from typing_extensions import Annotated
14
12
 
13
+ import napari
15
14
  import numpy as np
16
15
  import pandas as pd
17
- import napari
16
+ from scipy.fft import next_fast_len
18
17
  from napari.layers import Image
19
18
  from napari.utils.events import EventedList
20
-
21
19
  from magicgui import widgets
22
20
  from qtpy.QtWidgets import QFileDialog
23
21
  from numpy.typing import NDArray
24
22
 
23
+ from tme.backends import backend
25
24
  from tme import Preprocessor, Density
25
+ from tme.preprocessing import BandPassFilter
26
+ from tme.preprocessing.tilt_series import CTF
26
27
  from tme.matching_utils import create_mask, load_pickle
27
28
 
28
29
  preprocessor = Preprocessor()
@@ -35,19 +36,57 @@ def gaussian_filter(template: NDArray, sigma: float, **kwargs: dict) -> NDArray:
35
36
 
36
37
  def bandpass_filter(
37
38
  template: NDArray,
38
- minimum_frequency: float,
39
- maximum_frequency: float,
40
- gaussian_sigma: float,
41
- **kwargs: dict,
39
+ lowpass_angstrom: float = 30,
40
+ highpass_angstrom: float = 140,
41
+ hard_edges: bool = False,
42
+ sampling_rate=None,
42
43
  ) -> NDArray:
43
- return preprocessor.bandpass_filter(
44
- template=template,
45
- minimum_frequency=minimum_frequency,
46
- maximum_frequency=maximum_frequency,
47
- sampling_rate=1,
48
- gaussian_sigma=gaussian_sigma,
49
- **kwargs,
44
+ bpf = BandPassFilter(
45
+ lowpass=lowpass_angstrom,
46
+ highpass=highpass_angstrom,
47
+ sampling_rate=np.max(sampling_rate),
48
+ use_gaussian=not hard_edges,
49
+ shape_is_real_fourier=True,
50
+ return_real_fourier=True,
50
51
  )
52
+ template_ft = np.fft.rfftn(template, s=template.shape)
53
+
54
+ mask = bpf(shape=template_ft.shape)["data"]
55
+ np.multiply(template_ft, mask, out=template_ft)
56
+ return np.fft.irfftn(template_ft, s=template.shape).real
57
+
58
+
59
+ def ctf_filter(
60
+ template: NDArray,
61
+ defocus_angstrom: float = 30000,
62
+ acceleration_voltage: float = 300,
63
+ spherical_aberration: float = 2.7,
64
+ amplitude_contrast: float = 0.07,
65
+ phase_shift: float = 0,
66
+ defocus_angle: float = 0,
67
+ sampling_rate=None,
68
+ flip_phase: bool = False,
69
+ ) -> NDArray:
70
+ fast_shape = [next_fast_len(x) for x in np.multiply(template.shape, 2)]
71
+ template_pad = backend.topleft_pad(template, fast_shape)
72
+ template_ft = np.fft.rfftn(template_pad, s=template_pad.shape)
73
+ ctf = CTF(
74
+ angles=[0],
75
+ shape=fast_shape,
76
+ defocus_x=[defocus_angstrom],
77
+ acceleration_voltage=acceleration_voltage * 1e3,
78
+ spherical_aberration=spherical_aberration * 1e7,
79
+ amplitude_contrast=amplitude_contrast,
80
+ phase_shift=[phase_shift],
81
+ defocus_angle=[defocus_angle],
82
+ sampling_rate=np.max(sampling_rate),
83
+ return_real_fourier=True,
84
+ flip_phase=flip_phase,
85
+ )
86
+ np.multiply(template_ft, ctf()["data"], out=template_ft)
87
+ template_pad = np.fft.irfftn(template_ft, s=template_pad.shape).real
88
+ template = backend.topleft_pad(template_pad, template.shape)
89
+ return template
51
90
 
52
91
 
53
92
  def difference_of_gaussian_filter(
@@ -109,61 +148,6 @@ def mean(
109
148
  return preprocessor.mean_filter(template=template, width=width)
110
149
 
111
150
 
112
- def resolution_sphere(
113
- template: NDArray,
114
- cutoff_angstrom: float,
115
- highpass: bool = False,
116
- sampling_rate=None,
117
- ) -> NDArray:
118
- if cutoff_angstrom == 0:
119
- return template
120
-
121
- cutoff_frequency = np.max(2 * sampling_rate / cutoff_angstrom)
122
-
123
- min_freq, max_freq = 0, cutoff_frequency
124
- if highpass:
125
- min_freq, max_freq = cutoff_frequency, 1e10
126
-
127
- mask = preprocessor.bandpass_mask(
128
- shape=template.shape,
129
- minimum_frequency=min_freq,
130
- maximum_frequency=max_freq,
131
- omit_negative_frequencies=False,
132
- )
133
-
134
- template_ft = np.fft.fftn(template)
135
- np.multiply(template_ft, mask, out=template_ft)
136
- return np.fft.ifftn(template_ft).real
137
-
138
-
139
- def resolution_gaussian(
140
- template: NDArray,
141
- cutoff_angstrom: float,
142
- highpass: bool = False,
143
- sampling_rate=None,
144
- ) -> NDArray:
145
- if cutoff_angstrom == 0:
146
- return template
147
-
148
- grid = preprocessor.fftfreqn(
149
- shape=template.shape, sampling_rate=sampling_rate / sampling_rate.max()
150
- )
151
-
152
- sigma_fourier = np.divide(
153
- np.max(2 * sampling_rate / cutoff_angstrom), np.sqrt(2 * np.log(2))
154
- )
155
-
156
- mask = np.exp(-(grid**2) / (2 * sigma_fourier**2))
157
- if highpass:
158
- mask = 1 - mask
159
-
160
- mask = np.fft.ifftshift(mask)
161
-
162
- template_ft = np.fft.fftn(template)
163
- np.multiply(template_ft, mask, out=template_ft)
164
- return np.fft.ifftn(template_ft).real
165
-
166
-
167
151
  def wedge(
168
152
  template: NDArray,
169
153
  tilt_start: float,
@@ -274,8 +258,7 @@ WRAPPED_FUNCTIONS = {
274
258
  "mean_filter": mean,
275
259
  "wedge_filter": wedge,
276
260
  "power_spectrum": compute_power_spectrum,
277
- "resolution_gaussian": resolution_gaussian,
278
- "resolution_sphere": resolution_sphere,
261
+ "ctf": ctf_filter,
279
262
  }
280
263
 
281
264
  EXCLUDED_FUNCTIONS = [
@@ -421,6 +404,7 @@ def sphere_mask(
421
404
  center_y: float,
422
405
  center_z: float,
423
406
  radius: float,
407
+ sigma_decay: float = 0,
424
408
  **kwargs,
425
409
  ) -> NDArray:
426
410
  return create_mask(
@@ -428,6 +412,7 @@ def sphere_mask(
428
412
  shape=template.shape,
429
413
  center=(center_x, center_y, center_z),
430
414
  radius=radius,
415
+ sigma_decay=sigma_decay,
431
416
  )
432
417
 
433
418
 
@@ -439,6 +424,7 @@ def ellipsod_mask(
439
424
  radius_x: float,
440
425
  radius_y: float,
441
426
  radius_z: float,
427
+ sigma_decay: float = 0,
442
428
  **kwargs,
443
429
  ) -> NDArray:
444
430
  return create_mask(
@@ -446,6 +432,7 @@ def ellipsod_mask(
446
432
  shape=template.shape,
447
433
  center=(center_x, center_y, center_z),
448
434
  radius=(radius_x, radius_y, radius_z),
435
+ sigma_decay=sigma_decay,
449
436
  )
450
437
 
451
438
 
@@ -457,6 +444,7 @@ def box_mask(
457
444
  height_x: int,
458
445
  height_y: int,
459
446
  height_z: int,
447
+ sigma_decay: float = 0,
460
448
  **kwargs,
461
449
  ) -> NDArray:
462
450
  return create_mask(
@@ -464,6 +452,7 @@ def box_mask(
464
452
  shape=template.shape,
465
453
  center=(center_x, center_y, center_z),
466
454
  height=(height_x, height_y, height_z),
455
+ sigma_decay=sigma_decay,
467
456
  )
468
457
 
469
458
 
@@ -476,6 +465,7 @@ def tube_mask(
476
465
  inner_radius: float,
477
466
  outer_radius: float,
478
467
  height: int,
468
+ sigma_decay: float = 0,
479
469
  **kwargs,
480
470
  ) -> NDArray:
481
471
  return create_mask(
@@ -486,6 +476,7 @@ def tube_mask(
486
476
  inner_radius=inner_radius,
487
477
  outer_radius=outer_radius,
488
478
  height=height,
479
+ sigma_decay=sigma_decay,
489
480
  )
490
481
 
491
482
 
@@ -533,13 +524,23 @@ def wedge_mask(
533
524
 
534
525
 
535
526
  def threshold_mask(
536
- template: NDArray, standard_deviation: float = 5.0, invert: bool = False, **kwargs
527
+ template: NDArray,
528
+ invert: bool = False,
529
+ standard_deviation: float = 5.0,
530
+ sigma: float = 0.0,
531
+ **kwargs,
537
532
  ) -> NDArray:
538
533
  template_mean = template.mean()
539
534
  template_deviation = standard_deviation * template.std()
540
535
  upper = template_mean + template_deviation
541
536
  lower = template_mean - template_deviation
542
- mask = np.logical_and(template > lower, template < upper)
537
+ mask = np.logical_or(template <= lower, template >= upper)
538
+
539
+ if sigma != 0:
540
+ mask_filter = preprocessor.gaussian_filter(template=mask * 1.0, sigma=sigma)
541
+ mask = np.add(mask, (1 - mask) * mask_filter)
542
+ mask[mask < np.exp(-np.square(sigma))] = 0
543
+
543
544
  if invert:
544
545
  np.invert(mask, out=mask)
545
546
 
@@ -890,6 +891,7 @@ class PointCloudWidget(widgets.Container):
890
891
 
891
892
  self.viewer = viewer
892
893
  self.dataframes = {}
894
+ self.selected_category = -1
893
895
 
894
896
  self.import_button = widgets.PushButton(
895
897
  name="Import", text="Import Point Cloud"
@@ -902,10 +904,98 @@ class PointCloudWidget(widgets.Container):
902
904
  self.export_button.clicked.connect(self._export_point_cloud)
903
905
  self.export_button.enabled = False
904
906
 
907
+ self.annotation_container = widgets.Container(name="Label", layout="horizontal")
908
+ self.positive_button = widgets.PushButton(name="Positive", text="Positive")
909
+ self.negative_button = widgets.PushButton(name="Negative", text="Negative")
910
+ self.positive_button.clicked.connect(self._set_positive)
911
+ self.negative_button.clicked.connect(self._set_negative)
912
+ self.annotation_container.append(self.positive_button)
913
+ self.annotation_container.append(self.negative_button)
914
+
915
+ self.face_color_select = widgets.ComboBox(
916
+ name="Color", choices=["Label", "Score"], value=None, nullable=True
917
+ )
918
+ self.face_color_select.changed.connect(self._update_face_color_mode)
919
+
905
920
  self.append(self.import_button)
906
921
  self.append(self.export_button)
922
+ self.append(self.annotation_container)
923
+ self.append(self.face_color_select)
924
+
907
925
  self.viewer.layers.selection.events.changed.connect(self._update_buttons)
908
926
 
927
+ self.viewer.layers.events.inserted.connect(self._initialize_points_layer)
928
+
929
+ def _update_face_color_mode(self, event: str = None):
930
+ for layer in self.viewer.layers:
931
+ if not isinstance(layer, napari.layers.Points):
932
+ continue
933
+
934
+ layer.face_color = "white"
935
+ if event == "Label":
936
+ if len(layer.properties.get("detail", ())) == 0:
937
+ continue
938
+ layer.face_color = "detail"
939
+ layer.face_color_cycle = {
940
+ -1: "grey",
941
+ 0: "red",
942
+ 1: "green",
943
+ }
944
+ layer.face_color_mode = "cycle"
945
+ elif event == "Score":
946
+ if len(layer.properties.get("score_scaled", ())) == 0:
947
+ continue
948
+ layer.face_color = "score_scaled"
949
+ layer.face_colormap = "turbo"
950
+ layer.face_color_mode = "colormap"
951
+
952
+ layer.refresh_colors()
953
+
954
+ return None
955
+
956
+ def _set_positive(self, event):
957
+ self.selected_category = 1 if self.selected_category != 1 else -1
958
+ self._update_annotation_buttons()
959
+
960
+ def _set_negative(self, event):
961
+ self.selected_category = 0 if self.selected_category != 0 else -1
962
+ self._update_annotation_buttons()
963
+
964
+ def _update_annotation_buttons(self):
965
+ selected_style = "background-color: darkgrey"
966
+ default_style = "background-color: none"
967
+
968
+ self.positive_button.native.setStyleSheet(
969
+ selected_style if self.selected_category == 1 else default_style
970
+ )
971
+ self.negative_button.native.setStyleSheet(
972
+ selected_style if self.selected_category == 0 else default_style
973
+ )
974
+
975
+ def _initialize_points_layer(self, event):
976
+ layer = event.value
977
+ if not isinstance(layer, napari.layers.Points):
978
+ return
979
+ if len(layer.properties) == 0:
980
+ layer.properties = {"detail": [-1]}
981
+
982
+ if "detail" not in layer.properties:
983
+ layer["detail"] = [-1]
984
+
985
+ layer.mouse_drag_callbacks.append(self._on_click)
986
+ return None
987
+
988
+ def _on_click(self, layer, event):
989
+ if layer.mode == "add":
990
+ layer.current_properties["detail"][-1] = self.selected_category
991
+ elif layer.mode == "select":
992
+ for index in layer.selected_data:
993
+ layer.properties["detail"][index] = self.selected_category
994
+
995
+ # TODO: Check whether current face color is the desired one already
996
+ self._update_face_color_mode(self.face_color_select.value)
997
+ layer.refresh_colors()
998
+
909
999
  def _update_buttons(self, event):
910
1000
  is_pointcloud = isinstance(
911
1001
  self.viewer.layers.selection.active, napari.layers.Points
@@ -951,9 +1041,7 @@ class PointCloudWidget(widgets.Container):
951
1041
 
952
1042
  if "score" in merged_data.columns:
953
1043
  merged_data["score"] = merged_data["score"].fillna(1)
954
- if "detail" in merged_data.columns:
955
- merged_data["detail"] = merged_data["detail"].fillna(2)
956
-
1044
+ merged_data["detail"] = layer.properties["detail"]
957
1045
  merged_data.to_csv(filename, sep="\t", index=False)
958
1046
 
959
1047
  def _get_load_path(self, event):
@@ -977,7 +1065,7 @@ class PointCloudWidget(widgets.Container):
977
1065
  dataframe["score"] = 1
978
1066
 
979
1067
  if "detail" not in dataframe.columns:
980
- dataframe["detail"] = -2
1068
+ dataframe["detail"] = -1
981
1069
 
982
1070
  point_properties = {
983
1071
  "score": np.array(dataframe["score"].values),
@@ -991,8 +1079,6 @@ class PointCloudWidget(widgets.Container):
991
1079
  points,
992
1080
  size=10,
993
1081
  properties=point_properties,
994
- face_color="score_scaled",
995
- face_colormap="turbo",
996
1082
  name=layer_name,
997
1083
  )
998
1084
  self.dataframes[layer_name] = dataframe
@@ -1025,9 +1111,14 @@ class MatchingWidget(widgets.Container):
1025
1111
  def _load_data(self, filename):
1026
1112
  data = load_pickle(filename)
1027
1113
 
1028
- _ = self.viewer.add_image(data=data[2], name="Rotations", colormap="orange")
1114
+ metadata = {"origin": data[-1][1], "sampling_rate": data[-1][2]}
1115
+ _ = self.viewer.add_image(
1116
+ data=data[2], name="Rotations", colormap="orange", metadata=metadata
1117
+ )
1029
1118
 
1030
- _ = self.viewer.add_image(data=data[0], name="Scores", colormap="turbo")
1119
+ _ = self.viewer.add_image(
1120
+ data=data[0], name="Scores", colormap="turbo", metadata=metadata
1121
+ )
1031
1122
 
1032
1123
 
1033
1124
  def main():
@@ -1045,11 +1136,10 @@ def main():
1045
1136
  widget=alignment_widget, name="Alignment", area="right"
1046
1137
  )
1047
1138
  viewer.window.add_dock_widget(widget=mask_widget, name="Mask", area="right")
1139
+ viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
1048
1140
  viewer.window.add_dock_widget(widget=point_cloud, name="PointCloud", area="left")
1049
1141
  viewer.window.add_dock_widget(widget=matching_widget, name="Matching", area="left")
1050
1142
 
1051
- viewer.window.add_dock_widget(widget=export_widget, name="Export", area="right")
1052
-
1053
1143
  napari.run()
1054
1144
 
1055
1145