grasp-tool 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ """Preprocessing utilities (register/partition/portrait/etc.)."""
@@ -0,0 +1,66 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+
5
+
6
+ def rotate_nodes(node_matrix, angle):
7
+ theta = np.radians(angle)
8
+ cos_theta, sin_theta = np.cos(theta), np.sin(theta)
9
+ x, y = node_matrix["x"], node_matrix["y"]
10
+ node_matrix["x"] = x * cos_theta - y * sin_theta
11
+ node_matrix["y"] = x * sin_theta + y * cos_theta
12
+ return node_matrix
13
+
14
+
15
+ def dropout_nodes(adj_matrix, node_matrix, dropout_ratio):
16
+ real_nodes = node_matrix[node_matrix["is_virtual"] == 0].index
17
+ # n = len(node_matrix)
18
+ num_real_nodes = len(real_nodes)
19
+ num_drop = int(num_real_nodes * dropout_ratio)
20
+ drop_nodes = np.random.choice(real_nodes, size=num_drop, replace=False)
21
+ node_matrix.loc[drop_nodes, "is_virtual"] = 1
22
+ for node in drop_nodes:
23
+ adj_matrix.iloc[node, :] = 0
24
+ adj_matrix.iloc[:, node] = 0
25
+ return adj_matrix, node_matrix
26
+
27
+
28
+ def plot_graph(
29
+ adj_matrix_before,
30
+ node_matrix_before,
31
+ adj_matrix_after,
32
+ node_matrix_after,
33
+ title,
34
+ save_path,
35
+ ):
36
+ fig, axes = plt.subplots(1, 2, figsize=(6, 3), sharex=True, sharey=True)
37
+ titles = [f"{title} (Original)", f"{title} (After Dropout)"]
38
+ adj_matrices = [adj_matrix_before, adj_matrix_after]
39
+ node_matrices = [node_matrix_before, node_matrix_after]
40
+ for ax, adj_matrix, node_matrix, sub_title in zip(
41
+ axes, adj_matrices, node_matrices, titles
42
+ ):
43
+ ax.set_title(sub_title)
44
+ for i in range(len(adj_matrix)):
45
+ for j in range(len(adj_matrix)):
46
+ if adj_matrix.iloc[i, j] == 1:
47
+ x_coords = [node_matrix.loc[i, "x"], node_matrix.loc[j, "x"]]
48
+ y_coords = [node_matrix.loc[i, "y"], node_matrix.loc[j, "y"]]
49
+ ax.plot(x_coords, y_coords, "gray", alpha=0.6, linewidth=0.5)
50
+ virtual_nodes = node_matrix[node_matrix["is_virtual"] == 1]
51
+ real_nodes = node_matrix[node_matrix["is_virtual"] == 0]
52
+ ax.scatter(
53
+ virtual_nodes["x"],
54
+ virtual_nodes["y"],
55
+ color="lightgray",
56
+ label="Virtual Nodes",
57
+ s=5,
58
+ )
59
+ ax.scatter(
60
+ real_nodes["x"], real_nodes["y"], color="red", label="Real Nodes", s=5
61
+ )
62
+ # ax.legend()
63
+ ax.set_xlabel("x")
64
+ ax.set_ylabel("y")
65
+ plt.tight_layout()
66
+ plt.savefig(f"{save_path}/{title}_comparison.png")
@@ -0,0 +1,475 @@
1
+ import math
2
+ import os
3
+ import matplotlib.pyplot as plt
4
+ from shapely.wkt import loads
5
+ from shapely.geometry import Polygon
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+ import seaborn as sns
9
+
10
+
11
+ def plot_raw_cell(dataset, cell_boundary, nuclear_boundary, path):
12
+ save_dir = f"{path}/1_{dataset}_raw_cell_plot"
13
+ if not os.path.exists(save_dir):
14
+ os.makedirs(save_dir)
15
+ cells = list(cell_boundary.keys())
16
+ num_cells = len(cells)
17
+ for idx, cell in enumerate(cells):
18
+ plt.figure(figsize=(4, 4))
19
+ plt.plot(
20
+ cell_boundary[cell]["x"],
21
+ cell_boundary[cell]["y"],
22
+ label="Cell Boundary",
23
+ color="black",
24
+ )
25
+ if cell in nuclear_boundary:
26
+ plt.plot(
27
+ nuclear_boundary[cell]["x"],
28
+ nuclear_boundary[cell]["y"],
29
+ label="Nucleus Boundary",
30
+ color="red",
31
+ )
32
+ plt.title(f"Cell: {cell}")
33
+ save_path = os.path.join(save_dir, f"{cell}.png")
34
+ plt.savefig(save_path)
35
+
36
+ plt.close()
37
+ print(f"All cell images have been saved to {save_dir}")
38
+
39
+
40
+ def plot_raw_gene_distribution(dataset, cell_boundary, nuclear_boundary, df, path):
41
+ cells = list(cell_boundary.keys())
42
+ num_cells = len(cells)
43
+ # tqdm progress bar
44
+ for idx, cell in tqdm(enumerate(cells), total=num_cells, desc="Processing cells"):
45
+ cell_data = df[df["cell"] == cell]
46
+ save_dir = f"{path}/{dataset}/raw_gene/cell_{cell}"
47
+ if not os.path.exists(save_dir):
48
+ os.makedirs(save_dir)
49
+ for gene in cell_data["gene"].unique():
50
+ plt.figure(figsize=(3, 3))
51
+ plt.plot(
52
+ cell_boundary[cell]["x"],
53
+ cell_boundary[cell]["y"],
54
+ label="Cell Boundary",
55
+ color="black",
56
+ )
57
+ if cell in nuclear_boundary:
58
+ plt.plot(
59
+ nuclear_boundary[cell]["x"],
60
+ nuclear_boundary[cell]["y"],
61
+ label="Nucleus Boundary",
62
+ color="red",
63
+ )
64
+ gene_data = cell_data[cell_data["gene"] == gene]
65
+ plt.scatter(
66
+ gene_data["x"],
67
+ gene_data["y"],
68
+ label=f"Gene: {gene}",
69
+ s=3,
70
+ alpha=0.5,
71
+ color="blue",
72
+ )
73
+ if (
74
+ dataset == "simulated1"
75
+ or dataset == "simulated2"
76
+ or dataset == "simulated3"
77
+ ):
78
+ plt.title(f"{cell} - {gene}")
79
+ elif (
80
+ dataset == "merscope_liver_data2"
81
+ or dataset == "merscope_liver_data3"
82
+ or dataset == "merscope_liver_data4"
83
+ ):
84
+ plt.title(f"Gene: {gene}")
85
+ else:
86
+ plt.title(f"Cell: {cell} - Gene: {gene}")
87
+ # save_path = os.path.join(save_dir, f'{gene}.png')
88
+ plt.axis("off")
89
+ # plt.savefig(save_path)
90
+ plt.savefig(
91
+ f"{save_dir}/{gene}.png", format="png", dpi=300, bbox_inches="tight"
92
+ )
93
+ plt.savefig(f"{save_dir}/{gene}.pdf", format="pdf", bbox_inches="tight")
94
+ plt.savefig(f"{save_dir}/{gene}.svg", format="svg", bbox_inches="tight")
95
+ plt.close()
96
+ print(f"All cell images have been saved to {save_dir}")
97
+
98
+
99
+ def plot_raw_gene_distribution_without_nuclear(
100
+ dataset, cell_boundary, df_registered, path
101
+ ):
102
+ # for cell_name, cell_data in cell_boundary.items():
103
+ for cell_name, cell_data in tqdm(
104
+ cell_boundary.items(), desc="Processing cells", leave=True
105
+ ):
106
+ fig_path = f"{path}/{dataset}/raw_gene/{cell_name}"
107
+ os.makedirs(fig_path, exist_ok=True)
108
+ sub_df_registered = df_registered[df_registered["cell"] == cell_name]
109
+ gene_list = sub_df_registered["gene"].unique()
110
+ for gene in tqdm(gene_list, desc=f"Plotting for {cell_name}", leave=False):
111
+ gene_data = sub_df_registered[sub_df_registered["gene"] == gene]
112
+ plt.figure(figsize=(3, 3))
113
+ cell_polygon = Polygon(cell_data[["x", "y"]])
114
+ x, y = cell_polygon.exterior.xy
115
+ plt.plot(x, y, linestyle="-", color="black", linewidth=1)
116
+ plt.scatter(gene_data["x"], gene_data["y"], s=3, color="cornflowerblue")
117
+ plt.title(f"{cell_name} - {gene}")
118
+ plt.axis("off")
119
+ plt.tight_layout()
120
+ # plt.savefig(f'{path}/{gene}.png', dpi=300)
121
+ plt.savefig(
122
+ f"{fig_path}/{gene}.png", format="png", dpi=300, bbox_inches="tight"
123
+ )
124
+ plt.savefig(f"{fig_path}/{gene}.pdf", format="pdf", bbox_inches="tight")
125
+ plt.savefig(f"{fig_path}/{gene}.svg", format="svg", bbox_inches="tight")
126
+ # plt.show()
127
+ plt.close()
128
+
129
+
130
+ def plot_gene_galleries_from_df(
131
+ dataset_name,
132
+ df_to_plot,
133
+ cell_boundary_dict,
134
+ nuclear_boundary_dict,
135
+ output_base_path,
136
+ plots_per_gallery=48,
137
+ cols_per_gallery=6,
138
+ ):
139
+ """Create per-gene gallery figures (grid of cells) from a DataFrame."""
140
+
141
+ if not os.path.exists(output_base_path):
142
+ os.makedirs(output_base_path)
143
+ print(f"Created output directory: {output_base_path}")
144
+
145
+ if not all(col in df_to_plot.columns for col in ["gene", "cell", "x", "y"]):
146
+ print("ERROR: df_to_plot must contain columns: gene, cell, x, y")
147
+ return
148
+
149
+ unique_genes = df_to_plot["gene"].unique()
150
+ print(f"Found {len(unique_genes)} genes")
151
+
152
+ for gene_name in unique_genes:
153
+ print(f"\nProcessing gene: {gene_name}...")
154
+ gene_specific_df = df_to_plot[df_to_plot["gene"] == gene_name]
155
+
156
+ # Cells that contain this gene
157
+ cells_with_this_gene = gene_specific_df["cell"].unique()
158
+ if len(cells_with_this_gene) == 0:
159
+ print(f"No points found for gene {gene_name}; skip")
160
+ continue
161
+
162
+ gene_output_folder = os.path.join(output_base_path, dataset_name, gene_name)
163
+ if not os.path.exists(gene_output_folder):
164
+ os.makedirs(gene_output_folder)
165
+
166
+ num_cells_for_this_gene = len(cells_with_this_gene)
167
+
168
+ # Build gallery figures for this gene
169
+ for i in range(0, num_cells_for_this_gene, plots_per_gallery):
170
+ batch_cell_ids = cells_with_this_gene[i : i + plots_per_gallery]
171
+ current_batch_size = len(batch_cell_ids)
172
+
173
+ rows_this_gallery = math.ceil(current_batch_size / cols_per_gallery)
174
+
175
+ fig, axes = plt.subplots(
176
+ rows_this_gallery,
177
+ cols_per_gallery,
178
+ figsize=(
179
+ cols_per_gallery * 3.5,
180
+ rows_this_gallery * 3.5,
181
+ ),
182
+ )
183
+ # Normalize axes to a 2D array for indexing
184
+ if rows_this_gallery == 1 and cols_per_gallery == 1:
185
+ axes = [[axes]]
186
+ elif rows_this_gallery == 1:
187
+ axes = [axes]
188
+ elif cols_per_gallery == 1:
189
+ axes = [[ax] for ax in axes]
190
+
191
+ for plot_idx, cell_id in enumerate(batch_cell_ids):
192
+ ax_row = plot_idx // cols_per_gallery
193
+ ax_col = plot_idx % cols_per_gallery
194
+ ax = axes[ax_row][ax_col]
195
+
196
+ # 1) Cell boundary
197
+ if cell_id in cell_boundary_dict:
198
+ cb = cell_boundary_dict[cell_id]
199
+ ax.plot(cb["x"], cb["y"], color="black", linewidth=0.8)
200
+ else:
201
+ ax.text(
202
+ 0.5,
203
+ 0.5,
204
+ "Missing cell boundary",
205
+ ha="center",
206
+ va="center",
207
+ fontsize=8,
208
+ color="red",
209
+ )
210
+
211
+ # 2) Nuclear boundary (optional)
212
+ if cell_id in nuclear_boundary_dict:
213
+ nb_data = nuclear_boundary_dict[cell_id]
214
+ # nb_data can be a dict or a DataFrame
215
+ if (
216
+ isinstance(nb_data, dict)
217
+ and "x" in nb_data
218
+ and hasattr(nb_data["x"], "__len__")
219
+ and len(nb_data["x"]) > 0
220
+ and "y" in nb_data
221
+ and hasattr(nb_data["y"], "__len__")
222
+ and len(nb_data["y"]) > 0
223
+ ):
224
+ ax.plot(
225
+ nb_data["x"],
226
+ nb_data["y"],
227
+ color="dimgray",
228
+ linestyle="--",
229
+ linewidth=0.7,
230
+ )
231
+ elif (
232
+ isinstance(nb_data, pd.DataFrame)
233
+ and not nb_data.empty
234
+ and "x" in nb_data.columns
235
+ and "y" in nb_data.columns
236
+ and len(nb_data["x"]) > 0
237
+ and len(nb_data["y"]) > 0
238
+ ):
239
+ ax.plot(
240
+ nb_data["x"],
241
+ nb_data["y"],
242
+ color="dimgray",
243
+ linestyle="--",
244
+ linewidth=0.7,
245
+ )
246
+
247
+ # 3) Points
248
+ points_in_cell_gene = gene_specific_df[
249
+ gene_specific_df["cell"] == cell_id
250
+ ]
251
+ ax.scatter(
252
+ points_in_cell_gene["x"],
253
+ points_in_cell_gene["y"],
254
+ s=5,
255
+ alpha=0.7,
256
+ color="blue",
257
+ ) # s=3, alpha=0.5
258
+
259
+ ax.set_title(f"Cell: {cell_id}", fontsize=7)
260
+ ax.axis("off")
261
+ ax.set_aspect("equal", adjustable="box")
262
+
263
+ # Remove unused subplots
264
+ for k in range(current_batch_size, rows_this_gallery * cols_per_gallery):
265
+ ax_row = k // cols_per_gallery
266
+ ax_col = k % cols_per_gallery
267
+ fig.delaxes(axes[ax_row][ax_col])
268
+
269
+ plt.tight_layout(pad=0.5)
270
+ gallery_num = (i // plots_per_gallery) + 1
271
+ save_path = os.path.join(
272
+ gene_output_folder, f"{gene_name}_gallery_{gallery_num}.png"
273
+ )
274
+
275
+ try:
276
+ plt.savefig(save_path, dpi=200)
277
+ print(f"Saved: {save_path}")
278
+ except Exception as e:
279
+ print(f"Failed to save {save_path}: {e}")
280
+ plt.close(fig)
281
+
282
+ print("\nDone")
283
+
284
+
285
+ # Registered gene scatter per cell (with nuclear boundary)
286
+ def plot_register_gene_distribution(
287
+ dataset, df_registered, path, nuclear_boundary_df_registered
288
+ ):
289
+ cells = df_registered["cell"].unique()
290
+ for cell in tqdm(cells, desc="Plotting per cell"):
291
+ save_dir = f"{path}/{dataset}/registered_gene/{cell}"
292
+ if not os.path.exists(save_dir):
293
+ os.makedirs(save_dir)
294
+ cell_gene_data = df_registered[df_registered["cell"] == cell]
295
+
296
+ # Nuclear boundary data (optional). Some datasets may not have nucleus
297
+ # boundaries for every cell.
298
+ nuclear_boundary_df = None
299
+ try:
300
+ candidate = nuclear_boundary_df_registered[
301
+ nuclear_boundary_df_registered["cell"] == cell
302
+ ]
303
+ if (
304
+ candidate is not None
305
+ and hasattr(candidate, "empty")
306
+ and not candidate.empty
307
+ and "x_c_s" in candidate.columns
308
+ and "y_c_s" in candidate.columns
309
+ ):
310
+ nuclear_boundary_df = candidate
311
+ except Exception:
312
+ nuclear_boundary_df = None
313
+ if not cell_gene_data.empty:
314
+ genes = cell_gene_data["gene"].unique()
315
+ for gene in tqdm(genes, desc=f"Plotting for cell {cell}", leave=False):
316
+ # print(f'Cell: {cell} - Gene: {gene}')
317
+ plt.figure(figsize=(4, 4))
318
+ radius = 1
319
+ circle = plt.Circle(
320
+ (0, 0),
321
+ radius,
322
+ color="gray",
323
+ fill=False,
324
+ label="Cell Boundary",
325
+ linewidth=1,
326
+ )
327
+ plt.gca().add_patch(circle)
328
+ gene_data = cell_gene_data[cell_gene_data["gene"] == gene]
329
+ plt.scatter(
330
+ gene_data["x_c_s"],
331
+ gene_data["y_c_s"],
332
+ label=f"Gene: {gene}",
333
+ s=2,
334
+ color="cornflowerblue",
335
+ )
336
+ # Remove spines
337
+ ax = plt.gca()
338
+ ax.spines["top"].set_visible(False)
339
+ ax.spines["right"].set_visible(False)
340
+ ax.spines["bottom"].set_visible(False)
341
+ ax.spines["left"].set_visible(False)
342
+
343
+ if nuclear_boundary_df is not None:
344
+ polygon_coords = list(
345
+ zip(
346
+ nuclear_boundary_df["x_c_s"],
347
+ nuclear_boundary_df["y_c_s"],
348
+ )
349
+ )
350
+ if polygon_coords:
351
+ boundary_x, boundary_y = zip(*polygon_coords)
352
+ ax.plot(boundary_x, boundary_y, color="darkgray", linewidth=1)
353
+ # Hide axes
354
+ plt.axis("off")
355
+ if (
356
+ dataset == "simulated1"
357
+ or dataset == "simulated2"
358
+ or dataset == "simulated3"
359
+ ):
360
+ plt.title(f"{cell} - {gene}")
361
+ else:
362
+ plt.title(f"Cell: {cell} - Gene: {gene}")
363
+ # save_path = os.path.join(save_dir, f'{cell}_{gene}.png')
364
+ # plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
365
+ plt.savefig(
366
+ f"{save_dir}/{cell}_{gene}.png",
367
+ format="png",
368
+ dpi=300,
369
+ bbox_inches="tight",
370
+ )
371
+ plt.savefig(
372
+ f"{save_dir}/{cell}_{gene}.pdf", format="pdf", bbox_inches="tight"
373
+ )
374
+ plt.savefig(
375
+ f"{save_dir}/{cell}_{gene}.svg", format="svg", bbox_inches="tight"
376
+ )
377
+ plt.close()
378
+ # print(f"All cell and gene images have been saved to {save_dir}")
379
+
380
+
381
+ # Registered gene scatter per cell (without nuclear boundary)
382
+ def plot_register_gene_distribution_without_nuclear(
383
+ dataset, df_registered, cell_radii, path
384
+ ):
385
+ cells = df_registered["cell"].unique()
386
+ for cell in tqdm(cells, desc="Plotting per cell"):
387
+ save_dir = f"{path}/{dataset}/registered_gene/{cell}"
388
+ if not os.path.exists(save_dir):
389
+ os.makedirs(save_dir)
390
+ cell_gene_data = df_registered[df_registered["cell"] == cell]
391
+ if not cell_gene_data.empty:
392
+ genes = cell_gene_data["gene"].unique()
393
+ for gene in tqdm(genes, desc=f"Plotting for cell {cell}", leave=False):
394
+ # print(f'Cell: {cell} - Gene: {gene}')
395
+ plt.figure(figsize=(4, 4))
396
+ radius = 1
397
+ circle = plt.Circle(
398
+ (0, 0),
399
+ radius,
400
+ color="gray",
401
+ fill=False,
402
+ label="Cell Boundary",
403
+ linewidth=1,
404
+ )
405
+ plt.gca().add_patch(circle)
406
+ gene_data = cell_gene_data[cell_gene_data["gene"] == gene]
407
+ plt.scatter(
408
+ gene_data["x_c_s"],
409
+ gene_data["y_c_s"],
410
+ label=f"Gene: {gene}",
411
+ s=2,
412
+ color="cornflowerblue",
413
+ )
414
+ # Remove spines
415
+ ax = plt.gca()
416
+ ax.spines["top"].set_visible(False)
417
+ ax.spines["right"].set_visible(False)
418
+ ax.spines["bottom"].set_visible(False)
419
+ ax.spines["left"].set_visible(False)
420
+ plt.axis("off")
421
+ plt.title(f"Cell: {cell} - Gene: {gene}")
422
+ # save_path = os.path.join(save_dir, f'{cell}_{gene}.png')
423
+ # plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
424
+ plt.savefig(
425
+ f"{save_dir}/{cell}_{gene}.png",
426
+ format="png",
427
+ dpi=300,
428
+ bbox_inches="tight",
429
+ )
430
+ plt.savefig(
431
+ f"{save_dir}/{cell}_{gene}.pdf", format="pdf", bbox_inches="tight"
432
+ )
433
+ plt.savefig(
434
+ f"{save_dir}/{cell}_{gene}.svg", format="svg", bbox_inches="tight"
435
+ )
436
+ plt.close()
437
+
438
+
439
+ def plot_each_batch(dataset, adata, batch, path):
440
+ save_dir = f"{path}/{dataset}/each_batch"
441
+ if not os.path.exists(save_dir):
442
+ os.makedirs(save_dir)
443
+ adata_sub = adata[adata.obs["batch"] == batch]
444
+ points = adata_sub.uns["points"]
445
+ df = pd.DataFrame(points)
446
+ df = df[df["batch"] == int(batch)]
447
+ df["cell"] = df["cell"].astype(str)
448
+ df["gene"] = df["gene"].astype(str)
449
+ df = df[df["gene"].isin(set(adata.var_names))]
450
+ df = df[df["cell"].isin(set(adata.obs_names))]
451
+ gene_all = df["gene"].value_counts()
452
+ cell_all = df["cell"].value_counts()
453
+ cell_shape = adata_sub.obs["cell_shape"].to_frame()
454
+ nucleus_shape = adata_sub.obs["nucleus_shape"].to_frame()
455
+ plt.figure(figsize=(6, 4))
456
+ for index, row in cell_shape.iterrows():
457
+ polygon = loads(row["cell_shape"])
458
+ x, y = polygon.exterior.xy
459
+ plt.plot(x, y, linestyle="-", color="grey", linewidth=1)
460
+ centroid = polygon.centroid
461
+ cx, cy = centroid.x, centroid.y
462
+ plt.text(cx, cy, str(index), fontsize=10, ha="center", color="darkblue")
463
+ for index, row in nucleus_shape.iterrows():
464
+ polygon = loads(row["nucleus_shape"])
465
+ x, y = polygon.exterior.xy
466
+ plt.plot(x, y, linestyle="-", color="darkgray", linewidth=1)
467
+ plt.title(f"Batch {batch}")
468
+ plt.axis("off")
469
+ plt.savefig(
470
+ f"{save_dir}/batch{batch}_plot.png", format="png", dpi=300, bbox_inches="tight"
471
+ )
472
+ plt.savefig(f"{save_dir}/batch{batch}_plot.pdf", format="pdf", bbox_inches="tight")
473
+ plt.savefig(f"{save_dir}/batch{batch}_plot.svg", format="svg", bbox_inches="tight")
474
+ # plt.savefig(f"{save_dir}/batch{batch}_plot.png", format='png', dpi=400)
475
+ # plt.show()