multipers 2.3.3b6__cp311-cp311-manylinux_2_39_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -0
  2. multipers/_signed_measure_meta.py +453 -0
  3. multipers/_slicer_meta.py +211 -0
  4. multipers/array_api/__init__.py +45 -0
  5. multipers/array_api/numpy.py +41 -0
  6. multipers/array_api/torch.py +58 -0
  7. multipers/data/MOL2.py +458 -0
  8. multipers/data/UCR.py +18 -0
  9. multipers/data/__init__.py +1 -0
  10. multipers/data/graphs.py +466 -0
  11. multipers/data/immuno_regions.py +27 -0
  12. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  13. multipers/data/pytorch2simplextree.py +91 -0
  14. multipers/data/shape3d.py +101 -0
  15. multipers/data/synthetic.py +113 -0
  16. multipers/distances.py +202 -0
  17. multipers/filtration_conversions.pxd +229 -0
  18. multipers/filtration_conversions.pxd.tp +84 -0
  19. multipers/filtrations/__init__.py +18 -0
  20. multipers/filtrations/density.py +574 -0
  21. multipers/filtrations/filtrations.py +361 -0
  22. multipers/filtrations.pxd +224 -0
  23. multipers/function_rips.cpython-311-x86_64-linux-gnu.so +0 -0
  24. multipers/function_rips.pyx +105 -0
  25. multipers/grids.cpython-311-x86_64-linux-gnu.so +0 -0
  26. multipers/grids.pyx +433 -0
  27. multipers/gudhi/Persistence_slices_interface.h +132 -0
  28. multipers/gudhi/Simplex_tree_interface.h +239 -0
  29. multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
  30. multipers/gudhi/cubical_to_boundary.h +59 -0
  31. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  32. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  33. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  34. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  35. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  36. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  37. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  38. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  39. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  41. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  42. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  43. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  44. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  45. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  46. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  47. multipers/gudhi/gudhi/Matrix.h +2107 -0
  48. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  49. multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
  50. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  51. multipers/gudhi/gudhi/Off_reader.h +173 -0
  52. multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
  53. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  54. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  87. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  88. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  89. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  90. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  91. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  92. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  93. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  94. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  95. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  96. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  97. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  98. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  99. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  100. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  101. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  102. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
  103. multipers/gudhi/gudhi/distance_functions.h +62 -0
  104. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  105. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  106. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  107. multipers/gudhi/gudhi/reader_utils.h +367 -0
  108. multipers/gudhi/mma_interface_coh.h +256 -0
  109. multipers/gudhi/mma_interface_h0.h +223 -0
  110. multipers/gudhi/mma_interface_matrix.h +293 -0
  111. multipers/gudhi/naive_merge_tree.h +536 -0
  112. multipers/gudhi/scc_io.h +310 -0
  113. multipers/gudhi/truc.h +1403 -0
  114. multipers/io.cpython-311-x86_64-linux-gnu.so +0 -0
  115. multipers/io.pyx +644 -0
  116. multipers/ml/__init__.py +0 -0
  117. multipers/ml/accuracies.py +90 -0
  118. multipers/ml/invariants_with_persistable.py +79 -0
  119. multipers/ml/kernels.py +176 -0
  120. multipers/ml/mma.py +713 -0
  121. multipers/ml/one.py +472 -0
  122. multipers/ml/point_clouds.py +352 -0
  123. multipers/ml/signed_measures.py +1589 -0
  124. multipers/ml/sliced_wasserstein.py +461 -0
  125. multipers/ml/tools.py +113 -0
  126. multipers/mma_structures.cpython-311-x86_64-linux-gnu.so +0 -0
  127. multipers/mma_structures.pxd +128 -0
  128. multipers/mma_structures.pyx +2786 -0
  129. multipers/mma_structures.pyx.tp +1094 -0
  130. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
  131. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  132. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  133. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  134. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  135. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  136. multipers/multiparameter_edge_collapse.py +41 -0
  137. multipers/multiparameter_module_approximation/approximation.h +2330 -0
  138. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  139. multipers/multiparameter_module_approximation/debug.h +107 -0
  140. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  141. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  142. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  143. multipers/multiparameter_module_approximation/images.h +79 -0
  144. multipers/multiparameter_module_approximation/list_column.h +174 -0
  145. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  146. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  147. multipers/multiparameter_module_approximation/set_column.h +135 -0
  148. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  149. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  150. multipers/multiparameter_module_approximation/utilities.h +403 -0
  151. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  152. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  153. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  154. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  155. multipers/multiparameter_module_approximation.cpython-311-x86_64-linux-gnu.so +0 -0
  156. multipers/multiparameter_module_approximation.pyx +235 -0
  157. multipers/pickle.py +90 -0
  158. multipers/plots.py +456 -0
  159. multipers/point_measure.cpython-311-x86_64-linux-gnu.so +0 -0
  160. multipers/point_measure.pyx +395 -0
  161. multipers/simplex_tree_multi.cpython-311-x86_64-linux-gnu.so +0 -0
  162. multipers/simplex_tree_multi.pxd +134 -0
  163. multipers/simplex_tree_multi.pyx +10840 -0
  164. multipers/simplex_tree_multi.pyx.tp +2009 -0
  165. multipers/slicer.cpython-311-x86_64-linux-gnu.so +0 -0
  166. multipers/slicer.pxd +3034 -0
  167. multipers/slicer.pxd.tp +234 -0
  168. multipers/slicer.pyx +20481 -0
  169. multipers/slicer.pyx.tp +1088 -0
  170. multipers/tensor/tensor.h +672 -0
  171. multipers/tensor.pxd +13 -0
  172. multipers/test.pyx +44 -0
  173. multipers/tests/__init__.py +62 -0
  174. multipers/torch/__init__.py +1 -0
  175. multipers/torch/diff_grids.py +240 -0
  176. multipers/torch/rips_density.py +310 -0
  177. multipers-2.3.3b6.dist-info/METADATA +128 -0
  178. multipers-2.3.3b6.dist-info/RECORD +182 -0
  179. multipers-2.3.3b6.dist-info/WHEEL +5 -0
  180. multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
  181. multipers-2.3.3b6.dist-info/top_level.txt +1 -0
  182. multipers.libs/libtbb-ca48af5c.so.12.16 +0 -0
