multipers 2.2.3__cp310-cp310-win_amd64.whl → 2.3.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -31
  2. multipers/_signed_measure_meta.py +430 -430
  3. multipers/_slicer_meta.py +211 -212
  4. multipers/data/MOL2.py +458 -458
  5. multipers/data/UCR.py +18 -18
  6. multipers/data/graphs.py +466 -466
  7. multipers/data/immuno_regions.py +27 -27
  8. multipers/data/pytorch2simplextree.py +90 -90
  9. multipers/data/shape3d.py +101 -101
  10. multipers/data/synthetic.py +113 -111
  11. multipers/distances.py +198 -198
  12. multipers/filtration_conversions.pxd.tp +84 -84
  13. multipers/filtrations/__init__.py +18 -0
  14. multipers/{ml/convolutions.py → filtrations/density.py} +563 -520
  15. multipers/filtrations/filtrations.py +289 -0
  16. multipers/filtrations.pxd +224 -224
  17. multipers/function_rips.cp310-win_amd64.pyd +0 -0
  18. multipers/function_rips.pyx +105 -105
  19. multipers/grids.cp310-win_amd64.pyd +0 -0
  20. multipers/grids.pyx +350 -350
  21. multipers/gudhi/Persistence_slices_interface.h +132 -132
  22. multipers/gudhi/Simplex_tree_interface.h +239 -245
  23. multipers/gudhi/Simplex_tree_multi_interface.h +516 -561
  24. multipers/gudhi/cubical_to_boundary.h +59 -59
  25. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -450
  26. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -1070
  27. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -579
  28. multipers/gudhi/gudhi/Debug_utils.h +45 -45
  29. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -484
  30. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -455
  31. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -450
  32. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -531
  33. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -507
  34. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -531
  35. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -355
  36. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -376
  37. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -420
  38. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -400
  39. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -418
  40. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -337
  41. multipers/gudhi/gudhi/Matrix.h +2107 -2107
  42. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -1038
  43. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -171
  44. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -282
  45. multipers/gudhi/gudhi/Off_reader.h +173 -173
  46. multipers/gudhi/gudhi/One_critical_filtration.h +1433 -1431
  47. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -769
  48. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -686
  49. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -842
  50. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -1350
  51. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -1105
  52. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -859
  53. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -910
  54. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -139
  55. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -230
  56. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -211
  57. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -60
  58. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -60
  59. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -136
  60. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -190
  61. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -616
  62. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -150
  63. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -106
  64. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -219
  65. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -327
  66. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -1140
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -934
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -934
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -980
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -1092
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -192
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -921
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -1093
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -1012
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -1244
  76. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -186
  77. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -164
  78. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -156
  79. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -376
  80. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -540
  81. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -118
  82. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -173
  83. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -128
  84. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -745
  85. multipers/gudhi/gudhi/Points_off_io.h +171 -171
  86. multipers/gudhi/gudhi/Simple_object_pool.h +69 -69
  87. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -463
  88. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -83
  89. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -106
  90. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -277
  91. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -62
  92. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -27
  93. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -62
  94. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -157
  95. multipers/gudhi/gudhi/Simplex_tree.h +2794 -2794
  96. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -163
  97. multipers/gudhi/gudhi/distance_functions.h +62 -62
  98. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -104
  99. multipers/gudhi/gudhi/persistence_interval.h +253 -253
  100. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -170
  101. multipers/gudhi/gudhi/reader_utils.h +367 -367
  102. multipers/gudhi/mma_interface_coh.h +256 -255
  103. multipers/gudhi/mma_interface_h0.h +223 -231
  104. multipers/gudhi/mma_interface_matrix.h +291 -282
  105. multipers/gudhi/naive_merge_tree.h +536 -575
  106. multipers/gudhi/scc_io.h +310 -289
  107. multipers/gudhi/truc.h +957 -888
  108. multipers/io.cp310-win_amd64.pyd +0 -0
  109. multipers/io.pyx +714 -711
  110. multipers/ml/accuracies.py +90 -90
  111. multipers/ml/invariants_with_persistable.py +79 -79
  112. multipers/ml/kernels.py +176 -176
  113. multipers/ml/mma.py +713 -714
  114. multipers/ml/one.py +472 -472
  115. multipers/ml/point_clouds.py +352 -346
  116. multipers/ml/signed_measures.py +1589 -1589
  117. multipers/ml/sliced_wasserstein.py +461 -461
  118. multipers/ml/tools.py +113 -113
  119. multipers/mma_structures.cp310-win_amd64.pyd +0 -0
  120. multipers/mma_structures.pxd +127 -127
  121. multipers/mma_structures.pyx +4 -8
  122. multipers/mma_structures.pyx.tp +1083 -1085
  123. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -93
  124. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -97
  125. multipers/multi_parameter_rank_invariant/function_rips.h +322 -322
  126. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -769
  127. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -148
  128. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -369
  129. multipers/multiparameter_edge_collapse.py +41 -41
  130. multipers/multiparameter_module_approximation/approximation.h +2298 -2295
  131. multipers/multiparameter_module_approximation/combinatory.h +129 -129
  132. multipers/multiparameter_module_approximation/debug.h +107 -107
  133. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -286
  134. multipers/multiparameter_module_approximation/heap_column.h +238 -238
  135. multipers/multiparameter_module_approximation/images.h +79 -79
  136. multipers/multiparameter_module_approximation/list_column.h +174 -174
  137. multipers/multiparameter_module_approximation/list_column_2.h +232 -232
  138. multipers/multiparameter_module_approximation/ru_matrix.h +347 -347
  139. multipers/multiparameter_module_approximation/set_column.h +135 -135
  140. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -36
  141. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -166
  142. multipers/multiparameter_module_approximation/utilities.h +403 -419
  143. multipers/multiparameter_module_approximation/vector_column.h +223 -223
  144. multipers/multiparameter_module_approximation/vector_matrix.h +331 -331
  145. multipers/multiparameter_module_approximation/vineyards.h +464 -464
  146. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -649
  147. multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
  148. multipers/multiparameter_module_approximation.pyx +218 -217
  149. multipers/pickle.py +90 -53
  150. multipers/plots.py +342 -334
  151. multipers/point_measure.cp310-win_amd64.pyd +0 -0
  152. multipers/point_measure.pyx +322 -320
  153. multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
  154. multipers/simplex_tree_multi.pxd +133 -133
  155. multipers/simplex_tree_multi.pyx +115 -48
  156. multipers/simplex_tree_multi.pyx.tp +1947 -1935
  157. multipers/slicer.cp310-win_amd64.pyd +0 -0
  158. multipers/slicer.pxd +301 -120
  159. multipers/slicer.pxd.tp +218 -214
  160. multipers/slicer.pyx +1570 -507
  161. multipers/slicer.pyx.tp +931 -914
  162. multipers/tensor/tensor.h +672 -672
  163. multipers/tensor.pxd +13 -13
  164. multipers/test.pyx +44 -44
  165. multipers/tests/__init__.py +57 -57
  166. multipers/torch/diff_grids.py +217 -217
  167. multipers/torch/rips_density.py +310 -304
  168. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/LICENSE +21 -21
  169. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/METADATA +21 -11
  170. multipers-2.3.1.dist-info/RECORD +182 -0
  171. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/WHEEL +1 -1
  172. multipers/tests/test_diff_helper.py +0 -73
  173. multipers/tests/test_hilbert_function.py +0 -82
  174. multipers/tests/test_mma.py +0 -83
  175. multipers/tests/test_point_clouds.py +0 -49
  176. multipers/tests/test_python-cpp_conversion.py +0 -82
  177. multipers/tests/test_signed_betti.py +0 -181
  178. multipers/tests/test_signed_measure.py +0 -89
  179. multipers/tests/test_simplextreemulti.py +0 -221
  180. multipers/tests/test_slicer.py +0 -221
  181. multipers-2.2.3.dist-info/RECORD +0 -189
  182. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/top_level.txt +0 -0
