multipers 2.0.0__cp312-cp312-macosx_13_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (78) hide show
  1. multipers/.dylibs/libc++.1.0.dylib +0 -0
  2. multipers/.dylibs/libtbb.12.12.dylib +0 -0
  3. multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
  4. multipers/__init__.py +11 -0
  5. multipers/_signed_measure_meta.py +268 -0
  6. multipers/_slicer_meta.py +171 -0
  7. multipers/data/MOL2.py +350 -0
  8. multipers/data/UCR.py +18 -0
  9. multipers/data/__init__.py +1 -0
  10. multipers/data/graphs.py +466 -0
  11. multipers/data/immuno_regions.py +27 -0
  12. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  13. multipers/data/pytorch2simplextree.py +91 -0
  14. multipers/data/shape3d.py +101 -0
  15. multipers/data/synthetic.py +68 -0
  16. multipers/distances.py +198 -0
  17. multipers/euler_characteristic.pyx +132 -0
  18. multipers/filtration_conversions.pxd +229 -0
  19. multipers/filtrations.pxd +225 -0
  20. multipers/function_rips.cpython-312-darwin.so +0 -0
  21. multipers/function_rips.pyx +105 -0
  22. multipers/grids.cpython-312-darwin.so +0 -0
  23. multipers/grids.pyx +281 -0
  24. multipers/hilbert_function.pyi +46 -0
  25. multipers/hilbert_function.pyx +153 -0
  26. multipers/io.cpython-312-darwin.so +0 -0
  27. multipers/io.pyx +571 -0
  28. multipers/ml/__init__.py +0 -0
  29. multipers/ml/accuracies.py +90 -0
  30. multipers/ml/convolutions.py +532 -0
  31. multipers/ml/invariants_with_persistable.py +79 -0
  32. multipers/ml/kernels.py +176 -0
  33. multipers/ml/mma.py +659 -0
  34. multipers/ml/one.py +472 -0
  35. multipers/ml/point_clouds.py +238 -0
  36. multipers/ml/signed_betti.py +50 -0
  37. multipers/ml/signed_measures.py +1542 -0
  38. multipers/ml/sliced_wasserstein.py +461 -0
  39. multipers/ml/tools.py +113 -0
  40. multipers/mma_structures.cpython-312-darwin.so +0 -0
  41. multipers/mma_structures.pxd +127 -0
  42. multipers/mma_structures.pyx +2433 -0
  43. multipers/multiparameter_edge_collapse.py +41 -0
  44. multipers/multiparameter_module_approximation.cpython-312-darwin.so +0 -0
  45. multipers/multiparameter_module_approximation.pyx +211 -0
  46. multipers/pickle.py +53 -0
  47. multipers/plots.py +326 -0
  48. multipers/point_measure_integration.cpython-312-darwin.so +0 -0
  49. multipers/point_measure_integration.pyx +139 -0
  50. multipers/rank_invariant.cpython-312-darwin.so +0 -0
  51. multipers/rank_invariant.pyx +229 -0
  52. multipers/simplex_tree_multi.cpython-312-darwin.so +0 -0
  53. multipers/simplex_tree_multi.pxd +129 -0
  54. multipers/simplex_tree_multi.pyi +715 -0
  55. multipers/simplex_tree_multi.pyx +4655 -0
  56. multipers/slicer.cpython-312-darwin.so +0 -0
  57. multipers/slicer.pxd +781 -0
  58. multipers/slicer.pyx +3393 -0
  59. multipers/tensor.pxd +13 -0
  60. multipers/test.pyx +44 -0
  61. multipers/tests/__init__.py +40 -0
  62. multipers/tests/old_test_rank_invariant.py +91 -0
  63. multipers/tests/test_diff_helper.py +74 -0
  64. multipers/tests/test_hilbert_function.py +82 -0
  65. multipers/tests/test_mma.py +51 -0
  66. multipers/tests/test_point_clouds.py +59 -0
  67. multipers/tests/test_python-cpp_conversion.py +82 -0
  68. multipers/tests/test_signed_betti.py +181 -0
  69. multipers/tests/test_simplextreemulti.py +98 -0
  70. multipers/tests/test_slicer.py +63 -0
  71. multipers/torch/__init__.py +1 -0
  72. multipers/torch/diff_grids.py +217 -0
  73. multipers/torch/rips_density.py +257 -0
  74. multipers-2.0.0.dist-info/LICENSE +21 -0
  75. multipers-2.0.0.dist-info/METADATA +29 -0
  76. multipers-2.0.0.dist-info/RECORD +78 -0
  77. multipers-2.0.0.dist-info/WHEEL +5 -0
  78. multipers-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,41 @@
