multipers 2.2.3__cp312-cp312-win_amd64.whl → 2.3.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -31
  2. multipers/_signed_measure_meta.py +430 -430
  3. multipers/_slicer_meta.py +211 -212
  4. multipers/data/MOL2.py +458 -458
  5. multipers/data/UCR.py +18 -18
  6. multipers/data/graphs.py +466 -466
  7. multipers/data/immuno_regions.py +27 -27
  8. multipers/data/pytorch2simplextree.py +90 -90
  9. multipers/data/shape3d.py +101 -101
  10. multipers/data/synthetic.py +113 -111
  11. multipers/distances.py +198 -198
  12. multipers/filtration_conversions.pxd.tp +84 -84
  13. multipers/filtrations/__init__.py +18 -0
  14. multipers/filtrations/filtrations.py +289 -0
  15. multipers/filtrations.pxd +224 -224
  16. multipers/function_rips.cp312-win_amd64.pyd +0 -0
  17. multipers/function_rips.pyx +105 -105
  18. multipers/grids.cp312-win_amd64.pyd +0 -0
  19. multipers/grids.pyx +350 -350
  20. multipers/gudhi/Persistence_slices_interface.h +132 -132
  21. multipers/gudhi/Simplex_tree_interface.h +239 -245
  22. multipers/gudhi/Simplex_tree_multi_interface.h +516 -561
  23. multipers/gudhi/cubical_to_boundary.h +59 -59
  24. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -450
  25. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -1070
  26. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -579
  27. multipers/gudhi/gudhi/Debug_utils.h +45 -45
  28. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -484
  29. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -455
  30. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -450
  31. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -531
  32. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -507
  33. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -531
  34. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -355
  35. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -376
  36. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -420
  37. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -400
  38. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -418
  39. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -337
  40. multipers/gudhi/gudhi/Matrix.h +2107 -2107
  41. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -1038
  42. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -171
  43. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -282
  44. multipers/gudhi/gudhi/Off_reader.h +173 -173
  45. multipers/gudhi/gudhi/One_critical_filtration.h +1432 -1431
  46. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -769
  47. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -686
  48. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -842
  49. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -1350
  50. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -1105
  51. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -859
  52. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -910
  53. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -139
  54. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -230
  55. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -211
  56. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -60
  57. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -60
  58. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -136
  59. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -190
  60. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -616
  61. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -150
  62. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -106
  63. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -219
  64. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -327
  65. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -1140
  66. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -934
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -934
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -980
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -1092
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -192
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -921
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -1093
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -1012
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -1244
  75. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -186
  76. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -164
  77. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -156
  78. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -376
  79. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -540
  80. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -118
  81. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -173
  82. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -128
  83. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -745
  84. multipers/gudhi/gudhi/Points_off_io.h +171 -171
  85. multipers/gudhi/gudhi/Simple_object_pool.h +69 -69
  86. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -463
  87. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -83
  88. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -106
  89. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -277
  90. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -62
  91. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -27
  92. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -62
  93. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -157
  94. multipers/gudhi/gudhi/Simplex_tree.h +2794 -2794
  95. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -163
  96. multipers/gudhi/gudhi/distance_functions.h +62 -62
  97. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -104
  98. multipers/gudhi/gudhi/persistence_interval.h +253 -253
  99. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -170
  100. multipers/gudhi/gudhi/reader_utils.h +367 -367
  101. multipers/gudhi/mma_interface_coh.h +256 -255
  102. multipers/gudhi/mma_interface_h0.h +223 -231
  103. multipers/gudhi/mma_interface_matrix.h +284 -282
  104. multipers/gudhi/naive_merge_tree.h +536 -575
  105. multipers/gudhi/scc_io.h +310 -289
  106. multipers/gudhi/truc.h +890 -888
  107. multipers/io.cp312-win_amd64.pyd +0 -0
  108. multipers/io.pyx +711 -711
  109. multipers/ml/accuracies.py +90 -90
  110. multipers/ml/convolutions.py +520 -520
  111. multipers/ml/invariants_with_persistable.py +79 -79
  112. multipers/ml/kernels.py +176 -176
  113. multipers/ml/mma.py +713 -714
  114. multipers/ml/one.py +472 -472
  115. multipers/ml/point_clouds.py +352 -346
  116. multipers/ml/signed_measures.py +1589 -1589
  117. multipers/ml/sliced_wasserstein.py +461 -461
  118. multipers/ml/tools.py +113 -113
  119. multipers/mma_structures.cp312-win_amd64.pyd +0 -0
  120. multipers/mma_structures.pxd +127 -127
  121. multipers/mma_structures.pyx +4 -4
  122. multipers/mma_structures.pyx.tp +1085 -1085
  123. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -93
  124. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -97
  125. multipers/multi_parameter_rank_invariant/function_rips.h +322 -322
  126. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -769
  127. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -148
  128. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -369
  129. multipers/multiparameter_edge_collapse.py +41 -41
  130. multipers/multiparameter_module_approximation/approximation.h +2296 -2295
  131. multipers/multiparameter_module_approximation/combinatory.h +129 -129
  132. multipers/multiparameter_module_approximation/debug.h +107 -107
  133. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -286
  134. multipers/multiparameter_module_approximation/heap_column.h +238 -238
  135. multipers/multiparameter_module_approximation/images.h +79 -79
  136. multipers/multiparameter_module_approximation/list_column.h +174 -174
  137. multipers/multiparameter_module_approximation/list_column_2.h +232 -232
  138. multipers/multiparameter_module_approximation/ru_matrix.h +347 -347
  139. multipers/multiparameter_module_approximation/set_column.h +135 -135
  140. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -36
  141. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -166
  142. multipers/multiparameter_module_approximation/utilities.h +403 -419
  143. multipers/multiparameter_module_approximation/vector_column.h +223 -223
  144. multipers/multiparameter_module_approximation/vector_matrix.h +331 -331
  145. multipers/multiparameter_module_approximation/vineyards.h +464 -464
  146. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -649
  147. multipers/multiparameter_module_approximation.cp312-win_amd64.pyd +0 -0
  148. multipers/multiparameter_module_approximation.pyx +216 -217
  149. multipers/pickle.py +90 -53
  150. multipers/plots.py +342 -334
  151. multipers/point_measure.cp312-win_amd64.pyd +0 -0
  152. multipers/point_measure.pyx +322 -320
  153. multipers/simplex_tree_multi.cp312-win_amd64.pyd +0 -0
  154. multipers/simplex_tree_multi.pxd +133 -133
  155. multipers/simplex_tree_multi.pyx +18 -15
  156. multipers/simplex_tree_multi.pyx.tp +1939 -1935
  157. multipers/slicer.cp312-win_amd64.pyd +0 -0
  158. multipers/slicer.pxd +81 -20
  159. multipers/slicer.pxd.tp +215 -214
  160. multipers/slicer.pyx +1091 -308
  161. multipers/slicer.pyx.tp +924 -914
  162. multipers/tensor/tensor.h +672 -672
  163. multipers/tensor.pxd +13 -13
  164. multipers/test.pyx +44 -44
  165. multipers/tests/__init__.py +57 -57
  166. multipers/torch/diff_grids.py +217 -217
  167. multipers/torch/rips_density.py +310 -304
  168. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/LICENSE +21 -21
  169. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/METADATA +21 -11
  170. multipers-2.3.0.dist-info/RECORD +182 -0
  171. multipers/tests/test_diff_helper.py +0 -73
  172. multipers/tests/test_hilbert_function.py +0 -82
  173. multipers/tests/test_mma.py +0 -83
  174. multipers/tests/test_point_clouds.py +0 -49
  175. multipers/tests/test_python-cpp_conversion.py +0 -82
  176. multipers/tests/test_signed_betti.py +0 -181
  177. multipers/tests/test_signed_measure.py +0 -89
  178. multipers/tests/test_simplextreemulti.py +0 -221
  179. multipers/tests/test_slicer.py +0 -221
  180. multipers-2.2.3.dist-info/RECORD +0 -189
  181. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/WHEEL +0 -0
  182. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/top_level.txt +0 -0
