multipers 2.3.3__cp313-cp313-manylinux_2_39_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -0
  2. multipers/_signed_measure_meta.py +450 -0
  3. multipers/_slicer_meta.py +211 -0
  4. multipers/array_api/__init__.py +62 -0
  5. multipers/array_api/numpy.py +104 -0
  6. multipers/array_api/torch.py +117 -0
  7. multipers/data/MOL2.py +458 -0
  8. multipers/data/UCR.py +18 -0
  9. multipers/data/__init__.py +1 -0
  10. multipers/data/graphs.py +466 -0
  11. multipers/data/immuno_regions.py +27 -0
  12. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  13. multipers/data/pytorch2simplextree.py +91 -0
  14. multipers/data/shape3d.py +101 -0
  15. multipers/data/synthetic.py +113 -0
  16. multipers/distances.py +202 -0
  17. multipers/filtration_conversions.pxd +229 -0
  18. multipers/filtration_conversions.pxd.tp +84 -0
  19. multipers/filtrations/__init__.py +18 -0
  20. multipers/filtrations/density.py +533 -0
  21. multipers/filtrations/filtrations.py +361 -0
  22. multipers/filtrations.pxd +224 -0
  23. multipers/function_rips.cpython-313-x86_64-linux-gnu.so +0 -0
  24. multipers/function_rips.pyx +105 -0
  25. multipers/grids.cpython-313-x86_64-linux-gnu.so +0 -0
  26. multipers/grids.pyx +481 -0
  27. multipers/gudhi/Persistence_slices_interface.h +132 -0
  28. multipers/gudhi/Simplex_tree_interface.h +239 -0
  29. multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
  30. multipers/gudhi/cubical_to_boundary.h +59 -0
  31. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  32. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  33. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  34. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  35. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  36. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  37. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  38. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  39. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  41. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  42. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  43. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  44. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  45. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  46. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  47. multipers/gudhi/gudhi/Matrix.h +2107 -0
  48. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  49. multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
  50. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  51. multipers/gudhi/gudhi/Off_reader.h +173 -0
  52. multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
  53. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  54. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  87. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  88. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  89. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  90. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  91. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  92. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  93. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  94. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  95. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  96. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  97. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  98. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  99. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  100. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  101. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  102. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
  103. multipers/gudhi/gudhi/distance_functions.h +62 -0
  104. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  105. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  106. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  107. multipers/gudhi/gudhi/reader_utils.h +367 -0
  108. multipers/gudhi/mma_interface_coh.h +256 -0
  109. multipers/gudhi/mma_interface_h0.h +223 -0
  110. multipers/gudhi/mma_interface_matrix.h +293 -0
  111. multipers/gudhi/naive_merge_tree.h +536 -0
  112. multipers/gudhi/scc_io.h +310 -0
  113. multipers/gudhi/truc.h +1403 -0
  114. multipers/io.cpython-313-x86_64-linux-gnu.so +0 -0
  115. multipers/io.pyx +644 -0
  116. multipers/ml/__init__.py +0 -0
  117. multipers/ml/accuracies.py +90 -0
  118. multipers/ml/invariants_with_persistable.py +79 -0
  119. multipers/ml/kernels.py +176 -0
  120. multipers/ml/mma.py +713 -0
  121. multipers/ml/one.py +472 -0
  122. multipers/ml/point_clouds.py +352 -0
  123. multipers/ml/signed_measures.py +1667 -0
  124. multipers/ml/sliced_wasserstein.py +461 -0
  125. multipers/ml/tools.py +113 -0
  126. multipers/mma_structures.cpython-313-x86_64-linux-gnu.so +0 -0
  127. multipers/mma_structures.pxd +128 -0
  128. multipers/mma_structures.pyx +2786 -0
  129. multipers/mma_structures.pyx.tp +1094 -0
  130. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
  131. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  132. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  133. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  134. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  135. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  136. multipers/multiparameter_edge_collapse.py +41 -0
  137. multipers/multiparameter_module_approximation/approximation.h +2330 -0
  138. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  139. multipers/multiparameter_module_approximation/debug.h +107 -0
  140. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  141. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  142. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  143. multipers/multiparameter_module_approximation/images.h +79 -0
  144. multipers/multiparameter_module_approximation/list_column.h +174 -0
  145. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  146. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  147. multipers/multiparameter_module_approximation/set_column.h +135 -0
  148. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  149. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  150. multipers/multiparameter_module_approximation/utilities.h +403 -0
  151. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  152. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  153. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  154. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  155. multipers/multiparameter_module_approximation.cpython-313-x86_64-linux-gnu.so +0 -0
  156. multipers/multiparameter_module_approximation.pyx +235 -0
  157. multipers/pickle.py +90 -0
  158. multipers/plots.py +470 -0
  159. multipers/point_measure.cpython-313-x86_64-linux-gnu.so +0 -0
  160. multipers/point_measure.pyx +395 -0
  161. multipers/simplex_tree_multi.cpython-313-x86_64-linux-gnu.so +0 -0
  162. multipers/simplex_tree_multi.pxd +134 -0
  163. multipers/simplex_tree_multi.pyx +10980 -0
  164. multipers/simplex_tree_multi.pyx.tp +2007 -0
  165. multipers/slicer.cpython-313-x86_64-linux-gnu.so +0 -0
  166. multipers/slicer.pxd +3034 -0
  167. multipers/slicer.pxd.tp +234 -0
  168. multipers/slicer.pyx +20481 -0
  169. multipers/slicer.pyx.tp +1088 -0
  170. multipers/tensor/tensor.h +672 -0
  171. multipers/tensor.pxd +13 -0
  172. multipers/test.pyx +44 -0
  173. multipers/tests/__init__.py +62 -0
  174. multipers/torch/__init__.py +1 -0
  175. multipers/torch/diff_grids.py +240 -0
  176. multipers/torch/rips_density.py +310 -0
  177. multipers-2.3.3.dist-info/METADATA +128 -0
  178. multipers-2.3.3.dist-info/RECORD +182 -0
  179. multipers-2.3.3.dist-info/WHEEL +5 -0
  180. multipers-2.3.3.dist-info/licenses/LICENSE +21 -0
  181. multipers-2.3.3.dist-info/top_level.txt +1 -0
  182. multipers.libs/libtbb-ca48af5c.so.12.16 +0 -0