1
+ from tqdm import tqdm
2
+
3
+
4
+ def _collapse_edge_list(
5
+ edges,
6
+ num: int = 0,
7
+ full: bool = False,
8
+ strong: bool = False,
9
+ progress: bool = False,
10
+ ):
11
+ """
12
+ Given an edge list defining a 1 critical 2 parameter 1 dimensional simplicial complex, simplificates this filtered simplicial complex, using filtration-domination's edge collapser.
13
+ """
14
+ from filtration_domination import (
15
+ remove_strongly_filtration_dominated,
16
+ remove_filtration_dominated,
17
+ )
18
+
19
+ n = len(edges)
20
+ if full:
21
+ num = 100
22
+ with tqdm(
23
+ range(num), total=num, desc="Removing edges", disable=not (progress)
24
+ ) as I:
25
+ for i in I:
26
+ if strong:
27
+ edges = remove_strongly_filtration_dominated(edges) # nogil ?
28
+ else:
29
+ edges = remove_filtration_dominated(edges)
30
+ # Prevents doing useless collapses
31
+ if len(edges) >= n:
32
+ if full and strong:
33
+ strong = False
34
+ n = len(edges)
35
+ # n = edges.size() # len(edges)
36
+ else:
37
+ break
38
+ else:
39
+ n = len(edges)
40
+ # n = edges.size()
41
+ return edges
@@ -0,0 +1,211 @@
1
+ """!
2
+ @package mma
3
+ @brief Files containing the C++ cythonized functions.
4
+ @author David Loiseaux
5
+ @copyright Copyright (c) 2022 Inria.
6
+ """
7
+
8
+ # distutils: language = c++
9
+
10
+ ###########################################################################
11
+ ## PYTHON LIBRARIES
12
+ import gudhi as gd
13
+ import numpy as np
14
+ from typing import List
15
+ import pickle as pk
16
+
17
+ ###########################################################################
18
+ ## CPP CLASSES
19
+ from libc.stdint cimport intptr_t
20
+ from libc.stdint cimport uintptr_t
21
+
22
+ ###########################################################################
23
+ ## CYTHON TYPES
24
+ from libcpp.vector cimport vector
25
+ from libcpp.utility cimport pair
26
+ #from libcpp.list cimport list as clist
27
+ from libcpp cimport bool
28
+ from libcpp cimport int
29
+ from typing import Iterable,Optional, Literal
30
+ from cython.operator import dereference
31
+ #########################################################################
32
+ ## Multipersistence Module Approximation Classes
33
+ from multipers.mma_structures cimport *
34
+ from multipers.filtrations cimport *
35
+ from multipers.filtration_conversions cimport *
36
+ cimport numpy as cnp
37
+
38
+
39
+ #########################################################################
40
+ ## Small hack for typing
41
+ from multipers.simplex_tree_multi import is_simplextree_multi, SimplexTreeMulti_type
42
+ from multipers.slicer import Slicer_type, is_slicer
43
+ from multipers.mma_structures import *
44
+ from typing import Union
45
+ import multipers
46
+ import multipers.io as mio
47
+ from multipers.slicer cimport _multiparameter_module_approximation_f32, _multiparameter_module_approximation_f64
48
+
49
+
50
+ def module_approximation_from_slicer(
51
+ slicer:Slicer_type,
52
+ box:Optional[np.ndarray]=None,
53
+ max_error=-1,
54
+ bool complete=True,
55
+ bool threshold=False,
56
+ bool verbose=False,
57
+ list[float] direction = [],
58
+ ):
59
+
60
+ cdef Module[float] mod_f32
61
+ cdef Module[double] mod_f64
62
+ cdef intptr_t ptr
63
+ if not slicer.is_vine:
64
+ raise ValueError(f"Slicer must be able to do vineyards. Got {slicer}")
65
+ if slicer.dtype == np.float32:
66
+ approx_mod = PyModule_f32()
67
+ if box is None:
68
+ box = slicer.compute_box()
69
+ mod_f32 = _multiparameter_module_approximation_f32(slicer,_py21c_f32(direction), max_error,Box[float](box),threshold, complete, verbose)
70
+ ptr = <intptr_t>(&mod_f32)
71
+ elif slicer.dtype == np.float64:
72
+ approx_mod = PyModule_f64()
73
+ if box is None:
74
+ box = slicer.compute_box()
75
+ mod_f64 = _multiparameter_module_approximation_f64(slicer,_py21c_f64(direction), max_error,Box[double](box),threshold, complete, verbose)
76
+ ptr = <intptr_t>(&mod_f64)
77
+ else:
78
+ raise ValueError(f"Slicer must be float-like. Got {slicer.dtype}.")
79
+
80
+ approx_mod._set_from_ptr(ptr)
81
+
82
+ return approx_mod
83
+
84
+ def module_approximation(
85
+ input:Union[SimplexTreeMulti_type,Slicer_type, tuple],
86
+ box:Optional[np.ndarray]=None,
87
+ float max_error=-1,
88
+ int nlines=500,
89
+ slicer_backend:Literal["matrix","clement","graph"]="matrix",
90
+ minpres:Optional[Literal["mpfree"]]=None,
91
+ degree:Optional[int]=None,
92
+ bool complete=True,
93
+ bool threshold=False,
94
+ bool verbose=False,
95
+ bool ignore_warning=False,
96
+ id="",
97
+ list[float] direction = [],
98
+ list[int] swap_box_coords = [],
99
+ *,
100
+ int n_jobs = 1,
101
+ ):
102
+ """Computes an interval module approximation of a multiparameter filtration.
103
+
104
+ Parameters
105
+ ----------
106
+ input: SimplexTreeMulti or Slicer-like.
107
+ Holds the multifiltered complex.
108
+ max_error: positive float
109
+ Trade-off between approximation and computational complexity.
110
+ Upper bound of the module approximation, in bottleneck distance,
111
+ for interval-decomposable modules.
112
+ nlines: int = 200
113
+ Alternative to max_error;
114
+ specifies the number of persistence computation used for the approximation.
115
+ box : (Optional) pair of list of floats
116
+ Defines a rectangle on which to compute the approximation.
117
+ Format : [x,y], This defines a rectangle on which we draw the lines,
118
+ uniformly drawn (with a max_error step).
119
+ The first line is `x`.
120
+ **Warning**: For custom boxes, and directions, you **must** ensure
121
+ that the first line captures a generic barcode.
122
+ direction: float[:] = []
123
+ If given, the line are drawn with this angle.
124
+ **Warning**: You must ensure that the first line, defined by box,
125
+ captures a generic barcode.
126
+ slicer_backend: Either "matrix","clement", or "graph".
127
+ If a simplextree is given, it is first converted to this structure,
128
+ with different choices of backends.
129
+ minpres: (Optional) "mpfree" only for the moment.
130
+ If given, and the input is a simplextree,
131
+ computes a minimal presentation before starting the computation.
132
+ A degree has to be given.
133
+ degree: int Only required when minpres is given.
134
+ Homological degree of the minimal degree.
135
+ threshold: bool
136
+ When true, intersects the module support with the box,
137
+ i.e. no more infinite summands.
138
+ verbose: bool
139
+ Prints C++ infos.
140
+ ignore_warning : bool
141
+ Unless set to true, prevents computing on more than 10k lines.
142
+ Useful to prevent a segmentation fault due to "infinite" recursion.
143
+ Returns
144
+ -------
145
+ PyModule
146
+ An interval decomposable module approximation of the module defined by the
147
+ homology of this multi-filtration.
148
+ """
149
+ if isinstance(input, tuple) or isinstance(input, list):
150
+ if len(input) == 0:
151
+ return PyModule()
152
+ if n_jobs <= 1:
153
+ modules = tuple(module_approximation(slicer, box, max_error, nlines, slicer_backend, minpres, degree, complete, threshold, verbose, ignore_warning, id, direction, swap_box_coords) for slicer in input)
154
+ else:
155
+ from joblib import Parallel, delayed
156
+ modules = tuple(Parallel(n_jobs=n_jobs, prefer="threads")(
157
+ delayed(module_approximation)(slicer, box, max_error, nlines, slicer_backend, minpres, degree, complete, threshold, verbose, ignore_warning, id, direction, swap_box_coords)
158
+ for slicer in input
159
+ ))
160
+ mod = modules[0]
161
+ for m in modules[1:]:
162
+ mod.merge(m)
163
+ return mod
164
+ if box is None:
165
+ if is_simplextree_multi(input):
166
+ box = input.filtration_bounds()
167
+ else:
168
+ box = input.compute_box()
169
+ box = np.asarray(box)
170
+ for i in swap_box_coords:
171
+ box[0,i], box[1,i] = box[1,i], box[0,i]
172
+ num_parameters = box.shape[1]
173
+ if num_parameters <=0:
174
+ num_parameters = box.shape[1]
175
+ assert len(direction) == 0 or len(direction) == len(box[0]), f"Invalid line direction, has to be 0 or {num_parameters=}"
176
+
177
+ prod = sum(np.abs(box[1] - box[0])[:i].prod() * np.abs(box[1] - box[0])[i+1:].prod() for i in range(0,num_parameters))
178
+
179
+ if max_error <= 0:
180
+ max_error = (prod/nlines)**(1/(num_parameters-1))
181
+
182
+ if not ignore_warning and prod >= 10_000:
183
+ raise ValueError(f"""
184
+ Warning : the number of lines (around {np.round(prod)}) may be too high.
185
+ Try to increase the precision parameter, or set `ignore_warning=True` to compute this module.
186
+ Returning the trivial module."""
187
+ )
188
+ if is_simplextree_multi(input):
189
+ blocks = input._to_scc()
190
+ if minpres is not None:
191
+ assert not input.is_kcritical, "scc (and therefore mpfree, multi_chunk, 2pac, ...) format doesn't handle multi-critical filtrations."
192
+ mio.scc2disk(blocks, mio.input_path+id)
193
+ blocks = mio.reduce_complex(mio.input_path+id, dimension=input.dimension-degree, backend=minpres)
194
+ else:
195
+ pass
196
+ input = multipers.Slicer(blocks,backend=slicer_backend, dtype = input.dtype, is_kcritical = input.is_kcritical())
197
+ assert is_slicer(input), "First argument must be a simplextree or a slicer !"
198
+ return module_approximation_from_slicer(
199
+ slicer=input,
200
+ box=box,
201
+ max_error=max_error,
202
+ complete=complete,
203
+ threshold=threshold,
204
+ verbose=verbose,
205
+ direction=direction,
206
+ )
207
+
208
+
209
+
210
+
211
+
multipers/pickle.py ADDED
@@ -0,0 +1,53 @@
1
+ import numpy as np
2
+
3
+ def save_with_axis(path:str, signed_measures):
4
+ np.savez(path,
5
+ **{f"{i}_{axis}_{degree}":np.c_[sm_of_degree[0],sm_of_degree[1][:,np.newaxis]] for i,sm in enumerate(signed_measures) for axis,sm_of_axis in enumerate(sm) for degree,sm_of_degree in enumerate(sm_of_axis)},
6
+ )
7
+
8
+ def save_without_axis(path:str, signed_measures):
9
+ np.savez(path,
10
+ **{f"{i}_{degree}":np.c_[sm_of_degree[0],sm_of_degree[1][:,np.newaxis]] for i,sm in enumerate(signed_measures) for degree,sm_of_degree in enumerate(sm)},
11
+ )
12
+
13
+ def get_sm_with_axis(sms,idx,axis,degree):
14
+ sm = sms[f"{idx}_{axis}_{degree}"]
15
+ return (sm[:,:-1],sm[:,-1])
16
+ def get_sm_without_axis(sms,idx,degree):
17
+ sm = sms[f"{idx}_{degree}"]
18
+ return (sm[:,:-1],sm[:,-1])
19
+
20
+
21
+ def load_without_axis(sms):
22
+ indices = np.array([[int(i) for i in key.split('_')] for key in sms.keys()], dtype=int)
23
+ num_data,num_degrees = indices.max(axis=0)+1
24
+ signed_measures_reconstructed = [[get_sm_without_axis(sms,idx,degree) for degree in range(num_degrees)] for idx in range(num_data)]
25
+ return signed_measures_reconstructed
26
+ # test : np.all([np.array_equal(a[0],b[0]) and np.array_equal(a[1],b[1]) and len(a) == len(b) == 2 for x,y in zip(signed_measures_reconstructed,signed_measures_reconstructed) for a,b in zip(x,y)])
27
+
28
+ def load_with_axis(sms):
29
+ indices = np.array([[int(i) for i in key.split('_')] for key in sms.keys()], dtype=int)
30
+ num_data,num_axis,num_degrees = indices.max(axis=0)+1
31
+ signed_measures_reconstructed = [[[get_sm_with_axis(sms,idx,axis,degree) for degree in range(num_degrees)] for axis in range(num_axis)] for idx in range(num_data)]
32
+ return signed_measures_reconstructed
33
+
34
+ def save(path:str, signed_measures):
35
+ if isinstance(signed_measures[0][0], tuple):
36
+ save_without_axis(path=path,signed_measures=signed_measures)
37
+ else:
38
+ save_with_axis(path=path,signed_measures=signed_measures)
39
+
40
+ def load(path:str):
41
+ sms = np.load(path)
42
+ item=None
43
+ for i in sms.keys():
44
+ item=i
45
+ break
46
+ n = len(item.split('_'))
47
+ match n:
48
+ case 2:
49
+ return load_without_axis(sms)
50
+ case 3:
51
+ return load_with_axis(sms)
52
+ case _:
53
+ raise Exception("Invalid Signed Measure !")
multipers/plots.py ADDED
@@ -0,0 +1,326 @@
1
+ from typing import Optional
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+
7
+ def _plot_rectangle(rectangle: np.ndarray, weight, **plt_kwargs):
8
+ rectangle = np.asarray(rectangle)
9
+ x_axis = rectangle[[0, 2]]
10
+ y_axis = rectangle[[1, 3]]
11
+ color = "blue" if weight > 0 else "red"
12
+ plt.plot(x_axis, y_axis, c=color, **plt_kwargs)
13
+
14
+
15
+ def _plot_signed_measure_2(
16
+ pts, weights, temp_alpha=0.7, threshold=(np.inf, np.inf), **plt_kwargs
17
+ ):
18
+ import matplotlib.colors
19
+
20
+ pts = np.clip(pts, a_min=-np.inf, a_max=np.asarray(threshold)[None, :])
21
+ weights = np.asarray(weights)
22
+ color_weights = np.array(weights, dtype=float)
23
+ neg_idx = weights < 0
24
+ pos_idx = weights > 0
25
+ if np.any(neg_idx):
26
+ current_weights = -weights[neg_idx]
27
+ min_weight = np.max(current_weights)
28
+ color_weights[neg_idx] /= min_weight
29
+ color_weights[neg_idx] -= 1
30
+ else:
31
+ min_weight = 0
32
+
33
+ if np.any(pos_idx):
34
+ current_weights = weights[pos_idx]
35
+ max_weight = np.max(current_weights)
36
+ color_weights[pos_idx] /= max_weight
37
+ color_weights[pos_idx] += 1
38
+ else:
39
+ max_weight = 1
40
+
41
+ bordeaux = np.array([0.70567316, 0.01555616, 0.15023281, 1])
42
+ light_bordeaux = np.array([0.70567316, 0.01555616, 0.15023281, temp_alpha])
43
+ bleu = np.array([0.2298057, 0.29871797, 0.75368315, 1])
44
+ light_bleu = np.array([0.2298057, 0.29871797, 0.75368315, temp_alpha])
45
+ norm = plt.Normalize(-2, 2)
46
+ cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
47
+ "", [bordeaux, light_bordeaux, "white", light_bleu, bleu]
48
+ )
49
+ plt.scatter(
50
+ pts[:, 0], pts[:, 1], c=color_weights, cmap=cmap, norm=norm, **plt_kwargs
51
+ )
52
+ plt.scatter([], [], color=bleu, label="positive mass", **plt_kwargs)
53
+ plt.scatter([], [], color=bordeaux, label="negative mass", **plt_kwargs)
54
+ plt.legend()
55
+
56
+
57
+ def _plot_signed_measure_4(
58
+ pts,
59
+ weights,
60
+ x_smoothing: float = 1,
61
+ area_alpha: bool = True,
62
+ threshold=(np.inf, np.inf),
63
+ **plt_kwargs, # ignored ftm
64
+ ):
65
+ # compute the maximal rectangle area
66
+ pts = np.clip(pts, a_min=-np.inf, a_max=np.array((*threshold, *threshold))[None, :])
67
+ alpha_rescaling = 0
68
+ for rectangle, weight in zip(pts, weights):
69
+ if rectangle[2] > x_smoothing * rectangle[0]:
70
+ alpha_rescaling = max(
71
+ alpha_rescaling,
72
+ (rectangle[2] / x_smoothing - rectangle[0])
73
+ * (rectangle[3] - rectangle[1]),
74
+ )
75
+ # draw the rectangles
76
+ for rectangle, weight in zip(pts, weights):
77
+ # draw only the rectangles that have not been reduced to the empty set
78
+ if rectangle[2] > x_smoothing * rectangle[0]:
79
+ # make the alpha channel proportional to the rectangle's area
80
+ if area_alpha:
81
+ _plot_rectangle(
82
+ rectangle=[
83
+ rectangle[0],
84
+ rectangle[1],
85
+ rectangle[2] / x_smoothing,
86
+ rectangle[3],
87
+ ],
88
+ weight=weight,
89
+ alpha=(rectangle[2] / x_smoothing - rectangle[0])
90
+ * (rectangle[3] - rectangle[1])
91
+ / alpha_rescaling,
92
+ )
93
+ else:
94
+ _plot_rectangle(
95
+ rectangle=[
96
+ rectangle[0],
97
+ rectangle[1],
98
+ rectangle[2] / x_smoothing,
99
+ rectangle[3],
100
+ ],
101
+ weight=weight,
102
+ alpha=1,
103
+ )
104
+
105
+
106
+ def plot_signed_measure(signed_measure, threshold=None, ax=None, **plt_kwargs):
107
+ if ax is None:
108
+ ax = plt.gca()
109
+ else:
110
+ plt.sca(ax)
111
+ pts, weights = signed_measure
112
+ pts = np.asarray(pts)
113
+ num_pts = pts.shape[0]
114
+ num_parameters = pts.shape[1]
115
+ if threshold is None:
116
+ if num_pts == 0:
117
+ threshold = (np.inf, np.inf)
118
+ else:
119
+ if num_parameters == 4:
120
+ pts_ = np.concatenate([pts[:, :2], pts[:, 2:]], axis=0)
121
+ else:
122
+ pts_ = pts
123
+ threshold = np.max(np.ma.masked_invalid(pts_), axis=0)
124
+ if isinstance(pts, np.ndarray):
125
+ pass
126
+ else:
127
+ import torch
128
+
129
+ if isinstance(pts, torch.Tensor):
130
+ pts = pts.detach().numpy()
131
+ else:
132
+ raise Exception("Invalid measure type.")
133
+
134
+ assert num_parameters in (2, 4)
135
+ if num_parameters == 2:
136
+ _plot_signed_measure_2(
137
+ pts=pts, weights=weights, threshold=threshold, **plt_kwargs
138
+ )
139
+ else:
140
+ _plot_signed_measure_4(
141
+ pts=pts, weights=weights, threshold=threshold, **plt_kwargs
142
+ )
143
+
144
+
145
+ def plot_signed_measures(signed_measures, threshold=None, size=4):
146
+ num_degrees = len(signed_measures)
147
+ fig, axes = plt.subplots(
148
+ nrows=1, ncols=num_degrees, figsize=(num_degrees * size, size)
149
+ )
150
+ if num_degrees == 1:
151
+ axes = [axes]
152
+ for ax, signed_measure in zip(axes, signed_measures):
153
+ plot_signed_measure(signed_measure=signed_measure, ax=ax, threshold=threshold)
154
+ plt.tight_layout()
155
+
156
+
157
+ def plot_surface(
158
+ grid,
159
+ hf,
160
+ fig=None,
161
+ ax=None,
162
+ cmap: Optional[str] = None,
163
+ discrete_surface=False,
164
+ has_negative_values=False,
165
+ **plt_args,
166
+ ):
167
+ import matplotlib
168
+
169
+ if ax is None:
170
+ ax = plt.gca()
171
+ else:
172
+ plt.sca(ax)
173
+ if hf.ndim == 3 and hf.shape[0] == 1:
174
+ hf = hf[0]
175
+ assert hf.ndim == 2, "Can only plot a 2d surface"
176
+ fig = plt.gcf() if fig is None else fig
177
+ if cmap is None:
178
+ if discrete_surface:
179
+ cmap = matplotlib.colormaps["gray_r"]
180
+ else:
181
+ cmap = matplotlib.colormaps["plasma"]
182
+ if discrete_surface:
183
+ if has_negative_values:
184
+ bounds = np.arange(-5, 6, 1, dtype=int)
185
+ else:
186
+ bounds = np.arange(0, 11, 1, dtype=int)
187
+ norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N, extend="max")
188
+ im = ax.pcolormesh(grid[0], grid[1], hf.T, cmap=cmap, norm=norm, **plt_args)
189
+ cbar = fig.colorbar(
190
+ matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm),
191
+ spacing="proportional",
192
+ ax=ax,
193
+ )
194
+ cbar.set_ticks(ticks=bounds, labels=bounds)
195
+ return
196
+ im = ax.pcolormesh(grid[0], grid[1], hf.T, cmap=cmap, **plt_args)
197
+ return im
198
+
199
+
200
+ def plot_surfaces(HF, size=4, **plt_args):
201
+ grid, hf = HF
202
+ assert (
203
+ hf.ndim == 3
204
+ ), f"Found hf.shape = {hf.shape}, expected ndim = 3 : degree, 2-parameter surface."
205
+ num_degrees = hf.shape[0]
206
+ fig, axes = plt.subplots(
207
+ nrows=1, ncols=num_degrees, figsize=(num_degrees * size, size)
208
+ )
209
+ if num_degrees == 1:
210
+ axes = [axes]
211
+ for ax, hf_of_degree in zip(axes, hf):
212
+ plot_surface(grid=grid, hf=hf_of_degree, fig=fig, ax=ax, **plt_args)
213
+ plt.tight_layout()
214
+
215
+
216
+ def _rectangle(x, y, color, alpha):
217
+ """
218
+ Defines a rectangle patch in the format {z | x  ≤ z ≤ y} with color and alpha
219
+ """
220
+ from matplotlib.patches import Rectangle as RectanglePatch
221
+
222
+ return RectanglePatch(
223
+ x, max(y[0] - x[0], 0), max(y[1] - x[1], 0), color=color, alpha=alpha
224
+ )
225
+
226
+
227
+ def _d_inf(a, b):
228
+ if type(a) != np.ndarray or type(b) != np.ndarray:
229
+ a = np.array(a)
230
+ b = np.array(b)
231
+ return np.min(np.abs(b - a))
232
+
233
+
234
+ def plot2d_PyModule(
235
+ corners,
236
+ box,
237
+ *,
238
+ dimension=-1,
239
+ separated=False,
240
+ min_persistence=0,
241
+ alpha=1,
242
+ verbose=False,
243
+ save=False,
244
+ dpi=200,
245
+ shapely=True,
246
+ xlabel=None,
247
+ ylabel=None,
248
+ cmap=None,
249
+ ):
250
+ import matplotlib
251
+
252
+ try:
253
+ from shapely import union_all
254
+ from shapely.geometry import Polygon as _Polygon
255
+ from shapely.geometry import box as _rectangle_box
256
+
257
+ shapely = True and shapely
258
+ except ImportError:
259
+ from warnings import warn
260
+
261
+ shapely = False
262
+ warn(
263
+ "Shapely not installed. Fallbacking to matplotlib. The plots may be inacurate."
264
+ )
265
+ cmap = (
266
+ matplotlib.colormaps["Spectral"] if cmap is None else matplotlib.colormaps[cmap]
267
+ )
268
+ box = list(box)
269
+ if not (separated):
270
+ # fig, ax = plt.subplots()
271
+ ax = plt.gca()
272
+ ax.set(xlim=[box[0][0], box[1][0]], ylim=[box[0][1], box[1][1]])
273
+ n_summands = len(corners)
274
+ for i in range(n_summands):
275
+ trivial_summand = True
276
+ list_of_rect = []
277
+ for birth in corners[i][0]:
278
+ if len(birth) == 1:
279
+ birth = np.asarray([birth[0]] * 2)
280
+ birth = np.asarray(birth).clip(min=box[0])
281
+ for death in corners[i][1]:
282
+ if len(death) == 1:
283
+ death = np.asarray([death[0]] * 2)
284
+ death = np.asarray(death).clip(max=box[1])
285
+ if death[1] > birth[1] and death[0] > birth[0]:
286
+ if trivial_summand and _d_inf(birth, death) > min_persistence:
287
+ trivial_summand = False
288
+ if shapely:
289
+ list_of_rect.append(
290
+ _rectangle_box(birth[0], birth[1], death[0], death[1])
291
+ )
292
+ else:
293
+ list_of_rect.append(
294
+ _rectangle(birth, death, cmap(i / n_summands), alpha)
295
+ )
296
+ if not (trivial_summand):
297
+ if separated:
298
+ fig, ax = plt.subplots()
299
+ ax.set(xlim=[box[0][0], box[1][0]], ylim=[box[0][1], box[1][1]])
300
+ if shapely:
301
+ summand_shape = union_all(list_of_rect)
302
+ if type(summand_shape) is _Polygon:
303
+ xs, ys = summand_shape.exterior.xy
304
+ ax.fill(xs, ys, alpha=alpha, fc=cmap(i / n_summands), ec="None")
305
+ else:
306
+ for polygon in summand_shape.geoms:
307
+ xs, ys = polygon.exterior.xy
308
+ ax.fill(xs, ys, alpha=alpha, fc=cmap(i / n_summands), ec="None")
309
+ else:
310
+ for rectangle in list_of_rect:
311
+ ax.add_patch(rectangle)
312
+ if separated:
313
+ if xlabel:
314
+ plt.xlabel(xlabel)
315
+ if ylabel:
316
+ plt.ylabel(ylabel)
317
+ if dimension >= 0:
318
+ plt.title(rf"$H_{dimension}$ $2$-persistence")
319
+ if not (separated):
320
+ if xlabel is not None:
321
+ plt.xlabel(xlabel)
322
+ if ylabel is not None:
323
+ plt.ylabel(ylabel)
324
+ if dimension >= 0:
325
+ plt.title(rf"$H_{dimension}$ $2$-persistence")
326
+ return