pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +23 -10
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_backends.py +3 -9
  20. tests/test_density.py +0 -1
  21. tests/test_matching_utils.py +10 -60
  22. tests/test_rotations.py +1 -1
  23. tme/__version__.py +1 -1
  24. tme/analyzer/_utils.py +4 -4
  25. tme/analyzer/aggregation.py +13 -3
  26. tme/analyzer/peaks.py +11 -10
  27. tme/backends/_jax_utils.py +15 -13
  28. tme/backends/_numpyfftw_utils.py +270 -0
  29. tme/backends/cupy_backend.py +5 -44
  30. tme/backends/jax_backend.py +58 -37
  31. tme/backends/matching_backend.py +6 -51
  32. tme/backends/mlx_backend.py +1 -27
  33. tme/backends/npfftw_backend.py +68 -65
  34. tme/backends/pytorch_backend.py +1 -26
  35. tme/density.py +2 -6
  36. tme/extensions.cpython-311-darwin.so +0 -0
  37. tme/filters/ctf.py +22 -21
  38. tme/filters/wedge.py +10 -7
  39. tme/mask.py +341 -0
  40. tme/matching_data.py +7 -19
  41. tme/matching_exhaustive.py +34 -47
  42. tme/matching_optimization.py +2 -1
  43. tme/matching_scores.py +206 -411
  44. tme/matching_utils.py +73 -422
  45. tme/memory.py +1 -1
  46. tme/orientations.py +4 -6
  47. tme/rotations.py +1 -1
  48. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  49. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
  50. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
  51. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  52. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
tme/filters/ctf.py CHANGED
@@ -36,7 +36,7 @@ class CTF(ComposableFilter):
36
36
 
37
37
  #: The shape of the to-be created mask.
38
38
  shape: Tuple[int] = None
39
- #: The defocus value in x direction (in units of sampling rate).
39
+ #: The defocus in x direction (in units of sampling rate).
40
40
  defocus_x: Tuple[float] = None
41
41
  #: The tilt angles.
42
42
  angles: Tuple[float] = None
@@ -164,7 +164,7 @@ class CTF(ComposableFilter):
164
164
  shape : tuple of int
165
165
  The shape of the CTF.
166
166
  defocus_x : tuple of float
167
- The defocus value in x direction.
167
+ The defocus in x direction (in units of sampling rate).
168
168
  angles : tuple of float
169
169
  The tilt angles.
170
170
  opening_axis : int, optional
@@ -178,7 +178,7 @@ class CTF(ComposableFilter):
178
178
  defocus_angle : tuple of float, optional
179
179
  The defocus angle in radians, defaults to 0.
180
180
  defocus_y : tuple of float, optional
181
- The defocus value in y direction, defaults to None.
181
+ The defocus in x direction (in units of sampling rate).
182
182
  correct_defocus_gradient : bool, optional
183
183
  Whether to correct defocus gradient, defaults to False.
184
184
  sampling_rate : tuple of float, optional
@@ -219,14 +219,12 @@ class CTF(ComposableFilter):
219
219
  corrected_tilt_axis -= 1
220
220
 
221
221
  for index, angle in enumerate(angles):
222
- defocus_x, defocus_y = defoci_x[index], defoci_y[index]
223
-
224
222
  correction = correct_defocus_gradient and angle is not None
