grasp-tool 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,654 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.neighbors import KernelDensity
5
+ import seaborn as sns
6
+ import pickle
7
+ from shapely.geometry import Polygon, Point
8
+ import warnings
9
+ from scipy.spatial import distance_matrix
10
+ from tqdm import tqdm
11
+ import os
12
+ from itertools import combinations
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+ from sklearn.decomposition import PCA
15
+ from sklearn.manifold import TSNE
16
+ from sklearn.cluster import KMeans
17
+ from sklearn.metrics import silhouette_score
18
+ from sklearn.neighbors import NearestNeighbors
19
+ from sklearn.cluster import SpectralClustering
20
+ import multiprocessing as mp
21
+ import networkx as nx
22
+
23
+ # import ot
24
+ from matplotlib.patches import PathPatch
25
+ from matplotlib.path import Path
26
+
27
+ warnings.filterwarnings("ignore")
28
+
29
+
30
+ def classify_center_points_with_edge(
31
+ center_points, nuclear_boundary_df_registered, is_edge, epsilon=0.1
32
+ ):
33
+ polygon_coords = list(
34
+ zip(
35
+ nuclear_boundary_df_registered["x_c_s"],
36
+ nuclear_boundary_df_registered["y_c_s"],
37
+ )
38
+ )
39
+ polygon = Polygon(polygon_coords)
40
+ classifications = []
41
+ for idx, point in enumerate(center_points):
42
+ if is_edge[idx]:
43
+ classifications.append("edge")
44
+ continue
45
+ point_geom = Point(point)
46
+ if polygon.contains(point_geom):
47
+ classifications.append("inside")
48
+ elif polygon.touches(point_geom):
49
+ classifications.append("boundary")
50
+ else:
51
+ distance_to_boundary = polygon.boundary.distance(point_geom)
52
+ if distance_to_boundary <= epsilon:
53
+ classifications.append("boundary")
54
+ else:
55
+ classifications.append("outside")
56
+ return classifications
57
+
58
+
59
+ def save_node_data_to_csv_old(
60
+ center_points, is_virtual, plot_dir, gene, node_counts, k, nuclear_positions
61
+ ):
62
+ node_data = []
63
+ for idx, (x, y) in enumerate(center_points):
64
+ node_data.append(
65
+ {
66
+ "node_id": idx,
67
+ "x": x,
68
+ "y": y,
69
+ "is_virtual": 1 if is_virtual[idx] else 0,
70
+ "count": node_counts[idx],
71
+ "nuclear_position": nuclear_positions[idx],
72
+ }
73
+ )
74
+ node_df = pd.DataFrame(node_data)
75
+ node_df.to_csv(os.path.join(plot_dir, f"{gene}_node_matrix.csv"), index=False)
76
+ num_nodes = len(center_points)
77
+ distance_matrix = np.zeros((num_nodes, num_nodes))
78
+
79
+ for i in range(num_nodes):
80
+ for j in range(num_nodes):
81
+ if i == j:
82
+ distance_matrix[i, j] = 0
83
+ elif is_virtual[i] and is_virtual[j]:
84
+ # distance_matrix[i, j] = np.inf
85
+ distance_matrix[i, j] = 1e6
86
+ elif is_virtual[i] or is_virtual[j]:
87
+ # distance_matrix[i, j] = np.inf
88
+ distance_matrix[i, j] = 1e6
89
+ else:
90
+ distance_matrix[i, j] = np.linalg.norm(
91
+ np.array(center_points[i]) - np.array(center_points[j])
92
+ )
93
+ distance_matrix = pd.DataFrame(distance_matrix)
94
+ distance_matrix.to_csv(
95
+ os.path.join(plot_dir, f"{gene}_dis_matrix.csv"), index=False
96
+ )
97
+ adjacency_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
98
+ for i in range(num_nodes):
99
+ if is_virtual[i]:
100
+ continue
101
+ nearest_indices = np.argsort(distance_matrix[i])[: k + 1]
102
+ for idx in nearest_indices:
103
+ if not is_virtual[idx]:
104
+ adjacency_matrix[i, idx] = 1
105
+ np.fill_diagonal(adjacency_matrix, 0)
106
+ # Make adjacency symmetric.
107
+ adjacency_matrix = np.maximum(adjacency_matrix, adjacency_matrix.T)
108
+ adjacency_matrix = pd.DataFrame(adjacency_matrix)
109
+ adjacency_matrix.to_csv(
110
+ os.path.join(plot_dir, f"{gene}_adj_matrix.csv"), index=False
111
+ )
112
+
113
+
114
+ # Count points per sector/ring (shared centroid locations)
115
+ def count_points_in_areas_same(df, n_sectors, m_rings, r):
116
+ df["theta"] = np.arctan2(df["y_c_s"], df["x_c_s"])
117
+ df["radius"] = np.sqrt(df["x_c_s"] ** 2 + df["y_c_s"] ** 2)
118
+ count_matrix = np.zeros((m_rings, n_sectors))
119
+ theta_edges = np.linspace(-np.pi, np.pi, n_sectors + 1)
120
+ radius_edges = np.linspace(0, r, m_rings + 1)
121
+ center_points = []
122
+ point_counts = []
123
+ is_virtual = []
124
+ is_edge = []
125
+ for i in range(m_rings):
126
+ for j in range(n_sectors):
127
+ points_in_ring = df[
128
+ (df["radius"] > radius_edges[i]) & (df["radius"] <= radius_edges[i + 1])
129
+ ]
130
+ points_in_sector = points_in_ring[
131
+ (points_in_ring["theta"] >= theta_edges[j])
132
+ & (points_in_ring["theta"] < theta_edges[j + 1])
133
+ ]
134
+ count = len(points_in_sector)
135
+ count_matrix[i, j] = count
136
+ point_counts.append(count)
137
+ theta_center = (theta_edges[j] + theta_edges[j + 1]) / 2
138
+ radius_center = (radius_edges[i] + radius_edges[i + 1]) / 2
139
+ x_center, y_center = (
140
+ radius_center * np.cos(theta_center),
141
+ radius_center * np.sin(theta_center),
142
+ )
143
+ weight = count if count > 0 else 1
144
+ center_points.append((x_center, y_center))
145
+ is_virtual.append(False if count > 0 else True)
146
+ is_edge.append(True if i == m_rings - 1 or i == m_rings - 2 else False)
147
+ return count_matrix, center_points, point_counts, is_virtual, is_edge
148
+
149
+
150
+ # Count points per sector/ring (centroid varies per sector)
151
+ def count_points_in_areas(df, n_sectors, m_rings, r):
152
+ df["theta"] = np.arctan2(df["y_c_s"], df["x_c_s"])
153
+ df["radius"] = np.sqrt(df["x_c_s"] ** 2 + df["y_c_s"] ** 2)
154
+ count_matrix = np.zeros((m_rings, n_sectors))
155
+ theta_edges = np.linspace(-np.pi, np.pi, n_sectors + 1)
156
+ radius_edges = np.linspace(0, r, m_rings + 1)
157
+ center_points = []
158
+ point_counts = []
159
+ is_virtual = []
160
+ for i in range(m_rings):
161
+ for j in range(n_sectors):
162
+ points_in_ring = df[
163
+ (df["radius"] > radius_edges[i]) & (df["radius"] <= radius_edges[i + 1])
164
+ ]
165
+ points_in_sector = points_in_ring[
166
+ (points_in_ring["theta"] >= theta_edges[j])
167
+ & (points_in_ring["theta"] < theta_edges[j + 1])
168
+ ]
169
+ count = len(points_in_sector)
170
+ count_matrix[i, j] = count
171
+ point_counts.append(count)
172
+ theta_center = (theta_edges[j] + theta_edges[j + 1]) / 2
173
+ radius_center = (radius_edges[i] + radius_edges[i + 1]) / 2
174
+ x_center, y_center = (
175
+ radius_center * np.cos(theta_center),
176
+ radius_center * np.sin(theta_center),
177
+ )
178
+ if count > 0:
179
+ x_center = points_in_sector["x_c_s"].mean()
180
+ y_center = points_in_sector["y_c_s"].mean()
181
+ is_virtual.append(False)
182
+ else:
183
+ is_virtual.append(True)
184
+ center_points.append((x_center, y_center))
185
+ return count_matrix, center_points, point_counts, is_virtual
186
+
187
+
188
+ def build_graph_k_nearest(center_points, k):
189
+ edges = []
190
+ center_points = np.array(center_points)
191
+ nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm="ball_tree").fit(center_points)
192
+ distances, indices = nbrs.kneighbors(center_points)
193
+ for i, neighbors in enumerate(indices):
194
+ for j in neighbors[1:]:
195
+ edges.append((center_points[i], center_points[j]))
196
+ return edges
197
+
198
+
199
+ # Visualize partition, centers, and edges
200
+ def plot_cell_partition(
201
+ cell,
202
+ df,
203
+ center_points,
204
+ point_counts,
205
+ edges,
206
+ r,
207
+ gene,
208
+ is_virtual,
209
+ n_sectors,
210
+ m_rings,
211
+ plot_dir,
212
+ nuclear_boundary_df_registered,
213
+ ):
214
+ fig, ax = plt.subplots(figsize=(4, 4))
215
+ ax.set_aspect("equal")
216
+ ax.axis("off")
217
+ theta = np.linspace(0, 2 * np.pi, n_sectors + 1)
218
+ radii = np.linspace(0, r, m_rings + 1)
219
+ for rad in radii:
220
+ circle = plt.Circle(
221
+ (0, 0), rad, color="grey", fill=False, linestyle="--", linewidth=0.5
222
+ )
223
+ ax.add_artist(circle)
224
+ for angle in theta:
225
+ ax.plot(
226
+ [0, r * np.cos(angle)],
227
+ [0, r * np.sin(angle)],
228
+ color="grey",
229
+ linestyle="--",
230
+ linewidth=0.5,
231
+ )
232
+ ax.scatter(df["x_c_s"], df["y_c_s"], s=1, color="blue", label="Gene Points")
233
+ center_points = np.array(center_points)
234
+ point_sizes = np.array(point_counts) * 0.2
235
+ actual_centers = center_points[np.logical_not(is_virtual)]
236
+ virtual_centers = center_points[is_virtual]
237
+ for edge in edges:
238
+ (x1, y1), (x2, y2) = edge
239
+ start_index = np.where((center_points == (x1, y1)).all(axis=1))[0][0]
240
+ end_index = np.where((center_points == (x2, y2)).all(axis=1))[0][0]
241
+ if is_virtual[start_index] or is_virtual[end_index]:
242
+ line_style = "dashed"
243
+ color = "gainsboro" # color = 'orange'
244
+ else:
245
+ line_style = "solid"
246
+ color = "green"
247
+ ax.plot([x1, x2], [y1, y2], color=color, linestyle=line_style, linewidth=0.3)
248
+ ax.scatter(
249
+ virtual_centers[:, 0], virtual_centers[:, 1], color="gainsboro", s=2
250
+ ) # , label="Virtual Region Centers"
251
+ ax.scatter(
252
+ actual_centers[:, 0],
253
+ actual_centers[:, 1],
254
+ color="red",
255
+ s=point_sizes[np.logical_not(is_virtual)],
256
+ ) # , label="Actual Region Centers"
257
+ polygon_coords = list(
258
+ zip(
259
+ nuclear_boundary_df_registered["x_c_s"],
260
+ nuclear_boundary_df_registered["y_c_s"],
261
+ )
262
+ )
263
+ polygon = Polygon(polygon_coords)
264
+ boundary_x, boundary_y = zip(*polygon_coords)
265
+ ax.plot(boundary_x, boundary_y, color="blue", linewidth=1)
266
+ # colors = {'inside': 'green', 'outside': 'red', 'boundary': 'orange'}
267
+ # for point, classification in zip(center_points, classifications):
268
+ # ax.scatter(*point, color=colors[classification], label=f'{classification}', s=25, edgecolor='grey')
269
+ ax.spines["top"].set_visible(False)
270
+ ax.spines["right"].set_visible(False)
271
+ ax.spines["left"].set_visible(False)
272
+ ax.spines["bottom"].set_visible(False)
273
+
274
+ plt.title(f"Cell {cell} - Gene {gene}")
275
+ plt.xlim(-r, r)
276
+ plt.ylim(-r, r)
277
+ if not os.path.exists(plot_dir):
278
+ os.makedirs(plot_dir)
279
+ plt.savefig(
280
+ os.path.join(plot_dir, f"{gene}_partition_plot.png"),
281
+ bbox_inches="tight",
282
+ pad_inches=0,
283
+ dpi=300,
284
+ )
285
+ plt.close()
286
+
287
+
288
+ def build_graph_with_networkx(center_points, edges, is_virtual):
289
+ G = nx.Graph()
290
+ for idx, (x, y) in enumerate(center_points):
291
+ G.add_node(idx, pos=(x, y), is_virtual=is_virtual[idx])
292
+ edges = [(tuple(edge[0]), tuple(edge[1])) for edge in edges]
293
+ G.add_edges_from(edges)
294
+ return G
295
+
296
+
297
+ def save_node_data_to_csv_nonposition(
298
+ center_points, is_virtual, plot_dir, gene, node_counts, k
299
+ ):
300
+ node_data = []
301
+ for idx, (x, y) in enumerate(center_points):
302
+ node_data.append(
303
+ {
304
+ "node_id": idx,
305
+ "x": x,
306
+ "y": y,
307
+ "is_virtual": 1 if is_virtual[idx] else 0,
308
+ "count": node_counts[idx],
309
+ }
310
+ )
311
+ node_df = pd.DataFrame(node_data)
312
+ node_df.to_csv(os.path.join(plot_dir, f"{gene}_node_matrix.csv"), index=False)
313
+ num_nodes = len(center_points)
314
+ distance_matrix = np.zeros((num_nodes, num_nodes))
315
+
316
+ for i in range(num_nodes):
317
+ for j in range(num_nodes):
318
+ if i == j:
319
+ distance_matrix[i, j] = 0
320
+ elif is_virtual[i] and is_virtual[j]:
321
+ # distance_matrix[i, j] = np.inf
322
+ distance_matrix[i, j] = 1e6
323
+ elif is_virtual[i] or is_virtual[j]:
324
+ # distance_matrix[i, j] = np.inf
325
+ distance_matrix[i, j] = 1e6
326
+ else:
327
+ distance_matrix[i, j] = np.linalg.norm(
328
+ np.array(center_points[i]) - np.array(center_points[j])
329
+ )
330
+ distance_matrix = pd.DataFrame(distance_matrix)
331
+ distance_matrix.to_csv(
332
+ os.path.join(plot_dir, f"{gene}_dis_matrix.csv"), index=False
333
+ )
334
+ adjacency_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
335
+ for i in range(num_nodes):
336
+ if is_virtual[i]:
337
+ continue
338
+ nearest_indices = np.argsort(distance_matrix[i])[: k + 1]
339
+ for idx in nearest_indices:
340
+ if not is_virtual[idx]:
341
+ adjacency_matrix[i, idx] = 1
342
+ np.fill_diagonal(adjacency_matrix, 0)
343
+ adjacency_matrix = pd.DataFrame(adjacency_matrix)
344
+ adjacency_matrix.to_csv(
345
+ os.path.join(plot_dir, f"{gene}_adj_matrix.csv"), index=False
346
+ )
347
+
348
+
349
+ def save_node_data_to_csv(
350
+ center_points,
351
+ is_virtual,
352
+ is_edge,
353
+ plot_dir,
354
+ gene,
355
+ node_counts,
356
+ k,
357
+ nuclear_positions,
358
+ ):
359
+ node_data = []
360
+ for idx, (x, y) in enumerate(center_points):
361
+ node_data.append(
362
+ {
363
+ "node_id": idx,
364
+ "x": x,
365
+ "y": y,
366
+ "is_virtual": 1 if is_virtual[idx] else 0,
367
+ "is_edge": 1 if is_edge[idx] else 0, # Added is_edge column
368
+ "count": node_counts[idx],
369
+ "nuclear_position": nuclear_positions[idx],
370
+ }
371
+ )
372
+ node_df = pd.DataFrame(node_data)
373
+ node_df.to_csv(os.path.join(plot_dir, f"{gene}_node_matrix.csv"), index=False)
374
+ num_nodes = len(center_points)
375
+ distance_matrix = np.zeros((num_nodes, num_nodes))
376
+
377
+ for i in range(num_nodes):
378
+ for j in range(num_nodes):
379
+ if i == j:
380
+ distance_matrix[i, j] = 0
381
+ elif is_virtual[i] and is_virtual[j]:
382
+ # distance_matrix[i, j] = np.inf
383
+ distance_matrix[i, j] = 1e6
384
+ elif is_virtual[i] or is_virtual[j]:
385
+ # distance_matrix[i, j] = np.inf
386
+ distance_matrix[i, j] = 1e6
387
+ else:
388
+ distance_matrix[i, j] = np.linalg.norm(
389
+ np.array(center_points[i]) - np.array(center_points[j])
390
+ )
391
+ distance_matrix = pd.DataFrame(distance_matrix)
392
+ distance_matrix.to_csv(
393
+ os.path.join(plot_dir, f"{gene}_dis_matrix.csv"), index=False
394
+ )
395
+ adjacency_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
396
+ for i in range(num_nodes):
397
+ if is_virtual[i]:
398
+ continue
399
+ nearest_indices = np.argsort(distance_matrix[i])[: k + 1]
400
+ for idx in nearest_indices:
401
+ if not is_virtual[idx]:
402
+ adjacency_matrix[i, idx] = 1
403
+ np.fill_diagonal(adjacency_matrix, 0)
404
+ adjacency_matrix = pd.DataFrame(adjacency_matrix)
405
+ adjacency_matrix.to_csv(
406
+ os.path.join(plot_dir, f"{gene}_adj_matrix.csv"), index=False
407
+ )
408
+
409
+
410
+ def plot_cell_partition_heatmap_noposition(
411
+ cell, gene, point_counts, n_sectors, m_rings, r, plot_dir
412
+ ):
413
+ fig, ax = plt.subplots(figsize=(4, 4))
414
+ ax.set_aspect("equal")
415
+ ax.axis("off")
416
+ theta_edges = np.linspace(0, 2 * np.pi, n_sectors + 1)
417
+ radius_edges = np.linspace(0, r, m_rings + 1)
418
+ max_count = max(point_counts) if len(point_counts) > 0 else 1
419
+ normalized_counts = np.array(point_counts) / max_count
420
+ for sector_idx in range(n_sectors):
421
+ for ring_idx in range(m_rings):
422
+ theta_start = theta_edges[sector_idx]
423
+ theta_end = theta_edges[sector_idx + 1]
424
+ radius_start = radius_edges[ring_idx]
425
+ radius_end = radius_edges[ring_idx + 1]
426
+ index = ring_idx * n_sectors + sector_idx
427
+ count = normalized_counts[index] if index < len(normalized_counts) else 0
428
+ color = plt.cm.YlOrRd(count)
429
+ path_data = [
430
+ (
431
+ Path.MOVETO,
432
+ (
433
+ -radius_start * np.cos(theta_start),
434
+ -radius_start * np.sin(theta_start),
435
+ ),
436
+ ),
437
+ (
438
+ Path.LINETO,
439
+ (
440
+ -radius_end * np.cos(theta_start),
441
+ -radius_end * np.sin(theta_start),
442
+ ),
443
+ ),
444
+ (
445
+ Path.LINETO,
446
+ (-radius_end * np.cos(theta_end), -radius_end * np.sin(theta_end)),
447
+ ),
448
+ (
449
+ Path.LINETO,
450
+ (
451
+ -radius_start * np.cos(theta_end),
452
+ -radius_start * np.sin(theta_end),
453
+ ),
454
+ ),
455
+ (
456
+ Path.CLOSEPOLY,
457
+ (
458
+ -radius_start * np.cos(theta_start),
459
+ -radius_start * np.sin(theta_start),
460
+ ),
461
+ ),
462
+ ]
463
+ path = Path([p[1] for p in path_data], [p[0] for p in path_data])
464
+ patch = PathPatch(path, facecolor=color, edgecolor="grey", lw=0.5)
465
+ ax.add_patch(patch)
466
+
467
+ ax.spines["top"].set_visible(False)
468
+ ax.spines["right"].set_visible(False)
469
+ ax.spines["left"].set_visible(False)
470
+ ax.spines["bottom"].set_visible(False)
471
+ plt.title(f"Cell {cell} - Gene {gene}")
472
+ plt.xlim(-r, r)
473
+ plt.ylim(-r, r)
474
+ if not os.path.exists(plot_dir):
475
+ os.makedirs(plot_dir)
476
+ plt.savefig(
477
+ os.path.join(plot_dir, f"{gene}_partition_heatmap.png"),
478
+ bbox_inches="tight",
479
+ pad_inches=0,
480
+ dpi=300,
481
+ )
482
+ plt.close()
483
+
484
+
485
+ def plot_cell_partition_heatmap(
486
+ cell,
487
+ gene,
488
+ point_counts,
489
+ n_sectors,
490
+ m_rings,
491
+ r,
492
+ plot_dir,
493
+ nuclear_boundary_df_registered,
494
+ ):
495
+ fig, ax = plt.subplots(figsize=(4, 4))
496
+ ax.set_aspect("equal")
497
+ ax.axis("off")
498
+ theta_edges = np.linspace(0, 2 * np.pi, n_sectors + 1)
499
+ radius_edges = np.linspace(0, r, m_rings + 1)
500
+ max_count = max(point_counts) if len(point_counts) > 0 else 1
501
+ normalized_counts = np.array(point_counts) / max_count
502
+ for sector_idx in range(n_sectors):
503
+ for ring_idx in range(m_rings):
504
+ theta_start = theta_edges[sector_idx]
505
+ theta_end = theta_edges[sector_idx + 1]
506
+ radius_start = radius_edges[ring_idx]
507
+ radius_end = radius_edges[ring_idx + 1]
508
+ index = ring_idx * n_sectors + sector_idx
509
+ count = normalized_counts[index] if index < len(normalized_counts) else 0
510
+ color = plt.cm.YlOrRd(count)
511
+ path_data = [
512
+ (
513
+ Path.MOVETO,
514
+ (
515
+ -radius_start * np.cos(theta_start),
516
+ -radius_start * np.sin(theta_start),
517
+ ),
518
+ ),
519
+ (
520
+ Path.LINETO,
521
+ (
522
+ -radius_end * np.cos(theta_start),
523
+ -radius_end * np.sin(theta_start),
524
+ ),
525
+ ),
526
+ (
527
+ Path.LINETO,
528
+ (-radius_end * np.cos(theta_end), -radius_end * np.sin(theta_end)),
529
+ ),
530
+ (
531
+ Path.LINETO,
532
+ (
533
+ -radius_start * np.cos(theta_end),
534
+ -radius_start * np.sin(theta_end),
535
+ ),
536
+ ),
537
+ (
538
+ Path.CLOSEPOLY,
539
+ (
540
+ -radius_start * np.cos(theta_start),
541
+ -radius_start * np.sin(theta_start),
542
+ ),
543
+ ),
544
+ ]
545
+ path = Path([p[1] for p in path_data], [p[0] for p in path_data])
546
+ patch = PathPatch(path, facecolor=color, edgecolor="grey", lw=0.5)
547
+ ax.add_patch(patch)
548
+
549
+ # Optional: add centroid markers
550
+ # center_points = np.array(center_points)
551
+ # actual_centers = center_points[np.logical_not(is_virtual)]
552
+ # virtual_centers = center_points[is_virtual]
553
+ # ax.scatter(actual_centers[:, 0], actual_centers[:, 1], c='red', s=10, label='Actual Centers')
554
+ # ax.scatter(virtual_centers[:, 0], virtual_centers[:, 1], c='grey', s=5, label='Virtual Centers')
555
+
556
+ polygon_coords = list(
557
+ zip(
558
+ nuclear_boundary_df_registered["x_c_s"],
559
+ nuclear_boundary_df_registered["y_c_s"],
560
+ )
561
+ )
562
+ polygon = Polygon(polygon_coords)
563
+ boundary_x, boundary_y = zip(*polygon_coords)
564
+ ax.plot(boundary_x, boundary_y, color="blue", linewidth=1)
565
+ # colors = {'inside': 'green', 'outside': 'red', 'boundary': 'orange'}
566
+ # for point, classification in zip(center_points, classifications):
567
+ # ax.scatter(*point, color=colors[classification], label=f'{classification}', s=25, edgecolor='grey')
568
+ # Remove spines
569
+ ax.spines["top"].set_visible(False)
570
+ ax.spines["right"].set_visible(False)
571
+ ax.spines["left"].set_visible(False)
572
+ ax.spines["bottom"].set_visible(False)
573
+ plt.title(f"Cell {cell} - Gene {gene}")
574
+ plt.xlim(-r, r)
575
+ plt.ylim(-r, r)
576
+ if not os.path.exists(plot_dir):
577
+ os.makedirs(plot_dir)
578
+ plt.savefig(
579
+ os.path.join(plot_dir, f"{gene}_partition_heatmap.png"),
580
+ bbox_inches="tight",
581
+ pad_inches=0,
582
+ dpi=300,
583
+ )
584
+ plt.close()
585
+
586
+
587
+ def classify_nuclear_position(
588
+ center_points, nuclear_boundary_df_registered, epsilon=0.1
589
+ ): # Renamed, removed is_edge
590
+ polygon_coords = list(
591
+ zip(
592
+ nuclear_boundary_df_registered["x_c_s"],
593
+ nuclear_boundary_df_registered["y_c_s"],
594
+ )
595
+ )
596
+ polygon = Polygon(polygon_coords)
597
+ classifications = []
598
+ for point in center_points: # Removed idx, no is_edge check here
599
+ point_geom = Point(point)
600
+ if polygon.contains(point_geom):
601
+ classifications.append("inside")
602
+ elif polygon.touches(point_geom):
603
+ classifications.append("boundary")
604
+ else:
605
+ distance_to_boundary = polygon.boundary.distance(point_geom)
606
+ if distance_to_boundary <= epsilon:
607
+ classifications.append("boundary")
608
+ else:
609
+ classifications.append("outside")
610
+ return classifications
611
+
612
+
613
+ def plot_partition_nuclear_position(
614
+ center_points, nuclear_boundary_df_registered, classifications, cell, gene, plot_dir
615
+ ):
616
+ polygon_coords = list(
617
+ zip(
618
+ nuclear_boundary_df_registered["x_c_s"],
619
+ nuclear_boundary_df_registered["y_c_s"],
620
+ )
621
+ )
622
+ polygon = Polygon(polygon_coords)
623
+ fig, ax = plt.subplots(figsize=(4, 4))
624
+ boundary_x, boundary_y = zip(*polygon_coords)
625
+ ax.plot(boundary_x, boundary_y, color="blue", linewidth=1)
626
+ colors = {"inside": "green", "outside": "red", "boundary": "orange"}
627
+ for point, classification in zip(center_points, classifications):
628
+ ax.scatter(
629
+ *point,
630
+ color=colors[classification],
631
+ label=f"{classification}",
632
+ s=25,
633
+ edgecolor="grey",
634
+ )
635
+ # Remove spines
636
+ ax.spines["top"].set_visible(False)
637
+ ax.spines["right"].set_visible(False)
638
+ ax.spines["left"].set_visible(False)
639
+ ax.spines["bottom"].set_visible(False)
640
+ ax.axis("off")
641
+ ax.set_xlim(-1, 1)
642
+ ax.set_ylim(-1, 1)
643
+ ax.set_aspect("equal", "box")
644
+ ax.set_title(f"Cell {cell} - Gene {gene}")
645
+ # plt.show()
646
+ if not os.path.exists(plot_dir):
647
+ os.makedirs(plot_dir)
648
+ plt.savefig(
649
+ os.path.join(plot_dir, f"{gene}_partition_nuclear_position.png"),
650
+ bbox_inches="tight",
651
+ pad_inches=0,
652
+ dpi=300,
653
+ )
654
+ plt.close()