pytme 0.1.1__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. {pytme-0.1.1 → pytme-0.1.3}/PKG-INFO +2 -2
  2. {pytme-0.1.1 → pytme-0.1.3}/pyproject.toml +2 -2
  3. {pytme-0.1.1 → pytme-0.1.3}/pytme.egg-info/SOURCES.txt +1 -0
  4. {pytme-0.1.1 → pytme-0.1.3}/scripts/match_template.py +10 -9
  5. {pytme-0.1.1 → pytme-0.1.3}/scripts/postprocess.py +6 -3
  6. {pytme-0.1.1 → pytme-0.1.3}/scripts/preprocessor_gui.py +93 -14
  7. pytme-0.1.3/tme/__version__.py +1 -0
  8. {pytme-0.1.1 → pytme-0.1.3}/tme/analyzer.py +2 -2
  9. {pytme-0.1.1 → pytme-0.1.3}/tme/backends/pytorch_backend.py +7 -4
  10. {pytme-0.1.1 → pytme-0.1.3}/tme/density.py +22 -13
  11. {pytme-0.1.1 → pytme-0.1.3}/tme/matching_data.py +2 -4
  12. {pytme-0.1.1 → pytme-0.1.3}/tme/matching_exhaustive.py +0 -2
  13. {pytme-0.1.1 → pytme-0.1.3}/tme/matching_memory.py +1 -1
  14. {pytme-0.1.1 → pytme-0.1.3}/tme/matching_optimization.py +5 -0
  15. {pytme-0.1.1 → pytme-0.1.3}/tme/matching_utils.py +7 -1
  16. {pytme-0.1.1 → pytme-0.1.3}/tme/preprocessor.py +62 -8
  17. pytme-0.1.3/tme/scoring.py +679 -0
  18. {pytme-0.1.1 → pytme-0.1.3}/tme/structure.py +19 -20
  19. pytme-0.1.1/tme/__version__.py +0 -1
  20. {pytme-0.1.1 → pytme-0.1.3}/LICENSE +0 -0
  21. {pytme-0.1.1 → pytme-0.1.3}/MANIFEST.in +0 -0
  22. {pytme-0.1.1 → pytme-0.1.3}/README.md +0 -0
  23. {pytme-0.1.1 → pytme-0.1.3}/scripts/estimate_ram_usage.py +0 -0
  24. {pytme-0.1.1 → pytme-0.1.3}/scripts/preprocess.py +0 -0
  25. {pytme-0.1.1 → pytme-0.1.3}/setup.cfg +0 -0
  26. {pytme-0.1.1 → pytme-0.1.3}/setup.py +0 -0
  27. {pytme-0.1.1 → pytme-0.1.3}/src/extensions.cpp +0 -0
  28. {pytme-0.1.1 → pytme-0.1.3}/tme/__init__.py +0 -0
  29. {pytme-0.1.1 → pytme-0.1.3}/tme/backends/__init__.py +0 -0
  30. {pytme-0.1.1 → pytme-0.1.3}/tme/backends/cupy_backend.py +0 -0
  31. {pytme-0.1.1 → pytme-0.1.3}/tme/backends/matching_backend.py +0 -0
  32. {pytme-0.1.1 → pytme-0.1.3}/tme/backends/npfftw_backend.py +0 -0
  33. {pytme-0.1.1 → pytme-0.1.3}/tme/data/__init__.py +0 -0
  34. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48n309.npy +0 -0
  35. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48n527.npy +0 -0
  36. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48n9.npy +0 -0
  37. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u1.npy +0 -0
  38. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u1153.npy +0 -0
  39. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u1201.npy +0 -0
  40. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u1641.npy +0 -0
  41. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u181.npy +0 -0
  42. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u2219.npy +0 -0
  43. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u27.npy +0 -0
  44. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u2947.npy +0 -0
  45. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u3733.npy +0 -0
  46. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u4749.npy +0 -0
  47. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u5879.npy +0 -0
  48. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u7111.npy +0 -0
  49. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u815.npy +0 -0
  50. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u83.npy +0 -0
  51. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c48u8649.npy +0 -0
  52. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c600v.npy +0 -0
  53. {pytme-0.1.1 → pytme-0.1.3}/tme/data/c600vc.npy +0 -0
  54. {pytme-0.1.1 → pytme-0.1.3}/tme/data/metadata.yaml +0 -0
  55. {pytme-0.1.1 → pytme-0.1.3}/tme/data/quat_to_numpy.py +0 -0
  56. {pytme-0.1.1 → pytme-0.1.3}/tme/helpers.py +0 -0
  57. {pytme-0.1.1 → pytme-0.1.3}/tme/parser.py +0 -0
  58. {pytme-0.1.1 → pytme-0.1.3}/tme/types.py +0 -0
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pytme
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: Python Template Matching Engine
5
5
  Author: Valentin Maurer
