multipers 2.2.3__cp310-cp310-win_amd64.whl → 2.3.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/__init__.py +33 -31
- multipers/_signed_measure_meta.py +430 -430
- multipers/_slicer_meta.py +211 -212
- multipers/data/MOL2.py +458 -458
- multipers/data/UCR.py +18 -18
- multipers/data/graphs.py +466 -466
- multipers/data/immuno_regions.py +27 -27
- multipers/data/pytorch2simplextree.py +90 -90
- multipers/data/shape3d.py +101 -101
- multipers/data/synthetic.py +113 -111
- multipers/distances.py +198 -198
- multipers/filtration_conversions.pxd.tp +84 -84
- multipers/filtrations/__init__.py +18 -0
- multipers/{ml/convolutions.py → filtrations/density.py} +563 -520
- multipers/filtrations/filtrations.py +289 -0
- multipers/filtrations.pxd +224 -224
- multipers/function_rips.cp310-win_amd64.pyd +0 -0
- multipers/function_rips.pyx +105 -105
- multipers/grids.cp310-win_amd64.pyd +0 -0
- multipers/grids.pyx +350 -350
- multipers/gudhi/Persistence_slices_interface.h +132 -132
- multipers/gudhi/Simplex_tree_interface.h +239 -245
- multipers/gudhi/Simplex_tree_multi_interface.h +516 -561
- multipers/gudhi/cubical_to_boundary.h +59 -59
- multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -450
- multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -1070
- multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -579
- multipers/gudhi/gudhi/Debug_utils.h +45 -45
- multipers/gudhi/gudhi/Fields/Multi_field.h +484 -484
- multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -455
- multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -450
- multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -531
- multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -507
- multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -531
- multipers/gudhi/gudhi/Fields/Z2_field.h +355 -355
- multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -376
- multipers/gudhi/gudhi/Fields/Zp_field.h +420 -420
- multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -400
- multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -418
- multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -337
- multipers/gudhi/gudhi/Matrix.h +2107 -2107
- multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -1038
- multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -171
- multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -282
- multipers/gudhi/gudhi/Off_reader.h +173 -173
- multipers/gudhi/gudhi/One_critical_filtration.h +1433 -1431
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -769
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -686
- multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -842
- multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -1350
- multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -1105
- multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -859
- multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -910
- multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -139
- multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -230
- multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -211
- multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -60
- multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -60
- multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -136
- multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -190
- multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -616
- multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -150
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -106
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -219
- multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -327
- multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -1140
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -934
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -934
- multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -980
- multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -1092
- multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -192
- multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -921
- multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -1093
- multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -1012
- multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -1244
- multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -186
- multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -164
- multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -156
- multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -376
- multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -540
- multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -118
- multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -173
- multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -128
- multipers/gudhi/gudhi/Persistent_cohomology.h +745 -745
- multipers/gudhi/gudhi/Points_off_io.h +171 -171
- multipers/gudhi/gudhi/Simple_object_pool.h +69 -69
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -463
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -83
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -106
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -277
- multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -62
- multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -27
- multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -62
- multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -157
- multipers/gudhi/gudhi/Simplex_tree.h +2794 -2794
- multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -163
- multipers/gudhi/gudhi/distance_functions.h +62 -62
- multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -104
- multipers/gudhi/gudhi/persistence_interval.h +253 -253
- multipers/gudhi/gudhi/persistence_matrix_options.h +170 -170
- multipers/gudhi/gudhi/reader_utils.h +367 -367
- multipers/gudhi/mma_interface_coh.h +256 -255
- multipers/gudhi/mma_interface_h0.h +223 -231
- multipers/gudhi/mma_interface_matrix.h +291 -282
- multipers/gudhi/naive_merge_tree.h +536 -575
- multipers/gudhi/scc_io.h +310 -289
- multipers/gudhi/truc.h +957 -888
- multipers/io.cp310-win_amd64.pyd +0 -0
- multipers/io.pyx +714 -711
- multipers/ml/accuracies.py +90 -90
- multipers/ml/invariants_with_persistable.py +79 -79
- multipers/ml/kernels.py +176 -176
- multipers/ml/mma.py +713 -714
- multipers/ml/one.py +472 -472
- multipers/ml/point_clouds.py +352 -346
- multipers/ml/signed_measures.py +1589 -1589
- multipers/ml/sliced_wasserstein.py +461 -461
- multipers/ml/tools.py +113 -113
- multipers/mma_structures.cp310-win_amd64.pyd +0 -0
- multipers/mma_structures.pxd +127 -127
- multipers/mma_structures.pyx +4 -8
- multipers/mma_structures.pyx.tp +1083 -1085
- multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -93
- multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -97
- multipers/multi_parameter_rank_invariant/function_rips.h +322 -322
- multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -769
- multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -148
- multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -369
- multipers/multiparameter_edge_collapse.py +41 -41
- multipers/multiparameter_module_approximation/approximation.h +2298 -2295
- multipers/multiparameter_module_approximation/combinatory.h +129 -129
- multipers/multiparameter_module_approximation/debug.h +107 -107
- multipers/multiparameter_module_approximation/format_python-cpp.h +286 -286
- multipers/multiparameter_module_approximation/heap_column.h +238 -238
- multipers/multiparameter_module_approximation/images.h +79 -79
- multipers/multiparameter_module_approximation/list_column.h +174 -174
- multipers/multiparameter_module_approximation/list_column_2.h +232 -232
- multipers/multiparameter_module_approximation/ru_matrix.h +347 -347
- multipers/multiparameter_module_approximation/set_column.h +135 -135
- multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -36
- multipers/multiparameter_module_approximation/unordered_set_column.h +166 -166
- multipers/multiparameter_module_approximation/utilities.h +403 -419
- multipers/multiparameter_module_approximation/vector_column.h +223 -223
- multipers/multiparameter_module_approximation/vector_matrix.h +331 -331
- multipers/multiparameter_module_approximation/vineyards.h +464 -464
- multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -649
- multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
- multipers/multiparameter_module_approximation.pyx +218 -217
- multipers/pickle.py +90 -53
- multipers/plots.py +342 -334
- multipers/point_measure.cp310-win_amd64.pyd +0 -0
- multipers/point_measure.pyx +322 -320
- multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
- multipers/simplex_tree_multi.pxd +133 -133
- multipers/simplex_tree_multi.pyx +115 -48
- multipers/simplex_tree_multi.pyx.tp +1947 -1935
- multipers/slicer.cp310-win_amd64.pyd +0 -0
- multipers/slicer.pxd +301 -120
- multipers/slicer.pxd.tp +218 -214
- multipers/slicer.pyx +1570 -507
- multipers/slicer.pyx.tp +931 -914
- multipers/tensor/tensor.h +672 -672
- multipers/tensor.pxd +13 -13
- multipers/test.pyx +44 -44
- multipers/tests/__init__.py +57 -57
- multipers/torch/diff_grids.py +217 -217
- multipers/torch/rips_density.py +310 -304
- {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/LICENSE +21 -21
- {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/METADATA +21 -11
- multipers-2.3.1.dist-info/RECORD +182 -0
- {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/WHEEL +1 -1
- multipers/tests/test_diff_helper.py +0 -73
- multipers/tests/test_hilbert_function.py +0 -82
- multipers/tests/test_mma.py +0 -83
- multipers/tests/test_point_clouds.py +0 -49
- multipers/tests/test_python-cpp_conversion.py +0 -82
- multipers/tests/test_signed_betti.py +0 -181
- multipers/tests/test_signed_measure.py +0 -89
- multipers/tests/test_simplextreemulti.py +0 -221
- multipers/tests/test_slicer.py +0 -221
- multipers-2.2.3.dist-info/RECORD +0 -189
- {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/top_level.txt +0 -0
multipers/torch/rips_density.py
CHANGED
|
@@ -1,304 +1,310 @@
|
|
|
1
|
-
from typing import Callable, Literal, Optional
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
import multipers as mp
|
|
8
|
-
from multipers.
|
|
9
|
-
from multipers.simplex_tree_multi import _available_strategies
|
|
10
|
-
from multipers.torch.diff_grids import get_grid
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def function_rips_signed_measure_old(
|
|
14
|
-
x,
|
|
15
|
-
theta: Optional[float] = None,
|
|
16
|
-
function: Literal["dtm", "gaussian", "exponential"] | Callable = "dtm",
|
|
17
|
-
threshold: float = np.inf,
|
|
18
|
-
grid_strategy: _available_strategies = "regular_closest",
|
|
19
|
-
resolution: int = 100,
|
|
20
|
-
return_original: bool = False,
|
|
21
|
-
return_st: bool = False,
|
|
22
|
-
safe_conversion: bool = False,
|
|
23
|
-
num_collapses: int = -1,
|
|
24
|
-
expand_collapse: bool = False,
|
|
25
|
-
dtype=torch.float32,
|
|
26
|
-
**sm_kwargs,
|
|
27
|
-
):
|
|
28
|
-
"""
|
|
29
|
-
Computes a torch-differentiable function-rips signed measure.
|
|
30
|
-
|
|
31
|
-
Input
|
|
32
|
-
-----
|
|
33
|
-
- x (num_pts, dim) : The point cloud
|
|
34
|
-
- theta: For density-like functions : the bandwidth
|
|
35
|
-
- threshold : rips threshold
|
|
36
|
-
- function : Either "dtm", "gaussian", or "exponenetial" or Callable.
|
|
37
|
-
Function to compute the second parameter.
|
|
38
|
-
- grid_strategy: grid coarsenning strategy.
|
|
39
|
-
- resolution : when coarsenning, the target resolution,
|
|
40
|
-
- return_original : Also returns the non-differentiable signed measure.
|
|
41
|
-
- safe_conversion : Activate this if you encounter crashes.
|
|
42
|
-
- **kwargs : for the signed measure computation.
|
|
43
|
-
"""
|
|
44
|
-
assert isinstance(x, torch.Tensor)
|
|
45
|
-
if function == "dtm":
|
|
46
|
-
assert theta is not None, "Provide a theta to compute DTM"
|
|
47
|
-
codensity = DTM(masses=[theta]).fit(x).score_samples_diff(x)[0].type(dtype)
|
|
48
|
-
elif function in ["gaussian", "exponential"]:
|
|
49
|
-
assert theta is not None, "Provide a theta to compute density estimation"
|
|
50
|
-
codensity = (
|
|
51
|
-
-KDE(
|
|
52
|
-
bandwidth=theta,
|
|
53
|
-
kernel=function,
|
|
54
|
-
return_log=True,
|
|
55
|
-
)
|
|
56
|
-
.fit(x)
|
|
57
|
-
.score_samples(x)
|
|
58
|
-
.type(dtype)
|
|
59
|
-
)
|
|
60
|
-
else:
|
|
61
|
-
assert callable(function), "Function has to be callable"
|
|
62
|
-
if theta is None:
|
|
63
|
-
codensity = function(x).type(dtype)
|
|
64
|
-
else:
|
|
65
|
-
codensity = function(x, theta=theta).type(dtype)
|
|
66
|
-
|
|
67
|
-
distance_matrix = torch.cdist(x, x).type(dtype)
|
|
68
|
-
if threshold < np.inf:
|
|
69
|
-
distance_matrix[distance_matrix > threshold] = np.inf
|
|
70
|
-
|
|
71
|
-
st = RipsComplex(
|
|
72
|
-
|
|
73
|
-
).create_simplex_tree()
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
#
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
if
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
] =
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
-
|
|
186
|
-
|
|
187
|
-
-
|
|
188
|
-
-
|
|
189
|
-
|
|
190
|
-
-
|
|
191
|
-
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
if
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
assert theta is not None, "Provide a theta to compute
|
|
201
|
-
codensity = (
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
distances
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
#
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
st.
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
s =
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
1
|
+
from typing import Callable, Literal, Optional
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import gudhi as gd
|
|
6
|
+
|
|
7
|
+
import multipers as mp
|
|
8
|
+
from multipers.filtrations.density import DTM, KDE
|
|
9
|
+
from multipers.simplex_tree_multi import _available_strategies
|
|
10
|
+
from multipers.torch.diff_grids import get_grid
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def function_rips_signed_measure_old(
|
|
14
|
+
x,
|
|
15
|
+
theta: Optional[float] = None,
|
|
16
|
+
function: Literal["dtm", "gaussian", "exponential"] | Callable = "dtm",
|
|
17
|
+
threshold: float = np.inf,
|
|
18
|
+
grid_strategy: _available_strategies = "regular_closest",
|
|
19
|
+
resolution: int = 100,
|
|
20
|
+
return_original: bool = False,
|
|
21
|
+
return_st: bool = False,
|
|
22
|
+
safe_conversion: bool = False,
|
|
23
|
+
num_collapses: int = -1,
|
|
24
|
+
expand_collapse: bool = False,
|
|
25
|
+
dtype=torch.float32,
|
|
26
|
+
**sm_kwargs,
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Computes a torch-differentiable function-rips signed measure.
|
|
30
|
+
|
|
31
|
+
Input
|
|
32
|
+
-----
|
|
33
|
+
- x (num_pts, dim) : The point cloud
|
|
34
|
+
- theta: For density-like functions : the bandwidth
|
|
35
|
+
- threshold : rips threshold
|
|
36
|
+
- function : Either "dtm", "gaussian", or "exponenetial" or Callable.
|
|
37
|
+
Function to compute the second parameter.
|
|
38
|
+
- grid_strategy: grid coarsenning strategy.
|
|
39
|
+
- resolution : when coarsenning, the target resolution,
|
|
40
|
+
- return_original : Also returns the non-differentiable signed measure.
|
|
41
|
+
- safe_conversion : Activate this if you encounter crashes.
|
|
42
|
+
- **kwargs : for the signed measure computation.
|
|
43
|
+
"""
|
|
44
|
+
assert isinstance(x, torch.Tensor)
|
|
45
|
+
if function == "dtm":
|
|
46
|
+
assert theta is not None, "Provide a theta to compute DTM"
|
|
47
|
+
codensity = DTM(masses=[theta]).fit(x).score_samples_diff(x)[0].type(dtype)
|
|
48
|
+
elif function in ["gaussian", "exponential"]:
|
|
49
|
+
assert theta is not None, "Provide a theta to compute density estimation"
|
|
50
|
+
codensity = (
|
|
51
|
+
-KDE(
|
|
52
|
+
bandwidth=theta,
|
|
53
|
+
kernel=function,
|
|
54
|
+
return_log=True,
|
|
55
|
+
)
|
|
56
|
+
.fit(x)
|
|
57
|
+
.score_samples(x)
|
|
58
|
+
.type(dtype)
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
assert callable(function), "Function has to be callable"
|
|
62
|
+
if theta is None:
|
|
63
|
+
codensity = function(x).type(dtype)
|
|
64
|
+
else:
|
|
65
|
+
codensity = function(x, theta=theta).type(dtype)
|
|
66
|
+
|
|
67
|
+
distance_matrix = torch.cdist(x, x).type(dtype)
|
|
68
|
+
if threshold < np.inf:
|
|
69
|
+
distance_matrix[distance_matrix > threshold] = np.inf
|
|
70
|
+
|
|
71
|
+
# st = RipsComplex(
|
|
72
|
+
# distance_matrix=distance_matrix.detach(), max_edge_length=threshold
|
|
73
|
+
# ).create_simplex_tree()
|
|
74
|
+
st = gd.SimplexTree.create_from_array(
|
|
75
|
+
distance_matrix.detach(), max_filtration=threshold
|
|
76
|
+
)
|
|
77
|
+
# detach makes a new (reference) tensor, without tracking the gradient
|
|
78
|
+
st = mp.SimplexTreeMulti(st, num_parameters=2, safe_conversion=safe_conversion)
|
|
79
|
+
st.fill_lowerstar(
|
|
80
|
+
codensity.detach(), parameter=1
|
|
81
|
+
) # fills the codensity in the second parameter of the simplextree
|
|
82
|
+
|
|
83
|
+
# simplificates the simplextree for computation, the signed measure will be recovered from the copy afterward
|
|
84
|
+
st_copy = st.grid_squeeze(
|
|
85
|
+
grid_strategy=grid_strategy, resolution=resolution, coordinate_values=True
|
|
86
|
+
)
|
|
87
|
+
if sm_kwargs.get("degree", None) is None and sm_kwargs.get("degrees", [None]) == [
|
|
88
|
+
None
|
|
89
|
+
]:
|
|
90
|
+
expansion_degree = st.num_vertices
|
|
91
|
+
else:
|
|
92
|
+
expansion_degree = (
|
|
93
|
+
max(np.max(sm_kwargs.get("degrees", 1)), sm_kwargs.get("degree", 1)) + 1
|
|
94
|
+
)
|
|
95
|
+
st.collapse_edges(num=num_collapses)
|
|
96
|
+
if not expand_collapse:
|
|
97
|
+
st.expansion(expansion_degree) # edge collapse
|
|
98
|
+
sms = mp.signed_measure(st, **sm_kwargs) # computes the signed measure
|
|
99
|
+
del st
|
|
100
|
+
|
|
101
|
+
simplices_list = tuple(
|
|
102
|
+
s for s, _ in st_copy.get_simplices()
|
|
103
|
+
) # not optimal, we may want to do that in cython to get edges and nodes
|
|
104
|
+
sms_diff = []
|
|
105
|
+
for sm, weights in sms:
|
|
106
|
+
indices, not_found_indices = st_copy.pts_to_indices(
|
|
107
|
+
sm, simplices_dimensions=[1, 0]
|
|
108
|
+
)
|
|
109
|
+
if sm_kwargs.get("verbose", False):
|
|
110
|
+
print(
|
|
111
|
+
f"Found {(1-(indices == -1).mean()).round(2)} indices. \
|
|
112
|
+
Out : {(indices == -1).sum()}, {len(not_found_indices)}"
|
|
113
|
+
)
|
|
114
|
+
sm_diff = torch.empty(sm.shape).type(dtype)
|
|
115
|
+
# sim_dim = sm_diff.shape[1] // 2
|
|
116
|
+
|
|
117
|
+
# fills the Rips-filtrations of the signed measure.
|
|
118
|
+
# the loop is for the rank invariant
|
|
119
|
+
for i in range(0, sm_diff.shape[1], 2):
|
|
120
|
+
idxs = indices[:, i]
|
|
121
|
+
if (idxs == -1).all():
|
|
122
|
+
continue
|
|
123
|
+
useful_idxs = idxs != -1
|
|
124
|
+
# Retrieves the differentiable values from the distance_matrix
|
|
125
|
+
if useful_idxs.size > 0:
|
|
126
|
+
edges_filtrations = torch.cat(
|
|
127
|
+
[
|
|
128
|
+
distance_matrix[*simplices_list[idx], None]
|
|
129
|
+
for idx in idxs[useful_idxs]
|
|
130
|
+
]
|
|
131
|
+
)
|
|
132
|
+
# fills theses values into the signed measure
|
|
133
|
+
sm_diff[:, i][useful_idxs] = edges_filtrations
|
|
134
|
+
# same for the other axis
|
|
135
|
+
for i in range(1, sm_diff.shape[1], 2):
|
|
136
|
+
idxs = indices[:, i]
|
|
137
|
+
if (idxs == -1).all():
|
|
138
|
+
continue
|
|
139
|
+
useful_idxs = idxs != -1
|
|
140
|
+
if useful_idxs.size > 0:
|
|
141
|
+
nodes_filtrations = torch.cat(
|
|
142
|
+
[codensity[simplices_list[idx]] for idx in idxs[useful_idxs]]
|
|
143
|
+
)
|
|
144
|
+
sm_diff[:, i][useful_idxs] = nodes_filtrations
|
|
145
|
+
|
|
146
|
+
# fills not-found values as constants
|
|
147
|
+
if len(not_found_indices) > 0:
|
|
148
|
+
not_found_indices = indices == -1
|
|
149
|
+
sm_diff[indices == -1] = torch.from_numpy(sm[indices == -1]).type(dtype)
|
|
150
|
+
|
|
151
|
+
sms_diff.append((sm_diff, torch.from_numpy(weights)))
|
|
152
|
+
flags = [True, return_original, return_st]
|
|
153
|
+
if np.sum(flags) == 1:
|
|
154
|
+
return sms_diff
|
|
155
|
+
return tuple(stuff for stuff, flag in zip([sms_diff, sms, st_copy], flags) if flag)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def function_rips_signed_measure(
|
|
159
|
+
x,
|
|
160
|
+
theta: Optional[float] = None,
|
|
161
|
+
function: Literal["dtm", "gaussian", "exponential"] | Callable = "gaussian",
|
|
162
|
+
threshold: Optional[float] = None,
|
|
163
|
+
grid_strategy: Literal[
|
|
164
|
+
"regular_closest", "exact", "quantile", "regular_left"
|
|
165
|
+
] = "exact",
|
|
166
|
+
complex: Literal["rips", "delaunay", "weak_delaunay"] = "rips",
|
|
167
|
+
resolution: int = 100,
|
|
168
|
+
safe_conversion: bool = False,
|
|
169
|
+
num_collapses: Optional[int] = None,
|
|
170
|
+
expand_collapse: bool = False,
|
|
171
|
+
dtype=torch.float32,
|
|
172
|
+
plot=False,
|
|
173
|
+
# return_st: bool = False,
|
|
174
|
+
*,
|
|
175
|
+
log_density: bool = True,
|
|
176
|
+
vineyard: bool = False,
|
|
177
|
+
pers_backend=None,
|
|
178
|
+
**sm_kwargs,
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Computes a torch-differentiable function-rips signed measure.
|
|
182
|
+
|
|
183
|
+
Input
|
|
184
|
+
-----
|
|
185
|
+
- x (num_pts, dim) : The point cloud
|
|
186
|
+
- theta: For density-like functions : the bandwidth
|
|
187
|
+
- threshold : rips threshold
|
|
188
|
+
- function : Either "dtm", "gaussian", or "exponenetial" or Callable.
|
|
189
|
+
Function to compute the second parameter.
|
|
190
|
+
- grid_strategy: grid coarsenning strategy.
|
|
191
|
+
- resolution : when coarsenning, the target resolution,
|
|
192
|
+
- return_original : Also returns the non-differentiable signed measure.
|
|
193
|
+
- safe_conversion : Activate this if you encounter crashes.
|
|
194
|
+
- **kwargs : for the signed measure computation.
|
|
195
|
+
"""
|
|
196
|
+
if num_collapses is None:
|
|
197
|
+
num_collapses = -1 if complex == "rips" else None
|
|
198
|
+
assert isinstance(x, torch.Tensor)
|
|
199
|
+
if function == "dtm":
|
|
200
|
+
assert theta is not None, "Provide a theta to compute DTM"
|
|
201
|
+
codensity = DTM(masses=[theta]).fit(x).score_samples_diff(x)[0].type(dtype)
|
|
202
|
+
elif function in ["gaussian", "exponential"]:
|
|
203
|
+
assert theta is not None, "Provide a theta to compute density estimation"
|
|
204
|
+
codensity = (
|
|
205
|
+
-KDE(
|
|
206
|
+
bandwidth=theta,
|
|
207
|
+
kernel=function,
|
|
208
|
+
return_log=log_density,
|
|
209
|
+
)
|
|
210
|
+
.fit(x)
|
|
211
|
+
.score_samples(x)
|
|
212
|
+
.type(dtype)
|
|
213
|
+
)
|
|
214
|
+
elif isinstance(function, torch.Tensor):
|
|
215
|
+
assert (
|
|
216
|
+
function.ndim == 1 and codensity.shape[0] == x.shape[0]
|
|
217
|
+
), """
|
|
218
|
+
When function is a tensor, it is interpreted as the value of some function over x.
|
|
219
|
+
"""
|
|
220
|
+
codensity = function
|
|
221
|
+
else:
|
|
222
|
+
assert callable(function), "Function has to be callable"
|
|
223
|
+
if theta is None:
|
|
224
|
+
codensity = function(x).type(dtype)
|
|
225
|
+
else:
|
|
226
|
+
codensity = function(x, theta=theta).type(dtype)
|
|
227
|
+
|
|
228
|
+
distance_matrix = torch.cdist(x, x).type(dtype)
|
|
229
|
+
distances = distance_matrix.ravel()
|
|
230
|
+
if complex == "rips":
|
|
231
|
+
threshold = (
|
|
232
|
+
distance_matrix.max(axis=1).values.min() if threshold is None else threshold
|
|
233
|
+
)
|
|
234
|
+
distances = distances[distances <= threshold]
|
|
235
|
+
elif complex in ["delaunay", "weak_delaunay"]:
|
|
236
|
+
complex = "delaunay"
|
|
237
|
+
distances /= 2
|
|
238
|
+
else:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Unimplemented with complex {complex}. You can use rips or delaunay ftm."
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# simplificates the simplextree for computation, the signed measure will be recovered from the copy afterward
|
|
244
|
+
reduced_grid = get_grid(strategy=grid_strategy)((distances, codensity), resolution)
|
|
245
|
+
|
|
246
|
+
degrees = sm_kwargs.pop("degrees", [])
|
|
247
|
+
if sm_kwargs.get("degree", None) is not None:
|
|
248
|
+
degrees = [sm_kwargs.pop("degree", None)] + degrees
|
|
249
|
+
if complex == "rips":
|
|
250
|
+
# st = RipsComplex(
|
|
251
|
+
# distance_matrix=distance_matrix.detach(), max_edge_length=threshold
|
|
252
|
+
# ).create_simplex_tree()
|
|
253
|
+
st = gd.SimplexTree.create_from_array(
|
|
254
|
+
distance_matrix.detach(), max_filtration=threshold
|
|
255
|
+
)
|
|
256
|
+
# detach makes a new (reference) tensor, without tracking the gradient
|
|
257
|
+
st = mp.SimplexTreeMulti(st, num_parameters=2, safe_conversion=safe_conversion)
|
|
258
|
+
st.fill_lowerstar(
|
|
259
|
+
codensity.detach(), parameter=1
|
|
260
|
+
) # fills the codensity in the second parameter of the simplextree
|
|
261
|
+
st = st.grid_squeeze(reduced_grid)
|
|
262
|
+
st.filtration_grid = []
|
|
263
|
+
if None in degrees:
|
|
264
|
+
expansion_degree = st.num_vertices
|
|
265
|
+
else:
|
|
266
|
+
expansion_degree = max(degrees) + 1
|
|
267
|
+
st.collapse_edges(num=num_collapses)
|
|
268
|
+
if not expand_collapse:
|
|
269
|
+
st.expansion(expansion_degree) # edge collapse
|
|
270
|
+
|
|
271
|
+
s = mp.Slicer(st, vineyard=vineyard, backend=pers_backend)
|
|
272
|
+
elif complex == "delaunay":
|
|
273
|
+
s = mp.slicer.from_function_delaunay(
|
|
274
|
+
x.detach().numpy(), codensity.detach().numpy()
|
|
275
|
+
)
|
|
276
|
+
st = mp.slicer.to_simplextree(s)
|
|
277
|
+
st.flagify(2)
|
|
278
|
+
s = mp.Slicer(st, vineyard=vineyard, backend=pers_backend).grid_squeeze(
|
|
279
|
+
reduced_grid
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
s.filtration_grid = [] ## To enforce minpres to be reasonable
|
|
283
|
+
if None not in degrees:
|
|
284
|
+
s = s.minpres(degrees=degrees)
|
|
285
|
+
else:
|
|
286
|
+
from joblib import Parallel, delayed
|
|
287
|
+
|
|
288
|
+
s = tuple(
|
|
289
|
+
Parallel(n_jobs=-1, backend="threading")(
|
|
290
|
+
delayed(lambda d: s if d is None else s.minpres(degree=d))(d)
|
|
291
|
+
for d in degrees
|
|
292
|
+
)
|
|
293
|
+
)
|
|
294
|
+
## fix previous hack
|
|
295
|
+
for stuff in s:
|
|
296
|
+
# stuff.filtration_grid = reduced_grid ## not necessary
|
|
297
|
+
stuff.filtration_grid = [[1]] * stuff.num_parameters
|
|
298
|
+
|
|
299
|
+
sms = tuple(
|
|
300
|
+
sm
|
|
301
|
+
for slicer_of_degree, degree in zip(s, degrees)
|
|
302
|
+
for sm in mp.signed_measure(
|
|
303
|
+
slicer_of_degree, grid=reduced_grid, degree=degree, **sm_kwargs
|
|
304
|
+
)
|
|
305
|
+
) # computes the signed measure
|
|
306
|
+
if plot:
|
|
307
|
+
mp.plots.plot_signed_measures(
|
|
308
|
+
tuple((sm.detach().numpy(), w.detach().numpy()) for sm, w in sms)
|
|
309
|
+
)
|
|
310
|
+
return sms
|