multipers 2.4.0b1__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. multipers/.dylibs/libboost_timer.dylib +0 -0
  2. multipers/.dylibs/libc++.1.0.dylib +0 -0
  3. multipers/.dylibs/libtbb.12.17.dylib +0 -0
  4. multipers/__init__.py +33 -0
  5. multipers/_signed_measure_meta.py +426 -0
  6. multipers/_slicer_meta.py +231 -0
  7. multipers/array_api/__init__.py +62 -0
  8. multipers/array_api/numpy.py +124 -0
  9. multipers/array_api/torch.py +133 -0
  10. multipers/data/MOL2.py +458 -0
  11. multipers/data/UCR.py +18 -0
  12. multipers/data/__init__.py +1 -0
  13. multipers/data/graphs.py +466 -0
  14. multipers/data/immuno_regions.py +27 -0
  15. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  16. multipers/data/pytorch2simplextree.py +91 -0
  17. multipers/data/shape3d.py +101 -0
  18. multipers/data/synthetic.py +113 -0
  19. multipers/distances.py +202 -0
  20. multipers/filtration_conversions.pxd +736 -0
  21. multipers/filtration_conversions.pxd.tp +226 -0
  22. multipers/filtrations/__init__.py +21 -0
  23. multipers/filtrations/density.py +529 -0
  24. multipers/filtrations/filtrations.py +480 -0
  25. multipers/filtrations.pxd +534 -0
  26. multipers/filtrations.pxd.tp +332 -0
  27. multipers/function_rips.cpython-312-darwin.so +0 -0
  28. multipers/function_rips.pyx +104 -0
  29. multipers/grids.cpython-312-darwin.so +0 -0
  30. multipers/grids.pyx +538 -0
  31. multipers/gudhi/Persistence_slices_interface.h +213 -0
  32. multipers/gudhi/Simplex_tree_interface.h +274 -0
  33. multipers/gudhi/Simplex_tree_multi_interface.h +648 -0
  34. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  35. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  36. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  37. multipers/gudhi/gudhi/Debug_utils.h +52 -0
  38. multipers/gudhi/gudhi/Degree_rips_bifiltration.h +2307 -0
  39. multipers/gudhi/gudhi/Dynamic_multi_parameter_filtration.h +2524 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field.h +453 -0
  41. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +460 -0
  42. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +444 -0
  43. multipers/gudhi/gudhi/Fields/Multi_field_small.h +584 -0
  44. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +490 -0
  45. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +580 -0
  46. multipers/gudhi/gudhi/Fields/Z2_field.h +391 -0
  47. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +389 -0
  48. multipers/gudhi/gudhi/Fields/Zp_field.h +493 -0
  49. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +384 -0
  50. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +492 -0
  51. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  52. multipers/gudhi/gudhi/Matrix.h +2200 -0
  53. multipers/gudhi/gudhi/Multi_filtration/Multi_parameter_generator.h +1712 -0
  54. multipers/gudhi/gudhi/Multi_filtration/multi_filtration_conversions.h +237 -0
  55. multipers/gudhi/gudhi/Multi_filtration/multi_filtration_utils.h +225 -0
  56. multipers/gudhi/gudhi/Multi_parameter_filtered_complex.h +485 -0
  57. multipers/gudhi/gudhi/Multi_parameter_filtration.h +2643 -0
  58. multipers/gudhi/gudhi/Multi_persistence/Box.h +233 -0
  59. multipers/gudhi/gudhi/Multi_persistence/Line.h +309 -0
  60. multipers/gudhi/gudhi/Multi_persistence/Multi_parameter_filtered_complex_pcoh_interface.h +268 -0
  61. multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_cohomology.h +159 -0
  62. multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_matrix.h +463 -0
  63. multipers/gudhi/gudhi/Multi_persistence/Point.h +853 -0
  64. multipers/gudhi/gudhi/Off_reader.h +173 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +834 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +838 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +833 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1367 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1157 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +869 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +905 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +122 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +260 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +288 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +170 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +247 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +571 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +182 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +130 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +235 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +312 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1092 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +923 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +914 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +930 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1071 -0
  87. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +203 -0
  88. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +886 -0
  89. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +984 -0
  90. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1213 -0
  91. multipers/gudhi/gudhi/Persistence_matrix/index_mapper.h +58 -0
  92. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +227 -0
  93. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +200 -0
  94. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +166 -0
  95. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +319 -0
  96. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +562 -0
  97. multipers/gudhi/gudhi/Persistence_on_a_line.h +152 -0
  98. multipers/gudhi/gudhi/Persistence_on_rectangle.h +617 -0
  99. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  100. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  101. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  102. multipers/gudhi/gudhi/Persistent_cohomology.h +769 -0
  103. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  104. multipers/gudhi/gudhi/Projective_cover_kernel.h +379 -0
  105. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  106. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +559 -0
  107. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  108. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +121 -0
  109. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  110. multipers/gudhi/gudhi/Simplex_tree/filtration_value_utils.h +155 -0
  111. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  112. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  113. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +60 -0
  114. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +105 -0
  115. multipers/gudhi/gudhi/Simplex_tree.h +3170 -0
  116. multipers/gudhi/gudhi/Slicer.h +848 -0
  117. multipers/gudhi/gudhi/Thread_safe_slicer.h +393 -0
  118. multipers/gudhi/gudhi/distance_functions.h +62 -0
  119. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  120. multipers/gudhi/gudhi/multi_simplex_tree_helpers.h +147 -0
  121. multipers/gudhi/gudhi/persistence_interval.h +263 -0
  122. multipers/gudhi/gudhi/persistence_matrix_options.h +188 -0
  123. multipers/gudhi/gudhi/reader_utils.h +367 -0
  124. multipers/gudhi/gudhi/simple_mdspan.h +484 -0
  125. multipers/gudhi/gudhi/slicer_helpers.h +779 -0
  126. multipers/gudhi/tmp_h0_pers/mma_interface_h0.h +223 -0
  127. multipers/gudhi/tmp_h0_pers/naive_merge_tree.h +536 -0
  128. multipers/io.cpython-312-darwin.so +0 -0
  129. multipers/io.pyx +472 -0
  130. multipers/ml/__init__.py +0 -0
  131. multipers/ml/accuracies.py +90 -0
  132. multipers/ml/invariants_with_persistable.py +79 -0
  133. multipers/ml/kernels.py +176 -0
  134. multipers/ml/mma.py +713 -0
  135. multipers/ml/one.py +472 -0
  136. multipers/ml/point_clouds.py +352 -0
  137. multipers/ml/signed_measures.py +1667 -0
  138. multipers/ml/sliced_wasserstein.py +461 -0
  139. multipers/ml/tools.py +113 -0
  140. multipers/mma_structures.cpython-312-darwin.so +0 -0
  141. multipers/mma_structures.pxd +134 -0
  142. multipers/mma_structures.pyx +1483 -0
  143. multipers/mma_structures.pyx.tp +1126 -0
  144. multipers/multi_parameter_rank_invariant/diff_helpers.h +85 -0
  145. multipers/multi_parameter_rank_invariant/euler_characteristic.h +95 -0
  146. multipers/multi_parameter_rank_invariant/function_rips.h +317 -0
  147. multipers/multi_parameter_rank_invariant/hilbert_function.h +761 -0
  148. multipers/multi_parameter_rank_invariant/persistence_slices.h +149 -0
  149. multipers/multi_parameter_rank_invariant/rank_invariant.h +350 -0
  150. multipers/multiparameter_edge_collapse.py +41 -0
  151. multipers/multiparameter_module_approximation/approximation.h +2541 -0
  152. multipers/multiparameter_module_approximation/debug.h +107 -0
  153. multipers/multiparameter_module_approximation/format_python-cpp.h +292 -0
  154. multipers/multiparameter_module_approximation/utilities.h +428 -0
  155. multipers/multiparameter_module_approximation.cpython-312-darwin.so +0 -0
  156. multipers/multiparameter_module_approximation.pyx +286 -0
  157. multipers/ops.cpython-312-darwin.so +0 -0
  158. multipers/ops.pyx +231 -0
  159. multipers/pickle.py +89 -0
  160. multipers/plots.py +550 -0
  161. multipers/point_measure.cpython-312-darwin.so +0 -0
  162. multipers/point_measure.pyx +409 -0
  163. multipers/simplex_tree_multi.cpython-312-darwin.so +0 -0
  164. multipers/simplex_tree_multi.pxd +136 -0
  165. multipers/simplex_tree_multi.pyx +11719 -0
  166. multipers/simplex_tree_multi.pyx.tp +2102 -0
  167. multipers/slicer.cpython-312-darwin.so +0 -0
  168. multipers/slicer.pxd +2097 -0
  169. multipers/slicer.pxd.tp +263 -0
  170. multipers/slicer.pyx +13042 -0
  171. multipers/slicer.pyx.tp +1259 -0
  172. multipers/tensor/tensor.h +672 -0
  173. multipers/tensor.pxd +13 -0
  174. multipers/test.pyx +44 -0
  175. multipers/tests/__init__.py +70 -0
  176. multipers/torch/__init__.py +1 -0
  177. multipers/torch/diff_grids.py +240 -0
  178. multipers/torch/rips_density.py +310 -0
  179. multipers/vector_interface.pxd +46 -0
  180. multipers-2.4.0b1.dist-info/METADATA +131 -0
  181. multipers-2.4.0b1.dist-info/RECORD +184 -0
  182. multipers-2.4.0b1.dist-info/WHEEL +6 -0
  183. multipers-2.4.0b1.dist-info/licenses/LICENSE +21 -0
  184. multipers-2.4.0b1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1667 @@
