multipers 2.4.0b1__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. multipers/.dylibs/libboost_timer.dylib +0 -0
  2. multipers/.dylibs/libc++.1.0.dylib +0 -0
  3. multipers/.dylibs/libtbb.12.17.dylib +0 -0
  4. multipers/__init__.py +33 -0
  5. multipers/_signed_measure_meta.py +426 -0
  6. multipers/_slicer_meta.py +231 -0
  7. multipers/array_api/__init__.py +62 -0
  8. multipers/array_api/numpy.py +124 -0
  9. multipers/array_api/torch.py +133 -0
  10. multipers/data/MOL2.py +458 -0
  11. multipers/data/UCR.py +18 -0
  12. multipers/data/__init__.py +1 -0
  13. multipers/data/graphs.py +466 -0
  14. multipers/data/immuno_regions.py +27 -0
  15. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  16. multipers/data/pytorch2simplextree.py +91 -0
  17. multipers/data/shape3d.py +101 -0
  18. multipers/data/synthetic.py +113 -0
  19. multipers/distances.py +202 -0
  20. multipers/filtration_conversions.pxd +736 -0
  21. multipers/filtration_conversions.pxd.tp +226 -0
  22. multipers/filtrations/__init__.py +21 -0
  23. multipers/filtrations/density.py +529 -0
  24. multipers/filtrations/filtrations.py +480 -0
  25. multipers/filtrations.pxd +534 -0
  26. multipers/filtrations.pxd.tp +332 -0
  27. multipers/function_rips.cpython-312-darwin.so +0 -0
  28. multipers/function_rips.pyx +104 -0
  29. multipers/grids.cpython-312-darwin.so +0 -0
  30. multipers/grids.pyx +538 -0
  31. multipers/gudhi/Persistence_slices_interface.h +213 -0
  32. multipers/gudhi/Simplex_tree_interface.h +274 -0
  33. multipers/gudhi/Simplex_tree_multi_interface.h +648 -0
  34. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  35. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  36. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  37. multipers/gudhi/gudhi/Debug_utils.h +52 -0
  38. multipers/gudhi/gudhi/Degree_rips_bifiltration.h +2307 -0
  39. multipers/gudhi/gudhi/Dynamic_multi_parameter_filtration.h +2524 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field.h +453 -0
  41. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +460 -0
  42. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +444 -0
  43. multipers/gudhi/gudhi/Fields/Multi_field_small.h +584 -0
  44. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +490 -0
  45. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +580 -0
  46. multipers/gudhi/gudhi/Fields/Z2_field.h +391 -0
  47. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +389 -0
  48. multipers/gudhi/gudhi/Fields/Zp_field.h +493 -0
  49. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +384 -0
  50. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +492 -0
  51. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  52. multipers/gudhi/gudhi/Matrix.h +2200 -0
  53. multipers/gudhi/gudhi/Multi_filtration/Multi_parameter_generator.h +1712 -0
  54. multipers/gudhi/gudhi/Multi_filtration/multi_filtration_conversions.h +237 -0
  55. multipers/gudhi/gudhi/Multi_filtration/multi_filtration_utils.h +225 -0
  56. multipers/gudhi/gudhi/Multi_parameter_filtered_complex.h +485 -0
  57. multipers/gudhi/gudhi/Multi_parameter_filtration.h +2643 -0
  58. multipers/gudhi/gudhi/Multi_persistence/Box.h +233 -0
  59. multipers/gudhi/gudhi/Multi_persistence/Line.h +309 -0
  60. multipers/gudhi/gudhi/Multi_persistence/Multi_parameter_filtered_complex_pcoh_interface.h +268 -0
  61. multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_cohomology.h +159 -0
  62. multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_matrix.h +463 -0
  63. multipers/gudhi/gudhi/Multi_persistence/Point.h +853 -0
  64. multipers/gudhi/gudhi/Off_reader.h +173 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +834 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +838 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +833 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1367 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1157 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +869 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +905 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +122 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +260 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +288 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +170 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +247 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +571 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +182 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +130 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +235 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +312 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1092 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +923 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +914 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +930 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1071 -0
  87. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +203 -0
  88. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +886 -0
  89. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +984 -0
  90. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1213 -0
  91. multipers/gudhi/gudhi/Persistence_matrix/index_mapper.h +58 -0
  92. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +227 -0
  93. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +200 -0
  94. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +166 -0
  95. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +319 -0
  96. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +562 -0
  97. multipers/gudhi/gudhi/Persistence_on_a_line.h +152 -0
  98. multipers/gudhi/gudhi/Persistence_on_rectangle.h +617 -0
  99. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  100. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  101. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  102. multipers/gudhi/gudhi/Persistent_cohomology.h +769 -0
  103. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  104. multipers/gudhi/gudhi/Projective_cover_kernel.h +379 -0
  105. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  106. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +559 -0
  107. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  108. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +121 -0
  109. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  110. multipers/gudhi/gudhi/Simplex_tree/filtration_value_utils.h +155 -0
  111. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  112. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  113. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +60 -0
  114. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +105 -0
  115. multipers/gudhi/gudhi/Simplex_tree.h +3170 -0
  116. multipers/gudhi/gudhi/Slicer.h +848 -0
  117. multipers/gudhi/gudhi/Thread_safe_slicer.h +393 -0
  118. multipers/gudhi/gudhi/distance_functions.h +62 -0
  119. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  120. multipers/gudhi/gudhi/multi_simplex_tree_helpers.h +147 -0
  121. multipers/gudhi/gudhi/persistence_interval.h +263 -0
  122. multipers/gudhi/gudhi/persistence_matrix_options.h +188 -0
  123. multipers/gudhi/gudhi/reader_utils.h +367 -0
  124. multipers/gudhi/gudhi/simple_mdspan.h +484 -0
  125. multipers/gudhi/gudhi/slicer_helpers.h +779 -0
  126. multipers/gudhi/tmp_h0_pers/mma_interface_h0.h +223 -0
  127. multipers/gudhi/tmp_h0_pers/naive_merge_tree.h +536 -0
  128. multipers/io.cpython-312-darwin.so +0 -0
  129. multipers/io.pyx +472 -0
  130. multipers/ml/__init__.py +0 -0
  131. multipers/ml/accuracies.py +90 -0
  132. multipers/ml/invariants_with_persistable.py +79 -0
  133. multipers/ml/kernels.py +176 -0
  134. multipers/ml/mma.py +713 -0
  135. multipers/ml/one.py +472 -0
  136. multipers/ml/point_clouds.py +352 -0
  137. multipers/ml/signed_measures.py +1667 -0
  138. multipers/ml/sliced_wasserstein.py +461 -0
  139. multipers/ml/tools.py +113 -0
  140. multipers/mma_structures.cpython-312-darwin.so +0 -0
  141. multipers/mma_structures.pxd +134 -0
  142. multipers/mma_structures.pyx +1483 -0
  143. multipers/mma_structures.pyx.tp +1126 -0
  144. multipers/multi_parameter_rank_invariant/diff_helpers.h +85 -0
  145. multipers/multi_parameter_rank_invariant/euler_characteristic.h +95 -0
  146. multipers/multi_parameter_rank_invariant/function_rips.h +317 -0
  147. multipers/multi_parameter_rank_invariant/hilbert_function.h +761 -0
  148. multipers/multi_parameter_rank_invariant/persistence_slices.h +149 -0
  149. multipers/multi_parameter_rank_invariant/rank_invariant.h +350 -0
  150. multipers/multiparameter_edge_collapse.py +41 -0
  151. multipers/multiparameter_module_approximation/approximation.h +2541 -0
  152. multipers/multiparameter_module_approximation/debug.h +107 -0
  153. multipers/multiparameter_module_approximation/format_python-cpp.h +292 -0
  154. multipers/multiparameter_module_approximation/utilities.h +428 -0
  155. multipers/multiparameter_module_approximation.cpython-312-darwin.so +0 -0
  156. multipers/multiparameter_module_approximation.pyx +286 -0
  157. multipers/ops.cpython-312-darwin.so +0 -0
  158. multipers/ops.pyx +231 -0
  159. multipers/pickle.py +89 -0
  160. multipers/plots.py +550 -0
  161. multipers/point_measure.cpython-312-darwin.so +0 -0
  162. multipers/point_measure.pyx +409 -0
  163. multipers/simplex_tree_multi.cpython-312-darwin.so +0 -0
  164. multipers/simplex_tree_multi.pxd +136 -0
  165. multipers/simplex_tree_multi.pyx +11719 -0
  166. multipers/simplex_tree_multi.pyx.tp +2102 -0
  167. multipers/slicer.cpython-312-darwin.so +0 -0
  168. multipers/slicer.pxd +2097 -0
  169. multipers/slicer.pxd.tp +263 -0
  170. multipers/slicer.pyx +13042 -0
  171. multipers/slicer.pyx.tp +1259 -0
  172. multipers/tensor/tensor.h +672 -0
  173. multipers/tensor.pxd +13 -0
  174. multipers/test.pyx +44 -0
  175. multipers/tests/__init__.py +70 -0
  176. multipers/torch/__init__.py +1 -0
  177. multipers/torch/diff_grids.py +240 -0
  178. multipers/torch/rips_density.py +310 -0
  179. multipers/vector_interface.pxd +46 -0
  180. multipers-2.4.0b1.dist-info/METADATA +131 -0
  181. multipers-2.4.0b1.dist-info/RECORD +184 -0
  182. multipers-2.4.0b1.dist-info/WHEEL +6 -0
  183. multipers-2.4.0b1.dist-info/licenses/LICENSE +21 -0
  184. multipers-2.4.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,352 @@
