pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/match_template.py +97 -148
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/postprocess.py +20 -29
- pytme-0.2.4.data/scripts/preprocess.py +148 -0
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +15 -23
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/METADATA +11 -10
- pytme-0.2.4.dist-info/RECORD +119 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
- pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
- scripts/match_template.py +97 -148
- scripts/postprocess.py +20 -29
- scripts/preprocess.py +116 -61
- scripts/preprocessor_gui.py +15 -23
- tests/__init__.py +0 -0
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Blurring/blob_width18.npy +0 -0
- tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
- tests/data/Blurring/gaussian_sigma2.npy +0 -0
- tests/data/Blurring/hamming_width6.npy +0 -0
- tests/data/Blurring/kaiserb_width18.npy +0 -0
- tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
- tests/data/Blurring/mean_size5.npy +0 -0
- tests/data/Blurring/ntree_sigma0510.npy +0 -0
- tests/data/Blurring/rank_rank3.npy +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Maps/emd_8621.mrc.gz +0 -0
- tests/data/README.md +2 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Raw/em_map.map +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/data/Structures/1pdj.cif +3339 -0
- tests/data/Structures/1pdj.pdb +1429 -0
- tests/data/Structures/5khe.cif +3685 -0
- tests/data/Structures/5khe.ent +2210 -0
- tests/data/Structures/5khe.pdb +2210 -0
- tests/data/Structures/5uz4.cif +70548 -0
- tests/preprocessing/__init__.py +0 -0
- tests/preprocessing/test_compose.py +76 -0
- tests/preprocessing/test_frequency_filters.py +178 -0
- tests/preprocessing/test_preprocessor.py +136 -0
- tests/preprocessing/test_utils.py +79 -0
- tests/test_analyzer.py +310 -0
- tests/test_backends.py +375 -0
- tests/test_density.py +508 -0
- tests/test_extensions.py +130 -0
- tests/test_matching_cli.py +283 -0
- tests/test_matching_data.py +162 -0
- tests/test_matching_exhaustive.py +162 -0
- tests/test_matching_memory.py +30 -0
- tests/test_matching_optimization.py +276 -0
- tests/test_matching_utils.py +326 -0
- tests/test_orientations.py +173 -0
- tests/test_packaging.py +95 -0
- tests/test_parser.py +33 -0
- tests/test_structure.py +243 -0
- tme/__init__.py +0 -1
- tme/__version__.py +1 -1
- tme/analyzer.py +9 -6
- tme/backends/__init__.py +1 -1
- tme/backends/_jax_utils.py +10 -8
- tme/backends/cupy_backend.py +2 -7
- tme/backends/jax_backend.py +35 -20
- tme/backends/npfftw_backend.py +3 -2
- tme/backends/pytorch_backend.py +10 -7
- tme/data/scattering_factors.pickle +0 -0
- tme/density.py +26 -12
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/external/bindings.cpp +332 -0
- tme/matching_data.py +33 -24
- tme/matching_exhaustive.py +39 -20
- tme/matching_scores.py +5 -2
- tme/matching_utils.py +8 -2
- tme/orientations.py +26 -9
- tme/preprocessing/_utils.py +14 -14
- tme/preprocessing/composable_filter.py +5 -4
- tme/preprocessing/compose.py +4 -4
- tme/preprocessing/frequency_filters.py +32 -35
- tme/preprocessing/tilt_series.py +210 -148
- tme/preprocessor.py +24 -246
- tme/structure.py +14 -14
- pytme-0.2.2.dist-info/RECORD +0 -74
- tme/matching_memory.py +0 -383
- {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
- {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
tme/backends/cupy_backend.py
CHANGED
@@ -149,10 +149,10 @@ class CupyBackend(NumpyFFTWBackend):
|
|
149
149
|
cache.clear()
|
150
150
|
|
151
151
|
def rfftn(arr: CupyArray, out: CupyArray) -> CupyArray:
|
152
|
-
return cufft.rfftn(arr)
|
152
|
+
return cufft.rfftn(arr, s=fast_shape)
|
153
153
|
|
154
154
|
def irfftn(arr: CupyArray, out: CupyArray) -> CupyArray:
|
155
|
-
return cufft.irfftn(arr)
|
155
|
+
return cufft.irfftn(arr, s=fast_shape)
|
156
156
|
|
157
157
|
PLAN_CACHE[current_device] = [fast_shape, fast_ft_shape]
|
158
158
|
|
@@ -167,11 +167,6 @@ class CupyBackend(NumpyFFTWBackend):
|
|
167
167
|
fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
|
168
168
|
fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
|
169
169
|
|
170
|
-
# This almost never happens but avoid cuFFT casting errors
|
171
|
-
is_odd = fast_shape[-1] % 2
|
172
|
-
fast_shape[-1] += is_odd
|
173
|
-
fast_ft_shape[-1] += is_odd
|
174
|
-
|
175
170
|
return convolution_shape, fast_shape, fast_ft_shape
|
176
171
|
|
177
172
|
def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
|
tme/backends/jax_backend.py
CHANGED
@@ -119,19 +119,6 @@ class JaxBackend(NumpyFFTWBackend):
|
|
119
119
|
|
120
120
|
return rfftn, irfftn
|
121
121
|
|
122
|
-
def compute_convolution_shapes(
|
123
|
-
self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
|
124
|
-
) -> Tuple[List[int], List[int], List[int]]:
|
125
|
-
conv_shape, fast_shape, fast_ft_shape = super().compute_convolution_shapes(
|
126
|
-
arr1_shape, arr2_shape
|
127
|
-
)
|
128
|
-
|
129
|
-
is_odd = fast_shape[-1] % 2
|
130
|
-
fast_shape[-1] += is_odd
|
131
|
-
fast_ft_shape[-1] += is_odd
|
132
|
-
|
133
|
-
return conv_shape, fast_shape, fast_ft_shape
|
134
|
-
|
135
122
|
def rigid_transform(
|
136
123
|
self,
|
137
124
|
arr: BackendArray,
|
@@ -144,8 +131,8 @@ class JaxBackend(NumpyFFTWBackend):
|
|
144
131
|
**kwargs,
|
145
132
|
) -> Tuple[BackendArray, BackendArray]:
|
146
133
|
rotate_mask = arr_mask is not None
|
147
|
-
center = self.divide(self.to_backend_array(arr.shape), 2)[:, None]
|
148
134
|
|
135
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)[:, None]
|
149
136
|
indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
|
150
137
|
indices = indices.reshape((arr.ndim, -1))
|
151
138
|
indices = indices.at[:].add(-center)
|
@@ -200,7 +187,7 @@ class JaxBackend(NumpyFFTWBackend):
|
|
200
187
|
target_shape = tuple(
|
201
188
|
(x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
|
202
189
|
)
|
203
|
-
fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
|
190
|
+
conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
|
204
191
|
target_shape=self.to_numpy_array(target_shape),
|
205
192
|
template_shape=self.to_numpy_array(matching_data._template.shape),
|
206
193
|
pad_fourier=False,
|
@@ -210,13 +197,26 @@ class JaxBackend(NumpyFFTWBackend):
|
|
210
197
|
"convolution_mode": convolution_mode,
|
211
198
|
"fourier_shift": shift,
|
212
199
|
"targetshape": target_shape,
|
213
|
-
"templateshape": matching_data.
|
200
|
+
"templateshape": matching_data.template.shape,
|
201
|
+
"convolution_shape": conv_shape,
|
214
202
|
}
|
215
203
|
|
216
204
|
create_target_filter = matching_data.target_filter is not None
|
217
205
|
create_template_filter = matching_data.template_filter is not None
|
218
206
|
create_filter = create_target_filter or create_template_filter
|
219
207
|
|
208
|
+
# Applying the filter leads to more FFTs
|
209
|
+
fastt_shape = matching_data._template.shape
|
210
|
+
if create_template_filter:
|
211
|
+
# _, fastt_shape, _, tshift = matching_data._fourier_padding(
|
212
|
+
# target_shape=self.to_numpy_array(matching_data._template.shape),
|
213
|
+
# template_shape=self.to_numpy_array(
|
214
|
+
# [1 for _ in matching_data._template.shape]
|
215
|
+
# ),
|
216
|
+
# pad_fourier=False,
|
217
|
+
# )
|
218
|
+
fastt_shape = matching_data._template.shape
|
219
|
+
|
220
220
|
ret, template_filter, target_filter = [], 1, 1
|
221
221
|
rotation_mapping = {
|
222
222
|
self.tobytes(matching_data.rotations[i]): i
|
@@ -246,12 +246,12 @@ class JaxBackend(NumpyFFTWBackend):
|
|
246
246
|
|
247
247
|
if create_template_filter:
|
248
248
|
template_filter = matching_data.template_filter(
|
249
|
-
shape=
|
249
|
+
shape=fastt_shape, **filter_args
|
250
250
|
)["data"]
|
251
251
|
template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
|
252
252
|
|
253
253
|
if create_target_filter:
|
254
|
-
target_filter = matching_data.
|
254
|
+
target_filter = matching_data.target_filter(
|
255
255
|
shape=fast_shape, **filter_args
|
256
256
|
)["data"]
|
257
257
|
target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
|
@@ -260,8 +260,8 @@ class JaxBackend(NumpyFFTWBackend):
|
|
260
260
|
base, targets = None, self._array_backend.stack(targets)
|
261
261
|
scores, rotations = scan_inner(
|
262
262
|
targets,
|
263
|
-
matching_data.template,
|
264
|
-
matching_data.template_mask,
|
263
|
+
self.topleft_pad(matching_data.template, fastt_shape),
|
264
|
+
self.topleft_pad(matching_data.template_mask, fastt_shape),
|
265
265
|
matching_data.rotations,
|
266
266
|
template_filter,
|
267
267
|
target_filter,
|
@@ -280,3 +280,18 @@ class JaxBackend(NumpyFFTWBackend):
|
|
280
280
|
ret.append(tuple(temp._postprocess(**analyzer_args)))
|
281
281
|
|
282
282
|
return ret
|
283
|
+
|
284
|
+
def get_available_memory(self) -> int:
|
285
|
+
import jax
|
286
|
+
|
287
|
+
_memory = {"cpu": 0, "gpu": 0}
|
288
|
+
for device in jax.devices():
|
289
|
+
if device.platform == "cpu":
|
290
|
+
_memory["cpu"] = super().get_available_memory()
|
291
|
+
else:
|
292
|
+
mem_stats = device.memory_stats()
|
293
|
+
_memory["gpu"] += mem_stats.get("bytes_limit", 0)
|
294
|
+
|
295
|
+
if _memory["gpu"] > 0:
|
296
|
+
return _memory["gpu"]
|
297
|
+
return _memory["cpu"]
|
tme/backends/npfftw_backend.py
CHANGED
@@ -186,7 +186,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
186
186
|
def to_sharedarr(
|
187
187
|
self, arr: NDArray, shared_memory_handler: type = None
|
188
188
|
) -> shm_type:
|
189
|
-
if
|
189
|
+
if isinstance(shared_memory_handler, SharedMemoryManager):
|
190
190
|
shm = shared_memory_handler.SharedMemory(size=arr.nbytes)
|
191
191
|
else:
|
192
192
|
shm = shared_memory.SharedMemory(create=True, size=arr.nbytes)
|
@@ -347,7 +347,8 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
|
|
347
347
|
cache: bool = False,
|
348
348
|
) -> Tuple[NDArray, NDArray]:
|
349
349
|
translation = self.zeros(arr.ndim) if translation is None else translation
|
350
|
-
|
350
|
+
|
351
|
+
center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
|
351
352
|
if not use_geometric_center:
|
352
353
|
center = self.center_of_mass(arr, cutoff=0)
|
353
354
|
|
tme/backends/pytorch_backend.py
CHANGED
@@ -81,13 +81,13 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
81
81
|
|
82
82
|
def max(self, *args, **kwargs) -> NDArray:
|
83
83
|
ret = self._array_backend.amax(*args, **kwargs)
|
84
|
-
if
|
84
|
+
if isinstance(ret, self._array_backend.Tensor):
|
85
85
|
return ret
|
86
86
|
return ret[0]
|
87
87
|
|
88
88
|
def min(self, *args, **kwargs) -> NDArray:
|
89
89
|
ret = self._array_backend.amin(*args, **kwargs)
|
90
|
-
if
|
90
|
+
if isinstance(ret, self._array_backend.Tensor):
|
91
91
|
return ret
|
92
92
|
return ret[0]
|
93
93
|
|
@@ -154,7 +154,7 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
154
154
|
1, -1
|
155
155
|
)
|
156
156
|
if unraveled_coords.size(0) == 1:
|
157
|
-
return
|
157
|
+
return (unraveled_coords[0, :],)
|
158
158
|
|
159
159
|
else:
|
160
160
|
return tuple(unraveled_coords.T)
|
@@ -206,7 +206,9 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
206
206
|
else:
|
207
207
|
raise NotImplementedError("Operation only implemented for 2 and 3D inputs.")
|
208
208
|
|
209
|
-
pool = func(
|
209
|
+
pool = func(
|
210
|
+
kernel_size=min_distance, padding=min_distance // 2, return_indices=True
|
211
|
+
)
|
210
212
|
_, indices = pool(score_space.reshape(1, 1, *score_space.shape))
|
211
213
|
coordinates = self.unravel_index(indices.reshape(-1), score_space.shape)
|
212
214
|
coordinates = self.transpose(self.stack(coordinates))
|
@@ -217,7 +219,7 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
217
219
|
|
218
220
|
def from_sharedarr(self, args) -> TorchTensor:
|
219
221
|
if self.device == "cuda":
|
220
|
-
return args
|
222
|
+
return args
|
221
223
|
|
222
224
|
shm, shape, dtype = args
|
223
225
|
required_size = int(self._array_backend.prod(self.to_backend_array(shape)))
|
@@ -235,13 +237,12 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
235
237
|
|
236
238
|
nbytes = arr.numel() * arr.element_size()
|
237
239
|
|
238
|
-
if
|
240
|
+
if isinstance(shared_memory_handler, SharedMemoryManager):
|
239
241
|
shm = shared_memory_handler.SharedMemory(size=nbytes)
|
240
242
|
else:
|
241
243
|
shm = shared_memory.SharedMemory(create=True, size=nbytes)
|
242
244
|
|
243
245
|
shm.buf[:nbytes] = arr.numpy().tobytes()
|
244
|
-
|
245
246
|
return shm, arr.shape, arr.dtype
|
246
247
|
|
247
248
|
def transpose(self, arr):
|
@@ -415,6 +416,8 @@ class PytorchBackend(NumpyFFTWBackend):
|
|
415
416
|
yield None
|
416
417
|
|
417
418
|
def device_count(self) -> int:
|
419
|
+
if self.device == "cpu":
|
420
|
+
return 1
|
418
421
|
return self._array_backend.cuda.device_count()
|
419
422
|
|
420
423
|
def reverse(self, arr: TorchTensor) -> TorchTensor:
|
Binary file
|
tme/density.py
CHANGED
@@ -116,8 +116,8 @@ class Density:
|
|
116
116
|
response = "Density object at {}\nOrigin: {}, Sampling Rate: {}, Shape: {}"
|
117
117
|
return response.format(
|
118
118
|
hex(id(self)),
|
119
|
-
tuple(
|
120
|
-
tuple(
|
119
|
+
tuple(round(float(x), 3) for x in self.origin),
|
120
|
+
tuple(round(float(x), 3) for x in self.sampling_rate),
|
121
121
|
self.shape,
|
122
122
|
)
|
123
123
|
|
@@ -306,6 +306,10 @@ class Density:
|
|
306
306
|
"std": float(mrc.header.rms),
|
307
307
|
}
|
308
308
|
|
309
|
+
non_standard_crs = not np.all(crs_index == (0, 1, 2))
|
310
|
+
if non_standard_crs:
|
311
|
+
warnings.warn("Non standard MAPC, MAPR, MAPS, adapting data and origin.")
|
312
|
+
|
309
313
|
if is_gzipped(filename):
|
310
314
|
if use_memmap:
|
311
315
|
warnings.warn(
|
@@ -315,6 +319,10 @@ class Density:
|
|
315
319
|
use_memmap = False
|
316
320
|
|
317
321
|
if subset is not None:
|
322
|
+
subset = tuple(
|
323
|
+
subset[i] if i < len(subset) else slice(0, data_shape[i])
|
324
|
+
for i in crs_index
|
325
|
+
)
|
318
326
|
subset_shape = tuple(x.stop - x.start for x in subset)
|
319
327
|
if np.allclose(subset_shape, data_shape):
|
320
328
|
return cls._load_mrc(
|
@@ -328,18 +336,16 @@ class Density:
|
|
328
336
|
dtype=data_type,
|
329
337
|
header_size=1024 + extended_header,
|
330
338
|
)
|
331
|
-
|
332
|
-
|
333
|
-
if not use_memmap:
|
339
|
+
elif subset is None and not use_memmap:
|
334
340
|
with mrcfile.open(filename, header_only=False) as mrc:
|
335
341
|
data = mrc.data.astype(np.float32, copy=False)
|
336
342
|
else:
|
337
343
|
with mrcfile.mrcmemmap.MrcMemmap(filename, header_only=False) as mrc:
|
338
344
|
data = mrc.data
|
339
345
|
|
340
|
-
if
|
346
|
+
if non_standard_crs:
|
341
347
|
data = np.transpose(data, crs_index)
|
342
|
-
|
348
|
+
origin = np.take(origin, crs_index)
|
343
349
|
|
344
350
|
return data, origin, sampling_rate, metadata
|
345
351
|
|
@@ -738,9 +744,16 @@ class Density:
|
|
738
744
|
>>> )
|
739
745
|
|
740
746
|
:py:meth:`Density.from_structure` supports a variety of methods to convert
|
741
|
-
atoms into densities
|
742
|
-
|
743
|
-
|
747
|
+
atoms into densities
|
748
|
+
|
749
|
+
>>> density = Density.from_structure(
|
750
|
+
>>> filename_or_structure = path_to_structure,
|
751
|
+
>>> weight_type = "gaussian",
|
752
|
+
>>> weight_type_args={"resolution": "20"}
|
753
|
+
>>> )
|
754
|
+
|
755
|
+
In addition its possible to use experimentally determined scattering factors
|
756
|
+
from various sources:
|
744
757
|
|
745
758
|
>>> density = Density.from_structure(
|
746
759
|
>>> filename_or_structure = path_to_structure,
|
@@ -748,7 +761,7 @@ class Density:
|
|
748
761
|
>>> weight_type_args={"source": "dt1969"}
|
749
762
|
>>> )
|
750
763
|
|
751
|
-
|
764
|
+
and their lowpass filtered representation introduced in [1]_:
|
752
765
|
|
753
766
|
>>> density = Density.from_structure(
|
754
767
|
>>> filename_or_structure = path_to_structure,
|
@@ -873,6 +886,7 @@ class Density:
|
|
873
886
|
mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.rint(
|
874
887
|
np.divide(self.origin, self.sampling_rate)
|
875
888
|
)
|
889
|
+
mrc.header.origin = tuple(x for x in self.origin)
|
876
890
|
# mrcfile library expects origin to be in xyz format
|
877
891
|
mrc.header.mapc, mrc.header.mapr, mrc.header.maps = (1, 2, 3)
|
878
892
|
mrc.header["origin"] = tuple(self.origin[::-1])
|
@@ -1594,7 +1608,7 @@ class Density:
|
|
1594
1608
|
rotation_matrix: NDArray,
|
1595
1609
|
translation: NDArray = None,
|
1596
1610
|
order: int = 3,
|
1597
|
-
use_geometric_center: bool =
|
1611
|
+
use_geometric_center: bool = True,
|
1598
1612
|
) -> "Density":
|
1599
1613
|
"""
|
1600
1614
|
Performs a rigid transform of the class instance.
|
Binary file
|
@@ -0,0 +1,332 @@
|
|
1
|
+
/* Pybind extensions for template matching score space analyzers.
|
2
|
+
|
3
|
+
Copyright (c) 2023 European Molecular Biology Laboratory
|
4
|
+
|
5
|
+
Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <vector>
|
9
|
+
#include <iostream>
|
10
|
+
#include <limits>
|
11
|
+
|
12
|
+
#include <pybind11/stl.h>
|
13
|
+
#include <pybind11/numpy.h>
|
14
|
+
#include <pybind11/pybind11.h>
|
15
|
+
|
16
|
+
namespace py = pybind11;
|
17
|
+
|
18
|
+
template <typename T>
|
19
|
+
void absolute_minimum_deviation(
|
20
|
+
py::array_t<T, py::array::c_style> coordinates,
|
21
|
+
py::array_t<T, py::array::c_style> output) {
|
22
|
+
auto coordinates_data = coordinates.data();
|
23
|
+
auto output_data = output.mutable_data();
|
24
|
+
int n = coordinates.shape(0);
|
25
|
+
int k = coordinates.shape(1);
|
26
|
+
int ik, jk, in, jn;
|
27
|
+
|
28
|
+
for (int i = 0; i < n; ++i) {
|
29
|
+
ik = i * k;
|
30
|
+
in = i * n;
|
31
|
+
for (int j = i + 1; j < n; ++j) {
|
32
|
+
jk = j * k;
|
33
|
+
jn = j * n;
|
34
|
+
T min_distance = std::abs(coordinates_data[ik] - coordinates_data[jk]);
|
35
|
+
for (int p = 1; p < k; ++p) {
|
36
|
+
min_distance = std::min(min_distance,
|
37
|
+
std::abs(coordinates_data[ik + p] - coordinates_data[jk + p]));
|
38
|
+
}
|
39
|
+
output_data[in + j] = min_distance;
|
40
|
+
output_data[jn + i] = min_distance;
|
41
|
+
}
|
42
|
+
output_data[in + i] = 0;
|
43
|
+
}
|
44
|
+
}
|
45
|
+
|
46
|
+
template <typename T>
|
47
|
+
std::pair<double, std::pair<int, int>> max_euclidean_distance(
|
48
|
+
py::array_t<T, py::array::c_style> coordinates) {
|
49
|
+
auto coordinates_data = coordinates.data();
|
50
|
+
int n = coordinates.shape(0);
|
51
|
+
int k = coordinates.shape(1);
|
52
|
+
|
53
|
+
double distance = 0.0;
|
54
|
+
double difference = 0.0;
|
55
|
+
double max_distance = -1;
|
56
|
+
double squared_distances = 0.0;
|
57
|
+
|
58
|
+
int ik, jk;
|
59
|
+
int max_i = -1, max_j = -1;
|
60
|
+
|
61
|
+
for (int i = 0; i < n; ++i) {
|
62
|
+
ik = i * k;
|
63
|
+
for (int j = i + 1; j < n; ++j) {
|
64
|
+
jk = j * k;
|
65
|
+
squared_distances = 0.0;
|
66
|
+
for (int p = 0; p < k; ++p) {
|
67
|
+
difference = static_cast<double>(
|
68
|
+
coordinates_data[ik + p] - coordinates_data[jk + p]
|
69
|
+
);
|
70
|
+
squared_distances += (difference * difference);
|
71
|
+
}
|
72
|
+
distance = std::sqrt(squared_distances);
|
73
|
+
if (distance > max_distance) {
|
74
|
+
max_distance = distance;
|
75
|
+
max_i = i;
|
76
|
+
max_j = j;
|
77
|
+
}
|
78
|
+
}
|
79
|
+
}
|
80
|
+
|
81
|
+
return std::make_pair(max_distance, std::make_pair(max_i, max_j));
|
82
|
+
}
|
83
|
+
|
84
|
+
|
85
|
+
template <typename T>
|
86
|
+
inline py::array_t<int, py::array::c_style> find_candidate_indices(
|
87
|
+
py::array_t<T, py::array::c_style> coordinates,
|
88
|
+
T min_distance) {
|
89
|
+
auto coordinates_data = coordinates.data();
|
90
|
+
int n = coordinates.shape(0);
|
91
|
+
int k = coordinates.shape(1);
|
92
|
+
int ik, jk;
|
93
|
+
|
94
|
+
std::vector<int> candidate_indices;
|
95
|
+
candidate_indices.reserve(n / 2);
|
96
|
+
candidate_indices.push_back(0);
|
97
|
+
|
98
|
+
for (int i = 1; i < n; ++i) {
|
99
|
+
bool is_candidate = true;
|
100
|
+
ik = i * k;
|
101
|
+
for (int candidate_index : candidate_indices) {
|
102
|
+
jk = candidate_index * k;
|
103
|
+
T distance = std::pow(coordinates_data[ik] - coordinates_data[jk], 2);
|
104
|
+
for (int p = 1; p < k; ++p) {
|
105
|
+
distance += std::pow(coordinates_data[ik + p] - coordinates_data[jk + p], 2);
|
106
|
+
}
|
107
|
+
distance = std::sqrt(distance);
|
108
|
+
if (distance <= min_distance) {
|
109
|
+
is_candidate = false;
|
110
|
+
break;
|
111
|
+
}
|
112
|
+
}
|
113
|
+
if (is_candidate) {
|
114
|
+
candidate_indices.push_back(i);
|
115
|
+
}
|
116
|
+
}
|
117
|
+
|
118
|
+
py::array_t<int, py::array::c_style> output({(int)candidate_indices.size()});
|
119
|
+
auto output_data = output.mutable_data();
|
120
|
+
|
121
|
+
for (size_t i = 0; i < candidate_indices.size(); ++i) {
|
122
|
+
output_data[i] = candidate_indices[i];
|
123
|
+
}
|
124
|
+
|
125
|
+
return output;
|
126
|
+
}
|
127
|
+
|
128
|
+
template <typename T>
|
129
|
+
py::array_t<T, py::array::c_style> find_candidate_coordinates(
|
130
|
+
py::array_t<T, py::array::c_style> coordinates,
|
131
|
+
T min_distance) {
|
132
|
+
|
133
|
+
py::array_t<int, py::array::c_style> candidate_indices_array = find_candidate_indices(
|
134
|
+
coordinates, min_distance);
|
135
|
+
auto candidate_indices_data = candidate_indices_array.data();
|
136
|
+
int num_candidates = candidate_indices_array.shape(0);
|
137
|
+
int k = coordinates.shape(1);
|
138
|
+
auto coordinates_data = coordinates.data();
|
139
|
+
|
140
|
+
py::array_t<T, py::array::c_style> output({num_candidates, k});
|
141
|
+
auto output_data = output.mutable_data();
|
142
|
+
|
143
|
+
for (int i = 0; i < num_candidates; ++i) {
|
144
|
+
int candidate_index = candidate_indices_data[i] * k;
|
145
|
+
std::copy(
|
146
|
+
coordinates_data + candidate_index,
|
147
|
+
coordinates_data + candidate_index + k,
|
148
|
+
output_data + i * k
|
149
|
+
);
|
150
|
+
}
|
151
|
+
|
152
|
+
return output;
|
153
|
+
}
|
154
|
+
|
155
|
+
template <typename U, typename T>
|
156
|
+
py::dict max_index_by_label(
|
157
|
+
py::array_t<U, py::array::c_style> labels,
|
158
|
+
py::array_t<T, py::array::c_style> scores
|
159
|
+
) {
|
160
|
+
|
161
|
+
const U* labels_ptr = labels.data();
|
162
|
+
const T* scores_ptr = scores.data();
|
163
|
+
|
164
|
+
std::unordered_map<U, std::pair<T, ssize_t>> max_scores;
|
165
|
+
|
166
|
+
U label;
|
167
|
+
T score;
|
168
|
+
for (ssize_t i = 0; i < labels.size(); ++i) {
|
169
|
+
label = labels_ptr[i];
|
170
|
+
score = scores_ptr[i];
|
171
|
+
|
172
|
+
auto it = max_scores.insert({label, {score, i}});
|
173
|
+
|
174
|
+
if (score > it.first->second.first) {
|
175
|
+
it.first->second = {score, i};
|
176
|
+
}
|
177
|
+
}
|
178
|
+
|
179
|
+
py::dict ret;
|
180
|
+
for (auto& item: max_scores) {
|
181
|
+
ret[py::cast(item.first)] = py::cast(item.second.second);
|
182
|
+
}
|
183
|
+
|
184
|
+
return ret;
|
185
|
+
}
|
186
|
+
|
187
|
+
|
188
|
+
template <typename T>
|
189
|
+
py::tuple online_statistics(
|
190
|
+
py::array_t<T, py::array::c_style> arr,
|
191
|
+
unsigned long long int n = 0,
|
192
|
+
double rmean = 0,
|
193
|
+
double ssqd = 0,
|
194
|
+
T reference = 0) {
|
195
|
+
|
196
|
+
auto in = arr.data();
|
197
|
+
int size = arr.size();
|
198
|
+
|
199
|
+
T max_value = std::numeric_limits<T>::lowest();
|
200
|
+
T min_value = std::numeric_limits<T>::max();
|
201
|
+
double delta, delta_prime;
|
202
|
+
|
203
|
+
unsigned long long int nbetter_or_equal = 0;
|
204
|
+
|
205
|
+
for(int i = 0; i < size; i++){
|
206
|
+
n++;
|
207
|
+
delta = in[i] - rmean;
|
208
|
+
rmean += delta / n;
|
209
|
+
delta_prime = in[i] - rmean;
|
210
|
+
ssqd += delta * delta_prime;
|
211
|
+
|
212
|
+
max_value = std::max(in[i], max_value);
|
213
|
+
min_value = std::min(in[i], min_value);
|
214
|
+
if (in[i] >= reference)
|
215
|
+
nbetter_or_equal++;
|
216
|
+
}
|
217
|
+
|
218
|
+
return py::make_tuple(n, rmean, ssqd, nbetter_or_equal, max_value, min_value);
|
219
|
+
}
|
220
|
+
|
221
|
+
PYBIND11_MODULE(extensions, m) {
|
222
|
+
|
223
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<double>,
|
224
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (float64).",
|
225
|
+
py::arg("coordinates"), py::arg("output"));
|
226
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<float>,
|
227
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (float32).",
|
228
|
+
py::arg("coordinates"), py::arg("output"));
|
229
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<int64_t>,
|
230
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (int64).",
|
231
|
+
py::arg("coordinates"), py::arg("output"));
|
232
|
+
m.def("absolute_minimum_deviation", absolute_minimum_deviation<int32_t>,
|
233
|
+
"Compute pairwise absolute minimum deviation for a set of coordinates (int32).",
|
234
|
+
py::arg("coordinates"), py::arg("output"));
|
235
|
+
|
236
|
+
|
237
|
+
m.def("max_euclidean_distance", max_euclidean_distance<double>,
|
238
|
+
"Identify pair of points with maximal euclidean distance (float64).",
|
239
|
+
py::arg("coordinates"));
|
240
|
+
m.def("max_euclidean_distance", max_euclidean_distance<float>,
|
241
|
+
"Identify pair of points with maximal euclidean distance (float32).",
|
242
|
+
py::arg("coordinates"));
|
243
|
+
m.def("max_euclidean_distance", max_euclidean_distance<int64_t>,
|
244
|
+
"Identify pair of points with maximal euclidean distance (int64).",
|
245
|
+
py::arg("coordinates"));
|
246
|
+
m.def("max_euclidean_distance", max_euclidean_distance<int32_t>,
|
247
|
+
"Identify pair of points with maximal euclidean distance (int32).",
|
248
|
+
py::arg("coordinates"));
|
249
|
+
|
250
|
+
|
251
|
+
m.def("find_candidate_indices", &find_candidate_indices<double>,
|
252
|
+
"Finds candidate indices with minimum distance (float64).",
|
253
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
254
|
+
m.def("find_candidate_indices", &find_candidate_indices<float>,
|
255
|
+
"Finds candidate indices with minimum distance (float32).",
|
256
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
257
|
+
m.def("find_candidate_indices", &find_candidate_indices<int64_t>,
|
258
|
+
"Finds candidate indices with minimum distance (int64).",
|
259
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
260
|
+
m.def("find_candidate_indices", &find_candidate_indices<int32_t>,
|
261
|
+
"Finds candidate indices with minimum distance (int32).",
|
262
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
263
|
+
|
264
|
+
|
265
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<double>,
|
266
|
+
"Finds candidate coordinates with minimum distance (float64).",
|
267
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
268
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<float>,
|
269
|
+
"Finds candidate coordinates with minimum distance (float32).",
|
270
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
271
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<int64_t>,
|
272
|
+
"Finds candidate coordinates with minimum distance (int64).",
|
273
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
274
|
+
m.def("find_candidate_coordinates", &find_candidate_coordinates<int32_t>,
|
275
|
+
"Finds candidate coordinates with minimum distance (int32).",
|
276
|
+
py::arg("coordinates"), py::arg("min_distance"));
|
277
|
+
|
278
|
+
|
279
|
+
m.def("max_index_by_label", &max_index_by_label<double, double>,
|
280
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
281
|
+
m.def("max_index_by_label", &max_index_by_label<double, float>,
|
282
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
283
|
+
m.def("max_index_by_label", &max_index_by_label<double, int64_t>,
|
284
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
285
|
+
m.def("max_index_by_label", &max_index_by_label<double, int32_t>,
|
286
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
287
|
+
|
288
|
+
m.def("max_index_by_label", &max_index_by_label<float, double>,
|
289
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
290
|
+
m.def("max_index_by_label", &max_index_by_label<float, float>,
|
291
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
292
|
+
m.def("max_index_by_label", &max_index_by_label<float, int64_t>,
|
293
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
294
|
+
m.def("max_index_by_label", &max_index_by_label<float, int32_t>,
|
295
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
296
|
+
|
297
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, double>,
|
298
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
299
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, float>,
|
300
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
301
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, int64_t>,
|
302
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
303
|
+
m.def("max_index_by_label", &max_index_by_label<int64_t, int32_t>,
|
304
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
305
|
+
|
306
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, double>,
|
307
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
308
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, float>,
|
309
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
310
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, int64_t>,
|
311
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
312
|
+
m.def("max_index_by_label", &max_index_by_label<int32_t, int32_t>,
|
313
|
+
"Maximum value by label", py::arg("labels"), py::arg("scores"));
|
314
|
+
|
315
|
+
|
316
|
+
m.def("online_statistics", &online_statistics<double>, py::arg("arr"),
|
317
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
318
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
319
|
+
"Compute running online statistics on a numpy array.");
|
320
|
+
m.def("online_statistics", &online_statistics<float>, py::arg("arr"),
|
321
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
322
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
323
|
+
"Compute running online statistics on a numpy array.");
|
324
|
+
m.def("online_statistics", &online_statistics<int64_t>, py::arg("arr"),
|
325
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
326
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
327
|
+
"Compute running online statistics on a numpy array.");
|
328
|
+
m.def("online_statistics", &online_statistics<int32_t>, py::arg("arr"),
|
329
|
+
py::arg("n") = 0, py::arg("rmean") = 0,
|
330
|
+
py::arg("ssqd") = 0, py::arg("reference") = 0,
|
331
|
+
"Compute running online statistics on a numpy array.");
|
332
|
+
}
|