pytme 0.2.0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/match_template.py +183 -69
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/METADATA +1 -1
- pytme-0.2.1.dist-info/RECORD +73 -0
- scripts/extract_candidates.py +117 -85
- scripts/match_template.py +183 -69
- scripts/match_template_filters.py +193 -71
- scripts/postprocess.py +107 -49
- scripts/preprocessor_gui.py +4 -1
- scripts/refine_matches.py +364 -160
- tme/__version__.py +1 -1
- tme/analyzer.py +259 -117
- tme/backends/__init__.py +1 -0
- tme/backends/cupy_backend.py +20 -13
- tme/backends/jax_backend.py +218 -0
- tme/backends/matching_backend.py +25 -10
- tme/backends/mlx_backend.py +13 -9
- tme/backends/npfftw_backend.py +20 -8
- tme/backends/pytorch_backend.py +20 -9
- tme/density.py +79 -60
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +85 -61
- tme/matching_exhaustive.py +222 -129
- tme/matching_optimization.py +117 -76
- tme/orientations.py +175 -55
- tme/preprocessing/_utils.py +17 -5
- tme/preprocessing/composable_filter.py +2 -1
- tme/preprocessing/compose.py +1 -2
- tme/preprocessing/frequency_filters.py +97 -41
- tme/preprocessing/tilt_series.py +137 -87
- tme/preprocessor.py +3 -0
- tme/structure.py +4 -1
- pytme-0.2.0.dist-info/RECORD +0 -72
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/density.py
CHANGED
@@ -40,12 +40,12 @@ from .backends import NumpyFFTWBackend
|
|
40
40
|
|
41
41
|
class Density:
|
42
42
|
"""
|
43
|
-
|
43
|
+
Abstract representation of N-dimensional densities.
|
44
44
|
|
45
45
|
Parameters
|
46
46
|
----------
|
47
47
|
data : NDArray
|
48
|
-
|
48
|
+
Array of data values.
|
49
49
|
origin : NDArray, optional
|
50
50
|
Origin of the coordinate system. Defaults to zero.
|
51
51
|
sampling_rate : NDArray, optional
|
@@ -267,9 +267,9 @@ class Density:
|
|
267
267
|
# nx := column; ny := row; nz := section
|
268
268
|
start = np.array(
|
269
269
|
[
|
270
|
-
|
271
|
-
|
272
|
-
|
270
|
+
mrc.header["nzstart"],
|
271
|
+
mrc.header["nystart"],
|
272
|
+
mrc.header["nxstart"],
|
273
273
|
]
|
274
274
|
)
|
275
275
|
|
@@ -291,17 +291,9 @@ class Density:
|
|
291
291
|
sampling_rate = mrc.voxel_size.astype(
|
292
292
|
[("x", "<f4"), ("y", "<f4"), ("z", "<f4")]
|
293
293
|
).view(("<f4", 3))
|
294
|
-
sampling_rate = sampling_rate[::-1]
|
295
|
-
sampling_rate = np.array(sampling_rate)
|
294
|
+
sampling_rate = np.array(sampling_rate)[::-1]
|
296
295
|
|
297
|
-
if np.
|
298
|
-
pass
|
299
|
-
elif np.all(origin == 0) and not np.all(start == 0):
|
300
|
-
origin = np.multiply(start, sampling_rate)
|
301
|
-
elif np.all(
|
302
|
-
np.abs(origin.astype(int))
|
303
|
-
!= np.abs((start * sampling_rate).astype(int))
|
304
|
-
) and not np.all(start == 0):
|
296
|
+
if np.allclose(origin, 0) and not np.allclose(start, 0):
|
305
297
|
origin = np.multiply(start, sampling_rate)
|
306
298
|
|
307
299
|
extended_header = mrc.header.nsymbt
|
@@ -878,7 +870,7 @@ class Density:
|
|
878
870
|
compression = "gzip" if gzip else None
|
879
871
|
with mrcfile.new(filename, overwrite=True, compression=compression) as mrc:
|
880
872
|
mrc.set_data(self.data.astype("float32"))
|
881
|
-
mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.
|
873
|
+
mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.rint(
|
882
874
|
np.divide(self.origin, self.sampling_rate)
|
883
875
|
)
|
884
876
|
# mrcfile library expects origin to be in xyz format
|
@@ -1529,11 +1521,10 @@ class Density:
|
|
1529
1521
|
|
1530
1522
|
Returns
|
1531
1523
|
-------
|
1532
|
-
Density
|
1533
|
-
A copy of the class instance
|
1534
|
-
center of the data array.
|
1524
|
+
:py:class:`Density`
|
1525
|
+
A centered copy of the current class instance.
|
1535
1526
|
NDArray
|
1536
|
-
The
|
1527
|
+
The offset between array center and center of mass.
|
1537
1528
|
|
1538
1529
|
See Also
|
1539
1530
|
--------
|
@@ -1550,44 +1541,32 @@ class Density:
|
|
1550
1541
|
|
1551
1542
|
>>> import numpy as np
|
1552
1543
|
>>> from tme import Density
|
1553
|
-
>>> dens = Density(np.ones((5,5)))
|
1544
|
+
>>> dens = Density(np.ones((5,5,5)))
|
1554
1545
|
>>> centered_dens, translation = dens.centered(0)
|
1555
1546
|
>>> translation
|
1556
|
-
array([
|
1547
|
+
array([0., 0., 0.])
|
1557
1548
|
|
1558
1549
|
:py:meth:`Density.centered` extended the :py:attr:`Density.data` attribute
|
1559
1550
|
of the current :py:class:`Density` instance and modified
|
1560
1551
|
:py:attr:`Density.origin` accordingly.
|
1561
1552
|
|
1562
1553
|
>>> centered_dens
|
1563
|
-
Origin: (-
|
1554
|
+
Origin: (-2.0, -2.0, -2.0), sampling_rate: (1, 1, 1), Shape: (9, 9, 9)
|
1564
1555
|
|
1565
1556
|
:py:meth:`Density.centered` achieves centering via zero-padding and
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1569
|
-
array([[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
|
1570
|
-
[0. , 0.25, 0.5 , 0.5 , 0.5 , 0.5 , 0.25, 0. ],
|
1571
|
-
[0. , 0.5 , 1. , 1. , 1. , 1. , 0.5 , 0. ],
|
1572
|
-
[0. , 0.5 , 1. , 1. , 1. , 1. , 0.5 , 0. ],
|
1573
|
-
[0. , 0.5 , 1. , 1. , 1. , 1. , 0.5 , 0. ],
|
1574
|
-
[0. , 0.5 , 1. , 1. , 1. , 1. , 0.5 , 0. ],
|
1575
|
-
[0. , 0.25, 0.5 , 0.5 , 0.5 , 0.5 , 0.25, 0. ],
|
1576
|
-
[0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]])
|
1577
|
-
|
1578
|
-
`centered_dens` is sufficiently large to represent all rotations that
|
1579
|
-
could be applied to the :py:attr:`Density.data` attribute. Lets look
|
1580
|
-
at a random rotation obtained from
|
1557
|
+
rigid-transform of the internal :py:attr:`Density.data` attribute.
|
1558
|
+
`centered_dens` is sufficiently large to represent all rotations of the
|
1559
|
+
:py:attr:`Density.data` attribute, such as ones obtained from
|
1581
1560
|
:py:meth:`tme.matching_utils.get_rotation_matrices`.
|
1582
1561
|
|
1583
1562
|
>>> from tme.matching_utils import get_rotation_matrices
|
1584
|
-
>>> rotation_matrix = get_rotation_matrices(dim =
|
1563
|
+
>>> rotation_matrix = get_rotation_matrices(dim = 3 ,angular_sampling = 10)[0]
|
1585
1564
|
>>> rotated_centered_dens = centered_dens.rigid_transform(
|
1586
1565
|
>>> rotation_matrix = rotation_matrix,
|
1587
1566
|
>>> order = None
|
1588
1567
|
>>> )
|
1589
1568
|
>>> print(centered_dens.data.sum(), rotated_centered_dens.data.sum())
|
1590
|
-
|
1569
|
+
125.0 125.0
|
1591
1570
|
|
1592
1571
|
"""
|
1593
1572
|
ret = self.copy()
|
@@ -1596,7 +1575,7 @@ class Density:
|
|
1596
1575
|
ret.adjust_box(box)
|
1597
1576
|
|
1598
1577
|
new_shape = np.maximum(ret.shape, self.shape)
|
1599
|
-
new_shape = np.add(new_shape, np.mod(new_shape, 2))
|
1578
|
+
new_shape = np.add(new_shape, 1 - np.mod(new_shape, 2))
|
1600
1579
|
ret.pad(new_shape)
|
1601
1580
|
|
1602
1581
|
center = self.center_of_mass(ret.data, cutoff)
|
@@ -1608,9 +1587,9 @@ class Density:
|
|
1608
1587
|
use_geometric_center=False,
|
1609
1588
|
order=1,
|
1610
1589
|
)
|
1611
|
-
offset = np.subtract(center, self.center_of_mass(ret.data, cutoff))
|
1612
1590
|
|
1613
|
-
|
1591
|
+
shift = np.subtract(center, self.center_of_mass(ret.data, cutoff))
|
1592
|
+
return ret, shift
|
1614
1593
|
|
1615
1594
|
@classmethod
|
1616
1595
|
def rotate_array(
|
@@ -1733,31 +1712,45 @@ class Density:
|
|
1733
1712
|
Parameters
|
1734
1713
|
----------
|
1735
1714
|
rotation_matrix : NDArray
|
1736
|
-
Rotation matrix to apply
|
1715
|
+
Rotation matrix to apply.
|
1737
1716
|
translation : NDArray
|
1738
|
-
Translation to apply
|
1717
|
+
Translation to apply.
|
1739
1718
|
order : int, optional
|
1740
|
-
|
1719
|
+
Interpolation order to use. Default is 3, has to be in range 0-5.
|
1741
1720
|
use_geometric_center : bool, optional
|
1742
|
-
|
1743
|
-
class instance should be centered using :py:meth:`Density.centered`.
|
1721
|
+
Use geometric or mass center as rotation center.
|
1744
1722
|
|
1745
1723
|
Returns
|
1746
1724
|
-------
|
1747
1725
|
Density
|
1748
|
-
The transformed instance of :py:class:`
|
1726
|
+
The transformed instance of :py:class:`Density`.
|
1749
1727
|
|
1750
1728
|
Examples
|
1751
1729
|
--------
|
1730
|
+
Define the :py:class:`Density` instance
|
1731
|
+
|
1752
1732
|
>>> import numpy as np
|
1753
|
-
>>>
|
1754
|
-
>>>
|
1755
|
-
>>>
|
1733
|
+
>>> from tme import Density
|
1734
|
+
>>> dens = Density(np.arange(9).reshape(3,3).astype(np.float32))
|
1735
|
+
>>> dens, translation = dens.centered(0)
|
1736
|
+
|
1737
|
+
and apply the rotation, in this case a mirror around the z-axis
|
1738
|
+
|
1739
|
+
>>> rotation_matrix = np.eye(dens.data.ndim)
|
1740
|
+
>>> rotation_matrix[0, 0] = -1
|
1741
|
+
>>> dens_transform = dens.rigid_transform(rotation_matrix = rotation_matrix)
|
1742
|
+
>>> dens_transform.data
|
1743
|
+
array([[0. , 0. , 0. , 0. , 0. ],
|
1744
|
+
[0.5 , 3.0833333 , 3.5833333 , 3.3333333 , 0. ],
|
1745
|
+
[0.75 , 4.6666665 , 5.6666665 , 5.4166665 , 0. ],
|
1746
|
+
[0.25 , 1.6666666 , 2.6666667 , 2.9166667 , 0. ],
|
1747
|
+
[0. , 0.08333334, 0.5833333 , 0.8333333 , 0. ]],
|
1748
|
+
dtype=float32)
|
1756
1749
|
|
1757
1750
|
Notes
|
1758
1751
|
-----
|
1759
|
-
:py:
|
1760
|
-
sufficiently sized to
|
1752
|
+
This function assumes the internal :py:attr:`Density.data` attribute is
|
1753
|
+
sufficiently sized to hold the transformation.
|
1761
1754
|
|
1762
1755
|
See Also
|
1763
1756
|
--------
|
@@ -2155,6 +2148,7 @@ class Density:
|
|
2155
2148
|
cutoff_target: float = 0,
|
2156
2149
|
cutoff_template: float = 0,
|
2157
2150
|
scoring_method: str = "NormalizedCrossCorrelation",
|
2151
|
+
**kwargs,
|
2158
2152
|
) -> Tuple["Density", NDArray, NDArray, NDArray]:
|
2159
2153
|
"""
|
2160
2154
|
Aligns two :py:class:`Density` instances target and template and returns
|
@@ -2179,6 +2173,9 @@ class Density:
|
|
2179
2173
|
The scoring method to use for alignment. See
|
2180
2174
|
:py:class:`tme.matching_optimization.create_score_object` for available methods,
|
2181
2175
|
by default "NormalizedCrossCorrelation".
|
2176
|
+
kwargs : dict, optional
|
2177
|
+
Optional keyword arguments passed to
|
2178
|
+
:py:meth:`tme.matching_optimization.optimize_match`.
|
2182
2179
|
|
2183
2180
|
Returns
|
2184
2181
|
-------
|
@@ -2190,8 +2187,18 @@ class Density:
|
|
2190
2187
|
-----
|
2191
2188
|
No densities below cutoff_template are present in the returned Density object.
|
2192
2189
|
"""
|
2190
|
+
from .matching_exhaustive import normalize_under_mask
|
2193
2191
|
from .matching_optimization import optimize_match, create_score_object
|
2194
2192
|
|
2193
|
+
template_mask = template.empty
|
2194
|
+
template_mask.data[:] = 1
|
2195
|
+
|
2196
|
+
normalize_under_mask(
|
2197
|
+
template=template.data,
|
2198
|
+
mask=template_mask.data,
|
2199
|
+
mask_intensity=template_mask.data.sum(),
|
2200
|
+
)
|
2201
|
+
|
2195
2202
|
target_sampling_rate = np.array(target.sampling_rate)
|
2196
2203
|
template_sampling_rate = np.array(template.sampling_rate)
|
2197
2204
|
|
@@ -2226,16 +2233,21 @@ class Density:
|
|
2226
2233
|
).astype(int)
|
2227
2234
|
template_coordinates += mass_center_difference[:, None]
|
2228
2235
|
|
2236
|
+
coordinates_mask = template_mask.to_pointcloud()
|
2237
|
+
coordinates_mask = coordinates_mask * template_scaling[:, None]
|
2238
|
+
coordinates_mask += mass_center_difference[:, None]
|
2239
|
+
|
2229
2240
|
score_object = create_score_object(
|
2230
2241
|
score=scoring_method,
|
2231
2242
|
target=target.data,
|
2232
2243
|
template_coordinates=template_coordinates,
|
2244
|
+
template_mask_coordinates=coordinates_mask,
|
2233
2245
|
template_weights=template_weights,
|
2234
2246
|
sampling_rate=np.ones(template.data.ndim),
|
2235
2247
|
)
|
2236
2248
|
|
2237
2249
|
translation, rotation_matrix, score = optimize_match(
|
2238
|
-
score_object=score_object,
|
2250
|
+
score_object=score_object, **kwargs
|
2239
2251
|
)
|
2240
2252
|
|
2241
2253
|
translation += mass_center_difference
|
@@ -2257,6 +2269,8 @@ class Density:
|
|
2257
2269
|
template: "Structure",
|
2258
2270
|
cutoff_target: float = 0,
|
2259
2271
|
scoring_method: str = "NormalizedCrossCorrelation",
|
2272
|
+
optimization_method: str = "basinhopping",
|
2273
|
+
maxiter: int = 500,
|
2260
2274
|
) -> Tuple["Structure", NDArray, NDArray]:
|
2261
2275
|
"""
|
2262
2276
|
Aligns a :py:class:`tme.structure.Structure` template to :py:class:`Density`
|
@@ -2281,6 +2295,12 @@ class Density:
|
|
2281
2295
|
The scoring method to use for template matching. See
|
2282
2296
|
:py:class:`tme.matching_optimization.create_score_object` for available methods,
|
2283
2297
|
by default "NormalizedCrossCorrelation".
|
2298
|
+
optimization_method : str, optional
|
2299
|
+
Optimizer that is used.
|
2300
|
+
See :py:meth:`tme.matching_optimization.optimize_match`.
|
2301
|
+
maxiter : int, optional
|
2302
|
+
Maximum number of iterations for the optimizer.
|
2303
|
+
See :py:meth:`tme.matching_optimization.optimize_match`.
|
2284
2304
|
|
2285
2305
|
Returns
|
2286
2306
|
-------
|
@@ -2304,18 +2324,17 @@ class Density:
|
|
2304
2324
|
cutoff_target=cutoff_target,
|
2305
2325
|
cutoff_template=0,
|
2306
2326
|
scoring_method=scoring_method,
|
2327
|
+
optimization_method=optimization_method,
|
2328
|
+
maxiter=maxiter,
|
2307
2329
|
)
|
2308
2330
|
out = template.copy()
|
2309
|
-
final_translation = np.
|
2310
|
-
-template_density.origin,
|
2311
|
-
np.multiply(translation, template_density.sampling_rate),
|
2312
|
-
)
|
2331
|
+
final_translation = np.subtract(ret.origin, template_density.origin)
|
2313
2332
|
|
2314
2333
|
# Atom coordinates are in xyz
|
2315
2334
|
final_translation = final_translation[::-1]
|
2316
2335
|
rotation_matrix = rotation_matrix[::-1, ::-1]
|
2317
2336
|
|
2318
|
-
out.rigid_transform(
|
2337
|
+
out = out.rigid_transform(
|
2319
2338
|
translation=final_translation, rotation_matrix=rotation_matrix
|
2320
2339
|
)
|
2321
2340
|
|
Binary file
|
tme/matching_data.py
CHANGED
@@ -22,20 +22,44 @@ class MatchingData:
|
|
22
22
|
|
23
23
|
Parameters
|
24
24
|
----------
|
25
|
-
target : np.ndarray or Density
|
26
|
-
Target data
|
27
|
-
template : np.ndarray or Density
|
28
|
-
Template data
|
25
|
+
target : np.ndarray or :py:class:`tme.density.Density`
|
26
|
+
Target data.
|
27
|
+
template : np.ndarray or :py:class:`tme.density.Density`
|
28
|
+
Template data.
|
29
|
+
target_mask : np.ndarray or :py:class:`tme.density.Density`, optional
|
30
|
+
Target mask data.
|
31
|
+
template_mask : np.ndarray or :py:class:`tme.density.Density`, optional
|
32
|
+
Template mask data.
|
33
|
+
invert_target : bool, optional
|
34
|
+
Whether to invert and rescale the target before template matching..
|
35
|
+
rotations: np.ndarray, optional
|
36
|
+
Template rotations to sample. Can be a single (d x d) or a stack (n x d x d)
|
37
|
+
of rotation matrices where d is the dimension of the template.
|
38
|
+
|
39
|
+
Examples
|
40
|
+
--------
|
41
|
+
The following achieves the minimal definition of a :py:class:`MatchingData` instance.
|
42
|
+
|
43
|
+
>>> import numpy as np
|
44
|
+
>>> from tme.matching_data import MatchingData
|
45
|
+
>>> target = np.random.rand(50,40,60)
|
46
|
+
>>> template = target[15:25, 10:20, 30:40]
|
47
|
+
>>> matching_data = MatchingData(target=target, template=template)
|
29
48
|
|
30
49
|
"""
|
31
50
|
|
32
|
-
def __init__(
|
33
|
-
self
|
34
|
-
|
35
|
-
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
target: NDArray,
|
54
|
+
template: NDArray,
|
55
|
+
template_mask: NDArray = None,
|
56
|
+
target_mask: NDArray = None,
|
57
|
+
invert_target: bool = False,
|
58
|
+
rotations: NDArray = None,
|
59
|
+
):
|
36
60
|
self._target = target
|
37
|
-
self._target_mask =
|
38
|
-
self._template_mask =
|
61
|
+
self._target_mask = target_mask
|
62
|
+
self._template_mask = template_mask
|
39
63
|
self._translation_offset = np.zeros(len(target.shape), dtype=int)
|
40
64
|
|
41
65
|
self.template = template
|
@@ -46,8 +70,11 @@ class MatchingData:
|
|
46
70
|
self.template_filter = {}
|
47
71
|
self.target_filter = {}
|
48
72
|
|
49
|
-
self._invert_target =
|
50
|
-
|
73
|
+
self._invert_target = invert_target
|
74
|
+
|
75
|
+
self._rotations = rotations
|
76
|
+
if rotations is not None:
|
77
|
+
self.rotations = rotations
|
51
78
|
|
52
79
|
self._set_batch_dimension()
|
53
80
|
|
@@ -257,7 +284,7 @@ class MatchingData:
|
|
257
284
|
template_offset[(template_offset.size - len(template_slice)) :] = [
|
258
285
|
x.start for x in template_slice
|
259
286
|
]
|
260
|
-
ret._translation_offset =
|
287
|
+
ret._translation_offset = target_offset
|
261
288
|
|
262
289
|
ret.template_filter = self.template_filter
|
263
290
|
ret.target_filter = self.target_filter
|
@@ -266,7 +293,7 @@ class MatchingData:
|
|
266
293
|
ret._invert_target = self._invert_target
|
267
294
|
|
268
295
|
if self._target_mask is not None:
|
269
|
-
ret.
|
296
|
+
ret._target_mask = self.subset_array(
|
270
297
|
arr=self._target_mask, arr_slice=target_slice, padding=target_pad
|
271
298
|
)
|
272
299
|
if self._template_mask is not None:
|
@@ -289,19 +316,29 @@ class MatchingData:
|
|
289
316
|
|
290
317
|
def to_backend(self) -> None:
|
291
318
|
"""
|
292
|
-
Transfer
|
319
|
+
Transfer and convert types of class instance's data arrays to the current backend
|
293
320
|
"""
|
294
|
-
backend_arr = type(backend.zeros((1), dtype=backend.
|
321
|
+
backend_arr = type(backend.zeros((1), dtype=backend._float_dtype))
|
295
322
|
for attr_name, attr_value in vars(self).items():
|
323
|
+
converted_array = None
|
296
324
|
if isinstance(attr_value, np.ndarray):
|
297
325
|
converted_array = backend.to_backend_array(attr_value.copy())
|
298
|
-
setattr(self, attr_name, converted_array)
|
299
326
|
elif isinstance(attr_value, backend_arr):
|
300
327
|
converted_array = backend.to_backend_array(attr_value)
|
301
|
-
|
328
|
+
else:
|
329
|
+
continue
|
330
|
+
|
331
|
+
current_dtype = backend.get_fundamental_dtype(converted_array)
|
332
|
+
target_dtype = backend._fundamental_dtypes[current_dtype]
|
333
|
+
|
334
|
+
# Optional, but scores are float so we avoid casting and potential issues
|
335
|
+
if attr_name in ("_template", "_template_mask", "_target", "_target_mask"):
|
336
|
+
target_dtype = backend._float_dtype
|
302
337
|
|
303
|
-
|
304
|
-
|
338
|
+
if target_dtype != current_dtype:
|
339
|
+
converted_array = backend.astype(converted_array, target_dtype)
|
340
|
+
|
341
|
+
setattr(self, attr_name, converted_array)
|
305
342
|
|
306
343
|
def _set_batch_dimension(
|
307
344
|
self, target_dims: Tuple[int] = None, template_dims: Tuple[int] = None
|
@@ -350,12 +387,14 @@ class MatchingData:
|
|
350
387
|
matching_dims = target_measurement_dims + batch_dims
|
351
388
|
|
352
389
|
target_shape = backend.full(
|
353
|
-
shape=(matching_dims,), fill_value=1, dtype=backend.
|
390
|
+
shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
|
354
391
|
)
|
355
392
|
template_shape = backend.full(
|
356
|
-
shape=(matching_dims,), fill_value=1, dtype=backend.
|
393
|
+
shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
|
394
|
+
)
|
395
|
+
batch_mask = backend.full(
|
396
|
+
shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
|
357
397
|
)
|
358
|
-
batch_mask = backend.full(shape=(matching_dims,), fill_value=1, dtype=bool)
|
359
398
|
|
360
399
|
target_index, template_index = 0, 0
|
361
400
|
for k in range(matching_dims):
|
@@ -440,7 +479,7 @@ class MatchingData:
|
|
440
479
|
An array indicating the padding for each dimension of the target.
|
441
480
|
"""
|
442
481
|
target_padding = backend.zeros(
|
443
|
-
len(self._output_target_shape), dtype=backend.
|
482
|
+
len(self._output_target_shape), dtype=backend._int_dtype
|
444
483
|
)
|
445
484
|
|
446
485
|
if pad_target:
|
@@ -491,13 +530,14 @@ class MatchingData:
|
|
491
530
|
fourier_pad = backend.full(
|
492
531
|
shape=(len(fourier_pad),),
|
493
532
|
fill_value=1,
|
494
|
-
dtype=backend.
|
533
|
+
dtype=backend._int_dtype,
|
495
534
|
)
|
496
535
|
|
497
536
|
fourier_pad = backend.to_backend_array(fourier_pad)
|
498
537
|
if hasattr(self, "_batch_mask"):
|
499
538
|
batch_mask = backend.to_backend_array(self._batch_mask)
|
500
|
-
fourier_pad
|
539
|
+
backend.multiply(fourier_pad, 1 - batch_mask, out=fourier_pad)
|
540
|
+
backend.add(fourier_pad, batch_mask, out=fourier_pad)
|
501
541
|
|
502
542
|
pad_shape = backend.maximum(target_shape, template_shape)
|
503
543
|
ret = backend.compute_convolution_shapes(pad_shape, fourier_pad)
|
@@ -510,11 +550,11 @@ class MatchingData:
|
|
510
550
|
|
511
551
|
if hasattr(self, "_batch_mask"):
|
512
552
|
batch_mask = backend.to_backend_array(self._batch_mask)
|
513
|
-
shape_diff
|
553
|
+
backend.multiply(shape_diff, 1 - batch_mask, out=shape_diff)
|
514
554
|
|
515
555
|
backend.add(fourier_shift, shape_diff, out=fourier_shift)
|
516
556
|
|
517
|
-
fourier_shift = backend.astype(fourier_shift, backend.
|
557
|
+
fourier_shift = backend.astype(fourier_shift, backend._int_dtype)
|
518
558
|
|
519
559
|
return fast_shape, fast_ft_shape, fourier_shift
|
520
560
|
|
@@ -542,27 +582,27 @@ class MatchingData:
|
|
542
582
|
pass
|
543
583
|
else:
|
544
584
|
raise ValueError("Rotations have to be a rank 2 or 3 array.")
|
545
|
-
self._rotations = rotations.astype(
|
585
|
+
self._rotations = rotations.astype(np.float32)
|
546
586
|
|
547
587
|
@property
|
548
588
|
def target(self):
|
549
|
-
"""Returns the target
|
589
|
+
"""Returns the target."""
|
590
|
+
target = self._target
|
550
591
|
if isinstance(self._target, Density):
|
551
592
|
target = self._target.data
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
return target.reshape(tuple(int(x) for x in out_shape))
|
593
|
+
|
594
|
+
out_shape = tuple(int(x) for x in self._output_target_shape)
|
595
|
+
return target.reshape(out_shape)
|
556
596
|
|
557
597
|
@property
|
558
598
|
def template(self):
|
559
|
-
"""Returns the reversed template
|
599
|
+
"""Returns the reversed template."""
|
560
600
|
template = self._template
|
561
601
|
if isinstance(self._template, Density):
|
562
602
|
template = self._template.data
|
563
603
|
template = backend.reverse(template)
|
564
|
-
out_shape =
|
565
|
-
return template.reshape(
|
604
|
+
out_shape = tuple(int(x) for x in self._output_template_shape)
|
605
|
+
return template.reshape(out_shape)
|
566
606
|
|
567
607
|
@template.setter
|
568
608
|
def template(self, template: NDArray):
|
@@ -582,9 +622,7 @@ class MatchingData:
|
|
582
622
|
shape=template.shape, dtype=float, fill_value=1
|
583
623
|
)
|
584
624
|
|
585
|
-
|
586
|
-
template = template.data
|
587
|
-
self._template = template.astype(self._default_dtype, copy=False)
|
625
|
+
self._template = template
|
588
626
|
|
589
627
|
@property
|
590
628
|
def target_mask(self):
|
@@ -594,8 +632,8 @@ class MatchingData:
|
|
594
632
|
target_mask = self._target_mask.data
|
595
633
|
|
596
634
|
if target_mask is not None:
|
597
|
-
out_shape =
|
598
|
-
target_mask = target_mask.reshape(
|
635
|
+
out_shape = tuple(int(x) for x in self._output_target_shape)
|
636
|
+
target_mask = target_mask.reshape(out_shape)
|
599
637
|
|
600
638
|
return target_mask
|
601
639
|
|
@@ -605,14 +643,7 @@ class MatchingData:
|
|
605
643
|
if not np.all(self.target.shape == mask.shape):
|
606
644
|
raise ValueError("Target and its mask have to have the same shape.")
|
607
645
|
|
608
|
-
|
609
|
-
mask.data = mask.data.astype(self._default_dtype, copy=False)
|
610
|
-
self._target_mask = mask
|
611
|
-
self._targetmaskshape = self._target_mask.shape[::-1]
|
612
|
-
return None
|
613
|
-
|
614
|
-
self._target_mask = mask.astype(self._default_dtype, copy=False)
|
615
|
-
self._targetmaskshape = self._target_mask.shape
|
646
|
+
self._target_mask = mask
|
616
647
|
|
617
648
|
@property
|
618
649
|
def template_mask(self):
|
@@ -630,24 +661,17 @@ class MatchingData:
|
|
630
661
|
|
631
662
|
if mask is not None:
|
632
663
|
mask = backend.reverse(mask)
|
633
|
-
out_shape =
|
634
|
-
mask = mask.reshape(
|
664
|
+
out_shape = tuple(int(x) for x in self._output_template_shape)
|
665
|
+
mask = mask.reshape(out_shape)
|
635
666
|
return mask
|
636
667
|
|
637
668
|
@template_mask.setter
|
638
669
|
def template_mask(self, mask: NDArray):
|
639
670
|
"""Returns the reversed template mask NDArray."""
|
640
671
|
if not np.all(self._templateshape[::-1] == mask.shape):
|
641
|
-
raise ValueError("
|
642
|
-
|
643
|
-
if type(mask) == Density:
|
644
|
-
mask.data = mask.data.astype(self._default_dtype, copy=False)
|
645
|
-
self._template_mask = mask
|
646
|
-
self._templatemaskshape = self._template_mask.shape[::-1]
|
647
|
-
return None
|
672
|
+
raise ValueError("Template and its mask have to have the same shape.")
|
648
673
|
|
649
|
-
self._template_mask = mask
|
650
|
-
self._templatemaskshape = self._template_mask.shape[::-1]
|
674
|
+
self._template_mask = mask
|
651
675
|
|
652
676
|
def _split_rotations_on_jobs(self, n_jobs: int) -> List[NDArray]:
|
653
677
|
"""
|