multipers 2.3.1__cp313-cp313-macosx_13_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (180) hide show
  1. multipers/.dylibs/libc++.1.0.dylib +0 -0
  2. multipers/.dylibs/libtbb.12.14.dylib +0 -0
  3. multipers/__init__.py +33 -0
  4. multipers/_signed_measure_meta.py +430 -0
  5. multipers/_slicer_meta.py +211 -0
  6. multipers/data/MOL2.py +458 -0
  7. multipers/data/UCR.py +18 -0
  8. multipers/data/__init__.py +1 -0
  9. multipers/data/graphs.py +466 -0
  10. multipers/data/immuno_regions.py +27 -0
  11. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  12. multipers/data/pytorch2simplextree.py +91 -0
  13. multipers/data/shape3d.py +101 -0
  14. multipers/data/synthetic.py +113 -0
  15. multipers/distances.py +198 -0
  16. multipers/filtration_conversions.pxd +229 -0
  17. multipers/filtration_conversions.pxd.tp +84 -0
  18. multipers/filtrations/__init__.py +18 -0
  19. multipers/filtrations/density.py +563 -0
  20. multipers/filtrations/filtrations.py +289 -0
  21. multipers/filtrations.pxd +224 -0
  22. multipers/function_rips.cpython-313-darwin.so +0 -0
  23. multipers/function_rips.pyx +105 -0
  24. multipers/grids.cpython-313-darwin.so +0 -0
  25. multipers/grids.pyx +350 -0
  26. multipers/gudhi/Persistence_slices_interface.h +132 -0
  27. multipers/gudhi/Simplex_tree_interface.h +239 -0
  28. multipers/gudhi/Simplex_tree_multi_interface.h +516 -0
  29. multipers/gudhi/cubical_to_boundary.h +59 -0
  30. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  31. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  32. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  33. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  34. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  35. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  36. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  37. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  38. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  39. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  40. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  41. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  42. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  43. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  44. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  45. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  46. multipers/gudhi/gudhi/Matrix.h +2107 -0
  47. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  48. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -0
  49. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  50. multipers/gudhi/gudhi/Off_reader.h +173 -0
  51. multipers/gudhi/gudhi/One_critical_filtration.h +1433 -0
  52. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  53. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  54. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  86. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  87. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  88. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  89. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  90. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  91. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  92. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  93. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  94. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  95. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  96. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  97. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  98. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  99. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  100. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  101. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
  102. multipers/gudhi/gudhi/distance_functions.h +62 -0
  103. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  104. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  105. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  106. multipers/gudhi/gudhi/reader_utils.h +367 -0
  107. multipers/gudhi/mma_interface_coh.h +256 -0
  108. multipers/gudhi/mma_interface_h0.h +223 -0
  109. multipers/gudhi/mma_interface_matrix.h +291 -0
  110. multipers/gudhi/naive_merge_tree.h +536 -0
  111. multipers/gudhi/scc_io.h +310 -0
  112. multipers/gudhi/truc.h +957 -0
  113. multipers/io.cpython-313-darwin.so +0 -0
  114. multipers/io.pyx +714 -0
  115. multipers/ml/__init__.py +0 -0
  116. multipers/ml/accuracies.py +90 -0
  117. multipers/ml/invariants_with_persistable.py +79 -0
  118. multipers/ml/kernels.py +176 -0
  119. multipers/ml/mma.py +713 -0
  120. multipers/ml/one.py +472 -0
  121. multipers/ml/point_clouds.py +352 -0
  122. multipers/ml/signed_measures.py +1589 -0
  123. multipers/ml/sliced_wasserstein.py +461 -0
  124. multipers/ml/tools.py +113 -0
  125. multipers/mma_structures.cpython-313-darwin.so +0 -0
  126. multipers/mma_structures.pxd +127 -0
  127. multipers/mma_structures.pyx +2742 -0
  128. multipers/mma_structures.pyx.tp +1083 -0
  129. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
  130. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  131. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  132. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  133. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  134. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  135. multipers/multiparameter_edge_collapse.py +41 -0
  136. multipers/multiparameter_module_approximation/approximation.h +2298 -0
  137. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  138. multipers/multiparameter_module_approximation/debug.h +107 -0
  139. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  140. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  141. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  142. multipers/multiparameter_module_approximation/images.h +79 -0
  143. multipers/multiparameter_module_approximation/list_column.h +174 -0
  144. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  145. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  146. multipers/multiparameter_module_approximation/set_column.h +135 -0
  147. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  148. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  149. multipers/multiparameter_module_approximation/utilities.h +403 -0
  150. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  151. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  152. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  153. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  154. multipers/multiparameter_module_approximation.cpython-313-darwin.so +0 -0
  155. multipers/multiparameter_module_approximation.pyx +218 -0
  156. multipers/pickle.py +90 -0
  157. multipers/plots.py +342 -0
  158. multipers/point_measure.cpython-313-darwin.so +0 -0
  159. multipers/point_measure.pyx +322 -0
  160. multipers/simplex_tree_multi.cpython-313-darwin.so +0 -0
  161. multipers/simplex_tree_multi.pxd +133 -0
  162. multipers/simplex_tree_multi.pyx +10402 -0
  163. multipers/simplex_tree_multi.pyx.tp +1947 -0
  164. multipers/slicer.cpython-313-darwin.so +0 -0
  165. multipers/slicer.pxd +2552 -0
  166. multipers/slicer.pxd.tp +218 -0
  167. multipers/slicer.pyx +16530 -0
  168. multipers/slicer.pyx.tp +931 -0
  169. multipers/tensor/tensor.h +672 -0
  170. multipers/tensor.pxd +13 -0
  171. multipers/test.pyx +44 -0
  172. multipers/tests/__init__.py +57 -0
  173. multipers/torch/__init__.py +1 -0
  174. multipers/torch/diff_grids.py +217 -0
  175. multipers/torch/rips_density.py +310 -0
  176. multipers-2.3.1.dist-info/LICENSE +21 -0
  177. multipers-2.3.1.dist-info/METADATA +144 -0
  178. multipers-2.3.1.dist-info/RECORD +180 -0
  179. multipers-2.3.1.dist-info/WHEEL +6 -0
  180. multipers-2.3.1.dist-info/top_level.txt +1 -0