225
223
  chi = create_ctf(
226
224
  angle=angle,
227
225
  shape=ctf_shape,
228
- defocus_x=defocus_x,
229
- defocus_y=defocus_y,
226
+ defocus_x=defoci_x[index],
227
+ defocus_y=defoci_y[index],
230
228
  sampling_rate=sampling_rate,
231
229
  acceleration_voltage=acceleration_voltage[index],
232
230
  correct_defocus_gradient=correction,
@@ -243,12 +241,10 @@ class CTF(ComposableFilter):
243
241
  stack[index] = chi
244
242
 
245
243
  # Avoid contrast inversion
246
- np.negative(stack, out=stack)
244
+ stack = np.negative(stack, out=stack)
247
245
  if flip_phase:
248
- np.abs(stack, out=stack)
249
-
250
- stack = be.to_backend_array(np.squeeze(stack))
251
- return stack
246
+ stack = np.abs(stack, out=stack)
247
+ return be.to_backend_array(np.squeeze(stack))
252
248
 
253
249
 
254
250
  class CTFReconstructed(CTF):
@@ -281,7 +277,7 @@ class CTFReconstructed(CTF):
281
277
  shape : tuple of int
282
278
  The shape of the CTF.
283
279
  defocus_x : tuple of float
284
- The defocus value in x direction.
280
+ The defocus in x direction in units of sampling rate.
285
281
  opening_axis : int, optional
286
282
  The axis around which the wedge is opened, defaults to 2.
287
283
  amplitude_contrast : float, optional
@@ -291,7 +287,7 @@ class CTFReconstructed(CTF):
291
287
  defocus_angle : tuple of float, optional
292
288
  The defocus angle in radians, defaults to 0.
293
289
  defocus_y : tuple of float, optional
294
- The defocus value in y direction, defaults to None.
290
+ The defocus in y direction in units of sampling rate.
295
291
  sampling_rate : tuple of float, optional
296
292
  The sampling rate, defaults to 1.
297
293
  acceleration_voltage : float, optional
@@ -321,18 +317,15 @@ class CTFReconstructed(CTF):
321
317
  defocus_angle=defocus_angle,
322
318
  amplitude_contrast=amplitude_contrast,
323
319
  )
324
- stack = shift_fourier(data=stack, shape_is_real_fourier=False)
325
-
326
320
  # Avoid contrast inversion
327
321
  np.negative(stack, out=stack)
328
322
  if flip_phase:
329
323
  np.abs(stack, out=stack)
330
324
 
331
- stack = be.to_backend_array(np.squeeze(stack))
325
+ stack = shift_fourier(data=stack, shape_is_real_fourier=False)
332
326
  if return_real_fourier:
333
327
  stack = crop_real_fourier(stack)
334
-
335
- return stack
328
+ return be.to_backend_array(np.squeeze(stack))
336
329
 
337
330
 
338
331
  def _from_xml(filename: str) -> Dict:
@@ -436,6 +429,9 @@ def _from_ctffind(filename: str) -> Dict:
436
429
  output[key] = np.array(output[key])
437
430
 
438
431
  output["additional_phase_shift"] = np.degrees(output["additional_phase_shift"])
432
+ cs = output.get("spherical_aberration")
433
+ if cs is not None:
434
+ output["spherical_aberration"] = float(cs) * 1e7
439
435
  return output
440
436
 
441
437
 
