multipers 2.3.3b6__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (183) hide show
  1. multipers/.dylibs/libc++.1.0.dylib +0 -0
  2. multipers/.dylibs/libtbb.12.16.dylib +0 -0
  3. multipers/__init__.py +33 -0
  4. multipers/_signed_measure_meta.py +453 -0
  5. multipers/_slicer_meta.py +211 -0
  6. multipers/array_api/__init__.py +45 -0
  7. multipers/array_api/numpy.py +41 -0
  8. multipers/array_api/torch.py +58 -0
  9. multipers/data/MOL2.py +458 -0
  10. multipers/data/UCR.py +18 -0
  11. multipers/data/__init__.py +1 -0
  12. multipers/data/graphs.py +466 -0
  13. multipers/data/immuno_regions.py +27 -0
  14. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  15. multipers/data/pytorch2simplextree.py +91 -0
  16. multipers/data/shape3d.py +101 -0
  17. multipers/data/synthetic.py +113 -0
  18. multipers/distances.py +202 -0
  19. multipers/filtration_conversions.pxd +229 -0
  20. multipers/filtration_conversions.pxd.tp +84 -0
  21. multipers/filtrations/__init__.py +18 -0
  22. multipers/filtrations/density.py +574 -0
  23. multipers/filtrations/filtrations.py +361 -0
  24. multipers/filtrations.pxd +224 -0
  25. multipers/function_rips.cpython-310-darwin.so +0 -0
  26. multipers/function_rips.pyx +105 -0
  27. multipers/grids.cpython-310-darwin.so +0 -0
  28. multipers/grids.pyx +433 -0
  29. multipers/gudhi/Persistence_slices_interface.h +132 -0
  30. multipers/gudhi/Simplex_tree_interface.h +239 -0
  31. multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
  32. multipers/gudhi/cubical_to_boundary.h +59 -0
  33. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  34. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  35. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  36. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  37. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  38. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  39. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  41. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  42. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  43. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  44. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  45. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  46. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  47. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  48. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  49. multipers/gudhi/gudhi/Matrix.h +2107 -0
  50. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  51. multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
  52. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  53. multipers/gudhi/gudhi/Off_reader.h +173 -0
  54. multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  87. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  88. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  89. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  90. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  91. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  92. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  93. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  94. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  95. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  96. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  97. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  98. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  99. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  100. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  101. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  102. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  103. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  104. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
  105. multipers/gudhi/gudhi/distance_functions.h +62 -0
  106. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  107. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  108. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  109. multipers/gudhi/gudhi/reader_utils.h +367 -0
  110. multipers/gudhi/mma_interface_coh.h +256 -0
  111. multipers/gudhi/mma_interface_h0.h +223 -0
  112. multipers/gudhi/mma_interface_matrix.h +293 -0
  113. multipers/gudhi/naive_merge_tree.h +536 -0
  114. multipers/gudhi/scc_io.h +310 -0
  115. multipers/gudhi/truc.h +1403 -0
  116. multipers/io.cpython-310-darwin.so +0 -0
  117. multipers/io.pyx +644 -0
  118. multipers/ml/__init__.py +0 -0
  119. multipers/ml/accuracies.py +90 -0
  120. multipers/ml/invariants_with_persistable.py +79 -0
  121. multipers/ml/kernels.py +176 -0
  122. multipers/ml/mma.py +713 -0
  123. multipers/ml/one.py +472 -0
  124. multipers/ml/point_clouds.py +352 -0
  125. multipers/ml/signed_measures.py +1589 -0
  126. multipers/ml/sliced_wasserstein.py +461 -0
  127. multipers/ml/tools.py +113 -0
  128. multipers/mma_structures.cpython-310-darwin.so +0 -0
  129. multipers/mma_structures.pxd +128 -0
  130. multipers/mma_structures.pyx +2786 -0
  131. multipers/mma_structures.pyx.tp +1094 -0
  132. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
  133. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  134. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  135. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  136. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  137. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  138. multipers/multiparameter_edge_collapse.py +41 -0
  139. multipers/multiparameter_module_approximation/approximation.h +2330 -0
  140. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  141. multipers/multiparameter_module_approximation/debug.h +107 -0
  142. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  143. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  144. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  145. multipers/multiparameter_module_approximation/images.h +79 -0
  146. multipers/multiparameter_module_approximation/list_column.h +174 -0
  147. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  148. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  149. multipers/multiparameter_module_approximation/set_column.h +135 -0
  150. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  151. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  152. multipers/multiparameter_module_approximation/utilities.h +403 -0
  153. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  154. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  155. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  156. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  157. multipers/multiparameter_module_approximation.cpython-310-darwin.so +0 -0
  158. multipers/multiparameter_module_approximation.pyx +235 -0
  159. multipers/pickle.py +90 -0
  160. multipers/plots.py +456 -0
  161. multipers/point_measure.cpython-310-darwin.so +0 -0
  162. multipers/point_measure.pyx +395 -0
  163. multipers/simplex_tree_multi.cpython-310-darwin.so +0 -0
  164. multipers/simplex_tree_multi.pxd +134 -0
  165. multipers/simplex_tree_multi.pyx +10840 -0
  166. multipers/simplex_tree_multi.pyx.tp +2009 -0
  167. multipers/slicer.cpython-310-darwin.so +0 -0
  168. multipers/slicer.pxd +3034 -0
  169. multipers/slicer.pxd.tp +234 -0
  170. multipers/slicer.pyx +20481 -0
  171. multipers/slicer.pyx.tp +1088 -0
  172. multipers/tensor/tensor.h +672 -0
  173. multipers/tensor.pxd +13 -0
  174. multipers/test.pyx +44 -0
  175. multipers/tests/__init__.py +62 -0
  176. multipers/torch/__init__.py +1 -0
  177. multipers/torch/diff_grids.py +240 -0
  178. multipers/torch/rips_density.py +310 -0
  179. multipers-2.3.3b6.dist-info/METADATA +128 -0
  180. multipers-2.3.3b6.dist-info/RECORD +183 -0
  181. multipers-2.3.3b6.dist-info/WHEEL +6 -0
  182. multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
  183. multipers-2.3.3b6.dist-info/top_level.txt +1 -0
