multipers 2.3.3b5__cp312-cp312-win_amd64.whl → 2.3.3b7__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

@@ -9,12 +9,107 @@ from sklearn.base import BaseEstimator, TransformerMixin
9
9
  from tqdm import tqdm
10
10
 
11
11
  import multipers as mp
12
- from multipers.array_api import api_from_tensor
12
+ from multipers.array_api import api_from_tensor, api_from_tensors
13
13
  from multipers.filtrations.density import available_kernels, convolution_signed_measures
14
- from multipers.grids import compute_grid
14
+ from multipers.grids import compute_grid, todense
15
15
  from multipers.point_measure import rank_decomposition_by_rectangles, signed_betti
16
16
 
17
17
 
18
+ def batch_signed_measure_convolutions(
19
+ signed_measures, # array of shape (num_data,num_pts,D)
20
+ x, # array of shape (num_x, D) or (num_data, num_x, D)
21
+ bandwidth, # either float or matrix if multivariate kernel
22
+ kernel: available_kernels,
23
+ api=None,
24
+ ):
25
+ """
26
+ Input
27
+ -----
28
+ - signed_measures: unragged, of shape (num_data, num_pts, D+1)
29
+ where last coord is weights, (0 for dummy points)
30
+ - x : the points to convolve (num_x,D)
31
+ - bandwidth : the bandwidths or covariance matrix inverse or ... of the kernel
32
+ - kernel : "gaussian", "multivariate_gaussian", "exponential", or Callable (x_i, y_i, bandwidth)->float
33
+
34
+ Output
35
+ ------
36
+ Array of shape (num_convolutions, (num_axis), num_data,
37
+ Array of shape (num_convolutions, (num_axis), num_data, max_x_size)
38
+ """
39
+ from multipers.filtrations.density import _kernel
40
+
41
+ if api is None:
42
+ api = api_from_tensors(signed_measures, x)
43
+ if signed_measures.ndim == 2:
44
+ signed_measures = signed_measures[None, :, :]
45
+ sms = signed_measures[..., :-1]
46
+ weights = signed_measures[..., -1]
47
+ _sms = api.LazyTensor(api.ascontiguous(sms[..., None, :]))
48
+ _x = api.ascontiguous(x[..., None, :, :])
49
+
50
+ sms_kernel = _kernel(kernel)(_sms, _x, bandwidth)
51
+ out = (sms_kernel * api.ascontiguous(weights[..., None, None])).sum(
52
+ signed_measures.ndim - 2
53
+ )
54
+ assert out.shape[-1] == 1, "Pykeops bug fixed, TODO : refix this "
55
+ out = out[..., 0] ## pykeops bug + ensures its a tensor
56
+ # assert out.shape == (x.shape[0], x.shape[1]), f"{x.shape=}, {out.shape=}"
57
+ return out
58
+
59
+
60
+ def sm2deep(signed_measure, api=None):
61
+ if api is None:
62
+ api = api_from_tensor(signed_measure[0])
63
+ dirac_positions, dirac_signs = signed_measure
64
+ dtype = dirac_positions.dtype
65
+ new_shape = list(dirac_positions.shape)
66
+ new_shape[1] += 1
67
+ c = api.empty(new_shape, dtype=dtype)
68
+ c[:, :-1] = dirac_positions
69
+ c[:, -1] = api.astensor(dirac_signs)
70
+ return c
71
+
72
+
73
+ def deep_unrag(sms, api=None):
74
+ if api is None:
75
+ api = api_from_tensor(sms[0][0])
76
+ num_sm = len(sms)
77
+ if num_sm == 0:
78
+ return api.tensor([])
79
+ first = sms[0][0]
80
+ num_parameters = first.shape[1]
81
+ dtype = first.dtype
82
+ deep_sms = tuple(sm2deep(sm, api=api) for sm in sms)
83
+ max_num_pts = np.max([sm[0].shape[0] for sm in sms])
84
+ unragged_sms = api.zeros((num_sm, max_num_pts, num_parameters + 1), dtype=dtype)
85
+
86
+ for data in range(num_sm):
87
+ sm = deep_sms[data]
88
+ a, b = sm.shape
89
+ unragged_sms[data, :a, :b] = sm
90
+ return unragged_sms
91
+
92
+
93
+ def sm_convolution(
94
+ sms,
95
+ grid,
96
+ bandwidth,
97
+ kernel: available_kernels = "gaussian",
98
+ plot: bool = False,
99
+ **plt_kwargs,
100
+ ):
101
+ dense_grid = todense(grid)
102
+ api = api_from_tensors(sms[0][0], dense_grid)
103
+ sms = deep_unrag(sms, api=api)
104
+ convs = batch_signed_measure_convolutions(
105
+ sms, dense_grid, bandwidth, kernel, api=api
106
+ ).reshape(sms.shape[0], *(len(g) for g in grid))
107
+ if plot:
108
+ from multipers.plots import plot_surfaces
109
+ plot_surfaces((grid, convs), **plt_kwargs)
110
+ return convs
111
+
112
+
18
113
  class FilteredComplex2SignedMeasure(BaseEstimator, TransformerMixin):
