pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +50 -103
  6. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
  8. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +50 -103
  15. scripts/pytme_runner.py +46 -69
  16. scripts/refine_matches.py +5 -7
  17. tests/preprocessing/test_compose.py +31 -30
  18. tests/preprocessing/test_frequency_filters.py +17 -32
  19. tests/preprocessing/test_preprocessor.py +0 -19
  20. tests/preprocessing/test_utils.py +13 -1
  21. tests/test_analyzer.py +2 -10
  22. tests/test_backends.py +47 -18
  23. tests/test_density.py +72 -13
  24. tests/test_extensions.py +1 -0
  25. tests/test_matching_cli.py +23 -9
  26. tests/test_matching_exhaustive.py +5 -5
  27. tests/test_matching_utils.py +3 -3
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +124 -71
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +110 -105
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +102 -58
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +28 -8
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
  65. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
  66. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
  67. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
  68. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,6 @@ Copyright (c) 2023 European Molecular Biology Laboratory
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
- import os
10
9
  from psutil import virtual_memory
11
10
  from contextlib import contextmanager
12
11
  from typing import Tuple, List, Type
@@ -26,12 +25,6 @@ from ..types import NDArray, BackendArray, shm_type
26
25
  from .matching_backend import MatchingBackend, _create_metafunction
27
26
 
28
27
 
29
- os.environ["MKL_NUM_THREADS"] = "1"
30
- os.environ["OMP_NUM_THREADS"] = "1"
31
- os.environ["PYFFTW_NUM_THREADS"] = "1"
32
- os.environ["OPENBLAS_NUM_THREADS"] = "1"
33
-
34
-
35
28
  def create_ufuncs(obj):
36
29
  ufuncs = [
37
30
  "add",
@@ -247,12 +240,13 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
247
240
 
248
241
  shm = shared_memory_handler.SharedMemory(size=arr.nbytes)
249
242
  np_array = self.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
250
- np_array[:] = arr[:].copy()
243
+ np_array[:] = arr[:]
251
244
  return shm, arr.shape, arr.dtype
252
245
 
253
- def topleft_pad(self, arr: NDArray, shape: Tuple[int], padval: int = 0) -> NDArray:
254
- b = self.zeros(shape, arr.dtype)
255
- self.add(b, padval, out=b)
246
+ def topleft_pad(
247
+ self, arr: NDArray, shape: Tuple[int], padval: float = 0
248
+ ) -> NDArray:
249
+ b = self.full(shape, fill_value=padval, dtype=arr.dtype)
256
250
  aind = [slice(None, None)] * arr.ndim
257
251
  bind = [slice(None, None)] * arr.ndim
258
252
  for i in range(arr.ndim):
@@ -311,15 +305,6 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
311
305
  **kwargs,
312
306
  )
313
307
 
314
- def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
315
- new_shape = self.to_backend_array(newshape)
316
- current_shape = self.to_backend_array(arr.shape)
317
- starts = self.subtract(current_shape, new_shape)
318
- starts = self.astype(self.divide(starts, 2), self._int_dtype)
319
- stops = self.astype(self.add(starts, new_shape), self._int_dtype)
320
- box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
321
- return arr[box]
322
-
323
308
  def compute_convolution_shapes(
324
309
  self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
325
310
  ) -> Tuple[List[int], List[int], List[int]]:
@@ -329,62 +314,65 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
329
314
 
330
315
  return convolution_shape, fast_shape, fast_ft_shape
331
316
 
332
- def _rigid_transform_matrix(
317
+ def _build_transform_matrix(
333
318
  self,
334
- rotation_matrix: NDArray,
335
- translation: NDArray = None,
336
- center: NDArray = None,
337
- ) -> NDArray:
319
+ rotation_matrix: BackendArray,
320
+ translation: BackendArray = None,
321
+ center: BackendArray = None,
322
+ **kwargs,
323
+ ) -> BackendArray:
338
324
  ndim = rotation_matrix.shape[0]
339
- matrix = self.identity(ndim + 1, dtype=self._float_dtype)
340
325
 
326
+ spatial_slice = slice(0, ndim)
327
+ matrix = self.eye(ndim + 1, dtype=self._float_dtype)
328
+
329
+ rotation_matrix = self.astype(rotation_matrix, self._float_dtype)
330
+ matrix = self.at(matrix, (spatial_slice, spatial_slice), rotation_matrix)
331
+
332
+ total_translation = self.zeros(ndim, dtype=self._float_dtype)
341
333
  if translation is not None:
