grasp-tool 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_tool/__init__.py +17 -0
- grasp_tool/__main__.py +6 -0
- grasp_tool/cli/__init__.py +1 -0
- grasp_tool/cli/main.py +793 -0
- grasp_tool/cli/train_moco.py +778 -0
- grasp_tool/gnn/__init__.py +1 -0
- grasp_tool/gnn/embedding.py +165 -0
- grasp_tool/gnn/gat_moco_final.py +990 -0
- grasp_tool/gnn/graphloader.py +1748 -0
- grasp_tool/gnn/plot_refined.py +1556 -0
- grasp_tool/preprocessing/__init__.py +1 -0
- grasp_tool/preprocessing/augumentation.py +66 -0
- grasp_tool/preprocessing/cellplot.py +475 -0
- grasp_tool/preprocessing/filter.py +171 -0
- grasp_tool/preprocessing/network.py +79 -0
- grasp_tool/preprocessing/partition.py +654 -0
- grasp_tool/preprocessing/portrait.py +1862 -0
- grasp_tool/preprocessing/register.py +1021 -0
- grasp_tool-0.1.0.dist-info/METADATA +511 -0
- grasp_tool-0.1.0.dist-info/RECORD +22 -0
- grasp_tool-0.1.0.dist-info/WHEEL +4 -0
- grasp_tool-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,1748 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch_geometric.data import Data
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from torch_geometric.nn import GCNConv
|
|
9
|
+
import numpy as np
|
|
10
|
+
import random
|
|
11
|
+
from sklearn.cluster import KMeans
|
|
12
|
+
import warnings
|
|
13
|
+
from . import embedding as emb
|
|
14
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
15
|
+
from joblib import Parallel, delayed
|
|
16
|
+
from multiprocessing import Pool, cpu_count
|
|
17
|
+
import multiprocessing
|
|
18
|
+
from multiprocessing import Pool, cpu_count
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def process_cell_gene(
|
|
22
|
+
cell, gene, dataset, path, n_sectors, m_rings, k_neighbor, base_path
|
|
23
|
+
):
|
|
24
|
+
graphs = []
|
|
25
|
+
aug_graphs = []
|
|
26
|
+
|
|
27
|
+
raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
|
|
28
|
+
aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
|
|
29
|
+
|
|
30
|
+
# df_file = f"{base_path}/{gene}_distances_filter_new.csv"
|
|
31
|
+
df_file = f"{base_path}/{gene}_distances_filtered.csv"
|
|
32
|
+
if not os.path.exists(df_file):
|
|
33
|
+
# print(f"1. Skipping {gene} in {cell} (file not found).")
|
|
34
|
+
return graphs, aug_graphs
|
|
35
|
+
|
|
36
|
+
df = pd.read_csv(df_file)
|
|
37
|
+
filtered_df = df[
|
|
38
|
+
(df["gene"] == gene) & (df["cell"] == cell) & (df["location"] == "other")
|
|
39
|
+
]
|
|
40
|
+
if filtered_df.empty:
|
|
41
|
+
# print(f"Skipping {gene} in cell {cell} filtered_df is empty")
|
|
42
|
+
return graphs, aug_graphs
|
|
43
|
+
|
|
44
|
+
# Original graph
|
|
45
|
+
nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
|
|
46
|
+
adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
|
|
47
|
+
if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
|
|
48
|
+
# print(f"1. Skipping {gene} in {cell} (file not found).")
|
|
49
|
+
return graphs, aug_graphs
|
|
50
|
+
|
|
51
|
+
node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
|
|
52
|
+
# if len(node_features) <= 5:
|
|
53
|
+
# # print(f"2. Skipping {gene} in {cell} (too few points).")
|
|
54
|
+
# return graphs, aug_graphs
|
|
55
|
+
|
|
56
|
+
node_features["nuclear_position"] = (
|
|
57
|
+
node_features["nuclear_position"]
|
|
58
|
+
.map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
|
|
59
|
+
.fillna(4)
|
|
60
|
+
.astype(int)
|
|
61
|
+
)
|
|
62
|
+
count_features = (
|
|
63
|
+
node_features["count"]
|
|
64
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
|
|
65
|
+
.tolist()
|
|
66
|
+
)
|
|
67
|
+
count_features = pd.DataFrame(
|
|
68
|
+
count_features, columns=[f"dim_{i}" for i in range(12)]
|
|
69
|
+
)
|
|
70
|
+
position_one_hot = pd.get_dummies(
|
|
71
|
+
node_features["nuclear_position"], prefix="pos"
|
|
72
|
+
).astype(int)
|
|
73
|
+
node_features = pd.concat([count_features, position_one_hot], axis=1)
|
|
74
|
+
node_features_tensor = torch.tensor(node_features.values, dtype=torch.float)
|
|
75
|
+
adj_matrix = pd.read_csv(adj_file)
|
|
76
|
+
# edge_index = torch.tensor(adj_matrix.values.nonzero(), dtype=torch.long)
|
|
77
|
+
edge_index = torch.tensor(np.array(adj_matrix.values.nonzero()), dtype=torch.long)
|
|
78
|
+
graph = Data(x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene)
|
|
79
|
+
graphs.append(graph)
|
|
80
|
+
|
|
81
|
+
# Augmented graph
|
|
82
|
+
nodes_file = f"{aug_path}/{gene}_node_matrix.csv"
|
|
83
|
+
adj_file = f"{aug_path}/{gene}_adj_matrix.csv"
|
|
84
|
+
if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
|
|
85
|
+
# print(f"Skipping {gene} in {cell} (augmented file not found).")
|
|
86
|
+
return graphs, aug_graphs
|
|
87
|
+
|
|
88
|
+
aug_node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
|
|
89
|
+
aug_node_features["nuclear_position"] = (
|
|
90
|
+
aug_node_features["nuclear_position"]
|
|
91
|
+
.map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
|
|
92
|
+
.fillna(4)
|
|
93
|
+
.astype(int)
|
|
94
|
+
)
|
|
95
|
+
aug_count_features = (
|
|
96
|
+
aug_node_features["count"]
|
|
97
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
|
|
98
|
+
.tolist()
|
|
99
|
+
)
|
|
100
|
+
aug_count_features = pd.DataFrame(
|
|
101
|
+
aug_count_features, columns=[f"dim_{i}" for i in range(12)]
|
|
102
|
+
)
|
|
103
|
+
aug_position_one_hot = pd.get_dummies(
|
|
104
|
+
aug_node_features["nuclear_position"], prefix="pos"
|
|
105
|
+
).astype(int)
|
|
106
|
+
aug_node_features = pd.concat([aug_count_features, aug_position_one_hot], axis=1)
|
|
107
|
+
aug_node_features_tensor = torch.tensor(aug_node_features.values, dtype=torch.float)
|
|
108
|
+
aug_adj_matrix = pd.read_csv(adj_file)
|
|
109
|
+
# aug_edge_index = torch.tensor(aug_adj_matrix.values.nonzero(), dtype=torch.long)
|
|
110
|
+
aug_edge_index = torch.tensor(
|
|
111
|
+
np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long
|
|
112
|
+
)
|
|
113
|
+
aug_graph = Data(
|
|
114
|
+
x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
|
|
115
|
+
)
|
|
116
|
+
aug_graphs.append(aug_graph)
|
|
117
|
+
|
|
118
|
+
return graphs, aug_graphs
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def generate_graph_data_parallel(
|
|
122
|
+
dataset,
|
|
123
|
+
cell_list,
|
|
124
|
+
gene_list,
|
|
125
|
+
path,
|
|
126
|
+
base_path,
|
|
127
|
+
n_sectors,
|
|
128
|
+
m_rings,
|
|
129
|
+
k_neighbor,
|
|
130
|
+
n_jobs=4,
|
|
131
|
+
):
|
|
132
|
+
# tqdm progress bar
|
|
133
|
+
all_cells_genes = [(cell, gene) for cell in cell_list for gene in gene_list]
|
|
134
|
+
# Run in parallel with progress
|
|
135
|
+
results = Parallel(n_jobs=n_jobs)(
|
|
136
|
+
delayed(process_cell_gene)(
|
|
137
|
+
cell, gene, dataset, path, n_sectors, m_rings, k_neighbor, base_path
|
|
138
|
+
)
|
|
139
|
+
for cell, gene in tqdm(
|
|
140
|
+
all_cells_genes,
|
|
141
|
+
desc="Processing cells and genes",
|
|
142
|
+
total=len(all_cells_genes),
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
# Merge results
|
|
146
|
+
graphs = []
|
|
147
|
+
aug_graphs = []
|
|
148
|
+
for result in results:
|
|
149
|
+
g, ag = result
|
|
150
|
+
graphs.extend(g)
|
|
151
|
+
aug_graphs.extend(ag)
|
|
152
|
+
|
|
153
|
+
return graphs, aug_graphs
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
## simulated_data1 / simulated_data2
|
|
157
|
+
def process_cell_gene_nofiltered(
|
|
158
|
+
cell, gene, dataset, path, n_sectors, m_rings, k_neighbor
|
|
159
|
+
):
|
|
160
|
+
graphs = []
|
|
161
|
+
aug_graphs = []
|
|
162
|
+
raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
|
|
163
|
+
aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
|
|
164
|
+
|
|
165
|
+
if dataset == "merfish_u2os" or dataset == "seqfish_fibroblast":
|
|
166
|
+
# Optional filtering lists (historically hard-coded).
|
|
167
|
+
# If GRASP_FILTER_ROOT is unset or files are missing, skip filtering.
|
|
168
|
+
filter_root = os.environ.get("GRASP_FILTER_ROOT")
|
|
169
|
+
if filter_root:
|
|
170
|
+
list1_path = os.path.join(filter_root, dataset, "all_low_points_list.csv")
|
|
171
|
+
list2_path = os.path.join(filter_root, dataset, "all_no_points_list.csv")
|
|
172
|
+
|
|
173
|
+
if os.path.exists(list1_path) and os.path.exists(list2_path):
|
|
174
|
+
list1 = pd.read_csv(list1_path)
|
|
175
|
+
list2 = pd.read_csv(list2_path)
|
|
176
|
+
|
|
177
|
+
# Check whether (cell, gene) is in the filter lists.
|
|
178
|
+
is_in_list1 = ((list1["cell"] == cell) & (list1["gene"] == gene)).any()
|
|
179
|
+
is_in_list2 = ((list2["cell"] == cell) & (list2["gene"] == gene)).any()
|
|
180
|
+
if is_in_list1 or is_in_list2:
|
|
181
|
+
print(f"Skipping {gene} in {cell} (found in filter list).")
|
|
182
|
+
return graphs, aug_graphs
|
|
183
|
+
|
|
184
|
+
# Original graph
|
|
185
|
+
nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
|
|
186
|
+
adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
|
|
187
|
+
if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
|
|
188
|
+
# print(f"1. Skipping {gene} in {cell} (file not found).")
|
|
189
|
+
return graphs, aug_graphs
|
|
190
|
+
|
|
191
|
+
node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
|
|
192
|
+
total_count = node_features["count"].sum()
|
|
193
|
+
node_features["count_ratio"] = node_features["count"] / total_count
|
|
194
|
+
# print(node_features.head(5))
|
|
195
|
+
node_features["nuclear_position"] = (
|
|
196
|
+
node_features["nuclear_position"]
|
|
197
|
+
.map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
|
|
198
|
+
.fillna(4)
|
|
199
|
+
.astype(int)
|
|
200
|
+
)
|
|
201
|
+
# count_features = node_features['count'].apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12)).tolist()
|
|
202
|
+
count_features = (
|
|
203
|
+
node_features["count_ratio"]
|
|
204
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
|
|
205
|
+
.tolist()
|
|
206
|
+
)
|
|
207
|
+
count_features = pd.DataFrame(
|
|
208
|
+
count_features, columns=[f"dim_{i}" for i in range(12)]
|
|
209
|
+
)
|
|
210
|
+
position_one_hot = pd.get_dummies(
|
|
211
|
+
node_features["nuclear_position"], prefix="pos"
|
|
212
|
+
).astype(int)
|
|
213
|
+
node_features = pd.concat([count_features, position_one_hot], axis=1)
|
|
214
|
+
node_features_tensor = torch.tensor(node_features.values, dtype=torch.float)
|
|
215
|
+
adj_matrix = pd.read_csv(adj_file)
|
|
216
|
+
# edge_index = torch.tensor(adj_matrix.values.nonzero(), dtype=torch.long)
|
|
217
|
+
edge_index = torch.tensor(np.array(adj_matrix.values.nonzero()), dtype=torch.long)
|
|
218
|
+
graph = Data(x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene)
|
|
219
|
+
graphs.append(graph)
|
|
220
|
+
|
|
221
|
+
# Augmented graph
|
|
222
|
+
nodes_file = f"{aug_path}/{gene}_node_matrix.csv"
|
|
223
|
+
adj_file = f"{aug_path}/{gene}_adj_matrix.csv"
|
|
224
|
+
if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
|
|
225
|
+
# print(f"Skipping {gene} in {cell} (augmented file not found).")
|
|
226
|
+
return graphs, aug_graphs
|
|
227
|
+
|
|
228
|
+
aug_node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
|
|
229
|
+
total_count = aug_node_features["count"].sum()
|
|
230
|
+
aug_node_features["count_ratio"] = aug_node_features["count"] / total_count
|
|
231
|
+
# print(aug_node_features.head(5))
|
|
232
|
+
aug_node_features["nuclear_position"] = (
|
|
233
|
+
aug_node_features["nuclear_position"]
|
|
234
|
+
.map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
|
|
235
|
+
.fillna(4)
|
|
236
|
+
.astype(int)
|
|
237
|
+
)
|
|
238
|
+
# aug_count_features = aug_node_features['count'].apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12)).tolist()
|
|
239
|
+
aug_count_features = (
|
|
240
|
+
aug_node_features["count_ratio"]
|
|
241
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
|
|
242
|
+
.tolist()
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
aug_count_features = pd.DataFrame(
|
|
246
|
+
aug_count_features, columns=[f"dim_{i}" for i in range(12)]
|
|
247
|
+
)
|
|
248
|
+
aug_position_one_hot = pd.get_dummies(
|
|
249
|
+
aug_node_features["nuclear_position"], prefix="pos"
|
|
250
|
+
).astype(int)
|
|
251
|
+
aug_node_features = pd.concat([aug_count_features, aug_position_one_hot], axis=1)
|
|
252
|
+
aug_node_features_tensor = torch.tensor(aug_node_features.values, dtype=torch.float)
|
|
253
|
+
aug_adj_matrix = pd.read_csv(adj_file)
|
|
254
|
+
# aug_edge_index = torch.tensor(aug_adj_matrix.values.nonzero(), dtype=torch.long)
|
|
255
|
+
aug_edge_index = torch.tensor(
|
|
256
|
+
np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long
|
|
257
|
+
)
|
|
258
|
+
aug_graph = Data(
|
|
259
|
+
x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
|
|
260
|
+
)
|
|
261
|
+
aug_graphs.append(aug_graph)
|
|
262
|
+
|
|
263
|
+
return graphs, aug_graphs
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def generate_graph_data_parallel_nofiltered(
|
|
267
|
+
dataset, cell_list, gene_list, path, n_sectors, m_rings, k_neighbor, n_jobs=4
|
|
268
|
+
):
|
|
269
|
+
# tqdm progress bar
|
|
270
|
+
all_cells_genes = [(cell, gene) for cell in cell_list for gene in gene_list]
|
|
271
|
+
# Run in parallel with progress
|
|
272
|
+
results = Parallel(n_jobs=n_jobs)(
|
|
273
|
+
delayed(process_cell_gene_nofiltered)(
|
|
274
|
+
cell, gene, dataset, path, n_sectors, m_rings, k_neighbor
|
|
275
|
+
)
|
|
276
|
+
for cell, gene in tqdm(
|
|
277
|
+
all_cells_genes,
|
|
278
|
+
desc="Processing cells and genes",
|
|
279
|
+
total=len(all_cells_genes),
|
|
280
|
+
)
|
|
281
|
+
)
|
|
282
|
+
# Merge results
|
|
283
|
+
graphs = []
|
|
284
|
+
aug_graphs = []
|
|
285
|
+
for result in results:
|
|
286
|
+
g, ag = result
|
|
287
|
+
graphs.extend(g)
|
|
288
|
+
aug_graphs.extend(ag)
|
|
289
|
+
return graphs, aug_graphs
|
|
290
|
+
|
|
291
|
+
graphs = []
|
|
292
|
+
aug_graphs = []
|
|
293
|
+
# df = df[df['groundtruth_yyzh'].isin(['Nuclear', 'Nuclear_edge', 'Cytoplasmic', 'Cell_edge', 'Random'])].reset_index(drop=True)
|
|
294
|
+
|
|
295
|
+
for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing Graphs"):
|
|
296
|
+
cell = row["cell"]
|
|
297
|
+
gene = row["gene"]
|
|
298
|
+
|
|
299
|
+
raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
|
|
300
|
+
aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
|
|
301
|
+
|
|
302
|
+
# Original graph
|
|
303
|
+
nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
|
|
304
|
+
adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
|
|
305
|
+
|
|
306
|
+
node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
|
|
307
|
+
total_count = node_features["count"].sum()
|
|
308
|
+
node_features["count_ratio"] = node_features["count"] / total_count
|
|
309
|
+
# print(node_features.head(5))
|
|
310
|
+
|
|
311
|
+
node_features["nuclear_position"] = (
|
|
312
|
+
node_features["nuclear_position"]
|
|
313
|
+
.map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
|
|
314
|
+
.fillna(4)
|
|
315
|
+
.astype(int)
|
|
316
|
+
)
|
|
317
|
+
# count_features = node_features['count'].apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12)).tolist()
|
|
318
|
+
count_features = (
|
|
319
|
+
node_features["count_ratio"]
|
|
320
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
|
|
321
|
+
.tolist()
|
|
322
|
+
)
|
|
323
|
+
count_features = pd.DataFrame(
|
|
324
|
+
count_features, columns=[f"dim_{i}" for i in range(12)]
|
|
325
|
+
)
|
|
326
|
+
position_one_hot = pd.get_dummies(
|
|
327
|
+
node_features["nuclear_position"], prefix="pos"
|
|
328
|
+
).astype(int)
|
|
329
|
+
node_features = pd.concat([count_features, position_one_hot], axis=1)
|
|
330
|
+
node_features_tensor = torch.tensor(node_features.values, dtype=torch.float)
|
|
331
|
+
adj_matrix = pd.read_csv(adj_file)
|
|
332
|
+
# edge_index = torch.tensor(adj_matrix.values.nonzero(), dtype=torch.long)
|
|
333
|
+
edge_index = torch.tensor(
|
|
334
|
+
np.array(adj_matrix.values.nonzero()), dtype=torch.long
|
|
335
|
+
)
|
|
336
|
+
graph = Data(
|
|
337
|
+
x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene
|
|
338
|
+
)
|
|
339
|
+
graphs.append(graph)
|
|
340
|
+
# Augmented graph
|
|
341
|
+
nodes_file = f"{aug_path}/{gene}_node_matrix.csv"
|
|
342
|
+
adj_file = f"{aug_path}/{gene}_adj_matrix.csv"
|
|
343
|
+
|
|
344
|
+
aug_node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
|
|
345
|
+
total_count = aug_node_features["count"].sum()
|
|
346
|
+
aug_node_features["count_ratio"] = aug_node_features["count"] / total_count
|
|
347
|
+
# print(aug_node_features.head(5))
|
|
348
|
+
aug_node_features["nuclear_position"] = (
|
|
349
|
+
aug_node_features["nuclear_position"]
|
|
350
|
+
.map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
|
|
351
|
+
.fillna(4)
|
|
352
|
+
.astype(int)
|
|
353
|
+
)
|
|
354
|
+
# aug_count_features = aug_node_features['count'].apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12)).tolist()
|
|
355
|
+
aug_count_features = (
|
|
356
|
+
aug_node_features["count_ratio"]
|
|
357
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
|
|
358
|
+
.tolist()
|
|
359
|
+
)
|
|
360
|
+
aug_count_features = pd.DataFrame(
|
|
361
|
+
aug_count_features, columns=[f"dim_{i}" for i in range(12)]
|
|
362
|
+
)
|
|
363
|
+
aug_position_one_hot = pd.get_dummies(
|
|
364
|
+
aug_node_features["nuclear_position"], prefix="pos"
|
|
365
|
+
).astype(int)
|
|
366
|
+
aug_node_features = pd.concat(
|
|
367
|
+
[aug_count_features, aug_position_one_hot], axis=1
|
|
368
|
+
)
|
|
369
|
+
aug_node_features_tensor = torch.tensor(
|
|
370
|
+
aug_node_features.values, dtype=torch.float
|
|
371
|
+
)
|
|
372
|
+
aug_adj_matrix = pd.read_csv(adj_file)
|
|
373
|
+
# aug_edge_index = torch.tensor(aug_adj_matrix.values.nonzero(), dtype=torch.long)
|
|
374
|
+
aug_edge_index = torch.tensor(
|
|
375
|
+
np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long
|
|
376
|
+
)
|
|
377
|
+
aug_graph = Data(
|
|
378
|
+
x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
|
|
379
|
+
)
|
|
380
|
+
aug_graphs.append(aug_graph)
|
|
381
|
+
return graphs, aug_graphs
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def generate_graph_data_target_weight(
|
|
385
|
+
dataset,
|
|
386
|
+
df,
|
|
387
|
+
path,
|
|
388
|
+
n_sectors,
|
|
389
|
+
m_rings,
|
|
390
|
+
k_neighbor,
|
|
391
|
+
knn_k_for_similarity_graph=None,
|
|
392
|
+
knn_k_ratio=0.25,
|
|
393
|
+
knn_k_min=3,
|
|
394
|
+
):
|
|
395
|
+
graphs = []
|
|
396
|
+
aug_graphs = []
|
|
397
|
+
# df = df[df['groundtruth_yyzh'].isin(['Nuclear', 'Nuclear_edge', 'Cytoplasmic', 'Cell_edge', 'Random'])].reset_index(drop=True)
|
|
398
|
+
|
|
399
|
+
for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing Graphs"):
|
|
400
|
+
cell = row["cell"]
|
|
401
|
+
gene = row["gene"]
|
|
402
|
+
|
|
403
|
+
raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
|
|
404
|
+
aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
|
|
405
|
+
|
|
406
|
+
# Original graph
|
|
407
|
+
nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
|
|
408
|
+
|
|
409
|
+
node_data_df = pd.read_csv(
|
|
410
|
+
nodes_file, usecols=["count", "is_edge", "nuclear_position", "is_virtual"]
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# --- Filter out virtual nodes ---
|
|
414
|
+
if "is_virtual" in node_data_df.columns:
|
|
415
|
+
node_data_df = node_data_df[node_data_df["is_virtual"] == 0].reset_index(
|
|
416
|
+
drop=True
|
|
417
|
+
)
|
|
418
|
+
# --- End filtering ---
|
|
419
|
+
|
|
420
|
+
if node_data_df.empty: # If no real nodes left
|
|
421
|
+
continue
|
|
422
|
+
|
|
423
|
+
total_count = node_data_df["count"].sum()
|
|
424
|
+
if (
|
|
425
|
+
total_count == 0
|
|
426
|
+
): # Should not happen if only real nodes with counts are kept, but good for safety
|
|
427
|
+
node_data_df["count_ratio"] = 0
|
|
428
|
+
else:
|
|
429
|
+
node_data_df["count_ratio"] = node_data_df["count"] / total_count
|
|
430
|
+
|
|
431
|
+
count_features_embedded = (
|
|
432
|
+
node_data_df["count_ratio"]
|
|
433
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=11))
|
|
434
|
+
.tolist()
|
|
435
|
+
)
|
|
436
|
+
count_features_final_df = pd.DataFrame(
|
|
437
|
+
count_features_embedded,
|
|
438
|
+
columns=[f"count_dim_{i}" for i in range(11)],
|
|
439
|
+
index=node_data_df.index,
|
|
440
|
+
) # Use filtered df index
|
|
441
|
+
|
|
442
|
+
is_edge_one_hot_df = pd.get_dummies(
|
|
443
|
+
node_data_df["is_edge"], prefix="edge", dtype=int
|
|
444
|
+
)
|
|
445
|
+
# Ensure consistent columns for is_edge (edge_0, edge_1)
|
|
446
|
+
for val in [0, 1]:
|
|
447
|
+
col_name = f"edge_{val}"
|
|
448
|
+
if col_name not in is_edge_one_hot_df.columns:
|
|
449
|
+
is_edge_one_hot_df[col_name] = 0
|
|
450
|
+
is_edge_one_hot_df = is_edge_one_hot_df[
|
|
451
|
+
is_edge_one_hot_df.columns.intersection(["edge_0", "edge_1"])
|
|
452
|
+
] # Ensure only existing columns selected
|
|
453
|
+
if "edge_0" not in is_edge_one_hot_df.columns:
|
|
454
|
+
is_edge_one_hot_df["edge_0"] = 0
|
|
455
|
+
if "edge_1" not in is_edge_one_hot_df.columns:
|
|
456
|
+
is_edge_one_hot_df["edge_1"] = 0
|
|
457
|
+
is_edge_one_hot_df = is_edge_one_hot_df[["edge_0", "edge_1"]]
|
|
458
|
+
|
|
459
|
+
nuclear_position_one_hot_df = pd.get_dummies(
|
|
460
|
+
node_data_df["nuclear_position"], prefix="pos", dtype=int
|
|
461
|
+
)
|
|
462
|
+
expected_pos_categories = ["inside", "outside", "boundary"]
|
|
463
|
+
for cat in expected_pos_categories:
|
|
464
|
+
col_name = f"pos_{cat}"
|
|
465
|
+
if col_name not in nuclear_position_one_hot_df.columns:
|
|
466
|
+
nuclear_position_one_hot_df[col_name] = 0
|
|
467
|
+
# Ensure correct order and presence of all expected columns
|
|
468
|
+
nuclear_position_one_hot_df = nuclear_position_one_hot_df.reindex(
|
|
469
|
+
columns=[f"pos_{cat}" for cat in expected_pos_categories], fill_value=0
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
final_node_features_df = pd.concat(
|
|
473
|
+
[count_features_final_df, is_edge_one_hot_df, nuclear_position_one_hot_df],
|
|
474
|
+
axis=1,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# No need to set virtual node features to 0 as they are already removed
|
|
478
|
+
|
|
479
|
+
node_features_tensor = torch.tensor(
|
|
480
|
+
final_node_features_df.values, dtype=torch.float
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
num_nodes = node_features_tensor.shape[0]
|
|
484
|
+
edge_index = torch.empty(
|
|
485
|
+
(2, 0), dtype=torch.long, device=node_features_tensor.device
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
if num_nodes > 1:
|
|
489
|
+
norm_node_features = F.normalize(node_features_tensor, p=2, dim=1)
|
|
490
|
+
similarity_matrix = torch.matmul(norm_node_features, norm_node_features.t())
|
|
491
|
+
|
|
492
|
+
# Set diagonal to a very small value to prevent self-loops from being top-k
|
|
493
|
+
similarity_matrix.fill_diagonal_(float("-inf"))
|
|
494
|
+
|
|
495
|
+
# --- Dynamically determine k ---
|
|
496
|
+
# 1) If the caller explicitly passes a positive integer (knn_k_for_similarity_graph),
|
|
497
|
+
# we will respect it but still cap it with (num_nodes-1).
|
|
498
|
+
# 2) If the arg is None or a non-positive value, we compute k as
|
|
499
|
+
# k = max(1, int(num_nodes * knn_k_ratio))
|
|
500
|
+
# and again cap it with (num_nodes-1).
|
|
501
|
+
if (
|
|
502
|
+
isinstance(knn_k_for_similarity_graph, int)
|
|
503
|
+
and knn_k_for_similarity_graph > 0
|
|
504
|
+
):
|
|
505
|
+
current_knn_k = min(knn_k_for_similarity_graph, num_nodes - 1)
|
|
506
|
+
else:
|
|
507
|
+
adaptive_k = max(knn_k_min, int(num_nodes * knn_k_ratio))
|
|
508
|
+
current_knn_k = min(adaptive_k, num_nodes - 1)
|
|
509
|
+
|
|
510
|
+
if current_knn_k > 0:
|
|
511
|
+
# Get top-k similar nodes for each node
|
|
512
|
+
top_k_vals, top_k_indices = torch.topk(
|
|
513
|
+
similarity_matrix, k=current_knn_k, dim=1
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
# Create adjacency matrix from top-k indices
|
|
517
|
+
adj = torch.zeros_like(similarity_matrix, dtype=torch.bool)
|
|
518
|
+
row_indices = (
|
|
519
|
+
torch.arange(num_nodes, device=node_features_tensor.device)
|
|
520
|
+
.unsqueeze(1)
|
|
521
|
+
.expand(-1, current_knn_k)
|
|
522
|
+
)
|
|
523
|
+
adj[row_indices, top_k_indices] = True
|
|
524
|
+
|
|
525
|
+
# Symmetrize the adjacency matrix
|
|
526
|
+
adj = adj | adj.t()
|
|
527
|
+
|
|
528
|
+
# No need to remove virtual node edges as virtual nodes themselves are removed
|
|
529
|
+
edge_index = adj.nonzero(as_tuple=False).t().contiguous()
|
|
530
|
+
|
|
531
|
+
graph = Data(
|
|
532
|
+
x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene
|
|
533
|
+
)
|
|
534
|
+
graphs.append(graph)
|
|
535
|
+
|
|
536
|
+
# Augmented graph
|
|
537
|
+
nodes_file_aug = f"{aug_path}/{gene}_node_matrix.csv"
|
|
538
|
+
aug_node_data_df = pd.read_csv(
|
|
539
|
+
nodes_file_aug,
|
|
540
|
+
usecols=["count", "is_edge", "nuclear_position", "is_virtual"],
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# --- Filter out virtual nodes for augmented graph ---
|
|
544
|
+
if "is_virtual" in aug_node_data_df.columns:
|
|
545
|
+
aug_node_data_df = aug_node_data_df[
|
|
546
|
+
aug_node_data_df["is_virtual"] == 0
|
|
547
|
+
].reset_index(drop=True)
|
|
548
|
+
# --- End filtering ---
|
|
549
|
+
|
|
550
|
+
if aug_node_data_df.empty: # If no real nodes left in augmented graph
|
|
551
|
+
# If original graph was added, but augmented has no real nodes, we might still want the original.
|
|
552
|
+
# Current behavior: skip adding an augmented graph, original graph is already in the list.
|
|
553
|
+
# If you require pairs, you might need to remove the original graph too or handle this case differently.
|
|
554
|
+
continue
|
|
555
|
+
|
|
556
|
+
aug_total_count = aug_node_data_df["count"].sum()
|
|
557
|
+
if aug_total_count == 0:
|
|
558
|
+
aug_node_data_df["count_ratio"] = 0
|
|
559
|
+
else:
|
|
560
|
+
aug_node_data_df["count_ratio"] = (
|
|
561
|
+
aug_node_data_df["count"] / aug_total_count
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
aug_count_features_embedded = (
|
|
565
|
+
aug_node_data_df["count_ratio"]
|
|
566
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=11))
|
|
567
|
+
.tolist()
|
|
568
|
+
)
|
|
569
|
+
aug_count_features_final_df = pd.DataFrame(
|
|
570
|
+
aug_count_features_embedded,
|
|
571
|
+
columns=[f"count_dim_{i}" for i in range(11)],
|
|
572
|
+
index=aug_node_data_df.index,
|
|
573
|
+
) # Use filtered df index
|
|
574
|
+
|
|
575
|
+
aug_is_edge_one_hot_df = pd.get_dummies(
|
|
576
|
+
aug_node_data_df["is_edge"], prefix="edge", dtype=int
|
|
577
|
+
)
|
|
578
|
+
for val in [0, 1]:
|
|
579
|
+
col_name = f"edge_{val}"
|
|
580
|
+
if col_name not in aug_is_edge_one_hot_df.columns:
|
|
581
|
+
aug_is_edge_one_hot_df[col_name] = 0
|
|
582
|
+
aug_is_edge_one_hot_df = aug_is_edge_one_hot_df[
|
|
583
|
+
aug_is_edge_one_hot_df.columns.intersection(["edge_0", "edge_1"])
|
|
584
|
+
]
|
|
585
|
+
if "edge_0" not in aug_is_edge_one_hot_df.columns:
|
|
586
|
+
aug_is_edge_one_hot_df["edge_0"] = 0
|
|
587
|
+
if "edge_1" not in aug_is_edge_one_hot_df.columns:
|
|
588
|
+
aug_is_edge_one_hot_df["edge_1"] = 0
|
|
589
|
+
aug_is_edge_one_hot_df = aug_is_edge_one_hot_df[["edge_0", "edge_1"]]
|
|
590
|
+
|
|
591
|
+
aug_nuclear_position_one_hot_df = pd.get_dummies(
|
|
592
|
+
aug_node_data_df["nuclear_position"], prefix="pos", dtype=int
|
|
593
|
+
)
|
|
594
|
+
for cat in expected_pos_categories:
|
|
595
|
+
col_name = f"pos_{cat}"
|
|
596
|
+
if col_name not in aug_nuclear_position_one_hot_df.columns:
|
|
597
|
+
aug_nuclear_position_one_hot_df[col_name] = 0
|
|
598
|
+
aug_nuclear_position_one_hot_df = aug_nuclear_position_one_hot_df.reindex(
|
|
599
|
+
columns=[f"pos_{cat}" for cat in expected_pos_categories], fill_value=0
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
aug_final_node_features_df = pd.concat(
|
|
603
|
+
[
|
|
604
|
+
aug_count_features_final_df,
|
|
605
|
+
aug_is_edge_one_hot_df,
|
|
606
|
+
aug_nuclear_position_one_hot_df,
|
|
607
|
+
],
|
|
608
|
+
axis=1,
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
# No need to set virtual node features to 0 as they are already removed
|
|
612
|
+
|
|
613
|
+
aug_node_features_tensor = torch.tensor(
|
|
614
|
+
aug_final_node_features_df.values, dtype=torch.float
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
num_aug_nodes = aug_node_features_tensor.shape[0]
|
|
618
|
+
aug_edge_index = torch.empty(
|
|
619
|
+
(2, 0), dtype=torch.long, device=aug_node_features_tensor.device
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
if num_aug_nodes > 1:
|
|
623
|
+
norm_aug_node_features = F.normalize(aug_node_features_tensor, p=2, dim=1)
|
|
624
|
+
aug_similarity_matrix = torch.matmul(
|
|
625
|
+
norm_aug_node_features, norm_aug_node_features.t()
|
|
626
|
+
)
|
|
627
|
+
aug_similarity_matrix.fill_diagonal_(float("-inf"))
|
|
628
|
+
|
|
629
|
+
# Same adaptive-k logic for augmented graphs
|
|
630
|
+
if (
|
|
631
|
+
isinstance(knn_k_for_similarity_graph, int)
|
|
632
|
+
and knn_k_for_similarity_graph > 0
|
|
633
|
+
):
|
|
634
|
+
current_knn_k_aug = min(knn_k_for_similarity_graph, num_aug_nodes - 1)
|
|
635
|
+
else:
|
|
636
|
+
adaptive_k_aug = max(knn_k_min, int(num_aug_nodes * knn_k_ratio))
|
|
637
|
+
current_knn_k_aug = min(adaptive_k_aug, num_aug_nodes - 1)
|
|
638
|
+
|
|
639
|
+
if current_knn_k_aug > 0:
|
|
640
|
+
aug_top_k_vals, aug_top_k_indices = torch.topk(
|
|
641
|
+
aug_similarity_matrix, k=current_knn_k_aug, dim=1
|
|
642
|
+
)
|
|
643
|
+
aug_adj = torch.zeros_like(aug_similarity_matrix, dtype=torch.bool)
|
|
644
|
+
aug_row_indices = (
|
|
645
|
+
torch.arange(num_aug_nodes, device=aug_node_features_tensor.device)
|
|
646
|
+
.unsqueeze(1)
|
|
647
|
+
.expand(-1, current_knn_k_aug)
|
|
648
|
+
)
|
|
649
|
+
aug_adj[aug_row_indices, aug_top_k_indices] = True
|
|
650
|
+
aug_adj = aug_adj | aug_adj.t()
|
|
651
|
+
|
|
652
|
+
# No need to remove virtual node edges as virtual nodes themselves are removed
|
|
653
|
+
aug_edge_index = aug_adj.nonzero(as_tuple=False).t().contiguous()
|
|
654
|
+
|
|
655
|
+
aug_graph = Data(
|
|
656
|
+
x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
|
|
657
|
+
)
|
|
658
|
+
aug_graphs.append(aug_graph)
|
|
659
|
+
return graphs, aug_graphs
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def generate_graph_data_target_weight2(
|
|
663
|
+
dataset, df, path, n_sectors, m_rings, k_neighbor
|
|
664
|
+
):
|
|
665
|
+
graphs = []
|
|
666
|
+
aug_graphs = []
|
|
667
|
+
|
|
668
|
+
for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing Graphs"):
|
|
669
|
+
cell = row["cell"]
|
|
670
|
+
gene = row["gene"]
|
|
671
|
+
|
|
672
|
+
raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
|
|
673
|
+
aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
|
|
674
|
+
|
|
675
|
+
# Original graph
|
|
676
|
+
nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
|
|
677
|
+
adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
|
|
678
|
+
|
|
679
|
+
try:
|
|
680
|
+
node_features_df = pd.read_csv(
|
|
681
|
+
nodes_file,
|
|
682
|
+
usecols=["count", "is_edge", "nuclear_position", "is_virtual"],
|
|
683
|
+
)
|
|
684
|
+
except FileNotFoundError:
|
|
685
|
+
continue
|
|
686
|
+
if (
|
|
687
|
+
node_features_df.empty
|
|
688
|
+
or "count" not in node_features_df.columns
|
|
689
|
+
or "nuclear_position" not in node_features_df.columns
|
|
690
|
+
):
|
|
691
|
+
continue
|
|
692
|
+
if "is_virtual" not in node_features_df.columns:
|
|
693
|
+
node_features_df["is_virtual"] = 0
|
|
694
|
+
|
|
695
|
+
df_for_feature_calc = node_features_df.copy()
|
|
696
|
+
total_count = df_for_feature_calc["count"].sum()
|
|
697
|
+
df_for_feature_calc["count_ratio"] = (
|
|
698
|
+
0 if total_count == 0 else df_for_feature_calc["count"] / total_count
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
# Assumes emb.nonlinear_transform_embedding is available
|
|
702
|
+
count_features_list = (
|
|
703
|
+
df_for_feature_calc["count_ratio"]
|
|
704
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
|
|
705
|
+
.tolist()
|
|
706
|
+
)
|
|
707
|
+
count_features_final_df = pd.DataFrame(
|
|
708
|
+
count_features_list,
|
|
709
|
+
columns=[f"dim_{i}" for i in range(12)],
|
|
710
|
+
index=df_for_feature_calc.index,
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
df_for_feature_calc["nuclear_position_mapped"] = (
|
|
714
|
+
df_for_feature_calc["nuclear_position"]
|
|
715
|
+
.map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
|
|
716
|
+
.astype(np.int64)
|
|
717
|
+
)
|
|
718
|
+
position_one_hot_df = pd.get_dummies(
|
|
719
|
+
df_for_feature_calc["nuclear_position_mapped"], prefix="pos", dtype=int
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
expected_pos_cols = [f"pos_{i}" for i in range(4)]
|
|
723
|
+
for col in expected_pos_cols:
|
|
724
|
+
if col not in position_one_hot_df.columns:
|
|
725
|
+
position_one_hot_df[col] = 0
|
|
726
|
+
position_one_hot_df = position_one_hot_df[expected_pos_cols]
|
|
727
|
+
final_node_features_df = pd.concat(
|
|
728
|
+
[count_features_final_df, position_one_hot_df], axis=1
|
|
729
|
+
).astype(float)
|
|
730
|
+
virtual_node_mask = node_features_df["is_virtual"] == 1
|
|
731
|
+
final_node_features_df.loc[virtual_node_mask, :] = 0.0
|
|
732
|
+
final_node_features_df = final_node_features_df.astype(float)
|
|
733
|
+
node_features_tensor = torch.tensor(
|
|
734
|
+
final_node_features_df.values, dtype=torch.float
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
# adj_matrix = pd.read_csv(adj_file)
|
|
738
|
+
# edge_index = torch.tensor(np.array(adj_matrix.values.nonzero()), dtype=torch.long)
|
|
739
|
+
# graph = Data(x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene)
|
|
740
|
+
# graphs.append(graph)
|
|
741
|
+
|
|
742
|
+
# Augmented graph
|
|
743
|
+
nodes_file_aug = f"{aug_path}/{gene}_node_matrix.csv"
|
|
744
|
+
adj_file_aug = f"{aug_path}/{gene}_adj_matrix.csv"
|
|
745
|
+
|
|
746
|
+
try:
|
|
747
|
+
aug_node_features_df = pd.read_csv(nodes_file_aug)
|
|
748
|
+
except FileNotFoundError:
|
|
749
|
+
continue
|
|
750
|
+
if (
|
|
751
|
+
aug_node_features_df.empty
|
|
752
|
+
or "count" not in aug_node_features_df.columns
|
|
753
|
+
or "nuclear_position" not in aug_node_features_df.columns
|
|
754
|
+
):
|
|
755
|
+
continue
|
|
756
|
+
if "is_virtual" not in aug_node_features_df.columns:
|
|
757
|
+
aug_node_features_df["is_virtual"] = 0
|
|
758
|
+
|
|
759
|
+
aug_df_for_feature_calc = aug_node_features_df.copy()
|
|
760
|
+
aug_total_count = aug_df_for_feature_calc["count"].sum()
|
|
761
|
+
aug_df_for_feature_calc["count_ratio"] = (
|
|
762
|
+
0
|
|
763
|
+
if aug_total_count == 0
|
|
764
|
+
else aug_df_for_feature_calc["count"] / aug_total_count
|
|
765
|
+
)
|
|
766
|
+
aug_count_features_list = (
|
|
767
|
+
aug_df_for_feature_calc["count_ratio"]
|
|
768
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
|
|
769
|
+
.tolist()
|
|
770
|
+
)
|
|
771
|
+
aug_count_features_final_df = pd.DataFrame(
|
|
772
|
+
aug_count_features_list,
|
|
773
|
+
columns=[f"dim_{i}" for i in range(12)],
|
|
774
|
+
index=aug_df_for_feature_calc.index,
|
|
775
|
+
)
|
|
776
|
+
aug_df_for_feature_calc["nuclear_position_mapped"] = (
|
|
777
|
+
aug_df_for_feature_calc["nuclear_position"]
|
|
778
|
+
.map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
|
|
779
|
+
.astype(np.int64)
|
|
780
|
+
)
|
|
781
|
+
aug_position_one_hot_df = pd.get_dummies(
|
|
782
|
+
aug_df_for_feature_calc["nuclear_position_mapped"], prefix="pos", dtype=int
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
for col in expected_pos_cols:
|
|
786
|
+
if col not in aug_position_one_hot_df.columns:
|
|
787
|
+
aug_position_one_hot_df[col] = 0
|
|
788
|
+
aug_position_one_hot_df = aug_position_one_hot_df[expected_pos_cols]
|
|
789
|
+
aug_final_node_features_df = pd.concat(
|
|
790
|
+
[aug_count_features_final_df, aug_position_one_hot_df], axis=1
|
|
791
|
+
).astype(float)
|
|
792
|
+
aug_virtual_node_mask = aug_node_features_df["is_virtual"] == 1
|
|
793
|
+
aug_final_node_features_df.loc[aug_virtual_node_mask, :] = 0.0
|
|
794
|
+
aug_final_node_features_df = aug_final_node_features_df.astype(float)
|
|
795
|
+
aug_node_features_tensor = torch.tensor(
|
|
796
|
+
aug_final_node_features_df.values, dtype=torch.float
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
# aug_adj_matrix = pd.read_csv(adj_file_aug)
|
|
800
|
+
# aug_edge_index = torch.tensor(np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long)
|
|
801
|
+
# aug_graph = Data(x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene)
|
|
802
|
+
# aug_graphs.append(aug_graph)
|
|
803
|
+
|
|
804
|
+
return graphs, aug_graphs
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
# def process_graph_row_safe(args):
|
|
808
|
+
# row, path, n_sectors, m_rings, k_neighbor = args
|
|
809
|
+
# cell = row['cell']
|
|
810
|
+
# gene = row['gene']
|
|
811
|
+
# graph_data = []
|
|
812
|
+
|
|
813
|
+
# for is_aug in [False, True]:
|
|
814
|
+
# base_path = f"{path}/{cell}"
|
|
815
|
+
# if is_aug:
|
|
816
|
+
# base_path += "_aug"
|
|
817
|
+
|
|
818
|
+
# nodes_file = f'{base_path}/{gene}_node_matrix.csv'
|
|
819
|
+
# adj_file = f'{base_path}/{gene}_adj_matrix.csv'
|
|
820
|
+
|
|
821
|
+
# try:
|
|
822
|
+
# node_df = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
|
|
823
|
+
# total_count = node_df['count'].sum()
|
|
824
|
+
# if total_count == 0:
|
|
825
|
+
# return None
|
|
826
|
+
|
|
827
|
+
# count_ratio = (node_df['count'] / total_count).to_numpy()
|
|
828
|
+
# count_embed_np = np.array([emb.nonlinear_transform_embedding(x, dim=12) for x in count_ratio])
|
|
829
|
+
# pos_int = node_df['nuclear_position'].map(
|
|
830
|
+
# {'inside': 0, 'outside': 1, 'boundary': 2, 'edge': 3}
|
|
831
|
+
# ).fillna(4).astype(int).to_numpy()
|
|
832
|
+
# pos_one_hot = np.eye(4)[pos_int]
|
|
833
|
+
# node_features_np = np.hstack([count_embed_np, pos_one_hot])
|
|
834
|
+
|
|
835
|
+
# adj_matrix = pd.read_csv(adj_file).values
|
|
836
|
+
# edge_index_np = np.array(np.nonzero(adj_matrix))
|
|
837
|
+
|
|
838
|
+
# graph_data.append({
|
|
839
|
+
# 'x': node_features_np,
|
|
840
|
+
# 'edge_index': edge_index_np,
|
|
841
|
+
# 'cell': cell,
|
|
842
|
+
# 'gene': gene
|
|
843
|
+
# })
|
|
844
|
+
|
|
845
|
+
# except Exception as e:
|
|
846
|
+
# print(f"Error processing {nodes_file}: {e}")
|
|
847
|
+
# return None
|
|
848
|
+
|
|
849
|
+
# return graph_data[0], graph_data[1]
|
|
850
|
+
|
|
851
|
+
# def generate_graph_data_target_parallel(dataset, df, path, n_sectors, m_rings, k_neighbor):
|
|
852
|
+
# args_list = [(row, path, n_sectors, m_rings, k_neighbor) for _, row in df.iterrows()]
|
|
853
|
+
|
|
854
|
+
# multiprocessing.set_start_method("spawn", force=True)
|
|
855
|
+
|
|
856
|
+
# with Pool(processes=4) as pool:
|
|
857
|
+
# results = list(tqdm(pool.imap_unordered(process_graph_row_safe, args_list, chunksize=10), total=len(df), desc="Parallel Processing Graphs"))
|
|
858
|
+
|
|
859
|
+
# original_graphs, augmented_graphs = [], []
|
|
860
|
+
|
|
861
|
+
# for result in results:
|
|
862
|
+
# if result is None:
|
|
863
|
+
# continue
|
|
864
|
+
# for idx, (graph_list, target) in enumerate(zip([original_graphs, augmented_graphs], result)):
|
|
865
|
+
# data = Data(
|
|
866
|
+
# x=torch.tensor(target['x'], dtype=torch.float),
|
|
867
|
+
# edge_index=torch.tensor(target['edge_index'], dtype=torch.long),
|
|
868
|
+
# cell=target['cell'],
|
|
869
|
+
# gene=target['gene']
|
|
870
|
+
# )
|
|
871
|
+
# graph_list.append(data)
|
|
872
|
+
|
|
873
|
+
# return original_graphs, augmented_graphs
|
|
874
|
+
|
|
875
|
+
import pandas as pd
|
|
876
|
+
import numpy as np
|
|
877
|
+
import torch
|
|
878
|
+
from torch_geometric.data import Data
|
|
879
|
+
from tqdm.auto import tqdm
|
|
880
|
+
import multiprocessing
|
|
881
|
+
from multiprocessing import Pool
|
|
882
|
+
import functools
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
def process_single_gene_pair(args):
|
|
886
|
+
"""Process one (cell, gene) pair inside a worker process."""
|
|
887
|
+
row, path, n_sectors, m_rings, k_neighbor = args
|
|
888
|
+
cell = row["cell"]
|
|
889
|
+
gene = row["gene"]
|
|
890
|
+
|
|
891
|
+
results = {}
|
|
892
|
+
|
|
893
|
+
for suffix in ["", "_aug"]:
|
|
894
|
+
is_aug = suffix == "_aug"
|
|
895
|
+
base_path = f"{path}/{cell}{suffix}"
|
|
896
|
+
nodes_file = f"{base_path}/{gene}_node_matrix.csv"
|
|
897
|
+
adj_file = f"{base_path}/{gene}_adj_matrix.csv"
|
|
898
|
+
|
|
899
|
+
try:
|
|
900
|
+
# Read node features.
|
|
901
|
+
try:
|
|
902
|
+
node_df = pd.read_csv(nodes_file)
|
|
903
|
+
except FileNotFoundError:
|
|
904
|
+
return None
|
|
905
|
+
|
|
906
|
+
required_cols = ["count", "nuclear_position"]
|
|
907
|
+
if not all(col in node_df.columns for col in required_cols):
|
|
908
|
+
return None
|
|
909
|
+
|
|
910
|
+
# Feature computation (count ratio embedding + position one-hot).
|
|
911
|
+
total_count = node_df["count"].sum()
|
|
912
|
+
if total_count == 0:
|
|
913
|
+
return None
|
|
914
|
+
|
|
915
|
+
count_ratio = node_df["count"] / total_count
|
|
916
|
+
count_embed_list = [
|
|
917
|
+
emb.nonlinear_transform_embedding(x, dim=12) for x in count_ratio
|
|
918
|
+
]
|
|
919
|
+
count_embed_np = np.array(count_embed_list)
|
|
920
|
+
|
|
921
|
+
# Position one-hot.
|
|
922
|
+
pos_map = {"inside": 0, "outside": 1, "boundary": 2, "edge": 3}
|
|
923
|
+
node_df["pos_mapped"] = (
|
|
924
|
+
node_df["nuclear_position"].map(pos_map).fillna(4).astype(int)
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
# NOTE: enforce dtype=int to avoid mixed bool/int columns which become
|
|
928
|
+
# dtype=object when converting to numpy.
|
|
929
|
+
pos_dummies = pd.get_dummies(node_df["pos_mapped"], prefix="pos", dtype=int)
|
|
930
|
+
expected_cols = [f"pos_{i}" for i in range(4)]
|
|
931
|
+
pos_dummies = pos_dummies.reindex(columns=expected_cols, fill_value=0)
|
|
932
|
+
pos_features_np = pos_dummies.to_numpy()
|
|
933
|
+
|
|
934
|
+
# Merge features.
|
|
935
|
+
node_features_np = np.hstack([count_embed_np, pos_features_np])
|
|
936
|
+
|
|
937
|
+
# 4. Adjacency handling (check empty matrices)
|
|
938
|
+
try:
|
|
939
|
+
adj_df = pd.read_csv(adj_file)
|
|
940
|
+
if adj_df.empty:
|
|
941
|
+
edge_index_np = np.empty((2, 0))
|
|
942
|
+
else:
|
|
943
|
+
edge_index_np = np.array(adj_df.values.nonzero())
|
|
944
|
+
except FileNotFoundError:
|
|
945
|
+
return None
|
|
946
|
+
|
|
947
|
+
# Store in a temporary dict
|
|
948
|
+
key = "aug" if is_aug else "orig"
|
|
949
|
+
results[key] = {
|
|
950
|
+
"x": node_features_np,
|
|
951
|
+
"edge_index": edge_index_np,
|
|
952
|
+
"cell": cell,
|
|
953
|
+
"gene": gene,
|
|
954
|
+
}
|
|
955
|
+
|
|
956
|
+
except Exception as e:
|
|
957
|
+
# Log errors for debugging; keep output minimal in parallel runs
|
|
958
|
+
# print(f"Error in {cell}-{gene}: {e}")
|
|
959
|
+
return None
|
|
960
|
+
|
|
961
|
+
# Only return when both original and augmented graphs succeed
|
|
962
|
+
if "orig" in results and "aug" in results:
|
|
963
|
+
return results["orig"], results["aug"]
|
|
964
|
+
return None
|
|
965
|
+
|
|
966
|
+
|
|
967
|
+
def generate_graph_data_target_parallel(
|
|
968
|
+
dataset, df, path, n_sectors, m_rings, k_neighbor, processes=8
|
|
969
|
+
):
|
|
970
|
+
# Build argument list
|
|
971
|
+
args_list = [
|
|
972
|
+
(row, path, n_sectors, m_rings, k_neighbor) for _, row in df.iterrows()
|
|
973
|
+
]
|
|
974
|
+
|
|
975
|
+
# Ensure spawn mode works on some platforms
|
|
976
|
+
try:
|
|
977
|
+
multiprocessing.set_start_method("spawn", force=True)
|
|
978
|
+
except RuntimeError:
|
|
979
|
+
pass
|
|
980
|
+
|
|
981
|
+
original_graphs = []
|
|
982
|
+
augmented_graphs = []
|
|
983
|
+
|
|
984
|
+
print(f"Starting parallel processing with {processes} cores...")
|
|
985
|
+
|
|
986
|
+
with Pool(processes=processes) as pool:
|
|
987
|
+
# imap_unordered is efficient because it does not preserve task order
|
|
988
|
+
iterator = pool.imap_unordered(
|
|
989
|
+
process_single_gene_pair, args_list, chunksize=10
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
for result in tqdm(
|
|
993
|
+
iterator, total=len(args_list), desc="Generating Target Graphs"
|
|
994
|
+
):
|
|
995
|
+
if result is None:
|
|
996
|
+
continue
|
|
997
|
+
|
|
998
|
+
orig_dict, aug_dict = result
|
|
999
|
+
|
|
1000
|
+
# Convert to torch_geometric Data in the parent process (safer)
|
|
1001
|
+
for target_list, data_dict in zip(
|
|
1002
|
+
[original_graphs, augmented_graphs], [orig_dict, aug_dict]
|
|
1003
|
+
):
|
|
1004
|
+
data = Data(
|
|
1005
|
+
x=torch.tensor(data_dict["x"], dtype=torch.float),
|
|
1006
|
+
edge_index=torch.tensor(data_dict["edge_index"], dtype=torch.long),
|
|
1007
|
+
cell=data_dict["cell"],
|
|
1008
|
+
gene=data_dict["gene"],
|
|
1009
|
+
)
|
|
1010
|
+
target_list.append(data)
|
|
1011
|
+
|
|
1012
|
+
return original_graphs, augmented_graphs
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
## AD/merfish mouse brain (df is the dataframe to load)
|
|
1016
|
+
def generate_graph_data_target_noposition(
|
|
1017
|
+
dataset, df, path, n_sectors, m_rings, k_neighbor
|
|
1018
|
+
):
|
|
1019
|
+
graphs = []
|
|
1020
|
+
aug_graphs = []
|
|
1021
|
+
|
|
1022
|
+
for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing Graphs"):
|
|
1023
|
+
cell = row["cell"]
|
|
1024
|
+
gene = row["gene"]
|
|
1025
|
+
|
|
1026
|
+
raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
|
|
1027
|
+
aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
|
|
1028
|
+
# Original graph
|
|
1029
|
+
nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
|
|
1030
|
+
adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
|
|
1031
|
+
|
|
1032
|
+
node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4])
|
|
1033
|
+
total_count = node_features["count"].sum()
|
|
1034
|
+
node_features["count_ratio"] = node_features["count"] / total_count
|
|
1035
|
+
count_features = (
|
|
1036
|
+
node_features["count_ratio"]
|
|
1037
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=16))
|
|
1038
|
+
.tolist()
|
|
1039
|
+
)
|
|
1040
|
+
count_features = pd.DataFrame(
|
|
1041
|
+
count_features, columns=[f"dim_{i}" for i in range(16)]
|
|
1042
|
+
)
|
|
1043
|
+
node_features_tensor = torch.tensor(count_features.values, dtype=torch.float)
|
|
1044
|
+
adj_matrix = pd.read_csv(adj_file)
|
|
1045
|
+
edge_index = torch.tensor(
|
|
1046
|
+
np.array(adj_matrix.values.nonzero()), dtype=torch.long
|
|
1047
|
+
)
|
|
1048
|
+
graph = Data(
|
|
1049
|
+
x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene
|
|
1050
|
+
)
|
|
1051
|
+
graphs.append(graph)
|
|
1052
|
+
# Augmented graph
|
|
1053
|
+
nodes_file = f"{aug_path}/{gene}_node_matrix.csv"
|
|
1054
|
+
adj_file = f"{aug_path}/{gene}_adj_matrix.csv"
|
|
1055
|
+
|
|
1056
|
+
aug_node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4])
|
|
1057
|
+
total_count = aug_node_features["count"].sum()
|
|
1058
|
+
aug_node_features["count_ratio"] = aug_node_features["count"] / total_count
|
|
1059
|
+
aug_count_features = (
|
|
1060
|
+
aug_node_features["count_ratio"]
|
|
1061
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=16))
|
|
1062
|
+
.tolist()
|
|
1063
|
+
)
|
|
1064
|
+
aug_count_features = pd.DataFrame(
|
|
1065
|
+
aug_count_features, columns=[f"dim_{i}" for i in range(16)]
|
|
1066
|
+
)
|
|
1067
|
+
aug_node_features_tensor = torch.tensor(
|
|
1068
|
+
aug_count_features.values, dtype=torch.float
|
|
1069
|
+
)
|
|
1070
|
+
aug_adj_matrix = pd.read_csv(adj_file)
|
|
1071
|
+
aug_edge_index = torch.tensor(
|
|
1072
|
+
np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long
|
|
1073
|
+
)
|
|
1074
|
+
aug_graph = Data(
|
|
1075
|
+
x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
|
|
1076
|
+
)
|
|
1077
|
+
aug_graphs.append(aug_graph)
|
|
1078
|
+
return graphs, aug_graphs
|
|
1079
|
+
|
|
1080
|
+
|
|
1081
|
+
def process_cell_gene_noposition(
|
|
1082
|
+
cell, gene, dataset, path, n_sectors, m_rings, k_neighbor, base_path
|
|
1083
|
+
):
|
|
1084
|
+
graphs = []
|
|
1085
|
+
aug_graphs = []
|
|
1086
|
+
# base_path = f"/Volumes/hyydisk/GCN_CL/3_filter/1_{dataset}_Wasserstein_Distance/"
|
|
1087
|
+
|
|
1088
|
+
raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
|
|
1089
|
+
aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
|
|
1090
|
+
|
|
1091
|
+
df_file = f"{base_path}/{gene}_distances_filter_new.csv"
|
|
1092
|
+
if not os.path.exists(df_file):
|
|
1093
|
+
print(f"1. Skipping {gene} in {cell} (file not found).")
|
|
1094
|
+
return graphs, aug_graphs
|
|
1095
|
+
|
|
1096
|
+
df = pd.read_csv(df_file)
|
|
1097
|
+
filtered_df = df[
|
|
1098
|
+
(df["gene"] == gene) & (df["cell"] == cell) & (df["location"] == "other")
|
|
1099
|
+
]
|
|
1100
|
+
if filtered_df.empty:
|
|
1101
|
+
# print(f"Skipping {gene} in cell {cell} filtered_df is empty")
|
|
1102
|
+
return graphs, aug_graphs
|
|
1103
|
+
|
|
1104
|
+
# Original graph
|
|
1105
|
+
nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
|
|
1106
|
+
adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
|
|
1107
|
+
if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
|
|
1108
|
+
# print(f"1. Skipping {gene} in {cell} (file not found).")
|
|
1109
|
+
return graphs, aug_graphs
|
|
1110
|
+
|
|
1111
|
+
node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4])
|
|
1112
|
+
if len(node_features) <= 5:
|
|
1113
|
+
# print(f"2. Skipping {gene} in {cell} (too few points).")
|
|
1114
|
+
return graphs, aug_graphs
|
|
1115
|
+
|
|
1116
|
+
total_count = node_features["count"].sum()
|
|
1117
|
+
node_features["count_ratio"] = node_features["count"] / total_count
|
|
1118
|
+
count_features = (
|
|
1119
|
+
node_features["count_ratio"]
|
|
1120
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=16))
|
|
1121
|
+
.tolist()
|
|
1122
|
+
)
|
|
1123
|
+
count_features = pd.DataFrame(
|
|
1124
|
+
count_features, columns=[f"dim_{i}" for i in range(8)]
|
|
1125
|
+
)
|
|
1126
|
+
node_features_tensor = torch.tensor(node_features.values, dtype=torch.float)
|
|
1127
|
+
adj_matrix = pd.read_csv(adj_file)
|
|
1128
|
+
edge_index = torch.tensor(np.array(adj_matrix.values.nonzero()), dtype=torch.long)
|
|
1129
|
+
graph = Data(x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene)
|
|
1130
|
+
graphs.append(graph)
|
|
1131
|
+
|
|
1132
|
+
# Augmented graph
|
|
1133
|
+
nodes_file = f"{aug_path}/{gene}_node_matrix.csv"
|
|
1134
|
+
adj_file = f"{aug_path}/{gene}_adj_matrix.csv"
|
|
1135
|
+
if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
|
|
1136
|
+
# print(f"Skipping {gene} in {cell} (augmented file not found).")
|
|
1137
|
+
return graphs, aug_graphs
|
|
1138
|
+
|
|
1139
|
+
aug_node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4])
|
|
1140
|
+
aug_count_features = (
|
|
1141
|
+
aug_node_features["count"]
|
|
1142
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=8))
|
|
1143
|
+
.tolist()
|
|
1144
|
+
)
|
|
1145
|
+
aug_count_features = pd.DataFrame(
|
|
1146
|
+
aug_count_features, columns=[f"dim_{i}" for i in range(8)]
|
|
1147
|
+
)
|
|
1148
|
+
aug_node_features_tensor = torch.tensor(aug_node_features.values, dtype=torch.float)
|
|
1149
|
+
aug_adj_matrix = pd.read_csv(adj_file)
|
|
1150
|
+
aug_edge_index = torch.tensor(
|
|
1151
|
+
np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long
|
|
1152
|
+
)
|
|
1153
|
+
aug_graph = Data(
|
|
1154
|
+
x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
|
|
1155
|
+
)
|
|
1156
|
+
aug_graphs.append(aug_graph)
|
|
1157
|
+
return graphs, aug_graphs
|
|
1158
|
+
|
|
1159
|
+
|
|
1160
|
+
def generate_graph_data_parallel_noposition(
|
|
1161
|
+
dataset,
|
|
1162
|
+
cell_list,
|
|
1163
|
+
gene_list,
|
|
1164
|
+
path,
|
|
1165
|
+
base_path,
|
|
1166
|
+
n_sectors,
|
|
1167
|
+
m_rings,
|
|
1168
|
+
k_neighbor,
|
|
1169
|
+
n_jobs=4,
|
|
1170
|
+
):
|
|
1171
|
+
# tqdm progress bar
|
|
1172
|
+
all_cells_genes = [(cell, gene) for cell in cell_list for gene in gene_list]
|
|
1173
|
+
# Run in parallel with progress
|
|
1174
|
+
results = Parallel(n_jobs=n_jobs)(
|
|
1175
|
+
delayed(process_cell_gene_noposition)(
|
|
1176
|
+
cell, gene, dataset, path, n_sectors, m_rings, k_neighbor, base_path
|
|
1177
|
+
)
|
|
1178
|
+
for cell, gene in tqdm(
|
|
1179
|
+
all_cells_genes,
|
|
1180
|
+
desc="Processing cells and genes",
|
|
1181
|
+
total=len(all_cells_genes),
|
|
1182
|
+
)
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
# Merge results
|
|
1186
|
+
graphs = []
|
|
1187
|
+
aug_graphs = []
|
|
1188
|
+
for result in results:
|
|
1189
|
+
g, ag = result
|
|
1190
|
+
graphs.extend(g)
|
|
1191
|
+
aug_graphs.extend(ag)
|
|
1192
|
+
|
|
1193
|
+
return graphs, aug_graphs
|
|
1194
|
+
|
|
1195
|
+
|
|
1196
|
+
def generate_graph_data_target(dataset, df, path, n_sectors, m_rings, k_neighbor):
|
|
1197
|
+
graphs = []
|
|
1198
|
+
aug_graphs = []
|
|
1199
|
+
|
|
1200
|
+
for _, row in tqdm(
|
|
1201
|
+
df.iterrows(),
|
|
1202
|
+
total=len(df),
|
|
1203
|
+
desc="Processing Graphs generate_graph_data_target",
|
|
1204
|
+
):
|
|
1205
|
+
cell = row["cell"]
|
|
1206
|
+
gene = row["gene"]
|
|
1207
|
+
|
|
1208
|
+
# --- Process Original Graph ---
|
|
1209
|
+
raw_path_base = f"{path}/{cell}"
|
|
1210
|
+
nodes_file_orig = f"{raw_path_base}/{gene}_node_matrix.csv"
|
|
1211
|
+
adj_file_orig = f"{raw_path_base}/{gene}_adj_matrix.csv"
|
|
1212
|
+
skip_current_pair = False
|
|
1213
|
+
original_graph_data = None
|
|
1214
|
+
|
|
1215
|
+
try:
|
|
1216
|
+
# Read only necessary columns by name
|
|
1217
|
+
try:
|
|
1218
|
+
# Try reading with is_virtual first
|
|
1219
|
+
node_features_orig_df_raw = pd.read_csv(
|
|
1220
|
+
nodes_file_orig, usecols=["count", "nuclear_position", "is_virtual"]
|
|
1221
|
+
)
|
|
1222
|
+
except ValueError: # If is_virtual column doesn't exist
|
|
1223
|
+
try:
|
|
1224
|
+
node_features_orig_df_raw = pd.read_csv(
|
|
1225
|
+
nodes_file_orig, usecols=["count", "nuclear_position"]
|
|
1226
|
+
)
|
|
1227
|
+
except FileNotFoundError:
|
|
1228
|
+
print(
|
|
1229
|
+
f"Original node_matrix file not found: {nodes_file_orig}. Skipping pair."
|
|
1230
|
+
)
|
|
1231
|
+
skip_current_pair = True
|
|
1232
|
+
raise # Re-raise to be caught by outer try-except
|
|
1233
|
+
except ValueError: # If count or nuclear_position are missing
|
|
1234
|
+
print(
|
|
1235
|
+
f"Original node_matrix {nodes_file_orig} missing 'count' or 'nuclear_position'. Skipping pair."
|
|
1236
|
+
)
|
|
1237
|
+
skip_current_pair = True
|
|
1238
|
+
raise
|
|
1239
|
+
|
|
1240
|
+
if node_features_orig_df_raw.empty:
|
|
1241
|
+
print(
|
|
1242
|
+
f"Original node_matrix {nodes_file_orig} is empty. Skipping pair."
|
|
1243
|
+
)
|
|
1244
|
+
skip_current_pair = True
|
|
1245
|
+
raise FileNotFoundError # Treat as if file not found for outer catch
|
|
1246
|
+
|
|
1247
|
+
node_features_orig_df = node_features_orig_df_raw.copy()
|
|
1248
|
+
|
|
1249
|
+
total_count_orig = node_features_orig_df["count"].sum()
|
|
1250
|
+
node_features_orig_df["count_ratio"] = (
|
|
1251
|
+
0.0
|
|
1252
|
+
if total_count_orig == 0
|
|
1253
|
+
else node_features_orig_df["count"] / total_count_orig
|
|
1254
|
+
)
|
|
1255
|
+
|
|
1256
|
+
count_features_orig_list = (
|
|
1257
|
+
node_features_orig_df["count_ratio"]
|
|
1258
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
|
|
1259
|
+
.tolist()
|
|
1260
|
+
)
|
|
1261
|
+
count_features_orig_embedded_df = pd.DataFrame(
|
|
1262
|
+
count_features_orig_list,
|
|
1263
|
+
columns=[f"dim_{i}" for i in range(12)],
|
|
1264
|
+
index=node_features_orig_df.index,
|
|
1265
|
+
)
|
|
1266
|
+
|
|
1267
|
+
node_features_orig_df["nuclear_position_mapped"] = (
|
|
1268
|
+
node_features_orig_df["nuclear_position"]
|
|
1269
|
+
.map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
|
|
1270
|
+
.fillna(4)
|
|
1271
|
+
.astype(int)
|
|
1272
|
+
)
|
|
1273
|
+
position_orig_one_hot_df = pd.get_dummies(
|
|
1274
|
+
node_features_orig_df["nuclear_position_mapped"],
|
|
1275
|
+
prefix="pos",
|
|
1276
|
+
dtype=int,
|
|
1277
|
+
)
|
|
1278
|
+
|
|
1279
|
+
# Ensure all graphs have the same position feature dimensionality (pos_0..pos_3)
|
|
1280
|
+
expected_pos_cols = ["pos_0", "pos_1", "pos_2", "pos_3"]
|
|
1281
|
+
for col in expected_pos_cols:
|
|
1282
|
+
if col not in position_orig_one_hot_df.columns:
|
|
1283
|
+
position_orig_one_hot_df[col] = 0
|
|
1284
|
+
position_orig_one_hot_df = position_orig_one_hot_df[expected_pos_cols]
|
|
1285
|
+
|
|
1286
|
+
final_node_features_orig_df = pd.concat(
|
|
1287
|
+
[count_features_orig_embedded_df, position_orig_one_hot_df], axis=1
|
|
1288
|
+
)
|
|
1289
|
+
node_features_orig_tensor = torch.tensor(
|
|
1290
|
+
final_node_features_orig_df.values, dtype=torch.float
|
|
1291
|
+
)
|
|
1292
|
+
|
|
1293
|
+
try:
|
|
1294
|
+
adj_matrix_orig_df = pd.read_csv(adj_file_orig)
|
|
1295
|
+
except FileNotFoundError:
|
|
1296
|
+
print(
|
|
1297
|
+
f"Original adj_matrix file not found: {adj_file_orig}. Skipping pair."
|
|
1298
|
+
)
|
|
1299
|
+
skip_current_pair = True
|
|
1300
|
+
raise
|
|
1301
|
+
|
|
1302
|
+
if adj_matrix_orig_df.empty and node_features_orig_tensor.shape[0] > 0:
|
|
1303
|
+
print(
|
|
1304
|
+
f"Warning: Original adj_matrix {adj_file_orig} is empty for a graph with {node_features_orig_tensor.shape[0]} nodes."
|
|
1305
|
+
)
|
|
1306
|
+
edge_index_orig = torch.empty((2, 0), dtype=torch.long)
|
|
1307
|
+
elif node_features_orig_tensor.shape[0] == 0:
|
|
1308
|
+
edge_index_orig = torch.empty((2, 0), dtype=torch.long)
|
|
1309
|
+
else:
|
|
1310
|
+
edge_index_orig = torch.tensor(
|
|
1311
|
+
np.array(adj_matrix_orig_df.values.nonzero()), dtype=torch.long
|
|
1312
|
+
)
|
|
1313
|
+
|
|
1314
|
+
original_graph_data = Data(
|
|
1315
|
+
x=node_features_orig_tensor,
|
|
1316
|
+
edge_index=edge_index_orig,
|
|
1317
|
+
cell=cell,
|
|
1318
|
+
gene=gene,
|
|
1319
|
+
)
|
|
1320
|
+
|
|
1321
|
+
except (
|
|
1322
|
+
FileNotFoundError,
|
|
1323
|
+
ValueError,
|
|
1324
|
+
KeyError,
|
|
1325
|
+
) as e_orig: # Catch specific errors for original graph processing
|
|
1326
|
+
# This block now effectively handles and reports on the re-raised errors from inner try-excepts.
|
|
1327
|
+
# The specific print statements inside the inner blocks provide more detail.
|
|
1328
|
+
pass # Continue to next pair in the main loop
|
|
1329
|
+
except Exception as e_general_orig:
|
|
1330
|
+
print(
|
|
1331
|
+
f"Unexpected error processing original graph for {cell}-{gene}: {e_general_orig}. Skipping pair."
|
|
1332
|
+
)
|
|
1333
|
+
skip_current_pair = True
|
|
1334
|
+
pass
|
|
1335
|
+
|
|
1336
|
+
if skip_current_pair:
|
|
1337
|
+
continue # Move to the next cell-gene pair
|
|
1338
|
+
|
|
1339
|
+
# --- Process Augmented Graph ---
|
|
1340
|
+
aug_path_base = f"{path}/{cell}_aug"
|
|
1341
|
+
nodes_file_aug = f"{aug_path_base}/{gene}_node_matrix.csv"
|
|
1342
|
+
adj_file_aug = f"{aug_path_base}/{gene}_adj_matrix.csv"
|
|
1343
|
+
augmented_graph_data = None
|
|
1344
|
+
|
|
1345
|
+
try:
|
|
1346
|
+
try:
|
|
1347
|
+
node_features_aug_df_raw = pd.read_csv(
|
|
1348
|
+
nodes_file_aug, usecols=["count", "nuclear_position", "is_virtual"]
|
|
1349
|
+
)
|
|
1350
|
+
except ValueError:
|
|
1351
|
+
try:
|
|
1352
|
+
node_features_aug_df_raw = pd.read_csv(
|
|
1353
|
+
nodes_file_aug, usecols=["count", "nuclear_position"]
|
|
1354
|
+
)
|
|
1355
|
+
except FileNotFoundError:
|
|
1356
|
+
print(
|
|
1357
|
+
f"Augmented node_matrix file not found: {nodes_file_aug}. Original graph will be kept if processed."
|
|
1358
|
+
)
|
|
1359
|
+
raise # Re-raise to be caught by outer try-except for augmented graph
|
|
1360
|
+
except ValueError:
|
|
1361
|
+
print(
|
|
1362
|
+
f"Augmented node_matrix {nodes_file_aug} missing 'count' or 'nuclear_position'. Original graph will be kept."
|
|
1363
|
+
)
|
|
1364
|
+
raise
|
|
1365
|
+
|
|
1366
|
+
if node_features_aug_df_raw.empty:
|
|
1367
|
+
print(
|
|
1368
|
+
f"Augmented node_matrix {nodes_file_aug} is empty. Original graph will be kept."
|
|
1369
|
+
)
|
|
1370
|
+
raise FileNotFoundError
|
|
1371
|
+
|
|
1372
|
+
node_features_aug_df = node_features_aug_df_raw.copy()
|
|
1373
|
+
|
|
1374
|
+
total_count_aug = node_features_aug_df["count"].sum()
|
|
1375
|
+
node_features_aug_df["count_ratio"] = (
|
|
1376
|
+
0.0
|
|
1377
|
+
if total_count_aug == 0
|
|
1378
|
+
else node_features_aug_df["count"] / total_count_aug
|
|
1379
|
+
)
|
|
1380
|
+
|
|
1381
|
+
count_features_aug_list = (
|
|
1382
|
+
node_features_aug_df["count_ratio"]
|
|
1383
|
+
.apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
|
|
1384
|
+
.tolist()
|
|
1385
|
+
)
|
|
1386
|
+
count_features_aug_embedded_df = pd.DataFrame(
|
|
1387
|
+
count_features_aug_list,
|
|
1388
|
+
columns=[f"dim_{i}" for i in range(12)],
|
|
1389
|
+
index=node_features_aug_df.index,
|
|
1390
|
+
)
|
|
1391
|
+
|
|
1392
|
+
node_features_aug_df["nuclear_position_mapped"] = (
|
|
1393
|
+
node_features_aug_df["nuclear_position"]
|
|
1394
|
+
.map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
|
|
1395
|
+
.fillna(4)
|
|
1396
|
+
.astype(int)
|
|
1397
|
+
)
|
|
1398
|
+
position_aug_one_hot_df = pd.get_dummies(
|
|
1399
|
+
node_features_aug_df["nuclear_position_mapped"], prefix="pos", dtype=int
|
|
1400
|
+
)
|
|
1401
|
+
|
|
1402
|
+
# Ensure augmented graphs also have the same position feature dimensionality (pos_0..pos_3)
|
|
1403
|
+
expected_pos_cols = ["pos_0", "pos_1", "pos_2", "pos_3"]
|
|
1404
|
+
for col in expected_pos_cols:
|
|
1405
|
+
if col not in position_aug_one_hot_df.columns:
|
|
1406
|
+
position_aug_one_hot_df[col] = 0
|
|
1407
|
+
position_aug_one_hot_df = position_aug_one_hot_df[expected_pos_cols]
|
|
1408
|
+
|
|
1409
|
+
final_node_features_aug_df = pd.concat(
|
|
1410
|
+
[count_features_aug_embedded_df, position_aug_one_hot_df], axis=1
|
|
1411
|
+
)
|
|
1412
|
+
node_features_aug_tensor = torch.tensor(
|
|
1413
|
+
final_node_features_aug_df.values, dtype=torch.float
|
|
1414
|
+
)
|
|
1415
|
+
|
|
1416
|
+
try:
|
|
1417
|
+
adj_matrix_aug_df = pd.read_csv(adj_file_aug)
|
|
1418
|
+
except FileNotFoundError:
|
|
1419
|
+
print(
|
|
1420
|
+
f"Augmented adj_matrix file not found: {adj_file_aug}. Original graph will be kept."
|
|
1421
|
+
)
|
|
1422
|
+
raise
|
|
1423
|
+
|
|
1424
|
+
if adj_matrix_aug_df.empty and node_features_aug_tensor.shape[0] > 0:
|
|
1425
|
+
print(
|
|
1426
|
+
f"Warning: Augmented adj_matrix {adj_file_aug} is empty for a graph with {node_features_aug_tensor.shape[0]} nodes."
|
|
1427
|
+
)
|
|
1428
|
+
edge_index_aug = torch.empty((2, 0), dtype=torch.long)
|
|
1429
|
+
elif node_features_aug_tensor.shape[0] == 0:
|
|
1430
|
+
edge_index_aug = torch.empty((2, 0), dtype=torch.long)
|
|
1431
|
+
else:
|
|
1432
|
+
edge_index_aug = torch.tensor(
|
|
1433
|
+
np.array(adj_matrix_aug_df.values.nonzero()), dtype=torch.long
|
|
1434
|
+
)
|
|
1435
|
+
|
|
1436
|
+
augmented_graph_data = Data(
|
|
1437
|
+
x=node_features_aug_tensor,
|
|
1438
|
+
edge_index=edge_index_aug,
|
|
1439
|
+
cell=cell,
|
|
1440
|
+
gene=gene,
|
|
1441
|
+
)
|
|
1442
|
+
|
|
1443
|
+
except (FileNotFoundError, ValueError, KeyError) as e_aug:
|
|
1444
|
+
print(
|
|
1445
|
+
f"Failed to process augmented graph for {cell}-{gene}: {e_aug}. Original graph will be kept if available."
|
|
1446
|
+
)
|
|
1447
|
+
pass # Keep original graph if it was successfully processed
|
|
1448
|
+
except Exception as e_general_aug:
|
|
1449
|
+
print(
|
|
1450
|
+
f"Unexpected error processing augmented graph for {cell}-{gene}: {e_general_aug}. Original graph will be kept if available."
|
|
1451
|
+
)
|
|
1452
|
+
pass
|
|
1453
|
+
|
|
1454
|
+
# Add to lists if both original and augmented were processed successfully
|
|
1455
|
+
if original_graph_data and augmented_graph_data:
|
|
1456
|
+
graphs.append(original_graph_data)
|
|
1457
|
+
aug_graphs.append(augmented_graph_data)
|
|
1458
|
+
elif original_graph_data: # Only original was successful
|
|
1459
|
+
print(
|
|
1460
|
+
f"Only original graph processed for {cell}-{gene}. Augmented failed or was missing."
|
|
1461
|
+
)
|
|
1462
|
+
# Decide if you want to add originals even if augmented is missing.
|
|
1463
|
+
# For now, we are only adding pairs. To change this, uncomment the lines below:
|
|
1464
|
+
# graphs.append(original_graph_data)
|
|
1465
|
+
# aug_graphs.append(original_graph_data) # Or some placeholder for aug_graph
|
|
1466
|
+
pass
|
|
1467
|
+
|
|
1468
|
+
return graphs, aug_graphs
|
|
1469
|
+
|
|
1470
|
+
|
|
1471
|
+
def generate_graph_data_target_weight3(
|
|
1472
|
+
dataset, df, path, n_sectors, m_rings, k_neighbor
|
|
1473
|
+
):
|
|
1474
|
+
"""
|
|
1475
|
+
Robustly loads original and augmented graph data for a target dataframe.
|
|
1476
|
+
This version uses sinusoidal embedding for count features.
|
|
1477
|
+
This version creates a 16-dimensional node feature vector:
|
|
1478
|
+
- 12 dims for sinusoidal count embedding.
|
|
1479
|
+
- 3 dims for one-hot encoded nuclear position ('inside', 'outside', 'boundary').
|
|
1480
|
+
- 1 dim for the 'is_edge' flag.
|
|
1481
|
+
Virtual nodes are processed normally with their actual features.
|
|
1482
|
+
"""
|
|
1483
|
+
graphs = []
|
|
1484
|
+
aug_graphs = []
|
|
1485
|
+
|
|
1486
|
+
for _, row in tqdm(
|
|
1487
|
+
df.iterrows(),
|
|
1488
|
+
total=len(df),
|
|
1489
|
+
desc="Processing Graphs for generate_graph_data_target_weight3",
|
|
1490
|
+
):
|
|
1491
|
+
cell = row["cell"]
|
|
1492
|
+
gene = row["gene"]
|
|
1493
|
+
|
|
1494
|
+
# --- Process Original Graph ---
|
|
1495
|
+
raw_path_base = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
|
|
1496
|
+
nodes_file_orig = f"{raw_path_base}/{gene}_node_matrix.csv"
|
|
1497
|
+
adj_file_orig = f"{raw_path_base}/{gene}_adj_matrix.csv"
|
|
1498
|
+
original_graph_data = None
|
|
1499
|
+
|
|
1500
|
+
try:
|
|
1501
|
+
# Step 1: Robustly read node data
|
|
1502
|
+
try:
|
|
1503
|
+
node_features_orig_df_raw = pd.read_csv(
|
|
1504
|
+
nodes_file_orig,
|
|
1505
|
+
usecols=["count", "nuclear_position", "is_edge", "is_virtual"],
|
|
1506
|
+
)
|
|
1507
|
+
except ValueError:
|
|
1508
|
+
try:
|
|
1509
|
+
node_features_orig_df_raw = pd.read_csv(
|
|
1510
|
+
nodes_file_orig,
|
|
1511
|
+
usecols=["count", "nuclear_position", "is_edge"],
|
|
1512
|
+
)
|
|
1513
|
+
node_features_orig_df_raw["is_virtual"] = (
|
|
1514
|
+
0 # Assume not virtual if column is missing
|
|
1515
|
+
)
|
|
1516
|
+
except ValueError:
|
|
1517
|
+
try:
|
|
1518
|
+
node_features_orig_df_raw = pd.read_csv(
|
|
1519
|
+
nodes_file_orig, usecols=["count", "nuclear_position"]
|
|
1520
|
+
)
|
|
1521
|
+
node_features_orig_df_raw["is_edge"] = (
|
|
1522
|
+
0 # Assume not edge if column is missing
|
|
1523
|
+
)
|
|
1524
|
+
node_features_orig_df_raw["is_virtual"] = 0
|
|
1525
|
+
except (FileNotFoundError, ValueError) as e:
|
|
1526
|
+
print(
|
|
1527
|
+
f"Original node_matrix file error for {nodes_file_orig}: {e}. Skipping pair."
|
|
1528
|
+
)
|
|
1529
|
+
raise e from None
|
|
1530
|
+
|
|
1531
|
+
if node_features_orig_df_raw.empty:
|
|
1532
|
+
print(
|
|
1533
|
+
f"Original node_matrix {nodes_file_orig} is empty. Skipping pair."
|
|
1534
|
+
)
|
|
1535
|
+
raise FileNotFoundError
|
|
1536
|
+
|
|
1537
|
+
node_features_orig_df = node_features_orig_df_raw.copy()
|
|
1538
|
+
|
|
1539
|
+
# Step 2: Feature Engineering
|
|
1540
|
+
# Count embedding (12-dim)
|
|
1541
|
+
total_count_orig = node_features_orig_df["count"].sum()
|
|
1542
|
+
node_features_orig_df["count_ratio"] = (
|
|
1543
|
+
0.0
|
|
1544
|
+
if total_count_orig == 0
|
|
1545
|
+
else node_features_orig_df["count"] / total_count_orig
|
|
1546
|
+
)
|
|
1547
|
+
count_features_list = (
|
|
1548
|
+
node_features_orig_df["count_ratio"]
|
|
1549
|
+
.apply(
|
|
1550
|
+
lambda x: emb.get_sinusoidal_embedding_for_continuous_value(
|
|
1551
|
+
x, dim=12
|
|
1552
|
+
)
|
|
1553
|
+
)
|
|
1554
|
+
.tolist()
|
|
1555
|
+
)
|
|
1556
|
+
count_features_embedded_df = pd.DataFrame(
|
|
1557
|
+
count_features_list,
|
|
1558
|
+
columns=[f"dim_{i}" for i in range(12)],
|
|
1559
|
+
index=node_features_orig_df.index,
|
|
1560
|
+
)
|
|
1561
|
+
|
|
1562
|
+
# Nuclear position one-hot (3-dim)
|
|
1563
|
+
expected_pos_cats = ["inside", "outside", "boundary"]
|
|
1564
|
+
node_features_orig_df["nuclear_position_cat"] = pd.Categorical(
|
|
1565
|
+
node_features_orig_df["nuclear_position"], categories=expected_pos_cats
|
|
1566
|
+
)
|
|
1567
|
+
position_one_hot_df = pd.get_dummies(
|
|
1568
|
+
node_features_orig_df["nuclear_position_cat"], prefix="pos", dtype=int
|
|
1569
|
+
)
|
|
1570
|
+
|
|
1571
|
+
# is_edge feature (1-dim)
|
|
1572
|
+
is_edge_df = node_features_orig_df[["is_edge"]]
|
|
1573
|
+
|
|
1574
|
+
# Combine features
|
|
1575
|
+
final_node_features_df = pd.concat(
|
|
1576
|
+
[count_features_embedded_df, position_one_hot_df, is_edge_df], axis=1
|
|
1577
|
+
)
|
|
1578
|
+
|
|
1579
|
+
# Virtual nodes keep their original features (no longer set to zero)
|
|
1580
|
+
node_features_orig_tensor = torch.tensor(
|
|
1581
|
+
final_node_features_df.values, dtype=torch.float
|
|
1582
|
+
)
|
|
1583
|
+
|
|
1584
|
+
# Step 3: Load Adjacency Matrix
|
|
1585
|
+
try:
|
|
1586
|
+
adj_matrix_orig_df = pd.read_csv(adj_file_orig)
|
|
1587
|
+
except FileNotFoundError:
|
|
1588
|
+
print(
|
|
1589
|
+
f"Original adj_matrix file not found: {adj_file_orig}. Skipping pair."
|
|
1590
|
+
)
|
|
1591
|
+
raise
|
|
1592
|
+
|
|
1593
|
+
if adj_matrix_orig_df.empty and node_features_orig_tensor.shape[0] > 0:
|
|
1594
|
+
edge_index_orig = torch.empty((2, 0), dtype=torch.long)
|
|
1595
|
+
elif node_features_orig_tensor.shape[0] == 0:
|
|
1596
|
+
edge_index_orig = torch.empty((2, 0), dtype=torch.long)
|
|
1597
|
+
else:
|
|
1598
|
+
edge_index_orig = torch.tensor(
|
|
1599
|
+
np.array(adj_matrix_orig_df.values.nonzero()), dtype=torch.long
|
|
1600
|
+
)
|
|
1601
|
+
|
|
1602
|
+
original_graph_data = Data(
|
|
1603
|
+
x=node_features_orig_tensor,
|
|
1604
|
+
edge_index=edge_index_orig,
|
|
1605
|
+
cell=cell,
|
|
1606
|
+
gene=gene,
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1609
|
+
except (FileNotFoundError, ValueError, KeyError) as e_orig:
|
|
1610
|
+
pass
|
|
1611
|
+
except Exception as e_general_orig:
|
|
1612
|
+
print(
|
|
1613
|
+
f"Unexpected error processing original graph for {cell}-{gene}: {e_general_orig}. Skipping pair."
|
|
1614
|
+
)
|
|
1615
|
+
pass
|
|
1616
|
+
|
|
1617
|
+
if original_graph_data is None:
|
|
1618
|
+
continue
|
|
1619
|
+
|
|
1620
|
+
# --- Process Augmented Graph ---
|
|
1621
|
+
aug_path_base = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
|
|
1622
|
+
nodes_file_aug = f"{aug_path_base}/{gene}_node_matrix.csv"
|
|
1623
|
+
adj_file_aug = f"{aug_path_base}/{gene}_adj_matrix.csv"
|
|
1624
|
+
augmented_graph_data = None
|
|
1625
|
+
|
|
1626
|
+
try:
|
|
1627
|
+
# Step 1: Robustly read node data
|
|
1628
|
+
try:
|
|
1629
|
+
node_features_aug_df_raw = pd.read_csv(
|
|
1630
|
+
nodes_file_aug,
|
|
1631
|
+
usecols=["count", "nuclear_position", "is_edge", "is_virtual"],
|
|
1632
|
+
)
|
|
1633
|
+
except ValueError:
|
|
1634
|
+
try:
|
|
1635
|
+
node_features_aug_df_raw = pd.read_csv(
|
|
1636
|
+
nodes_file_aug, usecols=["count", "nuclear_position", "is_edge"]
|
|
1637
|
+
)
|
|
1638
|
+
node_features_aug_df_raw["is_virtual"] = 0
|
|
1639
|
+
except ValueError:
|
|
1640
|
+
try:
|
|
1641
|
+
node_features_aug_df_raw = pd.read_csv(
|
|
1642
|
+
nodes_file_aug, usecols=["count", "nuclear_position"]
|
|
1643
|
+
)
|
|
1644
|
+
node_features_aug_df_raw["is_edge"] = 0
|
|
1645
|
+
node_features_aug_df_raw["is_virtual"] = 0
|
|
1646
|
+
except (FileNotFoundError, ValueError) as e:
|
|
1647
|
+
print(
|
|
1648
|
+
f"Augmented node_matrix file error for {nodes_file_aug}: {e}. Original graph will be kept if processed."
|
|
1649
|
+
)
|
|
1650
|
+
raise e from None
|
|
1651
|
+
|
|
1652
|
+
if node_features_aug_df_raw.empty:
|
|
1653
|
+
print(
|
|
1654
|
+
f"Augmented node_matrix {nodes_file_aug} is empty. Original graph will be kept."
|
|
1655
|
+
)
|
|
1656
|
+
raise FileNotFoundError
|
|
1657
|
+
|
|
1658
|
+
node_features_aug_df = node_features_aug_df_raw.copy()
|
|
1659
|
+
|
|
1660
|
+
# Step 2: Feature Engineering for augmented graph
|
|
1661
|
+
total_count_aug = node_features_aug_df["count"].sum()
|
|
1662
|
+
node_features_aug_df["count_ratio"] = (
|
|
1663
|
+
0.0
|
|
1664
|
+
if total_count_aug == 0
|
|
1665
|
+
else node_features_aug_df["count"] / total_count_aug
|
|
1666
|
+
)
|
|
1667
|
+
count_features_aug_list = (
|
|
1668
|
+
node_features_aug_df["count_ratio"]
|
|
1669
|
+
.apply(
|
|
1670
|
+
lambda x: emb.get_sinusoidal_embedding_for_continuous_value(
|
|
1671
|
+
x, dim=12
|
|
1672
|
+
)
|
|
1673
|
+
)
|
|
1674
|
+
.tolist()
|
|
1675
|
+
)
|
|
1676
|
+
count_features_aug_embedded_df = pd.DataFrame(
|
|
1677
|
+
count_features_aug_list,
|
|
1678
|
+
columns=[f"dim_{i}" for i in range(12)],
|
|
1679
|
+
index=node_features_aug_df.index,
|
|
1680
|
+
)
|
|
1681
|
+
|
|
1682
|
+
node_features_aug_df["nuclear_position_cat"] = pd.Categorical(
|
|
1683
|
+
node_features_aug_df["nuclear_position"], categories=expected_pos_cats
|
|
1684
|
+
)
|
|
1685
|
+
position_aug_one_hot_df = pd.get_dummies(
|
|
1686
|
+
node_features_aug_df["nuclear_position_cat"], prefix="pos", dtype=int
|
|
1687
|
+
)
|
|
1688
|
+
|
|
1689
|
+
is_edge_aug_df = node_features_aug_df[["is_edge"]]
|
|
1690
|
+
|
|
1691
|
+
final_node_features_aug_df = pd.concat(
|
|
1692
|
+
[
|
|
1693
|
+
count_features_aug_embedded_df,
|
|
1694
|
+
position_aug_one_hot_df,
|
|
1695
|
+
is_edge_aug_df,
|
|
1696
|
+
],
|
|
1697
|
+
axis=1,
|
|
1698
|
+
)
|
|
1699
|
+
|
|
1700
|
+
# Virtual nodes keep their original features (no longer set to zero)
|
|
1701
|
+
node_features_aug_tensor = torch.tensor(
|
|
1702
|
+
final_node_features_aug_df.values, dtype=torch.float
|
|
1703
|
+
)
|
|
1704
|
+
|
|
1705
|
+
# Step 3: Load Adjacency Matrix for augmented graph
|
|
1706
|
+
try:
|
|
1707
|
+
adj_matrix_aug_df = pd.read_csv(adj_file_aug)
|
|
1708
|
+
except FileNotFoundError:
|
|
1709
|
+
print(
|
|
1710
|
+
f"Augmented adj_matrix file not found: {adj_file_aug}. Original graph will be kept."
|
|
1711
|
+
)
|
|
1712
|
+
raise
|
|
1713
|
+
|
|
1714
|
+
if adj_matrix_aug_df.empty and node_features_aug_tensor.shape[0] > 0:
|
|
1715
|
+
edge_index_aug = torch.empty((2, 0), dtype=torch.long)
|
|
1716
|
+
elif node_features_aug_tensor.shape[0] == 0:
|
|
1717
|
+
edge_index_aug = torch.empty((2, 0), dtype=torch.long)
|
|
1718
|
+
else:
|
|
1719
|
+
edge_index_aug = torch.tensor(
|
|
1720
|
+
np.array(adj_matrix_aug_df.values.nonzero()), dtype=torch.long
|
|
1721
|
+
)
|
|
1722
|
+
|
|
1723
|
+
augmented_graph_data = Data(
|
|
1724
|
+
x=node_features_aug_tensor,
|
|
1725
|
+
edge_index=edge_index_aug,
|
|
1726
|
+
cell=cell,
|
|
1727
|
+
gene=gene,
|
|
1728
|
+
)
|
|
1729
|
+
|
|
1730
|
+
except (FileNotFoundError, ValueError, KeyError) as e_aug:
|
|
1731
|
+
pass
|
|
1732
|
+
except Exception as e_general_aug:
|
|
1733
|
+
print(
|
|
1734
|
+
f"Unexpected error processing augmented graph for {cell}-{gene}: {e_general_aug}. Original graph will be kept if available."
|
|
1735
|
+
)
|
|
1736
|
+
pass
|
|
1737
|
+
|
|
1738
|
+
# Add to lists if both original and augmented were processed successfully
|
|
1739
|
+
if original_graph_data and augmented_graph_data:
|
|
1740
|
+
graphs.append(original_graph_data)
|
|
1741
|
+
aug_graphs.append(augmented_graph_data)
|
|
1742
|
+
elif original_graph_data:
|
|
1743
|
+
print(
|
|
1744
|
+
f"Only original graph processed for {cell}-{gene}. Augmented failed or was missing."
|
|
1745
|
+
)
|
|
1746
|
+
pass
|
|
1747
|
+
|
|
1748
|
+
return graphs, aug_graphs
|