pytme 0.2.3__cp311-cp311-macosx_14_0_arm64.whl → 0.2.5__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/match_template.py +8 -8
  2. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/preprocess.py +22 -6
  3. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/preprocessor_gui.py +9 -14
  4. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/METADATA +1 -1
  5. pytme-0.2.5.dist-info/RECORD +119 -0
  6. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/WHEEL +1 -1
  7. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/top_level.txt +1 -0
  8. scripts/match_template.py +8 -8
  9. scripts/preprocess.py +22 -6
  10. scripts/preprocessor_gui.py +9 -14
  11. tests/__init__.py +0 -0
  12. tests/data/.DS_Store +0 -0
  13. tests/data/Blurring/.DS_Store +0 -0
  14. tests/data/Blurring/blob_width18.npy +0 -0
  15. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  16. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  17. tests/data/Blurring/hamming_width6.npy +0 -0
  18. tests/data/Blurring/kaiserb_width18.npy +0 -0
  19. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  20. tests/data/Blurring/mean_size5.npy +0 -0
  21. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  22. tests/data/Blurring/rank_rank3.npy +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Maps/emd_8621.mrc.gz +0 -0
  25. tests/data/README.md +2 -0
  26. tests/data/Raw/.DS_Store +0 -0
  27. tests/data/Raw/em_map.map +0 -0
  28. tests/data/Structures/.DS_Store +0 -0
  29. tests/data/Structures/1pdj.cif +3339 -0
  30. tests/data/Structures/1pdj.pdb +1429 -0
  31. tests/data/Structures/5khe.cif +3685 -0
  32. tests/data/Structures/5khe.ent +2210 -0
  33. tests/data/Structures/5khe.pdb +2210 -0
  34. tests/data/Structures/5uz4.cif +70548 -0
  35. tests/preprocessing/__init__.py +0 -0
  36. tests/preprocessing/test_compose.py +76 -0
  37. tests/preprocessing/test_frequency_filters.py +178 -0
  38. tests/preprocessing/test_preprocessor.py +136 -0
  39. tests/preprocessing/test_utils.py +79 -0
  40. tests/test_analyzer.py +310 -0
  41. tests/test_backends.py +375 -0
  42. tests/test_density.py +508 -0
  43. tests/test_extensions.py +130 -0
  44. tests/test_matching_cli.py +283 -0
  45. tests/test_matching_data.py +162 -0
  46. tests/test_matching_exhaustive.py +162 -0
  47. tests/test_matching_memory.py +30 -0
  48. tests/test_matching_optimization.py +226 -0
  49. tests/test_matching_utils.py +326 -0
  50. tests/test_orientations.py +173 -0
  51. tests/test_packaging.py +95 -0
  52. tests/test_parser.py +33 -0
  53. tests/test_structure.py +243 -0
  54. tme/__init__.py +0 -1
  55. tme/__version__.py +1 -1
  56. tme/backends/jax_backend.py +3 -9
  57. tme/data/scattering_factors.pickle +0 -0
  58. tme/density.py +14 -10
  59. tme/external/bindings.cpp +332 -0
  60. tme/matching_data.py +14 -12
  61. tme/matching_exhaustive.py +17 -15
  62. tme/matching_optimization.py +215 -208
  63. tme/matching_utils.py +1 -0
  64. tme/preprocessing/_utils.py +14 -14
  65. tme/preprocessing/composable_filter.py +0 -2
  66. tme/preprocessing/compose.py +4 -4
  67. tme/preprocessing/frequency_filters.py +32 -35
  68. tme/preprocessing/tilt_series.py +198 -117
  69. tme/preprocessor.py +24 -246
  70. tme/structure.py +22 -22
  71. pytme-0.2.3.dist-info/RECORD +0 -75
  72. tme/matching_memory.py +0 -383
  73. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/estimate_ram_usage.py +0 -0
  74. {pytme-0.2.3.data → pytme-0.2.5.data}/scripts/postprocess.py +0 -0
  75. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/LICENSE +0 -0
  76. {pytme-0.2.3.dist-info → pytme-0.2.5.dist-info}/entry_points.txt +0 -0
