pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/matching_utils.py CHANGED
@@ -5,7 +5,7 @@
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
7
  import os
8
- import traceback
8
+ import yaml
9
9
  import pickle
10
10
  from shutil import move
11
11
  from tempfile import mkstemp
@@ -14,47 +14,100 @@ from typing import Tuple, Dict, Callable
14
14
  from concurrent.futures import ThreadPoolExecutor
15
15
 
16
16
  import numpy as np
17
- from numpy.typing import NDArray
18
17
  from scipy.spatial import ConvexHull
19
18
  from scipy.ndimage import gaussian_filter
20
19
  from scipy.spatial.transform import Rotation
21
20
 
21
+ from .backends import backend as be
22
+ from .memory import estimate_ram_usage
23
+ from .types import NDArray, BackendArray
22
24
  from .extensions import max_euclidean_distance
23
- from .matching_memory import estimate_ram_usage
24
- from .helpers import quaternion_to_rotation_matrix, load_quaternions_by_angle
25
25
 
26
26
 
27
- def handle_traceback(last_type, last_value, last_traceback):
27
+ def noop(*args, **kwargs):
28
+ pass
29
+
30
+
31
+ def identity(arr, *args):
32
+ return arr
33
+
34
+
35
+ def conditional_execute(
36
+ func: Callable,
37
+ execute_operation: bool,
38
+ alt_func: Callable = noop,
39
+ ) -> Callable:
28
40
  """
29
- Handle sys.exc_info().
41
+ Return the given function or a no-op function based on execute_operation.
30
42
 
31
43
  Parameters
32
44
  ----------
33
- last_type : type
34
- The type of the last exception.
35
- last_value :
36
- The value of the last exception.
37
- last_traceback : traceback
38
- The traceback object encapsulating the call stack at the point
39
- where the exception originally occurred.
45
+ func : Callable
46
+ Callable.
47
+ alt_func : Callable
48
+ Callable to return if ``execute_operation`` is False, no-op by default.
49
+ execute_operation : bool
50
+ Whether to return ``func`` or a ``alt_func`` function.
40
51
 
41
- Raises
42
- ------
43
- Exception
44
- Re-raises the last exception.
52
+ Returns
53
+ -------
54
+ Callable
55
+ ``func`` if ``execute_operation`` else ``alt_func``.
56
+ """
57
+
58
+ return func if execute_operation else alt_func
59
+
60
+
61
+ def normalize_template(
62
+ template: BackendArray, mask: BackendArray, n_observations: float
63
+ ) -> BackendArray:
64
+ """
65
+ Standardizes ``template`` to zero mean and unit standard deviation in ``mask``.
66
+
67
+ .. warning:: ``template`` is modified during the operation.
68
+
69
+ Parameters
70
+ ----------
71
+ template : BackendArray
72
+ Input data.
73
+ mask : BackendArray
74
+ Mask of the same shape as ``template``.
75
+ n_observations : float
76
+ Sum of mask elements.
77
+
78
+ Returns
79
+ -------
80
+ BackendArray
81
+ Standardized input data.
82
+
83
+ References
84
+ ----------
85
+ .. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
45
86
  """
46
- if last_type is None:
47
- return None
48
- traceback.print_tb(last_traceback)
49
- raise Exception(last_value)
50
- # raise last_type(last_value)
87
+ masked_mean = be.sum(be.multiply(template, mask)) / n_observations
88
+ masked_std = be.sum(be.multiply(be.square(template), mask))
89
+ masked_std = be.subtract(masked_std / n_observations, be.square(masked_mean))
90
+ masked_std = be.sqrt(be.maximum(masked_std, 0))
91
+
92
+ template = be.subtract(template, masked_mean, out=template)
93
+ template = be.divide(template, masked_std, out=template)
94
+ return be.multiply(template, mask, out=template)
95
+
96
+
97
+ def _normalize_template_overflow_safe(
98
+ template: BackendArray, mask: BackendArray, n_observations: float
99
+ ) -> BackendArray:
100
+ _template = be.astype(template, be._overflow_safe_dtype)
101
+ _mask = be.astype(mask, be._overflow_safe_dtype)
102
+ normalize_template(template=_template, mask=_mask, n_observations=n_observations)
103
+ template[:] = be.astype(_template, template.dtype)
104
+ return template
51
105
 
52
106
 
53
- def generate_tempfile_name(suffix=None):
107
+ def generate_tempfile_name(suffix: str = None) -> str:
54
108
  """
55
- Returns the path to a potential temporary file location. If the environment
56
- variable TME_TMPDIR is defined, the temporary file will be created there.
57
- Otherwise the default tmp directory will be used.
109
+ Returns the path to a temporary file with given suffix. If defined. the
110
+ environment variable TMPDIR is used as base.
58
111
 
59
112
  Parameters
60
113
  ----------
@@ -73,26 +126,19 @@ def generate_tempfile_name(suffix=None):
73
126
 
74
127
  def array_to_memmap(arr: NDArray, filename: str = None) -> str:
75
128
  """
76
- Converts a numpy array to a np.memmap.
129
+ Converts a obj:`numpy.ndarray` to a obj:`numpy.memmap`.
77
130
 
78
131
  Parameters
79
132
  ----------
80
- arr : np.ndarray
81
- The numpy array to be converted.
133
+ arr : obj:`numpy.ndarray`
134
+ Input data.
82
135
  filename : str, optional