342
- translation_matrix = self.identity(ndim + 1, dtype=self._float_dtype)
343
- translation_matrix[:ndim, ndim] = -translation
344
- self.dot(matrix, translation_matrix, out=matrix)
334
+ translation = self.astype(translation, self._float_dtype)
335
+ total_translation = self.subtract(total_translation, translation)
345
336
 
346
337
  if center is not None:
347
- center_matrix = self.identity(ndim + 1, dtype=self._float_dtype)
348
- center_matrix[:ndim, ndim] = center
349
- self.dot(matrix, center_matrix, out=matrix)
338
+ total_translation = self.add(total_translation, center)
339
+ rotated_center = self.matmul(rotation_matrix, center)
340
+ total_translation = self.subtract(total_translation, rotated_center)
350
341
 
351
- if rotation_matrix is not None:
352
- rmat = self.identity(ndim + 1, dtype=self._float_dtype)
353
- rmat[:ndim, :ndim] = self._array_backend.linalg.inv(rotation_matrix)
354
- self.dot(matrix, rmat, out=matrix)
342
+ matrix = self.at(matrix, (spatial_slice, ndim), total_translation)
343
+ return self.to_backend_array(matrix)
355
344
 
356
- if center is not None:
357
- center_matrix[:ndim, ndim] = -center_matrix[:ndim, ndim]
358
- self.dot(matrix, center_matrix, out=matrix)
345
+ def _batch_transform_matrix(self, matrix: NDArray) -> NDArray:
346
+ ndim = matrix.shape[0] + 1
359
347
 
360
- matrix /= matrix[ndim, ndim]
361
- return matrix
348
+ ret = self.zeros((ndim, ndim), dtype=matrix.dtype)
349
+ ret = self.at(ret, (0, 0), 1)
362
350
 
363
- def _rigid_transform(
351
+ spatial_slice = slice(1, ndim)
352
+ ret = self.at(ret, (spatial_slice, spatial_slice), matrix)
353
+ return ret
354
+
355
+ def _compute_transform_center(
356
+ self, arr: NDArray, use_geometric_center: bool, batched: bool = False
357
+ ) -> NDArray:
358
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
359
+ if not use_geometric_center:
360
+ center = self.center_of_mass(arr, cutoff=0)
361
+ if batched:
362
+ return center[1:]
363
+ return center
364
+
365
+ def _transform(
364
366
  self,
365
367
  data: NDArray,
366
368
  matrix: NDArray,
367
369
  output: NDArray,
368
370
  prefilter: bool,
369
371
  order: int,
370
- cache: bool = False,
371
- batched=False,
372
- ) -> None:
373
- if batched:
374
- for i in range(data.shape[0]):
375
- self._rigid_transform(
376
- data=data[i],
377
- matrix=matrix,
378
- output=output[i],
379
- prefilter=prefilter,
380
- order=order,
381
- cache=cache,
382
- batched=False,
383
- )
384
- return None
385
-
372
+ **kwargs,
373
+ ) -> NDArray:
386
374
  out_slice = tuple(slice(0, stop) for stop in data.shape)
