pytme 0.2.0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/match_template.py +183 -69
  2. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/METADATA +1 -1
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +183 -69
  8. scripts/match_template_filters.py +193 -71
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +259 -117
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +20 -8
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +79 -60
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +85 -61
  24. tme/matching_exhaustive.py +222 -129
  25. tme/matching_optimization.py +117 -76
  26. tme/orientations.py +175 -55
  27. tme/preprocessing/_utils.py +17 -5
  28. tme/preprocessing/composable_filter.py +2 -1
  29. tme/preprocessing/compose.py +1 -2
  30. tme/preprocessing/frequency_filters.py +97 -41
  31. tme/preprocessing/tilt_series.py +137 -87
  32. tme/preprocessor.py +3 -0
  33. tme/structure.py +4 -1
  34. pytme-0.2.0.dist-info/RECORD +0 -72
  35. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  36. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  37. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  38. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  39. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  40. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/density.py CHANGED
@@ -40,12 +40,12 @@ from .backends import NumpyFFTWBackend
40
40
 
41
41
  class Density:
42
42
  """
43
- Contains electron density data and implements operations on it.
43
+ Abstract representation of N-dimensional densities.
44
44
 
45
45
  Parameters
46
46
  ----------
47
47
  data : NDArray
48
- Electron density data.
48
+ Array of data values.
49
49
  origin : NDArray, optional
50
50
  Origin of the coordinate system. Defaults to zero.
51
51
  sampling_rate : NDArray, optional
@@ -267,9 +267,9 @@ class Density:
267
267
  # nx := column; ny := row; nz := section
268
268
  start = np.array(
269
269
  [
270
- int(mrc.header["nxstart"]),
271
- int(mrc.header["nystart"]),
272
- int(mrc.header["nzstart"]),
270
+ mrc.header["nzstart"],
271
+ mrc.header["nystart"],
272
+ mrc.header["nxstart"],
273
273
  ]
274
274
  )
275
275
 
@@ -291,17 +291,9 @@ class Density:
291
291
  sampling_rate = mrc.voxel_size.astype(
292
292
  [("x", "<f4"), ("y", "<f4"), ("z", "<f4")]
293
293
  ).view(("<f4", 3))
294
- sampling_rate = sampling_rate[::-1]
295
- sampling_rate = np.array(sampling_rate)
294
+ sampling_rate = np.array(sampling_rate)[::-1]
296
295
 
297
- if np.all(origin == start):
298
- pass
299
- elif np.all(origin == 0) and not np.all(start == 0):
300
- origin = np.multiply(start, sampling_rate)
301
- elif np.all(
302
- np.abs(origin.astype(int))
303
- != np.abs((start * sampling_rate).astype(int))
304
- ) and not np.all(start == 0):
296
+ if np.allclose(origin, 0) and not np.allclose(start, 0):
305
297
  origin = np.multiply(start, sampling_rate)
306
298
 
307
299
  extended_header = mrc.header.nsymbt
@@ -878,7 +870,7 @@ class Density:
878
870
  compression = "gzip" if gzip else None
879
871
  with mrcfile.new(filename, overwrite=True, compression=compression) as mrc:
880
872
  mrc.set_data(self.data.astype("float32"))
881
- mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.ceil(
873
+ mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.rint(
882
874
  np.divide(self.origin, self.sampling_rate)
883
875
  )
884
876
  # mrcfile library expects origin to be in xyz format
@@ -1529,11 +1521,10 @@ class Density:
1529
1521
 
1530
1522
  Returns
1531
1523
  -------
1532
- Density
1533
- A copy of the class instance whose data center of mass is in the
1534
- center of the data array.
1524
+ :py:class:`Density`
1525
+ A centered copy of the current class instance.
1535
1526
  NDArray
1536
- The coordinate translation.
1527
+ The offset between array center and center of mass.
1537
1528
 
1538
1529
  See Also
1539
1530
  --------
@@ -1550,44 +1541,32 @@ class Density:
1550
1541
 
1551
1542
  >>> import numpy as np
1552
1543
  >>> from tme import Density
1553
- >>> dens = Density(np.ones((5,5)))
1544
+ >>> dens = Density(np.ones((5,5,5)))
1554
1545
  >>> centered_dens, translation = dens.centered(0)
1555
1546
  >>> translation
1556
- array([-0.5, -0.5])
1547
+ array([0., 0., 0.])
1557
1548
 
1558
1549
  :py:meth:`Density.centered` extended the :py:attr:`Density.data` attribute
1559
1550
  of the current :py:class:`Density` instance and modified
1560
1551
  :py:attr:`Density.origin` accordingly.
1561
1552
 
1562
1553
  >>> centered_dens
1563
- Origin: (-1.0, -1.0), sampling_rate: (1, 1), Shape: (8, 8)
1554
+ Origin: (-2.0, -2.0, -2.0), sampling_rate: (1, 1, 1), Shape: (9, 9, 9)
1564
1555
 
1565
1556
  :py:meth:`Density.centered` achieves centering via zero-padding and
1566
- transforming the internal :py:attr:`Density.data` attribute:
1567
-
1568
- >>> centered_dens.data
1569
- array([[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
1570
- [0. , 0.25, 0.5 , 0.5 , 0.5 , 0.5 , 0.25, 0. ],
1571
- [0. , 0.5 , 1. , 1. , 1. , 1. , 0.5 , 0. ],
1572
- [0. , 0.5 , 1. , 1. , 1. , 1. , 0.5 , 0. ],
1573
- [0. , 0.5 , 1. , 1. , 1. , 1. , 0.5 , 0. ],
1574
- [0. , 0.5 , 1. , 1. , 1. , 1. , 0.5 , 0. ],
1575
- [0. , 0.25, 0.5 , 0.5 , 0.5 , 0.5 , 0.25, 0. ],
1576
- [0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]])
1577
-
1578
- `centered_dens` is sufficiently large to represent all rotations that
1579
- could be applied to the :py:attr:`Density.data` attribute. Lets look
1580
- at a random rotation obtained from
1557
+ rigid-transform of the internal :py:attr:`Density.data` attribute.
1558
+ `centered_dens` is sufficiently large to represent all rotations of the
1559
+ :py:attr:`Density.data` attribute, such as ones obtained from
1581
1560
  :py:meth:`tme.matching_utils.get_rotation_matrices`.
1582
1561
 
1583
1562
  >>> from tme.matching_utils import get_rotation_matrices
1584
- >>> rotation_matrix = get_rotation_matrices(dim = 2 ,angular_sampling = 10)[0]
1563
+ >>> rotation_matrix = get_rotation_matrices(dim = 3 ,angular_sampling = 10)[0]
1585
1564
  >>> rotated_centered_dens = centered_dens.rigid_transform(
1586
1565
  >>> rotation_matrix = rotation_matrix,
1587
1566
  >>> order = None
1588
1567
  >>> )
1589
1568
  >>> print(centered_dens.data.sum(), rotated_centered_dens.data.sum())
1590
- 25.000000000000007 25.000000000000007
1569
+ 125.0 125.0
1591
1570
 
1592
1571
  """
