multipers 2.2.3__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (189) hide show
  1. multipers/__init__.py +31 -0
  2. multipers/_signed_measure_meta.py +430 -0
  3. multipers/_slicer_meta.py +212 -0
  4. multipers/data/MOL2.py +458 -0
  5. multipers/data/UCR.py +18 -0
  6. multipers/data/__init__.py +1 -0
  7. multipers/data/graphs.py +466 -0
  8. multipers/data/immuno_regions.py +27 -0
  9. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  10. multipers/data/pytorch2simplextree.py +91 -0
  11. multipers/data/shape3d.py +101 -0
  12. multipers/data/synthetic.py +111 -0
  13. multipers/distances.py +198 -0
  14. multipers/filtration_conversions.pxd +229 -0
  15. multipers/filtration_conversions.pxd.tp +84 -0
  16. multipers/filtrations.pxd +224 -0
  17. multipers/function_rips.cp310-win_amd64.pyd +0 -0
  18. multipers/function_rips.pyx +105 -0
  19. multipers/grids.cp310-win_amd64.pyd +0 -0
  20. multipers/grids.pyx +350 -0
  21. multipers/gudhi/Persistence_slices_interface.h +132 -0
  22. multipers/gudhi/Simplex_tree_interface.h +245 -0
  23. multipers/gudhi/Simplex_tree_multi_interface.h +561 -0
  24. multipers/gudhi/cubical_to_boundary.h +59 -0
  25. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  26. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  27. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  28. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  29. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  30. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  31. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  32. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  33. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  34. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  35. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  36. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  37. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  38. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  39. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  40. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  41. multipers/gudhi/gudhi/Matrix.h +2107 -0
  42. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  43. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -0
  44. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  45. multipers/gudhi/gudhi/Off_reader.h +173 -0
  46. multipers/gudhi/gudhi/One_critical_filtration.h +1431 -0
  47. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  48. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  49. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  50. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  51. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  52. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  53. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  54. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  81. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  82. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  83. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  84. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  85. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  86. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  87. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  88. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  89. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  90. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  91. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  92. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  93. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  94. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  95. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  96. multipers/gudhi/gudhi/Simplex_tree_multi.h +163 -0
  97. multipers/gudhi/gudhi/distance_functions.h +62 -0
  98. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  99. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  100. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  101. multipers/gudhi/gudhi/reader_utils.h +367 -0
  102. multipers/gudhi/mma_interface_coh.h +255 -0
  103. multipers/gudhi/mma_interface_h0.h +231 -0
  104. multipers/gudhi/mma_interface_matrix.h +282 -0
  105. multipers/gudhi/naive_merge_tree.h +575 -0
  106. multipers/gudhi/scc_io.h +289 -0
  107. multipers/gudhi/truc.h +888 -0
  108. multipers/io.cp310-win_amd64.pyd +0 -0
  109. multipers/io.pyx +711 -0
  110. multipers/ml/__init__.py +0 -0
  111. multipers/ml/accuracies.py +90 -0
  112. multipers/ml/convolutions.py +520 -0
  113. multipers/ml/invariants_with_persistable.py +79 -0
  114. multipers/ml/kernels.py +176 -0
  115. multipers/ml/mma.py +714 -0
  116. multipers/ml/one.py +472 -0
  117. multipers/ml/point_clouds.py +346 -0
  118. multipers/ml/signed_measures.py +1589 -0
  119. multipers/ml/sliced_wasserstein.py +461 -0
  120. multipers/ml/tools.py +113 -0
  121. multipers/mma_structures.cp310-win_amd64.pyd +0 -0
  122. multipers/mma_structures.pxd +127 -0
  123. multipers/mma_structures.pyx +2746 -0
  124. multipers/mma_structures.pyx.tp +1085 -0
  125. multipers/multi_parameter_rank_invariant/diff_helpers.h +93 -0
  126. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  127. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  128. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  129. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  130. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  131. multipers/multiparameter_edge_collapse.py +41 -0
  132. multipers/multiparameter_module_approximation/approximation.h +2295 -0
  133. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  134. multipers/multiparameter_module_approximation/debug.h +107 -0
  135. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  136. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  137. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  138. multipers/multiparameter_module_approximation/images.h +79 -0
  139. multipers/multiparameter_module_approximation/list_column.h +174 -0
  140. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  141. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  142. multipers/multiparameter_module_approximation/set_column.h +135 -0
  143. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  144. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  145. multipers/multiparameter_module_approximation/utilities.h +419 -0
  146. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  147. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  148. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  149. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  150. multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
  151. multipers/multiparameter_module_approximation.pyx +217 -0
  152. multipers/pickle.py +53 -0
  153. multipers/plots.py +334 -0
  154. multipers/point_measure.cp310-win_amd64.pyd +0 -0
  155. multipers/point_measure.pyx +320 -0
  156. multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
  157. multipers/simplex_tree_multi.pxd +133 -0
  158. multipers/simplex_tree_multi.pyx +10335 -0
  159. multipers/simplex_tree_multi.pyx.tp +1935 -0
  160. multipers/slicer.cp310-win_amd64.pyd +0 -0
  161. multipers/slicer.pxd +2371 -0
  162. multipers/slicer.pxd.tp +214 -0
  163. multipers/slicer.pyx +15467 -0
  164. multipers/slicer.pyx.tp +914 -0
  165. multipers/tbb12.dll +0 -0
  166. multipers/tbbbind_2_5.dll +0 -0
  167. multipers/tbbmalloc.dll +0 -0
  168. multipers/tbbmalloc_proxy.dll +0 -0
  169. multipers/tensor/tensor.h +672 -0
  170. multipers/tensor.pxd +13 -0
  171. multipers/test.pyx +44 -0
  172. multipers/tests/__init__.py +57 -0
  173. multipers/tests/test_diff_helper.py +73 -0
  174. multipers/tests/test_hilbert_function.py +82 -0
  175. multipers/tests/test_mma.py +83 -0
  176. multipers/tests/test_point_clouds.py +49 -0
  177. multipers/tests/test_python-cpp_conversion.py +82 -0
  178. multipers/tests/test_signed_betti.py +181 -0
  179. multipers/tests/test_signed_measure.py +89 -0
  180. multipers/tests/test_simplextreemulti.py +221 -0
  181. multipers/tests/test_slicer.py +221 -0
  182. multipers/torch/__init__.py +1 -0
  183. multipers/torch/diff_grids.py +217 -0
  184. multipers/torch/rips_density.py +304 -0
  185. multipers-2.2.3.dist-info/LICENSE +21 -0
  186. multipers-2.2.3.dist-info/METADATA +134 -0
  187. multipers-2.2.3.dist-info/RECORD +189 -0
  188. multipers-2.2.3.dist-info/WHEEL +5 -0
  189. multipers-2.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,346 @@
