pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
tme/matching_data.py ADDED
@@ -0,0 +1,863 @@
1
+ """ Class representation of template matching data.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ import warnings
9
+ from typing import Tuple, List, Optional, Generator, Dict
10
+
11
+ import numpy as np
12
+
13
+ from . import Density
14
+ from .filters import Compose
15
+ from .backends import backend as be
16
+ from .types import BackendArray, NDArray
17
+ from .matching_utils import compute_parallelization_schedule
18
+
19
+ __all__ = ["MatchingData"]
20
+
21
+
22
+ class MatchingData:
23
+ """
24
+ Contains data required for template matching.
25
+
26
+ Parameters
27
+ ----------
28
+ target : np.ndarray or :py:class:`tme.density.Density`
29
+ Target data.
30
+ template : np.ndarray or :py:class:`tme.density.Density`
31
+ Template data.
32
+ target_mask : np.ndarray or :py:class:`tme.density.Density`, optional
33
+ Target mask data.
34
+ template_mask : np.ndarray or :py:class:`tme.density.Density`, optional
35
+ Template mask data.
36
+ invert_target : bool, optional
37
+ Whether to invert the target before template matching.
38
+ rotations: np.ndarray, optional
39
+ Template rotations to sample. Can be a single (d, d) or a stack (n, d, d)
40
+ of rotation matrices where d is the dimension of the template.
41
+
42
+ Examples
43
+ --------
44
+ The following achieves the minimal definition of a :py:class:`MatchingData` instance.
45
+
46
+ >>> import numpy as np
47
+ >>> from tme.matching_data import MatchingData
48
+ >>> target = np.random.rand(50,40,60)
49
+ >>> template = target[15:25, 10:20, 30:40]
50
+ >>> matching_data = MatchingData(target=target, template=template)
51
+
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ target: NDArray,
57
+ template: NDArray,
58
+ template_mask: NDArray = None,
59
+ target_mask: NDArray = None,
60
+ invert_target: bool = False,
61
+ rotations: NDArray = None,
62
+ ):
63
+ self.target = target
64
+ self.target_mask = target_mask
65
+
66
+ self.template = template
67
+ if template_mask is not None:
68
+ self.template_mask = template_mask
69
+
70
+ self.rotations = rotations
71
+ self._invert_target = invert_target
72
+ self._translation_offset = tuple(0 for _ in range(len(target.shape)))
73
+
74
+ self._set_matching_dimension()
75
+
76
+ @staticmethod
77
+ def _shape_to_slice(shape: Tuple[int]) -> Tuple[slice]:
78
+ return tuple(slice(0, dim) for dim in shape)
79
+
80
+ @classmethod
81
+ def _slice_to_mesh(cls, slice_variable: Tuple[slice], shape: Tuple[int]) -> NDArray:
82
+ if slice_variable is None:
83
+ slice_variable = cls._shape_to_slice(shape)
84
+ ranges = [range(slc.start, slc.stop) for slc in slice_variable]
85
+ indices = np.meshgrid(*ranges, sparse=True, indexing="ij")
86
+ return indices
87
+
88
+ @staticmethod
89
+ def _load_array(arr: BackendArray) -> BackendArray:
90
+ """Load ``arr``, if ``arr`` type is a :obj:`numpy.memmap`, reload from disk."""
91
+ if isinstance(arr, np.memmap):
92
+ return np.memmap(arr.filename, mode="r", shape=arr.shape, dtype=arr.dtype)
93
+ return arr
94
+
95
+ def subset_array(
96
+ self,
97
+ arr: NDArray,
98
+ arr_slice: Tuple[slice],
99
+ padding: NDArray,
100
+ invert: bool = False,
101
+ ) -> NDArray:
102
+ """
103
+ Extract a subset of the input array according to the given slice and
104
+ apply padding. If the padding exceeds the array dimensions, the
105
+ padded regions are filled by reflection of the boundaries. Otherwise,
106
+ the values in ``arr`` are used.
107
+
108
+ Parameters
109
+ ----------
110
+ arr : NDArray
111
+ The input array from which a subset is extracted.
112
+ arr_slice : tuple of slice
113
+ Defines the region of the input array to be extracted.
114
+ padding : NDArray
115
+ Padding values for each dimension.
116
+ invert : bool, optional
117
+ Whether the returned array should be inverted.
118
+
119
+ Returns
120
+ -------
121
+ NDArray
122
+ Subset of the input array with padding applied.
123
+ """
124
+ padding = be.to_numpy_array(padding)
125
+ padding = np.maximum(padding, 0).astype(int)
126
+
127
+ slice_start = np.array([x.start for x in arr_slice], dtype=int)
128
+ slice_stop = np.array([x.stop for x in arr_slice], dtype=int)
129
+
130
+ # We are deviating from our typical right_pad + mod here
131
+ # because cropping from full convolution mode to target shape
132
+ # is defined from the perspective of the origin
133
+ right_pad = np.divide(padding, 2).astype(int)
134
+ left_pad = np.add(right_pad, np.mod(padding, 2))
135
+
136
+ data_voxels_left = np.minimum(slice_start, left_pad)
137
+ data_voxels_right = np.minimum(
138
+ np.subtract(arr.shape, slice_stop), right_pad
139
+ ).astype(int)
140
+
141
+ arr_start = np.subtract(slice_start, data_voxels_left)
142
+ arr_stop = np.add(slice_stop, data_voxels_right)
143
+ arr_slice = tuple(slice(*pos) for pos in zip(arr_start, arr_stop))
144
+ arr_mesh = self._slice_to_mesh(arr_slice, arr.shape)
145
+
146
+ # Note different from joblib memmaps, the memmaps created by
147
+ # Density are guaranteed to only contain the array of interest
148
+ if isinstance(arr, Density):
149
+ if isinstance(arr.data, np.memmap):
150
+ arr = Density.from_file(arr.data.filename, subset=arr_slice).data
151
+ else:
152
+ arr = np.asarray(arr.data[*arr_mesh])
153
+ else:
154
+ arr = np.asarray(arr[*arr_mesh])
155
+
156
+ padding = tuple(
157
+ (left, right)
158
+ for left, right in zip(
159
+ np.subtract(left_pad, data_voxels_left),
160
+ np.subtract(right_pad, data_voxels_right),
161
+ )
162
+ )
163
+ # The reflections are later cropped from the scores
164
+ arr = np.pad(arr, padding, mode="reflect")
165
+
166
+ if invert:
167
+ arr = -arr
168
+ return arr
169
+
170
+ def subset_by_slice(
171
+ self,
172
+ target_slice: Tuple[slice] = None,
173
+ template_slice: Tuple[slice] = None,
174
+ target_pad: NDArray = None,
175
+ template_pad: NDArray = None,
176
+ invert_target: bool = False,
177
+ ) -> "MatchingData":
178
+ """
179
+ Subset class instance based on slices.
180
+
181
+ Parameters
182
+ ----------
183
+ target_slice : tuple of slice, optional
184
+ Target subset to use, all by default.
185
+ template_slice : tuple of slice, optional
186
+ Template subset to use, all by default.
187
+ target_pad : BackendArray, optional
188
+ Target padding, zero by default.
189
+ template_pad : BackendArray, optional
190
+ Template padding, zero by default.
191
+
192
+ Returns
193
+ -------
194
+ :py:class:`MatchingData`
195
+ Newly allocated subset of class instance.
196
+
197
+ Examples
198
+ --------
199
+ >>> import numpy as np
200
+ >>> from tme.matching_data import MatchingData
201
+ >>> target = np.random.rand(50,40,60)
202
+ >>> template = target[15:25, 10:20, 30:40]
203
+ >>> matching_data = MatchingData(target=target, template=template)
204
+ >>> subset = matching_data.subset_by_slice(
205
+ >>> target_slice=(slice(0, 10), slice(10,20), slice(15,35))
206
+ >>> )
207
+ """
208
+ if target_slice is None:
209
+ target_slice = self._shape_to_slice(self._target.shape)
210
+ if template_slice is None:
211
+ template_slice = self._shape_to_slice(self._template.shape)
212
+
213
+ if target_pad is None:
214
+ target_pad = np.zeros(len(self._target.shape), dtype=int)
215
+ if template_pad is None:
216
+ template_pad = np.zeros(len(self._template.shape), dtype=int)
217
+
218
+ target_mask, template_mask = None, None
219
+ target_subset = self.subset_array(
220
+ self._target, target_slice, target_pad, invert=self._invert_target
221
+ )
222
+ template_subset = self.subset_array(
223
+ arr=self._template, arr_slice=template_slice, padding=template_pad
224
+ )
225
+ if self._target_mask is not None:
226
+ mask_slice = zip(target_slice, self._target_mask.shape)
227
+ mask_slice = tuple(x if t != 1 else slice(0, 1) for x, t in mask_slice)
228
+ target_mask = self.subset_array(
229
+ arr=self._target_mask, arr_slice=mask_slice, padding=target_pad
230
+ )
231
+ if self._template_mask is not None:
232
+ mask_slice = zip(template_slice, self._template_mask.shape)
233
+ mask_slice = tuple(x if t != 1 else slice(0, 1) for x, t in mask_slice)
234
+ template_mask = self.subset_array(
235
+ arr=self._template_mask, arr_slice=mask_slice, padding=template_pad
236
+ )
237
+
238
+ ret = self.__class__(
239
+ target=target_subset,
240
+ template=template_subset,
241
+ template_mask=template_mask,
242
+ target_mask=target_mask,
243
+ rotations=self.rotations,
244
+ invert_target=self._invert_target,
245
+ )
246
+
247
+ # Deal with splitting offsets
248
+ mask = np.subtract(1, self._template_batch).astype(bool)
249
+ target_offset = np.zeros(len(self._output_target_shape), dtype=int)
250
+ target_offset[mask] = [x.start for x in target_slice]
251
+ mask = np.subtract(1, self._target_batch).astype(bool)
252
+ template_offset = np.zeros(len(self._output_template_shape), dtype=int)
253
+ template_offset[mask] = [x.start for x in template_slice]
254
+ ret._translation_offset = tuple(x for x in target_offset)
255
+
256
+ ret.target_filter = self.target_filter
257
+ ret.template_filter = self.template_filter
258
+
259
+ ret.set_matching_dimension(
260
+ target_dim=getattr(self, "_target_dim", None),
261
+ template_dim=getattr(self, "_template_dim", None),
262
+ )
263
+
264
+ return ret
265
+
266
+ def to_backend(self):
267
+ """
268
+ Transfer and convert types of internal data arrays to the current backend.
269
+
270
+ Examples
271
+ --------
272
+ >>> matching_data.to_backend()
273
+ """
274
+ backend_arr = type(be.zeros((1), dtype=be._float_dtype))
275
+ for attr_name, attr_value in vars(self).items():
276
+ converted_array = None
277
+ if isinstance(attr_value, np.ndarray):
278
+ converted_array = be.to_backend_array(attr_value.copy())
279
+ elif isinstance(attr_value, backend_arr):
280
+ converted_array = be.to_backend_array(attr_value)
281
+ else:
282
+ continue
283
+
284
+ current_dtype = be.get_fundamental_dtype(converted_array)
285
+ target_dtype = be._fundamental_dtypes[current_dtype]
286
+
287
+ # Optional, but scores are float so we avoid casting and potential issues
288
+ if attr_name in ("_template", "_template_mask", "_target", "_target_mask"):
289
+ target_dtype = be._float_dtype
290
+
291
+ if target_dtype != current_dtype:
292
+ converted_array = be.astype(converted_array, target_dtype)
293
+
294
+ setattr(self, attr_name, converted_array)
295
+
296
+ def set_matching_dimension(self, target_dim: int = None, template_dim: int = None):
297
+ """
298
+ Sets matching dimensions for target and template.
299
+
300
+ Parameters
301
+ ----------
302
+ target_dim : int, optional
303
+ Target batch dimension, None by default.
304
+ template_dim : int, optional
305
+ Template batch dimension, None by default.
306
+
307
+ Examples
308
+ --------
309
+ >>> matching_data.set_matching_dimension(target_dim=0, template_dim=None)
310
+
311
+ Notes
312
+ -----
313
+ If target and template share a batch dimension, the target will take
314
+ precendence and the template dimension will be shifted to the right. If target
315
+ and template have the same dimension, but target specifies batch dimensions,
316
+ the leftmost template dimensions are assumed to be collapse dimensions.
317
+ """
318
+ target_ndim = len(self._target.shape)
319
+ _, target_dims = self._compute_batch_dims(target_dim, ndim=target_ndim)
320
+ template_ndim = len(self._template.shape)
321
+ _, template_dims = self._compute_batch_dims(template_dim, ndim=template_ndim)
322
+
323
+ target_ndim -= len(target_dims)
324
+ template_ndim -= len(template_dims)
325
+
326
+ if target_ndim != template_ndim:
327
+ raise ValueError(
328
+ f"Dimension mismatch: Target ({target_ndim}) Template ({template_ndim})."
329
+ )
330
+ self._set_matching_dimension(
331
+ target_dims=target_dims, template_dims=template_dims
332
+ )
333
+
334
+ def _set_matching_dimension(
335
+ self, target_dims: Tuple[int] = (), template_dims: Tuple[int] = ()
336
+ ):
337
+ self._target_dim, self._template_dim = target_dims, template_dims
338
+
339
+ target_ndim, template_ndim = len(self._target.shape), len(self._template.shape)
340
+ batch_dims = len(target_dims) + len(template_dims)
341
+ target_measurement_dims = target_ndim - len(target_dims)
342
+ collapse_dims = max(
343
+ template_ndim - len(template_dims) - target_measurement_dims, 0
344
+ )
345
+ matching_dims = target_measurement_dims + batch_dims
346
+
347
+ target_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
348
+ template_shape = np.full(shape=matching_dims, fill_value=1, dtype=int)
349
+ template_batch = np.full(shape=matching_dims, fill_value=1, dtype=int)
350
+ target_batch = np.full(shape=matching_dims, fill_value=1, dtype=int)
351
+
352
+ target_index, template_index = 0, 0
353
+ for k in range(matching_dims):
354
+ target_dim = k - target_index
355
+ template_dim = k - template_index
356
+
357
+ if target_dim in target_dims:
358
+ target_shape[k] = self._target.shape[target_dim]
359
+ template_batch[k] = 0
360
+ if target_index == len(template_dims) and collapse_dims > 0:
361
+ template_shape[k] = self._template.shape[template_dim]
362
+ collapse_dims -= 1
363
+ template_index += 1
364
+ continue
365
+
366
+ if template_dim in template_dims:
367
+ template_shape[k] = self._template.shape[template_dim]
368
+ target_batch[k] = 0
369
+ target_index += 1
370
+ continue
371
+
372
+ target_batch[k] = template_batch[k] = 0
373
+ if target_dim < target_ndim:
374
+ target_shape[k] = self._target.shape[target_dim]
375
+ if template_dim < template_ndim:
376
+ template_shape[k] = self._template.shape[template_dim]
377
+
378
+ batch_mask = np.logical_or(target_batch, template_batch)
379
+ self._output_target_shape = tuple(int(x) for x in target_shape)
380
+ self._output_template_shape = tuple(int(x) for x in template_shape)
381
+ self._batch_mask = tuple(int(x) for x in batch_mask)
382
+ self._template_batch = tuple(int(x) for x in template_batch)
383
+ self._target_batch = tuple(int(x) for x in target_batch)
384
+
385
+ output_shape = np.add(
386
+ self._output_target_shape,
387
+ np.multiply(self._template_batch, self._output_template_shape),
388
+ )
389
+ output_shape = np.subtract(output_shape, self._template_batch)
390
+ self._output_shape = tuple(int(x) for x in output_shape)
391
+
392
+ @staticmethod
393
+ def _compute_batch_dims(batch_dims: Tuple[int], ndim: int) -> Tuple:
394
+ """
395
+ Computes a mask for the batch dimensions and the validated batch dimensions.
396
+
397
+ Parameters
398
+ ----------
399
+ batch_dims : tuple of int
400
+ A tuple of integers representing the batch dimensions.
401
+ ndim : int
402
+ The number of dimensions of the array.
403
+
404
+ Returns
405
+ -------
406
+ Tuple[ArrayLike, tuple of int]
407
+ Mask and the corresponding batch dimensions.
408
+
409
+ Raises
410
+ ------
411
+ ValueError
412
+ If any dimension in batch_dims is not less than ndim.
413
+ """
414
+ mask = np.zeros(ndim, dtype=int)
415
+ if batch_dims is None:
416
+ return mask, ()
417
+
418
+ if isinstance(batch_dims, int):
419
+ batch_dims = (batch_dims,)
420
+
421
+ for dim in batch_dims:
422
+ if dim < ndim:
423
+ mask[dim] = 1
424
+ continue
425
+ raise ValueError(f"Batch indices needs to be < {ndim}, got {dim}.")
426
+
427
+ return mask, batch_dims
428
+
429
+ @staticmethod
430
+ def _batch_shape(shape: Tuple[int], mask: Tuple[int], keepdims=True) -> Tuple[int]:
431
+ if keepdims:
432
+ return tuple(x if y == 0 else 1 for x, y in zip(shape, mask))
433
+ return tuple(x for x, y in zip(shape, mask) if y == 0)
434
+
435
+ @staticmethod
436
+ def _batch_iter(shape: Tuple[int], mask: Tuple[int]) -> Generator:
437
+ def _recursive_gen(current_shape, current_mask, current_slices):
438
+ if not current_shape:
439
+ yield current_slices
440
+ return
441
+
442
+ if current_mask[0] == 1:
443
+ for i in range(current_shape[0]):
444
+ new_slices = current_slices + (slice(i, i + 1),)
445
+ yield from _recursive_gen(
446
+ current_shape[1:], current_mask[1:], new_slices
447
+ )
448
+ else:
449
+ new_slices = current_slices + (slice(None),)
450
+ yield from _recursive_gen(
451
+ current_shape[1:], current_mask[1:], new_slices
452
+ )
453
+
454
+ return _recursive_gen(shape, mask, ())
455
+
456
+ @staticmethod
457
+ def _batch_axis(mask: Tuple[int]) -> Tuple[int]:
458
+ return tuple(i for i in range(len(mask)) if mask[i] == 0)
459
+
460
+ def target_padding(self, pad_target: bool = False) -> Tuple[int]:
461
+ """
462
+ Computes the padding of the target to the full convolution
463
+ shape given the registered template.
464
+
465
+ Parameters
466
+ ----------
467
+ pad_target : bool, optional
468
+ Whether to pad the target, defaults to False.
469
+
470
+ Returns
471
+ -------
472
+ tuple of int
473
+ Padding along each dimension.
474
+
475
+ Examples
476
+ --------
477
+ >>> matching_data.target_padding(pad_target=True)
478
+ """
479
+ padding = np.zeros(len(self._output_target_shape), dtype=int)
480
+ if pad_target:
481
+ padding = np.subtract(self._output_template_shape, 1)
482
+ if hasattr(self, "_target_batch"):
483
+ padding = np.multiply(padding, np.subtract(1, self._target_batch))
484
+
485
+ if hasattr(self, "_template_batch"):
486
+ padding = tuple(x for x, i in zip(padding, self._template_batch) if i == 0)
487
+
488
+ return tuple(int(x) for x in padding)
489
+
490
+ @staticmethod
491
+ def _fourier_padding(
492
+ target_shape: Tuple[int],
493
+ template_shape: Tuple[int],
494
+ pad_fourier: bool,
495
+ batch_mask: Tuple[int] = None,
496
+ ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
497
+ fourier_pad = template_shape
498
+ fourier_shift = np.zeros_like(template_shape)
499
+
500
+ if batch_mask is None:
501
+ batch_mask = np.zeros_like(template_shape)
502
+ batch_mask = np.asarray(batch_mask)
503
+
504
+ if not pad_fourier:
505
+ fourier_pad = np.ones(len(fourier_pad), dtype=int)
506
+ fourier_pad = np.multiply(fourier_pad, 1 - batch_mask)
507
+ fourier_pad = np.add(fourier_pad, batch_mask)
508
+
509
+ pad_shape = np.maximum(target_shape, template_shape)
510
+ ret = be.compute_convolution_shapes(pad_shape, fourier_pad)
511
+ conv_shape, fast_shape, fast_ft_shape = ret
512
+
513
+ template_mod = np.mod(template_shape, 2)
514
+ if not pad_fourier:
515
+ fourier_shift = 1 - np.divide(template_shape, 2).astype(int)
516
+ fourier_shift = np.subtract(fourier_shift, template_mod)
517
+
518
+ shape_diff = np.multiply(
519
+ np.subtract(target_shape, template_shape), 1 - batch_mask
520
+ )
521
+ shape_mask = shape_diff < 0
522
+ if np.sum(shape_mask):
523
+ shape_shift = np.divide(shape_diff, 2)
524
+ offset = np.mod(shape_diff, 2)
525
+ if pad_fourier:
526
+ offset = -np.subtract(
527
+ offset,
528
+ np.logical_and(np.mod(target_shape, 2) == 0, template_mod == 1),
529
+ )
530
+ else:
531
+ warnings.warn(
532
+ "Template is larger than target and padding is turned off. Consider "
533
+ "swapping them or activate padding. Correcting the shift for now."
534
+ )
535
+ shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
536
+ fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
537
+
538
+ fourier_shift = tuple(np.multiply(fourier_shift, 1 - batch_mask).astype(int))
539
+
540
+ return tuple(conv_shape), tuple(fast_shape), tuple(fast_ft_shape), fourier_shift
541
+
542
+ def fourier_padding(
543
+ self, pad_fourier: bool = False
544
+ ) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
545
+ """
546
+ Computes efficient shape four Fourier transforms and potential associated shifts.
547
+
548
+ Parameters
549
+ ----------
550
+ pad_fourier : bool, optional
551
+ If true, returns the shape of the full-convolution defined as sum of target
552
+ shape and template shape minus one, False by default.
553
+
554
+ Returns
555
+ -------
556
+ Tuple[tuple of int, tuple of int, tuple of int, tuple of int]
557
+ Tuple with convolution, forward FT, inverse FT shape and corresponding shift.
558
+
559
+ Examples
560
+ --------
561
+ >>> conv, fwd, inv, shift = matching_data.fourier_padding(pad_fourier=True)
562
+ """
563
+ return self._fourier_padding(
564
+ target_shape=be.to_numpy_array(self._output_target_shape),
565
+ template_shape=be.to_numpy_array(self._output_template_shape),
566
+ batch_mask=be.to_numpy_array(self._batch_mask),
567
+ pad_fourier=pad_fourier,
568
+ )
569
+
570
+ def computation_schedule(
571
+ self,
572
+ matching_method: str = "FLCSphericalMask",
573
+ max_cores: int = 1,
574
+ use_gpu: bool = False,
575
+ pad_fourier: bool = False,
576
+ pad_target_edges: bool = False,
577
+ analyzer_method: str = None,
578
+ available_memory: int = None,
579
+ max_splits: int = 256,
580
+ ) -> Tuple[Dict, Tuple]:
581
+ """
582
+ Computes a parallelization schedule for a given template matching operation.
583
+
584
+ Parameters
585
+ ----------
586
+ matching_method : str
587
+ Matching method to use, default "FLCSphericalMask".
588
+ max_cores : int, optional
589
+ Maximum number of CPU cores to use, default 1.
590
+ use_gpu : bool, optional
591
+ Whether to utilize GPU acceleration, default False.
592
+ pad_fourier : bool, optional
593
+ Apply Fourier padding, default False.
594
+ pad_target_edges : bool, optional
595
+ Apply padding to target edges, default False.
596
+ analyzer_method : str, optional
597
+ Method used for score analysis, default None.
598
+ available_memory : int, optional
599
+ Available memory in bytes. If None, uses all available system memory.
600
+ max_splits : int, optional
601
+ Maximum number of splits to consider, default 256.
602
+
603
+ Returns
604
+ -------
605
+ target_splits : dict
606
+ Optimal splits for each axis of the target tensor
607
+ schedule : tuple
608
+ (n_outer_jobs, n_inner_jobs_per_outer) defining the parallelization schedule
609
+ """
610
+
611
+ if available_memory is None:
612
+ available_memory = be.get_available_memory() * be.device_count()
613
+
614
+ _template = self._output_template_shape
615
+ shape1 = np.broadcast_shapes(
616
+ self._output_target_shape,
617
+ self._batch_shape(_template, np.subtract(1, self._template_batch)),
618
+ )
619
+
620
+ shape2 = tuple(0 for _ in _template)
621
+ if pad_fourier:
622
+ shape2 = np.multiply(_template, np.subtract(1, self._batch_mask))
623
+
624
+ padding = tuple(0 for _ in self._output_target_shape)
625
+ if pad_target_edges:
626
+ padding = tuple(
627
+ x if y == 0 else 1 for x, y in zip(_template, self._template_batch)
628
+ )
629
+
630
+ return compute_parallelization_schedule(
631
+ shape1=shape1,
632
+ shape2=shape2,
633
+ shape1_padding=padding,
634
+ max_cores=max_cores,
635
+ max_ram=available_memory,
636
+ matching_method=matching_method,
637
+ analyzer_method=analyzer_method,
638
+ backend=be._backend_name,
639
+ float_nbytes=be.datatype_bytes(be._float_dtype),
640
+ complex_nbytes=be.datatype_bytes(be._complex_dtype),
641
+ integer_nbytes=be.datatype_bytes(be._int_dtype),
642
+ split_only_outer=use_gpu,
643
+ split_axes=self._target_dim if len(self._target_dim) else None,
644
+ max_splits=max_splits,
645
+ )
646
+
647
+ @property
648
+ def rotations(self):
649
+ """Return stored rotation matrices."""
650
+ return self._rotations
651
+
652
+ @rotations.setter
653
+ def rotations(self, rotations: BackendArray):
654
+ """
655
+ Set :py:attr:`MatchingData.rotations`.
656
+
657
+ Parameters
658
+ ----------
659
+ rotations : BackendArray
660
+ Rotations matrices with shape (d, d) or (n, d, d).
661
+ """
662
+ if rotations is None:
663
+ print("No rotations provided, assuming identity for now.")
664
+ rotations = np.eye(len(self._target.shape))
665
+
666
+ if rotations.ndim not in (2, 3):
667
+ raise ValueError("Rotations have to be a rank 2 or 3 array.")
668
+ elif rotations.ndim == 2:
669
+ print("Reshaping rotations array to rank 3.")
670
+ rotations = rotations.reshape(1, *rotations.shape)
671
+ self._rotations = rotations.astype(np.float32)
672
+
673
+ @staticmethod
674
+ def _get_data(
675
+ attribute,
676
+ output_shape: Tuple[int],
677
+ reverse: bool = False,
678
+ axis: Tuple[int] = None,
679
+ ):
680
+ if isinstance(attribute, Density):
681
+ attribute = attribute.data
682
+
683
+ if attribute is not None:
684
+ if reverse:
685
+ rev_axis = tuple(i for i in range(attribute.ndim) if i not in axis)
686
+ attribute = be.reverse(attribute, axis=rev_axis)
687
+ attribute = attribute.reshape(tuple(int(x) for x in output_shape))
688
+
689
+ return attribute
690
+
691
+ @property
692
+ def target(self) -> BackendArray:
693
+ """Return the target."""
694
+ return self._get_data(self._target, self._output_target_shape, False)
695
+
696
+ @property
697
+ def target_mask(self) -> BackendArray:
698
+ """Return the target mask."""
699
+ target_mask = getattr(self, "_target_mask", None)
700
+ if target_mask is None:
701
+ return None
702
+
703
+ _output_shape = self._output_target_shape
704
+ if be.size(target_mask) != np.prod(_output_shape):
705
+ _output_shape = self._batch_shape(_output_shape, self._target_batch, True)
706
+
707
+ return self._get_data(target_mask, _output_shape, False)
708
+
709
+ @property
710
+ def template(self) -> BackendArray:
711
+ """Return the reversed template."""
712
+ _output_shape = self._output_template_shape
713
+ return self._get_data(self._template, _output_shape, True, self._template_dim)
714
+
715
+ @property
716
+ def template_mask(self) -> BackendArray:
717
+ """Return the reversed template mask."""
718
+ template_mask = getattr(self, "_template_mask", None)
719
+ if template_mask is None:
720
+ return None
721
+
722
+ _output_shape = self._output_template_shape
723
+ if np.prod([int(i) for i in template_mask.shape]) != np.prod(_output_shape):
724
+ _output_shape = self._batch_shape(_output_shape, self._template_batch, True)
725
+
726
+ return self._get_data(template_mask, _output_shape, True, self._template_dim)
727
+
728
+ @target.setter
729
+ def target(self, arr: NDArray):
730
+ """
731
+ Set :py:attr:`MatchingData.target`.
732
+
733
+ Parameters
734
+ ----------
735
+ arr : NDArray
736
+ Array to set as the target.
737
+ """
738
+ self._target = arr
739
+
740
+ @template.setter
741
+ def template(self, arr: NDArray):
742
+ """
743
+ Set :py:attr:`MatchingData.template` and initializes
744
+ :py:attr:`MatchingData.template_mask` to an to an uninformative
745
+ mask filled with ones if not already defined.
746
+
747
+ Parameters
748
+ ----------
749
+ arr : NDArray
750
+ Array to set as the template.
751
+ """
752
+ self._template = arr
753
+ if getattr(self, "_template_mask", None) is None:
754
+ self._template_mask = np.full(
755
+ shape=arr.shape, dtype=np.float32, fill_value=1
756
+ )
757
+
758
+ @staticmethod
759
+ def _set_mask(mask, shape: Tuple[int]):
760
+ if mask is not None:
761
+ if np.broadcast_shapes(mask.shape, shape) != shape:
762
+ raise ValueError("Mask and data shape need to be broadcastable.")
763
+ return mask
764
+
765
+ @target_mask.setter
766
+ def target_mask(self, arr: NDArray):
767
+ """
768
+ Set :py:attr:`MatchingData.target_mask`.
769
+
770
+ Parameters
771
+ ----------
772
+ arr : NDArray
773
+ Array to set as the target_mask.
774
+ """
775
+ self._target_mask = self._set_mask(mask=arr, shape=self._target.shape)
776
+
777
+ @template_mask.setter
778
+ def template_mask(self, arr: NDArray):
779
+ """
780
+ Set :py:attr:`MatchingData.template_mask`.
781
+
782
+ Parameters
783
+ ----------
784
+ arr : NDArray
785
+ Array to set as the template_mask.
786
+ """
787
+ self._template_mask = self._set_mask(mask=arr, shape=self._template.shape)
788
+
789
+ @staticmethod
790
+ def _set_filter(composable_filter) -> Optional[Compose]:
791
+ if composable_filter is None:
792
+ return None
793
+
794
+ if not isinstance(composable_filter, Compose):
795
+ warnings.warn(
796
+ "Custom filters are not sanitized and need to be correctly shaped."
797
+ )
798
+
799
+ return composable_filter
800
+
801
+ @property
802
+ def template_filter(self) -> Optional[Compose]:
803
+ """
804
+ Returns the template filter.
805
+
806
+ Returns
807
+ -------
808
+ :py:class:`tme.preprocessing.compose.Compose` | BackendArray | None
809
+ Composable filter, a backend array or None.
810
+ """
811
+ return getattr(self, "_template_filter", None)
812
+
813
+ @property
814
+ def target_filter(self) -> Optional[Compose]:
815
+ """
816
+ Returns the target filter.
817
+
818
+ Returns
819
+ -------
820
+ :py:class:`tme.preprocessing.compose.Compose` | BackendArray | None
821
+ Composable filter, a backend array or None.
822
+ """
823
+ return getattr(self, "_target_filter", None)
824
+
825
+ @template_filter.setter
826
+ def template_filter(self, template_filter):
827
+ self._template_filter = self._set_filter(template_filter)
828
+
829
+ @target_filter.setter
830
+ def target_filter(self, target_filter):
831
+ self._target_filter = self._set_filter(target_filter)
832
+
833
+ def _split_rotations_on_jobs(self, n_jobs: int) -> List[NDArray]:
834
+ """
835
+ Split the rotation matrices into parts based on the number of jobs.
836
+
837
+ Parameters
838
+ ----------
839
+ n_jobs : int
840
+ Number of jobs for splitting.
841
+
842
+ Returns
843
+ -------
844
+ list of NDArray
845
+ List of split rotation matrices.
846
+ """
847
+ nrot_per_job = self.rotations.shape[0] // n_jobs
848
+ rot_list = []
849
+ for n in range(n_jobs):
850
+ init_rot = n * nrot_per_job
851
+ end_rot = init_rot + nrot_per_job
852
+ if n == n_jobs - 1:
853
+ end_rot = None
854
+ rot_list.append(self.rotations[init_rot:end_rot])
855
+ return rot_list
856
+
857
+ def _free_data(self):
858
+ """
859
+ Dereference data arrays owned by the class instance.
860
+ """
861
+ attrs = ("_target", "_template", "_template_mask", "_target_mask")
862
+ for attr in attrs:
863
+ setattr(self, attr, None)