pytme 0.1.9__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. pytme-0.2.0.data/scripts/match_template.py +1019 -0
  2. pytme-0.2.0.data/scripts/postprocess.py +570 -0
  3. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocessor_gui.py +244 -60
  4. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/METADATA +3 -1
  5. pytme-0.2.0.dist-info/RECORD +72 -0
  6. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +218 -0
  8. scripts/match_template.py +459 -218
  9. pytme-0.1.9.data/scripts/match_template.py → scripts/match_template_filters.py +459 -218
  10. scripts/postprocess.py +380 -435
  11. scripts/preprocessor_gui.py +244 -60
  12. scripts/refine_matches.py +218 -0
  13. tme/__init__.py +2 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +533 -78
  16. tme/backends/cupy_backend.py +80 -15
  17. tme/backends/npfftw_backend.py +35 -6
  18. tme/backends/pytorch_backend.py +15 -7
  19. tme/density.py +173 -78
  20. tme/extensions.cpython-311-darwin.so +0 -0
  21. tme/matching_constrained.py +195 -0
  22. tme/matching_data.py +76 -33
  23. tme/matching_exhaustive.py +354 -225
  24. tme/matching_memory.py +1 -0
  25. tme/matching_optimization.py +753 -649
  26. tme/matching_utils.py +152 -8
  27. tme/orientations.py +561 -0
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +176 -0
  30. tme/preprocessing/composable_filter.py +30 -0
  31. tme/preprocessing/compose.py +52 -0
  32. tme/preprocessing/frequency_filters.py +322 -0
  33. tme/preprocessing/tilt_series.py +967 -0
  34. tme/preprocessor.py +35 -25
  35. tme/structure.py +2 -37
  36. pytme-0.1.9.data/scripts/postprocess.py +0 -625
  37. pytme-0.1.9.dist-info/RECORD +0 -61
  38. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/estimate_ram_usage.py +0 -0
  39. {pytme-0.1.9.data → pytme-0.2.0.data}/scripts/preprocess.py +0 -0
  40. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/LICENSE +0 -0
  41. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.1.9.dist-info → pytme-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
