pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/match_template.py +97 -148
  2. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/postprocess.py +20 -29
  3. pytme-0.2.4.data/scripts/preprocess.py +148 -0
  4. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +15 -23
  5. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/METADATA +11 -10
  6. pytme-0.2.4.dist-info/RECORD +119 -0
  7. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
  8. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
  9. pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
  10. scripts/match_template.py +97 -148
  11. scripts/postprocess.py +20 -29
  12. scripts/preprocess.py +116 -61
  13. scripts/preprocessor_gui.py +15 -23
  14. tests/__init__.py +0 -0
  15. tests/data/.DS_Store +0 -0
  16. tests/data/Blurring/.DS_Store +0 -0
  17. tests/data/Blurring/blob_width18.npy +0 -0
  18. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  19. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  20. tests/data/Blurring/hamming_width6.npy +0 -0
  21. tests/data/Blurring/kaiserb_width18.npy +0 -0
  22. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  23. tests/data/Blurring/mean_size5.npy +0 -0
  24. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  25. tests/data/Blurring/rank_rank3.npy +0 -0
  26. tests/data/Maps/.DS_Store +0 -0
  27. tests/data/Maps/emd_8621.mrc.gz +0 -0
  28. tests/data/README.md +2 -0
  29. tests/data/Raw/.DS_Store +0 -0
  30. tests/data/Raw/em_map.map +0 -0
  31. tests/data/Structures/.DS_Store +0 -0
  32. tests/data/Structures/1pdj.cif +3339 -0
  33. tests/data/Structures/1pdj.pdb +1429 -0
  34. tests/data/Structures/5khe.cif +3685 -0
  35. tests/data/Structures/5khe.ent +2210 -0
  36. tests/data/Structures/5khe.pdb +2210 -0
  37. tests/data/Structures/5uz4.cif +70548 -0
  38. tests/preprocessing/__init__.py +0 -0
  39. tests/preprocessing/test_compose.py +76 -0
  40. tests/preprocessing/test_frequency_filters.py +178 -0
  41. tests/preprocessing/test_preprocessor.py +136 -0
  42. tests/preprocessing/test_utils.py +79 -0
  43. tests/test_analyzer.py +310 -0
  44. tests/test_backends.py +375 -0
  45. tests/test_density.py +508 -0
  46. tests/test_extensions.py +130 -0
  47. tests/test_matching_cli.py +283 -0
  48. tests/test_matching_data.py +162 -0
  49. tests/test_matching_exhaustive.py +162 -0
  50. tests/test_matching_memory.py +30 -0
  51. tests/test_matching_optimization.py +276 -0
  52. tests/test_matching_utils.py +326 -0
  53. tests/test_orientations.py +173 -0
  54. tests/test_packaging.py +95 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_structure.py +243 -0
  57. tme/__init__.py +0 -1
  58. tme/__version__.py +1 -1
  59. tme/analyzer.py +9 -6
  60. tme/backends/__init__.py +1 -1
  61. tme/backends/_jax_utils.py +10 -8
  62. tme/backends/cupy_backend.py +2 -7
  63. tme/backends/jax_backend.py +35 -20
  64. tme/backends/npfftw_backend.py +3 -2
  65. tme/backends/pytorch_backend.py +10 -7
  66. tme/data/scattering_factors.pickle +0 -0
  67. tme/density.py +26 -12
  68. tme/extensions.cpython-311-darwin.so +0 -0
  69. tme/external/bindings.cpp +332 -0
  70. tme/matching_data.py +33 -24
  71. tme/matching_exhaustive.py +39 -20
  72. tme/matching_scores.py +5 -2
  73. tme/matching_utils.py +8 -2
  74. tme/orientations.py +26 -9
  75. tme/preprocessing/_utils.py +14 -14
  76. tme/preprocessing/composable_filter.py +5 -4
  77. tme/preprocessing/compose.py +4 -4
  78. tme/preprocessing/frequency_filters.py +32 -35
  79. tme/preprocessing/tilt_series.py +210 -148
  80. tme/preprocessor.py +24 -246
  81. tme/structure.py +14 -14
  82. pytme-0.2.2.dist-info/RECORD +0 -74
  83. tme/matching_memory.py +0 -383
  84. {pytme-0.2.2.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
  85. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
  86. {pytme-0.2.2.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
@@ -149,10 +149,10 @@ class CupyBackend(NumpyFFTWBackend):
149
149
  cache.clear()
150
150
 
151
151
  def rfftn(arr: CupyArray, out: CupyArray) -> CupyArray:
152
- return cufft.rfftn(arr)
152
+ return cufft.rfftn(arr, s=fast_shape)
153
153
 
154
154
  def irfftn(arr: CupyArray, out: CupyArray) -> CupyArray:
155
- return cufft.irfftn(arr)
155
+ return cufft.irfftn(arr, s=fast_shape)
156
156
 
157
157
  PLAN_CACHE[current_device] = [fast_shape, fast_ft_shape]
158
158
 
@@ -167,11 +167,6 @@ class CupyBackend(NumpyFFTWBackend):
167
167
  fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
168
168
  fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
169
169
 
170
- # This almost never happens but avoid cuFFT casting errors
171
- is_odd = fast_shape[-1] % 2
172
- fast_shape[-1] += is_odd
173
- fast_ft_shape[-1] += is_odd
174
-
175
170
  return convolution_shape, fast_shape, fast_ft_shape
176
171
 
177
172
  def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
@@ -119,19 +119,6 @@ class JaxBackend(NumpyFFTWBackend):
119
119
 
120
120
  return rfftn, irfftn
121
121
 
122
- def compute_convolution_shapes(
123
- self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
124
- ) -> Tuple[List[int], List[int], List[int]]:
125
- conv_shape, fast_shape, fast_ft_shape = super().compute_convolution_shapes(
126
- arr1_shape, arr2_shape
127
- )
128
-
129
- is_odd = fast_shape[-1] % 2
130
- fast_shape[-1] += is_odd
131
- fast_ft_shape[-1] += is_odd
132
-
133
- return conv_shape, fast_shape, fast_ft_shape
134
-
135
122
  def rigid_transform(
136
123
  self,
137
124
  arr: BackendArray,
@@ -144,8 +131,8 @@ class JaxBackend(NumpyFFTWBackend):
144
131
  **kwargs,
145
132
  ) -> Tuple[BackendArray, BackendArray]:
146
133
  rotate_mask = arr_mask is not None
147
- center = self.divide(self.to_backend_array(arr.shape), 2)[:, None]
148
134
 
135
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)[:, None]
149
136
  indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
150
137
  indices = indices.reshape((arr.ndim, -1))
151
138
  indices = indices.at[:].add(-center)
@@ -200,7 +187,7 @@ class JaxBackend(NumpyFFTWBackend):
200
187
  target_shape = tuple(
201
188
  (x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
202
189
  )
203
- fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
190
+ conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
204
191
  target_shape=self.to_numpy_array(target_shape),
205
192
  template_shape=self.to_numpy_array(matching_data._template.shape),
206
193
  pad_fourier=False,
@@ -210,13 +197,26 @@ class JaxBackend(NumpyFFTWBackend):
210
197
  "convolution_mode": convolution_mode,
211
198
  "fourier_shift": shift,
212
199
  "targetshape": target_shape,
213
- "templateshape": matching_data._template.shape,
200
+ "templateshape": matching_data.template.shape,
201
+ "convolution_shape": conv_shape,
214
202
  }
215
203
 
216
204
  create_target_filter = matching_data.target_filter is not None
217
205
  create_template_filter = matching_data.template_filter is not None
218
206
  create_filter = create_target_filter or create_template_filter
219
207
 
208
+ # Applying the filter leads to more FFTs
209
+ fastt_shape = matching_data._template.shape
210
+ if create_template_filter:
211
+ # _, fastt_shape, _, tshift = matching_data._fourier_padding(
212
+ # target_shape=self.to_numpy_array(matching_data._template.shape),
213
+ # template_shape=self.to_numpy_array(
214
+ # [1 for _ in matching_data._template.shape]
215
+ # ),
216
+ # pad_fourier=False,
217
+ # )
218
+ fastt_shape = matching_data._template.shape
219
+
220
220
  ret, template_filter, target_filter = [], 1, 1
221
221
  rotation_mapping = {
222
222
  self.tobytes(matching_data.rotations[i]): i
@@ -246,12 +246,12 @@ class JaxBackend(NumpyFFTWBackend):
246
246
 
247
247
  if create_template_filter:
248
248
  template_filter = matching_data.template_filter(
249
- shape=matching_data._template.shape, **filter_args
249
+ shape=fastt_shape, **filter_args
250
250
  )["data"]
251
251
  template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
252
252
 
253
253
  if create_target_filter:
254
- target_filter = matching_data.template_filter(
254
+ target_filter = matching_data.target_filter(
255
255
  shape=fast_shape, **filter_args
256
256
  )["data"]
257
257
  target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
@@ -260,8 +260,8 @@ class JaxBackend(NumpyFFTWBackend):
260
260
  base, targets = None, self._array_backend.stack(targets)
261
261
  scores, rotations = scan_inner(
262
262
  targets,
263
- matching_data.template,
264
- matching_data.template_mask,
263
+ self.topleft_pad(matching_data.template, fastt_shape),
264
+ self.topleft_pad(matching_data.template_mask, fastt_shape),
265
265
  matching_data.rotations,
266
266
  template_filter,
267
267
  target_filter,
@@ -280,3 +280,18 @@ class JaxBackend(NumpyFFTWBackend):
280
280
  ret.append(tuple(temp._postprocess(**analyzer_args)))
281
281
 
282
282
  return ret
283
+
284
+ def get_available_memory(self) -> int:
285
+ import jax
286
+
287
+ _memory = {"cpu": 0, "gpu": 0}
288
+ for device in jax.devices():
289
+ if device.platform == "cpu":
290
+ _memory["cpu"] = super().get_available_memory()
291
+ else:
292
+ mem_stats = device.memory_stats()
293
+ _memory["gpu"] += mem_stats.get("bytes_limit", 0)
294
+
295
+ if _memory["gpu"] > 0:
296
+ return _memory["gpu"]
297
+ return _memory["cpu"]
@@ -186,7 +186,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
186
186
  def to_sharedarr(
187
187
  self, arr: NDArray, shared_memory_handler: type = None
188
188
  ) -> shm_type:
189
- if type(shared_memory_handler) == SharedMemoryManager:
189
+ if isinstance(shared_memory_handler, SharedMemoryManager):
190
190
  shm = shared_memory_handler.SharedMemory(size=arr.nbytes)
191
191
  else:
192
192
  shm = shared_memory.SharedMemory(create=True, size=arr.nbytes)
@@ -347,7 +347,8 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
347
347
  cache: bool = False,
348
348
  ) -> Tuple[NDArray, NDArray]:
349
349
  translation = self.zeros(arr.ndim) if translation is None else translation
350
- center = self.divide(self.to_backend_array(arr.shape), 2)
350
+
351
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
351
352
  if not use_geometric_center:
352
353
  center = self.center_of_mass(arr, cutoff=0)
353
354
 
@@ -81,13 +81,13 @@ class PytorchBackend(NumpyFFTWBackend):
81
81
 
82
82
  def max(self, *args, **kwargs) -> NDArray:
83
83
  ret = self._array_backend.amax(*args, **kwargs)
84
- if type(ret) == self._array_backend.Tensor:
84
+ if isinstance(ret, self._array_backend.Tensor):
85
85
  return ret
86
86
  return ret[0]
87
87
 
88
88
  def min(self, *args, **kwargs) -> NDArray:
89
89
  ret = self._array_backend.amin(*args, **kwargs)
90
- if type(ret) == self._array_backend.Tensor:
90
+ if isinstance(ret, self._array_backend.Tensor):
91
91
  return ret
92
92
  return ret[0]
93
93
 
@@ -154,7 +154,7 @@ class PytorchBackend(NumpyFFTWBackend):
154
154
  1, -1
155
155
  )
156
156
  if unraveled_coords.size(0) == 1:
157
- return tuple(unraveled_coords[0, :].tolist())
157
+ return (unraveled_coords[0, :],)
158
158
 
159
159
  else:
160
160
  return tuple(unraveled_coords.T)
@@ -206,7 +206,9 @@ class PytorchBackend(NumpyFFTWBackend):
206
206
  else:
207
207
  raise NotImplementedError("Operation only implemented for 2 and 3D inputs.")
208
208
 
209
- pool = func(kernel_size=min_distance, return_indices=True)
209
+ pool = func(
210
+ kernel_size=min_distance, padding=min_distance // 2, return_indices=True
211
+ )
210
212
  _, indices = pool(score_space.reshape(1, 1, *score_space.shape))
211
213
  coordinates = self.unravel_index(indices.reshape(-1), score_space.shape)
212
214
  coordinates = self.transpose(self.stack(coordinates))
@@ -217,7 +219,7 @@ class PytorchBackend(NumpyFFTWBackend):
217
219
 
218
220
  def from_sharedarr(self, args) -> TorchTensor:
219
221
  if self.device == "cuda":
220
- return args[0]
222
+ return args
221
223
 
222
224
  shm, shape, dtype = args
223
225
  required_size = int(self._array_backend.prod(self.to_backend_array(shape)))
@@ -235,13 +237,12 @@ class PytorchBackend(NumpyFFTWBackend):
235
237
 
236
238
  nbytes = arr.numel() * arr.element_size()
237
239
 
238
- if type(shared_memory_handler) == SharedMemoryManager:
240
+ if isinstance(shared_memory_handler, SharedMemoryManager):
239
241
  shm = shared_memory_handler.SharedMemory(size=nbytes)
240
242
  else:
241
243
  shm = shared_memory.SharedMemory(create=True, size=nbytes)
242
244
 
243
245
  shm.buf[:nbytes] = arr.numpy().tobytes()
244
-
245
246
  return shm, arr.shape, arr.dtype
246
247
 
247
248
  def transpose(self, arr):
@@ -415,6 +416,8 @@ class PytorchBackend(NumpyFFTWBackend):
415
416
  yield None
416
417
 
417
418
  def device_count(self) -> int:
419
+ if self.device == "cpu":
420
+ return 1
418
421
  return self._array_backend.cuda.device_count()
419
422
 
420
423
  def reverse(self, arr: TorchTensor) -> TorchTensor:
Binary file
tme/density.py CHANGED
@@ -116,8 +116,8 @@ class Density:
116
116
  response = "Density object at {}\nOrigin: {}, Sampling Rate: {}, Shape: {}"
117
117
  return response.format(
118
118
  hex(id(self)),
119
- tuple(np.round(self.origin, 3)),
120
- tuple(np.round(self.sampling_rate, 3)),
119
+ tuple(round(float(x), 3) for x in self.origin),
120
+ tuple(round(float(x), 3) for x in self.sampling_rate),
121
121
  self.shape,
122
122
  )
123
123
 
@@ -306,6 +306,10 @@ class Density:
306
306
  "std": float(mrc.header.rms),
307
307
  }
308
308
 
309
+ non_standard_crs = not np.all(crs_index == (0, 1, 2))
310
+ if non_standard_crs:
311
+ warnings.warn("Non standard MAPC, MAPR, MAPS, adapting data and origin.")
312
+
309
313
  if is_gzipped(filename):
310
314
  if use_memmap:
311
315
  warnings.warn(
@@ -315,6 +319,10 @@ class Density:
315
319
  use_memmap = False
316
320
 
317
321
  if subset is not None:
322
+ subset = tuple(
323
+ subset[i] if i < len(subset) else slice(0, data_shape[i])
324
+ for i in crs_index
325
+ )
318
326
  subset_shape = tuple(x.stop - x.start for x in subset)
319
327
  if np.allclose(subset_shape, data_shape):
320
328
  return cls._load_mrc(
@@ -328,18 +336,16 @@ class Density:
328
336
  dtype=data_type,
329
337
  header_size=1024 + extended_header,
330
338
  )
331
- return data, origin, sampling_rate, metadata
332
-
333
- if not use_memmap:
339
+ elif subset is None and not use_memmap:
334
340
  with mrcfile.open(filename, header_only=False) as mrc:
335
341
  data = mrc.data.astype(np.float32, copy=False)
336
342
  else:
337
343
  with mrcfile.mrcmemmap.MrcMemmap(filename, header_only=False) as mrc:
338
344
  data = mrc.data
339
345
 
340
- if not np.all(crs_index == (0, 1, 2)):
346
+ if non_standard_crs:
341
347
  data = np.transpose(data, crs_index)
342
- start = np.take(start, crs_index)
348
+ origin = np.take(origin, crs_index)
343
349
 
344
350
  return data, origin, sampling_rate, metadata
345
351
 
@@ -738,9 +744,16 @@ class Density:
738
744
  >>> )
739
745
 
740
746
  :py:meth:`Density.from_structure` supports a variety of methods to convert
741
- atoms into densities. In addition to 'atomic_weight', 'atomic_number',
742
- and 'van_der_waals_radius', its possible to use experimentally determined
743
- scattering factors from various sources:
747
+ atoms into densities
748
+
749
+ >>> density = Density.from_structure(
750
+ >>> filename_or_structure = path_to_structure,
751
+ >>> weight_type = "gaussian",
752
+ >>> weight_type_args={"resolution": "20"}
753
+ >>> )
754
+
755
+ In addition its possible to use experimentally determined scattering factors
756
+ from various sources:
744
757
 
745
758
  >>> density = Density.from_structure(
746
759
  >>> filename_or_structure = path_to_structure,
@@ -748,7 +761,7 @@ class Density:
748
761
  >>> weight_type_args={"source": "dt1969"}
749
762
  >>> )
750
763
 
751
- or a lowpass filtered representation introduced in [1]_:
764
+ and their lowpass filtered representation introduced in [1]_:
752
765
 
753
766
  >>> density = Density.from_structure(
754
767
  >>> filename_or_structure = path_to_structure,
@@ -873,6 +886,7 @@ class Density:
873
886
  mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.rint(
874
887
  np.divide(self.origin, self.sampling_rate)
875
888
  )
889
+ mrc.header.origin = tuple(x for x in self.origin)
876
890
  # mrcfile library expects origin to be in xyz format
877
891
  mrc.header.mapc, mrc.header.mapr, mrc.header.maps = (1, 2, 3)
878
892
  mrc.header["origin"] = tuple(self.origin[::-1])
@@ -1594,7 +1608,7 @@ class Density:
1594
1608
  rotation_matrix: NDArray,
1595
1609
  translation: NDArray = None,
1596
1610
  order: int = 3,
1597
- use_geometric_center: bool = False,
1611
+ use_geometric_center: bool = True,
1598
1612
  ) -> "Density":
1599
1613
  """
1600
1614
  Performs a rigid transform of the class instance.
Binary file
@@ -0,0 +1,332 @@
1
+ /* Pybind extensions for template matching score space analyzers.
2
+
3
+ Copyright (c) 2023 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ */
7
+
8
+ #include <vector>
9
+ #include <iostream>
10
+ #include <limits>
11
+
12
+ #include <pybind11/stl.h>
13
+ #include <pybind11/numpy.h>
14
+ #include <pybind11/pybind11.h>
15
+
16
+ namespace py = pybind11;
17
+
18
+ template <typename T>
19
+ void absolute_minimum_deviation(
20
+ py::array_t<T, py::array::c_style> coordinates,
21
+ py::array_t<T, py::array::c_style> output) {
22
+ auto coordinates_data = coordinates.data();
23
+ auto output_data = output.mutable_data();
24
+ int n = coordinates.shape(0);
25
+ int k = coordinates.shape(1);
26
+ int ik, jk, in, jn;
27
+
28
+ for (int i = 0; i < n; ++i) {
29
+ ik = i * k;
30
+ in = i * n;
31
+ for (int j = i + 1; j < n; ++j) {
32
+ jk = j * k;
33
+ jn = j * n;
34
+ T min_distance = std::abs(coordinates_data[ik] - coordinates_data[jk]);
35
+ for (int p = 1; p < k; ++p) {
36
+ min_distance = std::min(min_distance,
37
+ std::abs(coordinates_data[ik + p] - coordinates_data[jk + p]));
38
+ }
39
+ output_data[in + j] = min_distance;
40
+ output_data[jn + i] = min_distance;
41
+ }
42
+ output_data[in + i] = 0;
43
+ }
44
+ }
45
+
46
+ template <typename T>
47
+ std::pair<double, std::pair<int, int>> max_euclidean_distance(
48
+ py::array_t<T, py::array::c_style> coordinates) {
49
+ auto coordinates_data = coordinates.data();
50
+ int n = coordinates.shape(0);
51
+ int k = coordinates.shape(1);
52
+
53
+ double distance = 0.0;
54
+ double difference = 0.0;
55
+ double max_distance = -1;
56
+ double squared_distances = 0.0;
57
+
58
+ int ik, jk;
59
+ int max_i = -1, max_j = -1;
60
+
61
+ for (int i = 0; i < n; ++i) {
62
+ ik = i * k;
63
+ for (int j = i + 1; j < n; ++j) {
64
+ jk = j * k;
65
+ squared_distances = 0.0;
66
+ for (int p = 0; p < k; ++p) {
67
+ difference = static_cast<double>(
68
+ coordinates_data[ik + p] - coordinates_data[jk + p]
69
+ );
70
+ squared_distances += (difference * difference);
71
+ }
72
+ distance = std::sqrt(squared_distances);
73
+ if (distance > max_distance) {
74
+ max_distance = distance;
75
+ max_i = i;
76
+ max_j = j;
77
+ }
78
+ }
79
+ }
80
+
81
+ return std::make_pair(max_distance, std::make_pair(max_i, max_j));
82
+ }
83
+
84
+
85
+ template <typename T>
86
+ inline py::array_t<int, py::array::c_style> find_candidate_indices(
87
+ py::array_t<T, py::array::c_style> coordinates,
88
+ T min_distance) {
89
+ auto coordinates_data = coordinates.data();
90
+ int n = coordinates.shape(0);
91
+ int k = coordinates.shape(1);
92
+ int ik, jk;
93
+
94
+ std::vector<int> candidate_indices;
95
+ candidate_indices.reserve(n / 2);
96
+ candidate_indices.push_back(0);
97
+
98
+ for (int i = 1; i < n; ++i) {
99
+ bool is_candidate = true;
100
+ ik = i * k;
101
+ for (int candidate_index : candidate_indices) {
102
+ jk = candidate_index * k;
103
+ T distance = std::pow(coordinates_data[ik] - coordinates_data[jk], 2);
104
+ for (int p = 1; p < k; ++p) {
105
+ distance += std::pow(coordinates_data[ik + p] - coordinates_data[jk + p], 2);
106
+ }
107
+ distance = std::sqrt(distance);
108
+ if (distance <= min_distance) {
109
+ is_candidate = false;
110
+ break;
111
+ }
112
+ }
113
+ if (is_candidate) {
114
+ candidate_indices.push_back(i);
115
+ }
116
+ }
117
+
118
+ py::array_t<int, py::array::c_style> output({(int)candidate_indices.size()});
119
+ auto output_data = output.mutable_data();
120
+
121
+ for (size_t i = 0; i < candidate_indices.size(); ++i) {
122
+ output_data[i] = candidate_indices[i];
123
+ }
124
+
125
+ return output;
126
+ }
127
+
128
+ template <typename T>
129
+ py::array_t<T, py::array::c_style> find_candidate_coordinates(
130
+ py::array_t<T, py::array::c_style> coordinates,
131
+ T min_distance) {
132
+
133
+ py::array_t<int, py::array::c_style> candidate_indices_array = find_candidate_indices(
134
+ coordinates, min_distance);
135
+ auto candidate_indices_data = candidate_indices_array.data();
136
+ int num_candidates = candidate_indices_array.shape(0);
137
+ int k = coordinates.shape(1);
138
+ auto coordinates_data = coordinates.data();
139
+
140
+ py::array_t<T, py::array::c_style> output({num_candidates, k});
141
+ auto output_data = output.mutable_data();
142
+
143
+ for (int i = 0; i < num_candidates; ++i) {
144
+ int candidate_index = candidate_indices_data[i] * k;
145
+ std::copy(
146
+ coordinates_data + candidate_index,
147
+ coordinates_data + candidate_index + k,
148
+ output_data + i * k
149
+ );
150
+ }
151
+
152
+ return output;
153
+ }
154
+
155
+ template <typename U, typename T>
156
+ py::dict max_index_by_label(
157
+ py::array_t<U, py::array::c_style> labels,
158
+ py::array_t<T, py::array::c_style> scores
159
+ ) {
160
+
161
+ const U* labels_ptr = labels.data();
162
+ const T* scores_ptr = scores.data();
163
+
164
+ std::unordered_map<U, std::pair<T, ssize_t>> max_scores;
165
+
166
+ U label;
167
+ T score;
168
+ for (ssize_t i = 0; i < labels.size(); ++i) {
169
+ label = labels_ptr[i];
170
+ score = scores_ptr[i];
171
+
172
+ auto it = max_scores.insert({label, {score, i}});
173
+
174
+ if (score > it.first->second.first) {
175
+ it.first->second = {score, i};
176
+ }
177
+ }
178
+
179
+ py::dict ret;
180
+ for (auto& item: max_scores) {
181
+ ret[py::cast(item.first)] = py::cast(item.second.second);
182
+ }
183
+
184
+ return ret;
185
+ }
186
+
187
+
188
+ template <typename T>
189
+ py::tuple online_statistics(
190
+ py::array_t<T, py::array::c_style> arr,
191
+ unsigned long long int n = 0,
192
+ double rmean = 0,
193
+ double ssqd = 0,
194
+ T reference = 0) {
195
+
196
+ auto in = arr.data();
197
+ int size = arr.size();
198
+
199
+ T max_value = std::numeric_limits<T>::lowest();
200
+ T min_value = std::numeric_limits<T>::max();
201
+ double delta, delta_prime;
202
+
203
+ unsigned long long int nbetter_or_equal = 0;
204
+
205
+ for(int i = 0; i < size; i++){
206
+ n++;
207
+ delta = in[i] - rmean;
208
+ rmean += delta / n;
209
+ delta_prime = in[i] - rmean;
210
+ ssqd += delta * delta_prime;
211
+
212
+ max_value = std::max(in[i], max_value);
213
+ min_value = std::min(in[i], min_value);
214
+ if (in[i] >= reference)
215
+ nbetter_or_equal++;
216
+ }
217
+
218
+ return py::make_tuple(n, rmean, ssqd, nbetter_or_equal, max_value, min_value);
219
+ }
220
+
221
+ PYBIND11_MODULE(extensions, m) {
222
+
223
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<double>,
224
+ "Compute pairwise absolute minimum deviation for a set of coordinates (float64).",
225
+ py::arg("coordinates"), py::arg("output"));
226
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<float>,
227
+ "Compute pairwise absolute minimum deviation for a set of coordinates (float32).",
228
+ py::arg("coordinates"), py::arg("output"));
229
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<int64_t>,
230
+ "Compute pairwise absolute minimum deviation for a set of coordinates (int64).",
231
+ py::arg("coordinates"), py::arg("output"));
232
+ m.def("absolute_minimum_deviation", absolute_minimum_deviation<int32_t>,
233
+ "Compute pairwise absolute minimum deviation for a set of coordinates (int32).",
234
+ py::arg("coordinates"), py::arg("output"));
235
+
236
+
237
+ m.def("max_euclidean_distance", max_euclidean_distance<double>,
238
+ "Identify pair of points with maximal euclidean distance (float64).",
239
+ py::arg("coordinates"));
240
+ m.def("max_euclidean_distance", max_euclidean_distance<float>,
241
+ "Identify pair of points with maximal euclidean distance (float32).",
242
+ py::arg("coordinates"));
243
+ m.def("max_euclidean_distance", max_euclidean_distance<int64_t>,
244
+ "Identify pair of points with maximal euclidean distance (int64).",
245
+ py::arg("coordinates"));
246
+ m.def("max_euclidean_distance", max_euclidean_distance<int32_t>,
247
+ "Identify pair of points with maximal euclidean distance (int32).",
248
+ py::arg("coordinates"));
249
+
250
+
251
+ m.def("find_candidate_indices", &find_candidate_indices<double>,
252
+ "Finds candidate indices with minimum distance (float64).",
253
+ py::arg("coordinates"), py::arg("min_distance"));
254
+ m.def("find_candidate_indices", &find_candidate_indices<float>,
255
+ "Finds candidate indices with minimum distance (float32).",
256
+ py::arg("coordinates"), py::arg("min_distance"));
257
+ m.def("find_candidate_indices", &find_candidate_indices<int64_t>,
258
+ "Finds candidate indices with minimum distance (int64).",
259
+ py::arg("coordinates"), py::arg("min_distance"));
260
+ m.def("find_candidate_indices", &find_candidate_indices<int32_t>,
261
+ "Finds candidate indices with minimum distance (int32).",
262
+ py::arg("coordinates"), py::arg("min_distance"));
263
+
264
+
265
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<double>,
266
+ "Finds candidate coordinates with minimum distance (float64).",
267
+ py::arg("coordinates"), py::arg("min_distance"));
268
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<float>,
269
+ "Finds candidate coordinates with minimum distance (float32).",
270
+ py::arg("coordinates"), py::arg("min_distance"));
271
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<int64_t>,
272
+ "Finds candidate coordinates with minimum distance (int64).",
273
+ py::arg("coordinates"), py::arg("min_distance"));
274
+ m.def("find_candidate_coordinates", &find_candidate_coordinates<int32_t>,
275
+ "Finds candidate coordinates with minimum distance (int32).",
276
+ py::arg("coordinates"), py::arg("min_distance"));
277
+
278
+
279
+ m.def("max_index_by_label", &max_index_by_label<double, double>,
280
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
281
+ m.def("max_index_by_label", &max_index_by_label<double, float>,
282
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
283
+ m.def("max_index_by_label", &max_index_by_label<double, int64_t>,
284
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
285
+ m.def("max_index_by_label", &max_index_by_label<double, int32_t>,
286
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
287
+
288
+ m.def("max_index_by_label", &max_index_by_label<float, double>,
289
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
290
+ m.def("max_index_by_label", &max_index_by_label<float, float>,
291
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
292
+ m.def("max_index_by_label", &max_index_by_label<float, int64_t>,
293
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
294
+ m.def("max_index_by_label", &max_index_by_label<float, int32_t>,
295
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
296
+
297
+ m.def("max_index_by_label", &max_index_by_label<int64_t, double>,
298
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
299
+ m.def("max_index_by_label", &max_index_by_label<int64_t, float>,
300
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
301
+ m.def("max_index_by_label", &max_index_by_label<int64_t, int64_t>,
302
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
303
+ m.def("max_index_by_label", &max_index_by_label<int64_t, int32_t>,
304
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
305
+
306
+ m.def("max_index_by_label", &max_index_by_label<int32_t, double>,
307
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
308
+ m.def("max_index_by_label", &max_index_by_label<int32_t, float>,
309
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
310
+ m.def("max_index_by_label", &max_index_by_label<int32_t, int64_t>,
311
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
312
+ m.def("max_index_by_label", &max_index_by_label<int32_t, int32_t>,
313
+ "Maximum value by label", py::arg("labels"), py::arg("scores"));
314
+
315
+
316
+ m.def("online_statistics", &online_statistics<double>, py::arg("arr"),
317
+ py::arg("n") = 0, py::arg("rmean") = 0,
318
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
319
+ "Compute running online statistics on a numpy array.");
320
+ m.def("online_statistics", &online_statistics<float>, py::arg("arr"),
321
+ py::arg("n") = 0, py::arg("rmean") = 0,
322
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
323
+ "Compute running online statistics on a numpy array.");
324
+ m.def("online_statistics", &online_statistics<int64_t>, py::arg("arr"),
325
+ py::arg("n") = 0, py::arg("rmean") = 0,
326
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
327
+ "Compute running online statistics on a numpy array.");
328
+ m.def("online_statistics", &online_statistics<int32_t>, py::arg("arr"),
329
+ py::arg("n") = 0, py::arg("rmean") = 0,
330
+ py::arg("ssqd") = 0, py::arg("reference") = 0,
331
+ "Compute running online statistics on a numpy array.");
332
+ }