multipers 2.3.3b6__cp312-cp312-manylinux_2_39_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -0
  2. multipers/_signed_measure_meta.py +453 -0
  3. multipers/_slicer_meta.py +211 -0
  4. multipers/array_api/__init__.py +45 -0
  5. multipers/array_api/numpy.py +41 -0
  6. multipers/array_api/torch.py +58 -0
  7. multipers/data/MOL2.py +458 -0
  8. multipers/data/UCR.py +18 -0
  9. multipers/data/__init__.py +1 -0
  10. multipers/data/graphs.py +466 -0
  11. multipers/data/immuno_regions.py +27 -0
  12. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  13. multipers/data/pytorch2simplextree.py +91 -0
  14. multipers/data/shape3d.py +101 -0
  15. multipers/data/synthetic.py +113 -0
  16. multipers/distances.py +202 -0
  17. multipers/filtration_conversions.pxd +229 -0
  18. multipers/filtration_conversions.pxd.tp +84 -0
  19. multipers/filtrations/__init__.py +18 -0
  20. multipers/filtrations/density.py +574 -0
  21. multipers/filtrations/filtrations.py +361 -0
  22. multipers/filtrations.pxd +224 -0
  23. multipers/function_rips.cpython-312-x86_64-linux-gnu.so +0 -0
  24. multipers/function_rips.pyx +105 -0
  25. multipers/grids.cpython-312-x86_64-linux-gnu.so +0 -0
  26. multipers/grids.pyx +433 -0
  27. multipers/gudhi/Persistence_slices_interface.h +132 -0
  28. multipers/gudhi/Simplex_tree_interface.h +239 -0
  29. multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
  30. multipers/gudhi/cubical_to_boundary.h +59 -0
  31. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  32. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  33. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  34. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  35. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  36. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  37. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  38. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  39. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  41. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  42. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  43. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  44. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  45. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  46. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  47. multipers/gudhi/gudhi/Matrix.h +2107 -0
  48. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  49. multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
  50. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  51. multipers/gudhi/gudhi/Off_reader.h +173 -0
  52. multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
  53. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  54. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  87. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  88. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  89. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  90. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  91. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  92. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  93. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  94. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  95. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  96. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  97. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  98. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  99. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  100. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  101. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  102. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
  103. multipers/gudhi/gudhi/distance_functions.h +62 -0
  104. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  105. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  106. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  107. multipers/gudhi/gudhi/reader_utils.h +367 -0
  108. multipers/gudhi/mma_interface_coh.h +256 -0
  109. multipers/gudhi/mma_interface_h0.h +223 -0
  110. multipers/gudhi/mma_interface_matrix.h +293 -0
  111. multipers/gudhi/naive_merge_tree.h +536 -0
  112. multipers/gudhi/scc_io.h +310 -0
  113. multipers/gudhi/truc.h +1403 -0
  114. multipers/io.cpython-312-x86_64-linux-gnu.so +0 -0
  115. multipers/io.pyx +644 -0
  116. multipers/ml/__init__.py +0 -0
  117. multipers/ml/accuracies.py +90 -0
  118. multipers/ml/invariants_with_persistable.py +79 -0
  119. multipers/ml/kernels.py +176 -0
  120. multipers/ml/mma.py +713 -0
  121. multipers/ml/one.py +472 -0
  122. multipers/ml/point_clouds.py +352 -0
  123. multipers/ml/signed_measures.py +1589 -0
  124. multipers/ml/sliced_wasserstein.py +461 -0
  125. multipers/ml/tools.py +113 -0
  126. multipers/mma_structures.cpython-312-x86_64-linux-gnu.so +0 -0
  127. multipers/mma_structures.pxd +128 -0
  128. multipers/mma_structures.pyx +2786 -0
  129. multipers/mma_structures.pyx.tp +1094 -0
  130. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
  131. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  132. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  133. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  134. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  135. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  136. multipers/multiparameter_edge_collapse.py +41 -0
  137. multipers/multiparameter_module_approximation/approximation.h +2330 -0
  138. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  139. multipers/multiparameter_module_approximation/debug.h +107 -0
  140. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  141. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  142. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  143. multipers/multiparameter_module_approximation/images.h +79 -0
  144. multipers/multiparameter_module_approximation/list_column.h +174 -0
  145. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  146. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  147. multipers/multiparameter_module_approximation/set_column.h +135 -0
  148. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  149. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  150. multipers/multiparameter_module_approximation/utilities.h +403 -0
  151. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  152. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  153. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  154. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  155. multipers/multiparameter_module_approximation.cpython-312-x86_64-linux-gnu.so +0 -0
  156. multipers/multiparameter_module_approximation.pyx +235 -0
  157. multipers/pickle.py +90 -0
  158. multipers/plots.py +456 -0
  159. multipers/point_measure.cpython-312-x86_64-linux-gnu.so +0 -0
  160. multipers/point_measure.pyx +395 -0
  161. multipers/simplex_tree_multi.cpython-312-x86_64-linux-gnu.so +0 -0
  162. multipers/simplex_tree_multi.pxd +134 -0
  163. multipers/simplex_tree_multi.pyx +10840 -0
  164. multipers/simplex_tree_multi.pyx.tp +2009 -0
  165. multipers/slicer.cpython-312-x86_64-linux-gnu.so +0 -0
  166. multipers/slicer.pxd +3034 -0
  167. multipers/slicer.pxd.tp +234 -0
  168. multipers/slicer.pyx +20481 -0
  169. multipers/slicer.pyx.tp +1088 -0
  170. multipers/tensor/tensor.h +672 -0
  171. multipers/tensor.pxd +13 -0
  172. multipers/test.pyx +44 -0
  173. multipers/tests/__init__.py +62 -0
  174. multipers/torch/__init__.py +1 -0
  175. multipers/torch/diff_grids.py +240 -0
  176. multipers/torch/rips_density.py +310 -0
  177. multipers-2.3.3b6.dist-info/METADATA +128 -0
  178. multipers-2.3.3b6.dist-info/RECORD +182 -0
  179. multipers-2.3.3b6.dist-info/WHEEL +5 -0
  180. multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
  181. multipers-2.3.3b6.dist-info/top_level.txt +1 -0
  182. multipers.libs/libtbb-ca48af5c.so.12.16 +0 -0