@@ -0,0 +1,574 @@
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Any, Literal, Union
3
+ import numpy as np
4
+
5
+
6
+ from multipers.array_api import api_from_tensor
7
+ global available_kernels
8
+ available_kernels = Union[
9
+ Literal[
10
+ "gaussian", "exponential", "exponential_kernel", "multivariate_gaussian", "sinc"
11
+ ],
12
+ Callable,
13
+ ]
14
+
15
+
16
+ def convolution_signed_measures(
17
+ iterable_of_signed_measures,
18
+ filtrations,
19
+ bandwidth,
20
+ flatten: bool = True,
21
+ n_jobs: int = 1,
22
+ backend="pykeops",
23
+ kernel: available_kernels = "gaussian",
24
+ **kwargs,
25
+ ):
26
+ """
27
+ Evaluates the convolution of the signed measures Iterable(pts, weights) with a gaussian measure of bandwidth bandwidth, on a grid given by the filtrations
28
+
29
+ Parameters
30
+ ----------
31
+
32
+ - iterable_of_signed_measures : (num_signed_measure) x [ (npts) x (num_parameters), (npts)]
33
+ - filtrations : (num_parameter) x (filtration values)
34
+ - flatten : bool
35
+ - n_jobs : int
36
+
37
+ Outputs
38
+ -------
39
+
40
+ The concatenated images, for each signed measure (num_signed_measures) x (len(f) for f in filtration_values)
41
+ """
42
+ from multipers.grids import todense
43
+
44
+ grid_iterator = todense(filtrations, product_order=True)
45
+ api = api_from_tensor(iterable_of_signed_measures[0][0][0])
46
+ match backend:
47
+ case "sklearn":
48
+
49
+ def convolution_signed_measures_on_grid(
50
+ signed_measures,
51
+ ):
52
+ return api.cat(
53
+ [
54
+ _pts_convolution_sparse_old(
55
+ pts=pts,
56
+ pts_weights=weights,
57
+ grid_iterator=grid_iterator,
58
+ bandwidth=bandwidth,
59
+ kernel=kernel,
60
+ **kwargs,
61
+ )
62
+ for pts, weights in signed_measures
63
+ ],
64
+ axis=0,
65
+ )
66
+
67
+ case "pykeops":
68
+
69
+ def convolution_signed_measures_on_grid(
70
+ signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
71
+ ) -> np.ndarray:
72
+ return api.cat(
73
+ [
74
+ _pts_convolution_pykeops(
75
+ pts=pts,
76
+ pts_weights=weights,
77
+ grid_iterator=grid_iterator,
78
+ bandwidth=bandwidth,
79
+ kernel=kernel,
80
+ **kwargs,
81
+ )
82
+ for pts, weights in signed_measures
83
+ ],
84
+ axis=0,
85
+ )
86
+
87
+ # compiles first once
88
+ pts, weights = iterable_of_signed_measures[0][0]
89
+ small_pts, small_weights = pts[:2], weights[:2]
90
+
91
+ _pts_convolution_pykeops(
92
+ small_pts,
93
+ small_weights,
94
+ grid_iterator=grid_iterator,
95
+ bandwidth=bandwidth,
96
+ kernel=kernel,
97
+ **kwargs,
98
+ )
99
+
100
+ if n_jobs > 1 or n_jobs == -1:
101
+ prefer = "processes" if backend == "sklearn" else "threads"
102
+ from joblib import Parallel, delayed
103
+
104
+ convolutions = Parallel(n_jobs=n_jobs, prefer=prefer)(
105
+ delayed(convolution_signed_measures_on_grid)(sms)
106
+ for sms in iterable_of_signed_measures
107
+ )
108
+ else:
109
+ convolutions = [
110
+ convolution_signed_measures_on_grid(sms)
111
+ for sms in iterable_of_signed_measures
112
+ ]
113
+ if not flatten:
114
+ out_shape = [-1] + [len(f) for f in filtrations] # Degree
115
+ convolutions = [x.reshape(out_shape) for x in convolutions]
116
+ return api.cat([x[None] for x in convolutions])
117
+
118
+
119
+ # def _test(r=1000, b=0.5, plot=True, kernel=0):
120
+ # import matplotlib.pyplot as plt
121
+ # pts, weigths = np.array([[1.,1.], [1.1,1.1]]), np.array([1,-1])
122
+ # pt_list = np.array(list(product(*[np.linspace(0,2,r)]*2)))
123
+ # img = _pts_convolution_sparse_pts(pts,weigths, pt_list,b,kernel=kernel)
124
+ # if plot:
125
+ # plt.imshow(img.reshape(r,-1).T, origin="lower")
126
+ # plt.show()
127
+
128
+
129
+ def _pts_convolution_sparse_old(
130
+ pts: np.ndarray,
131
+ pts_weights: np.ndarray,
132
+ grid_iterator,
133
+ kernel: available_kernels = "gaussian",
134
+ bandwidth=0.1,
135
+ **more_kde_args,
136
+ ):
137
+ """
138
+ Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
139
+ """
140
+ from sklearn.neighbors import KernelDensity
141
+
142
+ if len(pts) == 0:
143
+ # warn("Found a trivial signed measure !")
144
+ return np.zeros(len(grid_iterator))
145
+ kde = KernelDensity(
146
+ kernel=kernel, bandwidth=bandwidth, rtol=1e-4, **more_kde_args
147
+ ) # TODO : check rtol
148
+ pos_indices = pts_weights > 0
149
+ neg_indices = pts_weights < 0
150
+ img_pos = (
151
+ np.zeros(len(grid_iterator))
152
+ if pos_indices.sum() == 0
153
+ else kde.fit(
154
+ pts[pos_indices], sample_weight=pts_weights[pos_indices]
155
+ ).score_samples(grid_iterator)
156
+ )
157
+ img_neg = (
158
+ np.zeros(len(grid_iterator))
159
+ if neg_indices.sum() == 0
160
+ else kde.fit(
161
+ pts[neg_indices], sample_weight=-pts_weights[neg_indices]
162
+ ).score_samples(grid_iterator)
163
+ )
164
+ return np.exp(img_pos) - np.exp(img_neg)
165
+
166
+
167
+ def _pts_convolution_pykeops(
168
+ pts: np.ndarray,
169
+ pts_weights: np.ndarray,
170
+ grid_iterator,
171
+ kernel: available_kernels = "gaussian",
172
+ bandwidth=0.1,
173
+ **more_kde_args,
174
+ ):
175
+ """
176
+ Pykeops convolution
177
+ """
178
+ if isinstance(pts, np.ndarray):
179
+ _asarray_weights = lambda x : np.asarray(x, dtype=pts.dtype)
180
+ _asarray_grid = _asarray_weights
181
+ else:
182
+ import torch
183
+ _asarray_weights = lambda x : torch.from_numpy(x).type(pts.dtype)
184
+ _asarray_grid = lambda x : x.type(pts.dtype)
185
+ kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
186
+ return kde.fit(
187
+ pts, sample_weights=_asarray_weights(pts_weights)
188
+ ).score_samples(_asarray_grid(grid_iterator))
189
+
190
+
191
+ def gaussian_kernel(x_i, y_j, bandwidth):
192
+ D = x_i.shape[-1]
193
+ exponent = -(((x_i - y_j) / bandwidth) ** 2).sum(dim=-1) / 2
194
+ # float is necessary for some reason (pykeops fails)
195
+ kernel = (exponent).exp() / float((bandwidth*np.sqrt(2 * np.pi))**D)
196
+ return kernel
197
+
198
+
199
+ def multivariate_gaussian_kernel(x_i, y_j, covariance_matrix_inverse):
200
+ # 1 / \sqrt(2 \pi^dim * \Sigma.det()) * exp( -(x-y).T @ \Sigma ^{-1} @ (x-y))
201
+ # CF https://www.kernel-operations.io/keops/_auto_examples/pytorch/plot_anisotropic_kernels.html#sphx-glr-auto-examples-pytorch-plot-anisotropic-kernels-py
202
+ # and https://www.kernel-operations.io/keops/api/math-operations.html
203
+ dim = x_i.shape[-1]
204
+ z = x_i - y_j
205
+ exponent = -(z.weightedsqnorm(covariance_matrix_inverse.flatten()) / 2)
206
+ return (
207
+ float((2 * np.pi) ** (-dim / 2))
208
+ * (covariance_matrix_inverse.det().sqrt())
209
+ * exponent.exp()
210
+ )
211
+
212
+
213
+ def exponential_kernel(x_i, y_j, bandwidth):
214
+ # 1 / \sigma * exp( norm(x-y, dim=-1))
215
+ exponent = -(((((x_i - y_j) ** 2)).sum(dim=-1) ** 1 / 2) / bandwidth)
216
+ kernel = exponent.exp() / bandwidth
217
+ return kernel
218
+
219
+
220
+ def sinc_kernel(x_i, y_j, bandwidth):
221
+ norm = ((((x_i - y_j) ** 2)).sum(dim=-1) ** 1 / 2) / bandwidth
222
+ sinc = type(x_i).sinc
223
+ kernel = 2 * sinc(2 * norm) - sinc(norm)
224
+ return kernel
225
+
226
+
227
+ def _kernel(
228
+ kernel: available_kernels = "gaussian",
229
+ ):
230
+ match kernel:
231
+ case "gaussian":
232
+ return gaussian_kernel
233
+ case "exponential":
234
+ return exponential_kernel
235
+ case "multivariate_gaussian":
236
+ return multivariate_gaussian_kernel
237
+ case "sinc":
238
+ return sinc_kernel
239
+ case _:
240
+ assert callable(
241
+ kernel
242
+ ), f"""
243
+ --------------------------
244
+ Unknown kernel {kernel}.
245
+ --------------------------
246
+ Custom kernel has to be callable,
247
+ (x:LazyTensor(n,1,D),y:LazyTensor(1,m,D),bandwidth:float) ---> kernel matrix
248
+
249
+ Valid operations are given here:
250
+ https://www.kernel-operations.io/keops/python/api/index.html
251
+ """
252
+ return kernel
253
+
254
+
255
+ # TODO : multiple bandwidths at once with lazy tensors
256
+ class KDE:
257
+ """
258
+ Fast, scikit-style, and differentiable kernel density estimation, using PyKeops.
259
+ """
260
+
261
+ def __init__(
262
+ self,
263
+ bandwidth: Any = 1,
264
+ kernel: available_kernels = "gaussian",
265
+ return_log: bool = False,
266
+ ):
267
+ """
268
+ bandwidth : numeric
269
+ bandwidth for Gaussian kernel
270
+ """
271
+ self.X = None
272
+ self.bandwidth = bandwidth
273
+ self.kernel: available_kernels = kernel
274
+ self._kernel = None
275
+ self._backend = None
276
+ self._sample_weights = None
277
+ self.return_log = return_log
278
+
279
+ def fit(self, X, sample_weights=None, y=None):
280
+ self.X = X
281
+ self._sample_weights = sample_weights
282
+ if isinstance(X, np.ndarray):
283
+ self._backend = np
284
+ else:
285
+ import torch
286
+
287
+ if isinstance(X, torch.Tensor):
288
+ self._backend = torch
289
+ else:
290
+ raise Exception("Unsupported backend.")
291
+ self._kernel = _kernel(self.kernel)
292
+ return self
293
+
294
+ @staticmethod
295
+ def to_lazy(X, Y, x_weights):
296
+ if isinstance(X, np.ndarray):
297
+ from pykeops.numpy import LazyTensor
298
+
299
+ lazy_x = LazyTensor(
300
+ X.reshape((X.shape[0], 1, X.shape[1]))
301
+ ) # numpts, 1, dim
302
+ lazy_y = LazyTensor(
303
+ Y.reshape((1, Y.shape[0], Y.shape[1])).astype(X.dtype)
304
+ ) # 1, numpts, dim
305
+ if x_weights is not None:
306
+ w = LazyTensor(np.asarray(x_weights, dtype=X.dtype)[:, None], axis=0)
307
+ return lazy_x, lazy_y, w
308
+ return lazy_x, lazy_y, None
309
+ import torch
310
+
311
+ if isinstance(X, torch.Tensor):
312
+ from pykeops.torch import LazyTensor
313
+
314
+ lazy_x = LazyTensor(X.view(X.shape[0], 1, X.shape[1]))
315
+ lazy_y = LazyTensor(Y.type(X.dtype).view(1, Y.shape[0], Y.shape[1]))
316
+ if x_weights is not None:
317
+ if isinstance(x_weights, np.ndarray):
318
+ x_weights = torch.from_numpy(x_weights)
319
+ w = LazyTensor(x_weights[:, None].type(X.dtype), axis=0)
320
+ return lazy_x, lazy_y, w
321
+ return lazy_x, lazy_y, None
322
+ raise Exception("Bad tensor type.")
323
+
324
+ def score_samples(self, Y, X=None, return_kernel=False):
325
+ """Returns the kernel density estimates of each point in `Y`.
326
+
327
+ Parameters
328
+ ----------
329
+ Y : tensor (m, d)
330
+ `m` points with `d` dimensions for which the probability density will
331
+ be calculated
332
+ X : tensor (n, d), optional
333
+ `n` points with `d` dimensions to which KDE will be fit. Provided to
334
+ allow batch calculations in `log_prob`. By default, `X` is None and
335
+ all points used to initialize KernelDensityEstimator are included.
336
+
337
+
338
+ Returns
339
+ -------
340
+ log_probs : tensor (m)
341
+ log probability densities for each of the queried points in `Y`
342
+ """
343
+ assert self._backend is not None and self._kernel is not None, "Fit first."
344
+ X = self.X if X is None else X
345
+ if X.shape[0] == 0:
346
+ return self._backend.zeros((Y.shape[0]))
347
+ assert Y.shape[1] == X.shape[1] and X.ndim == Y.ndim == 2
348
+ lazy_x, lazy_y, w = self.to_lazy(X, Y, x_weights=self._sample_weights)
349
+ kernel = self._kernel(lazy_x, lazy_y, self.bandwidth)
350
+ if w is not None:
351
+ kernel *= w
352
+ if return_kernel:
353
+ return kernel
354
+ density_estimation = kernel.sum(dim=0).squeeze() / kernel.shape[0] # mean
355
+ return (
356
+ self._backend.log(density_estimation)
357
+ if self.return_log
358
+ else density_estimation
359
+ )
360
+
361
+
362
+ def batch_signed_measure_convolutions(
363
+ signed_measures, # array of shape (num_data,num_pts,D)
364
+ x, # array of shape (num_x, D) or (num_data, num_x, D)
365
+ bandwidth, # either float or matrix if multivariate kernel
366
+ kernel: available_kernels,
367
+ ):
368
+ """
369
+ Input
370
+ -----
371
+ - signed_measures: unragged, of shape (num_data, num_pts, D+1)
372
+ where last coord is weights, (0 for dummy points)
373
+ - x : the points to convolve (num_x,D)
374
+ - bandwidth : the bandwidths or covariance matrix inverse or ... of the kernel
375
+ - kernel : "gaussian", "multivariate_gaussian", "exponential", or Callable (x_i, y_i, bandwidth)->float
376
+
377
+ Output
378
+ ------
379
+ Array of shape (num_convolutions, (num_axis), num_data,
380
+ Array of shape (num_convolutions, (num_axis), num_data, max_x_size)
381
+ """
382
+ if signed_measures.ndim == 2:
383
+ signed_measures = signed_measures[None, :, :]
384
+ sms = signed_measures[..., :-1]
385
+ weights = signed_measures[..., -1]
386
+ if isinstance(signed_measures, np.ndarray):
387
+ from pykeops.numpy import LazyTensor
388
+ else:
389
+ import torch
390
+
391
+ assert isinstance(signed_measures, torch.Tensor)
392
+ from pykeops.torch import LazyTensor
393
+
394
+ _sms = LazyTensor(sms[..., None, :].contiguous())
395
+ _x = x[..., None, :, :].contiguous()
396
+
397
+ sms_kernel = _kernel(kernel)(_sms, _x, bandwidth)
398
+ out = (sms_kernel * weights[..., None, None].contiguous()).sum(
399
+ signed_measures.ndim - 2
400
+ )
401
+ assert out.shape[-1] == 1, "Pykeops bug fixed, TODO : refix this "
402
+ out = out[..., 0] ## pykeops bug + ensures its a tensor
403
+ # assert out.shape == (x.shape[0], x.shape[1]), f"{x.shape=}, {out.shape=}"
404
+ return out
405
+
406
+
407
+ class DTM:
408
+ """
409
+ Distance To Measure
410
+ """
411
+
412
+ def __init__(self, masses, metric: str = "euclidean", **_kdtree_kwargs):
413
+ """
414
+ mass : float in [0,1]
415
+ The mass threshold
416
+ metric :
417
+ The distance between points to consider
418
+ """
419
+ self.masses = masses
420
+ self.metric = metric
421
+ self._kdtree_kwargs = _kdtree_kwargs
422
+ self._ks = None
423
+ self._kdtree = None
424
+ self._X = None
425
+ self._backend = None
426
+
427
+ def fit(self, X, sample_weights=None, y=None):
428
+ if len(self.masses) == 0:
429
+ return self
430
+ assert np.max(self.masses) <= 1, "All masses should be in (0,1]."
431
+ from sklearn.neighbors import KDTree
432
+
433
+ if not isinstance(X, np.ndarray):
434
+ import torch
435
+
436
+ assert isinstance(X, torch.Tensor), "Backend has to be numpy of torch"
437
+ _X = X.detach()
438
+ self._backend = "torch"
439
+ else:
440
+ _X = X
441
+ self._backend = "numpy"
442
+ self._ks = np.array([int(mass * X.shape[0]) + 1 for mass in self.masses])
443
+ self._kdtree = KDTree(_X, metric=self.metric, **self._kdtree_kwargs)
444
+ self._X = X
445
+ return self
446
+
447
+ def score_samples(self, Y, X=None):
448
+ """Returns the kernel density estimates of each point in `Y`.
449
+
450
+ Parameters
451
+ ----------
452
+ Y : tensor (m, d)
453
+ `m` points with `d` dimensions for which the probability density will
454
+ be calculated
455
+
456
+
457
+ Returns
458
+ -------
459
+ the DTMs of Y, for each mass in masses.
460
+ """
461
+ if len(self.masses) == 0:
462
+ return np.empty((0, len(Y)))
463
+ assert (
464
+ self._ks is not None and self._kdtree is not None and self._X is not None
465
+ ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
466
+ assert Y.ndim == 2
467
+ if self._backend == "torch":
468
+ _Y = Y.detach().numpy()
469
+ else:
470
+ _Y = Y
471
+ NN_Dist, NN = self._kdtree.query(_Y, self._ks.max(), return_distance=True)
472
+ DTMs = np.array([((NN_Dist**2)[:, :k].mean(1)) ** 0.5 for k in self._ks])
473
+ return DTMs
474
+
475
+ def score_samples_diff(self, Y):
476
+ """Returns the kernel density estimates of each point in `Y`.
477
+
478
+ Parameters
479
+ ----------
480
+ Y : tensor (m, d)
481
+ `m` points with `d` dimensions for which the probability density will
482
+ be calculated
483
+ X : tensor (n, d), optional
484
+ `n` points with `d` dimensions to which KDE will be fit. Provided to
485
+ allow batch calculations in `log_prob`. By default, `X` is None and
486
+ all points used to initialize KernelDensityEstimator are included.
487
+
488
+
489
+ Returns
490
+ -------
491
+ log_probs : tensor (m)
492
+ log probability densities for each of the queried points in `Y`
493
+ """
494
+ import torch
495
+
496
+ if len(self.masses) == 0:
497
+ return torch.empty(0, len(Y))
498
+
499
+ assert Y.ndim == 2
500
+ assert self._backend == "torch", "Use the non-diff version with numpy."
501
+ assert (
502
+ self._ks is not None and self._kdtree is not None and self._X is not None
503
+ ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
504
+ NN = self._kdtree.query(Y.detach(), self._ks.max(), return_distance=False)
505
+ DTMs = tuple(
506
+ (((self._X[NN] - Y[:, None, :]) ** 2)[:, :k].sum(dim=(1, 2)) / k) ** 0.5
507
+ for k in self._ks
508
+ ) # TODO : kdtree already computes distance, find implementation of kdtree that is pytorch differentiable
509
+ return DTMs
510
+
511
+
512
+ ## code taken from pykeops doc (https://www.kernel-operations.io/keops/_auto_benchmarks/benchmark_KNN.html)
513
+ class KNNmean:
514
+ def __init__(self, k: int, metric: str = "euclidean"):
515
+ self.k = k
516
+ self.metric = metric
517
+ self._KNN_fun = None
518
+ self._x = None
519
+
520
+ def fit(self, x):
521
+ if isinstance(x, np.ndarray):
522
+ from pykeops.numpy import Vi, Vj
523
+ else:
524
+ import torch
525
+
526
+ assert isinstance(x, torch.Tensor), "Backend has to be numpy or torch"
527
+ from pykeops.torch import Vi, Vj
528
+
529
+ D = x.shape[1]
530
+ X_i = Vi(0, D)
531
+ X_j = Vj(1, D)
532
+
533
+ # Symbolic distance matrix:
534
+ if self.metric == "euclidean":
535
+ D_ij = ((X_i - X_j) ** 2).sum(-1) ** (1/2)
536
+ elif self.metric == "manhattan":
537
+ D_ij = (X_i - X_j).abs().sum(-1)
538
+ elif self.metric == "angular":
539
+ D_ij = -(X_i | X_j)
540
+ elif self.metric == "hyperbolic":
541
+ D_ij = ((X_i - X_j) ** 2).sum(-1) / (X_i[0] * X_j[0])
542
+ else:
543
+ raise NotImplementedError(f"The '{self.metric}' distance is not supported.")
544
+
545
+ self._x = x
546
+ self._KNN_fun = D_ij.Kmin(self.k, dim=1)
547
+ return self
548
+
549
+ def score_samples(self, x):
550
+ assert self._x is not None and self._KNN_fun is not None, "Fit first."
551
+ return self._KNN_fun(x, self._x).sum(axis=1) / self.k
552
+
553
+
554
+ # def _pts_convolution_sparse(pts:np.ndarray, pts_weights:np.ndarray, filtration_grid:Iterable[np.ndarray], kernel="gaussian", bandwidth=0.1, **more_kde_args):
555
+ # """
556
+ # Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
557
+ # """
558
+ # from sklearn.neighbors import KernelDensity
559
+ # grid_iterator = np.asarray(list(product(*filtration_grid)))
560
+ # grid_shape = [len(f) for f in filtration_grid]
561
+ # if len(pts) == 0:
562
+ # # warn("Found a trivial signed measure !")
563
+ # return np.zeros(shape=grid_shape)
564
+ # kde = KernelDensity(kernel=kernel, bandwidth=bandwidth, rtol = 1e-4, **more_kde_args) # TODO : check rtol
565
+
566
+ # pos_indices = pts_weights>0
567
+ # neg_indices = pts_weights<0
568
+ # img_pos = kde.fit(pts[pos_indices], sample_weight=pts_weights[pos_indices]).score_samples(grid_iterator).reshape(grid_shape)
569
+ # img_neg = kde.fit(pts[neg_indices], sample_weight=-pts_weights[neg_indices]).score_samples(grid_iterator).reshape(grid_shape)
570
+ # return np.exp(img_pos) - np.exp(img_neg)
571
+
572
+
573
+ # Precompiles the convolution
574
+ # _test(r=2,b=.5, plot=False)