pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
  2. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +473 -140
  8. scripts/match_template_filters.py +458 -169
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +278 -148
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +22 -12
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +85 -64
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +86 -60
  24. tme/matching_exhaustive.py +245 -166
  25. tme/matching_optimization.py +137 -69
  26. tme/matching_utils.py +1 -1
  27. tme/orientations.py +175 -55
  28. tme/preprocessing/__init__.py +2 -0
  29. tme/preprocessing/_utils.py +188 -0
  30. tme/preprocessing/composable_filter.py +31 -0
  31. tme/preprocessing/compose.py +51 -0
  32. tme/preprocessing/frequency_filters.py +378 -0
  33. tme/preprocessing/tilt_series.py +1017 -0
  34. tme/preprocessor.py +17 -7
  35. tme/structure.py +4 -1
  36. pytme-0.2.0b0.dist-info/RECORD +0 -66
  37. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  38. {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  39. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  40. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  41. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  42. {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/density.py CHANGED
@@ -40,12 +40,12 @@ from .backends import NumpyFFTWBackend
40
40
 
41
41
  class Density:
42
42
  """
43
- Contains electron density data and implements operations on it.
43
+ Abstract representation of N-dimensional densities.
44
44
 
45
45
  Parameters
46
46
  ----------
47
47
  data : NDArray
48
- Electron density data.
48
+ Array of data values.
49
49
  origin : NDArray, optional
50
50
  Origin of the coordinate system. Defaults to zero.
51
51
  sampling_rate : NDArray, optional
@@ -267,9 +267,9 @@ class Density:
267
267
  # nx := column; ny := row; nz := section
268
268
  start = np.array(
269
269
  [
270
- int(mrc.header["nxstart"]),
271
- int(mrc.header["nystart"]),
272
- int(mrc.header["nzstart"]),
270
+ mrc.header["nzstart"],
271
+ mrc.header["nystart"],
272
+ mrc.header["nxstart"],
273
273
  ]
274
274
  )
275
275
 
@@ -291,17 +291,9 @@ class Density:
291
291
  sampling_rate = mrc.voxel_size.astype(
292
292
  [("x", "<f4"), ("y", "<f4"), ("z", "<f4")]
293
293
  ).view(("<f4", 3))
294
- sampling_rate = sampling_rate[::-1]
295
- sampling_rate = np.array(sampling_rate)
294
+ sampling_rate = np.array(sampling_rate)[::-1]
296
295
 
297
- if np.all(origin == start):
298
- pass
299
- elif np.all(origin == 0) and not np.all(start == 0):
300
- origin = np.multiply(start, sampling_rate)
301
- elif np.all(
302
- np.abs(origin.astype(int))
303
- != np.abs((start * sampling_rate).astype(int))
304
- ) and not np.all(start == 0):
296
+ if np.allclose(origin, 0) and not np.allclose(start, 0):
305
297
  origin = np.multiply(start, sampling_rate)
306
298
 
307
299
  extended_header = mrc.header.nsymbt
@@ -317,7 +309,7 @@ class Density:
317
309
  if use_memmap:
318
310
  warnings.warn(
319
311
  f"Cannot open gzipped file {filename} as memmap."
320
- f" Please gunzip {filename} to use memmap functionality."
312
+ f" Please run 'gunzip {filename}' to use memmap functionality."
321
313
  )
322
314
  use_memmap = False
323
315
 
@@ -878,7 +870,7 @@ class Density:
878
870
  compression = "gzip" if gzip else None
879
871
  with mrcfile.new(filename, overwrite=True, compression=compression) as mrc:
880
872
  mrc.set_data(self.data.astype("float32"))
881
- mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.ceil(
873
+ mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.rint(
882
874
  np.divide(self.origin, self.sampling_rate)
883
875
  )
884
876
  # mrcfile library expects origin to be in xyz format
@@ -1529,11 +1521,10 @@ class Density:
1529
1521
 
1530
1522
  Returns
1531
1523
  -------
1532
- Density
1533
- A copy of the class instance whose data center of mass is in the
1534
- center of the data array.
1524
+ :py:class:`Density`
1525
+ A centered copy of the current class instance.
1535
1526
  NDArray
1536
- The coordinate translation.
1527
+ The offset between array center and center of mass.
1537
1528
 
1538
1529
  See Also
1539
1530
  --------
@@ -1545,48 +1536,37 @@ class Density:
1545
1536
  --------
1546
1537
  :py:meth:`Density.centered` returns a tuple containing a centered version
1547
1538
  of the current :py:class:`Density` instance, as well as an array with
1548
- translations. The translation corresponds to the shift that between the
1549
- center of mass and the center of the internal :py:attr:`Density.data` attribute.
1539
+ translations. The translation corresponds to the shift between the original and
1540
+ current center of mass.
1550
1541
 
1551
1542
  >>> import numpy as np
1552
1543
  >>> from tme import Density
1553
- >>> dens = Density(np.ones((5,5)))
1544
+ >>> dens = Density(np.ones((5,5,5)))
1554
1545
  >>> centered_dens, translation = dens.centered(0)
1555
1546
  >>> translation
1556
- array([-4.4408921e-16, 4.4408921e-16])
1547
+ array([0., 0., 0.])
1557
1548
 
1558
1549
  :py:meth:`Density.centered` extended the :py:attr:`Density.data` attribute
1559
1550
  of the current :py:class:`Density` instance and modified
1560
1551
  :py:attr:`Density.origin` accordingly.
1561
1552
 
1562
1553
  >>> centered_dens
1563
- Origin: (-1.0, -1.0), sampling_rate: (1, 1), Shape: (7, 7)
1564
-
1565
- :py:meth:`Density.centered` achieves centering via zero-padding the
1566
- internal :py:attr:`Density.data` attribute:
1567
-
1568
- >>> centered_dens.data
1569
- array([[0., 0., 0., 0., 0., 0., 0.],
1570
- [0., 1., 1., 1., 1., 1., 0.],
1571
- [0., 1., 1., 1., 1., 1., 0.],
1572
- [0., 1., 1., 1., 1., 1., 0.],
1573
- [0., 1., 1., 1., 1., 1., 0.],
1574
- [0., 1., 1., 1., 1., 1., 0.],
1575
- [0., 0., 0., 0., 0., 0., 0.]])
1576
-
1577
- `centered_dens` is sufficiently large to represent all rotations that
1578
- could be applied to the :py:attr:`Density.data` attribute. Lets look
1579
- at a random rotation obtained from
1554
+ Origin: (-2.0, -2.0, -2.0), sampling_rate: (1, 1, 1), Shape: (9, 9, 9)
1555
+
1556
+ :py:meth:`Density.centered` achieves centering via zero-padding and
1557
+ rigid-transform of the internal :py:attr:`Density.data` attribute.
1558
+ `centered_dens` is sufficiently large to represent all rotations of the
1559
+ :py:attr:`Density.data` attribute, such as ones obtained from
1580
1560
  :py:meth:`tme.matching_utils.get_rotation_matrices`.
1581
1561
 
1582
1562
  >>> from tme.matching_utils import get_rotation_matrices
1583
- >>> rotation_matrix = get_rotation_matrices(dim = 2 ,angular_sampling = 10)[0]
1563
+ >>> rotation_matrix = get_rotation_matrices(dim = 3 ,angular_sampling = 10)[0]
1584
1564
  >>> rotated_centered_dens = centered_dens.rigid_transform(
1585
1565
  >>> rotation_matrix = rotation_matrix,
1586
1566
  >>> order = None
1587
1567
  >>> )
1588
1568
  >>> print(centered_dens.data.sum(), rotated_centered_dens.data.sum())
1589
- 25.000000000000007 25.000000000000007
1569
+ 125.0 125.0
1590
1570
 
1591
1571
  """
1592
1572
  ret = self.copy()
@@ -1595,10 +1575,11 @@ class Density:
1595
1575
  ret.adjust_box(box)
1596
1576
 
1597
1577
  new_shape = np.maximum(ret.shape, self.shape)
1578
+ new_shape = np.add(new_shape, 1 - np.mod(new_shape, 2))
1598
1579
  ret.pad(new_shape)
1599
1580
 
1600
1581
  center = self.center_of_mass(ret.data, cutoff)
1601
- shift = np.subtract(np.divide(ret.shape, 2), center)
1582
+ shift = np.subtract(np.divide(np.subtract(ret.shape, 1), 2), center)
1602
1583
 
1603
1584
  ret = ret.rigid_transform(
1604
1585
  translation=shift,
@@ -1606,9 +1587,9 @@ class Density:
1606
1587
  use_geometric_center=False,
1607
1588
  order=1,
1608
1589
  )
1609
- offset = np.subtract(center, self.center_of_mass(ret.data, cutoff))
1610
1590
 
1611
- return ret, offset
1591
+ shift = np.subtract(center, self.center_of_mass(ret.data, cutoff))
1592
+ return ret, shift
1612
1593
 
1613
1594
  @classmethod
1614
1595
  def rotate_array(
@@ -1731,31 +1712,45 @@ class Density:
1731
1712
  Parameters
1732
1713
  ----------
1733
1714
  rotation_matrix : NDArray
1734
- Rotation matrix to apply to the `Density` instance.
1715
+ Rotation matrix to apply.
1735
1716
  translation : NDArray
1736
- Translation to apply to the `Density` instance.
1717
+ Translation to apply.
1737
1718
  order : int, optional
1738
- Order of spline interpolation.
1719
+ Interpolation order to use. Default is 3, has to be in range 0-5.
1739
1720
  use_geometric_center : bool, optional
1740
- Whether to use geometric or coordinate center. If False,
1741
- class instance should be centered using :py:meth:`Density.centered`.
1721
+ Use geometric or mass center as rotation center.
1742
1722
 
1743
1723
  Returns
1744
1724
  -------
1745
1725
  Density
1746
- The transformed instance of :py:class:`tme.density.Density`.
1726
+ The transformed instance of :py:class:`Density`.
1747
1727
 
1748
1728
  Examples
1749
1729
  --------
1730
+ Define the :py:class:`Density` instance
1731
+
1750
1732
  >>> import numpy as np
1751
- >>> rotation_matrix = np.eye(3)
1752
- >>> rotation_matrix[0] = -1
1753
- >>> density.rotate(rotation_matrix = rotation_matrix)
1733
+ >>> from tme import Density
1734
+ >>> dens = Density(np.arange(9).reshape(3,3).astype(np.float32))
1735
+ >>> dens, translation = dens.centered(0)
1736
+
1737
+ and apply the rotation, in this case a mirror around the z-axis
1738
+
1739
+ >>> rotation_matrix = np.eye(dens.data.ndim)
1740
+ >>> rotation_matrix[0, 0] = -1
1741
+ >>> dens_transform = dens.rigid_transform(rotation_matrix = rotation_matrix)
1742
+ >>> dens_transform.data
1743
+ array([[0. , 0. , 0. , 0. , 0. ],
1744
+ [0.5 , 3.0833333 , 3.5833333 , 3.3333333 , 0. ],
1745
+ [0.75 , 4.6666665 , 5.6666665 , 5.4166665 , 0. ],
1746
+ [0.25 , 1.6666666 , 2.6666667 , 2.9166667 , 0. ],
1747
+ [0. , 0.08333334, 0.5833333 , 0.8333333 , 0. ]],
1748
+ dtype=float32)
1754
1749
 
1755
1750
  Notes
1756
1751
  -----
1757
- :py:meth:`Density.rigid_transform` that the internal data array is
1758
- sufficiently sized to accomodate the transform.
1752
+ This function assumes the internal :py:attr:`Density.data` attribute is
1753
+ sufficiently sized to hold the transformation.
1759
1754
 
1760
1755
  See Also
1761
1756
  --------
@@ -2153,6 +2148,7 @@ class Density:
2153
2148
  cutoff_target: float = 0,
2154
2149
  cutoff_template: float = 0,
2155
2150
  scoring_method: str = "NormalizedCrossCorrelation",
2151
+ **kwargs,
2156
2152
  ) -> Tuple["Density", NDArray, NDArray, NDArray]:
2157
2153
  """
2158
2154
  Aligns two :py:class:`Density` instances target and template and returns
@@ -2177,6 +2173,9 @@ class Density:
2177
2173
  The scoring method to use for alignment. See
2178
2174
  :py:class:`tme.matching_optimization.create_score_object` for available methods,
2179
2175
  by default "NormalizedCrossCorrelation".
2176
+ kwargs : dict, optional
2177
+ Optional keyword arguments passed to
2178
+ :py:meth:`tme.matching_optimization.optimize_match`.
2180
2179
 
2181
2180
  Returns
2182
2181
  -------
@@ -2188,8 +2187,18 @@ class Density:
2188
2187
  -----
2189
2188
  No densities below cutoff_template are present in the returned Density object.
2190
2189
  """
2190
+ from .matching_exhaustive import normalize_under_mask
2191
2191
  from .matching_optimization import optimize_match, create_score_object
2192
2192
 
2193
+ template_mask = template.empty
2194
+ template_mask.data[:] = 1
2195
+
2196
+ normalize_under_mask(
2197
+ template=template.data,
2198
+ mask=template_mask.data,
2199
+ mask_intensity=template_mask.data.sum(),
2200
+ )
2201
+
2193
2202
  target_sampling_rate = np.array(target.sampling_rate)
2194
2203
  template_sampling_rate = np.array(template.sampling_rate)
2195
2204
 
@@ -2224,16 +2233,21 @@ class Density:
2224
2233
  ).astype(int)
2225
2234
  template_coordinates += mass_center_difference[:, None]
2226
2235
 
2236
+ coordinates_mask = template_mask.to_pointcloud()
2237
+ coordinates_mask = coordinates_mask * template_scaling[:, None]
2238
+ coordinates_mask += mass_center_difference[:, None]
2239
+
2227
2240
  score_object = create_score_object(
2228
2241
  score=scoring_method,
2229
2242
  target=target.data,
2230
2243
  template_coordinates=template_coordinates,
2244
+ template_mask_coordinates=coordinates_mask,
2231
2245
  template_weights=template_weights,
2232
2246
  sampling_rate=np.ones(template.data.ndim),
2233
2247
  )
2234
2248
 
2235
2249
  translation, rotation_matrix, score = optimize_match(
2236
- score_object=score_object, optimization_method="basinhopping"
2250
+ score_object=score_object, **kwargs
2237
2251
  )
2238
2252
 
2239
2253
  translation += mass_center_difference
@@ -2255,6 +2269,8 @@ class Density:
2255
2269
  template: "Structure",
2256
2270
  cutoff_target: float = 0,
2257
2271
  scoring_method: str = "NormalizedCrossCorrelation",
2272
+ optimization_method: str = "basinhopping",
2273
+ maxiter: int = 500,
2258
2274
  ) -> Tuple["Structure", NDArray, NDArray]:
2259
2275
  """
2260
2276
  Aligns a :py:class:`tme.structure.Structure` template to :py:class:`Density`
@@ -2279,6 +2295,12 @@ class Density:
2279
2295
  The scoring method to use for template matching. See
2280
2296
  :py:class:`tme.matching_optimization.create_score_object` for available methods,
2281
2297
  by default "NormalizedCrossCorrelation".
2298
+ optimization_method : str, optional
2299
+ Optimizer that is used.
2300
+ See :py:meth:`tme.matching_optimization.optimize_match`.
2301
+ maxiter : int, optional
2302
+ Maximum number of iterations for the optimizer.
2303
+ See :py:meth:`tme.matching_optimization.optimize_match`.
2282
2304
 
2283
2305
  Returns
2284
2306
  -------
@@ -2302,18 +2324,17 @@ class Density:
2302
2324
  cutoff_target=cutoff_target,
2303
2325
  cutoff_template=0,
2304
2326
  scoring_method=scoring_method,
2327
+ optimization_method=optimization_method,
2328
+ maxiter=maxiter,
2305
2329
  )
2306
2330
  out = template.copy()
2307
- final_translation = np.add(
2308
- -template_density.origin,
2309
- np.multiply(translation, template_density.sampling_rate),
2310
- )
2331
+ final_translation = np.subtract(ret.origin, template_density.origin)
2311
2332
 
2312
2333
  # Atom coordinates are in xyz
2313
2334
  final_translation = final_translation[::-1]
2314
2335
  rotation_matrix = rotation_matrix[::-1, ::-1]
2315
2336
 
2316
- out.rigid_transform(
2337
+ out = out.rigid_transform(
2317
2338
  translation=final_translation, rotation_matrix=rotation_matrix
2318
2339
  )
2319
2340
 
Binary file
tme/matching_data.py CHANGED
@@ -22,20 +22,44 @@ class MatchingData:
22
22
 
23
23
  Parameters
24
24
  ----------
25
- target : np.ndarray or Density
26
- Target data array for template matching.
27
- template : np.ndarray or Density
28
- Template data array for template matching.
25
+ target : np.ndarray or :py:class:`tme.density.Density`
26
+ Target data.
27
+ template : np.ndarray or :py:class:`tme.density.Density`
28
+ Template data.
29
+ target_mask : np.ndarray or :py:class:`tme.density.Density`, optional
30
+ Target mask data.
31
+ template_mask : np.ndarray or :py:class:`tme.density.Density`, optional
32
+ Template mask data.
33
+ invert_target : bool, optional
34
+ Whether to invert and rescale the target before template matching..
35
+ rotations: np.ndarray, optional
36
+ Template rotations to sample. Can be a single (d x d) or a stack (n x d x d)
37
+ of rotation matrices where d is the dimension of the template.
38
+
39
+ Examples
40
+ --------
41
+ The following achieves the minimal definition of a :py:class:`MatchingData` instance.
42
+
43
+ >>> import numpy as np
44
+ >>> from tme.matching_data import MatchingData
45
+ >>> target = np.random.rand(50,40,60)
46
+ >>> template = target[15:25, 10:20, 30:40]
47
+ >>> matching_data = MatchingData(target=target, template=template)
29
48
 
30
49
  """