1
+ from collections.abc import Iterable, Sequence
2
+ from itertools import product
3
+ from typing import Optional, Union
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from joblib import Parallel, delayed
8
+ from sklearn.base import BaseEstimator, TransformerMixin
9
+ from tqdm import tqdm
10
+
11
+ import multipers as mp
12
+ from multipers.array_api import api_from_tensor, api_from_tensors
13
+ from multipers.filtrations.density import available_kernels, convolution_signed_measures
14
+ from multipers.grids import compute_grid, todense
15
+ from multipers.point_measure import rank_decomposition_by_rectangles, signed_betti
16
+
17
+
18
+ def batch_signed_measure_convolutions(
19
+ signed_measures, # array of shape (num_data,num_pts,D)
20
+ x, # array of shape (num_x, D) or (num_data, num_x, D)
21
+ bandwidth, # either float or matrix if multivariate kernel
22
+ kernel: available_kernels,
23
+ api=None,
24
+ ):
25
+ """
26
+ Input
27
+ -----
28
+ - signed_measures: unragged, of shape (num_data, num_pts, D+1)
29
+ where last coord is weights, (0 for dummy points)
30
+ - x : the points to convolve (num_x,D)
31
+ - bandwidth : the bandwidths or covariance matrix inverse or ... of the kernel
32
+ - kernel : "gaussian", "multivariate_gaussian", "exponential", or Callable (x_i, y_i, bandwidth)->float
33
+
34
+ Output
35
+ ------
36
+ Array of shape (num_convolutions, (num_axis), num_data,
37
+ Array of shape (num_convolutions, (num_axis), num_data, max_x_size)
38
+ """
39
+ from multipers.filtrations.density import _kernel
40
+
41
+ if api is None:
42
+ api = api_from_tensors(signed_measures, x)
43
+ if signed_measures.ndim == 2:
44
+ signed_measures = signed_measures[None, :, :]
45
+ sms = signed_measures[..., :-1]
46
+ weights = signed_measures[..., -1]
47
+ _sms = api.LazyTensor(api.ascontiguous(sms[..., None, :]))
48
+ _x = api.ascontiguous(x[..., None, :, :])
49
+
50
+ sms_kernel = _kernel(kernel)(_sms, _x, bandwidth)
51
+ out = (sms_kernel * api.ascontiguous(weights[..., None, None])).sum(
52
+ signed_measures.ndim - 2
53
+ )
54
+ assert out.shape[-1] == 1, "Pykeops bug fixed, TODO : refix this "
55
+ out = out[..., 0] ## pykeops bug + ensures its a tensor
56
+ # assert out.shape == (x.shape[0], x.shape[1]), f"{x.shape=}, {out.shape=}"
57
+ return out
58
+
59
+
60
+ def sm2deep(signed_measure, api=None):
61
+ if api is None:
62
+ api = api_from_tensor(signed_measure[0])
63
+ dirac_positions, dirac_signs = signed_measure
64
+ dtype = dirac_positions.dtype
65
+ new_shape = list(dirac_positions.shape)
66
+ new_shape[1] += 1
67
+ c = api.empty(new_shape, dtype=dtype)
68
+ c[:, :-1] = dirac_positions
69
+ c[:, -1] = api.astensor(dirac_signs)
70
+ return c
71
+
72
+
73
+ def deep_unrag(sms, api=None):
74
+ if api is None:
75
+ api = api_from_tensor(sms[0][0])
76
+ num_sm = len(sms)
77
+ if num_sm == 0:
78
+ return api.tensor([])
79
+ first = sms[0][0]
80
+ num_parameters = first.shape[1]
81
+ dtype = first.dtype
82
+ deep_sms = tuple(sm2deep(sm, api=api) for sm in sms)
83
+ max_num_pts = np.max([sm[0].shape[0] for sm in sms])
84
+ unragged_sms = api.zeros((num_sm, max_num_pts, num_parameters + 1), dtype=dtype)
85
+
86
+ for data in range(num_sm):
87
+ sm = deep_sms[data]
88
+ a, b = sm.shape
89
+ unragged_sms[data, :a, :b] = sm
90
+ return unragged_sms
91
+
92
+
93
+ def sm_convolution(
94
+ sms,
95
+ grid,
96
+ bandwidth,
97
+ kernel: available_kernels = "gaussian",
98
+ plot: bool = False,
99
+ **plt_kwargs,
100
+ ):
101
+ dense_grid = todense(grid)
102
+ api = api_from_tensors(sms[0][0], dense_grid)
103
+ sms = deep_unrag(sms, api=api)
104
+ convs = batch_signed_measure_convolutions(
105
+ sms, dense_grid, bandwidth, kernel, api=api
106
+ ).reshape(sms.shape[0], *(len(g) for g in grid))
107
+ if plot:
108
+ from multipers.plots import plot_surfaces
109
+ plot_surfaces((grid, convs), **plt_kwargs)
110
+ return convs
111
+
112
+
113
+ class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
114
+ """
115
+ Input
116
+ -----
117
+ Iterable[SimplexTreeMulti]
118
+
119
+ Output
120
+ ------
121
+ Iterable[ list[signed_measure for degree] ]
122
+
123
+ signed measure is either
124
+ - (points : (n x num_parameters) array, weights : (n) int array ) if sparse,
125
+ - else an integer matrix.
126
+
127
+ Parameters
128
+ ----------
129
+ - degrees : list of degrees to compute. None correspond to the euler characteristic
130
+ - filtration grid : the grid on which to compute.
131
+ If None, the fit will infer it from
132
+ - fit_fraction : the fraction of data to consider for the fit, seed is controlled by the seed parameter
133
+ - resolution : the resolution of this grid
134
+ - filtration_quantile : filtrations values quantile to ignore
135
+ - grid_strategy:str : 'regular' or 'quantile' or 'exact'
136
+ - normalize filtration : if sparse, will normalize all filtrations.
137
+ - expand : expands the simplextree to compute correctly the degree, for
138
+ flag complexes
139
+ - invariant : the topological invariant to produce the signed measure.
140
+ Choices are "hilbert" or "euler". Will add rank invariant later.
141
+ - num_collapse : Either an int or "full". Collapse the complex before
142
+ doing computation.
143
+ - _möbius_inversion : if False, will not do the mobius inversion. output
144
+ has to be a matrix then.
145
+ - enforce_null_mass : Returns a zero mass measure, by thresholding the
146
+ module if True.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ # homological degrees + None for euler
152
+ degrees: list[int | None] = [],
153
+ rank_degrees: list[int] = [], # same for rank invariant
154
+ filtration_grid: (
155
+ Sequence[Sequence[np.ndarray]]
156
+ # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
157
+ | None
158
+ ) = None,
159
+ progress=False, # tqdm
160
+ num_collapses: int | str = 0, # edge collapses before computing
161
+ n_jobs=None,
162
+ resolution: (
163
+ Iterable[int] | int | None
164
+ ) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
165
+ # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
166
+ plot: bool = False,
167
+ filtration_quantile: float = 0.0, # quantile for inferring filtration grid
168
+ # wether or not to do the möbius inversion (not recommended to touch)
169
+ # _möbius_inversion: bool = True,
170
+ expand=False, # expand the simplextree befoe computing the homology
171
+ normalize_filtrations: bool = False,
172
+ # exact_computation:bool=False, # compute the exact signed measure.
173
+ grid_strategy: str = "exact",
174
+ seed: int = 0, # if fit_fraction is not 1, the seed sampling
175
+ fit_fraction=1, # the fraction of the data on which to fit
176
+ out_resolution: Iterable[int] | int | None = None,
177
+ individual_grid: Optional[
178
+ bool
179
+ ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
180
+ enforce_null_mass: bool = False,
181
+ flatten=True,
182
+ backend: Optional[str] = None,
183
+ ):
184
+ super().__init__()
185
+ self.degrees = degrees
186
+ self.rank_degrees = rank_degrees
187
+ self.filtration_grid = filtration_grid
188
+ self.progress = progress
189
+ self.num_collapses = num_collapses
190
+ self.n_jobs = n_jobs
191
+ self.resolution = resolution
192
+ self.plot = plot
193
+ self.backend = backend
194
+ # self.sparse=sparse # TODO : deprecate
195
+ self.filtration_quantile = filtration_quantile
196
+ # Will only work for non sparse output. (discrete matrices cannot be "rescaled")
197
+ self.normalize_filtrations = normalize_filtrations
198
+ self.grid_strategy = grid_strategy
199
+ # self._möbius_inversion = _möbius_inversion
200
+ self._reconversion_grid = None
201
+ self.expand = expand
202
+ # will only refit the grid if filtration_grid has never been given.
203
+ self._refit_grid = None
204
+ self.seed = seed
205
+ self.fit_fraction = fit_fraction
206
+ self._transform_st = None
207
+ self.out_resolution = out_resolution
208
+ self.individual_grid = individual_grid
209
+ self.enforce_null_mass = enforce_null_mass
210
+ self._default_mass_location = None
211
+ self.flatten = flatten
212
+ self.num_parameters: int = 0
213
+
214
+ return
215
+
216
+ @staticmethod
217
+ def _is_filtered_complex(input):
218
+ return mp.simplex_tree_multi.is_simplextree_multi(input) or mp.slicer.is_slicer(
219
+ input, allow_minpres=True
220
+ )
221
+
222
+ def _input_checks(self, X):
223
+ assert len(X) > 0, "No filtered complex found. Cannot fit."
224
+ assert self._is_filtered_complex(
225
+ X[0][0]
226
+ ), f"X[0] is not a known filtered complex, {X[0]=}, nor X[0][0]."
227
+ self._num_axis = len(X[0])
228
+ first = X[0][0]
229
+ assert (
230
+ not mp.slicer.is_slicer(first) or not self.expand
231
+ ), "Cannot expand slicers."
232
+ assert not mp.slicer.is_slicer(first) or not (
233
+ isinstance(first, Union[tuple, list]) and first[0].is_minpres
234
+ ), "Multi-degree minpres are not supported yet as an input. This can still be computed by providing a backend."
235
+
236
+ def _infer_filtration(self, X):
237
+ self.num_parameters = X[0][0].num_parameters
238
+ indices = np.random.choice(
239
+ len(X), min(int(self.fit_fraction * len(X)) + 1, len(X)), replace=False
240
+ )
241
+ ## ax, num_x
242
+ filtrations = tuple(
243
+ tuple(
244
+ compute_grid(x, strategy="exact")
245
+ for x in (X[idx][ax] for idx in indices)
246
+ )
247
+ for ax in range(self._num_axis)
248
+ )
249
+ num_parameters = len(filtrations[0][0])
250
+ assert (
251
+ num_parameters == self.num_parameters
252
+ ), f"Internal error, got {num_parameters=} and {self.num_parameters=}"
253
+
254
+ filtrations_values = [
255
+ [
256
+ np.unique(np.concatenate([x[i] for x in filtrations[ax]]))
257
+ for i in range(num_parameters)
258
+ ]
259
+ for ax in range(self._num_axis)
260
+ ]
261
+ ## ax, param, gridsize
262
+ filtration_grid = tuple(
263
+ compute_grid(
264
+ filtrations_values[ax],
265
+ resolution=self.resolution,
266
+ strategy=self.grid_strategy,
267
+ )
268
+ for ax in range(self._num_axis)
269
+ ) # TODO :use more parameters
270
+ self.filtration_grid = filtration_grid
271
+ return filtration_grid
272
+
273
+ def _params_check(self):
274
+ assert (
275
+ self.resolution is not None
276
+ or self.filtration_grid is not None
277
+ or self.grid_strategy == "exact"
278
+ or self.individual_grid
279
+ ), "For non exact filtrations, a resolution has to be specified."
280
+
281
+ def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
282
+ self._params_check()
283
+ self._input_checks(X)
284
+
285
+ if isinstance(self.resolution, int):
286
+ self.resolution = [self.resolution] * self.num_parameters
287
+
288
+ self.individual_grid = (
289
+ self.individual_grid
290
+ if self.individual_grid is not None
291
+ else self.grid_strategy
292
+ in ["regular_closest", "exact", "quantile", "partition"]
293
+ )
294
+
295
+ if (
296
+ not self.enforce_null_mass
297
+ and self.individual_grid
298
+ or self.filtration_grid is not None
299
+ ):
300
+ self._refit_grid = False
301
+ else:
302
+ self._refit_grid = True
303
+
304
+ if self._refit_grid:
305
+ self._infer_filtration(X=X)
306
+ if self.out_resolution is None:
307
+ self.out_resolution = self.resolution
308
+ # elif isinstance(self.out_resolution, int):
309
+ # self.out_resolution = [self.out_resolution] * self.num_parameters
310
+ if self.normalize_filtrations and not self.individual_grid:
311
+ # self._reconversion_grid = [np.linspace(0,1, num=len(f), dtype=float) for f in self.filtration_grid] ## This will not work for non-regular grids...
312
+ self._reconversion_grid = [
313
+ [(f - np.min(f)) / np.std(f) for f in F] for F in self.filtration_grid
314
+ ] # not the best, but better than some weird magic
315
+ # elif not self.sparse: # It actually renormalizes the filtration !!
316
+ # self._reconversion_grid = [np.linspace(0,r, num=r, dtype=int) for r in self.out_resolution]
317
+ else:
318
+ self._reconversion_grid = self.filtration_grid
319
+ ## ax, num_param
320
+ self._default_mass_location = (
321
+ np.asarray([[g[-1] for g in F] for F in self.filtration_grid])
322
+ if self.enforce_null_mass
323
+ else None
324
+ )
325
+ return self
326
+
327
+ def transform1(
328
+ self,
329
+ simplextree,
330
+ ax,
331
+ # _reconversion_grid,
332
+ thread_id: str = "",
333
+ ):
334
+ # st = mp.SimplexTreeMulti(st, num_parameters=st.num_parameters) # COPY
335
+ if self.individual_grid:
336
+ filtration_grid = compute_grid(
337
+ simplextree, strategy=self.grid_strategy, resolution=self.resolution
338
+ )
339
+ mass_default = (
340
+ self._default_mass_location[ax] if self.enforce_null_mass else None
341
+ )
342
+ if self.enforce_null_mass:
343
+ filtration_grid = [
344
+ np.concatenate([f, [d]], axis=0)
345
+ for f, d in zip(filtration_grid, mass_default)
346
+ ]
347
+ _reconversion_grid = filtration_grid
348
+ else:
349
+ filtration_grid = self.filtration_grid[ax]
350
+ _reconversion_grid = self._reconversion_grid[ax]
351
+ mass_default = (
352
+ self._default_mass_location[ax] if self.enforce_null_mass else None
353
+ )
354
+
355
+ st = simplextree.grid_squeeze(filtration_grid=filtration_grid)
356
+ if st.num_parameters == 2 and mp.simplex_tree_multi.is_simplextree_multi(st):
357
+ st.collapse_edges(num=self.num_collapses, max_dimension=1)
358
+ int_degrees = np.asarray([d for d in self.degrees if d is not None], dtype=int)
359
+ # EULER. First as there is prune above dimension below
360
+ if self.expand and None in self.degrees:
361
+ st.expansion(st.num_vertices)
362
+ signed_measures_euler = (
363
+ mp.signed_measure(
364
+ st,
365
+ degrees=[None],
366
+ plot=self.plot,
367
+ mass_default=mass_default,
368
+ invariant="euler",
369
+ # thread_id=thread_id,
370
+ backend=self.backend,
371
+ grid=_reconversion_grid,
372
+ )[0]
373
+ if None in self.degrees
374
+ else []
375
+ )
376
+
377
+ if self.expand and len(int_degrees) > 0:
378
+ st.expansion(np.max(int_degrees) + 1)
379
+ if len(int_degrees) > 0:
380
+ st.prune_above_dimension(
381
+ np.max(np.concatenate([int_degrees, self.rank_degrees])) + 1
382
+ ) # no need to compute homology beyond this
383
+ signed_measures_pers = (
384
+ mp.signed_measure(
385
+ st,
386
+ degrees=int_degrees,
387
+ mass_default=mass_default,
388
+ plot=self.plot,
389
+ invariant="hilbert",
390
+ thread_id=thread_id,
391
+ backend=self.backend,
392
+ grid=_reconversion_grid,
393
+ )
394
+ if len(int_degrees) > 0
395
+ else []
396
+ )
397
+ if self.plot:
398
+ plt.show()
399
+ if self.expand and len(self.rank_degrees) > 0:
400
+ st.expansion(np.max(self.rank_degrees) + 1)
401
+ if len(self.rank_degrees) > 0:
402
+ st.prune_above_dimension(
403
+ np.max(self.rank_degrees) + 1
404
+ ) # no need to compute homology beyond this
405
+ signed_measures_rank = (
406
+ mp.signed_measure(
407
+ st,
408
+ degrees=self.rank_degrees,
409
+ mass_default=mass_default,
410
+ plot=self.plot,
411
+ invariant="rank",
412
+ thread_id=thread_id,
413
+ backend=self.backend,
414
+ grid=_reconversion_grid,
415
+ )
416
+ if len(self.rank_degrees) > 0
417
+ else []
418
+ )
419
+ if self.plot:
420
+ plt.show()
421
+
422
+ count = 0
423
+ signed_measures = []
424
+ for d in self.degrees:
425
+ if d is None:
426
+ signed_measures.append(signed_measures_euler)
427
+ else:
428
+ signed_measures.append(signed_measures_pers[count])
429
+ count += 1
430
+ signed_measures += signed_measures_rank
431
+ return signed_measures
432
+
433
+ def transform(self, X):
434
+ ## X of shape (num_x, num_axis, filtered_complex
435
+ assert (
436
+ self.filtration_grid is not None and self._reconversion_grid is not None
437
+ ) or self.individual_grid, "Fit first"
438
+
439
+ def todo_x(x):
440
+ return tuple(self.transform1(x_axis, j) for j, x_axis in enumerate(x))
441
+
442
+ ## out shape num_x, num_axis, degree, sm
443
+ out = tuple(
444
+ Parallel(n_jobs=self.n_jobs, backend="threading")(
445
+ delayed(todo_x)(x) for x in X
446
+ )
447
+ )
448
+ # out = Parallel(n_jobs=self.n_jobs, backend="threading")(
449
+ # delayed(self.transform1)(to_st, thread_id=str(thread_id))
450
+ # for thread_id, to_st in tqdm(
451
+ # enumerate(X),
452
+ # disable=not self.progress,
453
+ # desc="Computing signed measure decompositions",
454
+ # )
455
+ # )
456
+ return out
457
+
458
+
459
+ class SimplexTree2SignedMeasure(FilteredComplex2SignedMeasure):
460
+ def __init__(
461
+ self,
462
+ # homological degrees + None for euler
463
+ degrees: list[int | None] = [],
464
+ rank_degrees: list[int] = [], # same for rank invariant
465
+ filtration_grid: (
466
+ Sequence[Sequence[np.ndarray]]
467
+ # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
468
+ | None
469
+ ) = None,
470
+ progress=False, # tqdm
471
+ num_collapses: int | str = 0, # edge collapses before computing
472
+ n_jobs=None,
473
+ resolution: (
474
+ Iterable[int] | int | None
475
+ ) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
476
+ # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
477
+ plot: bool = False,
478
+ filtration_quantile: float = 0.0, # quantile for inferring filtration grid
479
+ # wether or not to do the möbius inversion (not recommended to touch)
480
+ # _möbius_inversion: bool = True,
481
+ expand=False, # expand the simplextree befoe computing the homology
482
+ normalize_filtrations: bool = False,
483
+ # exact_computation:bool=False, # compute the exact signed measure.
484
+ grid_strategy: str = "exact",
485
+ seed: int = 0, # if fit_fraction is not 1, the seed sampling
486
+ fit_fraction=1, # the fraction of the data on which to fit
487
+ out_resolution: Iterable[int] | int | None = None,
488
+ individual_grid: Optional[
489
+ bool
490
+ ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
491
+ enforce_null_mass: bool = False,
492
+ flatten=True,
493
+ backend: Optional[str] = None,
494
+ ):
495
+ stuff = locals()
496
+ stuff.pop("self")
497
+ keys = list(stuff.keys())
498
+ for key in keys:
499
+ if key.startswith("__"):
500
+ stuff.pop(key)
501
+ super().__init__(**stuff)
502
+ from warnings import warn
503
+
504
+ warn("This class is deprecated, use FilteredComplex2SignedMeasure instead.")
505
+
506
+
507
+ # class SimplexTrees2SignedMeasures(SimplexTree2SignedMeasure):
508
+ # """
509
+ # Input
510
+ # -----
511
+ #
512
+ # (data) x (axis, e.g. different bandwidths for simplextrees) x (simplextree)
513
+ #
514
+ # Output
515
+ # ------
516
+ # (data) x (axis) x (degree) x (signed measure)
517
+ # """
518
+ #
519
+ # def __init__(self, **kwargs):
520
+ # super().__init__(**kwargs)
521
+ # self._num_st_per_data = None
522
+ # # self._super_model=SimplexTree2SignedMeasure(**kwargs)
523
+ # self._filtration_grids = None
524
+ # return
525
+ #
526
+ # def fit(self, X, y=None):
527
+ # if len(X) == 0:
528
+ # return self
529
+ # try:
530
+ # self._num_st_per_data = len(X[0])
531
+ # except:
532
+ # raise Exception(
533
+ # "Shape has to be (num_data, num_axis), dtype=SimplexTreeMulti"
534
+ # )
535
+ # self._filtration_grids = []
536
+ # for axis in range(self._num_st_per_data):
537
+ # self._filtration_grids.append(
538
+ # super().fit([x[axis] for x in X]).filtration_grid
539
+ # )
540
+ # # self._super_fits.append(truc)
541
+ # # self._super_fits_params = [super().fit([x[axis] for x in X]).get_params() for axis in range(self._num_st_per_data)]
542
+ # return self
543
+ #
544
+ # def transform(self, X):
545
+ # if self.normalize_filtrations:
546
+ # _reconversion_grids = [
547
+ # [np.linspace(0, 1, num=len(f), dtype=float) for f in F]
548
+ # for F in self._filtration_grids
549
+ # ]
550
+ # else:
551
+ # _reconversion_grids = self._filtration_grids
552
+ #
553
+ # def todo(x):
554
+ # # return [SimplexTree2SignedMeasure().set_params(**transformer_params).transform1(x[axis]) for axis,transformer_params in enumerate(self._super_fits_params)]
555
+ # out = [
556
+ # self.transform1(
557
+ # x[axis],
558
+ # filtration_grid=filtration_grid,
559
+ # _reconversion_grid=_reconversion_grid,
560
+ # )
561
+ # for axis, filtration_grid, _reconversion_grid in zip(
562
+ # range(self._num_st_per_data),
563
+ # self._filtration_grids,
564
+ # _reconversion_grids,
565
+ # )
566
+ # ]
567
+ # return out
568
+ #
569
+ # return Parallel(n_jobs=self.n_jobs, backend="threading")(
570
+ # delayed(todo)(x)
571
+ # for x in tqdm(
572
+ # X,
573
+ # disable=not self.progress,
574
+ # desc="Computing Signed Measures from simplextrees.",
575
+ # )
576
+ # )
577
+
578
+
579
+ def rescale_sparse_signed_measure(
580
+ signed_measure, filtration_weights, normalize_scales=None
581
+ ):
582
+ # from copy import deepcopy
583
+ #
584
+ # out = deepcopy(signed_measure)
585
+
586
+ if filtration_weights is None and normalize_scales is None:
587
+ return signed_measure
588
+
589
+ # if normalize_scales is None:
590
+ # out = tuple(
591
+ # (
592
+ # _cat(
593
+ # tuple(
594
+ # signed_measure[degree][0][:, parameter]
595
+ # * filtration_weights[parameter]
596
+ # for parameter in range(num_parameters)
597
+ # ),
598
+ # axis=1,
599
+ # ),
600
+ # signed_measure[degree][1],
601
+ # )
602
+ # for degree in range(len(signed_measure))
603
+ # )
604
+ # for degree in range(len(signed_measure)): # degree
605
+ # for parameter in range(len(filtration_weights)):
606
+ # signed_measure[degree][0][:, parameter] *= filtration_weights[parameter]
607
+ # # TODO Broadcast w.r.t. the parameter
608
+ # out = tuple(
609
+ # _cat(
610
+ # tuple(
611
+ # signed_measure[degree][0][:, [parameter]]
612
+ # * filtration_weights[parameter]
613
+ # / (
614
+ # normalize_scales[degree][parameter]
615
+ # if normalize_scales is not None
616
+ # else 1
617
+ # )
618
+ # for parameter in range(num_parameters)
619
+ # ),
620
+ # axis=1,
621
+ # )
622
+ # for degree in range(len(signed_measure))
623
+ # )
624
+ out = tuple(
625
+ (
626
+ signed_measure[degree][0]
627
+ * (1 if filtration_weights is None else filtration_weights.reshape(1, -1))
628
+ / (
629
+ normalize_scales[degree].reshape(1, -1)
630
+ if normalize_scales is not None
631
+ else 1
632
+ ),
633
+ signed_measure[degree][1],
634
+ )
635
+ for degree in range(len(signed_measure))
636
+ )
637
+ # for degree in range(len(out)):
638
+ # for parameter in range(len(filtration_weights)):
639
+ # out[degree][0][:, parameter] *= (
640
+ # filtration_weights[parameter] / normalize_scales[degree][parameter]
641
+ # )
642
+ return out
643
+
644
+
645
+ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
646
+ """
647
+ Input
648
+ -----
649
+
650
+ (data) x (degree) x (signed measure) or (data) x (axis) x (degree) x (signed measure)
651
+
652
+ Iterable[list[signed_measure_matrix of degree]] or Iterable[previous].
653
+
654
+ The second is meant to use multiple choices for signed measure input. An example of usage : they come from a Rips + Density with different bandwidth.
655
+ It is controlled by the axis parameter.
656
+
657
+ Output
658
+ ------
659
+
660
+ Iterable[list[(reweighted)_sparse_signed_measure of degree]]
661
+
662
+ or (deep format)
663
+
664
+ Tensor of shape (num_axis*num_degrees, data, max_num_pts, num_parameters)
665
+ """
666
+
667
+ def __init__(
668
+ self,
669
+ filtrations_weights: Optional[Iterable[float]] = None,
670
+ normalize=False,
671
+ plot: bool = False,
672
+ unsparse: bool = False,
673
+ axis: int = -1,
674
+ resolution: int | Iterable[int] = 50,
675
+ flatten: bool = False,
676
+ deep_format: bool = False,
677
+ unrag: bool = True,
678
+ n_jobs: int = 1,
679
+ verbose: bool = False,
680
+ integrate: bool = False,
681
+ grid_strategy="regular",
682
+ ):
683
+ super().__init__()
684
+ self.filtrations_weights = filtrations_weights
685
+ self.num_parameters: int = 0
686
+ self.plot = plot
687
+ self.unsparse = unsparse
688
+ self.n_jobs = n_jobs
689
+ self.axis = axis
690
+ self._num_axis = 0
691
+ self.resolution = resolution
692
+ self._filtrations_bounds = None
693
+ self.flatten = flatten
694
+ self.normalize = normalize
695
+ self._normalization_factors = None
696
+ self.deep_format = deep_format
697
+ self.unrag = unrag
698
+ assert (
699
+ not self.deep_format or not self.unsparse or not self.integrate
700
+ ), "One post processing at the time."
701
+ self.verbose = verbose
702
+ self._num_degrees = 0
703
+ self.integrate = integrate
704
+ self.grid_strategy = grid_strategy
705
+ self._infered_grids = None
706
+ self._axis_iterator = None
707
+ self._backend = None
708
+ return
709
+
710
+ def _get_filtration_bounds(self, X, axis):
711
+ stuff = [
712
+ self._backend.cat(
713
+ [sm[axis][degree][0] for sm in X],
714
+ axis=0,
715
+ )
716
+ for degree in range(self._num_degrees)
717
+ ]
718
+ sizes_ = np.array([len(x) == 0 for x in stuff])
719
+ assert np.all(~sizes_), f"Degree axis {np.where(sizes_)} is/are trivial !"
720
+
721
+ filtrations_bounds = self._backend.asnumpy(
722
+ self._backend.stack(
723
+ [
724
+ self._backend.stack(
725
+ [
726
+ self._backend.minvalues(f, axis=0),
727
+ self._backend.maxvalues(f, axis=0),
728
+ ]
729
+ )
730
+ for f in stuff
731
+ ]
732
+ )
733
+ ) ## don't want to rescale gradient of normalization
734
+ normalization_factors = (
735
+ filtrations_bounds[:, 1] - filtrations_bounds[:, 0]
736
+ if self.normalize
737
+ else None
738
+ )
739
+ # print("Normalization factors : ",self._normalization_factors)
740
+ if (normalization_factors == 0).any():
741
+ indices = normalization_factors == 0
742
+ # warn(f"Constant filtration encountered, at degree, parameter {indices} and axis {self.axis}.")
743
+ normalization_factors[indices] = 1
744
+ return filtrations_bounds, normalization_factors
745
+
746
+ def _plot_signed_measures(self, sms: Iterable[np.ndarray], size=4):
747
+ from multipers.plots import plot_signed_measure
748
+
749
+ num_degrees = len(sms[0])
750
+ num_imgs = len(sms)
751
+ fig, axes = plt.subplots(
752
+ ncols=num_degrees,
753
+ nrows=num_imgs,
754
+ figsize=(size * num_degrees, size * num_imgs),
755
+ )
756
+ axes = np.asarray(axes).reshape(num_imgs, num_degrees)
757
+ # assert axes.ndim==2, "Internal error"
758
+ for i, sm in enumerate(sms):
759
+ for j, sm_of_degree in enumerate(sm):
760
+ plot_signed_measure(sm_of_degree, ax=axes[i, j])
761
+
762
+ @staticmethod
763
+ def _check_sm(sm) -> bool:
764
+ return (
765
+ isinstance(sm, tuple)
766
+ and hasattr(sm[0], "ndim")
767
+ and sm[0].ndim == 2
768
+ and len(sm) == 2
769
+ )
770
+
771
+ def _check_axis(self, X):
772
+ # axes should be (num_data, num_axis, num_degrees, (signed_measure))
773
+ if len(X) == 0:
774
+ return
775
+ if len(X[0]) == 0:
776
+ return
777
+ if self._check_sm(X[0][0]):
778
+ self._has_axis = False
779
+ self._num_axis = 1
780
+ self._axis_iterator = [slice(None)]
781
+ return
782
+ assert self._check_sm( ## vaguely checks that its a signed measure
783
+ _sm := X[0][0][0]
784
+ ), f"Cannot take this input. # data, axis, degrees, sm.\n Got {_sm} of type {type(_sm)}"
785
+
786
+ self._has_axis = True
787
+ self._num_axis = len(X[0])
788
+ self._axis_iterator = range(self._num_axis) if self.axis == -1 else [self.axis]
789
+
790
+ def _check_backend(self, X):
791
+ if self._has_axis:
792
+ # data, axis, degrees, (pts, weights)
793
+ first_sm = X[0][0][0][0]
794
+ else:
795
+ first_sm = X[0][0][0]
796
+ self._backend = api_from_tensor(first_sm, verbose=self.verbose)
797
+
798
+ def _check_measures(self, X):
799
+ if self._has_axis:
800
+ first_sm = X[0][0]
801
+ else:
802
+ first_sm = X[0]
803
+ self._num_degrees = len(first_sm)
804
+ self.num_parameters = first_sm[0][0].shape[1]
805
+
806
+ def _check_resolution(self):
807
+ assert self.num_parameters > 0, "Num parameters hasn't been initialized."
808
+ if isinstance(self.resolution, int):
809
+ self.resolution = [self.resolution] * self.num_parameters
810
+ self.resolution = np.asarray(self.resolution, dtype=int)
811
+ assert (
812
+ self.resolution.shape[0] == self.num_parameters
813
+ ), "Resolution doesn't have a proper size."
814
+
815
+ def _check_weights(self):
816
+ if self.filtrations_weights is None:
817
+ return
818
+ assert (
819
+ self.filtrations_weights.shape[0] == self.num_parameters
820
+ ), "Filtration weights don't have a proper size"
821
+
822
+ def _infer_grids(self, X):
823
+ # Computes normalization factors
824
+ if self.normalize:
825
+ # if self._has_axis and self.axis == -1:
826
+ self._filtrations_bounds = []
827
+ self._normalization_factors = []
828
+ for ax in self._axis_iterator:
829
+ (
830
+ filtration_bounds,
831
+ normalization_factors,
832
+ ) = self._get_filtration_bounds(X, axis=ax)
833
+ self._filtrations_bounds.append(filtration_bounds)
834
+ self._normalization_factors.append(normalization_factors)
835
+ self._filtrations_bounds = self._backend.astensor(self._filtrations_bounds)
836
+ self._normalization_factors = self._backend.astensor(
837
+ self._normalization_factors
838
+ )
839
+ # else:
840
+ # (
841
+ # self._filtrations_bounds,
842
+ # self._normalization_factors,
843
+ # ) = self._get_filtration_bounds(
844
+ # X, axis=self._axis_iterator[0]
845
+ # ) ## axis = slice(None)
846
+ elif self.integrate or self.unsparse or self.deep_format:
847
+ filtration_values = [
848
+ np.concatenate(
849
+ [
850
+ (
851
+ stuff
852
+ if isinstance(stuff := x[ax][degree][0], np.ndarray)
853
+ else stuff.detach().numpy()
854
+ )
855
+ for x in X
856
+ for degree in range(self._num_degrees)
857
+ ]
858
+ )
859
+ for ax in self._axis_iterator
860
+ ]
861
+ # axis, filtration_values
862
+ filtration_values = [
863
+ self._backend.astensor(
864
+ compute_grid(
865
+ f_ax.T, resolution=self.resolution, strategy=self.grid_strategy
866
+ )
867
+ )
868
+ for f_ax in filtration_values
869
+ ]
870
+ self._infered_grids = filtration_values
871
+
872
+ def _print_stats(self, X):
873
+ print("------------SignedMeasureFormatter------------")
874
+ print("---- Parameters")
875
+ print(f"Number of axis : {self._num_axis}")
876
+ print(f"Number of degrees : {self._num_degrees}")
877
+ print(f"Filtration bounds : \n{self._filtrations_bounds}")
878
+ print(f"Normalization factor : \n{self._normalization_factors}")
879
+ if self._infered_grids is not None:
880
+ print(
881
+ f"Filtration grid shape : \n \
882
+ {tuple(tuple(len(f) for f in F) for F in self._infered_grids)}"
883
+ )
884
+ print("---- SM stats")
885
+ print("In axis :", self._num_axis)
886
+ sizes = [
887
+ [[len(xd[1]) for xd in x[ax]] for x in X] for ax in self._axis_iterator
888
+ ]
889
+ print(f"Size means (axis) x (degree): {np.mean(sizes, axis=(1))}")
890
+ print(f"Size std : {np.std(sizes, axis=(1))}")
891
+ print("----------------------------------------------")
892
+
893
+ def fit(self, X, y=None):
894
+ # Gets a grid. This will be the max in each coord+1
895
+ if (
896
+ len(X) == 0
897
+ or len(X[0]) == 0
898
+ or (self.axis is not None and len(X[0][0][0]) == 0)
899
+ ):
900
+ return self
901
+
902
+ self._check_axis(X)
903
+ self._check_backend(X)
904
+ self._check_measures(X)
905
+ self._check_resolution()
906
+ self._check_weights()
907
+ # if not sparse : not recommended.
908
+
909
+ self._infer_grids(X)
910
+ if self.verbose:
911
+ self._print_stats(X)
912
+ return self
913
+
914
+ def unsparse_signed_measure(self, sparse_signed_measure):
915
+ filtrations = self._infered_grids # ax, filtration
916
+ out = []
917
+ for filtrations_of_ax, ax in zip(filtrations, self._axis_iterator, strict=True):
918
+ sparse_signed_measure_of_ax = sparse_signed_measure[ax]
919
+ measure_of_ax = []
920
+ for pts, weights in sparse_signed_measure_of_ax: # over degree
921
+ signed_measure, _ = np.histogramdd(
922
+ pts, bins=filtrations_of_ax, weights=weights
923
+ )
924
+ if self.flatten:
925
+ signed_measure = signed_measure.flatten()
926
+ measure_of_ax.append(signed_measure)
927
+ out.append(np.asarray(measure_of_ax))
928
+
929
+ if self.flatten:
930
+ out = np.concatenate(out).flatten()
931
+ elif self.axis == -1:
932
+ return np.asarray(out)
933
+ else:
934
+ return np.asarray(out)[0]
935
+
936
+ @staticmethod
937
+ def _integrate_measure(sm, filtrations):
938
+ from multipers.point_measure import integrate_measure
939
+
940
+ return integrate_measure(sm[0], sm[1], filtrations)
941
+
942
+ def _rescale_measures(self, X):
943
+ def rescale_from_sparse(sparse_signed_measure):
944
+ if self.axis == -1 and self._has_axis:
945
+ return tuple(
946
+ rescale_sparse_signed_measure(
947
+ sparse_signed_measure[ax],
948
+ filtration_weights=self.filtrations_weights,
949
+ normalize_scales=n,
950
+ )
951
+ for ax, n in zip(
952
+ self._axis_iterator, self._normalization_factors, strict=True
953
+ )
954
+ )
955
+ return rescale_sparse_signed_measure( ## axis iterator is of size 1 here
956
+ sparse_signed_measure,
957
+ filtration_weights=self.filtrations_weights,
958
+ normalize_scales=self._normalization_factors[0],
959
+ )
960
+
961
+ out = tuple(rescale_from_sparse(x) for x in X)
962
+ return out
963
+
964
+ def transform(self, X):
965
+ if not self._has_axis or self.axis == -1:
966
+ out = X
967
+ else:
968
+ out = tuple(x[self.axis] for x in X)
969
+ # same format for everyone
970
+
971
+ if self._normalization_factors is not None:
972
+ out = self._rescale_measures(out)
973
+
974
+ if self.plot:
975
+ # assert ax != -1, "Not implemented"
976
+ self._plot_signed_measures(out)
977
+ if self.integrate:
978
+ filtrations = self._infered_grids
979
+ # if self.axis != -1:
980
+ ax = 0 # if self.axis is None else self.axis # TODO deal with axis -1
981
+
982
+ assert ax != -1, "Not implemented. Can only integrate with axis"
983
+ # try:
984
+ out = np.asarray(
985
+ [
986
+ [
987
+ self._integrate_measure(x[degree], filtrations=filtrations[ax])
988
+ for degree in range(self._num_degrees)
989
+ ]
990
+ for x in out
991
+ ]
992
+ )
993
+ # except:
994
+ # print(self.axis, ax, filtrations)
995
+ if self.flatten:
996
+ out = out.reshape((len(X), -1))
997
+ # else:
998
+ # out = [[[self._integrate_measure(x[axis][degree],filtrations=filtrations[degree].T) for degree in range(self._num_degrees)] for axis in range(self._num_axis)] for x in out]
999
+ elif self.unsparse:
1000
+ out = [self.unsparse_signed_measure(x) for x in out]
1001
+ elif self.deep_format:
1002
+ num_degrees = self._num_degrees
1003
+ out = tuple(
1004
+ tuple(sm2deep(sm[axis][degree]) for sm in out)
1005
+ for degree in range(num_degrees)
1006
+ for axis in self._axis_iterator
1007
+ )
1008
+ if self.unrag:
1009
+ max_num_pts = np.max(
1010
+ [sm.shape[0] for sm_of_axis in out for sm in sm_of_axis]
1011
+ )
1012
+ num_axis_degree = len(out)
1013
+ num_data = len(out[0])
1014
+ assert num_axis_degree == num_degrees * (
1015
+ self._num_axis if self._has_axis else 1
1016
+ ), f"Bad axis/degree count. Got {num_axis_degree} (Internal error)"
1017
+ num_parameters = out[0][0].shape[1]
1018
+ dtype = out[0][0].dtype
1019
+ unragged_tensor = self._backend.zeros(
1020
+ (
1021
+ num_axis_degree,
1022
+ num_data,
1023
+ max_num_pts,
1024
+ num_parameters,
1025
+ ),
1026
+ dtype=dtype,
1027
+ )
1028
+ for ax in range(num_axis_degree):
1029
+ for data in range(num_data):
1030
+ sm = out[ax][data]
1031
+ a, b = sm.shape
1032
+ unragged_tensor[ax, data, :a, :b] = sm
1033
+ out = unragged_tensor
1034
+ return out
1035
+
1036
+
1037
+ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1038
+ """
1039
+ Discrete convolution of a signed measure
1040
+
1041
+ Input
1042
+ -----
1043
+
1044
+ (data) x (degree) x (signed measure)
1045
+
1046
+ Parameters
1047
+ ----------
1048
+ - filtration_grid : Iterable[array] For each filtration, the filtration values on which to evaluate the grid
1049
+ - resolution : int or (num_parameters) : If filtration grid is not given, will infer a grid, with this resolution
1050
+ - grid_strategy : the strategy to generate the grid. Available ones are regular, quantile, exact
1051
+ - flatten : if true, the output will be flattened
1052
+ - kernel : kernel to used to convolve the images.
1053
+ - flatten : flatten the images if True
1054
+ - progress : progress bar if True
1055
+ - backend : sklearn, pykeops or numba.
1056
+ - plot : Creates a plot Figure.
1057
+
1058
+ Output
1059
+ ------
1060
+
1061
+ (data) x (concatenation of imgs of degree)
1062
+ """
1063
+
1064
+ def __init__(
1065
+ self,
1066
+ filtration_grid: Iterable[np.ndarray] = None,
1067
+ kernel: available_kernels = "gaussian",
1068
+ bandwidth: float | Iterable[float] = 1.0,
1069
+ flatten: bool = False,
1070
+ n_jobs: int = 1,
1071
+ resolution: int | None = None,
1072
+ grid_strategy: str = "regular",
1073
+ progress: bool = False,
1074
+ backend: str = "pykeops",
1075
+ plot: bool = False,
1076
+ log_density: bool = False,
1077
+ **kde_kwargs,
1078
+ # **kwargs ## DANGEROUS
1079
+ ):
1080
+ super().__init__()
1081
+ self.kernel: available_kernels = kernel
1082
+ self.bandwidth = bandwidth
1083
+ # self.more_kde_kwargs=kwargs
1084
+ self.filtration_grid = filtration_grid
1085
+ self.flatten = flatten
1086
+ self.progress = progress
1087
+ self.n_jobs = n_jobs
1088
+ self.resolution = resolution
1089
+ self.grid_strategy = grid_strategy
1090
+ self._is_input_sparse = None
1091
+ self._refit = filtration_grid is None
1092
+ self._input_resolution = None
1093
+ self._bandwidths = None
1094
+ self.diameter = None
1095
+ self.backend = backend
1096
+ self.plot = plot
1097
+ self.log_density = log_density
1098
+ self.kde_kwargs = kde_kwargs
1099
+ self._api = None
1100
+ return
1101
+
1102
+ def fit(self, X, y=None):
1103
+ # Infers if the input is sparse given X
1104
+ if len(X) == 0:
1105
+ return self
1106
+ if isinstance(X[0][0], tuple):
1107
+ self._is_input_sparse = True
1108
+
1109
+ self._api = api_from_tensor(X[0][0][0], verbose=self.progress)
1110
+ else:
1111
+ self._is_input_sparse = False
1112
+
1113
+ self._api = api_from_tensor(X, verbose=self.progress)
1114
+ # print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
1115
+ if not self._is_input_sparse:
1116
+ self._input_resolution = X[0][0].shape
1117
+ try:
1118
+ b = float(self.bandwidth)
1119
+ self._bandwidths = [
1120
+ b if b > 0 else -b * s for s in self._input_resolution
1121
+ ]
1122
+ except:
1123
+ self._bandwidths = [
1124
+ b if b > 0 else -b * s
1125
+ for s, b in zip(self._input_resolution, self.bandwidth)
1126
+ ]
1127
+ return self # in that case, singed measures are matrices, and the grid is already given
1128
+
1129
+ if self.filtration_grid is None and self.resolution is None:
1130
+ raise Exception(
1131
+ "Cannot infer filtration grid. Provide either a filtration grid or a resolution."
1132
+ )
1133
+ # If not sparse : a grid has to be defined
1134
+ if self._refit:
1135
+ # print("Fitting a grid...", end="")
1136
+ pts = self._api.cat(
1137
+ [sm[0] for signed_measures in X for sm in signed_measures]
1138
+ ).T
1139
+ self.filtration_grid = compute_grid(
1140
+ pts,
1141
+ strategy=self.grid_strategy,
1142
+ resolution=self.resolution,
1143
+ )
1144
+ # print('Done.')
1145
+ if self.filtration_grid is not None:
1146
+ self.diameter = self._api.norm(
1147
+ self._api.astensor([f[-1] - f[0] for f in self.filtration_grid])
1148
+ )
1149
+ if self.progress:
1150
+ print(f"Computed a diameter of {self.diameter}")
1151
+ return self
1152
+
1153
+ def _sm2smi(self, signed_measures):
1154
+ # print(self._input_resolution, self.bandwidths, _bandwidths)
1155
+ from scipy.ndimage import gaussian_filter
1156
+
1157
+ return np.concatenate(
1158
+ [
1159
+ gaussian_filter(
1160
+ input=signed_measure,
1161
+ sigma=self._bandwidths,
1162
+ mode="constant",
1163
+ cval=0,
1164
+ )
1165
+ for signed_measure in signed_measures
1166
+ ],
1167
+ axis=0,
1168
+ )
1169
+
1170
+ def _transform_from_sparse(self, X):
1171
+ bandwidth = (
1172
+ self.bandwidth if self.bandwidth > 0 else -self.bandwidth * self.diameter
1173
+ )
1174
+ # COMPILE KEOPS FIRST
1175
+ dummyx = [X[0]]
1176
+ dummyf = [f[:2] for f in self.filtration_grid]
1177
+ convolution_signed_measures(
1178
+ dummyx,
1179
+ filtrations=dummyf,
1180
+ bandwidth=bandwidth,
1181
+ flatten=self.flatten,
1182
+ n_jobs=1,
1183
+ kernel=self.kernel,
1184
+ backend=self.backend,
1185
+ )
1186
+
1187
+ return convolution_signed_measures(
1188
+ X,
1189
+ filtrations=self.filtration_grid,
1190
+ bandwidth=bandwidth,
1191
+ flatten=self.flatten,
1192
+ n_jobs=self.n_jobs,
1193
+ kernel=self.kernel,
1194
+ backend=self.backend,
1195
+ **self.kde_kwargs,
1196
+ )
1197
+
1198
+ def _plot_imgs(self, imgs: Iterable[np.ndarray], size=4):
1199
+ from multipers.plots import plot_surface
1200
+
1201
+ imgs = self._api.asnumpy(imgs)
1202
+ num_degrees = imgs[0].shape[0]
1203
+ num_imgs = len(imgs)
1204
+ fig, axes = plt.subplots(
1205
+ ncols=num_degrees,
1206
+ nrows=num_imgs,
1207
+ figsize=(size * num_degrees, size * num_imgs),
1208
+ )
1209
+ axes = np.asarray(axes).reshape(num_imgs, num_degrees)
1210
+ # assert axes.ndim==2, "Internal error"
1211
+ for i, img in enumerate(imgs):
1212
+ for j, img_of_degree in enumerate(img):
1213
+ plot_surface(
1214
+ [self._api.asnumpy(f) for f in self.filtration_grid],
1215
+ img_of_degree,
1216
+ ax=axes[i, j],
1217
+ cmap="Spectral",
1218
+ )
1219
+
1220
+ def transform(self, X):
1221
+ if self._is_input_sparse is None:
1222
+ raise Exception("Fit first")
1223
+ if self._is_input_sparse:
1224
+ out = self._transform_from_sparse(X)
1225
+ else:
1226
+ todo = SignedMeasure2Convolution._sm2smi
1227
+ out = Parallel(n_jobs=self.n_jobs, backend="threading")(
1228
+ delayed(todo)(self, signed_measures)
1229
+ for signed_measures in tqdm(
1230
+ X, desc="Computing images", disable=not self.progress
1231
+ )
1232
+ )
1233
+ out = self._api.cat([x[None] for x in out])
1234
+ if self.plot and not self.flatten:
1235
+ if self.progress:
1236
+ print("Plotting convolutions...", end="")
1237
+ self._plot_imgs(out)
1238
+ if self.progress:
1239
+ print("Done !")
1240
+ if self.flatten and not self._is_input_sparse:
1241
+ out = self._api.cat([x.ravel()[None] for x in out])
1242
+ return out
1243
+
1244
+
1245
+ class SignedMeasure2SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
1246
+ """
1247
+ Transformer from signed measure to distance matrix.
1248
+
1249
+ Input
1250
+ -----
1251
+
1252
+ (data) x (degree) x (signed measure)
1253
+
1254
+ Format
1255
+ ------
1256
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1257
+ - each data is a list of signed measure (for e.g. multiple degrees)
1258
+
1259
+ Output
1260
+ ------
1261
+ - (degree) x (distance matrix)
1262
+ """
1263
+
1264
+ def __init__(
1265
+ self,
1266
+ n_jobs=None,
1267
+ num_directions: int = 10,
1268
+ _sliced: bool = True,
1269
+ epsilon=-1,
1270
+ ground_norm=1,
1271
+ progress=False,
1272
+ grid_reconversion=None,
1273
+ scales=None,
1274
+ ):
1275
+ super().__init__()
1276
+ self.n_jobs = n_jobs
1277
+ self._SWD_list = None
1278
+ self._sliced = _sliced
1279
+ self.epsilon = epsilon
1280
+ self.ground_norm = ground_norm
1281
+ self.num_directions = num_directions
1282
+ self.progress = progress
1283
+ self.grid_reconversion = grid_reconversion
1284
+ self.scales = scales
1285
+ return
1286
+
1287
+ def fit(self, X, y=None):
1288
+ from multipers.ml.sliced_wasserstein import (
1289
+ SlicedWassersteinDistance,
1290
+ WassersteinDistance,
1291
+ )
1292
+
1293
+ # _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
1294
+ if len(X) == 0:
1295
+ return self
1296
+ num_degrees = len(X[0])
1297
+ self._SWD_list = [
1298
+ (
1299
+ SlicedWassersteinDistance(
1300
+ num_directions=self.num_directions,
1301
+ n_jobs=self.n_jobs,
1302
+ scales=self.scales,
1303
+ )
1304
+ if self._sliced
1305
+ else WassersteinDistance(
1306
+ epsilon=self.epsilon,
1307
+ ground_norm=self.ground_norm,
1308
+ n_jobs=self.n_jobs,
1309
+ )
1310
+ )
1311
+ for _ in range(num_degrees)
1312
+ ]
1313
+ for degree, swd in enumerate(self._SWD_list):
1314
+ signed_measures_of_degree = [x[degree] for x in X]
1315
+ swd.fit(signed_measures_of_degree)
1316
+ return self
1317
+
1318
+ def transform(self, X):
1319
+ assert self._SWD_list is not None, "Fit first"
1320
+ # out = []
1321
+ # for degree, swd in tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress):
1322
+ with tqdm(
1323
+ enumerate(self._SWD_list),
1324
+ desc="Computing distance matrices",
1325
+ total=len(self._SWD_list),
1326
+ disable=not self.progress,
1327
+ ) as SWD_it:
1328
+ # signed_measures_of_degree = [x[degree] for x in X]
1329
+ # out.append(swd.transform(signed_measures_of_degree))
1330
+ def todo(swd, X_of_degree):
1331
+ return swd.transform(X_of_degree)
1332
+
1333
+ out = Parallel(n_jobs=self.n_jobs, prefer="threads")(
1334
+ delayed(todo)(swd, [x[degree] for x in X]) for degree, swd in SWD_it
1335
+ )
1336
+ return np.asarray(out)
1337
+
1338
+ def predict(self, X):
1339
+ return self.transform(X)
1340
+
1341
+
1342
+ class SignedMeasures2SlicedWassersteinDistances(BaseEstimator, TransformerMixin):
1343
+ """
1344
+ Transformer from signed measure to distance matrix.
1345
+ Input
1346
+ -----
1347
+ (data) x opt (axis) x (degree) x (signed measure)
1348
+
1349
+ Format
1350
+ ------
1351
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1352
+ - each data is a list of signed measure (for e.g. multiple degrees)
1353
+
1354
+ Output
1355
+ ------
1356
+ - (axis) x (degree) x (distance matrix)
1357
+ """
1358
+
1359
+ def __init__(
1360
+ self,
1361
+ progress=False,
1362
+ n_jobs: int = 1,
1363
+ scales: Iterable[Iterable[float]] | None = None,
1364
+ **kwargs,
1365
+ ): # same init
1366
+ self._init_child = SignedMeasure2SlicedWassersteinDistance(
1367
+ progress=False, scales=None, n_jobs=-1, **kwargs
1368
+ )
1369
+ self._axe_iterator = None
1370
+ self._childs_to_fit = None
1371
+ self.scales = scales
1372
+ self.progress = progress
1373
+ self.n_jobs = n_jobs
1374
+ return
1375
+
1376
+ def fit(self, X, y=None):
1377
+ from sklearn.base import clone
1378
+
1379
+ if len(X) == 0:
1380
+ return self
1381
+ if isinstance(X[0][0], tuple): # Meaning that there are no axes
1382
+ self._axe_iterator = [slice(None)]
1383
+ else:
1384
+ self._axe_iterator = range(len(X[0]))
1385
+ if self.scales is None:
1386
+ self.scales = [None]
1387
+ else:
1388
+ self.scales = np.asarray(self.scales)
1389
+ if self.scales.ndim == 1:
1390
+ self.scales = np.asarray([self.scales])
1391
+ assert (
1392
+ self.scales[0] is None or self.scales.ndim == 2
1393
+ ), "Scales have to be either None or a list of scales !"
1394
+ self._childs_to_fit = [
1395
+ clone(self._init_child).set_params(scales=scales).fit([x[axis] for x in X])
1396
+ for axis, scales in product(self._axe_iterator, self.scales)
1397
+ ]
1398
+ print("New axes : ", list(product(self._axe_iterator, self.scales)))
1399
+ return self
1400
+
1401
+ def transform(self, X):
1402
+ return Parallel(n_jobs=self.n_jobs, prefer="processes")(
1403
+ delayed(self._childs_to_fit[child_id].transform)([x[axis] for x in X])
1404
+ for child_id, (axis, _) in tqdm(
1405
+ enumerate(product(self._axe_iterator, self.scales)),
1406
+ desc=f"Computing distances matrices of axis, and scales",
1407
+ disable=not self.progress,
1408
+ total=len(self._childs_to_fit),
1409
+ )
1410
+ )
1411
+ # [
1412
+ # child.transform([x[axis // len(self.scales)] for x in X])
1413
+ # for axis, child in tqdm(enumerate(self._childs_to_fit),
1414
+ # desc=f"Computing distances of axis", disable=not self.progress, total=len(self._childs_to_fit)
1415
+ # )
1416
+ # ]
1417
+
1418
+
1419
+ class SimplexTree2RectangleDecomposition(BaseEstimator, TransformerMixin):
1420
+ """
1421
+ Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition.
1422
+ """
1423
+
1424
+ def __init__(
1425
+ self,
1426
+ filtration_grid: np.ndarray,
1427
+ degrees: Iterable[int],
1428
+ plot=False,
1429
+ reconvert_grid=True,
1430
+ num_collapses: int = 0,
1431
+ ):
1432
+ super().__init__()
1433
+ self.filtration_grid = filtration_grid
1434
+ self.degrees = degrees
1435
+ self.plot = plot
1436
+ self.reconvert_grid = reconvert_grid
1437
+ self.num_collapses = num_collapses
1438
+ return
1439
+
1440
+ def fit(self, X, y=None):
1441
+ """
1442
+ TODO : infer grid from multiple simplextrees
1443
+ """
1444
+ return self
1445
+
1446
+ def transform(self, X: Iterable[mp.simplex_tree_multi.SimplexTreeMulti_type]):
1447
+ rectangle_decompositions = [
1448
+ [
1449
+ _st2ranktensor(
1450
+ simplextree,
1451
+ filtration_grid=self.filtration_grid,
1452
+ degree=degree,
1453
+ plot=self.plot,
1454
+ reconvert_grid=self.reconvert_grid,
1455
+ num_collapse=self.num_collapses,
1456
+ )
1457
+ for degree in self.degrees
1458
+ ]
1459
+ for simplextree in X
1460
+ ]
1461
+ # TODO : return iterator ?
1462
+ return rectangle_decompositions
1463
+
1464
+
1465
+ def _st2ranktensor(
1466
+ st: mp.simplex_tree_multi.SimplexTreeMulti_type,
1467
+ filtration_grid: np.ndarray,
1468
+ degree: int,
1469
+ plot: bool,
1470
+ reconvert_grid: bool,
1471
+ num_collapse: int | str = 0,
1472
+ ):
1473
+ """
1474
+ TODO
1475
+ """
1476
+ # Copy (the squeeze change the filtration values)
1477
+ # stcpy = mp.SimplexTreeMulti(st)
1478
+ # turns the simplextree into a coordinate simplex tree
1479
+ stcpy = st.grid_squeeze(filtration_grid=filtration_grid, coordinate_values=True)
1480
+ # stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
1481
+ if num_collapse == "full":
1482
+ stcpy.collapse_edges(full=True, ignore_warning=True, max_dimension=degree + 1)
1483
+ elif isinstance(num_collapse, int):
1484
+ stcpy.collapse_edges(
1485
+ num=num_collapse, ignore_warning=True, max_dimension=degree + 1
1486
+ )
1487
+ else:
1488
+ raise TypeError(
1489
+ f"Invalid num_collapse=\
1490
+ {num_collapse} type. Either full, or an integer."
1491
+ )
1492
+ # computes the rank invariant tensor
1493
+ rank_tensor = mp.rank_invariant2d(
1494
+ stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid]
1495
+ )
1496
+ # refactor this tensor into the rectangle decomposition of the signed betti
1497
+ grid_conversion = filtration_grid if reconvert_grid else None
1498
+ rank_decomposition = rank_decomposition_by_rectangles(
1499
+ rank_tensor,
1500
+ threshold=True,
1501
+ )
1502
+ rectangle_decomposition = tensor_möbius_inversion(
1503
+ tensor=rank_decomposition,
1504
+ grid_conversion=grid_conversion,
1505
+ plot=plot,
1506
+ num_parameters=st.num_parameters,
1507
+ )
1508
+ return rectangle_decomposition
1509
+
1510
+
1511
+ class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
1512
+ def __init__(
1513
+ self,
1514
+ degrees: Iterable[int],
1515
+ min_rips_value: float,
1516
+ max_rips_value,
1517
+ max_normalized_degree: float,
1518
+ min_normalized_degree: float,
1519
+ grid_granularity: int,
1520
+ progress: bool = False,
1521
+ n_jobs=1,
1522
+ sparse: bool = False,
1523
+ _möbius_inversion=True,
1524
+ fit_fraction=1,
1525
+ ) -> None:
1526
+ super().__init__()
1527
+ self.min_rips_value = min_rips_value
1528
+ self.max_rips_value = max_rips_value
1529
+ self.min_normalized_degree = min_normalized_degree
1530
+ self.max_normalized_degree = max_normalized_degree
1531
+ self._max_rips_value = None
1532
+ self.grid_granularity = grid_granularity
1533
+ self.progress = progress
1534
+ self.n_jobs = n_jobs
1535
+ self.degrees = degrees
1536
+ self.sparse = sparse
1537
+ self._möbius_inversion = _möbius_inversion
1538
+ self.fit_fraction = fit_fraction
1539
+ return
1540
+
1541
+ def fit(self, X: np.ndarray | list, y=None):
1542
+ if self.max_rips_value < 0:
1543
+ print("Estimating scale...", flush=True, end="")
1544
+ indices = np.random.choice(
1545
+ len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
1546
+ )
1547
+ diameters = np.max(
1548
+ [distance_matrix(x, x).max() for x in (X[i] for i in indices)]
1549
+ )
1550
+ print(f"Done. {diameters}", flush=True)
1551
+ self._max_rips_value = (
1552
+ -self.max_rips_value * diameters
1553
+ if self.max_rips_value < 0
1554
+ else self.max_rips_value
1555
+ )
1556
+ return self
1557
+
1558
+ def _transform1(self, data: np.ndarray):
1559
+ _distance_matrix = distance_matrix(data, data)
1560
+ signed_measures = []
1561
+ (
1562
+ rips_values,
1563
+ normalized_degree_values,
1564
+ hilbert_functions,
1565
+ minimal_presentations,
1566
+ ) = hf_degree_rips(
1567
+ _distance_matrix,
1568
+ min_rips_value=self.min_rips_value,
1569
+ max_rips_value=self._max_rips_value,
1570
+ min_normalized_degree=self.min_normalized_degree,
1571
+ max_normalized_degree=self.max_normalized_degree,
1572
+ grid_granularity=self.grid_granularity,
1573
+ max_homological_dimension=np.max(self.degrees),
1574
+ )
1575
+ for degree in self.degrees:
1576
+ hilbert_function = hilbert_functions[degree]
1577
+ signed_measure = (
1578
+ signed_betti(hilbert_function, threshold=True)
1579
+ if self._möbius_inversion
1580
+ else hilbert_function
1581
+ )
1582
+ if self.sparse:
1583
+ signed_measure = tensor_möbius_inversion(
1584
+ tensor=signed_measure,
1585
+ num_parameters=2,
1586
+ grid_conversion=[rips_values, normalized_degree_values],
1587
+ )
1588
+ if not self._möbius_inversion:
1589
+ signed_measure = signed_measure.flatten()
1590
+ signed_measures.append(signed_measure)
1591
+ return signed_measures
1592
+
1593
+ def transform(self, X):
1594
+ return Parallel(n_jobs=self.n_jobs)(
1595
+ delayed(self._transform1)(data)
1596
+ for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}")
1597
+ )
1598
+
1599
+
1600
+ def tensor_möbius_inversion(
1601
+ tensor,
1602
+ grid_conversion: Iterable[np.ndarray] | None = None,
1603
+ plot: bool = False,
1604
+ raw: bool = False,
1605
+ num_parameters: int | None = None,
1606
+ ):
1607
+ from torch import Tensor
1608
+
1609
+ betti_sparse = Tensor(tensor.copy()).to_sparse() # Copy necessary in some cases :(
1610
+ num_indices, num_pts = betti_sparse.indices().shape
1611
+ num_parameters = num_indices if num_parameters is None else num_parameters
1612
+ if num_indices == num_parameters: # either hilbert or rank invariant
1613
+ rank_invariant = False
1614
+ elif 2 * num_parameters == num_indices:
1615
+ rank_invariant = True
1616
+ else:
1617
+ raise TypeError(
1618
+ f"Unsupported betti shape. {num_indices}\
1619
+ has to be either {num_parameters} or \
1620
+ {2*num_parameters}."
1621
+ )
1622
+ points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
1623
+ weights = np.asarray(betti_sparse.values(), dtype=int)
1624
+
1625
+ if grid_conversion is not None:
1626
+ coords = np.empty(shape=(num_pts, num_indices), dtype=float)
1627
+ for i in range(num_indices):
1628
+ coords[:, i] = grid_conversion[i % num_parameters][points_filtration[:, i]]
1629
+ else:
1630
+ coords = points_filtration
1631
+ if (not rank_invariant) and plot:
1632
+ plt.figure()
1633
+ color_weights = np.empty(weights.shape)
1634
+ color_weights[weights > 0] = np.log10(weights[weights > 0]) + 2
1635
+ color_weights[weights < 0] = -np.log10(-weights[weights < 0]) - 2
1636
+ plt.scatter(
1637
+ points_filtration[:, 0],
1638
+ points_filtration[:, 1],
1639
+ c=color_weights,
1640
+ cmap="coolwarm",
1641
+ )
1642
+ if (not rank_invariant) or raw:
1643
+ return coords, weights
1644
+
1645
+ def _is_trivial(rectangle: np.ndarray):
1646
+ birth = rectangle[:num_parameters]
1647
+ death = rectangle[num_parameters:]
1648
+ return np.all(birth <= death) # and not np.array_equal(birth,death)
1649
+
1650
+ correct_indices = np.array([_is_trivial(rectangle) for rectangle in coords])
1651
+ if len(correct_indices) == 0:
1652
+ return np.empty((0, num_indices)), np.empty((0))
1653
+ signed_measure = np.asarray(coords[correct_indices])
1654
+ weights = weights[correct_indices]
1655
+ if plot:
1656
+ # plot only the rank decompo for the moment
1657
+ assert signed_measure.shape[1] == 4
1658
+
1659
+ def _plot_rectangle(rectangle: np.ndarray, weight: float):
1660
+ x_axis = rectangle[[0, 2]]
1661
+ y_axis = rectangle[[1, 3]]
1662
+ color = "blue" if weight > 0 else "red"
1663
+ plt.plot(x_axis, y_axis, c=color)
1664
+
1665
+ for rectangle, weight in zip(signed_measure, weights):
1666
+ _plot_rectangle(rectangle=rectangle, weight=weight)
1667
+ return signed_measure, weights