multipers 2.2.3__cp312-cp312-win_amd64.whl → 2.3.1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (182) hide show
  1. multipers/__init__.py +33 -31
  2. multipers/_signed_measure_meta.py +430 -430
  3. multipers/_slicer_meta.py +211 -212
  4. multipers/data/MOL2.py +458 -458
  5. multipers/data/UCR.py +18 -18
  6. multipers/data/graphs.py +466 -466
  7. multipers/data/immuno_regions.py +27 -27
  8. multipers/data/pytorch2simplextree.py +90 -90
  9. multipers/data/shape3d.py +101 -101
  10. multipers/data/synthetic.py +113 -111
  11. multipers/distances.py +198 -198
  12. multipers/filtration_conversions.pxd.tp +84 -84
  13. multipers/filtrations/__init__.py +18 -0
  14. multipers/{ml/convolutions.py → filtrations/density.py} +563 -520
  15. multipers/filtrations/filtrations.py +289 -0
  16. multipers/filtrations.pxd +224 -224
  17. multipers/function_rips.cp312-win_amd64.pyd +0 -0
  18. multipers/function_rips.pyx +105 -105
  19. multipers/grids.cp312-win_amd64.pyd +0 -0
  20. multipers/grids.pyx +350 -350
  21. multipers/gudhi/Persistence_slices_interface.h +132 -132
  22. multipers/gudhi/Simplex_tree_interface.h +239 -245
  23. multipers/gudhi/Simplex_tree_multi_interface.h +516 -561
  24. multipers/gudhi/cubical_to_boundary.h +59 -59
  25. multipers/gudhi/gudhi/Bitmap_cubical_complex.h +450 -450
  26. multipers/gudhi/gudhi/Bitmap_cubical_complex_base.h +1070 -1070
  27. multipers/gudhi/gudhi/Bitmap_cubical_complex_periodic_boundary_conditions_base.h +579 -579
  28. multipers/gudhi/gudhi/Debug_utils.h +45 -45
  29. multipers/gudhi/gudhi/Fields/Multi_field.h +484 -484
  30. multipers/gudhi/gudhi/Fields/Multi_field_operators.h +455 -455
  31. multipers/gudhi/gudhi/Fields/Multi_field_shared.h +450 -450
  32. multipers/gudhi/gudhi/Fields/Multi_field_small.h +531 -531
  33. multipers/gudhi/gudhi/Fields/Multi_field_small_operators.h +507 -507
  34. multipers/gudhi/gudhi/Fields/Multi_field_small_shared.h +531 -531
  35. multipers/gudhi/gudhi/Fields/Z2_field.h +355 -355
  36. multipers/gudhi/gudhi/Fields/Z2_field_operators.h +376 -376
  37. multipers/gudhi/gudhi/Fields/Zp_field.h +420 -420
  38. multipers/gudhi/gudhi/Fields/Zp_field_operators.h +400 -400
  39. multipers/gudhi/gudhi/Fields/Zp_field_shared.h +418 -418
  40. multipers/gudhi/gudhi/Flag_complex_edge_collapser.h +337 -337
  41. multipers/gudhi/gudhi/Matrix.h +2107 -2107
  42. multipers/gudhi/gudhi/Multi_critical_filtration.h +1038 -1038
  43. multipers/gudhi/gudhi/Multi_persistence/Box.h +171 -171
  44. multipers/gudhi/gudhi/Multi_persistence/Line.h +282 -282
  45. multipers/gudhi/gudhi/Off_reader.h +173 -173
  46. multipers/gudhi/gudhi/One_critical_filtration.h +1433 -1431
  47. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix.h +769 -769
  48. multipers/gudhi/gudhi/Persistence_matrix/Base_matrix_with_column_compression.h +686 -686
  49. multipers/gudhi/gudhi/Persistence_matrix/Boundary_matrix.h +842 -842
  50. multipers/gudhi/gudhi/Persistence_matrix/Chain_matrix.h +1350 -1350
  51. multipers/gudhi/gudhi/Persistence_matrix/Id_to_index_overlay.h +1105 -1105
  52. multipers/gudhi/gudhi/Persistence_matrix/Position_to_index_overlay.h +859 -859
  53. multipers/gudhi/gudhi/Persistence_matrix/RU_matrix.h +910 -910
  54. multipers/gudhi/gudhi/Persistence_matrix/allocators/entry_constructors.h +139 -139
  55. multipers/gudhi/gudhi/Persistence_matrix/base_pairing.h +230 -230
  56. multipers/gudhi/gudhi/Persistence_matrix/base_swap.h +211 -211
  57. multipers/gudhi/gudhi/Persistence_matrix/boundary_cell_position_to_id_mapper.h +60 -60
  58. multipers/gudhi/gudhi/Persistence_matrix/boundary_face_position_to_id_mapper.h +60 -60
  59. multipers/gudhi/gudhi/Persistence_matrix/chain_pairing.h +136 -136
  60. multipers/gudhi/gudhi/Persistence_matrix/chain_rep_cycles.h +190 -190
  61. multipers/gudhi/gudhi/Persistence_matrix/chain_vine_swap.h +616 -616
  62. multipers/gudhi/gudhi/Persistence_matrix/columns/chain_column_extra_properties.h +150 -150
  63. multipers/gudhi/gudhi/Persistence_matrix/columns/column_dimension_holder.h +106 -106
  64. multipers/gudhi/gudhi/Persistence_matrix/columns/column_utilities.h +219 -219
  65. multipers/gudhi/gudhi/Persistence_matrix/columns/entry_types.h +327 -327
  66. multipers/gudhi/gudhi/Persistence_matrix/columns/heap_column.h +1140 -1140
  67. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_list_column.h +934 -934
  68. multipers/gudhi/gudhi/Persistence_matrix/columns/intrusive_set_column.h +934 -934
  69. multipers/gudhi/gudhi/Persistence_matrix/columns/list_column.h +980 -980
  70. multipers/gudhi/gudhi/Persistence_matrix/columns/naive_vector_column.h +1092 -1092
  71. multipers/gudhi/gudhi/Persistence_matrix/columns/row_access.h +192 -192
  72. multipers/gudhi/gudhi/Persistence_matrix/columns/set_column.h +921 -921
  73. multipers/gudhi/gudhi/Persistence_matrix/columns/small_vector_column.h +1093 -1093
  74. multipers/gudhi/gudhi/Persistence_matrix/columns/unordered_set_column.h +1012 -1012
  75. multipers/gudhi/gudhi/Persistence_matrix/columns/vector_column.h +1244 -1244
  76. multipers/gudhi/gudhi/Persistence_matrix/matrix_dimension_holders.h +186 -186
  77. multipers/gudhi/gudhi/Persistence_matrix/matrix_row_access.h +164 -164
  78. multipers/gudhi/gudhi/Persistence_matrix/ru_pairing.h +156 -156
  79. multipers/gudhi/gudhi/Persistence_matrix/ru_rep_cycles.h +376 -376
  80. multipers/gudhi/gudhi/Persistence_matrix/ru_vine_swap.h +540 -540
  81. multipers/gudhi/gudhi/Persistent_cohomology/Field_Zp.h +118 -118
  82. multipers/gudhi/gudhi/Persistent_cohomology/Multi_field.h +173 -173
  83. multipers/gudhi/gudhi/Persistent_cohomology/Persistent_cohomology_column.h +128 -128
  84. multipers/gudhi/gudhi/Persistent_cohomology.h +745 -745
  85. multipers/gudhi/gudhi/Points_off_io.h +171 -171
  86. multipers/gudhi/gudhi/Simple_object_pool.h +69 -69
  87. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_iterators.h +463 -463
  88. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_node_explicit_storage.h +83 -83
  89. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_siblings.h +106 -106
  90. multipers/gudhi/gudhi/Simplex_tree/Simplex_tree_star_simplex_iterators.h +277 -277
  91. multipers/gudhi/gudhi/Simplex_tree/hooks_simplex_base.h +62 -62
  92. multipers/gudhi/gudhi/Simplex_tree/indexing_tag.h +27 -27
  93. multipers/gudhi/gudhi/Simplex_tree/serialization_utils.h +62 -62
  94. multipers/gudhi/gudhi/Simplex_tree/simplex_tree_options.h +157 -157
  95. multipers/gudhi/gudhi/Simplex_tree.h +2794 -2794
  96. multipers/gudhi/gudhi/Simplex_tree_multi.h +152 -163
  97. multipers/gudhi/gudhi/distance_functions.h +62 -62
  98. multipers/gudhi/gudhi/graph_simplicial_complex.h +104 -104
  99. multipers/gudhi/gudhi/persistence_interval.h +253 -253
  100. multipers/gudhi/gudhi/persistence_matrix_options.h +170 -170
  101. multipers/gudhi/gudhi/reader_utils.h +367 -367
  102. multipers/gudhi/mma_interface_coh.h +256 -255
  103. multipers/gudhi/mma_interface_h0.h +223 -231
  104. multipers/gudhi/mma_interface_matrix.h +291 -282
  105. multipers/gudhi/naive_merge_tree.h +536 -575
  106. multipers/gudhi/scc_io.h +310 -289
  107. multipers/gudhi/truc.h +957 -888
  108. multipers/io.cp312-win_amd64.pyd +0 -0
  109. multipers/io.pyx +714 -711
  110. multipers/ml/accuracies.py +90 -90
  111. multipers/ml/invariants_with_persistable.py +79 -79
  112. multipers/ml/kernels.py +176 -176
  113. multipers/ml/mma.py +713 -714
  114. multipers/ml/one.py +472 -472
  115. multipers/ml/point_clouds.py +352 -346
  116. multipers/ml/signed_measures.py +1589 -1589
  117. multipers/ml/sliced_wasserstein.py +461 -461
  118. multipers/ml/tools.py +113 -113
  119. multipers/mma_structures.cp312-win_amd64.pyd +0 -0
  120. multipers/mma_structures.pxd +127 -127
  121. multipers/mma_structures.pyx +4 -8
  122. multipers/mma_structures.pyx.tp +1083 -1085
  123. multipers/multi_parameter_rank_invariant/diff_helpers.h +84 -93
  124. multipers/multi_parameter_rank_invariant/euler_characteristic.h +97 -97
  125. multipers/multi_parameter_rank_invariant/function_rips.h +322 -322
  126. multipers/multi_parameter_rank_invariant/hilbert_function.h +769 -769
  127. multipers/multi_parameter_rank_invariant/persistence_slices.h +148 -148
  128. multipers/multi_parameter_rank_invariant/rank_invariant.h +369 -369
  129. multipers/multiparameter_edge_collapse.py +41 -41
  130. multipers/multiparameter_module_approximation/approximation.h +2298 -2295
  131. multipers/multiparameter_module_approximation/combinatory.h +129 -129
  132. multipers/multiparameter_module_approximation/debug.h +107 -107
  133. multipers/multiparameter_module_approximation/format_python-cpp.h +286 -286
  134. multipers/multiparameter_module_approximation/heap_column.h +238 -238
  135. multipers/multiparameter_module_approximation/images.h +79 -79
  136. multipers/multiparameter_module_approximation/list_column.h +174 -174
  137. multipers/multiparameter_module_approximation/list_column_2.h +232 -232
  138. multipers/multiparameter_module_approximation/ru_matrix.h +347 -347
  139. multipers/multiparameter_module_approximation/set_column.h +135 -135
  140. multipers/multiparameter_module_approximation/structure_higher_dim_barcode.h +36 -36
  141. multipers/multiparameter_module_approximation/unordered_set_column.h +166 -166
  142. multipers/multiparameter_module_approximation/utilities.h +403 -419
  143. multipers/multiparameter_module_approximation/vector_column.h +223 -223
  144. multipers/multiparameter_module_approximation/vector_matrix.h +331 -331
  145. multipers/multiparameter_module_approximation/vineyards.h +464 -464
  146. multipers/multiparameter_module_approximation/vineyards_trajectories.h +649 -649
  147. multipers/multiparameter_module_approximation.cp312-win_amd64.pyd +0 -0
  148. multipers/multiparameter_module_approximation.pyx +218 -217
  149. multipers/pickle.py +90 -53
  150. multipers/plots.py +342 -334
  151. multipers/point_measure.cp312-win_amd64.pyd +0 -0
  152. multipers/point_measure.pyx +322 -320
  153. multipers/simplex_tree_multi.cp312-win_amd64.pyd +0 -0
  154. multipers/simplex_tree_multi.pxd +133 -133
  155. multipers/simplex_tree_multi.pyx +115 -48
  156. multipers/simplex_tree_multi.pyx.tp +1947 -1935
  157. multipers/slicer.cp312-win_amd64.pyd +0 -0
  158. multipers/slicer.pxd +281 -100
  159. multipers/slicer.pxd.tp +218 -214
  160. multipers/slicer.pyx +1570 -507
  161. multipers/slicer.pyx.tp +931 -914
  162. multipers/tensor/tensor.h +672 -672
  163. multipers/tensor.pxd +13 -13
  164. multipers/test.pyx +44 -44
  165. multipers/tests/__init__.py +57 -57
  166. multipers/torch/diff_grids.py +217 -217
  167. multipers/torch/rips_density.py +310 -304
  168. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/LICENSE +21 -21
  169. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/METADATA +21 -11
  170. multipers-2.3.1.dist-info/RECORD +182 -0
  171. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/WHEEL +1 -1
  172. multipers/tests/test_diff_helper.py +0 -73
  173. multipers/tests/test_hilbert_function.py +0 -82
  174. multipers/tests/test_mma.py +0 -83
  175. multipers/tests/test_point_clouds.py +0 -49
  176. multipers/tests/test_python-cpp_conversion.py +0 -82
  177. multipers/tests/test_signed_betti.py +0 -181
  178. multipers/tests/test_signed_measure.py +0 -89
  179. multipers/tests/test_simplextreemulti.py +0 -221
  180. multipers/tests/test_slicer.py +0 -221
  181. multipers-2.2.3.dist-info/RECORD +0 -189
  182. {multipers-2.2.3.dist-info → multipers-2.3.1.dist-info}/top_level.txt +0 -0