31
50
 
32
- def __init__(self, target: NDArray, template: NDArray):
33
- self._default_dtype = np.float32
34
- self._complex_dtype = np.complex64
35
-
51
+ def __init__(
52
+ self,
53
+ target: NDArray,
54
+ template: NDArray,
55
+ template_mask: NDArray = None,
56
+ target_mask: NDArray = None,
57
+ invert_target: bool = False,
58
+ rotations: NDArray = None,
59
+ ):
36
60
  self._target = target
37
- self._target_mask = None
38
- self._template_mask = None
61
+ self._target_mask = target_mask
62
+ self._template_mask = template_mask
39
63
  self._translation_offset = np.zeros(len(target.shape), dtype=int)
40
64
 
41
65
  self.template = template
@@ -46,8 +70,11 @@ class MatchingData:
46
70
  self.template_filter = {}
47
71
  self.target_filter = {}
48
72
 
49
- self._invert_target = False
50
- self._rotations = None
73
+ self._invert_target = invert_target
74
+
75
+ self._rotations = rotations
76
+ if rotations is not None:
77
+ self.rotations = rotations
51
78
 
52
79
  self._set_batch_dimension()
53
80
 
@@ -257,7 +284,7 @@ class MatchingData:
257
284
  template_offset[(template_offset.size - len(template_slice)) :] = [
258
285
  x.start for x in template_slice
259
286
  ]
