pytme 0.2.0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/match_template.py +183 -69
  2. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/METADATA +1 -1
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +183 -69
  8. scripts/match_template_filters.py +193 -71
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +259 -117
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +20 -8
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +79 -60
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +85 -61
  24. tme/matching_exhaustive.py +222 -129
  25. tme/matching_optimization.py +117 -76
  26. tme/orientations.py +175 -55
  27. tme/preprocessing/_utils.py +17 -5
  28. tme/preprocessing/composable_filter.py +2 -1
  29. tme/preprocessing/compose.py +1 -2
  30. tme/preprocessing/frequency_filters.py +97 -41
  31. tme/preprocessing/tilt_series.py +137 -87
  32. tme/preprocessor.py +3 -0
  33. tme/structure.py +4 -1
  34. pytme-0.2.0.dist-info/RECORD +0 -72
  35. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  36. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  37. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  38. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  39. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  40. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
@@ -25,29 +25,37 @@ class CupyBackend(NumpyFFTWBackend):
25
25
  """
26
26
 
27
27
  def __init__(
28
- self, default_dtype=None, complex_dtype=None, default_dtype_int=None, **kwargs
28
+ self,
29
+ float_dtype=None,
30
+ complex_dtype=None,
31
+ int_dtype=None,
32
+ overflow_safe_dtype=None,
33
+ **kwargs,
29
34
  ):
30
35
  import cupy as cp
31
36
  from cupyx.scipy.fft import get_fft_plan
32
37
  from cupyx.scipy.ndimage import affine_transform
33
38
  from cupyx.scipy.ndimage import maximum_filter
34
39
 
35
- default_dtype = cp.float32 if default_dtype is None else default_dtype
40
+ float_dtype = cp.float32 if float_dtype is None else float_dtype
36
41
  complex_dtype = cp.complex64 if complex_dtype is None else complex_dtype
37
- default_dtype_int = cp.int32 if default_dtype_int is None else default_dtype_int
42
+ int_dtype = cp.int32 if int_dtype is None else int_dtype
43
+ if overflow_safe_dtype is None:
44
+ overflow_safe_dtype = cp.float32
38
45
 
39
46
  super().__init__(
40
47
  array_backend=cp,
41
- default_dtype=default_dtype,
48
+ float_dtype=float_dtype,
42
49
  complex_dtype=complex_dtype,
43
- default_dtype_int=default_dtype_int,
50
+ int_dtype=int_dtype,
51
+ overflow_safe_dtype=overflow_safe_dtype,
44
52
  )
45
53
  self.get_fft_plan = get_fft_plan
46
54
  self.affine_transform = affine_transform
47
55
  self.maximum_filter = maximum_filter
48
56
 
49
- floating = f"float{self.datatype_bytes(default_dtype) * 8}"
50
- integer = f"int{self.datatype_bytes(default_dtype_int) * 8}"
57
+ floating = f"float{self.datatype_bytes(float_dtype) * 8}"
58
+ integer = f"int{self.datatype_bytes(int_dtype) * 8}"
51
59
  self._max_score_over_rotations = self._array_backend.ElementwiseKernel(
52
60
  f"{floating} internal_scores, {floating} scores, {integer} rot_index",
53
61
  f"{floating} out1, {integer} rotations",
@@ -119,12 +127,11 @@ class CupyBackend(NumpyFFTWBackend):
119
127
  fast_ft_shape: Tuple[int],
120
128
  real_dtype: type,
121
129
  complex_dtype: type,
122
- fftargs: Dict = {},
123
130
  inverse_fast_shape: Tuple[int] = None,
124
131
  **kwargs,
125
132
  ) -> Tuple[Callable, Callable]:
126
133
  """
127
- Build pyFFTW builder functions.
134
+ Build rfftn and irfftn functions.
128
135
 
129
136
  Parameters
130
137
  ----------
@@ -140,8 +147,6 @@ class CupyBackend(NumpyFFTWBackend):
140
147
  Numpy dtype of the fourier transform.
141
148
  inverse_fast_shape : tuple, optional
142
149
  Output shape of the inverse Fourier transform. By default fast_shape.
