multipers 2.2.3__cp310-cp310-win_amd64.whl → 2.3.1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -31
  2. multipers/_signed_measure_meta.py +430 -430
  3. multipers/_slicer_meta.py +211 -212
  4. multipers/data/MOL2.py +458 -458
  5. multipers/data/UCR.py +18 -18
  6. multipers/data/graphs.py +466 -466
  7. multipers/data/immuno_regions.py +27 -27
  8. multipers/data/pytorch2simplextree.py +90 -90
  9. multipers/data/shape3d.py +101 -101
  10. multipers/data/synthetic.py +113 -111
  11. multipers/distances.py +198 -198
  12. multipers/filtration_conversions.pxd.tp +84 -84
  13. multipers/filtrations/__init__.py +18 -0
  14. multipers/{ml/convolutions.py → filtrations/density.py} +563 -520
  15. multipers/filtrations/filtrations.py +289 -0
  16. multipers/filtrations.pxd +224 -224
  17. multipers/function_rips.cp310-win_amd64.pyd +0 -0
  18. multipers/function_rips.pyx +105 -105
  19. multipers/grids.cp310-win_amd64.pyd +0 -0
  20. multipers/grids.pyx +350 -350
  21. multipers/gudhi/Persistence_slices_interface.h +132 -132
  22. multipers/gudhi/Simplex_tree_interface.h +239 -245
  23. multipers/gudhi/Simplex_tree_multi_interface.h +516 -561
  24. multipers/gudhi/cubical_to_boundary.h +59 -59
  25. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -450
  26. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -1070
  27. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -579
  28. multipers/gudhi/gudhi/Debug_utils.h +45 -45
  29. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -484
  30. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -455
  31. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -450
  32. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -531
  33. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -507
  34. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -531
  35. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -355
  36. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -376
  37. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -420
  38. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -400
  39. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -418
  40. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -337
  41. multipers/gudhi/gudhi/Matrix.h +2107 -2107
  42. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -1038
  43. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -171
  44. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -282
  45. multipers/gudhi/gudhi/Off_reader.h +173 -173
  46. multipers/gudhi/gudhi/One_critical_filtration.h +1433 -1431
  47. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -769
  48. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -686
  49. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -842
  50. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -1350
  51. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -1105
  52. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -859
  53. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -910
  54. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -139
  55. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -230
  56. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -211
  57. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -60
  58. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -60
  59. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -136
  60. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -190
  61. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -616
  62. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -150
  63. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -106
  64. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -219
  65. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -327
  66. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -1140
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -934
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -934
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -980
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -1092
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -192
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -921
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -1093
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -1012
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -1244
  76. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -186
  77. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -164
  78. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -156
  79. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -376
  80. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -540
  81. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -118
  82. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -173
  83. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -128
  84. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -745
  85. multipers/gudhi/gudhi/Points_off_io.h +171 -171
  86. multipers/gudhi/gudhi/Simple_object_pool.h +69 -69
  87. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -463
  88. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -83
  89. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -106
  90. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -277
  91. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -62
  92. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -27
  93. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -62
  94. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -157
  95. multipers/gudhi/gudhi/Simplex_tree.h +2794 -2794
  96. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -163
  97. multipers/gudhi/gudhi/distance_functions.h +62 -62
  98. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -104
  99. multipers/gudhi/gudhi/persistence_interval.h +253 -253
  100. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -170
  101. multipers/gudhi/gudhi/reader_utils.h +367 -367
  102. multipers/gudhi/mma_interface_coh.h +256 -255
  103. multipers/gudhi/mma_interface_h0.h +223 -231
  104. multipers/gudhi/mma_interface_matrix.h +291 -282
  105. multipers/gudhi/naive_merge_tree.h +536 -575
  106. multipers/gudhi/scc_io.h +310 -289
  107. multipers/gudhi/truc.h +957 -888
  108. multipers/io.cp310-win_amd64.pyd +0 -0
  109. multipers/io.pyx +714 -711
  110. multipers/ml/accuracies.py +90 -90
  111. multipers/ml/invariants_with_persistable.py +79 -79
  112. multipers/ml/kernels.py +176 -176
  113. multipers/ml/mma.py +713 -714
  114. multipers/ml/one.py +472 -472
  115. multipers/ml/point_clouds.py +352 -346
  116. multipers/ml/signed_measures.py +1589 -1589
  117. multipers/ml/sliced_wasserstein.py +461 -461
  118. multipers/ml/tools.py +113 -113
  119. multipers/mma_structures.cp310-win_amd64.pyd +0 -0
  120. multipers/mma_structures.pxd +127 -127
  121. multipers/mma_structures.pyx +4 -8
  122. multipers/mma_structures.pyx.tp +1083 -1085
  123. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -93
  124. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -97
  125. multipers/multi_parameter_rank_invariant/function_rips.h +322 -322
  126. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -769
  127. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -148
  128. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -369
  129. multipers/multiparameter_edge_collapse.py +41 -41
  130. multipers/multiparameter_module_approximation/approximation.h +2298 -2295
  131. multipers/multiparameter_module_approximation/combinatory.h +129 -129
  132. multipers/multiparameter_module_approximation/debug.h +107 -107
  133. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -286
  134. multipers/multiparameter_module_approximation/heap_column.h +238 -238
  135. multipers/multiparameter_module_approximation/images.h +79 -79
  136. multipers/multiparameter_module_approximation/list_column.h +174 -174
  137. multipers/multiparameter_module_approximation/list_column_2.h +232 -232
  138. multipers/multiparameter_module_approximation/ru_matrix.h +347 -347
  139. multipers/multiparameter_module_approximation/set_column.h +135 -135
  140. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -36
  141. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -166
  142. multipers/multiparameter_module_approximation/utilities.h +403 -419
  143. multipers/multiparameter_module_approximation/vector_column.h +223 -223
  144. multipers/multiparameter_module_approximation/vector_matrix.h +331 -331
  145. multipers/multiparameter_module_approximation/vineyards.h +464 -464
  146. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -649
  147. multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
  148. multipers/multiparameter_module_approximation.pyx +218 -217
  149. multipers/pickle.py +90 -53
  150. multipers/plots.py +342 -334
  151. multipers/point_measure.cp310-win_amd64.pyd +0 -0
  152. multipers/point_measure.pyx +322 -320
  153. multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
  154. multipers/simplex_tree_multi.pxd +133 -133
  155. multipers/simplex_tree_multi.pyx +115 -48
  156. multipers/simplex_tree_multi.pyx.tp +1947 -1935
  157. multipers/slicer.cp310-win_amd64.pyd +0 -0
  158. multipers/slicer.pxd +301 -120
  159. multipers/slicer.pxd.tp +218 -214
  160. multipers/slicer.pyx +1570 -507
  161. multipers/slicer.pyx.tp +931 -914
  162. multipers/tensor/tensor.h +672 -672
  163. multipers/tensor.pxd +13 -13
  164. multipers/test.pyx +44 -44
  165. multipers/tests/__init__.py +57 -57
  166. multipers/torch/diff_grids.py +217 -217
  167. multipers/torch/rips_density.py +310 -304
  168. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/LICENSE +21 -21
  169. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/METADATA +21 -11
  170. multipers-2.3.1.dist-info/RECORD +182 -0
  171. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/WHEEL +1 -1
  172. multipers/tests/test_diff_helper.py +0 -73
  173. multipers/tests/test_hilbert_function.py +0 -82
  174. multipers/tests/test_mma.py +0 -83
  175. multipers/tests/test_point_clouds.py +0 -49
  176. multipers/tests/test_python-cpp_conversion.py +0 -82
  177. multipers/tests/test_signed_betti.py +0 -181
  178. multipers/tests/test_signed_measure.py +0 -89
  179. multipers/tests/test_simplextreemulti.py +0 -221
  180. multipers/tests/test_slicer.py +0 -221
  181. multipers-2.2.3.dist-info/RECORD +0 -189
  182. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/top_level.txt +0 -0
