grasp-tool 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_tool/__init__.py +17 -0
- grasp_tool/__main__.py +6 -0
- grasp_tool/cli/__init__.py +1 -0
- grasp_tool/cli/main.py +793 -0
- grasp_tool/cli/train_moco.py +778 -0
- grasp_tool/gnn/__init__.py +1 -0
- grasp_tool/gnn/embedding.py +165 -0
- grasp_tool/gnn/gat_moco_final.py +990 -0
- grasp_tool/gnn/graphloader.py +1748 -0
- grasp_tool/gnn/plot_refined.py +1556 -0
- grasp_tool/preprocessing/__init__.py +1 -0
- grasp_tool/preprocessing/augumentation.py +66 -0
- grasp_tool/preprocessing/cellplot.py +475 -0
- grasp_tool/preprocessing/filter.py +171 -0
- grasp_tool/preprocessing/network.py +79 -0
- grasp_tool/preprocessing/partition.py +654 -0
- grasp_tool/preprocessing/portrait.py +1862 -0
- grasp_tool/preprocessing/register.py +1021 -0
- grasp_tool-0.1.0.dist-info/METADATA +511 -0
- grasp_tool-0.1.0.dist-info/RECORD +22 -0
- grasp_tool-0.1.0.dist-info/WHEEL +4 -0
- grasp_tool-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import os
|
|
6
|
+
import seaborn as sns
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _lazy_import_ot():
|
|
10
|
+
"""Import POT (Python Optimal Transport) only when needed."""
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import ot # type: ignore
|
|
14
|
+
|
|
15
|
+
return ot
|
|
16
|
+
except ModuleNotFoundError as e:
|
|
17
|
+
raise ModuleNotFoundError(
|
|
18
|
+
"This feature requires POT (Python Optimal Transport, import name: `ot`).\n"
|
|
19
|
+
"Install: pip install POT\n"
|
|
20
|
+
"Or (extras): pip install grasp-tool[ot]"
|
|
21
|
+
) from e
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def generate_random_points(radius, num_points):
|
|
25
|
+
center = (0, 0)
|
|
26
|
+
angles = np.random.uniform(0, 2 * np.pi, num_points)
|
|
27
|
+
radii = radius * np.sqrt(np.random.uniform(0, 1, num_points))
|
|
28
|
+
|
|
29
|
+
x = center[0] + radii * np.cos(angles)
|
|
30
|
+
y = center[1] + radii * np.sin(angles)
|
|
31
|
+
points_df = pd.DataFrame({"x": x, "y": y})
|
|
32
|
+
return points_df, np.vstack([x, y]).T
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def simulate_expected_distance(radius, num_points, num_simulations=10):
|
|
36
|
+
ot = _lazy_import_ot()
|
|
37
|
+
distances = []
|
|
38
|
+
for _ in range(num_simulations):
|
|
39
|
+
_, points1 = generate_random_points(radius, num_points)
|
|
40
|
+
_, points2 = generate_random_points(radius, num_points)
|
|
41
|
+
a = np.ones(num_points) / num_points
|
|
42
|
+
b = np.ones(num_points) / num_points
|
|
43
|
+
cost_matrix = ot.dist(points1, points2, metric="euclidean")
|
|
44
|
+
distance = ot.emd2(a, b, cost_matrix)
|
|
45
|
+
distances.append(distance)
|
|
46
|
+
return np.mean(distances)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def theoretical_expected_distance(radius, num_points, alpha=0.5):
|
|
50
|
+
return radius * num_points ** (-alpha)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def calculate_adjusted_wasserstein_distance(
|
|
54
|
+
original_points, target_points, num_points, expected_distance, epsilon=0.1
|
|
55
|
+
):
|
|
56
|
+
ot = _lazy_import_ot()
|
|
57
|
+
a = np.ones(num_points) / num_points
|
|
58
|
+
b = np.ones(num_points) / num_points
|
|
59
|
+
cost_matrix = ot.dist(original_points, target_points, metric="euclidean")
|
|
60
|
+
|
|
61
|
+
# cost_matrix = ot.dist(sampled_original, sampled_target, metric='euclidean') + 1e-9
|
|
62
|
+
wasserstein_distance = ot.emd2(a, b, cost_matrix)
|
|
63
|
+
adjusted_distance = abs(wasserstein_distance - expected_distance)
|
|
64
|
+
return (
|
|
65
|
+
wasserstein_distance,
|
|
66
|
+
adjusted_distance,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def plot_points(dataset, data1, data2, gene, cell, radius, wasserstein_distance, path):
|
|
71
|
+
fig, axes = plt.subplots(1, 2, figsize=(6, 4))
|
|
72
|
+
axes[0].scatter(
|
|
73
|
+
data1["x_c_s"],
|
|
74
|
+
data1["y_c_s"],
|
|
75
|
+
s=3,
|
|
76
|
+
alpha=0.6,
|
|
77
|
+
color="blue",
|
|
78
|
+
label=f"Gene: {gene}",
|
|
79
|
+
)
|
|
80
|
+
circle1 = plt.Circle(
|
|
81
|
+
(0, 0), radius, color="black", fill=False, label="Cell Boundary"
|
|
82
|
+
)
|
|
83
|
+
axes[0].add_patch(circle1)
|
|
84
|
+
axes[0].set_title(f"Original Points", fontsize=10)
|
|
85
|
+
axes[0].set_aspect("equal")
|
|
86
|
+
axes[0].axis("off")
|
|
87
|
+
|
|
88
|
+
axes[1].scatter(
|
|
89
|
+
data2["x"], data2["y"], s=3, alpha=0.6, color="green", label=f"Target Points"
|
|
90
|
+
)
|
|
91
|
+
circle2 = plt.Circle(
|
|
92
|
+
(0, 0), radius, color="black", fill=False, label="Cell Boundary"
|
|
93
|
+
)
|
|
94
|
+
axes[1].add_patch(circle2)
|
|
95
|
+
axes[1].set_title(f"Target Points", fontsize=10)
|
|
96
|
+
axes[1].set_aspect("equal")
|
|
97
|
+
axes[1].axis("off")
|
|
98
|
+
|
|
99
|
+
num = len(data1)
|
|
100
|
+
fig.suptitle(
|
|
101
|
+
f"Cell: {cell} - Gene: {gene}\nMean_adjusted_distances: {wasserstein_distance:.4f} - num: {num}",
|
|
102
|
+
fontsize=12,
|
|
103
|
+
)
|
|
104
|
+
save_dir = f"{path}/{gene}"
|
|
105
|
+
if not os.path.exists(save_dir):
|
|
106
|
+
os.makedirs(save_dir)
|
|
107
|
+
save_path = os.path.join(save_dir, f"{cell}.png")
|
|
108
|
+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
|
109
|
+
plt.close(fig)
|
|
110
|
+
# plt.tight_layout()
|
|
111
|
+
# plt.show()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def statics_plot(dataset, gene, path):
|
|
115
|
+
file = f"{path}/{gene}_distances.csv"
|
|
116
|
+
if not os.path.exists(file):
|
|
117
|
+
print(f"1. Skipping {gene} (file not found).")
|
|
118
|
+
|
|
119
|
+
else:
|
|
120
|
+
df = pd.read_csv(f"{path}/{gene}_distances.csv", index_col=0)
|
|
121
|
+
mean_value = df["mean_adjusted_distances"].mean()
|
|
122
|
+
median_value = df["mean_adjusted_distances"].median()
|
|
123
|
+
std_value = df["mean_adjusted_distances"].std()
|
|
124
|
+
# print(f"Mean: {mean_value:.3f}, Median: {median_value:.3f}, Std Dev: {std_value:.3f}")
|
|
125
|
+
plt.figure(figsize=(4, 3))
|
|
126
|
+
sns.histplot(
|
|
127
|
+
df["mean_wasserstein_distances"],
|
|
128
|
+
kde=True,
|
|
129
|
+
color="skyblue",
|
|
130
|
+
edgecolor="grey",
|
|
131
|
+
bins=20,
|
|
132
|
+
label="mean_wasserstein_distances",
|
|
133
|
+
)
|
|
134
|
+
sns.histplot(
|
|
135
|
+
df["expected_distance"],
|
|
136
|
+
kde=True,
|
|
137
|
+
color="royalblue",
|
|
138
|
+
edgecolor="grey",
|
|
139
|
+
bins=20,
|
|
140
|
+
label="expected_distance",
|
|
141
|
+
)
|
|
142
|
+
sns.histplot(
|
|
143
|
+
df["mean_adjusted_distances"],
|
|
144
|
+
kde=True,
|
|
145
|
+
color="pink",
|
|
146
|
+
edgecolor="grey",
|
|
147
|
+
bins=20,
|
|
148
|
+
label="mean_adjusted_distances",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Add mean/median markers
|
|
152
|
+
plt.axvline(
|
|
153
|
+
mean_value, color="red", linestyle="--", label=f"Mean: {mean_value:.3f}"
|
|
154
|
+
)
|
|
155
|
+
plt.axvline(
|
|
156
|
+
median_value,
|
|
157
|
+
color="green",
|
|
158
|
+
linestyle="--",
|
|
159
|
+
label=f"Medianm: {median_value:.3f}",
|
|
160
|
+
)
|
|
161
|
+
plt.title(f"Distribution of Wasserstein Distances {gene}", fontsize=10)
|
|
162
|
+
plt.xlabel("Wasserstein Distance", fontsize=8)
|
|
163
|
+
plt.ylabel("Frequency", fontsize=8)
|
|
164
|
+
plt.xticks(fontsize=8)
|
|
165
|
+
plt.yticks(fontsize=8)
|
|
166
|
+
plt.legend(prop={"size": 8})
|
|
167
|
+
plt.grid(False)
|
|
168
|
+
save_path = os.path.join(path, f"{gene}_plot.png")
|
|
169
|
+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
|
170
|
+
plt.close()
|
|
171
|
+
# plt.show()
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Misc network/analysis helpers.
|
|
2
|
+
|
|
3
|
+
This module is not part of the main GRASP training pipeline. Keep imports minimal
|
|
4
|
+
so that `import grasp_tool` and other modules do not pull heavy optional deps.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def read_data(
|
|
15
|
+
dataset,
|
|
16
|
+
target_cells,
|
|
17
|
+
df_registered,
|
|
18
|
+
n_sectors,
|
|
19
|
+
m_rings,
|
|
20
|
+
k_neighbor,
|
|
21
|
+
if_same,
|
|
22
|
+
partition_root=None,
|
|
23
|
+
):
|
|
24
|
+
"""Locate per-cell graph files on disk.
|
|
25
|
+
|
|
26
|
+
Historical code had hard-coded absolute paths. To keep this usable across
|
|
27
|
+
machines, the root directory must be provided.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
partition_root: Base directory containing partition outputs.
|
|
31
|
+
If None, will use env var `GRASP_PARTITION_ROOT`.
|
|
32
|
+
"""
|
|
33
|
+
if partition_root is None:
|
|
34
|
+
partition_root = os.environ.get("GRASP_PARTITION_ROOT")
|
|
35
|
+
if not partition_root:
|
|
36
|
+
raise ValueError("partition_root is required (or set env GRASP_PARTITION_ROOT)")
|
|
37
|
+
unique_genes = df_registered["gene"].unique()
|
|
38
|
+
target_genes = unique_genes.tolist()
|
|
39
|
+
# print(target_genes)
|
|
40
|
+
matching_paths = []
|
|
41
|
+
for cell in target_cells:
|
|
42
|
+
if if_same == "yes":
|
|
43
|
+
cell_dir = os.path.join(partition_root, f"4_{dataset}_partition_same", cell)
|
|
44
|
+
if not os.path.exists(cell_dir):
|
|
45
|
+
print(f"Directory {cell_dir} does not exist.")
|
|
46
|
+
continue
|
|
47
|
+
else:
|
|
48
|
+
cell_dir = os.path.join(partition_root, f"4_{dataset}_partition", cell)
|
|
49
|
+
if not os.path.exists(cell_dir):
|
|
50
|
+
print(f"Directory {cell_dir} does not exist.")
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
for root, dirs, files in os.walk(cell_dir):
|
|
54
|
+
if f"{n_sectors}_{m_rings}_k{k_neighbor}" in root:
|
|
55
|
+
for file in files:
|
|
56
|
+
gene_in_file = any(gene in file for gene in target_genes)
|
|
57
|
+
if gene_in_file and file.endswith("distance_matrix.csv"):
|
|
58
|
+
file_path = os.path.join(root, file)
|
|
59
|
+
# file_path = file_path.replace("_adjacency_matrix.csv", "")
|
|
60
|
+
file_path = file_path.replace("_distance_matrix.csv", "")
|
|
61
|
+
matching_paths.append(file_path)
|
|
62
|
+
return matching_paths
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def load_graph_data(paths):
|
|
66
|
+
graphs = {}
|
|
67
|
+
for path in paths:
|
|
68
|
+
# NOTE: assumes path layout: .../<cell>/<...>/<gene>
|
|
69
|
+
cell_name = path.split("/")[-3]
|
|
70
|
+
gene_name = path.split("/")[-1]
|
|
71
|
+
graph_name = f"{cell_name}_{gene_name}"
|
|
72
|
+
data = pd.read_csv(f"{path}_nodes.csv")
|
|
73
|
+
features = data[["x", "y", "is_virtual"]].to_numpy()
|
|
74
|
+
adj_matrix = pd.read_csv(f"{path}_adjacency_matrix.csv")
|
|
75
|
+
# adj_matrix = pd.read_csv(f"{path}_distance_matrix.csv")
|
|
76
|
+
weights = data["count"].to_numpy()
|
|
77
|
+
# name = path.split('/')[-1]
|
|
78
|
+
graphs[graph_name] = (features, adj_matrix, weights)
|
|
79
|
+
return graphs
|