6
6
  Author-email: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  License: Proprietary
8
- Project-URL: Homepage, https://git.embl.de/vmaurer/pytme
8
+ Project-URL: Homepage, https://github.com/KosinskiLab/pyTME
9
9
  Classifier: Programming Language :: Python :: 3
10
10
  Classifier: Operating System :: OS Independent
11
11
  Requires-Python: >=3.11
@@ -7,7 +7,7 @@ name="pytme"
7
7
  authors = [
8
8
  { name = "Valentin Maurer", email = "valentin.maurer@embl-hamburg.de" },
9
9
  ]
10
- version="0.1.1"
10
+ version="0.1.3"
11
11
  description="Python Template Matching Engine"
12
12
  readme="README.md"
13
13
  requires-python = ">=3.11"
@@ -38,7 +38,7 @@ preprocess = "tme.scripts:preprocess"
38
38
  postprocess = "tme.scripts:postprocess"
39
39
 
40
40
  [project.urls]
41
- "Homepage" = "https://git.embl.de/vmaurer/pytme"
41
+ "Homepage" = "https://github.com/KosinskiLab/pyTME"
42
42
 
43
43
  [tool.setuptools]
44
44
  include-package-data = true
@@ -21,6 +21,7 @@ tme/matching_optimization.py
21
21
  tme/matching_utils.py
22
22
  tme/parser.py
23
23
  tme/preprocessor.py
24
+ tme/scoring.py
24
25
  tme/structure.py
25
26
  tme/types.py
26
27
  tme/backends/__init__.py
@@ -43,7 +43,7 @@ def print_block(name: str, data: dict, label_width=20) -> None:
43
43
  """Prints a formatted block of information."""
44
44
  print(f"\n> {name}")
45
45
  for key, value in data.items():
46
- formatted_value = str(value) # Convert non-string values to string
46
+ formatted_value = str(value)
47
47
  print(f" - {key + ':':<{label_width}} {formatted_value}")
48
48
 
49
49
 
@@ -426,7 +426,7 @@ def main():
426
426
 
427
427
  if not np.allclose(target.sampling_rate, template.sampling_rate):
428
428
  print(
429
- f"Resampling template to {target.sampling_rate}."
429
+ f"Resampling template to {target.sampling_rate}. "
430
430
  "Consider providing a template with the same sampling rate as the target."
431
431
  )
432
432
  template = template.resample(target.sampling_rate, order=3)
@@ -506,7 +506,7 @@ def main():
506
506
  -tilt_start, tilt_stop + args.tilt_step, args.tilt_step
507
507
  )
508
508
  angles = np.zeros((template.data.ndim, tilt_angles.size))
509
- angles[1, :] = tilt_angles
509
+ angles[2, :] = tilt_angles
510
510
  template_filter["wedge_mask"] = {
511
511
  "tilt_angles": angles,
512
512
  "sigma": args.wedge_smooth,
@@ -516,6 +516,7 @@ def main():
516
516
  "start_tilt": tilt_start,
517
517
  "stop_tilt": tilt_stop,
518
518
  "tilt_axis": 1,
519
+ "infinite_plane": True,
519
520
  "sigma": args.wedge_smooth,
520
521
  }
521
522
 
