multipers 1.1.3__cp311-cp311-macosx_11_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/.dylibs/libtbb.12.12.dylib +0 -0
- multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
- multipers/__init__.py +5 -0
- multipers/_old_rank_invariant.pyx +328 -0
- multipers/_signed_measure_meta.py +193 -0
- multipers/data/MOL2.py +350 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +68 -0
- multipers/distances.py +172 -0
- multipers/euler_characteristic.cpython-311-darwin.so +0 -0
- multipers/euler_characteristic.pyx +137 -0
- multipers/function_rips.cpython-311-darwin.so +0 -0
- multipers/function_rips.pyx +102 -0
- multipers/hilbert_function.cpython-311-darwin.so +0 -0
- multipers/hilbert_function.pyi +46 -0
- multipers/hilbert_function.pyx +151 -0
- multipers/io.cpython-311-darwin.so +0 -0
- multipers/io.pyx +176 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +61 -0
- multipers/ml/convolutions.py +510 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +128 -0
- multipers/ml/mma.py +657 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +191 -0
- multipers/ml/signed_betti.py +50 -0
- multipers/ml/signed_measures.py +1479 -0
- multipers/ml/sliced_wasserstein.py +313 -0
- multipers/ml/tools.py +116 -0
- multipers/mma_structures.cpython-311-darwin.so +0 -0
- multipers/mma_structures.pxd +155 -0
- multipers/mma_structures.pyx +651 -0
- multipers/multiparameter_edge_collapse.py +29 -0
- multipers/multiparameter_module_approximation.cpython-311-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyi +439 -0
- multipers/multiparameter_module_approximation.pyx +311 -0
- multipers/pickle.py +53 -0
- multipers/plots.py +292 -0
- multipers/point_measure_integration.cpython-311-darwin.so +0 -0
- multipers/point_measure_integration.pyx +59 -0
- multipers/rank_invariant.cpython-311-darwin.so +0 -0
- multipers/rank_invariant.pyx +154 -0
- multipers/simplex_tree_multi.cpython-311-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +121 -0
- multipers/simplex_tree_multi.pyi +715 -0
- multipers/simplex_tree_multi.pyx +1417 -0
- multipers/slicer.cpython-311-darwin.so +0 -0
- multipers/slicer.pxd +94 -0
- multipers/slicer.pyx +276 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers-1.1.3.dist-info/LICENSE +21 -0
- multipers-1.1.3.dist-info/METADATA +22 -0
- multipers-1.1.3.dist-info/RECORD +63 -0
- multipers-1.1.3.dist-info/WHEEL +5 -0
- multipers-1.1.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
"""!
|
|
2
|
+
@package mma
|
|
3
|
+
@brief Files containing the C++ cythonized functions.
|
|
4
|
+
@author David Loiseaux
|
|
5
|
+
@copyright Copyright (c) 2022 Inria.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# distutils: language = c++
|
|
9
|
+
|
|
10
|
+
###########################################################################
|
|
11
|
+
## PYTHON LIBRARIES
|
|
12
|
+
import gudhi as gd
|
|
13
|
+
import numpy as np
|
|
14
|
+
from typing import List
|
|
15
|
+
import pickle as pk
|
|
16
|
+
|
|
17
|
+
###########################################################################
|
|
18
|
+
## CPP CLASSES
|
|
19
|
+
from libc.stdint cimport intptr_t
|
|
20
|
+
from libc.stdint cimport uintptr_t
|
|
21
|
+
|
|
22
|
+
###########################################################################
|
|
23
|
+
## CYTHON TYPES
|
|
24
|
+
from libcpp.vector cimport vector
|
|
25
|
+
from libcpp.utility cimport pair
|
|
26
|
+
#from libcpp.list cimport list as clist
|
|
27
|
+
from libcpp cimport bool
|
|
28
|
+
from libcpp cimport int
|
|
29
|
+
from typing import Iterable
|
|
30
|
+
from cython.operator import dereference
|
|
31
|
+
#########################################################################
|
|
32
|
+
## Multipersistence Module Approximation Classes
|
|
33
|
+
from multipers.mma_structures cimport Module, Box, pair, boundary_matrix
|
|
34
|
+
from multipers.slicer cimport *
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
#########################################################################
|
|
38
|
+
## Small hack for typing
|
|
39
|
+
from gudhi import SimplexTree
|
|
40
|
+
from multipers.simplex_tree_multi import SimplexTreeMulti
|
|
41
|
+
from multipers.slicer import Slicer, SlicerClement,SlicerVineGraph,SlicerVineSimplcial
|
|
42
|
+
from multipers.mma_structures import PyModule
|
|
43
|
+
# cimport numpy as cnp
|
|
44
|
+
# cnp.import_array()
|
|
45
|
+
|
|
46
|
+
###################################### MMA
|
|
47
|
+
cdef extern from "multiparameter_module_approximation/approximation.h" namespace "Gudhi::multiparameter::mma":
|
|
48
|
+
Module compute_vineyard_barcode_approximation(boundary_matrix, vector[Finitely_critical_multi_filtration] , value_type precision, Box[value_type] &, bool threshold, bool complete, bool multithread, bool verbose) nogil
|
|
49
|
+
# TODO : tempita
|
|
50
|
+
Module multiparameter_module_approximation(SimplicialVineGraphTruc&, value_type, Box[value_type]&, bool, bool, bool) nogil
|
|
51
|
+
|
|
52
|
+
Module multiparameter_module_approximation(SimplicialVineMatrixTruc&, value_type, Box[value_type]&, bool, bool, bool) nogil
|
|
53
|
+
Module multiparameter_module_approximation(GeneralVineTruc&, value_type, Box[value_type]&, bool, bool, bool) nogil
|
|
54
|
+
Module multiparameter_module_approximation(GeneralVineClementTruc&, value_type, Box[value_type]&, bool, bool, bool) nogil
|
|
55
|
+
|
|
56
|
+
# TODO : remove when old is deprecated
|
|
57
|
+
cdef extern from "multiparameter_module_approximation/format_python-cpp.h" namespace "Gudhi::multiparameter::mma":
|
|
58
|
+
pair[boundary_matrix, vector[Finitely_critical_multi_filtration]] simplextree_to_boundary_filtration(uintptr_t)
|
|
59
|
+
|
|
60
|
+
def module_approximation_old(
|
|
61
|
+
st:SimplexTreeMulti|None=None,
|
|
62
|
+
max_error:float|None = None,
|
|
63
|
+
box:list|np.ndarray|None = None,
|
|
64
|
+
threshold:bool = False,
|
|
65
|
+
complete:bool = True,
|
|
66
|
+
multithread:bool = False,
|
|
67
|
+
verbose:bool = False,
|
|
68
|
+
ignore_warning:bool = False,
|
|
69
|
+
nlines:int = 500,
|
|
70
|
+
max_dimension=np.inf,
|
|
71
|
+
boundary = None,
|
|
72
|
+
filtration = None,
|
|
73
|
+
return_timings:bool = False,
|
|
74
|
+
**kwargs
|
|
75
|
+
):
|
|
76
|
+
"""Computes an interval module approximation of a multiparameter filtration.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
st : n-filtered Simplextree, or None if boundary and filtration are provided.
|
|
81
|
+
Defines the n-filtration on which to compute the homology.
|
|
82
|
+
max_error: positive float
|
|
83
|
+
Trade-off between approximation and computational complexity.
|
|
84
|
+
Upper bound of the module approximation, in bottleneck distance,
|
|
85
|
+
for interval-decomposable modules.
|
|
86
|
+
nlines: int
|
|
87
|
+
Alternative to precision.
|
|
88
|
+
box : pair of list of floats
|
|
89
|
+
Defines a rectangle on which to compute the approximation.
|
|
90
|
+
Format : [x,y], where x,y defines the rectangle {z : x ≤ z ≤ y}
|
|
91
|
+
threshold: bool
|
|
92
|
+
When true, intersects the module support with the box.
|
|
93
|
+
verbose: bool
|
|
94
|
+
Prints C++ infos.
|
|
95
|
+
ignore_warning : bool
|
|
96
|
+
Unless set to true, prevents computing on more than 10k lines. Useful to prevent a segmentation fault due to "infinite" recursion.
|
|
97
|
+
return_timings : bool
|
|
98
|
+
If true, will return the time to compute instead (computed in python, using perf_counter_ns).
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
PyModule
|
|
102
|
+
An interval decomposable module approximation of the module defined by the
|
|
103
|
+
homology of this multi-filtration.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
if boundary is None or filtration is None:
|
|
107
|
+
from multipers.io import simplex_tree2boundary_filtrations
|
|
108
|
+
boundary,filtration = simplex_tree2boundary_filtrations(st) # TODO : recomputed each time... maybe store this somewhere ?
|
|
109
|
+
if max_dimension < np.inf: # TODO : make it more efficient
|
|
110
|
+
nsplx = len(boundary)
|
|
111
|
+
for i in range(nsplx-1,-1,-1):
|
|
112
|
+
b = boundary[i]
|
|
113
|
+
dim=len(b) -1
|
|
114
|
+
if dim>max_dimension:
|
|
115
|
+
boundary.pop(i)
|
|
116
|
+
for f in filtration:
|
|
117
|
+
f.pop(i)
|
|
118
|
+
nfiltration = len(filtration)
|
|
119
|
+
if nfiltration <= 0:
|
|
120
|
+
return PyModule()
|
|
121
|
+
if nfiltration == 1 and not(st is None):
|
|
122
|
+
st = st.project_on_line(0)
|
|
123
|
+
return st.persistence()
|
|
124
|
+
|
|
125
|
+
if box is None and not(st is None):
|
|
126
|
+
m,M = st.filtration_bounds()
|
|
127
|
+
elif box is not None:
|
|
128
|
+
m,M = np.asarray(box)
|
|
129
|
+
else:
|
|
130
|
+
m, M = np.min(filtration, axis=0), np.max(filtration, axis=0)
|
|
131
|
+
prod = 1
|
|
132
|
+
h = M[-1] - m[-1]
|
|
133
|
+
for i, [a,b] in enumerate(zip(m,M)):
|
|
134
|
+
if i == len(M)-1: continue
|
|
135
|
+
prod *= (b-a + h)
|
|
136
|
+
|
|
137
|
+
if max_error is None:
|
|
138
|
+
max_error:float = (prod/nlines)**(1/(nfiltration-1))
|
|
139
|
+
|
|
140
|
+
if box is None:
|
|
141
|
+
M = [np.max(f)+2*max_error for f in filtration]
|
|
142
|
+
m = [np.min(f)-2*max_error for f in filtration]
|
|
143
|
+
box = [m,M]
|
|
144
|
+
|
|
145
|
+
if ignore_warning and prod >= 20_000:
|
|
146
|
+
from warnings import warn
|
|
147
|
+
warn(f"Warning : the number of lines (around {np.round(prod)}) may be too high. Try to increase the precision parameter, or set `ignore_warning=True` to compute this module. Returning the trivial module.")
|
|
148
|
+
return PyModule()
|
|
149
|
+
|
|
150
|
+
approx_mod = PyModule()
|
|
151
|
+
cdef vector[Finitely_critical_multi_filtration] c_filtration = Finitely_critical_multi_filtration.from_python(filtration)
|
|
152
|
+
cdef boundary_matrix c_boundary = boundary
|
|
153
|
+
cdef value_type c_max_error = max_error
|
|
154
|
+
cdef bool c_threshold = threshold
|
|
155
|
+
cdef bool c_complete = complete
|
|
156
|
+
cdef bool c_multithread = multithread
|
|
157
|
+
cdef bool c_verbose = verbose
|
|
158
|
+
cdef Box[value_type] c_box = Box[value_type](box)
|
|
159
|
+
if return_timings:
|
|
160
|
+
from time import perf_counter_ns
|
|
161
|
+
t = perf_counter_ns()
|
|
162
|
+
with nogil:
|
|
163
|
+
c_mod = compute_vineyard_barcode_approximation(c_boundary,c_filtration,c_max_error, c_box, c_threshold, c_complete, c_multithread,c_verbose)
|
|
164
|
+
if return_timings:
|
|
165
|
+
t = perf_counter_ns() -t
|
|
166
|
+
t /= 10**9
|
|
167
|
+
return t
|
|
168
|
+
approx_mod._set_from_ptr(<intptr_t>(&c_mod))
|
|
169
|
+
return approx_mod
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def module_approximation(
|
|
174
|
+
st:SimplexTreeMulti,
|
|
175
|
+
value_type max_error = -1,
|
|
176
|
+
box:list|np.ndarray|None = None,
|
|
177
|
+
bool threshold:bool = False,
|
|
178
|
+
bool complete:bool = True,
|
|
179
|
+
bool verbose:bool = False,
|
|
180
|
+
bool ignore_warning:bool = False,
|
|
181
|
+
int nlines = 500,
|
|
182
|
+
backend:str="matrix",
|
|
183
|
+
# max_dimension=np.inf,
|
|
184
|
+
# return_timings:bool = False,
|
|
185
|
+
**kwargs
|
|
186
|
+
):
|
|
187
|
+
"""Computes an interval module approximation of a multiparameter filtration.
|
|
188
|
+
|
|
189
|
+
Parameters
|
|
190
|
+
----------
|
|
191
|
+
st : n-filtered Simplextree, or None if boundary and filtration are provided.
|
|
192
|
+
Defines the n-filtration on which to compute the homology.
|
|
193
|
+
max_error: positive float
|
|
194
|
+
Trade-off between approximation and computational complexity.
|
|
195
|
+
Upper bound of the module approximation, in bottleneck distance,
|
|
196
|
+
for interval-decomposable modules.
|
|
197
|
+
nlines: int
|
|
198
|
+
Alternative to precision.
|
|
199
|
+
box : pair of list of floats
|
|
200
|
+
Defines a rectangle on which to compute the approximation.
|
|
201
|
+
Format : [x,y], where x,y defines the rectangle {z : x ≤ z ≤ y}
|
|
202
|
+
threshold: bool
|
|
203
|
+
When true, intersects the module support with the box.
|
|
204
|
+
verbose: bool
|
|
205
|
+
Prints C++ infos.
|
|
206
|
+
ignore_warning : bool
|
|
207
|
+
Unless set to true, prevents computing on more than 10k lines. Useful to prevent a segmentation fault due to "infinite" recursion.
|
|
208
|
+
return_timings : bool
|
|
209
|
+
If true, will return the time to compute instead (computed in python, using perf_counter_ns).
|
|
210
|
+
Returns
|
|
211
|
+
-------
|
|
212
|
+
PyModule
|
|
213
|
+
An interval decomposable module approximation of the module defined by the
|
|
214
|
+
homology of this multi-filtration.
|
|
215
|
+
"""
|
|
216
|
+
if backend == "old":
|
|
217
|
+
if max_error == -1:
|
|
218
|
+
max_error_=None
|
|
219
|
+
else:
|
|
220
|
+
max_error_=max_error
|
|
221
|
+
return module_approximation_old(st, max_error=max_error_, box=box,threshold=threshold,complete=complete,verbose=verbose,ignore_warning=ignore_warning,nlines=nlines)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
cdef intptr_t ptr = st.thisptr
|
|
226
|
+
cdef Simplex_tree_multi_interface* st_ptr = <Simplex_tree_multi_interface*>(ptr)
|
|
227
|
+
cdef SimplicialVineGraphTruc graphtruc# copy ?
|
|
228
|
+
cdef SimplicialVineMatrixTruc matrixtruc# copy ?
|
|
229
|
+
cdef GeneralVineTruc generaltruc
|
|
230
|
+
|
|
231
|
+
cdef int num_parameters = st.num_parameters
|
|
232
|
+
|
|
233
|
+
if num_parameters <= 0:
|
|
234
|
+
return PyModule()
|
|
235
|
+
if num_parameters == 1 and not(st is None):
|
|
236
|
+
st = st.project_on_line(0)
|
|
237
|
+
return st.persistence()
|
|
238
|
+
|
|
239
|
+
if box is not None:
|
|
240
|
+
m,M = np.asarray(box)
|
|
241
|
+
else:
|
|
242
|
+
m,M = st.filtration_bounds()
|
|
243
|
+
box =np.asarray([m,M])
|
|
244
|
+
prod = 1
|
|
245
|
+
h = M[-1] - m[-1]
|
|
246
|
+
for i, [a,b] in enumerate(zip(m,M)):
|
|
247
|
+
if i == len(M)-1: continue
|
|
248
|
+
prod *= (b-a + h)
|
|
249
|
+
|
|
250
|
+
if max_error <= 0:
|
|
251
|
+
max_error = (prod/nlines)**(1/(num_parameters-1))
|
|
252
|
+
|
|
253
|
+
if not ignore_warning and prod >= 20_000:
|
|
254
|
+
from warnings import warn
|
|
255
|
+
warn(f"Warning : the number of lines (around {np.round(prod)}) may be too high. Try to increase the precision parameter, or set `ignore_warning=True` to compute this module. Returning the trivial module.")
|
|
256
|
+
return PyModule()
|
|
257
|
+
|
|
258
|
+
cdef Module mod
|
|
259
|
+
cdef Box[value_type] c_box = Box[value_type](box)
|
|
260
|
+
# Module multiparameter_module_approximation(Slicer &slicer, const value_type precision,
|
|
261
|
+
# Box<value_type> &box, const bool threshold,
|
|
262
|
+
# const bool complete, const bool verbose)
|
|
263
|
+
if backend == "matrix":
|
|
264
|
+
matrixtruc = SimplicialVineMatrixTruc(st_ptr)
|
|
265
|
+
with nogil:
|
|
266
|
+
mod = multiparameter_module_approximation(matrixtruc, max_error,c_box,threshold, complete, verbose)
|
|
267
|
+
elif backend == "graph":
|
|
268
|
+
graphtruc = SimplicialVineGraphTruc(st_ptr)
|
|
269
|
+
with nogil:
|
|
270
|
+
mod = multiparameter_module_approximation(graphtruc, max_error,c_box,threshold, complete, verbose)
|
|
271
|
+
else:
|
|
272
|
+
raise ValueError("Invalid backend.")
|
|
273
|
+
|
|
274
|
+
approx_mod = PyModule()
|
|
275
|
+
approx_mod._set_from_ptr(<intptr_t>(&mod))
|
|
276
|
+
return approx_mod
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def multiparameter_module_approximation_from_slicer(slicer, box, int num_parameters, value_type max_error, bool complete, bool threshold, bool verbose):
|
|
282
|
+
cdef intptr_t slicer_ptr = <intptr_t>(slicer.get_ptr())
|
|
283
|
+
cdef GeneralVineTruc cslicer
|
|
284
|
+
cdef GeneralVineClementTruc generalclementtruc
|
|
285
|
+
cdef SimplicialVineGraphTruc graphtruc
|
|
286
|
+
cdef SimplicialVineMatrixTruc matrixtruc
|
|
287
|
+
cdef Module mod
|
|
288
|
+
cdef Box[value_type] c_box = Box[value_type](box)
|
|
289
|
+
if isinstance(slicer,Slicer):
|
|
290
|
+
cslicer = dereference(<GeneralVineTruc*>(slicer_ptr))
|
|
291
|
+
with nogil:
|
|
292
|
+
mod = multiparameter_module_approximation(cslicer, max_error,c_box,threshold, complete, verbose)
|
|
293
|
+
elif isinstance(slicer,SlicerClement):
|
|
294
|
+
generalclementtruc = dereference(<GeneralVineClementTruc*>(slicer_ptr))
|
|
295
|
+
with nogil:
|
|
296
|
+
mod = multiparameter_module_approximation(generalclementtruc, max_error,c_box,threshold, complete, verbose)
|
|
297
|
+
elif isinstance(slicer,SlicerVineGraph):
|
|
298
|
+
graphtruc = dereference(<SimplicialVineGraphTruc*>(slicer_ptr))
|
|
299
|
+
with nogil:
|
|
300
|
+
mod = multiparameter_module_approximation(graphtruc, max_error,c_box,threshold, complete, verbose)
|
|
301
|
+
elif isinstance(slicer,SlicerVineSimplcial):
|
|
302
|
+
matrixtruc = dereference(<SimplicialVineMatrixTruc*>(slicer_ptr))
|
|
303
|
+
with nogil:
|
|
304
|
+
mod = multiparameter_module_approximation(matrixtruc, max_error,c_box,threshold, complete, verbose)
|
|
305
|
+
else:
|
|
306
|
+
raise ValueError("Unimplemeted slicer / Invalid slicer.")
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
approx_mod = PyModule()
|
|
310
|
+
approx_mod._set_from_ptr(<intptr_t>(&mod))
|
|
311
|
+
return approx_mod
|
multipers/pickle.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
def save_with_axis(path:str, signed_measures):
|
|
4
|
+
np.savez(path,
|
|
5
|
+
**{f"{i}_{axis}_{degree}":np.c_[sm_of_degree[0],sm_of_degree[1][:,np.newaxis]] for i,sm in enumerate(signed_measures) for axis,sm_of_axis in enumerate(sm) for degree,sm_of_degree in enumerate(sm_of_axis)},
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
def save_without_axis(path:str, signed_measures):
|
|
9
|
+
np.savez(path,
|
|
10
|
+
**{f"{i}_{degree}":np.c_[sm_of_degree[0],sm_of_degree[1][:,np.newaxis]] for i,sm in enumerate(signed_measures) for degree,sm_of_degree in enumerate(sm)},
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
def get_sm_with_axis(sms,idx,axis,degree):
|
|
14
|
+
sm = sms[f"{idx}_{axis}_{degree}"]
|
|
15
|
+
return (sm[:,:-1],sm[:,-1])
|
|
16
|
+
def get_sm_without_axis(sms,idx,degree):
|
|
17
|
+
sm = sms[f"{idx}_{degree}"]
|
|
18
|
+
return (sm[:,:-1],sm[:,-1])
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def load_without_axis(sms):
|
|
22
|
+
indices = np.array([[int(i) for i in key.split('_')] for key in sms.keys()], dtype=int)
|
|
23
|
+
num_data,num_degrees = indices.max(axis=0)+1
|
|
24
|
+
signed_measures_reconstructed = [[get_sm_without_axis(sms,idx,degree) for degree in range(num_degrees)] for idx in range(num_data)]
|
|
25
|
+
return signed_measures_reconstructed
|
|
26
|
+
# test : np.all([np.array_equal(a[0],b[0]) and np.array_equal(a[1],b[1]) and len(a) == len(b) == 2 for x,y in zip(signed_measures_reconstructed,signed_measures_reconstructed) for a,b in zip(x,y)])
|
|
27
|
+
|
|
28
|
+
def load_with_axis(sms):
|
|
29
|
+
indices = np.array([[int(i) for i in key.split('_')] for key in sms.keys()], dtype=int)
|
|
30
|
+
num_data,num_axis,num_degrees = indices.max(axis=0)+1
|
|
31
|
+
signed_measures_reconstructed = [[[get_sm_with_axis(sms,idx,axis,degree) for degree in range(num_degrees)] for axis in range(num_axis)] for idx in range(num_data)]
|
|
32
|
+
return signed_measures_reconstructed
|
|
33
|
+
|
|
34
|
+
def save(path:str, signed_measures):
|
|
35
|
+
if isinstance(signed_measures[0][0], tuple):
|
|
36
|
+
save_without_axis(path=path,signed_measures=signed_measures)
|
|
37
|
+
else:
|
|
38
|
+
save_with_axis(path=path,signed_measures=signed_measures)
|
|
39
|
+
|
|
40
|
+
def load(path:str):
|
|
41
|
+
sms = np.load(path)
|
|
42
|
+
item=None
|
|
43
|
+
for i in sms.keys():
|
|
44
|
+
item=i
|
|
45
|
+
break
|
|
46
|
+
n = len(item.split('_'))
|
|
47
|
+
match n:
|
|
48
|
+
case 2:
|
|
49
|
+
return load_without_axis(sms)
|
|
50
|
+
case 3:
|
|
51
|
+
return load_with_axis(sms)
|
|
52
|
+
case _:
|
|
53
|
+
raise Exception("Invalid Signed Measure !")
|
multipers/plots.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _plot_rectangle(rectangle: np.ndarray, weight, **plt_kwargs):
|
|
7
|
+
rectangle = np.asarray(rectangle)
|
|
8
|
+
x_axis = rectangle[[0, 2]]
|
|
9
|
+
y_axis = rectangle[[1, 3]]
|
|
10
|
+
color = "blue" if weight > 0 else "red"
|
|
11
|
+
plt.plot(x_axis, y_axis, c=color, **plt_kwargs)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _plot_signed_measure_2(pts, weights, temp_alpha=0.7, **plt_kwargs):
|
|
15
|
+
import matplotlib.colors
|
|
16
|
+
|
|
17
|
+
weights = np.asarray(weights)
|
|
18
|
+
color_weights = np.array(weights, dtype=float)
|
|
19
|
+
neg_idx = weights < 0
|
|
20
|
+
pos_idx = weights > 0
|
|
21
|
+
if np.any(neg_idx):
|
|
22
|
+
current_weights = -weights[neg_idx]
|
|
23
|
+
min_weight = np.max(current_weights)
|
|
24
|
+
color_weights[neg_idx] /= min_weight
|
|
25
|
+
color_weights[neg_idx] -= 1
|
|
26
|
+
else:
|
|
27
|
+
min_weight = 0
|
|
28
|
+
|
|
29
|
+
if np.any(pos_idx):
|
|
30
|
+
current_weights = weights[pos_idx]
|
|
31
|
+
max_weight = np.max(current_weights)
|
|
32
|
+
color_weights[pos_idx] /= max_weight
|
|
33
|
+
color_weights[pos_idx] += 1
|
|
34
|
+
else:
|
|
35
|
+
max_weight = 1
|
|
36
|
+
|
|
37
|
+
bordeaux = np.array([0.70567316, 0.01555616, 0.15023281, 1])
|
|
38
|
+
light_bordeaux = np.array([0.70567316, 0.01555616, 0.15023281, temp_alpha])
|
|
39
|
+
bleu = np.array([0.2298057, 0.29871797, 0.75368315, 1])
|
|
40
|
+
light_bleu = np.array([0.2298057, 0.29871797, 0.75368315, temp_alpha])
|
|
41
|
+
norm = plt.Normalize(-2, 2)
|
|
42
|
+
cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
|
|
43
|
+
"", [bordeaux, light_bordeaux, "white", light_bleu, bleu]
|
|
44
|
+
)
|
|
45
|
+
plt.scatter(
|
|
46
|
+
pts[:, 0], pts[:, 1], c=color_weights, cmap=cmap, norm=norm, **plt_kwargs
|
|
47
|
+
)
|
|
48
|
+
plt.scatter([], [], color=bleu, label="positive mass", **plt_kwargs)
|
|
49
|
+
plt.scatter([], [], color=bordeaux, label="negative mass", **plt_kwargs)
|
|
50
|
+
plt.legend()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _plot_signed_measure_4(
|
|
54
|
+
pts, weights, x_smoothing: float = 1, area_alpha: bool = True
|
|
55
|
+
):
|
|
56
|
+
# compute the maximal rectangle area
|
|
57
|
+
alpha_rescaling = 0
|
|
58
|
+
for rectangle, weight in zip(pts, weights):
|
|
59
|
+
if rectangle[2] > x_smoothing * rectangle[0]:
|
|
60
|
+
alpha_rescaling = max(
|
|
61
|
+
alpha_rescaling,
|
|
62
|
+
(rectangle[2] / x_smoothing - rectangle[0])
|
|
63
|
+
* (rectangle[3] - rectangle[1]),
|
|
64
|
+
)
|
|
65
|
+
# draw the rectangles
|
|
66
|
+
for rectangle, weight in zip(pts, weights):
|
|
67
|
+
# draw only the rectangles that have not been reduced to the empty set
|
|
68
|
+
if rectangle[2] > x_smoothing * rectangle[0]:
|
|
69
|
+
# make the alpha channel proportional to the rectangle's area
|
|
70
|
+
if area_alpha:
|
|
71
|
+
_plot_rectangle(
|
|
72
|
+
rectangle=[
|
|
73
|
+
rectangle[0],
|
|
74
|
+
rectangle[1],
|
|
75
|
+
rectangle[2] / x_smoothing,
|
|
76
|
+
rectangle[3],
|
|
77
|
+
],
|
|
78
|
+
weight=weight,
|
|
79
|
+
alpha=(rectangle[2] / x_smoothing - rectangle[0])
|
|
80
|
+
* (rectangle[3] - rectangle[1])
|
|
81
|
+
/ alpha_rescaling,
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
_plot_rectangle(
|
|
85
|
+
rectangle=[
|
|
86
|
+
rectangle[0],
|
|
87
|
+
rectangle[1],
|
|
88
|
+
rectangle[2] / x_smoothing,
|
|
89
|
+
rectangle[3],
|
|
90
|
+
],
|
|
91
|
+
weight=weight,
|
|
92
|
+
alpha=1,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def plot_signed_measure(signed_measure, ax=None, **plt_kwargs):
|
|
97
|
+
if ax is None:
|
|
98
|
+
ax = plt.gca()
|
|
99
|
+
else:
|
|
100
|
+
plt.sca(ax)
|
|
101
|
+
pts, weights = signed_measure
|
|
102
|
+
num_parameters = pts.shape[1]
|
|
103
|
+
|
|
104
|
+
if isinstance(pts, np.ndarray):
|
|
105
|
+
pass
|
|
106
|
+
else:
|
|
107
|
+
import torch
|
|
108
|
+
|
|
109
|
+
if isinstance(pts, torch.Tensor):
|
|
110
|
+
pts = pts.detach().numpy()
|
|
111
|
+
else:
|
|
112
|
+
raise Exception("Invalid measure type.")
|
|
113
|
+
|
|
114
|
+
assert num_parameters in (2, 4)
|
|
115
|
+
if num_parameters == 2:
|
|
116
|
+
_plot_signed_measure_2(pts=pts, weights=weights, **plt_kwargs)
|
|
117
|
+
else:
|
|
118
|
+
_plot_signed_measure_4(pts=pts, weights=weights, **plt_kwargs)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def plot_signed_measures(signed_measures, size=4):
|
|
122
|
+
num_degrees = len(signed_measures)
|
|
123
|
+
fig, axes = plt.subplots(
|
|
124
|
+
nrows=1, ncols=num_degrees, figsize=(num_degrees * size, size)
|
|
125
|
+
)
|
|
126
|
+
if num_degrees == 1:
|
|
127
|
+
axes = [axes]
|
|
128
|
+
for ax, signed_measure in zip(axes, signed_measures):
|
|
129
|
+
plot_signed_measure(signed_measure=signed_measure, ax=ax)
|
|
130
|
+
plt.tight_layout()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def plot_surface(
|
|
134
|
+
grid,
|
|
135
|
+
hf,
|
|
136
|
+
fig=None,
|
|
137
|
+
ax=None,
|
|
138
|
+
cmap: Optional[str] = None,
|
|
139
|
+
discrete_surface=False,
|
|
140
|
+
**plt_args,
|
|
141
|
+
):
|
|
142
|
+
import matplotlib
|
|
143
|
+
|
|
144
|
+
if ax is None:
|
|
145
|
+
ax = plt.gca()
|
|
146
|
+
else:
|
|
147
|
+
plt.sca(ax)
|
|
148
|
+
if hf.ndim == 3 and hf.shape[0] == 1:
|
|
149
|
+
hf = hf[0]
|
|
150
|
+
assert hf.ndim == 2, "Can only plot a 2d surface"
|
|
151
|
+
fig = plt.gcf() if fig is None else fig
|
|
152
|
+
if cmap is None:
|
|
153
|
+
if discrete_surface:
|
|
154
|
+
cmap = matplotlib.colormaps["gray_r"]
|
|
155
|
+
else:
|
|
156
|
+
cmap = matplotlib.colormaps["plasma"]
|
|
157
|
+
if discrete_surface:
|
|
158
|
+
bounds = np.arange(0, 11, 1, dtype=int)
|
|
159
|
+
norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N, extend="max")
|
|
160
|
+
im = ax.pcolormesh(grid[0], grid[1], hf.T, cmap=cmap, norm=norm, **plt_args)
|
|
161
|
+
cbar = fig.colorbar(
|
|
162
|
+
matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm),
|
|
163
|
+
spacing="proportional",
|
|
164
|
+
ax=ax,
|
|
165
|
+
)
|
|
166
|
+
cbar.set_ticks(ticks=bounds, labels=bounds)
|
|
167
|
+
return
|
|
168
|
+
im = ax.pcolormesh(grid[0], grid[1], hf.T, cmap=cmap, **plt_args)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def plot_surfaces(HF, size=4, **plt_args):
|
|
172
|
+
grid, hf = HF
|
|
173
|
+
assert (
|
|
174
|
+
hf.ndim == 3
|
|
175
|
+
), f"Found hf.shape = {hf.shape}, expected ndim = 3 : degree, 2-parameter surface."
|
|
176
|
+
num_degrees = hf.shape[0]
|
|
177
|
+
fig, axes = plt.subplots(
|
|
178
|
+
nrows=1, ncols=num_degrees, figsize=(num_degrees * size, size)
|
|
179
|
+
)
|
|
180
|
+
if num_degrees == 1:
|
|
181
|
+
axes = [axes]
|
|
182
|
+
for ax, hf_of_degree in zip(axes, hf):
|
|
183
|
+
plot_surface(grid=grid, hf=hf_of_degree, fig=fig, ax=ax, **plt_args)
|
|
184
|
+
plt.tight_layout()
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _rectangle(x, y, color, alpha):
|
|
188
|
+
"""
|
|
189
|
+
Defines a rectangle patch in the format {z | x ≤ z ≤ y} with color and alpha
|
|
190
|
+
"""
|
|
191
|
+
from matplotlib.patches import Rectangle as RectanglePatch
|
|
192
|
+
|
|
193
|
+
return RectanglePatch(
|
|
194
|
+
x, max(y[0] - x[0], 0), max(y[1] - x[1], 0), color=color, alpha=alpha
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _d_inf(a, b):
|
|
199
|
+
if type(a) != np.ndarray or type(b) != np.ndarray:
|
|
200
|
+
a = np.array(a)
|
|
201
|
+
b = np.array(b)
|
|
202
|
+
return np.min(np.abs(b - a))
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def plot2d_PyModule(
|
|
206
|
+
corners,
|
|
207
|
+
box=[],
|
|
208
|
+
*,
|
|
209
|
+
dimension=-1,
|
|
210
|
+
separated=False,
|
|
211
|
+
min_persistence=0,
|
|
212
|
+
alpha=1,
|
|
213
|
+
verbose=False,
|
|
214
|
+
save=False,
|
|
215
|
+
dpi=200,
|
|
216
|
+
shapely=True,
|
|
217
|
+
xlabel=None,
|
|
218
|
+
ylabel=None,
|
|
219
|
+
cmap=None,
|
|
220
|
+
):
|
|
221
|
+
import matplotlib
|
|
222
|
+
|
|
223
|
+
try:
|
|
224
|
+
from shapely.geometry import Polygon as _Polygon
|
|
225
|
+
from shapely.geometry import box as _rectangle_box
|
|
226
|
+
from shapely import union_all
|
|
227
|
+
|
|
228
|
+
shapely = True and shapely
|
|
229
|
+
except:
|
|
230
|
+
from warnings import warn
|
|
231
|
+
|
|
232
|
+
shapely = False
|
|
233
|
+
warn(
|
|
234
|
+
"Shapely not installed. Fallbacking to matplotlib. The plots may be inacurate."
|
|
235
|
+
)
|
|
236
|
+
cmap = (
|
|
237
|
+
matplotlib.colormaps["Spectral"] if cmap is None else matplotlib.colormaps[cmap]
|
|
238
|
+
)
|
|
239
|
+
if not (separated):
|
|
240
|
+
# fig, ax = plt.subplots()
|
|
241
|
+
ax = plt.gca()
|
|
242
|
+
ax.set(xlim=[box[0][0], box[1][0]], ylim=[box[0][1], box[1][1]])
|
|
243
|
+
n_summands = len(corners)
|
|
244
|
+
for i in range(n_summands):
|
|
245
|
+
trivial_summand = True
|
|
246
|
+
list_of_rect = []
|
|
247
|
+
for birth in corners[i][0]:
|
|
248
|
+
for death in corners[i][1]:
|
|
249
|
+
death[0] = min(death[0], box[1][0])
|
|
250
|
+
death[1] = min(death[1], box[1][1])
|
|
251
|
+
if death[1] > birth[1] and death[0] > birth[0]:
|
|
252
|
+
if trivial_summand and _d_inf(birth, death) > min_persistence:
|
|
253
|
+
trivial_summand = False
|
|
254
|
+
if shapely:
|
|
255
|
+
list_of_rect.append(
|
|
256
|
+
_rectangle_box(birth[0], birth[1], death[0], death[1])
|
|
257
|
+
)
|
|
258
|
+
else:
|
|
259
|
+
list_of_rect.append(
|
|
260
|
+
_rectangle(birth, death, cmap(i / n_summands), alpha)
|
|
261
|
+
)
|
|
262
|
+
if not (trivial_summand):
|
|
263
|
+
if separated:
|
|
264
|
+
fig, ax = plt.subplots()
|
|
265
|
+
ax.set(xlim=[box[0][0], box[1][0]], ylim=[box[0][1], box[1][1]])
|
|
266
|
+
if shapely:
|
|
267
|
+
summand_shape = union_all(list_of_rect)
|
|
268
|
+
if type(summand_shape) is _Polygon:
|
|
269
|
+
xs, ys = summand_shape.exterior.xy
|
|
270
|
+
ax.fill(xs, ys, alpha=alpha, fc=cmap(i / n_summands), ec="None")
|
|
271
|
+
else:
|
|
272
|
+
for polygon in summand_shape.geoms:
|
|
273
|
+
xs, ys = polygon.exterior.xy
|
|
274
|
+
ax.fill(xs, ys, alpha=alpha, fc=cmap(i / n_summands), ec="None")
|
|
275
|
+
else:
|
|
276
|
+
for rectangle in list_of_rect:
|
|
277
|
+
ax.add_patch(rectangle)
|
|
278
|
+
if separated:
|
|
279
|
+
if xlabel:
|
|
280
|
+
plt.xlabel(xlabel)
|
|
281
|
+
if ylabel:
|
|
282
|
+
plt.ylabel(ylabel)
|
|
283
|
+
if dimension >= 0:
|
|
284
|
+
plt.title(rf"$H_{dimension}$ $2$-persistence")
|
|
285
|
+
if not (separated):
|
|
286
|
+
if xlabel != None:
|
|
287
|
+
plt.xlabel(xlabel)
|
|
288
|
+
if ylabel != None:
|
|
289
|
+
plt.ylabel(ylabel)
|
|
290
|
+
if dimension >= 0:
|
|
291
|
+
plt.title(rf"$H_{dimension}$ $2$-persistence")
|
|
292
|
+
return
|
|
Binary file
|