multipers 2.0.0__cp311-cp311-macosx_13_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (78) hide show
  1. multipers/.dylibs/libc++.1.0.dylib +0 -0
  2. multipers/.dylibs/libtbb.12.12.dylib +0 -0
  3. multipers/.dylibs/libtbbmalloc.2.12.dylib +0 -0
  4. multipers/__init__.py +11 -0
  5. multipers/_signed_measure_meta.py +268 -0
  6. multipers/_slicer_meta.py +171 -0
  7. multipers/data/MOL2.py +350 -0
  8. multipers/data/UCR.py +18 -0
  9. multipers/data/__init__.py +1 -0
  10. multipers/data/graphs.py +466 -0
  11. multipers/data/immuno_regions.py +27 -0
  12. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  13. multipers/data/pytorch2simplextree.py +91 -0
  14. multipers/data/shape3d.py +101 -0
  15. multipers/data/synthetic.py +68 -0
  16. multipers/distances.py +198 -0
  17. multipers/euler_characteristic.pyx +132 -0
  18. multipers/filtration_conversions.pxd +229 -0
  19. multipers/filtrations.pxd +225 -0
  20. multipers/function_rips.cpython-311-darwin.so +0 -0
  21. multipers/function_rips.pyx +105 -0
  22. multipers/grids.cpython-311-darwin.so +0 -0
  23. multipers/grids.pyx +281 -0
  24. multipers/hilbert_function.pyi +46 -0
  25. multipers/hilbert_function.pyx +153 -0
  26. multipers/io.cpython-311-darwin.so +0 -0
  27. multipers/io.pyx +571 -0
  28. multipers/ml/__init__.py +0 -0
  29. multipers/ml/accuracies.py +90 -0
  30. multipers/ml/convolutions.py +532 -0
  31. multipers/ml/invariants_with_persistable.py +79 -0
  32. multipers/ml/kernels.py +176 -0
  33. multipers/ml/mma.py +659 -0
  34. multipers/ml/one.py +472 -0
  35. multipers/ml/point_clouds.py +238 -0
  36. multipers/ml/signed_betti.py +50 -0
  37. multipers/ml/signed_measures.py +1542 -0
  38. multipers/ml/sliced_wasserstein.py +461 -0
  39. multipers/ml/tools.py +113 -0
  40. multipers/mma_structures.cpython-311-darwin.so +0 -0
  41. multipers/mma_structures.pxd +127 -0
  42. multipers/mma_structures.pyx +2433 -0
  43. multipers/multiparameter_edge_collapse.py +41 -0
  44. multipers/multiparameter_module_approximation.cpython-311-darwin.so +0 -0
  45. multipers/multiparameter_module_approximation.pyx +211 -0
  46. multipers/pickle.py +53 -0
  47. multipers/plots.py +326 -0
  48. multipers/point_measure_integration.cpython-311-darwin.so +0 -0
  49. multipers/point_measure_integration.pyx +139 -0
  50. multipers/rank_invariant.cpython-311-darwin.so +0 -0
  51. multipers/rank_invariant.pyx +229 -0
  52. multipers/simplex_tree_multi.cpython-311-darwin.so +0 -0
  53. multipers/simplex_tree_multi.pxd +129 -0
  54. multipers/simplex_tree_multi.pyi +715 -0
  55. multipers/simplex_tree_multi.pyx +4655 -0
  56. multipers/slicer.cpython-311-darwin.so +0 -0
  57. multipers/slicer.pxd +781 -0
  58. multipers/slicer.pyx +3393 -0
  59. multipers/tensor.pxd +13 -0
  60. multipers/test.pyx +44 -0
  61. multipers/tests/__init__.py +40 -0
  62. multipers/tests/old_test_rank_invariant.py +91 -0
  63. multipers/tests/test_diff_helper.py +74 -0
  64. multipers/tests/test_hilbert_function.py +82 -0
  65. multipers/tests/test_mma.py +51 -0
  66. multipers/tests/test_point_clouds.py +59 -0
  67. multipers/tests/test_python-cpp_conversion.py +82 -0
  68. multipers/tests/test_signed_betti.py +181 -0
  69. multipers/tests/test_simplextreemulti.py +98 -0
  70. multipers/tests/test_slicer.py +63 -0
  71. multipers/torch/__init__.py +1 -0
  72. multipers/torch/diff_grids.py +217 -0
  73. multipers/torch/rips_density.py +257 -0
  74. multipers-2.0.0.dist-info/LICENSE +21 -0
  75. multipers-2.0.0.dist-info/METADATA +29 -0
  76. multipers-2.0.0.dist-info/RECORD +78 -0
  77. multipers-2.0.0.dist-info/WHEEL +5 -0
  78. multipers-2.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,181 @@
