pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
tme/matching_utils.py ADDED
@@ -0,0 +1,1188 @@
1
+ """ Utility functions for template matching.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ import os
9
+ import pickle
10
+ from shutil import move
11
+ from joblib import Parallel
12
+ from tempfile import mkstemp
13
+ from itertools import product
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ from typing import Tuple, Dict, Callable, Optional
16
+
17
+ import numpy as np
18
+ from tqdm import tqdm
19
+ from scipy.spatial import ConvexHull
20
+ from scipy.ndimage import gaussian_filter
21
+
22
+ from .backends import backend as be
23
+ from .memory import estimate_memory_usage
24
+ from .types import NDArray, BackendArray
25
+
26
+
27
+ def noop(*args, **kwargs):
28
+ pass
29
+
30
+
31
+ def identity(arr, *args):
32
+ return arr
33
+
34
+
35
+ def conditional_execute(
36
+ func: Callable,
37
+ execute_operation: bool,
38
+ alt_func: Callable = noop,
39
+ ) -> Callable:
40
+ """
41
+ Return the given function or a no-op function based on execute_operation.
42
+
43
+ Parameters
44
+ ----------
45
+ func : Callable
46
+ Callable.
47
+ alt_func : Callable
48
+ Callable to return if ``execute_operation`` is False, no-op by default.
49
+ execute_operation : bool
50
+ Whether to return ``func`` or a ``alt_func`` function.
51
+
52
+ Returns
53
+ -------
54
+ Callable
55
+ ``func`` if ``execute_operation`` else ``alt_func``.
56
+ """
57
+
58
+ return func if execute_operation else alt_func
59
+
60
+
61
+ def normalize_template(
62
+ template: BackendArray, mask: BackendArray, n_observations: float, axis=None
63
+ ) -> BackendArray:
64
+ """
65
+ Standardizes ``template`` to zero mean and unit standard deviation in ``mask``.
66
+
67
+ .. warning:: ``template`` is modified during the operation.
68
+
69
+ Parameters
70
+ ----------
71
+ template : BackendArray
72
+ Input data.
73
+ mask : BackendArray
74
+ Mask of the same shape as ``template``.
75
+ n_observations : float
76
+ Sum of mask elements.
77
+ axis : tuple of floats, optional
78
+ Axis to normalize over, all axis by default.
79
+
80
+ Returns
81
+ -------
82
+ BackendArray
83
+ Standardized input data.
84
+
85
+ References
86
+ ----------
87
+ .. [1] Hrabe T. et al, J. Struct. Biol. 178, 177 (2012).
88
+ """
89
+ masked_mean = be.sum(be.multiply(template, mask), axis=axis, keepdims=True)
90
+ masked_mean = be.divide(masked_mean, n_observations)
91
+ masked_std = be.sum(
92
+ be.multiply(be.square(template), mask), axis=axis, keepdims=True
93
+ )
94
+ masked_std = be.subtract(masked_std / n_observations, be.square(masked_mean))
95
+ masked_std = be.sqrt(be.maximum(masked_std, 0))
96
+
97
+ template = be.subtract(template, masked_mean, out=template)
98
+ template = be.divide(template, masked_std, out=template)
99
+ return be.multiply(template, mask, out=template)
100
+
101
+
102
+ def _normalize_template_overflow_safe(
103
+ template: BackendArray, mask: BackendArray, n_observations: float, axis=None
104
+ ) -> BackendArray:
105
+ _template = be.astype(template, be._overflow_safe_dtype)
106
+ _mask = be.astype(mask, be._overflow_safe_dtype)
107
+ normalize_template(
108
+ template=_template, mask=_mask, n_observations=n_observations, axis=axis
109
+ )
110
+ template[:] = be.astype(_template, template.dtype)
111
+ return template
112
+
113
+
114
+ def generate_tempfile_name(suffix: str = None) -> str:
115
+ """
116
+ Returns the path to a temporary file with given suffix. If defined. the
117
+ environment variable TMPDIR is used as base.
118
+
119
+ Parameters
120
+ ----------
121
+ suffix : str, optional
122
+ File suffix. By default the file has no suffix.
123
+
124
+ Returns
125
+ -------
126
+ str
127
+ The generated filename
128
+ """
129
+ return mkstemp(suffix=suffix)[1]
130
+
131
+
132
+ def array_to_memmap(arr: NDArray, filename: str = None, mode: str = "r") -> np.memmap:
133
+ """
134
+ Converts a obj:`numpy.ndarray` to a obj:`numpy.memmap`.
135
+
136
+ Parameters
137
+ ----------
138
+ arr : obj:`numpy.ndarray`
139
+ Input data.
140
+ filename : str, optional
141
+ Path to new memmap, :py:meth:`generate_tempfile_name` is used by default.
142
+ mode : str, optional
143
+ Mode to open the returned memmap object in, defautls to 'r'.
144
+
145
+ Returns
146
+ -------
147
+ obj:`numpy.memmap`
148
+ Memmaped array in reading mode.
149
+ """
150
+ if filename is None:
151
+ filename = generate_tempfile_name()
152
+
153
+ arr.tofile(filename)
154
+ return np.memmap(filename, mode=mode, dtype=arr.dtype, shape=arr.shape)
155
+
156
+
157
+ def memmap_to_array(arr: NDArray) -> NDArray:
158
+ """
159
+ Convert a obj:`numpy.memmap` to a obj:`numpy.ndarray` and delete the memmap.
160
+
161
+ Parameters
162
+ ----------
163
+ arr : obj:`numpy.memmap`
164
+ Input data.
165
+
166
+ Returns
167
+ -------
168
+ obj:`numpy.ndarray`
169
+ In-memory version of ``arr``.
170
+ """
171
+ if isinstance(arr, np.memmap):
172
+ memmap_filepath = arr.filename
173
+ arr = np.array(arr)
174
+ os.remove(memmap_filepath)
175
+ return arr
176
+
177
+
178
+ def write_pickle(data: object, filename: str) -> None:
179
+ """
180
+ Serialize and write data to a file invalidating the input data.
181
+
182
+ Parameters
183
+ ----------
184
+ data : iterable or object
185
+ The data to be serialized.
186
+ filename : str
187
+ The name of the file where the serialized data will be written.
188
+
189
+ See Also
190
+ --------
191
+ :py:meth:`load_pickle`
192
+ """
193
+ if type(data) not in (list, tuple):
194
+ data = (data,)
195
+
196
+ dirname = os.path.dirname(filename)
197
+ with open(filename, "wb") as ofile, ThreadPoolExecutor() as executor:
198
+ for i in range(len(data)):
199
+ futures = []
200
+ item = data[i]
201
+ if isinstance(item, np.memmap):
202
+ _, new_filename = mkstemp(suffix=".mm", dir=dirname)
203
+ new_item = ("np.memmap", item.shape, item.dtype, new_filename)
204
+ futures.append(executor.submit(move, item.filename, new_filename))
205
+ item = new_item
206
+ pickle.dump(item, ofile)
207
+ for future in futures:
208
+ future.result()
209
+
210
+
211
+ def load_pickle(filename: str) -> object:
212
+ """
213
+ Load and deserialize data written by :py:meth:`write_pickle`.
214
+
215
+ Parameters
216
+ ----------
217
+ filename : str
218
+ The name of the file to read and deserialize data from.
219
+
220
+ Returns
221
+ -------
222
+ object or iterable
223
+ The deserialized data.
224
+
225
+ See Also
226
+ --------
227
+ :py:meth:`write_pickle`
228
+ """
229
+
230
+ def _load_pickle(file_handle):
231
+ try:
232
+ while True:
233
+ yield pickle.load(file_handle)
234
+ except EOFError:
235
+ pass
236
+
237
+ def _is_pickle_memmap(data):
238
+ ret = False
239
+ if isinstance(data[0], str):
240
+ if data[0] == "np.memmap":
241
+ ret = True
242
+ return ret
243
+
244
+ items = []
245
+ with open(filename, "rb") as ifile:
246
+ for data in _load_pickle(ifile):
247
+ if isinstance(data, tuple):
248
+ if _is_pickle_memmap(data):
249
+ _, shape, dtype, filename = data
250
+ data = np.memmap(filename, shape=shape, dtype=dtype)
251
+ items.append(data)
252
+ return items[0] if len(items) == 1 else items
253
+
254
+
255
+ def compute_parallelization_schedule(
256
+ shape1: NDArray,
257
+ shape2: NDArray,
258
+ max_cores: int,
259
+ max_ram: int,
260
+ matching_method: str,
261
+ split_axes: Tuple[int] = None,
262
+ backend: str = None,
263
+ split_only_outer: bool = False,
264
+ shape1_padding: NDArray = None,
265
+ analyzer_method: str = None,
266
+ max_splits: int = 256,
267
+ float_nbytes: int = 4,
268
+ complex_nbytes: int = 8,
269
+ integer_nbytes: int = 4,
270
+ ) -> Tuple[Dict, int, int]:
271
+ """
272
+ Computes a parallelization schedule for a given computation.
273
+
274
+ This function estimates the amount of memory that would be used by a computation
275
+ and breaks down the computation into smaller parts that can be executed in parallel
276
+ without exceeding the specified limits on the number of cores and memory.
277
+
278
+ Parameters
279
+ ----------
280
+ shape1 : NDArray
281
+ The shape of the first input array.
282
+ shape1_padding : NDArray, optional
283
+ Padding for shape1, None by default.
284
+ shape2 : NDArray
285
+ The shape of the second input array.
286
+ max_cores : int
287
+ The maximum number of cores that can be used.
288
+ max_ram : int
289
+ The maximum amount of memory that can be used.
290
+ matching_method : str
291
+ The metric used for scoring the computations.
292
+ split_axes : tuple
293
+ Axes that can be used for splitting. By default all are considered.
294
+ backend : str, optional
295
+ Backend used for computations.
296
+ split_only_outer : bool, optional
297
+ Whether only outer splits sould be considered.
298
+ analyzer_method : str
299
+ The method used for score analysis.
300
+ max_splits : int, optional
301
+ The maximum number of parts that the computation can be split into,
302
+ by default 256.
303
+ float_nbytes : int
304
+ Number of bytes of the used float, e.g. 4 for float32.
305
+ complex_nbytes : int
306
+ Number of bytes of the used complex, e.g. 8 for complex64.
307
+ integer_nbytes : int
308
+ Number of bytes of the used integer, e.g. 4 for int32.
309
+
310
+ Notes
311
+ -----
312
+ This function assumes that no residual memory remains after each split,
313
+ which not always holds true, e.g. when using
314
+ :py:class:`tme.analyzer.MaxScoreOverRotations`.
315
+
316
+ Returns
317
+ -------
318
+ dict
319
+ The optimal splits for each axis of the first input tensor.
320
+ int
321
+ The number of outer jobs.
322
+ int
323
+ The number of inner jobs per outer job.
324
+ """
325
+ shape1 = tuple(int(x) for x in shape1)
326
+ shape2 = tuple(int(x) for x in shape2)
327
+
328
+ if shape1_padding is None:
329
+ shape1_padding = np.zeros_like(shape1)
330
+ core_assignments = []
331
+ for i in range(1, int(max_cores**0.5) + 1):
332
+ if max_cores % i == 0:
333
+ core_assignments.append((i, max_cores // i))
334
+ core_assignments.append((max_cores // i, i))
335
+
336
+ if split_only_outer:
337
+ core_assignments = [(1, max_cores)]
338
+
339
+ possible_params, split_axis = [], np.argmax(shape1)
340
+
341
+ split_axis_index = split_axis
342
+ if split_axes is not None:
343
+ split_axis, split_axis_index = split_axes[0], 0
344
+ else:
345
+ split_axes = tuple(i for i in range(len(shape1)))
346
+
347
+ split_factor, n_splits = [1 for _ in range(len(shape1))], 0
348
+ while n_splits <= max_splits:
349
+ splits = {k: split_factor[k] for k in range(len(split_factor))}
350
+ array_slices = split_shape(shape=shape1, splits=splits)
351
+ array_widths = [
352
+ tuple(x.stop - x.start for x in split) for split in array_slices
353
+ ]
354
+ n_splits = np.prod(list(splits.values()))
355
+
356
+ for inner_cores, outer_cores in core_assignments:
357
+ if outer_cores > n_splits:
358
+ continue
359
+ ram_usage = [
360
+ estimate_memory_usage(
361
+ shape1=tuple(sum(x) for x in zip(shp, shape1_padding)),
362
+ shape2=shape2,
363
+ matching_method=matching_method,
364
+ analyzer_method=analyzer_method,
365
+ backend=backend,
366
+ ncores=inner_cores,
367
+ float_nbytes=float_nbytes,
368
+ complex_nbytes=complex_nbytes,
369
+ integer_nbytes=integer_nbytes,
370
+ )
371
+ for shp in array_widths
372
+ ]
373
+ max_usage = 0
374
+ for i in range(0, len(ram_usage), outer_cores):
375
+ usage = np.sum(ram_usage[i : (i + outer_cores)])
376
+ if usage > max_usage:
377
+ max_usage = usage
378
+
379
+ inits = n_splits // outer_cores
380
+ if max_usage < max_ram:
381
+ possible_params.append(
382
+ (*split_factor, outer_cores, inner_cores, n_splits, inits)
383
+ )
384
+ split_factor[split_axis] += 1
385
+
386
+ split_axis_index += 1
387
+ if split_axis_index == len(split_axes):
388
+ split_axis_index = 0
389
+ split_axis = split_axes[split_axis_index]
390
+
391
+ possible_params = np.array(possible_params)
392
+ if not len(possible_params):
393
+ print(
394
+ "No suitable assignment found. Consider increasing "
395
+ "max_ram or decrease max_cores."
396
+ )
397
+ return None, None
398
+
399
+ init = possible_params.shape[1] - 1
400
+ possible_params = possible_params[
401
+ np.lexsort((possible_params[:, init], possible_params[:, (init - 1)]))
402
+ ]
403
+ splits = {k: possible_params[0, k] for k in range(len(shape1))}
404
+ core_assignment = (
405
+ possible_params[0, len(shape1)],
406
+ possible_params[0, (len(shape1) + 1)],
407
+ )
408
+
409
+ return splits, core_assignment
410
+
411
+
412
+ def _center_slice(current_shape: Tuple[int], new_shape: Tuple[int]) -> Tuple[slice]:
413
+ """Extract the center slice of ``current_shape`` to retrieve ``new_shape``."""
414
+ new_shape = tuple(int(x) for x in new_shape)
415
+ current_shape = tuple(int(x) for x in current_shape)
416
+ starts = tuple((x - y) // 2 for x, y in zip(current_shape, new_shape))
417
+ stops = tuple(sum(stop) for stop in zip(starts, new_shape))
418
+ box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
419
+ return box
420
+
421
+
422
+ def centered(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
423
+ """
424
+ Extract the centered portion of an array based on a new shape.
425
+
426
+ Parameters
427
+ ----------
428
+ arr : BackendArray
429
+ Input data.
430
+ new_shape : tuple of ints
431
+ Desired shape for the central portion.
432
+
433
+ Returns
434
+ -------
435
+ BackendArray
436
+ Central portion of the array with shape ``new_shape``.
437
+
438
+ References
439
+ ----------
440
+ .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L388
441
+ """
442
+ box = _center_slice(arr.shape, new_shape=new_shape)
443
+ return arr[box]
444
+
445
+
446
+ def centered_mask(arr: BackendArray, new_shape: Tuple[int]) -> BackendArray:
447
+ """
448
+ Mask the centered portion of an array based on a new shape.
449
+
450
+ Parameters
451
+ ----------
452
+ arr : BackendArray
453
+ Input data.
454
+ new_shape : tuple of ints
455
+ Desired shape for the mask.
456
+
457
+ Returns
458
+ -------
459
+ BackendArray
460
+ Array with central portion unmasked and the rest set to 0.
461
+ """
462
+ box = _center_slice(arr.shape, new_shape=new_shape)
463
+ mask = np.zeros_like(arr)
464
+ mask[box] = 1
465
+ arr *= mask
466
+ return arr
467
+
468
+
469
+ def apply_convolution_mode(
470
+ arr: BackendArray,
471
+ convolution_mode: str,
472
+ s1: Tuple[int],
473
+ s2: Tuple[int],
474
+ convolution_shape: Tuple[int] = None,
475
+ mask_output: bool = False,
476
+ ) -> BackendArray:
477
+ """
478
+ Applies convolution_mode to ``arr``.
479
+
480
+ Parameters
481
+ ----------
482
+ arr : BackendArray
483
+ Array containing convolution result of arrays with shape s1 and s2.
484
+ convolution_mode : str
485
+ Analogous to mode in obj:`scipy.signal.convolve`:
486
+
487
+ +---------+----------------------------------------------------------+
488
+ | 'full' | returns full template matching result of the inputs. |
489
+ +---------+----------------------------------------------------------+
490
+ | 'valid' | returns elements that do not rely on zero-padding.. |
491
+ +---------+----------------------------------------------------------+
492
+ | 'same' | output is the same size as s1. |
493
+ +---------+----------------------------------------------------------+
494
+ s1 : tuple of ints
495
+ Tuple of integers corresponding to shape of convolution array 1.
496
+ s2 : tuple of ints
497
+ Tuple of integers corresponding to shape of convolution array 2.
498
+ convolution_shape : tuple of ints, optional
499
+ Size of the actually computed convolution. s1 + s2 - 1 by default.
500
+ mask_output : bool, optional
501
+ Whether to mask values outside of convolution_mode rather than
502
+ removing them. Defaults to False.
503
+
504
+ Returns
505
+ -------
506
+ BackendArray
507
+ The array after applying the convolution mode.
508
+ """
509
+ # Remove padding to next fast Fourier length
510
+ if convolution_shape is None:
511
+ convolution_shape = [s1[i] + s2[i] - 1 for i in range(len(s1))]
512
+ arr = arr[tuple(slice(0, x) for x in convolution_shape)]
513
+
514
+ if convolution_mode not in ("full", "same", "valid"):
515
+ raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
516
+
517
+ func = centered_mask if mask_output else centered
518
+ if convolution_mode == "full":
519
+ return arr
520
+ elif convolution_mode == "same":
521
+ return func(arr, s1)
522
+ elif convolution_mode == "valid":
523
+ valid_shape = [s1[i] - s2[i] + s2[i] % 2 for i in range(arr.ndim)]
524
+ return func(arr, valid_shape)
525
+
526
+
527
+ def compute_full_convolution_index(
528
+ outer_shape: Tuple[int],
529
+ inner_shape: Tuple[int],
530
+ outer_split: Tuple[slice],
531
+ inner_split: Tuple[slice],
532
+ ) -> Tuple[slice]:
533
+ """
534
+ Computes the position of the convolution of pieces in the full convolution.
535
+
536
+ Parameters
537
+ ----------
538
+ outer_shape : tuple
539
+ Tuple of integers corresponding to the shape of the outer array.
540
+ inner_shape : tuple
541
+ Tuple of integers corresponding to the shape of the inner array.
542
+ outer_split : tuple
543
+ Tuple of slices used to split outer array (see :py:meth:`split_shape`).
544
+ inner_split : tuple
545
+ Tuple of slices used to split inner array (see :py:meth:`split_shape`).
546
+
547
+ Returns
548
+ -------
549
+ tuple
550
+ Tuple of slices corresponding to the position of the given convolution
551
+ in the full convolution.
552
+ """
553
+ outer_shape = np.asarray(outer_shape)
554
+ inner_shape = np.asarray(inner_shape)
555
+
556
+ outer_width = np.array([outer.stop - outer.start for outer in outer_split])
557
+ inner_width = np.array([inner.stop - inner.start for inner in inner_split])
558
+ convolution_shape = outer_width + inner_width - 1
559
+
560
+ end_inner = np.array([inner.stop for inner in inner_split]).astype(int)
561
+ start_outer = np.array([outer.start for outer in outer_split]).astype(int)
562
+
563
+ offsets = start_outer + inner_shape - end_inner
564
+
565
+ score_slice = tuple(
566
+ (slice(offset, offset + shape))
567
+ for offset, shape in zip(offsets, convolution_shape)
568
+ )
569
+
570
+ return score_slice
571
+
572
+
573
+ def split_shape(
574
+ shape: Tuple[int], splits: Dict, equal_shape: bool = True
575
+ ) -> Tuple[slice]:
576
+ """
577
+ Splits ``shape`` into equally sized and potentially overlapping subsets.
578
+
579
+ Parameters
580
+ ----------
581
+ shape : tuple of ints
582
+ Shape to split.
583
+ splits : dict
584
+ Dictionary mapping axis number to number of splits.
585
+ equal_shape : dict
586
+ Whether the subsets should be of equal shape, True by default.
587
+
588
+ Returns
589
+ -------
590
+ tuple
591
+ Tuple of slice with requested split combinations.
592
+ """
593
+ ndim = len(shape)
594
+ splits = {k: max(splits.get(k, 1), 1) for k in range(ndim)}
595
+ ret_shape = np.divide(shape, tuple(splits[i] for i in range(ndim)))
596
+ if equal_shape:
597
+ ret_shape = np.ceil(ret_shape).astype(int)
598
+ ret_shape = ret_shape.astype(int)
599
+
600
+ slice_list = [
601
+ tuple(
602
+ (
603
+ (slice((n_splits * length), (n_splits + 1) * length))
604
+ if n_splits < splits.get(axis, 1) - 1
605
+ else (
606
+ (slice(shape[axis] - length, shape[axis]))
607
+ if equal_shape
608
+ else (slice((n_splits * length), shape[axis]))
609
+ )
610
+ )
611
+ for n_splits in range(splits.get(axis, 1))
612
+ )
613
+ for length, axis in zip(ret_shape, splits.keys())
614
+ ]
615
+
616
+ splits = tuple(product(*slice_list))
617
+
618
+ return splits
619
+
620
+
621
+ def rigid_transform(
622
+ coordinates: NDArray,
623
+ rotation_matrix: NDArray,
624
+ out: NDArray,
625
+ translation: NDArray,
626
+ use_geometric_center: bool = False,
627
+ coordinates_mask: NDArray = None,
628
+ out_mask: NDArray = None,
629
+ center: NDArray = None,
630
+ ) -> None:
631
+ """
632
+ Apply a rigid transformation (rotation and translation) to given coordinates.
633
+
634
+ Parameters
635
+ ----------
636
+ coordinates : NDArray
637
+ An array representing the coordinates to be transformed (d,n).
638
+ rotation_matrix : NDArray
639
+ The rotation matrix to be applied (d,d).
640
+ translation : NDArray
641
+ The translation vector to be applied (d,).
642
+ out : NDArray
643
+ The output array to store the transformed coordinates (d,n).
644
+ coordinates_mask : NDArray, optional
645
+ An array representing the mask for the coordinates (d,t).
646
+ out_mask : NDArray, optional
647
+ The output array to store the transformed coordinates mask (d,t).
648
+ use_geometric_center : bool, optional
649
+ Whether to use geometric or coordinate center.
650
+ """
651
+ coordinate_dtype = coordinates.dtype
652
+ center = coordinates.mean(axis=1) if center is None else center
653
+ if not use_geometric_center:
654
+ coordinates = coordinates - center[:, None]
655
+
656
+ np.matmul(rotation_matrix, coordinates, out=out)
657
+ if use_geometric_center:
658
+ axis_max, axis_min = out.max(axis=1), out.min(axis=1)
659
+ axis_difference = axis_max - axis_min
660
+ translation = np.add(translation, center - axis_max + (axis_difference // 2))
661
+ else:
662
+ translation = np.add(translation, np.subtract(center, out.mean(axis=1)))
663
+
664
+ out += translation[:, None]
665
+ if coordinates_mask is not None and out_mask is not None:
666
+ if not use_geometric_center:
667
+ coordinates_mask = coordinates_mask - center[:, None]
668
+ np.matmul(rotation_matrix, coordinates_mask, out=out_mask)
669
+ out_mask += translation[:, None]
670
+
671
+ if not use_geometric_center and coordinate_dtype != out.dtype:
672
+ np.subtract(out.mean(axis=1), out.astype(int).mean(axis=1), out=translation)
673
+ out += translation[:, None]
674
+
675
+
676
+ def minimum_enclosing_box(
677
+ coordinates: NDArray, margin: NDArray = None, use_geometric_center: bool = False
678
+ ) -> Tuple[int]:
679
+ """
680
+ Computes the minimal enclosing box around coordinates with margin.
681
+
682
+ Parameters
683
+ ----------
684
+ coordinates : NDArray
685
+ Coordinates of shape (d,n) to compute the enclosing box of.
686
+ margin : NDArray, optional
687
+ Box margin, zero by default.
688
+ use_geometric_center : bool, optional
689
+ Whether box accommodates the geometric or coordinate center, False by default.
690
+
691
+ Returns
692
+ -------
693
+ tuple of ints
694
+ Minimum enclosing box shape.
695
+ """
696
+ from .extensions import max_euclidean_distance
697
+
698
+ point_cloud = np.asarray(coordinates)
699
+ dim = point_cloud.shape[0]
700
+ point_cloud = point_cloud - point_cloud.min(axis=1)[:, None]
701
+
702
+ margin = np.zeros(dim) if margin is None else margin
703
+ margin = np.asarray(margin).astype(int)
704
+
705
+ norm_cloud = point_cloud - point_cloud.mean(axis=1)[:, None]
706
+ # Adding one avoids clipping during scipy.ndimage.affine_transform
707
+ shape = np.repeat(
708
+ np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
709
+ ).astype(int)
710
+ if use_geometric_center:
711
+ hull = ConvexHull(point_cloud.T)
712
+ distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
713
+ distance += np.linalg.norm(np.ones(dim))
714
+ shape = np.repeat(np.rint(distance).astype(int), dim)
715
+
716
+ return shape
717
+
718
+
719
+ def create_mask(mask_type: str, sigma_decay: float = 0, **kwargs) -> NDArray:
720
+ """
721
+ Creates a mask of the specified type.
722
+
723
+ Parameters
724
+ ----------
725
+ mask_type : str
726
+ Type of the mask to be created. Can be one of:
727
+
728
+ +---------+----------------------------------------------------------+
729
+ | box | Box mask (see :py:meth:`box_mask`) |
730
+ +---------+----------------------------------------------------------+
731
+ | tube | Cylindrical mask (see :py:meth:`tube_mask`) |
732
+ +---------+----------------------------------------------------------+
733
+ | ellipse | Ellipsoidal mask (see :py:meth:`elliptical_mask`) |
734
+ +---------+----------------------------------------------------------+
735
+ sigma_decay : float, optional
736
+ Smoothing along mask edges using a Gaussian filter, 0 by default.
737
+ kwargs : dict
738
+ Parameters passed to the indivdual mask creation funcitons.
739
+
740
+ Returns
741
+ -------
742
+ NDArray
743
+ The created mask.
744
+
745
+ Raises
746
+ ------
747
+ ValueError
748
+ If the mask_type is invalid.
749
+ """
750
+ mapping = {"ellipse": elliptical_mask, "box": box_mask, "tube": tube_mask}
751
+ if mask_type not in mapping:
752
+ raise ValueError(f"mask_type has to be one of {','.join(mapping.keys())}")
753
+
754
+ mask = mapping[mask_type](**kwargs)
755
+ if sigma_decay > 0:
756
+ mask_filter = gaussian_filter(mask.astype(np.float32), sigma=sigma_decay)
757
+ mask = np.add(mask, (1 - mask) * mask_filter)
758
+ mask[mask < np.exp(-np.square(sigma_decay))] = 0
759
+
760
+ return mask
761
+
762
+
763
+ def elliptical_mask(
764
+ shape: Tuple[int],
765
+ radius: Tuple[float],
766
+ center: Optional[Tuple[float]] = None,
767
+ orientation: Optional[NDArray] = None,
768
+ ) -> NDArray:
769
+ """
770
+ Creates an ellipsoidal mask.
771
+
772
+ Parameters
773
+ ----------
774
+ shape : tuple of ints
775
+ Shape of the mask to be created.
776
+ radius : tuple of floats
777
+ Radius of the mask.
778
+ center : tuple of floats, optional
779
+ Center of the mask, default to shape // 2.
780
+ orientation : NDArray, optional.
781
+ Orientation of the mask as rotation matrix with shape (d,d).
782
+
783
+ Returns
784
+ -------
785
+ NDArray
786
+ The created ellipsoidal mask.
787
+
788
+ Raises
789
+ ------
790
+ ValueError
791
+ If the length of center and radius is not one or the same as shape.
792
+
793
+ Examples
794
+ --------
795
+ >>> from tme.matching_utils import elliptical_mask
796
+ >>> mask = elliptical_mask(shape=(20,20), radius=(5,5), center=(10,10))
797
+ """
798
+ shape, radius = np.asarray(shape), np.asarray(radius)
799
+
800
+ shape = shape.astype(int)
801
+ if center is None:
802
+ center = np.divide(shape, 2).astype(int)
803
+
804
+ center = np.asarray(center, dtype=np.float32)
805
+ radius = np.repeat(radius, shape.size // radius.size)
806
+ center = np.repeat(center, shape.size // center.size)
807
+ if radius.size != shape.size:
808
+ raise ValueError("Length of radius has to be either one or match shape.")
809
+ if center.size != shape.size:
810
+ raise ValueError("Length of center has to be either one or match shape.")
811
+
812
+ n = shape.size
813
+ center = center.reshape((-1,) + (1,) * n)
814
+ radius = radius.reshape((-1,) + (1,) * n)
815
+
816
+ indices = np.indices(shape, dtype=np.float32) - center
817
+ if orientation is not None:
818
+ return_shape = indices.shape
819
+ indices = indices.reshape(n, -1)
820
+ rigid_transform(
821
+ coordinates=indices,
822
+ rotation_matrix=np.asarray(orientation),
823
+ out=indices,
824
+ translation=np.zeros(n),
825
+ use_geometric_center=False,
826
+ )
827
+ indices = indices.reshape(*return_shape)
828
+
829
+ mask = np.linalg.norm(indices / radius, axis=0)
830
+ mask = (mask <= 1).astype(int)
831
+
832
+ return mask
833
+
834
+
835
+ def tube_mask2(
836
+ shape: Tuple[int],
837
+ inner_radius: float,
838
+ outer_radius: float,
839
+ height: int,
840
+ symmetry_axis: Optional[int] = 2,
841
+ center: Optional[Tuple[float]] = None,
842
+ orientation: Optional[NDArray] = None,
843
+ epsilon: float = 0.5,
844
+ ) -> NDArray:
845
+ """
846
+ Creates a tube mask.
847
+
848
+ Parameters
849
+ ----------
850
+ shape : tuple
851
+ Shape of the mask to be created.
852
+ inner_radius : float
853
+ Inner radius of the tube.
854
+ outer_radius : float
855
+ Outer radius of the tube.
856
+ height : int
857
+ Height of the tube.
858
+ symmetry_axis : int, optional
859
+ The axis of symmetry for the tube, defaults to 2.
860
+ center : tuple of float, optional.
861
+ Center of the mask, defaults to shape // 2.
862
+ orientation : NDArray, optional.
863
+ Orientation of the mask as rotation matrix with shape (d,d).
864
+ epsilon : float, optional
865
+ Tolerance to handle discretization errors, defaults to 0.5.
866
+
867
+ Returns
868
+ -------
869
+ NDArray
870
+ The created tube mask.
871
+
872
+ Raises
873
+ ------
874
+ ValueError
875
+ If ``inner_radius`` is larger than ``outer_radius``.
876
+ If ``center`` and ``shape`` do not have the same length.
877
+ """
878
+ shape = np.asarray(shape, dtype=int)
879
+
880
+ if center is None:
881
+ center = np.divide(shape, 2).astype(int)
882
+
883
+ center = np.asarray(center, dtype=np.float32)
884
+ center = np.repeat(center, shape.size // center.size)
885
+ if inner_radius > outer_radius:
886
+ raise ValueError("inner_radius should be smaller than outer_radius.")
887
+ if symmetry_axis > len(shape):
888
+ raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
889
+ if center.size != shape.size:
890
+ raise ValueError("Length of center has to be either one or match shape.")
891
+
892
+ n = shape.size
893
+ center = center.reshape((-1,) + (1,) * n)
894
+ indices = np.indices(shape, dtype=np.float32) - center
895
+ if orientation is not None:
896
+ return_shape = indices.shape
897
+ indices = indices.reshape(n, -1)
898
+ rigid_transform(
899
+ coordinates=indices,
900
+ rotation_matrix=np.asarray(orientation),
901
+ out=indices,
902
+ translation=np.zeros(n),
903
+ use_geometric_center=False,
904
+ )
905
+ indices = indices.reshape(*return_shape)
906
+
907
+ mask = np.zeros(shape, dtype=bool)
908
+ sq_dist = np.zeros(shape)
909
+ for i in range(len(shape)):
910
+ if i == symmetry_axis:
911
+ continue
912
+ sq_dist += indices[i] ** 2
913
+
914
+ sym_coord = indices[symmetry_axis]
915
+ half_height = height / 2
916
+ height_mask = np.abs(sym_coord) <= half_height
917
+
918
+ inner_mask = 1
919
+ if inner_radius > epsilon:
920
+ inner_mask = sq_dist >= ((inner_radius) ** 2 - epsilon)
921
+
922
+ height_mask = np.abs(sym_coord) <= (half_height + epsilon)
923
+ outer_mask = sq_dist <= ((outer_radius) ** 2 + epsilon)
924
+
925
+ mask = height_mask & inner_mask & outer_mask
926
+ return mask
927
+
928
+
929
+ def box_mask(shape: Tuple[int], center: Tuple[int], height: Tuple[int]) -> np.ndarray:
930
+ """
931
+ Creates a box mask centered around the provided center point.
932
+
933
+ Parameters
934
+ ----------
935
+ shape : tuple of ints
936
+ Shape of the output array.
937
+ center : tuple of ints
938
+ Center point coordinates of the box.
939
+ height : tuple of ints
940
+ Height (side length) of the box along each axis.
941
+
942
+ Returns
943
+ -------
944
+ NDArray
945
+ The created box mask.
946
+
947
+ Raises
948
+ ------
949
+ ValueError
950
+ If ``shape`` and ``center`` do not have the same length.
951
+ If ``center`` and ``height`` do not have the same length.
952
+ """
953
+ if len(shape) != len(center) or len(center) != len(height):
954
+ raise ValueError("The length of shape, center, and height must be consistent.")
955
+
956
+ shape = tuple(int(x) for x in shape)
957
+ center, height = np.array(center, dtype=int), np.array(height, dtype=int)
958
+
959
+ half_heights = height // 2
960
+ starts = np.maximum(center - half_heights, 0)
961
+ stops = np.minimum(center + half_heights + np.mod(height, 2) + 1, shape)
962
+ slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
963
+
964
+ out = np.zeros(shape)
965
+ out[slice_indices] = 1
966
+ return out
967
+
968
+
969
+ def tube_mask(
970
+ shape: Tuple[int],
971
+ symmetry_axis: int,
972
+ base_center: Tuple[int],
973
+ inner_radius: float,
974
+ outer_radius: float,
975
+ height: int,
976
+ ) -> NDArray:
977
+ """
978
+ Creates a tube mask.
979
+
980
+ Parameters
981
+ ----------
982
+ shape : tuple
983
+ Shape of the mask to be created.
984
+ symmetry_axis : int
985
+ The axis of symmetry for the tube.
986
+ base_center : tuple
987
+ Center of the tube.
988
+ inner_radius : float
989
+ Inner radius of the tube.
990
+ outer_radius : float
991
+ Outer radius of the tube.
992
+ height : int
993
+ Height of the tube.
994
+
995
+ Returns
996
+ -------
997
+ NDArray
998
+ The created tube mask.
999
+
1000
+ Raises
1001
+ ------
1002
+ ValueError
1003
+ If ``inner_radius`` is larger than ``outer_radius``.
1004
+ If ``height`` is larger than the symmetry axis.
1005
+ If ``base_center`` and ``shape`` do not have the same length.
1006
+ """
1007
+ if inner_radius > outer_radius:
1008
+ raise ValueError("inner_radius should be smaller than outer_radius.")
1009
+
1010
+ if height > shape[symmetry_axis]:
1011
+ raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
1012
+
1013
+ if symmetry_axis > len(shape):
1014
+ raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
1015
+
1016
+ if len(base_center) != len(shape):
1017
+ raise ValueError("shape and base_center need to have the same length.")
1018
+
1019
+ shape = tuple(int(x) for x in shape)
1020
+ circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
1021
+ circle_center = tuple(b for ix, b in enumerate(base_center) if ix != symmetry_axis)
1022
+
1023
+ inner_circle = np.zeros(circle_shape)
1024
+ outer_circle = np.zeros_like(inner_circle)
1025
+ if inner_radius > 0:
1026
+ inner_circle = create_mask(
1027
+ mask_type="ellipse",
1028
+ shape=circle_shape,
1029
+ radius=inner_radius,
1030
+ center=circle_center,
1031
+ )
1032
+ if outer_radius > 0:
1033
+ outer_circle = create_mask(
1034
+ mask_type="ellipse",
1035
+ shape=circle_shape,
1036
+ radius=outer_radius,
1037
+ center=circle_center,
1038
+ )
1039
+ circle = outer_circle - inner_circle
1040
+ circle = np.expand_dims(circle, axis=symmetry_axis)
1041
+
1042
+ center = base_center[symmetry_axis]
1043
+ start_idx = int(center - height // 2)
1044
+ stop_idx = int(center + height // 2 + height % 2)
1045
+
1046
+ start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis])
1047
+
1048
+ slice_indices = tuple(
1049
+ slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
1050
+ for i in range(len(shape))
1051
+ )
1052
+ tube = np.zeros(shape)
1053
+ tube[slice_indices] = circle
1054
+
1055
+ return tube
1056
+
1057
+
1058
+ def scramble_phases(
1059
+ arr: NDArray,
1060
+ noise_proportion: float = 0.5,
1061
+ seed: int = 42,
1062
+ normalize_power: bool = False,
1063
+ ) -> NDArray:
1064
+ """
1065
+ Perform random phase scrambling of ``arr``.
1066
+
1067
+ Parameters
1068
+ ----------
1069
+ arr : NDArray
1070
+ Input data.
1071
+ noise_proportion : float, optional
1072
+ Proportion of scrambled phases, 0.5 by default.
1073
+ seed : int, optional
1074
+ The seed for the random phase scrambling, 42 by default.
1075
+ normalize_power : bool, optional
1076
+ Return value has same sum of squares as ``arr``.
1077
+
1078
+ Returns
1079
+ -------
1080
+ NDArray
1081
+ Phase scrambled version of ``arr``.
1082
+ """
1083
+ np.random.seed(seed)
1084
+ noise_proportion = max(min(noise_proportion, 1), 0)
1085
+
1086
+ arr_fft = np.fft.fftn(arr)
1087
+ amp, ph = np.abs(arr_fft), np.angle(arr_fft)
1088
+
1089
+ ph_noise = np.random.permutation(ph)
1090
+ ph_new = ph * (1 - noise_proportion) + ph_noise * noise_proportion
1091
+ ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph_new)))
1092
+
1093
+ if normalize_power:
1094
+ np.divide(ret - ret.min(), ret.max() - ret.min(), out=ret)
1095
+ np.multiply(ret, np.subtract(arr.max(), arr.min()), out=ret)
1096
+ np.add(ret, arr.min(), out=ret)
1097
+ scaling = np.divide(np.abs(arr).sum(), np.abs(ret).sum())
1098
+ np.multiply(ret, scaling, out=ret)
1099
+
1100
+ return ret
1101
+
1102
+
1103
+ def compute_extraction_box(
1104
+ centers: BackendArray, extraction_shape: Tuple[int], original_shape: Tuple[int]
1105
+ ):
1106
+ """Compute coordinates for extracting fixed-size regions around points.
1107
+
1108
+ Parameters
1109
+ ----------
1110
+ centers : BackendArray
1111
+ Array of shape (n, d) containing n center coordinates in d dimensions.
1112
+ extraction_shape : tuple of int
1113
+ Desired shape of the extraction box.
1114
+ original_shape : tuple of int
1115
+ Shape of the original array from which extractions will be made.
1116
+
1117
+ Returns
1118
+ -------
1119
+ obs_beg : BackendArray
1120
+ Starting coordinates for extraction, shape (n, d).
1121
+ obs_end : BackendArray
1122
+ Ending coordinates for extraction, shape (n, d).
1123
+ cand_beg : BackendArray
1124
+ Starting coordinates in output array, shape (n, d).
1125
+ cand_end : BackendArray
1126
+ Ending coordinates in output array, shape (n, d).
1127
+ keep : BackendArray
1128
+ Boolean mask of valid extraction boxes, shape (n,).
1129
+ """
1130
+ target_shape = be.to_backend_array(original_shape)
1131
+ extraction_shape = be.to_backend_array(extraction_shape)
1132
+
1133
+ left_pad = be.astype(be.divide(extraction_shape, 2), int)
1134
+ right_pad = be.astype(be.add(left_pad, be.mod(extraction_shape, 2)), int)
1135
+
1136
+ obs_beg = be.subtract(centers, left_pad)
1137
+ obs_end = be.add(centers, right_pad)
1138
+
1139
+ obs_beg_clamp = be.maximum(obs_beg, 0)
1140
+ obs_end_clamp = be.minimum(obs_end, target_shape)
1141
+
1142
+ clamp_change = be.sum(
1143
+ be.add(obs_beg != obs_beg_clamp, obs_end != obs_end_clamp), axis=1
1144
+ )
1145
+
1146
+ cand_beg = left_pad - be.subtract(centers, obs_beg_clamp)
1147
+ cand_end = left_pad + be.subtract(obs_end_clamp, centers)
1148
+
1149
+ stops = be.subtract(cand_end, extraction_shape)
1150
+ keep = be.sum(be.multiply(cand_beg == 0, stops == 0), axis=1) == centers.shape[1]
1151
+ keep = be.multiply(keep, clamp_change == 0)
1152
+
1153
+ return obs_beg_clamp, obs_end_clamp, cand_beg, cand_end, keep
1154
+
1155
+
1156
+ class TqdmParallel(Parallel):
1157
+ """
1158
+ A minimal Parallel implementation using tqdm for progress reporting.
1159
+
1160
+ Parameters:
1161
+ -----------
1162
+ tqdm_args : dict, optional
1163
+ Dictionary of arguments passed to tqdm.tqdm
1164
+ *args, **kwargs:
1165
+ Arguments to pass to joblib.Parallel
1166
+ """
1167
+
1168
+ def __init__(self, tqdm_args: Dict = {}, *args, **kwargs):
1169
+ super().__init__(*args, **kwargs)
1170
+ self.pbar = tqdm(**tqdm_args)
1171
+
1172
+ def __call__(self, iterable, *args, **kwargs):
1173
+ self.n_tasks = len(iterable) if hasattr(iterable, "__len__") else None
1174
+ return super().__call__(iterable, *args, **kwargs)
1175
+
1176
+ def print_progress(self):
1177
+ if self.n_tasks is None:
1178
+ return super().print_progress()
1179
+
1180
+ if self.n_tasks != self.pbar.total:
1181
+ self.pbar.total = self.n_tasks
1182
+ self.pbar.refresh()
1183
+
1184
+ self.pbar.n = self.n_completed_tasks
1185
+ self.pbar.refresh()
1186
+
1187
+ if self.n_completed_tasks >= self.n_tasks:
1188
+ self.pbar.close()