grasp-tool 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1748 @@
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch_geometric.data import Data
6
+ from tqdm import tqdm
7
+ import torch.nn as nn
8
+ from torch_geometric.nn import GCNConv
9
+ import numpy as np
10
+ import random
11
+ from sklearn.cluster import KMeans
12
+ import warnings
13
+ from . import embedding as emb
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ from joblib import Parallel, delayed
16
+ from multiprocessing import Pool, cpu_count
17
+ import multiprocessing
18
+ from multiprocessing import Pool, cpu_count
19
+
20
+
21
+ def process_cell_gene(
22
+ cell, gene, dataset, path, n_sectors, m_rings, k_neighbor, base_path
23
+ ):
24
+ graphs = []
25
+ aug_graphs = []
26
+
27
+ raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
28
+ aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
29
+
30
+ # df_file = f"{base_path}/{gene}_distances_filter_new.csv"
31
+ df_file = f"{base_path}/{gene}_distances_filtered.csv"
32
+ if not os.path.exists(df_file):
33
+ # print(f"1. Skipping {gene} in {cell} (file not found).")
34
+ return graphs, aug_graphs
35
+
36
+ df = pd.read_csv(df_file)
37
+ filtered_df = df[
38
+ (df["gene"] == gene) & (df["cell"] == cell) & (df["location"] == "other")
39
+ ]
40
+ if filtered_df.empty:
41
+ # print(f"Skipping {gene} in cell {cell} filtered_df is empty")
42
+ return graphs, aug_graphs
43
+
44
+ # Original graph
45
+ nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
46
+ adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
47
+ if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
48
+ # print(f"1. Skipping {gene} in {cell} (file not found).")
49
+ return graphs, aug_graphs
50
+
51
+ node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
52
+ # if len(node_features) <= 5:
53
+ # # print(f"2. Skipping {gene} in {cell} (too few points).")
54
+ # return graphs, aug_graphs
55
+
56
+ node_features["nuclear_position"] = (
57
+ node_features["nuclear_position"]
58
+ .map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
59
+ .fillna(4)
60
+ .astype(int)
61
+ )
62
+ count_features = (
63
+ node_features["count"]
64
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
65
+ .tolist()
66
+ )
67
+ count_features = pd.DataFrame(
68
+ count_features, columns=[f"dim_{i}" for i in range(12)]
69
+ )
70
+ position_one_hot = pd.get_dummies(
71
+ node_features["nuclear_position"], prefix="pos"
72
+ ).astype(int)
73
+ node_features = pd.concat([count_features, position_one_hot], axis=1)
74
+ node_features_tensor = torch.tensor(node_features.values, dtype=torch.float)
75
+ adj_matrix = pd.read_csv(adj_file)
76
+ # edge_index = torch.tensor(adj_matrix.values.nonzero(), dtype=torch.long)
77
+ edge_index = torch.tensor(np.array(adj_matrix.values.nonzero()), dtype=torch.long)
78
+ graph = Data(x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene)
79
+ graphs.append(graph)
80
+
81
+ # Augmented graph
82
+ nodes_file = f"{aug_path}/{gene}_node_matrix.csv"
83
+ adj_file = f"{aug_path}/{gene}_adj_matrix.csv"
84
+ if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
85
+ # print(f"Skipping {gene} in {cell} (augmented file not found).")
86
+ return graphs, aug_graphs
87
+
88
+ aug_node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
89
+ aug_node_features["nuclear_position"] = (
90
+ aug_node_features["nuclear_position"]
91
+ .map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
92
+ .fillna(4)
93
+ .astype(int)
94
+ )
95
+ aug_count_features = (
96
+ aug_node_features["count"]
97
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
98
+ .tolist()
99
+ )
100
+ aug_count_features = pd.DataFrame(
101
+ aug_count_features, columns=[f"dim_{i}" for i in range(12)]
102
+ )
103
+ aug_position_one_hot = pd.get_dummies(
104
+ aug_node_features["nuclear_position"], prefix="pos"
105
+ ).astype(int)
106
+ aug_node_features = pd.concat([aug_count_features, aug_position_one_hot], axis=1)
107
+ aug_node_features_tensor = torch.tensor(aug_node_features.values, dtype=torch.float)
108
+ aug_adj_matrix = pd.read_csv(adj_file)
109
+ # aug_edge_index = torch.tensor(aug_adj_matrix.values.nonzero(), dtype=torch.long)
110
+ aug_edge_index = torch.tensor(
111
+ np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long
112
+ )
113
+ aug_graph = Data(
114
+ x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
115
+ )
116
+ aug_graphs.append(aug_graph)
117
+
118
+ return graphs, aug_graphs
119
+
120
+
121
+ def generate_graph_data_parallel(
122
+ dataset,
123
+ cell_list,
124
+ gene_list,
125
+ path,
126
+ base_path,
127
+ n_sectors,
128
+ m_rings,
129
+ k_neighbor,
130
+ n_jobs=4,
131
+ ):
132
+ # tqdm progress bar
133
+ all_cells_genes = [(cell, gene) for cell in cell_list for gene in gene_list]
134
+ # Run in parallel with progress
135
+ results = Parallel(n_jobs=n_jobs)(
136
+ delayed(process_cell_gene)(
137
+ cell, gene, dataset, path, n_sectors, m_rings, k_neighbor, base_path
138
+ )
139
+ for cell, gene in tqdm(
140
+ all_cells_genes,
141
+ desc="Processing cells and genes",
142
+ total=len(all_cells_genes),
143
+ )
144
+ )
145
+ # Merge results
146
+ graphs = []
147
+ aug_graphs = []
148
+ for result in results:
149
+ g, ag = result
150
+ graphs.extend(g)
151
+ aug_graphs.extend(ag)
152
+
153
+ return graphs, aug_graphs
154
+
155
+
156
+ ## simulated_data1 / simulated_data2
157
+ def process_cell_gene_nofiltered(
158
+ cell, gene, dataset, path, n_sectors, m_rings, k_neighbor
159
+ ):
160
+ graphs = []
161
+ aug_graphs = []
162
+ raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
163
+ aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
164
+
165
+ if dataset == "merfish_u2os" or dataset == "seqfish_fibroblast":
166
+ # Optional filtering lists (historically hard-coded).
167
+ # If GRASP_FILTER_ROOT is unset or files are missing, skip filtering.
168
+ filter_root = os.environ.get("GRASP_FILTER_ROOT")
169
+ if filter_root:
170
+ list1_path = os.path.join(filter_root, dataset, "all_low_points_list.csv")
171
+ list2_path = os.path.join(filter_root, dataset, "all_no_points_list.csv")
172
+
173
+ if os.path.exists(list1_path) and os.path.exists(list2_path):
174
+ list1 = pd.read_csv(list1_path)
175
+ list2 = pd.read_csv(list2_path)
176
+
177
+ # Check whether (cell, gene) is in the filter lists.
178
+ is_in_list1 = ((list1["cell"] == cell) & (list1["gene"] == gene)).any()
179
+ is_in_list2 = ((list2["cell"] == cell) & (list2["gene"] == gene)).any()
180
+ if is_in_list1 or is_in_list2:
181
+ print(f"Skipping {gene} in {cell} (found in filter list).")
182
+ return graphs, aug_graphs
183
+
184
+ # Original graph
185
+ nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
186
+ adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
187
+ if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
188
+ # print(f"1. Skipping {gene} in {cell} (file not found).")
189
+ return graphs, aug_graphs
190
+
191
+ node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
192
+ total_count = node_features["count"].sum()
193
+ node_features["count_ratio"] = node_features["count"] / total_count
194
+ # print(node_features.head(5))
195
+ node_features["nuclear_position"] = (
196
+ node_features["nuclear_position"]
197
+ .map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
198
+ .fillna(4)
199
+ .astype(int)
200
+ )
201
+ # count_features = node_features['count'].apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12)).tolist()
202
+ count_features = (
203
+ node_features["count_ratio"]
204
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
205
+ .tolist()
206
+ )
207
+ count_features = pd.DataFrame(
208
+ count_features, columns=[f"dim_{i}" for i in range(12)]
209
+ )
210
+ position_one_hot = pd.get_dummies(
211
+ node_features["nuclear_position"], prefix="pos"
212
+ ).astype(int)
213
+ node_features = pd.concat([count_features, position_one_hot], axis=1)
214
+ node_features_tensor = torch.tensor(node_features.values, dtype=torch.float)
215
+ adj_matrix = pd.read_csv(adj_file)
216
+ # edge_index = torch.tensor(adj_matrix.values.nonzero(), dtype=torch.long)
217
+ edge_index = torch.tensor(np.array(adj_matrix.values.nonzero()), dtype=torch.long)
218
+ graph = Data(x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene)
219
+ graphs.append(graph)
220
+
221
+ # Augmented graph
222
+ nodes_file = f"{aug_path}/{gene}_node_matrix.csv"
223
+ adj_file = f"{aug_path}/{gene}_adj_matrix.csv"
224
+ if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
225
+ # print(f"Skipping {gene} in {cell} (augmented file not found).")
226
+ return graphs, aug_graphs
227
+
228
+ aug_node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
229
+ total_count = aug_node_features["count"].sum()
230
+ aug_node_features["count_ratio"] = aug_node_features["count"] / total_count
231
+ # print(aug_node_features.head(5))
232
+ aug_node_features["nuclear_position"] = (
233
+ aug_node_features["nuclear_position"]
234
+ .map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
235
+ .fillna(4)
236
+ .astype(int)
237
+ )
238
+ # aug_count_features = aug_node_features['count'].apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12)).tolist()
239
+ aug_count_features = (
240
+ aug_node_features["count_ratio"]
241
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
242
+ .tolist()
243
+ )
244
+
245
+ aug_count_features = pd.DataFrame(
246
+ aug_count_features, columns=[f"dim_{i}" for i in range(12)]
247
+ )
248
+ aug_position_one_hot = pd.get_dummies(
249
+ aug_node_features["nuclear_position"], prefix="pos"
250
+ ).astype(int)
251
+ aug_node_features = pd.concat([aug_count_features, aug_position_one_hot], axis=1)
252
+ aug_node_features_tensor = torch.tensor(aug_node_features.values, dtype=torch.float)
253
+ aug_adj_matrix = pd.read_csv(adj_file)
254
+ # aug_edge_index = torch.tensor(aug_adj_matrix.values.nonzero(), dtype=torch.long)
255
+ aug_edge_index = torch.tensor(
256
+ np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long
257
+ )
258
+ aug_graph = Data(
259
+ x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
260
+ )
261
+ aug_graphs.append(aug_graph)
262
+
263
+ return graphs, aug_graphs
264
+
265
+
266
+ def generate_graph_data_parallel_nofiltered(
267
+ dataset, cell_list, gene_list, path, n_sectors, m_rings, k_neighbor, n_jobs=4
268
+ ):
269
+ # tqdm progress bar
270
+ all_cells_genes = [(cell, gene) for cell in cell_list for gene in gene_list]
271
+ # Run in parallel with progress
272
+ results = Parallel(n_jobs=n_jobs)(
273
+ delayed(process_cell_gene_nofiltered)(
274
+ cell, gene, dataset, path, n_sectors, m_rings, k_neighbor
275
+ )
276
+ for cell, gene in tqdm(
277
+ all_cells_genes,
278
+ desc="Processing cells and genes",
279
+ total=len(all_cells_genes),
280
+ )
281
+ )
282
+ # Merge results
283
+ graphs = []
284
+ aug_graphs = []
285
+ for result in results:
286
+ g, ag = result
287
+ graphs.extend(g)
288
+ aug_graphs.extend(ag)
289
+ return graphs, aug_graphs
290
+
291
+ graphs = []
292
+ aug_graphs = []
293
+ # df = df[df['groundtruth_yyzh'].isin(['Nuclear', 'Nuclear_edge', 'Cytoplasmic', 'Cell_edge', 'Random'])].reset_index(drop=True)
294
+
295
+ for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing Graphs"):
296
+ cell = row["cell"]
297
+ gene = row["gene"]
298
+
299
+ raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
300
+ aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
301
+
302
+ # Original graph
303
+ nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
304
+ adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
305
+
306
+ node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
307
+ total_count = node_features["count"].sum()
308
+ node_features["count_ratio"] = node_features["count"] / total_count
309
+ # print(node_features.head(5))
310
+
311
+ node_features["nuclear_position"] = (
312
+ node_features["nuclear_position"]
313
+ .map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
314
+ .fillna(4)
315
+ .astype(int)
316
+ )
317
+ # count_features = node_features['count'].apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12)).tolist()
318
+ count_features = (
319
+ node_features["count_ratio"]
320
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
321
+ .tolist()
322
+ )
323
+ count_features = pd.DataFrame(
324
+ count_features, columns=[f"dim_{i}" for i in range(12)]
325
+ )
326
+ position_one_hot = pd.get_dummies(
327
+ node_features["nuclear_position"], prefix="pos"
328
+ ).astype(int)
329
+ node_features = pd.concat([count_features, position_one_hot], axis=1)
330
+ node_features_tensor = torch.tensor(node_features.values, dtype=torch.float)
331
+ adj_matrix = pd.read_csv(adj_file)
332
+ # edge_index = torch.tensor(adj_matrix.values.nonzero(), dtype=torch.long)
333
+ edge_index = torch.tensor(
334
+ np.array(adj_matrix.values.nonzero()), dtype=torch.long
335
+ )
336
+ graph = Data(
337
+ x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene
338
+ )
339
+ graphs.append(graph)
340
+ # Augmented graph
341
+ nodes_file = f"{aug_path}/{gene}_node_matrix.csv"
342
+ adj_file = f"{aug_path}/{gene}_adj_matrix.csv"
343
+
344
+ aug_node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
345
+ total_count = aug_node_features["count"].sum()
346
+ aug_node_features["count_ratio"] = aug_node_features["count"] / total_count
347
+ # print(aug_node_features.head(5))
348
+ aug_node_features["nuclear_position"] = (
349
+ aug_node_features["nuclear_position"]
350
+ .map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
351
+ .fillna(4)
352
+ .astype(int)
353
+ )
354
+ # aug_count_features = aug_node_features['count'].apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12)).tolist()
355
+ aug_count_features = (
356
+ aug_node_features["count_ratio"]
357
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
358
+ .tolist()
359
+ )
360
+ aug_count_features = pd.DataFrame(
361
+ aug_count_features, columns=[f"dim_{i}" for i in range(12)]
362
+ )
363
+ aug_position_one_hot = pd.get_dummies(
364
+ aug_node_features["nuclear_position"], prefix="pos"
365
+ ).astype(int)
366
+ aug_node_features = pd.concat(
367
+ [aug_count_features, aug_position_one_hot], axis=1
368
+ )
369
+ aug_node_features_tensor = torch.tensor(
370
+ aug_node_features.values, dtype=torch.float
371
+ )
372
+ aug_adj_matrix = pd.read_csv(adj_file)
373
+ # aug_edge_index = torch.tensor(aug_adj_matrix.values.nonzero(), dtype=torch.long)
374
+ aug_edge_index = torch.tensor(
375
+ np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long
376
+ )
377
+ aug_graph = Data(
378
+ x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
379
+ )
380
+ aug_graphs.append(aug_graph)
381
+ return graphs, aug_graphs
382
+
383
+
384
+ def generate_graph_data_target_weight(
385
+ dataset,
386
+ df,
387
+ path,
388
+ n_sectors,
389
+ m_rings,
390
+ k_neighbor,
391
+ knn_k_for_similarity_graph=None,
392
+ knn_k_ratio=0.25,
393
+ knn_k_min=3,
394
+ ):
395
+ graphs = []
396
+ aug_graphs = []
397
+ # df = df[df['groundtruth_yyzh'].isin(['Nuclear', 'Nuclear_edge', 'Cytoplasmic', 'Cell_edge', 'Random'])].reset_index(drop=True)
398
+
399
+ for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing Graphs"):
400
+ cell = row["cell"]
401
+ gene = row["gene"]
402
+
403
+ raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
404
+ aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
405
+
406
+ # Original graph
407
+ nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
408
+
409
+ node_data_df = pd.read_csv(
410
+ nodes_file, usecols=["count", "is_edge", "nuclear_position", "is_virtual"]
411
+ )
412
+
413
+ # --- Filter out virtual nodes ---
414
+ if "is_virtual" in node_data_df.columns:
415
+ node_data_df = node_data_df[node_data_df["is_virtual"] == 0].reset_index(
416
+ drop=True
417
+ )
418
+ # --- End filtering ---
419
+
420
+ if node_data_df.empty: # If no real nodes left
421
+ continue
422
+
423
+ total_count = node_data_df["count"].sum()
424
+ if (
425
+ total_count == 0
426
+ ): # Should not happen if only real nodes with counts are kept, but good for safety
427
+ node_data_df["count_ratio"] = 0
428
+ else:
429
+ node_data_df["count_ratio"] = node_data_df["count"] / total_count
430
+
431
+ count_features_embedded = (
432
+ node_data_df["count_ratio"]
433
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=11))
434
+ .tolist()
435
+ )
436
+ count_features_final_df = pd.DataFrame(
437
+ count_features_embedded,
438
+ columns=[f"count_dim_{i}" for i in range(11)],
439
+ index=node_data_df.index,
440
+ ) # Use filtered df index
441
+
442
+ is_edge_one_hot_df = pd.get_dummies(
443
+ node_data_df["is_edge"], prefix="edge", dtype=int
444
+ )
445
+ # Ensure consistent columns for is_edge (edge_0, edge_1)
446
+ for val in [0, 1]:
447
+ col_name = f"edge_{val}"
448
+ if col_name not in is_edge_one_hot_df.columns:
449
+ is_edge_one_hot_df[col_name] = 0
450
+ is_edge_one_hot_df = is_edge_one_hot_df[
451
+ is_edge_one_hot_df.columns.intersection(["edge_0", "edge_1"])
452
+ ] # Ensure only existing columns selected
453
+ if "edge_0" not in is_edge_one_hot_df.columns:
454
+ is_edge_one_hot_df["edge_0"] = 0
455
+ if "edge_1" not in is_edge_one_hot_df.columns:
456
+ is_edge_one_hot_df["edge_1"] = 0
457
+ is_edge_one_hot_df = is_edge_one_hot_df[["edge_0", "edge_1"]]
458
+
459
+ nuclear_position_one_hot_df = pd.get_dummies(
460
+ node_data_df["nuclear_position"], prefix="pos", dtype=int
461
+ )
462
+ expected_pos_categories = ["inside", "outside", "boundary"]
463
+ for cat in expected_pos_categories:
464
+ col_name = f"pos_{cat}"
465
+ if col_name not in nuclear_position_one_hot_df.columns:
466
+ nuclear_position_one_hot_df[col_name] = 0
467
+ # Ensure correct order and presence of all expected columns
468
+ nuclear_position_one_hot_df = nuclear_position_one_hot_df.reindex(
469
+ columns=[f"pos_{cat}" for cat in expected_pos_categories], fill_value=0
470
+ )
471
+
472
+ final_node_features_df = pd.concat(
473
+ [count_features_final_df, is_edge_one_hot_df, nuclear_position_one_hot_df],
474
+ axis=1,
475
+ )
476
+
477
+ # No need to set virtual node features to 0 as they are already removed
478
+
479
+ node_features_tensor = torch.tensor(
480
+ final_node_features_df.values, dtype=torch.float
481
+ )
482
+
483
+ num_nodes = node_features_tensor.shape[0]
484
+ edge_index = torch.empty(
485
+ (2, 0), dtype=torch.long, device=node_features_tensor.device
486
+ )
487
+
488
+ if num_nodes > 1:
489
+ norm_node_features = F.normalize(node_features_tensor, p=2, dim=1)
490
+ similarity_matrix = torch.matmul(norm_node_features, norm_node_features.t())
491
+
492
+ # Set diagonal to a very small value to prevent self-loops from being top-k
493
+ similarity_matrix.fill_diagonal_(float("-inf"))
494
+
495
+ # --- Dynamically determine k ---
496
+ # 1) If the caller explicitly passes a positive integer (knn_k_for_similarity_graph),
497
+ # we will respect it but still cap it with (num_nodes-1).
498
+ # 2) If the arg is None or a non-positive value, we compute k as
499
+ # k = max(1, int(num_nodes * knn_k_ratio))
500
+ # and again cap it with (num_nodes-1).
501
+ if (
502
+ isinstance(knn_k_for_similarity_graph, int)
503
+ and knn_k_for_similarity_graph > 0
504
+ ):
505
+ current_knn_k = min(knn_k_for_similarity_graph, num_nodes - 1)
506
+ else:
507
+ adaptive_k = max(knn_k_min, int(num_nodes * knn_k_ratio))
508
+ current_knn_k = min(adaptive_k, num_nodes - 1)
509
+
510
+ if current_knn_k > 0:
511
+ # Get top-k similar nodes for each node
512
+ top_k_vals, top_k_indices = torch.topk(
513
+ similarity_matrix, k=current_knn_k, dim=1
514
+ )
515
+
516
+ # Create adjacency matrix from top-k indices
517
+ adj = torch.zeros_like(similarity_matrix, dtype=torch.bool)
518
+ row_indices = (
519
+ torch.arange(num_nodes, device=node_features_tensor.device)
520
+ .unsqueeze(1)
521
+ .expand(-1, current_knn_k)
522
+ )
523
+ adj[row_indices, top_k_indices] = True
524
+
525
+ # Symmetrize the adjacency matrix
526
+ adj = adj | adj.t()
527
+
528
+ # No need to remove virtual node edges as virtual nodes themselves are removed
529
+ edge_index = adj.nonzero(as_tuple=False).t().contiguous()
530
+
531
+ graph = Data(
532
+ x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene
533
+ )
534
+ graphs.append(graph)
535
+
536
+ # Augmented graph
537
+ nodes_file_aug = f"{aug_path}/{gene}_node_matrix.csv"
538
+ aug_node_data_df = pd.read_csv(
539
+ nodes_file_aug,
540
+ usecols=["count", "is_edge", "nuclear_position", "is_virtual"],
541
+ )
542
+
543
+ # --- Filter out virtual nodes for augmented graph ---
544
+ if "is_virtual" in aug_node_data_df.columns:
545
+ aug_node_data_df = aug_node_data_df[
546
+ aug_node_data_df["is_virtual"] == 0
547
+ ].reset_index(drop=True)
548
+ # --- End filtering ---
549
+
550
+ if aug_node_data_df.empty: # If no real nodes left in augmented graph
551
+ # If original graph was added, but augmented has no real nodes, we might still want the original.
552
+ # Current behavior: skip adding an augmented graph, original graph is already in the list.
553
+ # If you require pairs, you might need to remove the original graph too or handle this case differently.
554
+ continue
555
+
556
+ aug_total_count = aug_node_data_df["count"].sum()
557
+ if aug_total_count == 0:
558
+ aug_node_data_df["count_ratio"] = 0
559
+ else:
560
+ aug_node_data_df["count_ratio"] = (
561
+ aug_node_data_df["count"] / aug_total_count
562
+ )
563
+
564
+ aug_count_features_embedded = (
565
+ aug_node_data_df["count_ratio"]
566
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=11))
567
+ .tolist()
568
+ )
569
+ aug_count_features_final_df = pd.DataFrame(
570
+ aug_count_features_embedded,
571
+ columns=[f"count_dim_{i}" for i in range(11)],
572
+ index=aug_node_data_df.index,
573
+ ) # Use filtered df index
574
+
575
+ aug_is_edge_one_hot_df = pd.get_dummies(
576
+ aug_node_data_df["is_edge"], prefix="edge", dtype=int
577
+ )
578
+ for val in [0, 1]:
579
+ col_name = f"edge_{val}"
580
+ if col_name not in aug_is_edge_one_hot_df.columns:
581
+ aug_is_edge_one_hot_df[col_name] = 0
582
+ aug_is_edge_one_hot_df = aug_is_edge_one_hot_df[
583
+ aug_is_edge_one_hot_df.columns.intersection(["edge_0", "edge_1"])
584
+ ]
585
+ if "edge_0" not in aug_is_edge_one_hot_df.columns:
586
+ aug_is_edge_one_hot_df["edge_0"] = 0
587
+ if "edge_1" not in aug_is_edge_one_hot_df.columns:
588
+ aug_is_edge_one_hot_df["edge_1"] = 0
589
+ aug_is_edge_one_hot_df = aug_is_edge_one_hot_df[["edge_0", "edge_1"]]
590
+
591
+ aug_nuclear_position_one_hot_df = pd.get_dummies(
592
+ aug_node_data_df["nuclear_position"], prefix="pos", dtype=int
593
+ )
594
+ for cat in expected_pos_categories:
595
+ col_name = f"pos_{cat}"
596
+ if col_name not in aug_nuclear_position_one_hot_df.columns:
597
+ aug_nuclear_position_one_hot_df[col_name] = 0
598
+ aug_nuclear_position_one_hot_df = aug_nuclear_position_one_hot_df.reindex(
599
+ columns=[f"pos_{cat}" for cat in expected_pos_categories], fill_value=0
600
+ )
601
+
602
+ aug_final_node_features_df = pd.concat(
603
+ [
604
+ aug_count_features_final_df,
605
+ aug_is_edge_one_hot_df,
606
+ aug_nuclear_position_one_hot_df,
607
+ ],
608
+ axis=1,
609
+ )
610
+
611
+ # No need to set virtual node features to 0 as they are already removed
612
+
613
+ aug_node_features_tensor = torch.tensor(
614
+ aug_final_node_features_df.values, dtype=torch.float
615
+ )
616
+
617
+ num_aug_nodes = aug_node_features_tensor.shape[0]
618
+ aug_edge_index = torch.empty(
619
+ (2, 0), dtype=torch.long, device=aug_node_features_tensor.device
620
+ )
621
+
622
+ if num_aug_nodes > 1:
623
+ norm_aug_node_features = F.normalize(aug_node_features_tensor, p=2, dim=1)
624
+ aug_similarity_matrix = torch.matmul(
625
+ norm_aug_node_features, norm_aug_node_features.t()
626
+ )
627
+ aug_similarity_matrix.fill_diagonal_(float("-inf"))
628
+
629
+ # Same adaptive-k logic for augmented graphs
630
+ if (
631
+ isinstance(knn_k_for_similarity_graph, int)
632
+ and knn_k_for_similarity_graph > 0
633
+ ):
634
+ current_knn_k_aug = min(knn_k_for_similarity_graph, num_aug_nodes - 1)
635
+ else:
636
+ adaptive_k_aug = max(knn_k_min, int(num_aug_nodes * knn_k_ratio))
637
+ current_knn_k_aug = min(adaptive_k_aug, num_aug_nodes - 1)
638
+
639
+ if current_knn_k_aug > 0:
640
+ aug_top_k_vals, aug_top_k_indices = torch.topk(
641
+ aug_similarity_matrix, k=current_knn_k_aug, dim=1
642
+ )
643
+ aug_adj = torch.zeros_like(aug_similarity_matrix, dtype=torch.bool)
644
+ aug_row_indices = (
645
+ torch.arange(num_aug_nodes, device=aug_node_features_tensor.device)
646
+ .unsqueeze(1)
647
+ .expand(-1, current_knn_k_aug)
648
+ )
649
+ aug_adj[aug_row_indices, aug_top_k_indices] = True
650
+ aug_adj = aug_adj | aug_adj.t()
651
+
652
+ # No need to remove virtual node edges as virtual nodes themselves are removed
653
+ aug_edge_index = aug_adj.nonzero(as_tuple=False).t().contiguous()
654
+
655
+ aug_graph = Data(
656
+ x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
657
+ )
658
+ aug_graphs.append(aug_graph)
659
+ return graphs, aug_graphs
660
+
661
+
662
+ def generate_graph_data_target_weight2(
663
+ dataset, df, path, n_sectors, m_rings, k_neighbor
664
+ ):
665
+ graphs = []
666
+ aug_graphs = []
667
+
668
+ for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing Graphs"):
669
+ cell = row["cell"]
670
+ gene = row["gene"]
671
+
672
+ raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
673
+ aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
674
+
675
+ # Original graph
676
+ nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
677
+ adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
678
+
679
+ try:
680
+ node_features_df = pd.read_csv(
681
+ nodes_file,
682
+ usecols=["count", "is_edge", "nuclear_position", "is_virtual"],
683
+ )
684
+ except FileNotFoundError:
685
+ continue
686
+ if (
687
+ node_features_df.empty
688
+ or "count" not in node_features_df.columns
689
+ or "nuclear_position" not in node_features_df.columns
690
+ ):
691
+ continue
692
+ if "is_virtual" not in node_features_df.columns:
693
+ node_features_df["is_virtual"] = 0
694
+
695
+ df_for_feature_calc = node_features_df.copy()
696
+ total_count = df_for_feature_calc["count"].sum()
697
+ df_for_feature_calc["count_ratio"] = (
698
+ 0 if total_count == 0 else df_for_feature_calc["count"] / total_count
699
+ )
700
+
701
+ # Assumes emb.nonlinear_transform_embedding is available
702
+ count_features_list = (
703
+ df_for_feature_calc["count_ratio"]
704
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
705
+ .tolist()
706
+ )
707
+ count_features_final_df = pd.DataFrame(
708
+ count_features_list,
709
+ columns=[f"dim_{i}" for i in range(12)],
710
+ index=df_for_feature_calc.index,
711
+ )
712
+
713
+ df_for_feature_calc["nuclear_position_mapped"] = (
714
+ df_for_feature_calc["nuclear_position"]
715
+ .map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
716
+ .astype(np.int64)
717
+ )
718
+ position_one_hot_df = pd.get_dummies(
719
+ df_for_feature_calc["nuclear_position_mapped"], prefix="pos", dtype=int
720
+ )
721
+
722
+ expected_pos_cols = [f"pos_{i}" for i in range(4)]
723
+ for col in expected_pos_cols:
724
+ if col not in position_one_hot_df.columns:
725
+ position_one_hot_df[col] = 0
726
+ position_one_hot_df = position_one_hot_df[expected_pos_cols]
727
+ final_node_features_df = pd.concat(
728
+ [count_features_final_df, position_one_hot_df], axis=1
729
+ ).astype(float)
730
+ virtual_node_mask = node_features_df["is_virtual"] == 1
731
+ final_node_features_df.loc[virtual_node_mask, :] = 0.0
732
+ final_node_features_df = final_node_features_df.astype(float)
733
+ node_features_tensor = torch.tensor(
734
+ final_node_features_df.values, dtype=torch.float
735
+ )
736
+
737
+ # adj_matrix = pd.read_csv(adj_file)
738
+ # edge_index = torch.tensor(np.array(adj_matrix.values.nonzero()), dtype=torch.long)
739
+ # graph = Data(x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene)
740
+ # graphs.append(graph)
741
+
742
+ # Augmented graph
743
+ nodes_file_aug = f"{aug_path}/{gene}_node_matrix.csv"
744
+ adj_file_aug = f"{aug_path}/{gene}_adj_matrix.csv"
745
+
746
+ try:
747
+ aug_node_features_df = pd.read_csv(nodes_file_aug)
748
+ except FileNotFoundError:
749
+ continue
750
+ if (
751
+ aug_node_features_df.empty
752
+ or "count" not in aug_node_features_df.columns
753
+ or "nuclear_position" not in aug_node_features_df.columns
754
+ ):
755
+ continue
756
+ if "is_virtual" not in aug_node_features_df.columns:
757
+ aug_node_features_df["is_virtual"] = 0
758
+
759
+ aug_df_for_feature_calc = aug_node_features_df.copy()
760
+ aug_total_count = aug_df_for_feature_calc["count"].sum()
761
+ aug_df_for_feature_calc["count_ratio"] = (
762
+ 0
763
+ if aug_total_count == 0
764
+ else aug_df_for_feature_calc["count"] / aug_total_count
765
+ )
766
+ aug_count_features_list = (
767
+ aug_df_for_feature_calc["count_ratio"]
768
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
769
+ .tolist()
770
+ )
771
+ aug_count_features_final_df = pd.DataFrame(
772
+ aug_count_features_list,
773
+ columns=[f"dim_{i}" for i in range(12)],
774
+ index=aug_df_for_feature_calc.index,
775
+ )
776
+ aug_df_for_feature_calc["nuclear_position_mapped"] = (
777
+ aug_df_for_feature_calc["nuclear_position"]
778
+ .map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
779
+ .astype(np.int64)
780
+ )
781
+ aug_position_one_hot_df = pd.get_dummies(
782
+ aug_df_for_feature_calc["nuclear_position_mapped"], prefix="pos", dtype=int
783
+ )
784
+
785
+ for col in expected_pos_cols:
786
+ if col not in aug_position_one_hot_df.columns:
787
+ aug_position_one_hot_df[col] = 0
788
+ aug_position_one_hot_df = aug_position_one_hot_df[expected_pos_cols]
789
+ aug_final_node_features_df = pd.concat(
790
+ [aug_count_features_final_df, aug_position_one_hot_df], axis=1
791
+ ).astype(float)
792
+ aug_virtual_node_mask = aug_node_features_df["is_virtual"] == 1
793
+ aug_final_node_features_df.loc[aug_virtual_node_mask, :] = 0.0
794
+ aug_final_node_features_df = aug_final_node_features_df.astype(float)
795
+ aug_node_features_tensor = torch.tensor(
796
+ aug_final_node_features_df.values, dtype=torch.float
797
+ )
798
+
799
+ # aug_adj_matrix = pd.read_csv(adj_file_aug)
800
+ # aug_edge_index = torch.tensor(np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long)
801
+ # aug_graph = Data(x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene)
802
+ # aug_graphs.append(aug_graph)
803
+
804
+ return graphs, aug_graphs
805
+
806
+
807
+ # def process_graph_row_safe(args):
808
+ # row, path, n_sectors, m_rings, k_neighbor = args
809
+ # cell = row['cell']
810
+ # gene = row['gene']
811
+ # graph_data = []
812
+
813
+ # for is_aug in [False, True]:
814
+ # base_path = f"{path}/{cell}"
815
+ # if is_aug:
816
+ # base_path += "_aug"
817
+
818
+ # nodes_file = f'{base_path}/{gene}_node_matrix.csv'
819
+ # adj_file = f'{base_path}/{gene}_adj_matrix.csv'
820
+
821
+ # try:
822
+ # node_df = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4, 5])
823
+ # total_count = node_df['count'].sum()
824
+ # if total_count == 0:
825
+ # return None
826
+
827
+ # count_ratio = (node_df['count'] / total_count).to_numpy()
828
+ # count_embed_np = np.array([emb.nonlinear_transform_embedding(x, dim=12) for x in count_ratio])
829
+ # pos_int = node_df['nuclear_position'].map(
830
+ # {'inside': 0, 'outside': 1, 'boundary': 2, 'edge': 3}
831
+ # ).fillna(4).astype(int).to_numpy()
832
+ # pos_one_hot = np.eye(4)[pos_int]
833
+ # node_features_np = np.hstack([count_embed_np, pos_one_hot])
834
+
835
+ # adj_matrix = pd.read_csv(adj_file).values
836
+ # edge_index_np = np.array(np.nonzero(adj_matrix))
837
+
838
+ # graph_data.append({
839
+ # 'x': node_features_np,
840
+ # 'edge_index': edge_index_np,
841
+ # 'cell': cell,
842
+ # 'gene': gene
843
+ # })
844
+
845
+ # except Exception as e:
846
+ # print(f"Error processing {nodes_file}: {e}")
847
+ # return None
848
+
849
+ # return graph_data[0], graph_data[1]
850
+
851
+ # def generate_graph_data_target_parallel(dataset, df, path, n_sectors, m_rings, k_neighbor):
852
+ # args_list = [(row, path, n_sectors, m_rings, k_neighbor) for _, row in df.iterrows()]
853
+
854
+ # multiprocessing.set_start_method("spawn", force=True)
855
+
856
+ # with Pool(processes=4) as pool:
857
+ # results = list(tqdm(pool.imap_unordered(process_graph_row_safe, args_list, chunksize=10), total=len(df), desc="Parallel Processing Graphs"))
858
+
859
+ # original_graphs, augmented_graphs = [], []
860
+
861
+ # for result in results:
862
+ # if result is None:
863
+ # continue
864
+ # for idx, (graph_list, target) in enumerate(zip([original_graphs, augmented_graphs], result)):
865
+ # data = Data(
866
+ # x=torch.tensor(target['x'], dtype=torch.float),
867
+ # edge_index=torch.tensor(target['edge_index'], dtype=torch.long),
868
+ # cell=target['cell'],
869
+ # gene=target['gene']
870
+ # )
871
+ # graph_list.append(data)
872
+
873
+ # return original_graphs, augmented_graphs
874
+
875
+ import pandas as pd
876
+ import numpy as np
877
+ import torch
878
+ from torch_geometric.data import Data
879
+ from tqdm.auto import tqdm
880
+ import multiprocessing
881
+ from multiprocessing import Pool
882
+ import functools
883
+
884
+
885
+ def process_single_gene_pair(args):
886
+ """Process one (cell, gene) pair inside a worker process."""
887
+ row, path, n_sectors, m_rings, k_neighbor = args
888
+ cell = row["cell"]
889
+ gene = row["gene"]
890
+
891
+ results = {}
892
+
893
+ for suffix in ["", "_aug"]:
894
+ is_aug = suffix == "_aug"
895
+ base_path = f"{path}/{cell}{suffix}"
896
+ nodes_file = f"{base_path}/{gene}_node_matrix.csv"
897
+ adj_file = f"{base_path}/{gene}_adj_matrix.csv"
898
+
899
+ try:
900
+ # Read node features.
901
+ try:
902
+ node_df = pd.read_csv(nodes_file)
903
+ except FileNotFoundError:
904
+ return None
905
+
906
+ required_cols = ["count", "nuclear_position"]
907
+ if not all(col in node_df.columns for col in required_cols):
908
+ return None
909
+
910
+ # Feature computation (count ratio embedding + position one-hot).
911
+ total_count = node_df["count"].sum()
912
+ if total_count == 0:
913
+ return None
914
+
915
+ count_ratio = node_df["count"] / total_count
916
+ count_embed_list = [
917
+ emb.nonlinear_transform_embedding(x, dim=12) for x in count_ratio
918
+ ]
919
+ count_embed_np = np.array(count_embed_list)
920
+
921
+ # Position one-hot.
922
+ pos_map = {"inside": 0, "outside": 1, "boundary": 2, "edge": 3}
923
+ node_df["pos_mapped"] = (
924
+ node_df["nuclear_position"].map(pos_map).fillna(4).astype(int)
925
+ )
926
+
927
+ # NOTE: enforce dtype=int to avoid mixed bool/int columns which become
928
+ # dtype=object when converting to numpy.
929
+ pos_dummies = pd.get_dummies(node_df["pos_mapped"], prefix="pos", dtype=int)
930
+ expected_cols = [f"pos_{i}" for i in range(4)]
931
+ pos_dummies = pos_dummies.reindex(columns=expected_cols, fill_value=0)
932
+ pos_features_np = pos_dummies.to_numpy()
933
+
934
+ # Merge features.
935
+ node_features_np = np.hstack([count_embed_np, pos_features_np])
936
+
937
+ # 4. Adjacency handling (check empty matrices)
938
+ try:
939
+ adj_df = pd.read_csv(adj_file)
940
+ if adj_df.empty:
941
+ edge_index_np = np.empty((2, 0))
942
+ else:
943
+ edge_index_np = np.array(adj_df.values.nonzero())
944
+ except FileNotFoundError:
945
+ return None
946
+
947
+ # Store in a temporary dict
948
+ key = "aug" if is_aug else "orig"
949
+ results[key] = {
950
+ "x": node_features_np,
951
+ "edge_index": edge_index_np,
952
+ "cell": cell,
953
+ "gene": gene,
954
+ }
955
+
956
+ except Exception as e:
957
+ # Log errors for debugging; keep output minimal in parallel runs
958
+ # print(f"Error in {cell}-{gene}: {e}")
959
+ return None
960
+
961
+ # Only return when both original and augmented graphs succeed
962
+ if "orig" in results and "aug" in results:
963
+ return results["orig"], results["aug"]
964
+ return None
965
+
966
+
967
+ def generate_graph_data_target_parallel(
968
+ dataset, df, path, n_sectors, m_rings, k_neighbor, processes=8
969
+ ):
970
+ # Build argument list
971
+ args_list = [
972
+ (row, path, n_sectors, m_rings, k_neighbor) for _, row in df.iterrows()
973
+ ]
974
+
975
+ # Ensure spawn mode works on some platforms
976
+ try:
977
+ multiprocessing.set_start_method("spawn", force=True)
978
+ except RuntimeError:
979
+ pass
980
+
981
+ original_graphs = []
982
+ augmented_graphs = []
983
+
984
+ print(f"Starting parallel processing with {processes} cores...")
985
+
986
+ with Pool(processes=processes) as pool:
987
+ # imap_unordered is efficient because it does not preserve task order
988
+ iterator = pool.imap_unordered(
989
+ process_single_gene_pair, args_list, chunksize=10
990
+ )
991
+
992
+ for result in tqdm(
993
+ iterator, total=len(args_list), desc="Generating Target Graphs"
994
+ ):
995
+ if result is None:
996
+ continue
997
+
998
+ orig_dict, aug_dict = result
999
+
1000
+ # Convert to torch_geometric Data in the parent process (safer)
1001
+ for target_list, data_dict in zip(
1002
+ [original_graphs, augmented_graphs], [orig_dict, aug_dict]
1003
+ ):
1004
+ data = Data(
1005
+ x=torch.tensor(data_dict["x"], dtype=torch.float),
1006
+ edge_index=torch.tensor(data_dict["edge_index"], dtype=torch.long),
1007
+ cell=data_dict["cell"],
1008
+ gene=data_dict["gene"],
1009
+ )
1010
+ target_list.append(data)
1011
+
1012
+ return original_graphs, augmented_graphs
1013
+
1014
+
1015
+ ## AD/merfish mouse brain (df is the dataframe to load)
1016
+ def generate_graph_data_target_noposition(
1017
+ dataset, df, path, n_sectors, m_rings, k_neighbor
1018
+ ):
1019
+ graphs = []
1020
+ aug_graphs = []
1021
+
1022
+ for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing Graphs"):
1023
+ cell = row["cell"]
1024
+ gene = row["gene"]
1025
+
1026
+ raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
1027
+ aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
1028
+ # Original graph
1029
+ nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
1030
+ adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
1031
+
1032
+ node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4])
1033
+ total_count = node_features["count"].sum()
1034
+ node_features["count_ratio"] = node_features["count"] / total_count
1035
+ count_features = (
1036
+ node_features["count_ratio"]
1037
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=16))
1038
+ .tolist()
1039
+ )
1040
+ count_features = pd.DataFrame(
1041
+ count_features, columns=[f"dim_{i}" for i in range(16)]
1042
+ )
1043
+ node_features_tensor = torch.tensor(count_features.values, dtype=torch.float)
1044
+ adj_matrix = pd.read_csv(adj_file)
1045
+ edge_index = torch.tensor(
1046
+ np.array(adj_matrix.values.nonzero()), dtype=torch.long
1047
+ )
1048
+ graph = Data(
1049
+ x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene
1050
+ )
1051
+ graphs.append(graph)
1052
+ # Augmented graph
1053
+ nodes_file = f"{aug_path}/{gene}_node_matrix.csv"
1054
+ adj_file = f"{aug_path}/{gene}_adj_matrix.csv"
1055
+
1056
+ aug_node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4])
1057
+ total_count = aug_node_features["count"].sum()
1058
+ aug_node_features["count_ratio"] = aug_node_features["count"] / total_count
1059
+ aug_count_features = (
1060
+ aug_node_features["count_ratio"]
1061
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=16))
1062
+ .tolist()
1063
+ )
1064
+ aug_count_features = pd.DataFrame(
1065
+ aug_count_features, columns=[f"dim_{i}" for i in range(16)]
1066
+ )
1067
+ aug_node_features_tensor = torch.tensor(
1068
+ aug_count_features.values, dtype=torch.float
1069
+ )
1070
+ aug_adj_matrix = pd.read_csv(adj_file)
1071
+ aug_edge_index = torch.tensor(
1072
+ np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long
1073
+ )
1074
+ aug_graph = Data(
1075
+ x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
1076
+ )
1077
+ aug_graphs.append(aug_graph)
1078
+ return graphs, aug_graphs
1079
+
1080
+
1081
+ def process_cell_gene_noposition(
1082
+ cell, gene, dataset, path, n_sectors, m_rings, k_neighbor, base_path
1083
+ ):
1084
+ graphs = []
1085
+ aug_graphs = []
1086
+ # base_path = f"/Volumes/hyydisk/GCN_CL/3_filter/1_{dataset}_Wasserstein_Distance/"
1087
+
1088
+ raw_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
1089
+ aug_path = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
1090
+
1091
+ df_file = f"{base_path}/{gene}_distances_filter_new.csv"
1092
+ if not os.path.exists(df_file):
1093
+ print(f"1. Skipping {gene} in {cell} (file not found).")
1094
+ return graphs, aug_graphs
1095
+
1096
+ df = pd.read_csv(df_file)
1097
+ filtered_df = df[
1098
+ (df["gene"] == gene) & (df["cell"] == cell) & (df["location"] == "other")
1099
+ ]
1100
+ if filtered_df.empty:
1101
+ # print(f"Skipping {gene} in cell {cell} filtered_df is empty")
1102
+ return graphs, aug_graphs
1103
+
1104
+ # Original graph
1105
+ nodes_file = f"{raw_path}/{gene}_node_matrix.csv"
1106
+ adj_file = f"{raw_path}/{gene}_adj_matrix.csv"
1107
+ if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
1108
+ # print(f"1. Skipping {gene} in {cell} (file not found).")
1109
+ return graphs, aug_graphs
1110
+
1111
+ node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4])
1112
+ if len(node_features) <= 5:
1113
+ # print(f"2. Skipping {gene} in {cell} (too few points).")
1114
+ return graphs, aug_graphs
1115
+
1116
+ total_count = node_features["count"].sum()
1117
+ node_features["count_ratio"] = node_features["count"] / total_count
1118
+ count_features = (
1119
+ node_features["count_ratio"]
1120
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=16))
1121
+ .tolist()
1122
+ )
1123
+ count_features = pd.DataFrame(
1124
+ count_features, columns=[f"dim_{i}" for i in range(8)]
1125
+ )
1126
+ node_features_tensor = torch.tensor(node_features.values, dtype=torch.float)
1127
+ adj_matrix = pd.read_csv(adj_file)
1128
+ edge_index = torch.tensor(np.array(adj_matrix.values.nonzero()), dtype=torch.long)
1129
+ graph = Data(x=node_features_tensor, edge_index=edge_index, cell=cell, gene=gene)
1130
+ graphs.append(graph)
1131
+
1132
+ # Augmented graph
1133
+ nodes_file = f"{aug_path}/{gene}_node_matrix.csv"
1134
+ adj_file = f"{aug_path}/{gene}_adj_matrix.csv"
1135
+ if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
1136
+ # print(f"Skipping {gene} in {cell} (augmented file not found).")
1137
+ return graphs, aug_graphs
1138
+
1139
+ aug_node_features = pd.read_csv(nodes_file, usecols=[1, 2, 3, 4])
1140
+ aug_count_features = (
1141
+ aug_node_features["count"]
1142
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=8))
1143
+ .tolist()
1144
+ )
1145
+ aug_count_features = pd.DataFrame(
1146
+ aug_count_features, columns=[f"dim_{i}" for i in range(8)]
1147
+ )
1148
+ aug_node_features_tensor = torch.tensor(aug_node_features.values, dtype=torch.float)
1149
+ aug_adj_matrix = pd.read_csv(adj_file)
1150
+ aug_edge_index = torch.tensor(
1151
+ np.array(aug_adj_matrix.values.nonzero()), dtype=torch.long
1152
+ )
1153
+ aug_graph = Data(
1154
+ x=aug_node_features_tensor, edge_index=aug_edge_index, cell=cell, gene=gene
1155
+ )
1156
+ aug_graphs.append(aug_graph)
1157
+ return graphs, aug_graphs
1158
+
1159
+
1160
+ def generate_graph_data_parallel_noposition(
1161
+ dataset,
1162
+ cell_list,
1163
+ gene_list,
1164
+ path,
1165
+ base_path,
1166
+ n_sectors,
1167
+ m_rings,
1168
+ k_neighbor,
1169
+ n_jobs=4,
1170
+ ):
1171
+ # tqdm progress bar
1172
+ all_cells_genes = [(cell, gene) for cell in cell_list for gene in gene_list]
1173
+ # Run in parallel with progress
1174
+ results = Parallel(n_jobs=n_jobs)(
1175
+ delayed(process_cell_gene_noposition)(
1176
+ cell, gene, dataset, path, n_sectors, m_rings, k_neighbor, base_path
1177
+ )
1178
+ for cell, gene in tqdm(
1179
+ all_cells_genes,
1180
+ desc="Processing cells and genes",
1181
+ total=len(all_cells_genes),
1182
+ )
1183
+ )
1184
+
1185
+ # Merge results
1186
+ graphs = []
1187
+ aug_graphs = []
1188
+ for result in results:
1189
+ g, ag = result
1190
+ graphs.extend(g)
1191
+ aug_graphs.extend(ag)
1192
+
1193
+ return graphs, aug_graphs
1194
+
1195
+
1196
+ def generate_graph_data_target(dataset, df, path, n_sectors, m_rings, k_neighbor):
1197
+ graphs = []
1198
+ aug_graphs = []
1199
+
1200
+ for _, row in tqdm(
1201
+ df.iterrows(),
1202
+ total=len(df),
1203
+ desc="Processing Graphs generate_graph_data_target",
1204
+ ):
1205
+ cell = row["cell"]
1206
+ gene = row["gene"]
1207
+
1208
+ # --- Process Original Graph ---
1209
+ raw_path_base = f"{path}/{cell}"
1210
+ nodes_file_orig = f"{raw_path_base}/{gene}_node_matrix.csv"
1211
+ adj_file_orig = f"{raw_path_base}/{gene}_adj_matrix.csv"
1212
+ skip_current_pair = False
1213
+ original_graph_data = None
1214
+
1215
+ try:
1216
+ # Read only necessary columns by name
1217
+ try:
1218
+ # Try reading with is_virtual first
1219
+ node_features_orig_df_raw = pd.read_csv(
1220
+ nodes_file_orig, usecols=["count", "nuclear_position", "is_virtual"]
1221
+ )
1222
+ except ValueError: # If is_virtual column doesn't exist
1223
+ try:
1224
+ node_features_orig_df_raw = pd.read_csv(
1225
+ nodes_file_orig, usecols=["count", "nuclear_position"]
1226
+ )
1227
+ except FileNotFoundError:
1228
+ print(
1229
+ f"Original node_matrix file not found: {nodes_file_orig}. Skipping pair."
1230
+ )
1231
+ skip_current_pair = True
1232
+ raise # Re-raise to be caught by outer try-except
1233
+ except ValueError: # If count or nuclear_position are missing
1234
+ print(
1235
+ f"Original node_matrix {nodes_file_orig} missing 'count' or 'nuclear_position'. Skipping pair."
1236
+ )
1237
+ skip_current_pair = True
1238
+ raise
1239
+
1240
+ if node_features_orig_df_raw.empty:
1241
+ print(
1242
+ f"Original node_matrix {nodes_file_orig} is empty. Skipping pair."
1243
+ )
1244
+ skip_current_pair = True
1245
+ raise FileNotFoundError # Treat as if file not found for outer catch
1246
+
1247
+ node_features_orig_df = node_features_orig_df_raw.copy()
1248
+
1249
+ total_count_orig = node_features_orig_df["count"].sum()
1250
+ node_features_orig_df["count_ratio"] = (
1251
+ 0.0
1252
+ if total_count_orig == 0
1253
+ else node_features_orig_df["count"] / total_count_orig
1254
+ )
1255
+
1256
+ count_features_orig_list = (
1257
+ node_features_orig_df["count_ratio"]
1258
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
1259
+ .tolist()
1260
+ )
1261
+ count_features_orig_embedded_df = pd.DataFrame(
1262
+ count_features_orig_list,
1263
+ columns=[f"dim_{i}" for i in range(12)],
1264
+ index=node_features_orig_df.index,
1265
+ )
1266
+
1267
+ node_features_orig_df["nuclear_position_mapped"] = (
1268
+ node_features_orig_df["nuclear_position"]
1269
+ .map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
1270
+ .fillna(4)
1271
+ .astype(int)
1272
+ )
1273
+ position_orig_one_hot_df = pd.get_dummies(
1274
+ node_features_orig_df["nuclear_position_mapped"],
1275
+ prefix="pos",
1276
+ dtype=int,
1277
+ )
1278
+
1279
+ # Ensure all graphs have the same position feature dimensionality (pos_0..pos_3)
1280
+ expected_pos_cols = ["pos_0", "pos_1", "pos_2", "pos_3"]
1281
+ for col in expected_pos_cols:
1282
+ if col not in position_orig_one_hot_df.columns:
1283
+ position_orig_one_hot_df[col] = 0
1284
+ position_orig_one_hot_df = position_orig_one_hot_df[expected_pos_cols]
1285
+
1286
+ final_node_features_orig_df = pd.concat(
1287
+ [count_features_orig_embedded_df, position_orig_one_hot_df], axis=1
1288
+ )
1289
+ node_features_orig_tensor = torch.tensor(
1290
+ final_node_features_orig_df.values, dtype=torch.float
1291
+ )
1292
+
1293
+ try:
1294
+ adj_matrix_orig_df = pd.read_csv(adj_file_orig)
1295
+ except FileNotFoundError:
1296
+ print(
1297
+ f"Original adj_matrix file not found: {adj_file_orig}. Skipping pair."
1298
+ )
1299
+ skip_current_pair = True
1300
+ raise
1301
+
1302
+ if adj_matrix_orig_df.empty and node_features_orig_tensor.shape[0] > 0:
1303
+ print(
1304
+ f"Warning: Original adj_matrix {adj_file_orig} is empty for a graph with {node_features_orig_tensor.shape[0]} nodes."
1305
+ )
1306
+ edge_index_orig = torch.empty((2, 0), dtype=torch.long)
1307
+ elif node_features_orig_tensor.shape[0] == 0:
1308
+ edge_index_orig = torch.empty((2, 0), dtype=torch.long)
1309
+ else:
1310
+ edge_index_orig = torch.tensor(
1311
+ np.array(adj_matrix_orig_df.values.nonzero()), dtype=torch.long
1312
+ )
1313
+
1314
+ original_graph_data = Data(
1315
+ x=node_features_orig_tensor,
1316
+ edge_index=edge_index_orig,
1317
+ cell=cell,
1318
+ gene=gene,
1319
+ )
1320
+
1321
+ except (
1322
+ FileNotFoundError,
1323
+ ValueError,
1324
+ KeyError,
1325
+ ) as e_orig: # Catch specific errors for original graph processing
1326
+ # This block now effectively handles and reports on the re-raised errors from inner try-excepts.
1327
+ # The specific print statements inside the inner blocks provide more detail.
1328
+ pass # Continue to next pair in the main loop
1329
+ except Exception as e_general_orig:
1330
+ print(
1331
+ f"Unexpected error processing original graph for {cell}-{gene}: {e_general_orig}. Skipping pair."
1332
+ )
1333
+ skip_current_pair = True
1334
+ pass
1335
+
1336
+ if skip_current_pair:
1337
+ continue # Move to the next cell-gene pair
1338
+
1339
+ # --- Process Augmented Graph ---
1340
+ aug_path_base = f"{path}/{cell}_aug"
1341
+ nodes_file_aug = f"{aug_path_base}/{gene}_node_matrix.csv"
1342
+ adj_file_aug = f"{aug_path_base}/{gene}_adj_matrix.csv"
1343
+ augmented_graph_data = None
1344
+
1345
+ try:
1346
+ try:
1347
+ node_features_aug_df_raw = pd.read_csv(
1348
+ nodes_file_aug, usecols=["count", "nuclear_position", "is_virtual"]
1349
+ )
1350
+ except ValueError:
1351
+ try:
1352
+ node_features_aug_df_raw = pd.read_csv(
1353
+ nodes_file_aug, usecols=["count", "nuclear_position"]
1354
+ )
1355
+ except FileNotFoundError:
1356
+ print(
1357
+ f"Augmented node_matrix file not found: {nodes_file_aug}. Original graph will be kept if processed."
1358
+ )
1359
+ raise # Re-raise to be caught by outer try-except for augmented graph
1360
+ except ValueError:
1361
+ print(
1362
+ f"Augmented node_matrix {nodes_file_aug} missing 'count' or 'nuclear_position'. Original graph will be kept."
1363
+ )
1364
+ raise
1365
+
1366
+ if node_features_aug_df_raw.empty:
1367
+ print(
1368
+ f"Augmented node_matrix {nodes_file_aug} is empty. Original graph will be kept."
1369
+ )
1370
+ raise FileNotFoundError
1371
+
1372
+ node_features_aug_df = node_features_aug_df_raw.copy()
1373
+
1374
+ total_count_aug = node_features_aug_df["count"].sum()
1375
+ node_features_aug_df["count_ratio"] = (
1376
+ 0.0
1377
+ if total_count_aug == 0
1378
+ else node_features_aug_df["count"] / total_count_aug
1379
+ )
1380
+
1381
+ count_features_aug_list = (
1382
+ node_features_aug_df["count_ratio"]
1383
+ .apply(lambda x: emb.nonlinear_transform_embedding(x, dim=12))
1384
+ .tolist()
1385
+ )
1386
+ count_features_aug_embedded_df = pd.DataFrame(
1387
+ count_features_aug_list,
1388
+ columns=[f"dim_{i}" for i in range(12)],
1389
+ index=node_features_aug_df.index,
1390
+ )
1391
+
1392
+ node_features_aug_df["nuclear_position_mapped"] = (
1393
+ node_features_aug_df["nuclear_position"]
1394
+ .map({"inside": 0, "outside": 1, "boundary": 2, "edge": 3})
1395
+ .fillna(4)
1396
+ .astype(int)
1397
+ )
1398
+ position_aug_one_hot_df = pd.get_dummies(
1399
+ node_features_aug_df["nuclear_position_mapped"], prefix="pos", dtype=int
1400
+ )
1401
+
1402
+ # Ensure augmented graphs also have the same position feature dimensionality (pos_0..pos_3)
1403
+ expected_pos_cols = ["pos_0", "pos_1", "pos_2", "pos_3"]
1404
+ for col in expected_pos_cols:
1405
+ if col not in position_aug_one_hot_df.columns:
1406
+ position_aug_one_hot_df[col] = 0
1407
+ position_aug_one_hot_df = position_aug_one_hot_df[expected_pos_cols]
1408
+
1409
+ final_node_features_aug_df = pd.concat(
1410
+ [count_features_aug_embedded_df, position_aug_one_hot_df], axis=1
1411
+ )
1412
+ node_features_aug_tensor = torch.tensor(
1413
+ final_node_features_aug_df.values, dtype=torch.float
1414
+ )
1415
+
1416
+ try:
1417
+ adj_matrix_aug_df = pd.read_csv(adj_file_aug)
1418
+ except FileNotFoundError:
1419
+ print(
1420
+ f"Augmented adj_matrix file not found: {adj_file_aug}. Original graph will be kept."
1421
+ )
1422
+ raise
1423
+
1424
+ if adj_matrix_aug_df.empty and node_features_aug_tensor.shape[0] > 0:
1425
+ print(
1426
+ f"Warning: Augmented adj_matrix {adj_file_aug} is empty for a graph with {node_features_aug_tensor.shape[0]} nodes."
1427
+ )
1428
+ edge_index_aug = torch.empty((2, 0), dtype=torch.long)
1429
+ elif node_features_aug_tensor.shape[0] == 0:
1430
+ edge_index_aug = torch.empty((2, 0), dtype=torch.long)
1431
+ else:
1432
+ edge_index_aug = torch.tensor(
1433
+ np.array(adj_matrix_aug_df.values.nonzero()), dtype=torch.long
1434
+ )
1435
+
1436
+ augmented_graph_data = Data(
1437
+ x=node_features_aug_tensor,
1438
+ edge_index=edge_index_aug,
1439
+ cell=cell,
1440
+ gene=gene,
1441
+ )
1442
+
1443
+ except (FileNotFoundError, ValueError, KeyError) as e_aug:
1444
+ print(
1445
+ f"Failed to process augmented graph for {cell}-{gene}: {e_aug}. Original graph will be kept if available."
1446
+ )
1447
+ pass # Keep original graph if it was successfully processed
1448
+ except Exception as e_general_aug:
1449
+ print(
1450
+ f"Unexpected error processing augmented graph for {cell}-{gene}: {e_general_aug}. Original graph will be kept if available."
1451
+ )
1452
+ pass
1453
+
1454
+ # Add to lists if both original and augmented were processed successfully
1455
+ if original_graph_data and augmented_graph_data:
1456
+ graphs.append(original_graph_data)
1457
+ aug_graphs.append(augmented_graph_data)
1458
+ elif original_graph_data: # Only original was successful
1459
+ print(
1460
+ f"Only original graph processed for {cell}-{gene}. Augmented failed or was missing."
1461
+ )
1462
+ # Decide if you want to add originals even if augmented is missing.
1463
+ # For now, we are only adding pairs. To change this, uncomment the lines below:
1464
+ # graphs.append(original_graph_data)
1465
+ # aug_graphs.append(original_graph_data) # Or some placeholder for aug_graph
1466
+ pass
1467
+
1468
+ return graphs, aug_graphs
1469
+
1470
+
1471
+ def generate_graph_data_target_weight3(
1472
+ dataset, df, path, n_sectors, m_rings, k_neighbor
1473
+ ):
1474
+ """
1475
+ Robustly loads original and augmented graph data for a target dataframe.
1476
+ This version uses sinusoidal embedding for count features.
1477
+ This version creates a 16-dimensional node feature vector:
1478
+ - 12 dims for sinusoidal count embedding.
1479
+ - 3 dims for one-hot encoded nuclear position ('inside', 'outside', 'boundary').
1480
+ - 1 dim for the 'is_edge' flag.
1481
+ Virtual nodes are processed normally with their actual features.
1482
+ """
1483
+ graphs = []
1484
+ aug_graphs = []
1485
+
1486
+ for _, row in tqdm(
1487
+ df.iterrows(),
1488
+ total=len(df),
1489
+ desc="Processing Graphs for generate_graph_data_target_weight3",
1490
+ ):
1491
+ cell = row["cell"]
1492
+ gene = row["gene"]
1493
+
1494
+ # --- Process Original Graph ---
1495
+ raw_path_base = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
1496
+ nodes_file_orig = f"{raw_path_base}/{gene}_node_matrix.csv"
1497
+ adj_file_orig = f"{raw_path_base}/{gene}_adj_matrix.csv"
1498
+ original_graph_data = None
1499
+
1500
+ try:
1501
+ # Step 1: Robustly read node data
1502
+ try:
1503
+ node_features_orig_df_raw = pd.read_csv(
1504
+ nodes_file_orig,
1505
+ usecols=["count", "nuclear_position", "is_edge", "is_virtual"],
1506
+ )
1507
+ except ValueError:
1508
+ try:
1509
+ node_features_orig_df_raw = pd.read_csv(
1510
+ nodes_file_orig,
1511
+ usecols=["count", "nuclear_position", "is_edge"],
1512
+ )
1513
+ node_features_orig_df_raw["is_virtual"] = (
1514
+ 0 # Assume not virtual if column is missing
1515
+ )
1516
+ except ValueError:
1517
+ try:
1518
+ node_features_orig_df_raw = pd.read_csv(
1519
+ nodes_file_orig, usecols=["count", "nuclear_position"]
1520
+ )
1521
+ node_features_orig_df_raw["is_edge"] = (
1522
+ 0 # Assume not edge if column is missing
1523
+ )
1524
+ node_features_orig_df_raw["is_virtual"] = 0
1525
+ except (FileNotFoundError, ValueError) as e:
1526
+ print(
1527
+ f"Original node_matrix file error for {nodes_file_orig}: {e}. Skipping pair."
1528
+ )
1529
+ raise e from None
1530
+
1531
+ if node_features_orig_df_raw.empty:
1532
+ print(
1533
+ f"Original node_matrix {nodes_file_orig} is empty. Skipping pair."
1534
+ )
1535
+ raise FileNotFoundError
1536
+
1537
+ node_features_orig_df = node_features_orig_df_raw.copy()
1538
+
1539
+ # Step 2: Feature Engineering
1540
+ # Count embedding (12-dim)
1541
+ total_count_orig = node_features_orig_df["count"].sum()
1542
+ node_features_orig_df["count_ratio"] = (
1543
+ 0.0
1544
+ if total_count_orig == 0
1545
+ else node_features_orig_df["count"] / total_count_orig
1546
+ )
1547
+ count_features_list = (
1548
+ node_features_orig_df["count_ratio"]
1549
+ .apply(
1550
+ lambda x: emb.get_sinusoidal_embedding_for_continuous_value(
1551
+ x, dim=12
1552
+ )
1553
+ )
1554
+ .tolist()
1555
+ )
1556
+ count_features_embedded_df = pd.DataFrame(
1557
+ count_features_list,
1558
+ columns=[f"dim_{i}" for i in range(12)],
1559
+ index=node_features_orig_df.index,
1560
+ )
1561
+
1562
+ # Nuclear position one-hot (3-dim)
1563
+ expected_pos_cats = ["inside", "outside", "boundary"]
1564
+ node_features_orig_df["nuclear_position_cat"] = pd.Categorical(
1565
+ node_features_orig_df["nuclear_position"], categories=expected_pos_cats
1566
+ )
1567
+ position_one_hot_df = pd.get_dummies(
1568
+ node_features_orig_df["nuclear_position_cat"], prefix="pos", dtype=int
1569
+ )
1570
+
1571
+ # is_edge feature (1-dim)
1572
+ is_edge_df = node_features_orig_df[["is_edge"]]
1573
+
1574
+ # Combine features
1575
+ final_node_features_df = pd.concat(
1576
+ [count_features_embedded_df, position_one_hot_df, is_edge_df], axis=1
1577
+ )
1578
+
1579
+ # Virtual nodes keep their original features (no longer set to zero)
1580
+ node_features_orig_tensor = torch.tensor(
1581
+ final_node_features_df.values, dtype=torch.float
1582
+ )
1583
+
1584
+ # Step 3: Load Adjacency Matrix
1585
+ try:
1586
+ adj_matrix_orig_df = pd.read_csv(adj_file_orig)
1587
+ except FileNotFoundError:
1588
+ print(
1589
+ f"Original adj_matrix file not found: {adj_file_orig}. Skipping pair."
1590
+ )
1591
+ raise
1592
+
1593
+ if adj_matrix_orig_df.empty and node_features_orig_tensor.shape[0] > 0:
1594
+ edge_index_orig = torch.empty((2, 0), dtype=torch.long)
1595
+ elif node_features_orig_tensor.shape[0] == 0:
1596
+ edge_index_orig = torch.empty((2, 0), dtype=torch.long)
1597
+ else:
1598
+ edge_index_orig = torch.tensor(
1599
+ np.array(adj_matrix_orig_df.values.nonzero()), dtype=torch.long
1600
+ )
1601
+
1602
+ original_graph_data = Data(
1603
+ x=node_features_orig_tensor,
1604
+ edge_index=edge_index_orig,
1605
+ cell=cell,
1606
+ gene=gene,
1607
+ )
1608
+
1609
+ except (FileNotFoundError, ValueError, KeyError) as e_orig:
1610
+ pass
1611
+ except Exception as e_general_orig:
1612
+ print(
1613
+ f"Unexpected error processing original graph for {cell}-{gene}: {e_general_orig}. Skipping pair."
1614
+ )
1615
+ pass
1616
+
1617
+ if original_graph_data is None:
1618
+ continue
1619
+
1620
+ # --- Process Augmented Graph ---
1621
+ aug_path_base = f"{path}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
1622
+ nodes_file_aug = f"{aug_path_base}/{gene}_node_matrix.csv"
1623
+ adj_file_aug = f"{aug_path_base}/{gene}_adj_matrix.csv"
1624
+ augmented_graph_data = None
1625
+
1626
+ try:
1627
+ # Step 1: Robustly read node data
1628
+ try:
1629
+ node_features_aug_df_raw = pd.read_csv(
1630
+ nodes_file_aug,
1631
+ usecols=["count", "nuclear_position", "is_edge", "is_virtual"],
1632
+ )
1633
+ except ValueError:
1634
+ try:
1635
+ node_features_aug_df_raw = pd.read_csv(
1636
+ nodes_file_aug, usecols=["count", "nuclear_position", "is_edge"]
1637
+ )
1638
+ node_features_aug_df_raw["is_virtual"] = 0
1639
+ except ValueError:
1640
+ try:
1641
+ node_features_aug_df_raw = pd.read_csv(
1642
+ nodes_file_aug, usecols=["count", "nuclear_position"]
1643
+ )
1644
+ node_features_aug_df_raw["is_edge"] = 0
1645
+ node_features_aug_df_raw["is_virtual"] = 0
1646
+ except (FileNotFoundError, ValueError) as e:
1647
+ print(
1648
+ f"Augmented node_matrix file error for {nodes_file_aug}: {e}. Original graph will be kept if processed."
1649
+ )
1650
+ raise e from None
1651
+
1652
+ if node_features_aug_df_raw.empty:
1653
+ print(
1654
+ f"Augmented node_matrix {nodes_file_aug} is empty. Original graph will be kept."
1655
+ )
1656
+ raise FileNotFoundError
1657
+
1658
+ node_features_aug_df = node_features_aug_df_raw.copy()
1659
+
1660
+ # Step 2: Feature Engineering for augmented graph
1661
+ total_count_aug = node_features_aug_df["count"].sum()
1662
+ node_features_aug_df["count_ratio"] = (
1663
+ 0.0
1664
+ if total_count_aug == 0
1665
+ else node_features_aug_df["count"] / total_count_aug
1666
+ )
1667
+ count_features_aug_list = (
1668
+ node_features_aug_df["count_ratio"]
1669
+ .apply(
1670
+ lambda x: emb.get_sinusoidal_embedding_for_continuous_value(
1671
+ x, dim=12
1672
+ )
1673
+ )
1674
+ .tolist()
1675
+ )
1676
+ count_features_aug_embedded_df = pd.DataFrame(
1677
+ count_features_aug_list,
1678
+ columns=[f"dim_{i}" for i in range(12)],
1679
+ index=node_features_aug_df.index,
1680
+ )
1681
+
1682
+ node_features_aug_df["nuclear_position_cat"] = pd.Categorical(
1683
+ node_features_aug_df["nuclear_position"], categories=expected_pos_cats
1684
+ )
1685
+ position_aug_one_hot_df = pd.get_dummies(
1686
+ node_features_aug_df["nuclear_position_cat"], prefix="pos", dtype=int
1687
+ )
1688
+
1689
+ is_edge_aug_df = node_features_aug_df[["is_edge"]]
1690
+
1691
+ final_node_features_aug_df = pd.concat(
1692
+ [
1693
+ count_features_aug_embedded_df,
1694
+ position_aug_one_hot_df,
1695
+ is_edge_aug_df,
1696
+ ],
1697
+ axis=1,
1698
+ )
1699
+
1700
+ # Virtual nodes keep their original features (no longer set to zero)
1701
+ node_features_aug_tensor = torch.tensor(
1702
+ final_node_features_aug_df.values, dtype=torch.float
1703
+ )
1704
+
1705
+ # Step 3: Load Adjacency Matrix for augmented graph
1706
+ try:
1707
+ adj_matrix_aug_df = pd.read_csv(adj_file_aug)
1708
+ except FileNotFoundError:
1709
+ print(
1710
+ f"Augmented adj_matrix file not found: {adj_file_aug}. Original graph will be kept."
1711
+ )
1712
+ raise
1713
+
1714
+ if adj_matrix_aug_df.empty and node_features_aug_tensor.shape[0] > 0:
1715
+ edge_index_aug = torch.empty((2, 0), dtype=torch.long)
1716
+ elif node_features_aug_tensor.shape[0] == 0:
1717
+ edge_index_aug = torch.empty((2, 0), dtype=torch.long)
1718
+ else:
1719
+ edge_index_aug = torch.tensor(
1720
+ np.array(adj_matrix_aug_df.values.nonzero()), dtype=torch.long
1721
+ )
1722
+
1723
+ augmented_graph_data = Data(
1724
+ x=node_features_aug_tensor,
1725
+ edge_index=edge_index_aug,
1726
+ cell=cell,
1727
+ gene=gene,
1728
+ )
1729
+
1730
+ except (FileNotFoundError, ValueError, KeyError) as e_aug:
1731
+ pass
1732
+ except Exception as e_general_aug:
1733
+ print(
1734
+ f"Unexpected error processing augmented graph for {cell}-{gene}: {e_general_aug}. Original graph will be kept if available."
1735
+ )
1736
+ pass
1737
+
1738
+ # Add to lists if both original and augmented were processed successfully
1739
+ if original_graph_data and augmented_graph_data:
1740
+ graphs.append(original_graph_data)
1741
+ aug_graphs.append(augmented_graph_data)
1742
+ elif original_graph_data:
1743
+ print(
1744
+ f"Only original graph processed for {cell}-{gene}. Augmented failed or was missing."
1745
+ )
1746
+ pass
1747
+
1748
+ return graphs, aug_graphs