83
- Desired filename for the memmap. If not provided, a temporary
84
- file will be created.
85
-
86
- Notes
87
- -----
88
- If the environment variable TME_TMPDIR is defined, the temporary
89
- file will be created there. Otherwise the default tmp directory
90
- will be used.
136
+ Path to new memmap, :py:meth:`generate_tempfile_name` is used by default.
91
137
 
92
138
  Returns
93
139
  -------
94
140
  str
95
- The filename where the memmap was written to.
141
+ Path to the memmap.
96
142
  """
97
143
  if filename is None:
98
144
  filename = generate_tempfile_name()
@@ -108,47 +154,28 @@ def array_to_memmap(arr: NDArray, filename: str = None) -> str:
108
154
 
109
155
  def memmap_to_array(arr: NDArray) -> NDArray:
110
156
  """
111
- Converts a np.memmap into an numpy array.
157
+ Convert a obj:`numpy.memmap` to a obj:`numpy.ndarray` and delete the memmap.
112
158
 
113
159
  Parameters
114
160
  ----------
115
- arr : np.memmap
116
- The numpy array to be converted.
161
+ arr : obj:`numpy.memmap`
162
+ Input data.
117
163
 
118
164
  Returns
119
165
  -------
120
- np.ndarray
121
- The converted array.
166
+ obj:`numpy.ndarray`
167
+ In-memory version of ``arr``.
122
168
  """
123
- if type(arr) == np.memmap:
169
+ if isinstance(arr, np.memmap):
124
170
  memmap_filepath = arr.filename
125
171
  arr = np.array(arr)
126
172
  os.remove(memmap_filepath)
127
173
  return arr
128
174
 
129
175
 
130
- def close_memmap(arr: np.ndarray) -> None:
131
- """
132
- Remove the file associated with a numpy memmap array.
133
-
134
- Parameters
135
- ----------
136
- arr : np.ndarray
137
- The numpy array which might be a memmap.
138
- """
139
- try:
140
- os.remove(arr.filename)
141
- # arr._mmap.close()
142
- except Exception:
143
- pass
144
-
145
-
146
176
  def write_pickle(data: object, filename: str) -> None:
147
177
  """
148
- Serialize and write data to a file invalidating the input data in
149
- the process. This function uses type-specific serialization for
150
- certain objects, such as np.memmap, for optimized storage. Other
151
- objects are serialized using standard pickle.
178
+ Serialize and write data to a file invalidating the input data.
152
179
 
153
180
  Parameters
154
181
  ----------
