multipers 2.3.1__cp310-cp310-win_amd64.whl → 2.3.2b1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (49) hide show
  1. multipers/_signed_measure_meta.py +71 -65
  2. multipers/array_api/__init__.py +39 -0
  3. multipers/array_api/numpy.py +34 -0
  4. multipers/array_api/torch.py +35 -0
  5. multipers/distances.py +6 -2
  6. multipers/filtrations/density.py +23 -12
  7. multipers/filtrations/filtrations.py +74 -15
  8. multipers/function_rips.cp310-win_amd64.pyd +0 -0
  9. multipers/grids.cp310-win_amd64.pyd +0 -0
  10. multipers/grids.pyx +144 -61
  11. multipers/gudhi/Simplex_tree_multi_interface.h +35 -0
  12. multipers/gudhi/gudhi/Multi_persistence/Box.h +3 -0
  13. multipers/gudhi/gudhi/One_critical_filtration.h +17 -9
  14. multipers/gudhi/mma_interface_matrix.h +5 -3
  15. multipers/gudhi/truc.h +488 -42
  16. multipers/io.cp310-win_amd64.pyd +0 -0
  17. multipers/io.pyx +16 -86
  18. multipers/ml/mma.py +3 -3
  19. multipers/ml/signed_measures.py +60 -62
  20. multipers/mma_structures.cp310-win_amd64.pyd +0 -0
  21. multipers/mma_structures.pxd +2 -1
  22. multipers/mma_structures.pyx +56 -12
  23. multipers/mma_structures.pyx.tp +14 -3
  24. multipers/multiparameter_module_approximation/approximation.h +45 -13
  25. multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
  26. multipers/multiparameter_module_approximation.pyx +22 -6
  27. multipers/plots.py +1 -0
  28. multipers/point_measure.cp310-win_amd64.pyd +0 -0
  29. multipers/point_measure.pyx +6 -2
  30. multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
  31. multipers/simplex_tree_multi.pxd +1 -0
  32. multipers/simplex_tree_multi.pyx +487 -109
  33. multipers/simplex_tree_multi.pyx.tp +67 -18
  34. multipers/slicer.cp310-win_amd64.pyd +0 -0
  35. multipers/slicer.pxd +699 -217
  36. multipers/slicer.pxd.tp +22 -6
  37. multipers/slicer.pyx +5311 -1364
  38. multipers/slicer.pyx.tp +199 -46
  39. multipers/tbb12.dll +0 -0
  40. multipers/tbbbind_2_5.dll +0 -0
  41. multipers/tbbmalloc.dll +0 -0
  42. multipers/tbbmalloc_proxy.dll +0 -0
  43. multipers/tests/__init__.py +9 -4
  44. multipers/torch/diff_grids.py +30 -7
  45. {multipers-2.3.1.dist-info → multipers-2.3.2b1.dist-info}/METADATA +4 -25
  46. {multipers-2.3.1.dist-info → multipers-2.3.2b1.dist-info}/RECORD +49 -46
  47. {multipers-2.3.1.dist-info → multipers-2.3.2b1.dist-info}/WHEEL +1 -1
  48. {multipers-2.3.1.dist-info → multipers-2.3.2b1.dist-info/licenses}/LICENSE +0 -0
  49. {multipers-2.3.1.dist-info → multipers-2.3.2b1.dist-info}/top_level.txt +0 -0
Binary file
multipers/io.pyx CHANGED
@@ -18,7 +18,7 @@ cimport cython
18
18
  current_doc_url = "https://davidlapous.github.io/multipers/"
19
19
  doc_soft_urls = {
20
20
  "mpfree":"https://bitbucket.org/mkerber/mpfree/",
21
- "multi_chunk":"",
21
+ "multi_chunk":"https://bitbucket.org/mkerber/multi_chunk/",
22
22
  "function_delaunay":"https://bitbucket.org/mkerber/function_delaunay/",
23
23
  "2pac":"https://gitlab.com/flenzen/2pac",
24
24
  }
@@ -29,7 +29,7 @@ git clone {doc_soft_urls["mpfree"]}
29
29
  cd mpfree
30
30
  cmake . --fresh
31
31
  make
32
- sudo cp mpfree /usr/bin/
32
+ cp mpfree $CONDA_PREFIX/bin/
33
33
  cd ..
34
34
  rm -rf mpfree
