multipers 1.1.3__cp310-cp310-macosx_11_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (63) hide show
  1. multipers/.dylibs/libtbb.12.12.dylib +0 -0
  2. multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
  3. multipers/__init__.py +5 -0
  4. multipers/_old_rank_invariant.pyx +328 -0
  5. multipers/_signed_measure_meta.py +193 -0
  6. multipers/data/MOL2.py +350 -0
  7. multipers/data/UCR.py +18 -0
  8. multipers/data/__init__.py +1 -0
  9. multipers/data/graphs.py +466 -0
  10. multipers/data/immuno_regions.py +27 -0
  11. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  12. multipers/data/pytorch2simplextree.py +91 -0
  13. multipers/data/shape3d.py +101 -0
  14. multipers/data/synthetic.py +68 -0
  15. multipers/distances.py +172 -0
  16. multipers/euler_characteristic.cpython-310-darwin.so +0 -0
  17. multipers/euler_characteristic.pyx +137 -0
  18. multipers/function_rips.cpython-310-darwin.so +0 -0
  19. multipers/function_rips.pyx +102 -0
  20. multipers/hilbert_function.cpython-310-darwin.so +0 -0
  21. multipers/hilbert_function.pyi +46 -0
  22. multipers/hilbert_function.pyx +151 -0
  23. multipers/io.cpython-310-darwin.so +0 -0
  24. multipers/io.pyx +176 -0
  25. multipers/ml/__init__.py +0 -0
  26. multipers/ml/accuracies.py +61 -0
  27. multipers/ml/convolutions.py +510 -0
  28. multipers/ml/invariants_with_persistable.py +79 -0
  29. multipers/ml/kernels.py +128 -0
  30. multipers/ml/mma.py +657 -0
  31. multipers/ml/one.py +472 -0
  32. multipers/ml/point_clouds.py +191 -0
  33. multipers/ml/signed_betti.py +50 -0
  34. multipers/ml/signed_measures.py +1479 -0
  35. multipers/ml/sliced_wasserstein.py +313 -0
  36. multipers/ml/tools.py +116 -0
  37. multipers/mma_structures.cpython-310-darwin.so +0 -0
  38. multipers/mma_structures.pxd +155 -0
  39. multipers/mma_structures.pyx +651 -0
  40. multipers/multiparameter_edge_collapse.py +29 -0
  41. multipers/multiparameter_module_approximation.cpython-310-darwin.so +0 -0
  42. multipers/multiparameter_module_approximation.pyi +439 -0
  43. multipers/multiparameter_module_approximation.pyx +311 -0
  44. multipers/pickle.py +53 -0
  45. multipers/plots.py +292 -0
  46. multipers/point_measure_integration.cpython-310-darwin.so +0 -0
  47. multipers/point_measure_integration.pyx +59 -0
  48. multipers/rank_invariant.cpython-310-darwin.so +0 -0
  49. multipers/rank_invariant.pyx +154 -0
  50. multipers/simplex_tree_multi.cpython-310-darwin.so +0 -0
  51. multipers/simplex_tree_multi.pxd +121 -0
  52. multipers/simplex_tree_multi.pyi +715 -0
  53. multipers/simplex_tree_multi.pyx +1417 -0
  54. multipers/slicer.cpython-310-darwin.so +0 -0
  55. multipers/slicer.pxd +94 -0
  56. multipers/slicer.pyx +276 -0
  57. multipers/tensor.pxd +13 -0
  58. multipers/test.pyx +44 -0
  59. multipers-1.1.3.dist-info/LICENSE +21 -0
  60. multipers-1.1.3.dist-info/METADATA +22 -0
  61. multipers-1.1.3.dist-info/RECORD +63 -0
  62. multipers-1.1.3.dist-info/WHEEL +5 -0
  63. multipers-1.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,311 @@