1
+ import numpy as np
2
+ from multipers.ml.signed_betti import signed_betti, rank_decomposition_by_rectangles
3
+
4
+
5
+ # only tests rank functions with 1 and 2 parameters
6
+ def test_rank_decomposition():
7
+ # rank of an interval module in 1D on a grid with 2 elements
8
+ ri = np.array(
9
+ [
10
+ [
11
+ 1, # 0,0
12
+ 1, # 0,1
13
+ ],
14
+ [0, 1], # 1,0 # 1,1
15
+ ]
16
+ )
17
+ expected_rd = np.array(
18
+ [
19
+ [
20
+ 0, # 0,0
21
+ 1, # 0,1
22
+ ],
23
+ [0, 0], # 1,0 # 1,1
24
+ ]
25
+ )
26
+ rd = rank_decomposition_by_rectangles(ri)
27
+ for i in range(2):
28
+ for i_ in range(i, 2):
29
+ assert rd[i, i_] == expected_rd[i, i_]
30
+
31
+ # rank of a sum of two rectangles in 2D on a grid of 2 elements
32
+ ri = np.array(
33
+ [
34
+ [
35
+ [
36
+ [1, 1], # (0,0), (0,0) # (0,0), (0,1)
37
+ [1, 1], # (0,0), (1,0) # (0,0), (1,1)
38
+ ],
39
+ [
40
+ [0, 1], # (0,1), (0,0) # (0,1), (0,1)
41
+ [0, 1], # (0,1), (1,0) # (0,1), (1,1)
42
+ ],
43
+ ],
44
+ [
45
+ [
46
+ [0, 0], # (1,0), (0,0) # (1,0), (0,1)
47
+ [2, 2], # (1,0), (1,0) # (1,0), (1,1)
48
+ ],
49
+ [
50
+ [0, 0], # (1,1), (0,0) # (1,1), (0,1)
51
+ [0, 2], # (1,1), (1,0) # (1,1), (1,1)
52
+ ],
53
+ ],
54
+ ]
55
+ )
56
+ expected_rd = np.array(
57
+ [
58
+ [
59
+ [
60
+ [0, 0], # (0,0), (0,0) # (0,0), (0,1)
61
+ [0, 1], # (0,0), (1,0) # (0,0), (1,1)
62
+ ],
63
+ [
64
+ [0, 0], # (0,1), (0,0) # (0,1), (0,1)
65
+ [0, 0], # (0,1), (1,0) # (0,1), (1,1)
66
+ ],
67
+ ],
68
+ [
69
+ [
70
+ [0, 0], # (1,0), (0,0) # (1,0), (0,1)
71
+ [0, 1], # (1,0), (1,0) # (1,0), (1,1)
72
+ ],
73
+ [
74
+ [0, 0], # (1,1), (0,0) # (1,1), (0,1)
75
+ [0, 0], # (1,1), (1,0) # (1,1), (1,1)
76
+ ],
77
+ ],
78
+ ]
79
+ )
80
+
81
+ rd = rank_decomposition_by_rectangles(ri)
82
+ for i in range(2):
83
+ for i_ in range(i, 2):
84
+ for j in range(2):
85
+ for j_ in range(j, 2):
86
+ assert rd[i, j, i_, j_] == expected_rd[i, j, i_, j_]
87
+
88
+
89
+ # only tests Hilbert functions with 1, 2, 3, and 4 parameters
90
+ def _test_signed_betti():
91
+ np.random.seed(0)
92
+ N = 4
93
+
94
+ # test 1D
95
+ for _ in range(N):
96
+ a = np.random.randint(10, 30)
97
+
98
+ f = np.random.randint(0, 40, size=(a))
99
+ sb = signed_betti(f)
100
+
101
+ check = np.zeros(f.shape)
102
+ for i in range(f.shape[0]):
103
+ for i_ in range(0, i + 1):
104
+ check[i] += sb[i_]
105
+
106
+ assert np.allclose(check, f)
107
+
108
+ # test 2D
109
+ for _ in range(N):
110
+ a = np.random.randint(10, 30)
111
+ b = np.random.randint(10, 30)
112
+
113
+ f = np.random.randint(0, 40, size=(a, b))
114
+ sb = signed_betti(f)
115
+
116
+ check = np.zeros(f.shape)
117
+ for i in range(f.shape[0]):
118
+ for j in range(f.shape[1]):
119
+ for i_ in range(0, i + 1):
120
+ for j_ in range(0, j + 1):
121
+ check[i, j] += sb[i_, j_]
122
+
123
+ assert np.allclose(check, f)
124
+
125
+ # test 3D
126
+ for _ in range(N):
127
+ a = np.random.randint(10, 20)
128
+ b = np.random.randint(10, 20)
129
+ c = np.random.randint(10, 20)
130
+
131
+ f = np.random.randint(0, 40, size=(a, b, c))
132
+ sb = signed_betti(f)
133
+
134
+ check = np.zeros(f.shape)
135
+ for i in range(f.shape[0]):
136
+ for j in range(f.shape[1]):
137
+ for k in range(f.shape[2]):
138
+ for i_ in range(0, i + 1):
139
+ for j_ in range(0, j + 1):
140
+ for k_ in range(0, k + 1):
141
+ check[i, j, k] += sb[i_, j_, k_]
142
+
143
+ assert np.allclose(check, f)
144
+
145
+ # test 4D
146
+ for _ in range(N):
147
+ a = np.random.randint(5, 10)
148
+ b = np.random.randint(5, 10)
149
+ c = np.random.randint(5, 10)
150
+ d = np.random.randint(5, 10)
151
+
152
+ f = np.random.randint(0, 40, size=(a, b, c, d))
153
+ sb = signed_betti(f)
154
+
155
+ check = np.zeros(f.shape)
156
+ for i in range(f.shape[0]):
157
+ for j in range(f.shape[1]):
158
+ for k in range(f.shape[2]):
159
+ for l in range(f.shape[3]):
160
+ for i_ in range(0, i + 1):
161
+ for j_ in range(0, j + 1):
162
+ for k_ in range(0, k + 1):
163
+ for l_ in range(0, l + 1):
164
+ check[i, j, k, l] += sb[i_, j_, k_, l_]
165
+
166
+ assert np.allclose(check, f)
167
+
168
+ for threshold in [True, False]:
169
+ for _ in range(N):
170
+ a = np.random.randint(5, 10)
171
+ b = np.random.randint(5, 10)
172
+ c = np.random.randint(5, 10)
173
+ d = np.random.randint(5, 10)
174
+ e = np.random.randint(5, 10)
175
+ f = np.random.randint(5, 10)
176
+
177
+ f = np.random.randint(0, 40, size=(a, b, c, d,e,f))
178
+ sb = signed_betti(f, threshold=threshold)
179
+ sb_ = signed_betti(f, threshold=threshold)
180
+
181
+ assert np.allclose(sb, sb_)
@@ -0,0 +1,98 @@
1
+ import gudhi as gd
2
+ import numpy as np
3
+ from numpy import array
4
+
5
+ import multipers as mp
6
+ from multipers.tests import assert_st_simplices
7
+
8
+ mp.simplex_tree_multi.SAFE_CONVERSION = True
9
+
10
+
11
+ def test_1():
12
+ st = mp.SimplexTreeMulti(num_parameters=2)
13
+ st.insert([0], [0, 1])
14
+ st.insert([1], [1, 0])
15
+ st.insert([0, 1], [1, 1])
16
+ it = [([0, 1], [1.0, 1.0]), ([0], [0.0, 1.0]), ([1], [1.0, 0.0])]
17
+ assert_st_simplices(st, it)
18
+
19
+
20
+ def test_2():
21
+ from gudhi.rips_complex import RipsComplex
22
+
23
+ st2 = RipsComplex(points=[[0, 1], [1, 0], [0, 0]]).create_simplex_tree()
24
+ st2 = mp.SimplexTreeMulti(
25
+ st2, num_parameters=3, default_values=[1, 2]
26
+ ) # the gudhi filtration is placed on axis 0
27
+
28
+ it = (
29
+ ([0, 1], [np.sqrt(2), 1.0, 2.0]),
30
+ ([0, 2], [1.0, 1.0, 2.0]),
31
+ ([0], [0.0, 1.0, 2.0]),
32
+ ([1, 2], [1.0, 1.0, 2.0]),
33
+ ([1], [0.0, 1.0, 2.0]),
34
+ ([2], [0.0, 1.0, 2.0]),
35
+ )
36
+ assert_st_simplices(st2, it)
37
+
38
+
39
+ def test_3():
40
+ st = gd.SimplexTree() # usual gudhi simplextree
41
+ st.insert([0, 1], 1)
42
+ st.insert([1], 0)
43
+ # converts the simplextree into a multiparameter simplextree
44
+ for dtype in [np.int32, np.int64, np.float32, np.float64]:
45
+ try:
46
+ st_multi = mp.SimplexTreeMulti(st, num_parameters=4, dtype=dtype)
47
+ except KeyError:
48
+ import sys
49
+
50
+ print(f"type {dtype} not compiled, skipping.", file=sys.stderr)
51
+ continue ## dtype not compiled
52
+ minf = -np.inf if isinstance(dtype(1), np.floating) else np.iinfo(dtype).min
53
+ it = [
54
+ (array([0, 1]), array([1.0, minf, minf, minf])),
55
+ (array([0]), array([1.0, minf, minf, minf])),
56
+ (array([1]), array([0.0, minf, minf, minf])),
57
+ ]
58
+ assert_st_simplices(st_multi, it)
59
+
60
+
61
+ def test_4():
62
+ st = mp.SimplexTreeMulti(num_parameters=2, kcritical=True, dtype = np.float64)
63
+ st.insert([0,1,2], [0,1])
64
+ st.insert([0,1,2], [1,0])
65
+ st.remove_maximal_simplex([0,1,2])
66
+ st.insert([0,1,2], [1,2])
67
+ st.insert([0,1,2], [2,1])
68
+ st.insert([0,1,2],[1.5,1.5])
69
+ st.insert([0,1,2], [2.5,.5])
70
+ st.insert([0,1,2], [.5,2.5])
71
+
72
+ s = mp.Slicer(st, is_kcritical=True)
73
+
74
+ assert np.array_equal(s.get_filtrations_values(), array([[0. , 1. ],
75
+ [1. , 0. ],
76
+ [0. , 1. ],
77
+ [1. , 0. ],
78
+ [0. , 1. ],
79
+ [1. , 0. ],
80
+ [0. , 1. ],
81
+ [1. , 0. ],
82
+ [0. , 1. ],
83
+ [1. , 0. ],
84
+ [0. , 1. ],
85
+ [1. , 0. ],
86
+ [1. , 2. ],
87
+ [2. , 1. ],
88
+ [1.5, 1.5],
89
+ [2.5, 0.5],
90
+ [0.5, 2.5]])), "Invalid conversion from kcritical st to kcritical slicer."
91
+ death_curve = np.asarray(mp.module_approximation(s, box = [[0,0],[3,3]]).get_module_of_degree(1)[0].get_death_list())
92
+ assert np.array_equal(death_curve, array([[2. , 1.5],
93
+ [2.5, 1. ],
94
+ [np.inf, 0.5],
95
+ [1.5, 2. ],
96
+ [1. , 2.5],
97
+ [0.5, np.inf]]))
98
+
@@ -0,0 +1,63 @@
1
+ import numpy as np
2
+ from numpy import array
3
+
4
+ import multipers as mp
5
+ import multipers.slicer as mps
6
+ from multipers.tests import assert_sm
7
+
8
+
9
+ def test_1():
10
+ st = mp.SimplexTreeMulti(num_parameters=2)
11
+ st.insert([0], [0, 1])
12
+ st.insert([1], [1, 0])
13
+ st.insert([0, 1], [1, 1])
14
+ for S in mps.available_slicers:
15
+ if not S().is_vine or not S().col_type or S().is_kcritical:
16
+ continue
17
+ from multipers._slicer_meta import _blocks2boundary_dimension_grades
18
+
19
+ generator_maps, generator_dimensions, filtration_values = (
20
+ _blocks2boundary_dimension_grades(
21
+ st._to_scc(),
22
+ inplace=False,
23
+ )
24
+ )
25
+ it = S(
26
+ generator_maps, generator_dimensions, filtration_values
27
+ ).persistence_on_line([0, 0])
28
+ assert (
29
+ len(it) == 2
30
+ ), "There are simplices of dim 0 and 1, but no pers ? got {}".format(len(it))
31
+ assert len(it[1]) == 0, "Pers of dim 1 is not empty ? got {}".format(it[1])
32
+ for x in it[0]:
33
+ if np.any(np.asarray(x)):
34
+ continue
35
+ assert x[0] == 1 and x[1] > 45, "pers should be [1,inf], got {}".format(x)
36
+
37
+
38
+ def test_2():
39
+ st = mp.SimplexTreeMulti(num_parameters=2)
40
+ st.insert([0], [0, 1])
41
+ st.insert([1], [1, 0])
42
+ st.insert([0, 1], [1, 1])
43
+ st.insert([0, 1, 2], [2, 2])
44
+ assert mp.slicer.to_simplextree(mp.Slicer(st, dtype=st.dtype)) == st
45
+
46
+
47
+ def test_3():
48
+ st = mp.SimplexTreeMulti(num_parameters=2)
49
+ st.insert([0], [0, 1])
50
+ st.insert([1], [1, 0])
51
+ st.insert([0, 1], [1, 1])
52
+ sm = mp.signed_measure(st, invariant="rank", degree=0)
53
+ sm2 = mp.signed_measure(mp.Slicer(st, dtype=np.int32), invariant="rank", degree=0)
54
+ sm3 = mp.signed_measure(mp.Slicer(st, dtype=np.float64), invariant="rank", degree=0)
55
+ it = [
56
+ (
57
+ array([[0.0, 1.0, 1.0, 1.0], [1.0, 0.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]),
58
+ array([1, 1, -1]),
59
+ )
60
+ ]
61
+ assert_sm(sm, it)
62
+ assert_sm(sm2, it)
63
+ assert_sm(sm3, it)
@@ -0,0 +1 @@
1
+ from .rips_density import *
@@ -0,0 +1,217 @@
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ import torch
5
+ from pykeops.torch import LazyTensor
6
+
7
+
8
+ def get_grid(strategy: Literal["exact", "regular_closest", "regular_left", "quantile"]):
9
+ """
10
+ Given a strategy, returns a function of signature
11
+ `(num_pts, num_parameter), int --> Iterable[1d array]`
12
+ that generates a torch-differentiable grid from a set of points,
13
+ and a resolution.
14
+ """
15
+ match strategy:
16
+ case "exact":
17
+ return _exact_grid
18
+ case "regular_closest":
19
+ return _regular_closest_grid
20
+ case "regular_left":
21
+ return _regular_left_grid
22
+ case "quantile":
23
+ return _quantile_grid
24
+ case _:
25
+ raise ValueError(
26
+ f"""
27
+ Unimplemented strategy {strategy}.
28
+ Available ones : exact, regular_closest, regular_left, quantile.
29
+ """
30
+ )
31
+
32
+
33
+ def todense(grid: list[torch.Tensor]):
34
+ return torch.cartesian_product(*grid)
35
+
36
+
37
+ def _exact_grid(filtration_values, r=None):
38
+ grid = tuple(_unique_any(f) for f in filtration_values)
39
+ return grid
40
+
41
+
42
+ def _regular_closest_grid(filtration_values, r: int):
43
+ grid = tuple(_regular_closest(f, r) for f in filtration_values)
44
+ return grid
45
+
46
+
47
+ def _regular_left_grid(filtration_values, r: int):
48
+ grid = tuple(_regular_left(f, r) for f in filtration_values)
49
+ return grid
50
+
51
+
52
+ def _quantile_grid(filtration_values, r: int):
53
+ qs = torch.linspace(0, 1, r)
54
+ grid = tuple(_unique_any(torch.quantile(f, q=qs)) for f in filtration_values)
55
+ return grid
56
+
57
+
58
+ def _unique_any(x, assume_sorted=False, remove_inf: bool = True):
59
+ if not assume_sorted:
60
+ x, _ = x.sort()
61
+ if remove_inf and x[-1] == torch.inf:
62
+ x = x[:-1]
63
+ with torch.no_grad():
64
+ y = x.unique()
65
+ idx = torch.searchsorted(x, y)
66
+ x = torch.cat([x, torch.tensor([torch.inf])])
67
+ return x[idx]
68
+
69
+
70
+ def _regular_left(f, r: int, unique: bool = True):
71
+ f = _unique_any(f)
72
+ with torch.no_grad():
73
+ f_regular = torch.linspace(f[0].item(), f[-1].item(), r, device=f.device)
74
+ idx = torch.searchsorted(f, f_regular)
75
+ f = torch.cat([f, torch.tensor([torch.inf])])
76
+ if unique:
77
+ return _unique_any(f[idx])
78
+ return f[idx]
79
+
80
+
81
+ def _regular_closest(f, r: int, unique: bool = True):
82
+ f = _unique_any(f)
83
+ with torch.no_grad():
84
+ f_reg = torch.linspace(
85
+ f[0].item(), f[-1].item(), steps=r, dtype=f.dtype, device=f.device
86
+ )
87
+ _f = LazyTensor(f[:, None, None])
88
+ _f_reg = LazyTensor(f_reg[None, :, None])
89
+ indices = (_f - _f_reg).abs().argmin(0).ravel()
90
+ f = torch.cat([f, torch.tensor([torch.inf])])
91
+ f_regular_closest = f[indices]
92
+ if unique:
93
+ f_regular_closest = _unique_any(f_regular_closest)
94
+ return f_regular_closest
95
+
96
+
97
+ def evaluate_in_grid(pts, grid):
98
+ """Evaluates points (assumed to be coordinates) in this grid.
99
+ Input
100
+ -----
101
+ - pts: (num_points, num_parameters) array
102
+ - grid: Iterable of 1-d array, for each parameter
103
+
104
+ Returns
105
+ -------
106
+ - array of shape like points of dtype like grid.
107
+ """
108
+ # grid = [torch.cat([g, torch.tensor([torch.inf])]) for g in grid]
109
+ # new_pts = torch.empty(pts.shape, dtype=grid[0].dtype, device=grid[0].device)
110
+ # for parameter, pt_of_parameter in enumerate(pts.T):
111
+ # new_pts[:, parameter] = grid[parameter][pt_of_parameter]
112
+ return torch.cat(
113
+ [
114
+ grid[parameter][pt_of_parameter][:, None]
115
+ for parameter, pt_of_parameter in enumerate(pts.T)
116
+ ],
117
+ dim=1,
118
+ )
119
+
120
+
121
+ def evaluate_mod_in_grid(mod, grid, box=None):
122
+ """Given an MMA module, pushes it into the specified grid.
123
+ Useful for e.g., make it differentiable.
124
+
125
+ Input
126
+ -----
127
+ - mod: PyModule
128
+ - grid: Iterable of 1d array, for num_parameters
129
+ Ouput
130
+ -----
131
+ torch-compatible module in the format:
132
+ (num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
133
+
134
+ """
135
+ if box is not None:
136
+ grid = tuple(
137
+ torch.cat(
138
+ [
139
+ box[0][[i]],
140
+ _unique_any(
141
+ grid[i].clamp(min=box[0][i], max=box[1][i]), assume_sorted=True
142
+ ),
143
+ box[1][[i]],
144
+ ]
145
+ )
146
+ for i in range(len(grid))
147
+ )
148
+ (birth_sizes, death_sizes), births, deaths = mod.to_flat_idx(grid)
149
+ births = evaluate_in_grid(births, grid)
150
+ deaths = evaluate_in_grid(deaths, grid)
151
+ diff_mod = tuple(
152
+ zip(
153
+ births.split_with_sizes(birth_sizes.tolist()),
154
+ deaths.split_with_sizes(death_sizes.tolist()),
155
+ )
156
+ )
157
+ return diff_mod
158
+
159
+
160
+ def evaluate_mod_in_grid__old(mod, grid, box=None):
161
+ """Given an MMA module, pushes it into the specified grid.
162
+ Useful for e.g., make it differentiable.
163
+
164
+ Input
165
+ -----
166
+ - mod: PyModule
167
+ - grid: Iterable of 1d array, for num_parameters
168
+ Ouput
169
+ -----
170
+ torch-compatible module in the format:
171
+ (num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
172
+
173
+ """
174
+ from pykeops.numpy import LazyTensor
175
+
176
+ with torch.no_grad():
177
+ if box is None:
178
+ # box = mod.get_box()
179
+ box = np.asarray([[g[0] for g in grid], [g[-1] for g in grid]])
180
+ S = mod.dump()[1]
181
+
182
+ def get_idx_parameter(A, G, p):
183
+ g = G[p].numpy() if isinstance(G[p], torch.Tensor) else np.asarray(G[p])
184
+ la = LazyTensor(np.asarray(A, dtype=g.dtype)[None, :, [p]])
185
+ lg = LazyTensor(g[:, None, None])
186
+ return (la - lg).abs().argmin(0)
187
+
188
+ Bdump = np.concatenate([s[0] for s in S], axis=0).clip(box[[0]], box[[1]])
189
+ B = np.concatenate(
190
+ [get_idx_parameter(Bdump, grid, p) for p in range(mod.num_parameters)],
191
+ axis=1,
192
+ dtype=np.int64,
193
+ )
194
+ Ddump = np.concatenate([s[1] for s in S], axis=0, dtype=np.float32).clip(
195
+ box[[0]], box[[1]]
196
+ )
197
+ D = np.concatenate(
198
+ [get_idx_parameter(Ddump, grid, p) for p in range(mod.num_parameters)],
199
+ axis=1,
200
+ dtype=np.int64,
201
+ )
202
+
203
+ BB = evaluate_in_grid(B, grid)
204
+ DD = evaluate_in_grid(D, grid)
205
+
206
+ b_idx = tuple((len(s[0]) for s in S))
207
+ d_idx = tuple((len(s[1]) for s in S))
208
+ BBB = BB.split_with_sizes(b_idx)
209
+ DDD = DD.split_with_sizes(d_idx)
210
+
211
+ splits = np.concatenate([[0], mod.degree_splits(), [len(BBB)]])
212
+ splits = torch.from_numpy(splits)
213
+ out = [
214
+ list(zip(BBB[splits[i] : splits[i + 1]], DDD[splits[i] : splits[i + 1]]))
215
+ for i in range(len(splits) - 1)
216
+ ] ## For some reasons this kills the gradient ???? pytorch bug
217
+ return out