pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/match_template.py +473 -140
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/METADATA +2 -2
- pytme-0.2.1.dist-info/RECORD +73 -0
- scripts/extract_candidates.py +117 -85
- scripts/match_template.py +473 -140
- scripts/match_template_filters.py +458 -169
- scripts/postprocess.py +107 -49
- scripts/preprocessor_gui.py +4 -1
- scripts/refine_matches.py +364 -160
- tme/__version__.py +1 -1
- tme/analyzer.py +278 -148
- tme/backends/__init__.py +1 -0
- tme/backends/cupy_backend.py +20 -13
- tme/backends/jax_backend.py +218 -0
- tme/backends/matching_backend.py +25 -10
- tme/backends/mlx_backend.py +13 -9
- tme/backends/npfftw_backend.py +22 -12
- tme/backends/pytorch_backend.py +20 -9
- tme/density.py +85 -64
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +86 -60
- tme/matching_exhaustive.py +245 -166
- tme/matching_optimization.py +137 -69
- tme/matching_utils.py +1 -1
- tme/orientations.py +175 -55
- tme/preprocessing/__init__.py +2 -0
- tme/preprocessing/_utils.py +188 -0
- tme/preprocessing/composable_filter.py +31 -0
- tme/preprocessing/compose.py +51 -0
- tme/preprocessing/frequency_filters.py +378 -0
- tme/preprocessing/tilt_series.py +1017 -0
- tme/preprocessor.py +17 -7
- tme/structure.py +4 -1
- pytme-0.2.0b0.dist-info/RECORD +0 -66
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0b0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0b0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/density.py
CHANGED
@@ -40,12 +40,12 @@ from .backends import NumpyFFTWBackend
|
|
40
40
|
|
41
41
|
class Density:
|
42
42
|
"""
|
43
|
-
|
43
|
+
Abstract representation of N-dimensional densities.
|
44
44
|
|
45
45
|
Parameters
|
46
46
|
----------
|
47
47
|
data : NDArray
|
48
|
-
|
48
|
+
Array of data values.
|
49
49
|
origin : NDArray, optional
|
50
50
|
Origin of the coordinate system. Defaults to zero.
|
51
51
|
sampling_rate : NDArray, optional
|
@@ -267,9 +267,9 @@ class Density:
|
|
267
267
|
# nx := column; ny := row; nz := section
|
268
268
|
start = np.array(
|
269
269
|
[
|
270
|
-
|
271
|
-
|
272
|
-
|
270
|
+
mrc.header["nzstart"],
|
271
|
+
mrc.header["nystart"],
|
272
|
+
mrc.header["nxstart"],
|
273
273
|
]
|
274
274
|
)
|
275
275
|
|
@@ -291,17 +291,9 @@ class Density:
|
|
291
291
|
sampling_rate = mrc.voxel_size.astype(
|
292
292
|
[("x", "<f4"), ("y", "<f4"), ("z", "<f4")]
|
293
293
|
).view(("<f4", 3))
|
294
|
-
sampling_rate = sampling_rate[::-1]
|
295
|
-
sampling_rate = np.array(sampling_rate)
|
294
|
+
sampling_rate = np.array(sampling_rate)[::-1]
|
296
295
|
|
297
|
-
if np.
|
298
|
-
pass
|
299
|
-
elif np.all(origin == 0) and not np.all(start == 0):
|
300
|
-
origin = np.multiply(start, sampling_rate)
|
301
|
-
elif np.all(
|
302
|
-
np.abs(origin.astype(int))
|
303
|
-
!= np.abs((start * sampling_rate).astype(int))
|
304
|
-
) and not np.all(start == 0):
|
296
|
+
if np.allclose(origin, 0) and not np.allclose(start, 0):
|
305
297
|
origin = np.multiply(start, sampling_rate)
|
306
298
|
|
307
299
|
extended_header = mrc.header.nsymbt
|
@@ -317,7 +309,7 @@ class Density:
|
|
317
309
|
if use_memmap:
|
318
310
|
warnings.warn(
|
319
311
|
f"Cannot open gzipped file {filename} as memmap."
|
320
|
-
f" Please gunzip {filename} to use memmap functionality."
|
312
|
+
f" Please run 'gunzip {filename}' to use memmap functionality."
|
321
313
|
)
|
322
314
|
use_memmap = False
|
323
315
|
|
@@ -878,7 +870,7 @@ class Density:
|
|
878
870
|
compression = "gzip" if gzip else None
|
879
871
|
with mrcfile.new(filename, overwrite=True, compression=compression) as mrc:
|
880
872
|
mrc.set_data(self.data.astype("float32"))
|
881
|
-
mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.
|
873
|
+
mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.rint(
|
882
874
|
np.divide(self.origin, self.sampling_rate)
|
883
875
|
)
|
884
876
|
# mrcfile library expects origin to be in xyz format
|
@@ -1529,11 +1521,10 @@ class Density:
|
|
1529
1521
|
|
1530
1522
|
Returns
|
1531
1523
|
-------
|
1532
|
-
Density
|
1533
|
-
A copy of the class instance
|
1534
|
-
center of the data array.
|
1524
|
+
:py:class:`Density`
|
1525
|
+
A centered copy of the current class instance.
|
1535
1526
|
NDArray
|
1536
|
-
The
|
1527
|
+
The offset between array center and center of mass.
|
1537
1528
|
|
1538
1529
|
See Also
|
1539
1530
|
--------
|
@@ -1545,48 +1536,37 @@ class Density:
|
|
1545
1536
|
--------
|
1546
1537
|
:py:meth:`Density.centered` returns a tuple containing a centered version
|
1547
1538
|
of the current :py:class:`Density` instance, as well as an array with
|
1548
|
-
translations. The translation corresponds to the shift
|
1549
|
-
center of mass
|
1539
|
+
translations. The translation corresponds to the shift between the original and
|
1540
|
+
current center of mass.
|
1550
1541
|
|
1551
1542
|
>>> import numpy as np
|
1552
1543
|
>>> from tme import Density
|
1553
|
-
>>> dens = Density(np.ones((5,5)))
|
1544
|
+
>>> dens = Density(np.ones((5,5,5)))
|
1554
1545
|
>>> centered_dens, translation = dens.centered(0)
|
1555
1546
|
>>> translation
|
1556
|
-
array([
|
1547
|
+
array([0., 0., 0.])
|
1557
1548
|
|
1558
1549
|
:py:meth:`Density.centered` extended the :py:attr:`Density.data` attribute
|
1559
1550
|
of the current :py:class:`Density` instance and modified
|
1560
1551
|
:py:attr:`Density.origin` accordingly.
|
1561
1552
|
|
1562
1553
|
>>> centered_dens
|
1563
|
-
Origin: (-
|
1564
|
-
|
1565
|
-
:py:meth:`Density.centered` achieves centering via zero-padding
|
1566
|
-
internal :py:attr:`Density.data` attribute
|
1567
|
-
|
1568
|
-
|
1569
|
-
array([[0., 0., 0., 0., 0., 0., 0.],
|
1570
|
-
[0., 1., 1., 1., 1., 1., 0.],
|
1571
|
-
[0., 1., 1., 1., 1., 1., 0.],
|
1572
|
-
[0., 1., 1., 1., 1., 1., 0.],
|
1573
|
-
[0., 1., 1., 1., 1., 1., 0.],
|
1574
|
-
[0., 1., 1., 1., 1., 1., 0.],
|
1575
|
-
[0., 0., 0., 0., 0., 0., 0.]])
|
1576
|
-
|
1577
|
-
`centered_dens` is sufficiently large to represent all rotations that
|
1578
|
-
could be applied to the :py:attr:`Density.data` attribute. Lets look
|
1579
|
-
at a random rotation obtained from
|
1554
|
+
Origin: (-2.0, -2.0, -2.0), sampling_rate: (1, 1, 1), Shape: (9, 9, 9)
|
1555
|
+
|
1556
|
+
:py:meth:`Density.centered` achieves centering via zero-padding and
|
1557
|
+
rigid-transform of the internal :py:attr:`Density.data` attribute.
|
1558
|
+
`centered_dens` is sufficiently large to represent all rotations of the
|
1559
|
+
:py:attr:`Density.data` attribute, such as ones obtained from
|
1580
1560
|
:py:meth:`tme.matching_utils.get_rotation_matrices`.
|
1581
1561
|
|
1582
1562
|
>>> from tme.matching_utils import get_rotation_matrices
|
1583
|
-
>>> rotation_matrix = get_rotation_matrices(dim =
|
1563
|
+
>>> rotation_matrix = get_rotation_matrices(dim = 3 ,angular_sampling = 10)[0]
|
1584
1564
|
>>> rotated_centered_dens = centered_dens.rigid_transform(
|
1585
1565
|
>>> rotation_matrix = rotation_matrix,
|
1586
1566
|
>>> order = None
|
1587
1567
|
>>> )
|
1588
1568
|
>>> print(centered_dens.data.sum(), rotated_centered_dens.data.sum())
|
1589
|
-
|
1569
|
+
125.0 125.0
|
1590
1570
|
|
1591
1571
|
"""
|
1592
1572
|
ret = self.copy()
|
@@ -1595,10 +1575,11 @@ class Density:
|
|
1595
1575
|
ret.adjust_box(box)
|
1596
1576
|
|
1597
1577
|
new_shape = np.maximum(ret.shape, self.shape)
|
1578
|
+
new_shape = np.add(new_shape, 1 - np.mod(new_shape, 2))
|
1598
1579
|
ret.pad(new_shape)
|
1599
1580
|
|
1600
1581
|
center = self.center_of_mass(ret.data, cutoff)
|
1601
|
-
shift = np.subtract(np.divide(ret.shape, 2), center)
|
1582
|
+
shift = np.subtract(np.divide(np.subtract(ret.shape, 1), 2), center)
|
1602
1583
|
|
1603
1584
|
ret = ret.rigid_transform(
|
1604
1585
|
translation=shift,
|
@@ -1606,9 +1587,9 @@ class Density:
|
|
1606
1587
|
use_geometric_center=False,
|
1607
1588
|
order=1,
|
1608
1589
|
)
|
1609
|
-
offset = np.subtract(center, self.center_of_mass(ret.data, cutoff))
|
1610
1590
|
|
1611
|
-
|
1591
|
+
shift = np.subtract(center, self.center_of_mass(ret.data, cutoff))
|
1592
|
+
return ret, shift
|
1612
1593
|
|
1613
1594
|
@classmethod
|
1614
1595
|
def rotate_array(
|
@@ -1731,31 +1712,45 @@ class Density:
|
|
1731
1712
|
Parameters
|
1732
1713
|
----------
|
1733
1714
|
rotation_matrix : NDArray
|
1734
|
-
Rotation matrix to apply
|
1715
|
+
Rotation matrix to apply.
|
1735
1716
|
translation : NDArray
|
1736
|
-
Translation to apply
|
1717
|
+
Translation to apply.
|
1737
1718
|
order : int, optional
|
1738
|
-
|
1719
|
+
Interpolation order to use. Default is 3, has to be in range 0-5.
|
1739
1720
|
use_geometric_center : bool, optional
|
1740
|
-
|
1741
|
-
class instance should be centered using :py:meth:`Density.centered`.
|
1721
|
+
Use geometric or mass center as rotation center.
|
1742
1722
|
|
1743
1723
|
Returns
|
1744
1724
|
-------
|
1745
1725
|
Density
|
1746
|
-
The transformed instance of :py:class:`
|
1726
|
+
The transformed instance of :py:class:`Density`.
|
1747
1727
|
|
1748
1728
|
Examples
|
1749
1729
|
--------
|
1730
|
+
Define the :py:class:`Density` instance
|
1731
|
+
|
1750
1732
|
>>> import numpy as np
|
1751
|
-
>>>
|
1752
|
-
>>>
|
1753
|
-
>>>
|
1733
|
+
>>> from tme import Density
|
1734
|
+
>>> dens = Density(np.arange(9).reshape(3,3).astype(np.float32))
|
1735
|
+
>>> dens, translation = dens.centered(0)
|
1736
|
+
|
1737
|
+
and apply the rotation, in this case a mirror around the z-axis
|
1738
|
+
|
1739
|
+
>>> rotation_matrix = np.eye(dens.data.ndim)
|
1740
|
+
>>> rotation_matrix[0, 0] = -1
|
1741
|
+
>>> dens_transform = dens.rigid_transform(rotation_matrix = rotation_matrix)
|
1742
|
+
>>> dens_transform.data
|
1743
|
+
array([[0. , 0. , 0. , 0. , 0. ],
|
1744
|
+
[0.5 , 3.0833333 , 3.5833333 , 3.3333333 , 0. ],
|
1745
|
+
[0.75 , 4.6666665 , 5.6666665 , 5.4166665 , 0. ],
|
1746
|
+
[0.25 , 1.6666666 , 2.6666667 , 2.9166667 , 0. ],
|
1747
|
+
[0. , 0.08333334, 0.5833333 , 0.8333333 , 0. ]],
|
1748
|
+
dtype=float32)
|
1754
1749
|
|
1755
1750
|
Notes
|
1756
1751
|
-----
|
1757
|
-
:py:
|
1758
|
-
sufficiently sized to
|
1752
|
+
This function assumes the internal :py:attr:`Density.data` attribute is
|
1753
|
+
sufficiently sized to hold the transformation.
|
1759
1754
|
|
1760
1755
|
See Also
|
1761
1756
|
--------
|
@@ -2153,6 +2148,7 @@ class Density:
|
|
2153
2148
|
cutoff_target: float = 0,
|
2154
2149
|
cutoff_template: float = 0,
|
2155
2150
|
scoring_method: str = "NormalizedCrossCorrelation",
|
2151
|
+
**kwargs,
|
2156
2152
|
) -> Tuple["Density", NDArray, NDArray, NDArray]:
|
2157
2153
|
"""
|
2158
2154
|
Aligns two :py:class:`Density` instances target and template and returns
|
@@ -2177,6 +2173,9 @@ class Density:
|
|
2177
2173
|
The scoring method to use for alignment. See
|
2178
2174
|
:py:class:`tme.matching_optimization.create_score_object` for available methods,
|
2179
2175
|
by default "NormalizedCrossCorrelation".
|
2176
|
+
kwargs : dict, optional
|
2177
|
+
Optional keyword arguments passed to
|
2178
|
+
:py:meth:`tme.matching_optimization.optimize_match`.
|
2180
2179
|
|
2181
2180
|
Returns
|
2182
2181
|
-------
|
@@ -2188,8 +2187,18 @@ class Density:
|
|
2188
2187
|
-----
|
2189
2188
|
No densities below cutoff_template are present in the returned Density object.
|
2190
2189
|
"""
|
2190
|
+
from .matching_exhaustive import normalize_under_mask
|
2191
2191
|
from .matching_optimization import optimize_match, create_score_object
|
2192
2192
|
|
2193
|
+
template_mask = template.empty
|
2194
|
+
template_mask.data[:] = 1
|
2195
|
+
|
2196
|
+
normalize_under_mask(
|
2197
|
+
template=template.data,
|
2198
|
+
mask=template_mask.data,
|
2199
|
+
mask_intensity=template_mask.data.sum(),
|
2200
|
+
)
|
2201
|
+
|
2193
2202
|
target_sampling_rate = np.array(target.sampling_rate)
|
2194
2203
|
template_sampling_rate = np.array(template.sampling_rate)
|
2195
2204
|
|
@@ -2224,16 +2233,21 @@ class Density:
|
|
2224
2233
|
).astype(int)
|
2225
2234
|
template_coordinates += mass_center_difference[:, None]
|
2226
2235
|
|
2236
|
+
coordinates_mask = template_mask.to_pointcloud()
|
2237
|
+
coordinates_mask = coordinates_mask * template_scaling[:, None]
|
2238
|
+
coordinates_mask += mass_center_difference[:, None]
|
2239
|
+
|
2227
2240
|
score_object = create_score_object(
|
2228
2241
|
score=scoring_method,
|
2229
2242
|
target=target.data,
|
2230
2243
|
template_coordinates=template_coordinates,
|
2244
|
+
template_mask_coordinates=coordinates_mask,
|
2231
2245
|
template_weights=template_weights,
|
2232
2246
|
sampling_rate=np.ones(template.data.ndim),
|
2233
2247
|
)
|
2234
2248
|
|
2235
2249
|
translation, rotation_matrix, score = optimize_match(
|
2236
|
-
score_object=score_object,
|
2250
|
+
score_object=score_object, **kwargs
|
2237
2251
|
)
|
2238
2252
|
|
2239
2253
|
translation += mass_center_difference
|
@@ -2255,6 +2269,8 @@ class Density:
|
|
2255
2269
|
template: "Structure",
|
2256
2270
|
cutoff_target: float = 0,
|
2257
2271
|
scoring_method: str = "NormalizedCrossCorrelation",
|
2272
|
+
optimization_method: str = "basinhopping",
|
2273
|
+
maxiter: int = 500,
|
2258
2274
|
) -> Tuple["Structure", NDArray, NDArray]:
|
2259
2275
|
"""
|
2260
2276
|
Aligns a :py:class:`tme.structure.Structure` template to :py:class:`Density`
|
@@ -2279,6 +2295,12 @@ class Density:
|
|
2279
2295
|
The scoring method to use for template matching. See
|
2280
2296
|
:py:class:`tme.matching_optimization.create_score_object` for available methods,
|
2281
2297
|
by default "NormalizedCrossCorrelation".
|
2298
|
+
optimization_method : str, optional
|
2299
|
+
Optimizer that is used.
|
2300
|
+
See :py:meth:`tme.matching_optimization.optimize_match`.
|
2301
|
+
maxiter : int, optional
|
2302
|
+
Maximum number of iterations for the optimizer.
|
2303
|
+
See :py:meth:`tme.matching_optimization.optimize_match`.
|
2282
2304
|
|
2283
2305
|
Returns
|
2284
2306
|
-------
|
@@ -2302,18 +2324,17 @@ class Density:
|
|
2302
2324
|
cutoff_target=cutoff_target,
|
2303
2325
|
cutoff_template=0,
|
2304
2326
|
scoring_method=scoring_method,
|
2327
|
+
optimization_method=optimization_method,
|
2328
|
+
maxiter=maxiter,
|
2305
2329
|
)
|
2306
2330
|
out = template.copy()
|
2307
|
-
final_translation = np.
|
2308
|
-
-template_density.origin,
|
2309
|
-
np.multiply(translation, template_density.sampling_rate),
|
2310
|
-
)
|
2331
|
+
final_translation = np.subtract(ret.origin, template_density.origin)
|
2311
2332
|
|
2312
2333
|
# Atom coordinates are in xyz
|
2313
2334
|
final_translation = final_translation[::-1]
|
2314
2335
|
rotation_matrix = rotation_matrix[::-1, ::-1]
|
2315
2336
|
|
2316
|
-
out.rigid_transform(
|
2337
|
+
out = out.rigid_transform(
|
2317
2338
|
translation=final_translation, rotation_matrix=rotation_matrix
|
2318
2339
|
)
|
2319
2340
|
|
Binary file
|
tme/matching_data.py
CHANGED
@@ -22,20 +22,44 @@ class MatchingData:
|
|
22
22
|
|
23
23
|
Parameters
|
24
24
|
----------
|
25
|
-
target : np.ndarray or Density
|
26
|
-
Target data
|
27
|
-
template : np.ndarray or Density
|
28
|
-
Template data
|
25
|
+
target : np.ndarray or :py:class:`tme.density.Density`
|
26
|
+
Target data.
|
27
|
+
template : np.ndarray or :py:class:`tme.density.Density`
|
28
|
+
Template data.
|
29
|
+
target_mask : np.ndarray or :py:class:`tme.density.Density`, optional
|
30
|
+
Target mask data.
|
31
|
+
template_mask : np.ndarray or :py:class:`tme.density.Density`, optional
|
32
|
+
Template mask data.
|
33
|
+
invert_target : bool, optional
|
34
|
+
Whether to invert and rescale the target before template matching..
|
35
|
+
rotations: np.ndarray, optional
|
36
|
+
Template rotations to sample. Can be a single (d x d) or a stack (n x d x d)
|
37
|
+
of rotation matrices where d is the dimension of the template.
|
38
|
+
|
39
|
+
Examples
|
40
|
+
--------
|
41
|
+
The following achieves the minimal definition of a :py:class:`MatchingData` instance.
|
42
|
+
|
43
|
+
>>> import numpy as np
|
44
|
+
>>> from tme.matching_data import MatchingData
|
45
|
+
>>> target = np.random.rand(50,40,60)
|
46
|
+
>>> template = target[15:25, 10:20, 30:40]
|
47
|
+
>>> matching_data = MatchingData(target=target, template=template)
|
29
48
|
|
30
49
|
"""
|
31
50
|
|
32
|
-
def __init__(
|
33
|
-
self
|
34
|
-
|
35
|
-
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
target: NDArray,
|
54
|
+
template: NDArray,
|
55
|
+
template_mask: NDArray = None,
|
56
|
+
target_mask: NDArray = None,
|
57
|
+
invert_target: bool = False,
|
58
|
+
rotations: NDArray = None,
|
59
|
+
):
|
36
60
|
self._target = target
|
37
|
-
self._target_mask =
|
38
|
-
self._template_mask =
|
61
|
+
self._target_mask = target_mask
|
62
|
+
self._template_mask = template_mask
|
39
63
|
self._translation_offset = np.zeros(len(target.shape), dtype=int)
|
40
64
|
|
41
65
|
self.template = template
|
@@ -46,8 +70,11 @@ class MatchingData:
|
|
46
70
|
self.template_filter = {}
|
47
71
|
self.target_filter = {}
|
48
72
|
|
49
|
-
self._invert_target =
|
50
|
-
|
73
|
+
self._invert_target = invert_target
|
74
|
+
|
75
|
+
self._rotations = rotations
|
76
|
+
if rotations is not None:
|
77
|
+
self.rotations = rotations
|
51
78
|
|
52
79
|
self._set_batch_dimension()
|
53
80
|
|
@@ -257,7 +284,7 @@ class MatchingData:
|
|
257
284
|
template_offset[(template_offset.size - len(template_slice)) :] = [
|
258
285
|
x.start for x in template_slice
|
259
286
|
]
|
260
|
-
ret._translation_offset =
|
287
|
+
ret._translation_offset = target_offset
|
261
288
|
|
262
289
|
ret.template_filter = self.template_filter
|
263
290
|
ret.target_filter = self.target_filter
|
@@ -266,7 +293,7 @@ class MatchingData:
|
|
266
293
|
ret._invert_target = self._invert_target
|
267
294
|
|
268
295
|
if self._target_mask is not None:
|
269
|
-
ret.
|
296
|
+
ret._target_mask = self.subset_array(
|
270
297
|
arr=self._target_mask, arr_slice=target_slice, padding=target_pad
|
271
298
|
)
|
272
299
|
if self._template_mask is not None:
|
@@ -289,19 +316,29 @@ class MatchingData:
|
|
289
316
|
|
290
317
|
def to_backend(self) -> None:
|
291
318
|
"""
|
292
|
-
Transfer
|
319
|
+
Transfer and convert types of class instance's data arrays to the current backend
|
293
320
|
"""
|
294
|
-
backend_arr = type(backend.zeros((1), dtype=backend.
|
321
|
+
backend_arr = type(backend.zeros((1), dtype=backend._float_dtype))
|
295
322
|
for attr_name, attr_value in vars(self).items():
|
323
|
+
converted_array = None
|
296
324
|
if isinstance(attr_value, np.ndarray):
|
297
325
|
converted_array = backend.to_backend_array(attr_value.copy())
|
298
|
-
setattr(self, attr_name, converted_array)
|
299
326
|
elif isinstance(attr_value, backend_arr):
|
300
327
|
converted_array = backend.to_backend_array(attr_value)
|
301
|
-
|
328
|
+
else:
|
329
|
+
continue
|
330
|
+
|
331
|
+
current_dtype = backend.get_fundamental_dtype(converted_array)
|
332
|
+
target_dtype = backend._fundamental_dtypes[current_dtype]
|
333
|
+
|
334
|
+
# Optional, but scores are float so we avoid casting and potential issues
|
335
|
+
if attr_name in ("_template", "_template_mask", "_target", "_target_mask"):
|
336
|
+
target_dtype = backend._float_dtype
|
302
337
|
|
303
|
-
|
304
|
-
|
338
|
+
if target_dtype != current_dtype:
|
339
|
+
converted_array = backend.astype(converted_array, target_dtype)
|
340
|
+
|
341
|
+
setattr(self, attr_name, converted_array)
|
305
342
|
|
306
343
|
def _set_batch_dimension(
|
307
344
|
self, target_dims: Tuple[int] = None, template_dims: Tuple[int] = None
|
@@ -350,12 +387,14 @@ class MatchingData:
|
|
350
387
|
matching_dims = target_measurement_dims + batch_dims
|
351
388
|
|
352
389
|
target_shape = backend.full(
|
353
|
-
shape=(matching_dims,), fill_value=1, dtype=backend.
|
390
|
+
shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
|
354
391
|
)
|
355
392
|
template_shape = backend.full(
|
356
|
-
shape=(matching_dims,), fill_value=1, dtype=backend.
|
393
|
+
shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
|
394
|
+
)
|
395
|
+
batch_mask = backend.full(
|
396
|
+
shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
|
357
397
|
)
|
358
|
-
batch_mask = backend.full(shape=(matching_dims,), fill_value=1, dtype=bool)
|
359
398
|
|
360
399
|
target_index, template_index = 0, 0
|
361
400
|
for k in range(matching_dims):
|
@@ -440,7 +479,7 @@ class MatchingData:
|
|
440
479
|
An array indicating the padding for each dimension of the target.
|
441
480
|
"""
|
442
481
|
target_padding = backend.zeros(
|
443
|
-
len(self._output_target_shape), dtype=backend.
|
482
|
+
len(self._output_target_shape), dtype=backend._int_dtype
|
444
483
|
)
|
445
484
|
|
446
485
|
if pad_target:
|
@@ -491,13 +530,14 @@ class MatchingData:
|
|
491
530
|
fourier_pad = backend.full(
|
492
531
|
shape=(len(fourier_pad),),
|
493
532
|
fill_value=1,
|
494
|
-
dtype=backend.
|
533
|
+
dtype=backend._int_dtype,
|
495
534
|
)
|
496
535
|
|
497
536
|
fourier_pad = backend.to_backend_array(fourier_pad)
|
498
537
|
if hasattr(self, "_batch_mask"):
|
499
538
|
batch_mask = backend.to_backend_array(self._batch_mask)
|
500
|
-
fourier_pad
|
539
|
+
backend.multiply(fourier_pad, 1 - batch_mask, out=fourier_pad)
|
540
|
+
backend.add(fourier_pad, batch_mask, out=fourier_pad)
|
501
541
|
|
502
542
|
pad_shape = backend.maximum(target_shape, template_shape)
|
503
543
|
ret = backend.compute_convolution_shapes(pad_shape, fourier_pad)
|
@@ -510,10 +550,12 @@ class MatchingData:
|
|
510
550
|
|
511
551
|
if hasattr(self, "_batch_mask"):
|
512
552
|
batch_mask = backend.to_backend_array(self._batch_mask)
|
513
|
-
shape_diff
|
553
|
+
backend.multiply(shape_diff, 1 - batch_mask, out=shape_diff)
|
514
554
|
|
515
555
|
backend.add(fourier_shift, shape_diff, out=fourier_shift)
|
516
556
|
|
557
|
+
fourier_shift = backend.astype(fourier_shift, backend._int_dtype)
|
558
|
+
|
517
559
|
return fast_shape, fast_ft_shape, fourier_shift
|
518
560
|
|
519
561
|
@property
|
@@ -540,27 +582,27 @@ class MatchingData:
|
|
540
582
|
pass
|
541
583
|
else:
|
542
584
|
raise ValueError("Rotations have to be a rank 2 or 3 array.")
|
543
|
-
self._rotations = rotations.astype(
|
585
|
+
self._rotations = rotations.astype(np.float32)
|
544
586
|
|
545
587
|
@property
|
546
588
|
def target(self):
|
547
|
-
"""Returns the target
|
589
|
+
"""Returns the target."""
|
590
|
+
target = self._target
|
548
591
|
if isinstance(self._target, Density):
|
549
592
|
target = self._target.data
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
return target.reshape(tuple(int(x) for x in out_shape))
|
593
|
+
|
594
|
+
out_shape = tuple(int(x) for x in self._output_target_shape)
|
595
|
+
return target.reshape(out_shape)
|
554
596
|
|
555
597
|
@property
|
556
598
|
def template(self):
|
557
|
-
"""Returns the reversed template
|
599
|
+
"""Returns the reversed template."""
|
558
600
|
template = self._template
|
559
601
|
if isinstance(self._template, Density):
|
560
602
|
template = self._template.data
|
561
603
|
template = backend.reverse(template)
|
562
|
-
out_shape =
|
563
|
-
return template.reshape(
|
604
|
+
out_shape = tuple(int(x) for x in self._output_template_shape)
|
605
|
+
return template.reshape(out_shape)
|
564
606
|
|
565
607
|
@template.setter
|
566
608
|
def template(self, template: NDArray):
|
@@ -580,9 +622,7 @@ class MatchingData:
|
|
580
622
|
shape=template.shape, dtype=float, fill_value=1
|
581
623
|
)
|
582
624
|
|
583
|
-
|
584
|
-
template = template.data
|
585
|
-
self._template = template.astype(self._default_dtype, copy=False)
|
625
|
+
self._template = template
|
586
626
|
|
587
627
|
@property
|
588
628
|
def target_mask(self):
|
@@ -592,8 +632,8 @@ class MatchingData:
|
|
592
632
|
target_mask = self._target_mask.data
|
593
633
|
|
594
634
|
if target_mask is not None:
|
595
|
-
out_shape =
|
596
|
-
target_mask = target_mask.reshape(
|
635
|
+
out_shape = tuple(int(x) for x in self._output_target_shape)
|
636
|
+
target_mask = target_mask.reshape(out_shape)
|
597
637
|
|
598
638
|
return target_mask
|
599
639
|
|
@@ -603,14 +643,7 @@ class MatchingData:
|
|
603
643
|
if not np.all(self.target.shape == mask.shape):
|
604
644
|
raise ValueError("Target and its mask have to have the same shape.")
|
605
645
|
|
606
|
-
|
607
|
-
mask.data = mask.data.astype(self._default_dtype, copy=False)
|
608
|
-
self._target_mask = mask
|
609
|
-
self._targetmaskshape = self._target_mask.shape[::-1]
|
610
|
-
return None
|
611
|
-
|
612
|
-
self._target_mask = mask.astype(self._default_dtype, copy=False)
|
613
|
-
self._targetmaskshape = self._target_mask.shape
|
646
|
+
self._target_mask = mask
|
614
647
|
|
615
648
|
@property
|
616
649
|
def template_mask(self):
|
@@ -628,24 +661,17 @@ class MatchingData:
|
|
628
661
|
|
629
662
|
if mask is not None:
|
630
663
|
mask = backend.reverse(mask)
|
631
|
-
out_shape =
|
632
|
-
mask = mask.reshape(
|
664
|
+
out_shape = tuple(int(x) for x in self._output_template_shape)
|
665
|
+
mask = mask.reshape(out_shape)
|
633
666
|
return mask
|
634
667
|
|
635
668
|
@template_mask.setter
|
636
669
|
def template_mask(self, mask: NDArray):
|
637
670
|
"""Returns the reversed template mask NDArray."""
|
638
671
|
if not np.all(self._templateshape[::-1] == mask.shape):
|
639
|
-
raise ValueError("
|
640
|
-
|
641
|
-
if type(mask) == Density:
|
642
|
-
mask.data = mask.data.astype(self._default_dtype, copy=False)
|
643
|
-
self._template_mask = mask
|
644
|
-
self._templatemaskshape = self._template_mask.shape[::-1]
|
645
|
-
return None
|
672
|
+
raise ValueError("Template and its mask have to have the same shape.")
|
646
673
|
|
647
|
-
self._template_mask = mask
|
648
|
-
self._templatemaskshape = self._template_mask.shape[::-1]
|
674
|
+
self._template_mask = mask
|
649
675
|
|
650
676
|
def _split_rotations_on_jobs(self, n_jobs: int) -> List[NDArray]:
|
651
677
|
"""
|