multipers 2.2.3__cp310-cp310-win_amd64.whl → 2.3.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -31
  2. multipers/_signed_measure_meta.py +430 -430
  3. multipers/_slicer_meta.py +211 -212
  4. multipers/data/MOL2.py +458 -458
  5. multipers/data/UCR.py +18 -18
  6. multipers/data/graphs.py +466 -466
  7. multipers/data/immuno_regions.py +27 -27
  8. multipers/data/pytorch2simplextree.py +90 -90
  9. multipers/data/shape3d.py +101 -101
  10. multipers/data/synthetic.py +113 -111
  11. multipers/distances.py +198 -198
  12. multipers/filtration_conversions.pxd.tp +84 -84
  13. multipers/filtrations/__init__.py +18 -0
  14. multipers/filtrations/filtrations.py +289 -0
  15. multipers/filtrations.pxd +224 -224
  16. multipers/function_rips.cp310-win_amd64.pyd +0 -0
  17. multipers/function_rips.pyx +105 -105
  18. multipers/grids.cp310-win_amd64.pyd +0 -0
  19. multipers/grids.pyx +350 -350
  20. multipers/gudhi/Persistence_slices_interface.h +132 -132
  21. multipers/gudhi/Simplex_tree_interface.h +239 -245
  22. multipers/gudhi/Simplex_tree_multi_interface.h +516 -561
  23. multipers/gudhi/cubical_to_boundary.h +59 -59
  24. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -450
  25. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -1070
  26. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -579
  27. multipers/gudhi/gudhi/Debug_utils.h +45 -45
  28. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -484
  29. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -455
  30. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -450
  31. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -531
  32. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -507
  33. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -531
  34. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -355
  35. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -376
  36. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -420
  37. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -400
  38. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -418
  39. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -337
  40. multipers/gudhi/gudhi/Matrix.h +2107 -2107
  41. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -1038
  42. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -171
  43. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -282
  44. multipers/gudhi/gudhi/Off_reader.h +173 -173
  45. multipers/gudhi/gudhi/One_critical_filtration.h +1432 -1431
  46. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -769
  47. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -686
  48. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -842
  49. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -1350
  50. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -1105
  51. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -859
  52. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -910
  53. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -139
  54. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -230
  55. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -211
  56. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -60
  57. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -60
  58. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -136
  59. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -190
  60. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -616
  61. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -150
  62. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -106
  63. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -219
  64. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -327
  65. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -1140
  66. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -934
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -934
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -980
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -1092
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -192
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -921
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -1093
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -1012
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -1244
  75. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -186
  76. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -164
  77. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -156
  78. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -376
  79. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -540
  80. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -118
  81. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -173
  82. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -128
  83. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -745
  84. multipers/gudhi/gudhi/Points_off_io.h +171 -171
  85. multipers/gudhi/gudhi/Simple_object_pool.h +69 -69
  86. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -463
  87. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -83
  88. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -106
  89. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -277
  90. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -62
  91. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -27
  92. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -62
  93. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -157
  94. multipers/gudhi/gudhi/Simplex_tree.h +2794 -2794
  95. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -163
  96. multipers/gudhi/gudhi/distance_functions.h +62 -62
  97. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -104
  98. multipers/gudhi/gudhi/persistence_interval.h +253 -253
  99. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -170
  100. multipers/gudhi/gudhi/reader_utils.h +367 -367
  101. multipers/gudhi/mma_interface_coh.h +256 -255
  102. multipers/gudhi/mma_interface_h0.h +223 -231
  103. multipers/gudhi/mma_interface_matrix.h +284 -282
  104. multipers/gudhi/naive_merge_tree.h +536 -575
  105. multipers/gudhi/scc_io.h +310 -289
  106. multipers/gudhi/truc.h +890 -888
  107. multipers/io.cp310-win_amd64.pyd +0 -0
  108. multipers/io.pyx +711 -711
  109. multipers/ml/accuracies.py +90 -90
  110. multipers/ml/convolutions.py +520 -520
  111. multipers/ml/invariants_with_persistable.py +79 -79
  112. multipers/ml/kernels.py +176 -176
  113. multipers/ml/mma.py +713 -714
  114. multipers/ml/one.py +472 -472
  115. multipers/ml/point_clouds.py +352 -346
  116. multipers/ml/signed_measures.py +1589 -1589
  117. multipers/ml/sliced_wasserstein.py +461 -461
  118. multipers/ml/tools.py +113 -113
  119. multipers/mma_structures.cp310-win_amd64.pyd +0 -0
  120. multipers/mma_structures.pxd +127 -127
  121. multipers/mma_structures.pyx +4 -4
  122. multipers/mma_structures.pyx.tp +1085 -1085
  123. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -93
  124. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -97
  125. multipers/multi_parameter_rank_invariant/function_rips.h +322 -322
  126. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -769
  127. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -148
  128. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -369
  129. multipers/multiparameter_edge_collapse.py +41 -41
  130. multipers/multiparameter_module_approximation/approximation.h +2296 -2295
  131. multipers/multiparameter_module_approximation/combinatory.h +129 -129
  132. multipers/multiparameter_module_approximation/debug.h +107 -107
  133. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -286
  134. multipers/multiparameter_module_approximation/heap_column.h +238 -238
  135. multipers/multiparameter_module_approximation/images.h +79 -79
  136. multipers/multiparameter_module_approximation/list_column.h +174 -174
  137. multipers/multiparameter_module_approximation/list_column_2.h +232 -232
  138. multipers/multiparameter_module_approximation/ru_matrix.h +347 -347
  139. multipers/multiparameter_module_approximation/set_column.h +135 -135
  140. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -36
  141. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -166
  142. multipers/multiparameter_module_approximation/utilities.h +403 -419
  143. multipers/multiparameter_module_approximation/vector_column.h +223 -223
  144. multipers/multiparameter_module_approximation/vector_matrix.h +331 -331
  145. multipers/multiparameter_module_approximation/vineyards.h +464 -464
  146. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -649
  147. multipers/multiparameter_module_approximation.cp310-win_amd64.pyd +0 -0
  148. multipers/multiparameter_module_approximation.pyx +216 -217
  149. multipers/pickle.py +90 -53
  150. multipers/plots.py +342 -334
  151. multipers/point_measure.cp310-win_amd64.pyd +0 -0
  152. multipers/point_measure.pyx +322 -320
  153. multipers/simplex_tree_multi.cp310-win_amd64.pyd +0 -0
  154. multipers/simplex_tree_multi.pxd +133 -133
  155. multipers/simplex_tree_multi.pyx +18 -15
  156. multipers/simplex_tree_multi.pyx.tp +1939 -1935
  157. multipers/slicer.cp310-win_amd64.pyd +0 -0
  158. multipers/slicer.pxd +81 -20
  159. multipers/slicer.pxd.tp +215 -214
  160. multipers/slicer.pyx +1091 -308
  161. multipers/slicer.pyx.tp +924 -914
  162. multipers/tensor/tensor.h +672 -672
  163. multipers/tensor.pxd +13 -13
  164. multipers/test.pyx +44 -44
  165. multipers/tests/__init__.py +57 -57
  166. multipers/torch/diff_grids.py +217 -217
  167. multipers/torch/rips_density.py +310 -304
  168. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/LICENSE +21 -21
  169. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/METADATA +21 -11
  170. multipers-2.3.0.dist-info/RECORD +182 -0
  171. multipers/tests/test_diff_helper.py +0 -73
  172. multipers/tests/test_hilbert_function.py +0 -82
  173. multipers/tests/test_mma.py +0 -83
  174. multipers/tests/test_point_clouds.py +0 -49
  175. multipers/tests/test_python-cpp_conversion.py +0 -82
  176. multipers/tests/test_signed_betti.py +0 -181
  177. multipers/tests/test_signed_measure.py +0 -89
  178. multipers/tests/test_simplextreemulti.py +0 -221
  179. multipers/tests/test_slicer.py +0 -221
  180. multipers-2.2.3.dist-info/RECORD +0 -189
  181. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/WHEEL +0 -0
  182. {multipers-2.2.3.dist-info → multipers-2.3.0.dist-info}/top_level.txt +0 -0
