multipers 1.0__cp311-cp311-manylinux_2_34_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of multipers might be problematic. Click here for more details.

Files changed (56) hide show
  1. multipers/__init__.py +4 -0
  2. multipers/_old_rank_invariant.pyx +328 -0
  3. multipers/_signed_measure_meta.py +72 -0
  4. multipers/data/MOL2.py +350 -0
  5. multipers/data/UCR.py +18 -0
  6. multipers/data/__init__.py +1 -0
  7. multipers/data/graphs.py +272 -0
  8. multipers/data/immuno_regions.py +27 -0
  9. multipers/data/minimal_presentation_to_st_bf.py +0 -0
  10. multipers/data/pytorch2simplextree.py +91 -0
  11. multipers/data/shape3d.py +101 -0
  12. multipers/data/synthetic.py +68 -0
  13. multipers/distances.py +100 -0
  14. multipers/euler_characteristic.cpython-311-x86_64-linux-gnu.so +0 -0
  15. multipers/euler_characteristic.pyx +132 -0
  16. multipers/function_rips.cpython-311-x86_64-linux-gnu.so +0 -0
  17. multipers/function_rips.pyx +101 -0
  18. multipers/hilbert_function.cpython-311-x86_64-linux-gnu.so +0 -0
  19. multipers/hilbert_function.pyi +46 -0
  20. multipers/hilbert_function.pyx +145 -0
  21. multipers/ml/__init__.py +0 -0
  22. multipers/ml/accuracies.py +61 -0
  23. multipers/ml/convolutions.py +384 -0
  24. multipers/ml/invariants_with_persistable.py +79 -0
  25. multipers/ml/kernels.py +128 -0
  26. multipers/ml/mma.py +422 -0
  27. multipers/ml/one.py +472 -0
  28. multipers/ml/point_clouds.py +191 -0
  29. multipers/ml/signed_betti.py +50 -0
  30. multipers/ml/signed_measures.py +1046 -0
  31. multipers/ml/sliced_wasserstein.py +313 -0
  32. multipers/ml/tools.py +99 -0
  33. multipers/multiparameter_edge_collapse.py +29 -0
  34. multipers/multiparameter_module_approximation.cpython-311-x86_64-linux-gnu.so +0 -0
  35. multipers/multiparameter_module_approximation.pxd +147 -0
  36. multipers/multiparameter_module_approximation.pyi +439 -0
  37. multipers/multiparameter_module_approximation.pyx +931 -0
  38. multipers/pickle.py +53 -0
  39. multipers/plots.py +207 -0
  40. multipers/point_measure_integration.cpython-311-x86_64-linux-gnu.so +0 -0
  41. multipers/point_measure_integration.pyx +59 -0
  42. multipers/rank_invariant.cpython-311-x86_64-linux-gnu.so +0 -0
  43. multipers/rank_invariant.pyx +154 -0
  44. multipers/simplex_tree_multi.cpython-311-x86_64-linux-gnu.so +0 -0
  45. multipers/simplex_tree_multi.pxd +121 -0
  46. multipers/simplex_tree_multi.pyi +715 -0
  47. multipers/simplex_tree_multi.pyx +1284 -0
  48. multipers/tensor.pxd +13 -0
  49. multipers/test.pyx +44 -0
  50. multipers-1.0.dist-info/LICENSE +21 -0
  51. multipers-1.0.dist-info/METADATA +9 -0
  52. multipers-1.0.dist-info/RECORD +56 -0
  53. multipers-1.0.dist-info/WHEEL +5 -0
  54. multipers-1.0.dist-info/top_level.txt +1 -0
  55. multipers.libs/libtbb-5d1cde94.so.12.10 +0 -0
  56. multipers.libs/libtbbmalloc-5e0a3d4c.so.2.10 +0 -0