260
- ret._translation_offset = np.add(target_offset, template_offset)
287
+ ret._translation_offset = target_offset
261
288
 
262
289
  ret.template_filter = self.template_filter
263
290
  ret.target_filter = self.target_filter
@@ -266,7 +293,7 @@ class MatchingData:
266
293
  ret._invert_target = self._invert_target
267
294
 
268
295
  if self._target_mask is not None:
269
- ret.target_mask = self.subset_array(
296
+ ret._target_mask = self.subset_array(
270
297
  arr=self._target_mask, arr_slice=target_slice, padding=target_pad
271
298
  )
272
299
  if self._template_mask is not None:
@@ -289,19 +316,29 @@ class MatchingData:
289
316
 
290
317
  def to_backend(self) -> None:
291
318
  """
292
- Transfer the class instance's numpy arrays to the current backend.
319
+ Transfer and convert types of class instance's data arrays to the current backend
293
320
  """
294
- backend_arr = type(backend.zeros((1), dtype=backend._default_dtype))
321
+ backend_arr = type(backend.zeros((1), dtype=backend._float_dtype))
295
322
  for attr_name, attr_value in vars(self).items():
323
+ converted_array = None
296
324
  if isinstance(attr_value, np.ndarray):
297
325
  converted_array = backend.to_backend_array(attr_value.copy())
298
- setattr(self, attr_name, converted_array)
299
326
  elif isinstance(attr_value, backend_arr):
300
327
  converted_array = backend.to_backend_array(attr_value)
301
- setattr(self, attr_name, converted_array)
328
+ else:
329
+ continue
330
+
331
+ current_dtype = backend.get_fundamental_dtype(converted_array)
332
+ target_dtype = backend._fundamental_dtypes[current_dtype]
333
+
334
+ # Optional, but scores are float so we avoid casting and potential issues
335
+ if attr_name in ("_template", "_template_mask", "_target", "_target_mask"):
336
+ target_dtype = backend._float_dtype
302
337
 
303
- self._default_dtype = backend._default_dtype
304
- self._complex_dtype = backend._complex_dtype
338
+ if target_dtype != current_dtype:
339
+ converted_array = backend.astype(converted_array, target_dtype)
340
+
341
+ setattr(self, attr_name, converted_array)
305
342
 
306
343
  def _set_batch_dimension(
307
344
  self, target_dims: Tuple[int] = None, template_dims: Tuple[int] = None
@@ -350,12 +387,14 @@ class MatchingData:
350
387
  matching_dims = target_measurement_dims + batch_dims
351
388
 
352
389
  target_shape = backend.full(
353
- shape=(matching_dims,), fill_value=1, dtype=backend._default_dtype_int
390
+ shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
354
391
  )
355
392
  template_shape = backend.full(
356
- shape=(matching_dims,), fill_value=1, dtype=backend._default_dtype_int
393
+ shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
394
+ )
395
+ batch_mask = backend.full(
396
+ shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
357
397
  )
358
- batch_mask = backend.full(shape=(matching_dims,), fill_value=1, dtype=bool)
359
398
 
360
399
  target_index, template_index = 0, 0
361
400
  for k in range(matching_dims):
@@ -440,7 +479,7 @@ class MatchingData:
440
479
  An array indicating the padding for each dimension of the target.
441
480
  """
442
481
  target_padding = backend.zeros(
443
- len(self._output_target_shape), dtype=backend._default_dtype_int
482
+ len(self._output_target_shape), dtype=backend._int_dtype
444
483
  )
445
484
 
446
485
  if pad_target:
@@ -491,13 +530,14 @@ class MatchingData:
491
530
  fourier_pad = backend.full(
492
531
  shape=(len(fourier_pad),),
493
532
  fill_value=1,
494
- dtype=backend._default_dtype_int,
533
+ dtype=backend._int_dtype,
495
534
  )
496
535
 
497
536
  fourier_pad = backend.to_backend_array(fourier_pad)
498
537
  if hasattr(self, "_batch_mask"):
499
538
  batch_mask = backend.to_backend_array(self._batch_mask)
500
- fourier_pad[batch_mask] = 1
539
+ backend.multiply(fourier_pad, 1 - batch_mask, out=fourier_pad)
540
+ backend.add(fourier_pad, batch_mask, out=fourier_pad)
501
541
 
502
542
  pad_shape = backend.maximum(target_shape, template_shape)
503
543
  ret = backend.compute_convolution_shapes(pad_shape, fourier_pad)
@@ -510,10 +550,12 @@ class MatchingData:
510
550
 
511
551
  if hasattr(self, "_batch_mask"):
512
552
  batch_mask = backend.to_backend_array(self._batch_mask)
513
- shape_diff[batch_mask] = 0
553
+ backend.multiply(shape_diff, 1 - batch_mask, out=shape_diff)
514
554
 
515
555
  backend.add(fourier_shift, shape_diff, out=fourier_shift)
516
556
 
557
+ fourier_shift = backend.astype(fourier_shift, backend._int_dtype)
558
+
517
559
  return fast_shape, fast_ft_shape, fourier_shift
518
560
 
519
561
  @property
@@ -540,27 +582,27 @@ class MatchingData:
540
582
  pass
541
583
  else:
542
584
  raise ValueError("Rotations have to be a rank 2 or 3 array.")
543
- self._rotations = rotations.astype(self._default_dtype)
585
+ self._rotations = rotations.astype(np.float32)
544
586
 
545
587
  @property
546
588
  def target(self):
547
- """Returns the target NDArray."""
589
+ """Returns the target."""
590
+ target = self._target
548
591
  if isinstance(self._target, Density):
549
592
  target = self._target.data
550
- else:
551
- target = self._target
552
- out_shape = backend.to_numpy_array(self._output_target_shape)
553
- return target.reshape(tuple(int(x) for x in out_shape))
593
+
594
+ out_shape = tuple(int(x) for x in self._output_target_shape)
595
+ return target.reshape(out_shape)
554
596
 
555
597
  @property
556
598
  def template(self):
557
- """Returns the reversed template NDArray."""
599
+ """Returns the reversed template."""
558
600
  template = self._template
559
601
  if isinstance(self._template, Density):
560
602
  template = self._template.data
561
603
  template = backend.reverse(template)
562
- out_shape = backend.to_numpy_array(self._output_template_shape)
563
- return template.reshape(tuple(int(x) for x in out_shape))
604
+ out_shape = tuple(int(x) for x in self._output_template_shape)
605
+ return template.reshape(out_shape)
564
606
 
565
607
  @template.setter
566
608
  def template(self, template: NDArray):
@@ -580,9 +622,7 @@ class MatchingData:
580
622
  shape=template.shape, dtype=float, fill_value=1
581
623
  )
582
624
 
583
- if type(template) == Density:
584
- template = template.data
585
- self._template = template.astype(self._default_dtype, copy=False)
625
+ self._template = template
586
626
 
587
627
  @property
588
628
  def target_mask(self):
@@ -592,8 +632,8 @@ class MatchingData:
592
632
  target_mask = self._target_mask.data
593
633
 
594
634
  if target_mask is not None:
595
- out_shape = backend.to_numpy_array(self._output_target_shape)
596
- target_mask = target_mask.reshape(tuple(int(x) for x in out_shape))
635
+ out_shape = tuple(int(x) for x in self._output_target_shape)
636
+ target_mask = target_mask.reshape(out_shape)
597
637
 
598
638
  return target_mask
599
639
 
@@ -603,14 +643,7 @@ class MatchingData:
603
643
  if not np.all(self.target.shape == mask.shape):
604
644
  raise ValueError("Target and its mask have to have the same shape.")
605
645
 
606
- if type(mask) == Density:
607
- mask.data = mask.data.astype(self._default_dtype, copy=False)
608
- self._target_mask = mask
609
- self._targetmaskshape = self._target_mask.shape[::-1]
610
- return None
611
-
612
- self._target_mask = mask.astype(self._default_dtype, copy=False)
613
- self._targetmaskshape = self._target_mask.shape
646
+ self._target_mask = mask
614
647
 
615
648
  @property
616
649
  def template_mask(self):
@@ -628,24 +661,17 @@ class MatchingData:
628
661
 
629
662
  if mask is not None:
630
663
  mask = backend.reverse(mask)
631
- out_shape = backend.to_numpy_array(self._output_template_shape)
632
- mask = mask.reshape(tuple(int(x) for x in out_shape))
664
+ out_shape = tuple(int(x) for x in self._output_template_shape)
665
+ mask = mask.reshape(out_shape)
633
666
  return mask
634
667
 
635
668
  @template_mask.setter
636
669
  def template_mask(self, mask: NDArray):
637
670
  """Returns the reversed template mask NDArray."""
638
671
  if not np.all(self._templateshape[::-1] == mask.shape):
639
- raise ValueError("Target and its mask have to have the same shape.")
640
-
641
- if type(mask) == Density:
642
- mask.data = mask.data.astype(self._default_dtype, copy=False)
643
- self._template_mask = mask
644
- self._templatemaskshape = self._template_mask.shape[::-1]
645
- return None
672
+ raise ValueError("Template and its mask have to have the same shape.")
646
673
 
647
- self._template_mask = mask.astype(self._default_dtype, copy=False)
648
- self._templatemaskshape = self._template_mask.shape[::-1]
674
+ self._template_mask = mask
649
675
 
650
676
  def _split_rotations_on_jobs(self, n_jobs: int) -> List[NDArray]:
651
677
  """