multipers 2.2.3__cp310-cp310-win_amd64.whl → 2.3.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/__init__.py +33 -31
- multipers/_signed_measure_meta.py +430 -430
- multipers/_slicer_meta.py +211 -212
- multipers/data/MOL2.py +458 -458
- multipers/data/UCR.py +18 -18
- multipers/data/graphs.py +466 -466
- multipers/data/immuno_regions.py +27 -27
- multipers/data/pytorch2simplextree.py +90 -90
- multipers/data/shape3d.py +101 -101
- multipers/data/synthetic.py +113 -111
- multipers/distances.py +198 -198
- multipers/filtration_conversions.pxd.tp +84 -84
- multipers/filtrations/__init__.py +18 -0
- multipers/{ml/convolutions.py → filtrations/density.py} +563 -520
- multipers/filtrations/filtrations.py +289 -0
- multipers/filtrations.pxd +224 -224
- multipers/function_rips.cp310-win_amd64.pyd +0 -0
- multipers/function_rips.pyx +105 -105
- multipers/grids.cp310-win_amd64.pyd +0 -0
- multipers/grids.pyx +350 -350
- multipers/gudhi/Persistence_slices_interface.h +132 -132
- multipers/gudhi/Simplex_tree_interface.h +239 -245
- multipers/gudhi/Simplex_tree_multi_interface.h +516 -561
- multipers/gudhi/cubical_to_boundary.h +59 -59
- multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -450
- multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -1070
- multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -579
- multipers/gudhi/gudhi/Debug_utils.h +45 -45
- multipers/gudhi/gudhi/Fields/Multi_field.h +484 -484
- multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -455
- multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -450
- multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -531
- multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -507
- multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -531
- multipers/gudhi/gudhi/Fields/Z2_field.h +355 -355
- multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -376
- multipers/gudhi/gudhi/Fields/Zp_field.h +420 -420
- multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -400
- multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -418
- multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -337
- multipers/gudhi/gudhi/Matrix.h +2107 -2107
- multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -1038
- multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -171
- multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -282
- multipers/gudhi/gudhi/Off_reader.h +173 -173
- multipers/gudhi/gudhi/One_critical_filtration.h +1433 -1431
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -769
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -686
- multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -842
- multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -1350
- multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -1105
- multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -859
- multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -910
- multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -139
- multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -230
- multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -211
- multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -60
- multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -60
- multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -136
- multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -190
- multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -616
- multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -150
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -106
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -219
- multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -327
- multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -1140
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -934
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -934
- multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -980
- multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -1092
- multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -192
- multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -921
- multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -1093
- multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -1012
- multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -1244
- multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -186
- multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -164
- multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -156
- multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -376
- multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -540
- multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -118
- multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -173
- multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -128
- multipers/gudhi/gudhi/Persistent_cohomology.h +745 -745
- multipers/gudhi/gudhi/Points_off_io.h +171 -171
- multipers/gudhi/gudhi/Simple_object_pool.h +69 -69
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -463
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -83
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -106
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -277
- multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -62
- multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -27
- multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -62
- multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -157
- multipers/gudhi/gudhi/Simplex_tree.h +2794 -2794
- multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -163
- multipers/gudhi/gudhi/distance_functions.h +62 -62
- multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -104
- multipers/gudhi/gudhi/persistence_interval.h +253 -253
- multipers/gudhi/gudhi/persistence_matrix_options.h +170 -170
- multipers/gudhi/gudhi/reader_utils.h +367 -367
- multipers/gudhi/mma_interface_coh.h +256 -255
- multipers/gudhi/mma_interface_h0.h +223 -231
- multipers/gudhi/mma_interface_matrix.h +291 -282
- multipers/gudhi/naive_merge_tree.h +536 -575
- multipers/gudhi/scc_io.h +310 -289
- multipers/gudhi/truc.h +957 -888
- multipers/io.cp310-win_amd64.pyd +0 -0
- multipers/io.pyx +714 -711
- multipers/ml/accuracies.py +90 -90
- multipers/ml/invariants_with_persistable.py +79 -79
- multipers/ml/kernels.py +176 -176
- multipers/ml/mma.py +713 -714
- multipers/ml/one.py +472 -472
- multipers/ml/point_clouds.py +352 -346
- multipers/ml/signed_measures.py +1589 -1589
- multipers/ml/sliced_wasserstein.py +461 -461
- multipers/ml/tools.py +113 -113
- multipers/mma_structures.cp310-win_amd64.pyd +0 -0
- multipers/mma_structures.pxd +127 -127
- multipers/mma_structures.pyx +4 -8
- multipers/mma_structures.pyx.tp +1083 -1085
- multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -93
- multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -97
- multipers/multi_parameter_rank_invariant/function_rips.h +322 -322
- multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -769
- multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -148
- multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -369
- multipers/multiparameter_edge_collapse.py +41 -41
- multipers/multiparameter_module_approximation/approximation.h +2298 -2295
- multipers/multiparameter_module_approximation/combinatory.h +129 -129
- multipers/multiparameter_module_approximation/debug.h +107 -107
- multipers/multiparameter_module_approximation/format_python-cpp.h +286 -286
- multipers/multiparameter_module_approximation/heap_column.h +238 -238
- multipers/multiparameter_module_approximation/images.h +79 -79
- multipers/multiparameter_module_approximation/list_column.h +174 -174
- multipers/multiparameter_module_approximation/list_column_2.h +232 -232
- multipers/multiparameter_module_approximation/ru_matrix.h +347 -347
- multipers/multiparameter_module_approximation/set_column.h +135 -135
- multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -36
- multipers/multiparameter_module_approximation/unordered_set_column.h +166 -166
- multipers/multiparameter_module_approximation/utilities.h +403 -419
- multipers/multiparameter_module_approximation/vector_column.h +223 -223
- multipers/multiparameter_module_approximation/vector_matrix.h +331 -331
- multipers/multiparameter_module_approximation/vineyards.h +464 -464
- multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -649
- multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
- multipers/multiparameter_module_approximation.pyx +218 -217
- multipers/pickle.py +90 -53
- multipers/plots.py +342 -334
- multipers/point_measure.cp310-win_amd64.pyd +0 -0
- multipers/point_measure.pyx +322 -320
- multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
- multipers/simplex_tree_multi.pxd +133 -133
- multipers/simplex_tree_multi.pyx +115 -48
- multipers/simplex_tree_multi.pyx.tp +1947 -1935
- multipers/slicer.cp310-win_amd64.pyd +0 -0
- multipers/slicer.pxd +301 -120
- multipers/slicer.pxd.tp +218 -214
- multipers/slicer.pyx +1570 -507
- multipers/slicer.pyx.tp +931 -914
- multipers/tensor/tensor.h +672 -672
- multipers/tensor.pxd +13 -13
- multipers/test.pyx +44 -44
- multipers/tests/__init__.py +57 -57
- multipers/torch/diff_grids.py +217 -217
- multipers/torch/rips_density.py +310 -304
- {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/LICENSE +21 -21
- {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/METADATA +21 -11
- multipers-2.3.1.dist-info/RECORD +182 -0
- {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/WHEEL +1 -1
- multipers/tests/test_diff_helper.py +0 -73
- multipers/tests/test_hilbert_function.py +0 -82
- multipers/tests/test_mma.py +0 -83
- multipers/tests/test_point_clouds.py +0 -49
- multipers/tests/test_python-cpp_conversion.py +0 -82
- multipers/tests/test_signed_betti.py +0 -181
- multipers/tests/test_signed_measure.py +0 -89
- multipers/tests/test_simplextreemulti.py +0 -221
- multipers/tests/test_slicer.py +0 -221
- multipers-2.2.3.dist-info/RECORD +0 -189
- {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/top_level.txt +0 -0
multipers/tensor.pxd
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
from libc.stdint cimport uint16_t
|
|
2
|
-
from libcpp.vector cimport vector
|
|
3
|
-
from libcpp cimport bool, float
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
ctypedef float dtype
|
|
7
|
-
ctypedef uint16_t index_type
|
|
8
|
-
|
|
9
|
-
cdef extern from "tensor/tensor.h" namespace "tensor":
|
|
10
|
-
cdef cppclass static_tensor_view[float, uint16_t]:
|
|
11
|
-
static_tensor_view() except + nogil
|
|
12
|
-
static_tensor_view(dtype*,const vector[index_type]&) except + nogil
|
|
13
|
-
const vector[index_type]& get_resolution()
|
|
1
|
+
from libc.stdint cimport uint16_t
|
|
2
|
+
from libcpp.vector cimport vector
|
|
3
|
+
from libcpp cimport bool, float
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
ctypedef float dtype
|
|
7
|
+
ctypedef uint16_t index_type
|
|
8
|
+
|
|
9
|
+
cdef extern from "tensor/tensor.h" namespace "tensor":
|
|
10
|
+
cdef cppclass static_tensor_view[float, uint16_t]:
|
|
11
|
+
static_tensor_view() except + nogil
|
|
12
|
+
static_tensor_view(dtype*,const vector[index_type]&) except + nogil
|
|
13
|
+
const vector[index_type]& get_resolution()
|
multipers/test.pyx
CHANGED
|
@@ -1,44 +1,44 @@
|
|
|
1
|
-
# cimport multipers.tensor as mt
|
|
2
|
-
from libc.stdint cimport intptr_t, uint16_t
|
|
3
|
-
from libcpp.vector cimport vector
|
|
4
|
-
from libcpp cimport bool, int, float
|
|
5
|
-
from libcpp.utility cimport pair
|
|
6
|
-
from typing import Optional,Iterable,Callable
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
ctypedef float value_type
|
|
10
|
-
# ctypedef uint16_t index_type
|
|
11
|
-
|
|
12
|
-
import numpy as np
|
|
13
|
-
# cimport numpy as cnp
|
|
14
|
-
# cnp.import_array()
|
|
15
|
-
|
|
16
|
-
# cdef extern from "multi_parameter_rank_invariant/rank_invariant.h" namespace "Gudhi::rank_invariant":
|
|
17
|
-
# void get_hilbert_surface(const intptr_t, mt.static_tensor_view, const vector[index_type], const vector[index_type], index_type, index_type, const vector[index_type], bool, bool) except + nogil
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
from multipers.simplex_tree_multi import SimplexTreeMulti
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def numpy_to_tensor(array:np.ndarray):
|
|
24
|
-
cdef vector[index_type] shape = array.shape
|
|
25
|
-
cdef dtype[::1] contigus_array_view = np.ascontiguousarray(array)
|
|
26
|
-
cdef dtype* dtype_ptr = &contigus_array_view[0]
|
|
27
|
-
cdef mt.static_tensor_view tensor
|
|
28
|
-
with nogil:
|
|
29
|
-
tensor = mt.static_tensor_view(dtype_ptr, shape)
|
|
30
|
-
return tensor.get_resolution()
|
|
31
|
-
|
|
32
|
-
# def hilbert2d(simplextree:SimplexTreeMulti, grid_shape:np.ndarray|list, vector[index_type] degrees, bool mobius_inversion):
|
|
33
|
-
# # assert simplextree.num_parameters == 2
|
|
34
|
-
# cdef intptr_t ptr = simplextree.thisptr
|
|
35
|
-
# cdef vector[index_type] c_grid_shape = grid_shape
|
|
36
|
-
# cdef dtype[::1] container = np.zeros(grid_shape, dtype=np.float32).flatten()
|
|
37
|
-
# cdef dtype* container_ptr = &container[0]
|
|
38
|
-
# cdef mt.static_tensor_view c_container = mt.static_tensor_view(container_ptr, c_grid_shape)
|
|
39
|
-
# cdef index_type i = 0
|
|
40
|
-
# cdef index_type j = 1
|
|
41
|
-
# cdef vector[index_type] fixed_values = [[],[]]
|
|
42
|
-
# # get_hilbert_surface(ptr, c_container, c_grid_shape, degrees,i,j,fixed_values, False, False)
|
|
43
|
-
# return container.reshape(grid_shape)
|
|
44
|
-
|
|
1
|
+
# cimport multipers.tensor as mt
|
|
2
|
+
from libc.stdint cimport intptr_t, uint16_t
|
|
3
|
+
from libcpp.vector cimport vector
|
|
4
|
+
from libcpp cimport bool, int, float
|
|
5
|
+
from libcpp.utility cimport pair
|
|
6
|
+
from typing import Optional,Iterable,Callable
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
ctypedef float value_type
|
|
10
|
+
# ctypedef uint16_t index_type
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
# cimport numpy as cnp
|
|
14
|
+
# cnp.import_array()
|
|
15
|
+
|
|
16
|
+
# cdef extern from "multi_parameter_rank_invariant/rank_invariant.h" namespace "Gudhi::rank_invariant":
|
|
17
|
+
# void get_hilbert_surface(const intptr_t, mt.static_tensor_view, const vector[index_type], const vector[index_type], index_type, index_type, const vector[index_type], bool, bool) except + nogil
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
from multipers.simplex_tree_multi import SimplexTreeMulti
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def numpy_to_tensor(array:np.ndarray):
|
|
24
|
+
cdef vector[index_type] shape = array.shape
|
|
25
|
+
cdef dtype[::1] contigus_array_view = np.ascontiguousarray(array)
|
|
26
|
+
cdef dtype* dtype_ptr = &contigus_array_view[0]
|
|
27
|
+
cdef mt.static_tensor_view tensor
|
|
28
|
+
with nogil:
|
|
29
|
+
tensor = mt.static_tensor_view(dtype_ptr, shape)
|
|
30
|
+
return tensor.get_resolution()
|
|
31
|
+
|
|
32
|
+
# def hilbert2d(simplextree:SimplexTreeMulti, grid_shape:np.ndarray|list, vector[index_type] degrees, bool mobius_inversion):
|
|
33
|
+
# # assert simplextree.num_parameters == 2
|
|
34
|
+
# cdef intptr_t ptr = simplextree.thisptr
|
|
35
|
+
# cdef vector[index_type] c_grid_shape = grid_shape
|
|
36
|
+
# cdef dtype[::1] container = np.zeros(grid_shape, dtype=np.float32).flatten()
|
|
37
|
+
# cdef dtype* container_ptr = &container[0]
|
|
38
|
+
# cdef mt.static_tensor_view c_container = mt.static_tensor_view(container_ptr, c_grid_shape)
|
|
39
|
+
# cdef index_type i = 0
|
|
40
|
+
# cdef index_type j = 1
|
|
41
|
+
# cdef vector[index_type] fixed_values = [[],[]]
|
|
42
|
+
# # get_hilbert_surface(ptr, c_container, c_grid_shape, degrees,i,j,fixed_values, False, False)
|
|
43
|
+
# return container.reshape(grid_shape)
|
|
44
|
+
|
multipers/tests/__init__.py
CHANGED
|
@@ -1,57 +1,57 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
def assert_st_simplices(st, dump):
|
|
5
|
-
"""
|
|
6
|
-
Checks that the simplextree has the same
|
|
7
|
-
filtration as the dump.
|
|
8
|
-
"""
|
|
9
|
-
|
|
10
|
-
assert np.all(
|
|
11
|
-
[
|
|
12
|
-
np.isclose(a, b).all()
|
|
13
|
-
for x, y in zip(st.get_simplices(), dump, strict=True)
|
|
14
|
-
for a, b in zip(x, y, strict=True)
|
|
15
|
-
]
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def sort_sm(sms):
|
|
20
|
-
idx = np.argsort([sm[0][:, 0] for sm in sms])
|
|
21
|
-
return tuple((sm[0][idx], sm[1][idx]) for sm in sms)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def assert_sm_pair(sm1, sm2, exact=True, max_error=1e-3, reg=0.1):
|
|
25
|
-
if not exact:
|
|
26
|
-
from multipers.distances import sm_distance
|
|
27
|
-
|
|
28
|
-
d = sm_distance(sm1, sm2, reg=0.1)
|
|
29
|
-
assert d < max_error, f"Failed comparison:\n{sm1}\n{sm2},\n with distance {d}."
|
|
30
|
-
return
|
|
31
|
-
assert np.all(
|
|
32
|
-
[
|
|
33
|
-
np.isclose(a, b).all()
|
|
34
|
-
for x, y in zip(sm1, sm2, strict=True)
|
|
35
|
-
for a, b in zip(x, y, strict=True)
|
|
36
|
-
]
|
|
37
|
-
), f"Failed comparison:\n-----------------\n{sm1}\n-----------------\n{sm2}"
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def assert_sm(*args, exact=True, max_error=1e-5, reg=0.1):
|
|
41
|
-
sms = tuple(args)
|
|
42
|
-
for i in range(len(sms) - 1):
|
|
43
|
-
assert_sm_pair(sms[i], sms[i + 1], exact=exact, max_error=max_error, reg=reg)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def random_st(npts=100, num_parameters=2, max_dim=2):
|
|
47
|
-
import gudhi as gd
|
|
48
|
-
|
|
49
|
-
import multipers as mp
|
|
50
|
-
from multipers.data import noisy_annulus
|
|
51
|
-
|
|
52
|
-
x = noisy_annulus(npts // 2, npts - npts // 2, dim=max_dim)
|
|
53
|
-
st = gd.AlphaComplex(points=x).create_simplex_tree()
|
|
54
|
-
st = mp.SimplexTreeMulti(st, num_parameters=num_parameters)
|
|
55
|
-
for p in range(num_parameters):
|
|
56
|
-
st.fill_lowerstar(np.random.uniform(size=npts), p)
|
|
57
|
-
return st
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def assert_st_simplices(st, dump):
|
|
5
|
+
"""
|
|
6
|
+
Checks that the simplextree has the same
|
|
7
|
+
filtration as the dump.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
assert np.all(
|
|
11
|
+
[
|
|
12
|
+
np.isclose(a, b).all()
|
|
13
|
+
for x, y in zip(st.get_simplices(), dump, strict=True)
|
|
14
|
+
for a, b in zip(x, y, strict=True)
|
|
15
|
+
]
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def sort_sm(sms):
|
|
20
|
+
idx = np.argsort([sm[0][:, 0] for sm in sms])
|
|
21
|
+
return tuple((sm[0][idx], sm[1][idx]) for sm in sms)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def assert_sm_pair(sm1, sm2, exact=True, max_error=1e-3, reg=0.1):
|
|
25
|
+
if not exact:
|
|
26
|
+
from multipers.distances import sm_distance
|
|
27
|
+
|
|
28
|
+
d = sm_distance(sm1, sm2, reg=0.1)
|
|
29
|
+
assert d < max_error, f"Failed comparison:\n{sm1}\n{sm2},\n with distance {d}."
|
|
30
|
+
return
|
|
31
|
+
assert np.all(
|
|
32
|
+
[
|
|
33
|
+
np.isclose(a, b).all()
|
|
34
|
+
for x, y in zip(sm1, sm2, strict=True)
|
|
35
|
+
for a, b in zip(x, y, strict=True)
|
|
36
|
+
]
|
|
37
|
+
), f"Failed comparison:\n-----------------\n{sm1}\n-----------------\n{sm2}"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def assert_sm(*args, exact=True, max_error=1e-5, reg=0.1):
|
|
41
|
+
sms = tuple(args)
|
|
42
|
+
for i in range(len(sms) - 1):
|
|
43
|
+
assert_sm_pair(sms[i], sms[i + 1], exact=exact, max_error=max_error, reg=reg)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def random_st(npts=100, num_parameters=2, max_dim=2):
|
|
47
|
+
import gudhi as gd
|
|
48
|
+
|
|
49
|
+
import multipers as mp
|
|
50
|
+
from multipers.data import noisy_annulus
|
|
51
|
+
|
|
52
|
+
x = noisy_annulus(npts // 2, npts - npts // 2, dim=max_dim)
|
|
53
|
+
st = gd.AlphaComplex(points=x).create_simplex_tree()
|
|
54
|
+
st = mp.SimplexTreeMulti(st, num_parameters=num_parameters)
|
|
55
|
+
for p in range(num_parameters):
|
|
56
|
+
st.fill_lowerstar(np.random.uniform(size=npts), p)
|
|
57
|
+
return st
|
multipers/torch/diff_grids.py
CHANGED
|
@@ -1,217 +1,217 @@
|
|
|
1
|
-
from typing import Literal
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
from pykeops.torch import LazyTensor
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
def get_grid(strategy: Literal["exact", "regular_closest", "regular_left", "quantile"]):
|
|
9
|
-
"""
|
|
10
|
-
Given a strategy, returns a function of signature
|
|
11
|
-
`(num_pts, num_parameter), int --> Iterable[1d array]`
|
|
12
|
-
that generates a torch-differentiable grid from a set of points,
|
|
13
|
-
and a resolution.
|
|
14
|
-
"""
|
|
15
|
-
match strategy:
|
|
16
|
-
case "exact":
|
|
17
|
-
return _exact_grid
|
|
18
|
-
case "regular_closest":
|
|
19
|
-
return _regular_closest_grid
|
|
20
|
-
case "regular_left":
|
|
21
|
-
return _regular_left_grid
|
|
22
|
-
case "quantile":
|
|
23
|
-
return _quantile_grid
|
|
24
|
-
case _:
|
|
25
|
-
raise ValueError(
|
|
26
|
-
f"""
|
|
27
|
-
Unimplemented strategy {strategy}.
|
|
28
|
-
Available ones : exact, regular_closest, regular_left, quantile.
|
|
29
|
-
"""
|
|
30
|
-
)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def todense(grid: list[torch.Tensor]):
|
|
34
|
-
return torch.cartesian_prod(*grid)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def _exact_grid(filtration_values, r=None):
|
|
38
|
-
grid = tuple(_unique_any(f) for f in filtration_values)
|
|
39
|
-
return grid
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def _regular_closest_grid(filtration_values, r: int):
|
|
43
|
-
grid = tuple(_regular_closest(f, r) for f in filtration_values)
|
|
44
|
-
return grid
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def _regular_left_grid(filtration_values, r: int):
|
|
48
|
-
grid = tuple(_regular_left(f, r) for f in filtration_values)
|
|
49
|
-
return grid
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def _quantile_grid(filtration_values, r: int):
|
|
53
|
-
qs = torch.linspace(0, 1, r)
|
|
54
|
-
grid = tuple(_unique_any(torch.quantile(f, q=qs)) for f in filtration_values)
|
|
55
|
-
return grid
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def _unique_any(x, assume_sorted=False, remove_inf: bool = True):
|
|
59
|
-
if not assume_sorted:
|
|
60
|
-
x, _ = x.sort()
|
|
61
|
-
if remove_inf and x[-1] == torch.inf:
|
|
62
|
-
x = x[:-1]
|
|
63
|
-
with torch.no_grad():
|
|
64
|
-
y = x.unique()
|
|
65
|
-
idx = torch.searchsorted(x, y)
|
|
66
|
-
x = torch.cat([x, torch.tensor([torch.inf])])
|
|
67
|
-
return x[idx]
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def _regular_left(f, r: int, unique: bool = True):
|
|
71
|
-
f = _unique_any(f)
|
|
72
|
-
with torch.no_grad():
|
|
73
|
-
f_regular = torch.linspace(f[0].item(), f[-1].item(), r, device=f.device)
|
|
74
|
-
idx = torch.searchsorted(f, f_regular)
|
|
75
|
-
f = torch.cat([f, torch.tensor([torch.inf])])
|
|
76
|
-
if unique:
|
|
77
|
-
return _unique_any(f[idx])
|
|
78
|
-
return f[idx]
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def _regular_closest(f, r: int, unique: bool = True):
|
|
82
|
-
f = _unique_any(f)
|
|
83
|
-
with torch.no_grad():
|
|
84
|
-
f_reg = torch.linspace(
|
|
85
|
-
f[0].item(), f[-1].item(), steps=r, dtype=f.dtype, device=f.device
|
|
86
|
-
)
|
|
87
|
-
_f = LazyTensor(f[:, None, None])
|
|
88
|
-
_f_reg = LazyTensor(f_reg[None, :, None])
|
|
89
|
-
indices = (_f - _f_reg).abs().argmin(0).ravel()
|
|
90
|
-
f = torch.cat([f, torch.tensor([torch.inf])])
|
|
91
|
-
f_regular_closest = f[indices]
|
|
92
|
-
if unique:
|
|
93
|
-
f_regular_closest = _unique_any(f_regular_closest)
|
|
94
|
-
return f_regular_closest
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def evaluate_in_grid(pts, grid):
|
|
98
|
-
"""Evaluates points (assumed to be coordinates) in this grid.
|
|
99
|
-
Input
|
|
100
|
-
-----
|
|
101
|
-
- pts: (num_points, num_parameters) array
|
|
102
|
-
- grid: Iterable of 1-d array, for each parameter
|
|
103
|
-
|
|
104
|
-
Returns
|
|
105
|
-
-------
|
|
106
|
-
- array of shape like points of dtype like grid.
|
|
107
|
-
"""
|
|
108
|
-
# grid = [torch.cat([g, torch.tensor([torch.inf])]) for g in grid]
|
|
109
|
-
# new_pts = torch.empty(pts.shape, dtype=grid[0].dtype, device=grid[0].device)
|
|
110
|
-
# for parameter, pt_of_parameter in enumerate(pts.T):
|
|
111
|
-
# new_pts[:, parameter] = grid[parameter][pt_of_parameter]
|
|
112
|
-
return torch.cat(
|
|
113
|
-
[
|
|
114
|
-
grid[parameter][pt_of_parameter][:, None]
|
|
115
|
-
for parameter, pt_of_parameter in enumerate(pts.T)
|
|
116
|
-
],
|
|
117
|
-
dim=1,
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
def evaluate_mod_in_grid(mod, grid, box=None):
|
|
122
|
-
"""Given an MMA module, pushes it into the specified grid.
|
|
123
|
-
Useful for e.g., make it differentiable.
|
|
124
|
-
|
|
125
|
-
Input
|
|
126
|
-
-----
|
|
127
|
-
- mod: PyModule
|
|
128
|
-
- grid: Iterable of 1d array, for num_parameters
|
|
129
|
-
Ouput
|
|
130
|
-
-----
|
|
131
|
-
torch-compatible module in the format:
|
|
132
|
-
(num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
|
|
133
|
-
|
|
134
|
-
"""
|
|
135
|
-
if box is not None:
|
|
136
|
-
grid = tuple(
|
|
137
|
-
torch.cat(
|
|
138
|
-
[
|
|
139
|
-
box[0][[i]],
|
|
140
|
-
_unique_any(
|
|
141
|
-
grid[i].clamp(min=box[0][i], max=box[1][i]), assume_sorted=True
|
|
142
|
-
),
|
|
143
|
-
box[1][[i]],
|
|
144
|
-
]
|
|
145
|
-
)
|
|
146
|
-
for i in range(len(grid))
|
|
147
|
-
)
|
|
148
|
-
(birth_sizes, death_sizes), births, deaths = mod.to_flat_idx(grid)
|
|
149
|
-
births = evaluate_in_grid(births, grid)
|
|
150
|
-
deaths = evaluate_in_grid(deaths, grid)
|
|
151
|
-
diff_mod = tuple(
|
|
152
|
-
zip(
|
|
153
|
-
births.split_with_sizes(birth_sizes.tolist()),
|
|
154
|
-
deaths.split_with_sizes(death_sizes.tolist()),
|
|
155
|
-
)
|
|
156
|
-
)
|
|
157
|
-
return diff_mod
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
def evaluate_mod_in_grid__old(mod, grid, box=None):
|
|
161
|
-
"""Given an MMA module, pushes it into the specified grid.
|
|
162
|
-
Useful for e.g., make it differentiable.
|
|
163
|
-
|
|
164
|
-
Input
|
|
165
|
-
-----
|
|
166
|
-
- mod: PyModule
|
|
167
|
-
- grid: Iterable of 1d array, for num_parameters
|
|
168
|
-
Ouput
|
|
169
|
-
-----
|
|
170
|
-
torch-compatible module in the format:
|
|
171
|
-
(num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
|
|
172
|
-
|
|
173
|
-
"""
|
|
174
|
-
from pykeops.numpy import LazyTensor
|
|
175
|
-
|
|
176
|
-
with torch.no_grad():
|
|
177
|
-
if box is None:
|
|
178
|
-
# box = mod.get_box()
|
|
179
|
-
box = np.asarray([[g[0] for g in grid], [g[-1] for g in grid]])
|
|
180
|
-
S = mod.dump()[1]
|
|
181
|
-
|
|
182
|
-
def get_idx_parameter(A, G, p):
|
|
183
|
-
g = G[p].numpy() if isinstance(G[p], torch.Tensor) else np.asarray(G[p])
|
|
184
|
-
la = LazyTensor(np.asarray(A, dtype=g.dtype)[None, :, [p]])
|
|
185
|
-
lg = LazyTensor(g[:, None, None])
|
|
186
|
-
return (la - lg).abs().argmin(0)
|
|
187
|
-
|
|
188
|
-
Bdump = np.concatenate([s[0] for s in S], axis=0).clip(box[[0]], box[[1]])
|
|
189
|
-
B = np.concatenate(
|
|
190
|
-
[get_idx_parameter(Bdump, grid, p) for p in range(mod.num_parameters)],
|
|
191
|
-
axis=1,
|
|
192
|
-
dtype=np.int64,
|
|
193
|
-
)
|
|
194
|
-
Ddump = np.concatenate([s[1] for s in S], axis=0, dtype=np.float32).clip(
|
|
195
|
-
box[[0]], box[[1]]
|
|
196
|
-
)
|
|
197
|
-
D = np.concatenate(
|
|
198
|
-
[get_idx_parameter(Ddump, grid, p) for p in range(mod.num_parameters)],
|
|
199
|
-
axis=1,
|
|
200
|
-
dtype=np.int64,
|
|
201
|
-
)
|
|
202
|
-
|
|
203
|
-
BB = evaluate_in_grid(B, grid)
|
|
204
|
-
DD = evaluate_in_grid(D, grid)
|
|
205
|
-
|
|
206
|
-
b_idx = tuple((len(s[0]) for s in S))
|
|
207
|
-
d_idx = tuple((len(s[1]) for s in S))
|
|
208
|
-
BBB = BB.split_with_sizes(b_idx)
|
|
209
|
-
DDD = DD.split_with_sizes(d_idx)
|
|
210
|
-
|
|
211
|
-
splits = np.concatenate([[0], mod.degree_splits(), [len(BBB)]])
|
|
212
|
-
splits = torch.from_numpy(splits)
|
|
213
|
-
out = [
|
|
214
|
-
list(zip(BBB[splits[i] : splits[i + 1]], DDD[splits[i] : splits[i + 1]]))
|
|
215
|
-
for i in range(len(splits) - 1)
|
|
216
|
-
] ## For some reasons this kills the gradient ???? pytorch bug
|
|
217
|
-
return out
|
|
1
|
+
from typing import Literal
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from pykeops.torch import LazyTensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_grid(strategy: Literal["exact", "regular_closest", "regular_left", "quantile"]):
|
|
9
|
+
"""
|
|
10
|
+
Given a strategy, returns a function of signature
|
|
11
|
+
`(num_pts, num_parameter), int --> Iterable[1d array]`
|
|
12
|
+
that generates a torch-differentiable grid from a set of points,
|
|
13
|
+
and a resolution.
|
|
14
|
+
"""
|
|
15
|
+
match strategy:
|
|
16
|
+
case "exact":
|
|
17
|
+
return _exact_grid
|
|
18
|
+
case "regular_closest":
|
|
19
|
+
return _regular_closest_grid
|
|
20
|
+
case "regular_left":
|
|
21
|
+
return _regular_left_grid
|
|
22
|
+
case "quantile":
|
|
23
|
+
return _quantile_grid
|
|
24
|
+
case _:
|
|
25
|
+
raise ValueError(
|
|
26
|
+
f"""
|
|
27
|
+
Unimplemented strategy {strategy}.
|
|
28
|
+
Available ones : exact, regular_closest, regular_left, quantile.
|
|
29
|
+
"""
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def todense(grid: list[torch.Tensor]):
|
|
34
|
+
return torch.cartesian_prod(*grid)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _exact_grid(filtration_values, r=None):
|
|
38
|
+
grid = tuple(_unique_any(f) for f in filtration_values)
|
|
39
|
+
return grid
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _regular_closest_grid(filtration_values, r: int):
|
|
43
|
+
grid = tuple(_regular_closest(f, r) for f in filtration_values)
|
|
44
|
+
return grid
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _regular_left_grid(filtration_values, r: int):
|
|
48
|
+
grid = tuple(_regular_left(f, r) for f in filtration_values)
|
|
49
|
+
return grid
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _quantile_grid(filtration_values, r: int):
|
|
53
|
+
qs = torch.linspace(0, 1, r)
|
|
54
|
+
grid = tuple(_unique_any(torch.quantile(f, q=qs)) for f in filtration_values)
|
|
55
|
+
return grid
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _unique_any(x, assume_sorted=False, remove_inf: bool = True):
|
|
59
|
+
if not assume_sorted:
|
|
60
|
+
x, _ = x.sort()
|
|
61
|
+
if remove_inf and x[-1] == torch.inf:
|
|
62
|
+
x = x[:-1]
|
|
63
|
+
with torch.no_grad():
|
|
64
|
+
y = x.unique()
|
|
65
|
+
idx = torch.searchsorted(x, y)
|
|
66
|
+
x = torch.cat([x, torch.tensor([torch.inf])])
|
|
67
|
+
return x[idx]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _regular_left(f, r: int, unique: bool = True):
|
|
71
|
+
f = _unique_any(f)
|
|
72
|
+
with torch.no_grad():
|
|
73
|
+
f_regular = torch.linspace(f[0].item(), f[-1].item(), r, device=f.device)
|
|
74
|
+
idx = torch.searchsorted(f, f_regular)
|
|
75
|
+
f = torch.cat([f, torch.tensor([torch.inf])])
|
|
76
|
+
if unique:
|
|
77
|
+
return _unique_any(f[idx])
|
|
78
|
+
return f[idx]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _regular_closest(f, r: int, unique: bool = True):
|
|
82
|
+
f = _unique_any(f)
|
|
83
|
+
with torch.no_grad():
|
|
84
|
+
f_reg = torch.linspace(
|
|
85
|
+
f[0].item(), f[-1].item(), steps=r, dtype=f.dtype, device=f.device
|
|
86
|
+
)
|
|
87
|
+
_f = LazyTensor(f[:, None, None])
|
|
88
|
+
_f_reg = LazyTensor(f_reg[None, :, None])
|
|
89
|
+
indices = (_f - _f_reg).abs().argmin(0).ravel()
|
|
90
|
+
f = torch.cat([f, torch.tensor([torch.inf])])
|
|
91
|
+
f_regular_closest = f[indices]
|
|
92
|
+
if unique:
|
|
93
|
+
f_regular_closest = _unique_any(f_regular_closest)
|
|
94
|
+
return f_regular_closest
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def evaluate_in_grid(pts, grid):
|
|
98
|
+
"""Evaluates points (assumed to be coordinates) in this grid.
|
|
99
|
+
Input
|
|
100
|
+
-----
|
|
101
|
+
- pts: (num_points, num_parameters) array
|
|
102
|
+
- grid: Iterable of 1-d array, for each parameter
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
- array of shape like points of dtype like grid.
|
|
107
|
+
"""
|
|
108
|
+
# grid = [torch.cat([g, torch.tensor([torch.inf])]) for g in grid]
|
|
109
|
+
# new_pts = torch.empty(pts.shape, dtype=grid[0].dtype, device=grid[0].device)
|
|
110
|
+
# for parameter, pt_of_parameter in enumerate(pts.T):
|
|
111
|
+
# new_pts[:, parameter] = grid[parameter][pt_of_parameter]
|
|
112
|
+
return torch.cat(
|
|
113
|
+
[
|
|
114
|
+
grid[parameter][pt_of_parameter][:, None]
|
|
115
|
+
for parameter, pt_of_parameter in enumerate(pts.T)
|
|
116
|
+
],
|
|
117
|
+
dim=1,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def evaluate_mod_in_grid(mod, grid, box=None):
|
|
122
|
+
"""Given an MMA module, pushes it into the specified grid.
|
|
123
|
+
Useful for e.g., make it differentiable.
|
|
124
|
+
|
|
125
|
+
Input
|
|
126
|
+
-----
|
|
127
|
+
- mod: PyModule
|
|
128
|
+
- grid: Iterable of 1d array, for num_parameters
|
|
129
|
+
Ouput
|
|
130
|
+
-----
|
|
131
|
+
torch-compatible module in the format:
|
|
132
|
+
(num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
|
|
133
|
+
|
|
134
|
+
"""
|
|
135
|
+
if box is not None:
|
|
136
|
+
grid = tuple(
|
|
137
|
+
torch.cat(
|
|
138
|
+
[
|
|
139
|
+
box[0][[i]],
|
|
140
|
+
_unique_any(
|
|
141
|
+
grid[i].clamp(min=box[0][i], max=box[1][i]), assume_sorted=True
|
|
142
|
+
),
|
|
143
|
+
box[1][[i]],
|
|
144
|
+
]
|
|
145
|
+
)
|
|
146
|
+
for i in range(len(grid))
|
|
147
|
+
)
|
|
148
|
+
(birth_sizes, death_sizes), births, deaths = mod.to_flat_idx(grid)
|
|
149
|
+
births = evaluate_in_grid(births, grid)
|
|
150
|
+
deaths = evaluate_in_grid(deaths, grid)
|
|
151
|
+
diff_mod = tuple(
|
|
152
|
+
zip(
|
|
153
|
+
births.split_with_sizes(birth_sizes.tolist()),
|
|
154
|
+
deaths.split_with_sizes(death_sizes.tolist()),
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
return diff_mod
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def evaluate_mod_in_grid__old(mod, grid, box=None):
|
|
161
|
+
"""Given an MMA module, pushes it into the specified grid.
|
|
162
|
+
Useful for e.g., make it differentiable.
|
|
163
|
+
|
|
164
|
+
Input
|
|
165
|
+
-----
|
|
166
|
+
- mod: PyModule
|
|
167
|
+
- grid: Iterable of 1d array, for num_parameters
|
|
168
|
+
Ouput
|
|
169
|
+
-----
|
|
170
|
+
torch-compatible module in the format:
|
|
171
|
+
(num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
|
|
172
|
+
|
|
173
|
+
"""
|
|
174
|
+
from pykeops.numpy import LazyTensor
|
|
175
|
+
|
|
176
|
+
with torch.no_grad():
|
|
177
|
+
if box is None:
|
|
178
|
+
# box = mod.get_box()
|
|
179
|
+
box = np.asarray([[g[0] for g in grid], [g[-1] for g in grid]])
|
|
180
|
+
S = mod.dump()[1]
|
|
181
|
+
|
|
182
|
+
def get_idx_parameter(A, G, p):
|
|
183
|
+
g = G[p].numpy() if isinstance(G[p], torch.Tensor) else np.asarray(G[p])
|
|
184
|
+
la = LazyTensor(np.asarray(A, dtype=g.dtype)[None, :, [p]])
|
|
185
|
+
lg = LazyTensor(g[:, None, None])
|
|
186
|
+
return (la - lg).abs().argmin(0)
|
|
187
|
+
|
|
188
|
+
Bdump = np.concatenate([s[0] for s in S], axis=0).clip(box[[0]], box[[1]])
|
|
189
|
+
B = np.concatenate(
|
|
190
|
+
[get_idx_parameter(Bdump, grid, p) for p in range(mod.num_parameters)],
|
|
191
|
+
axis=1,
|
|
192
|
+
dtype=np.int64,
|
|
193
|
+
)
|
|
194
|
+
Ddump = np.concatenate([s[1] for s in S], axis=0, dtype=np.float32).clip(
|
|
195
|
+
box[[0]], box[[1]]
|
|
196
|
+
)
|
|
197
|
+
D = np.concatenate(
|
|
198
|
+
[get_idx_parameter(Ddump, grid, p) for p in range(mod.num_parameters)],
|
|
199
|
+
axis=1,
|
|
200
|
+
dtype=np.int64,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
BB = evaluate_in_grid(B, grid)
|
|
204
|
+
DD = evaluate_in_grid(D, grid)
|
|
205
|
+
|
|
206
|
+
b_idx = tuple((len(s[0]) for s in S))
|
|
207
|
+
d_idx = tuple((len(s[1]) for s in S))
|
|
208
|
+
BBB = BB.split_with_sizes(b_idx)
|
|
209
|
+
DDD = DD.split_with_sizes(d_idx)
|
|
210
|
+
|
|
211
|
+
splits = np.concatenate([[0], mod.degree_splits(), [len(BBB)]])
|
|
212
|
+
splits = torch.from_numpy(splits)
|
|
213
|
+
out = [
|
|
214
|
+
list(zip(BBB[splits[i] : splits[i + 1]], DDD[splits[i] : splits[i + 1]]))
|
|
215
|
+
for i in range(len(splits) - 1)
|
|
216
|
+
] ## For some reasons this kills the gradient ???? pytorch bug
|
|
217
|
+
return out
|