pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/matching_utils.py CHANGED
@@ -5,7 +5,7 @@
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
7
  import os
8
- import traceback
8
+ import yaml
9
9
  import pickle
10
10
  from shutil import move
11
11
  from tempfile import mkstemp
@@ -14,47 +14,100 @@ from typing import Tuple, Dict, Callable
14
14
  from concurrent.futures import ThreadPoolExecutor
15
15
 
16
16
  import numpy as np
17
- from numpy.typing import NDArray
18
17
  from scipy.spatial import ConvexHull
19
18
  from scipy.ndimage import gaussian_filter
20
19
  from scipy.spatial.transform import Rotation
21
20
 
21
+ from .backends import backend as be
22
+ from .memory import estimate_ram_usage
23
+ from .types import NDArray, BackendArray
22
24
  from .extensions import max_euclidean_distance
23
- from .matching_memory import estimate_ram_usage
24
- from .helpers import quaternion_to_rotation_matrix, load_quaternions_by_angle
25
25
 
26
26
 
27
- def handle_traceback(last_type, last_value, last_traceback):
27
+ def noop(*args, **kwargs):
28
+ pass
29
+
30
+
31
+ def identity(arr, *args):
32
+ return arr
33
+
34
+
35
+ def conditional_execute(
36
+ func: Callable,
37
+ execute_operation: bool,
38
+ alt_func: Callable = noop,
39
+ ) -> Callable:
28
40
  """
29
- Handle sys.exc_info().
41
+ Return the given function or a no-op function based on execute_operation.
30
42
 
31
43
  Parameters
32
44
  ----------
33
- last_type : type
34
- The type of the last exception.
35
- last_value :
36
- The value of the last exception.
37
- last_traceback : traceback
38
- The traceback object encapsulating the call stack at the point
39
- where the exception originally occurred.
45
+ func : Callable
46
+ Callable.
47
+ alt_func : Callable
48
+ Callable to return if ``execute_operation`` is False, no-op by default.
49
+ execute_operation : bool
50
+ Whether to return ``func`` or a ``alt_func`` function.
40
51
 
41
- Raises
42
- ------
43
- Exception
44
- Re-raises the last exception.
52
+ Returns
53
+ -------
54
+ Callable
55
+ ``func`` if ``execute_operation`` else ``alt_func``.
56
+ """
57
+
58
+ return func if execute_operation else alt_func
59
+
60
+
61
+ def normalize_template(
62
+ template: BackendArray, mask: BackendArray, n_observations: float
63
+ ) -> BackendArray:
64
+ """
65
+ Standardizes ``template`` to zero mean and unit standard deviation in ``mask``.
66
+
67
+ .. warning:: ``template`` is modified during the operation.
68
+
69
+ Parameters
70
+ ----------
71
+ template : BackendArray
72
+ Input data.
73
+ mask : BackendArray
74
+ Mask of the same shape as ``template``.
75
+ n_observations : float
76
+ Sum of mask elements.
77
+
78
+ Returns
79
+ -------
80
+ BackendArray
81
+ Standardized input data.
82
+
83
+ References
84
+ ----------
85
+ .. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
45
86
  """
46
- if last_type is None:
47
- return None
48
- traceback.print_tb(last_traceback)
49
- raise Exception(last_value)
50
- # raise last_type(last_value)
87
+ masked_mean = be.sum(be.multiply(template, mask)) / n_observations
88
+ masked_std = be.sum(be.multiply(be.square(template), mask))
89
+ masked_std = be.subtract(masked_std / n_observations, be.square(masked_mean))
90
+ masked_std = be.sqrt(be.maximum(masked_std, 0))
91
+
92
+ template = be.subtract(template, masked_mean, out=template)
93
+ template = be.divide(template, masked_std, out=template)
94
+ return be.multiply(template, mask, out=template)
95
+
96
+
97
+ def _normalize_template_overflow_safe(
98
+ template: BackendArray, mask: BackendArray, n_observations: float
99
+ ) -> BackendArray:
100
+ _template = be.astype(template, be._overflow_safe_dtype)
101
+ _mask = be.astype(mask, be._overflow_safe_dtype)
102
+ normalize_template(template=_template, mask=_mask, n_observations=n_observations)
103
+ template[:] = be.astype(_template, template.dtype)
104
+ return template
51
105
 
52
106
 
53
- def generate_tempfile_name(suffix=None):
107
+ def generate_tempfile_name(suffix: str = None) -> str:
54
108
  """
55
- Returns the path to a potential temporary file location. If the environment
56
- variable TME_TMPDIR is defined, the temporary file will be created there.
57
- Otherwise the default tmp directory will be used.
109
+ Returns the path to a temporary file with given suffix. If defined. the
110
+ environment variable TMPDIR is used as base.
58
111
 
59
112
  Parameters
60
113
  ----------
@@ -73,26 +126,19 @@ def generate_tempfile_name(suffix=None):
73
126
 
74
127
  def array_to_memmap(arr: NDArray, filename: str = None) -> str:
75
128
  """
76
- Converts a numpy array to a np.memmap.
129
+ Converts a obj:`numpy.ndarray` to a obj:`numpy.memmap`.
77
130
 
78
131
  Parameters
79
132
  ----------
80
- arr : np.ndarray
81
- The numpy array to be converted.
133
+ arr : obj:`numpy.ndarray`
134
+ Input data.
82
135
  filename : str, optional
