grasp-tool 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1862 @@
1
+ """Network Portrait + Jensen-Shannon distance for transcript graphs.
2
+
3
+ This module computes pairwise similarity between per-cell/per-gene transcript
4
+ graphs using a network-portrait representation and Jensen-Shannon divergence.
5
+
6
+ Input
7
+ A PKL that contains a DataFrame with transcript coordinates (typically the
8
+ registered output with `df_registered`). The DataFrame is expected to contain:
9
+ - cell, gene, x_c_s, y_c_s
10
+
11
+ Output
12
+ A CSV of JS distances that can be used as an optional signal for selecting
13
+ positive samples during training.
14
+ """
15
+ ## python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl/simulated_data1_data_dict.pkl --use_same_r --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_simulated_data1/js_scipy_auto.log --visualize_top_n 0 --auto_params
16
+ ## python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl/simulated_data1_data_dict.pkl --use_same_r --visualize_top_n 10 --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_simulated_data1/js2.log
17
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_fibroblast_data_dict.pkl --use_same_r --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy.log --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/seqfish_fibroblast_cell171_gene2734_graph143789.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_nohup.log &
18
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex.log --visualize_top_n 0 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_nohup.log &
19
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_data_dict_new.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex2.log --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/seqfish_cortex_new_cell708_gene93_graph708.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_nohup2.log &
20
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_data_dict_new1.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_new1.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/seqfish_cortex_sub19_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/seqfish_cortex_new1_cell2405_gene92_graph2405.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_new1.log &
21
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_merfish_liver/js_scipy_merfish_data3.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data3_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data3_cell281_gene139_graph13488.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_merfish_liver/js_scipy_merfish_data3.log &
22
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_merfish_liver/js_scipy_merfish_data4.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data4_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data4_cell919_gene176_graph43931.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_merfish_liver/js_scipy_merfish_data4.log &
23
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_central.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data4_central_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data4_central_cell182_gene106_graph9770.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_central.log &
24
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_data_dict_sub19.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_sub19.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/seqfish_cortex_sub19_portrait --visualize_top_n 0 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_sub19.log &
25
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_Astrocytes_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy_cortex_Astrocytes.log --filter_pkl_file seqfish_cortex_Astrocytes_cell528_gene22_graph528.pkl --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/seqfish_cortex_Astrocytes_portrait --visualize_top_n 0 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy_cortex_Astrocytes.log &
26
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy_seqfish_plus_all.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/seqfish_plus_all_portrait --visualize_top_n 0 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy_seqfish_plus_all.log &
27
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_protal.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data4_protal_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data4_protal_cell454_gene124_graph21222.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_protal.log &
28
+
29
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_central_bigger.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data4_central_bigger_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data_central_cell870_gene124_graph45025.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_central_bigger.log &
30
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_protal_bigger.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data4_protal_bigger_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data_protal_cell1713_gene143_graph79975.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_protal_bigger.log &
31
+
32
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_intestine_Enterocyte_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_intestine/js_scipy_merfish_Enterocyte.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_intestine_Enterocyte_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_intestine_Enterocyte_cell419_gene17_graph905.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_intestine/js_scipy_merfish_Enterocyte.log &
33
+
34
+
35
+ ## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/simulated1_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_simulated1/js_scipy_simulated1.log --output_dir /lustre/home/1910305118/data/GCN_CL/6_analysis/js_portrait/simulated1_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5.1_graph_data/simulated1_cell10_gene80_graph800_weight.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_simulated1/js_scipy_simulated1.log &
36
+
37
+ ## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/seqfish_fibroblast_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/seqfish_fibroblast_portrait --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5.1_graph_data/seqfish_fibroblast_cell171_gene59_graph8068.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy_nohup.log &
38
+
39
+ ## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/simulated3_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_simulated/js_scipy_simulated3.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/simulated3_portrait --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5.1_graph_data/simulated3_original_cell50_gene400_graph15000.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_simulated/js_scipy_simulated3.log &
40
+
41
+ ## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merscope_liver_data_region1_portal_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region1_portal.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_liver_region1_portal --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merscope_liver_data_region1_portal_cell1708_gene143_graph79172.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region1_portal.log &
42
+
43
+ ## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merscope_liver_data_region1_central_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region1_central.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_liver_region1_central --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merscope_liver_data_region1_central_cell1711_gene136_graph86074.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region1_central.log &
44
+
45
+ ## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merscope_liver_data_region2_central_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region2_central.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_liver_region2_central --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merscope_liver_data_region2_central_cell936_gene126_graph44910.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region2_central_nohup.log &
46
+
47
+ ## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merscope_liver_data_region2_portal_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region2_portal.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_liver_region2_portal --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merscope_liver_data_region2_portal_cell747_gene124_graph33523.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region2_portal_nohup.log &
48
+
49
+
50
+ ## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_intestine_Enterocyte_resegment_new_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_intestine/js_scipy_enterocyte_resegment_new.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_intestine_enterocyte_resegment_new_portrait --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_intestine_Enterocyte_resegment_new_cell688_gene58_graph4331.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_intestine/js_scipy_enterocyte_resegment_new_nohup.log &
51
+
52
+
53
+ # # ---------- Group 1 ----------
54
+ # nohup python portrait.py \
55
+ # --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_u2os_data_dict.pkl \
56
+ # --use_same_r --max_count 30 --auto_params \
57
+ # --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group1.log \
58
+ # --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_u2os_group1_portrait \
59
+ # --visualize_top_n 0 \
60
+ # --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_u2os_cell634_gene25_graph1000.pkl \
61
+ # 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group1_nohup.log &
62
+
63
+
64
+ # # ---------- Group 2 ----------
65
+ # nohup python portrait.py \
66
+ # --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_u2os_data_dict.pkl \
67
+ # --use_same_r --max_count 30 --auto_params \
68
+ # --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group2.log \
69
+ # --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_u2os_group2_portrait \
70
+ # --visualize_top_n 0 \
71
+ # --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_u2os_cell621_gene25_graph1000.pkl \
72
+ # 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group2_nohup.log &
73
+
74
+
75
+ # # ---------- Group 3 ----------
76
+ # nohup python portrait.py \
77
+ # --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_u2os_data_dict.pkl \
78
+ # --use_same_r --max_count 30 --auto_params \
79
+ # --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group3.log \
80
+ # --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_u2os_group3_portrait \
81
+ # --visualize_top_n 0 \
82
+ # --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_u2os_cell629_gene25_graph1000.pkl \
83
+ # 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group3_nohup.log &
84
+
85
+
86
+ # # ---------- Group 4 ----------
87
+ # nohup python portrait.py \
88
+ # --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_u2os_data_dict.pkl \
89
+ # --use_same_r --max_count 30 --auto_params \
90
+ # --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group4.log \
91
+ # --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_u2os_group4_portrait \
92
+ # --visualize_top_n 0 \
93
+ # --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_u2os_cell621_gene9_graph947.pkl \
94
+ # 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group4_nohup.log &
95
+
96
+
97
+ # # ---------- Group 5 ----------
98
+ # nohup python portrait.py \
99
+ # --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_u2os_data_dict.pkl \
100
+ # --use_same_r --max_count 30 --auto_params \
101
+ # --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group5.log \
102
+ # --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_u2os_group5_portrait \
103
+ # --visualize_top_n 0 \
104
+ # --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_u2os_cell989_gene25_graph23242.pkl \
105
+ # 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group5_nohup.log &
106
+
107
+ ## Legacy note: derived from utils_code/portrait.py
108
+ import pandas as pd
109
+ import numpy as np
110
+ import networkx as nx
111
+
112
+ from scipy.spatial import distance_matrix
113
+ import time
114
+ from datetime import datetime
115
+ import os
116
+ import logging
117
+ import argparse
118
+ from concurrent.futures import ThreadPoolExecutor, as_completed
119
+ import warnings
120
+ import matplotlib.pyplot as plt
121
+ import seaborn as sns
122
+ from tqdm.auto import tqdm
123
+ import pickle
124
+ from scipy.sparse import coo_matrix, csr_matrix
125
+ from scipy.sparse.csgraph import shortest_path
126
+ from scipy.spatial import KDTree
127
+ import random
128
+ import sys
129
+
130
+ # Matplotlib defaults (keep output deterministic and readable)
131
+ plt.rcParams["font.family"] = ["Arial"]
132
+ plt.rcParams["font.sans-serif"] = ["Arial"]
133
+ plt.rcParams["axes.unicode_minus"] = False
134
+
135
+ # Logging
136
+ logging.basicConfig(
137
+ level=logging.INFO,
138
+ format="[%(asctime)s][%(levelname)s] %(message)s",
139
+ datefmt="%Y-%m-%d %H:%M:%S",
140
+ )
141
+ logger = logging.getLogger("js_distance")
142
+
143
+ # Warnings
144
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
145
+ warnings.filterwarnings("ignore", category=UserWarning)
146
+
147
+
148
+ def find_r_for_isolated_threshold(
149
+ df, threshold=0.05, r_min=0.01, r_max=0.6, step=0.03, verbose=False, dists=None
150
+ ):
151
+ """Find a connection radius that keeps isolated-node ratio under a threshold.
152
+
153
+ This implementation uses a KDTree to approximate an appropriate radius based
154
+ on nearest-neighbor distances.
155
+ """
156
+ positions = df[["x_c_s", "y_c_s"]].values
157
+ N = len(df)
158
+
159
+ # Edge cases
160
+ if N <= 1:
161
+ # Single point (or empty): return a small default radius.
162
+ return min(r_min * 5, r_max * 0.1)
163
+
164
+ # Compute nearest-neighbor distances via KDTree.
165
+ tree = KDTree(positions)
166
+
167
+ # Use k=2 because the first neighbor is the point itself.
168
+ dists, _ = tree.query(positions, k=2)
169
+ nn_dists = dists[:, 1]
170
+
171
+ # Percentile: at most `threshold` fraction can be isolated.
172
+ r = np.percentile(nn_dists, 100 * (1 - threshold))
173
+
174
+ # Clamp to the configured range.
175
+ r = max(r_min, min(r, r_max))
176
+
177
+ if verbose:
178
+ logger.debug(
179
+ f"KDTree-based r={r:.4f} (percentile={100 * (1 - threshold):.1f} of NN distances)"
180
+ )
181
+ # Sanity-check the actual isolated ratio.
182
+ isolated_count = np.sum(nn_dists > r)
183
+ logger.debug(f"isolated_ratio={isolated_count / N:.4f} ({isolated_count}/{N})")
184
+
185
+ return r
186
+
187
+
188
+ def build_weighted_graph(df, r, dists=None):
189
+ """Build a weighted graph from transcript coordinates.
190
+
191
+ Nodes are transcripts and edges connect pairs within radius `r`.
192
+ Edge weights are Euclidean distances.
193
+ """
194
+ G = nx.Graph()
195
+ positions = df[["x_c_s", "y_c_s"]].values
196
+
197
+ # Add nodes.
198
+ for i, pos in enumerate(positions):
199
+ G.add_node(i, pos=pos)
200
+
201
+ # Compute pairwise distances if not provided.
202
+ if dists is None:
203
+ dists = distance_matrix(positions, positions)
204
+
205
+ # Add edges within radius.
206
+ for i in range(len(df)):
207
+ for j in range(i + 1, len(df)):
208
+ dist = dists[i, j]
209
+ if dist <= r:
210
+ G.add_edge(i, j, weight=dist)
211
+
212
+ return G
213
+
214
+
215
+ def get_network_portrait(G, bin_size=0.01, use_vectorized=True):
216
+ """Compute a network portrait for a weighted graph."""
217
+
218
+ # Node count
219
+ n_nodes = len(G)
220
+ if n_nodes <= 1:
221
+ return {}, n_nodes
222
+
223
+ # Implementation choice
224
+ if use_vectorized:
225
+ # Vectorized APSP via SciPy sparse shortest_path.
226
+ rows, cols, weights = [], [], []
227
+ for u, v, data in G.edges(data=True):
228
+ w = data.get("weight", 1.0)
229
+ rows.append(u)
230
+ cols.append(v)
231
+ weights.append(w)
232
+ rows.append(v)
233
+ cols.append(u)
234
+ weights.append(w)
235
+
236
+ A = csr_matrix((weights, (rows, cols)), shape=(n_nodes, n_nodes))
237
+
238
+ dist_mat = shortest_path(A, directed=False, unweighted=False, method="auto")
239
+
240
+ degs = np.array([d for _, d in sorted(G.degree(), key=lambda x: x[0])])
241
+
242
+ # 4) Exclude self-pairs and keep all i != j pairs.
243
+ # dist_mat is an n x n numpy array.
244
+ i_idx, j_idx = np.nonzero(~np.eye(n_nodes, dtype=bool))
245
+ dists = dist_mat[i_idx, j_idx]
246
+ src_degs = degs[i_idx]
247
+
248
+ # 5) Bin distances.
249
+ bins = np.floor(dists / bin_size).astype(int)
250
+
251
+ # 6) Vectorized counting.
252
+ combined = np.column_stack([bins, src_degs])
253
+ unique_ck, counts = np.unique(combined, axis=0, return_counts=True)
254
+
255
+ # 7) Convert back to a dict.
256
+ portrait = {
257
+ (int(bin_l), int(deg)): int(cnt)
258
+ for (bin_l, deg), cnt in zip(unique_ck, counts)
259
+ }
260
+ else:
261
+ # Loop implementation: slower, but uses less memory.
262
+ # All-pairs shortest path lengths.
263
+ length_dict = dict(nx.all_pairs_dijkstra_path_length(G, weight="weight"))
264
+ # Node degrees.
265
+ degrees = dict(G.degree())
266
+
267
+ portrait = {}
268
+
269
+ for i in length_dict:
270
+ for j in length_dict[i]:
271
+ if i == j:
272
+ continue
273
+ dist = length_dict[i][j]
274
+ bin_l = int(dist // bin_size)
275
+ deg = degrees[i]
276
+ key = (bin_l, deg)
277
+ portrait[key] = portrait.get(key, 0) + 1
278
+
279
+ return portrait, n_nodes
280
+
281
+
282
+ def compute_weighted_distribution(portrait, N):
283
+ """Convert a portrait count dict into a weighted probability distribution."""
284
+
285
+ total_pairs = N * N
286
+ dist = {}
287
+
288
+ for (l, k), count in portrait.items():
289
+ dist[(l, k)] = (k * count) / total_pairs
290
+
291
+ return dist
292
+
293
+
294
+ def js_divergence(P, Q):
295
+ """Compute Jensen-Shannon divergence between two distributions."""
296
+ keys = set(P.keys()).union(Q.keys())
297
+ p_vec = np.array([P.get(k, 0.0) for k in keys])
298
+ q_vec = np.array([Q.get(k, 0.0) for k in keys])
299
+ m_vec = 0.5 * (p_vec + q_vec)
300
+
301
+ def safe_kl(p, q):
302
+ mask = (p > 0) & (q > 0)
303
+ return np.sum(p[mask] * np.log2(p[mask] / q[mask]))
304
+
305
+ return 0.5 * safe_kl(p_vec, m_vec) + 0.5 * safe_kl(q_vec, m_vec)
306
+
307
+
308
+ def plot_portrait(portrait, title="Network Portrait", save_path=None):
309
+ """Plot a 2D heatmap for a network portrait."""
310
+ if not portrait:
311
+ logger.warning("Empty portrait; skip plotting")
312
+ return
313
+
314
+ df = pd.DataFrame([{"l": l, "k": k, "value": v} for (l, k), v in portrait.items()])
315
+
316
+ pivot = df.pivot(index="l", columns="k", values="value").fillna(0)
317
+
318
+ plt.figure(figsize=(8, 6))
319
+ sns.heatmap(pivot, cmap="viridis", annot=False)
320
+ plt.title(title, fontsize=14, fontweight="bold")
321
+ plt.xlabel("Degree k", fontsize=12)
322
+ plt.ylabel("Path length bin l", fontsize=12)
323
+ plt.xticks(fontsize=10)
324
+ plt.yticks(fontsize=10)
325
+ plt.tight_layout()
326
+
327
+ if save_path:
328
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
329
+ logger.info(f"Portrait saved to: {save_path}")
330
+ else:
331
+ plt.show()
332
+
333
+ plt.close()
334
+
335
+
336
+ def plot_graph(G, title="Graph structure", save_path=None):
337
+ """Plot the graph layout using stored node positions."""
338
+ pos = nx.get_node_attributes(G, "pos")
339
+
340
+ plt.figure(figsize=(8, 8))
341
+ nx.draw(
342
+ G,
343
+ pos,
344
+ node_size=30,
345
+ node_color="skyblue",
346
+ edge_color="gray",
347
+ width=0.5,
348
+ with_labels=False,
349
+ )
350
+ plt.title(title, fontsize=14, fontweight="bold")
351
+ plt.axis("equal")
352
+ plt.tight_layout()
353
+
354
+ if save_path:
355
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
356
+ logger.info(f"Graph saved to: {save_path}")
357
+ else:
358
+ plt.show()
359
+
360
+ plt.close()
361
+
362
+
363
+ def compare_graphs(
364
+ df1, df2, r=None, bin_size=0.01, show_plots=True, save_dir=None, use_vectorized=True
365
+ ):
366
+ """Compare two transcript graphs and compute JS divergence."""
367
+
368
+ # Auto-select radius if not provided.
369
+ if r is None:
370
+ r1 = find_r_for_isolated_threshold(df1, threshold=0.05)
371
+ r2 = find_r_for_isolated_threshold(df2, threshold=0.05)
372
+ r = max(r1, r2)
373
+ logger.info(f"Auto-selected connection radius r = {r:.2f}")
374
+
375
+ # Build graphs.
376
+ G1 = build_weighted_graph(df1, r)
377
+ G2 = build_weighted_graph(df2, r)
378
+
379
+ # Compute network portraits.
380
+ B1, N1 = get_network_portrait(G1, bin_size, use_vectorized)
381
+ B2, N2 = get_network_portrait(G2, bin_size, use_vectorized)
382
+
383
+ # Compute weighted distributions.
384
+ P = compute_weighted_distribution(B1, N1)
385
+ Q = compute_weighted_distribution(B2, N2)
386
+
387
+ # Compute JS divergence.
388
+ js = js_divergence(P, Q)
389
+ logger.info(f"Jensen-Shannon divergence between two graphs: {js:.4f}")
390
+
391
+ # Visualization.
392
+ if show_plots or save_dir:
393
+ if save_dir:
394
+ os.makedirs(save_dir, exist_ok=True)
395
+
396
+ if show_plots:
397
+ plot_graph(G1, title="Graph 1 structure")
398
+ plot_graph(G2, title="Graph 2 structure")
399
+ plot_portrait(B1, title="Graph 1 network portrait")
400
+ plot_portrait(B2, title="Graph 2 network portrait")
401
+
402
+ if save_dir:
403
+ plot_graph(
404
+ G1,
405
+ title="Graph 1 structure",
406
+ save_path=f"{save_dir}/graph1_structure.png",
407
+ )
408
+ plot_graph(
409
+ G2,
410
+ title="Graph 2 structure",
411
+ save_path=f"{save_dir}/graph2_structure.png",
412
+ )
413
+ plot_portrait(
414
+ B1,
415
+ title="Graph 1 network portrait",
416
+ save_path=f"{save_dir}/graph1_portrait.png",
417
+ )
418
+ plot_portrait(
419
+ B2,
420
+ title="Graph 2 network portrait",
421
+ save_path=f"{save_dir}/graph2_portrait.png",
422
+ )
423
+
424
+ return js
425
+
426
+
427
+ def find_gene_optimal_r(
428
+ gene,
429
+ df,
430
+ cell_list,
431
+ threshold=0.05,
432
+ r_min=0.01,
433
+ r_max=0.6,
434
+ r_step=0.03,
435
+ dist_dict=None,
436
+ ):
437
+ """
438
+ Find an optimal r value for a gene (use the max r across cells).
439
+
440
+ Args:
441
+ gene: Gene ID.
442
+ df: DataFrame containing transcripts.
443
+ cell_list: List of cell IDs.
444
+ threshold: Isolated-node ratio threshold.
445
+ r_min, r_max, r_step: Search parameters for r.
446
+ dist_dict: Optional precomputed distance matrices {(cell, gene): dist_matrix}.
447
+
448
+ Returns:
449
+ tuple: (optimal r for gene, dict {cell: dist_matrix})
450
+ """
451
+ logger.info(f"Computing optimal r for gene {gene}")
452
+
453
+ # All transcripts for this gene.
454
+ gene_df = df[df["gene"] == gene]
455
+
456
+ cell_r_values = {}
457
+ cell_dist_matrices = {} # store distance matrix per cell
458
+ transcript_counts = {} # transcript count per cell
459
+
460
+ # Compute r per cell.
461
+ for cell in cell_list:
462
+ cell_df = gene_df[gene_df["cell"] == cell]
463
+ transcript_count = len(cell_df)
464
+ transcript_counts[cell] = transcript_count
465
+
466
+ # If there is only one transcript, use a small default radius.
467
+ if transcript_count == 1:
468
+ # Avoid connecting to far-away points.
469
+ r_single = min(r_min * 5, r_max * 0.1) # min(5*r_min, 0.1*r_max)
470
+ cell_r_values[cell] = r_single
471
+ # Single-point distance matrix.
472
+ cell_dist_matrices[cell] = np.array([[0.0]])
473
+ continue
474
+
475
+ # Skip empty cells.
476
+ if transcript_count == 0:
477
+ continue
478
+
479
+ try:
480
+ # Prefer precomputed distance matrix.
481
+ dists = dist_dict.get((cell, gene), None) if dist_dict else None
482
+
483
+ # Compute if missing.
484
+ if dists is None:
485
+ positions = cell_df[["x_c_s", "y_c_s"]].values
486
+ dists = distance_matrix(positions, positions)
487
+
488
+ # Find optimal r for this cell.
489
+ r = find_r_for_isolated_threshold(
490
+ cell_df,
491
+ threshold,
492
+ r_min,
493
+ r_max,
494
+ r_step,
495
+ verbose=False,
496
+ dists=dists,
497
+ )
498
+ cell_r_values[cell] = r
499
+ cell_dist_matrices[cell] = dists # cache for reuse
500
+ except Exception as e:
501
+ logger.error(f"Failed to compute r for cell={cell}, gene={gene}: {e}")
502
+
503
+ # Summary stats.
504
+ total_cells = len(cell_list)
505
+ valid_cells = len(cell_r_values)
506
+ single_transcript_cells = sum(
507
+ 1 for count in transcript_counts.values() if count == 1
508
+ )
509
+ multi_transcript_cells = sum(1 for count in transcript_counts.values() if count > 1)
510
+ zero_transcript_cells = sum(1 for count in transcript_counts.values() if count == 0)
511
+
512
+ # If no r values are valid, return a safer default.
513
+ if not cell_r_values:
514
+ # Use a smaller default to avoid overly dense graphs.
515
+ default_r = min(r_max * 0.3, r_min * 10)
516
+ logger.warning(
517
+ f"Gene {gene} has no valid r; using adjusted default {default_r:.4f} (original default: {r_max})"
518
+ )
519
+ logger.warning(
520
+ f"Gene {gene} transcript stats: total_cells={total_cells}, zero={zero_transcript_cells}, one={single_transcript_cells}, multi={multi_transcript_cells}"
521
+ )
522
+ return default_r, {}
523
+
524
+ # Return max r across cells and the distance-matrix cache.
525
+ max_r = max(cell_r_values.values())
526
+ min_r = min(cell_r_values.values())
527
+ avg_r = np.mean(list(cell_r_values.values()))
528
+
529
+ logger.info(
530
+ f"Gene {gene} optimal r: {max_r:.2f} (from {valid_cells} cells, range: {min_r:.2f}-{max_r:.2f}, mean: {avg_r:.2f})"
531
+ )
532
+ logger.info(
533
+ f"Gene {gene} transcript distribution: zero={zero_transcript_cells}, one={single_transcript_cells}, multi={multi_transcript_cells}"
534
+ )
535
+
536
+ return max_r, cell_dist_matrices
537
+
538
+
539
+ def precompute_portraits_for_gene(
540
+ gene,
541
+ df,
542
+ cell_list,
543
+ threshold=0.05,
544
+ bin_size=0.01,
545
+ r_min=0.01,
546
+ r_max=0.6,
547
+ r_step=0.03,
548
+ use_same_r=True,
549
+ use_vectorized=True,
550
+ dist_dict=None,
551
+ ):
552
+ """
553
+ Precompute network portraits for all cells of a gene.
554
+
555
+ Args:
556
+ gene: Gene identifier.
557
+ df: Transcript DataFrame.
558
+ cell_list: List of cell IDs.
559
+ threshold: Isolated node ratio threshold.
560
+ bin_size: Path length bin size.
561
+ r_min, r_max, r_step: Radius search parameters.
562
+ use_same_r: Whether to use a shared r for all cells within a gene.
563
+ use_vectorized: Whether to use vectorized portrait computation.
564
+ dist_dict: Optional precomputed distance matrices {(cell, gene): dist_matrix}.
565
+
566
+ Returns:
567
+ Dict: {(cell, gene): (weighted_distribution, node_count, r_value)}
568
+ """
569
+ distributions = {}
570
+
571
+ logger.info(f"Start precomputing network portraits for gene {gene}")
572
+ start_time = time.time()
573
+
574
+ # Filter transcripts for this gene.
575
+ gene_df = df[df["gene"] == gene]
576
+
577
+ # Ensure there is data.
578
+ if len(gene_df) == 0:
579
+ logger.warning(f"Gene {gene} has no transcript records")
580
+ return distributions
581
+
582
+ # If using a shared r, compute gene-level r first (and reuse distance matrices).
583
+ gene_r = None
584
+ cell_dist_matrices = {}
585
+ if use_same_r:
586
+ gene_r, cell_dist_matrices = find_gene_optimal_r(
587
+ gene, df, cell_list, threshold, r_min, r_max, r_step, dist_dict
588
+ )
589
+
590
+ # Process each cell.
591
+ for cell in tqdm(
592
+ cell_list,
593
+ desc=f"Processing cells for gene {gene}",
594
+ leave=False,
595
+ disable=not sys.stdout.isatty(),
596
+ ):
597
+ # Filter transcripts for this cell.
598
+ cell_df = gene_df[gene_df["cell"] == cell]
599
+ transcript_count = len(cell_df)
600
+
601
+ # Skip cells with no transcripts.
602
+ if transcript_count == 0:
603
+ logger.debug(f"Cell {cell}, gene {gene} has no transcripts; skipping")
604
+ continue
605
+
606
+ # Handle the single-transcript case.
607
+ if transcript_count == 1:
608
+ logger.debug(
609
+ f"Cell {cell}, gene {gene} has 1 transcript; using a single-node portrait"
610
+ )
611
+ try:
612
+ # Special handling for a single-node graph.
613
+ # The portrait contains one node with degree=0 at path length=0.
614
+ # In (l, k) bins: l=0 (self distance), k=0 (degree=0).
615
+ single_portrait = {(0, 0): 1} # One entry: (path_length=0, degree=0)
616
+ weighted_dist = compute_weighted_distribution(single_portrait, 1)
617
+
618
+ # Use a reasonable r.
619
+ r = gene_r if use_same_r else min(r_min * 5, r_max * 0.1)
620
+
621
+ distributions[(cell, gene)] = (weighted_dist, 1, r)
622
+ continue
623
+ except Exception as e:
624
+ logger.error(
625
+ f"Failed single-transcript handling cell={cell}, gene={gene}: {e}"
626
+ )
627
+ continue
628
+
629
+ # Multi-transcript case (original logic).
630
+ try:
631
+ # Prefer cell_dist_matrices, then fall back to global dist_dict.
632
+ dists = cell_dist_matrices.get(cell, None)
633
+ if dists is None and dist_dict:
634
+ dists = dist_dict.get((cell, gene), None)
635
+
636
+ # Compute distance matrix if missing.
637
+ if dists is None:
638
+ positions = cell_df[["x_c_s", "y_c_s"]].values
639
+ dists = distance_matrix(positions, positions)
640
+
641
+ # Use gene-level r, or find an r for this (cell, gene) pair.
642
+ r = gene_r
643
+ if not use_same_r:
644
+ r = find_r_for_isolated_threshold(
645
+ cell_df,
646
+ threshold=threshold,
647
+ r_min=r_min,
648
+ r_max=r_max,
649
+ step=r_step,
650
+ dists=dists,
651
+ )
652
+
653
+ # Build graph (reuse the distance matrix).
654
+ G = build_weighted_graph(cell_df, r, dists=dists)
655
+
656
+ # Compute portrait and weighted distribution.
657
+ portrait, N = get_network_portrait(G, bin_size, use_vectorized)
658
+ weighted_dist = compute_weighted_distribution(portrait, N)
659
+
660
+ distributions[(cell, gene)] = (weighted_dist, N, r)
661
+
662
+ except Exception as e:
663
+ logger.error(f"Failed processing cell={cell}, gene={gene}: {e}")
664
+
665
+ elapsed = time.time() - start_time
666
+ logger.info(
667
+ f"Finished precomputing network portraits for gene {gene}: {len(distributions)} distributions, elapsed {elapsed:.2f}s"
668
+ )
669
+
670
+ return distributions
671
+
672
+
673
+ def find_js_distances_for_gene(
674
+ gene, df, cell_list, portraits, bin_size=0.01, max_count=None, transcript_window=30
675
+ ):
676
+ """
677
+ Compute JS divergence for cell pairs within a gene.
678
+
679
+ Args:
680
+ gene: Gene identifier.
681
+ df: DataFrame containing transcripts.
682
+ cell_list: List of cells.
683
+ portraits: Precomputed portraits {(cell, gene): (weighted_distribution, node_count, r_value)}.
684
+ bin_size: Path length bin size.
685
+ max_count: Max comparisons per target cell.
686
+ transcript_window: Candidate window on transcript count difference.
687
+
688
+ Returns:
689
+ List: JS divergence results.
690
+ """
691
+ logger.info(f"Computing JS divergences for gene {gene}")
692
+ start_time = time.time()
693
+
694
+ # Collect valid cells and transcript counts for this gene.
695
+ valid_cells = []
696
+ transcript_counts = {}
697
+ r_values = {}
698
+
699
+ for cell in cell_list:
700
+ key = (cell, gene)
701
+ if key in portraits:
702
+ valid_cells.append(cell)
703
+ _, N, r = portraits[key]
704
+ transcript_counts[cell] = N
705
+ r_values[cell] = r
706
+
707
+ logger.info(f"Gene {gene}: {len(valid_cells)} valid cells")
708
+
709
+ # Compute JS distances.
710
+ all_distances = []
711
+ processed_count = 0
712
+
713
+ for i, target_cell in enumerate(valid_cells):
714
+ target_transcript_count = transcript_counts[target_cell]
715
+ target_r = r_values[target_cell]
716
+
717
+ # Candidate cells with similar transcript counts.
718
+ candidates = []
719
+ for j, cell in enumerate(valid_cells):
720
+ if cell != target_cell:
721
+ transcript_diff = abs(transcript_counts[cell] - target_transcript_count)
722
+ if transcript_diff <= transcript_window:
723
+ candidates.append((cell, transcript_diff))
724
+
725
+ # Sort by transcript count difference.
726
+ candidates.sort(key=lambda x: x[1])
727
+
728
+ # Limit comparisons.
729
+ if max_count is not None and len(candidates) > max_count:
730
+ candidates = candidates[:max_count]
731
+
732
+ # Compute JS divergence.
733
+ for cell, transcript_diff in candidates:
734
+ try:
735
+ target_key = (target_cell, gene)
736
+ other_key = (cell, gene)
737
+
738
+ if target_key in portraits and other_key in portraits:
739
+ target_dist, _, _ = portraits[target_key]
740
+ other_dist, _, other_r = portraits[other_key]
741
+
742
+ js_distance = js_divergence(target_dist, other_dist)
743
+
744
+ all_distances.append(
745
+ (
746
+ target_cell,
747
+ gene,
748
+ cell,
749
+ gene,
750
+ transcript_counts[cell],
751
+ js_distance,
752
+ transcript_diff,
753
+ target_r,
754
+ other_r,
755
+ )
756
+ )
757
+
758
+ processed_count += 1
759
+ if processed_count % 100 == 0:
760
+ logger.debug(
761
+ f"Gene {gene}: computed {processed_count} JS distances"
762
+ )
763
+ except Exception as e:
764
+ logger.error(f"Failed JS divergence for {target_cell}-{cell}: {e}")
765
+
766
+ elapsed = time.time() - start_time
767
+ logger.info(
768
+ f"Finished JS divergences for gene {gene}: {len(all_distances)} results, elapsed {elapsed:.2f}s"
769
+ )
770
+
771
+ return all_distances
772
+
773
+
774
+ def calculate_js_distances(
775
+ pkl_file: str,
776
+ output_dir: str = None,
777
+ max_count: int = None,
778
+ transcript_window: int = 30,
779
+ bin_size: float = 0.01,
780
+ threshold: float = 0.05,
781
+ r_min: float = 0.01,
782
+ r_max: float = 0.6,
783
+ r_step: float = 0.03,
784
+ num_threads: int = None,
785
+ use_same_r: bool = True,
786
+ visualize_top_n: int = 5,
787
+ use_vectorized: bool = True,
788
+ filter_pkl_file: str = None,
789
+ auto_params: bool = False,
790
+ n_bins: int = 50,
791
+ min_percentile: float = 1.0,
792
+ max_percentile: float = 99.0,
793
+ ) -> pd.DataFrame:
794
+ """
795
+ Compute Jensen-Shannon divergence between transcript graphs, with optional auto r selection.
796
+
797
+ Args:
798
+ pkl_file: Input PKL path containing df_registered.
799
+ output_dir: Output directory; if None, infer a default.
800
+ max_count: Max comparisons per target cell.
801
+ transcript_window: Candidate window on transcript count difference.
802
+ bin_size: Path length bin size; use 'auto' to infer.
803
+ threshold: Isolated-node ratio threshold.
804
+ r_min: Minimum r for search; use 'auto' to infer.
805
+ r_max: Maximum r for search; use 'auto' to infer.
806
+ r_step: Step size for r search.
807
+ num_threads: Thread pool size.
808
+ use_same_r: Use a single r per gene (max across cells).
809
+ visualize_top_n: Visualize top-N most similar pairs.
810
+ use_vectorized: Use vectorized portrait computation.
811
+ filter_pkl_file: Optional PKL to filter (cell, gene) pairs.
812
+ auto_params: Auto-set r_min/r_max/bin_size based on distances.
813
+ n_bins: Bin count for auto bin_size.
814
+ min_percentile: Percentile used to infer r_min.
815
+ max_percentile: Percentile used to infer r_max.
816
+
817
+ Returns:
818
+ pd.DataFrame: JS divergence results.
819
+ """
820
+ # Default thread count.
821
+ if num_threads is None:
822
+ import multiprocessing
823
+
824
+ num_threads = min(multiprocessing.cpu_count() - 2, 20)
825
+
826
+ # Auto parameter flags.
827
+ auto_r_min = r_min == "auto" or auto_params
828
+ auto_r_max = r_max == "auto" or auto_params
829
+ auto_bin_size = bin_size == "auto" or auto_params
830
+
831
+ # Use None placeholders for auto-calculated params.
832
+ if auto_r_min:
833
+ r_min = None
834
+ else:
835
+ r_min = float(r_min)
836
+
837
+ if auto_r_max:
838
+ r_max = None
839
+ else:
840
+ r_max = float(r_max)
841
+
842
+ if auto_bin_size:
843
+ bin_size = None
844
+ else:
845
+ bin_size = float(bin_size)
846
+
847
+ logger.info(f"Computing JS divergence with {num_threads} threads")
848
+ if auto_params or auto_r_min or auto_r_max or auto_bin_size:
849
+ logger.info("Auto-calculating r_min, r_max, and bin_size")
850
+ else:
851
+ logger.info(
852
+ f"r search params: threshold={threshold}, r_min={r_min}, r_max={r_max}, r_step={r_step}"
853
+ )
854
+ logger.info(f"Path length bin_size: {bin_size}")
855
+
856
+ logger.info(
857
+ f"{'Use a single r per gene' if use_same_r else 'Use per (cell, gene) r'}"
858
+ )
859
+ logger.info(
860
+ f"{'Use vectorized portrait computation' if use_vectorized else 'Use loop-based portrait computation'}"
861
+ )
862
+
863
+ start_time = time.time()
864
+
865
+ # Load data.
866
+ with open(pkl_file, "rb") as f:
867
+ data_dict = pickle.load(f)
868
+
869
+ # Validate required key.
870
+ if "df_registered" not in data_dict:
871
+ raise ValueError(f"PKL file {pkl_file} does not contain df_registered")
872
+
873
+ df = data_dict["df_registered"]
874
+ logger.info(f"Loaded {len(df)} transcript records")
875
+
876
+ # Validate required columns.
877
+ required_cols = ["cell", "gene", "x_c_s", "y_c_s"]
878
+ missing_cols = [col for col in required_cols if col not in df.columns]
879
+ if missing_cols:
880
+ raise ValueError(f"df_registered is missing required columns: {missing_cols}")
881
+
882
+ # Optionally filter (cell, gene) pairs.
883
+ if filter_pkl_file and os.path.exists(filter_pkl_file):
884
+ logger.info(f"Filtering (cell, gene) pairs using PKL file: {filter_pkl_file}")
885
+ try:
886
+ with open(filter_pkl_file, "rb") as f:
887
+ filter_data = pickle.load(f)
888
+
889
+ # Extract cell_labels and gene_labels.
890
+ if "cell_labels" in filter_data and "gene_labels" in filter_data:
891
+ # Require equal lengths to form (cell, gene) pairs.
892
+ cell_labels = filter_data["cell_labels"]
893
+ gene_labels = filter_data["gene_labels"]
894
+
895
+ if len(cell_labels) == len(gene_labels):
896
+ # Build set of (cell, gene) pairs.
897
+ cell_gene_pairs = set(zip(cell_labels, gene_labels))
898
+ logger.info(
899
+ f"Extracted {len(cell_gene_pairs)} unique (cell, gene) pairs from filter file"
900
+ )
901
+
902
+ # Add temp (cell, gene) key for filtering.
903
+ df["cell_gene_pair"] = list(zip(df["cell"], df["gene"]))
904
+
905
+ # Filter df_registered to keep only pairs from the filter file.
906
+ original_len = len(df)
907
+ df = df[df["cell_gene_pair"].isin(cell_gene_pairs)]
908
+
909
+ # Drop temp column.
910
+ df = df.drop(columns=["cell_gene_pair"])
911
+
912
+ logger.info(
913
+ f"Filtered records: before={original_len}, after={len(df)}"
914
+ )
915
+
916
+ if len(df) == 0:
917
+ logger.warning(
918
+ "No records remain after filtering; check whether (cell, gene) pairs match"
919
+ )
920
+ return pd.DataFrame()
921
+ else:
922
+ logger.warning(
923
+ f"Filter PKL has mismatched lengths: cell_labels={len(cell_labels)} vs gene_labels={len(gene_labels)}; cannot form exact pairs"
924
+ )
925
+ logger.info("Proceeding with all cells and genes")
926
+ else:
927
+ logger.warning(
928
+ f"Filter PKL file {filter_pkl_file} does not contain cell_labels or gene_labels"
929
+ )
930
+ except Exception as e:
931
+ logger.error(f"Failed to read filter PKL file: {e}")
932
+ logger.info("Proceeding with all cells and genes")
933
+
934
+ # Unique cells and genes.
935
+ cell_list = sorted(df["cell"].unique())
936
+ gene_list = sorted(df["gene"].unique())
937
+
938
+ logger.info(f"Dataset contains {len(cell_list)} cells and {len(gene_list)} genes")
939
+
940
+ # Resolve output directory.
941
+ if output_dir is None:
942
+ # Try to infer the dataset name from the PKL filename.
943
+ filename = os.path.basename(pkl_file)
944
+ # Drop common suffixes, e.g. "_data_dict.pkl".
945
+ dataset = filename.split("_data_dict")[0]
946
+ if dataset == filename: # No "_data_dict" suffix found.
947
+ # Fall back to stripping the extension.
948
+ dataset = os.path.splitext(filename)[0]
949
+
950
+ logger.info(f"Derived dataset name from filename {filename}: {dataset}")
951
+
952
+ # Determine data directory: search upward for a "GRASP" directory.
953
+ pkl_abs_path = os.path.abspath(pkl_file)
954
+ # Find the "GRASP" directory as project root.
955
+ grasp_dir = None
956
+ path_parts = pkl_abs_path.split(os.sep)
957
+ for i, part in enumerate(path_parts):
958
+ if part == "GRASP":
959
+ grasp_dir = os.sep.join(path_parts[: i + 1])
960
+ break
961
+
962
+ if grasp_dir:
963
+ # Under GRASP, search for a dataset directory (e.g. data1_simulated1).
964
+ # Naming rule: directory name contains the dataset identifier.
965
+ data_dir = None
966
+ for item in os.listdir(grasp_dir):
967
+ item_path = os.path.join(grasp_dir, item)
968
+ if os.path.isdir(item_path) and dataset in item:
969
+ data_dir = item_path
970
+ break
971
+
972
+ if data_dir:
973
+ output_dir = os.path.join(data_dir, "step2_js")
974
+ logger.info(f"Using data directory: {data_dir}")
975
+ else:
976
+ # If no matching data directory, use the PKL directory.
977
+ pkl_dir = os.path.dirname(pkl_abs_path)
978
+ output_dir = os.path.join(pkl_dir, "step2_js")
979
+ logger.info(
980
+ f"No matching data directory found; using PKL directory: {pkl_dir}"
981
+ )
982
+ else:
983
+ # If no GRASP directory found, use the PKL directory.
984
+ pkl_dir = os.path.dirname(pkl_abs_path)
985
+ output_dir = os.path.join(pkl_dir, "step2_js")
986
+ logger.info(
987
+ f"GRASP directory not found in path; using PKL directory: {pkl_dir}"
988
+ )
989
+
990
+ os.makedirs(output_dir, exist_ok=True)
991
+
992
+ # Create visualization directory.
993
+ vis_dir = f"{output_dir}/visualization"
994
+ os.makedirs(vis_dir, exist_ok=True)
995
+
996
+ # Analyze transcript distribution (optional diagnostics).
997
+ logger.info("Analyzing transcript distribution (optional diagnostics)...")
998
+ try:
999
+ analyze_transcript_distribution(df, output_dir)
1000
+ except Exception as e:
1001
+ logger.warning(f"Transcript distribution analysis failed: {e}")
1002
+
1003
+ # Stage 0: precompute all distance matrices globally.
1004
+ logger.info("Stage 0: precomputing all distance matrices")
1005
+ dist_dict = {} # Global distance matrix dict {(cell, gene): dist_matrix}
1006
+
1007
+ # Precompute all distance matrices using a thread pool.
1008
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
1009
+ futures = {}
1010
+
1011
+ # Submit all computation tasks.
1012
+ for gene in gene_list:
1013
+ gene_df = df[df["gene"] == gene]
1014
+
1015
+ for cell in cell_list:
1016
+ cell_df = gene_df[gene_df["cell"] == cell]
1017
+ # Skip cells with insufficient transcripts.
1018
+ if len(cell_df) <= 1:
1019
+ continue
1020
+
1021
+ # Define a local function to compute distance matrices.
1022
+ def calc_dist_matrix(c_df):
1023
+ positions = c_df[["x_c_s", "y_c_s"]].values
1024
+ return distance_matrix(positions, positions)
1025
+
1026
+ # Submit task.
1027
+ futures[(cell, gene)] = executor.submit(calc_dist_matrix, cell_df)
1028
+
1029
+ # Collect results.
1030
+ for (cell, gene), future in tqdm(
1031
+ futures.items(),
1032
+ desc="Precomputing distance matrices",
1033
+ disable=not sys.stdout.isatty(),
1034
+ ):
1035
+ try:
1036
+ dist_dict[(cell, gene)] = future.result()
1037
+ except Exception as e:
1038
+ logger.error(
1039
+ f"Failed to compute distance matrix for cell={cell}, gene={gene}: {e}"
1040
+ )
1041
+
1042
+ logger.info(f"Distance matrix precompute done; {len(dist_dict)} (cell,gene) pairs")
1043
+
1044
+ # Auto-select parameters based on precomputed distance matrices.
1045
+ if auto_r_min or auto_r_max or auto_bin_size:
1046
+ logger.info("Auto-selecting parameters from precomputed distance matrices")
1047
+ all_dists = []
1048
+
1049
+ # Collect all non-zero distances (no sampling).
1050
+ for key, dmat in tqdm(
1051
+ dist_dict.items(),
1052
+ desc="Collecting distance samples",
1053
+ disable=not sys.stdout.isatty(),
1054
+ ):
1055
+ # Upper triangle (exclude self distances).
1056
+ triu_indices = np.triu_indices_from(dmat, k=1)
1057
+ dists = dmat[triu_indices]
1058
+ # Exclude zero and infinite distances.
1059
+ valid_dists = dists[(dists > 0) & (np.isfinite(dists))]
1060
+ if len(valid_dists) > 0:
1061
+ all_dists.append(valid_dists)
1062
+
1063
+ if all_dists:
1064
+ all_dists = np.concatenate(all_dists)
1065
+ logger.info(f"Collected {len(all_dists)} valid distance values")
1066
+
1067
+ # Percentile-based parameter estimation.
1068
+ if auto_r_min:
1069
+ r_min = float(np.percentile(all_dists, min_percentile))
1070
+ r_min = round(r_min, 2) # keep two decimals
1071
+ logger.info(
1072
+ f"Auto-set r_min = {r_min:.2f} ({min_percentile}% percentile)"
1073
+ )
1074
+
1075
+ if auto_r_max:
1076
+ r_max = float(np.percentile(all_dists, max_percentile))
1077
+ r_max = round(r_max, 2) # keep two decimals
1078
+ logger.info(
1079
+ f"Auto-set r_max = {r_max:.2f} ({max_percentile}% percentile)"
1080
+ )
1081
+
1082
+ if auto_bin_size:
1083
+ # Set bin_size = (r_max - r_min) / n_bins.
1084
+ bin_size = float((r_max - r_min) / n_bins)
1085
+ bin_size = round(bin_size, 2) # keep two decimals
1086
+ bin_size = max(0.01, bin_size) # ensure it is not too small
1087
+ logger.info(
1088
+ f"Auto-set bin_size = {bin_size:.2f} (1/{n_bins} of distance range)"
1089
+ )
1090
+ else:
1091
+ logger.warning("Could not collect valid distance samples; using defaults")
1092
+ if auto_r_min:
1093
+ r_min = 0.01
1094
+ if auto_r_max:
1095
+ r_max = 0.6
1096
+ if auto_bin_size:
1097
+ bin_size = 0.01
1098
+
1099
+ # Ensure parameters have reasonable defaults.
1100
+ if r_min is None:
1101
+ r_min = 0.01
1102
+ if r_max is None:
1103
+ r_max = 0.6
1104
+ if bin_size is None:
1105
+ bin_size = 0.01
1106
+
1107
+ logger.info(
1108
+ f"Final params: r_min={r_min:.2f}, r_max={r_max:.2f}, bin_size={bin_size:.2f}"
1109
+ )
1110
+
1111
+ # Stage 1: precompute all network portraits (including auto-selecting r).
1112
+ logger.info("Stage 1: precomputing network portraits")
1113
+ portraits = {}
1114
+
1115
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
1116
+ # Submit gene-level precompute tasks.
1117
+ futures = {
1118
+ executor.submit(
1119
+ precompute_portraits_for_gene,
1120
+ gene,
1121
+ df,
1122
+ cell_list,
1123
+ threshold,
1124
+ bin_size,
1125
+ r_min,
1126
+ r_max,
1127
+ r_step,
1128
+ use_same_r,
1129
+ use_vectorized,
1130
+ dist_dict, # pass the global distance matrix dict
1131
+ ): gene
1132
+ for gene in gene_list
1133
+ }
1134
+
1135
+ # Collect results.
1136
+ for future in tqdm(
1137
+ as_completed(futures),
1138
+ total=len(futures),
1139
+ desc="Precomputing network portraits",
1140
+ disable=not sys.stdout.isatty(),
1141
+ ):
1142
+ gene = futures[future]
1143
+ try:
1144
+ gene_portraits = future.result()
1145
+ portraits.update(gene_portraits)
1146
+ logger.info(
1147
+ f"Gene {gene} precompute done; {len(gene_portraits)} distributions"
1148
+ )
1149
+ except Exception as e:
1150
+ logger.error(f"Gene {gene} precompute failed: {e}")
1151
+
1152
+ logger.info(f"Network portrait precompute done; {len(portraits)} (cell,gene) pairs")
1153
+
1154
+ # Stage 2: compute JS divergence.
1155
+ logger.info("Stage 2: computing JS divergence")
1156
+ all_distances = []
1157
+
1158
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
1159
+ # Submit gene-level JS divergence tasks.
1160
+ futures = {
1161
+ executor.submit(
1162
+ find_js_distances_for_gene,
1163
+ gene,
1164
+ df,
1165
+ cell_list,
1166
+ portraits,
1167
+ bin_size,
1168
+ max_count,
1169
+ transcript_window,
1170
+ ): gene
1171
+ for gene in gene_list
1172
+ }
1173
+
1174
+ # Collect results.
1175
+ for future in tqdm(
1176
+ as_completed(futures),
1177
+ total=len(futures),
1178
+ desc="Computing JS divergence",
1179
+ disable=not sys.stdout.isatty(),
1180
+ ):
1181
+ gene = futures[future]
1182
+ try:
1183
+ gene_distances = future.result()
1184
+ all_distances.extend(gene_distances)
1185
+ logger.info(
1186
+ f"Gene {gene} JS divergence done; {len(gene_distances)} results"
1187
+ )
1188
+ except Exception as e:
1189
+ logger.error(f"Gene {gene} JS divergence failed: {e}")
1190
+
1191
+ # Clean up distance matrices to free memory.
1192
+ dist_dict.clear()
1193
+
1194
+ # Save results to a DataFrame.
1195
+ if all_distances:
1196
+ distances_df = pd.DataFrame(
1197
+ all_distances,
1198
+ columns=[
1199
+ "target_cell",
1200
+ "target_gene",
1201
+ "cell",
1202
+ "gene",
1203
+ "num_transcripts",
1204
+ "js_distance",
1205
+ "transcript_diff",
1206
+ "target_r",
1207
+ "other_r",
1208
+ ],
1209
+ )
1210
+
1211
+ # Save results.
1212
+ output_path = f"{output_dir}/js_distances_bin{bin_size:.4f}_count{max_count}_threshold{threshold}.csv"
1213
+ distances_df.to_csv(output_path, index=False)
1214
+
1215
+ logger.info(f"\nDone. Results saved to: {output_path}")
1216
+ logger.info(f"Computed {len(distances_df)} JS divergence records")
1217
+
1218
+ # Visualize top-N most similar pairs (smallest JS divergence).
1219
+ if visualize_top_n > 0:
1220
+ logger.info(
1221
+ f"Visualizing top {visualize_top_n} most similar pairs by JS divergence"
1222
+ )
1223
+ visualize_most_similar_pairs(
1224
+ df, distances_df, portraits, visualize_top_n, vis_dir, use_vectorized
1225
+ )
1226
+
1227
+ # Total elapsed time.
1228
+ total_time = time.time() - start_time
1229
+ logger.info(f"Total time: {total_time:.2f}s")
1230
+
1231
+ return distances_df
1232
+ else:
1233
+ logger.warning("WARNING: no JS divergences were computed")
1234
+ return pd.DataFrame()
1235
+
1236
+
1237
+ def visualize_most_similar_pairs(
1238
+ df, distances_df, portraits, top_n=5, output_dir=None, use_vectorized=True
1239
+ ):
1240
+ """
1241
+ Visualize the top-N most similar cell pairs by JS divergence.
1242
+
1243
+ Args:
1244
+ df: Transcript DataFrame.
1245
+ distances_df: JS divergence results DataFrame.
1246
+ portraits: Precomputed portrait dict.
1247
+ top_n: Number of pairs to visualize.
1248
+ output_dir: Output directory.
1249
+ use_vectorized: Whether to use vectorized portrait computation.
1250
+ """
1251
+ # Sort by JS divergence.
1252
+ sorted_df = distances_df.sort_values("js_distance").reset_index(drop=True)
1253
+
1254
+ # Ensure output directory exists.
1255
+ if output_dir:
1256
+ os.makedirs(output_dir, exist_ok=True)
1257
+
1258
+ # Visualize the top-N pairs.
1259
+ for i in range(min(top_n, len(sorted_df))):
1260
+ row = sorted_df.iloc[i]
1261
+
1262
+ target_cell = row["target_cell"]
1263
+ target_gene = row["target_gene"]
1264
+ other_cell = row["cell"]
1265
+ other_gene = row["gene"]
1266
+ js_dist = row["js_distance"]
1267
+ target_r = row["target_r"]
1268
+ other_r = row["other_r"]
1269
+
1270
+ logger.info(
1271
+ f"Rank {i + 1} most similar pair: {target_cell}:{target_gene} - {other_cell}:{other_gene}, JS={js_dist:.4f}"
1272
+ )
1273
+
1274
+ # Extract transcript data.
1275
+ target_df = df[(df["cell"] == target_cell) & (df["gene"] == target_gene)]
1276
+ other_df = df[(df["cell"] == other_cell) & (df["gene"] == other_gene)]
1277
+
1278
+ # Skip if there are too few transcripts.
1279
+ if len(target_df) <= 1 or len(other_df) <= 1:
1280
+ logger.warning(
1281
+ f"Pair {target_cell}:{target_gene} - {other_cell}:{other_gene} has too few transcripts; skipping visualization"
1282
+ )
1283
+ continue
1284
+
1285
+ # Build graphs.
1286
+ target_graph = build_weighted_graph(target_df, target_r)
1287
+ other_graph = build_weighted_graph(other_df, other_r)
1288
+
1289
+ # Compute network portraits.
1290
+ target_portrait, _ = get_network_portrait(
1291
+ target_graph, bin_size=0.01, use_vectorized=use_vectorized
1292
+ )
1293
+ other_portrait, _ = get_network_portrait(
1294
+ other_graph, bin_size=0.01, use_vectorized=use_vectorized
1295
+ )
1296
+
1297
+ # Create per-pair output directory.
1298
+ pair_dir = None
1299
+ if output_dir:
1300
+ pair_dir = f"{output_dir}/pair_{i + 1}_js{js_dist:.4f}"
1301
+ os.makedirs(pair_dir, exist_ok=True)
1302
+
1303
+ # Visualize.
1304
+ pair_prefix = f"Rank {i + 1} JS={js_dist:.4f}: "
1305
+
1306
+ if pair_dir:
1307
+ # Save graph structure.
1308
+ plot_graph(
1309
+ target_graph,
1310
+ title=f"{pair_prefix}{target_cell}:{target_gene} (r={target_r:.2f})",
1311
+ save_path=f"{pair_dir}/cell1_graph.png",
1312
+ )
1313
+ plot_graph(
1314
+ other_graph,
1315
+ title=f"{pair_prefix}{other_cell}:{other_gene} (r={other_r:.2f})",
1316
+ save_path=f"{pair_dir}/cell2_graph.png",
1317
+ )
1318
+
1319
+ # Save portraits.
1320
+ plot_portrait(
1321
+ target_portrait,
1322
+ title=f"{pair_prefix}{target_cell}:{target_gene} Network portrait",
1323
+ save_path=f"{pair_dir}/cell1_portrait.png",
1324
+ )
1325
+ plot_portrait(
1326
+ other_portrait,
1327
+ title=f"{pair_prefix}{other_cell}:{other_gene} Network portrait",
1328
+ save_path=f"{pair_dir}/cell2_portrait.png",
1329
+ )
1330
+
1331
+ # Save transcript scatter plots.
1332
+ plot_transcripts(
1333
+ target_df,
1334
+ target_r,
1335
+ title=f"{pair_prefix}{target_cell}:{target_gene} Transcripts",
1336
+ save_path=f"{pair_dir}/cell1_transcripts.png",
1337
+ )
1338
+ plot_transcripts(
1339
+ other_df,
1340
+ other_r,
1341
+ title=f"{pair_prefix}{other_cell}:{other_gene} Transcripts",
1342
+ save_path=f"{pair_dir}/cell2_transcripts.png",
1343
+ )
1344
+
1345
+ logger.info(f"Saved pair visualization to: {pair_dir}")
1346
+ else:
1347
+ # Show graph structure.
1348
+ plot_graph(
1349
+ target_graph,
1350
+ title=f"{pair_prefix}{target_cell}:{target_gene} (r={target_r:.2f})",
1351
+ )
1352
+ plot_graph(
1353
+ other_graph,
1354
+ title=f"{pair_prefix}{other_cell}:{other_gene} (r={other_r:.2f})",
1355
+ )
1356
+
1357
+ # Show portraits.
1358
+ plot_portrait(
1359
+ target_portrait,
1360
+ title=f"{pair_prefix}{target_cell}:{target_gene} Network portrait",
1361
+ )
1362
+ plot_portrait(
1363
+ other_portrait,
1364
+ title=f"{pair_prefix}{other_cell}:{other_gene} Network portrait",
1365
+ )
1366
+
1367
+ # Show transcript scatter plots.
1368
+ plot_transcripts(
1369
+ target_df,
1370
+ target_r,
1371
+ title=f"{pair_prefix}{target_cell}:{target_gene} Transcripts",
1372
+ )
1373
+ plot_transcripts(
1374
+ other_df,
1375
+ other_r,
1376
+ title=f"{pair_prefix}{other_cell}:{other_gene} Transcripts",
1377
+ )
1378
+
1379
+
1380
+ def plot_transcripts(df, r, title="Transcript distribution", save_path=None):
1381
+ """
1382
+ Plot transcript spatial distribution.
1383
+
1384
+ Args:
1385
+ df: Transcript DataFrame.
1386
+ r: Connection radius.
1387
+ title: Plot title.
1388
+ save_path: Optional output file path.
1389
+ """
1390
+ plt.figure(figsize=(10, 8))
1391
+
1392
+ # Plot transcript locations.
1393
+ plt.scatter(df["x_c_s"], df["y_c_s"], alpha=0.6, s=10)
1394
+
1395
+ # Add a radius circle for each transcript.
1396
+ for _, row in df.iterrows():
1397
+ circle = plt.Circle(
1398
+ (row["x_c_s"], row["y_c_s"]),
1399
+ r,
1400
+ fill=False,
1401
+ color="gray",
1402
+ alpha=0.2,
1403
+ linestyle="--",
1404
+ )
1405
+ plt.gca().add_patch(circle)
1406
+ # plt.text(center_x, center_y + r + 0.1, f'r = {r:.3f}', ha='center', fontsize=10, color='red') # center_x/center_y undefined
1407
+
1408
+ plt.xlabel("X coordinate")
1409
+ plt.ylabel("Y coordinate")
1410
+ plt.title(title)
1411
+ plt.axis("equal")
1412
+ plt.grid(True, alpha=0.3)
1413
+
1414
+ if save_path:
1415
+ plt.savefig(save_path, dpi=300, bbox_inches="tight")
1416
+ plt.close()
1417
+ else:
1418
+ plt.show()
1419
+
1420
+
1421
+ def analyze_transcript_distribution(df, output_dir=None):
1422
+ """
1423
+ Analyze and report transcript distribution.
1424
+
1425
+ Args:
1426
+ df: Transcript DataFrame.
1427
+ output_dir: Output directory. If None, only logs are emitted.
1428
+
1429
+ Returns:
1430
+ Dict: Summary statistics.
1431
+ """
1432
+ logger.info("Analyzing transcript distribution...")
1433
+
1434
+ # Basic stats.
1435
+ total_transcripts = len(df)
1436
+ unique_genes = df["gene"].nunique()
1437
+ unique_cells = df["cell"].nunique()
1438
+
1439
+ # Transcripts per gene.
1440
+ gene_transcript_counts = df.groupby("gene").size()
1441
+
1442
+ if not gene_transcript_counts.empty:
1443
+ gene_stats_values = {
1444
+ "transcript_per_gene_mean": float(gene_transcript_counts.mean()),
1445
+ "transcript_per_gene_median": float(gene_transcript_counts.median()),
1446
+ "transcript_per_gene_std": float(gene_transcript_counts.std())
1447
+ if not np.isnan(gene_transcript_counts.std())
1448
+ else None,
1449
+ "transcript_per_gene_min": int(gene_transcript_counts.min()),
1450
+ "transcript_per_gene_max": int(gene_transcript_counts.max()),
1451
+ }
1452
+ else:
1453
+ gene_stats_values = {
1454
+ "transcript_per_gene_mean": 0.0,
1455
+ "transcript_per_gene_median": 0.0,
1456
+ "transcript_per_gene_std": None,
1457
+ "transcript_per_gene_min": 0,
1458
+ "transcript_per_gene_max": 0,
1459
+ }
1460
+ gene_stats = {"total_genes": int(unique_genes), **gene_stats_values}
1461
+
1462
+ # Transcripts per cell.
1463
+ cell_transcript_counts = df.groupby("cell").size()
1464
+ if not cell_transcript_counts.empty:
1465
+ cell_stats_values = {
1466
+ "transcript_per_cell_mean": float(cell_transcript_counts.mean()),
1467
+ "transcript_per_cell_median": float(cell_transcript_counts.median()),
1468
+ "transcript_per_cell_std": float(cell_transcript_counts.std())
1469
+ if not np.isnan(cell_transcript_counts.std())
1470
+ else None,
1471
+ "transcript_per_cell_min": int(cell_transcript_counts.min()),
1472
+ "transcript_per_cell_max": int(cell_transcript_counts.max()),
1473
+ }
1474
+ else:
1475
+ cell_stats_values = {
1476
+ "transcript_per_cell_mean": 0.0,
1477
+ "transcript_per_cell_median": 0.0,
1478
+ "transcript_per_cell_std": None,
1479
+ "transcript_per_cell_min": 0,
1480
+ "transcript_per_cell_max": 0,
1481
+ }
1482
+ cell_stats = {"total_cells": int(unique_cells), **cell_stats_values}
1483
+
1484
+ # Transcripts per (cell, gene) pair.
1485
+ cell_gene_transcript_counts = df.groupby(["cell", "gene"]).size()
1486
+ if not cell_gene_transcript_counts.empty:
1487
+ pair_stats_values = {
1488
+ "transcript_per_pair_mean": float(cell_gene_transcript_counts.mean()),
1489
+ "transcript_per_pair_median": float(cell_gene_transcript_counts.median()),
1490
+ "transcript_per_pair_std": float(cell_gene_transcript_counts.std())
1491
+ if not np.isnan(cell_gene_transcript_counts.std())
1492
+ else None,
1493
+ }
1494
+ else:
1495
+ pair_stats_values = {
1496
+ "transcript_per_pair_mean": 0.0,
1497
+ "transcript_per_pair_median": 0.0,
1498
+ "transcript_per_pair_std": None,
1499
+ }
1500
+ pair_stats = {
1501
+ "total_cell_gene_pairs": int(
1502
+ len(cell_gene_transcript_counts)
1503
+ ), # len() returns python int
1504
+ **pair_stats_values,
1505
+ "single_transcript_pairs": int((cell_gene_transcript_counts == 1).sum()),
1506
+ "multi_transcript_pairs": int((cell_gene_transcript_counts > 1).sum()),
1507
+ }
1508
+
1509
+ # Count genes that are mostly single-transcript.
1510
+ genes_with_mostly_single_transcripts = 0
1511
+ if (
1512
+ unique_genes > 0 and not cell_gene_transcript_counts.empty
1513
+ ): # Avoid processing if no genes or no pairs
1514
+ for gene in df["gene"].unique():
1515
+ gene_pairs = cell_gene_transcript_counts[
1516
+ cell_gene_transcript_counts.index.get_level_values("gene") == gene
1517
+ ]
1518
+ if not gene_pairs.empty:
1519
+ single_transcript_ratio = (gene_pairs == 1).sum() / len(gene_pairs)
1520
+ if single_transcript_ratio > 0.8: # >80% pairs are single-transcript
1521
+ genes_with_mostly_single_transcripts += 1
1522
+
1523
+ problem_stats = {
1524
+ "genes_with_mostly_single_transcripts": int(
1525
+ genes_with_mostly_single_transcripts
1526
+ ),
1527
+ "problematic_gene_ratio": float(
1528
+ genes_with_mostly_single_transcripts / unique_genes
1529
+ )
1530
+ if unique_genes > 0
1531
+ else 0.0,
1532
+ "single_transcript_pair_ratio": float(
1533
+ pair_stats["single_transcript_pairs"] / pair_stats["total_cell_gene_pairs"]
1534
+ )
1535
+ if pair_stats["total_cell_gene_pairs"] > 0
1536
+ else 0.0,
1537
+ }
1538
+
1539
+ # Summary.
1540
+ stats_summary = {
1541
+ "total_transcripts": int(total_transcripts), # len() returns python int
1542
+ "gene_stats": gene_stats,
1543
+ "cell_stats": cell_stats,
1544
+ "pair_stats": pair_stats,
1545
+ "problem_stats": problem_stats,
1546
+ }
1547
+
1548
+ # Report.
1549
+ logger.info("=" * 60)
1550
+ logger.info("Transcript distribution report")
1551
+ logger.info("=" * 60)
1552
+ logger.info(f"Total transcripts: {total_transcripts:,}")
1553
+ logger.info(f"Unique genes: {unique_genes:,}")
1554
+ logger.info(f"Unique cells: {unique_cells:,}")
1555
+ logger.info("")
1556
+
1557
+ logger.info("Gene-level stats:")
1558
+ logger.info(
1559
+ f" Mean transcripts per gene: {gene_stats['transcript_per_gene_mean']:.1f}"
1560
+ )
1561
+ logger.info(f" Median: {gene_stats['transcript_per_gene_median']:.1f}")
1562
+ logger.info(
1563
+ f" Range: {gene_stats['transcript_per_gene_min']}-{gene_stats['transcript_per_gene_max']}"
1564
+ )
1565
+ logger.info("")
1566
+
1567
+ logger.info("Cell-level stats:")
1568
+ logger.info(
1569
+ f" Mean transcripts per cell: {cell_stats['transcript_per_cell_mean']:.1f}"
1570
+ )
1571
+ logger.info(f" Median: {cell_stats['transcript_per_cell_median']:.1f}")
1572
+ logger.info(
1573
+ f" Range: {cell_stats['transcript_per_cell_min']}-{cell_stats['transcript_per_cell_max']}"
1574
+ )
1575
+ logger.info("")
1576
+
1577
+ logger.info("(Cell, gene) pair-level stats:")
1578
+ logger.info(f" Total (cell, gene) pairs: {pair_stats['total_cell_gene_pairs']:,}")
1579
+ logger.info(
1580
+ f" Single-transcript pairs: {pair_stats['single_transcript_pairs']:,} ({pair_stats['single_transcript_pairs'] / pair_stats['total_cell_gene_pairs'] * 100:.1f}%)"
1581
+ )
1582
+ logger.info(
1583
+ f" Multi-transcript pairs: {pair_stats['multi_transcript_pairs']:,} ({pair_stats['multi_transcript_pairs'] / pair_stats['total_cell_gene_pairs'] * 100:.1f}%)"
1584
+ )
1585
+ logger.info(
1586
+ f" Mean transcripts per pair: {pair_stats['transcript_per_pair_mean']:.1f}"
1587
+ )
1588
+ logger.info("")
1589
+
1590
+ logger.info("Potential issues:")
1591
+ logger.info(
1592
+ f" Genes mostly single-transcript: {problem_stats['genes_with_mostly_single_transcripts']:,} ({problem_stats['problematic_gene_ratio'] * 100:.1f}%)"
1593
+ )
1594
+ logger.info(
1595
+ f" Single-transcript pair ratio: {problem_stats['single_transcript_pair_ratio'] * 100:.1f}%"
1596
+ )
1597
+
1598
+ if problem_stats["problematic_gene_ratio"] > 0.5:
1599
+ logger.warning(
1600
+ "WARNING: >50% of genes are mostly single-transcript; consider checking data quality or adjusting parameters"
1601
+ )
1602
+ elif problem_stats["single_transcript_pair_ratio"] > 0.7:
1603
+ logger.warning(
1604
+ "WARNING: >70% of (cell, gene) pairs have a single transcript; this may affect network portrait quality"
1605
+ )
1606
+ else:
1607
+ logger.info("Data quality looks OK for network portrait analysis")
1608
+
1609
+ logger.info("=" * 60)
1610
+
1611
+ # If output_dir is provided, write stats and plots.
1612
+ if output_dir:
1613
+ import json
1614
+
1615
+ os.makedirs(output_dir, exist_ok=True)
1616
+
1617
+ # Save stats.
1618
+ stats_file = os.path.join(output_dir, "transcript_distribution_stats.json")
1619
+ with open(stats_file, "w", encoding="utf-8") as f:
1620
+ json.dump(stats_summary, f, indent=2, ensure_ascii=False)
1621
+ logger.info(f"Saved detailed stats to: {stats_file}")
1622
+
1623
+ # Save distribution plots.
1624
+ plt.figure(figsize=(12, 8))
1625
+ plt.subplot(2, 2, 1)
1626
+ plt.hist(gene_transcript_counts, bins=50, alpha=0.7, edgecolor="black")
1627
+ plt.xlabel("Transcripts per gene")
1628
+ plt.ylabel("Number of genes")
1629
+ plt.title("Transcript count per gene")
1630
+ plt.yscale("log")
1631
+
1632
+ plt.subplot(2, 2, 2)
1633
+ plt.hist(cell_transcript_counts, bins=50, alpha=0.7, edgecolor="black")
1634
+ plt.xlabel("Transcripts per cell")
1635
+ plt.ylabel("Number of cells")
1636
+ plt.title("Transcript count per cell")
1637
+ plt.yscale("log")
1638
+
1639
+ plt.subplot(2, 2, 3)
1640
+ plt.hist(cell_gene_transcript_counts, bins=30, alpha=0.7, edgecolor="black")
1641
+ plt.xlabel("Transcripts per (cell, gene) pair")
1642
+ plt.ylabel("Number of pairs")
1643
+ plt.title("Transcript count per (cell, gene) pair")
1644
+ plt.yscale("log")
1645
+
1646
+ plt.subplot(2, 2, 4)
1647
+ # Single vs multi transcript pairs.
1648
+ labels = ["Single-transcript pairs", "Multi-transcript pairs"]
1649
+ sizes = [
1650
+ pair_stats["single_transcript_pairs"],
1651
+ pair_stats["multi_transcript_pairs"],
1652
+ ]
1653
+ plt.pie(sizes, labels=labels, autopct="%1.1f%%", startangle=90)
1654
+ plt.title("Single vs multi-transcript pairs")
1655
+
1656
+ plt.tight_layout()
1657
+ dist_plot_file = os.path.join(output_dir, "transcript_distribution_plots.png")
1658
+ plt.savefig(dist_plot_file, dpi=300, bbox_inches="tight")
1659
+ plt.close()
1660
+ logger.info(f"Saved distribution plots to: {dist_plot_file}")
1661
+
1662
+ return stats_summary
1663
+
1664
+
1665
+ if __name__ == "__main__":
1666
+ # CLI argument parsing.
1667
+ parser = argparse.ArgumentParser(
1668
+ description="Compute similarity between transcript graphs using JS divergence"
1669
+ )
1670
+ parser.add_argument(
1671
+ "--pkl_file",
1672
+ type=str,
1673
+ required=True,
1674
+ help="Path to a PKL containing df_registered",
1675
+ )
1676
+ parser.add_argument(
1677
+ "--output_dir", type=str, default=None, help="Output directory (optional)"
1678
+ )
1679
+ parser.add_argument(
1680
+ "--max_count", type=int, default=10, help="Max comparisons per target cell"
1681
+ )
1682
+ parser.add_argument(
1683
+ "--transcript_window",
1684
+ type=int,
1685
+ default=30,
1686
+ help="Transcript count difference window for candidate filtering",
1687
+ )
1688
+ parser.add_argument(
1689
+ "--bin_size",
1690
+ type=str,
1691
+ default="0.01",
1692
+ help='Path length bin size for JS divergence; set to "auto" to estimate',
1693
+ )
1694
+ parser.add_argument(
1695
+ "--threshold",
1696
+ type=float,
1697
+ default=0.05,
1698
+ help="Isolated node ratio threshold for selecting r",
1699
+ )
1700
+ parser.add_argument(
1701
+ "--r_min",
1702
+ type=str,
1703
+ default="0.01",
1704
+ help='Minimum r for search; set to "auto" to estimate',
1705
+ )
1706
+ parser.add_argument(
1707
+ "--r_max",
1708
+ type=str,
1709
+ default="0.6",
1710
+ help='Maximum r for search; set to "auto" to estimate',
1711
+ )
1712
+ parser.add_argument(
1713
+ "--r_step", type=float, default=0.03, help="Step size for r search"
1714
+ )
1715
+ parser.add_argument(
1716
+ "--num_threads",
1717
+ type=int,
1718
+ default=None,
1719
+ help="Number of threads (default: CPU count)",
1720
+ )
1721
+ parser.add_argument(
1722
+ "--use_same_r",
1723
+ action="store_true",
1724
+ help="Use a shared r for all cells within a gene (gene-level r)",
1725
+ )
1726
+ parser.add_argument(
1727
+ "--visualize_top_n",
1728
+ type=int,
1729
+ default=5,
1730
+ help="Visualize top-N most similar pairs by JS divergence (0 disables)",
1731
+ )
1732
+ parser.add_argument(
1733
+ "--log_level",
1734
+ type=str,
1735
+ default="INFO",
1736
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
1737
+ help="Log level",
1738
+ )
1739
+ parser.add_argument(
1740
+ "--log_file",
1741
+ type=str,
1742
+ default="js_distance_transcriptome.log",
1743
+ help="Log filename",
1744
+ )
1745
+ parser.add_argument(
1746
+ "--no_vectorized",
1747
+ action="store_true",
1748
+ help="Disable vectorized portrait computation (slower, lower memory)",
1749
+ )
1750
+ parser.add_argument(
1751
+ "--filter_pkl_file",
1752
+ type=str,
1753
+ default=None,
1754
+ help="Optional filter PKL path (contains cell_labels/gene_labels)",
1755
+ )
1756
+ parser.add_argument(
1757
+ "--auto_params", action="store_true", help="Auto-set r_min, r_max, and bin_size"
1758
+ )
1759
+ parser.add_argument(
1760
+ "--n_bins",
1761
+ type=int,
1762
+ default=50,
1763
+ help="Number of bins when auto-setting bin_size",
1764
+ )
1765
+ parser.add_argument(
1766
+ "--min_percentile",
1767
+ type=float,
1768
+ default=1.0,
1769
+ help="Percentile used when auto-setting r_min",
1770
+ )
1771
+ parser.add_argument(
1772
+ "--max_percentile",
1773
+ type=float,
1774
+ default=99.0,
1775
+ help="Percentile used when auto-setting r_max",
1776
+ )
1777
+
1778
+ args = parser.parse_args()
1779
+
1780
+ # Set log level.
1781
+ logger.setLevel(getattr(logging, args.log_level))
1782
+
1783
+ # Add file handler.
1784
+ file_handler = logging.FileHandler(args.log_file)
1785
+ file_handler.setFormatter(
1786
+ logging.Formatter(
1787
+ "[%(asctime)s][%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
1788
+ )
1789
+ )
1790
+ logger.addHandler(file_handler)
1791
+
1792
+ # Track program runtime.
1793
+ program_start_time = time.time()
1794
+ logger.info("=" * 80)
1795
+ logger.info("Program started")
1796
+ logger.info(f"CLI args: {vars(args)}")
1797
+
1798
+ try:
1799
+ # Run main function.
1800
+ distances_df = calculate_js_distances(
1801
+ pkl_file=args.pkl_file,
1802
+ output_dir=args.output_dir,
1803
+ max_count=args.max_count,
1804
+ transcript_window=args.transcript_window,
1805
+ bin_size=args.bin_size,
1806
+ threshold=args.threshold,
1807
+ r_min=args.r_min,
1808
+ r_max=args.r_max,
1809
+ r_step=args.r_step,
1810
+ num_threads=args.num_threads,
1811
+ use_same_r=args.use_same_r,
1812
+ visualize_top_n=args.visualize_top_n,
1813
+ use_vectorized=(not args.no_vectorized),
1814
+ filter_pkl_file=args.filter_pkl_file,
1815
+ auto_params=args.auto_params,
1816
+ n_bins=args.n_bins,
1817
+ min_percentile=args.min_percentile,
1818
+ max_percentile=args.max_percentile,
1819
+ )
1820
+
1821
+ # Print result stats.
1822
+ if not distances_df.empty:
1823
+ logger.info("\nResult stats:")
1824
+ logger.info(f"Total distances: {len(distances_df)}")
1825
+ logger.info(
1826
+ f"JS range: [{distances_df['js_distance'].min():.4f}, {distances_df['js_distance'].max():.4f}]"
1827
+ )
1828
+ logger.info(f"Mean JS: {distances_df['js_distance'].mean():.4f}")
1829
+ logger.info(f"Median JS: {distances_df['js_distance'].median():.4f}")
1830
+ logger.info(
1831
+ f"r range: [{distances_df['target_r'].min():.2f}, {distances_df['target_r'].max():.2f}]"
1832
+ )
1833
+ logger.info(f"Mean r: {distances_df['target_r'].mean():.2f}")
1834
+
1835
+ # Preview first few rows.
1836
+ logger.info("\nResult preview:")
1837
+ logger.info(distances_df.head().to_string())
1838
+
1839
+ except Exception as e:
1840
+ logger.error(f"Program failed: {e}")
1841
+ import traceback
1842
+
1843
+ logger.error(traceback.format_exc())
1844
+ raise
1845
+
1846
+ # Compute and report total runtime.
1847
+ program_end_time = time.time()
1848
+ total_program_time = program_end_time - program_start_time
1849
+ hours, remainder = divmod(total_program_time, 3600)
1850
+ minutes, seconds = divmod(remainder, 60)
1851
+
1852
+ logger.info("=" * 80)
1853
+ logger.info(f"Program total runtime: {int(hours)}h {int(minutes)}m {seconds:.2f}s")
1854
+ logger.info(
1855
+ f"Start time: {datetime.fromtimestamp(program_start_time).strftime('%Y-%m-%d %H:%M:%S')}"
1856
+ )
1857
+ logger.info(
1858
+ f"End time: {datetime.fromtimestamp(program_end_time).strftime('%Y-%m-%d %H:%M:%S')}"
1859
+ )
1860
+ logger.info("=" * 80)
1861
+
1862
+ print("=" * 50)