multipers 2.3.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -0
  2. multipers/_signed_measure_meta.py +430 -0
  3. multipers/_slicer_meta.py +211 -0
  4. multipers/data/MOL2.py +458 -0
  5. multipers/data/UCR.py +18 -0
  6. multipers/data/__init__.py +1 -0
  7. multipers/data/graphs.py +466 -0
  8. multipers/data/immuno_regions.py +27 -0
  9. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  10. multipers/data/pytorch2simplextree.py +91 -0
  11. multipers/data/shape3d.py +101 -0
  12. multipers/data/synthetic.py +113 -0
  13. multipers/distances.py +198 -0
  14. multipers/filtration_conversions.pxd +229 -0
  15. multipers/filtration_conversions.pxd.tp +84 -0
  16. multipers/filtrations/__init__.py +18 -0
  17. multipers/filtrations/density.py +563 -0
  18. multipers/filtrations/filtrations.py +289 -0
  19. multipers/filtrations.pxd +224 -0
  20. multipers/function_rips.cp313-win_amd64.pyd +0 -0
  21. multipers/function_rips.pyx +105 -0
  22. multipers/grids.cp313-win_amd64.pyd +0 -0
  23. multipers/grids.pyx +350 -0
  24. multipers/gudhi/Persistence_slices_interface.h +132 -0
  25. multipers/gudhi/Simplex_tree_interface.h +239 -0
  26. multipers/gudhi/Simplex_tree_multi_interface.h +516 -0
  27. multipers/gudhi/cubical_to_boundary.h +59 -0
  28. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  29. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  30. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  31. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  32. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  33. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  34. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  35. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  36. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  37. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  38. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  39. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  40. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  41. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  42. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  43. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  44. multipers/gudhi/gudhi/Matrix.h +2107 -0
  45. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  46. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -0
  47. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  48. multipers/gudhi/gudhi/Off_reader.h +173 -0
  49. multipers/gudhi/gudhi/One_critical_filtration.h +1433 -0
  50. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  51. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  52. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  53. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  54. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  84. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  85. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  86. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  87. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  88. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  89. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  90. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  91. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  92. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  93. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  94. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  95. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  96. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  97. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  98. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  99. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
  100. multipers/gudhi/gudhi/distance_functions.h +62 -0
  101. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  102. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  103. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  104. multipers/gudhi/gudhi/reader_utils.h +367 -0
  105. multipers/gudhi/mma_interface_coh.h +256 -0
  106. multipers/gudhi/mma_interface_h0.h +223 -0
  107. multipers/gudhi/mma_interface_matrix.h +291 -0
  108. multipers/gudhi/naive_merge_tree.h +536 -0
  109. multipers/gudhi/scc_io.h +310 -0
  110. multipers/gudhi/truc.h +957 -0
  111. multipers/io.cp313-win_amd64.pyd +0 -0
  112. multipers/io.pyx +714 -0
  113. multipers/ml/__init__.py +0 -0
  114. multipers/ml/accuracies.py +90 -0
  115. multipers/ml/invariants_with_persistable.py +79 -0
  116. multipers/ml/kernels.py +176 -0
  117. multipers/ml/mma.py +713 -0
  118. multipers/ml/one.py +472 -0
  119. multipers/ml/point_clouds.py +352 -0
  120. multipers/ml/signed_measures.py +1589 -0
  121. multipers/ml/sliced_wasserstein.py +461 -0
  122. multipers/ml/tools.py +113 -0
  123. multipers/mma_structures.cp313-win_amd64.pyd +0 -0
  124. multipers/mma_structures.pxd +127 -0
  125. multipers/mma_structures.pyx +2742 -0
  126. multipers/mma_structures.pyx.tp +1083 -0
  127. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
  128. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  129. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  130. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  131. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  132. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  133. multipers/multiparameter_edge_collapse.py +41 -0
  134. multipers/multiparameter_module_approximation/approximation.h +2298 -0
  135. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  136. multipers/multiparameter_module_approximation/debug.h +107 -0
  137. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  138. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  139. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  140. multipers/multiparameter_module_approximation/images.h +79 -0
  141. multipers/multiparameter_module_approximation/list_column.h +174 -0
  142. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  143. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  144. multipers/multiparameter_module_approximation/set_column.h +135 -0
  145. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  146. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  147. multipers/multiparameter_module_approximation/utilities.h +403 -0
  148. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  149. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  150. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  151. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  152. multipers/multiparameter_module_approximation.cp313-win_amd64.pyd +0 -0
  153. multipers/multiparameter_module_approximation.pyx +218 -0
  154. multipers/pickle.py +90 -0
  155. multipers/plots.py +342 -0
  156. multipers/point_measure.cp313-win_amd64.pyd +0 -0
  157. multipers/point_measure.pyx +322 -0
  158. multipers/simplex_tree_multi.cp313-win_amd64.pyd +0 -0
  159. multipers/simplex_tree_multi.pxd +133 -0
  160. multipers/simplex_tree_multi.pyx +10402 -0
  161. multipers/simplex_tree_multi.pyx.tp +1947 -0
  162. multipers/slicer.cp313-win_amd64.pyd +0 -0
  163. multipers/slicer.pxd +2552 -0
  164. multipers/slicer.pxd.tp +218 -0
  165. multipers/slicer.pyx +16530 -0
  166. multipers/slicer.pyx.tp +931 -0
  167. multipers/tbb12.dll +0 -0
  168. multipers/tbbbind_2_5.dll +0 -0
  169. multipers/tbbmalloc.dll +0 -0
  170. multipers/tbbmalloc_proxy.dll +0 -0
  171. multipers/tensor/tensor.h +672 -0
  172. multipers/tensor.pxd +13 -0
  173. multipers/test.pyx +44 -0
  174. multipers/tests/__init__.py +57 -0
  175. multipers/torch/__init__.py +1 -0
  176. multipers/torch/diff_grids.py +217 -0
  177. multipers/torch/rips_density.py +310 -0
  178. multipers-2.3.1.dist-info/LICENSE +21 -0
  179. multipers-2.3.1.dist-info/METADATA +144 -0
  180. multipers-2.3.1.dist-info/RECORD +182 -0
  181. multipers-2.3.1.dist-info/WHEEL +5 -0
  182. multipers-2.3.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,563 @@
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Any, Literal, Union
3
+
4
+ import numpy as np
5
+
6
+ global available_kernels
7
+ available_kernels = Union[
8
+ Literal[
9
+ "gaussian", "exponential", "exponential_kernel", "multivariate_gaussian", "sinc"
10
+ ],
11
+ Callable,
12
+ ]
13
+
14
+
15
+ def convolution_signed_measures(
16
+ iterable_of_signed_measures,
17
+ filtrations,
18
+ bandwidth,
19
+ flatten: bool = True,
20
+ n_jobs: int = 1,
21
+ backend="pykeops",
22
+ kernel: available_kernels = "gaussian",
23
+ **kwargs,
24
+ ):
25
+ """
26
+ Evaluates the convolution of the signed measures Iterable(pts, weights) with a gaussian measure of bandwidth bandwidth, on a grid given by the filtrations
27
+
28
+ Parameters
29
+ ----------
30
+
31
+ - iterable_of_signed_measures : (num_signed_measure) x [ (npts) x (num_parameters), (npts)]
32
+ - filtrations : (num_parameter) x (filtration values)
33
+ - flatten : bool
34
+ - n_jobs : int
35
+
36
+ Outputs
37
+ -------
38
+
39
+ The concatenated images, for each signed measure (num_signed_measures) x (len(f) for f in filtration_values)
40
+ """
41
+ from multipers.grids import todense
42
+
43
+ grid_iterator = todense(filtrations, product_order=True)
44
+ match backend:
45
+ case "sklearn":
46
+
47
+ def convolution_signed_measures_on_grid(
48
+ signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
49
+ ):
50
+ return np.concatenate(
51
+ [
52
+ _pts_convolution_sparse_old(
53
+ pts=pts,
54
+ pts_weights=weights,
55
+ grid_iterator=grid_iterator,
56
+ bandwidth=bandwidth,
57
+ kernel=kernel,
58
+ **kwargs,
59
+ )
60
+ for pts, weights in signed_measures
61
+ ],
62
+ axis=0,
63
+ )
64
+
65
+ case "pykeops":
66
+
67
+ def convolution_signed_measures_on_grid(
68
+ signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
69
+ ) -> np.ndarray:
70
+ return np.concatenate(
71
+ [
72
+ _pts_convolution_pykeops(
73
+ pts=pts,
74
+ pts_weights=weights,
75
+ grid_iterator=grid_iterator,
76
+ bandwidth=bandwidth,
77
+ kernel=kernel,
78
+ **kwargs,
79
+ )
80
+ for pts, weights in signed_measures
81
+ ],
82
+ axis=0,
83
+ )
84
+
85
+ # compiles first once
86
+ pts, weights = iterable_of_signed_measures[0][0]
87
+ small_pts, small_weights = pts[:2], weights[:2]
88
+
89
+ _pts_convolution_pykeops(
90
+ small_pts,
91
+ small_weights,
92
+ grid_iterator=grid_iterator,
93
+ bandwidth=bandwidth,
94
+ kernel=kernel,
95
+ **kwargs,
96
+ )
97
+
98
+ if n_jobs > 1 or n_jobs == -1:
99
+ prefer = "processes" if backend == "sklearn" else "threads"
100
+ from joblib import Parallel, delayed
101
+
102
+ convolutions = Parallel(n_jobs=n_jobs, prefer=prefer)(
103
+ delayed(convolution_signed_measures_on_grid)(sms)
104
+ for sms in iterable_of_signed_measures
105
+ )
106
+ else:
107
+ convolutions = [
108
+ convolution_signed_measures_on_grid(sms)
109
+ for sms in iterable_of_signed_measures
110
+ ]
111
+ if not flatten:
112
+ out_shape = [-1] + [len(f) for f in filtrations] # Degree
113
+ convolutions = [x.reshape(out_shape) for x in convolutions]
114
+ return np.asarray(convolutions)
115
+
116
+
117
+ # def _test(r=1000, b=0.5, plot=True, kernel=0):
118
+ # import matplotlib.pyplot as plt
119
+ # pts, weigths = np.array([[1.,1.], [1.1,1.1]]), np.array([1,-1])
120
+ # pt_list = np.array(list(product(*[np.linspace(0,2,r)]*2)))
121
+ # img = _pts_convolution_sparse_pts(pts,weigths, pt_list,b,kernel=kernel)
122
+ # if plot:
123
+ # plt.imshow(img.reshape(r,-1).T, origin="lower")
124
+ # plt.show()
125
+
126
+
127
+ def _pts_convolution_sparse_old(
128
+ pts: np.ndarray,
129
+ pts_weights: np.ndarray,
130
+ grid_iterator,
131
+ kernel: available_kernels = "gaussian",
132
+ bandwidth=0.1,
133
+ **more_kde_args,
134
+ ):
135
+ """
136
+ Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
137
+ """
138
+ from sklearn.neighbors import KernelDensity
139
+
140
+ if len(pts) == 0:
141
+ # warn("Found a trivial signed measure !")
142
+ return np.zeros(len(grid_iterator))
143
+ kde = KernelDensity(
144
+ kernel=kernel, bandwidth=bandwidth, rtol=1e-4, **more_kde_args
145
+ ) # TODO : check rtol
146
+ pos_indices = pts_weights > 0
147
+ neg_indices = pts_weights < 0
148
+ img_pos = (
149
+ np.zeros(len(grid_iterator))
150
+ if pos_indices.sum() == 0
151
+ else kde.fit(
152
+ pts[pos_indices], sample_weight=pts_weights[pos_indices]
153
+ ).score_samples(grid_iterator)
154
+ )
155
+ img_neg = (
156
+ np.zeros(len(grid_iterator))
157
+ if neg_indices.sum() == 0
158
+ else kde.fit(
159
+ pts[neg_indices], sample_weight=-pts_weights[neg_indices]
160
+ ).score_samples(grid_iterator)
161
+ )
162
+ return np.exp(img_pos) - np.exp(img_neg)
163
+
164
+
165
+ def _pts_convolution_pykeops(
166
+ pts: np.ndarray,
167
+ pts_weights: np.ndarray,
168
+ grid_iterator,
169
+ kernel: available_kernels = "gaussian",
170
+ bandwidth=0.1,
171
+ **more_kde_args,
172
+ ):
173
+ """
174
+ Pykeops convolution
175
+ """
176
+ kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
177
+ return kde.fit(
178
+ pts, sample_weights=np.asarray(pts_weights, dtype=pts.dtype)
179
+ ).score_samples(np.asarray(grid_iterator, dtype=pts.dtype))
180
+
181
+
182
+ def gaussian_kernel(x_i, y_j, bandwidth):
183
+ D = x_i.shape[-1]
184
+ exponent = -(((x_i - y_j) / bandwidth) ** 2).sum(dim=-1) / 2
185
+ # float is necessary for some reason (pykeops fails)
186
+ kernel = (exponent).exp() / float((bandwidth*np.sqrt(2 * np.pi))**D)
187
+ return kernel
188
+
189
+
190
+ def multivariate_gaussian_kernel(x_i, y_j, covariance_matrix_inverse):
191
+ # 1 / \sqrt(2 \pi^dim * \Sigma.det()) * exp( -(x-y).T @ \Sigma ^{-1} @ (x-y))
192
+ # CF https://www.kernel-operations.io/keops/_auto_examples/pytorch/plot_anisotropic_kernels.html#sphx-glr-auto-examples-pytorch-plot-anisotropic-kernels-py
193
+ # and https://www.kernel-operations.io/keops/api/math-operations.html
194
+ dim = x_i.shape[-1]
195
+ z = x_i - y_j
196
+ exponent = -(z.weightedsqnorm(covariance_matrix_inverse.flatten()) / 2)
197
+ return (
198
+ float((2 * np.pi) ** (-dim / 2))
199
+ * (covariance_matrix_inverse.det().sqrt())
200
+ * exponent.exp()
201
+ )
202
+
203
+
204
+ def exponential_kernel(x_i, y_j, bandwidth):
205
+ # 1 / \sigma * exp( norm(x-y, dim=-1))
206
+ exponent = -(((((x_i - y_j) ** 2)).sum(dim=-1) ** 1 / 2) / bandwidth)
207
+ kernel = exponent.exp() / bandwidth
208
+ return kernel
209
+
210
+
211
+ def sinc_kernel(x_i, y_j, bandwidth):
212
+ norm = ((((x_i - y_j) ** 2)).sum(dim=-1) ** 1 / 2) / bandwidth
213
+ sinc = type(x_i).sinc
214
+ kernel = 2 * sinc(2 * norm) - sinc(norm)
215
+ return kernel
216
+
217
+
218
+ def _kernel(
219
+ kernel: available_kernels = "gaussian",
220
+ ):
221
+ match kernel:
222
+ case "gaussian":
223
+ return gaussian_kernel
224
+ case "exponential":
225
+ return exponential_kernel
226
+ case "multivariate_gaussian":
227
+ return multivariate_gaussian_kernel
228
+ case "sinc":
229
+ return sinc_kernel
230
+ case _:
231
+ assert callable(
232
+ kernel
233
+ ), f"""
234
+ --------------------------
235
+ Unknown kernel {kernel}.
236
+ --------------------------
237
+ Custom kernel has to be callable,
238
+ (x:LazyTensor(n,1,D),y:LazyTensor(1,m,D),bandwidth:float) ---> kernel matrix
239
+
240
+ Valid operations are given here:
241
+ https://www.kernel-operations.io/keops/python/api/index.html
242
+ """
243
+ return kernel
244
+
245
+
246
+ # TODO : multiple bandwidths at once with lazy tensors
247
+ class KDE:
248
+ """
249
+ Fast, scikit-style, and differentiable kernel density estimation, using PyKeops.
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ bandwidth: Any = 1,
255
+ kernel: available_kernels = "gaussian",
256
+ return_log: bool = False,
257
+ ):
258
+ """
259
+ bandwidth : numeric
260
+ bandwidth for Gaussian kernel
261
+ """
262
+ self.X = None
263
+ self.bandwidth = bandwidth
264
+ self.kernel: available_kernels = kernel
265
+ self._kernel = None
266
+ self._backend = None
267
+ self._sample_weights = None
268
+ self.return_log = return_log
269
+
270
+ def fit(self, X, sample_weights=None, y=None):
271
+ self.X = X
272
+ self._sample_weights = sample_weights
273
+ if isinstance(X, np.ndarray):
274
+ self._backend = np
275
+ else:
276
+ import torch
277
+
278
+ if isinstance(X, torch.Tensor):
279
+ self._backend = torch
280
+ else:
281
+ raise Exception("Unsupported backend.")
282
+ self._kernel = _kernel(self.kernel)
283
+ return self
284
+
285
+ @staticmethod
286
+ def to_lazy(X, Y, x_weights):
287
+ if isinstance(X, np.ndarray):
288
+ from pykeops.numpy import LazyTensor
289
+
290
+ lazy_x = LazyTensor(
291
+ X.reshape((X.shape[0], 1, X.shape[1]))
292
+ ) # numpts, 1, dim
293
+ lazy_y = LazyTensor(
294
+ Y.reshape((1, Y.shape[0], Y.shape[1]))
295
+ ) # 1, numpts, dim
296
+ if x_weights is not None:
297
+ w = LazyTensor(x_weights[:, None], axis=0)
298
+ return lazy_x, lazy_y, w
299
+ return lazy_x, lazy_y, None
300
+ import torch
301
+
302
+ if isinstance(X, torch.Tensor):
303
+ from pykeops.torch import LazyTensor
304
+
305
+ lazy_x = LazyTensor(X.view(X.shape[0], 1, X.shape[1]))
306
+ lazy_y = LazyTensor(Y.view(1, Y.shape[0], Y.shape[1]))
307
+ if x_weights is not None:
308
+ w = LazyTensor(x_weights[:, None], axis=0)
309
+ return lazy_x, lazy_y, w
310
+ return lazy_x, lazy_y, None
311
+ raise Exception("Bad tensor type.")
312
+
313
+ def score_samples(self, Y, X=None, return_kernel=False):
314
+ """Returns the kernel density estimates of each point in `Y`.
315
+
316
+ Parameters
317
+ ----------
318
+ Y : tensor (m, d)
319
+ `m` points with `d` dimensions for which the probability density will
320
+ be calculated
321
+ X : tensor (n, d), optional
322
+ `n` points with `d` dimensions to which KDE will be fit. Provided to
323
+ allow batch calculations in `log_prob`. By default, `X` is None and
324
+ all points used to initialize KernelDensityEstimator are included.
325
+
326
+
327
+ Returns
328
+ -------
329
+ log_probs : tensor (m)
330
+ log probability densities for each of the queried points in `Y`
331
+ """
332
+ assert self._backend is not None and self._kernel is not None, "Fit first."
333
+ X = self.X if X is None else X
334
+ if X.shape[0] == 0:
335
+ return self._backend.zeros((Y.shape[0]))
336
+ assert Y.shape[1] == X.shape[1] and X.ndim == Y.ndim == 2
337
+ lazy_x, lazy_y, w = self.to_lazy(X, Y, x_weights=self._sample_weights)
338
+ kernel = self._kernel(lazy_x, lazy_y, self.bandwidth)
339
+ if w is not None:
340
+ kernel *= w
341
+ if return_kernel:
342
+ return kernel
343
+ density_estimation = kernel.sum(dim=0).squeeze() / kernel.shape[0] # mean
344
+ return (
345
+ self._backend.log(density_estimation)
346
+ if self.return_log
347
+ else density_estimation
348
+ )
349
+
350
+
351
+ def batch_signed_measure_convolutions(
352
+ signed_measures, # array of shape (num_data,num_pts,D)
353
+ x, # array of shape (num_x, D) or (num_data, num_x, D)
354
+ bandwidth, # either float or matrix if multivariate kernel
355
+ kernel: available_kernels,
356
+ ):
357
+ """
358
+ Input
359
+ -----
360
+ - signed_measures: unragged, of shape (num_data, num_pts, D+1)
361
+ where last coord is weights, (0 for dummy points)
362
+ - x : the points to convolve (num_x,D)
363
+ - bandwidth : the bandwidths or covariance matrix inverse or ... of the kernel
364
+ - kernel : "gaussian", "multivariate_gaussian", "exponential", or Callable (x_i, y_i, bandwidth)->float
365
+
366
+ Output
367
+ ------
368
+ Array of shape (num_convolutions, (num_axis), num_data,
369
+ Array of shape (num_convolutions, (num_axis), num_data, max_x_size)
370
+ """
371
+ if signed_measures.ndim == 2:
372
+ signed_measures = signed_measures[None, :, :]
373
+ sms = signed_measures[..., :-1]
374
+ weights = signed_measures[..., -1]
375
+ if isinstance(signed_measures, np.ndarray):
376
+ from pykeops.numpy import LazyTensor
377
+ else:
378
+ import torch
379
+
380
+ assert isinstance(signed_measures, torch.Tensor)
381
+ from pykeops.torch import LazyTensor
382
+
383
+ _sms = LazyTensor(sms[..., None, :].contiguous())
384
+ _x = x[..., None, :, :].contiguous()
385
+
386
+ sms_kernel = _kernel(kernel)(_sms, _x, bandwidth)
387
+ out = (sms_kernel * weights[..., None, None].contiguous()).sum(
388
+ signed_measures.ndim - 2
389
+ )
390
+ assert out.shape[-1] == 1, "Pykeops bug fixed, TODO : refix this "
391
+ out = out[..., 0] ## pykeops bug + ensures its a tensor
392
+ # assert out.shape == (x.shape[0], x.shape[1]), f"{x.shape=}, {out.shape=}"
393
+ return out
394
+
395
+
396
+ class DTM:
397
+ """
398
+ Distance To Measure
399
+ """
400
+
401
+ def __init__(self, masses, metric: str = "euclidean", **_kdtree_kwargs):
402
+ """
403
+ mass : float in [0,1]
404
+ The mass threshold
405
+ metric :
406
+ The distance between points to consider
407
+ """
408
+ self.masses = masses
409
+ self.metric = metric
410
+ self._kdtree_kwargs = _kdtree_kwargs
411
+ self._ks = None
412
+ self._kdtree = None
413
+ self._X = None
414
+ self._backend = None
415
+
416
+ def fit(self, X, sample_weights=None, y=None):
417
+ if len(self.masses) == 0:
418
+ return self
419
+ assert np.max(self.masses) <= 1, "All masses should be in (0,1]."
420
+ from sklearn.neighbors import KDTree
421
+
422
+ if not isinstance(X, np.ndarray):
423
+ import torch
424
+
425
+ assert isinstance(X, torch.Tensor), "Backend has to be numpy of torch"
426
+ _X = X.detach()
427
+ self._backend = "torch"
428
+ else:
429
+ _X = X
430
+ self._backend = "numpy"
431
+ self._ks = np.array([int(mass * X.shape[0]) + 1 for mass in self.masses])
432
+ self._kdtree = KDTree(_X, metric=self.metric, **self._kdtree_kwargs)
433
+ self._X = X
434
+ return self
435
+
436
+ def score_samples(self, Y, X=None):
437
+ """Returns the kernel density estimates of each point in `Y`.
438
+
439
+ Parameters
440
+ ----------
441
+ Y : tensor (m, d)
442
+ `m` points with `d` dimensions for which the probability density will
443
+ be calculated
444
+
445
+
446
+ Returns
447
+ -------
448
+ the DTMs of Y, for each mass in masses.
449
+ """
450
+ if len(self.masses) == 0:
451
+ return np.empty((0, len(Y)))
452
+ assert (
453
+ self._ks is not None and self._kdtree is not None and self._X is not None
454
+ ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
455
+ assert Y.ndim == 2
456
+ if self._backend == "torch":
457
+ _Y = Y.detach().numpy()
458
+ else:
459
+ _Y = Y
460
+ NN_Dist, NN = self._kdtree.query(_Y, self._ks.max(), return_distance=True)
461
+ DTMs = np.array([((NN_Dist**2)[:, :k].mean(1)) ** 0.5 for k in self._ks])
462
+ return DTMs
463
+
464
+ def score_samples_diff(self, Y):
465
+ """Returns the kernel density estimates of each point in `Y`.
466
+
467
+ Parameters
468
+ ----------
469
+ Y : tensor (m, d)
470
+ `m` points with `d` dimensions for which the probability density will
471
+ be calculated
472
+ X : tensor (n, d), optional
473
+ `n` points with `d` dimensions to which KDE will be fit. Provided to
474
+ allow batch calculations in `log_prob`. By default, `X` is None and
475
+ all points used to initialize KernelDensityEstimator are included.
476
+
477
+
478
+ Returns
479
+ -------
480
+ log_probs : tensor (m)
481
+ log probability densities for each of the queried points in `Y`
482
+ """
483
+ import torch
484
+
485
+ if len(self.masses) == 0:
486
+ return torch.empty(0, len(Y))
487
+
488
+ assert Y.ndim == 2
489
+ assert self._backend == "torch", "Use the non-diff version with numpy."
490
+ assert (
491
+ self._ks is not None and self._kdtree is not None and self._X is not None
492
+ ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
493
+ NN = self._kdtree.query(Y.detach(), self._ks.max(), return_distance=False)
494
+ DTMs = tuple(
495
+ (((self._X[NN] - Y[:, None, :]) ** 2)[:, :k].sum(dim=(1, 2)) / k) ** 0.5
496
+ for k in self._ks
497
+ ) # TODO : kdtree already computes distance, find implementation of kdtree that is pytorch differentiable
498
+ return DTMs
499
+
500
+
501
+ ## code taken from pykeops doc (https://www.kernel-operations.io/keops/_auto_benchmarks/benchmark_KNN.html)
502
+ class KNNmean:
503
+ def __init__(self, k: int, metric: str = "euclidean"):
504
+ self.k = k
505
+ self.metric = metric
506
+ self._KNN_fun = None
507
+ self._x = None
508
+
509
+ def fit(self, x):
510
+ if isinstance(x, np.ndarray):
511
+ from pykeops.numpy import Vi, Vj
512
+ else:
513
+ import torch
514
+
515
+ assert isinstance(x, torch.Tensor), "Backend has to be numpy or torch"
516
+ from pykeops.torch import Vi, Vj
517
+
518
+ D = x.shape[1]
519
+ X_i = Vi(0, D)
520
+ X_j = Vj(1, D)
521
+
522
+ # Symbolic distance matrix:
523
+ if self.metric == "euclidean":
524
+ D_ij = ((X_i - X_j) ** 2).sum(-1)
525
+ elif self.metric == "manhattan":
526
+ D_ij = (X_i - X_j).abs().sum(-1)
527
+ elif self.metric == "angular":
528
+ D_ij = -(X_i | X_j)
529
+ elif self.metric == "hyperbolic":
530
+ D_ij = ((X_i - X_j) ** 2).sum(-1) / (X_i[0] * X_j[0])
531
+ else:
532
+ raise NotImplementedError(f"The '{self.metric}' distance is not supported.")
533
+
534
+ self._x = x
535
+ self._KNN_fun = D_ij.Kmin(self.k, dim=1)
536
+ return self
537
+
538
+ def score_samples(self, x):
539
+ assert self._x is not None and self._KNN_fun is not None, "Fit first."
540
+ return self._KNN_fun(x, self._x).sum(axis=1) / self.k
541
+
542
+
543
+ # def _pts_convolution_sparse(pts:np.ndarray, pts_weights:np.ndarray, filtration_grid:Iterable[np.ndarray], kernel="gaussian", bandwidth=0.1, **more_kde_args):
544
+ # """
545
+ # Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
546
+ # """
547
+ # from sklearn.neighbors import KernelDensity
548
+ # grid_iterator = np.asarray(list(product(*filtration_grid)))
549
+ # grid_shape = [len(f) for f in filtration_grid]
550
+ # if len(pts) == 0:
551
+ # # warn("Found a trivial signed measure !")
552
+ # return np.zeros(shape=grid_shape)
553
+ # kde = KernelDensity(kernel=kernel, bandwidth=bandwidth, rtol = 1e-4, **more_kde_args) # TODO : check rtol
554
+
555
+ # pos_indices = pts_weights>0
556
+ # neg_indices = pts_weights<0
557
+ # img_pos = kde.fit(pts[pos_indices], sample_weight=pts_weights[pos_indices]).score_samples(grid_iterator).reshape(grid_shape)
558
+ # img_neg = kde.fit(pts[neg_indices], sample_weight=-pts_weights[neg_indices]).score_samples(grid_iterator).reshape(grid_shape)
559
+ # return np.exp(img_pos) - np.exp(img_neg)
560
+
561
+
562
+ # Precompiles the convolution
563
+ # _test(r=2,b=.5, plot=False)