19
114
  """
20
115
  Input
@@ -547,27 +642,6 @@ def rescale_sparse_signed_measure(
547
642
  return out
548
643
 
549
644
 
550
- def sm2deep(signed_measure):
551
- dirac_positions, dirac_signs = signed_measure
552
- dtype = dirac_positions.dtype
553
- new_shape = list(dirac_positions.shape)
554
- new_shape[1] += 1
555
- if isinstance(dirac_positions, np.ndarray):
556
- c = np.empty(new_shape, dtype=dtype)
557
- c[:, :-1] = dirac_positions
558
- c[:, -1] = dirac_signs
559
-
560
- else:
561
- import torch
562
-
563
- c = torch.empty(new_shape, dtype=dtype)
564
- c[:, :-1] = dirac_positions
565
- if isinstance(dirac_signs, np.ndarray):
566
- dirac_signs = torch.from_numpy(dirac_signs)
567
- c[:, -1] = dirac_signs
568
- return c
569
-
570
-
571
645
  class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
572
646
  """
573
647
  Input
@@ -759,7 +833,9 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
759
833
  self._filtrations_bounds.append(filtration_bounds)
760
834
  self._normalization_factors.append(normalization_factors)
761
835
  self._filtrations_bounds = self._backend.astensor(self._filtrations_bounds)
762
- self._normalization_factors = self._backend.astensor(self._normalization_factors)
836
+ self._normalization_factors = self._backend.astensor(
837
+ self._normalization_factors
838
+ )
763
839
  # else:
764
840
  # (
765
841
  # self._filtrations_bounds,
@@ -784,9 +860,11 @@ class SignedMeasureFormatter(BaseEstimator, TransformerMixin):
784
860
  ]
785
861
  # axis, filtration_values
786
862
  filtration_values = [
787
- self._backend.astensor(compute_grid(
788
- f_ax.T, resolution=self.resolution, strategy=self.grid_strategy
789
- ))
863
+ self._backend.astensor(
864
+ compute_grid(
865
+ f_ax.T, resolution=self.resolution, strategy=self.grid_strategy
866
+ )
867
+ )
790
868
  for f_ax in filtration_values
791
869
  ]
792
870
  self._infered_grids = filtration_values
@@ -751,7 +751,7 @@ cdef class PyModule_f64:
751
751
  axs = [plt.gca()]
752
752
  for image, degree, i in zip(image_vector, degrees, range(num_degrees)):
753
753
  ax = axs[i]
754
- temp = multipers.plots.plot_surface(grid, image.T, ax=ax)
754
+ temp = multipers.plots.plot_surface(grid, image, ax=ax)
755
755
  plt.colorbar(temp, ax = ax)
756
756
  if degree < 0 :
757
757
  ax.set_title(rf"$H_{i}$ $2$-persistence image")
@@ -1712,7 +1712,7 @@ cdef class PyModule_f32:
1712
1712
  axs = [plt.gca()]
1713
1713
  for image, degree, i in zip(image_vector, degrees, range(num_degrees)):
1714
1714
  ax = axs[i]
1715
- temp = multipers.plots.plot_surface(grid, image.T, ax=ax)
1715
+ temp = multipers.plots.plot_surface(grid, image, ax=ax)
1716
1716
  plt.colorbar(temp, ax = ax)
