multipers 2.3.3b6__cp312-cp312-manylinux_2_39_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -0
  2. multipers/_signed_measure_meta.py +453 -0
  3. multipers/_slicer_meta.py +211 -0
  4. multipers/array_api/__init__.py +45 -0
  5. multipers/array_api/numpy.py +41 -0
  6. multipers/array_api/torch.py +58 -0
  7. multipers/data/MOL2.py +458 -0
  8. multipers/data/UCR.py +18 -0
  9. multipers/data/__init__.py +1 -0
  10. multipers/data/graphs.py +466 -0
  11. multipers/data/immuno_regions.py +27 -0
  12. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  13. multipers/data/pytorch2simplextree.py +91 -0
  14. multipers/data/shape3d.py +101 -0
  15. multipers/data/synthetic.py +113 -0
  16. multipers/distances.py +202 -0
  17. multipers/filtration_conversions.pxd +229 -0
  18. multipers/filtration_conversions.pxd.tp +84 -0
  19. multipers/filtrations/__init__.py +18 -0
  20. multipers/filtrations/density.py +574 -0
  21. multipers/filtrations/filtrations.py +361 -0
  22. multipers/filtrations.pxd +224 -0
  23. multipers/function_rips.cpython-312-x86_64-linux-gnu.so +0 -0
  24. multipers/function_rips.pyx +105 -0
  25. multipers/grids.cpython-312-x86_64-linux-gnu.so +0 -0
  26. multipers/grids.pyx +433 -0
  27. multipers/gudhi/Persistence_slices_interface.h +132 -0
  28. multipers/gudhi/Simplex_tree_interface.h +239 -0
  29. multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
  30. multipers/gudhi/cubical_to_boundary.h +59 -0
  31. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  32. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  33. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  34. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  35. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  36. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  37. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  38. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  39. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  41. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  42. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  43. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  44. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  45. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  46. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  47. multipers/gudhi/gudhi/Matrix.h +2107 -0
  48. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  49. multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
  50. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  51. multipers/gudhi/gudhi/Off_reader.h +173 -0
  52. multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
  53. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  54. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  87. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  88. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  89. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  90. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  91. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  92. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  93. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  94. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  95. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  96. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  97. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  98. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  99. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  100. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  101. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  102. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
  103. multipers/gudhi/gudhi/distance_functions.h +62 -0
  104. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  105. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  106. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  107. multipers/gudhi/gudhi/reader_utils.h +367 -0
  108. multipers/gudhi/mma_interface_coh.h +256 -0
  109. multipers/gudhi/mma_interface_h0.h +223 -0
  110. multipers/gudhi/mma_interface_matrix.h +293 -0
  111. multipers/gudhi/naive_merge_tree.h +536 -0
  112. multipers/gudhi/scc_io.h +310 -0
  113. multipers/gudhi/truc.h +1403 -0
  114. multipers/io.cpython-312-x86_64-linux-gnu.so +0 -0
  115. multipers/io.pyx +644 -0
  116. multipers/ml/__init__.py +0 -0
  117. multipers/ml/accuracies.py +90 -0
  118. multipers/ml/invariants_with_persistable.py +79 -0
  119. multipers/ml/kernels.py +176 -0
  120. multipers/ml/mma.py +713 -0
  121. multipers/ml/one.py +472 -0
  122. multipers/ml/point_clouds.py +352 -0
  123. multipers/ml/signed_measures.py +1589 -0
  124. multipers/ml/sliced_wasserstein.py +461 -0
  125. multipers/ml/tools.py +113 -0
  126. multipers/mma_structures.cpython-312-x86_64-linux-gnu.so +0 -0
  127. multipers/mma_structures.pxd +128 -0
  128. multipers/mma_structures.pyx +2786 -0
  129. multipers/mma_structures.pyx.tp +1094 -0
  130. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
  131. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  132. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  133. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  134. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  135. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  136. multipers/multiparameter_edge_collapse.py +41 -0
  137. multipers/multiparameter_module_approximation/approximation.h +2330 -0
  138. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  139. multipers/multiparameter_module_approximation/debug.h +107 -0
  140. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  141. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  142. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  143. multipers/multiparameter_module_approximation/images.h +79 -0
  144. multipers/multiparameter_module_approximation/list_column.h +174 -0
  145. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  146. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  147. multipers/multiparameter_module_approximation/set_column.h +135 -0
  148. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  149. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  150. multipers/multiparameter_module_approximation/utilities.h +403 -0
  151. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  152. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  153. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  154. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  155. multipers/multiparameter_module_approximation.cpython-312-x86_64-linux-gnu.so +0 -0
  156. multipers/multiparameter_module_approximation.pyx +235 -0
  157. multipers/pickle.py +90 -0
  158. multipers/plots.py +456 -0
  159. multipers/point_measure.cpython-312-x86_64-linux-gnu.so +0 -0
  160. multipers/point_measure.pyx +395 -0
  161. multipers/simplex_tree_multi.cpython-312-x86_64-linux-gnu.so +0 -0
  162. multipers/simplex_tree_multi.pxd +134 -0
  163. multipers/simplex_tree_multi.pyx +10840 -0
  164. multipers/simplex_tree_multi.pyx.tp +2009 -0
  165. multipers/slicer.cpython-312-x86_64-linux-gnu.so +0 -0
  166. multipers/slicer.pxd +3034 -0
  167. multipers/slicer.pxd.tp +234 -0
  168. multipers/slicer.pyx +20481 -0
  169. multipers/slicer.pyx.tp +1088 -0
  170. multipers/tensor/tensor.h +672 -0
  171. multipers/tensor.pxd +13 -0
  172. multipers/test.pyx +44 -0
  173. multipers/tests/__init__.py +62 -0
  174. multipers/torch/__init__.py +1 -0
  175. multipers/torch/diff_grids.py +240 -0
  176. multipers/torch/rips_density.py +310 -0
  177. multipers-2.3.3b6.dist-info/METADATA +128 -0
  178. multipers-2.3.3b6.dist-info/RECORD +182 -0
  179. multipers-2.3.3b6.dist-info/WHEEL +5 -0
  180. multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
  181. multipers-2.3.3b6.dist-info/top_level.txt +1 -0
  182. multipers.libs/libtbb-ca48af5c.so.12.16 +0 -0
