pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/postprocess.py +35 -21
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.post1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/METADATA +5 -7
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/RECORD +55 -48
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +35 -21
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_analyzer.py +2 -3
  20. tests/test_backends.py +3 -9
  21. tests/test_density.py +0 -1
  22. tests/test_extensions.py +0 -1
  23. tests/test_matching_utils.py +10 -60
  24. tests/test_rotations.py +1 -1
  25. tme/__version__.py +1 -1
  26. tme/analyzer/_utils.py +4 -4
  27. tme/analyzer/aggregation.py +35 -15
  28. tme/analyzer/peaks.py +11 -10
  29. tme/backends/_jax_utils.py +26 -13
  30. tme/backends/_numpyfftw_utils.py +270 -0
  31. tme/backends/cupy_backend.py +16 -55
  32. tme/backends/jax_backend.py +76 -37
  33. tme/backends/matching_backend.py +17 -51
  34. tme/backends/mlx_backend.py +1 -27
  35. tme/backends/npfftw_backend.py +71 -65
  36. tme/backends/pytorch_backend.py +1 -26
  37. tme/density.py +2 -6
  38. tme/extensions.cpython-311-darwin.so +0 -0
  39. tme/filters/ctf.py +22 -21
  40. tme/filters/wedge.py +10 -7
  41. tme/mask.py +341 -0
  42. tme/matching_data.py +31 -19
  43. tme/matching_exhaustive.py +37 -47
  44. tme/matching_optimization.py +2 -1
  45. tme/matching_scores.py +229 -411
  46. tme/matching_utils.py +73 -422
  47. tme/memory.py +1 -1
  48. tme/orientations.py +13 -8
  49. tme/rotations.py +1 -1
  50. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  51. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/estimate_memory_usage.py +0 -0
  52. {pytme-0.3b0.post1.data → pytme-0.3.1.post1.data}/scripts/preprocess.py +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/WHEEL +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/entry_points.txt +0 -0
  55. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/licenses/LICENSE +0 -0
  56. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.post1.dist-info}/top_level.txt +0 -0