83
- Desired filename for the memmap. If not provided, a temporary
84
- file will be created.
85
-
86
- Notes
87
- -----
88
- If the environment variable TME_TMPDIR is defined, the temporary
89
- file will be created there. Otherwise the default tmp directory
90
- will be used.
136
+ Path to new memmap, :py:meth:`generate_tempfile_name` is used by default.
91
137
 
92
138
  Returns
93
139
  -------
94
140
  str
95
- The filename where the memmap was written to.
141
+ Path to the memmap.
96
142
  """
97
143
  if filename is None:
98
144
  filename = generate_tempfile_name()
@@ -108,47 +154,28 @@ def array_to_memmap(arr: NDArray, filename: str = None) -> str:
108
154
 
109
155
  def memmap_to_array(arr: NDArray) -> NDArray:
110
156
  """
111
- Converts a np.memmap into an numpy array.
157
+ Convert a obj:`numpy.memmap` to a obj:`numpy.ndarray` and delete the memmap.
112
158
 
113
159
  Parameters
114
160
  ----------
115
- arr : np.memmap
116
- The numpy array to be converted.
161
+ arr : obj:`numpy.memmap`
162
+ Input data.
117
163
 
118
164
  Returns
119
165
  -------
120
- np.ndarray
121
- The converted array.
166
+ obj:`numpy.ndarray`
167
+ In-memory version of ``arr``.
122
168
  """
123
- if type(arr) == np.memmap:
169
+ if isinstance(arr, np.memmap):
124
170
  memmap_filepath = arr.filename
125
171
  arr = np.array(arr)
126
172
  os.remove(memmap_filepath)
127
173
  return arr
128
174
 
129
175
 
130
- def close_memmap(arr: np.ndarray) -> None:
131
- """
132
- Remove the file associated with a numpy memmap array.
133
-
134
- Parameters
135
- ----------
136
- arr : np.ndarray
137
- The numpy array which might be a memmap.
138
- """
139
- try:
140
- os.remove(arr.filename)
141
- # arr._mmap.close()
142
- except Exception:
143
- pass
144
-
145
-
146
176
  def write_pickle(data: object, filename: str) -> None:
147
177
  """
148
- Serialize and write data to a file invalidating the input data in
149
- the process. This function uses type-specific serialization for
150
- certain objects, such as np.memmap, for optimized storage. Other
151
- objects are serialized using standard pickle.
178
+ Serialize and write data to a file invalidating the input data.
152
179
 
153
180
  Parameters
154
181
  ----------