@@ -316,7 +343,7 @@ def compute_parallelization_schedule(
316
343
  split_factor, n_splits = [1 for _ in range(len(shape1))], 0
317
344
  while n_splits <= max_splits:
318
345
  splits = {k: split_factor[k] for k in range(len(split_factor))}
319
- array_slices = split_numpy_array_slices(shape=shape1, splits=splits)
346
+ array_slices = split_shape(shape=shape1, splits=splits)
320
347
  array_widths = [
321
348
  tuple(x.stop - x.start for x in split) for split in array_slices
322
349
  ]
@@ -378,55 +405,57 @@ def compute_parallelization_schedule(
378
405
  return splits, core_assignment
379
406
 
380
407
 
381
- def centered(arr: NDArray, newshape: Tuple[int]) -> NDArray:
408
+ def _center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[slice]:
409
+ """Extract the center slice of ``current_shape`` to retrieve ``new_shape``."""
410
+ new_shape = tuple(int(x) for x in new_shape)
411
+ current_shape = tuple(int(x) for x in current_shape)
412
+ starts = tuple((x - y) // 2 for x, y in zip(current_shape, new_shape))
413
+ stops = tuple(sum(stop) for stop in zip(starts, new_shape))
414
+ box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
415
+ return box
416
+
417
+
418
+ def centered(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
382
419
  """
383
420
  Extract the centered portion of an array based on a new shape.
384
421
 
385
422
  Parameters
386
423
  ----------
387
- arr : NDArray
388
- Input array.
389
- newshape : tuple
424
+ arr : BackendArray
425
+ Input data.
426
+ new_shape : tuple of ints
390
427
  Desired shape for the central portion.
391
428
 
392
429
  Returns
393
430
  -------
394
- NDArray
395
- Central portion of the array with shape `newshape`.
431
+ BackendArray
432
+ Central portion of the array with shape ``new_shape``.
396
433
 
397
434
  References
398
435
  ----------
399
436
  .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L388
400
437
  """
401
- new_shape = np.asarray(newshape)
402
- current_shape = np.array(arr.shape)
403
- starts = (current_shape - new_shape) // 2
404
- stops = starts + newshape
405
- box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
438
+ box = _center_slice(arr.shape, new_shape=new_shape)
406
439
  return arr[box]
407
440
 
408
441
 
409
- def centered_mask(arr: NDArray, newshape: Tuple[int]) -> NDArray:
442
+ def centered_mask(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
410
443
  """
411
444
  Mask the centered portion of an array based on a new shape.
412
445
 
413
446
  Parameters
414
447
  ----------
415
- arr : NDArray
416
- Input array.
417
- newshape : tuple
448
+ arr : BackendArray
449
+ Input data.
450
+ new_shape : tuple of ints
418
451
  Desired shape for the mask.
419
452
 
420
453
  Returns
421
454
  -------
422
- NDArray
455
+ BackendArray
423
456
  Array with central portion unmasked and the rest set to 0.
424
457
  """
425
- new_shape = np.asarray(newshape)
426
- current_shape = np.array(arr.shape)
427
- starts = (current_shape - new_shape) // 2
428
- stops = starts + newshape
429
- box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
458
+ box = _center_slice(arr.shape, new_shape=new_shape)
430
459
  mask = np.zeros_like(arr)
431
460
  mask[box] = 1
432
461
  arr *= mask
@@ -434,21 +463,21 @@ def centered_mask(arr: NDArray, newshape: Tuple[int]) -> NDArray:
434
463
 
435
464
 
436
465
  def apply_convolution_mode(
437
- arr: NDArray,
466
+ arr: BackendArray,
438
467
  convolution_mode: str,
439
468
  s1: Tuple[int],
440
469
  s2: Tuple[int],
441
470
  mask_output: bool = False,
442
- ) -> NDArray:
471
+ ) -> BackendArray:
443
472
  """
444
- Applies convolution_mode to arr.
473
+ Applies convolution_mode to ``arr``.
445
474
 
446
475
  Parameters
447
476
  ----------
448
- arr : NDArray
449
- Numpy array containing convolution result of arrays with shape s1 and s2.
477
+ arr : BackendArray
478
+ Array containing convolution result of arrays with shape s1 and s2.
450
479
  convolution_mode : str
451
- Analogous to mode in ``scipy.signal.convolve``:
480
+ Analogous to mode in obj:`scipy.signal.convolve`:
452
481
 
453
482
  +---------+----------------------------------------------------------+
454
483
  | 'full' | returns full template matching result of the inputs. |
@@ -457,9 +486,9 @@ def apply_convolution_mode(
457
486
  +---------+----------------------------------------------------------+
458
487
  | 'same' | output is the same size as s1. |
459
488
  +---------+----------------------------------------------------------+
460
- s1 : tuple
489
+ s1 : tuple of ints
461
490
  Tuple of integers corresponding to shape of convolution array 1.
462
- s2 : tuple
491
+ s2 : tuple of ints
463
492
  Tuple of integers corresponding to shape of convolution array 2.
464
493
  mask_output : bool, optional
465
494
  Whether to mask values outside of convolution_mode rather than
@@ -467,14 +496,10 @@ def apply_convolution_mode(
467
496
 
468
497
  Returns
469
498
  -------
470
- NDArray
471
- The numpy array after applying the convolution mode.
472
-
473
- References
474
- ----------
475
- .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L519
499
+ BackendArray
500
+ The array after applying the convolution mode.
476
501
  """
477
- # This removes padding to next fast fourier length
502
+ # Remove padding to next fast Fourier length
478
503
  arr = arr[tuple(slice(s1[i] + s2[i] - 1) for i in range(len(s1)))]
479
504
 
480
505
  if convolution_mode not in ("full", "same", "valid"):
@@ -506,11 +531,9 @@ def compute_full_convolution_index(
506
531
  inner_shape : tuple
507
532
  Tuple of integers corresponding to the shape of the inner array.
508
533
  outer_split : tuple
509
- Tuple of slices used to split outer array
510
- (see :py:meth:`split_numpy_array_slices`).
534
+ Tuple of slices used to split outer array (see :py:meth:`split_shape`).
511
535
  inner_split : tuple
512
- Tuple of slices used to split inner array
513
- (see :py:meth:`split_numpy_array_slices`).
536
+ Tuple of slices used to split inner array (see :py:meth:`split_shape`).
514
537
 
515
538
  Returns
516
539
  -------
@@ -538,41 +561,43 @@ def compute_full_convolution_index(
538
561
  return score_slice
539
562
 
540
563
 
541
- def split_numpy_array_slices(
542
- shape: NDArray, splits: Dict, margin: NDArray = None
564
+ def split_shape(
565
+ shape: Tuple[int], splits: Dict, equal_shape: bool = True
543
566
  ) -> Tuple[slice]:
544
567
  """
545
- Returns a tuple of slices to subset a numpy array into pieces along multiple axes.
568
+ Splits ``shape`` into equally sized and potentially overlapping subsets.
546
569
 
547
570
  Parameters
548
571
  ----------
549
- shape : NDArray
550
- Shape of the array to split.
572
+ shape : tuple of ints
573
+ Shape to split.
551
574
  splits : dict
552
- A dictionary where the keys are the axis numbers and the values
553
- are the number of splits along that axis.
554
- margin : NDArray, optional
555
- Padding on the left hand side of the array.
575
+ Dictionary mapping axis number to number of splits.
576
+ equal_shape : dict
577
+ Whether the subsets should be of equal shape, True by default.
556
578
 
557
579
  Returns
558
580
  -------
559
581
  tuple
560
- A tuple of slices, where each slice corresponds to a split along an axis.
582
+ Tuple of slice with requested split combinations.
561
583
  """
562
584
  ndim = len(shape)
563
- if margin is None:
564
- margin = np.zeros(ndim, dtype=int)
565
- splits = {k: max(splits.get(k, 0), 1) for k in range(ndim)}
566
- new_shape = np.divide(shape, [splits.get(i, 1) for i in range(ndim)]).astype(int)
585
+ splits = {k: max(splits.get(k, 1), 1) for k in range(ndim)}
586
+ ret_shape = np.divide(shape, tuple(splits[i] for i in range(ndim)))
587
+ if equal_shape:
588
+ ret_shape = np.ceil(ret_shape).astype(int)
589
+ ret_shape = ret_shape.astype(int)
567
590
 
568
591
  slice_list = [
569
592
  tuple(
570
- (slice(max((n_splits * length) - margin[axis], 0), (n_splits + 1) * length))
593
+ (slice((n_splits * length), (n_splits + 1) * length))
571
594
  if n_splits < splits.get(axis, 1) - 1
572
- else (slice(max((n_splits * length) - margin[axis], 0), shape[axis]))
595
+ else (slice(shape[axis] - length, shape[axis]))
596
+ if equal_shape
597
+ else (slice((n_splits * length), shape[axis]))
573
598
  for n_splits in range(splits.get(axis, 1))
574
599
  )
575
- for length, axis in zip(new_shape, splits.keys())
600
+ for length, axis in zip(ret_shape, splits.keys())
576
601
  ]
577
602
 
578
603
  splits = tuple(product(*slice_list))
@@ -584,28 +609,25 @@ def get_rotation_matrices(
584
609
  angular_sampling: float, dim: int = 3, use_optimized_set: bool = True
585
610
  ) -> NDArray:
586
611
  """
587
- Returns rotation matrices in format k x dim x dim, where k is determined
588
- by ``angular_sampling``.
612
+ Returns rotation matrices with desired ``angular_sampling`` rate.
589
613
 
590
614
  Parameters
591
615
  ----------
592
616
  angular_sampling : float
593
- The angle in degrees used for the generation of rotation matrices.
617
+ The desired angular sampling in degrees.
594
618
  dim : int, optional
595
619
  Dimension of the rotation matrices.
596
620
  use_optimized_set : bool, optional
597
- Whether to use pre-computed rotational sets with more optimal sampling.
598
- Currently only available when dim=3.
621
+ Use optimized rotational sets, True by default and available for dim=3.
599
622
 
600
623
  Notes
601
624
  -----
602
- For the case of dim = 3 optimized rotational sets are used, otherwise
603
- QR-decomposition.
625
+ For dim = 3 optimized sets are used, otherwise QR-decomposition.
604
626
 
605
627
  Returns
606
628
  -------
607
629
  NDArray
608
- Array of shape (k, dim, dim) containing k rotation matrices.
630
+ Array of shape (n, d, d) containing n rotation matrices.
609
631
  """
610
632
  if dim == 3 and use_optimized_set:
611
633
  quaternions, *_ = load_quaternions_by_angle(angular_sampling)
@@ -706,144 +728,82 @@ def get_rotations_around_vector(
706
728
  return rotation_angles
707
729
 
708
730
 
709
- def minimum_enclosing_box(
710
- coordinates: NDArray,
711
- margin: NDArray = None,
712
- use_geometric_center: bool = False,
713
- ) -> Tuple[int]:
731
+ def load_quaternions_by_angle(
732
+ angular_sampling: float,
733
+ ) -> Tuple[NDArray, NDArray, float]:
714
734
  """
715
- Computes the minimal enclosing box around coordinates with margin.
735
+ Get orientations and weights proportional to the given angular_sampling.
716
736
 
717
737
  Parameters
718
738
  ----------
719
- coordinates : NDArray
720
- Coordinates of which the enclosing box should be computed. The shape
721
- of this array should be [d, n] with d dimensions and n coordinates.
722
- margin : NDArray, optional
723
- Box margin. Defaults to None.
724
- use_geometric_center : bool, optional
725
- Whether the box should accommodate the geometric or the coordinate
726
- center. Defaults to False.
739
+ angular_sampling : float
740
+ Requested angular sampling.
727
741
 
728
742
  Returns
729
743
  -------
730
- tuple
731
- Integers corresponding to the minimum enclosing box shape.
744
+ Tuple[NDArray, NDArray, float]
745
+ Quaternion representations of orientations, weights associated with each
746
+ quaternion and closest angular sampling to the requested sampling.
732
747
  """
733
- point_cloud = np.asarray(coordinates)
734
- dim = point_cloud.shape[0]
735
- point_cloud = point_cloud - point_cloud.min(axis=1)[:, None]
748
+ # Metadata contains (N orientations, rotational sampling, coverage as values)
749
+ with open(
750
+ os.path.join(os.path.dirname(__file__), "data", "metadata.yaml"), "r"
751
+ ) as infile:
752
+ metadata = yaml.full_load(infile)
753
+
754
+ set_diffs = {
755
+ setname: abs(angular_sampling - set_angle)
756
+ for setname, (_, set_angle, _) in metadata.items()
757
+ }
758
+ fname = min(set_diffs, key=set_diffs.get)
736
759
 
737
- margin = np.zeros(dim) if margin is None else margin
738
- margin = np.asarray(margin).astype(int)
760
+ infile = os.path.join(os.path.dirname(__file__), "data", fname)
761
+ quat_weights = np.load(infile)
739
762
 
740
- norm_cloud = point_cloud - point_cloud.mean(axis=1)[:, None]
741
- # Adding one avoids clipping during scipy.ndimage.affine_transform
742
- shape = np.repeat(
743
- np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
744
- ).astype(int)
745
- if use_geometric_center:
746
- hull = ConvexHull(point_cloud.T)
747
- distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
748
- distance += np.linalg.norm(np.ones(dim))
749
- shape = np.repeat(np.rint(distance).astype(int), dim)
763
+ quat = quat_weights[:, :4]
764
+ weights = quat_weights[:, -1]
765
+ angle = metadata[fname][0]
750
766
 
751
- return shape
767
+ return quat, weights, angle
752
768
 
753
769
 
754
- def crop_input(
755
- target: "Density",
756
- template: "Density",
757
- target_mask: "Density" = None,
758
- template_mask: "Density" = None,
759
- map_cutoff: float = 0,
760
- template_cutoff: float = 0,
761
- ) -> Tuple[int]:
770
+ def quaternion_to_rotation_matrix(quaternions: NDArray) -> NDArray:
762
771
  """
763
- Crop target and template maps for efficient fitting. Input densities
764
- are cropped in place.
772
+ Convert quaternions to rotation matrices.
765
773
 
766
774
  Parameters
767
775
  ----------
768
- target : Density
769
- Target to be fitted on.
770
- template : Density
771
- Template to fit onto the target.
772
- target_mask : Density, optional
773
- Path to mask of target. Will be croppped like target.
774
- template_mask : Density, optional
775
- Path to mask of template. Will be cropped like template.
776
- map_cutoff : float, optional
777
- Cutoff value for trimming the target Density. Default is 0.
778
- map_cutoff : float, optional
779
- Cutoff value for trimming the template Density. Default is 0.
776
+ quaternions : NDArray
777
+ Quaternion data of shape (n, 4).
780
778
 
781
779
  Returns
782
780
  -------
783
- Tuple[int]
784
- Tuple containing reference fit index
781
+ NDArray
782
+ Rotation matrices corresponding to the given quaternions.
785
783
  """
786
- convolution_shape_init = np.add(target.shape, template.shape) - 1
787
- # If target and template are aligned, fitting should return this index
788
- reference_fit = np.subtract(template.shape, 1)
789
-
790
- target_box = tuple(slice(0, x) for x in target.shape)
791
- if map_cutoff is not None:
792
- target_box = target.trim_box(cutoff=map_cutoff)
793
-
794
- target_mask_box = target_box
795
- if target_mask is not None and map_cutoff is not None:
796
- target_mask_box = target_mask.trim_box(cutoff=map_cutoff)
797
- target_box = tuple(
798
- slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
799
- for arr, mask in zip(target_box, target_mask_box)
800
- )
801
-
802
- template_box = tuple(slice(0, x) for x in template.shape)
803
- if template_cutoff is not None:
804
- template_box = template.trim_box(cutoff=template_cutoff)
805
-
806
- template_mask_box = template_box
807
- if template_mask is not None and template_cutoff is not None:
808
- template_mask_box = template_mask.trim_box(cutoff=template_cutoff)
809
- template_box = tuple(
810
- slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
811
- for arr, mask in zip(template_box, template_mask_box)
812
- )
784
+ q0 = quaternions[:, 0]
785
+ q1 = quaternions[:, 1]
786
+ q2 = quaternions[:, 2]
787
+ q3 = quaternions[:, 3]
813
788
 
814
- cut_right = np.array(
815
- [shape - x.stop for shape, x in zip(template.shape, template_box)]
816
- )
817
- cut_left = np.array([x.start for x in target_box])
789
+ s = np.linalg.norm(quaternions, axis=1) * 2
790
+ rotmat = np.zeros((quaternions.shape[0], 3, 3), dtype=np.float64)
818
791
 
819
- origin_difference = np.divide(target.origin - template.origin, target.sampling_rate)
820
- origin_difference = origin_difference.astype(int)
792
+ rotmat[:, 0, 0] = 1.0 - s * ((q2 * q2) + (q3 * q3))
793
+ rotmat[:, 0, 1] = s * ((q1 * q2) - (q0 * q3))
794
+ rotmat[:, 0, 2] = s * ((q1 * q3) + (q0 * q2))
821
795
 
822
- target.adjust_box(target_box)
823
- template.adjust_box(template_box)
796
+ rotmat[:, 1, 0] = s * ((q2 * q1) + (q0 * q3))
797
+ rotmat[:, 1, 1] = 1.0 - s * ((q3 * q3) + (q1 * q1))
798
+ rotmat[:, 1, 2] = s * ((q2 * q3) - (q0 * q1))
824
799
 
825
- if target_mask is not None:
826
- target_mask.adjust_box(target_box)
827
- if template_mask is not None:
828
- template_mask.adjust_box(template_box)
800
+ rotmat[:, 2, 0] = s * ((q3 * q1) - (q0 * q2))
801
+ rotmat[:, 2, 1] = s * ((q3 * q2) + (q0 * q1))
802
+ rotmat[:, 2, 2] = 1.0 - s * ((q1 * q1) + (q2 * q2))
829
803
 
830
- reference_fit -= cut_right + cut_left + origin_difference
804
+ np.around(rotmat, decimals=8, out=rotmat)
831
805
 
832
- convolution_shape = np.array(target.shape)
833
- convolution_shape += np.array(template.shape) - 1
834
-
835
- print(f"Cropped volume of target is: {target.shape}")
836
- print(f"Cropped volume of template is: {template.shape}")
837
- saving = 1 - (np.prod(convolution_shape)) / np.prod(convolution_shape_init)
838
- saving *= 100
839
-
840
- print(
841
- "Cropping changed array size from "
842
- f"{round(4*np.prod(convolution_shape_init)/1e6, 3)} MB "
843
- f"to {round(4*np.prod(convolution_shape)/1e6, 3)} MB "
844
- f"({'-' if saving > 0 else ''}{abs(round(saving, 2))}%)"
845
- )
846
- return reference_fit
806
+ return rotmat
847
807
 
848
808
 
849
809
  def euler_to_rotationmatrix(angles: Tuple[float], convention: str = "zyx") -> NDArray:
@@ -866,12 +826,8 @@ def euler_to_rotationmatrix(angles: Tuple[float], convention: str = "zyx") -> ND
866
826
  angle_convention = convention[:n_angles]
867
827
  if n_angles == 1:
868
828
  angles = (angles, 0, 0)
869
- rotation_matrix = (
870
- Rotation.from_euler(angle_convention, angles, degrees=True)
871
- .as_matrix()
872
- .astype(np.float32)
873
- )
874
- return rotation_matrix
829
+ rotation_matrix = Rotation.from_euler(angle_convention, angles, degrees=True)
830
+ return rotation_matrix.as_matrix().astype(np.float32)
875
831
 
876
832
 
877
833
  def euler_from_rotationmatrix(
@@ -883,9 +839,10 @@ def euler_from_rotationmatrix(
883
839
  Parameters
884
840
  ----------
885
841
  rotation_matrix : NDArray
886
- A 2 x 2 or 3 x 3 rotation matrix in z y x form.
842
+ A 2 x 2 or 3 x 3 rotation matrix in zyx form.
887
843
  convention : str, optional
888
- Euler angle convention.
844
+ Euler angle convention, zyx by default.
845
+
889
846
  Returns
890
847
  -------
891
848
  Tuple
@@ -895,12 +852,8 @@ def euler_from_rotationmatrix(
895
852
  temp_matrix = np.eye(3)
896
853
  temp_matrix[:2, :2] = rotation_matrix
897
854
  rotation_matrix = temp_matrix
898
- euler_angles = (
899
- Rotation.from_matrix(rotation_matrix)
900
- .as_euler(convention, degrees=True)
901
- .astype(np.float32)
902
- )
903
- return euler_angles
855
+ rotation = Rotation.from_matrix(rotation_matrix)
856
+ return rotation.as_euler(convention, degrees=True).astype(np.float32)
904
857
 
905
858
 
906
859
  def rotation_aligning_vectors(
@@ -961,23 +914,19 @@ def rigid_transform(
961
914
  Parameters
962
915
  ----------
963
916
  coordinates : NDArray
964
- An array representing the coordinates to be transformed [d x N].
917
+ An array representing the coordinates to be transformed (d,n).
965
918
  rotation_matrix : NDArray
966
- The rotation matrix to be applied [d x d].
919
+ The rotation matrix to be applied (d,d).
967
920
  translation : NDArray
968
- The translation vector to be applied [d].
921
+ The translation vector to be applied (d,).
969
922
  out : NDArray
970
- The output array to store the transformed coordinates.
923
+ The output array to store the transformed coordinates (d,n).
971
924
  coordinates_mask : NDArray, optional
972
- An array representing the mask for the coordinates [d x T].
925
+ An array representing the mask for the coordinates (d,t).
973
926
  out_mask : NDArray, optional
974
- The output array to store the transformed coordinates mask.
927
+ The output array to store the transformed coordinates mask (d,t).
975
928
  use_geometric_center : bool, optional
976
929
  Whether to use geometric or coordinate center.
977
-
978
- Returns
979
- -------
980
- None
981
930
  """
982
931
  coordinate_dtype = coordinates.dtype
983
932
  center = coordinates.mean(axis=1) if center is None else center
@@ -1004,71 +953,67 @@ def rigid_transform(
1004
953
  out += translation[:, None]
1005
954
 
1006
955
 
1007
- def _format_string(string: str) -> str:
956
+ def minimum_enclosing_box(
957
+ coordinates: NDArray, margin: NDArray = None, use_geometric_center: bool = False
958
+ ) -> Tuple[int]:
1008
959
  """
1009
- Formats a string by adding quotation marks if it contains white spaces.
960
+ Computes the minimal enclosing box around coordinates with margin.
1010
961
 
1011
962
  Parameters
1012
963
  ----------
1013
- string : str
1014
- Input string to be formatted.
964
+ coordinates : NDArray
965
+ Coordinates of shape (d,n) to compute the enclosing box of.
966
+ margin : NDArray, optional
967
+ Box margin, zero by default.
968
+ use_geometric_center : bool, optional
969
+ Whether box accommodates the geometric or coordinate center, False by default.
1015
970
 
1016
971
  Returns
1017
972
  -------
1018
- str
1019
- Formatted string with added quotation marks if needed.
973
+ tuple of ints
974
+ Minimum enclosing box shape.
1020
975
  """
1021
- if " " in string:
1022
- return f"'{string}'"
1023
- # Occurs e.g. for C1' atoms. The trailing whitespace is necessary.
1024
- if string.count("'") == 1:
1025
- return f'"{string}"'
1026
- return string
1027
-
976
+ point_cloud = np.asarray(coordinates)
977
+ dim = point_cloud.shape[0]
978
+ point_cloud = point_cloud - point_cloud.min(axis=1)[:, None]
1028
979
 
1029
- def _format_mmcif_colunns(subdict: Dict) -> Dict:
1030
- """
1031
- Formats the columns of a mmcif dictionary.
980
+ margin = np.zeros(dim) if margin is None else margin
981
+ margin = np.asarray(margin).astype(int)
1032
982
 
1033
- Parameters
1034
- ----------
1035
- subdict : dict
1036
- Input dictionary where each key corresponds to a column and the
1037
- values are iterables containing the column values.
983
+ norm_cloud = point_cloud - point_cloud.mean(axis=1)[:, None]
984
+ # Adding one avoids clipping during scipy.ndimage.affine_transform
985
+ shape = np.repeat(
986
+ np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
987
+ ).astype(int)
988
+ if use_geometric_center:
989
+ hull = ConvexHull(point_cloud.T)
990
+ distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
991
+ distance += np.linalg.norm(np.ones(dim))
992
+ shape = np.repeat(np.rint(distance).astype(int), dim)
1038
993
 
1039
- Returns
1040
- -------
1041
- dict
1042
- Formatted dictionary with the columns of the mmcif file.
1043
- """
1044
- subdict = {k: [_format_string(s) for s in v] for k, v in subdict.items()}
1045
- key_length = {
1046
- key: len(max(value, key=lambda x: len(x), default=""))
1047
- for key, value in subdict.items()
1048
- }
1049
- padded_subdict = {
1050
- key: [s.ljust(key_length[key] + 1) for s in values]
1051
- for key, values in subdict.items()
1052
- }
1053
- return padded_subdict
994
+ return shape
1054
995
 
1055
996
 
1056
- def create_mask(
1057
- mask_type: str, sigma_decay: float = 0, mask_cutoff: float = 0.135, **kwargs
1058
- ) -> NDArray:
997
+ def create_mask(mask_type: str, sigma_decay: float = 0, **kwargs) -> NDArray:
1059
998
  """
1060
999
  Creates a mask of the specified type.
1061
1000
 
1062
1001
  Parameters
1063
1002
  ----------
1064
1003
  mask_type : str
1065
- Type of the mask to be created. Can be "ellipse", "box", or "tube".
1004
+ Type of the mask to be created. Can be one of:
1005
+
1006
+ +---------+----------------------------------------------------------+
1007
+ | box | Box mask (see :py:meth:`box_mask`) |
1008
+ +---------+----------------------------------------------------------+
1009
+ | tube | Cylindrical mask (see :py:meth:`tube_mask`) |
1010
+ +---------+----------------------------------------------------------+
1011
+ | ellipse | Ellipsoidal mask (see :py:meth:`elliptical_mask`) |
1012
+ +---------+----------------------------------------------------------+
1066
1013
  sigma_decay : float, optional
1067
- Standard deviation of an optionally applied Gaussian filter.
1068
- mask_cutoff : float, optional
1069
- Values below mask_cutoff will be set to zero. By default, exp(-2).
1014
+ Smoothing along mask edges using a Gaussian filter, 0 by default.
1070
1015
  kwargs : dict
1071
- Additional parameters required by the mask creating functions.
1016
+ Parameters passed to the indivdual mask creation funcitons.
1072
1017
 
1073
1018
  Returns
1074
1019
  -------
@@ -1079,12 +1024,6 @@ def create_mask(
1079
1024
  ------
1080
1025
  ValueError
1081
1026
  If the mask_type is invalid.
1082
-
1083
- See Also
1084
- --------
1085
- :py:meth:`elliptical_mask`
1086
- :py:meth:`box_mask`
1087
- :py:meth:`tube_mask`
1088
1027
  """
1089
1028
  mapping = {"ellipse": elliptical_mask, "box": box_mask, "tube": tube_mask}
1090
1029
  if mask_type not in mapping:
@@ -1092,9 +1031,9 @@ def create_mask(
1092
1031
 
1093
1032
  mask = mapping[mask_type](**kwargs)
1094
1033
  if sigma_decay > 0:
1095
- mask = gaussian_filter(mask.astype(np.float32), sigma=sigma_decay)
1096
-
1097
- mask[mask < mask_cutoff] = 0
1034
+ mask_filter = gaussian_filter(mask.astype(np.float32), sigma=sigma_decay)
1035
+ mask = np.add(mask, (1 - mask) * mask_filter)
1036
+ mask[mask < np.exp(-np.square(sigma_decay))] = 0
1098
1037
 
1099
1038
  return mask
1100
1039
 
@@ -1126,6 +1065,7 @@ def elliptical_mask(
1126
1065
 
1127
1066
  Examples
1128
1067
  --------
1068
+ >>> from tme.matching_utils import elliptical_mask
1129
1069
  >>> mask = elliptical_mask(shape = (20,20), radius = (5,5), center = (10,10))
1130
1070
  """
1131
1071
  center, shape, radius = np.asarray(center), np.asarray(shape), np.asarray(radius)
@@ -1154,17 +1094,23 @@ def box_mask(shape: Tuple[int], center: Tuple[int], height: Tuple[int]) -> np.nd
1154
1094
 
1155
1095
  Parameters
1156
1096
  ----------
1157
- shape : Tuple[int]
1097
+ shape : tuple of ints
1158
1098
  Shape of the output array.
1159
- center : Tuple[int]
1099
+ center : tuple of ints
1160
1100
  Center point coordinates of the box.
1161
- height : Tuple[int]
1101
+ height : tuple of ints
1162
1102
  Height (side length) of the box along each axis.
1163
1103
 
1164
1104
  Returns
1165
1105
  -------
1166
1106
  NDArray
1167
1107
  The created box mask.
1108
+
1109
+ Raises
1110
+ ------
1111
+ ValueError
1112
+ If ``shape`` and ``center`` do not have the same length.
1113
+ If ``center`` and ``height`` do not have the same length.
1168
1114
  """
1169
1115
  if len(shape) != len(center) or len(center) != len(height):
1170
1116
  raise ValueError("The length of shape, center, and height must be consistent.")
@@ -1216,9 +1162,9 @@ def tube_mask(
1216
1162
  Raises
1217
1163
  ------
1218
1164
  ValueError
1219
- If the inner radius is larger than the outer radius, height is larger
1220
- than the symmetry axis shape, or if base_center and shape do not have the
1221
- same length.
1165
+ If ``inner_radius`` is larger than ``outer_radius``.
1166
+ If ``height`` is larger than the symmetry axis.
1167
+ If ``base_center`` and ``shape`` do not have the same length.
1222
1168
  """
1223
1169
  if inner_radius > outer_radius:
1224
1170
  raise ValueError("inner_radius should be smaller than outer_radius.")
@@ -1277,91 +1223,39 @@ def scramble_phases(
1277
1223
  normalize_power: bool = True,
1278
1224
  ) -> NDArray:
1279
1225
  """
1280
- Applies random phase scrambling to a given array.
1281
-
1282
- This function takes an input array, applies a Fourier transform, then scrambles the
1283
- phase with a given proportion of noise, and finally applies an
1284
- inverse Fourier transform to the scrambled data. The phase scrambling
1285
- is controlled by a random seed.
1226
+ Perform random phase scrambling of ``arr``.
1286
1227
 
1287
1228
  Parameters
1288
1229
  ----------
1289
1230
  arr : NDArray
1290
- The input array to be scrambled.
1231
+ Input data.
1291
1232
  noise_proportion : float, optional
1292
- The proportion of noise in the phase scrambling, by default 0.5.
1233
+ Proportion of scrambled phases, 0.5 by default.
1293
1234
  seed : int, optional
1294
- The seed for the random phase scrambling, by default 42.
1235
+ The seed for the random phase scrambling, 42 by default.
1295
1236
  normalize_power : bool, optional
1296
- Whether the returned template should have the same sum of squares as arr.
1237
+ Return value has same sum of squares as ``arr``.
1297
1238
 
1298
1239
  Returns
1299
1240
  -------
1300
1241
  NDArray
1301
- The array with scrambled phases.
1302
-
1303
- Raises
1304
- ------
1305
- ValueError
1306
- If noise_proportion is not within [0, 1].
1242
+ Phase scrambled version of ``arr``.
1307
1243
  """
1308
- if noise_proportion < 0 or noise_proportion > 1:
1309
- raise ValueError("noise_proportion has to be within [0, 1].")
1244
+ np.random.seed(seed)
1245
+ noise_proportion = max(min(noise_proportion, 1), 0)
1310
1246
 
1311
1247
  arr_fft = np.fft.fftn(arr)
1248
+ amp, ph = np.abs(arr_fft), np.angle(arr_fft)
1312
1249
 
1313
- amp = np.abs(arr_fft)
1314
- ph = np.angle(arr_fft)
1315
-
1316
- np.random.seed(seed)
1317
1250
  ph_noise = np.random.permutation(ph)
1318
1251
  ph_new = ph * (1 - noise_proportion) + ph_noise * noise_proportion
1319
1252
  ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph_new)))
1320
1253
 
1321
1254
  if normalize_power:
1322
- np.divide(
1323
- np.subtract(ret, ret.min()), np.subtract(ret.max(), ret.min()), out=ret
1324
- )
1255
+ np.divide(ret - ret.min(), ret.max() - ret.min(), out=ret)
1325
1256
  np.multiply(ret, np.subtract(arr.max(), arr.min()), out=ret)
1326
1257
  np.add(ret, arr.min(), out=ret)
1327
-
1328
1258
  scaling = np.divide(np.abs(arr).sum(), np.abs(ret).sum())
1329
1259
  np.multiply(ret, scaling, out=ret)
1330
1260
 
1331
1261
  return ret
1332
-
1333
-
1334
- def conditional_execute(func: Callable, execute_operation: bool = True) -> Callable:
1335
- """
1336
- Return the given function or a no-operation function based on execute_operation.
1337
-
1338
- Parameters
1339
- ----------
1340
- func : callable
1341
- The function to be executed if execute_operation is True.
1342
- execute_operation : bool, optional
1343
- A flag that determines whether to return `func` or a no-operation function.
1344
- Default is True.
1345
-
1346
- Returns
1347
- -------
1348
- callable
1349
- Either the given function `func` or a no-operation function.
1350
-
1351
- Examples
1352
- --------
1353
- >>> def greet(name):
1354
- ... return f"Hello, {name}!"
1355
- ...
1356
- >>> operation = conditional_execute(greet, False)
1357
- >>> operation("Alice")
1358
- >>> operation = conditional_execute(greet, True)
1359
- >>> operation("Alice")
1360
- 'Hello, Alice!'
1361
- """
1362
-
1363
- def noop(*args, **kwargs):
1364
- """No operation function."""
1365
- pass
1366
-
1367
- return func if execute_operation else noop