tme/mask.py ADDED
@@ -0,0 +1,341 @@
1
+ """
2
+ Utility functions for generating template matching masks.
3
+
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+
9
+ import numpy as np
10
+ from typing import Tuple, Optional
11
+
12
+ from .types import NDArray
13
+ from scipy.ndimage import gaussian_filter
14
+ from .matching_utils import rigid_transform
15
+
16
+ __all__ = ["elliptical_mask", "tube_mask", "box_mask", "membrane_mask"]
17
+
18
+
19
+ def elliptical_mask(
20
+ shape: Tuple[int],
21
+ radius: Tuple[float],
22
+ center: Optional[Tuple[float]] = None,
23
+ orientation: Optional[NDArray] = None,
24
+ sigma_decay: float = 0.0,
25
+ cutoff_sigma: float = 3,
26
+ **kwargs,
27
+ ) -> NDArray:
28
+ """
29
+ Creates an ellipsoidal mask.
30
+
31
+ Parameters
32
+ ----------
33
+ shape : tuple of ints
34
+ Shape of the mask to be created.
35
+ radius : tuple of floats
36
+ Radius of the mask.
37
+ center : tuple of floats, optional
38
+ Center of the mask, default to shape // 2.
39
+ orientation : NDArray, optional.
40
+ Orientation of the mask as rotation matrix with shape (d,d).
41
+
42
+ Returns
43
+ -------
44
+ NDArray
45
+ The created ellipsoidal mask.
46
+
47
+ Raises
48
+ ------
49
+ ValueError
50
+ If the length of center and radius is not one or the same as shape.
51
+
52
+ Examples
53
+ --------
54
+ >>> from tme.matching_utils import elliptical_mask
55
+ >>> mask = elliptical_mask(shape=(20,20), radius=(5,5), center=(10,10))
56
+ """
57
+ shape, radius = np.asarray(shape), np.asarray(radius)
58
+
59
+ shape = shape.astype(int)
60
+ if center is None:
61
+ center = np.divide(shape, 2).astype(int)
62
+
63
+ center = np.asarray(center, dtype=np.float32)
64
+ radius = np.repeat(radius, shape.size // radius.size)
65
+ center = np.repeat(center, shape.size // center.size)
66
+ if radius.size != shape.size:
67
+ raise ValueError("Length of radius has to be either one or match shape.")
68
+ if center.size != shape.size:
69
+ raise ValueError("Length of center has to be either one or match shape.")
70
+
71
+ n = shape.size
72
+ center = center.reshape((-1,) + (1,) * n)
73
+ radius = radius.reshape((-1,) + (1,) * n)
74
+
75
+ indices = np.indices(shape, dtype=np.float32) - center
76
+ if orientation is not None:
77
+ return_shape = indices.shape
78
+ indices = indices.reshape(n, -1)
79
+ rigid_transform(
80
+ coordinates=indices,
81
+ rotation_matrix=np.asarray(orientation),
82
+ out=indices,
83
+ translation=np.zeros(n),
84
+ use_geometric_center=False,
85
+ )
86
+ indices = indices.reshape(*return_shape)
87
+
88
+ dist = np.linalg.norm(indices / radius, axis=0)
89
+ if sigma_decay > 0:
90
+ sigma_decay = 2 * (sigma_decay / np.mean(radius)) ** 2
91
+ mask = np.maximum(0, dist - 1)
92
+ mask = np.exp(-(mask**2) / sigma_decay)
93
+ mask *= mask > np.exp(-(cutoff_sigma**2) / 2)
94
+ else:
95
+ mask = (dist <= 1).astype(int)
96
+ return mask
97
+
98
+
99
+ def box_mask(
100
+ shape: Tuple[int],
101
+ center: Tuple[int],
102
+ size: Tuple[int],
103
+ sigma_decay: float = 0.0,
104
+ cutoff_sigma: float = 3.0,
105
+ **kwargs,
106
+ ) -> np.ndarray:
107
+ """
108
+ Creates a box mask centered around the provided center point.
109
+
110
+ Parameters
111
+ ----------
112
+ shape : tuple of ints
113
+ Shape of the output array.
114
+ center : tuple of ints
115
+ Center point coordinates of the box.
116
+ size : tuple of ints
117
+ Side length of the box along each axis.
118
+
119
+ Returns
120
+ -------
121
+ NDArray
122
+ The created box mask.
123
+
124
+ Raises
125
+ ------
126
+ ValueError
127
+ If ``shape`` and ``center`` do not have the same length.
128
+ If ``center`` and ``height`` do not have the same length.
129
+ """
130
+ if len(shape) != len(center) or len(center) != len(size):
131
+ raise ValueError("The length of shape, center, and height must be consistent.")
132
+
133
+ shape = tuple(int(x) for x in shape)
134
+ center, size = np.array(center, dtype=int), np.array(size, dtype=int)
135
+
136
+ half_heights = size // 2
137
+ starts = np.maximum(center - half_heights, 0)
138
+ stops = np.minimum(center + half_heights + np.mod(size, 2) + 1, shape)
139
+ slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
140
+
141
+ out = np.zeros(shape)
142
+ out[slice_indices] = 1
143
+
144
+ if sigma_decay > 0:
145
+ mask_filter = gaussian_filter(
146
+ out.astype(np.float32), sigma=sigma_decay, truncate=cutoff_sigma
147
+ )
148
+ out = np.add(out, (1 - out) * mask_filter)
149
+ out *= out > np.exp(-(cutoff_sigma**2) / 2)
150
+ return out
151
+
152
+
153
+ def tube_mask(
154
+ shape: Tuple[int],
155
+ symmetry_axis: int,
156
+ center: Tuple[int],
157
+ inner_radius: float,
158
+ outer_radius: float,
159
+ height: int,
160
+ sigma_decay: float = 0.0,
161
+ cutoff_sigma: float = 3.0,
162
+ **kwargs,
163
+ ) -> NDArray:
164
+ """
165
+ Creates a tube mask.
166
+
167
+ Parameters
168
+ ----------
169
+ shape : tuple
170
+ Shape of the mask to be created.
171
+ symmetry_axis : int
172
+ The axis of symmetry for the tube.
173
+ base_center : tuple
174
+ Center of the tube.
175
+ inner_radius : float
176
+ Inner radius of the tube.
177
+ outer_radius : float
178
+ Outer radius of the tube.
179
+ height : int
180
+ Height of the tube.
181
+
182
+ Returns
183
+ -------
184
+ NDArray
185
+ The created tube mask.
186
+
187
+ Raises
188
+ ------
189
+ ValueError
190
+ If ``inner_radius`` is larger than ``outer_radius``.
191
+ If ``height`` is larger than the symmetry axis.
192
+ If ``base_center`` and ``shape`` do not have the same length.
193
+ """
194
+ if inner_radius > outer_radius:
195
+ raise ValueError("inner_radius should be smaller than outer_radius.")
196
+
197
+ if height > shape[symmetry_axis]:
198
+ raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
199
+
200
+ if symmetry_axis > len(shape):
201
+ raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
202
+
203
+ if len(center) != len(shape):
204
+ raise ValueError("shape and base_center need to have the same length.")
205
+
206
+ shape = tuple(int(x) for x in shape)
207
+ circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
208
+ circle_center = tuple(b for ix, b in enumerate(center) if ix != symmetry_axis)
209
+
210
+ inner_circle = np.zeros(circle_shape)
211
+ outer_circle = np.zeros_like(inner_circle)
212
+ if inner_radius > 0:
213
+ inner_circle = elliptical_mask(
214
+ shape=circle_shape,
215
+ radius=inner_radius,
216
+ center=circle_center,
217
+ sigma_decay=sigma_decay,
218
+ cutoff_sigma=cutoff_sigma,
219
+ )
220
+ if outer_radius > 0:
221
+ outer_circle = elliptical_mask(
222
+ shape=circle_shape,
223
+ radius=outer_radius,
224
+ center=circle_center,
225
+ sigma_decay=sigma_decay,
226
+ cutoff_sigma=cutoff_sigma,
227
+ )
228
+ circle = outer_circle - inner_circle
229
+ circle = np.expand_dims(circle, axis=symmetry_axis)
230
+
231
+ center = center[symmetry_axis]
232
+ start_idx = int(center - height // 2)
233
+ stop_idx = int(center + height // 2 + height % 2)
234
+
235
+ start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis])
236
+
237
+ slice_indices = tuple(
238
+ slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
239
+ for i in range(len(shape))
240
+ )
241
+ tube = np.zeros(shape)
242
+ tube[slice_indices] = circle
243
+
244
+ return tube
245
+
246
+
247
+ def membrane_mask(
248
+ shape: Tuple[int],
249
+ radius: float,
250
+ thickness: float,
251
+ separation: float,
252
+ symmetry_axis: int = 2,
253
+ center: Optional[Tuple[float]] = None,
254
+ sigma_decay: float = 0.5,
255
+ cutoff_sigma: float = 3,
256
+ **kwargs,
257
+ ) -> NDArray:
258
+ """
259
+ Creates a membrane mask consisting of two parallel disks with Gaussian intensity profile.
260
+ Uses efficient broadcasting approach: flat disk mask × height profile.
261
+
262
+ Parameters
263
+ ----------
264
+ shape : tuple of ints
265
+ Shape of the mask to be created.
266
+ radius : float
267
+ Radius of the membrane disks.
268
+ thickness : float
269
+ Thickness of each disk in the membrane.
270
+ separation : float
271
+ Distance between the centers of the two disks.
272
+ symmetry_axis : int, optional
273
+ The axis perpendicular to the membrane disks, defaults to 2.
274
+ center : tuple of floats, optional
275
+ Center of the membrane (midpoint between the two disks), defaults to shape // 2.
276
+ sigma_decay : float, optional
277
+ Controls edge sharpness relative to radius, defaults to 0.5.
278
+ cutoff_sigma : float, optional
279
+ Cutoff for height profile in standard deviations, defaults to 3.
280
+
281
+ Returns
282
+ -------
283
+ NDArray
284
+ The created membrane mask with Gaussian intensity profile.
285
+
286
+ Raises
287
+ ------
288
+ ValueError
289
+ If ``thickness`` is negative.
290
+ If ``separation`` is negative.
291
+ If ``center`` and ``shape`` do not have the same length.
292
+ If ``symmetry_axis`` is out of bounds.
293
+
294
+ Examples
295
+ --------
296
+ >>> from tme.matching_utils import membrane_mask
297
+ >>> mask = membrane_mask(shape=(50,50,50), radius=10, thickness=2, separation=15)
298
+ """
299
+ shape = np.asarray(shape, dtype=int)
300
+
301
+ if center is None:
302
+ center = np.divide(shape, 2).astype(float)
303
+
304
+ center = np.asarray(center, dtype=np.float32)
305
+ center = np.repeat(center, shape.size // center.size)
306
+
307
+ if thickness < 0:
308
+ raise ValueError("thickness must be non-negative.")
309
+ if separation < 0:
310
+ raise ValueError("separation must be non-negative.")
311
+ if symmetry_axis >= len(shape):
312
+ raise ValueError(f"symmetry_axis must be less than {len(shape)}.")
313
+ if center.size != shape.size:
314
+ raise ValueError("Length of center has to be either one or match shape.")
315
+
316
+ disk_mask = elliptical_mask(
317
+ shape=[x for i, x in enumerate(shape) if i != symmetry_axis],
318
+ radius=radius,
319
+ sigma_decay=sigma_decay,
320
+ cutoff_sigma=cutoff_sigma,
321
+ )
322
+
323
+ axial_coord = np.arange(shape[symmetry_axis]) - center[symmetry_axis]
324
+ height_profile = np.zeros((shape[symmetry_axis],), dtype=np.float32)
325
+ for leaflet_pos in [-separation / 2, separation / 2]:
326
+ leaflet_profile = np.exp(
327
+ -((axial_coord - leaflet_pos) ** 2) / (2 * (thickness / 3) ** 2)
328
+ )
329
+ cutoff_threshold = np.exp(-(cutoff_sigma**2) / 2)
330
+ leaflet_profile *= leaflet_profile > cutoff_threshold
331
+
332
+ height_profile = np.maximum(height_profile, leaflet_profile)
333
+
334
+ disk_mask = disk_mask.reshape(
335
+ [x if i != symmetry_axis else 1 for i, x in enumerate(shape)]
336
+ )
337
+ height_profile = height_profile.reshape(
338
+ [1 if i != symmetry_axis else x for i, x in enumerate(shape)]
339
+ )
340
+
341
+ return disk_mask * height_profile
tme/matching_data.py CHANGED
@@ -128,11 +128,8 @@ class MatchingData:
128
128
  slice_start = np.array([x.start for x in arr_slice], dtype=int)
129
129
  slice_stop = np.array([x.stop for x in arr_slice], dtype=int)
130
130
 
131
- # We are deviating from our typical right_pad + mod here
132
- # because cropping from full convolution mode to target shape
133
- # is defined from the perspective of the origin
134
- right_pad = np.divide(padding, 2).astype(int)
135
- left_pad = np.add(right_pad, np.mod(padding, 2))
131
+ left_pad = np.divide(padding, 2).astype(int)
132
+ right_pad = np.add(left_pad, np.mod(padding, 2))
136
133
 
137
134
  data_voxels_left = np.minimum(slice_start, left_pad)
138
135
  data_voxels_right = np.minimum(
@@ -253,7 +250,7 @@ class MatchingData:
253
250
  target_offset[mask] = [x.start for x in target_slice]
254
251
  mask = np.subtract(1, self._target_batch).astype(bool)
255
252
  template_offset = np.zeros(len(self._output_template_shape), dtype=int)
256
- template_offset[mask] = [x.start for x, b in zip(template_slice, mask) if b]
253
+ template_offset[mask] = [x.start for x in template_slice]
257
254
 
258
255
  translation_offset = tuple(x for x in target_offset)
259
256
 
@@ -299,7 +296,7 @@ class MatchingData:
299
296
 
300
297
  def set_matching_dimension(self, target_dim: int = None, template_dim: int = None):
301
298
  """
