multipers 2.2.3__cp310-cp310-win_amd64.whl → 2.3.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -31
  2. multipers/_signed_measure_meta.py +430 -430
  3. multipers/_slicer_meta.py +211 -212
  4. multipers/data/MOL2.py +458 -458
  5. multipers/data/UCR.py +18 -18
  6. multipers/data/graphs.py +466 -466
  7. multipers/data/immuno_regions.py +27 -27
  8. multipers/data/pytorch2simplextree.py +90 -90
  9. multipers/data/shape3d.py +101 -101
  10. multipers/data/synthetic.py +113 -111
  11. multipers/distances.py +198 -198
  12. multipers/filtration_conversions.pxd.tp +84 -84
  13. multipers/filtrations/__init__.py +18 -0
  14. multipers/filtrations/filtrations.py +289 -0
  15. multipers/filtrations.pxd +224 -224
  16. multipers/function_rips.cp310-win_amd64.pyd +0 -0
  17. multipers/function_rips.pyx +105 -105
  18. multipers/grids.cp310-win_amd64.pyd +0 -0
  19. multipers/grids.pyx +350 -350
  20. multipers/gudhi/Persistence_slices_interface.h +132 -132
  21. multipers/gudhi/Simplex_tree_interface.h +239 -245
  22. multipers/gudhi/Simplex_tree_multi_interface.h +516 -561
  23. multipers/gudhi/cubical_to_boundary.h +59 -59
  24. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -450
  25. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -1070
  26. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -579
  27. multipers/gudhi/gudhi/Debug_utils.h +45 -45
  28. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -484
  29. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -455
  30. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -450
  31. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -531
  32. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -507
  33. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -531
  34. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -355
  35. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -376
  36. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -420
  37. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -400
  38. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -418
  39. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -337
  40. multipers/gudhi/gudhi/Matrix.h +2107 -2107
  41. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -1038
  42. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -171
  43. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -282
  44. multipers/gudhi/gudhi/Off_reader.h +173 -173
  45. multipers/gudhi/gudhi/One_critical_filtration.h +1432 -1431
  46. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -769
  47. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -686
  48. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -842
  49. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -1350
  50. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -1105
  51. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -859
  52. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -910
  53. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -139
  54. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -230
  55. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -211
  56. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -60
  57. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -60
  58. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -136
  59. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -190
  60. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -616
  61. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -150
  62. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -106
  63. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -219
  64. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -327
  65. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -1140
  66. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -934
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -934
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -980
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -1092
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -192
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -921
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -1093
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -1012
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -1244
  75. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -186
  76. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -164
  77. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -156
  78. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -376
  79. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -540
  80. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -118
  81. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -173
  82. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -128
  83. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -745
  84. multipers/gudhi/gudhi/Points_off_io.h +171 -171
  85. multipers/gudhi/gudhi/Simple_object_pool.h +69 -69
  86. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -463
  87. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -83
  88. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -106
  89. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -277
  90. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -62
  91. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -27
  92. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -62
  93. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -157
  94. multipers/gudhi/gudhi/Simplex_tree.h +2794 -2794
  95. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -163
  96. multipers/gudhi/gudhi/distance_functions.h +62 -62
  97. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -104
  98. multipers/gudhi/gudhi/persistence_interval.h +253 -253
  99. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -170
  100. multipers/gudhi/gudhi/reader_utils.h +367 -367
  101. multipers/gudhi/mma_interface_coh.h +256 -255
  102. multipers/gudhi/mma_interface_h0.h +223 -231
  103. multipers/gudhi/mma_interface_matrix.h +284 -282
  104. multipers/gudhi/naive_merge_tree.h +536 -575
  105. multipers/gudhi/scc_io.h +310 -289
  106. multipers/gudhi/truc.h +890 -888
  107. multipers/io.cp310-win_amd64.pyd +0 -0
  108. multipers/io.pyx +711 -711
  109. multipers/ml/accuracies.py +90 -90
  110. multipers/ml/convolutions.py +520 -520
  111. multipers/ml/invariants_with_persistable.py +79 -79
  112. multipers/ml/kernels.py +176 -176
  113. multipers/ml/mma.py +713 -714
  114. multipers/ml/one.py +472 -472
  115. multipers/ml/point_clouds.py +352 -346
  116. multipers/ml/signed_measures.py +1589 -1589
  117. multipers/ml/sliced_wasserstein.py +461 -461
  118. multipers/ml/tools.py +113 -113
  119. multipers/mma_structures.cp310-win_amd64.pyd +0 -0
  120. multipers/mma_structures.pxd +127 -127
  121. multipers/mma_structures.pyx +4 -4
  122. multipers/mma_structures.pyx.tp +1085 -1085
  123. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -93
  124. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -97
  125. multipers/multi_parameter_rank_invariant/function_rips.h +322 -322
  126. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -769
  127. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -148
  128. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -369
  129. multipers/multiparameter_edge_collapse.py +41 -41
  130. multipers/multiparameter_module_approximation/approximation.h +2296 -2295
  131. multipers/multiparameter_module_approximation/combinatory.h +129 -129
  132. multipers/multiparameter_module_approximation/debug.h +107 -107
  133. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -286
  134. multipers/multiparameter_module_approximation/heap_column.h +238 -238
  135. multipers/multiparameter_module_approximation/images.h +79 -79
  136. multipers/multiparameter_module_approximation/list_column.h +174 -174
  137. multipers/multiparameter_module_approximation/list_column_2.h +232 -232
  138. multipers/multiparameter_module_approximation/ru_matrix.h +347 -347
  139. multipers/multiparameter_module_approximation/set_column.h +135 -135
  140. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -36
  141. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -166
  142. multipers/multiparameter_module_approximation/utilities.h +403 -419
  143. multipers/multiparameter_module_approximation/vector_column.h +223 -223
  144. multipers/multiparameter_module_approximation/vector_matrix.h +331 -331
  145. multipers/multiparameter_module_approximation/vineyards.h +464 -464
  146. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -649
  147. multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
  148. multipers/multiparameter_module_approximation.pyx +216 -217
  149. multipers/pickle.py +90 -53
  150. multipers/plots.py +342 -334
  151. multipers/point_measure.cp310-win_amd64.pyd +0 -0
  152. multipers/point_measure.pyx +322 -320
  153. multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
  154. multipers/simplex_tree_multi.pxd +133 -133
  155. multipers/simplex_tree_multi.pyx +18 -15
  156. multipers/simplex_tree_multi.pyx.tp +1939 -1935
  157. multipers/slicer.cp310-win_amd64.pyd +0 -0
  158. multipers/slicer.pxd +81 -20
  159. multipers/slicer.pxd.tp +215 -214
  160. multipers/slicer.pyx +1091 -308
  161. multipers/slicer.pyx.tp +924 -914
  162. multipers/tensor/tensor.h +672 -672
  163. multipers/tensor.pxd +13 -13
  164. multipers/test.pyx +44 -44
  165. multipers/tests/__init__.py +57 -57
  166. multipers/torch/diff_grids.py +217 -217
  167. multipers/torch/rips_density.py +310 -304
  168. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/LICENSE +21 -21
  169. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/METADATA +21 -11
  170. multipers-2.3.0.dist-info/RECORD +182 -0
  171. multipers/tests/test_diff_helper.py +0 -73
  172. multipers/tests/test_hilbert_function.py +0 -82
  173. multipers/tests/test_mma.py +0 -83
  174. multipers/tests/test_point_clouds.py +0 -49
  175. multipers/tests/test_python-cpp_conversion.py +0 -82
  176. multipers/tests/test_signed_betti.py +0 -181
  177. multipers/tests/test_signed_measure.py +0 -89
  178. multipers/tests/test_simplextreemulti.py +0 -221
  179. multipers/tests/test_slicer.py +0 -221
  180. multipers-2.2.3.dist-info/RECORD +0 -189
  181. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/WHEEL +0 -0
  182. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/top_level.txt +0 -0
