multipers 2.2.3__cp310-cp310-win_amd64.whl → 2.3.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -31
  2. multipers/_signed_measure_meta.py +430 -430
  3. multipers/_slicer_meta.py +211 -212
  4. multipers/data/MOL2.py +458 -458
  5. multipers/data/UCR.py +18 -18
  6. multipers/data/graphs.py +466 -466
  7. multipers/data/immuno_regions.py +27 -27
  8. multipers/data/pytorch2simplextree.py +90 -90
  9. multipers/data/shape3d.py +101 -101
  10. multipers/data/synthetic.py +113 -111
  11. multipers/distances.py +198 -198
  12. multipers/filtration_conversions.pxd.tp +84 -84
  13. multipers/filtrations/__init__.py +18 -0
  14. multipers/filtrations/filtrations.py +289 -0
  15. multipers/filtrations.pxd +224 -224
  16. multipers/function_rips.cp310-win_amd64.pyd +0 -0
  17. multipers/function_rips.pyx +105 -105
  18. multipers/grids.cp310-win_amd64.pyd +0 -0
  19. multipers/grids.pyx +350 -350
  20. multipers/gudhi/Persistence_slices_interface.h +132 -132
  21. multipers/gudhi/Simplex_tree_interface.h +239 -245
  22. multipers/gudhi/Simplex_tree_multi_interface.h +516 -561
  23. multipers/gudhi/cubical_to_boundary.h +59 -59
  24. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -450
  25. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -1070
  26. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -579
  27. multipers/gudhi/gudhi/Debug_utils.h +45 -45
  28. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -484
  29. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -455
  30. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -450
  31. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -531
  32. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -507
  33. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -531
  34. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -355
  35. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -376
  36. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -420
  37. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -400
  38. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -418
  39. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -337
  40. multipers/gudhi/gudhi/Matrix.h +2107 -2107
  41. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -1038
  42. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -171
  43. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -282
  44. multipers/gudhi/gudhi/Off_reader.h +173 -173
  45. multipers/gudhi/gudhi/One_critical_filtration.h +1432 -1431
  46. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -769
  47. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -686
  48. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -842
  49. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -1350
  50. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -1105
  51. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -859
  52. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -910
  53. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -139
  54. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -230
  55. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -211
  56. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -60
  57. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -60
  58. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -136
  59. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -190
  60. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -616
  61. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -150
  62. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -106
  63. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -219
  64. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -327
  65. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -1140
  66. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -934
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -934
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -980
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -1092
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -192
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -921
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -1093
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -1012
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -1244
  75. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -186
  76. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -164
  77. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -156
  78. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -376
  79. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -540
  80. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -118
  81. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -173
  82. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -128
  83. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -745
  84. multipers/gudhi/gudhi/Points_off_io.h +171 -171
  85. multipers/gudhi/gudhi/Simple_object_pool.h +69 -69
  86. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -463
  87. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -83
  88. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -106
  89. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -277
  90. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -62
  91. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -27
  92. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -62
  93. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -157
  94. multipers/gudhi/gudhi/Simplex_tree.h +2794 -2794
  95. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -163
  96. multipers/gudhi/gudhi/distance_functions.h +62 -62
  97. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -104
  98. multipers/gudhi/gudhi/persistence_interval.h +253 -253
  99. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -170
  100. multipers/gudhi/gudhi/reader_utils.h +367 -367
  101. multipers/gudhi/mma_interface_coh.h +256 -255
  102. multipers/gudhi/mma_interface_h0.h +223 -231
  103. multipers/gudhi/mma_interface_matrix.h +284 -282
  104. multipers/gudhi/naive_merge_tree.h +536 -575
  105. multipers/gudhi/scc_io.h +310 -289
  106. multipers/gudhi/truc.h +890 -888
  107. multipers/io.cp310-win_amd64.pyd +0 -0
  108. multipers/io.pyx +711 -711
  109. multipers/ml/accuracies.py +90 -90
  110. multipers/ml/convolutions.py +520 -520
  111. multipers/ml/invariants_with_persistable.py +79 -79
  112. multipers/ml/kernels.py +176 -176
  113. multipers/ml/mma.py +713 -714
  114. multipers/ml/one.py +472 -472
  115. multipers/ml/point_clouds.py +352 -346
  116. multipers/ml/signed_measures.py +1589 -1589
  117. multipers/ml/sliced_wasserstein.py +461 -461
  118. multipers/ml/tools.py +113 -113
  119. multipers/mma_structures.cp310-win_amd64.pyd +0 -0
  120. multipers/mma_structures.pxd +127 -127
  121. multipers/mma_structures.pyx +4 -4
  122. multipers/mma_structures.pyx.tp +1085 -1085
  123. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -93
  124. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -97
  125. multipers/multi_parameter_rank_invariant/function_rips.h +322 -322
  126. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -769
  127. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -148
  128. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -369
  129. multipers/multiparameter_edge_collapse.py +41 -41
  130. multipers/multiparameter_module_approximation/approximation.h +2296 -2295
  131. multipers/multiparameter_module_approximation/combinatory.h +129 -129
  132. multipers/multiparameter_module_approximation/debug.h +107 -107
  133. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -286
  134. multipers/multiparameter_module_approximation/heap_column.h +238 -238
  135. multipers/multiparameter_module_approximation/images.h +79 -79
  136. multipers/multiparameter_module_approximation/list_column.h +174 -174
  137. multipers/multiparameter_module_approximation/list_column_2.h +232 -232
  138. multipers/multiparameter_module_approximation/ru_matrix.h +347 -347
  139. multipers/multiparameter_module_approximation/set_column.h +135 -135
  140. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -36
  141. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -166
  142. multipers/multiparameter_module_approximation/utilities.h +403 -419
  143. multipers/multiparameter_module_approximation/vector_column.h +223 -223
  144. multipers/multiparameter_module_approximation/vector_matrix.h +331 -331
  145. multipers/multiparameter_module_approximation/vineyards.h +464 -464
  146. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -649
  147. multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
  148. multipers/multiparameter_module_approximation.pyx +216 -217
  149. multipers/pickle.py +90 -53
  150. multipers/plots.py +342 -334
  151. multipers/point_measure.cp310-win_amd64.pyd +0 -0
  152. multipers/point_measure.pyx +322 -320
  153. multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
  154. multipers/simplex_tree_multi.pxd +133 -133
  155. multipers/simplex_tree_multi.pyx +18 -15
  156. multipers/simplex_tree_multi.pyx.tp +1939 -1935
  157. multipers/slicer.cp310-win_amd64.pyd +0 -0
  158. multipers/slicer.pxd +81 -20
  159. multipers/slicer.pxd.tp +215 -214
  160. multipers/slicer.pyx +1091 -308
  161. multipers/slicer.pyx.tp +924 -914
  162. multipers/tensor/tensor.h +672 -672
  163. multipers/tensor.pxd +13 -13
  164. multipers/test.pyx +44 -44
  165. multipers/tests/__init__.py +57 -57
  166. multipers/torch/diff_grids.py +217 -217
  167. multipers/torch/rips_density.py +310 -304
  168. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/LICENSE +21 -21
  169. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/METADATA +21 -11
  170. multipers-2.3.0.dist-info/RECORD +182 -0
  171. multipers/tests/test_diff_helper.py +0 -73
  172. multipers/tests/test_hilbert_function.py +0 -82
  173. multipers/tests/test_mma.py +0 -83
  174. multipers/tests/test_point_clouds.py +0 -49
  175. multipers/tests/test_python-cpp_conversion.py +0 -82
  176. multipers/tests/test_signed_betti.py +0 -181
  177. multipers/tests/test_signed_measure.py +0 -89
  178. multipers/tests/test_simplextreemulti.py +0 -221
  179. multipers/tests/test_slicer.py +0 -221
  180. multipers-2.2.3.dist-info/RECORD +0 -189
  181. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/WHEEL +0 -0
  182. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/top_level.txt +0 -0
