multipers 2.3.3b6__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (183) hide show
  1. multipers/.dylibs/libc++.1.0.dylib +0 -0
  2. multipers/.dylibs/libtbb.12.16.dylib +0 -0
  3. multipers/__init__.py +33 -0
  4. multipers/_signed_measure_meta.py +453 -0
  5. multipers/_slicer_meta.py +211 -0
  6. multipers/array_api/__init__.py +45 -0
  7. multipers/array_api/numpy.py +41 -0
  8. multipers/array_api/torch.py +58 -0
  9. multipers/data/MOL2.py +458 -0
  10. multipers/data/UCR.py +18 -0
  11. multipers/data/__init__.py +1 -0
  12. multipers/data/graphs.py +466 -0
  13. multipers/data/immuno_regions.py +27 -0
  14. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  15. multipers/data/pytorch2simplextree.py +91 -0
  16. multipers/data/shape3d.py +101 -0
  17. multipers/data/synthetic.py +113 -0
  18. multipers/distances.py +202 -0
  19. multipers/filtration_conversions.pxd +229 -0
  20. multipers/filtration_conversions.pxd.tp +84 -0
  21. multipers/filtrations/__init__.py +18 -0
  22. multipers/filtrations/density.py +574 -0
  23. multipers/filtrations/filtrations.py +361 -0
  24. multipers/filtrations.pxd +224 -0
  25. multipers/function_rips.cpython-310-darwin.so +0 -0
  26. multipers/function_rips.pyx +105 -0
  27. multipers/grids.cpython-310-darwin.so +0 -0
  28. multipers/grids.pyx +433 -0
  29. multipers/gudhi/Persistence_slices_interface.h +132 -0
  30. multipers/gudhi/Simplex_tree_interface.h +239 -0
  31. multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
  32. multipers/gudhi/cubical_to_boundary.h +59 -0
  33. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  34. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  35. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  36. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  37. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  38. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  39. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  41. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  42. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  43. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  44. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  45. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  46. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  47. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  48. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  49. multipers/gudhi/gudhi/Matrix.h +2107 -0
  50. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  51. multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
  52. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  53. multipers/gudhi/gudhi/Off_reader.h +173 -0
  54. multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  87. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  88. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  89. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  90. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  91. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  92. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  93. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  94. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  95. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  96. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  97. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  98. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  99. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  100. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  101. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  102. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  103. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  104. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
  105. multipers/gudhi/gudhi/distance_functions.h +62 -0
  106. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  107. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  108. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  109. multipers/gudhi/gudhi/reader_utils.h +367 -0
  110. multipers/gudhi/mma_interface_coh.h +256 -0
  111. multipers/gudhi/mma_interface_h0.h +223 -0
  112. multipers/gudhi/mma_interface_matrix.h +293 -0
  113. multipers/gudhi/naive_merge_tree.h +536 -0
  114. multipers/gudhi/scc_io.h +310 -0
  115. multipers/gudhi/truc.h +1403 -0
  116. multipers/io.cpython-310-darwin.so +0 -0
  117. multipers/io.pyx +644 -0
  118. multipers/ml/__init__.py +0 -0
  119. multipers/ml/accuracies.py +90 -0
  120. multipers/ml/invariants_with_persistable.py +79 -0
  121. multipers/ml/kernels.py +176 -0
  122. multipers/ml/mma.py +713 -0
  123. multipers/ml/one.py +472 -0
  124. multipers/ml/point_clouds.py +352 -0
  125. multipers/ml/signed_measures.py +1589 -0
  126. multipers/ml/sliced_wasserstein.py +461 -0
  127. multipers/ml/tools.py +113 -0
  128. multipers/mma_structures.cpython-310-darwin.so +0 -0
  129. multipers/mma_structures.pxd +128 -0
  130. multipers/mma_structures.pyx +2786 -0
  131. multipers/mma_structures.pyx.tp +1094 -0
  132. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
  133. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  134. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  135. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  136. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  137. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  138. multipers/multiparameter_edge_collapse.py +41 -0
  139. multipers/multiparameter_module_approximation/approximation.h +2330 -0
  140. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  141. multipers/multiparameter_module_approximation/debug.h +107 -0
  142. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  143. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  144. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  145. multipers/multiparameter_module_approximation/images.h +79 -0
  146. multipers/multiparameter_module_approximation/list_column.h +174 -0
  147. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  148. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  149. multipers/multiparameter_module_approximation/set_column.h +135 -0
  150. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  151. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  152. multipers/multiparameter_module_approximation/utilities.h +403 -0
  153. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  154. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  155. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  156. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  157. multipers/multiparameter_module_approximation.cpython-310-darwin.so +0 -0
  158. multipers/multiparameter_module_approximation.pyx +235 -0
  159. multipers/pickle.py +90 -0
  160. multipers/plots.py +456 -0
  161. multipers/point_measure.cpython-310-darwin.so +0 -0
  162. multipers/point_measure.pyx +395 -0
  163. multipers/simplex_tree_multi.cpython-310-darwin.so +0 -0
  164. multipers/simplex_tree_multi.pxd +134 -0
  165. multipers/simplex_tree_multi.pyx +10840 -0
  166. multipers/simplex_tree_multi.pyx.tp +2009 -0
  167. multipers/slicer.cpython-310-darwin.so +0 -0
  168. multipers/slicer.pxd +3034 -0
  169. multipers/slicer.pxd.tp +234 -0
  170. multipers/slicer.pyx +20481 -0
  171. multipers/slicer.pyx.tp +1088 -0
  172. multipers/tensor/tensor.h +672 -0
  173. multipers/tensor.pxd +13 -0
  174. multipers/test.pyx +44 -0
  175. multipers/tests/__init__.py +62 -0
  176. multipers/torch/__init__.py +1 -0
  177. multipers/torch/diff_grids.py +240 -0
  178. multipers/torch/rips_density.py +310 -0
  179. multipers-2.3.3b6.dist-info/METADATA +128 -0
  180. multipers-2.3.3b6.dist-info/RECORD +183 -0
  181. multipers-2.3.3b6.dist-info/WHEEL +6 -0
  182. multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
  183. multipers-2.3.3b6.dist-info/top_level.txt +1 -0
