multipers 2.0.0__cp311-cp311-macosx_13_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (78) hide show
  1. multipers/.dylibs/libc++.1.0.dylib +0 -0
  2. multipers/.dylibs/libtbb.12.12.dylib +0 -0
  3. multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
  4. multipers/__init__.py +11 -0
  5. multipers/_signed_measure_meta.py +268 -0
  6. multipers/_slicer_meta.py +171 -0
  7. multipers/data/MOL2.py +350 -0
  8. multipers/data/UCR.py +18 -0
  9. multipers/data/__init__.py +1 -0
  10. multipers/data/graphs.py +466 -0
  11. multipers/data/immuno_regions.py +27 -0
  12. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  13. multipers/data/pytorch2simplextree.py +91 -0
  14. multipers/data/shape3d.py +101 -0
  15. multipers/data/synthetic.py +68 -0
  16. multipers/distances.py +198 -0
  17. multipers/euler_characteristic.pyx +132 -0
  18. multipers/filtration_conversions.pxd +229 -0
  19. multipers/filtrations.pxd +225 -0
  20. multipers/function_rips.cpython-311-darwin.so +0 -0
  21. multipers/function_rips.pyx +105 -0
  22. multipers/grids.cpython-311-darwin.so +0 -0
  23. multipers/grids.pyx +281 -0
  24. multipers/hilbert_function.pyi +46 -0
  25. multipers/hilbert_function.pyx +153 -0
  26. multipers/io.cpython-311-darwin.so +0 -0
  27. multipers/io.pyx +571 -0
  28. multipers/ml/__init__.py +0 -0
  29. multipers/ml/accuracies.py +90 -0
  30. multipers/ml/convolutions.py +532 -0
  31. multipers/ml/invariants_with_persistable.py +79 -0
  32. multipers/ml/kernels.py +176 -0
  33. multipers/ml/mma.py +659 -0
  34. multipers/ml/one.py +472 -0
  35. multipers/ml/point_clouds.py +238 -0
  36. multipers/ml/signed_betti.py +50 -0
  37. multipers/ml/signed_measures.py +1542 -0
  38. multipers/ml/sliced_wasserstein.py +461 -0
  39. multipers/ml/tools.py +113 -0
  40. multipers/mma_structures.cpython-311-darwin.so +0 -0
  41. multipers/mma_structures.pxd +127 -0
  42. multipers/mma_structures.pyx +2433 -0
  43. multipers/multiparameter_edge_collapse.py +41 -0
  44. multipers/multiparameter_module_approximation.cpython-311-darwin.so +0 -0
  45. multipers/multiparameter_module_approximation.pyx +211 -0
  46. multipers/pickle.py +53 -0
  47. multipers/plots.py +326 -0
  48. multipers/point_measure_integration.cpython-311-darwin.so +0 -0
  49. multipers/point_measure_integration.pyx +139 -0
  50. multipers/rank_invariant.cpython-311-darwin.so +0 -0
  51. multipers/rank_invariant.pyx +229 -0
  52. multipers/simplex_tree_multi.cpython-311-darwin.so +0 -0
  53. multipers/simplex_tree_multi.pxd +129 -0
  54. multipers/simplex_tree_multi.pyi +715 -0
  55. multipers/simplex_tree_multi.pyx +4655 -0
  56. multipers/slicer.cpython-311-darwin.so +0 -0
  57. multipers/slicer.pxd +781 -0
  58. multipers/slicer.pyx +3393 -0
  59. multipers/tensor.pxd +13 -0
  60. multipers/test.pyx +44 -0
  61. multipers/tests/__init__.py +40 -0
  62. multipers/tests/old_test_rank_invariant.py +91 -0
  63. multipers/tests/test_diff_helper.py +74 -0
  64. multipers/tests/test_hilbert_function.py +82 -0
  65. multipers/tests/test_mma.py +51 -0
  66. multipers/tests/test_point_clouds.py +59 -0
  67. multipers/tests/test_python-cpp_conversion.py +82 -0
  68. multipers/tests/test_signed_betti.py +181 -0
  69. multipers/tests/test_simplextreemulti.py +98 -0
  70. multipers/tests/test_slicer.py +63 -0
  71. multipers/torch/__init__.py +1 -0
  72. multipers/torch/diff_grids.py +217 -0
  73. multipers/torch/rips_density.py +257 -0
  74. multipers-2.0.0.dist-info/LICENSE +21 -0
  75. multipers-2.0.0.dist-info/METADATA +29 -0
  76. multipers-2.0.0.dist-info/RECORD +78 -0
  77. multipers-2.0.0.dist-info/WHEEL +5 -0
  78. multipers-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1542 @@
