multipers 2.3.0__cp310-cp310-win_amd64.whl → 2.3.2__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (54) hide show
  1. multipers/_signed_measure_meta.py +71 -65
  2. multipers/array_api/__init__.py +39 -0
  3. multipers/array_api/numpy.py +34 -0
  4. multipers/array_api/torch.py +35 -0
  5. multipers/distances.py +6 -2
  6. multipers/{ml/convolutions.py → filtrations/density.py} +67 -13
  7. multipers/filtrations/filtrations.py +76 -17
  8. multipers/function_rips.cp310-win_amd64.pyd +0 -0
  9. multipers/grids.cp310-win_amd64.pyd +0 -0
  10. multipers/grids.pyx +144 -61
  11. multipers/gudhi/Simplex_tree_multi_interface.h +36 -1
  12. multipers/gudhi/gudhi/Multi_persistence/Box.h +3 -0
  13. multipers/gudhi/gudhi/One_critical_filtration.h +18 -9
  14. multipers/gudhi/mma_interface_h0.h +1 -1
  15. multipers/gudhi/mma_interface_matrix.h +10 -1
  16. multipers/gudhi/naive_merge_tree.h +1 -1
  17. multipers/gudhi/truc.h +555 -42
  18. multipers/io.cp310-win_amd64.pyd +0 -0
  19. multipers/io.pyx +26 -93
  20. multipers/ml/mma.py +4 -4
  21. multipers/ml/point_clouds.py +2 -2
  22. multipers/ml/signed_measures.py +63 -65
  23. multipers/mma_structures.cp310-win_amd64.pyd +0 -0
  24. multipers/mma_structures.pxd +2 -1
  25. multipers/mma_structures.pyx +56 -16
  26. multipers/mma_structures.pyx.tp +14 -5
  27. multipers/multiparameter_module_approximation/approximation.h +48 -14
  28. multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
  29. multipers/multiparameter_module_approximation.pyx +27 -8
  30. multipers/plots.py +2 -1
  31. multipers/point_measure.cp310-win_amd64.pyd +0 -0
  32. multipers/point_measure.pyx +6 -2
  33. multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
  34. multipers/simplex_tree_multi.pxd +1 -0
  35. multipers/simplex_tree_multi.pyx +632 -146
  36. multipers/simplex_tree_multi.pyx.tp +92 -24
  37. multipers/slicer.cp310-win_amd64.pyd +0 -0
  38. multipers/slicer.pxd +779 -177
  39. multipers/slicer.pxd.tp +24 -5
  40. multipers/slicer.pyx +5657 -1427
  41. multipers/slicer.pyx.tp +211 -48
  42. multipers/tbb12.dll +0 -0
  43. multipers/tbbbind_2_5.dll +0 -0
  44. multipers/tbbmalloc.dll +0 -0
  45. multipers/tbbmalloc_proxy.dll +0 -0
  46. multipers/tensor/tensor.h +1 -1
  47. multipers/tests/__init__.py +9 -4
  48. multipers/torch/diff_grids.py +30 -7
  49. multipers/torch/rips_density.py +1 -1
  50. {multipers-2.3.0.dist-info → multipers-2.3.2.dist-info}/METADATA +4 -25
  51. {multipers-2.3.0.dist-info → multipers-2.3.2.dist-info}/RECORD +54 -51
  52. {multipers-2.3.0.dist-info → multipers-2.3.2.dist-info}/WHEEL +1 -1
  53. {multipers-2.3.0.dist-info → multipers-2.3.2.dist-info/licenses}/LICENSE +0 -0
  54. {multipers-2.3.0.dist-info → multipers-2.3.2.dist-info}/top_level.txt +0 -0
Binary file
multipers/io.pyx CHANGED
@@ -15,10 +15,10 @@ cimport cython
15
15
  # from multipers.filtration_conversions cimport *
16
16
  # from multipers.mma_structures cimport boundary_matrix,float,pair,vector,intptr_t
17
17
  # cimport numpy as cnp
18
-
18
+ current_doc_url = "https://davidlapous.github.io/multipers/"
19
19
  doc_soft_urls = {
20
20
  "mpfree":"https://bitbucket.org/mkerber/mpfree/",
21
- "multi_chunk":"",
21
+ "multi_chunk":"https://bitbucket.org/mkerber/multi_chunk/",
22
22
  "function_delaunay":"https://bitbucket.org/mkerber/function_delaunay/",
23
23
  "2pac":"https://gitlab.com/flenzen/2pac",
24
24
  }