@@ -597,6 +598,10 @@ def main():
597
598
  if not args.pad_fourier:
598
599
  template_box = np.ones(len(template_box), dtype=int)
599
600
 
601
+ callback_class = MaxScoreOverRotations
602
+ if args.peak_calling:
603
+ callback_class = PeakCallerMaximumFilter
604
+
600
605
  splits, schedule = compute_parallelization_schedule(
601
606
  shape1=target.shape,
602
607
  shape2=template_box,
@@ -605,7 +610,7 @@ def main():
605
610
  max_ram=args.ram,
606
611
  split_only_outer=args.use_gpu,
607
612
  matching_method=args.score,
608
- analyzer_method="MaxScoreOverRotations",
613
+ analyzer_method=callback_class.__name__,
609
614
  backend=backend._backend_name,
610
615
  float_nbytes=backend.datatype_bytes(backend._default_dtype),
611
616
  complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
@@ -627,16 +632,12 @@ def main():
627
632
  }
628
633
 
629
634
  matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
630
- callback_class = MaxScoreOverRotations
631
- if args.peak_calling:
632
- callback_class = PeakCallerMaximumFilter
633
-
634
635
  matching_data = MatchingData(target=target, template=template.data)
635
636
  matching_data.rotations = get_rotation_matrices(
636
637
  angular_sampling=args.angular_sampling, dim=target.data.ndim
637
638
  )
638
- matching_data.template_filter = template_filter
639
639
 
640
+ matching_data.template_filter = template_filter
640
641
  if target_mask is not None:
641
642
  matching_data.target_mask = target_mask
642
643
  if template_mask is not None:
@@ -100,9 +100,12 @@ def main():
100
100
  orientations.append((translation, angles, score, detail))
101
101
  else:
102
102
  candidates = data
103
- for translation, rotation, score, detail in zip(*candidates):
104
- angles = euler_from_rotationmatrix(rotation)
105
- orientations.append((translation, angles, score, detail))
103
+ translation, rotation, score, detail, *_ = data
104
+ for i in range(translation.shape[0]):
105
+ angles = euler_from_rotationmatrix(rotation[i])
106
+ orientations.append(
107
+ (np.array(translation[i]), angles, score[i], detail[i])
108
+ )
106
109
  else:
107
110
  with open(args.orientations, mode="r", encoding="utf-8") as infile:
108
111
  data = [x.strip().split("\t") for x in infile.read().split("\n")]