302
- Sets matching dimensions for target and template.
299
+ Sets matching batch dimensions for target and template.
303
300
 
304
301
  Parameters
305
302
  ----------
@@ -490,8 +487,8 @@ class MatchingData:
490
487
  def _fourier_padding(
491
488
  target_shape: Tuple[int],
492
489
  template_shape: Tuple[int],
493
- pad_target: bool = False,
494
490
  batch_mask: Tuple[int] = None,
491
+ **kwargs,
495
492
  ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
496
493
  if batch_mask is None:
497
494
  batch_mask = np.zeros_like(template_shape)
@@ -503,7 +500,7 @@ class MatchingData:
503
500
 
504
501
  # Avoid padding batch dimensions
505
502
  pad_shape = np.maximum(target_shape, template_shape)
506
- pad_shape = np.maximum(target_shape, np.multiply(1 - batch_mask, pad_shape))
503
+ pad_shape = np.maximum(pad_shape, np.multiply(1 - batch_mask, pad_shape))
507
504
  ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
508
505
  conv_shape, fast_shape, fast_ft_shape = ret
509
506
 
@@ -525,21 +522,13 @@ class MatchingData:
525
522
  shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
526
523
  fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
527
524
 
528
- if pad_target:
529
- fourier_shift = np.subtract(fourier_shift, np.subtract(1, template_mod))
530
-
531
525
  fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
532
526
  return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
533
527
 
534
- def fourier_padding(self, pad_target: bool = False) -> Tuple:
528
+ def fourier_padding(self, **kwargs) -> Tuple:
535
529
  """
536
530
  Computes efficient shape four Fourier transforms and potential associated shifts.
537
531
 
538
- Parameters
539
- ----------
540
- pad_target : bool, optional
541
- Whether the target has been padded to the full convolution shape.
542
-
543
532
  Returns
544
533
  -------
545
534
  Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
@@ -553,9 +542,32 @@ class MatchingData:
553
542
  target_shape=be.to_numpy_array(self._output_target_shape),
554
543
  template_shape=be.to_numpy_array(self._output_template_shape),
555
544
  batch_mask=be.to_numpy_array(self._batch_mask),
556
- pad_target=pad_target,
557
545
  )
558
546
 
547
+ def _score_mask(self, fast_shape: Tuple[int], shift: Tuple[int]) -> BackendArray:
548
+ """
549
+ Create a boolean mask to exclude scores derived from padding in template matching.
550
+ """
551
+ padding = self.target_padding(True)
552
+ offset = tuple(x // 2 for x in padding)
553
+ shape = tuple(y - x for x, y in zip(padding, self.target.shape))
554
+
555
+ subset = []
556
+ for i in range(len(offset)):
557
+ if self._batch_mask[i]:
558
+ subset.append(slice(None))
559
+ else:
560
+ subset.append(slice(offset[i], offset[i] + shape[i]))
561
+
562
+ score_mask = np.zeros(fast_shape, dtype=bool)
563
+ score_mask[tuple(subset)] = 1
564
+ score_mask = np.roll(
565
+ score_mask,
566
+ shift=tuple(-x for x in shift),
567
+ axis=tuple(i for i in range(len(shift))),
568
+ )
569
+ return be.to_backend_array(score_mask)
570
+
559
571
  def computation_schedule(
560
572
  self,
561
573
  matching_method: str = "FLCSphericalMask",
@@ -40,7 +40,7 @@ def _setup_template_filter_apply_target_filter(
40
40
  matching_data: MatchingData,
41
41
  fast_shape: Tuple[int],
42
42
  fast_ft_shape: Tuple[int],
43
- pad_template_filter: bool = True,
43
+ pad_template_filter: bool = False,
44
44
  ):
45
45
  target_filter = None
46
46
  backend_arr = type(be.zeros((1), dtype=be._float_dtype))
@@ -146,11 +146,10 @@ def scan(
146
146
  matching_data: MatchingData,
147
147
  matching_setup: Callable,
148
148
  matching_score: Callable,
149
- n_jobs: int = 4,
150
- callback_class: CallbackClass = None,
149
+ callback_class: CallbackClass,
151
150
  callback_class_args: Dict = {},
151
+ n_jobs: int = 4,
152
152
  pad_target: bool = True,
153
- pad_template_filter: bool = True,
154
153
  interpolation_order: int = 3,
155
154
  jobs_per_callback_class: int = 8,
156
155
  shm_handler=None,
@@ -172,20 +171,22 @@ def scan(
172
171
  Function pointer to scoring function.
173
172
  n_jobs : int, optional
174
173
  Number of parallel jobs. Default is 4.
175
- callback_class : type, optional
174
+ callback_class : type
176
175
  Analyzer class pointer to operate on computed scores.
177
176
  callback_class_args : dict, optional
178
177
  Arguments passed to the callback_class. Default is an empty dictionary.
179
178
  pad_target: bool, optional
180
179
  Whether to pad target to the full convolution shape.
181
- pad_template_filter: bool, optional
182
- Whether to pad potential template filters to the full convolution shape.
183
180
  interpolation_order : int, optional
184
181
  Order of spline interpolation for rotations.
185
182
  jobs_per_callback_class : int, optional
186
183
  Number of jobs a callback_class instance is shared between, 8 by default.
187
184
  shm_handler : type, optional
188
185
  Manager for shared memory objects, None by default.
186
+ target_slice : tuple of slice, optional
187
+ Target subset to process.
188
+ template_slice : tuple of slice, optional
189
+ Template subset to process.
189
190
 
190
191
  Returns
191
192
  -------
@@ -220,13 +221,17 @@ def scan(
220
221
  template_shape = matching_data._batch_shape(
221
222
  matching_data.template.shape, matching_data._template_batch
222
223
  )
223
- conv, fwd, inv, shift = matching_data.fourier_padding(pad_target=pad_target)
224
+ conv, fwd, inv, shift = matching_data.fourier_padding()
225
+
226
+ score_mask = be.full(shape=(1,), fill_value=1, dtype=bool)
227
+ if pad_target:
228
+ score_mask = matching_data._score_mask(fwd, shift)
224
229
 
225
230
  template_filter = _setup_template_filter_apply_target_filter(
226
231
  matching_data=matching_data,
227
232
  fast_shape=fwd,
228
233
  fast_ft_shape=inv,
229
- pad_template_filter=pad_template_filter,
234
+ pad_template_filter=False,
230
235
  )
231
236
 
232
237
  default_callback_args = {
@@ -240,11 +245,10 @@ def scan(
240
245
  "thread_safe": n_jobs > 1,
241
246
  "convolution_mode": "valid" if pad_target else "same",
242
247
  "shm_handler": shm_handler,
243
- "only_unique_rotations": True,
244
248
  "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
245
249
  "n_rotations": matching_data.rotations.shape[0],
250
+ "inversion_mapping": n_jobs == 1,
246
251
  }
247
- callback_class_args["inversion_mapping"] = n_jobs == 1
248
252
  default_callback_args.update(callback_class_args)
249
253
 
250
254
  setup = matching_setup(
@@ -254,22 +258,14 @@ def scan(
254
258
  fast_ft_shape=inv,
255
259
  shm_handler=shm_handler,
256
260
  )
257
- setup["interpolation_order"] = interpolation_order
258
- setup["template_filter"] = be.to_sharedarr(template_filter, shm_handler)
259
261
 
260
262
  matching_data._free_data()
261
- be.free_cache()
262
-
263
263
  n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
264
264
  callback_classes = [
265
- (
266
- SharedAnalyzerProxy(
267
- callback_class,
268
- default_callback_args,
269
- shm_handler=shm_handler if n_jobs > 1 else None,
270
- )
271
- if callback_class
272
- else None
265
+ SharedAnalyzerProxy(
266
+ callback_class,
267
+ default_callback_args,
268
+ shm_handler=shm_handler if n_jobs > 1 else None,
273
269
  )
274
270
  for _ in range(n_callback_classes)
275
271
  ]
@@ -277,35 +273,33 @@ def scan(
277
273
  delayed(_wrap_backend(matching_score))(
278
274
  backend_name=be._backend_name,
279
275
  backend_args=be._backend_args,
276
+ fast_shape=fwd,
277
+ fast_ft_shape=inv,
280
278
  rotations=rotation,
281
279
  callback=callback_classes[index % n_callback_classes],
280
+ interpolation_order=interpolation_order,
281
+ template_filter=be.to_sharedarr(template_filter, shm_handler),
282
+ score_mask=be.to_sharedarr(score_mask, shm_handler),
282
283
  **setup,
283
284
  )
284
285
  for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
285
286
  )
286
- callbacks = [
287
- callback.result(**default_callback_args)
288
- for callback in ret[:n_callback_classes]
289
- if callback
290
- ]
291
287
  be.free_cache()
292
288
 
293
- if callback_class:
294
- ret = callback_class.merge(callbacks, **default_callback_args)
295
- return ret
289
+ callbacks = [x.result(**default_callback_args) for x in ret[:n_callback_classes]]
290
+ return callback_class.merge(callbacks, **default_callback_args)
296
291
 
297
292
 
298
293
  def scan_subsets(
299
294
  matching_data: MatchingData,
300
295
  matching_score: Callable,
301
296
  matching_setup: Callable,
302
- callback_class: CallbackClass = None,
297
+ callback_class: CallbackClass,
303
298
  callback_class_args: Dict = {},
304
299
  job_schedule: Tuple[int] = (1, 1),
305
300
  target_splits: Dict = {},
306
301
  template_splits: Dict = {},
307
302
  pad_target_edges: bool = False,
308
- pad_template_filter: bool = True,
309
303
  interpolation_order: int = 3,
310
304
  jobs_per_callback_class: int = 8,
311
305
  backend_name: str = None,
@@ -325,7 +319,7 @@ def scan_subsets(
325
319
  Function pointer to setup function.
326
320
  matching_score : type
327
321
  Function pointer to scoring function.
328
- callback_class : type, optional
322
+ callback_class : type
329
323
  Analyzer class pointer to operate on computed scores.
330
324
  callback_class_args : dict, optional
331
325
  Arguments passed to the callback_class. Default is an empty dictionary.
@@ -341,8 +335,6 @@ def scan_subsets(
341
335
  See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
342
336
  pad_target_edges : bool, optional
343
337
  Pad the target boundaries to avoid edge effects.
344
- pad_template_filter: bool, optional
345
- Whether to pad potential template filters to the full convolution shape.
346
338
  interpolation_order : int, optional
347
339
  Order of spline interpolation for rotations.
348
340
  jobs_per_callback_class : int, optional
@@ -424,18 +416,22 @@ def scan_subsets(
424
416
  )
425
417
  splits = tuple(product(target_splits, template_splits))
426
418
 
419
+ kwargs = {
420
+ "matching_data": matching_data,
421
+ "callback_class": callback_class,
422
+ "callback_class_args": callback_class_args,
423
+ }
424
+
427
425
  outer_jobs, inner_jobs = job_schedule
428
426
  if be._backend_name == "jax":
429
427
  func = be.scan
430
428
 
431
429
  corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
432
430
  results = func(
433
- matching_data=matching_data,
434
431
  splits=splits,
435
432
  n_jobs=outer_jobs,
436
433
  rotate_mask=matching_score != corr_scoring,
437
- callback_class=callback_class,
438
- callback_class_args=callback_class_args,
434
+ **kwargs,
439
435
  )
440
436
  else:
441
437
  results = Parallel(n_jobs=outer_jobs, verbose=verbose)(
@@ -443,26 +439,20 @@ def scan_subsets(
443
439
  delayed(_wrap_backend(scan))(
444
440
  backend_name=be._backend_name,
445
441
  backend_args=be._backend_args,
446
- matching_data=matching_data,
447
442
  matching_score=matching_score,
448
443
  matching_setup=matching_setup,
449
444
  n_jobs=inner_jobs,
450
- callback_class=callback_class,
451
- callback_class_args=callback_class_args,
452
445
  interpolation_order=interpolation_order,
453
446
  pad_target=pad_target_edges,
454
447
  gpu_index=index % outer_jobs,
455
- pad_template_filter=pad_template_filter,
456
448
  target_slice=target_split,
457
449
  template_slice=template_split,
450
+ **kwargs,
458
451
  )
459
452
  for index, (target_split, template_split) in enumerate(splits)
460
453
  ]
461
454
  )
462
- matching_data._free_data()
463
- if callback_class is not None:
464
- return callback_class.merge(results, **callback_class_args)
465
- return None
455
+ return callback_class.merge(results, **callback_class_args)
466
456
 
467
457
 
468
458
  def register_matching_exhaustive(
@@ -1104,7 +1104,8 @@ def create_score_object(score: str, **kwargs) -> object:
1104
1104
  Examples
1105
1105
  --------
1106
1106
  >>> from tme import Density
1107
- >>> from tme.matching_utils import create_mask, euler_to_rotationmatrix
1107
+ >>> from tme.mask import create_mask
1108
+ >>> from tme.matching_utils import euler_to_rotationmatrix
1108
1109
  >>> from tme.matching_optimization import CrossCorrelation, optimize_match
1109
1110
  >>> translation, rotation = (5, -2, 7), (5, -10, 2)
1110
1111
  >>> target = create_mask(