1593
1572
  ret = self.copy()
@@ -1596,7 +1575,7 @@ class Density:
1596
1575
  ret.adjust_box(box)
1597
1576
 
1598
1577
  new_shape = np.maximum(ret.shape, self.shape)
1599
- new_shape = np.add(new_shape, np.mod(new_shape, 2))
1578
+ new_shape = np.add(new_shape, 1 - np.mod(new_shape, 2))
1600
1579
  ret.pad(new_shape)
1601
1580
 
1602
1581
  center = self.center_of_mass(ret.data, cutoff)
@@ -1608,9 +1587,9 @@ class Density:
1608
1587
  use_geometric_center=False,
1609
1588
  order=1,
1610
1589
  )
1611
- offset = np.subtract(center, self.center_of_mass(ret.data, cutoff))
1612
1590
 
1613
- return ret, offset
1591
+ shift = np.subtract(center, self.center_of_mass(ret.data, cutoff))
1592
+ return ret, shift
1614
1593
 
1615
1594
  @classmethod
1616
1595
  def rotate_array(
@@ -1733,31 +1712,45 @@ class Density:
1733
1712
  Parameters
1734
1713
  ----------
1735
1714
  rotation_matrix : NDArray
1736
- Rotation matrix to apply to the `Density` instance.
1715
+ Rotation matrix to apply.
1737
1716
  translation : NDArray
1738
- Translation to apply to the `Density` instance.
1717
+ Translation to apply.
1739
1718
  order : int, optional
1740
- Order of spline interpolation.
1719
+ Interpolation order to use. Default is 3, has to be in range 0-5.
1741
1720
  use_geometric_center : bool, optional
1742
- Whether to use geometric or coordinate center. If False,
1743
- class instance should be centered using :py:meth:`Density.centered`.
1721
+ Use geometric or mass center as rotation center.
1744
1722
 
1745
1723
  Returns
1746
1724
  -------
1747
1725
  Density
1748
- The transformed instance of :py:class:`tme.density.Density`.
1726
+ The transformed instance of :py:class:`Density`.
1749
1727
 
1750
1728
  Examples
1751
1729
  --------
1730
+ Define the :py:class:`Density` instance
1731
+
1752
1732
  >>> import numpy as np
1753
- >>> rotation_matrix = np.eye(3)
1754
- >>> rotation_matrix[0] = -1
1755
- >>> density.rotate(rotation_matrix = rotation_matrix)
1733
+ >>> from tme import Density
1734
+ >>> dens = Density(np.arange(9).reshape(3,3).astype(np.float32))
1735
+ >>> dens, translation = dens.centered(0)
1736
+
1737
+ and apply the rotation, in this case a mirror around the z-axis
1738
+
1739
+ >>> rotation_matrix = np.eye(dens.data.ndim)
1740
+ >>> rotation_matrix[0, 0] = -1
1741
+ >>> dens_transform = dens.rigid_transform(rotation_matrix = rotation_matrix)
1742
+ >>> dens_transform.data
1743
+ array([[0. , 0. , 0. , 0. , 0. ],
1744
+ [0.5 , 3.0833333 , 3.5833333 , 3.3333333 , 0. ],
1745
+ [0.75 , 4.6666665 , 5.6666665 , 5.4166665 , 0. ],
1746
+ [0.25 , 1.6666666 , 2.6666667 , 2.9166667 , 0. ],
1747
+ [0. , 0.08333334, 0.5833333 , 0.8333333 , 0. ]],
1748
+ dtype=float32)
1756
1749
 
1757
1750
  Notes
1758
1751
  -----
1759
- :py:meth:`Density.rigid_transform` that the internal data array is
1760
- sufficiently sized to accomodate the transform.
1752
+ This function assumes the internal :py:attr:`Density.data` attribute is
1753
+ sufficiently sized to hold the transformation.
1761
1754
 
1762
1755
  See Also
1763
1756
  --------
@@ -2155,6 +2148,7 @@ class Density:
2155
2148
  cutoff_target: float = 0,
2156
2149
  cutoff_template: float = 0,
2157
2150
  scoring_method: str = "NormalizedCrossCorrelation",
2151
+ **kwargs,
2158
2152
  ) -> Tuple["Density", NDArray, NDArray, NDArray]:
