multipers 1.1.3__cp310-cp310-macosx_11_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/.dylibs/libtbb.12.12.dylib +0 -0
- multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
- multipers/__init__.py +5 -0
- multipers/_old_rank_invariant.pyx +328 -0
- multipers/_signed_measure_meta.py +193 -0
- multipers/data/MOL2.py +350 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +68 -0
- multipers/distances.py +172 -0
- multipers/euler_characteristic.cpython-310-darwin.so +0 -0
- multipers/euler_characteristic.pyx +137 -0
- multipers/function_rips.cpython-310-darwin.so +0 -0
- multipers/function_rips.pyx +102 -0
- multipers/hilbert_function.cpython-310-darwin.so +0 -0
- multipers/hilbert_function.pyi +46 -0
- multipers/hilbert_function.pyx +151 -0
- multipers/io.cpython-310-darwin.so +0 -0
- multipers/io.pyx +176 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +61 -0
- multipers/ml/convolutions.py +510 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +128 -0
- multipers/ml/mma.py +657 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +191 -0
- multipers/ml/signed_betti.py +50 -0
- multipers/ml/signed_measures.py +1479 -0
- multipers/ml/sliced_wasserstein.py +313 -0
- multipers/ml/tools.py +116 -0
- multipers/mma_structures.cpython-310-darwin.so +0 -0
- multipers/mma_structures.pxd +155 -0
- multipers/mma_structures.pyx +651 -0
- multipers/multiparameter_edge_collapse.py +29 -0
- multipers/multiparameter_module_approximation.cpython-310-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyi +439 -0
- multipers/multiparameter_module_approximation.pyx +311 -0
- multipers/pickle.py +53 -0
- multipers/plots.py +292 -0
- multipers/point_measure_integration.cpython-310-darwin.so +0 -0
- multipers/point_measure_integration.pyx +59 -0
- multipers/rank_invariant.cpython-310-darwin.so +0 -0
- multipers/rank_invariant.pyx +154 -0
- multipers/simplex_tree_multi.cpython-310-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +121 -0
- multipers/simplex_tree_multi.pyi +715 -0
- multipers/simplex_tree_multi.pyx +1417 -0
- multipers/slicer.cpython-310-darwin.so +0 -0
- multipers/slicer.pxd +94 -0
- multipers/slicer.pyx +276 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers-1.1.3.dist-info/LICENSE +21 -0
- multipers-1.1.3.dist-info/METADATA +22 -0
- multipers-1.1.3.dist-info/RECORD +63 -0
- multipers-1.1.3.dist-info/WHEEL +5 -0
- multipers-1.1.3.dist-info/top_level.txt +1 -0
multipers/ml/one.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
2
|
+
import gudhi as gd
|
|
3
|
+
from os.path import exists
|
|
4
|
+
import networkx as nx
|
|
5
|
+
from joblib import Parallel, delayed
|
|
6
|
+
import numpy as np
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from warnings import warn
|
|
9
|
+
from sklearn.neighbors import KernelDensity
|
|
10
|
+
from typing import Iterable
|
|
11
|
+
from gudhi.representations import Landscape
|
|
12
|
+
from gudhi.representations.vector_methods import PersistenceImage
|
|
13
|
+
from gudhi.representations.kernel_methods import SlicedWassersteinDistance
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
from types import FunctionType
|
|
17
|
+
def get_simplextree(x)->gd.SimplexTree:
|
|
18
|
+
if isinstance(x, gd.SimplexTree):
|
|
19
|
+
return x
|
|
20
|
+
if isinstance(x, FunctionType):
|
|
21
|
+
return x()
|
|
22
|
+
if len(x) == 3 and isinstance(x[0],FunctionType):
|
|
23
|
+
f,args, kwargs = x
|
|
24
|
+
return f(*args,**kwargs)
|
|
25
|
+
raise TypeError("Not a valid SimplexTree")
|
|
26
|
+
def get_simplextrees(X)->Iterable[gd.SimplexTree]:
|
|
27
|
+
if len(X) == 2 and isinstance(X[0], FunctionType):
|
|
28
|
+
f,data = X
|
|
29
|
+
return (f(x) for x in data)
|
|
30
|
+
if len(X) == 0: return []
|
|
31
|
+
if not isinstance(X[0], gd.SimplexTree):
|
|
32
|
+
raise TypeError
|
|
33
|
+
return X
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
############## INTERVALS (for sliced wasserstein)
|
|
39
|
+
class Graph2SimplexTree(BaseEstimator,TransformerMixin):
|
|
40
|
+
def __init__(self, f:str="ricciCurvature",dtype=gd.SimplexTree, reverse_filtration:bool=False):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.f=f # filtration to search in graph
|
|
43
|
+
self.dtype = dtype # If None, will delay the computation in the pipe (for parallelism)
|
|
44
|
+
self.reverse_filtration = reverse_filtration # reverses the filtration #TODO
|
|
45
|
+
def fit(self, X, y=None):
|
|
46
|
+
return self
|
|
47
|
+
def transform(self,X:list[nx.Graph]):
|
|
48
|
+
def todo(graph, f=self.f) -> gd.SimplexTree: # TODO : use batch insert
|
|
49
|
+
st = gd.SimplexTree()
|
|
50
|
+
for i in graph.nodes: st.insert([i], graph.nodes[i][f])
|
|
51
|
+
for u,v in graph.edges: st.insert([u,v], graph[u][v][f])
|
|
52
|
+
return st
|
|
53
|
+
return [todo, X] if self.dtype is None else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(graph) for graph in X)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PointCloud2SimplexTree(BaseEstimator,TransformerMixin):
|
|
57
|
+
def __init__(self, delayed:bool = False, threshold = np.inf):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.delayed = delayed
|
|
60
|
+
self.threshold=threshold
|
|
61
|
+
@staticmethod
|
|
62
|
+
def _get_point_cloud_diameter(x):
|
|
63
|
+
from scipy.spatial import distance_matrix
|
|
64
|
+
return np.max(distance_matrix(x,x))
|
|
65
|
+
def fit(self, X, y=None):
|
|
66
|
+
if self.threshold < 0:
|
|
67
|
+
self.threshold = max(self._get_point_cloud_diameter(x) for x in X)
|
|
68
|
+
return self
|
|
69
|
+
def transform(self,X:list[nx.Graph]):
|
|
70
|
+
def todo(point_cloud) -> gd.SimplexTree: # TODO : use batch insert
|
|
71
|
+
st = gd.AlphaComplex(points=point_cloud).create_simplex_tree(max_alpha_square = self.threshold**2)
|
|
72
|
+
return st
|
|
73
|
+
return [todo, X] if self.delayed is None else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(point_cloud) for point_cloud in X)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
#################### FILVEC
|
|
78
|
+
def get_filtration_values(g:nx.Graph, f:str)->np.ndarray:
|
|
79
|
+
filtrations_values = [
|
|
80
|
+
g.nodes[node][f] for node in g.nodes
|
|
81
|
+
]+[
|
|
82
|
+
g[u][v][f] for u,v in g.edges
|
|
83
|
+
]
|
|
84
|
+
return np.array(filtrations_values)
|
|
85
|
+
def graph2filvec(g:nx.Graph, f:str, range:tuple, bins:int)->np.ndarray:
|
|
86
|
+
fs = get_filtration_values(g, f)
|
|
87
|
+
return np.histogram(fs, bins=bins,range=range)[0]
|
|
88
|
+
class FilvecGetter(BaseEstimator, TransformerMixin):
|
|
89
|
+
def __init__(self, f:str="ricciCurvature",quantile:float=0., bins:int=100, n_jobs:int=1):
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.f=f
|
|
92
|
+
self.quantile=quantile
|
|
93
|
+
self.bins=bins
|
|
94
|
+
self.range:tuple[float]|None=None
|
|
95
|
+
self.n_jobs=n_jobs
|
|
96
|
+
def fit(self, X, y=None):
|
|
97
|
+
filtration_values = np.concatenate(Parallel(n_jobs=self.n_jobs)(delayed(get_filtration_values)(g,f=self.f) for g in X))
|
|
98
|
+
self.range= tuple(np.quantile(filtration_values, [self.quantile, 1-self.quantile]))
|
|
99
|
+
return self
|
|
100
|
+
def transform(self,X):
|
|
101
|
+
if self.range == None:
|
|
102
|
+
print("Fit first")
|
|
103
|
+
return
|
|
104
|
+
return Parallel(n_jobs=self.n_jobs)(delayed(graph2filvec)(g,f=self.f, range=self.range, bins=self.bins) for g in X)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
############# Filvec from SimplexTree
|
|
110
|
+
# Input list of [list of diagrams], outputs histogram of persitence values (x and y coord mixed)
|
|
111
|
+
def simplextree2hist(simplextree, range:tuple[float, float], bins:int, density:bool)->np.ndarray: #TODO : Anything to histogram
|
|
112
|
+
filtration_values = np.array([f for s,f in simplextree.get_simplices()])
|
|
113
|
+
return np.histogram(filtration_values, bins=bins,range=range, density=density)[0]
|
|
114
|
+
class SimplexTree2Histogram(BaseEstimator, TransformerMixin):
|
|
115
|
+
def __init__(self, quantile:float=0., bins:int=100, n_jobs:int=1, progress:bool=False, density:bool=True):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.range:np.ndarray | None=None
|
|
118
|
+
self.quantile:float=quantile
|
|
119
|
+
self.bins:int=bins
|
|
120
|
+
self.n_jobs=n_jobs
|
|
121
|
+
self.density=density
|
|
122
|
+
self.progress = progress
|
|
123
|
+
# self.max_dimension=None # TODO: maybe use it
|
|
124
|
+
def fit(self, X, y=None): # X:list[diagrams]
|
|
125
|
+
if len(X) == 0: return self
|
|
126
|
+
if type(X[0]) is gd.SimplexTree: # If X contains simplextree : nothing to do
|
|
127
|
+
data = X
|
|
128
|
+
to_st = lambda x : x
|
|
129
|
+
else: # otherwise we assume that we retrieve simplextrees using f,data = X; simplextrees = (f(x) for x in data)
|
|
130
|
+
# assert len(X) == 2
|
|
131
|
+
to_st, data = X
|
|
132
|
+
persistence_values = np.array([f for st in data for s,f in to_st(st).get_simplices()])
|
|
133
|
+
persistence_values = persistence_values[persistence_values<np.inf]
|
|
134
|
+
self.range = np.quantile(persistence_values, [self.quantile, 1-self.quantile])
|
|
135
|
+
return self
|
|
136
|
+
def transform(self,X):
|
|
137
|
+
if len(X) == 0: return self
|
|
138
|
+
if type(X[0]) is gd.SimplexTree: # If X contains simplextree : nothing to do
|
|
139
|
+
if self.n_jobs > 1:
|
|
140
|
+
warn("Cannot pickle simplextrees, reducing to 1 thread to compute the simplextrees")
|
|
141
|
+
return [simplextree2hist(g,range=self.range, bins=self.bins, density=self.density) for g in tqdm(X, desc="Computing diagrams", disable=not self.progress)]
|
|
142
|
+
else: # otherwise we assume that we retrieve simplextrees using f,data = X; simplextrees = (f(x) for x in data)
|
|
143
|
+
to_st, data = X # asserts len(X) == 2
|
|
144
|
+
def pickle_able_todo(x, **kwargs):
|
|
145
|
+
simplextree = to_st(x)
|
|
146
|
+
return simplextree2hist(simplextree=simplextree, **kwargs)
|
|
147
|
+
return Parallel(n_jobs=self.n_jobs)(delayed(pickle_able_todo)(g,range=self.range, bins=self.bins, density=self.density) for g in tqdm(data, desc="Computing simplextrees and their diagrams", disable=not self.progress))
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
############# PERVEC
|
|
153
|
+
# Input list of [list of diagrams], outputs histogram of persitence values (x and y coord mixed)
|
|
154
|
+
def dgm2pervec(dgms, range:tuple[float, float], bins:int)->np.ndarray: #TODO : Anything to histogram
|
|
155
|
+
dgm_union = np.concatenate([dgm.flatten() for dgm in dgms]).flatten()
|
|
156
|
+
return np.histogram(dgm_union, bins=bins,range=range)[0]
|
|
157
|
+
class Dgm2Histogram(BaseEstimator, TransformerMixin):
|
|
158
|
+
def __init__(self, quantile:float=0., bins:int=100, n_jobs:int=1):
|
|
159
|
+
super().__init__()
|
|
160
|
+
self.range:np.ndarray | None=None
|
|
161
|
+
self.quantile:float=quantile
|
|
162
|
+
self.bins:int=bins
|
|
163
|
+
self.n_jobs=n_jobs
|
|
164
|
+
def fit(self, X, y=None): # X:list[diagrams]
|
|
165
|
+
persistence_values = np.concatenate([dgm.flatten() for dgms in X for dgm in dgms], axis=0).flatten()
|
|
166
|
+
persistence_values = persistence_values[persistence_values<np.inf]
|
|
167
|
+
self.range = np.quantile(persistence_values, [self.quantile, 1-self.quantile])
|
|
168
|
+
return self
|
|
169
|
+
def transform(self,X):
|
|
170
|
+
return Parallel(n_jobs=self.n_jobs)(delayed(dgm2pervec)(g,range=self.range, bins=self.bins) for g in X)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
################# SignedMeasureImage
|
|
179
|
+
class Dgms2SignedMeasureImage(BaseEstimator, TransformerMixin):
|
|
180
|
+
def __init__(self, ranges:None|Iterable[Iterable[float]]=None, resolution:int=100, quantile:float=0, bandwidth:float=1, kernel:str="gaussian") -> None:
|
|
181
|
+
super().__init__()
|
|
182
|
+
self.ranges=ranges
|
|
183
|
+
self.resolution=resolution
|
|
184
|
+
self.quantile = quantile
|
|
185
|
+
self.bandwidth = bandwidth
|
|
186
|
+
self.kernel = kernel
|
|
187
|
+
def fit(self, X, y=None): # X:list[diagrams]
|
|
188
|
+
num_degrees = len(X[0])
|
|
189
|
+
persistence_values = [np.concatenate([dgms[i].flatten() for dgms in X], axis=0) for i in range(num_degrees)] # values per degree
|
|
190
|
+
persistence_values = [degrees_values[(-np.inf<degrees_values) * (degrees_values<np.inf)] for degrees_values in persistence_values] # non-trivial values
|
|
191
|
+
quantiles = [np.quantile(degree_values, [self.quantile, 1-self.quantile]) for degree_values in persistence_values] # quantiles
|
|
192
|
+
self.ranges = np.array([np.linspace(start=[a], stop=[b], num=self.resolution) for a,b in quantiles])
|
|
193
|
+
return self
|
|
194
|
+
|
|
195
|
+
def _dgm2smi(self, dgms:Iterable[np.ndarray]):
|
|
196
|
+
smi = np.concatenate(
|
|
197
|
+
[
|
|
198
|
+
KernelDensity(bandwidth=self.bandwidth, kernel=self.kernel).fit(dgm[:,[0]]).score_samples(range)
|
|
199
|
+
- KernelDensity(bandwidth=self.bandwidth).fit(dgm[:,[1]]).score_samples(range)
|
|
200
|
+
for dgm, range in zip(dgms, self.ranges)
|
|
201
|
+
],
|
|
202
|
+
axis=0)
|
|
203
|
+
return smi
|
|
204
|
+
|
|
205
|
+
def transform(self,X): # X is a list (data) of list of diagrams
|
|
206
|
+
assert self.ranges is not None
|
|
207
|
+
out = Parallel(n_jobs=1, prefer="threads")(
|
|
208
|
+
delayed(Dgms2SignedMeasureImage._dgm2smi)(self=self, dgms=dgms)
|
|
209
|
+
for dgms in X
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return out
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
################# SignedMeasureHistogram
|
|
217
|
+
class Dgms2SignedMeasureHistogram(BaseEstimator, TransformerMixin):
|
|
218
|
+
def __init__(self, ranges:None|list[tuple[float,float]]=None, bins:int=100, quantile:float=0) -> None:
|
|
219
|
+
super().__init__()
|
|
220
|
+
self.ranges=ranges
|
|
221
|
+
self.bins=bins
|
|
222
|
+
self.quantile = quantile
|
|
223
|
+
def fit(self, X, y=None): # X:list[diagrams]
|
|
224
|
+
num_degrees = len(X[0])
|
|
225
|
+
persistence_values = [np.concatenate([dgms[i].flatten() for dgms in X], axis=0) for i in range(num_degrees)] # values per degree
|
|
226
|
+
persistence_values = [degrees_values[(-np.inf<degrees_values) * (degrees_values<np.inf)] for degrees_values in persistence_values] # non-trivial values
|
|
227
|
+
self.ranges = [np.quantile(degree_values, [self.quantile, 1-self.quantile]) for degree_values in persistence_values] # quantiles
|
|
228
|
+
return self
|
|
229
|
+
def transform(self,X): # X is a list (data) of list of diagrams
|
|
230
|
+
assert self.ranges is not None
|
|
231
|
+
out = [
|
|
232
|
+
np.concatenate(
|
|
233
|
+
[np.histogram(dgm[:,0], bins=self.bins,range=range)[0] - np.histogram(dgm[:,1], bins=self.bins,range=range)[0]
|
|
234
|
+
for dgm, range in zip(dgms, self.ranges)]
|
|
235
|
+
)
|
|
236
|
+
for dgms in X]
|
|
237
|
+
return out
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
################## Signed Measure Kernel 1D
|
|
247
|
+
# input : list of [list of diagrams], outputs: the kernel to feed to an svm
|
|
248
|
+
|
|
249
|
+
# TODO : optimize ?
|
|
250
|
+
## TODO : np.triu
|
|
251
|
+
class Dgms2SignedMeasureDistance(BaseEstimator, TransformerMixin):
|
|
252
|
+
def __init__(self, n_jobs:int=1, distance_matrix_path:str|None=None, progress:bool = False) -> None:
|
|
253
|
+
super().__init__()
|
|
254
|
+
self.degrees:list[int]|None=None
|
|
255
|
+
self.X:None|list[np.ndarray] = None
|
|
256
|
+
self.n_jobs=n_jobs
|
|
257
|
+
self.distance_matrix_path = distance_matrix_path
|
|
258
|
+
self.progress=progress
|
|
259
|
+
def fit(self, X:list[np.ndarray], y=None):
|
|
260
|
+
if len(X) <= 0:
|
|
261
|
+
warn("Fit a nontrivial vector")
|
|
262
|
+
return
|
|
263
|
+
self.X = X
|
|
264
|
+
self.degrees = list(range(len(X[0]))) # Assumes that all x \in X have the same number of diagrams
|
|
265
|
+
return self
|
|
266
|
+
|
|
267
|
+
@staticmethod
|
|
268
|
+
def wasserstein_1(a:np.ndarray,b:np.ndarray)->float:
|
|
269
|
+
return np.abs(np.sort(a) - np.sort(b)).mean() # norm 1
|
|
270
|
+
@staticmethod
|
|
271
|
+
def OSWdistance(mu:list[np.ndarray], nu:list[np.ndarray], dim:int)->float:
|
|
272
|
+
return Dgms2SignedMeasureDistance.wasserstein_1(np.hstack([mu[dim][:,0], nu[dim][:,1]]), np.hstack([nu[dim][:,0], mu[dim][:,1]])) # TODO : check: do we want to sum the kernels or the distances ? add weights ?
|
|
273
|
+
@staticmethod
|
|
274
|
+
def _ds(mu:list[np.ndarray], nus:list[list[np.ndarray]], dim:int): # mu and nu are lists of diagrams seen as signed measures (birth = +, death = -)
|
|
275
|
+
return [Dgms2SignedMeasureDistance.OSWdistance(mu,nu, dim) for nu in nus]
|
|
276
|
+
|
|
277
|
+
def transform(self,X): # X is a list (data) of list of diagrams
|
|
278
|
+
if self.X is None or self.degrees is None:
|
|
279
|
+
warn("Fit first !")
|
|
280
|
+
return np.array([[]])
|
|
281
|
+
# Cannot use sklearn / scipy, measures don't have the same size, -> no numpy array
|
|
282
|
+
# from sklearn.metrics import pairwise_distances
|
|
283
|
+
# distances = pairwise_distances(X, self.X, metric = OSWdistance, n_jobs=self.n_jobs)
|
|
284
|
+
# from scipy.spatial.distance import cdist
|
|
285
|
+
# distances = cdist(X, self.X, metric=self.OSWdistance)
|
|
286
|
+
distances_matrices = []
|
|
287
|
+
if not self.distance_matrix_path is None:
|
|
288
|
+
for degree in self.degrees:
|
|
289
|
+
with tqdm(X, desc=f"Computing distance matrix of degree {degree}") as diagrams_iterator:
|
|
290
|
+
matrix_path = f"{self.distance_matrix_path}_{degree}"
|
|
291
|
+
if exists(matrix_path):
|
|
292
|
+
distance_matrix = np.load(open(matrix_path, "rb"))
|
|
293
|
+
else:
|
|
294
|
+
distance_matrix = np.array(Parallel(n_jobs=self.n_jobs)(delayed(self._ds)(mu, self.X, degree) for mu in diagrams_iterator))
|
|
295
|
+
np.save(open(matrix_path, "wb"), distance_matrix)
|
|
296
|
+
distances_matrices.append(distance_matrix)
|
|
297
|
+
else:
|
|
298
|
+
for degree in self.degrees:
|
|
299
|
+
with tqdm(X, desc=f"Computing distance matrix of degree {degree}") as diagrams_iterator:
|
|
300
|
+
distances_matrices.append(np.array(Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self._ds)(mu, self.X, degree) for mu in diagrams_iterator)))
|
|
301
|
+
return np.asarray(distances_matrices)
|
|
302
|
+
# kernels = [np.exp(-distance_matrix / (2*self.sigma**2)) for distance_matrix in distances_matrices]
|
|
303
|
+
# return np.sum(kernels, axis=0)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
## Wrapper for SW, in order to take as an input a list of (list of diagrams)
|
|
310
|
+
class Dgms2SWK(BaseEstimator, TransformerMixin):
|
|
311
|
+
def __init__(self, num_directions:int=10, bandwidth:float=1.0, n_jobs:int=1, distance_matrix_path:str|None = None, progress:bool = False) -> None:
|
|
312
|
+
super().__init__()
|
|
313
|
+
self.num_directions:int=num_directions
|
|
314
|
+
self.bandwidth:float = bandwidth
|
|
315
|
+
self.n_jobs=n_jobs
|
|
316
|
+
self.SW_:list = []
|
|
317
|
+
self.distance_matrix_path = distance_matrix_path
|
|
318
|
+
self.progress = progress
|
|
319
|
+
def fit(self, X:list[list[np.ndarray]], y=None):
|
|
320
|
+
# Assumes that all x \in X have the same size
|
|
321
|
+
self.SW_ = [
|
|
322
|
+
SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs = self.n_jobs) for _ in range(len(X[0]))
|
|
323
|
+
]
|
|
324
|
+
for i, sw in enumerate(self.SW_):
|
|
325
|
+
self.SW_[i]=sw.fit([dgms[i] for dgms in X]) # TODO : check : Not sure copy is necessary here
|
|
326
|
+
return self
|
|
327
|
+
def transform(self,X)->np.ndarray:
|
|
328
|
+
if not self.distance_matrix_path is None:
|
|
329
|
+
distance_matrices = []
|
|
330
|
+
for i in range(len(self.SW_)):
|
|
331
|
+
SW_i_path = f"{self.distance_matrix_path}_{i}"
|
|
332
|
+
if exists(SW_i_path):
|
|
333
|
+
distance_matrices.append(np.load(open(SW_i_path, "rb")))
|
|
334
|
+
else:
|
|
335
|
+
distance_matrix = self.SW_[i].transform([dgms[i] for dgms in X])
|
|
336
|
+
np.save(open(SW_i_path, "wb"), distance_matrix)
|
|
337
|
+
else:
|
|
338
|
+
distance_matrices = [sw.transform([dgms[i] for dgms in X]) for i, sw in enumerate(self.SW_)]
|
|
339
|
+
kernels = [np.exp(-distance_matrix / (2*self.bandwidth**2)) for distance_matrix in distance_matrices]
|
|
340
|
+
return np.sum(kernels, axis=0) # TODO fix this, we may want to sum the distances instead of the kernels.
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
class Dgms2SlicedWassersteinDistanceMatrices(BaseEstimator, TransformerMixin):
|
|
344
|
+
def __init__(self, num_directions:int=10, n_jobs:int=1) -> None:
|
|
345
|
+
super().__init__()
|
|
346
|
+
self.num_directions:int=num_directions
|
|
347
|
+
self.n_jobs=n_jobs
|
|
348
|
+
self.SW_:list = []
|
|
349
|
+
def fit(self, X:list[list[np.ndarray]], y=None):
|
|
350
|
+
# Assumes that all x \in X have the same size
|
|
351
|
+
self.SW_ = [
|
|
352
|
+
SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs = self.n_jobs) for _ in range(len(X[0]))
|
|
353
|
+
]
|
|
354
|
+
for i, sw in enumerate(self.SW_):
|
|
355
|
+
self.SW_[i]=sw.fit([dgms[i] for dgms in X]) # TODO : check : Not sure copy is necessary here
|
|
356
|
+
return self
|
|
357
|
+
|
|
358
|
+
@staticmethod
|
|
359
|
+
def _get_distance(diagrams, SWD):
|
|
360
|
+
return SWD.transform(diagrams)
|
|
361
|
+
def transform(self,X):
|
|
362
|
+
distance_matrices = Parallel(n_jobs = self.n_jobs)(delayed(self._get_distance)([dgms[degree] for dgms in X], swd) for degree, swd in enumerate(self.SW_))
|
|
363
|
+
return np.asarray(distance_matrices)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
# Gudhi simplexTree to list of diagrams
|
|
368
|
+
class SimplexTree2Dgm(BaseEstimator, TransformerMixin):
|
|
369
|
+
def __init__(self, degrees:list[int]|None = None, extended:list[int]|bool=[], n_jobs=1, progress:bool=False, threshold:float=np.inf) -> None:
|
|
370
|
+
super().__init__()
|
|
371
|
+
self.extended:list[int]|bool = False if not extended else extended if type(extended) is list else [0,2,5,7] # extended persistence.
|
|
372
|
+
# There are 4 diagrams per dimension then, the list of ints acts as a filter, on which to consider,
|
|
373
|
+
# eg., [0,2, 5,7] is Ord0, Ext+0, Rel1, Ext-1
|
|
374
|
+
self.degrees:list[int] = degrees if degrees else list(range((max(self.extended) // 4)+1)) if self.extended else [0] # homological degrees
|
|
375
|
+
self.n_jobs=n_jobs
|
|
376
|
+
self.progress = progress # progress bar
|
|
377
|
+
self.threshold = threshold # Threshold value
|
|
378
|
+
return
|
|
379
|
+
def fit(self, X:list[gd.SimplexTree], y=None):
|
|
380
|
+
if self.threshold <= 0:
|
|
381
|
+
self.threshold = max( (abs(f) for simplextree in get_simplextrees(X) for s,f in simplextree.get_simplices()) ) ## MAX FILTRATION VALUE
|
|
382
|
+
print(f"Setting threshold to {self.threshold}.")
|
|
383
|
+
return self
|
|
384
|
+
def transform(self,X:list[gd.SimplexTree]):
|
|
385
|
+
# Todo computes the diagrams
|
|
386
|
+
def reshape(dgm:np.ndarray|list)->np.ndarray:
|
|
387
|
+
out = np.array(dgm) if len(dgm) > 0 else np.empty((0,2))
|
|
388
|
+
if self.threshold != np.inf:
|
|
389
|
+
out[out>self.threshold] = self.threshold
|
|
390
|
+
out[out<-self.threshold] = -self.threshold
|
|
391
|
+
return out
|
|
392
|
+
def todo_standard(st):
|
|
393
|
+
st.compute_persistence()
|
|
394
|
+
return [reshape(st.persistence_intervals_in_dimension(d)) for d in self.degrees]
|
|
395
|
+
def todo_extended(st):
|
|
396
|
+
st.extend_filtration()
|
|
397
|
+
dgms = st.extended_persistence()
|
|
398
|
+
# print(dgms, self.degrees)
|
|
399
|
+
return [reshape([bar for j,dgm in enumerate(dgms) for d, bar in dgm if d in self.degrees and j+4*d in self.extended])]
|
|
400
|
+
todo = todo_extended if self.extended else todo_standard
|
|
401
|
+
|
|
402
|
+
if isinstance(X[0],gd.SimplexTree): # simplextree aren't pickleable, no parallel
|
|
403
|
+
# if self.n_jobs != 1: warn("Cannot parallelize. Use dtype=None in previous pipe.")
|
|
404
|
+
return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(todo)(x) for x in tqdm(X, disable=not self.progress, desc="Computing diagrams"))
|
|
405
|
+
else:
|
|
406
|
+
to_st = X[0]# if to_st is None else to_st
|
|
407
|
+
dataset = X[1]# if to_st is None else X
|
|
408
|
+
pickleable_todo = lambda x : todo(to_st(x))
|
|
409
|
+
return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(pickleable_todo)(x) for x in tqdm(dataset, disable=not self.progress, desc="Computing simplextrees and diagrams"))
|
|
410
|
+
warn("Bad input.")
|
|
411
|
+
return
|
|
412
|
+
|
|
413
|
+
# Shuffles a diagram shaped array. Input : list of (list of diagrams), output, list of (list of shuffled diagrams)
|
|
414
|
+
class DiagramShuffle(BaseEstimator, TransformerMixin):
|
|
415
|
+
def __init__(self, ) -> None:
|
|
416
|
+
super().__init__()
|
|
417
|
+
return
|
|
418
|
+
def fit(self, X:list[list[np.ndarray]], y=None):
|
|
419
|
+
return self
|
|
420
|
+
def transform(self,X:list[list[np.ndarray]]):
|
|
421
|
+
def shuffle(dgm):
|
|
422
|
+
shape = dgm.shape
|
|
423
|
+
dgm = dgm.flatten()
|
|
424
|
+
np.random.shuffle(dgm)
|
|
425
|
+
dgm = dgm.reshape(shape)
|
|
426
|
+
return dgm
|
|
427
|
+
def todo(dgms):
|
|
428
|
+
return [shuffle(dgm) for dgm in dgms]
|
|
429
|
+
return [todo(dgm) for dgm in X]
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
class Dgms2Landscapes(BaseEstimator, TransformerMixin):
|
|
433
|
+
def __init__(self, num:int=5, resolution:int=100, n_jobs:int=1) -> None:
|
|
434
|
+
super().__init__()
|
|
435
|
+
self.degrees:list[int] = []
|
|
436
|
+
self.num:int= num
|
|
437
|
+
self.resolution:int = resolution
|
|
438
|
+
self.landscapes:list[Landscape]= []
|
|
439
|
+
self.n_jobs=n_jobs
|
|
440
|
+
return
|
|
441
|
+
def fit(self, X, y=None):
|
|
442
|
+
if len(X) == 0: return self
|
|
443
|
+
self.degrees = list(range(len(X[0])))
|
|
444
|
+
self.landscapes = []
|
|
445
|
+
for dim in self.degrees:
|
|
446
|
+
self.landscapes.append(Landscape(num_landscapes=self.num,resolution=self.resolution).fit([dgms[dim] for dgms in X]))
|
|
447
|
+
return self
|
|
448
|
+
def transform(self,X):
|
|
449
|
+
if len(X) == 0: return []
|
|
450
|
+
return np.concatenate([landscape.transform([dgms[degree] for dgms in X]) for degree, landscape in enumerate(self.landscapes)], axis=1)
|
|
451
|
+
|
|
452
|
+
class Dgms2Image(BaseEstimator, TransformerMixin):
|
|
453
|
+
def __init__(self, bandwidth:float=1, resolution:tuple[int,int]=(20,20), n_jobs:int=1) -> None:
|
|
454
|
+
super().__init__()
|
|
455
|
+
self.degrees:list[int] = []
|
|
456
|
+
self.bandwidth:float= bandwidth
|
|
457
|
+
self.resolution = resolution
|
|
458
|
+
self.PI:list[PersistenceImage]= []
|
|
459
|
+
self.n_jobs=n_jobs
|
|
460
|
+
return
|
|
461
|
+
def fit(self, X, y=None):
|
|
462
|
+
if len(X) == 0: return self
|
|
463
|
+
self.degrees = list(range(len(X[0])))
|
|
464
|
+
self.PI = []
|
|
465
|
+
for dim in self.degrees:
|
|
466
|
+
self.PI.append(PersistenceImage(bandwidth=self.bandwidth,resolution=self.resolution).fit([dgms[dim] for dgms in X]))
|
|
467
|
+
return self
|
|
468
|
+
def transform(self,X):
|
|
469
|
+
if len(X) == 0: return []
|
|
470
|
+
return np.concatenate([pers_image.transform([dgms[degree] for dgms in X]) for degree, pers_image in enumerate(self.PI)], axis=1)
|
|
471
|
+
|
|
472
|
+
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from numpy.core.multiarray import concatenate
|
|
3
|
+
from numpy.lib import copy
|
|
4
|
+
import gudhi as gd
|
|
5
|
+
import multipers as mp
|
|
6
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
7
|
+
from multipers.ml.convolutions import KDE, DTM
|
|
8
|
+
from joblib import Parallel, delayed
|
|
9
|
+
from sklearn.metrics import pairwise_distances
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
from typing import Literal,Optional
|
|
12
|
+
|
|
13
|
+
from multipers.simplex_tree_multi import SimplexTreeMulti
|
|
14
|
+
|
|
15
|
+
def _throw_nofit(any):
|
|
16
|
+
raise Exception("Fit first")
|
|
17
|
+
|
|
18
|
+
class PointCloud2SimplexTree(BaseEstimator, TransformerMixin):
|
|
19
|
+
def __init__(self,
|
|
20
|
+
bandwidths=[],
|
|
21
|
+
masses = [],
|
|
22
|
+
threshold:float=np.inf,
|
|
23
|
+
complex='rips',
|
|
24
|
+
sparse:float|None=None,
|
|
25
|
+
num_collapses:int|Literal['full']='full',
|
|
26
|
+
kernel:str="gaussian",
|
|
27
|
+
expand_dim:int=1,
|
|
28
|
+
progress:bool=False,
|
|
29
|
+
n_jobs:Optional[int]=None,
|
|
30
|
+
fit_fraction:float=1,
|
|
31
|
+
verbose:bool=False,
|
|
32
|
+
safe_conversion:bool=False,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""
|
|
35
|
+
(Rips or Alpha) + (Density Estimation or DTM) 1-critical 2-filtration.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
- bandwidth : real : The kernel density estimation bandwidth, or the DTM mass. If negative, it replaced by abs(bandwidth)*(radius of the dataset)
|
|
40
|
+
- threshold : real, max edge lenfth of the rips or max alpha square of the alpha
|
|
41
|
+
- sparse : real, sparse rips (c.f. rips doc) WARNING : ONLY FOR RIPS
|
|
42
|
+
- num_collapse : int, Number of edge collapses applied to the simplextrees, WARNING : ONLY FOR RIPS
|
|
43
|
+
- expand_dim : int, expand the rips complex to this dimension. WARNING : ONLY FOR RIPS
|
|
44
|
+
- kernel : the kernel used for density estimation. Available ones are, e.g., "dtm", "gaussian", "exponential".
|
|
45
|
+
- progress : bool, shows the calculus status
|
|
46
|
+
- n_jobs : number of processes
|
|
47
|
+
- fit_fraction : real, the fraction of data on which to fit
|
|
48
|
+
- verbose : bool, Shows more information if true.
|
|
49
|
+
|
|
50
|
+
Output
|
|
51
|
+
------
|
|
52
|
+
A list of SimplexTreeMulti whose first parameter is a rips and the second is the codensity.
|
|
53
|
+
"""
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.bandwidths = bandwidths
|
|
56
|
+
self.masses=masses
|
|
57
|
+
self.num_collapses=num_collapses
|
|
58
|
+
self.kernel = kernel
|
|
59
|
+
self.progress=progress
|
|
60
|
+
self._bandwidths= np.empty((0,))
|
|
61
|
+
self._threshold=np.inf
|
|
62
|
+
self.n_jobs = n_jobs
|
|
63
|
+
self._scale=np.empty((0,))
|
|
64
|
+
self.fit_fraction=fit_fraction
|
|
65
|
+
self.expand_dim=expand_dim
|
|
66
|
+
self.verbose=verbose
|
|
67
|
+
self.complex=complex
|
|
68
|
+
self.threshold=threshold
|
|
69
|
+
self.sparse=sparse
|
|
70
|
+
self._get_sts = _throw_nofit
|
|
71
|
+
self.safe_conversion=safe_conversion
|
|
72
|
+
return
|
|
73
|
+
def _get_distance_quantiles(self, X, qs):
|
|
74
|
+
if len(qs) == 0:
|
|
75
|
+
self._scale = []
|
|
76
|
+
return []
|
|
77
|
+
if self.progress: print("Estimating scale...", flush=True, end="")
|
|
78
|
+
indices = np.random.choice(len(X),min(len(X), int(self.fit_fraction*len(X))+1) ,replace=False)
|
|
79
|
+
# diameter = np.asarray([distance_matrix(x,x).max() for x in (X[i] for i in indices)]).max()
|
|
80
|
+
diameter = np.max([pairwise_distances(X = x).max() for x in (X[i] for i in indices)])
|
|
81
|
+
self._scale = diameter * np.asarray(qs)
|
|
82
|
+
if self.threshold > 0: self._scale[self._scale>self.threshold] = self.threshold
|
|
83
|
+
if self.progress: print(f"Done. Chosen scales {qs} are {self._scale}", flush=True)
|
|
84
|
+
return self._scale
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _get_sts_rips(self,x):
|
|
88
|
+
st_init = gd.RipsComplex(points=x, max_edge_length=self._threshold, sparse=self.sparse).create_simplex_tree(max_dimension=1)
|
|
89
|
+
st_init = mp.simplex_tree_multi.SimplexTreeMulti(st_init, num_parameters = 2, safe_conversion=self.safe_conversion)
|
|
90
|
+
codensities = self._get_codensities(x_fit=x,x_sample=x)
|
|
91
|
+
num_axes = codensities.shape[0]
|
|
92
|
+
sts = [st_init] + [
|
|
93
|
+
st_init.copy() for _ in range(num_axes -1)
|
|
94
|
+
]
|
|
95
|
+
# no need to multithread here, most operations are memory
|
|
96
|
+
for codensity,st_copy in zip(codensities,sts):
|
|
97
|
+
# RIPS has contigus vertices, so vertices are ordered.
|
|
98
|
+
st_copy.fill_lowerstar(codensity,parameter=1)
|
|
99
|
+
|
|
100
|
+
def collapse_edges(st):
|
|
101
|
+
if self.verbose:
|
|
102
|
+
print("Num simplices :", st.num_simplices)
|
|
103
|
+
if isinstance(self.num_collapses, int):
|
|
104
|
+
st.collapse_edges(num=self.num_collapses)
|
|
105
|
+
if self.verbose:
|
|
106
|
+
print(", after collapse :", st.num_simplices, end="")
|
|
107
|
+
elif self.num_collapses == "full":
|
|
108
|
+
st.collapse_edges(full=True)
|
|
109
|
+
if self.verbose:
|
|
110
|
+
print(", after collapse :", st.num_simplices, end="")
|
|
111
|
+
if self.expand_dim > 1:
|
|
112
|
+
st.expansion(self.expand_dim)
|
|
113
|
+
if self.verbose:
|
|
114
|
+
print(", after expansion :", st.num_simplices, end="")
|
|
115
|
+
if self.verbose:
|
|
116
|
+
print("")
|
|
117
|
+
return st
|
|
118
|
+
return Parallel(
|
|
119
|
+
backend='threading', n_jobs=self.n_jobs
|
|
120
|
+
)(delayed(collapse_edges)(st) for st in sts)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _get_sts_alpha(self,x:np.ndarray, return_alpha=False):
|
|
125
|
+
alpha_complex = gd.AlphaComplex(points=x)
|
|
126
|
+
st = alpha_complex.create_simplex_tree(max_alpha_square = self._threshold**2)
|
|
127
|
+
vertices = np.array([i for (i,),_ in st.get_skeleton(0)])
|
|
128
|
+
new_points = np.asarray([alpha_complex.get_point(i) for i in vertices]) ## Seems to be unsafe for some reason
|
|
129
|
+
# new_points = x
|
|
130
|
+
st = mp.simplex_tree_multi.SimplexTreeMulti(st, num_parameters = 2,safe_conversion=self.safe_conversion)
|
|
131
|
+
codensities = self._get_codensities(x_fit=x,x_sample=new_points)
|
|
132
|
+
num_axes = codensities.shape[0]
|
|
133
|
+
sts = [st] + [
|
|
134
|
+
st.copy() for _ in range(num_axes -1)
|
|
135
|
+
]
|
|
136
|
+
# no need to multithread here, most operations are memory
|
|
137
|
+
max_vertices = vertices.max()+2 # +1 to be safe
|
|
138
|
+
for codensity,st_copy in zip(codensities,sts):
|
|
139
|
+
alligned_codensity = np.array([np.nan]*max_vertices)
|
|
140
|
+
alligned_codensity[vertices] = codensity
|
|
141
|
+
# alligned_codensity = np.array([codensity[i] if i in vertices else np.nan for i in range(max_vertices)])
|
|
142
|
+
st_copy.fill_lowerstar(alligned_codensity, parameter=1)
|
|
143
|
+
if return_alpha:
|
|
144
|
+
return alpha_complex,sts
|
|
145
|
+
return sts
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _get_codensities(self,x_fit,x_sample):
|
|
149
|
+
x_fit = np.asarray(x_fit, dtype=np.float32)
|
|
150
|
+
x_sample = np.asarray(x_sample,dtype=np.float32)
|
|
151
|
+
codensities_kde = np.asarray([- KDE(
|
|
152
|
+
bandwidth=bandwidth, kernel=self.kernel).fit(x_fit).score_samples(x_sample)
|
|
153
|
+
for bandwidth in self._bandwidths],
|
|
154
|
+
).reshape(len(self._bandwidths), len(x_sample))
|
|
155
|
+
codensities_dtm = DTM(
|
|
156
|
+
masses=self.masses
|
|
157
|
+
).fit(x_fit).score_samples(x_sample).reshape(len(self.masses), len(x_sample))
|
|
158
|
+
return np.concatenate([codensities_kde,codensities_dtm])
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def fit(self, X:np.ndarray|list, y=None):
|
|
163
|
+
# self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
|
|
164
|
+
match self.complex:
|
|
165
|
+
case 'rips':
|
|
166
|
+
self._get_sts = self._get_sts_rips
|
|
167
|
+
case 'alpha':
|
|
168
|
+
self._get_sts = self._get_sts_alpha
|
|
169
|
+
case _:
|
|
170
|
+
raise ValueError(f"Invalid complex {self.complex}. Possible choises are rips or alpha.")
|
|
171
|
+
|
|
172
|
+
qs = [q for q in [*-np.asarray(self.bandwidths), -self.threshold] if 0 <= q <= 1]
|
|
173
|
+
self._get_distance_quantiles(X, qs=qs)
|
|
174
|
+
self._bandwidths = np.array(self.bandwidths)
|
|
175
|
+
count=0
|
|
176
|
+
for i in range(len(self._bandwidths)):
|
|
177
|
+
if self.bandwidths[i] < 0:
|
|
178
|
+
self._bandwidths[i] = self._scale[count]
|
|
179
|
+
count+=1
|
|
180
|
+
self._threshold = self.threshold if self.threshold > 0 else self._scale[-1]
|
|
181
|
+
|
|
182
|
+
##PRECOMPILE FIRST
|
|
183
|
+
self._get_codensities(X[0][:1],X[0][:1])
|
|
184
|
+
return self
|
|
185
|
+
|
|
186
|
+
def transform(self,X):
|
|
187
|
+
## precompile first
|
|
188
|
+
self._get_sts(X[0][:2])
|
|
189
|
+
with tqdm(X, desc="Filling simplextrees", disable = not self.progress, total=len(X)) as data:
|
|
190
|
+
stss = Parallel(backend="threading", n_jobs=self.n_jobs)(delayed(self._get_sts)(x) for x in data)
|
|
191
|
+
return stss
|