pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
- pytme-0.2.9.data/scripts/match_template.py +1135 -0
- pytme-0.2.9.data/scripts/postprocess.py +622 -0
- pytme-0.2.9.data/scripts/preprocess.py +209 -0
- pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
- pytme-0.2.9.dist-info/METADATA +95 -0
- pytme-0.2.9.dist-info/RECORD +119 -0
- pytme-0.2.9.dist-info/WHEEL +5 -0
- pytme-0.2.9.dist-info/entry_points.txt +6 -0
- pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
- pytme-0.2.9.dist-info/top_level.txt +3 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +97 -0
- scripts/match_template.py +1135 -0
- scripts/postprocess.py +622 -0
- scripts/preprocess.py +209 -0
- scripts/preprocessor_gui.py +1227 -0
- tests/__init__.py +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +216 -0
- tests/test_backends.py +446 -0
- tests/test_density.py +503 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +124 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +226 -0
- tests/test_matching_utils.py +189 -0
- tests/test_orientations.py +175 -0
- tests/test_parser.py +33 -0
- tests/test_rotations.py +153 -0
- tests/test_structure.py +247 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer/__init__.py +2 -0
- tme/analyzer/_utils.py +186 -0
- tme/analyzer/aggregation.py +577 -0
- tme/analyzer/peaks.py +953 -0
- tme/backends/__init__.py +171 -0
- tme/backends/_cupy_utils.py +734 -0
- tme/backends/_jax_utils.py +188 -0
- tme/backends/cupy_backend.py +294 -0
- tme/backends/jax_backend.py +314 -0
- tme/backends/matching_backend.py +1270 -0
- tme/backends/mlx_backend.py +241 -0
- tme/backends/npfftw_backend.py +583 -0
- tme/backends/pytorch_backend.py +430 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2263 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/filters/__init__.py +6 -0
- tme/filters/_utils.py +311 -0
- tme/filters/bandpass.py +230 -0
- tme/filters/compose.py +81 -0
- tme/filters/ctf.py +393 -0
- tme/filters/reconstruction.py +160 -0
- tme/filters/wedge.py +542 -0
- tme/filters/whitening.py +191 -0
- tme/matching_data.py +863 -0
- tme/matching_exhaustive.py +497 -0
- tme/matching_optimization.py +1311 -0
- tme/matching_scores.py +1183 -0
- tme/matching_utils.py +1188 -0
- tme/memory.py +337 -0
- tme/orientations.py +598 -0
- tme/parser.py +685 -0
- tme/preprocessor.py +1329 -0
- tme/rotations.py +350 -0
- tme/structure.py +1864 -0
- tme/types.py +13 -0
tme/matching_utils.py
ADDED
@@ -0,0 +1,1188 @@
|
|
1
|
+
""" Utility functions for template matching.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import pickle
|
10
|
+
from shutil import move
|
11
|
+
from joblib import Parallel
|
12
|
+
from tempfile import mkstemp
|
13
|
+
from itertools import product
|
14
|
+
from concurrent.futures import ThreadPoolExecutor
|
15
|
+
from typing import Tuple, Dict, Callable, Optional
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
from tqdm import tqdm
|
19
|
+
from scipy.spatial import ConvexHull
|
20
|
+
from scipy.ndimage import gaussian_filter
|
21
|
+
|
22
|
+
from .backends import backend as be
|
23
|
+
from .memory import estimate_memory_usage
|
24
|
+
from .types import NDArray, BackendArray
|
25
|
+
|
26
|
+
|
27
|
+
def noop(*args, **kwargs):
|
28
|
+
pass
|
29
|
+
|
30
|
+
|
31
|
+
def identity(arr, *args):
|
32
|
+
return arr
|
33
|
+
|
34
|
+
|
35
|
+
def conditional_execute(
|
36
|
+
func: Callable,
|
37
|
+
execute_operation: bool,
|
38
|
+
alt_func: Callable = noop,
|
39
|
+
) -> Callable:
|
40
|
+
"""
|
41
|
+
Return the given function or a no-op function based on execute_operation.
|
42
|
+
|
43
|
+
Parameters
|
44
|
+
----------
|
45
|
+
func : Callable
|
46
|
+
Callable.
|
47
|
+
alt_func : Callable
|
48
|
+
Callable to return if ``execute_operation`` is False, no-op by default.
|
49
|
+
execute_operation : bool
|
50
|
+
Whether to return ``func`` or a ``alt_func`` function.
|
51
|
+
|
52
|
+
Returns
|
53
|
+
-------
|
54
|
+
Callable
|
55
|
+
``func`` if ``execute_operation`` else ``alt_func``.
|
56
|
+
"""
|
57
|
+
|
58
|
+
return func if execute_operation else alt_func
|
59
|
+
|
60
|
+
|
61
|
+
def normalize_template(
|
62
|
+
template: BackendArray, mask: BackendArray, n_observations: float, axis=None
|
63
|
+
) -> BackendArray:
|
64
|
+
"""
|
65
|
+
Standardizes ``template`` to zero mean and unit standard deviation in ``mask``.
|
66
|
+
|
67
|
+
.. warning:: ``template`` is modified during the operation.
|
68
|
+
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
template : BackendArray
|
72
|
+
Input data.
|
73
|
+
mask : BackendArray
|
74
|
+
Mask of the same shape as ``template``.
|
75
|
+
n_observations : float
|
76
|
+
Sum of mask elements.
|
77
|
+
axis : tuple of floats, optional
|
78
|
+
Axis to normalize over, all axis by default.
|
79
|
+
|
80
|
+
Returns
|
81
|
+
-------
|
82
|
+
BackendArray
|
83
|
+
Standardized input data.
|
84
|
+
|
85
|
+
References
|
86
|
+
----------
|
87
|
+
.. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
|
88
|
+
"""
|
89
|
+
masked_mean = be.sum(be.multiply(template, mask), axis=axis, keepdims=True)
|
90
|
+
masked_mean = be.divide(masked_mean, n_observations)
|
91
|
+
masked_std = be.sum(
|
92
|
+
be.multiply(be.square(template), mask), axis=axis, keepdims=True
|
93
|
+
)
|
94
|
+
masked_std = be.subtract(masked_std / n_observations, be.square(masked_mean))
|
95
|
+
masked_std = be.sqrt(be.maximum(masked_std, 0))
|
96
|
+
|
97
|
+
template = be.subtract(template, masked_mean, out=template)
|
98
|
+
template = be.divide(template, masked_std, out=template)
|
99
|
+
return be.multiply(template, mask, out=template)
|
100
|
+
|
101
|
+
|
102
|
+
def _normalize_template_overflow_safe(
|
103
|
+
template: BackendArray, mask: BackendArray, n_observations: float, axis=None
|
104
|
+
) -> BackendArray:
|
105
|
+
_template = be.astype(template, be._overflow_safe_dtype)
|
106
|
+
_mask = be.astype(mask, be._overflow_safe_dtype)
|
107
|
+
normalize_template(
|
108
|
+
template=_template, mask=_mask, n_observations=n_observations, axis=axis
|
109
|
+
)
|
110
|
+
template[:] = be.astype(_template, template.dtype)
|
111
|
+
return template
|
112
|
+
|
113
|
+
|
114
|
+
def generate_tempfile_name(suffix: str = None) -> str:
|
115
|
+
"""
|
116
|
+
Returns the path to a temporary file with given suffix. If defined. the
|
117
|
+
environment variable TMPDIR is used as base.
|
118
|
+
|
119
|
+
Parameters
|
120
|
+
----------
|
121
|
+
suffix : str, optional
|
122
|
+
File suffix. By default the file has no suffix.
|
123
|
+
|
124
|
+
Returns
|
125
|
+
-------
|
126
|
+
str
|
127
|
+
The generated filename
|
128
|
+
"""
|
129
|
+
return mkstemp(suffix=suffix)[1]
|
130
|
+
|
131
|
+
|
132
|
+
def array_to_memmap(arr: NDArray, filename: str = None, mode: str = "r") -> np.memmap:
|
133
|
+
"""
|
134
|
+
Converts a obj:`numpy.ndarray` to a obj:`numpy.memmap`.
|
135
|
+
|
136
|
+
Parameters
|
137
|
+
----------
|
138
|
+
arr : obj:`numpy.ndarray`
|
139
|
+
Input data.
|
140
|
+
filename : str, optional
|
141
|
+
Path to new memmap, :py:meth:`generate_tempfile_name` is used by default.
|
142
|
+
mode : str, optional
|
143
|
+
Mode to open the returned memmap object in, defautls to 'r'.
|
144
|
+
|
145
|
+
Returns
|
146
|
+
-------
|
147
|
+
obj:`numpy.memmap`
|
148
|
+
Memmaped array in reading mode.
|
149
|
+
"""
|
150
|
+
if filename is None:
|
151
|
+
filename = generate_tempfile_name()
|
152
|
+
|
153
|
+
arr.tofile(filename)
|
154
|
+
return np.memmap(filename, mode=mode, dtype=arr.dtype, shape=arr.shape)
|
155
|
+
|
156
|
+
|
157
|
+
def memmap_to_array(arr: NDArray) -> NDArray:
|
158
|
+
"""
|
159
|
+
Convert a obj:`numpy.memmap` to a obj:`numpy.ndarray` and delete the memmap.
|
160
|
+
|
161
|
+
Parameters
|
162
|
+
----------
|
163
|
+
arr : obj:`numpy.memmap`
|
164
|
+
Input data.
|
165
|
+
|
166
|
+
Returns
|
167
|
+
-------
|
168
|
+
obj:`numpy.ndarray`
|
169
|
+
In-memory version of ``arr``.
|
170
|
+
"""
|
171
|
+
if isinstance(arr, np.memmap):
|
172
|
+
memmap_filepath = arr.filename
|
173
|
+
arr = np.array(arr)
|
174
|
+
os.remove(memmap_filepath)
|
175
|
+
return arr
|
176
|
+
|
177
|
+
|
178
|
+
def write_pickle(data: object, filename: str) -> None:
|
179
|
+
"""
|
180
|
+
Serialize and write data to a file invalidating the input data.
|
181
|
+
|
182
|
+
Parameters
|
183
|
+
----------
|
184
|
+
data : iterable or object
|
185
|
+
The data to be serialized.
|
186
|
+
filename : str
|
187
|
+
The name of the file where the serialized data will be written.
|
188
|
+
|
189
|
+
See Also
|
190
|
+
--------
|
191
|
+
:py:meth:`load_pickle`
|
192
|
+
"""
|
193
|
+
if type(data) not in (list, tuple):
|
194
|
+
data = (data,)
|
195
|
+
|
196
|
+
dirname = os.path.dirname(filename)
|
197
|
+
with open(filename, "wb") as ofile, ThreadPoolExecutor() as executor:
|
198
|
+
for i in range(len(data)):
|
199
|
+
futures = []
|
200
|
+
item = data[i]
|
201
|
+
if isinstance(item, np.memmap):
|
202
|
+
_, new_filename = mkstemp(suffix=".mm", dir=dirname)
|
203
|
+
new_item = ("np.memmap", item.shape, item.dtype, new_filename)
|
204
|
+
futures.append(executor.submit(move, item.filename, new_filename))
|
205
|
+
item = new_item
|
206
|
+
pickle.dump(item, ofile)
|
207
|
+
for future in futures:
|
208
|
+
future.result()
|
209
|
+
|
210
|
+
|
211
|
+
def load_pickle(filename: str) -> object:
|
212
|
+
"""
|
213
|
+
Load and deserialize data written by :py:meth:`write_pickle`.
|
214
|
+
|
215
|
+
Parameters
|
216
|
+
----------
|
217
|
+
filename : str
|
218
|
+
The name of the file to read and deserialize data from.
|
219
|
+
|
220
|
+
Returns
|
221
|
+
-------
|
222
|
+
object or iterable
|
223
|
+
The deserialized data.
|
224
|
+
|
225
|
+
See Also
|
226
|
+
--------
|
227
|
+
:py:meth:`write_pickle`
|
228
|
+
"""
|
229
|
+
|
230
|
+
def _load_pickle(file_handle):
|
231
|
+
try:
|
232
|
+
while True:
|
233
|
+
yield pickle.load(file_handle)
|
234
|
+
except EOFError:
|
235
|
+
pass
|
236
|
+
|
237
|
+
def _is_pickle_memmap(data):
|
238
|
+
ret = False
|
239
|
+
if isinstance(data[0], str):
|
240
|
+
if data[0] == "np.memmap":
|
241
|
+
ret = True
|
242
|
+
return ret
|
243
|
+
|
244
|
+
items = []
|
245
|
+
with open(filename, "rb") as ifile:
|
246
|
+
for data in _load_pickle(ifile):
|
247
|
+
if isinstance(data, tuple):
|
248
|
+
if _is_pickle_memmap(data):
|
249
|
+
_, shape, dtype, filename = data
|
250
|
+
data = np.memmap(filename, shape=shape, dtype=dtype)
|
251
|
+
items.append(data)
|
252
|
+
return items[0] if len(items) == 1 else items
|
253
|
+
|
254
|
+
|
255
|
+
def compute_parallelization_schedule(
|
256
|
+
shape1: NDArray,
|
257
|
+
shape2: NDArray,
|
258
|
+
max_cores: int,
|
259
|
+
max_ram: int,
|
260
|
+
matching_method: str,
|
261
|
+
split_axes: Tuple[int] = None,
|
262
|
+
backend: str = None,
|
263
|
+
split_only_outer: bool = False,
|
264
|
+
shape1_padding: NDArray = None,
|
265
|
+
analyzer_method: str = None,
|
266
|
+
max_splits: int = 256,
|
267
|
+
float_nbytes: int = 4,
|
268
|
+
complex_nbytes: int = 8,
|
269
|
+
integer_nbytes: int = 4,
|
270
|
+
) -> Tuple[Dict, int, int]:
|
271
|
+
"""
|
272
|
+
Computes a parallelization schedule for a given computation.
|
273
|
+
|
274
|
+
This function estimates the amount of memory that would be used by a computation
|
275
|
+
and breaks down the computation into smaller parts that can be executed in parallel
|
276
|
+
without exceeding the specified limits on the number of cores and memory.
|
277
|
+
|
278
|
+
Parameters
|
279
|
+
----------
|
280
|
+
shape1 : NDArray
|
281
|
+
The shape of the first input array.
|
282
|
+
shape1_padding : NDArray, optional
|
283
|
+
Padding for shape1, None by default.
|
284
|
+
shape2 : NDArray
|
285
|
+
The shape of the second input array.
|
286
|
+
max_cores : int
|
287
|
+
The maximum number of cores that can be used.
|
288
|
+
max_ram : int
|
289
|
+
The maximum amount of memory that can be used.
|
290
|
+
matching_method : str
|
291
|
+
The metric used for scoring the computations.
|
292
|
+
split_axes : tuple
|
293
|
+
Axes that can be used for splitting. By default all are considered.
|
294
|
+
backend : str, optional
|
295
|
+
Backend used for computations.
|
296
|
+
split_only_outer : bool, optional
|
297
|
+
Whether only outer splits sould be considered.
|
298
|
+
analyzer_method : str
|
299
|
+
The method used for score analysis.
|
300
|
+
max_splits : int, optional
|
301
|
+
The maximum number of parts that the computation can be split into,
|
302
|
+
by default 256.
|
303
|
+
float_nbytes : int
|
304
|
+
Number of bytes of the used float, e.g. 4 for float32.
|
305
|
+
complex_nbytes : int
|
306
|
+
Number of bytes of the used complex, e.g. 8 for complex64.
|
307
|
+
integer_nbytes : int
|
308
|
+
Number of bytes of the used integer, e.g. 4 for int32.
|
309
|
+
|
310
|
+
Notes
|
311
|
+
-----
|
312
|
+
This function assumes that no residual memory remains after each split,
|
313
|
+
which not always holds true, e.g. when using
|
314
|
+
:py:class:`tme.analyzer.MaxScoreOverRotations`.
|
315
|
+
|
316
|
+
Returns
|
317
|
+
-------
|
318
|
+
dict
|
319
|
+
The optimal splits for each axis of the first input tensor.
|
320
|
+
int
|
321
|
+
The number of outer jobs.
|
322
|
+
int
|
323
|
+
The number of inner jobs per outer job.
|
324
|
+
"""
|
325
|
+
shape1 = tuple(int(x) for x in shape1)
|
326
|
+
shape2 = tuple(int(x) for x in shape2)
|
327
|
+
|
328
|
+
if shape1_padding is None:
|
329
|
+
shape1_padding = np.zeros_like(shape1)
|
330
|
+
core_assignments = []
|
331
|
+
for i in range(1, int(max_cores**0.5) + 1):
|
332
|
+
if max_cores % i == 0:
|
333
|
+
core_assignments.append((i, max_cores // i))
|
334
|
+
core_assignments.append((max_cores // i, i))
|
335
|
+
|
336
|
+
if split_only_outer:
|
337
|
+
core_assignments = [(1, max_cores)]
|
338
|
+
|
339
|
+
possible_params, split_axis = [], np.argmax(shape1)
|
340
|
+
|
341
|
+
split_axis_index = split_axis
|
342
|
+
if split_axes is not None:
|
343
|
+
split_axis, split_axis_index = split_axes[0], 0
|
344
|
+
else:
|
345
|
+
split_axes = tuple(i for i in range(len(shape1)))
|
346
|
+
|
347
|
+
split_factor, n_splits = [1 for _ in range(len(shape1))], 0
|
348
|
+
while n_splits <= max_splits:
|
349
|
+
splits = {k: split_factor[k] for k in range(len(split_factor))}
|
350
|
+
array_slices = split_shape(shape=shape1, splits=splits)
|
351
|
+
array_widths = [
|
352
|
+
tuple(x.stop - x.start for x in split) for split in array_slices
|
353
|
+
]
|
354
|
+
n_splits = np.prod(list(splits.values()))
|
355
|
+
|
356
|
+
for inner_cores, outer_cores in core_assignments:
|
357
|
+
if outer_cores > n_splits:
|
358
|
+
continue
|
359
|
+
ram_usage = [
|
360
|
+
estimate_memory_usage(
|
361
|
+
shape1=tuple(sum(x) for x in zip(shp, shape1_padding)),
|
362
|
+
shape2=shape2,
|
363
|
+
matching_method=matching_method,
|
364
|
+
analyzer_method=analyzer_method,
|
365
|
+
backend=backend,
|
366
|
+
ncores=inner_cores,
|
367
|
+
float_nbytes=float_nbytes,
|
368
|
+
complex_nbytes=complex_nbytes,
|
369
|
+
integer_nbytes=integer_nbytes,
|
370
|
+
)
|
371
|
+
for shp in array_widths
|
372
|
+
]
|
373
|
+
max_usage = 0
|
374
|
+
for i in range(0, len(ram_usage), outer_cores):
|
375
|
+
usage = np.sum(ram_usage[i : (i + outer_cores)])
|
376
|
+
if usage > max_usage:
|
377
|
+
max_usage = usage
|
378
|
+
|
379
|
+
inits = n_splits // outer_cores
|
380
|
+
if max_usage < max_ram:
|
381
|
+
possible_params.append(
|
382
|
+
(*split_factor, outer_cores, inner_cores, n_splits, inits)
|
383
|
+
)
|
384
|
+
split_factor[split_axis] += 1
|
385
|
+
|
386
|
+
split_axis_index += 1
|
387
|
+
if split_axis_index == len(split_axes):
|
388
|
+
split_axis_index = 0
|
389
|
+
split_axis = split_axes[split_axis_index]
|
390
|
+
|
391
|
+
possible_params = np.array(possible_params)
|
392
|
+
if not len(possible_params):
|
393
|
+
print(
|
394
|
+
"No suitable assignment found. Consider increasing "
|
395
|
+
"max_ram or decrease max_cores."
|
396
|
+
)
|
397
|
+
return None, None
|
398
|
+
|
399
|
+
init = possible_params.shape[1] - 1
|
400
|
+
possible_params = possible_params[
|
401
|
+
np.lexsort((possible_params[:, init], possible_params[:, (init - 1)]))
|
402
|
+
]
|
403
|
+
splits = {k: possible_params[0, k] for k in range(len(shape1))}
|
404
|
+
core_assignment = (
|
405
|
+
possible_params[0, len(shape1)],
|
406
|
+
possible_params[0, (len(shape1) + 1)],
|
407
|
+
)
|
408
|
+
|
409
|
+
return splits, core_assignment
|
410
|
+
|
411
|
+
|
412
|
+
def _center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[slice]:
|
413
|
+
"""Extract the center slice of ``current_shape`` to retrieve ``new_shape``."""
|
414
|
+
new_shape = tuple(int(x) for x in new_shape)
|
415
|
+
current_shape = tuple(int(x) for x in current_shape)
|
416
|
+
starts = tuple((x - y) // 2 for x, y in zip(current_shape, new_shape))
|
417
|
+
stops = tuple(sum(stop) for stop in zip(starts, new_shape))
|
418
|
+
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
419
|
+
return box
|
420
|
+
|
421
|
+
|
422
|
+
def centered(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
|
423
|
+
"""
|
424
|
+
Extract the centered portion of an array based on a new shape.
|
425
|
+
|
426
|
+
Parameters
|
427
|
+
----------
|
428
|
+
arr : BackendArray
|
429
|
+
Input data.
|
430
|
+
new_shape : tuple of ints
|
431
|
+
Desired shape for the central portion.
|
432
|
+
|
433
|
+
Returns
|
434
|
+
-------
|
435
|
+
BackendArray
|
436
|
+
Central portion of the array with shape ``new_shape``.
|
437
|
+
|
438
|
+
References
|
439
|
+
----------
|
440
|
+
.. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L388
|
441
|
+
"""
|
442
|
+
box = _center_slice(arr.shape, new_shape=new_shape)
|
443
|
+
return arr[box]
|
444
|
+
|
445
|
+
|
446
|
+
def centered_mask(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
|
447
|
+
"""
|
448
|
+
Mask the centered portion of an array based on a new shape.
|
449
|
+
|
450
|
+
Parameters
|
451
|
+
----------
|
452
|
+
arr : BackendArray
|
453
|
+
Input data.
|
454
|
+
new_shape : tuple of ints
|
455
|
+
Desired shape for the mask.
|
456
|
+
|
457
|
+
Returns
|
458
|
+
-------
|
459
|
+
BackendArray
|
460
|
+
Array with central portion unmasked and the rest set to 0.
|
461
|
+
"""
|
462
|
+
box = _center_slice(arr.shape, new_shape=new_shape)
|
463
|
+
mask = np.zeros_like(arr)
|
464
|
+
mask[box] = 1
|
465
|
+
arr *= mask
|
466
|
+
return arr
|
467
|
+
|
468
|
+
|
469
|
+
def apply_convolution_mode(
|
470
|
+
arr: BackendArray,
|
471
|
+
convolution_mode: str,
|
472
|
+
s1: Tuple[int],
|
473
|
+
s2: Tuple[int],
|
474
|
+
convolution_shape: Tuple[int] = None,
|
475
|
+
mask_output: bool = False,
|
476
|
+
) -> BackendArray:
|
477
|
+
"""
|
478
|
+
Applies convolution_mode to ``arr``.
|
479
|
+
|
480
|
+
Parameters
|
481
|
+
----------
|
482
|
+
arr : BackendArray
|
483
|
+
Array containing convolution result of arrays with shape s1 and s2.
|
484
|
+
convolution_mode : str
|
485
|
+
Analogous to mode in obj:`scipy.signal.convolve`:
|
486
|
+
|
487
|
+
+---------+----------------------------------------------------------+
|
488
|
+
| 'full' | returns full template matching result of the inputs. |
|
489
|
+
+---------+----------------------------------------------------------+
|
490
|
+
| 'valid' | returns elements that do not rely on zero-padding.. |
|
491
|
+
+---------+----------------------------------------------------------+
|
492
|
+
| 'same' | output is the same size as s1. |
|
493
|
+
+---------+----------------------------------------------------------+
|
494
|
+
s1 : tuple of ints
|
495
|
+
Tuple of integers corresponding to shape of convolution array 1.
|
496
|
+
s2 : tuple of ints
|
497
|
+
Tuple of integers corresponding to shape of convolution array 2.
|
498
|
+
convolution_shape : tuple of ints, optional
|
499
|
+
Size of the actually computed convolution. s1 + s2 - 1 by default.
|
500
|
+
mask_output : bool, optional
|
501
|
+
Whether to mask values outside of convolution_mode rather than
|
502
|
+
removing them. Defaults to False.
|
503
|
+
|
504
|
+
Returns
|
505
|
+
-------
|
506
|
+
BackendArray
|
507
|
+
The array after applying the convolution mode.
|
508
|
+
"""
|
509
|
+
# Remove padding to next fast Fourier length
|
510
|
+
if convolution_shape is None:
|
511
|
+
convolution_shape = [s1[i] + s2[i] - 1 for i in range(len(s1))]
|
512
|
+
arr = arr[tuple(slice(0, x) for x in convolution_shape)]
|
513
|
+
|
514
|
+
if convolution_mode not in ("full", "same", "valid"):
|
515
|
+
raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
|
516
|
+
|
517
|
+
func = centered_mask if mask_output else centered
|
518
|
+
if convolution_mode == "full":
|
519
|
+
return arr
|
520
|
+
elif convolution_mode == "same":
|
521
|
+
return func(arr, s1)
|
522
|
+
elif convolution_mode == "valid":
|
523
|
+
valid_shape = [s1[i] - s2[i] + s2[i] % 2 for i in range(arr.ndim)]
|
524
|
+
return func(arr, valid_shape)
|
525
|
+
|
526
|
+
|
527
|
+
def compute_full_convolution_index(
|
528
|
+
outer_shape: Tuple[int],
|
529
|
+
inner_shape: Tuple[int],
|
530
|
+
outer_split: Tuple[slice],
|
531
|
+
inner_split: Tuple[slice],
|
532
|
+
) -> Tuple[slice]:
|
533
|
+
"""
|
534
|
+
Computes the position of the convolution of pieces in the full convolution.
|
535
|
+
|
536
|
+
Parameters
|
537
|
+
----------
|
538
|
+
outer_shape : tuple
|
539
|
+
Tuple of integers corresponding to the shape of the outer array.
|
540
|
+
inner_shape : tuple
|
541
|
+
Tuple of integers corresponding to the shape of the inner array.
|
542
|
+
outer_split : tuple
|
543
|
+
Tuple of slices used to split outer array (see :py:meth:`split_shape`).
|
544
|
+
inner_split : tuple
|
545
|
+
Tuple of slices used to split inner array (see :py:meth:`split_shape`).
|
546
|
+
|
547
|
+
Returns
|
548
|
+
-------
|
549
|
+
tuple
|
550
|
+
Tuple of slices corresponding to the position of the given convolution
|
551
|
+
in the full convolution.
|
552
|
+
"""
|
553
|
+
outer_shape = np.asarray(outer_shape)
|
554
|
+
inner_shape = np.asarray(inner_shape)
|
555
|
+
|
556
|
+
outer_width = np.array([outer.stop - outer.start for outer in outer_split])
|
557
|
+
inner_width = np.array([inner.stop - inner.start for inner in inner_split])
|
558
|
+
convolution_shape = outer_width + inner_width - 1
|
559
|
+
|
560
|
+
end_inner = np.array([inner.stop for inner in inner_split]).astype(int)
|
561
|
+
start_outer = np.array([outer.start for outer in outer_split]).astype(int)
|
562
|
+
|
563
|
+
offsets = start_outer + inner_shape - end_inner
|
564
|
+
|
565
|
+
score_slice = tuple(
|
566
|
+
(slice(offset, offset + shape))
|
567
|
+
for offset, shape in zip(offsets, convolution_shape)
|
568
|
+
)
|
569
|
+
|
570
|
+
return score_slice
|
571
|
+
|
572
|
+
|
573
|
+
def split_shape(
|
574
|
+
shape: Tuple[int], splits: Dict, equal_shape: bool = True
|
575
|
+
) -> Tuple[slice]:
|
576
|
+
"""
|
577
|
+
Splits ``shape`` into equally sized and potentially overlapping subsets.
|
578
|
+
|
579
|
+
Parameters
|
580
|
+
----------
|
581
|
+
shape : tuple of ints
|
582
|
+
Shape to split.
|
583
|
+
splits : dict
|
584
|
+
Dictionary mapping axis number to number of splits.
|
585
|
+
equal_shape : dict
|
586
|
+
Whether the subsets should be of equal shape, True by default.
|
587
|
+
|
588
|
+
Returns
|
589
|
+
-------
|
590
|
+
tuple
|
591
|
+
Tuple of slice with requested split combinations.
|
592
|
+
"""
|
593
|
+
ndim = len(shape)
|
594
|
+
splits = {k: max(splits.get(k, 1), 1) for k in range(ndim)}
|
595
|
+
ret_shape = np.divide(shape, tuple(splits[i] for i in range(ndim)))
|
596
|
+
if equal_shape:
|
597
|
+
ret_shape = np.ceil(ret_shape).astype(int)
|
598
|
+
ret_shape = ret_shape.astype(int)
|
599
|
+
|
600
|
+
slice_list = [
|
601
|
+
tuple(
|
602
|
+
(
|
603
|
+
(slice((n_splits * length), (n_splits + 1) * length))
|
604
|
+
if n_splits < splits.get(axis, 1) - 1
|
605
|
+
else (
|
606
|
+
(slice(shape[axis] - length, shape[axis]))
|
607
|
+
if equal_shape
|
608
|
+
else (slice((n_splits * length), shape[axis]))
|
609
|
+
)
|
610
|
+
)
|
611
|
+
for n_splits in range(splits.get(axis, 1))
|
612
|
+
)
|
613
|
+
for length, axis in zip(ret_shape, splits.keys())
|
614
|
+
]
|
615
|
+
|
616
|
+
splits = tuple(product(*slice_list))
|
617
|
+
|
618
|
+
return splits
|
619
|
+
|
620
|
+
|
621
|
+
def rigid_transform(
|
622
|
+
coordinates: NDArray,
|
623
|
+
rotation_matrix: NDArray,
|
624
|
+
out: NDArray,
|
625
|
+
translation: NDArray,
|
626
|
+
use_geometric_center: bool = False,
|
627
|
+
coordinates_mask: NDArray = None,
|
628
|
+
out_mask: NDArray = None,
|
629
|
+
center: NDArray = None,
|
630
|
+
) -> None:
|
631
|
+
"""
|
632
|
+
Apply a rigid transformation (rotation and translation) to given coordinates.
|
633
|
+
|
634
|
+
Parameters
|
635
|
+
----------
|
636
|
+
coordinates : NDArray
|
637
|
+
An array representing the coordinates to be transformed (d,n).
|
638
|
+
rotation_matrix : NDArray
|
639
|
+
The rotation matrix to be applied (d,d).
|
640
|
+
translation : NDArray
|
641
|
+
The translation vector to be applied (d,).
|
642
|
+
out : NDArray
|
643
|
+
The output array to store the transformed coordinates (d,n).
|
644
|
+
coordinates_mask : NDArray, optional
|
645
|
+
An array representing the mask for the coordinates (d,t).
|
646
|
+
out_mask : NDArray, optional
|
647
|
+
The output array to store the transformed coordinates mask (d,t).
|
648
|
+
use_geometric_center : bool, optional
|
649
|
+
Whether to use geometric or coordinate center.
|
650
|
+
"""
|
651
|
+
coordinate_dtype = coordinates.dtype
|
652
|
+
center = coordinates.mean(axis=1) if center is None else center
|
653
|
+
if not use_geometric_center:
|
654
|
+
coordinates = coordinates - center[:, None]
|
655
|
+
|
656
|
+
np.matmul(rotation_matrix, coordinates, out=out)
|
657
|
+
if use_geometric_center:
|
658
|
+
axis_max, axis_min = out.max(axis=1), out.min(axis=1)
|
659
|
+
axis_difference = axis_max - axis_min
|
660
|
+
translation = np.add(translation, center - axis_max + (axis_difference // 2))
|
661
|
+
else:
|
662
|
+
translation = np.add(translation, np.subtract(center, out.mean(axis=1)))
|
663
|
+
|
664
|
+
out += translation[:, None]
|
665
|
+
if coordinates_mask is not None and out_mask is not None:
|
666
|
+
if not use_geometric_center:
|
667
|
+
coordinates_mask = coordinates_mask - center[:, None]
|
668
|
+
np.matmul(rotation_matrix, coordinates_mask, out=out_mask)
|
669
|
+
out_mask += translation[:, None]
|
670
|
+
|
671
|
+
if not use_geometric_center and coordinate_dtype != out.dtype:
|
672
|
+
np.subtract(out.mean(axis=1), out.astype(int).mean(axis=1), out=translation)
|
673
|
+
out += translation[:, None]
|
674
|
+
|
675
|
+
|
676
|
+
def minimum_enclosing_box(
|
677
|
+
coordinates: NDArray, margin: NDArray = None, use_geometric_center: bool = False
|
678
|
+
) -> Tuple[int]:
|
679
|
+
"""
|
680
|
+
Computes the minimal enclosing box around coordinates with margin.
|
681
|
+
|
682
|
+
Parameters
|
683
|
+
----------
|
684
|
+
coordinates : NDArray
|
685
|
+
Coordinates of shape (d,n) to compute the enclosing box of.
|
686
|
+
margin : NDArray, optional
|
687
|
+
Box margin, zero by default.
|
688
|
+
use_geometric_center : bool, optional
|
689
|
+
Whether box accommodates the geometric or coordinate center, False by default.
|
690
|
+
|
691
|
+
Returns
|
692
|
+
-------
|
693
|
+
tuple of ints
|
694
|
+
Minimum enclosing box shape.
|
695
|
+
"""
|
696
|
+
from .extensions import max_euclidean_distance
|
697
|
+
|
698
|
+
point_cloud = np.asarray(coordinates)
|
699
|
+
dim = point_cloud.shape[0]
|
700
|
+
point_cloud = point_cloud - point_cloud.min(axis=1)[:, None]
|
701
|
+
|
702
|
+
margin = np.zeros(dim) if margin is None else margin
|
703
|
+
margin = np.asarray(margin).astype(int)
|
704
|
+
|
705
|
+
norm_cloud = point_cloud - point_cloud.mean(axis=1)[:, None]
|
706
|
+
# Adding one avoids clipping during scipy.ndimage.affine_transform
|
707
|
+
shape = np.repeat(
|
708
|
+
np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
|
709
|
+
).astype(int)
|
710
|
+
if use_geometric_center:
|
711
|
+
hull = ConvexHull(point_cloud.T)
|
712
|
+
distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
|
713
|
+
distance += np.linalg.norm(np.ones(dim))
|
714
|
+
shape = np.repeat(np.rint(distance).astype(int), dim)
|
715
|
+
|
716
|
+
return shape
|
717
|
+
|
718
|
+
|
719
|
+
def create_mask(mask_type: str, sigma_decay: float = 0, **kwargs) -> NDArray:
|
720
|
+
"""
|
721
|
+
Creates a mask of the specified type.
|
722
|
+
|
723
|
+
Parameters
|
724
|
+
----------
|
725
|
+
mask_type : str
|
726
|
+
Type of the mask to be created. Can be one of:
|
727
|
+
|
728
|
+
+---------+----------------------------------------------------------+
|
729
|
+
| box | Box mask (see :py:meth:`box_mask`) |
|
730
|
+
+---------+----------------------------------------------------------+
|
731
|
+
| tube | Cylindrical mask (see :py:meth:`tube_mask`) |
|
732
|
+
+---------+----------------------------------------------------------+
|
733
|
+
| ellipse | Ellipsoidal mask (see :py:meth:`elliptical_mask`) |
|
734
|
+
+---------+----------------------------------------------------------+
|
735
|
+
sigma_decay : float, optional
|
736
|
+
Smoothing along mask edges using a Gaussian filter, 0 by default.
|
737
|
+
kwargs : dict
|
738
|
+
Parameters passed to the indivdual mask creation funcitons.
|
739
|
+
|
740
|
+
Returns
|
741
|
+
-------
|
742
|
+
NDArray
|
743
|
+
The created mask.
|
744
|
+
|
745
|
+
Raises
|
746
|
+
------
|
747
|
+
ValueError
|
748
|
+
If the mask_type is invalid.
|
749
|
+
"""
|
750
|
+
mapping = {"ellipse": elliptical_mask, "box": box_mask, "tube": tube_mask}
|
751
|
+
if mask_type not in mapping:
|
752
|
+
raise ValueError(f"mask_type has to be one of {','.join(mapping.keys())}")
|
753
|
+
|
754
|
+
mask = mapping[mask_type](**kwargs)
|
755
|
+
if sigma_decay > 0:
|
756
|
+
mask_filter = gaussian_filter(mask.astype(np.float32), sigma=sigma_decay)
|
757
|
+
mask = np.add(mask, (1 - mask) * mask_filter)
|
758
|
+
mask[mask < np.exp(-np.square(sigma_decay))] = 0
|
759
|
+
|
760
|
+
return mask
|
761
|
+
|
762
|
+
|
763
|
+
def elliptical_mask(
|
764
|
+
shape: Tuple[int],
|
765
|
+
radius: Tuple[float],
|
766
|
+
center: Optional[Tuple[float]] = None,
|
767
|
+
orientation: Optional[NDArray] = None,
|
768
|
+
) -> NDArray:
|
769
|
+
"""
|
770
|
+
Creates an ellipsoidal mask.
|
771
|
+
|
772
|
+
Parameters
|
773
|
+
----------
|
774
|
+
shape : tuple of ints
|
775
|
+
Shape of the mask to be created.
|
776
|
+
radius : tuple of floats
|
777
|
+
Radius of the mask.
|
778
|
+
center : tuple of floats, optional
|
779
|
+
Center of the mask, default to shape // 2.
|
780
|
+
orientation : NDArray, optional.
|
781
|
+
Orientation of the mask as rotation matrix with shape (d,d).
|
782
|
+
|
783
|
+
Returns
|
784
|
+
-------
|
785
|
+
NDArray
|
786
|
+
The created ellipsoidal mask.
|
787
|
+
|
788
|
+
Raises
|
789
|
+
------
|
790
|
+
ValueError
|
791
|
+
If the length of center and radius is not one or the same as shape.
|
792
|
+
|
793
|
+
Examples
|
794
|
+
--------
|
795
|
+
>>> from tme.matching_utils import elliptical_mask
|
796
|
+
>>> mask = elliptical_mask(shape=(20,20), radius=(5,5), center=(10,10))
|
797
|
+
"""
|
798
|
+
shape, radius = np.asarray(shape), np.asarray(radius)
|
799
|
+
|
800
|
+
shape = shape.astype(int)
|
801
|
+
if center is None:
|
802
|
+
center = np.divide(shape, 2).astype(int)
|
803
|
+
|
804
|
+
center = np.asarray(center, dtype=np.float32)
|
805
|
+
radius = np.repeat(radius, shape.size // radius.size)
|
806
|
+
center = np.repeat(center, shape.size // center.size)
|
807
|
+
if radius.size != shape.size:
|
808
|
+
raise ValueError("Length of radius has to be either one or match shape.")
|
809
|
+
if center.size != shape.size:
|
810
|
+
raise ValueError("Length of center has to be either one or match shape.")
|
811
|
+
|
812
|
+
n = shape.size
|
813
|
+
center = center.reshape((-1,) + (1,) * n)
|
814
|
+
radius = radius.reshape((-1,) + (1,) * n)
|
815
|
+
|
816
|
+
indices = np.indices(shape, dtype=np.float32) - center
|
817
|
+
if orientation is not None:
|
818
|
+
return_shape = indices.shape
|
819
|
+
indices = indices.reshape(n, -1)
|
820
|
+
rigid_transform(
|
821
|
+
coordinates=indices,
|
822
|
+
rotation_matrix=np.asarray(orientation),
|
823
|
+
out=indices,
|
824
|
+
translation=np.zeros(n),
|
825
|
+
use_geometric_center=False,
|
826
|
+
)
|
827
|
+
indices = indices.reshape(*return_shape)
|
828
|
+
|
829
|
+
mask = np.linalg.norm(indices / radius, axis=0)
|
830
|
+
mask = (mask <= 1).astype(int)
|
831
|
+
|
832
|
+
return mask
|
833
|
+
|
834
|
+
|
835
|
+
def tube_mask2(
|
836
|
+
shape: Tuple[int],
|
837
|
+
inner_radius: float,
|
838
|
+
outer_radius: float,
|
839
|
+
height: int,
|
840
|
+
symmetry_axis: Optional[int] = 2,
|
841
|
+
center: Optional[Tuple[float]] = None,
|
842
|
+
orientation: Optional[NDArray] = None,
|
843
|
+
epsilon: float = 0.5,
|
844
|
+
) -> NDArray:
|
845
|
+
"""
|
846
|
+
Creates a tube mask.
|
847
|
+
|
848
|
+
Parameters
|
849
|
+
----------
|
850
|
+
shape : tuple
|
851
|
+
Shape of the mask to be created.
|
852
|
+
inner_radius : float
|
853
|
+
Inner radius of the tube.
|
854
|
+
outer_radius : float
|
855
|
+
Outer radius of the tube.
|
856
|
+
height : int
|
857
|
+
Height of the tube.
|
858
|
+
symmetry_axis : int, optional
|
859
|
+
The axis of symmetry for the tube, defaults to 2.
|
860
|
+
center : tuple of float, optional.
|
861
|
+
Center of the mask, defaults to shape // 2.
|
862
|
+
orientation : NDArray, optional.
|
863
|
+
Orientation of the mask as rotation matrix with shape (d,d).
|
864
|
+
epsilon : float, optional
|
865
|
+
Tolerance to handle discretization errors, defaults to 0.5.
|
866
|
+
|
867
|
+
Returns
|
868
|
+
-------
|
869
|
+
NDArray
|
870
|
+
The created tube mask.
|
871
|
+
|
872
|
+
Raises
|
873
|
+
------
|
874
|
+
ValueError
|
875
|
+
If ``inner_radius`` is larger than ``outer_radius``.
|
876
|
+
If ``center`` and ``shape`` do not have the same length.
|
877
|
+
"""
|
878
|
+
shape = np.asarray(shape, dtype=int)
|
879
|
+
|
880
|
+
if center is None:
|
881
|
+
center = np.divide(shape, 2).astype(int)
|
882
|
+
|
883
|
+
center = np.asarray(center, dtype=np.float32)
|
884
|
+
center = np.repeat(center, shape.size // center.size)
|
885
|
+
if inner_radius > outer_radius:
|
886
|
+
raise ValueError("inner_radius should be smaller than outer_radius.")
|
887
|
+
if symmetry_axis > len(shape):
|
888
|
+
raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
|
889
|
+
if center.size != shape.size:
|
890
|
+
raise ValueError("Length of center has to be either one or match shape.")
|
891
|
+
|
892
|
+
n = shape.size
|
893
|
+
center = center.reshape((-1,) + (1,) * n)
|
894
|
+
indices = np.indices(shape, dtype=np.float32) - center
|
895
|
+
if orientation is not None:
|
896
|
+
return_shape = indices.shape
|
897
|
+
indices = indices.reshape(n, -1)
|
898
|
+
rigid_transform(
|
899
|
+
coordinates=indices,
|
900
|
+
rotation_matrix=np.asarray(orientation),
|
901
|
+
out=indices,
|
902
|
+
translation=np.zeros(n),
|
903
|
+
use_geometric_center=False,
|
904
|
+
)
|
905
|
+
indices = indices.reshape(*return_shape)
|
906
|
+
|
907
|
+
mask = np.zeros(shape, dtype=bool)
|
908
|
+
sq_dist = np.zeros(shape)
|
909
|
+
for i in range(len(shape)):
|
910
|
+
if i == symmetry_axis:
|
911
|
+
continue
|
912
|
+
sq_dist += indices[i] ** 2
|
913
|
+
|
914
|
+
sym_coord = indices[symmetry_axis]
|
915
|
+
half_height = height / 2
|
916
|
+
height_mask = np.abs(sym_coord) <= half_height
|
917
|
+
|
918
|
+
inner_mask = 1
|
919
|
+
if inner_radius > epsilon:
|
920
|
+
inner_mask = sq_dist >= ((inner_radius) ** 2 - epsilon)
|
921
|
+
|
922
|
+
height_mask = np.abs(sym_coord) <= (half_height + epsilon)
|
923
|
+
outer_mask = sq_dist <= ((outer_radius) ** 2 + epsilon)
|
924
|
+
|
925
|
+
mask = height_mask & inner_mask & outer_mask
|
926
|
+
return mask
|
927
|
+
|
928
|
+
|
929
|
+
def box_mask(shape: Tuple[int], center: Tuple[int], height: Tuple[int]) -> np.ndarray:
|
930
|
+
"""
|
931
|
+
Creates a box mask centered around the provided center point.
|
932
|
+
|
933
|
+
Parameters
|
934
|
+
----------
|
935
|
+
shape : tuple of ints
|
936
|
+
Shape of the output array.
|
937
|
+
center : tuple of ints
|
938
|
+
Center point coordinates of the box.
|
939
|
+
height : tuple of ints
|
940
|
+
Height (side length) of the box along each axis.
|
941
|
+
|
942
|
+
Returns
|
943
|
+
-------
|
944
|
+
NDArray
|
945
|
+
The created box mask.
|
946
|
+
|
947
|
+
Raises
|
948
|
+
------
|
949
|
+
ValueError
|
950
|
+
If ``shape`` and ``center`` do not have the same length.
|
951
|
+
If ``center`` and ``height`` do not have the same length.
|
952
|
+
"""
|
953
|
+
if len(shape) != len(center) or len(center) != len(height):
|
954
|
+
raise ValueError("The length of shape, center, and height must be consistent.")
|
955
|
+
|
956
|
+
shape = tuple(int(x) for x in shape)
|
957
|
+
center, height = np.array(center, dtype=int), np.array(height, dtype=int)
|
958
|
+
|
959
|
+
half_heights = height // 2
|
960
|
+
starts = np.maximum(center - half_heights, 0)
|
961
|
+
stops = np.minimum(center + half_heights + np.mod(height, 2) + 1, shape)
|
962
|
+
slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
|
963
|
+
|
964
|
+
out = np.zeros(shape)
|
965
|
+
out[slice_indices] = 1
|
966
|
+
return out
|
967
|
+
|
968
|
+
|
969
|
+
def tube_mask(
|
970
|
+
shape: Tuple[int],
|
971
|
+
symmetry_axis: int,
|
972
|
+
base_center: Tuple[int],
|
973
|
+
inner_radius: float,
|
974
|
+
outer_radius: float,
|
975
|
+
height: int,
|
976
|
+
) -> NDArray:
|
977
|
+
"""
|
978
|
+
Creates a tube mask.
|
979
|
+
|
980
|
+
Parameters
|
981
|
+
----------
|
982
|
+
shape : tuple
|
983
|
+
Shape of the mask to be created.
|
984
|
+
symmetry_axis : int
|
985
|
+
The axis of symmetry for the tube.
|
986
|
+
base_center : tuple
|
987
|
+
Center of the tube.
|
988
|
+
inner_radius : float
|
989
|
+
Inner radius of the tube.
|
990
|
+
outer_radius : float
|
991
|
+
Outer radius of the tube.
|
992
|
+
height : int
|
993
|
+
Height of the tube.
|
994
|
+
|
995
|
+
Returns
|
996
|
+
-------
|
997
|
+
NDArray
|
998
|
+
The created tube mask.
|
999
|
+
|
1000
|
+
Raises
|
1001
|
+
------
|
1002
|
+
ValueError
|
1003
|
+
If ``inner_radius`` is larger than ``outer_radius``.
|
1004
|
+
If ``height`` is larger than the symmetry axis.
|
1005
|
+
If ``base_center`` and ``shape`` do not have the same length.
|
1006
|
+
"""
|
1007
|
+
if inner_radius > outer_radius:
|
1008
|
+
raise ValueError("inner_radius should be smaller than outer_radius.")
|
1009
|
+
|
1010
|
+
if height > shape[symmetry_axis]:
|
1011
|
+
raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
|
1012
|
+
|
1013
|
+
if symmetry_axis > len(shape):
|
1014
|
+
raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
|
1015
|
+
|
1016
|
+
if len(base_center) != len(shape):
|
1017
|
+
raise ValueError("shape and base_center need to have the same length.")
|
1018
|
+
|
1019
|
+
shape = tuple(int(x) for x in shape)
|
1020
|
+
circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
|
1021
|
+
circle_center = tuple(b for ix, b in enumerate(base_center) if ix != symmetry_axis)
|
1022
|
+
|
1023
|
+
inner_circle = np.zeros(circle_shape)
|
1024
|
+
outer_circle = np.zeros_like(inner_circle)
|
1025
|
+
if inner_radius > 0:
|
1026
|
+
inner_circle = create_mask(
|
1027
|
+
mask_type="ellipse",
|
1028
|
+
shape=circle_shape,
|
1029
|
+
radius=inner_radius,
|
1030
|
+
center=circle_center,
|
1031
|
+
)
|
1032
|
+
if outer_radius > 0:
|
1033
|
+
outer_circle = create_mask(
|
1034
|
+
mask_type="ellipse",
|
1035
|
+
shape=circle_shape,
|
1036
|
+
radius=outer_radius,
|
1037
|
+
center=circle_center,
|
1038
|
+
)
|
1039
|
+
circle = outer_circle - inner_circle
|
1040
|
+
circle = np.expand_dims(circle, axis=symmetry_axis)
|
1041
|
+
|
1042
|
+
center = base_center[symmetry_axis]
|
1043
|
+
start_idx = int(center - height // 2)
|
1044
|
+
stop_idx = int(center + height // 2 + height % 2)
|
1045
|
+
|
1046
|
+
start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis])
|
1047
|
+
|
1048
|
+
slice_indices = tuple(
|
1049
|
+
slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
|
1050
|
+
for i in range(len(shape))
|
1051
|
+
)
|
1052
|
+
tube = np.zeros(shape)
|
1053
|
+
tube[slice_indices] = circle
|
1054
|
+
|
1055
|
+
return tube
|
1056
|
+
|
1057
|
+
|
1058
|
+
def scramble_phases(
|
1059
|
+
arr: NDArray,
|
1060
|
+
noise_proportion: float = 0.5,
|
1061
|
+
seed: int = 42,
|
1062
|
+
normalize_power: bool = False,
|
1063
|
+
) -> NDArray:
|
1064
|
+
"""
|
1065
|
+
Perform random phase scrambling of ``arr``.
|
1066
|
+
|
1067
|
+
Parameters
|
1068
|
+
----------
|
1069
|
+
arr : NDArray
|
1070
|
+
Input data.
|
1071
|
+
noise_proportion : float, optional
|
1072
|
+
Proportion of scrambled phases, 0.5 by default.
|
1073
|
+
seed : int, optional
|
1074
|
+
The seed for the random phase scrambling, 42 by default.
|
1075
|
+
normalize_power : bool, optional
|
1076
|
+
Return value has same sum of squares as ``arr``.
|
1077
|
+
|
1078
|
+
Returns
|
1079
|
+
-------
|
1080
|
+
NDArray
|
1081
|
+
Phase scrambled version of ``arr``.
|
1082
|
+
"""
|
1083
|
+
np.random.seed(seed)
|
1084
|
+
noise_proportion = max(min(noise_proportion, 1), 0)
|
1085
|
+
|
1086
|
+
arr_fft = np.fft.fftn(arr)
|
1087
|
+
amp, ph = np.abs(arr_fft), np.angle(arr_fft)
|
1088
|
+
|
1089
|
+
ph_noise = np.random.permutation(ph)
|
1090
|
+
ph_new = ph * (1 - noise_proportion) + ph_noise * noise_proportion
|
1091
|
+
ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph_new)))
|
1092
|
+
|
1093
|
+
if normalize_power:
|
1094
|
+
np.divide(ret - ret.min(), ret.max() - ret.min(), out=ret)
|
1095
|
+
np.multiply(ret, np.subtract(arr.max(), arr.min()), out=ret)
|
1096
|
+
np.add(ret, arr.min(), out=ret)
|
1097
|
+
scaling = np.divide(np.abs(arr).sum(), np.abs(ret).sum())
|
1098
|
+
np.multiply(ret, scaling, out=ret)
|
1099
|
+
|
1100
|
+
return ret
|
1101
|
+
|
1102
|
+
|
1103
|
+
def compute_extraction_box(
|
1104
|
+
centers: BackendArray, extraction_shape: Tuple[int], original_shape: Tuple[int]
|
1105
|
+
):
|
1106
|
+
"""Compute coordinates for extracting fixed-size regions around points.
|
1107
|
+
|
1108
|
+
Parameters
|
1109
|
+
----------
|
1110
|
+
centers : BackendArray
|
1111
|
+
Array of shape (n, d) containing n center coordinates in d dimensions.
|
1112
|
+
extraction_shape : tuple of int
|
1113
|
+
Desired shape of the extraction box.
|
1114
|
+
original_shape : tuple of int
|
1115
|
+
Shape of the original array from which extractions will be made.
|
1116
|
+
|
1117
|
+
Returns
|
1118
|
+
-------
|
1119
|
+
obs_beg : BackendArray
|
1120
|
+
Starting coordinates for extraction, shape (n, d).
|
1121
|
+
obs_end : BackendArray
|
1122
|
+
Ending coordinates for extraction, shape (n, d).
|
1123
|
+
cand_beg : BackendArray
|
1124
|
+
Starting coordinates in output array, shape (n, d).
|
1125
|
+
cand_end : BackendArray
|
1126
|
+
Ending coordinates in output array, shape (n, d).
|
1127
|
+
keep : BackendArray
|
1128
|
+
Boolean mask of valid extraction boxes, shape (n,).
|
1129
|
+
"""
|
1130
|
+
target_shape = be.to_backend_array(original_shape)
|
1131
|
+
extraction_shape = be.to_backend_array(extraction_shape)
|
1132
|
+
|
1133
|
+
left_pad = be.astype(be.divide(extraction_shape, 2), int)
|
1134
|
+
right_pad = be.astype(be.add(left_pad, be.mod(extraction_shape, 2)), int)
|
1135
|
+
|
1136
|
+
obs_beg = be.subtract(centers, left_pad)
|
1137
|
+
obs_end = be.add(centers, right_pad)
|
1138
|
+
|
1139
|
+
obs_beg_clamp = be.maximum(obs_beg, 0)
|
1140
|
+
obs_end_clamp = be.minimum(obs_end, target_shape)
|
1141
|
+
|
1142
|
+
clamp_change = be.sum(
|
1143
|
+
be.add(obs_beg != obs_beg_clamp, obs_end != obs_end_clamp), axis=1
|
1144
|
+
)
|
1145
|
+
|
1146
|
+
cand_beg = left_pad - be.subtract(centers, obs_beg_clamp)
|
1147
|
+
cand_end = left_pad + be.subtract(obs_end_clamp, centers)
|
1148
|
+
|
1149
|
+
stops = be.subtract(cand_end, extraction_shape)
|
1150
|
+
keep = be.sum(be.multiply(cand_beg == 0, stops == 0), axis=1) == centers.shape[1]
|
1151
|
+
keep = be.multiply(keep, clamp_change == 0)
|
1152
|
+
|
1153
|
+
return obs_beg_clamp, obs_end_clamp, cand_beg, cand_end, keep
|
1154
|
+
|
1155
|
+
|
1156
|
+
class TqdmParallel(Parallel):
|
1157
|
+
"""
|
1158
|
+
A minimal Parallel implementation using tqdm for progress reporting.
|
1159
|
+
|
1160
|
+
Parameters:
|
1161
|
+
-----------
|
1162
|
+
tqdm_args : dict, optional
|
1163
|
+
Dictionary of arguments passed to tqdm.tqdm
|
1164
|
+
*args, **kwargs:
|
1165
|
+
Arguments to pass to joblib.Parallel
|
1166
|
+
"""
|
1167
|
+
|
1168
|
+
def __init__(self, tqdm_args: Dict = {}, *args, **kwargs):
|
1169
|
+
super().__init__(*args, **kwargs)
|
1170
|
+
self.pbar = tqdm(**tqdm_args)
|
1171
|
+
|
1172
|
+
def __call__(self, iterable, *args, **kwargs):
|
1173
|
+
self.n_tasks = len(iterable) if hasattr(iterable, "__len__") else None
|
1174
|
+
return super().__call__(iterable, *args, **kwargs)
|
1175
|
+
|
1176
|
+
def print_progress(self):
|
1177
|
+
if self.n_tasks is None:
|
1178
|
+
return super().print_progress()
|
1179
|
+
|
1180
|
+
if self.n_tasks != self.pbar.total:
|
1181
|
+
self.pbar.total = self.n_tasks
|
1182
|
+
self.pbar.refresh()
|
1183
|
+
|
1184
|
+
self.pbar.n = self.n_completed_tasks
|
1185
|
+
self.pbar.refresh()
|
1186
|
+
|
1187
|
+
if self.n_completed_tasks >= self.n_tasks:
|
1188
|
+
self.pbar.close()
|