2159
2153
  """
2160
2154
  Aligns two :py:class:`Density` instances target and template and returns
@@ -2179,6 +2173,9 @@ class Density:
2179
2173
  The scoring method to use for alignment. See
2180
2174
  :py:class:`tme.matching_optimization.create_score_object` for available methods,
2181
2175
  by default "NormalizedCrossCorrelation".
2176
+ kwargs : dict, optional
2177
+ Optional keyword arguments passed to
2178
+ :py:meth:`tme.matching_optimization.optimize_match`.
2182
2179
 
2183
2180
  Returns
2184
2181
  -------
@@ -2190,8 +2187,18 @@ class Density:
2190
2187
  -----
2191
2188
  No densities below cutoff_template are present in the returned Density object.
2192
2189
  """
2190
+ from .matching_exhaustive import normalize_under_mask
2193
2191
  from .matching_optimization import optimize_match, create_score_object
2194
2192
 
2193
+ template_mask = template.empty
2194
+ template_mask.data[:] = 1
2195
+
2196
+ normalize_under_mask(
2197
+ template=template.data,
2198
+ mask=template_mask.data,
2199
+ mask_intensity=template_mask.data.sum(),
2200
+ )
2201
+
2195
2202
  target_sampling_rate = np.array(target.sampling_rate)
