multipers 2.2.3__cp312-cp312-win_amd64.whl → 2.3.1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -31
  2. multipers/_signed_measure_meta.py +430 -430
  3. multipers/_slicer_meta.py +211 -212
  4. multipers/data/MOL2.py +458 -458
  5. multipers/data/UCR.py +18 -18
  6. multipers/data/graphs.py +466 -466
  7. multipers/data/immuno_regions.py +27 -27
  8. multipers/data/pytorch2simplextree.py +90 -90
  9. multipers/data/shape3d.py +101 -101
  10. multipers/data/synthetic.py +113 -111
  11. multipers/distances.py +198 -198
  12. multipers/filtration_conversions.pxd.tp +84 -84
  13. multipers/filtrations/__init__.py +18 -0
  14. multipers/{ml/convolutions.py → filtrations/density.py} +563 -520
  15. multipers/filtrations/filtrations.py +289 -0
  16. multipers/filtrations.pxd +224 -224
  17. multipers/function_rips.cp312-win_amd64.pyd +0 -0
  18. multipers/function_rips.pyx +105 -105
  19. multipers/grids.cp312-win_amd64.pyd +0 -0
  20. multipers/grids.pyx +350 -350
  21. multipers/gudhi/Persistence_slices_interface.h +132 -132
  22. multipers/gudhi/Simplex_tree_interface.h +239 -245
  23. multipers/gudhi/Simplex_tree_multi_interface.h +516 -561
  24. multipers/gudhi/cubical_to_boundary.h +59 -59
  25. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -450
  26. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -1070
  27. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -579
  28. multipers/gudhi/gudhi/Debug_utils.h +45 -45
  29. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -484
  30. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -455
  31. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -450
  32. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -531
  33. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -507
  34. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -531
  35. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -355
  36. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -376
  37. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -420
  38. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -400
  39. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -418
  40. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -337
  41. multipers/gudhi/gudhi/Matrix.h +2107 -2107
  42. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -1038
  43. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -171
  44. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -282
  45. multipers/gudhi/gudhi/Off_reader.h +173 -173
  46. multipers/gudhi/gudhi/One_critical_filtration.h +1433 -1431
  47. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -769
  48. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -686
  49. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -842
  50. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -1350
  51. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -1105
  52. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -859
  53. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -910
  54. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -139
  55. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -230
  56. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -211
  57. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -60
  58. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -60
  59. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -136
  60. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -190
  61. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -616
  62. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -150
  63. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -106
  64. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -219
  65. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -327
  66. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -1140
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -934
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -934
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -980
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -1092
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -192
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -921
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -1093
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -1012
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -1244
  76. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -186
  77. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -164
  78. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -156
  79. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -376
  80. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -540
  81. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -118
  82. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -173
  83. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -128
  84. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -745
  85. multipers/gudhi/gudhi/Points_off_io.h +171 -171
  86. multipers/gudhi/gudhi/Simple_object_pool.h +69 -69
  87. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -463
  88. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -83
  89. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -106
  90. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -277
  91. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -62
  92. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -27
  93. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -62
  94. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -157
  95. multipers/gudhi/gudhi/Simplex_tree.h +2794 -2794
  96. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -163
  97. multipers/gudhi/gudhi/distance_functions.h +62 -62
  98. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -104
  99. multipers/gudhi/gudhi/persistence_interval.h +253 -253
  100. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -170
  101. multipers/gudhi/gudhi/reader_utils.h +367 -367
  102. multipers/gudhi/mma_interface_coh.h +256 -255
  103. multipers/gudhi/mma_interface_h0.h +223 -231
  104. multipers/gudhi/mma_interface_matrix.h +291 -282
  105. multipers/gudhi/naive_merge_tree.h +536 -575
  106. multipers/gudhi/scc_io.h +310 -289
  107. multipers/gudhi/truc.h +957 -888
  108. multipers/io.cp312-win_amd64.pyd +0 -0
  109. multipers/io.pyx +714 -711
  110. multipers/ml/accuracies.py +90 -90
  111. multipers/ml/invariants_with_persistable.py +79 -79
  112. multipers/ml/kernels.py +176 -176
  113. multipers/ml/mma.py +713 -714
  114. multipers/ml/one.py +472 -472
  115. multipers/ml/point_clouds.py +352 -346
  116. multipers/ml/signed_measures.py +1589 -1589
  117. multipers/ml/sliced_wasserstein.py +461 -461
  118. multipers/ml/tools.py +113 -113
  119. multipers/mma_structures.cp312-win_amd64.pyd +0 -0
  120. multipers/mma_structures.pxd +127 -127
  121. multipers/mma_structures.pyx +4 -8
  122. multipers/mma_structures.pyx.tp +1083 -1085
  123. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -93
  124. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -97
  125. multipers/multi_parameter_rank_invariant/function_rips.h +322 -322
  126. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -769
  127. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -148
  128. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -369
  129. multipers/multiparameter_edge_collapse.py +41 -41
  130. multipers/multiparameter_module_approximation/approximation.h +2298 -2295
  131. multipers/multiparameter_module_approximation/combinatory.h +129 -129
  132. multipers/multiparameter_module_approximation/debug.h +107 -107
  133. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -286
  134. multipers/multiparameter_module_approximation/heap_column.h +238 -238
  135. multipers/multiparameter_module_approximation/images.h +79 -79
  136. multipers/multiparameter_module_approximation/list_column.h +174 -174
  137. multipers/multiparameter_module_approximation/list_column_2.h +232 -232
  138. multipers/multiparameter_module_approximation/ru_matrix.h +347 -347
  139. multipers/multiparameter_module_approximation/set_column.h +135 -135
  140. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -36
  141. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -166
  142. multipers/multiparameter_module_approximation/utilities.h +403 -419
  143. multipers/multiparameter_module_approximation/vector_column.h +223 -223
  144. multipers/multiparameter_module_approximation/vector_matrix.h +331 -331
  145. multipers/multiparameter_module_approximation/vineyards.h +464 -464
  146. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -649
  147. multipers/multiparameter_module_approximation.cp312-win_amd64.pyd +0 -0
  148. multipers/multiparameter_module_approximation.pyx +218 -217
  149. multipers/pickle.py +90 -53
  150. multipers/plots.py +342 -334
  151. multipers/point_measure.cp312-win_amd64.pyd +0 -0
  152. multipers/point_measure.pyx +322 -320
  153. multipers/simplex_tree_multi.cp312-win_amd64.pyd +0 -0
  154. multipers/simplex_tree_multi.pxd +133 -133
  155. multipers/simplex_tree_multi.pyx +115 -48
  156. multipers/simplex_tree_multi.pyx.tp +1947 -1935
  157. multipers/slicer.cp312-win_amd64.pyd +0 -0
  158. multipers/slicer.pxd +281 -100
  159. multipers/slicer.pxd.tp +218 -214
  160. multipers/slicer.pyx +1570 -507
  161. multipers/slicer.pyx.tp +931 -914
  162. multipers/tensor/tensor.h +672 -672
  163. multipers/tensor.pxd +13 -13
  164. multipers/test.pyx +44 -44
  165. multipers/tests/__init__.py +57 -57
  166. multipers/torch/diff_grids.py +217 -217
  167. multipers/torch/rips_density.py +310 -304
  168. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/LICENSE +21 -21
  169. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/METADATA +21 -11
  170. multipers-2.3.1.dist-info/RECORD +182 -0
  171. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/WHEEL +1 -1
  172. multipers/tests/test_diff_helper.py +0 -73
  173. multipers/tests/test_hilbert_function.py +0 -82
  174. multipers/tests/test_mma.py +0 -83
  175. multipers/tests/test_point_clouds.py +0 -49
  176. multipers/tests/test_python-cpp_conversion.py +0 -82
  177. multipers/tests/test_signed_betti.py +0 -181
  178. multipers/tests/test_signed_measure.py +0 -89
  179. multipers/tests/test_simplextreemulti.py +0 -221
  180. multipers/tests/test_slicer.py +0 -221
  181. multipers-2.2.3.dist-info/RECORD +0 -189
  182. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/top_level.txt +0 -0