1
+ from itertools import product
2
+ from typing import Callable, Iterable, Optional
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from joblib import Parallel, delayed
7
+ from sklearn.base import BaseEstimator, TransformerMixin
8
+ from tqdm import tqdm
9
+
10
+ import multipers as mp
11
+ from multipers.grids import compute_grid as reduce_grid
12
+ from multipers.ml.convolutions import convolution_signed_measures
13
+
14
+
15
+ class SimplexTree2SignedMeasure(BaseEstimator, TransformerMixin):
16
+ """
17
+ Input
18
+ -----
19
+ Iterable[SimplexTreeMulti]
20
+
21
+ Output
22
+ ------
23
+ Iterable[ list[signed_measure for degree] ]
24
+
25
+ signed measure is either
26
+ - (points : (n x num_parameters) array, weights : (n) int array ) if sparse,
27
+ - else an integer matrix.
28
+
29
+ Parameters
30
+ ----------
31
+ - degrees : list of degrees to compute. None correspond to the euler characteristic
32
+ - filtration grid : the grid on which to compute.
33
+ If None, the fit will infer it from
34
+ - fit_fraction : the fraction of data to consider for the fit, seed is controlled by the seed parameter
35
+ - resolution : the resolution of this grid
36
+ - filtration_quantile : filtrations values quantile to ignore
37
+ - grid_strategy:str : 'regular' or 'quantile' or 'exact'
38
+ - normalize filtration : if sparse, will normalize all filtrations.
39
+ - expand : expands the simplextree to compute correctly the degree, for
40
+ flag complexes
41
+ - invariant : the topological invariant to produce the signed measure.
42
+ Choices are "hilbert" or "euler". Will add rank invariant later.
43
+ - num_collapse : Either an int or "full". Collapse the complex before
44
+ doing computation.
45
+ - _möbius_inversion : if False, will not do the mobius inversion. output
46
+ has to be a matrix then.
47
+ - enforce_null_mass : Returns a zero mass measure, by thresholding the
48
+ module if True.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ # homological degrees + None for euler
54
+ degrees: list[int | None] = [],
55
+ rank_degrees: list[int] = [], # same for rank invariant
56
+ filtration_grid: Iterable[np.ndarray]
57
+ # filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
58
+ | None = None,
59
+ progress=False, # tqdm
60
+ num_collapses: int | str = 0, # edge collapses before computing
61
+ n_jobs=None,
62
+ resolution: Iterable[int]
63
+ | int
64
+ | None = None, # when filtration grid is not given, the resolution of the filtration grid to infer
65
+ # sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
66
+ plot: bool = False,
67
+ filtration_quantile: float = 0.0, # quantile for inferring filtration grid
68
+ # wether or not to do the möbius inversion (not recommended to touch)
69
+ _möbius_inversion: bool = True,
70
+ expand=True, # expand the simplextree befoe computing the homology
71
+ normalize_filtrations: bool = False,
72
+ # exact_computation:bool=False, # compute the exact signed measure.
73
+ grid_strategy: str = "exact",
74
+ seed: int = 0, # if fit_fraction is not 1, the seed sampling
75
+ fit_fraction=1, # the fraction of the data on which to fit
76
+ out_resolution: Iterable[int] | int | None = None,
77
+ individual_grid: Optional[
78
+ bool
79
+ ] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
80
+ enforce_null_mass: bool = False,
81
+ flatten=True,
82
+ backend="multipers",
83
+ ):
84
+ super().__init__()
85
+ self.degrees = degrees
86
+ self.rank_degrees = rank_degrees
87
+ self.filtration_grid = filtration_grid
88
+ self.progress = progress
89
+ self.num_collapses = num_collapses
90
+ self.n_jobs = n_jobs
91
+ self.resolution = resolution
92
+ self.plot = plot
93
+ self.backend = backend
94
+ # self.sparse=sparse # TODO : deprecate
95
+ self.filtration_quantile = filtration_quantile
96
+ # Will only work for non sparse output. (discrete matrices cannot be "rescaled")
97
+ self.normalize_filtrations = normalize_filtrations
98
+ self.grid_strategy = grid_strategy
99
+ self.num_parameter = None
100
+ self._is_input_delayed = None
101
+ self._möbius_inversion = _möbius_inversion
102
+ self._reconversion_grid = None
103
+ self.expand = expand
104
+ # will only refit the grid if filtration_grid has never been given.
105
+ self._refit_grid = None
106
+ self.seed = seed
107
+ self.fit_fraction = fit_fraction
108
+ self._transform_st = None
109
+ self._to_simplex_tree: Callable
110
+ self.out_resolution = out_resolution
111
+ self.individual_grid = individual_grid
112
+ self.enforce_null_mass = enforce_null_mass
113
+ self._default_mass_location = None
114
+ self.flatten = flatten
115
+ self.num_parameters: int = 0
116
+ return
117
+
118
+ def _infer_filtration(self, X):
119
+ indices = np.random.choice(
120
+ len(X), min(int(self.fit_fraction * len(X)) + 1, len(X)), replace=False
121
+ )
122
+
123
+ def get_st_filtration(x) -> np.ndarray:
124
+ return self._to_simplex_tree(x).get_filtration_grid(grid_strategy="exact")
125
+
126
+ filtrations: list = Parallel(n_jobs=self.n_jobs, backend="threading")(
127
+ delayed(get_st_filtration)(x) for x in (X[idx] for idx in indices)
128
+ )
129
+ num_parameters = len(filtrations[0])
130
+ filtrations_values = [
131
+ np.unique(np.concatenate([x[i] for x in filtrations]))
132
+ for i in range(num_parameters)
133
+ ]
134
+ filtration_grid = reduce_grid(
135
+ filtrations_values, resolution=self.resolution, strategy=self.grid_strategy
136
+ ) # TODO :use more parameters
137
+ self.filtration_grid = filtration_grid
138
+ return filtration_grid
139
+
140
+ def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
141
+ # assert not self.normalize_filtrations or self.sparse, "Not able to normalize a matrix without losing information."
142
+ assert (
143
+ self.resolution is not None
144
+ or self.filtration_grid is not None
145
+ or self.grid_strategy == "exact"
146
+ or self.individual_grid
147
+ ), "For non exact filtrations, a resolution has to be specified."
148
+ assert (
149
+ self._möbius_inversion or not self.individual_grid
150
+ ), "The grid has to be aligned when not using mobius inversion; disable individual_grid or enable mobius inversion."
151
+ # assert self.invariant != "_" or self._möbius_inversion
152
+ self._is_input_delayed = not mp.simplex_tree_multi.is_simplextree_multi(X[0])
153
+ if self._is_input_delayed:
154
+ from multipers.ml.tools import get_simplex_tree_from_delayed
155
+
156
+ self._to_simplex_tree = get_simplex_tree_from_delayed
157
+ else:
158
+ self._to_simplex_tree = lambda x: x
159
+ if isinstance(self.resolution, int) or self.resolution == np.inf:
160
+ self.resolution = [self.resolution] * self._to_simplex_tree(
161
+ X[0]
162
+ ).num_parameters
163
+ self.num_parameter = self._to_simplex_tree(X[0]).num_parameters
164
+ self.individual_grid = (
165
+ self.individual_grid
166
+ if self.individual_grid is not None
167
+ else self.grid_strategy
168
+ in ["regular_closest", "exact", "quantile", "partition"]
169
+ )
170
+
171
+ if (
172
+ not self.enforce_null_mass
173
+ and self.individual_grid
174
+ or self.filtration_grid is not None
175
+ ):
176
+ self._refit_grid = False
177
+ else:
178
+ self._refit_grid = True
179
+
180
+ if self._refit_grid:
181
+ self._infer_filtration(X=X)
182
+ if self.out_resolution is None:
183
+ self.out_resolution = self.resolution
184
+ elif isinstance(self.out_resolution, int):
185
+ self.out_resolution = [self.out_resolution] * self.num_parameters
186
+ if self.normalize_filtrations and not self.individual_grid:
187
+ # self._reconversion_grid = [np.linspace(0,1, num=len(f), dtype=float) for f in self.filtration_grid] ## This will not work for non-regular grids...
188
+ self._reconversion_grid = [
189
+ f / np.std(f) for f in self.filtration_grid
190
+ ] # not the best, but better than some weird magic
191
+ # elif not self.sparse: # It actually renormalizes the filtration !!
192
+ # self._reconversion_grid = [np.linspace(0,r, num=r, dtype=int) for r in self.out_resolution]
193
+ else:
194
+ self._reconversion_grid = self.filtration_grid
195
+ self._default_mass_location = (
196
+ [g[-1] for g in self._reconversion_grid] if self.enforce_null_mass else None
197
+ )
198
+ return self
199
+
200
+ def transform1(
201
+ self,
202
+ simplextree,
203
+ filtration_grid=None,
204
+ _reconversion_grid=None,
205
+ thread_id: str = "",
206
+ ):
207
+ if filtration_grid is None:
208
+ filtration_grid = self.filtration_grid
209
+ if _reconversion_grid is None:
210
+ _reconversion_grid = self._reconversion_grid
211
+ st = self._to_simplex_tree(simplextree)
212
+ # st = mp.SimplexTreeMulti(st, num_parameters=st.num_parameters) # COPY
213
+ if self.individual_grid:
214
+ filtration_grid = st.get_filtration_grid(
215
+ grid_strategy=self.grid_strategy, resolution=self.resolution
216
+ )
217
+ if self.enforce_null_mass:
218
+ filtration_grid = [
219
+ np.concatenate([f, [d]], axis=0)
220
+ for f, d in zip(filtration_grid, self._default_mass_location)
221
+ ]
222
+ st = st.grid_squeeze(filtration_grid=filtration_grid, coordinate_values=True)
223
+ if st.num_parameters == 2:
224
+ if self.num_collapses == "full":
225
+ st.collapse_edges(full=True, max_dimension=1)
226
+ elif isinstance(self.num_collapses, int):
227
+ st.collapse_edges(num=self.num_collapses, max_dimension=1)
228
+ else:
229
+ raise Exception("Bad edge collapse type. either 'full' or an int.")
230
+ int_degrees = np.asarray([d for d in self.degrees if d is not None])
231
+ if self._möbius_inversion:
232
+ # EULER. First as there is prune above dimension below
233
+ if self.expand and None in self.degrees:
234
+ st.expansion(st.num_vertices)
235
+ signed_measures_euler = (
236
+ mp.signed_measure(
237
+ simplextree=st,
238
+ degrees=[None],
239
+ plot=self.plot,
240
+ mass_default=self._default_mass_location,
241
+ invariant="euler",
242
+ thread_id=thread_id,
243
+ )[0]
244
+ if None in self.degrees
245
+ else []
246
+ )
247
+
248
+ if self.expand and len(int_degrees) > 0:
249
+ st.expansion(np.max(int_degrees) + 1)
250
+ if len(int_degrees) > 0:
251
+ st.prune_above_dimension(
252
+ np.max(np.concatenate([int_degrees, self.rank_degrees])) + 1
253
+ ) # no need to compute homology beyond this
254
+ signed_measures_pers = (
255
+ mp.signed_measure(
256
+ simplextree=st,
257
+ degrees=int_degrees,
258
+ mass_default=self._default_mass_location,
259
+ plot=self.plot,
260
+ invariant="hilbert",
261
+ thread_id=thread_id,
262
+ backend=self.backend,
263
+ )
264
+ if len(int_degrees) > 0
265
+ else []
266
+ )
267
+ if self.plot:
268
+ plt.show()
269
+ if self.expand and len(self.rank_degrees) > 0:
270
+ st.expansion(np.max(self.rank_degrees) + 1)
271
+ if len(self.rank_degrees) > 0:
272
+ st.prune_above_dimension(
273
+ np.max(self.rank_degrees) + 1
274
+ ) # no need to compute homology beyond this
275
+ signed_measures_rank = (
276
+ mp.signed_measure(
277
+ simplextree=st,
278
+ degrees=self.rank_degrees,
279
+ mass_default=self._default_mass_location,
280
+ plot=self.plot,
281
+ invariant="rank",
282
+ thread_id=thread_id,
283
+ )
284
+ if len(self.rank_degrees) > 0
285
+ else []
286
+ )
287
+ if self.plot:
288
+ plt.show()
289
+
290
+ else:
291
+ raise ValueError("This is deprecated")
292
+ # from multipers.euler_characteristic import euler_surface
293
+ # from multipers.hilbert_function import hilbert_surface
294
+ # from multipers.rank_invariant import rank_invariant
295
+ #
296
+ # if self.expand and None in self.degrees:
297
+ # st.expansion(st.num_vertices)
298
+ # signed_measures_euler = (
299
+ # euler_surface(
300
+ # simplextree=st,
301
+ # plot=self.plot,
302
+ # )[1][None]
303
+ # if None in self.degrees
304
+ # else []
305
+ # )
306
+ #
307
+ # if self.expand and len(int_degrees) > 0:
308
+ # st.expansion(np.max(int_degrees) + 1)
309
+ # if len(int_degrees) > 0:
310
+ # st.prune_above_dimension(
311
+ # np.max(np.concatenate([int_degrees, self.rank_degrees])) + 1
312
+ # )
313
+ # # no need to compute homology beyond this
314
+ # signed_measures_pers = (
315
+ # hilbert_surface(
316
+ # simplextree=st,
317
+ # degrees=int_degrees,
318
+ # plot=self.plot,
319
+ # )[1]
320
+ # if len(int_degrees) > 0
321
+ # else []
322
+ # )
323
+ # if self.plot:
324
+ # plt.show()
325
+ #
326
+ # if self.expand and len(self.rank_degrees) > 0:
327
+ # st.expansion(np.max(self.rank_degrees) + 1)
328
+ # if len(self.rank_degrees) > 0:
329
+ # st.prune_above_dimension(
330
+ # np.max(self.rank_degrees) + 1
331
+ # ) # no need to compute homology beyond this
332
+ # signed_measures_rank = (
333
+ # rank_invariant(
334
+ # sieplextree=st,
335
+ # degrees=self.rank_degrees,
336
+ # plot=self.plot,
337
+ # )
338
+ # if len(self.rank_degrees) > 0
339
+ # else []
340
+ # )
341
+ #
342
+ count = 0
343
+ signed_measures = []
344
+ for d in self.degrees:
345
+ if d is None:
346
+ signed_measures.append(signed_measures_euler)
347
+ else:
348
+ signed_measures.append(signed_measures_pers[count])
349
+ count += 1
350
+ signed_measures += signed_measures_rank
351
+ if not self._möbius_inversion and self.flatten:
352
+ signed_measures = np.asarray(signed_measures).flatten()
353
+ return signed_measures
354
+
355
+ def transform(self, X):
356
+ assert self.filtration_grid is not None or self.individual_grid, "Fit first"
357
+ prefer = "loky" if self._is_input_delayed else "threading"
358
+ out = Parallel(n_jobs=self.n_jobs, backend=prefer)(
359
+ delayed(self.transform1)(to_st, thread_id=str(thread_id))
360
+ for thread_id, to_st in tqdm(
361
+ enumerate(X),
362
+ disable=not self.progress,
363
+ desc="Computing signed measure decompositions",
364
+ )
365
+ )
366
+ return out
367
+
368
+
369
+ class SimplexTrees2SignedMeasures(SimplexTree2SignedMeasure):
370
+ """
371
+ Input
372
+ -----
373
+
374
+ (data) x (axis, e.g. different bandwidths for simplextrees) x (simplextree)
375
+
376
+ Output
377
+ ------
378
+ (data) x (axis) x (degree) x (signed measure)
379
+ """
380
+
381
+ def __init__(self, **kwargs):
382
+ super().__init__(**kwargs)
383
+ self._num_st_per_data = None
384
+ # self._super_model=SimplexTree2SignedMeasure(**kwargs)
385
+ self._filtration_grids = None
386
+ return
387
+
388
+ def fit(self, X, y=None):
389
+ if len(X) == 0:
390
+ return self
391
+ try:
392
+ self._num_st_per_data = len(X[0])
393
+ except:
394
+ raise Exception(
395
+ "Shape has to be (num_data, num_axis), dtype=SimplexTreeMulti"
396
+ )
397
+ self._filtration_grids = []
398
+ for axis in range(self._num_st_per_data):
399
+ self._filtration_grids.append(
400
+ super().fit([x[axis] for x in X]).filtration_grid
401
+ )
402
+ # self._super_fits.append(truc)
403
+ # self._super_fits_params = [super().fit([x[axis] for x in X]).get_params() for axis in range(self._num_st_per_data)]
404
+ return self
405
+
406
+ def transform(self, X):
407
+ if self.normalize_filtrations:
408
+ _reconversion_grids = [
409
+ [np.linspace(0, 1, num=len(f), dtype=float) for f in F]
410
+ for F in self._filtration_grids
411
+ ]
412
+ else:
413
+ _reconversion_grids = self._filtration_grids
414
+
415
+ def todo(x):
416
+ # return [SimplexTree2SignedMeasure().set_params(**transformer_params).transform1(x[axis]) for axis,transformer_params in enumerate(self._super_fits_params)]
417
+ out = [
418
+ self.transform1(
419
+ x[axis],
420
+ filtration_grid=filtration_grid,
421
+ _reconversion_grid=_reconversion_grid,
422
+ )
423
+ for axis, filtration_grid, _reconversion_grid in zip(
424
+ range(self._num_st_per_data),
425
+ self._filtration_grids,
426
+ _reconversion_grids,
427
+ )
428
+ ]
429
+ return out
430
+
431
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
432
+ delayed(todo)(x)
433
+ for x in tqdm(
434
+ X,
435
+ disable=not self.progress,
436
+ desc="Computing Signed Measures from simplextrees.",
437
+ )
438
+ )
439
+
440
+
441
+ def rescale_sparse_signed_measure(
442
+ signed_measure, filtration_weights, normalize_scales=None
443
+ ):
444
+ # from copy import deepcopy
445
+ #
446
+ # out = deepcopy(signed_measure)
447
+
448
+ if filtration_weights is None and normalize_scales is None:
449
+ return signed_measure
450
+
451
+ # if normalize_scales is None:
452
+ # out = tuple(
453
+ # (
454
+ # _cat(
455
+ # tuple(
456
+ # signed_measure[degree][0][:, parameter]
457
+ # * filtration_weights[parameter]
458
+ # for parameter in range(num_parameters)
459
+ # ),
460
+ # axis=1,
461
+ # ),
462
+ # signed_measure[degree][1],
463
+ # )
464
+ # for degree in range(len(signed_measure))
465
+ # )
466
+ # for degree in range(len(signed_measure)): # degree
467
+ # for parameter in range(len(filtration_weights)):
468
+ # signed_measure[degree][0][:, parameter] *= filtration_weights[parameter]
469
+ # # TODO Broadcast w.r.t. the parameter
470
+ # out = tuple(
471
+ # _cat(
472
+ # tuple(
473
+ # signed_measure[degree][0][:, [parameter]]
474
+ # * filtration_weights[parameter]
475
+ # / (
476
+ # normalize_scales[degree][parameter]
477
+ # if normalize_scales is not None
478
+ # else 1
479
+ # )
480
+ # for parameter in range(num_parameters)
481
+ # ),
482
+ # axis=1,
483
+ # )
484
+ # for degree in range(len(signed_measure))
485
+ # )
486
+ out = tuple(
487
+ (
488
+ signed_measure[degree][0]
489
+ * (1 if filtration_weights is None else filtration_weights.reshape(1, -1))
490
+ / (
491
+ normalize_scales[degree].reshape(1, -1)
492
+ if normalize_scales is not None
493
+ else 1
494
+ ),
495
+ signed_measure[degree][1],
496
+ )
497
+ for degree in range(len(signed_measure))
498
+ )
499
+ # for degree in range(len(out)):
500
+ # for parameter in range(len(filtration_weights)):
501
+ # out[degree][0][:, parameter] *= (
502
+ # filtration_weights[parameter] / normalize_scales[degree][parameter]
503
+ # )
504
+ return out
505
+
506
+
507
+ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
508
+ """
509
+ Input
510
+ -----
511
+
512
+ (data) x (degree) x (signed measure) or (data) x (axis) x (degree) x (signed measure)
513
+
514
+ Iterable[list[signed_measure_matrix of degree]] or Iterable[previous].
515
+
516
+ The second is meant to use multiple choices for signed measure input. An example of usage : they come from a Rips + Density with different bandwidth.
517
+ It is controlled by the axis parameter.
518
+
519
+ Output
520
+ ------
521
+
522
+ Iterable[list[(reweighted)_sparse_signed_measure of degree]]
523
+
524
+ or (deep format)
525
+
526
+ Tensor of shape (num_axis*num_degrees, data, max_num_pts, num_parameters)
527
+ """
528
+
529
+ def __init__(
530
+ self,
531
+ filtrations_weights: Optional[Iterable[float]] = None,
532
+ normalize=False,
533
+ plot: bool = False,
534
+ unsparse: bool = False,
535
+ axis: int = -1,
536
+ resolution: int | Iterable[int] = 50,
537
+ flatten: bool = False,
538
+ deep_format: bool = False,
539
+ unrag: bool = True,
540
+ n_jobs: int = 1,
541
+ verbose: bool = False,
542
+ integrate: bool = False,
543
+ grid_strategy="regular",
544
+ ):
545
+ super().__init__()
546
+ self.filtrations_weights = filtrations_weights
547
+ self.num_parameters: int = 0
548
+ self.plot = plot
549
+ self.unsparse = unsparse
550
+ self.n_jobs = n_jobs
551
+ self.axis = axis
552
+ self._num_axis = 0
553
+ self.resolution = resolution
554
+ self._filtrations_bounds = None
555
+ self.flatten = flatten
556
+ self.normalize = normalize
557
+ self._normalization_factors = None
558
+ self.deep_format = deep_format
559
+ self.unrag = unrag
560
+ assert (
561
+ not self.deep_format or not self.unsparse or not self.integrate
562
+ ), "One post processing at the time."
563
+ self.verbose = verbose
564
+ self._num_degrees = 0
565
+ self.integrate = integrate
566
+ self.grid_strategy = grid_strategy
567
+ self._infered_grids = None
568
+ self._axis_iterator = None
569
+ self._backend = None
570
+ return
571
+
572
+ def _get_filtration_bounds(self, X, axis):
573
+ if self._backend == "numpy":
574
+ _cat = np.concatenate
575
+
576
+ else:
577
+ ## torch is globally imported
578
+ _cat = torch.cat
579
+ stuff = [
580
+ _cat(
581
+ [sm[axis][degree][0] for sm in X],
582
+ axis=0,
583
+ )
584
+ for degree in range(self._num_degrees)
585
+ ]
586
+ sizes_ = np.array([len(x) == 0 for x in stuff])
587
+ assert np.all(1 - sizes_), f"Degree axis {np.where(sizes_)} is/are trivial !"
588
+ if self._backend == "numpy":
589
+ filtrations_bounds = np.array(
590
+ [([f.min(axis=0), f.max(axis=0)]) for f in stuff]
591
+ )
592
+ else:
593
+ filtrations_bounds = torch.stack(
594
+ [
595
+ torch.stack([f.min(axis=0).values, f.max(axis=0).values])
596
+ for f in stuff
597
+ ]
598
+ ).detach() ## don't want to rescale gradient of normalization
599
+ normalization_factors = (
600
+ filtrations_bounds[:, 1] - filtrations_bounds[:, 0]
601
+ if self.normalize
602
+ else None
603
+ )
604
+ # print("Normalization factors : ",self._normalization_factors)
605
+ if (normalization_factors == 0).any():
606
+ indices = normalization_factors == 0
607
+ # warn(f"Constant filtration encountered, at degree, parameter {indices} and axis {self.axis}.")
608
+ normalization_factors[indices] = 1
609
+ return filtrations_bounds, normalization_factors
610
+
611
+ def _plot_signed_measures(self, sms: Iterable[np.ndarray], size=4):
612
+ from multipers.plots import plot_signed_measure
613
+
614
+ num_degrees = len(sms[0])
615
+ num_imgs = len(sms)
616
+ fig, axes = plt.subplots(
617
+ ncols=num_degrees,
618
+ nrows=num_imgs,
619
+ figsize=(size * num_degrees, size * num_imgs),
620
+ )
621
+ axes = np.asarray(axes).reshape(num_imgs, num_degrees)
622
+ # assert axes.ndim==2, "Internal error"
623
+ for i, sm in enumerate(sms):
624
+ for j, sm_of_degree in enumerate(sm):
625
+ plot_signed_measure(sm_of_degree, ax=axes[i, j])
626
+
627
+ @staticmethod
628
+ def _check_sm(sm) -> bool:
629
+ return (
630
+ isinstance(sm, tuple)
631
+ and hasattr(sm[0], "ndim")
632
+ and sm[0].ndim == 2
633
+ and len(sm) == 2
634
+ )
635
+
636
+ def _check_axis(self, X):
637
+ # axes should be (num_data, num_axis, num_degrees, (signed_measure))
638
+ if len(X) == 0:
639
+ return
640
+ if len(X[0]) == 0:
641
+ return
642
+ if self._check_sm(X[0][0]):
643
+ self._has_axis = False
644
+ self._num_axis = 1
645
+ self._axis_iterator = [slice(None)]
646
+ return
647
+ assert ( ## vaguely checks that its a signed measure
648
+ self._check_sm(_sm := X[0][0][0])
649
+ ), f"Cannot take this input. # data, axis, degrees, sm.\n Got {_sm} of type {type(_sm)}"
650
+
651
+ self._has_axis = True
652
+ self._num_axis = len(X[0])
653
+ self._axis_iterator = range(self._num_axis) if self.axis == -1 else [self.axis]
654
+
655
+ def _check_backend(self, X):
656
+ if self._has_axis:
657
+ # data, axis, degrees, (pts, weights)
658
+ first_sm = X[0][0][0][0]
659
+ else:
660
+ first_sm = X[0][0][0]
661
+ if isinstance(first_sm, np.ndarray):
662
+ self._backend = "numpy"
663
+ else:
664
+ global torch
665
+ import torch
666
+
667
+ assert isinstance(first_sm, torch.Tensor)
668
+ self._backend = "pytorch"
669
+
670
+ def _check_measures(self, X):
671
+ if self._has_axis:
672
+ first_sm = X[0][0]
673
+ else:
674
+ first_sm = X[0]
675
+ self._num_degrees = len(first_sm)
676
+ self.num_parameters = first_sm[0][0].shape[1]
677
+
678
+ def _check_resolution(self):
679
+ assert self.num_parameters > 0, "Num parameters hasn't been initialized."
680
+ if isinstance(self.resolution, int):
681
+ self.resolution = [self.resolution] * self.num_parameters
682
+ self.resolution = np.asarray(self.resolution, dtype=int)
683
+ assert (
684
+ self.resolution.shape[0] == self.num_parameters
685
+ ), "Resolution doesn't have a proper size."
686
+
687
+ def _check_weights(self):
688
+ if self.filtrations_weights is None:
689
+ return
690
+ assert (
691
+ self.filtrations_weights.shape[0] == self.num_parameters
692
+ ), "Filtration weights don't have a proper size"
693
+
694
+ def _infer_grids(self, X):
695
+ # Computes normalization factors
696
+ if self.normalize:
697
+ # if self._has_axis and self.axis == -1:
698
+ self._filtrations_bounds = []
699
+ self._normalization_factors = []
700
+ for ax in self._axis_iterator:
701
+ (
702
+ filtration_bounds,
703
+ normalization_factors,
704
+ ) = self._get_filtration_bounds(X, axis=ax)
705
+ self._filtrations_bounds.append(filtration_bounds)
706
+ self._normalization_factors.append(normalization_factors)
707
+ # else:
708
+ # (
709
+ # self._filtrations_bounds,
710
+ # self._normalization_factors,
711
+ # ) = self._get_filtration_bounds(
712
+ # X, axis=self._axis_iterator[0]
713
+ # ) ## axis = slice(None)
714
+ elif self.integrate or self.unsparse or self.deep_format:
715
+ filtration_values = [
716
+ np.concatenate(
717
+ [
718
+ stuff
719
+ if isinstance(stuff := x[ax][degree][0], np.ndarray)
720
+ else stuff.detach().numpy()
721
+ for x in X
722
+ for degree in range(self._num_degrees)
723
+ ]
724
+ )
725
+ for ax in self._axis_iterator
726
+ ]
727
+ # axis, filtration_values
728
+ filtration_values = [
729
+ reduce_grid(
730
+ f_ax.T, resolution=self.resolution, strategy=self.grid_strategy
731
+ )
732
+ for f_ax in filtration_values
733
+ ]
734
+ self._infered_grids = filtration_values
735
+
736
+ def _print_stats(self, X):
737
+ print("------------SignedMeasureFormatter------------")
738
+ print("---- Parameters")
739
+ print(f"Number of axis : {self._num_axis}")
740
+ print(f"Number of degrees : {self._num_degrees}")
741
+ print(f"Filtration bounds : \n{self._filtrations_bounds}")
742
+ print(f"Normalization factor : \n{self._normalization_factors}")
743
+ if self._infered_grids is not None:
744
+ print(
745
+ f"Filtration grid shape : \n \
746
+ {tuple(tuple(len(f) for f in F) for F in self._infered_grids)}"
747
+ )
748
+ print("---- SM stats")
749
+ print("In axis :", self._num_axis)
750
+ sizes = [
751
+ [[len(xd[1]) for xd in x[ax]] for x in X] for ax in self._axis_iterator
752
+ ]
753
+ print(f"Size means (axis) x (degree): {np.mean(sizes, axis=(1))}")
754
+ print(f"Size std : {np.std(sizes, axis=(1))}")
755
+ print("----------------------------------------------")
756
+
757
+ def fit(self, X, y=None):
758
+ # Gets a grid. This will be the max in each coord+1
759
+ if (
760
+ len(X) == 0
761
+ or len(X[0]) == 0
762
+ or (self.axis is not None and len(X[0][0][0]) == 0)
763
+ ):
764
+ return self
765
+
766
+ self._check_axis(X)
767
+ self._check_backend(X)
768
+ self._check_measures(X)
769
+ self._check_resolution()
770
+ self._check_weights()
771
+ # if not sparse : not recommended.
772
+
773
+ self._infer_grids(X)
774
+ if self.verbose:
775
+ self._print_stats(X)
776
+ return self
777
+
778
+ def unsparse_signed_measure(self, sparse_signed_measure):
779
+ filtrations = self._infered_grids # ax, filtration
780
+ out = []
781
+ for filtrations_of_ax, ax in zip(filtrations, self._axis_iterator, strict=True):
782
+ sparse_signed_measure_of_ax = sparse_signed_measure[ax]
783
+ measure_of_ax = []
784
+ for pts, weights in sparse_signed_measure_of_ax: # over degree
785
+ signed_measure, _ = np.histogramdd(
786
+ pts, bins=filtrations_of_ax, weights=weights
787
+ )
788
+ if self.flatten:
789
+ signed_measure = signed_measure.flatten()
790
+ measure_of_ax.append(signed_measure)
791
+ out.append(np.asarray(measure_of_ax))
792
+
793
+ if self.flatten:
794
+ out = np.concatenate(out).flatten()
795
+ if self.axis == -1:
796
+ return np.asarray(out)
797
+ else:
798
+ return np.asarray(out)[0]
799
+
800
+ @staticmethod
801
+ def deep_format_measure(signed_measure):
802
+ dirac_positions, dirac_signs = signed_measure
803
+ dtype = dirac_positions.dtype
804
+ new_shape = list(dirac_positions.shape)
805
+ new_shape[1] += 1
806
+ if isinstance(dirac_positions, np.ndarray):
807
+ c = np.empty(new_shape, dtype=dtype)
808
+ c[:, :-1] = dirac_positions
809
+ c[:, -1] = dirac_signs
810
+
811
+ else:
812
+ import torch
813
+
814
+ c = torch.empty(new_shape, dtype=dtype)
815
+ c[:, :-1] = dirac_positions
816
+ c[:, -1] = dirac_signs
817
+ return c
818
+
819
+ @staticmethod
820
+ def _integrate_measure(sm, filtrations):
821
+ from multipers.point_measure_integration import integrate_measure
822
+
823
+ return integrate_measure(sm[0], sm[1], filtrations)
824
+
825
+ def _rescale_measures(self, X):
826
+ def rescale_from_sparse(sparse_signed_measure):
827
+ if self.axis == -1 and self._has_axis:
828
+ return tuple(
829
+ rescale_sparse_signed_measure(
830
+ sparse_signed_measure[ax],
831
+ filtration_weights=self.filtrations_weights,
832
+ normalize_scales=n,
833
+ )
834
+ for ax, n in zip(
835
+ self._axis_iterator, self._normalization_factors, strict=True
836
+ )
837
+ )
838
+ return rescale_sparse_signed_measure( ## axis iterator is of size 1 here
839
+ sparse_signed_measure,
840
+ filtration_weights=self.filtrations_weights,
841
+ normalize_scales=self._normalization_factors[0],
842
+ )
843
+
844
+ out = tuple(rescale_from_sparse(x) for x in X)
845
+ return out
846
+
847
+ def transform(self, X):
848
+ if not self._has_axis or self.axis == -1:
849
+ out = X
850
+ else:
851
+ out = tuple(x[self.axis] for x in X)
852
+ # same format for everyone
853
+
854
+ if self._normalization_factors is not None:
855
+ out = self._rescale_measures(out)
856
+
857
+ if self.plot:
858
+ # assert ax != -1, "Not implemented"
859
+ self._plot_signed_measures(out)
860
+ if self.integrate:
861
+ filtrations = self._infered_grids
862
+ # if self.axis != -1:
863
+ ax = 0 # if self.axis is None else self.axis # TODO deal with axis -1
864
+
865
+ assert ax != -1, "Not implemented. Can only integrate with axis"
866
+ # try:
867
+ out = np.asarray(
868
+ [
869
+ [
870
+ self._integrate_measure(x[degree], filtrations=filtrations[ax])
871
+ for degree in range(self._num_degrees)
872
+ ]
873
+ for x in out
874
+ ]
875
+ )
876
+ # except:
877
+ # print(self.axis, ax, filtrations)
878
+ if self.flatten:
879
+ out = out.reshape((len(X), -1))
880
+ # else:
881
+ # out = [[[self._integrate_measure(x[axis][degree],filtrations=filtrations[degree].T) for degree in range(self._num_degrees)] for axis in range(self._num_axis)] for x in out]
882
+ elif self.unsparse:
883
+ out = [self.unsparse_signed_measure(x) for x in out]
884
+ elif self.deep_format:
885
+ num_degrees = self._num_degrees
886
+ out = tuple(
887
+ tuple(self.deep_format_measure(sm[axis][degree]) for sm in out)
888
+ for degree in range(num_degrees)
889
+ for axis in self._axis_iterator
890
+ )
891
+ if self.unrag:
892
+ max_num_pts = np.max(
893
+ [sm.shape[0] for sm_of_axis in out for sm in sm_of_axis]
894
+ )
895
+ num_axis_degree = len(out)
896
+ num_data = len(out[0])
897
+ assert num_axis_degree == num_degrees * (
898
+ self._num_axis if self._has_axis else 1
899
+ ), f"Bad axis/degree count. Got {num_axis_degree} (Internal error)"
900
+ num_parameters = out[0][0].shape[1]
901
+ dtype = out[0][0].dtype
902
+ if isinstance(out[0][0], np.ndarray):
903
+ from numpy import zeros
904
+ else:
905
+ from torch import zeros
906
+ unragged_tensor = zeros(
907
+ (
908
+ num_axis_degree,
909
+ num_data,
910
+ max_num_pts,
911
+ num_parameters,
912
+ ),
913
+ dtype=dtype,
914
+ )
915
+ for ax in range(num_axis_degree):
916
+ for data in range(num_data):
917
+ sm = out[ax][data]
918
+ a, b = sm.shape
919
+ unragged_tensor[ax, data, :a, :b] = sm
920
+ out = unragged_tensor
921
+ return out
922
+
923
+
924
+ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
925
+ """
926
+ Discrete convolution of a signed measure
927
+
928
+ Input
929
+ -----
930
+
931
+ (data) x (degree) x (signed measure)
932
+
933
+ Parameters
934
+ ----------
935
+ - filtration_grid : Iterable[array] For each filtration, the filtration values on which to evaluate the grid
936
+ - resolution : int or (num_parameter) : If filtration grid is not given, will infer a grid, with this resolution
937
+ - grid_strategy : the strategy to generate the grid. Available ones are regular, quantile, exact
938
+ - flatten : if true, the output will be flattened
939
+ - kernel : kernel to used to convolve the images.
940
+ - flatten : flatten the images if True
941
+ - progress : progress bar if True
942
+ - backend : sklearn, pykeops or numba.
943
+ - plot : Creates a plot Figure.
944
+
945
+ Output
946
+ ------
947
+
948
+ (data) x (concatenation of imgs of degree)
949
+ """
950
+
951
+ def __init__(
952
+ self,
953
+ filtration_grid: Iterable[np.ndarray] = None,
954
+ kernel="gaussian",
955
+ bandwidth: float | Iterable[float] = 1.0,
956
+ flatten: bool = False,
957
+ n_jobs: int = 1,
958
+ resolution: int | None = None,
959
+ grid_strategy: str = "regular",
960
+ progress: bool = False,
961
+ backend: str = "pykeops",
962
+ plot: bool = False,
963
+ log_density: bool = False,
964
+ **kde_kwargs,
965
+ # **kwargs ## DANGEROUS
966
+ ):
967
+ super().__init__()
968
+ self.kernel = kernel
969
+ self.bandwidth = bandwidth
970
+ # self.more_kde_kwargs=kwargs
971
+ self.filtration_grid = filtration_grid
972
+ self.flatten = flatten
973
+ self.progress = progress
974
+ self.n_jobs = n_jobs
975
+ self.resolution = resolution
976
+ self.grid_strategy = grid_strategy
977
+ self._is_input_sparse = None
978
+ self._refit = filtration_grid is None
979
+ self._input_resolution = None
980
+ self._bandwidths = None
981
+ self.diameter = None
982
+ self.backend = backend
983
+ self.plot = plot
984
+ self.log_density = log_density
985
+ self.kde_kwargs = kde_kwargs
986
+ return
987
+
988
+ def fit(self, X, y=None):
989
+ # Infers if the input is sparse given X
990
+ if len(X) == 0:
991
+ return self
992
+ if isinstance(X[0][0], tuple):
993
+ self._is_input_sparse = True
994
+ else:
995
+ self._is_input_sparse = False
996
+ # print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
997
+ if not self._is_input_sparse:
998
+ self._input_resolution = X[0][0].shape
999
+ try:
1000
+ float(self.bandwidth)
1001
+ b = float(self.bandwidth)
1002
+ self._bandwidths = [
1003
+ b if b > 0 else -b * s for s in self._input_resolution
1004
+ ]
1005
+ except:
1006
+ self._bandwidths = [
1007
+ b if b > 0 else -b * s
1008
+ for s, b in zip(self._input_resolution, self.bandwidth)
1009
+ ]
1010
+ return self # in that case, singed measures are matrices, and the grid is already given
1011
+
1012
+ if self.filtration_grid is None and self.resolution is None:
1013
+ raise Exception(
1014
+ "Cannot infer filtration grid. Provide either a filtration grid or a resolution."
1015
+ )
1016
+ # If not sparse : a grid has to be defined
1017
+ if self._refit:
1018
+ # print("Fitting a grid...", end="")
1019
+ pts = np.concatenate(
1020
+ [sm[0] for signed_measures in X for sm in signed_measures]
1021
+ ).T
1022
+ self.filtration_grid = reduce_grid(
1023
+ pts,
1024
+ strategy=self.grid_strategy,
1025
+ resolution=self.resolution,
1026
+ )
1027
+ # print('Done.')
1028
+ if self.filtration_grid is not None:
1029
+ self.diameter = np.linalg.norm(
1030
+ [f.max() - f.min() for f in self.filtration_grid]
1031
+ )
1032
+ if self.progress:
1033
+ print(f"Computed a diameter of {self.diameter}")
1034
+ return self
1035
+
1036
+ def _sparsify(self, sm):
1037
+ return tensor_möbius_inversion(input=sm, grid_conversion=self.filtration_grid)
1038
+
1039
+ def _sm2smi(self, signed_measures: Iterable[np.ndarray]):
1040
+ # print(self._input_resolution, self.bandwidths, _bandwidths)
1041
+ from scipy.ndimage import gaussian_filter
1042
+
1043
+ return np.concatenate(
1044
+ [
1045
+ gaussian_filter(
1046
+ input=signed_measure,
1047
+ sigma=self._bandwidths,
1048
+ mode="constant",
1049
+ cval=0,
1050
+ )
1051
+ for signed_measure in signed_measures
1052
+ ],
1053
+ axis=0,
1054
+ )
1055
+
1056
+ def _transform_from_sparse(self, X):
1057
+ bandwidth = (
1058
+ self.bandwidth if self.bandwidth > 0 else -self.bandwidth * self.diameter
1059
+ )
1060
+ # COMPILE KEOPS FIRST
1061
+ dummyx = [X[0]]
1062
+ dummyf = [f[:2] for f in self.filtration_grid]
1063
+ convolution_signed_measures(
1064
+ dummyx,
1065
+ filtrations=dummyf,
1066
+ bandwidth=bandwidth,
1067
+ flatten=self.flatten,
1068
+ n_jobs=1,
1069
+ kernel=self.kernel,
1070
+ backend=self.backend,
1071
+ )
1072
+
1073
+ return convolution_signed_measures(
1074
+ X,
1075
+ filtrations=self.filtration_grid,
1076
+ bandwidth=bandwidth,
1077
+ flatten=self.flatten,
1078
+ n_jobs=self.n_jobs,
1079
+ kernel=self.kernel,
1080
+ backend=self.backend,
1081
+ **self.kde_kwargs,
1082
+ )
1083
+
1084
+ def _plot_imgs(self, imgs: Iterable[np.ndarray], size=4):
1085
+ from multipers.plots import plot_surface
1086
+
1087
+ num_degrees = imgs[0].shape[0]
1088
+ num_imgs = len(imgs)
1089
+ fig, axes = plt.subplots(
1090
+ ncols=num_degrees,
1091
+ nrows=num_imgs,
1092
+ figsize=(size * num_degrees, size * num_imgs),
1093
+ )
1094
+ axes = np.asarray(axes).reshape(num_imgs, num_degrees)
1095
+ # assert axes.ndim==2, "Internal error"
1096
+ for i, img in enumerate(imgs):
1097
+ for j, img_of_degree in enumerate(img):
1098
+ plot_surface(
1099
+ self.filtration_grid, img_of_degree, ax=axes[i, j], cmap="Spectral"
1100
+ )
1101
+
1102
+ def transform(self, X):
1103
+ if self._is_input_sparse is None:
1104
+ raise Exception("Fit first")
1105
+ if self._is_input_sparse:
1106
+ out = self._transform_from_sparse(X)
1107
+ else:
1108
+ todo = SignedMeasure2Convolution._sm2smi
1109
+ out = Parallel(n_jobs=self.n_jobs, backend="threading")(
1110
+ delayed(todo)(self, signed_measures)
1111
+ for signed_measures in tqdm(
1112
+ X, desc="Computing images", disable=not self.progress
1113
+ )
1114
+ )
1115
+ if self.plot and not self.flatten:
1116
+ if self.progress:
1117
+ print("Plotting convolutions...", end="")
1118
+ self._plot_imgs(out)
1119
+ if self.progress:
1120
+ print("Done !")
1121
+ if self.flatten and not self._is_input_sparse:
1122
+ out = [x.flatten() for x in out]
1123
+ return np.asarray(out)
1124
+
1125
+
1126
+ class SignedMeasure2SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
1127
+ """
1128
+ Transformer from signed measure to distance matrix.
1129
+
1130
+ Input
1131
+ -----
1132
+
1133
+ (data) x (degree) x (signed measure)
1134
+
1135
+ Format
1136
+ ------
1137
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1138
+ - each data is a list of signed measure (for e.g. multiple degrees)
1139
+
1140
+ Output
1141
+ ------
1142
+ - (degree) x (distance matrix)
1143
+ """
1144
+
1145
+ def __init__(
1146
+ self,
1147
+ n_jobs=None,
1148
+ num_directions: int = 10,
1149
+ _sliced: bool = True,
1150
+ epsilon=-1,
1151
+ ground_norm=1,
1152
+ progress=False,
1153
+ grid_reconversion=None,
1154
+ scales=None,
1155
+ ):
1156
+ super().__init__()
1157
+ self.n_jobs = n_jobs
1158
+ self._SWD_list = None
1159
+ self._sliced = _sliced
1160
+ self.epsilon = epsilon
1161
+ self.ground_norm = ground_norm
1162
+ self.num_directions = num_directions
1163
+ self.progress = progress
1164
+ self.grid_reconversion = grid_reconversion
1165
+ self.scales = scales
1166
+ return
1167
+
1168
+ def fit(self, X, y=None):
1169
+ from multipers.ml.sliced_wasserstein import (SlicedWassersteinDistance,
1170
+ WassersteinDistance)
1171
+
1172
+ # _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
1173
+ if len(X) == 0:
1174
+ return self
1175
+ num_degrees = len(X[0])
1176
+ self._SWD_list = [
1177
+ SlicedWassersteinDistance(
1178
+ num_directions=self.num_directions,
1179
+ n_jobs=self.n_jobs,
1180
+ scales=self.scales,
1181
+ )
1182
+ if self._sliced
1183
+ else WassersteinDistance(
1184
+ epsilon=self.epsilon, ground_norm=self.ground_norm, n_jobs=self.n_jobs
1185
+ )
1186
+ for _ in range(num_degrees)
1187
+ ]
1188
+ for degree, swd in enumerate(self._SWD_list):
1189
+ signed_measures_of_degree = [x[degree] for x in X]
1190
+ swd.fit(signed_measures_of_degree)
1191
+ return self
1192
+
1193
+ def transform(self, X):
1194
+ assert self._SWD_list is not None, "Fit first"
1195
+ # out = []
1196
+ # for degree, swd in tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress):
1197
+ with tqdm(
1198
+ enumerate(self._SWD_list),
1199
+ desc="Computing distance matrices",
1200
+ total=len(self._SWD_list),
1201
+ disable=not self.progress,
1202
+ ) as SWD_it:
1203
+ # signed_measures_of_degree = [x[degree] for x in X]
1204
+ # out.append(swd.transform(signed_measures_of_degree))
1205
+ def todo(swd, X_of_degree):
1206
+ return swd.transform(X_of_degree)
1207
+
1208
+ out = Parallel(n_jobs=self.n_jobs, prefer="threads")(
1209
+ delayed(todo)(swd, [x[degree] for x in X]) for degree, swd in SWD_it
1210
+ )
1211
+ return np.asarray(out)
1212
+
1213
+ def predict(self, X):
1214
+ return self.transform(X)
1215
+
1216
+
1217
+ class SignedMeasures2SlicedWassersteinDistances(BaseEstimator, TransformerMixin):
1218
+ """
1219
+ Transformer from signed measure to distance matrix.
1220
+ Input
1221
+ -----
1222
+ (data) x opt (axis) x (degree) x (signed measure)
1223
+
1224
+ Format
1225
+ ------
1226
+ - a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
1227
+ - each data is a list of signed measure (for e.g. multiple degrees)
1228
+
1229
+ Output
1230
+ ------
1231
+ - (axis) x (degree) x (distance matrix)
1232
+ """
1233
+
1234
+ def __init__(
1235
+ self,
1236
+ progress=False,
1237
+ n_jobs: int = 1,
1238
+ scales: Iterable[Iterable[float]] | None = None,
1239
+ **kwargs,
1240
+ ): # same init
1241
+ self._init_child = SignedMeasure2SlicedWassersteinDistance(
1242
+ progress=False, scales=None, n_jobs=-1, **kwargs
1243
+ )
1244
+ self._axe_iterator = None
1245
+ self._childs_to_fit = None
1246
+ self.scales = scales
1247
+ self.progress = progress
1248
+ self.n_jobs = n_jobs
1249
+ return
1250
+
1251
+ def fit(self, X, y=None):
1252
+ from sklearn.base import clone
1253
+
1254
+ if len(X) == 0:
1255
+ return self
1256
+ if isinstance(X[0][0], tuple): # Meaning that there are no axes
1257
+ self._axe_iterator = [slice(None)]
1258
+ else:
1259
+ self._axe_iterator = range(len(X[0]))
1260
+ if self.scales is None:
1261
+ self.scales = [None]
1262
+ else:
1263
+ self.scales = np.asarray(self.scales)
1264
+ if self.scales.ndim == 1:
1265
+ self.scales = np.asarray([self.scales])
1266
+ assert (
1267
+ self.scales[0] is None or self.scales.ndim == 2
1268
+ ), "Scales have to be either None or a list of scales !"
1269
+ self._childs_to_fit = [
1270
+ clone(self._init_child).set_params(scales=scales).fit([x[axis] for x in X])
1271
+ for axis, scales in product(self._axe_iterator, self.scales)
1272
+ ]
1273
+ print("New axes : ", list(product(self._axe_iterator, self.scales)))
1274
+ return self
1275
+
1276
+ def transform(self, X):
1277
+ return Parallel(n_jobs=self.n_jobs, prefer="processes")(
1278
+ delayed(self._childs_to_fit[child_id].transform)([x[axis] for x in X])
1279
+ for child_id, (axis, _) in tqdm(
1280
+ enumerate(product(self._axe_iterator, self.scales)),
1281
+ desc=f"Computing distances matrices of axis, and scales",
1282
+ disable=not self.progress,
1283
+ total=len(self._childs_to_fit),
1284
+ )
1285
+ )
1286
+ # [
1287
+ # child.transform([x[axis // len(self.scales)] for x in X])
1288
+ # for axis, child in tqdm(enumerate(self._childs_to_fit),
1289
+ # desc=f"Computing distances of axis", disable=not self.progress, total=len(self._childs_to_fit)
1290
+ # )
1291
+ # ]
1292
+
1293
+
1294
+ class SimplexTree2RectangleDecomposition(BaseEstimator, TransformerMixin):
1295
+ """
1296
+ Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition.
1297
+ """
1298
+
1299
+ def __init__(
1300
+ self,
1301
+ filtration_grid: np.ndarray,
1302
+ degrees: Iterable[int],
1303
+ plot=False,
1304
+ reconvert_grid=True,
1305
+ num_collapses: int = 0,
1306
+ ):
1307
+ super().__init__()
1308
+ self.filtration_grid = filtration_grid
1309
+ self.degrees = degrees
1310
+ self.plot = plot
1311
+ self.reconvert_grid = reconvert_grid
1312
+ self.num_collapses = num_collapses
1313
+ return
1314
+
1315
+ def fit(self, X, y=None):
1316
+ """
1317
+ TODO : infer grid from multiple simplextrees
1318
+ """
1319
+ return self
1320
+
1321
+ def transform(self, X: Iterable[mp.simplex_tree_multi.SimplexTreeMulti_type]):
1322
+ rectangle_decompositions = [
1323
+ [
1324
+ _st2ranktensor(
1325
+ simplextree,
1326
+ filtration_grid=self.filtration_grid,
1327
+ degree=degree,
1328
+ plot=self.plot,
1329
+ reconvert_grid=self.reconvert_grid,
1330
+ num_collapse=self.num_collapses,
1331
+ )
1332
+ for degree in self.degrees
1333
+ ]
1334
+ for simplextree in X
1335
+ ]
1336
+ # TODO : return iterator ?
1337
+ return rectangle_decompositions
1338
+
1339
+
1340
+ def _st2ranktensor(
1341
+ st: mp.simplex_tree_multi.SimplexTreeMulti_type,
1342
+ filtration_grid: np.ndarray,
1343
+ degree: int,
1344
+ plot: bool,
1345
+ reconvert_grid: bool,
1346
+ num_collapse: int | str = 0,
1347
+ ):
1348
+ """
1349
+ TODO
1350
+ """
1351
+ # Copy (the squeeze change the filtration values)
1352
+ # stcpy = mp.SimplexTreeMulti(st)
1353
+ # turns the simplextree into a coordinate simplex tree
1354
+ stcpy = st.grid_squeeze(filtration_grid=filtration_grid, coordinate_values=True)
1355
+ # stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
1356
+ if num_collapse == "full":
1357
+ stcpy.collapse_edges(full=True, ignore_warning=True, max_dimension=degree + 1)
1358
+ elif isinstance(num_collapse, int):
1359
+ stcpy.collapse_edges(
1360
+ num=num_collapse, ignore_warning=True, max_dimension=degree + 1
1361
+ )
1362
+ else:
1363
+ raise TypeError(
1364
+ f"Invalid num_collapse=\
1365
+ {num_collapse} type. Either full, or an integer."
1366
+ )
1367
+ # computes the rank invariant tensor
1368
+ rank_tensor = mp.rank_invariant2d(
1369
+ stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid]
1370
+ )
1371
+ # refactor this tensor into the rectangle decomposition of the signed betti
1372
+ grid_conversion = filtration_grid if reconvert_grid else None
1373
+ rank_decomposition = rank_decomposition_by_rectangles(
1374
+ rank_tensor,
1375
+ threshold=True,
1376
+ )
1377
+ rectangle_decomposition = tensor_möbius_inversion(
1378
+ tensor=rank_decomposition,
1379
+ grid_conversion=grid_conversion,
1380
+ plot=plot,
1381
+ num_parameters=st.num_parameters,
1382
+ )
1383
+ return rectangle_decomposition
1384
+
1385
+
1386
+ class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
1387
+ def __init__(
1388
+ self,
1389
+ degrees: Iterable[int],
1390
+ min_rips_value: float,
1391
+ max_rips_value,
1392
+ max_normalized_degree: float,
1393
+ min_normalized_degree: float,
1394
+ grid_granularity: int,
1395
+ progress: bool = False,
1396
+ n_jobs=1,
1397
+ sparse: bool = False,
1398
+ _möbius_inversion=True,
1399
+ fit_fraction=1,
1400
+ ) -> None:
1401
+ super().__init__()
1402
+ self.min_rips_value = min_rips_value
1403
+ self.max_rips_value = max_rips_value
1404
+ self.min_normalized_degree = min_normalized_degree
1405
+ self.max_normalized_degree = max_normalized_degree
1406
+ self._max_rips_value = None
1407
+ self.grid_granularity = grid_granularity
1408
+ self.progress = progress
1409
+ self.n_jobs = n_jobs
1410
+ self.degrees = degrees
1411
+ self.sparse = sparse
1412
+ self._möbius_inversion = _möbius_inversion
1413
+ self.fit_fraction = fit_fraction
1414
+ return
1415
+
1416
+ def fit(self, X: np.ndarray | list, y=None):
1417
+ if self.max_rips_value < 0:
1418
+ print("Estimating scale...", flush=True, end="")
1419
+ indices = np.random.choice(
1420
+ len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
1421
+ )
1422
+ diameters = np.max(
1423
+ [distance_matrix(x, x).max() for x in (X[i] for i in indices)]
1424
+ )
1425
+ print(f"Done. {diameters}", flush=True)
1426
+ self._max_rips_value = (
1427
+ -self.max_rips_value * diameters
1428
+ if self.max_rips_value < 0
1429
+ else self.max_rips_value
1430
+ )
1431
+ return self
1432
+
1433
+ def _transform1(self, data: np.ndarray):
1434
+ _distance_matrix = distance_matrix(data, data)
1435
+ signed_measures = []
1436
+ (
1437
+ rips_values,
1438
+ normalized_degree_values,
1439
+ hilbert_functions,
1440
+ minimal_presentations,
1441
+ ) = hf_degree_rips(
1442
+ _distance_matrix,
1443
+ min_rips_value=self.min_rips_value,
1444
+ max_rips_value=self._max_rips_value,
1445
+ min_normalized_degree=self.min_normalized_degree,
1446
+ max_normalized_degree=self.max_normalized_degree,
1447
+ grid_granularity=self.grid_granularity,
1448
+ max_homological_dimension=np.max(self.degrees),
1449
+ )
1450
+ for degree in self.degrees:
1451
+ hilbert_function = hilbert_functions[degree]
1452
+ signed_measure = (
1453
+ signed_betti(hilbert_function, threshold=True)
1454
+ if self._möbius_inversion
1455
+ else hilbert_function
1456
+ )
1457
+ if self.sparse:
1458
+ signed_measure = tensor_möbius_inversion(
1459
+ tensor=signed_measure,
1460
+ num_parameters=2,
1461
+ grid_conversion=[rips_values, normalized_degree_values],
1462
+ )
1463
+ if not self._möbius_inversion:
1464
+ signed_measure = signed_measure.flatten()
1465
+ signed_measures.append(signed_measure)
1466
+ return signed_measures
1467
+
1468
+ def transform(self, X):
1469
+ return Parallel(n_jobs=self.n_jobs)(
1470
+ delayed(self._transform1)(data)
1471
+ for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}")
1472
+ )
1473
+
1474
+
1475
+ def tensor_möbius_inversion(
1476
+ tensor,
1477
+ grid_conversion: Iterable[np.ndarray] | None = None,
1478
+ plot: bool = False,
1479
+ raw: bool = False,
1480
+ num_parameters: int | None = None,
1481
+ ):
1482
+ from torch import Tensor
1483
+
1484
+ betti_sparse = Tensor(tensor.copy()).to_sparse() # Copy necessary in some cases :(
1485
+ num_indices, num_pts = betti_sparse.indices().shape
1486
+ num_parameters = num_indices if num_parameters is None else num_parameters
1487
+ if num_indices == num_parameters: # either hilbert or rank invariant
1488
+ rank_invariant = False
1489
+ elif 2 * num_parameters == num_indices:
1490
+ rank_invariant = True
1491
+ else:
1492
+ raise TypeError(
1493
+ f"Unsupported betti shape. {num_indices}\
1494
+ has to be either {num_parameters} or \
1495
+ {2*num_parameters}."
1496
+ )
1497
+ points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
1498
+ weights = np.asarray(betti_sparse.values(), dtype=int)
1499
+
1500
+ if grid_conversion is not None:
1501
+ coords = np.empty(shape=(num_pts, num_indices), dtype=float)
1502
+ for i in range(num_indices):
1503
+ coords[:, i] = grid_conversion[i % num_parameters][points_filtration[:, i]]
1504
+ else:
1505
+ coords = points_filtration
1506
+ if (not rank_invariant) and plot:
1507
+ plt.figure()
1508
+ color_weights = np.empty(weights.shape)
1509
+ color_weights[weights > 0] = np.log10(weights[weights > 0]) + 2
1510
+ color_weights[weights < 0] = -np.log10(-weights[weights < 0]) - 2
1511
+ plt.scatter(
1512
+ points_filtration[:, 0],
1513
+ points_filtration[:, 1],
1514
+ c=color_weights,
1515
+ cmap="coolwarm",
1516
+ )
1517
+ if (not rank_invariant) or raw:
1518
+ return coords, weights
1519
+
1520
+ def _is_trivial(rectangle: np.ndarray):
1521
+ birth = rectangle[:num_parameters]
1522
+ death = rectangle[num_parameters:]
1523
+ return np.all(birth <= death) # and not np.array_equal(birth,death)
1524
+
1525
+ correct_indices = np.array([_is_trivial(rectangle) for rectangle in coords])
1526
+ if len(correct_indices) == 0:
1527
+ return np.empty((0, num_indices)), np.empty((0))
1528
+ signed_measure = np.asarray(coords[correct_indices])
1529
+ weights = weights[correct_indices]
1530
+ if plot:
1531
+ # plot only the rank decompo for the moment
1532
+ assert signed_measure.shape[1] == 4
1533
+
1534
+ def _plot_rectangle(rectangle: np.ndarray, weight: float):
1535
+ x_axis = rectangle[[0, 2]]
1536
+ y_axis = rectangle[[1, 3]]
1537
+ color = "blue" if weight > 0 else "red"
1538
+ plt.plot(x_axis, y_axis, c=color)
1539
+
1540
+ for rectangle, weight in zip(signed_measure, weights):
1541
+ _plot_rectangle(rectangle=rectangle, weight=weight)
1542
+ return signed_measure, weights