pytme 0.1.5__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.1.5.data/scripts/estimate_ram_usage.py +81 -0
  2. pytme-0.1.5.data/scripts/match_template.py +744 -0
  3. pytme-0.1.5.data/scripts/postprocess.py +279 -0
  4. pytme-0.1.5.data/scripts/preprocess.py +93 -0
  5. pytme-0.1.5.data/scripts/preprocessor_gui.py +729 -0
  6. pytme-0.1.5.dist-info/LICENSE +153 -0
  7. pytme-0.1.5.dist-info/METADATA +69 -0
  8. pytme-0.1.5.dist-info/RECORD +63 -0
  9. pytme-0.1.5.dist-info/WHEEL +5 -0
  10. pytme-0.1.5.dist-info/entry_points.txt +6 -0
  11. pytme-0.1.5.dist-info/top_level.txt +2 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +81 -0
  14. scripts/match_template.py +744 -0
  15. scripts/match_template_devel.py +788 -0
  16. scripts/postprocess.py +279 -0
  17. scripts/preprocess.py +93 -0
  18. scripts/preprocessor_gui.py +729 -0
  19. tme/__init__.py +6 -0
  20. tme/__version__.py +1 -0
  21. tme/analyzer.py +1144 -0
  22. tme/backends/__init__.py +134 -0
  23. tme/backends/cupy_backend.py +309 -0
  24. tme/backends/matching_backend.py +1154 -0
  25. tme/backends/npfftw_backend.py +763 -0
  26. tme/backends/pytorch_backend.py +526 -0
  27. tme/data/__init__.py +0 -0
  28. tme/data/c48n309.npy +0 -0
  29. tme/data/c48n527.npy +0 -0
  30. tme/data/c48n9.npy +0 -0
  31. tme/data/c48u1.npy +0 -0
  32. tme/data/c48u1153.npy +0 -0
  33. tme/data/c48u1201.npy +0 -0
  34. tme/data/c48u1641.npy +0 -0
  35. tme/data/c48u181.npy +0 -0
  36. tme/data/c48u2219.npy +0 -0
  37. tme/data/c48u27.npy +0 -0
  38. tme/data/c48u2947.npy +0 -0
  39. tme/data/c48u3733.npy +0 -0
  40. tme/data/c48u4749.npy +0 -0
  41. tme/data/c48u5879.npy +0 -0
  42. tme/data/c48u7111.npy +0 -0
  43. tme/data/c48u815.npy +0 -0
  44. tme/data/c48u83.npy +0 -0
  45. tme/data/c48u8649.npy +0 -0
  46. tme/data/c600v.npy +0 -0
  47. tme/data/c600vc.npy +0 -0
  48. tme/data/metadata.yaml +80 -0
  49. tme/data/quat_to_numpy.py +42 -0
  50. tme/data/scattering_factors.pickle +0 -0
  51. tme/density.py +2314 -0
  52. tme/extensions.cpython-311-darwin.so +0 -0
  53. tme/helpers.py +881 -0
  54. tme/matching_data.py +377 -0
  55. tme/matching_exhaustive.py +1553 -0
  56. tme/matching_memory.py +382 -0
  57. tme/matching_optimization.py +1123 -0
  58. tme/matching_utils.py +1180 -0
  59. tme/parser.py +429 -0
  60. tme/preprocessor.py +1291 -0
  61. tme/scoring.py +866 -0
  62. tme/structure.py +1428 -0
  63. tme/types.py +10 -0