@@ -1,461 +1,461 @@
1
- # This code was written by Mathieu Carrière.
2
-
3
- import numpy as np
4
- from sklearn.base import BaseEstimator, TransformerMixin
5
- from sklearn.metrics import pairwise_distances, pairwise_kernels
6
- from joblib import Parallel, delayed
7
-
8
-
9
- def _pairwise(fallback, skipdiag, X, Y, metric, n_jobs):
10
- if Y is not None:
11
- return fallback(X, Y, metric=metric, n_jobs=n_jobs)
12
- triu = np.triu_indices(len(X), k=skipdiag)
13
- tril = (triu[1], triu[0])
14
- par = Parallel(n_jobs=n_jobs, prefer="threads")
15
- d = par(delayed(metric)([triu[0][i]], [triu[1][i]])
16
- for i in range(len(triu[0])))
17
- m = np.empty((len(X), len(X)))
18
- m[triu] = d
19
- m[tril] = d
20
- if skipdiag:
21
- np.fill_diagonal(m, 0)
22
- return m
23
-
24
-
25
- def _sklearn_wrapper(metric, X, Y, **kwargs):
26
- """
27
- This function is a wrapper for any metric between two signed measures that takes two numpy arrays of shapes (nxD) and (mxD) as arguments.
28
- """
29
- if Y is None:
30
-
31
- def flat_metric(a, b):
32
- return metric(X[int(a[0])], X[int(b[0])], **kwargs)
33
- else:
34
-
35
- def flat_metric(a, b):
36
- return metric(X[int(a[0])], Y[int(b[0])], **kwargs)
37
-
38
- return flat_metric
39
-
40
-
41
- def _compute_signed_measure_parts(X):
42
- """
43
- This is a function for separating the positive and negative points of a list of signed measures. This function can be used as a preprocessing step in order to speed up the running time for computing all pairwise (sliced) Wasserstein distances on a list of signed measures.
44
-
45
- Parameters:
46
- X (list of n tuples): list of signed measures.
47
-
48
- Returns:
49
- list of n pairs of numpy arrays of shape (num x dimension): list of positive and negative signed measures.
50
- """
51
- XX = []
52
- for C, M in X:
53
- pos_idxs = np.argwhere(M > 0).ravel()
54
- neg_idxs = np.setdiff1d(np.arange(len(M)), pos_idxs)
55
- XX.append(
56
- [
57
- np.repeat(C[pos_idxs], M[pos_idxs], axis=0),
58
- np.repeat(C[neg_idxs], -M[neg_idxs], axis=0),
59
- ]
60
- )
61
- return XX
62
-
63
-
64
- def _compute_signed_measure_projections(X, num_directions, scales):
65
- """
66
- This is a function for projecting the points of a list of signed measures onto a fixed number of lines sampled uniformly. This function can be used as a preprocessing step in order to speed up the running time for computing all pairwise sliced Wasserstein distances on a list of signed measures.
67
-
68
- Parameters:
69
- X (list of n tuples): list of signed measures.
70
- num_directions (int): number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the distance computation.
71
- scales (array of shape D): scales associated to the dimensions.
72
-
73
- Returns:
74
- list of n pairs of numpy arrays of shape (num x num_directions): list of positive and negative projected signed measures.
75
- """
76
- dimension = X[0][0].shape[1]
77
- np.random.seed(42)
78
- thetas = np.random.normal(0, 1, [num_directions, dimension])
79
- lines = (thetas / np.linalg.norm(thetas, axis=1)[:, None]).T
80
- weights = (
81
- np.linalg.norm(np.multiply(scales[:, None], lines), axis=0)
82
- if scales is not None
83
- else np.ones(num_directions)
84
- )
85
- XX = []
86
- for C, M in X:
87
- pos_idxs = np.argwhere(M > 0).ravel()
88
- neg_idxs = np.setdiff1d(np.arange(len(M)), pos_idxs)
89
- XX.append(
90
- [
91
- np.matmul(np.repeat(C[pos_idxs], M[pos_idxs], axis=0), lines),
92
- np.matmul(np.repeat(C[neg_idxs], -M[neg_idxs], axis=0), lines),
93
- weights,
94
- ]
95
- )
96
- return XX
97
-
98
-
99
- def pairwise_signed_measure_distances(
100
- X, Y=None, metric="sliced_wasserstein", n_jobs=None, **kwargs
101
- ):
102
- """
103
- This function computes the distance matrix between two lists of signed measures given as numpy arrays of shape (nxD).
104
-
105
- Parameters:
106
- X (list of n tuples): first list of signed measures.
107
- Y (list of m tuples): second list of signed measures (optional). If None, pairwise distances are computed from the first list only.
108
- metric: distance to use. It can be either a string ("sliced_wasserstein", "wasserstein") or a function taking two tuples as inputs. If it is a function, make sure that it is symmetric and that it outputs 0 if called on the same two tuples.
109
- n_jobs (int): number of jobs to use for the computation. This uses joblib.Parallel(prefer="threads"), so metrics that do not release the GIL may not scale unless run inside a `joblib.parallel_backend <https://joblib.readthedocs.io/en/latest/parallel.html#joblib.parallel_backend>`_ block.
110
- **kwargs: optional keyword parameters. Any further parameters are passed directly to the distance function. See the docs of the various distance classes in this module.
111
-
112
- Returns:
113
- numpy array of shape (nxm): distance matrix
114
- """
115
- XX = np.reshape(np.arange(len(X)), [-1, 1])
116
- YY = None if Y is None or Y is X else np.reshape(
117
- np.arange(len(Y)), [-1, 1])
118
- if metric == "sliced_wasserstein":
119
- Xproj = _compute_signed_measure_projections(X, **kwargs)
120
- Yproj = None if Y is None else _compute_signed_measure_projections(
121
- Y, **kwargs)
122
- return _pairwise(
123
- pairwise_distances,
124
- True,
125
- XX,
126
- YY,
127
- metric=_sklearn_wrapper(
128
- _sliced_wasserstein_distance_on_projections, Xproj, Yproj
129
- ),
130
- n_jobs=n_jobs,
131
- )
132
- elif metric == "wasserstein":
133
- Xproj = _compute_signed_measure_parts(X)
134
- Yproj = None if Y is None else _compute_signed_measure_parts(Y)
135
- return _pairwise(
136
- pairwise_distances,
137
- True,
138
- XX,
139
- YY,
140
- metric=_sklearn_wrapper(
141
- _wasserstein_distance_on_parts(**kwargs), Xproj, Yproj
142
- ),
143
- n_jobs=n_jobs,
144
- )
145
- else:
146
- return _pairwise(
147
- pairwise_distances,
148
- True,
149
- XX,
150
- YY,
151
- metric=_sklearn_wrapper(metric, X, Y, **kwargs),
152
- n_jobs=n_jobs,
153
- )
154
-
155
-
156
- def _wasserstein_distance_on_parts(ground_norm=1, epsilon=1.0):
157
- """
158
- This is a function for computing the Wasserstein distance between two signed measures that have already been separated into their positive and negative parts.
159
-
160
- Parameters:
161
- meas1: pair of (n x dimension) numpy.arrays containing the points of the positive and negative parts of the first measure.
162
- meas2: pair of (m x dimension) numpy.arrays containing the points of the positive and negative parts of the second measure.
163
-
164
- Returns:
165
- float: the sliced Wasserstein distance between the projected signed measures.
166
- """
167
-
168
- def metric(meas1, meas2):
169
- meas1_plus, meas1_minus = meas1[0], meas1[1]
170
- meas2_plus, meas2_minus = meas2[0], meas2[1]
171
- num_pts = len(meas1_plus) + len(meas2_minus)
172
- meas_t1 = np.vstack([meas1_plus, meas2_minus])
173
- meas_t2 = np.vstack([meas2_plus, meas1_minus])
174
- import ot
175
-
176
- if epsilon > 0:
177
- wass = ot.sinkhorn2(
178
- 1 / num_pts * np.ones(num_pts),
179
- 1 / num_pts * np.ones(num_pts),
180
- pairwise_distances(
181
- meas_t1, meas_t2, metric="minkowski", p=ground_norm),
182
- epsilon,
183
- )
184
- return wass[0]
185
- else:
186
- wass = ot.lp.emd2(
187
- [],
188
- [],
189
- np.ascontiguousarray(
190
- pairwise_distances(
191
- meas_t1, meas_t2, metric="minkowski", p=ground_norm
192
- ),
193
- dtype=np.float64,
194
- ),
195
- )
196
- return wass
197
-
198
- return metric
199
-
200
-
201
- def _sliced_wasserstein_distance_on_projections(meas1, meas2, scales=None):
202
- """
203
- This is a function for computing the sliced Wasserstein distance between two signed measures that have already been projected onto some lines. It simply amounts to comparing the sorted projections with the 1-norm, and averaging over the lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
204
-
205
- Parameters:
206
- meas1: pair of (n x number_of_lines) numpy.arrays containing the projected points of the positive and negative parts of the first measure.
207
- meas2: pair of (m x number_of_lines) numpy.arrays containing the projected points of the positive and negative parts of the second measure.
208
- scales (array of shape D): scales associated to the dimensions.
209
-
210
- Returns:
211
- float: the sliced Wasserstein distance between the projected signed measures.
212
- """
213
- # assert np.array_equal( meas1[2], meas2[2] )
214
- weights = meas1[2]
215
- meas1_plus, meas1_minus = meas1[0], meas1[1]
216
- meas2_plus, meas2_minus = meas2[0], meas2[1]
217
- A = np.sort(np.vstack([meas1_plus, meas2_minus]), axis=0)
218
- B = np.sort(np.vstack([meas2_plus, meas1_minus]), axis=0)
219
- L1 = np.sum(np.abs(A - B), axis=0)
220
- return np.mean(np.multiply(L1, weights))
221
-
222
-
223
- def _sliced_wasserstein_distance(meas1, meas2, num_directions, scales=None):
224
- """
225
- This is a function for computing the sliced Wasserstein distance from two signed measures. The Sliced Wasserstein distance is computed by projecting the signed measures onto lines, comparing the projections with the 1-norm, and finally averaging over the lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
226
-
227
- Parameters:
228
- meas1: ((n x D), (n)) tuple with numpy.array encoding the (finite points of the) first measure and their multiplicities. Must not contain essential points (i.e. with infinite coordinate).
229
- meas2: ((m x D), (m)) tuple encoding the second measure.
230
- num_directions (int): number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the distance computation.
231
- scales (array of shape D): scales associated to the dimensions.
232
-
233
- Returns:
234
- float: the sliced Wasserstein distance between signed measures.
235
- """
236
- C1, M1 = meas1[0], meas1[1]
237
- C2, M2 = meas2[0], meas2[1]
238
- dimension = C1.shape[1]
239
- C1_plus_idxs, C2_plus_idxs = (
240
- np.argwhere(M1 > 0).ravel(),
241
- np.argwhere(M2 > 0).ravel(),
242
- )
243
- C1_minus_idxs, C2_minus_idxs = (
244
- np.setdiff1d(np.arange(len(M1)), C1_plus_idxs),
245
- np.setdiff1d(np.arange(len(M2)), C2_plus_idxs),
246
- )
247
- np.random.seed(42)
248
- thetas = np.random.normal(0, 1, [num_directions, dimension])
249
- lines = (thetas / np.linalg.norm(thetas, axis=1)[:, None]).T
250
- weights = (
251
- np.linalg.norm(np.multiply(scales[:, None], lines), axis=0)
252
- if scales is not None
253
- else np.ones(num_directions)
254
- )
255
- approx1 = np.matmul(
256
- np.vstack(
257
- [
258
- np.repeat(C1[C1_plus_idxs], M1[C1_plus_idxs], axis=0),
259
- np.repeat(C2[C2_minus_idxs], -M2[C2_minus_idxs], axis=0),
260
- ]
261
- ),
262
- lines,
263
- )
264
- approx2 = np.matmul(
265
- np.vstack(
266
- [
267
- np.repeat(C2[C2_plus_idxs], M2[C2_plus_idxs], axis=0),
268
- np.repeat(C1[C1_minus_idxs], -M1[C1_minus_idxs], axis=0),
269
- ]
270
- ),
271
- lines,
272
- )
273
- A = np.sort(approx1, axis=0)
274
- B = np.sort(approx2, axis=0)
275
- L1 = np.sum(np.abs(A - B), axis=0)
276
- return np.mean(np.multiply(L1, weights))
277
-
278
-
279
- def _wasserstein_distance(meas1, meas2, epsilon, ground_norm):
280
- """
281
- This is a function for computing the Wasserstein distance from two signed measures.
282
-
283
- Parameters:
284
- meas1: ((n x D), (n)) tuple with numpy.array encoding the (finite points of the) first measure and their multiplicities. Must not contain essential points (i.e. with infinite coordinate).
285
- meas2: ((m x D), (m)) tuple encoding the second measure.
286
- epsilon (float): entropy regularization parameter.
287
- ground_norm (int): norm to use for ground metric cost.
288
-
289
- Returns:
290
- float: the Wasserstein distance between signed measures.
291
- """
292
- C1, M1 = meas1[0], meas1[1]
293
- C2, M2 = meas2[0], meas2[1]
294
- C1_plus_idxs, C2_plus_idxs = (
295
- np.argwhere(M1 > 0).ravel(),
296
- np.argwhere(M2 > 0).ravel(),
297
- )
298
- C1_minus_idxs, C2_minus_idxs = (
299
- np.setdiff1d(np.arange(len(M1)), C1_plus_idxs),
300
- np.setdiff1d(np.arange(len(M2)), C2_plus_idxs),
301
- )
302
- approx1 = np.vstack(
303
- [
304
- np.repeat(C1[C1_plus_idxs], M1[C1_plus_idxs], axis=0),
305
- np.repeat(C2[C2_minus_idxs], -M2[C2_minus_idxs], axis=0),
306
- ]
307
- )
308
- approx2 = np.vstack(
309
- [
310
- np.repeat(C2[C2_plus_idxs], M2[C2_plus_idxs], axis=0),
311
- np.repeat(C1[C1_minus_idxs], -M1[C1_minus_idxs], axis=0),
312
- ]
313
- )
314
- num_pts = len(approx1)
315
- import ot
316
-
317
- if epsilon > 0:
318
- wass = ot.sinkhorn2(
319
- 1 / num_pts * np.ones(num_pts),
320
- 1 / num_pts * np.ones(num_pts),
321
- pairwise_distances(
322
- approx1, approx2, metric="minkowski", p=ground_norm),
323
- epsilon,
324
- )
325
- return wass[0]
326
- else:
327
- wass = ot.lp.emd2(
328
- 1 / num_pts * np.ones(num_pts),
329
- 1 / num_pts * np.ones(num_pts),
330
- pairwise_distances(
331
- approx1, approx2, metric="minkowski", p=ground_norm),
332
- )
333
- return wass
334
-
335
-
336
- class SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
337
- """
338
- This is a class for computing the sliced Wasserstein distance matrix from a list of signed measures. The Sliced Wasserstein distance is computed by projecting the signed measures onto lines, comparing the projections with the 1-norm, and finally integrating over all possible lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
339
- """
340
-
341
- def __init__(self, num_directions=10, scales=None, n_jobs=None):
342
- """
343
- Constructor for the SlicedWassersteinDistance class.
344
-
345
- Parameters:
346
- num_directions (int): number of lines evenly sampled in order to approximate and speed up the distance computation (default 10).
347
- scales (array of shape D): scales associated to the dimensions.
348
- n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_signed_measure_distances` for details.
349
- """
350
- self.num_directions = num_directions
351
- self.scales = scales
352
- self.n_jobs = n_jobs
353
-
354
- def fit(self, X, y=None):
355
- """
356
- Fit the SlicedWassersteinDistance class on a list of signed measures: signed measures are projected onto the different lines. The measures themselves are then stored in numpy arrays, called **measures_**.
357
-
358
- Parameters:
359
- X (list of tuples): input signed measures.
360
- y (n x 1 array): signed measure labels (unused).
361
- """
362
- self.measures_ = X
363
- return self
364
-
365
- def transform(self, X):
366
- """
367
- Compute all sliced Wasserstein distances between the signed measures that were stored after calling the fit() method, and a given list of (possibly different) signed measures.
368
-
369
- Parameters:
370
- X (list of tuples): input signed measures.
371
-
372
- Returns:
373
- numpy array of shape (number of measures in **measures**) x (number of measures in X): matrix of pairwise sliced Wasserstein distances.
374
- """
375
- return pairwise_signed_measure_distances(
376
- X,
377
- self.measures_,
378
- metric="sliced_wasserstein",
379
- num_directions=self.num_directions,
380
- scales=self.scales,
381
- n_jobs=self.n_jobs,
382
- )
383
-
384
- def __call__(self, meas1, meas2):
385
- """
386
- Apply SlicedWassersteinDistance on a single pair of signed measures and outputs the result.
387
-
388
- Parameters:
389
- meas1: ((n x D), (n)) tuple with numpy.array encoding the (finite points of the) first measure and their multiplicities. Must not contain essential points (i.e. with infinite coordinate).
390
- meas2: ((m x D), (m)) tuple encoding the second measure.
391
-
392
- Returns:
393
- float: sliced Wasserstein distance.
394
- """
395
- return _sliced_wasserstein_distance(
396
- meas1, meas2, num_directions=self.num_directions, scales=self.scales
397
- )
398
-
399
-
400
- class WassersteinDistance(BaseEstimator, TransformerMixin):
401
- """
402
- This is a class for computing the Wasserstein distance matrix from a list of signed measures.
403
- """
404
-
405
- def __init__(self, epsilon=1.0, ground_norm=1, n_jobs=None):
406
- """
407
- Constructor for the WassersteinDistance class.
408
-
409
- Parameters:
410
- epsilon (float): entropy regularization parameter.
411
- ground_norm (int): norm to use for ground metric cost.
412
- n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_signed_measure_distances` for details.
413
- """
414
- self.epsilon = epsilon
415
- self.ground_norm = ground_norm
416
- self.n_jobs = n_jobs
417
-
418
- def fit(self, X, y=None):
419
- """
420
- Fit the WassersteinDistance class on a list of signed measures. The measures themselves are then stored in numpy arrays, called **measures_**.
421
-
422
- Parameters:
423
- X (list of tuples): input signed measures.
424
- y (n x 1 array): signed measure labels (unused).
425
- """
426
- self.measures_ = X
427
- return self
428
-
429
- def transform(self, X):
430
- """
431
- Compute all Wasserstein distances between the signed measures that were stored after calling the fit() method, and a given list of (possibly different) signed measures.
432
-
433
- Parameters:
434
- X (list of tuples): input signed measures.
435
-
436
- Returns:
437
- numpy array of shape (number of measures in **measures**) x (number of measures in X): matrix of pairwise Wasserstein distances.
438
- """
439
- return pairwise_signed_measure_distances(
440
- X,
441
- self.measures_,
442
- metric="wasserstein",
443
- epsilon=self.epsilon,
444
- ground_norm=self.ground_norm,
445
- n_jobs=self.n_jobs,
446
- )
447
-
448
- def __call__(self, meas1, meas2):
449
- """
450
- Apply WassersteinDistance on a single pair of signed measures and outputs the result.
451
-
452
- Parameters:
453
- meas1: ((n x D), (n)) tuple with numpy.array encoding the (finite points of the) first measure and their multiplicities. Must not contain essential points (i.e. with infinite coordinate).
454
- meas2: ((m x D), (m)) tuple encoding the second measure.
455
-
456
- Returns:
457
- float: Wasserstein distance.
458
- """
459
- return _wasserstein_distance(
460
- meas1, meas2, epsilon=self.epsilon, ground_norm=self.ground_norm
461
- )
1
+ # This code was written by Mathieu Carrière.
2
+
3
+ import numpy as np
4
+ from sklearn.base import BaseEstimator, TransformerMixin
5
+ from sklearn.metrics import pairwise_distances, pairwise_kernels
6
+ from joblib import Parallel, delayed
7
+
8
+
9
+ def _pairwise(fallback, skipdiag, X, Y, metric, n_jobs):
10
+ if Y is not None:
11
+ return fallback(X, Y, metric=metric, n_jobs=n_jobs)
12
+ triu = np.triu_indices(len(X), k=skipdiag)
13
+ tril = (triu[1], triu[0])
14
+ par = Parallel(n_jobs=n_jobs, prefer="threads")
15
+ d = par(delayed(metric)([triu[0][i]], [triu[1][i]])
16
+ for i in range(len(triu[0])))
17
+ m = np.empty((len(X), len(X)))
18
+ m[triu] = d
19
+ m[tril] = d
20
+ if skipdiag:
21
+ np.fill_diagonal(m, 0)
22
+ return m
23
+
24
+
25
+ def _sklearn_wrapper(metric, X, Y, **kwargs):
26
+ """
27
+ This function is a wrapper for any metric between two signed measures that takes two numpy arrays of shapes (nxD) and (mxD) as arguments.
28
+ """
29
+ if Y is None:
30
+
31
+ def flat_metric(a, b):
32
+ return metric(X[int(a[0])], X[int(b[0])], **kwargs)
33
+ else:
34
+
35
+ def flat_metric(a, b):
36
+ return metric(X[int(a[0])], Y[int(b[0])], **kwargs)
37
+
38
+ return flat_metric
39
+
40
+
41
+ def _compute_signed_measure_parts(X):
42
+ """
43
+ This is a function for separating the positive and negative points of a list of signed measures. This function can be used as a preprocessing step in order to speed up the running time for computing all pairwise (sliced) Wasserstein distances on a list of signed measures.
44
+
45
+ Parameters:
46
+ X (list of n tuples): list of signed measures.
47
+
48
+ Returns:
49
+ list of n pairs of numpy arrays of shape (num x dimension): list of positive and negative signed measures.
50
+ """
51
+ XX = []
52
+ for C, M in X:
53
+ pos_idxs = np.argwhere(M > 0).ravel()
54
+ neg_idxs = np.setdiff1d(np.arange(len(M)), pos_idxs)
55
+ XX.append(
56
+ [
57
+ np.repeat(C[pos_idxs], M[pos_idxs], axis=0),
58
+ np.repeat(C[neg_idxs], -M[neg_idxs], axis=0),
59
+ ]
60
+ )
61
+ return XX
62
+
63
+
64
+ def _compute_signed_measure_projections(X, num_directions, scales):
65
+ """
66
+ This is a function for projecting the points of a list of signed measures onto a fixed number of lines sampled uniformly. This function can be used as a preprocessing step in order to speed up the running time for computing all pairwise sliced Wasserstein distances on a list of signed measures.
67
+
68
+ Parameters:
69
+ X (list of n tuples): list of signed measures.
70
+ num_directions (int): number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the distance computation.
71
+ scales (array of shape D): scales associated to the dimensions.
72
+
73
+ Returns:
74
+ list of n pairs of numpy arrays of shape (num x num_directions): list of positive and negative projected signed measures.
75
+ """
76
+ dimension = X[0][0].shape[1]
77
+ np.random.seed(42)
78
+ thetas = np.random.normal(0, 1, [num_directions, dimension])
79
+ lines = (thetas / np.linalg.norm(thetas, axis=1)[:, None]).T
80
+ weights = (
81
+ np.linalg.norm(np.multiply(scales[:, None], lines), axis=0)
82
+ if scales is not None
83
+ else np.ones(num_directions)
84
+ )
85
+ XX = []
86
+ for C, M in X:
87
+ pos_idxs = np.argwhere(M > 0).ravel()
88
+ neg_idxs = np.setdiff1d(np.arange(len(M)), pos_idxs)
89
+ XX.append(
90
+ [
91
+ np.matmul(np.repeat(C[pos_idxs], M[pos_idxs], axis=0), lines),
92
+ np.matmul(np.repeat(C[neg_idxs], -M[neg_idxs], axis=0), lines),
93
+ weights,
94
+ ]
95
+ )
96
+ return XX
97
+
98
+
99
+ def pairwise_signed_measure_distances(
100
+ X, Y=None, metric="sliced_wasserstein", n_jobs=None, **kwargs
101
+ ):
102
+ """
103
+ This function computes the distance matrix between two lists of signed measures given as numpy arrays of shape (nxD).
104
+
105
+ Parameters:
106
+ X (list of n tuples): first list of signed measures.
107
+ Y (list of m tuples): second list of signed measures (optional). If None, pairwise distances are computed from the first list only.
108
+ metric: distance to use. It can be either a string ("sliced_wasserstein", "wasserstein") or a function taking two tuples as inputs. If it is a function, make sure that it is symmetric and that it outputs 0 if called on the same two tuples.
109
+ n_jobs (int): number of jobs to use for the computation. This uses joblib.Parallel(prefer="threads"), so metrics that do not release the GIL may not scale unless run inside a `joblib.parallel_backend <https://joblib.readthedocs.io/en/latest/parallel.html#joblib.parallel_backend>`_ block.
110
+ **kwargs: optional keyword parameters. Any further parameters are passed directly to the distance function. See the docs of the various distance classes in this module.
111
+
112
+ Returns:
113
+ numpy array of shape (nxm): distance matrix
114
+ """
115
+ XX = np.reshape(np.arange(len(X)), [-1, 1])
116
+ YY = None if Y is None or Y is X else np.reshape(
117
+ np.arange(len(Y)), [-1, 1])
118
+ if metric == "sliced_wasserstein":
119
+ Xproj = _compute_signed_measure_projections(X, **kwargs)
120
+ Yproj = None if Y is None else _compute_signed_measure_projections(
121
+ Y, **kwargs)
122
+ return _pairwise(
123
+ pairwise_distances,
124
+ True,
125
+ XX,
126
+ YY,
127
+ metric=_sklearn_wrapper(
128
+ _sliced_wasserstein_distance_on_projections, Xproj, Yproj
129
+ ),
130
+ n_jobs=n_jobs,
131
+ )
132
+ elif metric == "wasserstein":
133
+ Xproj = _compute_signed_measure_parts(X)
134
+ Yproj = None if Y is None else _compute_signed_measure_parts(Y)
135
+ return _pairwise(
136
+ pairwise_distances,
137
+ True,
138
+ XX,
139
+ YY,
140
+ metric=_sklearn_wrapper(
141
+ _wasserstein_distance_on_parts(**kwargs), Xproj, Yproj
142
+ ),
143
+ n_jobs=n_jobs,
144
+ )
145
+ else:
146
+ return _pairwise(
147
+ pairwise_distances,
148
+ True,
149
+ XX,
150
+ YY,
151
+ metric=_sklearn_wrapper(metric, X, Y, **kwargs),
152
+ n_jobs=n_jobs,
153
+ )
154
+
155
+
156
+ def _wasserstein_distance_on_parts(ground_norm=1, epsilon=1.0):
157
+ """
158
+ This is a function for computing the Wasserstein distance between two signed measures that have already been separated into their positive and negative parts.
159
+
160
+ Parameters:
161
+ meas1: pair of (n x dimension) numpy.arrays containing the points of the positive and negative parts of the first measure.
162
+ meas2: pair of (m x dimension) numpy.arrays containing the points of the positive and negative parts of the second measure.
163
+
164
+ Returns:
165
+ float: the sliced Wasserstein distance between the projected signed measures.
166
+ """
167
+
168
+ def metric(meas1, meas2):
169
+ meas1_plus, meas1_minus = meas1[0], meas1[1]
170
+ meas2_plus, meas2_minus = meas2[0], meas2[1]
171
+ num_pts = len(meas1_plus) + len(meas2_minus)
172
+ meas_t1 = np.vstack([meas1_plus, meas2_minus])
173
+ meas_t2 = np.vstack([meas2_plus, meas1_minus])
174
+ import ot
175
+
176
+ if epsilon > 0:
177
+ wass = ot.sinkhorn2(
178
+ 1 / num_pts * np.ones(num_pts),
179
+ 1 / num_pts * np.ones(num_pts),
180
+ pairwise_distances(
181
+ meas_t1, meas_t2, metric="minkowski", p=ground_norm),
182
+ epsilon,
183
+ )
184
+ return wass[0]
185
+ else:
186
+ wass = ot.lp.emd2(
187
+ [],
188
+ [],
189
+ np.ascontiguousarray(
190
+ pairwise_distances(
191
+ meas_t1, meas_t2, metric="minkowski", p=ground_norm
192
+ ),
193
+ dtype=np.float64,
194
+ ),
195
+ )
196
+ return wass
197
+
198
+ return metric
199
+
200
+
201
+ def _sliced_wasserstein_distance_on_projections(meas1, meas2, scales=None):
202
+ """
203
+ This is a function for computing the sliced Wasserstein distance between two signed measures that have already been projected onto some lines. It simply amounts to comparing the sorted projections with the 1-norm, and averaging over the lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
204
+
205
+ Parameters:
206
+ meas1: pair of (n x number_of_lines) numpy.arrays containing the projected points of the positive and negative parts of the first measure.
207
+ meas2: pair of (m x number_of_lines) numpy.arrays containing the projected points of the positive and negative parts of the second measure.
208
+ scales (array of shape D): scales associated to the dimensions.
209
+
210
+ Returns:
211
+ float: the sliced Wasserstein distance between the projected signed measures.
212
+ """
213
+ # assert np.array_equal( meas1[2], meas2[2] )
214
+ weights = meas1[2]
215
+ meas1_plus, meas1_minus = meas1[0], meas1[1]
216
+ meas2_plus, meas2_minus = meas2[0], meas2[1]
217
+ A = np.sort(np.vstack([meas1_plus, meas2_minus]), axis=0)
218
+ B = np.sort(np.vstack([meas2_plus, meas1_minus]), axis=0)
219
+ L1 = np.sum(np.abs(A - B), axis=0)
220
+ return np.mean(np.multiply(L1, weights))
221
+
222
+
223
+ def _sliced_wasserstein_distance(meas1, meas2, num_directions, scales=None):
224
+ """
225
+ This is a function for computing the sliced Wasserstein distance from two signed measures. The Sliced Wasserstein distance is computed by projecting the signed measures onto lines, comparing the projections with the 1-norm, and finally averaging over the lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
226
+
227
+ Parameters:
228
+ meas1: ((n x D), (n)) tuple with numpy.array encoding the (finite points of the) first measure and their multiplicities. Must not contain essential points (i.e. with infinite coordinate).
229
+ meas2: ((m x D), (m)) tuple encoding the second measure.
230
+ num_directions (int): number of lines evenly sampled from [-pi/2,pi/2] in order to approximate and speed up the distance computation.
231
+ scales (array of shape D): scales associated to the dimensions.
232
+
233
+ Returns:
234
+ float: the sliced Wasserstein distance between signed measures.
235
+ """
236
+ C1, M1 = meas1[0], meas1[1]
237
+ C2, M2 = meas2[0], meas2[1]
238
+ dimension = C1.shape[1]
239
+ C1_plus_idxs, C2_plus_idxs = (
240
+ np.argwhere(M1 > 0).ravel(),
241
+ np.argwhere(M2 > 0).ravel(),
242
+ )
243
+ C1_minus_idxs, C2_minus_idxs = (
244
+ np.setdiff1d(np.arange(len(M1)), C1_plus_idxs),
245
+ np.setdiff1d(np.arange(len(M2)), C2_plus_idxs),
246
+ )
247
+ np.random.seed(42)
248
+ thetas = np.random.normal(0, 1, [num_directions, dimension])
249
+ lines = (thetas / np.linalg.norm(thetas, axis=1)[:, None]).T
250
+ weights = (
251
+ np.linalg.norm(np.multiply(scales[:, None], lines), axis=0)
252
+ if scales is not None
253
+ else np.ones(num_directions)
254
+ )
255
+ approx1 = np.matmul(
256
+ np.vstack(
257
+ [
258
+ np.repeat(C1[C1_plus_idxs], M1[C1_plus_idxs], axis=0),
259
+ np.repeat(C2[C2_minus_idxs], -M2[C2_minus_idxs], axis=0),
260
+ ]
261
+ ),
262
+ lines,
263
+ )
264
+ approx2 = np.matmul(
265
+ np.vstack(
266
+ [
267
+ np.repeat(C2[C2_plus_idxs], M2[C2_plus_idxs], axis=0),
268
+ np.repeat(C1[C1_minus_idxs], -M1[C1_minus_idxs], axis=0),
269
+ ]
270
+ ),
271
+ lines,
272
+ )
273
+ A = np.sort(approx1, axis=0)
274
+ B = np.sort(approx2, axis=0)
275
+ L1 = np.sum(np.abs(A - B), axis=0)
276
+ return np.mean(np.multiply(L1, weights))
277
+
278
+
279
+ def _wasserstein_distance(meas1, meas2, epsilon, ground_norm):
280
+ """
281
+ This is a function for computing the Wasserstein distance from two signed measures.
282
+
283
+ Parameters:
284
+ meas1: ((n x D), (n)) tuple with numpy.array encoding the (finite points of the) first measure and their multiplicities. Must not contain essential points (i.e. with infinite coordinate).
285
+ meas2: ((m x D), (m)) tuple encoding the second measure.
286
+ epsilon (float): entropy regularization parameter.
287
+ ground_norm (int): norm to use for ground metric cost.
288
+
289
+ Returns:
290
+ float: the Wasserstein distance between signed measures.
291
+ """
292
+ C1, M1 = meas1[0], meas1[1]
293
+ C2, M2 = meas2[0], meas2[1]
294
+ C1_plus_idxs, C2_plus_idxs = (
295
+ np.argwhere(M1 > 0).ravel(),
296
+ np.argwhere(M2 > 0).ravel(),
297
+ )
298
+ C1_minus_idxs, C2_minus_idxs = (
299
+ np.setdiff1d(np.arange(len(M1)), C1_plus_idxs),
300
+ np.setdiff1d(np.arange(len(M2)), C2_plus_idxs),
301
+ )
302
+ approx1 = np.vstack(
303
+ [
304
+ np.repeat(C1[C1_plus_idxs], M1[C1_plus_idxs], axis=0),
305
+ np.repeat(C2[C2_minus_idxs], -M2[C2_minus_idxs], axis=0),
306
+ ]
307
+ )
308
+ approx2 = np.vstack(
309
+ [
310
+ np.repeat(C2[C2_plus_idxs], M2[C2_plus_idxs], axis=0),
311
+ np.repeat(C1[C1_minus_idxs], -M1[C1_minus_idxs], axis=0),
312
+ ]
313
+ )
314
+ num_pts = len(approx1)
315
+ import ot
316
+
317
+ if epsilon > 0:
318
+ wass = ot.sinkhorn2(
319
+ 1 / num_pts * np.ones(num_pts),
320
+ 1 / num_pts * np.ones(num_pts),
321
+ pairwise_distances(
322
+ approx1, approx2, metric="minkowski", p=ground_norm),
323
+ epsilon,
324
+ )
325
+ return wass[0]
326
+ else:
327
+ wass = ot.lp.emd2(
328
+ 1 / num_pts * np.ones(num_pts),
329
+ 1 / num_pts * np.ones(num_pts),
330
+ pairwise_distances(
331
+ approx1, approx2, metric="minkowski", p=ground_norm),
332
+ )
333
+ return wass
334
+
335
+
336
+ class SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
337
+ """
338
+ This is a class for computing the sliced Wasserstein distance matrix from a list of signed measures. The Sliced Wasserstein distance is computed by projecting the signed measures onto lines, comparing the projections with the 1-norm, and finally integrating over all possible lines. See http://proceedings.mlr.press/v70/carriere17a.html for more details.
339
+ """
340
+
341
+ def __init__(self, num_directions=10, scales=None, n_jobs=None):
342
+ """
343
+ Constructor for the SlicedWassersteinDistance class.
344
+
345
+ Parameters:
346
+ num_directions (int): number of lines evenly sampled in order to approximate and speed up the distance computation (default 10).
347
+ scales (array of shape D): scales associated to the dimensions.
348
+ n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_signed_measure_distances` for details.
349
+ """
350
+ self.num_directions = num_directions
351
+ self.scales = scales
352
+ self.n_jobs = n_jobs
353
+
354
+ def fit(self, X, y=None):
355
+ """
356
+ Fit the SlicedWassersteinDistance class on a list of signed measures: signed measures are projected onto the different lines. The measures themselves are then stored in numpy arrays, called **measures_**.
357
+
358
+ Parameters:
359
+ X (list of tuples): input signed measures.
360
+ y (n x 1 array): signed measure labels (unused).
361
+ """
362
+ self.measures_ = X
363
+ return self
364
+
365
+ def transform(self, X):
366
+ """
367
+ Compute all sliced Wasserstein distances between the signed measures that were stored after calling the fit() method, and a given list of (possibly different) signed measures.
368
+
369
+ Parameters:
370
+ X (list of tuples): input signed measures.
371
+
372
+ Returns:
373
+ numpy array of shape (number of measures in **measures**) x (number of measures in X): matrix of pairwise sliced Wasserstein distances.
374
+ """
375
+ return pairwise_signed_measure_distances(
376
+ X,
377
+ self.measures_,
378
+ metric="sliced_wasserstein",
379
+ num_directions=self.num_directions,
380
+ scales=self.scales,
381
+ n_jobs=self.n_jobs,
382
+ )
383
+
384
+ def __call__(self, meas1, meas2):
385
+ """
386
+ Apply SlicedWassersteinDistance on a single pair of signed measures and outputs the result.
387
+
388
+ Parameters:
389
+ meas1: ((n x D), (n)) tuple with numpy.array encoding the (finite points of the) first measure and their multiplicities. Must not contain essential points (i.e. with infinite coordinate).
390
+ meas2: ((m x D), (m)) tuple encoding the second measure.
391
+
392
+ Returns:
393
+ float: sliced Wasserstein distance.
394
+ """
395
+ return _sliced_wasserstein_distance(
396
+ meas1, meas2, num_directions=self.num_directions, scales=self.scales
397
+ )
398
+
399
+
400
+ class WassersteinDistance(BaseEstimator, TransformerMixin):
401
+ """
402
+ This is a class for computing the Wasserstein distance matrix from a list of signed measures.
403
+ """
404
+
405
+ def __init__(self, epsilon=1.0, ground_norm=1, n_jobs=None):
406
+ """
407
+ Constructor for the WassersteinDistance class.
408
+
409
+ Parameters:
410
+ epsilon (float): entropy regularization parameter.
411
+ ground_norm (int): norm to use for ground metric cost.
412
+ n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_signed_measure_distances` for details.
413
+ """
414
+ self.epsilon = epsilon
415
+ self.ground_norm = ground_norm
416
+ self.n_jobs = n_jobs
417
+
418
+ def fit(self, X, y=None):
419
+ """
420
+ Fit the WassersteinDistance class on a list of signed measures. The measures themselves are then stored in numpy arrays, called **measures_**.
421
+
422
+ Parameters:
423
+ X (list of tuples): input signed measures.
424
+ y (n x 1 array): signed measure labels (unused).
425
+ """
426
+ self.measures_ = X
427
+ return self
428
+
429
+ def transform(self, X):
430
+ """
431
+ Compute all Wasserstein distances between the signed measures that were stored after calling the fit() method, and a given list of (possibly different) signed measures.
432
+
433
+ Parameters:
434
+ X (list of tuples): input signed measures.
435
+
436
+ Returns:
437
+ numpy array of shape (number of measures in **measures**) x (number of measures in X): matrix of pairwise Wasserstein distances.
438
+ """
439
+ return pairwise_signed_measure_distances(
440
+ X,
441
+ self.measures_,
442
+ metric="wasserstein",
443
+ epsilon=self.epsilon,
444
+ ground_norm=self.ground_norm,
445
+ n_jobs=self.n_jobs,
446
+ )
447
+
448
+ def __call__(self, meas1, meas2):
449
+ """
450
+ Apply WassersteinDistance on a single pair of signed measures and outputs the result.
451
+
452
+ Parameters:
453
+ meas1: ((n x D), (n)) tuple with numpy.array encoding the (finite points of the) first measure and their multiplicities. Must not contain essential points (i.e. with infinite coordinate).
454
+ meas2: ((m x D), (m)) tuple encoding the second measure.
455
+
456
+ Returns:
457
+ float: Wasserstein distance.
458
+ """
459
+ return _wasserstein_distance(
460
+ meas1, meas2, epsilon=self.epsilon, ground_norm=self.ground_norm
461
+ )