143
- fftargs : dict, optional
144
- Dictionary passed to pyFFTW builders.
145
150
  **kwargs: dict, optional
146
151
  Unused keyword arguments.
147
152
 
@@ -260,11 +265,13 @@ class CupyBackend(NumpyFFTWBackend):
260
265
  return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
261
266
  translation = self.zeros(arr.ndim) if translation is None else translation
262
267
 
263
- center = self.divide(self.to_backend_array(arr.shape), 2)
268
+ center = self.divide(arr.shape, 2)
264
269
  if not use_geometric_center:
265
270
  center = self.center_of_mass(arr, cutoff=0)
266
271
 
267
- rotation_matrix_inverted = self.linalg.inv(rotation_matrix)
272
+ rotation_matrix_inverted = self.linalg.inv(
273
+ rotation_matrix.astype(self._overflow_safe_dtype)
274
+ ).astype(self._float_dtype)
268
275
  transformed_center = rotation_matrix_inverted @ center.reshape(-1, 1)
269
276
  transformed_center = transformed_center.reshape(-1)
270
277
  base_offset = self.subtract(center, transformed_center)
@@ -0,0 +1,218 @@
1
+ """ Backend using jax for template matching.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ from typing import Tuple, Callable
8
+
9
+ from .npfftw_backend import NumpyFFTWBackend
10
+
11
+ class JaxBackend(NumpyFFTWBackend):
12
+ def __init__(
13
+ self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs
14
+ ):
15
+ import jax.scipy as jsp
16
+ import jax.numpy as jnp
17
+
18
+ float_dtype = jnp.float32 if float_dtype is None else float_dtype
19
+ complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype
20
+ int_dtype = jnp.int32 if int_dtype is None else int_dtype
21
+
22
+ self.scipy = jsp
23
+ super().__init__(
24
+ array_backend=jnp,
25
+ float_dtype=float_dtype,
26
+ complex_dtype=complex_dtype,
27
+ int_dtype=int_dtype,
28
+ )
29
+
30
+ def to_backend_array(self, arr):
31
+ return self._array_backend.asarray(arr)
32
+
33
+ def preallocate_array(self, shape: Tuple[int], dtype: type):
34
+ arr = self._array_backend.zeros(shape, dtype=dtype)
35
+ return arr
36
+
37
+ def topleft_pad(self, arr, shape: Tuple[int], padval: int = 0):
38
+ b = self.preallocate_array(shape, arr.dtype)
39
+ self.add(b, padval, out=b)
40
+ aind = [slice(None, None)] * arr.ndim
41
+ bind = [slice(None, None)] * arr.ndim
42
+ for i in range(arr.ndim):
43
+ if arr.shape[i] > shape[i]:
44
+ aind[i] = slice(0, shape[i])
45
+ elif arr.shape[i] < shape[i]:
46
+ bind[i] = slice(0, arr.shape[i])
47
+ b = b.at[tuple(bind)].set(arr[tuple(aind)])
48
+ return b
49
+
50
+
51
+ def add(self, x1, x2, out = None, *args, **kwargs):
52
+ x1 = self.to_backend_array(x1)
53
+ x2 = self.to_backend_array(x2)
54
+ ret = self._array_backend.add(x1, x2, *args, **kwargs)
55
+
56
+ if out is not None:
57
+ out = out.at[:].set(ret)
58
+ return ret
59
+
60
+ def subtract(self, x1, x2, out = None, *args, **kwargs):
61
+ x1 = self.to_backend_array(x1)
62
+ x2 = self.to_backend_array(x2)
63
+ ret = self._array_backend.subtract(x1, x2, *args, **kwargs)
64
+ if out is not None:
65
+ out = out.at[:].set(ret)
66
+ return ret
67
+
68
+ def multiply(self, x1, x2, out = None, *args, **kwargs):
69
+ x1 = self.to_backend_array(x1)
70
+ x2 = self.to_backend_array(x2)
71
+ ret = self._array_backend.multiply(x1, x2, *args, **kwargs)
72
+ if out is not None:
73
+ out = out.at[:].set(ret)
74
+ return ret
75
+
76
+ def divide(self, x1, x2, out = None, *args, **kwargs):
77
+ x1 = self.to_backend_array(x1)
78
+ x2 = self.to_backend_array(x2)
79
+ ret = self._array_backend.divide(x1, x2, *args, **kwargs)
80
+ if out is not None:
81
+ out = out.at[:].set(ret)
82
+ return ret
83
+
84
+ def fill(self, arr, value: float) -> None:
85
+ arr.at[:].set(value)
86
+
87
+
88
+ def build_fft(
89
+ self,
90
+ fast_shape: Tuple[int],
91
+ fast_ft_shape: Tuple[int],
92
+ inverse_fast_shape: Tuple[int] = None,
93
+ **kwargs,
94
+ ) -> Tuple[Callable, Callable]:
95
+ """
96
+ Build fft builder functions.
97
+
98
+ Parameters
99
+ ----------
100
+ fast_shape : tuple
101
+ Tuple of integers corresponding to fast convolution shape
102
+ (see :py:meth:`PytorchBackend.compute_convolution_shapes`).
103
+ fast_ft_shape : tuple
104
+ Tuple of integers corresponding to the shape of the Fourier
105
+ transform array (see :py:meth:`PytorchBackend.compute_convolution_shapes`).
106
+ inverse_fast_shape : tuple, optional
107
+ Output shape of the inverse Fourier transform. By default fast_shape.
108
+ **kwargs : dict, optional
109
+ Unused keyword arguments.
110
+
111
+ Returns
112
+ -------
113
+ tuple
114
+ Tupple containing callable rfft and irfft object.
115
+ """
116
+ if inverse_fast_shape is None:
117
+ inverse_fast_shape = fast_shape
118
+
119
+ def rfftn(
120
+ arr, out, shape: Tuple[int] = fast_shape
121
+ ) -> None:
122
+ out = out.at[:].set(self._array_backend.fft.rfftn(arr, s=shape))
123
+
124
+ def irfftn(
125
+ arr, out, shape: Tuple[int] = inverse_fast_shape
126
+ ) -> None:
127
+ out = out.at[:].set(self._array_backend.fft.irfftn(arr, s=shape))
128
+
129
+ return rfftn, irfftn
130
+
131
+ def sharedarr_to_arr(self, shm, shape: Tuple[int], dtype: str):
132
+ return shm
133
+
134
+ @staticmethod
135
+ def arr_to_sharedarr(arr, shared_memory_handler: type = None):
136
+ return arr
137
+
138
+ def rotate_array(
139
+ self,
140
+ arr,
141
+ rotation_matrix,
142
+ arr_mask = None,
143
+ translation = None,
144
+ use_geometric_center: bool = False,
145
+ out = None,
146
+ out_mask = None,
147
+ order: int = 3,
148
+ ) -> None:
149
+ rotate_mask = arr_mask is not None
150
+ return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
151
+ translation = self.zeros(arr.ndim) if translation is None else translation
152
+
153
+ indices = self._array_backend.indices(arr.shape).reshape(
154
+ (len(arr.shape), -1)
155
+ ).astype(self._float_dtype)
156
+
157
+ center = self.divide(arr.shape, 2)
158
+ if not use_geometric_center:
159
+ center = self.center_of_mass(arr, cutoff=0)
160
+ center = center[:, None]
161
+ indices = indices.at[:].add(-center)
162
+ rotation_matrix = self._array_backend.linalg.inv(rotation_matrix)
163
+ indices = self._array_backend.matmul(rotation_matrix, indices)
164
+ indices = indices.at[:].add(center)
165
+
166
+ out = self.zeros_like(arr) if out is None else out
167
+ out_slice = tuple(slice(0, stop) for stop in arr.shape)
168
+
169
+ out = out.at[out_slice].set(
170
+ self.scipy.ndimage.map_coordinates(
171
+ arr, indices, order=order
172
+ ).reshape(arr.shape)
173
+ )
174
+
175
+ if rotate_mask:
176
+ out_mask = self.zeros_like(arr_mask) if out_mask is None else out_mask
177
+ out_mask_slice = tuple(slice(0, stop) for stop in arr_mask.shape)
178
+ out_mask = out_mask.at[out_mask_slice].set(
179
+ self.scipy.ndimage.map_coordinates(
180
+ arr_mask, indices, order=order
181
+ ).reshape(arr.shape)
182
+ )
183
+
184
+ match return_type:
185
+ case 0:
186
+ return None
187
+ case 1:
188
+ return out
189
+ case 2:
190
+ return out_mask
191
+ case 3:
192
+ return out, out_mask
193
+
194
+ def max_score_over_rotations(
195
+ self,
196
+ score_space,
197
+ internal_scores,
198
+ internal_rotations,
199
+ rotation_index: int,
200
+ ):
201
+ """
202
+ Modify internal_scores and internal_rotations inplace with scores and rotation
203
+ index respectively, wherever score_sapce is larger than internal scores.
204
+
205
+ Parameters
206
+ ----------
207
+ score_space : CupyArray
208
+ The score space to compare against internal_scores.
209
+ internal_scores : CupyArray
210
+ The internal scores to update with maximum scores.
211
+ internal_rotations : CupyArray
212
+ The internal rotations corresponding to the maximum scores.
213
+ rotation_index : int
214
+ The index representing the current rotation.
215
+ """
216
+ indices = score_space > internal_scores
217
+ internal_scores.at[indices].set(score_space[indices])
218
+ internal_rotations.at[indices].set(rotation_index)
@@ -29,32 +29,39 @@ class MatchingBackend(ABC):
29
29
  ----------
30
30
  array_backend : object
31
31
  The backend object providing array functionalities.
32
- default_dtype : type
32
+ float_dtype : type
33
33
  Data type of real array instances, e.g. np.float32.
34
34
  complex_dtype : type
35
35
  Data type of complex array instances, e.g. np.complex64.
36
- default_dtype_int : type
36
+ int_dtype : type
37
37
  Data type of integer array instances, e.g. np.int32.
38
+ overflow_safe_dtype : type
39
+ Data type than can be used for reduction operations to avoid overflows.
38
40
 
39
41
  Attributes
40
42
  ----------
41
43
  _array_backend : object
42
44
  The backend object used to delegate method and attribute calls.
43
- _default_dtype : type
45
+ _float_dtype : type
44
46
  Data type of real array instances, e.g. np.float32.
45
47
  _complex_dtype : type
46
48
  Data type of complex array instances, e.g. np.complex64.
47
- _default_dtype_int : type
49
+ _int_dtype : type
48
50
  Data type of integer array instances, e.g. np.int32.
51
+ _overflow_safe_dtype : type
52
+ Data type than can be used for reduction operations to avoid overflows.
53
+ _fundamental_dtypes : Dict
54
+ Mapping between fundamental int, float and complex python types to
55
+ array backend specific data types.
49
56
 
50
57
  Examples
51
58
  --------
52
59
  >>> import numpy as np
53
60
  >>> backend = MatchingBackend(
54
61
  array_backend = np,
55
- default_dtype = np.float32,
62
+ float_dtype = np.float32,
56
63
  complex_dtype = np.complex64,
57
- default_dtype_int = np.int32
64
+ int_dtype = np.int32
58
65
  )
59
66
  >>> arr = backend.array([1, 2, 3])
60
67
  >>> print(arr)
@@ -69,14 +76,22 @@ class MatchingBackend(ABC):
69
76
  def __init__(
70
77
  self,
71
78
  array_backend,
72
- default_dtype: type,
79
+ float_dtype: type,
73
80
  complex_dtype: type,
74
- default_dtype_int: type,
81
+ int_dtype: type,
82
+ overflow_safe_dtype: type,
75
83
  ):
76
84
  self._array_backend = array_backend
77
- self._default_dtype = default_dtype
85
+ self._float_dtype = float_dtype
78
86
  self._complex_dtype = complex_dtype
79
- self._default_dtype_int = default_dtype_int
87
+ self._int_dtype = int_dtype
88
+ self._overflow_safe_dtype = overflow_safe_dtype
89
+
90
+ self._fundamental_dtypes = {
91
+ int: self._int_dtype,
92
+ float: self._float_dtype,
93
+ complex: self._complex_dtype,
94
+ }
80
95
 
81
96
  def __getattr__(self, name: str):
82
97
  """