multipers/data/MOL2.py CHANGED
@@ -1,458 +1,458 @@
1
- import os
2
- from os import listdir
3
- from os.path import expanduser
4
- from typing import Iterable
5
-
6
- import matplotlib.pyplot as plt
7
- import MDAnalysis as mda
8
- import numpy as np
9
- import pandas as pd
10
- from joblib import Parallel, delayed
11
- from MDAnalysis.topology.guessers import guess_masses
12
- from sklearn.base import BaseEstimator, TransformerMixin
13
- from sklearn.preprocessing import LabelEncoder
14
-
15
- # from numba import njit
16
- from tqdm import tqdm
17
-
18
- import multipers as mp
19
-
20
- DATASET_PATH = expanduser("~/Datasets/")
21
- JC_path = DATASET_PATH + "Cleves-Jain/"
22
- DUDE_path = DATASET_PATH + "DUD-E/"
23
-
24
-
25
- # pathes = get_data_path()
26
- # imgs = apply_pipeline(pathes=pathes, pipeline=pipeline_img)
27
- # distances_to_letter, ytest = img_distances(imgs)
28
-
29
-
30
- def _get_mols_in_path(folder):
31
- with open(folder + "/TargetList", "r") as f:
32
- train_data = [folder + "/" + mol.strip() for mol in f.readlines()]
33
- criterion = (
34
- lambda dataset: dataset.endswith(".mol2")
35
- and not dataset.startswith("final")
36
- and dataset not in train_data
37
- )
38
- test_data = [
39
- folder + "/" + dataset
40
- for dataset in listdir(folder)
41
- if criterion(folder + "/" + dataset)
42
- ]
43
- return train_data, test_data
44
-
45
-
46
- def get_data_path_JC(type="dict"):
47
- if type == "dict":
48
- out = {}
49
- elif type == "list":
50
- out = []
51
- else:
52
- raise TypeError(f"Type {out} not supported")
53
- for stuff in listdir(JC_path):
54
- if stuff.startswith("target_"):
55
- current_letter = stuff[-1]
56
- to_add = _get_mols_in_path(JC_path + stuff)
57
- if type == "dict":
58
- out[current_letter] = to_add
59
- elif type == "list":
60
- out.append(to_add)
61
- decoy_folder = JC_path + "RognanRing850/"
62
- to_add = [
63
- decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")
64
- ]
65
- if type == "dict":
66
- out["decoy"] = to_add
67
- elif type == "list":
68
- out.append(to_add)
69
- return out
70
-
71
-
72
- def get_all_JC_path():
73
- out = []
74
- for stuff in listdir(JC_path):
75
- if stuff.startswith("target_"):
76
- train_data, test_data = _get_mols_in_path(JC_path + stuff)
77
- out += train_data
78
- out += test_data
79
- decoy_folder = JC_path + "RognanRing850/"
80
- out += [
81
- decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")
82
- ]
83
- return out
84
-
85
-
86
- def split_multimol(
87
- path: str,
88
- mol_name: str,
89
- out_folder_name: str = "splitted",
90
- enforce_charges: bool = False,
91
- ):
92
- with open(path + mol_name, "r") as f:
93
- lines = f.readlines()
94
- splitted_mols = []
95
- index = 0
96
- for i, line in enumerate(lines):
97
- is_last = i == len(lines) - 1
98
- if line.strip() == "@<TRIPOS>MOLECULE" or is_last:
99
- if i != index:
100
- molecule = "".join(lines[index : i + is_last])
101
- if enforce_charges:
102
- # print(f"Replaced molecule {i}")
103
- molecule = molecule.replace("NO_CHARGES", "USER_CHARGES")
104
- # print(molecule)
105
- # return
106
- index = i
107
- splitted_mols.append(molecule)
108
- if not os.path.exists(path + out_folder_name):
109
- os.mkdir(path + out_folder_name)
110
- for i, mol in enumerate(splitted_mols):
111
- with open(path + out_folder_name + f"/{i}.mol2", "w") as f:
112
- f.write(mol)
113
- return [path + out_folder_name + f"/{i}.mol2" for i in range(len(splitted_mols))]
114
-
115
-
116
- # @njit(parallel=True)
117
- def apply_pipeline(pathes: dict, pipeline):
118
- img_dict = {}
119
- for key, value in tqdm(pathes.items(), desc="Applying pipeline"):
120
- if len(key) == 1:
121
- train_paths, test_paths = value
122
- train_imgs = pipeline.transform(train_paths)
123
- test_imgs = pipeline.transform(test_paths)
124
- img_dict[key] = (train_imgs, test_imgs)
125
- else:
126
- assert key == "decoy"
127
- img_dict[key] = pipeline.transform(value)
128
- return img_dict
129
-
130
-
131
- from sklearn.metrics import pairwise_distances
132
-
133
-
134
- def img_distances(img_dict: dict):
135
- distances_to_anchors = []
136
- ytest = []
137
- decoy_list = img_dict["decoy"]
138
- for letter, imgs in img_dict.items():
139
- if len(letter) != 1:
140
- continue # decoy
141
- xtrain, xtest = imgs
142
- assert len(xtest) > 0
143
- train_data, test_data = xtrain, np.concatenate([xtest, decoy_list])
144
- D = pairwise_distances(train_data, test_data)
145
- distances_to_anchors.append(D)
146
- letter_ytest = np.array(
147
- [letter] * len(xtest) + ["0"] * len(decoy_list), dtype="<U1"
148
- )
149
- ytest.append(letter_ytest)
150
- return distances_to_anchors, ytest
151
-
152
-
153
- def get_EF_vector_from_distances(distances, ytest, alpha=0.05):
154
- EF = []
155
- for distance_to_anchors, letter_ytest in zip(distances, ytest):
156
- indices = np.argsort(distance_to_anchors, axis=1)
157
- n = indices.shape[1]
158
- n_max = int(alpha * n)
159
- good_indices = (
160
- letter_ytest[indices[:, :n_max]] == letter_ytest[0]
161
- ) ## assumes that ytest[:,0] are the good letters
162
- EF_letter = good_indices.sum(axis=1) / (letter_ytest == letter_ytest[0]).sum()
163
- EF_letter /= alpha
164
- EF.append(EF_letter.mean())
165
- return np.mean(EF)
166
-
167
-
168
- def EF_from_distance_matrix(
169
- distances: np.ndarray, labels: list | np.ndarray, alpha: float, anchors_in_test=True
170
- ):
171
- """
172
- Computes the Enrichment Factor from a distance matrix, and its labels.
173
- - First axis of the distance matrix is the anchors on which to compute the EF
174
- - Second axis is the test. For convenience, anchors can be put in test, if the flag anchors_in_test is set to true.
175
- - labels is a table of bools, representing the the labels of the test axis of the distance matrix.
176
- - alpha : the EF alpha parameter.
177
- """
178
- n = len(labels)
179
- n_max = int(alpha * n)
180
- indices = np.argsort(distances, axis=1)
181
- EF_ = [
182
- ((labels[idx[:n_max]]).sum() - anchors_in_test)
183
- / (labels.sum() - anchors_in_test)
184
- for idx in indices
185
- ]
186
- return np.mean(EF_) / alpha
187
-
188
-
189
- def EF_AUC(distances: np.ndarray, labels: np.ndarray, anchors_in_test=0):
190
- if distances.ndim == 1:
191
- distances = distances[None, :]
192
- assert distances.ndim == 2
193
- indices = np.argsort(distances, axis=1)
194
- out = []
195
- for i in range(1, distances.size):
196
- proportion_of_good_indices = (
197
- labels[indices[:, :i]].sum(axis=1).mean() - anchors_in_test
198
- ) / min(i, labels.sum() - anchors_in_test)
199
- out.append(proportion_of_good_indices)
200
- # print(out)
201
- return np.mean(out)
202
-
203
-
204
- def theorical_max_EF(distances, labels, alpha):
205
- n = len(labels)
206
- n_max = int(alpha * n)
207
- num_true_labels = np.sum(
208
- labels == labels[0]
209
- ) ## if labels are not True / False, assumes that the first one is a good one
210
- return min(n_max, num_true_labels) / alpha
211
-
212
-
213
- def theorical_max_EF_from_distances(list_of_distances, list_of_labels, alpha):
214
- return np.mean(
215
- [
216
- theorical_max_EF(distances, labels, alpha)
217
- for distances, labels in zip(list_of_distances, list_of_labels)
218
- ]
219
- )
220
-
221
-
222
- def plot_EF_from_distances(
223
- alphas=[0.01, 0.02, 0.05, 0.1], EF=EF_from_distance_matrix, plot: bool = True
224
- ):
225
- y = np.round([EF(alpha=alpha) for alpha in alphas], decimals=2)
226
- if plot:
227
- _alphas = np.linspace(0.01, 1.0, 100)
228
- plt.figure()
229
- plt.plot(_alphas, [EF(alpha=alpha) for alpha in _alphas])
230
- plt.scatter(alphas, y, c="r")
231
- plt.title("Enrichment Factor")
232
- plt.xlabel(r"$\alpha$" + f" = {alphas}")
233
- plt.ylabel(r"$\mathrm{EF}_\alpha$" + f" = {y}")
234
- return y
235
-
236
-
237
- def lines2bonds(
238
- mol: mda.Universe, bond_types=["ar", "am", 3, 2, 1, 0], molecule_format=None
239
- ):
240
- extension = (
241
- mol.filename.split(".")[-1].lower()
242
- if molecule_format is None
243
- else molecule_format
244
- )
245
- match extension:
246
- case "mol2":
247
- out = lines2bonds_MOL2(mol)["bond_type"]
248
- case "pdb":
249
- out = lines2bonds_PDB(mol)
250
- case _:
251
- raise Exception("Invalid, or not supported molecule format.")
252
- return LabelEncoder().fit(bond_types).transform(out)
253
-
254
-
255
- def lines2bonds_MOL2(mol: mda.Universe):
256
- _lines = open(mol.filename, "r").readlines()
257
- out = []
258
- index = 0
259
- while index < len(_lines) and _lines[index].strip() != "@<TRIPOS>BOND":
260
- index += 1
261
- index += 1
262
- while index < len(_lines) and _lines[index].strip()[0] != "@":
263
- line = _lines[index].strip().split(" ")
264
- for j, truc in enumerate(line):
265
- line[j] = truc.strip()
266
- # try:
267
- out.append([stuff for stuff in line if len(stuff) > 0])
268
- # except:
269
- # print_lin
270
- index += 1
271
- out = pd.DataFrame(out, columns=["bond_id", "atom1", "atom2", "bond_type"])
272
- out.set_index(["bond_id"], inplace=True)
273
- return out
274
-
275
-
276
- def lines2bonds_PDB(mol: mda.Universe):
277
- raise Exception("Not yet implemented.")
278
- return
279
-
280
-
281
- def _mol2graphst(
282
- path: str | mda.Universe, filtrations: Iterable[str], molecule_format=None
283
- ):
284
- molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
285
-
286
- num_filtrations = len(filtrations)
287
- nodes = molecule.atoms.indices.reshape(1, -1)
288
- edges = molecule.bonds.dump_contents().T
289
- num_vertices = nodes.shape[1]
290
- num_edges = edges.shape[1]
291
-
292
- st = mp.SimplexTreeMulti(num_parameters=num_filtrations)
293
-
294
- ## Edges filtration
295
- # edges = np.array(bonds_df[["atom1", "atom2"]]).T
296
- edges_filtration = np.zeros((num_edges, num_filtrations), dtype=np.float32) - np.inf
297
- for i, filtration in enumerate(filtrations):
298
- match filtration:
299
- case "bond_length":
300
- bond_lengths = molecule.bonds.bonds()
301
- edges_filtration[:, i] = bond_lengths
302
- case "bond_type":
303
- bond_types = lines2bonds(mol=molecule, molecule_format=molecule_format)
304
- edges_filtration[:, i] = bond_types
305
- case _:
306
- pass
307
-
308
- ## Nodes filtration
309
- nodes_filtrations = np.zeros(
310
- (num_vertices, num_filtrations), dtype=np.float32
311
- ) + np.min(
312
- edges_filtration, axis=0
313
- ) # better than - np.inf
314
- st.insert_batch(nodes, nodes_filtrations)
315
-
316
- st.insert_batch(edges, edges_filtration)
317
- for i, filtration in enumerate(filtrations):
318
- match filtration:
319
- case "charge":
320
- charges = molecule.atoms.charges
321
- st.fill_lowerstar(charges, parameter=i)
322
- case "atomic_mass":
323
- masses = molecule.atoms.masses
324
- null_indices = masses == 0
325
- if np.any(null_indices): # guess if necessary
326
- masses[null_indices] = guess_masses(molecule.atoms.types)[
327
- null_indices
328
- ]
329
- st.fill_lowerstar(-masses, parameter=i)
330
- case _:
331
- pass
332
- st.make_filtration_non_decreasing() # Necessary ?
333
- return st
334
-
335
-
336
- def _mol2ripsst(
337
- path: str,
338
- filtrations: Iterable[str],
339
- threshold=np.inf,
340
- bond_types: list = ["ar", "am", 3, 2, 1, 0],
341
- ):
342
- import gudhi as gd
343
-
344
- assert "bond_length" == filtrations[0], "Bond length has to be first for rips."
345
- molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
346
- num_parameters = len(filtrations)
347
- st_rips = gd.RipsComplex(
348
- points=molecule.atoms.positions, max_edge_length=threshold
349
- ).create_simplex_tree()
350
- st = mp.SimplexTreeMulti(
351
- st_rips,
352
- num_parameters=num_parameters,
353
- default_values=[
354
- bond_types.index(0) if f == "bond_type" else -np.inf
355
- for f in filtrations[1:]
356
- ], # the 0 index is the label of 'no bond' in bond_types
357
- )
358
-
359
- ## Edges filtration
360
- mol_bonds = molecule.bonds.indices.T
361
- edges_filtration = (
362
- np.zeros((mol_bonds.shape[1], num_parameters), dtype=np.float32) - np.inf
363
- )
364
- for i, filtration in enumerate(filtrations):
365
- match filtration:
366
- case "bond_type":
367
- edges_filtration[:, i] = lines2bonds(
368
- mol=molecule, bond_types=bond_types
369
- )
370
- case "atomic_mass":
371
- continue
372
- case "charge":
373
- continue
374
- case "bond_length":
375
- edges_filtration[:, i] = [st_rips.filtration(s) for s in mol_bonds.T]
376
- case _:
377
- raise Exception(
378
- f"Invalid filtration {filtration}. Available ones : bond_type, atomic_mass, charge, bond_length."
379
- )
380
- st.assign_batch_filtration(mol_bonds, edges_filtration, propagate=False)
381
- min_filtration = edges_filtration.min(axis=0)
382
- st.assign_batch_filtration(
383
- np.asarray([list(range(st.num_vertices))], dtype=int),
384
- np.asarray([min_filtration] * st.num_vertices, dtype=np.float32),
385
- propagate=False,
386
- )
387
- ## Nodes filtration
388
- for i, filtration in enumerate(filtrations):
389
- match filtration:
390
- case "charge":
391
- charges = molecule.atoms.charges
392
- st.fill_lowerstar(charges, parameter=i)
393
- case "atomic_mass":
394
- masses = molecule.atoms.masses
395
- null_indices = masses == 0
396
- if np.any(null_indices): # guess if necessary
397
- masses[null_indices] = guess_masses(molecule.atoms.types)[
398
- null_indices
399
- ]
400
- # print(masses)
401
- st.fill_lowerstar(-masses, parameter=i)
402
- case _:
403
- pass
404
- st.make_filtration_non_decreasing() # Necessary ?
405
- return st
406
-
407
-
408
- class Molecule2SimplexTree(BaseEstimator, TransformerMixin):
409
- """
410
- Transforms a list of MDA-compatible files into a list of mulitparameter simplextrees
411
-
412
- Input
413
- -----
414
- X: Iterable[path_to_files:str]
415
-
416
- Output
417
- ------
418
- Iterable[multipers.SimplexTreeMulti]
419
-
420
- Parameters
421
- ----------
422
- - filtrations : list of filtration names. Available ones : 'charge', 'atomic_mass', 'bond_length', 'bond_type'. Others are ignored.
423
- - graph : bool. If true, will use the graph given by the molecule, otherwise, a Rips Complex Based on the distance. '
424
- In that case bond_length is ignored (it's the 1rst parameter).
425
- """
426
-
427
- def __init__(
428
- self,
429
- delayed: bool = False,
430
- filtrations: Iterable[str] = [],
431
- graph: bool = True,
432
- n_jobs: int = 1,
433
- ) -> None:
434
- super().__init__()
435
- self.delayed = delayed
436
- self.n_jobs = n_jobs
437
- self.filtrations = filtrations
438
- self.graph = graph
439
- self._molecule_format = None
440
- return
441
-
442
- def fit(self, X: Iterable[str], y=None):
443
- if len(X) == 0:
444
- return self
445
- test_mol = mda.Universe(X[0])
446
- self._molecule_format = test_mol.filename.split(".")[-1].lower()
447
- return self
448
-
449
- def transform(self, X: Iterable[str]):
450
- _to_simplextree = _mol2graphst if self.graph else _mol2ripsst
451
- to_simplex_tree = lambda path_to_mol2_file: [
452
- _to_simplextree(path=path_to_mol2_file, filtrations=self.filtrations)
453
- ]
454
- if self.delayed:
455
- return [delayed(to_simplex_tree)(path) for path in X]
456
- return Parallel(n_jobs=self.n_jobs, prefer="threads")(
457
- delayed(to_simplex_tree)(path) for path in X
458
- )
1
+ import os
2
+ from os import listdir
3
+ from os.path import expanduser
4
+ from typing import Iterable
5
+
6
+ import matplotlib.pyplot as plt
7
+ import MDAnalysis as mda
8
+ import numpy as np
9
+ import pandas as pd
10
+ from joblib import Parallel, delayed
11
+ from MDAnalysis.topology.guessers import guess_masses
12
+ from sklearn.base import BaseEstimator, TransformerMixin
13
+ from sklearn.preprocessing import LabelEncoder
14
+
15
+ # from numba import njit
16
+ from tqdm import tqdm
17
+
18
+ import multipers as mp
19
+
20
+ DATASET_PATH = expanduser("~/Datasets/")
21
+ JC_path = DATASET_PATH + "Cleves-Jain/"
22
+ DUDE_path = DATASET_PATH + "DUD-E/"
23
+
24
+
25
+ # pathes = get_data_path()
26
+ # imgs = apply_pipeline(pathes=pathes, pipeline=pipeline_img)
27
+ # distances_to_letter, ytest = img_distances(imgs)
28
+
29
+
30
+ def _get_mols_in_path(folder):
31
+ with open(folder + "/TargetList", "r") as f:
32
+ train_data = [folder + "/" + mol.strip() for mol in f.readlines()]
33
+ criterion = (
34
+ lambda dataset: dataset.endswith(".mol2")
35
+ and not dataset.startswith("final")
36
+ and dataset not in train_data
37
+ )
38
+ test_data = [
39
+ folder + "/" + dataset
40
+ for dataset in listdir(folder)
41
+ if criterion(folder + "/" + dataset)
42
+ ]
43
+ return train_data, test_data
44
+
45
+
46
+ def get_data_path_JC(type="dict"):
47
+ if type == "dict":
48
+ out = {}
49
+ elif type == "list":
50
+ out = []
51
+ else:
52
+ raise TypeError(f"Type {out} not supported")
53
+ for stuff in listdir(JC_path):
54
+ if stuff.startswith("target_"):
55
+ current_letter = stuff[-1]
56
+ to_add = _get_mols_in_path(JC_path + stuff)
57
+ if type == "dict":
58
+ out[current_letter] = to_add
59
+ elif type == "list":
60
+ out.append(to_add)
61
+ decoy_folder = JC_path + "RognanRing850/"
62
+ to_add = [
63
+ decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")
64
+ ]
65
+ if type == "dict":
66
+ out["decoy"] = to_add
67
+ elif type == "list":
68
+ out.append(to_add)
69
+ return out
70
+
71
+
72
+ def get_all_JC_path():
73
+ out = []
74
+ for stuff in listdir(JC_path):
75
+ if stuff.startswith("target_"):
76
+ train_data, test_data = _get_mols_in_path(JC_path + stuff)
77
+ out += train_data
78
+ out += test_data
79
+ decoy_folder = JC_path + "RognanRing850/"
80
+ out += [
81
+ decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")
82
+ ]
83
+ return out
84
+
85
+
86
+ def split_multimol(
87
+ path: str,
88
+ mol_name: str,
89
+ out_folder_name: str = "splitted",
90
+ enforce_charges: bool = False,
91
+ ):
92
+ with open(path + mol_name, "r") as f:
93
+ lines = f.readlines()
94
+ splitted_mols = []
95
+ index = 0
96
+ for i, line in enumerate(lines):
97
+ is_last = i == len(lines) - 1
98
+ if line.strip() == "@<TRIPOS>MOLECULE" or is_last:
99
+ if i != index:
100
+ molecule = "".join(lines[index : i + is_last])
101
+ if enforce_charges:
102
+ # print(f"Replaced molecule {i}")
103
+ molecule = molecule.replace("NO_CHARGES", "USER_CHARGES")
104
+ # print(molecule)
105
+ # return
106
+ index = i
107
+ splitted_mols.append(molecule)
108
+ if not os.path.exists(path + out_folder_name):
109
+ os.mkdir(path + out_folder_name)
110
+ for i, mol in enumerate(splitted_mols):
111
+ with open(path + out_folder_name + f"/{i}.mol2", "w") as f:
112
+ f.write(mol)
113
+ return [path + out_folder_name + f"/{i}.mol2" for i in range(len(splitted_mols))]
114
+
115
+
116
+ # @njit(parallel=True)
117
+ def apply_pipeline(pathes: dict, pipeline):
118
+ img_dict = {}
119
+ for key, value in tqdm(pathes.items(), desc="Applying pipeline"):
120
+ if len(key) == 1:
121
+ train_paths, test_paths = value
122
+ train_imgs = pipeline.transform(train_paths)
123
+ test_imgs = pipeline.transform(test_paths)
124
+ img_dict[key] = (train_imgs, test_imgs)
125
+ else:
126
+ assert key == "decoy"
127
+ img_dict[key] = pipeline.transform(value)
128
+ return img_dict
129
+
130
+
131
+ from sklearn.metrics import pairwise_distances
132
+
133
+
134
+ def img_distances(img_dict: dict):
135
+ distances_to_anchors = []
136
+ ytest = []
137
+ decoy_list = img_dict["decoy"]
138
+ for letter, imgs in img_dict.items():
139
+ if len(letter) != 1:
140
+ continue # decoy
141
+ xtrain, xtest = imgs
142
+ assert len(xtest) > 0
143
+ train_data, test_data = xtrain, np.concatenate([xtest, decoy_list])
144
+ D = pairwise_distances(train_data, test_data)
145
+ distances_to_anchors.append(D)
146
+ letter_ytest = np.array(
147
+ [letter] * len(xtest) + ["0"] * len(decoy_list), dtype="<U1"
148
+ )
149
+ ytest.append(letter_ytest)
150
+ return distances_to_anchors, ytest
151
+
152
+
153
+ def get_EF_vector_from_distances(distances, ytest, alpha=0.05):
154
+ EF = []
155
+ for distance_to_anchors, letter_ytest in zip(distances, ytest):
156
+ indices = np.argsort(distance_to_anchors, axis=1)
157
+ n = indices.shape[1]
158
+ n_max = int(alpha * n)
159
+ good_indices = (
160
+ letter_ytest[indices[:, :n_max]] == letter_ytest[0]
161
+ ) ## assumes that ytest[:,0] are the good letters
162
+ EF_letter = good_indices.sum(axis=1) / (letter_ytest == letter_ytest[0]).sum()
163
+ EF_letter /= alpha
164
+ EF.append(EF_letter.mean())
165
+ return np.mean(EF)
166
+
167
+
168
+ def EF_from_distance_matrix(
169
+ distances: np.ndarray, labels: list | np.ndarray, alpha: float, anchors_in_test=True
170
+ ):
171
+ """
172
+ Computes the Enrichment Factor from a distance matrix, and its labels.
173
+ - First axis of the distance matrix is the anchors on which to compute the EF
174
+ - Second axis is the test. For convenience, anchors can be put in test, if the flag anchors_in_test is set to true.
175
+ - labels is a table of bools, representing the the labels of the test axis of the distance matrix.
176
+ - alpha : the EF alpha parameter.
177
+ """
178
+ n = len(labels)
179
+ n_max = int(alpha * n)
180
+ indices = np.argsort(distances, axis=1)
181
+ EF_ = [
182
+ ((labels[idx[:n_max]]).sum() - anchors_in_test)
183
+ / (labels.sum() - anchors_in_test)
184
+ for idx in indices
185
+ ]
186
+ return np.mean(EF_) / alpha
187
+
188
+
189
+ def EF_AUC(distances: np.ndarray, labels: np.ndarray, anchors_in_test=0):
190
+ if distances.ndim == 1:
191
+ distances = distances[None, :]
192
+ assert distances.ndim == 2
193
+ indices = np.argsort(distances, axis=1)
194
+ out = []
195
+ for i in range(1, distances.size):
196
+ proportion_of_good_indices = (
197
+ labels[indices[:, :i]].sum(axis=1).mean() - anchors_in_test
198
+ ) / min(i, labels.sum() - anchors_in_test)
199
+ out.append(proportion_of_good_indices)
200
+ # print(out)
201
+ return np.mean(out)
202
+
203
+
204
+ def theorical_max_EF(distances, labels, alpha):
205
+ n = len(labels)
206
+ n_max = int(alpha * n)
207
+ num_true_labels = np.sum(
208
+ labels == labels[0]
209
+ ) ## if labels are not True / False, assumes that the first one is a good one
210
+ return min(n_max, num_true_labels) / alpha
211
+
212
+
213
+ def theorical_max_EF_from_distances(list_of_distances, list_of_labels, alpha):
214
+ return np.mean(
215
+ [
216
+ theorical_max_EF(distances, labels, alpha)
217
+ for distances, labels in zip(list_of_distances, list_of_labels)
218
+ ]
219
+ )
220
+
221
+
222
+ def plot_EF_from_distances(
223
+ alphas=[0.01, 0.02, 0.05, 0.1], EF=EF_from_distance_matrix, plot: bool = True
224
+ ):
225
+ y = np.round([EF(alpha=alpha) for alpha in alphas], decimals=2)
226
+ if plot:
227
+ _alphas = np.linspace(0.01, 1.0, 100)
228
+ plt.figure()
229
+ plt.plot(_alphas, [EF(alpha=alpha) for alpha in _alphas])
230
+ plt.scatter(alphas, y, c="r")
231
+ plt.title("Enrichment Factor")
232
+ plt.xlabel(r"$\alpha$" + f" = {alphas}")
233
+ plt.ylabel(r"$\mathrm{EF}_\alpha$" + f" = {y}")
234
+ return y
235
+
236
+
237
+ def lines2bonds(
238
+ mol: mda.Universe, bond_types=["ar", "am", 3, 2, 1, 0], molecule_format=None
239
+ ):
240
+ extension = (
241
+ mol.filename.split(".")[-1].lower()
242
+ if molecule_format is None
243
+ else molecule_format
244
+ )
245
+ match extension:
246
+ case "mol2":
247
+ out = lines2bonds_MOL2(mol)["bond_type"]
248
+ case "pdb":
249
+ out = lines2bonds_PDB(mol)
250
+ case _:
251
+ raise Exception("Invalid, or not supported molecule format.")
252
+ return LabelEncoder().fit(bond_types).transform(out)
253
+
254
+
255
+ def lines2bonds_MOL2(mol: mda.Universe):
256
+ _lines = open(mol.filename, "r").readlines()
257
+ out = []
258
+ index = 0
259
+ while index < len(_lines) and _lines[index].strip() != "@<TRIPOS>BOND":
260
+ index += 1
261
+ index += 1
262
+ while index < len(_lines) and _lines[index].strip()[0] != "@":
263
+ line = _lines[index].strip().split(" ")
264
+ for j, truc in enumerate(line):
265
+ line[j] = truc.strip()
266
+ # try:
267
+ out.append([stuff for stuff in line if len(stuff) > 0])
268
+ # except:
269
+ # print_lin
270
+ index += 1
271
+ out = pd.DataFrame(out, columns=["bond_id", "atom1", "atom2", "bond_type"])
272
+ out.set_index(["bond_id"], inplace=True)
273
+ return out
274
+
275
+
276
+ def lines2bonds_PDB(mol: mda.Universe):
277
+ raise Exception("Not yet implemented.")
278
+ return
279
+
280
+
281
+ def _mol2graphst(
282
+ path: str | mda.Universe, filtrations: Iterable[str], molecule_format=None
283
+ ):
284
+ molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
285
+
286
+ num_filtrations = len(filtrations)
287
+ nodes = molecule.atoms.indices.reshape(1, -1)
288
+ edges = molecule.bonds.dump_contents().T
289
+ num_vertices = nodes.shape[1]
290
+ num_edges = edges.shape[1]
291
+
292
+ st = mp.SimplexTreeMulti(num_parameters=num_filtrations)
293
+
294
+ ## Edges filtration
295
+ # edges = np.array(bonds_df[["atom1", "atom2"]]).T
296
+ edges_filtration = np.zeros((num_edges, num_filtrations), dtype=np.float32) - np.inf
297
+ for i, filtration in enumerate(filtrations):
298
+ match filtration:
299
+ case "bond_length":
300
+ bond_lengths = molecule.bonds.bonds()
301
+ edges_filtration[:, i] = bond_lengths
302
+ case "bond_type":
303
+ bond_types = lines2bonds(mol=molecule, molecule_format=molecule_format)
304
+ edges_filtration[:, i] = bond_types
305
+ case _:
306
+ pass
307
+
308
+ ## Nodes filtration
309
+ nodes_filtrations = np.zeros(
310
+ (num_vertices, num_filtrations), dtype=np.float32
311
+ ) + np.min(
312
+ edges_filtration, axis=0
313
+ ) # better than - np.inf
314
+ st.insert_batch(nodes, nodes_filtrations)
315
+
316
+ st.insert_batch(edges, edges_filtration)
317
+ for i, filtration in enumerate(filtrations):
318
+ match filtration:
319
+ case "charge":
320
+ charges = molecule.atoms.charges
321
+ st.fill_lowerstar(charges, parameter=i)
322
+ case "atomic_mass":
323
+ masses = molecule.atoms.masses
324
+ null_indices = masses == 0
325
+ if np.any(null_indices): # guess if necessary
326
+ masses[null_indices] = guess_masses(molecule.atoms.types)[
327
+ null_indices
328
+ ]
329
+ st.fill_lowerstar(-masses, parameter=i)
330
+ case _:
331
+ pass
332
+ st.make_filtration_non_decreasing() # Necessary ?
333
+ return st
334
+
335
+
336
+ def _mol2ripsst(
337
+ path: str,
338
+ filtrations: Iterable[str],
339
+ threshold=np.inf,
340
+ bond_types: list = ["ar", "am", 3, 2, 1, 0],
341
+ ):
342
+ import gudhi as gd
343
+
344
+ assert "bond_length" == filtrations[0], "Bond length has to be first for rips."
345
+ molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
346
+ num_parameters = len(filtrations)
347
+ st_rips = gd.RipsComplex(
348
+ points=molecule.atoms.positions, max_edge_length=threshold
349
+ ).create_simplex_tree()
350
+ st = mp.SimplexTreeMulti(
351
+ st_rips,
352
+ num_parameters=num_parameters,
353
+ default_values=[
354
+ bond_types.index(0) if f == "bond_type" else -np.inf
355
+ for f in filtrations[1:]
356
+ ], # the 0 index is the label of 'no bond' in bond_types
357
+ )
358
+
359
+ ## Edges filtration
360
+ mol_bonds = molecule.bonds.indices.T
361
+ edges_filtration = (
362
+ np.zeros((mol_bonds.shape[1], num_parameters), dtype=np.float32) - np.inf
363
+ )
364
+ for i, filtration in enumerate(filtrations):
365
+ match filtration:
366
+ case "bond_type":
367
+ edges_filtration[:, i] = lines2bonds(
368
+ mol=molecule, bond_types=bond_types
369
+ )
370
+ case "atomic_mass":
371
+ continue
372
+ case "charge":
373
+ continue
374
+ case "bond_length":
375
+ edges_filtration[:, i] = [st_rips.filtration(s) for s in mol_bonds.T]
376
+ case _:
377
+ raise Exception(
378
+ f"Invalid filtration {filtration}. Available ones : bond_type, atomic_mass, charge, bond_length."
379
+ )
380
+ st.assign_batch_filtration(mol_bonds, edges_filtration, propagate=False)
381
+ min_filtration = edges_filtration.min(axis=0)
382
+ st.assign_batch_filtration(
383
+ np.asarray([list(range(st.num_vertices))], dtype=int),
384
+ np.asarray([min_filtration] * st.num_vertices, dtype=np.float32),
385
+ propagate=False,
386
+ )
387
+ ## Nodes filtration
388
+ for i, filtration in enumerate(filtrations):
389
+ match filtration:
390
+ case "charge":
391
+ charges = molecule.atoms.charges
392
+ st.fill_lowerstar(charges, parameter=i)
393
+ case "atomic_mass":
394
+ masses = molecule.atoms.masses
395
+ null_indices = masses == 0
396
+ if np.any(null_indices): # guess if necessary
397
+ masses[null_indices] = guess_masses(molecule.atoms.types)[
398
+ null_indices
399
+ ]
400
+ # print(masses)
401
+ st.fill_lowerstar(-masses, parameter=i)
402
+ case _:
403
+ pass
404
+ st.make_filtration_non_decreasing() # Necessary ?
405
+ return st
406
+
407
+
408
+ class Molecule2SimplexTree(BaseEstimator, TransformerMixin):
409
+ """
410
+ Transforms a list of MDA-compatible files into a list of mulitparameter simplextrees
411
+
412
+ Input
413
+ -----
414
+ X: Iterable[path_to_files:str]
415
+
416
+ Output
417
+ ------
418
+ Iterable[multipers.SimplexTreeMulti]
419
+
420
+ Parameters
421
+ ----------
422
+ - filtrations : list of filtration names. Available ones : 'charge', 'atomic_mass', 'bond_length', 'bond_type'. Others are ignored.
423
+ - graph : bool. If true, will use the graph given by the molecule, otherwise, a Rips Complex Based on the distance. '
424
+ In that case bond_length is ignored (it's the 1rst parameter).
425
+ """
426
+
427
+ def __init__(
428
+ self,
429
+ delayed: bool = False,
430
+ filtrations: Iterable[str] = [],
431
+ graph: bool = True,
432
+ n_jobs: int = 1,
433
+ ) -> None:
434
+ super().__init__()
435
+ self.delayed = delayed
436
+ self.n_jobs = n_jobs
437
+ self.filtrations = filtrations
438
+ self.graph = graph
439
+ self._molecule_format = None
440
+ return
441
+
442
+ def fit(self, X: Iterable[str], y=None):
443
+ if len(X) == 0:
444
+ return self
445
+ test_mol = mda.Universe(X[0])
446
+ self._molecule_format = test_mol.filename.split(".")[-1].lower()
447
+ return self
448
+
449
+ def transform(self, X: Iterable[str]):
450
+ _to_simplextree = _mol2graphst if self.graph else _mol2ripsst
451
+ to_simplex_tree = lambda path_to_mol2_file: [
452
+ _to_simplextree(path=path_to_mol2_file, filtrations=self.filtrations)
453
+ ]
454
+ if self.delayed:
455
+ return [delayed(to_simplex_tree)(path) for path in X]
456
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(
457
+ delayed(to_simplex_tree)(path) for path in X
458
+ )