multipers 2.4.0b1__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. multipers/.dylibs/libboost_timer.dylib +0 -0
  2. multipers/.dylibs/libc++.1.0.dylib +0 -0
  3. multipers/.dylibs/libtbb.12.17.dylib +0 -0
  4. multipers/__init__.py +33 -0
  5. multipers/_signed_measure_meta.py +426 -0
  6. multipers/_slicer_meta.py +231 -0
  7. multipers/array_api/__init__.py +62 -0
  8. multipers/array_api/numpy.py +124 -0
  9. multipers/array_api/torch.py +133 -0
  10. multipers/data/MOL2.py +458 -0
  11. multipers/data/UCR.py +18 -0
  12. multipers/data/__init__.py +1 -0
  13. multipers/data/graphs.py +466 -0
  14. multipers/data/immuno_regions.py +27 -0
  15. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  16. multipers/data/pytorch2simplextree.py +91 -0
  17. multipers/data/shape3d.py +101 -0
  18. multipers/data/synthetic.py +113 -0
  19. multipers/distances.py +202 -0
  20. multipers/filtration_conversions.pxd +736 -0
  21. multipers/filtration_conversions.pxd.tp +226 -0
  22. multipers/filtrations/__init__.py +21 -0
  23. multipers/filtrations/density.py +529 -0
  24. multipers/filtrations/filtrations.py +480 -0
  25. multipers/filtrations.pxd +534 -0
  26. multipers/filtrations.pxd.tp +332 -0
  27. multipers/function_rips.cpython-312-darwin.so +0 -0
  28. multipers/function_rips.pyx +104 -0
  29. multipers/grids.cpython-312-darwin.so +0 -0
  30. multipers/grids.pyx +538 -0
  31. multipers/gudhi/Persistence_slices_interface.h +213 -0
  32. multipers/gudhi/Simplex_tree_interface.h +274 -0
  33. multipers/gudhi/Simplex_tree_multi_interface.h +648 -0
  34. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  35. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  36. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  37. multipers/gudhi/gudhi/Debug_utils.h +52 -0
  38. multipers/gudhi/gudhi/Degree_rips_bifiltration.h +2307 -0
  39. multipers/gudhi/gudhi/Dynamic_multi_parameter_filtration.h +2524 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field.h +453 -0
  41. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +460 -0
  42. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +444 -0
  43. multipers/gudhi/gudhi/Fields/Multi_field_small.h +584 -0
  44. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +490 -0
  45. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +580 -0
  46. multipers/gudhi/gudhi/Fields/Z2_field.h +391 -0
  47. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +389 -0
  48. multipers/gudhi/gudhi/Fields/Zp_field.h +493 -0
  49. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +384 -0
  50. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +492 -0
  51. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  52. multipers/gudhi/gudhi/Matrix.h +2200 -0
  53. multipers/gudhi/gudhi/Multi_filtration/Multi_parameter_generator.h +1712 -0
  54. multipers/gudhi/gudhi/Multi_filtration/multi_filtration_conversions.h +237 -0
  55. multipers/gudhi/gudhi/Multi_filtration/multi_filtration_utils.h +225 -0
  56. multipers/gudhi/gudhi/Multi_parameter_filtered_complex.h +485 -0
  57. multipers/gudhi/gudhi/Multi_parameter_filtration.h +2643 -0
  58. multipers/gudhi/gudhi/Multi_persistence/Box.h +233 -0
  59. multipers/gudhi/gudhi/Multi_persistence/Line.h +309 -0
  60. multipers/gudhi/gudhi/Multi_persistence/Multi_parameter_filtered_complex_pcoh_interface.h +268 -0
  61. multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_cohomology.h +159 -0
  62. multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_matrix.h +463 -0
  63. multipers/gudhi/gudhi/Multi_persistence/Point.h +853 -0
  64. multipers/gudhi/gudhi/Off_reader.h +173 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +834 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +838 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +833 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1367 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1157 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +869 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +905 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +122 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +260 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +288 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +170 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +247 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +571 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +182 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +130 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +235 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +312 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1092 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +923 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +914 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +930 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1071 -0
  87. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +203 -0
  88. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +886 -0
  89. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +984 -0
  90. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1213 -0
  91. multipers/gudhi/gudhi/Persistence_matrix/index_mapper.h +58 -0
  92. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +227 -0
  93. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +200 -0
  94. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +166 -0
  95. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +319 -0
  96. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +562 -0
  97. multipers/gudhi/gudhi/Persistence_on_a_line.h +152 -0
  98. multipers/gudhi/gudhi/Persistence_on_rectangle.h +617 -0
  99. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  100. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  101. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  102. multipers/gudhi/gudhi/Persistent_cohomology.h +769 -0
  103. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  104. multipers/gudhi/gudhi/Projective_cover_kernel.h +379 -0
  105. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  106. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +559 -0
  107. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  108. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +121 -0
  109. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  110. multipers/gudhi/gudhi/Simplex_tree/filtration_value_utils.h +155 -0
  111. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  112. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  113. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +60 -0
  114. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +105 -0
  115. multipers/gudhi/gudhi/Simplex_tree.h +3170 -0
  116. multipers/gudhi/gudhi/Slicer.h +848 -0
  117. multipers/gudhi/gudhi/Thread_safe_slicer.h +393 -0
  118. multipers/gudhi/gudhi/distance_functions.h +62 -0
  119. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  120. multipers/gudhi/gudhi/multi_simplex_tree_helpers.h +147 -0
  121. multipers/gudhi/gudhi/persistence_interval.h +263 -0
  122. multipers/gudhi/gudhi/persistence_matrix_options.h +188 -0
  123. multipers/gudhi/gudhi/reader_utils.h +367 -0
  124. multipers/gudhi/gudhi/simple_mdspan.h +484 -0
  125. multipers/gudhi/gudhi/slicer_helpers.h +779 -0
  126. multipers/gudhi/tmp_h0_pers/mma_interface_h0.h +223 -0
  127. multipers/gudhi/tmp_h0_pers/naive_merge_tree.h +536 -0
  128. multipers/io.cpython-312-darwin.so +0 -0
  129. multipers/io.pyx +472 -0
  130. multipers/ml/__init__.py +0 -0
  131. multipers/ml/accuracies.py +90 -0
  132. multipers/ml/invariants_with_persistable.py +79 -0
  133. multipers/ml/kernels.py +176 -0
  134. multipers/ml/mma.py +713 -0
  135. multipers/ml/one.py +472 -0
  136. multipers/ml/point_clouds.py +352 -0
  137. multipers/ml/signed_measures.py +1667 -0
  138. multipers/ml/sliced_wasserstein.py +461 -0
  139. multipers/ml/tools.py +113 -0
  140. multipers/mma_structures.cpython-312-darwin.so +0 -0
  141. multipers/mma_structures.pxd +134 -0
  142. multipers/mma_structures.pyx +1483 -0
  143. multipers/mma_structures.pyx.tp +1126 -0
  144. multipers/multi_parameter_rank_invariant/diff_helpers.h +85 -0
  145. multipers/multi_parameter_rank_invariant/euler_characteristic.h +95 -0
  146. multipers/multi_parameter_rank_invariant/function_rips.h +317 -0
  147. multipers/multi_parameter_rank_invariant/hilbert_function.h +761 -0
  148. multipers/multi_parameter_rank_invariant/persistence_slices.h +149 -0
  149. multipers/multi_parameter_rank_invariant/rank_invariant.h +350 -0
  150. multipers/multiparameter_edge_collapse.py +41 -0
  151. multipers/multiparameter_module_approximation/approximation.h +2541 -0
  152. multipers/multiparameter_module_approximation/debug.h +107 -0
  153. multipers/multiparameter_module_approximation/format_python-cpp.h +292 -0
  154. multipers/multiparameter_module_approximation/utilities.h +428 -0
  155. multipers/multiparameter_module_approximation.cpython-312-darwin.so +0 -0
  156. multipers/multiparameter_module_approximation.pyx +286 -0
  157. multipers/ops.cpython-312-darwin.so +0 -0
  158. multipers/ops.pyx +231 -0
  159. multipers/pickle.py +89 -0
  160. multipers/plots.py +550 -0
  161. multipers/point_measure.cpython-312-darwin.so +0 -0
  162. multipers/point_measure.pyx +409 -0
  163. multipers/simplex_tree_multi.cpython-312-darwin.so +0 -0
  164. multipers/simplex_tree_multi.pxd +136 -0
  165. multipers/simplex_tree_multi.pyx +11719 -0
  166. multipers/simplex_tree_multi.pyx.tp +2102 -0
  167. multipers/slicer.cpython-312-darwin.so +0 -0
  168. multipers/slicer.pxd +2097 -0
  169. multipers/slicer.pxd.tp +263 -0
  170. multipers/slicer.pyx +13042 -0
  171. multipers/slicer.pyx.tp +1259 -0
  172. multipers/tensor/tensor.h +672 -0
  173. multipers/tensor.pxd +13 -0
  174. multipers/test.pyx +44 -0
  175. multipers/tests/__init__.py +70 -0
  176. multipers/torch/__init__.py +1 -0
  177. multipers/torch/diff_grids.py +240 -0
  178. multipers/torch/rips_density.py +310 -0
  179. multipers/vector_interface.pxd +46 -0
  180. multipers-2.4.0b1.dist-info/METADATA +131 -0
  181. multipers-2.4.0b1.dist-info/RECORD +184 -0
  182. multipers-2.4.0b1.dist-info/WHEEL +6 -0
  183. multipers-2.4.0b1.dist-info/licenses/LICENSE +21 -0
  184. multipers-2.4.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,529 @@
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Any, Literal, Union
3
+
4
+ import numpy as np
5
+
6
+ from multipers.array_api import api_from_tensor, api_from_tensors
7
+
8
+ global available_kernels
9
+ available_kernels = Union[
10
+ Literal[
11
+ "gaussian", "exponential", "exponential_kernel", "multivariate_gaussian", "sinc"
12
+ ],
13
+ Callable,
14
+ ]
15
+
16
+
17
+ def convolution_signed_measures(
18
+ iterable_of_signed_measures,
19
+ filtrations,
20
+ bandwidth,
21
+ flatten: bool = True,
22
+ n_jobs: int = 1,
23
+ backend="pykeops",
24
+ kernel: available_kernels = "gaussian",
25
+ **kwargs,
26
+ ):
27
+ """
28
+ Evaluates the convolution of the signed measures Iterable(pts, weights) with a gaussian measure of bandwidth bandwidth, on a grid given by the filtrations
29
+
30
+ Parameters
31
+ ----------
32
+
33
+ - iterable_of_signed_measures : (num_signed_measure) x [ (npts) x (num_parameters), (npts)]
34
+ - filtrations : (num_parameter) x (filtration values)
35
+ - flatten : bool
36
+ - n_jobs : int
37
+
38
+ Outputs
39
+ -------
40
+
41
+ The concatenated images, for each signed measure (num_signed_measures) x (len(f) for f in filtration_values)
42
+ """
43
+ from multipers.grids import todense
44
+
45
+ grid_iterator = todense(filtrations, product_order=True)
46
+ api = api_from_tensor(iterable_of_signed_measures[0][0][0])
47
+ match backend:
48
+ case "sklearn":
49
+
50
+ def convolution_signed_measures_on_grid(
51
+ signed_measures,
52
+ ):
53
+ return api.cat(
54
+ [
55
+ _pts_convolution_sparse_old(
56
+ pts=pts,
57
+ pts_weights=weights,
58
+ grid_iterator=grid_iterator,
59
+ bandwidth=bandwidth,
60
+ kernel=kernel,
61
+ **kwargs,
62
+ )
63
+ for pts, weights in signed_measures
64
+ ],
65
+ axis=0,
66
+ )
67
+
68
+ case "pykeops":
69
+
70
+ def convolution_signed_measures_on_grid(
71
+ signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
72
+ ) -> np.ndarray:
73
+ return api.cat(
74
+ [
75
+ _pts_convolution_pykeops(
76
+ pts=pts,
77
+ pts_weights=weights,
78
+ grid_iterator=grid_iterator,
79
+ bandwidth=bandwidth,
80
+ kernel=kernel,
81
+ **kwargs,
82
+ )
83
+ for pts, weights in signed_measures
84
+ ],
85
+ axis=0,
86
+ )
87
+
88
+ # compiles first once
89
+ pts, weights = iterable_of_signed_measures[0][0]
90
+ small_pts, small_weights = pts[:2], weights[:2]
91
+
92
+ _pts_convolution_pykeops(
93
+ small_pts,
94
+ small_weights,
95
+ grid_iterator=grid_iterator,
96
+ bandwidth=bandwidth,
97
+ kernel=kernel,
98
+ **kwargs,
99
+ )
100
+
101
+ if n_jobs > 1 or n_jobs == -1:
102
+ prefer = "processes" if backend == "sklearn" else "threads"
103
+ from joblib import Parallel, delayed
104
+
105
+ convolutions = Parallel(n_jobs=n_jobs, prefer=prefer)(
106
+ delayed(convolution_signed_measures_on_grid)(sms)
107
+ for sms in iterable_of_signed_measures
108
+ )
109
+ else:
110
+ convolutions = [
111
+ convolution_signed_measures_on_grid(sms)
112
+ for sms in iterable_of_signed_measures
113
+ ]
114
+ if not flatten:
115
+ out_shape = [-1] + [len(f) for f in filtrations] # Degree
116
+ convolutions = [x.reshape(out_shape) for x in convolutions]
117
+ return api.cat([x[None] for x in convolutions])
118
+
119
+
120
+ # def _test(r=1000, b=0.5, plot=True, kernel=0):
121
+ # import matplotlib.pyplot as plt
122
+ # pts, weigths = np.array([[1.,1.], [1.1,1.1]]), np.array([1,-1])
123
+ # pt_list = np.array(list(product(*[np.linspace(0,2,r)]*2)))
124
+ # img = _pts_convolution_sparse_pts(pts,weigths, pt_list,b,kernel=kernel)
125
+ # if plot:
126
+ # plt.imshow(img.reshape(r,-1).T, origin="lower")
127
+ # plt.show()
128
+
129
+
130
+ def _pts_convolution_sparse_old(
131
+ pts: np.ndarray,
132
+ pts_weights: np.ndarray,
133
+ grid_iterator,
134
+ kernel: available_kernels = "gaussian",
135
+ bandwidth=0.1,
136
+ **more_kde_args,
137
+ ):
138
+ """
139
+ Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
140
+ """
141
+ from sklearn.neighbors import KernelDensity
142
+
143
+ if len(pts) == 0:
144
+ # warn("Found a trivial signed measure !")
145
+ return np.zeros(len(grid_iterator))
146
+ kde = KernelDensity(
147
+ kernel=kernel, bandwidth=bandwidth, rtol=1e-4, **more_kde_args
148
+ ) # TODO : check rtol
149
+ pos_indices = pts_weights > 0
150
+ neg_indices = pts_weights < 0
151
+ img_pos = (
152
+ np.zeros(len(grid_iterator))
153
+ if pos_indices.sum() == 0
154
+ else kde.fit(
155
+ pts[pos_indices], sample_weight=pts_weights[pos_indices]
156
+ ).score_samples(grid_iterator)
157
+ )
158
+ img_neg = (
159
+ np.zeros(len(grid_iterator))
160
+ if neg_indices.sum() == 0
161
+ else kde.fit(
162
+ pts[neg_indices], sample_weight=-pts_weights[neg_indices]
163
+ ).score_samples(grid_iterator)
164
+ )
165
+ return np.exp(img_pos) - np.exp(img_neg)
166
+
167
+
168
+ def _pts_convolution_pykeops(
169
+ pts: np.ndarray,
170
+ pts_weights: np.ndarray,
171
+ grid_iterator,
172
+ kernel: available_kernels = "gaussian",
173
+ bandwidth=0.1,
174
+ **more_kde_args,
175
+ ):
176
+ """
177
+ Pykeops convolution
178
+ """
179
+ if isinstance(pts, np.ndarray):
180
+ _asarray_weights = lambda x: np.asarray(x, dtype=pts.dtype)
181
+ _asarray_grid = _asarray_weights
182
+ else:
183
+ import torch
184
+
185
+ _asarray_weights = lambda x: torch.from_numpy(x).type(pts.dtype)
186
+ _asarray_grid = lambda x: x.type(pts.dtype)
187
+ kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
188
+ return kde.fit(pts, sample_weights=_asarray_weights(pts_weights)).score_samples(
189
+ _asarray_grid(grid_iterator)
190
+ )
191
+
192
+
193
+ def gaussian_kernel(x_i, y_j, bandwidth):
194
+ D = x_i.shape[-1]
195
+ exponent = -(((x_i - y_j) / bandwidth) ** 2).sum(dim=-1) / 2
196
+ # float is necessary for some reason (pykeops fails)
197
+ kernel = (exponent).exp() / float((bandwidth * np.sqrt(2 * np.pi)) ** D)
198
+ return kernel
199
+
200
+
201
+ def multivariate_gaussian_kernel(x_i, y_j, covariance_matrix_inverse):
202
+ # 1 / \sqrt(2 \pi^dim * \Sigma.det()) * exp( -(x-y).T @ \Sigma ^{-1} @ (x-y))
203
+ # CF https://www.kernel-operations.io/keops/_auto_examples/pytorch/plot_anisotropic_kernels.html#sphx-glr-auto-examples-pytorch-plot-anisotropic-kernels-py
204
+ # and https://www.kernel-operations.io/keops/api/math-operations.html
205
+ dim = x_i.shape[-1]
206
+ z = x_i - y_j
207
+ exponent = -(z.weightedsqnorm(covariance_matrix_inverse.flatten()) / 2)
208
+ return (
209
+ float((2 * np.pi) ** (-dim / 2))
210
+ * (covariance_matrix_inverse.det().sqrt())
211
+ * exponent.exp()
212
+ )
213
+
214
+
215
+ def exponential_kernel(x_i, y_j, bandwidth):
216
+ # 1 / \sigma * exp( norm(x-y, dim=-1))
217
+ exponent = -((((x_i - y_j) ** 2).sum(dim=-1) ** 1 / 2) / bandwidth)
218
+ kernel = exponent.exp() / bandwidth
219
+ return kernel
220
+
221
+
222
+ def sinc_kernel(x_i, y_j, bandwidth):
223
+ norm = (((x_i - y_j) ** 2).sum(dim=-1) ** 1 / 2) / bandwidth
224
+ sinc = type(x_i).sinc
225
+ kernel = 2 * sinc(2 * norm) - sinc(norm)
226
+ return kernel
227
+
228
+
229
+ def _kernel(
230
+ kernel: available_kernels = "gaussian",
231
+ ):
232
+ match kernel:
233
+ case "gaussian":
234
+ return gaussian_kernel
235
+ case "exponential":
236
+ return exponential_kernel
237
+ case "multivariate_gaussian":
238
+ return multivariate_gaussian_kernel
239
+ case "sinc":
240
+ return sinc_kernel
241
+ case _:
242
+ assert callable(kernel), f"""
243
+ --------------------------
244
+ Unknown kernel {kernel}.
245
+ --------------------------
246
+ Custom kernel has to be callable,
247
+ (x:LazyTensor(n,1,D),y:LazyTensor(1,m,D),bandwidth:float) ---> kernel matrix
248
+
249
+ Valid operations are given here:
250
+ https://www.kernel-operations.io/keops/python/api/index.html
251
+ """
252
+ return kernel
253
+
254
+
255
+ # TODO : multiple bandwidths at once with lazy tensors
256
+ class KDE:
257
+ """
258
+ Fast, scikit-style, and differentiable kernel density estimation, using PyKeops.
259
+ """
260
+
261
+ def __init__(
262
+ self,
263
+ bandwidth: Any = 1,
264
+ kernel: available_kernels = "gaussian",
265
+ return_log: bool = False,
266
+ ):
267
+ """
268
+ bandwidth : numeric
269
+ bandwidth for Gaussian kernel
270
+ """
271
+ self.X = None
272
+ self.bandwidth = bandwidth
273
+ self.kernel: available_kernels = kernel
274
+ self._kernel = None
275
+ self._backend = None
276
+ self._sample_weights = None
277
+ self.return_log = return_log
278
+
279
+ def fit(self, X, sample_weights=None, y=None):
280
+ self.X = X
281
+ self._sample_weights = sample_weights
282
+ if isinstance(X, np.ndarray):
283
+ self._backend = np
284
+ else:
285
+ import torch
286
+
287
+ if isinstance(X, torch.Tensor):
288
+ self._backend = torch
289
+ else:
290
+ raise Exception("Unsupported backend.")
291
+ self._kernel = _kernel(self.kernel)
292
+ return self
293
+
294
+ @staticmethod
295
+ def to_lazy(X, Y, x_weights):
296
+ if isinstance(X, np.ndarray):
297
+ from pykeops.numpy import LazyTensor
298
+
299
+ lazy_x = LazyTensor(
300
+ X.reshape((X.shape[0], 1, X.shape[1]))
301
+ ) # numpts, 1, dim
302
+ lazy_y = LazyTensor(
303
+ Y.reshape((1, Y.shape[0], Y.shape[1])).astype(X.dtype)
304
+ ) # 1, numpts, dim
305
+ if x_weights is not None:
306
+ w = LazyTensor(np.asarray(x_weights, dtype=X.dtype)[:, None], axis=0)
307
+ return lazy_x, lazy_y, w
308
+ return lazy_x, lazy_y, None
309
+ import torch
310
+
311
+ if isinstance(X, torch.Tensor):
312
+ from pykeops.torch import LazyTensor
313
+
314
+ lazy_x = LazyTensor(X.view(X.shape[0], 1, X.shape[1]))
315
+ lazy_y = LazyTensor(Y.type(X.dtype).view(1, Y.shape[0], Y.shape[1]))
316
+ if x_weights is not None:
317
+ if isinstance(x_weights, np.ndarray):
318
+ x_weights = torch.from_numpy(x_weights)
319
+ w = LazyTensor(x_weights[:, None].type(X.dtype), axis=0)
320
+ return lazy_x, lazy_y, w
321
+ return lazy_x, lazy_y, None
322
+ raise Exception("Bad tensor type.")
323
+
324
+ def score_samples(self, Y, X=None, return_kernel=False):
325
+ """Returns the kernel density estimates of each point in `Y`.
326
+
327
+ Parameters
328
+ ----------
329
+ Y : tensor (m, d)
330
+ `m` points with `d` dimensions for which the probability density will
331
+ be calculated
332
+ X : tensor (n, d), optional
333
+ `n` points with `d` dimensions to which KDE will be fit. Provided to
334
+ allow batch calculations in `log_prob`. By default, `X` is None and
335
+ all points used to initialize KernelDensityEstimator are included.
336
+
337
+
338
+ Returns
339
+ -------
340
+ log_probs : tensor (m)
341
+ log probability densities for each of the queried points in `Y`
342
+ """
343
+ assert self._backend is not None and self._kernel is not None, "Fit first."
344
+ X = self.X if X is None else X
345
+ if X.shape[0] == 0:
346
+ return self._backend.zeros((Y.shape[0]))
347
+ assert Y.shape[1] == X.shape[1] and X.ndim == Y.ndim == 2
348
+ lazy_x, lazy_y, w = self.to_lazy(X, Y, x_weights=self._sample_weights)
349
+ kernel = self._kernel(lazy_x, lazy_y, self.bandwidth)
350
+ if w is not None:
351
+ kernel *= w
352
+ if return_kernel:
353
+ return kernel
354
+ density_estimation = kernel.sum(dim=0).squeeze() / kernel.shape[0] # mean
355
+ return (
356
+ self._backend.log(density_estimation)
357
+ if self.return_log
358
+ else density_estimation
359
+ )
360
+
361
+
362
+ class DTM:
363
+ """
364
+ Distance To Measure
365
+ """
366
+
367
+ def __init__(self, masses, metric: str = "euclidean", **_kdtree_kwargs):
368
+ """
369
+ mass : float in [0,1]
370
+ The mass threshold
371
+ metric :
372
+ The distance between points to consider
373
+ """
374
+ self.masses = masses
375
+ self.metric = metric
376
+ self._kdtree_kwargs = _kdtree_kwargs
377
+ self._ks = None
378
+ self._kdtree = None
379
+ self._X = None
380
+ self._backend = None
381
+
382
+ def fit(self, X, sample_weights=None, y=None):
383
+ if len(self.masses) == 0:
384
+ return self
385
+ assert np.max(self.masses) <= 1, "All masses should be in (0,1]."
386
+ from sklearn.neighbors import KDTree
387
+
388
+ if not isinstance(X, np.ndarray):
389
+ import torch
390
+
391
+ assert isinstance(X, torch.Tensor), "Backend has to be numpy of torch"
392
+ _X = X.detach()
393
+ self._backend = "torch"
394
+ else:
395
+ _X = X
396
+ self._backend = "numpy"
397
+ self._ks = np.array([int(mass * X.shape[0]) + 1 for mass in self.masses])
398
+ self._kdtree = KDTree(_X, metric=self.metric, **self._kdtree_kwargs)
399
+ self._X = X
400
+ return self
401
+
402
+ def score_samples(self, Y, X=None):
403
+ """Returns the kernel density estimates of each point in `Y`.
404
+
405
+ Parameters
406
+ ----------
407
+ Y : tensor (m, d)
408
+ `m` points with `d` dimensions for which the probability density will
409
+ be calculated
410
+
411
+
412
+ Returns
413
+ -------
414
+ the DTMs of Y, for each mass in masses.
415
+ """
416
+ if len(self.masses) == 0:
417
+ return np.empty((0, len(Y)))
418
+ assert (
419
+ self._ks is not None and self._kdtree is not None and self._X is not None
420
+ ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
421
+ assert Y.ndim == 2
422
+ if self._backend == "torch":
423
+ _Y = Y.detach().numpy()
424
+ else:
425
+ _Y = Y
426
+ NN_Dist, NN = self._kdtree.query(_Y, self._ks.max(), return_distance=True)
427
+ DTMs = np.array([((NN_Dist**2)[:, :k].mean(1)) ** 0.5 for k in self._ks])
428
+ return DTMs
429
+
430
+ def score_samples_diff(self, Y):
431
+ """Returns the kernel density estimates of each point in `Y`.
432
+
433
+ Parameters
434
+ ----------
435
+ Y : tensor (m, d)
436
+ `m` points with `d` dimensions for which the probability density will
437
+ be calculated
438
+ X : tensor (n, d), optional
439
+ `n` points with `d` dimensions to which KDE will be fit. Provided to
440
+ allow batch calculations in `log_prob`. By default, `X` is None and
441
+ all points used to initialize KernelDensityEstimator are included.
442
+
443
+
444
+ Returns
445
+ -------
446
+ log_probs : tensor (m)
447
+ log probability densities for each of the queried points in `Y`
448
+ """
449
+ import torch
450
+
451
+ if len(self.masses) == 0:
452
+ return torch.empty(0, len(Y))
453
+
454
+ assert Y.ndim == 2
455
+ assert self._backend == "torch", "Use the non-diff version with numpy."
456
+ assert (
457
+ self._ks is not None and self._kdtree is not None and self._X is not None
458
+ ), f"Fit first. Got {self._ks=}, {self._kdtree=}, {self._X=}."
459
+ NN = self._kdtree.query(Y.detach(), self._ks.max(), return_distance=False)
460
+ DTMs = tuple(
461
+ (((self._X[NN] - Y[:, None, :]) ** 2)[:, :k].sum(dim=(1, 2)) / k) ** 0.5
462
+ for k in self._ks
463
+ ) # TODO : kdtree already computes distance, find implementation of kdtree that is pytorch differentiable
464
+ return DTMs
465
+
466
+
467
+ ## code taken from pykeops doc (https://www.kernel-operations.io/keops/_auto_benchmarks/benchmark_KNN.html)
468
+ class KNNmean:
469
+ def __init__(self, k: int, metric: str = "euclidean"):
470
+ self.k = k
471
+ self.metric = metric
472
+ self._KNN_fun = None
473
+ self._x = None
474
+
475
+ def fit(self, x):
476
+ if isinstance(x, np.ndarray):
477
+ from pykeops.numpy import Vi, Vj
478
+ else:
479
+ import torch
480
+
481
+ assert isinstance(x, torch.Tensor), "Backend has to be numpy or torch"
482
+ from pykeops.torch import Vi, Vj
483
+
484
+ D = x.shape[1]
485
+ X_i = Vi(0, D)
486
+ X_j = Vj(1, D)
487
+
488
+ # Symbolic distance matrix:
489
+ if self.metric == "euclidean":
490
+ D_ij = ((X_i - X_j) ** 2).sum(-1) ** (1 / 2)
491
+ elif self.metric == "manhattan":
492
+ D_ij = (X_i - X_j).abs().sum(-1)
493
+ elif self.metric == "angular":
494
+ D_ij = -(X_i | X_j)
495
+ elif self.metric == "hyperbolic":
496
+ D_ij = ((X_i - X_j) ** 2).sum(-1) / (X_i[0] * X_j[0])
497
+ else:
498
+ raise NotImplementedError(f"The '{self.metric}' distance is not supported.")
499
+
500
+ self._x = x
501
+ self._KNN_fun = D_ij.Kmin(self.k, dim=1)
502
+ return self
503
+
504
+ def score_samples(self, x):
505
+ assert self._x is not None and self._KNN_fun is not None, "Fit first."
506
+ return self._KNN_fun(x, self._x).sum(axis=1) / self.k
507
+
508
+
509
+ # def _pts_convolution_sparse(pts:np.ndarray, pts_weights:np.ndarray, filtration_grid:Iterable[np.ndarray], kernel="gaussian", bandwidth=0.1, **more_kde_args):
510
+ # """
511
+ # Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
512
+ # """
513
+ # from sklearn.neighbors import KernelDensity
514
+ # grid_iterator = np.asarray(list(product(*filtration_grid)))
515
+ # grid_shape = [len(f) for f in filtration_grid]
516
+ # if len(pts) == 0:
517
+ # # warn("Found a trivial signed measure !")
518
+ # return np.zeros(shape=grid_shape)
519
+ # kde = KernelDensity(kernel=kernel, bandwidth=bandwidth, rtol = 1e-4, **more_kde_args) # TODO : check rtol
520
+
521
+ # pos_indices = pts_weights>0
522
+ # neg_indices = pts_weights<0
523
+ # img_pos = kde.fit(pts[pos_indices], sample_weight=pts_weights[pos_indices]).score_samples(grid_iterator).reshape(grid_shape)
524
+ # img_neg = kde.fit(pts[neg_indices], sample_weight=-pts_weights[neg_indices]).score_samples(grid_iterator).reshape(grid_shape)
525
+ # return np.exp(img_pos) - np.exp(img_neg)
526
+
527
+
528
+ # Precompiles the convolution
529
+ # _test(r=2,b=.5, plot=False)