1
+ import numpy as np
2
+ from typing import Tuple, Dict
3
+
4
+ from scipy.ndimage import map_coordinates
5
+
6
+ from tme.types import ArrayLike
7
+ from tme.backends import backend
8
+ from tme.matching_data import MatchingData
9
+ from tme.matching_exhaustive import _normalize_under_mask
10
+
11
+
12
+ class MatchDensityToDensity:
13
+ def __init__(
14
+ self,
15
+ matching_data: "MatchingData",
16
+ pad_target_edges: bool = False,
17
+ pad_fourier: bool = False,
18
+ rotate_mask: bool = True,
19
+ interpolation_order: int = 1,
20
+ negate_score: bool = False,
21
+ ):
22
+ self.rotate_mask = rotate_mask
23
+ self.interpolation_order = interpolation_order
24
+
25
+ target_pad = matching_data.target_padding(pad_target=pad_target_edges)
26
+ matching_data = matching_data.subset_by_slice(target_pad=target_pad)
27
+
28
+ fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
29
+ pad_fourier=pad_fourier
30
+ )
31
+
32
+ self.target = backend.topleft_pad(matching_data.target, fast_shape)
33
+ self.target_mask = matching_data.target_mask
34
+
35
+ self.template = matching_data.template
36
+ self.template_rot = backend.preallocate_array(
37
+ fast_shape, backend._default_dtype
38
+ )
39
+
40
+ self.template_mask, self.template_mask_rot = 1, 1
41
+ rotate_mask = False if matching_data.template_mask is None else rotate_mask
42
+ if matching_data.template_mask is not None:
43
+ self.template_mask = matching_data.template_mask
44
+ self.template_mask_rot = backend.topleft_pad(
45
+ matching_data.template_mask, fast_shape
46
+ )
47
+
48
+ self.score_sign = -1 if negate_score else 1
49
+
50
+ @staticmethod
51
+ def rigid_transform(
52
+ arr,
53
+ rotation_matrix,
54
+ translation,
55
+ arr_mask=None,
56
+ out=None,
57
+ out_mask=None,
58
+ order: int = 1,
59
+ use_geometric_center: bool = False,
60
+ ):
61
+ rotate_mask = arr_mask is not None
62
+ return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
63
+ translation = np.zeros(arr.ndim) if translation is None else translation
64
+
65
+ center = np.floor(np.array(arr.shape) / 2)[:, None]
66
+ grid = np.indices(arr.shape, dtype=np.float32).reshape(arr.ndim, -1)
67
+ np.subtract(grid, center, out=grid)
68
+ np.matmul(rotation_matrix.T, grid, out=grid)
69
+ np.add(grid, center, out=grid)
70
+
71
+ if out is None:
72
+ out = np.zeros_like(arr)
73
+
74
+ map_coordinates(arr, grid, order=order, output=out.ravel())
75
+
76
+ if out_mask is None and arr_mask is not None:
77
+ out_mask = np.zeros_like(arr_mask)
78
+
79
+ if arr_mask is not None:
80
+ map_coordinates(arr_mask, grid, order=order, output=out_mask.ravel())
81
+
82
+ match return_type:
83
+ case 0:
84
+ return None
85
+ case 1:
86
+ return out
87
+ case 2:
88
+ return out_mask
89
+ case 3:
90
+ return out, out_mask
91
+
92
+ @staticmethod
93
+ def angles_to_rotationmatrix(angles: Tuple[float]) -> ArrayLike:
94
+ angles = backend.to_numpy_array(angles)
95
+ rotation_matrix = euler_to_rotationmatrix(angles)
96
+ return backend.to_backend_array(rotation_matrix)
97
+
98
+ def format_translation(self, translation: Tuple[float] = None) -> ArrayLike:
99
+ if translation is None:
100
+ return backend.zeros(self.template.ndim, backend._default_dtype)
101
+
102
+ return backend.to_backend_array(translation)
103
+
104
+ def score_translation(self, x: Tuple[float]) -> float:
105
+ translation = self.format_translation(x)
106
+ rotation_matrix = self.angles_to_rotationmatrix((0, 0, 0))
107
+
108
+ return self(translation=translation, rotation_matrix=rotation_matrix)
109
+
110
+ def score_angles(self, x: Tuple[float]) -> float:
111
+ translation = self.format_translation(None)
112
+ rotation_matrix = self.angles_to_rotationmatrix(x)
113
+
114
+ return self(translation=translation, rotation_matrix=rotation_matrix)
115
+
116
+ def score(self, x: Tuple[float]) -> float:
117
+ split = len(x) // 2
118
+ translation, angles = x[:split], x[split:]
119
+
120
+ translation = self.format_translation(translation)
121
+ rotation_matrix = self.angles_to_rotationmatrix(angles)
122
+
123
+ return self(translation=translation, rotation_matrix=rotation_matrix)
124
+
125
+
126
+ class FLC(MatchDensityToDensity):
127
+ def __init__(self, **kwargs: Dict):
128
+ super().__init__(**kwargs)
129
+
130
+ if self.target_mask is not None:
131
+ backend.multiply(self.target, self.target_mask, out=self.target)
132
+
133
+ self.target_square = backend.square(self.target)
134
+
135
+ _normalize_under_mask(
136
+ template=self.template,
137
+ mask=self.template_mask,
138
+ mask_intensity=backend.sum(self.template_mask),
139
+ )
140
+
141
+ self.template = backend.reverse(self.template)
142
+ self.template_mask = backend.reverse(self.template_mask)
143
+
144
+ def __call__(self, translation: ArrayLike, rotation_matrix: ArrayLike) -> float:
145
+ if self.rotate_mask:
146
+ self.rigid_transform(
147
+ arr=self.template,
148
+ arr_mask=self.template_mask,
149
+ rotation_matrix=rotation_matrix,
150
+ translation=translation,
151
+ out=self.template_rot,
152
+ out_mask=self.template_mask_rot,
153
+ use_geometric_center=False,
154
+ order=self.interpolation_order,
155
+ )
156
+ else:
157
+ self.rigid_transform(
158
+ arr=self.template,
159
+ rotation_matrix=rotation_matrix,
160
+ translation=translation,
161
+ out=self.template_rot,
162
+ use_geometric_center=False,
163
+ order=self.interpolation_order,
164
+ )
165
+ n_observations = backend.sum(self.template_mask_rot)
166
+
167
+ _normalize_under_mask(
168
+ template=self.template_rot,
169
+ mask=self.template_mask_rot,
170
+ mask_intensity=n_observations,
171
+ )
172
+
173
+ ex2 = backend.sum(
174
+ backend.divide(
175
+ backend.sum(
176
+ backend.multiply(self.target_square, self.template_mask_rot),
177
+ ),
178
+ n_observations,
179
+ )
180
+ )
181
+ e2x = backend.square(
182
+ backend.divide(
183
+ backend.sum(backend.multiply(self.target, self.template_mask_rot)),
184
+ n_observations,
185
+ )
186
+ )
187
+
188
+ denominator = backend.maximum(backend.subtract(ex2, e2x), 0.0)
189
+ denominator = backend.sqrt(denominator)
190
+ denominator = backend.multiply(denominator, n_observations)
191
+
192
+ overlap = backend.sum(backend.multiply(self.template_rot, self.target))
193
+
194
+ score = backend.divide(overlap, denominator) * self.score_sign
195
+ return score
tme/matching_data.py CHANGED
@@ -47,6 +47,9 @@ class MatchingData:
47
47
  self.target_filter = {}
