pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/helpers.py DELETED
@@ -1,881 +0,0 @@
1
- """ General utility functions.
2
-
3
- Copyright (c) 2023 European Molecular Biology Laboratory
4
-
5
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
- """
7
-
8
- import os
9
- import yaml
10
- import pickle
11
- from itertools import product
12
- from typing import Tuple, Dict
13
-
14
- import numpy as np
15
- from numpy.typing import NDArray
16
- from scipy.special import iv as scipy_special_iv
17
- from scipy.ndimage import correlate1d, gaussian_filter
18
- from scipy.optimize import minimize
19
- from scipy.signal import convolve
20
- from scipy.interpolate import splrep, BSpline
21
- from scipy.stats import entropy
22
-
23
-
24
- def is_gzipped(filename: str) -> bool:
25
- """Check if a file is a gzip file by reading its magic number."""
26
- with open(filename, "rb") as f:
27
- return f.read(2) == b"\x1f\x8b"
28
-
29
-
30
- def window_to_volume(window: NDArray) -> NDArray:
31
- """
32
- Convert a 1D window to a 3D volume.
33
-
34
- Parameters
35
- ----------
36
- window : numpy.ndarray
37
- 1D window.
38
-
39
- Returns
40
- -------
41
- numpy.ndarray
42
- 3D volume generated from the 1D window.
43
- """
44
- window /= np.trapz(window)
45
- return (
46
- window[:, np.newaxis, np.newaxis]
47
- * window[np.newaxis, :, np.newaxis]
48
- * window[np.newaxis, np.newaxis, :]
49
- )
50
-
51
-
52
- def window_kaiserb(width: int, beta: float = 3.2, order: int = 0) -> NDArray:
53
- """
54
- Create a Kaiser-Bessel window.
55
-
56
- Parameters
57
- ----------
58
- width : int
59
- Width of the window.
60
- beta : float, optional
61
- Beta parameter of the Kaiser-Bessel window. Default is 3.2.
62
- order : int, optional
63
- Order of the Bessel function. Default is 0.
64
-
65
- Returns
66
- -------
67
- NDArray
68
- Kaiser-Bessel window.
69
-
70
- References
71
- ----------
72
- .. [1] Sorzano, Carlos et al (Mar. 2015). Fast and accurate conversion
73
- of atomic models into electron density maps. AIMS Biophysics
74
- 2, 8–20.
75
- """
76
- window = np.arange(0, width)
77
- alpha = (width - 1) / 2.0
78
- arr = beta * np.sqrt(1 - ((window - alpha) / alpha) ** 2.0)
79
-
80
- return bessel(order, arr) / bessel(order, beta)
81
-
82
-
83
- def window_blob(width: int, beta: float = 3.2, order: int = 2) -> NDArray:
84
- """
85
- Generate a blob window based on Bessel functions.
86
-
87
- Parameters
88
- ----------
89
- width : int
90
- Width of the window.
91
- beta : float, optional
92
- Beta parameter. Default is 3.2.
93
- order : int, optional
94
- Order of the Bessel function. Default is 2.
95
-
96
- Returns
97
- -------
98
- NDArray
99
- Blob window.
100
-
101
- References
102
- ----------
103
- .. [1] Sorzano, Carlos et al (Mar. 2015). Fast and accurate conversion
104
- of atomic models into electron density maps. AIMS Biophysics
105
- 2, 8–20.
106
- """
107
- window = np.arange(0, width)
108
- alpha = (width - 1) / 2.0
109
- arr = beta * np.sqrt(1 - ((window - alpha) / alpha) ** 2.0)
110
-
111
- arr = np.divide(np.power(arr, order) * bessel(order, arr), bessel(order, beta))
112
- arr[arr != arr] = 0
113
- return arr
114
-
115
-
116
- def window_sinckb(omega: float, d: float, dw: float):
117
- """
118
- Compute the sinc window combined with a Kaiser window.
119
-
120
- Parameters
121
- ----------
122
- omega : float
123
- Reduction factor.
124
- d : float
125
- Ripple.
126
- dw : float
127
- Delta w.
128
-
129
- Returns
130
- -------
131
- ndarray
132
- Impulse response of the low-pass filter.
133
-
134
- References
135
- ----------
136
- .. [1] Sorzano, Carlos et al (Mar. 2015). Fast and accurate conversion
137
- of atomic models into electron density maps. AIMS Biophysics
138
- 2, 8–20.
139
- """
140
- kaiser = kaiser_mask(d, dw)
141
- sinc_m = sinc_mask(np.zeros(kaiser.shape), omega)
142
-
143
- mask = sinc_m * kaiser
144
-
145
- return mask / np.sum(mask)
146
-
147
-
148
- def apply_window_filter(
149
- arr: NDArray,
150
- filter_window: NDArray,
151
- mode: str = "reflect",
152
- cval: float = 0.0,
153
- origin: int = 0,
154
- ):
155
- """
156
- Apply a window filter on an input array.
157
-
158
- Parameters
159
- ----------
160
- arr : NDArray,
161
- Input array.
162
- filter_window : NDArray,
163
- Window filter to apply.
164
- mode : str, optional
165
- Mode for the filtering, default is "reflect".
166
- cval : float, optional
167
- Value to fill when mode is "constant", default is 0.0.
168
- origin : int, optional
169
- Origin of the filter window, default is 0.
170
-
171
- Returns
172
- -------
173
- NDArray,
174
- Array after filtering.
175
-
176
- """
177
- filter_window = filter_window[::-1]
178
- for axs in range(arr.ndim):
179
- correlate1d(
180
- input=arr,
181
- weights=filter_window,
182
- axis=axs,
183
- output=arr,
184
- mode=mode,
185
- cval=cval,
186
- origin=origin,
187
- )
188
- return arr
189
-
190
-
191
- def sinc_mask(mask: NDArray, omega: float) -> NDArray:
192
- """
193
- Create a sinc mask.
194
-
195
- Parameters
196
- ----------
197
- mask : NDArray
198
- Input mask.
199
- omega : float
200
- Reduction factor.
201
-
202
- Returns
203
- -------
204
- NDArray
205
- Sinc mask.
206
- """
207
- # Move filter origin to the center of the mask
208
- mask_origin = int((mask.size - 1) / 2)
209
- dist = np.arange(-mask_origin, mask_origin + 1)
210
-
211
- return np.multiply(omega / np.pi, np.sinc((omega / np.pi) * dist))
212
-
213
-
214
- def kaiser_mask(d: float, dw: float) -> NDArray:
215
- """
216
- Create a Kaiser mask.
217
-
218
- Parameters
219
- ----------
220
- d : float
221
- Ripple.
222
- dw : float
223
- Delta-w.
224
-
225
- Returns
226
- -------
227
- NDArray
228
- Kaiser mask.
229
- """
230
- # convert dw from a frequency normalized to 1 to a frequency normalized to pi
231
- dw *= np.pi
232
- A = -20 * np.log10(d)
233
- M = max(1, np.ceil((A - 8) / (2.285 * dw)))
234
-
235
- beta = 0
236
- if A > 50:
237
- beta = 0.1102 * (A - 8.7)
238
- elif A >= 21:
239
- beta = 0.5842 * np.power(A - 21, 0.4) + 0.07886 * (A - 21)
240
-
241
- mask_values = np.abs(np.arange(-M, M + 1))
242
- mask = np.sqrt(1 - np.power(mask_values / M, 2))
243
-
244
- return np.divide(bessel(0, beta * mask), bessel(0, beta))
245
-
246
-
247
- def bessel(order: int, arr: NDArray) -> NDArray:
248
- """
249
- Compute the modified Bessel function of the first kind.
250
-
251
- Parameters
252
- ----------
253
- order : int
254
- Order of the Bessel function.
255
- arr : NDArray
256
- Input array.
257
-
258
- Returns
259
- -------
260
- NDArray
261
- Bessel function values.
262
-
263
- """
264
- return scipy_special_iv(order, arr)
265
-
266
-
267
- def electron_factor(
268
- dist: NDArray, method: str, atom: str, fourier: bool = False
269
- ) -> NDArray:
270
- """
271
- Compute the electron factor.
272
-
273
- Parameters
274
- ----------
275
- dist : NDArray
276
- Distance.
277
- method : str
278
- Method name.
279
- atom : str
280
- Atom type.
281
- fourier : bool, optional
282
- Whether to compute the electron factor in Fourier space.
283
-
284
- Returns
285
- -------
286
- NDArray
287
- Computed electron factor.
288
- """
289
- data = get_scattering_factors(method)
290
- n_range = len(data.get(atom, [])) // 2
291
- default = np.zeros(n_range * 3)
292
-
293
- res = 0.0
294
- a_values = data.get(atom, default)[:n_range]
295
- b_values = data.get(atom, default)[n_range : 2 * n_range]
296
-
297
- if method == "dt1969":
298
- b_values = data.get(atom, default)[1 : (n_range + 1)]
299
-
300
- for i in range(n_range):
301
- a = a_values[i]
302
- b = b_values[i]
303
-
304
- if fourier:
305
- temp = a * np.exp(-b * np.power(dist, 2))
306
- else:
307
- b = b / (4 * np.power(np.pi, 2))
308
- temp = a * np.sqrt(np.pi / b) * np.exp(-np.power(dist, 2) / (4 * b))
309
-
310
- if not np.isnan(temp).any():
311
- res += temp
312
-
313
- return res / (2 * np.pi)
314
-
315
-
316
- def optimize_hlfp(profile, M, T, atom, method, filter_method):
317
- """
318
- Optimize high-low pass filter (HLFP).
319
-
320
- Parameters
321
- ----------
322
- profile : NDArray
323
- Input profile.
324
- M : int
325
- Scaling factor.
326
- T : float
327
- Time step.
328
- atom : str
329
- Atom type.
330
- method : str
331
- Method name.
332
- filter_method : str
333
- Filter method name.
334
-
335
- Returns
336
- -------
337
- float
338
- Fitness value.
339
-
340
- References
341
- ----------
342
- .. [1] Sorzano, Carlos et al (Mar. 2015). Fast and accurate conversion
343
- of atomic models into electron density maps. AIMS Biophysics
344
- 2, 8–20.
345
- """
346
- # omega, d, dw
347
- initial_params = [1.0, 0.01, 1.0 / 8.0]
348
- if filter_method == "brute":
349
- best_fitness = float("inf")
350
- OMEGA, D, DW = np.meshgrid(
351
- np.arange(0.7, 1.3, 0.015),
352
- np.arange(0.01, 0.2, 0.015),
353
- np.arange(0.05, 0.2, 0.015),
354
- )
355
- for omega, d, dw in zip(OMEGA.ravel(), D.ravel(), DW.ravel()):
356
- current_fitness = _hlpf_fitness([omega, d, dw], T, M, profile, atom, method)
357
- if current_fitness < best_fitness:
358
- best_fitness = current_fitness
359
- initial_params = [omega, d, dw]
360
- final_params = np.array(initial_params)
361
- else:
362
- res = minimize(
363
- _hlpf_fitness,
364
- initial_params,
365
- args=tuple([T, M, profile, atom, method]),
366
- method="SLSQP",
367
- bounds=([0.2, 2], [1e-3, 2], [1e-3, 1]),
368
- )
369
- final_params = res.x
370
- if np.any(final_params != final_params):
371
- print(f"Solver returned NAs for atom {atom} at {M}" % (atom, M))
372
- final_params = final_params
373
-
374
- final_params[0] *= np.pi / M
375
- mask = window_sinckb(*final_params)
376
-
377
- if profile.shape[0] > mask.shape[0]:
378
- profile_origin = int((profile.size - 1) / 2)
379
- mask = window(mask, profile_origin, profile_origin)
380
-
381
- return mask
382
-
383
-
384
- def _hlpf_fitness(
385
- params: Tuple[float], T: float, M: float, profile: NDArray, atom: str, method: str
386
- ) -> float:
387
- """
388
- Fitness function for high-low pass filter optimization.
389
-
390
- Parameters
391
- ----------
392
- params : tuple of float
393
- Parameters [omega, d, dw] for optimization.
394
- T : float
395
- Time step.
396
- M : int
397
- Scaling factor.
398
- profile : NDArray
399
- Input profile.
400
- atom : str
401
- Atom type.
402
- method : str
403
- Method name.
404
-
405
- Returns
406
- -------
407
- float
408
- Fitness value.
409
-
410
- References
411
- ----------
412
- .. [1] Sorzano, Carlos et al (Mar. 2015). Fast and accurate conversion
413
- of atomic models into electron density maps. AIMS Biophysics
414
- 2, 8–20.
415
- .. [2] https://github.com/I2PC/xmipp/blob/707f921dfd29cacf5a161535034d28153b58215a/src/xmipp/libraries/data/pdb.cpp#L1344
416
- """
417
- omega, d, dw = params
418
-
419
- if not (0.7 <= omega <= 1.3) and (0 <= d <= 0.2) and (1e-3 <= dw <= 0.2):
420
- return 1e38 * np.random.randint(1, 100)
421
-
422
- mask = window_sinckb(omega=omega * np.pi / M, d=d, dw=dw)
423
-
424
- if profile.shape[0] > mask.shape[0]:
425
- profile_origin = int((profile.size - 1) / 2)
426
- mask = window(mask, profile_origin, profile_origin)
427
- else:
428
- filter_origin = int((mask.size - 1) / 2)
429
- profile = window(profile, filter_origin, filter_origin)
430
-
431
- f_mask = convolve(profile, mask)
432
-
433
- orig = int((f_mask.size - 1) / 2)
434
- dist = np.arange(-orig, orig + 1) * T
435
- t, c, k = splrep(x=dist, y=f_mask, k=3)
436
- i_max = np.ceil(np.divide(f_mask.shape, M))
437
- coarse_mask = np.arange(-i_max, i_max + 1) * M
438
- spline = BSpline(t, c, k)
439
- coarse_values = spline(coarse_mask)
440
-
441
- # padding to retain longer fourier response
442
- aux = window(
443
- coarse_values, x0=10 * coarse_values.shape[0], xf=10 * coarse_values.shape[0]
444
- )
445
- f_filter = np.fft.fftn(aux)
446
- f_filter_mag = np.abs(f_filter)
447
- freq = np.fft.fftfreq(f_filter.size)
448
- freq /= M * T
449
- amplitude_f = mask.sum() / coarse_values.sum()
450
-
451
- size_f = f_filter_mag.shape[0] * amplitude_f
452
- fourier_form_f = electron_factor(dist=freq, atom=atom, method=method, fourier=True)
453
-
454
- valid_freq_mask = freq >= 0
455
- f1_values = np.log10(f_filter_mag[valid_freq_mask] * size_f)
456
- f2_values = np.log10(np.divide(T, fourier_form_f[valid_freq_mask]))
457
- squared_differences = np.square(f1_values - f2_values)
458
- error = np.sum(squared_differences)
459
- error /= np.sum(valid_freq_mask)
460
-
461
- return error
462
-
463
-
464
- def window(arr, x0, xf, constant_values=0):
465
- """
466
- Window an array by slicing between x0 and xf and padding if required.
467
-
468
- Parameters
469
- ----------
470
- arr : ndarray
471
- Input array to be windowed.
472
- x0 : int
473
- Start of the window.
474
- xf : int
475
- End of the window.
476
- constant_values : int or float, optional
477
- The constant values to use for padding, by default 0.
478
-
479
- Returns
480
- -------
481
- ndarray
482
- Windowed array.
483
- """
484
- origin = int((arr.size - 1) / 2)
485
-
486
- xs = origin - x0
487
- xe = origin - xf
488
-
489
- if xs >= 0 and xe <= arr.shape[0]:
490
- if xs <= arr.shape[0] and xe > 0:
491
- arr = arr[xs:xe]
492
- xs = 0
493
- xe = 0
494
- elif xs <= arr.shape[0]:
495
- arr = arr[xs:]
496
- xs = 0
497
- elif xe >= 0 and xe <= arr.shape[0]:
498
- arr = arr[:xe]
499
- xe = 0
500
-
501
- xs *= -1
502
- xe *= -1
503
-
504
- return np.pad(
505
- arr, (int(xs), int(xe)), mode="constant", constant_values=constant_values
506
- )
507
-
508
-
509
- def atom_profile(
510
- M, atom, T=0.08333333, method="peng1995", lfilter=True, filter_method="minimize"
511
- ):
512
- """
513
- Generate an atom profile using a variety of methods.
514
-
515
- Parameters
516
- ----------
517
- M : float
518
- Down sampling factor.
519
- atom : Any
520
- Type or representation of the atom.
521
- T : float, optional
522
- Sampling rate in angstroms/pixel, by default 0.08333333.
523
- method : str, optional
524
- Method to be used for generating the profile, by default "peng1995".
525
- lfilter : bool, optional
526
- Whether to apply filter on the profile, by default True.
527
- filter_method : str, optional
528
- The method for the filter, by default "minimize".
529
-
530
- Returns
531
- -------
532
- BSpline
533
- A spline representation of the atom profile.
534
-
535
- References
536
- ----------
537
- .. [1] Sorzano, Carlos et al (Mar. 2015). Fast and accurate conversion
538
- of atomic models into electron density maps. AIMS Biophysics
539
- 2, 8–20.
540
- .. [2] https://github.com/I2PC/xmipp/blob/707f921dfd29cacf5a161535034d28153b58215a/src/xmipp/libraries/data/pdb.cpp#L1344
541
- """
542
- M = M / T
543
- imax = np.ceil(4 / T * np.sqrt(76.7309 / (2 * np.power(np.pi, 2))))
544
- dist = np.arange(-imax, imax + 1) * T
545
-
546
- profile = electron_factor(dist, method, atom)
547
-
548
- if lfilter:
549
- window = optimize_hlfp(
550
- profile=profile,
551
- M=M,
552
- T=T,
553
- atom=atom,
554
- method=method,
555
- filter_method=filter_method,
556
- )
557
- profile = convolve(profile, window)
558
-
559
- indices = np.where(profile > 1e-3)
560
- min_indices = np.maximum(np.amin(indices, axis=1), 0)
561
- max_indices = np.minimum(np.amax(indices, axis=1) + 1, profile.shape)
562
- slices = tuple(slice(*coord) for coord in zip(min_indices, max_indices))
563
- profile = profile[slices]
564
-
565
- profile_origin = int((profile.size - 1) / 2)
566
- dist = np.arange(-profile_origin, profile_origin + 1) * T
567
- t, c, k = splrep(x=dist, y=profile, k=3)
568
-
569
- return BSpline(t, c, k)
570
-
571
-
572
- def get_scattering_factors(method: str) -> Dict:
573
- """
574
- Retrieve scattering factors from a stored file based on the given method.
575
-
576
- Parameters
577
- ----------
578
- method : str
579
- Method name used to get the scattering factors.
580
-
581
- Returns
582
- -------
583
- Dict
584
- Dictionary containing scattering factors for the given method.
585
-
586
- Raises
587
- ------
588
- ValueError
589
- If the method is not found in the stored data.
590
-
591
- """
592
- path = os.path.join(os.path.dirname(__file__), "data", "scattering_factors.pickle")
593
- with open(path, "rb") as infile:
594
- data = pickle.load(infile)
595
-
596
- if method not in data:
597
- raise ValueError(f"{method} is not valid. Use {', '.join(data.keys())}.")
598
- return data[method]
599
-
600
-
601
- def load_quaternions_by_angle(angle: float) -> (NDArray, NDArray, float):
602
- """
603
- Get orientations and weights proportional to the given angle.
604
-
605
- Parameters
606
- ----------
607
- angle : float
608
- Given angle.
609
-
610
- Returns
611
- -------
612
- tuple
613
- quaternions : NDArray
614
- Quaternion representations of orientations.
615
- weights : NDArray
616
- Weights associated with each orientation.
617
- angle : float
618
- The closest angle to the provided angle from the metadata.
619
- """
620
- # Metadata contains (N orientations, rotational sampling, coverage as values)
621
- with open(
622
- os.path.join(os.path.dirname(__file__), "data", "metadata.yaml"), "r"
623
- ) as infile:
624
- metadata = yaml.full_load(infile)
625
-
626
- set_diffs = {
627
- setname: abs(angle - set_angle)
628
- for setname, (_, set_angle, _) in metadata.items()
629
- }
630
- fname = min(set_diffs, key=set_diffs.get)
631
-
632
- infile = os.path.join(os.path.dirname(__file__), "data", fname)
633
- quat_weights = np.load(infile)
634
-
635
- quat = quat_weights[:, :4]
636
- weights = quat_weights[:, -1]
637
- angle = metadata[fname][0]
638
-
639
- return quat, weights, angle
640
-
641
-
642
- def quaternion_to_rotation_matrix(quaternions: NDArray) -> NDArray:
643
- """
644
- Convert quaternions to rotation matrices.
645
-
646
- Parameters
647
- ----------
648
- quaternions : NDArray
649
- Array containing quaternions.
650
-
651
- Returns
652
- -------
653
- NDArray
654
- Rotation matrices corresponding to the given quaternions.
655
- """
656
- q0 = quaternions[:, 0]
657
- q1 = quaternions[:, 1]
658
- q2 = quaternions[:, 2]
659
- q3 = quaternions[:, 3]
660
-
661
- s = np.linalg.norm(quaternions, axis=1) * 2
662
- rotmat = np.zeros((quaternions.shape[0], 3, 3), dtype=np.float64)
663
-
664
- rotmat[:, 0, 0] = 1.0 - s * ((q2 * q2) + (q3 * q3))
665
- rotmat[:, 0, 1] = s * ((q1 * q2) - (q0 * q3))
666
- rotmat[:, 0, 2] = s * ((q1 * q3) + (q0 * q2))
667
-
668
- rotmat[:, 1, 0] = s * ((q2 * q1) + (q0 * q3))
669
- rotmat[:, 1, 1] = 1.0 - s * ((q3 * q3) + (q1 * q1))
670
- rotmat[:, 1, 2] = s * ((q2 * q3) - (q0 * q1))
671
-
672
- rotmat[:, 2, 0] = s * ((q3 * q1) - (q0 * q2))
673
- rotmat[:, 2, 1] = s * ((q3 * q2) + (q0 * q1))
674
- rotmat[:, 2, 2] = 1.0 - s * ((q1 * q1) + (q2 * q2))
675
-
676
- np.around(rotmat, decimals=8, out=rotmat)
677
-
678
- return rotmat
679
-
680
-
681
- def reverse(arr: NDArray) -> NDArray:
682
- """
683
- Reverse the order of elements in an array along all its axes.
684
-
685
- Parameters
686
- ----------
687
- arr : NDArray
688
- Input array.
689
-
690
- Returns
691
- -------
692
- NDArray
693
- Reversed array.
694
- """
695
- return arr[(slice(None, None, -1),) * arr.ndim]
696
-
697
-
698
- class Ntree:
699
- """
700
- N-dimensional dyadic tree.
701
-
702
- Each array dimension is split into two similarly sized halves. The amount of
703
- subvolumes per split equals 2**n with n being the dimension of the input array.
704
-
705
- Attributes
706
- ----------
707
- nleaves : int
708
- Number of leaves in the Ntree.
709
-
710
- """
711
-
712
- def __init__(self, arr: NDArray):
713
- """
714
- Initialize the Ntree with the given array.
715
-
716
- Parameters
717
- ----------
718
- arr : np.ndarray
719
- Input array to build the N-dimensional dyadic tree.
720
-
721
- """
722
- arr = np.asarray(arr)
723
- self._subvolumes = []
724
- self._sd = []
725
- self._arr = arr.copy()
726
- self._arr += np.abs(np.min(self._arr))
727
-
728
- np.seterr(divide="ignore", invalid="ignore")
729
- self._create_node(self._arr, np.zeros(arr.ndim, dtype=int), 0)
730
- np.seterr(divide="warn", invalid="warn")
731
-
732
- @property
733
- def nleaves(self):
734
- return len(self._subvolumes)
735
-
736
- def _create_node(self, arr: NDArray, offset: NDArray, ig: float):
737
- """
738
- Recursively split the array into nodes based on specific criteria.
739
-
740
- Parameters
741
- ----------
742
- arr : NDArray
743
- The array to split into nodes.
744
- offset : NDArray
745
- The offset for the current split in the array.
746
- ig : float
747
- Information gain value.
748
-
749
- """
750
- coordinates = self._split_arr(arr)
751
- sd_arr = np.std(arr)
752
-
753
- for chunk in coordinates:
754
- sd_chunk = np.std(arr[chunk])
755
- split_needed = False
756
-
757
- if np.count_nonzero(arr[chunk]) == 0 or sd_chunk == 0:
758
- split_needed = False
759
- elif not np.all(np.greater(arr[chunk].shape, 3)):
760
- split_needed = False
761
- else:
762
- new_split = self._split_arr(arr[chunk])
763
- igo = self._information_gain(arr[chunk], new_split)
764
- if sd_chunk < sd_arr or igo > ig:
765
- split_needed = True
766
-
767
- if split_needed:
768
- new_offset = np.add(offset, [n.start for n in chunk])
769
- self._create_node(arr[chunk], new_offset, igo)
770
- else:
771
- final_coordinates = tuple(
772
- slice(n.start + offset[i], n.stop + offset[i])
773
- for i, n in enumerate(chunk)
774
- )
775
- self._subvolumes.append(final_coordinates)
776
- self._sd.append(np.sum(arr[tuple(chunk)] != 0))
777
-
778
- @staticmethod
779
- def _information_gain(arr: NDArray, chunks: Tuple[NDArray]):
780
- """
781
- Calculate the information gain of splitting the array.
782
-
783
- Parameters
784
- ----------
785
- arr : NDArray
786
- The array from which to calculate information gain.
787
- chunks : Tuple
788
- List of sub-arrays (chunks) created by splitting.
789
-
790
- Returns
791
- -------
792
- float
793
- The information gain of the split.
794
-
795
- """
796
- if not isinstance(chunks, list) and not isinstance(chunks, tuple):
797
- chunks = [chunks]
798
-
799
- arr_entropy = entropy(arr.ravel())
800
- weighted_split_entropy = [
801
- (arr[tuple(i)].size / arr.size) * entropy(arr[tuple(i)].ravel())
802
- for i in chunks
803
- ]
804
- return arr_entropy - np.sum(weighted_split_entropy)
805
-
806
- @staticmethod
807
- def _split_arr(arr: NDArray) -> Tuple[NDArray]:
808
- """
809
- Split the given array into multiple similarly sized chunks.
810
-
811
- Parameters
812
- ----------
813
- arr : NDArray
814
- The array to split.
815
-
816
- Returns
817
- -------
818
- tuple
819
- Tuple containing the slices to split the array.
820
-
821
- """
822
- old_shape = np.asarray(arr.shape).astype(int)
823
- new_shape = np.divide(arr.shape, 2).astype(int)
824
- split = tuple(
825
- product(
826
- *[
827
- (slice(0, n_shape), slice(n_shape, o_shape))
828
- for n_shape, o_shape in np.nditer([new_shape, old_shape])
829
- ]
830
- )
831
- )
832
- return split
833
-
834
- def _sd_to_range(self, scale_range: Tuple[float, float] = (0.1, 20)) -> NDArray:
835
- """
836
- Scale the standard deviation values to a specific range.
837
-
838
- Parameters
839
- ----------
840
- scale_range : tuple of float, optional
841
- The range to scale the standard deviation to.
842
-
843
- Returns
844
- -------
845
- NDArray
846
- Array of scaled standard deviation values.
847
-
848
- """
849
- scaled_sd = np.interp(
850
- self._sd, (np.min(self._sd), np.max(self._sd)), scale_range
851
- )
852
- return np.round(scaled_sd, decimals=0)
853
-
854
- def filter_chunks(
855
- self, arr: NDArray = None, sigma_range: Tuple[float, float] = (0.2, 10)
856
- ) -> NDArray:
857
- """
858
- Apply Gaussian filter to each chunk and return the filtered array.
859
-
860
- Parameters
861
- ----------
862
- arr : NDArray, optional
863
- The array to filter. If None, the original array is used.
864
- sigma_range : tuple of float, optional
865
- Range of sigma values for the Gaussian filter.
866
-
867
- Returns
868
- -------
869
- NDArray
870
- The filtered array.
871
-
872
- """
873
- if arr is None:
874
- arr = self._arr
875
- result = np.zeros_like(arr)
876
- chunk_sigmas = self._sd_to_range(sigma_range)
877
-
878
- for chunk, sigma in zip(self._subvolumes, chunk_sigmas):
879
- result[chunk] = gaussian_filter(arr[chunk], sigma)
880
-
881
- return result