multipers 1.1.3__cp310-cp310-macosx_11_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (63) hide show
  1. multipers/.dylibs/libtbb.12.12.dylib +0 -0
  2. multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
  3. multipers/__init__.py +5 -0
  4. multipers/_old_rank_invariant.pyx +328 -0
  5. multipers/_signed_measure_meta.py +193 -0
  6. multipers/data/MOL2.py +350 -0
  7. multipers/data/UCR.py +18 -0
  8. multipers/data/__init__.py +1 -0
  9. multipers/data/graphs.py +466 -0
  10. multipers/data/immuno_regions.py +27 -0
  11. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  12. multipers/data/pytorch2simplextree.py +91 -0
  13. multipers/data/shape3d.py +101 -0
  14. multipers/data/synthetic.py +68 -0
  15. multipers/distances.py +172 -0
  16. multipers/euler_characteristic.cpython-310-darwin.so +0 -0
  17. multipers/euler_characteristic.pyx +137 -0
  18. multipers/function_rips.cpython-310-darwin.so +0 -0
  19. multipers/function_rips.pyx +102 -0
  20. multipers/hilbert_function.cpython-310-darwin.so +0 -0
  21. multipers/hilbert_function.pyi +46 -0
  22. multipers/hilbert_function.pyx +151 -0
  23. multipers/io.cpython-310-darwin.so +0 -0
  24. multipers/io.pyx +176 -0
  25. multipers/ml/__init__.py +0 -0
  26. multipers/ml/accuracies.py +61 -0
  27. multipers/ml/convolutions.py +510 -0
  28. multipers/ml/invariants_with_persistable.py +79 -0
  29. multipers/ml/kernels.py +128 -0
  30. multipers/ml/mma.py +657 -0
  31. multipers/ml/one.py +472 -0
  32. multipers/ml/point_clouds.py +191 -0
  33. multipers/ml/signed_betti.py +50 -0
  34. multipers/ml/signed_measures.py +1479 -0
  35. multipers/ml/sliced_wasserstein.py +313 -0
  36. multipers/ml/tools.py +116 -0
  37. multipers/mma_structures.cpython-310-darwin.so +0 -0
  38. multipers/mma_structures.pxd +155 -0
  39. multipers/mma_structures.pyx +651 -0
  40. multipers/multiparameter_edge_collapse.py +29 -0
  41. multipers/multiparameter_module_approximation.cpython-310-darwin.so +0 -0
  42. multipers/multiparameter_module_approximation.pyi +439 -0
  43. multipers/multiparameter_module_approximation.pyx +311 -0
  44. multipers/pickle.py +53 -0
  45. multipers/plots.py +292 -0
  46. multipers/point_measure_integration.cpython-310-darwin.so +0 -0
  47. multipers/point_measure_integration.pyx +59 -0
  48. multipers/rank_invariant.cpython-310-darwin.so +0 -0
  49. multipers/rank_invariant.pyx +154 -0
  50. multipers/simplex_tree_multi.cpython-310-darwin.so +0 -0
  51. multipers/simplex_tree_multi.pxd +121 -0
  52. multipers/simplex_tree_multi.pyi +715 -0
  53. multipers/simplex_tree_multi.pyx +1417 -0
  54. multipers/slicer.cpython-310-darwin.so +0 -0
  55. multipers/slicer.pxd +94 -0
  56. multipers/slicer.pyx +276 -0
  57. multipers/tensor.pxd +13 -0
  58. multipers/test.pyx +44 -0
  59. multipers-1.1.3.dist-info/LICENSE +21 -0
  60. multipers-1.1.3.dist-info/METADATA +22 -0
  61. multipers-1.1.3.dist-info/RECORD +63 -0
  62. multipers-1.1.3.dist-info/WHEEL +5 -0
  63. multipers-1.1.3.dist-info/top_level.txt +1 -0
