multipers 2.4.0b1__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. multipers/.dylibs/libboost_timer.dylib +0 -0
  2. multipers/.dylibs/libc++.1.0.dylib +0 -0
  3. multipers/.dylibs/libtbb.12.17.dylib +0 -0
  4. multipers/__init__.py +33 -0
  5. multipers/_signed_measure_meta.py +426 -0
  6. multipers/_slicer_meta.py +231 -0
  7. multipers/array_api/__init__.py +62 -0
  8. multipers/array_api/numpy.py +124 -0
  9. multipers/array_api/torch.py +133 -0
  10. multipers/data/MOL2.py +458 -0
  11. multipers/data/UCR.py +18 -0
  12. multipers/data/__init__.py +1 -0
  13. multipers/data/graphs.py +466 -0
  14. multipers/data/immuno_regions.py +27 -0
  15. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  16. multipers/data/pytorch2simplextree.py +91 -0
  17. multipers/data/shape3d.py +101 -0
  18. multipers/data/synthetic.py +113 -0
  19. multipers/distances.py +202 -0
  20. multipers/filtration_conversions.pxd +736 -0
  21. multipers/filtration_conversions.pxd.tp +226 -0
  22. multipers/filtrations/__init__.py +21 -0
  23. multipers/filtrations/density.py +529 -0
  24. multipers/filtrations/filtrations.py +480 -0
  25. multipers/filtrations.pxd +534 -0
  26. multipers/filtrations.pxd.tp +332 -0
  27. multipers/function_rips.cpython-312-darwin.so +0 -0
  28. multipers/function_rips.pyx +104 -0
  29. multipers/grids.cpython-312-darwin.so +0 -0
  30. multipers/grids.pyx +538 -0
  31. multipers/gudhi/Persistence_slices_interface.h +213 -0
  32. multipers/gudhi/Simplex_tree_interface.h +274 -0
  33. multipers/gudhi/Simplex_tree_multi_interface.h +648 -0
  34. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  35. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  36. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  37. multipers/gudhi/gudhi/Debug_utils.h +52 -0
  38. multipers/gudhi/gudhi/Degree_rips_bifiltration.h +2307 -0
  39. multipers/gudhi/gudhi/Dynamic_multi_parameter_filtration.h +2524 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field.h +453 -0
  41. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +460 -0
  42. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +444 -0
  43. multipers/gudhi/gudhi/Fields/Multi_field_small.h +584 -0
  44. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +490 -0
  45. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +580 -0
  46. multipers/gudhi/gudhi/Fields/Z2_field.h +391 -0
  47. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +389 -0
  48. multipers/gudhi/gudhi/Fields/Zp_field.h +493 -0
  49. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +384 -0
  50. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +492 -0
  51. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  52. multipers/gudhi/gudhi/Matrix.h +2200 -0
  53. multipers/gudhi/gudhi/Multi_filtration/Multi_parameter_generator.h +1712 -0
  54. multipers/gudhi/gudhi/Multi_filtration/multi_filtration_conversions.h +237 -0
  55. multipers/gudhi/gudhi/Multi_filtration/multi_filtration_utils.h +225 -0
  56. multipers/gudhi/gudhi/Multi_parameter_filtered_complex.h +485 -0
  57. multipers/gudhi/gudhi/Multi_parameter_filtration.h +2643 -0
  58. multipers/gudhi/gudhi/Multi_persistence/Box.h +233 -0
  59. multipers/gudhi/gudhi/Multi_persistence/Line.h +309 -0
  60. multipers/gudhi/gudhi/Multi_persistence/Multi_parameter_filtered_complex_pcoh_interface.h +268 -0
  61. multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_cohomology.h +159 -0
  62. multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_matrix.h +463 -0
  63. multipers/gudhi/gudhi/Multi_persistence/Point.h +853 -0
  64. multipers/gudhi/gudhi/Off_reader.h +173 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +834 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +838 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +833 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1367 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1157 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +869 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +905 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +122 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +260 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +288 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +170 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +247 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +571 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +182 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +130 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +235 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +312 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1092 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +923 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +914 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +930 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1071 -0
  87. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +203 -0
  88. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +886 -0
  89. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +984 -0
  90. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1213 -0
  91. multipers/gudhi/gudhi/Persistence_matrix/index_mapper.h +58 -0
  92. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +227 -0
  93. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +200 -0
  94. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +166 -0
  95. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +319 -0
  96. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +562 -0
  97. multipers/gudhi/gudhi/Persistence_on_a_line.h +152 -0
  98. multipers/gudhi/gudhi/Persistence_on_rectangle.h +617 -0
  99. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  100. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  101. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  102. multipers/gudhi/gudhi/Persistent_cohomology.h +769 -0
  103. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  104. multipers/gudhi/gudhi/Projective_cover_kernel.h +379 -0
  105. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  106. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +559 -0
  107. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  108. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +121 -0
  109. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  110. multipers/gudhi/gudhi/Simplex_tree/filtration_value_utils.h +155 -0
  111. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  112. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  113. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +60 -0
  114. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +105 -0
  115. multipers/gudhi/gudhi/Simplex_tree.h +3170 -0
  116. multipers/gudhi/gudhi/Slicer.h +848 -0
  117. multipers/gudhi/gudhi/Thread_safe_slicer.h +393 -0
  118. multipers/gudhi/gudhi/distance_functions.h +62 -0
  119. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  120. multipers/gudhi/gudhi/multi_simplex_tree_helpers.h +147 -0
  121. multipers/gudhi/gudhi/persistence_interval.h +263 -0
  122. multipers/gudhi/gudhi/persistence_matrix_options.h +188 -0
  123. multipers/gudhi/gudhi/reader_utils.h +367 -0
  124. multipers/gudhi/gudhi/simple_mdspan.h +484 -0
  125. multipers/gudhi/gudhi/slicer_helpers.h +779 -0
  126. multipers/gudhi/tmp_h0_pers/mma_interface_h0.h +223 -0
  127. multipers/gudhi/tmp_h0_pers/naive_merge_tree.h +536 -0
  128. multipers/io.cpython-312-darwin.so +0 -0
  129. multipers/io.pyx +472 -0
  130. multipers/ml/__init__.py +0 -0
  131. multipers/ml/accuracies.py +90 -0
  132. multipers/ml/invariants_with_persistable.py +79 -0
  133. multipers/ml/kernels.py +176 -0
  134. multipers/ml/mma.py +713 -0
  135. multipers/ml/one.py +472 -0
  136. multipers/ml/point_clouds.py +352 -0
  137. multipers/ml/signed_measures.py +1667 -0
  138. multipers/ml/sliced_wasserstein.py +461 -0
  139. multipers/ml/tools.py +113 -0
  140. multipers/mma_structures.cpython-312-darwin.so +0 -0
  141. multipers/mma_structures.pxd +134 -0
  142. multipers/mma_structures.pyx +1483 -0
  143. multipers/mma_structures.pyx.tp +1126 -0
  144. multipers/multi_parameter_rank_invariant/diff_helpers.h +85 -0
  145. multipers/multi_parameter_rank_invariant/euler_characteristic.h +95 -0
  146. multipers/multi_parameter_rank_invariant/function_rips.h +317 -0
  147. multipers/multi_parameter_rank_invariant/hilbert_function.h +761 -0
  148. multipers/multi_parameter_rank_invariant/persistence_slices.h +149 -0
  149. multipers/multi_parameter_rank_invariant/rank_invariant.h +350 -0
  150. multipers/multiparameter_edge_collapse.py +41 -0
  151. multipers/multiparameter_module_approximation/approximation.h +2541 -0
  152. multipers/multiparameter_module_approximation/debug.h +107 -0
  153. multipers/multiparameter_module_approximation/format_python-cpp.h +292 -0
  154. multipers/multiparameter_module_approximation/utilities.h +428 -0
  155. multipers/multiparameter_module_approximation.cpython-312-darwin.so +0 -0
  156. multipers/multiparameter_module_approximation.pyx +286 -0
  157. multipers/ops.cpython-312-darwin.so +0 -0
  158. multipers/ops.pyx +231 -0
  159. multipers/pickle.py +89 -0
  160. multipers/plots.py +550 -0
  161. multipers/point_measure.cpython-312-darwin.so +0 -0
  162. multipers/point_measure.pyx +409 -0
  163. multipers/simplex_tree_multi.cpython-312-darwin.so +0 -0
  164. multipers/simplex_tree_multi.pxd +136 -0
  165. multipers/simplex_tree_multi.pyx +11719 -0
  166. multipers/simplex_tree_multi.pyx.tp +2102 -0
  167. multipers/slicer.cpython-312-darwin.so +0 -0
  168. multipers/slicer.pxd +2097 -0
  169. multipers/slicer.pxd.tp +263 -0
  170. multipers/slicer.pyx +13042 -0
  171. multipers/slicer.pyx.tp +1259 -0
  172. multipers/tensor/tensor.h +672 -0
  173. multipers/tensor.pxd +13 -0
  174. multipers/test.pyx +44 -0
  175. multipers/tests/__init__.py +70 -0
  176. multipers/torch/__init__.py +1 -0
  177. multipers/torch/diff_grids.py +240 -0
  178. multipers/torch/rips_density.py +310 -0
  179. multipers/vector_interface.pxd +46 -0
  180. multipers-2.4.0b1.dist-info/METADATA +131 -0
  181. multipers-2.4.0b1.dist-info/RECORD +184 -0
  182. multipers-2.4.0b1.dist-info/WHEEL +6 -0
  183. multipers-2.4.0b1.dist-info/licenses/LICENSE +21 -0
  184. multipers-2.4.0b1.dist-info/top_level.txt +1 -0
