multipers 1.1.3__cp312-cp312-macosx_11_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (63) hide show
  1. multipers/.dylibs/libtbb.12.12.dylib +0 -0
  2. multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
  3. multipers/__init__.py +5 -0
  4. multipers/_old_rank_invariant.pyx +328 -0
  5. multipers/_signed_measure_meta.py +193 -0
  6. multipers/data/MOL2.py +350 -0
  7. multipers/data/UCR.py +18 -0
  8. multipers/data/__init__.py +1 -0
  9. multipers/data/graphs.py +466 -0
  10. multipers/data/immuno_regions.py +27 -0
  11. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  12. multipers/data/pytorch2simplextree.py +91 -0
  13. multipers/data/shape3d.py +101 -0
  14. multipers/data/synthetic.py +68 -0
  15. multipers/distances.py +172 -0
  16. multipers/euler_characteristic.cpython-312-darwin.so +0 -0
  17. multipers/euler_characteristic.pyx +137 -0
  18. multipers/function_rips.cpython-312-darwin.so +0 -0
  19. multipers/function_rips.pyx +102 -0
  20. multipers/hilbert_function.cpython-312-darwin.so +0 -0
  21. multipers/hilbert_function.pyi +46 -0
  22. multipers/hilbert_function.pyx +151 -0
  23. multipers/io.cpython-312-darwin.so +0 -0
  24. multipers/io.pyx +176 -0
  25. multipers/ml/__init__.py +0 -0
  26. multipers/ml/accuracies.py +61 -0
  27. multipers/ml/convolutions.py +510 -0
  28. multipers/ml/invariants_with_persistable.py +79 -0
  29. multipers/ml/kernels.py +128 -0
  30. multipers/ml/mma.py +657 -0
  31. multipers/ml/one.py +472 -0
  32. multipers/ml/point_clouds.py +191 -0
  33. multipers/ml/signed_betti.py +50 -0
  34. multipers/ml/signed_measures.py +1479 -0
  35. multipers/ml/sliced_wasserstein.py +313 -0
  36. multipers/ml/tools.py +116 -0
  37. multipers/mma_structures.cpython-312-darwin.so +0 -0
  38. multipers/mma_structures.pxd +155 -0
  39. multipers/mma_structures.pyx +651 -0
  40. multipers/multiparameter_edge_collapse.py +29 -0
  41. multipers/multiparameter_module_approximation.cpython-312-darwin.so +0 -0
  42. multipers/multiparameter_module_approximation.pyi +439 -0
  43. multipers/multiparameter_module_approximation.pyx +311 -0
  44. multipers/pickle.py +53 -0
  45. multipers/plots.py +292 -0
  46. multipers/point_measure_integration.cpython-312-darwin.so +0 -0
  47. multipers/point_measure_integration.pyx +59 -0
  48. multipers/rank_invariant.cpython-312-darwin.so +0 -0
  49. multipers/rank_invariant.pyx +154 -0
  50. multipers/simplex_tree_multi.cpython-312-darwin.so +0 -0
  51. multipers/simplex_tree_multi.pxd +121 -0
  52. multipers/simplex_tree_multi.pyi +715 -0
  53. multipers/simplex_tree_multi.pyx +1417 -0
  54. multipers/slicer.cpython-312-darwin.so +0 -0
  55. multipers/slicer.pxd +94 -0
  56. multipers/slicer.pyx +276 -0
  57. multipers/tensor.pxd +13 -0
  58. multipers/test.pyx +44 -0
  59. multipers-1.1.3.dist-info/LICENSE +21 -0
  60. multipers-1.1.3.dist-info/METADATA +22 -0
  61. multipers-1.1.3.dist-info/RECORD +63 -0
  62. multipers-1.1.3.dist-info/WHEEL +5 -0
  63. multipers-1.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1479 @@
