multipers 2.3.3b6__cp311-cp311-macosx_10_13_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (183) hide show
  1. multipers/.dylibs/libc++.1.0.dylib +0 -0
  2. multipers/.dylibs/libtbb.12.16.dylib +0 -0
  3. multipers/__init__.py +33 -0
  4. multipers/_signed_measure_meta.py +453 -0
  5. multipers/_slicer_meta.py +211 -0
  6. multipers/array_api/__init__.py +45 -0
  7. multipers/array_api/numpy.py +41 -0
  8. multipers/array_api/torch.py +58 -0
  9. multipers/data/MOL2.py +458 -0
  10. multipers/data/UCR.py +18 -0
  11. multipers/data/__init__.py +1 -0
  12. multipers/data/graphs.py +466 -0
  13. multipers/data/immuno_regions.py +27 -0
  14. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  15. multipers/data/pytorch2simplextree.py +91 -0
  16. multipers/data/shape3d.py +101 -0
  17. multipers/data/synthetic.py +113 -0
  18. multipers/distances.py +202 -0
  19. multipers/filtration_conversions.pxd +229 -0
  20. multipers/filtration_conversions.pxd.tp +84 -0
  21. multipers/filtrations/__init__.py +18 -0
  22. multipers/filtrations/density.py +574 -0
  23. multipers/filtrations/filtrations.py +361 -0
  24. multipers/filtrations.pxd +224 -0
  25. multipers/function_rips.cpython-311-darwin.so +0 -0
  26. multipers/function_rips.pyx +105 -0
  27. multipers/grids.cpython-311-darwin.so +0 -0
  28. multipers/grids.pyx +433 -0
  29. multipers/gudhi/Persistence_slices_interface.h +132 -0
  30. multipers/gudhi/Simplex_tree_interface.h +239 -0
  31. multipers/gudhi/Simplex_tree_multi_interface.h +551 -0
  32. multipers/gudhi/cubical_to_boundary.h +59 -0
  33. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -0
  34. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -0
  35. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -0
  36. multipers/gudhi/gudhi/Debug_utils.h +45 -0
  37. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -0
  38. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -0
  39. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -0
  40. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -0
  41. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -0
  42. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -0
  43. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -0
  44. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -0
  45. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -0
  46. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -0
  47. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -0
  48. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -0
  49. multipers/gudhi/gudhi/Matrix.h +2107 -0
  50. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -0
  51. multipers/gudhi/gudhi/Multi_persistence/Box.h +174 -0
  52. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -0
  53. multipers/gudhi/gudhi/Off_reader.h +173 -0
  54. multipers/gudhi/gudhi/One_critical_filtration.h +1441 -0
  55. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -0
  56. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -0
  57. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -0
  58. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -0
  59. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -0
  60. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -0
  61. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -0
  62. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -0
  63. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -0
  64. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -0
  65. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -0
  66. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -0
  67. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -0
  68. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -0
  69. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -0
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -0
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -0
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -0
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -0
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -0
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -0
  76. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -0
  77. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -0
  78. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -0
  79. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -0
  80. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -0
  81. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -0
  82. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -0
  83. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -0
  84. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -0
  85. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -0
  86. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -0
  87. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -0
  88. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -0
  89. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -0
  90. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -0
  91. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -0
  92. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -0
  93. multipers/gudhi/gudhi/Points_off_io.h +171 -0
  94. multipers/gudhi/gudhi/Simple_object_pool.h +69 -0
  95. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -0
  96. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -0
  97. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -0
  98. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -0
  99. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -0
  100. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -0
  101. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -0
  102. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -0
  103. multipers/gudhi/gudhi/Simplex_tree.h +2794 -0
  104. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -0
  105. multipers/gudhi/gudhi/distance_functions.h +62 -0
  106. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -0
  107. multipers/gudhi/gudhi/persistence_interval.h +253 -0
  108. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -0
  109. multipers/gudhi/gudhi/reader_utils.h +367 -0
  110. multipers/gudhi/mma_interface_coh.h +256 -0
  111. multipers/gudhi/mma_interface_h0.h +223 -0
  112. multipers/gudhi/mma_interface_matrix.h +293 -0
  113. multipers/gudhi/naive_merge_tree.h +536 -0
  114. multipers/gudhi/scc_io.h +310 -0
  115. multipers/gudhi/truc.h +1403 -0
  116. multipers/io.cpython-311-darwin.so +0 -0
  117. multipers/io.pyx +644 -0
  118. multipers/ml/__init__.py +0 -0
  119. multipers/ml/accuracies.py +90 -0
  120. multipers/ml/invariants_with_persistable.py +79 -0
  121. multipers/ml/kernels.py +176 -0
  122. multipers/ml/mma.py +713 -0
  123. multipers/ml/one.py +472 -0
  124. multipers/ml/point_clouds.py +352 -0
  125. multipers/ml/signed_measures.py +1589 -0
  126. multipers/ml/sliced_wasserstein.py +461 -0
  127. multipers/ml/tools.py +113 -0
  128. multipers/mma_structures.cpython-311-darwin.so +0 -0
  129. multipers/mma_structures.pxd +128 -0
  130. multipers/mma_structures.pyx +2786 -0
  131. multipers/mma_structures.pyx.tp +1094 -0
  132. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -0
  133. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -0
  134. multipers/multi_parameter_rank_invariant/function_rips.h +322 -0
  135. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -0
  136. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -0
  137. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -0
  138. multipers/multiparameter_edge_collapse.py +41 -0
  139. multipers/multiparameter_module_approximation/approximation.h +2330 -0
  140. multipers/multiparameter_module_approximation/combinatory.h +129 -0
  141. multipers/multiparameter_module_approximation/debug.h +107 -0
  142. multipers/multiparameter_module_approximation/euler_curves.h +0 -0
  143. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -0
  144. multipers/multiparameter_module_approximation/heap_column.h +238 -0
  145. multipers/multiparameter_module_approximation/images.h +79 -0
  146. multipers/multiparameter_module_approximation/list_column.h +174 -0
  147. multipers/multiparameter_module_approximation/list_column_2.h +232 -0
  148. multipers/multiparameter_module_approximation/ru_matrix.h +347 -0
  149. multipers/multiparameter_module_approximation/set_column.h +135 -0
  150. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -0
  151. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -0
  152. multipers/multiparameter_module_approximation/utilities.h +403 -0
  153. multipers/multiparameter_module_approximation/vector_column.h +223 -0
  154. multipers/multiparameter_module_approximation/vector_matrix.h +331 -0
  155. multipers/multiparameter_module_approximation/vineyards.h +464 -0
  156. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -0
  157. multipers/multiparameter_module_approximation.cpython-311-darwin.so +0 -0
  158. multipers/multiparameter_module_approximation.pyx +235 -0
  159. multipers/pickle.py +90 -0
  160. multipers/plots.py +456 -0
  161. multipers/point_measure.cpython-311-darwin.so +0 -0
  162. multipers/point_measure.pyx +395 -0
  163. multipers/simplex_tree_multi.cpython-311-darwin.so +0 -0
  164. multipers/simplex_tree_multi.pxd +134 -0
  165. multipers/simplex_tree_multi.pyx +10840 -0
  166. multipers/simplex_tree_multi.pyx.tp +2009 -0
  167. multipers/slicer.cpython-311-darwin.so +0 -0
  168. multipers/slicer.pxd +3034 -0
  169. multipers/slicer.pxd.tp +234 -0
  170. multipers/slicer.pyx +20481 -0
  171. multipers/slicer.pyx.tp +1088 -0
  172. multipers/tensor/tensor.h +672 -0
  173. multipers/tensor.pxd +13 -0
  174. multipers/test.pyx +44 -0
  175. multipers/tests/__init__.py +62 -0
  176. multipers/torch/__init__.py +1 -0
  177. multipers/torch/diff_grids.py +240 -0
  178. multipers/torch/rips_density.py +310 -0
  179. multipers-2.3.3b6.dist-info/METADATA +128 -0
  180. multipers-2.3.3b6.dist-info/RECORD +183 -0
  181. multipers-2.3.3b6.dist-info/WHEEL +6 -0
  182. multipers-2.3.3b6.dist-info/licenses/LICENSE +21 -0
  183. multipers-2.3.3b6.dist-info/top_level.txt +1 -0
