pytme 0.1.5__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.1.5.data/scripts/estimate_ram_usage.py +81 -0
- pytme-0.1.5.data/scripts/match_template.py +744 -0
- pytme-0.1.5.data/scripts/postprocess.py +279 -0
- pytme-0.1.5.data/scripts/preprocess.py +93 -0
- pytme-0.1.5.data/scripts/preprocessor_gui.py +729 -0
- pytme-0.1.5.dist-info/LICENSE +153 -0
- pytme-0.1.5.dist-info/METADATA +69 -0
- pytme-0.1.5.dist-info/RECORD +63 -0
- pytme-0.1.5.dist-info/WHEEL +5 -0
- pytme-0.1.5.dist-info/entry_points.txt +6 -0
- pytme-0.1.5.dist-info/top_level.txt +2 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +81 -0
- scripts/match_template.py +744 -0
- scripts/match_template_devel.py +788 -0
- scripts/postprocess.py +279 -0
- scripts/preprocess.py +93 -0
- scripts/preprocessor_gui.py +729 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer.py +1144 -0
- tme/backends/__init__.py +134 -0
- tme/backends/cupy_backend.py +309 -0
- tme/backends/matching_backend.py +1154 -0
- tme/backends/npfftw_backend.py +763 -0
- tme/backends/pytorch_backend.py +526 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2314 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/helpers.py +881 -0
- tme/matching_data.py +377 -0
- tme/matching_exhaustive.py +1553 -0
- tme/matching_memory.py +382 -0
- tme/matching_optimization.py +1123 -0
- tme/matching_utils.py +1180 -0
- tme/parser.py +429 -0
- tme/preprocessor.py +1291 -0
- tme/scoring.py +866 -0
- tme/structure.py +1428 -0
- tme/types.py +10 -0
@@ -0,0 +1,526 @@
|
|
1
|
+
""" Backend using pytorch and optionally GPU acceleration for
|
2
|
+
template matching.
|
3
|
+
|
4
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
5
|
+
|
6
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Tuple, Dict, Callable
|
10
|
+
from contextlib import contextmanager
|
11
|
+
from multiprocessing import shared_memory
|
12
|
+
from multiprocessing.managers import SharedMemoryManager
|
13
|
+
|
14
|
+
from numpy.typing import NDArray
|
15
|
+
|
16
|
+
from .npfftw_backend import NumpyFFTWBackend
|
17
|
+
from ..types import NDArray, TorchTensor
|
18
|
+
|
19
|
+
|
20
|
+
class PytorchBackend(NumpyFFTWBackend):
|
21
|
+
"""
|
22
|
+
A pytorch based backend for template matching
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
device="cuda",
|
28
|
+
default_dtype=None,
|
29
|
+
complex_dtype=None,
|
30
|
+
default_dtype_int=None,
|
31
|
+
**kwargs,
|
32
|
+
):
|
33
|
+
import torch
|
34
|
+
import torch.nn.functional as F
|
35
|
+
|
36
|
+
default_dtype = torch.float32 if default_dtype is None else default_dtype
|
37
|
+
complex_dtype = torch.complex64 if complex_dtype is None else complex_dtype
|
38
|
+
default_dtype_int = (
|
39
|
+
torch.int32 if default_dtype_int is None else default_dtype_int
|
40
|
+
)
|
41
|
+
|
42
|
+
super().__init__(
|
43
|
+
array_backend=torch,
|
44
|
+
default_dtype=default_dtype,
|
45
|
+
complex_dtype=complex_dtype,
|
46
|
+
default_dtype_int=default_dtype_int,
|
47
|
+
)
|
48
|
+
self.device = device
|
49
|
+
self.F = F
|
50
|
+
self._default_dtype_int = torch.int32
|
51
|
+
|
52
|
+
def to_backend_array(self, arr: NDArray) -> TorchTensor:
|
53
|
+
if isinstance(arr, self._array_backend.Tensor):
|
54
|
+
if arr.device == self.device:
|
55
|
+
return arr
|
56
|
+
return arr.to(self.device)
|
57
|
+
return self.tensor(arr, device=self.device)
|
58
|
+
|
59
|
+
def to_numpy_array(self, arr: TorchTensor) -> NDArray:
|
60
|
+
return arr.cpu().numpy()
|
61
|
+
|
62
|
+
def to_cpu_array(self, arr: TorchTensor) -> NDArray:
|
63
|
+
return arr.cpu()
|
64
|
+
|
65
|
+
def free_cache(self):
|
66
|
+
self._array_backend.cuda.empty_cache()
|
67
|
+
|
68
|
+
def mod(self, x1, x2, *args, **kwargs):
|
69
|
+
x1 = self.to_backend_array(x1)
|
70
|
+
x2 = self.to_backend_array(x2)
|
71
|
+
return self._array_backend.remainder(x1, x2, *args, **kwargs)
|
72
|
+
|
73
|
+
def sum(self, *args, **kwargs) -> NDArray:
|
74
|
+
return self._array_backend.sum(*args, **kwargs)
|
75
|
+
|
76
|
+
def mean(self, *args, **kwargs) -> NDArray:
|
77
|
+
return self._array_backend.mean(*args, **kwargs)
|
78
|
+
|
79
|
+
def std(self, *args, **kwargs) -> NDArray:
|
80
|
+
return self._array_backend.std(*args, **kwargs)
|
81
|
+
|
82
|
+
def max(self, *args, **kwargs) -> NDArray:
|
83
|
+
ret = self._array_backend.amax(*args, **kwargs)
|
84
|
+
if type(ret) == self._array_backend.Tensor:
|
85
|
+
return ret
|
86
|
+
return ret[0]
|
87
|
+
|
88
|
+
def min(self, *args, **kwargs) -> NDArray:
|
89
|
+
ret = self._array_backend.amin(*args, **kwargs)
|
90
|
+
if type(ret) == self._array_backend.Tensor:
|
91
|
+
return ret
|
92
|
+
return ret[0]
|
93
|
+
|
94
|
+
def maximum(self, x1, x2, *args, **kwargs) -> NDArray:
|
95
|
+
x1 = self.to_backend_array(x1)
|
96
|
+
x2 = self.to_backend_array(x2)
|
97
|
+
return self._array_backend.maximum(input=x1, other=x2, *args, **kwargs)
|
98
|
+
|
99
|
+
def minimum(self, x1, x2, *args, **kwargs) -> NDArray:
|
100
|
+
x1 = self.to_backend_array(x1)
|
101
|
+
x2 = self.to_backend_array(x2)
|
102
|
+
return self._array_backend.minimum(input=x1, other=x2, *args, **kwargs)
|
103
|
+
|
104
|
+
def tobytes(self, arr):
|
105
|
+
return arr.cpu().numpy().tobytes()
|
106
|
+
|
107
|
+
def size(self, arr):
|
108
|
+
return arr.numel()
|
109
|
+
|
110
|
+
def zeros(self, shape, dtype=None):
|
111
|
+
return self._array_backend.zeros(shape, dtype=dtype, device=self.device)
|
112
|
+
|
113
|
+
def preallocate_array(self, shape: Tuple[int], dtype: type) -> NDArray:
|
114
|
+
"""
|
115
|
+
Returns a byte-aligned array of zeros with specified shape and dtype.
|
116
|
+
|
117
|
+
Parameters
|
118
|
+
----------
|
119
|
+
shape : Tuple[int]
|
120
|
+
Desired shape for the array.
|
121
|
+
dtype : type
|
122
|
+
Desired data type for the array.
|
123
|
+
|
124
|
+
Returns
|
125
|
+
-------
|
126
|
+
NDArray
|
127
|
+
Byte-aligned array of zeros with specified shape and dtype.
|
128
|
+
"""
|
129
|
+
arr = self._array_backend.zeros(shape, dtype=dtype, device=self.device)
|
130
|
+
return arr
|
131
|
+
|
132
|
+
def full(self, shape, fill_value, dtype=None):
|
133
|
+
return self._array_backend.full(
|
134
|
+
shape, dtype=dtype, fill_value=fill_value, device=self.device
|
135
|
+
)
|
136
|
+
|
137
|
+
def datatype_bytes(self, dtype: type) -> int:
|
138
|
+
temp = self.zeros(1, dtype=dtype)
|
139
|
+
return temp.element_size()
|
140
|
+
|
141
|
+
def fill(self, arr: TorchTensor, value: float):
|
142
|
+
arr.fill_(value)
|
143
|
+
|
144
|
+
def astype(self, arr, dtype):
|
145
|
+
return arr.to(dtype)
|
146
|
+
|
147
|
+
def flip(self, a, axis, **kwargs):
|
148
|
+
return self._array_backend.flip(input=a, dims=axis, **kwargs)
|
149
|
+
|
150
|
+
def arange(self, *args, **kwargs):
|
151
|
+
return self._array_backend.arange(*args, **kwargs, device=self.device)
|
152
|
+
|
153
|
+
def stack(self, *args, **kwargs):
|
154
|
+
return self._array_backend.stack(*args, **kwargs)
|
155
|
+
|
156
|
+
def topk_indices(self, arr, k):
|
157
|
+
temp = arr.reshape(-1)
|
158
|
+
values, indices = self._array_backend.topk(temp, k)
|
159
|
+
indices = self.unravel_index(indices=indices, shape=arr.shape)
|
160
|
+
return indices
|
161
|
+
|
162
|
+
def indices(self, shape: Tuple[int]) -> TorchTensor:
|
163
|
+
grids = [self.arange(x) for x in shape]
|
164
|
+
mesh = self._array_backend.meshgrid(*grids, indexing="ij")
|
165
|
+
return self._array_backend.stack(mesh)
|
166
|
+
|
167
|
+
def unravel_index(self, indices, shape):
|
168
|
+
indices = self.to_backend_array(indices)
|
169
|
+
shape = self.to_backend_array(shape)
|
170
|
+
strides = self._array_backend.cumprod(shape.flip(0), dim=0).flip(0)
|
171
|
+
strides = self._array_backend.cat(
|
172
|
+
(strides[1:], self.to_backend_array([1])),
|
173
|
+
)
|
174
|
+
unraveled_coords = (indices.view(-1, 1) // strides.view(1, -1)) % shape.view(
|
175
|
+
1, -1
|
176
|
+
)
|
177
|
+
if unraveled_coords.size(0) == 1:
|
178
|
+
return tuple(unraveled_coords[0, :].tolist())
|
179
|
+
|
180
|
+
else:
|
181
|
+
return tuple(unraveled_coords.T)
|
182
|
+
|
183
|
+
def roll(self, a, shift, axis, **kwargs):
|
184
|
+
shift = tuple(shift)
|
185
|
+
return self._array_backend.roll(input=a, shifts=shift, dims=axis, **kwargs)
|
186
|
+
|
187
|
+
def unique(
|
188
|
+
self,
|
189
|
+
ar,
|
190
|
+
return_index: bool = False,
|
191
|
+
return_inverse: bool = False,
|
192
|
+
return_counts: bool = False,
|
193
|
+
axis: int = None,
|
194
|
+
sorted: bool = True,
|
195
|
+
):
|
196
|
+
# https://github.com/pytorch/pytorch/issues/36748#issuecomment-1478913448
|
197
|
+
unique, inverse, counts = self._array_backend.unique(
|
198
|
+
ar, return_inverse=True, return_counts=True, dim=axis, sorted=sorted
|
199
|
+
)
|
200
|
+
inverse = inverse.reshape(-1)
|
201
|
+
|
202
|
+
if return_index:
|
203
|
+
inv_sorted = inverse.argsort(stable=True)
|
204
|
+
tot_counts = self._array_backend.cat(
|
205
|
+
(counts.new_zeros(1), counts.cumsum(dim=0))
|
206
|
+
)[:-1]
|
207
|
+
index = inv_sorted[tot_counts]
|
208
|
+
|
209
|
+
ret = unique
|
210
|
+
if return_index or return_inverse or return_counts:
|
211
|
+
ret = [unique]
|
212
|
+
|
213
|
+
if return_index:
|
214
|
+
ret.append(index)
|
215
|
+
if return_inverse:
|
216
|
+
ret.append(inverse)
|
217
|
+
if return_counts:
|
218
|
+
ret.append(counts)
|
219
|
+
|
220
|
+
return ret
|
221
|
+
|
222
|
+
def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
|
223
|
+
if score_space.ndim == 3:
|
224
|
+
func = self._array_backend.nn.MaxPool3d
|
225
|
+
elif score_space.ndim == 2:
|
226
|
+
func = self._array_backend.nn.MaxPool2d
|
227
|
+
else:
|
228
|
+
raise NotImplementedError("Operation only implemented for 2 and 3D inputs.")
|
229
|
+
|
230
|
+
pool = func(kernel_size=min_distance, return_indices=True)
|
231
|
+
_, indices = pool(score_space.reshape(1, 1, *score_space.shape))
|
232
|
+
coordinates = self.unravel_index(indices.reshape(-1), score_space.shape)
|
233
|
+
coordinates = self.transpose(self.stack(coordinates))
|
234
|
+
return coordinates
|
235
|
+
|
236
|
+
def repeat(self, *args, **kwargs):
|
237
|
+
return self._array_backend.repeat_interleave(*args, **kwargs)
|
238
|
+
|
239
|
+
def sharedarr_to_arr(
|
240
|
+
self, shape: Tuple[int], dtype: str, shm: TorchTensor
|
241
|
+
) -> TorchTensor:
|
242
|
+
if self.device == "cuda":
|
243
|
+
return shm
|
244
|
+
|
245
|
+
required_size = int(self._array_backend.prod(self.to_backend_array(shape)))
|
246
|
+
|
247
|
+
ret = self._array_backend.frombuffer(shm.buf, dtype=dtype)[
|
248
|
+
:required_size
|
249
|
+
].reshape(shape)
|
250
|
+
return ret
|
251
|
+
|
252
|
+
def arr_to_sharedarr(
|
253
|
+
self, arr: TorchTensor, shared_memory_handler: type = None
|
254
|
+
) -> TorchTensor:
|
255
|
+
if self.device == "cuda":
|
256
|
+
return arr
|
257
|
+
|
258
|
+
nbytes = arr.numel() * arr.element_size()
|
259
|
+
|
260
|
+
if type(shared_memory_handler) == SharedMemoryManager:
|
261
|
+
shm = shared_memory_handler.SharedMemory(size=nbytes)
|
262
|
+
else:
|
263
|
+
shm = shared_memory.SharedMemory(create=True, size=nbytes)
|
264
|
+
|
265
|
+
shm.buf[:nbytes] = arr.numpy().tobytes()
|
266
|
+
|
267
|
+
return shm
|
268
|
+
|
269
|
+
def transpose(self, arr):
|
270
|
+
return arr.permute(*self._array_backend.arange(arr.ndim - 1, -1, -1))
|
271
|
+
|
272
|
+
def power(self, *args, **kwargs):
|
273
|
+
return self._array_backend.pow(*args, **kwargs)
|
274
|
+
|
275
|
+
def rotate_array(
|
276
|
+
self,
|
277
|
+
arr: TorchTensor,
|
278
|
+
rotation_matrix: TorchTensor,
|
279
|
+
arr_mask: TorchTensor = None,
|
280
|
+
translation: TorchTensor = None,
|
281
|
+
out: TorchTensor = None,
|
282
|
+
out_mask: TorchTensor = None,
|
283
|
+
order: int = 1,
|
284
|
+
**kwargs,
|
285
|
+
):
|
286
|
+
"""
|
287
|
+
Rotates the given tensor `arr` based on the provided `rotation_matrix`.
|
288
|
+
|
289
|
+
This function optionally allows for rotating an accompanying mask
|
290
|
+
tensor (`arr_mask`) alongside the main tensor. The rotation is defined
|
291
|
+
by the `rotation_matrix` and the optional `translation`.
|
292
|
+
|
293
|
+
Parameters
|
294
|
+
----------
|
295
|
+
arr : TorchTensor
|
296
|
+
The input tensor to be rotated.
|
297
|
+
rotation_matrix : TorchTensor
|
298
|
+
The rotation matrix to apply. Must be square and of shape [d x d].
|
299
|
+
arr_mask : TorchTensor, optional
|
300
|
+
The mask of `arr` to be equivalently rotated.
|
301
|
+
translation : TorchTensor, optional
|
302
|
+
The translation to apply after rotation. Shape should
|
303
|
+
match tensor dimensions [d].
|
304
|
+
out : TorchTensor, optional
|
305
|
+
The output tensor to hold the rotated `arr`. If not provided, a new
|
306
|
+
tensor will be created.
|
307
|
+
out_mask : TorchTensor, optional
|
308
|
+
The output tensor to hold the rotated `arr_mask`. If not provided and
|
309
|
+
`arr_mask` is given, a new tensor will be created.
|
310
|
+
order : int, optional
|
311
|
+
Spline interpolation order. Supports orders:
|
312
|
+
|
313
|
+
+-------+---------------------------------------------------------------+
|
314
|
+
| 0 | Use 'nearest' neighbor interpolation. |
|
315
|
+
+-------+---------------------------------------------------------------+
|
316
|
+
| 1 | Use 'bilinear' interpolation for smoother results. |
|
317
|
+
+-------+---------------------------------------------------------------+
|
318
|
+
| 3 | Use 'bicubic' interpolation for higher order smoothness. |
|
319
|
+
+-------+---------------------------------------------------------------+
|
320
|
+
|
321
|
+
Returns
|
322
|
+
-------
|
323
|
+
out, out_mask : TorchTensor or Tuple[TorchTensor, TorchTensor] or None
|
324
|
+
Returns the rotated tensor(s). If `out` and `out_mask` are provided, the
|
325
|
+
function will return `None`.
|
326
|
+
If only `arr` is provided without `out`, it returns rotated `arr`.
|
327
|
+
If both `arr` and `arr_mask` are given without `out` and `out_mask`, it
|
328
|
+
returns a tuple of rotated tensors.
|
329
|
+
|
330
|
+
Notes
|
331
|
+
-----
|
332
|
+
Only a region of the size of `arr` and `arr_mask` is considered for
|
333
|
+
interpolation in `out` and `out_mask` respectively.
|
334
|
+
|
335
|
+
Currently bicubic interpolation is not supported for 3D inputs.
|
336
|
+
"""
|
337
|
+
device = arr.device
|
338
|
+
mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
|
339
|
+
mode = mode_mapping.get(order, None)
|
340
|
+
if mode is None:
|
341
|
+
modes = ", ".join([str(x) for x in mode_mapping.keys()])
|
342
|
+
raise ValueError(
|
343
|
+
f"Got {order} but supported interpolation orders are: {modes}."
|
344
|
+
)
|
345
|
+
rotate_mask = arr_mask is not None
|
346
|
+
return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
|
347
|
+
|
348
|
+
out = self.zeros_like(arr) if out is None else out
|
349
|
+
if translation is None:
|
350
|
+
translation = self._array_backend.zeros(arr.ndim, device=device)
|
351
|
+
|
352
|
+
normalized_translation = self.divide(
|
353
|
+
-2.0 * translation, self.tensor(arr.shape, device=arr.device)
|
354
|
+
)
|
355
|
+
|
356
|
+
rotation_matrix_pull = self.linalg.inv(self.flip(rotation_matrix, [0, 1]))
|
357
|
+
|
358
|
+
out_slice = tuple(slice(0, x) for x in arr.shape)
|
359
|
+
out[out_slice] = self._affine_transform(
|
360
|
+
arr=arr,
|
361
|
+
rotation_matrix=rotation_matrix_pull,
|
362
|
+
translation=normalized_translation,
|
363
|
+
mode=mode,
|
364
|
+
)
|
365
|
+
|
366
|
+
if rotate_mask:
|
367
|
+
out_mask_slice = tuple(slice(0, x) for x in arr_mask.shape)
|
368
|
+
if out_mask is None:
|
369
|
+
out_mask = self._array_backend.zeros_like(arr_mask)
|
370
|
+
out_mask[out_mask_slice] = self._affine_transform(
|
371
|
+
arr=arr_mask,
|
372
|
+
rotation_matrix=rotation_matrix_pull,
|
373
|
+
translation=normalized_translation,
|
374
|
+
mode=mode,
|
375
|
+
)
|
376
|
+
|
377
|
+
match return_type:
|
378
|
+
case 0:
|
379
|
+
return None
|
380
|
+
case 1:
|
381
|
+
return out
|
382
|
+
case 2:
|
383
|
+
return out_mask
|
384
|
+
case 3:
|
385
|
+
return out, out_mask
|
386
|
+
|
387
|
+
def build_fft(
|
388
|
+
self, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], **kwargs
|
389
|
+
) -> Tuple[Callable, Callable]:
|
390
|
+
"""
|
391
|
+
Build fft builder functions.
|
392
|
+
|
393
|
+
Parameters
|
394
|
+
----------
|
395
|
+
fast_shape : tuple
|
396
|
+
Tuple of integers corresponding to fast convolution shape
|
397
|
+
(see `compute_convolution_shapes`).
|
398
|
+
fast_ft_shape : tuple
|
399
|
+
Tuple of integers corresponding to the shape of the fourier
|
400
|
+
transform array (see `compute_convolution_shapes`).
|
401
|
+
**kwargs : dict, optional
|
402
|
+
Additional parameters that are not used for now.
|
403
|
+
|
404
|
+
Returns
|
405
|
+
-------
|
406
|
+
tuple
|
407
|
+
Tupple containing callable rfft and irfft object.
|
408
|
+
"""
|
409
|
+
|
410
|
+
def rfftn(
|
411
|
+
arr: TorchTensor, out: TorchTensor, shape: Tuple[int] = fast_shape
|
412
|
+
) -> None:
|
413
|
+
return self._array_backend.fft.rfftn(arr, s=shape, out=out)
|
414
|
+
|
415
|
+
def irfftn(
|
416
|
+
arr: TorchTensor, out: TorchTensor, shape: Tuple[int] = fast_shape
|
417
|
+
) -> None:
|
418
|
+
return self._array_backend.fft.irfftn(arr, s=shape, out=out)
|
419
|
+
|
420
|
+
return rfftn, irfftn
|
421
|
+
|
422
|
+
def _affine_transform(
|
423
|
+
self,
|
424
|
+
arr: TorchTensor,
|
425
|
+
rotation_matrix: TorchTensor,
|
426
|
+
translation: TorchTensor,
|
427
|
+
mode,
|
428
|
+
) -> TorchTensor:
|
429
|
+
"""
|
430
|
+
Performs an affine transformation on the given tensor.
|
431
|
+
|
432
|
+
The affine transformation is defined by the provided `rotation_matrix`
|
433
|
+
and the `translation` vector. The transformation is applied to the
|
434
|
+
input tensor `arr`.
|
435
|
+
|
436
|
+
Parameters
|
437
|
+
----------
|
438
|
+
arr : TorchTensor
|
439
|
+
The input tensor on which the transformation will be applied.
|
440
|
+
rotation_matrix : TorchTensor
|
441
|
+
The matrix defining the rotation component of the transformation.
|
442
|
+
translation : TorchTensor
|
443
|
+
The vector defining the translation to be applied post rotation.
|
444
|
+
mode : str
|
445
|
+
Interpolation mode to use. Options are: 'nearest', 'bilinear', 'bicubic'.
|
446
|
+
|
447
|
+
Returns
|
448
|
+
-------
|
449
|
+
TorchTensor
|
450
|
+
The tensor after applying the affine transformation.
|
451
|
+
"""
|
452
|
+
|
453
|
+
transformation_matrix = self._array_backend.zeros(
|
454
|
+
arr.ndim, arr.ndim + 1, device=arr.device, dtype=arr.dtype
|
455
|
+
)
|
456
|
+
transformation_matrix[:, : arr.ndim] = rotation_matrix
|
457
|
+
transformation_matrix[:, arr.ndim] = translation
|
458
|
+
|
459
|
+
size = self.Size([1, 1, *arr.shape])
|
460
|
+
grid = self.F.affine_grid(
|
461
|
+
theta=transformation_matrix.unsqueeze(0), size=size, align_corners=False
|
462
|
+
)
|
463
|
+
output = self.F.grid_sample(
|
464
|
+
input=arr.unsqueeze(0).unsqueeze(0),
|
465
|
+
grid=grid,
|
466
|
+
mode=mode,
|
467
|
+
align_corners=False,
|
468
|
+
)
|
469
|
+
|
470
|
+
return output.squeeze()
|
471
|
+
|
472
|
+
def get_available_memory(self) -> int:
|
473
|
+
if self.device == "cpu":
|
474
|
+
return super().get_available_memory()
|
475
|
+
return self._array_backend.cuda.mem_get_info()[0]
|
476
|
+
|
477
|
+
@contextmanager
|
478
|
+
def set_device(self, device_index: int):
|
479
|
+
"""
|
480
|
+
Set the active GPU device as a context.
|
481
|
+
|
482
|
+
This method sets the active GPU device for operations within the context.
|
483
|
+
|
484
|
+
Parameters
|
485
|
+
----------
|
486
|
+
device_index : int
|
487
|
+
Index of the GPU device to be set as active.
|
488
|
+
|
489
|
+
Yields
|
490
|
+
------
|
491
|
+
None
|
492
|
+
Operates as a context manager, yielding None and providing
|
493
|
+
the set GPU context for enclosed operations.
|
494
|
+
"""
|
495
|
+
if self.device == "cuda":
|
496
|
+
with self._array_backend.cuda.device(device_index):
|
497
|
+
yield
|
498
|
+
else:
|
499
|
+
yield None
|
500
|
+
|
501
|
+
def device_count(self) -> int:
|
502
|
+
"""
|
503
|
+
Return the number of available GPU devices.
|
504
|
+
|
505
|
+
Returns
|
506
|
+
-------
|
507
|
+
int
|
508
|
+
Number of available GPU devices.
|
509
|
+
"""
|
510
|
+
return self._array_backend.cuda.device_count()
|
511
|
+
|
512
|
+
def reverse(self, arr: TorchTensor) -> TorchTensor:
|
513
|
+
"""
|
514
|
+
Reverse the order of elements in a tensor along all its axes.
|
515
|
+
|
516
|
+
Parameters
|
517
|
+
----------
|
518
|
+
tensor : TorchTensor
|
519
|
+
Input tensor.
|
520
|
+
|
521
|
+
Returns
|
522
|
+
-------
|
523
|
+
TorchTensor
|
524
|
+
Reversed tensor.
|
525
|
+
"""
|
526
|
+
return self._array_backend.flip(arr, [i for i in range(arr.ndim)])
|
tme/data/__init__.py
ADDED
File without changes
|
tme/data/c48n309.npy
ADDED
Binary file
|
tme/data/c48n527.npy
ADDED
Binary file
|
tme/data/c48n9.npy
ADDED
Binary file
|
tme/data/c48u1.npy
ADDED
Binary file
|
tme/data/c48u1153.npy
ADDED
Binary file
|
tme/data/c48u1201.npy
ADDED
Binary file
|
tme/data/c48u1641.npy
ADDED
Binary file
|
tme/data/c48u181.npy
ADDED
Binary file
|
tme/data/c48u2219.npy
ADDED
Binary file
|
tme/data/c48u27.npy
ADDED
Binary file
|
tme/data/c48u2947.npy
ADDED
Binary file
|
tme/data/c48u3733.npy
ADDED
Binary file
|
tme/data/c48u4749.npy
ADDED
Binary file
|
tme/data/c48u5879.npy
ADDED
Binary file
|
tme/data/c48u7111.npy
ADDED
Binary file
|
tme/data/c48u815.npy
ADDED
Binary file
|
tme/data/c48u83.npy
ADDED
Binary file
|
tme/data/c48u8649.npy
ADDED
Binary file
|
tme/data/c600v.npy
ADDED
Binary file
|
tme/data/c600vc.npy
ADDED
Binary file
|
tme/data/metadata.yaml
ADDED
@@ -0,0 +1,80 @@
|
|
1
|
+
c48n309.npy:
|
2
|
+
- 7416
|
3
|
+
- 9.72
|
4
|
+
- 1.91567
|
5
|
+
c48n527.npy:
|
6
|
+
- 12648
|
7
|
+
- 8.17
|
8
|
+
- 1.94334
|
9
|
+
c48n9.npy:
|
10
|
+
- 216
|
11
|
+
- 36.47
|
12
|
+
- 2.89689
|
13
|
+
c48u1.npy:
|
14
|
+
- 24
|
15
|
+
- 62.8
|
16
|
+
- 1.57514
|
17
|
+
c48u1153.npy:
|
18
|
+
- 27672
|
19
|
+
- 6.6
|
20
|
+
- 2.23735
|
21
|
+
c48u1201.npy:
|
22
|
+
- 28824
|
23
|
+
- 6.48
|
24
|
+
- 2.20918
|
25
|
+
c48u1641.npy:
|
26
|
+
- 39384
|
27
|
+
- 5.75
|
28
|
+
- 2.10646
|
29
|
+
c48u181.npy:
|
30
|
+
- 4344
|
31
|
+
- 12.29
|
32
|
+
- 2.27013
|
33
|
+
c48u2219.npy:
|
34
|
+
- 53256
|
35
|
+
- 5.27
|
36
|
+
- 2.20117
|
37
|
+
c48u27.npy:
|
38
|
+
- 648
|
39
|
+
- 20.83
|
40
|
+
- 1.64091
|
41
|
+
c48u2947.npy:
|
42
|
+
- 70728
|
43
|
+
- 4.71
|
44
|
+
- 2.07843
|
45
|
+
c48u3733.npy:
|
46
|
+
- 89592
|
47
|
+
- 4.37
|
48
|
+
- 2.11197
|
49
|
+
c48u4749.npy:
|
50
|
+
- 113976
|
51
|
+
- 4.0
|
52
|
+
- 2.053
|
53
|
+
c48u5879.npy:
|
54
|
+
- 141096
|
55
|
+
- 3.74
|
56
|
+
- 2.07325
|
57
|
+
c48u7111.npy:
|
58
|
+
- 170664
|
59
|
+
- 3.53
|
60
|
+
- 2.11481
|
61
|
+
c48u815.npy:
|
62
|
+
- 19560
|
63
|
+
- 7.4
|
64
|
+
- 2.23719
|
65
|
+
c48u83.npy:
|
66
|
+
- 1992
|
67
|
+
- 16.29
|
68
|
+
- 2.42065
|
69
|
+
c48u8649.npy:
|
70
|
+
- 207576
|
71
|
+
- 3.26
|
72
|
+
- 2.02898
|
73
|
+
c600v.npy:
|
74
|
+
- 60
|
75
|
+
- 44.48
|
76
|
+
- 1.4448
|
77
|
+
c600vc.npy:
|
78
|
+
- 360
|
79
|
+
- 27.78
|
80
|
+
- 2.15246
|
@@ -0,0 +1,42 @@
|
|
1
|
+
#!python3
|
2
|
+
|
3
|
+
from os import listdir
|
4
|
+
from os.path import dirname, join, basename
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import yaml
|
8
|
+
|
9
|
+
|
10
|
+
def quat_to_numpy(filepath):
|
11
|
+
data = []
|
12
|
+
with open(join(filepath, filepath), "r", encoding="utf-8") as ifile:
|
13
|
+
for line in ifile:
|
14
|
+
if line.startswith("#") or line.startswith("format quaternion"):
|
15
|
+
continue
|
16
|
+
line_split = line.strip().split()
|
17
|
+
if len(line_split) == 3:
|
18
|
+
n, angle, c = line_split
|
19
|
+
continue
|
20
|
+
data.append(line_split)
|
21
|
+
|
22
|
+
data = np.array(data).astype(float)
|
23
|
+
return data, int(n), float(angle), float(c)
|
24
|
+
|
25
|
+
|
26
|
+
if __name__ == "__main__":
|
27
|
+
current_directory = dirname(__file__)
|
28
|
+
files = listdir(current_directory)
|
29
|
+
|
30
|
+
files = [file for file in files if file.endswith(".quat")]
|
31
|
+
files = [join(current_directory, file) for file in files]
|
32
|
+
numpy_names = [
|
33
|
+
join(current_directory, file.replace("quat", "npy")) for file in files
|
34
|
+
]
|
35
|
+
|
36
|
+
metadata = {}
|
37
|
+
for file, np_out in zip(files, numpy_names):
|
38
|
+
quaternions, n, angle, c = quat_to_numpy(file)
|
39
|
+
np.save(np_out, quaternions)
|
40
|
+
metadata[basename(np_out)] = [n, angle, c]
|
41
|
+
with open(join(current_directory, "metadata.yaml"), "w", encoding="utf-8") as ofile:
|
42
|
+
yaml.dump(metadata, ofile, default_flow_style=False)
|
Binary file
|