2196
2203
  template_sampling_rate = np.array(template.sampling_rate)
2197
2204
 
@@ -2226,16 +2233,21 @@ class Density:
2226
2233
  ).astype(int)
2227
2234
  template_coordinates += mass_center_difference[:, None]
2228
2235
 
2236
+ coordinates_mask = template_mask.to_pointcloud()
2237
+ coordinates_mask = coordinates_mask * template_scaling[:, None]
2238
+ coordinates_mask += mass_center_difference[:, None]
2239
+
2229
2240
  score_object = create_score_object(
2230
2241
  score=scoring_method,
2231
2242
  target=target.data,
2232
2243
  template_coordinates=template_coordinates,
2244
+ template_mask_coordinates=coordinates_mask,
2233
2245
  template_weights=template_weights,
2234
2246
  sampling_rate=np.ones(template.data.ndim),
2235
2247
  )
2236
2248
 
2237
2249
  translation, rotation_matrix, score = optimize_match(
2238
- score_object=score_object, optimization_method="basinhopping"
2250
+ score_object=score_object, **kwargs
2239
2251
  )
2240
2252
 
2241
2253
  translation += mass_center_difference
@@ -2257,6 +2269,8 @@ class Density:
2257
2269
  template: "Structure",
2258
2270
  cutoff_target: float = 0,
2259
2271
  scoring_method: str = "NormalizedCrossCorrelation",
2272
+ optimization_method: str = "basinhopping",
2273
+ maxiter: int = 500,
2260
2274
  ) -> Tuple["Structure", NDArray, NDArray]:
2261
2275
  """
2262
2276
  Aligns a :py:class:`tme.structure.Structure` template to :py:class:`Density`
@@ -2281,6 +2295,12 @@ class Density:
2281
2295
  The scoring method to use for template matching. See
2282
2296
  :py:class:`tme.matching_optimization.create_score_object` for available methods,
2283
2297
  by default "NormalizedCrossCorrelation".
2298
+ optimization_method : str, optional
2299
+ Optimizer that is used.
2300
+ See :py:meth:`tme.matching_optimization.optimize_match`.
2301
+ maxiter : int, optional
2302
+ Maximum number of iterations for the optimizer.
2303
+ See :py:meth:`tme.matching_optimization.optimize_match`.
2284
2304
 
2285
2305
  Returns
2286
2306
  -------
@@ -2304,18 +2324,17 @@ class Density:
2304
2324
  cutoff_target=cutoff_target,
2305
2325
  cutoff_template=0,
2306
2326
  scoring_method=scoring_method,
2327
+ optimization_method=optimization_method,
2328
+ maxiter=maxiter,
2307
2329
  )
2308
2330
  out = template.copy()
2309
- final_translation = np.add(
2310
- -template_density.origin,
2311
- np.multiply(translation, template_density.sampling_rate),
2312
- )
2331
+ final_translation = np.subtract(ret.origin, template_density.origin)
2313
2332
 
2314
2333
  # Atom coordinates are in xyz
2315
2334
  final_translation = final_translation[::-1]
2316
2335
  rotation_matrix = rotation_matrix[::-1, ::-1]
2317
2336
 
2318
- out.rigid_transform(
2337
+ out = out.rigid_transform(
2319
2338
  translation=final_translation, rotation_matrix=rotation_matrix
2320
2339
  )
2321
2340
 
Binary file
tme/matching_data.py CHANGED
@@ -22,20 +22,44 @@ class MatchingData:
22
22
 
23
23
  Parameters
24
24
  ----------
25
- target : np.ndarray or Density
26
- Target data array for template matching.
27
- template : np.ndarray or Density
28
- Template data array for template matching.
25
+ target : np.ndarray or :py:class:`tme.density.Density`
26
+ Target data.
27
+ template : np.ndarray or :py:class:`tme.density.Density`
28
+ Template data.
29
+ target_mask : np.ndarray or :py:class:`tme.density.Density`, optional
30
+ Target mask data.
31
+ template_mask : np.ndarray or :py:class:`tme.density.Density`, optional
32
+ Template mask data.
33
+ invert_target : bool, optional
34
+ Whether to invert and rescale the target before template matching..
35
+ rotations: np.ndarray, optional
36
+ Template rotations to sample. Can be a single (d x d) or a stack (n x d x d)
37
+ of rotation matrices where d is the dimension of the template.
38
+
39
+ Examples
40
+ --------
41
+ The following achieves the minimal definition of a :py:class:`MatchingData` instance.
42
+
43
+ >>> import numpy as np
44
+ >>> from tme.matching_data import MatchingData
45
+ >>> target = np.random.rand(50,40,60)
46
+ >>> template = target[15:25, 10:20, 30:40]
47
+ >>> matching_data = MatchingData(target=target, template=template)
29
48
 
30
49
  """
