multipers 2.3.3b6__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/.dylibs/libc++.1.0.dylib +0 -0
- multipers/.dylibs/libtbb.12.16.dylib +0 -0
- multipers/__init__.py +33 -0
- multipers/_signed_measure_meta.py +453 -0
- multipers/_slicer_meta.py +211 -0
- multipers/array_api/__init__.py +45 -0
- multipers/array_api/numpy.py +41 -0
- multipers/array_api/torch.py +58 -0
- multipers/data/MOL2.py +458 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +113 -0
- multipers/distances.py +202 -0
- multipers/filtration_conversions.pxd +229 -0
- multipers/filtration_conversions.pxd.tp +84 -0
- multipers/filtrations/__init__.py +18 -0
- multipers/filtrations/density.py +574 -0
- multipers/filtrations/filtrations.py +361 -0
- multipers/filtrations.pxd +224 -0
- multipers/function_rips.cpython-311-darwin.so +0 -0
- multipers/function_rips.pyx +105 -0
- multipers/grids.cpython-311-darwin.so +0 -0
- multipers/grids.pyx +433 -0
- multipers/gudhi/Persistence_slices_interface.h +132 -0
- multipers/gudhi/Simplex_tree_interface.h +239 -0
- multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
- multipers/gudhi/cubical_to_boundary.h +59 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
- multipers/gudhi/gudhi/Debug_utils.h +45 -0
- multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
- multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
- multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
- multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
- multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
- multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
- multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
- multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
- multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
- multipers/gudhi/gudhi/Matrix.h +2107 -0
- multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
- multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
- multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
- multipers/gudhi/gudhi/Off_reader.h +173 -0
- multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
- multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
- multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
- multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
- multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
- multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
- multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
- multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
- multipers/gudhi/gudhi/Points_off_io.h +171 -0
- multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
- multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
- multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
- multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
- multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
- multipers/gudhi/gudhi/distance_functions.h +62 -0
- multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
- multipers/gudhi/gudhi/persistence_interval.h +253 -0
- multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
- multipers/gudhi/gudhi/reader_utils.h +367 -0
- multipers/gudhi/mma_interface_coh.h +256 -0
- multipers/gudhi/mma_interface_h0.h +223 -0
- multipers/gudhi/mma_interface_matrix.h +293 -0
- multipers/gudhi/naive_merge_tree.h +536 -0
- multipers/gudhi/scc_io.h +310 -0
- multipers/gudhi/truc.h +1403 -0
- multipers/io.cpython-311-darwin.so +0 -0
- multipers/io.pyx +644 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +90 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +176 -0
- multipers/ml/mma.py +713 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +352 -0
- multipers/ml/signed_measures.py +1589 -0
- multipers/ml/sliced_wasserstein.py +461 -0
- multipers/ml/tools.py +113 -0
- multipers/mma_structures.cpython-311-darwin.so +0 -0
- multipers/mma_structures.pxd +128 -0
- multipers/mma_structures.pyx +2786 -0
- multipers/mma_structures.pyx.tp +1094 -0
- multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
- multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
- multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
- multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
- multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
- multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
- multipers/multiparameter_edge_collapse.py +41 -0
- multipers/multiparameter_module_approximation/approximation.h +2330 -0
- multipers/multiparameter_module_approximation/combinatory.h +129 -0
- multipers/multiparameter_module_approximation/debug.h +107 -0
- multipers/multiparameter_module_approximation/euler_curves.h +0 -0
- multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
- multipers/multiparameter_module_approximation/heap_column.h +238 -0
- multipers/multiparameter_module_approximation/images.h +79 -0
- multipers/multiparameter_module_approximation/list_column.h +174 -0
- multipers/multiparameter_module_approximation/list_column_2.h +232 -0
- multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
- multipers/multiparameter_module_approximation/set_column.h +135 -0
- multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
- multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
- multipers/multiparameter_module_approximation/utilities.h +403 -0
- multipers/multiparameter_module_approximation/vector_column.h +223 -0
- multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
- multipers/multiparameter_module_approximation/vineyards.h +464 -0
- multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
- multipers/multiparameter_module_approximation.cpython-311-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyx +235 -0
- multipers/pickle.py +90 -0
- multipers/plots.py +456 -0
- multipers/point_measure.cpython-311-darwin.so +0 -0
- multipers/point_measure.pyx +395 -0
- multipers/simplex_tree_multi.cpython-311-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +134 -0
- multipers/simplex_tree_multi.pyx +10840 -0
- multipers/simplex_tree_multi.pyx.tp +2009 -0
- multipers/slicer.cpython-311-darwin.so +0 -0
- multipers/slicer.pxd +3034 -0
- multipers/slicer.pxd.tp +234 -0
- multipers/slicer.pyx +20481 -0
- multipers/slicer.pyx.tp +1088 -0
- multipers/tensor/tensor.h +672 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers/tests/__init__.py +62 -0
- multipers/torch/__init__.py +1 -0
- multipers/torch/diff_grids.py +240 -0
- multipers/torch/rips_density.py +310 -0
- multipers-2.3.3b6.dist-info/METADATA +128 -0
- multipers-2.3.3b6.dist-info/RECORD +183 -0
- multipers-2.3.3b6.dist-info/WHEEL +6 -0
- multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
- multipers-2.3.3b6.dist-info/top_level.txt +1 -0
multipers/ml/mma.py
ADDED
|
@@ -0,0 +1,713 @@
|
|
|
1
|
+
from typing import Callable, Iterable, List, Optional, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from joblib import Parallel, delayed
|
|
5
|
+
from sklearn.base import BaseEstimator, TransformerMixin
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
import multipers as mp
|
|
9
|
+
import multipers.simplex_tree_multi
|
|
10
|
+
import multipers.slicer
|
|
11
|
+
from multipers.grids import compute_grid
|
|
12
|
+
from multipers.mma_structures import PyBox_f64, PyModule_type
|
|
13
|
+
|
|
14
|
+
_FilteredComplexType = Union[
|
|
15
|
+
mp.slicer.Slicer_type, mp.simplex_tree_multi.SimplexTreeMulti_type
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FilteredComplex2MMA(BaseEstimator, TransformerMixin):
|
|
20
|
+
"""
|
|
21
|
+
Turns a list of list of simplextrees or slicers to MMA approximations.
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
n_jobs: int = -1,
|
|
28
|
+
expand_dim: Optional[int] = None,
|
|
29
|
+
prune_degrees_above: Optional[int] = None,
|
|
30
|
+
progress=False,
|
|
31
|
+
minpres_degrees: Optional[Iterable[int]] = None,
|
|
32
|
+
plot: bool = False,
|
|
33
|
+
**persistence_kwargs,
|
|
34
|
+
) -> None:
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.persistence_args = persistence_kwargs
|
|
37
|
+
self.n_jobs = n_jobs
|
|
38
|
+
self._num_axis = None
|
|
39
|
+
self.prune_degrees_above = prune_degrees_above
|
|
40
|
+
self.progress = progress
|
|
41
|
+
self.expand_dim = expand_dim
|
|
42
|
+
self._boxes = None
|
|
43
|
+
self._is_minpres = None
|
|
44
|
+
self.minpres_degrees = minpres_degrees
|
|
45
|
+
self.plot = plot
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def _is_filtered_complex(input):
|
|
50
|
+
return mp.simplex_tree_multi.is_simplextree_multi(input) or mp.slicer.is_slicer(
|
|
51
|
+
input, allow_minpres=True
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def _input_checks(self, X):
|
|
55
|
+
assert len(X) > 0, "No filtered complex found. Cannot fit."
|
|
56
|
+
assert self._is_filtered_complex(
|
|
57
|
+
X[0][0]
|
|
58
|
+
), f"X[0] is not a known filtered complex, {X[0]=}, nor X[0][0]."
|
|
59
|
+
self._num_axis = len(X[0])
|
|
60
|
+
first = X[0][0]
|
|
61
|
+
assert (
|
|
62
|
+
not mp.slicer.is_slicer(first) or self.expand_dim is None
|
|
63
|
+
), "Cannot expand slicers."
|
|
64
|
+
self._is_minpres = mp.slicer.is_slicer(first) and isinstance(
|
|
65
|
+
first, Union[tuple, list]
|
|
66
|
+
)
|
|
67
|
+
assert not (
|
|
68
|
+
self._is_minpres and self.minpres_degrees is not None
|
|
69
|
+
), "Input is already a minpres. Cannot reduce again."
|
|
70
|
+
|
|
71
|
+
def _infer_bounding_box(self, X):
|
|
72
|
+
assert self._num_axis is not None, "Fit first"
|
|
73
|
+
filtration_values = (
|
|
74
|
+
np.asarray(
|
|
75
|
+
[
|
|
76
|
+
[s.filtration_bounds() for x in X for s in x[axis]]
|
|
77
|
+
for axis in range(self._num_axis)
|
|
78
|
+
]
|
|
79
|
+
)
|
|
80
|
+
if self._is_minpres
|
|
81
|
+
else np.asarray(
|
|
82
|
+
[
|
|
83
|
+
[x[axis].filtration_bounds() for x in X]
|
|
84
|
+
for axis in range(self._num_axis)
|
|
85
|
+
]
|
|
86
|
+
)
|
|
87
|
+
)
|
|
88
|
+
num_parameters = filtration_values.shape[-1]
|
|
89
|
+
# Output : axis, data, min/max, num_parameters
|
|
90
|
+
# print("TEST : NUM PARAMETERS ", num_parameters)
|
|
91
|
+
m = np.asarray(
|
|
92
|
+
[
|
|
93
|
+
[
|
|
94
|
+
filtration_values[axis, :, 0, parameter].min()
|
|
95
|
+
for parameter in range(num_parameters)
|
|
96
|
+
]
|
|
97
|
+
for axis in range(self._num_axis)
|
|
98
|
+
]
|
|
99
|
+
)
|
|
100
|
+
M = np.asarray(
|
|
101
|
+
[
|
|
102
|
+
[
|
|
103
|
+
filtration_values[axis, :, 1, parameter].max()
|
|
104
|
+
for parameter in range(num_parameters)
|
|
105
|
+
]
|
|
106
|
+
for axis in range(self._num_axis)
|
|
107
|
+
]
|
|
108
|
+
)
|
|
109
|
+
# shape of m/M axis,num_parameters
|
|
110
|
+
self._boxes = [
|
|
111
|
+
np.array([m_of_axis, M_of_axis]) for m_of_axis, M_of_axis in zip(m, M)
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
def fit(self, X, y=None):
|
|
115
|
+
if len(X) == 0:
|
|
116
|
+
return self
|
|
117
|
+
self._input_checks(X)
|
|
118
|
+
self._infer_bounding_box(X)
|
|
119
|
+
return self
|
|
120
|
+
|
|
121
|
+
def transform(self, X):
|
|
122
|
+
if self.prune_degrees_above is not None:
|
|
123
|
+
for x in X:
|
|
124
|
+
for x_ in x:
|
|
125
|
+
if self._is_minpres:
|
|
126
|
+
for s_ in x_:
|
|
127
|
+
s_.prune_above_dimension(
|
|
128
|
+
self.prune_degrees_above
|
|
129
|
+
) # we only do for H0 for computational ease
|
|
130
|
+
else:
|
|
131
|
+
x_.prune_above_dimension(
|
|
132
|
+
self.prune_degrees_above
|
|
133
|
+
) # we only do for H0 for computational ease
|
|
134
|
+
|
|
135
|
+
def todo1(x, box):
|
|
136
|
+
if self.expand_dim is not None:
|
|
137
|
+
x.expansion(self.expand_dim)
|
|
138
|
+
if self.minpres_degrees is not None:
|
|
139
|
+
x = mp.slicer.minimal_presentation(
|
|
140
|
+
mp.Slicer(x), degrees=self.minpres_degrees, vineyard=True
|
|
141
|
+
)
|
|
142
|
+
mod = mp.module_approximation(
|
|
143
|
+
x, box=box, verbose=False, **self.persistence_args
|
|
144
|
+
)
|
|
145
|
+
if self.plot:
|
|
146
|
+
mod.plot()
|
|
147
|
+
return mod
|
|
148
|
+
|
|
149
|
+
def todo(sts: Iterable[_FilteredComplexType]):
|
|
150
|
+
return tuple(todo1(st, box) for st, box in zip(sts, self._boxes))
|
|
151
|
+
|
|
152
|
+
return Parallel(n_jobs=self.n_jobs, backend="threading")(
|
|
153
|
+
delayed(todo)(x)
|
|
154
|
+
for x in tqdm(X, desc="Computing modules", disable=not self.progress)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class SimplexTree2MMA(FilteredComplex2MMA):
|
|
159
|
+
def __init__(
|
|
160
|
+
self,
|
|
161
|
+
n_jobs: int = -1,
|
|
162
|
+
expand_dim: Optional[int] = None,
|
|
163
|
+
prune_degrees_above: Optional[int] = None,
|
|
164
|
+
progress=False,
|
|
165
|
+
minpres_degrees: Optional[Iterable[int]] = None,
|
|
166
|
+
**persistence_kwargs,
|
|
167
|
+
):
|
|
168
|
+
stuff = locals()
|
|
169
|
+
stuff.pop("self")
|
|
170
|
+
keys = list(stuff.keys())
|
|
171
|
+
for key in keys:
|
|
172
|
+
if key.startswith("__"):
|
|
173
|
+
stuff.pop(key)
|
|
174
|
+
super().__init__(**stuff)
|
|
175
|
+
from warnings import warn
|
|
176
|
+
|
|
177
|
+
warn("This class is deprecated, use FilteredComplex2MMA instead.")
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class MMAFormatter(BaseEstimator, TransformerMixin):
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
degrees: Optional[list[int]] = None,
|
|
184
|
+
axis=None,
|
|
185
|
+
verbose: bool = False,
|
|
186
|
+
normalize: bool = False,
|
|
187
|
+
weights=None,
|
|
188
|
+
quantiles=None,
|
|
189
|
+
dump=False,
|
|
190
|
+
from_dump=False,
|
|
191
|
+
):
|
|
192
|
+
self._module_bounds = None
|
|
193
|
+
self.verbose = verbose
|
|
194
|
+
self.axis = axis
|
|
195
|
+
self._axis = []
|
|
196
|
+
self._has_axis = None
|
|
197
|
+
self._num_axis = 0
|
|
198
|
+
self.degrees = degrees
|
|
199
|
+
self._degrees = None
|
|
200
|
+
self.normalize = normalize
|
|
201
|
+
self._num_parameters = None
|
|
202
|
+
self.weights = weights
|
|
203
|
+
self.quantiles = quantiles
|
|
204
|
+
self.dump = dump
|
|
205
|
+
self.from_dump = from_dump
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def _maybe_from_dump(X_in):
|
|
209
|
+
if len(X_in) == 0:
|
|
210
|
+
return X_in
|
|
211
|
+
import pickle
|
|
212
|
+
|
|
213
|
+
if isinstance(X_in[0], bytes):
|
|
214
|
+
X = [pickle.loads(mods) for mods in X_in]
|
|
215
|
+
else:
|
|
216
|
+
X = X_in
|
|
217
|
+
return X
|
|
218
|
+
# return [[mp.multiparameter_module_approximation.from_dump(mod) for mod in mods] for mods in dumped_modules]
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def _get_module_bound(x, degree):
|
|
222
|
+
"""
|
|
223
|
+
Output format : (2,num_parameters)
|
|
224
|
+
"""
|
|
225
|
+
# l,L = x.get_box()
|
|
226
|
+
filtration_values = x.get_module_of_degree(degree).get_filtration_values(
|
|
227
|
+
unique=True
|
|
228
|
+
)
|
|
229
|
+
out = np.array([[f[0], f[-1]] for f in filtration_values if len(f) > 0]).T
|
|
230
|
+
if len(out) != 2:
|
|
231
|
+
print(f"Missing degree {degree} here !")
|
|
232
|
+
m = M = [np.nan for _ in range(x.num_parameters)]
|
|
233
|
+
else:
|
|
234
|
+
m, M = out
|
|
235
|
+
# m = np.where(m<np.inf, m, l)
|
|
236
|
+
# M = np.where(M>-np.inf, M,L)
|
|
237
|
+
return m, M
|
|
238
|
+
|
|
239
|
+
@staticmethod
|
|
240
|
+
def _infer_axis(X):
|
|
241
|
+
has_axis = not isinstance(X[0], PyModule_type)
|
|
242
|
+
assert not has_axis or isinstance(X[0][0], PyModule_type)
|
|
243
|
+
return has_axis
|
|
244
|
+
|
|
245
|
+
@staticmethod
|
|
246
|
+
def _infer_num_parameters(X, ax=slice(None)):
|
|
247
|
+
return X[0][ax].num_parameters
|
|
248
|
+
|
|
249
|
+
@staticmethod
|
|
250
|
+
def _infer_bounds(X, degrees=None, axis=[slice(None)], quantiles=None):
|
|
251
|
+
"""
|
|
252
|
+
Compute bounds of filtration values of a list of modules.
|
|
253
|
+
|
|
254
|
+
Output Format
|
|
255
|
+
-------------
|
|
256
|
+
m,M of shape : (num_axis,num_degrees,2,num_parameters)
|
|
257
|
+
"""
|
|
258
|
+
if degrees is None:
|
|
259
|
+
degrees = np.arange(X[0][axis[0]].max_degree + 1)
|
|
260
|
+
bounds = np.array(
|
|
261
|
+
[
|
|
262
|
+
[
|
|
263
|
+
[
|
|
264
|
+
MMAFormatter._get_module_bound(x[ax], degree)
|
|
265
|
+
for degree in degrees
|
|
266
|
+
]
|
|
267
|
+
for ax in axis
|
|
268
|
+
]
|
|
269
|
+
for x in X
|
|
270
|
+
]
|
|
271
|
+
)
|
|
272
|
+
if quantiles is not None:
|
|
273
|
+
qm, qM = quantiles
|
|
274
|
+
# TODO per axis, degree !!
|
|
275
|
+
# m = np.quantile(bounds[:,:,:,0,:], q=qm,axis=0)
|
|
276
|
+
# M = np.quantile(bounds[:,:,:,1,:], q=1-qM,axis=0)
|
|
277
|
+
num_pts, num_axis, num_degrees, _, num_parameters = bounds.shape
|
|
278
|
+
m = [
|
|
279
|
+
[
|
|
280
|
+
[
|
|
281
|
+
np.nanquantile(
|
|
282
|
+
bounds[:, ax, degree, 0, parameter], axis=0, q=qm
|
|
283
|
+
)
|
|
284
|
+
for parameter in range(num_parameters)
|
|
285
|
+
]
|
|
286
|
+
for degree in range(num_degrees)
|
|
287
|
+
]
|
|
288
|
+
for ax in range(num_axis)
|
|
289
|
+
]
|
|
290
|
+
m = np.asarray(m)
|
|
291
|
+
M = [
|
|
292
|
+
[
|
|
293
|
+
[
|
|
294
|
+
np.nanquantile(
|
|
295
|
+
bounds[:, ax, degree, 1, parameter], axis=0, q=1 - qM
|
|
296
|
+
)
|
|
297
|
+
for parameter in range(num_parameters)
|
|
298
|
+
]
|
|
299
|
+
for degree in range(num_degrees)
|
|
300
|
+
]
|
|
301
|
+
for ax in range(num_axis)
|
|
302
|
+
]
|
|
303
|
+
M = np.asarray(M)
|
|
304
|
+
else:
|
|
305
|
+
num_pts, num_axis, num_degrees, _, num_parameters = bounds.shape
|
|
306
|
+
m = [
|
|
307
|
+
[
|
|
308
|
+
[
|
|
309
|
+
np.nanmin(bounds[:, ax, degree, 0, parameter], axis=0)
|
|
310
|
+
for parameter in range(num_parameters)
|
|
311
|
+
]
|
|
312
|
+
for degree in range(num_degrees)
|
|
313
|
+
]
|
|
314
|
+
for ax in range(num_axis)
|
|
315
|
+
]
|
|
316
|
+
m = np.asarray(m)
|
|
317
|
+
M = [
|
|
318
|
+
[
|
|
319
|
+
[
|
|
320
|
+
np.nanmax(bounds[:, ax, degree, 1, parameter], axis=0)
|
|
321
|
+
for parameter in range(num_parameters)
|
|
322
|
+
]
|
|
323
|
+
for degree in range(num_degrees)
|
|
324
|
+
]
|
|
325
|
+
for ax in range(num_axis)
|
|
326
|
+
]
|
|
327
|
+
M = np.asarray(M)
|
|
328
|
+
# m = bounds[:,:,:,0,:].min(axis=0)
|
|
329
|
+
# M = bounds[:,:,:,1,:].max(axis=0)
|
|
330
|
+
return (m, M)
|
|
331
|
+
|
|
332
|
+
@staticmethod
|
|
333
|
+
def _infer_grid(
|
|
334
|
+
X: List[PyModule_type], strategy: str, resolution: int, degrees=None
|
|
335
|
+
):
|
|
336
|
+
"""
|
|
337
|
+
Given a list of PyModules, computes a multiparameter discrete grid,
|
|
338
|
+
with a given strategy,
|
|
339
|
+
from the filtration values of the summands of the modules.
|
|
340
|
+
"""
|
|
341
|
+
num_parameters = X[0].num_parameters
|
|
342
|
+
if degrees is None:
|
|
343
|
+
# Format here : ((filtration values of parameter) for parameter)
|
|
344
|
+
filtration_values = tuple(
|
|
345
|
+
mod.get_filtration_values(unique=True) for mod in X
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
filtration_values = tuple(
|
|
349
|
+
mod.get_module_of_degrees(degrees).get_filtration_values(unique=True)
|
|
350
|
+
for mod in X
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
if "_mean" in strategy:
|
|
354
|
+
substrategy = strategy.split("_")[0]
|
|
355
|
+
processed_filtration_values = [
|
|
356
|
+
compute_grid(f, resolution, substrategy, unique=False)
|
|
357
|
+
for f in filtration_values
|
|
358
|
+
]
|
|
359
|
+
reduced_grid = np.mean(processed_filtration_values, axis=0)
|
|
360
|
+
# elif "_quantile" in strategy:
|
|
361
|
+
# substrategy = strategy.split("_")[0]
|
|
362
|
+
# processed_filtration_values = [reduce_grid(f, resolution, substrategy, unique=False) for f in filtration_values]
|
|
363
|
+
# reduced_grid = np.qu(processed_filtration_values, axis=0)
|
|
364
|
+
else:
|
|
365
|
+
filtration_values = [
|
|
366
|
+
np.unique(
|
|
367
|
+
np.concatenate([f[parameter] for f in filtration_values], axis=0)
|
|
368
|
+
)
|
|
369
|
+
for parameter in range(num_parameters)
|
|
370
|
+
]
|
|
371
|
+
reduced_grid = compute_grid(
|
|
372
|
+
filtration_values, resolution, strategy, unique=True
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
return reduced_grid
|
|
376
|
+
|
|
377
|
+
def _infer_degrees(self, X):
|
|
378
|
+
if self.degrees is None:
|
|
379
|
+
max_degrees = [
|
|
380
|
+
x[ax].max_degree for i, ax in enumerate(self._axis) for x in X
|
|
381
|
+
] + [0]
|
|
382
|
+
self._degrees = np.arange(np.max(max_degrees) + 1)
|
|
383
|
+
else:
|
|
384
|
+
self._degrees = self.degrees
|
|
385
|
+
|
|
386
|
+
def fit(self, X_in, y=None):
|
|
387
|
+
X = self._maybe_from_dump(X_in)
|
|
388
|
+
if len(X) == 0:
|
|
389
|
+
return self
|
|
390
|
+
self._has_axis = self._infer_axis(X)
|
|
391
|
+
# assert not self._has_axis or isinstance(X[0][0], mp.PyModule)
|
|
392
|
+
if self.axis is None and self._has_axis:
|
|
393
|
+
self.axis = -1
|
|
394
|
+
if self.axis is not None and not (self._has_axis):
|
|
395
|
+
raise Exception(f"SMF didn't find an axis, but requested axis {self.axis}")
|
|
396
|
+
if self._has_axis:
|
|
397
|
+
self._num_axis = len(X[0])
|
|
398
|
+
if self.verbose:
|
|
399
|
+
print("-----------MMAFormatter-----------")
|
|
400
|
+
print("---- Infered stats")
|
|
401
|
+
print(f"Found axis : {self._has_axis}, num : {self._num_axis}")
|
|
402
|
+
print(f"Number of parameters : {self._num_parameters}")
|
|
403
|
+
self._axis = (
|
|
404
|
+
[slice(None)]
|
|
405
|
+
if self.axis is None
|
|
406
|
+
else range(self._num_axis) if self.axis == -1 else [self.axis]
|
|
407
|
+
)
|
|
408
|
+
self._infer_degrees(X)
|
|
409
|
+
|
|
410
|
+
self._num_parameters = self._infer_num_parameters(X, ax=self._axis[0])
|
|
411
|
+
if self.normalize:
|
|
412
|
+
# print(self._axis)
|
|
413
|
+
self._module_bounds = self._infer_bounds(
|
|
414
|
+
X, self._degrees, self._axis, self.quantiles
|
|
415
|
+
)
|
|
416
|
+
else:
|
|
417
|
+
m = np.zeros((self._num_axis, len(self._degrees), self._num_parameters))
|
|
418
|
+
M = m + 1
|
|
419
|
+
self._module_bounds = (m, M)
|
|
420
|
+
assert self._num_parameters == self._module_bounds[0].shape[-1]
|
|
421
|
+
if self.verbose:
|
|
422
|
+
print("---- Bounds (only computed if normalize):")
|
|
423
|
+
if self._has_axis and self._num_axis > 1:
|
|
424
|
+
print("(axis) x (degree) x (parameter)")
|
|
425
|
+
else:
|
|
426
|
+
print("(degree) x (parameter)")
|
|
427
|
+
m, M = self._module_bounds
|
|
428
|
+
print("-- Lower bound : ", m.shape)
|
|
429
|
+
print(m)
|
|
430
|
+
print("-- Upper bound :", M.shape)
|
|
431
|
+
print(M)
|
|
432
|
+
w = 1 if self.weights is None else np.asarray(self.weights)
|
|
433
|
+
m, M = self._module_bounds
|
|
434
|
+
normalizer = M - m
|
|
435
|
+
zero_normalizer = normalizer == 0
|
|
436
|
+
if np.any(zero_normalizer):
|
|
437
|
+
from warnings import warn
|
|
438
|
+
|
|
439
|
+
warn(f"Encountered empty bounds. Please fix me. \n M-m = {normalizer}")
|
|
440
|
+
normalizer[zero_normalizer] = 1
|
|
441
|
+
self._normalization_factors = w / normalizer
|
|
442
|
+
if self.verbose:
|
|
443
|
+
print("-- Normalization factors:", self._normalization_factors.shape)
|
|
444
|
+
print(self._normalization_factors)
|
|
445
|
+
|
|
446
|
+
if self.verbose:
|
|
447
|
+
print("---- Module size :")
|
|
448
|
+
for ax in self._axis:
|
|
449
|
+
print(f"- Axis {ax}")
|
|
450
|
+
for degree in self._degrees:
|
|
451
|
+
sizes = [len(x[ax].get_module_of_degree(degree)) for x in X]
|
|
452
|
+
print(
|
|
453
|
+
f" - Degree {degree} size \
|
|
454
|
+
{np.mean(sizes).round(decimals=2)}\
|
|
455
|
+
±{np.std(sizes).round(decimals=2)}"
|
|
456
|
+
)
|
|
457
|
+
print("----------------------------------")
|
|
458
|
+
return self
|
|
459
|
+
|
|
460
|
+
@staticmethod
|
|
461
|
+
def copy_transform(mod, degrees, translation, rescale_factors, new_box):
|
|
462
|
+
copy = mod.get_module_of_degrees(
|
|
463
|
+
degrees
|
|
464
|
+
) # and only returns the specific degrees
|
|
465
|
+
for j, degree in enumerate(degrees):
|
|
466
|
+
copy.translate(translation[j], degree=degree)
|
|
467
|
+
copy.rescale(rescale_factors[j], degree=degree)
|
|
468
|
+
copy.set_box(new_box)
|
|
469
|
+
return copy
|
|
470
|
+
|
|
471
|
+
def transform(self, X_in):
|
|
472
|
+
X = self._maybe_from_dump(X_in)
|
|
473
|
+
if np.any(self._normalization_factors != 1):
|
|
474
|
+
if self.verbose:
|
|
475
|
+
print("Normalizing...", end="")
|
|
476
|
+
w = (
|
|
477
|
+
[1] * self._num_parameters
|
|
478
|
+
if self.weights is None
|
|
479
|
+
else np.asarray(self.weights)
|
|
480
|
+
)
|
|
481
|
+
standard_box = np.array([[0] * self._num_parameters, w])
|
|
482
|
+
|
|
483
|
+
X_copy = [
|
|
484
|
+
[
|
|
485
|
+
self.copy_transform(
|
|
486
|
+
mod=x[ax],
|
|
487
|
+
degrees=self._degrees,
|
|
488
|
+
translation=-self._module_bounds[0][i],
|
|
489
|
+
rescale_factors=self._normalization_factors[i],
|
|
490
|
+
new_box=standard_box,
|
|
491
|
+
)
|
|
492
|
+
for i, ax in enumerate(self._axis)
|
|
493
|
+
]
|
|
494
|
+
for x in X
|
|
495
|
+
]
|
|
496
|
+
if self.verbose:
|
|
497
|
+
print("Done.")
|
|
498
|
+
return X_copy
|
|
499
|
+
if self.axis != -1:
|
|
500
|
+
X = [x[self.axis] for x in X]
|
|
501
|
+
if self.dump:
|
|
502
|
+
import pickle
|
|
503
|
+
|
|
504
|
+
X = [pickle.dumps(mods) for mods in X]
|
|
505
|
+
return X
|
|
506
|
+
# return [todo(x) for x in X]
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
class MMA2IMG(BaseEstimator, TransformerMixin):
|
|
510
|
+
def __init__(
|
|
511
|
+
self,
|
|
512
|
+
degrees: list,
|
|
513
|
+
bandwidth: float = 0.1,
|
|
514
|
+
power: float = 1,
|
|
515
|
+
normalize: bool = False,
|
|
516
|
+
resolution: list | int = 50,
|
|
517
|
+
plot: bool = False,
|
|
518
|
+
box=None,
|
|
519
|
+
n_jobs=-1,
|
|
520
|
+
flatten=False,
|
|
521
|
+
progress=False,
|
|
522
|
+
grid_strategy="regular",
|
|
523
|
+
kernel="linear",
|
|
524
|
+
signed: bool = False,
|
|
525
|
+
):
|
|
526
|
+
self.bandwidth = bandwidth
|
|
527
|
+
self.degrees = degrees
|
|
528
|
+
self.resolution = resolution
|
|
529
|
+
self.box = box
|
|
530
|
+
self.plot = plot
|
|
531
|
+
self._box = None
|
|
532
|
+
self.normalize = normalize
|
|
533
|
+
self.power = power
|
|
534
|
+
self._has_axis = None
|
|
535
|
+
self._num_parameters = None
|
|
536
|
+
self.n_jobs = n_jobs
|
|
537
|
+
self.flatten = flatten
|
|
538
|
+
self.progress = progress
|
|
539
|
+
self.grid_strategy = grid_strategy
|
|
540
|
+
self._num_axis = None
|
|
541
|
+
self._coords_to_compute = None
|
|
542
|
+
self._new_resolutions = None
|
|
543
|
+
self.kernel = kernel
|
|
544
|
+
self.signed = signed
|
|
545
|
+
|
|
546
|
+
def fit(self, X, y=None):
|
|
547
|
+
# TODO infer box
|
|
548
|
+
# TODO rescale module
|
|
549
|
+
self._has_axis = MMAFormatter._infer_axis(X)
|
|
550
|
+
if self._has_axis:
|
|
551
|
+
self._num_axis = len(X[0])
|
|
552
|
+
if self.box is None:
|
|
553
|
+
self._box = [[0, 0], [1, 1]]
|
|
554
|
+
else:
|
|
555
|
+
self._box = self.box
|
|
556
|
+
if self._has_axis:
|
|
557
|
+
its = (tuple(x[axis] for x in X) for axis in range(self._num_axis))
|
|
558
|
+
crs = tuple(
|
|
559
|
+
MMAFormatter._infer_grid(
|
|
560
|
+
X_axis, self.grid_strategy, self.resolution, degrees=self.degrees
|
|
561
|
+
)
|
|
562
|
+
for X_axis in its
|
|
563
|
+
)
|
|
564
|
+
self._coords_to_compute = (
|
|
565
|
+
crs # not the same resolutions, so cannot be put in an array
|
|
566
|
+
)
|
|
567
|
+
self._new_resolutions = np.asarray([tuple(len(g) for g in G) for G in crs])
|
|
568
|
+
else:
|
|
569
|
+
coords = MMAFormatter._infer_grid(
|
|
570
|
+
X, self.grid_strategy, self.resolution, degrees=self.degrees
|
|
571
|
+
)
|
|
572
|
+
self._coords_to_compute = coords
|
|
573
|
+
self._new_resolutions = np.array([len(g) for g in coords])
|
|
574
|
+
return self
|
|
575
|
+
|
|
576
|
+
def transform(self, X):
|
|
577
|
+
img_args = {
|
|
578
|
+
"bandwidth": self.bandwidth,
|
|
579
|
+
"p": self.power,
|
|
580
|
+
"normalize": self.normalize,
|
|
581
|
+
# "plot":self.plot,
|
|
582
|
+
# "cb":1, # colorbar
|
|
583
|
+
# "resolution" : self.resolution, # info in coordinates
|
|
584
|
+
"box": self.box,
|
|
585
|
+
"degrees": self.degrees,
|
|
586
|
+
# num_jobs is better for parallel over modules.
|
|
587
|
+
"n_jobs": self.n_jobs,
|
|
588
|
+
"kernel": self.kernel,
|
|
589
|
+
"signed": self.signed,
|
|
590
|
+
"flatten": True, # custom coordinates
|
|
591
|
+
}
|
|
592
|
+
if self._has_axis:
|
|
593
|
+
|
|
594
|
+
def todo1(x, c):
|
|
595
|
+
return x.representation(grid=c, **img_args)
|
|
596
|
+
|
|
597
|
+
else:
|
|
598
|
+
|
|
599
|
+
def todo1(x):
|
|
600
|
+
return x.representation(grid=self._coords_to_compute, **img_args)[
|
|
601
|
+
None, :
|
|
602
|
+
] # shape same as has_axis
|
|
603
|
+
|
|
604
|
+
if self._has_axis:
|
|
605
|
+
|
|
606
|
+
def todo2(mods):
|
|
607
|
+
return tuple(
|
|
608
|
+
todo1(mod, c) for mod, c in zip(mods, self._coords_to_compute)
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
else:
|
|
612
|
+
todo2 = todo1
|
|
613
|
+
|
|
614
|
+
if self.flatten:
|
|
615
|
+
|
|
616
|
+
def todo(mods):
|
|
617
|
+
return np.concatenate(todo2(mods), axis=1).flatten()
|
|
618
|
+
|
|
619
|
+
else:
|
|
620
|
+
|
|
621
|
+
def todo(mods):
|
|
622
|
+
return tuple(
|
|
623
|
+
img.reshape(len(img_args["degrees"]), *r)
|
|
624
|
+
for img, r in zip(todo2(mods), self._new_resolutions)
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
return Parallel(n_jobs=self.n_jobs, backend="threading")(
|
|
628
|
+
delayed(todo)(x)
|
|
629
|
+
for x in tqdm(X, desc="Computing images", disable=not self.progress)
|
|
630
|
+
) # res depends on ax (infer_grid)
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
class MMA2Landscape(BaseEstimator, TransformerMixin):
|
|
634
|
+
"""
|
|
635
|
+
Turns a list of MMA approximations into Landscapes vectorisations
|
|
636
|
+
"""
|
|
637
|
+
|
|
638
|
+
def __init__(
|
|
639
|
+
self,
|
|
640
|
+
resolution=[100, 100],
|
|
641
|
+
degrees: list[int] | None = [0, 1],
|
|
642
|
+
ks: Iterable[int] = range(5),
|
|
643
|
+
phi: Callable = np.sum,
|
|
644
|
+
box=None,
|
|
645
|
+
plot: bool = False,
|
|
646
|
+
n_jobs=-1,
|
|
647
|
+
filtration_quantile: float = 0.01,
|
|
648
|
+
) -> None:
|
|
649
|
+
super().__init__()
|
|
650
|
+
self.resolution: list[int] = resolution
|
|
651
|
+
self.degrees = degrees
|
|
652
|
+
self.ks = ks
|
|
653
|
+
self.phi = phi # Has to have a axis=0 !
|
|
654
|
+
self.box = box
|
|
655
|
+
self.plot = plot
|
|
656
|
+
self.n_jobs = n_jobs
|
|
657
|
+
self.filtration_quantile = filtration_quantile
|
|
658
|
+
return
|
|
659
|
+
|
|
660
|
+
def fit(self, X, y=None):
|
|
661
|
+
if len(X) <= 0:
|
|
662
|
+
return
|
|
663
|
+
assert (
|
|
664
|
+
X[0].num_parameters == 2
|
|
665
|
+
), f"Number of parameters {X[0].num_parameters} has to be 2."
|
|
666
|
+
if self.box is None:
|
|
667
|
+
|
|
668
|
+
def _bottom(mod):
|
|
669
|
+
return mod.get_bottom()
|
|
670
|
+
|
|
671
|
+
def _top(mod):
|
|
672
|
+
return mod.get_top()
|
|
673
|
+
|
|
674
|
+
m = np.quantile(
|
|
675
|
+
Parallel(n_jobs=self.n_jobs, backend="threading")(
|
|
676
|
+
delayed(_bottom)(mod) for mod in X
|
|
677
|
+
),
|
|
678
|
+
q=self.filtration_quantile,
|
|
679
|
+
axis=0,
|
|
680
|
+
)
|
|
681
|
+
M = np.quantile(
|
|
682
|
+
Parallel(n_jobs=self.n_jobs, backend="threading")(
|
|
683
|
+
delayed(_top)(mod) for mod in X
|
|
684
|
+
),
|
|
685
|
+
q=1 - self.filtration_quantile,
|
|
686
|
+
axis=0,
|
|
687
|
+
)
|
|
688
|
+
self.box = [m, M]
|
|
689
|
+
return self
|
|
690
|
+
|
|
691
|
+
def transform(self, X) -> list[np.ndarray]:
|
|
692
|
+
if len(X) <= 0:
|
|
693
|
+
return []
|
|
694
|
+
|
|
695
|
+
def todo(mod):
|
|
696
|
+
return np.concatenate(
|
|
697
|
+
[
|
|
698
|
+
self.phi(
|
|
699
|
+
mod.landscapes(
|
|
700
|
+
ks=self.ks,
|
|
701
|
+
resolution=self.resolution,
|
|
702
|
+
degree=degree,
|
|
703
|
+
plot=self.plot,
|
|
704
|
+
),
|
|
705
|
+
axis=0,
|
|
706
|
+
).flatten()
|
|
707
|
+
for degree in self.degrees
|
|
708
|
+
]
|
|
709
|
+
).flatten()
|
|
710
|
+
|
|
711
|
+
return Parallel(n_jobs=self.n_jobs, backend="threading")(
|
|
712
|
+
delayed(todo)(x) for x in X
|
|
713
|
+
)
|