pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
@@ -10,21 +10,23 @@ from typing import Tuple, Dict
10
10
  from abc import ABC, abstractmethod
11
11
 
12
12
  import numpy as np
13
- from numpy.typing import NDArray
13
+ from scipy.spatial import KDTree
14
+ from scipy.ndimage import laplace, map_coordinates
14
15
  from scipy.optimize import (
15
- differential_evolution,
16
- LinearConstraint,
17
- basinhopping,
18
16
  minimize,
17
+ basinhopping,
18
+ LinearConstraint,
19
+ differential_evolution,
19
20
  )
20
- from scipy.ndimage import laplace, map_coordinates
21
- from scipy.spatial import KDTree
22
21
 
23
- from .types import ArrayLike
24
- from .backends import backend
22
+ from .backends import backend as be
23
+ from .types import ArrayLike, NDArray
25
24
  from .matching_data import MatchingData
26
- from .matching_utils import rigid_transform, euler_to_rotationmatrix
27
- from .matching_exhaustive import normalize_under_mask
25
+ from .matching_utils import (
26
+ rigid_transform,
27
+ euler_to_rotationmatrix,
28
+ normalize_template,
29
+ )
28
30
 
29
31
 
30
32
  def _format_rigid_transform(x: Tuple[float]) -> Tuple[ArrayLike, ArrayLike]:
@@ -45,9 +47,9 @@ def _format_rigid_transform(x: Tuple[float]) -> Tuple[ArrayLike, ArrayLike]:
45
47
  split = len(x) // 2
46
48
  translation, angles = x[:split], x[split:]
47
49
 
48
- translation = backend.to_backend_array(translation)
49
- rotation_matrix = euler_to_rotationmatrix(backend.to_numpy_array(angles))
50
- rotation_matrix = backend.to_backend_array(rotation_matrix)
50
+ translation = be.to_backend_array(translation)
51
+ rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles))
52
+ rotation_matrix = be.to_backend_array(rotation_matrix)
51
53
 
52
54
  return translation, rotation_matrix
53
55
 
@@ -91,6 +93,7 @@ class _MatchDensityToDensity(ABC):
91
93
  negate_score: bool = True,
92
94
  **kwargs: Dict,
93
95
  ):
96
+ self.eps = be.eps(target.dtype)
94
97
  self.rotate_mask = rotate_mask
95
98
  self.interpolation_order = interpolation_order
96
99
 
@@ -100,36 +103,29 @@ class _MatchDensityToDensity(ABC):
100
103
  if target_mask is not None:
101
104
  matching_data.target_mask = target_mask
102
105
 
103
- target_pad = matching_data.target_padding(pad_target=pad_target_edges)
104
- matching_data = matching_data.subset_by_slice(target_pad=target_pad)
106
+ self.target, self.target_mask = matching_data.target, matching_data.target_mask
105
107
 
106
- fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
107
- pad_fourier=pad_fourier
108
- )
109
-
110
- self.target = backend.topleft_pad(matching_data.target, fast_shape)
111
- self.target_mask = matching_data.target_mask
112
-
113
- self.template = matching_data.template
114
- self.template_rot = backend.preallocate_array(
115
- fast_shape, backend._default_dtype
116
- )
108
+ self.template = matching_data._template
109
+ self.template_rot = be.zeros(template.shape, be._float_dtype)
117
110
 
118
111
  self.template_mask, self.template_mask_rot = 1, 1
119
- rotate_mask = False if matching_data.template_mask is None else rotate_mask
112
+ rotate_mask = False if matching_data._template_mask is None else rotate_mask
120
113
  if matching_data.template_mask is not None:
121
- self.template_mask = matching_data.template_mask
122
- self.template_mask_rot = backend.topleft_pad(
123
- matching_data.template_mask, fast_shape
114
+ self.template_mask = matching_data._template_mask
115
+ self.template_mask_rot = be.topleft_pad(
116
+ matching_data._template_mask, self.template_mask.shape
124
117
  )
125
118
 
119
+ self.template_slices = tuple(slice(None) for _ in self.template.shape)
120
+ self.target_slices = tuple(slice(0, x) for x in self.template.shape)
121
+
126
122
  self.score_sign = -1 if negate_score else 1
127
123
 
128
124
  if hasattr(self, "_post_init"):
129
125
  self._post_init(**kwargs)
130
126
 
131
- @staticmethod
132
- def rigid_transform(
127
+ def rotate_array(
128
+ self,
133
129
  arr,
134
130
  rotation_matrix,
135
131
  translation,
@@ -137,28 +133,39 @@ class _MatchDensityToDensity(ABC):
137
133
  out=None,
138
134
  out_mask=None,
139
135
  order: int = 1,
140
- use_geometric_center: bool = False,
136
+ **kwargs,
141
137
  ):
142
138
  rotate_mask = arr_mask is not None
143
139
  return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
144
140
  translation = np.zeros(arr.ndim) if translation is None else translation
145
141
 
146
142
  center = np.floor(np.array(arr.shape) / 2)[:, None]
147
- grid = np.indices(arr.shape, dtype=np.float32).reshape(arr.ndim, -1)
148
- np.subtract(grid, center, out=grid)
149
- np.matmul(rotation_matrix.T, grid, out=grid)
150
- np.add(grid, center, out=grid)
143
+
144
+ if not hasattr(self, "_previous_center"):
145
+ self._previous_center = arr.shape
146
+
147
+ if not hasattr(self, "grid") or not np.allclose(self._previous_center, center):
148
+ self.grid = np.indices(arr.shape, dtype=np.float32).reshape(arr.ndim, -1)
149
+ np.subtract(self.grid, center, out=self.grid)
150
+ self.grid_out = np.zeros_like(self.grid)
151
+ self._previous_center = center
152
+
153
+ np.matmul(rotation_matrix.T, self.grid, out=self.grid_out)
154
+ translation = np.add(translation[:, None], center)
155
+ np.add(self.grid_out, translation, out=self.grid_out)
151
156
 
152
157
  if out is None:
153
158
  out = np.zeros_like(arr)
154
159
 
155
- map_coordinates(arr, grid, order=order, output=out.ravel())
160
+ map_coordinates(arr, self.grid_out, order=order, output=out.ravel())
156
161
 
157
162
  if out_mask is None and arr_mask is not None:
158
163
  out_mask = np.zeros_like(arr_mask)
159
164
 
160
165
  if arr_mask is not None:
161
- map_coordinates(arr_mask, grid, order=order, output=out_mask.ravel())
166
+ map_coordinates(
167
+ arr_mask, self.grid_out, order=order, output=out_mask.ravel()
168
+ )
162
169
 
163
170
  match return_type:
164
171
  case 0:
@@ -218,19 +225,48 @@ class _MatchDensityToDensity(ABC):
218
225
  The matching score obtained for the transformation.
219
226
  """
220
227
  translation, rotation_matrix = _format_rigid_transform(x)
228
+ self.template_rot.fill(0)
229
+
230
+ voxel_translation = be.astype(translation, be._int_dtype)
231
+ subvoxel_translation = be.subtract(translation, voxel_translation)
232
+
233
+ center = be.astype(be.divide(self.template.shape, 2), be._int_dtype)
234
+ right_pad = be.subtract(self.template.shape, center)
235
+
236
+ translated_center = be.add(voxel_translation, center)
237
+
238
+ target_starts = be.subtract(translated_center, center)
239
+ target_stops = be.add(translated_center, right_pad)
240
+
241
+ template_starts = be.subtract(be.maximum(target_starts, 0), target_starts)
242
+ template_stops = be.subtract(
243
+ target_stops, be.minimum(target_stops, self.target.shape)
244
+ )
245
+ template_stops = be.subtract(self.template.shape, template_stops)
246
+
247
+ target_starts = be.maximum(target_starts, 0)
248
+ target_stops = be.minimum(target_stops, self.target.shape)
249
+
250
+ cand_start, cand_stop = template_starts.astype(int), template_stops.astype(int)
251
+ obs_start, obs_stop = target_starts.astype(int), target_stops.astype(int)
252
+
253
+ self.template_slices = tuple(slice(s, e) for s, e in zip(cand_start, cand_stop))
254
+ self.target_slices = tuple(slice(s, e) for s, e in zip(obs_start, obs_stop))
255
+
221
256
  kw_dict = {
222
257
  "arr": self.template,
223
258
  "rotation_matrix": rotation_matrix,
224
- "translation": translation,
259
+ "translation": subvoxel_translation,
225
260
  "out": self.template_rot,
226
- "use_geometric_center": False,
227
261
  "order": self.interpolation_order,
262
+ "use_geometric_center": True,
228
263
  }
229
264
  if self.rotate_mask:
265
+ self.template_mask_rot.fill(0)
230
266
  kw_dict["arr_mask"] = self.template_mask
231
267
  kw_dict["out_mask"] = self.template_mask_rot
232
268
 
233
- self.rigid_transform(**kw_dict)
269
+ self.rotate_array(**kw_dict)
234
270
 
235
271
  return self()
236
272
 
@@ -246,11 +282,11 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
246
282
  target : NDArray
247
283
  A d-dimensional target to match the template coordinate set to.
248
284
  template_coordinates : NDArray
249
- Template coordinate array with shape [d x N].
285
+ Template coordinate array with shape (d,n).
250
286
  template_weights : NDArray
251
- Template weight array with shape [N].
287
+ Template weight array with shape (n,).
252
288
  template_mask_coordinates : NDArray, optional
253
- Template mask coordinates with shape [d x N].
289
+ Template mask coordinates with shape (d,n).
254
290
  target_mask : NDArray, optional
255
291
  A d-dimensional mask to be applied to the target.
256
292
  negate_score : bool, optional
@@ -269,6 +305,7 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
269
305
  negate_score: bool = True,
270
306
  **kwargs: Dict,
271
307
  ):
308
+ self.eps = be.eps(target.dtype)
272
309
  self.target_density = target
273
310
  self.target_mask_density = target_mask
274
311
 
@@ -277,6 +314,8 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
277
314
  self.template_coordinates_rotated = np.copy(self.template_coordinates).astype(
278
315
  np.float32
279
316
  )
317
+ if template_mask_coordinates is None:
318
+ template_mask_coordinates = template_coordinates.copy()
280
319
 
281
320
  self.template_mask_coordinates = template_mask_coordinates
282
321
  self.template_mask_coordinates_rotated = template_mask_coordinates
@@ -291,9 +330,9 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
291
330
  self.in_volume, self.in_volume_mask = self.map_coordinates_to_array(
292
331
  coordinates=self.template_coordinates_rotated,
293
332
  coordinates_mask=self.template_mask_coordinates_rotated,
294
- array_origin=backend.zeros(target.ndim),
333
+ array_origin=be.zeros(target.ndim),
295
334
  array_shape=self.target_density.shape,
296
- sampling_rate=backend.full(target.ndim, fill_value=1),
335
+ sampling_rate=be.full(target.ndim, fill_value=1),
297
336
  )
298
337
 
299
338
  if hasattr(self, "_post_init"):
@@ -329,9 +368,9 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
329
368
  self.in_volume, self.in_volume_mask = self.map_coordinates_to_array(
330
369
  coordinates=self.template_coordinates_rotated,
331
370
  coordinates_mask=self.template_mask_coordinates_rotated,
332
- array_origin=backend.zeros(rotation_matrix.shape[0]),
371
+ array_origin=be.zeros(rotation_matrix.shape[0]),
333
372
  array_shape=self.target_density.shape,
334
- sampling_rate=backend.full(rotation_matrix.shape[0], fill_value=1),
373
+ sampling_rate=be.full(rotation_matrix.shape[0], fill_value=1),
335
374
  )
336
375
 
337
376
  return self()
@@ -521,55 +560,65 @@ class _MatchCoordinatesToCoordinates(_MatchDensityToDensity):
521
560
 
522
561
 
523
562
  class FLC(_MatchDensityToDensity):
563
+ """
564
+ Computes a normalized cross-correlation score of a target f a template g
565
+ and a mask m:
566
+
567
+ .. math::
568
+
569
+ \\frac{CC(f, \\frac{g*m - \\overline{g*m}}{\\sigma_{g*m}})}
570
+ {N_m * \\sqrt{
571
+ \\frac{CC(f^2, m)}{N_m} - (\\frac{CC(f, m)}{N_m})^2}
572
+ }
573
+
574
+ Where:
575
+
576
+ .. math::
577
+
578
+ CC(f,g) = \\mathcal{F}^{-1}(\\mathcal{F}(f) \\cdot \\mathcal{F}(g)^*)
579
+
580
+ and Nm is the number of voxels within the template mask m.
581
+ """
582
+
524
583
  __doc__ += _MatchDensityToDensity.__doc__
525
584
 
526
585
  def _post_init(self, **kwargs: Dict):
527
586
  if self.target_mask is not None:
528
- backend.multiply(self.target, self.target_mask, out=self.target)
587
+ be.multiply(self.target, self.target_mask, out=self.target)
529
588
 
530
- self.target_square = backend.square(self.target)
589
+ self.target_square = be.square(self.target)
531
590
 
532
- normalize_under_mask(
591
+ normalize_template(
533
592
  template=self.template,
534
593
  mask=self.template_mask,
535
- mask_intensity=backend.sum(self.template_mask),
594
+ n_observations=be.sum(self.template_mask),
536
595
  )
537
596
 
538
- self.template = backend.reverse(self.template)
539
- self.template_mask = backend.reverse(self.template_mask)
540
-
541
597
  def __call__(self) -> float:
542
598
  """Returns the score of the current configuration."""
543
- n_observations = backend.sum(self.template_mask_rot)
599
+ n_obs = be.sum(self.template_mask_rot)
544
600
 
545
- normalize_under_mask(
601
+ normalize_template(
546
602
  template=self.template_rot,
547
603
  mask=self.template_mask_rot,
548
- mask_intensity=n_observations,
604
+ n_observations=n_obs,
549
605
  )
550
-
551
- ex2 = backend.sum(
552
- backend.divide(
553
- backend.sum(
554
- backend.multiply(self.target_square, self.template_mask_rot),
555
- ),
556
- n_observations,
557
- )
558
- )
559
- e2x = backend.square(
560
- backend.divide(
561
- backend.sum(backend.multiply(self.target, self.template_mask_rot)),
562
- n_observations,
606
+ overlap = be.sum(
607
+ be.multiply(
608
+ self.template_rot[self.template_slices], self.target[self.target_slices]
563
609
  )
564
610
  )
565
611
 
566
- denominator = backend.maximum(backend.subtract(ex2, e2x), 0.0)
567
- denominator = backend.sqrt(denominator)
568
- denominator = backend.multiply(denominator, n_observations)
612
+ mask_rot = self.template_mask_rot[self.template_slices]
613
+ exp_sq = be.sum(self.target_square[self.target_slices] * mask_rot) / n_obs
614
+ sq_exp = be.square(be.sum(self.target[self.target_slices] * mask_rot) / n_obs)
569
615
 
570
- overlap = backend.sum(backend.multiply(self.template_rot, self.target))
616
+ denominator = be.maximum(be.subtract(exp_sq, sq_exp), 0.0)
617
+ denominator = be.sqrt(denominator)
618
+ if denominator < self.eps:
619
+ return 0
571
620
 
572
- score = backend.divide(overlap, denominator) * self.score_sign
621
+ score = be.divide(overlap, denominator * n_obs) * self.score_sign
573
622
  return score
574
623
 
575
624
 
@@ -586,29 +635,12 @@ class CrossCorrelation(_MatchCoordinatesToDensity):
586
635
 
587
636
  def __call__(self) -> float:
588
637
  """Returns the score of the current configuration."""
589
- try:
590
- score = np.dot(
591
- self.target_density[
592
- tuple(
593
- self.template_coordinates_rotated[:, self.in_volume].astype(int)
594
- )
595
- ],
596
- self.template_weights[self.in_volume],
597
- )
598
- except:
599
- print(self.template_coordinates_rotated[:, self.in_volume].astype(int))
600
- print(self.target_density.shape)
601
- print(self.in_volume)
602
- coordinates = self.template_coordinates_rotated[:, self.in_volume].astype(
603
- int
604
- )
605
- in_volume = np.logical_and(
606
- coordinates < np.array(self.target_density.shape)[:, None],
607
- coordinates >= 0,
608
- ).min(axis=0)
609
- print(in_volume)
610
-
611
- raise ValueError()
638
+ score = np.dot(
639
+ self.target_density[
640
+ tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
641
+ ],
642
+ self.template_weights[self.in_volume],
643
+ )
612
644
  score /= self.denominator
613
645
  return score * self.score_sign
614
646
 
@@ -663,10 +695,24 @@ class NormalizedCrossCorrelation(CrossCorrelation):
663
695
 
664
696
  __doc__ += _MatchCoordinatesToDensity.__doc__
665
697
 
666
- def _post_init(self, **kwargs):
667
- target_norm = np.linalg.norm(self.target_density[self.target_density != 0])
668
- template_norm = np.linalg.norm(self.template_weights)
669
- self.denominator = np.fmax(target_norm * template_norm, np.finfo(float).eps)
698
+ def __call__(self) -> float:
699
+ n_observations = be.sum(self.in_volume_mask)
700
+ target_coordinates = be.astype(
701
+ self.template_mask_coordinates_rotated[:, self.in_volume_mask], int
702
+ )
703
+ target_weight = self.target_density[tuple(target_coordinates)]
704
+ ex2 = be.divide(be.sum(be.square(target_weight)), n_observations)
705
+ e2x = be.square(be.divide(be.sum(target_weight), n_observations))
706
+
707
+ denominator = be.maximum(be.subtract(ex2, e2x), 0.0)
708
+ denominator = be.sqrt(denominator)
709
+ denominator = be.multiply(denominator, n_observations)
710
+
711
+ if denominator <= self.eps:
712
+ return 0.0
713
+
714
+ self.denominator = denominator
715
+ return super().__call__()
670
716
 
671
717
 
672
718
  class NormalizedCrossCorrelationMean(NormalizedCrossCorrelation):
@@ -1037,7 +1083,7 @@ def register_matching_optimization(match_name: str, match_class: type):
1037
1083
 
1038
1084
  def create_score_object(score: str, **kwargs) -> object:
1039
1085
  """
1040
- Initialize score object with name ``score`` using `**kwargs``.
1086
+ Initialize score object with name ``score`` using ``**kwargs``.
1041
1087
 
1042
1088
  Parameters
1043
1089
  ----------
@@ -1059,6 +1105,32 @@ def create_score_object(score: str, **kwargs) -> object:
1059
1105
  See Also
1060
1106
  --------
1061
1107
  :py:meth:`register_matching_optimization`
1108
+
1109
+ Examples
1110
+ --------
1111
+ >>> from tme import Density
1112
+ >>> from tme.matching_utils import create_mask, euler_to_rotationmatrix
1113
+ >>> from tme.matching_optimization import CrossCorrelation, optimize_match
1114
+ >>> translation, rotation = (5, -2, 7), (5, -10, 2)
1115
+ >>> target = create_mask(
1116
+ >>> mask_type="ellipse",
1117
+ >>> radius=(5,5,5),
1118
+ >>> shape=(51,51,51),
1119
+ >>> center=(25,25,25),
1120
+ >>> ).astype(float)
1121
+ >>> template = Density(data=target)
1122
+ >>> template = template.rigid_transform(
1123
+ >>> translation=translation,
1124
+ >>> rotation_matrix=euler_to_rotationmatrix(rotation),
1125
+ >>> )
1126
+ >>> template_coordinates = template.to_pointcloud(0)
1127
+ >>> template_weights = template.data[tuple(template_coordinates)]
1128
+ >>> score_object = CrossCorrelation(
1129
+ >>> target=target,
1130
+ >>> template_coordinates=template_coordinates,
1131
+ >>> template_weights=template_weights,
1132
+ >>> negate_score=True # Multiply returned score with -1 for minimization
1133
+ >>> )
1062
1134
  """
1063
1135
 
1064
1136
  score_object = MATCHING_OPTIMIZATION_REGISTER.get(score, None)
@@ -1078,10 +1150,11 @@ def optimize_match(
1078
1150
  bounds_translation: Tuple[Tuple[float]] = None,
1079
1151
  bounds_rotation: Tuple[Tuple[float]] = None,
1080
1152
  optimization_method: str = "basinhopping",
1081
- maxiter: int = 500,
1153
+ maxiter: int = 50,
1154
+ x0: Tuple[float] = None,
1082
1155
  ) -> Tuple[ArrayLike, ArrayLike, float]:
1083
1156
  """
1084
- Find the translation and rotation optimizing the score returned by `score_object`
1157
+ Find the translation and rotation optimizing the score returned by ``score_object``
1085
1158
  with respect to provided bounds.
1086
1159
 
1087
1160
  Parameters
@@ -1098,36 +1171,68 @@ def optimize_match(
1098
1171
  Bounds on the evaluated zyx Euler angles. Has to be specified per dimension
1099
1172
  as tuple of (min, max). Default is None.
1100
1173
  optimization_method : str, optional
1101
- Optimizer that will be used, by default basinhopping. For further
1174
+ Optimizer that will be used, basinhopping by default. For further
1102
1175
  information refer to :doc:`scipy:reference/optimize`.
1103
1176
 
1104
- +--------------------------+-----------------------------------------+
1105
- | 'differential_evolution' | Highest accuracy but long runtime. |
1106
- | | Requires bounds on translation. |
1107
- +--------------------------+-----------------------------------------+
1108
- | 'basinhopping' | Decent accuracy, medium runtime. |
1109
- +--------------------------+-----------------------------------------+
1110
- | 'minimize' | If initial values are closed to optimum |
1111
- | | decent performance, short runtime. |
1112
- +--------------------------+-----------------------------------------+
1177
+ +------------------------+-------------------------------------------+
1178
+ | differential_evolution | Highest accuracy but long runtime. |
1179
+ | | Requires bounds on translation. |
1180
+ +------------------------+-------------------------------------------+
1181
+ | basinhopping | Decent accuracy, medium runtime. |
1182
+ +------------------------+-------------------------------------------+
1183
+ | minimize | If initial values are closed to optimum |
1184
+ | | acceptable accuracy and short runtime |
1185
+ +------------------------+-------------------------------------------+
1186
+
1113
1187
  maxiter : int, optional
1114
- The maximum number of iterations. Default is 500. Not considered for
1115
- `optimization_method` 'minimize'.
1188
+ The maximum number of iterations, 50 by default.
1189
+ x0 : tuple of floats, optional
1190
+ Initial values for the optimizer, zero by default.
1116
1191
 
1117
1192
  Returns
1118
1193
  -------
1119
1194
  Tuple[ArrayLike, ArrayLike, float]
1120
- Translation and rotation matrix yielding final score.
1195
+ Optimal translation, rotation matrix and corresponding score.
1121
1196
 
1122
1197
  Raises
1123
1198
  ------
1124
1199
  ValueError
1125
- If `optimization_method` is not supported.
1200
+ If ``optimization_method`` is not supported.
1126
1201
 
1127
1202
  Notes
1128
1203
  -----
1129
1204
  This function currently only supports three-dimensional optimization and
1130
- `score_object` will be modified during this operation.
1205
+ ``score_object`` will be modified during this operation.
1206
+
1207
+ Examples
1208
+ --------
1209
+ Having defined ``score_object``, for instance via :py:meth:`create_score_object`,
1210
+ non-exhaustive template matching can be performed as follows
1211
+
1212
+ >>> translation_fit, rotation_fit, score = optimize_match(score_object)
1213
+
1214
+ `translation_fit` and `rotation_fit` correspond to the inverse of the applied
1215
+ translation and rotation, so the following statements should hold within tolerance
1216
+
1217
+ >>> np.allclose(translation, -translation_fit, atol = 1) # True
1218
+ >>> np.allclose(rotation, np.linalg.inv(rotation_fit), rtol = .1) # True
1219
+
1220
+ Bounds on translation and rotation can be defined as follows
1221
+
1222
+ >>> translation_fit, rotation_fit, score = optimize_match(
1223
+ >>> score_object=score_object,
1224
+ >>> bounds_translation=((-5,5),(-2,2),(0,0)),
1225
+ >>> bounds_rotation=((-10,10), (-5,5), (0,0)),
1226
+ >>> )
1227
+
1228
+ The optimization scheme and the initial parameter estimates can also be adapted
1229
+
1230
+ >>> translation_fit, rotation_fit, score = optimize_match(
1231
+ >>> score_object=score_object,
1232
+ >>> optimization_method="minimize",
1233
+ >>> x0=(0,0,0,5,3,-5),
1234
+ >>> )
1235
+
1131
1236
  """
1132
1237
  ndim = 3
1133
1238
  _optimization_method = {
@@ -1164,10 +1269,12 @@ def optimize_match(
1164
1269
  np.eye(len(bounds)), np.min(bounds, axis=1), np.max(bounds, axis=1)
1165
1270
  )
1166
1271
 
1167
- initial_score = score_object()
1272
+ x0 = np.zeros(2 * ndim) if x0 is None else x0
1273
+
1274
+ initial_score = score_object.score(x=x0)
1168
1275
  if optimization_method == "basinhopping":
1169
1276
  result = basinhopping(
1170
- x0=np.zeros(2 * ndim),
1277
+ x0=x0,
1171
1278
  func=score_object.score,
1172
1279
  niter=maxiter,
1173
1280
  minimizer_kwargs={"method": "COBYLA", "constraints": linear_constraint},
@@ -1180,11 +1287,13 @@ def optimize_match(
1180
1287
  maxiter=maxiter,
1181
1288
  )
1182
1289
  elif optimization_method == "minimize":
1290
+ print(maxiter)
1183
1291
  result = minimize(
1184
- x0=np.zeros(2 * ndim),
1292
+ x0=x0,
1185
1293
  fun=score_object.score,
1186
1294
  bounds=bounds,
1187
1295
  constraints=linear_constraint,
1296
+ options={"maxiter": maxiter},
1188
1297
  )
1189
1298
  print(f"Niter: {result.nit}, success : {result.success} ({result.message}).")
1190
1299
  print(f"Initial score: {initial_score} - Refined score: {result.fun}")
@@ -1194,7 +1303,3 @@ def optimize_match(
1194
1303
  translation, rotation = result.x[:ndim], result.x[ndim:]
1195
1304
  rotation_matrix = euler_to_rotationmatrix(rotation)
1196
1305
  return translation, rotation_matrix, result.fun
1197
-
1198
-
1199
- class FitRefinement:
1200
- pass