multipers/data/MOL2.py CHANGED
@@ -1,458 +1,458 @@
1
- import os
2
- from os import listdir
3
- from os.path import expanduser
4
- from typing import Iterable
5
-
6
- import matplotlib.pyplot as plt
7
- import MDAnalysis as mda
8
- import numpy as np
9
- import pandas as pd
10
- from joblib import Parallel, delayed
11
- from MDAnalysis.topology.guessers import guess_masses
12
- from sklearn.base import BaseEstimator, TransformerMixin
13
- from sklearn.preprocessing import LabelEncoder
14
-
15
- # from numba import njit
16
- from tqdm import tqdm
17
-
18
- import multipers as mp
19
-
20
- DATASET_PATH = expanduser("~/Datasets/")
21
- JC_path = DATASET_PATH + "Cleves-Jain/"
22
- DUDE_path = DATASET_PATH + "DUD-E/"
23
-
24
-
25
- # pathes = get_data_path()
26
- # imgs = apply_pipeline(pathes=pathes, pipeline=pipeline_img)
27
- # distances_to_letter, ytest = img_distances(imgs)
28
-
29
-
30
- def _get_mols_in_path(folder):
31
- with open(folder + "/TargetList", "r") as f:
32
- train_data = [folder + "/" + mol.strip() for mol in f.readlines()]
33
- criterion = (
34
- lambda dataset: dataset.endswith(".mol2")
35
- and not dataset.startswith("final")
36
- and dataset not in train_data
37
- )
38
- test_data = [
39
- folder + "/" + dataset
40
- for dataset in listdir(folder)
41
- if criterion(folder + "/" + dataset)
42
- ]
43
- return train_data, test_data
44
-
45
-
46
- def get_data_path_JC(type="dict"):
47
- if type == "dict":
48
- out = {}
49
- elif type == "list":
50
- out = []
51
- else:
52
- raise TypeError(f"Type {out} not supported")
53
- for stuff in listdir(JC_path):
54
- if stuff.startswith("target_"):
55
- current_letter = stuff[-1]
56
- to_add = _get_mols_in_path(JC_path + stuff)
57
- if type == "dict":
58
- out[current_letter] = to_add
59
- elif type == "list":
60
- out.append(to_add)
61
- decoy_folder = JC_path + "RognanRing850/"
62
- to_add = [
63
- decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")
64
- ]
65
- if type == "dict":
66
- out["decoy"] = to_add
67
- elif type == "list":
68
- out.append(to_add)
69
- return out
70
-
71
-
72
- def get_all_JC_path():
73
- out = []
74
- for stuff in listdir(JC_path):
75
- if stuff.startswith("target_"):
76
- train_data, test_data = _get_mols_in_path(JC_path + stuff)
77
- out += train_data
78
- out += test_data
79
- decoy_folder = JC_path + "RognanRing850/"
80
- out += [
81
- decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")
82
- ]
83
- return out
84
-
85
-
86
- def split_multimol(
87
- path: str,
88
- mol_name: str,
89
- out_folder_name: str = "splitted",
90
- enforce_charges: bool = False,
91
- ):
92
- with open(path + mol_name, "r") as f:
93
- lines = f.readlines()
94
- splitted_mols = []
95
- index = 0
96
- for i, line in enumerate(lines):
97
- is_last = i == len(lines) - 1
98
- if line.strip() == "@<TRIPOS>MOLECULE" or is_last:
99
- if i != index:
100
- molecule = "".join(lines[index : i + is_last])
101
- if enforce_charges:
102
- # print(f"Replaced molecule {i}")
103
- molecule = molecule.replace("NO_CHARGES", "USER_CHARGES")
104
- # print(molecule)
105
- # return
106
- index = i
107
- splitted_mols.append(molecule)
108
- if not os.path.exists(path + out_folder_name):
109
- os.mkdir(path + out_folder_name)
110
- for i, mol in enumerate(splitted_mols):
111
- with open(path + out_folder_name + f"/{i}.mol2", "w") as f:
112
- f.write(mol)
113
- return [path + out_folder_name + f"/{i}.mol2" for i in range(len(splitted_mols))]
114
-
115
-
116
- # @njit(parallel=True)
117
- def apply_pipeline(pathes: dict, pipeline):
118
- img_dict = {}
119
- for key, value in tqdm(pathes.items(), desc="Applying pipeline"):
120
- if len(key) == 1:
121
- train_paths, test_paths = value
122
- train_imgs = pipeline.transform(train_paths)
123
- test_imgs = pipeline.transform(test_paths)
124
- img_dict[key] = (train_imgs, test_imgs)
125
- else:
126
- assert key == "decoy"
127
- img_dict[key] = pipeline.transform(value)
128
- return img_dict
129
-
130
-
131
- from sklearn.metrics import pairwise_distances
132
-
133
-
134
- def img_distances(img_dict: dict):
135
- distances_to_anchors = []
136
- ytest = []
137
- decoy_list = img_dict["decoy"]
138
- for letter, imgs in img_dict.items():
139
- if len(letter) != 1:
140
- continue # decoy
141
- xtrain, xtest = imgs
142
- assert len(xtest) > 0
143
- train_data, test_data = xtrain, np.concatenate([xtest, decoy_list])
144
- D = pairwise_distances(train_data, test_data)
145
- distances_to_anchors.append(D)
146
- letter_ytest = np.array(
147
- [letter] * len(xtest) + ["0"] * len(decoy_list), dtype="<U1"
148
- )
149
- ytest.append(letter_ytest)
150
- return distances_to_anchors, ytest
151
-
152
-
153
- def get_EF_vector_from_distances(distances, ytest, alpha=0.05):
154
- EF = []
155
- for distance_to_anchors, letter_ytest in zip(distances, ytest):
156
- indices = np.argsort(distance_to_anchors, axis=1)
157
- n = indices.shape[1]
158
- n_max = int(alpha * n)
159
- good_indices = (
160
- letter_ytest[indices[:, :n_max]] == letter_ytest[0]
161
- ) ## assumes that ytest[:,0] are the good letters
162
- EF_letter = good_indices.sum(axis=1) / (letter_ytest == letter_ytest[0]).sum()
163
- EF_letter /= alpha
164
- EF.append(EF_letter.mean())
165
- return np.mean(EF)
166
-
167
-
168
- def EF_from_distance_matrix(
169
- distances: np.ndarray, labels: list | np.ndarray, alpha: float, anchors_in_test=True
170
- ):
171
- """
172
- Computes the Enrichment Factor from a distance matrix, and its labels.
173
- - First axis of the distance matrix is the anchors on which to compute the EF
174
- - Second axis is the test. For convenience, anchors can be put in test, if the flag anchors_in_test is set to true.
175
- - labels is a table of bools, representing the the labels of the test axis of the distance matrix.
176
- - alpha : the EF alpha parameter.
177
- """
178
- n = len(labels)
179
- n_max = int(alpha * n)
180
- indices = np.argsort(distances, axis=1)
181
- EF_ = [
182
- ((labels[idx[:n_max]]).sum() - anchors_in_test)
183
- / (labels.sum() - anchors_in_test)
184
- for idx in indices
185
- ]
186
- return np.mean(EF_) / alpha
187
-
188
-
189
- def EF_AUC(distances: np.ndarray, labels: np.ndarray, anchors_in_test=0):
190
- if distances.ndim == 1:
191
- distances = distances[None, :]
192
- assert distances.ndim == 2
193
- indices = np.argsort(distances, axis=1)
194
- out = []
195
- for i in range(1, distances.size):
196
- proportion_of_good_indices = (
197
- labels[indices[:, :i]].sum(axis=1).mean() - anchors_in_test
198
- ) / min(i, labels.sum() - anchors_in_test)
199
- out.append(proportion_of_good_indices)
200
- # print(out)
201
- return np.mean(out)
202
-
203
-
204
- def theorical_max_EF(distances, labels, alpha):
205
- n = len(labels)
206
- n_max = int(alpha * n)
207
- num_true_labels = np.sum(
208
- labels == labels[0]
209
- ) ## if labels are not True / False, assumes that the first one is a good one
210
- return min(n_max, num_true_labels) / alpha
211
-
212
-
213
- def theorical_max_EF_from_distances(list_of_distances, list_of_labels, alpha):
214
- return np.mean(
215
- [
216
- theorical_max_EF(distances, labels, alpha)
217
- for distances, labels in zip(list_of_distances, list_of_labels)
218
- ]
219
- )
220
-
221
-
222
- def plot_EF_from_distances(
223
- alphas=[0.01, 0.02, 0.05, 0.1], EF=EF_from_distance_matrix, plot: bool = True
224
- ):
225
- y = np.round([EF(alpha=alpha) for alpha in alphas], decimals=2)
226
- if plot:
227
- _alphas = np.linspace(0.01, 1.0, 100)
228
- plt.figure()
229
- plt.plot(_alphas, [EF(alpha=alpha) for alpha in _alphas])
230
- plt.scatter(alphas, y, c="r")
231
- plt.title("Enrichment Factor")
232
- plt.xlabel(r"$\alpha$" + f" = {alphas}")
233
- plt.ylabel(r"$\mathrm{EF}_\alpha$" + f" = {y}")
234
- return y
235
-
236
-
237
- def lines2bonds(
238
- mol: mda.Universe, bond_types=["ar", "am", 3, 2, 1, 0], molecule_format=None
239
- ):
240
- extension = (
241
- mol.filename.split(".")[-1].lower()
242
- if molecule_format is None
243
- else molecule_format
244
- )
245
- match extension:
246
- case "mol2":
247
- out = lines2bonds_MOL2(mol)["bond_type"]
248
- case "pdb":
249
- out = lines2bonds_PDB(mol)
250
- case _:
251
- raise Exception("Invalid, or not supported molecule format.")
252
- return LabelEncoder().fit(bond_types).transform(out)
253
-
254
-
255
- def lines2bonds_MOL2(mol: mda.Universe):
256
- _lines = open(mol.filename, "r").readlines()
257
- out = []
258
- index = 0
259
- while index < len(_lines) and _lines[index].strip() != "@<TRIPOS>BOND":
260
- index += 1
261
- index += 1
262
- while index < len(_lines) and _lines[index].strip()[0] != "@":
263
- line = _lines[index].strip().split(" ")
264
- for j, truc in enumerate(line):
265
- line[j] = truc.strip()
266
- # try:
267
- out.append([stuff for stuff in line if len(stuff) > 0])
268
- # except:
269
- # print_lin
270
- index += 1
271
- out = pd.DataFrame(out, columns=["bond_id", "atom1", "atom2", "bond_type"])
272
- out.set_index(["bond_id"], inplace=True)
273
- return out
274
-
275
-
276
- def lines2bonds_PDB(mol: mda.Universe):
277
- raise Exception("Not yet implemented.")
278
- return
279
-
280
-
281
- def _mol2graphst(
282
- path: str | mda.Universe, filtrations: Iterable[str], molecule_format=None
283
- ):
284
- molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
285
-
286
- num_filtrations = len(filtrations)
287
- nodes = molecule.atoms.indices.reshape(1, -1)
288
- edges = molecule.bonds.dump_contents().T
289
- num_vertices = nodes.shape[1]
290
- num_edges = edges.shape[1]
291
-
292
- st = mp.SimplexTreeMulti(num_parameters=num_filtrations)
293
-
294
- ## Edges filtration
295
- # edges = np.array(bonds_df[["atom1", "atom2"]]).T
296
- edges_filtration = np.zeros((num_edges, num_filtrations), dtype=np.float32) - np.inf
297
- for i, filtration in enumerate(filtrations):
298
- match filtration:
299
- case "bond_length":
300
- bond_lengths = molecule.bonds.bonds()
301
- edges_filtration[:, i] = bond_lengths
302
- case "bond_type":
303
- bond_types = lines2bonds(mol=molecule, molecule_format=molecule_format)
304
- edges_filtration[:, i] = bond_types
305
- case _:
306
- pass
307
-
308
- ## Nodes filtration
309
- nodes_filtrations = np.zeros(
310
- (num_vertices, num_filtrations), dtype=np.float32
311
- ) + np.min(
312
- edges_filtration, axis=0
313
- ) # better than - np.inf
314
- st.insert_batch(nodes, nodes_filtrations)
315
-
316
- st.insert_batch(edges, edges_filtration)
317
- for i, filtration in enumerate(filtrations):
318
- match filtration:
319
- case "charge":
320
- charges = molecule.atoms.charges
321
- st.fill_lowerstar(charges, parameter=i)
322
- case "atomic_mass":
323
- masses = molecule.atoms.masses
324
- null_indices = masses == 0
325
- if np.any(null_indices): # guess if necessary
326
- masses[null_indices] = guess_masses(molecule.atoms.types)[
327
- null_indices
328
- ]
329
- st.fill_lowerstar(-masses, parameter=i)
330
- case _:
331
- pass
332
- st.make_filtration_non_decreasing() # Necessary ?
333
- return st
334
-
335
-
336
- def _mol2ripsst(
337
- path: str,
338
- filtrations: Iterable[str],
339
- threshold=np.inf,
340
- bond_types: list = ["ar", "am", 3, 2, 1, 0],
341
- ):
342
- import gudhi as gd
343
-
344
- assert "bond_length" == filtrations[0], "Bond length has to be first for rips."
345
- molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
346
- num_parameters = len(filtrations)
347
- st_rips = gd.RipsComplex(
348
- points=molecule.atoms.positions, max_edge_length=threshold
349
- ).create_simplex_tree()
350
- st = mp.SimplexTreeMulti(
351
- st_rips,
352
- num_parameters=num_parameters,
353
- default_values=[
354
- bond_types.index(0) if f == "bond_type" else -np.inf
355
- for f in filtrations[1:]
356
- ], # the 0 index is the label of 'no bond' in bond_types
357
- )
358
-
359
- ## Edges filtration
360
- mol_bonds = molecule.bonds.indices.T
361
- edges_filtration = (
362
- np.zeros((mol_bonds.shape[1], num_parameters), dtype=np.float32) - np.inf
363
- )
364
- for i, filtration in enumerate(filtrations):
365
- match filtration:
366
- case "bond_type":
367
- edges_filtration[:, i] = lines2bonds(
368
- mol=molecule, bond_types=bond_types
369
- )
370
- case "atomic_mass":
371
- continue
372
- case "charge":
373
- continue
374
- case "bond_length":
375
- edges_filtration[:, i] = [st_rips.filtration(s) for s in mol_bonds.T]
376
- case _:
377
- raise Exception(
378
- f"Invalid filtration {filtration}. Available ones : bond_type, atomic_mass, charge, bond_length."
379
- )
380
- st.assign_batch_filtration(mol_bonds, edges_filtration, propagate=False)
381
- min_filtration = edges_filtration.min(axis=0)
382
- st.assign_batch_filtration(
383
- np.asarray([list(range(st.num_vertices))], dtype=int),
384
- np.asarray([min_filtration] * st.num_vertices, dtype=np.float32),
385
- propagate=False,
386
- )
387
- ## Nodes filtration
388
- for i, filtration in enumerate(filtrations):
389
- match filtration:
390
- case "charge":
391
- charges = molecule.atoms.charges
392
- st.fill_lowerstar(charges, parameter=i)
393
- case "atomic_mass":
394
- masses = molecule.atoms.masses
395
- null_indices = masses == 0
396
- if np.any(null_indices): # guess if necessary
397
- masses[null_indices] = guess_masses(molecule.atoms.types)[
398
- null_indices
399
- ]
400
- # print(masses)
401
- st.fill_lowerstar(-masses, parameter=i)
402
- case _:
403
- pass
404
- st.make_filtration_non_decreasing() # Necessary ?
405
- return st
406
-
407
-
408
- class Molecule2SimplexTree(BaseEstimator, TransformerMixin):
409
- """
410
- Transforms a list of MDA-compatible files into a list of mulitparameter simplextrees
411
-
412
- Input
413
- -----
414
- X: Iterable[path_to_files:str]
415
-
416
- Output
417
- ------
418
- Iterable[multipers.SimplexTreeMulti]
419
-
420
- Parameters
421
- ----------
422
- - filtrations : list of filtration names. Available ones : 'charge', 'atomic_mass', 'bond_length', 'bond_type'. Others are ignored.
423
- - graph : bool. If true, will use the graph given by the molecule, otherwise, a Rips Complex Based on the distance. '
424
- In that case bond_length is ignored (it's the 1rst parameter).
425
- """
426
-
427
- def __init__(
428
- self,
429
- delayed: bool = False,
430
- filtrations: Iterable[str] = [],
431
- graph: bool = True,
432
- n_jobs: int = 1,
433
- ) -> None:
434
- super().__init__()
435
- self.delayed = delayed
436
- self.n_jobs = n_jobs
437
- self.filtrations = filtrations
438
- self.graph = graph
439
- self._molecule_format = None
440
- return
441
-
442
- def fit(self, X: Iterable[str], y=None):
443
- if len(X) == 0:
444
- return self
445
- test_mol = mda.Universe(X[0])
446
- self._molecule_format = test_mol.filename.split(".")[-1].lower()
447
- return self
448
-
449
- def transform(self, X: Iterable[str]):
450
- _to_simplextree = _mol2graphst if self.graph else _mol2ripsst
451
- to_simplex_tree = lambda path_to_mol2_file: [
452
- _to_simplextree(path=path_to_mol2_file, filtrations=self.filtrations)
453
- ]
454
- if self.delayed:
455
- return [delayed(to_simplex_tree)(path) for path in X]
456
- return Parallel(n_jobs=self.n_jobs, prefer="threads")(
457
- delayed(to_simplex_tree)(path) for path in X
458
- )
1
+ import os
2
+ from os import listdir
3
+ from os.path import expanduser
4
+ from typing import Iterable
5
+
6
+ import matplotlib.pyplot as plt
7
+ import MDAnalysis as mda
8
+ import numpy as np
9
+ import pandas as pd
10
+ from joblib import Parallel, delayed
11
+ from MDAnalysis.topology.guessers import guess_masses
12
+ from sklearn.base import BaseEstimator, TransformerMixin
13
+ from sklearn.preprocessing import LabelEncoder
14
+
15
+ # from numba import njit
16
+ from tqdm import tqdm
17
+
18
+ import multipers as mp
19
+
20
+ DATASET_PATH = expanduser("~/Datasets/")
21
+ JC_path = DATASET_PATH + "Cleves-Jain/"
22
+ DUDE_path = DATASET_PATH + "DUD-E/"
23
+
24
+
25
+ # pathes = get_data_path()
26
+ # imgs = apply_pipeline(pathes=pathes, pipeline=pipeline_img)
27
+ # distances_to_letter, ytest = img_distances(imgs)
28
+
29
+
30
+ def _get_mols_in_path(folder):
31
+ with open(folder + "/TargetList", "r") as f:
32
+ train_data = [folder + "/" + mol.strip() for mol in f.readlines()]
33
+ criterion = (
34
+ lambda dataset: dataset.endswith(".mol2")
35
+ and not dataset.startswith("final")
36
+ and dataset not in train_data
37
+ )
38
+ test_data = [
39
+ folder + "/" + dataset
40
+ for dataset in listdir(folder)
41
+ if criterion(folder + "/" + dataset)
42
+ ]
43
+ return train_data, test_data
44
+
45
+
46
+ def get_data_path_JC(type="dict"):
47
+ if type == "dict":
48
+ out = {}
49
+ elif type == "list":
50
+ out = []
51
+ else:
52
+ raise TypeError(f"Type {out} not supported")
53
+ for stuff in listdir(JC_path):
54
+ if stuff.startswith("target_"):
55
+ current_letter = stuff[-1]
56
+ to_add = _get_mols_in_path(JC_path + stuff)
57
+ if type == "dict":
58
+ out[current_letter] = to_add
59
+ elif type == "list":
60
+ out.append(to_add)
61
+ decoy_folder = JC_path + "RognanRing850/"
62
+ to_add = [
63
+ decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")
64
+ ]
65
+ if type == "dict":
66
+ out["decoy"] = to_add
67
+ elif type == "list":
68
+ out.append(to_add)
69
+ return out
70
+
71
+
72
+ def get_all_JC_path():
73
+ out = []
74
+ for stuff in listdir(JC_path):
75
+ if stuff.startswith("target_"):
76
+ train_data, test_data = _get_mols_in_path(JC_path + stuff)
77
+ out += train_data
78
+ out += test_data
79
+ decoy_folder = JC_path + "RognanRing850/"
80
+ out += [
81
+ decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")
82
+ ]
83
+ return out
84
+
85
+
86
+ def split_multimol(
87
+ path: str,
88
+ mol_name: str,
89
+ out_folder_name: str = "splitted",
90
+ enforce_charges: bool = False,
91
+ ):
92
+ with open(path + mol_name, "r") as f:
93
+ lines = f.readlines()
94
+ splitted_mols = []
95
+ index = 0
96
+ for i, line in enumerate(lines):
97
+ is_last = i == len(lines) - 1
98
+ if line.strip() == "@<TRIPOS>MOLECULE" or is_last:
99
+ if i != index:
100
+ molecule = "".join(lines[index : i + is_last])
101
+ if enforce_charges:
102
+ # print(f"Replaced molecule {i}")
103
+ molecule = molecule.replace("NO_CHARGES", "USER_CHARGES")
104
+ # print(molecule)
105
+ # return
106
+ index = i
107
+ splitted_mols.append(molecule)
108
+ if not os.path.exists(path + out_folder_name):
109
+ os.mkdir(path + out_folder_name)
110
+ for i, mol in enumerate(splitted_mols):
111
+ with open(path + out_folder_name + f"/{i}.mol2", "w") as f:
112
+ f.write(mol)
113
+ return [path + out_folder_name + f"/{i}.mol2" for i in range(len(splitted_mols))]
114
+
115
+
116
+ # @njit(parallel=True)
117
+ def apply_pipeline(pathes: dict, pipeline):
118
+ img_dict = {}
119
+ for key, value in tqdm(pathes.items(), desc="Applying pipeline"):
120
+ if len(key) == 1:
121
+ train_paths, test_paths = value
122
+ train_imgs = pipeline.transform(train_paths)
123
+ test_imgs = pipeline.transform(test_paths)
124
+ img_dict[key] = (train_imgs, test_imgs)
125
+ else:
126
+ assert key == "decoy"
127
+ img_dict[key] = pipeline.transform(value)
128
+ return img_dict
129
+
130
+
131
+ from sklearn.metrics import pairwise_distances
132
+
133
+
134
+ def img_distances(img_dict: dict):
135
+ distances_to_anchors = []
136
+ ytest = []
137
+ decoy_list = img_dict["decoy"]
138
+ for letter, imgs in img_dict.items():
139
+ if len(letter) != 1:
140
+ continue # decoy
141
+ xtrain, xtest = imgs
142
+ assert len(xtest) > 0
143
+ train_data, test_data = xtrain, np.concatenate([xtest, decoy_list])
144
+ D = pairwise_distances(train_data, test_data)
145
+ distances_to_anchors.append(D)
146
+ letter_ytest = np.array(
147
+ [letter] * len(xtest) + ["0"] * len(decoy_list), dtype="<U1"
148
+ )
149
+ ytest.append(letter_ytest)
150
+ return distances_to_anchors, ytest
151
+
152
+
153
+ def get_EF_vector_from_distances(distances, ytest, alpha=0.05):
154
+ EF = []
155
+ for distance_to_anchors, letter_ytest in zip(distances, ytest):
156
+ indices = np.argsort(distance_to_anchors, axis=1)
157
+ n = indices.shape[1]
158
+ n_max = int(alpha * n)
159
+ good_indices = (
160
+ letter_ytest[indices[:, :n_max]] == letter_ytest[0]
161
+ ) ## assumes that ytest[:,0] are the good letters
162
+ EF_letter = good_indices.sum(axis=1) / (letter_ytest == letter_ytest[0]).sum()
163
+ EF_letter /= alpha
164
+ EF.append(EF_letter.mean())
165
+ return np.mean(EF)
166
+
167
+
168
+ def EF_from_distance_matrix(
169
+ distances: np.ndarray, labels: list | np.ndarray, alpha: float, anchors_in_test=True
170
+ ):
171
+ """
172
+ Computes the Enrichment Factor from a distance matrix, and its labels.
173
+ - First axis of the distance matrix is the anchors on which to compute the EF
174
+ - Second axis is the test. For convenience, anchors can be put in test, if the flag anchors_in_test is set to true.
175
+ - labels is a table of bools, representing the the labels of the test axis of the distance matrix.
176
+ - alpha : the EF alpha parameter.
177
+ """
178
+ n = len(labels)
179
+ n_max = int(alpha * n)
180
+ indices = np.argsort(distances, axis=1)
181
+ EF_ = [
182
+ ((labels[idx[:n_max]]).sum() - anchors_in_test)
183
+ / (labels.sum() - anchors_in_test)
184
+ for idx in indices
185
+ ]
186
+ return np.mean(EF_) / alpha
187
+
188
+
189
+ def EF_AUC(distances: np.ndarray, labels: np.ndarray, anchors_in_test=0):
190
+ if distances.ndim == 1:
191
+ distances = distances[None, :]
192
+ assert distances.ndim == 2
193
+ indices = np.argsort(distances, axis=1)
194
+ out = []
195
+ for i in range(1, distances.size):
196
+ proportion_of_good_indices = (
197
+ labels[indices[:, :i]].sum(axis=1).mean() - anchors_in_test
198
+ ) / min(i, labels.sum() - anchors_in_test)
199
+ out.append(proportion_of_good_indices)
200
+ # print(out)
201
+ return np.mean(out)
202
+
203
+
204
+ def theorical_max_EF(distances, labels, alpha):
205
+ n = len(labels)
206
+ n_max = int(alpha * n)
207
+ num_true_labels = np.sum(
208
+ labels == labels[0]
209
+ ) ## if labels are not True / False, assumes that the first one is a good one
210
+ return min(n_max, num_true_labels) / alpha
211
+
212
+
213
+ def theorical_max_EF_from_distances(list_of_distances, list_of_labels, alpha):
214
+ return np.mean(
215
+ [
216
+ theorical_max_EF(distances, labels, alpha)
217
+ for distances, labels in zip(list_of_distances, list_of_labels)
218
+ ]
219
+ )
220
+
221
+
222
+ def plot_EF_from_distances(
223
+ alphas=[0.01, 0.02, 0.05, 0.1], EF=EF_from_distance_matrix, plot: bool = True
224
+ ):
225
+ y = np.round([EF(alpha=alpha) for alpha in alphas], decimals=2)
226
+ if plot:
227
+ _alphas = np.linspace(0.01, 1.0, 100)
228
+ plt.figure()
229
+ plt.plot(_alphas, [EF(alpha=alpha) for alpha in _alphas])
230
+ plt.scatter(alphas, y, c="r")
231
+ plt.title("Enrichment Factor")
232
+ plt.xlabel(r"$\alpha$" + f" = {alphas}")
233
+ plt.ylabel(r"$\mathrm{EF}_\alpha$" + f" = {y}")
234
+ return y
235
+
236
+
237
+ def lines2bonds(
238
+ mol: mda.Universe, bond_types=["ar", "am", 3, 2, 1, 0], molecule_format=None
239
+ ):
240
+ extension = (
241
+ mol.filename.split(".")[-1].lower()
242
+ if molecule_format is None
243
+ else molecule_format
244
+ )
245
+ match extension:
246
+ case "mol2":
247
+ out = lines2bonds_MOL2(mol)["bond_type"]
248
+ case "pdb":
249
+ out = lines2bonds_PDB(mol)
250
+ case _:
251
+ raise Exception("Invalid, or not supported molecule format.")
252
+ return LabelEncoder().fit(bond_types).transform(out)
253
+
254
+
255
+ def lines2bonds_MOL2(mol: mda.Universe):
256
+ _lines = open(mol.filename, "r").readlines()
257
+ out = []
258
+ index = 0
259
+ while index < len(_lines) and _lines[index].strip() != "@<TRIPOS>BOND":
260
+ index += 1
261
+ index += 1
262
+ while index < len(_lines) and _lines[index].strip()[0] != "@":
263
+ line = _lines[index].strip().split(" ")
264
+ for j, truc in enumerate(line):
265
+ line[j] = truc.strip()
266
+ # try:
267
+ out.append([stuff for stuff in line if len(stuff) > 0])
268
+ # except:
269
+ # print_lin
270
+ index += 1
271
+ out = pd.DataFrame(out, columns=["bond_id", "atom1", "atom2", "bond_type"])
272
+ out.set_index(["bond_id"], inplace=True)
273
+ return out
274
+
275
+
276
+ def lines2bonds_PDB(mol: mda.Universe):
277
+ raise Exception("Not yet implemented.")
278
+ return
279
+
280
+
281
+ def _mol2graphst(
282
+ path: str | mda.Universe, filtrations: Iterable[str], molecule_format=None
283
+ ):
284
+ molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
285
+
286
+ num_filtrations = len(filtrations)
287
+ nodes = molecule.atoms.indices.reshape(1, -1)
288
+ edges = molecule.bonds.dump_contents().T
289
+ num_vertices = nodes.shape[1]
290
+ num_edges = edges.shape[1]
291
+
292
+ st = mp.SimplexTreeMulti(num_parameters=num_filtrations)
293
+
294
+ ## Edges filtration
295
+ # edges = np.array(bonds_df[["atom1", "atom2"]]).T
296
+ edges_filtration = np.zeros((num_edges, num_filtrations), dtype=np.float32) - np.inf
297
+ for i, filtration in enumerate(filtrations):
298
+ match filtration:
299
+ case "bond_length":
300
+ bond_lengths = molecule.bonds.bonds()
301
+ edges_filtration[:, i] = bond_lengths
302
+ case "bond_type":
303
+ bond_types = lines2bonds(mol=molecule, molecule_format=molecule_format)
304
+ edges_filtration[:, i] = bond_types
305
+ case _:
306
+ pass
307
+
308
+ ## Nodes filtration
309
+ nodes_filtrations = np.zeros(
310
+ (num_vertices, num_filtrations), dtype=np.float32
311
+ ) + np.min(
312
+ edges_filtration, axis=0
313
+ ) # better than - np.inf
314
+ st.insert_batch(nodes, nodes_filtrations)
315
+
316
+ st.insert_batch(edges, edges_filtration)
317
+ for i, filtration in enumerate(filtrations):
318
+ match filtration:
319
+ case "charge":
320
+ charges = molecule.atoms.charges
321
+ st.fill_lowerstar(charges, parameter=i)
322
+ case "atomic_mass":
323
+ masses = molecule.atoms.masses
324
+ null_indices = masses == 0
325
+ if np.any(null_indices): # guess if necessary
326
+ masses[null_indices] = guess_masses(molecule.atoms.types)[
327
+ null_indices
328
+ ]
329
+ st.fill_lowerstar(-masses, parameter=i)
330
+ case _:
331
+ pass
332
+ st.make_filtration_non_decreasing() # Necessary ?
333
+ return st
334
+
335
+
336
+ def _mol2ripsst(
337
+ path: str,
338
+ filtrations: Iterable[str],
339
+ threshold=np.inf,
340
+ bond_types: list = ["ar", "am", 3, 2, 1, 0],
341
+ ):
342
+ import gudhi as gd
343
+
344
+ assert "bond_length" == filtrations[0], "Bond length has to be first for rips."
345
+ molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
346
+ num_parameters = len(filtrations)
347
+ st_rips = gd.RipsComplex(
348
+ points=molecule.atoms.positions, max_edge_length=threshold
349
+ ).create_simplex_tree()
350
+ st = mp.SimplexTreeMulti(
351
+ st_rips,
352
+ num_parameters=num_parameters,
353
+ default_values=[
354
+ bond_types.index(0) if f == "bond_type" else -np.inf
355
+ for f in filtrations[1:]
356
+ ], # the 0 index is the label of 'no bond' in bond_types
357
+ )
358
+
359
+ ## Edges filtration
360
+ mol_bonds = molecule.bonds.indices.T
361
+ edges_filtration = (
362
+ np.zeros((mol_bonds.shape[1], num_parameters), dtype=np.float32) - np.inf
363
+ )
364
+ for i, filtration in enumerate(filtrations):
365
+ match filtration:
366
+ case "bond_type":
367
+ edges_filtration[:, i] = lines2bonds(
368
+ mol=molecule, bond_types=bond_types
369
+ )
370
+ case "atomic_mass":
371
+ continue
372
+ case "charge":
373
+ continue
374
+ case "bond_length":
375
+ edges_filtration[:, i] = [st_rips.filtration(s) for s in mol_bonds.T]
376
+ case _:
377
+ raise Exception(
378
+ f"Invalid filtration {filtration}. Available ones : bond_type, atomic_mass, charge, bond_length."
379
+ )
380
+ st.assign_batch_filtration(mol_bonds, edges_filtration, propagate=False)
381
+ min_filtration = edges_filtration.min(axis=0)
382
+ st.assign_batch_filtration(
383
+ np.asarray([list(range(st.num_vertices))], dtype=int),
384
+ np.asarray([min_filtration] * st.num_vertices, dtype=np.float32),
385
+ propagate=False,
386
+ )
387
+ ## Nodes filtration
388
+ for i, filtration in enumerate(filtrations):
389
+ match filtration:
390
+ case "charge":
391
+ charges = molecule.atoms.charges
392
+ st.fill_lowerstar(charges, parameter=i)
393
+ case "atomic_mass":
394
+ masses = molecule.atoms.masses
395
+ null_indices = masses == 0
396
+ if np.any(null_indices): # guess if necessary
397
+ masses[null_indices] = guess_masses(molecule.atoms.types)[
398
+ null_indices
399
+ ]
400
+ # print(masses)
401
+ st.fill_lowerstar(-masses, parameter=i)
402
+ case _:
403
+ pass
404
+ st.make_filtration_non_decreasing() # Necessary ?
405
+ return st
406
+
407
+
408
+ class Molecule2SimplexTree(BaseEstimator, TransformerMixin):
409
+ """
410
+ Transforms a list of MDA-compatible files into a list of mulitparameter simplextrees
411
+
412
+ Input
413
+ -----
414
+ X: Iterable[path_to_files:str]
415
+
416
+ Output
417
+ ------
418
+ Iterable[multipers.SimplexTreeMulti]
419
+
420
+ Parameters
421
+ ----------
422
+ - filtrations : list of filtration names. Available ones : 'charge', 'atomic_mass', 'bond_length', 'bond_type'. Others are ignored.
423
+ - graph : bool. If true, will use the graph given by the molecule, otherwise, a Rips Complex Based on the distance. '
424
+ In that case bond_length is ignored (it's the 1rst parameter).
425
+ """
426
+
427
+ def __init__(
428
+ self,
429
+ delayed: bool = False,
430
+ filtrations: Iterable[str] = [],
431
+ graph: bool = True,
432
+ n_jobs: int = 1,
433
+ ) -> None:
434
+ super().__init__()
435
+ self.delayed = delayed
436
+ self.n_jobs = n_jobs
437
+ self.filtrations = filtrations
438
+ self.graph = graph
439
+ self._molecule_format = None
440
+ return
441
+
442
+ def fit(self, X: Iterable[str], y=None):
443
+ if len(X) == 0:
444
+ return self
445
+ test_mol = mda.Universe(X[0])
446
+ self._molecule_format = test_mol.filename.split(".")[-1].lower()
447
+ return self
448
+
449
+ def transform(self, X: Iterable[str]):
450
+ _to_simplextree = _mol2graphst if self.graph else _mol2ripsst
451
+ to_simplex_tree = lambda path_to_mol2_file: [
452
+ _to_simplextree(path=path_to_mol2_file, filtrations=self.filtrations)
453
+ ]
454
+ if self.delayed:
455
+ return [delayed(to_simplex_tree)(path) for path in X]
456
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(
457
+ delayed(to_simplex_tree)(path) for path in X
458
+ )