pytme 0.1.9__cp311-cp311-macosx_14_0_arm64.whl → 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/match_template.py +148 -126
- pytme-0.2.0b0.data/scripts/postprocess.py +570 -0
- {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/preprocessor_gui.py +244 -60
- {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/METADATA +3 -1
- pytme-0.2.0b0.dist-info/RECORD +66 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/WHEEL +1 -1
- scripts/extract_candidates.py +218 -0
- scripts/match_template.py +148 -126
- scripts/match_template_filters.py +852 -0
- scripts/postprocess.py +380 -435
- scripts/preprocessor_gui.py +244 -60
- scripts/refine_matches.py +218 -0
- tme/__init__.py +2 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +545 -78
- tme/backends/cupy_backend.py +80 -15
- tme/backends/npfftw_backend.py +33 -2
- tme/backends/pytorch_backend.py +15 -7
- tme/density.py +156 -63
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_constrained.py +195 -0
- tme/matching_data.py +74 -33
- tme/matching_exhaustive.py +351 -208
- tme/matching_memory.py +1 -0
- tme/matching_optimization.py +728 -651
- tme/matching_utils.py +152 -8
- tme/orientations.py +561 -0
- tme/preprocessor.py +21 -18
- tme/structure.py +2 -37
- pytme-0.1.9.data/scripts/postprocess.py +0 -625
- pytme-0.1.9.dist-info/RECORD +0 -61
- {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.1.9.data → pytme-0.2.0b0.data}/scripts/preprocess.py +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/LICENSE +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/entry_points.txt +0 -0
- {pytme-0.1.9.dist-info → pytme-0.2.0b0.dist-info}/top_level.txt +0 -0
tme/density.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
"""
|
1
|
+
""" Representation of N-dimensional densities
|
2
2
|
|
3
3
|
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
4
|
|
@@ -12,6 +12,7 @@ from gzip import open as gzip_open
|
|
12
12
|
from typing import Tuple, Dict, Set
|
13
13
|
from os.path import splitext, basename
|
14
14
|
|
15
|
+
import h5py
|
15
16
|
import mrcfile
|
16
17
|
import numpy as np
|
17
18
|
import skimage.io as skio
|
@@ -26,7 +27,6 @@ from scipy.ndimage import (
|
|
26
27
|
)
|
27
28
|
from scipy.spatial import ConvexHull
|
28
29
|
|
29
|
-
from .matching_optimization import FitRefinement
|
30
30
|
from .structure import Structure
|
31
31
|
from .matching_utils import (
|
32
32
|
minimum_enclosing_box,
|
@@ -129,8 +129,8 @@ class Density:
|
|
129
129
|
Parameters
|
130
130
|
----------
|
131
131
|
filename : str
|
132
|
-
Path to a file in CCP4/MRC, EM or a format supported by
|
133
|
-
The file can be gzip compressed.
|
132
|
+
Path to a file in CCP4/MRC, EM, HDF5 or a format supported by
|
133
|
+
skimage.io.imread. The file can be gzip compressed.
|
134
134
|
subset : tuple of slices, optional
|
135
135
|
Slices representing the desired subset along each dimension.
|
136
136
|
use_memmap : bool, optional
|
@@ -185,8 +185,9 @@ class Density:
|
|
185
185
|
|
186
186
|
Notes
|
187
187
|
-----
|
188
|
-
If ``filename`` ends with ".em" or ".em.gz" the method will parse it as EM file
|
189
|
-
|
188
|
+
If ``filename`` ends with ".em" or ".em.gz" the method will parse it as EM file,
|
189
|
+
if it ends with "h5" or "h5.gz" the method will parse the file as HDF5.
|
190
|
+
Otherwise the method defaults to the CCP4/MRC format and on failure, switches to
|
190
191
|
:obj:`skimage.io.imread` regardless of the extension. Currently, the later does not
|
191
192
|
extract origin or sampling_rate information from the file.
|
192
193
|
|
@@ -197,8 +198,10 @@ class Density:
|
|
197
198
|
"""
|
198
199
|
try:
|
199
200
|
func = cls._load_mrc
|
200
|
-
if filename.endswith("
|
201
|
+
if filename.endswith("em") or filename.endswith("em.gz"):
|
201
202
|
func = cls._load_em
|
203
|
+
elif filename.endswith("h5") or filename.endswith("h5.gz"):
|
204
|
+
func = cls._load_hdf5
|
202
205
|
data, origin, sampling_rate, meta = func(
|
203
206
|
filename=filename, subset=subset, use_memmap=use_memmap
|
204
207
|
)
|
@@ -213,7 +216,7 @@ class Density:
|
|
213
216
|
@classmethod
|
214
217
|
def _load_mrc(
|
215
218
|
cls, filename: str, subset: Tuple[int] = None, use_memmap: bool = False
|
216
|
-
) -> Tuple[NDArray]:
|
219
|
+
) -> Tuple[NDArray, NDArray, NDArray, Dict]:
|
217
220
|
"""
|
218
221
|
Extracts data from a CCP4/MRC file.
|
219
222
|
|
@@ -228,12 +231,8 @@ class Density:
|
|
228
231
|
|
229
232
|
Returns
|
230
233
|
-------
|
231
|
-
NDArray
|
232
|
-
|
233
|
-
NDArray
|
234
|
-
The coordinate origin of the data.
|
235
|
-
NDArray
|
236
|
-
The sampling rate of the data.
|
234
|
+
Tuple[NDArray, NDArray, NDArray, Dict]
|
235
|
+
File data, coordinate origin, sampling rate array and metadata dictionary.
|
237
236
|
|
238
237
|
References
|
239
238
|
----------
|
@@ -355,7 +354,7 @@ class Density:
|
|
355
354
|
@classmethod
|
356
355
|
def _load_em(
|
357
356
|
cls, filename: str, subset: Tuple[int] = None, use_memmap: bool = False
|
358
|
-
) -> Tuple[NDArray]:
|
357
|
+
) -> Tuple[NDArray, NDArray, NDArray, Dict]:
|
359
358
|
"""
|
360
359
|
Extracts data from a EM file.
|
361
360
|
|
@@ -370,12 +369,8 @@ class Density:
|
|
370
369
|
|
371
370
|
Returns
|
372
371
|
-------
|
373
|
-
NDArray
|
374
|
-
|
375
|
-
NDArray
|
376
|
-
The coordinate origin of the data.
|
377
|
-
NDArray
|
378
|
-
The sampling rate of the data.
|
372
|
+
Tuple[NDArray, NDArray, NDArray, Dict]
|
373
|
+
File data, coordinate origin, sampling rate array and metadata dictionary.
|
379
374
|
|
380
375
|
References
|
381
376
|
----------
|
@@ -383,11 +378,11 @@ class Density:
|
|
383
378
|
|
384
379
|
Warns
|
385
380
|
-----
|
386
|
-
|
381
|
+
If the sampling rate is zero.
|
387
382
|
|
388
383
|
Notes
|
389
384
|
-----
|
390
|
-
A
|
385
|
+
A sampling rate of zero will be treated as missing value and changed to one. This
|
391
386
|
function does not yet extract an origin like :py:meth:`Density._load_mrc`.
|
392
387
|
|
393
388
|
See Also
|
@@ -483,10 +478,15 @@ class Density:
|
|
483
478
|
f"Expected length of slices : {n_dims}, got : {len(slices)}"
|
484
479
|
)
|
485
480
|
|
486
|
-
if any(
|
481
|
+
if any(
|
482
|
+
[
|
483
|
+
slices[i].stop > shape[i] or slices[i].start > shape[i]
|
484
|
+
for i in range(n_dims)
|
485
|
+
]
|
486
|
+
):
|
487
487
|
raise ValueError(f"Subset exceeds data dimensions ({shape}).")
|
488
488
|
|
489
|
-
if any([slices[i].stop < 0 for i in range(n_dims)]):
|
489
|
+
if any([slices[i].stop < 0 or slices[i].start < 0 for i in range(n_dims)]):
|
490
490
|
raise ValueError("Subsets have to be non-negative.")
|
491
491
|
|
492
492
|
@classmethod
|
@@ -560,7 +560,7 @@ class Density:
|
|
560
560
|
return subset_data
|
561
561
|
|
562
562
|
@staticmethod
|
563
|
-
def _load_skio(filename: str) -> Tuple[NDArray]:
|
563
|
+
def _load_skio(filename: str) -> Tuple[NDArray, NDArray, NDArray, Dict]:
|
564
564
|
"""
|
565
565
|
Uses :obj:`skimage.io.imread` to extract data from filename [1]_.
|
566
566
|
|
@@ -571,12 +571,8 @@ class Density:
|
|
571
571
|
|
572
572
|
Returns
|
573
573
|
-------
|
574
|
-
NDArray
|
575
|
-
|
576
|
-
NDArray
|
577
|
-
The coordinate origin of the data.
|
578
|
-
NDArray
|
579
|
-
The sampling rate of the data.
|
574
|
+
Tuple[NDArray, NDArray, NDArray, Dict]
|
575
|
+
File data, coordinate origin, sampling rate array and metadata dictionary.
|
580
576
|
|
581
577
|
References
|
582
578
|
----------
|
@@ -601,6 +597,51 @@ class Density:
|
|
601
597
|
)
|
602
598
|
return data, np.zeros(data.ndim), np.ones(data.ndim), {}
|
603
599
|
|
600
|
+
@staticmethod
|
601
|
+
def _load_hdf5(
|
602
|
+
filename: str, subset: Tuple[slice], use_memmap: bool = False, **kwargs
|
603
|
+
) -> "Density":
|
604
|
+
"""
|
605
|
+
Extracts data from an H5 file.
|
606
|
+
|
607
|
+
Parameters
|
608
|
+
----------
|
609
|
+
filename : str
|
610
|
+
Path to a file in CCP4/MRC format.
|
611
|
+
subset : tuple of slices, optional
|
612
|
+
Slices representing the desired subset along each dimension.
|
613
|
+
use_memmap : bool, optional
|
614
|
+
Whether the Density objects data attribute should be memmory mapped.
|
615
|
+
|
616
|
+
Returns
|
617
|
+
-------
|
618
|
+
Density
|
619
|
+
An instance of the Density class populated with the data from the HDF5 file.
|
620
|
+
|
621
|
+
See Also
|
622
|
+
--------
|
623
|
+
:py:meth:`Density._save_hdf5`
|
624
|
+
"""
|
625
|
+
subset = ... if subset is None else subset
|
626
|
+
|
627
|
+
with h5py.File(filename, mode="r") as infile:
|
628
|
+
data = infile["data"]
|
629
|
+
data_attributes = [
|
630
|
+
infile["data"].id.get_offset(),
|
631
|
+
infile["data"].shape,
|
632
|
+
infile["data"].dtype,
|
633
|
+
]
|
634
|
+
origin = infile["origin"][...].copy()
|
635
|
+
sampling_rate = infile["sampling_rate"][...].copy()
|
636
|
+
metadata = {key: val for key, val in infile.attrs.items()}
|
637
|
+
if not use_memmap:
|
638
|
+
return data[subset], origin, sampling_rate, metadata
|
639
|
+
|
640
|
+
offset, shape, dtype = data_attributes
|
641
|
+
data = np.memmap(filename, dtype=dtype, shape=shape, offset=offset)[subset]
|
642
|
+
|
643
|
+
return data, origin, sampling_rate, metadata
|
644
|
+
|
604
645
|
@classmethod
|
605
646
|
def from_structure(
|
606
647
|
cls,
|
@@ -798,9 +839,9 @@ class Density:
|
|
798
839
|
|
799
840
|
Notes
|
800
841
|
-----
|
801
|
-
If ``filename`` ends with "
|
802
|
-
Otherwise,
|
803
|
-
to :obj:`skimage.io.imsave`.
|
842
|
+
If ``filename`` ends with "em" or "em.gz" will create an EM file, "h5" or
|
843
|
+
"h5.gz" will create a HDF5 file. Otherwise, the method defaults to the CCP4/MRC
|
844
|
+
format, and on failure, falls back to :obj:`skimage.io.imsave`.
|
804
845
|
|
805
846
|
See Also
|
806
847
|
--------
|
@@ -811,13 +852,15 @@ class Density:
|
|
811
852
|
|
812
853
|
try:
|
813
854
|
func = self._save_mrc
|
814
|
-
if filename.endswith("
|
855
|
+
if filename.endswith("em") or filename.endswith("em.gz"):
|
815
856
|
func = self._save_em
|
857
|
+
elif filename.endswith("h5") or filename.endswith("h5.gz"):
|
858
|
+
func = self._save_hdf5
|
816
859
|
_ = func(filename=filename, gzip=gzip)
|
817
860
|
except ValueError:
|
818
861
|
_ = self._save_skio(filename=filename, gzip=gzip)
|
819
862
|
|
820
|
-
def _save_mrc(self, filename: str, gzip: bool) -> None:
|
863
|
+
def _save_mrc(self, filename: str, gzip: bool = False) -> None:
|
821
864
|
"""
|
822
865
|
Writes current class instance to disk as mrc file.
|
823
866
|
|
@@ -843,20 +886,16 @@ class Density:
|
|
843
886
|
mrc.header["origin"] = tuple(self.origin[::-1])
|
844
887
|
mrc.voxel_size = tuple(self.sampling_rate[::-1])
|
845
888
|
|
846
|
-
def _save_em(self, filename: str, gzip: bool) -> None:
|
889
|
+
def _save_em(self, filename: str, gzip: bool = False) -> None:
|
847
890
|
"""
|
848
891
|
Writes data to disk as an .em file.
|
849
892
|
|
850
893
|
Parameters
|
851
894
|
----------
|
852
895
|
filename : str
|
853
|
-
Path to write
|
854
|
-
|
855
|
-
|
856
|
-
origin : NDArray
|
857
|
-
Coordinate origin of the data.
|
858
|
-
sampling_rate : NDArray
|
859
|
-
Sampling rate of the data.
|
896
|
+
Path to write to.
|
897
|
+
gzip : bool, optional
|
898
|
+
If True, the output will be gzip compressed.
|
860
899
|
|
861
900
|
References
|
862
901
|
----------
|
@@ -886,7 +925,7 @@ class Density:
|
|
886
925
|
f.write(b" " * 256)
|
887
926
|
f.write(self.data.tobytes())
|
888
927
|
|
889
|
-
def _save_skio(self, filename: str, gzip: bool) -> None:
|
928
|
+
def _save_skio(self, filename: str, gzip: bool = False) -> None:
|
890
929
|
"""
|
891
930
|
Uses :obj:`skimage.io.imsave` to write data to filename [1]_.
|
892
931
|
|
@@ -904,12 +943,54 @@ class Density:
|
|
904
943
|
swap, kwargs = filename, {}
|
905
944
|
if gzip:
|
906
945
|
swap = BytesIO()
|
907
|
-
kwargs["format"] = splitext(basename(filename.replace(".gz", "")))[
|
946
|
+
kwargs["format"] = splitext(basename(filename.replace(".gz", "")))[
|
947
|
+
1
|
948
|
+
].replace(".", "")
|
908
949
|
skio.imsave(fname=swap, arr=self.data.astype("float32"), **kwargs)
|
909
950
|
if gzip:
|
910
951
|
with gzip_open(filename, "wb") as outfile:
|
911
952
|
outfile.write(swap.getvalue())
|
912
953
|
|
954
|
+
def _save_hdf5(self, filename: str, gzip: bool = False) -> None:
|
955
|
+
"""
|
956
|
+
Saves the Density instance data to an HDF5 file, with optional compression.
|
957
|
+
|
958
|
+
Parameters
|
959
|
+
----------
|
960
|
+
filename : str
|
961
|
+
Path to write to.
|
962
|
+
gzip : bool, optional
|
963
|
+
If True, the output will be gzip compressed.
|
964
|
+
|
965
|
+
See Also
|
966
|
+
--------
|
967
|
+
:py:meth:`Density._load_hdf5`
|
968
|
+
"""
|
969
|
+
compression = "gzip" if gzip else None
|
970
|
+
with h5py.File(filename, mode="w") as f:
|
971
|
+
f.create_dataset(
|
972
|
+
"data",
|
973
|
+
data=self.data,
|
974
|
+
shape=self.data.shape,
|
975
|
+
dtype=self.data.dtype,
|
976
|
+
compression=compression,
|
977
|
+
)
|
978
|
+
f.create_dataset("origin", data=self.origin)
|
979
|
+
f.create_dataset("sampling_rate", data=self.sampling_rate)
|
980
|
+
|
981
|
+
self.metadata["mean"] = self.metadata.get("mean", 0)
|
982
|
+
self.metadata["std"] = self.metadata.get("std", 0)
|
983
|
+
self.metadata["min"] = self.metadata.get("min", 0)
|
984
|
+
self.metadata["max"] = self.metadata.get("max", 0)
|
985
|
+
if type(self.data) != np.memmap:
|
986
|
+
self.metadata["mean"] = self.data.mean()
|
987
|
+
self.metadata["std"] = self.data.std()
|
988
|
+
self.metadata["min"] = self.data.min()
|
989
|
+
self.metadata["max"] = self.data.max()
|
990
|
+
|
991
|
+
for key, val in self.metadata.items():
|
992
|
+
f.attrs[key] = val
|
993
|
+
|
913
994
|
@property
|
914
995
|
def empty(self) -> "Density":
|
915
996
|
"""
|
@@ -1098,7 +1179,13 @@ class Density:
|
|
1098
1179
|
@property
|
1099
1180
|
def sampling_rate(self) -> NDArray:
|
1100
1181
|
"""
|
1101
|
-
Returns
|
1182
|
+
Returns the value of the current instance's :py:attr:`Density.sampling_rate`
|
1183
|
+
attribute.
|
1184
|
+
|
1185
|
+
Returns
|
1186
|
+
-------
|
1187
|
+
NDArray
|
1188
|
+
Sampling rate along axis.
|
1102
1189
|
"""
|
1103
1190
|
return self._sampling_rate
|
1104
1191
|
|
@@ -1114,7 +1201,12 @@ class Density:
|
|
1114
1201
|
@property
|
1115
1202
|
def metadata(self) -> Dict:
|
1116
1203
|
"""
|
1117
|
-
Returns
|
1204
|
+
Returns the current instance's :py:attr:`Density.metadata` dictionary attribute.
|
1205
|
+
|
1206
|
+
Returns
|
1207
|
+
-------
|
1208
|
+
Dict
|
1209
|
+
Metadata dictionary. Empty by default.
|
1118
1210
|
"""
|
1119
1211
|
return self._metadata
|
1120
1212
|
|
@@ -2068,7 +2160,7 @@ class Density:
|
|
2068
2160
|
|
2069
2161
|
If voxel sizes of target and template dont match coordinates are scaled
|
2070
2162
|
to the numerically smaller voxel size. Instances are prealigned based on their
|
2071
|
-
center of mass. Finally :py:
|
2163
|
+
center of mass. Finally :py:meth:`tme.matching_optimization.optimize_match` is
|
2072
2164
|
used to determine translation and rotation to map template to target.
|
2073
2165
|
|
2074
2166
|
Parameters
|
@@ -2083,7 +2175,7 @@ class Density:
|
|
2083
2175
|
The cutoff value for the template map, by default 0.
|
2084
2176
|
scoring_method : str, optional
|
2085
2177
|
The scoring method to use for alignment. See
|
2086
|
-
:py:class:`tme.matching_optimization.
|
2178
|
+
:py:class:`tme.matching_optimization.create_score_object` for available methods,
|
2087
2179
|
by default "NormalizedCrossCorrelation".
|
2088
2180
|
|
2089
2181
|
Returns
|
@@ -2096,6 +2188,8 @@ class Density:
|
|
2096
2188
|
-----
|
2097
2189
|
No densities below cutoff_template are present in the returned Density object.
|
2098
2190
|
"""
|
2191
|
+
from .matching_optimization import optimize_match, create_score_object
|
2192
|
+
|
2099
2193
|
target_sampling_rate = np.array(target.sampling_rate)
|
2100
2194
|
template_sampling_rate = np.array(template.sampling_rate)
|
2101
2195
|
|
@@ -2105,7 +2199,6 @@ class Density:
|
|
2105
2199
|
template_sampling_rate = np.repeat(
|
2106
2200
|
template_sampling_rate, template.data.ndim // template_sampling_rate.size
|
2107
2201
|
)
|
2108
|
-
|
2109
2202
|
if not np.allclose(target_sampling_rate, template_sampling_rate):
|
2110
2203
|
print(
|
2111
2204
|
"Voxel size of target and template do not match. "
|
@@ -2113,7 +2206,6 @@ class Density:
|
|
2113
2206
|
)
|
2114
2207
|
|
2115
2208
|
target_coordinates = target.to_pointcloud(cutoff_target)
|
2116
|
-
target_weights = target.data[tuple(target_coordinates)]
|
2117
2209
|
|
2118
2210
|
template_coordinates = template.to_pointcloud(cutoff_template)
|
2119
2211
|
template_weights = template.data[tuple(template_coordinates)]
|
@@ -2126,23 +2218,24 @@ class Density:
|
|
2126
2218
|
target_coordinates = target_coordinates * target_scaling[:, None]
|
2127
2219
|
template_coordinates = template_coordinates * template_scaling[:, None]
|
2128
2220
|
|
2129
|
-
target_mass_center = cls.center_of_mass(target.data, cutoff_target)
|
2130
|
-
template_mass_center = cls.center_of_mass(template.data, cutoff_template)
|
2131
2221
|
mass_center_difference = np.subtract(
|
2132
|
-
|
2222
|
+
cls.center_of_mass(target.data, cutoff_target),
|
2223
|
+
cls.center_of_mass(template.data, cutoff_template),
|
2133
2224
|
).astype(int)
|
2134
2225
|
template_coordinates += mass_center_difference[:, None]
|
2135
2226
|
|
2136
|
-
|
2137
|
-
|
2138
|
-
|
2227
|
+
score_object = create_score_object(
|
2228
|
+
score=scoring_method,
|
2229
|
+
target=target.data,
|
2139
2230
|
template_coordinates=template_coordinates,
|
2140
|
-
target_weights=target_weights,
|
2141
2231
|
template_weights=template_weights,
|
2142
|
-
scoring_class=scoring_method,
|
2143
2232
|
sampling_rate=np.ones(template.data.ndim),
|
2144
2233
|
)
|
2145
2234
|
|
2235
|
+
translation, rotation_matrix, score = optimize_match(
|
2236
|
+
score_object=score_object, optimization_method="basinhopping"
|
2237
|
+
)
|
2238
|
+
|
2146
2239
|
translation += mass_center_difference
|
2147
2240
|
translation = np.divide(translation, template_scaling)
|
2148
2241
|
|
@@ -2169,7 +2262,7 @@ class Density:
|
|
2169
2262
|
|
2170
2263
|
If voxel sizes of target and template dont match coordinates are scaled
|
2171
2264
|
to the numerically smaller voxel size. Prealignment is done by center's
|
2172
|
-
of mass. Finally :py:class:`tme.matching_optimization.
|
2265
|
+
of mass. Finally :py:class:`tme.matching_optimization.optimize_match` is used to
|
2173
2266
|
determine translation and rotation to match a template to target.
|
2174
2267
|
|
2175
2268
|
Parameters
|
@@ -2184,7 +2277,7 @@ class Density:
|
|
2184
2277
|
The cutoff value for the template map, by default 0.
|
2185
2278
|
scoring_method : str, optional
|
2186
2279
|
The scoring method to use for template matching. See
|
2187
|
-
:py:class:`tme.matching_optimization.
|
2280
|
+
:py:class:`tme.matching_optimization.create_score_object` for available methods,
|
2188
2281
|
by default "NormalizedCrossCorrelation".
|
2189
2282
|
|
2190
2283
|
Returns
|
Binary file
|
@@ -0,0 +1,195 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from typing import Tuple, Dict
|
3
|
+
|
4
|
+
from scipy.ndimage import map_coordinates
|
5
|
+
|
6
|
+
from tme.types import ArrayLike
|
7
|
+
from tme.backends import backend
|
8
|
+
from tme.matching_data import MatchingData
|
9
|
+
from tme.matching_exhaustive import _normalize_under_mask
|
10
|
+
|
11
|
+
|
12
|
+
class MatchDensityToDensity:
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
matching_data: "MatchingData",
|
16
|
+
pad_target_edges: bool = False,
|
17
|
+
pad_fourier: bool = False,
|
18
|
+
rotate_mask: bool = True,
|
19
|
+
interpolation_order: int = 1,
|
20
|
+
negate_score: bool = False,
|
21
|
+
):
|
22
|
+
self.rotate_mask = rotate_mask
|
23
|
+
self.interpolation_order = interpolation_order
|
24
|
+
|
25
|
+
target_pad = matching_data.target_padding(pad_target=pad_target_edges)
|
26
|
+
matching_data = matching_data.subset_by_slice(target_pad=target_pad)
|
27
|
+
|
28
|
+
fast_shape, fast_ft_shape, fourier_shift = matching_data.fourier_padding(
|
29
|
+
pad_fourier=pad_fourier
|
30
|
+
)
|
31
|
+
|
32
|
+
self.target = backend.topleft_pad(matching_data.target, fast_shape)
|
33
|
+
self.target_mask = matching_data.target_mask
|
34
|
+
|
35
|
+
self.template = matching_data.template
|
36
|
+
self.template_rot = backend.preallocate_array(
|
37
|
+
fast_shape, backend._default_dtype
|
38
|
+
)
|
39
|
+
|
40
|
+
self.template_mask, self.template_mask_rot = 1, 1
|
41
|
+
rotate_mask = False if matching_data.template_mask is None else rotate_mask
|
42
|
+
if matching_data.template_mask is not None:
|
43
|
+
self.template_mask = matching_data.template_mask
|
44
|
+
self.template_mask_rot = backend.topleft_pad(
|
45
|
+
matching_data.template_mask, fast_shape
|
46
|
+
)
|
47
|
+
|
48
|
+
self.score_sign = -1 if negate_score else 1
|
49
|
+
|
50
|
+
@staticmethod
|
51
|
+
def rigid_transform(
|
52
|
+
arr,
|
53
|
+
rotation_matrix,
|
54
|
+
translation,
|
55
|
+
arr_mask=None,
|
56
|
+
out=None,
|
57
|
+
out_mask=None,
|
58
|
+
order: int = 1,
|
59
|
+
use_geometric_center: bool = False,
|
60
|
+
):
|
61
|
+
rotate_mask = arr_mask is not None
|
62
|
+
return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
|
63
|
+
translation = np.zeros(arr.ndim) if translation is None else translation
|
64
|
+
|
65
|
+
center = np.floor(np.array(arr.shape) / 2)[:, None]
|
66
|
+
grid = np.indices(arr.shape, dtype=np.float32).reshape(arr.ndim, -1)
|
67
|
+
np.subtract(grid, center, out=grid)
|
68
|
+
np.matmul(rotation_matrix.T, grid, out=grid)
|
69
|
+
np.add(grid, center, out=grid)
|
70
|
+
|
71
|
+
if out is None:
|
72
|
+
out = np.zeros_like(arr)
|
73
|
+
|
74
|
+
map_coordinates(arr, grid, order=order, output=out.ravel())
|
75
|
+
|
76
|
+
if out_mask is None and arr_mask is not None:
|
77
|
+
out_mask = np.zeros_like(arr_mask)
|
78
|
+
|
79
|
+
if arr_mask is not None:
|
80
|
+
map_coordinates(arr_mask, grid, order=order, output=out_mask.ravel())
|
81
|
+
|
82
|
+
match return_type:
|
83
|
+
case 0:
|
84
|
+
return None
|
85
|
+
case 1:
|
86
|
+
return out
|
87
|
+
case 2:
|
88
|
+
return out_mask
|
89
|
+
case 3:
|
90
|
+
return out, out_mask
|
91
|
+
|
92
|
+
@staticmethod
|
93
|
+
def angles_to_rotationmatrix(angles: Tuple[float]) -> ArrayLike:
|
94
|
+
angles = backend.to_numpy_array(angles)
|
95
|
+
rotation_matrix = euler_to_rotationmatrix(angles)
|
96
|
+
return backend.to_backend_array(rotation_matrix)
|
97
|
+
|
98
|
+
def format_translation(self, translation: Tuple[float] = None) -> ArrayLike:
|
99
|
+
if translation is None:
|
100
|
+
return backend.zeros(self.template.ndim, backend._default_dtype)
|
101
|
+
|
102
|
+
return backend.to_backend_array(translation)
|
103
|
+
|
104
|
+
def score_translation(self, x: Tuple[float]) -> float:
|
105
|
+
translation = self.format_translation(x)
|
106
|
+
rotation_matrix = self.angles_to_rotationmatrix((0, 0, 0))
|
107
|
+
|
108
|
+
return self(translation=translation, rotation_matrix=rotation_matrix)
|
109
|
+
|
110
|
+
def score_angles(self, x: Tuple[float]) -> float:
|
111
|
+
translation = self.format_translation(None)
|
112
|
+
rotation_matrix = self.angles_to_rotationmatrix(x)
|
113
|
+
|
114
|
+
return self(translation=translation, rotation_matrix=rotation_matrix)
|
115
|
+
|
116
|
+
def score(self, x: Tuple[float]) -> float:
|
117
|
+
split = len(x) // 2
|
118
|
+
translation, angles = x[:split], x[split:]
|
119
|
+
|
120
|
+
translation = self.format_translation(translation)
|
121
|
+
rotation_matrix = self.angles_to_rotationmatrix(angles)
|
122
|
+
|
123
|
+
return self(translation=translation, rotation_matrix=rotation_matrix)
|
124
|
+
|
125
|
+
|
126
|
+
class FLC(MatchDensityToDensity):
|
127
|
+
def __init__(self, **kwargs: Dict):
|
128
|
+
super().__init__(**kwargs)
|
129
|
+
|
130
|
+
if self.target_mask is not None:
|
131
|
+
backend.multiply(self.target, self.target_mask, out=self.target)
|
132
|
+
|
133
|
+
self.target_square = backend.square(self.target)
|
134
|
+
|
135
|
+
_normalize_under_mask(
|
136
|
+
template=self.template,
|
137
|
+
mask=self.template_mask,
|
138
|
+
mask_intensity=backend.sum(self.template_mask),
|
139
|
+
)
|
140
|
+
|
141
|
+
self.template = backend.reverse(self.template)
|
142
|
+
self.template_mask = backend.reverse(self.template_mask)
|
143
|
+
|
144
|
+
def __call__(self, translation: ArrayLike, rotation_matrix: ArrayLike) -> float:
|
145
|
+
if self.rotate_mask:
|
146
|
+
self.rigid_transform(
|
147
|
+
arr=self.template,
|
148
|
+
arr_mask=self.template_mask,
|
149
|
+
rotation_matrix=rotation_matrix,
|
150
|
+
translation=translation,
|
151
|
+
out=self.template_rot,
|
152
|
+
out_mask=self.template_mask_rot,
|
153
|
+
use_geometric_center=False,
|
154
|
+
order=self.interpolation_order,
|
155
|
+
)
|
156
|
+
else:
|
157
|
+
self.rigid_transform(
|
158
|
+
arr=self.template,
|
159
|
+
rotation_matrix=rotation_matrix,
|
160
|
+
translation=translation,
|
161
|
+
out=self.template_rot,
|
162
|
+
use_geometric_center=False,
|
163
|
+
order=self.interpolation_order,
|
164
|
+
)
|
165
|
+
n_observations = backend.sum(self.template_mask_rot)
|
166
|
+
|
167
|
+
_normalize_under_mask(
|
168
|
+
template=self.template_rot,
|
169
|
+
mask=self.template_mask_rot,
|
170
|
+
mask_intensity=n_observations,
|
171
|
+
)
|
172
|
+
|
173
|
+
ex2 = backend.sum(
|
174
|
+
backend.divide(
|
175
|
+
backend.sum(
|
176
|
+
backend.multiply(self.target_square, self.template_mask_rot),
|
177
|
+
),
|
178
|
+
n_observations,
|
179
|
+
)
|
180
|
+
)
|
181
|
+
e2x = backend.square(
|
182
|
+
backend.divide(
|
183
|
+
backend.sum(backend.multiply(self.target, self.template_mask_rot)),
|
184
|
+
n_observations,
|
185
|
+
)
|
186
|
+
)
|
187
|
+
|
188
|
+
denominator = backend.maximum(backend.subtract(ex2, e2x), 0.0)
|
189
|
+
denominator = backend.sqrt(denominator)
|
190
|
+
denominator = backend.multiply(denominator, n_observations)
|
191
|
+
|
192
|
+
overlap = backend.sum(backend.multiply(self.template_rot, self.target))
|
193
|
+
|
194
|
+
score = backend.divide(overlap, denominator) * self.score_sign
|
195
|
+
return score
|