multipers 2.3.3b6__cp313-cp313-macosx_10_13_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/.dylibs/libc++.1.0.dylib +0 -0
- multipers/.dylibs/libtbb.12.16.dylib +0 -0
- multipers/__init__.py +33 -0
- multipers/_signed_measure_meta.py +453 -0
- multipers/_slicer_meta.py +211 -0
- multipers/array_api/__init__.py +45 -0
- multipers/array_api/numpy.py +41 -0
- multipers/array_api/torch.py +58 -0
- multipers/data/MOL2.py +458 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +113 -0
- multipers/distances.py +202 -0
- multipers/filtration_conversions.pxd +229 -0
- multipers/filtration_conversions.pxd.tp +84 -0
- multipers/filtrations/__init__.py +18 -0
- multipers/filtrations/density.py +574 -0
- multipers/filtrations/filtrations.py +361 -0
- multipers/filtrations.pxd +224 -0
- multipers/function_rips.cpython-313-darwin.so +0 -0
- multipers/function_rips.pyx +105 -0
- multipers/grids.cpython-313-darwin.so +0 -0
- multipers/grids.pyx +433 -0
- multipers/gudhi/Persistence_slices_interface.h +132 -0
- multipers/gudhi/Simplex_tree_interface.h +239 -0
- multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
- multipers/gudhi/cubical_to_boundary.h +59 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
- multipers/gudhi/gudhi/Debug_utils.h +45 -0
- multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
- multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
- multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
- multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
- multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
- multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
- multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
- multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
- multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
- multipers/gudhi/gudhi/Matrix.h +2107 -0
- multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
- multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
- multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
- multipers/gudhi/gudhi/Off_reader.h +173 -0
- multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
- multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
- multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
- multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
- multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
- multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
- multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
- multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
- multipers/gudhi/gudhi/Points_off_io.h +171 -0
- multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
- multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
- multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
- multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
- multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
- multipers/gudhi/gudhi/distance_functions.h +62 -0
- multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
- multipers/gudhi/gudhi/persistence_interval.h +253 -0
- multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
- multipers/gudhi/gudhi/reader_utils.h +367 -0
- multipers/gudhi/mma_interface_coh.h +256 -0
- multipers/gudhi/mma_interface_h0.h +223 -0
- multipers/gudhi/mma_interface_matrix.h +293 -0
- multipers/gudhi/naive_merge_tree.h +536 -0
- multipers/gudhi/scc_io.h +310 -0
- multipers/gudhi/truc.h +1403 -0
- multipers/io.cpython-313-darwin.so +0 -0
- multipers/io.pyx +644 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +90 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +176 -0
- multipers/ml/mma.py +713 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +352 -0
- multipers/ml/signed_measures.py +1589 -0
- multipers/ml/sliced_wasserstein.py +461 -0
- multipers/ml/tools.py +113 -0
- multipers/mma_structures.cpython-313-darwin.so +0 -0
- multipers/mma_structures.pxd +128 -0
- multipers/mma_structures.pyx +2786 -0
- multipers/mma_structures.pyx.tp +1094 -0
- multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
- multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
- multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
- multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
- multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
- multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
- multipers/multiparameter_edge_collapse.py +41 -0
- multipers/multiparameter_module_approximation/approximation.h +2330 -0
- multipers/multiparameter_module_approximation/combinatory.h +129 -0
- multipers/multiparameter_module_approximation/debug.h +107 -0
- multipers/multiparameter_module_approximation/euler_curves.h +0 -0
- multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
- multipers/multiparameter_module_approximation/heap_column.h +238 -0
- multipers/multiparameter_module_approximation/images.h +79 -0
- multipers/multiparameter_module_approximation/list_column.h +174 -0
- multipers/multiparameter_module_approximation/list_column_2.h +232 -0
- multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
- multipers/multiparameter_module_approximation/set_column.h +135 -0
- multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
- multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
- multipers/multiparameter_module_approximation/utilities.h +403 -0
- multipers/multiparameter_module_approximation/vector_column.h +223 -0
- multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
- multipers/multiparameter_module_approximation/vineyards.h +464 -0
- multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
- multipers/multiparameter_module_approximation.cpython-313-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyx +235 -0
- multipers/pickle.py +90 -0
- multipers/plots.py +456 -0
- multipers/point_measure.cpython-313-darwin.so +0 -0
- multipers/point_measure.pyx +395 -0
- multipers/simplex_tree_multi.cpython-313-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +134 -0
- multipers/simplex_tree_multi.pyx +10840 -0
- multipers/simplex_tree_multi.pyx.tp +2009 -0
- multipers/slicer.cpython-313-darwin.so +0 -0
- multipers/slicer.pxd +3034 -0
- multipers/slicer.pxd.tp +234 -0
- multipers/slicer.pyx +20481 -0
- multipers/slicer.pyx.tp +1088 -0
- multipers/tensor/tensor.h +672 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers/tests/__init__.py +62 -0
- multipers/torch/__init__.py +1 -0
- multipers/torch/diff_grids.py +240 -0
- multipers/torch/rips_density.py +310 -0
- multipers-2.3.3b6.dist-info/METADATA +128 -0
- multipers-2.3.3b6.dist-info/RECORD +183 -0
- multipers-2.3.3b6.dist-info/WHEEL +6 -0
- multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
- multipers-2.3.3b6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
import gudhi as gd
|
|
5
|
+
import numpy as np
|
|
6
|
+
from joblib import Parallel, delayed
|
|
7
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
8
|
+
from scipy.spatial.distance import cdist
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
import multipers as mp
|
|
12
|
+
import multipers.slicer as mps
|
|
13
|
+
from multipers.filtrations.density import DTM, KDE, available_kernels
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PointCloud2FilteredComplex(BaseEstimator, TransformerMixin):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
bandwidths=[],
|
|
20
|
+
masses=[],
|
|
21
|
+
threshold: float = -np.inf,
|
|
22
|
+
complex: Literal["alpha", "rips", "delaunay"] = "rips",
|
|
23
|
+
sparse: Optional[float] = None,
|
|
24
|
+
num_collapses: int = -2,
|
|
25
|
+
kernel: available_kernels = "gaussian",
|
|
26
|
+
log_density: bool = True,
|
|
27
|
+
expand_dim: int = 1,
|
|
28
|
+
progress: bool = False,
|
|
29
|
+
n_jobs: Optional[int] = None,
|
|
30
|
+
fit_fraction: float = 1,
|
|
31
|
+
verbose: bool = False,
|
|
32
|
+
safe_conversion: bool = False,
|
|
33
|
+
output_type: Optional[
|
|
34
|
+
Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
|
|
35
|
+
] = None,
|
|
36
|
+
reduce_degrees: Optional[Iterable[int]] = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
(Rips or Alpha or Delaunay) + (Density Estimation or DTM) 1-critical 2-filtration.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
- bandwidth : real : The kernel density estimation bandwidth, or the DTM mass. If negative, it replaced by abs(bandwidth)*(radius of the dataset)
|
|
44
|
+
- threshold : real, max edge lenfth of the rips or max alpha square of the alpha
|
|
45
|
+
- sparse : real, sparse rips (c.f. rips doc) WARNING : ONLY FOR RIPS
|
|
46
|
+
- num_collapse : int, Number of edge collapses applied to the simplextrees, WARNING : ONLY FOR RIPS
|
|
47
|
+
- expand_dim : int, expand the rips complex to this dimension. WARNING : ONLY FOR RIPS
|
|
48
|
+
- kernel : the kernel used for density estimation. Available ones are, e.g., "dtm", "gaussian", "exponential".
|
|
49
|
+
- progress : bool, shows the calculus status
|
|
50
|
+
- n_jobs : number of processes
|
|
51
|
+
- fit_fraction : real, the fraction of data on which to fit
|
|
52
|
+
- verbose : bool, Shows more information if true.
|
|
53
|
+
|
|
54
|
+
Output
|
|
55
|
+
------
|
|
56
|
+
A list of SimplexTreeMulti whose first parameter is a rips and the second is the codensity.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.bandwidths = bandwidths
|
|
60
|
+
self.masses = masses
|
|
61
|
+
self.num_collapses = num_collapses
|
|
62
|
+
self.kernel = kernel
|
|
63
|
+
self.log_density = log_density
|
|
64
|
+
self.progress = progress
|
|
65
|
+
self._bandwidths = np.empty((0,))
|
|
66
|
+
self._threshold = np.inf
|
|
67
|
+
self.n_jobs = n_jobs
|
|
68
|
+
self._scale = np.empty((0,))
|
|
69
|
+
self.fit_fraction = fit_fraction
|
|
70
|
+
self.expand_dim = expand_dim
|
|
71
|
+
self.verbose = verbose
|
|
72
|
+
self.complex = complex
|
|
73
|
+
self.threshold = threshold
|
|
74
|
+
self.sparse = sparse
|
|
75
|
+
self._get_sts = lambda: Exception("Fit first")
|
|
76
|
+
self.safe_conversion = safe_conversion
|
|
77
|
+
self.output_type = output_type
|
|
78
|
+
self._output_type = None
|
|
79
|
+
self.reduce_degrees = reduce_degrees
|
|
80
|
+
self._vineyard = None
|
|
81
|
+
|
|
82
|
+
assert (
|
|
83
|
+
output_type != "simplextree" or reduce_degrees is None
|
|
84
|
+
), "Reduced complex are not simplicial. Cannot return a simplextree."
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
def _get_distance_quantiles_and_threshold(self, X, qs):
|
|
88
|
+
## if we dont need to compute a distance matrix
|
|
89
|
+
if len(qs) == 0 and self.threshold >= 0:
|
|
90
|
+
self._scale = []
|
|
91
|
+
return []
|
|
92
|
+
if self.progress:
|
|
93
|
+
print("Estimating scale...", flush=True, end="")
|
|
94
|
+
## subsampling
|
|
95
|
+
indices = np.random.choice(
|
|
96
|
+
len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def compute_max_scale(x):
|
|
100
|
+
from pykeops.numpy import LazyTensor
|
|
101
|
+
|
|
102
|
+
a = LazyTensor(x[None, :, :])
|
|
103
|
+
b = LazyTensor(x[:, None, :])
|
|
104
|
+
return np.sqrt(((a - b) ** 2).sum(2).max(1).min(0)[0])
|
|
105
|
+
|
|
106
|
+
diameter = np.max([compute_max_scale(x) for x in (X[i] for i in indices)])
|
|
107
|
+
self._scale = diameter * np.array(qs)
|
|
108
|
+
|
|
109
|
+
if self.threshold == -np.inf:
|
|
110
|
+
self._threshold = diameter
|
|
111
|
+
elif self.threshold > 0:
|
|
112
|
+
self._threshold = self.threshold
|
|
113
|
+
else:
|
|
114
|
+
self._threshold = -diameter * self.threshold
|
|
115
|
+
|
|
116
|
+
if self.threshold > 0:
|
|
117
|
+
self._scale[self._scale > self.threshold] = self.threshold
|
|
118
|
+
|
|
119
|
+
if self.progress:
|
|
120
|
+
print(f"Done. Chosen scales {qs} are {self._scale}", flush=True)
|
|
121
|
+
return self._scale
|
|
122
|
+
|
|
123
|
+
def _get_sts_rips(self, x):
|
|
124
|
+
assert self._output_type is not None and self._vineyard is not None
|
|
125
|
+
if self.sparse is None:
|
|
126
|
+
st_init = gd.SimplexTree.create_from_array(
|
|
127
|
+
cdist(x,x), max_filtration=self._threshold
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
st_init = gd.RipsComplex(
|
|
131
|
+
points=x, max_edge_length=self._threshold, sparse=self.sparse
|
|
132
|
+
).create_simplex_tree(max_dimension=1)
|
|
133
|
+
st_init = mp.simplex_tree_multi.SimplexTreeMulti(
|
|
134
|
+
st_init, num_parameters=2, safe_conversion=self.safe_conversion
|
|
135
|
+
)
|
|
136
|
+
codensities = self._get_codensities(x_fit=x, x_sample=x)
|
|
137
|
+
num_axes = codensities.shape[0]
|
|
138
|
+
sts = [st_init] + [st_init.copy() for _ in range(num_axes - 1)]
|
|
139
|
+
# no need to multithread here, most operations are memory
|
|
140
|
+
for codensity, st_copy in zip(codensities, sts):
|
|
141
|
+
# RIPS has contigus vertices, so vertices are ordered.
|
|
142
|
+
st_copy.fill_lowerstar(codensity, parameter=1)
|
|
143
|
+
|
|
144
|
+
def reduce(st):
|
|
145
|
+
if self.verbose:
|
|
146
|
+
print("Num simplices :", st.num_simplices)
|
|
147
|
+
if isinstance(self.num_collapses, int):
|
|
148
|
+
st.collapse_edges(num=self.num_collapses)
|
|
149
|
+
if self.verbose:
|
|
150
|
+
print(", after collapse :", st.num_simplices, end="")
|
|
151
|
+
elif self.num_collapses == "full":
|
|
152
|
+
st.collapse_edges(full=True)
|
|
153
|
+
if self.verbose:
|
|
154
|
+
print(", after collapse :", st.num_simplices, end="")
|
|
155
|
+
if self.expand_dim > 1:
|
|
156
|
+
st.expansion(self.expand_dim)
|
|
157
|
+
if self.verbose:
|
|
158
|
+
print(", after expansion :", st.num_simplices, end="")
|
|
159
|
+
if self.verbose:
|
|
160
|
+
print("")
|
|
161
|
+
if self._output_type == "slicer":
|
|
162
|
+
st = mp.Slicer(st, vineyard=self._vineyard)
|
|
163
|
+
if self.reduce_degrees is not None:
|
|
164
|
+
st = mp.slicer.minimal_presentation(
|
|
165
|
+
st, degrees=self.reduce_degrees, vineyard=self._vineyard
|
|
166
|
+
)
|
|
167
|
+
return st
|
|
168
|
+
|
|
169
|
+
return Parallel(backend="threading", n_jobs=self.n_jobs)(
|
|
170
|
+
delayed(reduce)(st) for st in sts
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def _get_sts_alpha(self, x: np.ndarray, return_alpha=False):
|
|
174
|
+
assert self._output_type is not None and self._vineyard is not None
|
|
175
|
+
alpha_complex = gd.AlphaComplex(points=x)
|
|
176
|
+
st = alpha_complex.create_simplex_tree(max_alpha_square=self._threshold**2)
|
|
177
|
+
vertices = np.array([i for (i,), _ in st.get_skeleton(0)])
|
|
178
|
+
new_points = np.asarray(
|
|
179
|
+
[alpha_complex.get_point(int(i)) for i in vertices]
|
|
180
|
+
) # Seems to be unsafe for some reason
|
|
181
|
+
# new_points = x
|
|
182
|
+
st = mp.simplex_tree_multi.SimplexTreeMulti(
|
|
183
|
+
st, num_parameters=2, safe_conversion=self.safe_conversion
|
|
184
|
+
)
|
|
185
|
+
codensities = self._get_codensities(x_fit=x, x_sample=new_points)
|
|
186
|
+
num_axes = codensities.shape[0]
|
|
187
|
+
sts = [st] + [st.copy() for _ in range(num_axes - 1)]
|
|
188
|
+
# no need to multithread here, most operations are memory
|
|
189
|
+
max_vertices = vertices.max() + 2 # +1 to be safe
|
|
190
|
+
for codensity, st_copy in zip(codensities, sts):
|
|
191
|
+
alligned_codensity = np.array([np.nan] * max_vertices)
|
|
192
|
+
alligned_codensity[vertices] = codensity
|
|
193
|
+
# alligned_codensity = np.array([codensity[i] if i in vertices else np.nan for i in range(max_vertices)])
|
|
194
|
+
st_copy.fill_lowerstar(alligned_codensity, parameter=1)
|
|
195
|
+
if "slicer" in self._output_type:
|
|
196
|
+
sts2 = (mp.Slicer(st, vineyard=self._vineyard) for st in sts)
|
|
197
|
+
if self.reduce_degrees is not None:
|
|
198
|
+
sts = tuple(
|
|
199
|
+
mp.slicer.minimal_presentation(
|
|
200
|
+
s, degrees=self.reduce_degrees, vineyard=self._vineyard
|
|
201
|
+
)
|
|
202
|
+
for s in sts2
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
sts = tuple(sts2)
|
|
206
|
+
if return_alpha:
|
|
207
|
+
return alpha_complex, sts
|
|
208
|
+
return sts
|
|
209
|
+
|
|
210
|
+
def _get_sts_delaunay(self, x: np.ndarray):
|
|
211
|
+
codensities = self._get_codensities(x_fit=x, x_sample=x)
|
|
212
|
+
|
|
213
|
+
def get_st(c):
|
|
214
|
+
slicer = mps.from_function_delaunay(
|
|
215
|
+
x,
|
|
216
|
+
c,
|
|
217
|
+
verbose=self.verbose,
|
|
218
|
+
clear=not self.verbose,
|
|
219
|
+
vineyard=self._vineyard,
|
|
220
|
+
)
|
|
221
|
+
if self._output_type == "simplextree":
|
|
222
|
+
slicer = mps.to_simplextree(slicer)
|
|
223
|
+
elif self.reduce_degrees is not None:
|
|
224
|
+
slicer = mp.slicer.minimal_presentation(
|
|
225
|
+
slicer,
|
|
226
|
+
degrees=self.reduce_degrees,
|
|
227
|
+
vineyard=self._vineyard,
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
slicer = slicer
|
|
231
|
+
return slicer
|
|
232
|
+
|
|
233
|
+
sts = Parallel(backend="threading", n_jobs=self.n_jobs)(
|
|
234
|
+
delayed(get_st)(c) for c in codensities
|
|
235
|
+
)
|
|
236
|
+
return sts
|
|
237
|
+
|
|
238
|
+
def _get_codensities(self, x_fit, x_sample):
|
|
239
|
+
x_fit = np.asarray(x_fit, dtype=np.float64)
|
|
240
|
+
x_sample = np.asarray(x_sample, dtype=np.float64)
|
|
241
|
+
codensities_kde = np.asarray(
|
|
242
|
+
[
|
|
243
|
+
-KDE(
|
|
244
|
+
bandwidth=bandwidth, kernel=self.kernel, return_log=self.log_density
|
|
245
|
+
)
|
|
246
|
+
.fit(x_fit)
|
|
247
|
+
.score_samples(x_sample)
|
|
248
|
+
for bandwidth in self._bandwidths
|
|
249
|
+
],
|
|
250
|
+
).reshape(len(self._bandwidths), len(x_sample))
|
|
251
|
+
codensities_dtm = (
|
|
252
|
+
DTM(masses=self.masses)
|
|
253
|
+
.fit(x_fit)
|
|
254
|
+
.score_samples(x_sample)
|
|
255
|
+
.reshape(len(self.masses), len(x_sample))
|
|
256
|
+
)
|
|
257
|
+
return np.concatenate([codensities_kde, codensities_dtm])
|
|
258
|
+
|
|
259
|
+
def _define_sts(self):
|
|
260
|
+
match self.complex:
|
|
261
|
+
case "rips":
|
|
262
|
+
self._get_sts = self._get_sts_rips
|
|
263
|
+
_pref_output = "simplextree"
|
|
264
|
+
case "alpha":
|
|
265
|
+
self._get_sts = self._get_sts_alpha
|
|
266
|
+
_pref_output = "simplextree"
|
|
267
|
+
case "delaunay":
|
|
268
|
+
self._get_sts = self._get_sts_delaunay
|
|
269
|
+
_pref_output = "slicer"
|
|
270
|
+
case _:
|
|
271
|
+
raise ValueError(
|
|
272
|
+
f"Invalid complex \
|
|
273
|
+
{self.complex}. Possible choises are rips, delaunay, or alpha."
|
|
274
|
+
)
|
|
275
|
+
self._vineyard = (
|
|
276
|
+
False if self.output_type is None else "novine" not in self.output_type
|
|
277
|
+
)
|
|
278
|
+
self._output_type = (
|
|
279
|
+
_pref_output
|
|
280
|
+
if self.output_type is None
|
|
281
|
+
else (
|
|
282
|
+
"simplextree"
|
|
283
|
+
if (
|
|
284
|
+
self.output_type == "simplextree" or self.reduce_degrees is not None
|
|
285
|
+
)
|
|
286
|
+
else "slicer"
|
|
287
|
+
)
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
def _define_bandwidths(self, X):
|
|
291
|
+
qs = [q for q in [*-np.asarray(self.bandwidths)] if 0 <= q <= 1]
|
|
292
|
+
self._get_distance_quantiles_and_threshold(X, qs=qs)
|
|
293
|
+
self._bandwidths = np.array(self.bandwidths)
|
|
294
|
+
count = 0
|
|
295
|
+
for i in range(len(self._bandwidths)):
|
|
296
|
+
if self.bandwidths[i] < 0:
|
|
297
|
+
self._bandwidths[i] = self._scale[count]
|
|
298
|
+
count += 1
|
|
299
|
+
|
|
300
|
+
def fit(self, X: np.ndarray | list, y=None):
|
|
301
|
+
# self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
|
|
302
|
+
self._define_sts()
|
|
303
|
+
self._define_bandwidths(X)
|
|
304
|
+
# PRECOMPILE FIRST
|
|
305
|
+
self._get_codensities(X[0][:2], X[0][:2])
|
|
306
|
+
return self
|
|
307
|
+
|
|
308
|
+
def transform(self, X):
|
|
309
|
+
# precompile first
|
|
310
|
+
# self._get_sts(X[0][:5])
|
|
311
|
+
self._get_codensities(X[0][:2], X[0][:2])
|
|
312
|
+
with tqdm(
|
|
313
|
+
X, desc="Filling simplextrees", disable=not self.progress, total=len(X)
|
|
314
|
+
) as data:
|
|
315
|
+
stss = Parallel(backend="threading", n_jobs=self.n_jobs)(
|
|
316
|
+
delayed(self._get_sts)(x) for x in data
|
|
317
|
+
)
|
|
318
|
+
return stss
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class PointCloud2SimplexTree(PointCloud2FilteredComplex):
|
|
322
|
+
def __init__(
|
|
323
|
+
self,
|
|
324
|
+
bandwidths=[],
|
|
325
|
+
masses=[],
|
|
326
|
+
threshold: float = np.inf,
|
|
327
|
+
complex: Literal["alpha", "rips", "delaunay"] = "rips",
|
|
328
|
+
sparse: float | None = None,
|
|
329
|
+
num_collapses: int = -2,
|
|
330
|
+
kernel: available_kernels = "gaussian",
|
|
331
|
+
log_density: bool = True,
|
|
332
|
+
expand_dim: int = 1,
|
|
333
|
+
progress: bool = False,
|
|
334
|
+
n_jobs: Optional[int] = None,
|
|
335
|
+
fit_fraction: float = 1,
|
|
336
|
+
verbose: bool = False,
|
|
337
|
+
safe_conversion: bool = False,
|
|
338
|
+
output_type: Optional[
|
|
339
|
+
Literal["slicer", "simplextree", "slicer_vine", "slicer_novine"]
|
|
340
|
+
] = None,
|
|
341
|
+
reduce_degrees: Optional[Iterable[int]] = None,
|
|
342
|
+
) -> None:
|
|
343
|
+
stuff = locals()
|
|
344
|
+
stuff.pop("self")
|
|
345
|
+
keys = list(stuff.keys())
|
|
346
|
+
for key in keys:
|
|
347
|
+
if key.startswith("__"):
|
|
348
|
+
stuff.pop(key)
|
|
349
|
+
super().__init__(**stuff)
|
|
350
|
+
from warnings import warn
|
|
351
|
+
|
|
352
|
+
warn("This class is deprecated, use PointCloud2FilteredComplex instead.")
|