multipers/data/MOL2.py ADDED
@@ -0,0 +1,350 @@
1
+ import numpy as np
2
+ from os.path import expanduser
3
+ import pandas as pd
4
+ from sklearn.preprocessing import LabelEncoder
5
+ from os.path import expanduser
6
+ from os import listdir
7
+ import os
8
+ import MDAnalysis as mda
9
+ import matplotlib.pyplot as plt
10
+ from MDAnalysis.topology.guessers import guess_masses
11
+ import multipers as mp
12
+ # from numba import njit
13
+ from tqdm import tqdm
14
+ from typing import Iterable
15
+ from joblib import Parallel, delayed
16
+ from sklearn.base import BaseEstimator,TransformerMixin
17
+
18
+
19
+ DATASET_PATH = expanduser("~/Datasets/")
20
+ JC_path = DATASET_PATH + "Cleves-Jain/"
21
+ DUDE_path = DATASET_PATH + "DUD-E/"
22
+
23
+
24
+ #pathes = get_data_path()
25
+ #imgs = apply_pipeline(pathes=pathes, pipeline=pipeline_img)
26
+ #distances_to_letter, ytest = img_distances(imgs)
27
+
28
+
29
+ def _get_mols_in_path(folder):
30
+ with open(folder+"/TargetList", "r") as f:
31
+ train_data = [folder + "/" + mol.strip() for mol in f.readlines()]
32
+ criterion = lambda dataset : dataset.endswith(".mol2") and not dataset.startswith("final") and dataset not in train_data
33
+ test_data = [folder + "/" + dataset for dataset in listdir(folder) if criterion(folder + "/" + dataset)]
34
+ return train_data, test_data
35
+ def get_data_path_JC(type="dict"):
36
+ if type == "dict": out = {}
37
+ elif type == "list": out = []
38
+ else: raise TypeError(f"Type {out} not supported")
39
+ for stuff in listdir(JC_path):
40
+ if stuff.startswith("target_"):
41
+ current_letter = stuff[-1]
42
+ to_add = _get_mols_in_path(JC_path + stuff)
43
+ if type == "dict": out[current_letter] = to_add
44
+ elif type == "list": out.append(to_add)
45
+ decoy_folder = JC_path + "RognanRing850/"
46
+ to_add = [decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")]
47
+ if type == "dict": out["decoy"] = to_add
48
+ elif type == "list": out.append(to_add)
49
+ return out
50
+ def get_all_JC_path():
51
+ out = []
52
+ for stuff in listdir(JC_path):
53
+ if stuff.startswith("target_"):
54
+ train_data, test_data = _get_mols_in_path(JC_path + stuff)
55
+ out += train_data
56
+ out += test_data
57
+ decoy_folder = JC_path + "RognanRing850/"
58
+ out +=[decoy_folder + mol for mol in listdir(decoy_folder) if mol.endswith(".mol2")]
59
+ return out
60
+
61
+
62
+ def split_multimol(path:str, mol_name:str, out_folder_name:str = "splitted", enforce_charges:bool=False):
63
+ with open(path + mol_name, "r") as f:
64
+ lines = f.readlines()
65
+ splitted_mols = []
66
+ index = 0
67
+ for i,line in enumerate(lines):
68
+ is_last = i == len(lines)-1
69
+ if line.strip() == "@<TRIPOS>MOLECULE" or is_last:
70
+ if i != index:
71
+ molecule = "".join(lines[index:i + is_last])
72
+ if enforce_charges:
73
+ # print(f"Replaced molecule {i}")
74
+ molecule = molecule.replace("NO_CHARGES","USER_CHARGES")
75
+ # print(molecule)
76
+ # return
77
+ index = i
78
+ splitted_mols.append(molecule)
79
+ if not os.path.exists(path + out_folder_name):
80
+ os.mkdir(path + out_folder_name)
81
+ for i,mol in enumerate(splitted_mols):
82
+ with open(path + out_folder_name + f"/{i}.mol2", "w") as f:
83
+ f.write(mol)
84
+ return [path+out_folder_name + f"/{i}.mol2" for i in range(len(splitted_mols))]
85
+
86
+ # @njit(parallel=True)
87
+ def apply_pipeline(pathes:dict, pipeline):
88
+ img_dict = {}
89
+ for key, value in tqdm(pathes.items(), desc="Applying pipeline"):
90
+ if len(key) == 1:
91
+ train_paths, test_paths = value
92
+ train_imgs = pipeline.transform(train_paths)
93
+ test_imgs = pipeline.transform(test_paths)
94
+ img_dict[key] = (train_imgs, test_imgs)
95
+ else:
96
+ assert key == "decoy"
97
+ img_dict[key] = pipeline.transform(value)
98
+ return img_dict
99
+
100
+ from sklearn.metrics import pairwise_distances
101
+ def img_distances(img_dict:dict):
102
+ distances_to_anchors = []
103
+ ytest = []
104
+ decoy_list = img_dict["decoy"]
105
+ for letter, imgs in img_dict.items():
106
+ if len(letter) != 1 : continue # decoy
107
+ xtrain, xtest = imgs
108
+ assert len(xtest)>0
109
+ train_data, test_data = xtrain, np.concatenate([xtest ,decoy_list])
110
+ D = pairwise_distances(train_data, test_data)
111
+ distances_to_anchors.append(D)
112
+ letter_ytest = np.array([letter]*len(xtest) + ['0']*len(decoy_list), dtype="<U1")
113
+ ytest.append(letter_ytest)
114
+ return distances_to_anchors, ytest
115
+
116
+ def get_EF_vector_from_distances(distances, ytest, alpha=0.05):
117
+ EF = []
118
+ for distance_to_anchors, letter_ytest in zip(distances, ytest):
119
+ indices = np.argsort(distance_to_anchors, axis=1)
120
+ n = indices.shape[1]
121
+ n_max = int(alpha*n)
122
+ good_indices = (letter_ytest[indices[:,:n_max]] == letter_ytest[0]) ## assumes that ytest[:,0] are the good letters
123
+ EF_letter = good_indices.sum(axis=1) / (letter_ytest == letter_ytest[0]).sum()
124
+ EF_letter /= alpha
125
+ EF.append(EF_letter.mean())
126
+ return np.mean(EF)
127
+
128
+ def EF_from_distance_matrix(distances:np.ndarray, labels:list|np.ndarray, alpha:float, anchors_in_test=True):
129
+ """
130
+ Computes the Enrichment Factor from a distance matrix, and its labels.
131
+ - First axis of the distance matrix is the anchors on which to compute the EF
132
+ - Second axis is the test. For convenience, anchors can be put in test, if the flag anchors_in_test is set to true.
133
+ - labels is a table of bools, representing the the labels of the test axis of the distance matrix.
134
+ - alpha : the EF alpha parameter.
135
+ """
136
+ n = len(labels)
137
+ n_max = int(alpha*n)
138
+ indices = np.argsort(distances, axis=1)
139
+ EF_ = [((labels[idx[:n_max]]).sum()-anchors_in_test)/(labels.sum()-anchors_in_test) for idx in indices]
140
+ return np.mean(EF_)/alpha
141
+
142
+ def EF_AUC(distances:np.ndarray, labels:np.ndarray, anchors_in_test=0):
143
+ if distances.ndim == 1:
144
+ distances = distances[None,:]
145
+ assert distances.ndim == 2
146
+ indices = np.argsort(distances, axis=1)
147
+ out = []
148
+ for i in range(1,distances.size):
149
+ proportion_of_good_indices = (labels[indices[:,:i]].sum(axis=1).mean() -anchors_in_test)/min(i,labels.sum() -anchors_in_test)
150
+ out.append(proportion_of_good_indices)
151
+ # print(out)
152
+ return np.mean(out)
153
+
154
+
155
+ def theorical_max_EF(distances,labels, alpha):
156
+ n = len(labels)
157
+ n_max = int(alpha*n)
158
+ num_true_labels = np.sum(labels == labels[0]) ## if labels are not True / False, assumes that the first one is a good one
159
+ return min(n_max, num_true_labels)/alpha
160
+
161
+
162
+ def theorical_max_EF_from_distances(list_of_distances,list_of_labels, alpha):
163
+ return np.mean([theorical_max_EF(distances, labels,alpha) for distances, labels in zip(list_of_distances, list_of_labels)])
164
+
165
+ def plot_EF_from_distances(alphas = [0.01, 0.02, 0.05, 0.1], EF = EF_from_distance_matrix, plot:bool=True):
166
+ y = np.round([EF(alpha=alpha) for alpha in alphas], decimals=2)
167
+ if plot:
168
+ _alphas = np.linspace(0.01, 1., 100)
169
+ plt.figure()
170
+ plt.plot(_alphas, [EF(alpha=alpha) for alpha in _alphas])
171
+ plt.scatter(alphas, y, c='r')
172
+ plt.title("Enrichment Factor")
173
+ plt.xlabel(r"$\alpha$" + f" = {alphas}")
174
+ plt.ylabel(r"$\mathrm{EF}_\alpha$" + f" = {y}")
175
+ return y
176
+
177
+
178
+ def lines2bonds(mol:mda.Universe, bond_types = ['ar','am',3,2,1,0], molecule_format=None):
179
+ extension = mol.filename.split('.')[-1].lower() if molecule_format is None else molecule_format
180
+ match extension:
181
+ case 'mol2':
182
+ out = lines2bonds_MOL2(mol)['bond_type']
183
+ case 'pdb':
184
+ out = lines2bonds_PDB(mol)
185
+ case _:
186
+ raise Exception('Invalid, or not supported molecule format.')
187
+ return LabelEncoder().fit(bond_types).transform(out)
188
+
189
+
190
+ def lines2bonds_MOL2(mol:mda.Universe):
191
+ _lines = open(mol.filename, "r").readlines()
192
+ out = []
193
+ index = 0
194
+ while index < len(_lines) and _lines[index].strip() != "@<TRIPOS>BOND":
195
+ index += 1
196
+ index += 1
197
+ while index < len(_lines) and _lines[index].strip()[0] != "@":
198
+ line = _lines[index].strip().split(" ")
199
+ for j,truc in enumerate(line):
200
+ line[j] = truc.strip()
201
+ # try:
202
+ out.append([stuff for stuff in line if len(stuff) > 0])
203
+ # except:
204
+ # print_lin
205
+ index +=1
206
+ out = pd.DataFrame(out, columns=["bond_id","atom1", "atom2", "bond_type"])
207
+ out.set_index(["bond_id"],inplace=True)
208
+ return out
209
+
210
+
211
+ def lines2bonds_PDB(mol:mda.Universe):
212
+ raise Exception('Not yet implemented.')
213
+ return
214
+
215
+ def _mol2graphst(path:str|mda.Universe, filtrations:Iterable[str], molecule_format=None):
216
+ molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
217
+
218
+ num_filtrations = len(filtrations)
219
+ nodes = molecule.atoms.indices.reshape(1,-1)
220
+ edges = molecule.bonds.dump_contents().T
221
+ num_vertices = nodes.shape[1]
222
+ num_edges = edges.shape[1]
223
+
224
+ st = mp.SimplexTreeMulti(num_parameters = num_filtrations)
225
+
226
+ ## Edges filtration
227
+ # edges = np.array(bonds_df[["atom1", "atom2"]]).T
228
+ edges_filtration = np.zeros((num_edges, num_filtrations), dtype=np.float32) - np.inf
229
+ for i, filtration in enumerate(filtrations):
230
+ match filtration:
231
+ case "bond_length":
232
+ bond_lengths = molecule.bonds.bonds()
233
+ edges_filtration[:,i] = bond_lengths
234
+ case "bond_type":
235
+ bond_types = lines2bonds(mol=molecule, molecule_format=molecule_format)
236
+ edges_filtration[:,i] = bond_types
237
+ case _:
238
+ pass
239
+
240
+ ## Nodes filtration
241
+ nodes_filtrations = np.zeros((num_vertices,num_filtrations), dtype=np.float32) + np.min(edges_filtration, axis=0) # better than - np.inf
242
+ st.insert_batch(nodes, nodes_filtrations)
243
+
244
+ st.insert_batch(edges, edges_filtration)
245
+ for i, filtration in enumerate(filtrations):
246
+ match filtration:
247
+ case "charge":
248
+ charges = molecule.atoms.charges
249
+ st.fill_lowerstar(charges, parameter=i)
250
+ case "atomic_mass":
251
+ masses = molecule.atoms.masses
252
+ null_indices = masses == 0
253
+ if np.any(null_indices): # guess if necessary
254
+ masses[null_indices] = guess_masses(molecule.atoms.types)[null_indices]
255
+ st.fill_lowerstar(-masses, parameter=i)
256
+ case _:
257
+ pass
258
+ st.make_filtration_non_decreasing() # Necessary ?
259
+ return st
260
+
261
+
262
+ def _mol2ripsst(path:str, filtrations:Iterable[str], threshold=np.inf, bond_types:list=['ar','am',3,2,1,0]):
263
+ import gudhi as gd
264
+ assert 'bond_length' == filtrations[0], "Bond length has to be first for rips."
265
+ molecule = path if isinstance(path, mda.Universe) else mda.Universe(path)
266
+ num_parameters = len(filtrations)
267
+ st_rips = gd.RipsComplex(points = molecule.atoms.positions, max_edge_length=threshold).create_simplex_tree()
268
+ st = mp.SimplexTreeMulti(st_rips, num_parameters=num_parameters,
269
+ default_values = [bond_types.index(0) if f == "bond_type" else -np.inf for f in filtrations[1:]] # the 0 index is the label of 'no bond' in bond_types
270
+ )
271
+
272
+ ## Edges filtration
273
+ mol_bonds = molecule.bonds.indices.T
274
+ edges_filtration = np.zeros((mol_bonds.shape[1], num_parameters), dtype=np.float32) - np.inf
275
+ for i, filtration in enumerate(filtrations):
276
+ match filtration:
277
+ case "bond_type":
278
+ edges_filtration[:,i] = lines2bonds(mol=molecule, bond_types=bond_types)
279
+ case "atomic_mass":
280
+ continue
281
+ case "charge":
282
+ continue
283
+ case 'bond_length':
284
+ edges_filtration[:,i] = [st_rips.filtration(s) for s in mol_bonds.T]
285
+ case _:
286
+ raise Exception(f"Invalid filtration {filtration}. Available ones : bond_type, atomic_mass, charge, bond_length.")
287
+ st.assign_batch_filtration(mol_bonds, edges_filtration, propagate=False)
288
+ min_filtration = edges_filtration.min(axis=0)
289
+ st.assign_batch_filtration(np.asarray([list(range(st.num_vertices))], dtype=int), np.asarray([min_filtration]*st.num_vertices, dtype=np.float32), propagate=False)
290
+ ## Nodes filtration
291
+ for i, filtration in enumerate(filtrations):
292
+ match filtration:
293
+ case "charge":
294
+ charges = molecule.atoms.charges
295
+ st.fill_lowerstar(charges, parameter=i)
296
+ case "atomic_mass":
297
+ masses = molecule.atoms.masses
298
+ null_indices = masses == 0
299
+ if np.any(null_indices): # guess if necessary
300
+ masses[null_indices] = guess_masses(molecule.atoms.types)[null_indices]
301
+ # print(masses)
302
+ st.fill_lowerstar(-masses, parameter=i)
303
+ case _:
304
+ pass
305
+ st.make_filtration_non_decreasing() # Necessary ?
306
+ return st
307
+
308
+
309
+ class Molecule2SimplexTree(BaseEstimator, TransformerMixin):
310
+ """
311
+ Transforms a list of MDA-compatible files into a list of mulitparameter simplextrees
312
+
313
+ Input
314
+ -----
315
+ X: Iterable[path_to_files:str]
316
+
317
+ Output
318
+ ------
319
+ Iterable[multipers.SimplexTreeMulti]
320
+
321
+ Parameters
322
+ ----------
323
+ - filtrations : list of filtration names. Available ones : 'charge', 'atomic_mass', 'bond_length', 'bond_type'. Others are ignored.
324
+ - graph : bool. If true, will use the graph given by the molecule, otherwise, a Rips Complex Based on the distance. '
325
+ In that case bond_length is ignored (it's the 1rst parameter).
326
+ """
327
+ def __init__(self,
328
+ delayed:bool=False,
329
+ filtrations:Iterable[str]=[],
330
+ graph:bool=True,
331
+ n_jobs:int=1) -> None:
332
+ super().__init__()
333
+ self.delayed=delayed
334
+ self.n_jobs = n_jobs
335
+ self.filtrations=filtrations
336
+ self.graph=graph
337
+ self._molecule_format=None
338
+ return
339
+ def fit(self, X:Iterable[str], y=None):
340
+ if len(X) == 0: return self
341
+ test_mol = mda.Universe(X[0])
342
+ self._molecule_format = test_mol.filename.split('.')[-1].lower()
343
+ return self
344
+ def transform(self,X:Iterable[str]):
345
+ _to_simplextree = _mol2graphst if self.graph else _mol2ripsst
346
+ to_simplex_tree = lambda path_to_mol2_file : _to_simplextree(path=path_to_mol2_file, filtrations=self.filtrations)
347
+ if self.delayed:
348
+ return [delayed(to_simplex_tree)(path) for path in X]
349
+ return Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(to_simplex_tree)(path) for path in X)
350
+
multipers/data/UCR.py ADDED
@@ -0,0 +1,18 @@
1
+ import numpy as np
2
+ from os.path import expanduser
3
+ import pandas as pd
4
+ from sklearn.preprocessing import LabelEncoder
5
+
6
+ def get(dataset:str="UCR/Coffee", test:bool=False, DATASET_PATH:str=expanduser("~/Datasets/"), dim=3,delay=1,skip=1):
7
+ from gudhi.point_cloud.timedelay import TimeDelayEmbedding
8
+ dataset_path = DATASET_PATH + dataset + "/" + dataset[4:]
9
+ dataset_path += "_TEST.tsv" if test else "_TRAIN.tsv"
10
+ data = np.array(pd.read_csv(dataset_path, delimiter='\t', header=None, index_col=None))
11
+ Y = LabelEncoder().fit_transform(data[:,0])
12
+ data = data[:,1:]
13
+ tde = TimeDelayEmbedding(dim=dim, delay=delay, skip=skip).transform(data)
14
+ return tde, Y
15
+ def get_train(*args, **kwargs):
16
+ return get(*args, **kwargs, test=False)
17
+ def get_test(*args, **kwargs):
18
+ return get(*args, **kwargs, test=True)
@@ -0,0 +1 @@
1
+ from .synthetic import *
@@ -0,0 +1,272 @@
1
+ import numpy as np
2
+ from os.path import expanduser, exists
3
+ import networkx as nx
4
+ from warnings import warn
5
+ import pickle
6
+ from joblib import Parallel, delayed
7
+ from tqdm import tqdm
8
+ from sklearn.preprocessing import LabelEncoder
9
+ from scipy.spatial import distance_matrix
10
+ from sklearn.base import BaseEstimator, TransformerMixin, clone
11
+ import multipers as mp
12
+ from typing import Iterable
13
+
14
+ DATASET_PATH=expanduser("~/Datasets/")
15
+
16
+
17
+ def get(dataset:str, filtration:str):
18
+ graphs, labels = get_graphs(dataset)
19
+ try:
20
+ for g in graphs:
21
+ for node in g.nodes:
22
+ g.nodes[node][filtration]
23
+ except:
24
+ print(f"Filtration {filtration} not computed, trying to compute it ...", flush=1)
25
+ compute_filtration(dataset, filtration)
26
+ return get_graphs(dataset)
27
+
28
+
29
+
30
+ def get_from_file_old(dataset:str, label="lb"):
31
+ from os import walk
32
+ from scipy.io import loadmat
33
+ from warnings import warn
34
+ path = DATASET_PATH + dataset +"/mat/"
35
+ labels:list[int] = []
36
+ gs:list[nx.Graph] = []
37
+ for root, dir, files in walk(path):
38
+ for file in files:
39
+ file_ppties = file.split("_")
40
+ gid = file_ppties[5]
41
+ i=0
42
+ while i+1 < len(file_ppties) and file_ppties[i] != label :
43
+ i+=1
44
+ if i+1 >= len(file_ppties):
45
+ warn(f"Cannot find label {label} on file {file}.")
46
+ else:
47
+ labels += [file_ppties[i+1]]
48
+ adj_mat = np.array(loadmat(path + file)['A'], dtype=np.float32)
49
+ gs.append(nx.Graph(adj_mat))
50
+ return gs, labels
51
+
52
+
53
+ def get_from_file(dataset:str):
54
+ from os.path import expanduser, exists
55
+ path = DATASET_PATH + f"{dataset}/{dataset[7:]}."
56
+ try:
57
+ graphs_ids = np.loadtxt(path+"graph_idx")
58
+ except:
59
+ return get_from_file_old(dataset=dataset)
60
+ labels:list[int] = LabelEncoder().fit_transform(np.loadtxt(path+"graph_labels"))
61
+ edges = np.loadtxt(path+"edges", delimiter=',', dtype=int)-1
62
+ has_intrinsic_filtration = exists(path+"node_attrs")
63
+ graphs:list[nx.Graph] = []
64
+ if has_intrinsic_filtration:
65
+ F = np.loadtxt(path+"node_attrs", delimiter=',')
66
+ for graph_id in tqdm(np.unique(graphs_ids), desc="Reading graphs from file"):
67
+ nodes, = np.where(graphs_ids == graph_id)
68
+ def graph_has_edge(u:int,v:int)->bool:
69
+ if u in nodes or v in nodes:
70
+ assert u in nodes and v in nodes, f"Nodes {u} and {v} are not in the same graph"
71
+ return True
72
+ return False
73
+ graph_edges = [(u,v) for u,v in edges if graph_has_edge(u,v)]
74
+ g = nx.Graph(graph_edges)
75
+ if has_intrinsic_filtration:
76
+ node_attrs = {node:F[node] for node in nodes}
77
+ nx.set_node_attributes(g,node_attrs, "intrinsic")
78
+ graphs.append(g)
79
+ return graphs, labels
80
+
81
+
82
+ def get_graphs(dataset:str, N:int|str="")->tuple[list[nx.Graph], list[int]]:
83
+ graphs_path = f"{DATASET_PATH}{dataset}/graphs{N}.pkl"
84
+ labels_path = f"{DATASET_PATH}{dataset}/labels{N}.pkl"
85
+ if not exists(graphs_path) or not exists(labels_path):
86
+ if dataset.startswith("3dshapes/"):
87
+ return get_from_file_old(dataset,)
88
+ graphs, labels = get_from_file(dataset,)
89
+ print("Saving graphs at :", graphs_path)
90
+ set_graphs(graphs = graphs, labels = labels, dataset = dataset)
91
+ else:
92
+ graphs = pickle.load(open(graphs_path, "rb"))
93
+ labels = pickle.load(open(labels_path, "rb"))
94
+ from sklearn.preprocessing import LabelEncoder
95
+ return graphs, LabelEncoder().fit_transform(labels)
96
+
97
+
98
+ def set_graphs(graphs:list[nx.Graph], labels:list, dataset:str, N:int|str=""): # saves graphs (and filtration values) into a file
99
+ graphs_path = f"{DATASET_PATH}{dataset}/graphs{N}.pkl"
100
+ labels_path = f"{DATASET_PATH}{dataset}/labels{N}.pkl"
101
+ pickle.dump(graphs, open(graphs_path, "wb"))
102
+ pickle.dump(labels, open(labels_path, "wb"))
103
+ return
104
+
105
+ def reset_graphs(dataset:str, N=None): # Resets filtrations values on graphs
106
+ graphs, labels = get_from_file(dataset)
107
+ set_graphs(graphs,labels, dataset)
108
+ return
109
+
110
+
111
+
112
+
113
+ def compute_ricci(graphs:list[nx.Graph], alpha=0.5, progress = 1):
114
+ from GraphRicciCurvature.OllivierRicci import OllivierRicci
115
+ def ricci(graph, alpha=alpha):
116
+ return OllivierRicci(graph,alpha=alpha).compute_ricci_curvature()
117
+ graphs = [ricci(g) for g in tqdm(graphs, disable = not progress, desc="Computing ricci")]
118
+ def push_back_node(graph):
119
+ # for node in graph.nodes:
120
+ # graph.nodes[node]['ricciCurvature'] = np.min([graph[node][node2]['ricciCurvature'] for node2 in graph[node]] + [graph.nodes[node]['ricciCurvature']])
121
+ node_filtrations = {
122
+ node: -1 if len(graph[node]) == 0 else np.min([graph[node][node2]['ricciCurvature'] for node2 in graph[node]])
123
+ for node in graph.nodes
124
+ }
125
+ nx.set_node_attributes(graph,node_filtrations,"ricciCurvature")
126
+ return graph
127
+ graphs = [push_back_node(g) for g in graphs]
128
+ return graphs
129
+
130
+ def compute_cc(graphs:list[nx.Graph], progress = 1):
131
+ def _cc(g):
132
+ cc = nx.closeness_centrality(g)
133
+ nx.set_node_attributes(g,cc,"cc")
134
+ edges_cc = {(u,v):max(cc[u], cc[v]) for u,v in g.edges}
135
+ nx.set_edge_attributes(g,edges_cc, "cc")
136
+ return g
137
+ graphs = Parallel(n_jobs=1, prefer="threads")(delayed(_cc)(g) for g in tqdm(graphs, disable = not progress, desc="Computing cc"))
138
+ return graphs
139
+ # for g in tqdm(graphs, desc="Computing cc"):
140
+ # _cc(g)
141
+ # return graphs
142
+
143
+ def compute_degree(graphs:list[nx.Graph], progress=1):
144
+ def _degree(g):
145
+ degrees = {i:1.1 if degree == 0 else 1 / degree for i, degree in g.degree}
146
+ nx.set_node_attributes(g,degrees,"degree")
147
+ edges_dg = {(u,v):max(degrees[u], degrees[v]) for u,v in g.edges}
148
+ nx.set_edge_attributes(g,edges_dg, "degree")
149
+ return g
150
+ graphs = Parallel(n_jobs=1, prefer="threads")(delayed(_degree)(g) for g in tqdm(graphs, disable = not progress, desc="Computing degree"))
151
+ return graphs
152
+ # for g in tqdm(graphs, desc="Computing degree"):
153
+ # _degree(g)
154
+ # return graphs
155
+
156
+ def compute_fiedler(graphs:list[nx.Graph], progress = 1): # TODO : make it compatible with non-connexe graphs
157
+ def _fiedler(g):
158
+ connected_graphs = [nx.subgraph(g, nodes) for nodes in nx.connected_components(g)]
159
+ fiedler_vectors = [nx.fiedler_vector(g)**2 if g.number_of_nodes() > 2 else np.zeros(g.number_of_nodes()) for g in connected_graphs] # order of nx.fiedler_vector correspond to nx.laplacian -> g.nodes
160
+ fiedler_dict = {
161
+ node:fiedler_vector[node_index]
162
+ for g,fiedler_vector in zip(connected_graphs, fiedler_vectors)
163
+ for node_index,node in enumerate(list(g.nodes))
164
+ }
165
+ nx.set_node_attributes(g,fiedler_dict,"fiedler")
166
+ edges_fiedler = {(u,v):max(fiedler_dict[u], fiedler_dict[v]) for u,v in g.edges}
167
+ nx.set_edge_attributes(g,edges_fiedler, "fiedler")
168
+ return g
169
+ graphs = Parallel(n_jobs=1, prefer="threads")(delayed(_fiedler)(g) for g in tqdm(graphs, disable = not progress, desc="Computing fiedler"))
170
+ return graphs
171
+ # for g in tqdm(graphs, desc="Computing fiedler"):
172
+ # _fiedler(g)
173
+ # return graphs
174
+
175
+ def compute_hks(graphs:list[nx.Graph],t:float, progress = 1):
176
+ def _hks(g:nx.Graph):
177
+ w, vps = np.linalg.eig(nx.laplacianmatrix.normalized_laplacian_matrix(g, nodelist=g.nodes()).toarray()) # order is given by g.nodes order
178
+ w = w.view(dtype=float)
179
+ vps= vps.view(dtype=float)
180
+ node_hks = {node:np.sum(np.exp(-t*w)*np.square(vps[node_index,:])) for node_index,node in enumerate(g.nodes)}
181
+ nx.set_node_attributes(g, node_hks, f"hks_{t}")
182
+ edges_hks = {(u,v):max(node_hks[u], node_hks[v]) for u,v in g.edges}
183
+ nx.set_edge_attributes(g,edges_hks, f"hks_{t}")
184
+ return g
185
+ graphs = Parallel(n_jobs=1, prefer="threads")(delayed(_hks)(g) for g in tqdm(graphs, disable = not progress, desc=f"Computing hks_{t}"))
186
+ return graphs
187
+
188
+ def compute_geodesic(graphs:list[nx.Graph], progress=1):
189
+ def _f(g:nx.Graph):
190
+ try:
191
+ nodes_intrinsic = {i:n["intrinsic"] for i,n in g.nodes.data()}
192
+ except:
193
+ warn("This graph doesn't have an intrinsic filtration, will use 0 instead ...")
194
+ nodes_intrinsic = {i:0 for i,n in g.nodes.data()}
195
+ # return g
196
+ node_geodesic = {i:0 for i in g.nodes}
197
+ nx.set_node_attributes(g, node_geodesic, f"geodesic")
198
+ edges_geodesic = {(u,v):np.linalg.norm(nodes_intrinsic[u] - nodes_intrinsic[v]) for u,v in g.edges}
199
+ nx.set_edge_attributes(g,edges_geodesic, f"geodesic")
200
+ return g
201
+ graphs = Parallel(n_jobs=1, prefer="threads")(delayed(_f)(g) for g in tqdm(graphs, disable = not progress, desc=f"Computing geodesic distances on graphs"))
202
+ return graphs
203
+
204
+ def compute_filtration(dataset:str, filtration:str, **kwargs):
205
+ if filtration == "ALL":
206
+ reset_graphs(dataset) # not necessary
207
+ graphs,labels = get_graphs(dataset, **kwargs)
208
+ graphs = compute_geodesic(graphs)
209
+ graphs = compute_cc(graphs)
210
+ graphs = compute_degree(graphs)
211
+ graphs = compute_ricci(graphs)
212
+ graphs = compute_fiedler(graphs)
213
+ graphs = compute_hks(graphs, 10)
214
+ set_graphs(graphs=graphs, labels=labels, dataset=dataset)
215
+ return
216
+ graphs,labels = get_graphs(dataset, **kwargs)
217
+ if filtration == "dijkstra":
218
+ return
219
+ elif filtration == "cc":
220
+ graphs = compute_cc(graphs)
221
+ elif filtration == "degree":
222
+ graphs = compute_degree(graphs)
223
+ elif filtration == "ricciCurvature":
224
+ graphs = compute_ricci(graphs)
225
+ elif filtration == "fiedler":
226
+ graphs = compute_fiedler(graphs)
227
+ elif filtration == "geodesic":
228
+ graphs = compute_geodesic(graphs)
229
+ elif filtration.startswith('hks_'):
230
+ t = int(filtration[4:]) # don't want do deal with floats, makes dots in title...
231
+ graphs = compute_hks(graphs=graphs, t=t)
232
+ else:
233
+ warn(f"Filtration {filtration} not implemented !")
234
+ return
235
+ set_graphs(graphs=graphs, labels=labels, dataset=dataset)
236
+ return
237
+
238
+
239
+
240
+ class Graph2SimplexTrees(BaseEstimator,TransformerMixin):
241
+ """
242
+ Transforms a list of networkx graphs into a list of simplextree multi
243
+
244
+ Usual Filtrations
245
+ -----------------
246
+ - "cc" closeness centrality
247
+ - "geodesic" if the graph provides data to compute it, e.g., BZR, COX2, PROTEINS
248
+ - "degree"
249
+ - "ricciCurvature" the ricci curvature
250
+ - "fiedler" the square of the fiedler vector
251
+ """
252
+ def __init__(self, filtrations:Iterable[str]=[], delayed=False, num_collapses=100, progress:bool=False):
253
+ super().__init__()
254
+ self.filtrations=filtrations # filtration to search in graph
255
+ self.delayed = delayed # reverses the filtration #TODO
256
+ self.num_collapses=num_collapses
257
+ self.progress=progress
258
+ def fit(self, X, y=None):
259
+ return self
260
+ def transform(self,X:list[nx.Graph]):
261
+ def todo(graph, filtrations=self.filtrations) -> mp.SimplexTreeMulti:
262
+ st = mp.SimplexTreeMulti(num_parameters=len(filtrations))
263
+ nodes = np.asarray(graph.nodes, dtype=int).reshape(1,-1)
264
+ nodes_filtrations = np.asarray([[graph.nodes[node][filtration] for filtration in filtrations] for node in graph.nodes], dtype=np.float32)
265
+ st.insert_batch(nodes, nodes_filtrations)
266
+ edges = np.asarray(graph.edges, dtype=int).T
267
+ edges_filtrations = np.asarray([[graph[u][v][filtration] for filtration in filtrations] for u,v in graph.edges], dtype=np.float32)
268
+ st.insert_batch(edges,edges_filtrations)
269
+ if st.num_parameters == 2: st.collapse_edges(num=self.num_collapses) # TODO : wait for a filtration domination update
270
+ # st.make_filtration_non_decreasing() ## Ricci is not safe ...
271
+ return [st] # same output for each pipelines, some have a supplementary axis.
272
+ return [delayed(todo)(graph) for graph in X] if self.delayed else Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(graph) for graph in tqdm(X, desc="Computing simplextrees from graphs", disable=not self.progress))
@@ -0,0 +1,27 @@
1
+ import numpy as np
2
+ from pandas import read_csv
3
+ from os.path import expanduser
4
+ from os import walk
5
+ from sklearn.preprocessing import LabelEncoder
6
+
7
+
8
+
9
+ def get(DATASET_PATH = expanduser("~/Datasets/")):
10
+ DATASET_PATH += "1.5mmRegions/"
11
+ X, labels = [],[]
12
+ for label in ["FoxP3", "CD8", "CD68"]:
13
+ # for label in ["FoxP3", "CD8"]:
14
+ for root, dirs, files in walk(DATASET_PATH + label+"/"):
15
+ for name in files:
16
+ X.append(np.array(read_csv(DATASET_PATH+label+"/"+name))/1500) ## Rescaled
17
+ labels.append(label)
18
+ return X, LabelEncoder().fit_transform(np.array(labels))
19
+
20
+ def get_immuno(i=1, DATASET_PATH = expanduser("~/Datasets/")):
21
+ immu_dataset = read_csv(DATASET_PATH+f"LargeHypoxicRegion{i}.csv")
22
+ X = np.array(immu_dataset['x'])
23
+ X /= np.max(X)
24
+ Y = np.array(immu_dataset['y'])
25
+ Y /= np.max(Y)
26
+ labels = LabelEncoder().fit_transform(immu_dataset['Celltype'])
27
+ return np.asarray([X,Y]).T, labels
File without changes