387
- self.affine_transform(
375
+ return self.affine_transform(
388
376
  input=data,
389
377
  matrix=matrix,
390
378
  mode="constant",
@@ -393,72 +381,87 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
393
381
  prefilter=prefilter,
394
382
  )
395
383
 
396
- def rigid_transform(
384
+ def _rigid_transform(
397
385
  self,
398
386
  arr: NDArray,
399
- rotation_matrix: NDArray,
387
+ matrix: NDArray,
400
388
  arr_mask: NDArray = None,
401
- translation: NDArray = None,
402
- use_geometric_center: bool = False,
403
389
  out: NDArray = None,
404
390
  out_mask: NDArray = None,
405
391
  order: int = 3,
406
392
  cache: bool = False,
407
- batched: bool = False,
393
+ **kwargs,
408
394
  ) -> Tuple[NDArray, NDArray]:
409
395
  if out is None:
410
396
  out = self.zeros_like(arr)
411
397
 
412
- # Check whether rotation_matrix is already a rigid transform matrix
413
- matrix = rotation_matrix
414
- if matrix.shape[-1] == (arr.ndim - int(batched)):
415
- center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
416
- if not use_geometric_center:
417
- center = self.center_of_mass(arr, cutoff=0)
418
-
419
- offset = int(arr.ndim - rotation_matrix.shape[0])
420
- center = center[offset:]
421
- translation = (
422
- self.zeros(center.size) if translation is None else translation
423
- )
424
- matrix = self._rigid_transform_matrix(
425
- rotation_matrix=rotation_matrix,
426
- translation=translation,
427
- center=center,
428
- )
429
-
430
- self._rigid_transform(
398
+ out = self._transform(
431
399
  data=arr,
432
400
  matrix=matrix,
433
401
  output=out,
434
402
  order=order,
435
403
  prefilter=True,
436
404
  cache=cache,
437
- batched=batched,
438
405
  )
439
406
 
440
- # Applying the prefilter leads to artifacts in the mask.
441
407
  if arr_mask is not None:
442
408
  if out_mask is None:
443
- out_mask = self.zeros_like(arr_mask)
409
+ out_mask = self.zeros_like(arr)
444
410
 
445
- self._rigid_transform(
411
+ # Applying the prefilter leads to artifacts in the mask.
412
+ out_mask = self._transform(
446
413
  data=arr_mask,
447
414
  matrix=matrix,
448
415
  output=out_mask,
449
416
  order=order,
450
417
  prefilter=False,
451
418
  cache=cache,
452
- batched=batched,
453
419
  )
454
420
 
455
421
  return out, out_mask
456
422
 
423
+ def rigid_transform(
424
+ self,
425
+ arr: NDArray,
426
+ rotation_matrix: NDArray,
427
+ arr_mask: NDArray = None,
428
+ translation: NDArray = None,
429
+ use_geometric_center: bool = False,
430
+ out: NDArray = None,
431
+ out_mask: NDArray = None,
432
+ order: int = 3,
433
+ cache: bool = False,
434
+ batched: bool = False,
435
+ ) -> Tuple[NDArray, NDArray]:
436
+ matrix = rotation_matrix
437
+
438
+ # Build transformation matrix from rotation matrix
439
+ if matrix.shape[-1] == (arr.ndim - int(batched)):
440
+ center = self._compute_transform_center(arr, use_geometric_center, batched)
441
+ matrix = self._build_transform_matrix(
442
+ rotation_matrix=rotation_matrix,
443
+ translation=translation,
444
+ center=self.astype(center, self._float_dtype),
445
+ shape=arr.shape[1:] if batched else arr.shape,
446
+ )
447
+
448
+ if batched:
449
+ matrix = self._batch_transform_matrix(matrix)
450
+
451
+ return self._rigid_transform(
452
+ arr=arr,
453
+ arr_mask=arr_mask,
454
+ out=out,
455
+ out_mask=out_mask,
456
+ matrix=matrix,
457
+ cache=cache,
458
+ order=order,
459
+ batched=batched,
460
+ )
461
+
457
462
  def center_of_mass(self, arr: BackendArray, cutoff: float = None) -> BackendArray:
458
463
  """
459
- Computes the center of mass of a numpy ndarray instance using all available
460
- elements. For template matching it typically makes sense to only input
461
- positive densities.
464
+ Computes the center of mass of an array larger than cutoff.
462
465
 
463
466
  Parameters
464
467
  ----------
@@ -466,19 +469,19 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
466
469
  Array to compute the center of mass of.
467
470
  cutoff : float, optional
468
471
  Densities less than or equal to cutoff are nullified for center
469
- of mass computation. By default considers all values.
472
+ of mass computation. Defaults to None.
470
473
 
471
474
  Returns
472
475
  -------
473
476
  BackendArray
474
477
  Center of mass with shape (arr.ndim).
475
478
  """
476
- cutoff = self.min(arr) - 1 if cutoff is None else cutoff
477
-
478
- arr = self.where(arr > cutoff, arr, 0)
479
- denominator = self.sum(arr)
479
+ arr = self.abs(arr)
480
+ if cutoff is not None:
481
+ arr = self.where(arr > cutoff, arr, 0)
480
482
 
481
483
  grids = []
484
+ denominator = self.sum(arr)
482
485
  for i, x in enumerate(arr.shape):
483
486
  baseline_dims = tuple(1 if i != t else x for t in range(len(arr.shape)))
484
487
  grids.append(
@@ -587,7 +590,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
587
590
  sq_exp = self.sqrt(sq_exp, out=sq_exp)
588
591
 
589
592
  # Assume that low stdev regions also have low scores
590
- # See :py:meth:`tme.matching_exhaustive.flcSphericalMask_setup` for correct norm
593
+ # See :py:meth:`tme.matching_scores.flcSphericalMask_setup` for correct norm
591
594
  sq_exp[sq_exp < eps] = 1
592
595
  sq_exp = self.multiply(sq_exp, n_obs, out=sq_exp)
593
596
  return self.divide(arr, sq_exp, out=out)
@@ -281,22 +281,45 @@ class PytorchBackend(NumpyFFTWBackend):
281
281
  kwargs["dim"] = kwargs.pop("axes", None)
282
282
  return self._array_backend.fft.irfftn(arr, **kwargs)
283
283
 
284
- def _rigid_transform_matrix(self, rotation_matrix, *args, **kwargs):
285
- return rotation_matrix
284
+ def _build_transform_matrix(
285
+ self,
286
+ shape: Tuple[int],
287
+ rotation_matrix: TorchTensor,
288
+ translation: TorchTensor = None,
289
+ center: TorchTensor = None,
290
+ **kwargs,
291
+ ) -> TorchTensor:
292
+ """
293
+ Express the transform matrix in normalized coordinates.
294
+ """
295
+ shape = self.to_backend_array(shape) - 1
296
+
297
+ scale_factors = 2.0 / shape
298
+ if center is not None:
299
+ center = center - shape / 2
300
+ center = center * scale_factors
301
+
302
+ if translation is not None:
303
+ translation = translation * scale_factors
304
+
305
+ return super()._build_transform_matrix(
306
+ rotation_matrix=self.flip(rotation_matrix, [0, 1]),
307
+ translation=translation,
308
+ center=center,
309
+ )
286
310
 
287
- def rigid_transform(
311
+ def _rigid_transform(
288
312
  self,
289
313
  arr: TorchTensor,
290
- rotation_matrix: TorchTensor,
314
+ matrix: TorchTensor,
291
315
  arr_mask: TorchTensor = None,
292
- translation: TorchTensor = None,
293
- use_geometric_center: bool = False,
294
316
  out: TorchTensor = None,
295
317
  out_mask: TorchTensor = None,
296
318
  order: int = 1,
297
- cache: bool = False,
319
+ batched: bool = False,
298
320
  **kwargs,
299
- ):
321
+ ) -> Tuple[TorchTensor, TorchTensor]:
322
+ """Apply rigid transformation using homogeneous transformation matrix."""
300
323
  _mode_mapping = {0: "nearest", 1: "bilinear", 3: "bicubic"}
301
324
  mode = _mode_mapping.get(order, None)
302
325
  if mode is None:
@@ -305,90 +328,54 @@ class PytorchBackend(NumpyFFTWBackend):
305
328
  f"Got {order} but supported interpolation orders are: {modes}."
306
329
  )
307
330
 
308
- out = self.zeros_like(arr) if out is None else out
309
-
310
- if translation is None:
311
- translation = self._array_backend.zeros(arr.ndim, device=arr.device)
312
-
313
- normalized_translation = self.divide(
314
- -2.0 * translation, self.tensor(arr.shape, device=arr.device)
315
- )
316
- rotation_matrix_pull = self.linalg.inv(self.flip(rotation_matrix, [0, 1]))
317
-
318
- out_slice = tuple(slice(0, x) for x in arr.shape)
319
- subset = tuple(slice(None) for _ in range(arr.ndim))
320
- offset = max(int(arr.ndim - rotation_matrix.shape[0]) - 1, 0)
321
- if offset > 0:
322
- normalized_translation = normalized_translation[offset:]
323
- subset = tuple(0 if i < offset else slice(None) for i in range(arr.ndim))
324
- out_slice = tuple(
325
- slice(0, 1) if i < offset else slice(0, x)
326
- for i, x in enumerate(arr.shape)
327
- )
328
-
329
- out[out_slice] = self._affine_transform(
330
- arr=arr[subset],
331
- rotation_matrix=rotation_matrix_pull,
332
- translation=normalized_translation,
333
- mode=mode,
334
- )
335
-
336
- if arr_mask is not None:
337
- out_mask_slice = tuple(slice(0, x) for x in arr_mask.shape)
338
- if out_mask is None:
339
- out_mask = self._array_backend.zeros_like(arr_mask)
340
- out_mask[out_mask_slice] = self._affine_transform(
341
- arr=arr_mask[subset],
342
- rotation_matrix=rotation_matrix_pull,
343
- translation=normalized_translation,
344
- mode=mode,
345
- )
346
-
347
- return out, out_mask
348
-
349
- def _affine_transform(
350
- self,
351
- arr: TorchTensor,
352
- rotation_matrix: TorchTensor,
353
- translation: TorchTensor,
354
- mode,
355
- ) -> TorchTensor:
356
- batched = arr.ndim != rotation_matrix.shape[0]
357
-
358
331
  batch_size, spatial_dims = 1, arr.shape
332
+ out_slice = tuple(slice(0, x) for x in arr.shape)
359
333
  if batched:
360
- translation = translation[1:]
334
+ matrix = matrix[1:, 1:]
361
335
  batch_size, *spatial_dims = arr.shape
362
336
 
363
- n_dims = len(spatial_dims)
364
- transformation_matrix = self._array_backend.zeros(
365
- n_dims, n_dims + 1, device=arr.device, dtype=arr.dtype
366
- )
367
-
368
- transformation_matrix[:, :n_dims] = rotation_matrix
369
- transformation_matrix[:, n_dims] = translation
370
- transformation_matrix = transformation_matrix.unsqueeze(0).expand(
371
- batch_size, -1, -1
372
- )
337
+ # Remove homogeneous row and expand for batch processing
338
+ matrix = matrix[:-1, :].to(arr.dtype)
339
+ matrix = matrix.unsqueeze(0).expand(batch_size, -1, -1)
373
340
 
374
- if not batched:
375
- arr = arr.unsqueeze(0)
376
-
377
- size = self.Size([batch_size, 1, *spatial_dims])
378
341
  grid = self.F.affine_grid(
379
- theta=transformation_matrix, size=size, align_corners=False
342
+ theta=matrix.to(arr.dtype),
343
+ size=self.Size([batch_size, 1, *spatial_dims]),
344
+ align_corners=False,
380
345
  )
381
- output = self.F.grid_sample(
346
+
347
+ arr = arr.unsqueeze(0) if not batched else arr
348
+ ret = self.F.grid_sample(
382
349
  input=arr.unsqueeze(1),
383
350
  grid=grid,
384
351
  mode=mode,
385
352
  align_corners=False,
386
- )
353
+ ).squeeze(1)
354
+
355
+ ret_mask = None
356
+ if arr_mask is not None:
357
+ arr_mask = arr_mask.unsqueeze(0) if not batched else arr_mask
358
+ ret_mask = self.F.grid_sample(
359
+ input=arr_mask.unsqueeze(1),
360
+ grid=grid,
361
+ mode=mode,
362
+ align_corners=False,
363
+ ).squeeze(1)
387
364
 
388
365
  if not batched:
389
- output = output.squeeze(0)
366
+ ret = ret.squeeze(0)
367
+ ret_mask = ret_mask.squeeze(0) if arr_mask is not None else None
368
+
369
+ if out is not None:
370
+ out[out_slice] = ret
371
+ else:
372
+ out = ret
390
373
 
391
- return output.squeeze(1)
374
+ if out_mask is not None:
375
+ out_mask[out_slice] = ret_mask
376
+ else:
377
+ out_mask = ret_mask
378
+ return out, out_mask
392
379
 
393
380
  def get_available_memory(self) -> int:
394
381
  if self.device == "cpu":
tme/cli.py CHANGED
@@ -52,7 +52,7 @@ def match_template(
52
52
  """
53
53
  from .matching_data import MatchingData
54
54
  from .analyzer import MaxScoreOverRotations
55
- from .matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
55
+ from .matching_exhaustive import match_exhaustive, MATCHING_EXHAUSTIVE_REGISTER
56
56
 
57
57
  if rotations is None:
58
58
  rotations = np.eye(target.ndim).reshape(1, target.ndim, target.ndim)
@@ -73,7 +73,7 @@ def match_template(
73
73
  matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[score]
74
74
 
75
75
  candidates = list(
76
- scan_subsets(
76
+ match_exhaustive(
77
77
  matching_data=matching_data,
78
78
  matching_score=matching_score,
79
79
  matching_setup=matching_setup,