1
+ from typing import Iterable, Optional, Callable
2
+
3
+ from itertools import product
4
+ import matplotlib.pyplot as plt
5
+ from multipers.ml.convolutions import convolution_signed_measures
6
+ import numpy as np
7
+ from joblib import Parallel, delayed
8
+ from sklearn.base import BaseEstimator, TransformerMixin
9
+ from tqdm import tqdm
10
+
11
+ import multipers as mp
12
+ from multipers.simplex_tree_multi import SimplexTreeMulti
13
+
14
+ reduce_grid = SimplexTreeMulti._reduce_grid
15
+
16
+
17
+ class SimplexTree2SignedMeasure(BaseEstimator, TransformerMixin):
18
+ """
19
+ Input
20
+ -----
21
+ Iterable[SimplexTreeMulti]
22
+
23
+ Output
24
+ ------
25
+ Iterable[ list[signed_measure for degree] ]
26
+
27
+ signed measure is either
28
+ - (points : (n x num_parameters) array, weights : (n) int array ) if sparse,
29
+ - else an integer matrix.
30
+
31
+ Parameters
32
+ ----------
33
+ - degrees : list of degrees to compute. None correspond to the euler characteristic
34
+ - filtration grid : the grid on which to compute.
35
+ If None, the fit will infer it from
36
+ - fit_fraction : the fraction of data to consider for the fit, seed is controlled by the seed parameter
37
+ - resolution : the resolution of this grid
38
+ - filtration_quantile : filtrations values quantile to ignore
39
+ - grid_strategy:str : 'regular' or 'quantile' or 'exact'
40
+ - normalize filtration : if sparse, will normalize all filtrations.
41
+ - expand : expands the simplextree to compute correctly the degree, for
42
+ flag complexes
43
+ - invariant : the topological invariant to produce the signed measure.
44
+ Choices are "hilbert" or "euler". Will add rank invariant later.
45
+ - num_collapse : Either an int or "full". Collapse the complex before
46
+ doing computation.
47
+ - _möbius_inversion : if False, will not do the mobius inversion. output
48
+ has to be a matrix then.
49
+ - enforce_null_mass : Returns a zero mass measure, by thresholding the
50
+ module if True.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ # homological degrees + None for euler
56
+ degrees: list[int | None] = [],
57
+ rank_degrees: list[int] = [], # same for rank invariant
58
+ filtration_grid: Iterable[np.ndarray]
59
+ # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
60
+ | None = None,
61
+ progress=False, # tqdm
62
+ num_collapses: int | str = 0, # edge collapses before computing
63
+ n_jobs=None,
64
+ resolution: Iterable[int]
65
+ | int
66
+ | None = None, # when filtration grid is not given, the resolution of the filtration grid to infer
67
+ # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
68
+ plot: bool = False,
69
+ filtration_quantile: float = 0.0, # quantile for inferring filtration grid
70
+ # wether or not to do the möbius inversion (not recommended to touch)
71
+ _möbius_inversion: bool = True,
72
+ expand=True, # expand the simplextree befoe computing the homology
73
+ normalize_filtrations: bool = False,
74
+ # exact_computation:bool=False, # compute the exact signed measure.
75
+ grid_strategy: str = "exact",
76
+ seed: int = 0, # if fit_fraction is not 1, the seed sampling
77
+ fit_fraction=1, # the fraction of the data on which to fit
78
+ out_resolution: Iterable[int] | int | None = None,
79
+ individual_grid: Optional[
80
+ bool
81
+ ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
82
+ enforce_null_mass: bool = False,
83
+ flatten=True,
84
+ backend="multipers",
85
+ ):
86
+ super().__init__()
87
+ self.degrees = degrees
88
+ self.rank_degrees = rank_degrees
89
+ self.filtration_grid = filtration_grid
90
+ self.progress = progress
91
+ self.num_collapses = num_collapses
92
+ self.n_jobs = n_jobs
93
+ self.resolution = resolution
94
+ self.plot = plot
95
+ self.backend = backend
96
+ # self.sparse=sparse # TODO : deprecate
97
+ self.filtration_quantile = filtration_quantile
98
+ # Will only work for non sparse output. (discrete matrices cannot be "rescaled")
99
+ self.normalize_filtrations = normalize_filtrations
100
+ self.grid_strategy = grid_strategy
101
+ self.num_parameter = None
102
+ self._is_input_delayed = None
103
+ self._möbius_inversion = _möbius_inversion
104
+ self._reconversion_grid = None
105
+ self.expand = expand
106
+ # will only refit the grid if filtration_grid has never been given.
107
+ self._refit_grid = None
108
+ self.seed = seed
109
+ self.fit_fraction = fit_fraction
110
+ self._transform_st = None
111
+ self._to_simplex_tree: Callable
112
+ self.out_resolution = out_resolution
113
+ self.individual_grid = individual_grid
114
+ self.enforce_null_mass = enforce_null_mass
115
+ self._default_mass_location = None
116
+ self.flatten = flatten
117
+ self.num_parameters: int = 0
118
+ return
119
+
120
+ def _infer_filtration(self, X):
121
+ indices = np.random.choice(
122
+ len(X), min(int(self.fit_fraction * len(X)) + 1, len(X)), replace=False
123
+ )
124
+
125
+ def get_st_filtration(x) -> np.ndarray:
126
+ return self._to_simplex_tree(x).get_filtration_grid(grid_strategy="exact")
127
+
128
+ filtrations: list = Parallel(n_jobs=self.n_jobs, backend="threading")(
129
+ delayed(get_st_filtration)(x) for x in (X[idx] for idx in indices)
130
+ )
131
+ num_parameters = len(filtrations[0])
132
+ filtrations_values = [
133
+ np.unique(np.concatenate([x[i] for x in filtrations]))
134
+ for i in range(num_parameters)
135
+ ]
136
+ filtration_grid = reduce_grid(
137
+ filtrations_values, resolutions=self.resolution, strategy=self.grid_strategy
138
+ ) # TODO :use more parameters
139
+ self.filtration_grid = filtration_grid
140
+ return filtration_grid
141
+
142
+ def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
143
+ # assert not self.normalize_filtrations or self.sparse, "Not able to normalize a matrix without losing information."
144
+ assert (
145
+ self.resolution is not None
146
+ or self.filtration_grid is not None
147
+ or self.grid_strategy == "exact"
148
+ or self.individual_grid
149
+ ), "For non exact filtrations, a resolution has to be specified."
150
+ assert (
151
+ self._möbius_inversion or not self.individual_grid
152
+ ), "The grid has to be aligned when not using mobius inversion; disable individual_grid or enable mobius inversion."
153
+ # assert self.invariant != "_" or self._möbius_inversion
154
+ self._is_input_delayed = not isinstance(X[0], mp.SimplexTreeMulti)
155
+ if self._is_input_delayed:
156
+ from multipers.ml.tools import get_simplex_tree_from_delayed
157
+
158
+ self._to_simplex_tree = get_simplex_tree_from_delayed
159
+ else:
160
+ self._to_simplex_tree = lambda x: x
161
+ if isinstance(self.resolution, int) or self.resolution == np.inf:
162
+ self.resolution = [self.resolution] * self._to_simplex_tree(
163
+ X[0]
164
+ ).num_parameters
165
+ self.num_parameter = self._to_simplex_tree(X[0]).num_parameters
166
+ self.individual_grid = (
167
+ self.individual_grid
168
+ if self.individual_grid is not None
169
+ else self.grid_strategy
170
+ in ["regular_closest", "exact", "quantile", "partition"]
171
+ )
172
+
173
+ if (
174
+ not self.enforce_null_mass
175
+ and self.individual_grid
176
+ or self.filtration_grid is not None
177
+ ):
178
+ self._refit_grid = False
179
+ else:
180
+ self._refit_grid = True
181
+
182
+ if self._refit_grid:
183
+ self._infer_filtration(X=X)
184
+ if self.out_resolution is None:
185
+ self.out_resolution = self.resolution
186
+ elif isinstance(self.out_resolution, int):
187
+ self.out_resolution = [self.out_resolution] * self.num_parameters
188
+ if self.normalize_filtrations and not self.individual_grid:
189
+ # self._reconversion_grid = [np.linspace(0,1, num=len(f), dtype=float) for f in self.filtration_grid] ## This will not work for non-regular grids...
190
+ self._reconversion_grid = [
191
+ f / np.std(f) for f in self.filtration_grid
192
+ ] # not the best, but better than some weird magic
193
+ # elif not self.sparse: # It actually renormalizes the filtration !!
194
+ # self._reconversion_grid = [np.linspace(0,r, num=r, dtype=int) for r in self.out_resolution]
195
+ else:
196
+ self._reconversion_grid = self.filtration_grid
197
+ self._default_mass_location = (
198
+ [g[-1] for g in self._reconversion_grid] if self.enforce_null_mass else None
199
+ )
200
+ return self
201
+
202
+ def transform1(
203
+ self,
204
+ simplextree,
205
+ filtration_grid=None,
206
+ _reconversion_grid=None,
207
+ thread_id: str = "",
208
+ ):
209
+ if filtration_grid is None:
210
+ filtration_grid = self.filtration_grid
211
+ if _reconversion_grid is None:
212
+ _reconversion_grid = self._reconversion_grid
213
+ st = self._to_simplex_tree(simplextree)
214
+ st = mp.SimplexTreeMulti(st, num_parameters=st.num_parameters) # COPY
215
+ if self.individual_grid:
216
+ filtration_grid = st.get_filtration_grid(
217
+ grid_strategy=self.grid_strategy, resolution=self.resolution
218
+ )
219
+ if self.enforce_null_mass:
220
+ filtration_grid = [
221
+ np.concatenate([f, [d]], axis=0)
222
+ for f, d in zip(filtration_grid, self._default_mass_location)
223
+ ]
224
+ st.grid_squeeze(filtration_grid=filtration_grid,
225
+ coordinate_values=True)
226
+ if st.num_parameters == 2:
227
+ if self.num_collapses == "full":
228
+ st.collapse_edges(full=True, max_dimension=1)
229
+ elif isinstance(self.num_collapses, int):
230
+ st.collapse_edges(num=self.num_collapses, max_dimension=1)
231
+ else:
232
+ raise Exception(
233
+ "Bad edge collapse type. either 'full' or an int.")
234
+ int_degrees = np.asarray([d for d in self.degrees if d is not None])
235
+ if self._möbius_inversion:
236
+ # EULER. First as there is prune above dimension below
237
+ if self.expand and None in self.degrees:
238
+ st.expansion(st.num_vertices)
239
+ signed_measures_euler = (
240
+ mp.signed_measure(
241
+ simplextree=st,
242
+ degrees=[None],
243
+ plot=self.plot,
244
+ mass_default=self._default_mass_location,
245
+ invariant="euler",
246
+ thread_id=thread_id,
247
+ )[0]
248
+ if None in self.degrees
249
+ else []
250
+ )
251
+
252
+ if self.expand and len(int_degrees) > 0:
253
+ st.expansion(np.max(int_degrees) + 1)
254
+ if len(int_degrees) > 0:
255
+ st.prune_above_dimension(
256
+ np.max(np.concatenate(
257
+ [int_degrees, self.rank_degrees])) + 1
258
+ ) # no need to compute homology beyond this
259
+ signed_measures_pers = (
260
+ mp.signed_measure(
261
+ simplextree=st,
262
+ degrees=int_degrees,
263
+ mass_default=self._default_mass_location,
264
+ plot=self.plot,
265
+ invariant="hilbert",
266
+ thread_id=thread_id,
267
+ backend=self.backend,
268
+ )
269
+ if len(int_degrees) > 0
270
+ else []
271
+ )
272
+ if self.plot:
273
+ plt.show()
274
+ if self.expand and len(self.rank_degrees) > 0:
275
+ st.expansion(np.max(self.rank_degrees) + 1)
276
+ if len(self.rank_degrees) > 0:
277
+ st.prune_above_dimension(
278
+ np.max(self.rank_degrees) + 1
279
+ ) # no need to compute homology beyond this
280
+ signed_measures_rank = (
281
+ mp.signed_measure(
282
+ simplextree=st,
283
+ degrees=self.rank_degrees,
284
+ mass_default=self._default_mass_location,
285
+ plot=self.plot,
286
+ invariant="rank",
287
+ thread_id=thread_id,
288
+ )
289
+ if len(self.rank_degrees) > 0
290
+ else []
291
+ )
292
+ if self.plot:
293
+ plt.show()
294
+
295
+ else:
296
+ from multipers.euler_characteristic import euler_surface
297
+ from multipers.hilbert_function import hilbert_surface
298
+ from multipers.rank_invariant import rank_invariant
299
+
300
+ if self.expand and None in self.degrees:
301
+ st.expansion(st.num_vertices)
302
+ signed_measures_euler = (
303
+ euler_surface(
304
+ simplextree=st,
305
+ plot=self.plot,
306
+ )[1][None]
307
+ if None in self.degrees
308
+ else []
309
+ )
310
+
311
+ if self.expand and len(int_degrees) > 0:
312
+ st.expansion(np.max(int_degrees) + 1)
313
+ if len(int_degrees) > 0:
314
+ st.prune_above_dimension(
315
+ np.max(np.concatenate(
316
+ [int_degrees, self.rank_degrees])) + 1
317
+ )
318
+ # no need to compute homology beyond this
319
+ signed_measures_pers = (
320
+ hilbert_surface(
321
+ simplextree=st,
322
+ degrees=int_degrees,
323
+ plot=self.plot,
324
+ )[1]
325
+ if len(int_degrees) > 0
326
+ else []
327
+ )
328
+ if self.plot:
329
+ plt.show()
330
+
331
+ if self.expand and len(self.rank_degrees) > 0:
332
+ st.expansion(np.max(self.rank_degrees) + 1)
333
+ if len(self.rank_degrees) > 0:
334
+ st.prune_above_dimension(
335
+ np.max(self.rank_degrees) + 1
336
+ ) # no need to compute homology beyond this
337
+ signed_measures_rank = (
338
+ rank_invariant(
339
+ sieplextree=st,
340
+ degrees=self.rank_degrees,
341
+ plot=self.plot,
342
+ )
343
+ if len(self.rank_degrees) > 0
344
+ else []
345
+ )
346
+
347
+ count = 0
348
+ signed_measures = []
349
+ for d in self.degrees:
350
+ if d is None:
351
+ signed_measures.append(signed_measures_euler)
352
+ else:
353
+ signed_measures.append(signed_measures_pers[count])
354
+ count += 1
355
+ signed_measures += signed_measures_rank
356
+ if not self._möbius_inversion and self.flatten:
357
+ signed_measures = np.asarray(signed_measures).flatten()
358
+ return signed_measures
359
+
360
+ def transform(self, X):
361
+ assert self.filtration_grid is not None or self.individual_grid, "Fit first"
362
+ prefer = "loky" if self._is_input_delayed else "threading"
363
+ out = Parallel(n_jobs=self.n_jobs, backend=prefer)(
364
+ delayed(self.transform1)(to_st, thread_id=str(thread_id))
365
+ for thread_id, to_st in tqdm(
366
+ enumerate(X),
367
+ disable=not self.progress,
368
+ desc="Computing signed measure decompositions",
369
+ )
370
+ )
371
+ return out
372
+
373
+
374
+ class SimplexTrees2SignedMeasures(SimplexTree2SignedMeasure):
375
+ """
376
+ Input
377
+ -----
378
+
379
+ (data) x (axis, e.g. different bandwidths for simplextrees) x (simplextree)
380
+
381
+ Output
382
+ ------
383
+ (data) x (axis) x (degree) x (signed measure)
384
+ """
385
+
386
+ def __init__(self, **kwargs):
387
+ super().__init__(**kwargs)
388
+ self._num_st_per_data = None
389
+ # self._super_model=SimplexTree2SignedMeasure(**kwargs)
390
+ self._filtration_grids = None
391
+ return
392
+
393
+ def fit(self, X, y=None):
394
+ if len(X) == 0:
395
+ return self
396
+ try:
397
+ self._num_st_per_data = len(X[0])
398
+ except:
399
+ raise Exception(
400
+ "Shape has to be (num_data, num_axis), dtype=SimplexTreeMulti"
401
+ )
402
+ self._filtration_grids = []
403
+ for axis in range(self._num_st_per_data):
404
+ self._filtration_grids.append(
405
+ super().fit([x[axis] for x in X]).filtration_grid
406
+ )
407
+ # self._super_fits.append(truc)
408
+ # self._super_fits_params = [super().fit([x[axis] for x in X]).get_params() for axis in range(self._num_st_per_data)]
409
+ return self
410
+
411
+ def transform(self, X):
412
+ if self.normalize_filtrations:
413
+ _reconversion_grids = [
414
+ [np.linspace(0, 1, num=len(f), dtype=float) for f in F]
415
+ for F in self._filtration_grids
416
+ ]
417
+ else:
418
+ _reconversion_grids = self._filtration_grids
419
+
420
+ def todo(x):
421
+ # return [SimplexTree2SignedMeasure().set_params(**transformer_params).transform1(x[axis]) for axis,transformer_params in enumerate(self._super_fits_params)]
422
+ out = [
423
+ self.transform1(
424
+ x[axis],
425
+ filtration_grid=filtration_grid,
426
+ _reconversion_grid=_reconversion_grid,
427
+ )
428
+ for axis, filtration_grid, _reconversion_grid in zip(
429
+ range(self._num_st_per_data),
430
+ self._filtration_grids,
431
+ _reconversion_grids,
432
+ )
433
+ ]
434
+ return out
435
+
436
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
437
+ delayed(todo)(x)
438
+ for x in tqdm(
439
+ X,
440
+ disable=not self.progress,
441
+ desc="Computing Signed Measures from simplextrees.",
442
+ )
443
+ )
444
+
445
+
446
+ def rescale_sparse_signed_measure(
447
+ signed_measure, filtration_weights, normalize_scales=None
448
+ ):
449
+ from copy import deepcopy
450
+
451
+ out = deepcopy(signed_measure)
452
+ if normalize_scales is None:
453
+ for degree in range(len(out)): # degree
454
+ for parameter in range(len(filtration_weights)):
455
+ out[degree][0][:, parameter] *= filtration_weights[parameter]
456
+ # TODO Broadcast w.r.t. the parameter
457
+ else:
458
+ for degree in range(len(out)):
459
+ for parameter in range(len(filtration_weights)):
460
+ out[degree][0][:, parameter] *= (
461
+ filtration_weights[parameter] /
462
+ normalize_scales[degree][parameter]
463
+ )
464
+ return out
465
+
466
+
467
+ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
468
+ """
469
+ Input
470
+ -----
471
+
472
+ (data) x (degree) x (signed measure) or (data) x (axis) x (degree) x (signed measure)
473
+
474
+ Iterable[list[signed_measure_matrix of degree]] or Iterable[previous].
475
+
476
+ The second is meant to use multiple choices for signed measure input. An example of usage : they come from a Rips + Density with different bandwidth.
477
+ It is controlled by the axis parameter.
478
+
479
+ Output
480
+ ------
481
+
482
+ Iterable[list[(reweighted)_sparse_signed_measure of degree]]
483
+ """
484
+
485
+ def __init__(
486
+ self,
487
+ filtrations_weights: Optional[Iterable[float]] = None,
488
+ normalize=False,
489
+ plot: bool = False,
490
+ unsparse: bool = False,
491
+ axis: int = -1,
492
+ resolution: int | Iterable[int] = 50,
493
+ flatten: bool = False,
494
+ deep_format: bool = False,
495
+ unrag: bool = True,
496
+ n_jobs: int = 1,
497
+ verbose: bool = False,
498
+ integrate: bool = False,
499
+ grid_strategy="regular",
500
+ ):
501
+ super().__init__()
502
+ self.filtrations_weights = filtrations_weights
503
+ self.num_parameters: int = 0
504
+ self.plot = plot
505
+ self.unsparse = unsparse
506
+ self.n_jobs = n_jobs
507
+ self.axis = axis
508
+ self._num_axis = 0
509
+ self.resolution = resolution
510
+ self._filtrations_bounds = None
511
+ self.flatten = flatten
512
+ self.normalize = normalize
513
+ self._normalization_factors = None
514
+ self.deep_format = deep_format
515
+ self.unrag = unrag
516
+ assert not self.deep_format or not self.unsparse
517
+ assert not normalize or (
518
+ not unsparse and not deep_format and not integrate)
519
+ self.verbose = verbose
520
+ self._num_degrees = 0
521
+ self.integrate = integrate
522
+ self.grid_strategy = grid_strategy
523
+ self._infered_grids = None
524
+ return
525
+
526
+ def _get_filtration_bounds(self, X, axis):
527
+ num_degrees = len(X[0][axis])
528
+ stuff = [
529
+ np.concatenate([sm[axis][degree][0] for sm in X], axis=0)
530
+ for degree in range(num_degrees)
531
+ ]
532
+ sizes_ = np.array([len(x) == 0 for x in stuff])
533
+ assert np.all(
534
+ 1 - sizes_), f"Degree axis {np.where(sizes_)} is/are trivial !"
535
+ filtrations_bounds = np.asarray(
536
+ [[f.min(axis=0), f.max(axis=0)] for f in stuff])
537
+ normalization_factors = (
538
+ filtrations_bounds[:, 1] - filtrations_bounds[:, 0]
539
+ if self.normalize
540
+ else None
541
+ )
542
+ # print("Normalization factors : ",self._normalization_factors)
543
+ if np.any(normalization_factors == 0):
544
+ indices = np.where(normalization_factors == 0)
545
+ # warn(f"Constant filtration encountered, at degree, parameter {indices} and axis {self.axis}.")
546
+ normalization_factors[indices] = 1
547
+ return filtrations_bounds, normalization_factors
548
+
549
+ def _plot_signed_measures(self, sms: Iterable[np.ndarray], size=4):
550
+ from multipers.plots import plot_signed_measure
551
+
552
+ num_degrees = len(sms[0])
553
+ num_imgs = len(sms)
554
+ fig, axes = plt.subplots(
555
+ ncols=num_degrees,
556
+ nrows=num_imgs,
557
+ figsize=(size * num_degrees, size * num_imgs),
558
+ )
559
+ axes = np.asarray(axes).reshape(num_imgs, num_degrees)
560
+ # assert axes.ndim==2, "Internal error"
561
+ for i, sm in enumerate(sms):
562
+ for j, sm_of_degree in enumerate(sm):
563
+ plot_signed_measure(sm_of_degree, ax=axes[i, j])
564
+
565
+ def _check_axis(self, X):
566
+ # axes should be (num_data, num_axis, num_degrees, (signed_measure))
567
+ if len(X) == 0:
568
+ return
569
+ if len(X[0]) == 0:
570
+ return
571
+ if isinstance(X[0][0], tuple):
572
+ self._has_axis = False
573
+ self._num_axis = 1
574
+ return
575
+ assert isinstance(X[0][0][0], tuple), "Cannot take this input."
576
+ self._has_axis = True
577
+ self._num_axis = len(X[0])
578
+
579
+ def _check_measures(self, X):
580
+ if self._has_axis:
581
+ first_sm = X[0][0]
582
+ else:
583
+ first_sm = X[0]
584
+ self._num_degrees = len(first_sm)
585
+ self.num_parameters = first_sm[0][0].shape[1]
586
+
587
+ def _check_resolution(self):
588
+ assert self.num_parameters > 0, "Num parameters hasn't been initialized."
589
+ if isinstance(self.resolution, int):
590
+ self.resolution = [self.resolution] * self.num_parameters
591
+ self.resolution = np.asarray(self.resolution, dtype=int)
592
+ assert (
593
+ self.resolution.shape[0] == self.num_parameters
594
+ ), "Resolution doesn't have a proper size."
595
+
596
+ def _check_weights(self):
597
+ if self.filtrations_weights is None:
598
+ self.filtrations_weights = np.array([1] * self.num_parameters)
599
+ self.filtrations_weights = np.asarray(self.filtrations_weights)
600
+ assert (
601
+ self.filtrations_weights.shape[0] == self.num_parameters
602
+ ), "Filtration weights don't have a proper size"
603
+
604
+ def _infer_grids(self, X):
605
+ # Computes normalization factors
606
+ if self.normalize:
607
+ axis = slice(None) if not self._has_axis else self.axis
608
+ if self._has_axis and axis == -1:
609
+ self._filtrations_bounds = []
610
+ self._normalization_factors = []
611
+ for ax in range(self._num_axis):
612
+ (
613
+ filtration_bounds,
614
+ normalization_factors,
615
+ ) = self._get_filtration_bounds(X, axis=ax)
616
+ self._filtrations_bounds.append(filtration_bounds)
617
+ self._normalization_factors.append(normalization_factors)
618
+ else:
619
+ (
620
+ self._filtrations_bounds,
621
+ self._normalization_factors,
622
+ ) = self._get_filtration_bounds(X, axis=axis)
623
+ elif self.integrate or self.unsparse:
624
+ axis = (
625
+ [slice(None)]
626
+ if not self._has_axis
627
+ else range(self._num_axis)
628
+ if self.axis == -1
629
+ else [self.axis]
630
+ )
631
+ filtration_values = [
632
+ np.concatenate(
633
+ [x[ax][degree][0]
634
+ for x in X for degree in range(self._num_degrees)]
635
+ )
636
+ for ax in axis
637
+ ]
638
+ # axis, filtration_values
639
+ filtration_values = [
640
+ reduce_grid(
641
+ f_ax.T, resolutions=self.resolution, strategy=self.grid_strategy
642
+ )
643
+ for f_ax in filtration_values
644
+ ]
645
+ self._infered_grids = filtration_values
646
+
647
+ def _print_stats(self, X):
648
+ print("------------SignedMeasureFormatter------------")
649
+ print("---- Parameters")
650
+ print(f"Number of axis : {self._num_axis}")
651
+ print(f"Number of degrees : {self._num_degrees}")
652
+ print(f"Filtration bounds : \n{self._filtrations_bounds}")
653
+ print(f"Normalization factor : \n{self._normalization_factors}")
654
+ if self._infered_grids is not None:
655
+ print(
656
+ f"Filtration grid shape : \n \
657
+ {tuple(tuple(len(f) for f in F) for F in self._infered_grids)}"
658
+ )
659
+ print("---- SM stats")
660
+ print("In axis :", self._num_axis)
661
+ if self.axis == -1 and self._has_axis:
662
+ axis = range(self._num_axis)
663
+ elif self._has_axis:
664
+ axis = [self.axis]
665
+ else:
666
+ axis = [slice(None)]
667
+ sizes = [[[len(xd[1]) for xd in x[ax]] for x in X] for ax in axis]
668
+ print(f"Size means (axis) x (degree): {np.mean(sizes, axis=(1))}")
669
+ print(f"Size std : {np.std(sizes, axis=(1))}")
670
+ print("----------------------------------------------")
671
+
672
+ def fit(self, X, y=None):
673
+ assert not self.normalize or (
674
+ not self.unsparse and not self.deep_format and not self.integrate
675
+ )
676
+ # Gets a grid. This will be the max in each coord+1
677
+ if (
678
+ len(X) == 0
679
+ or len(X[0]) == 0
680
+ or (self.axis is not None and len(X[0][0][0]) == 0)
681
+ ):
682
+ return self
683
+
684
+ self._check_axis(X)
685
+ self._check_measures(X)
686
+ self._check_resolution()
687
+ self._check_weights()
688
+ # if not sparse : not recommended.
689
+
690
+ self._infer_grids(X)
691
+ if self.verbose:
692
+ self._print_stats(X)
693
+ return self
694
+
695
+ def unsparse_signed_measure(self, sparse_signed_measure):
696
+ filtrations = self._infered_grids # ax, filtration
697
+ out = []
698
+ axis = range(self._num_axis) if self.axis == -1 else [slice(None)]
699
+ for filtrations_of_ax, ax in zip(filtrations, axis):
700
+ sparse_signed_measure_of_ax = sparse_signed_measure[ax]
701
+ measure_of_ax = []
702
+ for pts, weights in sparse_signed_measure_of_ax: # over degree
703
+ signed_measure, _ = np.histogramdd(
704
+ pts, bins=filtrations_of_ax, weights=weights
705
+ )
706
+ if self.flatten:
707
+ signed_measure = signed_measure.flatten()
708
+ measure_of_ax.append(signed_measure)
709
+ out.append(np.asarray(measure_of_ax))
710
+
711
+ if self.flatten:
712
+ out = np.concatenate(out).flatten()
713
+ if self.axis == -1:
714
+ return np.asarray(out)
715
+ else:
716
+ return np.asarray(out)[0]
717
+
718
+ @staticmethod
719
+ def deep_format_measure(signed_measure):
720
+ dirac_positions, dirac_signs = signed_measure
721
+ dtype = dirac_positions.dtype
722
+ new_shape = list(dirac_positions.shape)
723
+ new_shape[1] += 1
724
+ if isinstance(dirac_positions, np.ndarray):
725
+ c = np.empty(new_shape, dtype=dtype)
726
+ c[:, :-1] = dirac_positions
727
+ c[:, -1] = dirac_signs
728
+
729
+ else:
730
+ import torch
731
+
732
+ c = torch.empty(new_shape, dtype=dtype)
733
+ c[:, :-1] = dirac_positions
734
+ c[:, -1] = torch.from_numpy(dirac_signs)
735
+ return c
736
+
737
+ @staticmethod
738
+ def _integrate_measure(sm, filtrations):
739
+ from multipers.point_measure_integration import integrate_measure
740
+
741
+ return integrate_measure(sm[0], sm[1], filtrations)
742
+
743
+ def _rescale_measures(self, X):
744
+ def rescale_from_sparse(sparse_signed_measure):
745
+ if self.axis == -1 and self._has_axis:
746
+ return [
747
+ rescale_sparse_signed_measure(
748
+ sparse_signed_measure[ax],
749
+ filtration_weights=self.filtrations_weights,
750
+ normalize_scales=n,
751
+ )
752
+ for ax, n in enumerate(self._normalization_factors)
753
+ ]
754
+ return rescale_sparse_signed_measure(
755
+ sparse_signed_measure,
756
+ filtration_weights=self.filtrations_weights,
757
+ normalize_scales=self._normalization_factors,
758
+ )
759
+
760
+ out = tuple(rescale_from_sparse(x) for x in X)
761
+ return out
762
+
763
+ def transform(self, X):
764
+ if not self._has_axis or self.axis == -1:
765
+ out = X
766
+ else:
767
+ out = tuple(x[self.axis] for x in X)
768
+ # same format for everyone
769
+
770
+ if self._normalization_factors is not None:
771
+ out = self._rescale_measures(out)
772
+
773
+ if self.plot:
774
+ # assert ax != -1, "Not implemented"
775
+ self._plot_signed_measures(out)
776
+ if self.integrate:
777
+ filtrations = self._infered_grids
778
+ # if self.axis != -1:
779
+ ax = 0 # if self.axis is None else self.axis # TODO deal with axis -1
780
+
781
+ assert ax != -1, "Not implemented"
782
+ # try:
783
+ out = np.asarray(
784
+ [
785
+ [
786
+ self._integrate_measure(
787
+ x[degree], filtrations=filtrations[ax])
788
+ for degree in range(self._num_degrees)
789
+ ]
790
+ for x in out
791
+ ]
792
+ )
793
+ # except:
794
+ # print(self.axis, ax, filtrations)
795
+ if self.flatten:
796
+ out = out.reshape((len(X), -1))
797
+ # else:
798
+ # out = [[[self._integrate_measure(x[axis][degree],filtrations=filtrations[degree].T) for degree in range(self._num_degrees)] for axis in range(self._num_axis)] for x in out]
799
+ elif self.unsparse:
800
+ out = [self.unsparse_signed_measure(x) for x in out]
801
+ elif self.deep_format:
802
+ num_degrees = self._num_degrees
803
+ if not self._has_axis:
804
+ axes = [slice(None)]
805
+ elif self.axis == -1:
806
+ axes = range(self._num_axis)
807
+ else:
808
+ axes = [self.axis]
809
+ out = [
810
+ [self.deep_format_measure(sm[axis][degree]) for sm in out]
811
+ for degree in range(num_degrees)
812
+ for axis in axes
813
+ ]
814
+ if self.unrag:
815
+ max_num_pts = np.max(
816
+ [sm.shape[0] for sm_of_axis in out for sm in sm_of_axis]
817
+ )
818
+ num_axis_degree = len(out)
819
+ num_data = len(out[0])
820
+ assert num_axis_degree == num_degrees * (
821
+ self._num_axis if self._has_axis else 1
822
+ ), f"Bad axis/degree count. Got {num_axis_degree} (Internal error)"
823
+ num_parameters = out[0][0].shape[1]
824
+ dtype = out[0][0].dtype
825
+ if isinstance(out[0][0], np.ndarray):
826
+ from numpy import zeros
827
+ else:
828
+ from torch import zeros
829
+ unragged_tensor = zeros(
830
+ (
831
+ num_axis_degree,
832
+ num_data,
833
+ max_num_pts,
834
+ num_parameters,
835
+ ),
836
+ dtype=dtype,
837
+ )
838
+ for ax in range(num_axis_degree):
839
+ for data in range(num_data):
840
+ sm = out[ax][data]
841
+ a, b = sm.shape
842
+ unragged_tensor[ax, data, :a, :b] = sm
843
+ out = unragged_tensor
844
+ return out
845
+
846
+
847
+ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
848
+ """
849
+ Discrete convolution of a signed measure
850
+
851
+ Input
852
+ -----
853
+
854
+ (data) x (degree) x (signed measure)
855
+
856
+ Parameters
857
+ ----------
858
+ - filtration_grid : Iterable[array] For each filtration, the filtration values on which to evaluate the grid
859
+ - resolution : int or (num_parameter) : If filtration grid is not given, will infer a grid, with this resolution
860
+ - grid_strategy : the strategy to generate the grid. Available ones are regular, quantile, exact
861
+ - flatten : if true, the output will be flattened
862
+ - kernel : kernel to used to convolve the images.
863
+ - flatten : flatten the images if True
864
+ - progress : progress bar if True
865
+ - backend : sklearn, pykeops or numba.
866
+ - plot : Creates a plot Figure.
867
+
868
+ Output
869
+ ------
870
+
871
+ (data) x (concatenation of imgs of degree)
872
+ """
873
+
874
+ def __init__(
875
+ self,
876
+ filtration_grid: Iterable[np.ndarray] = None,
877
+ kernel="gaussian",
878
+ bandwidth: float | Iterable[float] = 1.0,
879
+ flatten: bool = False,
880
+ n_jobs: int = 1,
881
+ resolution: int | None = None,
882
+ grid_strategy: str = "regular",
883
+ progress: bool = False,
884
+ backend: str = "pykeops",
885
+ plot: bool = False,
886
+ # **kwargs ## DANGEROUS
887
+ ):
888
+ super().__init__()
889
+ self.kernel = kernel
890
+ self.bandwidth = bandwidth
891
+ # self.more_kde_kwargs=kwargs
892
+ self.filtration_grid = filtration_grid
893
+ self.flatten = flatten
894
+ self.progress = progress
895
+ self.n_jobs = n_jobs
896
+ self.resolution = resolution
897
+ self.grid_strategy = grid_strategy
898
+ self._is_input_sparse = None
899
+ self._refit = filtration_grid is None
900
+ self._input_resolution = None
901
+ self._bandwidths = None
902
+ self.diameter = None
903
+ self.backend = backend
904
+ self.plot = plot
905
+ return
906
+
907
+ def fit(self, X, y=None):
908
+ # Infers if the input is sparse given X
909
+ if len(X) == 0:
910
+ return self
911
+ if isinstance(X[0][0], tuple):
912
+ self._is_input_sparse = True
913
+ else:
914
+ self._is_input_sparse = False
915
+ # print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
916
+ if not self._is_input_sparse:
917
+ self._input_resolution = X[0][0].shape
918
+ try:
919
+ float(self.bandwidth)
920
+ b = float(self.bandwidth)
921
+ self._bandwidths = [
922
+ b if b > 0 else -b * s for s in self._input_resolution
923
+ ]
924
+ except:
925
+ self._bandwidths = [
926
+ b if b > 0 else -b * s
927
+ for s, b in zip(self._input_resolution, self.bandwidth)
928
+ ]
929
+ return self # in that case, singed measures are matrices, and the grid is already given
930
+
931
+ if self.filtration_grid is None and self.resolution is None:
932
+ raise Exception(
933
+ "Cannot infer filtration grid. Provide either a filtration grid or a resolution."
934
+ )
935
+ # If not sparse : a grid has to be defined
936
+ if self._refit:
937
+ # print("Fitting a grid...", end="")
938
+ pts = np.concatenate(
939
+ [sm[0] for signed_measures in X for sm in signed_measures]
940
+ ).T
941
+ self.filtration_grid = reduce_grid(
942
+ filtrations_values=pts,
943
+ strategy=self.grid_strategy,
944
+ resolutions=self.resolution,
945
+ )
946
+ # print('Done.')
947
+ if self.filtration_grid is not None:
948
+ self.diameter = np.linalg.norm(
949
+ [f.max() - f.min() for f in self.filtration_grid]
950
+ )
951
+ if self.progress:
952
+ print(f"Computed a diameter of {self.diameter}")
953
+ return self
954
+
955
+ def _sparsify(self, sm):
956
+ return tensor_möbius_inversion(input=sm, grid_conversion=self.filtration_grid)
957
+
958
+ def _sm2smi(self, signed_measures: Iterable[np.ndarray]):
959
+ # print(self._input_resolution, self.bandwidths, _bandwidths)
960
+ from scipy.ndimage import gaussian_filter
961
+
962
+ return np.concatenate(
963
+ [
964
+ gaussian_filter(
965
+ input=signed_measure,
966
+ sigma=self._bandwidths,
967
+ mode="constant",
968
+ cval=0,
969
+ )
970
+ for signed_measure in signed_measures
971
+ ],
972
+ axis=0,
973
+ )
974
+
975
+ # def _sm2smi_sparse(self, signed_measures:Iterable[tuple[np.ndarray]]):
976
+ # return np.concatenate([
977
+ # _pts_convolution_sparse(
978
+ # pts = signed_measure_pts, pts_weights = signed_measure_weights,
979
+ # filtration_grid = self.filtration_grid,
980
+ # kernel=self.kernel,
981
+ # bandwidth=self.bandwidths,
982
+ # **self.more_kde_kwargs
983
+ # )
984
+ # for signed_measure_pts, signed_measure_weights in signed_measures], axis=0)
985
+ def _transform_from_sparse(self, X):
986
+ bandwidth = (
987
+ self.bandwidth if self.bandwidth > 0 else -self.bandwidth * self.diameter
988
+ )
989
+ # COMPILE KEOPS FIRST
990
+ dummyx = [X[0]]
991
+ dummyf = [f[:2] for f in self.filtration_grid]
992
+ convolution_signed_measures(
993
+ dummyx,
994
+ filtrations=dummyf,
995
+ bandwidth=bandwidth,
996
+ flatten=self.flatten,
997
+ n_jobs=1,
998
+ kernel=self.kernel,
999
+ backend=self.backend,
1000
+ )
1001
+
1002
+ return convolution_signed_measures(
1003
+ X,
1004
+ filtrations=self.filtration_grid,
1005
+ bandwidth=bandwidth,
1006
+ flatten=self.flatten,
1007
+ n_jobs=self.n_jobs,
1008
+ kernel=self.kernel,
1009
+ backend=self.backend,
1010
+ )
1011
+
1012
+ def _plot_imgs(self, imgs: Iterable[np.ndarray], size=4):
1013
+ from multipers.plots import plot_surface
1014
+
1015
+ num_degrees = imgs[0].shape[0]
1016
+ num_imgs = len(imgs)
1017
+ fig, axes = plt.subplots(
1018
+ ncols=num_degrees,
1019
+ nrows=num_imgs,
1020
+ figsize=(size * num_degrees, size * num_imgs),
1021
+ )
1022
+ axes = np.asarray(axes).reshape(num_imgs, num_degrees)
1023
+ # assert axes.ndim==2, "Internal error"
1024
+ for i, img in enumerate(imgs):
1025
+ for j, img_of_degree in enumerate(img):
1026
+ plot_surface(
1027
+ self.filtration_grid, img_of_degree, ax=axes[i,
1028
+ j], cmap="Spectral"
1029
+ )
1030
+
1031
+ def transform(self, X):
1032
+ if self._is_input_sparse is None:
1033
+ raise Exception("Fit first")
1034
+ if self._is_input_sparse:
1035
+ out = self._transform_from_sparse(X)
1036
+ else:
1037
+ todo = SignedMeasure2Convolution._sm2smi
1038
+ out = Parallel(n_jobs=self.n_jobs, backend="threading")(
1039
+ delayed(todo)(self, signed_measures)
1040
+ for signed_measures in tqdm(
1041
+ X, desc="Computing images", disable=not self.progress
1042
+ )
1043
+ )
1044
+ if self.plot and not self.flatten:
1045
+ if self.progress:
1046
+ print("Plotting convolutions...", end="")
1047
+ self._plot_imgs(out)
1048
+ if self.progress:
1049
+ print("Done !")
1050
+ if self.flatten and not self._is_input_sparse:
1051
+ out = [x.flatten() for x in out]
1052
+ return np.asarray(out)
1053
+
1054
+
1055
+ class SignedMeasure2SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
1056
+ """
1057
+ Transformer from signed measure to distance matrix.
1058
+
1059
+ Input
1060
+ -----
1061
+
1062
+ (data) x (degree) x (signed measure)
1063
+
1064
+ Format
1065
+ ------
1066
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1067
+ - each data is a list of signed measure (for e.g. multiple degrees)
1068
+
1069
+ Output
1070
+ ------
1071
+ - (degree) x (distance matrix)
1072
+ """
1073
+
1074
+ def __init__(
1075
+ self,
1076
+ n_jobs=None,
1077
+ num_directions: int = 10,
1078
+ _sliced: bool = True,
1079
+ epsilon=-1,
1080
+ ground_norm=1,
1081
+ progress=False,
1082
+ grid_reconversion=None,
1083
+ scales=None,
1084
+ ):
1085
+ super().__init__()
1086
+ self.n_jobs = n_jobs
1087
+ self._SWD_list = None
1088
+ self._sliced = _sliced
1089
+ self.epsilon = epsilon
1090
+ self.ground_norm = ground_norm
1091
+ self.num_directions = num_directions
1092
+ self.progress = progress
1093
+ self.grid_reconversion = grid_reconversion
1094
+ self.scales = scales
1095
+ return
1096
+
1097
+ def fit(self, X, y=None):
1098
+ from multipers.ml.sliced_wasserstein import (
1099
+ SlicedWassersteinDistance,
1100
+ WassersteinDistance,
1101
+ )
1102
+
1103
+ # _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
1104
+ if len(X) == 0:
1105
+ return self
1106
+ num_degrees = len(X[0])
1107
+ self._SWD_list = [
1108
+ SlicedWassersteinDistance(
1109
+ num_directions=self.num_directions,
1110
+ n_jobs=self.n_jobs,
1111
+ scales=self.scales,
1112
+ )
1113
+ if self._sliced
1114
+ else WassersteinDistance(
1115
+ epsilon=self.epsilon, ground_norm=self.ground_norm, n_jobs=self.n_jobs
1116
+ )
1117
+ for _ in range(num_degrees)
1118
+ ]
1119
+ for degree, swd in enumerate(self._SWD_list):
1120
+ signed_measures_of_degree = [x[degree] for x in X]
1121
+ swd.fit(signed_measures_of_degree)
1122
+ return self
1123
+
1124
+ def transform(self, X):
1125
+ assert self._SWD_list is not None, "Fit first"
1126
+ # out = []
1127
+ # for degree, swd in tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress):
1128
+ with tqdm(
1129
+ enumerate(self._SWD_list),
1130
+ desc="Computing distance matrices",
1131
+ total=len(self._SWD_list),
1132
+ disable=not self.progress,
1133
+ ) as SWD_it:
1134
+ # signed_measures_of_degree = [x[degree] for x in X]
1135
+ # out.append(swd.transform(signed_measures_of_degree))
1136
+ def todo(swd, X_of_degree):
1137
+ return swd.transform(X_of_degree)
1138
+
1139
+ out = Parallel(n_jobs=self.n_jobs, prefer="threads")(
1140
+ delayed(todo)(swd, [x[degree] for x in X]) for degree, swd in SWD_it
1141
+ )
1142
+ return np.asarray(out)
1143
+
1144
+ def predict(self, X):
1145
+ return self.transform(X)
1146
+
1147
+
1148
+ class SignedMeasures2SlicedWassersteinDistances(BaseEstimator, TransformerMixin):
1149
+ """
1150
+ Transformer from signed measure to distance matrix.
1151
+ Input
1152
+ -----
1153
+ (data) x opt (axis) x (degree) x (signed measure)
1154
+
1155
+ Format
1156
+ ------
1157
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1158
+ - each data is a list of signed measure (for e.g. multiple degrees)
1159
+
1160
+ Output
1161
+ ------
1162
+ - (axis) x (degree) x (distance matrix)
1163
+ """
1164
+
1165
+ def __init__(
1166
+ self,
1167
+ progress=False,
1168
+ n_jobs: int = 1,
1169
+ scales: Iterable[Iterable[float]] | None = None,
1170
+ **kwargs,
1171
+ ): # same init
1172
+ self._init_child = SignedMeasure2SlicedWassersteinDistance(
1173
+ progress=False, scales=None, n_jobs=-1, **kwargs
1174
+ )
1175
+ self._axe_iterator = None
1176
+ self._childs_to_fit = None
1177
+ self.scales = scales
1178
+ self.progress = progress
1179
+ self.n_jobs = n_jobs
1180
+ return
1181
+
1182
+ def fit(self, X, y=None):
1183
+ from sklearn.base import clone
1184
+
1185
+ if len(X) == 0:
1186
+ return self
1187
+ if isinstance(X[0][0], tuple): # Meaning that there are no axes
1188
+ self._axe_iterator = [slice(None)]
1189
+ else:
1190
+ self._axe_iterator = range(len(X[0]))
1191
+ if self.scales is None:
1192
+ self.scales = [None]
1193
+ else:
1194
+ self.scales = np.asarray(self.scales)
1195
+ if self.scales.ndim == 1:
1196
+ self.scales = np.asarray([self.scales])
1197
+ assert (
1198
+ self.scales[0] is None or self.scales.ndim == 2
1199
+ ), "Scales have to be either None or a list of scales !"
1200
+ self._childs_to_fit = [
1201
+ clone(self._init_child).set_params(
1202
+ scales=scales).fit([x[axis] for x in X])
1203
+ for axis, scales in product(self._axe_iterator, self.scales)
1204
+ ]
1205
+ print("New axes : ", list(product(self._axe_iterator, self.scales)))
1206
+ return self
1207
+
1208
+ def transform(self, X):
1209
+ return Parallel(n_jobs=self.n_jobs, prefer="processes")(
1210
+ delayed(self._childs_to_fit[child_id].transform)(
1211
+ [x[axis] for x in X])
1212
+ for child_id, (axis, _) in tqdm(
1213
+ enumerate(product(self._axe_iterator, self.scales)),
1214
+ desc=f"Computing distances matrices of axis, and scales",
1215
+ disable=not self.progress,
1216
+ total=len(self._childs_to_fit),
1217
+ )
1218
+ )
1219
+ # [
1220
+ # child.transform([x[axis // len(self.scales)] for x in X])
1221
+ # for axis, child in tqdm(enumerate(self._childs_to_fit),
1222
+ # desc=f"Computing distances of axis", disable=not self.progress, total=len(self._childs_to_fit)
1223
+ # )
1224
+ # ]
1225
+
1226
+
1227
+ class SimplexTree2RectangleDecomposition(BaseEstimator, TransformerMixin):
1228
+ """
1229
+ Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition.
1230
+ """
1231
+
1232
+ def __init__(
1233
+ self,
1234
+ filtration_grid: np.ndarray,
1235
+ degrees: Iterable[int],
1236
+ plot=False,
1237
+ reconvert_grid=True,
1238
+ num_collapses: int = 0,
1239
+ ):
1240
+ super().__init__()
1241
+ self.filtration_grid = filtration_grid
1242
+ self.degrees = degrees
1243
+ self.plot = plot
1244
+ self.reconvert_grid = reconvert_grid
1245
+ self.num_collapses = num_collapses
1246
+ return
1247
+
1248
+ def fit(self, X, y=None):
1249
+ """
1250
+ TODO : infer grid from multiple simplextrees
1251
+ """
1252
+ return self
1253
+
1254
+ def transform(self, X: Iterable[mp.SimplexTreeMulti]):
1255
+ rectangle_decompositions = [
1256
+ [
1257
+ _st2ranktensor(
1258
+ simplextree,
1259
+ filtration_grid=self.filtration_grid,
1260
+ degree=degree,
1261
+ plot=self.plot,
1262
+ reconvert_grid=self.reconvert_grid,
1263
+ num_collapse=self.num_collapses,
1264
+ )
1265
+ for degree in self.degrees
1266
+ ]
1267
+ for simplextree in X
1268
+ ]
1269
+ # TODO : return iterator ?
1270
+ return rectangle_decompositions
1271
+
1272
+
1273
+ def _st2ranktensor(
1274
+ st: mp.SimplexTreeMulti,
1275
+ filtration_grid: np.ndarray,
1276
+ degree: int,
1277
+ plot: bool,
1278
+ reconvert_grid: bool,
1279
+ num_collapse: int | str = 0,
1280
+ ):
1281
+ """
1282
+ TODO
1283
+ """
1284
+ # Copy (the squeeze change the filtration values)
1285
+ stcpy = mp.SimplexTreeMulti(st)
1286
+ # turns the simplextree into a coordinate simplex tree
1287
+ stcpy.grid_squeeze(filtration_grid=filtration_grid, coordinate_values=True)
1288
+ # stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
1289
+ if num_collapse == "full":
1290
+ stcpy.collapse_edges(full=True, ignore_warning=True,
1291
+ max_dimension=degree + 1)
1292
+ elif isinstance(num_collapse, int):
1293
+ stcpy.collapse_edges(
1294
+ num=num_collapse, ignore_warning=True, max_dimension=degree + 1
1295
+ )
1296
+ else:
1297
+ raise TypeError(
1298
+ f"Invalid num_collapse=\
1299
+ {num_collapse} type. Either full, or an integer."
1300
+ )
1301
+ # computes the rank invariant tensor
1302
+ rank_tensor = mp.rank_invariant2d(
1303
+ stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid]
1304
+ )
1305
+ # refactor this tensor into the rectangle decomposition of the signed betti
1306
+ grid_conversion = filtration_grid if reconvert_grid else None
1307
+ rank_decomposition = rank_decomposition_by_rectangles(
1308
+ rank_tensor,
1309
+ threshold=True,
1310
+ )
1311
+ rectangle_decomposition = tensor_möbius_inversion(
1312
+ tensor=rank_decomposition,
1313
+ grid_conversion=grid_conversion,
1314
+ plot=plot,
1315
+ num_parameters=st.num_parameters,
1316
+ )
1317
+ return rectangle_decomposition
1318
+
1319
+
1320
+ class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
1321
+ def __init__(
1322
+ self,
1323
+ degrees: Iterable[int],
1324
+ min_rips_value: float,
1325
+ max_rips_value,
1326
+ max_normalized_degree: float,
1327
+ min_normalized_degree: float,
1328
+ grid_granularity: int,
1329
+ progress: bool = False,
1330
+ n_jobs=1,
1331
+ sparse: bool = False,
1332
+ _möbius_inversion=True,
1333
+ fit_fraction=1,
1334
+ ) -> None:
1335
+ super().__init__()
1336
+ self.min_rips_value = min_rips_value
1337
+ self.max_rips_value = max_rips_value
1338
+ self.min_normalized_degree = min_normalized_degree
1339
+ self.max_normalized_degree = max_normalized_degree
1340
+ self._max_rips_value = None
1341
+ self.grid_granularity = grid_granularity
1342
+ self.progress = progress
1343
+ self.n_jobs = n_jobs
1344
+ self.degrees = degrees
1345
+ self.sparse = sparse
1346
+ self._möbius_inversion = _möbius_inversion
1347
+ self.fit_fraction = fit_fraction
1348
+ return
1349
+
1350
+ def fit(self, X: np.ndarray | list, y=None):
1351
+ if self.max_rips_value < 0:
1352
+ print("Estimating scale...", flush=True, end="")
1353
+ indices = np.random.choice(
1354
+ len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
1355
+ )
1356
+ diameters = np.max(
1357
+ [distance_matrix(x, x).max() for x in (X[i] for i in indices)]
1358
+ )
1359
+ print(f"Done. {diameters}", flush=True)
1360
+ self._max_rips_value = (
1361
+ -self.max_rips_value * diameters
1362
+ if self.max_rips_value < 0
1363
+ else self.max_rips_value
1364
+ )
1365
+ return self
1366
+
1367
+ def _transform1(self, data: np.ndarray):
1368
+ _distance_matrix = distance_matrix(data, data)
1369
+ signed_measures = []
1370
+ (
1371
+ rips_values,
1372
+ normalized_degree_values,
1373
+ hilbert_functions,
1374
+ minimal_presentations,
1375
+ ) = hf_degree_rips(
1376
+ _distance_matrix,
1377
+ min_rips_value=self.min_rips_value,
1378
+ max_rips_value=self._max_rips_value,
1379
+ min_normalized_degree=self.min_normalized_degree,
1380
+ max_normalized_degree=self.max_normalized_degree,
1381
+ grid_granularity=self.grid_granularity,
1382
+ max_homological_dimension=np.max(self.degrees),
1383
+ )
1384
+ for degree in self.degrees:
1385
+ hilbert_function = hilbert_functions[degree]
1386
+ signed_measure = (
1387
+ signed_betti(hilbert_function, threshold=True)
1388
+ if self._möbius_inversion
1389
+ else hilbert_function
1390
+ )
1391
+ if self.sparse:
1392
+ signed_measure = tensor_möbius_inversion(
1393
+ tensor=signed_measure,
1394
+ num_parameters=2,
1395
+ grid_conversion=[rips_values, normalized_degree_values],
1396
+ )
1397
+ if not self._möbius_inversion:
1398
+ signed_measure = signed_measure.flatten()
1399
+ signed_measures.append(signed_measure)
1400
+ return signed_measures
1401
+
1402
+ def transform(self, X):
1403
+ return Parallel(n_jobs=self.n_jobs)(
1404
+ delayed(self._transform1)(data)
1405
+ for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}")
1406
+ )
1407
+
1408
+
1409
+ def tensor_möbius_inversion(
1410
+ tensor,
1411
+ grid_conversion: Iterable[np.ndarray] | None = None,
1412
+ plot: bool = False,
1413
+ raw: bool = False,
1414
+ num_parameters: int | None = None,
1415
+ ):
1416
+ from torch import Tensor
1417
+
1418
+ betti_sparse = Tensor(tensor.copy()).to_sparse(
1419
+ ) # Copy necessary in some cases :(
1420
+ num_indices, num_pts = betti_sparse.indices().shape
1421
+ num_parameters = num_indices if num_parameters is None else num_parameters
1422
+ if num_indices == num_parameters: # either hilbert or rank invariant
1423
+ rank_invariant = False
1424
+ elif 2 * num_parameters == num_indices:
1425
+ rank_invariant = True
1426
+ else:
1427
+ raise TypeError(
1428
+ f"Unsupported betti shape. {num_indices}\
1429
+ has to be either {num_parameters} or \
1430
+ {2*num_parameters}."
1431
+ )
1432
+ points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
1433
+ weights = np.asarray(betti_sparse.values(), dtype=int)
1434
+
1435
+ if grid_conversion is not None:
1436
+ coords = np.empty(shape=(num_pts, num_indices), dtype=float)
1437
+ for i in range(num_indices):
1438
+ coords[:, i] = grid_conversion[i %
1439
+ num_parameters][points_filtration[:, i]]
1440
+ else:
1441
+ coords = points_filtration
1442
+ if (not rank_invariant) and plot:
1443
+ plt.figure()
1444
+ color_weights = np.empty(weights.shape)
1445
+ color_weights[weights > 0] = np.log10(weights[weights > 0]) + 2
1446
+ color_weights[weights < 0] = -np.log10(-weights[weights < 0]) - 2
1447
+ plt.scatter(
1448
+ points_filtration[:, 0],
1449
+ points_filtration[:, 1],
1450
+ c=color_weights,
1451
+ cmap="coolwarm",
1452
+ )
1453
+ if (not rank_invariant) or raw:
1454
+ return coords, weights
1455
+
1456
+ def _is_trivial(rectangle: np.ndarray):
1457
+ birth = rectangle[:num_parameters]
1458
+ death = rectangle[num_parameters:]
1459
+ return np.all(birth <= death) # and not np.array_equal(birth,death)
1460
+
1461
+ correct_indices = np.array([_is_trivial(rectangle)
1462
+ for rectangle in coords])
1463
+ if len(correct_indices) == 0:
1464
+ return np.empty((0, num_indices)), np.empty((0))
1465
+ signed_measure = np.asarray(coords[correct_indices])
1466
+ weights = weights[correct_indices]
1467
+ if plot:
1468
+ # plot only the rank decompo for the moment
1469
+ assert signed_measure.shape[1] == 4
1470
+
1471
+ def _plot_rectangle(rectangle: np.ndarray, weight: float):
1472
+ x_axis = rectangle[[0, 2]]
1473
+ y_axis = rectangle[[1, 3]]
1474
+ color = "blue" if weight > 0 else "red"
1475
+ plt.plot(x_axis, y_axis, c=color)
1476
+
1477
+ for rectangle, weight in zip(signed_measure, weights):
1478
+ _plot_rectangle(rectangle=rectangle, weight=weight)
1479
+ return signed_measure, weights