multipers 2.3.3b6__cp313-cp313-win_amd64.whl → 2.3.4__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

@@ -279,10 +279,6 @@ def signed_measure(
279
279
  ignore_inf=ignore_infinite_filtration_values,
280
280
  )
281
281
  fix_mass_default = False
282
-
283
- if "hook" in invariant:
284
- from multipers.point_measure import rectangle_to_hook_minimal_signed_barcode
285
- sms = [rectangle_to_hook_minimal_signed_barcode(pts,w) for pts,w in sms]
286
282
  if verbose:
287
283
  print("Done.")
288
284
  elif filtered_complex_.is_minpres:
@@ -345,9 +341,6 @@ def signed_measure(
345
341
  expand_collapse=expand_collapse,
346
342
  )
347
343
  fix_mass_default = False
348
- if "hook" in invariant:
349
- from multipers.point_measure import rectangle_to_hook_minimal_signed_barcode
350
- sms = [rectangle_to_hook_minimal_signed_barcode(pts,w) for pts,w in sms]
351
344
  if verbose:
352
345
  print("Done.")
353
346
  elif len(degrees) == 1 and degrees[0] is None:
@@ -420,6 +413,10 @@ def signed_measure(
420
413
  sms = zero_out_sms(sms, mass_default=mass_default)
421
414
  if verbose:
422
415
  print("Done.")
416
+
417
+ if invariant == "hook":
418
+ from multipers.point_measure import rectangle_to_hook_minimal_signed_barcode
419
+ sms = [rectangle_to_hook_minimal_signed_barcode(pts,w) for pts,w in sms]
423
420
  if plot:
424
421
  plot_signed_measures(sms)
425
422
  return sms
@@ -1,7 +1,15 @@
1
1
  import multipers.array_api.numpy as npapi
2
2
 
3
3
 
4
- def api_from_tensor(x, *, verbose: bool = False):
4
+ def api_from_tensor(x, *, verbose: bool = False, strict=False):
5
+ if strict:
6
+ if npapi.is_tensor(x):
7
+ return npapi
8
+ import multipers.array_api.torch as torchapi
9
+
10
+ if torchapi.is_tensor(x):
11
+ return torchapi
12
+ raise ValueError(f"Unsupported (strict) type {type(x)=}")
5
13
  if npapi.is_promotable(x):
6
14
  if verbose:
7
15
  print("using numpy backend")
@@ -43,3 +51,12 @@ def api_from_tensors(*args):
43
51
  def to_numpy(x):
44
52
  api = api_from_tensor(x)
45
53
  return api.asnumpy(x)
54
+
55
+
56
+ def check_keops():
57
+ import os
58
+
59
+ if os.name == "nt":
60
+ # see https://github.com/getkeops/keops/pull/421
61
+ return False
62
+ return npapi.check_keops()
@@ -19,10 +19,71 @@ max = _np.max
19
19
  repeat_interleave = _np.repeat
20
20
  cdist = cdist # type: ignore[no-redef]
21
21
  unique = _np.unique
22
+ inf = _np.inf
23
+ searchsorted = _np.searchsorted
24
+ LazyTensor = None
25
+
26
+ # Test keops
27
+ _is_keops_available = None
28
+
29
+
30
+ def check_keops():
31
+ global _is_keops_available, LazyTensor
32
+ if _is_keops_available is not None:
33
+ return _is_keops_available
34
+ try:
35
+ if _is_keops_available is not None:
36
+ return _is_keops_available
37
+ import pykeops.numpy as pknp
38
+ from pykeops.numpy import LazyTensor as LT
39
+
40
+ formula = "SqNorm2(x - y)"
41
+ var = ["x = Vi(3)", "y = Vj(3)"]
42
+ expected_res = _np.array([63.0, 90.0])
43
+ x = _np.arange(1, 10).reshape(-1, 3).astype("float32")
44
+ y = _np.arange(3, 9).reshape(-1, 3).astype("float32")
45
+
46
+ my_conv = pknp.Genred(formula, var)
47
+ _is_keops_available = _np.allclose(my_conv(x, y).flatten(), expected_res)
48
+ LazyTensor = LT
49
+ except:
50
+ from warnings import warn
51
+
52
+ warn("Could not initialize keops (numpy). using workarounds")
53
+ _is_keops_available = False
54
+
55
+ return _is_keops_available
56
+
57
+
58
+ def from_numpy(x):
59
+ return _np.asarray(x)
60
+
61
+
62
+ def ascontiguous(x):
63
+ return _np.ascontiguousarray(x)
64
+
65
+
66
+ def sort(x, axis=-1):
67
+ return _np.sort(x, axis=axis)
68
+
69
+
70
+ def device(x): # type: ignore[no-unused-arg]
71
+ return None
72
+
73
+
74
+ # type: ignore[no-unused-arg]
75
+ def linspace(low, high, r, device=None, dtype=None):
76
+ return _np.linspace(low, high, r, dtype=dtype)
77
+
78
+
79
+ def cartesian_product(*arrays, dtype=None):
80
+ mesh = _np.meshgrid(*arrays, indexing="ij")
81
+ coordinates = _np.stack(mesh, axis=-1).reshape(-1, len(arrays)).astype(dtype)
82
+ return coordinates
22
83
 
23
84
 
24
85
  def quantile_closest(x, q, axis=None):
25
- return _np.quantile(x, q, axis=axis, interpolation="closest_observation")
86
+ return _np.quantile(x, q, axis=axis, method="closest_observation")
26
87
 
27
88
 
28
89
  def minvalues(x: _np.ndarray, **kwargs):
@@ -33,6 +94,10 @@ def maxvalues(x: _np.ndarray, **kwargs):
33
94
  return _np.max(x, **kwargs)
34
95
 
35
96
 
97
+ def is_tensor(x):
98
+ return isinstance(x, _np.ndarray)
99
+
100
+
36
101
  def is_promotable(x):
37
102
  return isinstance(x, _np.ndarray | list | tuple)
38
103
 
@@ -15,6 +15,61 @@ zeros = _t.zeros
15
15
  min = _t.min
16
16
  max = _t.max
17
17
  repeat_interleave = _t.repeat_interleave
18
+ linspace = _t.linspace
19
+ cartesian_product = _t.cartesian_prod
20
+ inf = _t.inf
21
+ searchsorted = _t.searchsorted
22
+ LazyTensor = None
23
+
24
+
25
+ _is_keops_available = None
26
+
27
+
28
+ def check_keops():
29
+ global _is_keops_available, LazyTensor
30
+ if _is_keops_available is not None:
31
+ return _is_keops_available
32
+ try:
33
+ import pykeops.torch as pknp
34
+ from pykeops.torch import LazyTensor as LT
35
+
36
+ formula = "SqNorm2(x - y)"
37
+ var = ["x = Vi(3)", "y = Vj(3)"]
38
+ expected_res = _t.tensor([63.0, 90.0])
39
+ x = _t.arange(1, 10, dtype=_t.float32).view(-1, 3)
40
+ y = _t.arange(3, 9, dtype=_t.float32).view(-1, 3)
41
+
42
+ my_conv = pknp.Genred(formula, var)
43
+ _is_keops_available = _t.allclose(
44
+ my_conv(x, y).view(-1), expected_res.type(_t.float32)
45
+ )
46
+ LazyTensor = LT
47
+
48
+ except:
49
+ from warnings import warn
50
+
51
+ warn("Could not initialize keops (torch). using workarounds")
52
+
53
+ _is_keops_available = False
54
+
55
+ return _is_keops_available
56
+ check_keops()
57
+
58
+
59
+ def from_numpy(x):
60
+ return _t.from_numpy(x)
61
+
62
+
63
+ def ascontiguous(x):
64
+ return _t.as_tensor(x).contiguous()
65
+
66
+
67
+ def device(x):
68
+ return x.device
69
+
70
+
71
+ def sort(x, axis=-1):
72
+ return _t.sort(x, dim=axis).values
18
73
 
19
74
 
20
75
  # in our context, this allows to get a correct gradient.
@@ -28,10 +83,10 @@ def unique(x, assume_sorted=False, _mean=True):
28
83
  _, c = _t.unique(x, sorted=True, return_counts=True)
29
84
  if _mean:
30
85
  x = _t.segment_reduce(data=x, reduce="mean", lengths=c, unsafe=True, axis=0)
31
- return x
32
-
33
- c = _np.concatenate([[0], _np.cumsum(c[:-1])])
34
- return x[c]
86
+ else:
87
+ c = _np.concatenate([[0], _np.cumsum(c[:-1])])
88
+ x = x[c]
89
+ return x
35
90
 
36
91
 
37
92
  def quantile_closest(x, q, axis=None):
@@ -50,6 +105,10 @@ def asnumpy(x):
50
105
  return x.detach().numpy()
51
106
 
52
107
 
108
+ def is_tensor(x):
109
+ return isinstance(x, _t.Tensor)
110
+
111
+
53
112
  def is_promotable(x):
54
113
  return isinstance(x, _t.Tensor)
55
114
 
@@ -1,9 +1,10 @@
1
1
  from collections.abc import Callable, Iterable
2
2
  from typing import Any, Literal, Union
3
+
3
4
  import numpy as np
4
5
 
6
+ from multipers.array_api import api_from_tensor, api_from_tensors
5
7
 
6
- from multipers.array_api import api_from_tensor
7
8
  global available_kernels
8
9
  available_kernels = Union[
9
10
  Literal[
@@ -176,23 +177,24 @@ def _pts_convolution_pykeops(
176
177
  Pykeops convolution
177
178
  """
178
179
  if isinstance(pts, np.ndarray):
179
- _asarray_weights = lambda x : np.asarray(x, dtype=pts.dtype)
180
+ _asarray_weights = lambda x: np.asarray(x, dtype=pts.dtype)
180
181
  _asarray_grid = _asarray_weights
181
182
  else:
182
183
  import torch
183
- _asarray_weights = lambda x : torch.from_numpy(x).type(pts.dtype)
184
- _asarray_grid = lambda x : x.type(pts.dtype)
184
+
185
+ _asarray_weights = lambda x: torch.from_numpy(x).type(pts.dtype)
186
+ _asarray_grid = lambda x: x.type(pts.dtype)
185
187
  kde = KDE(kernel=kernel, bandwidth=bandwidth, **more_kde_args)
186
- return kde.fit(
187
- pts, sample_weights=_asarray_weights(pts_weights)
188
- ).score_samples(_asarray_grid(grid_iterator))
188
+ return kde.fit(pts, sample_weights=_asarray_weights(pts_weights)).score_samples(
189
+ _asarray_grid(grid_iterator)
190
+ )
189
191
 
190
192
 
191
193
  def gaussian_kernel(x_i, y_j, bandwidth):
192
194
  D = x_i.shape[-1]
193
195
  exponent = -(((x_i - y_j) / bandwidth) ** 2).sum(dim=-1) / 2
194
196
  # float is necessary for some reason (pykeops fails)
195
- kernel = (exponent).exp() / float((bandwidth*np.sqrt(2 * np.pi))**D)
197
+ kernel = (exponent).exp() / float((bandwidth * np.sqrt(2 * np.pi)) ** D)
196
198
  return kernel
197
199
 
198
200
 
@@ -359,49 +361,6 @@ class KDE:
359
361
  )
360
362
 
361
363
 
362
- def batch_signed_measure_convolutions(
363
- signed_measures, # array of shape (num_data,num_pts,D)
364
- x, # array of shape (num_x, D) or (num_data, num_x, D)
365
- bandwidth, # either float or matrix if multivariate kernel
366
- kernel: available_kernels,
367
- ):
368
- """
369
- Input
370
- -----
371
- - signed_measures: unragged, of shape (num_data, num_pts, D+1)
372
- where last coord is weights, (0 for dummy points)
373
- - x : the points to convolve (num_x,D)
374
- - bandwidth : the bandwidths or covariance matrix inverse or ... of the kernel
375
- - kernel : "gaussian", "multivariate_gaussian", "exponential", or Callable (x_i, y_i, bandwidth)->float
376
-
377
- Output
378
- ------
379
- Array of shape (num_convolutions, (num_axis), num_data,
380
- Array of shape (num_convolutions, (num_axis), num_data, max_x_size)
381
- """
382
- if signed_measures.ndim == 2:
383
- signed_measures = signed_measures[None, :, :]
384
- sms = signed_measures[..., :-1]
385
- weights = signed_measures[..., -1]
386
- if isinstance(signed_measures, np.ndarray):
387
- from pykeops.numpy import LazyTensor
388
- else:
389
- import torch
390
-
391
- assert isinstance(signed_measures, torch.Tensor)
392
- from pykeops.torch import LazyTensor
393
-
394
- _sms = LazyTensor(sms[..., None, :].contiguous())
395
- _x = x[..., None, :, :].contiguous()
396
-
397
- sms_kernel = _kernel(kernel)(_sms, _x, bandwidth)
398
- out = (sms_kernel * weights[..., None, None].contiguous()).sum(
399
- signed_measures.ndim - 2
400
- )
401
- assert out.shape[-1] == 1, "Pykeops bug fixed, TODO : refix this "
402
- out = out[..., 0] ## pykeops bug + ensures its a tensor
403
- # assert out.shape == (x.shape[0], x.shape[1]), f"{x.shape=}, {out.shape=}"
404
- return out
405
364
 
406
365
 
407
366
  class DTM:
@@ -532,7 +491,7 @@ class KNNmean:
532
491
 
533
492
  # Symbolic distance matrix:
534
493
  if self.metric == "euclidean":
535
- D_ij = ((X_i - X_j) ** 2).sum(-1) ** (1/2)
494
+ D_ij = ((X_i - X_j) ** 2).sum(-1) ** (1 / 2)
536
495
  elif self.metric == "manhattan":
537
496
  D_ij = (X_i - X_j).abs().sum(-1)
538
497
  elif self.metric == "angular":
@@ -76,8 +76,8 @@ def RipsLowerstar(
76
76
  st.fill_lowerstar(api.asnumpy(function[:, i]), parameter=1 + i)
77
77
  if api.has_grad(D) or api.has_grad(function):
78
78
  from multipers.grids import compute_grid
79
-
80
- grid = compute_grid([D.ravel(), *[f for f in function.T]])
79
+ filtration_values = [D.ravel(), *[f for f in function.T]]
80
+ grid = compute_grid(filtration_values)
81
81
  st = st.grid_squeeze(grid)
82
82
  return st
83
83
 
Binary file
Binary file
multipers/grids.pyx CHANGED
@@ -11,6 +11,7 @@ from typing import Iterable,Literal,Optional
11
11
  from itertools import product
12
12
  from multipers.array_api import api_from_tensor, api_from_tensors
13
13
  from multipers.array_api import numpy as npapi
14
+ from multipers.array_api import check_keops
14
15
 
15
16
  available_strategies = ["regular","regular_closest", "regular_left", "partition", "quantile", "precomputed"]
16
17
  Lstrategies = Literal["regular","regular_closest", "regular_left", "partition", "quantile", "precomputed"]
@@ -123,8 +124,7 @@ def compute_grid(
123
124
  except TypeError:
124
125
  pass
125
126
 
126
- if api is npapi:
127
- return _compute_grid_numpy(
127
+ grid = _compute_grid_numpy(
128
128
  initial_grid,
129
129
  resolution=resolution,
130
130
  strategy = strategy,
@@ -132,9 +132,9 @@ def compute_grid(
132
132
  _q_factor=_q_factor,
133
133
  drop_quantiles=drop_quantiles,
134
134
  dense = dense,
135
- )
136
- from multipers.torch.diff_grids import get_grid
137
- grid = get_grid(strategy)(initial_grid,resolution)
135
+ )
136
+ # from multipers.torch.diff_grids import get_grid
137
+ # grid = get_grid(strategy)(initial_grid,resolution)
138
138
  if dense:
139
139
  grid = todense(grid)
140
140
  return grid
@@ -168,41 +168,41 @@ def _compute_grid_numpy(
168
168
  Iterable[array[float, ndim=1]] : the 1d-grid for each parameter.
169
169
  """
170
170
  num_parameters = len(filtrations_values)
171
+ api = api_from_tensors(*filtrations_values)
171
172
  try:
172
173
  a,b=drop_quantiles
173
174
  except:
174
175
  a,b=drop_quantiles,drop_quantiles
175
176
 
176
177
  if a != 0 or b != 0:
177
- boxes = np.asarray([np.quantile(filtration, [a, b], axis=1, method='closest_observation') for filtration in filtrations_values])
178
- min_filtration, max_filtration = np.min(boxes, axis=(0,1)), np.max(boxes, axis=(0,1)) # box, birth/death, filtration
178
+ boxes = api.astensor([api.quantile_closest(filtration, [a, b], axis=1) for filtration in filtrations_values])
179
+ min_filtration, max_filtration = api.minvalues(boxes, axis=(0,1)), api.maxvalues(boxes, axis=(0,1)) # box, birth/death, filtration
179
180
  filtrations_values = [
180
181
  filtration[(m<filtration) * (filtration <M)]
181
182
  for filtration, m,M in zip(filtrations_values, min_filtration, max_filtration)
182
183
  ]
183
184
 
184
- to_unique = lambda f : np.unique(f) if isinstance(f,np.ndarray) else f.unique()
185
185
  ## match doesn't work with cython BUG
186
186
  if strategy == "exact":
187
- F=tuple(to_unique(f) for f in filtrations_values)
187
+ F=tuple(api.unique(f) for f in filtrations_values)
188
188
  elif strategy == "quantile":
189
- F = tuple(to_unique(f) for f in filtrations_values)
189
+ F = tuple(api.unique(f) for f in filtrations_values)
190
190
  max_resolution = [min(len(f),r) for f,r in zip(F,resolution)]
191
- F = tuple( np.quantile(f, q=np.linspace(0,1,num=int(r*_q_factor)), axis=0, method='closest_observation') for f,r in zip(F, resolution) )
191
+ F = tuple( api.quantile_closest(f, q=api.linspace(0,1,int(r*_q_factor)), axis=0) for f,r in zip(F, resolution) )
192
192
  if unique:
193
- F = tuple(to_unique(f) for f in F)
193
+ F = tuple(api.unique(f) for f in F)
194
194
  if np.all(np.asarray(max_resolution) > np.asarray([len(f) for f in F])):
195
195
  return _compute_grid_numpy(filtrations_values=filtrations_values, resolution=resolution, strategy="quantile",_q_factor=1.5*_q_factor)
196
196
  elif strategy == "regular":
197
- F = tuple(np.linspace(np.min(f),np.max(f),num=r, dtype=np.asarray(f).dtype) for f,r in zip(filtrations_values, resolution))
197
+ F = tuple(_todo_regular(f,r,api) for f,r in zip(filtrations_values, resolution))
198
198
  elif strategy == "regular_closest":
199
- F = tuple(_todo_regular_closest(f,r, unique) for f,r in zip(filtrations_values, resolution))
199
+ F = tuple(_todo_regular_closest(f,r, unique,api) for f,r in zip(filtrations_values, resolution))
200
200
  elif strategy == "regular_left":
201
- F = tuple(_todo_regular_left(f,r, unique) for f,r in zip(filtrations_values, resolution))
202
- elif strategy == "torch_regular_closest":
203
- F = tuple(_torch_regular_closest(f,r, unique) for f,r in zip(filtrations_values, resolution))
201
+ F = tuple(_todo_regular_left(f,r, unique,api) for f,r in zip(filtrations_values, resolution))
202
+ # elif strategy == "torch_regular_closest":
203
+ # F = tuple(_torch_regular_closest(f,r, unique) for f,r in zip(filtrations_values, resolution))
204
204
  elif strategy == "partition":
205
- F = tuple(_todo_partition(f,r, unique) for f,r in zip(filtrations_values, resolution))
205
+ F = tuple(_todo_partition(f,r, unique, api) for f,r in zip(filtrations_values, resolution))
206
206
  elif strategy == "precomputed":
207
207
  F=filtrations_values
208
208
  else:
@@ -214,43 +214,85 @@ def _compute_grid_numpy(
214
214
  def todense(grid, bool product_order=False):
215
215
  if len(grid) == 0:
216
216
  return np.empty(0)
217
- if not isinstance(grid[0], np.ndarray):
218
- import torch
219
- assert isinstance(grid[0], torch.Tensor)
220
- from multipers.torch.diff_grids import todense
221
- return todense(grid)
222
- dtype = grid[0].dtype
223
- if product_order:
224
- return np.fromiter(product(*grid), dtype=np.dtype((dtype, len(grid))), count=np.prod([len(f) for f in grid]))
225
- mesh = np.meshgrid(*grid)
226
- coordinates = np.concatenate(tuple(stuff.ravel()[:,None] for stuff in mesh), axis=1, dtype=dtype)
227
- return coordinates
228
-
229
-
230
-
231
- ## TODO : optimize. Pykeops ?
232
- def _todo_regular_closest(some_float[:] f, int r, bool unique):
217
+ api = api_from_tensors(*grid)
218
+ # if product_order:
219
+ # if not api.backend ==np:
220
+ # raise NotImplementedError("only numpy here.")
221
+ # return np.fromiter(product(*grid), dtype=np.dtype((dtype, len(grid))), count=np.prod([len(f) for f in grid]))
222
+ return api.cartesian_product(*grid)
223
+ # if not isinstance(grid[0], np.ndarray):
224
+ # import torch
225
+ # assert isinstance(grid[0], torch.Tensor)
226
+ # from multipers.torch.diff_grids import todense
227
+ # return todense(grid)
228
+ # dtype = grid[0].dtype
229
+ # if product_order:
230
+ # return np.fromiter(product(*grid), dtype=np.dtype((dtype, len(grid))), count=np.prod([len(f) for f in grid]))
231
+ # mesh = np.meshgrid(*grid)
232
+ # coordinates = np.stack(mesh, axis=-1).reshape(-1, len(grid)).astype(dtype)
233
+ # return coordinates
234
+
235
+
236
+
237
+ def _todo_regular(f, int r, api):
238
+ if api.has_grad(f):
239
+ from warnings import warn
240
+ warn("`strategy=regular` is not differentiable. Removing grad.")
241
+ with api.no_grad():
242
+ return api.linspace(api.min(f), api.max(f), r)
243
+
244
+ def _project_on_1d_grid(f,grid, bool unique, api):
245
+ # api=api_from_tensors(f,grid)
246
+ if f.ndim != 1:
247
+ raise ValueError(f"Got ndim!=1. {f=}")
248
+ f = api.unique(f)
249
+ with api.no_grad():
250
+ _f = api.LazyTensor(f[:, None, None])
251
+ _f_reg = api.LazyTensor(grid[None, :, None])
252
+ indices = (_f - _f_reg).abs().argmin(0).ravel()
253
+ f = api.cat([f, api.tensor([api.inf], dtype=f.dtype)])
254
+ f_proj = f[indices]
255
+ if unique:
256
+ f_proj = api.unique(f_proj)
257
+ return f_proj
258
+
259
+ def _todo_regular_closest_keops(f, int r, bool unique, api):
260
+ f = api.astensor(f)
261
+ with api.no_grad():
262
+ f_regular = api.linspace(api.min(f), api.max(f), r, device = api.device(f),dtype=f.dtype)
263
+ return _project_on_1d_grid(f,f_regular,unique,api)
264
+
265
+ def _todo_regular_closest_old(some_float[:] f, int r, bool unique, api=None):
233
266
  f_array = np.asarray(f)
234
267
  f_regular = np.linspace(np.min(f), np.max(f),num=r, dtype=f_array.dtype)
235
- f_regular_closest = np.asarray([f[<int64_t>np.argmin(np.abs(f_array-f_regular[i]))] for i in range(r)])
268
+ f_regular_closest = np.asarray([f[<int64_t>np.argmin(np.abs(f_array-f_regular[i]))] for i in range(r)], dtype=f_array.dtype)
236
269
  if unique: f_regular_closest = np.unique(f_regular_closest)
237
270
  return f_regular_closest
238
271
 
239
- def _todo_regular_left(some_float[:] f, int r, bool unique):
272
+ def _todo_regular_left(f, int r, bool unique,api):
273
+ sorted_f = api.sort(f)
274
+ with api.no_grad():
275
+ f_regular = api.linspace(sorted_f[0],sorted_f[-1],r, dtype=sorted_f.dtype, device=api.device(sorted_f))
276
+ idx=api.searchsorted(sorted_f,f_regular)
277
+ f_regular_closest = sorted_f[idx]
278
+ if unique: f_regular_closest = api.unique(f_regular_closest)
279
+ return f_regular_closest
280
+
281
+ def _todo_regular_left_old(some_float[:] f, int r, bool unique):
240
282
  sorted_f = np.sort(f)
241
283
  f_regular = np.linspace(sorted_f[0],sorted_f[-1],num=r, dtype=sorted_f.dtype)
242
284
  f_regular_closest = sorted_f[np.searchsorted(sorted_f,f_regular)]
243
285
  if unique: f_regular_closest = np.unique(f_regular_closest)
244
286
  return f_regular_closest
245
287
 
246
- def _torch_regular_closest(f, int r, bool unique=True):
247
- import torch
248
- f_regular = torch.linspace(f.min(),f.max(), r, dtype=f.dtype)
249
- f_regular_closest =torch.tensor([f[(f-x).abs().argmin()] for x in f_regular])
250
- if unique: f_regular_closest = f_regular_closest.unique()
251
- return f_regular_closest
288
+ def _todo_partition(x, int resolution, bool unique, api):
289
+ if api.has_grad(x):
290
+ from warnings import warn
291
+ warn("`strategy=partition` is not differentiable. Removing grad.")
292
+ out = _todo_partition_(api.asnumpy(x), resolution, unique)
293
+ return api.from_numpy(out)
252
294
 
253
- def _todo_partition(some_float[:] data,int resolution, bool unique):
295
+ def _todo_partition_(some_float[:] data,int resolution, bool unique):
254
296
  if data.shape[0] < resolution: resolution=data.shape[0]
255
297
  k = data.shape[0] // resolution
256
298
  partitions = np.partition(data, k)
@@ -259,6 +301,12 @@ def _todo_partition(some_float[:] data,int resolution, bool unique):
259
301
  return f
260
302
 
261
303
 
304
+ if check_keops():
305
+ _todo_regular_closest = _todo_regular_closest_keops
306
+ else:
307
+ _todo_regular_closest = _todo_regular_closest_old
308
+
309
+
262
310
  def compute_bounding_box(stuff, inflate = 0.):
263
311
  r"""
264
312
  Returns a array of shape (2, num_parameters)
@@ -805,7 +805,7 @@ class Multi_critical_filtration {
805
805
  res.add_generator(nf);
806
806
  }
807
807
  }
808
- swap(f1, res);
808
+ std::swap(f1, res);
809
809
 
810
810
  return f1 != res;
811
811
  }
Binary file