pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
- pytme-0.2.3.data/scripts/preprocess.py +132 -0
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
- pytme-0.2.3.dist-info/METADATA +92 -0
- pytme-0.2.3.dist-info/RECORD +75 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
- pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +219 -216
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +86 -54
- scripts/preprocess.py +95 -56
- scripts/preprocessor_gui.py +181 -94
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +458 -813
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +187 -0
- tme/backends/cupy_backend.py +109 -226
- tme/backends/jax_backend.py +230 -152
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +240 -507
- tme/backends/pytorch_backend.py +30 -151
- tme/density.py +248 -371
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +328 -284
- tme/matching_exhaustive.py +195 -1499
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +887 -0
- tme/matching_utils.py +287 -388
- tme/memory.py +377 -0
- tme/orientations.py +78 -21
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/composable_filter.py +7 -4
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +44 -72
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/matching_data.py
CHANGED
@@ -1,18 +1,19 @@
|
|
1
|
-
"""
|
1
|
+
""" Class representation of template matching data.
|
2
2
|
|
3
3
|
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
4
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
7
|
import warnings
|
8
|
-
from typing import Tuple, List
|
8
|
+
from typing import Tuple, List, Optional
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
from numpy.typing import NDArray
|
12
12
|
|
13
13
|
from . import Density
|
14
14
|
from .types import ArrayLike
|
15
|
-
from .
|
15
|
+
from .preprocessing import Compose
|
16
|
+
from .backends import backend as be
|
16
17
|
from .matching_utils import compute_full_convolution_index
|
17
18
|
|
18
19
|
|
@@ -31,9 +32,9 @@ class MatchingData:
|
|
31
32
|
template_mask : np.ndarray or :py:class:`tme.density.Density`, optional
|
32
33
|
Template mask data.
|
33
34
|
invert_target : bool, optional
|
34
|
-
Whether to invert
|
35
|
+
Whether to invert the target before template matching..
|
35
36
|
rotations: np.ndarray, optional
|
36
|
-
Template rotations to sample. Can be a single (d
|
37
|
+
Template rotations to sample. Can be a single (d, d) or a stack (n, d, d)
|
37
38
|
of rotation matrices where d is the dimension of the template.
|
38
39
|
|
39
40
|
Examples
|
@@ -57,26 +58,18 @@ class MatchingData:
|
|
57
58
|
invert_target: bool = False,
|
58
59
|
rotations: NDArray = None,
|
59
60
|
):
|
60
|
-
self.
|
61
|
-
self.
|
62
|
-
self._template_mask = template_mask
|
63
|
-
self._translation_offset = np.zeros(len(target.shape), dtype=int)
|
61
|
+
self.target = target
|
62
|
+
self.target_mask = target_mask
|
64
63
|
|
65
64
|
self.template = template
|
65
|
+
if template_mask is not None:
|
66
|
+
self.template_mask = template_mask
|
66
67
|
|
67
|
-
self.
|
68
|
-
self.
|
69
|
-
|
70
|
-
self.template_filter = {}
|
71
|
-
self.target_filter = {}
|
72
|
-
|
68
|
+
self.rotations = rotations
|
69
|
+
self._translation_offset = np.zeros(len(target.shape), dtype=int)
|
73
70
|
self._invert_target = invert_target
|
74
71
|
|
75
|
-
self.
|
76
|
-
if rotations is not None:
|
77
|
-
self.rotations = rotations
|
78
|
-
|
79
|
-
self._set_batch_dimension()
|
72
|
+
self._set_matching_dimension()
|
80
73
|
|
81
74
|
@staticmethod
|
82
75
|
def _shape_to_slice(shape: Tuple[int]):
|
@@ -105,8 +98,7 @@ class MatchingData:
|
|
105
98
|
NDArray
|
106
99
|
Loaded array.
|
107
100
|
"""
|
108
|
-
|
109
|
-
if type(arr) == np.memmap:
|
101
|
+
if isinstance(arr, np.memmap):
|
110
102
|
return np.memmap(arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype)
|
111
103
|
return arr
|
112
104
|
|
@@ -141,7 +133,7 @@ class MatchingData:
|
|
141
133
|
NDArray
|
142
134
|
Subset of the input array with padding applied.
|
143
135
|
"""
|
144
|
-
padding =
|
136
|
+
padding = be.to_numpy_array(padding)
|
145
137
|
padding = np.maximum(padding, 0).astype(int)
|
146
138
|
|
147
139
|
slice_start = np.array([x.start for x in arr_slice], dtype=int)
|
@@ -160,38 +152,18 @@ class MatchingData:
|
|
160
152
|
arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
|
161
153
|
arr_mesh = self._slice_to_mesh(arr_slice, arr.shape)
|
162
154
|
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
dens = Density.from_file(arr.data.filename, subset=arr_slice)
|
167
|
-
arr = dens.data
|
168
|
-
arr_min = dens.metadata.get("min", None)
|
169
|
-
arr_max = dens.metadata.get("max", None)
|
155
|
+
if isinstance(arr, Density):
|
156
|
+
if isinstance(arr.data, np.memmap):
|
157
|
+
arr = Density.from_file(arr.data.filename, subset=arr_slice).data
|
170
158
|
else:
|
171
159
|
arr = np.asarray(arr.data[*arr_mesh])
|
172
160
|
else:
|
173
|
-
if
|
161
|
+
if isinstance(arr, np.memmap):
|
174
162
|
arr = np.memmap(
|
175
163
|
arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype
|
176
164
|
)
|
177
165
|
arr = np.asarray(arr[*arr_mesh])
|
178
166
|
|
179
|
-
def _warn_on_mismatch(
|
180
|
-
expectation: float, computation: float, name: str
|
181
|
-
) -> float:
|
182
|
-
if expectation is None:
|
183
|
-
expectation = computation
|
184
|
-
expectation, computation = float(expectation), float(computation)
|
185
|
-
|
186
|
-
if abs(computation) > abs(expectation):
|
187
|
-
warnings.warn(
|
188
|
-
f"Computed {name} value is more extreme than value in file header"
|
189
|
-
f" (|{computation}| > |{expectation}|). This may lead to issues"
|
190
|
-
" with padding and contrast inversion."
|
191
|
-
)
|
192
|
-
|
193
|
-
return expectation
|
194
|
-
|
195
167
|
padding = tuple(
|
196
168
|
(left, right)
|
197
169
|
for left, right in zip(
|
@@ -199,17 +171,12 @@ class MatchingData:
|
|
199
171
|
np.subtract(right_pad, data_voxels_right),
|
200
172
|
)
|
201
173
|
)
|
202
|
-
|
174
|
+
# The reflections are later cropped from the scores
|
175
|
+
arr = np.pad(arr, padding, mode="reflect")
|
203
176
|
|
204
177
|
if invert:
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
# Avoid in-place operation in case ret is not floating point
|
209
|
-
ret = (
|
210
|
-
-np.divide(np.subtract(ret, arr_min), np.subtract(arr_max, arr_min)) + 1
|
211
|
-
)
|
212
|
-
return ret
|
178
|
+
arr = -arr
|
179
|
+
return arr
|
213
180
|
|
214
181
|
def subset_by_slice(
|
215
182
|
self,
|
@@ -220,97 +187,94 @@ class MatchingData:
|
|
220
187
|
invert_target: bool = False,
|
221
188
|
) -> "MatchingData":
|
222
189
|
"""
|
223
|
-
|
190
|
+
Subset class instance based on slices.
|
224
191
|
|
225
192
|
Parameters
|
226
193
|
----------
|
227
194
|
target_slice : tuple of slice, optional
|
228
|
-
|
195
|
+
Target subset to use, all by default.
|
229
196
|
template_slice : tuple of slice, optional
|
230
|
-
|
197
|
+
Template subset to use, all by default.
|
231
198
|
target_pad : NDArray, optional
|
232
|
-
|
233
|
-
pad with mean.
|
199
|
+
Target padding, zero by default.
|
234
200
|
template_pad : NDArray, optional
|
235
|
-
|
236
|
-
pad with mean.
|
201
|
+
Template padding, zero by default.
|
237
202
|
|
238
203
|
Returns
|
239
204
|
-------
|
240
|
-
MatchingData
|
241
|
-
Newly allocated
|
205
|
+
:py:class:`MatchingData`
|
206
|
+
Newly allocated subset of class instance.
|
207
|
+
|
208
|
+
Examples
|
209
|
+
--------
|
210
|
+
>>> import numpy as np
|
211
|
+
>>> from tme.matching_data import MatchingData
|
212
|
+
>>> target = np.random.rand(50,40,60)
|
213
|
+
>>> template = target[15:25, 10:20, 30:40]
|
214
|
+
>>> matching_data = MatchingData(target=target, template=template)
|
215
|
+
>>> subset = matching_data.subset_by_slice(
|
216
|
+
>>> target_slice=(slice(0, 10), slice(10,20), slice(15,35))
|
217
|
+
>>> )
|
242
218
|
"""
|
243
|
-
target_shape = self._target.shape
|
244
|
-
template_shape = self._template.shape
|
245
|
-
|
246
219
|
if target_slice is None:
|
247
|
-
target_slice = self._shape_to_slice(
|
220
|
+
target_slice = self._shape_to_slice(self._target.shape)
|
248
221
|
if template_slice is None:
|
249
|
-
template_slice = self._shape_to_slice(
|
222
|
+
template_slice = self._shape_to_slice(self._template.shape)
|
250
223
|
|
251
224
|
if target_pad is None:
|
252
225
|
target_pad = np.zeros(len(self._target.shape), dtype=int)
|
253
226
|
if template_pad is None:
|
254
227
|
template_pad = np.zeros(len(self._template.shape), dtype=int)
|
255
228
|
|
256
|
-
|
257
|
-
if len(self._target.shape) == len(self._template.shape):
|
258
|
-
indices = compute_full_convolution_index(
|
259
|
-
outer_shape=self._target.shape,
|
260
|
-
inner_shape=self._template.shape,
|
261
|
-
outer_split=target_slice,
|
262
|
-
inner_split=template_slice,
|
263
|
-
)
|
264
|
-
|
229
|
+
target_mask, template_mask = None, None
|
265
230
|
target_subset = self.subset_array(
|
266
|
-
|
267
|
-
arr_slice=target_slice,
|
268
|
-
padding=target_pad,
|
269
|
-
invert=self._invert_target,
|
231
|
+
self._target, target_slice, target_pad, invert=self._invert_target
|
270
232
|
)
|
271
|
-
|
272
233
|
template_subset = self.subset_array(
|
273
|
-
arr=self._template,
|
274
|
-
arr_slice=template_slice,
|
275
|
-
padding=template_pad,
|
234
|
+
arr=self._template, arr_slice=template_slice, padding=template_pad
|
276
235
|
)
|
277
|
-
ret = self.__class__(target=target_subset, template=template_subset)
|
278
|
-
|
279
|
-
target_offset = np.zeros(len(self._output_target_shape), dtype=int)
|
280
|
-
target_offset[(target_offset.size - len(target_slice)) :] = [
|
281
|
-
x.start for x in target_slice
|
282
|
-
]
|
283
|
-
template_offset = np.zeros(len(self._output_target_shape), dtype=int)
|
284
|
-
template_offset[(template_offset.size - len(template_slice)) :] = [
|
285
|
-
x.start for x in template_slice
|
286
|
-
]
|
287
|
-
ret._translation_offset = target_offset
|
288
|
-
|
289
|
-
ret.template_filter = self.template_filter
|
290
|
-
ret.target_filter = self.target_filter
|
291
|
-
ret._rotations, ret.indices = self.rotations, indices
|
292
|
-
ret._target_pad, ret._template_pad = target_pad, template_pad
|
293
|
-
ret._invert_target = self._invert_target
|
294
|
-
|
295
236
|
if self._target_mask is not None:
|
296
|
-
|
237
|
+
target_mask = self.subset_array(
|
297
238
|
arr=self._target_mask, arr_slice=target_slice, padding=target_pad
|
298
239
|
)
|
299
240
|
if self._template_mask is not None:
|
300
|
-
|
301
|
-
arr=self._template_mask,
|
302
|
-
arr_slice=template_slice,
|
303
|
-
padding=template_pad,
|
241
|
+
template_mask = self.subset_array(
|
242
|
+
arr=self._template_mask, arr_slice=template_slice, padding=template_pad
|
304
243
|
)
|
305
244
|
|
306
|
-
|
307
|
-
|
308
|
-
|
245
|
+
ret = self.__class__(
|
246
|
+
target=target_subset,
|
247
|
+
template=template_subset,
|
248
|
+
template_mask=template_mask,
|
249
|
+
target_mask=target_mask,
|
250
|
+
rotations=self.rotations,
|
251
|
+
invert_target=self._invert_target,
|
252
|
+
)
|
253
|
+
|
254
|
+
# Deal with splitting offsets
|
255
|
+
target_offset = np.zeros(len(self._output_target_shape), dtype=int)
|
256
|
+
offset = target_offset.size - len(target_slice)
|
257
|
+
target_offset[offset:] = [x.start for x in target_slice]
|
258
|
+
template_offset = np.zeros(len(self._output_target_shape), dtype=int)
|
259
|
+
offset = template_offset.size - len(template_slice)
|
260
|
+
template_offset[offset:] = [x.start for x in template_slice]
|
261
|
+
ret._translation_offset = target_offset
|
262
|
+
if len(self._target.shape) == len(self._template.shape):
|
263
|
+
ret.indices = compute_full_convolution_index(
|
264
|
+
outer_shape=self._target.shape,
|
265
|
+
inner_shape=self._template.shape,
|
266
|
+
outer_split=target_slice,
|
267
|
+
inner_split=template_slice,
|
268
|
+
)
|
309
269
|
|
310
|
-
|
311
|
-
|
270
|
+
ret._is_padded = be.sum(be.to_backend_array(target_pad)) > 0
|
271
|
+
ret.target_filter = self.target_filter
|
272
|
+
ret.template_filter = self.template_filter
|
312
273
|
|
313
|
-
ret.
|
274
|
+
ret._set_matching_dimension(
|
275
|
+
target_dims=getattr(self, "_target_dims", None),
|
276
|
+
template_dims=getattr(self, "_template_dims", None),
|
277
|
+
)
|
314
278
|
|
315
279
|
return ret
|
316
280
|
|
@@ -318,47 +282,42 @@ class MatchingData:
|
|
318
282
|
"""
|
319
283
|
Transfer and convert types of class instance's data arrays to the current backend
|
320
284
|
"""
|
321
|
-
backend_arr = type(
|
285
|
+
backend_arr = type(be.zeros((1), dtype=be._float_dtype))
|
322
286
|
for attr_name, attr_value in vars(self).items():
|
323
287
|
converted_array = None
|
324
288
|
if isinstance(attr_value, np.ndarray):
|
325
|
-
converted_array =
|
289
|
+
converted_array = be.to_backend_array(attr_value.copy())
|
326
290
|
elif isinstance(attr_value, backend_arr):
|
327
|
-
converted_array =
|
291
|
+
converted_array = be.to_backend_array(attr_value)
|
328
292
|
else:
|
329
293
|
continue
|
330
294
|
|
331
|
-
current_dtype =
|
332
|
-
target_dtype =
|
295
|
+
current_dtype = be.get_fundamental_dtype(converted_array)
|
296
|
+
target_dtype = be._fundamental_dtypes[current_dtype]
|
333
297
|
|
334
298
|
# Optional, but scores are float so we avoid casting and potential issues
|
335
299
|
if attr_name in ("_template", "_template_mask", "_target", "_target_mask"):
|
336
|
-
target_dtype =
|
300
|
+
target_dtype = be._float_dtype
|
337
301
|
|
338
302
|
if target_dtype != current_dtype:
|
339
|
-
converted_array =
|
303
|
+
converted_array = be.astype(converted_array, target_dtype)
|
340
304
|
|
341
305
|
setattr(self, attr_name, converted_array)
|
342
306
|
|
343
|
-
def
|
307
|
+
def _set_matching_dimension(
|
344
308
|
self, target_dims: Tuple[int] = None, template_dims: Tuple[int] = None
|
345
309
|
) -> None:
|
346
310
|
"""
|
347
|
-
Sets
|
348
|
-
their corresponding batch dimensions.
|
349
|
-
|
311
|
+
Sets matching dimensions for target and template.
|
350
312
|
Parameters
|
351
313
|
----------
|
352
|
-
target_dims :
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
A tuple of integers specifying the batch dimensions of the template. If None,
|
357
|
-
the template is assumed not to have batch dimensions.
|
314
|
+
target_dims : tuple of ints, optional
|
315
|
+
Target batch dimensions, None by default.
|
316
|
+
template_dims : tuple of ints, optional
|
317
|
+
Template batch dimensions, None by default.
|
358
318
|
|
359
319
|
Notes
|
360
320
|
-----
|
361
|
-
|
362
321
|
If the target and template share a batch dimension, the target will
|
363
322
|
take precendence and the template dimension will be shifted to the right.
|
364
323
|
If target and template have the same dimension, but target specifies batch
|
@@ -386,15 +345,9 @@ class MatchingData:
|
|
386
345
|
|
387
346
|
matching_dims = target_measurement_dims + batch_dims
|
388
347
|
|
389
|
-
target_shape =
|
390
|
-
|
391
|
-
)
|
392
|
-
template_shape = backend.full(
|
393
|
-
shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
|
394
|
-
)
|
395
|
-
batch_mask = backend.full(
|
396
|
-
shape=(matching_dims,), fill_value=1, dtype=backend._int_dtype
|
397
|
-
)
|
348
|
+
target_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
|
349
|
+
template_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
|
350
|
+
batch_mask = np.full(shape=matching_dims, fill_value=1, dtype=int)
|
398
351
|
|
399
352
|
target_index, template_index = 0, 0
|
400
353
|
for k in range(matching_dims):
|
@@ -420,9 +373,9 @@ class MatchingData:
|
|
420
373
|
if template_dim < template_ndim:
|
421
374
|
template_shape[k] = self._template.shape[template_dim]
|
422
375
|
|
423
|
-
self._output_target_shape = target_shape
|
424
|
-
self._output_template_shape = template_shape
|
425
|
-
self._batch_mask = batch_mask
|
376
|
+
self._output_target_shape = tuple(int(x) for x in target_shape)
|
377
|
+
self._output_template_shape = tuple(int(x) for x in template_shape)
|
378
|
+
self._batch_mask = tuple(int(x) for x in batch_mask)
|
426
379
|
|
427
380
|
@staticmethod
|
428
381
|
def _compute_batch_dimension(
|
@@ -433,22 +386,22 @@ class MatchingData:
|
|
433
386
|
|
434
387
|
Parameters
|
435
388
|
----------
|
436
|
-
batch_dims :
|
389
|
+
batch_dims : tuple of ints
|
437
390
|
A tuple of integers representing the batch dimensions.
|
438
391
|
ndim : int
|
439
392
|
The number of dimensions of the array.
|
440
393
|
|
441
394
|
Returns
|
442
395
|
-------
|
443
|
-
Tuple[ArrayLike,
|
444
|
-
|
396
|
+
Tuple[ArrayLike, tuple of ints]
|
397
|
+
Mask and the corresponding batch dimensions.
|
445
398
|
|
446
399
|
Raises
|
447
400
|
------
|
448
401
|
ValueError
|
449
402
|
If any dimension in batch_dims is not less than ndim.
|
450
403
|
"""
|
451
|
-
mask =
|
404
|
+
mask = np.zeros(ndim, dtype=int)
|
452
405
|
if batch_dims is None:
|
453
406
|
return mask, ()
|
454
407
|
|
@@ -463,215 +416,298 @@ class MatchingData:
|
|
463
416
|
|
464
417
|
return mask, batch_dims
|
465
418
|
|
466
|
-
def target_padding(self, pad_target: bool = False) ->
|
419
|
+
def target_padding(self, pad_target: bool = False) -> Tuple[int]:
|
467
420
|
"""
|
468
421
|
Computes padding for the target based on the template's shape.
|
469
422
|
|
470
423
|
Parameters
|
471
424
|
----------
|
472
425
|
pad_target : bool, default False
|
473
|
-
|
474
|
-
an array of zeros is returned.
|
426
|
+
Whether to pad the target, default returns an array of zeros.
|
475
427
|
|
476
428
|
Returns
|
477
429
|
-------
|
478
|
-
|
479
|
-
|
430
|
+
tuple of ints
|
431
|
+
Padding along each dimension of the target.
|
480
432
|
"""
|
481
|
-
target_padding =
|
482
|
-
len(self._output_target_shape), dtype=backend._int_dtype
|
483
|
-
)
|
484
|
-
|
433
|
+
target_padding = np.zeros(len(self._output_target_shape), dtype=int)
|
485
434
|
if pad_target:
|
486
|
-
|
435
|
+
target_padding = np.subtract(
|
487
436
|
self._output_template_shape,
|
488
|
-
|
489
|
-
out=target_padding,
|
437
|
+
np.mod(self._output_template_shape, 2),
|
490
438
|
)
|
491
439
|
if hasattr(self, "_is_target_batch"):
|
492
|
-
target_padding
|
440
|
+
target_padding = np.multiply(
|
441
|
+
target_padding,
|
442
|
+
np.subtract(1, self._is_target_batch),
|
443
|
+
)
|
493
444
|
|
494
|
-
return target_padding
|
445
|
+
return tuple(int(x) for x in target_padding)
|
495
446
|
|
496
|
-
|
497
|
-
|
498
|
-
|
447
|
+
@staticmethod
|
448
|
+
def _fourier_padding(
|
449
|
+
target_shape: NDArray,
|
450
|
+
template_shape: NDArray,
|
451
|
+
batch_mask: NDArray = None,
|
452
|
+
pad_fourier: bool = False,
|
453
|
+
) -> Tuple[Tuple, Tuple, Tuple]:
|
499
454
|
"""
|
500
|
-
|
501
|
-
corresponding shape of the real-valued FFT, and the associated
|
502
|
-
translation shift.
|
503
|
-
|
504
|
-
Parameters
|
505
|
-
----------
|
506
|
-
pad_fourier : bool, default False
|
507
|
-
If true, returns the shape of the full-convolution defined as sum of target
|
508
|
-
shape and template shape minus one. By default, returns unpadded transform.
|
509
|
-
|
510
|
-
Returns
|
511
|
-
-------
|
512
|
-
Tuple[ArrayLike, ArrayLike, ArrayLike]
|
513
|
-
A tuple containing the calculated fast shape, fast Fourier transform shape,
|
514
|
-
and the Fourier shift values, respectively.
|
455
|
+
Determines an efficient shape for Fourier transforms considering zero-padding.
|
515
456
|
"""
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
457
|
+
fourier_pad = template_shape
|
458
|
+
fourier_shift = np.zeros_like(template_shape)
|
459
|
+
|
460
|
+
if batch_mask is None:
|
461
|
+
batch_mask = np.zeros_like(template_shape)
|
462
|
+
batch_mask = np.asarray(batch_mask)
|
520
463
|
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
464
|
+
if not pad_fourier:
|
465
|
+
fourier_pad = np.ones(len(fourier_pad), dtype=int)
|
466
|
+
fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
|
467
|
+
fourier_pad = np.add(fourier_pad, batch_mask)
|
525
468
|
|
526
|
-
|
527
|
-
|
469
|
+
pad_shape = np.maximum(target_shape, template_shape)
|
470
|
+
ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
|
471
|
+
conv_shape, fast_shape, fast_ft_shape = ret
|
528
472
|
|
473
|
+
template_mod = np.mod(template_shape, 2)
|
529
474
|
if not pad_fourier:
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
475
|
+
fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
|
476
|
+
fourier_shift = np.subtract(fourier_shift, template_mod)
|
477
|
+
|
478
|
+
shape_diff = np.multiply(
|
479
|
+
np.subtract(target_shape, template_shape), 1 - batch_mask
|
480
|
+
)
|
481
|
+
if np.sum(shape_diff < 0):
|
482
|
+
warnings.warn(
|
483
|
+
"Template is larger than target and padding is turned off. Consider "
|
484
|
+
"swapping them or activate padding. Correcting the shift for now."
|
534
485
|
)
|
535
486
|
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
487
|
+
shape_shift = np.divide(shape_diff, 2)
|
488
|
+
offset = np.mod(shape_diff, 2)
|
489
|
+
if pad_fourier:
|
490
|
+
offset = -np.subtract(
|
491
|
+
offset,
|
492
|
+
np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
|
493
|
+
)
|
541
494
|
|
542
|
-
|
543
|
-
|
544
|
-
convolution_shape, fast_shape, fast_ft_shape = ret
|
545
|
-
if not pad_fourier:
|
546
|
-
fourier_shift = 1 - backend.astype(backend.divide(template_shape, 2), int)
|
547
|
-
fourier_shift -= backend.mod(template_shape, 2)
|
548
|
-
shape_diff = backend.subtract(fast_shape, convolution_shape)
|
549
|
-
shape_diff = backend.astype(backend.divide(shape_diff, 2), int)
|
495
|
+
shape_shift = np.add(shape_shift, offset)
|
496
|
+
fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
|
550
497
|
|
551
|
-
|
552
|
-
|
553
|
-
backend.multiply(shape_diff, 1 - batch_mask, out=shape_diff)
|
498
|
+
fourier_shift = tuple(fourier_shift.astype(int))
|
499
|
+
return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
|
554
500
|
|
555
|
-
|
501
|
+
def fourier_padding(self, pad_fourier: bool = False) -> Tuple[Tuple, Tuple, Tuple]:
|
502
|
+
"""
|
503
|
+
Computes efficient shape four Fourier transforms and potential associated shifts.
|
556
504
|
|
557
|
-
|
505
|
+
Parameters
|
506
|
+
----------
|
507
|
+
pad_fourier : bool, default False
|
508
|
+
If true, returns the shape of the full-convolution defined as sum of target
|
509
|
+
shape and template shape minus one, False by default.
|
558
510
|
|
559
|
-
|
511
|
+
Returns
|
512
|
+
-------
|
513
|
+
Tuple[tuple of int, tuple of int, tuple of int]
|
514
|
+
Tuple with real and complex Fourier transform shape, and corresponding shift.
|
515
|
+
"""
|
516
|
+
return self._fourier_padding(
|
517
|
+
target_shape=be.to_numpy_array(self._output_target_shape),
|
518
|
+
template_shape=be.to_numpy_array(self._output_template_shape),
|
519
|
+
batch_mask=be.to_numpy_array(self._batch_mask),
|
520
|
+
pad_fourier=pad_fourier,
|
521
|
+
)
|
560
522
|
|
561
523
|
@property
|
562
524
|
def rotations(self):
|
563
|
-
"""Return stored rotation matrices
|
525
|
+
"""Return stored rotation matrices."""
|
564
526
|
return self._rotations
|
565
527
|
|
566
528
|
@rotations.setter
|
567
529
|
def rotations(self, rotations: NDArray):
|
568
530
|
"""
|
569
|
-
Set
|
531
|
+
Set :py:attr:`MatchingData.rotations`.
|
570
532
|
|
571
533
|
Parameters
|
572
534
|
----------
|
573
535
|
rotations : NDArray
|
574
|
-
Rotations
|
536
|
+
Rotations matrices with shape (d, d) or (n, d, d).
|
575
537
|
"""
|
576
|
-
if rotations
|
577
|
-
|
578
|
-
|
538
|
+
if rotations is None:
|
539
|
+
print("No rotations provided, assuming identity for now.")
|
540
|
+
rotations = np.eye(len(self._target.shape))
|
541
|
+
|
542
|
+
if rotations.ndim not in (2, 3):
|
543
|
+
raise ValueError("Rotations have to be a rank 2 or 3 array.")
|
544
|
+
elif rotations.ndim == 2:
|
579
545
|
print("Reshaping rotations array to rank 3.")
|
580
546
|
rotations = rotations.reshape(1, *rotations.shape)
|
581
|
-
elif rotations.ndim == 3:
|
582
|
-
pass
|
583
|
-
else:
|
584
|
-
raise ValueError("Rotations have to be a rank 2 or 3 array.")
|
585
547
|
self._rotations = rotations.astype(np.float32)
|
586
548
|
|
549
|
+
@staticmethod
|
550
|
+
def _get_data(attribute, output_shape: Tuple[int], reverse: bool = False):
|
551
|
+
if isinstance(attribute, Density):
|
552
|
+
attribute = attribute.data
|
553
|
+
|
554
|
+
if attribute is not None:
|
555
|
+
if reverse:
|
556
|
+
attribute = be.reverse(attribute)
|
557
|
+
attribute = attribute.reshape(tuple(int(x) for x in output_shape))
|
558
|
+
|
559
|
+
return attribute
|
560
|
+
|
587
561
|
@property
|
588
562
|
def target(self):
|
589
|
-
"""
|
590
|
-
|
591
|
-
|
592
|
-
|
563
|
+
"""
|
564
|
+
Return the target.
|
565
|
+
|
566
|
+
Returns
|
567
|
+
-------
|
568
|
+
NDArray
|
569
|
+
Output data.
|
570
|
+
"""
|
571
|
+
return self._get_data(self._target, self._output_target_shape, False)
|
593
572
|
|
594
|
-
|
595
|
-
|
573
|
+
@property
|
574
|
+
def target_mask(self):
|
575
|
+
"""
|
576
|
+
Return the target mask.
|
577
|
+
|
578
|
+
Returns
|
579
|
+
-------
|
580
|
+
NDArray
|
581
|
+
Output data.
|
582
|
+
"""
|
583
|
+
target_mask = getattr(self, "_target_mask", None)
|
584
|
+
return self._get_data(target_mask, self._output_target_shape, False)
|
596
585
|
|
597
586
|
@property
|
598
587
|
def template(self):
|
599
|
-
"""
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
588
|
+
"""
|
589
|
+
Return the reversed template.
|
590
|
+
|
591
|
+
Returns
|
592
|
+
-------
|
593
|
+
NDArray
|
594
|
+
Output data.
|
595
|
+
"""
|
596
|
+
return self._get_data(self._template, self._output_template_shape, True)
|
597
|
+
|
598
|
+
@property
|
599
|
+
def template_mask(self):
|
600
|
+
"""
|
601
|
+
Return the reversed template mask.
|
602
|
+
|
603
|
+
Returns
|
604
|
+
-------
|
605
|
+
NDArray
|
606
|
+
Output data.
|
607
|
+
"""
|
608
|
+
template_mask = getattr(self, "_template_mask", None)
|
609
|
+
return self._get_data(template_mask, self._output_template_shape, True)
|
610
|
+
|
611
|
+
@target.setter
|
612
|
+
def target(self, arr: NDArray):
|
613
|
+
"""
|
614
|
+
Set :py:attr:`MatchingData.target`.
|
615
|
+
|
616
|
+
Parameters
|
617
|
+
----------
|
618
|
+
arr : NDArray
|
619
|
+
Array to set as the target.
|
620
|
+
"""
|
621
|
+
self._target = arr
|
606
622
|
|
607
623
|
@template.setter
|
608
|
-
def template(self,
|
624
|
+
def template(self, arr: NDArray):
|
609
625
|
"""
|
610
|
-
Set
|
611
|
-
:py:attr:`MatchingData.template_mask` to an
|
612
|
-
ones.
|
626
|
+
Set :py:attr:`MatchingData.template` and initializes
|
627
|
+
:py:attr:`MatchingData.template_mask` to an to an uninformative
|
628
|
+
mask filled with ones if not already defined.
|
613
629
|
|
614
630
|
Parameters
|
615
631
|
----------
|
616
|
-
|
632
|
+
arr : NDArray
|
617
633
|
Array to set as the template.
|
618
634
|
"""
|
619
|
-
self.
|
620
|
-
if self
|
621
|
-
self._template_mask =
|
622
|
-
shape=
|
635
|
+
self._template = arr
|
636
|
+
if getattr(self, "_template_mask", None) is None:
|
637
|
+
self._template_mask = be.full(
|
638
|
+
shape=arr.shape, dtype=be._float_dtype, fill_value=1
|
623
639
|
)
|
624
640
|
|
625
|
-
|
641
|
+
@staticmethod
|
642
|
+
def _set_mask(mask, shape: Tuple[int]):
|
643
|
+
if mask is not None:
|
644
|
+
if mask.shape != shape:
|
645
|
+
raise ValueError(
|
646
|
+
"Mask and respective data have to have the same shape."
|
647
|
+
)
|
648
|
+
return mask
|
626
649
|
|
627
|
-
@
|
628
|
-
def target_mask(self):
|
629
|
-
"""
|
630
|
-
|
631
|
-
if isinstance(self._target_mask, Density):
|
632
|
-
target_mask = self._target_mask.data
|
650
|
+
@target_mask.setter
|
651
|
+
def target_mask(self, arr: NDArray):
|
652
|
+
"""
|
653
|
+
Set :py:attr:`MatchingData.target_mask`.
|
633
654
|
|
634
|
-
|
635
|
-
|
636
|
-
|
655
|
+
Parameters
|
656
|
+
----------
|
657
|
+
arr : NDArray
|
658
|
+
Array to set as the target_mask.
|
659
|
+
"""
|
660
|
+
self._target_mask = self._set_mask(mask=arr, shape=self._target.shape)
|
637
661
|
|
638
|
-
|
662
|
+
@template_mask.setter
|
663
|
+
def template_mask(self, arr: NDArray):
|
664
|
+
"""
|
665
|
+
Set :py:attr:`MatchingData.template_mask`.
|
639
666
|
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
667
|
+
Parameters
|
668
|
+
----------
|
669
|
+
arr : NDArray
|
670
|
+
Array to set as the template_mask.
|
671
|
+
"""
|
672
|
+
self._template_mask = self._set_mask(mask=arr, shape=self._template.shape)
|
645
673
|
|
646
|
-
|
674
|
+
@staticmethod
|
675
|
+
def _set_filter(composable_filter) -> Optional[Compose]:
|
676
|
+
if isinstance(composable_filter, Compose):
|
677
|
+
return composable_filter
|
678
|
+
return None
|
647
679
|
|
648
680
|
@property
|
649
|
-
def
|
681
|
+
def template_filter(self) -> Optional[Compose]:
|
650
682
|
"""
|
651
|
-
|
683
|
+
Returns the composable template filter.
|
652
684
|
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
685
|
+
Returns
|
686
|
+
-------
|
687
|
+
:py:class:`tme.preprocessing.Compose` | None
|
688
|
+
Composable template filter or None.
|
657
689
|
"""
|
658
|
-
|
659
|
-
if isinstance(self._template_mask, Density):
|
660
|
-
mask = self._template_mask.data
|
690
|
+
return getattr(self, "_template_filter", None)
|
661
691
|
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
return mask
|
692
|
+
@property
|
693
|
+
def target_filter(self) -> Optional[Compose]:
|
694
|
+
"""
|
695
|
+
Returns the composable target filter.
|
667
696
|
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
697
|
+
Returns
|
698
|
+
-------
|
699
|
+
:py:class:`tme.preprocessing.Compose` | None
|
700
|
+
Composable filter or None.
|
701
|
+
"""
|
702
|
+
return getattr(self, "_target_filter", None)
|
673
703
|
|
674
|
-
|
704
|
+
@template_filter.setter
|
705
|
+
def template_filter(self, composable_filter: Compose):
|
706
|
+
self._template_filter = self._set_filter(composable_filter)
|
707
|
+
|
708
|
+
@target_filter.setter
|
709
|
+
def target_filter(self, composable_filter: Compose):
|
710
|
+
self._target_filter = self._set_filter(composable_filter)
|
675
711
|
|
676
712
|
def _split_rotations_on_jobs(self, n_jobs: int) -> List[NDArray]:
|
677
713
|
"""
|
@@ -696,3 +732,11 @@ class MatchingData:
|
|
696
732
|
end_rot = None
|
697
733
|
rot_list.append(self.rotations[init_rot:end_rot])
|
698
734
|
return rot_list
|
735
|
+
|
736
|
+
def _free_data(self):
|
737
|
+
"""
|
738
|
+
Free (dereference) data arrays owned by the class instance.
|
739
|
+
"""
|
740
|
+
attrs = ("_target", "_template", "_template_mask", "_target_mask")
|
741
|
+
for attr in attrs:
|
742
|
+
setattr(self, attr, None)
|