multipers 2.3.3__cp312-cp312-manylinux_2_39_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -0
  2. multipers/_signed_measure_meta.py +450 -0
  3. multipers/_slicer_meta.py +211 -0
  4. multipers/array_api/__init__.py +62 -0
  5. multipers/array_api/numpy.py +104 -0
  6. multipers/array_api/torch.py +117 -0
  7. multipers/data/MOL2.py +458 -0
  8. multipers/data/UCR.py +18 -0
  9. multipers/data/__init__.py +1 -0
  10. multipers/data/graphs.py +466 -0
  11. multipers/data/immuno_regions.py +27 -0
  12. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  13. multipers/data/pytorch2simplextree.py +91 -0
  14. multipers/data/shape3d.py +101 -0
  15. multipers/data/synthetic.py +113 -0
  16. multipers/distances.py +202 -0
  17. multipers/filtration_conversions.pxd +229 -0
  18. multipers/filtration_conversions.pxd.tp +84 -0
  19. multipers/filtrations/__init__.py +18 -0
  20. multipers/filtrations/density.py +533 -0
  21. multipers/filtrations/filtrations.py +361 -0
  22. multipers/filtrations.pxd +224 -0
  23. multipers/function_rips.cpython-312-x86_64-linux-gnu.so +0 -0
  24. multipers/function_rips.pyx +105 -0
  25. multipers/grids.cpython-312-x86_64-linux-gnu.so +0 -0
  26. multipers/grids.pyx +481 -0
  27. multipers/gudhi/Persistence_slices_interface.h +132 -0
  28. multipers/gudhi/Simplex_tree_interface.h +239 -0
  29. multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
  30. multipers/gudhi/cubical_to_boundary.h +59 -0
  31. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  32. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  33. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  34. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  35. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  36. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  37. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  38. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  39. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  41. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  42. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  43. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  44. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  45. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  46. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  47. multipers/gudhi/gudhi/Matrix.h +2107 -0
  48. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  49. multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
  50. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  51. multipers/gudhi/gudhi/Off_reader.h +173 -0
  52. multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
  53. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  54. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  87. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  88. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  89. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  90. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  91. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  92. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  93. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  94. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  95. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  96. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  97. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  98. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  99. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  100. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  101. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  102. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
  103. multipers/gudhi/gudhi/distance_functions.h +62 -0
  104. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  105. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  106. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  107. multipers/gudhi/gudhi/reader_utils.h +367 -0
  108. multipers/gudhi/mma_interface_coh.h +256 -0
  109. multipers/gudhi/mma_interface_h0.h +223 -0
  110. multipers/gudhi/mma_interface_matrix.h +293 -0
  111. multipers/gudhi/naive_merge_tree.h +536 -0
  112. multipers/gudhi/scc_io.h +310 -0
  113. multipers/gudhi/truc.h +1403 -0
  114. multipers/io.cpython-312-x86_64-linux-gnu.so +0 -0
  115. multipers/io.pyx +644 -0
  116. multipers/ml/__init__.py +0 -0
  117. multipers/ml/accuracies.py +90 -0
  118. multipers/ml/invariants_with_persistable.py +79 -0
  119. multipers/ml/kernels.py +176 -0
  120. multipers/ml/mma.py +713 -0
  121. multipers/ml/one.py +472 -0
  122. multipers/ml/point_clouds.py +352 -0
  123. multipers/ml/signed_measures.py +1667 -0
  124. multipers/ml/sliced_wasserstein.py +461 -0
  125. multipers/ml/tools.py +113 -0
  126. multipers/mma_structures.cpython-312-x86_64-linux-gnu.so +0 -0
  127. multipers/mma_structures.pxd +128 -0
  128. multipers/mma_structures.pyx +2786 -0
  129. multipers/mma_structures.pyx.tp +1094 -0
  130. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
  131. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  132. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  133. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  134. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  135. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  136. multipers/multiparameter_edge_collapse.py +41 -0
  137. multipers/multiparameter_module_approximation/approximation.h +2330 -0
  138. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  139. multipers/multiparameter_module_approximation/debug.h +107 -0
  140. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  141. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  142. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  143. multipers/multiparameter_module_approximation/images.h +79 -0
  144. multipers/multiparameter_module_approximation/list_column.h +174 -0
  145. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  146. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  147. multipers/multiparameter_module_approximation/set_column.h +135 -0
  148. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  149. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  150. multipers/multiparameter_module_approximation/utilities.h +403 -0
  151. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  152. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  153. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  154. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  155. multipers/multiparameter_module_approximation.cpython-312-x86_64-linux-gnu.so +0 -0
  156. multipers/multiparameter_module_approximation.pyx +235 -0
  157. multipers/pickle.py +90 -0
  158. multipers/plots.py +470 -0
  159. multipers/point_measure.cpython-312-x86_64-linux-gnu.so +0 -0
  160. multipers/point_measure.pyx +395 -0
  161. multipers/simplex_tree_multi.cpython-312-x86_64-linux-gnu.so +0 -0
  162. multipers/simplex_tree_multi.pxd +134 -0
  163. multipers/simplex_tree_multi.pyx +10980 -0
  164. multipers/simplex_tree_multi.pyx.tp +2007 -0
  165. multipers/slicer.cpython-312-x86_64-linux-gnu.so +0 -0
  166. multipers/slicer.pxd +3034 -0
  167. multipers/slicer.pxd.tp +234 -0
  168. multipers/slicer.pyx +20481 -0
  169. multipers/slicer.pyx.tp +1088 -0
  170. multipers/tensor/tensor.h +672 -0
  171. multipers/tensor.pxd +13 -0
  172. multipers/test.pyx +44 -0
  173. multipers/tests/__init__.py +62 -0
  174. multipers/torch/__init__.py +1 -0
  175. multipers/torch/diff_grids.py +240 -0
  176. multipers/torch/rips_density.py +310 -0
  177. multipers-2.3.3.dist-info/METADATA +128 -0
  178. multipers-2.3.3.dist-info/RECORD +182 -0
  179. multipers-2.3.3.dist-info/WHEEL +5 -0
  180. multipers-2.3.3.dist-info/licenses/LICENSE +21 -0
  181. multipers-2.3.3.dist-info/top_level.txt +1 -0
  182. multipers.libs/libtbb-ca48af5c.so.12.16 +0 -0