1
+ from collections.abc import Callable, Iterable
2
+ from typing import Literal, Optional
3
+
4
+ import gudhi as gd
5
+ import numpy as np
6
+ from joblib import Parallel, delayed
7
+ from sklearn.base import BaseEstimator, TransformerMixin
8
+ from tqdm import tqdm
9
+
10
+ import multipers as mp
11
+ import multipers.slicer as mps
12
+ from multipers.ml.convolutions import DTM, KDE, available_kernels
13
+
14
+
15
+ class PointCloud2FilteredComplex(BaseEstimator, TransformerMixin):
16
+ def __init__(
17
+ self,
18
+ bandwidths=[],
19
+ masses=[],
20
+ threshold: float = -np.inf,
21
+ complex: Literal["alpha", "rips", "delaunay"] = "rips",
22
+ sparse: Optional[float] = None,
23
+ num_collapses: int = -2,
24
+ kernel: available_kernels = "gaussian",
25
+ log_density: bool = True,
26
+ expand_dim: int = 1,
27
+ progress: bool = False,
28
+ n_jobs: Optional[int] = None,
29
+ fit_fraction: float = 1,
30
+ verbose: bool = False,
31
+ safe_conversion: bool = False,
32
+ output_type: Optional[
33
+ Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
34
+ ] = None,
35
+ reduce_degrees: Optional[Iterable[int]] = None,
36
+ ) -> None:
37
+ """
38
+ (Rips or Alpha or Delaunay) + (Density Estimation or DTM) 1-critical 2-filtration.
39
+
40
+ Parameters
41
+ ----------
42
+ - bandwidth : real : The kernel density estimation bandwidth, or the DTM mass. If negative, it replaced by abs(bandwidth)*(radius of the dataset)
43
+ - threshold : real, max edge lenfth of the rips or max alpha square of the alpha
44
+ - sparse : real, sparse rips (c.f. rips doc) WARNING : ONLY FOR RIPS
45
+ - num_collapse : int, Number of edge collapses applied to the simplextrees, WARNING : ONLY FOR RIPS
46
+ - expand_dim : int, expand the rips complex to this dimension. WARNING : ONLY FOR RIPS
47
+ - kernel : the kernel used for density estimation. Available ones are, e.g., "dtm", "gaussian", "exponential".
48
+ - progress : bool, shows the calculus status
49
+ - n_jobs : number of processes
50
+ - fit_fraction : real, the fraction of data on which to fit
51
+ - verbose : bool, Shows more information if true.
52
+
53
+ Output
54
+ ------
55
+ A list of SimplexTreeMulti whose first parameter is a rips and the second is the codensity.
56
+ """
57
+ super().__init__()
58
+ self.bandwidths = bandwidths
59
+ self.masses = masses
60
+ self.num_collapses = num_collapses
61
+ self.kernel = kernel
62
+ self.log_density = log_density
63
+ self.progress = progress
64
+ self._bandwidths = np.empty((0,))
65
+ self._threshold = np.inf
66
+ self.n_jobs = n_jobs
67
+ self._scale = np.empty((0,))
68
+ self.fit_fraction = fit_fraction
69
+ self.expand_dim = expand_dim
70
+ self.verbose = verbose
71
+ self.complex = complex
72
+ self.threshold = threshold
73
+ self.sparse = sparse
74
+ self._get_sts = lambda: Exception("Fit first")
75
+ self.safe_conversion = safe_conversion
76
+ self.output_type = output_type
77
+ self._output_type = None
78
+ self.reduce_degrees = reduce_degrees
79
+ self._vineyard = None
80
+
81
+ assert (
82
+ output_type != "simplextree" or reduce_degrees is None
83
+ ), "Reduced complex are not simplicial. Cannot return a simplextree."
84
+ return
85
+
86
+ def _get_distance_quantiles_and_threshold(self, X, qs):
87
+ ## if we dont need to compute a distance matrix
88
+ if len(qs) == 0 and self.threshold >= 0:
89
+ self._scale = []
90
+ return []
91
+ if self.progress:
92
+ print("Estimating scale...", flush=True, end="")
93
+ ## subsampling
94
+ indices = np.random.choice(
95
+ len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
96
+ )
97
+
98
+ def compute_max_scale(x):
99
+ from pykeops.numpy import LazyTensor
100
+
101
+ a = LazyTensor(x[None, :, :])
102
+ b = LazyTensor(x[:, None, :])
103
+ return np.sqrt(((a - b) ** 2).sum(2).max(1).min(0)[0])
104
+
105
+ diameter = np.max([compute_max_scale(x) for x in (X[i] for i in indices)])
106
+ self._scale = diameter * np.array(qs)
107
+
108
+ if self.threshold == -np.inf:
109
+ self._threshold = diameter
110
+ elif self.threshold > 0:
111
+ self._threshold = self.threshold
112
+ else:
113
+ self._threshold = -diameter * self.threshold
114
+
115
+ if self.threshold > 0:
116
+ self._scale[self._scale > self.threshold] = self.threshold
117
+
118
+ if self.progress:
119
+ print(f"Done. Chosen scales {qs} are {self._scale}", flush=True)
120
+ return self._scale
121
+
122
+ def _get_sts_rips(self, x):
123
+ assert self._output_type is not None and self._vineyard is not None
124
+ st_init = gd.RipsComplex(
125
+ points=x, max_edge_length=self._threshold, sparse=self.sparse
126
+ ).create_simplex_tree(max_dimension=1)
127
+ st_init = mp.simplex_tree_multi.SimplexTreeMulti(
128
+ st_init, num_parameters=2, safe_conversion=self.safe_conversion
129
+ )
130
+ codensities = self._get_codensities(x_fit=x, x_sample=x)
131
+ num_axes = codensities.shape[0]
132
+ sts = [st_init] + [st_init.copy() for _ in range(num_axes - 1)]
133
+ # no need to multithread here, most operations are memory
134
+ for codensity, st_copy in zip(codensities, sts):
135
+ # RIPS has contigus vertices, so vertices are ordered.
136
+ st_copy.fill_lowerstar(codensity, parameter=1)
137
+
138
+ def reduce(st):
139
+ if self.verbose:
140
+ print("Num simplices :", st.num_simplices)
141
+ if isinstance(self.num_collapses, int):
142
+ st.collapse_edges(num=self.num_collapses)
143
+ if self.verbose:
144
+ print(", after collapse :", st.num_simplices, end="")
145
+ elif self.num_collapses == "full":
146
+ st.collapse_edges(full=True)
147
+ if self.verbose:
148
+ print(", after collapse :", st.num_simplices, end="")
149
+ if self.expand_dim > 1:
150
+ st.expansion(self.expand_dim)
151
+ if self.verbose:
152
+ print(", after expansion :", st.num_simplices, end="")
153
+ if self.verbose:
154
+ print("")
155
+ if self._output_type == "slicer":
156
+ st = mp.Slicer(st, vineyard=self._vineyard)
157
+ if self.reduce_degrees is not None:
158
+ st = mp.slicer.minimal_presentation(
159
+ st, degrees=self.reduce_degrees, vineyard=self._vineyard
160
+ )
161
+ return st
162
+
163
+ return Parallel(backend="threading", n_jobs=self.n_jobs)(
164
+ delayed(reduce)(st) for st in sts
165
+ )
166
+
167
+ def _get_sts_alpha(self, x: np.ndarray, return_alpha=False):
168
+ assert self._output_type is not None and self._vineyard is not None
169
+ alpha_complex = gd.AlphaComplex(points=x)
170
+ st = alpha_complex.create_simplex_tree(max_alpha_square=self._threshold**2)
171
+ vertices = np.array([i for (i,), _ in st.get_skeleton(0)])
172
+ new_points = np.asarray(
173
+ [alpha_complex.get_point(i) for i in vertices]
174
+ ) # Seems to be unsafe for some reason
175
+ # new_points = x
176
+ st = mp.simplex_tree_multi.SimplexTreeMulti(
177
+ st, num_parameters=2, safe_conversion=self.safe_conversion
178
+ )
179
+ codensities = self._get_codensities(x_fit=x, x_sample=new_points)
180
+ num_axes = codensities.shape[0]
181
+ sts = [st] + [st.copy() for _ in range(num_axes - 1)]
182
+ # no need to multithread here, most operations are memory
183
+ max_vertices = vertices.max() + 2 # +1 to be safe
184
+ for codensity, st_copy in zip(codensities, sts):
185
+ alligned_codensity = np.array([np.nan] * max_vertices)
186
+ alligned_codensity[vertices] = codensity
187
+ # alligned_codensity = np.array([codensity[i] if i in vertices else np.nan for i in range(max_vertices)])
188
+ st_copy.fill_lowerstar(alligned_codensity, parameter=1)
189
+ if "slicer" in self._output_type:
190
+ sts2 = (mp.Slicer(st, vineyard=self._vineyard) for st in sts)
191
+ if self.reduce_degrees is not None:
192
+ sts = tuple(
193
+ mp.slicer.minimal_presentation(
194
+ s, degrees=self.reduce_degrees, vineyard=self._vineyard
195
+ )
196
+ for s in sts2
197
+ )
198
+ else:
199
+ sts = tuple(sts2)
200
+ if return_alpha:
201
+ return alpha_complex, sts
202
+ return sts
203
+
204
+ def _get_sts_delaunay(self, x: np.ndarray):
205
+ codensities = self._get_codensities(x_fit=x, x_sample=x)
206
+
207
+ def get_st(c):
208
+ slicer = mps.from_function_delaunay(
209
+ x,
210
+ c,
211
+ verbose=self.verbose,
212
+ clear=not self.verbose,
213
+ vineyard=self._vineyard,
214
+ )
215
+ if self._output_type == "simplextree":
216
+ slicer = mps.to_simplextree(slicer)
217
+ elif self.reduce_degrees is not None:
218
+ slicer = mp.slicer.minimal_presentation(
219
+ slicer,
220
+ degrees=self.reduce_degrees,
221
+ vineyard=self._vineyard,
222
+ )
223
+ else:
224
+ slicer = slicer
225
+ return slicer
226
+
227
+ sts = Parallel(backend="threading", n_jobs=self.n_jobs)(
228
+ delayed(get_st)(c) for c in codensities
229
+ )
230
+ return sts
231
+
232
+ def _get_codensities(self, x_fit, x_sample):
233
+ x_fit = np.asarray(x_fit, dtype=np.float64)
234
+ x_sample = np.asarray(x_sample, dtype=np.float64)
235
+ codensities_kde = np.asarray(
236
+ [
237
+ -KDE(
238
+ bandwidth=bandwidth, kernel=self.kernel, return_log=self.log_density
239
+ )
240
+ .fit(x_fit)
241
+ .score_samples(x_sample)
242
+ for bandwidth in self._bandwidths
243
+ ],
244
+ ).reshape(len(self._bandwidths), len(x_sample))
245
+ codensities_dtm = (
246
+ DTM(masses=self.masses)
247
+ .fit(x_fit)
248
+ .score_samples(x_sample)
249
+ .reshape(len(self.masses), len(x_sample))
250
+ )
251
+ return np.concatenate([codensities_kde, codensities_dtm])
252
+
253
+ def _define_sts(self):
254
+ match self.complex:
255
+ case "rips":
256
+ self._get_sts = self._get_sts_rips
257
+ _pref_output = "simplextree"
258
+ case "alpha":
259
+ self._get_sts = self._get_sts_alpha
260
+ _pref_output = "simplextree"
261
+ case "delaunay":
262
+ self._get_sts = self._get_sts_delaunay
263
+ _pref_output = "slicer"
264
+ case _:
265
+ raise ValueError(
266
+ f"Invalid complex \
267
+ {self.complex}. Possible choises are rips, delaunay, or alpha."
268
+ )
269
+ self._vineyard = (
270
+ False if self.output_type is None else "novine" not in self.output_type
271
+ )
272
+ self._output_type = (
273
+ _pref_output
274
+ if self.output_type is None
275
+ else (
276
+ "simplextree"
277
+ if (
278
+ self.output_type == "simplextree" or self.reduce_degrees is not None
279
+ )
280
+ else "slicer"
281
+ )
282
+ )
283
+
284
+ def _define_bandwidths(self, X):
285
+ qs = [q for q in [*-np.asarray(self.bandwidths)] if 0 <= q <= 1]
286
+ self._get_distance_quantiles_and_threshold(X, qs=qs)
287
+ self._bandwidths = np.array(self.bandwidths)
288
+ count = 0
289
+ for i in range(len(self._bandwidths)):
290
+ if self.bandwidths[i] < 0:
291
+ self._bandwidths[i] = self._scale[count]
292
+ count += 1
293
+
294
+ def fit(self, X: np.ndarray | list, y=None):
295
+ # self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
296
+ self._define_sts()
297
+ self._define_bandwidths(X)
298
+ # PRECOMPILE FIRST
299
+ self._get_codensities(X[0][:2], X[0][:2])
300
+ return self
301
+
302
+ def transform(self, X):
303
+ # precompile first
304
+ # self._get_sts(X[0][:5])
305
+ self._get_codensities(X[0][:2], X[0][:2])
306
+ with tqdm(
307
+ X, desc="Filling simplextrees", disable=not self.progress, total=len(X)
308
+ ) as data:
309
+ stss = Parallel(backend="threading", n_jobs=self.n_jobs)(
310
+ delayed(self._get_sts)(x) for x in data
311
+ )
312
+ return stss
313
+
314
+
315
+ class PointCloud2SimplexTree(PointCloud2FilteredComplex):
316
+ def __init__(
317
+ self,
318
+ bandwidths=[],
319
+ masses=[],
320
+ threshold: float = np.inf,
321
+ complex: Literal["alpha", "rips", "delaunay"] = "rips",
322
+ sparse: float | None = None,
323
+ num_collapses: int = -2,
324
+ kernel: available_kernels = "gaussian",
325
+ log_density: bool = True,
326
+ expand_dim: int = 1,
327
+ progress: bool = False,
328
+ n_jobs: Optional[int] = None,
329
+ fit_fraction: float = 1,
330
+ verbose: bool = False,
331
+ safe_conversion: bool = False,
332
+ output_type: Optional[
333
+ Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
334
+ ] = None,
335
+ reduce_degrees: Optional[Iterable[int]] = None,
336
+ ) -> None:
337
+ stuff = locals()
338
+ stuff.pop("self")
339
+ keys = list(stuff.keys())
340
+ for key in keys:
341
+ if key.startswith("__"):
342
+ stuff.pop(key)
343
+ super().__init__(**stuff)
344
+ from warnings import warn
345
+
346
+ warn("This class is deprecated, use PointCloud2FilteredComplex instead.")