pytme 0.1.5__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.1.5.data/scripts/estimate_ram_usage.py +81 -0
- pytme-0.1.5.data/scripts/match_template.py +744 -0
- pytme-0.1.5.data/scripts/postprocess.py +279 -0
- pytme-0.1.5.data/scripts/preprocess.py +93 -0
- pytme-0.1.5.data/scripts/preprocessor_gui.py +729 -0
- pytme-0.1.5.dist-info/LICENSE +153 -0
- pytme-0.1.5.dist-info/METADATA +69 -0
- pytme-0.1.5.dist-info/RECORD +63 -0
- pytme-0.1.5.dist-info/WHEEL +5 -0
- pytme-0.1.5.dist-info/entry_points.txt +6 -0
- pytme-0.1.5.dist-info/top_level.txt +2 -0
- scripts/__init__.py +0 -0
- scripts/estimate_ram_usage.py +81 -0
- scripts/match_template.py +744 -0
- scripts/match_template_devel.py +788 -0
- scripts/postprocess.py +279 -0
- scripts/preprocess.py +93 -0
- scripts/preprocessor_gui.py +729 -0
- tme/__init__.py +6 -0
- tme/__version__.py +1 -0
- tme/analyzer.py +1144 -0
- tme/backends/__init__.py +134 -0
- tme/backends/cupy_backend.py +309 -0
- tme/backends/matching_backend.py +1154 -0
- tme/backends/npfftw_backend.py +763 -0
- tme/backends/pytorch_backend.py +526 -0
- tme/data/__init__.py +0 -0
- tme/data/c48n309.npy +0 -0
- tme/data/c48n527.npy +0 -0
- tme/data/c48n9.npy +0 -0
- tme/data/c48u1.npy +0 -0
- tme/data/c48u1153.npy +0 -0
- tme/data/c48u1201.npy +0 -0
- tme/data/c48u1641.npy +0 -0
- tme/data/c48u181.npy +0 -0
- tme/data/c48u2219.npy +0 -0
- tme/data/c48u27.npy +0 -0
- tme/data/c48u2947.npy +0 -0
- tme/data/c48u3733.npy +0 -0
- tme/data/c48u4749.npy +0 -0
- tme/data/c48u5879.npy +0 -0
- tme/data/c48u7111.npy +0 -0
- tme/data/c48u815.npy +0 -0
- tme/data/c48u83.npy +0 -0
- tme/data/c48u8649.npy +0 -0
- tme/data/c600v.npy +0 -0
- tme/data/c600vc.npy +0 -0
- tme/data/metadata.yaml +80 -0
- tme/data/quat_to_numpy.py +42 -0
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +2314 -0
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/helpers.py +881 -0
- tme/matching_data.py +377 -0
- tme/matching_exhaustive.py +1553 -0
- tme/matching_memory.py +382 -0
- tme/matching_optimization.py +1123 -0
- tme/matching_utils.py +1180 -0
- tme/parser.py +429 -0
- tme/preprocessor.py +1291 -0
- tme/scoring.py +866 -0
- tme/structure.py +1428 -0
- tme/types.py +10 -0
tme/matching_utils.py
ADDED
@@ -0,0 +1,1180 @@
|
|
1
|
+
""" Utility functions for template matching.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
"""
|
7
|
+
import os
|
8
|
+
import traceback
|
9
|
+
import pickle
|
10
|
+
from shutil import move
|
11
|
+
from tempfile import mkstemp
|
12
|
+
from itertools import product
|
13
|
+
from typing import Tuple, Dict, Callable
|
14
|
+
from concurrent.futures import ThreadPoolExecutor
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
from numpy.typing import NDArray
|
18
|
+
from scipy.spatial import ConvexHull
|
19
|
+
from scipy.spatial.transform import Rotation
|
20
|
+
|
21
|
+
from .helpers import quaternion_to_rotation_matrix, load_quaternions_by_angle
|
22
|
+
from .extensions import max_euclidean_distance
|
23
|
+
from .matching_memory import estimate_ram_usage
|
24
|
+
|
25
|
+
|
26
|
+
def handle_traceback(last_type, last_value, last_traceback):
|
27
|
+
"""
|
28
|
+
Handle sys.exc_info().
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
last_type : type
|
33
|
+
The type of the last exception.
|
34
|
+
last_value :
|
35
|
+
The value of the last exception.
|
36
|
+
last_traceback : traceback
|
37
|
+
The traceback object encapsulating the call stack at the point
|
38
|
+
where the exception originally occurred.
|
39
|
+
|
40
|
+
Raises
|
41
|
+
------
|
42
|
+
Exception
|
43
|
+
Re-raises the last exception.
|
44
|
+
"""
|
45
|
+
if last_type is None:
|
46
|
+
return None
|
47
|
+
traceback.print_tb(last_traceback)
|
48
|
+
raise Exception(last_value)
|
49
|
+
# raise last_type(last_value)
|
50
|
+
|
51
|
+
|
52
|
+
def generate_tempfile_name(suffix=None):
|
53
|
+
"""
|
54
|
+
Returns the path to a potential temporary file location. If the environment
|
55
|
+
variable TME_TMPDIR is defined, the temporary file will be created there.
|
56
|
+
Otherwise the default tmp directory will be used.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
suffix : str, optional
|
61
|
+
File suffix. By default the file has no suffix.
|
62
|
+
|
63
|
+
Returns
|
64
|
+
-------
|
65
|
+
str
|
66
|
+
The generated filename
|
67
|
+
"""
|
68
|
+
tmp_dir = os.environ.get("TMPDIR", None)
|
69
|
+
_, filename = mkstemp(suffix=suffix, dir=tmp_dir)
|
70
|
+
return filename
|
71
|
+
|
72
|
+
|
73
|
+
def array_to_memmap(arr: NDArray, filename: str = None) -> str:
|
74
|
+
"""
|
75
|
+
Converts a numpy array to a np.memmap.
|
76
|
+
|
77
|
+
Parameters
|
78
|
+
----------
|
79
|
+
arr : np.ndarray
|
80
|
+
The numpy array to be converted.
|
81
|
+
filename : str, optional
|
82
|
+
Desired filename for the memmap. If not provided, a temporary
|
83
|
+
file will be created.
|
84
|
+
|
85
|
+
Notes
|
86
|
+
-----
|
87
|
+
If the environment variable TME_TMPDIR is defined, the temporary
|
88
|
+
file will be created there. Otherwise the default tmp directory
|
89
|
+
will be used.
|
90
|
+
|
91
|
+
Returns
|
92
|
+
-------
|
93
|
+
str
|
94
|
+
The filename where the memmap was written to.
|
95
|
+
"""
|
96
|
+
if filename is None:
|
97
|
+
filename = generate_tempfile_name()
|
98
|
+
|
99
|
+
shape, dtype = arr.shape, arr.dtype
|
100
|
+
arr_memmap = np.memmap(filename, mode="w+", dtype=dtype, shape=shape)
|
101
|
+
|
102
|
+
arr_memmap[:] = arr[:]
|
103
|
+
arr_memmap.flush()
|
104
|
+
|
105
|
+
return filename
|
106
|
+
|
107
|
+
|
108
|
+
def memmap_to_array(arr: NDArray) -> NDArray:
|
109
|
+
"""
|
110
|
+
Converts a np.memmap into an numpy array.
|
111
|
+
|
112
|
+
Parameters
|
113
|
+
----------
|
114
|
+
arr : np.memmap
|
115
|
+
The numpy array to be converted.
|
116
|
+
|
117
|
+
Returns
|
118
|
+
-------
|
119
|
+
np.ndarray
|
120
|
+
The converted array.
|
121
|
+
"""
|
122
|
+
if type(arr) == np.memmap:
|
123
|
+
memmap_filepath = arr.filename
|
124
|
+
arr = np.array(arr)
|
125
|
+
os.remove(memmap_filepath)
|
126
|
+
return arr
|
127
|
+
|
128
|
+
|
129
|
+
def close_memmap(arr: np.ndarray) -> None:
|
130
|
+
"""
|
131
|
+
Remove the file associated with a numpy memmap array.
|
132
|
+
|
133
|
+
Parameters
|
134
|
+
----------
|
135
|
+
arr : np.ndarray
|
136
|
+
The numpy array which might be a memmap.
|
137
|
+
"""
|
138
|
+
try:
|
139
|
+
os.remove(arr.filename)
|
140
|
+
# arr._mmap.close()
|
141
|
+
except Exception:
|
142
|
+
pass
|
143
|
+
|
144
|
+
|
145
|
+
def write_pickle(data: object, filename: str) -> None:
|
146
|
+
"""
|
147
|
+
Serialize and write data to a file invalidating the input data in
|
148
|
+
the process. This function uses type-specific serialization for
|
149
|
+
certain objects, such as np.memmap, for optimized storage. Other
|
150
|
+
objects are serialized using standard pickle.
|
151
|
+
|
152
|
+
Parameters
|
153
|
+
----------
|
154
|
+
data : iterable or object
|
155
|
+
The data to be serialized.
|
156
|
+
filename : str
|
157
|
+
The name of the file where the serialized data will be written.
|
158
|
+
|
159
|
+
See Also
|
160
|
+
--------
|
161
|
+
:py:meth:`load_pickle`
|
162
|
+
"""
|
163
|
+
if type(data) not in (list, tuple):
|
164
|
+
data = (data,)
|
165
|
+
|
166
|
+
dirname = os.path.dirname(filename)
|
167
|
+
with open(filename, "wb") as ofile, ThreadPoolExecutor() as executor:
|
168
|
+
for i in range(len(data)):
|
169
|
+
futures = []
|
170
|
+
item = data[i]
|
171
|
+
if isinstance(item, np.memmap):
|
172
|
+
_, new_filename = mkstemp(suffix=".mm", dir=dirname)
|
173
|
+
new_item = ("np.memmap", item.shape, item.dtype, new_filename)
|
174
|
+
futures.append(executor.submit(move, item.filename, new_filename))
|
175
|
+
item = new_item
|
176
|
+
pickle.dump(item, ofile)
|
177
|
+
for future in futures:
|
178
|
+
future.result()
|
179
|
+
|
180
|
+
|
181
|
+
def load_pickle(filename: str) -> object:
|
182
|
+
"""
|
183
|
+
Load and deserialize data written by :py:meth:`write_pickle`.
|
184
|
+
|
185
|
+
Parameters
|
186
|
+
----------
|
187
|
+
filename : str
|
188
|
+
The name of the file to read and deserialize data from.
|
189
|
+
|
190
|
+
Returns
|
191
|
+
-------
|
192
|
+
object or iterable
|
193
|
+
The deserialized data.
|
194
|
+
|
195
|
+
See Also
|
196
|
+
--------
|
197
|
+
:py:meth:`write_pickle`
|
198
|
+
"""
|
199
|
+
|
200
|
+
def _load_pickle(file_handle):
|
201
|
+
try:
|
202
|
+
while True:
|
203
|
+
yield pickle.load(file_handle)
|
204
|
+
except EOFError:
|
205
|
+
pass
|
206
|
+
|
207
|
+
def _is_pickle_memmap(data):
|
208
|
+
ret = False
|
209
|
+
if type(data[0]) == str:
|
210
|
+
if data[0] == "np.memmap":
|
211
|
+
ret = True
|
212
|
+
return ret
|
213
|
+
|
214
|
+
items = []
|
215
|
+
with open(filename, "rb") as ifile:
|
216
|
+
for data in _load_pickle(ifile):
|
217
|
+
if isinstance(data, tuple):
|
218
|
+
if _is_pickle_memmap(data):
|
219
|
+
_, shape, dtype, filename = data
|
220
|
+
data = np.memmap(filename, shape=shape, dtype=dtype)
|
221
|
+
items.append(data)
|
222
|
+
return items[0] if len(items) == 1 else items
|
223
|
+
|
224
|
+
|
225
|
+
def compute_parallelization_schedule(
|
226
|
+
shape1: NDArray,
|
227
|
+
shape2: NDArray,
|
228
|
+
max_cores: int,
|
229
|
+
max_ram: int,
|
230
|
+
matching_method: str,
|
231
|
+
backend: str = None,
|
232
|
+
split_only_outer: bool = False,
|
233
|
+
shape1_padding: NDArray = None,
|
234
|
+
analyzer_method: str = None,
|
235
|
+
max_splits: int = 256,
|
236
|
+
float_nbytes: int = 4,
|
237
|
+
complex_nbytes: int = 8,
|
238
|
+
integer_nbytes: int = 4,
|
239
|
+
) -> Tuple[Dict, int, int]:
|
240
|
+
"""
|
241
|
+
Computes a parallelization schedule for a given computation.
|
242
|
+
|
243
|
+
This function estimates the amount of memory that would be used by a computation
|
244
|
+
and breaks down the computation into smaller parts that can be executed in parallel
|
245
|
+
without exceeding the specified limits on the number of cores and memory.
|
246
|
+
|
247
|
+
Parameters
|
248
|
+
----------
|
249
|
+
shape1 : NDArray
|
250
|
+
The shape of the first input tensor.
|
251
|
+
shape1_padding : NDArray, optional
|
252
|
+
Padding for shape1 used for each split. None by defauly
|
253
|
+
shape2 : NDArray
|
254
|
+
The shape of the second input tensor.
|
255
|
+
max_cores : int
|
256
|
+
The maximum number of cores that can be used.
|
257
|
+
max_ram : int
|
258
|
+
The maximum amount of memory that can be used.
|
259
|
+
matching_method : str
|
260
|
+
The metric used for scoring the computations.
|
261
|
+
backend : str, optional
|
262
|
+
Backend used for computations.
|
263
|
+
split_only_outer : bool, optional
|
264
|
+
Whether only outer splits sould be considered.
|
265
|
+
analyzer_method : str
|
266
|
+
The method used for score analysis.
|
267
|
+
max_splits : int, optional
|
268
|
+
The maximum number of parts that the computation can be split into,
|
269
|
+
by default 256.
|
270
|
+
float_nbytes : int
|
271
|
+
Number of bytes of the used float, e.g. 4 for float32.
|
272
|
+
complex_nbytes : int
|
273
|
+
Number of bytes of the used complex, e.g. 8 for complex64.
|
274
|
+
integer_nbytes : int
|
275
|
+
Number of bytes of the used integer, e.g. 4 for int32.
|
276
|
+
|
277
|
+
Notes
|
278
|
+
-----
|
279
|
+
This function assumes that no residual memory remains after each split,
|
280
|
+
which not always holds true, e.g. when using
|
281
|
+
:py:class:`tme.analyzer.MaxScoreOverRotations`.
|
282
|
+
|
283
|
+
Returns
|
284
|
+
-------
|
285
|
+
dict
|
286
|
+
The optimal splits for each axis of the first input tensor.
|
287
|
+
int
|
288
|
+
The number of outer jobs.
|
289
|
+
int
|
290
|
+
The number of inner jobs per outer job.
|
291
|
+
"""
|
292
|
+
shape1, shape2 = np.array(shape1), np.array(shape2)
|
293
|
+
if shape1_padding is None:
|
294
|
+
shape1_padding = np.zeros_like(shape1)
|
295
|
+
core_assignments = []
|
296
|
+
for i in range(1, int(max_cores**0.5) + 1):
|
297
|
+
if max_cores % i == 0:
|
298
|
+
core_assignments.append((i, max_cores // i))
|
299
|
+
core_assignments.append((max_cores // i, i))
|
300
|
+
|
301
|
+
if split_only_outer:
|
302
|
+
core_assignments = [(1, max_cores)]
|
303
|
+
|
304
|
+
possible_params, split_axis = [], np.argmax(shape1)
|
305
|
+
split_factor, n_splits = [1 for _ in range(len(shape1))], 0
|
306
|
+
while n_splits <= max_splits:
|
307
|
+
splits = {k: split_factor[k] for k in range(len(split_factor))}
|
308
|
+
array_slices = split_numpy_array_slices(shape=shape1, splits=splits)
|
309
|
+
array_widths = [
|
310
|
+
tuple(x.stop - x.start for x in split) for split in array_slices
|
311
|
+
]
|
312
|
+
n_splits = np.prod(list(splits.values()))
|
313
|
+
|
314
|
+
for inner_cores, outer_cores in core_assignments:
|
315
|
+
if outer_cores > n_splits:
|
316
|
+
continue
|
317
|
+
ram_usage = [
|
318
|
+
estimate_ram_usage(
|
319
|
+
shape1=np.add(shp, shape1_padding),
|
320
|
+
shape2=shape2,
|
321
|
+
matching_method=matching_method,
|
322
|
+
analyzer_method=analyzer_method,
|
323
|
+
backend=backend,
|
324
|
+
ncores=inner_cores,
|
325
|
+
float_nbytes=float_nbytes,
|
326
|
+
complex_nbytes=complex_nbytes,
|
327
|
+
integer_nbytes=integer_nbytes,
|
328
|
+
)
|
329
|
+
for shp in array_widths
|
330
|
+
]
|
331
|
+
max_usage = 0
|
332
|
+
for i in range(0, len(ram_usage), outer_cores):
|
333
|
+
usage = np.sum(ram_usage[i : (i + outer_cores)])
|
334
|
+
if usage > max_usage:
|
335
|
+
max_usage = usage
|
336
|
+
|
337
|
+
inits = n_splits // outer_cores
|
338
|
+
if max_usage < max_ram:
|
339
|
+
possible_params.append(
|
340
|
+
(*split_factor, outer_cores, inner_cores, n_splits, inits)
|
341
|
+
)
|
342
|
+
split_factor[split_axis] += 1
|
343
|
+
split_axis += 1
|
344
|
+
if split_axis == shape1.size:
|
345
|
+
split_axis = 0
|
346
|
+
|
347
|
+
possible_params = np.array(possible_params)
|
348
|
+
if not len(possible_params):
|
349
|
+
print(
|
350
|
+
"No suitable assignment found. Consider increasing "
|
351
|
+
"max_ram or decrease max_cores."
|
352
|
+
)
|
353
|
+
return None, None
|
354
|
+
|
355
|
+
init = possible_params.shape[1] - 1
|
356
|
+
possible_params = possible_params[
|
357
|
+
np.lexsort((possible_params[:, init], possible_params[:, (init - 1)]))
|
358
|
+
]
|
359
|
+
splits = {k: possible_params[0, k] for k in range(shape1.size)}
|
360
|
+
core_assignment = (
|
361
|
+
possible_params[0, shape1.size],
|
362
|
+
possible_params[0, (shape1.size + 1)],
|
363
|
+
)
|
364
|
+
|
365
|
+
return splits, core_assignment
|
366
|
+
|
367
|
+
|
368
|
+
def centered(arr: NDArray, newshape: Tuple[int]) -> NDArray:
|
369
|
+
"""
|
370
|
+
Extract the centered portion of an array based on a new shape.
|
371
|
+
|
372
|
+
Parameters
|
373
|
+
----------
|
374
|
+
arr : NDArray
|
375
|
+
Input array.
|
376
|
+
newshape : tuple
|
377
|
+
Desired shape for the central portion.
|
378
|
+
|
379
|
+
Returns
|
380
|
+
-------
|
381
|
+
NDArray
|
382
|
+
Central portion of the array with shape `newshape`.
|
383
|
+
|
384
|
+
References
|
385
|
+
----------
|
386
|
+
.. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L388
|
387
|
+
"""
|
388
|
+
new_shape = np.asarray(newshape)
|
389
|
+
current_shape = np.array(arr.shape)
|
390
|
+
starts = (current_shape - new_shape) // 2
|
391
|
+
stops = starts + newshape
|
392
|
+
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
393
|
+
return arr[box]
|
394
|
+
|
395
|
+
|
396
|
+
def centered_mask(arr: NDArray, newshape: Tuple[int]) -> NDArray:
|
397
|
+
"""
|
398
|
+
Mask the centered portion of an array based on a new shape.
|
399
|
+
|
400
|
+
Parameters
|
401
|
+
----------
|
402
|
+
arr : NDArray
|
403
|
+
Input array.
|
404
|
+
newshape : tuple
|
405
|
+
Desired shape for the mask.
|
406
|
+
|
407
|
+
Returns
|
408
|
+
-------
|
409
|
+
NDArray
|
410
|
+
Array with central portion unmasked and the rest set to 0.
|
411
|
+
"""
|
412
|
+
new_shape = np.asarray(newshape)
|
413
|
+
current_shape = np.array(arr.shape)
|
414
|
+
starts = (current_shape - new_shape) // 2
|
415
|
+
stops = starts + newshape
|
416
|
+
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
417
|
+
mask = np.zeros_like(arr)
|
418
|
+
mask[box] = 1
|
419
|
+
arr *= mask
|
420
|
+
return arr
|
421
|
+
|
422
|
+
|
423
|
+
def apply_convolution_mode(
|
424
|
+
arr: NDArray,
|
425
|
+
convolution_mode: str,
|
426
|
+
s1: Tuple[int],
|
427
|
+
s2: Tuple[int],
|
428
|
+
mask_output: bool = False,
|
429
|
+
) -> NDArray:
|
430
|
+
"""
|
431
|
+
Applies convolution_mode to arr.
|
432
|
+
|
433
|
+
Parameters
|
434
|
+
----------
|
435
|
+
arr : NDArray
|
436
|
+
Numpy array containing convolution result of arrays with shape s1 and s2.
|
437
|
+
convolution_mode : str
|
438
|
+
Analogous to mode in ``scipy.signal.convolve``:
|
439
|
+
|
440
|
+
+---------+----------------------------------------------------------+
|
441
|
+
| 'full' | returns full template matching result of the inputs. |
|
442
|
+
+---------+----------------------------------------------------------+
|
443
|
+
| 'valid' | returns elements that do not rely on zero-padding.. |
|
444
|
+
+---------+----------------------------------------------------------+
|
445
|
+
| 'same' | output is the same size as s1. |
|
446
|
+
+---------+----------------------------------------------------------+
|
447
|
+
s1 : tuple
|
448
|
+
Tuple of integers corresponding to shape of convolution array 1.
|
449
|
+
s2 : tuple
|
450
|
+
Tuple of integers corresponding to shape of convolution array 2.
|
451
|
+
mask_output : bool, optional
|
452
|
+
Whether to mask values outside of convolution_mode rather than
|
453
|
+
removing them. Defaults to False.
|
454
|
+
|
455
|
+
Returns
|
456
|
+
-------
|
457
|
+
NDArray
|
458
|
+
The numpy array after applying the convolution mode.
|
459
|
+
|
460
|
+
References
|
461
|
+
----------
|
462
|
+
.. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L519
|
463
|
+
"""
|
464
|
+
# This removes padding to next fast fourier length
|
465
|
+
arr = arr[tuple(slice(s1[i] + s2[i] - 1) for i in range(len(s1)))]
|
466
|
+
|
467
|
+
if convolution_mode not in ("full", "same", "valid"):
|
468
|
+
raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
|
469
|
+
|
470
|
+
func = centered_mask if mask_output else centered
|
471
|
+
if convolution_mode == "full":
|
472
|
+
return arr
|
473
|
+
elif convolution_mode == "same":
|
474
|
+
return func(arr, s1)
|
475
|
+
elif convolution_mode == "valid":
|
476
|
+
valid_shape = [s1[i] - s2[i] + s2[i] % 2 for i in range(arr.ndim)]
|
477
|
+
return func(arr, valid_shape)
|
478
|
+
|
479
|
+
|
480
|
+
def compute_full_convolution_index(
|
481
|
+
outer_shape: Tuple[int],
|
482
|
+
inner_shape: Tuple[int],
|
483
|
+
outer_split: Tuple[slice],
|
484
|
+
inner_split: Tuple[slice],
|
485
|
+
) -> Tuple[slice]:
|
486
|
+
"""
|
487
|
+
Computes the position of the convolution of pieces in the full convolution.
|
488
|
+
|
489
|
+
Parameters
|
490
|
+
----------
|
491
|
+
outer_shape : tuple
|
492
|
+
Tuple of integers corresponding to the shape of the outer array.
|
493
|
+
inner_shape : tuple
|
494
|
+
Tuple of integers corresponding to the shape of the inner array.
|
495
|
+
outer_split : tuple
|
496
|
+
Tuple of slices used to split outer array
|
497
|
+
(see :py:meth:`split_numpy_array_slices`).
|
498
|
+
inner_split : tuple
|
499
|
+
Tuple of slices used to split inner array
|
500
|
+
(see :py:meth:`split_numpy_array_slices`).
|
501
|
+
|
502
|
+
Returns
|
503
|
+
-------
|
504
|
+
tuple
|
505
|
+
Tuple of slices corresponding to the position of the given convolution
|
506
|
+
in the full convolution.
|
507
|
+
"""
|
508
|
+
outer_shape = np.asarray(outer_shape)
|
509
|
+
inner_shape = np.asarray(inner_shape)
|
510
|
+
|
511
|
+
outer_width = np.array([outer.stop - outer.start for outer in outer_split])
|
512
|
+
inner_width = np.array([inner.stop - inner.start for inner in inner_split])
|
513
|
+
convolution_shape = outer_width + inner_width - 1
|
514
|
+
|
515
|
+
end_inner = np.array([inner.stop for inner in inner_split]).astype(int)
|
516
|
+
start_outer = np.array([outer.start for outer in outer_split]).astype(int)
|
517
|
+
|
518
|
+
offsets = start_outer + inner_shape - end_inner
|
519
|
+
|
520
|
+
score_slice = tuple(
|
521
|
+
(slice(offset, offset + shape))
|
522
|
+
for offset, shape in zip(offsets, convolution_shape)
|
523
|
+
)
|
524
|
+
|
525
|
+
return score_slice
|
526
|
+
|
527
|
+
|
528
|
+
def split_numpy_array_slices(
|
529
|
+
shape: NDArray, splits: Dict, margin: NDArray = None
|
530
|
+
) -> Tuple[slice]:
|
531
|
+
"""
|
532
|
+
Returns a tuple of slices to subset a numpy array into pieces along multiple axes.
|
533
|
+
|
534
|
+
Parameters
|
535
|
+
----------
|
536
|
+
shape : NDArray
|
537
|
+
Shape of the array to split.
|
538
|
+
splits : dict
|
539
|
+
A dictionary where the keys are the axis numbers and the values
|
540
|
+
are the number of splits along that axis.
|
541
|
+
margin : NDArray, optional
|
542
|
+
Padding on the left hand side of the array.
|
543
|
+
|
544
|
+
Returns
|
545
|
+
-------
|
546
|
+
tuple
|
547
|
+
A tuple of slices, where each slice corresponds to a split along an axis.
|
548
|
+
"""
|
549
|
+
ndim = len(shape)
|
550
|
+
if margin is None:
|
551
|
+
margin = np.zeros(ndim, dtype=int)
|
552
|
+
splits = {k: max(splits.get(k, 0), 1) for k in range(ndim)}
|
553
|
+
new_shape = np.divide(shape, [splits.get(i, 1) for i in range(ndim)]).astype(int)
|
554
|
+
|
555
|
+
slice_list = [
|
556
|
+
tuple(
|
557
|
+
(slice(max((n_splits * length) - margin[axis], 0), (n_splits + 1) * length))
|
558
|
+
if n_splits < splits.get(axis, 1) - 1
|
559
|
+
else (slice(max((n_splits * length) - margin[axis], 0), shape[axis]))
|
560
|
+
for n_splits in range(splits.get(axis, 1))
|
561
|
+
)
|
562
|
+
for length, axis in zip(new_shape, splits.keys())
|
563
|
+
]
|
564
|
+
|
565
|
+
splits = tuple(product(*slice_list))
|
566
|
+
|
567
|
+
return splits
|
568
|
+
|
569
|
+
|
570
|
+
def get_rotation_matrices(
|
571
|
+
angular_sampling: float, dim: int = 3, use_optimized_set: bool = True
|
572
|
+
) -> NDArray:
|
573
|
+
"""
|
574
|
+
Returns rotation matrices in format k x dim x dim, where k is determined
|
575
|
+
by ``angular_sampling``.
|
576
|
+
|
577
|
+
Parameters
|
578
|
+
----------
|
579
|
+
angular_sampling : float
|
580
|
+
The angle in degrees used for the generation of rotation matrices.
|
581
|
+
dim : int, optional
|
582
|
+
Dimension of the rotation matrices.
|
583
|
+
use_optimized_set : bool, optional
|
584
|
+
Whether to use pre-computed rotational sets with more optimal sampling.
|
585
|
+
Currently only available when dim=3.
|
586
|
+
|
587
|
+
Notes
|
588
|
+
-----
|
589
|
+
For the case of dim = 3 optimized rotational sets are used, otherwise
|
590
|
+
QR-decomposition.
|
591
|
+
|
592
|
+
Returns
|
593
|
+
-------
|
594
|
+
NDArray
|
595
|
+
Array of shape (k, dim, dim) containing k rotation matrices.
|
596
|
+
"""
|
597
|
+
if dim == 3 and use_optimized_set:
|
598
|
+
quaternions, *_ = load_quaternions_by_angle(angular_sampling)
|
599
|
+
ret = quaternion_to_rotation_matrix(quaternions)
|
600
|
+
else:
|
601
|
+
num_rotations = dim * (dim - 1) // 2
|
602
|
+
k = int((360 / angular_sampling) ** num_rotations)
|
603
|
+
As = np.random.randn(k, dim, dim)
|
604
|
+
ret, _ = np.linalg.qr(As)
|
605
|
+
dets = np.linalg.det(ret)
|
606
|
+
neg_dets = dets < 0
|
607
|
+
ret[neg_dets, :, -1] *= -1
|
608
|
+
return ret
|
609
|
+
|
610
|
+
|
611
|
+
def minimum_enclosing_box(
|
612
|
+
coordinates: NDArray,
|
613
|
+
margin: NDArray = None,
|
614
|
+
use_geometric_center: bool = False,
|
615
|
+
) -> Tuple[int]:
|
616
|
+
"""
|
617
|
+
Computes the minimal enclosing box around coordinates with margin.
|
618
|
+
|
619
|
+
Parameters
|
620
|
+
----------
|
621
|
+
coordinates : NDArray
|
622
|
+
Coordinates of which the enclosing box should be computed. The shape
|
623
|
+
of this array should be [d, n] with d dimensions and n coordinates.
|
624
|
+
margin : NDArray, optional
|
625
|
+
Box margin. Defaults to None.
|
626
|
+
use_geometric_center : bool, optional
|
627
|
+
Whether the box should accommodate the geometric or the coordinate
|
628
|
+
center. Defaults to False.
|
629
|
+
|
630
|
+
Returns
|
631
|
+
-------
|
632
|
+
tuple
|
633
|
+
Integers corresponding to the minimum enclosing box shape.
|
634
|
+
"""
|
635
|
+
point_cloud = np.asarray(coordinates)
|
636
|
+
dim = point_cloud.shape[0]
|
637
|
+
point_cloud = point_cloud - point_cloud.min(axis=1)[:, None]
|
638
|
+
|
639
|
+
margin = np.zeros(dim) if margin is None else margin
|
640
|
+
margin = np.asarray(margin).astype(int)
|
641
|
+
|
642
|
+
norm_cloud = point_cloud - point_cloud.mean(axis=1)[:, None]
|
643
|
+
# Adding one avoids clipping during scipy.ndimage.affine_transform
|
644
|
+
shape = np.repeat(
|
645
|
+
np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
|
646
|
+
).astype(int)
|
647
|
+
if use_geometric_center:
|
648
|
+
hull = ConvexHull(point_cloud.T)
|
649
|
+
distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
|
650
|
+
distance += np.linalg.norm(np.ones(dim))
|
651
|
+
shape = np.repeat(np.rint(distance).astype(int), dim)
|
652
|
+
|
653
|
+
return shape
|
654
|
+
|
655
|
+
|
656
|
+
def crop_input(
|
657
|
+
target: "Density",
|
658
|
+
template: "Density",
|
659
|
+
target_mask: "Density" = None,
|
660
|
+
template_mask: "Density" = None,
|
661
|
+
map_cutoff: float = 0,
|
662
|
+
template_cutoff: float = 0,
|
663
|
+
) -> Tuple[int]:
|
664
|
+
"""
|
665
|
+
Crop target and template maps for efficient fitting. Input densities
|
666
|
+
are cropped in place.
|
667
|
+
|
668
|
+
Parameters
|
669
|
+
----------
|
670
|
+
target : Density
|
671
|
+
Target to be fitted on.
|
672
|
+
template : Density
|
673
|
+
Template to fit onto the target.
|
674
|
+
target_mask : Density, optional
|
675
|
+
Path to mask of target. Will be croppped like target.
|
676
|
+
template_mask : Density, optional
|
677
|
+
Path to mask of template. Will be cropped like template.
|
678
|
+
map_cutoff : float, optional
|
679
|
+
Cutoff value for trimming the target Density. Default is 0.
|
680
|
+
map_cutoff : float, optional
|
681
|
+
Cutoff value for trimming the template Density. Default is 0.
|
682
|
+
|
683
|
+
Returns
|
684
|
+
-------
|
685
|
+
Tuple[int]
|
686
|
+
Tuple containing reference fit index
|
687
|
+
"""
|
688
|
+
convolution_shape_init = np.add(target.shape, template.shape) - 1
|
689
|
+
# If target and template are aligned, fitting should return this index
|
690
|
+
reference_fit = np.subtract(template.shape, 1)
|
691
|
+
|
692
|
+
target_box = tuple(slice(0, x) for x in target.shape)
|
693
|
+
if map_cutoff is not None:
|
694
|
+
target_box = target.trim_box(cutoff=map_cutoff)
|
695
|
+
|
696
|
+
target_mask_box = target_box
|
697
|
+
if target_mask is not None and map_cutoff is not None:
|
698
|
+
target_mask_box = target_mask.trim_box(cutoff=map_cutoff)
|
699
|
+
target_box = tuple(
|
700
|
+
slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
|
701
|
+
for arr, mask in zip(target_box, target_mask_box)
|
702
|
+
)
|
703
|
+
|
704
|
+
template_box = tuple(slice(0, x) for x in template.shape)
|
705
|
+
if template_cutoff is not None:
|
706
|
+
template_box = template.trim_box(cutoff=template_cutoff)
|
707
|
+
|
708
|
+
template_mask_box = template_box
|
709
|
+
if template_mask is not None and template_cutoff is not None:
|
710
|
+
template_mask_box = template_mask.trim_box(cutoff=template_cutoff)
|
711
|
+
template_box = tuple(
|
712
|
+
slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
|
713
|
+
for arr, mask in zip(template_box, template_mask_box)
|
714
|
+
)
|
715
|
+
|
716
|
+
cut_right = np.array(
|
717
|
+
[shape - x.stop for shape, x in zip(template.shape, template_box)]
|
718
|
+
)
|
719
|
+
cut_left = np.array([x.start for x in target_box])
|
720
|
+
|
721
|
+
origin_difference = np.divide(target.origin - template.origin, target.sampling_rate)
|
722
|
+
origin_difference = origin_difference.astype(int)
|
723
|
+
|
724
|
+
target.adjust_box(target_box)
|
725
|
+
template.adjust_box(template_box)
|
726
|
+
|
727
|
+
if target_mask is not None:
|
728
|
+
target_mask.adjust_box(target_box)
|
729
|
+
if template_mask is not None:
|
730
|
+
template_mask.adjust_box(template_box)
|
731
|
+
|
732
|
+
reference_fit -= cut_right + cut_left + origin_difference
|
733
|
+
|
734
|
+
convolution_shape = np.array(target.shape)
|
735
|
+
convolution_shape += np.array(template.shape) - 1
|
736
|
+
|
737
|
+
print(f"Cropped volume of target is: {target.shape}")
|
738
|
+
print(f"Cropped volume of template is: {template.shape}")
|
739
|
+
saving = 1 - (np.prod(convolution_shape)) / np.prod(convolution_shape_init)
|
740
|
+
saving *= 100
|
741
|
+
|
742
|
+
print(
|
743
|
+
"Cropping changed array size from "
|
744
|
+
f"{round(4*np.prod(convolution_shape_init)/1e6, 3)} MB "
|
745
|
+
f"to {round(4*np.prod(convolution_shape)/1e6, 3)} MB "
|
746
|
+
f"({'-' if saving > 0 else ''}{abs(round(saving, 2))}%)"
|
747
|
+
)
|
748
|
+
return reference_fit
|
749
|
+
|
750
|
+
|
751
|
+
def euler_to_rotationmatrix(angles: Tuple[float]) -> NDArray:
|
752
|
+
"""
|
753
|
+
Convert Euler angles to a rotation matrix.
|
754
|
+
|
755
|
+
Parameters
|
756
|
+
----------
|
757
|
+
angles : tuple
|
758
|
+
A tuple representing the Euler angles in degrees.
|
759
|
+
|
760
|
+
Returns
|
761
|
+
-------
|
762
|
+
NDArray
|
763
|
+
The generated rotation matrix.
|
764
|
+
"""
|
765
|
+
if len(angles) == 1:
|
766
|
+
angles = (angles, 0, 0)
|
767
|
+
rotation_matrix = (
|
768
|
+
Rotation.from_euler("zyx", angles, degrees=True).as_matrix().astype(np.float32)
|
769
|
+
)
|
770
|
+
return rotation_matrix
|
771
|
+
|
772
|
+
|
773
|
+
def euler_from_rotationmatrix(rotation_matrix: NDArray) -> Tuple:
|
774
|
+
"""
|
775
|
+
Convert a rotation matrix to euler angles.
|
776
|
+
|
777
|
+
Parameters
|
778
|
+
----------
|
779
|
+
rotation_matrix : NDArray
|
780
|
+
A 2 x 2 or 3 x 3 rotation matrix in z y x form.
|
781
|
+
|
782
|
+
Returns
|
783
|
+
-------
|
784
|
+
Tuple
|
785
|
+
The generate euler angles in degrees
|
786
|
+
"""
|
787
|
+
if rotation_matrix.shape[0] == 2:
|
788
|
+
temp_matrix = np.eye(3)
|
789
|
+
temp_matrix[:2, :2] = rotation_matrix
|
790
|
+
rotation_matrix = temp_matrix
|
791
|
+
euler_angles = (
|
792
|
+
Rotation.from_matrix(rotation_matrix)
|
793
|
+
.as_euler("zyx", degrees=True)
|
794
|
+
.astype(np.float32)
|
795
|
+
)
|
796
|
+
return euler_angles
|
797
|
+
|
798
|
+
|
799
|
+
def rigid_transform(
|
800
|
+
coordinates: NDArray,
|
801
|
+
rotation_matrix: NDArray,
|
802
|
+
out: NDArray,
|
803
|
+
translation: NDArray,
|
804
|
+
use_geometric_center: bool = False,
|
805
|
+
coordinates_mask: NDArray = None,
|
806
|
+
out_mask: NDArray = None,
|
807
|
+
center: NDArray = None,
|
808
|
+
) -> None:
|
809
|
+
"""
|
810
|
+
Apply a rigid transformation (rotation and translation) to given coordinates.
|
811
|
+
|
812
|
+
Parameters
|
813
|
+
----------
|
814
|
+
coordinates : NDArray
|
815
|
+
An array representing the coordinates to be transformed [d x N].
|
816
|
+
rotation_matrix : NDArray
|
817
|
+
The rotation matrix to be applied [d x d].
|
818
|
+
translation : NDArray
|
819
|
+
The translation vector to be applied [d].
|
820
|
+
out : NDArray
|
821
|
+
The output array to store the transformed coordinates.
|
822
|
+
coordinates_mask : NDArray, optional
|
823
|
+
An array representing the mask for the coordinates [d x T].
|
824
|
+
out_mask : NDArray, optional
|
825
|
+
The output array to store the transformed coordinates mask.
|
826
|
+
use_geometric_center : bool, optional
|
827
|
+
Whether to use geometric or coordinate center.
|
828
|
+
|
829
|
+
Returns
|
830
|
+
-------
|
831
|
+
None
|
832
|
+
"""
|
833
|
+
coordinate_dtype = coordinates.dtype
|
834
|
+
center = coordinates.mean(axis=1) if center is None else center
|
835
|
+
if not use_geometric_center:
|
836
|
+
coordinates = coordinates - center[:, None]
|
837
|
+
|
838
|
+
np.matmul(rotation_matrix, coordinates, out=out)
|
839
|
+
if use_geometric_center:
|
840
|
+
axis_max, axis_min = out.max(axis=1), out.min(axis=1)
|
841
|
+
axis_difference = axis_max - axis_min
|
842
|
+
translation = np.add(translation, center - axis_max + (axis_difference // 2))
|
843
|
+
else:
|
844
|
+
translation = np.add(translation, np.subtract(center, out.mean(axis=1)))
|
845
|
+
|
846
|
+
out += translation[:, None]
|
847
|
+
if coordinates_mask is not None and out_mask is not None:
|
848
|
+
if not use_geometric_center:
|
849
|
+
coordinates_mask = coordinates_mask - center[:, None]
|
850
|
+
np.matmul(rotation_matrix, coordinates_mask, out=out_mask)
|
851
|
+
out_mask += translation[:, None]
|
852
|
+
|
853
|
+
if not use_geometric_center and coordinate_dtype != out.dtype:
|
854
|
+
np.subtract(out.mean(axis=1), out.astype(int).mean(axis=1), out=translation)
|
855
|
+
out += translation[:, None]
|
856
|
+
|
857
|
+
|
858
|
+
def _format_string(string: str) -> str:
|
859
|
+
"""
|
860
|
+
Formats a string by adding quotation marks if it contains white spaces.
|
861
|
+
|
862
|
+
Parameters
|
863
|
+
----------
|
864
|
+
string : str
|
865
|
+
Input string to be formatted.
|
866
|
+
|
867
|
+
Returns
|
868
|
+
-------
|
869
|
+
str
|
870
|
+
Formatted string with added quotation marks if needed.
|
871
|
+
"""
|
872
|
+
if " " in string:
|
873
|
+
return f"'{string}'"
|
874
|
+
# Occurs e.g. for C1' atoms. The trailing whitespace is necessary.
|
875
|
+
if string.count("'") == 1:
|
876
|
+
return f'"{string}"'
|
877
|
+
return string
|
878
|
+
|
879
|
+
|
880
|
+
def _format_mmcif_colunns(subdict: Dict) -> Dict:
|
881
|
+
"""
|
882
|
+
Formats the columns of a mmcif dictionary.
|
883
|
+
|
884
|
+
Parameters
|
885
|
+
----------
|
886
|
+
subdict : dict
|
887
|
+
Input dictionary where each key corresponds to a column and the
|
888
|
+
values are iterables containing the column values.
|
889
|
+
|
890
|
+
Returns
|
891
|
+
-------
|
892
|
+
dict
|
893
|
+
Formatted dictionary with the columns of the mmcif file.
|
894
|
+
"""
|
895
|
+
subdict = {k: [_format_string(s) for s in v] for k, v in subdict.items()}
|
896
|
+
key_length = {
|
897
|
+
key: len(max(value, key=lambda x: len(x), default=""))
|
898
|
+
for key, value in subdict.items()
|
899
|
+
}
|
900
|
+
padded_subdict = {
|
901
|
+
key: [s.ljust(key_length[key] + 1) for s in values]
|
902
|
+
for key, values in subdict.items()
|
903
|
+
}
|
904
|
+
return padded_subdict
|
905
|
+
|
906
|
+
|
907
|
+
def create_mask(mask_type: str, **kwargs) -> NDArray:
|
908
|
+
"""
|
909
|
+
Creates a mask of the specified type.
|
910
|
+
|
911
|
+
Parameters
|
912
|
+
----------
|
913
|
+
mask_type : str
|
914
|
+
Type of the mask to be created. Can be "ellipse", "box", or "tube".
|
915
|
+
kwargs : dict
|
916
|
+
Additional parameters required by the mask creating functions.
|
917
|
+
|
918
|
+
Returns
|
919
|
+
-------
|
920
|
+
NDArray
|
921
|
+
The created mask.
|
922
|
+
|
923
|
+
Raises
|
924
|
+
------
|
925
|
+
ValueError
|
926
|
+
If the mask_type is invalid.
|
927
|
+
|
928
|
+
See Also
|
929
|
+
--------
|
930
|
+
:py:meth:`elliptical_mask`
|
931
|
+
:py:meth:`box_mask`
|
932
|
+
:py:meth:`tube_mask`
|
933
|
+
"""
|
934
|
+
mapping = {"ellipse": elliptical_mask, "box": box_mask, "tube": tube_mask}
|
935
|
+
if mask_type not in mapping:
|
936
|
+
raise ValueError(f"mask_type has to be one of {','.join(mapping.keys())}")
|
937
|
+
|
938
|
+
return mapping[mask_type](**kwargs)
|
939
|
+
|
940
|
+
|
941
|
+
def elliptical_mask(
|
942
|
+
shape: Tuple[int], radius: Tuple[float], center: Tuple[int]
|
943
|
+
) -> NDArray:
|
944
|
+
"""
|
945
|
+
Creates an ellipsoidal mask.
|
946
|
+
|
947
|
+
Parameters
|
948
|
+
----------
|
949
|
+
shape : tuple
|
950
|
+
Shape of the mask to be created.
|
951
|
+
radius : tuple
|
952
|
+
Radius of the ellipse.
|
953
|
+
center : tuple
|
954
|
+
Center of the ellipse.
|
955
|
+
|
956
|
+
Returns
|
957
|
+
-------
|
958
|
+
NDArray
|
959
|
+
The created ellipsoidal mask.
|
960
|
+
|
961
|
+
Raises
|
962
|
+
------
|
963
|
+
ValueError
|
964
|
+
If the length of center and radius is not one or the same as shape.
|
965
|
+
|
966
|
+
Examples
|
967
|
+
--------
|
968
|
+
>>> mask = elliptical_mask(shape = (20,20), radius = (5,5), center = (10,10))
|
969
|
+
"""
|
970
|
+
center, shape, radius = np.asarray(center), np.asarray(shape), np.asarray(radius)
|
971
|
+
|
972
|
+
radius = np.repeat(radius, shape.size // radius.size)
|
973
|
+
center = np.repeat(center, shape.size // center.size)
|
974
|
+
|
975
|
+
if radius.size != shape.size:
|
976
|
+
raise ValueError("Length of radius has to be either one or match shape.")
|
977
|
+
if center.size != shape.size:
|
978
|
+
raise ValueError("Length of center has to be either one or match shape.")
|
979
|
+
|
980
|
+
n = shape.size
|
981
|
+
center = center.reshape((-1,) + (1,) * n)
|
982
|
+
radius = radius.reshape((-1,) + (1,) * n)
|
983
|
+
|
984
|
+
mask = np.linalg.norm((np.indices(shape) - center) / radius, axis=0)
|
985
|
+
mask = (mask <= 1).astype(int)
|
986
|
+
|
987
|
+
return mask
|
988
|
+
|
989
|
+
|
990
|
+
def box_mask(shape: Tuple[int], center: Tuple[int], height: Tuple[int]) -> np.ndarray:
|
991
|
+
"""
|
992
|
+
Creates a box mask centered around the provided center point.
|
993
|
+
|
994
|
+
Parameters
|
995
|
+
----------
|
996
|
+
shape : Tuple[int]
|
997
|
+
Shape of the output array.
|
998
|
+
center : Tuple[int]
|
999
|
+
Center point coordinates of the box.
|
1000
|
+
height : Tuple[int]
|
1001
|
+
Height (side length) of the box along each axis.
|
1002
|
+
|
1003
|
+
Returns
|
1004
|
+
-------
|
1005
|
+
NDArray
|
1006
|
+
The created box mask.
|
1007
|
+
"""
|
1008
|
+
if len(shape) != len(center) or len(center) != len(height):
|
1009
|
+
raise ValueError("The length of shape, center, and height must be consistent.")
|
1010
|
+
|
1011
|
+
# Calculate min and max coordinates for the box using the center and half-heights
|
1012
|
+
center, height = np.array(center, dtype=int), np.array(height, dtype=int)
|
1013
|
+
|
1014
|
+
half_heights = height // 2
|
1015
|
+
starts = np.maximum(center - half_heights, 0)
|
1016
|
+
stops = np.minimum(center + half_heights + np.mod(height, 2) + 1, shape)
|
1017
|
+
slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
|
1018
|
+
|
1019
|
+
out = np.zeros(shape)
|
1020
|
+
out[slice_indices] = 1
|
1021
|
+
return out
|
1022
|
+
|
1023
|
+
|
1024
|
+
def tube_mask(
|
1025
|
+
shape: Tuple[int],
|
1026
|
+
symmetry_axis: int,
|
1027
|
+
base_center: Tuple[int],
|
1028
|
+
inner_radius: float,
|
1029
|
+
outer_radius: float,
|
1030
|
+
height: int,
|
1031
|
+
) -> NDArray:
|
1032
|
+
"""
|
1033
|
+
Creates a tube mask.
|
1034
|
+
|
1035
|
+
Parameters
|
1036
|
+
----------
|
1037
|
+
shape : tuple
|
1038
|
+
Shape of the mask to be created.
|
1039
|
+
symmetry_axis : int
|
1040
|
+
The axis of symmetry for the tube.
|
1041
|
+
base_center : tuple
|
1042
|
+
Center of the base circle of the tube.
|
1043
|
+
inner_radius : float
|
1044
|
+
Inner radius of the tube.
|
1045
|
+
outer_radius : float
|
1046
|
+
Outer radius of the tube.
|
1047
|
+
height : int
|
1048
|
+
Height of the tube.
|
1049
|
+
|
1050
|
+
Returns
|
1051
|
+
-------
|
1052
|
+
NDArray
|
1053
|
+
The created tube mask.
|
1054
|
+
|
1055
|
+
Raises
|
1056
|
+
------
|
1057
|
+
ValueError
|
1058
|
+
If the inner radius is larger than the outer radius. Or height is larger
|
1059
|
+
than the symmetry axis shape.
|
1060
|
+
"""
|
1061
|
+
if inner_radius > outer_radius:
|
1062
|
+
raise ValueError("inner_radius should be smaller than outer_radius.")
|
1063
|
+
|
1064
|
+
if height > shape[symmetry_axis]:
|
1065
|
+
raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
|
1066
|
+
|
1067
|
+
if symmetry_axis > len(shape):
|
1068
|
+
raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
|
1069
|
+
|
1070
|
+
circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
|
1071
|
+
base_center = tuple(b for ix, b in enumerate(base_center) if ix != symmetry_axis)
|
1072
|
+
|
1073
|
+
inner_circle = create_mask(
|
1074
|
+
mask_type="ellipse",
|
1075
|
+
shape=circle_shape,
|
1076
|
+
radius=inner_radius,
|
1077
|
+
center=base_center,
|
1078
|
+
)
|
1079
|
+
outer_circle = create_mask(
|
1080
|
+
mask_type="ellipse",
|
1081
|
+
shape=circle_shape,
|
1082
|
+
radius=outer_radius,
|
1083
|
+
center=base_center,
|
1084
|
+
)
|
1085
|
+
circle = outer_circle - inner_circle
|
1086
|
+
circle = np.expand_dims(circle, axis=symmetry_axis)
|
1087
|
+
|
1088
|
+
center = shape[symmetry_axis] // 2
|
1089
|
+
start_idx = center - height // 2
|
1090
|
+
stop_idx = center + height // 2 + height % 2
|
1091
|
+
|
1092
|
+
slice_indices = tuple(
|
1093
|
+
slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
|
1094
|
+
for i in range(len(shape))
|
1095
|
+
)
|
1096
|
+
tube = np.zeros(shape)
|
1097
|
+
tube[slice_indices] = np.repeat(circle, height, axis=symmetry_axis)
|
1098
|
+
|
1099
|
+
return tube
|
1100
|
+
|
1101
|
+
|
1102
|
+
def scramble_phases(
|
1103
|
+
arr: NDArray, noise_proportion: float = 0.5, seed: int = 42
|
1104
|
+
) -> NDArray:
|
1105
|
+
"""
|
1106
|
+
Applies random phase scrambling to a given array.
|
1107
|
+
|
1108
|
+
This function takes an input array, applies a Fourier transform, then scrambles the
|
1109
|
+
phase with a given proportion of noise, and finally applies an
|
1110
|
+
inverse Fourier transform to the scrambled data. The phase scrambling
|
1111
|
+
is controlled by a random seed.
|
1112
|
+
|
1113
|
+
Parameters
|
1114
|
+
----------
|
1115
|
+
arr : NDArray
|
1116
|
+
The input array to be scrambled.
|
1117
|
+
noise_proportion : float, optional
|
1118
|
+
The proportion of noise in the phase scrambling, by default 0.5.
|
1119
|
+
seed : int, optional
|
1120
|
+
The seed for the random phase scrambling, by default 42.
|
1121
|
+
|
1122
|
+
Returns
|
1123
|
+
-------
|
1124
|
+
NDArray
|
1125
|
+
The array with scrambled phases.
|
1126
|
+
|
1127
|
+
Raises
|
1128
|
+
------
|
1129
|
+
ValueError
|
1130
|
+
If noise_proportion is not within [0, 1].
|
1131
|
+
"""
|
1132
|
+
if noise_proportion < 0 or noise_proportion > 1:
|
1133
|
+
raise ValueError("noise_proportion has to be within [0, 1].")
|
1134
|
+
|
1135
|
+
arr_fft = np.fft.fftn(arr)
|
1136
|
+
|
1137
|
+
amp = np.abs(arr_fft)
|
1138
|
+
ph = np.angle(arr_fft)
|
1139
|
+
|
1140
|
+
np.random.seed(seed)
|
1141
|
+
ph_noise = np.random.permutation(ph)
|
1142
|
+
ph_new = ph * (1 - noise_proportion) + ph_noise * noise_proportion
|
1143
|
+
ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph_new)))
|
1144
|
+
return ret
|
1145
|
+
|
1146
|
+
|
1147
|
+
def conditional_execute(func: Callable, execute_operation: bool = True) -> Callable:
|
1148
|
+
"""
|
1149
|
+
Return the given function or a no-operation function based on execute_operation.
|
1150
|
+
|
1151
|
+
Parameters
|
1152
|
+
----------
|
1153
|
+
func : callable
|
1154
|
+
The function to be executed if execute_operation is True.
|
1155
|
+
execute_operation : bool, optional
|
1156
|
+
A flag that determines whether to return `func` or a no-operation function.
|
1157
|
+
Default is True.
|
1158
|
+
|
1159
|
+
Returns
|
1160
|
+
-------
|
1161
|
+
callable
|
1162
|
+
Either the given function `func` or a no-operation function.
|
1163
|
+
|
1164
|
+
Examples
|
1165
|
+
--------
|
1166
|
+
>>> def greet(name):
|
1167
|
+
... return f"Hello, {name}!"
|
1168
|
+
...
|
1169
|
+
>>> operation = conditional_execute(greet, False)
|
1170
|
+
>>> operation("Alice")
|
1171
|
+
>>> operation = conditional_execute(greet, True)
|
1172
|
+
>>> operation("Alice")
|
1173
|
+
'Hello, Alice!'
|
1174
|
+
"""
|
1175
|
+
|
1176
|
+
def noop(*args, **kwargs):
|
1177
|
+
"""No operation function."""
|
1178
|
+
pass
|
1179
|
+
|
1180
|
+
return func if execute_operation else noop
|