@@ -27,7 +27,9 @@ doc_soft_easy_install = {
27
27
  ```sh
28
28
  git clone {doc_soft_urls["mpfree"]}
29
29
  cd mpfree
30
- sudo cp mpfree /usr/bin/
30
+ cmake . --fresh
31
+ make
32
+ cp mpfree $CONDA_PREFIX/bin/
31
33
  cd ..
32
34
  rm -rf mpfree
33
35
  ```
@@ -36,7 +38,9 @@ rm -rf mpfree
36
38
  ```sh
37
39
  git clone {doc_soft_urls["multi_chunk"]}
38
40
  cd multi_chunk
39
- sudo cp multi_chunk /usr/bin/
41
+ cmake . --fresh
42
+ make
43
+ cp multi_chunk $CONDA_PREFIX/bin/
40
44
  cd ..
41
45
  rm -rf multi_chunk
42
46
  ```
@@ -45,7 +49,9 @@ rm -rf multi_chunk
45
49
  ```sh
46
50
  git clone {doc_soft_urls["function_delaunay"]}
47
51
  cd function_delaunay
48
- sudo cp main /usr/bin/function_delaunay
52
+ cmake . --fresh
53
+ make
54
+ cp main $CONDA_PREFIX/bin/function_delaunay
49
55
  cd ..
50
56
  rm -rf function_delaunay
51
57
  ```
@@ -56,7 +62,7 @@ git clone {doc_soft_urls["2pac"]} 2pac
56
62
  cd 2pac && mkdir build && cd build
57
63
  cmake ..
58
64
  make
59
- sudo cp 2pac /usr/bin
65
+ cp 2pac $CONDA_PREFIX/bin
60
66
  ```
61
67
  """,
62
68
  }
@@ -65,7 +71,6 @@ doc_soft_easy_install = defaultdict(lambda:"<Unknown>", doc_soft_easy_install)
65
71
 
66
72
  available_reduce_softs = Literal["mpfree","multi_chunk","2pac"]
67
73
 
68
-
69
74
  def _path_init(soft:str|os.PathLike):
70
75
  a = which(f"./{soft}")
71
76
  b = which(f"{soft}")
@@ -104,6 +109,13 @@ cdef dict[str,str|None] pathes = {
104
109
  input_path:str|os.PathLike = "multipers_input.scc"
105
110
  output_path:str|os.PathLike = "multipers_output.scc"
106
111
 
112
+ def _put_temp_files_to_ram():
113
+ global input_path,output_path
114
+ shm_memory = "/tmp/" # on unix, we can write in RAM instead of disk.
115
+ if os.access(shm_memory, os.W_OK) and not input_path.startswith(shm_memory):
116
+ input_path = shm_memory + input_path
117
+ output_path = shm_memory + output_path
118
+ _put_temp_files_to_ram()
107
119
 
108
120
 
109
121
  ## TODO : optimize with Python.h ?
@@ -200,14 +212,6 @@ def scc_parser__old(path: str):
200
212
  return blocks
201
213
 
202
214
 
203
-
204
- def _put_temp_files_to_ram():
205
- global input_path,output_path
206
- shm_memory = "/tmp/" # on unix, we can write in RAM instead of disk.
207
- if os.access(shm_memory, os.W_OK) and not input_path.startswith(shm_memory):
208
- input_path = shm_memory + input_path
209
- output_path = shm_memory + output_path
210
-
211
215
  def _init_external_softwares(requires=[]):
212
216
  global pathes
213
217
  cdef bool any = False
@@ -222,12 +226,14 @@ def _init_external_softwares(requires=[]):
222
226
  if pathes[soft] is None:
223
227
  global doc_soft_urls
224
228
  raise ValueError(f"""
225
- Did not found {soft}.
229
+ Did not find {soft}.
226
230
  Install it from {doc_soft_urls[soft]}, and put it in your current directory,
227
231
  or in you $PATH.
232
+ Documentation is available here: {current_doc_url}compilation.html#external-libraries
228
233
  For instance:
229
234
  {doc_soft_easy_install[soft]}
230
235
  """)
236
+ _init_external_softwares()
231
237
  def _check_available(soft:str):
232
238
  _init_external_softwares()
233
239
  return pathes.get(soft,None) is not None
@@ -255,8 +261,7 @@ def scc_reduce_from_str(
255
261
  backend: "mpfree", "multi_chunk" or "2pac"
256
262
  """
257
263
  global pathes, input_path, output_path
258
- if pathes[backend] is None:
259
- _init_external_softwares(requires=[backend])
264
+ _init_external_softwares(requires=[backend])
260
265
 
261
266
 
262
267
  resolution_str = "--resolution" if full_resolution else ""
@@ -323,8 +328,7 @@ def scc_reduce_from_str_to_slicer(
323
328
  backend: "mpfree", "multi_chunk" or "2pac"
324
329
  """
325
330
  global pathes, input_path, output_path
326
- if pathes[backend] is None:
327
- _init_external_softwares(requires=[backend])
331
+ _init_external_softwares(requires=[backend])
328
332
 
329
333
 
330
334
  resolution_str = "--resolution" if full_resolution else ""
@@ -443,8 +447,7 @@ def function_delaunay_presentation(
443
447
  id = str(threading.get_native_id())
444
448
  global input_path, output_path, pathes
445
449
  backend = "function_delaunay"
446
- if pathes[backend] is None :
447
- _init_external_softwares(requires=[backend])
450
+ _init_external_softwares(requires=[backend])
448
451
 
449
452
  to_write = np.concatenate([point_cloud, function_values.reshape(-1,1)], axis=1)
450
453
  np.savetxt(input_path+id,to_write,delimiter=' ')
@@ -492,8 +495,7 @@ def function_delaunay_presentation_to_slicer(
492
495
  id = str(threading.get_native_id())
493
496
  global input_path, output_path, pathes
494
497
  backend = "function_delaunay"
495
- if pathes[backend] is None :
496
- _init_external_softwares(requires=[backend])
498
+ _init_external_softwares(requires=[backend])
497
499
 
498
500
  to_write = np.concatenate([point_cloud, function_values.reshape(-1,1)], axis=1)
499
501
  np.savetxt(input_path+id,to_write,delimiter=' ')
@@ -524,75 +526,6 @@ def clear_io(*args):
524
526
 
525
527
 
526
528
 
527
-
528
-
529
- # cdef extern from "multiparameter_module_approximation/format_python-cpp.h" namespace "Gudhi::multiparameter::mma":
530
- # pair[boundary_matrix, vector[One_critical_filtration[double]]] simplextree_to_boundary_filtration(intptr_t)
531
- # vector[pair[ vector[vector[float]],boundary_matrix]] simplextree_to_scc(intptr_t)
532
- # vector[pair[ vector[vector[vector[float]]],boundary_matrix]] function_simplextree_to_scc(intptr_t)
533
- # pair[vector[vector[float]],boundary_matrix ] simplextree_to_ordered_bf(intptr_t)
534
-
535
- # def simplex_tree2boundary_filtrations(simplextree:SimplexTreeMulti | SimplexTree):
536
- # """Computes a (sparse) boundary matrix, with associated filtration. Can be used as an input of approx afterwards.
537
- #
538
- # Parameters
539
- # ----------
540
- # simplextree: Gudhi or mma simplextree
541
- # The simplextree defining the filtration to convert to boundary-filtration.
542
- #
543
- # Returns
544
- # -------
545
- # B:List of lists of ints
546
- # The boundary matrix.
547
- # F: List of 1D filtration
548
- # The filtrations aligned with B; the i-th simplex of this simplextree has boundary B[i] and filtration(s) F[i].
549
- #
550
- # """
551
- # cdef intptr_t cptr
552
- # if isinstance(simplextree, SimplexTreeMulti):
553
- # cptr = simplextree.thisptr
554
- # elif isinstance(simplextree, SimplexTree):
555
- # temp_st = gd.SimplexTreeMulti(simplextree, parameters=1)
556
- # cptr = temp_st.thisptr
557
- # else:
558
- # raise TypeError("Has to be a simplextree")
559
- # cdef pair[boundary_matrix, vector[One_critical_filtration[double]]] cboundary_filtration = simplextree_to_boundary_filtration(cptr)
560
- # boundary = cboundary_filtration.first
561
- # # multi_filtrations = np.array(<vector[vector[float]]>One_critical_filtration.to_python(cboundary_filtration.second))
562
- # cdef cnp.ndarray[double, ndim=2] multi_filtrations = _fmf2numpy_f64(cboundary_filtration.second)
563
- # return boundary, multi_filtrations
564
-
565
- # def simplextree2scc(simplextree:SimplexTreeMulti | SimplexTree, filtration_dtype=np.float32, bool flattened=False):
566
- # """
567
- # Turns a simplextree into a (simplicial) module presentation.
568
- # """
569
- # cdef intptr_t cptr
570
- # cdef bool is_function_st = False
571
- # if isinstance(simplextree, SimplexTreeMulti):
572
- # cptr = simplextree.thisptr
573
- # is_function_st = simplextree._is_function_simplextree
574
- # elif isinstance(simplextree, SimplexTree):
575
- # temp_st = gd.SimplexTreeMulti(simplextree, parameters=1)
576
- # cptr = temp_st.thisptr
577
- # else:
578
- # raise TypeError("Has to be a simplextree")
579
- #
580
- # cdef pair[vector[vector[float]], boundary_matrix] out
581
- # if flattened:
582
- # out = simplextree_to_ordered_bf(cptr)
583
- # return np.asarray(out.first,dtype=filtration_dtype), tuple(out.second)
584
- #
585
- # if is_function_st:
586
- # blocks = function_simplextree_to_scc(cptr)
587
- # else:
588
- # blocks = simplextree_to_scc(cptr)
589
- # # reduces the space in memory
590
- # if is_function_st:
591
- # blocks = [(tuple(f), tuple(b)) for f,b in blocks[::-1]]
592
- # else:
593
- # blocks = [(np.asarray(f,dtype=filtration_dtype), tuple(b)) for f,b in blocks[::-1]] ## presentation is on the other order
594
- # return blocks+[(np.empty(0,dtype=filtration_dtype),[])]
595
-
596
529
  @cython.boundscheck(False)
597
530
  @cython.wraparound(False)
598
531
  def scc2disk(
multipers/ml/mma.py CHANGED
@@ -8,7 +8,7 @@ from tqdm import tqdm
8
8
  import multipers as mp
9
9
  import multipers.simplex_tree_multi
10
10
  import multipers.slicer
11
- from multipers.grids import compute_grid as reduce_grid
11
+ from multipers.grids import compute_grid
12
12
  from multipers.mma_structures import PyBox_f64, PyModule_type
13
13
 
14
14
  _FilteredComplexType = Union[
@@ -353,7 +353,7 @@ class MMAFormatter(BaseEstimator, TransformerMixin):
353
353
  if "_mean" in strategy:
354
354
  substrategy = strategy.split("_")[0]
355
355
  processed_filtration_values = [
356
- reduce_grid(f, resolution, substrategy, unique=False)
356
+ compute_grid(f, resolution, substrategy, unique=False)
357
357
  for f in filtration_values
358
358
  ]
359
359
  reduced_grid = np.mean(processed_filtration_values, axis=0)
@@ -368,7 +368,7 @@ class MMAFormatter(BaseEstimator, TransformerMixin):
368
368
  )
369
369
  for parameter in range(num_parameters)
370
370
  ]
371
- reduced_grid = reduce_grid(
371
+ reduced_grid = compute_grid(
372
372
  filtration_values, resolution, strategy, unique=True
373
373
  )
374
374
 
@@ -478,7 +478,7 @@ class MMAFormatter(BaseEstimator, TransformerMixin):
478
478
  if self.weights is None
479
479
  else np.asarray(self.weights)
480
480
  )
481
- standard_box = PyBox_f64([0] * self._num_parameters, w)
481
+ standard_box = np.array([[0] * self._num_parameters, w])
482
482
 
483
483
  X_copy = [
484
484
  [
@@ -10,7 +10,7 @@ from tqdm import tqdm
10
10
 
11
11
  import multipers as mp
12
12
  import multipers.slicer as mps
13
- from multipers.ml.convolutions import DTM, KDE, available_kernels
13
+ from multipers.filtrations.density import DTM, KDE, available_kernels
14
14
 
15
15
 
16
16
  class PointCloud2FilteredComplex(BaseEstimator, TransformerMixin):
@@ -176,7 +176,7 @@ class PointCloud2FilteredComplex(BaseEstimator, TransformerMixin):
176
176
  st = alpha_complex.create_simplex_tree(max_alpha_square=self._threshold**2)
177
177
  vertices = np.array([i for (i,), _ in st.get_skeleton(0)])
178
178
  new_points = np.asarray(
179
- [alpha_complex.get_point(i) for i in vertices]
179
+ [alpha_complex.get_point(int(i)) for i in vertices]
180
180
  ) # Seems to be unsafe for some reason
181
181
  # new_points = x
182
182
  st = mp.simplex_tree_multi.SimplexTreeMulti(
@@ -1,4 +1,4 @@
1
- from collections.abc import Callable, Iterable, Sequence
1
+ from collections.abc import Iterable, Sequence
2
2
  from itertools import product
3
3
  from typing import Optional, Union
4
4
 
@@ -9,9 +9,10 @@ from sklearn.base import BaseEstimator, TransformerMixin
9
9
  from tqdm import tqdm
10
10
 
11
11
  import multipers as mp
12
- from multipers.grids import compute_grid as reduce_grid
13
- from multipers.ml.convolutions import available_kernels, convolution_signed_measures
14
- from multipers.point_measure import signed_betti, rank_decomposition_by_rectangles
12
+ from multipers.array_api import api_from_tensor
13
+ from multipers.filtrations.density import available_kernels, convolution_signed_measures
14
+ from multipers.grids import compute_grid
15
+ from multipers.point_measure import rank_decomposition_by_rectangles, signed_betti
15
16
 
16
17
 
17
18
  class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
@@ -145,7 +146,7 @@ class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
145
146
  ## ax, num_x
146
147
  filtrations = tuple(
147
148
  tuple(
148
- reduce_grid(x, strategy="exact")
149
+ compute_grid(x, strategy="exact")
149
150
  for x in (X[idx][ax] for idx in indices)
150
151
  )
151
152
  for ax in range(self._num_axis)
@@ -164,7 +165,7 @@ class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
164
165
  ]
165
166
  ## ax, param, gridsize
166
167
  filtration_grid = tuple(
167
- reduce_grid(
168
+ compute_grid(
168
169
  filtrations_values[ax],
169
170
  resolution=self.resolution,
170
171
  strategy=self.grid_strategy,
@@ -237,7 +238,7 @@ class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
237
238
  ):
238
239
  # st = mp.SimplexTreeMulti(st, num_parameters=st.num_parameters) # COPY
239
240
  if self.individual_grid:
240
- filtration_grid = reduce_grid(
241
+ filtration_grid = compute_grid(
241
242
  simplextree, strategy=self.grid_strategy, resolution=self.resolution
242
243
  )
243
244
  mass_default = (
@@ -546,6 +547,27 @@ def rescale_sparse_signed_measure(
546
547
  return out
547
548
 
548
549
 
550
+ def sm2deep(signed_measure):
551
+ dirac_positions, dirac_signs = signed_measure
552
+ dtype = dirac_positions.dtype
553
+ new_shape = list(dirac_positions.shape)
554
+ new_shape[1] += 1
555
+ if isinstance(dirac_positions, np.ndarray):
556
+ c = np.empty(new_shape, dtype=dtype)
557
+ c[:, :-1] = dirac_positions
558
+ c[:, -1] = dirac_signs
559
+
560
+ else:
561
+ import torch
562
+
563
+ c = torch.empty(new_shape, dtype=dtype)
564
+ c[:, :-1] = dirac_positions
565
+ if isinstance(dirac_signs, np.ndarray):
566
+ dirac_signs = torch.from_numpy(dirac_signs)
567
+ c[:, -1] = dirac_signs
568
+ return c
569
+
570
+
549
571
  class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
550
572
  """
551
573
  Input
@@ -612,14 +634,8 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
612
634
  return
613
635
 
614
636
  def _get_filtration_bounds(self, X, axis):
615
- if self._backend == "numpy":
616
- _cat = np.concatenate
617
-
618
- else:
619
- ## torch is globally imported
620
- _cat = torch.cat
621
637
  stuff = [
622
- _cat(
638
+ self._backend.cat(
623
639
  [sm[axis][degree][0] for sm in X],
624
640
  axis=0,
625
641
  )
@@ -627,17 +643,20 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
627
643
  ]
628
644
  sizes_ = np.array([len(x) == 0 for x in stuff])
629
645
  assert np.all(~sizes_), f"Degree axis {np.where(sizes_)} is/are trivial !"
630
- if self._backend == "numpy":
631
- filtrations_bounds = np.array(
632
- [([f.min(axis=0), f.max(axis=0)]) for f in stuff]
633
- )
634
- else:
635
- filtrations_bounds = torch.stack(
646
+
647
+ filtrations_bounds = self._backend.asnumpy(
648
+ self._backend.stack(
636
649
  [
637
- torch.stack([f.min(axis=0).values, f.max(axis=0).values])
650
+ self._backend.stack(
651
+ [
652
+ self._backend.minvalues(f, axis=0),
653
+ self._backend.maxvalues(f, axis=0),
654
+ ]
655
+ )
638
656
  for f in stuff
639
657
  ]
640
- ).detach() ## don't want to rescale gradient of normalization
658
+ )
659
+ ) ## don't want to rescale gradient of normalization
641
660
  normalization_factors = (
642
661
  filtrations_bounds[:, 1] - filtrations_bounds[:, 0]
643
662
  if self.normalize
@@ -700,14 +719,7 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
700
719
  first_sm = X[0][0][0][0]
701
720
  else:
702
721
  first_sm = X[0][0][0]
703
- if isinstance(first_sm, np.ndarray):
704
- self._backend = "numpy"
705
- else:
706
- global torch
707
- import torch
708
-
709
- assert isinstance(first_sm, torch.Tensor)
710
- self._backend = "pytorch"
722
+ self._backend = api_from_tensor(first_sm, verbose=self.verbose)
711
723
 
712
724
  def _check_measures(self, X):
713
725
  if self._has_axis:
@@ -770,7 +782,7 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
770
782
  ]
771
783
  # axis, filtration_values
772
784
  filtration_values = [
773
- reduce_grid(
785
+ compute_grid(
774
786
  f_ax.T, resolution=self.resolution, strategy=self.grid_strategy
775
787
  )
776
788
  for f_ax in filtration_values
@@ -841,25 +853,6 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
841
853
  else:
842
854
  return np.asarray(out)[0]
843
855
 
844
- @staticmethod
845
- def deep_format_measure(signed_measure):
846
- dirac_positions, dirac_signs = signed_measure
847
- dtype = dirac_positions.dtype
848
- new_shape = list(dirac_positions.shape)
849
- new_shape[1] += 1
850
- if isinstance(dirac_positions, np.ndarray):
851
- c = np.empty(new_shape, dtype=dtype)
852
- c[:, :-1] = dirac_positions
853
- c[:, -1] = dirac_signs
854
-
855
- else:
856
- import torch
857
-
858
- c = torch.empty(new_shape, dtype=dtype)
859
- c[:, :-1] = dirac_positions
860
- c[:, -1] = dirac_signs
861
- return c
862
-
863
856
  @staticmethod
864
857
  def _integrate_measure(sm, filtrations):
865
858
  from multipers.point_measure import integrate_measure
@@ -928,7 +921,7 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
928
921
  elif self.deep_format:
929
922
  num_degrees = self._num_degrees
930
923
  out = tuple(
931
- tuple(self.deep_format_measure(sm[axis][degree]) for sm in out)
924
+ tuple(sm2deep(sm[axis][degree]) for sm in out)
932
925
  for degree in range(num_degrees)
933
926
  for axis in self._axis_iterator
934
927
  )
@@ -943,11 +936,7 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
943
936
  ), f"Bad axis/degree count. Got {num_axis_degree} (Internal error)"
944
937
  num_parameters = out[0][0].shape[1]
945
938
  dtype = out[0][0].dtype
946
- if isinstance(out[0][0], np.ndarray):
947
- from numpy import zeros
948
- else:
949
- from torch import zeros
950
- unragged_tensor = zeros(
939
+ unragged_tensor = self._backend.zeros(
951
940
  (
952
941
  num_axis_degree,
953
942
  num_data,
@@ -1027,6 +1016,7 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1027
1016
  self.plot = plot
1028
1017
  self.log_density = log_density
1029
1018
  self.kde_kwargs = kde_kwargs
1019
+ self._api = None
1030
1020
  return
1031
1021
 
1032
1022
  def fit(self, X, y=None):
@@ -1035,13 +1025,16 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1035
1025
  return self
1036
1026
  if isinstance(X[0][0], tuple):
1037
1027
  self._is_input_sparse = True
1028
+
1029
+ self._api = api_from_tensor(X[0][0][0], verbose=self.progress)
1038
1030
  else:
1039
1031
  self._is_input_sparse = False
1032
+
1033
+ self._api = api_from_tensor(X, verbose=self.progress)
1040
1034
  # print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
1041
1035
  if not self._is_input_sparse:
1042
1036
  self._input_resolution = X[0][0].shape
1043
1037
  try:
1044
- float(self.bandwidth)
1045
1038
  b = float(self.bandwidth)
1046
1039
  self._bandwidths = [
1047
1040
  b if b > 0 else -b * s for s in self._input_resolution
@@ -1060,24 +1053,24 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1060
1053
  # If not sparse : a grid has to be defined
1061
1054
  if self._refit:
1062
1055
  # print("Fitting a grid...", end="")
1063
- pts = np.concatenate(
1056
+ pts = self._api.cat(
1064
1057
  [sm[0] for signed_measures in X for sm in signed_measures]
1065
1058
  ).T
1066
- self.filtration_grid = reduce_grid(
1059
+ self.filtration_grid = compute_grid(
1067
1060
  pts,
1068
1061
  strategy=self.grid_strategy,
1069
1062
  resolution=self.resolution,
1070
1063
  )
1071
1064
  # print('Done.')
1072
1065
  if self.filtration_grid is not None:
1073
- self.diameter = np.linalg.norm(
1074
- [f.max() - f.min() for f in self.filtration_grid]
1066
+ self.diameter = self._api.norm(
1067
+ self._api.astensor([f[-1] - f[0] for f in self.filtration_grid])
1075
1068
  )
1076
1069
  if self.progress:
1077
1070
  print(f"Computed a diameter of {self.diameter}")
1078
1071
  return self
1079
1072
 
1080
- def _sm2smi(self, signed_measures: Iterable[np.ndarray]):
1073
+ def _sm2smi(self, signed_measures):
1081
1074
  # print(self._input_resolution, self.bandwidths, _bandwidths)
1082
1075
  from scipy.ndimage import gaussian_filter
1083
1076
 
@@ -1125,6 +1118,7 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1125
1118
  def _plot_imgs(self, imgs: Iterable[np.ndarray], size=4):
1126
1119
  from multipers.plots import plot_surface
1127
1120
 
1121
+ imgs = self._api.asnumpy(imgs)
1128
1122
  num_degrees = imgs[0].shape[0]
1129
1123
  num_imgs = len(imgs)
1130
1124
  fig, axes = plt.subplots(
@@ -1137,7 +1131,10 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1137
1131
  for i, img in enumerate(imgs):
1138
1132
  for j, img_of_degree in enumerate(img):
1139
1133
  plot_surface(
1140
- self.filtration_grid, img_of_degree, ax=axes[i, j], cmap="Spectral"
1134
+ [self._api.asnumpy(f) for f in self.filtration_grid],
1135
+ img_of_degree,
1136
+ ax=axes[i, j],
1137
+ cmap="Spectral",
1141
1138
  )
1142
1139
 
1143
1140
  def transform(self, X):
@@ -1153,6 +1150,7 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1153
1150
  X, desc="Computing images", disable=not self.progress
1154
1151
  )
1155
1152
  )
1153
+ out = self._api.cat([x[None] for x in out])
1156
1154
  if self.plot and not self.flatten:
1157
1155
  if self.progress:
1158
1156
  print("Plotting convolutions...", end="")
@@ -1160,8 +1158,8 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1160
1158
  if self.progress:
1161
1159
  print("Done !")
1162
1160
  if self.flatten and not self._is_input_sparse:
1163
- out = [x.flatten() for x in out]
1164
- return np.asarray(out)
1161
+ out = self._api.cat([x.ravel()[None] for x in out])
1162
+ return out
1165
1163
 
1166
1164
 
1167
1165
  class SignedMeasure2SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
@@ -30,7 +30,7 @@ cdef extern from "multiparameter_module_approximation/approximation.h" namespace
30
30
  bool contains(const vector[T]&) nogil const
31
31
  Box[T] get_bounds() nogil const
32
32
  void rescale(const vector[T]&) nogil
33
-
33
+ bool operator==(const Summand[T]&) nogil
34
34
 
35
35
 
36
36
 
@@ -86,6 +86,7 @@ cdef extern from "multiparameter_module_approximation/approximation.h" namespace
86
86
  Summand[T]& at(unsigned int) nogil
87
87
  vector[Summand[T]].iterator begin()
88
88
  vector[Summand[T]].iterator end()
89
+ bool operator==(const Module[T]&) nogil
89
90
  void clean(const bool) nogil
90
91
  void fill(const T) nogil
91
92
  # vector[image_type] get_vectorization(const T,const T, unsigned int,unsigned int,const Box&)