grasp-tool 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,171 @@
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+
4
+ import pandas as pd
5
+ import os
6
+ import seaborn as sns
7
+
8
+
9
+ def _lazy_import_ot():
10
+ """Import POT (Python Optimal Transport) only when needed."""
11
+
12
+ try:
13
+ import ot # type: ignore
14
+
15
+ return ot
16
+ except ModuleNotFoundError as e:
17
+ raise ModuleNotFoundError(
18
+ "This feature requires POT (Python Optimal Transport, import name: `ot`).\n"
19
+ "Install: pip install POT\n"
20
+ "Or (extras): pip install grasp-tool[ot]"
21
+ ) from e
22
+
23
+
24
+ def generate_random_points(radius, num_points):
25
+ center = (0, 0)
26
+ angles = np.random.uniform(0, 2 * np.pi, num_points)
27
+ radii = radius * np.sqrt(np.random.uniform(0, 1, num_points))
28
+
29
+ x = center[0] + radii * np.cos(angles)
30
+ y = center[1] + radii * np.sin(angles)
31
+ points_df = pd.DataFrame({"x": x, "y": y})
32
+ return points_df, np.vstack([x, y]).T
33
+
34
+
35
+ def simulate_expected_distance(radius, num_points, num_simulations=10):
36
+ ot = _lazy_import_ot()
37
+ distances = []
38
+ for _ in range(num_simulations):
39
+ _, points1 = generate_random_points(radius, num_points)
40
+ _, points2 = generate_random_points(radius, num_points)
41
+ a = np.ones(num_points) / num_points
42
+ b = np.ones(num_points) / num_points
43
+ cost_matrix = ot.dist(points1, points2, metric="euclidean")
44
+ distance = ot.emd2(a, b, cost_matrix)
45
+ distances.append(distance)
46
+ return np.mean(distances)
47
+
48
+
49
+ def theoretical_expected_distance(radius, num_points, alpha=0.5):
50
+ return radius * num_points ** (-alpha)
51
+
52
+
53
+ def calculate_adjusted_wasserstein_distance(
54
+ original_points, target_points, num_points, expected_distance, epsilon=0.1
55
+ ):
56
+ ot = _lazy_import_ot()
57
+ a = np.ones(num_points) / num_points
58
+ b = np.ones(num_points) / num_points
59
+ cost_matrix = ot.dist(original_points, target_points, metric="euclidean")
60
+
61
+ # cost_matrix = ot.dist(sampled_original, sampled_target, metric='euclidean') + 1e-9
62
+ wasserstein_distance = ot.emd2(a, b, cost_matrix)
63
+ adjusted_distance = abs(wasserstein_distance - expected_distance)
64
+ return (
65
+ wasserstein_distance,
66
+ adjusted_distance,
67
+ )
68
+
69
+
70
+ def plot_points(dataset, data1, data2, gene, cell, radius, wasserstein_distance, path):
71
+ fig, axes = plt.subplots(1, 2, figsize=(6, 4))
72
+ axes[0].scatter(
73
+ data1["x_c_s"],
74
+ data1["y_c_s"],
75
+ s=3,
76
+ alpha=0.6,
77
+ color="blue",
78
+ label=f"Gene: {gene}",
79
+ )
80
+ circle1 = plt.Circle(
81
+ (0, 0), radius, color="black", fill=False, label="Cell Boundary"
82
+ )
83
+ axes[0].add_patch(circle1)
84
+ axes[0].set_title(f"Original Points", fontsize=10)
85
+ axes[0].set_aspect("equal")
86
+ axes[0].axis("off")
87
+
88
+ axes[1].scatter(
89
+ data2["x"], data2["y"], s=3, alpha=0.6, color="green", label=f"Target Points"
90
+ )
91
+ circle2 = plt.Circle(
92
+ (0, 0), radius, color="black", fill=False, label="Cell Boundary"
93
+ )
94
+ axes[1].add_patch(circle2)
95
+ axes[1].set_title(f"Target Points", fontsize=10)
96
+ axes[1].set_aspect("equal")
97
+ axes[1].axis("off")
98
+
99
+ num = len(data1)
100
+ fig.suptitle(
101
+ f"Cell: {cell} - Gene: {gene}\nMean_adjusted_distances: {wasserstein_distance:.4f} - num: {num}",
102
+ fontsize=12,
103
+ )
104
+ save_dir = f"{path}/{gene}"
105
+ if not os.path.exists(save_dir):
106
+ os.makedirs(save_dir)
107
+ save_path = os.path.join(save_dir, f"{cell}.png")
108
+ plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
109
+ plt.close(fig)
110
+ # plt.tight_layout()
111
+ # plt.show()
112
+
113
+
114
+ def statics_plot(dataset, gene, path):
115
+ file = f"{path}/{gene}_distances.csv"
116
+ if not os.path.exists(file):
117
+ print(f"1. Skipping {gene} (file not found).")
118
+
119
+ else:
120
+ df = pd.read_csv(f"{path}/{gene}_distances.csv", index_col=0)
121
+ mean_value = df["mean_adjusted_distances"].mean()
122
+ median_value = df["mean_adjusted_distances"].median()
123
+ std_value = df["mean_adjusted_distances"].std()
124
+ # print(f"Mean: {mean_value:.3f}, Median: {median_value:.3f}, Std Dev: {std_value:.3f}")
125
+ plt.figure(figsize=(4, 3))
126
+ sns.histplot(
127
+ df["mean_wasserstein_distances"],
128
+ kde=True,
129
+ color="skyblue",
130
+ edgecolor="grey",
131
+ bins=20,
132
+ label="mean_wasserstein_distances",
133
+ )
134
+ sns.histplot(
135
+ df["expected_distance"],
136
+ kde=True,
137
+ color="royalblue",
138
+ edgecolor="grey",
139
+ bins=20,
140
+ label="expected_distance",
141
+ )
142
+ sns.histplot(
143
+ df["mean_adjusted_distances"],
144
+ kde=True,
145
+ color="pink",
146
+ edgecolor="grey",
147
+ bins=20,
148
+ label="mean_adjusted_distances",
149
+ )
150
+
151
+ # Add mean/median markers
152
+ plt.axvline(
153
+ mean_value, color="red", linestyle="--", label=f"Mean: {mean_value:.3f}"
154
+ )
155
+ plt.axvline(
156
+ median_value,
157
+ color="green",
158
+ linestyle="--",
159
+ label=f"Medianm: {median_value:.3f}",
160
+ )
161
+ plt.title(f"Distribution of Wasserstein Distances {gene}", fontsize=10)
162
+ plt.xlabel("Wasserstein Distance", fontsize=8)
163
+ plt.ylabel("Frequency", fontsize=8)
164
+ plt.xticks(fontsize=8)
165
+ plt.yticks(fontsize=8)
166
+ plt.legend(prop={"size": 8})
167
+ plt.grid(False)
168
+ save_path = os.path.join(path, f"{gene}_plot.png")
169
+ plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
170
+ plt.close()
171
+ # plt.show()
@@ -0,0 +1,79 @@
1
+ """Misc network/analysis helpers.
2
+
3
+ This module is not part of the main GRASP training pipeline. Keep imports minimal
4
+ so that `import grasp_tool` and other modules do not pull heavy optional deps.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import os
10
+
11
+ import pandas as pd
12
+
13
+
14
+ def read_data(
15
+ dataset,
16
+ target_cells,
17
+ df_registered,
18
+ n_sectors,
19
+ m_rings,
20
+ k_neighbor,
21
+ if_same,
22
+ partition_root=None,
23
+ ):
24
+ """Locate per-cell graph files on disk.
25
+
26
+ Historical code had hard-coded absolute paths. To keep this usable across
27
+ machines, the root directory must be provided.
28
+
29
+ Args:
30
+ partition_root: Base directory containing partition outputs.
31
+ If None, will use env var `GRASP_PARTITION_ROOT`.
32
+ """
33
+ if partition_root is None:
34
+ partition_root = os.environ.get("GRASP_PARTITION_ROOT")
35
+ if not partition_root:
36
+ raise ValueError("partition_root is required (or set env GRASP_PARTITION_ROOT)")
37
+ unique_genes = df_registered["gene"].unique()
38
+ target_genes = unique_genes.tolist()
39
+ # print(target_genes)
40
+ matching_paths = []
41
+ for cell in target_cells:
42
+ if if_same == "yes":
43
+ cell_dir = os.path.join(partition_root, f"4_{dataset}_partition_same", cell)
44
+ if not os.path.exists(cell_dir):
45
+ print(f"Directory {cell_dir} does not exist.")
46
+ continue
47
+ else:
48
+ cell_dir = os.path.join(partition_root, f"4_{dataset}_partition", cell)
49
+ if not os.path.exists(cell_dir):
50
+ print(f"Directory {cell_dir} does not exist.")
51
+ continue
52
+
53
+ for root, dirs, files in os.walk(cell_dir):
54
+ if f"{n_sectors}_{m_rings}_k{k_neighbor}" in root:
55
+ for file in files:
56
+ gene_in_file = any(gene in file for gene in target_genes)
57
+ if gene_in_file and file.endswith("distance_matrix.csv"):
58
+ file_path = os.path.join(root, file)
59
+ # file_path = file_path.replace("_adjacency_matrix.csv", "")
60
+ file_path = file_path.replace("_distance_matrix.csv", "")
61
+ matching_paths.append(file_path)
62
+ return matching_paths
63
+
64
+
65
+ def load_graph_data(paths):
66
+ graphs = {}
67
+ for path in paths:
68
+ # NOTE: assumes path layout: .../<cell>/<...>/<gene>
69
+ cell_name = path.split("/")[-3]
70
+ gene_name = path.split("/")[-1]
71
+ graph_name = f"{cell_name}_{gene_name}"
72
+ data = pd.read_csv(f"{path}_nodes.csv")
73
+ features = data[["x", "y", "is_virtual"]].to_numpy()
74
+ adj_matrix = pd.read_csv(f"{path}_adjacency_matrix.csv")
75
+ # adj_matrix = pd.read_csv(f"{path}_distance_matrix.csv")
76
+ weights = data["count"].to_numpy()
77
+ # name = path.split('/')[-1]
78
+ graphs[graph_name] = (features, adj_matrix, weights)
79
+ return graphs