multipers 2.3.0__cp310-cp310-win_amd64.whl → 2.3.2b1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of multipers might be problematic. Click here for more details.
- multipers/_signed_measure_meta.py +71 -65
- multipers/array_api/__init__.py +39 -0
- multipers/array_api/numpy.py +34 -0
- multipers/array_api/torch.py +35 -0
- multipers/distances.py +6 -2
- multipers/{ml/convolutions.py → filtrations/density.py} +67 -13
- multipers/filtrations/filtrations.py +76 -17
- multipers/function_rips.cp310-win_amd64.pyd +0 -0
- multipers/grids.cp310-win_amd64.pyd +0 -0
- multipers/grids.pyx +144 -61
- multipers/gudhi/Simplex_tree_multi_interface.h +36 -1
- multipers/gudhi/gudhi/Multi_persistence/Box.h +3 -0
- multipers/gudhi/gudhi/One_critical_filtration.h +18 -9
- multipers/gudhi/mma_interface_h0.h +1 -1
- multipers/gudhi/mma_interface_matrix.h +10 -1
- multipers/gudhi/naive_merge_tree.h +1 -1
- multipers/gudhi/truc.h +555 -42
- multipers/io.cp310-win_amd64.pyd +0 -0
- multipers/io.pyx +26 -93
- multipers/ml/mma.py +3 -3
- multipers/ml/point_clouds.py +2 -2
- multipers/ml/signed_measures.py +63 -65
- multipers/mma_structures.cp310-win_amd64.pyd +0 -0
- multipers/mma_structures.pxd +2 -1
- multipers/mma_structures.pyx +56 -16
- multipers/mma_structures.pyx.tp +14 -5
- multipers/multiparameter_module_approximation/approximation.h +48 -14
- multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
- multipers/multiparameter_module_approximation.pyx +25 -7
- multipers/plots.py +2 -1
- multipers/point_measure.cp310-win_amd64.pyd +0 -0
- multipers/point_measure.pyx +6 -2
- multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
- multipers/simplex_tree_multi.pxd +1 -0
- multipers/simplex_tree_multi.pyx +584 -142
- multipers/simplex_tree_multi.pyx.tp +80 -23
- multipers/slicer.cp310-win_amd64.pyd +0 -0
- multipers/slicer.pxd +799 -197
- multipers/slicer.pxd.tp +24 -5
- multipers/slicer.pyx +5653 -1426
- multipers/slicer.pyx.tp +208 -48
- multipers/tbb12.dll +0 -0
- multipers/tbbbind_2_5.dll +0 -0
- multipers/tbbmalloc.dll +0 -0
- multipers/tbbmalloc_proxy.dll +0 -0
- multipers/tensor/tensor.h +1 -1
- multipers/tests/__init__.py +9 -4
- multipers/torch/diff_grids.py +30 -7
- multipers/torch/rips_density.py +1 -1
- {multipers-2.3.0.dist-info → multipers-2.3.2b1.dist-info}/METADATA +4 -25
- {multipers-2.3.0.dist-info → multipers-2.3.2b1.dist-info}/RECORD +54 -51
- {multipers-2.3.0.dist-info → multipers-2.3.2b1.dist-info}/WHEEL +1 -1
- {multipers-2.3.0.dist-info → multipers-2.3.2b1.dist-info/licenses}/LICENSE +0 -0
- {multipers-2.3.0.dist-info → multipers-2.3.2b1.dist-info}/top_level.txt +0 -0
|
@@ -31,11 +31,10 @@ def signed_measure(
|
|
|
31
31
|
verbose: bool = False,
|
|
32
32
|
n_jobs: int = -1,
|
|
33
33
|
expand_collapse: bool = False,
|
|
34
|
-
backend: Optional[str] = None,
|
|
35
|
-
thread_id: str = "",
|
|
34
|
+
backend: Optional[str] = None, # deprecated
|
|
36
35
|
grid: Optional[Iterable] = None,
|
|
37
36
|
coordinate_measure: bool = False,
|
|
38
|
-
num_collapses: int = 0,
|
|
37
|
+
num_collapses: int = 0, # TODO : deprecate
|
|
39
38
|
clean: Optional[bool] = None,
|
|
40
39
|
vineyard: bool = False,
|
|
41
40
|
grid_conversion: Optional[Iterable] = None,
|
|
@@ -99,7 +98,13 @@ def signed_measure(
|
|
|
99
98
|
It is usually faster to use this backend if not in a parallel context.
|
|
100
99
|
- Rank: Same as Hilbert.
|
|
101
100
|
"""
|
|
101
|
+
if backend is not None:
|
|
102
|
+
raise ValueError("backend is deprecated. reduce the complex before this function.")
|
|
103
|
+
if num_collapses >0:
|
|
104
|
+
raise ValueError("num_collapses is deprecated. reduce the complex before this function.")
|
|
102
105
|
## TODO : add timings in verbose
|
|
106
|
+
if len(filtered_complex) == 0:
|
|
107
|
+
return [(np.empty((0,2), dtype=filtered_complex.dtype), np.empty(shape=(0,), dtype=int))]
|
|
103
108
|
if grid_conversion is not None:
|
|
104
109
|
grid = tuple(f for f in grid_conversion)
|
|
105
110
|
raise DeprecationWarning(
|
|
@@ -133,7 +138,7 @@ def signed_measure(
|
|
|
133
138
|
|
|
134
139
|
assert (
|
|
135
140
|
not plot or filtered_complex.num_parameters == 2
|
|
136
|
-
), "Can only plot 2d measures."
|
|
141
|
+
), f"Can only plot 2d measures. Got {filtered_complex.num_parameters=}."
|
|
137
142
|
|
|
138
143
|
if grid is None:
|
|
139
144
|
if not filtered_complex.is_squeezed:
|
|
@@ -141,7 +146,7 @@ def signed_measure(
|
|
|
141
146
|
filtered_complex, strategy=grid_strategy, **infer_grid_kwargs
|
|
142
147
|
)
|
|
143
148
|
else:
|
|
144
|
-
grid =
|
|
149
|
+
grid = filtered_complex.filtration_grid
|
|
145
150
|
|
|
146
151
|
if mass_default is None:
|
|
147
152
|
mass_default = mass_default
|
|
@@ -186,69 +191,70 @@ def signed_measure(
|
|
|
186
191
|
grid
|
|
187
192
|
), f"Number of parameter do not coincide. Got (grid) {len(grid)} and (filtered complex) {num_parameters}."
|
|
188
193
|
|
|
189
|
-
if is_simplextree_multi(filtered_complex_):
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
194
|
+
# if is_simplextree_multi(filtered_complex_):
|
|
195
|
+
# # if num_collapses != 0:
|
|
196
|
+
# # if verbose:
|
|
197
|
+
# # print("Collapsing edges...", end="")
|
|
198
|
+
# # filtered_complex_.collapse_edges(num_collapses)
|
|
199
|
+
# # if verbose:
|
|
200
|
+
# # print("Done.")
|
|
201
|
+
# # if backend is not None:
|
|
202
|
+
# # filtered_complex_ = mp.Slicer(filtered_complex_, vineyard=vineyard)
|
|
198
203
|
|
|
199
204
|
fix_mass_default = mass_default is not None
|
|
200
205
|
if is_slicer(filtered_complex_):
|
|
201
206
|
if verbose:
|
|
202
207
|
print("Input is a slicer.")
|
|
203
208
|
if backend is not None and not filtered_complex_.is_minpres:
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
209
|
+
raise ValueError("giving a backend to this function is deprecated")
|
|
210
|
+
# from multipers.slicer import minimal_presentation
|
|
211
|
+
#
|
|
212
|
+
# assert (
|
|
213
|
+
# invariant != "euler"
|
|
214
|
+
# ), "Euler Characteristic cannot be speed up by a backend"
|
|
215
|
+
# # This returns a list of reduced complexes
|
|
216
|
+
# if verbose:
|
|
217
|
+
# print("Reducing complex...", end="")
|
|
218
|
+
# reduced_complex = minimal_presentation(
|
|
219
|
+
# filtered_complex_,
|
|
220
|
+
# degrees=degrees,
|
|
221
|
+
# backend=backend,
|
|
222
|
+
# vineyard=vineyard,
|
|
223
|
+
# verbose=verbose,
|
|
224
|
+
# )
|
|
225
|
+
# if verbose:
|
|
226
|
+
# print("Done.")
|
|
227
|
+
# if invariant is not None and "rank" in invariant:
|
|
228
|
+
# if verbose:
|
|
229
|
+
# print("Computing rank...", end="")
|
|
230
|
+
# sms = [
|
|
231
|
+
# _rank_from_slicer(
|
|
232
|
+
# s,
|
|
233
|
+
# degrees=[d],
|
|
234
|
+
# n_jobs=n_jobs,
|
|
235
|
+
# # grid_shape=tuple(len(g) for g in grid),
|
|
236
|
+
# zero_pad=fix_mass_default,
|
|
237
|
+
# ignore_inf=ignore_infinite_filtration_values,
|
|
238
|
+
# )[0]
|
|
239
|
+
# for s, d in zip(reduced_complex, degrees)
|
|
240
|
+
# ]
|
|
241
|
+
# fix_mass_default = False
|
|
242
|
+
# if verbose:
|
|
243
|
+
# print("Done.")
|
|
244
|
+
# else:
|
|
245
|
+
# if verbose:
|
|
246
|
+
# print("Reduced slicer. Retrieving measure from it...", end="")
|
|
247
|
+
# sms = [
|
|
248
|
+
# _signed_measure_from_slicer(
|
|
249
|
+
# s,
|
|
250
|
+
# shift=(
|
|
251
|
+
# reduced_complex.minpres_degree & 1 if d is None else d & 1
|
|
252
|
+
# ),
|
|
253
|
+
# )[0]
|
|
254
|
+
# for s, d in zip(reduced_complex, degrees)
|
|
255
|
+
# ]
|
|
256
|
+
# if verbose:
|
|
257
|
+
# print("Done.")
|
|
252
258
|
else: # No backend
|
|
253
259
|
if invariant is not None and "rank" in invariant:
|
|
254
260
|
degrees = np.asarray(degrees, dtype=int)
|
|
@@ -272,7 +278,7 @@ def signed_measure(
|
|
|
272
278
|
_signed_measure_from_slicer(
|
|
273
279
|
filtered_complex_,
|
|
274
280
|
shift=(
|
|
275
|
-
filtered_complex_.minpres_degree
|
|
281
|
+
filtered_complex_.minpres_degree & 1 if d is None else d & 1
|
|
276
282
|
),
|
|
277
283
|
)[0]
|
|
278
284
|
for d in degrees
|
|
@@ -385,7 +391,7 @@ def signed_measure(
|
|
|
385
391
|
sms,
|
|
386
392
|
grid=grid,
|
|
387
393
|
mass_default=mass_default,
|
|
388
|
-
num_parameters=num_parameters,
|
|
394
|
+
# num_parameters=num_parameters,
|
|
389
395
|
)
|
|
390
396
|
if verbose:
|
|
391
397
|
print("Done.")
|
|
@@ -408,7 +414,7 @@ def _signed_measure_from_scc(
|
|
|
408
414
|
pts = np.concatenate([b[0] for b in minimal_presentation])
|
|
409
415
|
weights = np.concatenate(
|
|
410
416
|
[
|
|
411
|
-
(1 - 2 * (i
|
|
417
|
+
(1 - 2 * (i & 1)) * np.ones(len(b[0]))
|
|
412
418
|
for i, b in enumerate(minimal_presentation)
|
|
413
419
|
]
|
|
414
420
|
)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
def api_from_tensor(x, *, verbose: bool = False):
|
|
2
|
+
import multipers.array_api.numpy as npapi
|
|
3
|
+
|
|
4
|
+
if npapi.is_promotable(x):
|
|
5
|
+
if verbose:
|
|
6
|
+
print("using numpy backend")
|
|
7
|
+
return npapi
|
|
8
|
+
import multipers.array_api.torch as torchapi
|
|
9
|
+
|
|
10
|
+
if torchapi.is_promotable(x):
|
|
11
|
+
if verbose:
|
|
12
|
+
print("using torch backend")
|
|
13
|
+
return torchapi
|
|
14
|
+
raise ValueError(f"Unsupported type {type(x)=}")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def api_from_tensors(*args):
|
|
18
|
+
assert len(args) > 0, "no tensor given"
|
|
19
|
+
import multipers.array_api.numpy as npapi
|
|
20
|
+
|
|
21
|
+
is_numpy = True
|
|
22
|
+
for x in args:
|
|
23
|
+
if not npapi.is_promotable(x):
|
|
24
|
+
is_numpy = False
|
|
25
|
+
break
|
|
26
|
+
if is_numpy:
|
|
27
|
+
return npapi
|
|
28
|
+
|
|
29
|
+
# only torch for now
|
|
30
|
+
import multipers.array_api.torch as torchapi
|
|
31
|
+
|
|
32
|
+
is_torch = True
|
|
33
|
+
for x in args:
|
|
34
|
+
if not torchapi.is_promotable(x):
|
|
35
|
+
is_torch = False
|
|
36
|
+
break
|
|
37
|
+
if is_torch:
|
|
38
|
+
return torchapi
|
|
39
|
+
raise ValueError(f"Incompatible types got {[type(x) for x in args]=}.")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from contextlib import nullcontext
|
|
2
|
+
|
|
3
|
+
import numpy as _np
|
|
4
|
+
from scipy.spatial.distance import cdist
|
|
5
|
+
|
|
6
|
+
backend = _np
|
|
7
|
+
cat = _np.concatenate
|
|
8
|
+
norm = _np.linalg.norm
|
|
9
|
+
astensor = _np.asarray
|
|
10
|
+
asnumpy = _np.asarray
|
|
11
|
+
tensor = _np.array
|
|
12
|
+
stack = _np.stack
|
|
13
|
+
empty = _np.empty
|
|
14
|
+
where = _np.where
|
|
15
|
+
no_grad = nullcontext
|
|
16
|
+
zeros = _np.zeros
|
|
17
|
+
min = _np.min
|
|
18
|
+
max = _np.max
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def minvalues(x: _np.ndarray, **kwargs):
|
|
22
|
+
return _np.min(x, **kwargs)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def maxvalues(x: _np.ndarray, **kwargs):
|
|
26
|
+
return _np.max(x, **kwargs)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def is_promotable(x):
|
|
30
|
+
return isinstance(x, _np.ndarray | list | tuple)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def has_grad(_):
|
|
34
|
+
return False
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import torch as _t
|
|
2
|
+
|
|
3
|
+
backend = _t
|
|
4
|
+
cat = _t.cat
|
|
5
|
+
norm = _t.norm
|
|
6
|
+
astensor = _t.as_tensor
|
|
7
|
+
tensor = _t.tensor
|
|
8
|
+
stack = _t.stack
|
|
9
|
+
empty = _t.empty
|
|
10
|
+
where = _t.where
|
|
11
|
+
no_grad = _t.no_grad
|
|
12
|
+
cdist = _t.cdist
|
|
13
|
+
zeros = _t.zeros
|
|
14
|
+
min = _t.min
|
|
15
|
+
max = _t.max
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def minvalues(x: _t.Tensor, **kwargs):
|
|
19
|
+
return _t.min(x, **kwargs).values
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def maxvalues(x: _t.Tensor, **kwargs):
|
|
23
|
+
return _t.max(x, **kwargs).values
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def asnumpy(x):
|
|
27
|
+
return x.detach().numpy()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def is_promotable(x):
|
|
31
|
+
return isinstance(x, _t.Tensor)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def has_grad(x):
|
|
35
|
+
return x.requires_grad
|
multipers/distances.py
CHANGED
|
@@ -6,7 +6,7 @@ from multipers.multiparameter_module_approximation import PyModule_type
|
|
|
6
6
|
from multipers.simplex_tree_multi import SimplexTreeMulti_type
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def sm2diff(sm1, sm2):
|
|
9
|
+
def sm2diff(sm1, sm2, threshold=None):
|
|
10
10
|
pts = sm1[0]
|
|
11
11
|
dtype = pts.dtype
|
|
12
12
|
if isinstance(pts, np.ndarray):
|
|
@@ -45,6 +45,9 @@ def sm2diff(sm1, sm2):
|
|
|
45
45
|
)
|
|
46
46
|
x = backend_concatenate(pts1[pos_indices1], pts2[neg_indices2])
|
|
47
47
|
y = backend_concatenate(pts1[neg_indices1], pts2[pos_indices2])
|
|
48
|
+
if threshold is not None:
|
|
49
|
+
x[x>threshold]=threshold
|
|
50
|
+
y[y>threshold]=threshold
|
|
48
51
|
return x, y
|
|
49
52
|
|
|
50
53
|
|
|
@@ -55,6 +58,7 @@ def sm_distance(
|
|
|
55
58
|
reg_m: float = 0,
|
|
56
59
|
numItermax: int = 10000,
|
|
57
60
|
p: float = 1,
|
|
61
|
+
threshold=None,
|
|
58
62
|
):
|
|
59
63
|
"""
|
|
60
64
|
Computes the wasserstein distances between two signed measures,
|
|
@@ -68,7 +72,7 @@ def sm_distance(
|
|
|
68
72
|
- sinkhorn if reg != 0
|
|
69
73
|
- sinkhorn unbalanced if reg_m != 0
|
|
70
74
|
"""
|
|
71
|
-
x, y = sm2diff(sm1, sm2)
|
|
75
|
+
x, y = sm2diff(sm1, sm2, threshold=threshold)
|
|
72
76
|
loss = ot.dist(
|
|
73
77
|
x, y, metric="sqeuclidean", p=p
|
|
74
78
|
) # only euc + sqeuclidian are implemented in pot for the moment with torch backend # TODO : check later
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
from collections.abc import Callable, Iterable
|
|
2
2
|
from typing import Any, Literal, Union
|
|
3
|
-
|
|
4
3
|
import numpy as np
|
|
5
4
|
|
|
5
|
+
|
|
6
|
+
from multipers.array_api import api_from_tensor
|
|
6
7
|
global available_kernels
|
|
7
8
|
available_kernels = Union[
|
|
8
9
|
Literal[
|
|
@@ -41,13 +42,14 @@ def convolution_signed_measures(
|
|
|
41
42
|
from multipers.grids import todense
|
|
42
43
|
|
|
43
44
|
grid_iterator = todense(filtrations, product_order=True)
|
|
45
|
+
api = api_from_tensor(iterable_of_signed_measures[0][0][0])
|
|
44
46
|
match backend:
|
|
45
47
|
case "sklearn":
|
|
46
48
|
|
|
47
49
|
def convolution_signed_measures_on_grid(
|
|
48
|
-
signed_measures
|
|
50
|
+
signed_measures,
|
|
49
51
|
):
|
|
50
|
-
return
|
|
52
|
+
return api.cat(
|
|
51
53
|
[
|
|
52
54
|
_pts_convolution_sparse_old(
|
|
53
55
|
pts=pts,
|
|
@@ -67,7 +69,7 @@ def convolution_signed_measures(
|
|
|
67
69
|
def convolution_signed_measures_on_grid(
|
|
68
70
|
signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
|
|
69
71
|
) -> np.ndarray:
|
|
70
|
-
return
|
|
72
|
+
return api.cat(
|
|
71
73
|
[
|
|
72
74
|
_pts_convolution_pykeops(
|
|
73
75
|
pts=pts,
|
|
@@ -111,7 +113,7 @@ def convolution_signed_measures(
|
|
|
111
113
|
if not flatten:
|
|
112
114
|
out_shape = [-1] + [len(f) for f in filtrations] # Degree
|
|
113
115
|
convolutions = [x.reshape(out_shape) for x in convolutions]
|
|
114
|
-
return
|
|
116
|
+
return api.cat([x[None] for x in convolutions])
|
|
115
117
|
|
|
116
118
|
|
|
117
119
|
# def _test(r=1000, b=0.5, plot=True, kernel=0):
|
|
@@ -173,16 +175,24 @@ def _pts_convolution_pykeops(
|
|
|
173
175
|
"""
|
|
174
176
|
Pykeops convolution
|
|
175
177
|
"""
|
|
178
|
+
if isinstance(pts, np.ndarray):
|
|
179
|
+
_asarray_weights = lambda x : np.asarray(x, dtype=pts.dtype)
|
|
180
|
+
_asarray_grid = _asarray_weights
|
|
181
|
+
else:
|
|
182
|
+
import torch
|
|
183
|
+
_asarray_weights = lambda x : torch.from_numpy(x).type(pts.dtype)
|
|
184
|
+
_asarray_grid = lambda x : x.type(pts.dtype)
|
|
176
185
|
kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
|
|
177
186
|
return kde.fit(
|
|
178
|
-
pts, sample_weights=
|
|
179
|
-
).score_samples(
|
|
187
|
+
pts, sample_weights=_asarray_weights(pts_weights)
|
|
188
|
+
).score_samples(_asarray_grid(grid_iterator))
|
|
180
189
|
|
|
181
190
|
|
|
182
191
|
def gaussian_kernel(x_i, y_j, bandwidth):
|
|
192
|
+
D = x_i.shape[-1]
|
|
183
193
|
exponent = -(((x_i - y_j) / bandwidth) ** 2).sum(dim=-1) / 2
|
|
184
194
|
# float is necessary for some reason (pykeops fails)
|
|
185
|
-
kernel = (exponent).exp() / (bandwidth
|
|
195
|
+
kernel = (exponent).exp() / float((bandwidth*np.sqrt(2 * np.pi))**D)
|
|
186
196
|
return kernel
|
|
187
197
|
|
|
188
198
|
|
|
@@ -290,10 +300,10 @@ class KDE:
|
|
|
290
300
|
X.reshape((X.shape[0], 1, X.shape[1]))
|
|
291
301
|
) # numpts, 1, dim
|
|
292
302
|
lazy_y = LazyTensor(
|
|
293
|
-
Y.reshape((1, Y.shape[0], Y.shape[1]))
|
|
303
|
+
Y.reshape((1, Y.shape[0], Y.shape[1])).astype(X.dtype)
|
|
294
304
|
) # 1, numpts, dim
|
|
295
305
|
if x_weights is not None:
|
|
296
|
-
w = LazyTensor(x_weights[:, None], axis=0)
|
|
306
|
+
w = LazyTensor(np.asarray(x_weights, dtype=X.dtype)[:, None], axis=0)
|
|
297
307
|
return lazy_x, lazy_y, w
|
|
298
308
|
return lazy_x, lazy_y, None
|
|
299
309
|
import torch
|
|
@@ -302,9 +312,11 @@ class KDE:
|
|
|
302
312
|
from pykeops.torch import LazyTensor
|
|
303
313
|
|
|
304
314
|
lazy_x = LazyTensor(X.view(X.shape[0], 1, X.shape[1]))
|
|
305
|
-
lazy_y = LazyTensor(Y.view(1, Y.shape[0], Y.shape[1]))
|
|
315
|
+
lazy_y = LazyTensor(Y.type(X.dtype).view(1, Y.shape[0], Y.shape[1]))
|
|
306
316
|
if x_weights is not None:
|
|
307
|
-
|
|
317
|
+
if isinstance(x_weights, np.ndarray):
|
|
318
|
+
x_weights = torch.from_numpy(x_weights)
|
|
319
|
+
w = LazyTensor(x_weights[:, None].type(X.dtype), axis=0)
|
|
308
320
|
return lazy_x, lazy_y, w
|
|
309
321
|
return lazy_x, lazy_y, None
|
|
310
322
|
raise Exception("Bad tensor type.")
|
|
@@ -339,7 +351,7 @@ class KDE:
|
|
|
339
351
|
kernel *= w
|
|
340
352
|
if return_kernel:
|
|
341
353
|
return kernel
|
|
342
|
-
density_estimation = kernel.sum(dim=0).
|
|
354
|
+
density_estimation = kernel.sum(dim=0).squeeze() / kernel.shape[0] # mean
|
|
343
355
|
return (
|
|
344
356
|
self._backend.log(density_estimation)
|
|
345
357
|
if self.return_log
|
|
@@ -497,6 +509,48 @@ class DTM:
|
|
|
497
509
|
return DTMs
|
|
498
510
|
|
|
499
511
|
|
|
512
|
+
## code taken from pykeops doc (https://www.kernel-operations.io/keops/_auto_benchmarks/benchmark_KNN.html)
|
|
513
|
+
class KNNmean:
|
|
514
|
+
def __init__(self, k: int, metric: str = "euclidean"):
|
|
515
|
+
self.k = k
|
|
516
|
+
self.metric = metric
|
|
517
|
+
self._KNN_fun = None
|
|
518
|
+
self._x = None
|
|
519
|
+
|
|
520
|
+
def fit(self, x):
|
|
521
|
+
if isinstance(x, np.ndarray):
|
|
522
|
+
from pykeops.numpy import Vi, Vj
|
|
523
|
+
else:
|
|
524
|
+
import torch
|
|
525
|
+
|
|
526
|
+
assert isinstance(x, torch.Tensor), "Backend has to be numpy or torch"
|
|
527
|
+
from pykeops.torch import Vi, Vj
|
|
528
|
+
|
|
529
|
+
D = x.shape[1]
|
|
530
|
+
X_i = Vi(0, D)
|
|
531
|
+
X_j = Vj(1, D)
|
|
532
|
+
|
|
533
|
+
# Symbolic distance matrix:
|
|
534
|
+
if self.metric == "euclidean":
|
|
535
|
+
D_ij = ((X_i - X_j) ** 2).sum(-1) ** (1/2)
|
|
536
|
+
elif self.metric == "manhattan":
|
|
537
|
+
D_ij = (X_i - X_j).abs().sum(-1)
|
|
538
|
+
elif self.metric == "angular":
|
|
539
|
+
D_ij = -(X_i | X_j)
|
|
540
|
+
elif self.metric == "hyperbolic":
|
|
541
|
+
D_ij = ((X_i - X_j) ** 2).sum(-1) / (X_i[0] * X_j[0])
|
|
542
|
+
else:
|
|
543
|
+
raise NotImplementedError(f"The '{self.metric}' distance is not supported.")
|
|
544
|
+
|
|
545
|
+
self._x = x
|
|
546
|
+
self._KNN_fun = D_ij.Kmin(self.k, dim=1)
|
|
547
|
+
return self
|
|
548
|
+
|
|
549
|
+
def score_samples(self, x):
|
|
550
|
+
assert self._x is not None and self._KNN_fun is not None, "Fit first."
|
|
551
|
+
return self._KNN_fun(x, self._x).sum(axis=1) / self.k
|
|
552
|
+
|
|
553
|
+
|
|
500
554
|
# def _pts_convolution_sparse(pts:np.ndarray, pts_weights:np.ndarray, filtration_grid:Iterable[np.ndarray], kernel="gaussian", bandwidth=0.1, **more_kde_args):
|
|
501
555
|
# """
|
|
502
556
|
# Old version of `convolution_signed_measures`. Scikitlearn's convolution is slower than the code above.
|