@@ -566,7 +562,7 @@ def create_ctf(
566
562
  amplitude_contrast : float, optional
567
563
  Amplitude contrast of microscope, defaults to 0.07.
568
564
  spherical_aberration : float, optional
569
- Spherical aberration of microscope in Angstrom.
565
+ Spherical aberration of microscope in units of sampling rate.
570
566
  angle : float, optional
571
567
  Assume the created CTF is a projection over opening_axis observed at angle.
572
568
  opening_axis : int, optional
@@ -590,10 +586,14 @@ def create_ctf(
590
586
  electron_wavelength = _compute_electron_wavelength(acceleration_voltage)
591
587
  electron_wavelength /= sampling_rate
592
588
  aberration = (spherical_aberration / sampling_rate) * electron_wavelength**2
589
+
590
+ defocus_x = defocus_x / sampling_rate if defocus_x is not None else None
591
+ defocus_y = defocus_y / sampling_rate if defocus_y is not None else None
593
592
  if correct_defocus_gradient or defocus_y is not None:
594
593
  if len(shape) < 2:
595
594
  raise ValueError(f"Length of shape needs to be at least 2, got {shape}")
596
595
 
596
+ # Axial distance from grid center in multiples of sampling rate
597
597
  sampling = tuple(float(x) for x in np.divide(sampling_rate, shape))
598
598
  grid = fftfreqn(
599
599
  shape=shape,
@@ -619,6 +619,7 @@ def create_ctf(
619
619
  defocus_sum = np.add(defocus_x, defocus_y)
620
620
  defocus_difference = np.subtract(defocus_x, defocus_y)
621
621
 
622
+ # Reusing grid, but in principle pure frequencies would suffice
622
623
  angular_grid = np.arctan2(grid[1], grid[0])
623
624
  defocus_difference = np.multiply(
624
625
  defocus_difference,
@@ -627,7 +628,7 @@ def create_ctf(
627
628
  defocus_x = np.add(defocus_sum, defocus_difference)
628
629
  defocus_x *= 0.5
629
630
 
630
- frequency_grid = fftfreqn(shape, sampling_rate=True, compute_euclidean_norm=True)
631
+ frequency_grid = fftfreqn(shape, sampling_rate=1, compute_euclidean_norm=True)
631
632
  if angle is not None and opening_axis is not None and full_shape is not None:
632
633
  frequency_grid = frequency_grid_at_angle(
633
634
  shape=full_shape,
tme/filters/wedge.py CHANGED
@@ -15,7 +15,7 @@ import numpy as np
15
15
  from ..types import NDArray
16
16
  from ..backends import backend as be
17
17
  from .compose import ComposableFilter
18
- from ..matching_utils import centered
18
+ from ..matching_utils import center_slice
19
19
  from ..rotations import euler_to_rotationmatrix
20
20
  from ..parser import XMLParser, StarParser, MDOCParser
21
21
  from ._utils import (
@@ -207,11 +207,10 @@ class Wedge(ComposableFilter):
207
207
  )
208
208
  sigma = np.sqrt(self.weights[index] * 4 / (8 * np.pi**2))
209
209
  sigma = -2 * np.pi**2 * sigma**2
210
- np.square(frequency_grid, out=frequency_grid)
211
- np.multiply(sigma, frequency_grid, out=frequency_grid)
212
- np.exp(frequency_grid, out=frequency_grid)
213
- np.multiply(frequency_grid, np.cos(np.radians(angle)), out=frequency_grid)
214
- wedges[index] = frequency_grid
210
+ frequency_grid = np.square(frequency_grid, out=frequency_grid)
211
+ frequency_grid = np.multiply(sigma, frequency_grid, out=frequency_grid)
212
+ frequency_grid = np.exp(frequency_grid, out=frequency_grid)
213
+ wedges[index] = np.multiply(frequency_grid, np.cos(np.radians(angle)))
215
214
 
216
215
  return wedges
217
216
 
@@ -490,7 +489,11 @@ class WedgeReconstructed:
490
489
  )
491
490
  wedge_volume += plane_rotated * weights[index]
492
491
 
493
- wedge_volume = centered(wedge_volume, (shape[opening_axis], shape[tilt_axis]))
492
+ subset = center_slice(
493
+ wedge_volume.shape, (shape[opening_axis], shape[tilt_axis])
494
+ )
495
+ wedge_volume = wedge_volume[subset]
496
+
494
497
  np.fmin(wedge_volume, np.max(weights), wedge_volume)
495
498
 
496
499
  if opening_axis > tilt_axis:
tme/mask.py ADDED
@@ -0,0 +1,341 @@
1
+ """
2
+ Utility functions for generating template matching masks.
3
+
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+
9
+ import numpy as np
10
+ from typing import Tuple, Optional
11
+
12
+ from .types import NDArray
13
+ from scipy.ndimage import gaussian_filter
14
+ from .matching_utils import rigid_transform
15
+
16
+ __all__ = ["elliptical_mask", "tube_mask", "box_mask", "membrane_mask"]
17
+
18
+
19
+ def elliptical_mask(
20
+ shape: Tuple[int],
21
+ radius: Tuple[float],
22
+ center: Optional[Tuple[float]] = None,
23
+ orientation: Optional[NDArray] = None,
24
+ sigma_decay: float = 0.0,
25
+ cutoff_sigma: float = 3,
26
+ **kwargs,
27
+ ) -> NDArray:
28
+ """
29
+ Creates an ellipsoidal mask.
30
+
31
+ Parameters
32
+ ----------
33
+ shape : tuple of ints
34
+ Shape of the mask to be created.
35
+ radius : tuple of floats
36
+ Radius of the mask.
37
+ center : tuple of floats, optional
38
+ Center of the mask, default to shape // 2.
39
+ orientation : NDArray, optional.
40
+ Orientation of the mask as rotation matrix with shape (d,d).
41
+
42
+ Returns
43
+ -------
44
+ NDArray
45
+ The created ellipsoidal mask.
46
+
47
+ Raises
48
+ ------
49
+ ValueError
50
+ If the length of center and radius is not one or the same as shape.
51
+
52
+ Examples
53
+ --------
54
+ >>> from tme.matching_utils import elliptical_mask
55
+ >>> mask = elliptical_mask(shape=(20,20), radius=(5,5), center=(10,10))
56
+ """
57
+ shape, radius = np.asarray(shape), np.asarray(radius)
58
+
59
+ shape = shape.astype(int)
60
+ if center is None:
61
+ center = np.divide(shape, 2).astype(int)
62
+
63
+ center = np.asarray(center, dtype=np.float32)
64
+ radius = np.repeat(radius, shape.size // radius.size)
65
+ center = np.repeat(center, shape.size // center.size)
66
+ if radius.size != shape.size:
67
+ raise ValueError("Length of radius has to be either one or match shape.")
68
+ if center.size != shape.size:
69
+ raise ValueError("Length of center has to be either one or match shape.")
70
+
71
+ n = shape.size
72
+ center = center.reshape((-1,) + (1,) * n)
73
+ radius = radius.reshape((-1,) + (1,) * n)
74
+
75
+ indices = np.indices(shape, dtype=np.float32) - center
76
+ if orientation is not None:
77
+ return_shape = indices.shape
78
+ indices = indices.reshape(n, -1)
79
+ rigid_transform(
80
+ coordinates=indices,
81
+ rotation_matrix=np.asarray(orientation),
82
+ out=indices,
83
+ translation=np.zeros(n),
84
+ use_geometric_center=False,
85
+ )
86
+ indices = indices.reshape(*return_shape)
87
+
88
+ dist = np.linalg.norm(indices / radius, axis=0)
89
+ if sigma_decay > 0:
90
+ sigma_decay = 2 * (sigma_decay / np.mean(radius)) ** 2
91
+ mask = np.maximum(0, dist - 1)
92
+ mask = np.exp(-(mask**2) / sigma_decay)
93
+ mask *= mask > np.exp(-(cutoff_sigma**2) / 2)
94
+ else:
95
+ mask = (dist <= 1).astype(int)
96
+ return mask
97
+
98
+
99
+ def box_mask(
100
+ shape: Tuple[int],
101
+ center: Tuple[int],
102
+ size: Tuple[int],
103
+ sigma_decay: float = 0.0,
104
+ cutoff_sigma: float = 3.0,
105
+ **kwargs,
106
+ ) -> np.ndarray:
107
+ """
108
+ Creates a box mask centered around the provided center point.
109
+
110
+ Parameters
111
+ ----------
112
+ shape : tuple of ints
113
+ Shape of the output array.
114
+ center : tuple of ints
115
+ Center point coordinates of the box.
116
+ size : tuple of ints
117
+ Side length of the box along each axis.
118
+
119
+ Returns
120
+ -------
121
+ NDArray
122
+ The created box mask.
123
+
124
+ Raises
125
+ ------
126
+ ValueError
127
+ If ``shape`` and ``center`` do not have the same length.
128
+ If ``center`` and ``height`` do not have the same length.
129
+ """
130
+ if len(shape) != len(center) or len(center) != len(size):
131
+ raise ValueError("The length of shape, center, and height must be consistent.")
132
+
133
+ shape = tuple(int(x) for x in shape)
134
+ center, size = np.array(center, dtype=int), np.array(size, dtype=int)
135
+
136
+ half_heights = size // 2
137
+ starts = np.maximum(center - half_heights, 0)
138
+ stops = np.minimum(center + half_heights + np.mod(size, 2) + 1, shape)
139
+ slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
140
+
141
+ out = np.zeros(shape)
142
+ out[slice_indices] = 1
143
+
144
+ if sigma_decay > 0:
145
+ mask_filter = gaussian_filter(
146
+ out.astype(np.float32), sigma=sigma_decay, truncate=cutoff_sigma
147
+ )
148
+ out = np.add(out, (1 - out) * mask_filter)
149
+ out *= out > np.exp(-(cutoff_sigma**2) / 2)
150
+ return out
151
+
152
+
153
+ def tube_mask(
154
+ shape: Tuple[int],
155
+ symmetry_axis: int,
156
+ center: Tuple[int],
157
+ inner_radius: float,
158
+ outer_radius: float,
159
+ height: int,
160
+ sigma_decay: float = 0.0,
161
+ cutoff_sigma: float = 3.0,
162
+ **kwargs,
163
+ ) -> NDArray:
164
+ """
165
+ Creates a tube mask.
166
+
167
+ Parameters
168
+ ----------
169
+ shape : tuple
170
+ Shape of the mask to be created.
171
+ symmetry_axis : int
172
+ The axis of symmetry for the tube.
173
+ base_center : tuple
174
+ Center of the tube.
175
+ inner_radius : float
176
+ Inner radius of the tube.
177
+ outer_radius : float
178
+ Outer radius of the tube.
179
+ height : int
180
+ Height of the tube.
181
+
182
+ Returns
183
+ -------
184
+ NDArray
185
+ The created tube mask.
186
+
187
+ Raises
188
+ ------
189
+ ValueError
190
+ If ``inner_radius`` is larger than ``outer_radius``.
191
+ If ``height`` is larger than the symmetry axis.
192
+ If ``base_center`` and ``shape`` do not have the same length.
193
+ """
194
+ if inner_radius > outer_radius:
195
+ raise ValueError("inner_radius should be smaller than outer_radius.")
196
+
197
+ if height > shape[symmetry_axis]:
198
+ raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
199
+
200
+ if symmetry_axis > len(shape):
201
+ raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
202
+
203
+ if len(center) != len(shape):
204
+ raise ValueError("shape and base_center need to have the same length.")
205
+
206
+ shape = tuple(int(x) for x in shape)
207
+ circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
208
+ circle_center = tuple(b for ix, b in enumerate(center) if ix != symmetry_axis)
209
+
210
+ inner_circle = np.zeros(circle_shape)
211
+ outer_circle = np.zeros_like(inner_circle)
212
+ if inner_radius > 0:
213
+ inner_circle = elliptical_mask(
214
+ shape=circle_shape,
215
+ radius=inner_radius,
216
+ center=circle_center,
217
+ sigma_decay=sigma_decay,
218
+ cutoff_sigma=cutoff_sigma,
219
+ )
220
+ if outer_radius > 0:
221
+ outer_circle = elliptical_mask(
222
+ shape=circle_shape,
223
+ radius=outer_radius,
224
+ center=circle_center,
225
+ sigma_decay=sigma_decay,
226
+ cutoff_sigma=cutoff_sigma,
227
+ )
228
+ circle = outer_circle - inner_circle
229
+ circle = np.expand_dims(circle, axis=symmetry_axis)
230
+
231
+ center = center[symmetry_axis]
232
+ start_idx = int(center - height // 2)
233
+ stop_idx = int(center + height // 2 + height % 2)
234
+
235
+ start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis])
236
+
237
+ slice_indices = tuple(
238
+ slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
239
+ for i in range(len(shape))
240
+ )
241
+ tube = np.zeros(shape)
242
+ tube[slice_indices] = circle
243
+
244
+ return tube
245
+
246
+
247
+ def membrane_mask(
248
+ shape: Tuple[int],
249
+ radius: float,
250
+ thickness: float,
251
+ separation: float,
252
+ symmetry_axis: int = 2,
253
+ center: Optional[Tuple[float]] = None,
254
+ sigma_decay: float = 0.5,
255
+ cutoff_sigma: float = 3,
256
+ **kwargs,
257
+ ) -> NDArray:
258
+ """
259
+ Creates a membrane mask consisting of two parallel disks with Gaussian intensity profile.
260
+ Uses efficient broadcasting approach: flat disk mask × height profile.
261
+
262
+ Parameters
263
+ ----------
264
+ shape : tuple of ints
265
+ Shape of the mask to be created.
266
+ radius : float
267
+ Radius of the membrane disks.
268
+ thickness : float
269
+ Thickness of each disk in the membrane.
270
+ separation : float
271
+ Distance between the centers of the two disks.
272
+ symmetry_axis : int, optional
273
+ The axis perpendicular to the membrane disks, defaults to 2.
274
+ center : tuple of floats, optional
275
+ Center of the membrane (midpoint between the two disks), defaults to shape // 2.
276
+ sigma_decay : float, optional
277
+ Controls edge sharpness relative to radius, defaults to 0.5.
278
+ cutoff_sigma : float, optional
279
+ Cutoff for height profile in standard deviations, defaults to 3.
280
+
281
+ Returns
282
+ -------
283
+ NDArray
284
+ The created membrane mask with Gaussian intensity profile.
285
+
286
+ Raises
287
+ ------
288
+ ValueError
289
+ If ``thickness`` is negative.
290
+ If ``separation`` is negative.
291
+ If ``center`` and ``shape`` do not have the same length.
292
+ If ``symmetry_axis`` is out of bounds.
293
+
294
+ Examples
295
+ --------
296
+ >>> from tme.matching_utils import membrane_mask
297
+ >>> mask = membrane_mask(shape=(50,50,50), radius=10, thickness=2, separation=15)
298
+ """
299
+ shape = np.asarray(shape, dtype=int)
300
+
301
+ if center is None:
302
+ center = np.divide(shape, 2).astype(float)
303
+
304
+ center = np.asarray(center, dtype=np.float32)
305
+ center = np.repeat(center, shape.size // center.size)
306
+
307
+ if thickness < 0:
308
+ raise ValueError("thickness must be non-negative.")
309
+ if separation < 0:
310
+ raise ValueError("separation must be non-negative.")
311
+ if symmetry_axis >= len(shape):
312
+ raise ValueError(f"symmetry_axis must be less than {len(shape)}.")
313
+ if center.size != shape.size:
314
+ raise ValueError("Length of center has to be either one or match shape.")
315
+
316
+ disk_mask = elliptical_mask(
317
+ shape=[x for i, x in enumerate(shape) if i != symmetry_axis],
318
+ radius=radius,
319
+ sigma_decay=sigma_decay,
320
+ cutoff_sigma=cutoff_sigma,
321
+ )
322
+
323
+ axial_coord = np.arange(shape[symmetry_axis]) - center[symmetry_axis]
324
+ height_profile = np.zeros((shape[symmetry_axis],), dtype=np.float32)
325
+ for leaflet_pos in [-separation / 2, separation / 2]:
326
+ leaflet_profile = np.exp(
327
+ -((axial_coord - leaflet_pos) ** 2) / (2 * (thickness / 3) ** 2)
328
+ )
329
+ cutoff_threshold = np.exp(-(cutoff_sigma**2) / 2)
330
+ leaflet_profile *= leaflet_profile > cutoff_threshold
331
+
332
+ height_profile = np.maximum(height_profile, leaflet_profile)
333
+
334
+ disk_mask = disk_mask.reshape(
335
+ [x if i != symmetry_axis else 1 for i, x in enumerate(shape)]
336
+ )
337
+ height_profile = height_profile.reshape(
338
+ [1 if i != symmetry_axis else x for i, x in enumerate(shape)]
339
+ )
340
+
341
+ return disk_mask * height_profile
tme/matching_data.py CHANGED
@@ -128,11 +128,8 @@ class MatchingData:
128
128
  slice_start = np.array([x.start for x in arr_slice], dtype=int)
129
129
  slice_stop = np.array([x.stop for x in arr_slice], dtype=int)
130
130
 
131
- # We are deviating from our typical right_pad + mod here
132
- # because cropping from full convolution mode to target shape
133
- # is defined from the perspective of the origin
134
- right_pad = np.divide(padding, 2).astype(int)
135
- left_pad = np.add(right_pad, np.mod(padding, 2))
131
+ left_pad = np.divide(padding, 2).astype(int)
132
+ right_pad = np.add(left_pad, np.mod(padding, 2))
136
133
 
137
134
  data_voxels_left = np.minimum(slice_start, left_pad)
138
135
  data_voxels_right = np.minimum(
@@ -253,7 +250,7 @@ class MatchingData:
253
250
  target_offset[mask] = [x.start for x in target_slice]
254
251
  mask = np.subtract(1, self._target_batch).astype(bool)
255
252
  template_offset = np.zeros(len(self._output_template_shape), dtype=int)
256
- template_offset[mask] = [x.start for x, b in zip(template_slice, mask) if b]
253
+ template_offset[mask] = [x.start for x in template_slice]
257
254
 
258
255
  translation_offset = tuple(x for x in target_offset)
259
256
 
@@ -299,7 +296,7 @@ class MatchingData:
299
296
 
300
297
  def set_matching_dimension(self, target_dim: int = None, template_dim: int = None):
301
298
  """
302
- Sets matching dimensions for target and template.
299
+ Sets matching batch dimensions for target and template.
303
300
 
304
301
  Parameters
305
302
  ----------
@@ -490,8 +487,8 @@ class MatchingData:
490
487
  def _fourier_padding(
491
488
  target_shape: Tuple[int],
492
489
  template_shape: Tuple[int],
493
- pad_target: bool = False,
494
490
  batch_mask: Tuple[int] = None,
491
+ **kwargs,
495
492
  ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
496
493
  if batch_mask is None:
497
494
  batch_mask = np.zeros_like(template_shape)
@@ -503,7 +500,7 @@ class MatchingData:
503
500
 
504
501
  # Avoid padding batch dimensions
505
502
  pad_shape = np.maximum(target_shape, template_shape)
506
- pad_shape = np.maximum(target_shape, np.multiply(1 - batch_mask, pad_shape))
503
+ pad_shape = np.maximum(pad_shape, np.multiply(1 - batch_mask, pad_shape))
507
504
  ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
508
505
  conv_shape, fast_shape, fast_ft_shape = ret
509
506
 
@@ -525,21 +522,13 @@ class MatchingData:
525
522
  shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
526
523
  fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
527
524
 
528
- if pad_target:
529
- fourier_shift = np.subtract(fourier_shift, np.subtract(1, template_mod))
530
-
531
525
  fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
532
526
  return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
533
527
 
534
- def fourier_padding(self, pad_target: bool = False) -> Tuple:
528
+ def fourier_padding(self, **kwargs) -> Tuple:
535
529
  """
536
530
  Computes efficient shape four Fourier transforms and potential associated shifts.
537
531
 
538
- Parameters
539
- ----------
540
- pad_target : bool, optional
541
- Whether the target has been padded to the full convolution shape.
542
-
543
532
  Returns
544
533
  -------
545
534
  Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
@@ -553,7 +542,6 @@ class MatchingData:
553
542
  target_shape=be.to_numpy_array(self._output_target_shape),
554
543
  template_shape=be.to_numpy_array(self._output_template_shape),
555
544
  batch_mask=be.to_numpy_array(self._batch_mask),
556
- pad_target=pad_target,
557
545
  )
558
546
 
559
547
  def computation_schedule(