pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
- pytme-0.2.3.data/scripts/preprocess.py +132 -0
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
- pytme-0.2.3.dist-info/METADATA +92 -0
- pytme-0.2.3.dist-info/RECORD +75 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
- pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/extract_candidates.py +20 -13
- scripts/match_template.py +219 -216
- scripts/match_template_filters.py +154 -95
- scripts/postprocess.py +86 -54
- scripts/preprocess.py +95 -56
- scripts/preprocessor_gui.py +181 -94
- scripts/refine_matches.py +265 -61
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +458 -813
- tme/backends/__init__.py +40 -11
- tme/backends/_jax_utils.py +187 -0
- tme/backends/cupy_backend.py +109 -226
- tme/backends/jax_backend.py +230 -152
- tme/backends/matching_backend.py +445 -384
- tme/backends/mlx_backend.py +32 -59
- tme/backends/npfftw_backend.py +240 -507
- tme/backends/pytorch_backend.py +30 -151
- tme/density.py +248 -371
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/matching_data.py +328 -284
- tme/matching_exhaustive.py +195 -1499
- tme/matching_optimization.py +143 -106
- tme/matching_scores.py +887 -0
- tme/matching_utils.py +287 -388
- tme/memory.py +377 -0
- tme/orientations.py +78 -21
- tme/parser.py +3 -4
- tme/preprocessing/_utils.py +61 -32
- tme/preprocessing/composable_filter.py +7 -4
- tme/preprocessing/compose.py +7 -3
- tme/preprocessing/frequency_filters.py +49 -39
- tme/preprocessing/tilt_series.py +44 -72
- tme/preprocessor.py +560 -526
- tme/structure.py +491 -188
- tme/types.py +5 -3
- pytme-0.2.1.dist-info/METADATA +0 -73
- pytme-0.2.1.dist-info/RECORD +0 -73
- tme/helpers.py +0 -881
- tme/matching_constrained.py +0 -195
- {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
- {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/matching_utils.py
CHANGED
@@ -5,7 +5,7 @@
|
|
5
5
|
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
6
|
"""
|
7
7
|
import os
|
8
|
-
import
|
8
|
+
import yaml
|
9
9
|
import pickle
|
10
10
|
from shutil import move
|
11
11
|
from tempfile import mkstemp
|
@@ -14,47 +14,100 @@ from typing import Tuple, Dict, Callable
|
|
14
14
|
from concurrent.futures import ThreadPoolExecutor
|
15
15
|
|
16
16
|
import numpy as np
|
17
|
-
from numpy.typing import NDArray
|
18
17
|
from scipy.spatial import ConvexHull
|
19
18
|
from scipy.ndimage import gaussian_filter
|
20
19
|
from scipy.spatial.transform import Rotation
|
21
20
|
|
21
|
+
from .backends import backend as be
|
22
|
+
from .memory import estimate_ram_usage
|
23
|
+
from .types import NDArray, BackendArray
|
22
24
|
from .extensions import max_euclidean_distance
|
23
|
-
from .matching_memory import estimate_ram_usage
|
24
|
-
from .helpers import quaternion_to_rotation_matrix, load_quaternions_by_angle
|
25
25
|
|
26
26
|
|
27
|
-
def
|
27
|
+
def noop(*args, **kwargs):
|
28
|
+
pass
|
29
|
+
|
30
|
+
|
31
|
+
def identity(arr, *args):
|
32
|
+
return arr
|
33
|
+
|
34
|
+
|
35
|
+
def conditional_execute(
|
36
|
+
func: Callable,
|
37
|
+
execute_operation: bool,
|
38
|
+
alt_func: Callable = noop,
|
39
|
+
) -> Callable:
|
28
40
|
"""
|
29
|
-
|
41
|
+
Return the given function or a no-op function based on execute_operation.
|
30
42
|
|
31
43
|
Parameters
|
32
44
|
----------
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
where the exception originally occurred.
|
45
|
+
func : Callable
|
46
|
+
Callable.
|
47
|
+
alt_func : Callable
|
48
|
+
Callable to return if ``execute_operation`` is False, no-op by default.
|
49
|
+
execute_operation : bool
|
50
|
+
Whether to return ``func`` or a ``alt_func`` function.
|
40
51
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
52
|
+
Returns
|
53
|
+
-------
|
54
|
+
Callable
|
55
|
+
``func`` if ``execute_operation`` else ``alt_func``.
|
56
|
+
"""
|
57
|
+
|
58
|
+
return func if execute_operation else alt_func
|
59
|
+
|
60
|
+
|
61
|
+
def normalize_template(
|
62
|
+
template: BackendArray, mask: BackendArray, n_observations: float
|
63
|
+
) -> BackendArray:
|
64
|
+
"""
|
65
|
+
Standardizes ``template`` to zero mean and unit standard deviation in ``mask``.
|
66
|
+
|
67
|
+
.. warning:: ``template`` is modified during the operation.
|
68
|
+
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
template : BackendArray
|
72
|
+
Input data.
|
73
|
+
mask : BackendArray
|
74
|
+
Mask of the same shape as ``template``.
|
75
|
+
n_observations : float
|
76
|
+
Sum of mask elements.
|
77
|
+
|
78
|
+
Returns
|
79
|
+
-------
|
80
|
+
BackendArray
|
81
|
+
Standardized input data.
|
82
|
+
|
83
|
+
References
|
84
|
+
----------
|
85
|
+
.. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
|
45
86
|
"""
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
87
|
+
masked_mean = be.sum(be.multiply(template, mask)) / n_observations
|
88
|
+
masked_std = be.sum(be.multiply(be.square(template), mask))
|
89
|
+
masked_std = be.subtract(masked_std / n_observations, be.square(masked_mean))
|
90
|
+
masked_std = be.sqrt(be.maximum(masked_std, 0))
|
91
|
+
|
92
|
+
template = be.subtract(template, masked_mean, out=template)
|
93
|
+
template = be.divide(template, masked_std, out=template)
|
94
|
+
return be.multiply(template, mask, out=template)
|
95
|
+
|
96
|
+
|
97
|
+
def _normalize_template_overflow_safe(
|
98
|
+
template: BackendArray, mask: BackendArray, n_observations: float
|
99
|
+
) -> BackendArray:
|
100
|
+
_template = be.astype(template, be._overflow_safe_dtype)
|
101
|
+
_mask = be.astype(mask, be._overflow_safe_dtype)
|
102
|
+
normalize_template(template=_template, mask=_mask, n_observations=n_observations)
|
103
|
+
template[:] = be.astype(_template, template.dtype)
|
104
|
+
return template
|
51
105
|
|
52
106
|
|
53
|
-
def generate_tempfile_name(suffix=None):
|
107
|
+
def generate_tempfile_name(suffix: str = None) -> str:
|
54
108
|
"""
|
55
|
-
Returns the path to a
|
56
|
-
variable
|
57
|
-
Otherwise the default tmp directory will be used.
|
109
|
+
Returns the path to a temporary file with given suffix. If defined. the
|
110
|
+
environment variable TMPDIR is used as base.
|
58
111
|
|
59
112
|
Parameters
|
60
113
|
----------
|
@@ -73,26 +126,19 @@ def generate_tempfile_name(suffix=None):
|
|
73
126
|
|
74
127
|
def array_to_memmap(arr: NDArray, filename: str = None) -> str:
|
75
128
|
"""
|
76
|
-
Converts a numpy
|
129
|
+
Converts a obj:`numpy.ndarray` to a obj:`numpy.memmap`.
|
77
130
|
|
78
131
|
Parameters
|
79
132
|
----------
|
80
|
-
arr :
|
81
|
-
|
133
|
+
arr : obj:`numpy.ndarray`
|
134
|
+
Input data.
|
82
135
|
filename : str, optional
|
83
|
-
|
84
|
-
file will be created.
|
85
|
-
|
86
|
-
Notes
|
87
|
-
-----
|
88
|
-
If the environment variable TME_TMPDIR is defined, the temporary
|
89
|
-
file will be created there. Otherwise the default tmp directory
|
90
|
-
will be used.
|
136
|
+
Path to new memmap, :py:meth:`generate_tempfile_name` is used by default.
|
91
137
|
|
92
138
|
Returns
|
93
139
|
-------
|
94
140
|
str
|
95
|
-
|
141
|
+
Path to the memmap.
|
96
142
|
"""
|
97
143
|
if filename is None:
|
98
144
|
filename = generate_tempfile_name()
|
@@ -108,47 +154,28 @@ def array_to_memmap(arr: NDArray, filename: str = None) -> str:
|
|
108
154
|
|
109
155
|
def memmap_to_array(arr: NDArray) -> NDArray:
|
110
156
|
"""
|
111
|
-
|
157
|
+
Convert a obj:`numpy.memmap` to a obj:`numpy.ndarray` and delete the memmap.
|
112
158
|
|
113
159
|
Parameters
|
114
160
|
----------
|
115
|
-
arr :
|
116
|
-
|
161
|
+
arr : obj:`numpy.memmap`
|
162
|
+
Input data.
|
117
163
|
|
118
164
|
Returns
|
119
165
|
-------
|
120
|
-
|
121
|
-
|
166
|
+
obj:`numpy.ndarray`
|
167
|
+
In-memory version of ``arr``.
|
122
168
|
"""
|
123
|
-
if
|
169
|
+
if isinstance(arr, np.memmap):
|
124
170
|
memmap_filepath = arr.filename
|
125
171
|
arr = np.array(arr)
|
126
172
|
os.remove(memmap_filepath)
|
127
173
|
return arr
|
128
174
|
|
129
175
|
|
130
|
-
def close_memmap(arr: np.ndarray) -> None:
|
131
|
-
"""
|
132
|
-
Remove the file associated with a numpy memmap array.
|
133
|
-
|
134
|
-
Parameters
|
135
|
-
----------
|
136
|
-
arr : np.ndarray
|
137
|
-
The numpy array which might be a memmap.
|
138
|
-
"""
|
139
|
-
try:
|
140
|
-
os.remove(arr.filename)
|
141
|
-
# arr._mmap.close()
|
142
|
-
except Exception:
|
143
|
-
pass
|
144
|
-
|
145
|
-
|
146
176
|
def write_pickle(data: object, filename: str) -> None:
|
147
177
|
"""
|
148
|
-
Serialize and write data to a file invalidating the input data
|
149
|
-
the process. This function uses type-specific serialization for
|
150
|
-
certain objects, such as np.memmap, for optimized storage. Other
|
151
|
-
objects are serialized using standard pickle.
|
178
|
+
Serialize and write data to a file invalidating the input data.
|
152
179
|
|
153
180
|
Parameters
|
154
181
|
----------
|
@@ -316,7 +343,7 @@ def compute_parallelization_schedule(
|
|
316
343
|
split_factor, n_splits = [1 for _ in range(len(shape1))], 0
|
317
344
|
while n_splits <= max_splits:
|
318
345
|
splits = {k: split_factor[k] for k in range(len(split_factor))}
|
319
|
-
array_slices =
|
346
|
+
array_slices = split_shape(shape=shape1, splits=splits)
|
320
347
|
array_widths = [
|
321
348
|
tuple(x.stop - x.start for x in split) for split in array_slices
|
322
349
|
]
|
@@ -378,55 +405,57 @@ def compute_parallelization_schedule(
|
|
378
405
|
return splits, core_assignment
|
379
406
|
|
380
407
|
|
381
|
-
def
|
408
|
+
def _center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[slice]:
|
409
|
+
"""Extract the center slice of ``current_shape`` to retrieve ``new_shape``."""
|
410
|
+
new_shape = tuple(int(x) for x in new_shape)
|
411
|
+
current_shape = tuple(int(x) for x in current_shape)
|
412
|
+
starts = tuple((x - y) // 2 for x, y in zip(current_shape, new_shape))
|
413
|
+
stops = tuple(sum(stop) for stop in zip(starts, new_shape))
|
414
|
+
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
415
|
+
return box
|
416
|
+
|
417
|
+
|
418
|
+
def centered(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
|
382
419
|
"""
|
383
420
|
Extract the centered portion of an array based on a new shape.
|
384
421
|
|
385
422
|
Parameters
|
386
423
|
----------
|
387
|
-
arr :
|
388
|
-
Input
|
389
|
-
|
424
|
+
arr : BackendArray
|
425
|
+
Input data.
|
426
|
+
new_shape : tuple of ints
|
390
427
|
Desired shape for the central portion.
|
391
428
|
|
392
429
|
Returns
|
393
430
|
-------
|
394
|
-
|
395
|
-
Central portion of the array with shape
|
431
|
+
BackendArray
|
432
|
+
Central portion of the array with shape ``new_shape``.
|
396
433
|
|
397
434
|
References
|
398
435
|
----------
|
399
436
|
.. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L388
|
400
437
|
"""
|
401
|
-
|
402
|
-
current_shape = np.array(arr.shape)
|
403
|
-
starts = (current_shape - new_shape) // 2
|
404
|
-
stops = starts + newshape
|
405
|
-
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
438
|
+
box = _center_slice(arr.shape, new_shape=new_shape)
|
406
439
|
return arr[box]
|
407
440
|
|
408
441
|
|
409
|
-
def centered_mask(arr:
|
442
|
+
def centered_mask(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
|
410
443
|
"""
|
411
444
|
Mask the centered portion of an array based on a new shape.
|
412
445
|
|
413
446
|
Parameters
|
414
447
|
----------
|
415
|
-
arr :
|
416
|
-
Input
|
417
|
-
|
448
|
+
arr : BackendArray
|
449
|
+
Input data.
|
450
|
+
new_shape : tuple of ints
|
418
451
|
Desired shape for the mask.
|
419
452
|
|
420
453
|
Returns
|
421
454
|
-------
|
422
|
-
|
455
|
+
BackendArray
|
423
456
|
Array with central portion unmasked and the rest set to 0.
|
424
457
|
"""
|
425
|
-
|
426
|
-
current_shape = np.array(arr.shape)
|
427
|
-
starts = (current_shape - new_shape) // 2
|
428
|
-
stops = starts + newshape
|
429
|
-
box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
|
458
|
+
box = _center_slice(arr.shape, new_shape=new_shape)
|
430
459
|
mask = np.zeros_like(arr)
|
431
460
|
mask[box] = 1
|
432
461
|
arr *= mask
|
@@ -434,21 +463,22 @@ def centered_mask(arr: NDArray, newshape: Tuple[int]) -> NDArray:
|
|
434
463
|
|
435
464
|
|
436
465
|
def apply_convolution_mode(
|
437
|
-
arr:
|
466
|
+
arr: BackendArray,
|
438
467
|
convolution_mode: str,
|
439
468
|
s1: Tuple[int],
|
440
469
|
s2: Tuple[int],
|
470
|
+
convolution_shape: Tuple[int] = None,
|
441
471
|
mask_output: bool = False,
|
442
|
-
) ->
|
472
|
+
) -> BackendArray:
|
443
473
|
"""
|
444
|
-
Applies convolution_mode to arr
|
474
|
+
Applies convolution_mode to ``arr``.
|
445
475
|
|
446
476
|
Parameters
|
447
477
|
----------
|
448
|
-
arr :
|
449
|
-
|
478
|
+
arr : BackendArray
|
479
|
+
Array containing convolution result of arrays with shape s1 and s2.
|
450
480
|
convolution_mode : str
|
451
|
-
Analogous to mode in
|
481
|
+
Analogous to mode in obj:`scipy.signal.convolve`:
|
452
482
|
|
453
483
|
+---------+----------------------------------------------------------+
|
454
484
|
| 'full' | returns full template matching result of the inputs. |
|
@@ -457,25 +487,25 @@ def apply_convolution_mode(
|
|
457
487
|
+---------+----------------------------------------------------------+
|
458
488
|
| 'same' | output is the same size as s1. |
|
459
489
|
+---------+----------------------------------------------------------+
|
460
|
-
s1 : tuple
|
490
|
+
s1 : tuple of ints
|
461
491
|
Tuple of integers corresponding to shape of convolution array 1.
|
462
|
-
s2 : tuple
|
492
|
+
s2 : tuple of ints
|
463
493
|
Tuple of integers corresponding to shape of convolution array 2.
|
494
|
+
convolution_shape : tuple of ints, optional
|
495
|
+
Size of the actually computed convolution. s1 + s2 - 1 by default.
|
464
496
|
mask_output : bool, optional
|
465
497
|
Whether to mask values outside of convolution_mode rather than
|
466
498
|
removing them. Defaults to False.
|
467
499
|
|
468
500
|
Returns
|
469
501
|
-------
|
470
|
-
|
471
|
-
The
|
472
|
-
|
473
|
-
References
|
474
|
-
----------
|
475
|
-
.. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L519
|
502
|
+
BackendArray
|
503
|
+
The array after applying the convolution mode.
|
476
504
|
"""
|
477
|
-
#
|
478
|
-
|
505
|
+
# Remove padding to next fast Fourier length
|
506
|
+
if convolution_shape is None:
|
507
|
+
convolution_shape = [s1[i] + s2[i] - 1 for i in range(len(s1))]
|
508
|
+
arr = arr[tuple(slice(0, x) for x in convolution_shape)]
|
479
509
|
|
480
510
|
if convolution_mode not in ("full", "same", "valid"):
|
481
511
|
raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
|
@@ -506,11 +536,9 @@ def compute_full_convolution_index(
|
|
506
536
|
inner_shape : tuple
|
507
537
|
Tuple of integers corresponding to the shape of the inner array.
|
508
538
|
outer_split : tuple
|
509
|
-
Tuple of slices used to split outer array
|
510
|
-
(see :py:meth:`split_numpy_array_slices`).
|
539
|
+
Tuple of slices used to split outer array (see :py:meth:`split_shape`).
|
511
540
|
inner_split : tuple
|
512
|
-
Tuple of slices used to split inner array
|
513
|
-
(see :py:meth:`split_numpy_array_slices`).
|
541
|
+
Tuple of slices used to split inner array (see :py:meth:`split_shape`).
|
514
542
|
|
515
543
|
Returns
|
516
544
|
-------
|
@@ -538,41 +566,43 @@ def compute_full_convolution_index(
|
|
538
566
|
return score_slice
|
539
567
|
|
540
568
|
|
541
|
-
def
|
542
|
-
shape:
|
569
|
+
def split_shape(
|
570
|
+
shape: Tuple[int], splits: Dict, equal_shape: bool = True
|
543
571
|
) -> Tuple[slice]:
|
544
572
|
"""
|
545
|
-
|
573
|
+
Splits ``shape`` into equally sized and potentially overlapping subsets.
|
546
574
|
|
547
575
|
Parameters
|
548
576
|
----------
|
549
|
-
shape :
|
550
|
-
Shape
|
577
|
+
shape : tuple of ints
|
578
|
+
Shape to split.
|
551
579
|
splits : dict
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
Padding on the left hand side of the array.
|
580
|
+
Dictionary mapping axis number to number of splits.
|
581
|
+
equal_shape : dict
|
582
|
+
Whether the subsets should be of equal shape, True by default.
|
556
583
|
|
557
584
|
Returns
|
558
585
|
-------
|
559
586
|
tuple
|
560
|
-
|
587
|
+
Tuple of slice with requested split combinations.
|
561
588
|
"""
|
562
589
|
ndim = len(shape)
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
590
|
+
splits = {k: max(splits.get(k, 1), 1) for k in range(ndim)}
|
591
|
+
ret_shape = np.divide(shape, tuple(splits[i] for i in range(ndim)))
|
592
|
+
if equal_shape:
|
593
|
+
ret_shape = np.ceil(ret_shape).astype(int)
|
594
|
+
ret_shape = ret_shape.astype(int)
|
567
595
|
|
568
596
|
slice_list = [
|
569
597
|
tuple(
|
570
|
-
(slice(
|
598
|
+
(slice((n_splits * length), (n_splits + 1) * length))
|
571
599
|
if n_splits < splits.get(axis, 1) - 1
|
572
|
-
else (slice(
|
600
|
+
else (slice(shape[axis] - length, shape[axis]))
|
601
|
+
if equal_shape
|
602
|
+
else (slice((n_splits * length), shape[axis]))
|
573
603
|
for n_splits in range(splits.get(axis, 1))
|
574
604
|
)
|
575
|
-
for length, axis in zip(
|
605
|
+
for length, axis in zip(ret_shape, splits.keys())
|
576
606
|
]
|
577
607
|
|
578
608
|
splits = tuple(product(*slice_list))
|
@@ -584,28 +614,25 @@ def get_rotation_matrices(
|
|
584
614
|
angular_sampling: float, dim: int = 3, use_optimized_set: bool = True
|
585
615
|
) -> NDArray:
|
586
616
|
"""
|
587
|
-
Returns rotation matrices
|
588
|
-
by ``angular_sampling``.
|
617
|
+
Returns rotation matrices with desired ``angular_sampling`` rate.
|
589
618
|
|
590
619
|
Parameters
|
591
620
|
----------
|
592
621
|
angular_sampling : float
|
593
|
-
The
|
622
|
+
The desired angular sampling in degrees.
|
594
623
|
dim : int, optional
|
595
624
|
Dimension of the rotation matrices.
|
596
625
|
use_optimized_set : bool, optional
|
597
|
-
|
598
|
-
Currently only available when dim=3.
|
626
|
+
Use optimized rotational sets, True by default and available for dim=3.
|
599
627
|
|
600
628
|
Notes
|
601
629
|
-----
|
602
|
-
For
|
603
|
-
QR-decomposition.
|
630
|
+
For dim = 3 optimized sets are used, otherwise QR-decomposition.
|
604
631
|
|
605
632
|
Returns
|
606
633
|
-------
|
607
634
|
NDArray
|
608
|
-
Array of shape (
|
635
|
+
Array of shape (n, d, d) containing n rotation matrices.
|
609
636
|
"""
|
610
637
|
if dim == 3 and use_optimized_set:
|
611
638
|
quaternions, *_ = load_quaternions_by_angle(angular_sampling)
|
@@ -706,144 +733,82 @@ def get_rotations_around_vector(
|
|
706
733
|
return rotation_angles
|
707
734
|
|
708
735
|
|
709
|
-
def
|
710
|
-
|
711
|
-
|
712
|
-
use_geometric_center: bool = False,
|
713
|
-
) -> Tuple[int]:
|
736
|
+
def load_quaternions_by_angle(
|
737
|
+
angular_sampling: float,
|
738
|
+
) -> Tuple[NDArray, NDArray, float]:
|
714
739
|
"""
|
715
|
-
|
740
|
+
Get orientations and weights proportional to the given angular_sampling.
|
716
741
|
|
717
742
|
Parameters
|
718
743
|
----------
|
719
|
-
|
720
|
-
|
721
|
-
of this array should be [d, n] with d dimensions and n coordinates.
|
722
|
-
margin : NDArray, optional
|
723
|
-
Box margin. Defaults to None.
|
724
|
-
use_geometric_center : bool, optional
|
725
|
-
Whether the box should accommodate the geometric or the coordinate
|
726
|
-
center. Defaults to False.
|
744
|
+
angular_sampling : float
|
745
|
+
Requested angular sampling.
|
727
746
|
|
728
747
|
Returns
|
729
748
|
-------
|
730
|
-
|
731
|
-
|
749
|
+
Tuple[NDArray, NDArray, float]
|
750
|
+
Quaternion representations of orientations, weights associated with each
|
751
|
+
quaternion and closest angular sampling to the requested sampling.
|
732
752
|
"""
|
733
|
-
|
734
|
-
|
735
|
-
|
753
|
+
# Metadata contains (N orientations, rotational sampling, coverage as values)
|
754
|
+
with open(
|
755
|
+
os.path.join(os.path.dirname(__file__), "data", "metadata.yaml"), "r"
|
756
|
+
) as infile:
|
757
|
+
metadata = yaml.full_load(infile)
|
758
|
+
|
759
|
+
set_diffs = {
|
760
|
+
setname: abs(angular_sampling - set_angle)
|
761
|
+
for setname, (_, set_angle, _) in metadata.items()
|
762
|
+
}
|
763
|
+
fname = min(set_diffs, key=set_diffs.get)
|
736
764
|
|
737
|
-
|
738
|
-
|
765
|
+
infile = os.path.join(os.path.dirname(__file__), "data", fname)
|
766
|
+
quat_weights = np.load(infile)
|
739
767
|
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
|
744
|
-
).astype(int)
|
745
|
-
if use_geometric_center:
|
746
|
-
hull = ConvexHull(point_cloud.T)
|
747
|
-
distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
|
748
|
-
distance += np.linalg.norm(np.ones(dim))
|
749
|
-
shape = np.repeat(np.rint(distance).astype(int), dim)
|
768
|
+
quat = quat_weights[:, :4]
|
769
|
+
weights = quat_weights[:, -1]
|
770
|
+
angle = metadata[fname][0]
|
750
771
|
|
751
|
-
return
|
772
|
+
return quat, weights, angle
|
752
773
|
|
753
774
|
|
754
|
-
def
|
755
|
-
target: "Density",
|
756
|
-
template: "Density",
|
757
|
-
target_mask: "Density" = None,
|
758
|
-
template_mask: "Density" = None,
|
759
|
-
map_cutoff: float = 0,
|
760
|
-
template_cutoff: float = 0,
|
761
|
-
) -> Tuple[int]:
|
775
|
+
def quaternion_to_rotation_matrix(quaternions: NDArray) -> NDArray:
|
762
776
|
"""
|
763
|
-
|
764
|
-
are cropped in place.
|
777
|
+
Convert quaternions to rotation matrices.
|
765
778
|
|
766
779
|
Parameters
|
767
780
|
----------
|
768
|
-
|
769
|
-
|
770
|
-
template : Density
|
771
|
-
Template to fit onto the target.
|
772
|
-
target_mask : Density, optional
|
773
|
-
Path to mask of target. Will be croppped like target.
|
774
|
-
template_mask : Density, optional
|
775
|
-
Path to mask of template. Will be cropped like template.
|
776
|
-
map_cutoff : float, optional
|
777
|
-
Cutoff value for trimming the target Density. Default is 0.
|
778
|
-
map_cutoff : float, optional
|
779
|
-
Cutoff value for trimming the template Density. Default is 0.
|
781
|
+
quaternions : NDArray
|
782
|
+
Quaternion data of shape (n, 4).
|
780
783
|
|
781
784
|
Returns
|
782
785
|
-------
|
783
|
-
|
784
|
-
|
786
|
+
NDArray
|
787
|
+
Rotation matrices corresponding to the given quaternions.
|
785
788
|
"""
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
target_box = tuple(slice(0, x) for x in target.shape)
|
791
|
-
if map_cutoff is not None:
|
792
|
-
target_box = target.trim_box(cutoff=map_cutoff)
|
793
|
-
|
794
|
-
target_mask_box = target_box
|
795
|
-
if target_mask is not None and map_cutoff is not None:
|
796
|
-
target_mask_box = target_mask.trim_box(cutoff=map_cutoff)
|
797
|
-
target_box = tuple(
|
798
|
-
slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
|
799
|
-
for arr, mask in zip(target_box, target_mask_box)
|
800
|
-
)
|
801
|
-
|
802
|
-
template_box = tuple(slice(0, x) for x in template.shape)
|
803
|
-
if template_cutoff is not None:
|
804
|
-
template_box = template.trim_box(cutoff=template_cutoff)
|
805
|
-
|
806
|
-
template_mask_box = template_box
|
807
|
-
if template_mask is not None and template_cutoff is not None:
|
808
|
-
template_mask_box = template_mask.trim_box(cutoff=template_cutoff)
|
809
|
-
template_box = tuple(
|
810
|
-
slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
|
811
|
-
for arr, mask in zip(template_box, template_mask_box)
|
812
|
-
)
|
789
|
+
q0 = quaternions[:, 0]
|
790
|
+
q1 = quaternions[:, 1]
|
791
|
+
q2 = quaternions[:, 2]
|
792
|
+
q3 = quaternions[:, 3]
|
813
793
|
|
814
|
-
|
815
|
-
|
816
|
-
)
|
817
|
-
cut_left = np.array([x.start for x in target_box])
|
794
|
+
s = np.linalg.norm(quaternions, axis=1) * 2
|
795
|
+
rotmat = np.zeros((quaternions.shape[0], 3, 3), dtype=np.float64)
|
818
796
|
|
819
|
-
|
820
|
-
|
797
|
+
rotmat[:, 0, 0] = 1.0 - s * ((q2 * q2) + (q3 * q3))
|
798
|
+
rotmat[:, 0, 1] = s * ((q1 * q2) - (q0 * q3))
|
799
|
+
rotmat[:, 0, 2] = s * ((q1 * q3) + (q0 * q2))
|
821
800
|
|
822
|
-
|
823
|
-
|
801
|
+
rotmat[:, 1, 0] = s * ((q2 * q1) + (q0 * q3))
|
802
|
+
rotmat[:, 1, 1] = 1.0 - s * ((q3 * q3) + (q1 * q1))
|
803
|
+
rotmat[:, 1, 2] = s * ((q2 * q3) - (q0 * q1))
|
824
804
|
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
template_mask.adjust_box(template_box)
|
805
|
+
rotmat[:, 2, 0] = s * ((q3 * q1) - (q0 * q2))
|
806
|
+
rotmat[:, 2, 1] = s * ((q3 * q2) + (q0 * q1))
|
807
|
+
rotmat[:, 2, 2] = 1.0 - s * ((q1 * q1) + (q2 * q2))
|
829
808
|
|
830
|
-
|
809
|
+
np.around(rotmat, decimals=8, out=rotmat)
|
831
810
|
|
832
|
-
|
833
|
-
convolution_shape += np.array(template.shape) - 1
|
834
|
-
|
835
|
-
print(f"Cropped volume of target is: {target.shape}")
|
836
|
-
print(f"Cropped volume of template is: {template.shape}")
|
837
|
-
saving = 1 - (np.prod(convolution_shape)) / np.prod(convolution_shape_init)
|
838
|
-
saving *= 100
|
839
|
-
|
840
|
-
print(
|
841
|
-
"Cropping changed array size from "
|
842
|
-
f"{round(4*np.prod(convolution_shape_init)/1e6, 3)} MB "
|
843
|
-
f"to {round(4*np.prod(convolution_shape)/1e6, 3)} MB "
|
844
|
-
f"({'-' if saving > 0 else ''}{abs(round(saving, 2))}%)"
|
845
|
-
)
|
846
|
-
return reference_fit
|
811
|
+
return rotmat
|
847
812
|
|
848
813
|
|
849
814
|
def euler_to_rotationmatrix(angles: Tuple[float], convention: str = "zyx") -> NDArray:
|
@@ -866,12 +831,8 @@ def euler_to_rotationmatrix(angles: Tuple[float], convention: str = "zyx") -> ND
|
|
866
831
|
angle_convention = convention[:n_angles]
|
867
832
|
if n_angles == 1:
|
868
833
|
angles = (angles, 0, 0)
|
869
|
-
rotation_matrix = (
|
870
|
-
|
871
|
-
.as_matrix()
|
872
|
-
.astype(np.float32)
|
873
|
-
)
|
874
|
-
return rotation_matrix
|
834
|
+
rotation_matrix = Rotation.from_euler(angle_convention, angles, degrees=True)
|
835
|
+
return rotation_matrix.as_matrix().astype(np.float32)
|
875
836
|
|
876
837
|
|
877
838
|
def euler_from_rotationmatrix(
|
@@ -883,9 +844,10 @@ def euler_from_rotationmatrix(
|
|
883
844
|
Parameters
|
884
845
|
----------
|
885
846
|
rotation_matrix : NDArray
|
886
|
-
A 2 x 2 or 3 x 3 rotation matrix in
|
847
|
+
A 2 x 2 or 3 x 3 rotation matrix in zyx form.
|
887
848
|
convention : str, optional
|
888
|
-
Euler angle convention.
|
849
|
+
Euler angle convention, zyx by default.
|
850
|
+
|
889
851
|
Returns
|
890
852
|
-------
|
891
853
|
Tuple
|
@@ -895,12 +857,8 @@ def euler_from_rotationmatrix(
|
|
895
857
|
temp_matrix = np.eye(3)
|
896
858
|
temp_matrix[:2, :2] = rotation_matrix
|
897
859
|
rotation_matrix = temp_matrix
|
898
|
-
|
899
|
-
|
900
|
-
.as_euler(convention, degrees=True)
|
901
|
-
.astype(np.float32)
|
902
|
-
)
|
903
|
-
return euler_angles
|
860
|
+
rotation = Rotation.from_matrix(rotation_matrix)
|
861
|
+
return rotation.as_euler(convention, degrees=True).astype(np.float32)
|
904
862
|
|
905
863
|
|
906
864
|
def rotation_aligning_vectors(
|
@@ -961,23 +919,19 @@ def rigid_transform(
|
|
961
919
|
Parameters
|
962
920
|
----------
|
963
921
|
coordinates : NDArray
|
964
|
-
An array representing the coordinates to be transformed
|
922
|
+
An array representing the coordinates to be transformed (d,n).
|
965
923
|
rotation_matrix : NDArray
|
966
|
-
The rotation matrix to be applied
|
924
|
+
The rotation matrix to be applied (d,d).
|
967
925
|
translation : NDArray
|
968
|
-
The translation vector to be applied
|
926
|
+
The translation vector to be applied (d,).
|
969
927
|
out : NDArray
|
970
|
-
The output array to store the transformed coordinates.
|
928
|
+
The output array to store the transformed coordinates (d,n).
|
971
929
|
coordinates_mask : NDArray, optional
|
972
|
-
An array representing the mask for the coordinates
|
930
|
+
An array representing the mask for the coordinates (d,t).
|
973
931
|
out_mask : NDArray, optional
|
974
|
-
The output array to store the transformed coordinates mask.
|
932
|
+
The output array to store the transformed coordinates mask (d,t).
|
975
933
|
use_geometric_center : bool, optional
|
976
934
|
Whether to use geometric or coordinate center.
|
977
|
-
|
978
|
-
Returns
|
979
|
-
-------
|
980
|
-
None
|
981
935
|
"""
|
982
936
|
coordinate_dtype = coordinates.dtype
|
983
937
|
center = coordinates.mean(axis=1) if center is None else center
|
@@ -1004,71 +958,67 @@ def rigid_transform(
|
|
1004
958
|
out += translation[:, None]
|
1005
959
|
|
1006
960
|
|
1007
|
-
def
|
961
|
+
def minimum_enclosing_box(
|
962
|
+
coordinates: NDArray, margin: NDArray = None, use_geometric_center: bool = False
|
963
|
+
) -> Tuple[int]:
|
1008
964
|
"""
|
1009
|
-
|
965
|
+
Computes the minimal enclosing box around coordinates with margin.
|
1010
966
|
|
1011
967
|
Parameters
|
1012
968
|
----------
|
1013
|
-
|
1014
|
-
|
969
|
+
coordinates : NDArray
|
970
|
+
Coordinates of shape (d,n) to compute the enclosing box of.
|
971
|
+
margin : NDArray, optional
|
972
|
+
Box margin, zero by default.
|
973
|
+
use_geometric_center : bool, optional
|
974
|
+
Whether box accommodates the geometric or coordinate center, False by default.
|
1015
975
|
|
1016
976
|
Returns
|
1017
977
|
-------
|
1018
|
-
|
1019
|
-
|
978
|
+
tuple of ints
|
979
|
+
Minimum enclosing box shape.
|
1020
980
|
"""
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
if string.count("'") == 1:
|
1025
|
-
return f'"{string}"'
|
1026
|
-
return string
|
1027
|
-
|
981
|
+
point_cloud = np.asarray(coordinates)
|
982
|
+
dim = point_cloud.shape[0]
|
983
|
+
point_cloud = point_cloud - point_cloud.min(axis=1)[:, None]
|
1028
984
|
|
1029
|
-
|
1030
|
-
|
1031
|
-
Formats the columns of a mmcif dictionary.
|
985
|
+
margin = np.zeros(dim) if margin is None else margin
|
986
|
+
margin = np.asarray(margin).astype(int)
|
1032
987
|
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
988
|
+
norm_cloud = point_cloud - point_cloud.mean(axis=1)[:, None]
|
989
|
+
# Adding one avoids clipping during scipy.ndimage.affine_transform
|
990
|
+
shape = np.repeat(
|
991
|
+
np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
|
992
|
+
).astype(int)
|
993
|
+
if use_geometric_center:
|
994
|
+
hull = ConvexHull(point_cloud.T)
|
995
|
+
distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
|
996
|
+
distance += np.linalg.norm(np.ones(dim))
|
997
|
+
shape = np.repeat(np.rint(distance).astype(int), dim)
|
1038
998
|
|
1039
|
-
|
1040
|
-
-------
|
1041
|
-
dict
|
1042
|
-
Formatted dictionary with the columns of the mmcif file.
|
1043
|
-
"""
|
1044
|
-
subdict = {k: [_format_string(s) for s in v] for k, v in subdict.items()}
|
1045
|
-
key_length = {
|
1046
|
-
key: len(max(value, key=lambda x: len(x), default=""))
|
1047
|
-
for key, value in subdict.items()
|
1048
|
-
}
|
1049
|
-
padded_subdict = {
|
1050
|
-
key: [s.ljust(key_length[key] + 1) for s in values]
|
1051
|
-
for key, values in subdict.items()
|
1052
|
-
}
|
1053
|
-
return padded_subdict
|
999
|
+
return shape
|
1054
1000
|
|
1055
1001
|
|
1056
|
-
def create_mask(
|
1057
|
-
mask_type: str, sigma_decay: float = 0, mask_cutoff: float = 0.135, **kwargs
|
1058
|
-
) -> NDArray:
|
1002
|
+
def create_mask(mask_type: str, sigma_decay: float = 0, **kwargs) -> NDArray:
|
1059
1003
|
"""
|
1060
1004
|
Creates a mask of the specified type.
|
1061
1005
|
|
1062
1006
|
Parameters
|
1063
1007
|
----------
|
1064
1008
|
mask_type : str
|
1065
|
-
Type of the mask to be created. Can be
|
1009
|
+
Type of the mask to be created. Can be one of:
|
1010
|
+
|
1011
|
+
+---------+----------------------------------------------------------+
|
1012
|
+
| box | Box mask (see :py:meth:`box_mask`) |
|
1013
|
+
+---------+----------------------------------------------------------+
|
1014
|
+
| tube | Cylindrical mask (see :py:meth:`tube_mask`) |
|
1015
|
+
+---------+----------------------------------------------------------+
|
1016
|
+
| ellipse | Ellipsoidal mask (see :py:meth:`elliptical_mask`) |
|
1017
|
+
+---------+----------------------------------------------------------+
|
1066
1018
|
sigma_decay : float, optional
|
1067
|
-
|
1068
|
-
mask_cutoff : float, optional
|
1069
|
-
Values below mask_cutoff will be set to zero. By default, exp(-2).
|
1019
|
+
Smoothing along mask edges using a Gaussian filter, 0 by default.
|
1070
1020
|
kwargs : dict
|
1071
|
-
|
1021
|
+
Parameters passed to the indivdual mask creation funcitons.
|
1072
1022
|
|
1073
1023
|
Returns
|
1074
1024
|
-------
|
@@ -1079,12 +1029,6 @@ def create_mask(
|
|
1079
1029
|
------
|
1080
1030
|
ValueError
|
1081
1031
|
If the mask_type is invalid.
|
1082
|
-
|
1083
|
-
See Also
|
1084
|
-
--------
|
1085
|
-
:py:meth:`elliptical_mask`
|
1086
|
-
:py:meth:`box_mask`
|
1087
|
-
:py:meth:`tube_mask`
|
1088
1032
|
"""
|
1089
1033
|
mapping = {"ellipse": elliptical_mask, "box": box_mask, "tube": tube_mask}
|
1090
1034
|
if mask_type not in mapping:
|
@@ -1092,9 +1036,9 @@ def create_mask(
|
|
1092
1036
|
|
1093
1037
|
mask = mapping[mask_type](**kwargs)
|
1094
1038
|
if sigma_decay > 0:
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1039
|
+
mask_filter = gaussian_filter(mask.astype(np.float32), sigma=sigma_decay)
|
1040
|
+
mask = np.add(mask, (1 - mask) * mask_filter)
|
1041
|
+
mask[mask < np.exp(-np.square(sigma_decay))] = 0
|
1098
1042
|
|
1099
1043
|
return mask
|
1100
1044
|
|
@@ -1126,6 +1070,7 @@ def elliptical_mask(
|
|
1126
1070
|
|
1127
1071
|
Examples
|
1128
1072
|
--------
|
1073
|
+
>>> from tme.matching_utils import elliptical_mask
|
1129
1074
|
>>> mask = elliptical_mask(shape = (20,20), radius = (5,5), center = (10,10))
|
1130
1075
|
"""
|
1131
1076
|
center, shape, radius = np.asarray(center), np.asarray(shape), np.asarray(radius)
|
@@ -1154,17 +1099,23 @@ def box_mask(shape: Tuple[int], center: Tuple[int], height: Tuple[int]) -> np.nd
|
|
1154
1099
|
|
1155
1100
|
Parameters
|
1156
1101
|
----------
|
1157
|
-
shape :
|
1102
|
+
shape : tuple of ints
|
1158
1103
|
Shape of the output array.
|
1159
|
-
center :
|
1104
|
+
center : tuple of ints
|
1160
1105
|
Center point coordinates of the box.
|
1161
|
-
height :
|
1106
|
+
height : tuple of ints
|
1162
1107
|
Height (side length) of the box along each axis.
|
1163
1108
|
|
1164
1109
|
Returns
|
1165
1110
|
-------
|
1166
1111
|
NDArray
|
1167
1112
|
The created box mask.
|
1113
|
+
|
1114
|
+
Raises
|
1115
|
+
------
|
1116
|
+
ValueError
|
1117
|
+
If ``shape`` and ``center`` do not have the same length.
|
1118
|
+
If ``center`` and ``height`` do not have the same length.
|
1168
1119
|
"""
|
1169
1120
|
if len(shape) != len(center) or len(center) != len(height):
|
1170
1121
|
raise ValueError("The length of shape, center, and height must be consistent.")
|
@@ -1216,9 +1167,9 @@ def tube_mask(
|
|
1216
1167
|
Raises
|
1217
1168
|
------
|
1218
1169
|
ValueError
|
1219
|
-
If
|
1220
|
-
than the symmetry axis
|
1221
|
-
same length.
|
1170
|
+
If ``inner_radius`` is larger than ``outer_radius``.
|
1171
|
+
If ``height`` is larger than the symmetry axis.
|
1172
|
+
If ``base_center`` and ``shape`` do not have the same length.
|
1222
1173
|
"""
|
1223
1174
|
if inner_radius > outer_radius:
|
1224
1175
|
raise ValueError("inner_radius should be smaller than outer_radius.")
|
@@ -1274,94 +1225,42 @@ def scramble_phases(
|
|
1274
1225
|
arr: NDArray,
|
1275
1226
|
noise_proportion: float = 0.5,
|
1276
1227
|
seed: int = 42,
|
1277
|
-
normalize_power: bool =
|
1228
|
+
normalize_power: bool = False,
|
1278
1229
|
) -> NDArray:
|
1279
1230
|
"""
|
1280
|
-
|
1281
|
-
|
1282
|
-
This function takes an input array, applies a Fourier transform, then scrambles the
|
1283
|
-
phase with a given proportion of noise, and finally applies an
|
1284
|
-
inverse Fourier transform to the scrambled data. The phase scrambling
|
1285
|
-
is controlled by a random seed.
|
1231
|
+
Perform random phase scrambling of ``arr``.
|
1286
1232
|
|
1287
1233
|
Parameters
|
1288
1234
|
----------
|
1289
1235
|
arr : NDArray
|
1290
|
-
|
1236
|
+
Input data.
|
1291
1237
|
noise_proportion : float, optional
|
1292
|
-
|
1238
|
+
Proportion of scrambled phases, 0.5 by default.
|
1293
1239
|
seed : int, optional
|
1294
|
-
The seed for the random phase scrambling, by default
|
1240
|
+
The seed for the random phase scrambling, 42 by default.
|
1295
1241
|
normalize_power : bool, optional
|
1296
|
-
|
1242
|
+
Return value has same sum of squares as ``arr``.
|
1297
1243
|
|
1298
1244
|
Returns
|
1299
1245
|
-------
|
1300
1246
|
NDArray
|
1301
|
-
|
1302
|
-
|
1303
|
-
Raises
|
1304
|
-
------
|
1305
|
-
ValueError
|
1306
|
-
If noise_proportion is not within [0, 1].
|
1247
|
+
Phase scrambled version of ``arr``.
|
1307
1248
|
"""
|
1308
|
-
|
1309
|
-
|
1249
|
+
np.random.seed(seed)
|
1250
|
+
noise_proportion = max(min(noise_proportion, 1), 0)
|
1310
1251
|
|
1311
1252
|
arr_fft = np.fft.fftn(arr)
|
1253
|
+
amp, ph = np.abs(arr_fft), np.angle(arr_fft)
|
1312
1254
|
|
1313
|
-
amp = np.abs(arr_fft)
|
1314
|
-
ph = np.angle(arr_fft)
|
1315
|
-
|
1316
|
-
np.random.seed(seed)
|
1317
1255
|
ph_noise = np.random.permutation(ph)
|
1318
1256
|
ph_new = ph * (1 - noise_proportion) + ph_noise * noise_proportion
|
1319
1257
|
ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph_new)))
|
1320
1258
|
|
1321
1259
|
if normalize_power:
|
1322
|
-
np.divide(
|
1323
|
-
np.subtract(ret, ret.min()), np.subtract(ret.max(), ret.min()), out=ret
|
1324
|
-
)
|
1260
|
+
np.divide(ret - ret.min(), ret.max() - ret.min(), out=ret)
|
1325
1261
|
np.multiply(ret, np.subtract(arr.max(), arr.min()), out=ret)
|
1326
1262
|
np.add(ret, arr.min(), out=ret)
|
1327
|
-
|
1328
1263
|
scaling = np.divide(np.abs(arr).sum(), np.abs(ret).sum())
|
1329
1264
|
np.multiply(ret, scaling, out=ret)
|
1330
1265
|
|
1331
1266
|
return ret
|
1332
|
-
|
1333
|
-
|
1334
|
-
def conditional_execute(func: Callable, execute_operation: bool = True) -> Callable:
|
1335
|
-
"""
|
1336
|
-
Return the given function or a no-operation function based on execute_operation.
|
1337
|
-
|
1338
|
-
Parameters
|
1339
|
-
----------
|
1340
|
-
func : callable
|
1341
|
-
The function to be executed if execute_operation is True.
|
1342
|
-
execute_operation : bool, optional
|
1343
|
-
A flag that determines whether to return `func` or a no-operation function.
|
1344
|
-
Default is True.
|
1345
|
-
|
1346
|
-
Returns
|
1347
|
-
-------
|
1348
|
-
callable
|
1349
|
-
Either the given function `func` or a no-operation function.
|
1350
|
-
|
1351
|
-
Examples
|
1352
|
-
--------
|
1353
|
-
>>> def greet(name):
|
1354
|
-
... return f"Hello, {name}!"
|
1355
|
-
...
|
1356
|
-
>>> operation = conditional_execute(greet, False)
|
1357
|
-
>>> operation("Alice")
|
1358
|
-
>>> operation = conditional_execute(greet, True)
|
1359
|
-
>>> operation("Alice")
|
1360
|
-
'Hello, Alice!'
|
1361
|
-
"""
|
1362
|
-
|
1363
|
-
def noop(*args, **kwargs):
|
1364
|
-
"""No operation function."""
|
1365
|
-
pass
|
1366
|
-
|
1367
|
-
return func if execute_operation else noop
|