grasp-tool 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_tool/__init__.py +17 -0
- grasp_tool/__main__.py +6 -0
- grasp_tool/cli/__init__.py +1 -0
- grasp_tool/cli/main.py +793 -0
- grasp_tool/cli/train_moco.py +778 -0
- grasp_tool/gnn/__init__.py +1 -0
- grasp_tool/gnn/embedding.py +165 -0
- grasp_tool/gnn/gat_moco_final.py +990 -0
- grasp_tool/gnn/graphloader.py +1748 -0
- grasp_tool/gnn/plot_refined.py +1556 -0
- grasp_tool/preprocessing/__init__.py +1 -0
- grasp_tool/preprocessing/augumentation.py +66 -0
- grasp_tool/preprocessing/cellplot.py +475 -0
- grasp_tool/preprocessing/filter.py +171 -0
- grasp_tool/preprocessing/network.py +79 -0
- grasp_tool/preprocessing/partition.py +654 -0
- grasp_tool/preprocessing/portrait.py +1862 -0
- grasp_tool/preprocessing/register.py +1021 -0
- grasp_tool-0.1.0.dist-info/METADATA +511 -0
- grasp_tool-0.1.0.dist-info/RECORD +22 -0
- grasp_tool-0.1.0.dist-info/WHEEL +4 -0
- grasp_tool-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Preprocessing utilities (register/partition/portrait/etc.)."""
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def rotate_nodes(node_matrix, angle):
|
|
7
|
+
theta = np.radians(angle)
|
|
8
|
+
cos_theta, sin_theta = np.cos(theta), np.sin(theta)
|
|
9
|
+
x, y = node_matrix["x"], node_matrix["y"]
|
|
10
|
+
node_matrix["x"] = x * cos_theta - y * sin_theta
|
|
11
|
+
node_matrix["y"] = x * sin_theta + y * cos_theta
|
|
12
|
+
return node_matrix
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def dropout_nodes(adj_matrix, node_matrix, dropout_ratio):
|
|
16
|
+
real_nodes = node_matrix[node_matrix["is_virtual"] == 0].index
|
|
17
|
+
# n = len(node_matrix)
|
|
18
|
+
num_real_nodes = len(real_nodes)
|
|
19
|
+
num_drop = int(num_real_nodes * dropout_ratio)
|
|
20
|
+
drop_nodes = np.random.choice(real_nodes, size=num_drop, replace=False)
|
|
21
|
+
node_matrix.loc[drop_nodes, "is_virtual"] = 1
|
|
22
|
+
for node in drop_nodes:
|
|
23
|
+
adj_matrix.iloc[node, :] = 0
|
|
24
|
+
adj_matrix.iloc[:, node] = 0
|
|
25
|
+
return adj_matrix, node_matrix
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def plot_graph(
|
|
29
|
+
adj_matrix_before,
|
|
30
|
+
node_matrix_before,
|
|
31
|
+
adj_matrix_after,
|
|
32
|
+
node_matrix_after,
|
|
33
|
+
title,
|
|
34
|
+
save_path,
|
|
35
|
+
):
|
|
36
|
+
fig, axes = plt.subplots(1, 2, figsize=(6, 3), sharex=True, sharey=True)
|
|
37
|
+
titles = [f"{title} (Original)", f"{title} (After Dropout)"]
|
|
38
|
+
adj_matrices = [adj_matrix_before, adj_matrix_after]
|
|
39
|
+
node_matrices = [node_matrix_before, node_matrix_after]
|
|
40
|
+
for ax, adj_matrix, node_matrix, sub_title in zip(
|
|
41
|
+
axes, adj_matrices, node_matrices, titles
|
|
42
|
+
):
|
|
43
|
+
ax.set_title(sub_title)
|
|
44
|
+
for i in range(len(adj_matrix)):
|
|
45
|
+
for j in range(len(adj_matrix)):
|
|
46
|
+
if adj_matrix.iloc[i, j] == 1:
|
|
47
|
+
x_coords = [node_matrix.loc[i, "x"], node_matrix.loc[j, "x"]]
|
|
48
|
+
y_coords = [node_matrix.loc[i, "y"], node_matrix.loc[j, "y"]]
|
|
49
|
+
ax.plot(x_coords, y_coords, "gray", alpha=0.6, linewidth=0.5)
|
|
50
|
+
virtual_nodes = node_matrix[node_matrix["is_virtual"] == 1]
|
|
51
|
+
real_nodes = node_matrix[node_matrix["is_virtual"] == 0]
|
|
52
|
+
ax.scatter(
|
|
53
|
+
virtual_nodes["x"],
|
|
54
|
+
virtual_nodes["y"],
|
|
55
|
+
color="lightgray",
|
|
56
|
+
label="Virtual Nodes",
|
|
57
|
+
s=5,
|
|
58
|
+
)
|
|
59
|
+
ax.scatter(
|
|
60
|
+
real_nodes["x"], real_nodes["y"], color="red", label="Real Nodes", s=5
|
|
61
|
+
)
|
|
62
|
+
# ax.legend()
|
|
63
|
+
ax.set_xlabel("x")
|
|
64
|
+
ax.set_ylabel("y")
|
|
65
|
+
plt.tight_layout()
|
|
66
|
+
plt.savefig(f"{save_path}/{title}_comparison.png")
|
|
@@ -0,0 +1,475 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import os
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from shapely.wkt import loads
|
|
5
|
+
from shapely.geometry import Polygon
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
import seaborn as sns
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def plot_raw_cell(dataset, cell_boundary, nuclear_boundary, path):
|
|
12
|
+
save_dir = f"{path}/1_{dataset}_raw_cell_plot"
|
|
13
|
+
if not os.path.exists(save_dir):
|
|
14
|
+
os.makedirs(save_dir)
|
|
15
|
+
cells = list(cell_boundary.keys())
|
|
16
|
+
num_cells = len(cells)
|
|
17
|
+
for idx, cell in enumerate(cells):
|
|
18
|
+
plt.figure(figsize=(4, 4))
|
|
19
|
+
plt.plot(
|
|
20
|
+
cell_boundary[cell]["x"],
|
|
21
|
+
cell_boundary[cell]["y"],
|
|
22
|
+
label="Cell Boundary",
|
|
23
|
+
color="black",
|
|
24
|
+
)
|
|
25
|
+
if cell in nuclear_boundary:
|
|
26
|
+
plt.plot(
|
|
27
|
+
nuclear_boundary[cell]["x"],
|
|
28
|
+
nuclear_boundary[cell]["y"],
|
|
29
|
+
label="Nucleus Boundary",
|
|
30
|
+
color="red",
|
|
31
|
+
)
|
|
32
|
+
plt.title(f"Cell: {cell}")
|
|
33
|
+
save_path = os.path.join(save_dir, f"{cell}.png")
|
|
34
|
+
plt.savefig(save_path)
|
|
35
|
+
|
|
36
|
+
plt.close()
|
|
37
|
+
print(f"All cell images have been saved to {save_dir}")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def plot_raw_gene_distribution(dataset, cell_boundary, nuclear_boundary, df, path):
|
|
41
|
+
cells = list(cell_boundary.keys())
|
|
42
|
+
num_cells = len(cells)
|
|
43
|
+
# tqdm progress bar
|
|
44
|
+
for idx, cell in tqdm(enumerate(cells), total=num_cells, desc="Processing cells"):
|
|
45
|
+
cell_data = df[df["cell"] == cell]
|
|
46
|
+
save_dir = f"{path}/{dataset}/raw_gene/cell_{cell}"
|
|
47
|
+
if not os.path.exists(save_dir):
|
|
48
|
+
os.makedirs(save_dir)
|
|
49
|
+
for gene in cell_data["gene"].unique():
|
|
50
|
+
plt.figure(figsize=(3, 3))
|
|
51
|
+
plt.plot(
|
|
52
|
+
cell_boundary[cell]["x"],
|
|
53
|
+
cell_boundary[cell]["y"],
|
|
54
|
+
label="Cell Boundary",
|
|
55
|
+
color="black",
|
|
56
|
+
)
|
|
57
|
+
if cell in nuclear_boundary:
|
|
58
|
+
plt.plot(
|
|
59
|
+
nuclear_boundary[cell]["x"],
|
|
60
|
+
nuclear_boundary[cell]["y"],
|
|
61
|
+
label="Nucleus Boundary",
|
|
62
|
+
color="red",
|
|
63
|
+
)
|
|
64
|
+
gene_data = cell_data[cell_data["gene"] == gene]
|
|
65
|
+
plt.scatter(
|
|
66
|
+
gene_data["x"],
|
|
67
|
+
gene_data["y"],
|
|
68
|
+
label=f"Gene: {gene}",
|
|
69
|
+
s=3,
|
|
70
|
+
alpha=0.5,
|
|
71
|
+
color="blue",
|
|
72
|
+
)
|
|
73
|
+
if (
|
|
74
|
+
dataset == "simulated1"
|
|
75
|
+
or dataset == "simulated2"
|
|
76
|
+
or dataset == "simulated3"
|
|
77
|
+
):
|
|
78
|
+
plt.title(f"{cell} - {gene}")
|
|
79
|
+
elif (
|
|
80
|
+
dataset == "merscope_liver_data2"
|
|
81
|
+
or dataset == "merscope_liver_data3"
|
|
82
|
+
or dataset == "merscope_liver_data4"
|
|
83
|
+
):
|
|
84
|
+
plt.title(f"Gene: {gene}")
|
|
85
|
+
else:
|
|
86
|
+
plt.title(f"Cell: {cell} - Gene: {gene}")
|
|
87
|
+
# save_path = os.path.join(save_dir, f'{gene}.png')
|
|
88
|
+
plt.axis("off")
|
|
89
|
+
# plt.savefig(save_path)
|
|
90
|
+
plt.savefig(
|
|
91
|
+
f"{save_dir}/{gene}.png", format="png", dpi=300, bbox_inches="tight"
|
|
92
|
+
)
|
|
93
|
+
plt.savefig(f"{save_dir}/{gene}.pdf", format="pdf", bbox_inches="tight")
|
|
94
|
+
plt.savefig(f"{save_dir}/{gene}.svg", format="svg", bbox_inches="tight")
|
|
95
|
+
plt.close()
|
|
96
|
+
print(f"All cell images have been saved to {save_dir}")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def plot_raw_gene_distribution_without_nuclear(
|
|
100
|
+
dataset, cell_boundary, df_registered, path
|
|
101
|
+
):
|
|
102
|
+
# for cell_name, cell_data in cell_boundary.items():
|
|
103
|
+
for cell_name, cell_data in tqdm(
|
|
104
|
+
cell_boundary.items(), desc="Processing cells", leave=True
|
|
105
|
+
):
|
|
106
|
+
fig_path = f"{path}/{dataset}/raw_gene/{cell_name}"
|
|
107
|
+
os.makedirs(fig_path, exist_ok=True)
|
|
108
|
+
sub_df_registered = df_registered[df_registered["cell"] == cell_name]
|
|
109
|
+
gene_list = sub_df_registered["gene"].unique()
|
|
110
|
+
for gene in tqdm(gene_list, desc=f"Plotting for {cell_name}", leave=False):
|
|
111
|
+
gene_data = sub_df_registered[sub_df_registered["gene"] == gene]
|
|
112
|
+
plt.figure(figsize=(3, 3))
|
|
113
|
+
cell_polygon = Polygon(cell_data[["x", "y"]])
|
|
114
|
+
x, y = cell_polygon.exterior.xy
|
|
115
|
+
plt.plot(x, y, linestyle="-", color="black", linewidth=1)
|
|
116
|
+
plt.scatter(gene_data["x"], gene_data["y"], s=3, color="cornflowerblue")
|
|
117
|
+
plt.title(f"{cell_name} - {gene}")
|
|
118
|
+
plt.axis("off")
|
|
119
|
+
plt.tight_layout()
|
|
120
|
+
# plt.savefig(f'{path}/{gene}.png', dpi=300)
|
|
121
|
+
plt.savefig(
|
|
122
|
+
f"{fig_path}/{gene}.png", format="png", dpi=300, bbox_inches="tight"
|
|
123
|
+
)
|
|
124
|
+
plt.savefig(f"{fig_path}/{gene}.pdf", format="pdf", bbox_inches="tight")
|
|
125
|
+
plt.savefig(f"{fig_path}/{gene}.svg", format="svg", bbox_inches="tight")
|
|
126
|
+
# plt.show()
|
|
127
|
+
plt.close()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def plot_gene_galleries_from_df(
|
|
131
|
+
dataset_name,
|
|
132
|
+
df_to_plot,
|
|
133
|
+
cell_boundary_dict,
|
|
134
|
+
nuclear_boundary_dict,
|
|
135
|
+
output_base_path,
|
|
136
|
+
plots_per_gallery=48,
|
|
137
|
+
cols_per_gallery=6,
|
|
138
|
+
):
|
|
139
|
+
"""Create per-gene gallery figures (grid of cells) from a DataFrame."""
|
|
140
|
+
|
|
141
|
+
if not os.path.exists(output_base_path):
|
|
142
|
+
os.makedirs(output_base_path)
|
|
143
|
+
print(f"Created output directory: {output_base_path}")
|
|
144
|
+
|
|
145
|
+
if not all(col in df_to_plot.columns for col in ["gene", "cell", "x", "y"]):
|
|
146
|
+
print("ERROR: df_to_plot must contain columns: gene, cell, x, y")
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
unique_genes = df_to_plot["gene"].unique()
|
|
150
|
+
print(f"Found {len(unique_genes)} genes")
|
|
151
|
+
|
|
152
|
+
for gene_name in unique_genes:
|
|
153
|
+
print(f"\nProcessing gene: {gene_name}...")
|
|
154
|
+
gene_specific_df = df_to_plot[df_to_plot["gene"] == gene_name]
|
|
155
|
+
|
|
156
|
+
# Cells that contain this gene
|
|
157
|
+
cells_with_this_gene = gene_specific_df["cell"].unique()
|
|
158
|
+
if len(cells_with_this_gene) == 0:
|
|
159
|
+
print(f"No points found for gene {gene_name}; skip")
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
gene_output_folder = os.path.join(output_base_path, dataset_name, gene_name)
|
|
163
|
+
if not os.path.exists(gene_output_folder):
|
|
164
|
+
os.makedirs(gene_output_folder)
|
|
165
|
+
|
|
166
|
+
num_cells_for_this_gene = len(cells_with_this_gene)
|
|
167
|
+
|
|
168
|
+
# Build gallery figures for this gene
|
|
169
|
+
for i in range(0, num_cells_for_this_gene, plots_per_gallery):
|
|
170
|
+
batch_cell_ids = cells_with_this_gene[i : i + plots_per_gallery]
|
|
171
|
+
current_batch_size = len(batch_cell_ids)
|
|
172
|
+
|
|
173
|
+
rows_this_gallery = math.ceil(current_batch_size / cols_per_gallery)
|
|
174
|
+
|
|
175
|
+
fig, axes = plt.subplots(
|
|
176
|
+
rows_this_gallery,
|
|
177
|
+
cols_per_gallery,
|
|
178
|
+
figsize=(
|
|
179
|
+
cols_per_gallery * 3.5,
|
|
180
|
+
rows_this_gallery * 3.5,
|
|
181
|
+
),
|
|
182
|
+
)
|
|
183
|
+
# Normalize axes to a 2D array for indexing
|
|
184
|
+
if rows_this_gallery == 1 and cols_per_gallery == 1:
|
|
185
|
+
axes = [[axes]]
|
|
186
|
+
elif rows_this_gallery == 1:
|
|
187
|
+
axes = [axes]
|
|
188
|
+
elif cols_per_gallery == 1:
|
|
189
|
+
axes = [[ax] for ax in axes]
|
|
190
|
+
|
|
191
|
+
for plot_idx, cell_id in enumerate(batch_cell_ids):
|
|
192
|
+
ax_row = plot_idx // cols_per_gallery
|
|
193
|
+
ax_col = plot_idx % cols_per_gallery
|
|
194
|
+
ax = axes[ax_row][ax_col]
|
|
195
|
+
|
|
196
|
+
# 1) Cell boundary
|
|
197
|
+
if cell_id in cell_boundary_dict:
|
|
198
|
+
cb = cell_boundary_dict[cell_id]
|
|
199
|
+
ax.plot(cb["x"], cb["y"], color="black", linewidth=0.8)
|
|
200
|
+
else:
|
|
201
|
+
ax.text(
|
|
202
|
+
0.5,
|
|
203
|
+
0.5,
|
|
204
|
+
"Missing cell boundary",
|
|
205
|
+
ha="center",
|
|
206
|
+
va="center",
|
|
207
|
+
fontsize=8,
|
|
208
|
+
color="red",
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# 2) Nuclear boundary (optional)
|
|
212
|
+
if cell_id in nuclear_boundary_dict:
|
|
213
|
+
nb_data = nuclear_boundary_dict[cell_id]
|
|
214
|
+
# nb_data can be a dict or a DataFrame
|
|
215
|
+
if (
|
|
216
|
+
isinstance(nb_data, dict)
|
|
217
|
+
and "x" in nb_data
|
|
218
|
+
and hasattr(nb_data["x"], "__len__")
|
|
219
|
+
and len(nb_data["x"]) > 0
|
|
220
|
+
and "y" in nb_data
|
|
221
|
+
and hasattr(nb_data["y"], "__len__")
|
|
222
|
+
and len(nb_data["y"]) > 0
|
|
223
|
+
):
|
|
224
|
+
ax.plot(
|
|
225
|
+
nb_data["x"],
|
|
226
|
+
nb_data["y"],
|
|
227
|
+
color="dimgray",
|
|
228
|
+
linestyle="--",
|
|
229
|
+
linewidth=0.7,
|
|
230
|
+
)
|
|
231
|
+
elif (
|
|
232
|
+
isinstance(nb_data, pd.DataFrame)
|
|
233
|
+
and not nb_data.empty
|
|
234
|
+
and "x" in nb_data.columns
|
|
235
|
+
and "y" in nb_data.columns
|
|
236
|
+
and len(nb_data["x"]) > 0
|
|
237
|
+
and len(nb_data["y"]) > 0
|
|
238
|
+
):
|
|
239
|
+
ax.plot(
|
|
240
|
+
nb_data["x"],
|
|
241
|
+
nb_data["y"],
|
|
242
|
+
color="dimgray",
|
|
243
|
+
linestyle="--",
|
|
244
|
+
linewidth=0.7,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# 3) Points
|
|
248
|
+
points_in_cell_gene = gene_specific_df[
|
|
249
|
+
gene_specific_df["cell"] == cell_id
|
|
250
|
+
]
|
|
251
|
+
ax.scatter(
|
|
252
|
+
points_in_cell_gene["x"],
|
|
253
|
+
points_in_cell_gene["y"],
|
|
254
|
+
s=5,
|
|
255
|
+
alpha=0.7,
|
|
256
|
+
color="blue",
|
|
257
|
+
) # s=3, alpha=0.5
|
|
258
|
+
|
|
259
|
+
ax.set_title(f"Cell: {cell_id}", fontsize=7)
|
|
260
|
+
ax.axis("off")
|
|
261
|
+
ax.set_aspect("equal", adjustable="box")
|
|
262
|
+
|
|
263
|
+
# Remove unused subplots
|
|
264
|
+
for k in range(current_batch_size, rows_this_gallery * cols_per_gallery):
|
|
265
|
+
ax_row = k // cols_per_gallery
|
|
266
|
+
ax_col = k % cols_per_gallery
|
|
267
|
+
fig.delaxes(axes[ax_row][ax_col])
|
|
268
|
+
|
|
269
|
+
plt.tight_layout(pad=0.5)
|
|
270
|
+
gallery_num = (i // plots_per_gallery) + 1
|
|
271
|
+
save_path = os.path.join(
|
|
272
|
+
gene_output_folder, f"{gene_name}_gallery_{gallery_num}.png"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
plt.savefig(save_path, dpi=200)
|
|
277
|
+
print(f"Saved: {save_path}")
|
|
278
|
+
except Exception as e:
|
|
279
|
+
print(f"Failed to save {save_path}: {e}")
|
|
280
|
+
plt.close(fig)
|
|
281
|
+
|
|
282
|
+
print("\nDone")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# Registered gene scatter per cell (with nuclear boundary)
|
|
286
|
+
def plot_register_gene_distribution(
|
|
287
|
+
dataset, df_registered, path, nuclear_boundary_df_registered
|
|
288
|
+
):
|
|
289
|
+
cells = df_registered["cell"].unique()
|
|
290
|
+
for cell in tqdm(cells, desc="Plotting per cell"):
|
|
291
|
+
save_dir = f"{path}/{dataset}/registered_gene/{cell}"
|
|
292
|
+
if not os.path.exists(save_dir):
|
|
293
|
+
os.makedirs(save_dir)
|
|
294
|
+
cell_gene_data = df_registered[df_registered["cell"] == cell]
|
|
295
|
+
|
|
296
|
+
# Nuclear boundary data (optional). Some datasets may not have nucleus
|
|
297
|
+
# boundaries for every cell.
|
|
298
|
+
nuclear_boundary_df = None
|
|
299
|
+
try:
|
|
300
|
+
candidate = nuclear_boundary_df_registered[
|
|
301
|
+
nuclear_boundary_df_registered["cell"] == cell
|
|
302
|
+
]
|
|
303
|
+
if (
|
|
304
|
+
candidate is not None
|
|
305
|
+
and hasattr(candidate, "empty")
|
|
306
|
+
and not candidate.empty
|
|
307
|
+
and "x_c_s" in candidate.columns
|
|
308
|
+
and "y_c_s" in candidate.columns
|
|
309
|
+
):
|
|
310
|
+
nuclear_boundary_df = candidate
|
|
311
|
+
except Exception:
|
|
312
|
+
nuclear_boundary_df = None
|
|
313
|
+
if not cell_gene_data.empty:
|
|
314
|
+
genes = cell_gene_data["gene"].unique()
|
|
315
|
+
for gene in tqdm(genes, desc=f"Plotting for cell {cell}", leave=False):
|
|
316
|
+
# print(f'Cell: {cell} - Gene: {gene}')
|
|
317
|
+
plt.figure(figsize=(4, 4))
|
|
318
|
+
radius = 1
|
|
319
|
+
circle = plt.Circle(
|
|
320
|
+
(0, 0),
|
|
321
|
+
radius,
|
|
322
|
+
color="gray",
|
|
323
|
+
fill=False,
|
|
324
|
+
label="Cell Boundary",
|
|
325
|
+
linewidth=1,
|
|
326
|
+
)
|
|
327
|
+
plt.gca().add_patch(circle)
|
|
328
|
+
gene_data = cell_gene_data[cell_gene_data["gene"] == gene]
|
|
329
|
+
plt.scatter(
|
|
330
|
+
gene_data["x_c_s"],
|
|
331
|
+
gene_data["y_c_s"],
|
|
332
|
+
label=f"Gene: {gene}",
|
|
333
|
+
s=2,
|
|
334
|
+
color="cornflowerblue",
|
|
335
|
+
)
|
|
336
|
+
# Remove spines
|
|
337
|
+
ax = plt.gca()
|
|
338
|
+
ax.spines["top"].set_visible(False)
|
|
339
|
+
ax.spines["right"].set_visible(False)
|
|
340
|
+
ax.spines["bottom"].set_visible(False)
|
|
341
|
+
ax.spines["left"].set_visible(False)
|
|
342
|
+
|
|
343
|
+
if nuclear_boundary_df is not None:
|
|
344
|
+
polygon_coords = list(
|
|
345
|
+
zip(
|
|
346
|
+
nuclear_boundary_df["x_c_s"],
|
|
347
|
+
nuclear_boundary_df["y_c_s"],
|
|
348
|
+
)
|
|
349
|
+
)
|
|
350
|
+
if polygon_coords:
|
|
351
|
+
boundary_x, boundary_y = zip(*polygon_coords)
|
|
352
|
+
ax.plot(boundary_x, boundary_y, color="darkgray", linewidth=1)
|
|
353
|
+
# Hide axes
|
|
354
|
+
plt.axis("off")
|
|
355
|
+
if (
|
|
356
|
+
dataset == "simulated1"
|
|
357
|
+
or dataset == "simulated2"
|
|
358
|
+
or dataset == "simulated3"
|
|
359
|
+
):
|
|
360
|
+
plt.title(f"{cell} - {gene}")
|
|
361
|
+
else:
|
|
362
|
+
plt.title(f"Cell: {cell} - Gene: {gene}")
|
|
363
|
+
# save_path = os.path.join(save_dir, f'{cell}_{gene}.png')
|
|
364
|
+
# plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
|
365
|
+
plt.savefig(
|
|
366
|
+
f"{save_dir}/{cell}_{gene}.png",
|
|
367
|
+
format="png",
|
|
368
|
+
dpi=300,
|
|
369
|
+
bbox_inches="tight",
|
|
370
|
+
)
|
|
371
|
+
plt.savefig(
|
|
372
|
+
f"{save_dir}/{cell}_{gene}.pdf", format="pdf", bbox_inches="tight"
|
|
373
|
+
)
|
|
374
|
+
plt.savefig(
|
|
375
|
+
f"{save_dir}/{cell}_{gene}.svg", format="svg", bbox_inches="tight"
|
|
376
|
+
)
|
|
377
|
+
plt.close()
|
|
378
|
+
# print(f"All cell and gene images have been saved to {save_dir}")
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
# Registered gene scatter per cell (without nuclear boundary)
|
|
382
|
+
def plot_register_gene_distribution_without_nuclear(
|
|
383
|
+
dataset, df_registered, cell_radii, path
|
|
384
|
+
):
|
|
385
|
+
cells = df_registered["cell"].unique()
|
|
386
|
+
for cell in tqdm(cells, desc="Plotting per cell"):
|
|
387
|
+
save_dir = f"{path}/{dataset}/registered_gene/{cell}"
|
|
388
|
+
if not os.path.exists(save_dir):
|
|
389
|
+
os.makedirs(save_dir)
|
|
390
|
+
cell_gene_data = df_registered[df_registered["cell"] == cell]
|
|
391
|
+
if not cell_gene_data.empty:
|
|
392
|
+
genes = cell_gene_data["gene"].unique()
|
|
393
|
+
for gene in tqdm(genes, desc=f"Plotting for cell {cell}", leave=False):
|
|
394
|
+
# print(f'Cell: {cell} - Gene: {gene}')
|
|
395
|
+
plt.figure(figsize=(4, 4))
|
|
396
|
+
radius = 1
|
|
397
|
+
circle = plt.Circle(
|
|
398
|
+
(0, 0),
|
|
399
|
+
radius,
|
|
400
|
+
color="gray",
|
|
401
|
+
fill=False,
|
|
402
|
+
label="Cell Boundary",
|
|
403
|
+
linewidth=1,
|
|
404
|
+
)
|
|
405
|
+
plt.gca().add_patch(circle)
|
|
406
|
+
gene_data = cell_gene_data[cell_gene_data["gene"] == gene]
|
|
407
|
+
plt.scatter(
|
|
408
|
+
gene_data["x_c_s"],
|
|
409
|
+
gene_data["y_c_s"],
|
|
410
|
+
label=f"Gene: {gene}",
|
|
411
|
+
s=2,
|
|
412
|
+
color="cornflowerblue",
|
|
413
|
+
)
|
|
414
|
+
# Remove spines
|
|
415
|
+
ax = plt.gca()
|
|
416
|
+
ax.spines["top"].set_visible(False)
|
|
417
|
+
ax.spines["right"].set_visible(False)
|
|
418
|
+
ax.spines["bottom"].set_visible(False)
|
|
419
|
+
ax.spines["left"].set_visible(False)
|
|
420
|
+
plt.axis("off")
|
|
421
|
+
plt.title(f"Cell: {cell} - Gene: {gene}")
|
|
422
|
+
# save_path = os.path.join(save_dir, f'{cell}_{gene}.png')
|
|
423
|
+
# plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
|
424
|
+
plt.savefig(
|
|
425
|
+
f"{save_dir}/{cell}_{gene}.png",
|
|
426
|
+
format="png",
|
|
427
|
+
dpi=300,
|
|
428
|
+
bbox_inches="tight",
|
|
429
|
+
)
|
|
430
|
+
plt.savefig(
|
|
431
|
+
f"{save_dir}/{cell}_{gene}.pdf", format="pdf", bbox_inches="tight"
|
|
432
|
+
)
|
|
433
|
+
plt.savefig(
|
|
434
|
+
f"{save_dir}/{cell}_{gene}.svg", format="svg", bbox_inches="tight"
|
|
435
|
+
)
|
|
436
|
+
plt.close()
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def plot_each_batch(dataset, adata, batch, path):
|
|
440
|
+
save_dir = f"{path}/{dataset}/each_batch"
|
|
441
|
+
if not os.path.exists(save_dir):
|
|
442
|
+
os.makedirs(save_dir)
|
|
443
|
+
adata_sub = adata[adata.obs["batch"] == batch]
|
|
444
|
+
points = adata_sub.uns["points"]
|
|
445
|
+
df = pd.DataFrame(points)
|
|
446
|
+
df = df[df["batch"] == int(batch)]
|
|
447
|
+
df["cell"] = df["cell"].astype(str)
|
|
448
|
+
df["gene"] = df["gene"].astype(str)
|
|
449
|
+
df = df[df["gene"].isin(set(adata.var_names))]
|
|
450
|
+
df = df[df["cell"].isin(set(adata.obs_names))]
|
|
451
|
+
gene_all = df["gene"].value_counts()
|
|
452
|
+
cell_all = df["cell"].value_counts()
|
|
453
|
+
cell_shape = adata_sub.obs["cell_shape"].to_frame()
|
|
454
|
+
nucleus_shape = adata_sub.obs["nucleus_shape"].to_frame()
|
|
455
|
+
plt.figure(figsize=(6, 4))
|
|
456
|
+
for index, row in cell_shape.iterrows():
|
|
457
|
+
polygon = loads(row["cell_shape"])
|
|
458
|
+
x, y = polygon.exterior.xy
|
|
459
|
+
plt.plot(x, y, linestyle="-", color="grey", linewidth=1)
|
|
460
|
+
centroid = polygon.centroid
|
|
461
|
+
cx, cy = centroid.x, centroid.y
|
|
462
|
+
plt.text(cx, cy, str(index), fontsize=10, ha="center", color="darkblue")
|
|
463
|
+
for index, row in nucleus_shape.iterrows():
|
|
464
|
+
polygon = loads(row["nucleus_shape"])
|
|
465
|
+
x, y = polygon.exterior.xy
|
|
466
|
+
plt.plot(x, y, linestyle="-", color="darkgray", linewidth=1)
|
|
467
|
+
plt.title(f"Batch {batch}")
|
|
468
|
+
plt.axis("off")
|
|
469
|
+
plt.savefig(
|
|
470
|
+
f"{save_dir}/batch{batch}_plot.png", format="png", dpi=300, bbox_inches="tight"
|
|
471
|
+
)
|
|
472
|
+
plt.savefig(f"{save_dir}/batch{batch}_plot.pdf", format="pdf", bbox_inches="tight")
|
|
473
|
+
plt.savefig(f"{save_dir}/batch{batch}_plot.svg", format="svg", bbox_inches="tight")
|
|
474
|
+
# plt.savefig(f"{save_dir}/batch{batch}_plot.png", format='png', dpi=400)
|
|
475
|
+
# plt.show()
|