1
+ """!
2
+ @package mma
3
+ @brief Files containing the C++ cythonized functions.
4
+ @author David Loiseaux
5
+ @copyright Copyright (c) 2022 Inria.
6
+ """
7
+
8
+ # distutils: language = c++
9
+
10
+ ###########################################################################
11
+ ## PYTHON LIBRARIES
12
+ import gudhi as gd
13
+ import numpy as np
14
+ from typing import List
15
+ import pickle as pk
16
+
17
+ ###########################################################################
18
+ ## CPP CLASSES
19
+ from libc.stdint cimport intptr_t
20
+ from libc.stdint cimport uintptr_t
21
+
22
+ ###########################################################################
23
+ ## CYTHON TYPES
24
+ from libcpp.vector cimport vector
25
+ from libcpp.utility cimport pair
26
+ #from libcpp.list cimport list as clist
27
+ from libcpp cimport bool
28
+ from libcpp cimport int
29
+ from typing import Iterable
30
+ from cython.operator import dereference
31
+ #########################################################################
32
+ ## Multipersistence Module Approximation Classes
33
+ from multipers.mma_structures cimport Module, Box, pair, boundary_matrix
34
+ from multipers.slicer cimport *
35
+
36
+
37
+ #########################################################################
38
+ ## Small hack for typing
39
+ from gudhi import SimplexTree
40
+ from multipers.simplex_tree_multi import SimplexTreeMulti
41
+ from multipers.slicer import Slicer, SlicerClement,SlicerVineGraph,SlicerVineSimplcial
42
+ from multipers.mma_structures import PyModule
43
+ # cimport numpy as cnp
44
+ # cnp.import_array()
45
+
46
+ ###################################### MMA
47
+ cdef extern from "multiparameter_module_approximation/approximation.h" namespace "Gudhi::multiparameter::mma":
48
+ Module compute_vineyard_barcode_approximation(boundary_matrix, vector[Finitely_critical_multi_filtration] , value_type precision, Box[value_type] &, bool threshold, bool complete, bool multithread, bool verbose) nogil
49
+ # TODO : tempita
50
+ Module multiparameter_module_approximation(SimplicialVineGraphTruc&, value_type, Box[value_type]&, bool, bool, bool) nogil
51
+
52
+ Module multiparameter_module_approximation(SimplicialVineMatrixTruc&, value_type, Box[value_type]&, bool, bool, bool) nogil
53
+ Module multiparameter_module_approximation(GeneralVineTruc&, value_type, Box[value_type]&, bool, bool, bool) nogil
54
+ Module multiparameter_module_approximation(GeneralVineClementTruc&, value_type, Box[value_type]&, bool, bool, bool) nogil
55
+
56
+ # TODO : remove when old is deprecated
57
+ cdef extern from "multiparameter_module_approximation/format_python-cpp.h" namespace "Gudhi::multiparameter::mma":
58
+ pair[boundary_matrix, vector[Finitely_critical_multi_filtration]] simplextree_to_boundary_filtration(uintptr_t)
59
+
60
+ def module_approximation_old(
61
+ st:SimplexTreeMulti|None=None,
62
+ max_error:float|None = None,
63
+ box:list|np.ndarray|None = None,
64
+ threshold:bool = False,
65
+ complete:bool = True,
66
+ multithread:bool = False,
67
+ verbose:bool = False,
68
+ ignore_warning:bool = False,
69
+ nlines:int = 500,
70
+ max_dimension=np.inf,
71
+ boundary = None,
72
+ filtration = None,
73
+ return_timings:bool = False,
74
+ **kwargs
75
+ ):
76
+ """Computes an interval module approximation of a multiparameter filtration.
77
+
78
+ Parameters
79
+ ----------
80
+ st : n-filtered Simplextree, or None if boundary and filtration are provided.
81
+ Defines the n-filtration on which to compute the homology.
82
+ max_error: positive float
83
+ Trade-off between approximation and computational complexity.
84
+ Upper bound of the module approximation, in bottleneck distance,
85
+ for interval-decomposable modules.
86
+ nlines: int
87
+ Alternative to precision.
88
+ box : pair of list of floats
89
+ Defines a rectangle on which to compute the approximation.
90
+ Format : [x,y], where x,y defines the rectangle {z : x ≤ z ≤ y}
91
+ threshold: bool
92
+ When true, intersects the module support with the box.
93
+ verbose: bool
94
+ Prints C++ infos.
95
+ ignore_warning : bool
96
+ Unless set to true, prevents computing on more than 10k lines. Useful to prevent a segmentation fault due to "infinite" recursion.
97
+ return_timings : bool
98
+ If true, will return the time to compute instead (computed in python, using perf_counter_ns).
99
+ Returns
100
+ -------
101
+ PyModule
102
+ An interval decomposable module approximation of the module defined by the
103
+ homology of this multi-filtration.
104
+ """
105
+
106
+ if boundary is None or filtration is None:
107
+ from multipers.io import simplex_tree2boundary_filtrations
108
+ boundary,filtration = simplex_tree2boundary_filtrations(st) # TODO : recomputed each time... maybe store this somewhere ?
109
+ if max_dimension < np.inf: # TODO : make it more efficient
110
+ nsplx = len(boundary)
111
+ for i in range(nsplx-1,-1,-1):
112
+ b = boundary[i]
113
+ dim=len(b) -1
114
+ if dim>max_dimension:
115
+ boundary.pop(i)
116
+ for f in filtration:
117
+ f.pop(i)
118
+ nfiltration = len(filtration)
119
+ if nfiltration <= 0:
120
+ return PyModule()
121
+ if nfiltration == 1 and not(st is None):
122
+ st = st.project_on_line(0)
123
+ return st.persistence()
124
+
125
+ if box is None and not(st is None):
126
+ m,M = st.filtration_bounds()
127
+ elif box is not None:
128
+ m,M = np.asarray(box)
129
+ else:
130
+ m, M = np.min(filtration, axis=0), np.max(filtration, axis=0)
131
+ prod = 1
132
+ h = M[-1] - m[-1]
133
+ for i, [a,b] in enumerate(zip(m,M)):
134
+ if i == len(M)-1: continue
135
+ prod *= (b-a + h)
136
+
137
+ if max_error is None:
138
+ max_error:float = (prod/nlines)**(1/(nfiltration-1))
139
+
140
+ if box is None:
141
+ M = [np.max(f)+2*max_error for f in filtration]
142
+ m = [np.min(f)-2*max_error for f in filtration]
143
+ box = [m,M]
144
+
145
+ if ignore_warning and prod >= 20_000:
146
+ from warnings import warn
147
+ warn(f"Warning : the number of lines (around {np.round(prod)}) may be too high. Try to increase the precision parameter, or set `ignore_warning=True` to compute this module. Returning the trivial module.")
148
+ return PyModule()
149
+
150
+ approx_mod = PyModule()
151
+ cdef vector[Finitely_critical_multi_filtration] c_filtration = Finitely_critical_multi_filtration.from_python(filtration)
152
+ cdef boundary_matrix c_boundary = boundary
153
+ cdef value_type c_max_error = max_error
154
+ cdef bool c_threshold = threshold
155
+ cdef bool c_complete = complete
156
+ cdef bool c_multithread = multithread
157
+ cdef bool c_verbose = verbose
158
+ cdef Box[value_type] c_box = Box[value_type](box)
159
+ if return_timings:
160
+ from time import perf_counter_ns
161
+ t = perf_counter_ns()
162
+ with nogil:
163
+ c_mod = compute_vineyard_barcode_approximation(c_boundary,c_filtration,c_max_error, c_box, c_threshold, c_complete, c_multithread,c_verbose)
164
+ if return_timings:
165
+ t = perf_counter_ns() -t
166
+ t /= 10**9
167
+ return t
168
+ approx_mod._set_from_ptr(<intptr_t>(&c_mod))
169
+ return approx_mod
170
+
171
+
172
+
173
+ def module_approximation(
174
+ st:SimplexTreeMulti,
175
+ value_type max_error = -1,
176
+ box:list|np.ndarray|None = None,
177
+ bool threshold:bool = False,
178
+ bool complete:bool = True,
179
+ bool verbose:bool = False,
180
+ bool ignore_warning:bool = False,
181
+ int nlines = 500,
182
+ backend:str="matrix",
183
+ # max_dimension=np.inf,
184
+ # return_timings:bool = False,
185
+ **kwargs
186
+ ):
187
+ """Computes an interval module approximation of a multiparameter filtration.
188
+
189
+ Parameters
190
+ ----------
191
+ st : n-filtered Simplextree, or None if boundary and filtration are provided.
192
+ Defines the n-filtration on which to compute the homology.
193
+ max_error: positive float
194
+ Trade-off between approximation and computational complexity.
195
+ Upper bound of the module approximation, in bottleneck distance,
196
+ for interval-decomposable modules.
197
+ nlines: int
198
+ Alternative to precision.
199
+ box : pair of list of floats
200
+ Defines a rectangle on which to compute the approximation.
201
+ Format : [x,y], where x,y defines the rectangle {z : x ≤ z ≤ y}
202
+ threshold: bool
203
+ When true, intersects the module support with the box.
204
+ verbose: bool
205
+ Prints C++ infos.
206
+ ignore_warning : bool
207
+ Unless set to true, prevents computing on more than 10k lines. Useful to prevent a segmentation fault due to "infinite" recursion.
208
+ return_timings : bool
209
+ If true, will return the time to compute instead (computed in python, using perf_counter_ns).
210
+ Returns
211
+ -------
212
+ PyModule
213
+ An interval decomposable module approximation of the module defined by the
214
+ homology of this multi-filtration.
215
+ """
216
+ if backend == "old":
217
+ if max_error == -1:
218
+ max_error_=None
219
+ else:
220
+ max_error_=max_error
221
+ return module_approximation_old(st, max_error=max_error_, box=box,threshold=threshold,complete=complete,verbose=verbose,ignore_warning=ignore_warning,nlines=nlines)
222
+
223
+
224
+
225
+ cdef intptr_t ptr = st.thisptr
226
+ cdef Simplex_tree_multi_interface* st_ptr = <Simplex_tree_multi_interface*>(ptr)
227
+ cdef SimplicialVineGraphTruc graphtruc# copy ?
228
+ cdef SimplicialVineMatrixTruc matrixtruc# copy ?
229
+ cdef GeneralVineTruc generaltruc
230
+
231
+ cdef int num_parameters = st.num_parameters
232
+
233
+ if num_parameters <= 0:
234
+ return PyModule()
235
+ if num_parameters == 1 and not(st is None):
236
+ st = st.project_on_line(0)
237
+ return st.persistence()
238
+
239
+ if box is not None:
240
+ m,M = np.asarray(box)
241
+ else:
242
+ m,M = st.filtration_bounds()
243
+ box =np.asarray([m,M])
244
+ prod = 1
245
+ h = M[-1] - m[-1]
246
+ for i, [a,b] in enumerate(zip(m,M)):
247
+ if i == len(M)-1: continue
248
+ prod *= (b-a + h)
249
+
250
+ if max_error <= 0:
251
+ max_error = (prod/nlines)**(1/(num_parameters-1))
252
+
253
+ if not ignore_warning and prod >= 20_000:
254
+ from warnings import warn
255
+ warn(f"Warning : the number of lines (around {np.round(prod)}) may be too high. Try to increase the precision parameter, or set `ignore_warning=True` to compute this module. Returning the trivial module.")
256
+ return PyModule()
257
+
258
+ cdef Module mod
259
+ cdef Box[value_type] c_box = Box[value_type](box)
260
+ # Module multiparameter_module_approximation(Slicer &slicer, const value_type precision,
261
+ # Box<value_type> &box, const bool threshold,
262
+ # const bool complete, const bool verbose)
263
+ if backend == "matrix":
264
+ matrixtruc = SimplicialVineMatrixTruc(st_ptr)
265
+ with nogil:
266
+ mod = multiparameter_module_approximation(matrixtruc, max_error,c_box,threshold, complete, verbose)
267
+ elif backend == "graph":
268
+ graphtruc = SimplicialVineGraphTruc(st_ptr)
269
+ with nogil:
270
+ mod = multiparameter_module_approximation(graphtruc, max_error,c_box,threshold, complete, verbose)
271
+ else:
272
+ raise ValueError("Invalid backend.")
273
+
274
+ approx_mod = PyModule()
275
+ approx_mod._set_from_ptr(<intptr_t>(&mod))
276
+ return approx_mod
277
+
278
+
279
+
280
+
281
+ def multiparameter_module_approximation_from_slicer(slicer, box, int num_parameters, value_type max_error, bool complete, bool threshold, bool verbose):
282
+ cdef intptr_t slicer_ptr = <intptr_t>(slicer.get_ptr())
283
+ cdef GeneralVineTruc cslicer
284
+ cdef GeneralVineClementTruc generalclementtruc
285
+ cdef SimplicialVineGraphTruc graphtruc
286
+ cdef SimplicialVineMatrixTruc matrixtruc
287
+ cdef Module mod
288
+ cdef Box[value_type] c_box = Box[value_type](box)
289
+ if isinstance(slicer,Slicer):
290
+ cslicer = dereference(<GeneralVineTruc*>(slicer_ptr))
291
+ with nogil:
292
+ mod = multiparameter_module_approximation(cslicer, max_error,c_box,threshold, complete, verbose)
293
+ elif isinstance(slicer,SlicerClement):
294
+ generalclementtruc = dereference(<GeneralVineClementTruc*>(slicer_ptr))
295
+ with nogil:
296
+ mod = multiparameter_module_approximation(generalclementtruc, max_error,c_box,threshold, complete, verbose)
297
+ elif isinstance(slicer,SlicerVineGraph):
298
+ graphtruc = dereference(<SimplicialVineGraphTruc*>(slicer_ptr))
299
+ with nogil:
300
+ mod = multiparameter_module_approximation(graphtruc, max_error,c_box,threshold, complete, verbose)
301
+ elif isinstance(slicer,SlicerVineSimplcial):
302
+ matrixtruc = dereference(<SimplicialVineMatrixTruc*>(slicer_ptr))
303
+ with nogil:
304
+ mod = multiparameter_module_approximation(matrixtruc, max_error,c_box,threshold, complete, verbose)
305
+ else:
306
+ raise ValueError("Unimplemeted slicer / Invalid slicer.")
307
+
308
+
309
+ approx_mod = PyModule()
310
+ approx_mod._set_from_ptr(<intptr_t>(&mod))
311
+ return approx_mod
multipers/pickle.py ADDED
@@ -0,0 +1,53 @@
1
+ import numpy as np
2
+
3
+ def save_with_axis(path:str, signed_measures):
4
+ np.savez(path,
5
+ **{f"{i}_{axis}_{degree}":np.c_[sm_of_degree[0],sm_of_degree[1][:,np.newaxis]] for i,sm in enumerate(signed_measures) for axis,sm_of_axis in enumerate(sm) for degree,sm_of_degree in enumerate(sm_of_axis)},
6
+ )
7
+
8
+ def save_without_axis(path:str, signed_measures):
9
+ np.savez(path,
10
+ **{f"{i}_{degree}":np.c_[sm_of_degree[0],sm_of_degree[1][:,np.newaxis]] for i,sm in enumerate(signed_measures) for degree,sm_of_degree in enumerate(sm)},
11
+ )
12
+
13
+ def get_sm_with_axis(sms,idx,axis,degree):
14
+ sm = sms[f"{idx}_{axis}_{degree}"]
15
+ return (sm[:,:-1],sm[:,-1])
16
+ def get_sm_without_axis(sms,idx,degree):
17
+ sm = sms[f"{idx}_{degree}"]
18
+ return (sm[:,:-1],sm[:,-1])
19
+
20
+
21
+ def load_without_axis(sms):
22
+ indices = np.array([[int(i) for i in key.split('_')] for key in sms.keys()], dtype=int)
23
+ num_data,num_degrees = indices.max(axis=0)+1
24
+ signed_measures_reconstructed = [[get_sm_without_axis(sms,idx,degree) for degree in range(num_degrees)] for idx in range(num_data)]
25
+ return signed_measures_reconstructed
26
+ # test : np.all([np.array_equal(a[0],b[0]) and np.array_equal(a[1],b[1]) and len(a) == len(b) == 2 for x,y in zip(signed_measures_reconstructed,signed_measures_reconstructed) for a,b in zip(x,y)])
27
+
28
+ def load_with_axis(sms):
29
+ indices = np.array([[int(i) for i in key.split('_')] for key in sms.keys()], dtype=int)
30
+ num_data,num_axis,num_degrees = indices.max(axis=0)+1
31
+ signed_measures_reconstructed = [[[get_sm_with_axis(sms,idx,axis,degree) for degree in range(num_degrees)] for axis in range(num_axis)] for idx in range(num_data)]
32
+ return signed_measures_reconstructed
33
+
34
+ def save(path:str, signed_measures):
35
+ if isinstance(signed_measures[0][0], tuple):
36
+ save_without_axis(path=path,signed_measures=signed_measures)
37
+ else:
38
+ save_with_axis(path=path,signed_measures=signed_measures)
39
+
40
+ def load(path:str):
41
+ sms = np.load(path)
42
+ item=None
43
+ for i in sms.keys():
44
+ item=i
45
+ break
46
+ n = len(item.split('_'))
47
+ match n:
48
+ case 2:
49
+ return load_without_axis(sms)
50
+ case 3:
51
+ return load_with_axis(sms)
52
+ case _:
53
+ raise Exception("Invalid Signed Measure !")
multipers/plots.py ADDED
@@ -0,0 +1,292 @@
1
+ import matplotlib.pyplot as plt
2
+ from typing import Optional
3
+ import numpy as np
4
+
5
+
6
+ def _plot_rectangle(rectangle: np.ndarray, weight, **plt_kwargs):
7
+ rectangle = np.asarray(rectangle)
8
+ x_axis = rectangle[[0, 2]]
9
+ y_axis = rectangle[[1, 3]]
10
+ color = "blue" if weight > 0 else "red"
11
+ plt.plot(x_axis, y_axis, c=color, **plt_kwargs)
12
+
13
+
14
+ def _plot_signed_measure_2(pts, weights, temp_alpha=0.7, **plt_kwargs):
15
+ import matplotlib.colors
16
+
17
+ weights = np.asarray(weights)
18
+ color_weights = np.array(weights, dtype=float)
19
+ neg_idx = weights < 0
20
+ pos_idx = weights > 0
21
+ if np.any(neg_idx):
22
+ current_weights = -weights[neg_idx]
23
+ min_weight = np.max(current_weights)
24
+ color_weights[neg_idx] /= min_weight
25
+ color_weights[neg_idx] -= 1
26
+ else:
27
+ min_weight = 0
28
+
29
+ if np.any(pos_idx):
30
+ current_weights = weights[pos_idx]
31
+ max_weight = np.max(current_weights)
32
+ color_weights[pos_idx] /= max_weight
33
+ color_weights[pos_idx] += 1
34
+ else:
35
+ max_weight = 1
36
+
37
+ bordeaux = np.array([0.70567316, 0.01555616, 0.15023281, 1])
38
+ light_bordeaux = np.array([0.70567316, 0.01555616, 0.15023281, temp_alpha])
39
+ bleu = np.array([0.2298057, 0.29871797, 0.75368315, 1])
40
+ light_bleu = np.array([0.2298057, 0.29871797, 0.75368315, temp_alpha])
41
+ norm = plt.Normalize(-2, 2)
42
+ cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
43
+ "", [bordeaux, light_bordeaux, "white", light_bleu, bleu]
44
+ )
45
+ plt.scatter(
46
+ pts[:, 0], pts[:, 1], c=color_weights, cmap=cmap, norm=norm, **plt_kwargs
47
+ )
48
+ plt.scatter([], [], color=bleu, label="positive mass", **plt_kwargs)
49
+ plt.scatter([], [], color=bordeaux, label="negative mass", **plt_kwargs)
50
+ plt.legend()
51
+
52
+
53
+ def _plot_signed_measure_4(
54
+ pts, weights, x_smoothing: float = 1, area_alpha: bool = True
55
+ ):
56
+ # compute the maximal rectangle area
57
+ alpha_rescaling = 0
58
+ for rectangle, weight in zip(pts, weights):
59
+ if rectangle[2] > x_smoothing * rectangle[0]:
60
+ alpha_rescaling = max(
61
+ alpha_rescaling,
62
+ (rectangle[2] / x_smoothing - rectangle[0])
63
+ * (rectangle[3] - rectangle[1]),
64
+ )
65
+ # draw the rectangles
66
+ for rectangle, weight in zip(pts, weights):
67
+ # draw only the rectangles that have not been reduced to the empty set
68
+ if rectangle[2] > x_smoothing * rectangle[0]:
69
+ # make the alpha channel proportional to the rectangle's area
70
+ if area_alpha:
71
+ _plot_rectangle(
72
+ rectangle=[
73
+ rectangle[0],
74
+ rectangle[1],
75
+ rectangle[2] / x_smoothing,
76
+ rectangle[3],
77
+ ],
78
+ weight=weight,
79
+ alpha=(rectangle[2] / x_smoothing - rectangle[0])
80
+ * (rectangle[3] - rectangle[1])
81
+ / alpha_rescaling,
82
+ )
83
+ else:
84
+ _plot_rectangle(
85
+ rectangle=[
86
+ rectangle[0],
87
+ rectangle[1],
88
+ rectangle[2] / x_smoothing,
89
+ rectangle[3],
90
+ ],
91
+ weight=weight,
92
+ alpha=1,
93
+ )
94
+
95
+
96
+ def plot_signed_measure(signed_measure, ax=None, **plt_kwargs):
97
+ if ax is None:
98
+ ax = plt.gca()
99
+ else:
100
+ plt.sca(ax)
101
+ pts, weights = signed_measure
102
+ num_parameters = pts.shape[1]
103
+
104
+ if isinstance(pts, np.ndarray):
105
+ pass
106
+ else:
107
+ import torch
108
+
109
+ if isinstance(pts, torch.Tensor):
110
+ pts = pts.detach().numpy()
111
+ else:
112
+ raise Exception("Invalid measure type.")
113
+
114
+ assert num_parameters in (2, 4)
115
+ if num_parameters == 2:
116
+ _plot_signed_measure_2(pts=pts, weights=weights, **plt_kwargs)
117
+ else:
118
+ _plot_signed_measure_4(pts=pts, weights=weights, **plt_kwargs)
119
+
120
+
121
+ def plot_signed_measures(signed_measures, size=4):
122
+ num_degrees = len(signed_measures)
123
+ fig, axes = plt.subplots(
124
+ nrows=1, ncols=num_degrees, figsize=(num_degrees * size, size)
125
+ )
126
+ if num_degrees == 1:
127
+ axes = [axes]
128
+ for ax, signed_measure in zip(axes, signed_measures):
129
+ plot_signed_measure(signed_measure=signed_measure, ax=ax)
130
+ plt.tight_layout()
131
+
132
+
133
+ def plot_surface(
134
+ grid,
135
+ hf,
136
+ fig=None,
137
+ ax=None,
138
+ cmap: Optional[str] = None,
139
+ discrete_surface=False,
140
+ **plt_args,
141
+ ):
142
+ import matplotlib
143
+
144
+ if ax is None:
145
+ ax = plt.gca()
146
+ else:
147
+ plt.sca(ax)
148
+ if hf.ndim == 3 and hf.shape[0] == 1:
149
+ hf = hf[0]
150
+ assert hf.ndim == 2, "Can only plot a 2d surface"
151
+ fig = plt.gcf() if fig is None else fig
152
+ if cmap is None:
153
+ if discrete_surface:
154
+ cmap = matplotlib.colormaps["gray_r"]
155
+ else:
156
+ cmap = matplotlib.colormaps["plasma"]
157
+ if discrete_surface:
158
+ bounds = np.arange(0, 11, 1, dtype=int)
159
+ norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N, extend="max")
160
+ im = ax.pcolormesh(grid[0], grid[1], hf.T, cmap=cmap, norm=norm, **plt_args)
161
+ cbar = fig.colorbar(
162
+ matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm),
163
+ spacing="proportional",
164
+ ax=ax,
165
+ )
166
+ cbar.set_ticks(ticks=bounds, labels=bounds)
167
+ return
168
+ im = ax.pcolormesh(grid[0], grid[1], hf.T, cmap=cmap, **plt_args)
169
+
170
+
171
+ def plot_surfaces(HF, size=4, **plt_args):
172
+ grid, hf = HF
173
+ assert (
174
+ hf.ndim == 3
175
+ ), f"Found hf.shape = {hf.shape}, expected ndim = 3 : degree, 2-parameter surface."
176
+ num_degrees = hf.shape[0]
177
+ fig, axes = plt.subplots(
178
+ nrows=1, ncols=num_degrees, figsize=(num_degrees * size, size)
179
+ )
180
+ if num_degrees == 1:
181
+ axes = [axes]
182
+ for ax, hf_of_degree in zip(axes, hf):
183
+ plot_surface(grid=grid, hf=hf_of_degree, fig=fig, ax=ax, **plt_args)
184
+ plt.tight_layout()
185
+
186
+
187
+ def _rectangle(x, y, color, alpha):
188
+ """
189
+ Defines a rectangle patch in the format {z | x  ≤ z ≤ y} with color and alpha
190
+ """
191
+ from matplotlib.patches import Rectangle as RectanglePatch
192
+
193
+ return RectanglePatch(
194
+ x, max(y[0] - x[0], 0), max(y[1] - x[1], 0), color=color, alpha=alpha
195
+ )
196
+
197
+
198
+ def _d_inf(a, b):
199
+ if type(a) != np.ndarray or type(b) != np.ndarray:
200
+ a = np.array(a)
201
+ b = np.array(b)
202
+ return np.min(np.abs(b - a))
203
+
204
+
205
+ def plot2d_PyModule(
206
+ corners,
207
+ box=[],
208
+ *,
209
+ dimension=-1,
210
+ separated=False,
211
+ min_persistence=0,
212
+ alpha=1,
213
+ verbose=False,
214
+ save=False,
215
+ dpi=200,
216
+ shapely=True,
217
+ xlabel=None,
218
+ ylabel=None,
219
+ cmap=None,
220
+ ):
221
+ import matplotlib
222
+
223
+ try:
224
+ from shapely.geometry import Polygon as _Polygon
225
+ from shapely.geometry import box as _rectangle_box
226
+ from shapely import union_all
227
+
228
+ shapely = True and shapely
229
+ except:
230
+ from warnings import warn
231
+
232
+ shapely = False
233
+ warn(
234
+ "Shapely not installed. Fallbacking to matplotlib. The plots may be inacurate."
235
+ )
236
+ cmap = (
237
+ matplotlib.colormaps["Spectral"] if cmap is None else matplotlib.colormaps[cmap]
238
+ )
239
+ if not (separated):
240
+ # fig, ax = plt.subplots()
241
+ ax = plt.gca()
242
+ ax.set(xlim=[box[0][0], box[1][0]], ylim=[box[0][1], box[1][1]])
243
+ n_summands = len(corners)
244
+ for i in range(n_summands):
245
+ trivial_summand = True
246
+ list_of_rect = []
247
+ for birth in corners[i][0]:
248
+ for death in corners[i][1]:
249
+ death[0] = min(death[0], box[1][0])
250
+ death[1] = min(death[1], box[1][1])
251
+ if death[1] > birth[1] and death[0] > birth[0]:
252
+ if trivial_summand and _d_inf(birth, death) > min_persistence:
253
+ trivial_summand = False
254
+ if shapely:
255
+ list_of_rect.append(
256
+ _rectangle_box(birth[0], birth[1], death[0], death[1])
257
+ )
258
+ else:
259
+ list_of_rect.append(
260
+ _rectangle(birth, death, cmap(i / n_summands), alpha)
261
+ )
262
+ if not (trivial_summand):
263
+ if separated:
264
+ fig, ax = plt.subplots()
265
+ ax.set(xlim=[box[0][0], box[1][0]], ylim=[box[0][1], box[1][1]])
266
+ if shapely:
267
+ summand_shape = union_all(list_of_rect)
268
+ if type(summand_shape) is _Polygon:
269
+ xs, ys = summand_shape.exterior.xy
270
+ ax.fill(xs, ys, alpha=alpha, fc=cmap(i / n_summands), ec="None")
271
+ else:
272
+ for polygon in summand_shape.geoms:
273
+ xs, ys = polygon.exterior.xy
274
+ ax.fill(xs, ys, alpha=alpha, fc=cmap(i / n_summands), ec="None")
275
+ else:
276
+ for rectangle in list_of_rect:
277
+ ax.add_patch(rectangle)
278
+ if separated:
279
+ if xlabel:
280
+ plt.xlabel(xlabel)
281
+ if ylabel:
282
+ plt.ylabel(ylabel)
283
+ if dimension >= 0:
284
+ plt.title(rf"$H_{dimension}$ $2$-persistence")
285
+ if not (separated):
286
+ if xlabel != None:
287
+ plt.xlabel(xlabel)
288
+ if ylabel != None:
289
+ plt.ylabel(ylabel)
290
+ if dimension >= 0:
291
+ plt.title(rf"$H_{dimension}$ $2$-persistence")
292
+ return