@@ -0,0 +1,45 @@
1
+ import multipers.array_api.numpy as npapi
2
+
3
+
4
+ def api_from_tensor(x, *, verbose: bool = False):
5
+ if npapi.is_promotable(x):
6
+ if verbose:
7
+ print("using numpy backend")
8
+ return npapi
9
+ import multipers.array_api.torch as torchapi
10
+
11
+ if torchapi.is_promotable(x):
12
+ if verbose:
13
+ print("using torch backend")
14
+ return torchapi
15
+ raise ValueError(f"Unsupported type {type(x)=}")
16
+
17
+
18
+ def api_from_tensors(*args):
19
+ assert len(args) > 0, "no tensor given"
20
+ import multipers.array_api.numpy as npapi
21
+
22
+ is_numpy = True
23
+ for x in args:
24
+ if not npapi.is_promotable(x):
25
+ is_numpy = False
26
+ break
27
+ if is_numpy:
28
+ return npapi
29
+
30
+ # only torch for now
31
+ import multipers.array_api.torch as torchapi
32
+
33
+ is_torch = True
34
+ for x in args:
35
+ if not torchapi.is_promotable(x):
36
+ is_torch = False
37
+ break
38
+ if is_torch:
39
+ return torchapi
40
+ raise ValueError(f"Incompatible types got {[type(x) for x in args]=}.")
41
+
42
+
43
+ def to_numpy(x):
44
+ api = api_from_tensor(x)
45
+ return api.asnumpy(x)
@@ -0,0 +1,41 @@
1
+ from contextlib import nullcontext
2
+
3
+ import numpy as _np
4
+ from scipy.spatial.distance import cdist
5
+
6
+ backend = _np
7
+ cat = _np.concatenate
8
+ norm = _np.linalg.norm
9
+ astensor = _np.asarray
10
+ asnumpy = _np.asarray
11
+ tensor = _np.array
12
+ stack = _np.stack
13
+ empty = _np.empty
14
+ where = _np.where
15
+ no_grad = nullcontext
16
+ zeros = _np.zeros
17
+ min = _np.min
18
+ max = _np.max
19
+ repeat_interleave = _np.repeat
20
+ cdist = cdist # type: ignore[no-redef]
21
+ unique = _np.unique
22
+
23
+
24
+ def quantile_closest(x, q, axis=None):
25
+ return _np.quantile(x, q, axis=axis, interpolation="closest_observation")
26
+
27
+
28
+ def minvalues(x: _np.ndarray, **kwargs):
29
+ return _np.min(x, **kwargs)
30
+
31
+
32
+ def maxvalues(x: _np.ndarray, **kwargs):
33
+ return _np.max(x, **kwargs)
34
+
35
+
36
+ def is_promotable(x):
37
+ return isinstance(x, _np.ndarray | list | tuple)
38
+
39
+
40
+ def has_grad(_):
41
+ return False
@@ -0,0 +1,58 @@
1
+ import numpy as _np
2
+ import torch as _t
3
+
4
+ backend = _t
5
+ cat = _t.cat
6
+ norm = _t.norm
7
+ astensor = _t.as_tensor
8
+ tensor = _t.tensor
9
+ stack = _t.stack
10
+ empty = _t.empty
11
+ where = _t.where
12
+ no_grad = _t.no_grad
13
+ cdist = _t.cdist
14
+ zeros = _t.zeros
15
+ min = _t.min
16
+ max = _t.max
17
+ repeat_interleave = _t.repeat_interleave
18
+
19
+
20
+ # in our context, this allows to get a correct gradient.
21
+ def unique(x, assume_sorted=False, _mean=True):
22
+ if not x.requires_grad:
23
+ return x.unique(sorted=assume_sorted)
24
+ if x.ndim != 1:
25
+ raise ValueError(f"Got ndim!=1. {x=}")
26
+ if not assume_sorted:
27
+ x = x.sort().values
28
+ _, c = _t.unique(x, sorted=True, return_counts=True)
29
+ if _mean:
30
+ x = _t.segment_reduce(data=x, reduce="mean", lengths=c, unsafe=True, axis=0)
31
+ return x
32
+
33
+ c = _np.concatenate([[0], _np.cumsum(c[:-1])])
34
+ return x[c]
35
+
36
+
37
+ def quantile_closest(x, q, axis=None):
38
+ return _t.quantile(x, q, dim=axis, interpolation="nearest")
39
+
40
+
41
+ def minvalues(x: _t.Tensor, **kwargs):
42
+ return _t.min(x, **kwargs).values
43
+
44
+
45
+ def maxvalues(x: _t.Tensor, **kwargs):
46
+ return _t.max(x, **kwargs).values
47
+
48
+
49
+ def asnumpy(x):
50
+ return x.detach().numpy()
51
+
52
+
53
+ def is_promotable(x):
54
+ return isinstance(x, _t.Tensor)
55
+
56
+
57
+ def has_grad(x):
58
+ return x.requires_grad
multipers/data/MOL2.py ADDED
@@ -0,0 +1,458 @@
1
+ import os
2
+ from os import listdir
3
+ from os.path import expanduser
4
+ from typing import Iterable
5
+
6
+ import matplotlib.pyplot as plt
7
+ import MDAnalysis as mda
8
+ import numpy as np
9
+ import pandas as pd
10
+ from joblib import Parallel, delayed
11
+ from MDAnalysis.topology.guessers import guess_masses
12
+ from sklearn.base import BaseEstimator, TransformerMixin
13
+ from sklearn.preprocessing import LabelEncoder
14
+
15
+ # from numba import njit
16
+ from tqdm import tqdm
17
+
18
+ import multipers as mp
19
+
20
+ DATASET_PATH = expanduser("~/Datasets/")
21
+ JC_path = DATASET_PATH + "Cleves-Jain/"
22
+ DUDE_path = DATASET_PATH + "DUD-E/"
23
+
24
+
25
+ # pathes = get_data_path()
26
+ # imgs = apply_pipeline(pathes=pathes, pipeline=pipeline_img)
27
+ # distances_to_letter, ytest = img_distances(imgs)
28
+
29
+
30
+ def _get_mols_in_path(folder):
31
+ with open(folder + "/TargetList", "r") as f:
32
+ train_data = [folder + "/" + mol.strip() for mol in f.readlines()]
33
+ criterion = (
34
+ lambda dataset: dataset.endswith(".mol2")
35
+ and not dataset.startswith("final")
36
+ and dataset not in train_data
37
+ )
38
+ test_data = [
39
+ folder + "/" + dataset
40
+ for dataset in listdir(folder)
41
+ if criterion(folder + "/" + dataset)
42
+ ]
43
+ return train_data, test_data
44
+
45
+
46
+ def get_data_path_JC(type="dict"):
47
+ if type == "dict":
48
+ out = {}
49
+ elif type == "list":
50
+ out = []
51
+ else:
52
+ raise TypeError(f"Type {out} not supported")
53
+ for stuff in listdir(JC_path):
54
+ if stuff.startswith("target_"):
55
+ current_letter = stuff[-1]
56
+ to_add = _get_mols_in_path(JC_path + stuff)
57
+ if type == "dict":
58
+ out[current_letter] = to_add
59
+ elif type == "list":
60
+ out.append(to_add)
61
+ decoy_folder = JC_path + "RognanRing850/"
62
+ to_add = [
63
+ decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")
64
+ ]
65
+ if type == "dict":
66
+ out["decoy"] = to_add
67
+ elif type == "list":
68
+ out.append(to_add)
69
+ return out
70
+
71
+
72
+ def get_all_JC_path():
73
+ out = []
74
+ for stuff in listdir(JC_path):
75
+ if stuff.startswith("target_"):
76
+ train_data, test_data = _get_mols_in_path(JC_path + stuff)
77
+ out += train_data
78
+ out += test_data
79
+ decoy_folder = JC_path + "RognanRing850/"
80
+ out += [
81
+ decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")
82
+ ]
83
+ return out
84
+
85
+
86
+ def split_multimol(
87
+ path: str,
88
+ mol_name: str,
89
+ out_folder_name: str = "splitted",
90
+ enforce_charges: bool = False,
91
+ ):
92
+ with open(path + mol_name, "r") as f:
93
+ lines = f.readlines()
94
+ splitted_mols = []
95
+ index = 0
96
+ for i, line in enumerate(lines):
97
+ is_last = i == len(lines) - 1
98
+ if line.strip() == "@<TRIPOS>MOLECULE" or is_last:
99
+ if i != index:
100
+ molecule = "".join(lines[index : i + is_last])
101
+ if enforce_charges:
102
+ # print(f"Replaced molecule {i}")
103
+ molecule = molecule.replace("NO_CHARGES", "USER_CHARGES")
104
+ # print(molecule)
105
+ # return
106
+ index = i
107
+ splitted_mols.append(molecule)
108
+ if not os.path.exists(path + out_folder_name):
109
+ os.mkdir(path + out_folder_name)
110
+ for i, mol in enumerate(splitted_mols):
111
+ with open(path + out_folder_name + f"/{i}.mol2", "w") as f:
112
+ f.write(mol)
113
+ return [path + out_folder_name + f"/{i}.mol2" for i in range(len(splitted_mols))]
114
+
115
+
116
+ # @njit(parallel=True)
117
+ def apply_pipeline(pathes: dict, pipeline):
118
+ img_dict = {}
119
+ for key, value in tqdm(pathes.items(), desc="Applying pipeline"):
120
+ if len(key) == 1:
121
+ train_paths, test_paths = value
122
+ train_imgs = pipeline.transform(train_paths)
123
+ test_imgs = pipeline.transform(test_paths)
124
+ img_dict[key] = (train_imgs, test_imgs)
125
+ else:
126
+ assert key == "decoy"
127
+ img_dict[key] = pipeline.transform(value)
128
+ return img_dict
129
+
130
+
131
+ from sklearn.metrics import pairwise_distances
132
+
133
+
134
+ def img_distances(img_dict: dict):
135
+ distances_to_anchors = []
136
+ ytest = []
137
+ decoy_list = img_dict["decoy"]
138
+ for letter, imgs in img_dict.items():
139
+ if len(letter) != 1:
140
+ continue # decoy
141
+ xtrain, xtest = imgs
142
+ assert len(xtest) > 0
143
+ train_data, test_data = xtrain, np.concatenate([xtest, decoy_list])
144
+ D = pairwise_distances(train_data, test_data)
145
+ distances_to_anchors.append(D)
146
+ letter_ytest = np.array(
147
+ [letter] * len(xtest) + ["0"] * len(decoy_list), dtype="<U1"
148
+ )
149
+ ytest.append(letter_ytest)
150
+ return distances_to_anchors, ytest
151
+
152
+
153
+ def get_EF_vector_from_distances(distances, ytest, alpha=0.05):
154
+ EF = []
155
+ for distance_to_anchors, letter_ytest in zip(distances, ytest):
156
+ indices = np.argsort(distance_to_anchors, axis=1)
157
+ n = indices.shape[1]
158
+ n_max = int(alpha * n)
159
+ good_indices = (
160
+ letter_ytest[indices[:, :n_max]] == letter_ytest[0]
161
+ ) ## assumes that ytest[:,0] are the good letters
162
+ EF_letter = good_indices.sum(axis=1) / (letter_ytest == letter_ytest[0]).sum()
163
+ EF_letter /= alpha
164
+ EF.append(EF_letter.mean())
165
+ return np.mean(EF)
166
+
167
+
168
+ def EF_from_distance_matrix(
169
+ distances: np.ndarray, labels: list | np.ndarray, alpha: float, anchors_in_test=True
170
+ ):
171
+ """
172
+ Computes the Enrichment Factor from a distance matrix, and its labels.
173
+ - First axis of the distance matrix is the anchors on which to compute the EF
174
+ - Second axis is the test. For convenience, anchors can be put in test, if the flag anchors_in_test is set to true.
175
+ - labels is a table of bools, representing the the labels of the test axis of the distance matrix.
176
+ - alpha : the EF alpha parameter.
177
+ """
178
+ n = len(labels)
179
+ n_max = int(alpha * n)
180
+ indices = np.argsort(distances, axis=1)
181
+ EF_ = [
182
+ ((labels[idx[:n_max]]).sum() - anchors_in_test)
183
+ / (labels.sum() - anchors_in_test)
184
+ for idx in indices
185
+ ]
186
+ return np.mean(EF_) / alpha
187
+
188
+
189
+ def EF_AUC(distances: np.ndarray, labels: np.ndarray, anchors_in_test=0):
190
+ if distances.ndim == 1:
191
+ distances = distances[None, :]
192
+ assert distances.ndim == 2
193
+ indices = np.argsort(distances, axis=1)
194
+ out = []
195
+ for i in range(1, distances.size):
196
+ proportion_of_good_indices = (
197
+ labels[indices[:, :i]].sum(axis=1).mean() - anchors_in_test
198
+ ) / min(i, labels.sum() - anchors_in_test)
199
+ out.append(proportion_of_good_indices)
200
+ # print(out)
201
+ return np.mean(out)
202
+
203
+
204
+ def theorical_max_EF(distances, labels, alpha):
205
+ n = len(labels)
206
+ n_max = int(alpha * n)
207
+ num_true_labels = np.sum(
208
+ labels == labels[0]
209
+ ) ## if labels are not True / False, assumes that the first one is a good one
210
+ return min(n_max, num_true_labels) / alpha
211
+
212
+
213
+ def theorical_max_EF_from_distances(list_of_distances, list_of_labels, alpha):
214
+ return np.mean(
215
+ [
216
+ theorical_max_EF(distances, labels, alpha)
217
+ for distances, labels in zip(list_of_distances, list_of_labels)
218
+ ]
219
+ )
220
+
221
+
222
+ def plot_EF_from_distances(
223
+ alphas=[0.01, 0.02, 0.05, 0.1], EF=EF_from_distance_matrix, plot: bool = True
224
+ ):
225
+ y = np.round([EF(alpha=alpha) for alpha in alphas], decimals=2)
226
+ if plot:
227
+ _alphas = np.linspace(0.01, 1.0, 100)
228
+ plt.figure()
229
+ plt.plot(_alphas, [EF(alpha=alpha) for alpha in _alphas])
230
+ plt.scatter(alphas, y, c="r")
231
+ plt.title("Enrichment Factor")
232
+ plt.xlabel(r"$\alpha$" + f" = {alphas}")
233
+ plt.ylabel(r"$\mathrm{EF}_\alpha$" + f" = {y}")
234
+ return y
235
+
236
+
237
+ def lines2bonds(
238
+ mol: mda.Universe, bond_types=["ar", "am", 3, 2, 1, 0], molecule_format=None
239
+ ):
240
+ extension = (
241
+ mol.filename.split(".")[-1].lower()
242
+ if molecule_format is None
243
+ else molecule_format
244
+ )
245
+ match extension:
246
+ case "mol2":
247
+ out = lines2bonds_MOL2(mol)["bond_type"]
248
+ case "pdb":
249
+ out = lines2bonds_PDB(mol)
250
+ case _:
251
+ raise Exception("Invalid, or not supported molecule format.")
252
+ return LabelEncoder().fit(bond_types).transform(out)
253
+
254
+
255
+ def lines2bonds_MOL2(mol: mda.Universe):
256
+ _lines = open(mol.filename, "r").readlines()
257
+ out = []
258
+ index = 0
259
+ while index < len(_lines) and _lines[index].strip() != "@<TRIPOS>BOND":
260
+ index += 1
261
+ index += 1
262
+ while index < len(_lines) and _lines[index].strip()[0] != "@":
263
+ line = _lines[index].strip().split(" ")
264
+ for j, truc in enumerate(line):
265
+ line[j] = truc.strip()
266
+ # try:
267
+ out.append([stuff for stuff in line if len(stuff) > 0])
268
+ # except:
269
+ # print_lin
270
+ index += 1
271
+ out = pd.DataFrame(out, columns=["bond_id", "atom1", "atom2", "bond_type"])
272
+ out.set_index(["bond_id"], inplace=True)
273
+ return out
274
+
275
+
276
+ def lines2bonds_PDB(mol: mda.Universe):
277
+ raise Exception("Not yet implemented.")
278
+ return
279
+
280
+
281
+ def _mol2graphst(
282
+ path: str | mda.Universe, filtrations: Iterable[str], molecule_format=None
283
+ ):
284
+ molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
285
+
286
+ num_filtrations = len(filtrations)
287
+ nodes = molecule.atoms.indices.reshape(1, -1)
288
+ edges = molecule.bonds.dump_contents().T
289
+ num_vertices = nodes.shape[1]
290
+ num_edges = edges.shape[1]
291
+
292
+ st = mp.SimplexTreeMulti(num_parameters=num_filtrations)
293
+
294
+ ## Edges filtration
295
+ # edges = np.array(bonds_df[["atom1", "atom2"]]).T
296
+ edges_filtration = np.zeros((num_edges, num_filtrations), dtype=np.float32) - np.inf
297
+ for i, filtration in enumerate(filtrations):
298
+ match filtration:
299
+ case "bond_length":
300
+ bond_lengths = molecule.bonds.bonds()
301
+ edges_filtration[:, i] = bond_lengths
302
+ case "bond_type":
303
+ bond_types = lines2bonds(mol=molecule, molecule_format=molecule_format)
304
+ edges_filtration[:, i] = bond_types
305
+ case _:
306
+ pass
307
+
308
+ ## Nodes filtration
309
+ nodes_filtrations = np.zeros(
310
+ (num_vertices, num_filtrations), dtype=np.float32
311
+ ) + np.min(
312
+ edges_filtration, axis=0
313
+ ) # better than - np.inf
314
+ st.insert_batch(nodes, nodes_filtrations)
315
+
316
+ st.insert_batch(edges, edges_filtration)
317
+ for i, filtration in enumerate(filtrations):
318
+ match filtration:
319
+ case "charge":
320
+ charges = molecule.atoms.charges
321
+ st.fill_lowerstar(charges, parameter=i)
322
+ case "atomic_mass":
323
+ masses = molecule.atoms.masses
324
+ null_indices = masses == 0
325
+ if np.any(null_indices): # guess if necessary
326
+ masses[null_indices] = guess_masses(molecule.atoms.types)[
327
+ null_indices
328
+ ]
329
+ st.fill_lowerstar(-masses, parameter=i)
330
+ case _:
331
+ pass
332
+ st.make_filtration_non_decreasing() # Necessary ?
333
+ return st
334
+
335
+
336
+ def _mol2ripsst(
337
+ path: str,
338
+ filtrations: Iterable[str],
339
+ threshold=np.inf,
340
+ bond_types: list = ["ar", "am", 3, 2, 1, 0],
341
+ ):
342
+ import gudhi as gd
343
+
344
+ assert "bond_length" == filtrations[0], "Bond length has to be first for rips."
345
+ molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
346
+ num_parameters = len(filtrations)
347
+ st_rips = gd.RipsComplex(
348
+ points=molecule.atoms.positions, max_edge_length=threshold
349
+ ).create_simplex_tree()
350
+ st = mp.SimplexTreeMulti(
351
+ st_rips,
352
+ num_parameters=num_parameters,
353
+ default_values=[
354
+ bond_types.index(0) if f == "bond_type" else -np.inf
355
+ for f in filtrations[1:]
356
+ ], # the 0 index is the label of 'no bond' in bond_types
357
+ )
358
+
359
+ ## Edges filtration
360
+ mol_bonds = molecule.bonds.indices.T
361
+ edges_filtration = (
362
+ np.zeros((mol_bonds.shape[1], num_parameters), dtype=np.float32) - np.inf
363
+ )
364
+ for i, filtration in enumerate(filtrations):
365
+ match filtration:
366
+ case "bond_type":
367
+ edges_filtration[:, i] = lines2bonds(
368
+ mol=molecule, bond_types=bond_types
369
+ )
370
+ case "atomic_mass":
371
+ continue
372
+ case "charge":
373
+ continue
374
+ case "bond_length":
375
+ edges_filtration[:, i] = [st_rips.filtration(s) for s in mol_bonds.T]
376
+ case _:
377
+ raise Exception(
378
+ f"Invalid filtration {filtration}. Available ones : bond_type, atomic_mass, charge, bond_length."
379
+ )
380
+ st.assign_batch_filtration(mol_bonds, edges_filtration, propagate=False)
381
+ min_filtration = edges_filtration.min(axis=0)
382
+ st.assign_batch_filtration(
383
+ np.asarray([list(range(st.num_vertices))], dtype=int),
384
+ np.asarray([min_filtration] * st.num_vertices, dtype=np.float32),
385
+ propagate=False,
386
+ )
387
+ ## Nodes filtration
388
+ for i, filtration in enumerate(filtrations):
389
+ match filtration:
390
+ case "charge":
391
+ charges = molecule.atoms.charges
392
+ st.fill_lowerstar(charges, parameter=i)
393
+ case "atomic_mass":
394
+ masses = molecule.atoms.masses
395
+ null_indices = masses == 0
396
+ if np.any(null_indices): # guess if necessary
397
+ masses[null_indices] = guess_masses(molecule.atoms.types)[
398
+ null_indices
399
+ ]
400
+ # print(masses)
401
+ st.fill_lowerstar(-masses, parameter=i)
402
+ case _:
403
+ pass
404
+ st.make_filtration_non_decreasing() # Necessary ?
405
+ return st
406
+
407
+
408
+ class Molecule2SimplexTree(BaseEstimator, TransformerMixin):
409
+ """
410
+ Transforms a list of MDA-compatible files into a list of mulitparameter simplextrees
411
+
412
+ Input
413
+ -----
414
+ X: Iterable[path_to_files:str]
415
+
416
+ Output
417
+ ------
418
+ Iterable[multipers.SimplexTreeMulti]
419
+
420
+ Parameters
421
+ ----------
422
+ - filtrations : list of filtration names. Available ones : 'charge', 'atomic_mass', 'bond_length', 'bond_type'. Others are ignored.
423
+ - graph : bool. If true, will use the graph given by the molecule, otherwise, a Rips Complex Based on the distance. '
424
+ In that case bond_length is ignored (it's the 1rst parameter).
425
+ """
426
+
427
+ def __init__(
428
+ self,
429
+ delayed: bool = False,
430
+ filtrations: Iterable[str] = [],
431
+ graph: bool = True,
432
+ n_jobs: int = 1,
433
+ ) -> None:
434
+ super().__init__()
435
+ self.delayed = delayed
436
+ self.n_jobs = n_jobs
437
+ self.filtrations = filtrations
438
+ self.graph = graph
439
+ self._molecule_format = None
440
+ return
441
+
442
+ def fit(self, X: Iterable[str], y=None):
443
+ if len(X) == 0:
444
+ return self
445
+ test_mol = mda.Universe(X[0])
446
+ self._molecule_format = test_mol.filename.split(".")[-1].lower()
447
+ return self
448
+
449
+ def transform(self, X: Iterable[str]):
450
+ _to_simplextree = _mol2graphst if self.graph else _mol2ripsst
451
+ to_simplex_tree = lambda path_to_mol2_file: [
452
+ _to_simplextree(path=path_to_mol2_file, filtrations=self.filtrations)
453
+ ]
454
+ if self.delayed:
455
+ return [delayed(to_simplex_tree)(path) for path in X]
456
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(
457
+ delayed(to_simplex_tree)(path) for path in X
458
+ )
multipers/data/UCR.py ADDED
@@ -0,0 +1,18 @@
1
+ import numpy as np
2
+ from os.path import expanduser
3
+ import pandas as pd
4
+ from sklearn.preprocessing import LabelEncoder
5
+
6
+ def get(dataset:str="UCR/Coffee", test:bool=False, DATASET_PATH:str=expanduser("~/Datasets/"), dim=3,delay=1,skip=1):
7
+ from gudhi.point_cloud.timedelay import TimeDelayEmbedding
8
+ dataset_path = DATASET_PATH + dataset + "/" + dataset[4:]
9
+ dataset_path += "_TEST.tsv" if test else "_TRAIN.tsv"
10
+ data = np.array(pd.read_csv(dataset_path, delimiter='\t', header=None, index_col=None))
11
+ Y = LabelEncoder().fit_transform(data[:,0])
12
+ data = data[:,1:]
13
+ tde = TimeDelayEmbedding(dim=dim, delay=delay, skip=skip).transform(data)
14
+ return tde, Y
15
+ def get_train(*args, **kwargs):
16
+ return get(*args, **kwargs, test=False)
17
+ def get_test(*args, **kwargs):
18
+ return get(*args, **kwargs, test=True)
@@ -0,0 +1 @@
1
+ from .synthetic import *