multipers/ml/one.py ADDED
@@ -0,0 +1,472 @@
1
+ from sklearn.base import BaseEstimator, TransformerMixin
2
+ import gudhi as gd
3
+ from os.path import exists
4
+ import networkx as nx
5
+ from joblib import Parallel, delayed
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from warnings import warn
9
+ from sklearn.neighbors import KernelDensity
10
+ from typing import Iterable
11
+ from gudhi.representations import Landscape
12
+ from gudhi.representations.vector_methods import PersistenceImage
13
+ from gudhi.representations.kernel_methods import SlicedWassersteinDistance
14
+
15
+
16
+ from types import FunctionType
17
+ def get_simplextree(x)->gd.SimplexTree:
18
+ if isinstance(x, gd.SimplexTree):
19
+ return x
20
+ if isinstance(x, FunctionType):
21
+ return x()
22
+ if len(x) == 3 and isinstance(x[0],FunctionType):
23
+ f,args, kwargs = x
24
+ return f(*args,**kwargs)
25
+ raise TypeError("Not a valid SimplexTree")
26
+ def get_simplextrees(X)->Iterable[gd.SimplexTree]:
27
+ if len(X) == 2 and isinstance(X[0], FunctionType):
28
+ f,data = X
29
+ return (f(x) for x in data)
30
+ if len(X) == 0: return []
31
+ if not isinstance(X[0], gd.SimplexTree):
32
+ raise TypeError
33
+ return X
34
+
35
+
36
+
37
+
38
+ ############## INTERVALS (for sliced wasserstein)
39
+ class Graph2SimplexTree(BaseEstimator,TransformerMixin):
40
+ def __init__(self, f:str="ricciCurvature",dtype=gd.SimplexTree, reverse_filtration:bool=False):
41
+ super().__init__()
42
+ self.f=f # filtration to search in graph
43
+ self.dtype = dtype # If None, will delay the computation in the pipe (for parallelism)
44
+ self.reverse_filtration = reverse_filtration # reverses the filtration #TODO
45
+ def fit(self, X, y=None):
46
+ return self
47
+ def transform(self,X:list[nx.Graph]):
48
+ def todo(graph, f=self.f) -> gd.SimplexTree: # TODO : use batch insert
49
+ st = gd.SimplexTree()
50
+ for i in graph.nodes: st.insert([i], graph.nodes[i][f])
51
+ for u,v in graph.edges: st.insert([u,v], graph[u][v][f])
52
+ return st
53
+ return [todo, X] if self.dtype is None else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(graph) for graph in X)
54
+
55
+
56
+ class PointCloud2SimplexTree(BaseEstimator,TransformerMixin):
57
+ def __init__(self, delayed:bool = False, threshold = np.inf):
58
+ super().__init__()
59
+ self.delayed = delayed
60
+ self.threshold=threshold
61
+ @staticmethod
62
+ def _get_point_cloud_diameter(x):
63
+ from scipy.spatial import distance_matrix
64
+ return np.max(distance_matrix(x,x))
65
+ def fit(self, X, y=None):
66
+ if self.threshold < 0:
67
+ self.threshold = max(self._get_point_cloud_diameter(x) for x in X)
68
+ return self
69
+ def transform(self,X:list[nx.Graph]):
70
+ def todo(point_cloud) -> gd.SimplexTree: # TODO : use batch insert
71
+ st = gd.AlphaComplex(points=point_cloud).create_simplex_tree(max_alpha_square = self.threshold**2)
72
+ return st
73
+ return [todo, X] if self.delayed is None else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(point_cloud) for point_cloud in X)
74
+
75
+
76
+
77
+ #################### FILVEC
78
+ def get_filtration_values(g:nx.Graph, f:str)->np.ndarray:
79
+ filtrations_values = [
80
+ g.nodes[node][f] for node in g.nodes
81
+ ]+[
82
+ g[u][v][f] for u,v in g.edges
83
+ ]
84
+ return np.array(filtrations_values)
85
+ def graph2filvec(g:nx.Graph, f:str, range:tuple, bins:int)->np.ndarray:
86
+ fs = get_filtration_values(g, f)
87
+ return np.histogram(fs, bins=bins,range=range)[0]
88
+ class FilvecGetter(BaseEstimator, TransformerMixin):
89
+ def __init__(self, f:str="ricciCurvature",quantile:float=0., bins:int=100, n_jobs:int=1):
90
+ super().__init__()
91
+ self.f=f
92
+ self.quantile=quantile
93
+ self.bins=bins
94
+ self.range:tuple[float]|None=None
95
+ self.n_jobs=n_jobs
96
+ def fit(self, X, y=None):
97
+ filtration_values = np.concatenate(Parallel(n_jobs=self.n_jobs)(delayed(get_filtration_values)(g,f=self.f) for g in X))
98
+ self.range= tuple(np.quantile(filtration_values, [self.quantile, 1-self.quantile]))
99
+ return self
100
+ def transform(self,X):
101
+ if self.range == None:
102
+ print("Fit first")
103
+ return
104
+ return Parallel(n_jobs=self.n_jobs)(delayed(graph2filvec)(g,f=self.f, range=self.range, bins=self.bins) for g in X)
105
+
106
+
107
+
108
+
109
+ ############# Filvec from SimplexTree
110
+ # Input list of [list of diagrams], outputs histogram of persitence values (x and y coord mixed)
111
+ def simplextree2hist(simplextree, range:tuple[float, float], bins:int, density:bool)->np.ndarray: #TODO : Anything to histogram
112
+ filtration_values = np.array([f for s,f in simplextree.get_simplices()])
113
+ return np.histogram(filtration_values, bins=bins,range=range, density=density)[0]
114
+ class SimplexTree2Histogram(BaseEstimator, TransformerMixin):
115
+ def __init__(self, quantile:float=0., bins:int=100, n_jobs:int=1, progress:bool=False, density:bool=True):
116
+ super().__init__()
117
+ self.range:np.ndarray | None=None
118
+ self.quantile:float=quantile
119
+ self.bins:int=bins
120
+ self.n_jobs=n_jobs
121
+ self.density=density
122
+ self.progress = progress
123
+ # self.max_dimension=None # TODO: maybe use it
124
+ def fit(self, X, y=None): # X:list[diagrams]
125
+ if len(X) == 0: return self
126
+ if type(X[0]) is gd.SimplexTree: # If X contains simplextree : nothing to do
127
+ data = X
128
+ to_st = lambda x : x
129
+ else: # otherwise we assume that we retrieve simplextrees using f,data = X; simplextrees = (f(x) for x in data)
130
+ # assert len(X) == 2
131
+ to_st, data = X
132
+ persistence_values = np.array([f for st in data for s,f in to_st(st).get_simplices()])
133
+ persistence_values = persistence_values[persistence_values<np.inf]
134
+ self.range = np.quantile(persistence_values, [self.quantile, 1-self.quantile])
135
+ return self
136
+ def transform(self,X):
137
+ if len(X) == 0: return self
138
+ if type(X[0]) is gd.SimplexTree: # If X contains simplextree : nothing to do
139
+ if self.n_jobs > 1:
140
+ warn("Cannot pickle simplextrees, reducing to 1 thread to compute the simplextrees")
141
+ return [simplextree2hist(g,range=self.range, bins=self.bins, density=self.density) for g in tqdm(X, desc="Computing diagrams", disable=not self.progress)]
142
+ else: # otherwise we assume that we retrieve simplextrees using f,data = X; simplextrees = (f(x) for x in data)
143
+ to_st, data = X # asserts len(X) == 2
144
+ def pickle_able_todo(x, **kwargs):
145
+ simplextree = to_st(x)
146
+ return simplextree2hist(simplextree=simplextree, **kwargs)
147
+ return Parallel(n_jobs=self.n_jobs)(delayed(pickle_able_todo)(g,range=self.range, bins=self.bins, density=self.density) for g in tqdm(data, desc="Computing simplextrees and their diagrams", disable=not self.progress))
148
+
149
+
150
+
151
+
152
+ ############# PERVEC
153
+ # Input list of [list of diagrams], outputs histogram of persitence values (x and y coord mixed)
154
+ def dgm2pervec(dgms, range:tuple[float, float], bins:int)->np.ndarray: #TODO : Anything to histogram
155
+ dgm_union = np.concatenate([dgm.flatten() for dgm in dgms]).flatten()
156
+ return np.histogram(dgm_union, bins=bins,range=range)[0]
157
+ class Dgm2Histogram(BaseEstimator, TransformerMixin):
158
+ def __init__(self, quantile:float=0., bins:int=100, n_jobs:int=1):
159
+ super().__init__()
160
+ self.range:np.ndarray | None=None
161
+ self.quantile:float=quantile
162
+ self.bins:int=bins
163
+ self.n_jobs=n_jobs
164
+ def fit(self, X, y=None): # X:list[diagrams]
165
+ persistence_values = np.concatenate([dgm.flatten() for dgms in X for dgm in dgms], axis=0).flatten()
166
+ persistence_values = persistence_values[persistence_values<np.inf]
167
+ self.range = np.quantile(persistence_values, [self.quantile, 1-self.quantile])
168
+ return self
169
+ def transform(self,X):
170
+ return Parallel(n_jobs=self.n_jobs)(delayed(dgm2pervec)(g,range=self.range, bins=self.bins) for g in X)
171
+
172
+
173
+
174
+
175
+
176
+
177
+
178
+ ################# SignedMeasureImage
179
+ class Dgms2SignedMeasureImage(BaseEstimator, TransformerMixin):
180
+ def __init__(self, ranges:None|Iterable[Iterable[float]]=None, resolution:int=100, quantile:float=0, bandwidth:float=1, kernel:str="gaussian") -> None:
181
+ super().__init__()
182
+ self.ranges=ranges
183
+ self.resolution=resolution
184
+ self.quantile = quantile
185
+ self.bandwidth = bandwidth
186
+ self.kernel = kernel
187
+ def fit(self, X, y=None): # X:list[diagrams]
188
+ num_degrees = len(X[0])
189
+ persistence_values = [np.concatenate([dgms[i].flatten() for dgms in X], axis=0) for i in range(num_degrees)] # values per degree
190
+ persistence_values = [degrees_values[(-np.inf<degrees_values) * (degrees_values<np.inf)] for degrees_values in persistence_values] # non-trivial values
191
+ quantiles = [np.quantile(degree_values, [self.quantile, 1-self.quantile]) for degree_values in persistence_values] # quantiles
192
+ self.ranges = np.array([np.linspace(start=[a], stop=[b], num=self.resolution) for a,b in quantiles])
193
+ return self
194
+
195
+ def _dgm2smi(self, dgms:Iterable[np.ndarray]):
196
+ smi = np.concatenate(
197
+ [
198
+ KernelDensity(bandwidth=self.bandwidth, kernel=self.kernel).fit(dgm[:,[0]]).score_samples(range)
199
+ - KernelDensity(bandwidth=self.bandwidth).fit(dgm[:,[1]]).score_samples(range)
200
+ for dgm, range in zip(dgms, self.ranges)
201
+ ],
202
+ axis=0)
203
+ return smi
204
+
205
+ def transform(self,X): # X is a list (data) of list of diagrams
206
+ assert self.ranges is not None
207
+ out = Parallel(n_jobs=1, prefer="threads")(
208
+ delayed(Dgms2SignedMeasureImage._dgm2smi)(self=self, dgms=dgms)
209
+ for dgms in X
210
+ )
211
+
212
+ return out
213
+
214
+
215
+
216
+ ################# SignedMeasureHistogram
217
+ class Dgms2SignedMeasureHistogram(BaseEstimator, TransformerMixin):
218
+ def __init__(self, ranges:None|list[tuple[float,float]]=None, bins:int=100, quantile:float=0) -> None:
219
+ super().__init__()
220
+ self.ranges=ranges
221
+ self.bins=bins
222
+ self.quantile = quantile
223
+ def fit(self, X, y=None): # X:list[diagrams]
224
+ num_degrees = len(X[0])
225
+ persistence_values = [np.concatenate([dgms[i].flatten() for dgms in X], axis=0) for i in range(num_degrees)] # values per degree
226
+ persistence_values = [degrees_values[(-np.inf<degrees_values) * (degrees_values<np.inf)] for degrees_values in persistence_values] # non-trivial values
227
+ self.ranges = [np.quantile(degree_values, [self.quantile, 1-self.quantile]) for degree_values in persistence_values] # quantiles
228
+ return self
229
+ def transform(self,X): # X is a list (data) of list of diagrams
230
+ assert self.ranges is not None
231
+ out = [
232
+ np.concatenate(
233
+ [np.histogram(dgm[:,0], bins=self.bins,range=range)[0] - np.histogram(dgm[:,1], bins=self.bins,range=range)[0]
234
+ for dgm, range in zip(dgms, self.ranges)]
235
+ )
236
+ for dgms in X]
237
+ return out
238
+
239
+
240
+
241
+
242
+
243
+
244
+
245
+
246
+ ################## Signed Measure Kernel 1D
247
+ # input : list of [list of diagrams], outputs: the kernel to feed to an svm
248
+
249
+ # TODO : optimize ?
250
+ ## TODO : np.triu
251
+ class Dgms2SignedMeasureDistance(BaseEstimator, TransformerMixin):
252
+ def __init__(self, n_jobs:int=1, distance_matrix_path:str|None=None, progress:bool = False) -> None:
253
+ super().__init__()
254
+ self.degrees:list[int]|None=None
255
+ self.X:None|list[np.ndarray] = None
256
+ self.n_jobs=n_jobs
257
+ self.distance_matrix_path = distance_matrix_path
258
+ self.progress=progress
259
+ def fit(self, X:list[np.ndarray], y=None):
260
+ if len(X) <= 0:
261
+ warn("Fit a nontrivial vector")
262
+ return
263
+ self.X = X
264
+ self.degrees = list(range(len(X[0]))) # Assumes that all x \in X have the same number of diagrams
265
+ return self
266
+
267
+ @staticmethod
268
+ def wasserstein_1(a:np.ndarray,b:np.ndarray)->float:
269
+ return np.abs(np.sort(a) - np.sort(b)).mean() # norm 1
270
+ @staticmethod
271
+ def OSWdistance(mu:list[np.ndarray], nu:list[np.ndarray], dim:int)->float:
272
+ return Dgms2SignedMeasureDistance.wasserstein_1(np.hstack([mu[dim][:,0], nu[dim][:,1]]), np.hstack([nu[dim][:,0], mu[dim][:,1]])) # TODO : check: do we want to sum the kernels or the distances ? add weights ?
273
+ @staticmethod
274
+ def _ds(mu:list[np.ndarray], nus:list[list[np.ndarray]], dim:int): # mu and nu are lists of diagrams seen as signed measures (birth = +, death = -)
275
+ return [Dgms2SignedMeasureDistance.OSWdistance(mu,nu, dim) for nu in nus]
276
+
277
+ def transform(self,X): # X is a list (data) of list of diagrams
278
+ if self.X is None or self.degrees is None:
279
+ warn("Fit first !")
280
+ return np.array([[]])
281
+ # Cannot use sklearn / scipy, measures don't have the same size, -> no numpy array
282
+ # from sklearn.metrics import pairwise_distances
283
+ # distances = pairwise_distances(X, self.X, metric = OSWdistance, n_jobs=self.n_jobs)
284
+ # from scipy.spatial.distance import cdist
285
+ # distances = cdist(X, self.X, metric=self.OSWdistance)
286
+ distances_matrices = []
287
+ if not self.distance_matrix_path is None:
288
+ for degree in self.degrees:
289
+ with tqdm(X, desc=f"Computing distance matrix of degree {degree}") as diagrams_iterator:
290
+ matrix_path = f"{self.distance_matrix_path}_{degree}"
291
+ if exists(matrix_path):
292
+ distance_matrix = np.load(open(matrix_path, "rb"))
293
+ else:
294
+ distance_matrix = np.array(Parallel(n_jobs=self.n_jobs)(delayed(self._ds)(mu, self.X, degree) for mu in diagrams_iterator))
295
+ np.save(open(matrix_path, "wb"), distance_matrix)
296
+ distances_matrices.append(distance_matrix)
297
+ else:
298
+ for degree in self.degrees:
299
+ with tqdm(X, desc=f"Computing distance matrix of degree {degree}") as diagrams_iterator:
300
+ distances_matrices.append(np.array(Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self._ds)(mu, self.X, degree) for mu in diagrams_iterator)))
301
+ return np.asarray(distances_matrices)
302
+ # kernels = [np.exp(-distance_matrix / (2*self.sigma**2)) for distance_matrix in distances_matrices]
303
+ # return np.sum(kernels, axis=0)
304
+
305
+
306
+
307
+
308
+
309
+ ## Wrapper for SW, in order to take as an input a list of (list of diagrams)
310
+ class Dgms2SWK(BaseEstimator, TransformerMixin):
311
+ def __init__(self, num_directions:int=10, bandwidth:float=1.0, n_jobs:int=1, distance_matrix_path:str|None = None, progress:bool = False) -> None:
312
+ super().__init__()
313
+ self.num_directions:int=num_directions
314
+ self.bandwidth:float = bandwidth
315
+ self.n_jobs=n_jobs
316
+ self.SW_:list = []
317
+ self.distance_matrix_path = distance_matrix_path
318
+ self.progress = progress
319
+ def fit(self, X:list[list[np.ndarray]], y=None):
320
+ # Assumes that all x \in X have the same size
321
+ self.SW_ = [
322
+ SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs = self.n_jobs) for _ in range(len(X[0]))
323
+ ]
324
+ for i, sw in enumerate(self.SW_):
325
+ self.SW_[i]=sw.fit([dgms[i] for dgms in X]) # TODO : check : Not sure copy is necessary here
326
+ return self
327
+ def transform(self,X)->np.ndarray:
328
+ if not self.distance_matrix_path is None:
329
+ distance_matrices = []
330
+ for i in range(len(self.SW_)):
331
+ SW_i_path = f"{self.distance_matrix_path}_{i}"
332
+ if exists(SW_i_path):
333
+ distance_matrices.append(np.load(open(SW_i_path, "rb")))
334
+ else:
335
+ distance_matrix = self.SW_[i].transform([dgms[i] for dgms in X])
336
+ np.save(open(SW_i_path, "wb"), distance_matrix)
337
+ else:
338
+ distance_matrices = [sw.transform([dgms[i] for dgms in X]) for i, sw in enumerate(self.SW_)]
339
+ kernels = [np.exp(-distance_matrix / (2*self.bandwidth**2)) for distance_matrix in distance_matrices]
340
+ return np.sum(kernels, axis=0) # TODO fix this, we may want to sum the distances instead of the kernels.
341
+
342
+
343
+ class Dgms2SlicedWassersteinDistanceMatrices(BaseEstimator, TransformerMixin):
344
+ def __init__(self, num_directions:int=10, n_jobs:int=1) -> None:
345
+ super().__init__()
346
+ self.num_directions:int=num_directions
347
+ self.n_jobs=n_jobs
348
+ self.SW_:list = []
349
+ def fit(self, X:list[list[np.ndarray]], y=None):
350
+ # Assumes that all x \in X have the same size
351
+ self.SW_ = [
352
+ SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs = self.n_jobs) for _ in range(len(X[0]))
353
+ ]
354
+ for i, sw in enumerate(self.SW_):
355
+ self.SW_[i]=sw.fit([dgms[i] for dgms in X]) # TODO : check : Not sure copy is necessary here
356
+ return self
357
+
358
+ @staticmethod
359
+ def _get_distance(diagrams, SWD):
360
+ return SWD.transform(diagrams)
361
+ def transform(self,X):
362
+ distance_matrices = Parallel(n_jobs = self.n_jobs)(delayed(self._get_distance)([dgms[degree] for dgms in X], swd) for degree, swd in enumerate(self.SW_))
363
+ return np.asarray(distance_matrices)
364
+
365
+
366
+
367
+ # Gudhi simplexTree to list of diagrams
368
+ class SimplexTree2Dgm(BaseEstimator, TransformerMixin):
369
+ def __init__(self, degrees:list[int]|None = None, extended:list[int]|bool=[], n_jobs=1, progress:bool=False, threshold:float=np.inf) -> None:
370
+ super().__init__()
371
+ self.extended:list[int]|bool = False if not extended else extended if type(extended) is list else [0,2,5,7] # extended persistence.
372
+ # There are 4 diagrams per dimension then, the list of ints acts as a filter, on which to consider,
373
+ # eg., [0,2, 5,7] is Ord0, Ext+0, Rel1, Ext-1
374
+ self.degrees:list[int] = degrees if degrees else list(range((max(self.extended) // 4)+1)) if self.extended else [0] # homological degrees
375
+ self.n_jobs=n_jobs
376
+ self.progress = progress # progress bar
377
+ self.threshold = threshold # Threshold value
378
+ return
379
+ def fit(self, X:list[gd.SimplexTree], y=None):
380
+ if self.threshold <= 0:
381
+ self.threshold = max( (abs(f) for simplextree in get_simplextrees(X) for s,f in simplextree.get_simplices()) ) ## MAX FILTRATION VALUE
382
+ print(f"Setting threshold to {self.threshold}.")
383
+ return self
384
+ def transform(self,X:list[gd.SimplexTree]):
385
+ # Todo computes the diagrams
386
+ def reshape(dgm:np.ndarray|list)->np.ndarray:
387
+ out = np.array(dgm) if len(dgm) > 0 else np.empty((0,2))
388
+ if self.threshold != np.inf:
389
+ out[out>self.threshold] = self.threshold
390
+ out[out<-self.threshold] = -self.threshold
391
+ return out
392
+ def todo_standard(st):
393
+ st.compute_persistence()
394
+ return [reshape(st.persistence_intervals_in_dimension(d)) for d in self.degrees]
395
+ def todo_extended(st):
396
+ st.extend_filtration()
397
+ dgms = st.extended_persistence()
398
+ # print(dgms, self.degrees)
399
+ return [reshape([bar for j,dgm in enumerate(dgms) for d, bar in dgm if d in self.degrees and j+4*d in self.extended])]
400
+ todo = todo_extended if self.extended else todo_standard
401
+
402
+ if isinstance(X[0],gd.SimplexTree): # simplextree aren't pickleable, no parallel
403
+ # if self.n_jobs != 1: warn("Cannot parallelize. Use dtype=None in previous pipe.")
404
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(todo)(x) for x in tqdm(X, disable=not self.progress, desc="Computing diagrams"))
405
+ else:
406
+ to_st = X[0]# if to_st is None else to_st
407
+ dataset = X[1]# if to_st is None else X
408
+ pickleable_todo = lambda x : todo(to_st(x))
409
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(pickleable_todo)(x) for x in tqdm(dataset, disable=not self.progress, desc="Computing simplextrees and diagrams"))
410
+ warn("Bad input.")
411
+ return
412
+
413
+ # Shuffles a diagram shaped array. Input : list of (list of diagrams), output, list of (list of shuffled diagrams)
414
+ class DiagramShuffle(BaseEstimator, TransformerMixin):
415
+ def __init__(self, ) -> None:
416
+ super().__init__()
417
+ return
418
+ def fit(self, X:list[list[np.ndarray]], y=None):
419
+ return self
420
+ def transform(self,X:list[list[np.ndarray]]):
421
+ def shuffle(dgm):
422
+ shape = dgm.shape
423
+ dgm = dgm.flatten()
424
+ np.random.shuffle(dgm)
425
+ dgm = dgm.reshape(shape)
426
+ return dgm
427
+ def todo(dgms):
428
+ return [shuffle(dgm) for dgm in dgms]
429
+ return [todo(dgm) for dgm in X]
430
+
431
+
432
+ class Dgms2Landscapes(BaseEstimator, TransformerMixin):
433
+ def __init__(self, num:int=5, resolution:int=100, n_jobs:int=1) -> None:
434
+ super().__init__()
435
+ self.degrees:list[int] = []
436
+ self.num:int= num
437
+ self.resolution:int = resolution
438
+ self.landscapes:list[Landscape]= []
439
+ self.n_jobs=n_jobs
440
+ return
441
+ def fit(self, X, y=None):
442
+ if len(X) == 0: return self
443
+ self.degrees = list(range(len(X[0])))
444
+ self.landscapes = []
445
+ for dim in self.degrees:
446
+ self.landscapes.append(Landscape(num_landscapes=self.num,resolution=self.resolution).fit([dgms[dim] for dgms in X]))
447
+ return self
448
+ def transform(self,X):
449
+ if len(X) == 0: return []
450
+ return np.concatenate([landscape.transform([dgms[degree] for dgms in X]) for degree, landscape in enumerate(self.landscapes)], axis=1)
451
+
452
+ class Dgms2Image(BaseEstimator, TransformerMixin):
453
+ def __init__(self, bandwidth:float=1, resolution:tuple[int,int]=(20,20), n_jobs:int=1) -> None:
454
+ super().__init__()
455
+ self.degrees:list[int] = []
456
+ self.bandwidth:float= bandwidth
457
+ self.resolution = resolution
458
+ self.PI:list[PersistenceImage]= []
459
+ self.n_jobs=n_jobs
460
+ return
461
+ def fit(self, X, y=None):
462
+ if len(X) == 0: return self
463
+ self.degrees = list(range(len(X[0])))
464
+ self.PI = []
465
+ for dim in self.degrees:
466
+ self.PI.append(PersistenceImage(bandwidth=self.bandwidth,resolution=self.resolution).fit([dgms[dim] for dgms in X]))
467
+ return self
468
+ def transform(self,X):
469
+ if len(X) == 0: return []
470
+ return np.concatenate([pers_image.transform([dgms[degree] for dgms in X]) for degree, pers_image in enumerate(self.PI)], axis=1)
471
+
472
+