multipers/plots.py ADDED
@@ -0,0 +1,550 @@
1
+ from typing import Optional
2
+
3
+ import matplotlib.colors as mcolors
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from matplotlib.colors import ListedColormap
8
+ from numpy.typing import ArrayLike
9
+
10
+ from multipers.array_api import to_numpy
11
+
12
+ _custom_colors = [
13
+ "#03045e",
14
+ "#0077b6",
15
+ "#00b4d8",
16
+ "#90e0ef",
17
+ ]
18
+ _cmap_ = ListedColormap(_custom_colors)
19
+ _cmap = mcolors.LinearSegmentedColormap.from_list(
20
+ "continuous_cmap", _cmap_.colors, N=256
21
+ )
22
+
23
+
24
+ def _plot_rectangle(rectangle: np.ndarray, weight, **plt_kwargs):
25
+ rectangle = np.asarray(rectangle)
26
+ x_axis = rectangle[[0, 2]]
27
+ y_axis = rectangle[[1, 3]]
28
+ color = "blue" if weight > 0 else "red"
29
+ plt.plot(x_axis, y_axis, c=color, **plt_kwargs)
30
+
31
+
32
+ def _plot_signed_measure_2(
33
+ pts, weights, temp_alpha=0.7, threshold=(np.inf, np.inf), **plt_kwargs
34
+ ):
35
+ import matplotlib.colors
36
+
37
+ pts = np.clip(pts, a_min=-np.inf, a_max=np.asarray(threshold)[None, :])
38
+ weights = np.asarray(weights)
39
+ color_weights = np.array(weights, dtype=float)
40
+ neg_idx = weights < 0
41
+ pos_idx = weights > 0
42
+ if np.any(neg_idx):
43
+ current_weights = -weights[neg_idx]
44
+ min_weight = np.max(current_weights)
45
+ color_weights[neg_idx] /= min_weight
46
+ color_weights[neg_idx] -= 1
47
+ else:
48
+ min_weight = 0
49
+
50
+ if np.any(pos_idx):
51
+ current_weights = weights[pos_idx]
52
+ max_weight = np.max(current_weights)
53
+ color_weights[pos_idx] /= max_weight
54
+ color_weights[pos_idx] += 1
55
+ else:
56
+ max_weight = 1
57
+
58
+ bordeaux = np.array([0.70567316, 0.01555616, 0.15023281, 1])
59
+ light_bordeaux = np.array([0.70567316, 0.01555616, 0.15023281, temp_alpha])
60
+ bleu = np.array([0.2298057, 0.29871797, 0.75368315, 1])
61
+ light_bleu = np.array([0.2298057, 0.29871797, 0.75368315, temp_alpha])
62
+ norm = plt.Normalize(-2, 2)
63
+ cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
64
+ "", [bordeaux, light_bordeaux, "white", light_bleu, bleu]
65
+ )
66
+ plt.scatter(
67
+ pts[:, 0], pts[:, 1], c=color_weights, cmap=cmap, norm=norm, **plt_kwargs
68
+ )
69
+ plt.scatter([], [], color=bleu, label="positive mass", **plt_kwargs)
70
+ plt.scatter([], [], color=bordeaux, label="negative mass", **plt_kwargs)
71
+ plt.legend()
72
+
73
+
74
+ def _plot_signed_measure_4(
75
+ pts,
76
+ weights,
77
+ x_smoothing: float = 1,
78
+ area_alpha: bool = True,
79
+ threshold=(np.inf, np.inf),
80
+ alpha=None,
81
+ **plt_kwargs, # ignored ftm
82
+ ):
83
+ # compute the maximal rectangle area
84
+ pts = np.clip(pts, a_min=-np.inf, a_max=np.array((*threshold, *threshold))[None, :])
85
+ alpha_rescaling = 0
86
+ for rectangle, weight in zip(pts, weights):
87
+ if rectangle[2] >= x_smoothing * rectangle[0]:
88
+ alpha_rescaling = max(
89
+ alpha_rescaling,
90
+ (rectangle[2] / x_smoothing - rectangle[0])
91
+ * (rectangle[3] - rectangle[1]),
92
+ )
93
+ # draw the rectangles
94
+ for rectangle, weight in zip(pts, weights):
95
+ # draw only the rectangles that have not been reduced to the empty set
96
+ if rectangle[2] >= x_smoothing * rectangle[0]:
97
+ # make the alpha channel proportional to the rectangle's area
98
+ if area_alpha:
99
+ _plot_rectangle(
100
+ rectangle=[
101
+ rectangle[0],
102
+ rectangle[1],
103
+ rectangle[2] / x_smoothing,
104
+ rectangle[3],
105
+ ],
106
+ weight=weight,
107
+ alpha=(
108
+ (rectangle[2] / x_smoothing - rectangle[0])
109
+ * (rectangle[3] - rectangle[1])
110
+ / alpha_rescaling
111
+ if alpha is None
112
+ else alpha
113
+ ),
114
+ **plt_kwargs,
115
+ )
116
+ else:
117
+ _plot_rectangle(
118
+ rectangle=[
119
+ rectangle[0],
120
+ rectangle[1],
121
+ rectangle[2] / x_smoothing,
122
+ rectangle[3],
123
+ ],
124
+ weight=weight,
125
+ alpha=1 if alpha is None else alpha,
126
+ **plt_kwargs,
127
+ )
128
+
129
+
130
+ def plot_signed_measure(signed_measure, threshold=None, ax=None, **plt_kwargs):
131
+ if ax is None:
132
+ ax = plt.gca()
133
+ else:
134
+ plt.sca(ax)
135
+ pts, weights = signed_measure
136
+ pts = to_numpy(pts)
137
+ weights = to_numpy(weights)
138
+ num_pts = pts.shape[0]
139
+ num_parameters = pts.shape[1]
140
+ if threshold is None:
141
+ if num_pts == 0:
142
+ threshold = (np.inf, np.inf)
143
+ else:
144
+ if num_parameters == 4:
145
+ pts_ = np.concatenate([pts[:, :2], pts[:, 2:]], axis=0)
146
+ else:
147
+ pts_ = pts
148
+ threshold = np.max(np.ma.masked_invalid(pts_), axis=0)
149
+ threshold = np.max(
150
+ [threshold, [plt.gca().get_xlim()[1], plt.gca().get_ylim()[1]]], axis=0
151
+ )
152
+
153
+ assert num_parameters in (2, 4)
154
+ if num_parameters == 2:
155
+ _plot_signed_measure_2(
156
+ pts=pts, weights=weights, threshold=threshold, **plt_kwargs
157
+ )
158
+ else:
159
+ _plot_signed_measure_4(
160
+ pts=pts, weights=weights, threshold=threshold, **plt_kwargs
161
+ )
162
+
163
+
164
+ def plot_signed_measures(signed_measures, threshold=None, size=4, alpha=None):
165
+ num_degrees = len(signed_measures)
166
+ if num_degrees <= 1:
167
+ axes = [plt.gca()]
168
+ else:
169
+ fig, axes = plt.subplots(
170
+ nrows=1, ncols=num_degrees, figsize=(num_degrees * size, size)
171
+ )
172
+ for ax, signed_measure in zip(axes, signed_measures):
173
+ plot_signed_measure(
174
+ signed_measure=signed_measure, ax=ax, threshold=threshold, alpha=alpha
175
+ )
176
+ plt.tight_layout()
177
+
178
+
179
+ def plot_surface(
180
+ grid,
181
+ hf,
182
+ fig=None,
183
+ ax=None,
184
+ cmap: Optional[str] = None,
185
+ discrete_surface: bool = False,
186
+ has_negative_values: bool = False,
187
+ contour: bool = True,
188
+ **plt_args,
189
+ ):
190
+ import matplotlib
191
+
192
+ grid = [to_numpy(g) for g in grid]
193
+ hf = to_numpy(hf)
194
+ if ax is None:
195
+ ax = plt.gca()
196
+ else:
197
+ plt.sca(ax)
198
+ if hf.ndim == 3 and hf.shape[0] == 1:
199
+ hf = hf[0]
200
+ assert hf.ndim == 2, "Can only plot a 2d surface"
201
+ fig = plt.gcf() if fig is None else fig
202
+ if cmap is None:
203
+ if discrete_surface:
204
+ cmap = matplotlib.colormaps["gray_r"]
205
+ else:
206
+ cmap = _cmap
207
+ if discrete_surface or not contour:
208
+ # for shading="flat"
209
+ grid = [np.concatenate([g, [g[-1] * 1.1 - 0.1 * g[0]]]) for g in grid]
210
+ if discrete_surface:
211
+ if has_negative_values:
212
+ bounds = np.arange(-5, 6, 1, dtype=int)
213
+ else:
214
+ bounds = np.arange(0, 11, 1, dtype=int)
215
+ norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N, extend="max")
216
+ im = ax.pcolormesh(
217
+ grid[0], grid[1], hf.T, cmap=cmap, norm=norm, shading="flat", **plt_args
218
+ )
219
+ cbar = fig.colorbar(
220
+ matplotlib.cm.ScalarMappable(cmap=cmap, norm=norm),
221
+ spacing="proportional",
222
+ ax=ax,
223
+ )
224
+ cbar.set_ticks(ticks=bounds, labels=bounds)
225
+ return im
226
+
227
+ if contour:
228
+ levels = plt_args.pop("levels", 50)
229
+ im = ax.contourf(grid[0], grid[1], hf.T, cmap=cmap, levels=levels, **plt_args)
230
+ else:
231
+ im = ax.pcolormesh(
232
+ grid[0], grid[1], hf.T, cmap=cmap, shading="flat", **plt_args
233
+ )
234
+ return im
235
+
236
+
237
+ def plot_surfaces(HF, size=4, **plt_args):
238
+ grid, hf = HF
239
+ assert hf.ndim == 3, (
240
+ f"Found hf.shape = {hf.shape}, expected ndim = 3 : degree, 2-parameter surface."
241
+ )
242
+ num_degrees = hf.shape[0]
243
+ fig, axes = plt.subplots(
244
+ nrows=1, ncols=num_degrees, figsize=(num_degrees * size, size)
245
+ )
246
+ if num_degrees == 1:
247
+ axes = [axes]
248
+ for ax, hf_of_degree in zip(axes, hf):
249
+ plot_surface(grid=grid, hf=hf_of_degree, fig=fig, ax=ax, **plt_args)
250
+ plt.tight_layout()
251
+
252
+
253
+ def _rectangle(x, y, color, alpha):
254
+ """
255
+ Defines a rectangle patch in the format {z | x  ≤ z ≤ y} with color and alpha
256
+ """
257
+ from matplotlib.patches import Rectangle as RectanglePatch
258
+
259
+ return RectanglePatch(
260
+ x, max(y[0] - x[0], 0), max(y[1] - x[1], 0), color=color, alpha=alpha
261
+ )
262
+
263
+
264
+ def _d_inf(a, b):
265
+ a = np.asarray(a)
266
+ b = np.asarray(b)
267
+ return np.min(np.abs(b - a))
268
+
269
+
270
+ HAS_SHAPELY = None
271
+
272
+
273
+ def plot2d_PyModule(
274
+ corners,
275
+ box,
276
+ *,
277
+ dimension=-1,
278
+ separated=False,
279
+ min_persistence=0,
280
+ alpha=None,
281
+ verbose=False,
282
+ save=False,
283
+ dpi=200,
284
+ xlabel=None,
285
+ ylabel=None,
286
+ cmap=None,
287
+ outline_width=0.2,
288
+ outline_threshold=np.inf,
289
+ interleavings=None,
290
+ backend=None,
291
+ ):
292
+ global HAS_SHAPELY
293
+ if HAS_SHAPELY is None:
294
+ try:
295
+ import shapely
296
+ from shapely import union_all
297
+
298
+ HAS_SHAPELY = True
299
+ except ImportError:
300
+ HAS_SHAPELY = False
301
+ from warnings import warn
302
+
303
+ warn(
304
+ "Shapely is not installed. MMA plots may be imprecise.",
305
+ ImportWarning,
306
+ )
307
+ if not HAS_SHAPELY:
308
+ backend = "matplotlib" if backend is None else backend
309
+ alpha = 1 if alpha is None else alpha
310
+ else:
311
+ backend = "shapely" if backend is None else backend
312
+ alpha = 0.8 if alpha is None else alpha
313
+
314
+ cmap_instance = (
315
+ matplotlib.colormaps["Spectral"] if cmap is None else matplotlib.colormaps[cmap]
316
+ )
317
+
318
+ box = np.asarray(box)
319
+ if not separated:
320
+ ax = plt.gca()
321
+ ax.set(xlim=[box[0][0], box[1][0]], ylim=[box[0][1], box[1][1]])
322
+
323
+ n_summands = len(corners)
324
+
325
+ for i in range(n_summands):
326
+ summand_interleaving = 0 if interleavings is None else interleavings[i]
327
+
328
+ births = np.asarray(corners[i][0])
329
+ deaths = np.asarray(corners[i][1])
330
+
331
+ if births.size == 0 or deaths.size == 0:
332
+ continue
333
+
334
+ if births.ndim == 1:
335
+ births = births[None, :]
336
+ if deaths.ndim == 1:
337
+ deaths = deaths[None, :]
338
+ if births.ndim != 2 or deaths.ndim != 2:
339
+ raise ValueError(
340
+ f"Invalid corners format. Got {births.shape=}, {deaths.shape=}"
341
+ )
342
+
343
+ b_expanded = births[:, None, :]
344
+ d_expanded = deaths[None, :, :]
345
+
346
+ births_grid, deaths_grid = np.broadcast_arrays(b_expanded, d_expanded)
347
+ births_flat = births_grid.reshape(-1, 2)
348
+ deaths_flat = deaths_grid.reshape(-1, 2)
349
+
350
+ births_flat = np.maximum(births_flat, box[0])
351
+ deaths_flat = np.minimum(deaths_flat, box[1])
352
+
353
+ is_valid = np.all(deaths_flat > births_flat, axis=1)
354
+
355
+ if not np.any(is_valid):
356
+ continue
357
+
358
+ valid_births = births_flat[is_valid]
359
+ valid_deaths = deaths_flat[is_valid]
360
+
361
+ if interleavings is None:
362
+ diffs = valid_deaths - valid_births
363
+ d_infs = np.min(diffs, axis=1)
364
+ current_max_interleaving = np.max(d_infs) if d_infs.size > 0 else 0
365
+ summand_interleaving = max(summand_interleaving, current_max_interleaving)
366
+
367
+ if summand_interleaving <= min_persistence:
368
+ continue
369
+
370
+ # --- Plotting ---
371
+ color = cmap_instance(i / n_summands)
372
+ outline_summand = (
373
+ "black" if (summand_interleaving > outline_threshold) else None
374
+ )
375
+
376
+ if separated:
377
+ fig, ax = plt.subplots()
378
+ ax.set(xlim=[box[0][0], box[1][0]], ylim=[box[0][1], box[1][1]])
379
+
380
+ if HAS_SHAPELY:
381
+ # OPTIMIZATION: Shapely Union
382
+ import shapely
383
+ from shapely import union_all
384
+
385
+ rects = shapely.box(
386
+ valid_births[:, 0],
387
+ valid_births[:, 1],
388
+ valid_deaths[:, 0],
389
+ valid_deaths[:, 1],
390
+ )
391
+ summand_shape = union_all(rects)
392
+
393
+ geoms = getattr(summand_shape, "geoms", [summand_shape])
394
+ for geom in geoms:
395
+ if geom.is_empty:
396
+ continue
397
+ xs, ys = geom.exterior.xy
398
+ ax.fill(
399
+ xs,
400
+ ys,
401
+ alpha=alpha,
402
+ fc=color,
403
+ ec=outline_summand,
404
+ lw=outline_width,
405
+ ls="-",
406
+ )
407
+ else:
408
+ from matplotlib.collections import PolyCollection
409
+
410
+ # Construct vertices: (N, 4, 2)
411
+ # (x0, y0), (x0, y1), (x1, y1), (x1, y0)
412
+ verts = np.stack(
413
+ [
414
+ np.stack([valid_births[:, 0], valid_births[:, 1]], axis=1),
415
+ np.stack([valid_births[:, 0], valid_deaths[:, 1]], axis=1),
416
+ np.stack([valid_deaths[:, 0], valid_deaths[:, 1]], axis=1),
417
+ np.stack([valid_deaths[:, 0], valid_births[:, 1]], axis=1),
418
+ ],
419
+ axis=1,
420
+ )
421
+
422
+ pc = PolyCollection(
423
+ verts,
424
+ facecolors=color,
425
+ edgecolors=outline_summand,
426
+ alpha=alpha,
427
+ linewidths=outline_width,
428
+ )
429
+ ax.add_collection(pc)
430
+
431
+ if separated:
432
+ if xlabel:
433
+ plt.xlabel(xlabel)
434
+ if ylabel:
435
+ plt.ylabel(ylabel)
436
+ if dimension >= 0:
437
+ plt.title(f"$\\mathrm{{H}}_{dimension}$ 2-persistence")
438
+
439
+ if not separated:
440
+ if xlabel:
441
+ plt.xlabel(xlabel)
442
+ if ylabel:
443
+ plt.ylabel(ylabel)
444
+ if dimension >= 0:
445
+ plt.title(f"$\\mathrm{{H}}_{dimension}$ 2-persistence")
446
+
447
+ return
448
+
449
+
450
+ def plot_simplicial_complex(
451
+ st, pts: ArrayLike, x: float, y: float, mma=None, degree=None
452
+ ):
453
+ """
454
+ Scatters the points, with the simplices in the filtration at coordinates (x,y).
455
+ if an mma module is given, plots it in a second axis
456
+ """
457
+ if mma is not None:
458
+ fig, (a, b) = plt.subplots(ncols=2, figsize=(15, 5))
459
+ plt.sca(a)
460
+ plot_simplicial_complex(st, pts, x, y)
461
+ plt.sca(b)
462
+ mma.plot(degree=degree)
463
+ box = mma.get_box()
464
+ a, b, c, d = box.ravel()
465
+ mma.plot(degree=1, min_persistence=0.01)
466
+ plt.vlines(x, b, d, color="k", linestyle="--")
467
+ plt.hlines(y, a, c, color="k", linestyle="--")
468
+ plt.scatter([x], [y], c="r", zorder=10)
469
+ plt.text(x + 0.01 * (b - a), y + 0.01 * (d - c), f"({x},{y})")
470
+ return
471
+
472
+ pts = np.asarray(pts)
473
+ values = np.array([-f[1] for s, f in st.get_skeleton(0)])
474
+ qs = np.quantile(values, np.linspace(0, 1, 100))
475
+
476
+ def color_idx(d):
477
+ return np.searchsorted(qs, d) / 100
478
+
479
+ from matplotlib.pyplot import get_cmap
480
+
481
+ def color(d):
482
+ return get_cmap("viridis")([0, color_idx(d), 1])[1]
483
+
484
+ cols_pc = np.asarray([color(v) for v in values])
485
+ ax = plt.gca()
486
+ for s, f in st: # simplexe, filtration
487
+ density = -f[1]
488
+ if len(s) <= 1 or f[0] > x or density < -y: # simplexe = point
489
+ continue
490
+ if len(s) == 2: # simplexe = segment
491
+ xx = np.array([pts[a, 0] for a in s])
492
+ yy = np.array([pts[a, 1] for a in s])
493
+ plt.plot(xx, yy, c=color(density), alpha=1, zorder=10 * density, lw=1.5)
494
+ if len(s) == 3: # simplexe = triangle
495
+ xx = np.array([pts[a, 0] for a in s])
496
+ yy = np.array([pts[a, 1] for a in s])
497
+ _c = color(density)
498
+ ax.fill(xx, yy, c=_c, alpha=0.3, zorder=0)
499
+ out = plt.scatter(pts[:, 0], pts[:, 1], c=cols_pc, zorder=10, s=10)
500
+ ax.set_aspect(1)
501
+ return out
502
+
503
+
504
+ def plot_point_cloud(
505
+ pts,
506
+ function,
507
+ x,
508
+ y,
509
+ mma=None,
510
+ degree=None,
511
+ ball_alpha=0.3,
512
+ point_cmap="viridis",
513
+ color_bias=1,
514
+ ball_color=None,
515
+ point_size=20,
516
+ ):
517
+ if mma is not None:
518
+ fig, (a, b) = plt.subplots(ncols=2, figsize=(15, 5))
519
+ plt.sca(a)
520
+ plot_point_cloud(pts, function, x, y)
521
+ plt.sca(b)
522
+ mma.plot(degree=degree)
523
+ box = mma.get_box()
524
+ a, b, c, d = box.ravel()
525
+ mma.plot(degree=1, min_persistence=0.01)
526
+ plt.vlines(x, b, d, color="k", linestyle="--")
527
+ plt.hlines(y, a, c, color="k", linestyle="--")
528
+ plt.scatter([x], [y], c="r", zorder=10)
529
+ plt.text(x + 0.01 * (b - a), y + 0.01 * (d - c), f"({x},{y})")
530
+ return
531
+ values = -function
532
+ qs = np.quantile(values, np.linspace(0, 1, 100))
533
+
534
+ def color_idx(d):
535
+ return np.searchsorted(qs, d * color_bias) / 100
536
+
537
+ from matplotlib.collections import PatchCollection
538
+ from matplotlib.pyplot import get_cmap
539
+
540
+ def color(d):
541
+ return get_cmap(point_cmap)([0, color_idx(d), 1])[1]
542
+
543
+ _colors = np.array([color(v) for v in values])
544
+ ax = plt.gca()
545
+ idx = function <= y
546
+ circles = [plt.Circle(pt, x) for pt, c in zip(pts[idx], function)]
547
+ pc = PatchCollection(circles, alpha=ball_alpha, color=ball_color)
548
+ ax.add_collection(pc)
549
+ plt.scatter(*pts.T, c=_colors, s=point_size)
550
+ ax.set_aspect(1)