multipers/plots.py ADDED
@@ -0,0 +1,470 @@
1
+ from typing import Optional
2
+ from warnings import warn
3
+
4
+ import matplotlib.colors as mcolors
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from matplotlib.colors import ListedColormap
8
+ from numpy.typing import ArrayLike
9
+
10
+ from multipers.array_api import to_numpy
11
+
12
+ _custom_colors = [
13
+ "#03045e",
14
+ "#0077b6",
15
+ "#00b4d8",
16
+ "#90e0ef",
17
+ ]
18
+ _cmap_ = ListedColormap(_custom_colors)
19
+ _cmap = mcolors.LinearSegmentedColormap.from_list(
20
+ "continuous_cmap", _cmap_.colors, N=256
21
+ )
22
+
23
+
24
+ def _plot_rectangle(rectangle: np.ndarray, weight, **plt_kwargs):
25
+ rectangle = np.asarray(rectangle)
26
+ x_axis = rectangle[[0, 2]]
27
+ y_axis = rectangle[[1, 3]]
28
+ color = "blue" if weight > 0 else "red"
29
+ plt.plot(x_axis, y_axis, c=color, **plt_kwargs)
30
+
31
+
32
+ def _plot_signed_measure_2(
33
+ pts, weights, temp_alpha=0.7, threshold=(np.inf, np.inf), **plt_kwargs
34
+ ):
35
+ import matplotlib.colors
36
+
37
+ pts = np.clip(pts, a_min=-np.inf, a_max=np.asarray(threshold)[None, :])
38
+ weights = np.asarray(weights)
39
+ color_weights = np.array(weights, dtype=float)
40
+ neg_idx = weights < 0
41
+ pos_idx = weights > 0
42
+ if np.any(neg_idx):
43
+ current_weights = -weights[neg_idx]
44
+ min_weight = np.max(current_weights)
45
+ color_weights[neg_idx] /= min_weight
46
+ color_weights[neg_idx] -= 1
47
+ else:
48
+ min_weight = 0
49
+
50
+ if np.any(pos_idx):
51
+ current_weights = weights[pos_idx]
52
+ max_weight = np.max(current_weights)
53
+ color_weights[pos_idx] /= max_weight
54
+ color_weights[pos_idx] += 1
55
+ else:
56
+ max_weight = 1
57
+
58
+ bordeaux = np.array([0.70567316, 0.01555616, 0.15023281, 1])
59
+ light_bordeaux = np.array([0.70567316, 0.01555616, 0.15023281, temp_alpha])
60
+ bleu = np.array([0.2298057, 0.29871797, 0.75368315, 1])
61
+ light_bleu = np.array([0.2298057, 0.29871797, 0.75368315, temp_alpha])
62
+ norm = plt.Normalize(-2, 2)
63
+ cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
64
+ "", [bordeaux, light_bordeaux, "white", light_bleu, bleu]
65
+ )
66
+ plt.scatter(
67
+ pts[:, 0], pts[:, 1], c=color_weights, cmap=cmap, norm=norm, **plt_kwargs
68
+ )
69
+ plt.scatter([], [], color=bleu, label="positive mass", **plt_kwargs)
70
+ plt.scatter([], [], color=bordeaux, label="negative mass", **plt_kwargs)
71
+ plt.legend()
72
+
73
+
74
+ def _plot_signed_measure_4(
75
+ pts,
76
+ weights,
77
+ x_smoothing: float = 1,
78
+ area_alpha: bool = True,
79
+ threshold=(np.inf, np.inf),
80
+ alpha=None,
81
+ **plt_kwargs, # ignored ftm
82
+ ):
83
+ # compute the maximal rectangle area
84
+ pts = np.clip(pts, a_min=-np.inf,
85
+ a_max=np.array((*threshold, *threshold))[None, :])
86
+ alpha_rescaling = 0
87
+ for rectangle, weight in zip(pts, weights):
88
+ if rectangle[2] >= x_smoothing * rectangle[0]:
89
+ alpha_rescaling = max(
90
+ alpha_rescaling,
91
+ (rectangle[2] / x_smoothing - rectangle[0])
92
+ * (rectangle[3] - rectangle[1]),
93
+ )
94
+ # draw the rectangles
95
+ for rectangle, weight in zip(pts, weights):
96
+ # draw only the rectangles that have not been reduced to the empty set
97
+ if rectangle[2] >= x_smoothing * rectangle[0]:
98
+ # make the alpha channel proportional to the rectangle's area
99
+ if area_alpha:
100
+ _plot_rectangle(
101
+ rectangle=[
102
+ rectangle[0],
103
+ rectangle[1],
104
+ rectangle[2] / x_smoothing,
105
+ rectangle[3],
106
+ ],
107
+ weight=weight,
108
+ alpha=(
109
+ (rectangle[2] / x_smoothing - rectangle[0])
110
+ * (rectangle[3] - rectangle[1])
111
+ / alpha_rescaling
112
+ if alpha is None
113
+ else alpha
114
+ ),
115
+ **plt_kwargs,
116
+ )
117
+ else:
118
+ _plot_rectangle(
119
+ rectangle=[
120
+ rectangle[0],
121
+ rectangle[1],
122
+ rectangle[2] / x_smoothing,
123
+ rectangle[3],
124
+ ],
125
+ weight=weight,
126
+ alpha=1 if alpha is None else alpha,
127
+ **plt_kwargs,
128
+ )
129
+
130
+
131
+ def plot_signed_measure(signed_measure, threshold=None, ax=None, **plt_kwargs):
132
+ if ax is None:
133
+ ax = plt.gca()
134
+ else:
135
+ plt.sca(ax)
136
+ pts, weights = signed_measure
137
+ pts = to_numpy(pts)
138
+ weights = to_numpy(weights)
139
+ num_pts = pts.shape[0]
140
+ num_parameters = pts.shape[1]
141
+ if threshold is None:
142
+ if num_pts == 0:
143
+ threshold = (np.inf, np.inf)
144
+ else:
145
+ if num_parameters == 4:
146
+ pts_ = np.concatenate([pts[:, :2], pts[:, 2:]], axis=0)
147
+ else:
148
+ pts_ = pts
149
+ threshold = np.max(np.ma.masked_invalid(pts_), axis=0)
150
+ threshold = np.max(
151
+ [threshold, [plt.gca().get_xlim()[1], plt.gca().get_ylim()[1]]], axis=0
152
+ )
153
+
154
+ assert num_parameters in (2, 4)
155
+ if num_parameters == 2:
156
+ _plot_signed_measure_2(
157
+ pts=pts, weights=weights, threshold=threshold, **plt_kwargs
158
+ )
159
+ else:
160
+ _plot_signed_measure_4(
161
+ pts=pts, weights=weights, threshold=threshold, **plt_kwargs
162
+ )
163
+
164
+
165
+ def plot_signed_measures(signed_measures, threshold=None, size=4):
166
+ num_degrees = len(signed_measures)
167
+ if num_degrees <= 1:
168
+ axes = [plt.gca()]
169
+ else:
170
+ fig, axes = plt.subplots(
171
+ nrows=1, ncols=num_degrees, figsize=(num_degrees * size, size)
172
+ )
173
+ for ax, signed_measure in zip(axes, signed_measures):
174
+ plot_signed_measure(signed_measure=signed_measure,
175
+ ax=ax, threshold=threshold)
176
+ plt.tight_layout()
177
+
178
+
179
+ def plot_surface(
180
+ grid,
181
+ hf,
182
+ fig=None,
183
+ ax=None,
184
+ cmap: Optional[str] = None,
185
+ discrete_surface: bool = False,
186
+ has_negative_values: bool = False,
187
+ contour: bool = True,
188
+ **plt_args,
189
+ ):
190
+ import matplotlib
191
+
192
+ if ax is None:
193
+ ax = plt.gca()
194
+ else:
195
+ plt.sca(ax)
196
+ if hf.ndim == 3 and hf.shape[0] == 1:
197
+ hf = hf[0]
198
+ assert hf.ndim == 2, "Can only plot a 2d surface"
199
+ fig = plt.gcf() if fig is None else fig
200
+ if cmap is None:
201
+ if discrete_surface:
202
+ cmap = matplotlib.colormaps["gray_r"]
203
+ else:
204
+ cmap = _cmap
205
+ if discrete_surface or not contour:
206
+ # for shading="flat"
207
+ grid = [np.concatenate([g, [g[-1]*1.1 - .1*g[0]]]) for g in grid]
208
+ if discrete_surface:
209
+ if has_negative_values:
210
+ bounds = np.arange(-5, 6, 1, dtype=int)
211
+ else:
212
+ bounds = np.arange(0, 11, 1, dtype=int)
213
+ norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N, extend="max")
214
+ im = ax.pcolormesh(grid[0], grid[1], hf.T, cmap=cmap,
215
+ norm=norm, shading="flat", **plt_args)
216
+ cbar = fig.colorbar(
217
+ matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm),
218
+ spacing="proportional",
219
+ ax=ax,
220
+ )
221
+ cbar.set_ticks(ticks=bounds, labels=bounds)
222
+ return im
223
+
224
+ if contour:
225
+ levels = plt_args.pop("levels", 50)
226
+ im = ax.contourf(grid[0], grid[1], hf.T,
227
+ cmap=cmap, levels=levels, **plt_args)
228
+ else:
229
+ im = ax.pcolormesh(grid[0], grid[1], hf.T,
230
+ cmap=cmap, shading="flat", **plt_args)
231
+ return im
232
+
233
+
234
+ def plot_surfaces(HF, size=4, **plt_args):
235
+ grid, hf = HF
236
+ assert (
237
+ hf.ndim == 3
238
+ ), f"Found hf.shape = {hf.shape}, expected ndim = 3 : degree, 2-parameter surface."
239
+ num_degrees = hf.shape[0]
240
+ fig, axes = plt.subplots(
241
+ nrows=1, ncols=num_degrees, figsize=(num_degrees * size, size)
242
+ )
243
+ if num_degrees == 1:
244
+ axes = [axes]
245
+ for ax, hf_of_degree in zip(axes, hf):
246
+ plot_surface(grid=grid, hf=hf_of_degree, fig=fig, ax=ax, **plt_args)
247
+ plt.tight_layout()
248
+
249
+
250
+ def _rectangle(x, y, color, alpha):
251
+ """
252
+ Defines a rectangle patch in the format {z | x  ≤ z ≤ y} with color and alpha
253
+ """
254
+ from matplotlib.patches import Rectangle as RectanglePatch
255
+
256
+ return RectanglePatch(
257
+ x, max(y[0] - x[0], 0), max(y[1] - x[1], 0), color=color, alpha=alpha
258
+ )
259
+
260
+
261
+ def _d_inf(a, b):
262
+ a = np.asarray(a)
263
+ b = np.asarray(b)
264
+ return np.min(np.abs(b - a))
265
+
266
+
267
+ def plot2d_PyModule(
268
+ corners,
269
+ box,
270
+ *,
271
+ dimension=-1,
272
+ separated=False,
273
+ min_persistence=0,
274
+ alpha=None,
275
+ verbose=False,
276
+ save=False,
277
+ dpi=200,
278
+ shapely=True,
279
+ xlabel=None,
280
+ ylabel=None,
281
+ cmap=None,
282
+ ):
283
+ import matplotlib
284
+
285
+ try:
286
+ from shapely import union_all
287
+ from shapely.geometry import Polygon as _Polygon
288
+ from shapely.geometry import box as _rectangle_box
289
+
290
+ shapely = True and shapely
291
+ except ImportError:
292
+ shapely = False
293
+ warn(
294
+ "Shapely not installed. Fallbacking to matplotlib. The plots may be inacurate."
295
+ )
296
+ if alpha is None:
297
+ alpha = 0.8 if shapely else 1
298
+ if not shapely and alpha != 1:
299
+ warn("Opacity without shapely will lead to incorect plots.")
300
+ cmap = (
301
+ matplotlib.colormaps["Spectral"] if cmap is None else matplotlib.colormaps[cmap]
302
+ )
303
+ box = list(box)
304
+ if not (separated):
305
+ # fig, ax = plt.subplots()
306
+ ax = plt.gca()
307
+ ax.set(xlim=[box[0][0], box[1][0]], ylim=[box[0][1], box[1][1]])
308
+ n_summands = len(corners)
309
+ for i in range(n_summands):
310
+ trivial_summand = True
311
+ list_of_rect = []
312
+ for birth in corners[i][0]:
313
+ if len(birth) == 1:
314
+ birth = np.asarray([birth[0]] * 2)
315
+ birth = np.asarray(birth).clip(min=box[0])
316
+ for death in corners[i][1]:
317
+ if len(death) == 1:
318
+ death = np.asarray([death[0]] * 2)
319
+ death = np.asarray(death).clip(max=box[1])
320
+ if death[1] > birth[1] and death[0] > birth[0]:
321
+ if trivial_summand and _d_inf(birth, death) > min_persistence:
322
+ trivial_summand = False
323
+ if shapely:
324
+ list_of_rect.append(
325
+ _rectangle_box(
326
+ birth[0], birth[1], death[0], death[1])
327
+ )
328
+ else:
329
+ list_of_rect.append(
330
+ _rectangle(birth, death, cmap(
331
+ i / n_summands), alpha)
332
+ )
333
+ if not (trivial_summand):
334
+ if separated:
335
+ fig, ax = plt.subplots()
336
+ ax.set(xlim=[box[0][0], box[1][0]],
337
+ ylim=[box[0][1], box[1][1]])
338
+ if shapely:
339
+ summand_shape = union_all(list_of_rect)
340
+ if type(summand_shape) is _Polygon:
341
+ xs, ys = summand_shape.exterior.xy
342
+ ax.fill(xs, ys, alpha=alpha, fc=cmap(
343
+ i / n_summands), ec="None")
344
+ else:
345
+ for polygon in summand_shape.geoms:
346
+ xs, ys = polygon.exterior.xy
347
+ ax.fill(xs, ys, alpha=alpha, fc=cmap(
348
+ i / n_summands), ec="None")
349
+ else:
350
+ for rectangle in list_of_rect:
351
+ ax.add_patch(rectangle)
352
+ if separated:
353
+ if xlabel:
354
+ plt.xlabel(xlabel)
355
+ if ylabel:
356
+ plt.ylabel(ylabel)
357
+ if dimension >= 0:
358
+ plt.title(rf"$H_{dimension}$ $2$-persistence")
359
+ if not (separated):
360
+ if xlabel is not None:
361
+ plt.xlabel(xlabel)
362
+ if ylabel is not None:
363
+ plt.ylabel(ylabel)
364
+ if dimension >= 0:
365
+ plt.title(rf"$H_{dimension}$ $2$-persistence")
366
+ return
367
+
368
+
369
+ def plot_simplicial_complex(
370
+ st, pts: ArrayLike, x: float, y: float, mma=None, degree=None
371
+ ):
372
+ """
373
+ Scatters the points, with the simplices in the filtration at coordinates (x,y).
374
+ if an mma module is given, plots it in a second axis
375
+ """
376
+ if mma is not None:
377
+ fig, (a, b) = plt.subplots(ncols=2, figsize=(15, 5))
378
+ plt.sca(a)
379
+ plot_simplicial_complex(st, pts, x, y)
380
+ plt.sca(b)
381
+ mma.plot(degree=degree)
382
+ box = mma.get_box()
383
+ a, b, c, d = box.ravel()
384
+ mma.plot(degree=1, min_persistence=0.01)
385
+ plt.vlines(x, b, d, color="k", linestyle="--")
386
+ plt.hlines(y, a, c, color="k", linestyle="--")
387
+ plt.scatter([x], [y], c="r", zorder=10)
388
+ plt.text(x + 0.01 * (b - a), y + 0.01 * (d - c), f"({x},{y})")
389
+ return
390
+
391
+ pts = np.asarray(pts)
392
+ values = np.array([-f[1] for s, f in st.get_skeleton(0)])
393
+ qs = np.quantile(values, np.linspace(0, 1, 100))
394
+
395
+ def color_idx(d):
396
+ return np.searchsorted(qs, d) / 100
397
+
398
+ from matplotlib.pyplot import get_cmap
399
+
400
+ def color(d):
401
+ return get_cmap("viridis")([0, color_idx(d), 1])[1]
402
+
403
+ cols_pc = np.asarray([color(v) for v in values])
404
+ ax = plt.gca()
405
+ for s, f in st: # simplexe, filtration
406
+ density = -f[1]
407
+ if len(s) <= 1 or f[0] > x or density < -y: # simplexe = point
408
+ continue
409
+ if len(s) == 2: # simplexe = segment
410
+ xx = np.array([pts[a, 0] for a in s])
411
+ yy = np.array([pts[a, 1] for a in s])
412
+ plt.plot(xx, yy, c=color(density), alpha=1,
413
+ zorder=10 * density, lw=1.5)
414
+ if len(s) == 3: # simplexe = triangle
415
+ xx = np.array([pts[a, 0] for a in s])
416
+ yy = np.array([pts[a, 1] for a in s])
417
+ _c = color(density)
418
+ ax.fill(xx, yy, c=_c, alpha=0.3, zorder=0)
419
+ out = plt.scatter(pts[:, 0], pts[:, 1], c=cols_pc, zorder=10, s=10)
420
+ ax.set_aspect(1)
421
+ return out
422
+
423
+
424
+ def plot_point_cloud(
425
+ pts,
426
+ function,
427
+ x,
428
+ y,
429
+ mma=None,
430
+ degree=None,
431
+ ball_alpha=0.3,
432
+ point_cmap="viridis",
433
+ color_bias=1,
434
+ ball_color=None,
435
+ point_size=20,
436
+ ):
437
+ if mma is not None:
438
+ fig, (a, b) = plt.subplots(ncols=2, figsize=(15, 5))
439
+ plt.sca(a)
440
+ plot_point_cloud(pts, function, x, y)
441
+ plt.sca(b)
442
+ mma.plot(degree=degree)
443
+ box = mma.get_box()
444
+ a, b, c, d = box.ravel()
445
+ mma.plot(degree=1, min_persistence=0.01)
446
+ plt.vlines(x, b, d, color="k", linestyle="--")
447
+ plt.hlines(y, a, c, color="k", linestyle="--")
448
+ plt.scatter([x], [y], c="r", zorder=10)
449
+ plt.text(x + 0.01 * (b - a), y + 0.01 * (d - c), f"({x},{y})")
450
+ return
451
+ values = -function
452
+ qs = np.quantile(values, np.linspace(0, 1, 100))
453
+
454
+ def color_idx(d):
455
+ return np.searchsorted(qs, d * color_bias) / 100
456
+
457
+ from matplotlib.collections import PatchCollection
458
+ from matplotlib.pyplot import get_cmap
459
+
460
+ def color(d):
461
+ return get_cmap(point_cmap)([0, color_idx(d), 1])[1]
462
+
463
+ _colors = np.array([color(v) for v in values])
464
+ ax = plt.gca()
465
+ idx = function <= y
466
+ circles = [plt.Circle(pt, x) for pt, c in zip(pts[idx], function)]
467
+ pc = PatchCollection(circles, alpha=ball_alpha, color=ball_color)
468
+ ax.add_collection(pc)
469
+ plt.scatter(*pts.T, c=_colors, s=point_size)
470
+ ax.set_aspect(1)