tme/matching_utils.py ADDED
@@ -0,0 +1,1180 @@
1
+ """ Utility functions for template matching.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ import os
8
+ import traceback
9
+ import pickle
10
+ from shutil import move
11
+ from tempfile import mkstemp
12
+ from itertools import product
13
+ from typing import Tuple, Dict, Callable
14
+ from concurrent.futures import ThreadPoolExecutor
15
+
16
+ import numpy as np
17
+ from numpy.typing import NDArray
18
+ from scipy.spatial import ConvexHull
19
+ from scipy.spatial.transform import Rotation
20
+
21
+ from .helpers import quaternion_to_rotation_matrix, load_quaternions_by_angle
22
+ from .extensions import max_euclidean_distance
23
+ from .matching_memory import estimate_ram_usage
24
+
25
+
26
+ def handle_traceback(last_type, last_value, last_traceback):
27
+ """
28
+ Handle sys.exc_info().
29
+
30
+ Parameters
31
+ ----------
32
+ last_type : type
33
+ The type of the last exception.
34
+ last_value :
35
+ The value of the last exception.
36
+ last_traceback : traceback
37
+ The traceback object encapsulating the call stack at the point
38
+ where the exception originally occurred.
39
+
40
+ Raises
41
+ ------
42
+ Exception
43
+ Re-raises the last exception.
44
+ """
45
+ if last_type is None:
46
+ return None
47
+ traceback.print_tb(last_traceback)
48
+ raise Exception(last_value)
49
+ # raise last_type(last_value)
50
+
51
+
52
+ def generate_tempfile_name(suffix=None):
53
+ """
54
+ Returns the path to a potential temporary file location. If the environment
55
+ variable TME_TMPDIR is defined, the temporary file will be created there.
56
+ Otherwise the default tmp directory will be used.
57
+
58
+ Parameters
59
+ ----------
60
+ suffix : str, optional
61
+ File suffix. By default the file has no suffix.
62
+
63
+ Returns
64
+ -------
65
+ str
66
+ The generated filename
67
+ """
68
+ tmp_dir = os.environ.get("TMPDIR", None)
69
+ _, filename = mkstemp(suffix=suffix, dir=tmp_dir)
70
+ return filename
71
+
72
+
73
+ def array_to_memmap(arr: NDArray, filename: str = None) -> str:
74
+ """
75
+ Converts a numpy array to a np.memmap.
76
+
77
+ Parameters
78
+ ----------
79
+ arr : np.ndarray
80
+ The numpy array to be converted.
81
+ filename : str, optional
82
+ Desired filename for the memmap. If not provided, a temporary
83
+ file will be created.
84
+
85
+ Notes
86
+ -----
87
+ If the environment variable TME_TMPDIR is defined, the temporary
88
+ file will be created there. Otherwise the default tmp directory
89
+ will be used.
90
+
91
+ Returns
92
+ -------
93
+ str
94
+ The filename where the memmap was written to.
95
+ """
96
+ if filename is None:
97
+ filename = generate_tempfile_name()
98
+
99
+ shape, dtype = arr.shape, arr.dtype
100
+ arr_memmap = np.memmap(filename, mode="w+", dtype=dtype, shape=shape)
101
+
102
+ arr_memmap[:] = arr[:]
103
+ arr_memmap.flush()
104
+
105
+ return filename
106
+
107
+
108
+ def memmap_to_array(arr: NDArray) -> NDArray:
109
+ """
110
+ Converts a np.memmap into an numpy array.
111
+
112
+ Parameters
113
+ ----------
114
+ arr : np.memmap
115
+ The numpy array to be converted.
116
+
117
+ Returns
118
+ -------
119
+ np.ndarray
120
+ The converted array.
121
+ """
122
+ if type(arr) == np.memmap:
123
+ memmap_filepath = arr.filename
124
+ arr = np.array(arr)
125
+ os.remove(memmap_filepath)
126
+ return arr
127
+
128
+
129
+ def close_memmap(arr: np.ndarray) -> None:
130
+ """
131
+ Remove the file associated with a numpy memmap array.
132
+
133
+ Parameters
134
+ ----------
135
+ arr : np.ndarray
136
+ The numpy array which might be a memmap.
137
+ """
138
+ try:
139
+ os.remove(arr.filename)
140
+ # arr._mmap.close()
141
+ except Exception:
142
+ pass
143
+
144
+
145
+ def write_pickle(data: object, filename: str) -> None:
146
+ """
147
+ Serialize and write data to a file invalidating the input data in
148
+ the process. This function uses type-specific serialization for
149
+ certain objects, such as np.memmap, for optimized storage. Other
150
+ objects are serialized using standard pickle.
151
+
152
+ Parameters
153
+ ----------
154
+ data : iterable or object
155
+ The data to be serialized.
156
+ filename : str
157
+ The name of the file where the serialized data will be written.
158
+
159
+ See Also
160
+ --------
161
+ :py:meth:`load_pickle`
162
+ """
163
+ if type(data) not in (list, tuple):
164
+ data = (data,)
165
+
166
+ dirname = os.path.dirname(filename)
167
+ with open(filename, "wb") as ofile, ThreadPoolExecutor() as executor:
168
+ for i in range(len(data)):
169
+ futures = []
170
+ item = data[i]
171
+ if isinstance(item, np.memmap):
172
+ _, new_filename = mkstemp(suffix=".mm", dir=dirname)
173
+ new_item = ("np.memmap", item.shape, item.dtype, new_filename)
174
+ futures.append(executor.submit(move, item.filename, new_filename))
175
+ item = new_item
176
+ pickle.dump(item, ofile)
177
+ for future in futures:
178
+ future.result()
179
+
180
+
181
+ def load_pickle(filename: str) -> object:
182
+ """
183
+ Load and deserialize data written by :py:meth:`write_pickle`.
184
+
185
+ Parameters
186
+ ----------
187
+ filename : str
188
+ The name of the file to read and deserialize data from.
189
+
190
+ Returns
191
+ -------
192
+ object or iterable
193
+ The deserialized data.
194
+
195
+ See Also
196
+ --------
197
+ :py:meth:`write_pickle`
198
+ """
199
+
200
+ def _load_pickle(file_handle):
201
+ try:
202
+ while True:
203
+ yield pickle.load(file_handle)
204
+ except EOFError:
205
+ pass
206
+
207
+ def _is_pickle_memmap(data):
208
+ ret = False
209
+ if type(data[0]) == str:
210
+ if data[0] == "np.memmap":
211
+ ret = True
212
+ return ret
213
+
214
+ items = []
215
+ with open(filename, "rb") as ifile:
216
+ for data in _load_pickle(ifile):
217
+ if isinstance(data, tuple):
218
+ if _is_pickle_memmap(data):
219
+ _, shape, dtype, filename = data
220
+ data = np.memmap(filename, shape=shape, dtype=dtype)
221
+ items.append(data)
222
+ return items[0] if len(items) == 1 else items
223
+
224
+
225
+ def compute_parallelization_schedule(
226
+ shape1: NDArray,
227
+ shape2: NDArray,
228
+ max_cores: int,
229
+ max_ram: int,
230
+ matching_method: str,
231
+ backend: str = None,
232
+ split_only_outer: bool = False,
233
+ shape1_padding: NDArray = None,
234
+ analyzer_method: str = None,
235
+ max_splits: int = 256,
236
+ float_nbytes: int = 4,
237
+ complex_nbytes: int = 8,
238
+ integer_nbytes: int = 4,
239
+ ) -> Tuple[Dict, int, int]:
240
+ """
241
+ Computes a parallelization schedule for a given computation.
242
+
243
+ This function estimates the amount of memory that would be used by a computation
244
+ and breaks down the computation into smaller parts that can be executed in parallel
245
+ without exceeding the specified limits on the number of cores and memory.
246
+
247
+ Parameters
248
+ ----------
249
+ shape1 : NDArray
250
+ The shape of the first input tensor.
251
+ shape1_padding : NDArray, optional
252
+ Padding for shape1 used for each split. None by defauly
253
+ shape2 : NDArray
254
+ The shape of the second input tensor.
255
+ max_cores : int
256
+ The maximum number of cores that can be used.
257
+ max_ram : int
258
+ The maximum amount of memory that can be used.
259
+ matching_method : str
260
+ The metric used for scoring the computations.
261
+ backend : str, optional
262
+ Backend used for computations.
263
+ split_only_outer : bool, optional
264
+ Whether only outer splits sould be considered.
265
+ analyzer_method : str
266
+ The method used for score analysis.
267
+ max_splits : int, optional
268
+ The maximum number of parts that the computation can be split into,
269
+ by default 256.
270
+ float_nbytes : int
271
+ Number of bytes of the used float, e.g. 4 for float32.
272
+ complex_nbytes : int
273
+ Number of bytes of the used complex, e.g. 8 for complex64.
274
+ integer_nbytes : int
275
+ Number of bytes of the used integer, e.g. 4 for int32.
276
+
277
+ Notes
278
+ -----
279
+ This function assumes that no residual memory remains after each split,
280
+ which not always holds true, e.g. when using
281
+ :py:class:`tme.analyzer.MaxScoreOverRotations`.
282
+
283
+ Returns
284
+ -------
285
+ dict
286
+ The optimal splits for each axis of the first input tensor.
287
+ int
288
+ The number of outer jobs.
289
+ int
290
+ The number of inner jobs per outer job.
291
+ """
292
+ shape1, shape2 = np.array(shape1), np.array(shape2)
293
+ if shape1_padding is None:
294
+ shape1_padding = np.zeros_like(shape1)
295
+ core_assignments = []
296
+ for i in range(1, int(max_cores**0.5) + 1):
297
+ if max_cores % i == 0:
298
+ core_assignments.append((i, max_cores // i))
299
+ core_assignments.append((max_cores // i, i))
300
+
301
+ if split_only_outer:
302
+ core_assignments = [(1, max_cores)]
303
+
304
+ possible_params, split_axis = [], np.argmax(shape1)
305
+ split_factor, n_splits = [1 for _ in range(len(shape1))], 0
306
+ while n_splits <= max_splits:
307
+ splits = {k: split_factor[k] for k in range(len(split_factor))}
308
+ array_slices = split_numpy_array_slices(shape=shape1, splits=splits)
309
+ array_widths = [
310
+ tuple(x.stop - x.start for x in split) for split in array_slices
311
+ ]
312
+ n_splits = np.prod(list(splits.values()))
313
+
314
+ for inner_cores, outer_cores in core_assignments:
315
+ if outer_cores > n_splits:
316
+ continue
317
+ ram_usage = [
318
+ estimate_ram_usage(
319
+ shape1=np.add(shp, shape1_padding),
320
+ shape2=shape2,
321
+ matching_method=matching_method,
322
+ analyzer_method=analyzer_method,
323
+ backend=backend,
324
+ ncores=inner_cores,
325
+ float_nbytes=float_nbytes,
326
+ complex_nbytes=complex_nbytes,
327
+ integer_nbytes=integer_nbytes,
328
+ )
329
+ for shp in array_widths
330
+ ]
331
+ max_usage = 0
332
+ for i in range(0, len(ram_usage), outer_cores):
333
+ usage = np.sum(ram_usage[i : (i + outer_cores)])
334
+ if usage > max_usage:
335
+ max_usage = usage
336
+
337
+ inits = n_splits // outer_cores
338
+ if max_usage < max_ram:
339
+ possible_params.append(
340
+ (*split_factor, outer_cores, inner_cores, n_splits, inits)
341
+ )
342
+ split_factor[split_axis] += 1
343
+ split_axis += 1
344
+ if split_axis == shape1.size:
345
+ split_axis = 0
346
+
347
+ possible_params = np.array(possible_params)
348
+ if not len(possible_params):
349
+ print(
350
+ "No suitable assignment found. Consider increasing "
351
+ "max_ram or decrease max_cores."
352
+ )
353
+ return None, None
354
+
355
+ init = possible_params.shape[1] - 1
356
+ possible_params = possible_params[
357
+ np.lexsort((possible_params[:, init], possible_params[:, (init - 1)]))
358
+ ]
359
+ splits = {k: possible_params[0, k] for k in range(shape1.size)}
360
+ core_assignment = (
361
+ possible_params[0, shape1.size],
362
+ possible_params[0, (shape1.size + 1)],
363
+ )
364
+
365
+ return splits, core_assignment
366
+
367
+
368
+ def centered(arr: NDArray, newshape: Tuple[int]) -> NDArray:
369
+ """
370
+ Extract the centered portion of an array based on a new shape.
371
+
372
+ Parameters
373
+ ----------
374
+ arr : NDArray
375
+ Input array.
376
+ newshape : tuple
377
+ Desired shape for the central portion.
378
+
379
+ Returns
380
+ -------
381
+ NDArray
382
+ Central portion of the array with shape `newshape`.
383
+
384
+ References
385
+ ----------
386
+ .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L388
387
+ """
388
+ new_shape = np.asarray(newshape)
389
+ current_shape = np.array(arr.shape)
390
+ starts = (current_shape - new_shape) // 2
391
+ stops = starts + newshape
392
+ box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
393
+ return arr[box]
394
+
395
+
396
+ def centered_mask(arr: NDArray, newshape: Tuple[int]) -> NDArray:
397
+ """
398
+ Mask the centered portion of an array based on a new shape.
399
+
400
+ Parameters
401
+ ----------
402
+ arr : NDArray
403
+ Input array.
404
+ newshape : tuple
405
+ Desired shape for the mask.
406
+
407
+ Returns
408
+ -------
409
+ NDArray
410
+ Array with central portion unmasked and the rest set to 0.
411
+ """
412
+ new_shape = np.asarray(newshape)
413
+ current_shape = np.array(arr.shape)
414
+ starts = (current_shape - new_shape) // 2
415
+ stops = starts + newshape
416
+ box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
417
+ mask = np.zeros_like(arr)
418
+ mask[box] = 1
419
+ arr *= mask
420
+ return arr
421
+
422
+
423
+ def apply_convolution_mode(
424
+ arr: NDArray,
425
+ convolution_mode: str,
426
+ s1: Tuple[int],
427
+ s2: Tuple[int],
428
+ mask_output: bool = False,
429
+ ) -> NDArray:
430
+ """
431
+ Applies convolution_mode to arr.
432
+
433
+ Parameters
434
+ ----------
435
+ arr : NDArray
436
+ Numpy array containing convolution result of arrays with shape s1 and s2.
437
+ convolution_mode : str
438
+ Analogous to mode in ``scipy.signal.convolve``:
439
+
440
+ +---------+----------------------------------------------------------+
441
+ | 'full' | returns full template matching result of the inputs. |
442
+ +---------+----------------------------------------------------------+
443
+ | 'valid' | returns elements that do not rely on zero-padding.. |
444
+ +---------+----------------------------------------------------------+
445
+ | 'same' | output is the same size as s1. |
446
+ +---------+----------------------------------------------------------+
447
+ s1 : tuple
448
+ Tuple of integers corresponding to shape of convolution array 1.
449
+ s2 : tuple
450
+ Tuple of integers corresponding to shape of convolution array 2.
451
+ mask_output : bool, optional
452
+ Whether to mask values outside of convolution_mode rather than
453
+ removing them. Defaults to False.
454
+
455
+ Returns
456
+ -------
457
+ NDArray
458
+ The numpy array after applying the convolution mode.
459
+
460
+ References
461
+ ----------
462
+ .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py#L519
463
+ """
464
+ # This removes padding to next fast fourier length
465
+ arr = arr[tuple(slice(s1[i] + s2[i] - 1) for i in range(len(s1)))]
466
+
467
+ if convolution_mode not in ("full", "same", "valid"):
468
+ raise ValueError("Supported convolution_mode are 'full', 'same' and 'valid'.")
469
+
470
+ func = centered_mask if mask_output else centered
471
+ if convolution_mode == "full":
472
+ return arr
473
+ elif convolution_mode == "same":
474
+ return func(arr, s1)
475
+ elif convolution_mode == "valid":
476
+ valid_shape = [s1[i] - s2[i] + s2[i] % 2 for i in range(arr.ndim)]
477
+ return func(arr, valid_shape)
478
+
479
+
480
+ def compute_full_convolution_index(
481
+ outer_shape: Tuple[int],
482
+ inner_shape: Tuple[int],
483
+ outer_split: Tuple[slice],
484
+ inner_split: Tuple[slice],
485
+ ) -> Tuple[slice]:
486
+ """
487
+ Computes the position of the convolution of pieces in the full convolution.
488
+
489
+ Parameters
490
+ ----------
491
+ outer_shape : tuple
492
+ Tuple of integers corresponding to the shape of the outer array.
493
+ inner_shape : tuple
494
+ Tuple of integers corresponding to the shape of the inner array.
495
+ outer_split : tuple
496
+ Tuple of slices used to split outer array
497
+ (see :py:meth:`split_numpy_array_slices`).
498
+ inner_split : tuple
499
+ Tuple of slices used to split inner array
500
+ (see :py:meth:`split_numpy_array_slices`).
501
+
502
+ Returns
503
+ -------
504
+ tuple
505
+ Tuple of slices corresponding to the position of the given convolution
506
+ in the full convolution.
507
+ """
508
+ outer_shape = np.asarray(outer_shape)
509
+ inner_shape = np.asarray(inner_shape)
510
+
511
+ outer_width = np.array([outer.stop - outer.start for outer in outer_split])
512
+ inner_width = np.array([inner.stop - inner.start for inner in inner_split])
513
+ convolution_shape = outer_width + inner_width - 1
514
+
515
+ end_inner = np.array([inner.stop for inner in inner_split]).astype(int)
516
+ start_outer = np.array([outer.start for outer in outer_split]).astype(int)
517
+
518
+ offsets = start_outer + inner_shape - end_inner
519
+
520
+ score_slice = tuple(
521
+ (slice(offset, offset + shape))
522
+ for offset, shape in zip(offsets, convolution_shape)
523
+ )
524
+
525
+ return score_slice
526
+
527
+
528
+ def split_numpy_array_slices(
529
+ shape: NDArray, splits: Dict, margin: NDArray = None
530
+ ) -> Tuple[slice]:
531
+ """
532
+ Returns a tuple of slices to subset a numpy array into pieces along multiple axes.
533
+
534
+ Parameters
535
+ ----------
536
+ shape : NDArray
537
+ Shape of the array to split.
538
+ splits : dict
539
+ A dictionary where the keys are the axis numbers and the values
540
+ are the number of splits along that axis.
541
+ margin : NDArray, optional
542
+ Padding on the left hand side of the array.
543
+
544
+ Returns
545
+ -------
546
+ tuple
547
+ A tuple of slices, where each slice corresponds to a split along an axis.
548
+ """
549
+ ndim = len(shape)
550
+ if margin is None:
551
+ margin = np.zeros(ndim, dtype=int)
552
+ splits = {k: max(splits.get(k, 0), 1) for k in range(ndim)}
553
+ new_shape = np.divide(shape, [splits.get(i, 1) for i in range(ndim)]).astype(int)
554
+
555
+ slice_list = [
556
+ tuple(
557
+ (slice(max((n_splits * length) - margin[axis], 0), (n_splits + 1) * length))
558
+ if n_splits < splits.get(axis, 1) - 1
559
+ else (slice(max((n_splits * length) - margin[axis], 0), shape[axis]))
560
+ for n_splits in range(splits.get(axis, 1))
561
+ )
562
+ for length, axis in zip(new_shape, splits.keys())
563
+ ]
564
+
565
+ splits = tuple(product(*slice_list))
566
+
567
+ return splits
568
+
569
+
570
+ def get_rotation_matrices(
571
+ angular_sampling: float, dim: int = 3, use_optimized_set: bool = True
572
+ ) -> NDArray:
573
+ """
574
+ Returns rotation matrices in format k x dim x dim, where k is determined
575
+ by ``angular_sampling``.
576
+
577
+ Parameters
578
+ ----------
579
+ angular_sampling : float
580
+ The angle in degrees used for the generation of rotation matrices.
581
+ dim : int, optional
582
+ Dimension of the rotation matrices.
583
+ use_optimized_set : bool, optional
584
+ Whether to use pre-computed rotational sets with more optimal sampling.
585
+ Currently only available when dim=3.
586
+
587
+ Notes
588
+ -----
589
+ For the case of dim = 3 optimized rotational sets are used, otherwise
590
+ QR-decomposition.
591
+
592
+ Returns
593
+ -------
594
+ NDArray
595
+ Array of shape (k, dim, dim) containing k rotation matrices.
596
+ """
597
+ if dim == 3 and use_optimized_set:
598
+ quaternions, *_ = load_quaternions_by_angle(angular_sampling)
599
+ ret = quaternion_to_rotation_matrix(quaternions)
600
+ else:
601
+ num_rotations = dim * (dim - 1) // 2
602
+ k = int((360 / angular_sampling) ** num_rotations)
603
+ As = np.random.randn(k, dim, dim)
604
+ ret, _ = np.linalg.qr(As)
605
+ dets = np.linalg.det(ret)
606
+ neg_dets = dets < 0
607
+ ret[neg_dets, :, -1] *= -1
608
+ return ret
609
+
610
+
611
+ def minimum_enclosing_box(
612
+ coordinates: NDArray,
613
+ margin: NDArray = None,
614
+ use_geometric_center: bool = False,
615
+ ) -> Tuple[int]:
616
+ """
617
+ Computes the minimal enclosing box around coordinates with margin.
618
+
619
+ Parameters
620
+ ----------
621
+ coordinates : NDArray
622
+ Coordinates of which the enclosing box should be computed. The shape
623
+ of this array should be [d, n] with d dimensions and n coordinates.
624
+ margin : NDArray, optional
625
+ Box margin. Defaults to None.
626
+ use_geometric_center : bool, optional
627
+ Whether the box should accommodate the geometric or the coordinate
628
+ center. Defaults to False.
629
+
630
+ Returns
631
+ -------
632
+ tuple
633
+ Integers corresponding to the minimum enclosing box shape.
634
+ """
635
+ point_cloud = np.asarray(coordinates)
636
+ dim = point_cloud.shape[0]
637
+ point_cloud = point_cloud - point_cloud.min(axis=1)[:, None]
638
+
639
+ margin = np.zeros(dim) if margin is None else margin
640
+ margin = np.asarray(margin).astype(int)
641
+
642
+ norm_cloud = point_cloud - point_cloud.mean(axis=1)[:, None]
643
+ # Adding one avoids clipping during scipy.ndimage.affine_transform
644
+ shape = np.repeat(
645
+ np.ceil(2 * np.linalg.norm(norm_cloud, axis=0).max()) + 1, dim
646
+ ).astype(int)
647
+ if use_geometric_center:
648
+ hull = ConvexHull(point_cloud.T)
649
+ distance, _ = max_euclidean_distance(point_cloud[:, hull.vertices].T)
650
+ distance += np.linalg.norm(np.ones(dim))
651
+ shape = np.repeat(np.rint(distance).astype(int), dim)
652
+
653
+ return shape
654
+
655
+
656
+ def crop_input(
657
+ target: "Density",
658
+ template: "Density",
659
+ target_mask: "Density" = None,
660
+ template_mask: "Density" = None,
661
+ map_cutoff: float = 0,
662
+ template_cutoff: float = 0,
663
+ ) -> Tuple[int]:
664
+ """
665
+ Crop target and template maps for efficient fitting. Input densities
666
+ are cropped in place.
667
+
668
+ Parameters
669
+ ----------
670
+ target : Density
671
+ Target to be fitted on.
672
+ template : Density
673
+ Template to fit onto the target.
674
+ target_mask : Density, optional
675
+ Path to mask of target. Will be croppped like target.
676
+ template_mask : Density, optional
677
+ Path to mask of template. Will be cropped like template.
678
+ map_cutoff : float, optional
679
+ Cutoff value for trimming the target Density. Default is 0.
680
+ map_cutoff : float, optional
681
+ Cutoff value for trimming the template Density. Default is 0.
682
+
683
+ Returns
684
+ -------
685
+ Tuple[int]
686
+ Tuple containing reference fit index
687
+ """
688
+ convolution_shape_init = np.add(target.shape, template.shape) - 1
689
+ # If target and template are aligned, fitting should return this index
690
+ reference_fit = np.subtract(template.shape, 1)
691
+
692
+ target_box = tuple(slice(0, x) for x in target.shape)
693
+ if map_cutoff is not None:
694
+ target_box = target.trim_box(cutoff=map_cutoff)
695
+
696
+ target_mask_box = target_box
697
+ if target_mask is not None and map_cutoff is not None:
698
+ target_mask_box = target_mask.trim_box(cutoff=map_cutoff)
699
+ target_box = tuple(
700
+ slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
701
+ for arr, mask in zip(target_box, target_mask_box)
702
+ )
703
+
704
+ template_box = tuple(slice(0, x) for x in template.shape)
705
+ if template_cutoff is not None:
706
+ template_box = template.trim_box(cutoff=template_cutoff)
707
+
708
+ template_mask_box = template_box
709
+ if template_mask is not None and template_cutoff is not None:
710
+ template_mask_box = template_mask.trim_box(cutoff=template_cutoff)
711
+ template_box = tuple(
712
+ slice(min(arr.start, mask.start), max(arr.stop, mask.stop))
713
+ for arr, mask in zip(template_box, template_mask_box)
714
+ )
715
+
716
+ cut_right = np.array(
717
+ [shape - x.stop for shape, x in zip(template.shape, template_box)]
718
+ )
719
+ cut_left = np.array([x.start for x in target_box])
720
+
721
+ origin_difference = np.divide(target.origin - template.origin, target.sampling_rate)
722
+ origin_difference = origin_difference.astype(int)
723
+
724
+ target.adjust_box(target_box)
725
+ template.adjust_box(template_box)
726
+
727
+ if target_mask is not None:
728
+ target_mask.adjust_box(target_box)
729
+ if template_mask is not None:
730
+ template_mask.adjust_box(template_box)
731
+
732
+ reference_fit -= cut_right + cut_left + origin_difference
733
+
734
+ convolution_shape = np.array(target.shape)
735
+ convolution_shape += np.array(template.shape) - 1
736
+
737
+ print(f"Cropped volume of target is: {target.shape}")
738
+ print(f"Cropped volume of template is: {template.shape}")
739
+ saving = 1 - (np.prod(convolution_shape)) / np.prod(convolution_shape_init)
740
+ saving *= 100
741
+
742
+ print(
743
+ "Cropping changed array size from "
744
+ f"{round(4*np.prod(convolution_shape_init)/1e6, 3)} MB "
745
+ f"to {round(4*np.prod(convolution_shape)/1e6, 3)} MB "
746
+ f"({'-' if saving > 0 else ''}{abs(round(saving, 2))}%)"
747
+ )
748
+ return reference_fit
749
+
750
+
751
+ def euler_to_rotationmatrix(angles: Tuple[float]) -> NDArray:
752
+ """
753
+ Convert Euler angles to a rotation matrix.
754
+
755
+ Parameters
756
+ ----------
757
+ angles : tuple
758
+ A tuple representing the Euler angles in degrees.
759
+
760
+ Returns
761
+ -------
762
+ NDArray
763
+ The generated rotation matrix.
764
+ """
765
+ if len(angles) == 1:
766
+ angles = (angles, 0, 0)
767
+ rotation_matrix = (
768
+ Rotation.from_euler("zyx", angles, degrees=True).as_matrix().astype(np.float32)
769
+ )
770
+ return rotation_matrix
771
+
772
+
773
+ def euler_from_rotationmatrix(rotation_matrix: NDArray) -> Tuple:
774
+ """
775
+ Convert a rotation matrix to euler angles.
776
+
777
+ Parameters
778
+ ----------
779
+ rotation_matrix : NDArray
780
+ A 2 x 2 or 3 x 3 rotation matrix in z y x form.
781
+
782
+ Returns
783
+ -------
784
+ Tuple
785
+ The generate euler angles in degrees
786
+ """
787
+ if rotation_matrix.shape[0] == 2:
788
+ temp_matrix = np.eye(3)
789
+ temp_matrix[:2, :2] = rotation_matrix
790
+ rotation_matrix = temp_matrix
791
+ euler_angles = (
792
+ Rotation.from_matrix(rotation_matrix)
793
+ .as_euler("zyx", degrees=True)
794
+ .astype(np.float32)
795
+ )
796
+ return euler_angles
797
+
798
+
799
+ def rigid_transform(
800
+ coordinates: NDArray,
801
+ rotation_matrix: NDArray,
802
+ out: NDArray,
803
+ translation: NDArray,
804
+ use_geometric_center: bool = False,
805
+ coordinates_mask: NDArray = None,
806
+ out_mask: NDArray = None,
807
+ center: NDArray = None,
808
+ ) -> None:
809
+ """
810
+ Apply a rigid transformation (rotation and translation) to given coordinates.
811
+
812
+ Parameters
813
+ ----------
814
+ coordinates : NDArray
815
+ An array representing the coordinates to be transformed [d x N].
816
+ rotation_matrix : NDArray
817
+ The rotation matrix to be applied [d x d].
818
+ translation : NDArray
819
+ The translation vector to be applied [d].
820
+ out : NDArray
821
+ The output array to store the transformed coordinates.
822
+ coordinates_mask : NDArray, optional
823
+ An array representing the mask for the coordinates [d x T].
824
+ out_mask : NDArray, optional
825
+ The output array to store the transformed coordinates mask.
826
+ use_geometric_center : bool, optional
827
+ Whether to use geometric or coordinate center.
828
+
829
+ Returns
830
+ -------
831
+ None
832
+ """
833
+ coordinate_dtype = coordinates.dtype
834
+ center = coordinates.mean(axis=1) if center is None else center
835
+ if not use_geometric_center:
836
+ coordinates = coordinates - center[:, None]
837
+
838
+ np.matmul(rotation_matrix, coordinates, out=out)
839
+ if use_geometric_center:
840
+ axis_max, axis_min = out.max(axis=1), out.min(axis=1)
841
+ axis_difference = axis_max - axis_min
842
+ translation = np.add(translation, center - axis_max + (axis_difference // 2))
843
+ else:
844
+ translation = np.add(translation, np.subtract(center, out.mean(axis=1)))
845
+
846
+ out += translation[:, None]
847
+ if coordinates_mask is not None and out_mask is not None:
848
+ if not use_geometric_center:
849
+ coordinates_mask = coordinates_mask - center[:, None]
850
+ np.matmul(rotation_matrix, coordinates_mask, out=out_mask)
851
+ out_mask += translation[:, None]
852
+
853
+ if not use_geometric_center and coordinate_dtype != out.dtype:
854
+ np.subtract(out.mean(axis=1), out.astype(int).mean(axis=1), out=translation)
855
+ out += translation[:, None]
856
+
857
+
858
+ def _format_string(string: str) -> str:
859
+ """
860
+ Formats a string by adding quotation marks if it contains white spaces.
861
+
862
+ Parameters
863
+ ----------
864
+ string : str
865
+ Input string to be formatted.
866
+
867
+ Returns
868
+ -------
869
+ str
870
+ Formatted string with added quotation marks if needed.
871
+ """
872
+ if " " in string:
873
+ return f"'{string}'"
874
+ # Occurs e.g. for C1' atoms. The trailing whitespace is necessary.
875
+ if string.count("'") == 1:
876
+ return f'"{string}"'
877
+ return string
878
+
879
+
880
+ def _format_mmcif_colunns(subdict: Dict) -> Dict:
881
+ """
882
+ Formats the columns of a mmcif dictionary.
883
+
884
+ Parameters
885
+ ----------
886
+ subdict : dict
887
+ Input dictionary where each key corresponds to a column and the
888
+ values are iterables containing the column values.
889
+
890
+ Returns
891
+ -------
892
+ dict
893
+ Formatted dictionary with the columns of the mmcif file.
894
+ """
895
+ subdict = {k: [_format_string(s) for s in v] for k, v in subdict.items()}
896
+ key_length = {
897
+ key: len(max(value, key=lambda x: len(x), default=""))
898
+ for key, value in subdict.items()
899
+ }
900
+ padded_subdict = {
901
+ key: [s.ljust(key_length[key] + 1) for s in values]
902
+ for key, values in subdict.items()
903
+ }
904
+ return padded_subdict
905
+
906
+
907
+ def create_mask(mask_type: str, **kwargs) -> NDArray:
908
+ """
909
+ Creates a mask of the specified type.
910
+
911
+ Parameters
912
+ ----------
913
+ mask_type : str
914
+ Type of the mask to be created. Can be "ellipse", "box", or "tube".
915
+ kwargs : dict
916
+ Additional parameters required by the mask creating functions.
917
+
918
+ Returns
919
+ -------
920
+ NDArray
921
+ The created mask.
922
+
923
+ Raises
924
+ ------
925
+ ValueError
926
+ If the mask_type is invalid.
927
+
928
+ See Also
929
+ --------
930
+ :py:meth:`elliptical_mask`
931
+ :py:meth:`box_mask`
932
+ :py:meth:`tube_mask`
933
+ """
934
+ mapping = {"ellipse": elliptical_mask, "box": box_mask, "tube": tube_mask}
935
+ if mask_type not in mapping:
936
+ raise ValueError(f"mask_type has to be one of {','.join(mapping.keys())}")
937
+
938
+ return mapping[mask_type](**kwargs)
939
+
940
+
941
+ def elliptical_mask(
942
+ shape: Tuple[int], radius: Tuple[float], center: Tuple[int]
943
+ ) -> NDArray:
944
+ """
945
+ Creates an ellipsoidal mask.
946
+
947
+ Parameters
948
+ ----------
949
+ shape : tuple
950
+ Shape of the mask to be created.
951
+ radius : tuple
952
+ Radius of the ellipse.
953
+ center : tuple
954
+ Center of the ellipse.
955
+
956
+ Returns
957
+ -------
958
+ NDArray
959
+ The created ellipsoidal mask.
960
+
961
+ Raises
962
+ ------
963
+ ValueError
964
+ If the length of center and radius is not one or the same as shape.
965
+
966
+ Examples
967
+ --------
968
+ >>> mask = elliptical_mask(shape = (20,20), radius = (5,5), center = (10,10))
969
+ """
970
+ center, shape, radius = np.asarray(center), np.asarray(shape), np.asarray(radius)
971
+
972
+ radius = np.repeat(radius, shape.size // radius.size)
973
+ center = np.repeat(center, shape.size // center.size)
974
+
975
+ if radius.size != shape.size:
976
+ raise ValueError("Length of radius has to be either one or match shape.")
977
+ if center.size != shape.size:
978
+ raise ValueError("Length of center has to be either one or match shape.")
979
+
980
+ n = shape.size
981
+ center = center.reshape((-1,) + (1,) * n)
982
+ radius = radius.reshape((-1,) + (1,) * n)
983
+
984
+ mask = np.linalg.norm((np.indices(shape) - center) / radius, axis=0)
985
+ mask = (mask <= 1).astype(int)
986
+
987
+ return mask
988
+
989
+
990
+ def box_mask(shape: Tuple[int], center: Tuple[int], height: Tuple[int]) -> np.ndarray:
991
+ """
992
+ Creates a box mask centered around the provided center point.
993
+
994
+ Parameters
995
+ ----------
996
+ shape : Tuple[int]
997
+ Shape of the output array.
998
+ center : Tuple[int]
999
+ Center point coordinates of the box.
1000
+ height : Tuple[int]
1001
+ Height (side length) of the box along each axis.
1002
+
1003
+ Returns
1004
+ -------
1005
+ NDArray
1006
+ The created box mask.
1007
+ """
1008
+ if len(shape) != len(center) or len(center) != len(height):
1009
+ raise ValueError("The length of shape, center, and height must be consistent.")
1010
+
1011
+ # Calculate min and max coordinates for the box using the center and half-heights
1012
+ center, height = np.array(center, dtype=int), np.array(height, dtype=int)
1013
+
1014
+ half_heights = height // 2
1015
+ starts = np.maximum(center - half_heights, 0)
1016
+ stops = np.minimum(center + half_heights + np.mod(height, 2) + 1, shape)
1017
+ slice_indices = tuple(slice(*coord) for coord in zip(starts, stops))
1018
+
1019
+ out = np.zeros(shape)
1020
+ out[slice_indices] = 1
1021
+ return out
1022
+
1023
+
1024
+ def tube_mask(
1025
+ shape: Tuple[int],
1026
+ symmetry_axis: int,
1027
+ base_center: Tuple[int],
1028
+ inner_radius: float,
1029
+ outer_radius: float,
1030
+ height: int,
1031
+ ) -> NDArray:
1032
+ """
1033
+ Creates a tube mask.
1034
+
1035
+ Parameters
1036
+ ----------
1037
+ shape : tuple
1038
+ Shape of the mask to be created.
1039
+ symmetry_axis : int
1040
+ The axis of symmetry for the tube.
1041
+ base_center : tuple
1042
+ Center of the base circle of the tube.
1043
+ inner_radius : float
1044
+ Inner radius of the tube.
1045
+ outer_radius : float
1046
+ Outer radius of the tube.
1047
+ height : int
1048
+ Height of the tube.
1049
+
1050
+ Returns
1051
+ -------
1052
+ NDArray
1053
+ The created tube mask.
1054
+
1055
+ Raises
1056
+ ------
1057
+ ValueError
1058
+ If the inner radius is larger than the outer radius. Or height is larger
1059
+ than the symmetry axis shape.
1060
+ """
1061
+ if inner_radius > outer_radius:
1062
+ raise ValueError("inner_radius should be smaller than outer_radius.")
1063
+
1064
+ if height > shape[symmetry_axis]:
1065
+ raise ValueError(f"Height can be no larger than {shape[symmetry_axis]}.")
1066
+
1067
+ if symmetry_axis > len(shape):
1068
+ raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
1069
+
1070
+ circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
1071
+ base_center = tuple(b for ix, b in enumerate(base_center) if ix != symmetry_axis)
1072
+
1073
+ inner_circle = create_mask(
1074
+ mask_type="ellipse",
1075
+ shape=circle_shape,
1076
+ radius=inner_radius,
1077
+ center=base_center,
1078
+ )
1079
+ outer_circle = create_mask(
1080
+ mask_type="ellipse",
1081
+ shape=circle_shape,
1082
+ radius=outer_radius,
1083
+ center=base_center,
1084
+ )
1085
+ circle = outer_circle - inner_circle
1086
+ circle = np.expand_dims(circle, axis=symmetry_axis)
1087
+
1088
+ center = shape[symmetry_axis] // 2
1089
+ start_idx = center - height // 2
1090
+ stop_idx = center + height // 2 + height % 2
1091
+
1092
+ slice_indices = tuple(
1093
+ slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
1094
+ for i in range(len(shape))
1095
+ )
1096
+ tube = np.zeros(shape)
1097
+ tube[slice_indices] = np.repeat(circle, height, axis=symmetry_axis)
1098
+
1099
+ return tube
1100
+
1101
+
1102
+ def scramble_phases(
1103
+ arr: NDArray, noise_proportion: float = 0.5, seed: int = 42
1104
+ ) -> NDArray:
1105
+ """
1106
+ Applies random phase scrambling to a given array.
1107
+
1108
+ This function takes an input array, applies a Fourier transform, then scrambles the
1109
+ phase with a given proportion of noise, and finally applies an
1110
+ inverse Fourier transform to the scrambled data. The phase scrambling
1111
+ is controlled by a random seed.
1112
+
1113
+ Parameters
1114
+ ----------
1115
+ arr : NDArray
1116
+ The input array to be scrambled.
1117
+ noise_proportion : float, optional
1118
+ The proportion of noise in the phase scrambling, by default 0.5.
1119
+ seed : int, optional
1120
+ The seed for the random phase scrambling, by default 42.
1121
+
1122
+ Returns
1123
+ -------
1124
+ NDArray
1125
+ The array with scrambled phases.
1126
+
1127
+ Raises
1128
+ ------
1129
+ ValueError
1130
+ If noise_proportion is not within [0, 1].
1131
+ """
1132
+ if noise_proportion < 0 or noise_proportion > 1:
1133
+ raise ValueError("noise_proportion has to be within [0, 1].")
1134
+
1135
+ arr_fft = np.fft.fftn(arr)
1136
+
1137
+ amp = np.abs(arr_fft)
1138
+ ph = np.angle(arr_fft)
1139
+
1140
+ np.random.seed(seed)
1141
+ ph_noise = np.random.permutation(ph)
1142
+ ph_new = ph * (1 - noise_proportion) + ph_noise * noise_proportion
1143
+ ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph_new)))
1144
+ return ret
1145
+
1146
+
1147
+ def conditional_execute(func: Callable, execute_operation: bool = True) -> Callable:
1148
+ """
1149
+ Return the given function or a no-operation function based on execute_operation.
1150
+
1151
+ Parameters
1152
+ ----------
1153
+ func : callable
1154
+ The function to be executed if execute_operation is True.
1155
+ execute_operation : bool, optional
1156
+ A flag that determines whether to return `func` or a no-operation function.
1157
+ Default is True.
1158
+
1159
+ Returns
1160
+ -------
1161
+ callable
1162
+ Either the given function `func` or a no-operation function.
1163
+
1164
+ Examples
1165
+ --------
1166
+ >>> def greet(name):
1167
+ ... return f"Hello, {name}!"
1168
+ ...
1169
+ >>> operation = conditional_execute(greet, False)
1170
+ >>> operation("Alice")
1171
+ >>> operation = conditional_execute(greet, True)
1172
+ >>> operation("Alice")
1173
+ 'Hello, Alice!'
1174
+ """
1175
+
1176
+ def noop(*args, **kwargs):
1177
+ """No operation function."""
1178
+ pass
1179
+
1180
+ return func if execute_operation else noop