@@ -108,6 +108,31 @@ def mean(
108
108
  return preprocessor.mean_filter(template=template, width=width)
109
109
 
110
110
 
111
+ def wedge(
112
+ template: NDArray,
113
+ tilt_start: float,
114
+ tilt_stop: float,
115
+ gaussian_sigma: float,
116
+ tilt_axis: int = 1,
117
+ infinite_plane: bool = True,
118
+ extrude_plane: bool = True,
119
+ ):
120
+ template_ft = np.fft.rfftn(template)
121
+ wedge_mask = preprocessor.continuous_wedge_mask(
122
+ start_tilt=tilt_start,
123
+ stop_tilt=tilt_stop,
124
+ tilt_axis=tilt_axis,
125
+ shape=template.shape,
126
+ sigma=gaussian_sigma,
127
+ omit_negative_frequencies=True,
128
+ infinite_plane=infinite_plane,
129
+ extrude_plane=extrude_plane,
130
+ )
131
+ np.multiply(template_ft, wedge_mask, out=template_ft)
132
+ template = np.real(np.fft.irfftn(template_ft))
133
+ return template
134
+
135
+
111
136
  def widgets_from_function(function: Callable, exclude_params: List = ["self"]):
112
137
  """
113
138
  Creates list of magicui widgets by inspecting function typing ann
@@ -166,7 +191,8 @@ WRAPPED_FUNCTIONS = {
166
191
  "ntree_filter": ntree,
167
192
  "local_gaussian_filter": local_gaussian_filter,
168
193
  "difference_of_gaussian_filter": difference_of_gaussian_filter,
169
- "mean_filter" : mean,
194
+ "mean_filter": mean,
195
+ "continuous_wedge_mask": wedge,
170
196
  }
171
197
 
172
198
  EXCLUDED_FUNCTIONS = [
@@ -178,7 +204,7 @@ EXCLUDED_FUNCTIONS = [
178
204
  "interpolate_box",
179
205
  "molmap",
180
206
  "local_gaussian_alignment_filter",
181
- "continuous_wedge_mask",
207
+ # "continuous_wedge_mask",
182
208
  "wedge_mask",
183
209
  "bandpass_mask",
184
210
  ]
@@ -372,19 +398,22 @@ def wedge_mask(
372
398
  tilt_stop: float,
373
399
  gaussian_sigma: float,
374
400
  tilt_axis: int = 1,
401
+ omit_negative_frequencies: bool = True,
402
+ extrude_plane: bool = True,
403
+ infinite_plane: bool = True,
375
404
  ):
376
- template_ft = np.fft.fftn(template)
377
405
  wedge_mask = preprocessor.continuous_wedge_mask(
378
406
  start_tilt=tilt_start,
379
407
  stop_tilt=tilt_stop,
380
408
  tilt_axis=tilt_axis,
381
- shape=template_ft.shape,
409
+ shape=template.shape,
382
410
  sigma=gaussian_sigma,
411
+ omit_negative_frequencies=omit_negative_frequencies,
412
+ extrude_plane=extrude_plane,
413
+ infinite_plane=infinite_plane,
383
414
  )
384
-
385
- np.multiply(template_ft, wedge_mask, out=template_ft)
386
- template = np.real(np.fft.ifftn(template_ft))
387
- return template
415
+ wedge_mask = np.fft.fftshift(wedge_mask)
416
+ return wedge_mask
388
417
 
389
418
 
390
419
  class MaskWidget(widgets.Container):
@@ -410,34 +439,73 @@ class MaskWidget(widgets.Container):
410
439
  )
411
440
  self.method_dropdown.changed.connect(self._on_method_changed)
412
441
 
413
- self.adapt_button = widgets.PushButton(
414
- text="Adapt to current layer", enabled=False
415
- )
442
+ self.adapt_button = widgets.PushButton(text="Adapt to layer", enabled=False)
416
443
  self.adapt_button.changed.connect(self._update_initial_values)
417
444
 
418
445
  self.viewer.layers.selection.events.active.connect(
419
446
  self._update_action_button_state
420
447
  )
421
448
 
449
+ self.align_button = widgets.PushButton(text="Align to axis", enabled=False)
450
+ self.align_button.changed.connect(self._align_with_axis)
451
+ self.density_field = widgets.Label()
452
+ # self.density_field.value = f"Positive Density in Mask: {0:.2f}%"
453
+
422
454
  self.append(self.method_dropdown)
423
455
  self.append(self.adapt_button)
456
+ self.append(self.align_button)
424
457
  self.append(self.action_button)
458
+ self.append(self.density_field)
425
459
 
426
460
  # Create GUI for initially selected filtering method
427
461
  self._on_method_changed(None)
428
462
 
429
463
  def _update_action_button_state(self, event):
464
+ self.align_button.enabled = bool(self.viewer.layers.selection.active)
430
465
  self.action_button.enabled = bool(self.viewer.layers.selection.active)
431
466
  self.adapt_button.enabled = bool(self.viewer.layers.selection.active)
432
467
 
468
+ def _align_with_axis(self):
469
+ active_layer = self.viewer.layers.selection.active
470
+
471
+ if active_layer.metadata.get("is_aligned", False):
472
+ return
473
+
474
+ coords = np.array(np.where(active_layer.data > 0)).T
475
+ centered_coords = coords - np.mean(coords, axis=0)
476
+ cov_matrix = np.cov(centered_coords, rowvar=False)
477
+
478
+ eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
479
+ principal_eigenvector = eigenvectors[:, np.argmax(eigenvalues)]
480
+
481
+ rotation_axis = np.cross(principal_eigenvector, [1, 0, 0])
482
+ rotation_angle = np.arccos(np.dot(principal_eigenvector, [1, 0, 0]))
483
+ k = rotation_axis / np.linalg.norm(rotation_axis)
484
+ K = np.array([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]])
485
+ rotation_matrix = np.eye(3)
486
+ rotation_matrix += np.sin(rotation_angle) * K
487
+ rotation_matrix += (1 - np.cos(rotation_angle)) * np.dot(K, K)
488
+
489
+ rotated_data = Density.rotate_array(
490
+ arr=active_layer.data,
491
+ rotation_matrix=rotation_matrix,
492
+ use_geometric_center=False,
493
+ )
494
+ eps = np.finfo(rotated_data.dtype).eps
495
+ rotated_data[rotated_data < eps] = 0
496
+
497
+ active_layer.metadata["is_aligned"] = True
498
+ active_layer.data = rotated_data
499
+
433
500
  def _update_initial_values(self, event=None):
434
501
  active_layer = self.viewer.layers.selection.active
435
- center_of_mass = Density.center_of_mass(np.abs(active_layer.data))
502
+ center_of_mass = Density.center_of_mass(np.abs(active_layer.data), 0)
436
503
  coordinates = np.array(np.where(active_layer.data > 0))
437
504
  coordinates_min = coordinates.min(axis=1)
438
505
  coordinates_max = coordinates.max(axis=1)
439
506
  coordinates_heights = coordinates_max - coordinates_min
440
507
  coordinate_radius = np.divide(coordinates_heights, 2)
508
+ center_of_mass = coordinate_radius + coordinates_min
441
509
 
442
510
  defaults = dict(zip(["center_x", "center_y", "center_z"], center_of_mass))
443
511
  defaults.update(
@@ -465,7 +533,7 @@ class MaskWidget(widgets.Container):
465
533
  widgets = widgets_from_function(function)
466
534
  for widget in widgets:
467
535
  self.action_widgets.append(widget)
468
- self.insert(-2, widget)
536
+ self.insert(1, widget)
469
537
 
470
538
  def _action(self):
471
539
  function = self.methods.get(self.method_dropdown.value)
@@ -479,7 +547,8 @@ class MaskWidget(widgets.Container):
479
547
  selected_layer = self.viewer.layers[new_layer_name]
480
548
 
481
549
  processed_data = processed_data.astype(np.float32)
482
- mask = selected_layer.metadata.get("mask", False)
550
+ metadata = selected_layer.metadata
551
+ mask = metadata.get("mask", False)
483
552
  if mask == self.method_dropdown.value:
484
553
  selected_layer.data = processed_data
485
554
  else:
@@ -490,8 +559,18 @@ class MaskWidget(widgets.Container):
490
559
  metadata = selected_layer.metadata.copy()
491
560
  metadata["filter_parameters"] = {self.method_dropdown.value: kwargs.copy()}
492
561
  metadata["mask"] = self.method_dropdown.value
562
+ metadata["origin_layer"] = selected_layer.name
493
563
  new_layer.metadata = metadata
494
564
 
565
+ origin_layer = metadata["origin_layer"]
566
+ if origin_layer in self.viewer.layers:
567
+ origin_layer = self.viewer.layers[origin_layer]
568
+ if np.allclose(origin_layer.data.shape, processed_data.shape):
569
+ in_mask = np.sum(np.fmax(origin_layer.data * processed_data, 0))
570
+ in_mask /= np.sum(np.fmax(origin_layer.data, 0))
571
+ in_mask *= 100
572
+ self.density_field.value = f"Positive Density in Mask: {in_mask:.2f}%"
573
+
495
574
 
496
575
  class ExportWidget(widgets.Container):
497
576
  def __init__(self, viewer):
@@ -0,0 +1 @@
1
+ __version__ = "0.1.3"
@@ -121,7 +121,7 @@ class PeakCaller(ABC):
121
121
  fourier_shift = kwargs.get(
122
122
  "fourier_shift", backend.zeros(peak_positions.shape[1], dtype=int)
123
123
  )
124
- if np.any(fourier_shift) != 0:
124
+ if backend.sum(fourier_shift != 0) != 0:
125
125
  peak_positions = backend.mod(
126
126
  backend.add(peak_positions, fourier_shift), score_space.shape
127
127
  )
@@ -197,6 +197,7 @@ class PeakCaller(ABC):
197
197
  if len(candidate) == 0:
198
198
  continue
199
199
  peak_positions, rotations, peak_scores, peak_details = candidate
200
+ kwargs["translation_offset"] = backend.zeros(peak_positions.shape[1])
200
201
  base._update(
201
202
  peak_positions=backend.to_backend_array(peak_positions),
202
203
  peak_details=backend.to_backend_array(peak_details),
@@ -237,7 +238,6 @@ class PeakCaller(ABC):
237
238
  translation_offset = backend.astype(translation_offset, peak_positions.dtype)
238
239
 
239
240
  backend.add(peak_positions, translation_offset, out=peak_positions)
240
-
241
241
  if not len(self.peak_list):
242
242
  self.peak_list = [peak_positions, rotations, peak_scores, peak_details]
243
243
  return None
@@ -491,8 +491,11 @@ class PytorchBackend(NumpyFFTWBackend):
491
491
  Operates as a context manager, yielding None and providing
492
492
  the set GPU context for enclosed operations.
493
493
  """
494
- with self._array_backend.cuda.device(device_index):
495
- yield
494
+ if self.device == "cuda":
495
+ with self._array_backend.cuda.device(device_index):
496
+ yield
497
+ else:
498
+ yield None
496
499
 
497
500
  def device_count(self) -> int:
498
501
  """
@@ -505,7 +508,7 @@ class PytorchBackend(NumpyFFTWBackend):
505
508
  """
506
509
  return self._array_backend.cuda.device_count()
507
510
 
508
- def reverse(arr: TorchTensor) -> TorchTensor:
511
+ def reverse(self, arr: TorchTensor) -> TorchTensor:
509
512
  """
510
513
  Reverse the order of elements in a tensor along all its axes.
511
514
 
@@ -519,4 +522,4 @@ class PytorchBackend(NumpyFFTWBackend):
519
522
  TorchTensor
520
523
  Reversed tensor.
521
524
  """
522
- return self._array_backend.flip(arr, [i for i in range(arr.dim())])
525
+ return self._array_backend.flip(arr, [i for i in range(arr.ndim)])
@@ -46,10 +46,10 @@ class Density:
46
46
  ----------
47
47
  data : NDArray
48
48
  Electron density data.
49
- origin : NDArray
50
- Origin of the coordinate system.
51
- sampling_rate : NDArray
52
- Sampling rate along data axis.
49
+ origin : NDArray, optional
50
+ Origin of the coordinate system. Defaults to zero.
51
+ sampling_rate : NDArray, optional
52
+ Sampling rate along data axis. Defaults to one.
53
53
  metadata : dict, optional
54
54
  Dictionary with metadata information, empty by default.
55
55
 
@@ -62,16 +62,18 @@ class Density:
62
62
  --------
63
63
  >>> import numpy as np
64
64
  >>> data = np.random.rand(50,50,50)
65
- >>> Density(data = data, origin = (0, 0, 0), sampling_rate = (0, 0, 0))
65
+ >>> Density(data = data, origin = (0, 0, 0), sampling_rate = (1, 1, 1))
66
66
  """
67
67
 
68
68
  def __init__(
69
69
  self,
70
70
  data: NDArray,
71
- origin: NDArray,
72
- sampling_rate: NDArray,
71
+ origin: NDArray = None,
72
+ sampling_rate: NDArray = None,
73
73
  metadata: Dict = {},
74
74
  ):
75
+ origin = 0 if origin is None else origin
76
+ sampling_rate = 1 if sampling_rate is None else sampling_rate
75
77
  sampling_rate, origin = np.asarray(sampling_rate), np.asarray(origin)
76
78
  sampling_rate = np.repeat(sampling_rate, data.ndim // sampling_rate.size)
77
79
 
@@ -127,7 +129,7 @@ class Density:
127
129
 
128
130
  Examples
129
131
  --------
130
- >>> density = Density.from_file("/path/to/mrc")
132
+ >>> density = Density.from_file("/path/to/file")
131
133
 
132
134
  Notes
133
135
  -----
@@ -548,8 +550,9 @@ class Density:
548
550
  Which weight should be given to individual atoms. For valid values
549
551
  see :py:meth:`Structure.from_file`.
550
552
  chain : str, optional
551
- Which chain of the protein should be considered. Default value
552
- corresponds to using all chains.
553
+ The chain identifier. If multiple chains should be selected they need
554
+ to be a comma separated string, e.g. 'A,B,CE'. If chain None,
555
+ all chains are returned. Default is None.
553
556
  filter_by_elements : set, optional
554
557
  Set of atomic elements to keep. Default is all atoms.
555
558
  filter_by_residues : set, optional
@@ -876,7 +879,6 @@ class Density:
876
879
  box_start = np.array([b.start for b in box])
877
880
  box_stop = np.array([b.stop for b in box])
878
881
  left_pad = -np.minimum(box_start, np.zeros(len(box), dtype=int))
879
- has_extension = box_start < 0
880
882
 
881
883
  right_pad = box_stop - box_start * (box_start > 0)
882
884
  right_pad -= np.array(self.shape, dtype=int)
@@ -905,8 +907,6 @@ class Density:
905
907
  self.data = self.data[crop_box].copy()
906
908
 
907
909
  # In case the box is larger than the current map
908
- before_shape = self.data.shape
909
- after_shape = [b.stop - b.start for b in box]
910
910
  self.data = self._pad_slice(box)
911
911
 
912
912
  # Adjust the origin
@@ -934,10 +934,19 @@ class Density:
934
934
  tuple
935
935
  A tuple containing slice objects representing the box.
936
936
 
937
+ Raises
938
+ ------
939
+ ValueError
940
+ If the cutoff is larger than or equal to the maximum density value.
941
+
937
942
  See Also
938
943
  --------
939
944
  :py:meth:`Density.adjust_box`
940
945
  """
946
+ if cutoff >= self.data.max():
947
+ raise ValueError(
948
+ f"Cutoff exceeds data range ({cutoff} >= {self.data.max()})."
949
+ )
941
950
  starts, stops = [], []
942
951
  for axis in range(self.data.ndim):
943
952
  projected_max = np.max(
@@ -116,7 +116,6 @@ class MatchingData:
116
116
  ).astype(int)
117
117
 
118
118
  ret_shape = np.add(slice_shape, padding)
119
-
120
119
  arr_start = np.subtract(slice_start, data_voxels_left)
121
120
  arr_stop = np.add(slice_stop, data_voxels_right)
122
121
  arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
@@ -138,7 +137,6 @@ class MatchingData:
138
137
  arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype
139
138
  )
