pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,314 @@
1
+ """ Backend using jax for template matching.
2
+
3
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from functools import wraps
9
+ from typing import Tuple, List, Callable
10
+
11
+ from ..types import BackendArray
12
+ from .npfftw_backend import NumpyFFTWBackend, shm_type
13
+
14
+
15
+ def emulate_out(func):
16
+ """
17
+ Adds an out argument to write output of ``func`` to.
18
+ """
19
+
20
+ @wraps(func)
21
+ def inner(*args, out=None, **kwargs):
22
+ ret = func(*args, **kwargs)
23
+ if out is not None:
24
+ out = out.at[:].set(ret)
25
+ return out
26
+ return ret
27
+
28
+ return inner
29
+
30
+
31
+ class JaxBackend(NumpyFFTWBackend):
32
+ """
33
+ A jax-based matching backend.
34
+ """
35
+
36
+ def __init__(self, float_dtype=None, complex_dtype=None, int_dtype=None, **kwargs):
37
+ import jax.scipy as jsp
38
+ import jax.numpy as jnp
39
+
40
+ float_dtype = jnp.float32 if float_dtype is None else float_dtype
41
+ complex_dtype = jnp.complex64 if complex_dtype is None else complex_dtype
42
+ int_dtype = jnp.int32 if int_dtype is None else int_dtype
43
+
44
+ super().__init__(
45
+ array_backend=jnp,
46
+ float_dtype=float_dtype,
47
+ complex_dtype=complex_dtype,
48
+ int_dtype=int_dtype,
49
+ overflow_safe_dtype=float_dtype,
50
+ )
51
+ self.scipy = jsp
52
+ self._create_ufuncs()
53
+ try:
54
+ from ._jax_utils import scan as _
55
+
56
+ self.scan = self._scan
57
+ except Exception:
58
+ pass
59
+
60
+ def from_sharedarr(self, arr: BackendArray) -> BackendArray:
61
+ return arr
62
+
63
+ @staticmethod
64
+ def to_sharedarr(arr: BackendArray, shared_memory_handler: type = None) -> shm_type:
65
+ return arr
66
+
67
+ @staticmethod
68
+ def at(arr, idx, value) -> BackendArray:
69
+ arr = arr.at[idx].set(value)
70
+ return arr
71
+
72
+ def topleft_pad(
73
+ self, arr: BackendArray, shape: Tuple[int], padval: int = 0
74
+ ) -> BackendArray:
75
+ b = self.full(shape=shape, dtype=arr.dtype, fill_value=padval)
76
+ aind = [slice(None, None)] * arr.ndim
77
+ bind = [slice(None, None)] * arr.ndim
78
+ for i in range(arr.ndim):
79
+ if arr.shape[i] > shape[i]:
80
+ aind[i] = slice(0, shape[i])
81
+ elif arr.shape[i] < shape[i]:
82
+ bind[i] = slice(0, arr.shape[i])
83
+ b = b.at[tuple(bind)].set(arr[tuple(aind)])
84
+ return b
85
+
86
+ def _create_ufuncs(self):
87
+ ufuncs = [
88
+ "add",
89
+ "subtract",
90
+ "multiply",
91
+ "divide",
92
+ "square",
93
+ "sqrt",
94
+ "maximum",
95
+ "exp",
96
+ ]
97
+ for ufunc in ufuncs:
98
+ backend_method = emulate_out(getattr(self._array_backend, ufunc))
99
+ setattr(self, ufunc, staticmethod(backend_method))
100
+
101
+ ufuncs = ["zeros", "full"]
102
+ for ufunc in ufuncs:
103
+ backend_method = getattr(self._array_backend, ufunc)
104
+ setattr(self, ufunc, staticmethod(backend_method))
105
+
106
+ def fill(self, arr: BackendArray, value: float) -> BackendArray:
107
+ return self._array_backend.full(
108
+ shape=arr.shape, dtype=arr.dtype, fill_value=value
109
+ )
110
+
111
+ def build_fft(
112
+ self,
113
+ fwd_shape: Tuple[int],
114
+ inv_shape: Tuple[int] = None,
115
+ inv_output_shape: Tuple[int] = None,
116
+ fwd_axes: Tuple[int] = None,
117
+ inv_axes: Tuple[int] = None,
118
+ **kwargs,
119
+ ) -> Tuple[Callable, Callable]:
120
+ rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
121
+ irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
122
+ irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
123
+
124
+ def rfftn(arr, out=None, s=rfft_shape, axes=fwd_axes):
125
+ return self._array_backend.fft.rfftn(arr, s=s, axes=axes)
126
+
127
+ def irfftn(arr, out=None, s=irfft_shape, axes=inv_axes):
128
+ return self._array_backend.fft.irfftn(arr, s=s, axes=axes)
129
+
130
+ return rfftn, irfftn
131
+
132
+ def rfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
133
+ return self._array_backend.fft.rfftn(arr, **kwargs)
134
+
135
+ def irfftn(self, arr: BackendArray, *args, **kwargs) -> BackendArray:
136
+ return self._array_backend.fft.irfftn(arr, **kwargs)
137
+
138
+ def rigid_transform(
139
+ self,
140
+ arr: BackendArray,
141
+ rotation_matrix: BackendArray,
142
+ out: BackendArray = None,
143
+ out_mask: BackendArray = None,
144
+ translation: BackendArray = None,
145
+ arr_mask: BackendArray = None,
146
+ order: int = 1,
147
+ **kwargs,
148
+ ) -> Tuple[BackendArray, BackendArray]:
149
+ rotate_mask = arr_mask is not None
150
+
151
+ # This approach is only valid for order <= 1
152
+ if arr.ndim != rotation_matrix.shape[0]:
153
+ matrix = self._array_backend.zeros((arr.ndim, arr.ndim))
154
+ matrix = matrix.at[0, 0].set(1)
155
+ matrix = matrix.at[1:, 1:].add(rotation_matrix)
156
+ rotation_matrix = matrix
157
+
158
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)[:, None]
159
+ indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
160
+ indices = indices.reshape((arr.ndim, -1))
161
+ indices = indices.at[:].add(-center)
162
+ indices = self._array_backend.matmul(rotation_matrix.T, indices)
163
+ indices = indices.at[:].add(center)
164
+ if translation is not None:
165
+ indices = indices.at[:].add(translation)
166
+
167
+ out = self.scipy.ndimage.map_coordinates(arr, indices, order=order).reshape(
168
+ arr.shape
169
+ )
170
+
171
+ out_mask = arr_mask
172
+ if rotate_mask:
173
+ out_mask = self.scipy.ndimage.map_coordinates(
174
+ arr_mask, indices, order=order
175
+ ).reshape(arr_mask.shape)
176
+
177
+ return out, out_mask
178
+
179
+ def max_score_over_rotations(
180
+ self,
181
+ scores: BackendArray,
182
+ max_scores: BackendArray,
183
+ rotations: BackendArray,
184
+ rotation_index: int,
185
+ ) -> Tuple[BackendArray, BackendArray]:
186
+ update = self.greater(max_scores, scores)
187
+ max_scores = max_scores.at[:].set(self.where(update, max_scores, scores))
188
+ rotations = rotations.at[:].set(self.where(update, rotations, rotation_index))
189
+ return max_scores, rotations
190
+
191
+ def _scan(
192
+ self,
193
+ matching_data: type,
194
+ splits: Tuple[Tuple[slice, slice]],
195
+ n_jobs: int,
196
+ callback_class,
197
+ rotate_mask: bool = False,
198
+ **kwargs,
199
+ ) -> List:
200
+ """
201
+ Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
202
+ :py:class:`tme.analyzer.MaxScoreOverRotations`.
203
+ """
204
+ from ._jax_utils import scan as scan_inner
205
+
206
+ pad_target = True if len(splits) > 1 else False
207
+ convolution_mode = "valid" if pad_target else "same"
208
+ target_pad = matching_data.target_padding(pad_target=pad_target)
209
+
210
+ target_shape = tuple(
211
+ (x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
212
+ )
213
+ conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
214
+ target_shape=self.to_numpy_array(target_shape),
215
+ template_shape=self.to_numpy_array(matching_data._template.shape),
216
+ pad_fourier=False,
217
+ )
218
+
219
+ analyzer_args = {
220
+ "convolution_mode": convolution_mode,
221
+ "fourier_shift": shift,
222
+ "targetshape": target_shape,
223
+ "templateshape": matching_data.template.shape,
224
+ "convolution_shape": conv_shape,
225
+ }
226
+
227
+ create_target_filter = matching_data.target_filter is not None
228
+ create_template_filter = matching_data.template_filter is not None
229
+ create_filter = create_target_filter or create_template_filter
230
+
231
+ # Applying the filter leads to more FFTs
232
+ fastt_shape = matching_data._template.shape
233
+ if create_template_filter:
234
+ fastt_shape = matching_data._template.shape
235
+
236
+ ret, template_filter, target_filter = [], 1, 1
237
+ rotation_mapping = {
238
+ self.tobytes(matching_data.rotations[i]): i
239
+ for i in range(matching_data.rotations.shape[0])
240
+ }
241
+ for split_start in range(0, len(splits), n_jobs):
242
+ split_subset = splits[split_start : (split_start + n_jobs)]
243
+ if not len(split_subset):
244
+ continue
245
+
246
+ targets, translation_offsets = [], []
247
+ for target_split, template_split in split_subset:
248
+ base = matching_data.subset_by_slice(
249
+ target_slice=target_split,
250
+ target_pad=target_pad,
251
+ template_slice=template_split,
252
+ )
253
+ translation_offsets.append(base._translation_offset)
254
+ targets.append(self.topleft_pad(base._target, fast_shape))
255
+
256
+ if create_filter:
257
+ filter_args = {
258
+ "data_rfft": self.fft.rfftn(targets[0]),
259
+ "return_real_fourier": True,
260
+ "shape_is_real_fourier": False,
261
+ }
262
+
263
+ if create_template_filter:
264
+ template_filter = matching_data.template_filter(
265
+ shape=fastt_shape, **filter_args
266
+ )["data"]
267
+ template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
268
+
269
+ if create_target_filter:
270
+ target_filter = matching_data.target_filter(
271
+ shape=fast_shape, **filter_args
272
+ )["data"]
273
+ target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
274
+
275
+ create_filter, create_template_filter, create_target_filter = (False,) * 3
276
+ base, targets = None, self._array_backend.stack(targets)
277
+ scores, rotations = scan_inner(
278
+ self.astype(targets, self._float_dtype),
279
+ matching_data.template,
280
+ matching_data.template_mask,
281
+ matching_data.rotations,
282
+ template_filter,
283
+ target_filter,
284
+ fast_shape,
285
+ rotate_mask,
286
+ )
287
+
288
+ for index in range(scores.shape[0]):
289
+ temp = callback_class(
290
+ shape=scores.shape,
291
+ scores=scores[index],
292
+ rotations=rotations[index],
293
+ thread_safe=False,
294
+ offset=translation_offsets[index],
295
+ )
296
+ temp.rotation_mapping = rotation_mapping
297
+ ret.append(tuple(temp._postprocess(**analyzer_args)))
298
+
299
+ return ret
300
+
301
+ def get_available_memory(self) -> int:
302
+ import jax
303
+
304
+ _memory = {"cpu": 0, "gpu": 0}
305
+ for device in jax.devices():
306
+ if device.platform == "cpu":
307
+ _memory["cpu"] = super().get_available_memory()
308
+ else:
309
+ mem_stats = device.memory_stats()
310
+ _memory["gpu"] += mem_stats.get("bytes_limit", 0)
311
+
312
+ if _memory["gpu"] > 0:
313
+ return _memory["gpu"]
314
+ return _memory["cpu"]