multipers 1.0__cp311-cp311-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (56) hide show
  1. multipers/__init__.py +4 -0
  2. multipers/_old_rank_invariant.pyx +328 -0
  3. multipers/_signed_measure_meta.py +72 -0
  4. multipers/data/MOL2.py +350 -0
  5. multipers/data/UCR.py +18 -0
  6. multipers/data/__init__.py +1 -0
  7. multipers/data/graphs.py +272 -0
  8. multipers/data/immuno_regions.py +27 -0
  9. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  10. multipers/data/pytorch2simplextree.py +91 -0
  11. multipers/data/shape3d.py +101 -0
  12. multipers/data/synthetic.py +68 -0
  13. multipers/distances.py +100 -0
  14. multipers/euler_characteristic.cpython-311-x86_64-linux-gnu.so +0 -0
  15. multipers/euler_characteristic.pyx +132 -0
  16. multipers/function_rips.cpython-311-x86_64-linux-gnu.so +0 -0
  17. multipers/function_rips.pyx +101 -0
  18. multipers/hilbert_function.cpython-311-x86_64-linux-gnu.so +0 -0
  19. multipers/hilbert_function.pyi +46 -0
  20. multipers/hilbert_function.pyx +145 -0
  21. multipers/ml/__init__.py +0 -0
  22. multipers/ml/accuracies.py +61 -0
  23. multipers/ml/convolutions.py +384 -0
  24. multipers/ml/invariants_with_persistable.py +79 -0
  25. multipers/ml/kernels.py +128 -0
  26. multipers/ml/mma.py +422 -0
  27. multipers/ml/one.py +472 -0
  28. multipers/ml/point_clouds.py +191 -0
  29. multipers/ml/signed_betti.py +50 -0
  30. multipers/ml/signed_measures.py +1046 -0
  31. multipers/ml/sliced_wasserstein.py +313 -0
  32. multipers/ml/tools.py +99 -0
  33. multipers/multiparameter_edge_collapse.py +29 -0
  34. multipers/multiparameter_module_approximation.cpython-311-x86_64-linux-gnu.so +0 -0
  35. multipers/multiparameter_module_approximation.pxd +147 -0
  36. multipers/multiparameter_module_approximation.pyi +439 -0
  37. multipers/multiparameter_module_approximation.pyx +931 -0
  38. multipers/pickle.py +53 -0
  39. multipers/plots.py +207 -0
  40. multipers/point_measure_integration.cpython-311-x86_64-linux-gnu.so +0 -0
  41. multipers/point_measure_integration.pyx +59 -0
  42. multipers/rank_invariant.cpython-311-x86_64-linux-gnu.so +0 -0
  43. multipers/rank_invariant.pyx +154 -0
  44. multipers/simplex_tree_multi.cpython-311-x86_64-linux-gnu.so +0 -0
  45. multipers/simplex_tree_multi.pxd +121 -0
  46. multipers/simplex_tree_multi.pyi +715 -0
  47. multipers/simplex_tree_multi.pyx +1284 -0
  48. multipers/tensor.pxd +13 -0
  49. multipers/test.pyx +44 -0
  50. multipers-1.0.dist-info/LICENSE +21 -0
  51. multipers-1.0.dist-info/METADATA +9 -0
  52. multipers-1.0.dist-info/RECORD +56 -0
  53. multipers-1.0.dist-info/WHEEL +5 -0
  54. multipers-1.0.dist-info/top_level.txt +1 -0
  55. multipers.libs/libtbb-5d1cde94.so.12.10 +0 -0
  56. multipers.libs/libtbbmalloc-5e0a3d4c.so.2.10 +0 -0