multipers/test.pyx ADDED
@@ -0,0 +1,44 @@
1
+ # cimport multipers.tensor as mt
2
+ from libc.stdint cimport intptr_t, uint16_t
3
+ from libcpp.vector cimport vector
4
+ from libcpp cimport bool, int, float
5
+ from libcpp.utility cimport pair
6
+ from typing import Optional,Iterable,Callable
7
+
8
+
9
+ ctypedef float value_type
10
+ # ctypedef uint16_t index_type
11
+
12
+ import numpy as np
13
+ # cimport numpy as cnp
14
+ # cnp.import_array()
15
+
16
+ # cdef extern from "multi_parameter_rank_invariant/rank_invariant.h" namespace "Gudhi::rank_invariant":
17
+ # void get_hilbert_surface(const intptr_t, mt.static_tensor_view, const vector[index_type], const vector[index_type], index_type, index_type, const vector[index_type], bool, bool) except + nogil
18
+
19
+
20
+ from multipers.simplex_tree_multi import SimplexTreeMulti
21
+
22
+
23
+ def numpy_to_tensor(array:np.ndarray):
24
+ cdef vector[index_type] shape = array.shape
25
+ cdef dtype[::1] contigus_array_view = np.ascontiguousarray(array)
26
+ cdef dtype* dtype_ptr = &contigus_array_view[0]
27
+ cdef mt.static_tensor_view tensor
28
+ with nogil:
29
+ tensor = mt.static_tensor_view(dtype_ptr, shape)
30
+ return tensor.get_resolution()
31
+
32
+ # def hilbert2d(simplextree:SimplexTreeMulti, grid_shape:np.ndarray|list, vector[index_type] degrees, bool mobius_inversion):
33
+ # # assert simplextree.num_parameters == 2
34
+ # cdef intptr_t ptr = simplextree.thisptr
35
+ # cdef vector[index_type] c_grid_shape = grid_shape
36
+ # cdef dtype[::1] container = np.zeros(grid_shape, dtype=np.float32).flatten()
37
+ # cdef dtype* container_ptr = &container[0]
38
+ # cdef mt.static_tensor_view c_container = mt.static_tensor_view(container_ptr, c_grid_shape)
39
+ # cdef index_type i = 0
40
+ # cdef index_type j = 1
41
+ # cdef vector[index_type] fixed_values = [[],[]]
42
+ # # get_hilbert_surface(ptr, c_container, c_grid_shape, degrees,i,j,fixed_values, False, False)
43
+ # return container.reshape(grid_shape)
44
+
@@ -0,0 +1,62 @@
1
+ import numpy as np
2
+
3
+
4
+ def assert_st_simplices(st, dump):
5
+ """
6
+ Checks that the simplextree has the same
7
+ filtration as the dump.
8
+ """
9
+
10
+ assert np.all(
11
+ [
12
+ np.isclose(a, b).all()
13
+ for x, y in zip(st.get_simplices(), dump, strict=True)
14
+ for a, b in zip(x, y, strict=True)
15
+ ]
16
+ )
17
+
18
+
19
+ def sort_sm(sms):
20
+ idx = np.argsort([sm[0][:, 0] for sm in sms])
21
+ return tuple((sm[0][idx], sm[1][idx]) for sm in sms)
22
+
23
+
24
+ def assert_sm_pair(sm1, sm2, exact=True, max_error=1e-3, reg=0.1, threshold=None):
25
+ if not exact:
26
+ from multipers.distances import sm_distance
27
+ if threshold is not None:
28
+ _inf_value_fix = threshold
29
+ sm1[0][sm1[0] >threshold] = _inf_value_fix
30
+ sm2[0][sm2[0] >threshold] = _inf_value_fix
31
+
32
+ d = sm_distance(sm1, sm2, reg=reg)
33
+ assert d < max_error, f"Failed comparison:\n{sm1}\n{sm2},\n with distance {d}."
34
+ return
35
+ assert np.all(
36
+ [
37
+ np.isclose(a, b).all()
38
+ for x, y in zip(sm1, sm2, strict=True)
39
+ for a, b in zip(x, y, strict=True)
40
+ ]
41
+ ), f"Failed comparison:\n-----------------\n{sm1}\n-----------------\n{sm2}"
42
+
43
+
44
+ def assert_sm(*args, exact=True, max_error=1e-5, reg=0.1, threshold=None):
45
+ sms = tuple(args)
46
+ for i in range(len(sms) - 1):
47
+ print(i)
48
+ assert_sm_pair(sms[i], sms[i + 1], exact=exact, max_error=max_error, reg=reg, threshold=threshold)
49
+
50
+
51
+ def random_st(npts=100, num_parameters=2, max_dim=2):
52
+ import gudhi as gd
53
+
54
+ import multipers as mp
55
+ from multipers.data import noisy_annulus
56
+
57
+ x = noisy_annulus(npts // 2, npts - npts // 2, dim=max_dim)
58
+ st = gd.AlphaComplex(points=x).create_simplex_tree()
59
+ st = mp.SimplexTreeMulti(st, num_parameters=num_parameters)
60
+ for p in range(num_parameters):
61
+ st.fill_lowerstar(np.random.uniform(size=npts), p)
62
+ return st
@@ -0,0 +1 @@
1
+ from .rips_density import *
@@ -0,0 +1,240 @@
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ import torch
5
+ from pykeops.torch import LazyTensor
6
+
7
+
8
+ def get_grid(strategy: Literal["exact", "regular_closest", "regular_left", "quantile"]):
9
+ """
10
+ Given a strategy, returns a function of signature
11
+ `(num_pts, num_parameter), int --> Iterable[1d array]`
12
+ that generates a torch-differentiable grid from a set of points,
13
+ and a resolution.
14
+ """
15
+ match strategy:
16
+ case "exact":
17
+ return _exact_grid
18
+ case "regular":
19
+ return _regular_grid
20
+ case "regular_closest":
21
+ return _regular_closest_grid
22
+ case "regular_left":
23
+ return _regular_left_grid
24
+ case "quantile":
25
+ return _quantile_grid
26
+ case _:
27
+ raise ValueError(
28
+ f"""
29
+ Unimplemented strategy {strategy}.
30
+ Available ones : exact, regular_closest, regular_left, quantile.
31
+ """
32
+ )
33
+
34
+
35
+ def todense(grid: list[torch.Tensor]):
36
+ return torch.cartesian_prod(*grid)
37
+
38
+
39
+ def _exact_grid(filtration_values, r=None):
40
+ assert r is None
41
+ grid = tuple(_unique_any(f) for f in filtration_values)
42
+ return grid
43
+
44
+
45
+ def _regular_closest_grid(filtration_values, res):
46
+ grid = tuple(_regular_closest(f, r) for f,r in zip(filtration_values, res))
47
+ return grid
48
+
49
+ def _regular_grid(filtration_values, res):
50
+ grid = tuple(_regular(g,r) for g,r in zip(filtration_values, res))
51
+ return grid
52
+
53
+ def _regular(x, r:int):
54
+ if x.ndim != 1:
55
+ raise ValueError(f"Got ndim!=1. {x=}")
56
+ return torch.linspace(start=torch.min(x), end=torch.max(x), steps=r, dtype=x.dtype)
57
+
58
+ def _regular_left_grid(filtration_values, res):
59
+ grid = tuple(_regular_left(f, r) for f,r in zip(filtration_values,res))
60
+ return grid
61
+
62
+
63
+ def _quantile_grid(filtration_values, res):
64
+ grid = tuple(_quantile(f, r) for f,r in zip(filtration_values,res))
65
+ return grid
66
+ def _quantile(x, r):
67
+ if x.ndim != 1:
68
+ raise ValueError(f"Got ndim!=1. {x=}")
69
+ qs = torch.linspace(0, 1, r, dtype=x.dtype)
70
+ return _unique_any(torch.quantile(x, q=qs))
71
+
72
+
73
+
74
+
75
+ def _unique_any(x, assume_sorted=False, remove_inf: bool = True):
76
+ if x.ndim != 1:
77
+ raise ValueError(f"Got ndim!=1. {x=}")
78
+ if not assume_sorted:
79
+ x, _ = x.sort()
80
+ if remove_inf and x[-1] == torch.inf:
81
+ x = x[:-1]
82
+ with torch.no_grad():
83
+ y = x.unique()
84
+ idx = torch.searchsorted(x, y)
85
+ x = torch.cat([x, torch.tensor([torch.inf])])
86
+ return x[idx]
87
+
88
+
89
+ def _regular_left(f, r: int, unique: bool = True):
90
+ if f.ndim != 1:
91
+ raise ValueError(f"Got ndim!=1. {f=}")
92
+ f = _unique_any(f)
93
+ with torch.no_grad():
94
+ f_regular = torch.linspace(f[0].item(), f[-1].item(), r, device=f.device)
95
+ idx = torch.searchsorted(f, f_regular)
96
+ f = torch.cat([f, torch.tensor([torch.inf])])
97
+ if unique:
98
+ return _unique_any(f[idx])
99
+ return f[idx]
100
+
101
+
102
+ def _regular_closest(f, r: int, unique: bool = True):
103
+ if f.ndim != 1:
104
+ raise ValueError(f"Got ndim!=1. {f=}")
105
+ f = _unique_any(f)
106
+ with torch.no_grad():
107
+ f_reg = torch.linspace(
108
+ f[0].item(), f[-1].item(), steps=r, dtype=f.dtype, device=f.device
109
+ )
110
+ _f = LazyTensor(f[:, None, None])
111
+ _f_reg = LazyTensor(f_reg[None, :, None])
112
+ indices = (_f - _f_reg).abs().argmin(0).ravel()
113
+ f = torch.cat([f, torch.tensor([torch.inf])])
114
+ f_regular_closest = f[indices]
115
+ if unique:
116
+ f_regular_closest = _unique_any(f_regular_closest)
117
+ return f_regular_closest
118
+
119
+
120
+ def evaluate_in_grid(pts, grid):
121
+ """Evaluates points (assumed to be coordinates) in this grid.
122
+ Input
123
+ -----
124
+ - pts: (num_points, num_parameters) array
125
+ - grid: Iterable of 1-d array, for each parameter
126
+
127
+ Returns
128
+ -------
129
+ - array of shape like points of dtype like grid.
130
+ """
131
+ # grid = [torch.cat([g, torch.tensor([torch.inf])]) for g in grid]
132
+ # new_pts = torch.empty(pts.shape, dtype=grid[0].dtype, device=grid[0].device)
133
+ # for parameter, pt_of_parameter in enumerate(pts.T):
134
+ # new_pts[:, parameter] = grid[parameter][pt_of_parameter]
135
+ return torch.cat(
136
+ [
137
+ grid[parameter][pt_of_parameter][:, None]
138
+ for parameter, pt_of_parameter in enumerate(pts.T)
139
+ ],
140
+ dim=1,
141
+ )
142
+
143
+
144
+ def evaluate_mod_in_grid(mod, grid, box=None):
145
+ """Given an MMA module, pushes it into the specified grid.
146
+ Useful for e.g., make it differentiable.
147
+
148
+ Input
149
+ -----
150
+ - mod: PyModule
151
+ - grid: Iterable of 1d array, for num_parameters
152
+ Ouput
153
+ -----
154
+ torch-compatible module in the format:
155
+ (num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
156
+
157
+ """
158
+ if box is not None:
159
+ grid = tuple(
160
+ torch.cat(
161
+ [
162
+ box[0][[i]],
163
+ _unique_any(
164
+ grid[i].clamp(min=box[0][i], max=box[1][i]), assume_sorted=True
165
+ ),
166
+ box[1][[i]],
167
+ ]
168
+ )
169
+ for i in range(len(grid))
170
+ )
171
+ (birth_sizes, death_sizes), births, deaths = mod.to_flat_idx(grid)
172
+ births = evaluate_in_grid(births, grid)
173
+ deaths = evaluate_in_grid(deaths, grid)
174
+ diff_mod = tuple(
175
+ zip(
176
+ births.split_with_sizes(birth_sizes.tolist()),
177
+ deaths.split_with_sizes(death_sizes.tolist()),
178
+ )
179
+ )
180
+ return diff_mod
181
+
182
+
183
+ def evaluate_mod_in_grid__old(mod, grid, box=None):
184
+ """Given an MMA module, pushes it into the specified grid.
185
+ Useful for e.g., make it differentiable.
186
+
187
+ Input
188
+ -----
189
+ - mod: PyModule
190
+ - grid: Iterable of 1d array, for num_parameters
191
+ Ouput
192
+ -----
193
+ torch-compatible module in the format:
194
+ (num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
195
+
196
+ """
197
+ from pykeops.numpy import LazyTensor
198
+
199
+ with torch.no_grad():
200
+ if box is None:
201
+ # box = mod.get_box()
202
+ box = np.asarray([[g[0] for g in grid], [g[-1] for g in grid]])
203
+ S = mod.dump()[1]
204
+
205
+ def get_idx_parameter(A, G, p):
206
+ g = G[p].numpy() if isinstance(G[p], torch.Tensor) else np.asarray(G[p])
207
+ la = LazyTensor(np.asarray(A, dtype=g.dtype)[None, :, [p]])
208
+ lg = LazyTensor(g[:, None, None])
209
+ return (la - lg).abs().argmin(0)
210
+
211
+ Bdump = np.concatenate([s[0] for s in S], axis=0).clip(box[[0]], box[[1]])
212
+ B = np.concatenate(
213
+ [get_idx_parameter(Bdump, grid, p) for p in range(mod.num_parameters)],
214
+ axis=1,
215
+ dtype=np.int64,
216
+ )
217
+ Ddump = np.concatenate([s[1] for s in S], axis=0, dtype=np.float32).clip(
218
+ box[[0]], box[[1]]
219
+ )
220
+ D = np.concatenate(
221
+ [get_idx_parameter(Ddump, grid, p) for p in range(mod.num_parameters)],
222
+ axis=1,
223
+ dtype=np.int64,
224
+ )
225
+
226
+ BB = evaluate_in_grid(B, grid)
227
+ DD = evaluate_in_grid(D, grid)
228
+
229
+ b_idx = tuple((len(s[0]) for s in S))
230
+ d_idx = tuple((len(s[1]) for s in S))
231
+ BBB = BB.split_with_sizes(b_idx)
232
+ DDD = DD.split_with_sizes(d_idx)
233
+
234
+ splits = np.concatenate([[0], mod.degree_splits(), [len(BBB)]])
235
+ splits = torch.from_numpy(splits)
236
+ out = [
237
+ list(zip(BBB[splits[i] : splits[i + 1]], DDD[splits[i] : splits[i + 1]]))
238
+ for i in range(len(splits) - 1)
239
+ ] ## For some reasons this kills the gradient ???? pytorch bug
240
+ return out
@@ -0,0 +1,310 @@
1
+ from typing import Callable, Literal, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ import gudhi as gd
6
+
7
+ import multipers as mp
8
+ from multipers.filtrations.density import DTM, KDE
9
+ from multipers.simplex_tree_multi import _available_strategies
10
+ from multipers.torch.diff_grids import get_grid
11
+
12
+
13
+ def function_rips_signed_measure_old(
14
+ x,
15
+ theta: Optional[float] = None,
16
+ function: Literal["dtm", "gaussian", "exponential"] | Callable = "dtm",
17
+ threshold: float = np.inf,
18
+ grid_strategy: _available_strategies = "regular_closest",
19
+ resolution: int = 100,
20
+ return_original: bool = False,
21
+ return_st: bool = False,
22
+ safe_conversion: bool = False,
23
+ num_collapses: int = -1,
24
+ expand_collapse: bool = False,
25
+ dtype=torch.float32,
26
+ **sm_kwargs,
27
+ ):
28
+ """
29
+ Computes a torch-differentiable function-rips signed measure.
30
+
31
+ Input
32
+ -----
33
+ - x (num_pts, dim) : The point cloud
34
+ - theta: For density-like functions : the bandwidth
35
+ - threshold : rips threshold
36
+ - function : Either "dtm", "gaussian", or "exponenetial" or Callable.
37
+ Function to compute the second parameter.
38
+ - grid_strategy: grid coarsenning strategy.
39
+ - resolution : when coarsenning, the target resolution,
40
+ - return_original : Also returns the non-differentiable signed measure.
41
+ - safe_conversion : Activate this if you encounter crashes.
42
+ - **kwargs : for the signed measure computation.
43
+ """
44
+ assert isinstance(x, torch.Tensor)
45
+ if function == "dtm":
46
+ assert theta is not None, "Provide a theta to compute DTM"
47
+ codensity = DTM(masses=[theta]).fit(x).score_samples_diff(x)[0].type(dtype)
48
+ elif function in ["gaussian", "exponential"]:
49
+ assert theta is not None, "Provide a theta to compute density estimation"
50
+ codensity = (
51
+ -KDE(
52
+ bandwidth=theta,
53
+ kernel=function,
54
+ return_log=True,
55
+ )
56
+ .fit(x)
57
+ .score_samples(x)
58
+ .type(dtype)
59
+ )
60
+ else:
61
+ assert callable(function), "Function has to be callable"
62
+ if theta is None:
63
+ codensity = function(x).type(dtype)
64
+ else:
65
+ codensity = function(x, theta=theta).type(dtype)
66
+
67
+ distance_matrix = torch.cdist(x, x).type(dtype)
68
+ if threshold < np.inf:
69
+ distance_matrix[distance_matrix > threshold] = np.inf
70
+
71
+ # st = RipsComplex(
72
+ # distance_matrix=distance_matrix.detach(), max_edge_length=threshold
73
+ # ).create_simplex_tree()
74
+ st = gd.SimplexTree.create_from_array(
75
+ distance_matrix.detach(), max_filtration=threshold
76
+ )
77
+ # detach makes a new (reference) tensor, without tracking the gradient
78
+ st = mp.SimplexTreeMulti(st, num_parameters=2, safe_conversion=safe_conversion)
79
+ st.fill_lowerstar(
80
+ codensity.detach(), parameter=1
81
+ ) # fills the codensity in the second parameter of the simplextree
82
+
83
+ # simplificates the simplextree for computation, the signed measure will be recovered from the copy afterward
84
+ st_copy = st.grid_squeeze(
85
+ grid_strategy=grid_strategy, resolution=resolution, coordinate_values=True
86
+ )
87
+ if sm_kwargs.get("degree", None) is None and sm_kwargs.get("degrees", [None]) == [
88
+ None
89
+ ]:
90
+ expansion_degree = st.num_vertices
91
+ else:
92
+ expansion_degree = (
93
+ max(np.max(sm_kwargs.get("degrees", 1)), sm_kwargs.get("degree", 1)) + 1
94
+ )
95
+ st.collapse_edges(num=num_collapses)
96
+ if not expand_collapse:
97
+ st.expansion(expansion_degree) # edge collapse
98
+ sms = mp.signed_measure(st, **sm_kwargs) # computes the signed measure
99
+ del st
100
+
101
+ simplices_list = tuple(
102
+ s for s, _ in st_copy.get_simplices()
103
+ ) # not optimal, we may want to do that in cython to get edges and nodes
104
+ sms_diff = []
105
+ for sm, weights in sms:
106
+ indices, not_found_indices = st_copy.pts_to_indices(
107
+ sm, simplices_dimensions=[1, 0]
108
+ )
109
+ if sm_kwargs.get("verbose", False):
110
+ print(
111
+ f"Found {(1-(indices == -1).mean()).round(2)} indices. \
112
+ Out : {(indices == -1).sum()}, {len(not_found_indices)}"
113
+ )
114
+ sm_diff = torch.empty(sm.shape).type(dtype)
115
+ # sim_dim = sm_diff.shape[1] // 2
116
+
117
+ # fills the Rips-filtrations of the signed measure.
118
+ # the loop is for the rank invariant
119
+ for i in range(0, sm_diff.shape[1], 2):
120
+ idxs = indices[:, i]
121
+ if (idxs == -1).all():
122
+ continue
123
+ useful_idxs = idxs != -1
124
+ # Retrieves the differentiable values from the distance_matrix
125
+ if useful_idxs.size > 0:
126
+ edges_filtrations = torch.cat(
127
+ [
128
+ distance_matrix[*simplices_list[idx], None]
129
+ for idx in idxs[useful_idxs]
130
+ ]
131
+ )
132
+ # fills theses values into the signed measure
133
+ sm_diff[:, i][useful_idxs] = edges_filtrations
134
+ # same for the other axis
135
+ for i in range(1, sm_diff.shape[1], 2):
136
+ idxs = indices[:, i]
137
+ if (idxs == -1).all():
138
+ continue
139
+ useful_idxs = idxs != -1
140
+ if useful_idxs.size > 0:
141
+ nodes_filtrations = torch.cat(
142
+ [codensity[simplices_list[idx]] for idx in idxs[useful_idxs]]
143
+ )
144
+ sm_diff[:, i][useful_idxs] = nodes_filtrations
145
+
146
+ # fills not-found values as constants
147
+ if len(not_found_indices) > 0:
148
+ not_found_indices = indices == -1
149
+ sm_diff[indices == -1] = torch.from_numpy(sm[indices == -1]).type(dtype)
150
+
151
+ sms_diff.append((sm_diff, torch.from_numpy(weights)))
152
+ flags = [True, return_original, return_st]
153
+ if np.sum(flags) == 1:
154
+ return sms_diff
155
+ return tuple(stuff for stuff, flag in zip([sms_diff, sms, st_copy], flags) if flag)
156
+
157
+
158
+ def function_rips_signed_measure(
159
+ x,
160
+ theta: Optional[float] = None,
161
+ function: Literal["dtm", "gaussian", "exponential"] | Callable = "gaussian",
162
+ threshold: Optional[float] = None,
163
+ grid_strategy: Literal[
164
+ "regular_closest", "exact", "quantile", "regular_left"
165
+ ] = "exact",
166
+ complex: Literal["rips", "delaunay", "weak_delaunay"] = "rips",
167
+ resolution: int = 100,
168
+ safe_conversion: bool = False,
169
+ num_collapses: Optional[int] = None,
170
+ expand_collapse: bool = False,
171
+ dtype=torch.float32,
172
+ plot=False,
173
+ # return_st: bool = False,
174
+ *,
175
+ log_density: bool = True,
176
+ vineyard: bool = False,
177
+ pers_backend=None,
178
+ **sm_kwargs,
179
+ ):
180
+ """
181
+ Computes a torch-differentiable function-rips signed measure.
182
+
183
+ Input
184
+ -----
185
+ - x (num_pts, dim) : The point cloud
186
+ - theta: For density-like functions : the bandwidth
187
+ - threshold : rips threshold
188
+ - function : Either "dtm", "gaussian", or "exponenetial" or Callable.
189
+ Function to compute the second parameter.
190
+ - grid_strategy: grid coarsenning strategy.
191
+ - resolution : when coarsenning, the target resolution,
192
+ - return_original : Also returns the non-differentiable signed measure.
193
+ - safe_conversion : Activate this if you encounter crashes.
194
+ - **kwargs : for the signed measure computation.
195
+ """
196
+ if num_collapses is None:
197
+ num_collapses = -1 if complex == "rips" else None
198
+ assert isinstance(x, torch.Tensor)
199
+ if function == "dtm":
200
+ assert theta is not None, "Provide a theta to compute DTM"
201
+ codensity = DTM(masses=[theta]).fit(x).score_samples_diff(x)[0].type(dtype)
202
+ elif function in ["gaussian", "exponential"]:
203
+ assert theta is not None, "Provide a theta to compute density estimation"
204
+ codensity = (
205
+ -KDE(
206
+ bandwidth=theta,
207
+ kernel=function,
208
+ return_log=log_density,
209
+ )
210
+ .fit(x)
211
+ .score_samples(x)
212
+ .type(dtype)
213
+ )
214
+ elif isinstance(function, torch.Tensor):
215
+ assert (
216
+ function.ndim == 1 and codensity.shape[0] == x.shape[0]
217
+ ), """
218
+ When function is a tensor, it is interpreted as the value of some function over x.
219
+ """
220
+ codensity = function
221
+ else:
222
+ assert callable(function), "Function has to be callable"
223
+ if theta is None:
224
+ codensity = function(x).type(dtype)
225
+ else:
226
+ codensity = function(x, theta=theta).type(dtype)
227
+
228
+ distance_matrix = torch.cdist(x, x).type(dtype)
229
+ distances = distance_matrix.ravel()
230
+ if complex == "rips":
231
+ threshold = (
232
+ distance_matrix.max(axis=1).values.min() if threshold is None else threshold
233
+ )
234
+ distances = distances[distances <= threshold]
235
+ elif complex in ["delaunay", "weak_delaunay"]:
236
+ complex = "delaunay"
237
+ distances /= 2
238
+ else:
239
+ raise ValueError(
240
+ f"Unimplemented with complex {complex}. You can use rips or delaunay ftm."
241
+ )
242
+
243
+ # simplificates the simplextree for computation, the signed measure will be recovered from the copy afterward
244
+ reduced_grid = get_grid(strategy=grid_strategy)((distances, codensity), resolution)
245
+
246
+ degrees = sm_kwargs.pop("degrees", [])
247
+ if sm_kwargs.get("degree", None) is not None:
248
+ degrees = [sm_kwargs.pop("degree", None)] + degrees
249
+ if complex == "rips":
250
+ # st = RipsComplex(
251
+ # distance_matrix=distance_matrix.detach(), max_edge_length=threshold
252
+ # ).create_simplex_tree()
253
+ st = gd.SimplexTree.create_from_array(
254
+ distance_matrix.detach(), max_filtration=threshold
255
+ )
256
+ # detach makes a new (reference) tensor, without tracking the gradient
257
+ st = mp.SimplexTreeMulti(st, num_parameters=2, safe_conversion=safe_conversion)
258
+ st.fill_lowerstar(
259
+ codensity.detach(), parameter=1
260
+ ) # fills the codensity in the second parameter of the simplextree
261
+ st = st.grid_squeeze(reduced_grid)
262
+ st.filtration_grid = []
263
+ if None in degrees:
264
+ expansion_degree = st.num_vertices
265
+ else:
266
+ expansion_degree = max(degrees) + 1
267
+ st.collapse_edges(num=num_collapses)
268
+ if not expand_collapse:
269
+ st.expansion(expansion_degree) # edge collapse
270
+
271
+ s = mp.Slicer(st, vineyard=vineyard, backend=pers_backend)
272
+ elif complex == "delaunay":
273
+ s = mp.slicer.from_function_delaunay(
274
+ x.detach().numpy(), codensity.detach().numpy()
275
+ )
276
+ st = mp.slicer.to_simplextree(s)
277
+ st.flagify(2)
278
+ s = mp.Slicer(st, vineyard=vineyard, backend=pers_backend).grid_squeeze(
279
+ reduced_grid
280
+ )
281
+
282
+ s.filtration_grid = [] ## To enforce minpres to be reasonable
283
+ if None not in degrees:
284
+ s = s.minpres(degrees=degrees)
285
+ else:
286
+ from joblib import Parallel, delayed
287
+
288
+ s = tuple(
289
+ Parallel(n_jobs=-1, backend="threading")(
290
+ delayed(lambda d: s if d is None else s.minpres(degree=d))(d)
291
+ for d in degrees
292
+ )
293
+ )
294
+ ## fix previous hack
295
+ for stuff in s:
296
+ # stuff.filtration_grid = reduced_grid ## not necessary
297
+ stuff.filtration_grid = [[1]] * stuff.num_parameters
298
+
299
+ sms = tuple(
300
+ sm
301
+ for slicer_of_degree, degree in zip(s, degrees)
302
+ for sm in mp.signed_measure(
303
+ slicer_of_degree, grid=reduced_grid, degree=degree, **sm_kwargs
304
+ )
305
+ ) # computes the signed measure
306
+ if plot:
307
+ mp.plots.plot_signed_measures(
308
+ tuple((sm.detach().numpy(), w.detach().numpy()) for sm, w in sms)
309
+ )
310
+ return sms