multipers 2.3.3b6__cp313-cp313-macosx_10_13_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/.dylibs/libc++.1.0.dylib +0 -0
- multipers/.dylibs/libtbb.12.16.dylib +0 -0
- multipers/__init__.py +33 -0
- multipers/_signed_measure_meta.py +453 -0
- multipers/_slicer_meta.py +211 -0
- multipers/array_api/__init__.py +45 -0
- multipers/array_api/numpy.py +41 -0
- multipers/array_api/torch.py +58 -0
- multipers/data/MOL2.py +458 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +113 -0
- multipers/distances.py +202 -0
- multipers/filtration_conversions.pxd +229 -0
- multipers/filtration_conversions.pxd.tp +84 -0
- multipers/filtrations/__init__.py +18 -0
- multipers/filtrations/density.py +574 -0
- multipers/filtrations/filtrations.py +361 -0
- multipers/filtrations.pxd +224 -0
- multipers/function_rips.cpython-313-darwin.so +0 -0
- multipers/function_rips.pyx +105 -0
- multipers/grids.cpython-313-darwin.so +0 -0
- multipers/grids.pyx +433 -0
- multipers/gudhi/Persistence_slices_interface.h +132 -0
- multipers/gudhi/Simplex_tree_interface.h +239 -0
- multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
- multipers/gudhi/cubical_to_boundary.h +59 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
- multipers/gudhi/gudhi/Debug_utils.h +45 -0
- multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
- multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
- multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
- multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
- multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
- multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
- multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
- multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
- multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
- multipers/gudhi/gudhi/Matrix.h +2107 -0
- multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
- multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
- multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
- multipers/gudhi/gudhi/Off_reader.h +173 -0
- multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
- multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
- multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
- multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
- multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
- multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
- multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
- multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
- multipers/gudhi/gudhi/Points_off_io.h +171 -0
- multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
- multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
- multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
- multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
- multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
- multipers/gudhi/gudhi/distance_functions.h +62 -0
- multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
- multipers/gudhi/gudhi/persistence_interval.h +253 -0
- multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
- multipers/gudhi/gudhi/reader_utils.h +367 -0
- multipers/gudhi/mma_interface_coh.h +256 -0
- multipers/gudhi/mma_interface_h0.h +223 -0
- multipers/gudhi/mma_interface_matrix.h +293 -0
- multipers/gudhi/naive_merge_tree.h +536 -0
- multipers/gudhi/scc_io.h +310 -0
- multipers/gudhi/truc.h +1403 -0
- multipers/io.cpython-313-darwin.so +0 -0
- multipers/io.pyx +644 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +90 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +176 -0
- multipers/ml/mma.py +713 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +352 -0
- multipers/ml/signed_measures.py +1589 -0
- multipers/ml/sliced_wasserstein.py +461 -0
- multipers/ml/tools.py +113 -0
- multipers/mma_structures.cpython-313-darwin.so +0 -0
- multipers/mma_structures.pxd +128 -0
- multipers/mma_structures.pyx +2786 -0
- multipers/mma_structures.pyx.tp +1094 -0
- multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
- multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
- multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
- multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
- multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
- multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
- multipers/multiparameter_edge_collapse.py +41 -0
- multipers/multiparameter_module_approximation/approximation.h +2330 -0
- multipers/multiparameter_module_approximation/combinatory.h +129 -0
- multipers/multiparameter_module_approximation/debug.h +107 -0
- multipers/multiparameter_module_approximation/euler_curves.h +0 -0
- multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
- multipers/multiparameter_module_approximation/heap_column.h +238 -0
- multipers/multiparameter_module_approximation/images.h +79 -0
- multipers/multiparameter_module_approximation/list_column.h +174 -0
- multipers/multiparameter_module_approximation/list_column_2.h +232 -0
- multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
- multipers/multiparameter_module_approximation/set_column.h +135 -0
- multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
- multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
- multipers/multiparameter_module_approximation/utilities.h +403 -0
- multipers/multiparameter_module_approximation/vector_column.h +223 -0
- multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
- multipers/multiparameter_module_approximation/vineyards.h +464 -0
- multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
- multipers/multiparameter_module_approximation.cpython-313-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyx +235 -0
- multipers/pickle.py +90 -0
- multipers/plots.py +456 -0
- multipers/point_measure.cpython-313-darwin.so +0 -0
- multipers/point_measure.pyx +395 -0
- multipers/simplex_tree_multi.cpython-313-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +134 -0
- multipers/simplex_tree_multi.pyx +10840 -0
- multipers/simplex_tree_multi.pyx.tp +2009 -0
- multipers/slicer.cpython-313-darwin.so +0 -0
- multipers/slicer.pxd +3034 -0
- multipers/slicer.pxd.tp +234 -0
- multipers/slicer.pyx +20481 -0
- multipers/slicer.pyx.tp +1088 -0
- multipers/tensor/tensor.h +672 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers/tests/__init__.py +62 -0
- multipers/torch/__init__.py +1 -0
- multipers/torch/diff_grids.py +240 -0
- multipers/torch/rips_density.py +310 -0
- multipers-2.3.3b6.dist-info/METADATA +128 -0
- multipers-2.3.3b6.dist-info/RECORD +183 -0
- multipers-2.3.3b6.dist-info/WHEEL +6 -0
- multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
- multipers-2.3.3b6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1589 @@
|
|
|
1
|
+
from collections.abc import Iterable, Sequence
|
|
2
|
+
from itertools import product
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
from joblib import Parallel, delayed
|
|
8
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
import multipers as mp
|
|
12
|
+
from multipers.array_api import api_from_tensor
|
|
13
|
+
from multipers.filtrations.density import available_kernels, convolution_signed_measures
|
|
14
|
+
from multipers.grids import compute_grid
|
|
15
|
+
from multipers.point_measure import rank_decomposition_by_rectangles, signed_betti
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
|
|
19
|
+
"""
|
|
20
|
+
Input
|
|
21
|
+
-----
|
|
22
|
+
Iterable[SimplexTreeMulti]
|
|
23
|
+
|
|
24
|
+
Output
|
|
25
|
+
------
|
|
26
|
+
Iterable[ list[signed_measure for degree] ]
|
|
27
|
+
|
|
28
|
+
signed measure is either
|
|
29
|
+
- (points : (n x num_parameters) array, weights : (n) int array ) if sparse,
|
|
30
|
+
- else an integer matrix.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
- degrees : list of degrees to compute. None correspond to the euler characteristic
|
|
35
|
+
- filtration grid : the grid on which to compute.
|
|
36
|
+
If None, the fit will infer it from
|
|
37
|
+
- fit_fraction : the fraction of data to consider for the fit, seed is controlled by the seed parameter
|
|
38
|
+
- resolution : the resolution of this grid
|
|
39
|
+
- filtration_quantile : filtrations values quantile to ignore
|
|
40
|
+
- grid_strategy:str : 'regular' or 'quantile' or 'exact'
|
|
41
|
+
- normalize filtration : if sparse, will normalize all filtrations.
|
|
42
|
+
- expand : expands the simplextree to compute correctly the degree, for
|
|
43
|
+
flag complexes
|
|
44
|
+
- invariant : the topological invariant to produce the signed measure.
|
|
45
|
+
Choices are "hilbert" or "euler". Will add rank invariant later.
|
|
46
|
+
- num_collapse : Either an int or "full". Collapse the complex before
|
|
47
|
+
doing computation.
|
|
48
|
+
- _möbius_inversion : if False, will not do the mobius inversion. output
|
|
49
|
+
has to be a matrix then.
|
|
50
|
+
- enforce_null_mass : Returns a zero mass measure, by thresholding the
|
|
51
|
+
module if True.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
# homological degrees + None for euler
|
|
57
|
+
degrees: list[int | None] = [],
|
|
58
|
+
rank_degrees: list[int] = [], # same for rank invariant
|
|
59
|
+
filtration_grid: (
|
|
60
|
+
Sequence[Sequence[np.ndarray]]
|
|
61
|
+
# filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
|
|
62
|
+
| None
|
|
63
|
+
) = None,
|
|
64
|
+
progress=False, # tqdm
|
|
65
|
+
num_collapses: int | str = 0, # edge collapses before computing
|
|
66
|
+
n_jobs=None,
|
|
67
|
+
resolution: (
|
|
68
|
+
Iterable[int] | int | None
|
|
69
|
+
) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
|
|
70
|
+
# sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
|
|
71
|
+
plot: bool = False,
|
|
72
|
+
filtration_quantile: float = 0.0, # quantile for inferring filtration grid
|
|
73
|
+
# wether or not to do the möbius inversion (not recommended to touch)
|
|
74
|
+
# _möbius_inversion: bool = True,
|
|
75
|
+
expand=False, # expand the simplextree befoe computing the homology
|
|
76
|
+
normalize_filtrations: bool = False,
|
|
77
|
+
# exact_computation:bool=False, # compute the exact signed measure.
|
|
78
|
+
grid_strategy: str = "exact",
|
|
79
|
+
seed: int = 0, # if fit_fraction is not 1, the seed sampling
|
|
80
|
+
fit_fraction=1, # the fraction of the data on which to fit
|
|
81
|
+
out_resolution: Iterable[int] | int | None = None,
|
|
82
|
+
individual_grid: Optional[
|
|
83
|
+
bool
|
|
84
|
+
] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
|
|
85
|
+
enforce_null_mass: bool = False,
|
|
86
|
+
flatten=True,
|
|
87
|
+
backend: Optional[str] = None,
|
|
88
|
+
):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.degrees = degrees
|
|
91
|
+
self.rank_degrees = rank_degrees
|
|
92
|
+
self.filtration_grid = filtration_grid
|
|
93
|
+
self.progress = progress
|
|
94
|
+
self.num_collapses = num_collapses
|
|
95
|
+
self.n_jobs = n_jobs
|
|
96
|
+
self.resolution = resolution
|
|
97
|
+
self.plot = plot
|
|
98
|
+
self.backend = backend
|
|
99
|
+
# self.sparse=sparse # TODO : deprecate
|
|
100
|
+
self.filtration_quantile = filtration_quantile
|
|
101
|
+
# Will only work for non sparse output. (discrete matrices cannot be "rescaled")
|
|
102
|
+
self.normalize_filtrations = normalize_filtrations
|
|
103
|
+
self.grid_strategy = grid_strategy
|
|
104
|
+
# self._möbius_inversion = _möbius_inversion
|
|
105
|
+
self._reconversion_grid = None
|
|
106
|
+
self.expand = expand
|
|
107
|
+
# will only refit the grid if filtration_grid has never been given.
|
|
108
|
+
self._refit_grid = None
|
|
109
|
+
self.seed = seed
|
|
110
|
+
self.fit_fraction = fit_fraction
|
|
111
|
+
self._transform_st = None
|
|
112
|
+
self.out_resolution = out_resolution
|
|
113
|
+
self.individual_grid = individual_grid
|
|
114
|
+
self.enforce_null_mass = enforce_null_mass
|
|
115
|
+
self._default_mass_location = None
|
|
116
|
+
self.flatten = flatten
|
|
117
|
+
self.num_parameters: int = 0
|
|
118
|
+
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
def _is_filtered_complex(input):
|
|
123
|
+
return mp.simplex_tree_multi.is_simplextree_multi(input) or mp.slicer.is_slicer(
|
|
124
|
+
input, allow_minpres=True
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def _input_checks(self, X):
|
|
128
|
+
assert len(X) > 0, "No filtered complex found. Cannot fit."
|
|
129
|
+
assert self._is_filtered_complex(
|
|
130
|
+
X[0][0]
|
|
131
|
+
), f"X[0] is not a known filtered complex, {X[0]=}, nor X[0][0]."
|
|
132
|
+
self._num_axis = len(X[0])
|
|
133
|
+
first = X[0][0]
|
|
134
|
+
assert (
|
|
135
|
+
not mp.slicer.is_slicer(first) or not self.expand
|
|
136
|
+
), "Cannot expand slicers."
|
|
137
|
+
assert not mp.slicer.is_slicer(first) or not (
|
|
138
|
+
isinstance(first, Union[tuple, list]) and first[0].is_minpres
|
|
139
|
+
), "Multi-degree minpres are not supported yet as an input. This can still be computed by providing a backend."
|
|
140
|
+
|
|
141
|
+
def _infer_filtration(self, X):
|
|
142
|
+
self.num_parameters = X[0][0].num_parameters
|
|
143
|
+
indices = np.random.choice(
|
|
144
|
+
len(X), min(int(self.fit_fraction * len(X)) + 1, len(X)), replace=False
|
|
145
|
+
)
|
|
146
|
+
## ax, num_x
|
|
147
|
+
filtrations = tuple(
|
|
148
|
+
tuple(
|
|
149
|
+
compute_grid(x, strategy="exact")
|
|
150
|
+
for x in (X[idx][ax] for idx in indices)
|
|
151
|
+
)
|
|
152
|
+
for ax in range(self._num_axis)
|
|
153
|
+
)
|
|
154
|
+
num_parameters = len(filtrations[0][0])
|
|
155
|
+
assert (
|
|
156
|
+
num_parameters == self.num_parameters
|
|
157
|
+
), f"Internal error, got {num_parameters=} and {self.num_parameters=}"
|
|
158
|
+
|
|
159
|
+
filtrations_values = [
|
|
160
|
+
[
|
|
161
|
+
np.unique(np.concatenate([x[i] for x in filtrations[ax]]))
|
|
162
|
+
for i in range(num_parameters)
|
|
163
|
+
]
|
|
164
|
+
for ax in range(self._num_axis)
|
|
165
|
+
]
|
|
166
|
+
## ax, param, gridsize
|
|
167
|
+
filtration_grid = tuple(
|
|
168
|
+
compute_grid(
|
|
169
|
+
filtrations_values[ax],
|
|
170
|
+
resolution=self.resolution,
|
|
171
|
+
strategy=self.grid_strategy,
|
|
172
|
+
)
|
|
173
|
+
for ax in range(self._num_axis)
|
|
174
|
+
) # TODO :use more parameters
|
|
175
|
+
self.filtration_grid = filtration_grid
|
|
176
|
+
return filtration_grid
|
|
177
|
+
|
|
178
|
+
def _params_check(self):
|
|
179
|
+
assert (
|
|
180
|
+
self.resolution is not None
|
|
181
|
+
or self.filtration_grid is not None
|
|
182
|
+
or self.grid_strategy == "exact"
|
|
183
|
+
or self.individual_grid
|
|
184
|
+
), "For non exact filtrations, a resolution has to be specified."
|
|
185
|
+
|
|
186
|
+
def fit(self, X, y=None): # Todo : infer filtration grid ? quantiles ?
|
|
187
|
+
self._params_check()
|
|
188
|
+
self._input_checks(X)
|
|
189
|
+
|
|
190
|
+
if isinstance(self.resolution, int):
|
|
191
|
+
self.resolution = [self.resolution] * self.num_parameters
|
|
192
|
+
|
|
193
|
+
self.individual_grid = (
|
|
194
|
+
self.individual_grid
|
|
195
|
+
if self.individual_grid is not None
|
|
196
|
+
else self.grid_strategy
|
|
197
|
+
in ["regular_closest", "exact", "quantile", "partition"]
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if (
|
|
201
|
+
not self.enforce_null_mass
|
|
202
|
+
and self.individual_grid
|
|
203
|
+
or self.filtration_grid is not None
|
|
204
|
+
):
|
|
205
|
+
self._refit_grid = False
|
|
206
|
+
else:
|
|
207
|
+
self._refit_grid = True
|
|
208
|
+
|
|
209
|
+
if self._refit_grid:
|
|
210
|
+
self._infer_filtration(X=X)
|
|
211
|
+
if self.out_resolution is None:
|
|
212
|
+
self.out_resolution = self.resolution
|
|
213
|
+
# elif isinstance(self.out_resolution, int):
|
|
214
|
+
# self.out_resolution = [self.out_resolution] * self.num_parameters
|
|
215
|
+
if self.normalize_filtrations and not self.individual_grid:
|
|
216
|
+
# self._reconversion_grid = [np.linspace(0,1, num=len(f), dtype=float) for f in self.filtration_grid] ## This will not work for non-regular grids...
|
|
217
|
+
self._reconversion_grid = [
|
|
218
|
+
[(f - np.min(f)) / np.std(f) for f in F] for F in self.filtration_grid
|
|
219
|
+
] # not the best, but better than some weird magic
|
|
220
|
+
# elif not self.sparse: # It actually renormalizes the filtration !!
|
|
221
|
+
# self._reconversion_grid = [np.linspace(0,r, num=r, dtype=int) for r in self.out_resolution]
|
|
222
|
+
else:
|
|
223
|
+
self._reconversion_grid = self.filtration_grid
|
|
224
|
+
## ax, num_param
|
|
225
|
+
self._default_mass_location = (
|
|
226
|
+
np.asarray([[g[-1] for g in F] for F in self.filtration_grid])
|
|
227
|
+
if self.enforce_null_mass
|
|
228
|
+
else None
|
|
229
|
+
)
|
|
230
|
+
return self
|
|
231
|
+
|
|
232
|
+
def transform1(
|
|
233
|
+
self,
|
|
234
|
+
simplextree,
|
|
235
|
+
ax,
|
|
236
|
+
# _reconversion_grid,
|
|
237
|
+
thread_id: str = "",
|
|
238
|
+
):
|
|
239
|
+
# st = mp.SimplexTreeMulti(st, num_parameters=st.num_parameters) # COPY
|
|
240
|
+
if self.individual_grid:
|
|
241
|
+
filtration_grid = compute_grid(
|
|
242
|
+
simplextree, strategy=self.grid_strategy, resolution=self.resolution
|
|
243
|
+
)
|
|
244
|
+
mass_default = (
|
|
245
|
+
self._default_mass_location[ax] if self.enforce_null_mass else None
|
|
246
|
+
)
|
|
247
|
+
if self.enforce_null_mass:
|
|
248
|
+
filtration_grid = [
|
|
249
|
+
np.concatenate([f, [d]], axis=0)
|
|
250
|
+
for f, d in zip(filtration_grid, mass_default)
|
|
251
|
+
]
|
|
252
|
+
_reconversion_grid = filtration_grid
|
|
253
|
+
else:
|
|
254
|
+
filtration_grid = self.filtration_grid[ax]
|
|
255
|
+
_reconversion_grid = self._reconversion_grid[ax]
|
|
256
|
+
mass_default = (
|
|
257
|
+
self._default_mass_location[ax] if self.enforce_null_mass else None
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
st = simplextree.grid_squeeze(filtration_grid=filtration_grid)
|
|
261
|
+
if st.num_parameters == 2 and mp.simplex_tree_multi.is_simplextree_multi(st):
|
|
262
|
+
st.collapse_edges(num=self.num_collapses, max_dimension=1)
|
|
263
|
+
int_degrees = np.asarray([d for d in self.degrees if d is not None], dtype=int)
|
|
264
|
+
# EULER. First as there is prune above dimension below
|
|
265
|
+
if self.expand and None in self.degrees:
|
|
266
|
+
st.expansion(st.num_vertices)
|
|
267
|
+
signed_measures_euler = (
|
|
268
|
+
mp.signed_measure(
|
|
269
|
+
st,
|
|
270
|
+
degrees=[None],
|
|
271
|
+
plot=self.plot,
|
|
272
|
+
mass_default=mass_default,
|
|
273
|
+
invariant="euler",
|
|
274
|
+
# thread_id=thread_id,
|
|
275
|
+
backend=self.backend,
|
|
276
|
+
grid=_reconversion_grid,
|
|
277
|
+
)[0]
|
|
278
|
+
if None in self.degrees
|
|
279
|
+
else []
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
if self.expand and len(int_degrees) > 0:
|
|
283
|
+
st.expansion(np.max(int_degrees) + 1)
|
|
284
|
+
if len(int_degrees) > 0:
|
|
285
|
+
st.prune_above_dimension(
|
|
286
|
+
np.max(np.concatenate([int_degrees, self.rank_degrees])) + 1
|
|
287
|
+
) # no need to compute homology beyond this
|
|
288
|
+
signed_measures_pers = (
|
|
289
|
+
mp.signed_measure(
|
|
290
|
+
st,
|
|
291
|
+
degrees=int_degrees,
|
|
292
|
+
mass_default=mass_default,
|
|
293
|
+
plot=self.plot,
|
|
294
|
+
invariant="hilbert",
|
|
295
|
+
thread_id=thread_id,
|
|
296
|
+
backend=self.backend,
|
|
297
|
+
grid=_reconversion_grid,
|
|
298
|
+
)
|
|
299
|
+
if len(int_degrees) > 0
|
|
300
|
+
else []
|
|
301
|
+
)
|
|
302
|
+
if self.plot:
|
|
303
|
+
plt.show()
|
|
304
|
+
if self.expand and len(self.rank_degrees) > 0:
|
|
305
|
+
st.expansion(np.max(self.rank_degrees) + 1)
|
|
306
|
+
if len(self.rank_degrees) > 0:
|
|
307
|
+
st.prune_above_dimension(
|
|
308
|
+
np.max(self.rank_degrees) + 1
|
|
309
|
+
) # no need to compute homology beyond this
|
|
310
|
+
signed_measures_rank = (
|
|
311
|
+
mp.signed_measure(
|
|
312
|
+
st,
|
|
313
|
+
degrees=self.rank_degrees,
|
|
314
|
+
mass_default=mass_default,
|
|
315
|
+
plot=self.plot,
|
|
316
|
+
invariant="rank",
|
|
317
|
+
thread_id=thread_id,
|
|
318
|
+
backend=self.backend,
|
|
319
|
+
grid=_reconversion_grid,
|
|
320
|
+
)
|
|
321
|
+
if len(self.rank_degrees) > 0
|
|
322
|
+
else []
|
|
323
|
+
)
|
|
324
|
+
if self.plot:
|
|
325
|
+
plt.show()
|
|
326
|
+
|
|
327
|
+
count = 0
|
|
328
|
+
signed_measures = []
|
|
329
|
+
for d in self.degrees:
|
|
330
|
+
if d is None:
|
|
331
|
+
signed_measures.append(signed_measures_euler)
|
|
332
|
+
else:
|
|
333
|
+
signed_measures.append(signed_measures_pers[count])
|
|
334
|
+
count += 1
|
|
335
|
+
signed_measures += signed_measures_rank
|
|
336
|
+
return signed_measures
|
|
337
|
+
|
|
338
|
+
def transform(self, X):
|
|
339
|
+
## X of shape (num_x, num_axis, filtered_complex
|
|
340
|
+
assert (
|
|
341
|
+
self.filtration_grid is not None and self._reconversion_grid is not None
|
|
342
|
+
) or self.individual_grid, "Fit first"
|
|
343
|
+
|
|
344
|
+
def todo_x(x):
|
|
345
|
+
return tuple(self.transform1(x_axis, j) for j, x_axis in enumerate(x))
|
|
346
|
+
|
|
347
|
+
## out shape num_x, num_axis, degree, sm
|
|
348
|
+
out = tuple(
|
|
349
|
+
Parallel(n_jobs=self.n_jobs, backend="threading")(
|
|
350
|
+
delayed(todo_x)(x) for x in X
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
# out = Parallel(n_jobs=self.n_jobs, backend="threading")(
|
|
354
|
+
# delayed(self.transform1)(to_st, thread_id=str(thread_id))
|
|
355
|
+
# for thread_id, to_st in tqdm(
|
|
356
|
+
# enumerate(X),
|
|
357
|
+
# disable=not self.progress,
|
|
358
|
+
# desc="Computing signed measure decompositions",
|
|
359
|
+
# )
|
|
360
|
+
# )
|
|
361
|
+
return out
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class SimplexTree2SignedMeasure(FilteredComplex2SignedMeasure):
|
|
365
|
+
def __init__(
|
|
366
|
+
self,
|
|
367
|
+
# homological degrees + None for euler
|
|
368
|
+
degrees: list[int | None] = [],
|
|
369
|
+
rank_degrees: list[int] = [], # same for rank invariant
|
|
370
|
+
filtration_grid: (
|
|
371
|
+
Sequence[Sequence[np.ndarray]]
|
|
372
|
+
# filtration values to consider. Format : [ filtration values of Fi for Fi:filtration values of parameter i]
|
|
373
|
+
| None
|
|
374
|
+
) = None,
|
|
375
|
+
progress=False, # tqdm
|
|
376
|
+
num_collapses: int | str = 0, # edge collapses before computing
|
|
377
|
+
n_jobs=None,
|
|
378
|
+
resolution: (
|
|
379
|
+
Iterable[int] | int | None
|
|
380
|
+
) = None, # when filtration grid is not given, the resolution of the filtration grid to infer
|
|
381
|
+
# sparse=True, # sparse output # DEPRECATED TO Ssigned measure formatter
|
|
382
|
+
plot: bool = False,
|
|
383
|
+
filtration_quantile: float = 0.0, # quantile for inferring filtration grid
|
|
384
|
+
# wether or not to do the möbius inversion (not recommended to touch)
|
|
385
|
+
# _möbius_inversion: bool = True,
|
|
386
|
+
expand=False, # expand the simplextree befoe computing the homology
|
|
387
|
+
normalize_filtrations: bool = False,
|
|
388
|
+
# exact_computation:bool=False, # compute the exact signed measure.
|
|
389
|
+
grid_strategy: str = "exact",
|
|
390
|
+
seed: int = 0, # if fit_fraction is not 1, the seed sampling
|
|
391
|
+
fit_fraction=1, # the fraction of the data on which to fit
|
|
392
|
+
out_resolution: Iterable[int] | int | None = None,
|
|
393
|
+
individual_grid: Optional[
|
|
394
|
+
bool
|
|
395
|
+
] = None, # Can be significantly faster for some grid strategies, but can drop statistical performance
|
|
396
|
+
enforce_null_mass: bool = False,
|
|
397
|
+
flatten=True,
|
|
398
|
+
backend: Optional[str] = None,
|
|
399
|
+
):
|
|
400
|
+
stuff = locals()
|
|
401
|
+
stuff.pop("self")
|
|
402
|
+
keys = list(stuff.keys())
|
|
403
|
+
for key in keys:
|
|
404
|
+
if key.startswith("__"):
|
|
405
|
+
stuff.pop(key)
|
|
406
|
+
super().__init__(**stuff)
|
|
407
|
+
from warnings import warn
|
|
408
|
+
|
|
409
|
+
warn("This class is deprecated, use FilteredComplex2SignedMeasure instead.")
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
# class SimplexTrees2SignedMeasures(SimplexTree2SignedMeasure):
|
|
413
|
+
# """
|
|
414
|
+
# Input
|
|
415
|
+
# -----
|
|
416
|
+
#
|
|
417
|
+
# (data) x (axis, e.g. different bandwidths for simplextrees) x (simplextree)
|
|
418
|
+
#
|
|
419
|
+
# Output
|
|
420
|
+
# ------
|
|
421
|
+
# (data) x (axis) x (degree) x (signed measure)
|
|
422
|
+
# """
|
|
423
|
+
#
|
|
424
|
+
# def __init__(self, **kwargs):
|
|
425
|
+
# super().__init__(**kwargs)
|
|
426
|
+
# self._num_st_per_data = None
|
|
427
|
+
# # self._super_model=SimplexTree2SignedMeasure(**kwargs)
|
|
428
|
+
# self._filtration_grids = None
|
|
429
|
+
# return
|
|
430
|
+
#
|
|
431
|
+
# def fit(self, X, y=None):
|
|
432
|
+
# if len(X) == 0:
|
|
433
|
+
# return self
|
|
434
|
+
# try:
|
|
435
|
+
# self._num_st_per_data = len(X[0])
|
|
436
|
+
# except:
|
|
437
|
+
# raise Exception(
|
|
438
|
+
# "Shape has to be (num_data, num_axis), dtype=SimplexTreeMulti"
|
|
439
|
+
# )
|
|
440
|
+
# self._filtration_grids = []
|
|
441
|
+
# for axis in range(self._num_st_per_data):
|
|
442
|
+
# self._filtration_grids.append(
|
|
443
|
+
# super().fit([x[axis] for x in X]).filtration_grid
|
|
444
|
+
# )
|
|
445
|
+
# # self._super_fits.append(truc)
|
|
446
|
+
# # self._super_fits_params = [super().fit([x[axis] for x in X]).get_params() for axis in range(self._num_st_per_data)]
|
|
447
|
+
# return self
|
|
448
|
+
#
|
|
449
|
+
# def transform(self, X):
|
|
450
|
+
# if self.normalize_filtrations:
|
|
451
|
+
# _reconversion_grids = [
|
|
452
|
+
# [np.linspace(0, 1, num=len(f), dtype=float) for f in F]
|
|
453
|
+
# for F in self._filtration_grids
|
|
454
|
+
# ]
|
|
455
|
+
# else:
|
|
456
|
+
# _reconversion_grids = self._filtration_grids
|
|
457
|
+
#
|
|
458
|
+
# def todo(x):
|
|
459
|
+
# # return [SimplexTree2SignedMeasure().set_params(**transformer_params).transform1(x[axis]) for axis,transformer_params in enumerate(self._super_fits_params)]
|
|
460
|
+
# out = [
|
|
461
|
+
# self.transform1(
|
|
462
|
+
# x[axis],
|
|
463
|
+
# filtration_grid=filtration_grid,
|
|
464
|
+
# _reconversion_grid=_reconversion_grid,
|
|
465
|
+
# )
|
|
466
|
+
# for axis, filtration_grid, _reconversion_grid in zip(
|
|
467
|
+
# range(self._num_st_per_data),
|
|
468
|
+
# self._filtration_grids,
|
|
469
|
+
# _reconversion_grids,
|
|
470
|
+
# )
|
|
471
|
+
# ]
|
|
472
|
+
# return out
|
|
473
|
+
#
|
|
474
|
+
# return Parallel(n_jobs=self.n_jobs, backend="threading")(
|
|
475
|
+
# delayed(todo)(x)
|
|
476
|
+
# for x in tqdm(
|
|
477
|
+
# X,
|
|
478
|
+
# disable=not self.progress,
|
|
479
|
+
# desc="Computing Signed Measures from simplextrees.",
|
|
480
|
+
# )
|
|
481
|
+
# )
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def rescale_sparse_signed_measure(
|
|
485
|
+
signed_measure, filtration_weights, normalize_scales=None
|
|
486
|
+
):
|
|
487
|
+
# from copy import deepcopy
|
|
488
|
+
#
|
|
489
|
+
# out = deepcopy(signed_measure)
|
|
490
|
+
|
|
491
|
+
if filtration_weights is None and normalize_scales is None:
|
|
492
|
+
return signed_measure
|
|
493
|
+
|
|
494
|
+
# if normalize_scales is None:
|
|
495
|
+
# out = tuple(
|
|
496
|
+
# (
|
|
497
|
+
# _cat(
|
|
498
|
+
# tuple(
|
|
499
|
+
# signed_measure[degree][0][:, parameter]
|
|
500
|
+
# * filtration_weights[parameter]
|
|
501
|
+
# for parameter in range(num_parameters)
|
|
502
|
+
# ),
|
|
503
|
+
# axis=1,
|
|
504
|
+
# ),
|
|
505
|
+
# signed_measure[degree][1],
|
|
506
|
+
# )
|
|
507
|
+
# for degree in range(len(signed_measure))
|
|
508
|
+
# )
|
|
509
|
+
# for degree in range(len(signed_measure)): # degree
|
|
510
|
+
# for parameter in range(len(filtration_weights)):
|
|
511
|
+
# signed_measure[degree][0][:, parameter] *= filtration_weights[parameter]
|
|
512
|
+
# # TODO Broadcast w.r.t. the parameter
|
|
513
|
+
# out = tuple(
|
|
514
|
+
# _cat(
|
|
515
|
+
# tuple(
|
|
516
|
+
# signed_measure[degree][0][:, [parameter]]
|
|
517
|
+
# * filtration_weights[parameter]
|
|
518
|
+
# / (
|
|
519
|
+
# normalize_scales[degree][parameter]
|
|
520
|
+
# if normalize_scales is not None
|
|
521
|
+
# else 1
|
|
522
|
+
# )
|
|
523
|
+
# for parameter in range(num_parameters)
|
|
524
|
+
# ),
|
|
525
|
+
# axis=1,
|
|
526
|
+
# )
|
|
527
|
+
# for degree in range(len(signed_measure))
|
|
528
|
+
# )
|
|
529
|
+
out = tuple(
|
|
530
|
+
(
|
|
531
|
+
signed_measure[degree][0]
|
|
532
|
+
* (1 if filtration_weights is None else filtration_weights.reshape(1, -1))
|
|
533
|
+
/ (
|
|
534
|
+
normalize_scales[degree].reshape(1, -1)
|
|
535
|
+
if normalize_scales is not None
|
|
536
|
+
else 1
|
|
537
|
+
),
|
|
538
|
+
signed_measure[degree][1],
|
|
539
|
+
)
|
|
540
|
+
for degree in range(len(signed_measure))
|
|
541
|
+
)
|
|
542
|
+
# for degree in range(len(out)):
|
|
543
|
+
# for parameter in range(len(filtration_weights)):
|
|
544
|
+
# out[degree][0][:, parameter] *= (
|
|
545
|
+
# filtration_weights[parameter] / normalize_scales[degree][parameter]
|
|
546
|
+
# )
|
|
547
|
+
return out
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def sm2deep(signed_measure):
|
|
551
|
+
dirac_positions, dirac_signs = signed_measure
|
|
552
|
+
dtype = dirac_positions.dtype
|
|
553
|
+
new_shape = list(dirac_positions.shape)
|
|
554
|
+
new_shape[1] += 1
|
|
555
|
+
if isinstance(dirac_positions, np.ndarray):
|
|
556
|
+
c = np.empty(new_shape, dtype=dtype)
|
|
557
|
+
c[:, :-1] = dirac_positions
|
|
558
|
+
c[:, -1] = dirac_signs
|
|
559
|
+
|
|
560
|
+
else:
|
|
561
|
+
import torch
|
|
562
|
+
|
|
563
|
+
c = torch.empty(new_shape, dtype=dtype)
|
|
564
|
+
c[:, :-1] = dirac_positions
|
|
565
|
+
if isinstance(dirac_signs, np.ndarray):
|
|
566
|
+
dirac_signs = torch.from_numpy(dirac_signs)
|
|
567
|
+
c[:, -1] = dirac_signs
|
|
568
|
+
return c
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
|
|
572
|
+
"""
|
|
573
|
+
Input
|
|
574
|
+
-----
|
|
575
|
+
|
|
576
|
+
(data) x (degree) x (signed measure) or (data) x (axis) x (degree) x (signed measure)
|
|
577
|
+
|
|
578
|
+
Iterable[list[signed_measure_matrix of degree]] or Iterable[previous].
|
|
579
|
+
|
|
580
|
+
The second is meant to use multiple choices for signed measure input. An example of usage : they come from a Rips + Density with different bandwidth.
|
|
581
|
+
It is controlled by the axis parameter.
|
|
582
|
+
|
|
583
|
+
Output
|
|
584
|
+
------
|
|
585
|
+
|
|
586
|
+
Iterable[list[(reweighted)_sparse_signed_measure of degree]]
|
|
587
|
+
|
|
588
|
+
or (deep format)
|
|
589
|
+
|
|
590
|
+
Tensor of shape (num_axis*num_degrees, data, max_num_pts, num_parameters)
|
|
591
|
+
"""
|
|
592
|
+
|
|
593
|
+
def __init__(
|
|
594
|
+
self,
|
|
595
|
+
filtrations_weights: Optional[Iterable[float]] = None,
|
|
596
|
+
normalize=False,
|
|
597
|
+
plot: bool = False,
|
|
598
|
+
unsparse: bool = False,
|
|
599
|
+
axis: int = -1,
|
|
600
|
+
resolution: int | Iterable[int] = 50,
|
|
601
|
+
flatten: bool = False,
|
|
602
|
+
deep_format: bool = False,
|
|
603
|
+
unrag: bool = True,
|
|
604
|
+
n_jobs: int = 1,
|
|
605
|
+
verbose: bool = False,
|
|
606
|
+
integrate: bool = False,
|
|
607
|
+
grid_strategy="regular",
|
|
608
|
+
):
|
|
609
|
+
super().__init__()
|
|
610
|
+
self.filtrations_weights = filtrations_weights
|
|
611
|
+
self.num_parameters: int = 0
|
|
612
|
+
self.plot = plot
|
|
613
|
+
self.unsparse = unsparse
|
|
614
|
+
self.n_jobs = n_jobs
|
|
615
|
+
self.axis = axis
|
|
616
|
+
self._num_axis = 0
|
|
617
|
+
self.resolution = resolution
|
|
618
|
+
self._filtrations_bounds = None
|
|
619
|
+
self.flatten = flatten
|
|
620
|
+
self.normalize = normalize
|
|
621
|
+
self._normalization_factors = None
|
|
622
|
+
self.deep_format = deep_format
|
|
623
|
+
self.unrag = unrag
|
|
624
|
+
assert (
|
|
625
|
+
not self.deep_format or not self.unsparse or not self.integrate
|
|
626
|
+
), "One post processing at the time."
|
|
627
|
+
self.verbose = verbose
|
|
628
|
+
self._num_degrees = 0
|
|
629
|
+
self.integrate = integrate
|
|
630
|
+
self.grid_strategy = grid_strategy
|
|
631
|
+
self._infered_grids = None
|
|
632
|
+
self._axis_iterator = None
|
|
633
|
+
self._backend = None
|
|
634
|
+
return
|
|
635
|
+
|
|
636
|
+
def _get_filtration_bounds(self, X, axis):
|
|
637
|
+
stuff = [
|
|
638
|
+
self._backend.cat(
|
|
639
|
+
[sm[axis][degree][0] for sm in X],
|
|
640
|
+
axis=0,
|
|
641
|
+
)
|
|
642
|
+
for degree in range(self._num_degrees)
|
|
643
|
+
]
|
|
644
|
+
sizes_ = np.array([len(x) == 0 for x in stuff])
|
|
645
|
+
assert np.all(~sizes_), f"Degree axis {np.where(sizes_)} is/are trivial !"
|
|
646
|
+
|
|
647
|
+
filtrations_bounds = self._backend.asnumpy(
|
|
648
|
+
self._backend.stack(
|
|
649
|
+
[
|
|
650
|
+
self._backend.stack(
|
|
651
|
+
[
|
|
652
|
+
self._backend.minvalues(f, axis=0),
|
|
653
|
+
self._backend.maxvalues(f, axis=0),
|
|
654
|
+
]
|
|
655
|
+
)
|
|
656
|
+
for f in stuff
|
|
657
|
+
]
|
|
658
|
+
)
|
|
659
|
+
) ## don't want to rescale gradient of normalization
|
|
660
|
+
normalization_factors = (
|
|
661
|
+
filtrations_bounds[:, 1] - filtrations_bounds[:, 0]
|
|
662
|
+
if self.normalize
|
|
663
|
+
else None
|
|
664
|
+
)
|
|
665
|
+
# print("Normalization factors : ",self._normalization_factors)
|
|
666
|
+
if (normalization_factors == 0).any():
|
|
667
|
+
indices = normalization_factors == 0
|
|
668
|
+
# warn(f"Constant filtration encountered, at degree, parameter {indices} and axis {self.axis}.")
|
|
669
|
+
normalization_factors[indices] = 1
|
|
670
|
+
return filtrations_bounds, normalization_factors
|
|
671
|
+
|
|
672
|
+
def _plot_signed_measures(self, sms: Iterable[np.ndarray], size=4):
|
|
673
|
+
from multipers.plots import plot_signed_measure
|
|
674
|
+
|
|
675
|
+
num_degrees = len(sms[0])
|
|
676
|
+
num_imgs = len(sms)
|
|
677
|
+
fig, axes = plt.subplots(
|
|
678
|
+
ncols=num_degrees,
|
|
679
|
+
nrows=num_imgs,
|
|
680
|
+
figsize=(size * num_degrees, size * num_imgs),
|
|
681
|
+
)
|
|
682
|
+
axes = np.asarray(axes).reshape(num_imgs, num_degrees)
|
|
683
|
+
# assert axes.ndim==2, "Internal error"
|
|
684
|
+
for i, sm in enumerate(sms):
|
|
685
|
+
for j, sm_of_degree in enumerate(sm):
|
|
686
|
+
plot_signed_measure(sm_of_degree, ax=axes[i, j])
|
|
687
|
+
|
|
688
|
+
@staticmethod
|
|
689
|
+
def _check_sm(sm) -> bool:
|
|
690
|
+
return (
|
|
691
|
+
isinstance(sm, tuple)
|
|
692
|
+
and hasattr(sm[0], "ndim")
|
|
693
|
+
and sm[0].ndim == 2
|
|
694
|
+
and len(sm) == 2
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
def _check_axis(self, X):
|
|
698
|
+
# axes should be (num_data, num_axis, num_degrees, (signed_measure))
|
|
699
|
+
if len(X) == 0:
|
|
700
|
+
return
|
|
701
|
+
if len(X[0]) == 0:
|
|
702
|
+
return
|
|
703
|
+
if self._check_sm(X[0][0]):
|
|
704
|
+
self._has_axis = False
|
|
705
|
+
self._num_axis = 1
|
|
706
|
+
self._axis_iterator = [slice(None)]
|
|
707
|
+
return
|
|
708
|
+
assert self._check_sm( ## vaguely checks that its a signed measure
|
|
709
|
+
_sm := X[0][0][0]
|
|
710
|
+
), f"Cannot take this input. # data, axis, degrees, sm.\n Got {_sm} of type {type(_sm)}"
|
|
711
|
+
|
|
712
|
+
self._has_axis = True
|
|
713
|
+
self._num_axis = len(X[0])
|
|
714
|
+
self._axis_iterator = range(self._num_axis) if self.axis == -1 else [self.axis]
|
|
715
|
+
|
|
716
|
+
def _check_backend(self, X):
|
|
717
|
+
if self._has_axis:
|
|
718
|
+
# data, axis, degrees, (pts, weights)
|
|
719
|
+
first_sm = X[0][0][0][0]
|
|
720
|
+
else:
|
|
721
|
+
first_sm = X[0][0][0]
|
|
722
|
+
self._backend = api_from_tensor(first_sm, verbose=self.verbose)
|
|
723
|
+
|
|
724
|
+
def _check_measures(self, X):
|
|
725
|
+
if self._has_axis:
|
|
726
|
+
first_sm = X[0][0]
|
|
727
|
+
else:
|
|
728
|
+
first_sm = X[0]
|
|
729
|
+
self._num_degrees = len(first_sm)
|
|
730
|
+
self.num_parameters = first_sm[0][0].shape[1]
|
|
731
|
+
|
|
732
|
+
def _check_resolution(self):
|
|
733
|
+
assert self.num_parameters > 0, "Num parameters hasn't been initialized."
|
|
734
|
+
if isinstance(self.resolution, int):
|
|
735
|
+
self.resolution = [self.resolution] * self.num_parameters
|
|
736
|
+
self.resolution = np.asarray(self.resolution, dtype=int)
|
|
737
|
+
assert (
|
|
738
|
+
self.resolution.shape[0] == self.num_parameters
|
|
739
|
+
), "Resolution doesn't have a proper size."
|
|
740
|
+
|
|
741
|
+
def _check_weights(self):
|
|
742
|
+
if self.filtrations_weights is None:
|
|
743
|
+
return
|
|
744
|
+
assert (
|
|
745
|
+
self.filtrations_weights.shape[0] == self.num_parameters
|
|
746
|
+
), "Filtration weights don't have a proper size"
|
|
747
|
+
|
|
748
|
+
def _infer_grids(self, X):
|
|
749
|
+
# Computes normalization factors
|
|
750
|
+
if self.normalize:
|
|
751
|
+
# if self._has_axis and self.axis == -1:
|
|
752
|
+
self._filtrations_bounds = []
|
|
753
|
+
self._normalization_factors = []
|
|
754
|
+
for ax in self._axis_iterator:
|
|
755
|
+
(
|
|
756
|
+
filtration_bounds,
|
|
757
|
+
normalization_factors,
|
|
758
|
+
) = self._get_filtration_bounds(X, axis=ax)
|
|
759
|
+
self._filtrations_bounds.append(filtration_bounds)
|
|
760
|
+
self._normalization_factors.append(normalization_factors)
|
|
761
|
+
self._filtrations_bounds = self._backend.astensor(self._filtrations_bounds)
|
|
762
|
+
self._normalization_factors = self._backend.astensor(self._normalization_factors)
|
|
763
|
+
# else:
|
|
764
|
+
# (
|
|
765
|
+
# self._filtrations_bounds,
|
|
766
|
+
# self._normalization_factors,
|
|
767
|
+
# ) = self._get_filtration_bounds(
|
|
768
|
+
# X, axis=self._axis_iterator[0]
|
|
769
|
+
# ) ## axis = slice(None)
|
|
770
|
+
elif self.integrate or self.unsparse or self.deep_format:
|
|
771
|
+
filtration_values = [
|
|
772
|
+
np.concatenate(
|
|
773
|
+
[
|
|
774
|
+
(
|
|
775
|
+
stuff
|
|
776
|
+
if isinstance(stuff := x[ax][degree][0], np.ndarray)
|
|
777
|
+
else stuff.detach().numpy()
|
|
778
|
+
)
|
|
779
|
+
for x in X
|
|
780
|
+
for degree in range(self._num_degrees)
|
|
781
|
+
]
|
|
782
|
+
)
|
|
783
|
+
for ax in self._axis_iterator
|
|
784
|
+
]
|
|
785
|
+
# axis, filtration_values
|
|
786
|
+
filtration_values = [
|
|
787
|
+
self._backend.astensor(compute_grid(
|
|
788
|
+
f_ax.T, resolution=self.resolution, strategy=self.grid_strategy
|
|
789
|
+
))
|
|
790
|
+
for f_ax in filtration_values
|
|
791
|
+
]
|
|
792
|
+
self._infered_grids = filtration_values
|
|
793
|
+
|
|
794
|
+
def _print_stats(self, X):
|
|
795
|
+
print("------------SignedMeasureFormatter------------")
|
|
796
|
+
print("---- Parameters")
|
|
797
|
+
print(f"Number of axis : {self._num_axis}")
|
|
798
|
+
print(f"Number of degrees : {self._num_degrees}")
|
|
799
|
+
print(f"Filtration bounds : \n{self._filtrations_bounds}")
|
|
800
|
+
print(f"Normalization factor : \n{self._normalization_factors}")
|
|
801
|
+
if self._infered_grids is not None:
|
|
802
|
+
print(
|
|
803
|
+
f"Filtration grid shape : \n \
|
|
804
|
+
{tuple(tuple(len(f) for f in F) for F in self._infered_grids)}"
|
|
805
|
+
)
|
|
806
|
+
print("---- SM stats")
|
|
807
|
+
print("In axis :", self._num_axis)
|
|
808
|
+
sizes = [
|
|
809
|
+
[[len(xd[1]) for xd in x[ax]] for x in X] for ax in self._axis_iterator
|
|
810
|
+
]
|
|
811
|
+
print(f"Size means (axis) x (degree): {np.mean(sizes, axis=(1))}")
|
|
812
|
+
print(f"Size std : {np.std(sizes, axis=(1))}")
|
|
813
|
+
print("----------------------------------------------")
|
|
814
|
+
|
|
815
|
+
def fit(self, X, y=None):
|
|
816
|
+
# Gets a grid. This will be the max in each coord+1
|
|
817
|
+
if (
|
|
818
|
+
len(X) == 0
|
|
819
|
+
or len(X[0]) == 0
|
|
820
|
+
or (self.axis is not None and len(X[0][0][0]) == 0)
|
|
821
|
+
):
|
|
822
|
+
return self
|
|
823
|
+
|
|
824
|
+
self._check_axis(X)
|
|
825
|
+
self._check_backend(X)
|
|
826
|
+
self._check_measures(X)
|
|
827
|
+
self._check_resolution()
|
|
828
|
+
self._check_weights()
|
|
829
|
+
# if not sparse : not recommended.
|
|
830
|
+
|
|
831
|
+
self._infer_grids(X)
|
|
832
|
+
if self.verbose:
|
|
833
|
+
self._print_stats(X)
|
|
834
|
+
return self
|
|
835
|
+
|
|
836
|
+
def unsparse_signed_measure(self, sparse_signed_measure):
|
|
837
|
+
filtrations = self._infered_grids # ax, filtration
|
|
838
|
+
out = []
|
|
839
|
+
for filtrations_of_ax, ax in zip(filtrations, self._axis_iterator, strict=True):
|
|
840
|
+
sparse_signed_measure_of_ax = sparse_signed_measure[ax]
|
|
841
|
+
measure_of_ax = []
|
|
842
|
+
for pts, weights in sparse_signed_measure_of_ax: # over degree
|
|
843
|
+
signed_measure, _ = np.histogramdd(
|
|
844
|
+
pts, bins=filtrations_of_ax, weights=weights
|
|
845
|
+
)
|
|
846
|
+
if self.flatten:
|
|
847
|
+
signed_measure = signed_measure.flatten()
|
|
848
|
+
measure_of_ax.append(signed_measure)
|
|
849
|
+
out.append(np.asarray(measure_of_ax))
|
|
850
|
+
|
|
851
|
+
if self.flatten:
|
|
852
|
+
out = np.concatenate(out).flatten()
|
|
853
|
+
elif self.axis == -1:
|
|
854
|
+
return np.asarray(out)
|
|
855
|
+
else:
|
|
856
|
+
return np.asarray(out)[0]
|
|
857
|
+
|
|
858
|
+
@staticmethod
|
|
859
|
+
def _integrate_measure(sm, filtrations):
|
|
860
|
+
from multipers.point_measure import integrate_measure
|
|
861
|
+
|
|
862
|
+
return integrate_measure(sm[0], sm[1], filtrations)
|
|
863
|
+
|
|
864
|
+
def _rescale_measures(self, X):
|
|
865
|
+
def rescale_from_sparse(sparse_signed_measure):
|
|
866
|
+
if self.axis == -1 and self._has_axis:
|
|
867
|
+
return tuple(
|
|
868
|
+
rescale_sparse_signed_measure(
|
|
869
|
+
sparse_signed_measure[ax],
|
|
870
|
+
filtration_weights=self.filtrations_weights,
|
|
871
|
+
normalize_scales=n,
|
|
872
|
+
)
|
|
873
|
+
for ax, n in zip(
|
|
874
|
+
self._axis_iterator, self._normalization_factors, strict=True
|
|
875
|
+
)
|
|
876
|
+
)
|
|
877
|
+
return rescale_sparse_signed_measure( ## axis iterator is of size 1 here
|
|
878
|
+
sparse_signed_measure,
|
|
879
|
+
filtration_weights=self.filtrations_weights,
|
|
880
|
+
normalize_scales=self._normalization_factors[0],
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
out = tuple(rescale_from_sparse(x) for x in X)
|
|
884
|
+
return out
|
|
885
|
+
|
|
886
|
+
def transform(self, X):
|
|
887
|
+
if not self._has_axis or self.axis == -1:
|
|
888
|
+
out = X
|
|
889
|
+
else:
|
|
890
|
+
out = tuple(x[self.axis] for x in X)
|
|
891
|
+
# same format for everyone
|
|
892
|
+
|
|
893
|
+
if self._normalization_factors is not None:
|
|
894
|
+
out = self._rescale_measures(out)
|
|
895
|
+
|
|
896
|
+
if self.plot:
|
|
897
|
+
# assert ax != -1, "Not implemented"
|
|
898
|
+
self._plot_signed_measures(out)
|
|
899
|
+
if self.integrate:
|
|
900
|
+
filtrations = self._infered_grids
|
|
901
|
+
# if self.axis != -1:
|
|
902
|
+
ax = 0 # if self.axis is None else self.axis # TODO deal with axis -1
|
|
903
|
+
|
|
904
|
+
assert ax != -1, "Not implemented. Can only integrate with axis"
|
|
905
|
+
# try:
|
|
906
|
+
out = np.asarray(
|
|
907
|
+
[
|
|
908
|
+
[
|
|
909
|
+
self._integrate_measure(x[degree], filtrations=filtrations[ax])
|
|
910
|
+
for degree in range(self._num_degrees)
|
|
911
|
+
]
|
|
912
|
+
for x in out
|
|
913
|
+
]
|
|
914
|
+
)
|
|
915
|
+
# except:
|
|
916
|
+
# print(self.axis, ax, filtrations)
|
|
917
|
+
if self.flatten:
|
|
918
|
+
out = out.reshape((len(X), -1))
|
|
919
|
+
# else:
|
|
920
|
+
# out = [[[self._integrate_measure(x[axis][degree],filtrations=filtrations[degree].T) for degree in range(self._num_degrees)] for axis in range(self._num_axis)] for x in out]
|
|
921
|
+
elif self.unsparse:
|
|
922
|
+
out = [self.unsparse_signed_measure(x) for x in out]
|
|
923
|
+
elif self.deep_format:
|
|
924
|
+
num_degrees = self._num_degrees
|
|
925
|
+
out = tuple(
|
|
926
|
+
tuple(sm2deep(sm[axis][degree]) for sm in out)
|
|
927
|
+
for degree in range(num_degrees)
|
|
928
|
+
for axis in self._axis_iterator
|
|
929
|
+
)
|
|
930
|
+
if self.unrag:
|
|
931
|
+
max_num_pts = np.max(
|
|
932
|
+
[sm.shape[0] for sm_of_axis in out for sm in sm_of_axis]
|
|
933
|
+
)
|
|
934
|
+
num_axis_degree = len(out)
|
|
935
|
+
num_data = len(out[0])
|
|
936
|
+
assert num_axis_degree == num_degrees * (
|
|
937
|
+
self._num_axis if self._has_axis else 1
|
|
938
|
+
), f"Bad axis/degree count. Got {num_axis_degree} (Internal error)"
|
|
939
|
+
num_parameters = out[0][0].shape[1]
|
|
940
|
+
dtype = out[0][0].dtype
|
|
941
|
+
unragged_tensor = self._backend.zeros(
|
|
942
|
+
(
|
|
943
|
+
num_axis_degree,
|
|
944
|
+
num_data,
|
|
945
|
+
max_num_pts,
|
|
946
|
+
num_parameters,
|
|
947
|
+
),
|
|
948
|
+
dtype=dtype,
|
|
949
|
+
)
|
|
950
|
+
for ax in range(num_axis_degree):
|
|
951
|
+
for data in range(num_data):
|
|
952
|
+
sm = out[ax][data]
|
|
953
|
+
a, b = sm.shape
|
|
954
|
+
unragged_tensor[ax, data, :a, :b] = sm
|
|
955
|
+
out = unragged_tensor
|
|
956
|
+
return out
|
|
957
|
+
|
|
958
|
+
|
|
959
|
+
class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
|
|
960
|
+
"""
|
|
961
|
+
Discrete convolution of a signed measure
|
|
962
|
+
|
|
963
|
+
Input
|
|
964
|
+
-----
|
|
965
|
+
|
|
966
|
+
(data) x (degree) x (signed measure)
|
|
967
|
+
|
|
968
|
+
Parameters
|
|
969
|
+
----------
|
|
970
|
+
- filtration_grid : Iterable[array] For each filtration, the filtration values on which to evaluate the grid
|
|
971
|
+
- resolution : int or (num_parameters) : If filtration grid is not given, will infer a grid, with this resolution
|
|
972
|
+
- grid_strategy : the strategy to generate the grid. Available ones are regular, quantile, exact
|
|
973
|
+
- flatten : if true, the output will be flattened
|
|
974
|
+
- kernel : kernel to used to convolve the images.
|
|
975
|
+
- flatten : flatten the images if True
|
|
976
|
+
- progress : progress bar if True
|
|
977
|
+
- backend : sklearn, pykeops or numba.
|
|
978
|
+
- plot : Creates a plot Figure.
|
|
979
|
+
|
|
980
|
+
Output
|
|
981
|
+
------
|
|
982
|
+
|
|
983
|
+
(data) x (concatenation of imgs of degree)
|
|
984
|
+
"""
|
|
985
|
+
|
|
986
|
+
def __init__(
|
|
987
|
+
self,
|
|
988
|
+
filtration_grid: Iterable[np.ndarray] = None,
|
|
989
|
+
kernel: available_kernels = "gaussian",
|
|
990
|
+
bandwidth: float | Iterable[float] = 1.0,
|
|
991
|
+
flatten: bool = False,
|
|
992
|
+
n_jobs: int = 1,
|
|
993
|
+
resolution: int | None = None,
|
|
994
|
+
grid_strategy: str = "regular",
|
|
995
|
+
progress: bool = False,
|
|
996
|
+
backend: str = "pykeops",
|
|
997
|
+
plot: bool = False,
|
|
998
|
+
log_density: bool = False,
|
|
999
|
+
**kde_kwargs,
|
|
1000
|
+
# **kwargs ## DANGEROUS
|
|
1001
|
+
):
|
|
1002
|
+
super().__init__()
|
|
1003
|
+
self.kernel: available_kernels = kernel
|
|
1004
|
+
self.bandwidth = bandwidth
|
|
1005
|
+
# self.more_kde_kwargs=kwargs
|
|
1006
|
+
self.filtration_grid = filtration_grid
|
|
1007
|
+
self.flatten = flatten
|
|
1008
|
+
self.progress = progress
|
|
1009
|
+
self.n_jobs = n_jobs
|
|
1010
|
+
self.resolution = resolution
|
|
1011
|
+
self.grid_strategy = grid_strategy
|
|
1012
|
+
self._is_input_sparse = None
|
|
1013
|
+
self._refit = filtration_grid is None
|
|
1014
|
+
self._input_resolution = None
|
|
1015
|
+
self._bandwidths = None
|
|
1016
|
+
self.diameter = None
|
|
1017
|
+
self.backend = backend
|
|
1018
|
+
self.plot = plot
|
|
1019
|
+
self.log_density = log_density
|
|
1020
|
+
self.kde_kwargs = kde_kwargs
|
|
1021
|
+
self._api = None
|
|
1022
|
+
return
|
|
1023
|
+
|
|
1024
|
+
def fit(self, X, y=None):
|
|
1025
|
+
# Infers if the input is sparse given X
|
|
1026
|
+
if len(X) == 0:
|
|
1027
|
+
return self
|
|
1028
|
+
if isinstance(X[0][0], tuple):
|
|
1029
|
+
self._is_input_sparse = True
|
|
1030
|
+
|
|
1031
|
+
self._api = api_from_tensor(X[0][0][0], verbose=self.progress)
|
|
1032
|
+
else:
|
|
1033
|
+
self._is_input_sparse = False
|
|
1034
|
+
|
|
1035
|
+
self._api = api_from_tensor(X, verbose=self.progress)
|
|
1036
|
+
# print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
|
|
1037
|
+
if not self._is_input_sparse:
|
|
1038
|
+
self._input_resolution = X[0][0].shape
|
|
1039
|
+
try:
|
|
1040
|
+
b = float(self.bandwidth)
|
|
1041
|
+
self._bandwidths = [
|
|
1042
|
+
b if b > 0 else -b * s for s in self._input_resolution
|
|
1043
|
+
]
|
|
1044
|
+
except:
|
|
1045
|
+
self._bandwidths = [
|
|
1046
|
+
b if b > 0 else -b * s
|
|
1047
|
+
for s, b in zip(self._input_resolution, self.bandwidth)
|
|
1048
|
+
]
|
|
1049
|
+
return self # in that case, singed measures are matrices, and the grid is already given
|
|
1050
|
+
|
|
1051
|
+
if self.filtration_grid is None and self.resolution is None:
|
|
1052
|
+
raise Exception(
|
|
1053
|
+
"Cannot infer filtration grid. Provide either a filtration grid or a resolution."
|
|
1054
|
+
)
|
|
1055
|
+
# If not sparse : a grid has to be defined
|
|
1056
|
+
if self._refit:
|
|
1057
|
+
# print("Fitting a grid...", end="")
|
|
1058
|
+
pts = self._api.cat(
|
|
1059
|
+
[sm[0] for signed_measures in X for sm in signed_measures]
|
|
1060
|
+
).T
|
|
1061
|
+
self.filtration_grid = compute_grid(
|
|
1062
|
+
pts,
|
|
1063
|
+
strategy=self.grid_strategy,
|
|
1064
|
+
resolution=self.resolution,
|
|
1065
|
+
)
|
|
1066
|
+
# print('Done.')
|
|
1067
|
+
if self.filtration_grid is not None:
|
|
1068
|
+
self.diameter = self._api.norm(
|
|
1069
|
+
self._api.astensor([f[-1] - f[0] for f in self.filtration_grid])
|
|
1070
|
+
)
|
|
1071
|
+
if self.progress:
|
|
1072
|
+
print(f"Computed a diameter of {self.diameter}")
|
|
1073
|
+
return self
|
|
1074
|
+
|
|
1075
|
+
def _sm2smi(self, signed_measures):
|
|
1076
|
+
# print(self._input_resolution, self.bandwidths, _bandwidths)
|
|
1077
|
+
from scipy.ndimage import gaussian_filter
|
|
1078
|
+
|
|
1079
|
+
return np.concatenate(
|
|
1080
|
+
[
|
|
1081
|
+
gaussian_filter(
|
|
1082
|
+
input=signed_measure,
|
|
1083
|
+
sigma=self._bandwidths,
|
|
1084
|
+
mode="constant",
|
|
1085
|
+
cval=0,
|
|
1086
|
+
)
|
|
1087
|
+
for signed_measure in signed_measures
|
|
1088
|
+
],
|
|
1089
|
+
axis=0,
|
|
1090
|
+
)
|
|
1091
|
+
|
|
1092
|
+
def _transform_from_sparse(self, X):
|
|
1093
|
+
bandwidth = (
|
|
1094
|
+
self.bandwidth if self.bandwidth > 0 else -self.bandwidth * self.diameter
|
|
1095
|
+
)
|
|
1096
|
+
# COMPILE KEOPS FIRST
|
|
1097
|
+
dummyx = [X[0]]
|
|
1098
|
+
dummyf = [f[:2] for f in self.filtration_grid]
|
|
1099
|
+
convolution_signed_measures(
|
|
1100
|
+
dummyx,
|
|
1101
|
+
filtrations=dummyf,
|
|
1102
|
+
bandwidth=bandwidth,
|
|
1103
|
+
flatten=self.flatten,
|
|
1104
|
+
n_jobs=1,
|
|
1105
|
+
kernel=self.kernel,
|
|
1106
|
+
backend=self.backend,
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
return convolution_signed_measures(
|
|
1110
|
+
X,
|
|
1111
|
+
filtrations=self.filtration_grid,
|
|
1112
|
+
bandwidth=bandwidth,
|
|
1113
|
+
flatten=self.flatten,
|
|
1114
|
+
n_jobs=self.n_jobs,
|
|
1115
|
+
kernel=self.kernel,
|
|
1116
|
+
backend=self.backend,
|
|
1117
|
+
**self.kde_kwargs,
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
def _plot_imgs(self, imgs: Iterable[np.ndarray], size=4):
|
|
1121
|
+
from multipers.plots import plot_surface
|
|
1122
|
+
|
|
1123
|
+
imgs = self._api.asnumpy(imgs)
|
|
1124
|
+
num_degrees = imgs[0].shape[0]
|
|
1125
|
+
num_imgs = len(imgs)
|
|
1126
|
+
fig, axes = plt.subplots(
|
|
1127
|
+
ncols=num_degrees,
|
|
1128
|
+
nrows=num_imgs,
|
|
1129
|
+
figsize=(size * num_degrees, size * num_imgs),
|
|
1130
|
+
)
|
|
1131
|
+
axes = np.asarray(axes).reshape(num_imgs, num_degrees)
|
|
1132
|
+
# assert axes.ndim==2, "Internal error"
|
|
1133
|
+
for i, img in enumerate(imgs):
|
|
1134
|
+
for j, img_of_degree in enumerate(img):
|
|
1135
|
+
plot_surface(
|
|
1136
|
+
[self._api.asnumpy(f) for f in self.filtration_grid],
|
|
1137
|
+
img_of_degree,
|
|
1138
|
+
ax=axes[i, j],
|
|
1139
|
+
cmap="Spectral",
|
|
1140
|
+
)
|
|
1141
|
+
|
|
1142
|
+
def transform(self, X):
|
|
1143
|
+
if self._is_input_sparse is None:
|
|
1144
|
+
raise Exception("Fit first")
|
|
1145
|
+
if self._is_input_sparse:
|
|
1146
|
+
out = self._transform_from_sparse(X)
|
|
1147
|
+
else:
|
|
1148
|
+
todo = SignedMeasure2Convolution._sm2smi
|
|
1149
|
+
out = Parallel(n_jobs=self.n_jobs, backend="threading")(
|
|
1150
|
+
delayed(todo)(self, signed_measures)
|
|
1151
|
+
for signed_measures in tqdm(
|
|
1152
|
+
X, desc="Computing images", disable=not self.progress
|
|
1153
|
+
)
|
|
1154
|
+
)
|
|
1155
|
+
out = self._api.cat([x[None] for x in out])
|
|
1156
|
+
if self.plot and not self.flatten:
|
|
1157
|
+
if self.progress:
|
|
1158
|
+
print("Plotting convolutions...", end="")
|
|
1159
|
+
self._plot_imgs(out)
|
|
1160
|
+
if self.progress:
|
|
1161
|
+
print("Done !")
|
|
1162
|
+
if self.flatten and not self._is_input_sparse:
|
|
1163
|
+
out = self._api.cat([x.ravel()[None] for x in out])
|
|
1164
|
+
return out
|
|
1165
|
+
|
|
1166
|
+
|
|
1167
|
+
class SignedMeasure2SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
|
|
1168
|
+
"""
|
|
1169
|
+
Transformer from signed measure to distance matrix.
|
|
1170
|
+
|
|
1171
|
+
Input
|
|
1172
|
+
-----
|
|
1173
|
+
|
|
1174
|
+
(data) x (degree) x (signed measure)
|
|
1175
|
+
|
|
1176
|
+
Format
|
|
1177
|
+
------
|
|
1178
|
+
- a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
|
|
1179
|
+
- each data is a list of signed measure (for e.g. multiple degrees)
|
|
1180
|
+
|
|
1181
|
+
Output
|
|
1182
|
+
------
|
|
1183
|
+
- (degree) x (distance matrix)
|
|
1184
|
+
"""
|
|
1185
|
+
|
|
1186
|
+
def __init__(
|
|
1187
|
+
self,
|
|
1188
|
+
n_jobs=None,
|
|
1189
|
+
num_directions: int = 10,
|
|
1190
|
+
_sliced: bool = True,
|
|
1191
|
+
epsilon=-1,
|
|
1192
|
+
ground_norm=1,
|
|
1193
|
+
progress=False,
|
|
1194
|
+
grid_reconversion=None,
|
|
1195
|
+
scales=None,
|
|
1196
|
+
):
|
|
1197
|
+
super().__init__()
|
|
1198
|
+
self.n_jobs = n_jobs
|
|
1199
|
+
self._SWD_list = None
|
|
1200
|
+
self._sliced = _sliced
|
|
1201
|
+
self.epsilon = epsilon
|
|
1202
|
+
self.ground_norm = ground_norm
|
|
1203
|
+
self.num_directions = num_directions
|
|
1204
|
+
self.progress = progress
|
|
1205
|
+
self.grid_reconversion = grid_reconversion
|
|
1206
|
+
self.scales = scales
|
|
1207
|
+
return
|
|
1208
|
+
|
|
1209
|
+
def fit(self, X, y=None):
|
|
1210
|
+
from multipers.ml.sliced_wasserstein import (
|
|
1211
|
+
SlicedWassersteinDistance,
|
|
1212
|
+
WassersteinDistance,
|
|
1213
|
+
)
|
|
1214
|
+
|
|
1215
|
+
# _DISTANCE = lambda : SlicedWassersteinDistance(num_directions=self.num_directions) if self._sliced else WassersteinDistance(epsilon=self.epsilon, ground_norm=self.ground_norm) # WARNING if _sliced is false, this distance is not CNSD
|
|
1216
|
+
if len(X) == 0:
|
|
1217
|
+
return self
|
|
1218
|
+
num_degrees = len(X[0])
|
|
1219
|
+
self._SWD_list = [
|
|
1220
|
+
(
|
|
1221
|
+
SlicedWassersteinDistance(
|
|
1222
|
+
num_directions=self.num_directions,
|
|
1223
|
+
n_jobs=self.n_jobs,
|
|
1224
|
+
scales=self.scales,
|
|
1225
|
+
)
|
|
1226
|
+
if self._sliced
|
|
1227
|
+
else WassersteinDistance(
|
|
1228
|
+
epsilon=self.epsilon,
|
|
1229
|
+
ground_norm=self.ground_norm,
|
|
1230
|
+
n_jobs=self.n_jobs,
|
|
1231
|
+
)
|
|
1232
|
+
)
|
|
1233
|
+
for _ in range(num_degrees)
|
|
1234
|
+
]
|
|
1235
|
+
for degree, swd in enumerate(self._SWD_list):
|
|
1236
|
+
signed_measures_of_degree = [x[degree] for x in X]
|
|
1237
|
+
swd.fit(signed_measures_of_degree)
|
|
1238
|
+
return self
|
|
1239
|
+
|
|
1240
|
+
def transform(self, X):
|
|
1241
|
+
assert self._SWD_list is not None, "Fit first"
|
|
1242
|
+
# out = []
|
|
1243
|
+
# for degree, swd in tqdm(enumerate(self._SWD_list), desc="Computing distance matrices", total=len(self._SWD_list), disable= not self.progress):
|
|
1244
|
+
with tqdm(
|
|
1245
|
+
enumerate(self._SWD_list),
|
|
1246
|
+
desc="Computing distance matrices",
|
|
1247
|
+
total=len(self._SWD_list),
|
|
1248
|
+
disable=not self.progress,
|
|
1249
|
+
) as SWD_it:
|
|
1250
|
+
# signed_measures_of_degree = [x[degree] for x in X]
|
|
1251
|
+
# out.append(swd.transform(signed_measures_of_degree))
|
|
1252
|
+
def todo(swd, X_of_degree):
|
|
1253
|
+
return swd.transform(X_of_degree)
|
|
1254
|
+
|
|
1255
|
+
out = Parallel(n_jobs=self.n_jobs, prefer="threads")(
|
|
1256
|
+
delayed(todo)(swd, [x[degree] for x in X]) for degree, swd in SWD_it
|
|
1257
|
+
)
|
|
1258
|
+
return np.asarray(out)
|
|
1259
|
+
|
|
1260
|
+
def predict(self, X):
|
|
1261
|
+
return self.transform(X)
|
|
1262
|
+
|
|
1263
|
+
|
|
1264
|
+
class SignedMeasures2SlicedWassersteinDistances(BaseEstimator, TransformerMixin):
|
|
1265
|
+
"""
|
|
1266
|
+
Transformer from signed measure to distance matrix.
|
|
1267
|
+
Input
|
|
1268
|
+
-----
|
|
1269
|
+
(data) x opt (axis) x (degree) x (signed measure)
|
|
1270
|
+
|
|
1271
|
+
Format
|
|
1272
|
+
------
|
|
1273
|
+
- a signed measure : tuple of array. (point position) : npts x (num_paramters) and weigths : npts
|
|
1274
|
+
- each data is a list of signed measure (for e.g. multiple degrees)
|
|
1275
|
+
|
|
1276
|
+
Output
|
|
1277
|
+
------
|
|
1278
|
+
- (axis) x (degree) x (distance matrix)
|
|
1279
|
+
"""
|
|
1280
|
+
|
|
1281
|
+
def __init__(
|
|
1282
|
+
self,
|
|
1283
|
+
progress=False,
|
|
1284
|
+
n_jobs: int = 1,
|
|
1285
|
+
scales: Iterable[Iterable[float]] | None = None,
|
|
1286
|
+
**kwargs,
|
|
1287
|
+
): # same init
|
|
1288
|
+
self._init_child = SignedMeasure2SlicedWassersteinDistance(
|
|
1289
|
+
progress=False, scales=None, n_jobs=-1, **kwargs
|
|
1290
|
+
)
|
|
1291
|
+
self._axe_iterator = None
|
|
1292
|
+
self._childs_to_fit = None
|
|
1293
|
+
self.scales = scales
|
|
1294
|
+
self.progress = progress
|
|
1295
|
+
self.n_jobs = n_jobs
|
|
1296
|
+
return
|
|
1297
|
+
|
|
1298
|
+
def fit(self, X, y=None):
|
|
1299
|
+
from sklearn.base import clone
|
|
1300
|
+
|
|
1301
|
+
if len(X) == 0:
|
|
1302
|
+
return self
|
|
1303
|
+
if isinstance(X[0][0], tuple): # Meaning that there are no axes
|
|
1304
|
+
self._axe_iterator = [slice(None)]
|
|
1305
|
+
else:
|
|
1306
|
+
self._axe_iterator = range(len(X[0]))
|
|
1307
|
+
if self.scales is None:
|
|
1308
|
+
self.scales = [None]
|
|
1309
|
+
else:
|
|
1310
|
+
self.scales = np.asarray(self.scales)
|
|
1311
|
+
if self.scales.ndim == 1:
|
|
1312
|
+
self.scales = np.asarray([self.scales])
|
|
1313
|
+
assert (
|
|
1314
|
+
self.scales[0] is None or self.scales.ndim == 2
|
|
1315
|
+
), "Scales have to be either None or a list of scales !"
|
|
1316
|
+
self._childs_to_fit = [
|
|
1317
|
+
clone(self._init_child).set_params(scales=scales).fit([x[axis] for x in X])
|
|
1318
|
+
for axis, scales in product(self._axe_iterator, self.scales)
|
|
1319
|
+
]
|
|
1320
|
+
print("New axes : ", list(product(self._axe_iterator, self.scales)))
|
|
1321
|
+
return self
|
|
1322
|
+
|
|
1323
|
+
def transform(self, X):
|
|
1324
|
+
return Parallel(n_jobs=self.n_jobs, prefer="processes")(
|
|
1325
|
+
delayed(self._childs_to_fit[child_id].transform)([x[axis] for x in X])
|
|
1326
|
+
for child_id, (axis, _) in tqdm(
|
|
1327
|
+
enumerate(product(self._axe_iterator, self.scales)),
|
|
1328
|
+
desc=f"Computing distances matrices of axis, and scales",
|
|
1329
|
+
disable=not self.progress,
|
|
1330
|
+
total=len(self._childs_to_fit),
|
|
1331
|
+
)
|
|
1332
|
+
)
|
|
1333
|
+
# [
|
|
1334
|
+
# child.transform([x[axis // len(self.scales)] for x in X])
|
|
1335
|
+
# for axis, child in tqdm(enumerate(self._childs_to_fit),
|
|
1336
|
+
# desc=f"Computing distances of axis", disable=not self.progress, total=len(self._childs_to_fit)
|
|
1337
|
+
# )
|
|
1338
|
+
# ]
|
|
1339
|
+
|
|
1340
|
+
|
|
1341
|
+
class SimplexTree2RectangleDecomposition(BaseEstimator, TransformerMixin):
|
|
1342
|
+
"""
|
|
1343
|
+
Transformer. 2 parameter SimplexTrees to their respective rectangle decomposition.
|
|
1344
|
+
"""
|
|
1345
|
+
|
|
1346
|
+
def __init__(
|
|
1347
|
+
self,
|
|
1348
|
+
filtration_grid: np.ndarray,
|
|
1349
|
+
degrees: Iterable[int],
|
|
1350
|
+
plot=False,
|
|
1351
|
+
reconvert_grid=True,
|
|
1352
|
+
num_collapses: int = 0,
|
|
1353
|
+
):
|
|
1354
|
+
super().__init__()
|
|
1355
|
+
self.filtration_grid = filtration_grid
|
|
1356
|
+
self.degrees = degrees
|
|
1357
|
+
self.plot = plot
|
|
1358
|
+
self.reconvert_grid = reconvert_grid
|
|
1359
|
+
self.num_collapses = num_collapses
|
|
1360
|
+
return
|
|
1361
|
+
|
|
1362
|
+
def fit(self, X, y=None):
|
|
1363
|
+
"""
|
|
1364
|
+
TODO : infer grid from multiple simplextrees
|
|
1365
|
+
"""
|
|
1366
|
+
return self
|
|
1367
|
+
|
|
1368
|
+
def transform(self, X: Iterable[mp.simplex_tree_multi.SimplexTreeMulti_type]):
|
|
1369
|
+
rectangle_decompositions = [
|
|
1370
|
+
[
|
|
1371
|
+
_st2ranktensor(
|
|
1372
|
+
simplextree,
|
|
1373
|
+
filtration_grid=self.filtration_grid,
|
|
1374
|
+
degree=degree,
|
|
1375
|
+
plot=self.plot,
|
|
1376
|
+
reconvert_grid=self.reconvert_grid,
|
|
1377
|
+
num_collapse=self.num_collapses,
|
|
1378
|
+
)
|
|
1379
|
+
for degree in self.degrees
|
|
1380
|
+
]
|
|
1381
|
+
for simplextree in X
|
|
1382
|
+
]
|
|
1383
|
+
# TODO : return iterator ?
|
|
1384
|
+
return rectangle_decompositions
|
|
1385
|
+
|
|
1386
|
+
|
|
1387
|
+
def _st2ranktensor(
|
|
1388
|
+
st: mp.simplex_tree_multi.SimplexTreeMulti_type,
|
|
1389
|
+
filtration_grid: np.ndarray,
|
|
1390
|
+
degree: int,
|
|
1391
|
+
plot: bool,
|
|
1392
|
+
reconvert_grid: bool,
|
|
1393
|
+
num_collapse: int | str = 0,
|
|
1394
|
+
):
|
|
1395
|
+
"""
|
|
1396
|
+
TODO
|
|
1397
|
+
"""
|
|
1398
|
+
# Copy (the squeeze change the filtration values)
|
|
1399
|
+
# stcpy = mp.SimplexTreeMulti(st)
|
|
1400
|
+
# turns the simplextree into a coordinate simplex tree
|
|
1401
|
+
stcpy = st.grid_squeeze(filtration_grid=filtration_grid, coordinate_values=True)
|
|
1402
|
+
# stcpy.collapse_edges(num=100, strong = True, ignore_warning=True)
|
|
1403
|
+
if num_collapse == "full":
|
|
1404
|
+
stcpy.collapse_edges(full=True, ignore_warning=True, max_dimension=degree + 1)
|
|
1405
|
+
elif isinstance(num_collapse, int):
|
|
1406
|
+
stcpy.collapse_edges(
|
|
1407
|
+
num=num_collapse, ignore_warning=True, max_dimension=degree + 1
|
|
1408
|
+
)
|
|
1409
|
+
else:
|
|
1410
|
+
raise TypeError(
|
|
1411
|
+
f"Invalid num_collapse=\
|
|
1412
|
+
{num_collapse} type. Either full, or an integer."
|
|
1413
|
+
)
|
|
1414
|
+
# computes the rank invariant tensor
|
|
1415
|
+
rank_tensor = mp.rank_invariant2d(
|
|
1416
|
+
stcpy, degree=degree, grid_shape=[len(f) for f in filtration_grid]
|
|
1417
|
+
)
|
|
1418
|
+
# refactor this tensor into the rectangle decomposition of the signed betti
|
|
1419
|
+
grid_conversion = filtration_grid if reconvert_grid else None
|
|
1420
|
+
rank_decomposition = rank_decomposition_by_rectangles(
|
|
1421
|
+
rank_tensor,
|
|
1422
|
+
threshold=True,
|
|
1423
|
+
)
|
|
1424
|
+
rectangle_decomposition = tensor_möbius_inversion(
|
|
1425
|
+
tensor=rank_decomposition,
|
|
1426
|
+
grid_conversion=grid_conversion,
|
|
1427
|
+
plot=plot,
|
|
1428
|
+
num_parameters=st.num_parameters,
|
|
1429
|
+
)
|
|
1430
|
+
return rectangle_decomposition
|
|
1431
|
+
|
|
1432
|
+
|
|
1433
|
+
class DegreeRips2SignedMeasure(BaseEstimator, TransformerMixin):
|
|
1434
|
+
def __init__(
|
|
1435
|
+
self,
|
|
1436
|
+
degrees: Iterable[int],
|
|
1437
|
+
min_rips_value: float,
|
|
1438
|
+
max_rips_value,
|
|
1439
|
+
max_normalized_degree: float,
|
|
1440
|
+
min_normalized_degree: float,
|
|
1441
|
+
grid_granularity: int,
|
|
1442
|
+
progress: bool = False,
|
|
1443
|
+
n_jobs=1,
|
|
1444
|
+
sparse: bool = False,
|
|
1445
|
+
_möbius_inversion=True,
|
|
1446
|
+
fit_fraction=1,
|
|
1447
|
+
) -> None:
|
|
1448
|
+
super().__init__()
|
|
1449
|
+
self.min_rips_value = min_rips_value
|
|
1450
|
+
self.max_rips_value = max_rips_value
|
|
1451
|
+
self.min_normalized_degree = min_normalized_degree
|
|
1452
|
+
self.max_normalized_degree = max_normalized_degree
|
|
1453
|
+
self._max_rips_value = None
|
|
1454
|
+
self.grid_granularity = grid_granularity
|
|
1455
|
+
self.progress = progress
|
|
1456
|
+
self.n_jobs = n_jobs
|
|
1457
|
+
self.degrees = degrees
|
|
1458
|
+
self.sparse = sparse
|
|
1459
|
+
self._möbius_inversion = _möbius_inversion
|
|
1460
|
+
self.fit_fraction = fit_fraction
|
|
1461
|
+
return
|
|
1462
|
+
|
|
1463
|
+
def fit(self, X: np.ndarray | list, y=None):
|
|
1464
|
+
if self.max_rips_value < 0:
|
|
1465
|
+
print("Estimating scale...", flush=True, end="")
|
|
1466
|
+
indices = np.random.choice(
|
|
1467
|
+
len(X), min(len(X), int(self.fit_fraction * len(X)) + 1), replace=False
|
|
1468
|
+
)
|
|
1469
|
+
diameters = np.max(
|
|
1470
|
+
[distance_matrix(x, x).max() for x in (X[i] for i in indices)]
|
|
1471
|
+
)
|
|
1472
|
+
print(f"Done. {diameters}", flush=True)
|
|
1473
|
+
self._max_rips_value = (
|
|
1474
|
+
-self.max_rips_value * diameters
|
|
1475
|
+
if self.max_rips_value < 0
|
|
1476
|
+
else self.max_rips_value
|
|
1477
|
+
)
|
|
1478
|
+
return self
|
|
1479
|
+
|
|
1480
|
+
def _transform1(self, data: np.ndarray):
|
|
1481
|
+
_distance_matrix = distance_matrix(data, data)
|
|
1482
|
+
signed_measures = []
|
|
1483
|
+
(
|
|
1484
|
+
rips_values,
|
|
1485
|
+
normalized_degree_values,
|
|
1486
|
+
hilbert_functions,
|
|
1487
|
+
minimal_presentations,
|
|
1488
|
+
) = hf_degree_rips(
|
|
1489
|
+
_distance_matrix,
|
|
1490
|
+
min_rips_value=self.min_rips_value,
|
|
1491
|
+
max_rips_value=self._max_rips_value,
|
|
1492
|
+
min_normalized_degree=self.min_normalized_degree,
|
|
1493
|
+
max_normalized_degree=self.max_normalized_degree,
|
|
1494
|
+
grid_granularity=self.grid_granularity,
|
|
1495
|
+
max_homological_dimension=np.max(self.degrees),
|
|
1496
|
+
)
|
|
1497
|
+
for degree in self.degrees:
|
|
1498
|
+
hilbert_function = hilbert_functions[degree]
|
|
1499
|
+
signed_measure = (
|
|
1500
|
+
signed_betti(hilbert_function, threshold=True)
|
|
1501
|
+
if self._möbius_inversion
|
|
1502
|
+
else hilbert_function
|
|
1503
|
+
)
|
|
1504
|
+
if self.sparse:
|
|
1505
|
+
signed_measure = tensor_möbius_inversion(
|
|
1506
|
+
tensor=signed_measure,
|
|
1507
|
+
num_parameters=2,
|
|
1508
|
+
grid_conversion=[rips_values, normalized_degree_values],
|
|
1509
|
+
)
|
|
1510
|
+
if not self._möbius_inversion:
|
|
1511
|
+
signed_measure = signed_measure.flatten()
|
|
1512
|
+
signed_measures.append(signed_measure)
|
|
1513
|
+
return signed_measures
|
|
1514
|
+
|
|
1515
|
+
def transform(self, X):
|
|
1516
|
+
return Parallel(n_jobs=self.n_jobs)(
|
|
1517
|
+
delayed(self._transform1)(data)
|
|
1518
|
+
for data in tqdm(X, desc=f"Computing DegreeRips, of degrees {self.degrees}")
|
|
1519
|
+
)
|
|
1520
|
+
|
|
1521
|
+
|
|
1522
|
+
def tensor_möbius_inversion(
|
|
1523
|
+
tensor,
|
|
1524
|
+
grid_conversion: Iterable[np.ndarray] | None = None,
|
|
1525
|
+
plot: bool = False,
|
|
1526
|
+
raw: bool = False,
|
|
1527
|
+
num_parameters: int | None = None,
|
|
1528
|
+
):
|
|
1529
|
+
from torch import Tensor
|
|
1530
|
+
|
|
1531
|
+
betti_sparse = Tensor(tensor.copy()).to_sparse() # Copy necessary in some cases :(
|
|
1532
|
+
num_indices, num_pts = betti_sparse.indices().shape
|
|
1533
|
+
num_parameters = num_indices if num_parameters is None else num_parameters
|
|
1534
|
+
if num_indices == num_parameters: # either hilbert or rank invariant
|
|
1535
|
+
rank_invariant = False
|
|
1536
|
+
elif 2 * num_parameters == num_indices:
|
|
1537
|
+
rank_invariant = True
|
|
1538
|
+
else:
|
|
1539
|
+
raise TypeError(
|
|
1540
|
+
f"Unsupported betti shape. {num_indices}\
|
|
1541
|
+
has to be either {num_parameters} or \
|
|
1542
|
+
{2*num_parameters}."
|
|
1543
|
+
)
|
|
1544
|
+
points_filtration = np.asarray(betti_sparse.indices().T, dtype=int)
|
|
1545
|
+
weights = np.asarray(betti_sparse.values(), dtype=int)
|
|
1546
|
+
|
|
1547
|
+
if grid_conversion is not None:
|
|
1548
|
+
coords = np.empty(shape=(num_pts, num_indices), dtype=float)
|
|
1549
|
+
for i in range(num_indices):
|
|
1550
|
+
coords[:, i] = grid_conversion[i % num_parameters][points_filtration[:, i]]
|
|
1551
|
+
else:
|
|
1552
|
+
coords = points_filtration
|
|
1553
|
+
if (not rank_invariant) and plot:
|
|
1554
|
+
plt.figure()
|
|
1555
|
+
color_weights = np.empty(weights.shape)
|
|
1556
|
+
color_weights[weights > 0] = np.log10(weights[weights > 0]) + 2
|
|
1557
|
+
color_weights[weights < 0] = -np.log10(-weights[weights < 0]) - 2
|
|
1558
|
+
plt.scatter(
|
|
1559
|
+
points_filtration[:, 0],
|
|
1560
|
+
points_filtration[:, 1],
|
|
1561
|
+
c=color_weights,
|
|
1562
|
+
cmap="coolwarm",
|
|
1563
|
+
)
|
|
1564
|
+
if (not rank_invariant) or raw:
|
|
1565
|
+
return coords, weights
|
|
1566
|
+
|
|
1567
|
+
def _is_trivial(rectangle: np.ndarray):
|
|
1568
|
+
birth = rectangle[:num_parameters]
|
|
1569
|
+
death = rectangle[num_parameters:]
|
|
1570
|
+
return np.all(birth <= death) # and not np.array_equal(birth,death)
|
|
1571
|
+
|
|
1572
|
+
correct_indices = np.array([_is_trivial(rectangle) for rectangle in coords])
|
|
1573
|
+
if len(correct_indices) == 0:
|
|
1574
|
+
return np.empty((0, num_indices)), np.empty((0))
|
|
1575
|
+
signed_measure = np.asarray(coords[correct_indices])
|
|
1576
|
+
weights = weights[correct_indices]
|
|
1577
|
+
if plot:
|
|
1578
|
+
# plot only the rank decompo for the moment
|
|
1579
|
+
assert signed_measure.shape[1] == 4
|
|
1580
|
+
|
|
1581
|
+
def _plot_rectangle(rectangle: np.ndarray, weight: float):
|
|
1582
|
+
x_axis = rectangle[[0, 2]]
|
|
1583
|
+
y_axis = rectangle[[1, 3]]
|
|
1584
|
+
color = "blue" if weight > 0 else "red"
|
|
1585
|
+
plt.plot(x_axis, y_axis, c=color)
|
|
1586
|
+
|
|
1587
|
+
for rectangle, weight in zip(signed_measure, weights):
|
|
1588
|
+
_plot_rectangle(rectangle=rectangle, weight=weight)
|
|
1589
|
+
return signed_measure, weights
|