@@ -0,0 +1,90 @@
1
+ import pandas as pd
2
+ from warnings import warn
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from os.path import exists
6
+
7
+
8
+ def accuracy_to_csv(
9
+ X,
10
+ Y,
11
+ cl,
12
+ k: float = 10,
13
+ dataset: str = "",
14
+ shuffle=True,
15
+ verbose: bool = True,
16
+ **more_columns,
17
+ ):
18
+ assert k > 0, "k is either the number of kfold > 1 or the test size > 0."
19
+ if k > 1:
20
+ k = int(k)
21
+ from sklearn.model_selection import StratifiedKFold as KFold
22
+
23
+ kfold = KFold(k, shuffle=shuffle).split(X, Y)
24
+ accuracies = np.zeros(k)
25
+ for i, (train_idx, test_idx) in enumerate(
26
+ tqdm(kfold, total=k, desc="Computing kfold")
27
+ ):
28
+ xtrain = [X[i] for i in train_idx]
29
+ ytrain = [Y[i] for i in train_idx]
30
+ cl.fit(xtrain, ytrain)
31
+ xtest = [X[i] for i in test_idx]
32
+ ytest = [Y[i] for i in test_idx]
33
+ accuracies[i] = cl.score(xtest, ytest)
34
+ if verbose:
35
+ print(f"step {i+1}, {dataset} : {accuracies[i]}", flush=True)
36
+ try:
37
+ print("Best classification parameters : ", cl.best_params_)
38
+ except:
39
+ None
40
+
41
+ print(
42
+ f"""Accuracy {dataset} : {np.mean(accuracies).round(decimals=3)}±{np.std(accuracies).round(decimals=3)}"""
43
+ )
44
+ elif k > 0:
45
+ from sklearn.model_selection import train_test_split
46
+
47
+ print("Computing accuracy, with train test split", flush=True)
48
+ xtrain, xtest, ytrain, ytest = train_test_split(
49
+ X, Y, shuffle=shuffle, test_size=k
50
+ )
51
+ print("Fitting...", end="", flush=True)
52
+ cl.fit(xtrain, ytrain)
53
+ print("Computing score...", end="", flush=True)
54
+ accuracies = cl.score(xtest, ytest)
55
+ try:
56
+ print("Best classification parameters : ", cl.best_params_)
57
+ except:
58
+ None
59
+ print("Done.")
60
+ if verbose:
61
+ print(f"Accuracy {dataset} : {accuracies} ")
62
+ file_path: str = f"result_{dataset}.csv".replace("/", "_").replace(".off", "")
63
+ columns: list[str] = ["dataset", "cv", "mean", "std"]
64
+ if exists(file_path):
65
+ df: pd.DataFrame = pd.read_csv(file_path)
66
+ else:
67
+ df: pd.DataFrame = pd.DataFrame(columns=columns)
68
+ more_names = []
69
+ more_values = []
70
+ for key, value in more_columns.items():
71
+ if key not in columns:
72
+ more_names.append(key)
73
+ more_values.append(value)
74
+ else:
75
+ warn(f"Duplicate key {key} ! with value {value}")
76
+ new_line: pd.DataFrame = pd.DataFrame(
77
+ [
78
+ [
79
+ dataset,
80
+ k,
81
+ np.mean(accuracies).round(decimals=3),
82
+ np.std(accuracies).round(decimals=3),
83
+ ]
84
+ + more_values
85
+ ],
86
+ columns=columns + more_names,
87
+ )
88
+ print(new_line)
89
+ df = pd.concat([df, new_line])
90
+ df.to_csv(file_path, index=False)
@@ -0,0 +1,79 @@
1
+ import persistable
2
+
3
+
4
+ # requires installing ripser (pip install ripser) as well as persistable from the higher-homology branch,
5
+ # which can be done as follows:
6
+ # pip install git+https://github.com/LuisScoccola/persistable.git@higher-homology
7
+ # NOTE: only accepts as input a distance matrix
8
+ def hf_degree_rips(
9
+ distance_matrix,
10
+ min_rips_value,
11
+ max_rips_value,
12
+ max_normalized_degree,
13
+ min_normalized_degree,
14
+ grid_granularity,
15
+ max_homological_dimension,
16
+ subsample_size = None,
17
+ ):
18
+ if subsample_size == None:
19
+ p = persistable.Persistable(distance_matrix, metric="precomputed")
20
+ else:
21
+ p = persistable.Persistable(distance_matrix, metric="precomputed", subsample=subsample_size)
22
+
23
+ rips_values, normalized_degree_values, hilbert_functions, minimal_hilbert_decompositions = p._hilbert_function(
24
+ min_rips_value,
25
+ max_rips_value,
26
+ max_normalized_degree,
27
+ min_normalized_degree,
28
+ grid_granularity,
29
+ homological_dimension=max_homological_dimension,
30
+ )
31
+
32
+ return rips_values, normalized_degree_values, hilbert_functions, minimal_hilbert_decompositions
33
+
34
+
35
+
36
+ def hf_h0_degree_rips(
37
+ point_cloud,
38
+ min_rips_value,
39
+ max_rips_value,
40
+ max_normalized_degree,
41
+ min_normalized_degree,
42
+ grid_granularity,
43
+ ):
44
+ p = persistable.Persistable(point_cloud, n_neighbors="all")
45
+
46
+ rips_values, normalized_degree_values, hilbert_functions, minimal_hilbert_decompositions = p._hilbert_function(
47
+ min_rips_value,
48
+ max_rips_value,
49
+ max_normalized_degree,
50
+ min_normalized_degree,
51
+ grid_granularity,
52
+ )
53
+
54
+ return rips_values, normalized_degree_values, hilbert_functions[0], minimal_hilbert_decompositions[0]
55
+
56
+
57
+ def ri_h0_degree_rips(
58
+ point_cloud,
59
+ min_rips_value,
60
+ max_rips_value,
61
+ max_normalized_degree,
62
+ min_normalized_degree,
63
+ grid_granularity,
64
+ ):
65
+ p = persistable.Persistable(point_cloud, n_neighbors="all")
66
+
67
+ rips_values, normalized_degree_values, rank_invariant, _, _ = p._rank_invariant(
68
+ min_rips_value,
69
+ max_rips_value,
70
+ max_normalized_degree,
71
+ min_normalized_degree,
72
+ grid_granularity,
73
+ )
74
+
75
+ return rips_values, normalized_degree_values, rank_invariant
76
+
77
+
78
+
79
+
@@ -0,0 +1,176 @@
1
+ from sklearn.base import BaseEstimator, TransformerMixin, clone
2
+ import numpy as np
3
+ from typing import Iterable
4
+
5
+
6
+ # To do k folds with a distance matrix, we need to slice it into list of distances.
7
+ # k-fold usually shuffles the lists, so we need to add an identifier to each entry,
8
+ #
9
+ class DistanceMatrix2DistanceList(BaseEstimator, TransformerMixin):
10
+ def __init__(self) -> None:
11
+ super().__init__()
12
+
13
+ def fit(self, X, y=None):
14
+ return self
15
+
16
+ def transform(self, X):
17
+ X = np.asarray(X)
18
+ assert X.ndim == 2 # Its a matrix
19
+ return np.asarray([[i, *distance_to_pt] for i, distance_to_pt in enumerate(X)])
20
+
21
+
22
+ class DistanceList2DistanceMatrix(BaseEstimator, TransformerMixin):
23
+ def __init__(self) -> None:
24
+ super().__init__()
25
+
26
+ def fit(self, X, y=None):
27
+ return self
28
+
29
+ def transform(self, X):
30
+ index_list = (
31
+ np.asarray(X[:, 0], dtype=int) + 1
32
+ ) # shift of 1, because the first index is for indexing the pts
33
+ return X[:, index_list] # The distance matrix of the index_list
34
+
35
+
36
+ class DistanceMatrices2DistancesList(BaseEstimator, TransformerMixin):
37
+ """
38
+ Input (degree) x (distance matrix) or (axis) x (degree) x (distance matrix D)
39
+ Output _ (D1) x opt (axis) x (degree) x (D2, , with indices first)
40
+ """
41
+
42
+ def __init__(self) -> None:
43
+ super().__init__()
44
+ self._axes = None
45
+
46
+ def fit(self, X, y=None):
47
+ X = np.asarray(X)
48
+ self._axes = X.ndim == 4
49
+ assert (
50
+ self._axes or X.ndim == 3
51
+ ), " Bad input shape. Input is either (degree) x (distance matrix) or (axis) x (degree) x (distance matrix) "
52
+
53
+ return self
54
+
55
+ def transform(self, X):
56
+ X = np.asarray(X)
57
+ assert (X.ndim == 3 and not self._axes) or (
58
+ X.ndim == 4 and self._axes
59
+ ), f"X shape ({X.shape}) is not valid"
60
+ if self._axes:
61
+ out = np.asarray(
62
+ [
63
+ [
64
+ DistanceMatrix2DistanceList().fit_transform(M)
65
+ for M in matrices_in_axes
66
+ ]
67
+ for matrices_in_axes in X
68
+ ]
69
+ )
70
+ return np.moveaxis(out, [2, 0, 1, 3], [0, 1, 2, 3])
71
+ else:
72
+ out = np.array(
73
+ [DistanceMatrix2DistanceList().fit_transform(M) for M in X]
74
+ ) # indices are at [:,0,Any_coord]
75
+ # return np.moveaxis(out, 0, -1) ## indices are at [:,0,any_coord], degree axis is the last
76
+ return np.moveaxis(out, [1, 0, 2], [0, 1, 2])
77
+
78
+ def predict(self, X):
79
+ return self.transform(X)
80
+
81
+
82
+ class DistancesLists2DistanceMatrices(BaseEstimator, TransformerMixin):
83
+ """
84
+ Input (D1) x opt (axis) x (degree) x (D2 with indices first)
85
+ Output opt (axis) x (degree) x (distance matrix (D1,D2))
86
+ """
87
+
88
+ def __init__(self) -> None:
89
+ super().__init__()
90
+ self.train_indices = None
91
+ self._axes = None
92
+
93
+ def fit(self, X: np.ndarray, y=None):
94
+ X = np.asarray(X)
95
+ assert X.ndim in [3, 4]
96
+ self._axes = X.ndim == 4
97
+ if self._axes:
98
+ self.train_indices = np.asarray(X[:, 0, 0, 0], dtype=int)
99
+ else:
100
+ self.train_indices = np.asarray(X[:, 0, 0], dtype=int)
101
+ return self
102
+
103
+ def transform(self, X):
104
+ X = np.asarray(X)
105
+ assert X.ndim in [3, 4]
106
+ # test_indices = np.asarray(X[:,0,0], dtype=int)
107
+ # print(X.shape, self.train_indices, test_indices, flush=1)
108
+ # First coord of X is test indices by design, train indices have to be selected in the second coord, last one is the degree
109
+ if self._axes:
110
+ Y = X[:, :, :, self.train_indices + 1]
111
+ return np.moveaxis(Y, [0, 1, 2, 3], [2, 0, 1, 3])
112
+ else:
113
+ Y = X[
114
+ :, :, self.train_indices + 1
115
+ ] # we only keep the good indices # shift of 1, because the first index is for indexing the pts
116
+ return np.moveaxis(
117
+ Y, [0, 1, 2], [1, 0, 2]
118
+ ) # we put back the degree axis first
119
+
120
+ # # out = np.moveaxis(Y,-1,0) ## we put back the degree axis first
121
+ # return out
122
+
123
+
124
+ class DistanceMatrix2Kernel(BaseEstimator, TransformerMixin):
125
+ """
126
+ Input : (degree) x (distance matrix) or (axis) x (degree) x (distance matrix) in the second case, axis HAS to be specified (meant for cross validation)
127
+ Output : kernel of the same shape of distance matrix
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ sigma: float | Iterable[float] = 1,
133
+ axis: int | None = None,
134
+ weights: Iterable[float] | float = 1,
135
+ ) -> None:
136
+ super().__init__()
137
+ self.sigma = sigma
138
+ self.axis = axis
139
+ self.weights = weights
140
+ # self._num_axes=None
141
+ self._num_degrees = None
142
+
143
+ def fit(self, X, y=None):
144
+ if len(X) == 0:
145
+ return self
146
+ assert X.ndim in [3, 4], "Bad input."
147
+ if self.axis is None:
148
+ assert X.ndim == 3 or X.shape[0] == 1, "Set an axis for data with axis !"
149
+ if X.shape[0] == 1 and X.ndim == 4:
150
+ self.axis = 0
151
+ self._num_degrees = len(X[0])
152
+ else:
153
+ self._num_degrees = len(X)
154
+ else:
155
+ assert X.ndim == 4, "Cannot choose axis from data with no axis !"
156
+ self._num_degrees = len(X[self.axis])
157
+ if isinstance(self.weights, float) or isinstance(self.weights, int):
158
+ self.weights = [self.weights] * self._num_degrees
159
+ assert (
160
+ len(self.weights) == self._num_degrees
161
+ ), f"Number of weights ({len(self.weights)}) has to be the same as the number of degrees ({self._num_degrees})"
162
+ return self
163
+
164
+ def transform(self, X) -> np.ndarray:
165
+ if self.axis is not None:
166
+ X = X[self.axis]
167
+ # TODO : pykeops, and full pipeline w/ pykeops
168
+ kernels = np.asarray(
169
+ [
170
+ np.exp(-distance_matrix / (2 * self.sigma**2)) * weight
171
+ for distance_matrix, weight in zip(X, self.weights)
172
+ ]
173
+ )
174
+ out = np.mean(kernels, axis=0)
175
+
176
+ return out