@@ -0,0 +1,1589 @@
1
+ from collections.abc import Iterable, Sequence
2
+ from itertools import product
3
+ from typing import Optional, Union
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from joblib import Parallel, delayed
8
+ from sklearn.base import BaseEstimator, TransformerMixin
9
+ from tqdm import tqdm
10
+
11
+ import multipers as mp
12
+ from multipers.array_api import api_from_tensor
13
+ from multipers.filtrations.density import available_kernels, convolution_signed_measures
14
+ from multipers.grids import compute_grid
15
+ from multipers.point_measure import rank_decomposition_by_rectangles, signed_betti
16
+
17
+
18
+ class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
19
+ """
20
+ Input
21
+ -----
22
+ Iterable[SimplexTreeMulti]
23
+
24
+ Output
25
+ ------
26
+ Iterable[ list[signed_measure for degree] ]
27
+
28
+ signed measure is either
29
+ - (points : (n x num_parameters) array, weights : (n) int array ) if sparse,
30
+ - else an integer matrix.
31
+
32
+ Parameters
33
+ ----------
34
+ - degrees : list of degrees to compute. None correspond to the euler characteristic
35
+ - filtration grid : the grid on which to compute.
36
+ If None, the fit will infer it from
37
+ - fit_fraction : the fraction of data to consider for the fit, seed is controlled by the seed parameter
38
+ - resolution : the resolution of this grid
39
+ - filtration_quantile : filtrations values quantile to ignore
40
+ - grid_strategy:str : 'regular' or 'quantile' or 'exact'
41
+ - normalize filtration : if sparse, will normalize all filtrations.
42
+ - expand : expands the simplextree to compute correctly the degree, for
43
+ flag complexes
44
+ - invariant : the topological invariant to produce the signed measure.
45
+ Choices are "hilbert" or "euler". Will add rank invariant later.
46
+ - num_collapse : Either an int or "full". Collapse the complex before
47
+ doing computation.
48
+ - _möbius_inversion : if False, will not do the mobius inversion. output
49
+ has to be a matrix then.
50
+ - enforce_null_mass : Returns a zero mass measure, by thresholding the
51
+ module if True.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ # homological degrees + None for euler
57
+ degrees: list[int | None] = [],
58
+ rank_degrees: list[int] = [], # same for rank invariant
59
+ filtration_grid: (
60
+ Sequence[Sequence[np.ndarray]]
61
+ # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
62
+ | None
63
+ ) = None,
64
+ progress=False, # tqdm
65
+ num_collapses: int | str = 0, # edge collapses before computing
66
+ n_jobs=None,
67
+ resolution: (
68
+ Iterable[int] | int | None
69
+ ) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
70
+ # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
71
+ plot: bool = False,
72
+ filtration_quantile: float = 0.0, # quantile for inferring filtration grid
73
+ # wether or not to do the möbius inversion (not recommended to touch)
74
+ # _möbius_inversion: bool = True,
75
+ expand=False, # expand the simplextree befoe computing the homology
76
+ normalize_filtrations: bool = False,
77
+ # exact_computation:bool=False, # compute the exact signed measure.
78
+ grid_strategy: str = "exact",
79
+ seed: int = 0, # if fit_fraction is not 1, the seed sampling
80
+ fit_fraction=1, # the fraction of the data on which to fit
81
+ out_resolution: Iterable[int] | int | None = None,
82
+ individual_grid: Optional[
83
+ bool
84
+ ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
85
+ enforce_null_mass: bool = False,
86
+ flatten=True,
87
+ backend: Optional[str] = None,
88
+ ):
89
+ super().__init__()
90
+ self.degrees = degrees
91
+ self.rank_degrees = rank_degrees
92
+ self.filtration_grid = filtration_grid
93
+ self.progress = progress
94
+ self.num_collapses = num_collapses
95
+ self.n_jobs = n_jobs
96
+ self.resolution = resolution
97
+ self.plot = plot
98
+ self.backend = backend
99
+ # self.sparse=sparse # TODO : deprecate
100
+ self.filtration_quantile = filtration_quantile
101
+ # Will only work for non sparse output. (discrete matrices cannot be "rescaled")
102
+ self.normalize_filtrations = normalize_filtrations
103
+ self.grid_strategy = grid_strategy
104
+ # self._möbius_inversion = _möbius_inversion
105
+ self._reconversion_grid = None
106
+ self.expand = expand
107
+ # will only refit the grid if filtration_grid has never been given.
108
+ self._refit_grid = None
109
+ self.seed = seed
110
+ self.fit_fraction = fit_fraction
111
+ self._transform_st = None
112
+ self.out_resolution = out_resolution
113
+ self.individual_grid = individual_grid
114
+ self.enforce_null_mass = enforce_null_mass
115
+ self._default_mass_location = None
116
+ self.flatten = flatten
117
+ self.num_parameters: int = 0
118
+
119
+ return
120
+
121
+ @staticmethod
122
+ def _is_filtered_complex(input):
123
+ return mp.simplex_tree_multi.is_simplextree_multi(input) or mp.slicer.is_slicer(
124
+ input, allow_minpres=True
125
+ )
126
+
127
+ def _input_checks(self, X):
128
+ assert len(X) > 0, "No filtered complex found. Cannot fit."
129
+ assert self._is_filtered_complex(
130
+ X[0][0]
131
+ ), f"X[0] is not a known filtered complex, {X[0]=}, nor X[0][0]."
132
+ self._num_axis = len(X[0])
133
+ first = X[0][0]
134
+ assert (
135
+ not mp.slicer.is_slicer(first) or not self.expand
136
+ ), "Cannot expand slicers."
137
+ assert not mp.slicer.is_slicer(first) or not (
138
+ isinstance(first, Union[tuple, list]) and first[0].is_minpres
139
+ ), "Multi-degree minpres are not supported yet as an input. This can still be computed by providing a backend."
140
+
141
+ def _infer_filtration(self, X):
142
+ self.num_parameters = X[0][0].num_parameters
143
+ indices = np.random.choice(
144
+ len(X), min(int(self.fit_fraction * len(X)) + 1, len(X)), replace=False
145
+ )
146
+ ## ax, num_x
147
+ filtrations = tuple(
148
+ tuple(
149
+ compute_grid(x, strategy="exact")
150
+ for x in (X[idx][ax] for idx in indices)
151
+ )
152
+ for ax in range(self._num_axis)
153
+ )
154
+ num_parameters = len(filtrations[0][0])
155
+ assert (
156
+ num_parameters == self.num_parameters
157
+ ), f"Internal error, got {num_parameters=} and {self.num_parameters=}"
158
+
159
+ filtrations_values = [
160
+ [
161
+ np.unique(np.concatenate([x[i] for x in filtrations[ax]]))
162
+ for i in range(num_parameters)
163
+ ]
164
+ for ax in range(self._num_axis)
165
+ ]
166
+ ## ax, param, gridsize
167
+ filtration_grid = tuple(
168
+ compute_grid(
169
+ filtrations_values[ax],
170
+ resolution=self.resolution,
171
+ strategy=self.grid_strategy,
172
+ )
173
+ for ax in range(self._num_axis)
174
+ ) # TODO :use more parameters
175
+ self.filtration_grid = filtration_grid
176
+ return filtration_grid
177
+
178
+ def _params_check(self):
179
+ assert (
180
+ self.resolution is not None
181
+ or self.filtration_grid is not None
182
+ or self.grid_strategy == "exact"
183
+ or self.individual_grid
184
+ ), "For non exact filtrations, a resolution has to be specified."
185
+
186
+ def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
187
+ self._params_check()
188
+ self._input_checks(X)
189
+
190
+ if isinstance(self.resolution, int):
191
+ self.resolution = [self.resolution] * self.num_parameters
192
+
193
+ self.individual_grid = (
194
+ self.individual_grid
195
+ if self.individual_grid is not None
196
+ else self.grid_strategy
197
+ in ["regular_closest", "exact", "quantile", "partition"]
198
+ )
199
+
200
+ if (
201
+ not self.enforce_null_mass
202
+ and self.individual_grid
203
+ or self.filtration_grid is not None
204
+ ):
205
+ self._refit_grid = False
206
+ else:
207
+ self._refit_grid = True
208
+
209
+ if self._refit_grid:
210
+ self._infer_filtration(X=X)
211
+ if self.out_resolution is None:
212
+ self.out_resolution = self.resolution
213
+ # elif isinstance(self.out_resolution, int):
214
+ # self.out_resolution = [self.out_resolution] * self.num_parameters
215
+ if self.normalize_filtrations and not self.individual_grid:
216
+ # self._reconversion_grid = [np.linspace(0,1, num=len(f), dtype=float) for f in self.filtration_grid] ## This will not work for non-regular grids...
217
+ self._reconversion_grid = [
218
+ [(f - np.min(f)) / np.std(f) for f in F] for F in self.filtration_grid
219
+ ] # not the best, but better than some weird magic
220
+ # elif not self.sparse: # It actually renormalizes the filtration !!
221
+ # self._reconversion_grid = [np.linspace(0,r, num=r, dtype=int) for r in self.out_resolution]
222
+ else:
223
+ self._reconversion_grid = self.filtration_grid
224
+ ## ax, num_param
225
+ self._default_mass_location = (
226
+ np.asarray([[g[-1] for g in F] for F in self.filtration_grid])
227
+ if self.enforce_null_mass
228
+ else None
229
+ )
230
+ return self
231
+
232
+ def transform1(
233
+ self,
234
+ simplextree,
235
+ ax,
236
+ # _reconversion_grid,
237
+ thread_id: str = "",
238
+ ):
239
+ # st = mp.SimplexTreeMulti(st, num_parameters=st.num_parameters) # COPY
240
+ if self.individual_grid:
241
+ filtration_grid = compute_grid(
242
+ simplextree, strategy=self.grid_strategy, resolution=self.resolution
243
+ )
244
+ mass_default = (
245
+ self._default_mass_location[ax] if self.enforce_null_mass else None
246
+ )
247
+ if self.enforce_null_mass:
248
+ filtration_grid = [
249
+ np.concatenate([f, [d]], axis=0)
250
+ for f, d in zip(filtration_grid, mass_default)
251
+ ]
252
+ _reconversion_grid = filtration_grid
253
+ else:
254
+ filtration_grid = self.filtration_grid[ax]
255
+ _reconversion_grid = self._reconversion_grid[ax]
256
+ mass_default = (
257
+ self._default_mass_location[ax] if self.enforce_null_mass else None
258
+ )
259
+
260
+ st = simplextree.grid_squeeze(filtration_grid=filtration_grid)
261
+ if st.num_parameters == 2 and mp.simplex_tree_multi.is_simplextree_multi(st):
262
+ st.collapse_edges(num=self.num_collapses, max_dimension=1)
263
+ int_degrees = np.asarray([d for d in self.degrees if d is not None], dtype=int)
264
+ # EULER. First as there is prune above dimension below
265
+ if self.expand and None in self.degrees:
266
+ st.expansion(st.num_vertices)
267
+ signed_measures_euler = (
268
+ mp.signed_measure(
269
+ st,
270
+ degrees=[None],
271
+ plot=self.plot,
272
+ mass_default=mass_default,
273
+ invariant="euler",
274
+ # thread_id=thread_id,
275
+ backend=self.backend,
276
+ grid=_reconversion_grid,
277
+ )[0]
278
+ if None in self.degrees
279
+ else []
280
+ )
281
+
282
+ if self.expand and len(int_degrees) > 0:
283
+ st.expansion(np.max(int_degrees) + 1)
284
+ if len(int_degrees) > 0:
285
+ st.prune_above_dimension(
286
+ np.max(np.concatenate([int_degrees, self.rank_degrees])) + 1
287
+ ) # no need to compute homology beyond this
288
+ signed_measures_pers = (
289
+ mp.signed_measure(
290
+ st,
291
+ degrees=int_degrees,
292
+ mass_default=mass_default,
293
+ plot=self.plot,
294
+ invariant="hilbert",
295
+ thread_id=thread_id,
296
+ backend=self.backend,
297
+ grid=_reconversion_grid,
298
+ )
299
+ if len(int_degrees) > 0
300
+ else []
301
+ )
302
+ if self.plot:
303
+ plt.show()
304
+ if self.expand and len(self.rank_degrees) > 0:
305
+ st.expansion(np.max(self.rank_degrees) + 1)
306
+ if len(self.rank_degrees) > 0:
307
+ st.prune_above_dimension(
308
+ np.max(self.rank_degrees) + 1
309
+ ) # no need to compute homology beyond this
310
+ signed_measures_rank = (
311
+ mp.signed_measure(
312
+ st,
313
+ degrees=self.rank_degrees,
314
+ mass_default=mass_default,
315
+ plot=self.plot,
316
+ invariant="rank",
317
+ thread_id=thread_id,
318
+ backend=self.backend,
319
+ grid=_reconversion_grid,
320
+ )
321
+ if len(self.rank_degrees) > 0
322
+ else []
323
+ )
324
+ if self.plot:
325
+ plt.show()
326
+
327
+ count = 0
328
+ signed_measures = []
329
+ for d in self.degrees:
330
+ if d is None:
331
+ signed_measures.append(signed_measures_euler)
332
+ else:
333
+ signed_measures.append(signed_measures_pers[count])
334
+ count += 1
335
+ signed_measures += signed_measures_rank
336
+ return signed_measures
337
+
338
+ def transform(self, X):
339
+ ## X of shape (num_x, num_axis, filtered_complex
340
+ assert (
341
+ self.filtration_grid is not None and self._reconversion_grid is not None
342
+ ) or self.individual_grid, "Fit first"
343
+
344
+ def todo_x(x):
345
+ return tuple(self.transform1(x_axis, j) for j, x_axis in enumerate(x))
346
+
347
+ ## out shape num_x, num_axis, degree, sm
348
+ out = tuple(
349
+ Parallel(n_jobs=self.n_jobs, backend="threading")(
350
+ delayed(todo_x)(x) for x in X
351
+ )
352
+ )
353
+ # out = Parallel(n_jobs=self.n_jobs, backend="threading")(
354
+ # delayed(self.transform1)(to_st, thread_id=str(thread_id))
355
+ # for thread_id, to_st in tqdm(
356
+ # enumerate(X),
357
+ # disable=not self.progress,
358
+ # desc="Computing signed measure decompositions",
359
+ # )
360
+ # )
361
+ return out
362
+
363
+
364
+ class SimplexTree2SignedMeasure(FilteredComplex2SignedMeasure):
365
+ def __init__(
366
+ self,
367
+ # homological degrees + None for euler
368
+ degrees: list[int | None] = [],
369
+ rank_degrees: list[int] = [], # same for rank invariant
370
+ filtration_grid: (
371
+ Sequence[Sequence[np.ndarray]]
372
+ # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
373
+ | None
374
+ ) = None,
375
+ progress=False, # tqdm
376
+ num_collapses: int | str = 0, # edge collapses before computing
377
+ n_jobs=None,
378
+ resolution: (
379
+ Iterable[int] | int | None
380
+ ) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
381
+ # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
382
+ plot: bool = False,
383
+ filtration_quantile: float = 0.0, # quantile for inferring filtration grid
384
+ # wether or not to do the möbius inversion (not recommended to touch)
385
+ # _möbius_inversion: bool = True,
386
+ expand=False, # expand the simplextree befoe computing the homology
387
+ normalize_filtrations: bool = False,
388
+ # exact_computation:bool=False, # compute the exact signed measure.
389
+ grid_strategy: str = "exact",
390
+ seed: int = 0, # if fit_fraction is not 1, the seed sampling
391
+ fit_fraction=1, # the fraction of the data on which to fit
392
+ out_resolution: Iterable[int] | int | None = None,
393
+ individual_grid: Optional[
394
+ bool
395
+ ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
396
+ enforce_null_mass: bool = False,
397
+ flatten=True,
398
+ backend: Optional[str] = None,
399
+ ):
400
+ stuff = locals()
401
+ stuff.pop("self")
402
+ keys = list(stuff.keys())
403
+ for key in keys:
404
+ if key.startswith("__"):
405
+ stuff.pop(key)
406
+ super().__init__(**stuff)
407
+ from warnings import warn
408
+
409
+ warn("This class is deprecated, use FilteredComplex2SignedMeasure instead.")
410
+
411
+
412
+ # class SimplexTrees2SignedMeasures(SimplexTree2SignedMeasure):
413
+ # """
414
+ # Input
415
+ # -----
416
+ #
417
+ # (data) x (axis, e.g. different bandwidths for simplextrees) x (simplextree)
418
+ #
419
+ # Output
420
+ # ------
421
+ # (data) x (axis) x (degree) x (signed measure)
422
+ # """
423
+ #
424
+ # def __init__(self, **kwargs):
425
+ # super().__init__(**kwargs)
426
+ # self._num_st_per_data = None
427
+ # # self._super_model=SimplexTree2SignedMeasure(**kwargs)
428
+ # self._filtration_grids = None
429
+ # return
430
+ #
431
+ # def fit(self, X, y=None):
432
+ # if len(X) == 0:
433
+ # return self
434
+ # try:
435
+ # self._num_st_per_data = len(X[0])
436
+ # except:
437
+ # raise Exception(
438
+ # "Shape has to be (num_data, num_axis), dtype=SimplexTreeMulti"
439
+ # )
440
+ # self._filtration_grids = []
441
+ # for axis in range(self._num_st_per_data):
442
+ # self._filtration_grids.append(
443
+ # super().fit([x[axis] for x in X]).filtration_grid
444
+ # )
445
+ # # self._super_fits.append(truc)
446
+ # # self._super_fits_params = [super().fit([x[axis] for x in X]).get_params() for axis in range(self._num_st_per_data)]
447
+ # return self
448
+ #
449
+ # def transform(self, X):
450
+ # if self.normalize_filtrations:
451
+ # _reconversion_grids = [
452
+ # [np.linspace(0, 1, num=len(f), dtype=float) for f in F]
453
+ # for F in self._filtration_grids
454
+ # ]
455
+ # else:
456
+ # _reconversion_grids = self._filtration_grids
457
+ #
458
+ # def todo(x):
459
+ # # return [SimplexTree2SignedMeasure().set_params(**transformer_params).transform1(x[axis]) for axis,transformer_params in enumerate(self._super_fits_params)]
460
+ # out = [
461
+ # self.transform1(
462
+ # x[axis],
463
+ # filtration_grid=filtration_grid,
464
+ # _reconversion_grid=_reconversion_grid,
465
+ # )
466
+ # for axis, filtration_grid, _reconversion_grid in zip(
467
+ # range(self._num_st_per_data),
468
+ # self._filtration_grids,
469
+ # _reconversion_grids,
470
+ # )
471
+ # ]
472
+ # return out
473
+ #
474
+ # return Parallel(n_jobs=self.n_jobs, backend="threading")(
475
+ # delayed(todo)(x)
476
+ # for x in tqdm(
477
+ # X,
478
+ # disable=not self.progress,
479
+ # desc="Computing Signed Measures from simplextrees.",
480
+ # )
481
+ # )
482
+
483
+
484
+ def rescale_sparse_signed_measure(
485
+ signed_measure, filtration_weights, normalize_scales=None
486
+ ):
487
+ # from copy import deepcopy
488
+ #
489
+ # out = deepcopy(signed_measure)
490
+
491
+ if filtration_weights is None and normalize_scales is None:
492
+ return signed_measure
493
+
494
+ # if normalize_scales is None:
495
+ # out = tuple(
496
+ # (
497
+ # _cat(
498
+ # tuple(
499
+ # signed_measure[degree][0][:, parameter]
500
+ # * filtration_weights[parameter]
501
+ # for parameter in range(num_parameters)
502
+ # ),
503
+ # axis=1,
504
+ # ),
505
+ # signed_measure[degree][1],
506
+ # )
507
+ # for degree in range(len(signed_measure))
508
+ # )
509
+ # for degree in range(len(signed_measure)): # degree
510
+ # for parameter in range(len(filtration_weights)):
511
+ # signed_measure[degree][0][:, parameter] *= filtration_weights[parameter]
512
+ # # TODO Broadcast w.r.t. the parameter
513
+ # out = tuple(
514
+ # _cat(
515
+ # tuple(
516
+ # signed_measure[degree][0][:, [parameter]]
517
+ # * filtration_weights[parameter]
518
+ # / (
519
+ # normalize_scales[degree][parameter]
520
+ # if normalize_scales is not None
521
+ # else 1
522
+ # )
523
+ # for parameter in range(num_parameters)
524
+ # ),
525
+ # axis=1,
526
+ # )
527
+ # for degree in range(len(signed_measure))
528
+ # )
529
+ out = tuple(
530
+ (
531
+ signed_measure[degree][0]
532
+ * (1 if filtration_weights is None else filtration_weights.reshape(1, -1))
533
+ / (
534
+ normalize_scales[degree].reshape(1, -1)
535
+ if normalize_scales is not None
536
+ else 1
537
+ ),
538
+ signed_measure[degree][1],
539
+ )
540
+ for degree in range(len(signed_measure))
541
+ )
542
+ # for degree in range(len(out)):
543
+ # for parameter in range(len(filtration_weights)):
544
+ # out[degree][0][:, parameter] *= (
545
+ # filtration_weights[parameter] / normalize_scales[degree][parameter]
546
+ # )
547
+ return out
548
+
549
+
550
+ def sm2deep(signed_measure):
551
+ dirac_positions, dirac_signs = signed_measure
552
+ dtype = dirac_positions.dtype
553
+ new_shape = list(dirac_positions.shape)
554
+ new_shape[1] += 1
555
+ if isinstance(dirac_positions, np.ndarray):
556
+ c = np.empty(new_shape, dtype=dtype)
557
+ c[:, :-1] = dirac_positions
558
+ c[:, -1] = dirac_signs
559
+
560
+ else:
561
+ import torch
562
+
563
+ c = torch.empty(new_shape, dtype=dtype)
564
+ c[:, :-1] = dirac_positions
565
+ if isinstance(dirac_signs, np.ndarray):
566
+ dirac_signs = torch.from_numpy(dirac_signs)
567
+ c[:, -1] = dirac_signs
568
+ return c
569
+
570
+
571
+ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
572
+ """
573
+ Input
574
+ -----
575
+
576
+ (data) x (degree) x (signed measure) or (data) x (axis) x (degree) x (signed measure)
577
+
578
+ Iterable[list[signed_measure_matrix of degree]] or Iterable[previous].
579
+
580
+ The second is meant to use multiple choices for signed measure input. An example of usage : they come from a Rips + Density with different bandwidth.
581
+ It is controlled by the axis parameter.
582
+
583
+ Output
584
+ ------
585
+
586
+ Iterable[list[(reweighted)_sparse_signed_measure of degree]]
587
+
588
+ or (deep format)
589
+
590
+ Tensor of shape (num_axis*num_degrees, data, max_num_pts, num_parameters)
591
+ """
592
+
593
+ def __init__(
594
+ self,
595
+ filtrations_weights: Optional[Iterable[float]] = None,
596
+ normalize=False,
597
+ plot: bool = False,
598
+ unsparse: bool = False,
599
+ axis: int = -1,
600
+ resolution: int | Iterable[int] = 50,
601
+ flatten: bool = False,
602
+ deep_format: bool = False,
603
+ unrag: bool = True,
604
+ n_jobs: int = 1,
605
+ verbose: bool = False,
606
+ integrate: bool = False,
607
+ grid_strategy="regular",
608
+ ):
609
+ super().__init__()
610
+ self.filtrations_weights = filtrations_weights
611
+ self.num_parameters: int = 0
612
+ self.plot = plot
613
+ self.unsparse = unsparse
614
+ self.n_jobs = n_jobs
615
+ self.axis = axis
616
+ self._num_axis = 0
617
+ self.resolution = resolution
618
+ self._filtrations_bounds = None
619
+ self.flatten = flatten
620
+ self.normalize = normalize
621
+ self._normalization_factors = None
622
+ self.deep_format = deep_format
623
+ self.unrag = unrag
624
+ assert (
625
+ not self.deep_format or not self.unsparse or not self.integrate
626
+ ), "One post processing at the time."
627
+ self.verbose = verbose
628
+ self._num_degrees = 0
629
+ self.integrate = integrate
630
+ self.grid_strategy = grid_strategy
631
+ self._infered_grids = None
632
+ self._axis_iterator = None
633
+ self._backend = None
634
+ return
635
+
636
+ def _get_filtration_bounds(self, X, axis):
637
+ stuff = [
638
+ self._backend.cat(
639
+ [sm[axis][degree][0] for sm in X],
640
+ axis=0,
641
+ )
642
+ for degree in range(self._num_degrees)
643
+ ]
644
+ sizes_ = np.array([len(x) == 0 for x in stuff])
645
+ assert np.all(~sizes_), f"Degree axis {np.where(sizes_)} is/are trivial !"
646
+
647
+ filtrations_bounds = self._backend.asnumpy(
648
+ self._backend.stack(
649
+ [
650
+ self._backend.stack(
651
+ [
652
+ self._backend.minvalues(f, axis=0),
653
+ self._backend.maxvalues(f, axis=0),
654
+ ]
655
+ )
656
+ for f in stuff
657
+ ]
658
+ )
659
+ ) ## don't want to rescale gradient of normalization
660
+ normalization_factors = (
661
+ filtrations_bounds[:, 1] - filtrations_bounds[:, 0]
662
+ if self.normalize
663
+ else None
664
+ )
665
+ # print("Normalization factors : ",self._normalization_factors)
666
+ if (normalization_factors == 0).any():
667
+ indices = normalization_factors == 0
668
+ # warn(f"Constant filtration encountered, at degree, parameter {indices} and axis {self.axis}.")
669
+ normalization_factors[indices] = 1
670
+ return filtrations_bounds, normalization_factors
671
+
672
+ def _plot_signed_measures(self, sms: Iterable[np.ndarray], size=4):
673
+ from multipers.plots import plot_signed_measure
674
+
675
+ num_degrees = len(sms[0])
676
+ num_imgs = len(sms)
677
+ fig, axes = plt.subplots(
678
+ ncols=num_degrees,
679
+ nrows=num_imgs,
680
+ figsize=(size * num_degrees, size * num_imgs),
681
+ )
682
+ axes = np.asarray(axes).reshape(num_imgs, num_degrees)
683
+ # assert axes.ndim==2, "Internal error"
684
+ for i, sm in enumerate(sms):
685
+ for j, sm_of_degree in enumerate(sm):
686
+ plot_signed_measure(sm_of_degree, ax=axes[i, j])
687
+
688
+ @staticmethod
689
+ def _check_sm(sm) -> bool:
690
+ return (
691
+ isinstance(sm, tuple)
692
+ and hasattr(sm[0], "ndim")
693
+ and sm[0].ndim == 2
694
+ and len(sm) == 2
695
+ )
696
+
697
+ def _check_axis(self, X):
698
+ # axes should be (num_data, num_axis, num_degrees, (signed_measure))
699
+ if len(X) == 0:
700
+ return
701
+ if len(X[0]) == 0:
702
+ return
703
+ if self._check_sm(X[0][0]):
704
+ self._has_axis = False
705
+ self._num_axis = 1
706
+ self._axis_iterator = [slice(None)]
707
+ return
708
+ assert self._check_sm( ## vaguely checks that its a signed measure
709
+ _sm := X[0][0][0]
710
+ ), f"Cannot take this input. # data, axis, degrees, sm.\n Got {_sm} of type {type(_sm)}"
711
+
712
+ self._has_axis = True
713
+ self._num_axis = len(X[0])
714
+ self._axis_iterator = range(self._num_axis) if self.axis == -1 else [self.axis]
715
+
716
+ def _check_backend(self, X):
717
+ if self._has_axis:
718
+ # data, axis, degrees, (pts, weights)
719
+ first_sm = X[0][0][0][0]
720
+ else:
721
+ first_sm = X[0][0][0]
722
+ self._backend = api_from_tensor(first_sm, verbose=self.verbose)
723
+
724
+ def _check_measures(self, X):
725
+ if self._has_axis:
726
+ first_sm = X[0][0]
727
+ else:
728
+ first_sm = X[0]
729
+ self._num_degrees = len(first_sm)
730
+ self.num_parameters = first_sm[0][0].shape[1]
731
+
732
+ def _check_resolution(self):
733
+ assert self.num_parameters > 0, "Num parameters hasn't been initialized."
734
+ if isinstance(self.resolution, int):
735
+ self.resolution = [self.resolution] * self.num_parameters
736
+ self.resolution = np.asarray(self.resolution, dtype=int)
737
+ assert (
738
+ self.resolution.shape[0] == self.num_parameters
739
+ ), "Resolution doesn't have a proper size."
740
+
741
+ def _check_weights(self):
742
+ if self.filtrations_weights is None:
743
+ return
744
+ assert (
745
+ self.filtrations_weights.shape[0] == self.num_parameters
746
+ ), "Filtration weights don't have a proper size"
747
+
748
+ def _infer_grids(self, X):
749
+ # Computes normalization factors
750
+ if self.normalize:
751
+ # if self._has_axis and self.axis == -1:
752
+ self._filtrations_bounds = []
753
+ self._normalization_factors = []
754
+ for ax in self._axis_iterator:
755
+ (
756
+ filtration_bounds,
757
+ normalization_factors,
758
+ ) = self._get_filtration_bounds(X, axis=ax)
759
+ self._filtrations_bounds.append(filtration_bounds)
760
+ self._normalization_factors.append(normalization_factors)
761
+ self._filtrations_bounds = self._backend.astensor(self._filtrations_bounds)
762
+ self._normalization_factors = self._backend.astensor(self._normalization_factors)
763
+ # else:
764
+ # (
765
+ # self._filtrations_bounds,
766
+ # self._normalization_factors,
767
+ # ) = self._get_filtration_bounds(
768
+ # X, axis=self._axis_iterator[0]
769
+ # ) ## axis = slice(None)
770
+ elif self.integrate or self.unsparse or self.deep_format:
771
+ filtration_values = [
772
+ np.concatenate(
773
+ [
774
+ (
775
+ stuff
776
+ if isinstance(stuff := x[ax][degree][0], np.ndarray)
777
+ else stuff.detach().numpy()
778
+ )
779
+ for x in X
780
+ for degree in range(self._num_degrees)
781
+ ]
782
+ )
783
+ for ax in self._axis_iterator
784
+ ]
785
+ # axis, filtration_values
786
+ filtration_values = [
787
+ self._backend.astensor(compute_grid(
788
+ f_ax.T, resolution=self.resolution, strategy=self.grid_strategy
789
+ ))
790
+ for f_ax in filtration_values
791
+ ]
792
+ self._infered_grids = filtration_values
793
+
794
+ def _print_stats(self, X):
795
+ print("------------SignedMeasureFormatter------------")
796
+ print("---- Parameters")
797
+ print(f"Number of axis : {self._num_axis}")
798
+ print(f"Number of degrees : {self._num_degrees}")
799
+ print(f"Filtration bounds : \n{self._filtrations_bounds}")
800
+ print(f"Normalization factor : \n{self._normalization_factors}")
801
+ if self._infered_grids is not None:
802
+ print(
803
+ f"Filtration grid shape : \n \
804
+ {tuple(tuple(len(f) for f in F) for F in self._infered_grids)}"
805
+ )
806
+ print("---- SM stats")
807
+ print("In axis :", self._num_axis)
808
+ sizes = [
809
+ [[len(xd[1]) for xd in x[ax]] for x in X] for ax in self._axis_iterator
810
+ ]
811
+ print(f"Size means (axis) x (degree): {np.mean(sizes, axis=(1))}")
812
+ print(f"Size std : {np.std(sizes, axis=(1))}")
813
+ print("----------------------------------------------")
814
+
815
+ def fit(self, X, y=None):
816
+ # Gets a grid. This will be the max in each coord+1
817
+ if (
818
+ len(X) == 0
819
+ or len(X[0]) == 0
820
+ or (self.axis is not None and len(X[0][0][0]) == 0)
821
+ ):
822
+ return self
823
+
824
+ self._check_axis(X)
825
+ self._check_backend(X)
826
+ self._check_measures(X)
827
+ self._check_resolution()
828
+ self._check_weights()
829
+ # if not sparse : not recommended.
830
+
831
+ self._infer_grids(X)
832
+ if self.verbose:
833
+ self._print_stats(X)
834
+ return self
835
+
836
+ def unsparse_signed_measure(self, sparse_signed_measure):
837
+ filtrations = self._infered_grids # ax, filtration
838
+ out = []
839
+ for filtrations_of_ax, ax in zip(filtrations, self._axis_iterator, strict=True):
840
+ sparse_signed_measure_of_ax = sparse_signed_measure[ax]
841
+ measure_of_ax = []
842
+ for pts, weights in sparse_signed_measure_of_ax: # over degree
843
+ signed_measure, _ = np.histogramdd(
844
+ pts, bins=filtrations_of_ax, weights=weights
845
+ )
846
+ if self.flatten:
847
+ signed_measure = signed_measure.flatten()
848
+ measure_of_ax.append(signed_measure)
849
+ out.append(np.asarray(measure_of_ax))
850
+
851
+ if self.flatten:
852
+ out = np.concatenate(out).flatten()
853
+ elif self.axis == -1:
854
+ return np.asarray(out)
855
+ else:
856
+ return np.asarray(out)[0]
857
+
858
+ @staticmethod
859
+ def _integrate_measure(sm, filtrations):
860
+ from multipers.point_measure import integrate_measure
861
+
862
+ return integrate_measure(sm[0], sm[1], filtrations)
863
+
864
+ def _rescale_measures(self, X):
865
+ def rescale_from_sparse(sparse_signed_measure):
866
+ if self.axis == -1 and self._has_axis:
867
+ return tuple(
868
+ rescale_sparse_signed_measure(
869
+ sparse_signed_measure[ax],
870
+ filtration_weights=self.filtrations_weights,
871
+ normalize_scales=n,
872
+ )
873
+ for ax, n in zip(
874
+ self._axis_iterator, self._normalization_factors, strict=True
875
+ )
876
+ )
877
+ return rescale_sparse_signed_measure( ## axis iterator is of size 1 here
878
+ sparse_signed_measure,
879
+ filtration_weights=self.filtrations_weights,
880
+ normalize_scales=self._normalization_factors[0],
881
+ )
882
+
883
+ out = tuple(rescale_from_sparse(x) for x in X)
884
+ return out
885
+
886
+ def transform(self, X):
887
+ if not self._has_axis or self.axis == -1:
888
+ out = X
889
+ else:
890
+ out = tuple(x[self.axis] for x in X)
891
+ # same format for everyone
892
+
893
+ if self._normalization_factors is not None:
894
+ out = self._rescale_measures(out)
895
+
896
+ if self.plot:
897
+ # assert ax != -1, "Not implemented"
898
+ self._plot_signed_measures(out)
899
+ if self.integrate:
900
+ filtrations = self._infered_grids
901
+ # if self.axis != -1:
902
+ ax = 0 # if self.axis is None else self.axis # TODO deal with axis -1
903
+
904
+ assert ax != -1, "Not implemented. Can only integrate with axis"
905
+ # try:
906
+ out = np.asarray(
907
+ [
908
+ [
909
+ self._integrate_measure(x[degree], filtrations=filtrations[ax])
910
+ for degree in range(self._num_degrees)
911
+ ]
912
+ for x in out
913
+ ]
914
+ )
915
+ # except:
916
+ # print(self.axis, ax, filtrations)
917
+ if self.flatten:
918
+ out = out.reshape((len(X), -1))
919
+ # else:
920
+ # out = [[[self._integrate_measure(x[axis][degree],filtrations=filtrations[degree].T) for degree in range(self._num_degrees)] for axis in range(self._num_axis)] for x in out]
921
+ elif self.unsparse:
922
+ out = [self.unsparse_signed_measure(x) for x in out]
923
+ elif self.deep_format:
924
+ num_degrees = self._num_degrees
925
+ out = tuple(
926
+ tuple(sm2deep(sm[axis][degree]) for sm in out)
927
+ for degree in range(num_degrees)
928
+ for axis in self._axis_iterator
929
+ )
930
+ if self.unrag:
931
+ max_num_pts = np.max(
932
+ [sm.shape[0] for sm_of_axis in out for sm in sm_of_axis]
933
+ )
934
+ num_axis_degree = len(out)
935
+ num_data = len(out[0])
936
+ assert num_axis_degree == num_degrees * (
937
+ self._num_axis if self._has_axis else 1
938
+ ), f"Bad axis/degree count. Got {num_axis_degree} (Internal error)"
939
+ num_parameters = out[0][0].shape[1]
940
+ dtype = out[0][0].dtype
941
+ unragged_tensor = self._backend.zeros(
942
+ (
943
+ num_axis_degree,
944
+ num_data,
945
+ max_num_pts,
946
+ num_parameters,
947
+ ),
948
+ dtype=dtype,
949
+ )
950
+ for ax in range(num_axis_degree):
951
+ for data in range(num_data):
952
+ sm = out[ax][data]
953
+ a, b = sm.shape
954
+ unragged_tensor[ax, data, :a, :b] = sm
955
+ out = unragged_tensor
956
+ return out
957
+
958
+
959
+ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
960
+ """
961
+ Discrete convolution of a signed measure
962
+
963
+ Input
964
+ -----
965
+
966
+ (data) x (degree) x (signed measure)
967
+
968
+ Parameters
969
+ ----------
970
+ - filtration_grid : Iterable[array] For each filtration, the filtration values on which to evaluate the grid
971
+ - resolution : int or (num_parameters) : If filtration grid is not given, will infer a grid, with this resolution
972
+ - grid_strategy : the strategy to generate the grid. Available ones are regular, quantile, exact
973
+ - flatten : if true, the output will be flattened
974
+ - kernel : kernel to used to convolve the images.
975
+ - flatten : flatten the images if True
976
+ - progress : progress bar if True
977
+ - backend : sklearn, pykeops or numba.
978
+ - plot : Creates a plot Figure.
979
+
980
+ Output
981
+ ------
982
+
983
+ (data) x (concatenation of imgs of degree)
984
+ """
985
+
986
+ def __init__(
987
+ self,
988
+ filtration_grid: Iterable[np.ndarray] = None,
989
+ kernel: available_kernels = "gaussian",
990
+ bandwidth: float | Iterable[float] = 1.0,
991
+ flatten: bool = False,
992
+ n_jobs: int = 1,
993
+ resolution: int | None = None,
994
+ grid_strategy: str = "regular",
995
+ progress: bool = False,
996
+ backend: str = "pykeops",
997
+ plot: bool = False,
998
+ log_density: bool = False,
999
+ **kde_kwargs,
1000
+ # **kwargs ## DANGEROUS
1001
+ ):
1002
+ super().__init__()
1003
+ self.kernel: available_kernels = kernel
1004
+ self.bandwidth = bandwidth
1005
+ # self.more_kde_kwargs=kwargs
1006
+ self.filtration_grid = filtration_grid
1007
+ self.flatten = flatten
1008
+ self.progress = progress
1009
+ self.n_jobs = n_jobs
1010
+ self.resolution = resolution
1011
+ self.grid_strategy = grid_strategy
1012
+ self._is_input_sparse = None
1013
+ self._refit = filtration_grid is None
1014
+ self._input_resolution = None
1015
+ self._bandwidths = None
1016
+ self.diameter = None
1017
+ self.backend = backend
1018
+ self.plot = plot
1019
+ self.log_density = log_density
1020
+ self.kde_kwargs = kde_kwargs
1021
+ self._api = None
1022
+ return
1023
+
1024
+ def fit(self, X, y=None):
1025
+ # Infers if the input is sparse given X
1026
+ if len(X) == 0:
1027
+ return self
1028
+ if isinstance(X[0][0], tuple):
1029
+ self._is_input_sparse = True
1030
+
1031
+ self._api = api_from_tensor(X[0][0][0], verbose=self.progress)
1032
+ else:
1033
+ self._is_input_sparse = False
1034
+
1035
+ self._api = api_from_tensor(X, verbose=self.progress)
1036
+ # print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
1037
+ if not self._is_input_sparse:
1038
+ self._input_resolution = X[0][0].shape
1039
+ try:
1040
+ b = float(self.bandwidth)
1041
+ self._bandwidths = [
1042
+ b if b > 0 else -b * s for s in self._input_resolution
1043
+ ]
1044
+ except:
1045
+ self._bandwidths = [
1046
+ b if b > 0 else -b * s
1047
+ for s, b in zip(self._input_resolution, self.bandwidth)
1048
+ ]
1049
+ return self # in that case, singed measures are matrices, and the grid is already given
1050
+
1051
+ if self.filtration_grid is None and self.resolution is None:
1052
+ raise Exception(
1053
+ "Cannot infer filtration grid. Provide either a filtration grid or a resolution."
1054
+ )
1055
+ # If not sparse : a grid has to be defined
1056
+ if self._refit:
1057
+ # print("Fitting a grid...", end="")
1058
+ pts = self._api.cat(
1059
+ [sm[0] for signed_measures in X for sm in signed_measures]
1060
+ ).T
1061
+ self.filtration_grid = compute_grid(
1062
+ pts,
1063
+ strategy=self.grid_strategy,
1064
+ resolution=self.resolution,
1065
+ )
1066
+ # print('Done.')
1067
+ if self.filtration_grid is not None:
1068
+ self.diameter = self._api.norm(
1069
+ self._api.astensor([f[-1] - f[0] for f in self.filtration_grid])
1070
+ )
1071
+ if self.progress:
1072
+ print(f"Computed a diameter of {self.diameter}")
1073
+ return self
1074
+
1075
+ def _sm2smi(self, signed_measures):
1076
+ # print(self._input_resolution, self.bandwidths, _bandwidths)
1077
+ from scipy.ndimage import gaussian_filter
1078
+
1079
+ return np.concatenate(
1080
+ [
1081
+ gaussian_filter(
1082
+ input=signed_measure,
1083
+ sigma=self._bandwidths,
1084
+ mode="constant",
1085
+ cval=0,
1086
+ )
1087
+ for signed_measure in signed_measures
1088
+ ],
1089
+ axis=0,
1090
+ )
1091
+
1092
+ def _transform_from_sparse(self, X):
1093
+ bandwidth = (
1094
+ self.bandwidth if self.bandwidth > 0 else -self.bandwidth * self.diameter
1095
+ )
1096
+ # COMPILE KEOPS FIRST
1097
+ dummyx = [X[0]]
1098
+ dummyf = [f[:2] for f in self.filtration_grid]
1099
+ convolution_signed_measures(
1100
+ dummyx,
1101
+ filtrations=dummyf,
1102
+ bandwidth=bandwidth,
1103
+ flatten=self.flatten,
1104
+ n_jobs=1,
1105
+ kernel=self.kernel,
1106
+ backend=self.backend,
1107
+ )
1108
+
1109
+ return convolution_signed_measures(
1110
+ X,
1111
+ filtrations=self.filtration_grid,
1112
+ bandwidth=bandwidth,
1113
+ flatten=self.flatten,
1114
+ n_jobs=self.n_jobs,
1115
+ kernel=self.kernel,
1116
+ backend=self.backend,
1117
+ **self.kde_kwargs,
1118
+ )
1119
+
1120
+ def _plot_imgs(self, imgs: Iterable[np.ndarray], size=4):
1121
+ from multipers.plots import plot_surface
1122
+
1123
+ imgs = self._api.asnumpy(imgs)
1124
+ num_degrees = imgs[0].shape[0]
1125
+ num_imgs = len(imgs)
1126
+ fig, axes = plt.subplots(
1127
+ ncols=num_degrees,
1128
+ nrows=num_imgs,
1129
+ figsize=(size * num_degrees, size * num_imgs),
1130
+ )
1131
+ axes = np.asarray(axes).reshape(num_imgs, num_degrees)
1132
+ # assert axes.ndim==2, "Internal error"
1133
+ for i, img in enumerate(imgs):
1134
+ for j, img_of_degree in enumerate(img):
1135
+ plot_surface(
1136
+ [self._api.asnumpy(f) for f in self.filtration_grid],
1137
+ img_of_degree,
1138
+ ax=axes[i, j],
1139
+ cmap="Spectral",
1140
+ )
1141
+
1142
+ def transform(self, X):
1143
+ if self._is_input_sparse is None:
1144
+ raise Exception("Fit first")
1145
+ if self._is_input_sparse:
1146
+ out = self._transform_from_sparse(X)
1147
+ else:
1148
+ todo = SignedMeasure2Convolution._sm2smi
1149
+ out = Parallel(n_jobs=self.n_jobs, backend="threading")(
1150
+ delayed(todo)(self, signed_measures)
1151
+ for signed_measures in tqdm(
1152
+ X, desc="Computing images", disable=not self.progress
1153
+ )
1154
+ )
1155
+ out = self._api.cat([x[None] for x in out])
1156
+ if self.plot and not self.flatten:
1157
+ if self.progress:
1158
+ print("Plotting convolutions...", end="")
1159
+ self._plot_imgs(out)
1160
+ if self.progress:
1161
+ print("Done !")
1162
+ if self.flatten and not self._is_input_sparse:
1163
+ out = self._api.cat([x.ravel()[None] for x in out])
1164
+ return out
1165
+
1166
+
1167
+ class SignedMeasure2SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
1168
+ """
1169
+ Transformer from signed measure to distance matrix.
1170
+
1171
+ Input
1172
+ -----
1173
+
1174
+ (data) x (degree) x (signed measure)
1175
+
1176
+ Format
1177
+ ------
1178
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1179
+ - each data is a list of signed measure (for e.g. multiple degrees)
1180
+
1181
+ Output
1182
+ ------
1183
+ - (degree) x (distance matrix)
1184
+ """
1185
+
1186
+ def __init__(
1187
+ self,
1188
+ n_jobs=None,
1189
+ num_directions: int = 10,
1190
+ _sliced: bool = True,
1191
+ epsilon=-1,
1192
+ ground_norm=1,
1193
+ progress=False,
1194
+ grid_reconversion=None,
1195
+ scales=None,
1196
+ ):
1197
+ super().__init__()
1198
+ self.n_jobs = n_jobs
1199
+ self._SWD_list = None
1200
+ self._sliced = _sliced
1201
+ self.epsilon = epsilon
1202
+ self.ground_norm = ground_norm
1203
+ self.num_directions = num_directions
1204
+ self.progress = progress
1205
+ self.grid_reconversion = grid_reconversion
1206
+ self.scales = scales
1207
+ return
1208
+
1209
+ def fit(self, X, y=None):
1210
+ from multipers.ml.sliced_wasserstein import (
1211
+ SlicedWassersteinDistance,
1212
+ WassersteinDistance,
1213
+ )
1214
+
1215
+ # _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
1216
+ if len(X) == 0:
1217
+ return self
1218
+ num_degrees = len(X[0])
1219
+ self._SWD_list = [
1220
+ (
1221
+ SlicedWassersteinDistance(
1222
+ num_directions=self.num_directions,
1223
+ n_jobs=self.n_jobs,
1224
+ scales=self.scales,
1225
+ )
1226
+ if self._sliced
1227
+ else WassersteinDistance(
1228
+ epsilon=self.epsilon,
1229
+ ground_norm=self.ground_norm,
1230
+ n_jobs=self.n_jobs,
1231
+ )
1232
+ )
1233
+ for _ in range(num_degrees)
1234
+ ]
1235
+ for degree, swd in enumerate(self._SWD_list):
1236
+ signed_measures_of_degree = [x[degree] for x in X]
1237
+ swd.fit(signed_measures_of_degree)
1238
+ return self
1239
+
1240
+ def transform(self, X):
1241
+ assert self._SWD_list is not None, "Fit first"
1242
+ # out = []
1243
+ # for degree, swd in tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress):
1244
+ with tqdm(
1245
+ enumerate(self._SWD_list),
1246
+ desc="Computing distance matrices",
1247
+ total=len(self._SWD_list),
1248
+ disable=not self.progress,
1249
+ ) as SWD_it:
1250
+ # signed_measures_of_degree = [x[degree] for x in X]
1251
+ # out.append(swd.transform(signed_measures_of_degree))
1252
+ def todo(swd, X_of_degree):
1253
+ return swd.transform(X_of_degree)
1254
+
1255
+ out = Parallel(n_jobs=self.n_jobs, prefer="threads")(
1256
+ delayed(todo)(swd, [x[degree] for x in X]) for degree, swd in SWD_it
1257
+ )
1258
+ return np.asarray(out)
1259
+
1260
+ def predict(self, X):
1261
+ return self.transform(X)
1262
+
1263
+
1264
+ class SignedMeasures2SlicedWassersteinDistances(BaseEstimator, TransformerMixin):
1265
+ """
1266
+ Transformer from signed measure to distance matrix.
1267
+ Input
1268
+ -----
1269
+ (data) x opt (axis) x (degree) x (signed measure)
1270
+
1271
+ Format
1272
+ ------
1273
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1274
+ - each data is a list of signed measure (for e.g. multiple degrees)
1275
+
1276
+ Output
1277
+ ------
1278
+ - (axis) x (degree) x (distance matrix)
1279
+ """
1280
+
1281
+ def __init__(
1282
+ self,
1283
+ progress=False,
1284
+ n_jobs: int = 1,
1285
+ scales: Iterable[Iterable[float]] | None = None,
1286
+ **kwargs,
1287
+ ): # same init
1288
+ self._init_child = SignedMeasure2SlicedWassersteinDistance(
1289
+ progress=False, scales=None, n_jobs=-1, **kwargs
1290
+ )
1291
+ self._axe_iterator = None
1292
+ self._childs_to_fit = None
1293
+ self.scales = scales
1294
+ self.progress = progress
1295
+ self.n_jobs = n_jobs
1296
+ return
1297
+
1298
+ def fit(self, X, y=None):
1299
+ from sklearn.base import clone
1300
+
1301
+ if len(X) == 0:
1302
+ return self
1303
+ if isinstance(X[0][0], tuple): # Meaning that there are no axes
1304
+ self._axe_iterator = [slice(None)]
1305
+ else:
1306
+ self._axe_iterator = range(len(X[0]))
1307
+ if self.scales is None:
1308
+ self.scales = [None]
1309
+ else:
1310
+ self.scales = np.asarray(self.scales)
1311
+ if self.scales.ndim == 1:
1312
+ self.scales = np.asarray([self.scales])
1313
+ assert (
1314
+ self.scales[0] is None or self.scales.ndim == 2
1315
+ ), "Scales have to be either None or a list of scales !"
1316
+ self._childs_to_fit = [
1317
+ clone(self._init_child).set_params(scales=scales).fit([x[axis] for x in X])
1318
+ for axis, scales in product(self._axe_iterator, self.scales)
1319
+ ]
1320
+ print("New axes : ", list(product(self._axe_iterator, self.scales)))
1321
+ return self
1322
+
1323
+ def transform(self, X):
1324
+ return Parallel(n_jobs=self.n_jobs, prefer="processes")(
1325
+ delayed(self._childs_to_fit[child_id].transform)([x[axis] for x in X])
1326
+ for child_id, (axis, _) in tqdm(
1327
+ enumerate(product(self._axe_iterator, self.scales)),
1328
+ desc=f"Computing distances matrices of axis, and scales",
1329
+ disable=not self.progress,
1330
+ total=len(self._childs_to_fit),
1331
+ )
1332
+ )
1333
+ # [
1334
+ # child.transform([x[axis // len(self.scales)] for x in X])
1335
+ # for axis, child in tqdm(enumerate(self._childs_to_fit),
1336
+ # desc=f"Computing distances of axis", disable=not self.progress, total=len(self._childs_to_fit)
1337
+ # )
1338
+ # ]
1339
+
1340
+
1341
+ class SimplexTree2RectangleDecomposition(BaseEstimator, TransformerMixin):
1342
+ """
1343
+ Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition.
1344
+ """
1345
+
1346
+ def __init__(
1347
+ self,
1348
+ filtration_grid: np.ndarray,
1349
+ degrees: Iterable[int],
1350
+ plot=False,
1351
+ reconvert_grid=True,
1352
+ num_collapses: int = 0,
1353
+ ):
1354
+ super().__init__()
1355
+ self.filtration_grid = filtration_grid
1356
+ self.degrees = degrees
1357
+ self.plot = plot
1358
+ self.reconvert_grid = reconvert_grid
1359
+ self.num_collapses = num_collapses
1360
+ return
1361
+
1362
+ def fit(self, X, y=None):
1363
+ """
1364
+ TODO : infer grid from multiple simplextrees
1365
+ """
1366
+ return self
1367
+
1368
+ def transform(self, X: Iterable[mp.simplex_tree_multi.SimplexTreeMulti_type]):
1369
+ rectangle_decompositions = [
1370
+ [
1371
+ _st2ranktensor(
1372
+ simplextree,
1373
+ filtration_grid=self.filtration_grid,
1374
+ degree=degree,
1375
+ plot=self.plot,
1376
+ reconvert_grid=self.reconvert_grid,
1377
+ num_collapse=self.num_collapses,
1378
+ )
1379
+ for degree in self.degrees
1380
+ ]
1381
+ for simplextree in X
1382
+ ]
1383
+ # TODO : return iterator ?
1384
+ return rectangle_decompositions
1385
+
1386
+
1387
+ def _st2ranktensor(
1388
+ st: mp.simplex_tree_multi.SimplexTreeMulti_type,
1389
+ filtration_grid: np.ndarray,
1390
+ degree: int,
1391
+ plot: bool,
1392
+ reconvert_grid: bool,
1393
+ num_collapse: int | str = 0,
1394
+ ):
1395
+ """
1396
+ TODO
1397
+ """
1398
+ # Copy (the squeeze change the filtration values)
1399
+ # stcpy = mp.SimplexTreeMulti(st)
1400
+ # turns the simplextree into a coordinate simplex tree
1401
+ stcpy = st.grid_squeeze(filtration_grid=filtration_grid, coordinate_values=True)
1402
+ # stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
1403
+ if num_collapse == "full":
1404
+ stcpy.collapse_edges(full=True, ignore_warning=True, max_dimension=degree + 1)
1405
+ elif isinstance(num_collapse, int):
1406
+ stcpy.collapse_edges(
1407
+ num=num_collapse, ignore_warning=True, max_dimension=degree + 1
1408
+ )
1409
+ else:
1410
+ raise TypeError(
1411
+ f"Invalid num_collapse=\
1412
+ {num_collapse} type. Either full, or an integer."
1413
+ )
1414
+ # computes the rank invariant tensor
1415
+ rank_tensor = mp.rank_invariant2d(
1416
+ stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid]
1417
+ )
1418
+ # refactor this tensor into the rectangle decomposition of the signed betti
1419
+ grid_conversion = filtration_grid if reconvert_grid else None
1420
+ rank_decomposition = rank_decomposition_by_rectangles(
1421
+ rank_tensor,
1422
+ threshold=True,
1423
+ )
1424
+ rectangle_decomposition = tensor_möbius_inversion(
1425
+ tensor=rank_decomposition,
1426
+ grid_conversion=grid_conversion,
1427
+ plot=plot,
1428
+ num_parameters=st.num_parameters,
1429
+ )
1430
+ return rectangle_decomposition
1431
+
1432
+
1433
+ class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
1434
+ def __init__(
1435
+ self,
1436
+ degrees: Iterable[int],
1437
+ min_rips_value: float,
1438
+ max_rips_value,
1439
+ max_normalized_degree: float,
1440
+ min_normalized_degree: float,
1441
+ grid_granularity: int,
1442
+ progress: bool = False,
1443
+ n_jobs=1,
1444
+ sparse: bool = False,
1445
+ _möbius_inversion=True,
1446
+ fit_fraction=1,
1447
+ ) -> None:
1448
+ super().__init__()
1449
+ self.min_rips_value = min_rips_value
1450
+ self.max_rips_value = max_rips_value
1451
+ self.min_normalized_degree = min_normalized_degree
1452
+ self.max_normalized_degree = max_normalized_degree
1453
+ self._max_rips_value = None
1454
+ self.grid_granularity = grid_granularity
1455
+ self.progress = progress
1456
+ self.n_jobs = n_jobs
1457
+ self.degrees = degrees
1458
+ self.sparse = sparse
1459
+ self._möbius_inversion = _möbius_inversion
1460
+ self.fit_fraction = fit_fraction
1461
+ return
1462
+
1463
+ def fit(self, X: np.ndarray | list, y=None):
1464
+ if self.max_rips_value < 0:
1465
+ print("Estimating scale...", flush=True, end="")
1466
+ indices = np.random.choice(
1467
+ len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
1468
+ )
1469
+ diameters = np.max(
1470
+ [distance_matrix(x, x).max() for x in (X[i] for i in indices)]
1471
+ )
1472
+ print(f"Done. {diameters}", flush=True)
1473
+ self._max_rips_value = (
1474
+ -self.max_rips_value * diameters
1475
+ if self.max_rips_value < 0
1476
+ else self.max_rips_value
1477
+ )
1478
+ return self
1479
+
1480
+ def _transform1(self, data: np.ndarray):
1481
+ _distance_matrix = distance_matrix(data, data)
1482
+ signed_measures = []
1483
+ (
1484
+ rips_values,
1485
+ normalized_degree_values,
1486
+ hilbert_functions,
1487
+ minimal_presentations,
1488
+ ) = hf_degree_rips(
1489
+ _distance_matrix,
1490
+ min_rips_value=self.min_rips_value,
1491
+ max_rips_value=self._max_rips_value,
1492
+ min_normalized_degree=self.min_normalized_degree,
1493
+ max_normalized_degree=self.max_normalized_degree,
1494
+ grid_granularity=self.grid_granularity,
1495
+ max_homological_dimension=np.max(self.degrees),
1496
+ )
1497
+ for degree in self.degrees:
1498
+ hilbert_function = hilbert_functions[degree]
1499
+ signed_measure = (
1500
+ signed_betti(hilbert_function, threshold=True)
1501
+ if self._möbius_inversion
1502
+ else hilbert_function
1503
+ )
1504
+ if self.sparse:
1505
+ signed_measure = tensor_möbius_inversion(
1506
+ tensor=signed_measure,
1507
+ num_parameters=2,
1508
+ grid_conversion=[rips_values, normalized_degree_values],
1509
+ )
1510
+ if not self._möbius_inversion:
1511
+ signed_measure = signed_measure.flatten()
1512
+ signed_measures.append(signed_measure)
1513
+ return signed_measures
1514
+
1515
+ def transform(self, X):
1516
+ return Parallel(n_jobs=self.n_jobs)(
1517
+ delayed(self._transform1)(data)
1518
+ for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}")
1519
+ )
1520
+
1521
+
1522
+ def tensor_möbius_inversion(
1523
+ tensor,
1524
+ grid_conversion: Iterable[np.ndarray] | None = None,
1525
+ plot: bool = False,
1526
+ raw: bool = False,
1527
+ num_parameters: int | None = None,
1528
+ ):
1529
+ from torch import Tensor
1530
+
1531
+ betti_sparse = Tensor(tensor.copy()).to_sparse() # Copy necessary in some cases :(
1532
+ num_indices, num_pts = betti_sparse.indices().shape
1533
+ num_parameters = num_indices if num_parameters is None else num_parameters
1534
+ if num_indices == num_parameters: # either hilbert or rank invariant
1535
+ rank_invariant = False
1536
+ elif 2 * num_parameters == num_indices:
1537
+ rank_invariant = True
1538
+ else:
1539
+ raise TypeError(
1540
+ f"Unsupported betti shape. {num_indices}\
1541
+ has to be either {num_parameters} or \
1542
+ {2*num_parameters}."
1543
+ )
1544
+ points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
1545
+ weights = np.asarray(betti_sparse.values(), dtype=int)
1546
+
1547
+ if grid_conversion is not None:
1548
+ coords = np.empty(shape=(num_pts, num_indices), dtype=float)
1549
+ for i in range(num_indices):
1550
+ coords[:, i] = grid_conversion[i % num_parameters][points_filtration[:, i]]
1551
+ else:
1552
+ coords = points_filtration
1553
+ if (not rank_invariant) and plot:
1554
+ plt.figure()
1555
+ color_weights = np.empty(weights.shape)
1556
+ color_weights[weights > 0] = np.log10(weights[weights > 0]) + 2
1557
+ color_weights[weights < 0] = -np.log10(-weights[weights < 0]) - 2
1558
+ plt.scatter(
1559
+ points_filtration[:, 0],
1560
+ points_filtration[:, 1],
1561
+ c=color_weights,
1562
+ cmap="coolwarm",
1563
+ )
1564
+ if (not rank_invariant) or raw:
1565
+ return coords, weights
1566
+
1567
+ def _is_trivial(rectangle: np.ndarray):
1568
+ birth = rectangle[:num_parameters]
1569
+ death = rectangle[num_parameters:]
1570
+ return np.all(birth <= death) # and not np.array_equal(birth,death)
1571
+
1572
+ correct_indices = np.array([_is_trivial(rectangle) for rectangle in coords])
1573
+ if len(correct_indices) == 0:
1574
+ return np.empty((0, num_indices)), np.empty((0))
1575
+ signed_measure = np.asarray(coords[correct_indices])
1576
+ weights = weights[correct_indices]
1577
+ if plot:
1578
+ # plot only the rank decompo for the moment
1579
+ assert signed_measure.shape[1] == 4
1580
+
1581
+ def _plot_rectangle(rectangle: np.ndarray, weight: float):
1582
+ x_axis = rectangle[[0, 2]]
1583
+ y_axis = rectangle[[1, 3]]
1584
+ color = "blue" if weight > 0 else "red"
1585
+ plt.plot(x_axis, y_axis, c=color)
1586
+
1587
+ for rectangle, weight in zip(signed_measure, weights):
1588
+ _plot_rectangle(rectangle=rectangle, weight=weight)
1589
+ return signed_measure, weights