@@ -1,4 +1,4 @@
1
- """ Backend Apple's MLX library for template matching.
1
+ """ Backend using Apple's MLX library for template matching.
2
2
 
3
3
  Copyright (c) 2024 European Molecular Biology Laboratory
4
4
 
@@ -20,23 +20,27 @@ class MLXBackend(NumpyFFTWBackend):
20
20
  def __init__(
21
21
  self,
22
22
  device="cpu",
23
- default_dtype=None,
23
+ float_dtype=None,
24
24
  complex_dtype=None,
25
- default_dtype_int=None,
25
+ int_dtype=None,
26
+ overflow_safe_dtype=None,
26
27
  **kwargs,
27
28
  ):
28
29
  import mlx.core as mx
29
30
 
30
31
  device = mx.cpu if device == "cpu" else mx.gpu
31
- default_dtype = mx.float32 if default_dtype is None else default_dtype
32
+ float_dtype = mx.float32 if float_dtype is None else float_dtype
32
33
  complex_dtype = mx.complex64 if complex_dtype is None else complex_dtype
33
- default_dtype_int = mx.int32 if default_dtype_int is None else default_dtype_int
34
+ int_dtype = mx.int32 if int_dtype is None else int_dtype
35
+ if overflow_safe_dtype is None:
36
+ overflow_safe_dtype = mx.float32
34
37
 