35
35
  ```
@@ -40,7 +40,7 @@ git clone {doc_soft_urls["multi_chunk"]}
40
40
  cd multi_chunk
41
41
  cmake . --fresh
42
42
  make
43
- sudo cp multi_chunk /usr/bin/
43
+ cp multi_chunk $CONDA_PREFIX/bin/
44
44
  cd ..
45
45
  rm -rf multi_chunk
46
46
  ```
@@ -51,7 +51,7 @@ git clone {doc_soft_urls["function_delaunay"]}
51
51
  cd function_delaunay
52
52
  cmake . --fresh
53
53
  make
54
- sudo cp main /usr/bin/function_delaunay
54
+ cp main $CONDA_PREFIX/bin/function_delaunay
55
55
  cd ..
56
56
  rm -rf function_delaunay
57
57
  ```
@@ -62,7 +62,7 @@ git clone {doc_soft_urls["2pac"]} 2pac
62
62
  cd 2pac && mkdir build && cd build
63
63
  cmake ..
64
64
  make
65
- sudo cp 2pac /usr/bin
65
+ cp 2pac $CONDA_PREFIX/bin
66
66
  ```
67
67
  """,
68
68
  }
@@ -109,6 +109,13 @@ cdef dict[str,str|None] pathes = {
109
109
  input_path:str|os.PathLike = "multipers_input.scc"
110
110
  output_path:str|os.PathLike = "multipers_output.scc"
111
111
 
112
+ def _put_temp_files_to_ram():
113
+ global input_path,output_path
114
+ shm_memory = "/tmp/" # on unix, we can write in RAM instead of disk.
115
+ if os.access(shm_memory, os.W_OK) and not input_path.startswith(shm_memory):
116
+ input_path = shm_memory + input_path
117
+ output_path = shm_memory + output_path
118
+ _put_temp_files_to_ram()
112
119
 
113
120
 
114
121
  ## TODO : optimize with Python.h ?
@@ -205,14 +212,6 @@ def scc_parser__old(path: str):
205
212
  return blocks
206
213
 
207
214
 
208
-
209
- def _put_temp_files_to_ram():
210
- global input_path,output_path
211
- shm_memory = "/tmp/" # on unix, we can write in RAM instead of disk.
212
- if os.access(shm_memory, os.W_OK) and not input_path.startswith(shm_memory):
213
- input_path = shm_memory + input_path
214
- output_path = shm_memory + output_path
215
-
216
215
  def _init_external_softwares(requires=[]):
217
216
  global pathes
218
217
  cdef bool any = False
@@ -262,7 +261,7 @@ def scc_reduce_from_str(
262
261
  backend: "mpfree", "multi_chunk" or "2pac"
263
262
  """
264
263
  global pathes, input_path, output_path
265
- assert _check_available(backend), f"Backend {backend} is not available."
264
+ _init_external_softwares(requires=[backend])
266
265
 
267
266
 
268
267
  resolution_str = "--resolution" if full_resolution else ""