31
50
 
32
- def __init__(self, target: NDArray, template: NDArray):
33
- self._default_dtype = np.float32
34
- self._complex_dtype = np.complex64
35
-
51
+ def __init__(
52
+ self,
53
+ target: NDArray,
54
+ template: NDArray,
55
+ template_mask: NDArray = None,
56
+ target_mask: NDArray = None,
57
+ invert_target: bool = False,
58
+ rotations: NDArray = None,
59
+ ):
36
60
  self._target = target
37
- self._target_mask = None
38
- self._template_mask = None
61
+ self._target_mask = target_mask
62
+ self._template_mask = template_mask
39
63
  self._translation_offset = np.zeros(len(target.shape), dtype=int)
40
64
 
41
65
  self.template = template
@@ -46,8 +70,11 @@ class MatchingData:
46
70
  self.template_filter = {}
47
71
  self.target_filter = {}
48
72
 
49
- self._invert_target = False
50
- self._rotations = None
73
+ self._invert_target = invert_target
74
+
75
+ self._rotations = rotations
76
+ if rotations is not None:
77
+ self.rotations = rotations
51
78
 
52
79
  self._set_batch_dimension()
53
80
 
@@ -257,7 +284,7 @@ class MatchingData:
257
284
  template_offset[(template_offset.size - len(template_slice)) :] = [
258
285
  x.start for x in template_slice
259
286
  ]
260
- ret._translation_offset = np.add(target_offset, template_offset)
287
+ ret._translation_offset = target_offset
261
288
 
262
289
  ret.template_filter = self.template_filter
263
290
  ret.target_filter = self.target_filter
@@ -266,7 +293,7 @@ class MatchingData:
266
293
  ret._invert_target = self._invert_target
267
294
 
268
295
  if self._target_mask is not None:
