multipers 2.3.1__cp310-cp310-win_amd64.whl → 2.3.2__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (49) hide show
  1. multipers/_signed_measure_meta.py +71 -65
  2. multipers/array_api/__init__.py +39 -0
  3. multipers/array_api/numpy.py +34 -0
  4. multipers/array_api/torch.py +35 -0
  5. multipers/distances.py +6 -2
  6. multipers/filtrations/density.py +23 -12
  7. multipers/filtrations/filtrations.py +74 -15
  8. multipers/function_rips.cp310-win_amd64.pyd +0 -0
  9. multipers/grids.cp310-win_amd64.pyd +0 -0
  10. multipers/grids.pyx +144 -61
  11. multipers/gudhi/Simplex_tree_multi_interface.h +35 -0
  12. multipers/gudhi/gudhi/Multi_persistence/Box.h +3 -0
  13. multipers/gudhi/gudhi/One_critical_filtration.h +17 -9
  14. multipers/gudhi/mma_interface_matrix.h +5 -3
  15. multipers/gudhi/truc.h +488 -42
  16. multipers/io.cp310-win_amd64.pyd +0 -0
  17. multipers/io.pyx +16 -86
  18. multipers/ml/mma.py +4 -4
  19. multipers/ml/signed_measures.py +60 -62
  20. multipers/mma_structures.cp310-win_amd64.pyd +0 -0
  21. multipers/mma_structures.pxd +2 -1
  22. multipers/mma_structures.pyx +56 -12
  23. multipers/mma_structures.pyx.tp +14 -3
  24. multipers/multiparameter_module_approximation/approximation.h +45 -13
  25. multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
  26. multipers/multiparameter_module_approximation.pyx +24 -7
  27. multipers/plots.py +1 -0
  28. multipers/point_measure.cp310-win_amd64.pyd +0 -0
  29. multipers/point_measure.pyx +6 -2
  30. multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
  31. multipers/simplex_tree_multi.pxd +1 -0
  32. multipers/simplex_tree_multi.pyx +535 -113
  33. multipers/simplex_tree_multi.pyx.tp +79 -19
  34. multipers/slicer.cp310-win_amd64.pyd +0 -0
  35. multipers/slicer.pxd +719 -237
  36. multipers/slicer.pxd.tp +22 -6
  37. multipers/slicer.pyx +5315 -1365
  38. multipers/slicer.pyx.tp +202 -46
  39. multipers/tbb12.dll +0 -0
  40. multipers/tbbbind_2_5.dll +0 -0
  41. multipers/tbbmalloc.dll +0 -0
  42. multipers/tbbmalloc_proxy.dll +0 -0
  43. multipers/tests/__init__.py +9 -4
  44. multipers/torch/diff_grids.py +30 -7
  45. {multipers-2.3.1.dist-info → multipers-2.3.2.dist-info}/METADATA +4 -25
  46. {multipers-2.3.1.dist-info → multipers-2.3.2.dist-info}/RECORD +49 -46
  47. {multipers-2.3.1.dist-info → multipers-2.3.2.dist-info}/WHEEL +1 -1
  48. {multipers-2.3.1.dist-info → multipers-2.3.2.dist-info/licenses}/LICENSE +0 -0
  49. {multipers-2.3.1.dist-info → multipers-2.3.2.dist-info}/top_level.txt +0 -0