35
38
  super().__init__(
36
39
  array_backend=mx,
37
- default_dtype=default_dtype,
40
+ float_dtype=float_dtype,
38
41
  complex_dtype=complex_dtype,
39
- default_dtype_int=default_dtype_int,
42
+ int_dtype=int_dtype,
43
+ overflow_safe_dtype=overflow_safe_dtype,
40
44
  )
41
45
 
42
46
  self.device = device
@@ -141,8 +145,8 @@ class MLXBackend(NumpyFFTWBackend):
141
145
  new_shape = self.to_backend_array(newshape)
142
146
  current_shape = self.to_backend_array(arr.shape)
143
147
  starts = self.subtract(current_shape, new_shape)
144
- starts = self.astype(self.divide(starts, 2), self._default_dtype_int)
145
- stops = self.astype(self.add(starts, newshape), self._default_dtype_int)
148
+ starts = self.astype(self.divide(starts, 2), self._int_dtype)
149
+ stops = self.astype(self.add(starts, newshape), self._int_dtype)
146
150
  starts, stops = starts.tolist(), stops.tolist()
147
151
  box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
148
152
  return arr[box]
@@ -29,16 +29,18 @@ class NumpyFFTWBackend(MatchingBackend):
29
29
  def __init__(
30
30
  self,
31
31
  array_backend=np,
32
- default_dtype=np.float32,
32
+ float_dtype=np.float32,
33
33
  complex_dtype=np.complex64,
34
- default_dtype_int=np.int32,
34
+ int_dtype=np.int32,
35
+ overflow_safe_dtype=np.float32,
35
36
  **kwargs,
36
37
  ):