@@ -1,346 +1,352 @@
1
- from collections.abc import Callable, Iterable
2
- from typing import Literal, Optional
3
-
4
- import gudhi as gd
5
- import numpy as np
6
- from joblib import Parallel, delayed
7
- from sklearn.base import BaseEstimator, TransformerMixin
8
- from tqdm import tqdm
9
-
10
- import multipers as mp
11
- import multipers.slicer as mps
12
- from multipers.ml.convolutions import DTM, KDE, available_kernels
13
-
14
-
15
- class PointCloud2FilteredComplex(BaseEstimator, TransformerMixin):
16
- def __init__(
17
- self,
18
- bandwidths=[],
19
- masses=[],
20
- threshold: float = -np.inf,
21
- complex: Literal["alpha", "rips", "delaunay"] = "rips",
22
- sparse: Optional[float] = None,
23
- num_collapses: int = -2,
24
- kernel: available_kernels = "gaussian",
25
- log_density: bool = True,
26
- expand_dim: int = 1,
27
- progress: bool = False,
28
- n_jobs: Optional[int] = None,
29
- fit_fraction: float = 1,
30
- verbose: bool = False,
31
- safe_conversion: bool = False,
32
- output_type: Optional[
33
- Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
34
- ] = None,
35
- reduce_degrees: Optional[Iterable[int]] = None,
36
- ) -> None:
37
- """
38
- (Rips or Alpha or Delaunay) + (Density Estimation or DTM) 1-critical 2-filtration.
39
-
40
- Parameters
41
- ----------
42
- - bandwidth : real : The kernel density estimation bandwidth, or the DTM mass. If negative, it replaced by abs(bandwidth)*(radius of the dataset)
43
- - threshold : real, max edge lenfth of the rips or max alpha square of the alpha
44
- - sparse : real, sparse rips (c.f. rips doc) WARNING : ONLY FOR RIPS
45
- - num_collapse : int, Number of edge collapses applied to the simplextrees, WARNING : ONLY FOR RIPS
46
- - expand_dim : int, expand the rips complex to this dimension. WARNING : ONLY FOR RIPS
47
- - kernel : the kernel used for density estimation. Available ones are, e.g., "dtm", "gaussian", "exponential".
48
- - progress : bool, shows the calculus status
49
- - n_jobs : number of processes
50
- - fit_fraction : real, the fraction of data on which to fit
51
- - verbose : bool, Shows more information if true.
52
-
53
- Output
54
- ------
55
- A list of SimplexTreeMulti whose first parameter is a rips and the second is the codensity.
56
- """
57
- super().__init__()
58
- self.bandwidths = bandwidths
59
- self.masses = masses
60
- self.num_collapses = num_collapses
61
- self.kernel = kernel
62
- self.log_density = log_density
63
- self.progress = progress
64
- self._bandwidths = np.empty((0,))
65
- self._threshold = np.inf
66
- self.n_jobs = n_jobs
67
- self._scale = np.empty((0,))
68
- self.fit_fraction = fit_fraction
69
- self.expand_dim = expand_dim
70
- self.verbose = verbose
71
- self.complex = complex
72
- self.threshold = threshold
73
- self.sparse = sparse
74
- self._get_sts = lambda: Exception("Fit first")
75
- self.safe_conversion = safe_conversion
76
- self.output_type = output_type
77
- self._output_type = None
78
- self.reduce_degrees = reduce_degrees
79
- self._vineyard = None
80
-
81
- assert (
82
- output_type != "simplextree" or reduce_degrees is None
83
- ), "Reduced complex are not simplicial. Cannot return a simplextree."
84
- return
85
-
86
- def _get_distance_quantiles_and_threshold(self, X, qs):
87
- ## if we dont need to compute a distance matrix
88
- if len(qs) == 0 and self.threshold >= 0:
89
- self._scale = []
90
- return []
91
- if self.progress:
92
- print("Estimating scale...", flush=True, end="")
93
- ## subsampling
94
- indices = np.random.choice(
95
- len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
96
- )
97
-
98
- def compute_max_scale(x):
99
- from pykeops.numpy import LazyTensor
100
-
101
- a = LazyTensor(x[None, :, :])
102
- b = LazyTensor(x[:, None, :])
103
- return np.sqrt(((a - b) ** 2).sum(2).max(1).min(0)[0])
104
-
105
- diameter = np.max([compute_max_scale(x) for x in (X[i] for i in indices)])
106
- self._scale = diameter * np.array(qs)
107
-
108
- if self.threshold == -np.inf:
109
- self._threshold = diameter
110
- elif self.threshold > 0:
111
- self._threshold = self.threshold
112
- else:
113
- self._threshold = -diameter * self.threshold
114
-
115
- if self.threshold > 0:
116
- self._scale[self._scale > self.threshold] = self.threshold
117
-
118
- if self.progress:
119
- print(f"Done. Chosen scales {qs} are {self._scale}", flush=True)
120
- return self._scale
121
-
122
- def _get_sts_rips(self, x):
123
- assert self._output_type is not None and self._vineyard is not None
124
- st_init = gd.RipsComplex(
125
- points=x, max_edge_length=self._threshold, sparse=self.sparse
126
- ).create_simplex_tree(max_dimension=1)
127
- st_init = mp.simplex_tree_multi.SimplexTreeMulti(
128
- st_init, num_parameters=2, safe_conversion=self.safe_conversion
129
- )
130
- codensities = self._get_codensities(x_fit=x, x_sample=x)
131
- num_axes = codensities.shape[0]
132
- sts = [st_init] + [st_init.copy() for _ in range(num_axes - 1)]
133
- # no need to multithread here, most operations are memory
134
- for codensity, st_copy in zip(codensities, sts):
135
- # RIPS has contigus vertices, so vertices are ordered.
136
- st_copy.fill_lowerstar(codensity, parameter=1)
137
-
138
- def reduce(st):
139
- if self.verbose:
140
- print("Num simplices :", st.num_simplices)
141
- if isinstance(self.num_collapses, int):
142
- st.collapse_edges(num=self.num_collapses)
143
- if self.verbose:
144
- print(", after collapse :", st.num_simplices, end="")
145
- elif self.num_collapses == "full":
146
- st.collapse_edges(full=True)
147
- if self.verbose:
148
- print(", after collapse :", st.num_simplices, end="")
149
- if self.expand_dim > 1:
150
- st.expansion(self.expand_dim)
151
- if self.verbose:
152
- print(", after expansion :", st.num_simplices, end="")
153
- if self.verbose:
154
- print("")
155
- if self._output_type == "slicer":
156
- st = mp.Slicer(st, vineyard=self._vineyard)
157
- if self.reduce_degrees is not None:
158
- st = mp.slicer.minimal_presentation(
159
- st, degrees=self.reduce_degrees, vineyard=self._vineyard
160
- )
161
- return st
162
-
163
- return Parallel(backend="threading", n_jobs=self.n_jobs)(
164
- delayed(reduce)(st) for st in sts
165
- )
166
-
167
- def _get_sts_alpha(self, x: np.ndarray, return_alpha=False):
168
- assert self._output_type is not None and self._vineyard is not None
169
- alpha_complex = gd.AlphaComplex(points=x)
170
- st = alpha_complex.create_simplex_tree(max_alpha_square=self._threshold**2)
171
- vertices = np.array([i for (i,), _ in st.get_skeleton(0)])
172
- new_points = np.asarray(
173
- [alpha_complex.get_point(i) for i in vertices]
174
- ) # Seems to be unsafe for some reason
175
- # new_points = x
176
- st = mp.simplex_tree_multi.SimplexTreeMulti(
177
- st, num_parameters=2, safe_conversion=self.safe_conversion
178
- )
179
- codensities = self._get_codensities(x_fit=x, x_sample=new_points)
180
- num_axes = codensities.shape[0]
181
- sts = [st] + [st.copy() for _ in range(num_axes - 1)]
182
- # no need to multithread here, most operations are memory
183
- max_vertices = vertices.max() + 2 # +1 to be safe
184
- for codensity, st_copy in zip(codensities, sts):
185
- alligned_codensity = np.array([np.nan] * max_vertices)
186
- alligned_codensity[vertices] = codensity
187
- # alligned_codensity = np.array([codensity[i] if i in vertices else np.nan for i in range(max_vertices)])
188
- st_copy.fill_lowerstar(alligned_codensity, parameter=1)
189
- if "slicer" in self._output_type:
190
- sts2 = (mp.Slicer(st, vineyard=self._vineyard) for st in sts)
191
- if self.reduce_degrees is not None:
192
- sts = tuple(
193
- mp.slicer.minimal_presentation(
194
- s, degrees=self.reduce_degrees, vineyard=self._vineyard
195
- )
196
- for s in sts2
197
- )
198
- else:
199
- sts = tuple(sts2)
200
- if return_alpha:
201
- return alpha_complex, sts
202
- return sts
203
-
204
- def _get_sts_delaunay(self, x: np.ndarray):
205
- codensities = self._get_codensities(x_fit=x, x_sample=x)
206
-
207
- def get_st(c):
208
- slicer = mps.from_function_delaunay(
209
- x,
210
- c,
211
- verbose=self.verbose,
212
- clear=not self.verbose,
213
- vineyard=self._vineyard,
214
- )
215
- if self._output_type == "simplextree":
216
- slicer = mps.to_simplextree(slicer)
217
- elif self.reduce_degrees is not None:
218
- slicer = mp.slicer.minimal_presentation(
219
- slicer,
220
- degrees=self.reduce_degrees,
221
- vineyard=self._vineyard,
222
- )
223
- else:
224
- slicer = slicer
225
- return slicer
226
-
227
- sts = Parallel(backend="threading", n_jobs=self.n_jobs)(
228
- delayed(get_st)(c) for c in codensities
229
- )
230
- return sts
231
-
232
- def _get_codensities(self, x_fit, x_sample):
233
- x_fit = np.asarray(x_fit, dtype=np.float64)
234
- x_sample = np.asarray(x_sample, dtype=np.float64)
235
- codensities_kde = np.asarray(
236
- [
237
- -KDE(
238
- bandwidth=bandwidth, kernel=self.kernel, return_log=self.log_density
239
- )
240
- .fit(x_fit)
241
- .score_samples(x_sample)
242
- for bandwidth in self._bandwidths
243
- ],
244
- ).reshape(len(self._bandwidths), len(x_sample))
245
- codensities_dtm = (
246
- DTM(masses=self.masses)
247
- .fit(x_fit)
248
- .score_samples(x_sample)
249
- .reshape(len(self.masses), len(x_sample))
250
- )
251
- return np.concatenate([codensities_kde, codensities_dtm])
252
-
253
- def _define_sts(self):
254
- match self.complex:
255
- case "rips":
256
- self._get_sts = self._get_sts_rips
257
- _pref_output = "simplextree"
258
- case "alpha":
259
- self._get_sts = self._get_sts_alpha
260
- _pref_output = "simplextree"
261
- case "delaunay":
262
- self._get_sts = self._get_sts_delaunay
263
- _pref_output = "slicer"
264
- case _:
265
- raise ValueError(
266
- f"Invalid complex \
267
- {self.complex}. Possible choises are rips, delaunay, or alpha."
268
- )
269
- self._vineyard = (
270
- False if self.output_type is None else "novine" not in self.output_type
271
- )
272
- self._output_type = (
273
- _pref_output
274
- if self.output_type is None
275
- else (
276
- "simplextree"
277
- if (
278
- self.output_type == "simplextree" or self.reduce_degrees is not None
279
- )
280
- else "slicer"
281
- )
282
- )
283
-
284
- def _define_bandwidths(self, X):
285
- qs = [q for q in [*-np.asarray(self.bandwidths)] if 0 <= q <= 1]
286
- self._get_distance_quantiles_and_threshold(X, qs=qs)
287
- self._bandwidths = np.array(self.bandwidths)
288
- count = 0
289
- for i in range(len(self._bandwidths)):
290
- if self.bandwidths[i] < 0:
291
- self._bandwidths[i] = self._scale[count]
292
- count += 1
293
-
294
- def fit(self, X: np.ndarray | list, y=None):
295
- # self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
296
- self._define_sts()
297
- self._define_bandwidths(X)
298
- # PRECOMPILE FIRST
299
- self._get_codensities(X[0][:2], X[0][:2])
300
- return self
301
-
302
- def transform(self, X):
303
- # precompile first
304
- # self._get_sts(X[0][:5])
305
- self._get_codensities(X[0][:2], X[0][:2])
306
- with tqdm(
307
- X, desc="Filling simplextrees", disable=not self.progress, total=len(X)
308
- ) as data:
309
- stss = Parallel(backend="threading", n_jobs=self.n_jobs)(
310
- delayed(self._get_sts)(x) for x in data
311
- )
312
- return stss
313
-
314
-
315
- class PointCloud2SimplexTree(PointCloud2FilteredComplex):
316
- def __init__(
317
- self,
318
- bandwidths=[],
319
- masses=[],
320
- threshold: float = np.inf,
321
- complex: Literal["alpha", "rips", "delaunay"] = "rips",
322
- sparse: float | None = None,
323
- num_collapses: int = -2,
324
- kernel: available_kernels = "gaussian",
325
- log_density: bool = True,
326
- expand_dim: int = 1,
327
- progress: bool = False,
328
- n_jobs: Optional[int] = None,
329
- fit_fraction: float = 1,
330
- verbose: bool = False,
331
- safe_conversion: bool = False,
332
- output_type: Optional[
333
- Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
334
- ] = None,
335
- reduce_degrees: Optional[Iterable[int]] = None,
336
- ) -> None:
337
- stuff = locals()
338
- stuff.pop("self")
339
- keys = list(stuff.keys())
340
- for key in keys:
341
- if key.startswith("__"):
342
- stuff.pop(key)
343
- super().__init__(**stuff)
344
- from warnings import warn
345
-
346
- warn("This class is deprecated, use PointCloud2FilteredComplex instead.")
1
+ from collections.abc import Iterable
2
+ from typing import Literal, Optional
3
+
4
+ import gudhi as gd
5
+ import numpy as np
6
+ from joblib import Parallel, delayed
7
+ from sklearn.base import BaseEstimator, TransformerMixin
8
+ from scipy.spatial.distance import cdist
9
+ from tqdm import tqdm
10
+
11
+ import multipers as mp
12
+ import multipers.slicer as mps
13
+ from multipers.filtrations.density import DTM, KDE, available_kernels
14
+
15
+
16
+ class PointCloud2FilteredComplex(BaseEstimator, TransformerMixin):
17
+ def __init__(
18
+ self,
19
+ bandwidths=[],
20
+ masses=[],
21
+ threshold: float = -np.inf,
22
+ complex: Literal["alpha", "rips", "delaunay"] = "rips",
23
+ sparse: Optional[float] = None,
24
+ num_collapses: int = -2,
25
+ kernel: available_kernels = "gaussian",
26
+ log_density: bool = True,
27
+ expand_dim: int = 1,
28
+ progress: bool = False,
29
+ n_jobs: Optional[int] = None,
30
+ fit_fraction: float = 1,
31
+ verbose: bool = False,
32
+ safe_conversion: bool = False,
33
+ output_type: Optional[
34
+ Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
35
+ ] = None,
36
+ reduce_degrees: Optional[Iterable[int]] = None,
37
+ ) -> None:
38
+ """
39
+ (Rips or Alpha or Delaunay) + (Density Estimation or DTM) 1-critical 2-filtration.
40
+
41
+ Parameters
42
+ ----------
43
+ - bandwidth : real : The kernel density estimation bandwidth, or the DTM mass. If negative, it replaced by abs(bandwidth)*(radius of the dataset)
44
+ - threshold : real, max edge lenfth of the rips or max alpha square of the alpha
45
+ - sparse : real, sparse rips (c.f. rips doc) WARNING : ONLY FOR RIPS
46
+ - num_collapse : int, Number of edge collapses applied to the simplextrees, WARNING : ONLY FOR RIPS
47
+ - expand_dim : int, expand the rips complex to this dimension. WARNING : ONLY FOR RIPS
48
+ - kernel : the kernel used for density estimation. Available ones are, e.g., "dtm", "gaussian", "exponential".
49
+ - progress : bool, shows the calculus status
50
+ - n_jobs : number of processes
51
+ - fit_fraction : real, the fraction of data on which to fit
52
+ - verbose : bool, Shows more information if true.
53
+
54
+ Output
55
+ ------
56
+ A list of SimplexTreeMulti whose first parameter is a rips and the second is the codensity.
57
+ """
58
+ super().__init__()
59
+ self.bandwidths = bandwidths
60
+ self.masses = masses
61
+ self.num_collapses = num_collapses
62
+ self.kernel = kernel
63
+ self.log_density = log_density
64
+ self.progress = progress
65
+ self._bandwidths = np.empty((0,))
66
+ self._threshold = np.inf
67
+ self.n_jobs = n_jobs
68
+ self._scale = np.empty((0,))
69
+ self.fit_fraction = fit_fraction
70
+ self.expand_dim = expand_dim
71
+ self.verbose = verbose
72
+ self.complex = complex
73
+ self.threshold = threshold
74
+ self.sparse = sparse
75
+ self._get_sts = lambda: Exception("Fit first")
76
+ self.safe_conversion = safe_conversion
77
+ self.output_type = output_type
78
+ self._output_type = None
79
+ self.reduce_degrees = reduce_degrees
80
+ self._vineyard = None
81
+
82
+ assert (
83
+ output_type != "simplextree" or reduce_degrees is None
84
+ ), "Reduced complex are not simplicial. Cannot return a simplextree."
85
+ return
86
+
87
+ def _get_distance_quantiles_and_threshold(self, X, qs):
88
+ ## if we dont need to compute a distance matrix
89
+ if len(qs) == 0 and self.threshold >= 0:
90
+ self._scale = []
91
+ return []
92
+ if self.progress:
93
+ print("Estimating scale...", flush=True, end="")
94
+ ## subsampling
95
+ indices = np.random.choice(
96
+ len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
97
+ )
98
+
99
+ def compute_max_scale(x):
100
+ from pykeops.numpy import LazyTensor
101
+
102
+ a = LazyTensor(x[None, :, :])
103
+ b = LazyTensor(x[:, None, :])
104
+ return np.sqrt(((a - b) ** 2).sum(2).max(1).min(0)[0])
105
+
106
+ diameter = np.max([compute_max_scale(x) for x in (X[i] for i in indices)])
107
+ self._scale = diameter * np.array(qs)
108
+
109
+ if self.threshold == -np.inf:
110
+ self._threshold = diameter
111
+ elif self.threshold > 0:
112
+ self._threshold = self.threshold
113
+ else:
114
+ self._threshold = -diameter * self.threshold
115
+
116
+ if self.threshold > 0:
117
+ self._scale[self._scale > self.threshold] = self.threshold
118
+
119
+ if self.progress:
120
+ print(f"Done. Chosen scales {qs} are {self._scale}", flush=True)
121
+ return self._scale
122
+
123
+ def _get_sts_rips(self, x):
124
+ assert self._output_type is not None and self._vineyard is not None
125
+ if self.sparse is None:
126
+ st_init = gd.SimplexTree.create_from_array(
127
+ cdist(x,x), max_filtration=self._threshold
128
+ )
129
+ else:
130
+ st_init = gd.RipsComplex(
131
+ points=x, max_edge_length=self._threshold, sparse=self.sparse
132
+ ).create_simplex_tree(max_dimension=1)
133
+ st_init = mp.simplex_tree_multi.SimplexTreeMulti(
134
+ st_init, num_parameters=2, safe_conversion=self.safe_conversion
135
+ )
136
+ codensities = self._get_codensities(x_fit=x, x_sample=x)
137
+ num_axes = codensities.shape[0]
138
+ sts = [st_init] + [st_init.copy() for _ in range(num_axes - 1)]
139
+ # no need to multithread here, most operations are memory
140
+ for codensity, st_copy in zip(codensities, sts):
141
+ # RIPS has contigus vertices, so vertices are ordered.
142
+ st_copy.fill_lowerstar(codensity, parameter=1)
143
+
144
+ def reduce(st):
145
+ if self.verbose:
146
+ print("Num simplices :", st.num_simplices)
147
+ if isinstance(self.num_collapses, int):
148
+ st.collapse_edges(num=self.num_collapses)
149
+ if self.verbose:
150
+ print(", after collapse :", st.num_simplices, end="")
151
+ elif self.num_collapses == "full":
152
+ st.collapse_edges(full=True)
153
+ if self.verbose:
154
+ print(", after collapse :", st.num_simplices, end="")
155
+ if self.expand_dim > 1:
156
+ st.expansion(self.expand_dim)
157
+ if self.verbose:
158
+ print(", after expansion :", st.num_simplices, end="")
159
+ if self.verbose:
160
+ print("")
161
+ if self._output_type == "slicer":
162
+ st = mp.Slicer(st, vineyard=self._vineyard)
163
+ if self.reduce_degrees is not None:
164
+ st = mp.slicer.minimal_presentation(
165
+ st, degrees=self.reduce_degrees, vineyard=self._vineyard
166
+ )
167
+ return st
168
+
169
+ return Parallel(backend="threading", n_jobs=self.n_jobs)(
170
+ delayed(reduce)(st) for st in sts
171
+ )
172
+
173
+ def _get_sts_alpha(self, x: np.ndarray, return_alpha=False):
174
+ assert self._output_type is not None and self._vineyard is not None
175
+ alpha_complex = gd.AlphaComplex(points=x)
176
+ st = alpha_complex.create_simplex_tree(max_alpha_square=self._threshold**2)
177
+ vertices = np.array([i for (i,), _ in st.get_skeleton(0)])
178
+ new_points = np.asarray(
179
+ [alpha_complex.get_point(int(i)) for i in vertices]
180
+ ) # Seems to be unsafe for some reason
181
+ # new_points = x
182
+ st = mp.simplex_tree_multi.SimplexTreeMulti(
183
+ st, num_parameters=2, safe_conversion=self.safe_conversion
184
+ )
185
+ codensities = self._get_codensities(x_fit=x, x_sample=new_points)
186
+ num_axes = codensities.shape[0]
187
+ sts = [st] + [st.copy() for _ in range(num_axes - 1)]
188
+ # no need to multithread here, most operations are memory
189
+ max_vertices = vertices.max() + 2 # +1 to be safe
190
+ for codensity, st_copy in zip(codensities, sts):
191
+ alligned_codensity = np.array([np.nan] * max_vertices)
192
+ alligned_codensity[vertices] = codensity
193
+ # alligned_codensity = np.array([codensity[i] if i in vertices else np.nan for i in range(max_vertices)])
194
+ st_copy.fill_lowerstar(alligned_codensity, parameter=1)
195
+ if "slicer" in self._output_type:
196
+ sts2 = (mp.Slicer(st, vineyard=self._vineyard) for st in sts)
197
+ if self.reduce_degrees is not None:
198
+ sts = tuple(
199
+ mp.slicer.minimal_presentation(
200
+ s, degrees=self.reduce_degrees, vineyard=self._vineyard
201
+ )
202
+ for s in sts2
203
+ )
204
+ else:
205
+ sts = tuple(sts2)
206
+ if return_alpha:
207
+ return alpha_complex, sts
208
+ return sts
209
+
210
+ def _get_sts_delaunay(self, x: np.ndarray):
211
+ codensities = self._get_codensities(x_fit=x, x_sample=x)
212
+
213
+ def get_st(c):
214
+ slicer = mps.from_function_delaunay(
215
+ x,
216
+ c,
217
+ verbose=self.verbose,
218
+ clear=not self.verbose,
219
+ vineyard=self._vineyard,
220
+ )
221
+ if self._output_type == "simplextree":
222
+ slicer = mps.to_simplextree(slicer)
223
+ elif self.reduce_degrees is not None:
224
+ slicer = mp.slicer.minimal_presentation(
225
+ slicer,
226
+ degrees=self.reduce_degrees,
227
+ vineyard=self._vineyard,
228
+ )
229
+ else:
230
+ slicer = slicer
231
+ return slicer
232
+
233
+ sts = Parallel(backend="threading", n_jobs=self.n_jobs)(
234
+ delayed(get_st)(c) for c in codensities
235
+ )
236
+ return sts
237
+
238
+ def _get_codensities(self, x_fit, x_sample):
239
+ x_fit = np.asarray(x_fit, dtype=np.float64)
240
+ x_sample = np.asarray(x_sample, dtype=np.float64)
241
+ codensities_kde = np.asarray(
242
+ [
243
+ -KDE(
244
+ bandwidth=bandwidth, kernel=self.kernel, return_log=self.log_density
245
+ )
246
+ .fit(x_fit)
247
+ .score_samples(x_sample)
248
+ for bandwidth in self._bandwidths
249
+ ],
250
+ ).reshape(len(self._bandwidths), len(x_sample))
251
+ codensities_dtm = (
252
+ DTM(masses=self.masses)
253
+ .fit(x_fit)
254
+ .score_samples(x_sample)
255
+ .reshape(len(self.masses), len(x_sample))
256
+ )
257
+ return np.concatenate([codensities_kde, codensities_dtm])
258
+
259
+ def _define_sts(self):
260
+ match self.complex:
261
+ case "rips":
262
+ self._get_sts = self._get_sts_rips
263
+ _pref_output = "simplextree"
264
+ case "alpha":
265
+ self._get_sts = self._get_sts_alpha
266
+ _pref_output = "simplextree"
267
+ case "delaunay":
268
+ self._get_sts = self._get_sts_delaunay
269
+ _pref_output = "slicer"
270
+ case _:
271
+ raise ValueError(
272
+ f"Invalid complex \
273
+ {self.complex}. Possible choises are rips, delaunay, or alpha."
274
+ )
275
+ self._vineyard = (
276
+ False if self.output_type is None else "novine" not in self.output_type
277
+ )
278
+ self._output_type = (
279
+ _pref_output
280
+ if self.output_type is None
281
+ else (
282
+ "simplextree"
283
+ if (
284
+ self.output_type == "simplextree" or self.reduce_degrees is not None
285
+ )
286
+ else "slicer"
287
+ )
288
+ )
289
+
290
+ def _define_bandwidths(self, X):
291
+ qs = [q for q in [*-np.asarray(self.bandwidths)] if 0 <= q <= 1]
292
+ self._get_distance_quantiles_and_threshold(X, qs=qs)
293
+ self._bandwidths = np.array(self.bandwidths)
294
+ count = 0
295
+ for i in range(len(self._bandwidths)):
296
+ if self.bandwidths[i] < 0:
297
+ self._bandwidths[i] = self._scale[count]
298
+ count += 1
299
+
300
+ def fit(self, X: np.ndarray | list, y=None):
301
+ # self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
302
+ self._define_sts()
303
+ self._define_bandwidths(X)
304
+ # PRECOMPILE FIRST
305
+ self._get_codensities(X[0][:2], X[0][:2])
306
+ return self
307
+
308
+ def transform(self, X):
309
+ # precompile first
310
+ # self._get_sts(X[0][:5])
311
+ self._get_codensities(X[0][:2], X[0][:2])
312
+ with tqdm(
313
+ X, desc="Filling simplextrees", disable=not self.progress, total=len(X)
314
+ ) as data:
315
+ stss = Parallel(backend="threading", n_jobs=self.n_jobs)(
316
+ delayed(self._get_sts)(x) for x in data
317
+ )
318
+ return stss
319
+
320
+
321
+ class PointCloud2SimplexTree(PointCloud2FilteredComplex):
322
+ def __init__(
323
+ self,
324
+ bandwidths=[],
325
+ masses=[],
326
+ threshold: float = np.inf,
327
+ complex: Literal["alpha", "rips", "delaunay"] = "rips",
328
+ sparse: float | None = None,
329
+ num_collapses: int = -2,
330
+ kernel: available_kernels = "gaussian",
331
+ log_density: bool = True,
332
+ expand_dim: int = 1,
333
+ progress: bool = False,
334
+ n_jobs: Optional[int] = None,
335
+ fit_fraction: float = 1,
336
+ verbose: bool = False,
337
+ safe_conversion: bool = False,
338
+ output_type: Optional[
339
+ Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
340
+ ] = None,
341
+ reduce_degrees: Optional[Iterable[int]] = None,
342
+ ) -> None:
343
+ stuff = locals()
344
+ stuff.pop("self")
345
+ keys = list(stuff.keys())
346
+ for key in keys:
347
+ if key.startswith("__"):
348
+ stuff.pop(key)
349
+ super().__init__(**stuff)
350
+ from warnings import warn
351
+
352
+ warn("This class is deprecated, use PointCloud2FilteredComplex instead.")