multipers 2.3.3b6__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/.dylibs/libc++.1.0.dylib +0 -0
- multipers/.dylibs/libtbb.12.16.dylib +0 -0
- multipers/__init__.py +33 -0
- multipers/_signed_measure_meta.py +453 -0
- multipers/_slicer_meta.py +211 -0
- multipers/array_api/__init__.py +45 -0
- multipers/array_api/numpy.py +41 -0
- multipers/array_api/torch.py +58 -0
- multipers/data/MOL2.py +458 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +113 -0
- multipers/distances.py +202 -0
- multipers/filtration_conversions.pxd +229 -0
- multipers/filtration_conversions.pxd.tp +84 -0
- multipers/filtrations/__init__.py +18 -0
- multipers/filtrations/density.py +574 -0
- multipers/filtrations/filtrations.py +361 -0
- multipers/filtrations.pxd +224 -0
- multipers/function_rips.cpython-313-darwin.so +0 -0
- multipers/function_rips.pyx +105 -0
- multipers/grids.cpython-313-darwin.so +0 -0
- multipers/grids.pyx +433 -0
- multipers/gudhi/Persistence_slices_interface.h +132 -0
- multipers/gudhi/Simplex_tree_interface.h +239 -0
- multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
- multipers/gudhi/cubical_to_boundary.h +59 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
- multipers/gudhi/gudhi/Debug_utils.h +45 -0
- multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
- multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
- multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
- multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
- multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
- multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
- multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
- multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
- multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
- multipers/gudhi/gudhi/Matrix.h +2107 -0
- multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
- multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
- multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
- multipers/gudhi/gudhi/Off_reader.h +173 -0
- multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
- multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
- multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
- multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
- multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
- multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
- multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
- multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
- multipers/gudhi/gudhi/Points_off_io.h +171 -0
- multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
- multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
- multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
- multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
- multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
- multipers/gudhi/gudhi/distance_functions.h +62 -0
- multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
- multipers/gudhi/gudhi/persistence_interval.h +253 -0
- multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
- multipers/gudhi/gudhi/reader_utils.h +367 -0
- multipers/gudhi/mma_interface_coh.h +256 -0
- multipers/gudhi/mma_interface_h0.h +223 -0
- multipers/gudhi/mma_interface_matrix.h +293 -0
- multipers/gudhi/naive_merge_tree.h +536 -0
- multipers/gudhi/scc_io.h +310 -0
- multipers/gudhi/truc.h +1403 -0
- multipers/io.cpython-313-darwin.so +0 -0
- multipers/io.pyx +644 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +90 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +176 -0
- multipers/ml/mma.py +713 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +352 -0
- multipers/ml/signed_measures.py +1589 -0
- multipers/ml/sliced_wasserstein.py +461 -0
- multipers/ml/tools.py +113 -0
- multipers/mma_structures.cpython-313-darwin.so +0 -0
- multipers/mma_structures.pxd +128 -0
- multipers/mma_structures.pyx +2786 -0
- multipers/mma_structures.pyx.tp +1094 -0
- multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
- multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
- multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
- multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
- multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
- multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
- multipers/multiparameter_edge_collapse.py +41 -0
- multipers/multiparameter_module_approximation/approximation.h +2330 -0
- multipers/multiparameter_module_approximation/combinatory.h +129 -0
- multipers/multiparameter_module_approximation/debug.h +107 -0
- multipers/multiparameter_module_approximation/euler_curves.h +0 -0
- multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
- multipers/multiparameter_module_approximation/heap_column.h +238 -0
- multipers/multiparameter_module_approximation/images.h +79 -0
- multipers/multiparameter_module_approximation/list_column.h +174 -0
- multipers/multiparameter_module_approximation/list_column_2.h +232 -0
- multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
- multipers/multiparameter_module_approximation/set_column.h +135 -0
- multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
- multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
- multipers/multiparameter_module_approximation/utilities.h +403 -0
- multipers/multiparameter_module_approximation/vector_column.h +223 -0
- multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
- multipers/multiparameter_module_approximation/vineyards.h +464 -0
- multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
- multipers/multiparameter_module_approximation.cpython-313-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyx +235 -0
- multipers/pickle.py +90 -0
- multipers/plots.py +456 -0
- multipers/point_measure.cpython-313-darwin.so +0 -0
- multipers/point_measure.pyx +395 -0
- multipers/simplex_tree_multi.cpython-313-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +134 -0
- multipers/simplex_tree_multi.pyx +10840 -0
- multipers/simplex_tree_multi.pyx.tp +2009 -0
- multipers/slicer.cpython-313-darwin.so +0 -0
- multipers/slicer.pxd +3034 -0
- multipers/slicer.pxd.tp +234 -0
- multipers/slicer.pyx +20481 -0
- multipers/slicer.pyx.tp +1088 -0
- multipers/tensor/tensor.h +672 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers/tests/__init__.py +62 -0
- multipers/torch/__init__.py +1 -0
- multipers/torch/diff_grids.py +240 -0
- multipers/torch/rips_density.py +310 -0
- multipers-2.3.3b6.dist-info/METADATA +128 -0
- multipers-2.3.3b6.dist-info/RECORD +183 -0
- multipers-2.3.3b6.dist-info/WHEEL +6 -0
- multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
- multipers-2.3.3b6.dist-info/top_level.txt +1 -0
multipers/test.pyx
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# cimport multipers.tensor as mt
|
|
2
|
+
from libc.stdint cimport intptr_t, uint16_t
|
|
3
|
+
from libcpp.vector cimport vector
|
|
4
|
+
from libcpp cimport bool, int, float
|
|
5
|
+
from libcpp.utility cimport pair
|
|
6
|
+
from typing import Optional,Iterable,Callable
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
ctypedef float value_type
|
|
10
|
+
# ctypedef uint16_t index_type
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
# cimport numpy as cnp
|
|
14
|
+
# cnp.import_array()
|
|
15
|
+
|
|
16
|
+
# cdef extern from "multi_parameter_rank_invariant/rank_invariant.h" namespace "Gudhi::rank_invariant":
|
|
17
|
+
# void get_hilbert_surface(const intptr_t, mt.static_tensor_view, const vector[index_type], const vector[index_type], index_type, index_type, const vector[index_type], bool, bool) except + nogil
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
from multipers.simplex_tree_multi import SimplexTreeMulti
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def numpy_to_tensor(array:np.ndarray):
|
|
24
|
+
cdef vector[index_type] shape = array.shape
|
|
25
|
+
cdef dtype[::1] contigus_array_view = np.ascontiguousarray(array)
|
|
26
|
+
cdef dtype* dtype_ptr = &contigus_array_view[0]
|
|
27
|
+
cdef mt.static_tensor_view tensor
|
|
28
|
+
with nogil:
|
|
29
|
+
tensor = mt.static_tensor_view(dtype_ptr, shape)
|
|
30
|
+
return tensor.get_resolution()
|
|
31
|
+
|
|
32
|
+
# def hilbert2d(simplextree:SimplexTreeMulti, grid_shape:np.ndarray|list, vector[index_type] degrees, bool mobius_inversion):
|
|
33
|
+
# # assert simplextree.num_parameters == 2
|
|
34
|
+
# cdef intptr_t ptr = simplextree.thisptr
|
|
35
|
+
# cdef vector[index_type] c_grid_shape = grid_shape
|
|
36
|
+
# cdef dtype[::1] container = np.zeros(grid_shape, dtype=np.float32).flatten()
|
|
37
|
+
# cdef dtype* container_ptr = &container[0]
|
|
38
|
+
# cdef mt.static_tensor_view c_container = mt.static_tensor_view(container_ptr, c_grid_shape)
|
|
39
|
+
# cdef index_type i = 0
|
|
40
|
+
# cdef index_type j = 1
|
|
41
|
+
# cdef vector[index_type] fixed_values = [[],[]]
|
|
42
|
+
# # get_hilbert_surface(ptr, c_container, c_grid_shape, degrees,i,j,fixed_values, False, False)
|
|
43
|
+
# return container.reshape(grid_shape)
|
|
44
|
+
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def assert_st_simplices(st, dump):
|
|
5
|
+
"""
|
|
6
|
+
Checks that the simplextree has the same
|
|
7
|
+
filtration as the dump.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
assert np.all(
|
|
11
|
+
[
|
|
12
|
+
np.isclose(a, b).all()
|
|
13
|
+
for x, y in zip(st.get_simplices(), dump, strict=True)
|
|
14
|
+
for a, b in zip(x, y, strict=True)
|
|
15
|
+
]
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def sort_sm(sms):
|
|
20
|
+
idx = np.argsort([sm[0][:, 0] for sm in sms])
|
|
21
|
+
return tuple((sm[0][idx], sm[1][idx]) for sm in sms)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def assert_sm_pair(sm1, sm2, exact=True, max_error=1e-3, reg=0.1, threshold=None):
|
|
25
|
+
if not exact:
|
|
26
|
+
from multipers.distances import sm_distance
|
|
27
|
+
if threshold is not None:
|
|
28
|
+
_inf_value_fix = threshold
|
|
29
|
+
sm1[0][sm1[0] >threshold] = _inf_value_fix
|
|
30
|
+
sm2[0][sm2[0] >threshold] = _inf_value_fix
|
|
31
|
+
|
|
32
|
+
d = sm_distance(sm1, sm2, reg=reg)
|
|
33
|
+
assert d < max_error, f"Failed comparison:\n{sm1}\n{sm2},\n with distance {d}."
|
|
34
|
+
return
|
|
35
|
+
assert np.all(
|
|
36
|
+
[
|
|
37
|
+
np.isclose(a, b).all()
|
|
38
|
+
for x, y in zip(sm1, sm2, strict=True)
|
|
39
|
+
for a, b in zip(x, y, strict=True)
|
|
40
|
+
]
|
|
41
|
+
), f"Failed comparison:\n-----------------\n{sm1}\n-----------------\n{sm2}"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def assert_sm(*args, exact=True, max_error=1e-5, reg=0.1, threshold=None):
|
|
45
|
+
sms = tuple(args)
|
|
46
|
+
for i in range(len(sms) - 1):
|
|
47
|
+
print(i)
|
|
48
|
+
assert_sm_pair(sms[i], sms[i + 1], exact=exact, max_error=max_error, reg=reg, threshold=threshold)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def random_st(npts=100, num_parameters=2, max_dim=2):
|
|
52
|
+
import gudhi as gd
|
|
53
|
+
|
|
54
|
+
import multipers as mp
|
|
55
|
+
from multipers.data import noisy_annulus
|
|
56
|
+
|
|
57
|
+
x = noisy_annulus(npts // 2, npts - npts // 2, dim=max_dim)
|
|
58
|
+
st = gd.AlphaComplex(points=x).create_simplex_tree()
|
|
59
|
+
st = mp.SimplexTreeMulti(st, num_parameters=num_parameters)
|
|
60
|
+
for p in range(num_parameters):
|
|
61
|
+
st.fill_lowerstar(np.random.uniform(size=npts), p)
|
|
62
|
+
return st
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .rips_density import *
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from pykeops.torch import LazyTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_grid(strategy: Literal["exact", "regular_closest", "regular_left", "quantile"]):
|
|
9
|
+
"""
|
|
10
|
+
Given a strategy, returns a function of signature
|
|
11
|
+
`(num_pts, num_parameter), int --> Iterable[1d array]`
|
|
12
|
+
that generates a torch-differentiable grid from a set of points,
|
|
13
|
+
and a resolution.
|
|
14
|
+
"""
|
|
15
|
+
match strategy:
|
|
16
|
+
case "exact":
|
|
17
|
+
return _exact_grid
|
|
18
|
+
case "regular":
|
|
19
|
+
return _regular_grid
|
|
20
|
+
case "regular_closest":
|
|
21
|
+
return _regular_closest_grid
|
|
22
|
+
case "regular_left":
|
|
23
|
+
return _regular_left_grid
|
|
24
|
+
case "quantile":
|
|
25
|
+
return _quantile_grid
|
|
26
|
+
case _:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
f"""
|
|
29
|
+
Unimplemented strategy {strategy}.
|
|
30
|
+
Available ones : exact, regular_closest, regular_left, quantile.
|
|
31
|
+
"""
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def todense(grid: list[torch.Tensor]):
|
|
36
|
+
return torch.cartesian_prod(*grid)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _exact_grid(filtration_values, r=None):
|
|
40
|
+
assert r is None
|
|
41
|
+
grid = tuple(_unique_any(f) for f in filtration_values)
|
|
42
|
+
return grid
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _regular_closest_grid(filtration_values, res):
|
|
46
|
+
grid = tuple(_regular_closest(f, r) for f,r in zip(filtration_values, res))
|
|
47
|
+
return grid
|
|
48
|
+
|
|
49
|
+
def _regular_grid(filtration_values, res):
|
|
50
|
+
grid = tuple(_regular(g,r) for g,r in zip(filtration_values, res))
|
|
51
|
+
return grid
|
|
52
|
+
|
|
53
|
+
def _regular(x, r:int):
|
|
54
|
+
if x.ndim != 1:
|
|
55
|
+
raise ValueError(f"Got ndim!=1. {x=}")
|
|
56
|
+
return torch.linspace(start=torch.min(x), end=torch.max(x), steps=r, dtype=x.dtype)
|
|
57
|
+
|
|
58
|
+
def _regular_left_grid(filtration_values, res):
|
|
59
|
+
grid = tuple(_regular_left(f, r) for f,r in zip(filtration_values,res))
|
|
60
|
+
return grid
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _quantile_grid(filtration_values, res):
|
|
64
|
+
grid = tuple(_quantile(f, r) for f,r in zip(filtration_values,res))
|
|
65
|
+
return grid
|
|
66
|
+
def _quantile(x, r):
|
|
67
|
+
if x.ndim != 1:
|
|
68
|
+
raise ValueError(f"Got ndim!=1. {x=}")
|
|
69
|
+
qs = torch.linspace(0, 1, r, dtype=x.dtype)
|
|
70
|
+
return _unique_any(torch.quantile(x, q=qs))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _unique_any(x, assume_sorted=False, remove_inf: bool = True):
|
|
76
|
+
if x.ndim != 1:
|
|
77
|
+
raise ValueError(f"Got ndim!=1. {x=}")
|
|
78
|
+
if not assume_sorted:
|
|
79
|
+
x, _ = x.sort()
|
|
80
|
+
if remove_inf and x[-1] == torch.inf:
|
|
81
|
+
x = x[:-1]
|
|
82
|
+
with torch.no_grad():
|
|
83
|
+
y = x.unique()
|
|
84
|
+
idx = torch.searchsorted(x, y)
|
|
85
|
+
x = torch.cat([x, torch.tensor([torch.inf])])
|
|
86
|
+
return x[idx]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _regular_left(f, r: int, unique: bool = True):
|
|
90
|
+
if f.ndim != 1:
|
|
91
|
+
raise ValueError(f"Got ndim!=1. {f=}")
|
|
92
|
+
f = _unique_any(f)
|
|
93
|
+
with torch.no_grad():
|
|
94
|
+
f_regular = torch.linspace(f[0].item(), f[-1].item(), r, device=f.device)
|
|
95
|
+
idx = torch.searchsorted(f, f_regular)
|
|
96
|
+
f = torch.cat([f, torch.tensor([torch.inf])])
|
|
97
|
+
if unique:
|
|
98
|
+
return _unique_any(f[idx])
|
|
99
|
+
return f[idx]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _regular_closest(f, r: int, unique: bool = True):
|
|
103
|
+
if f.ndim != 1:
|
|
104
|
+
raise ValueError(f"Got ndim!=1. {f=}")
|
|
105
|
+
f = _unique_any(f)
|
|
106
|
+
with torch.no_grad():
|
|
107
|
+
f_reg = torch.linspace(
|
|
108
|
+
f[0].item(), f[-1].item(), steps=r, dtype=f.dtype, device=f.device
|
|
109
|
+
)
|
|
110
|
+
_f = LazyTensor(f[:, None, None])
|
|
111
|
+
_f_reg = LazyTensor(f_reg[None, :, None])
|
|
112
|
+
indices = (_f - _f_reg).abs().argmin(0).ravel()
|
|
113
|
+
f = torch.cat([f, torch.tensor([torch.inf])])
|
|
114
|
+
f_regular_closest = f[indices]
|
|
115
|
+
if unique:
|
|
116
|
+
f_regular_closest = _unique_any(f_regular_closest)
|
|
117
|
+
return f_regular_closest
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def evaluate_in_grid(pts, grid):
|
|
121
|
+
"""Evaluates points (assumed to be coordinates) in this grid.
|
|
122
|
+
Input
|
|
123
|
+
-----
|
|
124
|
+
- pts: (num_points, num_parameters) array
|
|
125
|
+
- grid: Iterable of 1-d array, for each parameter
|
|
126
|
+
|
|
127
|
+
Returns
|
|
128
|
+
-------
|
|
129
|
+
- array of shape like points of dtype like grid.
|
|
130
|
+
"""
|
|
131
|
+
# grid = [torch.cat([g, torch.tensor([torch.inf])]) for g in grid]
|
|
132
|
+
# new_pts = torch.empty(pts.shape, dtype=grid[0].dtype, device=grid[0].device)
|
|
133
|
+
# for parameter, pt_of_parameter in enumerate(pts.T):
|
|
134
|
+
# new_pts[:, parameter] = grid[parameter][pt_of_parameter]
|
|
135
|
+
return torch.cat(
|
|
136
|
+
[
|
|
137
|
+
grid[parameter][pt_of_parameter][:, None]
|
|
138
|
+
for parameter, pt_of_parameter in enumerate(pts.T)
|
|
139
|
+
],
|
|
140
|
+
dim=1,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def evaluate_mod_in_grid(mod, grid, box=None):
|
|
145
|
+
"""Given an MMA module, pushes it into the specified grid.
|
|
146
|
+
Useful for e.g., make it differentiable.
|
|
147
|
+
|
|
148
|
+
Input
|
|
149
|
+
-----
|
|
150
|
+
- mod: PyModule
|
|
151
|
+
- grid: Iterable of 1d array, for num_parameters
|
|
152
|
+
Ouput
|
|
153
|
+
-----
|
|
154
|
+
torch-compatible module in the format:
|
|
155
|
+
(num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
|
|
156
|
+
|
|
157
|
+
"""
|
|
158
|
+
if box is not None:
|
|
159
|
+
grid = tuple(
|
|
160
|
+
torch.cat(
|
|
161
|
+
[
|
|
162
|
+
box[0][[i]],
|
|
163
|
+
_unique_any(
|
|
164
|
+
grid[i].clamp(min=box[0][i], max=box[1][i]), assume_sorted=True
|
|
165
|
+
),
|
|
166
|
+
box[1][[i]],
|
|
167
|
+
]
|
|
168
|
+
)
|
|
169
|
+
for i in range(len(grid))
|
|
170
|
+
)
|
|
171
|
+
(birth_sizes, death_sizes), births, deaths = mod.to_flat_idx(grid)
|
|
172
|
+
births = evaluate_in_grid(births, grid)
|
|
173
|
+
deaths = evaluate_in_grid(deaths, grid)
|
|
174
|
+
diff_mod = tuple(
|
|
175
|
+
zip(
|
|
176
|
+
births.split_with_sizes(birth_sizes.tolist()),
|
|
177
|
+
deaths.split_with_sizes(death_sizes.tolist()),
|
|
178
|
+
)
|
|
179
|
+
)
|
|
180
|
+
return diff_mod
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def evaluate_mod_in_grid__old(mod, grid, box=None):
|
|
184
|
+
"""Given an MMA module, pushes it into the specified grid.
|
|
185
|
+
Useful for e.g., make it differentiable.
|
|
186
|
+
|
|
187
|
+
Input
|
|
188
|
+
-----
|
|
189
|
+
- mod: PyModule
|
|
190
|
+
- grid: Iterable of 1d array, for num_parameters
|
|
191
|
+
Ouput
|
|
192
|
+
-----
|
|
193
|
+
torch-compatible module in the format:
|
|
194
|
+
(num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
|
|
195
|
+
|
|
196
|
+
"""
|
|
197
|
+
from pykeops.numpy import LazyTensor
|
|
198
|
+
|
|
199
|
+
with torch.no_grad():
|
|
200
|
+
if box is None:
|
|
201
|
+
# box = mod.get_box()
|
|
202
|
+
box = np.asarray([[g[0] for g in grid], [g[-1] for g in grid]])
|
|
203
|
+
S = mod.dump()[1]
|
|
204
|
+
|
|
205
|
+
def get_idx_parameter(A, G, p):
|
|
206
|
+
g = G[p].numpy() if isinstance(G[p], torch.Tensor) else np.asarray(G[p])
|
|
207
|
+
la = LazyTensor(np.asarray(A, dtype=g.dtype)[None, :, [p]])
|
|
208
|
+
lg = LazyTensor(g[:, None, None])
|
|
209
|
+
return (la - lg).abs().argmin(0)
|
|
210
|
+
|
|
211
|
+
Bdump = np.concatenate([s[0] for s in S], axis=0).clip(box[[0]], box[[1]])
|
|
212
|
+
B = np.concatenate(
|
|
213
|
+
[get_idx_parameter(Bdump, grid, p) for p in range(mod.num_parameters)],
|
|
214
|
+
axis=1,
|
|
215
|
+
dtype=np.int64,
|
|
216
|
+
)
|
|
217
|
+
Ddump = np.concatenate([s[1] for s in S], axis=0, dtype=np.float32).clip(
|
|
218
|
+
box[[0]], box[[1]]
|
|
219
|
+
)
|
|
220
|
+
D = np.concatenate(
|
|
221
|
+
[get_idx_parameter(Ddump, grid, p) for p in range(mod.num_parameters)],
|
|
222
|
+
axis=1,
|
|
223
|
+
dtype=np.int64,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
BB = evaluate_in_grid(B, grid)
|
|
227
|
+
DD = evaluate_in_grid(D, grid)
|
|
228
|
+
|
|
229
|
+
b_idx = tuple((len(s[0]) for s in S))
|
|
230
|
+
d_idx = tuple((len(s[1]) for s in S))
|
|
231
|
+
BBB = BB.split_with_sizes(b_idx)
|
|
232
|
+
DDD = DD.split_with_sizes(d_idx)
|
|
233
|
+
|
|
234
|
+
splits = np.concatenate([[0], mod.degree_splits(), [len(BBB)]])
|
|
235
|
+
splits = torch.from_numpy(splits)
|
|
236
|
+
out = [
|
|
237
|
+
list(zip(BBB[splits[i] : splits[i + 1]], DDD[splits[i] : splits[i + 1]]))
|
|
238
|
+
for i in range(len(splits) - 1)
|
|
239
|
+
] ## For some reasons this kills the gradient ???? pytorch bug
|
|
240
|
+
return out
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
from typing import Callable, Literal, Optional
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import gudhi as gd
|
|
6
|
+
|
|
7
|
+
import multipers as mp
|
|
8
|
+
from multipers.filtrations.density import DTM, KDE
|
|
9
|
+
from multipers.simplex_tree_multi import _available_strategies
|
|
10
|
+
from multipers.torch.diff_grids import get_grid
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def function_rips_signed_measure_old(
|
|
14
|
+
x,
|
|
15
|
+
theta: Optional[float] = None,
|
|
16
|
+
function: Literal["dtm", "gaussian", "exponential"] | Callable = "dtm",
|
|
17
|
+
threshold: float = np.inf,
|
|
18
|
+
grid_strategy: _available_strategies = "regular_closest",
|
|
19
|
+
resolution: int = 100,
|
|
20
|
+
return_original: bool = False,
|
|
21
|
+
return_st: bool = False,
|
|
22
|
+
safe_conversion: bool = False,
|
|
23
|
+
num_collapses: int = -1,
|
|
24
|
+
expand_collapse: bool = False,
|
|
25
|
+
dtype=torch.float32,
|
|
26
|
+
**sm_kwargs,
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Computes a torch-differentiable function-rips signed measure.
|
|
30
|
+
|
|
31
|
+
Input
|
|
32
|
+
-----
|
|
33
|
+
- x (num_pts, dim) : The point cloud
|
|
34
|
+
- theta: For density-like functions : the bandwidth
|
|
35
|
+
- threshold : rips threshold
|
|
36
|
+
- function : Either "dtm", "gaussian", or "exponenetial" or Callable.
|
|
37
|
+
Function to compute the second parameter.
|
|
38
|
+
- grid_strategy: grid coarsenning strategy.
|
|
39
|
+
- resolution : when coarsenning, the target resolution,
|
|
40
|
+
- return_original : Also returns the non-differentiable signed measure.
|
|
41
|
+
- safe_conversion : Activate this if you encounter crashes.
|
|
42
|
+
- **kwargs : for the signed measure computation.
|
|
43
|
+
"""
|
|
44
|
+
assert isinstance(x, torch.Tensor)
|
|
45
|
+
if function == "dtm":
|
|
46
|
+
assert theta is not None, "Provide a theta to compute DTM"
|
|
47
|
+
codensity = DTM(masses=[theta]).fit(x).score_samples_diff(x)[0].type(dtype)
|
|
48
|
+
elif function in ["gaussian", "exponential"]:
|
|
49
|
+
assert theta is not None, "Provide a theta to compute density estimation"
|
|
50
|
+
codensity = (
|
|
51
|
+
-KDE(
|
|
52
|
+
bandwidth=theta,
|
|
53
|
+
kernel=function,
|
|
54
|
+
return_log=True,
|
|
55
|
+
)
|
|
56
|
+
.fit(x)
|
|
57
|
+
.score_samples(x)
|
|
58
|
+
.type(dtype)
|
|
59
|
+
)
|
|
60
|
+
else:
|
|
61
|
+
assert callable(function), "Function has to be callable"
|
|
62
|
+
if theta is None:
|
|
63
|
+
codensity = function(x).type(dtype)
|
|
64
|
+
else:
|
|
65
|
+
codensity = function(x, theta=theta).type(dtype)
|
|
66
|
+
|
|
67
|
+
distance_matrix = torch.cdist(x, x).type(dtype)
|
|
68
|
+
if threshold < np.inf:
|
|
69
|
+
distance_matrix[distance_matrix > threshold] = np.inf
|
|
70
|
+
|
|
71
|
+
# st = RipsComplex(
|
|
72
|
+
# distance_matrix=distance_matrix.detach(), max_edge_length=threshold
|
|
73
|
+
# ).create_simplex_tree()
|
|
74
|
+
st = gd.SimplexTree.create_from_array(
|
|
75
|
+
distance_matrix.detach(), max_filtration=threshold
|
|
76
|
+
)
|
|
77
|
+
# detach makes a new (reference) tensor, without tracking the gradient
|
|
78
|
+
st = mp.SimplexTreeMulti(st, num_parameters=2, safe_conversion=safe_conversion)
|
|
79
|
+
st.fill_lowerstar(
|
|
80
|
+
codensity.detach(), parameter=1
|
|
81
|
+
) # fills the codensity in the second parameter of the simplextree
|
|
82
|
+
|
|
83
|
+
# simplificates the simplextree for computation, the signed measure will be recovered from the copy afterward
|
|
84
|
+
st_copy = st.grid_squeeze(
|
|
85
|
+
grid_strategy=grid_strategy, resolution=resolution, coordinate_values=True
|
|
86
|
+
)
|
|
87
|
+
if sm_kwargs.get("degree", None) is None and sm_kwargs.get("degrees", [None]) == [
|
|
88
|
+
None
|
|
89
|
+
]:
|
|
90
|
+
expansion_degree = st.num_vertices
|
|
91
|
+
else:
|
|
92
|
+
expansion_degree = (
|
|
93
|
+
max(np.max(sm_kwargs.get("degrees", 1)), sm_kwargs.get("degree", 1)) + 1
|
|
94
|
+
)
|
|
95
|
+
st.collapse_edges(num=num_collapses)
|
|
96
|
+
if not expand_collapse:
|
|
97
|
+
st.expansion(expansion_degree) # edge collapse
|
|
98
|
+
sms = mp.signed_measure(st, **sm_kwargs) # computes the signed measure
|
|
99
|
+
del st
|
|
100
|
+
|
|
101
|
+
simplices_list = tuple(
|
|
102
|
+
s for s, _ in st_copy.get_simplices()
|
|
103
|
+
) # not optimal, we may want to do that in cython to get edges and nodes
|
|
104
|
+
sms_diff = []
|
|
105
|
+
for sm, weights in sms:
|
|
106
|
+
indices, not_found_indices = st_copy.pts_to_indices(
|
|
107
|
+
sm, simplices_dimensions=[1, 0]
|
|
108
|
+
)
|
|
109
|
+
if sm_kwargs.get("verbose", False):
|
|
110
|
+
print(
|
|
111
|
+
f"Found {(1-(indices == -1).mean()).round(2)} indices. \
|
|
112
|
+
Out : {(indices == -1).sum()}, {len(not_found_indices)}"
|
|
113
|
+
)
|
|
114
|
+
sm_diff = torch.empty(sm.shape).type(dtype)
|
|
115
|
+
# sim_dim = sm_diff.shape[1] // 2
|
|
116
|
+
|
|
117
|
+
# fills the Rips-filtrations of the signed measure.
|
|
118
|
+
# the loop is for the rank invariant
|
|
119
|
+
for i in range(0, sm_diff.shape[1], 2):
|
|
120
|
+
idxs = indices[:, i]
|
|
121
|
+
if (idxs == -1).all():
|
|
122
|
+
continue
|
|
123
|
+
useful_idxs = idxs != -1
|
|
124
|
+
# Retrieves the differentiable values from the distance_matrix
|
|
125
|
+
if useful_idxs.size > 0:
|
|
126
|
+
edges_filtrations = torch.cat(
|
|
127
|
+
[
|
|
128
|
+
distance_matrix[*simplices_list[idx], None]
|
|
129
|
+
for idx in idxs[useful_idxs]
|
|
130
|
+
]
|
|
131
|
+
)
|
|
132
|
+
# fills theses values into the signed measure
|
|
133
|
+
sm_diff[:, i][useful_idxs] = edges_filtrations
|
|
134
|
+
# same for the other axis
|
|
135
|
+
for i in range(1, sm_diff.shape[1], 2):
|
|
136
|
+
idxs = indices[:, i]
|
|
137
|
+
if (idxs == -1).all():
|
|
138
|
+
continue
|
|
139
|
+
useful_idxs = idxs != -1
|
|
140
|
+
if useful_idxs.size > 0:
|
|
141
|
+
nodes_filtrations = torch.cat(
|
|
142
|
+
[codensity[simplices_list[idx]] for idx in idxs[useful_idxs]]
|
|
143
|
+
)
|
|
144
|
+
sm_diff[:, i][useful_idxs] = nodes_filtrations
|
|
145
|
+
|
|
146
|
+
# fills not-found values as constants
|
|
147
|
+
if len(not_found_indices) > 0:
|
|
148
|
+
not_found_indices = indices == -1
|
|
149
|
+
sm_diff[indices == -1] = torch.from_numpy(sm[indices == -1]).type(dtype)
|
|
150
|
+
|
|
151
|
+
sms_diff.append((sm_diff, torch.from_numpy(weights)))
|
|
152
|
+
flags = [True, return_original, return_st]
|
|
153
|
+
if np.sum(flags) == 1:
|
|
154
|
+
return sms_diff
|
|
155
|
+
return tuple(stuff for stuff, flag in zip([sms_diff, sms, st_copy], flags) if flag)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def function_rips_signed_measure(
|
|
159
|
+
x,
|
|
160
|
+
theta: Optional[float] = None,
|
|
161
|
+
function: Literal["dtm", "gaussian", "exponential"] | Callable = "gaussian",
|
|
162
|
+
threshold: Optional[float] = None,
|
|
163
|
+
grid_strategy: Literal[
|
|
164
|
+
"regular_closest", "exact", "quantile", "regular_left"
|
|
165
|
+
] = "exact",
|
|
166
|
+
complex: Literal["rips", "delaunay", "weak_delaunay"] = "rips",
|
|
167
|
+
resolution: int = 100,
|
|
168
|
+
safe_conversion: bool = False,
|
|
169
|
+
num_collapses: Optional[int] = None,
|
|
170
|
+
expand_collapse: bool = False,
|
|
171
|
+
dtype=torch.float32,
|
|
172
|
+
plot=False,
|
|
173
|
+
# return_st: bool = False,
|
|
174
|
+
*,
|
|
175
|
+
log_density: bool = True,
|
|
176
|
+
vineyard: bool = False,
|
|
177
|
+
pers_backend=None,
|
|
178
|
+
**sm_kwargs,
|
|
179
|
+
):
|
|
180
|
+
"""
|
|
181
|
+
Computes a torch-differentiable function-rips signed measure.
|
|
182
|
+
|
|
183
|
+
Input
|
|
184
|
+
-----
|
|
185
|
+
- x (num_pts, dim) : The point cloud
|
|
186
|
+
- theta: For density-like functions : the bandwidth
|
|
187
|
+
- threshold : rips threshold
|
|
188
|
+
- function : Either "dtm", "gaussian", or "exponenetial" or Callable.
|
|
189
|
+
Function to compute the second parameter.
|
|
190
|
+
- grid_strategy: grid coarsenning strategy.
|
|
191
|
+
- resolution : when coarsenning, the target resolution,
|
|
192
|
+
- return_original : Also returns the non-differentiable signed measure.
|
|
193
|
+
- safe_conversion : Activate this if you encounter crashes.
|
|
194
|
+
- **kwargs : for the signed measure computation.
|
|
195
|
+
"""
|
|
196
|
+
if num_collapses is None:
|
|
197
|
+
num_collapses = -1 if complex == "rips" else None
|
|
198
|
+
assert isinstance(x, torch.Tensor)
|
|
199
|
+
if function == "dtm":
|
|
200
|
+
assert theta is not None, "Provide a theta to compute DTM"
|
|
201
|
+
codensity = DTM(masses=[theta]).fit(x).score_samples_diff(x)[0].type(dtype)
|
|
202
|
+
elif function in ["gaussian", "exponential"]:
|
|
203
|
+
assert theta is not None, "Provide a theta to compute density estimation"
|
|
204
|
+
codensity = (
|
|
205
|
+
-KDE(
|
|
206
|
+
bandwidth=theta,
|
|
207
|
+
kernel=function,
|
|
208
|
+
return_log=log_density,
|
|
209
|
+
)
|
|
210
|
+
.fit(x)
|
|
211
|
+
.score_samples(x)
|
|
212
|
+
.type(dtype)
|
|
213
|
+
)
|
|
214
|
+
elif isinstance(function, torch.Tensor):
|
|
215
|
+
assert (
|
|
216
|
+
function.ndim == 1 and codensity.shape[0] == x.shape[0]
|
|
217
|
+
), """
|
|
218
|
+
When function is a tensor, it is interpreted as the value of some function over x.
|
|
219
|
+
"""
|
|
220
|
+
codensity = function
|
|
221
|
+
else:
|
|
222
|
+
assert callable(function), "Function has to be callable"
|
|
223
|
+
if theta is None:
|
|
224
|
+
codensity = function(x).type(dtype)
|
|
225
|
+
else:
|
|
226
|
+
codensity = function(x, theta=theta).type(dtype)
|
|
227
|
+
|
|
228
|
+
distance_matrix = torch.cdist(x, x).type(dtype)
|
|
229
|
+
distances = distance_matrix.ravel()
|
|
230
|
+
if complex == "rips":
|
|
231
|
+
threshold = (
|
|
232
|
+
distance_matrix.max(axis=1).values.min() if threshold is None else threshold
|
|
233
|
+
)
|
|
234
|
+
distances = distances[distances <= threshold]
|
|
235
|
+
elif complex in ["delaunay", "weak_delaunay"]:
|
|
236
|
+
complex = "delaunay"
|
|
237
|
+
distances /= 2
|
|
238
|
+
else:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
f"Unimplemented with complex {complex}. You can use rips or delaunay ftm."
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# simplificates the simplextree for computation, the signed measure will be recovered from the copy afterward
|
|
244
|
+
reduced_grid = get_grid(strategy=grid_strategy)((distances, codensity), resolution)
|
|
245
|
+
|
|
246
|
+
degrees = sm_kwargs.pop("degrees", [])
|
|
247
|
+
if sm_kwargs.get("degree", None) is not None:
|
|
248
|
+
degrees = [sm_kwargs.pop("degree", None)] + degrees
|
|
249
|
+
if complex == "rips":
|
|
250
|
+
# st = RipsComplex(
|
|
251
|
+
# distance_matrix=distance_matrix.detach(), max_edge_length=threshold
|
|
252
|
+
# ).create_simplex_tree()
|
|
253
|
+
st = gd.SimplexTree.create_from_array(
|
|
254
|
+
distance_matrix.detach(), max_filtration=threshold
|
|
255
|
+
)
|
|
256
|
+
# detach makes a new (reference) tensor, without tracking the gradient
|
|
257
|
+
st = mp.SimplexTreeMulti(st, num_parameters=2, safe_conversion=safe_conversion)
|
|
258
|
+
st.fill_lowerstar(
|
|
259
|
+
codensity.detach(), parameter=1
|
|
260
|
+
) # fills the codensity in the second parameter of the simplextree
|
|
261
|
+
st = st.grid_squeeze(reduced_grid)
|
|
262
|
+
st.filtration_grid = []
|
|
263
|
+
if None in degrees:
|
|
264
|
+
expansion_degree = st.num_vertices
|
|
265
|
+
else:
|
|
266
|
+
expansion_degree = max(degrees) + 1
|
|
267
|
+
st.collapse_edges(num=num_collapses)
|
|
268
|
+
if not expand_collapse:
|
|
269
|
+
st.expansion(expansion_degree) # edge collapse
|
|
270
|
+
|
|
271
|
+
s = mp.Slicer(st, vineyard=vineyard, backend=pers_backend)
|
|
272
|
+
elif complex == "delaunay":
|
|
273
|
+
s = mp.slicer.from_function_delaunay(
|
|
274
|
+
x.detach().numpy(), codensity.detach().numpy()
|
|
275
|
+
)
|
|
276
|
+
st = mp.slicer.to_simplextree(s)
|
|
277
|
+
st.flagify(2)
|
|
278
|
+
s = mp.Slicer(st, vineyard=vineyard, backend=pers_backend).grid_squeeze(
|
|
279
|
+
reduced_grid
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
s.filtration_grid = [] ## To enforce minpres to be reasonable
|
|
283
|
+
if None not in degrees:
|
|
284
|
+
s = s.minpres(degrees=degrees)
|
|
285
|
+
else:
|
|
286
|
+
from joblib import Parallel, delayed
|
|
287
|
+
|
|
288
|
+
s = tuple(
|
|
289
|
+
Parallel(n_jobs=-1, backend="threading")(
|
|
290
|
+
delayed(lambda d: s if d is None else s.minpres(degree=d))(d)
|
|
291
|
+
for d in degrees
|
|
292
|
+
)
|
|
293
|
+
)
|
|
294
|
+
## fix previous hack
|
|
295
|
+
for stuff in s:
|
|
296
|
+
# stuff.filtration_grid = reduced_grid ## not necessary
|
|
297
|
+
stuff.filtration_grid = [[1]] * stuff.num_parameters
|
|
298
|
+
|
|
299
|
+
sms = tuple(
|
|
300
|
+
sm
|
|
301
|
+
for slicer_of_degree, degree in zip(s, degrees)
|
|
302
|
+
for sm in mp.signed_measure(
|
|
303
|
+
slicer_of_degree, grid=reduced_grid, degree=degree, **sm_kwargs
|
|
304
|
+
)
|
|
305
|
+
) # computes the signed measure
|
|
306
|
+
if plot:
|
|
307
|
+
mp.plots.plot_signed_measures(
|
|
308
|
+
tuple((sm.detach().numpy(), w.detach().numpy()) for sm, w in sms)
|
|
309
|
+
)
|
|
310
|
+
return sms
|