multipers 2.4.0b1__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- multipers/.dylibs/libboost_timer.dylib +0 -0
- multipers/.dylibs/libc++.1.0.dylib +0 -0
- multipers/.dylibs/libtbb.12.17.dylib +0 -0
- multipers/__init__.py +33 -0
- multipers/_signed_measure_meta.py +426 -0
- multipers/_slicer_meta.py +231 -0
- multipers/array_api/__init__.py +62 -0
- multipers/array_api/numpy.py +124 -0
- multipers/array_api/torch.py +133 -0
- multipers/data/MOL2.py +458 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +113 -0
- multipers/distances.py +202 -0
- multipers/filtration_conversions.pxd +736 -0
- multipers/filtration_conversions.pxd.tp +226 -0
- multipers/filtrations/__init__.py +21 -0
- multipers/filtrations/density.py +529 -0
- multipers/filtrations/filtrations.py +480 -0
- multipers/filtrations.pxd +534 -0
- multipers/filtrations.pxd.tp +332 -0
- multipers/function_rips.cpython-312-darwin.so +0 -0
- multipers/function_rips.pyx +104 -0
- multipers/grids.cpython-312-darwin.so +0 -0
- multipers/grids.pyx +538 -0
- multipers/gudhi/Persistence_slices_interface.h +213 -0
- multipers/gudhi/Simplex_tree_interface.h +274 -0
- multipers/gudhi/Simplex_tree_multi_interface.h +648 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
- multipers/gudhi/gudhi/Debug_utils.h +52 -0
- multipers/gudhi/gudhi/Degree_rips_bifiltration.h +2307 -0
- multipers/gudhi/gudhi/Dynamic_multi_parameter_filtration.h +2524 -0
- multipers/gudhi/gudhi/Fields/Multi_field.h +453 -0
- multipers/gudhi/gudhi/Fields/Multi_field_operators.h +460 -0
- multipers/gudhi/gudhi/Fields/Multi_field_shared.h +444 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small.h +584 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +490 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +580 -0
- multipers/gudhi/gudhi/Fields/Z2_field.h +391 -0
- multipers/gudhi/gudhi/Fields/Z2_field_operators.h +389 -0
- multipers/gudhi/gudhi/Fields/Zp_field.h +493 -0
- multipers/gudhi/gudhi/Fields/Zp_field_operators.h +384 -0
- multipers/gudhi/gudhi/Fields/Zp_field_shared.h +492 -0
- multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
- multipers/gudhi/gudhi/Matrix.h +2200 -0
- multipers/gudhi/gudhi/Multi_filtration/Multi_parameter_generator.h +1712 -0
- multipers/gudhi/gudhi/Multi_filtration/multi_filtration_conversions.h +237 -0
- multipers/gudhi/gudhi/Multi_filtration/multi_filtration_utils.h +225 -0
- multipers/gudhi/gudhi/Multi_parameter_filtered_complex.h +485 -0
- multipers/gudhi/gudhi/Multi_parameter_filtration.h +2643 -0
- multipers/gudhi/gudhi/Multi_persistence/Box.h +233 -0
- multipers/gudhi/gudhi/Multi_persistence/Line.h +309 -0
- multipers/gudhi/gudhi/Multi_persistence/Multi_parameter_filtered_complex_pcoh_interface.h +268 -0
- multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_cohomology.h +159 -0
- multipers/gudhi/gudhi/Multi_persistence/Persistence_interface_matrix.h +463 -0
- multipers/gudhi/gudhi/Multi_persistence/Point.h +853 -0
- multipers/gudhi/gudhi/Off_reader.h +173 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +834 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +838 -0
- multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +833 -0
- multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1367 -0
- multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1157 -0
- multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +869 -0
- multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +905 -0
- multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +122 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +260 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +288 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +170 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +247 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +571 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +182 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +130 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +235 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +312 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1092 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +923 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +914 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +930 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1071 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +203 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +886 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +984 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1213 -0
- multipers/gudhi/gudhi/Persistence_matrix/index_mapper.h +58 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +227 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +200 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +166 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +319 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +562 -0
- multipers/gudhi/gudhi/Persistence_on_a_line.h +152 -0
- multipers/gudhi/gudhi/Persistence_on_rectangle.h +617 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
- multipers/gudhi/gudhi/Persistent_cohomology.h +769 -0
- multipers/gudhi/gudhi/Points_off_io.h +171 -0
- multipers/gudhi/gudhi/Projective_cover_kernel.h +379 -0
- multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +559 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +121 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
- multipers/gudhi/gudhi/Simplex_tree/filtration_value_utils.h +155 -0
- multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
- multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +60 -0
- multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +105 -0
- multipers/gudhi/gudhi/Simplex_tree.h +3170 -0
- multipers/gudhi/gudhi/Slicer.h +848 -0
- multipers/gudhi/gudhi/Thread_safe_slicer.h +393 -0
- multipers/gudhi/gudhi/distance_functions.h +62 -0
- multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
- multipers/gudhi/gudhi/multi_simplex_tree_helpers.h +147 -0
- multipers/gudhi/gudhi/persistence_interval.h +263 -0
- multipers/gudhi/gudhi/persistence_matrix_options.h +188 -0
- multipers/gudhi/gudhi/reader_utils.h +367 -0
- multipers/gudhi/gudhi/simple_mdspan.h +484 -0
- multipers/gudhi/gudhi/slicer_helpers.h +779 -0
- multipers/gudhi/tmp_h0_pers/mma_interface_h0.h +223 -0
- multipers/gudhi/tmp_h0_pers/naive_merge_tree.h +536 -0
- multipers/io.cpython-312-darwin.so +0 -0
- multipers/io.pyx +472 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +90 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +176 -0
- multipers/ml/mma.py +713 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +352 -0
- multipers/ml/signed_measures.py +1667 -0
- multipers/ml/sliced_wasserstein.py +461 -0
- multipers/ml/tools.py +113 -0
- multipers/mma_structures.cpython-312-darwin.so +0 -0
- multipers/mma_structures.pxd +134 -0
- multipers/mma_structures.pyx +1483 -0
- multipers/mma_structures.pyx.tp +1126 -0
- multipers/multi_parameter_rank_invariant/diff_helpers.h +85 -0
- multipers/multi_parameter_rank_invariant/euler_characteristic.h +95 -0
- multipers/multi_parameter_rank_invariant/function_rips.h +317 -0
- multipers/multi_parameter_rank_invariant/hilbert_function.h +761 -0
- multipers/multi_parameter_rank_invariant/persistence_slices.h +149 -0
- multipers/multi_parameter_rank_invariant/rank_invariant.h +350 -0
- multipers/multiparameter_edge_collapse.py +41 -0
- multipers/multiparameter_module_approximation/approximation.h +2541 -0
- multipers/multiparameter_module_approximation/debug.h +107 -0
- multipers/multiparameter_module_approximation/format_python-cpp.h +292 -0
- multipers/multiparameter_module_approximation/utilities.h +428 -0
- multipers/multiparameter_module_approximation.cpython-312-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyx +286 -0
- multipers/ops.cpython-312-darwin.so +0 -0
- multipers/ops.pyx +231 -0
- multipers/pickle.py +89 -0
- multipers/plots.py +550 -0
- multipers/point_measure.cpython-312-darwin.so +0 -0
- multipers/point_measure.pyx +409 -0
- multipers/simplex_tree_multi.cpython-312-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +136 -0
- multipers/simplex_tree_multi.pyx +11719 -0
- multipers/simplex_tree_multi.pyx.tp +2102 -0
- multipers/slicer.cpython-312-darwin.so +0 -0
- multipers/slicer.pxd +2097 -0
- multipers/slicer.pxd.tp +263 -0
- multipers/slicer.pyx +13042 -0
- multipers/slicer.pyx.tp +1259 -0
- multipers/tensor/tensor.h +672 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers/tests/__init__.py +70 -0
- multipers/torch/__init__.py +1 -0
- multipers/torch/diff_grids.py +240 -0
- multipers/torch/rips_density.py +310 -0
- multipers/vector_interface.pxd +46 -0
- multipers-2.4.0b1.dist-info/METADATA +131 -0
- multipers-2.4.0b1.dist-info/RECORD +184 -0
- multipers-2.4.0b1.dist-info/WHEEL +6 -0
- multipers-2.4.0b1.dist-info/licenses/LICENSE +21 -0
- multipers-2.4.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def noisy_annulus(
|
|
5
|
+
n1: int = 1000,
|
|
6
|
+
n2: int = 200,
|
|
7
|
+
r1: float = 1,
|
|
8
|
+
r2: float = 2,
|
|
9
|
+
dim: int = 2,
|
|
10
|
+
center: np.ndarray | list | None = None,
|
|
11
|
+
**kwargs,
|
|
12
|
+
) -> np.ndarray:
|
|
13
|
+
"""Generates a noisy annulus dataset.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
r1 : float.
|
|
18
|
+
Lower radius of the annulus.
|
|
19
|
+
r2 : float.
|
|
20
|
+
Upper radius of the annulus.
|
|
21
|
+
n1 : int
|
|
22
|
+
Number of points in the annulus.
|
|
23
|
+
n2 : int
|
|
24
|
+
Number of points in the square.
|
|
25
|
+
dim : int
|
|
26
|
+
Dimension of the annulus.
|
|
27
|
+
center: list or array
|
|
28
|
+
center of the annulus.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
numpy array
|
|
33
|
+
Dataset. size : (n1+n2) x dim
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
theta = np.random.normal(size=(n1, dim))
|
|
37
|
+
theta /= np.linalg.norm(theta, axis=1)[:, None]
|
|
38
|
+
rs = np.sqrt(np.random.uniform(low=r1**2, high=r2**2, size=n1))
|
|
39
|
+
annulus = rs[:, None] * theta
|
|
40
|
+
if center is not None:
|
|
41
|
+
annulus += np.array(center)
|
|
42
|
+
diffuse_noise = np.random.uniform(size=(n2, dim), low=-1.1 * r2, high=1.1 * r2)
|
|
43
|
+
if center is not None:
|
|
44
|
+
diffuse_noise += np.array(center)
|
|
45
|
+
return np.vstack([annulus, diffuse_noise])
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def three_annulus(num_pts: int = 500, num_outliers: int = 500):
|
|
49
|
+
q, r = divmod(num_pts, 3)
|
|
50
|
+
num_pts_1, num_pts_2, num_pts_3 = q, q + (r > 0), q + (r > 1)
|
|
51
|
+
X = np.block(
|
|
52
|
+
[
|
|
53
|
+
[np.random.uniform(low=-2, high=2, size=(num_outliers, 2))],
|
|
54
|
+
[
|
|
55
|
+
np.array(
|
|
56
|
+
noisy_annulus(
|
|
57
|
+
r1=0.6,
|
|
58
|
+
r2=0.9,
|
|
59
|
+
n1=num_pts_1,
|
|
60
|
+
n2=0,
|
|
61
|
+
center=[1, -0.2],
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
],
|
|
65
|
+
[
|
|
66
|
+
np.array(
|
|
67
|
+
noisy_annulus(
|
|
68
|
+
r1=0.4,
|
|
69
|
+
r2=0.55,
|
|
70
|
+
n1=num_pts_2,
|
|
71
|
+
n2=0,
|
|
72
|
+
center=[-1.2, -1],
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
],
|
|
76
|
+
[
|
|
77
|
+
np.array(
|
|
78
|
+
noisy_annulus(
|
|
79
|
+
r1=0.3,
|
|
80
|
+
r2=0.4,
|
|
81
|
+
n1=num_pts_3,
|
|
82
|
+
n2=0,
|
|
83
|
+
center=[-0.7, 1.1],
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
],
|
|
87
|
+
]
|
|
88
|
+
)
|
|
89
|
+
return X
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def orbit(n: int = 1000, r: float = 1.0, x0=[]):
|
|
93
|
+
point_list = []
|
|
94
|
+
if len(x0) != 2:
|
|
95
|
+
x, y = np.random.uniform(size=2)
|
|
96
|
+
else:
|
|
97
|
+
x, y = x0
|
|
98
|
+
point_list.append([x, y])
|
|
99
|
+
for _ in range(n - 1):
|
|
100
|
+
x = (x + r * y * (1 - y)) % 1
|
|
101
|
+
y = (y + r * x * (1 - x)) % 1
|
|
102
|
+
point_list.append([x, y])
|
|
103
|
+
return np.asarray(point_list, dtype=float)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def get_orbit5k(num_pts=1000, num_data=5000):
|
|
107
|
+
from sklearn.preprocessing import LabelEncoder
|
|
108
|
+
|
|
109
|
+
rs = [2.5, 3.5, 4, 4.1, 4.3]
|
|
110
|
+
labels = np.random.choice(rs, size=num_data, replace=True)
|
|
111
|
+
X = [orbit(n=num_pts, r=r) for r in labels]
|
|
112
|
+
labels = LabelEncoder().fit_transform(labels)
|
|
113
|
+
return X, labels
|
multipers/distances.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import ot
|
|
3
|
+
|
|
4
|
+
from multipers.mma_structures import PyMultiDiagrams_type
|
|
5
|
+
from multipers.multiparameter_module_approximation import PyModule_type
|
|
6
|
+
from multipers.simplex_tree_multi import SimplexTreeMulti_type
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def sm2diff(sm1, sm2, threshold=None):
|
|
10
|
+
pts = sm1[0]
|
|
11
|
+
dtype = pts.dtype
|
|
12
|
+
if isinstance(pts, np.ndarray):
|
|
13
|
+
|
|
14
|
+
def backend_concatenate(a, b):
|
|
15
|
+
return np.concatenate([a, b], axis=0, dtype=dtype)
|
|
16
|
+
|
|
17
|
+
def backend_tensor(x):
|
|
18
|
+
return np.asarray(x, dtype=int)
|
|
19
|
+
|
|
20
|
+
else:
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
assert isinstance(pts, torch.Tensor), "Invalid backend. Numpy or torch."
|
|
24
|
+
|
|
25
|
+
def backend_concatenate(a, b):
|
|
26
|
+
return torch.concatenate([a, b], dim=0)
|
|
27
|
+
|
|
28
|
+
def backend_tensor(x):
|
|
29
|
+
return torch.tensor(x).type(torch.int)
|
|
30
|
+
|
|
31
|
+
pts1, w1 = sm1
|
|
32
|
+
pts2, w2 = sm2
|
|
33
|
+
## TODO: optimize this
|
|
34
|
+
pos_indices1 = backend_tensor(
|
|
35
|
+
[i for i, w in enumerate(w1) for _ in range(w) if w > 0]
|
|
36
|
+
)
|
|
37
|
+
pos_indices2 = backend_tensor(
|
|
38
|
+
[i for i, w in enumerate(w2) for _ in range(w) if w > 0]
|
|
39
|
+
)
|
|
40
|
+
neg_indices1 = backend_tensor(
|
|
41
|
+
[i for i, w in enumerate(w1) for _ in range(-w) if w < 0]
|
|
42
|
+
)
|
|
43
|
+
neg_indices2 = backend_tensor(
|
|
44
|
+
[i for i, w in enumerate(w2) for _ in range(-w) if w < 0]
|
|
45
|
+
)
|
|
46
|
+
x = backend_concatenate(pts1[pos_indices1], pts2[neg_indices2])
|
|
47
|
+
y = backend_concatenate(pts1[neg_indices1], pts2[pos_indices2])
|
|
48
|
+
if threshold is not None:
|
|
49
|
+
x[x > threshold] = threshold
|
|
50
|
+
y[y > threshold] = threshold
|
|
51
|
+
return x, y
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def sm_distance(
|
|
55
|
+
sm1: tuple,
|
|
56
|
+
sm2: tuple,
|
|
57
|
+
reg: float = 0,
|
|
58
|
+
reg_m: float = 0,
|
|
59
|
+
numItermax: int = 10000,
|
|
60
|
+
p: float = 1,
|
|
61
|
+
threshold=None,
|
|
62
|
+
):
|
|
63
|
+
"""
|
|
64
|
+
Computes the wasserstein distances between two signed measures,
|
|
65
|
+
of the form
|
|
66
|
+
- (pts,weights)
|
|
67
|
+
with
|
|
68
|
+
- pts : (num_pts, dim) float array
|
|
69
|
+
- weights : (num_pts,) int array
|
|
70
|
+
|
|
71
|
+
Regularisation:
|
|
72
|
+
- sinkhorn if reg != 0
|
|
73
|
+
- sinkhorn unbalanced if reg_m != 0
|
|
74
|
+
"""
|
|
75
|
+
x, y = sm2diff(sm1, sm2, threshold=threshold)
|
|
76
|
+
loss = ot.dist(
|
|
77
|
+
x, y, metric="sqeuclidean", p=p
|
|
78
|
+
) # only euc + sqeuclidian are implemented in pot for the moment with torch backend # TODO : check later
|
|
79
|
+
if isinstance(x, np.ndarray):
|
|
80
|
+
empty_tensor = np.array([]) # uniform weights
|
|
81
|
+
else:
|
|
82
|
+
import torch
|
|
83
|
+
|
|
84
|
+
assert isinstance(x, torch.Tensor), "Unimplemented backend."
|
|
85
|
+
empty_tensor = torch.tensor([]) # uniform weights
|
|
86
|
+
|
|
87
|
+
if reg == 0:
|
|
88
|
+
return ot.lp.emd2(empty_tensor, empty_tensor, M=loss) * len(x)
|
|
89
|
+
if reg_m == 0:
|
|
90
|
+
return ot.sinkhorn2(
|
|
91
|
+
a=empty_tensor, b=empty_tensor, M=loss, reg=reg, numItermax=numItermax
|
|
92
|
+
)
|
|
93
|
+
return ot.sinkhorn_unbalanced2(
|
|
94
|
+
a=empty_tensor,
|
|
95
|
+
b=empty_tensor,
|
|
96
|
+
M=loss,
|
|
97
|
+
reg=reg,
|
|
98
|
+
reg_m=reg_m,
|
|
99
|
+
numItermax=numItermax,
|
|
100
|
+
)
|
|
101
|
+
# return ot.sinkhorn2(a=onesx,b=onesy,M=loss,reg=reg, numItermax=numItermax)
|
|
102
|
+
# return ot.bregman.empirical_sinkhorn2(x,y,reg=reg)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def estimate_matching(b1: PyMultiDiagrams_type, b2: PyMultiDiagrams_type):
|
|
106
|
+
assert len(b1) == len(b2)
|
|
107
|
+
from gudhi.bottleneck import bottleneck_distance
|
|
108
|
+
|
|
109
|
+
def get_bc(b: PyMultiDiagrams_type, i: int) -> np.ndarray:
|
|
110
|
+
temp = b[i].get_points()
|
|
111
|
+
out = (
|
|
112
|
+
np.array(temp)[:, :, 0] if len(temp) > 0 else np.empty((0, 2))
|
|
113
|
+
) # GUDHI FIX
|
|
114
|
+
return out
|
|
115
|
+
|
|
116
|
+
return max(
|
|
117
|
+
(bottleneck_distance(get_bc(b1, i), get_bc(b2, i)) for i in range(len(b1)))
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# Functions to estimate precision
|
|
122
|
+
def estimate_error(
|
|
123
|
+
st: SimplexTreeMulti_type,
|
|
124
|
+
module: PyModule_type,
|
|
125
|
+
degree: int,
|
|
126
|
+
nlines: int = 100,
|
|
127
|
+
verbose: bool = False,
|
|
128
|
+
):
|
|
129
|
+
"""
|
|
130
|
+
Given an MMA SimplexTree and PyModule, estimates the bottleneck distance using barcodes given by gudhi.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
- st:SimplexTree
|
|
135
|
+
The simplextree representing the n-filtered complex. Used to define the gudhi simplextrees on different lines.
|
|
136
|
+
- module:PyModule
|
|
137
|
+
The module on which to estimate approximation error, w.r.t. the original simplextree st.
|
|
138
|
+
- degree:int
|
|
139
|
+
The homology degree to consider
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
- float:The estimation of the matching distance, i.e., the maximum of the sampled bottleneck distances.
|
|
144
|
+
|
|
145
|
+
"""
|
|
146
|
+
from time import perf_counter
|
|
147
|
+
|
|
148
|
+
parameter = 0
|
|
149
|
+
|
|
150
|
+
def _get_bc_ST(st, basepoint, degree: int):
|
|
151
|
+
"""
|
|
152
|
+
Slices an mma simplextree to a gudhi simplextree, and compute its persistence on the diagonal line crossing the given basepoint.
|
|
153
|
+
"""
|
|
154
|
+
gst = st.project_on_line(
|
|
155
|
+
basepoint=basepoint, parameter=parameter
|
|
156
|
+
) # we consider only the 1rst coordinate (as )
|
|
157
|
+
gst.compute_persistence()
|
|
158
|
+
return gst.persistence_intervals_in_dimension(degree)
|
|
159
|
+
|
|
160
|
+
from gudhi.bottleneck import bottleneck_distance
|
|
161
|
+
|
|
162
|
+
low, high = module.get_box()
|
|
163
|
+
nfiltration = len(low)
|
|
164
|
+
basepoints = np.random.uniform(low=low, high=high, size=(nlines, nfiltration))
|
|
165
|
+
# barcodes from module
|
|
166
|
+
print("Computing mma barcodes...", flush=1, end="") if verbose else None
|
|
167
|
+
time = perf_counter()
|
|
168
|
+
bcs_from_mod = module.barcodes(degree=degree, basepoints=basepoints).get_points()
|
|
169
|
+
print(f"Done. {perf_counter() - time}s.") if verbose else None
|
|
170
|
+
|
|
171
|
+
def clean(dgm):
|
|
172
|
+
return np.array(
|
|
173
|
+
[
|
|
174
|
+
[birth[parameter], death[parameter]]
|
|
175
|
+
for birth, death in dgm
|
|
176
|
+
if len(birth) > 0 and birth[parameter] != np.inf
|
|
177
|
+
]
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
bcs_from_mod = [
|
|
181
|
+
clean(dgm) for dgm in bcs_from_mod
|
|
182
|
+
] # we only consider the 1st coordinate of the barcode
|
|
183
|
+
# Computes gudhi barcodes
|
|
184
|
+
from tqdm import tqdm
|
|
185
|
+
|
|
186
|
+
bcs_from_gudhi = [
|
|
187
|
+
_get_bc_ST(st, basepoint=basepoint, degree=degree)
|
|
188
|
+
for basepoint in tqdm(
|
|
189
|
+
basepoints, disable=not verbose, desc="Computing gudhi barcodes"
|
|
190
|
+
)
|
|
191
|
+
]
|
|
192
|
+
return max(
|
|
193
|
+
(
|
|
194
|
+
bottleneck_distance(a, b)
|
|
195
|
+
for a, b in tqdm(
|
|
196
|
+
zip(bcs_from_mod, bcs_from_gudhi),
|
|
197
|
+
disable=not verbose,
|
|
198
|
+
total=nlines,
|
|
199
|
+
desc="Computing bottleneck distances",
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
)
|