multipers/test.pyx ADDED
@@ -0,0 +1,44 @@
1
+ # cimport multipers.tensor as mt
2
+ from libc.stdint cimport intptr_t, uint16_t
3
+ from libcpp.vector cimport vector
4
+ from libcpp cimport bool, int, float
5
+ from libcpp.utility cimport pair
6
+ from typing import Optional,Iterable,Callable
7
+
8
+
9
+ ctypedef float value_type
10
+ # ctypedef uint16_t index_type
11
+
12
+ import numpy as np
13
+ # cimport numpy as cnp
14
+ # cnp.import_array()
15
+
16
+ # cdef extern from "multi_parameter_rank_invariant/rank_invariant.h" namespace "Gudhi::rank_invariant":
17
+ # void get_hilbert_surface(const intptr_t, mt.static_tensor_view, const vector[index_type], const vector[index_type], index_type, index_type, const vector[index_type], bool, bool) except + nogil
18
+
19
+
20
+ from multipers.simplex_tree_multi import SimplexTreeMulti
21
+
22
+
23
+ def numpy_to_tensor(array:np.ndarray):
24
+ cdef vector[index_type] shape = array.shape
25
+ cdef dtype[::1] contigus_array_view = np.ascontiguousarray(array)
26
+ cdef dtype* dtype_ptr = &contigus_array_view[0]
27
+ cdef mt.static_tensor_view tensor
28
+ with nogil:
29
+ tensor = mt.static_tensor_view(dtype_ptr, shape)
30
+ return tensor.get_resolution()
31
+
32
+ # def hilbert2d(simplextree:SimplexTreeMulti, grid_shape:np.ndarray|list, vector[index_type] degrees, bool mobius_inversion):
33
+ # # assert simplextree.num_parameters == 2
34
+ # cdef intptr_t ptr = simplextree.thisptr
35
+ # cdef vector[index_type] c_grid_shape = grid_shape
36
+ # cdef dtype[::1] container = np.zeros(grid_shape, dtype=np.float32).flatten()
37
+ # cdef dtype* container_ptr = &container[0]
38
+ # cdef mt.static_tensor_view c_container = mt.static_tensor_view(container_ptr, c_grid_shape)
39
+ # cdef index_type i = 0
40
+ # cdef index_type j = 1
41
+ # cdef vector[index_type] fixed_values = [[],[]]
42
+ # # get_hilbert_surface(ptr, c_container, c_grid_shape, degrees,i,j,fixed_values, False, False)
43
+ # return container.reshape(grid_shape)
44
+
@@ -0,0 +1,57 @@
1
+ import numpy as np
2
+
3
+
4
+ def assert_st_simplices(st, dump):
5
+ """
6
+ Checks that the simplextree has the same
7
+ filtration as the dump.
8
+ """
9
+
10
+ assert np.all(
11
+ [
12
+ np.isclose(a, b).all()
13
+ for x, y in zip(st.get_simplices(), dump, strict=True)
14
+ for a, b in zip(x, y, strict=True)
15
+ ]
16
+ )
17
+
18
+
19
+ def sort_sm(sms):
20
+ idx = np.argsort([sm[0][:, 0] for sm in sms])
21
+ return tuple((sm[0][idx], sm[1][idx]) for sm in sms)
22
+
23
+
24
+ def assert_sm_pair(sm1, sm2, exact=True, max_error=1e-3, reg=0.1):
25
+ if not exact:
26
+ from multipers.distances import sm_distance
27
+
28
+ d = sm_distance(sm1, sm2, reg=0.1)
29
+ assert d < max_error, f"Failed comparison:\n{sm1}\n{sm2},\n with distance {d}."
30
+ return
31
+ assert np.all(
32
+ [
33
+ np.isclose(a, b).all()
34
+ for x, y in zip(sm1, sm2, strict=True)
35
+ for a, b in zip(x, y, strict=True)
36
+ ]
37
+ ), f"Failed comparison:\n-----------------\n{sm1}\n-----------------\n{sm2}"
38
+
39
+
40
+ def assert_sm(*args, exact=True, max_error=1e-5, reg=0.1):
41
+ sms = tuple(args)
42
+ for i in range(len(sms) - 1):
43
+ assert_sm_pair(sms[i], sms[i + 1], exact=exact, max_error=max_error, reg=reg)
44
+
45
+
46
+ def random_st(npts=100, num_parameters=2, max_dim=2):
47
+ import gudhi as gd
48
+
49
+ import multipers as mp
50
+ from multipers.data import noisy_annulus
51
+
52
+ x = noisy_annulus(npts // 2, npts - npts // 2, dim=max_dim)
53
+ st = gd.AlphaComplex(points=x).create_simplex_tree()
54
+ st = mp.SimplexTreeMulti(st, num_parameters=num_parameters)
55
+ for p in range(num_parameters):
56
+ st.fill_lowerstar(np.random.uniform(size=npts), p)
57
+ return st
@@ -0,0 +1 @@
1
+ from .rips_density import *
@@ -0,0 +1,217 @@
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ import torch
5
+ from pykeops.torch import LazyTensor
6
+
7
+
8
+ def get_grid(strategy: Literal["exact", "regular_closest", "regular_left", "quantile"]):
9
+ """
10
+ Given a strategy, returns a function of signature
11
+ `(num_pts, num_parameter), int --> Iterable[1d array]`
12
+ that generates a torch-differentiable grid from a set of points,
13
+ and a resolution.
14
+ """
15
+ match strategy:
16
+ case "exact":
17
+ return _exact_grid
18
+ case "regular_closest":
19
+ return _regular_closest_grid
20
+ case "regular_left":
21
+ return _regular_left_grid
22
+ case "quantile":
23
+ return _quantile_grid
24
+ case _:
25
+ raise ValueError(
26
+ f"""
27
+ Unimplemented strategy {strategy}.
28
+ Available ones : exact, regular_closest, regular_left, quantile.
29
+ """
30
+ )
31
+
32
+
33
+ def todense(grid: list[torch.Tensor]):
34
+ return torch.cartesian_prod(*grid)
35
+
36
+
37
+ def _exact_grid(filtration_values, r=None):
38
+ grid = tuple(_unique_any(f) for f in filtration_values)
39
+ return grid
40
+
41
+
42
+ def _regular_closest_grid(filtration_values, r: int):
43
+ grid = tuple(_regular_closest(f, r) for f in filtration_values)
44
+ return grid
45
+
46
+
47
+ def _regular_left_grid(filtration_values, r: int):
48
+ grid = tuple(_regular_left(f, r) for f in filtration_values)
49
+ return grid
50
+
51
+
52
+ def _quantile_grid(filtration_values, r: int):
53
+ qs = torch.linspace(0, 1, r)
54
+ grid = tuple(_unique_any(torch.quantile(f, q=qs)) for f in filtration_values)
55
+ return grid
56
+
57
+
58
+ def _unique_any(x, assume_sorted=False, remove_inf: bool = True):
59
+ if not assume_sorted:
60
+ x, _ = x.sort()
61
+ if remove_inf and x[-1] == torch.inf:
62
+ x = x[:-1]
63
+ with torch.no_grad():
64
+ y = x.unique()
65
+ idx = torch.searchsorted(x, y)
66
+ x = torch.cat([x, torch.tensor([torch.inf])])
67
+ return x[idx]
68
+
69
+
70
+ def _regular_left(f, r: int, unique: bool = True):
71
+ f = _unique_any(f)
72
+ with torch.no_grad():
73
+ f_regular = torch.linspace(f[0].item(), f[-1].item(), r, device=f.device)
74
+ idx = torch.searchsorted(f, f_regular)
75
+ f = torch.cat([f, torch.tensor([torch.inf])])
76
+ if unique:
77
+ return _unique_any(f[idx])
78
+ return f[idx]
79
+
80
+
81
+ def _regular_closest(f, r: int, unique: bool = True):
82
+ f = _unique_any(f)
83
+ with torch.no_grad():
84
+ f_reg = torch.linspace(
85
+ f[0].item(), f[-1].item(), steps=r, dtype=f.dtype, device=f.device
86
+ )
87
+ _f = LazyTensor(f[:, None, None])
88
+ _f_reg = LazyTensor(f_reg[None, :, None])
89
+ indices = (_f - _f_reg).abs().argmin(0).ravel()
90
+ f = torch.cat([f, torch.tensor([torch.inf])])
91
+ f_regular_closest = f[indices]
92
+ if unique:
93
+ f_regular_closest = _unique_any(f_regular_closest)
94
+ return f_regular_closest
95
+
96
+
97
+ def evaluate_in_grid(pts, grid):
98
+ """Evaluates points (assumed to be coordinates) in this grid.
99
+ Input
100
+ -----
101
+ - pts: (num_points, num_parameters) array
102
+ - grid: Iterable of 1-d array, for each parameter
103
+
104
+ Returns
105
+ -------
106
+ - array of shape like points of dtype like grid.
107
+ """
108
+ # grid = [torch.cat([g, torch.tensor([torch.inf])]) for g in grid]
109
+ # new_pts = torch.empty(pts.shape, dtype=grid[0].dtype, device=grid[0].device)
110
+ # for parameter, pt_of_parameter in enumerate(pts.T):
111
+ # new_pts[:, parameter] = grid[parameter][pt_of_parameter]
112
+ return torch.cat(
113
+ [
114
+ grid[parameter][pt_of_parameter][:, None]
115
+ for parameter, pt_of_parameter in enumerate(pts.T)
116
+ ],
117
+ dim=1,
118
+ )
119
+
120
+
121
+ def evaluate_mod_in_grid(mod, grid, box=None):
122
+ """Given an MMA module, pushes it into the specified grid.
123
+ Useful for e.g., make it differentiable.
124
+
125
+ Input
126
+ -----
127
+ - mod: PyModule
128
+ - grid: Iterable of 1d array, for num_parameters
129
+ Ouput
130
+ -----
131
+ torch-compatible module in the format:
132
+ (num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
133
+
134
+ """
135
+ if box is not None:
136
+ grid = tuple(
137
+ torch.cat(
138
+ [
139
+ box[0][[i]],
140
+ _unique_any(
141
+ grid[i].clamp(min=box[0][i], max=box[1][i]), assume_sorted=True
142
+ ),
143
+ box[1][[i]],
144
+ ]
145
+ )
146
+ for i in range(len(grid))
147
+ )
148
+ (birth_sizes, death_sizes), births, deaths = mod.to_flat_idx(grid)
149
+ births = evaluate_in_grid(births, grid)
150
+ deaths = evaluate_in_grid(deaths, grid)
151
+ diff_mod = tuple(
152
+ zip(
153
+ births.split_with_sizes(birth_sizes.tolist()),
154
+ deaths.split_with_sizes(death_sizes.tolist()),
155
+ )
156
+ )
157
+ return diff_mod
158
+
159
+
160
+ def evaluate_mod_in_grid__old(mod, grid, box=None):
161
+ """Given an MMA module, pushes it into the specified grid.
162
+ Useful for e.g., make it differentiable.
163
+
164
+ Input
165
+ -----
166
+ - mod: PyModule
167
+ - grid: Iterable of 1d array, for num_parameters
168
+ Ouput
169
+ -----
170
+ torch-compatible module in the format:
171
+ (num_degrees) x (num_interval of degree) x ((num_birth, num_parameter), (num_death, num_parameters))
172
+
173
+ """
174
+ from pykeops.numpy import LazyTensor
175
+
176
+ with torch.no_grad():
177
+ if box is None:
178
+ # box = mod.get_box()
179
+ box = np.asarray([[g[0] for g in grid], [g[-1] for g in grid]])
180
+ S = mod.dump()[1]
181
+
182
+ def get_idx_parameter(A, G, p):
183
+ g = G[p].numpy() if isinstance(G[p], torch.Tensor) else np.asarray(G[p])
184
+ la = LazyTensor(np.asarray(A, dtype=g.dtype)[None, :, [p]])
185
+ lg = LazyTensor(g[:, None, None])
186
+ return (la - lg).abs().argmin(0)
187
+
188
+ Bdump = np.concatenate([s[0] for s in S], axis=0).clip(box[[0]], box[[1]])
189
+ B = np.concatenate(
190
+ [get_idx_parameter(Bdump, grid, p) for p in range(mod.num_parameters)],
191
+ axis=1,
192
+ dtype=np.int64,
193
+ )
194
+ Ddump = np.concatenate([s[1] for s in S], axis=0, dtype=np.float32).clip(
195
+ box[[0]], box[[1]]
196
+ )
197
+ D = np.concatenate(
198
+ [get_idx_parameter(Ddump, grid, p) for p in range(mod.num_parameters)],
199
+ axis=1,
200
+ dtype=np.int64,
201
+ )
202
+
203
+ BB = evaluate_in_grid(B, grid)
204
+ DD = evaluate_in_grid(D, grid)
205
+
206
+ b_idx = tuple((len(s[0]) for s in S))
207
+ d_idx = tuple((len(s[1]) for s in S))
208
+ BBB = BB.split_with_sizes(b_idx)
209
+ DDD = DD.split_with_sizes(d_idx)
210
+
211
+ splits = np.concatenate([[0], mod.degree_splits(), [len(BBB)]])
212
+ splits = torch.from_numpy(splits)
213
+ out = [
214
+ list(zip(BBB[splits[i] : splits[i + 1]], DDD[splits[i] : splits[i + 1]]))
215
+ for i in range(len(splits) - 1)
216
+ ] ## For some reasons this kills the gradient ???? pytorch bug
217
+ return out
@@ -0,0 +1,310 @@
1
+ from typing import Callable, Literal, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ import gudhi as gd
6
+
7
+ import multipers as mp
8
+ from multipers.filtrations.density import DTM, KDE
9
+ from multipers.simplex_tree_multi import _available_strategies
10
+ from multipers.torch.diff_grids import get_grid
11
+
12
+
13
+ def function_rips_signed_measure_old(
14
+ x,
15
+ theta: Optional[float] = None,
16
+ function: Literal["dtm", "gaussian", "exponential"] | Callable = "dtm",
17
+ threshold: float = np.inf,
18
+ grid_strategy: _available_strategies = "regular_closest",
19
+ resolution: int = 100,
20
+ return_original: bool = False,
21
+ return_st: bool = False,
22
+ safe_conversion: bool = False,
23
+ num_collapses: int = -1,
24
+ expand_collapse: bool = False,
25
+ dtype=torch.float32,
26
+ **sm_kwargs,
27
+ ):
28
+ """
29
+ Computes a torch-differentiable function-rips signed measure.
30
+
31
+ Input
32
+ -----
33
+ - x (num_pts, dim) : The point cloud
34
+ - theta: For density-like functions : the bandwidth
35
+ - threshold : rips threshold
36
+ - function : Either "dtm", "gaussian", or "exponenetial" or Callable.
37
+ Function to compute the second parameter.
38
+ - grid_strategy: grid coarsenning strategy.
39
+ - resolution : when coarsenning, the target resolution,
40
+ - return_original : Also returns the non-differentiable signed measure.
41
+ - safe_conversion : Activate this if you encounter crashes.
42
+ - **kwargs : for the signed measure computation.
43
+ """
44
+ assert isinstance(x, torch.Tensor)
45
+ if function == "dtm":
46
+ assert theta is not None, "Provide a theta to compute DTM"
47
+ codensity = DTM(masses=[theta]).fit(x).score_samples_diff(x)[0].type(dtype)
48
+ elif function in ["gaussian", "exponential"]:
49
+ assert theta is not None, "Provide a theta to compute density estimation"
50
+ codensity = (
51
+ -KDE(
52
+ bandwidth=theta,
53
+ kernel=function,
54
+ return_log=True,
55
+ )
56
+ .fit(x)
57
+ .score_samples(x)
58
+ .type(dtype)
59
+ )
60
+ else:
61
+ assert callable(function), "Function has to be callable"
62
+ if theta is None:
63
+ codensity = function(x).type(dtype)
64
+ else:
65
+ codensity = function(x, theta=theta).type(dtype)
66
+
67
+ distance_matrix = torch.cdist(x, x).type(dtype)
68
+ if threshold < np.inf:
69
+ distance_matrix[distance_matrix > threshold] = np.inf
70
+
71
+ # st = RipsComplex(
72
+ # distance_matrix=distance_matrix.detach(), max_edge_length=threshold
73
+ # ).create_simplex_tree()
74
+ st = gd.SimplexTree.create_from_array(
75
+ distance_matrix.detach(), max_filtration=threshold
76
+ )
77
+ # detach makes a new (reference) tensor, without tracking the gradient
78
+ st = mp.SimplexTreeMulti(st, num_parameters=2, safe_conversion=safe_conversion)
79
+ st.fill_lowerstar(
80
+ codensity.detach(), parameter=1
81
+ ) # fills the codensity in the second parameter of the simplextree
82
+
83
+ # simplificates the simplextree for computation, the signed measure will be recovered from the copy afterward
84
+ st_copy = st.grid_squeeze(
85
+ grid_strategy=grid_strategy, resolution=resolution, coordinate_values=True
86
+ )
87
+ if sm_kwargs.get("degree", None) is None and sm_kwargs.get("degrees", [None]) == [
88
+ None
89
+ ]:
90
+ expansion_degree = st.num_vertices
91
+ else:
92
+ expansion_degree = (
93
+ max(np.max(sm_kwargs.get("degrees", 1)), sm_kwargs.get("degree", 1)) + 1
94
+ )
95
+ st.collapse_edges(num=num_collapses)
96
+ if not expand_collapse:
97
+ st.expansion(expansion_degree) # edge collapse
98
+ sms = mp.signed_measure(st, **sm_kwargs) # computes the signed measure
99
+ del st
100
+
101
+ simplices_list = tuple(
102
+ s for s, _ in st_copy.get_simplices()
103
+ ) # not optimal, we may want to do that in cython to get edges and nodes
104
+ sms_diff = []
105
+ for sm, weights in sms:
106
+ indices, not_found_indices = st_copy.pts_to_indices(
107
+ sm, simplices_dimensions=[1, 0]
108
+ )
109
+ if sm_kwargs.get("verbose", False):
110
+ print(
111
+ f"Found {(1-(indices == -1).mean()).round(2)} indices. \
112
+ Out : {(indices == -1).sum()}, {len(not_found_indices)}"
113
+ )
114
+ sm_diff = torch.empty(sm.shape).type(dtype)
115
+ # sim_dim = sm_diff.shape[1] // 2
116
+
117
+ # fills the Rips-filtrations of the signed measure.
118
+ # the loop is for the rank invariant
119
+ for i in range(0, sm_diff.shape[1], 2):
120
+ idxs = indices[:, i]
121
+ if (idxs == -1).all():
122
+ continue
123
+ useful_idxs = idxs != -1
124
+ # Retrieves the differentiable values from the distance_matrix
125
+ if useful_idxs.size > 0:
126
+ edges_filtrations = torch.cat(
127
+ [
128
+ distance_matrix[*simplices_list[idx], None]
129
+ for idx in idxs[useful_idxs]
130
+ ]
131
+ )
132
+ # fills theses values into the signed measure
133
+ sm_diff[:, i][useful_idxs] = edges_filtrations
134
+ # same for the other axis
135
+ for i in range(1, sm_diff.shape[1], 2):
136
+ idxs = indices[:, i]
137
+ if (idxs == -1).all():
138
+ continue
139
+ useful_idxs = idxs != -1
140
+ if useful_idxs.size > 0:
141
+ nodes_filtrations = torch.cat(
142
+ [codensity[simplices_list[idx]] for idx in idxs[useful_idxs]]
143
+ )
144
+ sm_diff[:, i][useful_idxs] = nodes_filtrations
145
+
146
+ # fills not-found values as constants
147
+ if len(not_found_indices) > 0:
148
+ not_found_indices = indices == -1
149
+ sm_diff[indices == -1] = torch.from_numpy(sm[indices == -1]).type(dtype)
150
+
151
+ sms_diff.append((sm_diff, torch.from_numpy(weights)))
152
+ flags = [True, return_original, return_st]
153
+ if np.sum(flags) == 1:
154
+ return sms_diff
155
+ return tuple(stuff for stuff, flag in zip([sms_diff, sms, st_copy], flags) if flag)
156
+
157
+
158
+ def function_rips_signed_measure(
159
+ x,
160
+ theta: Optional[float] = None,
161
+ function: Literal["dtm", "gaussian", "exponential"] | Callable = "gaussian",
162
+ threshold: Optional[float] = None,
163
+ grid_strategy: Literal[
164
+ "regular_closest", "exact", "quantile", "regular_left"
165
+ ] = "exact",
166
+ complex: Literal["rips", "delaunay", "weak_delaunay"] = "rips",
167
+ resolution: int = 100,
168
+ safe_conversion: bool = False,
169
+ num_collapses: Optional[int] = None,
170
+ expand_collapse: bool = False,
171
+ dtype=torch.float32,
172
+ plot=False,
173
+ # return_st: bool = False,
174
+ *,
175
+ log_density: bool = True,
176
+ vineyard: bool = False,
177
+ pers_backend=None,
178
+ **sm_kwargs,
179
+ ):
180
+ """
181
+ Computes a torch-differentiable function-rips signed measure.
182
+
183
+ Input
184
+ -----
185
+ - x (num_pts, dim) : The point cloud
186
+ - theta: For density-like functions : the bandwidth
187
+ - threshold : rips threshold
188
+ - function : Either "dtm", "gaussian", or "exponenetial" or Callable.
189
+ Function to compute the second parameter.
190
+ - grid_strategy: grid coarsenning strategy.
191
+ - resolution : when coarsenning, the target resolution,
192
+ - return_original : Also returns the non-differentiable signed measure.
193
+ - safe_conversion : Activate this if you encounter crashes.
194
+ - **kwargs : for the signed measure computation.
195
+ """
196
+ if num_collapses is None:
197
+ num_collapses = -1 if complex == "rips" else None
198
+ assert isinstance(x, torch.Tensor)
199
+ if function == "dtm":
200
+ assert theta is not None, "Provide a theta to compute DTM"
201
+ codensity = DTM(masses=[theta]).fit(x).score_samples_diff(x)[0].type(dtype)
202
+ elif function in ["gaussian", "exponential"]:
203
+ assert theta is not None, "Provide a theta to compute density estimation"
204
+ codensity = (
205
+ -KDE(
206
+ bandwidth=theta,
207
+ kernel=function,
208
+ return_log=log_density,
209
+ )
210
+ .fit(x)
211
+ .score_samples(x)
212
+ .type(dtype)
213
+ )
214
+ elif isinstance(function, torch.Tensor):
215
+ assert (
216
+ function.ndim == 1 and codensity.shape[0] == x.shape[0]
217
+ ), """
218
+ When function is a tensor, it is interpreted as the value of some function over x.
219
+ """
220
+ codensity = function
221
+ else:
222
+ assert callable(function), "Function has to be callable"
223
+ if theta is None:
224
+ codensity = function(x).type(dtype)
225
+ else:
226
+ codensity = function(x, theta=theta).type(dtype)
227
+
228
+ distance_matrix = torch.cdist(x, x).type(dtype)
229
+ distances = distance_matrix.ravel()
230
+ if complex == "rips":
231
+ threshold = (
232
+ distance_matrix.max(axis=1).values.min() if threshold is None else threshold
233
+ )
234
+ distances = distances[distances <= threshold]
235
+ elif complex in ["delaunay", "weak_delaunay"]:
236
+ complex = "delaunay"
237
+ distances /= 2
238
+ else:
239
+ raise ValueError(
240
+ f"Unimplemented with complex {complex}. You can use rips or delaunay ftm."
241
+ )
242
+
243
+ # simplificates the simplextree for computation, the signed measure will be recovered from the copy afterward
244
+ reduced_grid = get_grid(strategy=grid_strategy)((distances, codensity), resolution)
245
+
246
+ degrees = sm_kwargs.pop("degrees", [])
247
+ if sm_kwargs.get("degree", None) is not None:
248
+ degrees = [sm_kwargs.pop("degree", None)] + degrees
249
+ if complex == "rips":
250
+ # st = RipsComplex(
251
+ # distance_matrix=distance_matrix.detach(), max_edge_length=threshold
252
+ # ).create_simplex_tree()
253
+ st = gd.SimplexTree.create_from_array(
254
+ distance_matrix.detach(), max_filtration=threshold
255
+ )
256
+ # detach makes a new (reference) tensor, without tracking the gradient
257
+ st = mp.SimplexTreeMulti(st, num_parameters=2, safe_conversion=safe_conversion)
258
+ st.fill_lowerstar(
259
+ codensity.detach(), parameter=1
260
+ ) # fills the codensity in the second parameter of the simplextree
261
+ st = st.grid_squeeze(reduced_grid)
262
+ st.filtration_grid = []
263
+ if None in degrees:
264
+ expansion_degree = st.num_vertices
265
+ else:
266
+ expansion_degree = max(degrees) + 1
267
+ st.collapse_edges(num=num_collapses)
268
+ if not expand_collapse:
269
+ st.expansion(expansion_degree) # edge collapse
270
+
271
+ s = mp.Slicer(st, vineyard=vineyard, backend=pers_backend)
272
+ elif complex == "delaunay":
273
+ s = mp.slicer.from_function_delaunay(
274
+ x.detach().numpy(), codensity.detach().numpy()
275
+ )
276
+ st = mp.slicer.to_simplextree(s)
277
+ st.flagify(2)
278
+ s = mp.Slicer(st, vineyard=vineyard, backend=pers_backend).grid_squeeze(
279
+ reduced_grid
280
+ )
281
+
282
+ s.filtration_grid = [] ## To enforce minpres to be reasonable
283
+ if None not in degrees:
284
+ s = s.minpres(degrees=degrees)
285
+ else:
286
+ from joblib import Parallel, delayed
287
+
288
+ s = tuple(
289
+ Parallel(n_jobs=-1, backend="threading")(
290
+ delayed(lambda d: s if d is None else s.minpres(degree=d))(d)
291
+ for d in degrees
292
+ )
293
+ )
294
+ ## fix previous hack
295
+ for stuff in s:
296
+ # stuff.filtration_grid = reduced_grid ## not necessary
297
+ stuff.filtration_grid = [[1]] * stuff.num_parameters
298
+
299
+ sms = tuple(
300
+ sm
301
+ for slicer_of_degree, degree in zip(s, degrees)
302
+ for sm in mp.signed_measure(
303
+ slicer_of_degree, grid=reduced_grid, degree=degree, **sm_kwargs
304
+ )
305
+ ) # computes the signed measure
306
+ if plot:
307
+ mp.plots.plot_signed_measures(
308
+ tuple((sm.detach().numpy(), w.detach().numpy()) for sm, w in sms)
309
+ )
310
+ return sms
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 David Loiseaux
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.