37
38
  super().__init__(
38
39
  array_backend=array_backend,
39
- default_dtype=default_dtype,
40
+ float_dtype=float_dtype,
40
41
  complex_dtype=complex_dtype,
41
- default_dtype_int=default_dtype_int,
42
+ int_dtype=int_dtype,
43
+ overflow_safe_dtype=overflow_safe_dtype,
42
44
  )
43
45
  self.affine_transform = affine_transform
44
46
 
@@ -53,6 +55,16 @@ class NumpyFFTWBackend(MatchingBackend):
53
55
  def to_cpu_array(self, arr: NDArray) -> NDArray:
54
56
  return arr
55
57
 
58
+ def get_fundamental_dtype(self, arr):
59
+ dt = arr.dtype
60
+ if self._array_backend.issubdtype(dt, self._array_backend.integer):
61
+ return int
62
+ elif self._array_backend.issubdtype(dt, self._array_backend.floating):
63
+ return float
64
+ elif self._array_backend.issubdtype(dt, self._array_backend.complexfloating):
65
+ return complex
66
+ return float
67
+
56
68
  def free_cache(self):
57
69
  pass
58
70
 
@@ -429,8 +441,8 @@ class NumpyFFTWBackend(MatchingBackend):
429
441
  new_shape = self.to_backend_array(newshape)