1
+ from collections.abc import Iterable
2
+ from typing import Literal, Optional
3
+
4
+ import gudhi as gd
5
+ import numpy as np
6
+ from joblib import Parallel, delayed
7
+ from sklearn.base import BaseEstimator, TransformerMixin
8
+ from scipy.spatial.distance import cdist
9
+ from tqdm import tqdm
10
+
11
+ import multipers as mp
12
+ import multipers.slicer as mps
13
+ from multipers.filtrations.density import DTM, KDE, available_kernels
14
+
15
+
16
+ class PointCloud2FilteredComplex(BaseEstimator, TransformerMixin):
17
+ def __init__(
18
+ self,
19
+ bandwidths=[],
20
+ masses=[],
21
+ threshold: float = -np.inf,
22
+ complex: Literal["alpha", "rips", "delaunay"] = "rips",
23
+ sparse: Optional[float] = None,
24
+ num_collapses: int = -2,
25
+ kernel: available_kernels = "gaussian",
26
+ log_density: bool = True,
27
+ expand_dim: int = 1,
28
+ progress: bool = False,
29
+ n_jobs: Optional[int] = None,
30
+ fit_fraction: float = 1,
31
+ verbose: bool = False,
32
+ safe_conversion: bool = False,
33
+ output_type: Optional[
34
+ Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
35
+ ] = None,
36
+ reduce_degrees: Optional[Iterable[int]] = None,
37
+ ) -> None:
38
+ """
39
+ (Rips or Alpha or Delaunay) + (Density Estimation or DTM) 1-critical 2-filtration.
40
+
41
+ Parameters
42
+ ----------
43
+ - bandwidth : real : The kernel density estimation bandwidth, or the DTM mass. If negative, it replaced by abs(bandwidth)*(radius of the dataset)
44
+ - threshold : real, max edge lenfth of the rips or max alpha square of the alpha
45
+ - sparse : real, sparse rips (c.f. rips doc) WARNING : ONLY FOR RIPS
46
+ - num_collapse : int, Number of edge collapses applied to the simplextrees, WARNING : ONLY FOR RIPS
47
+ - expand_dim : int, expand the rips complex to this dimension. WARNING : ONLY FOR RIPS
48
+ - kernel : the kernel used for density estimation. Available ones are, e.g., "dtm", "gaussian", "exponential".
49
+ - progress : bool, shows the calculus status
50
+ - n_jobs : number of processes
51
+ - fit_fraction : real, the fraction of data on which to fit
52
+ - verbose : bool, Shows more information if true.
53
+
54
+ Output
55
+ ------
56
+ A list of SimplexTreeMulti whose first parameter is a rips and the second is the codensity.
57
+ """
58
+ super().__init__()
59
+ self.bandwidths = bandwidths
60
+ self.masses = masses
61
+ self.num_collapses = num_collapses
62
+ self.kernel = kernel
63
+ self.log_density = log_density
64
+ self.progress = progress
65
+ self._bandwidths = np.empty((0,))
66
+ self._threshold = np.inf
67
+ self.n_jobs = n_jobs
68
+ self._scale = np.empty((0,))
69
+ self.fit_fraction = fit_fraction
70
+ self.expand_dim = expand_dim
71
+ self.verbose = verbose
72
+ self.complex = complex
73
+ self.threshold = threshold
74
+ self.sparse = sparse
75
+ self._get_sts = lambda: Exception("Fit first")
76
+ self.safe_conversion = safe_conversion
77
+ self.output_type = output_type
78
+ self._output_type = None
79
+ self.reduce_degrees = reduce_degrees
80
+ self._vineyard = None
81
+
82
+ assert (
83
+ output_type != "simplextree" or reduce_degrees is None
84
+ ), "Reduced complex are not simplicial. Cannot return a simplextree."
85
+ return
86
+
87
+ def _get_distance_quantiles_and_threshold(self, X, qs):
88
+ ## if we dont need to compute a distance matrix
89
+ if len(qs) == 0 and self.threshold >= 0:
90
+ self._scale = []
91
+ return []
92
+ if self.progress:
93
+ print("Estimating scale...", flush=True, end="")
94
+ ## subsampling
95
+ indices = np.random.choice(
96
+ len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
97
+ )
98
+
99
+ def compute_max_scale(x):
100
+ from pykeops.numpy import LazyTensor
101
+
102
+ a = LazyTensor(x[None, :, :])
103
+ b = LazyTensor(x[:, None, :])
104
+ return np.sqrt(((a - b) ** 2).sum(2).max(1).min(0)[0])
105
+
106
+ diameter = np.max([compute_max_scale(x) for x in (X[i] for i in indices)])
107
+ self._scale = diameter * np.array(qs)
108
+
109
+ if self.threshold == -np.inf:
110
+ self._threshold = diameter
111
+ elif self.threshold > 0:
112
+ self._threshold = self.threshold
113
+ else:
114
+ self._threshold = -diameter * self.threshold
115
+
116
+ if self.threshold > 0:
117
+ self._scale[self._scale > self.threshold] = self.threshold
118
+
119
+ if self.progress:
120
+ print(f"Done. Chosen scales {qs} are {self._scale}", flush=True)
121
+ return self._scale
122
+
123
+ def _get_sts_rips(self, x):
124
+ assert self._output_type is not None and self._vineyard is not None
125
+ if self.sparse is None:
126
+ st_init = gd.SimplexTree.create_from_array(
127
+ cdist(x,x), max_filtration=self._threshold
128
+ )
129
+ else:
130
+ st_init = gd.RipsComplex(
131
+ points=x, max_edge_length=self._threshold, sparse=self.sparse
132
+ ).create_simplex_tree(max_dimension=1)
133
+ st_init = mp.simplex_tree_multi.SimplexTreeMulti(
134
+ st_init, num_parameters=2, safe_conversion=self.safe_conversion
135
+ )
136
+ codensities = self._get_codensities(x_fit=x, x_sample=x)
137
+ num_axes = codensities.shape[0]
138
+ sts = [st_init] + [st_init.copy() for _ in range(num_axes - 1)]
139
+ # no need to multithread here, most operations are memory
140
+ for codensity, st_copy in zip(codensities, sts):
141
+ # RIPS has contigus vertices, so vertices are ordered.
142
+ st_copy.fill_lowerstar(codensity, parameter=1)
143
+
144
+ def reduce(st):
145
+ if self.verbose:
146
+ print("Num simplices :", st.num_simplices)
147
+ if isinstance(self.num_collapses, int):
148
+ st.collapse_edges(num=self.num_collapses)
149
+ if self.verbose:
150
+ print(", after collapse :", st.num_simplices, end="")
151
+ elif self.num_collapses == "full":
152
+ st.collapse_edges(full=True)
153
+ if self.verbose:
154
+ print(", after collapse :", st.num_simplices, end="")
155
+ if self.expand_dim > 1:
156
+ st.expansion(self.expand_dim)
157
+ if self.verbose:
158
+ print(", after expansion :", st.num_simplices, end="")
159
+ if self.verbose:
160
+ print("")
161
+ if self._output_type == "slicer":
162
+ st = mp.Slicer(st, vineyard=self._vineyard)
163
+ if self.reduce_degrees is not None:
164
+ st = mp.slicer.minimal_presentation(
165
+ st, degrees=self.reduce_degrees, vineyard=self._vineyard
166
+ )
167
+ return st
168
+
169
+ return Parallel(backend="threading", n_jobs=self.n_jobs)(
170
+ delayed(reduce)(st) for st in sts
171
+ )
172
+
173
+ def _get_sts_alpha(self, x: np.ndarray, return_alpha=False):
174
+ assert self._output_type is not None and self._vineyard is not None
175
+ alpha_complex = gd.AlphaComplex(points=x)
176
+ st = alpha_complex.create_simplex_tree(max_alpha_square=self._threshold**2)
177
+ vertices = np.array([i for (i,), _ in st.get_skeleton(0)])
178
+ new_points = np.asarray(
179
+ [alpha_complex.get_point(int(i)) for i in vertices]
180
+ ) # Seems to be unsafe for some reason
181
+ # new_points = x
182
+ st = mp.simplex_tree_multi.SimplexTreeMulti(
183
+ st, num_parameters=2, safe_conversion=self.safe_conversion
184
+ )
185
+ codensities = self._get_codensities(x_fit=x, x_sample=new_points)
186
+ num_axes = codensities.shape[0]
187
+ sts = [st] + [st.copy() for _ in range(num_axes - 1)]
188
+ # no need to multithread here, most operations are memory
189
+ max_vertices = vertices.max() + 2 # +1 to be safe
190
+ for codensity, st_copy in zip(codensities, sts):
191
+ alligned_codensity = np.array([np.nan] * max_vertices)
192
+ alligned_codensity[vertices] = codensity
193
+ # alligned_codensity = np.array([codensity[i] if i in vertices else np.nan for i in range(max_vertices)])
194
+ st_copy.fill_lowerstar(alligned_codensity, parameter=1)
195
+ if "slicer" in self._output_type:
196
+ sts2 = (mp.Slicer(st, vineyard=self._vineyard) for st in sts)
197
+ if self.reduce_degrees is not None:
198
+ sts = tuple(
199
+ mp.slicer.minimal_presentation(
200
+ s, degrees=self.reduce_degrees, vineyard=self._vineyard
201
+ )
202
+ for s in sts2
203
+ )
204
+ else:
205
+ sts = tuple(sts2)
206
+ if return_alpha:
207
+ return alpha_complex, sts
208
+ return sts
209
+
210
+ def _get_sts_delaunay(self, x: np.ndarray):
211
+ codensities = self._get_codensities(x_fit=x, x_sample=x)
212
+
213
+ def get_st(c):
214
+ slicer = mps.from_function_delaunay(
215
+ x,
216
+ c,
217
+ verbose=self.verbose,
218
+ clear=not self.verbose,
219
+ vineyard=self._vineyard,
220
+ )
221
+ if self._output_type == "simplextree":
222
+ slicer = mps.to_simplextree(slicer)
223
+ elif self.reduce_degrees is not None:
224
+ slicer = mp.slicer.minimal_presentation(
225
+ slicer,
226
+ degrees=self.reduce_degrees,
227
+ vineyard=self._vineyard,
228
+ )
229
+ else:
230
+ slicer = slicer
231
+ return slicer
232
+
233
+ sts = Parallel(backend="threading", n_jobs=self.n_jobs)(
234
+ delayed(get_st)(c) for c in codensities
235
+ )
236
+ return sts
237
+
238
+ def _get_codensities(self, x_fit, x_sample):
239
+ x_fit = np.asarray(x_fit, dtype=np.float64)
240
+ x_sample = np.asarray(x_sample, dtype=np.float64)
241
+ codensities_kde = np.asarray(
242
+ [
243
+ -KDE(
244
+ bandwidth=bandwidth, kernel=self.kernel, return_log=self.log_density
245
+ )
246
+ .fit(x_fit)
247
+ .score_samples(x_sample)
248
+ for bandwidth in self._bandwidths
249
+ ],
250
+ ).reshape(len(self._bandwidths), len(x_sample))
251
+ codensities_dtm = (
252
+ DTM(masses=self.masses)
253
+ .fit(x_fit)
254
+ .score_samples(x_sample)
255
+ .reshape(len(self.masses), len(x_sample))
256
+ )
257
+ return np.concatenate([codensities_kde, codensities_dtm])
258
+
259
+ def _define_sts(self):
260
+ match self.complex:
261
+ case "rips":
262
+ self._get_sts = self._get_sts_rips
263
+ _pref_output = "simplextree"
264
+ case "alpha":
265
+ self._get_sts = self._get_sts_alpha
266
+ _pref_output = "simplextree"
267
+ case "delaunay":
268
+ self._get_sts = self._get_sts_delaunay
269
+ _pref_output = "slicer"
270
+ case _:
271
+ raise ValueError(
272
+ f"Invalid complex \
273
+ {self.complex}. Possible choises are rips, delaunay, or alpha."
274
+ )
275
+ self._vineyard = (
276
+ False if self.output_type is None else "novine" not in self.output_type
277
+ )
278
+ self._output_type = (
279
+ _pref_output
280
+ if self.output_type is None
281
+ else (
282
+ "simplextree"
283
+ if (
284
+ self.output_type == "simplextree" or self.reduce_degrees is not None
285
+ )
286
+ else "slicer"
287
+ )
288
+ )
289
+
290
+ def _define_bandwidths(self, X):
291
+ qs = [q for q in [*-np.asarray(self.bandwidths)] if 0 <= q <= 1]
292
+ self._get_distance_quantiles_and_threshold(X, qs=qs)
293
+ self._bandwidths = np.array(self.bandwidths)
294
+ count = 0
295
+ for i in range(len(self._bandwidths)):
296
+ if self.bandwidths[i] < 0:
297
+ self._bandwidths[i] = self._scale[count]
298
+ count += 1
299
+
300
+ def fit(self, X: np.ndarray | list, y=None):
301
+ # self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
302
+ self._define_sts()
303
+ self._define_bandwidths(X)
304
+ # PRECOMPILE FIRST
305
+ self._get_codensities(X[0][:2], X[0][:2])
306
+ return self
307
+
308
+ def transform(self, X):
309
+ # precompile first
310
+ # self._get_sts(X[0][:5])
311
+ self._get_codensities(X[0][:2], X[0][:2])
312
+ with tqdm(
313
+ X, desc="Filling simplextrees", disable=not self.progress, total=len(X)
314
+ ) as data:
315
+ stss = Parallel(backend="threading", n_jobs=self.n_jobs)(
316
+ delayed(self._get_sts)(x) for x in data
317
+ )
318
+ return stss
319
+
320
+
321
+ class PointCloud2SimplexTree(PointCloud2FilteredComplex):
322
+ def __init__(
323
+ self,
324
+ bandwidths=[],
325
+ masses=[],
326
+ threshold: float = np.inf,
327
+ complex: Literal["alpha", "rips", "delaunay"] = "rips",
328
+ sparse: float | None = None,
329
+ num_collapses: int = -2,
330
+ kernel: available_kernels = "gaussian",
331
+ log_density: bool = True,
332
+ expand_dim: int = 1,
333
+ progress: bool = False,
334
+ n_jobs: Optional[int] = None,
335
+ fit_fraction: float = 1,
336
+ verbose: bool = False,
337
+ safe_conversion: bool = False,
338
+ output_type: Optional[
339
+ Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
340
+ ] = None,
341
+ reduce_degrees: Optional[Iterable[int]] = None,
342
+ ) -> None:
343
+ stuff = locals()
344
+ stuff.pop("self")
345
+ keys = list(stuff.keys())
346
+ for key in keys:
347
+ if key.startswith("__"):
348
+ stuff.pop(key)
349
+ super().__init__(**stuff)
350
+ from warnings import warn
351
+
352
+ warn("This class is deprecated, use PointCloud2FilteredComplex instead.")