@@ -329,7 +328,7 @@ def scc_reduce_from_str_to_slicer(
329
328
  backend: "mpfree", "multi_chunk" or "2pac"
330
329
  """
331
330
  global pathes, input_path, output_path
332
- assert _check_available(backend), f"Backend {backend} is not available."
331
+ _init_external_softwares(requires=[backend])
333
332
 
334
333
 
335
334
  resolution_str = "--resolution" if full_resolution else ""
@@ -448,7 +447,7 @@ def function_delaunay_presentation(
448
447
  id = str(threading.get_native_id())
449
448
  global input_path, output_path, pathes
450
449
  backend = "function_delaunay"
451
- assert _check_available(backend), f"Backend {backend} is not available."
450
+ _init_external_softwares(requires=[backend])
452
451
 
453
452
  to_write = np.concatenate([point_cloud, function_values.reshape(-1,1)], axis=1)
454
453
  np.savetxt(input_path+id,to_write,delimiter=' ')
@@ -496,7 +495,7 @@ def function_delaunay_presentation_to_slicer(
496
495
  id = str(threading.get_native_id())
497
496
  global input_path, output_path, pathes
498
497
  backend = "function_delaunay"
499
- assert _check_available(backend), f"Backend {backend} is not available."
498
+ _init_external_softwares(requires=[backend])
500
499
 
501
500
  to_write = np.concatenate([point_cloud, function_values.reshape(-1,1)], axis=1)
502
501
  np.savetxt(input_path+id,to_write,delimiter=' ')
@@ -527,75 +526,6 @@ def clear_io(*args):
527
526
 
528
527
 
529
528
 
530
-
531
-
532
- # cdef extern from "multiparameter_module_approximation/format_python-cpp.h" namespace "Gudhi::multiparameter::mma":
533
- # pair[boundary_matrix, vector[One_critical_filtration[double]]] simplextree_to_boundary_filtration(intptr_t)
534
- # vector[pair[ vector[vector[float]],boundary_matrix]] simplextree_to_scc(intptr_t)
535
- # vector[pair[ vector[vector[vector[float]]],boundary_matrix]] function_simplextree_to_scc(intptr_t)
536
- # pair[vector[vector[float]],boundary_matrix ] simplextree_to_ordered_bf(intptr_t)
537
-
538
- # def simplex_tree2boundary_filtrations(simplextree:SimplexTreeMulti | SimplexTree):
539
- # """Computes a (sparse) boundary matrix, with associated filtration. Can be used as an input of approx afterwards.
540
- #
541
- # Parameters
542
- # ----------
543
- # simplextree: Gudhi or mma simplextree
544
- # The simplextree defining the filtration to convert to boundary-filtration.
545
- #
546
- # Returns
547
- # -------
548
- # B:List of lists of ints
549
- # The boundary matrix.
550
- # F: List of 1D filtration
551
- # The filtrations aligned with B; the i-th simplex of this simplextree has boundary B[i] and filtration(s) F[i].
552
- #
553
- # """
554
- # cdef intptr_t cptr
555
- # if isinstance(simplextree, SimplexTreeMulti):
556
- # cptr = simplextree.thisptr
557
- # elif isinstance(simplextree, SimplexTree):
558
- # temp_st = gd.SimplexTreeMulti(simplextree, parameters=1)
559
- # cptr = temp_st.thisptr
560
- # else:
561
- # raise TypeError("Has to be a simplextree")
562
- # cdef pair[boundary_matrix, vector[One_critical_filtration[double]]] cboundary_filtration = simplextree_to_boundary_filtration(cptr)
563
- # boundary = cboundary_filtration.first
564
- # # multi_filtrations = np.array(<vector[vector[float]]>One_critical_filtration.to_python(cboundary_filtration.second))
565
- # cdef cnp.ndarray[double, ndim=2] multi_filtrations = _fmf2numpy_f64(cboundary_filtration.second)
566
- # return boundary, multi_filtrations
567
-
568
- # def simplextree2scc(simplextree:SimplexTreeMulti | SimplexTree, filtration_dtype=np.float32, bool flattened=False):
569
- # """
570
- # Turns a simplextree into a (simplicial) module presentation.
571
- # """
572
- # cdef intptr_t cptr
573
- # cdef bool is_function_st = False
574
- # if isinstance(simplextree, SimplexTreeMulti):
575
- # cptr = simplextree.thisptr
576
- # is_function_st = simplextree._is_function_simplextree
577
- # elif isinstance(simplextree, SimplexTree):
578
- # temp_st = gd.SimplexTreeMulti(simplextree, parameters=1)
579
- # cptr = temp_st.thisptr
580
- # else:
581
- # raise TypeError("Has to be a simplextree")
582
- #
583
- # cdef pair[vector[vector[float]], boundary_matrix] out
584
- # if flattened:
585
- # out = simplextree_to_ordered_bf(cptr)
586
- # return np.asarray(out.first,dtype=filtration_dtype), tuple(out.second)
587
- #
588
- # if is_function_st:
589
- # blocks = function_simplextree_to_scc(cptr)
590
- # else:
591
- # blocks = simplextree_to_scc(cptr)
592
- # # reduces the space in memory
593
- # if is_function_st:
594
- # blocks = [(tuple(f), tuple(b)) for f,b in blocks[::-1]]
595
- # else:
596
- # blocks = [(np.asarray(f,dtype=filtration_dtype), tuple(b)) for f,b in blocks[::-1]] ## presentation is on the other order
597
- # return blocks+[(np.empty(0,dtype=filtration_dtype),[])]
598
-
599
529
  @cython.boundscheck(False)
600
530
  @cython.wraparound(False)
601
531
  def scc2disk(
multipers/ml/mma.py CHANGED
@@ -8,7 +8,7 @@ from tqdm import tqdm
8
8
  import multipers as mp
9
9
  import multipers.simplex_tree_multi
10
10
  import multipers.slicer
11
- from multipers.grids import compute_grid as reduce_grid
11
+ from multipers.grids import compute_grid
12
12
  from multipers.mma_structures import PyBox_f64, PyModule_type
13
13
 
14
14
  _FilteredComplexType = Union[
@@ -353,7 +353,7 @@ class MMAFormatter(BaseEstimator, TransformerMixin):
353
353
  if "_mean" in strategy:
354
354
  substrategy = strategy.split("_")[0]
355
355
  processed_filtration_values = [
356
- reduce_grid(f, resolution, substrategy, unique=False)
356
+ compute_grid(f, resolution, substrategy, unique=False)
357
357
  for f in filtration_values
358
358
  ]
359
359
  reduced_grid = np.mean(processed_filtration_values, axis=0)
@@ -368,7 +368,7 @@ class MMAFormatter(BaseEstimator, TransformerMixin):
368
368
  )
369
369
  for parameter in range(num_parameters)
370
370
  ]
371
- reduced_grid = reduce_grid(
371
+ reduced_grid = compute_grid(
372
372
  filtration_values, resolution, strategy, unique=True
373
373
  )
374
374
 
@@ -9,8 +9,9 @@ from sklearn.base import BaseEstimator, TransformerMixin
9
9
  from tqdm import tqdm
10
10
 
11
11
  import multipers as mp
12
- from multipers.grids import compute_grid as reduce_grid
12
+ from multipers.array_api import api_from_tensor
13
13
  from multipers.filtrations.density import available_kernels, convolution_signed_measures
14
+ from multipers.grids import compute_grid
14
15
  from multipers.point_measure import rank_decomposition_by_rectangles, signed_betti
15
16
 
16
17
 
@@ -145,7 +146,7 @@ class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
145
146
  ## ax, num_x
146
147
  filtrations = tuple(
147
148
  tuple(
148
- reduce_grid(x, strategy="exact")
149
+ compute_grid(x, strategy="exact")
149
150
  for x in (X[idx][ax] for idx in indices)
150
151
  )
151
152
  for ax in range(self._num_axis)
@@ -164,7 +165,7 @@ class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
164
165
  ]
165
166
  ## ax, param, gridsize
166
167
  filtration_grid = tuple(
167
- reduce_grid(
168
+ compute_grid(
168
169
  filtrations_values[ax],
169
170
  resolution=self.resolution,
170
171
  strategy=self.grid_strategy,
@@ -237,7 +238,7 @@ class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
237
238
  ):
238
239
  # st = mp.SimplexTreeMulti(st, num_parameters=st.num_parameters) # COPY
239
240
  if self.individual_grid:
240
- filtration_grid = reduce_grid(
241
+ filtration_grid = compute_grid(
241
242
  simplextree, strategy=self.grid_strategy, resolution=self.resolution
242
243
  )
243
244
  mass_default = (
@@ -546,6 +547,27 @@ def rescale_sparse_signed_measure(
546
547
  return out
547
548
 
548
549
 
550
+ def sm2deep(signed_measure):
551
+ dirac_positions, dirac_signs = signed_measure
552
+ dtype = dirac_positions.dtype
553
+ new_shape = list(dirac_positions.shape)
554
+ new_shape[1] += 1
555
+ if isinstance(dirac_positions, np.ndarray):
556
+ c = np.empty(new_shape, dtype=dtype)
557
+ c[:, :-1] = dirac_positions
558
+ c[:, -1] = dirac_signs
559
+
560
+ else:
561
+ import torch
562
+
563
+ c = torch.empty(new_shape, dtype=dtype)
564
+ c[:, :-1] = dirac_positions
565
+ if isinstance(dirac_signs, np.ndarray):
566
+ dirac_signs = torch.from_numpy(dirac_signs)
567
+ c[:, -1] = dirac_signs
568
+ return c
569
+
570
+
549
571
  class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
550
572
  """
551
573
  Input
@@ -612,14 +634,8 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
612
634
  return
613
635
 
614
636
  def _get_filtration_bounds(self, X, axis):
615
- if self._backend == "numpy":
616
- _cat = np.concatenate
617
-
618
- else:
619
- ## torch is globally imported
620
- _cat = torch.cat
621
637
  stuff = [
622
- _cat(
638
+ self._backend.cat(
623
639
  [sm[axis][degree][0] for sm in X],
624
640
  axis=0,
625
641
  )
@@ -627,17 +643,20 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
627
643
  ]
628
644
  sizes_ = np.array([len(x) == 0 for x in stuff])
629
645
  assert np.all(~sizes_), f"Degree axis {np.where(sizes_)} is/are trivial !"
630
- if self._backend == "numpy":
631
- filtrations_bounds = np.array(
632
- [([f.min(axis=0), f.max(axis=0)]) for f in stuff]
633
- )
634
- else:
635
- filtrations_bounds = torch.stack(
646
+
647
+ filtrations_bounds = self._backend.asnumpy(
648
+ self._backend.stack(
636
649
  [
637
- torch.stack([f.min(axis=0).values, f.max(axis=0).values])
650
+ self._backend.stack(
651
+ [
652
+ self._backend.minvalues(f, axis=0),
653
+ self._backend.maxvalues(f, axis=0),
654
+ ]
655
+ )
638
656
  for f in stuff
639
657
  ]
640
- ).detach() ## don't want to rescale gradient of normalization
658
+ )
659
+ ) ## don't want to rescale gradient of normalization
641
660
  normalization_factors = (
642
661
  filtrations_bounds[:, 1] - filtrations_bounds[:, 0]
643
662
  if self.normalize
@@ -700,14 +719,7 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
700
719
  first_sm = X[0][0][0][0]
701
720
  else:
702
721
  first_sm = X[0][0][0]
703
- if isinstance(first_sm, np.ndarray):
704
- self._backend = "numpy"
705
- else:
706
- global torch
707
- import torch
708
-
709
- assert isinstance(first_sm, torch.Tensor)
710
- self._backend = "pytorch"
722
+ self._backend = api_from_tensor(first_sm, verbose=self.verbose)
711
723
 
712
724
  def _check_measures(self, X):
713
725
  if self._has_axis:
@@ -770,7 +782,7 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
770
782
  ]
771
783
  # axis, filtration_values
772
784
  filtration_values = [
773
- reduce_grid(
785
+ compute_grid(
774
786
  f_ax.T, resolution=self.resolution, strategy=self.grid_strategy
775
787
  )
776
788
  for f_ax in filtration_values
@@ -841,25 +853,6 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
841
853
  else:
842
854
  return np.asarray(out)[0]
843
855
 
844
- @staticmethod
845
- def deep_format_measure(signed_measure):
846
- dirac_positions, dirac_signs = signed_measure
847
- dtype = dirac_positions.dtype
848
- new_shape = list(dirac_positions.shape)
849
- new_shape[1] += 1
850
- if isinstance(dirac_positions, np.ndarray):
851
- c = np.empty(new_shape, dtype=dtype)
852
- c[:, :-1] = dirac_positions
853
- c[:, -1] = dirac_signs
854
-
855
- else:
856
- import torch
857
-
858
- c = torch.empty(new_shape, dtype=dtype)
859
- c[:, :-1] = dirac_positions
860
- c[:, -1] = dirac_signs
861
- return c
862
-
863
856
  @staticmethod
864
857
  def _integrate_measure(sm, filtrations):
865
858
  from multipers.point_measure import integrate_measure
@@ -928,7 +921,7 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
928
921
  elif self.deep_format:
929
922
  num_degrees = self._num_degrees
930
923
  out = tuple(
931
- tuple(self.deep_format_measure(sm[axis][degree]) for sm in out)
924
+ tuple(sm2deep(sm[axis][degree]) for sm in out)
932
925
  for degree in range(num_degrees)
933
926
  for axis in self._axis_iterator
934
927
  )
@@ -943,11 +936,7 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
943
936
  ), f"Bad axis/degree count. Got {num_axis_degree} (Internal error)"
944
937
  num_parameters = out[0][0].shape[1]
945
938
  dtype = out[0][0].dtype
946
- if isinstance(out[0][0], np.ndarray):
947
- from numpy import zeros
948
- else:
949
- from torch import zeros
950
- unragged_tensor = zeros(
939
+ unragged_tensor = self._backend.zeros(
951
940
  (
952
941
  num_axis_degree,
953
942
  num_data,
@@ -1027,6 +1016,7 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1027
1016
  self.plot = plot
1028
1017
  self.log_density = log_density
1029
1018
  self.kde_kwargs = kde_kwargs
1019
+ self._api = None
1030
1020
  return
1031
1021
 
1032
1022
  def fit(self, X, y=None):
@@ -1035,13 +1025,16 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1035
1025
  return self
1036
1026
  if isinstance(X[0][0], tuple):
1037
1027
  self._is_input_sparse = True
1028
+
1029
+ self._api = api_from_tensor(X[0][0][0], verbose=self.progress)
1038
1030
  else:
1039
1031
  self._is_input_sparse = False
1032
+
1033
+ self._api = api_from_tensor(X, verbose=self.progress)
1040
1034
  # print(f"IMG output is set to {'sparse' if self.sparse else 'matrix'}")
1041
1035
  if not self._is_input_sparse:
1042
1036
  self._input_resolution = X[0][0].shape
1043
1037
  try:
1044
- float(self.bandwidth)
1045
1038
  b = float(self.bandwidth)
1046
1039
  self._bandwidths = [
1047
1040
  b if b > 0 else -b * s for s in self._input_resolution
@@ -1060,24 +1053,24 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1060
1053
  # If not sparse : a grid has to be defined
1061
1054
  if self._refit:
1062
1055
  # print("Fitting a grid...", end="")
1063
- pts = np.concatenate(
1056
+ pts = self._api.cat(
1064
1057
  [sm[0] for signed_measures in X for sm in signed_measures]
1065
1058
  ).T
1066
- self.filtration_grid = reduce_grid(
1059
+ self.filtration_grid = compute_grid(
1067
1060
  pts,
1068
1061
  strategy=self.grid_strategy,
1069
1062
  resolution=self.resolution,
1070
1063
  )
1071
1064
  # print('Done.')
1072
1065
  if self.filtration_grid is not None:
1073
- self.diameter = np.linalg.norm(
1074
- [f.max() - f.min() for f in self.filtration_grid]
1066
+ self.diameter = self._api.norm(
1067
+ self._api.astensor([f[-1] - f[0] for f in self.filtration_grid])
1075
1068
  )
1076
1069
  if self.progress:
1077
1070
  print(f"Computed a diameter of {self.diameter}")
1078
1071
  return self
1079
1072
 
1080
- def _sm2smi(self, signed_measures: Iterable[np.ndarray]):
1073
+ def _sm2smi(self, signed_measures):
1081
1074
  # print(self._input_resolution, self.bandwidths, _bandwidths)
1082
1075
  from scipy.ndimage import gaussian_filter
1083
1076
 
@@ -1125,6 +1118,7 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1125
1118
  def _plot_imgs(self, imgs: Iterable[np.ndarray], size=4):
1126
1119
  from multipers.plots import plot_surface
1127
1120
 
1121
+ imgs = self._api.asnumpy(imgs)
1128
1122
  num_degrees = imgs[0].shape[0]
1129
1123
  num_imgs = len(imgs)
1130
1124
  fig, axes = plt.subplots(
@@ -1137,7 +1131,10 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1137
1131
  for i, img in enumerate(imgs):
1138
1132
  for j, img_of_degree in enumerate(img):
1139
1133
  plot_surface(
1140
- self.filtration_grid, img_of_degree, ax=axes[i, j], cmap="Spectral"
1134
+ [self._api.asnumpy(f) for f in self.filtration_grid],
1135
+ img_of_degree,
1136
+ ax=axes[i, j],
1137
+ cmap="Spectral",
1141
1138
  )
1142
1139
 
1143
1140
  def transform(self, X):
@@ -1153,6 +1150,7 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1153
1150
  X, desc="Computing images", disable=not self.progress
1154
1151
  )
1155
1152
  )
1153
+ out = self._api.cat([x[None] for x in out])
1156
1154
  if self.plot and not self.flatten:
1157
1155
  if self.progress:
1158
1156
  print("Plotting convolutions...", end="")
@@ -1160,8 +1158,8 @@ class SignedMeasure2Convolution(BaseEstimator, TransformerMixin):
1160
1158
  if self.progress:
1161
1159
  print("Done !")
1162
1160
  if self.flatten and not self._is_input_sparse:
1163
- out = [x.flatten() for x in out]
1164
- return np.asarray(out)
1161
+ out = self._api.cat([x.ravel()[None] for x in out])
1162
+ return out
1165
1163
 
1166
1164
 
1167
1165
  class SignedMeasure2SlicedWassersteinDistance(BaseEstimator, TransformerMixin):
@@ -30,7 +30,7 @@ cdef extern from "multiparameter_module_approximation/approximation.h" namespace
30
30
  bool contains(const vector[T]&) nogil const
31
31
  Box[T] get_bounds() nogil const
32
32
  void rescale(const vector[T]&) nogil
33
-
33
+ bool operator==(const Summand[T]&) nogil
34
34
 
35
35
 
36
36
 
@@ -86,6 +86,7 @@ cdef extern from "multiparameter_module_approximation/approximation.h" namespace
86
86
  Summand[T]& at(unsigned int) nogil
87
87
  vector[Summand[T]].iterator begin()
88
88
  vector[Summand[T]].iterator end()
89
+ bool operator==(const Module[T]&) nogil
89
90
  void clean(const bool) nogil
90
91
  void fill(const T) nogil
91
92
  # vector[image_type] get_vectorization(const T,const T, unsigned int,unsigned int,const Box&)
@@ -97,6 +97,11 @@ cdef class PySummand_f64:
97
97
  return v[0].num_parameters()
98
98
  v = self.sum.get_death_list()
99
99
  return v[0].num_parameters()
100
+ def __eq__(self, PySummand_f64 other):
101
+ return self.sum == other.sum
102
+
103
+
104
+
100
105
 
101
106
  cdef inline get_summand_filtration_values_f64(Summand[double] summand):
102
107
  r"""
@@ -161,6 +166,10 @@ cdef class PyModule_f64:
161
166
 
162
167
  cdef set(self, Module[double] m):
163
168
  self.cmod = m
169
+
170
+ def __eq__(self, PyModule_f64 other):
171
+ return self.cmod == other.cmod
172
+
164
173
  def merge(self, PyModule_f64 other, int dim=-1):
165
174
  r"""
166
175
  Merges two modules into one
@@ -176,7 +185,9 @@ cdef class PyModule_f64:
176
185
  Copy module from a memory pointer. Unsafe.
177
186
  """
178
187
  self.cmod = move(dereference(<Module[double]*>(module_ptr)))
179
- def set_box(self, PyBox_f64 pybox):
188
+ def set_box(self, box):
189
+ assert len(box) == 2, "Box format is [low, hight]"
190
+ pybox = PyBox_f64(box[0], box[1])
180
191
  cdef Box[double] cbox = pybox.box
181
192
  with nogil:
182
193
  self.cmod.set_box(cbox)
@@ -959,8 +970,8 @@ cdef dump_summand_f64(Summand[double]& summand):
959
970
  cdef vector[One_critical_filtration[double]] births = summand.get_birth_list()
960
971
  cdef vector[One_critical_filtration[double]] deaths = summand.get_death_list()
961
972
  return (
962
- _vff21cview_f64(births, copy=True), ## copy as local variables
963
- _vff21cview_f64(deaths, copy=True),
973
+ np.array(_vff21cview_f64(births)),
974
+ np.array(_vff21cview_f64(deaths)),
964
975
  summand.get_dimension(),
965
976
  )
966
977
 
@@ -1047,6 +1058,11 @@ cdef class PySummand_f32:
1047
1058
  return v[0].num_parameters()
1048
1059
  v = self.sum.get_death_list()
1049
1060
  return v[0].num_parameters()
1061
+ def __eq__(self, PySummand_f32 other):
1062
+ return self.sum == other.sum
1063
+
1064
+
1065
+
1050
1066
 
1051
1067
  cdef inline get_summand_filtration_values_f32(Summand[float] summand):
1052
1068
  r"""
@@ -1111,6 +1127,10 @@ cdef class PyModule_f32:
1111
1127
 
1112
1128
  cdef set(self, Module[float] m):
1113
1129
  self.cmod = m
1130
+
1131
+ def __eq__(self, PyModule_f32 other):
1132
+ return self.cmod == other.cmod
1133
+
1114
1134
  def merge(self, PyModule_f32 other, int dim=-1):
1115
1135
  r"""
1116
1136
  Merges two modules into one
@@ -1126,7 +1146,9 @@ cdef class PyModule_f32:
1126
1146
  Copy module from a memory pointer. Unsafe.
1127
1147
  """
1128
1148
  self.cmod = move(dereference(<Module[float]*>(module_ptr)))
1129
- def set_box(self, PyBox_f32 pybox):
1149
+ def set_box(self, box):
1150
+ assert len(box) == 2, "Box format is [low, hight]"
1151
+ pybox = PyBox_f32(box[0], box[1])
1130
1152
  cdef Box[float] cbox = pybox.box
1131
1153
  with nogil:
1132
1154
  self.cmod.set_box(cbox)
@@ -1909,8 +1931,8 @@ cdef dump_summand_f32(Summand[float]& summand):
1909
1931
  cdef vector[One_critical_filtration[float]] births = summand.get_birth_list()
1910
1932
  cdef vector[One_critical_filtration[float]] deaths = summand.get_death_list()
1911
1933
  return (
1912
- _vff21cview_f32(births, copy=True), ## copy as local variables
1913
- _vff21cview_f32(deaths, copy=True),
1934
+ np.array(_vff21cview_f32(births)),
1935
+ np.array(_vff21cview_f32(deaths)),
1914
1936
  summand.get_dimension(),
1915
1937
  )
1916
1938
 
@@ -1997,6 +2019,11 @@ cdef class PySummand_i32:
1997
2019
  return v[0].num_parameters()
1998
2020
  v = self.sum.get_death_list()
1999
2021
  return v[0].num_parameters()
2022
+ def __eq__(self, PySummand_i32 other):
2023
+ return self.sum == other.sum
2024
+
2025
+
2026
+
2000
2027
 
2001
2028
  cdef inline get_summand_filtration_values_i32(Summand[int32_t] summand):
2002
2029
  r"""
@@ -2061,6 +2088,10 @@ cdef class PyModule_i32:
2061
2088
 
2062
2089
  cdef set(self, Module[int32_t] m):
2063
2090
  self.cmod = m
2091
+
2092
+ def __eq__(self, PyModule_i32 other):
2093
+ return self.cmod == other.cmod
2094
+
2064
2095
  def merge(self, PyModule_i32 other, int dim=-1):
2065
2096
  r"""
2066
2097
  Merges two modules into one
@@ -2076,7 +2107,9 @@ cdef class PyModule_i32:
2076
2107
  Copy module from a memory pointer. Unsafe.
2077
2108
  """
2078
2109
  self.cmod = move(dereference(<Module[int32_t]*>(module_ptr)))
2079
- def set_box(self, PyBox_i32 pybox):
2110
+ def set_box(self, box):
2111
+ assert len(box) == 2, "Box format is [low, hight]"
2112
+ pybox = PyBox_i32(box[0], box[1])
2080
2113
  cdef Box[int32_t] cbox = pybox.box
2081
2114
  with nogil:
2082
2115
  self.cmod.set_box(cbox)
@@ -2277,8 +2310,8 @@ cdef dump_summand_i32(Summand[int32_t]& summand):
2277
2310
  cdef vector[One_critical_filtration[int32_t]] births = summand.get_birth_list()
2278
2311
  cdef vector[One_critical_filtration[int32_t]] deaths = summand.get_death_list()
2279
2312
  return (
2280
- _vff21cview_i32(births, copy=True), ## copy as local variables
2281
- _vff21cview_i32(deaths, copy=True),
2313
+ np.array(_vff21cview_i32(births)),
2314
+ np.array(_vff21cview_i32(deaths)),
2282
2315
  summand.get_dimension(),
2283
2316
  )
2284
2317
 
@@ -2365,6 +2398,11 @@ cdef class PySummand_i64:
2365
2398
  return v[0].num_parameters()
2366
2399
  v = self.sum.get_death_list()
2367
2400
  return v[0].num_parameters()
2401
+ def __eq__(self, PySummand_i64 other):
2402
+ return self.sum == other.sum
2403
+
2404
+
2405
+
2368
2406
 
2369
2407
  cdef inline get_summand_filtration_values_i64(Summand[int64_t] summand):
2370
2408
  r"""
@@ -2429,6 +2467,10 @@ cdef class PyModule_i64:
2429
2467
 
2430
2468
  cdef set(self, Module[int64_t] m):
2431
2469
  self.cmod = m
2470
+
2471
+ def __eq__(self, PyModule_i64 other):
2472
+ return self.cmod == other.cmod
2473
+
2432
2474
  def merge(self, PyModule_i64 other, int dim=-1):
2433
2475
  r"""
2434
2476
  Merges two modules into one
@@ -2444,7 +2486,9 @@ cdef class PyModule_i64:
2444
2486
  Copy module from a memory pointer. Unsafe.
2445
2487
  """
2446
2488
  self.cmod = move(dereference(<Module[int64_t]*>(module_ptr)))
2447
- def set_box(self, PyBox_i64 pybox):
2489
+ def set_box(self, box):
2490
+ assert len(box) == 2, "Box format is [low, hight]"
2491
+ pybox = PyBox_i64(box[0], box[1])
2448
2492
  cdef Box[int64_t] cbox = pybox.box
2449
2493
  with nogil:
2450
2494
  self.cmod.set_box(cbox)
@@ -2645,8 +2689,8 @@ cdef dump_summand_i64(Summand[int64_t]& summand):
2645
2689
  cdef vector[One_critical_filtration[int64_t]] births = summand.get_birth_list()
2646
2690
  cdef vector[One_critical_filtration[int64_t]] deaths = summand.get_death_list()
2647
2691
  return (
2648
- _vff21cview_i64(births, copy=True), ## copy as local variables
2649
- _vff21cview_i64(deaths, copy=True),
2692
+ np.array(_vff21cview_i64(births)),
2693
+ np.array(_vff21cview_i64(deaths)),
2650
2694
  summand.get_dimension(),
2651
2695
  )
2652
2696