140
139
  arr = np.asarray(arr[*arr_mesh])
141
-
142
140
  ret = np.full(
143
141
  shape=np.add(slice_shape, padding), fill_value=arr.mean(), dtype=arr.dtype
144
142
  )
@@ -188,8 +186,8 @@ class MatchingData:
188
186
  template_pad = np.zeros(len(self.target.shape), dtype=int)
189
187
 
190
188
  indices = compute_full_convolution_index(
191
- outer_shape=self.target.shape,
192
- inner_shape=self.template.shape,
189
+ outer_shape=self._target.shape,
190
+ inner_shape=self._template.shape,
193
191
  outer_split=target_slice,
194
192
  inner_split=template_slice,
195
193
  )
@@ -1136,7 +1136,6 @@ def mcc_scoring(
1136
1136
  mask_overlap, axis=axes, keepdims=True
1137
1137
  )
1138
1138
  temp[mask_overlap < number_px_threshold] = 0.0
1139
-
1140
1139
  convolution_mode = kwargs.get("convolution_mode", "full")
1141
1140
  score = apply_convolution_mode(
1142
1141
  temp, convolution_mode=convolution_mode, s1=targetshape, s2=templateshape
@@ -1165,7 +1164,6 @@ def device_memory_handler(func: Callable):
1165
1164
  return_value = func(shared_memory_handler=smh, *args, **kwargs)
1166
1165
  except Exception as e:
1167
1166
  print(e)
1168
- return None
1169
1167
  last_type, last_value, last_traceback = sys.exc_info()
1170
1168
  finally:
1171
1169
  handle_traceback(last_type, last_value, last_traceback)
@@ -327,7 +327,7 @@ def estimate_ram_usage(
327
327
  """
328
328
  if matching_method not in MATCHING_MEMORY_REGISTRY:
329
329
  raise ValueError(
330
- f"Supported fitters are {','.join(MATCHING_MEMORY_REGISTRY.keys())}"
330
+ f"Supported options are {','.join(MATCHING_MEMORY_REGISTRY.keys())}"
331
331
  )
332
332
 
333
333
  convolution_shape, fast_shape, ft_shape = _compute_convolution_shapes(
@@ -246,6 +246,7 @@ class CrossCorrelation(MatchCoordinatesToDensity):
246
246
  .. math::
247
247
 
248
248
  \\text{score} = \\text{target_weights} \\cdot \\text{template_weights}
249
+
249
250
  """
250
251
 
251
252
  def __init__(self, **kwargs):
@@ -298,6 +299,7 @@ class LaplaceCrossCorrelation(CrossCorrelation):
298
299
 
299
300
  \\text{score} = \\nabla^{2} \\text{target_weights} \\cdot
300
301
  \\nabla^{2} \\text{template_weights}
302
+
301
303
  """
302
304
 
303
305
  def __init__(self, **kwargs):
@@ -617,6 +619,7 @@ class Chamfer(MatchCoordinatesToCoordinates):
617
619
  -------
618
620
  float
619
621
  The negative of the Chamfer distance score.
622
+
620
623
  """
621
624
  dist, _ = self.target_tree.query(self.template_coordinates_rotated.T)
622
625
  score = np.mean(dist)
@@ -638,6 +641,7 @@ class MutualInformation(MatchCoordinatesToDensity):
638
641
  .. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
639
642
  fitting", Journal of Structural Biology, vol. 174, no. 2,
640
643
  pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
644
+
641
645
  """
642
646
 
643
647
  def __init__(self, **kwargs):
@@ -776,6 +780,7 @@ class NormalVectorScore(MatchCoordinatesToCoordinates):
776
780
  .. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
777
781
  fitting", Journal of Structural Biology, vol. 174, no. 2,
778
782
  pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
783
+
779
784
  """
780
785
 
781
786
  def __init__(self, **kwargs):
@@ -762,6 +762,8 @@ def euler_to_rotationmatrix(angles: Tuple[float]) -> NDArray:
762
762
  NDArray
763
763
  The generated rotation matrix.
764
764
  """
765
+ if len(angles) == 1:
766
+ angles = (angles, 0, 0)
765
767
  rotation_matrix = (
766
768
  Rotation.from_euler("zyx", angles, degrees=True).as_matrix().astype(np.float32)
767
769
  )
@@ -775,13 +777,17 @@ def euler_from_rotationmatrix(rotation_matrix: NDArray) -> Tuple:
775
777
  Parameters
776
778
  ----------
777
779
  rotation_matrix : NDArray
778
- A 3 x 3 rotation matrix in z y x form.
780
+ A 2 x 2 or 3 x 3 rotation matrix in z y x form.
779
781
 
780
782
  Returns
781
783
  -------
782
784
  Tuple
783
785
  The generate euler angles in degrees
784
786
  """
787
+ if rotation_matrix.shape[0] == 2:
788
+ temp_matrix = np.eye(3)
789
+ temp_matrix[:2, :2] = rotation_matrix
790
+ rotation_matrix = temp_matrix
785
791
  euler_angles = (
786
792
  Rotation.from_matrix(rotation_matrix)
787
793
  .as_euler("zyx", degrees=True)