1717
1717
  if degree < 0 :
1718
1718
  ax.set_title(rf"$H_{i}$ $2$-persistence image")
@@ -773,7 +773,7 @@ cdef class PyModule_{{SHORT}}:
773
773
  axs = [plt.gca()]
774
774
  for image, degree, i in zip(image_vector, degrees, range(num_degrees)):
775
775
  ax = axs[i]
776
- temp = multipers.plots.plot_surface(grid, image.T, ax=ax)
776
+ temp = multipers.plots.plot_surface(grid, image, ax=ax)
777
777
  plt.colorbar(temp, ax = ax)
778
778
  if degree < 0 :
779
779
  ax.set_title(rf"$H_{i}$ $2$-persistence image")
multipers/plots.py CHANGED
@@ -15,9 +15,9 @@ _custom_colors = [
15
15
  "#00b4d8",
16
16
  "#90e0ef",
17
17
  ]
18
- _cmap = ListedColormap(_custom_colors)
19
- _continuous_cmap = mcolors.LinearSegmentedColormap.from_list(
20
- "continuous_cmap", _cmap.colors, N=256
18
+ _cmap_ = ListedColormap(_custom_colors)
19
+ _cmap = mcolors.LinearSegmentedColormap.from_list(
20
+ "continuous_cmap", _cmap_.colors, N=256
21
21
  )
22
22
 
23
23
 
@@ -180,8 +180,9 @@ def plot_surface(
180
180
  fig=None,
181
181
  ax=None,
182
182
  cmap: Optional[str] = None,
183
- discrete_surface=False,
184
- has_negative_values=False,
183
+ discrete_surface: bool = False,
184
+ has_negative_values: bool = False,
185
+ contour: bool = True,
185
186
  **plt_args,
186
187
  ):
187
188
  import matplotlib
@@ -213,7 +214,12 @@ def plot_surface(
213
214
  )
214
215
  cbar.set_ticks(ticks=bounds, labels=bounds)
215
216
  return im
216
- im = ax.contourf(grid[0], grid[1], hf.T, cmap=cmap, **plt_args)
217
+
218
+ if contour:
219
+ levels = plt_args.pop("levels", 50)
220
+ im = ax.contourf(grid[0], grid[1], hf.T, cmap=cmap, levels=levels, **plt_args)
221
+ else:
222
+ im = ax.pcolormesh(grid[0], grid[1], hf.T, cmap=cmap, **plt_args)
217
223
  return im
218
224
 
219
225
 
Binary file
@@ -883,6 +883,8 @@ cdef class SimplexTreeMulti_KFi32:
883
883
  bool coordinate_values=True,
884
884
  bool force=False,
885
885
  str strategy:_available_strategies = "exact",
886
+ resolution:Optional[int|list[int]] = None,
887
+ bool coordinates = False,
886
888
  grid_strategy=None,
887
889
  bool inplace=False,
888
890
  **filtration_grid_kwargs
@@ -910,7 +912,7 @@ cdef class SimplexTreeMulti_KFi32:
910
912
 
911
913
  #TODO : multi-critical
912
914
  if filtration_grid is None:
913
- filtration_grid = self.get_filtration_grid(grid_strategy=strategy, **filtration_grid_kwargs)
915
+ filtration_grid = self.get_filtration_grid(grid_strategy=strategy, resolution=resolution, **filtration_grid_kwargs)
914
916
  else:
915
917
  filtration_grid = sanitize_grid(filtration_grid)
916
918
  if len(filtration_grid) != self.num_parameters:
@@ -2322,6 +2324,8 @@ cdef class SimplexTreeMulti_Fi32:
2322
2324
  bool coordinate_values=True,
2323
2325
  bool force=False,
2324
2326
  str strategy:_available_strategies = "exact",
2327
+ resolution:Optional[int|list[int]] = None,
2328
+ bool coordinates = False,
2325
2329
  grid_strategy=None,
2326
2330
  bool inplace=False,
2327
2331
  **filtration_grid_kwargs
@@ -2349,7 +2353,7 @@ cdef class SimplexTreeMulti_Fi32:
2349
2353
 
2350
2354
  #TODO : multi-critical
2351
2355
  if filtration_grid is None:
2352
- filtration_grid = self.get_filtration_grid(grid_strategy=strategy, **filtration_grid_kwargs)
2356
+ filtration_grid = self.get_filtration_grid(grid_strategy=strategy, resolution=resolution, **filtration_grid_kwargs)
2353
2357
  else:
2354
2358
  filtration_grid = sanitize_grid(filtration_grid)
2355
2359
  if len(filtration_grid) != self.num_parameters:
@@ -3471,6 +3475,8 @@ cdef class SimplexTreeMulti_KFi64:
3471
3475
  bool coordinate_values=True,
3472
3476
  bool force=False,
3473
3477
  str strategy:_available_strategies = "exact",
3478
+ resolution:Optional[int|list[int]] = None,
3479
+ bool coordinates = False,
3474
3480
  grid_strategy=None,
3475
3481
  bool inplace=False,
3476
3482
  **filtration_grid_kwargs
@@ -3498,7 +3504,7 @@ cdef class SimplexTreeMulti_KFi64:
3498
3504
 
3499
3505
  #TODO : multi-critical
3500
3506
  if filtration_grid is None:
3501
- filtration_grid = self.get_filtration_grid(grid_strategy=strategy, **filtration_grid_kwargs)
3507
+ filtration_grid = self.get_filtration_grid(grid_strategy=strategy, resolution=resolution, **filtration_grid_kwargs)
3502
3508
  else:
3503
3509
  filtration_grid = sanitize_grid(filtration_grid)
3504
3510
  if len(filtration_grid) != self.num_parameters:
@@ -4910,6 +4916,8 @@ cdef class SimplexTreeMulti_Fi64:
4910
4916
  bool coordinate_values=True,
4911
4917
  bool force=False,
4912
4918
  str strategy:_available_strategies = "exact",
4919
+ resolution:Optional[int|list[int]] = None,
4920
+ bool coordinates = False,
4913
4921
  grid_strategy=None,
4914
4922
  bool inplace=False,
4915
4923
  **filtration_grid_kwargs
@@ -4937,7 +4945,7 @@ cdef class SimplexTreeMulti_Fi64:
4937
4945
 
4938
4946
  #TODO : multi-critical
4939
4947
  if filtration_grid is None:
4940
- filtration_grid = self.get_filtration_grid(grid_strategy=strategy, **filtration_grid_kwargs)
4948
+ filtration_grid = self.get_filtration_grid(grid_strategy=strategy, resolution=resolution, **filtration_grid_kwargs)
4941
4949
  else:
4942
4950
  filtration_grid = sanitize_grid(filtration_grid)
4943
4951
  if len(filtration_grid) != self.num_parameters:
@@ -6059,6 +6067,8 @@ cdef class SimplexTreeMulti_KFf32:
6059
6067
  bool coordinate_values=True,
6060
6068
  bool force=False,
6061
6069
  str strategy:_available_strategies = "exact",
6070
+ resolution:Optional[int|list[int]] = None,
6071
+ bool coordinates = False,
6062
6072
  grid_strategy=None,
6063
6073
  bool inplace=False,
6064
6074
  **filtration_grid_kwargs
@@ -6086,7 +6096,7 @@ cdef class SimplexTreeMulti_KFf32:
6086
6096
 
6087
6097
  #TODO : multi-critical
6088
6098
  if filtration_grid is None:
6089
- filtration_grid = self.get_filtration_grid(grid_strategy=strategy, **filtration_grid_kwargs)
6099
+ filtration_grid = self.get_filtration_grid(grid_strategy=strategy, resolution=resolution, **filtration_grid_kwargs)
6090
6100
  else:
6091
6101
  filtration_grid = sanitize_grid(filtration_grid)
6092
6102
  if len(filtration_grid) != self.num_parameters:
@@ -7498,6 +7508,8 @@ cdef class SimplexTreeMulti_Ff32:
7498
7508
  bool coordinate_values=True,
7499
7509
  bool force=False,
7500
7510
  str strategy:_available_strategies = "exact",
7511
+ resolution:Optional[int|list[int]] = None,
7512
+ bool coordinates = False,
7501
7513
  grid_strategy=None,
7502
7514
  bool inplace=False,
7503
7515
  **filtration_grid_kwargs
@@ -7525,7 +7537,7 @@ cdef class SimplexTreeMulti_Ff32:
7525
7537
 
7526
7538
  #TODO : multi-critical
7527
7539
  if filtration_grid is None:
7528
- filtration_grid = self.get_filtration_grid(grid_strategy=strategy, **filtration_grid_kwargs)
7540
+ filtration_grid = self.get_filtration_grid(grid_strategy=strategy, resolution=resolution, **filtration_grid_kwargs)
7529
7541
  else:
7530
7542
  filtration_grid = sanitize_grid(filtration_grid)
7531
7543
  if len(filtration_grid) != self.num_parameters:
@@ -8647,6 +8659,8 @@ cdef class SimplexTreeMulti_KFf64:
8647
8659
  bool coordinate_values=True,
8648
8660
  bool force=False,
8649
8661
  str strategy:_available_strategies = "exact",
8662
+ resolution:Optional[int|list[int]] = None,
8663
+ bool coordinates = False,
8650
8664
  grid_strategy=None,
8651
8665
  bool inplace=False,
8652
8666
  **filtration_grid_kwargs
@@ -8674,7 +8688,7 @@ cdef class SimplexTreeMulti_KFf64:
8674
8688
 
8675
8689
  #TODO : multi-critical
8676
8690
  if filtration_grid is None:
8677
- filtration_grid = self.get_filtration_grid(grid_strategy=strategy, **filtration_grid_kwargs)
8691
+ filtration_grid = self.get_filtration_grid(grid_strategy=strategy, resolution=resolution, **filtration_grid_kwargs)
8678
8692
  else:
8679
8693
  filtration_grid = sanitize_grid(filtration_grid)
8680
8694
  if len(filtration_grid) != self.num_parameters:
@@ -10086,6 +10100,8 @@ cdef class SimplexTreeMulti_Ff64:
10086
10100
  bool coordinate_values=True,
10087
10101
  bool force=False,
10088
10102
  str strategy:_available_strategies = "exact",
10103
+ resolution:Optional[int|list[int]] = None,
10104
+ bool coordinates = False,
10089
10105
  grid_strategy=None,
10090
10106
  bool inplace=False,
10091
10107
  **filtration_grid_kwargs
@@ -10113,7 +10129,7 @@ cdef class SimplexTreeMulti_Ff64:
10113
10129
 
10114
10130
  #TODO : multi-critical
10115
10131
  if filtration_grid is None:
10116
- filtration_grid = self.get_filtration_grid(grid_strategy=strategy, **filtration_grid_kwargs)
10132
+ filtration_grid = self.get_filtration_grid(grid_strategy=strategy, resolution=resolution, **filtration_grid_kwargs)
10117
10133
  else:
10118
10134
  filtration_grid = sanitize_grid(filtration_grid)
10119
10135
  if len(filtration_grid) != self.num_parameters:
@@ -1330,6 +1330,8 @@ cdef class SimplexTreeMulti_{{FSHORT}}:
1330
1330
  bool coordinate_values=True,
1331
1331
  bool force=False,
1332
1332
  str strategy:_available_strategies = "exact",
1333
+ resolution:Optional[int|list[int]] = None,
1334
+ bool coordinates = False,
1333
1335
  grid_strategy=None,
1334
1336
  bool inplace=False,
1335
1337
  **filtration_grid_kwargs
@@ -1357,7 +1359,7 @@ cdef class SimplexTreeMulti_{{FSHORT}}:
1357
1359
 
1358
1360
  #TODO : multi-critical
1359
1361
  if filtration_grid is None:
1360
- filtration_grid = self.get_filtration_grid(grid_strategy=strategy, **filtration_grid_kwargs)
1362
+ filtration_grid = self.get_filtration_grid(grid_strategy=strategy, resolution=resolution, **filtration_grid_kwargs)
1361
1363
  else:
1362
1364
  filtration_grid = sanitize_grid(filtration_grid)
1363
1365
  if len(filtration_grid) != self.num_parameters:
Binary file