multipers 2.3.3b6__cp313-cp313-macosx_10_13_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/.dylibs/libc++.1.0.dylib +0 -0
- multipers/.dylibs/libtbb.12.16.dylib +0 -0
- multipers/__init__.py +33 -0
- multipers/_signed_measure_meta.py +453 -0
- multipers/_slicer_meta.py +211 -0
- multipers/array_api/__init__.py +45 -0
- multipers/array_api/numpy.py +41 -0
- multipers/array_api/torch.py +58 -0
- multipers/data/MOL2.py +458 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +113 -0
- multipers/distances.py +202 -0
- multipers/filtration_conversions.pxd +229 -0
- multipers/filtration_conversions.pxd.tp +84 -0
- multipers/filtrations/__init__.py +18 -0
- multipers/filtrations/density.py +574 -0
- multipers/filtrations/filtrations.py +361 -0
- multipers/filtrations.pxd +224 -0
- multipers/function_rips.cpython-313-darwin.so +0 -0
- multipers/function_rips.pyx +105 -0
- multipers/grids.cpython-313-darwin.so +0 -0
- multipers/grids.pyx +433 -0
- multipers/gudhi/Persistence_slices_interface.h +132 -0
- multipers/gudhi/Simplex_tree_interface.h +239 -0
- multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
- multipers/gudhi/cubical_to_boundary.h +59 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
- multipers/gudhi/gudhi/Debug_utils.h +45 -0
- multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
- multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
- multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
- multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
- multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
- multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
- multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
- multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
- multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
- multipers/gudhi/gudhi/Matrix.h +2107 -0
- multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
- multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
- multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
- multipers/gudhi/gudhi/Off_reader.h +173 -0
- multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
- multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
- multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
- multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
- multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
- multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
- multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
- multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
- multipers/gudhi/gudhi/Points_off_io.h +171 -0
- multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
- multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
- multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
- multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
- multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
- multipers/gudhi/gudhi/distance_functions.h +62 -0
- multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
- multipers/gudhi/gudhi/persistence_interval.h +253 -0
- multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
- multipers/gudhi/gudhi/reader_utils.h +367 -0
- multipers/gudhi/mma_interface_coh.h +256 -0
- multipers/gudhi/mma_interface_h0.h +223 -0
- multipers/gudhi/mma_interface_matrix.h +293 -0
- multipers/gudhi/naive_merge_tree.h +536 -0
- multipers/gudhi/scc_io.h +310 -0
- multipers/gudhi/truc.h +1403 -0
- multipers/io.cpython-313-darwin.so +0 -0
- multipers/io.pyx +644 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +90 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +176 -0
- multipers/ml/mma.py +713 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +352 -0
- multipers/ml/signed_measures.py +1589 -0
- multipers/ml/sliced_wasserstein.py +461 -0
- multipers/ml/tools.py +113 -0
- multipers/mma_structures.cpython-313-darwin.so +0 -0
- multipers/mma_structures.pxd +128 -0
- multipers/mma_structures.pyx +2786 -0
- multipers/mma_structures.pyx.tp +1094 -0
- multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
- multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
- multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
- multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
- multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
- multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
- multipers/multiparameter_edge_collapse.py +41 -0
- multipers/multiparameter_module_approximation/approximation.h +2330 -0
- multipers/multiparameter_module_approximation/combinatory.h +129 -0
- multipers/multiparameter_module_approximation/debug.h +107 -0
- multipers/multiparameter_module_approximation/euler_curves.h +0 -0
- multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
- multipers/multiparameter_module_approximation/heap_column.h +238 -0
- multipers/multiparameter_module_approximation/images.h +79 -0
- multipers/multiparameter_module_approximation/list_column.h +174 -0
- multipers/multiparameter_module_approximation/list_column_2.h +232 -0
- multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
- multipers/multiparameter_module_approximation/set_column.h +135 -0
- multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
- multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
- multipers/multiparameter_module_approximation/utilities.h +403 -0
- multipers/multiparameter_module_approximation/vector_column.h +223 -0
- multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
- multipers/multiparameter_module_approximation/vineyards.h +464 -0
- multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
- multipers/multiparameter_module_approximation.cpython-313-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyx +235 -0
- multipers/pickle.py +90 -0
- multipers/plots.py +456 -0
- multipers/point_measure.cpython-313-darwin.so +0 -0
- multipers/point_measure.pyx +395 -0
- multipers/simplex_tree_multi.cpython-313-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +134 -0
- multipers/simplex_tree_multi.pyx +10840 -0
- multipers/simplex_tree_multi.pyx.tp +2009 -0
- multipers/slicer.cpython-313-darwin.so +0 -0
- multipers/slicer.pxd +3034 -0
- multipers/slicer.pxd.tp +234 -0
- multipers/slicer.pyx +20481 -0
- multipers/slicer.pyx.tp +1088 -0
- multipers/tensor/tensor.h +672 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers/tests/__init__.py +62 -0
- multipers/torch/__init__.py +1 -0
- multipers/torch/diff_grids.py +240 -0
- multipers/torch/rips_density.py +310 -0
- multipers-2.3.3b6.dist-info/METADATA +128 -0
- multipers-2.3.3b6.dist-info/RECORD +183 -0
- multipers-2.3.3b6.dist-info/WHEEL +6 -0
- multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
- multipers-2.3.3b6.dist-info/top_level.txt +1 -0
multipers/ml/one.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
2
|
+
import gudhi as gd
|
|
3
|
+
from os.path import exists
|
|
4
|
+
import networkx as nx
|
|
5
|
+
from joblib import Parallel, delayed
|
|
6
|
+
import numpy as np
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from warnings import warn
|
|
9
|
+
from sklearn.neighbors import KernelDensity
|
|
10
|
+
from typing import Iterable
|
|
11
|
+
from gudhi.representations import Landscape
|
|
12
|
+
from gudhi.representations.vector_methods import PersistenceImage
|
|
13
|
+
from gudhi.representations.kernel_methods import SlicedWassersteinDistance
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from types import FunctionType
|
|
17
|
+
def get_simplextree(x)->gd.SimplexTree:
|
|
18
|
+
if isinstance(x, gd.SimplexTree):
|
|
19
|
+
return x
|
|
20
|
+
if isinstance(x, FunctionType):
|
|
21
|
+
return x()
|
|
22
|
+
if len(x) == 3 and isinstance(x[0],FunctionType):
|
|
23
|
+
f,args, kwargs = x
|
|
24
|
+
return f(*args,**kwargs)
|
|
25
|
+
raise TypeError("Not a valid SimplexTree")
|
|
26
|
+
def get_simplextrees(X)->Iterable[gd.SimplexTree]:
|
|
27
|
+
if len(X) == 2 and isinstance(X[0], FunctionType):
|
|
28
|
+
f,data = X
|
|
29
|
+
return (f(x) for x in data)
|
|
30
|
+
if len(X) == 0: return []
|
|
31
|
+
if not isinstance(X[0], gd.SimplexTree):
|
|
32
|
+
raise TypeError
|
|
33
|
+
return X
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
############## INTERVALS (for sliced wasserstein)
|
|
39
|
+
class Graph2SimplexTree(BaseEstimator,TransformerMixin):
|
|
40
|
+
def __init__(self, f:str="ricciCurvature",dtype=gd.SimplexTree, reverse_filtration:bool=False):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.f=f # filtration to search in graph
|
|
43
|
+
self.dtype = dtype # If None, will delay the computation in the pipe (for parallelism)
|
|
44
|
+
self.reverse_filtration = reverse_filtration # reverses the filtration #TODO
|
|
45
|
+
def fit(self, X, y=None):
|
|
46
|
+
return self
|
|
47
|
+
def transform(self,X:list[nx.Graph]):
|
|
48
|
+
def todo(graph, f=self.f) -> gd.SimplexTree: # TODO : use batch insert
|
|
49
|
+
st = gd.SimplexTree()
|
|
50
|
+
for i in graph.nodes: st.insert([i], graph.nodes[i][f])
|
|
51
|
+
for u,v in graph.edges: st.insert([u,v], graph[u][v][f])
|
|
52
|
+
return st
|
|
53
|
+
return [todo, X] if self.dtype is None else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(graph) for graph in X)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PointCloud2SimplexTree(BaseEstimator,TransformerMixin):
|
|
57
|
+
def __init__(self, delayed:bool = False, threshold = np.inf):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.delayed = delayed
|
|
60
|
+
self.threshold=threshold
|
|
61
|
+
@staticmethod
|
|
62
|
+
def _get_point_cloud_diameter(x):
|
|
63
|
+
from scipy.spatial import distance_matrix
|
|
64
|
+
return np.max(distance_matrix(x,x))
|
|
65
|
+
def fit(self, X, y=None):
|
|
66
|
+
if self.threshold < 0:
|
|
67
|
+
self.threshold = max(self._get_point_cloud_diameter(x) for x in X)
|
|
68
|
+
return self
|
|
69
|
+
def transform(self,X:list[nx.Graph]):
|
|
70
|
+
def todo(point_cloud) -> gd.SimplexTree: # TODO : use batch insert
|
|
71
|
+
st = gd.AlphaComplex(points=point_cloud).create_simplex_tree(max_alpha_square = self.threshold**2)
|
|
72
|
+
return st
|
|
73
|
+
return [todo, X] if self.delayed is None else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(point_cloud) for point_cloud in X)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
#################### FILVEC
|
|
78
|
+
def get_filtration_values(g:nx.Graph, f:str)->np.ndarray:
|
|
79
|
+
filtrations_values = [
|
|
80
|
+
g.nodes[node][f] for node in g.nodes
|
|
81
|
+
]+[
|
|
82
|
+
g[u][v][f] for u,v in g.edges
|
|
83
|
+
]
|
|
84
|
+
return np.array(filtrations_values)
|
|
85
|
+
def graph2filvec(g:nx.Graph, f:str, range:tuple, bins:int)->np.ndarray:
|
|
86
|
+
fs = get_filtration_values(g, f)
|
|
87
|
+
return np.histogram(fs, bins=bins,range=range)[0]
|
|
88
|
+
class FilvecGetter(BaseEstimator, TransformerMixin):
|
|
89
|
+
def __init__(self, f:str="ricciCurvature",quantile:float=0., bins:int=100, n_jobs:int=1):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.f=f
|
|
92
|
+
self.quantile=quantile
|
|
93
|
+
self.bins=bins
|
|
94
|
+
self.range:tuple[float]|None=None
|
|
95
|
+
self.n_jobs=n_jobs
|
|
96
|
+
def fit(self, X, y=None):
|
|
97
|
+
filtration_values = np.concatenate(Parallel(n_jobs=self.n_jobs)(delayed(get_filtration_values)(g,f=self.f) for g in X))
|
|
98
|
+
self.range= tuple(np.quantile(filtration_values, [self.quantile, 1-self.quantile]))
|
|
99
|
+
return self
|
|
100
|
+
def transform(self,X):
|
|
101
|
+
if self.range == None:
|
|
102
|
+
print("Fit first")
|
|
103
|
+
return
|
|
104
|
+
return Parallel(n_jobs=self.n_jobs)(delayed(graph2filvec)(g,f=self.f, range=self.range, bins=self.bins) for g in X)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
############# Filvec from SimplexTree
|
|
110
|
+
# Input list of [list of diagrams], outputs histogram of persitence values (x and y coord mixed)
|
|
111
|
+
def simplextree2hist(simplextree, range:tuple[float, float], bins:int, density:bool)->np.ndarray: #TODO : Anything to histogram
|
|
112
|
+
filtration_values = np.array([f for s,f in simplextree.get_simplices()])
|
|
113
|
+
return np.histogram(filtration_values, bins=bins,range=range, density=density)[0]
|
|
114
|
+
class SimplexTree2Histogram(BaseEstimator, TransformerMixin):
|
|
115
|
+
def __init__(self, quantile:float=0., bins:int=100, n_jobs:int=1, progress:bool=False, density:bool=True):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.range:np.ndarray | None=None
|
|
118
|
+
self.quantile:float=quantile
|
|
119
|
+
self.bins:int=bins
|
|
120
|
+
self.n_jobs=n_jobs
|
|
121
|
+
self.density=density
|
|
122
|
+
self.progress = progress
|
|
123
|
+
# self.max_dimension=None # TODO: maybe use it
|
|
124
|
+
def fit(self, X, y=None): # X:list[diagrams]
|
|
125
|
+
if len(X) == 0: return self
|
|
126
|
+
if type(X[0]) is gd.SimplexTree: # If X contains simplextree : nothing to do
|
|
127
|
+
data = X
|
|
128
|
+
to_st = lambda x : x
|
|
129
|
+
else: # otherwise we assume that we retrieve simplextrees using f,data = X; simplextrees = (f(x) for x in data)
|
|
130
|
+
# assert len(X) == 2
|
|
131
|
+
to_st, data = X
|
|
132
|
+
persistence_values = np.array([f for st in data for s,f in to_st(st).get_simplices()])
|
|
133
|
+
persistence_values = persistence_values[persistence_values<np.inf]
|
|
134
|
+
self.range = np.quantile(persistence_values, [self.quantile, 1-self.quantile])
|
|
135
|
+
return self
|
|
136
|
+
def transform(self,X):
|
|
137
|
+
if len(X) == 0: return self
|
|
138
|
+
if type(X[0]) is gd.SimplexTree: # If X contains simplextree : nothing to do
|
|
139
|
+
if self.n_jobs > 1:
|
|
140
|
+
warn("Cannot pickle simplextrees, reducing to 1 thread to compute the simplextrees")
|
|
141
|
+
return [simplextree2hist(g,range=self.range, bins=self.bins, density=self.density) for g in tqdm(X, desc="Computing diagrams", disable=not self.progress)]
|
|
142
|
+
else: # otherwise we assume that we retrieve simplextrees using f,data = X; simplextrees = (f(x) for x in data)
|
|
143
|
+
to_st, data = X # asserts len(X) == 2
|
|
144
|
+
def pickle_able_todo(x, **kwargs):
|
|
145
|
+
simplextree = to_st(x)
|
|
146
|
+
return simplextree2hist(simplextree=simplextree, **kwargs)
|
|
147
|
+
return Parallel(n_jobs=self.n_jobs)(delayed(pickle_able_todo)(g,range=self.range, bins=self.bins, density=self.density) for g in tqdm(data, desc="Computing simplextrees and their diagrams", disable=not self.progress))
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
############# PERVEC
|
|
153
|
+
# Input list of [list of diagrams], outputs histogram of persitence values (x and y coord mixed)
|
|
154
|
+
def dgm2pervec(dgms, range:tuple[float, float], bins:int)->np.ndarray: #TODO : Anything to histogram
|
|
155
|
+
dgm_union = np.concatenate([dgm.flatten() for dgm in dgms]).flatten()
|
|
156
|
+
return np.histogram(dgm_union, bins=bins,range=range)[0]
|
|
157
|
+
class Dgm2Histogram(BaseEstimator, TransformerMixin):
|
|
158
|
+
def __init__(self, quantile:float=0., bins:int=100, n_jobs:int=1):
|
|
159
|
+
super().__init__()
|
|
160
|
+
self.range:np.ndarray | None=None
|
|
161
|
+
self.quantile:float=quantile
|
|
162
|
+
self.bins:int=bins
|
|
163
|
+
self.n_jobs=n_jobs
|
|
164
|
+
def fit(self, X, y=None): # X:list[diagrams]
|
|
165
|
+
persistence_values = np.concatenate([dgm.flatten() for dgms in X for dgm in dgms], axis=0).flatten()
|
|
166
|
+
persistence_values = persistence_values[persistence_values<np.inf]
|
|
167
|
+
self.range = np.quantile(persistence_values, [self.quantile, 1-self.quantile])
|
|
168
|
+
return self
|
|
169
|
+
def transform(self,X):
|
|
170
|
+
return Parallel(n_jobs=self.n_jobs)(delayed(dgm2pervec)(g,range=self.range, bins=self.bins) for g in X)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
################# SignedMeasureImage
|
|
179
|
+
class Dgms2SignedMeasureImage(BaseEstimator, TransformerMixin):
|
|
180
|
+
def __init__(self, ranges:None|Iterable[Iterable[float]]=None, resolution:int=100, quantile:float=0, bandwidth:float=1, kernel:str="gaussian") -> None:
|
|
181
|
+
super().__init__()
|
|
182
|
+
self.ranges=ranges
|
|
183
|
+
self.resolution=resolution
|
|
184
|
+
self.quantile = quantile
|
|
185
|
+
self.bandwidth = bandwidth
|
|
186
|
+
self.kernel = kernel
|
|
187
|
+
def fit(self, X, y=None): # X:list[diagrams]
|
|
188
|
+
num_degrees = len(X[0])
|
|
189
|
+
persistence_values = [np.concatenate([dgms[i].flatten() for dgms in X], axis=0) for i in range(num_degrees)] # values per degree
|
|
190
|
+
persistence_values = [degrees_values[(-np.inf<degrees_values) * (degrees_values<np.inf)] for degrees_values in persistence_values] # non-trivial values
|
|
191
|
+
quantiles = [np.quantile(degree_values, [self.quantile, 1-self.quantile]) for degree_values in persistence_values] # quantiles
|
|
192
|
+
self.ranges = np.array([np.linspace(start=[a], stop=[b], num=self.resolution) for a,b in quantiles])
|
|
193
|
+
return self
|
|
194
|
+
|
|
195
|
+
def _dgm2smi(self, dgms:Iterable[np.ndarray]):
|
|
196
|
+
smi = np.concatenate(
|
|
197
|
+
[
|
|
198
|
+
KernelDensity(bandwidth=self.bandwidth, kernel=self.kernel).fit(dgm[:,[0]]).score_samples(range)
|
|
199
|
+
- KernelDensity(bandwidth=self.bandwidth).fit(dgm[:,[1]]).score_samples(range)
|
|
200
|
+
for dgm, range in zip(dgms, self.ranges)
|
|
201
|
+
],
|
|
202
|
+
axis=0)
|
|
203
|
+
return smi
|
|
204
|
+
|
|
205
|
+
def transform(self,X): # X is a list (data) of list of diagrams
|
|
206
|
+
assert self.ranges is not None
|
|
207
|
+
out = Parallel(n_jobs=1, prefer="threads")(
|
|
208
|
+
delayed(Dgms2SignedMeasureImage._dgm2smi)(self=self, dgms=dgms)
|
|
209
|
+
for dgms in X
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return out
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
################# SignedMeasureHistogram
|
|
217
|
+
class Dgms2SignedMeasureHistogram(BaseEstimator, TransformerMixin):
|
|
218
|
+
def __init__(self, ranges:None|list[tuple[float,float]]=None, bins:int=100, quantile:float=0) -> None:
|
|
219
|
+
super().__init__()
|
|
220
|
+
self.ranges=ranges
|
|
221
|
+
self.bins=bins
|
|
222
|
+
self.quantile = quantile
|
|
223
|
+
def fit(self, X, y=None): # X:list[diagrams]
|
|
224
|
+
num_degrees = len(X[0])
|
|
225
|
+
persistence_values = [np.concatenate([dgms[i].flatten() for dgms in X], axis=0) for i in range(num_degrees)] # values per degree
|
|
226
|
+
persistence_values = [degrees_values[(-np.inf<degrees_values) * (degrees_values<np.inf)] for degrees_values in persistence_values] # non-trivial values
|
|
227
|
+
self.ranges = [np.quantile(degree_values, [self.quantile, 1-self.quantile]) for degree_values in persistence_values] # quantiles
|
|
228
|
+
return self
|
|
229
|
+
def transform(self,X): # X is a list (data) of list of diagrams
|
|
230
|
+
assert self.ranges is not None
|
|
231
|
+
out = [
|
|
232
|
+
np.concatenate(
|
|
233
|
+
[np.histogram(dgm[:,0], bins=self.bins,range=range)[0] - np.histogram(dgm[:,1], bins=self.bins,range=range)[0]
|
|
234
|
+
for dgm, range in zip(dgms, self.ranges)]
|
|
235
|
+
)
|
|
236
|
+
for dgms in X]
|
|
237
|
+
return out
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
################## Signed Measure Kernel 1D
|
|
247
|
+
# input : list of [list of diagrams], outputs: the kernel to feed to an svm
|
|
248
|
+
|
|
249
|
+
# TODO : optimize ?
|
|
250
|
+
## TODO : np.triu
|
|
251
|
+
class Dgms2SignedMeasureDistance(BaseEstimator, TransformerMixin):
|
|
252
|
+
def __init__(self, n_jobs:int=1, distance_matrix_path:str|None=None, progress:bool = False) -> None:
|
|
253
|
+
super().__init__()
|
|
254
|
+
self.degrees:list[int]|None=None
|
|
255
|
+
self.X:None|list[np.ndarray] = None
|
|
256
|
+
self.n_jobs=n_jobs
|
|
257
|
+
self.distance_matrix_path = distance_matrix_path
|
|
258
|
+
self.progress=progress
|
|
259
|
+
def fit(self, X:list[np.ndarray], y=None):
|
|
260
|
+
if len(X) <= 0:
|
|
261
|
+
warn("Fit a nontrivial vector")
|
|
262
|
+
return
|
|
263
|
+
self.X = X
|
|
264
|
+
self.degrees = list(range(len(X[0]))) # Assumes that all x \in X have the same number of diagrams
|
|
265
|
+
return self
|
|
266
|
+
|
|
267
|
+
@staticmethod
|
|
268
|
+
def wasserstein_1(a:np.ndarray,b:np.ndarray)->float:
|
|
269
|
+
return np.abs(np.sort(a) - np.sort(b)).mean() # norm 1
|
|
270
|
+
@staticmethod
|
|
271
|
+
def OSWdistance(mu:list[np.ndarray], nu:list[np.ndarray], dim:int)->float:
|
|
272
|
+
return Dgms2SignedMeasureDistance.wasserstein_1(np.hstack([mu[dim][:,0], nu[dim][:,1]]), np.hstack([nu[dim][:,0], mu[dim][:,1]])) # TODO : check: do we want to sum the kernels or the distances ? add weights ?
|
|
273
|
+
@staticmethod
|
|
274
|
+
def _ds(mu:list[np.ndarray], nus:list[list[np.ndarray]], dim:int): # mu and nu are lists of diagrams seen as signed measures (birth = +, death = -)
|
|
275
|
+
return [Dgms2SignedMeasureDistance.OSWdistance(mu,nu, dim) for nu in nus]
|
|
276
|
+
|
|
277
|
+
def transform(self,X): # X is a list (data) of list of diagrams
|
|
278
|
+
if self.X is None or self.degrees is None:
|
|
279
|
+
warn("Fit first !")
|
|
280
|
+
return np.array([[]])
|
|
281
|
+
# Cannot use sklearn / scipy, measures don't have the same size, -> no numpy array
|
|
282
|
+
# from sklearn.metrics import pairwise_distances
|
|
283
|
+
# distances = pairwise_distances(X, self.X, metric = OSWdistance, n_jobs=self.n_jobs)
|
|
284
|
+
# from scipy.spatial.distance import cdist
|
|
285
|
+
# distances = cdist(X, self.X, metric=self.OSWdistance)
|
|
286
|
+
distances_matrices = []
|
|
287
|
+
if not self.distance_matrix_path is None:
|
|
288
|
+
for degree in self.degrees:
|
|
289
|
+
with tqdm(X, desc=f"Computing distance matrix of degree {degree}") as diagrams_iterator:
|
|
290
|
+
matrix_path = f"{self.distance_matrix_path}_{degree}"
|
|
291
|
+
if exists(matrix_path):
|
|
292
|
+
distance_matrix = np.load(open(matrix_path, "rb"))
|
|
293
|
+
else:
|
|
294
|
+
distance_matrix = np.array(Parallel(n_jobs=self.n_jobs)(delayed(self._ds)(mu, self.X, degree) for mu in diagrams_iterator))
|
|
295
|
+
np.save(open(matrix_path, "wb"), distance_matrix)
|
|
296
|
+
distances_matrices.append(distance_matrix)
|
|
297
|
+
else:
|
|
298
|
+
for degree in self.degrees:
|
|
299
|
+
with tqdm(X, desc=f"Computing distance matrix of degree {degree}") as diagrams_iterator:
|
|
300
|
+
distances_matrices.append(np.array(Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self._ds)(mu, self.X, degree) for mu in diagrams_iterator)))
|
|
301
|
+
return np.asarray(distances_matrices)
|
|
302
|
+
# kernels = [np.exp(-distance_matrix / (2*self.sigma**2)) for distance_matrix in distances_matrices]
|
|
303
|
+
# return np.sum(kernels, axis=0)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
## Wrapper for SW, in order to take as an input a list of (list of diagrams)
|
|
310
|
+
class Dgms2SWK(BaseEstimator, TransformerMixin):
|
|
311
|
+
def __init__(self, num_directions:int=10, bandwidth:float=1.0, n_jobs:int=1, distance_matrix_path:str|None = None, progress:bool = False) -> None:
|
|
312
|
+
super().__init__()
|
|
313
|
+
self.num_directions:int=num_directions
|
|
314
|
+
self.bandwidth:float = bandwidth
|
|
315
|
+
self.n_jobs=n_jobs
|
|
316
|
+
self.SW_:list = []
|
|
317
|
+
self.distance_matrix_path = distance_matrix_path
|
|
318
|
+
self.progress = progress
|
|
319
|
+
def fit(self, X:list[list[np.ndarray]], y=None):
|
|
320
|
+
# Assumes that all x \in X have the same size
|
|
321
|
+
self.SW_ = [
|
|
322
|
+
SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs = self.n_jobs) for _ in range(len(X[0]))
|
|
323
|
+
]
|
|
324
|
+
for i, sw in enumerate(self.SW_):
|
|
325
|
+
self.SW_[i]=sw.fit([dgms[i] for dgms in X]) # TODO : check : Not sure copy is necessary here
|
|
326
|
+
return self
|
|
327
|
+
def transform(self,X)->np.ndarray:
|
|
328
|
+
if not self.distance_matrix_path is None:
|
|
329
|
+
distance_matrices = []
|
|
330
|
+
for i in range(len(self.SW_)):
|
|
331
|
+
SW_i_path = f"{self.distance_matrix_path}_{i}"
|
|
332
|
+
if exists(SW_i_path):
|
|
333
|
+
distance_matrices.append(np.load(open(SW_i_path, "rb")))
|
|
334
|
+
else:
|
|
335
|
+
distance_matrix = self.SW_[i].transform([dgms[i] for dgms in X])
|
|
336
|
+
np.save(open(SW_i_path, "wb"), distance_matrix)
|
|
337
|
+
else:
|
|
338
|
+
distance_matrices = [sw.transform([dgms[i] for dgms in X]) for i, sw in enumerate(self.SW_)]
|
|
339
|
+
kernels = [np.exp(-distance_matrix / (2*self.bandwidth**2)) for distance_matrix in distance_matrices]
|
|
340
|
+
return np.sum(kernels, axis=0) # TODO fix this, we may want to sum the distances instead of the kernels.
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class Dgms2SlicedWassersteinDistanceMatrices(BaseEstimator, TransformerMixin):
|
|
344
|
+
def __init__(self, num_directions:int=10, n_jobs:int=1) -> None:
|
|
345
|
+
super().__init__()
|
|
346
|
+
self.num_directions:int=num_directions
|
|
347
|
+
self.n_jobs=n_jobs
|
|
348
|
+
self.SW_:list = []
|
|
349
|
+
def fit(self, X:list[list[np.ndarray]], y=None):
|
|
350
|
+
# Assumes that all x \in X have the same size
|
|
351
|
+
self.SW_ = [
|
|
352
|
+
SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs = self.n_jobs) for _ in range(len(X[0]))
|
|
353
|
+
]
|
|
354
|
+
for i, sw in enumerate(self.SW_):
|
|
355
|
+
self.SW_[i]=sw.fit([dgms[i] for dgms in X]) # TODO : check : Not sure copy is necessary here
|
|
356
|
+
return self
|
|
357
|
+
|
|
358
|
+
@staticmethod
|
|
359
|
+
def _get_distance(diagrams, SWD):
|
|
360
|
+
return SWD.transform(diagrams)
|
|
361
|
+
def transform(self,X):
|
|
362
|
+
distance_matrices = Parallel(n_jobs = self.n_jobs)(delayed(self._get_distance)([dgms[degree] for dgms in X], swd) for degree, swd in enumerate(self.SW_))
|
|
363
|
+
return np.asarray(distance_matrices)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
# Gudhi simplexTree to list of diagrams
|
|
368
|
+
class SimplexTree2Dgm(BaseEstimator, TransformerMixin):
|
|
369
|
+
def __init__(self, degrees:list[int]|None = None, extended:list[int]|bool=[], n_jobs=1, progress:bool=False, threshold:float=np.inf) -> None:
|
|
370
|
+
super().__init__()
|
|
371
|
+
self.extended:list[int]|bool = False if not extended else extended if type(extended) is list else [0,2,5,7] # extended persistence.
|
|
372
|
+
# There are 4 diagrams per dimension then, the list of ints acts as a filter, on which to consider,
|
|
373
|
+
# eg., [0,2, 5,7] is Ord0, Ext+0, Rel1, Ext-1
|
|
374
|
+
self.degrees:list[int] = degrees if degrees else list(range((max(self.extended) // 4)+1)) if self.extended else [0] # homological degrees
|
|
375
|
+
self.n_jobs=n_jobs
|
|
376
|
+
self.progress = progress # progress bar
|
|
377
|
+
self.threshold = threshold # Threshold value
|
|
378
|
+
return
|
|
379
|
+
def fit(self, X:list[gd.SimplexTree], y=None):
|
|
380
|
+
if self.threshold <= 0:
|
|
381
|
+
self.threshold = max( (abs(f) for simplextree in get_simplextrees(X) for s,f in simplextree.get_simplices()) ) ## MAX FILTRATION VALUE
|
|
382
|
+
print(f"Setting threshold to {self.threshold}.")
|
|
383
|
+
return self
|
|
384
|
+
def transform(self,X:list[gd.SimplexTree]):
|
|
385
|
+
# Todo computes the diagrams
|
|
386
|
+
def reshape(dgm:np.ndarray|list)->np.ndarray:
|
|
387
|
+
out = np.array(dgm) if len(dgm) > 0 else np.empty((0,2))
|
|
388
|
+
if self.threshold != np.inf:
|
|
389
|
+
out[out>self.threshold] = self.threshold
|
|
390
|
+
out[out<-self.threshold] = -self.threshold
|
|
391
|
+
return out
|
|
392
|
+
def todo_standard(st):
|
|
393
|
+
st.compute_persistence()
|
|
394
|
+
return [reshape(st.persistence_intervals_in_dimension(d)) for d in self.degrees]
|
|
395
|
+
def todo_extended(st):
|
|
396
|
+
st.extend_filtration()
|
|
397
|
+
dgms = st.extended_persistence()
|
|
398
|
+
# print(dgms, self.degrees)
|
|
399
|
+
return [reshape([bar for j,dgm in enumerate(dgms) for d, bar in dgm if d in self.degrees and j+4*d in self.extended])]
|
|
400
|
+
todo = todo_extended if self.extended else todo_standard
|
|
401
|
+
|
|
402
|
+
if isinstance(X[0],gd.SimplexTree): # simplextree aren't pickleable, no parallel
|
|
403
|
+
# if self.n_jobs != 1: warn("Cannot parallelize. Use dtype=None in previous pipe.")
|
|
404
|
+
return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(todo)(x) for x in tqdm(X, disable=not self.progress, desc="Computing diagrams"))
|
|
405
|
+
else:
|
|
406
|
+
to_st = X[0]# if to_st is None else to_st
|
|
407
|
+
dataset = X[1]# if to_st is None else X
|
|
408
|
+
pickleable_todo = lambda x : todo(to_st(x))
|
|
409
|
+
return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(pickleable_todo)(x) for x in tqdm(dataset, disable=not self.progress, desc="Computing simplextrees and diagrams"))
|
|
410
|
+
warn("Bad input.")
|
|
411
|
+
return
|
|
412
|
+
|
|
413
|
+
# Shuffles a diagram shaped array. Input : list of (list of diagrams), output, list of (list of shuffled diagrams)
|
|
414
|
+
class DiagramShuffle(BaseEstimator, TransformerMixin):
|
|
415
|
+
def __init__(self, ) -> None:
|
|
416
|
+
super().__init__()
|
|
417
|
+
return
|
|
418
|
+
def fit(self, X:list[list[np.ndarray]], y=None):
|
|
419
|
+
return self
|
|
420
|
+
def transform(self,X:list[list[np.ndarray]]):
|
|
421
|
+
def shuffle(dgm):
|
|
422
|
+
shape = dgm.shape
|
|
423
|
+
dgm = dgm.flatten()
|
|
424
|
+
np.random.shuffle(dgm)
|
|
425
|
+
dgm = dgm.reshape(shape)
|
|
426
|
+
return dgm
|
|
427
|
+
def todo(dgms):
|
|
428
|
+
return [shuffle(dgm) for dgm in dgms]
|
|
429
|
+
return [todo(dgm) for dgm in X]
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
class Dgms2Landscapes(BaseEstimator, TransformerMixin):
|
|
433
|
+
def __init__(self, num:int=5, resolution:int=100, n_jobs:int=1) -> None:
|
|
434
|
+
super().__init__()
|
|
435
|
+
self.degrees:list[int] = []
|
|
436
|
+
self.num:int= num
|
|
437
|
+
self.resolution:int = resolution
|
|
438
|
+
self.landscapes:list[Landscape]= []
|
|
439
|
+
self.n_jobs=n_jobs
|
|
440
|
+
return
|
|
441
|
+
def fit(self, X, y=None):
|
|
442
|
+
if len(X) == 0: return self
|
|
443
|
+
self.degrees = list(range(len(X[0])))
|
|
444
|
+
self.landscapes = []
|
|
445
|
+
for dim in self.degrees:
|
|
446
|
+
self.landscapes.append(Landscape(num_landscapes=self.num,resolution=self.resolution).fit([dgms[dim] for dgms in X]))
|
|
447
|
+
return self
|
|
448
|
+
def transform(self,X):
|
|
449
|
+
if len(X) == 0: return []
|
|
450
|
+
return np.concatenate([landscape.transform([dgms[degree] for dgms in X]) for degree, landscape in enumerate(self.landscapes)], axis=1)
|
|
451
|
+
|
|
452
|
+
class Dgms2Image(BaseEstimator, TransformerMixin):
|
|
453
|
+
def __init__(self, bandwidth:float=1, resolution:tuple[int,int]=(20,20), n_jobs:int=1) -> None:
|
|
454
|
+
super().__init__()
|
|
455
|
+
self.degrees:list[int] = []
|
|
456
|
+
self.bandwidth:float= bandwidth
|
|
457
|
+
self.resolution = resolution
|
|
458
|
+
self.PI:list[PersistenceImage]= []
|
|
459
|
+
self.n_jobs=n_jobs
|
|
460
|
+
return
|
|
461
|
+
def fit(self, X, y=None):
|
|
462
|
+
if len(X) == 0: return self
|
|
463
|
+
self.degrees = list(range(len(X[0])))
|
|
464
|
+
self.PI = []
|
|
465
|
+
for dim in self.degrees:
|
|
466
|
+
self.PI.append(PersistenceImage(bandwidth=self.bandwidth,resolution=self.resolution).fit([dgms[dim] for dgms in X]))
|
|
467
|
+
return self
|
|
468
|
+
def transform(self,X):
|
|
469
|
+
if len(X) == 0: return []
|
|
470
|
+
return np.concatenate([pers_image.transform([dgms[degree] for dgms in X]) for degree, pers_image in enumerate(self.PI)], axis=1)
|
|
471
|
+
|
|
472
|
+
|