48
48
 
49
49
  self._invert_target = False
50
+ self._rotations = None
51
+
52
+ self._set_batch_dimension()
50
53
 
51
54
  @staticmethod
52
55
  def _shape_to_slice(shape: Tuple[int]):
@@ -149,13 +152,13 @@ class MatchingData:
149
152
  def _warn_on_mismatch(
150
153
  expectation: float, computation: float, name: str
151
154
  ) -> float:
152
- expectation, computation = float(expectation), float(computation)
153
155
  if expectation is None:
154
156
  expectation = computation
157
+ expectation, computation = float(expectation), float(computation)
155
158
 
156
159
  if abs(computation) > abs(expectation):
157
160
  warnings.warn(
158
- f"Computed {name} value is more extreme than value specified in file"
161
+ f"Computed {name} value is more extreme than value in file header"
159
162
  f" (|{computation}| > |{expectation}|). This may lead to issues"
160
163
  " with padding and contrast inversion."
161
164
  )
@@ -176,11 +179,9 @@ class MatchingData:
176
179
  arr_max = _warn_on_mismatch(arr_max, arr.max(), "max")
177
180
 
178
181
  # Avoid in-place operation in case ret is not floating point
179
- ret = np.divide(
180
- np.subtract(-ret, arr_min),
181
- np.subtract(arr_max, arr_min)
182
+ ret = (
183
+ -np.divide(np.subtract(ret, arr_min), np.subtract(arr_max, arr_min)) + 1
182
184
  )
183
-
184
185
  return ret
185
186
 
186
187
  def subset_by_slice(
@@ -223,14 +224,16 @@ class MatchingData:
223
224
  if target_pad is None:
224
225
  target_pad = np.zeros(len(self._target.shape), dtype=int)
225
226
  if template_pad is None:
226
- template_pad = np.zeros(len(self._target.shape), dtype=int)
227
-
228
- indices = compute_full_convolution_index(
229
- outer_shape=self._target.shape,
230
- inner_shape=self._template.shape,
231
- outer_split=target_slice,
232
- inner_split=template_slice,
233
- )
227
+ template_pad = np.zeros(len(self._template.shape), dtype=int)
228
+
229
+ indices = None
230
+ if len(self._target.shape) == len(self._template.shape):
231
+ indices = compute_full_convolution_index(
232
+ outer_shape=self._target.shape,
233
+ inner_shape=self._template.shape,
234
+ outer_split=target_slice,
235
+ inner_split=template_slice,
236
+ )
234
237
 
235
238
  target_subset = self.subset_array(
236
239
  arr=self._target,
@@ -246,13 +249,19 @@ class MatchingData:
246
249
  )
247
250
  ret = self.__class__(target=target_subset, template=template_subset)
248
251
 
249
- ret._translation_offset = np.add(
250
- [x.start for x in target_slice],
251
- [x.start for x in template_slice],
252
- )
253
- ret.template_filter = self.template_filter
252
+ target_offset = np.zeros(len(self._output_target_shape), dtype=int)
253
+ target_offset[(target_offset.size - len(target_slice)) :] = [
254
+ x.start for x in target_slice
255
+ ]
256
+ template_offset = np.zeros(len(self._output_target_shape), dtype=int)
257
+ template_offset[(template_offset.size - len(template_slice)) :] = [
258
+ x.start for x in template_slice
259
+ ]
260
+ ret._translation_offset = np.add(target_offset, template_offset)
254
261
 
255
- ret.rotations, ret.indices = self.rotations, indices
262
+ ret.template_filter = self.template_filter
263
+ ret.target_filter = self.target_filter
264
+ ret._rotations, ret.indices = self.rotations, indices
256
265
  ret._target_pad, ret._template_pad = target_pad, template_pad
257
266
  ret._invert_target = self._invert_target
258
267
 
@@ -282,10 +291,14 @@ class MatchingData:
282
291
  """
283
292
  Transfer the class instance's numpy arrays to the current backend.
284
293
  """
294
+ backend_arr = type(backend.zeros((1), dtype=backend._default_dtype))
285
295
  for attr_name, attr_value in vars(self).items():
286
296
  if isinstance(attr_value, np.ndarray):
287
297
  converted_array = backend.to_backend_array(attr_value.copy())
288
298
  setattr(self, attr_name, converted_array)
299
+ elif isinstance(attr_value, backend_arr):
300
+ converted_array = backend.to_backend_array(attr_value)
301
+ setattr(self, attr_name, converted_array)
289
302
 
290
303
  self._default_dtype = backend._default_dtype
291
304
  self._complex_dtype = backend._complex_dtype
@@ -427,13 +440,13 @@ class MatchingData:
427
440
  An array indicating the padding for each dimension of the target.
428
441
  """
429
442
  target_padding = backend.zeros(
430
- len(self.target.shape), dtype=backend._default_dtype_int
443
+ len(self._output_target_shape), dtype=backend._default_dtype_int
431
444
  )
432
445
 
433
446
  if pad_target:
434
447
  backend.subtract(
435
- self._template.shape,
436
- backend.mod(self._template.shape, 2),
448
+ self._output_template_shape,
449
+ backend.mod(self._output_template_shape, 2),
437
450
  out=target_padding,
438
451
  )
439
452
  if hasattr(self, "_is_target_batch"):
@@ -475,22 +488,34 @@ class MatchingData:
475
488
  fourier_shift = backend.zeros(len(fourier_pad))
476
489
 
477
490
  if not pad_fourier:
478
- fourier_pad = backend.full(shape=len(fourier_pad), fill_value=1, dtype=int)
491
+ fourier_pad = backend.full(
492
+ shape=(len(fourier_pad),),
493
+ fill_value=1,
494
+ dtype=backend._default_dtype_int,
495
+ )
479
496
 
480
497
  fourier_pad = backend.to_backend_array(fourier_pad)
481
498
  if hasattr(self, "_batch_mask"):
482
499
  batch_mask = backend.to_backend_array(self._batch_mask)
483
500
  fourier_pad[batch_mask] = 1
484
501
 
485
- ret = backend.compute_convolution_shapes(target_shape, fourier_pad)
502
+ pad_shape = backend.maximum(target_shape, template_shape)
503
+ ret = backend.compute_convolution_shapes(pad_shape, fourier_pad)
486
504
  convolution_shape, fast_shape, fast_ft_shape = ret
487
505
  if not pad_fourier:
488
506
  fourier_shift = 1 - backend.astype(backend.divide(template_shape, 2), int)
489
507
  fourier_shift -= backend.mod(template_shape, 2)
490
508
  shape_diff = backend.subtract(fast_shape, convolution_shape)
491
509
  shape_diff = backend.astype(backend.divide(shape_diff, 2), int)
510
+
511
+ if hasattr(self, "_batch_mask"):
512
+ batch_mask = backend.to_backend_array(self._batch_mask)
513
+ shape_diff[batch_mask] = 0
514
+
492
515
  backend.add(fourier_shift, shape_diff, out=fourier_shift)
493
516
 
517
+ fourier_shift = backend.astype(fourier_shift, backend._default_dtype_int)
518
+
494
519
  return fast_shape, fast_ft_shape, fourier_shift
495
520
 
496
521
  @property
@@ -523,15 +548,21 @@ class MatchingData:
523
548
  def target(self):
524
549
  """Returns the target NDArray."""
525
550
  if isinstance(self._target, Density):
526
- return self._target.data
527
- return self._target
551
+ target = self._target.data
552
+ else:
553
+ target = self._target
554
+ out_shape = backend.to_numpy_array(self._output_target_shape)
555
+ return target.reshape(tuple(int(x) for x in out_shape))
528
556
 
529
557
  @property
530
558
  def template(self):
531
559
  """Returns the reversed template NDArray."""
560
+ template = self._template
532
561
  if isinstance(self._template, Density):
533
- return backend.reverse(self._template.data)
534
- return backend.reverse(self._template)
562
+ template = self._template.data
563
+ template = backend.reverse(template)
564
+ out_shape = backend.to_numpy_array(self._output_template_shape)
565
+ return template.reshape(tuple(int(x) for x in out_shape))
535
566
 
536
567
  @template.setter
537
568
  def template(self, template: NDArray):
@@ -558,9 +589,15 @@ class MatchingData:
558
589
  @property
559
590
  def target_mask(self):
560
591
  """Returns the target mask NDArray."""
592
+ target_mask = self._target_mask
561
593
  if isinstance(self._target_mask, Density):
562
- return self._target_mask.data
563
- return self._target_mask
594
+ target_mask = self._target_mask.data
595
+
596
+ if target_mask is not None:
597
+ out_shape = backend.to_numpy_array(self._output_target_shape)
598
+ target_mask = target_mask.reshape(tuple(int(x) for x in out_shape))
599
+
600
+ return target_mask
564
601
 
565
602
  @target_mask.setter
566
603
  def target_mask(self, mask: NDArray):
@@ -587,9 +624,15 @@ class MatchingData:
587
624
  template : NDArray
588
625
  Array to set as the template.
589
626
  """
627
+ mask = self._template_mask
590
628
  if isinstance(self._template_mask, Density):
591
- return backend.reverse(self._template_mask.data)
592
- return backend.reverse(self._template_mask)
629
+ mask = self._template_mask.data
630
+
631
+ if mask is not None:
632
+ mask = backend.reverse(mask)
633
+ out_shape = backend.to_numpy_array(self._output_template_shape)
634
+ mask = mask.reshape(tuple(int(x) for x in out_shape))
635
+ return mask
593
636
 
594
637
  @template_mask.setter
595
638
  def template_mask(self, mask: NDArray):