@@ -1,520 +1,520 @@
1
- from collections.abc import Callable, Iterable
2
- from typing import Any, Literal, Union
3
-
4
- import numpy as np
5
-
6
- global available_kernels
7
- available_kernels = Union[
8
- Literal[
9
- "gaussian", "exponential", "exponential_kernel", "multivariate_gaussian", "sinc"
10
- ],
11
- Callable,
12
- ]
13
-
14
-
15
- def convolution_signed_measures(
16
- iterable_of_signed_measures,
17
- filtrations,
18
- bandwidth,
19
- flatten: bool = True,
20
- n_jobs: int = 1,
21
- backend="pykeops",
22
- kernel: available_kernels = "gaussian",
23
- **kwargs,
24
- ):
25
- """
26
- Evaluates the convolution of the signed measures Iterable(pts, weights) with a gaussian measure of bandwidth bandwidth, on a grid given by the filtrations
27
-
28
- Parameters
29
- ----------
30
-
31
- - iterable_of_signed_measures : (num_signed_measure) x [ (npts) x (num_parameters), (npts)]
32
- - filtrations : (num_parameter) x (filtration values)
33
- - flatten : bool
34
- - n_jobs : int
35
-
36
- Outputs
37
- -------
38
-
39
- The concatenated images, for each signed measure (num_signed_measures) x (len(f) for f in filtration_values)
40
- """
41
- from multipers.grids import todense
42
-
43
- grid_iterator = todense(filtrations, product_order=True)
44
- match backend:
45
- case "sklearn":
46
-
47
- def convolution_signed_measures_on_grid(
48
- signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
49
- ):
50
- return np.concatenate(
51
- [
52
- _pts_convolution_sparse_old(
53
- pts=pts,
54
- pts_weights=weights,
55
- grid_iterator=grid_iterator,
56
- bandwidth=bandwidth,
57
- kernel=kernel,
58
- **kwargs,
59
- )
60
- for pts, weights in signed_measures
61
- ],
62
- axis=0,
63
- )
64
-
65
- case "pykeops":
66
-
67
- def convolution_signed_measures_on_grid(
68
- signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
69
- ) -> np.ndarray:
70
- return np.concatenate(
71
- [
72
- _pts_convolution_pykeops(
73
- pts=pts,
74
- pts_weights=weights,
75
- grid_iterator=grid_iterator,
76
- bandwidth=bandwidth,
77
- kernel=kernel,
78
- **kwargs,
79
- )
80
- for pts, weights in signed_measures
81
- ],
82
- axis=0,
83
- )
84
-
85
- # compiles first once
86
- pts, weights = iterable_of_signed_measures[0][0]
87
- small_pts, small_weights = pts[:2], weights[:2]
88
-
89
- _pts_convolution_pykeops(
90
- small_pts,
91
- small_weights,
92
- grid_iterator=grid_iterator,
93
- bandwidth=bandwidth,
94
- kernel=kernel,
95
- **kwargs,
96
- )
97
-
98
- if n_jobs > 1 or n_jobs == -1:
99
- prefer = "processes" if backend == "sklearn" else "threads"
100
- from joblib import Parallel, delayed
101
-
102
- convolutions = Parallel(n_jobs=n_jobs, prefer=prefer)(
103
- delayed(convolution_signed_measures_on_grid)(sms)
104
- for sms in iterable_of_signed_measures
105
- )
106
- else:
107
- convolutions = [
108
- convolution_signed_measures_on_grid(sms)
109
- for sms in iterable_of_signed_measures
110
- ]
111
- if not flatten:
112
- out_shape = [-1] + [len(f) for f in filtrations] # Degree
113
- convolutions = [x.reshape(out_shape) for x in convolutions]
114
- return np.asarray(convolutions)
115
-
116
-
117
- # def _test(r=1000, b=0.5, plot=True, kernel=0):
118
- # import matplotlib.pyplot as plt
119
- # pts, weigths = np.array([[1.,1.], [1.1,1.1]]), np.array([1,-1])
120
- # pt_list = np.array(list(product(*[np.linspace(0,2,r)]*2)))
121
- # img = _pts_convolution_sparse_pts(pts,weigths, pt_list,b,kernel=kernel)
122
- # if plot:
123
- # plt.imshow(img.reshape(r,-1).T, origin="lower")
124
- # plt.show()
125
-
126
-
127
- def _pts_convolution_sparse_old(
128
- pts: np.ndarray,
129
- pts_weights: np.ndarray,
130
- grid_iterator,
131
- kernel: available_kernels = "gaussian",
132
- bandwidth=0.1,
133
- **more_kde_args,
134
- ):
135
- """
136
- Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
137
- """
138
- from sklearn.neighbors import KernelDensity
139
-
140
- if len(pts) == 0:
141
- # warn("Found a trivial signed measure !")
142
- return np.zeros(len(grid_iterator))
143
- kde = KernelDensity(
144
- kernel=kernel, bandwidth=bandwidth, rtol=1e-4, **more_kde_args
145
- ) # TODO : check rtol
146
- pos_indices = pts_weights > 0
147
- neg_indices = pts_weights < 0
148
- img_pos = (
149
- np.zeros(len(grid_iterator))
150
- if pos_indices.sum() == 0
151
- else kde.fit(
152
- pts[pos_indices], sample_weight=pts_weights[pos_indices]
153
- ).score_samples(grid_iterator)
154
- )
155
- img_neg = (
156
- np.zeros(len(grid_iterator))
157
- if neg_indices.sum() == 0
158
- else kde.fit(
159
- pts[neg_indices], sample_weight=-pts_weights[neg_indices]
160
- ).score_samples(grid_iterator)
161
- )
162
- return np.exp(img_pos) - np.exp(img_neg)
163
-
164
-
165
- def _pts_convolution_pykeops(
166
- pts: np.ndarray,
167
- pts_weights: np.ndarray,
168
- grid_iterator,
169
- kernel: available_kernels = "gaussian",
170
- bandwidth=0.1,
171
- **more_kde_args,
172
- ):
173
- """
174
- Pykeops convolution
175
- """
176
- kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
177
- return kde.fit(
178
- pts, sample_weights=np.asarray(pts_weights, dtype=pts.dtype)
179
- ).score_samples(np.asarray(grid_iterator, dtype=pts.dtype))
180
-
181
-
182
- def gaussian_kernel(x_i, y_j, bandwidth):
183
- exponent = -(((x_i - y_j) / bandwidth) ** 2).sum(dim=-1) / 2
184
- # float is necessary for some reason (pykeops fails)
185
- kernel = (exponent).exp() / (bandwidth * float(np.sqrt(2 * np.pi)))
186
- return kernel
187
-
188
-
189
- def multivariate_gaussian_kernel(x_i, y_j, covariance_matrix_inverse):
190
- # 1 / \sqrt(2 \pi^dim * \Sigma.det()) * exp( -(x-y).T @ \Sigma ^{-1} @ (x-y))
191
- # CF https://www.kernel-operations.io/keops/_auto_examples/pytorch/plot_anisotropic_kernels.html#sphx-glr-auto-examples-pytorch-plot-anisotropic-kernels-py
192
- # and https://www.kernel-operations.io/keops/api/math-operations.html
193
- dim = x_i.shape[-1]
194
- z = x_i - y_j
195
- exponent = -(z.weightedsqnorm(covariance_matrix_inverse.flatten()) / 2)
196
- return (
197
- float((2 * np.pi) ** (-dim / 2))
198
- * (covariance_matrix_inverse.det().sqrt())
199
- * exponent.exp()
200
- )
201
-
202
-
203
- def exponential_kernel(x_i, y_j, bandwidth):
204
- # 1 / \sigma * exp( norm(x-y, dim=-1))
205
- exponent = -(((((x_i - y_j) ** 2)).sum(dim=-1) ** 1 / 2) / bandwidth)
206
- kernel = exponent.exp() / bandwidth
207
- return kernel
208
-
209
-
210
- def sinc_kernel(x_i, y_j, bandwidth):
211
- norm = ((((x_i - y_j) ** 2)).sum(dim=-1) ** 1 / 2) / bandwidth
212
- sinc = type(x_i).sinc
213
- kernel = 2 * sinc(2 * norm) - sinc(norm)
214
- return kernel
215
-
216
-
217
- def _kernel(
218
- kernel: available_kernels = "gaussian",
219
- ):
220
- match kernel:
221
- case "gaussian":
222
- return gaussian_kernel
223
- case "exponential":
224
- return exponential_kernel
225
- case "multivariate_gaussian":
226
- return multivariate_gaussian_kernel
227
- case "sinc":
228
- return sinc_kernel
229
- case _:
230
- assert callable(
231
- kernel
232
- ), f"""
233
- --------------------------
234
- Unknown kernel {kernel}.
235
- --------------------------
236
- Custom kernel has to be callable,
237
- (x:LazyTensor(n,1,D),y:LazyTensor(1,m,D),bandwidth:float) ---> kernel matrix
238
-
239
- Valid operations are given here:
240
- https://www.kernel-operations.io/keops/python/api/index.html
241
- """
242
- return kernel
243
-
244
-
245
- # TODO : multiple bandwidths at once with lazy tensors
246
- class KDE:
247
- """
248
- Fast, scikit-style, and differentiable kernel density estimation, using PyKeops.
249
- """
250
-
251
- def __init__(
252
- self,
253
- bandwidth: Any = 1,
254
- kernel: available_kernels = "gaussian",
255
- return_log: bool = False,
256
- ):
257
- """
258
- bandwidth : numeric
259
- bandwidth for Gaussian kernel
260
- """
261
- self.X = None
262
- self.bandwidth = bandwidth
263
- self.kernel: available_kernels = kernel
264
- self._kernel = None
265
- self._backend = None
266
- self._sample_weights = None
267
- self.return_log = return_log
268
-
269
- def fit(self, X, sample_weights=None, y=None):
270
- self.X = X
271
- self._sample_weights = sample_weights
272
- if isinstance(X, np.ndarray):
273
- self._backend = np
274
- else:
275
- import torch
276
-
277
- if isinstance(X, torch.Tensor):
278
- self._backend = torch
279
- else:
280
- raise Exception("Unsupported backend.")
281
- self._kernel = _kernel(self.kernel)
282
- return self
283
-
284
- @staticmethod
285
- def to_lazy(X, Y, x_weights):
286
- if isinstance(X, np.ndarray):
287
- from pykeops.numpy import LazyTensor
288
-
289
- lazy_x = LazyTensor(
290
- X.reshape((X.shape[0], 1, X.shape[1]))
291
- ) # numpts, 1, dim
292
- lazy_y = LazyTensor(
293
- Y.reshape((1, Y.shape[0], Y.shape[1]))
294
- ) # 1, numpts, dim
295
- if x_weights is not None:
296
- w = LazyTensor(x_weights[:, None], axis=0)
297
- return lazy_x, lazy_y, w
298
- return lazy_x, lazy_y, None
299
- import torch
300
-
301
- if isinstance(X, torch.Tensor):
302
- from pykeops.torch import LazyTensor
303
-
304
- lazy_x = LazyTensor(X.view(X.shape[0], 1, X.shape[1]))
305
- lazy_y = LazyTensor(Y.view(1, Y.shape[0], Y.shape[1]))
306
- if x_weights is not None:
307
- w = LazyTensor(x_weights[:, None], axis=0)
308
- return lazy_x, lazy_y, w
309
- return lazy_x, lazy_y, None
310
- raise Exception("Bad tensor type.")
311
-
312
- def score_samples(self, Y, X=None, return_kernel=False):
313
- """Returns the kernel density estimates of each point in `Y`.
314
-
315
- Parameters
316
- ----------
317
- Y : tensor (m, d)
318
- `m` points with `d` dimensions for which the probability density will
319
- be calculated
320
- X : tensor (n, d), optional
321
- `n` points with `d` dimensions to which KDE will be fit. Provided to
322
- allow batch calculations in `log_prob`. By default, `X` is None and
323
- all points used to initialize KernelDensityEstimator are included.
324
-
325
-
326
- Returns
327
- -------
328
- log_probs : tensor (m)
329
- log probability densities for each of the queried points in `Y`
330
- """
331
- assert self._backend is not None and self._kernel is not None, "Fit first."
332
- X = self.X if X is None else X
333
- if X.shape[0] == 0:
334
- return self._backend.zeros((Y.shape[0]))
335
- assert Y.shape[1] == X.shape[1] and X.ndim == Y.ndim == 2
336
- lazy_x, lazy_y, w = self.to_lazy(X, Y, x_weights=self._sample_weights)
337
- kernel = self._kernel(lazy_x, lazy_y, self.bandwidth)
338
- if w is not None:
339
- kernel *= w
340
- if return_kernel:
341
- return kernel
342
- density_estimation = kernel.sum(dim=0).ravel() / kernel.shape[0] # mean
343
- return (
344
- self._backend.log(density_estimation)
345
- if self.return_log
346
- else density_estimation
347
- )
348
-
349
-
350
- def batch_signed_measure_convolutions(
351
- signed_measures, # array of shape (num_data,num_pts,D)
352
- x, # array of shape (num_x, D) or (num_data, num_x, D)
353
- bandwidth, # either float or matrix if multivariate kernel
354
- kernel: available_kernels,
355
- ):
356
- """
357
- Input
358
- -----
359
- - signed_measures: unragged, of shape (num_data, num_pts, D+1)
360
- where last coord is weights, (0 for dummy points)
361
- - x : the points to convolve (num_x,D)
362
- - bandwidth : the bandwidths or covariance matrix inverse or ... of the kernel
363
- - kernel : "gaussian", "multivariate_gaussian", "exponential", or Callable (x_i, y_i, bandwidth)->float
364
-
365
- Output
366
- ------
367
- Array of shape (num_convolutions, (num_axis), num_data,
368
- Array of shape (num_convolutions, (num_axis), num_data, max_x_size)
369
- """
370
- if signed_measures.ndim == 2:
371
- signed_measures = signed_measures[None, :, :]
372
- sms = signed_measures[..., :-1]
373
- weights = signed_measures[..., -1]
374
- if isinstance(signed_measures, np.ndarray):
375
- from pykeops.numpy import LazyTensor
376
- else:
377
- import torch
378
-
379
- assert isinstance(signed_measures, torch.Tensor)
380
- from pykeops.torch import LazyTensor
381
-
382
- _sms = LazyTensor(sms[..., None, :].contiguous())
383
- _x = x[..., None, :, :].contiguous()
384
-
385
- sms_kernel = _kernel(kernel)(_sms, _x, bandwidth)
386
- out = (sms_kernel * weights[..., None, None].contiguous()).sum(
387
- signed_measures.ndim - 2
388
- )
389
- assert out.shape[-1] == 1, "Pykeops bug fixed, TODO : refix this "
390
- out = out[..., 0] ## pykeops bug + ensures its a tensor
391
- # assert out.shape == (x.shape[0], x.shape[1]), f"{x.shape=}, {out.shape=}"
392
- return out
393
-
394
-
395
- class DTM:
396
- """
397
- Distance To Measure
398
- """
399
-
400
- def __init__(self, masses, metric: str = "euclidean", **_kdtree_kwargs):
401
- """
402
- mass : float in [0,1]
403
- The mass threshold
404
- metric :
405
- The distance between points to consider
406
- """
407
- self.masses = masses
408
- self.metric = metric
409
- self._kdtree_kwargs = _kdtree_kwargs
410
- self._ks = None
411
- self._kdtree = None
412
- self._X = None
413
- self._backend = None
414
-
415
- def fit(self, X, sample_weights=None, y=None):
416
- if len(self.masses) == 0:
417
- return self
418
- assert np.max(self.masses) <= 1, "All masses should be in (0,1]."
419
- from sklearn.neighbors import KDTree
420
-
421
- if not isinstance(X, np.ndarray):
422
- import torch
423
-
424
- assert isinstance(X, torch.Tensor), "Backend has to be numpy of torch"
425
- _X = X.detach()
426
- self._backend = "torch"
427
- else:
428
- _X = X
429
- self._backend = "numpy"
430
- self._ks = np.array([int(mass * X.shape[0]) + 1 for mass in self.masses])
431
- self._kdtree = KDTree(_X, metric=self.metric, **self._kdtree_kwargs)
432
- self._X = X
433
- return self
434
-
435
- def score_samples(self, Y, X=None):
436
- """Returns the kernel density estimates of each point in `Y`.
437
-
438
- Parameters
439
- ----------
440
- Y : tensor (m, d)
441
- `m` points with `d` dimensions for which the probability density will
442
- be calculated
443
-
444
-
445
- Returns
446
- -------
447
- the DTMs of Y, for each mass in masses.
448
- """
449
- if len(self.masses) == 0:
450
- return np.empty((0, len(Y)))
451
- assert (
452
- self._ks is not None and self._kdtree is not None and self._X is not None
453
- ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
454
- assert Y.ndim == 2
455
- if self._backend == "torch":
456
- _Y = Y.detach().numpy()
457
- else:
458
- _Y = Y
459
- NN_Dist, NN = self._kdtree.query(_Y, self._ks.max(), return_distance=True)
460
- DTMs = np.array([((NN_Dist**2)[:, :k].mean(1)) ** 0.5 for k in self._ks])
461
- return DTMs
462
-
463
- def score_samples_diff(self, Y):
464
- """Returns the kernel density estimates of each point in `Y`.
465
-
466
- Parameters
467
- ----------
468
- Y : tensor (m, d)
469
- `m` points with `d` dimensions for which the probability density will
470
- be calculated
471
- X : tensor (n, d), optional
472
- `n` points with `d` dimensions to which KDE will be fit. Provided to
473
- allow batch calculations in `log_prob`. By default, `X` is None and
474
- all points used to initialize KernelDensityEstimator are included.
475
-
476
-
477
- Returns
478
- -------
479
- log_probs : tensor (m)
480
- log probability densities for each of the queried points in `Y`
481
- """
482
- import torch
483
-
484
- if len(self.masses) == 0:
485
- return torch.empty(0, len(Y))
486
-
487
- assert Y.ndim == 2
488
- assert self._backend == "torch", "Use the non-diff version with numpy."
489
- assert (
490
- self._ks is not None and self._kdtree is not None and self._X is not None
491
- ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
492
- NN = self._kdtree.query(Y.detach(), self._ks.max(), return_distance=False)
493
- DTMs = tuple(
494
- (((self._X[NN] - Y[:, None, :]) ** 2)[:, :k].sum(dim=(1, 2)) / k) ** 0.5
495
- for k in self._ks
496
- ) # TODO : kdtree already computes distance, find implementation of kdtree that is pytorch differentiable
497
- return DTMs
498
-
499
-
500
- # def _pts_convolution_sparse(pts:np.ndarray, pts_weights:np.ndarray, filtration_grid:Iterable[np.ndarray], kernel="gaussian", bandwidth=0.1, **more_kde_args):
501
- # """
502
- # Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
503
- # """
504
- # from sklearn.neighbors import KernelDensity
505
- # grid_iterator = np.asarray(list(product(*filtration_grid)))
506
- # grid_shape = [len(f) for f in filtration_grid]
507
- # if len(pts) == 0:
508
- # # warn("Found a trivial signed measure !")
509
- # return np.zeros(shape=grid_shape)
510
- # kde = KernelDensity(kernel=kernel, bandwidth=bandwidth, rtol = 1e-4, **more_kde_args) # TODO : check rtol
511
-
512
- # pos_indices = pts_weights>0
513
- # neg_indices = pts_weights<0
514
- # img_pos = kde.fit(pts[pos_indices], sample_weight=pts_weights[pos_indices]).score_samples(grid_iterator).reshape(grid_shape)
515
- # img_neg = kde.fit(pts[neg_indices], sample_weight=-pts_weights[neg_indices]).score_samples(grid_iterator).reshape(grid_shape)
516
- # return np.exp(img_pos) - np.exp(img_neg)
517
-
518
-
519
- # Precompiles the convolution
520
- # _test(r=2,b=.5, plot=False)
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Any, Literal, Union
3
+
4
+ import numpy as np
5
+
6
+ global available_kernels
7
+ available_kernels = Union[
8
+ Literal[
9
+ "gaussian", "exponential", "exponential_kernel", "multivariate_gaussian", "sinc"
10
+ ],
11
+ Callable,
12
+ ]
13
+
14
+
15
+ def convolution_signed_measures(
16
+ iterable_of_signed_measures,
17
+ filtrations,
18
+ bandwidth,
19
+ flatten: bool = True,
20
+ n_jobs: int = 1,
21
+ backend="pykeops",
22
+ kernel: available_kernels = "gaussian",
23
+ **kwargs,
24
+ ):
25
+ """
26
+ Evaluates the convolution of the signed measures Iterable(pts, weights) with a gaussian measure of bandwidth bandwidth, on a grid given by the filtrations
27
+
28
+ Parameters
29
+ ----------
30
+
31
+ - iterable_of_signed_measures : (num_signed_measure) x [ (npts) x (num_parameters), (npts)]
32
+ - filtrations : (num_parameter) x (filtration values)
33
+ - flatten : bool
34
+ - n_jobs : int
35
+
36
+ Outputs
37
+ -------
38
+
39
+ The concatenated images, for each signed measure (num_signed_measures) x (len(f) for f in filtration_values)
40
+ """
41
+ from multipers.grids import todense
42
+
43
+ grid_iterator = todense(filtrations, product_order=True)
44
+ match backend:
45
+ case "sklearn":
46
+
47
+ def convolution_signed_measures_on_grid(
48
+ signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
49
+ ):
50
+ return np.concatenate(
51
+ [
52
+ _pts_convolution_sparse_old(
53
+ pts=pts,
54
+ pts_weights=weights,
55
+ grid_iterator=grid_iterator,
56
+ bandwidth=bandwidth,
57
+ kernel=kernel,
58
+ **kwargs,
59
+ )
60
+ for pts, weights in signed_measures
61
+ ],
62
+ axis=0,
63
+ )
64
+
65
+ case "pykeops":
66
+
67
+ def convolution_signed_measures_on_grid(
68
+ signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
69
+ ) -> np.ndarray:
70
+ return np.concatenate(
71
+ [
72
+ _pts_convolution_pykeops(
73
+ pts=pts,
74
+ pts_weights=weights,
75
+ grid_iterator=grid_iterator,
76
+ bandwidth=bandwidth,
77
+ kernel=kernel,
78
+ **kwargs,
79
+ )
80
+ for pts, weights in signed_measures
81
+ ],
82
+ axis=0,
83
+ )
84
+
85
+ # compiles first once
86
+ pts, weights = iterable_of_signed_measures[0][0]
87
+ small_pts, small_weights = pts[:2], weights[:2]
88
+
89
+ _pts_convolution_pykeops(
90
+ small_pts,
91
+ small_weights,
92
+ grid_iterator=grid_iterator,
93
+ bandwidth=bandwidth,
94
+ kernel=kernel,
95
+ **kwargs,
96
+ )
97
+
98
+ if n_jobs > 1 or n_jobs == -1:
99
+ prefer = "processes" if backend == "sklearn" else "threads"
100
+ from joblib import Parallel, delayed
101
+
102
+ convolutions = Parallel(n_jobs=n_jobs, prefer=prefer)(
103
+ delayed(convolution_signed_measures_on_grid)(sms)
104
+ for sms in iterable_of_signed_measures
105
+ )
106
+ else:
107
+ convolutions = [
108
+ convolution_signed_measures_on_grid(sms)
109
+ for sms in iterable_of_signed_measures
110
+ ]
111
+ if not flatten:
112
+ out_shape = [-1] + [len(f) for f in filtrations] # Degree
113
+ convolutions = [x.reshape(out_shape) for x in convolutions]
114
+ return np.asarray(convolutions)
115
+
116
+
117
+ # def _test(r=1000, b=0.5, plot=True, kernel=0):
118
+ # import matplotlib.pyplot as plt
119
+ # pts, weigths = np.array([[1.,1.], [1.1,1.1]]), np.array([1,-1])
120
+ # pt_list = np.array(list(product(*[np.linspace(0,2,r)]*2)))
121
+ # img = _pts_convolution_sparse_pts(pts,weigths, pt_list,b,kernel=kernel)
122
+ # if plot:
123
+ # plt.imshow(img.reshape(r,-1).T, origin="lower")
124
+ # plt.show()
125
+
126
+
127
+ def _pts_convolution_sparse_old(
128
+ pts: np.ndarray,
129
+ pts_weights: np.ndarray,
130
+ grid_iterator,
131
+ kernel: available_kernels = "gaussian",
132
+ bandwidth=0.1,
133
+ **more_kde_args,
134
+ ):
135
+ """
136
+ Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
137
+ """
138
+ from sklearn.neighbors import KernelDensity
139
+
140
+ if len(pts) == 0:
141
+ # warn("Found a trivial signed measure !")
142
+ return np.zeros(len(grid_iterator))
143
+ kde = KernelDensity(
144
+ kernel=kernel, bandwidth=bandwidth, rtol=1e-4, **more_kde_args
145
+ ) # TODO : check rtol
146
+ pos_indices = pts_weights > 0
147
+ neg_indices = pts_weights < 0
148
+ img_pos = (
149
+ np.zeros(len(grid_iterator))
150
+ if pos_indices.sum() == 0
151
+ else kde.fit(
152
+ pts[pos_indices], sample_weight=pts_weights[pos_indices]
153
+ ).score_samples(grid_iterator)
154
+ )
155
+ img_neg = (
156
+ np.zeros(len(grid_iterator))
157
+ if neg_indices.sum() == 0
158
+ else kde.fit(
159
+ pts[neg_indices], sample_weight=-pts_weights[neg_indices]
160
+ ).score_samples(grid_iterator)
161
+ )
162
+ return np.exp(img_pos) - np.exp(img_neg)
163
+
164
+
165
+ def _pts_convolution_pykeops(
166
+ pts: np.ndarray,
167
+ pts_weights: np.ndarray,
168
+ grid_iterator,
169
+ kernel: available_kernels = "gaussian",
170
+ bandwidth=0.1,
171
+ **more_kde_args,
172
+ ):
173
+ """
174
+ Pykeops convolution
175
+ """
176
+ kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
177
+ return kde.fit(
178
+ pts, sample_weights=np.asarray(pts_weights, dtype=pts.dtype)
179
+ ).score_samples(np.asarray(grid_iterator, dtype=pts.dtype))
180
+
181
+
182
+ def gaussian_kernel(x_i, y_j, bandwidth):
183
+ exponent = -(((x_i - y_j) / bandwidth) ** 2).sum(dim=-1) / 2
184
+ # float is necessary for some reason (pykeops fails)
185
+ kernel = (exponent).exp() / (bandwidth * float(np.sqrt(2 * np.pi)))
186
+ return kernel
187
+
188
+
189
+ def multivariate_gaussian_kernel(x_i, y_j, covariance_matrix_inverse):
190
+ # 1 / \sqrt(2 \pi^dim * \Sigma.det()) * exp( -(x-y).T @ \Sigma ^{-1} @ (x-y))
191
+ # CF https://www.kernel-operations.io/keops/_auto_examples/pytorch/plot_anisotropic_kernels.html#sphx-glr-auto-examples-pytorch-plot-anisotropic-kernels-py
192
+ # and https://www.kernel-operations.io/keops/api/math-operations.html
193
+ dim = x_i.shape[-1]
194
+ z = x_i - y_j
195
+ exponent = -(z.weightedsqnorm(covariance_matrix_inverse.flatten()) / 2)
196
+ return (
197
+ float((2 * np.pi) ** (-dim / 2))
198
+ * (covariance_matrix_inverse.det().sqrt())
199
+ * exponent.exp()
200
+ )
201
+
202
+
203
+ def exponential_kernel(x_i, y_j, bandwidth):
204
+ # 1 / \sigma * exp( norm(x-y, dim=-1))
205
+ exponent = -(((((x_i - y_j) ** 2)).sum(dim=-1) ** 1 / 2) / bandwidth)
206
+ kernel = exponent.exp() / bandwidth
207
+ return kernel
208
+
209
+
210
+ def sinc_kernel(x_i, y_j, bandwidth):
211
+ norm = ((((x_i - y_j) ** 2)).sum(dim=-1) ** 1 / 2) / bandwidth
212
+ sinc = type(x_i).sinc
213
+ kernel = 2 * sinc(2 * norm) - sinc(norm)
214
+ return kernel
215
+
216
+
217
+ def _kernel(
218
+ kernel: available_kernels = "gaussian",
219
+ ):
220
+ match kernel:
221
+ case "gaussian":
222
+ return gaussian_kernel
223
+ case "exponential":
224
+ return exponential_kernel
225
+ case "multivariate_gaussian":
226
+ return multivariate_gaussian_kernel
227
+ case "sinc":
228
+ return sinc_kernel
229
+ case _:
230
+ assert callable(
231
+ kernel
232
+ ), f"""
233
+ --------------------------
234
+ Unknown kernel {kernel}.
235
+ --------------------------
236
+ Custom kernel has to be callable,
237
+ (x:LazyTensor(n,1,D),y:LazyTensor(1,m,D),bandwidth:float) ---> kernel matrix
238
+
239
+ Valid operations are given here:
240
+ https://www.kernel-operations.io/keops/python/api/index.html
241
+ """
242
+ return kernel
243
+
244
+
245
+ # TODO : multiple bandwidths at once with lazy tensors
246
+ class KDE:
247
+ """
248
+ Fast, scikit-style, and differentiable kernel density estimation, using PyKeops.
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ bandwidth: Any = 1,
254
+ kernel: available_kernels = "gaussian",
255
+ return_log: bool = False,
256
+ ):
257
+ """
258
+ bandwidth : numeric
259
+ bandwidth for Gaussian kernel
260
+ """
261
+ self.X = None
262
+ self.bandwidth = bandwidth
263
+ self.kernel: available_kernels = kernel
264
+ self._kernel = None
265
+ self._backend = None
266
+ self._sample_weights = None
267
+ self.return_log = return_log
268
+
269
+ def fit(self, X, sample_weights=None, y=None):
270
+ self.X = X
271
+ self._sample_weights = sample_weights
272
+ if isinstance(X, np.ndarray):
273
+ self._backend = np
274
+ else:
275
+ import torch
276
+
277
+ if isinstance(X, torch.Tensor):
278
+ self._backend = torch
279
+ else:
280
+ raise Exception("Unsupported backend.")
281
+ self._kernel = _kernel(self.kernel)
282
+ return self
283
+
284
+ @staticmethod
285
+ def to_lazy(X, Y, x_weights):
286
+ if isinstance(X, np.ndarray):
287
+ from pykeops.numpy import LazyTensor
288
+
289
+ lazy_x = LazyTensor(
290
+ X.reshape((X.shape[0], 1, X.shape[1]))
291
+ ) # numpts, 1, dim
292
+ lazy_y = LazyTensor(
293
+ Y.reshape((1, Y.shape[0], Y.shape[1]))
294
+ ) # 1, numpts, dim
295
+ if x_weights is not None:
296
+ w = LazyTensor(x_weights[:, None], axis=0)
297
+ return lazy_x, lazy_y, w
298
+ return lazy_x, lazy_y, None
299
+ import torch
300
+
301
+ if isinstance(X, torch.Tensor):
302
+ from pykeops.torch import LazyTensor
303
+
304
+ lazy_x = LazyTensor(X.view(X.shape[0], 1, X.shape[1]))
305
+ lazy_y = LazyTensor(Y.view(1, Y.shape[0], Y.shape[1]))
306
+ if x_weights is not None:
307
+ w = LazyTensor(x_weights[:, None], axis=0)
308
+ return lazy_x, lazy_y, w
309
+ return lazy_x, lazy_y, None
310
+ raise Exception("Bad tensor type.")
311
+
312
+ def score_samples(self, Y, X=None, return_kernel=False):
313
+ """Returns the kernel density estimates of each point in `Y`.
314
+
315
+ Parameters
316
+ ----------
317
+ Y : tensor (m, d)
318
+ `m` points with `d` dimensions for which the probability density will
319
+ be calculated
320
+ X : tensor (n, d), optional
321
+ `n` points with `d` dimensions to which KDE will be fit. Provided to
322
+ allow batch calculations in `log_prob`. By default, `X` is None and
323
+ all points used to initialize KernelDensityEstimator are included.
324
+
325
+
326
+ Returns
327
+ -------
328
+ log_probs : tensor (m)
329
+ log probability densities for each of the queried points in `Y`
330
+ """
331
+ assert self._backend is not None and self._kernel is not None, "Fit first."
332
+ X = self.X if X is None else X
333
+ if X.shape[0] == 0:
334
+ return self._backend.zeros((Y.shape[0]))
335
+ assert Y.shape[1] == X.shape[1] and X.ndim == Y.ndim == 2
336
+ lazy_x, lazy_y, w = self.to_lazy(X, Y, x_weights=self._sample_weights)
337
+ kernel = self._kernel(lazy_x, lazy_y, self.bandwidth)
338
+ if w is not None:
339
+ kernel *= w
340
+ if return_kernel:
341
+ return kernel
342
+ density_estimation = kernel.sum(dim=0).ravel() / kernel.shape[0] # mean
343
+ return (
344
+ self._backend.log(density_estimation)
345
+ if self.return_log
346
+ else density_estimation
347
+ )
348
+
349
+
350
+ def batch_signed_measure_convolutions(
351
+ signed_measures, # array of shape (num_data,num_pts,D)
352
+ x, # array of shape (num_x, D) or (num_data, num_x, D)
353
+ bandwidth, # either float or matrix if multivariate kernel
354
+ kernel: available_kernels,
355
+ ):
356
+ """
357
+ Input
358
+ -----
359
+ - signed_measures: unragged, of shape (num_data, num_pts, D+1)
360
+ where last coord is weights, (0 for dummy points)
361
+ - x : the points to convolve (num_x,D)
362
+ - bandwidth : the bandwidths or covariance matrix inverse or ... of the kernel
363
+ - kernel : "gaussian", "multivariate_gaussian", "exponential", or Callable (x_i, y_i, bandwidth)->float
364
+
365
+ Output
366
+ ------
367
+ Array of shape (num_convolutions, (num_axis), num_data,
368
+ Array of shape (num_convolutions, (num_axis), num_data, max_x_size)
369
+ """
370
+ if signed_measures.ndim == 2:
371
+ signed_measures = signed_measures[None, :, :]
372
+ sms = signed_measures[..., :-1]
373
+ weights = signed_measures[..., -1]
374
+ if isinstance(signed_measures, np.ndarray):
375
+ from pykeops.numpy import LazyTensor
376
+ else:
377
+ import torch
378
+
379
+ assert isinstance(signed_measures, torch.Tensor)
380
+ from pykeops.torch import LazyTensor
381
+
382
+ _sms = LazyTensor(sms[..., None, :].contiguous())
383
+ _x = x[..., None, :, :].contiguous()
384
+
385
+ sms_kernel = _kernel(kernel)(_sms, _x, bandwidth)
386
+ out = (sms_kernel * weights[..., None, None].contiguous()).sum(
387
+ signed_measures.ndim - 2
388
+ )
389
+ assert out.shape[-1] == 1, "Pykeops bug fixed, TODO : refix this "
390
+ out = out[..., 0] ## pykeops bug + ensures its a tensor
391
+ # assert out.shape == (x.shape[0], x.shape[1]), f"{x.shape=}, {out.shape=}"
392
+ return out
393
+
394
+
395
+ class DTM:
396
+ """
397
+ Distance To Measure
398
+ """
399
+
400
+ def __init__(self, masses, metric: str = "euclidean", **_kdtree_kwargs):
401
+ """
402
+ mass : float in [0,1]
403
+ The mass threshold
404
+ metric :
405
+ The distance between points to consider
406
+ """
407
+ self.masses = masses
408
+ self.metric = metric
409
+ self._kdtree_kwargs = _kdtree_kwargs
410
+ self._ks = None
411
+ self._kdtree = None
412
+ self._X = None
413
+ self._backend = None
414
+
415
+ def fit(self, X, sample_weights=None, y=None):
416
+ if len(self.masses) == 0:
417
+ return self
418
+ assert np.max(self.masses) <= 1, "All masses should be in (0,1]."
419
+ from sklearn.neighbors import KDTree
420
+
421
+ if not isinstance(X, np.ndarray):
422
+ import torch
423
+
424
+ assert isinstance(X, torch.Tensor), "Backend has to be numpy of torch"
425
+ _X = X.detach()
426
+ self._backend = "torch"
427
+ else:
428
+ _X = X
429
+ self._backend = "numpy"
430
+ self._ks = np.array([int(mass * X.shape[0]) + 1 for mass in self.masses])
431
+ self._kdtree = KDTree(_X, metric=self.metric, **self._kdtree_kwargs)
432
+ self._X = X
433
+ return self
434
+
435
+ def score_samples(self, Y, X=None):
436
+ """Returns the kernel density estimates of each point in `Y`.
437
+
438
+ Parameters
439
+ ----------
440
+ Y : tensor (m, d)
441
+ `m` points with `d` dimensions for which the probability density will
442
+ be calculated
443
+
444
+
445
+ Returns
446
+ -------
447
+ the DTMs of Y, for each mass in masses.
448
+ """
449
+ if len(self.masses) == 0:
450
+ return np.empty((0, len(Y)))
451
+ assert (
452
+ self._ks is not None and self._kdtree is not None and self._X is not None
453
+ ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
454
+ assert Y.ndim == 2
455
+ if self._backend == "torch":
456
+ _Y = Y.detach().numpy()
457
+ else:
458
+ _Y = Y
459
+ NN_Dist, NN = self._kdtree.query(_Y, self._ks.max(), return_distance=True)
460
+ DTMs = np.array([((NN_Dist**2)[:, :k].mean(1)) ** 0.5 for k in self._ks])
461
+ return DTMs
462
+
463
+ def score_samples_diff(self, Y):
464
+ """Returns the kernel density estimates of each point in `Y`.
465
+
466
+ Parameters
467
+ ----------
468
+ Y : tensor (m, d)
469
+ `m` points with `d` dimensions for which the probability density will
470
+ be calculated
471
+ X : tensor (n, d), optional
472
+ `n` points with `d` dimensions to which KDE will be fit. Provided to
473
+ allow batch calculations in `log_prob`. By default, `X` is None and
474
+ all points used to initialize KernelDensityEstimator are included.
475
+
476
+
477
+ Returns
478
+ -------
479
+ log_probs : tensor (m)
480
+ log probability densities for each of the queried points in `Y`
481
+ """
482
+ import torch
483
+
484
+ if len(self.masses) == 0:
485
+ return torch.empty(0, len(Y))
486
+
487
+ assert Y.ndim == 2
488
+ assert self._backend == "torch", "Use the non-diff version with numpy."
489
+ assert (
490
+ self._ks is not None and self._kdtree is not None and self._X is not None
491
+ ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
492
+ NN = self._kdtree.query(Y.detach(), self._ks.max(), return_distance=False)
493
+ DTMs = tuple(
494
+ (((self._X[NN] - Y[:, None, :]) ** 2)[:, :k].sum(dim=(1, 2)) / k) ** 0.5
495
+ for k in self._ks
496
+ ) # TODO : kdtree already computes distance, find implementation of kdtree that is pytorch differentiable
497
+ return DTMs
498
+
499
+
500
+ # def _pts_convolution_sparse(pts:np.ndarray, pts_weights:np.ndarray, filtration_grid:Iterable[np.ndarray], kernel="gaussian", bandwidth=0.1, **more_kde_args):
501
+ # """
502
+ # Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
503
+ # """
504
+ # from sklearn.neighbors import KernelDensity
505
+ # grid_iterator = np.asarray(list(product(*filtration_grid)))
506
+ # grid_shape = [len(f) for f in filtration_grid]
507
+ # if len(pts) == 0:
508
+ # # warn("Found a trivial signed measure !")
509
+ # return np.zeros(shape=grid_shape)
510
+ # kde = KernelDensity(kernel=kernel, bandwidth=bandwidth, rtol = 1e-4, **more_kde_args) # TODO : check rtol
511
+
512
+ # pos_indices = pts_weights>0
513
+ # neg_indices = pts_weights<0
514
+ # img_pos = kde.fit(pts[pos_indices], sample_weight=pts_weights[pos_indices]).score_samples(grid_iterator).reshape(grid_shape)
515
+ # img_neg = kde.fit(pts[neg_indices], sample_weight=-pts_weights[neg_indices]).score_samples(grid_iterator).reshape(grid_shape)
516
+ # return np.exp(img_pos) - np.exp(img_neg)
517
+
518
+
519
+ # Precompiles the convolution
520
+ # _test(r=2,b=.5, plot=False)