@@ -1,346 +1,352 @@
1
- from collections.abc import Callable, Iterable
2
- from typing import Literal, Optional
3
-
4
- import gudhi as gd
5
- import numpy as np
6
- from joblib import Parallel, delayed
7
- from sklearn.base import BaseEstimator, TransformerMixin
8
- from tqdm import tqdm
9
-
10
- import multipers as mp
11
- import multipers.slicer as mps
12
- from multipers.ml.convolutions import DTM, KDE, available_kernels
13
-
14
-
15
- class PointCloud2FilteredComplex(BaseEstimator, TransformerMixin):
16
- def __init__(
17
- self,
18
- bandwidths=[],
19
- masses=[],
20
- threshold: float = -np.inf,
21
- complex: Literal["alpha", "rips", "delaunay"] = "rips",
22
- sparse: Optional[float] = None,
23
- num_collapses: int = -2,
24
- kernel: available_kernels = "gaussian",
25
- log_density: bool = True,
26
- expand_dim: int = 1,
27
- progress: bool = False,
28
- n_jobs: Optional[int] = None,
29
- fit_fraction: float = 1,
30
- verbose: bool = False,
31
- safe_conversion: bool = False,
32
- output_type: Optional[
33
- Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
34
- ] = None,
35
- reduce_degrees: Optional[Iterable[int]] = None,
36
- ) -> None:
37
- """
38
- (Rips or Alpha or Delaunay) + (Density Estimation or DTM) 1-critical 2-filtration.
39
-
40
- Parameters
41
- ----------
42
- - bandwidth : real : The kernel density estimation bandwidth, or the DTM mass. If negative, it replaced by abs(bandwidth)*(radius of the dataset)
43
- - threshold : real, max edge lenfth of the rips or max alpha square of the alpha
44
- - sparse : real, sparse rips (c.f. rips doc) WARNING : ONLY FOR RIPS
45
- - num_collapse : int, Number of edge collapses applied to the simplextrees, WARNING : ONLY FOR RIPS
46
- - expand_dim : int, expand the rips complex to this dimension. WARNING : ONLY FOR RIPS
47
- - kernel : the kernel used for density estimation. Available ones are, e.g., "dtm", "gaussian", "exponential".
48
- - progress : bool, shows the calculus status
49
- - n_jobs : number of processes
50
- - fit_fraction : real, the fraction of data on which to fit
51
- - verbose : bool, Shows more information if true.
52
-
53
- Output
54
- ------
55
- A list of SimplexTreeMulti whose first parameter is a rips and the second is the codensity.
56
- """
57
- super().__init__()
58
- self.bandwidths = bandwidths
59
- self.masses = masses
60
- self.num_collapses = num_collapses
61
- self.kernel = kernel
62
- self.log_density = log_density
63
- self.progress = progress
64
- self._bandwidths = np.empty((0,))
65
- self._threshold = np.inf
66
- self.n_jobs = n_jobs
67
- self._scale = np.empty((0,))
68
- self.fit_fraction = fit_fraction
69
- self.expand_dim = expand_dim
70
- self.verbose = verbose
71
- self.complex = complex
72
- self.threshold = threshold
73
- self.sparse = sparse
74
- self._get_sts = lambda: Exception("Fit first")
75
- self.safe_conversion = safe_conversion
76
- self.output_type = output_type
77
- self._output_type = None
78
- self.reduce_degrees = reduce_degrees
79
- self._vineyard = None
80
-
81
- assert (
82
- output_type != "simplextree" or reduce_degrees is None
83
- ), "Reduced complex are not simplicial. Cannot return a simplextree."
84
- return
85
-
86
- def _get_distance_quantiles_and_threshold(self, X, qs):
87
- ## if we dont need to compute a distance matrix
88
- if len(qs) == 0 and self.threshold >= 0:
89
- self._scale = []
90
- return []
91
- if self.progress:
92
- print("Estimating scale...", flush=True, end="")
93
- ## subsampling
94
- indices = np.random.choice(
95
- len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
96
- )
97
-
98
- def compute_max_scale(x):
99
- from pykeops.numpy import LazyTensor
100
-
101
- a = LazyTensor(x[None, :, :])
102
- b = LazyTensor(x[:, None, :])
103
- return np.sqrt(((a - b) ** 2).sum(2).max(1).min(0)[0])
104
-
105
- diameter = np.max([compute_max_scale(x) for x in (X[i] for i in indices)])
106
- self._scale = diameter * np.array(qs)
107
-
108
- if self.threshold == -np.inf:
109
- self._threshold = diameter
110
- elif self.threshold > 0:
111
- self._threshold = self.threshold
112
- else:
113
- self._threshold = -diameter * self.threshold
114
-
115
- if self.threshold > 0:
116
- self._scale[self._scale > self.threshold] = self.threshold
117
-
118
- if self.progress:
119
- print(f"Done. Chosen scales {qs} are {self._scale}", flush=True)
120
- return self._scale
121
-
122
- def _get_sts_rips(self, x):
123
- assert self._output_type is not None and self._vineyard is not None
124
- st_init = gd.RipsComplex(
125
- points=x, max_edge_length=self._threshold, sparse=self.sparse
126
- ).create_simplex_tree(max_dimension=1)
127
- st_init = mp.simplex_tree_multi.SimplexTreeMulti(
128
- st_init, num_parameters=2, safe_conversion=self.safe_conversion
129
- )
130
- codensities = self._get_codensities(x_fit=x, x_sample=x)
131
- num_axes = codensities.shape[0]
132
- sts = [st_init] + [st_init.copy() for _ in range(num_axes - 1)]
133
- # no need to multithread here, most operations are memory
134
- for codensity, st_copy in zip(codensities, sts):
135
- # RIPS has contigus vertices, so vertices are ordered.
136
- st_copy.fill_lowerstar(codensity, parameter=1)
137
-
138
- def reduce(st):
139
- if self.verbose:
140
- print("Num simplices :", st.num_simplices)
141
- if isinstance(self.num_collapses, int):
142
- st.collapse_edges(num=self.num_collapses)
143
- if self.verbose:
144
- print(", after collapse :", st.num_simplices, end="")
145
- elif self.num_collapses == "full":
146
- st.collapse_edges(full=True)
147
- if self.verbose:
148
- print(", after collapse :", st.num_simplices, end="")
149
- if self.expand_dim > 1:
150
- st.expansion(self.expand_dim)
151
- if self.verbose:
152
- print(", after expansion :", st.num_simplices, end="")
153
- if self.verbose:
154
- print("")
155
- if self._output_type == "slicer":
156
- st = mp.Slicer(st, vineyard=self._vineyard)
157
- if self.reduce_degrees is not None:
158
- st = mp.slicer.minimal_presentation(
159
- st, degrees=self.reduce_degrees, vineyard=self._vineyard
160
- )
161
- return st
162
-
163
- return Parallel(backend="threading", n_jobs=self.n_jobs)(
164
- delayed(reduce)(st) for st in sts
165
- )
166
-
167
- def _get_sts_alpha(self, x: np.ndarray, return_alpha=False):
168
- assert self._output_type is not None and self._vineyard is not None
169
- alpha_complex = gd.AlphaComplex(points=x)
170
- st = alpha_complex.create_simplex_tree(max_alpha_square=self._threshold**2)
171
- vertices = np.array([i for (i,), _ in st.get_skeleton(0)])
172
- new_points = np.asarray(
173
- [alpha_complex.get_point(i) for i in vertices]
174
- ) # Seems to be unsafe for some reason
175
- # new_points = x
176
- st = mp.simplex_tree_multi.SimplexTreeMulti(
177
- st, num_parameters=2, safe_conversion=self.safe_conversion
178
- )
179
- codensities = self._get_codensities(x_fit=x, x_sample=new_points)
180
- num_axes = codensities.shape[0]
181
- sts = [st] + [st.copy() for _ in range(num_axes - 1)]
182
- # no need to multithread here, most operations are memory
183
- max_vertices = vertices.max() + 2 # +1 to be safe
184
- for codensity, st_copy in zip(codensities, sts):
185
- alligned_codensity = np.array([np.nan] * max_vertices)
186
- alligned_codensity[vertices] = codensity
187
- # alligned_codensity = np.array([codensity[i] if i in vertices else np.nan for i in range(max_vertices)])
188
- st_copy.fill_lowerstar(alligned_codensity, parameter=1)
189
- if "slicer" in self._output_type:
190
- sts2 = (mp.Slicer(st, vineyard=self._vineyard) for st in sts)
191
- if self.reduce_degrees is not None:
192
- sts = tuple(
193
- mp.slicer.minimal_presentation(
194
- s, degrees=self.reduce_degrees, vineyard=self._vineyard
195
- )
196
- for s in sts2
197
- )
198
- else:
199
- sts = tuple(sts2)
200
- if return_alpha:
201
- return alpha_complex, sts
202
- return sts
203
-
204
- def _get_sts_delaunay(self, x: np.ndarray):
205
- codensities = self._get_codensities(x_fit=x, x_sample=x)
206
-
207
- def get_st(c):
208
- slicer = mps.from_function_delaunay(
209
- x,
210
- c,
211
- verbose=self.verbose,
212
- clear=not self.verbose,
213
- vineyard=self._vineyard,
214
- )
215
- if self._output_type == "simplextree":
216
- slicer = mps.to_simplextree(slicer)
217
- elif self.reduce_degrees is not None:
218
- slicer = mp.slicer.minimal_presentation(
219
- slicer,
220
- degrees=self.reduce_degrees,
221
- vineyard=self._vineyard,
222
- )
223
- else:
224
- slicer = slicer
225
- return slicer
226
-
227
- sts = Parallel(backend="threading", n_jobs=self.n_jobs)(
228
- delayed(get_st)(c) for c in codensities
229
- )
230
- return sts
231
-
232
- def _get_codensities(self, x_fit, x_sample):
233
- x_fit = np.asarray(x_fit, dtype=np.float64)
234
- x_sample = np.asarray(x_sample, dtype=np.float64)
235
- codensities_kde = np.asarray(
236
- [
237
- -KDE(
238
- bandwidth=bandwidth, kernel=self.kernel, return_log=self.log_density
239
- )
240
- .fit(x_fit)
241
- .score_samples(x_sample)
242
- for bandwidth in self._bandwidths
243
- ],
244
- ).reshape(len(self._bandwidths), len(x_sample))
245
- codensities_dtm = (
246
- DTM(masses=self.masses)
247
- .fit(x_fit)
248
- .score_samples(x_sample)
249
- .reshape(len(self.masses), len(x_sample))
250
- )
251
- return np.concatenate([codensities_kde, codensities_dtm])
252
-
253
- def _define_sts(self):
254
- match self.complex:
255
- case "rips":
256
- self._get_sts = self._get_sts_rips
257
- _pref_output = "simplextree"
258
- case "alpha":
259
- self._get_sts = self._get_sts_alpha
260
- _pref_output = "simplextree"
261
- case "delaunay":
262
- self._get_sts = self._get_sts_delaunay
263
- _pref_output = "slicer"
264
- case _:
265
- raise ValueError(
266
- f"Invalid complex \
267
- {self.complex}. Possible choises are rips, delaunay, or alpha."
268
- )
269
- self._vineyard = (
270
- False if self.output_type is None else "novine" not in self.output_type
271
- )
272
- self._output_type = (
273
- _pref_output
274
- if self.output_type is None
275
- else (
276
- "simplextree"
277
- if (
278
- self.output_type == "simplextree" or self.reduce_degrees is not None
279
- )
280
- else "slicer"
281
- )
282
- )
283
-
284
- def _define_bandwidths(self, X):
285
- qs = [q for q in [*-np.asarray(self.bandwidths)] if 0 <= q <= 1]
286
- self._get_distance_quantiles_and_threshold(X, qs=qs)
287
- self._bandwidths = np.array(self.bandwidths)
288
- count = 0
289
- for i in range(len(self._bandwidths)):
290
- if self.bandwidths[i] < 0:
291
- self._bandwidths[i] = self._scale[count]
292
- count += 1
293
-
294
- def fit(self, X: np.ndarray | list, y=None):
295
- # self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
296
- self._define_sts()
297
- self._define_bandwidths(X)
298
- # PRECOMPILE FIRST
299
- self._get_codensities(X[0][:2], X[0][:2])
300
- return self
301
-
302
- def transform(self, X):
303
- # precompile first
304
- # self._get_sts(X[0][:5])
305
- self._get_codensities(X[0][:2], X[0][:2])
306
- with tqdm(
307
- X, desc="Filling simplextrees", disable=not self.progress, total=len(X)
308
- ) as data:
309
- stss = Parallel(backend="threading", n_jobs=self.n_jobs)(
310
- delayed(self._get_sts)(x) for x in data
311
- )
312
- return stss
313
-
314
-
315
- class PointCloud2SimplexTree(PointCloud2FilteredComplex):
316
- def __init__(
317
- self,
318
- bandwidths=[],
319
- masses=[],
320
- threshold: float = np.inf,
321
- complex: Literal["alpha", "rips", "delaunay"] = "rips",
322
- sparse: float | None = None,
323
- num_collapses: int = -2,
324
- kernel: available_kernels = "gaussian",
325
- log_density: bool = True,
326
- expand_dim: int = 1,
327
- progress: bool = False,
328
- n_jobs: Optional[int] = None,
329
- fit_fraction: float = 1,
330
- verbose: bool = False,
331
- safe_conversion: bool = False,
332
- output_type: Optional[
333
- Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
334
- ] = None,
335
- reduce_degrees: Optional[Iterable[int]] = None,
336
- ) -> None:
337
- stuff = locals()
338
- stuff.pop("self")
339
- keys = list(stuff.keys())
340
- for key in keys:
341
- if key.startswith("__"):
342
- stuff.pop(key)
343
- super().__init__(**stuff)
344
- from warnings import warn
345
-
346
- warn("This class is deprecated, use PointCloud2FilteredComplex instead.")
1
+ from collections.abc import Iterable
2
+ from typing import Literal, Optional
3
+
4
+ import gudhi as gd
5
+ import numpy as np
6
+ from joblib import Parallel, delayed
7
+ from sklearn.base import BaseEstimator, TransformerMixin
8
+ from scipy.spatial.distance import cdist
9
+ from tqdm import tqdm
10
+
11
+ import multipers as mp
12
+ import multipers.slicer as mps
13
+ from multipers.ml.convolutions import DTM, KDE, available_kernels
14
+
15
+
16
+ class PointCloud2FilteredComplex(BaseEstimator, TransformerMixin):
17
+ def __init__(
18
+ self,
19
+ bandwidths=[],
20
+ masses=[],
21
+ threshold: float = -np.inf,
22
+ complex: Literal["alpha", "rips", "delaunay"] = "rips",
23
+ sparse: Optional[float] = None,
24
+ num_collapses: int = -2,
25
+ kernel: available_kernels = "gaussian",
26
+ log_density: bool = True,
27
+ expand_dim: int = 1,
28
+ progress: bool = False,
29
+ n_jobs: Optional[int] = None,
30
+ fit_fraction: float = 1,
31
+ verbose: bool = False,
32
+ safe_conversion: bool = False,
33
+ output_type: Optional[
34
+ Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
35
+ ] = None,
36
+ reduce_degrees: Optional[Iterable[int]] = None,
37
+ ) -> None:
38
+ """
39
+ (Rips or Alpha or Delaunay) + (Density Estimation or DTM) 1-critical 2-filtration.
40
+
41
+ Parameters
42
+ ----------
43
+ - bandwidth : real : The kernel density estimation bandwidth, or the DTM mass. If negative, it replaced by abs(bandwidth)*(radius of the dataset)
44
+ - threshold : real, max edge lenfth of the rips or max alpha square of the alpha
45
+ - sparse : real, sparse rips (c.f. rips doc) WARNING : ONLY FOR RIPS
46
+ - num_collapse : int, Number of edge collapses applied to the simplextrees, WARNING : ONLY FOR RIPS
47
+ - expand_dim : int, expand the rips complex to this dimension. WARNING : ONLY FOR RIPS
48
+ - kernel : the kernel used for density estimation. Available ones are, e.g., "dtm", "gaussian", "exponential".
49
+ - progress : bool, shows the calculus status
50
+ - n_jobs : number of processes
51
+ - fit_fraction : real, the fraction of data on which to fit
52
+ - verbose : bool, Shows more information if true.
53
+
54
+ Output
55
+ ------
56
+ A list of SimplexTreeMulti whose first parameter is a rips and the second is the codensity.
57
+ """
58
+ super().__init__()
59
+ self.bandwidths = bandwidths
60
+ self.masses = masses
61
+ self.num_collapses = num_collapses
62
+ self.kernel = kernel
63
+ self.log_density = log_density
64
+ self.progress = progress
65
+ self._bandwidths = np.empty((0,))
66
+ self._threshold = np.inf
67
+ self.n_jobs = n_jobs
68
+ self._scale = np.empty((0,))
69
+ self.fit_fraction = fit_fraction
70
+ self.expand_dim = expand_dim
71
+ self.verbose = verbose
72
+ self.complex = complex
73
+ self.threshold = threshold
74
+ self.sparse = sparse
75
+ self._get_sts = lambda: Exception("Fit first")
76
+ self.safe_conversion = safe_conversion
77
+ self.output_type = output_type
78
+ self._output_type = None
79
+ self.reduce_degrees = reduce_degrees
80
+ self._vineyard = None
81
+
82
+ assert (
83
+ output_type != "simplextree" or reduce_degrees is None
84
+ ), "Reduced complex are not simplicial. Cannot return a simplextree."
85
+ return
86
+
87
+ def _get_distance_quantiles_and_threshold(self, X, qs):
88
+ ## if we dont need to compute a distance matrix
89
+ if len(qs) == 0 and self.threshold >= 0:
90
+ self._scale = []
91
+ return []
92
+ if self.progress:
93
+ print("Estimating scale...", flush=True, end="")
94
+ ## subsampling
95
+ indices = np.random.choice(
96
+ len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
97
+ )
98
+
99
+ def compute_max_scale(x):
100
+ from pykeops.numpy import LazyTensor
101
+
102
+ a = LazyTensor(x[None, :, :])
103
+ b = LazyTensor(x[:, None, :])
104
+ return np.sqrt(((a - b) ** 2).sum(2).max(1).min(0)[0])
105
+
106
+ diameter = np.max([compute_max_scale(x) for x in (X[i] for i in indices)])
107
+ self._scale = diameter * np.array(qs)
108
+
109
+ if self.threshold == -np.inf:
110
+ self._threshold = diameter
111
+ elif self.threshold > 0:
112
+ self._threshold = self.threshold
113
+ else:
114
+ self._threshold = -diameter * self.threshold
115
+
116
+ if self.threshold > 0:
117
+ self._scale[self._scale > self.threshold] = self.threshold
118
+
119
+ if self.progress:
120
+ print(f"Done. Chosen scales {qs} are {self._scale}", flush=True)
121
+ return self._scale
122
+
123
+ def _get_sts_rips(self, x):
124
+ assert self._output_type is not None and self._vineyard is not None
125
+ if self.sparse is None:
126
+ st_init = gd.SimplexTree.create_from_array(
127
+ cdist(x,x), max_filtration=self._threshold
128
+ )
129
+ else:
130
+ st_init = gd.RipsComplex(
131
+ points=x, max_edge_length=self._threshold, sparse=self.sparse
132
+ ).create_simplex_tree(max_dimension=1)
133
+ st_init = mp.simplex_tree_multi.SimplexTreeMulti(
134
+ st_init, num_parameters=2, safe_conversion=self.safe_conversion
135
+ )
136
+ codensities = self._get_codensities(x_fit=x, x_sample=x)
137
+ num_axes = codensities.shape[0]
138
+ sts = [st_init] + [st_init.copy() for _ in range(num_axes - 1)]
139
+ # no need to multithread here, most operations are memory
140
+ for codensity, st_copy in zip(codensities, sts):
141
+ # RIPS has contigus vertices, so vertices are ordered.
142
+ st_copy.fill_lowerstar(codensity, parameter=1)
143
+
144
+ def reduce(st):
145
+ if self.verbose:
146
+ print("Num simplices :", st.num_simplices)
147
+ if isinstance(self.num_collapses, int):
148
+ st.collapse_edges(num=self.num_collapses)
149
+ if self.verbose:
150
+ print(", after collapse :", st.num_simplices, end="")
151
+ elif self.num_collapses == "full":
152
+ st.collapse_edges(full=True)
153
+ if self.verbose:
154
+ print(", after collapse :", st.num_simplices, end="")
155
+ if self.expand_dim > 1:
156
+ st.expansion(self.expand_dim)
157
+ if self.verbose:
158
+ print(", after expansion :", st.num_simplices, end="")
159
+ if self.verbose:
160
+ print("")
161
+ if self._output_type == "slicer":
162
+ st = mp.Slicer(st, vineyard=self._vineyard)
163
+ if self.reduce_degrees is not None:
164
+ st = mp.slicer.minimal_presentation(
165
+ st, degrees=self.reduce_degrees, vineyard=self._vineyard
166
+ )
167
+ return st
168
+
169
+ return Parallel(backend="threading", n_jobs=self.n_jobs)(
170
+ delayed(reduce)(st) for st in sts
171
+ )
172
+
173
+ def _get_sts_alpha(self, x: np.ndarray, return_alpha=False):
174
+ assert self._output_type is not None and self._vineyard is not None
175
+ alpha_complex = gd.AlphaComplex(points=x)
176
+ st = alpha_complex.create_simplex_tree(max_alpha_square=self._threshold**2)
177
+ vertices = np.array([i for (i,), _ in st.get_skeleton(0)])
178
+ new_points = np.asarray(
179
+ [alpha_complex.get_point(i) for i in vertices]
180
+ ) # Seems to be unsafe for some reason
181
+ # new_points = x
182
+ st = mp.simplex_tree_multi.SimplexTreeMulti(
183
+ st, num_parameters=2, safe_conversion=self.safe_conversion
184
+ )
185
+ codensities = self._get_codensities(x_fit=x, x_sample=new_points)
186
+ num_axes = codensities.shape[0]
187
+ sts = [st] + [st.copy() for _ in range(num_axes - 1)]
188
+ # no need to multithread here, most operations are memory
189
+ max_vertices = vertices.max() + 2 # +1 to be safe
190
+ for codensity, st_copy in zip(codensities, sts):
191
+ alligned_codensity = np.array([np.nan] * max_vertices)
192
+ alligned_codensity[vertices] = codensity
193
+ # alligned_codensity = np.array([codensity[i] if i in vertices else np.nan for i in range(max_vertices)])
194
+ st_copy.fill_lowerstar(alligned_codensity, parameter=1)
195
+ if "slicer" in self._output_type:
196
+ sts2 = (mp.Slicer(st, vineyard=self._vineyard) for st in sts)
197
+ if self.reduce_degrees is not None:
198
+ sts = tuple(
199
+ mp.slicer.minimal_presentation(
200
+ s, degrees=self.reduce_degrees, vineyard=self._vineyard
201
+ )
202
+ for s in sts2
203
+ )
204
+ else:
205
+ sts = tuple(sts2)
206
+ if return_alpha:
207
+ return alpha_complex, sts
208
+ return sts
209
+
210
+ def _get_sts_delaunay(self, x: np.ndarray):
211
+ codensities = self._get_codensities(x_fit=x, x_sample=x)
212
+
213
+ def get_st(c):
214
+ slicer = mps.from_function_delaunay(
215
+ x,
216
+ c,
217
+ verbose=self.verbose,
218
+ clear=not self.verbose,
219
+ vineyard=self._vineyard,
220
+ )
221
+ if self._output_type == "simplextree":
222
+ slicer = mps.to_simplextree(slicer)
223
+ elif self.reduce_degrees is not None:
224
+ slicer = mp.slicer.minimal_presentation(
225
+ slicer,
226
+ degrees=self.reduce_degrees,
227
+ vineyard=self._vineyard,
228
+ )
229
+ else:
230
+ slicer = slicer
231
+ return slicer
232
+
233
+ sts = Parallel(backend="threading", n_jobs=self.n_jobs)(
234
+ delayed(get_st)(c) for c in codensities
235
+ )
236
+ return sts
237
+
238
+ def _get_codensities(self, x_fit, x_sample):
239
+ x_fit = np.asarray(x_fit, dtype=np.float64)
240
+ x_sample = np.asarray(x_sample, dtype=np.float64)
241
+ codensities_kde = np.asarray(
242
+ [
243
+ -KDE(
244
+ bandwidth=bandwidth, kernel=self.kernel, return_log=self.log_density
245
+ )
246
+ .fit(x_fit)
247
+ .score_samples(x_sample)
248
+ for bandwidth in self._bandwidths
249
+ ],
250
+ ).reshape(len(self._bandwidths), len(x_sample))
251
+ codensities_dtm = (
252
+ DTM(masses=self.masses)
253
+ .fit(x_fit)
254
+ .score_samples(x_sample)
255
+ .reshape(len(self.masses), len(x_sample))
256
+ )
257
+ return np.concatenate([codensities_kde, codensities_dtm])
258
+
259
+ def _define_sts(self):
260
+ match self.complex:
261
+ case "rips":
262
+ self._get_sts = self._get_sts_rips
263
+ _pref_output = "simplextree"
264
+ case "alpha":
265
+ self._get_sts = self._get_sts_alpha
266
+ _pref_output = "simplextree"
267
+ case "delaunay":
268
+ self._get_sts = self._get_sts_delaunay
269
+ _pref_output = "slicer"
270
+ case _:
271
+ raise ValueError(
272
+ f"Invalid complex \
273
+ {self.complex}. Possible choises are rips, delaunay, or alpha."
274
+ )
275
+ self._vineyard = (
276
+ False if self.output_type is None else "novine" not in self.output_type
277
+ )
278
+ self._output_type = (
279
+ _pref_output
280
+ if self.output_type is None
281
+ else (
282
+ "simplextree"
283
+ if (
284
+ self.output_type == "simplextree" or self.reduce_degrees is not None
285
+ )
286
+ else "slicer"
287
+ )
288
+ )
289
+
290
+ def _define_bandwidths(self, X):
291
+ qs = [q for q in [*-np.asarray(self.bandwidths)] if 0 <= q <= 1]
292
+ self._get_distance_quantiles_and_threshold(X, qs=qs)
293
+ self._bandwidths = np.array(self.bandwidths)
294
+ count = 0
295
+ for i in range(len(self._bandwidths)):
296
+ if self.bandwidths[i] < 0:
297
+ self._bandwidths[i] = self._scale[count]
298
+ count += 1
299
+
300
+ def fit(self, X: np.ndarray | list, y=None):
301
+ # self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
302
+ self._define_sts()
303
+ self._define_bandwidths(X)
304
+ # PRECOMPILE FIRST
305
+ self._get_codensities(X[0][:2], X[0][:2])
306
+ return self
307
+
308
+ def transform(self, X):
309
+ # precompile first
310
+ # self._get_sts(X[0][:5])
311
+ self._get_codensities(X[0][:2], X[0][:2])
312
+ with tqdm(
313
+ X, desc="Filling simplextrees", disable=not self.progress, total=len(X)
314
+ ) as data:
315
+ stss = Parallel(backend="threading", n_jobs=self.n_jobs)(
316
+ delayed(self._get_sts)(x) for x in data
317
+ )
318
+ return stss
319
+
320
+
321
+ class PointCloud2SimplexTree(PointCloud2FilteredComplex):
322
+ def __init__(
323
+ self,
324
+ bandwidths=[],
325
+ masses=[],
326
+ threshold: float = np.inf,
327
+ complex: Literal["alpha", "rips", "delaunay"] = "rips",
328
+ sparse: float | None = None,
329
+ num_collapses: int = -2,
330
+ kernel: available_kernels = "gaussian",
331
+ log_density: bool = True,
332
+ expand_dim: int = 1,
333
+ progress: bool = False,
334
+ n_jobs: Optional[int] = None,
335
+ fit_fraction: float = 1,
336
+ verbose: bool = False,
337
+ safe_conversion: bool = False,
338
+ output_type: Optional[
339
+ Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
340
+ ] = None,
341
+ reduce_degrees: Optional[Iterable[int]] = None,
342
+ ) -> None:
343
+ stuff = locals()
344
+ stuff.pop("self")
345
+ keys = list(stuff.keys())
346
+ for key in keys:
347
+ if key.startswith("__"):
348
+ stuff.pop(key)
349
+ super().__init__(**stuff)
350
+ from warnings import warn
351
+
352
+ warn("This class is deprecated, use PointCloud2FilteredComplex instead.")