multipers 2.3.3b6__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/.dylibs/libc++.1.0.dylib +0 -0
- multipers/.dylibs/libtbb.12.16.dylib +0 -0
- multipers/__init__.py +33 -0
- multipers/_signed_measure_meta.py +453 -0
- multipers/_slicer_meta.py +211 -0
- multipers/array_api/__init__.py +45 -0
- multipers/array_api/numpy.py +41 -0
- multipers/array_api/torch.py +58 -0
- multipers/data/MOL2.py +458 -0
- multipers/data/UCR.py +18 -0
- multipers/data/__init__.py +1 -0
- multipers/data/graphs.py +466 -0
- multipers/data/immuno_regions.py +27 -0
- multipers/data/minimal_presentation_to_st_bf.py +0 -0
- multipers/data/pytorch2simplextree.py +91 -0
- multipers/data/shape3d.py +101 -0
- multipers/data/synthetic.py +113 -0
- multipers/distances.py +202 -0
- multipers/filtration_conversions.pxd +229 -0
- multipers/filtration_conversions.pxd.tp +84 -0
- multipers/filtrations/__init__.py +18 -0
- multipers/filtrations/density.py +574 -0
- multipers/filtrations/filtrations.py +361 -0
- multipers/filtrations.pxd +224 -0
- multipers/function_rips.cpython-311-darwin.so +0 -0
- multipers/function_rips.pyx +105 -0
- multipers/grids.cpython-311-darwin.so +0 -0
- multipers/grids.pyx +433 -0
- multipers/gudhi/Persistence_slices_interface.h +132 -0
- multipers/gudhi/Simplex_tree_interface.h +239 -0
- multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
- multipers/gudhi/cubical_to_boundary.h +59 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
- multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
- multipers/gudhi/gudhi/Debug_utils.h +45 -0
- multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
- multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
- multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
- multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
- multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
- multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
- multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
- multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
- multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
- multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
- multipers/gudhi/gudhi/Matrix.h +2107 -0
- multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
- multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
- multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
- multipers/gudhi/gudhi/Off_reader.h +173 -0
- multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
- multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
- multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
- multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
- multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
- multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
- multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
- multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
- multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
- multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
- multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
- multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
- multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
- multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
- multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
- multipers/gudhi/gudhi/Points_off_io.h +171 -0
- multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
- multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
- multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
- multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
- multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
- multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
- multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
- multipers/gudhi/gudhi/distance_functions.h +62 -0
- multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
- multipers/gudhi/gudhi/persistence_interval.h +253 -0
- multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
- multipers/gudhi/gudhi/reader_utils.h +367 -0
- multipers/gudhi/mma_interface_coh.h +256 -0
- multipers/gudhi/mma_interface_h0.h +223 -0
- multipers/gudhi/mma_interface_matrix.h +293 -0
- multipers/gudhi/naive_merge_tree.h +536 -0
- multipers/gudhi/scc_io.h +310 -0
- multipers/gudhi/truc.h +1403 -0
- multipers/io.cpython-311-darwin.so +0 -0
- multipers/io.pyx +644 -0
- multipers/ml/__init__.py +0 -0
- multipers/ml/accuracies.py +90 -0
- multipers/ml/invariants_with_persistable.py +79 -0
- multipers/ml/kernels.py +176 -0
- multipers/ml/mma.py +713 -0
- multipers/ml/one.py +472 -0
- multipers/ml/point_clouds.py +352 -0
- multipers/ml/signed_measures.py +1589 -0
- multipers/ml/sliced_wasserstein.py +461 -0
- multipers/ml/tools.py +113 -0
- multipers/mma_structures.cpython-311-darwin.so +0 -0
- multipers/mma_structures.pxd +128 -0
- multipers/mma_structures.pyx +2786 -0
- multipers/mma_structures.pyx.tp +1094 -0
- multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
- multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
- multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
- multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
- multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
- multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
- multipers/multiparameter_edge_collapse.py +41 -0
- multipers/multiparameter_module_approximation/approximation.h +2330 -0
- multipers/multiparameter_module_approximation/combinatory.h +129 -0
- multipers/multiparameter_module_approximation/debug.h +107 -0
- multipers/multiparameter_module_approximation/euler_curves.h +0 -0
- multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
- multipers/multiparameter_module_approximation/heap_column.h +238 -0
- multipers/multiparameter_module_approximation/images.h +79 -0
- multipers/multiparameter_module_approximation/list_column.h +174 -0
- multipers/multiparameter_module_approximation/list_column_2.h +232 -0
- multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
- multipers/multiparameter_module_approximation/set_column.h +135 -0
- multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
- multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
- multipers/multiparameter_module_approximation/utilities.h +403 -0
- multipers/multiparameter_module_approximation/vector_column.h +223 -0
- multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
- multipers/multiparameter_module_approximation/vineyards.h +464 -0
- multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
- multipers/multiparameter_module_approximation.cpython-311-darwin.so +0 -0
- multipers/multiparameter_module_approximation.pyx +235 -0
- multipers/pickle.py +90 -0
- multipers/plots.py +456 -0
- multipers/point_measure.cpython-311-darwin.so +0 -0
- multipers/point_measure.pyx +395 -0
- multipers/simplex_tree_multi.cpython-311-darwin.so +0 -0
- multipers/simplex_tree_multi.pxd +134 -0
- multipers/simplex_tree_multi.pyx +10840 -0
- multipers/simplex_tree_multi.pyx.tp +2009 -0
- multipers/slicer.cpython-311-darwin.so +0 -0
- multipers/slicer.pxd +3034 -0
- multipers/slicer.pxd.tp +234 -0
- multipers/slicer.pyx +20481 -0
- multipers/slicer.pyx.tp +1088 -0
- multipers/tensor/tensor.h +672 -0
- multipers/tensor.pxd +13 -0
- multipers/test.pyx +44 -0
- multipers/tests/__init__.py +62 -0
- multipers/torch/__init__.py +1 -0
- multipers/torch/diff_grids.py +240 -0
- multipers/torch/rips_density.py +310 -0
- multipers-2.3.3b6.dist-info/METADATA +128 -0
- multipers-2.3.3b6.dist-info/RECORD +183 -0
- multipers-2.3.3b6.dist-info/WHEEL +6 -0
- multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
- multipers-2.3.3b6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from warnings import warn
|
|
3
|
+
import numpy as np
|
|
4
|
+
from tqdm import tqdm
|
|
5
|
+
from os.path import exists
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def accuracy_to_csv(
|
|
9
|
+
X,
|
|
10
|
+
Y,
|
|
11
|
+
cl,
|
|
12
|
+
k: float = 10,
|
|
13
|
+
dataset: str = "",
|
|
14
|
+
shuffle=True,
|
|
15
|
+
verbose: bool = True,
|
|
16
|
+
**more_columns,
|
|
17
|
+
):
|
|
18
|
+
assert k > 0, "k is either the number of kfold > 1 or the test size > 0."
|
|
19
|
+
if k > 1:
|
|
20
|
+
k = int(k)
|
|
21
|
+
from sklearn.model_selection import StratifiedKFold as KFold
|
|
22
|
+
|
|
23
|
+
kfold = KFold(k, shuffle=shuffle).split(X, Y)
|
|
24
|
+
accuracies = np.zeros(k)
|
|
25
|
+
for i, (train_idx, test_idx) in enumerate(
|
|
26
|
+
tqdm(kfold, total=k, desc="Computing kfold")
|
|
27
|
+
):
|
|
28
|
+
xtrain = [X[i] for i in train_idx]
|
|
29
|
+
ytrain = [Y[i] for i in train_idx]
|
|
30
|
+
cl.fit(xtrain, ytrain)
|
|
31
|
+
xtest = [X[i] for i in test_idx]
|
|
32
|
+
ytest = [Y[i] for i in test_idx]
|
|
33
|
+
accuracies[i] = cl.score(xtest, ytest)
|
|
34
|
+
if verbose:
|
|
35
|
+
print(f"step {i+1}, {dataset} : {accuracies[i]}", flush=True)
|
|
36
|
+
try:
|
|
37
|
+
print("Best classification parameters : ", cl.best_params_)
|
|
38
|
+
except:
|
|
39
|
+
None
|
|
40
|
+
|
|
41
|
+
print(
|
|
42
|
+
f"""Accuracy {dataset} : {np.mean(accuracies).round(decimals=3)}±{np.std(accuracies).round(decimals=3)}"""
|
|
43
|
+
)
|
|
44
|
+
elif k > 0:
|
|
45
|
+
from sklearn.model_selection import train_test_split
|
|
46
|
+
|
|
47
|
+
print("Computing accuracy, with train test split", flush=True)
|
|
48
|
+
xtrain, xtest, ytrain, ytest = train_test_split(
|
|
49
|
+
X, Y, shuffle=shuffle, test_size=k
|
|
50
|
+
)
|
|
51
|
+
print("Fitting...", end="", flush=True)
|
|
52
|
+
cl.fit(xtrain, ytrain)
|
|
53
|
+
print("Computing score...", end="", flush=True)
|
|
54
|
+
accuracies = cl.score(xtest, ytest)
|
|
55
|
+
try:
|
|
56
|
+
print("Best classification parameters : ", cl.best_params_)
|
|
57
|
+
except:
|
|
58
|
+
None
|
|
59
|
+
print("Done.")
|
|
60
|
+
if verbose:
|
|
61
|
+
print(f"Accuracy {dataset} : {accuracies} ")
|
|
62
|
+
file_path: str = f"result_{dataset}.csv".replace("/", "_").replace(".off", "")
|
|
63
|
+
columns: list[str] = ["dataset", "cv", "mean", "std"]
|
|
64
|
+
if exists(file_path):
|
|
65
|
+
df: pd.DataFrame = pd.read_csv(file_path)
|
|
66
|
+
else:
|
|
67
|
+
df: pd.DataFrame = pd.DataFrame(columns=columns)
|
|
68
|
+
more_names = []
|
|
69
|
+
more_values = []
|
|
70
|
+
for key, value in more_columns.items():
|
|
71
|
+
if key not in columns:
|
|
72
|
+
more_names.append(key)
|
|
73
|
+
more_values.append(value)
|
|
74
|
+
else:
|
|
75
|
+
warn(f"Duplicate key {key} ! with value {value}")
|
|
76
|
+
new_line: pd.DataFrame = pd.DataFrame(
|
|
77
|
+
[
|
|
78
|
+
[
|
|
79
|
+
dataset,
|
|
80
|
+
k,
|
|
81
|
+
np.mean(accuracies).round(decimals=3),
|
|
82
|
+
np.std(accuracies).round(decimals=3),
|
|
83
|
+
]
|
|
84
|
+
+ more_values
|
|
85
|
+
],
|
|
86
|
+
columns=columns + more_names,
|
|
87
|
+
)
|
|
88
|
+
print(new_line)
|
|
89
|
+
df = pd.concat([df, new_line])
|
|
90
|
+
df.to_csv(file_path, index=False)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import persistable
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# requires installing ripser (pip install ripser) as well as persistable from the higher-homology branch,
|
|
5
|
+
# which can be done as follows:
|
|
6
|
+
# pip install git+https://github.com/LuisScoccola/persistable.git@higher-homology
|
|
7
|
+
# NOTE: only accepts as input a distance matrix
|
|
8
|
+
def hf_degree_rips(
|
|
9
|
+
distance_matrix,
|
|
10
|
+
min_rips_value,
|
|
11
|
+
max_rips_value,
|
|
12
|
+
max_normalized_degree,
|
|
13
|
+
min_normalized_degree,
|
|
14
|
+
grid_granularity,
|
|
15
|
+
max_homological_dimension,
|
|
16
|
+
subsample_size = None,
|
|
17
|
+
):
|
|
18
|
+
if subsample_size == None:
|
|
19
|
+
p = persistable.Persistable(distance_matrix, metric="precomputed")
|
|
20
|
+
else:
|
|
21
|
+
p = persistable.Persistable(distance_matrix, metric="precomputed", subsample=subsample_size)
|
|
22
|
+
|
|
23
|
+
rips_values, normalized_degree_values, hilbert_functions, minimal_hilbert_decompositions = p._hilbert_function(
|
|
24
|
+
min_rips_value,
|
|
25
|
+
max_rips_value,
|
|
26
|
+
max_normalized_degree,
|
|
27
|
+
min_normalized_degree,
|
|
28
|
+
grid_granularity,
|
|
29
|
+
homological_dimension=max_homological_dimension,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
return rips_values, normalized_degree_values, hilbert_functions, minimal_hilbert_decompositions
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def hf_h0_degree_rips(
|
|
37
|
+
point_cloud,
|
|
38
|
+
min_rips_value,
|
|
39
|
+
max_rips_value,
|
|
40
|
+
max_normalized_degree,
|
|
41
|
+
min_normalized_degree,
|
|
42
|
+
grid_granularity,
|
|
43
|
+
):
|
|
44
|
+
p = persistable.Persistable(point_cloud, n_neighbors="all")
|
|
45
|
+
|
|
46
|
+
rips_values, normalized_degree_values, hilbert_functions, minimal_hilbert_decompositions = p._hilbert_function(
|
|
47
|
+
min_rips_value,
|
|
48
|
+
max_rips_value,
|
|
49
|
+
max_normalized_degree,
|
|
50
|
+
min_normalized_degree,
|
|
51
|
+
grid_granularity,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return rips_values, normalized_degree_values, hilbert_functions[0], minimal_hilbert_decompositions[0]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def ri_h0_degree_rips(
|
|
58
|
+
point_cloud,
|
|
59
|
+
min_rips_value,
|
|
60
|
+
max_rips_value,
|
|
61
|
+
max_normalized_degree,
|
|
62
|
+
min_normalized_degree,
|
|
63
|
+
grid_granularity,
|
|
64
|
+
):
|
|
65
|
+
p = persistable.Persistable(point_cloud, n_neighbors="all")
|
|
66
|
+
|
|
67
|
+
rips_values, normalized_degree_values, rank_invariant, _, _ = p._rank_invariant(
|
|
68
|
+
min_rips_value,
|
|
69
|
+
max_rips_value,
|
|
70
|
+
max_normalized_degree,
|
|
71
|
+
min_normalized_degree,
|
|
72
|
+
grid_granularity,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return rips_values, normalized_degree_values, rank_invariant
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
|
multipers/ml/kernels.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from sklearn.base import BaseEstimator, TransformerMixin, clone
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Iterable
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# To do k folds with a distance matrix, we need to slice it into list of distances.
|
|
7
|
+
# k-fold usually shuffles the lists, so we need to add an identifier to each entry,
|
|
8
|
+
#
|
|
9
|
+
class DistanceMatrix2DistanceList(BaseEstimator, TransformerMixin):
|
|
10
|
+
def __init__(self) -> None:
|
|
11
|
+
super().__init__()
|
|
12
|
+
|
|
13
|
+
def fit(self, X, y=None):
|
|
14
|
+
return self
|
|
15
|
+
|
|
16
|
+
def transform(self, X):
|
|
17
|
+
X = np.asarray(X)
|
|
18
|
+
assert X.ndim == 2 # Its a matrix
|
|
19
|
+
return np.asarray([[i, *distance_to_pt] for i, distance_to_pt in enumerate(X)])
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DistanceList2DistanceMatrix(BaseEstimator, TransformerMixin):
|
|
23
|
+
def __init__(self) -> None:
|
|
24
|
+
super().__init__()
|
|
25
|
+
|
|
26
|
+
def fit(self, X, y=None):
|
|
27
|
+
return self
|
|
28
|
+
|
|
29
|
+
def transform(self, X):
|
|
30
|
+
index_list = (
|
|
31
|
+
np.asarray(X[:, 0], dtype=int) + 1
|
|
32
|
+
) # shift of 1, because the first index is for indexing the pts
|
|
33
|
+
return X[:, index_list] # The distance matrix of the index_list
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DistanceMatrices2DistancesList(BaseEstimator, TransformerMixin):
|
|
37
|
+
"""
|
|
38
|
+
Input (degree) x (distance matrix) or (axis) x (degree) x (distance matrix D)
|
|
39
|
+
Output _ (D1) x opt (axis) x (degree) x (D2, , with indices first)
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self) -> None:
|
|
43
|
+
super().__init__()
|
|
44
|
+
self._axes = None
|
|
45
|
+
|
|
46
|
+
def fit(self, X, y=None):
|
|
47
|
+
X = np.asarray(X)
|
|
48
|
+
self._axes = X.ndim == 4
|
|
49
|
+
assert (
|
|
50
|
+
self._axes or X.ndim == 3
|
|
51
|
+
), " Bad input shape. Input is either (degree) x (distance matrix) or (axis) x (degree) x (distance matrix) "
|
|
52
|
+
|
|
53
|
+
return self
|
|
54
|
+
|
|
55
|
+
def transform(self, X):
|
|
56
|
+
X = np.asarray(X)
|
|
57
|
+
assert (X.ndim == 3 and not self._axes) or (
|
|
58
|
+
X.ndim == 4 and self._axes
|
|
59
|
+
), f"X shape ({X.shape}) is not valid"
|
|
60
|
+
if self._axes:
|
|
61
|
+
out = np.asarray(
|
|
62
|
+
[
|
|
63
|
+
[
|
|
64
|
+
DistanceMatrix2DistanceList().fit_transform(M)
|
|
65
|
+
for M in matrices_in_axes
|
|
66
|
+
]
|
|
67
|
+
for matrices_in_axes in X
|
|
68
|
+
]
|
|
69
|
+
)
|
|
70
|
+
return np.moveaxis(out, [2, 0, 1, 3], [0, 1, 2, 3])
|
|
71
|
+
else:
|
|
72
|
+
out = np.array(
|
|
73
|
+
[DistanceMatrix2DistanceList().fit_transform(M) for M in X]
|
|
74
|
+
) # indices are at [:,0,Any_coord]
|
|
75
|
+
# return np.moveaxis(out, 0, -1) ## indices are at [:,0,any_coord], degree axis is the last
|
|
76
|
+
return np.moveaxis(out, [1, 0, 2], [0, 1, 2])
|
|
77
|
+
|
|
78
|
+
def predict(self, X):
|
|
79
|
+
return self.transform(X)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class DistancesLists2DistanceMatrices(BaseEstimator, TransformerMixin):
|
|
83
|
+
"""
|
|
84
|
+
Input (D1) x opt (axis) x (degree) x (D2 with indices first)
|
|
85
|
+
Output opt (axis) x (degree) x (distance matrix (D1,D2))
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(self) -> None:
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.train_indices = None
|
|
91
|
+
self._axes = None
|
|
92
|
+
|
|
93
|
+
def fit(self, X: np.ndarray, y=None):
|
|
94
|
+
X = np.asarray(X)
|
|
95
|
+
assert X.ndim in [3, 4]
|
|
96
|
+
self._axes = X.ndim == 4
|
|
97
|
+
if self._axes:
|
|
98
|
+
self.train_indices = np.asarray(X[:, 0, 0, 0], dtype=int)
|
|
99
|
+
else:
|
|
100
|
+
self.train_indices = np.asarray(X[:, 0, 0], dtype=int)
|
|
101
|
+
return self
|
|
102
|
+
|
|
103
|
+
def transform(self, X):
|
|
104
|
+
X = np.asarray(X)
|
|
105
|
+
assert X.ndim in [3, 4]
|
|
106
|
+
# test_indices = np.asarray(X[:,0,0], dtype=int)
|
|
107
|
+
# print(X.shape, self.train_indices, test_indices, flush=1)
|
|
108
|
+
# First coord of X is test indices by design, train indices have to be selected in the second coord, last one is the degree
|
|
109
|
+
if self._axes:
|
|
110
|
+
Y = X[:, :, :, self.train_indices + 1]
|
|
111
|
+
return np.moveaxis(Y, [0, 1, 2, 3], [2, 0, 1, 3])
|
|
112
|
+
else:
|
|
113
|
+
Y = X[
|
|
114
|
+
:, :, self.train_indices + 1
|
|
115
|
+
] # we only keep the good indices # shift of 1, because the first index is for indexing the pts
|
|
116
|
+
return np.moveaxis(
|
|
117
|
+
Y, [0, 1, 2], [1, 0, 2]
|
|
118
|
+
) # we put back the degree axis first
|
|
119
|
+
|
|
120
|
+
# # out = np.moveaxis(Y,-1,0) ## we put back the degree axis first
|
|
121
|
+
# return out
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class DistanceMatrix2Kernel(BaseEstimator, TransformerMixin):
|
|
125
|
+
"""
|
|
126
|
+
Input : (degree) x (distance matrix) or (axis) x (degree) x (distance matrix) in the second case, axis HAS to be specified (meant for cross validation)
|
|
127
|
+
Output : kernel of the same shape of distance matrix
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
sigma: float | Iterable[float] = 1,
|
|
133
|
+
axis: int | None = None,
|
|
134
|
+
weights: Iterable[float] | float = 1,
|
|
135
|
+
) -> None:
|
|
136
|
+
super().__init__()
|
|
137
|
+
self.sigma = sigma
|
|
138
|
+
self.axis = axis
|
|
139
|
+
self.weights = weights
|
|
140
|
+
# self._num_axes=None
|
|
141
|
+
self._num_degrees = None
|
|
142
|
+
|
|
143
|
+
def fit(self, X, y=None):
|
|
144
|
+
if len(X) == 0:
|
|
145
|
+
return self
|
|
146
|
+
assert X.ndim in [3, 4], "Bad input."
|
|
147
|
+
if self.axis is None:
|
|
148
|
+
assert X.ndim == 3 or X.shape[0] == 1, "Set an axis for data with axis !"
|
|
149
|
+
if X.shape[0] == 1 and X.ndim == 4:
|
|
150
|
+
self.axis = 0
|
|
151
|
+
self._num_degrees = len(X[0])
|
|
152
|
+
else:
|
|
153
|
+
self._num_degrees = len(X)
|
|
154
|
+
else:
|
|
155
|
+
assert X.ndim == 4, "Cannot choose axis from data with no axis !"
|
|
156
|
+
self._num_degrees = len(X[self.axis])
|
|
157
|
+
if isinstance(self.weights, float) or isinstance(self.weights, int):
|
|
158
|
+
self.weights = [self.weights] * self._num_degrees
|
|
159
|
+
assert (
|
|
160
|
+
len(self.weights) == self._num_degrees
|
|
161
|
+
), f"Number of weights ({len(self.weights)}) has to be the same as the number of degrees ({self._num_degrees})"
|
|
162
|
+
return self
|
|
163
|
+
|
|
164
|
+
def transform(self, X) -> np.ndarray:
|
|
165
|
+
if self.axis is not None:
|
|
166
|
+
X = X[self.axis]
|
|
167
|
+
# TODO : pykeops, and full pipeline w/ pykeops
|
|
168
|
+
kernels = np.asarray(
|
|
169
|
+
[
|
|
170
|
+
np.exp(-distance_matrix / (2 * self.sigma**2)) * weight
|
|
171
|
+
for distance_matrix, weight in zip(X, self.weights)
|
|
172
|
+
]
|
|
173
|
+
)
|
|
174
|
+
out = np.mean(kernels, axis=0)
|
|
175
|
+
|
|
176
|
+
return out
|