pytme 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +49 -103
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
- pytme-0.3.2.dev0.dist-info/RECORD +136 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +49 -103
- scripts/pytme_runner.py +46 -69
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_orientations.py +12 -0
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +91 -68
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +103 -98
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +44 -57
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +17 -3
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- pytme-0.3.1.post2.dist-info/RECORD +0 -133
- {pytme-0.3.1.post2.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post2.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tme/backends/npfftw_backend.py
CHANGED
@@ -6,7 +6,6 @@ Copyright (c) 2023 European Molecular Biology Laboratory
|
|
6
6
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
7
7
|
"""
|
8
8
|
|
9
|
-
import os
|
10
9
|
from psutil import virtual_memory
|
11
10
|
from contextlib import contextmanager
|
12
11
|
from typing import Tuple, List, Type
|
@@ -26,12 +25,6 @@ from ..types import NDArray, BackendArray, shm_type
|
|
26
25
|
from .matching_backend import MatchingBackend, _create_metafunction
|
27
26
|
|
28
27
|
|
29
|
-
os.environ["MKL_NUM_THREADS"] = "1"
|
30
|
-
os.environ["OMP_NUM_THREADS"] = "1"
|
31
|
-
os.environ["PYFFTW_NUM_THREADS"] = "1"
|
32
|
-
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
33
|
-
|
34
|
-
|
35
28
|
def create_ufuncs(obj):
|
36
29
|
ufuncs = [
|
37
30
|
"add",
|
@@ -247,12 +240,13 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
247
240
|
|
248
241
|
shm = shared_memory_handler.SharedMemory(size=arr.nbytes)
|
249
242
|
np_array = self.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
|
250
|
-
np_array[:] = arr[:]
|
243
|
+
np_array[:] = arr[:]
|
251
244
|
return shm, arr.shape, arr.dtype
|
252
245
|
|
253
|
-
def topleft_pad(
|
254
|
-
|
255
|
-
|
246
|
+
def topleft_pad(
|
247
|
+
self, arr: NDArray, shape: Tuple[int], padval: float = 0
|
248
|
+
) -> NDArray:
|
249
|
+
b = self.full(shape, fill_value=padval, dtype=arr.dtype)
|
256
250
|
aind = [slice(None, None)] * arr.ndim
|
257
251
|
bind = [slice(None, None)] * arr.ndim
|
258
252
|
for i in range(arr.ndim):
|
@@ -311,15 +305,6 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
311
305
|
**kwargs,
|
312
306
|
)
|
313
307
|
|
314
|
-
def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
|
315
|
-
new_shape = self.to_backend_array(newshape)
|
316
|
-
current_shape = self.to_backend_array(arr.shape)
|
317
|
-
starts = self.subtract(current_shape, new_shape)
|
318
|
-
starts = self.astype(self.divide(starts, 2), self._int_dtype)
|
319
|
-
stops = self.astype(self.add(starts, new_shape), self._int_dtype)
|
320
|
-
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
321
|
-
return arr[box]
|
322
|
-
|
323
308
|
def compute_convolution_shapes(
|
324
309
|
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
325
310
|
) -> Tuple[List[int], List[int], List[int]]:
|
@@ -329,62 +314,65 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
329
314
|
|
330
315
|
return convolution_shape, fast_shape, fast_ft_shape
|
331
316
|
|
332
|
-
def
|
317
|
+
def _build_transform_matrix(
|
333
318
|
self,
|
334
|
-
rotation_matrix:
|
335
|
-
translation:
|
336
|
-
center:
|
337
|
-
|
319
|
+
rotation_matrix: BackendArray,
|
320
|
+
translation: BackendArray = None,
|
321
|
+
center: BackendArray = None,
|
322
|
+
**kwargs,
|
323
|
+
) -> BackendArray:
|
338
324
|
ndim = rotation_matrix.shape[0]
|
339
|
-
matrix = self.identity(ndim + 1, dtype=self._float_dtype)
|
340
325
|
|
326
|
+
spatial_slice = slice(0, ndim)
|
327
|
+
matrix = self.eye(ndim + 1, dtype=self._float_dtype)
|
328
|
+
|
329
|
+
rotation_matrix = self.astype(rotation_matrix, self._float_dtype)
|
330
|
+
matrix = self.at(matrix, (spatial_slice, spatial_slice), rotation_matrix)
|
331
|
+
|
332
|
+
total_translation = self.zeros(ndim, dtype=self._float_dtype)
|
341
333
|
if translation is not None:
|
342
|
-
|
343
|
-
|
344
|
-
self.dot(matrix, translation_matrix, out=matrix)
|
334
|
+
translation = self.astype(translation, self._float_dtype)
|
335
|
+
total_translation = self.subtract(total_translation, translation)
|
345
336
|
|
346
337
|
if center is not None:
|
347
|
-
|
348
|
-
|
349
|
-
self.
|
338
|
+
total_translation = self.add(total_translation, center)
|
339
|
+
rotated_center = self.matmul(rotation_matrix, center)
|
340
|
+
total_translation = self.subtract(total_translation, rotated_center)
|
350
341
|
|
351
|
-
|
352
|
-
|
353
|
-
rmat[:ndim, :ndim] = self._array_backend.linalg.inv(rotation_matrix)
|
354
|
-
self.dot(matrix, rmat, out=matrix)
|
342
|
+
matrix = self.at(matrix, (spatial_slice, ndim), total_translation)
|
343
|
+
return self.to_backend_array(matrix)
|
355
344
|
|
356
|
-
|
357
|
-
|
358
|
-
self.dot(matrix, center_matrix, out=matrix)
|
345
|
+
def _batch_transform_matrix(self, matrix: NDArray) -> NDArray:
|
346
|
+
ndim = matrix.shape[0] + 1
|
359
347
|
|
360
|
-
|
361
|
-
|
348
|
+
ret = self.zeros((ndim, ndim), dtype=matrix.dtype)
|
349
|
+
ret = self.at(ret, (0, 0), 1)
|
362
350
|
|
363
|
-
|
351
|
+
spatial_slice = slice(1, ndim)
|
352
|
+
ret = self.at(ret, (spatial_slice, spatial_slice), matrix)
|
353
|
+
return ret
|
354
|
+
|
355
|
+
def _compute_transform_center(
|
356
|
+
self, arr: NDArray, use_geometric_center: bool, batched: bool = False
|
357
|
+
) -> NDArray:
|
358
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
|
359
|
+
if not use_geometric_center:
|
360
|
+
center = self.center_of_mass(arr, cutoff=0)
|
361
|
+
if batched:
|
362
|
+
return center[1:]
|
363
|
+
return center
|
364
|
+
|
365
|
+
def _transform(
|
364
366
|
self,
|
365
367
|
data: NDArray,
|
366
368
|
matrix: NDArray,
|
367
369
|
output: NDArray,
|
368
370
|
prefilter: bool,
|
369
371
|
order: int,
|
370
|
-
|
371
|
-
|
372
|
-
) -> None:
|
373
|
-
if batched:
|
374
|
-
for i in range(data.shape[0]):
|
375
|
-
self._rigid_transform(
|
376
|
-
data=data[i],
|
377
|
-
matrix=matrix,
|
378
|
-
output=output[i],
|
379
|
-
prefilter=prefilter,
|
380
|
-
order=order,
|
381
|
-
cache=cache,
|
382
|
-
batched=False,
|
383
|
-
)
|
384
|
-
return None
|
385
|
-
|
372
|
+
**kwargs,
|
373
|
+
) -> NDArray:
|
386
374
|
out_slice = tuple(slice(0, stop) for stop in data.shape)
|
387
|
-
self.affine_transform(
|
375
|
+
return self.affine_transform(
|
388
376
|
input=data,
|
389
377
|
matrix=matrix,
|
390
378
|
mode="constant",
|
@@ -393,72 +381,87 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
393
381
|
prefilter=prefilter,
|
394
382
|
)
|
395
383
|
|
396
|
-
def
|
384
|
+
def _rigid_transform(
|
397
385
|
self,
|
398
386
|
arr: NDArray,
|
399
|
-
|
387
|
+
matrix: NDArray,
|
400
388
|
arr_mask: NDArray = None,
|
401
|
-
translation: NDArray = None,
|
402
|
-
use_geometric_center: bool = False,
|
403
389
|
out: NDArray = None,
|
404
390
|
out_mask: NDArray = None,
|
405
391
|
order: int = 3,
|
406
392
|
cache: bool = False,
|
407
|
-
|
393
|
+
**kwargs,
|
408
394
|
) -> Tuple[NDArray, NDArray]:
|
409
395
|
if out is None:
|
410
396
|
out = self.zeros_like(arr)
|
411
397
|
|
412
|
-
|
413
|
-
matrix = rotation_matrix
|
414
|
-
if matrix.shape[-1] == (arr.ndim - int(batched)):
|
415
|
-
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
|
416
|
-
if not use_geometric_center:
|
417
|
-
center = self.center_of_mass(arr, cutoff=0)
|
418
|
-
|
419
|
-
offset = int(arr.ndim - rotation_matrix.shape[0])
|
420
|
-
center = center[offset:]
|
421
|
-
translation = (
|
422
|
-
self.zeros(center.size) if translation is None else translation
|
423
|
-
)
|
424
|
-
matrix = self._rigid_transform_matrix(
|
425
|
-
rotation_matrix=rotation_matrix,
|
426
|
-
translation=translation,
|
427
|
-
center=center,
|
428
|
-
)
|
429
|
-
|
430
|
-
self._rigid_transform(
|
398
|
+
out = self._transform(
|
431
399
|
data=arr,
|
432
400
|
matrix=matrix,
|
433
401
|
output=out,
|
434
402
|
order=order,
|
435
403
|
prefilter=True,
|
436
404
|
cache=cache,
|
437
|
-
batched=batched,
|
438
405
|
)
|
439
406
|
|
440
|
-
# Applying the prefilter leads to artifacts in the mask.
|
441
407
|
if arr_mask is not None:
|
442
408
|
if out_mask is None:
|
443
|
-
out_mask = self.zeros_like(
|
409
|
+
out_mask = self.zeros_like(arr)
|
444
410
|
|
445
|
-
|
411
|
+
# Applying the prefilter leads to artifacts in the mask.
|
412
|
+
out_mask = self._transform(
|
446
413
|
data=arr_mask,
|
447
414
|
matrix=matrix,
|
448
415
|
output=out_mask,
|
449
416
|
order=order,
|
450
417
|
prefilter=False,
|
451
418
|
cache=cache,
|
452
|
-
batched=batched,
|
453
419
|
)
|
454
420
|
|
455
421
|
return out, out_mask
|
456
422
|
|
423
|
+
def rigid_transform(
|
424
|
+
self,
|
425
|
+
arr: NDArray,
|
426
|
+
rotation_matrix: NDArray,
|
427
|
+
arr_mask: NDArray = None,
|
428
|
+
translation: NDArray = None,
|
429
|
+
use_geometric_center: bool = False,
|
430
|
+
out: NDArray = None,
|
431
|
+
out_mask: NDArray = None,
|
432
|
+
order: int = 3,
|
433
|
+
cache: bool = False,
|
434
|
+
batched: bool = False,
|
435
|
+
) -> Tuple[NDArray, NDArray]:
|
436
|
+
matrix = rotation_matrix
|
437
|
+
|
438
|
+
# Build transformation matrix from rotation matrix
|
439
|
+
if matrix.shape[-1] == (arr.ndim - int(batched)):
|
440
|
+
center = self._compute_transform_center(arr, use_geometric_center, batched)
|
441
|
+
matrix = self._build_transform_matrix(
|
442
|
+
rotation_matrix=rotation_matrix,
|
443
|
+
translation=translation,
|
444
|
+
center=self.astype(center, self._float_dtype),
|
445
|
+
shape=arr.shape[1:] if batched else arr.shape,
|
446
|
+
)
|
447
|
+
|
448
|
+
if batched:
|
449
|
+
matrix = self._batch_transform_matrix(matrix)
|
450
|
+
|
451
|
+
return self._rigid_transform(
|
452
|
+
arr=arr,
|
453
|
+
arr_mask=arr_mask,
|
454
|
+
out=out,
|
455
|
+
out_mask=out_mask,
|
456
|
+
matrix=matrix,
|
457
|
+
cache=cache,
|
458
|
+
order=order,
|
459
|
+
batched=batched,
|
460
|
+
)
|
461
|
+
|
457
462
|
def center_of_mass(self, arr: BackendArray, cutoff: float = None) -> BackendArray:
|
458
463
|
"""
|
459
|
-
Computes the center of mass of
|
460
|
-
elements. For template matching it typically makes sense to only input
|
461
|
-
positive densities.
|
464
|
+
Computes the center of mass of an array larger than cutoff.
|
462
465
|
|
463
466
|
Parameters
|
464
467
|
----------
|
@@ -466,19 +469,19 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
466
469
|
Array to compute the center of mass of.
|
467
470
|
cutoff : float, optional
|
468
471
|
Densities less than or equal to cutoff are nullified for center
|
469
|
-
of mass computation.
|
472
|
+
of mass computation. Defaults to None.
|
470
473
|
|
471
474
|
Returns
|
472
475
|
-------
|
473
476
|
BackendArray
|
474
477
|
Center of mass with shape (arr.ndim).
|
475
478
|
"""
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
denominator = self.sum(arr)
|
479
|
+
arr = self.abs(arr)
|
480
|
+
if cutoff is not None:
|
481
|
+
arr = self.where(arr > cutoff, arr, 0)
|
480
482
|
|
481
483
|
grids = []
|
484
|
+
denominator = self.sum(arr)
|
482
485
|
for i, x in enumerate(arr.shape):
|
483
486
|
baseline_dims = tuple(1 if i != t else x for t in range(len(arr.shape)))
|
484
487
|
grids.append(
|
@@ -587,7 +590,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
587
590
|
sq_exp = self.sqrt(sq_exp, out=sq_exp)
|
588
591
|
|
589
592
|
# Assume that low stdev regions also have low scores
|
590
|
-
# See :py:meth:`tme.
|
593
|
+
# See :py:meth:`tme.matching_scores.flcSphericalMask_setup` for correct norm
|
591
594
|
sq_exp[sq_exp < eps] = 1
|
592
595
|
sq_exp = self.multiply(sq_exp, n_obs, out=sq_exp)
|
593
596
|
return self.divide(arr, sq_exp, out=out)
|
tme/backends/pytorch_backend.py
CHANGED
@@ -281,22 +281,45 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
281
281
|
kwargs["dim"] = kwargs.pop("axes", None)
|
282
282
|
return self._array_backend.fft.irfftn(arr, **kwargs)
|
283
283
|
|
284
|
-
def
|
285
|
-
|
284
|
+
def _build_transform_matrix(
|
285
|
+
self,
|
286
|
+
shape: Tuple[int],
|
287
|
+
rotation_matrix: TorchTensor,
|
288
|
+
translation: TorchTensor = None,
|
289
|
+
center: TorchTensor = None,
|
290
|
+
**kwargs,
|
291
|
+
) -> TorchTensor:
|
292
|
+
"""
|
293
|
+
Express the transform matrix in normalized coordinates.
|
294
|
+
"""
|
295
|
+
shape = self.to_backend_array(shape) - 1
|
296
|
+
|
297
|
+
scale_factors = 2.0 / shape
|
298
|
+
if center is not None:
|
299
|
+
center = center - shape / 2
|
300
|
+
center = center * scale_factors
|
301
|
+
|
302
|
+
if translation is not None:
|
303
|
+
translation = translation * scale_factors
|
304
|
+
|
305
|
+
return super()._build_transform_matrix(
|
306
|
+
rotation_matrix=self.flip(rotation_matrix, [0, 1]),
|
307
|
+
translation=translation,
|
308
|
+
center=center,
|
309
|
+
)
|
286
310
|
|
287
|
-
def
|
311
|
+
def _rigid_transform(
|
288
312
|
self,
|
289
313
|
arr: TorchTensor,
|
290
|
-
|
314
|
+
matrix: TorchTensor,
|
291
315
|
arr_mask: TorchTensor = None,
|
292
|
-
translation: TorchTensor = None,
|
293
|
-
use_geometric_center: bool = False,
|
294
316
|
out: TorchTensor = None,
|
295
317
|
out_mask: TorchTensor = None,
|
296
318
|
order: int = 1,
|
297
|
-
|
319
|
+
batched: bool = False,
|
298
320
|
**kwargs,
|
299
|
-
):
|
321
|
+
) -> Tuple[TorchTensor, TorchTensor]:
|
322
|
+
"""Apply rigid transformation using homogeneous transformation matrix."""
|
300
323
|
_mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
|
301
324
|
mode = _mode_mapping.get(order, None)
|
302
325
|
if mode is None:
|
@@ -305,90 +328,54 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
305
328
|
f"Got {order} but supported interpolation orders are: {modes}."
|
306
329
|
)
|
307
330
|
|
308
|
-
out = self.zeros_like(arr) if out is None else out
|
309
|
-
|
310
|
-
if translation is None:
|
311
|
-
translation = self._array_backend.zeros(arr.ndim, device=arr.device)
|
312
|
-
|
313
|
-
normalized_translation = self.divide(
|
314
|
-
-2.0 * translation, self.tensor(arr.shape, device=arr.device)
|
315
|
-
)
|
316
|
-
rotation_matrix_pull = self.linalg.inv(self.flip(rotation_matrix, [0, 1]))
|
317
|
-
|
318
|
-
out_slice = tuple(slice(0, x) for x in arr.shape)
|
319
|
-
subset = tuple(slice(None) for _ in range(arr.ndim))
|
320
|
-
offset = max(int(arr.ndim - rotation_matrix.shape[0]) - 1, 0)
|
321
|
-
if offset > 0:
|
322
|
-
normalized_translation = normalized_translation[offset:]
|
323
|
-
subset = tuple(0 if i < offset else slice(None) for i in range(arr.ndim))
|
324
|
-
out_slice = tuple(
|
325
|
-
slice(0, 1) if i < offset else slice(0, x)
|
326
|
-
for i, x in enumerate(arr.shape)
|
327
|
-
)
|
328
|
-
|
329
|
-
out[out_slice] = self._affine_transform(
|
330
|
-
arr=arr[subset],
|
331
|
-
rotation_matrix=rotation_matrix_pull,
|
332
|
-
translation=normalized_translation,
|
333
|
-
mode=mode,
|
334
|
-
)
|
335
|
-
|
336
|
-
if arr_mask is not None:
|
337
|
-
out_mask_slice = tuple(slice(0, x) for x in arr_mask.shape)
|
338
|
-
if out_mask is None:
|
339
|
-
out_mask = self._array_backend.zeros_like(arr_mask)
|
340
|
-
out_mask[out_mask_slice] = self._affine_transform(
|
341
|
-
arr=arr_mask[subset],
|
342
|
-
rotation_matrix=rotation_matrix_pull,
|
343
|
-
translation=normalized_translation,
|
344
|
-
mode=mode,
|
345
|
-
)
|
346
|
-
|
347
|
-
return out, out_mask
|
348
|
-
|
349
|
-
def _affine_transform(
|
350
|
-
self,
|
351
|
-
arr: TorchTensor,
|
352
|
-
rotation_matrix: TorchTensor,
|
353
|
-
translation: TorchTensor,
|
354
|
-
mode,
|
355
|
-
) -> TorchTensor:
|
356
|
-
batched = arr.ndim != rotation_matrix.shape[0]
|
357
|
-
|
358
331
|
batch_size, spatial_dims = 1, arr.shape
|
332
|
+
out_slice = tuple(slice(0, x) for x in arr.shape)
|
359
333
|
if batched:
|
360
|
-
|
334
|
+
matrix = matrix[1:, 1:]
|
361
335
|
batch_size, *spatial_dims = arr.shape
|
362
336
|
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
)
|
367
|
-
|
368
|
-
transformation_matrix[:, :n_dims] = rotation_matrix
|
369
|
-
transformation_matrix[:, n_dims] = translation
|
370
|
-
transformation_matrix = transformation_matrix.unsqueeze(0).expand(
|
371
|
-
batch_size, -1, -1
|
372
|
-
)
|
337
|
+
# Remove homogeneous row and expand for batch processing
|
338
|
+
matrix = matrix[:-1, :].to(arr.dtype)
|
339
|
+
matrix = matrix.unsqueeze(0).expand(batch_size, -1, -1)
|
373
340
|
|
374
|
-
if not batched:
|
375
|
-
arr = arr.unsqueeze(0)
|
376
|
-
|
377
|
-
size = self.Size([batch_size, 1, *spatial_dims])
|
378
341
|
grid = self.F.affine_grid(
|
379
|
-
theta=
|
342
|
+
theta=matrix.to(arr.dtype),
|
343
|
+
size=self.Size([batch_size, 1, *spatial_dims]),
|
344
|
+
align_corners=False,
|
380
345
|
)
|
381
|
-
|
346
|
+
|
347
|
+
arr = arr.unsqueeze(0) if not batched else arr
|
348
|
+
ret = self.F.grid_sample(
|
382
349
|
input=arr.unsqueeze(1),
|
383
350
|
grid=grid,
|
384
351
|
mode=mode,
|
385
352
|
align_corners=False,
|
386
|
-
)
|
353
|
+
).squeeze(1)
|
354
|
+
|
355
|
+
ret_mask = None
|
356
|
+
if arr_mask is not None:
|
357
|
+
arr_mask = arr_mask.unsqueeze(0) if not batched else arr_mask
|
358
|
+
ret_mask = self.F.grid_sample(
|
359
|
+
input=arr_mask.unsqueeze(1),
|
360
|
+
grid=grid,
|
361
|
+
mode=mode,
|
362
|
+
align_corners=False,
|
363
|
+
).squeeze(1)
|
387
364
|
|
388
365
|
if not batched:
|
389
|
-
|
366
|
+
ret = ret.squeeze(0)
|
367
|
+
ret_mask = ret_mask.squeeze(0) if arr_mask is not None else None
|
368
|
+
|
369
|
+
if out is not None:
|
370
|
+
out[out_slice] = ret
|
371
|
+
else:
|
372
|
+
out = ret
|
390
373
|
|
391
|
-
|
374
|
+
if out_mask is not None:
|
375
|
+
out_mask[out_slice] = ret_mask
|
376
|
+
else:
|
377
|
+
out_mask = ret_mask
|
378
|
+
return out, out_mask
|
392
379
|
|
393
380
|
def get_available_memory(self) -> int:
|
394
381
|
if self.device == "cpu":
|
tme/cli.py
CHANGED
@@ -52,7 +52,7 @@ def match_template(
|
|
52
52
|
"""
|
53
53
|
from .matching_data import MatchingData
|
54
54
|
from .analyzer import MaxScoreOverRotations
|
55
|
-
from .matching_exhaustive import
|
55
|
+
from .matching_exhaustive import match_exhaustive, MATCHING_EXHAUSTIVE_REGISTER
|
56
56
|
|
57
57
|
if rotations is None:
|
58
58
|
rotations = np.eye(target.ndim).reshape(1, target.ndim, target.ndim)
|
@@ -73,7 +73,7 @@ def match_template(
|
|
73
73
|
matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[score]
|
74
74
|
|
75
75
|
candidates = list(
|
76
|
-
|
76
|
+
match_exhaustive(
|
77
77
|
matching_data=matching_data,
|
78
78
|
matching_score=matching_score,
|
79
79
|
matching_setup=matching_setup,
|