pytme 0.1.8__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. pytme-0.2.0.data/scripts/match_template.py +1019 -0
  2. pytme-0.2.0.data/scripts/postprocess.py +570 -0
  3. {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/preprocessor_gui.py +244 -60
  4. {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/METADATA +3 -1
  5. pytme-0.2.0.dist-info/RECORD +72 -0
  6. {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +218 -0
  8. scripts/match_template.py +459 -218
  9. pytme-0.1.8.data/scripts/match_template.py → scripts/match_template_filters.py +459 -218
  10. scripts/postprocess.py +380 -435
  11. scripts/preprocessor_gui.py +244 -60
  12. scripts/refine_matches.py +218 -0
  13. tme/__init__.py +2 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +533 -78
  16. tme/backends/cupy_backend.py +80 -15
  17. tme/backends/npfftw_backend.py +35 -6
  18. tme/backends/pytorch_backend.py +15 -7
  19. tme/density.py +173 -78
  20. tme/extensions.cpython-311-darwin.so +0 -0
  21. tme/matching_constrained.py +195 -0
  22. tme/matching_data.py +78 -32
  23. tme/matching_exhaustive.py +369 -221
  24. tme/matching_memory.py +1 -0
  25. tme/matching_optimization.py +753 -649
  26. tme/matching_utils.py +152 -8
  27. tme/orientations.py +561 -0
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +176 -0
  30. tme/preprocessing/composable_filter.py +30 -0
  31. tme/preprocessing/compose.py +52 -0
  32. tme/preprocessing/frequency_filters.py +322 -0
  33. tme/preprocessing/tilt_series.py +967 -0
  34. tme/preprocessor.py +35 -25
  35. tme/structure.py +2 -37
  36. pytme-0.1.8.data/scripts/postprocess.py +0 -625
  37. pytme-0.1.8.dist-info/RECORD +0 -61
  38. {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/estimate_ram_usage.py +0 -0
  39. {pytme-0.1.8.data → pytme-0.2.0.data}/scripts/preprocess.py +0 -0
  40. {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/LICENSE +0 -0
  41. {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.1.8.dist-info → pytme-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
1
+ import numpy as np
2
+ from typing import Tuple, Dict
3
+
4
+ from scipy.ndimage import map_coordinates
5
+
6
+ from tme.types import ArrayLike
7
+ from tme.backends import backend
8
+ from tme.matching_data import MatchingData
9
+ from tme.matching_exhaustive import _normalize_under_mask
10
+
11
+
12
+ class MatchDensityToDensity:
13
+ def __init__(
14
+ self,
15
+ matching_data: "MatchingData",
16
+ pad_target_edges: bool = False,
17
+ pad_fourier: bool = False,
18
+ rotate_mask: bool = True,
19
+ interpolation_order: int = 1,
20
+ negate_score: bool = False,
21
+ ):
22
+ self.rotate_mask = rotate_mask
23
+ self.interpolation_order = interpolation_order
24
+
25
+ target_pad = matching_data.target_padding(pad_target=pad_target_edges)
26
+ matching_data = matching_data.subset_by_slice(target_pad=target_pad)
27
+
28
+ fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
29
+ pad_fourier=pad_fourier
30
+ )
31
+
32
+ self.target = backend.topleft_pad(matching_data.target, fast_shape)
33
+ self.target_mask = matching_data.target_mask
34
+
35
+ self.template = matching_data.template
36
+ self.template_rot = backend.preallocate_array(
37
+ fast_shape, backend._default_dtype
38
+ )
39
+
40
+ self.template_mask, self.template_mask_rot = 1, 1
41
+ rotate_mask = False if matching_data.template_mask is None else rotate_mask
42
+ if matching_data.template_mask is not None:
43
+ self.template_mask = matching_data.template_mask
44
+ self.template_mask_rot = backend.topleft_pad(
45
+ matching_data.template_mask, fast_shape
46
+ )
47
+
48
+ self.score_sign = -1 if negate_score else 1
49
+
50
+ @staticmethod
51
+ def rigid_transform(
52
+ arr,
53
+ rotation_matrix,
54
+ translation,
55
+ arr_mask=None,
56
+ out=None,
57
+ out_mask=None,
58
+ order: int = 1,
59
+ use_geometric_center: bool = False,
60
+ ):
61
+ rotate_mask = arr_mask is not None
62
+ return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
63
+ translation = np.zeros(arr.ndim) if translation is None else translation
64
+
65
+ center = np.floor(np.array(arr.shape) / 2)[:, None]
66
+ grid = np.indices(arr.shape, dtype=np.float32).reshape(arr.ndim, -1)
67
+ np.subtract(grid, center, out=grid)
68
+ np.matmul(rotation_matrix.T, grid, out=grid)
69
+ np.add(grid, center, out=grid)
70
+
71
+ if out is None:
72
+ out = np.zeros_like(arr)
73
+
74
+ map_coordinates(arr, grid, order=order, output=out.ravel())
75
+
76
+ if out_mask is None and arr_mask is not None:
77
+ out_mask = np.zeros_like(arr_mask)
78
+
79
+ if arr_mask is not None:
80
+ map_coordinates(arr_mask, grid, order=order, output=out_mask.ravel())
81
+
82
+ match return_type:
83
+ case 0:
84
+ return None
85
+ case 1:
86
+ return out
87
+ case 2:
88
+ return out_mask
89
+ case 3:
90
+ return out, out_mask
91
+
92
+ @staticmethod
93
+ def angles_to_rotationmatrix(angles: Tuple[float]) -> ArrayLike:
94
+ angles = backend.to_numpy_array(angles)
95
+ rotation_matrix = euler_to_rotationmatrix(angles)
96
+ return backend.to_backend_array(rotation_matrix)
97
+
98
+ def format_translation(self, translation: Tuple[float] = None) -> ArrayLike:
99
+ if translation is None:
100
+ return backend.zeros(self.template.ndim, backend._default_dtype)
101
+
102
+ return backend.to_backend_array(translation)
103
+
104
+ def score_translation(self, x: Tuple[float]) -> float:
105
+ translation = self.format_translation(x)
106
+ rotation_matrix = self.angles_to_rotationmatrix((0, 0, 0))
107
+
108
+ return self(translation=translation, rotation_matrix=rotation_matrix)
109
+
110
+ def score_angles(self, x: Tuple[float]) -> float:
111
+ translation = self.format_translation(None)
112
+ rotation_matrix = self.angles_to_rotationmatrix(x)
113
+
114
+ return self(translation=translation, rotation_matrix=rotation_matrix)
115
+
116
+ def score(self, x: Tuple[float]) -> float:
117
+ split = len(x) // 2
118
+ translation, angles = x[:split], x[split:]
119
+
120
+ translation = self.format_translation(translation)
121
+ rotation_matrix = self.angles_to_rotationmatrix(angles)
122
+
123
+ return self(translation=translation, rotation_matrix=rotation_matrix)
124
+
125
+
126
+ class FLC(MatchDensityToDensity):
127
+ def __init__(self, **kwargs: Dict):
128
+ super().__init__(**kwargs)
129
+
130
+ if self.target_mask is not None:
131
+ backend.multiply(self.target, self.target_mask, out=self.target)
132
+
133
+ self.target_square = backend.square(self.target)
134
+
135
+ _normalize_under_mask(
136
+ template=self.template,
137
+ mask=self.template_mask,
138
+ mask_intensity=backend.sum(self.template_mask),
139
+ )
140
+
141
+ self.template = backend.reverse(self.template)
142
+ self.template_mask = backend.reverse(self.template_mask)
143
+
144
+ def __call__(self, translation: ArrayLike, rotation_matrix: ArrayLike) -> float:
145
+ if self.rotate_mask:
146
+ self.rigid_transform(
147
+ arr=self.template,
148
+ arr_mask=self.template_mask,
149
+ rotation_matrix=rotation_matrix,
150
+ translation=translation,
151
+ out=self.template_rot,
152
+ out_mask=self.template_mask_rot,
153
+ use_geometric_center=False,
154
+ order=self.interpolation_order,
155
+ )
156
+ else:
157
+ self.rigid_transform(
158
+ arr=self.template,
159
+ rotation_matrix=rotation_matrix,
160
+ translation=translation,
161
+ out=self.template_rot,
162
+ use_geometric_center=False,
163
+ order=self.interpolation_order,
164
+ )
165
+ n_observations = backend.sum(self.template_mask_rot)
166
+
167
+ _normalize_under_mask(
168
+ template=self.template_rot,
169
+ mask=self.template_mask_rot,
170
+ mask_intensity=n_observations,
171
+ )
172
+
173
+ ex2 = backend.sum(
174
+ backend.divide(
175
+ backend.sum(
176
+ backend.multiply(self.target_square, self.template_mask_rot),
177
+ ),
178
+ n_observations,
179
+ )
180
+ )
181
+ e2x = backend.square(
182
+ backend.divide(
183
+ backend.sum(backend.multiply(self.target, self.template_mask_rot)),
184
+ n_observations,
185
+ )
186
+ )
187
+
188
+ denominator = backend.maximum(backend.subtract(ex2, e2x), 0.0)
189
+ denominator = backend.sqrt(denominator)
190
+ denominator = backend.multiply(denominator, n_observations)
191
+
192
+ overlap = backend.sum(backend.multiply(self.template_rot, self.target))
193
+
194
+ score = backend.divide(overlap, denominator) * self.score_sign
195
+ return score
tme/matching_data.py CHANGED
@@ -47,6 +47,9 @@ class MatchingData:
47
47
  self.target_filter = {}
48
48
 
49
49
  self._invert_target = False
50
+ self._rotations = None
51
+
52
+ self._set_batch_dimension()
50
53
 
51
54
  @staticmethod
52
55
  def _shape_to_slice(shape: Tuple[int]):
@@ -149,13 +152,13 @@ class MatchingData:
149
152
  def _warn_on_mismatch(
150
153
  expectation: float, computation: float, name: str
151
154
  ) -> float:
152
- expectation, computation = float(expectation), float(computation)
153
155
  if expectation is None:
154
156
  expectation = computation
157
+ expectation, computation = float(expectation), float(computation)
155
158
 
156
159
  if abs(computation) > abs(expectation):
157
160
  warnings.warn(
158
- f"Computed {name} value is more extreme than value specified in file"
161
+ f"Computed {name} value is more extreme than value in file header"
159
162
  f" (|{computation}| > |{expectation}|). This may lead to issues"
160
163
  " with padding and contrast inversion."
161
164
  )
@@ -175,9 +178,10 @@ class MatchingData:
175
178
  arr_min = _warn_on_mismatch(arr_min, arr.min(), "min")
176
179
  arr_max = _warn_on_mismatch(arr_max, arr.max(), "max")
177
180
 
178
- np.subtract(-ret, arr_min, out=ret)
179
- np.divide(ret, arr_max - arr_min, out=ret)
180
-
181
+ # Avoid in-place operation in case ret is not floating point
182
+ ret = (
183
+ -np.divide(np.subtract(ret, arr_min), np.subtract(arr_max, arr_min)) + 1
184
+ )
181
185
  return ret
182
186
 
183
187
  def subset_by_slice(
@@ -220,14 +224,16 @@ class MatchingData:
220
224
  if target_pad is None:
221
225
  target_pad = np.zeros(len(self._target.shape), dtype=int)
222
226
  if template_pad is None:
223
- template_pad = np.zeros(len(self._target.shape), dtype=int)
224
-
225
- indices = compute_full_convolution_index(
226
- outer_shape=self._target.shape,
227
- inner_shape=self._template.shape,
228
- outer_split=target_slice,
229
- inner_split=template_slice,
230
- )
227
+ template_pad = np.zeros(len(self._template.shape), dtype=int)
228
+
229
+ indices = None
230
+ if len(self._target.shape) == len(self._template.shape):
231
+ indices = compute_full_convolution_index(
232
+ outer_shape=self._target.shape,
233
+ inner_shape=self._template.shape,
234
+ outer_split=target_slice,
235
+ inner_split=template_slice,
236
+ )
231
237
 
232
238
  target_subset = self.subset_array(
233
239
  arr=self._target,
@@ -243,13 +249,19 @@ class MatchingData:
243
249
  )
244
250
  ret = self.__class__(target=target_subset, template=template_subset)
245
251
 
246
- ret._translation_offset = np.add(
247
- [x.start for x in target_slice],
248
- [x.start for x in template_slice],
249
- )
250
- ret.template_filter = self.template_filter
252
+ target_offset = np.zeros(len(self._output_target_shape), dtype=int)
253
+ target_offset[(target_offset.size - len(target_slice)) :] = [
254
+ x.start for x in target_slice
255
+ ]
256
+ template_offset = np.zeros(len(self._output_target_shape), dtype=int)
257
+ template_offset[(template_offset.size - len(template_slice)) :] = [
258
+ x.start for x in template_slice
259
+ ]
260
+ ret._translation_offset = np.add(target_offset, template_offset)
251
261
 
252
- ret.rotations, ret.indices = self.rotations, indices
262
+ ret.template_filter = self.template_filter
263
+ ret.target_filter = self.target_filter
264
+ ret._rotations, ret.indices = self.rotations, indices
253
265
  ret._target_pad, ret._template_pad = target_pad, template_pad
254
266
  ret._invert_target = self._invert_target
255
267
 
@@ -279,10 +291,14 @@ class MatchingData:
279
291
  """
280
292
  Transfer the class instance's numpy arrays to the current backend.
281
293
  """
294
+ backend_arr = type(backend.zeros((1), dtype=backend._default_dtype))
282
295
  for attr_name, attr_value in vars(self).items():
283
296
  if isinstance(attr_value, np.ndarray):
284
297
  converted_array = backend.to_backend_array(attr_value.copy())
285
298
  setattr(self, attr_name, converted_array)
299
+ elif isinstance(attr_value, backend_arr):
300
+ converted_array = backend.to_backend_array(attr_value)
301
+ setattr(self, attr_name, converted_array)
286
302
 
287
303
  self._default_dtype = backend._default_dtype
288
304
  self._complex_dtype = backend._complex_dtype
@@ -424,13 +440,13 @@ class MatchingData:
424
440
  An array indicating the padding for each dimension of the target.
425
441
  """
426
442
  target_padding = backend.zeros(
427
- len(self.target.shape), dtype=backend._default_dtype_int
443
+ len(self._output_target_shape), dtype=backend._default_dtype_int
428
444
  )
429
445
 
430
446
  if pad_target:
431
447
  backend.subtract(
432
- self._template.shape,
433
- backend.mod(self._template.shape, 2),
448
+ self._output_template_shape,
449
+ backend.mod(self._output_template_shape, 2),
434
450
  out=target_padding,
435
451
  )
436
452
  if hasattr(self, "_is_target_batch"):
@@ -472,22 +488,34 @@ class MatchingData:
472
488
  fourier_shift = backend.zeros(len(fourier_pad))
473
489
 
474
490
  if not pad_fourier:
475
- fourier_pad = backend.full(shape=len(fourier_pad), fill_value=1, dtype=int)
491
+ fourier_pad = backend.full(
492
+ shape=(len(fourier_pad),),
493
+ fill_value=1,
494
+ dtype=backend._default_dtype_int,
495
+ )
476
496
 
477
497
  fourier_pad = backend.to_backend_array(fourier_pad)
478
498
  if hasattr(self, "_batch_mask"):
479
499
  batch_mask = backend.to_backend_array(self._batch_mask)
480
500
  fourier_pad[batch_mask] = 1
481
501
 
482
- ret = backend.compute_convolution_shapes(target_shape, fourier_pad)
502
+ pad_shape = backend.maximum(target_shape, template_shape)
503
+ ret = backend.compute_convolution_shapes(pad_shape, fourier_pad)
483
504
  convolution_shape, fast_shape, fast_ft_shape = ret
484
505
  if not pad_fourier:
485
506
  fourier_shift = 1 - backend.astype(backend.divide(template_shape, 2), int)
486
507
  fourier_shift -= backend.mod(template_shape, 2)
487
508
  shape_diff = backend.subtract(fast_shape, convolution_shape)
488
509
  shape_diff = backend.astype(backend.divide(shape_diff, 2), int)
510
+
511
+ if hasattr(self, "_batch_mask"):
512
+ batch_mask = backend.to_backend_array(self._batch_mask)
513
+ shape_diff[batch_mask] = 0
514
+
489
515
  backend.add(fourier_shift, shape_diff, out=fourier_shift)
490
516
 
517
+ fourier_shift = backend.astype(fourier_shift, backend._default_dtype_int)
518
+
491
519
  return fast_shape, fast_ft_shape, fourier_shift
492
520
 
493
521
  @property
@@ -520,15 +548,21 @@ class MatchingData:
520
548
  def target(self):
521
549
  """Returns the target NDArray."""
522
550
  if isinstance(self._target, Density):
523
- return self._target.data
524
- return self._target
551
+ target = self._target.data
552
+ else:
553
+ target = self._target
554
+ out_shape = backend.to_numpy_array(self._output_target_shape)
555
+ return target.reshape(tuple(int(x) for x in out_shape))
525
556
 
526
557
  @property
527
558
  def template(self):
528
559
  """Returns the reversed template NDArray."""
560
+ template = self._template
529
561
  if isinstance(self._template, Density):
530
- return backend.reverse(self._template.data)
531
- return backend.reverse(self._template)
562
+ template = self._template.data
563
+ template = backend.reverse(template)
564
+ out_shape = backend.to_numpy_array(self._output_template_shape)
565
+ return template.reshape(tuple(int(x) for x in out_shape))
532
566
 
533
567
  @template.setter
534
568
  def template(self, template: NDArray):
@@ -555,9 +589,15 @@ class MatchingData:
555
589
  @property
556
590
  def target_mask(self):
557
591
  """Returns the target mask NDArray."""
592
+ target_mask = self._target_mask
558
593
  if isinstance(self._target_mask, Density):
559
- return self._target_mask.data
560
- return self._target_mask
594
+ target_mask = self._target_mask.data
595
+
596
+ if target_mask is not None:
597
+ out_shape = backend.to_numpy_array(self._output_target_shape)
598
+ target_mask = target_mask.reshape(tuple(int(x) for x in out_shape))
599
+
600
+ return target_mask
561
601
 
562
602
  @target_mask.setter
563
603
  def target_mask(self, mask: NDArray):
@@ -584,9 +624,15 @@ class MatchingData:
584
624
  template : NDArray
585
625
  Array to set as the template.
586
626
  """
627
+ mask = self._template_mask
587
628
  if isinstance(self._template_mask, Density):
588
- return backend.reverse(self._template_mask.data)
589
- return backend.reverse(self._template_mask)
629
+ mask = self._template_mask.data
630
+
631
+ if mask is not None:
632
+ mask = backend.reverse(mask)
633
+ out_shape = backend.to_numpy_array(self._output_template_shape)
634
+ mask = mask.reshape(tuple(int(x) for x in out_shape))
635
+ return mask
590
636
 
591
637
  @template_mask.setter
592
638
  def template_mask(self, mask: NDArray):