pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.2.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/match_template.py +213 -196
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/postprocess.py +40 -78
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocess.py +4 -5
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/preprocessor_gui.py +50 -103
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/pytme_runner.py +46 -69
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/METADATA +3 -2
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/RECORD +68 -65
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +213 -196
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +40 -78
- scripts/preprocess.py +4 -5
- scripts/preprocessor_gui.py +50 -103
- scripts/pytme_runner.py +46 -69
- scripts/refine_matches.py +5 -7
- tests/preprocessing/test_compose.py +31 -30
- tests/preprocessing/test_frequency_filters.py +17 -32
- tests/preprocessing/test_preprocessor.py +0 -19
- tests/preprocessing/test_utils.py +13 -1
- tests/test_analyzer.py +2 -10
- tests/test_backends.py +47 -18
- tests/test_density.py +72 -13
- tests/test_extensions.py +1 -0
- tests/test_matching_cli.py +23 -9
- tests/test_matching_exhaustive.py +5 -5
- tests/test_matching_utils.py +3 -3
- tests/test_rotations.py +13 -23
- tests/test_structure.py +1 -7
- tme/__version__.py +1 -1
- tme/analyzer/aggregation.py +47 -16
- tme/analyzer/base.py +34 -0
- tme/analyzer/peaks.py +26 -13
- tme/analyzer/proxy.py +14 -0
- tme/backends/_jax_utils.py +124 -71
- tme/backends/cupy_backend.py +6 -19
- tme/backends/jax_backend.py +110 -105
- tme/backends/matching_backend.py +0 -17
- tme/backends/mlx_backend.py +0 -29
- tme/backends/npfftw_backend.py +100 -97
- tme/backends/pytorch_backend.py +65 -78
- tme/cli.py +2 -2
- tme/density.py +102 -58
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/_utils.py +52 -24
- tme/filters/bandpass.py +99 -105
- tme/filters/compose.py +133 -39
- tme/filters/ctf.py +51 -102
- tme/filters/reconstruction.py +67 -122
- tme/filters/wedge.py +296 -325
- tme/filters/whitening.py +39 -75
- tme/mask.py +2 -2
- tme/matching_data.py +87 -15
- tme/matching_exhaustive.py +70 -120
- tme/matching_optimization.py +9 -63
- tme/matching_scores.py +261 -100
- tme/matching_utils.py +150 -91
- tme/memory.py +1 -0
- tme/orientations.py +28 -8
- tme/preprocessor.py +0 -239
- tme/rotations.py +102 -70
- tme/structure.py +601 -631
- tme/types.py +1 -0
- {pytme-0.3.1.post1.data → pytme-0.3.2.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/WHEEL +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/entry_points.txt +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dist-info}/top_level.txt +0 -0
tme/matching_exhaustive.py
CHANGED
@@ -8,7 +8,6 @@ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
|
8
8
|
|
9
9
|
import sys
|
10
10
|
import warnings
|
11
|
-
from math import prod
|
12
11
|
from functools import wraps
|
13
12
|
from itertools import product
|
14
13
|
from typing import Callable, Tuple, Dict, Optional
|
@@ -16,14 +15,15 @@ from typing import Callable, Tuple, Dict, Optional
|
|
16
15
|
from joblib import Parallel, delayed
|
17
16
|
from multiprocessing.managers import SharedMemoryManager
|
18
17
|
|
19
|
-
from .filters import Compose
|
20
18
|
from .backends import backend as be
|
21
|
-
from .matching_utils import split_shape
|
19
|
+
from .matching_utils import split_shape, setup_filter
|
22
20
|
from .types import CallbackClass, MatchingData
|
23
21
|
from .analyzer.proxy import SharedAnalyzerProxy
|
24
22
|
from .matching_scores import MATCHING_EXHAUSTIVE_REGISTER
|
25
23
|
from .memory import MatchingMemoryUsage, MATCHING_MEMORY_REGISTRY
|
26
24
|
|
25
|
+
__all__ = ["match_exhaustive"]
|
26
|
+
|
27
27
|
|
28
28
|
def _wrap_backend(func):
|
29
29
|
@wraps(func)
|
@@ -36,89 +36,6 @@ def _wrap_backend(func):
|
|
36
36
|
return wrapper
|
37
37
|
|
38
38
|
|
39
|
-
def _setup_template_filter_apply_target_filter(
|
40
|
-
matching_data: MatchingData,
|
41
|
-
fast_shape: Tuple[int],
|
42
|
-
fast_ft_shape: Tuple[int],
|
43
|
-
pad_template_filter: bool = False,
|
44
|
-
):
|
45
|
-
target_filter = None
|
46
|
-
backend_arr = type(be.zeros((1), dtype=be._float_dtype))
|
47
|
-
template_filter = be.full(shape=(1,), fill_value=1, dtype=be._float_dtype)
|
48
|
-
if isinstance(matching_data.template_filter, backend_arr):
|
49
|
-
template_filter = matching_data.template_filter
|
50
|
-
|
51
|
-
if isinstance(matching_data.target_filter, backend_arr):
|
52
|
-
target_filter = matching_data.target_filter
|
53
|
-
|
54
|
-
filter_template = isinstance(matching_data.template_filter, Compose)
|
55
|
-
filter_target = isinstance(matching_data.target_filter, Compose)
|
56
|
-
|
57
|
-
# For now assume user-supplied template_filter is correctly padded
|
58
|
-
if filter_target is None and target_filter is None:
|
59
|
-
return template_filter
|
60
|
-
|
61
|
-
cmpl_template_shape_full, batch_mask = fast_ft_shape, matching_data._batch_mask
|
62
|
-
real_shape = matching_data._batch_shape(fast_shape, batch_mask, keepdims=False)
|
63
|
-
cmpl_shape = matching_data._batch_shape(fast_ft_shape, batch_mask, keepdims=True)
|
64
|
-
|
65
|
-
real_template_shape, cmpl_template_shape = real_shape, cmpl_shape
|
66
|
-
cmpl_template_shape_full = matching_data._batch_shape(
|
67
|
-
fast_ft_shape, matching_data._target_batch, keepdims=True
|
68
|
-
)
|
69
|
-
cmpl_target_shape_full = matching_data._batch_shape(
|
70
|
-
fast_ft_shape, matching_data._template_batch, keepdims=True
|
71
|
-
)
|
72
|
-
if filter_template and not pad_template_filter:
|
73
|
-
out_shape = matching_data._output_template_shape
|
74
|
-
real_template_shape = matching_data._batch_shape(
|
75
|
-
out_shape, batch_mask, keepdims=False
|
76
|
-
)
|
77
|
-
cmpl_template_shape = list(
|
78
|
-
matching_data._batch_shape(out_shape, batch_mask, keepdims=True)
|
79
|
-
)
|
80
|
-
cmpl_template_shape_full = list(out_shape)
|
81
|
-
cmpl_template_shape[-1] = cmpl_template_shape[-1] // 2 + 1
|
82
|
-
cmpl_template_shape_full[-1] = cmpl_template_shape_full[-1] // 2 + 1
|
83
|
-
|
84
|
-
# Setup composable filters
|
85
|
-
target_temp = be.topleft_pad(matching_data.target, fast_shape)
|
86
|
-
target_temp_ft = be.rfftn(target_temp)
|
87
|
-
filter_kwargs = {
|
88
|
-
"return_real_fourier": True,
|
89
|
-
"shape_is_real_fourier": False,
|
90
|
-
"data_rfft": target_temp_ft,
|
91
|
-
"batch_dimension": matching_data._target_dim,
|
92
|
-
}
|
93
|
-
|
94
|
-
if filter_template:
|
95
|
-
template_filter = matching_data.template_filter(
|
96
|
-
shape=real_template_shape, **filter_kwargs
|
97
|
-
)["data"]
|
98
|
-
template_filter_size = int(be.size(template_filter))
|
99
|
-
|
100
|
-
if template_filter_size == prod(cmpl_template_shape_full):
|
101
|
-
cmpl_template_shape = cmpl_template_shape_full
|
102
|
-
elif template_filter_size == prod(cmpl_shape):
|
103
|
-
cmpl_template_shape = cmpl_shape
|
104
|
-
template_filter = be.reshape(template_filter, cmpl_template_shape)
|
105
|
-
|
106
|
-
if filter_target:
|
107
|
-
target_filter = matching_data.target_filter(
|
108
|
-
shape=real_shape, weight_type=None, **filter_kwargs
|
109
|
-
)["data"]
|
110
|
-
if int(be.size(target_filter)) == prod(cmpl_target_shape_full):
|
111
|
-
cmpl_shape = cmpl_target_shape_full
|
112
|
-
|
113
|
-
target_filter = be.reshape(target_filter, cmpl_shape)
|
114
|
-
target_temp_ft = be.multiply(target_temp_ft, target_filter, out=target_temp_ft)
|
115
|
-
|
116
|
-
target_temp = be.irfftn(target_temp_ft, s=target_temp.shape)
|
117
|
-
matching_data._target = be.topleft_pad(target_temp, matching_data.target.shape)
|
118
|
-
|
119
|
-
return be.astype(be.to_backend_array(template_filter), be._float_dtype)
|
120
|
-
|
121
|
-
|
122
39
|
def device_memory_handler(func: Callable):
|
123
40
|
"""Decorator function providing SharedMemory Handler."""
|
124
41
|
|
@@ -142,7 +59,7 @@ def device_memory_handler(func: Callable):
|
|
142
59
|
|
143
60
|
|
144
61
|
@device_memory_handler
|
145
|
-
def
|
62
|
+
def _match_exhaustive(
|
146
63
|
matching_data: MatchingData,
|
147
64
|
matching_setup: Callable,
|
148
65
|
matching_score: Callable,
|
@@ -155,6 +72,8 @@ def scan(
|
|
155
72
|
shm_handler=None,
|
156
73
|
target_slice=None,
|
157
74
|
template_slice=None,
|
75
|
+
background_correction: str = None,
|
76
|
+
**kwargs,
|
158
77
|
) -> Optional[Tuple]:
|
159
78
|
"""
|
160
79
|
Run template matching.
|
@@ -187,29 +106,19 @@ def scan(
|
|
187
106
|
Target subset to process.
|
188
107
|
template_slice : tuple of slice, optional
|
189
108
|
Template subset to process.
|
109
|
+
background_correction : str, optional
|
110
|
+
Background correctoin use use. Supported methods are 'phase-scrambling'.
|
190
111
|
|
191
112
|
Returns
|
192
113
|
-------
|
193
114
|
Optional[Tuple]
|
194
115
|
The merged results from callback_class if provided otherwise None.
|
195
116
|
|
196
|
-
|
197
|
-
|
198
|
-
Schematically,
|
117
|
+
Notes
|
118
|
+
-----
|
119
|
+
Schematically, this function is identical to :py:meth:`match_exhaustive`,
|
199
120
|
with the distinction that the objects contained in ``matching_data`` are not
|
200
121
|
split and the search is only parallelized over angles.
|
201
|
-
Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan`
|
202
|
-
can be invoked like so
|
203
|
-
|
204
|
-
>>> from tme.matching_exhaustive import scan
|
205
|
-
>>> results = scan(
|
206
|
-
>>> matching_data=matching_data,
|
207
|
-
>>> matching_score=matching_score,
|
208
|
-
>>> matching_setup=matching_setup,
|
209
|
-
>>> callback_class=callback_class,
|
210
|
-
>>> callback_class_args=callback_class_args,
|
211
|
-
>>> )
|
212
|
-
|
213
122
|
"""
|
214
123
|
matching_data, translation_offset = matching_data.subset_by_slice(
|
215
124
|
target_slice=target_slice,
|
@@ -219,19 +128,21 @@ def scan(
|
|
219
128
|
|
220
129
|
matching_data.to_backend()
|
221
130
|
template_shape = matching_data._batch_shape(
|
222
|
-
matching_data.
|
131
|
+
matching_data._template.shape, matching_data._template_batch
|
223
132
|
)
|
224
133
|
conv, fwd, inv, shift = matching_data.fourier_padding()
|
225
134
|
|
135
|
+
# Mask invalid scores from padding to not skew score statistics
|
226
136
|
score_mask = be.full(shape=(1,), fill_value=1, dtype=bool)
|
227
137
|
if pad_target:
|
228
138
|
score_mask = matching_data._score_mask(fwd, shift)
|
229
139
|
|
230
|
-
template_filter =
|
140
|
+
template_filter, _ = setup_filter(
|
231
141
|
matching_data=matching_data,
|
232
142
|
fast_shape=fwd,
|
233
143
|
fast_ft_shape=inv,
|
234
144
|
pad_template_filter=False,
|
145
|
+
apply_target_filter=True,
|
235
146
|
)
|
236
147
|
|
237
148
|
default_callback_args = {
|
@@ -259,7 +170,15 @@ def scan(
|
|
259
170
|
shm_handler=shm_handler,
|
260
171
|
)
|
261
172
|
|
262
|
-
|
173
|
+
if background_correction == "phase-scrambling":
|
174
|
+
# Use getter to make sure template is reversed correctly
|
175
|
+
matching_data.template = matching_data.transform_template("phase_randomization")
|
176
|
+
setup["template_background"] = be.to_sharedarr(matching_data.template)
|
177
|
+
|
178
|
+
matching_data.free()
|
179
|
+
if not callback_class.shareable:
|
180
|
+
jobs_per_callback_class = 1
|
181
|
+
|
263
182
|
n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
|
264
183
|
callback_classes = [
|
265
184
|
SharedAnalyzerProxy(
|
@@ -286,11 +205,14 @@ def scan(
|
|
286
205
|
)
|
287
206
|
be.free_cache()
|
288
207
|
|
289
|
-
|
208
|
+
# Background correction creates individual non-shared arrays
|
209
|
+
if background_correction is None:
|
210
|
+
ret = ret[:n_callback_classes]
|
211
|
+
callbacks = [x.result(**default_callback_args) for x in ret]
|
290
212
|
return callback_class.merge(callbacks, **default_callback_args)
|
291
213
|
|
292
214
|
|
293
|
-
def
|
215
|
+
def match_exhaustive(
|
294
216
|
matching_data: MatchingData,
|
295
217
|
matching_score: Callable,
|
296
218
|
matching_setup: Callable,
|
@@ -305,11 +227,12 @@ def scan_subsets(
|
|
305
227
|
backend_name: str = None,
|
306
228
|
backend_args: Dict = {},
|
307
229
|
verbose: bool = False,
|
230
|
+
background_correction: str = None,
|
308
231
|
**kwargs,
|
309
232
|
) -> Optional[Tuple]:
|
310
233
|
"""
|
311
|
-
|
312
|
-
|
234
|
+
Run exhaustive template matching over all translations and a subset of rotations
|
235
|
+
specified in `matching_data`.
|
313
236
|
|
314
237
|
Parameters
|
315
238
|
----------
|
@@ -341,7 +264,9 @@ def scan_subsets(
|
|
341
264
|
How many jobs should be processed by a single callback_class instance,
|
342
265
|
if ones is provided.
|
343
266
|
verbose : bool, optional
|
344
|
-
Indicate matching progress.
|
267
|
+
Indicate matching progress, defaults to False.
|
268
|
+
background_correction : str, optional
|
269
|
+
Background correctoin use use. Supported methods are 'phase-scrambling'.
|
345
270
|
|
346
271
|
Returns
|
347
272
|
-------
|
@@ -355,7 +280,7 @@ def scan_subsets(
|
|
355
280
|
|
356
281
|
>>> import numpy as np
|
357
282
|
>>> from tme.matching_data import MatchingData
|
358
|
-
>>> from tme.
|
283
|
+
>>> from tme.rotations import get_rotation_matrices
|
359
284
|
>>> target = np.random.rand(50,40,60)
|
360
285
|
>>> template = target[15:25, 10:20, 30:40]
|
361
286
|
>>> matching_data = MatchingData(target, template)
|
@@ -391,8 +316,8 @@ def scan_subsets(
|
|
391
316
|
Finally, we can perform template matching. Note that the data
|
392
317
|
contained in ``matching_data`` will be destroyed when running the following
|
393
318
|
|
394
|
-
>>> from tme.matching_exhaustive import
|
395
|
-
>>> results =
|
319
|
+
>>> from tme.matching_exhaustive import match_exhaustive
|
320
|
+
>>> results = match_exhaustive(
|
396
321
|
>>> matching_data=matching_data,
|
397
322
|
>>> matching_score=matching_score,
|
398
323
|
>>> matching_setup=matching_setup,
|
@@ -407,6 +332,12 @@ def scan_subsets(
|
|
407
332
|
--------
|
408
333
|
:py:meth:`tme.matching_utils.compute_parallelization_schedule`
|
409
334
|
"""
|
335
|
+
if background_correction not in (None, "phase-scrambling"):
|
336
|
+
raise ValueError(
|
337
|
+
"Argument background_correction can be either None or "
|
338
|
+
f"'phase-scrambling', got {background_correction}."
|
339
|
+
)
|
340
|
+
|
410
341
|
template_splits = split_shape(matching_data._template.shape, splits=template_splits)
|
411
342
|
target_splits = split_shape(matching_data._target.shape, splits=target_splits)
|
412
343
|
if (len(target_splits) > 1) and not pad_target_edges:
|
@@ -417,26 +348,25 @@ def scan_subsets(
|
|
417
348
|
splits = tuple(product(target_splits, template_splits))
|
418
349
|
|
419
350
|
kwargs = {
|
351
|
+
"match_projection": kwargs.get("match_projection", False),
|
420
352
|
"matching_data": matching_data,
|
421
353
|
"callback_class": callback_class,
|
422
354
|
"callback_class_args": callback_class_args,
|
423
355
|
}
|
424
|
-
|
425
356
|
outer_jobs, inner_jobs = job_schedule
|
426
357
|
if be._backend_name == "jax":
|
427
|
-
|
428
|
-
|
429
|
-
corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
|
430
|
-
results = func(
|
358
|
+
score = MATCHING_EXHAUSTIVE_REGISTER.get("FLC", (None, None))[1]
|
359
|
+
results = be.scan(
|
431
360
|
splits=splits,
|
432
361
|
n_jobs=outer_jobs,
|
433
|
-
rotate_mask=matching_score
|
362
|
+
rotate_mask=matching_score == score,
|
363
|
+
background_correction=background_correction,
|
434
364
|
**kwargs,
|
435
365
|
)
|
436
366
|
else:
|
437
367
|
results = Parallel(n_jobs=outer_jobs, verbose=verbose)(
|
438
368
|
[
|
439
|
-
delayed(_wrap_backend(
|
369
|
+
delayed(_wrap_backend(_match_exhaustive))(
|
440
370
|
backend_name=be._backend_name,
|
441
371
|
backend_args=be._backend_args,
|
442
372
|
matching_score=matching_score,
|
@@ -447,6 +377,7 @@ def scan_subsets(
|
|
447
377
|
gpu_index=index % outer_jobs,
|
448
378
|
target_slice=target_split,
|
449
379
|
template_slice=template_split,
|
380
|
+
background_correction=background_correction,
|
450
381
|
**kwargs,
|
451
382
|
)
|
452
383
|
for index, (target_split, template_split) in enumerate(splits)
|
@@ -489,3 +420,22 @@ def register_matching_exhaustive(
|
|
489
420
|
|
490
421
|
MATCHING_EXHAUSTIVE_REGISTER[matching] = (matching_setup, matching_scoring)
|
491
422
|
MATCHING_MEMORY_REGISTRY[matching] = memory_class
|
423
|
+
|
424
|
+
|
425
|
+
def scan(*args, **kwargs):
|
426
|
+
warnings.warn(
|
427
|
+
"Using scan directly is deprecated and will raise an error "
|
428
|
+
"in future releases. Please use match_exhaustive instead.",
|
429
|
+
DeprecationWarning,
|
430
|
+
stacklevel=2,
|
431
|
+
)
|
432
|
+
return _match_exhaustive(*args, **kwargs)
|
433
|
+
|
434
|
+
|
435
|
+
def scan_subsets(*args, **kwargs):
|
436
|
+
warnings.warn(
|
437
|
+
"Using scan_subsets directly is deprecated and will raise an error "
|
438
|
+
"in future releases. Please use match_exhaustive instead.",
|
439
|
+
DeprecationWarning,
|
440
|
+
)
|
441
|
+
return match_exhaustive(*args, **kwargs)
|
tme/matching_optimization.py
CHANGED
@@ -23,7 +23,7 @@ from .backends import backend as be
|
|
23
23
|
from .types import ArrayLike, NDArray
|
24
24
|
from .matching_data import MatchingData
|
25
25
|
from .rotations import euler_to_rotationmatrix
|
26
|
-
from .matching_utils import
|
26
|
+
from .matching_utils import _rigid_transform, standardize
|
27
27
|
|
28
28
|
|
29
29
|
def _format_rigid_transform(x: Tuple[float]) -> Tuple[ArrayLike, ArrayLike]:
|
@@ -45,10 +45,8 @@ def _format_rigid_transform(x: Tuple[float]) -> Tuple[ArrayLike, ArrayLike]:
|
|
45
45
|
translation, angles = x[:split], x[split:]
|
46
46
|
|
47
47
|
translation = be.to_backend_array(translation)
|
48
|
-
rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles))
|
49
|
-
|
50
|
-
|
51
|
-
return translation, rotation_matrix
|
48
|
+
rotation_matrix = euler_to_rotationmatrix(be.to_numpy_array(angles), "ZYZ")
|
49
|
+
return translation, be.to_backend_array(rotation_matrix)
|
52
50
|
|
53
51
|
|
54
52
|
class _MatchDensityToDensity(ABC):
|
@@ -121,57 +119,6 @@ class _MatchDensityToDensity(ABC):
|
|
121
119
|
if hasattr(self, "_post_init"):
|
122
120
|
self._post_init(**kwargs)
|
123
121
|
|
124
|
-
def rotate_array(
|
125
|
-
self,
|
126
|
-
arr,
|
127
|
-
rotation_matrix,
|
128
|
-
translation,
|
129
|
-
arr_mask=None,
|
130
|
-
out=None,
|
131
|
-
out_mask=None,
|
132
|
-
order: int = 1,
|
133
|
-
**kwargs,
|
134
|
-
):
|
135
|
-
rotate_mask = arr_mask is not None
|
136
|
-
return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
|
137
|
-
translation = np.zeros(arr.ndim) if translation is None else translation
|
138
|
-
|
139
|
-
center = np.floor(np.array(arr.shape) / 2)[:, None]
|
140
|
-
|
141
|
-
if not hasattr(self, "_previous_center"):
|
142
|
-
self._previous_center = arr.shape
|
143
|
-
|
144
|
-
if not hasattr(self, "grid") or not np.allclose(self._previous_center, center):
|
145
|
-
self.grid = np.indices(arr.shape, dtype=np.float32).reshape(arr.ndim, -1)
|
146
|
-
np.subtract(self.grid, center, out=self.grid)
|
147
|
-
self.grid_out = np.zeros_like(self.grid)
|
148
|
-
self._previous_center = center
|
149
|
-
|
150
|
-
np.matmul(rotation_matrix.T, self.grid, out=self.grid_out)
|
151
|
-
translation = np.add(translation[:, None], center)
|
152
|
-
np.add(self.grid_out, translation, out=self.grid_out)
|
153
|
-
|
154
|
-
if out is None:
|
155
|
-
out = np.zeros_like(arr)
|
156
|
-
|
157
|
-
self._interpolate(arr, self.grid_out, order=order, out=out.ravel())
|
158
|
-
|
159
|
-
if out_mask is None and arr_mask is not None:
|
160
|
-
out_mask = np.zeros_like(arr_mask)
|
161
|
-
|
162
|
-
if arr_mask is not None:
|
163
|
-
self._interpolate(arr_mask, self.grid_out, order=order, out=out.ravel())
|
164
|
-
|
165
|
-
match return_type:
|
166
|
-
case 0:
|
167
|
-
return None
|
168
|
-
case 1:
|
169
|
-
return out
|
170
|
-
case 2:
|
171
|
-
return out_mask
|
172
|
-
case 3:
|
173
|
-
return out, out_mask
|
174
|
-
|
175
122
|
@staticmethod
|
176
123
|
def _interpolate(data, positions, order: int = 1, out=None):
|
177
124
|
return map_coordinates(
|
@@ -266,8 +213,7 @@ class _MatchDensityToDensity(ABC):
|
|
266
213
|
self.template_mask_rot.fill(0)
|
267
214
|
kw_dict["arr_mask"] = self.template_mask
|
268
215
|
kw_dict["out_mask"] = self.template_mask_rot
|
269
|
-
|
270
|
-
self.rotate_array(**kw_dict)
|
216
|
+
be.rigid_transform(**kw_dict)
|
271
217
|
|
272
218
|
return self()
|
273
219
|
|
@@ -361,7 +307,7 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
|
|
361
307
|
"""
|
362
308
|
translation, rotation_matrix = _format_rigid_transform(x)
|
363
309
|
|
364
|
-
|
310
|
+
_rigid_transform(
|
365
311
|
coordinates=self.template,
|
366
312
|
coordinates_mask=self.template_mask,
|
367
313
|
rotation_matrix=rotation_matrix,
|
@@ -469,7 +415,7 @@ class _MatchCoordinatesToCoordinates(_MatchDensityToDensity):
|
|
469
415
|
"""
|
470
416
|
translation, rotation_matrix = _format_rigid_transform(x)
|
471
417
|
|
472
|
-
|
418
|
+
_rigid_transform(
|
473
419
|
coordinates=self.template_coordinates,
|
474
420
|
coordinates_mask=self.template_mask_coordinates,
|
475
421
|
rotation_matrix=rotation_matrix,
|
@@ -514,7 +460,7 @@ class FLC(_MatchDensityToDensity):
|
|
514
460
|
|
515
461
|
self.target_square = be.square(self.target)
|
516
462
|
|
517
|
-
|
463
|
+
standardize(
|
518
464
|
template=self.template,
|
519
465
|
mask=self.template_mask,
|
520
466
|
n_observations=be.sum(self.template_mask),
|
@@ -524,7 +470,7 @@ class FLC(_MatchDensityToDensity):
|
|
524
470
|
"""Returns the score of the current configuration."""
|
525
471
|
n_obs = be.sum(self.template_mask_rot)
|
526
472
|
|
527
|
-
|
473
|
+
standardize(
|
528
474
|
template=self.template_rot,
|
529
475
|
mask=self.template_mask_rot,
|
530
476
|
n_observations=n_obs,
|
@@ -1309,5 +1255,5 @@ def optimize_match(
|
|
1309
1255
|
print("Initial score better than refined score. Returning identity.")
|
1310
1256
|
result.x = np.zeros_like(result.x)
|
1311
1257
|
translation, rotation = result.x[:ndim], result.x[ndim:]
|
1312
|
-
rotation_matrix = euler_to_rotationmatrix(rotation)
|
1258
|
+
rotation_matrix = euler_to_rotationmatrix(rotation, "ZYZ")
|
1313
1259
|
return translation, rotation_matrix, float(result.fun)
|