430
442
  current_shape = self.to_backend_array(arr.shape)
431
443
  starts = self.subtract(current_shape, new_shape)
432
- starts = self.astype(self.divide(starts, 2), self._default_dtype_int)
433
- stops = self.astype(self.add(starts, newshape), self._default_dtype_int)
444
+ starts = self.astype(self.divide(starts, 2), self._int_dtype)
445
+ stops = self.astype(self.add(starts, newshape), self._int_dtype)
434
446
  box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
435
447
  return arr[box]
436
448
 
@@ -722,11 +734,11 @@ class NumpyFFTWBackend(MatchingBackend):
722
734
  arr = self._array_backend.where(arr > cutoff, arr, 0)
723
735
  denominator = self.sum(arr)
724
736
  grids = self._array_backend.ogrid[tuple(slice(0, i) for i in arr.shape)]
725
- grids = [grid.astype(self._default_dtype) for grid in grids]
737
+ grids = [grid.astype(self._float_dtype) for grid in grids]
726
738
 
727
739
  center_of_mass = self.array(
728
740
  [
729
- self.sum(self.multiply(arr, grids[dim])) / denominator
741
+ self.sum(self.multiply(arr, grids[dim]) / denominator)
730
742
  for dim in range(arr.ndim)
731
743
  ]
732
744
  )
@@ -24,25 +24,27 @@ class PytorchBackend(NumpyFFTWBackend):
24
24
  def __init__(
25
25
  self,
26
26
  device="cuda",
27
- default_dtype=None,
27
+ float_dtype=None,
28
28
  complex_dtype=None,
29
- default_dtype_int=None,
29
+ int_dtype=None,
30
+ overflow_safe_dtype=None,
30
31
  **kwargs,
31
32
  ):
32
33
  import torch
33
34
  import torch.nn.functional as F
34
35
 
35
- default_dtype = torch.float32 if default_dtype is None else default_dtype
36
+ float_dtype = torch.float32 if float_dtype is None else float_dtype
36
37
  complex_dtype = torch.complex64 if complex_dtype is None else complex_dtype
37
- default_dtype_int = (
38
- torch.int32 if default_dtype_int is None else default_dtype_int
39
- )
38
+ int_dtype = torch.int32 if int_dtype is None else int_dtype
39
+ if overflow_safe_dtype is None:
40
+ overflow_safe_dtype = torch.float32
40
41
 
41
42
  super().__init__(
42
43
  array_backend=torch,
43
- default_dtype=default_dtype,
44
+ float_dtype=float_dtype,
44
45
  complex_dtype=complex_dtype,
45
- default_dtype_int=default_dtype_int,
46
+ int_dtype=int_dtype,
47
+ overflow_safe_dtype=overflow_safe_dtype,
46
48
  )
47
49
  self.device = device
48
50
  self.F = F
@@ -57,11 +59,20 @@ class PytorchBackend(NumpyFFTWBackend):
57
59
  def to_numpy_array(self, arr: TorchTensor) -> NDArray:
58
60
  if isinstance(arr, np.ndarray):
59
61
  return arr
60
- return arr.cpu().numpy()
62
+ elif isinstance(arr, self._array_backend.Tensor):
63
+ return arr.cpu().numpy()
64
+ return np.array(arr)
61
65
 
62
66
  def to_cpu_array(self, arr: TorchTensor) -> NDArray:
63
67
  return arr.cpu()
64
68
 
69
+ def get_fundamental_dtype(self, arr):
70
+ if self._array_backend.is_floating_point(arr):
71
+ return float
72
+ elif self._array_backend.is_complex(arr):
73
+ return complex
74
+ return int
75
+
65
76
  def free_cache(self):
66
77
  self._array_backend.cuda.empty_cache()
67
78