pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,241 @@
1
+ """ Backend using Apple's MLX library for template matching.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+
8
+ from typing import Tuple, List, Callable
9
+
10
+ import numpy as np
11
+
12
+ from .npfftw_backend import NumpyFFTWBackend
13
+ from ..types import NDArray, MlxArray, Scalar, shm_type
14
+
15
+
16
+ class MLXBackend(NumpyFFTWBackend):
17
+ """
18
+ A mlx-based matching backend.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ device="cpu",
24
+ float_dtype=None,
25
+ complex_dtype=None,
26
+ int_dtype=None,
27
+ overflow_safe_dtype=None,
28
+ **kwargs,
29
+ ):
30
+ import mlx.core as mx
31
+
32
+ device = mx.cpu if device == "cpu" else mx.gpu
33
+ float_dtype = mx.float32 if float_dtype is None else float_dtype
34
+ complex_dtype = mx.complex64 if complex_dtype is None else complex_dtype
35
+ int_dtype = mx.int32 if int_dtype is None else int_dtype
36
+ if overflow_safe_dtype is None:
37
+ overflow_safe_dtype = mx.float32
38
+
39
+ super().__init__(
40
+ array_backend=mx,
41
+ float_dtype=float_dtype,
42
+ complex_dtype=complex_dtype,
43
+ int_dtype=int_dtype,
44
+ overflow_safe_dtype=overflow_safe_dtype,
45
+ )
46
+
47
+ self.device = device
48
+
49
+ def to_backend_array(self, arr: NDArray) -> MlxArray:
50
+ return self._array_backend.array(arr)
51
+
52
+ def to_numpy_array(self, arr: MlxArray) -> NDArray:
53
+ return np.array(arr)
54
+
55
+ def to_cpu_array(self, arr: MlxArray) -> NDArray:
56
+ return arr
57
+
58
+ def free_cache(self):
59
+ pass
60
+
61
+ def mod(self, arr1: MlxArray, arr2: MlxArray, out: MlxArray = None) -> MlxArray:
62
+ if out is not None:
63
+ out[:] = arr1 % arr2
64
+ return None
65
+ return arr1 % arr2
66
+
67
+ def add(self, x1, x2, out: MlxArray = None, **kwargs) -> MlxArray:
68
+ x1 = self.to_backend_array(x1)
69
+ x2 = self.to_backend_array(x2)
70
+
71
+ if out is not None:
72
+ out[:] = self._array_backend.add(x1, x2, **kwargs)
73
+ return None
74
+ return self._array_backend.add(x1, x2, **kwargs)
75
+
76
+ def multiply(self, x1, x2, out: MlxArray = None, **kwargs) -> MlxArray:
77
+ x1 = self.to_backend_array(x1)
78
+ x2 = self.to_backend_array(x2)
79
+
80
+ if out is not None:
81
+ out[:] = self._array_backend.multiply(x1, x2, **kwargs)
82
+ return None
83
+ return self._array_backend.multiply(x1, x2, **kwargs)
84
+
85
+ def std(self, arr: MlxArray, axis) -> Scalar:
86
+ return self._array_backend.sqrt(arr.var(axis=axis))
87
+
88
+ def unique(self, *args, **kwargs):
89
+ ret = np.unique(*args, **kwargs)
90
+ if isinstance(ret, tuple):
91
+ ret = [self.to_backend_array(x) for x in ret]
92
+ return ret
93
+
94
+ def tobytes(self, arr):
95
+ return self.to_numpy_array(arr).tobytes()
96
+
97
+ def full(self, shape, fill_value, dtype=None):
98
+ return self._array_backend.full(shape=shape, dtype=dtype, vals=fill_value)
99
+
100
+ def fill(self, arr: MlxArray, value: Scalar) -> MlxArray:
101
+ arr[:] = value
102
+ return arr
103
+
104
+ def zeros(self, shape: Tuple[int], dtype: type = None) -> MlxArray:
105
+ return self._array_backend.zeros(shape=shape, dtype=dtype)
106
+
107
+ def roll(self, a: MlxArray, shift, axis, **kwargs):
108
+ a = self.to_numpy_array(a)
109
+ ret = NumpyFFTWBackend().roll(
110
+ a,
111
+ shift=shift,
112
+ axis=axis,
113
+ **kwargs,
114
+ )
115
+ return self.to_backend_array(ret)
116
+
117
+ def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
118
+ """
119
+ Extract the centered portion of an array based on a new shape.
120
+
121
+ Parameters
122
+ ----------
123
+ arr : NDArray
124
+ Input array.
125
+ newshape : tuple
126
+ Desired shape for the central portion.
127
+
128
+ Returns
129
+ -------
130
+ NDArray
131
+ Central portion of the array with shape `newshape`.
132
+
133
+ References
134
+ ----------
135
+ .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py
136
+ """
137
+ new_shape = self.to_backend_array(newshape)
138
+ current_shape = self.to_backend_array(arr.shape)
139
+ starts = self.subtract(current_shape, new_shape)
140
+ starts = self.astype(self.divide(starts, 2), self._int_dtype)
141
+ stops = self.astype(self.add(starts, newshape), self._int_dtype)
142
+ starts, stops = starts.tolist(), stops.tolist()
143
+ box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
144
+ return arr[box]
145
+
146
+ def build_fft(
147
+ self,
148
+ fwd_shape: Tuple[int],
149
+ inv_shape: Tuple[int] = None,
150
+ inv_output_shape: Tuple[int] = None,
151
+ fwd_axes: Tuple[int] = None,
152
+ inv_axes: Tuple[int] = None,
153
+ **kwargs,
154
+ ) -> Tuple[Callable, Callable]:
155
+ # Runs on mlx.core.cpu until Metal support is available
156
+ rfft_shape = self._format_fft_shape(fwd_shape, fwd_axes)
157
+ irfft_shape = fwd_shape if inv_output_shape is None else inv_output_shape
158
+ irfft_shape = self._format_fft_shape(irfft_shape, inv_axes)
159
+
160
+ def rfftn(arr: MlxArray, out: MlxArray = None, s=rfft_shape, axes=fwd_axes):
161
+ out[:] = self._array_backend.fft.rfftn(
162
+ arr, s=s, axes=axes, stream=self._array_backend.cpu
163
+ )
164
+
165
+ def irfftn(arr: MlxArray, out: MlxArray = None, s=irfft_shape, axes=inv_axes):
166
+ out[:] = self._array_backend.fft.irfftn(
167
+ arr, s=s, axes=axes, stream=self._array_backend.cpu
168
+ )
169
+
170
+ return rfftn, irfftn
171
+
172
+ def rfftn(self, arr, *args, **kwargs):
173
+ return self.fft.rfftn(arr, stream=self._array_backend.cpu, **kwargs)
174
+
175
+ def irfftn(self, arr, *args, **kwargs):
176
+ return self.fft.irfftn(arr, stream=self._array_backend.cpu, **kwargs)
177
+
178
+ def from_sharedarr(self, arr: MlxArray) -> MlxArray:
179
+ return arr
180
+
181
+ @staticmethod
182
+ def to_sharedarr(arr: MlxArray, shared_memory_handler: type = None) -> shm_type:
183
+ return arr
184
+
185
+ def topk_indices(self, arr: NDArray, k: int):
186
+ arr = self.to_numpy_array(arr)
187
+ ret = NumpyFFTWBackend().topk_indices(arr=arr, k=k)
188
+ ret = [self.to_backend_array(x) for x in ret]
189
+ return ret
190
+
191
+ def rigid_transform(
192
+ self,
193
+ arr: NDArray,
194
+ rotation_matrix: NDArray,
195
+ arr_mask: NDArray = None,
196
+ translation: NDArray = None,
197
+ use_geometric_center: bool = False,
198
+ out: NDArray = None,
199
+ out_mask: NDArray = None,
200
+ order: int = 3,
201
+ **kwargs,
202
+ ) -> None:
203
+ arr = self.to_numpy_array(arr)
204
+ rotation_matrix = self.to_numpy_array(rotation_matrix)
205
+
206
+ if arr_mask is not None:
207
+ arr_mask = self.to_numpy_array(arr_mask)
208
+
209
+ if translation is not None:
210
+ translation = self.to_numpy_array(translation)
211
+
212
+ if out is None:
213
+ out = self.zeros(arr.shape)
214
+ if out_mask is None and arr_mask is not None:
215
+ out_mask_pass = self.zeros(arr_mask.shape)
216
+
217
+ ret = NumpyFFTWBackend().rigid_transform(
218
+ arr=arr,
219
+ rotation_matrix=rotation_matrix,
220
+ arr_mask=arr_mask,
221
+ translation=translation,
222
+ use_geometric_center=use_geometric_center,
223
+ order=order,
224
+ )
225
+
226
+ out_pass, out_mask_pass = ret
227
+ out[:] = self.to_backend_array(out_pass)
228
+
229
+ if out_mask_pass is not None:
230
+ out_mask_pass = self.to_backend_array(out_mask_pass)
231
+
232
+ if out_mask is not None:
233
+ out_mask[:] = out_mask_pass
234
+ else:
235
+ out_mask = out_mask_pass
236
+
237
+ return out, out_mask
238
+
239
+ def indices(self, arr: List) -> MlxArray:
240
+ ret = NumpyFFTWBackend().indices(arr)
241
+ return self.to_backend_array(ret)