multipers/ml/mma.py CHANGED
@@ -1,714 +1,713 @@
1
- from typing import Callable, Iterable, List, Optional, Union
2
-
3
- import numpy as np
4
- from joblib import Parallel, delayed
5
- from sklearn.base import BaseEstimator, TransformerMixin
6
- from tqdm import tqdm
7
-
8
- import multipers as mp
9
- import multipers.simplex_tree_multi
10
- import multipers.slicer
11
- from multipers.grids import compute_grid as reduce_grid
12
- from multipers.ml.tools import filtration_grid_to_coordinates
13
- from multipers.mma_structures import PyBox_f64, PyModule_type
14
-
15
- _FilteredComplexType = Union[
16
- mp.slicer.Slicer_type, mp.simplex_tree_multi.SimplexTreeMulti_type
17
- ]
18
-
19
-
20
- class FilteredComplex2MMA(BaseEstimator, TransformerMixin):
21
- """
22
- Turns a list of list of simplextrees or slicers to MMA approximations.
23
-
24
- """
25
-
26
- def __init__(
27
- self,
28
- n_jobs: int = -1,
29
- expand_dim: Optional[int] = None,
30
- prune_degrees_above: Optional[int] = None,
31
- progress=False,
32
- minpres_degrees: Optional[Iterable[int]] = None,
33
- plot: bool = False,
34
- **persistence_kwargs,
35
- ) -> None:
36
- super().__init__()
37
- self.persistence_args = persistence_kwargs
38
- self.n_jobs = n_jobs
39
- self._num_axis = None
40
- self.prune_degrees_above = prune_degrees_above
41
- self.progress = progress
42
- self.expand_dim = expand_dim
43
- self._boxes = None
44
- self._is_minpres = None
45
- self.minpres_degrees = minpres_degrees
46
- self.plot = plot
47
- return
48
-
49
- @staticmethod
50
- def _is_filtered_complex(input):
51
- return mp.simplex_tree_multi.is_simplextree_multi(input) or mp.slicer.is_slicer(
52
- input, allow_minpres=True
53
- )
54
-
55
- def _input_checks(self, X):
56
- assert len(X) > 0, "No filtered complex found. Cannot fit."
57
- assert self._is_filtered_complex(
58
- X[0][0]
59
- ), f"X[0] is not a known filtered complex, {X[0]=}, nor X[0][0]."
60
- self._num_axis = len(X[0])
61
- first = X[0][0]
62
- assert (
63
- not mp.slicer.is_slicer(first) or self.expand_dim is None
64
- ), "Cannot expand slicers."
65
- self._is_minpres = mp.slicer.is_slicer(first) and isinstance(
66
- first, Union[tuple, list]
67
- )
68
- assert not (
69
- self._is_minpres and self.minpres_degrees is not None
70
- ), "Input is already a minpres. Cannot reduce again."
71
-
72
- def _infer_bounding_box(self, X):
73
- assert self._num_axis is not None, "Fit first"
74
- filtration_values = (
75
- np.asarray(
76
- [
77
- [s.filtration_bounds() for x in X for s in x[axis]]
78
- for axis in range(self._num_axis)
79
- ]
80
- )
81
- if self._is_minpres
82
- else np.asarray(
83
- [
84
- [x[axis].filtration_bounds() for x in X]
85
- for axis in range(self._num_axis)
86
- ]
87
- )
88
- )
89
- num_parameters = filtration_values.shape[-1]
90
- # Output : axis, data, min/max, num_parameters
91
- # print("TEST : NUM PARAMETERS ", num_parameters)
92
- m = np.asarray(
93
- [
94
- [
95
- filtration_values[axis, :, 0, parameter].min()
96
- for parameter in range(num_parameters)
97
- ]
98
- for axis in range(self._num_axis)
99
- ]
100
- )
101
- M = np.asarray(
102
- [
103
- [
104
- filtration_values[axis, :, 1, parameter].max()
105
- for parameter in range(num_parameters)
106
- ]
107
- for axis in range(self._num_axis)
108
- ]
109
- )
110
- # shape of m/M axis,num_parameters
111
- self._boxes = [
112
- np.array([m_of_axis, M_of_axis]) for m_of_axis, M_of_axis in zip(m, M)
113
- ]
114
-
115
- def fit(self, X, y=None):
116
- if len(X) == 0:
117
- return self
118
- self._input_checks(X)
119
- self._infer_bounding_box(X)
120
- return self
121
-
122
- def transform(self, X):
123
- if self.prune_degrees_above is not None:
124
- for x in X:
125
- for x_ in x:
126
- if self._is_minpres:
127
- for s_ in x_:
128
- s_.prune_above_dimension(
129
- self.prune_degrees_above
130
- ) # we only do for H0 for computational ease
131
- else:
132
- x_.prune_above_dimension(
133
- self.prune_degrees_above
134
- ) # we only do for H0 for computational ease
135
-
136
- def todo1(x, box):
137
- if self.expand_dim is not None:
138
- x.expansion(self.expand_dim)
139
- if self.minpres_degrees is not None:
140
- x = mp.slicer.minimal_presentation(
141
- mp.Slicer(x), degrees=self.minpres_degrees, vineyard=True
142
- )
143
- mod = mp.module_approximation(
144
- x, box=box, verbose=False, **self.persistence_args
145
- )
146
- if self.plot:
147
- mod.plot()
148
- return mod
149
-
150
- def todo(sts: Iterable[_FilteredComplexType]):
151
- return tuple(todo1(st, box) for st, box in zip(sts, self._boxes))
152
-
153
- return Parallel(n_jobs=self.n_jobs, backend="threading")(
154
- delayed(todo)(x)
155
- for x in tqdm(X, desc="Computing modules", disable=not self.progress)
156
- )
157
-
158
-
159
- class SimplexTree2MMA(FilteredComplex2MMA):
160
- def __init__(
161
- self,
162
- n_jobs: int = -1,
163
- expand_dim: Optional[int] = None,
164
- prune_degrees_above: Optional[int] = None,
165
- progress=False,
166
- minpres_degrees: Optional[Iterable[int]] = None,
167
- **persistence_kwargs,
168
- ):
169
- stuff = locals()
170
- stuff.pop("self")
171
- keys = list(stuff.keys())
172
- for key in keys:
173
- if key.startswith("__"):
174
- stuff.pop(key)
175
- super().__init__(**stuff)
176
- from warnings import warn
177
-
178
- warn("This class is deprecated, use FilteredComplex2MMA instead.")
179
-
180
-
181
- class MMAFormatter(BaseEstimator, TransformerMixin):
182
- def __init__(
183
- self,
184
- degrees: Optional[list[int]] = None,
185
- axis=None,
186
- verbose: bool = False,
187
- normalize: bool = False,
188
- weights=None,
189
- quantiles=None,
190
- dump=False,
191
- from_dump=False,
192
- ):
193
- self._module_bounds = None
194
- self.verbose = verbose
195
- self.axis = axis
196
- self._axis = []
197
- self._has_axis = None
198
- self._num_axis = 0
199
- self.degrees = degrees
200
- self._degrees = None
201
- self.normalize = normalize
202
- self._num_parameters = None
203
- self.weights = weights
204
- self.quantiles = quantiles
205
- self.dump = dump
206
- self.from_dump = from_dump
207
-
208
- @staticmethod
209
- def _maybe_from_dump(X_in):
210
- if len(X_in) == 0:
211
- return X_in
212
- import pickle
213
-
214
- if isinstance(X_in[0], bytes):
215
- X = [pickle.loads(mods) for mods in X_in]
216
- else:
217
- X = X_in
218
- return X
219
- # return [[mp.multiparameter_module_approximation.from_dump(mod) for mod in mods] for mods in dumped_modules]
220
-
221
- @staticmethod
222
- def _get_module_bound(x, degree):
223
- """
224
- Output format : (2,num_parameters)
225
- """
226
- # l,L = x.get_box()
227
- filtration_values = x.get_module_of_degree(degree).get_filtration_values(
228
- unique=True
229
- )
230
- out = np.array([[f[0], f[-1]] for f in filtration_values if len(f) > 0]).T
231
- if len(out) != 2:
232
- print(f"Missing degree {degree} here !")
233
- m = M = [np.nan for _ in range(x.num_parameters)]
234
- else:
235
- m, M = out
236
- # m = np.where(m<np.inf, m, l)
237
- # M = np.where(M>-np.inf, M,L)
238
- return m, M
239
-
240
- @staticmethod
241
- def _infer_axis(X):
242
- has_axis = not isinstance(X[0], PyModule_type)
243
- assert not has_axis or isinstance(X[0][0], PyModule_type)
244
- return has_axis
245
-
246
- @staticmethod
247
- def _infer_num_parameters(X, ax=slice(None)):
248
- return X[0][ax].num_parameters
249
-
250
- @staticmethod
251
- def _infer_bounds(X, degrees=None, axis=[slice(None)], quantiles=None):
252
- """
253
- Compute bounds of filtration values of a list of modules.
254
-
255
- Output Format
256
- -------------
257
- m,M of shape : (num_axis,num_degrees,2,num_parameters)
258
- """
259
- if degrees is None:
260
- degrees = np.arange(X[0][axis[0]].max_degree + 1)
261
- bounds = np.array(
262
- [
263
- [
264
- [
265
- MMAFormatter._get_module_bound(x[ax], degree)
266
- for degree in degrees
267
- ]
268
- for ax in axis
269
- ]
270
- for x in X
271
- ]
272
- )
273
- if quantiles is not None:
274
- qm, qM = quantiles
275
- # TODO per axis, degree !!
276
- # m = np.quantile(bounds[:,:,:,0,:], q=qm,axis=0)
277
- # M = np.quantile(bounds[:,:,:,1,:], q=1-qM,axis=0)
278
- num_pts, num_axis, num_degrees, _, num_parameters = bounds.shape
279
- m = [
280
- [
281
- [
282
- np.nanquantile(
283
- bounds[:, ax, degree, 0, parameter], axis=0, q=qm
284
- )
285
- for parameter in range(num_parameters)
286
- ]
287
- for degree in range(num_degrees)
288
- ]
289
- for ax in range(num_axis)
290
- ]
291
- m = np.asarray(m)
292
- M = [
293
- [
294
- [
295
- np.nanquantile(
296
- bounds[:, ax, degree, 1, parameter], axis=0, q=1 - qM
297
- )
298
- for parameter in range(num_parameters)
299
- ]
300
- for degree in range(num_degrees)
301
- ]
302
- for ax in range(num_axis)
303
- ]
304
- M = np.asarray(M)
305
- else:
306
- num_pts, num_axis, num_degrees, _, num_parameters = bounds.shape
307
- m = [
308
- [
309
- [
310
- np.nanmin(bounds[:, ax, degree, 0, parameter], axis=0)
311
- for parameter in range(num_parameters)
312
- ]
313
- for degree in range(num_degrees)
314
- ]
315
- for ax in range(num_axis)
316
- ]
317
- m = np.asarray(m)
318
- M = [
319
- [
320
- [
321
- np.nanmax(bounds[:, ax, degree, 1, parameter], axis=0)
322
- for parameter in range(num_parameters)
323
- ]
324
- for degree in range(num_degrees)
325
- ]
326
- for ax in range(num_axis)
327
- ]
328
- M = np.asarray(M)
329
- # m = bounds[:,:,:,0,:].min(axis=0)
330
- # M = bounds[:,:,:,1,:].max(axis=0)
331
- return (m, M)
332
-
333
- @staticmethod
334
- def _infer_grid(
335
- X: List[PyModule_type], strategy: str, resolution: int, degrees=None
336
- ):
337
- """
338
- Given a list of PyModules, computes a multiparameter discrete grid,
339
- with a given strategy,
340
- from the filtration values of the summands of the modules.
341
- """
342
- num_parameters = X[0].num_parameters
343
- if degrees is None:
344
- # Format here : ((filtration values of parameter) for parameter)
345
- filtration_values = tuple(
346
- mod.get_filtration_values(unique=True) for mod in X
347
- )
348
- else:
349
- filtration_values = tuple(
350
- mod.get_module_of_degrees(degrees).get_filtration_values(unique=True)
351
- for mod in X
352
- )
353
-
354
- if "_mean" in strategy:
355
- substrategy = strategy.split("_")[0]
356
- processed_filtration_values = [
357
- reduce_grid(f, resolution, substrategy, unique=False)
358
- for f in filtration_values
359
- ]
360
- reduced_grid = np.mean(processed_filtration_values, axis=0)
361
- # elif "_quantile" in strategy:
362
- # substrategy = strategy.split("_")[0]
363
- # processed_filtration_values = [reduce_grid(f, resolution, substrategy, unique=False) for f in filtration_values]
364
- # reduced_grid = np.qu(processed_filtration_values, axis=0)
365
- else:
366
- filtration_values = [
367
- np.unique(
368
- np.concatenate([f[parameter] for f in filtration_values], axis=0)
369
- )
370
- for parameter in range(num_parameters)
371
- ]
372
- reduced_grid = reduce_grid(
373
- filtration_values, resolution, strategy, unique=True
374
- )
375
-
376
- return reduced_grid
377
-
378
- def _infer_degrees(self, X):
379
- if self.degrees is None:
380
- max_degrees = [
381
- x[ax].max_degree for i, ax in enumerate(self._axis) for x in X
382
- ] + [0]
383
- self._degrees = np.arange(np.max(max_degrees) + 1)
384
- else:
385
- self._degrees = self.degrees
386
-
387
- def fit(self, X_in, y=None):
388
- X = self._maybe_from_dump(X_in)
389
- if len(X) == 0:
390
- return self
391
- self._has_axis = self._infer_axis(X)
392
- # assert not self._has_axis or isinstance(X[0][0], mp.PyModule)
393
- if self.axis is None and self._has_axis:
394
- self.axis = -1
395
- if self.axis is not None and not (self._has_axis):
396
- raise Exception(f"SMF didn't find an axis, but requested axis {self.axis}")
397
- if self._has_axis:
398
- self._num_axis = len(X[0])
399
- if self.verbose:
400
- print("-----------MMAFormatter-----------")
401
- print("---- Infered stats")
402
- print(f"Found axis : {self._has_axis}, num : {self._num_axis}")
403
- print(f"Number of parameters : {self._num_parameters}")
404
- self._axis = (
405
- [slice(None)]
406
- if self.axis is None
407
- else range(self._num_axis) if self.axis == -1 else [self.axis]
408
- )
409
- self._infer_degrees(X)
410
-
411
- self._num_parameters = self._infer_num_parameters(X, ax=self._axis[0])
412
- if self.normalize:
413
- # print(self._axis)
414
- self._module_bounds = self._infer_bounds(
415
- X, self._degrees, self._axis, self.quantiles
416
- )
417
- else:
418
- m = np.zeros((self._num_axis, len(self._degrees), self._num_parameters))
419
- M = m + 1
420
- self._module_bounds = (m, M)
421
- assert self._num_parameters == self._module_bounds[0].shape[-1]
422
- if self.verbose:
423
- print("---- Bounds (only computed if normalize):")
424
- if self._has_axis and self._num_axis > 1:
425
- print("(axis) x (degree) x (parameter)")
426
- else:
427
- print("(degree) x (parameter)")
428
- m, M = self._module_bounds
429
- print("-- Lower bound : ", m.shape)
430
- print(m)
431
- print("-- Upper bound :", M.shape)
432
- print(M)
433
- w = 1 if self.weights is None else np.asarray(self.weights)
434
- m, M = self._module_bounds
435
- normalizer = M - m
436
- zero_normalizer = normalizer == 0
437
- if np.any(zero_normalizer):
438
- from warnings import warn
439
-
440
- warn(f"Encountered empty bounds. Please fix me. \n M-m = {normalizer}")
441
- normalizer[zero_normalizer] = 1
442
- self._normalization_factors = w / normalizer
443
- if self.verbose:
444
- print("-- Normalization factors:", self._normalization_factors.shape)
445
- print(self._normalization_factors)
446
-
447
- if self.verbose:
448
- print("---- Module size :")
449
- for ax in self._axis:
450
- print(f"- Axis {ax}")
451
- for degree in self._degrees:
452
- sizes = [len(x[ax].get_module_of_degree(degree)) for x in X]
453
- print(
454
- f" - Degree {degree} size \
455
- {np.mean(sizes).round(decimals=2)}\
456
- ±{np.std(sizes).round(decimals=2)}"
457
- )
458
- print("----------------------------------")
459
- return self
460
-
461
- @staticmethod
462
- def copy_transform(mod, degrees, translation, rescale_factors, new_box):
463
- copy = mod.get_module_of_degrees(
464
- degrees
465
- ) # and only returns the specific degrees
466
- for j, degree in enumerate(degrees):
467
- copy.translate(translation[j], degree=degree)
468
- copy.rescale(rescale_factors[j], degree=degree)
469
- copy.set_box(new_box)
470
- return copy
471
-
472
- def transform(self, X_in):
473
- X = self._maybe_from_dump(X_in)
474
- if np.any(self._normalization_factors != 1):
475
- if self.verbose:
476
- print("Normalizing...", end="")
477
- w = (
478
- [1] * self._num_parameters
479
- if self.weights is None
480
- else np.asarray(self.weights)
481
- )
482
- standard_box = PyBox_f64([0] * self._num_parameters, w)
483
-
484
- X_copy = [
485
- [
486
- self.copy_transform(
487
- mod=x[ax],
488
- degrees=self._degrees,
489
- translation=-self._module_bounds[0][i],
490
- rescale_factors=self._normalization_factors[i],
491
- new_box=standard_box,
492
- )
493
- for i, ax in enumerate(self._axis)
494
- ]
495
- for x in X
496
- ]
497
- if self.verbose:
498
- print("Done.")
499
- return X_copy
500
- if self.axis != -1:
501
- X = [x[self.axis] for x in X]
502
- if self.dump:
503
- import pickle
504
-
505
- X = [pickle.dumps(mods) for mods in X]
506
- return X
507
- # return [todo(x) for x in X]
508
-
509
-
510
- class MMA2IMG(BaseEstimator, TransformerMixin):
511
- def __init__(
512
- self,
513
- degrees: list,
514
- bandwidth: float = 0.1,
515
- power: float = 1,
516
- normalize: bool = False,
517
- resolution: list | int = 50,
518
- plot: bool = False,
519
- box=None,
520
- n_jobs=-1,
521
- flatten=False,
522
- progress=False,
523
- grid_strategy="regular",
524
- kernel="linear",
525
- signed: bool = False,
526
- ):
527
- self.bandwidth = bandwidth
528
- self.degrees = degrees
529
- self.resolution = resolution
530
- self.box = box
531
- self.plot = plot
532
- self._box = None
533
- self.normalize = normalize
534
- self.power = power
535
- self._has_axis = None
536
- self._num_parameters = None
537
- self.n_jobs = n_jobs
538
- self.flatten = flatten
539
- self.progress = progress
540
- self.grid_strategy = grid_strategy
541
- self._num_axis = None
542
- self._coords_to_compute = None
543
- self._new_resolutions = None
544
- self.kernel = kernel
545
- self.signed = signed
546
-
547
- def fit(self, X, y=None):
548
- # TODO infer box
549
- # TODO rescale module
550
- self._has_axis = MMAFormatter._infer_axis(X)
551
- if self._has_axis:
552
- self._num_axis = len(X[0])
553
- if self.box is None:
554
- self._box = [[0, 0], [1, 1]]
555
- else:
556
- self._box = self.box
557
- if self._has_axis:
558
- its = (tuple(x[axis] for x in X) for axis in range(self._num_axis))
559
- crs = tuple(
560
- MMAFormatter._infer_grid(
561
- X_axis, self.grid_strategy, self.resolution, degrees=self.degrees
562
- )
563
- for X_axis in its
564
- )
565
- self._coords_to_compute = (
566
- crs # not the same resolutions, so cannot be put in an array
567
- )
568
- self._new_resolutions = np.asarray([tuple(len(g) for g in G) for G in crs])
569
- else:
570
- coords = MMAFormatter._infer_grid(
571
- X, self.grid_strategy, self.resolution, degrees=self.degrees
572
- )
573
- self._coords_to_compute = coords
574
- self._new_resolutions = np.array([len(g) for g in coords])
575
- return self
576
-
577
- def transform(self, X):
578
- img_args = {
579
- "bandwidth": self.bandwidth,
580
- "p": self.power,
581
- "normalize": self.normalize,
582
- # "plot":self.plot,
583
- # "cb":1, # colorbar
584
- # "resolution" : self.resolution, # info in coordinates
585
- "box": self.box,
586
- "degrees": self.degrees,
587
- # num_jobs is better for parallel over modules.
588
- "n_jobs": self.n_jobs,
589
- "kernel": self.kernel,
590
- "signed": self.signed,
591
- "flatten": True, # custom coordinates
592
- }
593
- if self._has_axis:
594
-
595
- def todo1(x, c):
596
- return x.representation(grid=c, **img_args)
597
-
598
- else:
599
-
600
- def todo1(x):
601
- return x.representation(grid=self._coords_to_compute, **img_args)[
602
- None, :
603
- ] # shape same as has_axis
604
-
605
- if self._has_axis:
606
-
607
- def todo2(mods):
608
- return tuple(
609
- todo1(mod, c) for mod, c in zip(mods, self._coords_to_compute)
610
- )
611
-
612
- else:
613
- todo2 = todo1
614
-
615
- if self.flatten:
616
-
617
- def todo(mods):
618
- return np.concatenate(todo2(mods), axis=1).flatten()
619
-
620
- else:
621
-
622
- def todo(mods):
623
- return tuple(
624
- img.reshape(len(img_args["degrees"]), *r)
625
- for img, r in zip(todo2(mods), self._new_resolutions)
626
- )
627
-
628
- return Parallel(n_jobs=self.n_jobs, backend="threading")(
629
- delayed(todo)(x)
630
- for x in tqdm(X, desc="Computing images", disable=not self.progress)
631
- ) # res depends on ax (infer_grid)
632
-
633
-
634
- class MMA2Landscape(BaseEstimator, TransformerMixin):
635
- """
636
- Turns a list of MMA approximations into Landscapes vectorisations
637
- """
638
-
639
- def __init__(
640
- self,
641
- resolution=[100, 100],
642
- degrees: list[int] | None = [0, 1],
643
- ks: Iterable[int] = range(5),
644
- phi: Callable = np.sum,
645
- box=None,
646
- plot: bool = False,
647
- n_jobs=-1,
648
- filtration_quantile: float = 0.01,
649
- ) -> None:
650
- super().__init__()
651
- self.resolution: list[int] = resolution
652
- self.degrees = degrees
653
- self.ks = ks
654
- self.phi = phi # Has to have a axis=0 !
655
- self.box = box
656
- self.plot = plot
657
- self.n_jobs = n_jobs
658
- self.filtration_quantile = filtration_quantile
659
- return
660
-
661
- def fit(self, X, y=None):
662
- if len(X) <= 0:
663
- return
664
- assert (
665
- X[0].num_parameters == 2
666
- ), f"Number of parameters {X[0].num_parameters} has to be 2."
667
- if self.box is None:
668
-
669
- def _bottom(mod):
670
- return mod.get_bottom()
671
-
672
- def _top(mod):
673
- return mod.get_top()
674
-
675
- m = np.quantile(
676
- Parallel(n_jobs=self.n_jobs, backend="threading")(
677
- delayed(_bottom)(mod) for mod in X
678
- ),
679
- q=self.filtration_quantile,
680
- axis=0,
681
- )
682
- M = np.quantile(
683
- Parallel(n_jobs=self.n_jobs, backend="threading")(
684
- delayed(_top)(mod) for mod in X
685
- ),
686
- q=1 - self.filtration_quantile,
687
- axis=0,
688
- )
689
- self.box = [m, M]
690
- return self
691
-
692
- def transform(self, X) -> list[np.ndarray]:
693
- if len(X) <= 0:
694
- return []
695
-
696
- def todo(mod):
697
- return np.concatenate(
698
- [
699
- self.phi(
700
- mod.landscapes(
701
- ks=self.ks,
702
- resolution=self.resolution,
703
- degree=degree,
704
- plot=self.plot,
705
- ),
706
- axis=0,
707
- ).flatten()
708
- for degree in self.degrees
709
- ]
710
- ).flatten()
711
-
712
- return Parallel(n_jobs=self.n_jobs, backend="threading")(
713
- delayed(todo)(x) for x in X
714
- )
1
+ from typing import Callable, Iterable, List, Optional, Union
2
+
3
+ import numpy as np
4
+ from joblib import Parallel, delayed
5
+ from sklearn.base import BaseEstimator, TransformerMixin
6
+ from tqdm import tqdm
7
+
8
+ import multipers as mp
9
+ import multipers.simplex_tree_multi
10
+ import multipers.slicer
11
+ from multipers.grids import compute_grid as reduce_grid
12
+ from multipers.mma_structures import PyBox_f64, PyModule_type
13
+
14
+ _FilteredComplexType = Union[
15
+ mp.slicer.Slicer_type, mp.simplex_tree_multi.SimplexTreeMulti_type
16
+ ]
17
+
18
+
19
+ class FilteredComplex2MMA(BaseEstimator, TransformerMixin):
20
+ """
21
+ Turns a list of list of simplextrees or slicers to MMA approximations.
22
+
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ n_jobs: int = -1,
28
+ expand_dim: Optional[int] = None,
29
+ prune_degrees_above: Optional[int] = None,
30
+ progress=False,
31
+ minpres_degrees: Optional[Iterable[int]] = None,
32
+ plot: bool = False,
33
+ **persistence_kwargs,
34
+ ) -> None:
35
+ super().__init__()
36
+ self.persistence_args = persistence_kwargs
37
+ self.n_jobs = n_jobs
38
+ self._num_axis = None
39
+ self.prune_degrees_above = prune_degrees_above
40
+ self.progress = progress
41
+ self.expand_dim = expand_dim
42
+ self._boxes = None
43
+ self._is_minpres = None
44
+ self.minpres_degrees = minpres_degrees
45
+ self.plot = plot
46
+ return
47
+
48
+ @staticmethod
49
+ def _is_filtered_complex(input):
50
+ return mp.simplex_tree_multi.is_simplextree_multi(input) or mp.slicer.is_slicer(
51
+ input, allow_minpres=True
52
+ )
53
+
54
+ def _input_checks(self, X):
55
+ assert len(X) > 0, "No filtered complex found. Cannot fit."
56
+ assert self._is_filtered_complex(
57
+ X[0][0]
58
+ ), f"X[0] is not a known filtered complex, {X[0]=}, nor X[0][0]."
59
+ self._num_axis = len(X[0])
60
+ first = X[0][0]
61
+ assert (
62
+ not mp.slicer.is_slicer(first) or self.expand_dim is None
63
+ ), "Cannot expand slicers."
64
+ self._is_minpres = mp.slicer.is_slicer(first) and isinstance(
65
+ first, Union[tuple, list]
66
+ )
67
+ assert not (
68
+ self._is_minpres and self.minpres_degrees is not None
69
+ ), "Input is already a minpres. Cannot reduce again."
70
+
71
+ def _infer_bounding_box(self, X):
72
+ assert self._num_axis is not None, "Fit first"
73
+ filtration_values = (
74
+ np.asarray(
75
+ [
76
+ [s.filtration_bounds() for x in X for s in x[axis]]
77
+ for axis in range(self._num_axis)
78
+ ]
79
+ )
80
+ if self._is_minpres
81
+ else np.asarray(
82
+ [
83
+ [x[axis].filtration_bounds() for x in X]
84
+ for axis in range(self._num_axis)
85
+ ]
86
+ )
87
+ )
88
+ num_parameters = filtration_values.shape[-1]
89
+ # Output : axis, data, min/max, num_parameters
90
+ # print("TEST : NUM PARAMETERS ", num_parameters)
91
+ m = np.asarray(
92
+ [
93
+ [
94
+ filtration_values[axis, :, 0, parameter].min()
95
+ for parameter in range(num_parameters)
96
+ ]
97
+ for axis in range(self._num_axis)
98
+ ]
99
+ )
100
+ M = np.asarray(
101
+ [
102
+ [
103
+ filtration_values[axis, :, 1, parameter].max()
104
+ for parameter in range(num_parameters)
105
+ ]
106
+ for axis in range(self._num_axis)
107
+ ]
108
+ )
109
+ # shape of m/M axis,num_parameters
110
+ self._boxes = [
111
+ np.array([m_of_axis, M_of_axis]) for m_of_axis, M_of_axis in zip(m, M)
112
+ ]
113
+
114
+ def fit(self, X, y=None):
115
+ if len(X) == 0:
116
+ return self
117
+ self._input_checks(X)
118
+ self._infer_bounding_box(X)
119
+ return self
120
+
121
+ def transform(self, X):
122
+ if self.prune_degrees_above is not None:
123
+ for x in X:
124
+ for x_ in x:
125
+ if self._is_minpres:
126
+ for s_ in x_:
127
+ s_.prune_above_dimension(
128
+ self.prune_degrees_above
129
+ ) # we only do for H0 for computational ease
130
+ else:
131
+ x_.prune_above_dimension(
132
+ self.prune_degrees_above
133
+ ) # we only do for H0 for computational ease
134
+
135
+ def todo1(x, box):
136
+ if self.expand_dim is not None:
137
+ x.expansion(self.expand_dim)
138
+ if self.minpres_degrees is not None:
139
+ x = mp.slicer.minimal_presentation(
140
+ mp.Slicer(x), degrees=self.minpres_degrees, vineyard=True
141
+ )
142
+ mod = mp.module_approximation(
143
+ x, box=box, verbose=False, **self.persistence_args
144
+ )
145
+ if self.plot:
146
+ mod.plot()
147
+ return mod
148
+
149
+ def todo(sts: Iterable[_FilteredComplexType]):
150
+ return tuple(todo1(st, box) for st, box in zip(sts, self._boxes))
151
+
152
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
153
+ delayed(todo)(x)
154
+ for x in tqdm(X, desc="Computing modules", disable=not self.progress)
155
+ )
156
+
157
+
158
+ class SimplexTree2MMA(FilteredComplex2MMA):
159
+ def __init__(
160
+ self,
161
+ n_jobs: int = -1,
162
+ expand_dim: Optional[int] = None,
163
+ prune_degrees_above: Optional[int] = None,
164
+ progress=False,
165
+ minpres_degrees: Optional[Iterable[int]] = None,
166
+ **persistence_kwargs,
167
+ ):
168
+ stuff = locals()
169
+ stuff.pop("self")
170
+ keys = list(stuff.keys())
171
+ for key in keys:
172
+ if key.startswith("__"):
173
+ stuff.pop(key)
174
+ super().__init__(**stuff)
175
+ from warnings import warn
176
+
177
+ warn("This class is deprecated, use FilteredComplex2MMA instead.")
178
+
179
+
180
+ class MMAFormatter(BaseEstimator, TransformerMixin):
181
+ def __init__(
182
+ self,
183
+ degrees: Optional[list[int]] = None,
184
+ axis=None,
185
+ verbose: bool = False,
186
+ normalize: bool = False,
187
+ weights=None,
188
+ quantiles=None,
189
+ dump=False,
190
+ from_dump=False,
191
+ ):
192
+ self._module_bounds = None
193
+ self.verbose = verbose
194
+ self.axis = axis
195
+ self._axis = []
196
+ self._has_axis = None
197
+ self._num_axis = 0
198
+ self.degrees = degrees
199
+ self._degrees = None
200
+ self.normalize = normalize
201
+ self._num_parameters = None
202
+ self.weights = weights
203
+ self.quantiles = quantiles
204
+ self.dump = dump
205
+ self.from_dump = from_dump
206
+
207
+ @staticmethod
208
+ def _maybe_from_dump(X_in):
209
+ if len(X_in) == 0:
210
+ return X_in
211
+ import pickle
212
+
213
+ if isinstance(X_in[0], bytes):
214
+ X = [pickle.loads(mods) for mods in X_in]
215
+ else:
216
+ X = X_in
217
+ return X
218
+ # return [[mp.multiparameter_module_approximation.from_dump(mod) for mod in mods] for mods in dumped_modules]
219
+
220
+ @staticmethod
221
+ def _get_module_bound(x, degree):
222
+ """
223
+ Output format : (2,num_parameters)
224
+ """
225
+ # l,L = x.get_box()
226
+ filtration_values = x.get_module_of_degree(degree).get_filtration_values(
227
+ unique=True
228
+ )
229
+ out = np.array([[f[0], f[-1]] for f in filtration_values if len(f) > 0]).T
230
+ if len(out) != 2:
231
+ print(f"Missing degree {degree} here !")
232
+ m = M = [np.nan for _ in range(x.num_parameters)]
233
+ else:
234
+ m, M = out
235
+ # m = np.where(m<np.inf, m, l)
236
+ # M = np.where(M>-np.inf, M,L)
237
+ return m, M
238
+
239
+ @staticmethod
240
+ def _infer_axis(X):
241
+ has_axis = not isinstance(X[0], PyModule_type)
242
+ assert not has_axis or isinstance(X[0][0], PyModule_type)
243
+ return has_axis
244
+
245
+ @staticmethod
246
+ def _infer_num_parameters(X, ax=slice(None)):
247
+ return X[0][ax].num_parameters
248
+
249
+ @staticmethod
250
+ def _infer_bounds(X, degrees=None, axis=[slice(None)], quantiles=None):
251
+ """
252
+ Compute bounds of filtration values of a list of modules.
253
+
254
+ Output Format
255
+ -------------
256
+ m,M of shape : (num_axis,num_degrees,2,num_parameters)
257
+ """
258
+ if degrees is None:
259
+ degrees = np.arange(X[0][axis[0]].max_degree + 1)
260
+ bounds = np.array(
261
+ [
262
+ [
263
+ [
264
+ MMAFormatter._get_module_bound(x[ax], degree)
265
+ for degree in degrees
266
+ ]
267
+ for ax in axis
268
+ ]
269
+ for x in X
270
+ ]
271
+ )
272
+ if quantiles is not None:
273
+ qm, qM = quantiles
274
+ # TODO per axis, degree !!
275
+ # m = np.quantile(bounds[:,:,:,0,:], q=qm,axis=0)
276
+ # M = np.quantile(bounds[:,:,:,1,:], q=1-qM,axis=0)
277
+ num_pts, num_axis, num_degrees, _, num_parameters = bounds.shape
278
+ m = [
279
+ [
280
+ [
281
+ np.nanquantile(
282
+ bounds[:, ax, degree, 0, parameter], axis=0, q=qm
283
+ )
284
+ for parameter in range(num_parameters)
285
+ ]
286
+ for degree in range(num_degrees)
287
+ ]
288
+ for ax in range(num_axis)
289
+ ]
290
+ m = np.asarray(m)
291
+ M = [
292
+ [
293
+ [
294
+ np.nanquantile(
295
+ bounds[:, ax, degree, 1, parameter], axis=0, q=1 - qM
296
+ )
297
+ for parameter in range(num_parameters)
298
+ ]
299
+ for degree in range(num_degrees)
300
+ ]
301
+ for ax in range(num_axis)
302
+ ]
303
+ M = np.asarray(M)
304
+ else:
305
+ num_pts, num_axis, num_degrees, _, num_parameters = bounds.shape
306
+ m = [
307
+ [
308
+ [
309
+ np.nanmin(bounds[:, ax, degree, 0, parameter], axis=0)
310
+ for parameter in range(num_parameters)
311
+ ]
312
+ for degree in range(num_degrees)
313
+ ]
314
+ for ax in range(num_axis)
315
+ ]
316
+ m = np.asarray(m)
317
+ M = [
318
+ [
319
+ [
320
+ np.nanmax(bounds[:, ax, degree, 1, parameter], axis=0)
321
+ for parameter in range(num_parameters)
322
+ ]
323
+ for degree in range(num_degrees)
324
+ ]
325
+ for ax in range(num_axis)
326
+ ]
327
+ M = np.asarray(M)
328
+ # m = bounds[:,:,:,0,:].min(axis=0)
329
+ # M = bounds[:,:,:,1,:].max(axis=0)
330
+ return (m, M)
331
+
332
+ @staticmethod
333
+ def _infer_grid(
334
+ X: List[PyModule_type], strategy: str, resolution: int, degrees=None
335
+ ):
336
+ """
337
+ Given a list of PyModules, computes a multiparameter discrete grid,
338
+ with a given strategy,
339
+ from the filtration values of the summands of the modules.
340
+ """
341
+ num_parameters = X[0].num_parameters
342
+ if degrees is None:
343
+ # Format here : ((filtration values of parameter) for parameter)
344
+ filtration_values = tuple(
345
+ mod.get_filtration_values(unique=True) for mod in X
346
+ )
347
+ else:
348
+ filtration_values = tuple(
349
+ mod.get_module_of_degrees(degrees).get_filtration_values(unique=True)
350
+ for mod in X
351
+ )
352
+
353
+ if "_mean" in strategy:
354
+ substrategy = strategy.split("_")[0]
355
+ processed_filtration_values = [
356
+ reduce_grid(f, resolution, substrategy, unique=False)
357
+ for f in filtration_values
358
+ ]
359
+ reduced_grid = np.mean(processed_filtration_values, axis=0)
360
+ # elif "_quantile" in strategy:
361
+ # substrategy = strategy.split("_")[0]
362
+ # processed_filtration_values = [reduce_grid(f, resolution, substrategy, unique=False) for f in filtration_values]
363
+ # reduced_grid = np.qu(processed_filtration_values, axis=0)
364
+ else:
365
+ filtration_values = [
366
+ np.unique(
367
+ np.concatenate([f[parameter] for f in filtration_values], axis=0)
368
+ )
369
+ for parameter in range(num_parameters)
370
+ ]
371
+ reduced_grid = reduce_grid(
372
+ filtration_values, resolution, strategy, unique=True
373
+ )
374
+
375
+ return reduced_grid
376
+
377
+ def _infer_degrees(self, X):
378
+ if self.degrees is None:
379
+ max_degrees = [
380
+ x[ax].max_degree for i, ax in enumerate(self._axis) for x in X
381
+ ] + [0]
382
+ self._degrees = np.arange(np.max(max_degrees) + 1)
383
+ else:
384
+ self._degrees = self.degrees
385
+
386
+ def fit(self, X_in, y=None):
387
+ X = self._maybe_from_dump(X_in)
388
+ if len(X) == 0:
389
+ return self
390
+ self._has_axis = self._infer_axis(X)
391
+ # assert not self._has_axis or isinstance(X[0][0], mp.PyModule)
392
+ if self.axis is None and self._has_axis:
393
+ self.axis = -1
394
+ if self.axis is not None and not (self._has_axis):
395
+ raise Exception(f"SMF didn't find an axis, but requested axis {self.axis}")
396
+ if self._has_axis:
397
+ self._num_axis = len(X[0])
398
+ if self.verbose:
399
+ print("-----------MMAFormatter-----------")
400
+ print("---- Infered stats")
401
+ print(f"Found axis : {self._has_axis}, num : {self._num_axis}")
402
+ print(f"Number of parameters : {self._num_parameters}")
403
+ self._axis = (
404
+ [slice(None)]
405
+ if self.axis is None
406
+ else range(self._num_axis) if self.axis == -1 else [self.axis]
407
+ )
408
+ self._infer_degrees(X)
409
+
410
+ self._num_parameters = self._infer_num_parameters(X, ax=self._axis[0])
411
+ if self.normalize:
412
+ # print(self._axis)
413
+ self._module_bounds = self._infer_bounds(
414
+ X, self._degrees, self._axis, self.quantiles
415
+ )
416
+ else:
417
+ m = np.zeros((self._num_axis, len(self._degrees), self._num_parameters))
418
+ M = m + 1
419
+ self._module_bounds = (m, M)
420
+ assert self._num_parameters == self._module_bounds[0].shape[-1]
421
+ if self.verbose:
422
+ print("---- Bounds (only computed if normalize):")
423
+ if self._has_axis and self._num_axis > 1:
424
+ print("(axis) x (degree) x (parameter)")
425
+ else:
426
+ print("(degree) x (parameter)")
427
+ m, M = self._module_bounds
428
+ print("-- Lower bound : ", m.shape)
429
+ print(m)
430
+ print("-- Upper bound :", M.shape)
431
+ print(M)
432
+ w = 1 if self.weights is None else np.asarray(self.weights)
433
+ m, M = self._module_bounds
434
+ normalizer = M - m
435
+ zero_normalizer = normalizer == 0
436
+ if np.any(zero_normalizer):
437
+ from warnings import warn
438
+
439
+ warn(f"Encountered empty bounds. Please fix me. \n M-m = {normalizer}")
440
+ normalizer[zero_normalizer] = 1
441
+ self._normalization_factors = w / normalizer
442
+ if self.verbose:
443
+ print("-- Normalization factors:", self._normalization_factors.shape)
444
+ print(self._normalization_factors)
445
+
446
+ if self.verbose:
447
+ print("---- Module size :")
448
+ for ax in self._axis:
449
+ print(f"- Axis {ax}")
450
+ for degree in self._degrees:
451
+ sizes = [len(x[ax].get_module_of_degree(degree)) for x in X]
452
+ print(
453
+ f" - Degree {degree} size \
454
+ {np.mean(sizes).round(decimals=2)}\
455
+ ±{np.std(sizes).round(decimals=2)}"
456
+ )
457
+ print("----------------------------------")
458
+ return self
459
+
460
+ @staticmethod
461
+ def copy_transform(mod, degrees, translation, rescale_factors, new_box):
462
+ copy = mod.get_module_of_degrees(
463
+ degrees
464
+ ) # and only returns the specific degrees
465
+ for j, degree in enumerate(degrees):
466
+ copy.translate(translation[j], degree=degree)
467
+ copy.rescale(rescale_factors[j], degree=degree)
468
+ copy.set_box(new_box)
469
+ return copy
470
+
471
+ def transform(self, X_in):
472
+ X = self._maybe_from_dump(X_in)
473
+ if np.any(self._normalization_factors != 1):
474
+ if self.verbose:
475
+ print("Normalizing...", end="")
476
+ w = (
477
+ [1] * self._num_parameters
478
+ if self.weights is None
479
+ else np.asarray(self.weights)
480
+ )
481
+ standard_box = PyBox_f64([0] * self._num_parameters, w)
482
+
483
+ X_copy = [
484
+ [
485
+ self.copy_transform(
486
+ mod=x[ax],
487
+ degrees=self._degrees,
488
+ translation=-self._module_bounds[0][i],
489
+ rescale_factors=self._normalization_factors[i],
490
+ new_box=standard_box,
491
+ )
492
+ for i, ax in enumerate(self._axis)
493
+ ]
494
+ for x in X
495
+ ]
496
+ if self.verbose:
497
+ print("Done.")
498
+ return X_copy
499
+ if self.axis != -1:
500
+ X = [x[self.axis] for x in X]
501
+ if self.dump:
502
+ import pickle
503
+
504
+ X = [pickle.dumps(mods) for mods in X]
505
+ return X
506
+ # return [todo(x) for x in X]
507
+
508
+
509
+ class MMA2IMG(BaseEstimator, TransformerMixin):
510
+ def __init__(
511
+ self,
512
+ degrees: list,
513
+ bandwidth: float = 0.1,
514
+ power: float = 1,
515
+ normalize: bool = False,
516
+ resolution: list | int = 50,
517
+ plot: bool = False,
518
+ box=None,
519
+ n_jobs=-1,
520
+ flatten=False,
521
+ progress=False,
522
+ grid_strategy="regular",
523
+ kernel="linear",
524
+ signed: bool = False,
525
+ ):
526
+ self.bandwidth = bandwidth
527
+ self.degrees = degrees
528
+ self.resolution = resolution
529
+ self.box = box
530
+ self.plot = plot
531
+ self._box = None
532
+ self.normalize = normalize
533
+ self.power = power
534
+ self._has_axis = None
535
+ self._num_parameters = None
536
+ self.n_jobs = n_jobs
537
+ self.flatten = flatten
538
+ self.progress = progress
539
+ self.grid_strategy = grid_strategy
540
+ self._num_axis = None
541
+ self._coords_to_compute = None
542
+ self._new_resolutions = None
543
+ self.kernel = kernel
544
+ self.signed = signed
545
+
546
+ def fit(self, X, y=None):
547
+ # TODO infer box
548
+ # TODO rescale module
549
+ self._has_axis = MMAFormatter._infer_axis(X)
550
+ if self._has_axis:
551
+ self._num_axis = len(X[0])
552
+ if self.box is None:
553
+ self._box = [[0, 0], [1, 1]]
554
+ else:
555
+ self._box = self.box
556
+ if self._has_axis:
557
+ its = (tuple(x[axis] for x in X) for axis in range(self._num_axis))
558
+ crs = tuple(
559
+ MMAFormatter._infer_grid(
560
+ X_axis, self.grid_strategy, self.resolution, degrees=self.degrees
561
+ )
562
+ for X_axis in its
563
+ )
564
+ self._coords_to_compute = (
565
+ crs # not the same resolutions, so cannot be put in an array
566
+ )
567
+ self._new_resolutions = np.asarray([tuple(len(g) for g in G) for G in crs])
568
+ else:
569
+ coords = MMAFormatter._infer_grid(
570
+ X, self.grid_strategy, self.resolution, degrees=self.degrees
571
+ )
572
+ self._coords_to_compute = coords
573
+ self._new_resolutions = np.array([len(g) for g in coords])
574
+ return self
575
+
576
+ def transform(self, X):
577
+ img_args = {
578
+ "bandwidth": self.bandwidth,
579
+ "p": self.power,
580
+ "normalize": self.normalize,
581
+ # "plot":self.plot,
582
+ # "cb":1, # colorbar
583
+ # "resolution" : self.resolution, # info in coordinates
584
+ "box": self.box,
585
+ "degrees": self.degrees,
586
+ # num_jobs is better for parallel over modules.
587
+ "n_jobs": self.n_jobs,
588
+ "kernel": self.kernel,
589
+ "signed": self.signed,
590
+ "flatten": True, # custom coordinates
591
+ }
592
+ if self._has_axis:
593
+
594
+ def todo1(x, c):
595
+ return x.representation(grid=c, **img_args)
596
+
597
+ else:
598
+
599
+ def todo1(x):
600
+ return x.representation(grid=self._coords_to_compute, **img_args)[
601
+ None, :
602
+ ] # shape same as has_axis
603
+
604
+ if self._has_axis:
605
+
606
+ def todo2(mods):
607
+ return tuple(
608
+ todo1(mod, c) for mod, c in zip(mods, self._coords_to_compute)
609
+ )
610
+
611
+ else:
612
+ todo2 = todo1
613
+
614
+ if self.flatten:
615
+
616
+ def todo(mods):
617
+ return np.concatenate(todo2(mods), axis=1).flatten()
618
+
619
+ else:
620
+
621
+ def todo(mods):
622
+ return tuple(
623
+ img.reshape(len(img_args["degrees"]), *r)
624
+ for img, r in zip(todo2(mods), self._new_resolutions)
625
+ )
626
+
627
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
628
+ delayed(todo)(x)
629
+ for x in tqdm(X, desc="Computing images", disable=not self.progress)
630
+ ) # res depends on ax (infer_grid)
631
+
632
+
633
+ class MMA2Landscape(BaseEstimator, TransformerMixin):
634
+ """
635
+ Turns a list of MMA approximations into Landscapes vectorisations
636
+ """
637
+
638
+ def __init__(
639
+ self,
640
+ resolution=[100, 100],
641
+ degrees: list[int] | None = [0, 1],
642
+ ks: Iterable[int] = range(5),
643
+ phi: Callable = np.sum,
644
+ box=None,
645
+ plot: bool = False,
646
+ n_jobs=-1,
647
+ filtration_quantile: float = 0.01,
648
+ ) -> None:
649
+ super().__init__()
650
+ self.resolution: list[int] = resolution
651
+ self.degrees = degrees
652
+ self.ks = ks
653
+ self.phi = phi # Has to have a axis=0 !
654
+ self.box = box
655
+ self.plot = plot
656
+ self.n_jobs = n_jobs
657
+ self.filtration_quantile = filtration_quantile
658
+ return
659
+
660
+ def fit(self, X, y=None):
661
+ if len(X) <= 0:
662
+ return
663
+ assert (
664
+ X[0].num_parameters == 2
665
+ ), f"Number of parameters {X[0].num_parameters} has to be 2."
666
+ if self.box is None:
667
+
668
+ def _bottom(mod):
669
+ return mod.get_bottom()
670
+
671
+ def _top(mod):
672
+ return mod.get_top()
673
+
674
+ m = np.quantile(
675
+ Parallel(n_jobs=self.n_jobs, backend="threading")(
676
+ delayed(_bottom)(mod) for mod in X
677
+ ),
678
+ q=self.filtration_quantile,
679
+ axis=0,
680
+ )
681
+ M = np.quantile(
682
+ Parallel(n_jobs=self.n_jobs, backend="threading")(
683
+ delayed(_top)(mod) for mod in X
684
+ ),
685
+ q=1 - self.filtration_quantile,
686
+ axis=0,
687
+ )
688
+ self.box = [m, M]
689
+ return self
690
+
691
+ def transform(self, X) -> list[np.ndarray]:
692
+ if len(X) <= 0:
693
+ return []
694
+
695
+ def todo(mod):
696
+ return np.concatenate(
697
+ [
698
+ self.phi(
699
+ mod.landscapes(
700
+ ks=self.ks,
701
+ resolution=self.resolution,
702
+ degree=degree,
703
+ plot=self.plot,
704
+ ),
705
+ axis=0,
706
+ ).flatten()
707
+ for degree in self.degrees
708
+ ]
709
+ ).flatten()
710
+
711
+ return Parallel(n_jobs=self.n_jobs, backend="threading")(
712
+ delayed(todo)(x) for x in X
713
+ )