@@ -316,7 +343,7 @@ def compute_parallelization_schedule(
316
343
  split_factor, n_splits = [1 for _ in range(len(shape1))], 0
317
344
  while n_splits <= max_splits:
318
345
  splits = {k: split_factor[k] for k in range(len(split_factor))}
319
- array_slices = split_numpy_array_slices(shape=shape1, splits=splits)
346
+ array_slices = split_shape(shape=shape1, splits=splits)
320
347
  array_widths = [
321
348
  tuple(x.stop - x.start for x in split) for split in array_slices
322
349
  ]
@@ -378,55 +405,57 @@ def compute_parallelization_schedule(
378
405
  return splits, core_assignment
379
406
 
380
407
 
381
- def centered(arr: NDArray, newshape: Tuple[int]) -> NDArray:
408
+ def _center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[slice]:
409
+ """Extract the center slice of ``current_shape`` to retrieve ``new_shape``."""
410
+ new_shape = tuple(int(x) for x in new_shape)
411
+ current_shape = tuple(int(x) for x in current_shape)
412
+ starts = tuple((x - y) // 2 for x, y in zip(current_shape, new_shape))
413
+ stops = tuple(sum(stop) for stop in zip(starts, new_shape))
414
+ box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
415
+ return box
416
+
417
+
418
+ def centered(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
382
419
  """
383
420
  Extract the centered portion of an array based on a new shape.
384
421
 
385
422
  Parameters
386
423
  ----------
387
- arr : NDArray
388
- Input array.
389
- newshape : tuple
424
+ arr : BackendArray
425
+ Input data.
426
+ new_shape : tuple of ints
390
427
  Desired shape for the central portion.
391
428
 
392
429
  Returns
393
430
  -------
394
- NDArray
395
- Central portion of the array with shape `newshape`.
431
+ BackendArray
432
+ Central portion of the array with shape ``new_shape``.
396
433
 
397
434
  References
398
435
  ----------
399
436
  .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L388
400
437
  """
401
- new_shape = np.asarray(newshape)
402
- current_shape = np.array(arr.shape)
403
- starts = (current_shape - new_shape) // 2
404
- stops = starts + newshape
405
- box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
438
+ box = _center_slice(arr.shape, new_shape=new_shape)
406
439
  return arr[box]
407
440
 
408
441
 
409
- def centered_mask(arr: NDArray, newshape: Tuple[int]) -> NDArray:
442
+ def centered_mask(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
410
443
  """
411
444
  Mask the centered portion of an array based on a new shape.
412
445
 
413
446
  Parameters
414
447
  ----------
415
- arr : NDArray
416
- Input array.
417
- newshape : tuple
448
+ arr : BackendArray
449
+ Input data.
450
+ new_shape : tuple of ints
418
451
  Desired shape for the mask.
419
452
 
420
453
  Returns
421
454
  -------
422
- NDArray
455
+ BackendArray
423
456
  Array with central portion unmasked and the rest set to 0.
424
457
  """
425
- new_shape = np.asarray(newshape)
426
- current_shape = np.array(arr.shape)
427
- starts = (current_shape - new_shape) // 2
428
- stops = starts + newshape
429
- box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
458
+ box = _center_slice(arr.shape, new_shape=new_shape)
430
459
  mask = np.zeros_like(arr)
431
460
  mask[box] = 1
432
461
  arr *= mask
@@ -434,21 +463,22 @@ def centered_mask(arr: NDArray, newshape: Tuple[int]) -> NDArray:
434
463
 
435
464
 
436
465
  def apply_convolution_mode(
437
- arr: NDArray,
466
+ arr: BackendArray,
438
467
  convolution_mode: str,
439
468
  s1: Tuple[int],
440
469
  s2: Tuple[int],
470
+ convolution_shape: Tuple[int] = None,
441
471
  mask_output: bool = False,
442
- ) -> NDArray:
472
+ ) -> BackendArray:
443
473
  """
444
- Applies convolution_mode to arr.
474
+ Applies convolution_mode to ``arr``.
445
475
 
446
476
  Parameters
447
477
  ----------
448
- arr : NDArray
449
- Numpy array containing convolution result of arrays with shape s1 and s2.
478
+ arr : BackendArray
479
+ Array containing convolution result of arrays with shape s1 and s2.
450
480
  convolution_mode : str
451
- Analogous to mode in ``scipy.signal.convolve``:
481
+ Analogous to mode in obj:`scipy.signal.convolve`:
452
482
 
453
483
  +---------+----------------------------------------------------------+
454
484
  | 'full' | returns full template matching result of the inputs. |
@@ -457,25 +487,25 @@ def apply_convolution_mode(
457
487
  +---------+----------------------------------------------------------+
458
488
  | 'same' | output is the same size as s1. |
459
489
  +---------+----------------------------------------------------------+
460
- s1 : tuple
490
+ s1 : tuple of ints
461
491
  Tuple of integers corresponding to shape of convolution array 1.
462
- s2 : tuple
492
+ s2 : tuple of ints
463
493
  Tuple of integers corresponding to shape of convolution array 2.
494
+ convolution_shape : tuple of ints, optional
495
+ Size of the actually computed convolution. s1 + s2 - 1 by default.
464
496
  mask_output : bool, optional
465
497
  Whether to mask values outside of convolution_mode rather than
466
498
  removing them. Defaults to False.
467
499
 
468
500
  Returns
469
501
  -------
470
- NDArray
471
- The numpy array after applying the convolution mode.
472
-
473
- References
474
- ----------
475
- .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L519
502
+ BackendArray
503
+ The array after applying the convolution mode.
476
504
  """
477
- # This removes padding to next fast fourier length
478
- arr = arr[tuple(slice(s1[i] + s2[i] - 1) for i in range(len(s1)))]
505
+ # Remove padding to next fast Fourier length
506
+ if convolution_shape is None:
507
+ convolution_shape = [s1[i] + s2[i] - 1 for i in range(len(s1))]
508
+ arr = arr[tuple(slice(0, x) for x in convolution_shape)]
479
509
 
480
510
  if convolution_mode not in ("full", "same", "valid"):
481
511
  raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
@@ -506,11 +536,9 @@ def compute_full_convolution_index(
506
536
  inner_shape : tuple
507
537
  Tuple of integers corresponding to the shape of the inner array.
508
538
  outer_split : tuple
509
- Tuple of slices used to split outer array
510
- (see :py:meth:`split_numpy_array_slices`).
539
+ Tuple of slices used to split outer array (see :py:meth:`split_shape`).
511
540
  inner_split : tuple
512
- Tuple of slices used to split inner array
513
- (see :py:meth:`split_numpy_array_slices`).
541
+ Tuple of slices used to split inner array (see :py:meth:`split_shape`).
514
542
 
515
543
  Returns
516
544
  -------
@@ -538,41 +566,43 @@ def compute_full_convolution_index(
538
566
  return score_slice
539
567
 
540
568
 
541
- def split_numpy_array_slices(
542
- shape: NDArray, splits: Dict, margin: NDArray = None
569
+ def split_shape(
570
+ shape: Tuple[int], splits: Dict, equal_shape: bool = True
543
571
  ) -> Tuple[slice]:
544
572
  """
545
- Returns a tuple of slices to subset a numpy array into pieces along multiple axes.
573
+ Splits ``shape`` into equally sized and potentially overlapping subsets.
546
574
 
547
575
  Parameters
548
576
  ----------
549
- shape : NDArray
550
- Shape of the array to split.
577
+ shape : tuple of ints
578
+ Shape to split.
551
579
  splits : dict
552
- A dictionary where the keys are the axis numbers and the values
553
- are the number of splits along that axis.
554
- margin : NDArray, optional
555
- Padding on the left hand side of the array.
580
+ Dictionary mapping axis number to number of splits.
581
+ equal_shape : dict
582
+ Whether the subsets should be of equal shape, True by default.
556
583
 
557
584
  Returns
558
585
  -------
559
586
  tuple
560
- A tuple of slices, where each slice corresponds to a split along an axis.
587
+ Tuple of slice with requested split combinations.
561
588
  """
562
589
  ndim = len(shape)
563
- if margin is None:
564
- margin = np.zeros(ndim, dtype=int)
565
- splits = {k: max(splits.get(k, 0), 1) for k in range(ndim)}
566
- new_shape = np.divide(shape, [splits.get(i, 1) for i in range(ndim)]).astype(int)
590
+ splits = {k: max(splits.get(k, 1), 1) for k in range(ndim)}
591
+ ret_shape = np.divide(shape, tuple(splits[i] for i in range(ndim)))
592
+ if equal_shape:
593
+ ret_shape = np.ceil(ret_shape).astype(int)
594
+ ret_shape = ret_shape.astype(int)
567
595
 
568
596
  slice_list = [
569
597
  tuple(
570
- (slice(max((n_splits * length) - margin[axis], 0), (n_splits + 1) * length))
598
+ (slice((n_splits * length), (n_splits + 1) * length))
571
599
  if n_splits < splits.get(axis, 1) - 1
572
- else (slice(max((n_splits * length) - margin[axis], 0), shape[axis]))
600
+ else (slice(shape[axis] - length, shape[axis]))
601
+ if equal_shape
602
+ else (slice((n_splits * length), shape[axis]))
573
603
  for n_splits in range(splits.get(axis, 1))
574
604
  )
575
- for length, axis in zip(new_shape, splits.keys())
605
+ for length, axis in zip(ret_shape, splits.keys())
576
606
  ]
577
607
 
578
608
  splits = tuple(product(*slice_list))
@@ -584,28 +614,25 @@ def get_rotation_matrices(
584
614
  angular_sampling: float, dim: int = 3, use_optimized_set: bool = True
585
615
  ) -> NDArray:
586
616
  """
587
- Returns rotation matrices in format k x dim x dim, where k is determined
588
- by ``angular_sampling``.
617
+ Returns rotation matrices with desired ``angular_sampling`` rate.
589
618
 
590
619
  Parameters
591
620
  ----------
592
621
  angular_sampling : float
593
- The angle in degrees used for the generation of rotation matrices.
622
+ The desired angular sampling in degrees.
594
623
  dim : int, optional
595
624
  Dimension of the rotation matrices.
596
625
  use_optimized_set : bool, optional
597
- Whether to use pre-computed rotational sets with more optimal sampling.
598
- Currently only available when dim=3.
626
+ Use optimized rotational sets, True by default and available for dim=3.
599
627
 
600
628
  Notes
601
629
  -----
602
- For the case of dim = 3 optimized rotational sets are used, otherwise
603
- QR-decomposition.
630
+ For dim = 3 optimized sets are used, otherwise QR-decomposition.
604
631
 
605
632
  Returns
606
633
  -------
607
634
  NDArray
608
- Array of shape (k, dim, dim) containing k rotation matrices.
635
+ Array of shape (n, d, d) containing n rotation matrices.
609
636
  """
610
637
  if dim == 3 and use_optimized_set:
611
638
  quaternions, *_ = load_quaternions_by_angle(angular_sampling)
@@ -706,144 +733,82 @@ def get_rotations_around_vector(
706
733
  return rotation_angles
707
734
 
708
735
 
709
- def minimum_enclosing_box(
710
- coordinates: NDArray,
711
- margin: NDArray = None,
712
- use_geometric_center: bool = False,
713
- ) -> Tuple[int]:
736
+ def load_quaternions_by_angle(
737
+ angular_sampling: float,
738
+ ) -> Tuple[NDArray, NDArray, float]:
714
739
  """
715
- Computes the minimal enclosing box around coordinates with margin.
740
+ Get orientations and weights proportional to the given angular_sampling.
716
741
 
717
742
  Parameters
718
743
  ----------
719
- coordinates : NDArray
720
- Coordinates of which the enclosing box should be computed. The shape
721
- of this array should be [d, n] with d dimensions and n coordinates.
722
- margin : NDArray, optional
723
- Box margin. Defaults to None.
724
- use_geometric_center : bool, optional
725
- Whether the box should accommodate the geometric or the coordinate
726
- center. Defaults to False.
744
+ angular_sampling : float
745
+ Requested angular sampling.
727
746
 
728
747
  Returns
729
748
  -------
730
- tuple
731
- Integers corresponding to the minimum enclosing box shape.
749
+ Tuple[NDArray, NDArray, float]
750
+ Quaternion representations of orientations, weights associated with each
751
+ quaternion and closest angular sampling to the requested sampling.
732
752
  """
733
- point_cloud = np.asarray(coordinates)
734
- dim = point_cloud.shape[0]
735
- point_cloud = point_cloud - point_cloud.min(axis=1)[:, None]
753
+ # Metadata contains (N orientations, rotational sampling, coverage as values)
754
+ with open(
755
+ os.path.join(os.path.dirname(__file__), "data", "metadata.yaml"), "r"
756
+ ) as infile:
757
+ metadata = yaml.full_load(infile)
758
+
759
+ set_diffs = {
760
+ setname: abs(angular_sampling - set_angle)
761
+ for setname, (_, set_angle, _) in metadata.items()
762
+ }
763
+ fname = min(set_diffs, key=set_diffs.get)
736
764
 
737
- margin = np.zeros(dim) if margin is None else margin
738
- margin = np.asarray(margin).astype(int)
765
+ infile = os.path.join(os.path.dirname(__file__), "data", fname)
766
+ quat_weights = np.load(infile)
739
767
 
740
- norm_cloud = point_cloud - point_cloud.mean(axis=1)[:, None]
741
- # Adding one avoids clipping during scipy.ndimage.affine_transform
742
- shape = np.repeat(
743
- np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
744
- ).astype(int)
745
- if use_geometric_center:
746
- hull = ConvexHull(point_cloud.T)
747
- distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
748
- distance += np.linalg.norm(np.ones(dim))
749
- shape = np.repeat(np.rint(distance).astype(int), dim)
768
+ quat = quat_weights[:, :4]
769
+ weights = quat_weights[:, -1]
770
+ angle = metadata[fname][0]
750
771
 
751
- return shape
772
+ return quat, weights, angle
752
773
 
753
774
 
754
- def crop_input(
755
- target: "Density",
756
- template: "Density",
757
- target_mask: "Density" = None,
758
- template_mask: "Density" = None,
759
- map_cutoff: float = 0,
760
- template_cutoff: float = 0,
761
- ) -> Tuple[int]:
775
+ def quaternion_to_rotation_matrix(quaternions: NDArray) -> NDArray:
762
776
  """
763
- Crop target and template maps for efficient fitting. Input densities
764
- are cropped in place.
777
+ Convert quaternions to rotation matrices.
765
778
 
766
779
  Parameters
767
780
  ----------
768
- target : Density
769
- Target to be fitted on.
770
- template : Density
771
- Template to fit onto the target.
772
- target_mask : Density, optional
773
- Path to mask of target. Will be croppped like target.
774
- template_mask : Density, optional
775
- Path to mask of template. Will be cropped like template.
776
- map_cutoff : float, optional
777
- Cutoff value for trimming the target Density. Default is 0.
778
- map_cutoff : float, optional
779
- Cutoff value for trimming the template Density. Default is 0.
781
+ quaternions : NDArray
782
+ Quaternion data of shape (n, 4).
780
783
 
781
784
  Returns
782
785
  -------
783
- Tuple[int]
784
- Tuple containing reference fit index
786
+ NDArray
787
+ Rotation matrices corresponding to the given quaternions.
785
788
  """
786
- convolution_shape_init = np.add(target.shape, template.shape) - 1
787
- # If target and template are aligned, fitting should return this index
788
- reference_fit = np.subtract(template.shape, 1)
789
-
790
- target_box = tuple(slice(0, x) for x in target.shape)
791
- if map_cutoff is not None:
792
- target_box = target.trim_box(cutoff=map_cutoff)
793
-
794
- target_mask_box = target_box
795
- if target_mask is not None and map_cutoff is not None:
796
- target_mask_box = target_mask.trim_box(cutoff=map_cutoff)
797
- target_box = tuple(
798
- slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
799
- for arr, mask in zip(target_box, target_mask_box)
800
- )
801
-
802
- template_box = tuple(slice(0, x) for x in template.shape)
803
- if template_cutoff is not None:
804
- template_box = template.trim_box(cutoff=template_cutoff)
805
-
806
- template_mask_box = template_box
807
- if template_mask is not None and template_cutoff is not None:
808
- template_mask_box = template_mask.trim_box(cutoff=template_cutoff)
809
- template_box = tuple(
810
- slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
811
- for arr, mask in zip(template_box, template_mask_box)
812
- )
789
+ q0 = quaternions[:, 0]
790
+ q1 = quaternions[:, 1]
791
+ q2 = quaternions[:, 2]
792
+ q3 = quaternions[:, 3]
813
793
 
814
- cut_right = np.array(
815
- [shape - x.stop for shape, x in zip(template.shape, template_box)]
816
- )
817
- cut_left = np.array([x.start for x in target_box])
794
+ s = np.linalg.norm(quaternions, axis=1) * 2
795
+ rotmat = np.zeros((quaternions.shape[0], 3, 3), dtype=np.float64)
818
796
 
819
- origin_difference = np.divide(target.origin - template.origin, target.sampling_rate)
820
- origin_difference = origin_difference.astype(int)
797
+ rotmat[:, 0, 0] = 1.0 - s * ((q2 * q2) + (q3 * q3))
798
+ rotmat[:, 0, 1] = s * ((q1 * q2) - (q0 * q3))
799
+ rotmat[:, 0, 2] = s * ((q1 * q3) + (q0 * q2))
821
800
 
822
- target.adjust_box(target_box)
823
- template.adjust_box(template_box)
801
+ rotmat[:, 1, 0] = s * ((q2 * q1) + (q0 * q3))
802
+ rotmat[:, 1, 1] = 1.0 - s * ((q3 * q3) + (q1 * q1))
803
+ rotmat[:, 1, 2] = s * ((q2 * q3) - (q0 * q1))
824
804
 
825
- if target_mask is not None:
826
- target_mask.adjust_box(target_box)
827
- if template_mask is not None:
828
- template_mask.adjust_box(template_box)
805
+ rotmat[:, 2, 0] = s * ((q3 * q1) - (q0 * q2))
806
+ rotmat[:, 2, 1] = s * ((q3 * q2) + (q0 * q1))
807
+ rotmat[:, 2, 2] = 1.0 - s * ((q1 * q1) + (q2 * q2))
829
808
 
830
- reference_fit -= cut_right + cut_left + origin_difference
809
+ np.around(rotmat, decimals=8, out=rotmat)
831
810
 
832
- convolution_shape = np.array(target.shape)
833
- convolution_shape += np.array(template.shape) - 1
834
-
835
- print(f"Cropped volume of target is: {target.shape}")
836
- print(f"Cropped volume of template is: {template.shape}")
837
- saving = 1 - (np.prod(convolution_shape)) / np.prod(convolution_shape_init)
838
- saving *= 100
839
-
840
- print(
841
- "Cropping changed array size from "
842
- f"{round(4*np.prod(convolution_shape_init)/1e6, 3)} MB "
843
- f"to {round(4*np.prod(convolution_shape)/1e6, 3)} MB "
844
- f"({'-' if saving > 0 else ''}{abs(round(saving, 2))}%)"
845
- )
846
- return reference_fit
811
+ return rotmat
847
812
 
848
813
 
849
814
  def euler_to_rotationmatrix(angles: Tuple[float], convention: str = "zyx") -> NDArray:
@@ -866,12 +831,8 @@ def euler_to_rotationmatrix(angles: Tuple[float], convention: str = "zyx") -> ND
866
831
  angle_convention = convention[:n_angles]
867
832
  if n_angles == 1:
868
833
  angles = (angles, 0, 0)
869
- rotation_matrix = (
870
- Rotation.from_euler(angle_convention, angles, degrees=True)
871
- .as_matrix()
872
- .astype(np.float32)
873
- )
874
- return rotation_matrix
834
+ rotation_matrix = Rotation.from_euler(angle_convention, angles, degrees=True)
835
+ return rotation_matrix.as_matrix().astype(np.float32)
875
836
 
876
837
 
877
838
  def euler_from_rotationmatrix(
@@ -883,9 +844,10 @@ def euler_from_rotationmatrix(
883
844
  Parameters
884
845
  ----------
885
846
  rotation_matrix : NDArray
886
- A 2 x 2 or 3 x 3 rotation matrix in z y x form.
847
+ A 2 x 2 or 3 x 3 rotation matrix in zyx form.
887
848
  convention : str, optional
888
- Euler angle convention.
849
+ Euler angle convention, zyx by default.
850
+
889
851
  Returns
890
852
  -------
891
853
  Tuple
@@ -895,12 +857,8 @@ def euler_from_rotationmatrix(
895
857
  temp_matrix = np.eye(3)
896
858
  temp_matrix[:2, :2] = rotation_matrix
897
859
  rotation_matrix = temp_matrix
898
- euler_angles = (
899
- Rotation.from_matrix(rotation_matrix)
900
- .as_euler(convention, degrees=True)
901
- .astype(np.float32)
902
- )
903
- return euler_angles
860
+ rotation = Rotation.from_matrix(rotation_matrix)
861
+ return rotation.as_euler(convention, degrees=True).astype(np.float32)
904
862
 
905
863
 
906
864
  def rotation_aligning_vectors(
@@ -961,23 +919,19 @@ def rigid_transform(
961
919
  Parameters
962
920
  ----------
963
921
  coordinates : NDArray
964
- An array representing the coordinates to be transformed [d x N].
922
+ An array representing the coordinates to be transformed (d,n).
965
923
  rotation_matrix : NDArray
966
- The rotation matrix to be applied [d x d].
924
+ The rotation matrix to be applied (d,d).
967
925
  translation : NDArray
968
- The translation vector to be applied [d].
926
+ The translation vector to be applied (d,).
969
927
  out : NDArray
970
- The output array to store the transformed coordinates.
928
+ The output array to store the transformed coordinates (d,n).
971
929
  coordinates_mask : NDArray, optional
972
- An array representing the mask for the coordinates [d x T].
930
+ An array representing the mask for the coordinates (d,t).
973
931
  out_mask : NDArray, optional
974
- The output array to store the transformed coordinates mask.
932
+ The output array to store the transformed coordinates mask (d,t).
975
933
  use_geometric_center : bool, optional
976
934
  Whether to use geometric or coordinate center.
977
-
978
- Returns
979
- -------
980
- None
981
935
  """
982
936
  coordinate_dtype = coordinates.dtype
983
937
  center = coordinates.mean(axis=1) if center is None else center
@@ -1004,71 +958,67 @@ def rigid_transform(
1004
958
  out += translation[:, None]
1005
959
 
1006
960
 
1007
- def _format_string(string: str) -> str:
961
+ def minimum_enclosing_box(
962
+ coordinates: NDArray, margin: NDArray = None, use_geometric_center: bool = False
963
+ ) -> Tuple[int]:
1008
964
  """
1009
- Formats a string by adding quotation marks if it contains white spaces.
965
+ Computes the minimal enclosing box around coordinates with margin.
1010
966
 
1011
967
  Parameters
1012
968
  ----------
1013
- string : str
1014
- Input string to be formatted.
969
+ coordinates : NDArray
970
+ Coordinates of shape (d,n) to compute the enclosing box of.
971
+ margin : NDArray, optional
972
+ Box margin, zero by default.
973
+ use_geometric_center : bool, optional
974
+ Whether box accommodates the geometric or coordinate center, False by default.
1015
975
 
1016
976
  Returns
1017
977
  -------
1018
- str
1019
- Formatted string with added quotation marks if needed.
978
+ tuple of ints
979
+ Minimum enclosing box shape.
1020
980
  """
1021
- if " " in string:
1022
- return f"'{string}'"
1023
- # Occurs e.g. for C1' atoms. The trailing whitespace is necessary.
1024
- if string.count("'") == 1:
1025
- return f'"{string}"'
1026
- return string
1027
-
981
+ point_cloud = np.asarray(coordinates)
982
+ dim = point_cloud.shape[0]
983
+ point_cloud = point_cloud - point_cloud.min(axis=1)[:, None]
1028
984
 
1029
- def _format_mmcif_colunns(subdict: Dict) -> Dict:
1030
- """
1031
- Formats the columns of a mmcif dictionary.
985
+ margin = np.zeros(dim) if margin is None else margin
986
+ margin = np.asarray(margin).astype(int)
1032
987
 
1033
- Parameters
1034
- ----------
1035
- subdict : dict
1036
- Input dictionary where each key corresponds to a column and the
1037
- values are iterables containing the column values.
988
+ norm_cloud = point_cloud - point_cloud.mean(axis=1)[:, None]
989
+ # Adding one avoids clipping during scipy.ndimage.affine_transform
990
+ shape = np.repeat(
991
+ np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
992
+ ).astype(int)
993
+ if use_geometric_center:
994
+ hull = ConvexHull(point_cloud.T)
995
+ distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
996
+ distance += np.linalg.norm(np.ones(dim))
997
+ shape = np.repeat(np.rint(distance).astype(int), dim)
1038
998
 
1039
- Returns
1040
- -------
1041
- dict
1042
- Formatted dictionary with the columns of the mmcif file.
1043
- """
1044
- subdict = {k: [_format_string(s) for s in v] for k, v in subdict.items()}
1045
- key_length = {
1046
- key: len(max(value, key=lambda x: len(x), default=""))
1047
- for key, value in subdict.items()
1048
- }
1049
- padded_subdict = {
1050
- key: [s.ljust(key_length[key] + 1) for s in values]
1051
- for key, values in subdict.items()
1052
- }
1053
- return padded_subdict
999
+ return shape
1054
1000
 
1055
1001
 
1056
- def create_mask(
1057
- mask_type: str, sigma_decay: float = 0, mask_cutoff: float = 0.135, **kwargs
1058
- ) -> NDArray:
1002
+ def create_mask(mask_type: str, sigma_decay: float = 0, **kwargs) -> NDArray:
1059
1003
  """
1060
1004
  Creates a mask of the specified type.
1061
1005
 
1062
1006
  Parameters
1063
1007
  ----------
1064
1008
  mask_type : str
1065
- Type of the mask to be created. Can be "ellipse", "box", or "tube".
1009
+ Type of the mask to be created. Can be one of:
1010
+
1011
+ +---------+----------------------------------------------------------+
1012
+ | box | Box mask (see :py:meth:`box_mask`) |
1013
+ +---------+----------------------------------------------------------+
1014
+ | tube | Cylindrical mask (see :py:meth:`tube_mask`) |
1015
+ +---------+----------------------------------------------------------+
1016
+ | ellipse | Ellipsoidal mask (see :py:meth:`elliptical_mask`) |
1017
+ +---------+----------------------------------------------------------+
1066
1018
  sigma_decay : float, optional
1067
- Standard deviation of an optionally applied Gaussian filter.
1068
- mask_cutoff : float, optional
1069
- Values below mask_cutoff will be set to zero. By default, exp(-2).
1019
+ Smoothing along mask edges using a Gaussian filter, 0 by default.
1070
1020
  kwargs : dict
1071
- Additional parameters required by the mask creating functions.
1021
+ Parameters passed to the indivdual mask creation funcitons.
1072
1022
 
1073
1023
  Returns
1074
1024
  -------
@@ -1079,12 +1029,6 @@ def create_mask(
1079
1029
  ------
1080
1030
  ValueError
1081
1031
  If the mask_type is invalid.
1082
-
1083
- See Also
1084
- --------
1085
- :py:meth:`elliptical_mask`
1086
- :py:meth:`box_mask`
1087
- :py:meth:`tube_mask`
1088
1032
  """
1089
1033
  mapping = {"ellipse": elliptical_mask, "box": box_mask, "tube": tube_mask}
1090
1034
  if mask_type not in mapping:
@@ -1092,9 +1036,9 @@ def create_mask(
1092
1036
 
1093
1037
  mask = mapping[mask_type](**kwargs)
1094
1038
  if sigma_decay > 0:
1095
- mask = gaussian_filter(mask.astype(np.float32), sigma=sigma_decay)
1096
-
1097
- mask[mask < mask_cutoff] = 0
1039
+ mask_filter = gaussian_filter(mask.astype(np.float32), sigma=sigma_decay)
1040
+ mask = np.add(mask, (1 - mask) * mask_filter)
1041
+ mask[mask < np.exp(-np.square(sigma_decay))] = 0
1098
1042
 
1099
1043
  return mask
1100
1044
 
@@ -1126,6 +1070,7 @@ def elliptical_mask(
1126
1070
 
1127
1071
  Examples
1128
1072
  --------
1073
+ >>> from tme.matching_utils import elliptical_mask
1129
1074
  >>> mask = elliptical_mask(shape = (20,20), radius = (5,5), center = (10,10))
1130
1075
  """
1131
1076
  center, shape, radius = np.asarray(center), np.asarray(shape), np.asarray(radius)
@@ -1154,17 +1099,23 @@ def box_mask(shape: Tuple[int], center: Tuple[int], height: Tuple[int]) -> np.nd
1154
1099
 
1155
1100
  Parameters
1156
1101
  ----------
1157
- shape : Tuple[int]
1102
+ shape : tuple of ints
1158
1103
  Shape of the output array.
1159
- center : Tuple[int]
1104
+ center : tuple of ints
1160
1105
  Center point coordinates of the box.
1161
- height : Tuple[int]
1106
+ height : tuple of ints
1162
1107
  Height (side length) of the box along each axis.
1163
1108
 
1164
1109
  Returns
1165
1110
  -------
1166
1111
  NDArray
1167
1112
  The created box mask.
1113
+
1114
+ Raises
1115
+ ------
1116
+ ValueError
1117
+ If ``shape`` and ``center`` do not have the same length.
1118
+ If ``center`` and ``height`` do not have the same length.
1168
1119
  """
1169
1120
  if len(shape) != len(center) or len(center) != len(height):
1170
1121
  raise ValueError("The length of shape, center, and height must be consistent.")
@@ -1216,9 +1167,9 @@ def tube_mask(
1216
1167
  Raises
1217
1168
  ------
1218
1169
  ValueError
1219
- If the inner radius is larger than the outer radius, height is larger
1220
- than the symmetry axis shape, or if base_center and shape do not have the
1221
- same length.
1170
+ If ``inner_radius`` is larger than ``outer_radius``.
1171
+ If ``height`` is larger than the symmetry axis.
1172
+ If ``base_center`` and ``shape`` do not have the same length.
1222
1173
  """
1223
1174
  if inner_radius > outer_radius:
1224
1175
  raise ValueError("inner_radius should be smaller than outer_radius.")
@@ -1274,94 +1225,42 @@ def scramble_phases(
1274
1225
  arr: NDArray,
1275
1226
  noise_proportion: float = 0.5,
1276
1227
  seed: int = 42,
1277
- normalize_power: bool = True,
1228
+ normalize_power: bool = False,
1278
1229
  ) -> NDArray:
1279
1230
  """
1280
- Applies random phase scrambling to a given array.
1281
-
1282
- This function takes an input array, applies a Fourier transform, then scrambles the
1283
- phase with a given proportion of noise, and finally applies an
1284
- inverse Fourier transform to the scrambled data. The phase scrambling
1285
- is controlled by a random seed.
1231
+ Perform random phase scrambling of ``arr``.
1286
1232
 
1287
1233
  Parameters
1288
1234
  ----------
1289
1235
  arr : NDArray
1290
- The input array to be scrambled.
1236
+ Input data.
1291
1237
  noise_proportion : float, optional
1292
- The proportion of noise in the phase scrambling, by default 0.5.
1238
+ Proportion of scrambled phases, 0.5 by default.
1293
1239
  seed : int, optional
1294
- The seed for the random phase scrambling, by default 42.
1240
+ The seed for the random phase scrambling, 42 by default.
1295
1241
  normalize_power : bool, optional
1296
- Whether the returned template should have the same sum of squares as arr.
1242
+ Return value has same sum of squares as ``arr``.
1297
1243
 
1298
1244
  Returns
1299
1245
  -------
1300
1246
  NDArray
1301
- The array with scrambled phases.
1302
-
1303
- Raises
1304
- ------
1305
- ValueError
1306
- If noise_proportion is not within [0, 1].
1247
+ Phase scrambled version of ``arr``.
1307
1248
  """
1308
- if noise_proportion < 0 or noise_proportion > 1:
1309
- raise ValueError("noise_proportion has to be within [0, 1].")
1249
+ np.random.seed(seed)
1250
+ noise_proportion = max(min(noise_proportion, 1), 0)
1310
1251
 
1311
1252
  arr_fft = np.fft.fftn(arr)
1253
+ amp, ph = np.abs(arr_fft), np.angle(arr_fft)
1312
1254
 
1313
- amp = np.abs(arr_fft)
1314
- ph = np.angle(arr_fft)
1315
-
1316
- np.random.seed(seed)
1317
1255
  ph_noise = np.random.permutation(ph)
1318
1256
  ph_new = ph * (1 - noise_proportion) + ph_noise * noise_proportion
1319
1257
  ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph_new)))
1320
1258
 
1321
1259
  if normalize_power:
1322
- np.divide(
1323
- np.subtract(ret, ret.min()), np.subtract(ret.max(), ret.min()), out=ret
1324
- )
1260
+ np.divide(ret - ret.min(), ret.max() - ret.min(), out=ret)
1325
1261
  np.multiply(ret, np.subtract(arr.max(), arr.min()), out=ret)
1326
1262
  np.add(ret, arr.min(), out=ret)
1327
-
1328
1263
  scaling = np.divide(np.abs(arr).sum(), np.abs(ret).sum())
1329
1264
  np.multiply(ret, scaling, out=ret)
1330
1265
 
1331
1266
  return ret
1332
-
1333
-
1334
- def conditional_execute(func: Callable, execute_operation: bool = True) -> Callable:
1335
- """
1336
- Return the given function or a no-operation function based on execute_operation.
1337
-
1338
- Parameters
1339
- ----------
1340
- func : callable
1341
- The function to be executed if execute_operation is True.
1342
- execute_operation : bool, optional
1343
- A flag that determines whether to return `func` or a no-operation function.
1344
- Default is True.
1345
-
1346
- Returns
1347
- -------
1348
- callable
1349
- Either the given function `func` or a no-operation function.
1350
-
1351
- Examples
1352
- --------
1353
- >>> def greet(name):
1354
- ... return f"Hello, {name}!"
1355
- ...
1356
- >>> operation = conditional_execute(greet, False)
1357
- >>> operation("Alice")
1358
- >>> operation = conditional_execute(greet, True)
1359
- >>> operation("Alice")
1360
- 'Hello, Alice!'
1361
- """
1362
-
1363
- def noop(*args, **kwargs):
1364
- """No operation function."""
1365
- pass
1366
-
1367
- return func if execute_operation else noop