@@ -1,1589 +1,1589 @@
1
- from collections.abc import Callable, Iterable, Sequence
2
- from itertools import product
3
- from typing import Optional, Union
4
-
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- from joblib import Parallel, delayed
8
- from sklearn.base import BaseEstimator, TransformerMixin
9
- from tqdm import tqdm
10
-
11
- import multipers as mp
12
- from multipers.grids import compute_grid as reduce_grid
13
- from multipers.ml.convolutions import available_kernels, convolution_signed_measures
14
- from multipers.point_measure import signed_betti, rank_decomposition_by_rectangles
15
-
16
-
17
- class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
18
- """
19
- Input
20
- -----
21
- Iterable[SimplexTreeMulti]
22
-
23
- Output
24
- ------
25
- Iterable[ list[signed_measure for degree] ]
26
-
27
- signed measure is either
28
- - (points : (n x num_parameters) array, weights : (n) int array ) if sparse,
29
- - else an integer matrix.
30
-
31
- Parameters
32
- ----------
33
- - degrees : list of degrees to compute. None correspond to the euler characteristic
34
- - filtration grid : the grid on which to compute.
35
- If None, the fit will infer it from
36
- - fit_fraction : the fraction of data to consider for the fit, seed is controlled by the seed parameter
37
- - resolution : the resolution of this grid
38
- - filtration_quantile : filtrations values quantile to ignore
39
- - grid_strategy:str : 'regular' or 'quantile' or 'exact'
40
- - normalize filtration : if sparse, will normalize all filtrations.
41
- - expand : expands the simplextree to compute correctly the degree, for
42
- flag complexes
43
- - invariant : the topological invariant to produce the signed measure.
44
- Choices are "hilbert" or "euler". Will add rank invariant later.
45
- - num_collapse : Either an int or "full". Collapse the complex before
46
- doing computation.
47
- - _möbius_inversion : if False, will not do the mobius inversion. output
48
- has to be a matrix then.
49
- - enforce_null_mass : Returns a zero mass measure, by thresholding the
50
- module if True.
51
- """
52
-
53
- def __init__(
54
- self,
55
- # homological degrees + None for euler
56
- degrees: list[int | None] = [],
57
- rank_degrees: list[int] = [], # same for rank invariant
58
- filtration_grid: (
59
- Sequence[Sequence[np.ndarray]]
60
- # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
61
- | None
62
- ) = None,
63
- progress=False, # tqdm
64
- num_collapses: int | str = 0, # edge collapses before computing
65
- n_jobs=None,
66
- resolution: (
67
- Iterable[int] | int | None
68
- ) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
69
- # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
70
- plot: bool = False,
71
- filtration_quantile: float = 0.0, # quantile for inferring filtration grid
72
- # wether or not to do the möbius inversion (not recommended to touch)
73
- # _möbius_inversion: bool = True,
74
- expand=False, # expand the simplextree befoe computing the homology
75
- normalize_filtrations: bool = False,
76
- # exact_computation:bool=False, # compute the exact signed measure.
77
- grid_strategy: str = "exact",
78
- seed: int = 0, # if fit_fraction is not 1, the seed sampling
79
- fit_fraction=1, # the fraction of the data on which to fit
80
- out_resolution: Iterable[int] | int | None = None,
81
- individual_grid: Optional[
82
- bool
83
- ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
84
- enforce_null_mass: bool = False,
85
- flatten=True,
86
- backend: Optional[str] = None,
87
- ):
88
- super().__init__()
89
- self.degrees = degrees
90
- self.rank_degrees = rank_degrees
91
- self.filtration_grid = filtration_grid
92
- self.progress = progress
93
- self.num_collapses = num_collapses
94
- self.n_jobs = n_jobs
95
- self.resolution = resolution
96
- self.plot = plot
97
- self.backend = backend
98
- # self.sparse=sparse # TODO : deprecate
99
- self.filtration_quantile = filtration_quantile
100
- # Will only work for non sparse output. (discrete matrices cannot be "rescaled")
101
- self.normalize_filtrations = normalize_filtrations
102
- self.grid_strategy = grid_strategy
103
- # self._möbius_inversion = _möbius_inversion
104
- self._reconversion_grid = None
105
- self.expand = expand
106
- # will only refit the grid if filtration_grid has never been given.
107
- self._refit_grid = None
108
- self.seed = seed
109
- self.fit_fraction = fit_fraction
110
- self._transform_st = None
111
- self.out_resolution = out_resolution
112
- self.individual_grid = individual_grid
113
- self.enforce_null_mass = enforce_null_mass
114
- self._default_mass_location = None
115
- self.flatten = flatten
116
- self.num_parameters: int = 0
117
-
118
- return
119
-
120
- @staticmethod
121
- def _is_filtered_complex(input):
122
- return mp.simplex_tree_multi.is_simplextree_multi(input) or mp.slicer.is_slicer(
123
- input, allow_minpres=True
124
- )
125
-
126
- def _input_checks(self, X):
127
- assert len(X) > 0, "No filtered complex found. Cannot fit."
128
- assert self._is_filtered_complex(
129
- X[0][0]
130
- ), f"X[0] is not a known filtered complex, {X[0]=}, nor X[0][0]."
131
- self._num_axis = len(X[0])
132
- first = X[0][0]
133
- assert (
134
- not mp.slicer.is_slicer(first) or not self.expand
135
- ), "Cannot expand slicers."
136
- assert not mp.slicer.is_slicer(first) or not (
137
- isinstance(first, Union[tuple, list]) and first[0].is_minpres
138
- ), "Multi-degree minpres are not supported yet as an input. This can still be computed by providing a backend."
139
-
140
- def _infer_filtration(self, X):
141
- self.num_parameters = X[0][0].num_parameters
142
- indices = np.random.choice(
143
- len(X), min(int(self.fit_fraction * len(X)) + 1, len(X)), replace=False
144
- )
145
- ## ax, num_x
146
- filtrations = tuple(
147
- tuple(
148
- reduce_grid(x, strategy="exact")
149
- for x in (X[idx][ax] for idx in indices)
150
- )
151
- for ax in range(self._num_axis)
152
- )
153
- num_parameters = len(filtrations[0][0])
154
- assert (
155
- num_parameters == self.num_parameters
156
- ), f"Internal error, got {num_parameters=} and {self.num_parameters=}"
157
-
158
- filtrations_values = [
159
- [
160
- np.unique(np.concatenate([x[i] for x in filtrations[ax]]))
161
- for i in range(num_parameters)
162
- ]
163
- for ax in range(self._num_axis)
164
- ]
165
- ## ax, param, gridsize
166
- filtration_grid = tuple(
167
- reduce_grid(
168
- filtrations_values[ax],
169
- resolution=self.resolution,
170
- strategy=self.grid_strategy,
171
- )
172
- for ax in range(self._num_axis)
173
- ) # TODO :use more parameters
174
- self.filtration_grid = filtration_grid
175
- return filtration_grid
176
-
177
- def _params_check(self):
178
- assert (
179
- self.resolution is not None
180
- or self.filtration_grid is not None
181
- or self.grid_strategy == "exact"
182
- or self.individual_grid
183
- ), "For non exact filtrations, a resolution has to be specified."
184
-
185
- def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
186
- self._params_check()
187
- self._input_checks(X)
188
-
189
- if isinstance(self.resolution, int):
190
- self.resolution = [self.resolution] * self.num_parameters
191
-
192
- self.individual_grid = (
193
- self.individual_grid
194
- if self.individual_grid is not None
195
- else self.grid_strategy
196
- in ["regular_closest", "exact", "quantile", "partition"]
197
- )
198
-
199
- if (
200
- not self.enforce_null_mass
201
- and self.individual_grid
202
- or self.filtration_grid is not None
203
- ):
204
- self._refit_grid = False
205
- else:
206
- self._refit_grid = True
207
-
208
- if self._refit_grid:
209
- self._infer_filtration(X=X)
210
- if self.out_resolution is None:
211
- self.out_resolution = self.resolution
212
- # elif isinstance(self.out_resolution, int):
213
- # self.out_resolution = [self.out_resolution] * self.num_parameters
214
- if self.normalize_filtrations and not self.individual_grid:
215
- # self._reconversion_grid = [np.linspace(0,1, num=len(f), dtype=float) for f in self.filtration_grid] ## This will not work for non-regular grids...
216
- self._reconversion_grid = [
217
- [(f - np.min(f)) / np.std(f) for f in F] for F in self.filtration_grid
218
- ] # not the best, but better than some weird magic
219
- # elif not self.sparse: # It actually renormalizes the filtration !!
220
- # self._reconversion_grid = [np.linspace(0,r, num=r, dtype=int) for r in self.out_resolution]
221
- else:
222
- self._reconversion_grid = self.filtration_grid
223
- ## ax, num_param
224
- self._default_mass_location = (
225
- np.asarray([[g[-1] for g in F] for F in self.filtration_grid])
226
- if self.enforce_null_mass
227
- else None
228
- )
229
- return self
230
-
231
- def transform1(
232
- self,
233
- simplextree,
234
- ax,
235
- # _reconversion_grid,
236
- thread_id: str = "",
237
- ):
238
- # st = mp.SimplexTreeMulti(st, num_parameters=st.num_parameters) # COPY
239
- if self.individual_grid:
240
- filtration_grid = reduce_grid(
241
- simplextree, strategy=self.grid_strategy, resolution=self.resolution
242
- )
243
- mass_default = (
244
- self._default_mass_location[ax] if self.enforce_null_mass else None
245
- )
246
- if self.enforce_null_mass:
247
- filtration_grid = [
248
- np.concatenate([f, [d]], axis=0)
249
- for f, d in zip(filtration_grid, mass_default)
250
- ]
251
- _reconversion_grid = filtration_grid
252
- else:
253
- filtration_grid = self.filtration_grid[ax]
254
- _reconversion_grid = self._reconversion_grid[ax]
255
- mass_default = (
256
- self._default_mass_location[ax] if self.enforce_null_mass else None
257
- )
258
-
259
- st = simplextree.grid_squeeze(filtration_grid=filtration_grid)
260
- if st.num_parameters == 2 and mp.simplex_tree_multi.is_simplextree_multi(st):
261
- st.collapse_edges(num=self.num_collapses, max_dimension=1)
262
- int_degrees = np.asarray([d for d in self.degrees if d is not None], dtype=int)
263
- # EULER. First as there is prune above dimension below
264
- if self.expand and None in self.degrees:
265
- st.expansion(st.num_vertices)
266
- signed_measures_euler = (
267
- mp.signed_measure(
268
- st,
269
- degrees=[None],
270
- plot=self.plot,
271
- mass_default=mass_default,
272
- invariant="euler",
273
- # thread_id=thread_id,
274
- backend=self.backend,
275
- grid=_reconversion_grid,
276
- )[0]
277
- if None in self.degrees
278
- else []
279
- )
280
-
281
- if self.expand and len(int_degrees) > 0:
282
- st.expansion(np.max(int_degrees) + 1)
283
- if len(int_degrees) > 0:
284
- st.prune_above_dimension(
285
- np.max(np.concatenate([int_degrees, self.rank_degrees])) + 1
286
- ) # no need to compute homology beyond this
287
- signed_measures_pers = (
288
- mp.signed_measure(
289
- st,
290
- degrees=int_degrees,
291
- mass_default=mass_default,
292
- plot=self.plot,
293
- invariant="hilbert",
294
- thread_id=thread_id,
295
- backend=self.backend,
296
- grid=_reconversion_grid,
297
- )
298
- if len(int_degrees) > 0
299
- else []
300
- )
301
- if self.plot:
302
- plt.show()
303
- if self.expand and len(self.rank_degrees) > 0:
304
- st.expansion(np.max(self.rank_degrees) + 1)
305
- if len(self.rank_degrees) > 0:
306
- st.prune_above_dimension(
307
- np.max(self.rank_degrees) + 1
308
- ) # no need to compute homology beyond this
309
- signed_measures_rank = (
310
- mp.signed_measure(
311
- st,
312
- degrees=self.rank_degrees,
313
- mass_default=mass_default,
314
- plot=self.plot,
315
- invariant="rank",
316
- thread_id=thread_id,
317
- backend=self.backend,
318
- grid=_reconversion_grid,
319
- )
320
- if len(self.rank_degrees) > 0
321
- else []
322
- )
323
- if self.plot:
324
- plt.show()
325
-
326
- count = 0
327
- signed_measures = []
328
- for d in self.degrees:
329
- if d is None:
330
- signed_measures.append(signed_measures_euler)
331
- else:
332
- signed_measures.append(signed_measures_pers[count])
333
- count += 1
334
- signed_measures += signed_measures_rank
335
- return signed_measures
336
-
337
- def transform(self, X):
338
- ## X of shape (num_x, num_axis, filtered_complex
339
- assert (
340
- self.filtration_grid is not None and self._reconversion_grid is not None
341
- ) or self.individual_grid, "Fit first"
342
-
343
- def todo_x(x):
344
- return tuple(self.transform1(x_axis, j) for j, x_axis in enumerate(x))
345
-
346
- ## out shape num_x, num_axis, degree, sm
347
- out = tuple(
348
- Parallel(n_jobs=self.n_jobs, backend="threading")(
349
- delayed(todo_x)(x) for x in X
350
- )
351
- )
352
- # out = Parallel(n_jobs=self.n_jobs, backend="threading")(
353
- # delayed(self.transform1)(to_st, thread_id=str(thread_id))
354
- # for thread_id, to_st in tqdm(
355
- # enumerate(X),
356
- # disable=not self.progress,
357
- # desc="Computing signed measure decompositions",
358
- # )
359
- # )
360
- return out
361
-
362
-
363
- class SimplexTree2SignedMeasure(FilteredComplex2SignedMeasure):
364
- def __init__(
365
- self,
366
- # homological degrees + None for euler
367
- degrees: list[int | None] = [],
368
- rank_degrees: list[int] = [], # same for rank invariant
369
- filtration_grid: (
370
- Sequence[Sequence[np.ndarray]]
371
- # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
372
- | None
373
- ) = None,
374
- progress=False, # tqdm
375
- num_collapses: int | str = 0, # edge collapses before computing
376
- n_jobs=None,
377
- resolution: (
378
- Iterable[int] | int | None
379
- ) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
380
- # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
381
- plot: bool = False,
382
- filtration_quantile: float = 0.0, # quantile for inferring filtration grid
383
- # wether or not to do the möbius inversion (not recommended to touch)
384
- # _möbius_inversion: bool = True,
385
- expand=False, # expand the simplextree befoe computing the homology
386
- normalize_filtrations: bool = False,
387
- # exact_computation:bool=False, # compute the exact signed measure.
388
- grid_strategy: str = "exact",
389
- seed: int = 0, # if fit_fraction is not 1, the seed sampling
390
- fit_fraction=1, # the fraction of the data on which to fit
391
- out_resolution: Iterable[int] | int | None = None,
392
- individual_grid: Optional[
393
- bool
394
- ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
395
- enforce_null_mass: bool = False,
396
- flatten=True,
397
- backend: Optional[str] = None,
398
- ):
399
- stuff = locals()
400
- stuff.pop("self")
401
- keys = list(stuff.keys())
402
- for key in keys:
403
- if key.startswith("__"):
404
- stuff.pop(key)
405
- super().__init__(**stuff)
406
- from warnings import warn
407
-
408
- warn("This class is deprecated, use FilteredComplex2SignedMeasure instead.")
409
-
410
-
411
- # class SimplexTrees2SignedMeasures(SimplexTree2SignedMeasure):
412
- # """
413
- # Input
414
- # -----
415
- #
416
- # (data) x (axis, e.g. different bandwidths for simplextrees) x (simplextree)
417
- #
418
- # Output
419
- # ------
420
- # (data) x (axis) x (degree) x (signed measure)
421
- # """
422
- #
423
- # def __init__(self, **kwargs):
424
- # super().__init__(**kwargs)
425
- # self._num_st_per_data = None
426
- # # self._super_model=SimplexTree2SignedMeasure(**kwargs)
427
- # self._filtration_grids = None
428
- # return
429
- #
430
- # def fit(self, X, y=None):
431
- # if len(X) == 0:
432
- # return self
433
- # try:
434
- # self._num_st_per_data = len(X[0])
435
- # except:
436
- # raise Exception(
437
- # "Shape has to be (num_data, num_axis), dtype=SimplexTreeMulti"
438
- # )
439
- # self._filtration_grids = []
440
- # for axis in range(self._num_st_per_data):
441
- # self._filtration_grids.append(
442
- # super().fit([x[axis] for x in X]).filtration_grid
443
- # )
444
- # # self._super_fits.append(truc)
445
- # # self._super_fits_params = [super().fit([x[axis] for x in X]).get_params() for axis in range(self._num_st_per_data)]
446
- # return self
447
- #
448
- # def transform(self, X):
449
- # if self.normalize_filtrations:
450
- # _reconversion_grids = [
451
- # [np.linspace(0, 1, num=len(f), dtype=float) for f in F]
452
- # for F in self._filtration_grids
453
- # ]
454
- # else:
455
- # _reconversion_grids = self._filtration_grids
456
- #
457
- # def todo(x):
458
- # # return [SimplexTree2SignedMeasure().set_params(**transformer_params).transform1(x[axis]) for axis,transformer_params in enumerate(self._super_fits_params)]
459
- # out = [
460
- # self.transform1(
461
- # x[axis],
462
- # filtration_grid=filtration_grid,
463
- # _reconversion_grid=_reconversion_grid,
464
- # )
465
- # for axis, filtration_grid, _reconversion_grid in zip(
466
- # range(self._num_st_per_data),
467
- # self._filtration_grids,
468
- # _reconversion_grids,
469
- # )
470
- # ]
471
- # return out
472
- #
473
- # return Parallel(n_jobs=self.n_jobs, backend="threading")(
474
- # delayed(todo)(x)
475
- # for x in tqdm(
476
- # X,
477
- # disable=not self.progress,
478
- # desc="Computing Signed Measures from simplextrees.",
479
- # )
480
- # )
481
-
482
-
483
- def rescale_sparse_signed_measure(
484
- signed_measure, filtration_weights, normalize_scales=None
485
- ):
486
- # from copy import deepcopy
487
- #
488
- # out = deepcopy(signed_measure)
489
-
490
- if filtration_weights is None and normalize_scales is None:
491
- return signed_measure
492
-
493
- # if normalize_scales is None:
494
- # out = tuple(
495
- # (
496
- # _cat(
497
- # tuple(
498
- # signed_measure[degree][0][:, parameter]
499
- # * filtration_weights[parameter]
500
- # for parameter in range(num_parameters)
501
- # ),
502
- # axis=1,
503
- # ),
504
- # signed_measure[degree][1],
505
- # )
506
- # for degree in range(len(signed_measure))
507
- # )
508
- # for degree in range(len(signed_measure)): # degree
509
- # for parameter in range(len(filtration_weights)):
510
- # signed_measure[degree][0][:, parameter] *= filtration_weights[parameter]
511
- # # TODO Broadcast w.r.t. the parameter
512
- # out = tuple(
513
- # _cat(
514
- # tuple(
515
- # signed_measure[degree][0][:, [parameter]]
516
- # * filtration_weights[parameter]
517
- # / (
518
- # normalize_scales[degree][parameter]
519
- # if normalize_scales is not None
520
- # else 1
521
- # )
522
- # for parameter in range(num_parameters)
523
- # ),
524
- # axis=1,
525
- # )
526
- # for degree in range(len(signed_measure))
527
- # )
528
- out = tuple(
529
- (
530
- signed_measure[degree][0]
531
- * (1 if filtration_weights is None else filtration_weights.reshape(1, -1))
532
- / (
533
- normalize_scales[degree].reshape(1, -1)
534
- if normalize_scales is not None
535
- else 1
536
- ),
537
- signed_measure[degree][1],
538
- )
539
- for degree in range(len(signed_measure))
540
- )
541
- # for degree in range(len(out)):
542
- # for parameter in range(len(filtration_weights)):
543
- # out[degree][0][:, parameter] *= (
544
- # filtration_weights[parameter] / normalize_scales[degree][parameter]
545
- # )
546
- return out
547
-
548
-
549
- class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
550
- """
551
- Input
552
- -----
553
-
554
- (data) x (degree) x (signed measure) or (data) x (axis) x (degree) x (signed measure)
555
-
556
- Iterable[list[signed_measure_matrix of degree]] or Iterable[previous].
557
-
558
- The second is meant to use multiple choices for signed measure input. An example of usage : they come from a Rips + Density with different bandwidth.
559
- It is controlled by the axis parameter.
560
-
561
- Output
562
- ------
563
-
564
- Iterable[list[(reweighted)_sparse_signed_measure of degree]]
565
-
566
- or (deep format)
567
-
568
- Tensor of shape (num_axis*num_degrees, data, max_num_pts, num_parameters)
569
- """
570
-
571
- def __init__(
572
- self,
573
- filtrations_weights: Optional[Iterable[float]] = None,
574
- normalize=False,
575
- plot: bool = False,
576
- unsparse: bool = False,
577
- axis: int = -1,
578
- resolution: int | Iterable[int] = 50,
579
- flatten: bool = False,
580
- deep_format: bool = False,
581
- unrag: bool = True,
582
- n_jobs: int = 1,
583
- verbose: bool = False,
584
- integrate: bool = False,
585
- grid_strategy="regular",
586
- ):
587
- super().__init__()
588
- self.filtrations_weights = filtrations_weights
589
- self.num_parameters: int = 0
590
- self.plot = plot
591
- self.unsparse = unsparse
592
- self.n_jobs = n_jobs
593
- self.axis = axis
594
- self._num_axis = 0
595
- self.resolution = resolution
596
- self._filtrations_bounds = None
597
- self.flatten = flatten
598
- self.normalize = normalize
599
- self._normalization_factors = None
600
- self.deep_format = deep_format
601
- self.unrag = unrag
602
- assert (
603
- not self.deep_format or not self.unsparse or not self.integrate
604
- ), "One post processing at the time."
605
- self.verbose = verbose
606
- self._num_degrees = 0
607
- self.integrate = integrate
608
- self.grid_strategy = grid_strategy
609
- self._infered_grids = None
610
- self._axis_iterator = None
611
- self._backend = None
612
- return
613
-
614
- def _get_filtration_bounds(self, X, axis):
615
- if self._backend == "numpy":
616
- _cat = np.concatenate
617
-
618
- else:
619
- ## torch is globally imported
620
- _cat = torch.cat
621
- stuff = [
622
- _cat(
623
- [sm[axis][degree][0] for sm in X],
624
- axis=0,
625
- )
626
- for degree in range(self._num_degrees)
627
- ]
628
- sizes_ = np.array([len(x) == 0 for x in stuff])
629
- assert np.all(~sizes_), f"Degree axis {np.where(sizes_)} is/are trivial !"
630
- if self._backend == "numpy":
631
- filtrations_bounds = np.array(
632
- [([f.min(axis=0), f.max(axis=0)]) for f in stuff]
633
- )
634
- else:
635
- filtrations_bounds = torch.stack(
636
- [
637
- torch.stack([f.min(axis=0).values, f.max(axis=0).values])
638
- for f in stuff
639
- ]
640
- ).detach() ## don't want to rescale gradient of normalization
641
- normalization_factors = (
642
- filtrations_bounds[:, 1] - filtrations_bounds[:, 0]
643
- if self.normalize
644
- else None
645
- )
646
- # print("Normalization factors : ",self._normalization_factors)
647
- if (normalization_factors == 0).any():
648
- indices = normalization_factors == 0
649
- # warn(f"Constant filtration encountered, at degree, parameter {indices} and axis {self.axis}.")
650
- normalization_factors[indices] = 1
651
- return filtrations_bounds, normalization_factors
652
-
653
- def _plot_signed_measures(self, sms: Iterable[np.ndarray], size=4):
654
- from multipers.plots import plot_signed_measure
655
-
656
- num_degrees = len(sms[0])
657
- num_imgs = len(sms)
658
- fig, axes = plt.subplots(
659
- ncols=num_degrees,
660
- nrows=num_imgs,
661
- figsize=(size * num_degrees, size * num_imgs),
662
- )
663
- axes = np.asarray(axes).reshape(num_imgs, num_degrees)
664
- # assert axes.ndim==2, "Internal error"
665
- for i, sm in enumerate(sms):
666
- for j, sm_of_degree in enumerate(sm):
667
- plot_signed_measure(sm_of_degree, ax=axes[i, j])
668
-
669
- @staticmethod
670
- def _check_sm(sm) -> bool:
671
- return (
672
- isinstance(sm, tuple)
673
- and hasattr(sm[0], "ndim")
674
- and sm[0].ndim == 2
675
- and len(sm) == 2
676
- )
677
-
678
- def _check_axis(self, X):
679
- # axes should be (num_data, num_axis, num_degrees, (signed_measure))
680
- if len(X) == 0:
681
- return
682
- if len(X[0]) == 0:
683
- return
684
- if self._check_sm(X[0][0]):
685
- self._has_axis = False
686
- self._num_axis = 1
687
- self._axis_iterator = [slice(None)]
688
- return
689
- assert self._check_sm( ## vaguely checks that its a signed measure
690
- _sm := X[0][0][0]
691
- ), f"Cannot take this input. # data, axis, degrees, sm.\n Got {_sm} of type {type(_sm)}"
692
-
693
- self._has_axis = True
694
- self._num_axis = len(X[0])
695
- self._axis_iterator = range(self._num_axis) if self.axis == -1 else [self.axis]
696
-
697
- def _check_backend(self, X):
698
- if self._has_axis:
699
- # data, axis, degrees, (pts, weights)
700
- first_sm = X[0][0][0][0]
701
- else:
702
- first_sm = X[0][0][0]
703
- if isinstance(first_sm, np.ndarray):
704
- self._backend = "numpy"
705
- else:
706
- global torch
707
- import torch
708
-
709
- assert isinstance(first_sm, torch.Tensor)
710
- self._backend = "pytorch"
711
-
712
- def _check_measures(self, X):
713
- if self._has_axis:
714
- first_sm = X[0][0]
715
- else:
716
- first_sm = X[0]
717
- self._num_degrees = len(first_sm)
718
- self.num_parameters = first_sm[0][0].shape[1]
719
-
720
- def _check_resolution(self):
721
- assert self.num_parameters > 0, "Num parameters hasn't been initialized."
722
- if isinstance(self.resolution, int):
723
- self.resolution = [self.resolution] * self.num_parameters
724
- self.resolution = np.asarray(self.resolution, dtype=int)
725
- assert (
726
- self.resolution.shape[0] == self.num_parameters
727
- ), "Resolution doesn't have a proper size."
728
-
729
- def _check_weights(self):
730
- if self.filtrations_weights is None:
731
- return
732
- assert (
733
- self.filtrations_weights.shape[0] == self.num_parameters
734
- ), "Filtration weights don't have a proper size"
735
-
736
- def _infer_grids(self, X):
737
- # Computes normalization factors
738
- if self.normalize:
739
- # if self._has_axis and self.axis == -1:
740
- self._filtrations_bounds = []
741
- self._normalization_factors = []
742
- for ax in self._axis_iterator:
743
- (
744
- filtration_bounds,
745
- normalization_factors,
746
- ) = self._get_filtration_bounds(X, axis=ax)
747
- self._filtrations_bounds.append(filtration_bounds)
748
- self._normalization_factors.append(normalization_factors)
749
- # else:
750
- # (
751
- # self._filtrations_bounds,
752
- # self._normalization_factors,
753
- # ) = self._get_filtration_bounds(
754
- # X, axis=self._axis_iterator[0]
755
- # ) ## axis = slice(None)
756
- elif self.integrate or self.unsparse or self.deep_format:
757
- filtration_values = [
758
- np.concatenate(
759
- [
760
- (
761
- stuff
762
- if isinstance(stuff := x[ax][degree][0], np.ndarray)
763
- else stuff.detach().numpy()
764
- )
765
- for x in X
766
- for degree in range(self._num_degrees)
767
- ]
768
- )
769
- for ax in self._axis_iterator
770
- ]
771
- # axis, filtration_values
772
- filtration_values = [
773
- reduce_grid(
774
- f_ax.T, resolution=self.resolution, strategy=self.grid_strategy
775
- )
776
- for f_ax in filtration_values
777
- ]
778
- self._infered_grids = filtration_values
779
-
780
- def _print_stats(self, X):
781
- print("------------SignedMeasureFormatter------------")
782
- print("---- Parameters")
783
- print(f"Number of axis : {self._num_axis}")
784
- print(f"Number of degrees : {self._num_degrees}")
785
- print(f"Filtration bounds : \n{self._filtrations_bounds}")
786
- print(f"Normalization factor : \n{self._normalization_factors}")
787
- if self._infered_grids is not None:
788
- print(
789
- f"Filtration grid shape : \n \
790
- {tuple(tuple(len(f) for f in F) for F in self._infered_grids)}"
791
- )
792
- print("---- SM stats")
793
- print("In axis :", self._num_axis)
794
- sizes = [
795
- [[len(xd[1]) for xd in x[ax]] for x in X] for ax in self._axis_iterator
796
- ]
797
- print(f"Size means (axis) x (degree): {np.mean(sizes, axis=(1))}")
798
- print(f"Size std : {np.std(sizes, axis=(1))}")
799
- print("----------------------------------------------")
800
-
801
- def fit(self, X, y=None):
802
- # Gets a grid. This will be the max in each coord+1
803
- if (
804
- len(X) == 0
805
- or len(X[0]) == 0
806
- or (self.axis is not None and len(X[0][0][0]) == 0)
807
- ):
808
- return self
809
-
810
- self._check_axis(X)
811
- self._check_backend(X)
812
- self._check_measures(X)
813
- self._check_resolution()
814
- self._check_weights()
815
- # if not sparse : not recommended.
816
-
817
- self._infer_grids(X)
818
- if self.verbose:
819
- self._print_stats(X)
820
- return self
821
-
822
- def unsparse_signed_measure(self, sparse_signed_measure):
823
- filtrations = self._infered_grids # ax, filtration
824
- out = []
825
- for filtrations_of_ax, ax in zip(filtrations, self._axis_iterator, strict=True):
826
- sparse_signed_measure_of_ax = sparse_signed_measure[ax]
827
- measure_of_ax = []
828
- for pts, weights in sparse_signed_measure_of_ax: # over degree
829
- signed_measure, _ = np.histogramdd(
830
- pts, bins=filtrations_of_ax, weights=weights
831
- )
832
- if self.flatten:
833
- signed_measure = signed_measure.flatten()
834
- measure_of_ax.append(signed_measure)
835
- out.append(np.asarray(measure_of_ax))
836
-
837
- if self.flatten:
838
- out = np.concatenate(out).flatten()
839
- if self.axis == -1:
840
- return np.asarray(out)
841
- else:
842
- return np.asarray(out)[0]
843
-
844
- @staticmethod
845
- def deep_format_measure(signed_measure):
846
- dirac_positions, dirac_signs = signed_measure
847
- dtype = dirac_positions.dtype
848
- new_shape = list(dirac_positions.shape)
849
- new_shape[1] += 1
850
- if isinstance(dirac_positions, np.ndarray):
851
- c = np.empty(new_shape, dtype=dtype)
852
- c[:, :-1] = dirac_positions
853
- c[:, -1] = dirac_signs
854
-
855
- else:
856
- import torch
857
-
858
- c = torch.empty(new_shape, dtype=dtype)
859
- c[:, :-1] = dirac_positions
860
- c[:, -1] = dirac_signs
861
- return c
862
-
863
- @staticmethod
864
- def _integrate_measure(sm, filtrations):
865
- from multipers.point_measure import integrate_measure
866
-
867
- return integrate_measure(sm[0], sm[1], filtrations)
868
-
869
- def _rescale_measures(self, X):
870
- def rescale_from_sparse(sparse_signed_measure):
871
- if self.axis == -1 and self._has_axis:
872
- return tuple(
873
- rescale_sparse_signed_measure(
874
- sparse_signed_measure[ax],
875
- filtration_weights=self.filtrations_weights,
876
- normalize_scales=n,
877
- )
878
- for ax, n in zip(
879
- self._axis_iterator, self._normalization_factors, strict=True
880
- )
881
- )
882
- return rescale_sparse_signed_measure( ## axis iterator is of size 1 here
883
- sparse_signed_measure,
884
- filtration_weights=self.filtrations_weights,
885
- normalize_scales=self._normalization_factors[0],
886
- )
887
-
888
- out = tuple(rescale_from_sparse(x) for x in X)
889
- return out
890
-
891
- def transform(self, X):
892
- if not self._has_axis or self.axis == -1:
893
- out = X
894
- else:
895
- out = tuple(x[self.axis] for x in X)
896
- # same format for everyone
897
-
898
- if self._normalization_factors is not None:
899
- out = self._rescale_measures(out)
900
-
901
- if self.plot:
902
- # assert ax != -1, "Not implemented"
903
- self._plot_signed_measures(out)
904
- if self.integrate:
905
- filtrations = self._infered_grids
906
- # if self.axis != -1:
907
- ax = 0 # if self.axis is None else self.axis # TODO deal with axis -1
908
-
909
- assert ax != -1, "Not implemented. Can only integrate with axis"
910
- # try:
911
- out = np.asarray(
912
- [
913
- [
914
- self._integrate_measure(x[degree], filtrations=filtrations[ax])
915
- for degree in range(self._num_degrees)
916
- ]
917
- for x in out
918
- ]
919
- )
920
- # except:
921
- # print(self.axis, ax, filtrations)
922
- if self.flatten:
923
- out = out.reshape((len(X), -1))
924
- # else:
925
- # out = [[[self._integrate_measure(x[axis][degree],filtrations=filtrations[degree].T) for degree in range(self._num_degrees)] for axis in range(self._num_axis)] for x in out]
926
- elif self.unsparse:
927
- out = [self.unsparse_signed_measure(x) for x in out]
928
- elif self.deep_format:
929
- num_degrees = self._num_degrees
930
- out = tuple(
931
- tuple(self.deep_format_measure(sm[axis][degree]) for sm in out)
932
- for degree in range(num_degrees)
933
- for axis in self._axis_iterator
934
- )
935
- if self.unrag:
936
- max_num_pts = np.max(
937
- [sm.shape[0] for sm_of_axis in out for sm in sm_of_axis]
938
- )
939
- num_axis_degree = len(out)
940
- num_data = len(out[0])
941
- assert num_axis_degree == num_degrees * (
942
- self._num_axis if self._has_axis else 1
943
- ), f"Bad axis/degree count. Got {num_axis_degree} (Internal error)"
944
- num_parameters = out[0][0].shape[1]
945
- dtype = out[0][0].dtype
946
- if isinstance(out[0][0], np.ndarray):
947
- from numpy import zeros
948
- else:
949
- from torch import zeros
950
- unragged_tensor = zeros(
951
- (
952
- num_axis_degree,
953
- num_data,
954
- max_num_pts,
955
- num_parameters,
956
- ),
957
- dtype=dtype,
958
- )
959
- for ax in range(num_axis_degree):
960
- for data in range(num_data):
961
- sm = out[ax][data]
962
- a, b = sm.shape
963
- unragged_tensor[ax, data, :a, :b] = sm
964
- out = unragged_tensor
965
- return out
966
-
967
-
968
- class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
969
- """
970
- Discrete convolution of a signed measure
971
-
972
- Input
973
- -----
974
-
975
- (data) x (degree) x (signed measure)
976
-
977
- Parameters
978
- ----------
979
- - filtration_grid : Iterable[array] For each filtration, the filtration values on which to evaluate the grid
980
- - resolution : int or (num_parameters) : If filtration grid is not given, will infer a grid, with this resolution
981
- - grid_strategy : the strategy to generate the grid. Available ones are regular, quantile, exact
982
- - flatten : if true, the output will be flattened
983
- - kernel : kernel to used to convolve the images.
984
- - flatten : flatten the images if True
985
- - progress : progress bar if True
986
- - backend : sklearn, pykeops or numba.
987
- - plot : Creates a plot Figure.
988
-
989
- Output
990
- ------
991
-
992
- (data) x (concatenation of imgs of degree)
993
- """
994
-
995
- def __init__(
996
- self,
997
- filtration_grid: Iterable[np.ndarray] = None,
998
- kernel: available_kernels = "gaussian",
999
- bandwidth: float | Iterable[float] = 1.0,
1000
- flatten: bool = False,
1001
- n_jobs: int = 1,
1002
- resolution: int | None = None,
1003
- grid_strategy: str = "regular",
1004
- progress: bool = False,
1005
- backend: str = "pykeops",
1006
- plot: bool = False,
1007
- log_density: bool = False,
1008
- **kde_kwargs,
1009
- # **kwargs ## DANGEROUS
1010
- ):
1011
- super().__init__()
1012
- self.kernel: available_kernels = kernel
1013
- self.bandwidth = bandwidth
1014
- # self.more_kde_kwargs=kwargs
1015
- self.filtration_grid = filtration_grid
1016
- self.flatten = flatten
1017
- self.progress = progress
1018
- self.n_jobs = n_jobs
1019
- self.resolution = resolution
1020
- self.grid_strategy = grid_strategy
1021
- self._is_input_sparse = None
1022
- self._refit = filtration_grid is None
1023
- self._input_resolution = None
1024
- self._bandwidths = None
1025
- self.diameter = None
1026
- self.backend = backend
1027
- self.plot = plot
1028
- self.log_density = log_density
1029
- self.kde_kwargs = kde_kwargs
1030
- return
1031
-
1032
- def fit(self, X, y=None):
1033
- # Infers if the input is sparse given X
1034
- if len(X) == 0:
1035
- return self
1036
- if isinstance(X[0][0], tuple):
1037
- self._is_input_sparse = True
1038
- else:
1039
- self._is_input_sparse = False
1040
- # print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
1041
- if not self._is_input_sparse:
1042
- self._input_resolution = X[0][0].shape
1043
- try:
1044
- float(self.bandwidth)
1045
- b = float(self.bandwidth)
1046
- self._bandwidths = [
1047
- b if b > 0 else -b * s for s in self._input_resolution
1048
- ]
1049
- except:
1050
- self._bandwidths = [
1051
- b if b > 0 else -b * s
1052
- for s, b in zip(self._input_resolution, self.bandwidth)
1053
- ]
1054
- return self # in that case, singed measures are matrices, and the grid is already given
1055
-
1056
- if self.filtration_grid is None and self.resolution is None:
1057
- raise Exception(
1058
- "Cannot infer filtration grid. Provide either a filtration grid or a resolution."
1059
- )
1060
- # If not sparse : a grid has to be defined
1061
- if self._refit:
1062
- # print("Fitting a grid...", end="")
1063
- pts = np.concatenate(
1064
- [sm[0] for signed_measures in X for sm in signed_measures]
1065
- ).T
1066
- self.filtration_grid = reduce_grid(
1067
- pts,
1068
- strategy=self.grid_strategy,
1069
- resolution=self.resolution,
1070
- )
1071
- # print('Done.')
1072
- if self.filtration_grid is not None:
1073
- self.diameter = np.linalg.norm(
1074
- [f.max() - f.min() for f in self.filtration_grid]
1075
- )
1076
- if self.progress:
1077
- print(f"Computed a diameter of {self.diameter}")
1078
- return self
1079
-
1080
- def _sm2smi(self, signed_measures: Iterable[np.ndarray]):
1081
- # print(self._input_resolution, self.bandwidths, _bandwidths)
1082
- from scipy.ndimage import gaussian_filter
1083
-
1084
- return np.concatenate(
1085
- [
1086
- gaussian_filter(
1087
- input=signed_measure,
1088
- sigma=self._bandwidths,
1089
- mode="constant",
1090
- cval=0,
1091
- )
1092
- for signed_measure in signed_measures
1093
- ],
1094
- axis=0,
1095
- )
1096
-
1097
- def _transform_from_sparse(self, X):
1098
- bandwidth = (
1099
- self.bandwidth if self.bandwidth > 0 else -self.bandwidth * self.diameter
1100
- )
1101
- # COMPILE KEOPS FIRST
1102
- dummyx = [X[0]]
1103
- dummyf = [f[:2] for f in self.filtration_grid]
1104
- convolution_signed_measures(
1105
- dummyx,
1106
- filtrations=dummyf,
1107
- bandwidth=bandwidth,
1108
- flatten=self.flatten,
1109
- n_jobs=1,
1110
- kernel=self.kernel,
1111
- backend=self.backend,
1112
- )
1113
-
1114
- return convolution_signed_measures(
1115
- X,
1116
- filtrations=self.filtration_grid,
1117
- bandwidth=bandwidth,
1118
- flatten=self.flatten,
1119
- n_jobs=self.n_jobs,
1120
- kernel=self.kernel,
1121
- backend=self.backend,
1122
- **self.kde_kwargs,
1123
- )
1124
-
1125
- def _plot_imgs(self, imgs: Iterable[np.ndarray], size=4):
1126
- from multipers.plots import plot_surface
1127
-
1128
- num_degrees = imgs[0].shape[0]
1129
- num_imgs = len(imgs)
1130
- fig, axes = plt.subplots(
1131
- ncols=num_degrees,
1132
- nrows=num_imgs,
1133
- figsize=(size * num_degrees, size * num_imgs),
1134
- )
1135
- axes = np.asarray(axes).reshape(num_imgs, num_degrees)
1136
- # assert axes.ndim==2, "Internal error"
1137
- for i, img in enumerate(imgs):
1138
- for j, img_of_degree in enumerate(img):
1139
- plot_surface(
1140
- self.filtration_grid, img_of_degree, ax=axes[i, j], cmap="Spectral"
1141
- )
1142
-
1143
- def transform(self, X):
1144
- if self._is_input_sparse is None:
1145
- raise Exception("Fit first")
1146
- if self._is_input_sparse:
1147
- out = self._transform_from_sparse(X)
1148
- else:
1149
- todo = SignedMeasure2Convolution._sm2smi
1150
- out = Parallel(n_jobs=self.n_jobs, backend="threading")(
1151
- delayed(todo)(self, signed_measures)
1152
- for signed_measures in tqdm(
1153
- X, desc="Computing images", disable=not self.progress
1154
- )
1155
- )
1156
- if self.plot and not self.flatten:
1157
- if self.progress:
1158
- print("Plotting convolutions...", end="")
1159
- self._plot_imgs(out)
1160
- if self.progress:
1161
- print("Done !")
1162
- if self.flatten and not self._is_input_sparse:
1163
- out = [x.flatten() for x in out]
1164
- return np.asarray(out)
1165
-
1166
-
1167
- class SignedMeasure2SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
1168
- """
1169
- Transformer from signed measure to distance matrix.
1170
-
1171
- Input
1172
- -----
1173
-
1174
- (data) x (degree) x (signed measure)
1175
-
1176
- Format
1177
- ------
1178
- - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1179
- - each data is a list of signed measure (for e.g. multiple degrees)
1180
-
1181
- Output
1182
- ------
1183
- - (degree) x (distance matrix)
1184
- """
1185
-
1186
- def __init__(
1187
- self,
1188
- n_jobs=None,
1189
- num_directions: int = 10,
1190
- _sliced: bool = True,
1191
- epsilon=-1,
1192
- ground_norm=1,
1193
- progress=False,
1194
- grid_reconversion=None,
1195
- scales=None,
1196
- ):
1197
- super().__init__()
1198
- self.n_jobs = n_jobs
1199
- self._SWD_list = None
1200
- self._sliced = _sliced
1201
- self.epsilon = epsilon
1202
- self.ground_norm = ground_norm
1203
- self.num_directions = num_directions
1204
- self.progress = progress
1205
- self.grid_reconversion = grid_reconversion
1206
- self.scales = scales
1207
- return
1208
-
1209
- def fit(self, X, y=None):
1210
- from multipers.ml.sliced_wasserstein import (
1211
- SlicedWassersteinDistance,
1212
- WassersteinDistance,
1213
- )
1214
-
1215
- # _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
1216
- if len(X) == 0:
1217
- return self
1218
- num_degrees = len(X[0])
1219
- self._SWD_list = [
1220
- (
1221
- SlicedWassersteinDistance(
1222
- num_directions=self.num_directions,
1223
- n_jobs=self.n_jobs,
1224
- scales=self.scales,
1225
- )
1226
- if self._sliced
1227
- else WassersteinDistance(
1228
- epsilon=self.epsilon,
1229
- ground_norm=self.ground_norm,
1230
- n_jobs=self.n_jobs,
1231
- )
1232
- )
1233
- for _ in range(num_degrees)
1234
- ]
1235
- for degree, swd in enumerate(self._SWD_list):
1236
- signed_measures_of_degree = [x[degree] for x in X]
1237
- swd.fit(signed_measures_of_degree)
1238
- return self
1239
-
1240
- def transform(self, X):
1241
- assert self._SWD_list is not None, "Fit first"
1242
- # out = []
1243
- # for degree, swd in tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress):
1244
- with tqdm(
1245
- enumerate(self._SWD_list),
1246
- desc="Computing distance matrices",
1247
- total=len(self._SWD_list),
1248
- disable=not self.progress,
1249
- ) as SWD_it:
1250
- # signed_measures_of_degree = [x[degree] for x in X]
1251
- # out.append(swd.transform(signed_measures_of_degree))
1252
- def todo(swd, X_of_degree):
1253
- return swd.transform(X_of_degree)
1254
-
1255
- out = Parallel(n_jobs=self.n_jobs, prefer="threads")(
1256
- delayed(todo)(swd, [x[degree] for x in X]) for degree, swd in SWD_it
1257
- )
1258
- return np.asarray(out)
1259
-
1260
- def predict(self, X):
1261
- return self.transform(X)
1262
-
1263
-
1264
- class SignedMeasures2SlicedWassersteinDistances(BaseEstimator, TransformerMixin):
1265
- """
1266
- Transformer from signed measure to distance matrix.
1267
- Input
1268
- -----
1269
- (data) x opt (axis) x (degree) x (signed measure)
1270
-
1271
- Format
1272
- ------
1273
- - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1274
- - each data is a list of signed measure (for e.g. multiple degrees)
1275
-
1276
- Output
1277
- ------
1278
- - (axis) x (degree) x (distance matrix)
1279
- """
1280
-
1281
- def __init__(
1282
- self,
1283
- progress=False,
1284
- n_jobs: int = 1,
1285
- scales: Iterable[Iterable[float]] | None = None,
1286
- **kwargs,
1287
- ): # same init
1288
- self._init_child = SignedMeasure2SlicedWassersteinDistance(
1289
- progress=False, scales=None, n_jobs=-1, **kwargs
1290
- )
1291
- self._axe_iterator = None
1292
- self._childs_to_fit = None
1293
- self.scales = scales
1294
- self.progress = progress
1295
- self.n_jobs = n_jobs
1296
- return
1297
-
1298
- def fit(self, X, y=None):
1299
- from sklearn.base import clone
1300
-
1301
- if len(X) == 0:
1302
- return self
1303
- if isinstance(X[0][0], tuple): # Meaning that there are no axes
1304
- self._axe_iterator = [slice(None)]
1305
- else:
1306
- self._axe_iterator = range(len(X[0]))
1307
- if self.scales is None:
1308
- self.scales = [None]
1309
- else:
1310
- self.scales = np.asarray(self.scales)
1311
- if self.scales.ndim == 1:
1312
- self.scales = np.asarray([self.scales])
1313
- assert (
1314
- self.scales[0] is None or self.scales.ndim == 2
1315
- ), "Scales have to be either None or a list of scales !"
1316
- self._childs_to_fit = [
1317
- clone(self._init_child).set_params(scales=scales).fit([x[axis] for x in X])
1318
- for axis, scales in product(self._axe_iterator, self.scales)
1319
- ]
1320
- print("New axes : ", list(product(self._axe_iterator, self.scales)))
1321
- return self
1322
-
1323
- def transform(self, X):
1324
- return Parallel(n_jobs=self.n_jobs, prefer="processes")(
1325
- delayed(self._childs_to_fit[child_id].transform)([x[axis] for x in X])
1326
- for child_id, (axis, _) in tqdm(
1327
- enumerate(product(self._axe_iterator, self.scales)),
1328
- desc=f"Computing distances matrices of axis, and scales",
1329
- disable=not self.progress,
1330
- total=len(self._childs_to_fit),
1331
- )
1332
- )
1333
- # [
1334
- # child.transform([x[axis // len(self.scales)] for x in X])
1335
- # for axis, child in tqdm(enumerate(self._childs_to_fit),
1336
- # desc=f"Computing distances of axis", disable=not self.progress, total=len(self._childs_to_fit)
1337
- # )
1338
- # ]
1339
-
1340
-
1341
- class SimplexTree2RectangleDecomposition(BaseEstimator, TransformerMixin):
1342
- """
1343
- Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition.
1344
- """
1345
-
1346
- def __init__(
1347
- self,
1348
- filtration_grid: np.ndarray,
1349
- degrees: Iterable[int],
1350
- plot=False,
1351
- reconvert_grid=True,
1352
- num_collapses: int = 0,
1353
- ):
1354
- super().__init__()
1355
- self.filtration_grid = filtration_grid
1356
- self.degrees = degrees
1357
- self.plot = plot
1358
- self.reconvert_grid = reconvert_grid
1359
- self.num_collapses = num_collapses
1360
- return
1361
-
1362
- def fit(self, X, y=None):
1363
- """
1364
- TODO : infer grid from multiple simplextrees
1365
- """
1366
- return self
1367
-
1368
- def transform(self, X: Iterable[mp.simplex_tree_multi.SimplexTreeMulti_type]):
1369
- rectangle_decompositions = [
1370
- [
1371
- _st2ranktensor(
1372
- simplextree,
1373
- filtration_grid=self.filtration_grid,
1374
- degree=degree,
1375
- plot=self.plot,
1376
- reconvert_grid=self.reconvert_grid,
1377
- num_collapse=self.num_collapses,
1378
- )
1379
- for degree in self.degrees
1380
- ]
1381
- for simplextree in X
1382
- ]
1383
- # TODO : return iterator ?
1384
- return rectangle_decompositions
1385
-
1386
-
1387
- def _st2ranktensor(
1388
- st: mp.simplex_tree_multi.SimplexTreeMulti_type,
1389
- filtration_grid: np.ndarray,
1390
- degree: int,
1391
- plot: bool,
1392
- reconvert_grid: bool,
1393
- num_collapse: int | str = 0,
1394
- ):
1395
- """
1396
- TODO
1397
- """
1398
- # Copy (the squeeze change the filtration values)
1399
- # stcpy = mp.SimplexTreeMulti(st)
1400
- # turns the simplextree into a coordinate simplex tree
1401
- stcpy = st.grid_squeeze(filtration_grid=filtration_grid, coordinate_values=True)
1402
- # stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
1403
- if num_collapse == "full":
1404
- stcpy.collapse_edges(full=True, ignore_warning=True, max_dimension=degree + 1)
1405
- elif isinstance(num_collapse, int):
1406
- stcpy.collapse_edges(
1407
- num=num_collapse, ignore_warning=True, max_dimension=degree + 1
1408
- )
1409
- else:
1410
- raise TypeError(
1411
- f"Invalid num_collapse=\
1412
- {num_collapse} type. Either full, or an integer."
1413
- )
1414
- # computes the rank invariant tensor
1415
- rank_tensor = mp.rank_invariant2d(
1416
- stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid]
1417
- )
1418
- # refactor this tensor into the rectangle decomposition of the signed betti
1419
- grid_conversion = filtration_grid if reconvert_grid else None
1420
- rank_decomposition = rank_decomposition_by_rectangles(
1421
- rank_tensor,
1422
- threshold=True,
1423
- )
1424
- rectangle_decomposition = tensor_möbius_inversion(
1425
- tensor=rank_decomposition,
1426
- grid_conversion=grid_conversion,
1427
- plot=plot,
1428
- num_parameters=st.num_parameters,
1429
- )
1430
- return rectangle_decomposition
1431
-
1432
-
1433
- class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
1434
- def __init__(
1435
- self,
1436
- degrees: Iterable[int],
1437
- min_rips_value: float,
1438
- max_rips_value,
1439
- max_normalized_degree: float,
1440
- min_normalized_degree: float,
1441
- grid_granularity: int,
1442
- progress: bool = False,
1443
- n_jobs=1,
1444
- sparse: bool = False,
1445
- _möbius_inversion=True,
1446
- fit_fraction=1,
1447
- ) -> None:
1448
- super().__init__()
1449
- self.min_rips_value = min_rips_value
1450
- self.max_rips_value = max_rips_value
1451
- self.min_normalized_degree = min_normalized_degree
1452
- self.max_normalized_degree = max_normalized_degree
1453
- self._max_rips_value = None
1454
- self.grid_granularity = grid_granularity
1455
- self.progress = progress
1456
- self.n_jobs = n_jobs
1457
- self.degrees = degrees
1458
- self.sparse = sparse
1459
- self._möbius_inversion = _möbius_inversion
1460
- self.fit_fraction = fit_fraction
1461
- return
1462
-
1463
- def fit(self, X: np.ndarray | list, y=None):
1464
- if self.max_rips_value < 0:
1465
- print("Estimating scale...", flush=True, end="")
1466
- indices = np.random.choice(
1467
- len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
1468
- )
1469
- diameters = np.max(
1470
- [distance_matrix(x, x).max() for x in (X[i] for i in indices)]
1471
- )
1472
- print(f"Done. {diameters}", flush=True)
1473
- self._max_rips_value = (
1474
- -self.max_rips_value * diameters
1475
- if self.max_rips_value < 0
1476
- else self.max_rips_value
1477
- )
1478
- return self
1479
-
1480
- def _transform1(self, data: np.ndarray):
1481
- _distance_matrix = distance_matrix(data, data)
1482
- signed_measures = []
1483
- (
1484
- rips_values,
1485
- normalized_degree_values,
1486
- hilbert_functions,
1487
- minimal_presentations,
1488
- ) = hf_degree_rips(
1489
- _distance_matrix,
1490
- min_rips_value=self.min_rips_value,
1491
- max_rips_value=self._max_rips_value,
1492
- min_normalized_degree=self.min_normalized_degree,
1493
- max_normalized_degree=self.max_normalized_degree,
1494
- grid_granularity=self.grid_granularity,
1495
- max_homological_dimension=np.max(self.degrees),
1496
- )
1497
- for degree in self.degrees:
1498
- hilbert_function = hilbert_functions[degree]
1499
- signed_measure = (
1500
- signed_betti(hilbert_function, threshold=True)
1501
- if self._möbius_inversion
1502
- else hilbert_function
1503
- )
1504
- if self.sparse:
1505
- signed_measure = tensor_möbius_inversion(
1506
- tensor=signed_measure,
1507
- num_parameters=2,
1508
- grid_conversion=[rips_values, normalized_degree_values],
1509
- )
1510
- if not self._möbius_inversion:
1511
- signed_measure = signed_measure.flatten()
1512
- signed_measures.append(signed_measure)
1513
- return signed_measures
1514
-
1515
- def transform(self, X):
1516
- return Parallel(n_jobs=self.n_jobs)(
1517
- delayed(self._transform1)(data)
1518
- for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}")
1519
- )
1520
-
1521
-
1522
- def tensor_möbius_inversion(
1523
- tensor,
1524
- grid_conversion: Iterable[np.ndarray] | None = None,
1525
- plot: bool = False,
1526
- raw: bool = False,
1527
- num_parameters: int | None = None,
1528
- ):
1529
- from torch import Tensor
1530
-
1531
- betti_sparse = Tensor(tensor.copy()).to_sparse() # Copy necessary in some cases :(
1532
- num_indices, num_pts = betti_sparse.indices().shape
1533
- num_parameters = num_indices if num_parameters is None else num_parameters
1534
- if num_indices == num_parameters: # either hilbert or rank invariant
1535
- rank_invariant = False
1536
- elif 2 * num_parameters == num_indices:
1537
- rank_invariant = True
1538
- else:
1539
- raise TypeError(
1540
- f"Unsupported betti shape. {num_indices}\
1541
- has to be either {num_parameters} or \
1542
- {2*num_parameters}."
1543
- )
1544
- points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
1545
- weights = np.asarray(betti_sparse.values(), dtype=int)
1546
-
1547
- if grid_conversion is not None:
1548
- coords = np.empty(shape=(num_pts, num_indices), dtype=float)
1549
- for i in range(num_indices):
1550
- coords[:, i] = grid_conversion[i % num_parameters][points_filtration[:, i]]
1551
- else:
1552
- coords = points_filtration
1553
- if (not rank_invariant) and plot:
1554
- plt.figure()
1555
- color_weights = np.empty(weights.shape)
1556
- color_weights[weights > 0] = np.log10(weights[weights > 0]) + 2
1557
- color_weights[weights < 0] = -np.log10(-weights[weights < 0]) - 2
1558
- plt.scatter(
1559
- points_filtration[:, 0],
1560
- points_filtration[:, 1],
1561
- c=color_weights,
1562
- cmap="coolwarm",
1563
- )
1564
- if (not rank_invariant) or raw:
1565
- return coords, weights
1566
-
1567
- def _is_trivial(rectangle: np.ndarray):
1568
- birth = rectangle[:num_parameters]
1569
- death = rectangle[num_parameters:]
1570
- return np.all(birth <= death) # and not np.array_equal(birth,death)
1571
-
1572
- correct_indices = np.array([_is_trivial(rectangle) for rectangle in coords])
1573
- if len(correct_indices) == 0:
1574
- return np.empty((0, num_indices)), np.empty((0))
1575
- signed_measure = np.asarray(coords[correct_indices])
1576
- weights = weights[correct_indices]
1577
- if plot:
1578
- # plot only the rank decompo for the moment
1579
- assert signed_measure.shape[1] == 4
1580
-
1581
- def _plot_rectangle(rectangle: np.ndarray, weight: float):
1582
- x_axis = rectangle[[0, 2]]
1583
- y_axis = rectangle[[1, 3]]
1584
- color = "blue" if weight > 0 else "red"
1585
- plt.plot(x_axis, y_axis, c=color)
1586
-
1587
- for rectangle, weight in zip(signed_measure, weights):
1588
- _plot_rectangle(rectangle=rectangle, weight=weight)
1589
- return signed_measure, weights
1
+ from collections.abc import Iterable, Sequence
2
+ from itertools import product
3
+ from typing import Optional, Union
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from joblib import Parallel, delayed
8
+ from sklearn.base import BaseEstimator, TransformerMixin
9
+ from tqdm import tqdm
10
+
11
+ import multipers as mp
12
+ from multipers.grids import compute_grid as reduce_grid
13
+ from multipers.filtrations.density import available_kernels, convolution_signed_measures
14
+ from multipers.point_measure import rank_decomposition_by_rectangles, signed_betti
15
+
16
+
17
+ class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
18
+ """
19
+ Input
20
+ -----
21
+ Iterable[SimplexTreeMulti]
22
+
23
+ Output
24
+ ------
25
+ Iterable[ list[signed_measure for degree] ]
26
+
27
+ signed measure is either
28
+ - (points : (n x num_parameters) array, weights : (n) int array ) if sparse,
29
+ - else an integer matrix.
30
+
31
+ Parameters
32
+ ----------
33
+ - degrees : list of degrees to compute. None correspond to the euler characteristic
34
+ - filtration grid : the grid on which to compute.
35
+ If None, the fit will infer it from
36
+ - fit_fraction : the fraction of data to consider for the fit, seed is controlled by the seed parameter
37
+ - resolution : the resolution of this grid
38
+ - filtration_quantile : filtrations values quantile to ignore
39
+ - grid_strategy:str : 'regular' or 'quantile' or 'exact'
40
+ - normalize filtration : if sparse, will normalize all filtrations.
41
+ - expand : expands the simplextree to compute correctly the degree, for
42
+ flag complexes
43
+ - invariant : the topological invariant to produce the signed measure.
44
+ Choices are "hilbert" or "euler". Will add rank invariant later.
45
+ - num_collapse : Either an int or "full". Collapse the complex before
46
+ doing computation.
47
+ - _möbius_inversion : if False, will not do the mobius inversion. output
48
+ has to be a matrix then.
49
+ - enforce_null_mass : Returns a zero mass measure, by thresholding the
50
+ module if True.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ # homological degrees + None for euler
56
+ degrees: list[int | None] = [],
57
+ rank_degrees: list[int] = [], # same for rank invariant
58
+ filtration_grid: (
59
+ Sequence[Sequence[np.ndarray]]
60
+ # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
61
+ | None
62
+ ) = None,
63
+ progress=False, # tqdm
64
+ num_collapses: int | str = 0, # edge collapses before computing
65
+ n_jobs=None,
66
+ resolution: (
67
+ Iterable[int] | int | None
68
+ ) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
69
+ # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
70
+ plot: bool = False,
71
+ filtration_quantile: float = 0.0, # quantile for inferring filtration grid
72
+ # wether or not to do the möbius inversion (not recommended to touch)
73
+ # _möbius_inversion: bool = True,
74
+ expand=False, # expand the simplextree befoe computing the homology
75
+ normalize_filtrations: bool = False,
76
+ # exact_computation:bool=False, # compute the exact signed measure.
77
+ grid_strategy: str = "exact",
78
+ seed: int = 0, # if fit_fraction is not 1, the seed sampling
79
+ fit_fraction=1, # the fraction of the data on which to fit
80
+ out_resolution: Iterable[int] | int | None = None,
81
+ individual_grid: Optional[
82
+ bool
83
+ ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
84
+ enforce_null_mass: bool = False,
85
+ flatten=True,
86
+ backend: Optional[str] = None,
87
+ ):
88
+ super().__init__()
89
+ self.degrees = degrees
90
+ self.rank_degrees = rank_degrees
91
+ self.filtration_grid = filtration_grid
92
+ self.progress = progress
93
+ self.num_collapses = num_collapses
94
+ self.n_jobs = n_jobs
95
+ self.resolution = resolution
96
+ self.plot = plot
97
+ self.backend = backend
98
+ # self.sparse=sparse # TODO : deprecate
99
+ self.filtration_quantile = filtration_quantile
100
+ # Will only work for non sparse output. (discrete matrices cannot be "rescaled")
101
+ self.normalize_filtrations = normalize_filtrations
102
+ self.grid_strategy = grid_strategy
103
+ # self._möbius_inversion = _möbius_inversion
104
+ self._reconversion_grid = None
105
+ self.expand = expand
106
+ # will only refit the grid if filtration_grid has never been given.
107
+ self._refit_grid = None
108
+ self.seed = seed
109
+ self.fit_fraction = fit_fraction
110
+ self._transform_st = None
111
+ self.out_resolution = out_resolution
112
+ self.individual_grid = individual_grid
113
+ self.enforce_null_mass = enforce_null_mass
114
+ self._default_mass_location = None
115
+ self.flatten = flatten
116
+ self.num_parameters: int = 0
117
+
118
+ return
119
+
120
+ @staticmethod
121
+ def _is_filtered_complex(input):
122
+ return mp.simplex_tree_multi.is_simplextree_multi(input) or mp.slicer.is_slicer(
123
+ input, allow_minpres=True
124
+ )
125
+
126
+ def _input_checks(self, X):
127
+ assert len(X) > 0, "No filtered complex found. Cannot fit."
128
+ assert self._is_filtered_complex(
129
+ X[0][0]
130
+ ), f"X[0] is not a known filtered complex, {X[0]=}, nor X[0][0]."
131
+ self._num_axis = len(X[0])
132
+ first = X[0][0]
133
+ assert (
134
+ not mp.slicer.is_slicer(first) or not self.expand
135
+ ), "Cannot expand slicers."
136
+ assert not mp.slicer.is_slicer(first) or not (
137
+ isinstance(first, Union[tuple, list]) and first[0].is_minpres
138
+ ), "Multi-degree minpres are not supported yet as an input. This can still be computed by providing a backend."
139
+
140
+ def _infer_filtration(self, X):
141
+ self.num_parameters = X[0][0].num_parameters
142
+ indices = np.random.choice(
143
+ len(X), min(int(self.fit_fraction * len(X)) + 1, len(X)), replace=False
144
+ )
145
+ ## ax, num_x
146
+ filtrations = tuple(
147
+ tuple(
148
+ reduce_grid(x, strategy="exact")
149
+ for x in (X[idx][ax] for idx in indices)
150
+ )
151
+ for ax in range(self._num_axis)
152
+ )
153
+ num_parameters = len(filtrations[0][0])
154
+ assert (
155
+ num_parameters == self.num_parameters
156
+ ), f"Internal error, got {num_parameters=} and {self.num_parameters=}"
157
+
158
+ filtrations_values = [
159
+ [
160
+ np.unique(np.concatenate([x[i] for x in filtrations[ax]]))
161
+ for i in range(num_parameters)
162
+ ]
163
+ for ax in range(self._num_axis)
164
+ ]
165
+ ## ax, param, gridsize
166
+ filtration_grid = tuple(
167
+ reduce_grid(
168
+ filtrations_values[ax],
169
+ resolution=self.resolution,
170
+ strategy=self.grid_strategy,
171
+ )
172
+ for ax in range(self._num_axis)
173
+ ) # TODO :use more parameters
174
+ self.filtration_grid = filtration_grid
175
+ return filtration_grid
176
+
177
+ def _params_check(self):
178
+ assert (
179
+ self.resolution is not None
180
+ or self.filtration_grid is not None
181
+ or self.grid_strategy == "exact"
182
+ or self.individual_grid
183
+ ), "For non exact filtrations, a resolution has to be specified."
184
+
185
+ def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
186
+ self._params_check()
187
+ self._input_checks(X)
188
+
189
+ if isinstance(self.resolution, int):
190
+ self.resolution = [self.resolution] * self.num_parameters
191
+
192
+ self.individual_grid = (
193
+ self.individual_grid
194
+ if self.individual_grid is not None
195
+ else self.grid_strategy
196
+ in ["regular_closest", "exact", "quantile", "partition"]
197
+ )
198
+
199
+ if (
200
+ not self.enforce_null_mass
201
+ and self.individual_grid
202
+ or self.filtration_grid is not None
203
+ ):
204
+ self._refit_grid = False
205
+ else:
206
+ self._refit_grid = True
207
+
208
+ if self._refit_grid:
209
+ self._infer_filtration(X=X)
210
+ if self.out_resolution is None:
211
+ self.out_resolution = self.resolution
212
+ # elif isinstance(self.out_resolution, int):
213
+ # self.out_resolution = [self.out_resolution] * self.num_parameters
214
+ if self.normalize_filtrations and not self.individual_grid:
215
+ # self._reconversion_grid = [np.linspace(0,1, num=len(f), dtype=float) for f in self.filtration_grid] ## This will not work for non-regular grids...
216
+ self._reconversion_grid = [
217
+ [(f - np.min(f)) / np.std(f) for f in F] for F in self.filtration_grid
218
+ ] # not the best, but better than some weird magic
219
+ # elif not self.sparse: # It actually renormalizes the filtration !!
220
+ # self._reconversion_grid = [np.linspace(0,r, num=r, dtype=int) for r in self.out_resolution]
221
+ else:
222
+ self._reconversion_grid = self.filtration_grid
223
+ ## ax, num_param
224
+ self._default_mass_location = (
225
+ np.asarray([[g[-1] for g in F] for F in self.filtration_grid])
226
+ if self.enforce_null_mass
227
+ else None
228
+ )
229
+ return self
230
+
231
+ def transform1(
232
+ self,
233
+ simplextree,
234
+ ax,
235
+ # _reconversion_grid,
236
+ thread_id: str = "",
237
+ ):
238
+ # st = mp.SimplexTreeMulti(st, num_parameters=st.num_parameters) # COPY
239
+ if self.individual_grid:
240
+ filtration_grid = reduce_grid(
241
+ simplextree, strategy=self.grid_strategy, resolution=self.resolution
242
+ )
243
+ mass_default = (
244
+ self._default_mass_location[ax] if self.enforce_null_mass else None
245
+ )
246
+ if self.enforce_null_mass:
247
+ filtration_grid = [
248
+ np.concatenate([f, [d]], axis=0)
249
+ for f, d in zip(filtration_grid, mass_default)
250
+ ]
251
+ _reconversion_grid = filtration_grid
252
+ else:
253
+ filtration_grid = self.filtration_grid[ax]
254
+ _reconversion_grid = self._reconversion_grid[ax]
255
+ mass_default = (
256
+ self._default_mass_location[ax] if self.enforce_null_mass else None
257
+ )
258
+
259
+ st = simplextree.grid_squeeze(filtration_grid=filtration_grid)
260
+ if st.num_parameters == 2 and mp.simplex_tree_multi.is_simplextree_multi(st):
261
+ st.collapse_edges(num=self.num_collapses, max_dimension=1)
262
+ int_degrees = np.asarray([d for d in self.degrees if d is not None], dtype=int)
263
+ # EULER. First as there is prune above dimension below
264
+ if self.expand and None in self.degrees:
265
+ st.expansion(st.num_vertices)
266
+ signed_measures_euler = (
267
+ mp.signed_measure(
268
+ st,
269
+ degrees=[None],
270
+ plot=self.plot,
271
+ mass_default=mass_default,
272
+ invariant="euler",
273
+ # thread_id=thread_id,
274
+ backend=self.backend,
275
+ grid=_reconversion_grid,
276
+ )[0]
277
+ if None in self.degrees
278
+ else []
279
+ )
280
+
281
+ if self.expand and len(int_degrees) > 0:
282
+ st.expansion(np.max(int_degrees) + 1)
283
+ if len(int_degrees) > 0:
284
+ st.prune_above_dimension(
285
+ np.max(np.concatenate([int_degrees, self.rank_degrees])) + 1
286
+ ) # no need to compute homology beyond this
287
+ signed_measures_pers = (
288
+ mp.signed_measure(
289
+ st,
290
+ degrees=int_degrees,
291
+ mass_default=mass_default,
292
+ plot=self.plot,
293
+ invariant="hilbert",
294
+ thread_id=thread_id,
295
+ backend=self.backend,
296
+ grid=_reconversion_grid,
297
+ )
298
+ if len(int_degrees) > 0
299
+ else []
300
+ )
301
+ if self.plot:
302
+ plt.show()
303
+ if self.expand and len(self.rank_degrees) > 0:
304
+ st.expansion(np.max(self.rank_degrees) + 1)
305
+ if len(self.rank_degrees) > 0:
306
+ st.prune_above_dimension(
307
+ np.max(self.rank_degrees) + 1
308
+ ) # no need to compute homology beyond this
309
+ signed_measures_rank = (
310
+ mp.signed_measure(
311
+ st,
312
+ degrees=self.rank_degrees,
313
+ mass_default=mass_default,
314
+ plot=self.plot,
315
+ invariant="rank",
316
+ thread_id=thread_id,
317
+ backend=self.backend,
318
+ grid=_reconversion_grid,
319
+ )
320
+ if len(self.rank_degrees) > 0
321
+ else []
322
+ )
323
+ if self.plot:
324
+ plt.show()
325
+
326
+ count = 0
327
+ signed_measures = []
328
+ for d in self.degrees:
329
+ if d is None:
330
+ signed_measures.append(signed_measures_euler)
331
+ else:
332
+ signed_measures.append(signed_measures_pers[count])
333
+ count += 1
334
+ signed_measures += signed_measures_rank
335
+ return signed_measures
336
+
337
+ def transform(self, X):
338
+ ## X of shape (num_x, num_axis, filtered_complex
339
+ assert (
340
+ self.filtration_grid is not None and self._reconversion_grid is not None
341
+ ) or self.individual_grid, "Fit first"
342
+
343
+ def todo_x(x):
344
+ return tuple(self.transform1(x_axis, j) for j, x_axis in enumerate(x))
345
+
346
+ ## out shape num_x, num_axis, degree, sm
347
+ out = tuple(
348
+ Parallel(n_jobs=self.n_jobs, backend="threading")(
349
+ delayed(todo_x)(x) for x in X
350
+ )
351
+ )
352
+ # out = Parallel(n_jobs=self.n_jobs, backend="threading")(
353
+ # delayed(self.transform1)(to_st, thread_id=str(thread_id))
354
+ # for thread_id, to_st in tqdm(
355
+ # enumerate(X),
356
+ # disable=not self.progress,
357
+ # desc="Computing signed measure decompositions",
358
+ # )
359
+ # )
360
+ return out
361
+
362
+
363
+ class SimplexTree2SignedMeasure(FilteredComplex2SignedMeasure):
364
+ def __init__(
365
+ self,
366
+ # homological degrees + None for euler
367
+ degrees: list[int | None] = [],
368
+ rank_degrees: list[int] = [], # same for rank invariant
369
+ filtration_grid: (
370
+ Sequence[Sequence[np.ndarray]]
371
+ # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
372
+ | None
373
+ ) = None,
374
+ progress=False, # tqdm
375
+ num_collapses: int | str = 0, # edge collapses before computing
376
+ n_jobs=None,
377
+ resolution: (
378
+ Iterable[int] | int | None
379
+ ) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
380
+ # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
381
+ plot: bool = False,
382
+ filtration_quantile: float = 0.0, # quantile for inferring filtration grid
383
+ # wether or not to do the möbius inversion (not recommended to touch)
384
+ # _möbius_inversion: bool = True,
385
+ expand=False, # expand the simplextree befoe computing the homology
386
+ normalize_filtrations: bool = False,
387
+ # exact_computation:bool=False, # compute the exact signed measure.
388
+ grid_strategy: str = "exact",
389
+ seed: int = 0, # if fit_fraction is not 1, the seed sampling
390
+ fit_fraction=1, # the fraction of the data on which to fit
391
+ out_resolution: Iterable[int] | int | None = None,
392
+ individual_grid: Optional[
393
+ bool
394
+ ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
395
+ enforce_null_mass: bool = False,
396
+ flatten=True,
397
+ backend: Optional[str] = None,
398
+ ):
399
+ stuff = locals()
400
+ stuff.pop("self")
401
+ keys = list(stuff.keys())
402
+ for key in keys:
403
+ if key.startswith("__"):
404
+ stuff.pop(key)
405
+ super().__init__(**stuff)
406
+ from warnings import warn
407
+
408
+ warn("This class is deprecated, use FilteredComplex2SignedMeasure instead.")
409
+
410
+
411
+ # class SimplexTrees2SignedMeasures(SimplexTree2SignedMeasure):
412
+ # """
413
+ # Input
414
+ # -----
415
+ #
416
+ # (data) x (axis, e.g. different bandwidths for simplextrees) x (simplextree)
417
+ #
418
+ # Output
419
+ # ------
420
+ # (data) x (axis) x (degree) x (signed measure)
421
+ # """
422
+ #
423
+ # def __init__(self, **kwargs):
424
+ # super().__init__(**kwargs)
425
+ # self._num_st_per_data = None
426
+ # # self._super_model=SimplexTree2SignedMeasure(**kwargs)
427
+ # self._filtration_grids = None
428
+ # return
429
+ #
430
+ # def fit(self, X, y=None):
431
+ # if len(X) == 0:
432
+ # return self
433
+ # try:
434
+ # self._num_st_per_data = len(X[0])
435
+ # except:
436
+ # raise Exception(
437
+ # "Shape has to be (num_data, num_axis), dtype=SimplexTreeMulti"
438
+ # )
439
+ # self._filtration_grids = []
440
+ # for axis in range(self._num_st_per_data):
441
+ # self._filtration_grids.append(
442
+ # super().fit([x[axis] for x in X]).filtration_grid
443
+ # )
444
+ # # self._super_fits.append(truc)
445
+ # # self._super_fits_params = [super().fit([x[axis] for x in X]).get_params() for axis in range(self._num_st_per_data)]
446
+ # return self
447
+ #
448
+ # def transform(self, X):
449
+ # if self.normalize_filtrations:
450
+ # _reconversion_grids = [
451
+ # [np.linspace(0, 1, num=len(f), dtype=float) for f in F]
452
+ # for F in self._filtration_grids
453
+ # ]
454
+ # else:
455
+ # _reconversion_grids = self._filtration_grids
456
+ #
457
+ # def todo(x):
458
+ # # return [SimplexTree2SignedMeasure().set_params(**transformer_params).transform1(x[axis]) for axis,transformer_params in enumerate(self._super_fits_params)]
459
+ # out = [
460
+ # self.transform1(
461
+ # x[axis],
462
+ # filtration_grid=filtration_grid,
463
+ # _reconversion_grid=_reconversion_grid,
464
+ # )
465
+ # for axis, filtration_grid, _reconversion_grid in zip(
466
+ # range(self._num_st_per_data),
467
+ # self._filtration_grids,
468
+ # _reconversion_grids,
469
+ # )
470
+ # ]
471
+ # return out
472
+ #
473
+ # return Parallel(n_jobs=self.n_jobs, backend="threading")(
474
+ # delayed(todo)(x)
475
+ # for x in tqdm(
476
+ # X,
477
+ # disable=not self.progress,
478
+ # desc="Computing Signed Measures from simplextrees.",
479
+ # )
480
+ # )
481
+
482
+
483
+ def rescale_sparse_signed_measure(
484
+ signed_measure, filtration_weights, normalize_scales=None
485
+ ):
486
+ # from copy import deepcopy
487
+ #
488
+ # out = deepcopy(signed_measure)
489
+
490
+ if filtration_weights is None and normalize_scales is None:
491
+ return signed_measure
492
+
493
+ # if normalize_scales is None:
494
+ # out = tuple(
495
+ # (
496
+ # _cat(
497
+ # tuple(
498
+ # signed_measure[degree][0][:, parameter]
499
+ # * filtration_weights[parameter]
500
+ # for parameter in range(num_parameters)
501
+ # ),
502
+ # axis=1,
503
+ # ),
504
+ # signed_measure[degree][1],
505
+ # )
506
+ # for degree in range(len(signed_measure))
507
+ # )
508
+ # for degree in range(len(signed_measure)): # degree
509
+ # for parameter in range(len(filtration_weights)):
510
+ # signed_measure[degree][0][:, parameter] *= filtration_weights[parameter]
511
+ # # TODO Broadcast w.r.t. the parameter
512
+ # out = tuple(
513
+ # _cat(
514
+ # tuple(
515
+ # signed_measure[degree][0][:, [parameter]]
516
+ # * filtration_weights[parameter]
517
+ # / (
518
+ # normalize_scales[degree][parameter]
519
+ # if normalize_scales is not None
520
+ # else 1
521
+ # )
522
+ # for parameter in range(num_parameters)
523
+ # ),
524
+ # axis=1,
525
+ # )
526
+ # for degree in range(len(signed_measure))
527
+ # )
528
+ out = tuple(
529
+ (
530
+ signed_measure[degree][0]
531
+ * (1 if filtration_weights is None else filtration_weights.reshape(1, -1))
532
+ / (
533
+ normalize_scales[degree].reshape(1, -1)
534
+ if normalize_scales is not None
535
+ else 1
536
+ ),
537
+ signed_measure[degree][1],
538
+ )
539
+ for degree in range(len(signed_measure))
540
+ )
541
+ # for degree in range(len(out)):
542
+ # for parameter in range(len(filtration_weights)):
543
+ # out[degree][0][:, parameter] *= (
544
+ # filtration_weights[parameter] / normalize_scales[degree][parameter]
545
+ # )
546
+ return out
547
+
548
+
549
+ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
550
+ """
551
+ Input
552
+ -----
553
+
554
+ (data) x (degree) x (signed measure) or (data) x (axis) x (degree) x (signed measure)
555
+
556
+ Iterable[list[signed_measure_matrix of degree]] or Iterable[previous].
557
+
558
+ The second is meant to use multiple choices for signed measure input. An example of usage : they come from a Rips + Density with different bandwidth.
559
+ It is controlled by the axis parameter.
560
+
561
+ Output
562
+ ------
563
+
564
+ Iterable[list[(reweighted)_sparse_signed_measure of degree]]
565
+
566
+ or (deep format)
567
+
568
+ Tensor of shape (num_axis*num_degrees, data, max_num_pts, num_parameters)
569
+ """
570
+
571
+ def __init__(
572
+ self,
573
+ filtrations_weights: Optional[Iterable[float]] = None,
574
+ normalize=False,
575
+ plot: bool = False,
576
+ unsparse: bool = False,
577
+ axis: int = -1,
578
+ resolution: int | Iterable[int] = 50,
579
+ flatten: bool = False,
580
+ deep_format: bool = False,
581
+ unrag: bool = True,
582
+ n_jobs: int = 1,
583
+ verbose: bool = False,
584
+ integrate: bool = False,
585
+ grid_strategy="regular",
586
+ ):
587
+ super().__init__()
588
+ self.filtrations_weights = filtrations_weights
589
+ self.num_parameters: int = 0
590
+ self.plot = plot
591
+ self.unsparse = unsparse
592
+ self.n_jobs = n_jobs
593
+ self.axis = axis
594
+ self._num_axis = 0
595
+ self.resolution = resolution
596
+ self._filtrations_bounds = None
597
+ self.flatten = flatten
598
+ self.normalize = normalize
599
+ self._normalization_factors = None
600
+ self.deep_format = deep_format
601
+ self.unrag = unrag
602
+ assert (
603
+ not self.deep_format or not self.unsparse or not self.integrate
604
+ ), "One post processing at the time."
605
+ self.verbose = verbose
606
+ self._num_degrees = 0
607
+ self.integrate = integrate
608
+ self.grid_strategy = grid_strategy
609
+ self._infered_grids = None
610
+ self._axis_iterator = None
611
+ self._backend = None
612
+ return
613
+
614
+ def _get_filtration_bounds(self, X, axis):
615
+ if self._backend == "numpy":
616
+ _cat = np.concatenate
617
+
618
+ else:
619
+ ## torch is globally imported
620
+ _cat = torch.cat
621
+ stuff = [
622
+ _cat(
623
+ [sm[axis][degree][0] for sm in X],
624
+ axis=0,
625
+ )
626
+ for degree in range(self._num_degrees)
627
+ ]
628
+ sizes_ = np.array([len(x) == 0 for x in stuff])
629
+ assert np.all(~sizes_), f"Degree axis {np.where(sizes_)} is/are trivial !"
630
+ if self._backend == "numpy":
631
+ filtrations_bounds = np.array(
632
+ [([f.min(axis=0), f.max(axis=0)]) for f in stuff]
633
+ )
634
+ else:
635
+ filtrations_bounds = torch.stack(
636
+ [
637
+ torch.stack([f.min(axis=0).values, f.max(axis=0).values])
638
+ for f in stuff
639
+ ]
640
+ ).detach() ## don't want to rescale gradient of normalization
641
+ normalization_factors = (
642
+ filtrations_bounds[:, 1] - filtrations_bounds[:, 0]
643
+ if self.normalize
644
+ else None
645
+ )
646
+ # print("Normalization factors : ",self._normalization_factors)
647
+ if (normalization_factors == 0).any():
648
+ indices = normalization_factors == 0
649
+ # warn(f"Constant filtration encountered, at degree, parameter {indices} and axis {self.axis}.")
650
+ normalization_factors[indices] = 1
651
+ return filtrations_bounds, normalization_factors
652
+
653
+ def _plot_signed_measures(self, sms: Iterable[np.ndarray], size=4):
654
+ from multipers.plots import plot_signed_measure
655
+
656
+ num_degrees = len(sms[0])
657
+ num_imgs = len(sms)
658
+ fig, axes = plt.subplots(
659
+ ncols=num_degrees,
660
+ nrows=num_imgs,
661
+ figsize=(size * num_degrees, size * num_imgs),
662
+ )
663
+ axes = np.asarray(axes).reshape(num_imgs, num_degrees)
664
+ # assert axes.ndim==2, "Internal error"
665
+ for i, sm in enumerate(sms):
666
+ for j, sm_of_degree in enumerate(sm):
667
+ plot_signed_measure(sm_of_degree, ax=axes[i, j])
668
+
669
+ @staticmethod
670
+ def _check_sm(sm) -> bool:
671
+ return (
672
+ isinstance(sm, tuple)
673
+ and hasattr(sm[0], "ndim")
674
+ and sm[0].ndim == 2
675
+ and len(sm) == 2
676
+ )
677
+
678
+ def _check_axis(self, X):
679
+ # axes should be (num_data, num_axis, num_degrees, (signed_measure))
680
+ if len(X) == 0:
681
+ return
682
+ if len(X[0]) == 0:
683
+ return
684
+ if self._check_sm(X[0][0]):
685
+ self._has_axis = False
686
+ self._num_axis = 1
687
+ self._axis_iterator = [slice(None)]
688
+ return
689
+ assert self._check_sm( ## vaguely checks that its a signed measure
690
+ _sm := X[0][0][0]
691
+ ), f"Cannot take this input. # data, axis, degrees, sm.\n Got {_sm} of type {type(_sm)}"
692
+
693
+ self._has_axis = True
694
+ self._num_axis = len(X[0])
695
+ self._axis_iterator = range(self._num_axis) if self.axis == -1 else [self.axis]
696
+
697
+ def _check_backend(self, X):
698
+ if self._has_axis:
699
+ # data, axis, degrees, (pts, weights)
700
+ first_sm = X[0][0][0][0]
701
+ else:
702
+ first_sm = X[0][0][0]
703
+ if isinstance(first_sm, np.ndarray):
704
+ self._backend = "numpy"
705
+ else:
706
+ global torch
707
+ import torch
708
+
709
+ assert isinstance(first_sm, torch.Tensor)
710
+ self._backend = "pytorch"
711
+
712
+ def _check_measures(self, X):
713
+ if self._has_axis:
714
+ first_sm = X[0][0]
715
+ else:
716
+ first_sm = X[0]
717
+ self._num_degrees = len(first_sm)
718
+ self.num_parameters = first_sm[0][0].shape[1]
719
+
720
+ def _check_resolution(self):
721
+ assert self.num_parameters > 0, "Num parameters hasn't been initialized."
722
+ if isinstance(self.resolution, int):
723
+ self.resolution = [self.resolution] * self.num_parameters
724
+ self.resolution = np.asarray(self.resolution, dtype=int)
725
+ assert (
726
+ self.resolution.shape[0] == self.num_parameters
727
+ ), "Resolution doesn't have a proper size."
728
+
729
+ def _check_weights(self):
730
+ if self.filtrations_weights is None:
731
+ return
732
+ assert (
733
+ self.filtrations_weights.shape[0] == self.num_parameters
734
+ ), "Filtration weights don't have a proper size"
735
+
736
+ def _infer_grids(self, X):
737
+ # Computes normalization factors
738
+ if self.normalize:
739
+ # if self._has_axis and self.axis == -1:
740
+ self._filtrations_bounds = []
741
+ self._normalization_factors = []
742
+ for ax in self._axis_iterator:
743
+ (
744
+ filtration_bounds,
745
+ normalization_factors,
746
+ ) = self._get_filtration_bounds(X, axis=ax)
747
+ self._filtrations_bounds.append(filtration_bounds)
748
+ self._normalization_factors.append(normalization_factors)
749
+ # else:
750
+ # (
751
+ # self._filtrations_bounds,
752
+ # self._normalization_factors,
753
+ # ) = self._get_filtration_bounds(
754
+ # X, axis=self._axis_iterator[0]
755
+ # ) ## axis = slice(None)
756
+ elif self.integrate or self.unsparse or self.deep_format:
757
+ filtration_values = [
758
+ np.concatenate(
759
+ [
760
+ (
761
+ stuff
762
+ if isinstance(stuff := x[ax][degree][0], np.ndarray)
763
+ else stuff.detach().numpy()
764
+ )
765
+ for x in X
766
+ for degree in range(self._num_degrees)
767
+ ]
768
+ )
769
+ for ax in self._axis_iterator
770
+ ]
771
+ # axis, filtration_values
772
+ filtration_values = [
773
+ reduce_grid(
774
+ f_ax.T, resolution=self.resolution, strategy=self.grid_strategy
775
+ )
776
+ for f_ax in filtration_values
777
+ ]
778
+ self._infered_grids = filtration_values
779
+
780
+ def _print_stats(self, X):
781
+ print("------------SignedMeasureFormatter------------")
782
+ print("---- Parameters")
783
+ print(f"Number of axis : {self._num_axis}")
784
+ print(f"Number of degrees : {self._num_degrees}")
785
+ print(f"Filtration bounds : \n{self._filtrations_bounds}")
786
+ print(f"Normalization factor : \n{self._normalization_factors}")
787
+ if self._infered_grids is not None:
788
+ print(
789
+ f"Filtration grid shape : \n \
790
+ {tuple(tuple(len(f) for f in F) for F in self._infered_grids)}"
791
+ )
792
+ print("---- SM stats")
793
+ print("In axis :", self._num_axis)
794
+ sizes = [
795
+ [[len(xd[1]) for xd in x[ax]] for x in X] for ax in self._axis_iterator
796
+ ]
797
+ print(f"Size means (axis) x (degree): {np.mean(sizes, axis=(1))}")
798
+ print(f"Size std : {np.std(sizes, axis=(1))}")
799
+ print("----------------------------------------------")
800
+
801
+ def fit(self, X, y=None):
802
+ # Gets a grid. This will be the max in each coord+1
803
+ if (
804
+ len(X) == 0
805
+ or len(X[0]) == 0
806
+ or (self.axis is not None and len(X[0][0][0]) == 0)
807
+ ):
808
+ return self
809
+
810
+ self._check_axis(X)
811
+ self._check_backend(X)
812
+ self._check_measures(X)
813
+ self._check_resolution()
814
+ self._check_weights()
815
+ # if not sparse : not recommended.
816
+
817
+ self._infer_grids(X)
818
+ if self.verbose:
819
+ self._print_stats(X)
820
+ return self
821
+
822
+ def unsparse_signed_measure(self, sparse_signed_measure):
823
+ filtrations = self._infered_grids # ax, filtration
824
+ out = []
825
+ for filtrations_of_ax, ax in zip(filtrations, self._axis_iterator, strict=True):
826
+ sparse_signed_measure_of_ax = sparse_signed_measure[ax]
827
+ measure_of_ax = []
828
+ for pts, weights in sparse_signed_measure_of_ax: # over degree
829
+ signed_measure, _ = np.histogramdd(
830
+ pts, bins=filtrations_of_ax, weights=weights
831
+ )
832
+ if self.flatten:
833
+ signed_measure = signed_measure.flatten()
834
+ measure_of_ax.append(signed_measure)
835
+ out.append(np.asarray(measure_of_ax))
836
+
837
+ if self.flatten:
838
+ out = np.concatenate(out).flatten()
839
+ if self.axis == -1:
840
+ return np.asarray(out)
841
+ else:
842
+ return np.asarray(out)[0]
843
+
844
+ @staticmethod
845
+ def deep_format_measure(signed_measure):
846
+ dirac_positions, dirac_signs = signed_measure
847
+ dtype = dirac_positions.dtype
848
+ new_shape = list(dirac_positions.shape)
849
+ new_shape[1] += 1
850
+ if isinstance(dirac_positions, np.ndarray):
851
+ c = np.empty(new_shape, dtype=dtype)
852
+ c[:, :-1] = dirac_positions
853
+ c[:, -1] = dirac_signs
854
+
855
+ else:
856
+ import torch
857
+
858
+ c = torch.empty(new_shape, dtype=dtype)
859
+ c[:, :-1] = dirac_positions
860
+ c[:, -1] = dirac_signs
861
+ return c
862
+
863
+ @staticmethod
864
+ def _integrate_measure(sm, filtrations):
865
+ from multipers.point_measure import integrate_measure
866
+
867
+ return integrate_measure(sm[0], sm[1], filtrations)
868
+
869
+ def _rescale_measures(self, X):
870
+ def rescale_from_sparse(sparse_signed_measure):
871
+ if self.axis == -1 and self._has_axis:
872
+ return tuple(
873
+ rescale_sparse_signed_measure(
874
+ sparse_signed_measure[ax],
875
+ filtration_weights=self.filtrations_weights,
876
+ normalize_scales=n,
877
+ )
878
+ for ax, n in zip(
879
+ self._axis_iterator, self._normalization_factors, strict=True
880
+ )
881
+ )
882
+ return rescale_sparse_signed_measure( ## axis iterator is of size 1 here
883
+ sparse_signed_measure,
884
+ filtration_weights=self.filtrations_weights,
885
+ normalize_scales=self._normalization_factors[0],
886
+ )
887
+
888
+ out = tuple(rescale_from_sparse(x) for x in X)
889
+ return out
890
+
891
+ def transform(self, X):
892
+ if not self._has_axis or self.axis == -1:
893
+ out = X
894
+ else:
895
+ out = tuple(x[self.axis] for x in X)
896
+ # same format for everyone
897
+
898
+ if self._normalization_factors is not None:
899
+ out = self._rescale_measures(out)
900
+
901
+ if self.plot:
902
+ # assert ax != -1, "Not implemented"
903
+ self._plot_signed_measures(out)
904
+ if self.integrate:
905
+ filtrations = self._infered_grids
906
+ # if self.axis != -1:
907
+ ax = 0 # if self.axis is None else self.axis # TODO deal with axis -1
908
+
909
+ assert ax != -1, "Not implemented. Can only integrate with axis"
910
+ # try:
911
+ out = np.asarray(
912
+ [
913
+ [
914
+ self._integrate_measure(x[degree], filtrations=filtrations[ax])
915
+ for degree in range(self._num_degrees)
916
+ ]
917
+ for x in out
918
+ ]
919
+ )
920
+ # except:
921
+ # print(self.axis, ax, filtrations)
922
+ if self.flatten:
923
+ out = out.reshape((len(X), -1))
924
+ # else:
925
+ # out = [[[self._integrate_measure(x[axis][degree],filtrations=filtrations[degree].T) for degree in range(self._num_degrees)] for axis in range(self._num_axis)] for x in out]
926
+ elif self.unsparse:
927
+ out = [self.unsparse_signed_measure(x) for x in out]
928
+ elif self.deep_format:
929
+ num_degrees = self._num_degrees
930
+ out = tuple(
931
+ tuple(self.deep_format_measure(sm[axis][degree]) for sm in out)
932
+ for degree in range(num_degrees)
933
+ for axis in self._axis_iterator
934
+ )
935
+ if self.unrag:
936
+ max_num_pts = np.max(
937
+ [sm.shape[0] for sm_of_axis in out for sm in sm_of_axis]
938
+ )
939
+ num_axis_degree = len(out)
940
+ num_data = len(out[0])
941
+ assert num_axis_degree == num_degrees * (
942
+ self._num_axis if self._has_axis else 1
943
+ ), f"Bad axis/degree count. Got {num_axis_degree} (Internal error)"
944
+ num_parameters = out[0][0].shape[1]
945
+ dtype = out[0][0].dtype
946
+ if isinstance(out[0][0], np.ndarray):
947
+ from numpy import zeros
948
+ else:
949
+ from torch import zeros
950
+ unragged_tensor = zeros(
951
+ (
952
+ num_axis_degree,
953
+ num_data,
954
+ max_num_pts,
955
+ num_parameters,
956
+ ),
957
+ dtype=dtype,
958
+ )
959
+ for ax in range(num_axis_degree):
960
+ for data in range(num_data):
961
+ sm = out[ax][data]
962
+ a, b = sm.shape
963
+ unragged_tensor[ax, data, :a, :b] = sm
964
+ out = unragged_tensor
965
+ return out
966
+
967
+
968
+ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
969
+ """
970
+ Discrete convolution of a signed measure
971
+
972
+ Input
973
+ -----
974
+
975
+ (data) x (degree) x (signed measure)
976
+
977
+ Parameters
978
+ ----------
979
+ - filtration_grid : Iterable[array] For each filtration, the filtration values on which to evaluate the grid
980
+ - resolution : int or (num_parameters) : If filtration grid is not given, will infer a grid, with this resolution
981
+ - grid_strategy : the strategy to generate the grid. Available ones are regular, quantile, exact
982
+ - flatten : if true, the output will be flattened
983
+ - kernel : kernel to used to convolve the images.
984
+ - flatten : flatten the images if True
985
+ - progress : progress bar if True
986
+ - backend : sklearn, pykeops or numba.
987
+ - plot : Creates a plot Figure.
988
+
989
+ Output
990
+ ------
991
+
992
+ (data) x (concatenation of imgs of degree)
993
+ """
994
+
995
+ def __init__(
996
+ self,
997
+ filtration_grid: Iterable[np.ndarray] = None,
998
+ kernel: available_kernels = "gaussian",
999
+ bandwidth: float | Iterable[float] = 1.0,
1000
+ flatten: bool = False,
1001
+ n_jobs: int = 1,
1002
+ resolution: int | None = None,
1003
+ grid_strategy: str = "regular",
1004
+ progress: bool = False,
1005
+ backend: str = "pykeops",
1006
+ plot: bool = False,
1007
+ log_density: bool = False,
1008
+ **kde_kwargs,
1009
+ # **kwargs ## DANGEROUS
1010
+ ):
1011
+ super().__init__()
1012
+ self.kernel: available_kernels = kernel
1013
+ self.bandwidth = bandwidth
1014
+ # self.more_kde_kwargs=kwargs
1015
+ self.filtration_grid = filtration_grid
1016
+ self.flatten = flatten
1017
+ self.progress = progress
1018
+ self.n_jobs = n_jobs
1019
+ self.resolution = resolution
1020
+ self.grid_strategy = grid_strategy
1021
+ self._is_input_sparse = None
1022
+ self._refit = filtration_grid is None
1023
+ self._input_resolution = None
1024
+ self._bandwidths = None
1025
+ self.diameter = None
1026
+ self.backend = backend
1027
+ self.plot = plot
1028
+ self.log_density = log_density
1029
+ self.kde_kwargs = kde_kwargs
1030
+ return
1031
+
1032
+ def fit(self, X, y=None):
1033
+ # Infers if the input is sparse given X
1034
+ if len(X) == 0:
1035
+ return self
1036
+ if isinstance(X[0][0], tuple):
1037
+ self._is_input_sparse = True
1038
+ else:
1039
+ self._is_input_sparse = False
1040
+ # print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
1041
+ if not self._is_input_sparse:
1042
+ self._input_resolution = X[0][0].shape
1043
+ try:
1044
+ float(self.bandwidth)
1045
+ b = float(self.bandwidth)
1046
+ self._bandwidths = [
1047
+ b if b > 0 else -b * s for s in self._input_resolution
1048
+ ]
1049
+ except:
1050
+ self._bandwidths = [
1051
+ b if b > 0 else -b * s
1052
+ for s, b in zip(self._input_resolution, self.bandwidth)
1053
+ ]
1054
+ return self # in that case, singed measures are matrices, and the grid is already given
1055
+
1056
+ if self.filtration_grid is None and self.resolution is None:
1057
+ raise Exception(
1058
+ "Cannot infer filtration grid. Provide either a filtration grid or a resolution."
1059
+ )
1060
+ # If not sparse : a grid has to be defined
1061
+ if self._refit:
1062
+ # print("Fitting a grid...", end="")
1063
+ pts = np.concatenate(
1064
+ [sm[0] for signed_measures in X for sm in signed_measures]
1065
+ ).T
1066
+ self.filtration_grid = reduce_grid(
1067
+ pts,
1068
+ strategy=self.grid_strategy,
1069
+ resolution=self.resolution,
1070
+ )
1071
+ # print('Done.')
1072
+ if self.filtration_grid is not None:
1073
+ self.diameter = np.linalg.norm(
1074
+ [f.max() - f.min() for f in self.filtration_grid]
1075
+ )
1076
+ if self.progress:
1077
+ print(f"Computed a diameter of {self.diameter}")
1078
+ return self
1079
+
1080
+ def _sm2smi(self, signed_measures: Iterable[np.ndarray]):
1081
+ # print(self._input_resolution, self.bandwidths, _bandwidths)
1082
+ from scipy.ndimage import gaussian_filter
1083
+
1084
+ return np.concatenate(
1085
+ [
1086
+ gaussian_filter(
1087
+ input=signed_measure,
1088
+ sigma=self._bandwidths,
1089
+ mode="constant",
1090
+ cval=0,
1091
+ )
1092
+ for signed_measure in signed_measures
1093
+ ],
1094
+ axis=0,
1095
+ )
1096
+
1097
+ def _transform_from_sparse(self, X):
1098
+ bandwidth = (
1099
+ self.bandwidth if self.bandwidth > 0 else -self.bandwidth * self.diameter
1100
+ )
1101
+ # COMPILE KEOPS FIRST
1102
+ dummyx = [X[0]]
1103
+ dummyf = [f[:2] for f in self.filtration_grid]
1104
+ convolution_signed_measures(
1105
+ dummyx,
1106
+ filtrations=dummyf,
1107
+ bandwidth=bandwidth,
1108
+ flatten=self.flatten,
1109
+ n_jobs=1,
1110
+ kernel=self.kernel,
1111
+ backend=self.backend,
1112
+ )
1113
+
1114
+ return convolution_signed_measures(
1115
+ X,
1116
+ filtrations=self.filtration_grid,
1117
+ bandwidth=bandwidth,
1118
+ flatten=self.flatten,
1119
+ n_jobs=self.n_jobs,
1120
+ kernel=self.kernel,
1121
+ backend=self.backend,
1122
+ **self.kde_kwargs,
1123
+ )
1124
+
1125
+ def _plot_imgs(self, imgs: Iterable[np.ndarray], size=4):
1126
+ from multipers.plots import plot_surface
1127
+
1128
+ num_degrees = imgs[0].shape[0]
1129
+ num_imgs = len(imgs)
1130
+ fig, axes = plt.subplots(
1131
+ ncols=num_degrees,
1132
+ nrows=num_imgs,
1133
+ figsize=(size * num_degrees, size * num_imgs),
1134
+ )
1135
+ axes = np.asarray(axes).reshape(num_imgs, num_degrees)
1136
+ # assert axes.ndim==2, "Internal error"
1137
+ for i, img in enumerate(imgs):
1138
+ for j, img_of_degree in enumerate(img):
1139
+ plot_surface(
1140
+ self.filtration_grid, img_of_degree, ax=axes[i, j], cmap="Spectral"
1141
+ )
1142
+
1143
+ def transform(self, X):
1144
+ if self._is_input_sparse is None:
1145
+ raise Exception("Fit first")
1146
+ if self._is_input_sparse:
1147
+ out = self._transform_from_sparse(X)
1148
+ else:
1149
+ todo = SignedMeasure2Convolution._sm2smi
1150
+ out = Parallel(n_jobs=self.n_jobs, backend="threading")(
1151
+ delayed(todo)(self, signed_measures)
1152
+ for signed_measures in tqdm(
1153
+ X, desc="Computing images", disable=not self.progress
1154
+ )
1155
+ )
1156
+ if self.plot and not self.flatten:
1157
+ if self.progress:
1158
+ print("Plotting convolutions...", end="")
1159
+ self._plot_imgs(out)
1160
+ if self.progress:
1161
+ print("Done !")
1162
+ if self.flatten and not self._is_input_sparse:
1163
+ out = [x.flatten() for x in out]
1164
+ return np.asarray(out)
1165
+
1166
+
1167
+ class SignedMeasure2SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
1168
+ """
1169
+ Transformer from signed measure to distance matrix.
1170
+
1171
+ Input
1172
+ -----
1173
+
1174
+ (data) x (degree) x (signed measure)
1175
+
1176
+ Format
1177
+ ------
1178
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1179
+ - each data is a list of signed measure (for e.g. multiple degrees)
1180
+
1181
+ Output
1182
+ ------
1183
+ - (degree) x (distance matrix)
1184
+ """
1185
+
1186
+ def __init__(
1187
+ self,
1188
+ n_jobs=None,
1189
+ num_directions: int = 10,
1190
+ _sliced: bool = True,
1191
+ epsilon=-1,
1192
+ ground_norm=1,
1193
+ progress=False,
1194
+ grid_reconversion=None,
1195
+ scales=None,
1196
+ ):
1197
+ super().__init__()
1198
+ self.n_jobs = n_jobs
1199
+ self._SWD_list = None
1200
+ self._sliced = _sliced
1201
+ self.epsilon = epsilon
1202
+ self.ground_norm = ground_norm
1203
+ self.num_directions = num_directions
1204
+ self.progress = progress
1205
+ self.grid_reconversion = grid_reconversion
1206
+ self.scales = scales
1207
+ return
1208
+
1209
+ def fit(self, X, y=None):
1210
+ from multipers.ml.sliced_wasserstein import (
1211
+ SlicedWassersteinDistance,
1212
+ WassersteinDistance,
1213
+ )
1214
+
1215
+ # _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
1216
+ if len(X) == 0:
1217
+ return self
1218
+ num_degrees = len(X[0])
1219
+ self._SWD_list = [
1220
+ (
1221
+ SlicedWassersteinDistance(
1222
+ num_directions=self.num_directions,
1223
+ n_jobs=self.n_jobs,
1224
+ scales=self.scales,
1225
+ )
1226
+ if self._sliced
1227
+ else WassersteinDistance(
1228
+ epsilon=self.epsilon,
1229
+ ground_norm=self.ground_norm,
1230
+ n_jobs=self.n_jobs,
1231
+ )
1232
+ )
1233
+ for _ in range(num_degrees)
1234
+ ]
1235
+ for degree, swd in enumerate(self._SWD_list):
1236
+ signed_measures_of_degree = [x[degree] for x in X]
1237
+ swd.fit(signed_measures_of_degree)
1238
+ return self
1239
+
1240
+ def transform(self, X):
1241
+ assert self._SWD_list is not None, "Fit first"
1242
+ # out = []
1243
+ # for degree, swd in tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress):
1244
+ with tqdm(
1245
+ enumerate(self._SWD_list),
1246
+ desc="Computing distance matrices",
1247
+ total=len(self._SWD_list),
1248
+ disable=not self.progress,
1249
+ ) as SWD_it:
1250
+ # signed_measures_of_degree = [x[degree] for x in X]
1251
+ # out.append(swd.transform(signed_measures_of_degree))
1252
+ def todo(swd, X_of_degree):
1253
+ return swd.transform(X_of_degree)
1254
+
1255
+ out = Parallel(n_jobs=self.n_jobs, prefer="threads")(
1256
+ delayed(todo)(swd, [x[degree] for x in X]) for degree, swd in SWD_it
1257
+ )
1258
+ return np.asarray(out)
1259
+
1260
+ def predict(self, X):
1261
+ return self.transform(X)
1262
+
1263
+
1264
+ class SignedMeasures2SlicedWassersteinDistances(BaseEstimator, TransformerMixin):
1265
+ """
1266
+ Transformer from signed measure to distance matrix.
1267
+ Input
1268
+ -----
1269
+ (data) x opt (axis) x (degree) x (signed measure)
1270
+
1271
+ Format
1272
+ ------
1273
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1274
+ - each data is a list of signed measure (for e.g. multiple degrees)
1275
+
1276
+ Output
1277
+ ------
1278
+ - (axis) x (degree) x (distance matrix)
1279
+ """
1280
+
1281
+ def __init__(
1282
+ self,
1283
+ progress=False,
1284
+ n_jobs: int = 1,
1285
+ scales: Iterable[Iterable[float]] | None = None,
1286
+ **kwargs,
1287
+ ): # same init
1288
+ self._init_child = SignedMeasure2SlicedWassersteinDistance(
1289
+ progress=False, scales=None, n_jobs=-1, **kwargs
1290
+ )
1291
+ self._axe_iterator = None
1292
+ self._childs_to_fit = None
1293
+ self.scales = scales
1294
+ self.progress = progress
1295
+ self.n_jobs = n_jobs
1296
+ return
1297
+
1298
+ def fit(self, X, y=None):
1299
+ from sklearn.base import clone
1300
+
1301
+ if len(X) == 0:
1302
+ return self
1303
+ if isinstance(X[0][0], tuple): # Meaning that there are no axes
1304
+ self._axe_iterator = [slice(None)]
1305
+ else:
1306
+ self._axe_iterator = range(len(X[0]))
1307
+ if self.scales is None:
1308
+ self.scales = [None]
1309
+ else:
1310
+ self.scales = np.asarray(self.scales)
1311
+ if self.scales.ndim == 1:
1312
+ self.scales = np.asarray([self.scales])
1313
+ assert (
1314
+ self.scales[0] is None or self.scales.ndim == 2
1315
+ ), "Scales have to be either None or a list of scales !"
1316
+ self._childs_to_fit = [
1317
+ clone(self._init_child).set_params(scales=scales).fit([x[axis] for x in X])
1318
+ for axis, scales in product(self._axe_iterator, self.scales)
1319
+ ]
1320
+ print("New axes : ", list(product(self._axe_iterator, self.scales)))
1321
+ return self
1322
+
1323
+ def transform(self, X):
1324
+ return Parallel(n_jobs=self.n_jobs, prefer="processes")(
1325
+ delayed(self._childs_to_fit[child_id].transform)([x[axis] for x in X])
1326
+ for child_id, (axis, _) in tqdm(
1327
+ enumerate(product(self._axe_iterator, self.scales)),
1328
+ desc=f"Computing distances matrices of axis, and scales",
1329
+ disable=not self.progress,
1330
+ total=len(self._childs_to_fit),
1331
+ )
1332
+ )
1333
+ # [
1334
+ # child.transform([x[axis // len(self.scales)] for x in X])
1335
+ # for axis, child in tqdm(enumerate(self._childs_to_fit),
1336
+ # desc=f"Computing distances of axis", disable=not self.progress, total=len(self._childs_to_fit)
1337
+ # )
1338
+ # ]
1339
+
1340
+
1341
+ class SimplexTree2RectangleDecomposition(BaseEstimator, TransformerMixin):
1342
+ """
1343
+ Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition.
1344
+ """
1345
+
1346
+ def __init__(
1347
+ self,
1348
+ filtration_grid: np.ndarray,
1349
+ degrees: Iterable[int],
1350
+ plot=False,
1351
+ reconvert_grid=True,
1352
+ num_collapses: int = 0,
1353
+ ):
1354
+ super().__init__()
1355
+ self.filtration_grid = filtration_grid
1356
+ self.degrees = degrees
1357
+ self.plot = plot
1358
+ self.reconvert_grid = reconvert_grid
1359
+ self.num_collapses = num_collapses
1360
+ return
1361
+
1362
+ def fit(self, X, y=None):
1363
+ """
1364
+ TODO : infer grid from multiple simplextrees
1365
+ """
1366
+ return self
1367
+
1368
+ def transform(self, X: Iterable[mp.simplex_tree_multi.SimplexTreeMulti_type]):
1369
+ rectangle_decompositions = [
1370
+ [
1371
+ _st2ranktensor(
1372
+ simplextree,
1373
+ filtration_grid=self.filtration_grid,
1374
+ degree=degree,
1375
+ plot=self.plot,
1376
+ reconvert_grid=self.reconvert_grid,
1377
+ num_collapse=self.num_collapses,
1378
+ )
1379
+ for degree in self.degrees
1380
+ ]
1381
+ for simplextree in X
1382
+ ]
1383
+ # TODO : return iterator ?
1384
+ return rectangle_decompositions
1385
+
1386
+
1387
+ def _st2ranktensor(
1388
+ st: mp.simplex_tree_multi.SimplexTreeMulti_type,
1389
+ filtration_grid: np.ndarray,
1390
+ degree: int,
1391
+ plot: bool,
1392
+ reconvert_grid: bool,
1393
+ num_collapse: int | str = 0,
1394
+ ):
1395
+ """
1396
+ TODO
1397
+ """
1398
+ # Copy (the squeeze change the filtration values)
1399
+ # stcpy = mp.SimplexTreeMulti(st)
1400
+ # turns the simplextree into a coordinate simplex tree
1401
+ stcpy = st.grid_squeeze(filtration_grid=filtration_grid, coordinate_values=True)
1402
+ # stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
1403
+ if num_collapse == "full":
1404
+ stcpy.collapse_edges(full=True, ignore_warning=True, max_dimension=degree + 1)
1405
+ elif isinstance(num_collapse, int):
1406
+ stcpy.collapse_edges(
1407
+ num=num_collapse, ignore_warning=True, max_dimension=degree + 1
1408
+ )
1409
+ else:
1410
+ raise TypeError(
1411
+ f"Invalid num_collapse=\
1412
+ {num_collapse} type. Either full, or an integer."
1413
+ )
1414
+ # computes the rank invariant tensor
1415
+ rank_tensor = mp.rank_invariant2d(
1416
+ stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid]
1417
+ )
1418
+ # refactor this tensor into the rectangle decomposition of the signed betti
1419
+ grid_conversion = filtration_grid if reconvert_grid else None
1420
+ rank_decomposition = rank_decomposition_by_rectangles(
1421
+ rank_tensor,
1422
+ threshold=True,
1423
+ )
1424
+ rectangle_decomposition = tensor_möbius_inversion(
1425
+ tensor=rank_decomposition,
1426
+ grid_conversion=grid_conversion,
1427
+ plot=plot,
1428
+ num_parameters=st.num_parameters,
1429
+ )
1430
+ return rectangle_decomposition
1431
+
1432
+
1433
+ class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
1434
+ def __init__(
1435
+ self,
1436
+ degrees: Iterable[int],
1437
+ min_rips_value: float,
1438
+ max_rips_value,
1439
+ max_normalized_degree: float,
1440
+ min_normalized_degree: float,
1441
+ grid_granularity: int,
1442
+ progress: bool = False,
1443
+ n_jobs=1,
1444
+ sparse: bool = False,
1445
+ _möbius_inversion=True,
1446
+ fit_fraction=1,
1447
+ ) -> None:
1448
+ super().__init__()
1449
+ self.min_rips_value = min_rips_value
1450
+ self.max_rips_value = max_rips_value
1451
+ self.min_normalized_degree = min_normalized_degree
1452
+ self.max_normalized_degree = max_normalized_degree
1453
+ self._max_rips_value = None
1454
+ self.grid_granularity = grid_granularity
1455
+ self.progress = progress
1456
+ self.n_jobs = n_jobs
1457
+ self.degrees = degrees
1458
+ self.sparse = sparse
1459
+ self._möbius_inversion = _möbius_inversion
1460
+ self.fit_fraction = fit_fraction
1461
+ return
1462
+
1463
+ def fit(self, X: np.ndarray | list, y=None):
1464
+ if self.max_rips_value < 0:
1465
+ print("Estimating scale...", flush=True, end="")
1466
+ indices = np.random.choice(
1467
+ len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
1468
+ )
1469
+ diameters = np.max(
1470
+ [distance_matrix(x, x).max() for x in (X[i] for i in indices)]
1471
+ )
1472
+ print(f"Done. {diameters}", flush=True)
1473
+ self._max_rips_value = (
1474
+ -self.max_rips_value * diameters
1475
+ if self.max_rips_value < 0
1476
+ else self.max_rips_value
1477
+ )
1478
+ return self
1479
+
1480
+ def _transform1(self, data: np.ndarray):
1481
+ _distance_matrix = distance_matrix(data, data)
1482
+ signed_measures = []
1483
+ (
1484
+ rips_values,
1485
+ normalized_degree_values,
1486
+ hilbert_functions,
1487
+ minimal_presentations,
1488
+ ) = hf_degree_rips(
1489
+ _distance_matrix,
1490
+ min_rips_value=self.min_rips_value,
1491
+ max_rips_value=self._max_rips_value,
1492
+ min_normalized_degree=self.min_normalized_degree,
1493
+ max_normalized_degree=self.max_normalized_degree,
1494
+ grid_granularity=self.grid_granularity,
1495
+ max_homological_dimension=np.max(self.degrees),
1496
+ )
1497
+ for degree in self.degrees:
1498
+ hilbert_function = hilbert_functions[degree]
1499
+ signed_measure = (
1500
+ signed_betti(hilbert_function, threshold=True)
1501
+ if self._möbius_inversion
1502
+ else hilbert_function
1503
+ )
1504
+ if self.sparse:
1505
+ signed_measure = tensor_möbius_inversion(
1506
+ tensor=signed_measure,
1507
+ num_parameters=2,
1508
+ grid_conversion=[rips_values, normalized_degree_values],
1509
+ )
1510
+ if not self._möbius_inversion:
1511
+ signed_measure = signed_measure.flatten()
1512
+ signed_measures.append(signed_measure)
1513
+ return signed_measures
1514
+
1515
+ def transform(self, X):
1516
+ return Parallel(n_jobs=self.n_jobs)(
1517
+ delayed(self._transform1)(data)
1518
+ for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}")
1519
+ )
1520
+
1521
+
1522
+ def tensor_möbius_inversion(
1523
+ tensor,
1524
+ grid_conversion: Iterable[np.ndarray] | None = None,
1525
+ plot: bool = False,
1526
+ raw: bool = False,
1527
+ num_parameters: int | None = None,
1528
+ ):
1529
+ from torch import Tensor
1530
+
1531
+ betti_sparse = Tensor(tensor.copy()).to_sparse() # Copy necessary in some cases :(
1532
+ num_indices, num_pts = betti_sparse.indices().shape
1533
+ num_parameters = num_indices if num_parameters is None else num_parameters
1534
+ if num_indices == num_parameters: # either hilbert or rank invariant
1535
+ rank_invariant = False
1536
+ elif 2 * num_parameters == num_indices:
1537
+ rank_invariant = True
1538
+ else:
1539
+ raise TypeError(
1540
+ f"Unsupported betti shape. {num_indices}\
1541
+ has to be either {num_parameters} or \
1542
+ {2*num_parameters}."
1543
+ )
1544
+ points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
1545
+ weights = np.asarray(betti_sparse.values(), dtype=int)
1546
+
1547
+ if grid_conversion is not None:
1548
+ coords = np.empty(shape=(num_pts, num_indices), dtype=float)
1549
+ for i in range(num_indices):
1550
+ coords[:, i] = grid_conversion[i % num_parameters][points_filtration[:, i]]
1551
+ else:
1552
+ coords = points_filtration
1553
+ if (not rank_invariant) and plot:
1554
+ plt.figure()
1555
+ color_weights = np.empty(weights.shape)
1556
+ color_weights[weights > 0] = np.log10(weights[weights > 0]) + 2
1557
+ color_weights[weights < 0] = -np.log10(-weights[weights < 0]) - 2
1558
+ plt.scatter(
1559
+ points_filtration[:, 0],
1560
+ points_filtration[:, 1],
1561
+ c=color_weights,
1562
+ cmap="coolwarm",
1563
+ )
1564
+ if (not rank_invariant) or raw:
1565
+ return coords, weights
1566
+
1567
+ def _is_trivial(rectangle: np.ndarray):
1568
+ birth = rectangle[:num_parameters]
1569
+ death = rectangle[num_parameters:]
1570
+ return np.all(birth <= death) # and not np.array_equal(birth,death)
1571
+
1572
+ correct_indices = np.array([_is_trivial(rectangle) for rectangle in coords])
1573
+ if len(correct_indices) == 0:
1574
+ return np.empty((0, num_indices)), np.empty((0))
1575
+ signed_measure = np.asarray(coords[correct_indices])
1576
+ weights = weights[correct_indices]
1577
+ if plot:
1578
+ # plot only the rank decompo for the moment
1579
+ assert signed_measure.shape[1] == 4
1580
+
1581
+ def _plot_rectangle(rectangle: np.ndarray, weight: float):
1582
+ x_axis = rectangle[[0, 2]]
1583
+ y_axis = rectangle[[1, 3]]
1584
+ color = "blue" if weight > 0 else "red"
1585
+ plt.plot(x_axis, y_axis, c=color)
1586
+
1587
+ for rectangle, weight in zip(signed_measure, weights):
1588
+ _plot_rectangle(rectangle=rectangle, weight=weight)
1589
+ return signed_measure, weights