multipers 2.3.3b5__cp313-cp313-win_amd64.whl → 2.3.3b7__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (33) hide show
  1. multipers/_signed_measure_meta.py +4 -7
  2. multipers/array_api/__init__.py +18 -1
  3. multipers/array_api/numpy.py +68 -0
  4. multipers/array_api/torch.py +80 -0
  5. multipers/filtrations/density.py +11 -52
  6. multipers/filtrations/filtrations.py +19 -6
  7. multipers/function_rips.cp313-win_amd64.pyd +0 -0
  8. multipers/grids.cp313-win_amd64.pyd +0 -0
  9. multipers/grids.pyx +73 -32
  10. multipers/io.cp313-win_amd64.pyd +0 -0
  11. multipers/ml/signed_measures.py +105 -27
  12. multipers/mma_structures.cp313-win_amd64.pyd +0 -0
  13. multipers/mma_structures.pyx +2 -2
  14. multipers/mma_structures.pyx.tp +1 -1
  15. multipers/multiparameter_module_approximation.cp313-win_amd64.pyd +0 -0
  16. multipers/plots.py +12 -6
  17. multipers/point_measure.cp313-win_amd64.pyd +0 -0
  18. multipers/simplex_tree_multi.cp313-win_amd64.pyd +0 -0
  19. multipers/simplex_tree_multi.pyx +24 -8
  20. multipers/simplex_tree_multi.pyx.tp +3 -1
  21. multipers/slicer.cp313-win_amd64.pyd +0 -0
  22. multipers/slicer.pxd +20 -20
  23. multipers/slicer.pyx +53 -52
  24. multipers/slicer.pyx.tp +2 -1
  25. multipers/tbb12.dll +0 -0
  26. multipers/tbbbind_2_5.dll +0 -0
  27. multipers/tbbmalloc.dll +0 -0
  28. multipers/tbbmalloc_proxy.dll +0 -0
  29. {multipers-2.3.3b5.dist-info → multipers-2.3.3b7.dist-info}/METADATA +1 -1
  30. {multipers-2.3.3b5.dist-info → multipers-2.3.3b7.dist-info}/RECORD +33 -33
  31. {multipers-2.3.3b5.dist-info → multipers-2.3.3b7.dist-info}/WHEEL +0 -0
  32. {multipers-2.3.3b5.dist-info → multipers-2.3.3b7.dist-info}/licenses/LICENSE +0 -0
  33. {multipers-2.3.3b5.dist-info → multipers-2.3.3b7.dist-info}/top_level.txt +0 -0
@@ -279,10 +279,6 @@ def signed_measure(
279
279
  ignore_inf=ignore_infinite_filtration_values,
280
280
  )
281
281
  fix_mass_default = False
282
-
283
- if "hook" in invariant:
284
- from multipers.point_measure import rectangle_to_hook_minimal_signed_barcode
285
- sms = [rectangle_to_hook_minimal_signed_barcode(pts,w) for pts,w in sms]
286
282
  if verbose:
287
283
  print("Done.")
288
284
  elif filtered_complex_.is_minpres:
@@ -345,9 +341,6 @@ def signed_measure(
345
341
  expand_collapse=expand_collapse,
346
342
  )
347
343
  fix_mass_default = False
348
- if "hook" in invariant:
349
- from multipers.point_measure import rectangle_to_hook_minimal_signed_barcode
350
- sms = [rectangle_to_hook_minimal_signed_barcode(pts,w) for pts,w in sms]
351
344
  if verbose:
352
345
  print("Done.")
353
346
  elif len(degrees) == 1 and degrees[0] is None:
@@ -420,6 +413,10 @@ def signed_measure(
420
413
  sms = zero_out_sms(sms, mass_default=mass_default)
421
414
  if verbose:
422
415
  print("Done.")
416
+
417
+ if invariant == "hook":
418
+ from multipers.point_measure import rectangle_to_hook_minimal_signed_barcode
419
+ sms = [rectangle_to_hook_minimal_signed_barcode(pts,w) for pts,w in sms]
423
420
  if plot:
424
421
  plot_signed_measures(sms)
425
422
  return sms
@@ -1,7 +1,15 @@
1
1
  import multipers.array_api.numpy as npapi
2
2
 
3
3
 
4
- def api_from_tensor(x, *, verbose: bool = False):
4
+ def api_from_tensor(x, *, verbose: bool = False, strict=False):
5
+ if strict:
6
+ if npapi.is_tensor(x):
7
+ return npapi
8
+ import multipers.array_api.torch as torchapi
9
+
10
+ if torchapi.is_tensor(x):
11
+ return torchapi
12
+ raise ValueError(f"Unsupported (strict) type {type(x)=}")
5
13
  if npapi.is_promotable(x):
6
14
  if verbose:
7
15
  print("using numpy backend")
@@ -43,3 +51,12 @@ def api_from_tensors(*args):
43
51
  def to_numpy(x):
44
52
  api = api_from_tensor(x)
45
53
  return api.asnumpy(x)
54
+
55
+
56
+ def check_keops():
57
+ import os
58
+
59
+ if os.name == "nt":
60
+ # see https://github.com/getkeops/keops/pull/421
61
+ return False
62
+ return npapi.check_keops()
@@ -18,6 +18,70 @@ min = _np.min
18
18
  max = _np.max
19
19
  repeat_interleave = _np.repeat
20
20
  cdist = cdist # type: ignore[no-redef]
21
+ unique = _np.unique
22
+ inf = _np.inf
23
+ searchsorted = _np.searchsorted
24
+ LazyTensor = None
25
+
26
+ # Test keops
27
+ _is_keops_available = None
28
+
29
+
30
+ def check_keops():
31
+ global _is_keops_available, LazyTensor
32
+ if _is_keops_available is not None:
33
+ return _is_keops_available
34
+ import pykeops.numpy as pknp
35
+ from pykeops.numpy import LazyTensor as LT
36
+
37
+ formula = "SqNorm2(x - y)"
38
+ var = ["x = Vi(3)", "y = Vj(3)"]
39
+ expected_res = _np.array([63.0, 90.0])
40
+ x = _np.arange(1, 10).reshape(-1, 3).astype("float32")
41
+ y = _np.arange(3, 9).reshape(-1, 3).astype("float32")
42
+
43
+ my_conv = pknp.Genred(formula, var)
44
+ try:
45
+ _is_keops_available = _np.allclose(my_conv(x, y).flatten(), expected_res)
46
+ LazyTensor = LT
47
+ except:
48
+ from warnings import warn
49
+
50
+ warn("Could not initialize keops (numpy). using workarounds")
51
+ _is_keops_available = False
52
+
53
+ return _is_keops_available
54
+
55
+
56
+ def from_numpy(x):
57
+ return _np.asarray(x)
58
+
59
+
60
+ def ascontiguous(x):
61
+ return _np.ascontiguousarray(x)
62
+
63
+
64
+ def sort(x, axis=-1):
65
+ return _np.sort(x, axis=axis)
66
+
67
+
68
+ def device(x): # type: ignore[no-unused-arg]
69
+ return None
70
+
71
+
72
+ # type: ignore[no-unused-arg]
73
+ def linspace(low, high, r, device=None, dtype=None):
74
+ return _np.linspace(low, high, r, dtype=dtype)
75
+
76
+
77
+ def cartesian_product(*arrays, dtype=None):
78
+ mesh = _np.meshgrid(*arrays, indexing="ij")
79
+ coordinates = _np.stack(mesh, axis=-1).reshape(-1, len(arrays)).astype(dtype)
80
+ return coordinates
81
+
82
+
83
+ def quantile_closest(x, q, axis=None):
84
+ return _np.quantile(x, q, axis=axis, method="closest_observation")
21
85
 
22
86
 
23
87
  def minvalues(x: _np.ndarray, **kwargs):
@@ -28,6 +92,10 @@ def maxvalues(x: _np.ndarray, **kwargs):
28
92
  return _np.max(x, **kwargs)
29
93
 
30
94
 
95
+ def is_tensor(x):
96
+ return isinstance(x, _np.ndarray)
97
+
98
+
31
99
  def is_promotable(x):
32
100
  return isinstance(x, _np.ndarray | list | tuple)
33
101
 
@@ -1,3 +1,4 @@
1
+ import numpy as _np
1
2
  import torch as _t
2
3
 
3
4
  backend = _t
@@ -14,6 +15,81 @@ zeros = _t.zeros
14
15
  min = _t.min
15
16
  max = _t.max
16
17
  repeat_interleave = _t.repeat_interleave
18
+ linspace = _t.linspace
19
+ cartesian_product = _t.cartesian_prod
20
+ inf = _t.inf
21
+ searchsorted = _t.searchsorted
22
+ LazyTensor = None
23
+
24
+
25
+ _is_keops_available = None
26
+
27
+
28
+ def check_keops():
29
+ global _is_keops_available, LazyTensor
30
+ if _is_keops_available is not None:
31
+ return _is_keops_available
32
+ try:
33
+ import pykeops.torch as pknp
34
+ from pykeops.torch import LazyTensor as LT
35
+
36
+ formula = "SqNorm2(x - y)"
37
+ var = ["x = Vi(3)", "y = Vj(3)"]
38
+ expected_res = _t.tensor([63.0, 90.0])
39
+ x = _t.arange(1, 10, dtype=_t.float32).view(-1, 3)
40
+ y = _t.arange(3, 9, dtype=_t.float32).view(-1, 3)
41
+
42
+ my_conv = pknp.Genred(formula, var)
43
+ _is_keops_available = _t.allclose(
44
+ my_conv(x, y).view(-1), _t.tensor(expected_res).type(_t.float32)
45
+ )
46
+ LazyTensor = LT
47
+
48
+ except:
49
+ from warnings import warn
50
+
51
+ warn("Could not initialize keops (torch). using workarounds")
52
+
53
+ _is_keops_available = False
54
+
55
+ return _is_keops_available
56
+
57
+
58
+ def from_numpy(x):
59
+ return _t.from_numpy(x)
60
+
61
+
62
+ def ascontiguous(x):
63
+ return _t.as_tensor(x).contiguous()
64
+
65
+
66
+ def device(x):
67
+ return x.device
68
+
69
+
70
+ def sort(x, axis=-1):
71
+ return _t.sort(x, dim=axis).values
72
+
73
+
74
+ # in our context, this allows to get a correct gradient.
75
+ def unique(x, assume_sorted=False, _mean=True):
76
+ if not x.requires_grad:
77
+ return x.unique(sorted=assume_sorted)
78
+ if x.ndim != 1:
79
+ raise ValueError(f"Got ndim!=1. {x=}")
80
+ if not assume_sorted:
81
+ x = x.sort().values
82
+ _, c = _t.unique(x, sorted=True, return_counts=True)
83
+ if _mean:
84
+ x = _t.segment_reduce(data=x, reduce="mean", lengths=c, unsafe=True, axis=0)
85
+ else:
86
+ c = _np.concatenate([[0], _np.cumsum(c[:-1])])
87
+ x = x[c]
88
+ return x
89
+
90
+
91
+ def quantile_closest(x, q, axis=None):
92
+ return _t.quantile(x, q, dim=axis, interpolation="nearest")
17
93
 
18
94
 
19
95
  def minvalues(x: _t.Tensor, **kwargs):
@@ -28,6 +104,10 @@ def asnumpy(x):
28
104
  return x.detach().numpy()
29
105
 
30
106
 
107
+ def is_tensor(x):
108
+ return isinstance(x, _t.Tensor)
109
+
110
+
31
111
  def is_promotable(x):
32
112
  return isinstance(x, _t.Tensor)
33
113
 
@@ -1,9 +1,10 @@
1
1
  from collections.abc import Callable, Iterable
2
2
  from typing import Any, Literal, Union
3
+
3
4
  import numpy as np
4
5
 
6
+ from multipers.array_api import api_from_tensor, api_from_tensors
5
7
 
6
- from multipers.array_api import api_from_tensor
7
8
  global available_kernels
8
9
  available_kernels = Union[
9
10
  Literal[
@@ -176,23 +177,24 @@ def _pts_convolution_pykeops(
176
177
  Pykeops convolution
177
178
  """
178
179
  if isinstance(pts, np.ndarray):
179
- _asarray_weights = lambda x : np.asarray(x, dtype=pts.dtype)
180
+ _asarray_weights = lambda x: np.asarray(x, dtype=pts.dtype)
180
181
  _asarray_grid = _asarray_weights
181
182
  else:
182
183
  import torch
183
- _asarray_weights = lambda x : torch.from_numpy(x).type(pts.dtype)
184
- _asarray_grid = lambda x : x.type(pts.dtype)
184
+
185
+ _asarray_weights = lambda x: torch.from_numpy(x).type(pts.dtype)
186
+ _asarray_grid = lambda x: x.type(pts.dtype)
185
187
  kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
186
- return kde.fit(
187
- pts, sample_weights=_asarray_weights(pts_weights)
188
- ).score_samples(_asarray_grid(grid_iterator))
188
+ return kde.fit(pts, sample_weights=_asarray_weights(pts_weights)).score_samples(
189
+ _asarray_grid(grid_iterator)
190
+ )
189
191
 
190
192
 
191
193
  def gaussian_kernel(x_i, y_j, bandwidth):
192
194
  D = x_i.shape[-1]
193
195
  exponent = -(((x_i - y_j) / bandwidth) ** 2).sum(dim=-1) / 2
194
196
  # float is necessary for some reason (pykeops fails)
195
- kernel = (exponent).exp() / float((bandwidth*np.sqrt(2 * np.pi))**D)
197
+ kernel = (exponent).exp() / float((bandwidth * np.sqrt(2 * np.pi)) ** D)
196
198
  return kernel
197
199
 
198
200
 
@@ -359,49 +361,6 @@ class KDE:
359
361
  )
360
362
 
361
363
 
362
- def batch_signed_measure_convolutions(
363
- signed_measures, # array of shape (num_data,num_pts,D)
364
- x, # array of shape (num_x, D) or (num_data, num_x, D)
365
- bandwidth, # either float or matrix if multivariate kernel
366
- kernel: available_kernels,
367
- ):
368
- """
369
- Input
370
- -----
371
- - signed_measures: unragged, of shape (num_data, num_pts, D+1)
372
- where last coord is weights, (0 for dummy points)
373
- - x : the points to convolve (num_x,D)
374
- - bandwidth : the bandwidths or covariance matrix inverse or ... of the kernel
375
- - kernel : "gaussian", "multivariate_gaussian", "exponential", or Callable (x_i, y_i, bandwidth)->float
376
-
377
- Output
378
- ------
379
- Array of shape (num_convolutions, (num_axis), num_data,
380
- Array of shape (num_convolutions, (num_axis), num_data, max_x_size)
381
- """
382
- if signed_measures.ndim == 2:
383
- signed_measures = signed_measures[None, :, :]
384
- sms = signed_measures[..., :-1]
385
- weights = signed_measures[..., -1]
386
- if isinstance(signed_measures, np.ndarray):
387
- from pykeops.numpy import LazyTensor
388
- else:
389
- import torch
390
-
391
- assert isinstance(signed_measures, torch.Tensor)
392
- from pykeops.torch import LazyTensor
393
-
394
- _sms = LazyTensor(sms[..., None, :].contiguous())
395
- _x = x[..., None, :, :].contiguous()
396
-
397
- sms_kernel = _kernel(kernel)(_sms, _x, bandwidth)
398
- out = (sms_kernel * weights[..., None, None].contiguous()).sum(
399
- signed_measures.ndim - 2
400
- )
401
- assert out.shape[-1] == 1, "Pykeops bug fixed, TODO : refix this "
402
- out = out[..., 0] ## pykeops bug + ensures its a tensor
403
- # assert out.shape == (x.shape[0], x.shape[1]), f"{x.shape=}, {out.shape=}"
404
- return out
405
364
 
406
365
 
407
366
  class DTM:
@@ -532,7 +491,7 @@ class KNNmean:
532
491
 
533
492
  # Symbolic distance matrix:
534
493
  if self.metric == "euclidean":
535
- D_ij = ((X_i - X_j) ** 2).sum(-1) ** (1/2)
494
+ D_ij = ((X_i - X_j) ** 2).sum(-1) ** (1 / 2)
536
495
  elif self.metric == "manhattan":
537
496
  D_ij = (X_i - X_j).abs().sum(-1)
538
497
  elif self.metric == "angular":
@@ -17,7 +17,6 @@ try:
17
17
 
18
18
  from multipers.filtrations.density import KDE
19
19
  except ImportError:
20
-
21
20
  from sklearn.neighbors import KernelDensity
22
21
 
23
22
  warn("pykeops not found. Falling back to sklearn.")
@@ -67,7 +66,9 @@ def RipsLowerstar(
67
66
  function = function[:, None]
68
67
  if function.ndim != 2:
69
68
  raise ValueError(
70
- f"`function.ndim` should be 0 or 1 . Got {function.ndim=}.{function=}"
69
+ f"""
70
+ `function.ndim` should be 0 or 1 . Got {function.ndim=}.{function=}
71
+ """
71
72
  )
72
73
  num_parameters = function.shape[1] + 1
73
74
  st = SimplexTreeMulti(st, num_parameters=num_parameters)
@@ -154,6 +155,9 @@ def DelaunayLowerstar(
154
155
  verbose=verbose,
155
156
  clear=clear,
156
157
  )
158
+ if reduce_degree >= 0:
159
+ # Force resolution to avoid confusion with hilbert.
160
+ slicer = slicer.minpres(degree=reduce_degree, force=True)
157
161
  if flagify:
158
162
  from multipers.slicer import to_simplextree
159
163
 
@@ -192,7 +196,7 @@ def DelaunayCodensity(
192
196
  ), "Density estimation is either via kernels or dtm."
193
197
  if bandwidth is not None:
194
198
  kde = KDE(bandwidth=bandwidth, kernel=kernel, return_log=return_log)
195
- f = kde.fit(points).score_samples(points)
199
+ f = -kde.fit(points).score_samples(points)
196
200
  elif dtm_mass is not None:
197
201
  f = DTM(masses=[dtm_mass]).fit(points).score_samples(points)[0]
198
202
  else:
@@ -287,11 +291,17 @@ def CoreDelaunay(
287
291
  "safe",
288
292
  "exact",
289
293
  "fast",
290
- ], f"The parameter precision must be one of ['safe', 'exact', 'fast'], got {precision}."
294
+ ], f"""
295
+ The parameter precision must be one of ['safe', 'exact', 'fast'],
296
+ got {precision}.
297
+ """
291
298
 
292
299
  if verbose:
293
300
  print(
294
- f"Computing the Delaunay Core Bifiltration of {len(points)} points in dimension {points.shape[1]} with parameters:"
301
+ f"""Computing the Delaunay Core Bifiltration
302
+ of {len(points)} points in dimension {points.shape[1]}
303
+ with parameters:
304
+ """
295
305
  )
296
306
  print(f"\tbeta = {beta}")
297
307
  print(f"\tks = {ks}")
@@ -333,7 +343,10 @@ def CoreDelaunay(
333
343
  num_simplices = len(vertex_array)
334
344
  if verbose:
335
345
  print(
336
- f"Inserting {num_simplices} simplices of dimension {dim} ({num_simplices * len(ks)} birth values)..."
346
+ f"""
347
+ Inserting {num_simplices} simplices of dimension {dim}
348
+ ({num_simplices * len(ks)} birth values)...
349
+ """
337
350
  )
338
351
  max_knn_distances = np.max(knn_distances[vertex_array], axis=1)
339
352
  critical_radii = np.maximum(alphas[:, None], beta * max_knn_distances)
Binary file
Binary file
multipers/grids.pyx CHANGED
@@ -11,6 +11,7 @@ from typing import Iterable,Literal,Optional
11
11
  from itertools import product
12
12
  from multipers.array_api import api_from_tensor, api_from_tensors
13
13
  from multipers.array_api import numpy as npapi
14
+ from multipers.array_api import check_keops
14
15
 
15
16
  available_strategies = ["regular","regular_closest", "regular_left", "partition", "quantile", "precomputed"]
16
17
  Lstrategies = Literal["regular","regular_closest", "regular_left", "partition", "quantile", "precomputed"]
@@ -168,39 +169,39 @@ def _compute_grid_numpy(
168
169
  Iterable[array[float, ndim=1]] : the 1d-grid for each parameter.
169
170
  """
170
171
  num_parameters = len(filtrations_values)
172
+ api = api_from_tensors(filtrations_values)
171
173
  try:
172
174
  a,b=drop_quantiles
173
175
  except:
174
176
  a,b=drop_quantiles,drop_quantiles
175
177
 
176
178
  if a != 0 or b != 0:
177
- boxes = np.asarray([np.quantile(filtration, [a, b], axis=1, method='closest_observation') for filtration in filtrations_values])
178
- min_filtration, max_filtration = np.min(boxes, axis=(0,1)), np.max(boxes, axis=(0,1)) # box, birth/death, filtration
179
+ boxes = api.astensor([api.quantile_closest(filtration, [a, b], axis=1) for filtration in filtrations_values])
180
+ min_filtration, max_filtration = api.minvalues(boxes, axis=(0,1)), api.maxvalues(boxes, axis=(0,1)) # box, birth/death, filtration
179
181
  filtrations_values = [
180
182
  filtration[(m<filtration) * (filtration <M)]
181
183
  for filtration, m,M in zip(filtrations_values, min_filtration, max_filtration)
182
184
  ]
183
185
 
184
- to_unique = lambda f : np.unique(f) if isinstance(f,np.ndarray) else f.unique()
185
186
  ## match doesn't work with cython BUG
186
187
  if strategy == "exact":
187
- F=tuple(to_unique(f) for f in filtrations_values)
188
+ F=tuple(api.unique(f) for f in filtrations_values)
188
189
  elif strategy == "quantile":
189
- F = tuple(to_unique(f) for f in filtrations_values)
190
+ F = tuple(api.unique(f) for f in filtrations_values)
190
191
  max_resolution = [min(len(f),r) for f,r in zip(F,resolution)]
191
- F = tuple( np.quantile(f, q=np.linspace(0,1,num=int(r*_q_factor)), axis=0, method='closest_observation') for f,r in zip(F, resolution) )
192
+ F = tuple( api.quantile_closest(f, q=np.linspace(0,1,num=int(r*_q_factor)), axis=0) for f,r in zip(F, resolution) )
192
193
  if unique:
193
- F = tuple(to_unique(f) for f in F)
194
+ F = tuple(api.unique(f) for f in F)
194
195
  if np.all(np.asarray(max_resolution) > np.asarray([len(f) for f in F])):
195
196
  return _compute_grid_numpy(filtrations_values=filtrations_values, resolution=resolution, strategy="quantile",_q_factor=1.5*_q_factor)
196
197
  elif strategy == "regular":
197
- F = tuple(np.linspace(np.min(f),np.max(f),num=r, dtype=np.asarray(f).dtype) for f,r in zip(filtrations_values, resolution))
198
+ F = tuple(_todo_regular(f,r,api) for f,r in zip(filtrations_values, resolution))
198
199
  elif strategy == "regular_closest":
199
- F = tuple(_todo_regular_closest(f,r, unique) for f,r in zip(filtrations_values, resolution))
200
+ F = tuple(_todo_regular_closest(f,r, unique,api) for f,r in zip(filtrations_values, resolution))
200
201
  elif strategy == "regular_left":
201
- F = tuple(_todo_regular_left(f,r, unique) for f,r in zip(filtrations_values, resolution))
202
- elif strategy == "torch_regular_closest":
203
- F = tuple(_torch_regular_closest(f,r, unique) for f,r in zip(filtrations_values, resolution))
202
+ F = tuple(_todo_regular_left(f,r, unique,api) for f,r in zip(filtrations_values, resolution))
203
+ # elif strategy == "torch_regular_closest":
204
+ # F = tuple(_torch_regular_closest(f,r, unique) for f,r in zip(filtrations_values, resolution))
204
205
  elif strategy == "partition":
205
206
  F = tuple(_todo_partition(f,r, unique) for f,r in zip(filtrations_values, resolution))
206
207
  elif strategy == "precomputed":
@@ -214,41 +215,75 @@ def _compute_grid_numpy(
214
215
  def todense(grid, bool product_order=False):
215
216
  if len(grid) == 0:
216
217
  return np.empty(0)
217
- if not isinstance(grid[0], np.ndarray):
218
- import torch
219
- assert isinstance(grid[0], torch.Tensor)
220
- from multipers.torch.diff_grids import todense
221
- return todense(grid)
222
- dtype = grid[0].dtype
223
- if product_order:
224
- return np.fromiter(product(*grid), dtype=np.dtype((dtype, len(grid))), count=np.prod([len(f) for f in grid]))
225
- mesh = np.meshgrid(*grid)
226
- coordinates = np.concatenate(tuple(stuff.ravel()[:,None] for stuff in mesh), axis=1, dtype=dtype)
227
- return coordinates
218
+ api = api_from_tensors(grid)
219
+ # if product_order:
220
+ # if not api.backend ==np:
221
+ # raise NotImplementedError("only numpy here.")
222
+ # return np.fromiter(product(*grid), dtype=np.dtype((dtype, len(grid))), count=np.prod([len(f) for f in grid]))
223
+ return api.cartesian_product(*grid)
224
+ # if not isinstance(grid[0], np.ndarray):
225
+ # import torch
226
+ # assert isinstance(grid[0], torch.Tensor)
227
+ # from multipers.torch.diff_grids import todense
228
+ # return todense(grid)
229
+ # dtype = grid[0].dtype
230
+ # if product_order:
231
+ # return np.fromiter(product(*grid), dtype=np.dtype((dtype, len(grid))), count=np.prod([len(f) for f in grid]))
232
+ # mesh = np.meshgrid(*grid)
233
+ # coordinates = np.stack(mesh, axis=-1).reshape(-1, len(grid)).astype(dtype)
234
+ # return coordinates
228
235
 
229
236
 
230
237
 
231
238
  ## TODO : optimize. Pykeops ?
232
- def _todo_regular_closest(some_float[:] f, int r, bool unique):
239
+ def _todo_regular(f, int r, api):
240
+ with api.no_grad():
241
+ return api.linspace(api.min(f), api.max(f), r)
242
+
243
+ def _project_on_1d_grid(f,grid, bool unique, api):
244
+ # api=api_from_tensors(f,grid)
245
+ if f.ndim != 1:
246
+ raise ValueError(f"Got ndim!=1. {f=}")
247
+ f = api.unique(f)
248
+ with api.no_grad():
249
+ _f = api.LazyTensor(f[:, None, None])
250
+ _f_reg = api.LazyTensor(grid[None, :, None])
251
+ indices = (_f - _f_reg).abs().argmin(0).ravel()
252
+ f = api.cat([f, api.tensor([api.inf], dtype=f.dtype)])
253
+ f_proj = f[indices]
254
+ if unique:
255
+ f_proj = api.unique(f_proj)
256
+ return f_proj
257
+
258
+ def _todo_regular_closest_keops(f, int r, bool unique, api):
259
+ f = api.astensor(f)
260
+ with api.no_grad():
261
+ f_regular = api.linspace(api.min(f), api.max(f), r, device = api.device(f),dtype=f.dtype)
262
+ return _project_on_1d_grid(f,f_regular,unique,api)
263
+
264
+ def _todo_regular_closest_old(some_float[:] f, int r, bool unique, api=None):
233
265
  f_array = np.asarray(f)
234
266
  f_regular = np.linspace(np.min(f), np.max(f),num=r, dtype=f_array.dtype)
235
- f_regular_closest = np.asarray([f[<int64_t>np.argmin(np.abs(f_array-f_regular[i]))] for i in range(r)])
267
+ f_regular_closest = np.asarray([f[<int64_t>np.argmin(np.abs(f_array-f_regular[i]))] for i in range(r)], dtype=f_array.dtype)
236
268
  if unique: f_regular_closest = np.unique(f_regular_closest)
237
269
  return f_regular_closest
238
270
 
239
- def _todo_regular_left(some_float[:] f, int r, bool unique):
271
+ def _todo_regular_left(f, int r, bool unique,api):
272
+ sorted_f = api.sort(f)
273
+ with api.no_grad():
274
+ f_regular = api.linspace(sorted_f[0],sorted_f[-1],r, dtype=sorted_f.dtype, device=api.device(sorted_f))
275
+ idx=api.searchsorted(sorted_f,f_regular)
276
+ f_regular_closest = sorted_f[idx]
277
+ if unique: f_regular_closest = api.unique(f_regular_closest)
278
+ return f_regular_closest
279
+
280
+ def _todo_regular_left_old(some_float[:] f, int r, bool unique):
240
281
  sorted_f = np.sort(f)
241
282
  f_regular = np.linspace(sorted_f[0],sorted_f[-1],num=r, dtype=sorted_f.dtype)
242
283
  f_regular_closest = sorted_f[np.searchsorted(sorted_f,f_regular)]
243
284
  if unique: f_regular_closest = np.unique(f_regular_closest)
244
285
  return f_regular_closest
245
286
 
246
- def _torch_regular_closest(f, int r, bool unique=True):
247
- import torch
248
- f_regular = torch.linspace(f.min(),f.max(), r, dtype=f.dtype)
249
- f_regular_closest =torch.tensor([f[(f-x).abs().argmin()] for x in f_regular])
250
- if unique: f_regular_closest = f_regular_closest.unique()
251
- return f_regular_closest
252
287
 
253
288
  def _todo_partition(some_float[:] data,int resolution, bool unique):
254
289
  if data.shape[0] < resolution: resolution=data.shape[0]
@@ -259,6 +294,12 @@ def _todo_partition(some_float[:] data,int resolution, bool unique):
259
294
  return f
260
295
 
261
296
 
297
+ if check_keops():
298
+ _todo_regular_closest = _todo_regular_closest_keops
299
+ else:
300
+ _todo_regular_closest = _todo_regular_closest_old
301
+
302
+
262
303
  def compute_bounding_box(stuff, inflate = 0.):
263
304
  r"""
264
305
  Returns a array of shape (2, num_parameters)
Binary file