multipers/ml/one.py ADDED
@@ -0,0 +1,472 @@
1
+ from sklearn.base import BaseEstimator, TransformerMixin
2
+ import gudhi as gd
3
+ from os.path import exists
4
+ import networkx as nx
5
+ from joblib import Parallel, delayed
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from warnings import warn
9
+ from sklearn.neighbors import KernelDensity
10
+ from typing import Iterable
11
+ from gudhi.representations import Landscape
12
+ from gudhi.representations.vector_methods import PersistenceImage
13
+ from gudhi.representations.kernel_methods import SlicedWassersteinDistance
14
+
15
+
16
+ from types import FunctionType
17
+ def get_simplextree(x)->gd.SimplexTree:
18
+ if isinstance(x, gd.SimplexTree):
19
+ return x
20
+ if isinstance(x, FunctionType):
21
+ return x()
22
+ if len(x) == 3 and isinstance(x[0],FunctionType):
23
+ f,args, kwargs = x
24
+ return f(*args,**kwargs)
25
+ raise TypeError("Not a valid SimplexTree")
26
+ def get_simplextrees(X)->Iterable[gd.SimplexTree]:
27
+ if len(X) == 2 and isinstance(X[0], FunctionType):
28
+ f,data = X
29
+ return (f(x) for x in data)
30
+ if len(X) == 0: return []
31
+ if not isinstance(X[0], gd.SimplexTree):
32
+ raise TypeError
33
+ return X
34
+
35
+
36
+
37
+
38
+ ############## INTERVALS (for sliced wasserstein)
39
+ class Graph2SimplexTree(BaseEstimator,TransformerMixin):
40
+ def __init__(self, f:str="ricciCurvature",dtype=gd.SimplexTree, reverse_filtration:bool=False):
41
+ super().__init__()
42
+ self.f=f # filtration to search in graph
43
+ self.dtype = dtype # If None, will delay the computation in the pipe (for parallelism)
44
+ self.reverse_filtration = reverse_filtration # reverses the filtration #TODO
45
+ def fit(self, X, y=None):
46
+ return self
47
+ def transform(self,X:list[nx.Graph]):
48
+ def todo(graph, f=self.f) -> gd.SimplexTree: # TODO : use batch insert
49
+ st = gd.SimplexTree()
50
+ for i in graph.nodes: st.insert([i], graph.nodes[i][f])
51
+ for u,v in graph.edges: st.insert([u,v], graph[u][v][f])
52
+ return st
53
+ return [todo, X] if self.dtype is None else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(graph) for graph in X)
54
+
55
+
56
+ class PointCloud2SimplexTree(BaseEstimator,TransformerMixin):
57
+ def __init__(self, delayed:bool = False, threshold = np.inf):
58
+ super().__init__()
59
+ self.delayed = delayed
60
+ self.threshold=threshold
61
+ @staticmethod
62
+ def _get_point_cloud_diameter(x):
63
+ from scipy.spatial import distance_matrix
64
+ return np.max(distance_matrix(x,x))
65
+ def fit(self, X, y=None):
66
+ if self.threshold < 0:
67
+ self.threshold = max(self._get_point_cloud_diameter(x) for x in X)
68
+ return self
69
+ def transform(self,X:list[nx.Graph]):
70
+ def todo(point_cloud) -> gd.SimplexTree: # TODO : use batch insert
71
+ st = gd.AlphaComplex(points=point_cloud).create_simplex_tree(max_alpha_square = self.threshold**2)
72
+ return st
73
+ return [todo, X] if self.delayed is None else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(point_cloud) for point_cloud in X)
74
+
75
+
76
+
77
+ #################### FILVEC
78
+ def get_filtration_values(g:nx.Graph, f:str)->np.ndarray:
79
+ filtrations_values = [
80
+ g.nodes[node][f] for node in g.nodes
81
+ ]+[
82
+ g[u][v][f] for u,v in g.edges
83
+ ]
84
+ return np.array(filtrations_values)
85
+ def graph2filvec(g:nx.Graph, f:str, range:tuple, bins:int)->np.ndarray:
86
+ fs = get_filtration_values(g, f)
87
+ return np.histogram(fs, bins=bins,range=range)[0]
88
+ class FilvecGetter(BaseEstimator, TransformerMixin):
89
+ def __init__(self, f:str="ricciCurvature",quantile:float=0., bins:int=100, n_jobs:int=1):
90
+ super().__init__()
91
+ self.f=f
92
+ self.quantile=quantile
93
+ self.bins=bins
94
+ self.range:tuple[float]|None=None
95
+ self.n_jobs=n_jobs
96
+ def fit(self, X, y=None):
97
+ filtration_values = np.concatenate(Parallel(n_jobs=self.n_jobs)(delayed(get_filtration_values)(g,f=self.f) for g in X))
98
+ self.range= tuple(np.quantile(filtration_values, [self.quantile, 1-self.quantile]))
99
+ return self
100
+ def transform(self,X):
101
+ if self.range == None:
102
+ print("Fit first")
103
+ return
104
+ return Parallel(n_jobs=self.n_jobs)(delayed(graph2filvec)(g,f=self.f, range=self.range, bins=self.bins) for g in X)
105
+
106
+
107
+
108
+
109
+ ############# Filvec from SimplexTree
110
+ # Input list of [list of diagrams], outputs histogram of persitence values (x and y coord mixed)
111
+ def simplextree2hist(simplextree, range:tuple[float, float], bins:int, density:bool)->np.ndarray: #TODO : Anything to histogram
112
+ filtration_values = np.array([f for s,f in simplextree.get_simplices()])
113
+ return np.histogram(filtration_values, bins=bins,range=range, density=density)[0]
114
+ class SimplexTree2Histogram(BaseEstimator, TransformerMixin):
115
+ def __init__(self, quantile:float=0., bins:int=100, n_jobs:int=1, progress:bool=False, density:bool=True):
116
+ super().__init__()
117
+ self.range:np.ndarray | None=None
118
+ self.quantile:float=quantile
119
+ self.bins:int=bins
120
+ self.n_jobs=n_jobs
121
+ self.density=density
122
+ self.progress = progress
123
+ # self.max_dimension=None # TODO: maybe use it
124
+ def fit(self, X, y=None): # X:list[diagrams]
125
+ if len(X) == 0: return self
126
+ if type(X[0]) is gd.SimplexTree: # If X contains simplextree : nothing to do
127
+ data = X
128
+ to_st = lambda x : x
129
+ else: # otherwise we assume that we retrieve simplextrees using f,data = X; simplextrees = (f(x) for x in data)
130
+ # assert len(X) == 2
131
+ to_st, data = X
132
+ persistence_values = np.array([f for st in data for s,f in to_st(st).get_simplices()])
133
+ persistence_values = persistence_values[persistence_values<np.inf]
134
+ self.range = np.quantile(persistence_values, [self.quantile, 1-self.quantile])
135
+ return self
136
+ def transform(self,X):
137
+ if len(X) == 0: return self
138
+ if type(X[0]) is gd.SimplexTree: # If X contains simplextree : nothing to do
139
+ if self.n_jobs > 1:
140
+ warn("Cannot pickle simplextrees, reducing to 1 thread to compute the simplextrees")
141
+ return [simplextree2hist(g,range=self.range, bins=self.bins, density=self.density) for g in tqdm(X, desc="Computing diagrams", disable=not self.progress)]
142
+ else: # otherwise we assume that we retrieve simplextrees using f,data = X; simplextrees = (f(x) for x in data)
143
+ to_st, data = X # asserts len(X) == 2
144
+ def pickle_able_todo(x, **kwargs):
145
+ simplextree = to_st(x)
146
+ return simplextree2hist(simplextree=simplextree, **kwargs)
147
+ return Parallel(n_jobs=self.n_jobs)(delayed(pickle_able_todo)(g,range=self.range, bins=self.bins, density=self.density) for g in tqdm(data, desc="Computing simplextrees and their diagrams", disable=not self.progress))
148
+
149
+
150
+
151
+
152
+ ############# PERVEC
153
+ # Input list of [list of diagrams], outputs histogram of persitence values (x and y coord mixed)
154
+ def dgm2pervec(dgms, range:tuple[float, float], bins:int)->np.ndarray: #TODO : Anything to histogram
155
+ dgm_union = np.concatenate([dgm.flatten() for dgm in dgms]).flatten()
156
+ return np.histogram(dgm_union, bins=bins,range=range)[0]
157
+ class Dgm2Histogram(BaseEstimator, TransformerMixin):
158
+ def __init__(self, quantile:float=0., bins:int=100, n_jobs:int=1):
159
+ super().__init__()
160
+ self.range:np.ndarray | None=None
161
+ self.quantile:float=quantile
162
+ self.bins:int=bins
163
+ self.n_jobs=n_jobs
164
+ def fit(self, X, y=None): # X:list[diagrams]
165
+ persistence_values = np.concatenate([dgm.flatten() for dgms in X for dgm in dgms], axis=0).flatten()
166
+ persistence_values = persistence_values[persistence_values<np.inf]
167
+ self.range = np.quantile(persistence_values, [self.quantile, 1-self.quantile])
168
+ return self
169
+ def transform(self,X):
170
+ return Parallel(n_jobs=self.n_jobs)(delayed(dgm2pervec)(g,range=self.range, bins=self.bins) for g in X)
171
+
172
+
173
+
174
+
175
+
176
+
177
+
178
+ ################# SignedMeasureImage
179
+ class Dgms2SignedMeasureImage(BaseEstimator, TransformerMixin):
180
+ def __init__(self, ranges:None|Iterable[Iterable[float]]=None, resolution:int=100, quantile:float=0, bandwidth:float=1, kernel:str="gaussian") -> None:
181
+ super().__init__()
182
+ self.ranges=ranges
183
+ self.resolution=resolution
184
+ self.quantile = quantile
185
+ self.bandwidth = bandwidth
186
+ self.kernel = kernel
187
+ def fit(self, X, y=None): # X:list[diagrams]
188
+ num_degrees = len(X[0])
189
+ persistence_values = [np.concatenate([dgms[i].flatten() for dgms in X], axis=0) for i in range(num_degrees)] # values per degree
190
+ persistence_values = [degrees_values[(-np.inf<degrees_values) * (degrees_values<np.inf)] for degrees_values in persistence_values] # non-trivial values
191
+ quantiles = [np.quantile(degree_values, [self.quantile, 1-self.quantile]) for degree_values in persistence_values] # quantiles
192
+ self.ranges = np.array([np.linspace(start=[a], stop=[b], num=self.resolution) for a,b in quantiles])
193
+ return self
194
+
195
+ def _dgm2smi(self, dgms:Iterable[np.ndarray]):
196
+ smi = np.concatenate(
197
+ [
198
+ KernelDensity(bandwidth=self.bandwidth, kernel=self.kernel).fit(dgm[:,[0]]).score_samples(range)
199
+ - KernelDensity(bandwidth=self.bandwidth).fit(dgm[:,[1]]).score_samples(range)
200
+ for dgm, range in zip(dgms, self.ranges)
201
+ ],
202
+ axis=0)
203
+ return smi
204
+
205
+ def transform(self,X): # X is a list (data) of list of diagrams
206
+ assert self.ranges is not None
207
+ out = Parallel(n_jobs=1, prefer="threads")(
208
+ delayed(Dgms2SignedMeasureImage._dgm2smi)(self=self, dgms=dgms)
209
+ for dgms in X
210
+ )
211
+
212
+ return out
213
+
214
+
215
+
216
+ ################# SignedMeasureHistogram
217
+ class Dgms2SignedMeasureHistogram(BaseEstimator, TransformerMixin):
218
+ def __init__(self, ranges:None|list[tuple[float,float]]=None, bins:int=100, quantile:float=0) -> None:
219
+ super().__init__()
220
+ self.ranges=ranges
221
+ self.bins=bins
222
+ self.quantile = quantile
223
+ def fit(self, X, y=None): # X:list[diagrams]
224
+ num_degrees = len(X[0])
225
+ persistence_values = [np.concatenate([dgms[i].flatten() for dgms in X], axis=0) for i in range(num_degrees)] # values per degree
226
+ persistence_values = [degrees_values[(-np.inf<degrees_values) * (degrees_values<np.inf)] for degrees_values in persistence_values] # non-trivial values
227
+ self.ranges = [np.quantile(degree_values, [self.quantile, 1-self.quantile]) for degree_values in persistence_values] # quantiles
228
+ return self
229
+ def transform(self,X): # X is a list (data) of list of diagrams
230
+ assert self.ranges is not None
231
+ out = [
232
+ np.concatenate(
233
+ [np.histogram(dgm[:,0], bins=self.bins,range=range)[0] - np.histogram(dgm[:,1], bins=self.bins,range=range)[0]
234
+ for dgm, range in zip(dgms, self.ranges)]
235
+ )
236
+ for dgms in X]
237
+ return out
238
+
239
+
240
+
241
+
242
+
243
+
244
+
245
+
246
+ ################## Signed Measure Kernel 1D
247
+ # input : list of [list of diagrams], outputs: the kernel to feed to an svm
248
+
249
+ # TODO : optimize ?
250
+ ## TODO : np.triu
251
+ class Dgms2SignedMeasureDistance(BaseEstimator, TransformerMixin):
252
+ def __init__(self, n_jobs:int=1, distance_matrix_path:str|None=None, progress:bool = False) -> None:
253
+ super().__init__()
254
+ self.degrees:list[int]|None=None
255
+ self.X:None|list[np.ndarray] = None
256
+ self.n_jobs=n_jobs
257
+ self.distance_matrix_path = distance_matrix_path
258
+ self.progress=progress
259
+ def fit(self, X:list[np.ndarray], y=None):
260
+ if len(X) <= 0:
261
+ warn("Fit a nontrivial vector")
262
+ return
263
+ self.X = X
264
+ self.degrees = list(range(len(X[0]))) # Assumes that all x \in X have the same number of diagrams
265
+ return self
266
+
267
+ @staticmethod
268
+ def wasserstein_1(a:np.ndarray,b:np.ndarray)->float:
269
+ return np.abs(np.sort(a) - np.sort(b)).mean() # norm 1
270
+ @staticmethod
271
+ def OSWdistance(mu:list[np.ndarray], nu:list[np.ndarray], dim:int)->float:
272
+ return Dgms2SignedMeasureDistance.wasserstein_1(np.hstack([mu[dim][:,0], nu[dim][:,1]]), np.hstack([nu[dim][:,0], mu[dim][:,1]])) # TODO : check: do we want to sum the kernels or the distances ? add weights ?
273
+ @staticmethod
274
+ def _ds(mu:list[np.ndarray], nus:list[list[np.ndarray]], dim:int): # mu and nu are lists of diagrams seen as signed measures (birth = +, death = -)
275
+ return [Dgms2SignedMeasureDistance.OSWdistance(mu,nu, dim) for nu in nus]
276
+
277
+ def transform(self,X): # X is a list (data) of list of diagrams
278
+ if self.X is None or self.degrees is None:
279
+ warn("Fit first !")
280
+ return np.array([[]])
281
+ # Cannot use sklearn / scipy, measures don't have the same size, -> no numpy array
282
+ # from sklearn.metrics import pairwise_distances
283
+ # distances = pairwise_distances(X, self.X, metric = OSWdistance, n_jobs=self.n_jobs)
284
+ # from scipy.spatial.distance import cdist
285
+ # distances = cdist(X, self.X, metric=self.OSWdistance)
286
+ distances_matrices = []
287
+ if not self.distance_matrix_path is None:
288
+ for degree in self.degrees:
289
+ with tqdm(X, desc=f"Computing distance matrix of degree {degree}") as diagrams_iterator:
290
+ matrix_path = f"{self.distance_matrix_path}_{degree}"
291
+ if exists(matrix_path):
292
+ distance_matrix = np.load(open(matrix_path, "rb"))
293
+ else:
294
+ distance_matrix = np.array(Parallel(n_jobs=self.n_jobs)(delayed(self._ds)(mu, self.X, degree) for mu in diagrams_iterator))
295
+ np.save(open(matrix_path, "wb"), distance_matrix)
296
+ distances_matrices.append(distance_matrix)
297
+ else:
298
+ for degree in self.degrees:
299
+ with tqdm(X, desc=f"Computing distance matrix of degree {degree}") as diagrams_iterator:
300
+ distances_matrices.append(np.array(Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self._ds)(mu, self.X, degree) for mu in diagrams_iterator)))
301
+ return np.asarray(distances_matrices)
302
+ # kernels = [np.exp(-distance_matrix / (2*self.sigma**2)) for distance_matrix in distances_matrices]
303
+ # return np.sum(kernels, axis=0)
304
+
305
+
306
+
307
+
308
+
309
+ ## Wrapper for SW, in order to take as an input a list of (list of diagrams)
310
+ class Dgms2SWK(BaseEstimator, TransformerMixin):
311
+ def __init__(self, num_directions:int=10, bandwidth:float=1.0, n_jobs:int=1, distance_matrix_path:str|None = None, progress:bool = False) -> None:
312
+ super().__init__()
313
+ self.num_directions:int=num_directions
314
+ self.bandwidth:float = bandwidth
315
+ self.n_jobs=n_jobs
316
+ self.SW_:list = []
317
+ self.distance_matrix_path = distance_matrix_path
318
+ self.progress = progress
319
+ def fit(self, X:list[list[np.ndarray]], y=None):
320
+ # Assumes that all x \in X have the same size
321
+ self.SW_ = [
322
+ SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs = self.n_jobs) for _ in range(len(X[0]))
323
+ ]
324
+ for i, sw in enumerate(self.SW_):
325
+ self.SW_[i]=sw.fit([dgms[i] for dgms in X]) # TODO : check : Not sure copy is necessary here
326
+ return self
327
+ def transform(self,X)->np.ndarray:
328
+ if not self.distance_matrix_path is None:
329
+ distance_matrices = []
330
+ for i in range(len(self.SW_)):
331
+ SW_i_path = f"{self.distance_matrix_path}_{i}"
332
+ if exists(SW_i_path):
333
+ distance_matrices.append(np.load(open(SW_i_path, "rb")))
334
+ else:
335
+ distance_matrix = self.SW_[i].transform([dgms[i] for dgms in X])
336
+ np.save(open(SW_i_path, "wb"), distance_matrix)
337
+ else:
338
+ distance_matrices = [sw.transform([dgms[i] for dgms in X]) for i, sw in enumerate(self.SW_)]
339
+ kernels = [np.exp(-distance_matrix / (2*self.bandwidth**2)) for distance_matrix in distance_matrices]
340
+ return np.sum(kernels, axis=0) # TODO fix this, we may want to sum the distances instead of the kernels.
341
+
342
+
343
+ class Dgms2SlicedWassersteinDistanceMatrices(BaseEstimator, TransformerMixin):
344
+ def __init__(self, num_directions:int=10, n_jobs:int=1) -> None:
345
+ super().__init__()
346
+ self.num_directions:int=num_directions
347
+ self.n_jobs=n_jobs
348
+ self.SW_:list = []
349
+ def fit(self, X:list[list[np.ndarray]], y=None):
350
+ # Assumes that all x \in X have the same size
351
+ self.SW_ = [
352
+ SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs = self.n_jobs) for _ in range(len(X[0]))
353
+ ]
354
+ for i, sw in enumerate(self.SW_):
355
+ self.SW_[i]=sw.fit([dgms[i] for dgms in X]) # TODO : check : Not sure copy is necessary here
356
+ return self
357
+
358
+ @staticmethod
359
+ def _get_distance(diagrams, SWD):
360
+ return SWD.transform(diagrams)
361
+ def transform(self,X):
362
+ distance_matrices = Parallel(n_jobs = self.n_jobs)(delayed(self._get_distance)([dgms[degree] for dgms in X], swd) for degree, swd in enumerate(self.SW_))
363
+ return np.asarray(distance_matrices)
364
+
365
+
366
+
367
+ # Gudhi simplexTree to list of diagrams
368
+ class SimplexTree2Dgm(BaseEstimator, TransformerMixin):
369
+ def __init__(self, degrees:list[int]|None = None, extended:list[int]|bool=[], n_jobs=1, progress:bool=False, threshold:float=np.inf) -> None:
370
+ super().__init__()
371
+ self.extended:list[int]|bool = False if not extended else extended if type(extended) is list else [0,2,5,7] # extended persistence.
372
+ # There are 4 diagrams per dimension then, the list of ints acts as a filter, on which to consider,
373
+ # eg., [0,2, 5,7] is Ord0, Ext+0, Rel1, Ext-1
374
+ self.degrees:list[int] = degrees if degrees else list(range((max(self.extended) // 4)+1)) if self.extended else [0] # homological degrees
375
+ self.n_jobs=n_jobs
376
+ self.progress = progress # progress bar
377
+ self.threshold = threshold # Threshold value
378
+ return
379
+ def fit(self, X:list[gd.SimplexTree], y=None):
380
+ if self.threshold <= 0:
381
+ self.threshold = max( (abs(f) for simplextree in get_simplextrees(X) for s,f in simplextree.get_simplices()) ) ## MAX FILTRATION VALUE
382
+ print(f"Setting threshold to {self.threshold}.")
383
+ return self
384
+ def transform(self,X:list[gd.SimplexTree]):
385
+ # Todo computes the diagrams
386
+ def reshape(dgm:np.ndarray|list)->np.ndarray:
387
+ out = np.array(dgm) if len(dgm) > 0 else np.empty((0,2))
388
+ if self.threshold != np.inf:
389
+ out[out>self.threshold] = self.threshold
390
+ out[out<-self.threshold] = -self.threshold
391
+ return out
392
+ def todo_standard(st):
393
+ st.compute_persistence()
394
+ return [reshape(st.persistence_intervals_in_dimension(d)) for d in self.degrees]
395
+ def todo_extended(st):
396
+ st.extend_filtration()
397
+ dgms = st.extended_persistence()
398
+ # print(dgms, self.degrees)
399
+ return [reshape([bar for j,dgm in enumerate(dgms) for d, bar in dgm if d in self.degrees and j+4*d in self.extended])]
400
+ todo = todo_extended if self.extended else todo_standard
401
+
402
+ if isinstance(X[0],gd.SimplexTree): # simplextree aren't pickleable, no parallel
403
+ # if self.n_jobs != 1: warn("Cannot parallelize. Use dtype=None in previous pipe.")
404
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(todo)(x) for x in tqdm(X, disable=not self.progress, desc="Computing diagrams"))
405
+ else:
406
+ to_st = X[0]# if to_st is None else to_st
407
+ dataset = X[1]# if to_st is None else X
408
+ pickleable_todo = lambda x : todo(to_st(x))
409
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(pickleable_todo)(x) for x in tqdm(dataset, disable=not self.progress, desc="Computing simplextrees and diagrams"))
410
+ warn("Bad input.")
411
+ return
412
+
413
+ # Shuffles a diagram shaped array. Input : list of (list of diagrams), output, list of (list of shuffled diagrams)
414
+ class DiagramShuffle(BaseEstimator, TransformerMixin):
415
+ def __init__(self, ) -> None:
416
+ super().__init__()
417
+ return
418
+ def fit(self, X:list[list[np.ndarray]], y=None):
419
+ return self
420
+ def transform(self,X:list[list[np.ndarray]]):
421
+ def shuffle(dgm):
422
+ shape = dgm.shape
423
+ dgm = dgm.flatten()
424
+ np.random.shuffle(dgm)
425
+ dgm = dgm.reshape(shape)
426
+ return dgm
427
+ def todo(dgms):
428
+ return [shuffle(dgm) for dgm in dgms]
429
+ return [todo(dgm) for dgm in X]
430
+
431
+
432
+ class Dgms2Landscapes(BaseEstimator, TransformerMixin):
433
+ def __init__(self, num:int=5, resolution:int=100, n_jobs:int=1) -> None:
434
+ super().__init__()
435
+ self.degrees:list[int] = []
436
+ self.num:int= num
437
+ self.resolution:int = resolution
438
+ self.landscapes:list[Landscape]= []
439
+ self.n_jobs=n_jobs
440
+ return
441
+ def fit(self, X, y=None):
442
+ if len(X) == 0: return self
443
+ self.degrees = list(range(len(X[0])))
444
+ self.landscapes = []
445
+ for dim in self.degrees:
446
+ self.landscapes.append(Landscape(num_landscapes=self.num,resolution=self.resolution).fit([dgms[dim] for dgms in X]))
447
+ return self
448
+ def transform(self,X):
449
+ if len(X) == 0: return []
450
+ return np.concatenate([landscape.transform([dgms[degree] for dgms in X]) for degree, landscape in enumerate(self.landscapes)], axis=1)
451
+
452
+ class Dgms2Image(BaseEstimator, TransformerMixin):
453
+ def __init__(self, bandwidth:float=1, resolution:tuple[int,int]=(20,20), n_jobs:int=1) -> None:
454
+ super().__init__()
455
+ self.degrees:list[int] = []
456
+ self.bandwidth:float= bandwidth
457
+ self.resolution = resolution
458
+ self.PI:list[PersistenceImage]= []
459
+ self.n_jobs=n_jobs
460
+ return
461
+ def fit(self, X, y=None):
462
+ if len(X) == 0: return self
463
+ self.degrees = list(range(len(X[0])))
464
+ self.PI = []
465
+ for dim in self.degrees:
466
+ self.PI.append(PersistenceImage(bandwidth=self.bandwidth,resolution=self.resolution).fit([dgms[dim] for dgms in X]))
467
+ return self
468
+ def transform(self,X):
469
+ if len(X) == 0: return []
470
+ return np.concatenate([pers_image.transform([dgms[degree] for dgms in X]) for degree, pers_image in enumerate(self.PI)], axis=1)
471
+
472
+
@@ -0,0 +1,191 @@
1
+ import numpy as np
2
+ from numpy.core.multiarray import concatenate
3
+ from numpy.lib import copy
4
+ import gudhi as gd
5
+ import multipers as mp
6
+ from sklearn.base import BaseEstimator, TransformerMixin
7
+ from multipers.ml.convolutions import KDE, DTM
8
+ from joblib import Parallel, delayed
9
+ from sklearn.metrics import pairwise_distances
10
+ from tqdm import tqdm
11
+ from typing import Literal,Optional
12
+
13
+ from multipers.simplex_tree_multi import SimplexTreeMulti
14
+
15
+ def _throw_nofit(any):
16
+ raise Exception("Fit first")
17
+
18
+ class PointCloud2SimplexTree(BaseEstimator, TransformerMixin):
19
+ def __init__(self,
20
+ bandwidths=[],
21
+ masses = [],
22
+ threshold:float=np.inf,
23
+ complex='rips',
24
+ sparse:float|None=None,
25
+ num_collapses:int|Literal['full']='full',
26
+ kernel:str="gaussian",
27
+ expand_dim:int=1,
28
+ progress:bool=False,
29
+ n_jobs:Optional[int]=None,
30
+ fit_fraction:float=1,
31
+ verbose:bool=False,
32
+ safe_conversion:bool=False,
33
+ ) -> None:
34
+ """
35
+ (Rips or Alpha) + (Density Estimation or DTM) 1-critical 2-filtration.
36
+
37
+ Parameters
38
+ ----------
39
+ - bandwidth : real : The kernel density estimation bandwidth, or the DTM mass. If negative, it replaced by abs(bandwidth)*(radius of the dataset)
40
+ - threshold : real, max edge lenfth of the rips or max alpha square of the alpha
41
+ - sparse : real, sparse rips (c.f. rips doc) WARNING : ONLY FOR RIPS
42
+ - num_collapse : int, Number of edge collapses applied to the simplextrees, WARNING : ONLY FOR RIPS
43
+ - expand_dim : int, expand the rips complex to this dimension. WARNING : ONLY FOR RIPS
44
+ - kernel : the kernel used for density estimation. Available ones are, e.g., "dtm", "gaussian", "exponential".
45
+ - progress : bool, shows the calculus status
46
+ - n_jobs : number of processes
47
+ - fit_fraction : real, the fraction of data on which to fit
48
+ - verbose : bool, Shows more information if true.
49
+
50
+ Output
51
+ ------
52
+ A list of SimplexTreeMulti whose first parameter is a rips and the second is the codensity.
53
+ """
54
+ super().__init__()
55
+ self.bandwidths = bandwidths
56
+ self.masses=masses
57
+ self.num_collapses=num_collapses
58
+ self.kernel = kernel
59
+ self.progress=progress
60
+ self._bandwidths= np.empty((0,))
61
+ self._threshold=np.inf
62
+ self.n_jobs = n_jobs
63
+ self._scale=np.empty((0,))
64
+ self.fit_fraction=fit_fraction
65
+ self.expand_dim=expand_dim
66
+ self.verbose=verbose
67
+ self.complex=complex
68
+ self.threshold=threshold
69
+ self.sparse=sparse
70
+ self._get_sts = _throw_nofit
71
+ self.safe_conversion=safe_conversion
72
+ return
73
+ def _get_distance_quantiles(self, X, qs):
74
+ if len(qs) == 0:
75
+ self._scale = []
76
+ return []
77
+ if self.progress: print("Estimating scale...", flush=True, end="")
78
+ indices = np.random.choice(len(X),min(len(X), int(self.fit_fraction*len(X))+1) ,replace=False)
79
+ # diameter = np.asarray([distance_matrix(x,x).max() for x in (X[i] for i in indices)]).max()
80
+ diameter = np.max([pairwise_distances(X = x).max() for x in (X[i] for i in indices)])
81
+ self._scale = diameter * np.asarray(qs)
82
+ if self.threshold > 0: self._scale[self._scale>self.threshold] = self.threshold
83
+ if self.progress: print(f"Done. Chosen scales {qs} are {self._scale}", flush=True)
84
+ return self._scale
85
+
86
+
87
+ def _get_sts_rips(self,x):
88
+ st_init = gd.RipsComplex(points=x, max_edge_length=self._threshold, sparse=self.sparse).create_simplex_tree(max_dimension=1)
89
+ st_init = mp.simplex_tree_multi.SimplexTreeMulti(st_init, num_parameters = 2, safe_conversion=self.safe_conversion)
90
+ codensities = self._get_codensities(x_fit=x,x_sample=x)
91
+ num_axes = codensities.shape[0]
92
+ sts = [st_init] + [
93
+ st_init.copy() for _ in range(num_axes -1)
94
+ ]
95
+ # no need to multithread here, most operations are memory
96
+ for codensity,st_copy in zip(codensities,sts):
97
+ # RIPS has contigus vertices, so vertices are ordered.
98
+ st_copy.fill_lowerstar(codensity,parameter=1)
99
+
100
+ def collapse_edges(st):
101
+ if self.verbose:
102
+ print("Num simplices :", st.num_simplices)
103
+ if isinstance(self.num_collapses, int):
104
+ st.collapse_edges(num=self.num_collapses)
105
+ if self.verbose:
106
+ print(", after collapse :", st.num_simplices, end="")
107
+ elif self.num_collapses == "full":
108
+ st.collapse_edges(full=True)
109
+ if self.verbose:
110
+ print(", after collapse :", st.num_simplices, end="")
111
+ if self.expand_dim > 1:
112
+ st.expansion(self.expand_dim)
113
+ if self.verbose:
114
+ print(", after expansion :", st.num_simplices, end="")
115
+ if self.verbose:
116
+ print("")
117
+ return st
118
+ return Parallel(
119
+ backend='threading', n_jobs=self.n_jobs
120
+ )(delayed(collapse_edges)(st) for st in sts)
121
+
122
+
123
+
124
+ def _get_sts_alpha(self,x:np.ndarray, return_alpha=False):
125
+ alpha_complex = gd.AlphaComplex(points=x)
126
+ st = alpha_complex.create_simplex_tree(max_alpha_square = self._threshold**2)
127
+ vertices = np.array([i for (i,),_ in st.get_skeleton(0)])
128
+ new_points = np.asarray([alpha_complex.get_point(i) for i in vertices]) ## Seems to be unsafe for some reason
129
+ # new_points = x
130
+ st = mp.simplex_tree_multi.SimplexTreeMulti(st, num_parameters = 2,safe_conversion=self.safe_conversion)
131
+ codensities = self._get_codensities(x_fit=x,x_sample=new_points)
132
+ num_axes = codensities.shape[0]
133
+ sts = [st] + [
134
+ st.copy() for _ in range(num_axes -1)
135
+ ]
136
+ # no need to multithread here, most operations are memory
137
+ max_vertices = vertices.max()+2 # +1 to be safe
138
+ for codensity,st_copy in zip(codensities,sts):
139
+ alligned_codensity = np.array([np.nan]*max_vertices)
140
+ alligned_codensity[vertices] = codensity
141
+ # alligned_codensity = np.array([codensity[i] if i in vertices else np.nan for i in range(max_vertices)])
142
+ st_copy.fill_lowerstar(alligned_codensity, parameter=1)
143
+ if return_alpha:
144
+ return alpha_complex,sts
145
+ return sts
146
+
147
+
148
+ def _get_codensities(self,x_fit,x_sample):
149
+ x_fit = np.asarray(x_fit, dtype=np.float32)
150
+ x_sample = np.asarray(x_sample,dtype=np.float32)
151
+ codensities_kde = np.asarray([- KDE(
152
+ bandwidth=bandwidth, kernel=self.kernel).fit(x_fit).score_samples(x_sample)
153
+ for bandwidth in self._bandwidths],
154
+ ).reshape(len(self._bandwidths), len(x_sample))
155
+ codensities_dtm = DTM(
156
+ masses=self.masses
157
+ ).fit(x_fit).score_samples(x_sample).reshape(len(self.masses), len(x_sample))
158
+ return np.concatenate([codensities_kde,codensities_dtm])
159
+
160
+
161
+
162
+ def fit(self, X:np.ndarray|list, y=None):
163
+ # self.bandwidth = "silverman" ## not good, as is can make bandwidth not constant
164
+ match self.complex:
165
+ case 'rips':
166
+ self._get_sts = self._get_sts_rips
167
+ case 'alpha':
168
+ self._get_sts = self._get_sts_alpha
169
+ case _:
170
+ raise ValueError(f"Invalid complex {self.complex}. Possible choises are rips or alpha.")
171
+
172
+ qs = [q for q in [*-np.asarray(self.bandwidths), -self.threshold] if 0 <= q <= 1]
173
+ self._get_distance_quantiles(X, qs=qs)
174
+ self._bandwidths = np.array(self.bandwidths)
175
+ count=0
176
+ for i in range(len(self._bandwidths)):
177
+ if self.bandwidths[i] < 0:
178
+ self._bandwidths[i] = self._scale[count]
179
+ count+=1
180
+ self._threshold = self.threshold if self.threshold > 0 else self._scale[-1]
181
+
182
+ ##PRECOMPILE FIRST
183
+ self._get_codensities(X[0][:1],X[0][:1])
184
+ return self
185
+
186
+ def transform(self,X):
187
+ ## precompile first
188
+ self._get_sts(X[0][:2])
189
+ with tqdm(X, desc="Filling simplextrees", disable = not self.progress, total=len(X)) as data:
190
+ stss = Parallel(backend="threading", n_jobs=self.n_jobs)(delayed(self._get_sts)(x) for x in data)
191
+ return stss