pytme 0.2.0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/match_template.py +183 -69
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/METADATA +1 -1
- pytme-0.2.1.dist-info/RECORD +73 -0
- scripts/extract_candidates.py +117 -85
- scripts/match_template.py +183 -69
- scripts/match_template_filters.py +193 -71
- scripts/postprocess.py +107 -49
- scripts/preprocessor_gui.py +4 -1
- scripts/refine_matches.py +364 -160
- tme/__version__.py +1 -1
- tme/analyzer.py +259 -117
- tme/backends/__init__.py +1 -0
- tme/backends/cupy_backend.py +20 -13
- tme/backends/jax_backend.py +218 -0
- tme/backends/matching_backend.py +25 -10
- tme/backends/mlx_backend.py +13 -9
- tme/backends/npfftw_backend.py +20 -8
- tme/backends/pytorch_backend.py +20 -9
- tme/density.py +79 -60
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +85 -61
- tme/matching_exhaustive.py +222 -129
- tme/matching_optimization.py +117 -76
- tme/orientations.py +175 -55
- tme/preprocessing/_utils.py +17 -5
- tme/preprocessing/composable_filter.py +2 -1
- tme/preprocessing/compose.py +1 -2
- tme/preprocessing/frequency_filters.py +97 -41
- tme/preprocessing/tilt_series.py +137 -87
- tme/preprocessor.py +3 -0
- tme/structure.py +4 -1
- pytme-0.2.0.dist-info/RECORD +0 -72
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/backends/cupy_backend.py
CHANGED
@@ -25,29 +25,37 @@ class CupyBackend(NumpyFFTWBackend):
|
|
25
25
|
"""
|
26
26
|
|
27
27
|
def __init__(
|
28
|
-
self,
|
28
|
+
self,
|
29
|
+
float_dtype=None,
|
30
|
+
complex_dtype=None,
|
31
|
+
int_dtype=None,
|
32
|
+
overflow_safe_dtype=None,
|
33
|
+
**kwargs,
|
29
34
|
):
|
30
35
|
import cupy as cp
|
31
36
|
from cupyx.scipy.fft import get_fft_plan
|
32
37
|
from cupyx.scipy.ndimage import affine_transform
|
33
38
|
from cupyx.scipy.ndimage import maximum_filter
|
34
39
|
|
35
|
-
|
40
|
+
float_dtype = cp.float32 if float_dtype is None else float_dtype
|
36
41
|
complex_dtype = cp.complex64 if complex_dtype is None else complex_dtype
|
37
|
-
|
42
|
+
int_dtype = cp.int32 if int_dtype is None else int_dtype
|
43
|
+
if overflow_safe_dtype is None:
|
44
|
+
overflow_safe_dtype = cp.float32
|
38
45
|
|
39
46
|
super().__init__(
|
40
47
|
array_backend=cp,
|
41
|
-
|
48
|
+
float_dtype=float_dtype,
|
42
49
|
complex_dtype=complex_dtype,
|
43
|
-
|
50
|
+
int_dtype=int_dtype,
|
51
|
+
overflow_safe_dtype=overflow_safe_dtype,
|
44
52
|
)
|
45
53
|
self.get_fft_plan = get_fft_plan
|
46
54
|
self.affine_transform = affine_transform
|
47
55
|
self.maximum_filter = maximum_filter
|
48
56
|
|
49
|
-
floating = f"float{self.datatype_bytes(
|
50
|
-
integer = f"int{self.datatype_bytes(
|
57
|
+
floating = f"float{self.datatype_bytes(float_dtype) * 8}"
|
58
|
+
integer = f"int{self.datatype_bytes(int_dtype) * 8}"
|
51
59
|
self._max_score_over_rotations = self._array_backend.ElementwiseKernel(
|
52
60
|
f"{floating} internal_scores, {floating} scores, {integer} rot_index",
|
53
61
|
f"{floating} out1, {integer} rotations",
|
@@ -119,12 +127,11 @@ class CupyBackend(NumpyFFTWBackend):
|
|
119
127
|
fast_ft_shape: Tuple[int],
|
120
128
|
real_dtype: type,
|
121
129
|
complex_dtype: type,
|
122
|
-
fftargs: Dict = {},
|
123
130
|
inverse_fast_shape: Tuple[int] = None,
|
124
131
|
**kwargs,
|
125
132
|
) -> Tuple[Callable, Callable]:
|
126
133
|
"""
|
127
|
-
Build
|
134
|
+
Build rfftn and irfftn functions.
|
128
135
|
|
129
136
|
Parameters
|
130
137
|
----------
|
@@ -140,8 +147,6 @@ class CupyBackend(NumpyFFTWBackend):
|
|
140
147
|
Numpy dtype of the fourier transform.
|
141
148
|
inverse_fast_shape : tuple, optional
|
142
149
|
Output shape of the inverse Fourier transform. By default fast_shape.
|
143
|
-
fftargs : dict, optional
|
144
|
-
Dictionary passed to pyFFTW builders.
|
145
150
|
**kwargs: dict, optional
|
146
151
|
Unused keyword arguments.
|
147
152
|
|
@@ -260,11 +265,13 @@ class CupyBackend(NumpyFFTWBackend):
|
|
260
265
|
return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
|
261
266
|
translation = self.zeros(arr.ndim) if translation is None else translation
|
262
267
|
|
263
|
-
center = self.divide(
|
268
|
+
center = self.divide(arr.shape, 2)
|
264
269
|
if not use_geometric_center:
|
265
270
|
center = self.center_of_mass(arr, cutoff=0)
|
266
271
|
|
267
|
-
rotation_matrix_inverted = self.linalg.inv(
|
272
|
+
rotation_matrix_inverted = self.linalg.inv(
|
273
|
+
rotation_matrix.astype(self._overflow_safe_dtype)
|
274
|
+
).astype(self._float_dtype)
|
268
275
|
transformed_center = rotation_matrix_inverted @ center.reshape(-1, 1)
|
269
276
|
transformed_center = transformed_center.reshape(-1)
|
270
277
|
base_offset = self.subtract(center, transformed_center)
|
@@ -0,0 +1,218 @@
|
|
1
|
+
""" Backend using jax for template matching.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
from typing import Tuple, Callable
|
8
|
+
|
9
|
+
from .npfftw_backend import NumpyFFTWBackend
|
10
|
+
|
11
|
+
class JaxBackend(NumpyFFTWBackend):
|
12
|
+
def __init__(
|
13
|
+
self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs
|
14
|
+
):
|
15
|
+
import jax.scipy as jsp
|
16
|
+
import jax.numpy as jnp
|
17
|
+
|
18
|
+
float_dtype = jnp.float32 if float_dtype is None else float_dtype
|
19
|
+
complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype
|
20
|
+
int_dtype = jnp.int32 if int_dtype is None else int_dtype
|
21
|
+
|
22
|
+
self.scipy = jsp
|
23
|
+
super().__init__(
|
24
|
+
array_backend=jnp,
|
25
|
+
float_dtype=float_dtype,
|
26
|
+
complex_dtype=complex_dtype,
|
27
|
+
int_dtype=int_dtype,
|
28
|
+
)
|
29
|
+
|
30
|
+
def to_backend_array(self, arr):
|
31
|
+
return self._array_backend.asarray(arr)
|
32
|
+
|
33
|
+
def preallocate_array(self, shape: Tuple[int], dtype: type):
|
34
|
+
arr = self._array_backend.zeros(shape, dtype=dtype)
|
35
|
+
return arr
|
36
|
+
|
37
|
+
def topleft_pad(self, arr, shape: Tuple[int], padval: int = 0):
|
38
|
+
b = self.preallocate_array(shape, arr.dtype)
|
39
|
+
self.add(b, padval, out=b)
|
40
|
+
aind = [slice(None, None)] * arr.ndim
|
41
|
+
bind = [slice(None, None)] * arr.ndim
|
42
|
+
for i in range(arr.ndim):
|
43
|
+
if arr.shape[i] > shape[i]:
|
44
|
+
aind[i] = slice(0, shape[i])
|
45
|
+
elif arr.shape[i] < shape[i]:
|
46
|
+
bind[i] = slice(0, arr.shape[i])
|
47
|
+
b = b.at[tuple(bind)].set(arr[tuple(aind)])
|
48
|
+
return b
|
49
|
+
|
50
|
+
|
51
|
+
def add(self, x1, x2, out = None, *args, **kwargs):
|
52
|
+
x1 = self.to_backend_array(x1)
|
53
|
+
x2 = self.to_backend_array(x2)
|
54
|
+
ret = self._array_backend.add(x1, x2, *args, **kwargs)
|
55
|
+
|
56
|
+
if out is not None:
|
57
|
+
out = out.at[:].set(ret)
|
58
|
+
return ret
|
59
|
+
|
60
|
+
def subtract(self, x1, x2, out = None, *args, **kwargs):
|
61
|
+
x1 = self.to_backend_array(x1)
|
62
|
+
x2 = self.to_backend_array(x2)
|
63
|
+
ret = self._array_backend.subtract(x1, x2, *args, **kwargs)
|
64
|
+
if out is not None:
|
65
|
+
out = out.at[:].set(ret)
|
66
|
+
return ret
|
67
|
+
|
68
|
+
def multiply(self, x1, x2, out = None, *args, **kwargs):
|
69
|
+
x1 = self.to_backend_array(x1)
|
70
|
+
x2 = self.to_backend_array(x2)
|
71
|
+
ret = self._array_backend.multiply(x1, x2, *args, **kwargs)
|
72
|
+
if out is not None:
|
73
|
+
out = out.at[:].set(ret)
|
74
|
+
return ret
|
75
|
+
|
76
|
+
def divide(self, x1, x2, out = None, *args, **kwargs):
|
77
|
+
x1 = self.to_backend_array(x1)
|
78
|
+
x2 = self.to_backend_array(x2)
|
79
|
+
ret = self._array_backend.divide(x1, x2, *args, **kwargs)
|
80
|
+
if out is not None:
|
81
|
+
out = out.at[:].set(ret)
|
82
|
+
return ret
|
83
|
+
|
84
|
+
def fill(self, arr, value: float) -> None:
|
85
|
+
arr.at[:].set(value)
|
86
|
+
|
87
|
+
|
88
|
+
def build_fft(
|
89
|
+
self,
|
90
|
+
fast_shape: Tuple[int],
|
91
|
+
fast_ft_shape: Tuple[int],
|
92
|
+
inverse_fast_shape: Tuple[int] = None,
|
93
|
+
**kwargs,
|
94
|
+
) -> Tuple[Callable, Callable]:
|
95
|
+
"""
|
96
|
+
Build fft builder functions.
|
97
|
+
|
98
|
+
Parameters
|
99
|
+
----------
|
100
|
+
fast_shape : tuple
|
101
|
+
Tuple of integers corresponding to fast convolution shape
|
102
|
+
(see :py:meth:`PytorchBackend.compute_convolution_shapes`).
|
103
|
+
fast_ft_shape : tuple
|
104
|
+
Tuple of integers corresponding to the shape of the Fourier
|
105
|
+
transform array (see :py:meth:`PytorchBackend.compute_convolution_shapes`).
|
106
|
+
inverse_fast_shape : tuple, optional
|
107
|
+
Output shape of the inverse Fourier transform. By default fast_shape.
|
108
|
+
**kwargs : dict, optional
|
109
|
+
Unused keyword arguments.
|
110
|
+
|
111
|
+
Returns
|
112
|
+
-------
|
113
|
+
tuple
|
114
|
+
Tupple containing callable rfft and irfft object.
|
115
|
+
"""
|
116
|
+
if inverse_fast_shape is None:
|
117
|
+
inverse_fast_shape = fast_shape
|
118
|
+
|
119
|
+
def rfftn(
|
120
|
+
arr, out, shape: Tuple[int] = fast_shape
|
121
|
+
) -> None:
|
122
|
+
out = out.at[:].set(self._array_backend.fft.rfftn(arr, s=shape))
|
123
|
+
|
124
|
+
def irfftn(
|
125
|
+
arr, out, shape: Tuple[int] = inverse_fast_shape
|
126
|
+
) -> None:
|
127
|
+
out = out.at[:].set(self._array_backend.fft.irfftn(arr, s=shape))
|
128
|
+
|
129
|
+
return rfftn, irfftn
|
130
|
+
|
131
|
+
def sharedarr_to_arr(self, shm, shape: Tuple[int], dtype: str):
|
132
|
+
return shm
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def arr_to_sharedarr(arr, shared_memory_handler: type = None):
|
136
|
+
return arr
|
137
|
+
|
138
|
+
def rotate_array(
|
139
|
+
self,
|
140
|
+
arr,
|
141
|
+
rotation_matrix,
|
142
|
+
arr_mask = None,
|
143
|
+
translation = None,
|
144
|
+
use_geometric_center: bool = False,
|
145
|
+
out = None,
|
146
|
+
out_mask = None,
|
147
|
+
order: int = 3,
|
148
|
+
) -> None:
|
149
|
+
rotate_mask = arr_mask is not None
|
150
|
+
return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
|
151
|
+
translation = self.zeros(arr.ndim) if translation is None else translation
|
152
|
+
|
153
|
+
indices = self._array_backend.indices(arr.shape).reshape(
|
154
|
+
(len(arr.shape), -1)
|
155
|
+
).astype(self._float_dtype)
|
156
|
+
|
157
|
+
center = self.divide(arr.shape, 2)
|
158
|
+
if not use_geometric_center:
|
159
|
+
center = self.center_of_mass(arr, cutoff=0)
|
160
|
+
center = center[:, None]
|
161
|
+
indices = indices.at[:].add(-center)
|
162
|
+
rotation_matrix = self._array_backend.linalg.inv(rotation_matrix)
|
163
|
+
indices = self._array_backend.matmul(rotation_matrix, indices)
|
164
|
+
indices = indices.at[:].add(center)
|
165
|
+
|
166
|
+
out = self.zeros_like(arr) if out is None else out
|
167
|
+
out_slice = tuple(slice(0, stop) for stop in arr.shape)
|
168
|
+
|
169
|
+
out = out.at[out_slice].set(
|
170
|
+
self.scipy.ndimage.map_coordinates(
|
171
|
+
arr, indices, order=order
|
172
|
+
).reshape(arr.shape)
|
173
|
+
)
|
174
|
+
|
175
|
+
if rotate_mask:
|
176
|
+
out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
|
177
|
+
out_mask_slice = tuple(slice(0, stop) for stop in arr_mask.shape)
|
178
|
+
out_mask = out_mask.at[out_mask_slice].set(
|
179
|
+
self.scipy.ndimage.map_coordinates(
|
180
|
+
arr_mask, indices, order=order
|
181
|
+
).reshape(arr.shape)
|
182
|
+
)
|
183
|
+
|
184
|
+
match return_type:
|
185
|
+
case 0:
|
186
|
+
return None
|
187
|
+
case 1:
|
188
|
+
return out
|
189
|
+
case 2:
|
190
|
+
return out_mask
|
191
|
+
case 3:
|
192
|
+
return out, out_mask
|
193
|
+
|
194
|
+
def max_score_over_rotations(
|
195
|
+
self,
|
196
|
+
score_space,
|
197
|
+
internal_scores,
|
198
|
+
internal_rotations,
|
199
|
+
rotation_index: int,
|
200
|
+
):
|
201
|
+
"""
|
202
|
+
Modify internal_scores and internal_rotations inplace with scores and rotation
|
203
|
+
index respectively, wherever score_sapce is larger than internal scores.
|
204
|
+
|
205
|
+
Parameters
|
206
|
+
----------
|
207
|
+
score_space : CupyArray
|
208
|
+
The score space to compare against internal_scores.
|
209
|
+
internal_scores : CupyArray
|
210
|
+
The internal scores to update with maximum scores.
|
211
|
+
internal_rotations : CupyArray
|
212
|
+
The internal rotations corresponding to the maximum scores.
|
213
|
+
rotation_index : int
|
214
|
+
The index representing the current rotation.
|
215
|
+
"""
|
216
|
+
indices = score_space > internal_scores
|
217
|
+
internal_scores.at[indices].set(score_space[indices])
|
218
|
+
internal_rotations.at[indices].set(rotation_index)
|
tme/backends/matching_backend.py
CHANGED
@@ -29,32 +29,39 @@ class MatchingBackend(ABC):
|
|
29
29
|
----------
|
30
30
|
array_backend : object
|
31
31
|
The backend object providing array functionalities.
|
32
|
-
|
32
|
+
float_dtype : type
|
33
33
|
Data type of real array instances, e.g. np.float32.
|
34
34
|
complex_dtype : type
|
35
35
|
Data type of complex array instances, e.g. np.complex64.
|
36
|
-
|
36
|
+
int_dtype : type
|
37
37
|
Data type of integer array instances, e.g. np.int32.
|
38
|
+
overflow_safe_dtype : type
|
39
|
+
Data type than can be used for reduction operations to avoid overflows.
|
38
40
|
|
39
41
|
Attributes
|
40
42
|
----------
|
41
43
|
_array_backend : object
|
42
44
|
The backend object used to delegate method and attribute calls.
|
43
|
-
|
45
|
+
_float_dtype : type
|
44
46
|
Data type of real array instances, e.g. np.float32.
|
45
47
|
_complex_dtype : type
|
46
48
|
Data type of complex array instances, e.g. np.complex64.
|
47
|
-
|
49
|
+
_int_dtype : type
|
48
50
|
Data type of integer array instances, e.g. np.int32.
|
51
|
+
_overflow_safe_dtype : type
|
52
|
+
Data type than can be used for reduction operations to avoid overflows.
|
53
|
+
_fundamental_dtypes : Dict
|
54
|
+
Mapping between fundamental int, float and complex python types to
|
55
|
+
array backend specific data types.
|
49
56
|
|
50
57
|
Examples
|
51
58
|
--------
|
52
59
|
>>> import numpy as np
|
53
60
|
>>> backend = MatchingBackend(
|
54
61
|
array_backend = np,
|
55
|
-
|
62
|
+
float_dtype = np.float32,
|
56
63
|
complex_dtype = np.complex64,
|
57
|
-
|
64
|
+
int_dtype = np.int32
|
58
65
|
)
|
59
66
|
>>> arr = backend.array([1, 2, 3])
|
60
67
|
>>> print(arr)
|
@@ -69,14 +76,22 @@ class MatchingBackend(ABC):
|
|
69
76
|
def __init__(
|
70
77
|
self,
|
71
78
|
array_backend,
|
72
|
-
|
79
|
+
float_dtype: type,
|
73
80
|
complex_dtype: type,
|
74
|
-
|
81
|
+
int_dtype: type,
|
82
|
+
overflow_safe_dtype: type,
|
75
83
|
):
|
76
84
|
self._array_backend = array_backend
|
77
|
-
self.
|
85
|
+
self._float_dtype = float_dtype
|
78
86
|
self._complex_dtype = complex_dtype
|
79
|
-
self.
|
87
|
+
self._int_dtype = int_dtype
|
88
|
+
self._overflow_safe_dtype = overflow_safe_dtype
|
89
|
+
|
90
|
+
self._fundamental_dtypes = {
|
91
|
+
int: self._int_dtype,
|
92
|
+
float: self._float_dtype,
|
93
|
+
complex: self._complex_dtype,
|
94
|
+
}
|
80
95
|
|
81
96
|
def __getattr__(self, name: str):
|
82
97
|
"""
|
tme/backends/mlx_backend.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
""" Backend Apple's MLX library for template matching.
|
1
|
+
""" Backend using Apple's MLX library for template matching.
|
2
2
|
|
3
3
|
Copyright (c) 2024 European Molecular Biology Laboratory
|
4
4
|
|
@@ -20,23 +20,27 @@ class MLXBackend(NumpyFFTWBackend):
|
|
20
20
|
def __init__(
|
21
21
|
self,
|
22
22
|
device="cpu",
|
23
|
-
|
23
|
+
float_dtype=None,
|
24
24
|
complex_dtype=None,
|
25
|
-
|
25
|
+
int_dtype=None,
|
26
|
+
overflow_safe_dtype=None,
|
26
27
|
**kwargs,
|
27
28
|
):
|
28
29
|
import mlx.core as mx
|
29
30
|
|
30
31
|
device = mx.cpu if device == "cpu" else mx.gpu
|
31
|
-
|
32
|
+
float_dtype = mx.float32 if float_dtype is None else float_dtype
|
32
33
|
complex_dtype = mx.complex64 if complex_dtype is None else complex_dtype
|
33
|
-
|
34
|
+
int_dtype = mx.int32 if int_dtype is None else int_dtype
|
35
|
+
if overflow_safe_dtype is None:
|
36
|
+
overflow_safe_dtype = mx.float32
|
34
37
|
|
35
38
|
super().__init__(
|
36
39
|
array_backend=mx,
|
37
|
-
|
40
|
+
float_dtype=float_dtype,
|
38
41
|
complex_dtype=complex_dtype,
|
39
|
-
|
42
|
+
int_dtype=int_dtype,
|
43
|
+
overflow_safe_dtype=overflow_safe_dtype,
|
40
44
|
)
|
41
45
|
|
42
46
|
self.device = device
|
@@ -141,8 +145,8 @@ class MLXBackend(NumpyFFTWBackend):
|
|
141
145
|
new_shape = self.to_backend_array(newshape)
|
142
146
|
current_shape = self.to_backend_array(arr.shape)
|
143
147
|
starts = self.subtract(current_shape, new_shape)
|
144
|
-
starts = self.astype(self.divide(starts, 2), self.
|
145
|
-
stops = self.astype(self.add(starts, newshape), self.
|
148
|
+
starts = self.astype(self.divide(starts, 2), self._int_dtype)
|
149
|
+
stops = self.astype(self.add(starts, newshape), self._int_dtype)
|
146
150
|
starts, stops = starts.tolist(), stops.tolist()
|
147
151
|
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
148
152
|
return arr[box]
|
tme/backends/npfftw_backend.py
CHANGED
@@ -29,16 +29,18 @@ class NumpyFFTWBackend(MatchingBackend):
|
|
29
29
|
def __init__(
|
30
30
|
self,
|
31
31
|
array_backend=np,
|
32
|
-
|
32
|
+
float_dtype=np.float32,
|
33
33
|
complex_dtype=np.complex64,
|
34
|
-
|
34
|
+
int_dtype=np.int32,
|
35
|
+
overflow_safe_dtype=np.float32,
|
35
36
|
**kwargs,
|
36
37
|
):
|
37
38
|
super().__init__(
|
38
39
|
array_backend=array_backend,
|
39
|
-
|
40
|
+
float_dtype=float_dtype,
|
40
41
|
complex_dtype=complex_dtype,
|
41
|
-
|
42
|
+
int_dtype=int_dtype,
|
43
|
+
overflow_safe_dtype=overflow_safe_dtype,
|
42
44
|
)
|
43
45
|
self.affine_transform = affine_transform
|
44
46
|
|
@@ -53,6 +55,16 @@ class NumpyFFTWBackend(MatchingBackend):
|
|
53
55
|
def to_cpu_array(self, arr: NDArray) -> NDArray:
|
54
56
|
return arr
|
55
57
|
|
58
|
+
def get_fundamental_dtype(self, arr):
|
59
|
+
dt = arr.dtype
|
60
|
+
if self._array_backend.issubdtype(dt, self._array_backend.integer):
|
61
|
+
return int
|
62
|
+
elif self._array_backend.issubdtype(dt, self._array_backend.floating):
|
63
|
+
return float
|
64
|
+
elif self._array_backend.issubdtype(dt, self._array_backend.complexfloating):
|
65
|
+
return complex
|
66
|
+
return float
|
67
|
+
|
56
68
|
def free_cache(self):
|
57
69
|
pass
|
58
70
|
|
@@ -429,8 +441,8 @@ class NumpyFFTWBackend(MatchingBackend):
|
|
429
441
|
new_shape = self.to_backend_array(newshape)
|
430
442
|
current_shape = self.to_backend_array(arr.shape)
|
431
443
|
starts = self.subtract(current_shape, new_shape)
|
432
|
-
starts = self.astype(self.divide(starts, 2), self.
|
433
|
-
stops = self.astype(self.add(starts, newshape), self.
|
444
|
+
starts = self.astype(self.divide(starts, 2), self._int_dtype)
|
445
|
+
stops = self.astype(self.add(starts, newshape), self._int_dtype)
|
434
446
|
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
435
447
|
return arr[box]
|
436
448
|
|
@@ -722,11 +734,11 @@ class NumpyFFTWBackend(MatchingBackend):
|
|
722
734
|
arr = self._array_backend.where(arr > cutoff, arr, 0)
|
723
735
|
denominator = self.sum(arr)
|
724
736
|
grids = self._array_backend.ogrid[tuple(slice(0, i) for i in arr.shape)]
|
725
|
-
grids = [grid.astype(self.
|
737
|
+
grids = [grid.astype(self._float_dtype) for grid in grids]
|
726
738
|
|
727
739
|
center_of_mass = self.array(
|
728
740
|
[
|
729
|
-
self.sum(self.multiply(arr, grids[dim])
|
741
|
+
self.sum(self.multiply(arr, grids[dim]) / denominator)
|
730
742
|
for dim in range(arr.ndim)
|
731
743
|
]
|
732
744
|
)
|
tme/backends/pytorch_backend.py
CHANGED
@@ -24,25 +24,27 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
24
24
|
def __init__(
|
25
25
|
self,
|
26
26
|
device="cuda",
|
27
|
-
|
27
|
+
float_dtype=None,
|
28
28
|
complex_dtype=None,
|
29
|
-
|
29
|
+
int_dtype=None,
|
30
|
+
overflow_safe_dtype=None,
|
30
31
|
**kwargs,
|
31
32
|
):
|
32
33
|
import torch
|
33
34
|
import torch.nn.functional as F
|
34
35
|
|
35
|
-
|
36
|
+
float_dtype = torch.float32 if float_dtype is None else float_dtype
|
36
37
|
complex_dtype = torch.complex64 if complex_dtype is None else complex_dtype
|
37
|
-
|
38
|
-
|
39
|
-
|
38
|
+
int_dtype = torch.int32 if int_dtype is None else int_dtype
|
39
|
+
if overflow_safe_dtype is None:
|
40
|
+
overflow_safe_dtype = torch.float32
|
40
41
|
|
41
42
|
super().__init__(
|
42
43
|
array_backend=torch,
|
43
|
-
|
44
|
+
float_dtype=float_dtype,
|
44
45
|
complex_dtype=complex_dtype,
|
45
|
-
|
46
|
+
int_dtype=int_dtype,
|
47
|
+
overflow_safe_dtype=overflow_safe_dtype,
|
46
48
|
)
|
47
49
|
self.device = device
|
48
50
|
self.F = F
|
@@ -57,11 +59,20 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
57
59
|
def to_numpy_array(self, arr: TorchTensor) -> NDArray:
|
58
60
|
if isinstance(arr, np.ndarray):
|
59
61
|
return arr
|
60
|
-
|
62
|
+
elif isinstance(arr, self._array_backend.Tensor):
|
63
|
+
return arr.cpu().numpy()
|
64
|
+
return np.array(arr)
|
61
65
|
|
62
66
|
def to_cpu_array(self, arr: TorchTensor) -> NDArray:
|
63
67
|
return arr.cpu()
|
64
68
|
|
69
|
+
def get_fundamental_dtype(self, arr):
|
70
|
+
if self._array_backend.is_floating_point(arr):
|
71
|
+
return float
|
72
|
+
elif self._array_backend.is_complex(arr):
|
73
|
+
return complex
|
74
|
+
return int
|
75
|
+
|
65
76
|
def free_cache(self):
|
66
77
|
self._array_backend.cuda.empty_cache()
|
67
78
|
|