pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
@@ -0,0 +1,430 @@
|
|
1
|
+
""" Backend using pytorch and optionally GPU acceleration for
|
2
|
+
template matching.
|
3
|
+
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Tuple, Callable
|
10
|
+
from contextlib import contextmanager
|
11
|
+
from multiprocessing import shared_memory
|
12
|
+
from multiprocessing.managers import SharedMemoryManager
|
13
|
+
|
14
|
+
import numpy as np
|
15
|
+
from .npfftw_backend import NumpyFFTWBackend
|
16
|
+
from ..types import NDArray, TorchTensor, shm_type
|
17
|
+
|
18
|
+
|
19
|
+
class PytorchBackend(NumpyFFTWBackend):
|
20
|
+
"""
|
21
|
+
A pytorch-based matching backend.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
device="cuda",
|
27
|
+
float_dtype=None,
|
28
|
+
complex_dtype=None,
|
29
|
+
int_dtype=None,
|
30
|
+
overflow_safe_dtype=None,
|
31
|
+
**kwargs,
|
32
|
+
):
|
33
|
+
import torch
|
34
|
+
import torch.nn.functional as F
|
35
|
+
|
36
|
+
float_dtype = torch.float32 if float_dtype is None else float_dtype
|
37
|
+
complex_dtype = torch.complex64 if complex_dtype is None else complex_dtype
|
38
|
+
int_dtype = torch.int32 if int_dtype is None else int_dtype
|
39
|
+
if overflow_safe_dtype is None:
|
40
|
+
overflow_safe_dtype = torch.float32
|
41
|
+
|
42
|
+
super().__init__(
|
43
|
+
array_backend=torch,
|
44
|
+
float_dtype=float_dtype,
|
45
|
+
complex_dtype=complex_dtype,
|
46
|
+
int_dtype=int_dtype,
|
47
|
+
overflow_safe_dtype=overflow_safe_dtype,
|
48
|
+
)
|
49
|
+
self.device = device
|
50
|
+
self.F = F
|
51
|
+
|
52
|
+
def to_backend_array(self, arr: NDArray, check_device: bool = True) -> TorchTensor:
|
53
|
+
if isinstance(arr, self._array_backend.Tensor):
|
54
|
+
if arr.device == self.device or not check_device:
|
55
|
+
return arr
|
56
|
+
return arr.to(self.device)
|
57
|
+
return self.tensor(arr, device=self.device)
|
58
|
+
|
59
|
+
def to_numpy_array(self, arr: TorchTensor) -> NDArray:
|
60
|
+
if isinstance(arr, np.ndarray):
|
61
|
+
return arr
|
62
|
+
elif isinstance(arr, self._array_backend.Tensor):
|
63
|
+
return arr.cpu().numpy()
|
64
|
+
return np.array(arr)
|
65
|
+
|
66
|
+
def to_cpu_array(self, arr: TorchTensor) -> NDArray:
|
67
|
+
return arr.cpu()
|
68
|
+
|
69
|
+
def get_fundamental_dtype(self, arr):
|
70
|
+
if self._array_backend.is_floating_point(arr):
|
71
|
+
return float
|
72
|
+
elif self._array_backend.is_complex(arr):
|
73
|
+
return complex
|
74
|
+
return int
|
75
|
+
|
76
|
+
def free_cache(self):
|
77
|
+
self._array_backend.cuda.empty_cache()
|
78
|
+
|
79
|
+
def mod(self, x1, x2, *args, **kwargs):
|
80
|
+
return self._array_backend.remainder(x1, x2, *args, **kwargs)
|
81
|
+
|
82
|
+
def max(self, *args, **kwargs) -> NDArray:
|
83
|
+
ret = self._array_backend.amax(*args, **kwargs)
|
84
|
+
if isinstance(ret, self._array_backend.Tensor):
|
85
|
+
return ret
|
86
|
+
return ret[0]
|
87
|
+
|
88
|
+
def min(self, *args, **kwargs) -> NDArray:
|
89
|
+
ret = self._array_backend.amin(*args, **kwargs)
|
90
|
+
if isinstance(ret, self._array_backend.Tensor):
|
91
|
+
return ret
|
92
|
+
return ret[0]
|
93
|
+
|
94
|
+
def maximum(self, x1, x2, *args, **kwargs) -> NDArray:
|
95
|
+
x1 = self.to_backend_array(x1, check_device=False)
|
96
|
+
x2 = self.to_backend_array(x2, check_device=False).to(x1.device)
|
97
|
+
return self._array_backend.maximum(input=x1, other=x2, *args, **kwargs)
|
98
|
+
|
99
|
+
def minimum(self, x1, x2, *args, **kwargs) -> NDArray:
|
100
|
+
x1 = self.to_backend_array(x1)
|
101
|
+
x2 = self.to_backend_array(x2)
|
102
|
+
return self._array_backend.minimum(input=x1, other=x2, *args, **kwargs)
|
103
|
+
|
104
|
+
def tobytes(self, arr):
|
105
|
+
return arr.cpu().numpy().tobytes()
|
106
|
+
|
107
|
+
def size(self, arr):
|
108
|
+
return arr.numel()
|
109
|
+
|
110
|
+
def zeros(self, shape, dtype=None):
|
111
|
+
return self._array_backend.zeros(shape, dtype=dtype, device=self.device)
|
112
|
+
|
113
|
+
def copy(self, arr: TorchTensor) -> TorchTensor:
|
114
|
+
return self._array_backend.clone(arr)
|
115
|
+
|
116
|
+
def full(self, shape, fill_value, dtype=None):
|
117
|
+
if isinstance(shape, int):
|
118
|
+
shape = (shape,)
|
119
|
+
return self._array_backend.full(
|
120
|
+
size=shape, dtype=dtype, fill_value=fill_value, device=self.device
|
121
|
+
)
|
122
|
+
|
123
|
+
def arange(self, *args, **kwargs):
|
124
|
+
return self._array_backend.arange(*args, **kwargs, device=self.device)
|
125
|
+
|
126
|
+
def datatype_bytes(self, dtype: type) -> int:
|
127
|
+
temp = self.zeros(1, dtype=dtype)
|
128
|
+
return temp.element_size()
|
129
|
+
|
130
|
+
def fill(self, arr: TorchTensor, value: float) -> TorchTensor:
|
131
|
+
arr.fill_(value)
|
132
|
+
return arr
|
133
|
+
|
134
|
+
def astype(self, arr: TorchTensor, dtype: type) -> TorchTensor:
|
135
|
+
return arr.to(dtype)
|
136
|
+
|
137
|
+
def flip(self, a, axis, **kwargs):
|
138
|
+
return self._array_backend.flip(input=a, dims=axis, **kwargs)
|
139
|
+
|
140
|
+
def topk_indices(self, arr, k):
|
141
|
+
temp = arr.reshape(-1)
|
142
|
+
values, indices = self._array_backend.topk(temp, k)
|
143
|
+
indices = self.unravel_index(indices=indices, shape=arr.shape)
|
144
|
+
return indices
|
145
|
+
|
146
|
+
def indices(self, shape: Tuple[int], dtype: type = int) -> TorchTensor:
|
147
|
+
grids = [self.arange(x, dtype=dtype) for x in shape]
|
148
|
+
mesh = self._array_backend.meshgrid(*grids, indexing="ij")
|
149
|
+
return self._array_backend.stack(mesh)
|
150
|
+
|
151
|
+
def unravel_index(self, indices, shape):
|
152
|
+
indices = self.to_backend_array(indices)
|
153
|
+
shape = self.to_backend_array(shape)
|
154
|
+
strides = self._array_backend.cumprod(shape.flip(0), dim=0).flip(0)
|
155
|
+
strides = self._array_backend.cat(
|
156
|
+
(strides[1:], self.to_backend_array([1])),
|
157
|
+
)
|
158
|
+
unraveled_coords = (indices.view(-1, 1) // strides.view(1, -1)) % shape.view(
|
159
|
+
1, -1
|
160
|
+
)
|
161
|
+
if unraveled_coords.size(0) == 1:
|
162
|
+
return (unraveled_coords[0, :],)
|
163
|
+
|
164
|
+
else:
|
165
|
+
return tuple(unraveled_coords.T)
|
166
|
+
|
167
|
+
def roll(self, a, shift, axis, **kwargs):
|
168
|
+
shift = tuple(shift)
|
169
|
+
return self._array_backend.roll(input=a, shifts=shift, dims=axis, **kwargs)
|
170
|
+
|
171
|
+
def unique(
|
172
|
+
self,
|
173
|
+
ar,
|
174
|
+
return_index: bool = False,
|
175
|
+
return_inverse: bool = False,
|
176
|
+
return_counts: bool = False,
|
177
|
+
axis: int = None,
|
178
|
+
sorted: bool = True,
|
179
|
+
):
|
180
|
+
# https://github.com/pytorch/pytorch/issues/36748#issuecomment-1478913448
|
181
|
+
unique, inverse, counts = self._array_backend.unique(
|
182
|
+
ar, return_inverse=True, return_counts=True, dim=axis, sorted=sorted
|
183
|
+
)
|
184
|
+
inverse = inverse.reshape(-1)
|
185
|
+
|
186
|
+
if return_index:
|
187
|
+
inv_sorted = inverse.argsort(stable=True)
|
188
|
+
tot_counts = self._array_backend.cat(
|
189
|
+
(counts.new_zeros(1), counts.cumsum(dim=0))
|
190
|
+
)[:-1]
|
191
|
+
index = inv_sorted[tot_counts]
|
192
|
+
|
193
|
+
ret = unique
|
194
|
+
if return_index or return_inverse or return_counts:
|
195
|
+
ret = [unique]
|
196
|
+
|
197
|
+
if return_index:
|
198
|
+
ret.append(index)
|
199
|
+
if return_inverse:
|
200
|
+
ret.append(inverse)
|
201
|
+
if return_counts:
|
202
|
+
ret.append(counts)
|
203
|
+
|
204
|
+
return ret
|
205
|
+
|
206
|
+
def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
|
207
|
+
if score_space.ndim == 3:
|
208
|
+
func = self._array_backend.nn.MaxPool3d
|
209
|
+
elif score_space.ndim == 2:
|
210
|
+
func = self._array_backend.nn.MaxPool2d
|
211
|
+
else:
|
212
|
+
raise NotImplementedError("Operation only implemented for 2 and 3D inputs.")
|
213
|
+
|
214
|
+
pool = func(
|
215
|
+
kernel_size=min_distance, padding=min_distance // 2, return_indices=True
|
216
|
+
)
|
217
|
+
_, indices = pool(score_space.reshape(1, 1, *score_space.shape))
|
218
|
+
coordinates = self.unravel_index(indices.reshape(-1), score_space.shape)
|
219
|
+
coordinates = self.transpose(self.stack(coordinates))
|
220
|
+
return coordinates
|
221
|
+
|
222
|
+
def repeat(self, *args, **kwargs):
|
223
|
+
return self._array_backend.repeat_interleave(*args, **kwargs)
|
224
|
+
|
225
|
+
def from_sharedarr(self, args) -> TorchTensor:
|
226
|
+
if self.device == "cuda":
|
227
|
+
return args
|
228
|
+
|
229
|
+
shm, shape, dtype = args
|
230
|
+
required_size = int(self._array_backend.prod(self.to_backend_array(shape)))
|
231
|
+
|
232
|
+
ret = self._array_backend.frombuffer(shm.buf, dtype=dtype)[
|
233
|
+
:required_size
|
234
|
+
].reshape(shape)
|
235
|
+
return ret
|
236
|
+
|
237
|
+
def to_sharedarr(
|
238
|
+
self, arr: TorchTensor, shared_memory_handler: type = None
|
239
|
+
) -> shm_type:
|
240
|
+
if self.device == "cuda":
|
241
|
+
return arr
|
242
|
+
|
243
|
+
nbytes = arr.numel() * arr.element_size()
|
244
|
+
|
245
|
+
if isinstance(shared_memory_handler, SharedMemoryManager):
|
246
|
+
shm = shared_memory_handler.SharedMemory(size=nbytes)
|
247
|
+
else:
|
248
|
+
shm = shared_memory.SharedMemory(create=True, size=nbytes)
|
249
|
+
|
250
|
+
shm.buf[:nbytes] = arr.numpy().tobytes()
|
251
|
+
return shm, arr.shape, arr.dtype
|
252
|
+
|
253
|
+
def transpose(self, arr, axes=None):
|
254
|
+
if axes is None:
|
255
|
+
axes = tuple(range(arr.ndim - 1, -1, -1))
|
256
|
+
return arr.permute(axes)
|
257
|
+
|
258
|
+
def power(self, *args, **kwargs):
|
259
|
+
return self._array_backend.pow(*args, **kwargs)
|
260
|
+
|
261
|
+
def eye(self, *args, **kwargs):
|
262
|
+
if "device" not in kwargs:
|
263
|
+
kwargs["device"] = self.device
|
264
|
+
return self._array_backend.eye(*args, **kwargs)
|
265
|
+
|
266
|
+
def build_fft(
|
267
|
+
self,
|
268
|
+
fwd_shape: Tuple[int],
|
269
|
+
inv_shape: Tuple[int],
|
270
|
+
inv_output_shape: Tuple[int] = None,
|
271
|
+
fwd_axes: Tuple[int] = None,
|
272
|
+
inv_axes: Tuple[int] = None,
|
273
|
+
**kwargs,
|
274
|
+
) -> Tuple[Callable, Callable]:
|
275
|
+
rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
|
276
|
+
irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
|
277
|
+
irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
|
278
|
+
|
279
|
+
def rfftn(
|
280
|
+
arr: TorchTensor, out: TorchTensor, s=rfft_shape, axes=fwd_axes
|
281
|
+
) -> TorchTensor:
|
282
|
+
return self._array_backend.fft.rfftn(arr, s=s, out=out, dim=axes)
|
283
|
+
|
284
|
+
def irfftn(
|
285
|
+
arr: TorchTensor, out: TorchTensor = None, s=irfft_shape, axes=inv_axes
|
286
|
+
) -> TorchTensor:
|
287
|
+
return self._array_backend.fft.irfftn(arr, s=s, out=out, dim=axes)
|
288
|
+
|
289
|
+
return rfftn, irfftn
|
290
|
+
|
291
|
+
def rfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
292
|
+
kwargs["dim"] = kwargs.pop("axes", None)
|
293
|
+
return self._array_backend.fft.rfftn(arr, **kwargs)
|
294
|
+
|
295
|
+
def irfftn(self, arr: NDArray, *args, **kwargs) -> NDArray:
|
296
|
+
kwargs["dim"] = kwargs.pop("axes", None)
|
297
|
+
return self._array_backend.fft.irfftn(arr, **kwargs)
|
298
|
+
|
299
|
+
def rigid_transform(
|
300
|
+
self,
|
301
|
+
arr: TorchTensor,
|
302
|
+
rotation_matrix: TorchTensor,
|
303
|
+
arr_mask: TorchTensor = None,
|
304
|
+
translation: TorchTensor = None,
|
305
|
+
use_geometric_center: bool = False,
|
306
|
+
out: TorchTensor = None,
|
307
|
+
out_mask: TorchTensor = None,
|
308
|
+
order: int = 1,
|
309
|
+
cache: bool = False,
|
310
|
+
):
|
311
|
+
_mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
|
312
|
+
mode = _mode_mapping.get(order, None)
|
313
|
+
if mode is None:
|
314
|
+
modes = ", ".join([str(x) for x in _mode_mapping.keys()])
|
315
|
+
raise ValueError(
|
316
|
+
f"Got {order} but supported interpolation orders are: {modes}."
|
317
|
+
)
|
318
|
+
|
319
|
+
out = self.zeros_like(arr) if out is None else out
|
320
|
+
|
321
|
+
if translation is None:
|
322
|
+
translation = self._array_backend.zeros(arr.ndim, device=arr.device)
|
323
|
+
|
324
|
+
normalized_translation = self.divide(
|
325
|
+
-2.0 * translation, self.tensor(arr.shape, device=arr.device)
|
326
|
+
)
|
327
|
+
rotation_matrix_pull = self.linalg.inv(self.flip(rotation_matrix, [0, 1]))
|
328
|
+
|
329
|
+
out_slice = tuple(slice(0, x) for x in arr.shape)
|
330
|
+
subset = tuple(slice(None) for _ in range(arr.ndim))
|
331
|
+
offset = max(int(arr.ndim - rotation_matrix.shape[0]) - 1, 0)
|
332
|
+
if offset > 0:
|
333
|
+
normalized_translation = normalized_translation[offset:]
|
334
|
+
subset = tuple(0 if i < offset else slice(None) for i in range(arr.ndim))
|
335
|
+
out_slice = tuple(
|
336
|
+
slice(0, 1) if i < offset else slice(0, x)
|
337
|
+
for i, x in enumerate(arr.shape)
|
338
|
+
)
|
339
|
+
|
340
|
+
out[out_slice] = self._affine_transform(
|
341
|
+
arr=arr[subset],
|
342
|
+
rotation_matrix=rotation_matrix_pull,
|
343
|
+
translation=normalized_translation,
|
344
|
+
mode=mode,
|
345
|
+
)
|
346
|
+
|
347
|
+
if arr_mask is not None:
|
348
|
+
out_mask_slice = tuple(slice(0, x) for x in arr_mask.shape)
|
349
|
+
if out_mask is None:
|
350
|
+
out_mask = self._array_backend.zeros_like(arr_mask)
|
351
|
+
out_mask[out_mask_slice] = self._affine_transform(
|
352
|
+
arr=arr_mask[subset],
|
353
|
+
rotation_matrix=rotation_matrix_pull,
|
354
|
+
translation=normalized_translation,
|
355
|
+
mode=mode,
|
356
|
+
)
|
357
|
+
|
358
|
+
return out, out_mask
|
359
|
+
|
360
|
+
def _affine_transform(
|
361
|
+
self,
|
362
|
+
arr: TorchTensor,
|
363
|
+
rotation_matrix: TorchTensor,
|
364
|
+
translation: TorchTensor,
|
365
|
+
mode,
|
366
|
+
) -> TorchTensor:
|
367
|
+
batched = arr.ndim != rotation_matrix.shape[0]
|
368
|
+
|
369
|
+
batch_size, spatial_dims = 1, arr.shape
|
370
|
+
if batched:
|
371
|
+
translation = translation[1:]
|
372
|
+
batch_size, *spatial_dims = arr.shape
|
373
|
+
|
374
|
+
n_dims = len(spatial_dims)
|
375
|
+
transformation_matrix = self._array_backend.zeros(
|
376
|
+
n_dims, n_dims + 1, device=arr.device, dtype=arr.dtype
|
377
|
+
)
|
378
|
+
|
379
|
+
transformation_matrix[:, :n_dims] = rotation_matrix
|
380
|
+
transformation_matrix[:, n_dims] = translation
|
381
|
+
transformation_matrix = transformation_matrix.unsqueeze(0).expand(
|
382
|
+
batch_size, -1, -1
|
383
|
+
)
|
384
|
+
|
385
|
+
if not batched:
|
386
|
+
arr = arr.unsqueeze(0)
|
387
|
+
|
388
|
+
size = self.Size([batch_size, 1, *spatial_dims])
|
389
|
+
grid = self.F.affine_grid(
|
390
|
+
theta=transformation_matrix, size=size, align_corners=False
|
391
|
+
)
|
392
|
+
output = self.F.grid_sample(
|
393
|
+
input=arr.unsqueeze(1),
|
394
|
+
grid=grid,
|
395
|
+
mode=mode,
|
396
|
+
align_corners=False,
|
397
|
+
)
|
398
|
+
|
399
|
+
if not batched:
|
400
|
+
output = output.squeeze(0)
|
401
|
+
|
402
|
+
return output.squeeze(1)
|
403
|
+
|
404
|
+
def get_available_memory(self) -> int:
|
405
|
+
if self.device == "cpu":
|
406
|
+
return super().get_available_memory()
|
407
|
+
return self._array_backend.cuda.mem_get_info()[0]
|
408
|
+
|
409
|
+
@contextmanager
|
410
|
+
def set_device(self, device_index: int):
|
411
|
+
if self.device == "cuda":
|
412
|
+
with self._array_backend.cuda.device(device_index):
|
413
|
+
yield
|
414
|
+
else:
|
415
|
+
yield None
|
416
|
+
|
417
|
+
def device_count(self) -> int:
|
418
|
+
if self.device == "cpu":
|
419
|
+
return 1
|
420
|
+
return self._array_backend.cuda.device_count()
|
421
|
+
|
422
|
+
def reverse(self, arr: TorchTensor, axis: Tuple[int] = None) -> TorchTensor:
|
423
|
+
if axis is None:
|
424
|
+
axis = tuple(range(arr.ndim))
|
425
|
+
return self._array_backend.flip(arr, [i for i in range(arr.ndim) if i in axis])
|
426
|
+
|
427
|
+
def triu_indices(self, n: int, k: int = 0, m: int = None) -> TorchTensor:
|
428
|
+
if m is None:
|
429
|
+
m = n
|
430
|
+
return self._array_backend.triu_indices(n, m, k)
|
tme/data/__init__.py
ADDED
File without changes
|
tme/data/c48n309.npy
ADDED
Binary file
|
tme/data/c48n527.npy
ADDED
Binary file
|
tme/data/c48n9.npy
ADDED
Binary file
|
tme/data/c48u1.npy
ADDED
Binary file
|
tme/data/c48u1153.npy
ADDED
Binary file
|
tme/data/c48u1201.npy
ADDED
Binary file
|
tme/data/c48u1641.npy
ADDED
Binary file
|
tme/data/c48u181.npy
ADDED
Binary file
|
tme/data/c48u2219.npy
ADDED
Binary file
|
tme/data/c48u27.npy
ADDED
Binary file
|
tme/data/c48u2947.npy
ADDED
Binary file
|
tme/data/c48u3733.npy
ADDED
Binary file
|
tme/data/c48u4749.npy
ADDED
Binary file
|
tme/data/c48u5879.npy
ADDED
Binary file
|
tme/data/c48u7111.npy
ADDED
Binary file
|
tme/data/c48u815.npy
ADDED
Binary file
|
tme/data/c48u83.npy
ADDED
Binary file
|
tme/data/c48u8649.npy
ADDED
Binary file
|
tme/data/c600v.npy
ADDED
Binary file
|
tme/data/c600vc.npy
ADDED
Binary file
|
tme/data/metadata.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
1
|
+
c48n309.npy:
|
2
|
+
- 7416
|
3
|
+
- 9.72
|
4
|
+
- 1.91567
|
5
|
+
c48n527.npy:
|
6
|
+
- 12648
|
7
|
+
- 8.17
|
8
|
+
- 1.94334
|
9
|
+
c48n9.npy:
|
10
|
+
- 216
|
11
|
+
- 36.47
|
12
|
+
- 2.89689
|
13
|
+
c48u1.npy:
|
14
|
+
- 24
|
15
|
+
- 62.8
|
16
|
+
- 1.57514
|
17
|
+
c48u1153.npy:
|
18
|
+
- 27672
|
19
|
+
- 6.6
|
20
|
+
- 2.23735
|
21
|
+
c48u1201.npy:
|
22
|
+
- 28824
|
23
|
+
- 6.48
|
24
|
+
- 2.20918
|
25
|
+
c48u1641.npy:
|
26
|
+
- 39384
|
27
|
+
- 5.75
|
28
|
+
- 2.10646
|
29
|
+
c48u181.npy:
|
30
|
+
- 4344
|
31
|
+
- 12.29
|
32
|
+
- 2.27013
|
33
|
+
c48u2219.npy:
|
34
|
+
- 53256
|
35
|
+
- 5.27
|
36
|
+
- 2.20117
|
37
|
+
c48u27.npy:
|
38
|
+
- 648
|
39
|
+
- 20.83
|
40
|
+
- 1.64091
|
41
|
+
c48u2947.npy:
|
42
|
+
- 70728
|
43
|
+
- 4.71
|
44
|
+
- 2.07843
|
45
|
+
c48u3733.npy:
|
46
|
+
- 89592
|
47
|
+
- 4.37
|
48
|
+
- 2.11197
|
49
|
+
c48u4749.npy:
|
50
|
+
- 113976
|
51
|
+
- 4.0
|
52
|
+
- 2.053
|
53
|
+
c48u5879.npy:
|
54
|
+
- 141096
|
55
|
+
- 3.74
|
56
|
+
- 2.07325
|
57
|
+
c48u7111.npy:
|
58
|
+
- 170664
|
59
|
+
- 3.53
|
60
|
+
- 2.11481
|
61
|
+
c48u815.npy:
|
62
|
+
- 19560
|
63
|
+
- 7.4
|
64
|
+
- 2.23719
|
65
|
+
c48u83.npy:
|
66
|
+
- 1992
|
67
|
+
- 16.29
|
68
|
+
- 2.42065
|
69
|
+
c48u8649.npy:
|
70
|
+
- 207576
|
71
|
+
- 3.26
|
72
|
+
- 2.02898
|
73
|
+
c600v.npy:
|
74
|
+
- 60
|
75
|
+
- 44.48
|
76
|
+
- 1.4448
|
77
|
+
c600vc.npy:
|
78
|
+
- 360
|
79
|
+
- 27.78
|
80
|
+
- 2.15246
|
@@ -0,0 +1,42 @@
|
|
1
|
+
#!python3
|
2
|
+
|
3
|
+
from os import listdir
|
4
|
+
from os.path import dirname, join, basename
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import yaml
|
8
|
+
|
9
|
+
|
10
|
+
def quat_to_numpy(filepath):
|
11
|
+
data = []
|
12
|
+
with open(join(filepath, filepath), "r", encoding="utf-8") as ifile:
|
13
|
+
for line in ifile:
|
14
|
+
if line.startswith("#") or line.startswith("format quaternion"):
|
15
|
+
continue
|
16
|
+
line_split = line.strip().split()
|
17
|
+
if len(line_split) == 3:
|
18
|
+
n, angle, c = line_split
|
19
|
+
continue
|
20
|
+
data.append(line_split)
|
21
|
+
|
22
|
+
data = np.array(data).astype(float)
|
23
|
+
return data, int(n), float(angle), float(c)
|
24
|
+
|
25
|
+
|
26
|
+
if __name__ == "__main__":
|
27
|
+
current_directory = dirname(__file__)
|
28
|
+
files = listdir(current_directory)
|
29
|
+
|
30
|
+
files = [file for file in files if file.endswith(".quat")]
|
31
|
+
files = [join(current_directory, file) for file in files]
|
32
|
+
numpy_names = [
|
33
|
+
join(current_directory, file.replace("quat", "npy")) for file in files
|
34
|
+
]
|
35
|
+
|
36
|
+
metadata = {}
|
37
|
+
for file, np_out in zip(files, numpy_names):
|
38
|
+
quaternions, n, angle, c = quat_to_numpy(file)
|
39
|
+
np.save(np_out, quaternions)
|
40
|
+
metadata[basename(np_out)] = [n, angle, c]
|
41
|
+
with open(join(current_directory, "metadata.yaml"), "w", encoding="utf-8") as ofile:
|
42
|
+
yaml.dump(metadata, ofile, default_flow_style=False)
|
Binary file
|