@@ -0,0 +1,1046 @@
1
+
2
+ from typing import Iterable, Optional
3
+
4
+ from itertools import product
5
+ import matplotlib.pyplot as plt
6
+ from multipers.ml.convolutions import convolution_signed_measures
7
+ import numpy as np
8
+ from joblib import Parallel, delayed
9
+ from sklearn.base import BaseEstimator, TransformerMixin
10
+ from tqdm import tqdm
11
+
12
+ import multipers as mp
13
+ from multipers.simplex_tree_multi import SimplexTreeMulti
14
+
15
+ reduce_grid = SimplexTreeMulti._reduce_grid
16
+
17
+
18
+ class SimplexTree2SignedMeasure(BaseEstimator,TransformerMixin):
19
+ """
20
+ Input
21
+ -----
22
+ Iterable[SimplexTreeMulti]
23
+
24
+ Output
25
+ ------
26
+ Iterable[ list[signed_measure for degree] ]
27
+
28
+ signed measure is either
29
+ - (points : (n x num_parameters) array, weights : (n) int array ) if sparse,
30
+ - else an integer matrix.
31
+
32
+ Parameters
33
+ ----------
34
+ - degrees : list of degrees to compute. None correspond to the euler characteristic
35
+ - filtration grid : the grid on which to compute.
36
+ If None, the fit will infer it from
37
+ - fit_fraction : the fraction of data to consider for the fit, seed is controlled by the seed parameter
38
+ - resolution : the resolution of this grid
39
+ - filtration_quantile : filtrations values quantile to ignore
40
+ - grid_strategy:str : 'regular' or 'quantile' or 'exact'
41
+ - normalize filtration : if sparse, will normalize all filtrations.
42
+ - expand : expands the simplextree to compute correctly the degree, for
43
+ flag complexes
44
+ - invariant : the topological invariant to produce the signed measure.
45
+ Choices are "hilbert" or "euler". Will add rank invariant later.
46
+ - num_collapse : Either an int or "full". Collapse the complex before
47
+ doing computation.
48
+ - _möbius_inversion : if False, will not do the mobius inversion. output
49
+ has to be a matrix then.
50
+ - enforce_null_mass : Returns a zero mass measure, by thresholding the
51
+ module if True.
52
+ """
53
+ def __init__(self,
54
+ degrees:list[int|None]|None=[], # homological degrees + None for euler
55
+ rank_degrees:list[int]=[], # same for rank invariant
56
+ filtration_grid:Iterable[np.ndarray]|None=None, # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
57
+ progress=False, # tqdm
58
+ num_collapses:int|str=0, # edge collapses before computing
59
+ n_jobs=None,
60
+ resolution:Iterable[int]|int|None=None, # when filtration grid is not given, the resolution of the filtration grid to infer
61
+ # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
62
+ plot:bool=False,
63
+ filtration_quantile:float=0., # quantile for inferring filtration grid
64
+ _möbius_inversion:bool=True, # wether or not to do the möbius inversion (not recommended to touch)
65
+ expand=True, # expand the simplextree befoe computing the homology
66
+ normalize_filtrations:bool=False,
67
+ # exact_computation:bool=False, # compute the exact signed measure.
68
+ grid_strategy:str='exact',
69
+ seed:int=0, # if fit_fraction is not 1, the seed sampling
70
+ fit_fraction = 1, # the fraction of the data on which to fit
71
+ out_resolution:Iterable[int]|int|None=None,
72
+ individual_grid:Optional[bool]=None, # Can be significantly faster for some grid strategies, but can drop statistical performance
73
+ enforce_null_mass:bool=False,
74
+ flatten=True,
75
+ ):
76
+ super().__init__()
77
+ self.degrees = degrees
78
+ self.rank_degrees = rank_degrees
79
+ self.filtration_grid = filtration_grid
80
+ self.progress = progress
81
+ self.num_collapses=num_collapses
82
+ self.n_jobs = n_jobs
83
+ self.resolution = np.inf if grid_strategy == 'exact' and resolution is None else 100
84
+ self.plot=plot
85
+ # self.sparse=sparse # TODO : deprecate
86
+ self.filtration_quantile=filtration_quantile
87
+ self.normalize_filtrations = normalize_filtrations # Will only work for non sparse output. (discrete matrices cannot be "rescaled")
88
+ self.grid_strategy = grid_strategy
89
+ self.num_parameter = None
90
+ self._is_input_delayed = None
91
+ self._möbius_inversion = _möbius_inversion
92
+ self._reconversion_grid = None
93
+ self.expand = expand
94
+ # will only refit the grid if filtration_grid has never been given.
95
+ self._refit_grid = None
96
+ self.seed=seed
97
+ self.fit_fraction = fit_fraction
98
+ self._transform_st = None
99
+ self._to_simplex_tree = None
100
+ self.out_resolution = out_resolution
101
+ self.individual_grid = individual_grid
102
+ self.enforce_null_mass = enforce_null_mass
103
+ self._default_mass_location=None
104
+ self.flatten=flatten
105
+ return
106
+ def _infer_filtration(self,X):
107
+ indices = np.random.choice(
108
+ len(X), min(int(self.fit_fraction* len(X)) +1, len(X)),
109
+ replace=False
110
+ )
111
+ get_st_filtration = lambda x : self._to_simplex_tree(x).get_filtration_grid(grid_strategy="exact")
112
+ filtrations = Parallel(n_jobs=self.n_jobs, backend="threading")(
113
+ delayed(get_st_filtration)(x) for x in (X[idx] for idx in indices)
114
+ )
115
+ num_parameters = len(filtrations[0])
116
+ filtrations_values = [np.unique(np.concatenate([x[i] for x in filtrations])) for i in range(num_parameters)]
117
+ filtration_grid = reduce_grid(filtrations_values, resolutions = self.resolution, strategy = self.grid_strategy) # TODO :use more parameters
118
+ self.filtration_grid=filtration_grid
119
+ return filtration_grid
120
+
121
+ def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
122
+ # assert not self.normalize_filtrations or self.sparse, "Not able to normalize a matrix without losing information."
123
+ assert self.resolution is not None or self.filtration_grid is not None or self.grid_strategy == "exact" or self.individual_grid, 'For non exact filtrations, a resolution has to be specified.'
124
+ assert self._möbius_inversion or not self.individual_grid, "The grid has to be aligned when not using mobius inversion; disable individual_grid or enable mobius inversion."
125
+ # assert self.invariant != "_" or self._möbius_inversion
126
+ self._is_input_delayed = not isinstance(X[0], mp.SimplexTreeMulti)
127
+ if self._is_input_delayed:
128
+ from multipers.ml.tools import get_simplex_tree_from_delayed
129
+ self._to_simplex_tree = get_simplex_tree_from_delayed
130
+ else:
131
+ self._to_simplex_tree = lambda x : x
132
+ if isinstance(self.resolution, int) or self.resolution == np.inf:
133
+ self.resolution = [self.resolution]*self._to_simplex_tree(X[0]).num_parameters
134
+ self.num_parameter = len(self.filtration_grid) if self.resolution is None else len(self.resolution)
135
+
136
+ self.individual_grid = self.individual_grid if self.individual_grid is not None else self.grid_strategy in ["regular_closest", "exact", "quantile", "partition"]
137
+
138
+ if not self.enforce_null_mass and self.individual_grid or self.filtration_grid is not None:
139
+ self._refit_grid = False
140
+ else: self._refit_grid = True
141
+
142
+ if self._refit_grid:
143
+ self._infer_filtration(X=X)
144
+ if self.out_resolution is None:
145
+ self.out_resolution = self.resolution
146
+ elif isinstance(self.out_resolution, int):
147
+ self.out_resolution = [self.out_resolution]*len(self.resolution)
148
+
149
+ if self.normalize_filtrations and not self.individual_grid:
150
+ # self._reconversion_grid = [np.linspace(0,1, num=len(f), dtype=float) for f in self.filtration_grid] ## This will not work for non-regular grids...
151
+ self._reconversion_grid = [f/np.std(f) for f in self.filtration_grid] # not the best, but better than some weird magic
152
+ # elif not self.sparse: # It actually renormalizes the filtration !!
153
+ # self._reconversion_grid = [np.linspace(0,r, num=r, dtype=int) for r in self.out_resolution]
154
+ else:
155
+ self._reconversion_grid = self.filtration_grid
156
+ self._default_mass_location = [g[-1] for g in self._reconversion_grid] if self.enforce_null_mass else None
157
+ return self
158
+
159
+ def transform1(self, simplextree, filtration_grid=None, _reconversion_grid=None):
160
+ if filtration_grid is None: filtration_grid = self.filtration_grid
161
+ if _reconversion_grid is None: _reconversion_grid = self._reconversion_grid
162
+ st = self._to_simplex_tree(simplextree)
163
+ st = mp.SimplexTreeMulti(st, num_parameters = st.num_parameters) ## COPY
164
+ if self.individual_grid:
165
+ filtration_grid = st.get_filtration_grid(grid_strategy=self.grid_strategy, resolution=self.resolution)
166
+ if self.enforce_null_mass:
167
+ filtration_grid = [np.concatenate([f,[d]], axis=0) for f,d in zip(filtration_grid, self._default_mass_location)]
168
+ st.grid_squeeze(filtration_grid = filtration_grid, coordinate_values = True)
169
+ if st.num_parameters == 2:
170
+ if self.num_collapses == "full":
171
+ st.collapse_edges(full=True,max_dimension=1)
172
+ elif isinstance(self.num_collapses, int):
173
+ st.collapse_edges(num=self.num_collapses,max_dimension=1)
174
+ else:
175
+ raise Exception("Bad edge collapse type. either 'full' or an int.")
176
+ int_degrees = np.asarray([d for d in self.degrees if d is not None])
177
+ if self._möbius_inversion:
178
+ ## EULER. First as there is prune above dimension below
179
+ if self.expand and None in self.degrees:
180
+ st.expansion(st.num_vertices)
181
+ signed_measures_euler = mp.signed_measure(
182
+ simplextree=st,degrees=[None],
183
+ plot=self.plot,
184
+ mass_default = self._default_mass_location,
185
+ invariant="euler",
186
+ )[0] if None in self.degrees else []
187
+
188
+ if self.expand and len(int_degrees) > 0:
189
+ st.expansion(np.max(int_degrees)+1)
190
+ if len(int_degrees) > 0:
191
+ st.prune_above_dimension(np.max(np.concatenate([int_degrees, self.rank_degrees]))+1) ## no need to compute homology beyond this
192
+ signed_measures_pers = mp.signed_measure(
193
+ simplextree=st,degrees=int_degrees,
194
+ mass_default = self._default_mass_location,
195
+ plot=self.plot,
196
+ invariant="hilbert",
197
+ ) if len(int_degrees) >0 else []
198
+ if self.plot:
199
+ plt.show()
200
+ if self.expand and len(self.rank_degrees) > 0 :
201
+ st.expansion(np.max(self.rank_degrees)+1)
202
+ if len(self.rank_degrees) > 0:
203
+ st.prune_above_dimension(np.max(self.rank_degrees)+1) ## no need to compute homology beyond this
204
+ signed_measures_rank = mp.signed_measure(
205
+ simplextree=st,degrees=self.rank_degrees,
206
+ mass_default = self._default_mass_location,
207
+ plot=self.plot,
208
+ invariant="rank",
209
+ ) if len(self.rank_degrees) >0 else []
210
+ if self.plot:
211
+ plt.show()
212
+
213
+ else:
214
+ from multipers.euler_characteristic import euler_surface
215
+ from multipers.hilbert_function import hilbert_surface
216
+ from multipers.rank_invariant import rank_invariant
217
+
218
+ if self.expand and None in self.degrees: st.expansion(st.num_vertices)
219
+ signed_measures_euler = euler_surface(
220
+ simplextree=st,
221
+ plot=self.plot,
222
+ )[1][None] if None in self.degrees else []
223
+
224
+ if self.expand and len(int_degrees) > 0:
225
+ st.expansion(np.max(int_degrees)+1)
226
+ if len(int_degrees) > 0:
227
+ st.prune_above_dimension(np.max(np.concatenate([int_degrees, self.rank_degrees]))+1)
228
+ ## no need to compute homology beyond this
229
+ signed_measures_pers = hilbert_surface(
230
+ simplextree=st,degrees=int_degrees,
231
+ plot=self.plot,
232
+ )[1] if len(int_degrees) >0 else []
233
+ if self.plot: plt.show()
234
+
235
+ if self.expand and len(self.rank_degrees) > 0 :
236
+ st.expansion(np.max(self.rank_degrees)+1)
237
+ if len(self.rank_degrees) > 0:
238
+ st.prune_above_dimension(np.max(self.rank_degrees)+1) ## no need to compute homology beyond this
239
+ signed_measures_rank = rank_invariant(
240
+ simplextree=st,degrees=self.rank_degrees,
241
+ plot=self.plot,
242
+ ) if len(self.rank_degrees) >0 else []
243
+
244
+
245
+ count = 0
246
+ signed_measures = []
247
+ for d in self.degrees:
248
+ if d is None:
249
+ signed_measures.append(signed_measures_euler)
250
+ else:
251
+ signed_measures.append(signed_measures_pers[count])
252
+ count += 1
253
+ signed_measures += signed_measures_rank
254
+ if not self._möbius_inversion and self.flatten:
255
+ signed_measures = np.asarray(signed_measures).flatten()
256
+ return signed_measures
257
+ def transform(self,X):
258
+ assert self.filtration_grid is not None or self.individual_grid, "Fit first"
259
+ prefer = "loky" if self._is_input_delayed else "threading"
260
+ out = Parallel(n_jobs=self.n_jobs, backend=prefer)(
261
+ delayed(self.transform1)(to_st) for to_st in tqdm(X, disable = not self.progress, desc=f"Computing signed measure decompositions")
262
+ )
263
+ return out
264
+ # return [self.transform1(x) for x in tqdm(X, disable = not self.progress, desc="Computing Hilbert function")]
265
+
266
+
267
+
268
+
269
+
270
+ class SimplexTrees2SignedMeasures(SimplexTree2SignedMeasure):
271
+ """
272
+ Input
273
+ -----
274
+
275
+ (data) x (axis, e.g. different bandwidths for simplextrees) x (simplextree)
276
+
277
+ Output
278
+ ------
279
+ (data) x (axis) x (degree) x (signed measure)
280
+ """
281
+ def __init__(self,**kwargs):
282
+ super().__init__(**kwargs)
283
+ self._num_st_per_data=None
284
+ # self._super_model=SimplexTree2SignedMeasure(**kwargs)
285
+ self._filtration_grids = None
286
+ return
287
+ def fit(self, X, y=None):
288
+ if len(X) == 0: return self
289
+ try:
290
+ self._num_st_per_data = len(X[0])
291
+ except:
292
+ raise Exception("Shape has to be (num_data, num_axis), dtype=SimplexTreeMulti")
293
+ self._filtration_grids=[]
294
+ for axis in range(self._num_st_per_data):
295
+ self._filtration_grids.append(super().fit([x[axis] for x in X]).filtration_grid)
296
+ # self._super_fits.append(truc)
297
+ # self._super_fits_params = [super().fit([x[axis] for x in X]).get_params() for axis in range(self._num_st_per_data)]
298
+ return self
299
+ def transform(self, X):
300
+ if self.normalize_filtrations:
301
+ _reconversion_grids = [[np.linspace(0,1, num=len(f), dtype=float) for f in F] for F in self._filtration_grids]
302
+ else:
303
+ _reconversion_grids = self._filtration_grids
304
+ def todo(x):
305
+ # return [SimplexTree2SignedMeasure().set_params(**transformer_params).transform1(x[axis]) for axis,transformer_params in enumerate(self._super_fits_params)]
306
+ out = [
307
+ self.transform1(
308
+ x[axis],filtration_grid=filtration_grid,
309
+ _reconversion_grid=_reconversion_grid)
310
+ for axis, filtration_grid, _reconversion_grid in
311
+ zip(range(self._num_st_per_data), self._filtration_grids, _reconversion_grids)
312
+ ]
313
+ return out
314
+
315
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(delayed(todo)(x)
316
+ for x in tqdm(X, disable=not self.progress,
317
+ desc="Computing Signed Measures from simplextrees."))
318
+
319
+
320
+ def rescale_sparse_signed_measure(signed_measure, filtration_weights, normalize_scales=None):
321
+ from copy import deepcopy
322
+ out = deepcopy(signed_measure)
323
+ if normalize_scales is None:
324
+ for degree in range(len(out)): # degree
325
+ for parameter in range(len(filtration_weights)):
326
+ out[degree][0][:,parameter] *= filtration_weights[parameter]
327
+ ## TODO Broadcast w.r.t. the parameter
328
+ else:
329
+ for degree in range(len(out)):
330
+ for parameter in range(len(filtration_weights)):
331
+ out[degree][0][:,parameter] *= filtration_weights[parameter] / normalize_scales[degree][parameter]
332
+ return out
333
+
334
+
335
+
336
+ class SignedMeasureFormatter(BaseEstimator,TransformerMixin):
337
+ """
338
+ Input
339
+ -----
340
+
341
+ (data) x (degree) x (signed measure) or (data) x (axis) x (degree) x (signed measure)
342
+
343
+ Iterable[list[signed_measure_matrix of degree]] or Iterable[previous].
344
+
345
+ The second is meant to use multiple choices for signed measure input. An example of usage : they come from a Rips + Density with different bandwidth.
346
+ It is controlled by the axis parameter.
347
+
348
+ Output
349
+ ------
350
+
351
+ Iterable[list[(reweighted)_sparse_signed_measure of degree]]
352
+ """
353
+ def __init__(self,
354
+ filtrations_weights:Iterable[float]=None,
355
+ normalize=False,
356
+ num_parameters:int|None=None,
357
+ plot:bool=False,
358
+ unsparse:bool=False,
359
+ axis:int=None,
360
+ resolution:int|Iterable[int]=50,
361
+ flatten:bool=False,
362
+ deep_format:bool=False,
363
+ unrag:bool=True,
364
+ n_jobs:int=1,
365
+ verbose:bool=False,
366
+ integrate:bool=False,
367
+ grid_strategy='regular',
368
+ ):
369
+ super().__init__()
370
+ self.filtrations_weights = filtrations_weights
371
+ self.num_parameters = num_parameters
372
+ self.plot=plot
373
+ self.unsparse = unsparse
374
+ self.n_jobs=n_jobs
375
+ self.axis=axis
376
+ self._num_axis=None
377
+ self._is_input_sparse=None
378
+ self.resolution=resolution
379
+ self._filtrations_bounds=None
380
+ self.flatten=flatten
381
+ self.normalize=normalize
382
+ self._normalization_factors=None
383
+ self.deep_format=deep_format
384
+ self.unrag = unrag
385
+ assert not self.deep_format or not self.unsparse
386
+ assert not normalize or (not unsparse and not deep_format and not integrate)
387
+ self.verbose=verbose
388
+ self._num_degrees=None
389
+ self.integrate=integrate
390
+ self.grid_strategy=grid_strategy
391
+ self._infered_grids=None
392
+ return
393
+
394
+ def _get_filtration_bounds(self, X, axis):
395
+ num_degrees = len(X[0][axis])
396
+ stuff = [
397
+ np.concatenate([sm[axis][degree][0] for sm in X], axis=0)
398
+ for degree in range(num_degrees)
399
+ ]
400
+ sizes_ = np.array([len(x) == 0 for x in stuff])
401
+ assert np.all(1- sizes_), f"Degree axis {np.where(sizes_)} is/are trivial !"
402
+ filtrations_bounds = np.asarray([[f.min(axis=0), f.max(axis=0)] for f in stuff])
403
+ normalization_factors = filtrations_bounds[:,1] - filtrations_bounds[:,0] if self.normalize else None
404
+ # print("Normalization factors : ",self._normalization_factors)
405
+ if np.any(normalization_factors == 0 ):
406
+ indices = np.where(normalization_factors == 0)
407
+ # warn(f"Constant filtration encountered, at degree, parameter {indices} and axis {self.axis}.")
408
+ normalization_factors[indices] = 1
409
+ return filtrations_bounds,normalization_factors
410
+
411
+ def _plot_signed_measures(self, sms:Iterable[np.ndarray], size=4):
412
+ from multipers.plots import plot_signed_measure
413
+ num_degrees = len(sms[0])
414
+ num_imgs = len(sms)
415
+ fig, axes = plt.subplots(ncols=num_degrees, nrows=num_imgs, figsize=(size*num_degrees,size*num_imgs))
416
+ axes = np.asarray(axes).reshape(num_imgs,num_degrees)
417
+ # assert axes.ndim==2, "Internal error"
418
+ for i, sm in enumerate(sms):
419
+ for j, sm_of_degree in enumerate(sm):
420
+ plot_signed_measure(sm_of_degree, ax=axes[i,j])
421
+
422
+ def fit(self, X, y=None):
423
+ assert not self.normalize or (not self.unsparse and not self.deep_format and not self.integrate)
424
+ ## Gets a grid. This will be the max in each coord+1
425
+ if len(X) == 0 or len(X[0]) == 0 or (self.axis is not None and len(X[0][0][0]) == 0): return self
426
+
427
+ self._is_input_sparse = (isinstance(X[0][0], tuple) and self.axis is None) or (isinstance(X[0][0][0], tuple) and self.axis is not None)
428
+ # print("Sparse input : ", self._is_input_sparse)
429
+ if self.axis is None:
430
+ # try:
431
+ ## DATA,NOAXIS,DEGREE,(sm,weights)
432
+ self.num_parameters = X[0][0][0].shape[1] if self._is_input_sparse else X[0][0].ndim
433
+ # except:
434
+ # print(X)
435
+ # raise Exception("")
436
+ self._num_degrees = len(X[0])
437
+ else:
438
+ # (data) x (axis) x (degree) x (signed measure)
439
+ self.num_parameters = X[0][0][0][0].shape[1] if self._is_input_sparse else X[0][0][0].ndim
440
+ self._num_degrees = len(X[0][0])
441
+ # Sets weights to 1 if None
442
+ if self.filtrations_weights is None:
443
+ self.filtrations_weights = np.array([1]*self.num_parameters)
444
+ ## Checks compatibilities
445
+ assert self._is_input_sparse or (not self.deep_format)
446
+
447
+ # resolution is iterable over the parameters
448
+ try:
449
+ float(self.resolution)
450
+ self.resolution = [self.resolution]*self.num_parameters
451
+ except:
452
+ None
453
+ assert len(self.filtrations_weights) == self.num_parameters == len(self.resolution), f"Number of parameter is not consistent. Inferred : {self.num_parameters}, Filtration weigths : {len(self.filtrations_weights)}, Resolutions : {len(self.resolution)}."
454
+ # if not sparse : not recommended.
455
+ assert np.all(1 == np.asarray(self.filtrations_weights)) or self._is_input_sparse, f"Use sparse signed measure to rescale. Recieved weights {self.filtrations_weights}"
456
+ self._num_axis = None if self.axis is None else len(X[0])
457
+
458
+ ## Computes normalization factors
459
+ if self._is_input_sparse and self.normalize:
460
+ axis = slice(None) if self.axis is None else self.axis
461
+ if axis == -1:
462
+ self._filtrations_bounds = []
463
+ self._normalization_factors = []
464
+ for ax in range(self._num_axis):
465
+ filtration_bounds, normalization_factors = self._get_filtration_bounds(X, axis=ax)
466
+ self._filtrations_bounds.append(filtration_bounds)
467
+ self._normalization_factors.append(normalization_factors)
468
+ else:
469
+ self._filtrations_bounds, self._normalization_factors = self._get_filtration_bounds(X, axis=axis)
470
+ elif self._is_input_sparse and (self.integrate or self.unsparse):
471
+ axis = [slice(None)] if self.axis is None else range(self._num_axis) if self.axis == -1 else [self.axis]
472
+ filtration_values = [np.concatenate([x[ax][degree][0] for x in X for degree in range(self._num_degrees)]) for ax in axis]
473
+ ## axis, filtration_values
474
+ filtration_values = [reduce_grid(f_ax.T, resolutions=self.resolution, strategy=self.grid_strategy) for f_ax in filtration_values]
475
+ self._infered_grids = filtration_values
476
+
477
+ if self.verbose:
478
+ print("------------SignedMeasureFormatter------------")
479
+ print("---- Parameters")
480
+ print(f"Sparse input : {self._is_input_sparse}")
481
+ print(f"Number of axis : {self._num_axis}")
482
+ print(f"Number of degrees : {self._num_degrees}")
483
+ print(f"Filtration bounds : \n{self._filtrations_bounds}")
484
+ print(f"Normalization factor : \n{self._normalization_factors}")
485
+ if self._infered_grids is not None:
486
+ print(f"Filtration grid shape : {tuple(tuple(len(f) for f in F) for F in self._infered_grids)}")
487
+ print("---- SM stats")
488
+ print(f"In axis : {1 if self.axis is None else len(X[0])}")
489
+ if self.axis == -1:
490
+ axis = range(len(X[0]))
491
+ else:
492
+ axis = [slice(None)] if self.axis is None else [self.axis]
493
+ sizes = [[[len(xd[1]) for xd in x[ax]] for x in X] for ax in axis]
494
+ print(f"Size means (axis) x (degree): {np.mean(sizes, axis=(1))}")
495
+ print(f"Size std : {np.std(sizes, axis=(1))}")
496
+ print("----------------------------------------------")
497
+ return self
498
+
499
+ def unsparse_signed_measure(self, sparse_signed_measure):
500
+ filtrations = self._infered_grids # ax, filtration
501
+ out = []
502
+ axis = range(self._num_axis) if self.axis == -1 else [slice(None)]
503
+ for filtrations_of_ax, ax in zip(filtrations, axis):
504
+ sparse_signed_measure_of_ax = sparse_signed_measure[ax]
505
+ measure_of_ax = []
506
+ for pts, weights in sparse_signed_measure_of_ax: # over degree
507
+ signed_measure,_ = np.histogramdd(
508
+ pts,bins=filtrations_of_ax,
509
+ weights=weights
510
+ )
511
+ if self.flatten: signed_measure = signed_measure.flatten()
512
+ measure_of_ax.append(signed_measure)
513
+ out.append(np.asarray(measure_of_ax))
514
+
515
+ if self.flatten: out = np.concatenate(out).flatten()
516
+ if self.axis == -1:
517
+ return np.asarray(out)
518
+ else:
519
+ return np.asarray(out)[0]
520
+
521
+ @staticmethod
522
+ def deep_format_measure(signed_measure):
523
+ from numpy import empty, float32
524
+ dirac_positions,dirac_signs = signed_measure
525
+ new_shape = list(dirac_positions.shape)
526
+ new_shape[1]+=1
527
+ c=empty(new_shape, dtype=float32)
528
+ c[:,:-1] =dirac_positions
529
+ c[:,-1] = dirac_signs
530
+ return c
531
+
532
+ @staticmethod
533
+ def _integrate_measure(sm, filtrations):
534
+ from multipers.point_measure_integration import integrate_measure
535
+ return integrate_measure(sm[0], sm[1],filtrations)
536
+
537
+
538
+
539
+ def transform(self,X):
540
+ def rescale_from_not_sparse(signed_measure:Iterable[np.ndarray]):
541
+ if not self.flatten:
542
+ return signed_measure
543
+ return np.asarray([sm.flatten() for sm in signed_measure]).flatten()
544
+
545
+ def rescale_from_sparse(sparse_signed_measure):
546
+ if self.axis == -1:
547
+ return [
548
+ rescale_sparse_signed_measure(sparse_signed_measure[ax], filtration_weights=self.filtrations_weights, normalize_scales = n)
549
+ for ax,n in enumerate(self._normalization_factors)
550
+ ]
551
+ return rescale_sparse_signed_measure(sparse_signed_measure, filtration_weights=self.filtrations_weights, normalize_scales = self._normalization_factors)
552
+
553
+
554
+ if self._is_input_sparse:
555
+ todo_rescale = rescale_from_sparse
556
+ else:
557
+ todo_rescale = rescale_from_not_sparse
558
+
559
+ if self.axis is None or self.axis == -1:
560
+ out = X
561
+ else:
562
+ out = tuple(x[self.axis] for x in X)
563
+ if self._normalization_factors is not None:
564
+ if self.n_jobs >1:
565
+ out = Parallel(n_jobs=self.n_jobs, backend="threading")(delayed(todo_rescale)(x) for x in out)
566
+ else:
567
+ out = tuple(todo_rescale(x) for x in out)
568
+
569
+ if self._is_input_sparse:
570
+ if self.plot:
571
+ # assert ax != -1, "Not implemented"
572
+ self._plot_signed_measures(out)
573
+ if self.integrate:
574
+ filtrations = self._infered_grids
575
+ # if self.axis != -1:
576
+ ax = 0 #if self.axis is None else self.axis # TODO deal with axis -1
577
+
578
+ assert ax != -1, "Not implemented"
579
+ # try:
580
+ out = np.asarray([[self._integrate_measure(x[degree],filtrations=filtrations[ax]) for degree in range(self._num_degrees)] for x in out])
581
+ # except:
582
+ # print(self.axis, ax, filtrations)
583
+ if self.flatten:
584
+ out = out.reshape((len(X), -1))
585
+ # else:
586
+ # out = [[[self._integrate_measure(x[axis][degree],filtrations=filtrations[degree].T) for degree in range(self._num_degrees)] for axis in range(self._num_axis)] for x in out]
587
+ elif self.unsparse:
588
+ out = [self.unsparse_signed_measure(x) for x in out]
589
+ elif self.deep_format:
590
+ if self.axis is None:
591
+ num_degrees = len(out[0])
592
+ axes = [slice(None)]
593
+ else:
594
+ num_degrees = len(out[0][0])
595
+ axes = range(self._num_axis) if self.axis == -1 else (self.axis,)
596
+ out = [[self.deep_format_measure(sm[axis][degree]) for sm in out] for degree in range(num_degrees) for axis in axes]
597
+ if self.unrag:
598
+ max_num_pts = np.max([sm.shape[0] for sm_of_axis in out for sm in sm_of_axis])
599
+ num_axis = len(out)
600
+ num_data = len(out[0])
601
+ num_parameters = out[0][0].shape[1]
602
+ unragged_tensor = np.zeros(shape=(num_axis,num_data, max_num_pts,num_parameters),dtype=np.float32)
603
+ for ax in range(num_axis):
604
+ for data in range(num_data):
605
+ sm = out[ax][data]
606
+ a,b = sm.shape
607
+ unragged_tensor[ax,data,:a,:b] = sm
608
+ out = unragged_tensor
609
+ return out
610
+
611
+
612
+
613
+
614
+
615
+
616
+
617
+ class SignedMeasure2Convolution(BaseEstimator,TransformerMixin):
618
+ """
619
+ Discrete convolution of a signed measure
620
+
621
+ Input
622
+ -----
623
+
624
+ (data) x (degree) x (signed measure)
625
+
626
+ Parameters
627
+ ----------
628
+ - filtration_grid : Iterable[array] For each filtration, the filtration values on which to evaluate the grid
629
+ - resolution : int or (num_parameter) : If filtration grid is not given, will infer a grid, with this resolution
630
+ - grid_strategy : the strategy to generate the grid. Available ones are regular, quantile, exact
631
+ - flatten : if true, the output will be flattened
632
+ - kernel : kernel to used to convolve the images.
633
+ - flatten : flatten the images if True
634
+ - progress : progress bar if True
635
+ - backend : sklearn, pykeops or numba.
636
+ - plot : Creates a plot Figure.
637
+
638
+ Output
639
+ ------
640
+
641
+ (data) x (concatenation of imgs of degree)
642
+ """
643
+ def __init__(self,
644
+ filtration_grid:Iterable[np.ndarray]=None,
645
+ kernel="gaussian",
646
+ bandwidth:float|Iterable[float]=1.,
647
+ flatten:bool=False, n_jobs:int=1,
648
+ resolution:int|None=None,
649
+ grid_strategy:str="regular",
650
+ progress:bool=False,
651
+ backend:str='pykeops',
652
+ plot:bool=False,
653
+ # **kwargs ## DANGEROUS
654
+ ):
655
+ super().__init__()
656
+ self.kernel=kernel
657
+ self.bandwidth=bandwidth
658
+ # self.more_kde_kwargs=kwargs
659
+ self.filtration_grid=filtration_grid
660
+ self.flatten=flatten
661
+ self.progress=progress
662
+ self.n_jobs = n_jobs
663
+ self.resolution = resolution
664
+ self.grid_strategy = grid_strategy
665
+ self._is_input_sparse = None
666
+ self._refit = filtration_grid is None
667
+ self._input_resolution=None
668
+ self._bandwidths=None
669
+ self.diameter=None
670
+ self.backend=backend
671
+ self.plot=plot
672
+ return
673
+ def fit(self, X, y=None):
674
+ ## Infers if the input is sparse given X
675
+ if len(X) == 0: return self
676
+ if isinstance(X[0][0], tuple): self._is_input_sparse = True
677
+ else: self._is_input_sparse = False
678
+ # print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
679
+ if not self._is_input_sparse:
680
+ self._input_resolution = X[0][0].shape
681
+ try:
682
+ float(self.bandwidth)
683
+ b = float(self.bandwidth)
684
+ self._bandwidths = [b if b > 0 else -b * s for s in self._input_resolution]
685
+ except:
686
+ self._bandwidths = [b if b > 0 else -b * s for s,b in zip(self._input_resolution, self.bandwidth)]
687
+ return self # in that case, singed measures are matrices, and the grid is already given
688
+
689
+ if self.filtration_grid is None and self.resolution is None:
690
+ raise Exception("Cannot infer filtration grid. Provide either a filtration grid or a resolution.")
691
+ ## If not sparse : a grid has to be defined
692
+ if self._refit:
693
+ # print("Fitting a grid...", end="")
694
+ pts = np.concatenate([
695
+ sm[0] for signed_measures in X for sm in signed_measures
696
+ ]).T
697
+ self.filtration_grid = reduce_grid(filtrations_values=pts, strategy=self.grid_strategy, resolutions=self.resolution)
698
+ # print('Done.')
699
+ if self.filtration_grid is not None:
700
+ self.diameter=np.linalg.norm([f.max() - f.min() for f in self.filtration_grid])
701
+ if self.progress: print(f"Computed a diameter of {self.diameter}")
702
+ return self
703
+
704
+ def _sparsify(self,sm):
705
+ return tensor_möbius_inversion(input=sm,grid_conversion=self.filtration_grid)
706
+
707
+ def _sm2smi(self, signed_measures:Iterable[np.ndarray]):
708
+ # print(self._input_resolution, self.bandwidths, _bandwidths)
709
+ from scipy.ndimage import gaussian_filter
710
+ return np.concatenate([
711
+ gaussian_filter(input=signed_measure, sigma=self._bandwidths,mode="constant", cval=0)
712
+ for signed_measure in signed_measures], axis=0)
713
+ # def _sm2smi_sparse(self, signed_measures:Iterable[tuple[np.ndarray]]):
714
+ # return np.concatenate([
715
+ # _pts_convolution_sparse(
716
+ # pts = signed_measure_pts, pts_weights = signed_measure_weights,
717
+ # filtration_grid = self.filtration_grid,
718
+ # kernel=self.kernel,
719
+ # bandwidth=self.bandwidths,
720
+ # **self.more_kde_kwargs
721
+ # )
722
+ # for signed_measure_pts, signed_measure_weights in signed_measures], axis=0)
723
+ def _transform_from_sparse(self,X):
724
+ bandwidth = self.bandwidth if self.bandwidth > 0 else -self.bandwidth * self.diameter
725
+ ## COMPILE KEOPS FIRST
726
+ dummyx = [X[0]]
727
+ dummyf = [f[:2] for f in self.filtration_grid]
728
+ convolution_signed_measures(dummyx, filtrations=dummyf, bandwidth=bandwidth, flatten=self.flatten, n_jobs=1, kernel=self.kernel, backend=self.backend)
729
+
730
+ return convolution_signed_measures(X, filtrations=self.filtration_grid, bandwidth=bandwidth, flatten=self.flatten, n_jobs=self.n_jobs, kernel=self.kernel, backend=self.backend)
731
+
732
+ def _plot_imgs(self, imgs:Iterable[np.ndarray], size=4):
733
+ from multipers.plots import plot_surface
734
+ num_degrees = imgs[0].shape[0]
735
+ num_imgs = len(imgs)
736
+ fig, axes = plt.subplots(ncols=num_degrees, nrows=num_imgs, figsize=(size*num_degrees,size*num_imgs))
737
+ axes = np.asarray(axes).reshape(num_imgs,num_degrees)
738
+ # assert axes.ndim==2, "Internal error"
739
+ for i, img in enumerate(imgs):
740
+ for j, img_of_degree in enumerate(img):
741
+ plot_surface(self.filtration_grid, img_of_degree, ax=axes[i,j], cmap='Spectral')
742
+ def transform(self,X):
743
+ if self._is_input_sparse is None: raise Exception("Fit first")
744
+ if self._is_input_sparse:
745
+ out = self._transform_from_sparse(X)
746
+ else:
747
+ todo = SignedMeasure2Convolution._sm2smi
748
+ out = Parallel(n_jobs=self.n_jobs, backend="threading")(delayed(todo)(self, signed_measures) for signed_measures in tqdm(X, desc="Computing images", disable = not self.progress))
749
+ if self.plot and not self.flatten:
750
+ if self.progress: print("Plotting convolutions...", end="")
751
+ self._plot_imgs(out)
752
+ if self.progress: print("Done !")
753
+ if self.flatten and not self._is_input_sparse: out = [x.flatten() for x in out]
754
+ return np.asarray(out)
755
+
756
+
757
+
758
+ class SignedMeasure2SlicedWassersteinDistance(BaseEstimator,TransformerMixin):
759
+ """
760
+ Transformer from signed measure to distance matrix.
761
+
762
+ Input
763
+ -----
764
+
765
+ (data) x (degree) x (signed measure)
766
+
767
+ Format
768
+ ------
769
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
770
+ - each data is a list of signed measure (for e.g. multiple degrees)
771
+
772
+ Output
773
+ ------
774
+ - (degree) x (distance matrix)
775
+ """
776
+ def __init__(self, n_jobs=None, num_directions:int=10, _sliced:bool=True, epsilon=-1, ground_norm=1, progress = False, grid_reconversion=None, scales=None):
777
+ super().__init__()
778
+ self.n_jobs=n_jobs
779
+ self._SWD_list = None
780
+ self._sliced=_sliced
781
+ self.epsilon = epsilon
782
+ self.ground_norm = ground_norm
783
+ self.num_directions = num_directions
784
+ self.progress = progress
785
+ self.grid_reconversion=grid_reconversion
786
+ self.scales=scales
787
+ return
788
+
789
+ def fit(self, X, y=None):
790
+ from multipers.ml.sliced_wasserstein import (SlicedWassersteinDistance,
791
+ WassersteinDistance)
792
+
793
+ # _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
794
+ if len(X) == 0: return self
795
+ num_degrees = len(X[0])
796
+ self._SWD_list = [
797
+ SlicedWassersteinDistance(num_directions=self.num_directions, n_jobs=self.n_jobs, scales=self.scales)
798
+ if self._sliced else
799
+ WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm, n_jobs=self.n_jobs)
800
+ for _ in range(num_degrees)
801
+ ]
802
+ for degree, swd in enumerate(self._SWD_list):
803
+ signed_measures_of_degree = [x[degree] for x in X]
804
+ swd.fit(signed_measures_of_degree)
805
+ return self
806
+ def transform(self,X):
807
+ assert self._SWD_list is not None, "Fit first"
808
+ # out = []
809
+ # for degree, swd in tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress):
810
+ with tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress) as SWD_it:
811
+ # signed_measures_of_degree = [x[degree] for x in X]
812
+ # out.append(swd.transform(signed_measures_of_degree))
813
+ todo = lambda swd, X_of_degree : swd.transform(X_of_degree)
814
+ out = Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(todo)(swd,[x[degree] for x in X]) for degree, swd in SWD_it)
815
+ return np.asarray(out)
816
+ def predict(self, X):
817
+ return self.transform(X)
818
+
819
+
820
+ class SignedMeasures2SlicedWassersteinDistances(BaseEstimator,TransformerMixin):
821
+ """
822
+ Transformer from signed measure to distance matrix.
823
+ Input
824
+ -----
825
+ (data) x opt (axis) x (degree) x (signed measure)
826
+
827
+ Format
828
+ ------
829
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
830
+ - each data is a list of signed measure (for e.g. multiple degrees)
831
+
832
+ Output
833
+ ------
834
+ - (axis) x (degree) x (distance matrix)
835
+ """
836
+ def __init__(self, progress=False, n_jobs:int=1, scales:Iterable[Iterable[float]]|None = None, **kwargs): # same init
837
+ self._init_child = SignedMeasure2SlicedWassersteinDistance(progress=False, scales=None,n_jobs=-1, **kwargs)
838
+ self._axe_iterator=None
839
+ self._childs_to_fit=None
840
+ self.scales = scales
841
+ self.progress = progress
842
+ self.n_jobs=n_jobs
843
+ return
844
+
845
+ def fit(self, X, y=None):
846
+ from sklearn.base import clone
847
+ if len(X) == 0: return self
848
+ if isinstance(X[0][0],tuple): # Meaning that there are no axes
849
+ self._axe_iterator = [slice(None)]
850
+ else:
851
+ self._axe_iterator = range(len(X[0]))
852
+ if self.scales is None:
853
+ self.scales = [None]
854
+ else:
855
+ self.scales = np.asarray(self.scales)
856
+ if self.scales.ndim == 1:
857
+ self.scales = np.asarray([self.scales])
858
+ assert self.scales[0] is None or self.scales.ndim==2, "Scales have to be either None or a list of scales !"
859
+ self._childs_to_fit = [
860
+ clone(self._init_child).set_params(scales=scales).fit(
861
+ [x[axis] for x in X])
862
+ for axis, scales in product(self._axe_iterator, self.scales)
863
+ ]
864
+ print("New axes : ", list(product(self._axe_iterator, self.scales)))
865
+ return self
866
+ def transform(self,X):
867
+ return Parallel(n_jobs=self.n_jobs, prefer="processes")(
868
+ delayed(self._childs_to_fit[child_id].transform)([x[axis] for x in X])
869
+ for child_id, (axis, _) in tqdm(enumerate(product(self._axe_iterator, self.scales)),
870
+ desc=f"Computing distances matrices of axis, and scales", disable=not self.progress, total=len(self._childs_to_fit)
871
+ )
872
+ )
873
+ # [
874
+ # child.transform([x[axis // len(self.scales)] for x in X])
875
+ # for axis, child in tqdm(enumerate(self._childs_to_fit),
876
+ # desc=f"Computing distances of axis", disable=not self.progress, total=len(self._childs_to_fit)
877
+ # )
878
+ # ]
879
+
880
+
881
+ class SimplexTree2RectangleDecomposition(BaseEstimator,TransformerMixin):
882
+ """
883
+ Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition.
884
+ """
885
+ def __init__(self, filtration_grid:np.ndarray, degrees:Iterable[int], plot=False, reconvert_grid=True, num_collapses:int=0):
886
+ super().__init__()
887
+ self.filtration_grid = filtration_grid
888
+ self.degrees = degrees
889
+ self.plot=plot
890
+ self.reconvert_grid = reconvert_grid
891
+ self.num_collapses=num_collapses
892
+ return
893
+ def fit(self, X, y=None):
894
+ """
895
+ TODO : infer grid from multiple simplextrees
896
+ """
897
+ return self
898
+ def transform(self,X:Iterable[mp.SimplexTreeMulti]):
899
+ rectangle_decompositions = [
900
+ [_st2ranktensor(
901
+ simplextree, filtration_grid=self.filtration_grid,
902
+ degree=degree,
903
+ plot=self.plot,
904
+ reconvert_grid = self.reconvert_grid,
905
+ num_collapse=self.num_collapses
906
+ ) for degree in self.degrees]
907
+ for simplextree in X
908
+ ]
909
+ ## TODO : return iterator ?
910
+ return rectangle_decompositions
911
+
912
+
913
+
914
+ def _st2ranktensor(st:mp.SimplexTreeMulti, filtration_grid:np.ndarray, degree:int, plot:bool, reconvert_grid:bool, num_collapse:int|str=0):
915
+ """
916
+ TODO
917
+ """
918
+ ## Copy (the squeeze change the filtration values)
919
+ stcpy = mp.SimplexTreeMulti(st)
920
+ # turns the simplextree into a coordinate simplex tree
921
+ stcpy.grid_squeeze(
922
+ filtration_grid = filtration_grid,
923
+ coordinate_values = True)
924
+ # stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
925
+ if num_collapse == "full":
926
+ stcpy.collapse_edges(full=True, ignore_warning=True, max_dimension=degree+1)
927
+ elif isinstance(num_collapse, int):
928
+ stcpy.collapse_edges(num=num_collapse,ignore_warning=True, max_dimension=degree+1)
929
+ else:
930
+ raise TypeError(f"Invalid num_collapse={num_collapse} type. Either full, or an integer.")
931
+ # computes the rank invariant tensor
932
+ rank_tensor = mp.rank_invariant2d(stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid])
933
+ # refactor this tensor into the rectangle decomposition of the signed betti
934
+ grid_conversion = filtration_grid if reconvert_grid else None
935
+ rank_decomposition = rank_decomposition_by_rectangles(
936
+ rank_tensor, threshold=True,
937
+ )
938
+ rectangle_decomposition = tensor_möbius_inversion(tensor = rank_decomposition, grid_conversion = grid_conversion, plot=plot, num_parameters=st.num_parameters)
939
+ return rectangle_decomposition
940
+
941
+ class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
942
+ def __init__(self, degrees:Iterable[int], min_rips_value:float,
943
+ max_rips_value,max_normalized_degree:float, min_normalized_degree:float,
944
+ grid_granularity:int, progress:bool=False, n_jobs=1, sparse:bool=False,
945
+ _möbius_inversion=True,
946
+ fit_fraction=1,
947
+ ) -> None:
948
+ super().__init__()
949
+ self.min_rips_value = min_rips_value
950
+ self.max_rips_value = max_rips_value
951
+ self.min_normalized_degree = min_normalized_degree
952
+ self.max_normalized_degree = max_normalized_degree
953
+ self._max_rips_value = None
954
+ self.grid_granularity = grid_granularity
955
+ self.progress=progress
956
+ self.n_jobs = n_jobs
957
+ self.degrees = degrees
958
+ self.sparse=sparse
959
+ self._möbius_inversion = _möbius_inversion
960
+ self.fit_fraction=fit_fraction
961
+ return
962
+ def fit(self, X:np.ndarray|list, y=None):
963
+ if self.max_rips_value < 0:
964
+ print("Estimating scale...", flush=True, end="")
965
+ indices = np.random.choice(len(X),min(len(X), int(self.fit_fraction*len(X))+1) ,replace=False)
966
+ diameters =np.max([distance_matrix(x,x).max() for x in (X[i] for i in indices)])
967
+ print(f"Done. {diameters}", flush=True)
968
+ self._max_rips_value = - self.max_rips_value * diameters if self.max_rips_value < 0 else self.max_rips_value
969
+ return self
970
+
971
+ def _transform1(self, data:np.ndarray):
972
+ _distance_matrix = distance_matrix(data, data)
973
+ signed_measures = []
974
+ rips_values, normalized_degree_values, hilbert_functions, minimal_presentations = hf_degree_rips(
975
+ _distance_matrix,
976
+ min_rips_value = self.min_rips_value,
977
+ max_rips_value = self._max_rips_value,
978
+ min_normalized_degree = self.min_normalized_degree,
979
+ max_normalized_degree = self.max_normalized_degree,
980
+ grid_granularity = self.grid_granularity,
981
+ max_homological_dimension = np.max(self.degrees),
982
+ )
983
+ for degree in self.degrees:
984
+ hilbert_function = hilbert_functions[degree]
985
+ signed_measure = signed_betti(hilbert_function, threshold=True) if self._möbius_inversion else hilbert_function
986
+ if self.sparse:
987
+ signed_measure = tensor_möbius_inversion(
988
+ tensor=signed_measure,num_parameters=2,
989
+ grid_conversion=[rips_values, normalized_degree_values]
990
+ )
991
+ if not self._möbius_inversion: signed_measure = signed_measure.flatten()
992
+ signed_measures.append(signed_measure)
993
+ return signed_measures
994
+ def transform(self,X):
995
+ return Parallel(n_jobs=self.n_jobs)(delayed(self._transform1)(data)
996
+ for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}"))
997
+
998
+
999
+
1000
+
1001
+ def tensor_möbius_inversion(tensor, grid_conversion:Iterable[np.ndarray]|None = None, plot:bool=False, raw:bool=False, num_parameters:int|None=None):
1002
+ from torch import Tensor
1003
+ betti_sparse = Tensor(tensor.copy()).to_sparse() # Copy necessary in some cases :(
1004
+ num_indices, num_pts = betti_sparse.indices().shape
1005
+ num_parameters = num_indices if num_parameters is None else num_parameters
1006
+ if num_indices == num_parameters: # either hilbert or rank invariant
1007
+ rank_invariant = False
1008
+ elif 2*num_parameters == num_indices:
1009
+ rank_invariant = True
1010
+ else:
1011
+ raise TypeError(f"Unsupported betti shape. {num_indices} has to be either {num_parameters} or {2*num_parameters}.")
1012
+ points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
1013
+ weights = np.asarray(betti_sparse.values(), dtype=int)
1014
+
1015
+ if grid_conversion is not None:
1016
+ coords = np.empty(shape=(num_pts,num_indices), dtype=float)
1017
+ for i in range(num_indices):
1018
+ coords[:,i] = grid_conversion[i%num_parameters][points_filtration[:,i]]
1019
+ else:
1020
+ coords = points_filtration
1021
+ if (not rank_invariant) and plot:
1022
+ plt.figure()
1023
+ color_weights = np.empty(weights.shape)
1024
+ color_weights[weights>0] = np.log10(weights[weights>0])+2
1025
+ color_weights[weights<0] = -np.log10(-weights[weights<0])-2
1026
+ plt.scatter(points_filtration[:,0],points_filtration[:,1], c=color_weights, cmap="coolwarm")
1027
+ if (not rank_invariant) or raw: return coords, weights
1028
+ def _is_trivial(rectangle:np.ndarray):
1029
+ birth=rectangle[:num_parameters]
1030
+ death=rectangle[num_parameters:]
1031
+ return np.all(birth<=death) # and not np.array_equal(birth,death)
1032
+ correct_indices = np.array([_is_trivial(rectangle) for rectangle in coords])
1033
+ if len(correct_indices) == 0: return np.empty((0, num_indices)), np.empty((0))
1034
+ signed_measure = np.asarray(coords[correct_indices])
1035
+ weights = weights[correct_indices]
1036
+ if plot:
1037
+ assert signed_measure.shape[1] == 4 # plot only the rank decompo for the moment
1038
+ def _plot_rectangle(rectangle:np.ndarray, weight:float):
1039
+ x_axis=rectangle[[0,2]]
1040
+ y_axis=rectangle[[1,3]]
1041
+ color = "blue" if weight > 0 else "red"
1042
+ plt.plot(x_axis, y_axis, c=color)
1043
+ for rectangle, weight in zip(signed_measure, weights):
1044
+ _plot_rectangle(rectangle=rectangle, weight=weight)
1045
+ return signed_measure, weights
1046
+