@@ -1,17 +1,16 @@
1
- """ Implements various methods for non-exhaustive template matching
2
- based on numerical optimization.
1
+ """ Implements methods for non-exhaustive template matching.
3
2
 
4
3
  Copyright (c) 2023 European Molecular Biology Laboratory
5
4
 
6
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
6
  """
8
-
9
- from typing import Tuple, Dict
7
+ import warnings
8
+ from typing import Tuple, List, Dict
10
9
  from abc import ABC, abstractmethod
11
10
 
12
11
  import numpy as np
13
12
  from scipy.spatial import KDTree
14
- from scipy.ndimage import laplace, map_coordinates
13
+ from scipy.ndimage import laplace, map_coordinates, sobel
15
14
  from scipy.optimize import (
16
15
  minimize,
17
16
  basinhopping,
@@ -157,15 +156,13 @@ class _MatchDensityToDensity(ABC):
157
156
  if out is None:
158
157
  out = np.zeros_like(arr)
159
158
 
160
- map_coordinates(arr, self.grid_out, order=order, output=out.ravel())
159
+ self._interpolate(arr, self.grid_out, order=order, out=out.ravel())
161
160
 
162
161
  if out_mask is None and arr_mask is not None:
163
162
  out_mask = np.zeros_like(arr_mask)
164
163
 
165
164
  if arr_mask is not None:
166
- map_coordinates(
167
- arr_mask, self.grid_out, order=order, output=out_mask.ravel()
168
- )
165
+ self._interpolate(arr_mask, self.grid_out, order=order, out=out.ravel())
169
166
 
170
167
  match return_type:
171
168
  case 0:
@@ -177,6 +174,12 @@ class _MatchDensityToDensity(ABC):
177
174
  case 3:
178
175
  return out, out_mask
179
176
 
177
+ @staticmethod
178
+ def _interpolate(data, positions, order: int = 1, out=None):
179
+ return map_coordinates(
180
+ data, positions, order=order, mode="constant", output=out
181
+ )
182
+
180
183
  def score_translation(self, x: Tuple[float]) -> float:
181
184
  """
182
185
  Computes the score after a given translation.
@@ -291,6 +294,8 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
291
294
  A d-dimensional mask to be applied to the target.
292
295
  negate_score : bool, optional
293
296
  Whether the final score should be multiplied by negative one. Default is True.
297
+ return_gradient : bool, optional
298
+ Invoking __call_ returns a tuple of score and parameter gradient. Default is False.
294
299
  **kwargs : Dict, optional
295
300
  Keyword arguments propagated to downstream functions.
296
301
  """
@@ -303,40 +308,43 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
303
308
  template_mask_coordinates: NDArray = None,
304
309
  target_mask: NDArray = None,
305
310
  negate_score: bool = True,
311
+ return_gradient: bool = False,
312
+ interpolation_order: int = 1,
306
313
  **kwargs: Dict,
307
314
  ):
308
- self.eps = be.eps(target.dtype)
309
- self.target_density = target
310
- self.target_mask_density = target_mask
315
+ self.target = target.astype(np.float32)
316
+ self.target_mask = None
317
+ if target_mask is not None:
318
+ self.target_mask = target_mask.astype(np.float32)
311
319
 
312
- self.template_weights = template_weights
313
- self.template_coordinates = template_coordinates
314
- self.template_coordinates_rotated = np.copy(self.template_coordinates).astype(
315
- np.float32
320
+ self.eps = be.eps(self.target.dtype)
321
+
322
+ self.target_grad = np.stack(
323
+ [sobel(self.target, axis=i) for i in range(self.target.ndim)]
316
324
  )
317
- if template_mask_coordinates is None:
318
- template_mask_coordinates = template_coordinates.copy()
319
325
 
320
- self.template_mask_coordinates = template_mask_coordinates
321
- self.template_mask_coordinates_rotated = template_mask_coordinates
326
+ self.n_points = template_coordinates.shape[1]
327
+ self.template = template_coordinates.astype(np.float32)
328
+ self.template_rotated = np.zeros_like(self.template)
329
+ self.template_weights = template_weights.astype(np.float32)
330
+ self.template_center = np.mean(self.template, axis=1)[:, None]
331
+
332
+ self.template_mask, self.template_mask_rotated = None, None
322
333
  if template_mask_coordinates is not None:
323
- self.template_mask_coordinates_rotated = np.copy(
324
- self.template_mask_coordinates
325
- ).astype(np.float32)
334
+ self.template_mask = template_mask_coordinates.astype(np.float32)
335
+ self.template_mask_rotated = np.empty_like(self.template_mask)
326
336
 
327
337
  self.denominator = 1
328
338
  self.score_sign = -1 if negate_score else 1
339
+ self.interpolation_order = interpolation_order
329
340
 
330
- self.in_volume, self.in_volume_mask = self.map_coordinates_to_array(
331
- coordinates=self.template_coordinates_rotated,
332
- coordinates_mask=self.template_mask_coordinates_rotated,
333
- array_origin=be.zeros(target.ndim),
334
- array_shape=self.target_density.shape,
335
- sampling_rate=be.full(target.ndim, fill_value=1),
341
+ self._target_values = self._interpolate(
342
+ self.target, self.template, order=self.interpolation_order
336
343
  )
337
344
 
338
- if hasattr(self, "_post_init"):
339
- self._post_init(**kwargs)
345
+ if return_gradient and not hasattr(self, "grad"):
346
+ raise NotImplementedError(f"{type(self)} does not have grad method.")
347
+ self.return_gradient = return_gradient
340
348
 
341
349
  def score(self, x: Tuple[float]):
342
350
  """