@@ -31,11 +31,10 @@ def signed_measure(
31
31
  verbose: bool = False,
32
32
  n_jobs: int = -1,
33
33
  expand_collapse: bool = False,
34
- backend: Optional[str] = None,
35
- thread_id: str = "",
34
+ backend: Optional[str] = None, # deprecated
36
35
  grid: Optional[Iterable] = None,
37
36
  coordinate_measure: bool = False,
38
- num_collapses: int = 0,
37
+ num_collapses: int = 0, # TODO : deprecate
39
38
  clean: Optional[bool] = None,
40
39
  vineyard: bool = False,
41
40
  grid_conversion: Optional[Iterable] = None,
@@ -99,7 +98,13 @@ def signed_measure(
99
98
  It is usually faster to use this backend if not in a parallel context.
100
99
  - Rank: Same as Hilbert.
101
100
  """
101
+ if backend is not None:
102
+ raise ValueError("backend is deprecated. reduce the complex before this function.")
103
+ if num_collapses >0:
104
+ raise ValueError("num_collapses is deprecated. reduce the complex before this function.")
102
105
  ## TODO : add timings in verbose
106
+ if len(filtered_complex) == 0:
107
+ return [(np.empty((0,2), dtype=filtered_complex.dtype), np.empty(shape=(0,), dtype=int))]
103
108
  if grid_conversion is not None:
104
109
  grid = tuple(f for f in grid_conversion)
105
110
  raise DeprecationWarning(
@@ -133,7 +138,7 @@ def signed_measure(
133
138
 
134
139
  assert (
135
140
  not plot or filtered_complex.num_parameters == 2
136
- ), "Can only plot 2d measures."
141
+ ), f"Can only plot 2d measures. Got {filtered_complex.num_parameters=}."
137
142
 
138
143
  if grid is None:
139
144
  if not filtered_complex.is_squeezed:
@@ -141,7 +146,7 @@ def signed_measure(
141
146
  filtered_complex, strategy=grid_strategy, **infer_grid_kwargs
142
147
  )
143
148
  else:
144
- grid = tuple(np.asarray(f) for f in filtered_complex.filtration_grid)
149
+ grid = filtered_complex.filtration_grid
145
150
 
146
151
  if mass_default is None:
147
152
  mass_default = mass_default
@@ -186,69 +191,70 @@ def signed_measure(
186
191
  grid
187
192
  ), f"Number of parameter do not coincide. Got (grid) {len(grid)} and (filtered complex) {num_parameters}."
188
193
 
189
- if is_simplextree_multi(filtered_complex_):
190
- if num_collapses != 0:
191
- if verbose:
192
- print("Collapsing edges...", end="")
193
- filtered_complex_.collapse_edges(num_collapses)
194
- if verbose:
195
- print("Done.")
196
- if backend is not None:
197
- filtered_complex_ = mp.Slicer(filtered_complex_, vineyard=vineyard)
194
+ # if is_simplextree_multi(filtered_complex_):
195
+ # # if num_collapses != 0:
196
+ # # if verbose:
197
+ # # print("Collapsing edges...", end="")
198
+ # # filtered_complex_.collapse_edges(num_collapses)
199
+ # # if verbose:
200
+ # # print("Done.")
201
+ # # if backend is not None:
202
+ # # filtered_complex_ = mp.Slicer(filtered_complex_, vineyard=vineyard)
198
203
 
199
204
  fix_mass_default = mass_default is not None
200
205
  if is_slicer(filtered_complex_):
201
206
  if verbose:
202
207
  print("Input is a slicer.")
203
208
  if backend is not None and not filtered_complex_.is_minpres:
204
- from multipers.slicer import minimal_presentation
205
-
206
- assert (
207
- invariant != "euler"
208
- ), "Euler Characteristic cannot be speed up by a backend"
209
- # This returns a list of reduced complexes
210
- if verbose:
211
- print("Reducing complex...", end="")
212
- reduced_complex = minimal_presentation(
213
- filtered_complex_,
214
- degrees=degrees,
215
- backend=backend,
216
- vineyard=vineyard,
217
- verbose=verbose,
218
- )
219
- if verbose:
220
- print("Done.")
221
- if invariant is not None and "rank" in invariant:
222
- if verbose:
223
- print("Computing rank...", end="")
224
- sms = [
225
- _rank_from_slicer(
226
- s,
227
- degrees=[d],
228
- n_jobs=n_jobs,
229
- # grid_shape=tuple(len(g) for g in grid),
230
- zero_pad=fix_mass_default,
231
- ignore_inf=ignore_infinite_filtration_values,
232
- )[0]
233
- for s, d in zip(reduced_complex, degrees)
234
- ]
235
- fix_mass_default = False
236
- if verbose:
237
- print("Done.")
238
- else:
239
- if verbose:
240
- print("Reduced slicer. Retrieving measure from it...", end="")
241
- sms = [
242
- _signed_measure_from_slicer(
243
- s,
244
- shift=(
245
- reduced_complex.minpres_degree % 2 if d is None else d % 2
246
- ),
247
- )[0]
248
- for s, d in zip(reduced_complex, degrees)
249
- ]
250
- if verbose:
251
- print("Done.")
209
+ raise ValueError("giving a backend to this function is deprecated")
210
+ # from multipers.slicer import minimal_presentation
211
+ #
212
+ # assert (
213
+ # invariant != "euler"
214
+ # ), "Euler Characteristic cannot be speed up by a backend"
215
+ # # This returns a list of reduced complexes
216
+ # if verbose:
217
+ # print("Reducing complex...", end="")
218
+ # reduced_complex = minimal_presentation(
219
+ # filtered_complex_,
220
+ # degrees=degrees,
221
+ # backend=backend,
222
+ # vineyard=vineyard,
223
+ # verbose=verbose,
224
+ # )
225
+ # if verbose:
226
+ # print("Done.")
227
+ # if invariant is not None and "rank" in invariant:
228
+ # if verbose:
229
+ # print("Computing rank...", end="")
230
+ # sms = [
231
+ # _rank_from_slicer(
232
+ # s,
233
+ # degrees=[d],
234
+ # n_jobs=n_jobs,
235
+ # # grid_shape=tuple(len(g) for g in grid),
236
+ # zero_pad=fix_mass_default,
237
+ # ignore_inf=ignore_infinite_filtration_values,
238
+ # )[0]
239
+ # for s, d in zip(reduced_complex, degrees)
240
+ # ]
241
+ # fix_mass_default = False
242
+ # if verbose:
243
+ # print("Done.")
244
+ # else:
245
+ # if verbose:
246
+ # print("Reduced slicer. Retrieving measure from it...", end="")
247
+ # sms = [
248
+ # _signed_measure_from_slicer(
249
+ # s,
250
+ # shift=(
251
+ # reduced_complex.minpres_degree & 1 if d is None else d & 1
252
+ # ),
253
+ # )[0]
254
+ # for s, d in zip(reduced_complex, degrees)
255
+ # ]
256
+ # if verbose:
257
+ # print("Done.")
252
258
  else: # No backend
253
259
  if invariant is not None and "rank" in invariant:
254
260
  degrees = np.asarray(degrees, dtype=int)
@@ -272,7 +278,7 @@ def signed_measure(
272
278
  _signed_measure_from_slicer(
273
279
  filtered_complex_,
274
280
  shift=(
275
- filtered_complex_.minpres_degree % 2 if d is None else d % 2
281
+ filtered_complex_.minpres_degree & 1 if d is None else d & 1
276
282
  ),
277
283
  )[0]
278
284
  for d in degrees
@@ -385,7 +391,7 @@ def signed_measure(
385
391
  sms,
386
392
  grid=grid,
387
393
  mass_default=mass_default,
388
- num_parameters=num_parameters,
394
+ # num_parameters=num_parameters,
389
395
  )
390
396
  if verbose:
391
397
  print("Done.")
@@ -408,7 +414,7 @@ def _signed_measure_from_scc(
408
414
  pts = np.concatenate([b[0] for b in minimal_presentation])
409
415
  weights = np.concatenate(
410
416
  [
411
- (1 - 2 * (i % 2)) * np.ones(len(b[0]))
417
+ (1 - 2 * (i & 1)) * np.ones(len(b[0]))
412
418
  for i, b in enumerate(minimal_presentation)
413
419
  ]
414
420
  )
@@ -0,0 +1,39 @@
1
+ def api_from_tensor(x, *, verbose: bool = False):
2
+ import multipers.array_api.numpy as npapi
3
+
4
+ if npapi.is_promotable(x):
5
+ if verbose:
6
+ print("using numpy backend")
7
+ return npapi
8
+ import multipers.array_api.torch as torchapi
9
+
10
+ if torchapi.is_promotable(x):
11
+ if verbose:
12
+ print("using torch backend")
13
+ return torchapi
14
+ raise ValueError(f"Unsupported type {type(x)=}")
15
+
16
+
17
+ def api_from_tensors(*args):
18
+ assert len(args) > 0, "no tensor given"
19
+ import multipers.array_api.numpy as npapi
20
+
21
+ is_numpy = True
22
+ for x in args:
23
+ if not npapi.is_promotable(x):
24
+ is_numpy = False
25
+ break
26
+ if is_numpy:
27
+ return npapi
28
+
29
+ # only torch for now
30
+ import multipers.array_api.torch as torchapi
31
+
32
+ is_torch = True
33
+ for x in args:
34
+ if not torchapi.is_promotable(x):
35
+ is_torch = False
36
+ break
37
+ if is_torch:
38
+ return torchapi
39
+ raise ValueError(f"Incompatible types got {[type(x) for x in args]=}.")
@@ -0,0 +1,34 @@
1
+ from contextlib import nullcontext
2
+
3
+ import numpy as _np
4
+ from scipy.spatial.distance import cdist
5
+
6
+ backend = _np
7
+ cat = _np.concatenate
8
+ norm = _np.linalg.norm
9
+ astensor = _np.asarray
10
+ asnumpy = _np.asarray
11
+ tensor = _np.array
12
+ stack = _np.stack
13
+ empty = _np.empty
14
+ where = _np.where
15
+ no_grad = nullcontext
16
+ zeros = _np.zeros
17
+ min = _np.min
18
+ max = _np.max
19
+
20
+
21
+ def minvalues(x: _np.ndarray, **kwargs):
22
+ return _np.min(x, **kwargs)
23
+
24
+
25
+ def maxvalues(x: _np.ndarray, **kwargs):
26
+ return _np.max(x, **kwargs)
27
+
28
+
29
+ def is_promotable(x):
30
+ return isinstance(x, _np.ndarray | list | tuple)
31
+
32
+
33
+ def has_grad(_):
34
+ return False
@@ -0,0 +1,35 @@
1
+ import torch as _t
2
+
3
+ backend = _t
4
+ cat = _t.cat
5
+ norm = _t.norm
6
+ astensor = _t.as_tensor
7
+ tensor = _t.tensor
8
+ stack = _t.stack
9
+ empty = _t.empty
10
+ where = _t.where
11
+ no_grad = _t.no_grad
12
+ cdist = _t.cdist
13
+ zeros = _t.zeros
14
+ min = _t.min
15
+ max = _t.max
16
+
17
+
18
+ def minvalues(x: _t.Tensor, **kwargs):
19
+ return _t.min(x, **kwargs).values
20
+
21
+
22
+ def maxvalues(x: _t.Tensor, **kwargs):
23
+ return _t.max(x, **kwargs).values
24
+
25
+
26
+ def asnumpy(x):
27
+ return x.detach().numpy()
28
+
29
+
30
+ def is_promotable(x):
31
+ return isinstance(x, _t.Tensor)
32
+
33
+
34
+ def has_grad(x):
35
+ return x.requires_grad
multipers/distances.py CHANGED
@@ -6,7 +6,7 @@ from multipers.multiparameter_module_approximation import PyModule_type
6
6
  from multipers.simplex_tree_multi import SimplexTreeMulti_type
7
7
 
8
8
 
9
- def sm2diff(sm1, sm2):
9
+ def sm2diff(sm1, sm2, threshold=None):
10
10
  pts = sm1[0]
11
11
  dtype = pts.dtype
12
12
  if isinstance(pts, np.ndarray):
@@ -45,6 +45,9 @@ def sm2diff(sm1, sm2):
45
45
  )
46
46
  x = backend_concatenate(pts1[pos_indices1], pts2[neg_indices2])
47
47
  y = backend_concatenate(pts1[neg_indices1], pts2[pos_indices2])
48
+ if threshold is not None:
49
+ x[x>threshold]=threshold
50
+ y[y>threshold]=threshold
48
51
  return x, y
49
52
 
50
53
 
@@ -55,6 +58,7 @@ def sm_distance(
55
58
  reg_m: float = 0,
56
59
  numItermax: int = 10000,
57
60
  p: float = 1,
61
+ threshold=None,
58
62
  ):
59
63
  """
60
64
  Computes the wasserstein distances between two signed measures,
@@ -68,7 +72,7 @@ def sm_distance(
68
72
  - sinkhorn if reg != 0
69
73
  - sinkhorn unbalanced if reg_m != 0
70
74
  """
71
- x, y = sm2diff(sm1, sm2)
75
+ x, y = sm2diff(sm1, sm2, threshold=threshold)
72
76
  loss = ot.dist(
73
77
  x, y, metric="sqeuclidean", p=p
74
78
  ) # only euc + sqeuclidian are implemented in pot for the moment with torch backend # TODO : check later
@@ -1,8 +1,9 @@
1
1
  from collections.abc import Callable, Iterable
2
2
  from typing import Any, Literal, Union
3
-
4
3
  import numpy as np
5
4
 
5
+
6
+ from multipers.array_api import api_from_tensor
6
7
  global available_kernels
7
8
  available_kernels = Union[
8
9
  Literal[
@@ -41,13 +42,14 @@ def convolution_signed_measures(
41
42
  from multipers.grids import todense
42
43
 
43
44
  grid_iterator = todense(filtrations, product_order=True)
45
+ api = api_from_tensor(iterable_of_signed_measures[0][0][0])
44
46
  match backend:
45
47
  case "sklearn":
46
48
 
47
49
  def convolution_signed_measures_on_grid(
48
- signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
50
+ signed_measures,
49
51
  ):
50
- return np.concatenate(
52
+ return api.cat(
51
53
  [
52
54
  _pts_convolution_sparse_old(
53
55
  pts=pts,
@@ -67,7 +69,7 @@ def convolution_signed_measures(
67
69
  def convolution_signed_measures_on_grid(
68
70
  signed_measures: Iterable[tuple[np.ndarray, np.ndarray]],
69
71
  ) -> np.ndarray:
70
- return np.concatenate(
72
+ return api.cat(
71
73
  [
72
74
  _pts_convolution_pykeops(
73
75
  pts=pts,
@@ -111,7 +113,7 @@ def convolution_signed_measures(
111
113
  if not flatten:
112
114
  out_shape = [-1] + [len(f) for f in filtrations] # Degree
113
115
  convolutions = [x.reshape(out_shape) for x in convolutions]
114
- return np.asarray(convolutions)
116
+ return api.cat([x[None] for x in convolutions])
115
117
 
116
118
 
117
119
  # def _test(r=1000, b=0.5, plot=True, kernel=0):
@@ -173,10 +175,17 @@ def _pts_convolution_pykeops(
173
175
  """
174
176
  Pykeops convolution
175
177
  """
178
+ if isinstance(pts, np.ndarray):
179
+ _asarray_weights = lambda x : np.asarray(x, dtype=pts.dtype)
180
+ _asarray_grid = _asarray_weights
181
+ else:
182
+ import torch
183
+ _asarray_weights = lambda x : torch.from_numpy(x).type(pts.dtype)
184
+ _asarray_grid = lambda x : x.type(pts.dtype)
176
185
  kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
177
186
  return kde.fit(
178
- pts, sample_weights=np.asarray(pts_weights, dtype=pts.dtype)
179
- ).score_samples(np.asarray(grid_iterator, dtype=pts.dtype))
187
+ pts, sample_weights=_asarray_weights(pts_weights)
188
+ ).score_samples(_asarray_grid(grid_iterator))
180
189
 
181
190
 
182
191
  def gaussian_kernel(x_i, y_j, bandwidth):
@@ -291,10 +300,10 @@ class KDE:
291
300
  X.reshape((X.shape[0], 1, X.shape[1]))
292
301
  ) # numpts, 1, dim
293
302
  lazy_y = LazyTensor(
294
- Y.reshape((1, Y.shape[0], Y.shape[1]))
303
+ Y.reshape((1, Y.shape[0], Y.shape[1])).astype(X.dtype)
295
304
  ) # 1, numpts, dim
296
305
  if x_weights is not None:
297
- w = LazyTensor(x_weights[:, None], axis=0)
306
+ w = LazyTensor(np.asarray(x_weights, dtype=X.dtype)[:, None], axis=0)
298
307
  return lazy_x, lazy_y, w
299
308
  return lazy_x, lazy_y, None
300
309
  import torch
@@ -303,9 +312,11 @@ class KDE:
303
312
  from pykeops.torch import LazyTensor
304
313
 
305
314
  lazy_x = LazyTensor(X.view(X.shape[0], 1, X.shape[1]))
306
- lazy_y = LazyTensor(Y.view(1, Y.shape[0], Y.shape[1]))
315
+ lazy_y = LazyTensor(Y.type(X.dtype).view(1, Y.shape[0], Y.shape[1]))
307
316
  if x_weights is not None:
308
- w = LazyTensor(x_weights[:, None], axis=0)
317
+ if isinstance(x_weights, np.ndarray):
318
+ x_weights = torch.from_numpy(x_weights)
319
+ w = LazyTensor(x_weights[:, None].type(X.dtype), axis=0)
309
320
  return lazy_x, lazy_y, w
310
321
  return lazy_x, lazy_y, None
311
322
  raise Exception("Bad tensor type.")
@@ -521,7 +532,7 @@ class KNNmean:
521
532
 
522
533
  # Symbolic distance matrix:
523
534
  if self.metric == "euclidean":
524
- D_ij = ((X_i - X_j) ** 2).sum(-1)
535
+ D_ij = ((X_i - X_j) ** 2).sum(-1) ** (1/2)
525
536
  elif self.metric == "manhattan":
526
537
  D_ij = (X_i - X_j).abs().sum(-1)
527
538
  elif self.metric == "angular":
@@ -1,13 +1,15 @@
1
1
  from collections.abc import Sequence
2
2
  from typing import Optional
3
+ from warnings import warn
3
4
 
4
5
  import gudhi as gd
5
6
  import numpy as np
6
7
  from numpy.typing import ArrayLike
7
8
  from scipy.spatial import KDTree
8
- from scipy.spatial.distance import cdist
9
9
 
10
+ from multipers.array_api import api_from_tensor, api_from_tensors
10
11
  from multipers.filtrations.density import DTM, available_kernels
12
+ from multipers.grids import compute_grid
11
13
  from multipers.simplex_tree_multi import SimplexTreeMulti, SimplexTreeMulti_type
12
14
 
13
15
  try:
@@ -15,8 +17,9 @@ try:
15
17
 
16
18
  from multipers.filtrations.density import KDE
17
19
  except ImportError:
20
+
18
21
  from sklearn.neighbors import KernelDensity
19
- from warnings import warn
22
+
20
23
  warn("pykeops not found. Falling back to sklearn.")
21
24
 
22
25
  def KDE(bandwidth, kernel, return_log):
@@ -28,8 +31,8 @@ def RipsLowerstar(
28
31
  *,
29
32
  points: Optional[ArrayLike] = None,
30
33
  distance_matrix: Optional[ArrayLike] = None,
31
- function=None,
32
- threshold_radius=None,
34
+ function: Optional[ArrayLike] = None,
35
+ threshold_radius: Optional[float] = None,
33
36
  ):
34
37
  """
35
38
  Computes the Rips complex, with the usual rips filtration as a first parameter,
@@ -44,22 +47,37 @@ def RipsLowerstar(
44
47
  points is not None or distance_matrix is not None
45
48
  ), "`points` or `distance_matrix` has to be given."
46
49
  if distance_matrix is None:
47
- distance_matrix = cdist(points, points) # this may be slow...
50
+ api = api_from_tensor(points)
51
+ points = api.astensor(points)
52
+ D = api.cdist(points, points) # this may be slow...
53
+ else:
54
+ api = api_from_tensor(distance_matrix)
55
+ D = api.astensor(distance_matrix)
56
+
48
57
  if threshold_radius is None:
49
- threshold_radius = np.min(np.max(distance_matrix, axis=1))
58
+ threshold_radius = api.min(api.maxvalues(D, axis=1))
50
59
  st = gd.SimplexTree.create_from_array(
51
- distance_matrix, max_filtration=threshold_radius
60
+ api.asnumpy(D), max_filtration=threshold_radius
52
61
  )
53
62
  if function is None:
54
63
  return SimplexTreeMulti(st, num_parameters=1)
55
64
 
56
- function = np.asarray(function)
65
+ function = api.astensor(function)
57
66
  if function.ndim == 1:
58
67
  function = function[:, None]
68
+ if function.ndim != 2:
69
+ raise ValueError(
70
+ f"`function.ndim` should be 0 or 1 . Got {function.ndim=}.{function=}"
71
+ )
59
72
  num_parameters = function.shape[1] + 1
60
73
  st = SimplexTreeMulti(st, num_parameters=num_parameters)
61
74
  for i in range(function.shape[1]):
62
- st.fill_lowerstar(function[:, i], parameter=1 + i)
75
+ st.fill_lowerstar(api.asnumpy(function[:, i]), parameter=1 + i)
76
+ if api.has_grad(D) or api.has_grad(function):
77
+ from multipers.grids import compute_grid
78
+
79
+ grid = compute_grid([D.ravel(), *[f for f in function.T]])
80
+ st = st.grid_squeeze(grid)
63
81
  return st
64
82
 
65
83
 
@@ -99,6 +117,7 @@ def DelaunayLowerstar(
99
117
  dtype=np.float64,
100
118
  verbose: bool = False,
101
119
  clear: bool = True,
120
+ flagify: bool = False,
102
121
  ):
103
122
  """
104
123
  Computes the Function Delaunay bifiltration. Similar to RipsLowerstar, but most suited for low-dimensional euclidean data.
@@ -110,23 +129,44 @@ def DelaunayLowerstar(
110
129
  - threshold_radius: max edge length of the rips. Defaults at min(max(distance_matrix, axis=1)).
111
130
  """
112
131
  from multipers.slicer import from_function_delaunay
132
+
133
+ if flagify and reduce_degree >= 0:
134
+ raise ValueError(
135
+ "Got {reduce_degree=} and {flagify=}. Cannot flagify with reduce degree."
136
+ )
113
137
  assert distance_matrix is None, "Delaunay cannot be built from distance matrices"
114
138
  if threshold_radius is not None:
115
139
  raise NotImplementedError("Delaunay with threshold not implemented yet.")
116
- points = np.asarray(points)
117
- function = np.asarray(function).squeeze()
140
+ api = api_from_tensors(points, function)
141
+ if not flagify and (api.has_grad(points) or api.has_grad(function)):
142
+ warn("Cannot keep points gradient unless using `flagify=True`.")
143
+ points = api.astensor(points)
144
+ function = api.astensor(function).squeeze()
118
145
  assert (
119
146
  function.ndim == 1
120
147
  ), "Delaunay Lowerstar is only compatible with 1 additional parameter."
121
- return from_function_delaunay(
122
- points,
123
- function,
148
+ slicer = from_function_delaunay(
149
+ api.asnumpy(points),
150
+ api.asnumpy(function),
124
151
  degree=reduce_degree,
125
152
  vineyard=vineyard,
126
153
  dtype=dtype,
127
154
  verbose=verbose,
128
155
  clear=clear,
129
156
  )
157
+ if flagify:
158
+ from multipers.slicer import to_simplextree
159
+
160
+ slicer = to_simplextree(slicer)
161
+ slicer.flagify(2)
162
+
163
+ if api.has_grad(points) or api.has_grad(function):
164
+ distances = api.cdist(points, points) / 2
165
+ grid = compute_grid([distances.ravel(), function])
166
+ slicer = slicer.grid_squeeze(grid)
167
+ slicer = slicer._clean_filtration_grid()
168
+
169
+ return slicer
130
170
 
131
171
 
132
172
  def DelaunayCodensity(
@@ -142,6 +182,7 @@ def DelaunayCodensity(
142
182
  dtype=np.float64,
143
183
  verbose: bool = False,
144
184
  clear: bool = True,
185
+ flagify: bool = False,
145
186
  ):
146
187
  """
147
188
  TODO
@@ -165,6 +206,7 @@ def DelaunayCodensity(
165
206
  dtype=dtype,
166
207
  verbose=verbose,
167
208
  clear=clear,
209
+ flagify=flagify,
168
210
  )
169
211
 
170
212
 
@@ -178,6 +220,23 @@ def Cubical(image: ArrayLike, **slicer_kwargs):
178
220
  - ** args : specify non-default slicer parameters
179
221
  """
180
222
  from multipers.slicer import from_bitmap
223
+
224
+ api = api_from_tensor(image)
225
+ image = api.astensor(image)
226
+ if api.has_grad(image):
227
+ img2 = image.reshape(-1, image.shape[-1]).T
228
+ grid = compute_grid(img2)
229
+ coord_img = np.empty(image.shape, dtype=np.int32)
230
+ slice_shape = image.shape[:-1]
231
+ for i in range(image.shape[-1]):
232
+ coord_img[..., i] = np.searchsorted(
233
+ api.asnumpy(grid[i]),
234
+ api.asnumpy(image[..., i]).reshape(-1),
235
+ ).reshape(slice_shape)
236
+ slicer = from_bitmap(coord_img, **slicer_kwargs)
237
+ slicer.filtration_grid = grid
238
+ return slicer
239
+
181
240
  return from_bitmap(image, **slicer_kwargs)
182
241
 
183
242
 
@@ -214,7 +273,7 @@ def CoreDelaunay(
214
273
  ks = np.arange(1, len(points) + 1)
215
274
  else:
216
275
  ks = np.asarray(ks, dtype=int)
217
- ks:np.ndarray
276
+ ks: np.ndarray
218
277
 
219
278
  assert len(ks) > 0, "The parameter ks must contain at least one value."
220
279
  assert np.all(ks > 0), "All values in ks must be positive."
Binary file
Binary file