269
- ret.target_mask = self.subset_array(
296
+ ret._target_mask = self.subset_array(
270
297
  arr=self._target_mask, arr_slice=target_slice, padding=target_pad
271
298
  )
272
299
  if self._template_mask is not None:
@@ -289,19 +316,29 @@ class MatchingData:
289
316
 
290
317
  def to_backend(self) -> None:
291
318
  """
292
- Transfer the class instance's numpy arrays to the current backend.
319
+ Transfer and convert types of class instance's data arrays to the current backend
293
320
  """
294
- backend_arr = type(backend.zeros((1), dtype=backend._default_dtype))
321
+ backend_arr = type(backend.zeros((1), dtype=backend._float_dtype))
295
322
  for attr_name, attr_value in vars(self).items():
323
+ converted_array = None
296
324
  if isinstance(attr_value, np.ndarray):
297
325
  converted_array = backend.to_backend_array(attr_value.copy())
298
- setattr(self, attr_name, converted_array)
299
326
  elif isinstance(attr_value, backend_arr):
300
327
  converted_array = backend.to_backend_array(attr_value)
301
- setattr(self, attr_name, converted_array)
328
+ else:
329
+ continue
330
+
331
+ current_dtype = backend.get_fundamental_dtype(converted_array)
332
+ target_dtype = backend._fundamental_dtypes[current_dtype]
333
+
334
+ # Optional, but scores are float so we avoid casting and potential issues
335
+ if attr_name in ("_template", "_template_mask", "_target", "_target_mask"):
336
+ target_dtype = backend._float_dtype
302
337
 
303
- self._default_dtype = backend._default_dtype
304
- self._complex_dtype = backend._complex_dtype
338
+ if target_dtype != current_dtype:
339
+ converted_array = backend.astype(converted_array, target_dtype)
340
+
341
+ setattr(self, attr_name, converted_array)
305
342
 
306
343
  def _set_batch_dimension(
307
344
  self, target_dims: Tuple[int] = None, template_dims: Tuple[int] = None
@@ -350,12 +387,14 @@ class MatchingData:
350
387
  matching_dims = target_measurement_dims + batch_dims
351
388
 
352
389
  target_shape = backend.full(
353
- shape=(matching_dims,), fill_value=1, dtype=backend._default_dtype_int
390
+ shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
354
391
  )
355
392
  template_shape = backend.full(
356
- shape=(matching_dims,), fill_value=1, dtype=backend._default_dtype_int
393
+ shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
394
+ )
395
+ batch_mask = backend.full(
396
+ shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
357
397
  )
358
- batch_mask = backend.full(shape=(matching_dims,), fill_value=1, dtype=bool)
359
398
 
360
399
  target_index, template_index = 0, 0
361
400
  for k in range(matching_dims):
@@ -440,7 +479,7 @@ class MatchingData:
440
479
  An array indicating the padding for each dimension of the target.
441
480
  """
442
481
  target_padding = backend.zeros(
443
- len(self._output_target_shape), dtype=backend._default_dtype_int
482
+ len(self._output_target_shape), dtype=backend._int_dtype
444
483
  )
445
484
 
446
485
  if pad_target:
@@ -491,13 +530,14 @@ class MatchingData:
491
530
  fourier_pad = backend.full(
492
531
  shape=(len(fourier_pad),),
493
532
  fill_value=1,
494
- dtype=backend._default_dtype_int,
533
+ dtype=backend._int_dtype,
495
534
  )
496
535
 
497
536
  fourier_pad = backend.to_backend_array(fourier_pad)
498
537
  if hasattr(self, "_batch_mask"):
499
538
  batch_mask = backend.to_backend_array(self._batch_mask)
500
- fourier_pad[batch_mask] = 1
539
+ backend.multiply(fourier_pad, 1 - batch_mask, out=fourier_pad)
540
+ backend.add(fourier_pad, batch_mask, out=fourier_pad)
501
541
 
502
542
  pad_shape = backend.maximum(target_shape, template_shape)
503
543
  ret = backend.compute_convolution_shapes(pad_shape, fourier_pad)
@@ -510,11 +550,11 @@ class MatchingData:
510
550
 
511
551
  if hasattr(self, "_batch_mask"):
512
552
  batch_mask = backend.to_backend_array(self._batch_mask)
513
- shape_diff[batch_mask] = 0
553
+ backend.multiply(shape_diff, 1 - batch_mask, out=shape_diff)
514
554
 
515
555
  backend.add(fourier_shift, shape_diff, out=fourier_shift)
516
556
 
517
- fourier_shift = backend.astype(fourier_shift, backend._default_dtype_int)
557
+ fourier_shift = backend.astype(fourier_shift, backend._int_dtype)
518
558
 
519
559
  return fast_shape, fast_ft_shape, fourier_shift
520
560
 
@@ -542,27 +582,27 @@ class MatchingData:
542
582
  pass
543
583
  else:
544
584
  raise ValueError("Rotations have to be a rank 2 or 3 array.")
545
- self._rotations = rotations.astype(self._default_dtype)
585
+ self._rotations = rotations.astype(np.float32)
546
586
 
547
587
  @property
548
588
  def target(self):
549
- """Returns the target NDArray."""
589
+ """Returns the target."""
590
+ target = self._target
550
591
  if isinstance(self._target, Density):
551
592
  target = self._target.data
552
- else:
553
- target = self._target
554
- out_shape = backend.to_numpy_array(self._output_target_shape)
555
- return target.reshape(tuple(int(x) for x in out_shape))
593
+
594
+ out_shape = tuple(int(x) for x in self._output_target_shape)
595
+ return target.reshape(out_shape)
556
596
 
557
597
  @property
558
598
  def template(self):
559
- """Returns the reversed template NDArray."""
599
+ """Returns the reversed template."""
560
600
  template = self._template
561
601
  if isinstance(self._template, Density):
562
602
  template = self._template.data
563
603
  template = backend.reverse(template)
564
- out_shape = backend.to_numpy_array(self._output_template_shape)
565
- return template.reshape(tuple(int(x) for x in out_shape))
604
+ out_shape = tuple(int(x) for x in self._output_template_shape)
605
+ return template.reshape(out_shape)
566
606
 
567
607
  @template.setter
568
608
  def template(self, template: NDArray):
@@ -582,9 +622,7 @@ class MatchingData:
582
622
  shape=template.shape, dtype=float, fill_value=1
583
623
  )
584
624
 
585
- if type(template) == Density:
586
- template = template.data
587
- self._template = template.astype(self._default_dtype, copy=False)
625
+ self._template = template
588
626
 
589
627
  @property
590
628
  def target_mask(self):
@@ -594,8 +632,8 @@ class MatchingData:
594
632
  target_mask = self._target_mask.data
595
633
 
596
634
  if target_mask is not None:
597
- out_shape = backend.to_numpy_array(self._output_target_shape)
598
- target_mask = target_mask.reshape(tuple(int(x) for x in out_shape))
635
+ out_shape = tuple(int(x) for x in self._output_target_shape)
636
+ target_mask = target_mask.reshape(out_shape)
599
637
 
600
638
  return target_mask
601
639
 
@@ -605,14 +643,7 @@ class MatchingData:
605
643
  if not np.all(self.target.shape == mask.shape):
606
644
  raise ValueError("Target and its mask have to have the same shape.")
607
645
 
608
- if type(mask) == Density:
609
- mask.data = mask.data.astype(self._default_dtype, copy=False)
610
- self._target_mask = mask
611
- self._targetmaskshape = self._target_mask.shape[::-1]
612
- return None
613
-
614
- self._target_mask = mask.astype(self._default_dtype, copy=False)
615
- self._targetmaskshape = self._target_mask.shape
646
+ self._target_mask = mask
616
647
 
617
648
  @property
618
649
  def template_mask(self):
@@ -630,24 +661,17 @@ class MatchingData:
630
661
 
631
662
  if mask is not None:
632
663
  mask = backend.reverse(mask)
633
- out_shape = backend.to_numpy_array(self._output_template_shape)
634
- mask = mask.reshape(tuple(int(x) for x in out_shape))
664
+ out_shape = tuple(int(x) for x in self._output_template_shape)
665
+ mask = mask.reshape(out_shape)
635
666
  return mask
636
667
 
637
668
  @template_mask.setter
638
669
  def template_mask(self, mask: NDArray):
639
670
  """Returns the reversed template mask NDArray."""
640
671
  if not np.all(self._templateshape[::-1] == mask.shape):
641
- raise ValueError("Target and its mask have to have the same shape.")
642
-
643
- if type(mask) == Density:
644
- mask.data = mask.data.astype(self._default_dtype, copy=False)
645
- self._template_mask = mask
646
- self._templatemaskshape = self._template_mask.shape[::-1]
647
- return None
672
+ raise ValueError("Template and its mask have to have the same shape.")
648
673
 
649
- self._template_mask = mask.astype(self._default_dtype, copy=False)
650
- self._templatemaskshape = self._template_mask.shape[::-1]
674
+ self._template_mask = mask
651
675
 
652
676
  def _split_rotations_on_jobs(self, n_jobs: int) -> List[NDArray]:
653
677
  """