@@ -356,119 +364,39 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
356
364
  translation, rotation_matrix = _format_rigid_transform(x)
357
365
 
358
366
  rigid_transform(
359
- coordinates=self.template_coordinates,
360
- coordinates_mask=self.template_mask_coordinates,
367
+ coordinates=self.template,
368
+ coordinates_mask=self.template_mask,
361
369
  rotation_matrix=rotation_matrix,
362
370
  translation=translation,
363
- out=self.template_coordinates_rotated,
364
- out_mask=self.template_mask_coordinates_rotated,
371
+ out=self.template_rotated,
372
+ out_mask=self.template_mask_rotated,
365
373
  use_geometric_center=False,
366
374
  )
367
375
 
368
- self.in_volume, self.in_volume_mask = self.map_coordinates_to_array(
369
- coordinates=self.template_coordinates_rotated,
370
- coordinates_mask=self.template_mask_coordinates_rotated,
371
- array_origin=be.zeros(rotation_matrix.shape[0]),
372
- array_shape=self.target_density.shape,
373
- sampling_rate=be.full(rotation_matrix.shape[0], fill_value=1),
376
+ self._target_values = self._interpolate(
377
+ self.target, self.template_rotated, order=self.interpolation_order
374
378
  )
375
379
 
376
- return self()
380
+ score = self()
381
+ if not self.return_gradient:
382
+ return score
377
383
 
378
- @staticmethod
379
- def array_from_coordinates(
380
- coordinates: NDArray,
381
- weights: NDArray,
382
- sampling_rate: NDArray,
383
- origin: NDArray = None,
384
- shape: NDArray = None,
385
- ) -> Tuple[NDArray, NDArray, NDArray]:
386
- """
387
- Create a volume from coordinates, using given weights and voxel size.
384
+ return score, self.grad()
388
385
 
389
- Parameters
390
- ----------
391
- coordinates : NDArray
392
- An array representing the coordinates [d x N].
393
- weights : NDArray
394
- An array representing the weights for each coordinate [N].
395
- sampling_rate : NDArray
396
- The size of a voxel in the volume.
397
- origin : NDArray, optional
398
- The origin of the volume.
399
- shape : NDArray, optional
400
- The shape of the volume.
386
+ def _interpolate_gradient(self, positions):
387
+ ret = be.zeros(positions.shape, dtype=positions.dtype)
401
388
 
402
- Returns
403
- -------
404
- tuple
405
- Returns the generated volume, positions of coordinates, and origin.
406
- """
407
- if origin is None:
408
- origin = coordinates.min(axis=1)
409
-
410
- positions = np.divide(coordinates - origin[:, None], sampling_rate[:, None])
411
- positions = positions.astype(int)
412
-
413
- if shape is None:
414
- shape = positions.max(axis=1) + 1
389
+ for k in range(self.target_grad.shape[0]):
390
+ ret[k, :] = self._interpolate(
391
+ self.target_grad[k], positions, order=self.interpolation_order
392
+ )
415
393
 
416
- arr = np.zeros(shape, dtype=np.float32)
417
- np.add.at(arr, tuple(positions), weights)
418
- return arr, positions, origin
394
+ return ret
419
395
 
420
396
  @staticmethod
421
- def map_coordinates_to_array(
422
- coordinates: NDArray,
423
- array_shape: NDArray,
424
- array_origin: NDArray,
425
- sampling_rate: NDArray,
426
- coordinates_mask: NDArray = None,
427
- ) -> Tuple[NDArray, NDArray]:
428
- """
429
- Map coordinates to a volume based on given voxel size and origin.
430
-
431
- Parameters
432
- ----------
433
- coordinates : NDArray
434
- An array representing the coordinates to be mapped [d x N].
435
- array_shape : NDArray
436
- The shape of the array to which the coordinates are mapped.
437
- array_origin : NDArray
438
- The origin of the array to which the coordinates are mapped.
439
- sampling_rate : NDArray
440
- The size of a voxel in the array.
441
- coordinates_mask : NDArray, optional
442
- An array representing the mask for the coordinates [d x T].
443
-
444
- Returns
445
- -------
446
- tuple
447
- Returns transformed coordinates, transformed coordinates mask,
448
- mask for in_volume points, and mask for in_volume points in mask.
449
- """
450
- np.divide(
451
- coordinates - array_origin[:, None], sampling_rate[:, None], out=coordinates
452
- )
453
-
454
- in_volume = np.logical_and(
455
- coordinates < np.array(array_shape)[:, None],
456
- coordinates >= 0,
457
- ).min(axis=0)
458
-
459
- in_volume_mask = None
460
- if coordinates_mask is not None:
461
- np.divide(
462
- coordinates_mask - array_origin[:, None],
463
- sampling_rate[:, None],
464
- out=coordinates_mask,
465
- )
466
- in_volume_mask = np.logical_and(
467
- coordinates_mask < np.array(array_shape)[:, None],
468
- coordinates_mask >= 0,
469
- ).min(axis=0)
470
-
471
- return in_volume, in_volume_mask
397
+ def _torques(positions, center, gradients):
398
+ positions_center = (positions - center).T
399
+ return be.cross(positions_center, gradients.T).T
472
400
 
473
401
 
474
402
  class _MatchCoordinatesToCoordinates(_MatchDensityToDensity):
@@ -635,14 +563,43 @@ class CrossCorrelation(_MatchCoordinatesToDensity):
635
563
 
636
564
  def __call__(self) -> float:
637
565
  """Returns the score of the current configuration."""
638
- score = np.dot(
639
- self.target_density[
640
- tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
641
- ],
642
- self.template_weights[self.in_volume],
566
+ score = be.dot(self._target_values, self.template_weights)
567
+ score /= self.denominator * self.score_sign
568
+ return score
569
+
570
+ def grad(self):
571
+ """
572
+ Calculate the gradient of the cost function w.r.t. translation and rotation.
573
+
574
+ .. math::
575
+
576
+ \\nabla f = -\\frac{1}{N} \\begin{bmatrix}
577
+ \\sum_i w_i \\nabla v(x_i) \\\\
578
+ \\sum_i w_i (r_i \\times \\nabla v(x_i))
579
+ \\end{bmatrix}
580
+
581
+ where :math:`N` is the number of points, :math:`w_i` are weights,
582
+ :math:`x_i` are rotated template positions, and :math:`r_i` are
583
+ positions relative to the template center.
584
+
585
+ Returns
586
+ -------
587
+ np.ndarray
588
+ Negative gradient of the cost function: [dx, dy, dz, dRx, dRy, dRz].
589
+
590
+ """
591
+ grad = self._interpolate_gradient(positions=self.template_rotated)
592
+ torque = self._torques(
593
+ positions=self.template_rotated, gradients=grad, center=self.template_center
643
594
  )
644
- score /= self.denominator
645
- return score * self.score_sign
595
+
596
+ translation_grad = be.sum(grad * self.template_weights, axis=1)
597
+ torque_grad = be.sum(torque * self.template_weights, axis=1)
598
+
599
+ # <u, dv/dx> / <u, r x dv/dx>
600
+ total_grad = be.concatenate([translation_grad, torque_grad])
601
+ total_grad = be.divide(total_grad, self.n_points, out=total_grad)
602
+ return -total_grad
646
603
 
647
604
 
648
605
  class LaplaceCrossCorrelation(CrossCorrelation):
@@ -658,15 +615,18 @@ class LaplaceCrossCorrelation(CrossCorrelation):
658
615
 
659
616
  __doc__ += _MatchCoordinatesToDensity.__doc__
660
617
 
661
- def _post_init(self, **kwargs):
662
- self.target_density = laplace(self.target_density)
618
+ def __init__(self, **kwargs):
619
+ kwargs["target"] = laplace(kwargs["target"])
663
620
 
664
- arr, positions, _ = self.array_from_coordinates(
665
- self.template_coordinates,
666
- self.template_weights,
667
- np.ones(self.template_coordinates.shape[0]),
668
- )
669
- self.template_weights = laplace(arr)[tuple(positions)]
621
+ coordinates = kwargs["template_coordinates"]
622
+ origin = coordinates.min(axis=1)
623
+ positions = (coordinates - origin[:, None]).astype(int)
624
+ shape = positions.max(axis=1) + 1
625
+ arr = np.zeros(shape, dtype=np.float32)
626
+ np.add.at(arr, tuple(positions), kwargs["template_weights"])
627
+
628
+ kwargs["template_weights"] = laplace(arr)[tuple(positions)]
629
+ super().__init__(**kwargs)
670
630
 
671
631
 
672
632
  class NormalizedCrossCorrelation(CrossCorrelation):
@@ -696,24 +656,76 @@ class NormalizedCrossCorrelation(CrossCorrelation):
696
656
  __doc__ += _MatchCoordinatesToDensity.__doc__
697
657
 
698
658
  def __call__(self) -> float:
699
- n_observations = be.sum(self.in_volume_mask)
700
- target_coordinates = be.astype(
701
- self.template_mask_coordinates_rotated[:, self.in_volume_mask], int
659
+ denominator = be.multiply(
660
+ np.linalg.norm(self.template_weights), np.linalg.norm(self._target_values)
702
661
  )
703
- target_weight = self.target_density[tuple(target_coordinates)]
704
- ex2 = be.divide(be.sum(be.square(target_weight)), n_observations)
705
- e2x = be.square(be.divide(be.sum(target_weight), n_observations))
706
-
707
- denominator = be.maximum(be.subtract(ex2, e2x), 0.0)
708
- denominator = be.sqrt(denominator)
709
- denominator = be.multiply(denominator, n_observations)
710
662
 
711
- if denominator <= self.eps:
663
+ if denominator <= 0:
712
664
  return 0.0
713
665
 
714
666
  self.denominator = denominator
715
667
  return super().__call__()
716
668
 
669
+ def grad(self):
670
+ """
671
+ Calculate the normalized gradient of the cost function w.r.t. translation and rotation.
672
+
673
+ .. math::
674
+
675
+ \\nabla f = -\\frac{1}{N|w||v|^3} \\begin{bmatrix}
676
+ (\\sum_i w_i \\nabla v(x_i))|v|^2 - (\\sum_i v(x_i)
677
+ \\nabla v(x_i))(w \\cdot v) \\\\
678
+ (\\sum_i w_i (r_i \\times \\nabla v(x_i)))|v|^2 - (\\sum_i v(x_i)
679
+ (r_i \\times \\nabla v(x_i)))(w \\cdot v)
680
+ \\end{bmatrix}
681
+
682
+ where :math:`N` is the number of points, :math:`w` are weights,
683
+ :math:`v` are target values, :math:`x_i` are rotated template positions,
684
+ and :math:`r_i` are positions relative to the template center.
685
+
686
+ Returns
687
+ -------
688
+ np.ndarray
689
+ Negative normalized gradient: [dx, dy, dz, dRx, dRy, dRz].
690
+
691
+ """
692
+ grad = self._interpolate_gradient(positions=self.template_rotated)
693
+ torque = self._torques(
694
+ positions=self.template_rotated, gradients=grad, center=self.template_center
695
+ )
696
+
697
+ norm = be.multiply(
698
+ be.power(be.sqrt(be.sum(be.square(self._target_values))), 3),
699
+ be.sqrt(be.sum(be.square(self.template_weights))),
700
+ )
701
+
702
+ # (<u,dv/dx> * |v|**2 - <u,v> * <v,dv/dx>)/(|w|*|v|**3)
703
+ translation_grad = be.multiply(
704
+ be.sum(be.multiply(grad, self.template_weights), axis=1),
705
+ be.sum(be.square(self._target_values)),
706
+ )
707
+ translation_grad -= be.multiply(
708
+ be.sum(be.multiply(grad, self._target_values), axis=1),
709
+ be.sum(be.multiply(self._target_values, self.template_weights)),
710
+ )
711
+
712
+ # (<u,r x dv/dx> * |v|**2 - <u,v> * <v,r x dv/dx>)/(|w|*|v|**3)
713
+ torque_grad = be.multiply(
714
+ be.sum(be.multiply(torque, self.template_weights), axis=1),
715
+ be.sum(be.square(self._target_values)),
716
+ )
717
+ torque_grad -= be.multiply(
718
+ be.sum(be.multiply(torque, self._target_values), axis=1),
719
+ be.sum(be.multiply(self._target_values, self.template_weights)),
720
+ )
721
+
722
+ total_grad = be.concatenate([translation_grad, torque_grad])
723
+ if norm > 0:
724
+ total_grad = be.divide(total_grad, norm, out=total_grad)
725
+
726
+ total_grad = be.divide(total_grad, self.n_points, out=total_grad)
727
+ return -total_grad
728
+
717
729
 
718
730
  class NormalizedCrossCorrelationMean(NormalizedCrossCorrelation):
719
731
  """
@@ -802,33 +814,33 @@ class MaskedCrossCorrelation(_MatchCoordinatesToDensity):
802
814
 
803
815
  def __call__(self) -> float:
804
816
  """Returns the score of the current configuration."""
817
+
818
+ in_volume = np.logical_and(
819
+ self.template_rotated < np.array(self.target.shape)[:, None],
820
+ self.template_rotated >= 0,
821
+ ).min(axis=0)
822
+ in_volume_mask = np.logical_and(
823
+ self.template_mask_rotated < np.array(self.target.shape)[:, None],
824
+ self.template_mask_rotated >= 0,
825
+ ).min(axis=0)
826
+
805
827
  mask_overlap = np.sum(
806
- self.target_mask_density[
807
- tuple(
808
- self.template_mask_coordinates_rotated[
809
- :, self.in_volume_mask
810
- ].astype(int)
811
- )
828
+ self.target_mask[
829
+ tuple(self.template_mask_rotated[:, in_volume_mask].astype(int))
812
830
  ],
813
831
  )
814
832
  mask_overlap = np.fmax(mask_overlap, np.finfo(float).eps)
815
833
 
816
- mask_target = self.target_density[
817
- tuple(
818
- self.template_mask_coordinates_rotated[:, self.in_volume_mask].astype(
819
- int
820
- )
821
- )
834
+ mask_target = self.target[
835
+ tuple(self.template_mask_rotated[:, in_volume_mask].astype(int))
822
836
  ]
823
837
  denominator1 = np.subtract(
824
838
  np.sum(mask_target**2),
825
839
  np.divide(np.square(np.sum(mask_target)), mask_overlap),
826
840
  )
827
841
  mask_template = np.multiply(
828
- self.template_weights[self.in_volume],
829
- self.target_mask_density[
830
- tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
831
- ],
842
+ self.template_weights[in_volume],
843
+ self.target_mask[tuple(self.template_rotated[:, in_volume].astype(int))],
832
844
  )
833
845
  denominator2 = np.subtract(
834
846
  np.sum(mask_template**2),
@@ -840,10 +852,8 @@ class MaskedCrossCorrelation(_MatchCoordinatesToDensity):
840
852
  denominator = np.sqrt(np.multiply(denominator1, denominator2))
841
853
 
842
854
  numerator = np.dot(
843
- self.target_density[
844
- tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
845
- ],
846
- self.template_weights[self.in_volume],
855
+ self.target[tuple(self.template_rotated[:, in_volume].astype(int))],
856
+ self.template_weights[in_volume],
847
857
  )
848
858
 
849
859
  numerator -= np.divide(
@@ -877,21 +887,9 @@ class PartialLeastSquareDifference(_MatchCoordinatesToDensity):
877
887
 
878
888
  def __call__(self) -> float:
879
889
  """Returns the score of the current configuration."""
880
- score = np.sum(
881
- np.square(
882
- np.subtract(
883
- self.target_density[
884
- tuple(
885
- self.template_coordinates_rotated[:, self.in_volume].astype(
886
- int
887
- )
888
- )
889
- ],
890
- self.template_weights[self.in_volume],
891
- )
892
- )
890
+ score = be.sum(
891
+ be.square(be.subtract(self._target_values, self.template_weights))
893
892
  )
894
- score += np.sum(np.square(self.template_weights[np.invert(self.in_volume)]))
895
893
  return score * self.score_sign
896
894
 
897
895
 
@@ -917,10 +915,7 @@ class MutualInformation(_MatchCoordinatesToDensity):
917
915
  def __call__(self) -> float:
918
916
  """Returns the score of the current configuration."""
919
917
  p_xy, target, template = np.histogram2d(
920
- self.target_density[
921
- tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
922
- ],
923
- self.template_weights[self.in_volume],
918
+ self._target_values, self.template_weights
924
919
  )
925
920
  p_x, p_y = np.sum(p_xy, axis=1), np.sum(p_xy, axis=0)
926
921
 
@@ -947,7 +942,7 @@ class Envelope(_MatchCoordinatesToDensity):
947
942
  References
948
943
  ----------
949
944
  .. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
950
- fitting", Journal of Structural Biology, vol. 174, no. 2,
945
+ fitting", Journal of Structural Biology, vol. 1174, no. 2,
951
946
  pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
952
947
  """
953
948
 
@@ -956,22 +951,22 @@ class Envelope(_MatchCoordinatesToDensity):
956
951
  def __init__(self, target_threshold: float = None, **kwargs):
957
952
  super().__init__(**kwargs)
958
953
  if target_threshold is None:
959
- target_threshold = np.mean(self.target_density)
960
- self.target_density = np.where(self.target_density > target_threshold, -1, 1)
961
- self.target_density_present = np.sum(self.target_density == -1)
962
- self.target_density_absent = np.sum(self.target_density == 1)
954
+ target_threshold = np.mean(self.target)
955
+ self.target = np.where(self.target > target_threshold, -1, 1)
956
+ self.target_present = np.sum(self.target == -1)
957
+ self.target_absent = np.sum(self.target == 1)
963
958
  self.template_weights = np.ones_like(self.template_weights)
964
959
 
965
960
  def __call__(self) -> float:
966
961
  """Returns the score of the current configuration."""
967
- score = self.target_density[
968
- tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
969
- ]
970
- unassigned_density = self.target_density_present - (score == -1).sum()
962
+ score = self._target_values
963
+ unassigned_density = self.target_present - (score == -1).sum()
971
964
 
972
- score = score.sum() - unassigned_density - 2 * np.sum(np.invert(self.in_volume))
973
- min_score = -self.target_density_present - 2 * self.target_density_absent
974
- score = (score - 2 * min_score) / (2 * self.target_density_present - min_score)
965
+ # Out of volume values will be set to 0
966
+ score = score.sum() - unassigned_density
967
+ score -= 2 * np.sum(np.invert(np.abs(self._target_values) > 0))
968
+ min_score = -self.target_present - 2 * self.target_absent
969
+ score = (score - 2 * min_score) / (2 * self.target_present - min_score)
975
970
 
976
971
  return score * self.score_sign
977
972
 
@@ -1271,7 +1266,15 @@ def optimize_match(
1271
1266
 
1272
1267
  x0 = np.zeros(2 * ndim) if x0 is None else x0
1273
1268
 
1269
+ return_gradient = getattr(score_object, "return_gradient", False)
1270
+ if optimization_method != "minimize" and return_gradient:
1271
+ warnings.warn("Gradient only considered for optimization_method='minimize'.")
1272
+ score_object.return_gradient = False
1273
+
1274
1274
  initial_score = score_object.score(x=x0)
1275
+ if isinstance(initial_score, (List, Tuple)):
1276
+ initial_score = initial_score[0]
1277
+
1275
1278
  if optimization_method == "basinhopping":
1276
1279
  result = basinhopping(
1277
1280
  x0=x0,
@@ -1287,10 +1290,14 @@ def optimize_match(
1287
1290
  maxiter=maxiter,
1288
1291
  )
1289
1292
  elif optimization_method == "minimize":
1290
- print(maxiter)
1293
+ if hasattr(score_object, "grad") and not return_gradient:
1294
+ warnings.warn(
1295
+ "Consider initializing score object with return_gradient=True."
1296
+ )
1291
1297
  result = minimize(
1292
1298
  x0=x0,
1293
1299
  fun=score_object.score,
1300
+ jac=return_gradient,
1294
1301
  bounds=bounds,
1295
1302
  constraints=linear_constraint,
1296
1303
  options={"maxiter": maxiter},
tme/matching_utils.py CHANGED
@@ -645,6 +645,7 @@ def get_rotation_matrices(
645
645
  dets = np.linalg.det(ret)
646
646
  neg_dets = dets < 0
647
647
  ret[neg_dets, :, -1] *= -1
648
+ ret[0] = np.eye(dim, dtype=ret.dtype)
648
649
  return ret
649
650
 
650
651
 
@@ -19,8 +19,8 @@ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool =
19
19
  """
20
20
  Given an opening_axis, computes the shape of the remaining dimensions.
21
21
 
22
- Parameters:
23
- -----------
22
+ Parameters
23
+ ----------
24
24
  shape : Tuple[int]
25
25
  The shape of the input array.
26
26
  opening_axis : int
@@ -28,8 +28,8 @@ def compute_tilt_shape(shape: Tuple[int], opening_axis: int, reduce_dim: bool =
28
28
  reduce_dim : bool, optional (default=False)
29
29
  Whether to reduce the dimensionality after tilting.
30
30
 
31
- Returns:
32
- --------
31
+ Returns
32
+ -------
33
33
  Tuple[int]
34
34
  The shape of the array after tilting.
35
35
  """
@@ -44,13 +44,13 @@ def centered_grid(shape: Tuple[int]) -> NDArray:
44
44
  """
45
45
  Generate an integer valued grid centered around size // 2
46
46
 
47
- Parameters:
48
- -----------
47
+ Parameters
48
+ ----------
49
49
  shape : Tuple[int]
50
50
  The shape of the grid.
51
51
 
52
- Returns:
53
- --------
52
+ Returns
53
+ -------
54
54
  NDArray
55
55
  The centered grid.
56
56
  """
@@ -70,8 +70,8 @@ def frequency_grid_at_angle(
70
70
  """
71
71
  Generate a frequency grid from 0 to 1/(2 * sampling_rate) in each axis.
72
72
 
73
- Parameters:
74
- -----------
73
+ Parameters
74
+ ----------
75
75
  shape : Tuple[int]
76
76
  The shape of the grid.
77
77
  angle : float
@@ -128,8 +128,8 @@ def fftfreqn(
128
128
  """
129
129
  Generate the n-dimensional discrete Fourier transform sample frequencies.
130
130
 
131
- Parameters:
132
- -----------
131
+ Parameters
132
+ ----------
133
133
  shape : Tuple[int]
134
134
  The shape of the data.
135
135
  sampling_rate : float or Tuple[float]
@@ -180,8 +180,8 @@ def crop_real_fourier(data: BackendArray) -> BackendArray:
180
180
  """
181
181
  Crop the real part of a Fourier transform.
182
182
 
183
- Parameters:
184
- -----------
183
+ Parameters
184
+ ----------
185
185
  data : BackendArray
186
186
  The Fourier transformed data.
187
187