@@ -0,0 +1,533 @@
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Any, Literal, Union
3
+
4
+ import numpy as np
5
+
6
+ from multipers.array_api import api_from_tensor, api_from_tensors
7
+
8
+ global available_kernels
9
+ available_kernels = Union[
10
+ Literal[
11
+ "gaussian", "exponential", "exponential_kernel", "multivariate_gaussian", "sinc"
12
+ ],
13
+ Callable,
14
+ ]
15
+
16
+
17
+ def convolution_signed_measures(
18
+ iterable_of_signed_measures,
19
+ filtrations,
20
+ bandwidth,
21
+ flatten: bool = True,
22
+ n_jobs: int = 1,
23
+ backend="pykeops",
24
+ kernel: available_kernels = "gaussian",
25
+ **kwargs,
26
+ ):
27
+ """
28
+ Evaluates the convolution of the signed measures Iterable(pts, weights) with a gaussian measure of bandwidth bandwidth, on a grid given by the filtrations
29
+
30
+ Parameters
31
+ ----------
32
+
33
+ - iterable_of_signed_measures : (num_signed_measure) x [ (npts) x (num_parameters), (npts)]
34
+ - filtrations : (num_parameter) x (filtration values)
35
+ - flatten : bool
36
+ - n_jobs : int
37
+
38
+ Outputs
39
+ -------
40
+
41
+ The concatenated images, for each signed measure (num_signed_measures) x (len(f) for f in filtration_values)
42
+ """
43
+ from multipers.grids import todense
44
+
45
+ grid_iterator = todense(filtrations, product_order=True)
46
+ api = api_from_tensor(iterable_of_signed_measures[0][0][0])
47
+ match backend:
48
+ case "sklearn":
49
+
50
+ def convolution_signed_measures_on_grid(
51
+ signed_measures,
52
+ ):
53
+ return api.cat(
54
+ [
55
+ _pts_convolution_sparse_old(
56
+ pts=pts,
57
+ pts_weights=weights,
58
+ grid_iterator=grid_iterator,
59
+ bandwidth=bandwidth,
60
+ kernel=kernel,
61
+ **kwargs,
62
+ )
63
+ for pts, weights in signed_measures
64
+ ],
65
+ axis=0,
66
+ )
67
+
68
+ case "pykeops":
69
+
70
+ def convolution_signed_measures_on_grid(
71
+ signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
72
+ ) -> np.ndarray:
73
+ return api.cat(
74
+ [
75
+ _pts_convolution_pykeops(
76
+ pts=pts,
77
+ pts_weights=weights,
78
+ grid_iterator=grid_iterator,
79
+ bandwidth=bandwidth,
80
+ kernel=kernel,
81
+ **kwargs,
82
+ )
83
+ for pts, weights in signed_measures
84
+ ],
85
+ axis=0,
86
+ )
87
+
88
+ # compiles first once
89
+ pts, weights = iterable_of_signed_measures[0][0]
90
+ small_pts, small_weights = pts[:2], weights[:2]
91
+
92
+ _pts_convolution_pykeops(
93
+ small_pts,
94
+ small_weights,
95
+ grid_iterator=grid_iterator,
96
+ bandwidth=bandwidth,
97
+ kernel=kernel,
98
+ **kwargs,
99
+ )
100
+
101
+ if n_jobs > 1 or n_jobs == -1:
102
+ prefer = "processes" if backend == "sklearn" else "threads"
103
+ from joblib import Parallel, delayed
104
+
105
+ convolutions = Parallel(n_jobs=n_jobs, prefer=prefer)(
106
+ delayed(convolution_signed_measures_on_grid)(sms)
107
+ for sms in iterable_of_signed_measures
108
+ )
109
+ else:
110
+ convolutions = [
111
+ convolution_signed_measures_on_grid(sms)
112
+ for sms in iterable_of_signed_measures
113
+ ]
114
+ if not flatten:
115
+ out_shape = [-1] + [len(f) for f in filtrations] # Degree
116
+ convolutions = [x.reshape(out_shape) for x in convolutions]
117
+ return api.cat([x[None] for x in convolutions])
118
+
119
+
120
+ # def _test(r=1000, b=0.5, plot=True, kernel=0):
121
+ # import matplotlib.pyplot as plt
122
+ # pts, weigths = np.array([[1.,1.], [1.1,1.1]]), np.array([1,-1])
123
+ # pt_list = np.array(list(product(*[np.linspace(0,2,r)]*2)))
124
+ # img = _pts_convolution_sparse_pts(pts,weigths, pt_list,b,kernel=kernel)
125
+ # if plot:
126
+ # plt.imshow(img.reshape(r,-1).T, origin="lower")
127
+ # plt.show()
128
+
129
+
130
+ def _pts_convolution_sparse_old(
131
+ pts: np.ndarray,
132
+ pts_weights: np.ndarray,
133
+ grid_iterator,
134
+ kernel: available_kernels = "gaussian",
135
+ bandwidth=0.1,
136
+ **more_kde_args,
137
+ ):
138
+ """
139
+ Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
140
+ """
141
+ from sklearn.neighbors import KernelDensity
142
+
143
+ if len(pts) == 0:
144
+ # warn("Found a trivial signed measure !")
145
+ return np.zeros(len(grid_iterator))
146
+ kde = KernelDensity(
147
+ kernel=kernel, bandwidth=bandwidth, rtol=1e-4, **more_kde_args
148
+ ) # TODO : check rtol
149
+ pos_indices = pts_weights > 0
150
+ neg_indices = pts_weights < 0
151
+ img_pos = (
152
+ np.zeros(len(grid_iterator))
153
+ if pos_indices.sum() == 0
154
+ else kde.fit(
155
+ pts[pos_indices], sample_weight=pts_weights[pos_indices]
156
+ ).score_samples(grid_iterator)
157
+ )
158
+ img_neg = (
159
+ np.zeros(len(grid_iterator))
160
+ if neg_indices.sum() == 0
161
+ else kde.fit(
162
+ pts[neg_indices], sample_weight=-pts_weights[neg_indices]
163
+ ).score_samples(grid_iterator)
164
+ )
165
+ return np.exp(img_pos) - np.exp(img_neg)
166
+
167
+
168
+ def _pts_convolution_pykeops(
169
+ pts: np.ndarray,
170
+ pts_weights: np.ndarray,
171
+ grid_iterator,
172
+ kernel: available_kernels = "gaussian",
173
+ bandwidth=0.1,
174
+ **more_kde_args,
175
+ ):
176
+ """
177
+ Pykeops convolution
178
+ """
179
+ if isinstance(pts, np.ndarray):
180
+ _asarray_weights = lambda x: np.asarray(x, dtype=pts.dtype)
181
+ _asarray_grid = _asarray_weights
182
+ else:
183
+ import torch
184
+
185
+ _asarray_weights = lambda x: torch.from_numpy(x).type(pts.dtype)
186
+ _asarray_grid = lambda x: x.type(pts.dtype)
187
+ kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
188
+ return kde.fit(pts, sample_weights=_asarray_weights(pts_weights)).score_samples(
189
+ _asarray_grid(grid_iterator)
190
+ )
191
+
192
+
193
+ def gaussian_kernel(x_i, y_j, bandwidth):
194
+ D = x_i.shape[-1]
195
+ exponent = -(((x_i - y_j) / bandwidth) ** 2).sum(dim=-1) / 2
196
+ # float is necessary for some reason (pykeops fails)
197
+ kernel = (exponent).exp() / float((bandwidth * np.sqrt(2 * np.pi)) ** D)
198
+ return kernel
199
+
200
+
201
+ def multivariate_gaussian_kernel(x_i, y_j, covariance_matrix_inverse):
202
+ # 1 / \sqrt(2 \pi^dim * \Sigma.det()) * exp( -(x-y).T @ \Sigma ^{-1} @ (x-y))
203
+ # CF https://www.kernel-operations.io/keops/_auto_examples/pytorch/plot_anisotropic_kernels.html#sphx-glr-auto-examples-pytorch-plot-anisotropic-kernels-py
204
+ # and https://www.kernel-operations.io/keops/api/math-operations.html
205
+ dim = x_i.shape[-1]
206
+ z = x_i - y_j
207
+ exponent = -(z.weightedsqnorm(covariance_matrix_inverse.flatten()) / 2)
208
+ return (
209
+ float((2 * np.pi) ** (-dim / 2))
210
+ * (covariance_matrix_inverse.det().sqrt())
211
+ * exponent.exp()
212
+ )
213
+
214
+
215
+ def exponential_kernel(x_i, y_j, bandwidth):
216
+ # 1 / \sigma * exp( norm(x-y, dim=-1))
217
+ exponent = -(((((x_i - y_j) ** 2)).sum(dim=-1) ** 1 / 2) / bandwidth)
218
+ kernel = exponent.exp() / bandwidth
219
+ return kernel
220
+
221
+
222
+ def sinc_kernel(x_i, y_j, bandwidth):
223
+ norm = ((((x_i - y_j) ** 2)).sum(dim=-1) ** 1 / 2) / bandwidth
224
+ sinc = type(x_i).sinc
225
+ kernel = 2 * sinc(2 * norm) - sinc(norm)
226
+ return kernel
227
+
228
+
229
+ def _kernel(
230
+ kernel: available_kernels = "gaussian",
231
+ ):
232
+ match kernel:
233
+ case "gaussian":
234
+ return gaussian_kernel
235
+ case "exponential":
236
+ return exponential_kernel
237
+ case "multivariate_gaussian":
238
+ return multivariate_gaussian_kernel
239
+ case "sinc":
240
+ return sinc_kernel
241
+ case _:
242
+ assert callable(
243
+ kernel
244
+ ), f"""
245
+ --------------------------
246
+ Unknown kernel {kernel}.
247
+ --------------------------
248
+ Custom kernel has to be callable,
249
+ (x:LazyTensor(n,1,D),y:LazyTensor(1,m,D),bandwidth:float) ---> kernel matrix
250
+
251
+ Valid operations are given here:
252
+ https://www.kernel-operations.io/keops/python/api/index.html
253
+ """
254
+ return kernel
255
+
256
+
257
+ # TODO : multiple bandwidths at once with lazy tensors
258
+ class KDE:
259
+ """
260
+ Fast, scikit-style, and differentiable kernel density estimation, using PyKeops.
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ bandwidth: Any = 1,
266
+ kernel: available_kernels = "gaussian",
267
+ return_log: bool = False,
268
+ ):
269
+ """
270
+ bandwidth : numeric
271
+ bandwidth for Gaussian kernel
272
+ """
273
+ self.X = None
274
+ self.bandwidth = bandwidth
275
+ self.kernel: available_kernels = kernel
276
+ self._kernel = None
277
+ self._backend = None
278
+ self._sample_weights = None
279
+ self.return_log = return_log
280
+
281
+ def fit(self, X, sample_weights=None, y=None):
282
+ self.X = X
283
+ self._sample_weights = sample_weights
284
+ if isinstance(X, np.ndarray):
285
+ self._backend = np
286
+ else:
287
+ import torch
288
+
289
+ if isinstance(X, torch.Tensor):
290
+ self._backend = torch
291
+ else:
292
+ raise Exception("Unsupported backend.")
293
+ self._kernel = _kernel(self.kernel)
294
+ return self
295
+
296
+ @staticmethod
297
+ def to_lazy(X, Y, x_weights):
298
+ if isinstance(X, np.ndarray):
299
+ from pykeops.numpy import LazyTensor
300
+
301
+ lazy_x = LazyTensor(
302
+ X.reshape((X.shape[0], 1, X.shape[1]))
303
+ ) # numpts, 1, dim
304
+ lazy_y = LazyTensor(
305
+ Y.reshape((1, Y.shape[0], Y.shape[1])).astype(X.dtype)
306
+ ) # 1, numpts, dim
307
+ if x_weights is not None:
308
+ w = LazyTensor(np.asarray(x_weights, dtype=X.dtype)[:, None], axis=0)
309
+ return lazy_x, lazy_y, w
310
+ return lazy_x, lazy_y, None
311
+ import torch
312
+
313
+ if isinstance(X, torch.Tensor):
314
+ from pykeops.torch import LazyTensor
315
+
316
+ lazy_x = LazyTensor(X.view(X.shape[0], 1, X.shape[1]))
317
+ lazy_y = LazyTensor(Y.type(X.dtype).view(1, Y.shape[0], Y.shape[1]))
318
+ if x_weights is not None:
319
+ if isinstance(x_weights, np.ndarray):
320
+ x_weights = torch.from_numpy(x_weights)
321
+ w = LazyTensor(x_weights[:, None].type(X.dtype), axis=0)
322
+ return lazy_x, lazy_y, w
323
+ return lazy_x, lazy_y, None
324
+ raise Exception("Bad tensor type.")
325
+
326
+ def score_samples(self, Y, X=None, return_kernel=False):
327
+ """Returns the kernel density estimates of each point in `Y`.
328
+
329
+ Parameters
330
+ ----------
331
+ Y : tensor (m, d)
332
+ `m` points with `d` dimensions for which the probability density will
333
+ be calculated
334
+ X : tensor (n, d), optional
335
+ `n` points with `d` dimensions to which KDE will be fit. Provided to
336
+ allow batch calculations in `log_prob`. By default, `X` is None and
337
+ all points used to initialize KernelDensityEstimator are included.
338
+
339
+
340
+ Returns
341
+ -------
342
+ log_probs : tensor (m)
343
+ log probability densities for each of the queried points in `Y`
344
+ """
345
+ assert self._backend is not None and self._kernel is not None, "Fit first."
346
+ X = self.X if X is None else X
347
+ if X.shape[0] == 0:
348
+ return self._backend.zeros((Y.shape[0]))
349
+ assert Y.shape[1] == X.shape[1] and X.ndim == Y.ndim == 2
350
+ lazy_x, lazy_y, w = self.to_lazy(X, Y, x_weights=self._sample_weights)
351
+ kernel = self._kernel(lazy_x, lazy_y, self.bandwidth)
352
+ if w is not None:
353
+ kernel *= w
354
+ if return_kernel:
355
+ return kernel
356
+ density_estimation = kernel.sum(dim=0).squeeze() / kernel.shape[0] # mean
357
+ return (
358
+ self._backend.log(density_estimation)
359
+ if self.return_log
360
+ else density_estimation
361
+ )
362
+
363
+
364
+
365
+
366
+ class DTM:
367
+ """
368
+ Distance To Measure
369
+ """
370
+
371
+ def __init__(self, masses, metric: str = "euclidean", **_kdtree_kwargs):
372
+ """
373
+ mass : float in [0,1]
374
+ The mass threshold
375
+ metric :
376
+ The distance between points to consider
377
+ """
378
+ self.masses = masses
379
+ self.metric = metric
380
+ self._kdtree_kwargs = _kdtree_kwargs
381
+ self._ks = None
382
+ self._kdtree = None
383
+ self._X = None
384
+ self._backend = None
385
+
386
+ def fit(self, X, sample_weights=None, y=None):
387
+ if len(self.masses) == 0:
388
+ return self
389
+ assert np.max(self.masses) <= 1, "All masses should be in (0,1]."
390
+ from sklearn.neighbors import KDTree
391
+
392
+ if not isinstance(X, np.ndarray):
393
+ import torch
394
+
395
+ assert isinstance(X, torch.Tensor), "Backend has to be numpy of torch"
396
+ _X = X.detach()
397
+ self._backend = "torch"
398
+ else:
399
+ _X = X
400
+ self._backend = "numpy"
401
+ self._ks = np.array([int(mass * X.shape[0]) + 1 for mass in self.masses])
402
+ self._kdtree = KDTree(_X, metric=self.metric, **self._kdtree_kwargs)
403
+ self._X = X
404
+ return self
405
+
406
+ def score_samples(self, Y, X=None):
407
+ """Returns the kernel density estimates of each point in `Y`.
408
+
409
+ Parameters
410
+ ----------
411
+ Y : tensor (m, d)
412
+ `m` points with `d` dimensions for which the probability density will
413
+ be calculated
414
+
415
+
416
+ Returns
417
+ -------
418
+ the DTMs of Y, for each mass in masses.
419
+ """
420
+ if len(self.masses) == 0:
421
+ return np.empty((0, len(Y)))
422
+ assert (
423
+ self._ks is not None and self._kdtree is not None and self._X is not None
424
+ ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
425
+ assert Y.ndim == 2
426
+ if self._backend == "torch":
427
+ _Y = Y.detach().numpy()
428
+ else:
429
+ _Y = Y
430
+ NN_Dist, NN = self._kdtree.query(_Y, self._ks.max(), return_distance=True)
431
+ DTMs = np.array([((NN_Dist**2)[:, :k].mean(1)) ** 0.5 for k in self._ks])
432
+ return DTMs
433
+
434
+ def score_samples_diff(self, Y):
435
+ """Returns the kernel density estimates of each point in `Y`.
436
+
437
+ Parameters
438
+ ----------
439
+ Y : tensor (m, d)
440
+ `m` points with `d` dimensions for which the probability density will
441
+ be calculated
442
+ X : tensor (n, d), optional
443
+ `n` points with `d` dimensions to which KDE will be fit. Provided to
444
+ allow batch calculations in `log_prob`. By default, `X` is None and
445
+ all points used to initialize KernelDensityEstimator are included.
446
+
447
+
448
+ Returns
449
+ -------
450
+ log_probs : tensor (m)
451
+ log probability densities for each of the queried points in `Y`
452
+ """
453
+ import torch
454
+
455
+ if len(self.masses) == 0:
456
+ return torch.empty(0, len(Y))
457
+
458
+ assert Y.ndim == 2
459
+ assert self._backend == "torch", "Use the non-diff version with numpy."
460
+ assert (
461
+ self._ks is not None and self._kdtree is not None and self._X is not None
462
+ ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
463
+ NN = self._kdtree.query(Y.detach(), self._ks.max(), return_distance=False)
464
+ DTMs = tuple(
465
+ (((self._X[NN] - Y[:, None, :]) ** 2)[:, :k].sum(dim=(1, 2)) / k) ** 0.5
466
+ for k in self._ks
467
+ ) # TODO : kdtree already computes distance, find implementation of kdtree that is pytorch differentiable
468
+ return DTMs
469
+
470
+
471
+ ## code taken from pykeops doc (https://www.kernel-operations.io/keops/_auto_benchmarks/benchmark_KNN.html)
472
+ class KNNmean:
473
+ def __init__(self, k: int, metric: str = "euclidean"):
474
+ self.k = k
475
+ self.metric = metric
476
+ self._KNN_fun = None
477
+ self._x = None
478
+
479
+ def fit(self, x):
480
+ if isinstance(x, np.ndarray):
481
+ from pykeops.numpy import Vi, Vj
482
+ else:
483
+ import torch
484
+
485
+ assert isinstance(x, torch.Tensor), "Backend has to be numpy or torch"
486
+ from pykeops.torch import Vi, Vj
487
+
488
+ D = x.shape[1]
489
+ X_i = Vi(0, D)
490
+ X_j = Vj(1, D)
491
+
492
+ # Symbolic distance matrix:
493
+ if self.metric == "euclidean":
494
+ D_ij = ((X_i - X_j) ** 2).sum(-1) ** (1 / 2)
495
+ elif self.metric == "manhattan":
496
+ D_ij = (X_i - X_j).abs().sum(-1)
497
+ elif self.metric == "angular":
498
+ D_ij = -(X_i | X_j)
499
+ elif self.metric == "hyperbolic":
500
+ D_ij = ((X_i - X_j) ** 2).sum(-1) / (X_i[0] * X_j[0])
501
+ else:
502
+ raise NotImplementedError(f"The '{self.metric}' distance is not supported.")
503
+
504
+ self._x = x
505
+ self._KNN_fun = D_ij.Kmin(self.k, dim=1)
506
+ return self
507
+
508
+ def score_samples(self, x):
509
+ assert self._x is not None and self._KNN_fun is not None, "Fit first."
510
+ return self._KNN_fun(x, self._x).sum(axis=1) / self.k
511
+
512
+
513
+ # def _pts_convolution_sparse(pts:np.ndarray, pts_weights:np.ndarray, filtration_grid:Iterable[np.ndarray], kernel="gaussian", bandwidth=0.1, **more_kde_args):
514
+ # """
515
+ # Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
516
+ # """
517
+ # from sklearn.neighbors import KernelDensity
518
+ # grid_iterator = np.asarray(list(product(*filtration_grid)))
519
+ # grid_shape = [len(f) for f in filtration_grid]
520
+ # if len(pts) == 0:
521
+ # # warn("Found a trivial signed measure !")
522
+ # return np.zeros(shape=grid_shape)
523
+ # kde = KernelDensity(kernel=kernel, bandwidth=bandwidth, rtol = 1e-4, **more_kde_args) # TODO : check rtol
524
+
525
+ # pos_indices = pts_weights>0
526
+ # neg_indices = pts_weights<0
527
+ # img_pos = kde.fit(pts[pos_indices], sample_weight=pts_weights[pos_indices]).score_samples(grid_iterator).reshape(grid_shape)
528
+ # img_neg = kde.fit(pts[neg_indices], sample_weight=-pts_weights[neg_indices]).score_samples(grid_iterator).reshape(grid_shape)
529
+ # return np.exp(img_pos) - np.exp(img_neg)
530
+
531
+
532
+ # Precompiles the convolution
533
+ # _test(r=2,b=.5, plot=False)