multipers 2.0.0__cp312-cp312-macosx_13_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/.dylibs/libc++.1.0.dylib +0 -0
- multipers/.dylibs/libtbb.12.12.dylib +0 -0
- multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
- multipers/__init__.py +11 -0
- multipers/_signed_measure_meta.py +268 -0
- multipers/_slicer_meta.py +171 -0
- multipers/data/MOL2.py +350 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +68 -0
- multipers/distances.py +198 -0
- multipers/euler_characteristic.pyx +132 -0
- multipers/filtration_conversions.pxd +229 -0
- multipers/filtrations.pxd +225 -0
- multipers/function_rips.cpython-312-darwin.so +0 -0
- multipers/function_rips.pyx +105 -0
- multipers/grids.cpython-312-darwin.so +0 -0
- multipers/grids.pyx +281 -0
- multipers/hilbert_function.pyi +46 -0
- multipers/hilbert_function.pyx +153 -0
- multipers/io.cpython-312-darwin.so +0 -0
- multipers/io.pyx +571 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +90 -0
- multipers/ml/convolutions.py +532 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +176 -0
- multipers/ml/mma.py +659 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +238 -0
- multipers/ml/signed_betti.py +50 -0
- multipers/ml/signed_measures.py +1542 -0
- multipers/ml/sliced_wasserstein.py +461 -0
- multipers/ml/tools.py +113 -0
- multipers/mma_structures.cpython-312-darwin.so +0 -0
- multipers/mma_structures.pxd +127 -0
- multipers/mma_structures.pyx +2433 -0
- multipers/multiparameter_edge_collapse.py +41 -0
- multipers/multiparameter_module_approximation.cpython-312-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyx +211 -0
- multipers/pickle.py +53 -0
- multipers/plots.py +326 -0
- multipers/point_measure_integration.cpython-312-darwin.so +0 -0
- multipers/point_measure_integration.pyx +139 -0
- multipers/rank_invariant.cpython-312-darwin.so +0 -0
- multipers/rank_invariant.pyx +229 -0
- multipers/simplex_tree_multi.cpython-312-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +129 -0
- multipers/simplex_tree_multi.pyi +715 -0
- multipers/simplex_tree_multi.pyx +4655 -0
- multipers/slicer.cpython-312-darwin.so +0 -0
- multipers/slicer.pxd +781 -0
- multipers/slicer.pyx +3393 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers/tests/__init__.py +40 -0
- multipers/tests/old_test_rank_invariant.py +91 -0
- multipers/tests/test_diff_helper.py +74 -0
- multipers/tests/test_hilbert_function.py +82 -0
- multipers/tests/test_mma.py +51 -0
- multipers/tests/test_point_clouds.py +59 -0
- multipers/tests/test_python-cpp_conversion.py +82 -0
- multipers/tests/test_signed_betti.py +181 -0
- multipers/tests/test_simplextreemulti.py +98 -0
- multipers/tests/test_slicer.py +63 -0
- multipers/torch/__init__.py +1 -0
- multipers/torch/diff_grids.py +217 -0
- multipers/torch/rips_density.py +257 -0
- multipers-2.0.0.dist-info/LICENSE +21 -0
- multipers-2.0.0.dist-info/METADATA +29 -0
- multipers-2.0.0.dist-info/RECORD +78 -0
- multipers-2.0.0.dist-info/WHEEL +5 -0
- multipers-2.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from multipers.ml.signed_betti import signed_betti, rank_decomposition_by_rectangles
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# only tests rank functions with 1 and 2 parameters
|
|
6
|
+
def test_rank_decomposition():
|
|
7
|
+
# rank of an interval module in 1D on a grid with 2 elements
|
|
8
|
+
ri = np.array(
|
|
9
|
+
[
|
|
10
|
+
[
|
|
11
|
+
1, # 0,0
|
|
12
|
+
1, # 0,1
|
|
13
|
+
],
|
|
14
|
+
[0, 1], # 1,0 # 1,1
|
|
15
|
+
]
|
|
16
|
+
)
|
|
17
|
+
expected_rd = np.array(
|
|
18
|
+
[
|
|
19
|
+
[
|
|
20
|
+
0, # 0,0
|
|
21
|
+
1, # 0,1
|
|
22
|
+
],
|
|
23
|
+
[0, 0], # 1,0 # 1,1
|
|
24
|
+
]
|
|
25
|
+
)
|
|
26
|
+
rd = rank_decomposition_by_rectangles(ri)
|
|
27
|
+
for i in range(2):
|
|
28
|
+
for i_ in range(i, 2):
|
|
29
|
+
assert rd[i, i_] == expected_rd[i, i_]
|
|
30
|
+
|
|
31
|
+
# rank of a sum of two rectangles in 2D on a grid of 2 elements
|
|
32
|
+
ri = np.array(
|
|
33
|
+
[
|
|
34
|
+
[
|
|
35
|
+
[
|
|
36
|
+
[1, 1], # (0,0), (0,0) # (0,0), (0,1)
|
|
37
|
+
[1, 1], # (0,0), (1,0) # (0,0), (1,1)
|
|
38
|
+
],
|
|
39
|
+
[
|
|
40
|
+
[0, 1], # (0,1), (0,0) # (0,1), (0,1)
|
|
41
|
+
[0, 1], # (0,1), (1,0) # (0,1), (1,1)
|
|
42
|
+
],
|
|
43
|
+
],
|
|
44
|
+
[
|
|
45
|
+
[
|
|
46
|
+
[0, 0], # (1,0), (0,0) # (1,0), (0,1)
|
|
47
|
+
[2, 2], # (1,0), (1,0) # (1,0), (1,1)
|
|
48
|
+
],
|
|
49
|
+
[
|
|
50
|
+
[0, 0], # (1,1), (0,0) # (1,1), (0,1)
|
|
51
|
+
[0, 2], # (1,1), (1,0) # (1,1), (1,1)
|
|
52
|
+
],
|
|
53
|
+
],
|
|
54
|
+
]
|
|
55
|
+
)
|
|
56
|
+
expected_rd = np.array(
|
|
57
|
+
[
|
|
58
|
+
[
|
|
59
|
+
[
|
|
60
|
+
[0, 0], # (0,0), (0,0) # (0,0), (0,1)
|
|
61
|
+
[0, 1], # (0,0), (1,0) # (0,0), (1,1)
|
|
62
|
+
],
|
|
63
|
+
[
|
|
64
|
+
[0, 0], # (0,1), (0,0) # (0,1), (0,1)
|
|
65
|
+
[0, 0], # (0,1), (1,0) # (0,1), (1,1)
|
|
66
|
+
],
|
|
67
|
+
],
|
|
68
|
+
[
|
|
69
|
+
[
|
|
70
|
+
[0, 0], # (1,0), (0,0) # (1,0), (0,1)
|
|
71
|
+
[0, 1], # (1,0), (1,0) # (1,0), (1,1)
|
|
72
|
+
],
|
|
73
|
+
[
|
|
74
|
+
[0, 0], # (1,1), (0,0) # (1,1), (0,1)
|
|
75
|
+
[0, 0], # (1,1), (1,0) # (1,1), (1,1)
|
|
76
|
+
],
|
|
77
|
+
],
|
|
78
|
+
]
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
rd = rank_decomposition_by_rectangles(ri)
|
|
82
|
+
for i in range(2):
|
|
83
|
+
for i_ in range(i, 2):
|
|
84
|
+
for j in range(2):
|
|
85
|
+
for j_ in range(j, 2):
|
|
86
|
+
assert rd[i, j, i_, j_] == expected_rd[i, j, i_, j_]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# only tests Hilbert functions with 1, 2, 3, and 4 parameters
|
|
90
|
+
def _test_signed_betti():
|
|
91
|
+
np.random.seed(0)
|
|
92
|
+
N = 4
|
|
93
|
+
|
|
94
|
+
# test 1D
|
|
95
|
+
for _ in range(N):
|
|
96
|
+
a = np.random.randint(10, 30)
|
|
97
|
+
|
|
98
|
+
f = np.random.randint(0, 40, size=(a))
|
|
99
|
+
sb = signed_betti(f)
|
|
100
|
+
|
|
101
|
+
check = np.zeros(f.shape)
|
|
102
|
+
for i in range(f.shape[0]):
|
|
103
|
+
for i_ in range(0, i + 1):
|
|
104
|
+
check[i] += sb[i_]
|
|
105
|
+
|
|
106
|
+
assert np.allclose(check, f)
|
|
107
|
+
|
|
108
|
+
# test 2D
|
|
109
|
+
for _ in range(N):
|
|
110
|
+
a = np.random.randint(10, 30)
|
|
111
|
+
b = np.random.randint(10, 30)
|
|
112
|
+
|
|
113
|
+
f = np.random.randint(0, 40, size=(a, b))
|
|
114
|
+
sb = signed_betti(f)
|
|
115
|
+
|
|
116
|
+
check = np.zeros(f.shape)
|
|
117
|
+
for i in range(f.shape[0]):
|
|
118
|
+
for j in range(f.shape[1]):
|
|
119
|
+
for i_ in range(0, i + 1):
|
|
120
|
+
for j_ in range(0, j + 1):
|
|
121
|
+
check[i, j] += sb[i_, j_]
|
|
122
|
+
|
|
123
|
+
assert np.allclose(check, f)
|
|
124
|
+
|
|
125
|
+
# test 3D
|
|
126
|
+
for _ in range(N):
|
|
127
|
+
a = np.random.randint(10, 20)
|
|
128
|
+
b = np.random.randint(10, 20)
|
|
129
|
+
c = np.random.randint(10, 20)
|
|
130
|
+
|
|
131
|
+
f = np.random.randint(0, 40, size=(a, b, c))
|
|
132
|
+
sb = signed_betti(f)
|
|
133
|
+
|
|
134
|
+
check = np.zeros(f.shape)
|
|
135
|
+
for i in range(f.shape[0]):
|
|
136
|
+
for j in range(f.shape[1]):
|
|
137
|
+
for k in range(f.shape[2]):
|
|
138
|
+
for i_ in range(0, i + 1):
|
|
139
|
+
for j_ in range(0, j + 1):
|
|
140
|
+
for k_ in range(0, k + 1):
|
|
141
|
+
check[i, j, k] += sb[i_, j_, k_]
|
|
142
|
+
|
|
143
|
+
assert np.allclose(check, f)
|
|
144
|
+
|
|
145
|
+
# test 4D
|
|
146
|
+
for _ in range(N):
|
|
147
|
+
a = np.random.randint(5, 10)
|
|
148
|
+
b = np.random.randint(5, 10)
|
|
149
|
+
c = np.random.randint(5, 10)
|
|
150
|
+
d = np.random.randint(5, 10)
|
|
151
|
+
|
|
152
|
+
f = np.random.randint(0, 40, size=(a, b, c, d))
|
|
153
|
+
sb = signed_betti(f)
|
|
154
|
+
|
|
155
|
+
check = np.zeros(f.shape)
|
|
156
|
+
for i in range(f.shape[0]):
|
|
157
|
+
for j in range(f.shape[1]):
|
|
158
|
+
for k in range(f.shape[2]):
|
|
159
|
+
for l in range(f.shape[3]):
|
|
160
|
+
for i_ in range(0, i + 1):
|
|
161
|
+
for j_ in range(0, j + 1):
|
|
162
|
+
for k_ in range(0, k + 1):
|
|
163
|
+
for l_ in range(0, l + 1):
|
|
164
|
+
check[i, j, k, l] += sb[i_, j_, k_, l_]
|
|
165
|
+
|
|
166
|
+
assert np.allclose(check, f)
|
|
167
|
+
|
|
168
|
+
for threshold in [True, False]:
|
|
169
|
+
for _ in range(N):
|
|
170
|
+
a = np.random.randint(5, 10)
|
|
171
|
+
b = np.random.randint(5, 10)
|
|
172
|
+
c = np.random.randint(5, 10)
|
|
173
|
+
d = np.random.randint(5, 10)
|
|
174
|
+
e = np.random.randint(5, 10)
|
|
175
|
+
f = np.random.randint(5, 10)
|
|
176
|
+
|
|
177
|
+
f = np.random.randint(0, 40, size=(a, b, c, d,e,f))
|
|
178
|
+
sb = signed_betti(f, threshold=threshold)
|
|
179
|
+
sb_ = signed_betti(f, threshold=threshold)
|
|
180
|
+
|
|
181
|
+
assert np.allclose(sb, sb_)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import gudhi as gd
|
|
2
|
+
import numpy as np
|
|
3
|
+
from numpy import array
|
|
4
|
+
|
|
5
|
+
import multipers as mp
|
|
6
|
+
from multipers.tests import assert_st_simplices
|
|
7
|
+
|
|
8
|
+
mp.simplex_tree_multi.SAFE_CONVERSION = True
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def test_1():
|
|
12
|
+
st = mp.SimplexTreeMulti(num_parameters=2)
|
|
13
|
+
st.insert([0], [0, 1])
|
|
14
|
+
st.insert([1], [1, 0])
|
|
15
|
+
st.insert([0, 1], [1, 1])
|
|
16
|
+
it = [([0, 1], [1.0, 1.0]), ([0], [0.0, 1.0]), ([1], [1.0, 0.0])]
|
|
17
|
+
assert_st_simplices(st, it)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_2():
|
|
21
|
+
from gudhi.rips_complex import RipsComplex
|
|
22
|
+
|
|
23
|
+
st2 = RipsComplex(points=[[0, 1], [1, 0], [0, 0]]).create_simplex_tree()
|
|
24
|
+
st2 = mp.SimplexTreeMulti(
|
|
25
|
+
st2, num_parameters=3, default_values=[1, 2]
|
|
26
|
+
) # the gudhi filtration is placed on axis 0
|
|
27
|
+
|
|
28
|
+
it = (
|
|
29
|
+
([0, 1], [np.sqrt(2), 1.0, 2.0]),
|
|
30
|
+
([0, 2], [1.0, 1.0, 2.0]),
|
|
31
|
+
([0], [0.0, 1.0, 2.0]),
|
|
32
|
+
([1, 2], [1.0, 1.0, 2.0]),
|
|
33
|
+
([1], [0.0, 1.0, 2.0]),
|
|
34
|
+
([2], [0.0, 1.0, 2.0]),
|
|
35
|
+
)
|
|
36
|
+
assert_st_simplices(st2, it)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def test_3():
|
|
40
|
+
st = gd.SimplexTree() # usual gudhi simplextree
|
|
41
|
+
st.insert([0, 1], 1)
|
|
42
|
+
st.insert([1], 0)
|
|
43
|
+
# converts the simplextree into a multiparameter simplextree
|
|
44
|
+
for dtype in [np.int32, np.int64, np.float32, np.float64]:
|
|
45
|
+
try:
|
|
46
|
+
st_multi = mp.SimplexTreeMulti(st, num_parameters=4, dtype=dtype)
|
|
47
|
+
except KeyError:
|
|
48
|
+
import sys
|
|
49
|
+
|
|
50
|
+
print(f"type {dtype} not compiled, skipping.", file=sys.stderr)
|
|
51
|
+
continue ## dtype not compiled
|
|
52
|
+
minf = -np.inf if isinstance(dtype(1), np.floating) else np.iinfo(dtype).min
|
|
53
|
+
it = [
|
|
54
|
+
(array([0, 1]), array([1.0, minf, minf, minf])),
|
|
55
|
+
(array([0]), array([1.0, minf, minf, minf])),
|
|
56
|
+
(array([1]), array([0.0, minf, minf, minf])),
|
|
57
|
+
]
|
|
58
|
+
assert_st_simplices(st_multi, it)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_4():
|
|
62
|
+
st = mp.SimplexTreeMulti(num_parameters=2, kcritical=True, dtype = np.float64)
|
|
63
|
+
st.insert([0,1,2], [0,1])
|
|
64
|
+
st.insert([0,1,2], [1,0])
|
|
65
|
+
st.remove_maximal_simplex([0,1,2])
|
|
66
|
+
st.insert([0,1,2], [1,2])
|
|
67
|
+
st.insert([0,1,2], [2,1])
|
|
68
|
+
st.insert([0,1,2],[1.5,1.5])
|
|
69
|
+
st.insert([0,1,2], [2.5,.5])
|
|
70
|
+
st.insert([0,1,2], [.5,2.5])
|
|
71
|
+
|
|
72
|
+
s = mp.Slicer(st, is_kcritical=True)
|
|
73
|
+
|
|
74
|
+
assert np.array_equal(s.get_filtrations_values(), array([[0. , 1. ],
|
|
75
|
+
[1. , 0. ],
|
|
76
|
+
[0. , 1. ],
|
|
77
|
+
[1. , 0. ],
|
|
78
|
+
[0. , 1. ],
|
|
79
|
+
[1. , 0. ],
|
|
80
|
+
[0. , 1. ],
|
|
81
|
+
[1. , 0. ],
|
|
82
|
+
[0. , 1. ],
|
|
83
|
+
[1. , 0. ],
|
|
84
|
+
[0. , 1. ],
|
|
85
|
+
[1. , 0. ],
|
|
86
|
+
[1. , 2. ],
|
|
87
|
+
[2. , 1. ],
|
|
88
|
+
[1.5, 1.5],
|
|
89
|
+
[2.5, 0.5],
|
|
90
|
+
[0.5, 2.5]])), "Invalid conversion from kcritical st to kcritical slicer."
|
|
91
|
+
death_curve = np.asarray(mp.module_approximation(s, box = [[0,0],[3,3]]).get_module_of_degree(1)[0].get_death_list())
|
|
92
|
+
assert np.array_equal(death_curve, array([[2. , 1.5],
|
|
93
|
+
[2.5, 1. ],
|
|
94
|
+
[np.inf, 0.5],
|
|
95
|
+
[1.5, 2. ],
|
|
96
|
+
[1. , 2.5],
|
|
97
|
+
[0.5, np.inf]]))
|
|
98
|
+
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from numpy import array
|
|
3
|
+
|
|
4
|
+
import multipers as mp
|
|
5
|
+
import multipers.slicer as mps
|
|
6
|
+
from multipers.tests import assert_sm
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def test_1():
|
|
10
|
+
st = mp.SimplexTreeMulti(num_parameters=2)
|
|
11
|
+
st.insert([0], [0, 1])
|
|
12
|
+
st.insert([1], [1, 0])
|
|
13
|
+
st.insert([0, 1], [1, 1])
|
|
14
|
+
for S in mps.available_slicers:
|
|
15
|
+
if not S().is_vine or not S().col_type or S().is_kcritical:
|
|
16
|
+
continue
|
|
17
|
+
from multipers._slicer_meta import _blocks2boundary_dimension_grades
|
|
18
|
+
|
|
19
|
+
generator_maps, generator_dimensions, filtration_values = (
|
|
20
|
+
_blocks2boundary_dimension_grades(
|
|
21
|
+
st._to_scc(),
|
|
22
|
+
inplace=False,
|
|
23
|
+
)
|
|
24
|
+
)
|
|
25
|
+
it = S(
|
|
26
|
+
generator_maps, generator_dimensions, filtration_values
|
|
27
|
+
).persistence_on_line([0, 0])
|
|
28
|
+
assert (
|
|
29
|
+
len(it) == 2
|
|
30
|
+
), "There are simplices of dim 0 and 1, but no pers ? got {}".format(len(it))
|
|
31
|
+
assert len(it[1]) == 0, "Pers of dim 1 is not empty ? got {}".format(it[1])
|
|
32
|
+
for x in it[0]:
|
|
33
|
+
if np.any(np.asarray(x)):
|
|
34
|
+
continue
|
|
35
|
+
assert x[0] == 1 and x[1] > 45, "pers should be [1,inf], got {}".format(x)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def test_2():
|
|
39
|
+
st = mp.SimplexTreeMulti(num_parameters=2)
|
|
40
|
+
st.insert([0], [0, 1])
|
|
41
|
+
st.insert([1], [1, 0])
|
|
42
|
+
st.insert([0, 1], [1, 1])
|
|
43
|
+
st.insert([0, 1, 2], [2, 2])
|
|
44
|
+
assert mp.slicer.to_simplextree(mp.Slicer(st, dtype=st.dtype)) == st
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_3():
|
|
48
|
+
st = mp.SimplexTreeMulti(num_parameters=2)
|
|
49
|
+
st.insert([0], [0, 1])
|
|
50
|
+
st.insert([1], [1, 0])
|
|
51
|
+
st.insert([0, 1], [1, 1])
|
|
52
|
+
sm = mp.signed_measure(st, invariant="rank", degree=0)
|
|
53
|
+
sm2 = mp.signed_measure(mp.Slicer(st, dtype=np.int32), invariant="rank", degree=0)
|
|
54
|
+
sm3 = mp.signed_measure(mp.Slicer(st, dtype=np.float64), invariant="rank", degree=0)
|
|
55
|
+
it = [
|
|
56
|
+
(
|
|
57
|
+
array([[0.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]),
|
|
58
|
+
array([1, 1, -1]),
|
|
59
|
+
)
|
|
60
|
+
]
|
|
61
|
+
assert_sm(sm, it)
|
|
62
|
+
assert_sm(sm2, it)
|
|
63
|
+
assert_sm(sm3, it)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .rips_density import *
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from pykeops.torch import LazyTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_grid(strategy: Literal["exact", "regular_closest", "regular_left", "quantile"]):
|
|
9
|
+
"""
|
|
10
|
+
Given a strategy, returns a function of signature
|
|
11
|
+
`(num_pts, num_parameter), int --> Iterable[1d array]`
|
|
12
|
+
that generates a torch-differentiable grid from a set of points,
|
|
13
|
+
and a resolution.
|
|
14
|
+
"""
|
|
15
|
+
match strategy:
|
|
16
|
+
case "exact":
|
|
17
|
+
return _exact_grid
|
|
18
|
+
case "regular_closest":
|
|
19
|
+
return _regular_closest_grid
|
|
20
|
+
case "regular_left":
|
|
21
|
+
return _regular_left_grid
|
|
22
|
+
case "quantile":
|
|
23
|
+
return _quantile_grid
|
|
24
|
+
case _:
|
|
25
|
+
raise ValueError(
|
|
26
|
+
f"""
|
|
27
|
+
Unimplemented strategy {strategy}.
|
|
28
|
+
Available ones : exact, regular_closest, regular_left, quantile.
|
|
29
|
+
"""
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def todense(grid: list[torch.Tensor]):
|
|
34
|
+
return torch.cartesian_product(*grid)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _exact_grid(filtration_values, r=None):
|
|
38
|
+
grid = tuple(_unique_any(f) for f in filtration_values)
|
|
39
|
+
return grid
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _regular_closest_grid(filtration_values, r: int):
|
|
43
|
+
grid = tuple(_regular_closest(f, r) for f in filtration_values)
|
|
44
|
+
return grid
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _regular_left_grid(filtration_values, r: int):
|
|
48
|
+
grid = tuple(_regular_left(f, r) for f in filtration_values)
|
|
49
|
+
return grid
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _quantile_grid(filtration_values, r: int):
|
|
53
|
+
qs = torch.linspace(0, 1, r)
|
|
54
|
+
grid = tuple(_unique_any(torch.quantile(f, q=qs)) for f in filtration_values)
|
|
55
|
+
return grid
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _unique_any(x, assume_sorted=False, remove_inf: bool = True):
|
|
59
|
+
if not assume_sorted:
|
|
60
|
+
x, _ = x.sort()
|
|
61
|
+
if remove_inf and x[-1] == torch.inf:
|
|
62
|
+
x = x[:-1]
|
|
63
|
+
with torch.no_grad():
|
|
64
|
+
y = x.unique()
|
|
65
|
+
idx = torch.searchsorted(x, y)
|
|
66
|
+
x = torch.cat([x, torch.tensor([torch.inf])])
|
|
67
|
+
return x[idx]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _regular_left(f, r: int, unique: bool = True):
|
|
71
|
+
f = _unique_any(f)
|
|
72
|
+
with torch.no_grad():
|
|
73
|
+
f_regular = torch.linspace(f[0].item(), f[-1].item(), r, device=f.device)
|
|
74
|
+
idx = torch.searchsorted(f, f_regular)
|
|
75
|
+
f = torch.cat([f, torch.tensor([torch.inf])])
|
|
76
|
+
if unique:
|
|
77
|
+
return _unique_any(f[idx])
|
|
78
|
+
return f[idx]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _regular_closest(f, r: int, unique: bool = True):
|
|
82
|
+
f = _unique_any(f)
|
|
83
|
+
with torch.no_grad():
|
|
84
|
+
f_reg = torch.linspace(
|
|
85
|
+
f[0].item(), f[-1].item(), steps=r, dtype=f.dtype, device=f.device
|
|
86
|
+
)
|
|
87
|
+
_f = LazyTensor(f[:, None, None])
|
|
88
|
+
_f_reg = LazyTensor(f_reg[None, :, None])
|
|
89
|
+
indices = (_f - _f_reg).abs().argmin(0).ravel()
|
|
90
|
+
f = torch.cat([f, torch.tensor([torch.inf])])
|
|
91
|
+
f_regular_closest = f[indices]
|
|
92
|
+
if unique:
|
|
93
|
+
f_regular_closest = _unique_any(f_regular_closest)
|
|
94
|
+
return f_regular_closest
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def evaluate_in_grid(pts, grid):
|
|
98
|
+
"""Evaluates points (assumed to be coordinates) in this grid.
|
|
99
|
+
Input
|
|
100
|
+
-----
|
|
101
|
+
- pts: (num_points, num_parameters) array
|
|
102
|
+
- grid: Iterable of 1-d array, for each parameter
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
- array of shape like points of dtype like grid.
|
|
107
|
+
"""
|
|
108
|
+
# grid = [torch.cat([g, torch.tensor([torch.inf])]) for g in grid]
|
|
109
|
+
# new_pts = torch.empty(pts.shape, dtype=grid[0].dtype, device=grid[0].device)
|
|
110
|
+
# for parameter, pt_of_parameter in enumerate(pts.T):
|
|
111
|
+
# new_pts[:, parameter] = grid[parameter][pt_of_parameter]
|
|
112
|
+
return torch.cat(
|
|
113
|
+
[
|
|
114
|
+
grid[parameter][pt_of_parameter][:, None]
|
|
115
|
+
for parameter, pt_of_parameter in enumerate(pts.T)
|
|
116
|
+
],
|
|
117
|
+
dim=1,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def evaluate_mod_in_grid(mod, grid, box=None):
|
|
122
|
+
"""Given an MMA module, pushes it into the specified grid.
|
|
123
|
+
Useful for e.g., make it differentiable.
|
|
124
|
+
|
|
125
|
+
Input
|
|
126
|
+
-----
|
|
127
|
+
- mod: PyModule
|
|
128
|
+
- grid: Iterable of 1d array, for num_parameters
|
|
129
|
+
Ouput
|
|
130
|
+
-----
|
|
131
|
+
torch-compatible module in the format:
|
|
132
|
+
(num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
|
|
133
|
+
|
|
134
|
+
"""
|
|
135
|
+
if box is not None:
|
|
136
|
+
grid = tuple(
|
|
137
|
+
torch.cat(
|
|
138
|
+
[
|
|
139
|
+
box[0][[i]],
|
|
140
|
+
_unique_any(
|
|
141
|
+
grid[i].clamp(min=box[0][i], max=box[1][i]), assume_sorted=True
|
|
142
|
+
),
|
|
143
|
+
box[1][[i]],
|
|
144
|
+
]
|
|
145
|
+
)
|
|
146
|
+
for i in range(len(grid))
|
|
147
|
+
)
|
|
148
|
+
(birth_sizes, death_sizes), births, deaths = mod.to_flat_idx(grid)
|
|
149
|
+
births = evaluate_in_grid(births, grid)
|
|
150
|
+
deaths = evaluate_in_grid(deaths, grid)
|
|
151
|
+
diff_mod = tuple(
|
|
152
|
+
zip(
|
|
153
|
+
births.split_with_sizes(birth_sizes.tolist()),
|
|
154
|
+
deaths.split_with_sizes(death_sizes.tolist()),
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
return diff_mod
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def evaluate_mod_in_grid__old(mod, grid, box=None):
|
|
161
|
+
"""Given an MMA module, pushes it into the specified grid.
|
|
162
|
+
Useful for e.g., make it differentiable.
|
|
163
|
+
|
|
164
|
+
Input
|
|
165
|
+
-----
|
|
166
|
+
- mod: PyModule
|
|
167
|
+
- grid: Iterable of 1d array, for num_parameters
|
|
168
|
+
Ouput
|
|
169
|
+
-----
|
|
170
|
+
torch-compatible module in the format:
|
|
171
|
+
(num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
|
|
172
|
+
|
|
173
|
+
"""
|
|
174
|
+
from pykeops.numpy import LazyTensor
|
|
175
|
+
|
|
176
|
+
with torch.no_grad():
|
|
177
|
+
if box is None:
|
|
178
|
+
# box = mod.get_box()
|
|
179
|
+
box = np.asarray([[g[0] for g in grid], [g[-1] for g in grid]])
|
|
180
|
+
S = mod.dump()[1]
|
|
181
|
+
|
|
182
|
+
def get_idx_parameter(A, G, p):
|
|
183
|
+
g = G[p].numpy() if isinstance(G[p], torch.Tensor) else np.asarray(G[p])
|
|
184
|
+
la = LazyTensor(np.asarray(A, dtype=g.dtype)[None, :, [p]])
|
|
185
|
+
lg = LazyTensor(g[:, None, None])
|
|
186
|
+
return (la - lg).abs().argmin(0)
|
|
187
|
+
|
|
188
|
+
Bdump = np.concatenate([s[0] for s in S], axis=0).clip(box[[0]], box[[1]])
|
|
189
|
+
B = np.concatenate(
|
|
190
|
+
[get_idx_parameter(Bdump, grid, p) for p in range(mod.num_parameters)],
|
|
191
|
+
axis=1,
|
|
192
|
+
dtype=np.int64,
|
|
193
|
+
)
|
|
194
|
+
Ddump = np.concatenate([s[1] for s in S], axis=0, dtype=np.float32).clip(
|
|
195
|
+
box[[0]], box[[1]]
|
|
196
|
+
)
|
|
197
|
+
D = np.concatenate(
|
|
198
|
+
[get_idx_parameter(Ddump, grid, p) for p in range(mod.num_parameters)],
|
|
199
|
+
axis=1,
|
|
200
|
+
dtype=np.int64,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
BB = evaluate_in_grid(B, grid)
|
|
204
|
+
DD = evaluate_in_grid(D, grid)
|
|
205
|
+
|
|
206
|
+
b_idx = tuple((len(s[0]) for s in S))
|
|
207
|
+
d_idx = tuple((len(s[1]) for s in S))
|
|
208
|
+
BBB = BB.split_with_sizes(b_idx)
|
|
209
|
+
DDD = DD.split_with_sizes(d_idx)
|
|
210
|
+
|
|
211
|
+
splits = np.concatenate([[0], mod.degree_splits(), [len(BBB)]])
|
|
212
|
+
splits = torch.from_numpy(splits)
|
|
213
|
+
out = [
|
|
214
|
+
list(zip(BBB[splits[i] : splits[i + 1]], DDD[splits[i] : splits[i + 1]]))
|
|
215
|
+
for i in range(len(splits) - 1)
|
|
216
|
+
] ## For some reasons this kills the gradient ???? pytorch bug
|
|
217
|
+
return out
|