grasp-tool 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_tool/__init__.py +17 -0
- grasp_tool/__main__.py +6 -0
- grasp_tool/cli/__init__.py +1 -0
- grasp_tool/cli/main.py +793 -0
- grasp_tool/cli/train_moco.py +778 -0
- grasp_tool/gnn/__init__.py +1 -0
- grasp_tool/gnn/embedding.py +165 -0
- grasp_tool/gnn/gat_moco_final.py +990 -0
- grasp_tool/gnn/graphloader.py +1748 -0
- grasp_tool/gnn/plot_refined.py +1556 -0
- grasp_tool/preprocessing/__init__.py +1 -0
- grasp_tool/preprocessing/augumentation.py +66 -0
- grasp_tool/preprocessing/cellplot.py +475 -0
- grasp_tool/preprocessing/filter.py +171 -0
- grasp_tool/preprocessing/network.py +79 -0
- grasp_tool/preprocessing/partition.py +654 -0
- grasp_tool/preprocessing/portrait.py +1862 -0
- grasp_tool/preprocessing/register.py +1021 -0
- grasp_tool-0.1.0.dist-info/METADATA +511 -0
- grasp_tool-0.1.0.dist-info/RECORD +22 -0
- grasp_tool-0.1.0.dist-info/WHEEL +4 -0
- grasp_tool-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,1862 @@
|
|
|
1
|
+
"""Network Portrait + Jensen-Shannon distance for transcript graphs.
|
|
2
|
+
|
|
3
|
+
This module computes pairwise similarity between per-cell/per-gene transcript
|
|
4
|
+
graphs using a network-portrait representation and Jensen-Shannon divergence.
|
|
5
|
+
|
|
6
|
+
Input
|
|
7
|
+
A PKL that contains a DataFrame with transcript coordinates (typically the
|
|
8
|
+
registered output with `df_registered`). The DataFrame is expected to contain:
|
|
9
|
+
- cell, gene, x_c_s, y_c_s
|
|
10
|
+
|
|
11
|
+
Output
|
|
12
|
+
A CSV of JS distances that can be used as an optional signal for selecting
|
|
13
|
+
positive samples during training.
|
|
14
|
+
"""
|
|
15
|
+
## python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl/simulated_data1_data_dict.pkl --use_same_r --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_simulated_data1/js_scipy_auto.log --visualize_top_n 0 --auto_params
|
|
16
|
+
## python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl/simulated_data1_data_dict.pkl --use_same_r --visualize_top_n 10 --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_simulated_data1/js2.log
|
|
17
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_fibroblast_data_dict.pkl --use_same_r --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy.log --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/seqfish_fibroblast_cell171_gene2734_graph143789.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_nohup.log &
|
|
18
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex.log --visualize_top_n 0 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_nohup.log &
|
|
19
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_data_dict_new.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex2.log --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/seqfish_cortex_new_cell708_gene93_graph708.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_nohup2.log &
|
|
20
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_data_dict_new1.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_new1.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/seqfish_cortex_sub19_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/seqfish_cortex_new1_cell2405_gene92_graph2405.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_new1.log &
|
|
21
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_merfish_liver/js_scipy_merfish_data3.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data3_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data3_cell281_gene139_graph13488.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_merfish_liver/js_scipy_merfish_data3.log &
|
|
22
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_merfish_liver/js_scipy_merfish_data4.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data4_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data4_cell919_gene176_graph43931.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_merfish_liver/js_scipy_merfish_data4.log &
|
|
23
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_central.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data4_central_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data4_central_cell182_gene106_graph9770.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_central.log &
|
|
24
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_data_dict_sub19.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_sub19.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/seqfish_cortex_sub19_portrait --visualize_top_n 0 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/logs_seqfish_fibroblast/js_scipy_cortex_sub19.log &
|
|
25
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_Astrocytes_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy_cortex_Astrocytes.log --filter_pkl_file seqfish_cortex_Astrocytes_cell528_gene22_graph528.pkl --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/seqfish_cortex_Astrocytes_portrait --visualize_top_n 0 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy_cortex_Astrocytes.log &
|
|
26
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/seqfish_cortex_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy_seqfish_plus_all.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/seqfish_plus_all_portrait --visualize_top_n 0 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy_seqfish_plus_all.log &
|
|
27
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_protal.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data4_protal_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data4_protal_cell454_gene124_graph21222.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_protal.log &
|
|
28
|
+
|
|
29
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_central_bigger.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data4_central_bigger_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data_central_cell870_gene124_graph45025.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_central_bigger.log &
|
|
30
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_liver_data2_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_protal_bigger.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_liver_data4_protal_bigger_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_liver_data_protal_cell1713_gene143_graph79975.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_data4_protal_bigger.log &
|
|
31
|
+
|
|
32
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/merscope_intestine_Enterocyte_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_intestine/js_scipy_merfish_Enterocyte.log --output_dir /lustre/home/1910305118/data/GCN_CL/1_input/merfish_intestine_Enterocyte_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5_graph_data/merscope_intestine_Enterocyte_cell419_gene17_graph905.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_merfish_intestine/js_scipy_merfish_Enterocyte.log &
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
## nohup python portrait.py --pkl_file /lustre/home/1910305118/data/GCN_CL/1_input/pkl_data/simulated1_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_simulated1/js_scipy_simulated1.log --output_dir /lustre/home/1910305118/data/GCN_CL/6_analysis/js_portrait/simulated1_portrait --visualize_top_n 0 --filter_pkl_file /lustre/home/1910305118/data/GCN_CL/5.1_graph_data/simulated1_cell10_gene80_graph800_weight.pkl 2>&1 > /lustre/home/1910305118/data/GCN_CL/0_code/0.5_logs/logs_simulated1/js_scipy_simulated1.log &
|
|
36
|
+
|
|
37
|
+
## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/seqfish_fibroblast_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/seqfish_fibroblast_portrait --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5.1_graph_data/seqfish_fibroblast_cell171_gene59_graph8068.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_seqfish_fibroblast/js_scipy_nohup.log &
|
|
38
|
+
|
|
39
|
+
## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/simulated3_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_simulated/js_scipy_simulated3.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/simulated3_portrait --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5.1_graph_data/simulated3_original_cell50_gene400_graph15000.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_simulated/js_scipy_simulated3.log &
|
|
40
|
+
|
|
41
|
+
## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merscope_liver_data_region1_portal_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region1_portal.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_liver_region1_portal --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merscope_liver_data_region1_portal_cell1708_gene143_graph79172.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region1_portal.log &
|
|
42
|
+
|
|
43
|
+
## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merscope_liver_data_region1_central_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region1_central.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_liver_region1_central --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merscope_liver_data_region1_central_cell1711_gene136_graph86074.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region1_central.log &
|
|
44
|
+
|
|
45
|
+
## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merscope_liver_data_region2_central_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region2_central.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_liver_region2_central --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merscope_liver_data_region2_central_cell936_gene126_graph44910.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region2_central_nohup.log &
|
|
46
|
+
|
|
47
|
+
## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merscope_liver_data_region2_portal_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region2_portal.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_liver_region2_portal --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merscope_liver_data_region2_portal_cell747_gene124_graph33523.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_liver/js_scipy_merfish_region2_portal_nohup.log &
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
## nohup python portrait.py --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_intestine_Enterocyte_resegment_new_data_dict.pkl --use_same_r --max_count 30 --auto_params --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_intestine/js_scipy_enterocyte_resegment_new.log --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_intestine_enterocyte_resegment_new_portrait --visualize_top_n 0 --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_intestine_Enterocyte_resegment_new_cell688_gene58_graph4331.pkl 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_intestine/js_scipy_enterocyte_resegment_new_nohup.log &
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# # ---------- Group 1 ----------
|
|
54
|
+
# nohup python portrait.py \
|
|
55
|
+
# --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_u2os_data_dict.pkl \
|
|
56
|
+
# --use_same_r --max_count 30 --auto_params \
|
|
57
|
+
# --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group1.log \
|
|
58
|
+
# --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_u2os_group1_portrait \
|
|
59
|
+
# --visualize_top_n 0 \
|
|
60
|
+
# --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_u2os_cell634_gene25_graph1000.pkl \
|
|
61
|
+
# 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group1_nohup.log &
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
# # ---------- Group 2 ----------
|
|
65
|
+
# nohup python portrait.py \
|
|
66
|
+
# --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_u2os_data_dict.pkl \
|
|
67
|
+
# --use_same_r --max_count 30 --auto_params \
|
|
68
|
+
# --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group2.log \
|
|
69
|
+
# --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_u2os_group2_portrait \
|
|
70
|
+
# --visualize_top_n 0 \
|
|
71
|
+
# --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_u2os_cell621_gene25_graph1000.pkl \
|
|
72
|
+
# 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group2_nohup.log &
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# # ---------- Group 3 ----------
|
|
76
|
+
# nohup python portrait.py \
|
|
77
|
+
# --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_u2os_data_dict.pkl \
|
|
78
|
+
# --use_same_r --max_count 30 --auto_params \
|
|
79
|
+
# --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group3.log \
|
|
80
|
+
# --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_u2os_group3_portrait \
|
|
81
|
+
# --visualize_top_n 0 \
|
|
82
|
+
# --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_u2os_cell629_gene25_graph1000.pkl \
|
|
83
|
+
# 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group3_nohup.log &
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# # ---------- Group 4 ----------
|
|
87
|
+
# nohup python portrait.py \
|
|
88
|
+
# --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_u2os_data_dict.pkl \
|
|
89
|
+
# --use_same_r --max_count 30 --auto_params \
|
|
90
|
+
# --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group4.log \
|
|
91
|
+
# --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_u2os_group4_portrait \
|
|
92
|
+
# --visualize_top_n 0 \
|
|
93
|
+
# --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_u2os_cell621_gene9_graph947.pkl \
|
|
94
|
+
# 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group4_nohup.log &
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# # ---------- Group 5 ----------
|
|
98
|
+
# nohup python portrait.py \
|
|
99
|
+
# --pkl_file /home/lixiangyu/hyy/GRASP/1_input/pkl_data/merfish_u2os_data_dict.pkl \
|
|
100
|
+
# --use_same_r --max_count 30 --auto_params \
|
|
101
|
+
# --log_file /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group5.log \
|
|
102
|
+
# --output_dir /home/lixiangyu/hyy/GRASP/6_analysis/js_portrait/merfish_u2os_group5_portrait \
|
|
103
|
+
# --visualize_top_n 0 \
|
|
104
|
+
# --filter_pkl_file /home/lixiangyu/hyy/GRASP/5_graph_data/merfish_u2os_cell989_gene25_graph23242.pkl \
|
|
105
|
+
# 2>&1 > /home/lixiangyu/hyy/GRASP/0_code/0.5_logs/logs_merfish_u2os/js_group5_nohup.log &
|
|
106
|
+
|
|
107
|
+
## Legacy note: derived from utils_code/portrait.py
|
|
108
|
+
import pandas as pd
|
|
109
|
+
import numpy as np
|
|
110
|
+
import networkx as nx
|
|
111
|
+
|
|
112
|
+
from scipy.spatial import distance_matrix
|
|
113
|
+
import time
|
|
114
|
+
from datetime import datetime
|
|
115
|
+
import os
|
|
116
|
+
import logging
|
|
117
|
+
import argparse
|
|
118
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
119
|
+
import warnings
|
|
120
|
+
import matplotlib.pyplot as plt
|
|
121
|
+
import seaborn as sns
|
|
122
|
+
from tqdm.auto import tqdm
|
|
123
|
+
import pickle
|
|
124
|
+
from scipy.sparse import coo_matrix, csr_matrix
|
|
125
|
+
from scipy.sparse.csgraph import shortest_path
|
|
126
|
+
from scipy.spatial import KDTree
|
|
127
|
+
import random
|
|
128
|
+
import sys
|
|
129
|
+
|
|
130
|
+
# Matplotlib defaults (keep output deterministic and readable)
|
|
131
|
+
plt.rcParams["font.family"] = ["Arial"]
|
|
132
|
+
plt.rcParams["font.sans-serif"] = ["Arial"]
|
|
133
|
+
plt.rcParams["axes.unicode_minus"] = False
|
|
134
|
+
|
|
135
|
+
# Logging
|
|
136
|
+
logging.basicConfig(
|
|
137
|
+
level=logging.INFO,
|
|
138
|
+
format="[%(asctime)s][%(levelname)s] %(message)s",
|
|
139
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
140
|
+
)
|
|
141
|
+
logger = logging.getLogger("js_distance")
|
|
142
|
+
|
|
143
|
+
# Warnings
|
|
144
|
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
|
145
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def find_r_for_isolated_threshold(
|
|
149
|
+
df, threshold=0.05, r_min=0.01, r_max=0.6, step=0.03, verbose=False, dists=None
|
|
150
|
+
):
|
|
151
|
+
"""Find a connection radius that keeps isolated-node ratio under a threshold.
|
|
152
|
+
|
|
153
|
+
This implementation uses a KDTree to approximate an appropriate radius based
|
|
154
|
+
on nearest-neighbor distances.
|
|
155
|
+
"""
|
|
156
|
+
positions = df[["x_c_s", "y_c_s"]].values
|
|
157
|
+
N = len(df)
|
|
158
|
+
|
|
159
|
+
# Edge cases
|
|
160
|
+
if N <= 1:
|
|
161
|
+
# Single point (or empty): return a small default radius.
|
|
162
|
+
return min(r_min * 5, r_max * 0.1)
|
|
163
|
+
|
|
164
|
+
# Compute nearest-neighbor distances via KDTree.
|
|
165
|
+
tree = KDTree(positions)
|
|
166
|
+
|
|
167
|
+
# Use k=2 because the first neighbor is the point itself.
|
|
168
|
+
dists, _ = tree.query(positions, k=2)
|
|
169
|
+
nn_dists = dists[:, 1]
|
|
170
|
+
|
|
171
|
+
# Percentile: at most `threshold` fraction can be isolated.
|
|
172
|
+
r = np.percentile(nn_dists, 100 * (1 - threshold))
|
|
173
|
+
|
|
174
|
+
# Clamp to the configured range.
|
|
175
|
+
r = max(r_min, min(r, r_max))
|
|
176
|
+
|
|
177
|
+
if verbose:
|
|
178
|
+
logger.debug(
|
|
179
|
+
f"KDTree-based r={r:.4f} (percentile={100 * (1 - threshold):.1f} of NN distances)"
|
|
180
|
+
)
|
|
181
|
+
# Sanity-check the actual isolated ratio.
|
|
182
|
+
isolated_count = np.sum(nn_dists > r)
|
|
183
|
+
logger.debug(f"isolated_ratio={isolated_count / N:.4f} ({isolated_count}/{N})")
|
|
184
|
+
|
|
185
|
+
return r
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def build_weighted_graph(df, r, dists=None):
|
|
189
|
+
"""Build a weighted graph from transcript coordinates.
|
|
190
|
+
|
|
191
|
+
Nodes are transcripts and edges connect pairs within radius `r`.
|
|
192
|
+
Edge weights are Euclidean distances.
|
|
193
|
+
"""
|
|
194
|
+
G = nx.Graph()
|
|
195
|
+
positions = df[["x_c_s", "y_c_s"]].values
|
|
196
|
+
|
|
197
|
+
# Add nodes.
|
|
198
|
+
for i, pos in enumerate(positions):
|
|
199
|
+
G.add_node(i, pos=pos)
|
|
200
|
+
|
|
201
|
+
# Compute pairwise distances if not provided.
|
|
202
|
+
if dists is None:
|
|
203
|
+
dists = distance_matrix(positions, positions)
|
|
204
|
+
|
|
205
|
+
# Add edges within radius.
|
|
206
|
+
for i in range(len(df)):
|
|
207
|
+
for j in range(i + 1, len(df)):
|
|
208
|
+
dist = dists[i, j]
|
|
209
|
+
if dist <= r:
|
|
210
|
+
G.add_edge(i, j, weight=dist)
|
|
211
|
+
|
|
212
|
+
return G
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def get_network_portrait(G, bin_size=0.01, use_vectorized=True):
|
|
216
|
+
"""Compute a network portrait for a weighted graph."""
|
|
217
|
+
|
|
218
|
+
# Node count
|
|
219
|
+
n_nodes = len(G)
|
|
220
|
+
if n_nodes <= 1:
|
|
221
|
+
return {}, n_nodes
|
|
222
|
+
|
|
223
|
+
# Implementation choice
|
|
224
|
+
if use_vectorized:
|
|
225
|
+
# Vectorized APSP via SciPy sparse shortest_path.
|
|
226
|
+
rows, cols, weights = [], [], []
|
|
227
|
+
for u, v, data in G.edges(data=True):
|
|
228
|
+
w = data.get("weight", 1.0)
|
|
229
|
+
rows.append(u)
|
|
230
|
+
cols.append(v)
|
|
231
|
+
weights.append(w)
|
|
232
|
+
rows.append(v)
|
|
233
|
+
cols.append(u)
|
|
234
|
+
weights.append(w)
|
|
235
|
+
|
|
236
|
+
A = csr_matrix((weights, (rows, cols)), shape=(n_nodes, n_nodes))
|
|
237
|
+
|
|
238
|
+
dist_mat = shortest_path(A, directed=False, unweighted=False, method="auto")
|
|
239
|
+
|
|
240
|
+
degs = np.array([d for _, d in sorted(G.degree(), key=lambda x: x[0])])
|
|
241
|
+
|
|
242
|
+
# 4) Exclude self-pairs and keep all i != j pairs.
|
|
243
|
+
# dist_mat is an n x n numpy array.
|
|
244
|
+
i_idx, j_idx = np.nonzero(~np.eye(n_nodes, dtype=bool))
|
|
245
|
+
dists = dist_mat[i_idx, j_idx]
|
|
246
|
+
src_degs = degs[i_idx]
|
|
247
|
+
|
|
248
|
+
# 5) Bin distances.
|
|
249
|
+
bins = np.floor(dists / bin_size).astype(int)
|
|
250
|
+
|
|
251
|
+
# 6) Vectorized counting.
|
|
252
|
+
combined = np.column_stack([bins, src_degs])
|
|
253
|
+
unique_ck, counts = np.unique(combined, axis=0, return_counts=True)
|
|
254
|
+
|
|
255
|
+
# 7) Convert back to a dict.
|
|
256
|
+
portrait = {
|
|
257
|
+
(int(bin_l), int(deg)): int(cnt)
|
|
258
|
+
for (bin_l, deg), cnt in zip(unique_ck, counts)
|
|
259
|
+
}
|
|
260
|
+
else:
|
|
261
|
+
# Loop implementation: slower, but uses less memory.
|
|
262
|
+
# All-pairs shortest path lengths.
|
|
263
|
+
length_dict = dict(nx.all_pairs_dijkstra_path_length(G, weight="weight"))
|
|
264
|
+
# Node degrees.
|
|
265
|
+
degrees = dict(G.degree())
|
|
266
|
+
|
|
267
|
+
portrait = {}
|
|
268
|
+
|
|
269
|
+
for i in length_dict:
|
|
270
|
+
for j in length_dict[i]:
|
|
271
|
+
if i == j:
|
|
272
|
+
continue
|
|
273
|
+
dist = length_dict[i][j]
|
|
274
|
+
bin_l = int(dist // bin_size)
|
|
275
|
+
deg = degrees[i]
|
|
276
|
+
key = (bin_l, deg)
|
|
277
|
+
portrait[key] = portrait.get(key, 0) + 1
|
|
278
|
+
|
|
279
|
+
return portrait, n_nodes
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def compute_weighted_distribution(portrait, N):
|
|
283
|
+
"""Convert a portrait count dict into a weighted probability distribution."""
|
|
284
|
+
|
|
285
|
+
total_pairs = N * N
|
|
286
|
+
dist = {}
|
|
287
|
+
|
|
288
|
+
for (l, k), count in portrait.items():
|
|
289
|
+
dist[(l, k)] = (k * count) / total_pairs
|
|
290
|
+
|
|
291
|
+
return dist
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def js_divergence(P, Q):
|
|
295
|
+
"""Compute Jensen-Shannon divergence between two distributions."""
|
|
296
|
+
keys = set(P.keys()).union(Q.keys())
|
|
297
|
+
p_vec = np.array([P.get(k, 0.0) for k in keys])
|
|
298
|
+
q_vec = np.array([Q.get(k, 0.0) for k in keys])
|
|
299
|
+
m_vec = 0.5 * (p_vec + q_vec)
|
|
300
|
+
|
|
301
|
+
def safe_kl(p, q):
|
|
302
|
+
mask = (p > 0) & (q > 0)
|
|
303
|
+
return np.sum(p[mask] * np.log2(p[mask] / q[mask]))
|
|
304
|
+
|
|
305
|
+
return 0.5 * safe_kl(p_vec, m_vec) + 0.5 * safe_kl(q_vec, m_vec)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def plot_portrait(portrait, title="Network Portrait", save_path=None):
|
|
309
|
+
"""Plot a 2D heatmap for a network portrait."""
|
|
310
|
+
if not portrait:
|
|
311
|
+
logger.warning("Empty portrait; skip plotting")
|
|
312
|
+
return
|
|
313
|
+
|
|
314
|
+
df = pd.DataFrame([{"l": l, "k": k, "value": v} for (l, k), v in portrait.items()])
|
|
315
|
+
|
|
316
|
+
pivot = df.pivot(index="l", columns="k", values="value").fillna(0)
|
|
317
|
+
|
|
318
|
+
plt.figure(figsize=(8, 6))
|
|
319
|
+
sns.heatmap(pivot, cmap="viridis", annot=False)
|
|
320
|
+
plt.title(title, fontsize=14, fontweight="bold")
|
|
321
|
+
plt.xlabel("Degree k", fontsize=12)
|
|
322
|
+
plt.ylabel("Path length bin l", fontsize=12)
|
|
323
|
+
plt.xticks(fontsize=10)
|
|
324
|
+
plt.yticks(fontsize=10)
|
|
325
|
+
plt.tight_layout()
|
|
326
|
+
|
|
327
|
+
if save_path:
|
|
328
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
329
|
+
logger.info(f"Portrait saved to: {save_path}")
|
|
330
|
+
else:
|
|
331
|
+
plt.show()
|
|
332
|
+
|
|
333
|
+
plt.close()
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def plot_graph(G, title="Graph structure", save_path=None):
|
|
337
|
+
"""Plot the graph layout using stored node positions."""
|
|
338
|
+
pos = nx.get_node_attributes(G, "pos")
|
|
339
|
+
|
|
340
|
+
plt.figure(figsize=(8, 8))
|
|
341
|
+
nx.draw(
|
|
342
|
+
G,
|
|
343
|
+
pos,
|
|
344
|
+
node_size=30,
|
|
345
|
+
node_color="skyblue",
|
|
346
|
+
edge_color="gray",
|
|
347
|
+
width=0.5,
|
|
348
|
+
with_labels=False,
|
|
349
|
+
)
|
|
350
|
+
plt.title(title, fontsize=14, fontweight="bold")
|
|
351
|
+
plt.axis("equal")
|
|
352
|
+
plt.tight_layout()
|
|
353
|
+
|
|
354
|
+
if save_path:
|
|
355
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
356
|
+
logger.info(f"Graph saved to: {save_path}")
|
|
357
|
+
else:
|
|
358
|
+
plt.show()
|
|
359
|
+
|
|
360
|
+
plt.close()
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def compare_graphs(
|
|
364
|
+
df1, df2, r=None, bin_size=0.01, show_plots=True, save_dir=None, use_vectorized=True
|
|
365
|
+
):
|
|
366
|
+
"""Compare two transcript graphs and compute JS divergence."""
|
|
367
|
+
|
|
368
|
+
# Auto-select radius if not provided.
|
|
369
|
+
if r is None:
|
|
370
|
+
r1 = find_r_for_isolated_threshold(df1, threshold=0.05)
|
|
371
|
+
r2 = find_r_for_isolated_threshold(df2, threshold=0.05)
|
|
372
|
+
r = max(r1, r2)
|
|
373
|
+
logger.info(f"Auto-selected connection radius r = {r:.2f}")
|
|
374
|
+
|
|
375
|
+
# Build graphs.
|
|
376
|
+
G1 = build_weighted_graph(df1, r)
|
|
377
|
+
G2 = build_weighted_graph(df2, r)
|
|
378
|
+
|
|
379
|
+
# Compute network portraits.
|
|
380
|
+
B1, N1 = get_network_portrait(G1, bin_size, use_vectorized)
|
|
381
|
+
B2, N2 = get_network_portrait(G2, bin_size, use_vectorized)
|
|
382
|
+
|
|
383
|
+
# Compute weighted distributions.
|
|
384
|
+
P = compute_weighted_distribution(B1, N1)
|
|
385
|
+
Q = compute_weighted_distribution(B2, N2)
|
|
386
|
+
|
|
387
|
+
# Compute JS divergence.
|
|
388
|
+
js = js_divergence(P, Q)
|
|
389
|
+
logger.info(f"Jensen-Shannon divergence between two graphs: {js:.4f}")
|
|
390
|
+
|
|
391
|
+
# Visualization.
|
|
392
|
+
if show_plots or save_dir:
|
|
393
|
+
if save_dir:
|
|
394
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
395
|
+
|
|
396
|
+
if show_plots:
|
|
397
|
+
plot_graph(G1, title="Graph 1 structure")
|
|
398
|
+
plot_graph(G2, title="Graph 2 structure")
|
|
399
|
+
plot_portrait(B1, title="Graph 1 network portrait")
|
|
400
|
+
plot_portrait(B2, title="Graph 2 network portrait")
|
|
401
|
+
|
|
402
|
+
if save_dir:
|
|
403
|
+
plot_graph(
|
|
404
|
+
G1,
|
|
405
|
+
title="Graph 1 structure",
|
|
406
|
+
save_path=f"{save_dir}/graph1_structure.png",
|
|
407
|
+
)
|
|
408
|
+
plot_graph(
|
|
409
|
+
G2,
|
|
410
|
+
title="Graph 2 structure",
|
|
411
|
+
save_path=f"{save_dir}/graph2_structure.png",
|
|
412
|
+
)
|
|
413
|
+
plot_portrait(
|
|
414
|
+
B1,
|
|
415
|
+
title="Graph 1 network portrait",
|
|
416
|
+
save_path=f"{save_dir}/graph1_portrait.png",
|
|
417
|
+
)
|
|
418
|
+
plot_portrait(
|
|
419
|
+
B2,
|
|
420
|
+
title="Graph 2 network portrait",
|
|
421
|
+
save_path=f"{save_dir}/graph2_portrait.png",
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
return js
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def find_gene_optimal_r(
|
|
428
|
+
gene,
|
|
429
|
+
df,
|
|
430
|
+
cell_list,
|
|
431
|
+
threshold=0.05,
|
|
432
|
+
r_min=0.01,
|
|
433
|
+
r_max=0.6,
|
|
434
|
+
r_step=0.03,
|
|
435
|
+
dist_dict=None,
|
|
436
|
+
):
|
|
437
|
+
"""
|
|
438
|
+
Find an optimal r value for a gene (use the max r across cells).
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
gene: Gene ID.
|
|
442
|
+
df: DataFrame containing transcripts.
|
|
443
|
+
cell_list: List of cell IDs.
|
|
444
|
+
threshold: Isolated-node ratio threshold.
|
|
445
|
+
r_min, r_max, r_step: Search parameters for r.
|
|
446
|
+
dist_dict: Optional precomputed distance matrices {(cell, gene): dist_matrix}.
|
|
447
|
+
|
|
448
|
+
Returns:
|
|
449
|
+
tuple: (optimal r for gene, dict {cell: dist_matrix})
|
|
450
|
+
"""
|
|
451
|
+
logger.info(f"Computing optimal r for gene {gene}")
|
|
452
|
+
|
|
453
|
+
# All transcripts for this gene.
|
|
454
|
+
gene_df = df[df["gene"] == gene]
|
|
455
|
+
|
|
456
|
+
cell_r_values = {}
|
|
457
|
+
cell_dist_matrices = {} # store distance matrix per cell
|
|
458
|
+
transcript_counts = {} # transcript count per cell
|
|
459
|
+
|
|
460
|
+
# Compute r per cell.
|
|
461
|
+
for cell in cell_list:
|
|
462
|
+
cell_df = gene_df[gene_df["cell"] == cell]
|
|
463
|
+
transcript_count = len(cell_df)
|
|
464
|
+
transcript_counts[cell] = transcript_count
|
|
465
|
+
|
|
466
|
+
# If there is only one transcript, use a small default radius.
|
|
467
|
+
if transcript_count == 1:
|
|
468
|
+
# Avoid connecting to far-away points.
|
|
469
|
+
r_single = min(r_min * 5, r_max * 0.1) # min(5*r_min, 0.1*r_max)
|
|
470
|
+
cell_r_values[cell] = r_single
|
|
471
|
+
# Single-point distance matrix.
|
|
472
|
+
cell_dist_matrices[cell] = np.array([[0.0]])
|
|
473
|
+
continue
|
|
474
|
+
|
|
475
|
+
# Skip empty cells.
|
|
476
|
+
if transcript_count == 0:
|
|
477
|
+
continue
|
|
478
|
+
|
|
479
|
+
try:
|
|
480
|
+
# Prefer precomputed distance matrix.
|
|
481
|
+
dists = dist_dict.get((cell, gene), None) if dist_dict else None
|
|
482
|
+
|
|
483
|
+
# Compute if missing.
|
|
484
|
+
if dists is None:
|
|
485
|
+
positions = cell_df[["x_c_s", "y_c_s"]].values
|
|
486
|
+
dists = distance_matrix(positions, positions)
|
|
487
|
+
|
|
488
|
+
# Find optimal r for this cell.
|
|
489
|
+
r = find_r_for_isolated_threshold(
|
|
490
|
+
cell_df,
|
|
491
|
+
threshold,
|
|
492
|
+
r_min,
|
|
493
|
+
r_max,
|
|
494
|
+
r_step,
|
|
495
|
+
verbose=False,
|
|
496
|
+
dists=dists,
|
|
497
|
+
)
|
|
498
|
+
cell_r_values[cell] = r
|
|
499
|
+
cell_dist_matrices[cell] = dists # cache for reuse
|
|
500
|
+
except Exception as e:
|
|
501
|
+
logger.error(f"Failed to compute r for cell={cell}, gene={gene}: {e}")
|
|
502
|
+
|
|
503
|
+
# Summary stats.
|
|
504
|
+
total_cells = len(cell_list)
|
|
505
|
+
valid_cells = len(cell_r_values)
|
|
506
|
+
single_transcript_cells = sum(
|
|
507
|
+
1 for count in transcript_counts.values() if count == 1
|
|
508
|
+
)
|
|
509
|
+
multi_transcript_cells = sum(1 for count in transcript_counts.values() if count > 1)
|
|
510
|
+
zero_transcript_cells = sum(1 for count in transcript_counts.values() if count == 0)
|
|
511
|
+
|
|
512
|
+
# If no r values are valid, return a safer default.
|
|
513
|
+
if not cell_r_values:
|
|
514
|
+
# Use a smaller default to avoid overly dense graphs.
|
|
515
|
+
default_r = min(r_max * 0.3, r_min * 10)
|
|
516
|
+
logger.warning(
|
|
517
|
+
f"Gene {gene} has no valid r; using adjusted default {default_r:.4f} (original default: {r_max})"
|
|
518
|
+
)
|
|
519
|
+
logger.warning(
|
|
520
|
+
f"Gene {gene} transcript stats: total_cells={total_cells}, zero={zero_transcript_cells}, one={single_transcript_cells}, multi={multi_transcript_cells}"
|
|
521
|
+
)
|
|
522
|
+
return default_r, {}
|
|
523
|
+
|
|
524
|
+
# Return max r across cells and the distance-matrix cache.
|
|
525
|
+
max_r = max(cell_r_values.values())
|
|
526
|
+
min_r = min(cell_r_values.values())
|
|
527
|
+
avg_r = np.mean(list(cell_r_values.values()))
|
|
528
|
+
|
|
529
|
+
logger.info(
|
|
530
|
+
f"Gene {gene} optimal r: {max_r:.2f} (from {valid_cells} cells, range: {min_r:.2f}-{max_r:.2f}, mean: {avg_r:.2f})"
|
|
531
|
+
)
|
|
532
|
+
logger.info(
|
|
533
|
+
f"Gene {gene} transcript distribution: zero={zero_transcript_cells}, one={single_transcript_cells}, multi={multi_transcript_cells}"
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
return max_r, cell_dist_matrices
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def precompute_portraits_for_gene(
|
|
540
|
+
gene,
|
|
541
|
+
df,
|
|
542
|
+
cell_list,
|
|
543
|
+
threshold=0.05,
|
|
544
|
+
bin_size=0.01,
|
|
545
|
+
r_min=0.01,
|
|
546
|
+
r_max=0.6,
|
|
547
|
+
r_step=0.03,
|
|
548
|
+
use_same_r=True,
|
|
549
|
+
use_vectorized=True,
|
|
550
|
+
dist_dict=None,
|
|
551
|
+
):
|
|
552
|
+
"""
|
|
553
|
+
Precompute network portraits for all cells of a gene.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
gene: Gene identifier.
|
|
557
|
+
df: Transcript DataFrame.
|
|
558
|
+
cell_list: List of cell IDs.
|
|
559
|
+
threshold: Isolated node ratio threshold.
|
|
560
|
+
bin_size: Path length bin size.
|
|
561
|
+
r_min, r_max, r_step: Radius search parameters.
|
|
562
|
+
use_same_r: Whether to use a shared r for all cells within a gene.
|
|
563
|
+
use_vectorized: Whether to use vectorized portrait computation.
|
|
564
|
+
dist_dict: Optional precomputed distance matrices {(cell, gene): dist_matrix}.
|
|
565
|
+
|
|
566
|
+
Returns:
|
|
567
|
+
Dict: {(cell, gene): (weighted_distribution, node_count, r_value)}
|
|
568
|
+
"""
|
|
569
|
+
distributions = {}
|
|
570
|
+
|
|
571
|
+
logger.info(f"Start precomputing network portraits for gene {gene}")
|
|
572
|
+
start_time = time.time()
|
|
573
|
+
|
|
574
|
+
# Filter transcripts for this gene.
|
|
575
|
+
gene_df = df[df["gene"] == gene]
|
|
576
|
+
|
|
577
|
+
# Ensure there is data.
|
|
578
|
+
if len(gene_df) == 0:
|
|
579
|
+
logger.warning(f"Gene {gene} has no transcript records")
|
|
580
|
+
return distributions
|
|
581
|
+
|
|
582
|
+
# If using a shared r, compute gene-level r first (and reuse distance matrices).
|
|
583
|
+
gene_r = None
|
|
584
|
+
cell_dist_matrices = {}
|
|
585
|
+
if use_same_r:
|
|
586
|
+
gene_r, cell_dist_matrices = find_gene_optimal_r(
|
|
587
|
+
gene, df, cell_list, threshold, r_min, r_max, r_step, dist_dict
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# Process each cell.
|
|
591
|
+
for cell in tqdm(
|
|
592
|
+
cell_list,
|
|
593
|
+
desc=f"Processing cells for gene {gene}",
|
|
594
|
+
leave=False,
|
|
595
|
+
disable=not sys.stdout.isatty(),
|
|
596
|
+
):
|
|
597
|
+
# Filter transcripts for this cell.
|
|
598
|
+
cell_df = gene_df[gene_df["cell"] == cell]
|
|
599
|
+
transcript_count = len(cell_df)
|
|
600
|
+
|
|
601
|
+
# Skip cells with no transcripts.
|
|
602
|
+
if transcript_count == 0:
|
|
603
|
+
logger.debug(f"Cell {cell}, gene {gene} has no transcripts; skipping")
|
|
604
|
+
continue
|
|
605
|
+
|
|
606
|
+
# Handle the single-transcript case.
|
|
607
|
+
if transcript_count == 1:
|
|
608
|
+
logger.debug(
|
|
609
|
+
f"Cell {cell}, gene {gene} has 1 transcript; using a single-node portrait"
|
|
610
|
+
)
|
|
611
|
+
try:
|
|
612
|
+
# Special handling for a single-node graph.
|
|
613
|
+
# The portrait contains one node with degree=0 at path length=0.
|
|
614
|
+
# In (l, k) bins: l=0 (self distance), k=0 (degree=0).
|
|
615
|
+
single_portrait = {(0, 0): 1} # One entry: (path_length=0, degree=0)
|
|
616
|
+
weighted_dist = compute_weighted_distribution(single_portrait, 1)
|
|
617
|
+
|
|
618
|
+
# Use a reasonable r.
|
|
619
|
+
r = gene_r if use_same_r else min(r_min * 5, r_max * 0.1)
|
|
620
|
+
|
|
621
|
+
distributions[(cell, gene)] = (weighted_dist, 1, r)
|
|
622
|
+
continue
|
|
623
|
+
except Exception as e:
|
|
624
|
+
logger.error(
|
|
625
|
+
f"Failed single-transcript handling cell={cell}, gene={gene}: {e}"
|
|
626
|
+
)
|
|
627
|
+
continue
|
|
628
|
+
|
|
629
|
+
# Multi-transcript case (original logic).
|
|
630
|
+
try:
|
|
631
|
+
# Prefer cell_dist_matrices, then fall back to global dist_dict.
|
|
632
|
+
dists = cell_dist_matrices.get(cell, None)
|
|
633
|
+
if dists is None and dist_dict:
|
|
634
|
+
dists = dist_dict.get((cell, gene), None)
|
|
635
|
+
|
|
636
|
+
# Compute distance matrix if missing.
|
|
637
|
+
if dists is None:
|
|
638
|
+
positions = cell_df[["x_c_s", "y_c_s"]].values
|
|
639
|
+
dists = distance_matrix(positions, positions)
|
|
640
|
+
|
|
641
|
+
# Use gene-level r, or find an r for this (cell, gene) pair.
|
|
642
|
+
r = gene_r
|
|
643
|
+
if not use_same_r:
|
|
644
|
+
r = find_r_for_isolated_threshold(
|
|
645
|
+
cell_df,
|
|
646
|
+
threshold=threshold,
|
|
647
|
+
r_min=r_min,
|
|
648
|
+
r_max=r_max,
|
|
649
|
+
step=r_step,
|
|
650
|
+
dists=dists,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
# Build graph (reuse the distance matrix).
|
|
654
|
+
G = build_weighted_graph(cell_df, r, dists=dists)
|
|
655
|
+
|
|
656
|
+
# Compute portrait and weighted distribution.
|
|
657
|
+
portrait, N = get_network_portrait(G, bin_size, use_vectorized)
|
|
658
|
+
weighted_dist = compute_weighted_distribution(portrait, N)
|
|
659
|
+
|
|
660
|
+
distributions[(cell, gene)] = (weighted_dist, N, r)
|
|
661
|
+
|
|
662
|
+
except Exception as e:
|
|
663
|
+
logger.error(f"Failed processing cell={cell}, gene={gene}: {e}")
|
|
664
|
+
|
|
665
|
+
elapsed = time.time() - start_time
|
|
666
|
+
logger.info(
|
|
667
|
+
f"Finished precomputing network portraits for gene {gene}: {len(distributions)} distributions, elapsed {elapsed:.2f}s"
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
return distributions
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def find_js_distances_for_gene(
|
|
674
|
+
gene, df, cell_list, portraits, bin_size=0.01, max_count=None, transcript_window=30
|
|
675
|
+
):
|
|
676
|
+
"""
|
|
677
|
+
Compute JS divergence for cell pairs within a gene.
|
|
678
|
+
|
|
679
|
+
Args:
|
|
680
|
+
gene: Gene identifier.
|
|
681
|
+
df: DataFrame containing transcripts.
|
|
682
|
+
cell_list: List of cells.
|
|
683
|
+
portraits: Precomputed portraits {(cell, gene): (weighted_distribution, node_count, r_value)}.
|
|
684
|
+
bin_size: Path length bin size.
|
|
685
|
+
max_count: Max comparisons per target cell.
|
|
686
|
+
transcript_window: Candidate window on transcript count difference.
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
List: JS divergence results.
|
|
690
|
+
"""
|
|
691
|
+
logger.info(f"Computing JS divergences for gene {gene}")
|
|
692
|
+
start_time = time.time()
|
|
693
|
+
|
|
694
|
+
# Collect valid cells and transcript counts for this gene.
|
|
695
|
+
valid_cells = []
|
|
696
|
+
transcript_counts = {}
|
|
697
|
+
r_values = {}
|
|
698
|
+
|
|
699
|
+
for cell in cell_list:
|
|
700
|
+
key = (cell, gene)
|
|
701
|
+
if key in portraits:
|
|
702
|
+
valid_cells.append(cell)
|
|
703
|
+
_, N, r = portraits[key]
|
|
704
|
+
transcript_counts[cell] = N
|
|
705
|
+
r_values[cell] = r
|
|
706
|
+
|
|
707
|
+
logger.info(f"Gene {gene}: {len(valid_cells)} valid cells")
|
|
708
|
+
|
|
709
|
+
# Compute JS distances.
|
|
710
|
+
all_distances = []
|
|
711
|
+
processed_count = 0
|
|
712
|
+
|
|
713
|
+
for i, target_cell in enumerate(valid_cells):
|
|
714
|
+
target_transcript_count = transcript_counts[target_cell]
|
|
715
|
+
target_r = r_values[target_cell]
|
|
716
|
+
|
|
717
|
+
# Candidate cells with similar transcript counts.
|
|
718
|
+
candidates = []
|
|
719
|
+
for j, cell in enumerate(valid_cells):
|
|
720
|
+
if cell != target_cell:
|
|
721
|
+
transcript_diff = abs(transcript_counts[cell] - target_transcript_count)
|
|
722
|
+
if transcript_diff <= transcript_window:
|
|
723
|
+
candidates.append((cell, transcript_diff))
|
|
724
|
+
|
|
725
|
+
# Sort by transcript count difference.
|
|
726
|
+
candidates.sort(key=lambda x: x[1])
|
|
727
|
+
|
|
728
|
+
# Limit comparisons.
|
|
729
|
+
if max_count is not None and len(candidates) > max_count:
|
|
730
|
+
candidates = candidates[:max_count]
|
|
731
|
+
|
|
732
|
+
# Compute JS divergence.
|
|
733
|
+
for cell, transcript_diff in candidates:
|
|
734
|
+
try:
|
|
735
|
+
target_key = (target_cell, gene)
|
|
736
|
+
other_key = (cell, gene)
|
|
737
|
+
|
|
738
|
+
if target_key in portraits and other_key in portraits:
|
|
739
|
+
target_dist, _, _ = portraits[target_key]
|
|
740
|
+
other_dist, _, other_r = portraits[other_key]
|
|
741
|
+
|
|
742
|
+
js_distance = js_divergence(target_dist, other_dist)
|
|
743
|
+
|
|
744
|
+
all_distances.append(
|
|
745
|
+
(
|
|
746
|
+
target_cell,
|
|
747
|
+
gene,
|
|
748
|
+
cell,
|
|
749
|
+
gene,
|
|
750
|
+
transcript_counts[cell],
|
|
751
|
+
js_distance,
|
|
752
|
+
transcript_diff,
|
|
753
|
+
target_r,
|
|
754
|
+
other_r,
|
|
755
|
+
)
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
processed_count += 1
|
|
759
|
+
if processed_count % 100 == 0:
|
|
760
|
+
logger.debug(
|
|
761
|
+
f"Gene {gene}: computed {processed_count} JS distances"
|
|
762
|
+
)
|
|
763
|
+
except Exception as e:
|
|
764
|
+
logger.error(f"Failed JS divergence for {target_cell}-{cell}: {e}")
|
|
765
|
+
|
|
766
|
+
elapsed = time.time() - start_time
|
|
767
|
+
logger.info(
|
|
768
|
+
f"Finished JS divergences for gene {gene}: {len(all_distances)} results, elapsed {elapsed:.2f}s"
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
return all_distances
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
def calculate_js_distances(
|
|
775
|
+
pkl_file: str,
|
|
776
|
+
output_dir: str = None,
|
|
777
|
+
max_count: int = None,
|
|
778
|
+
transcript_window: int = 30,
|
|
779
|
+
bin_size: float = 0.01,
|
|
780
|
+
threshold: float = 0.05,
|
|
781
|
+
r_min: float = 0.01,
|
|
782
|
+
r_max: float = 0.6,
|
|
783
|
+
r_step: float = 0.03,
|
|
784
|
+
num_threads: int = None,
|
|
785
|
+
use_same_r: bool = True,
|
|
786
|
+
visualize_top_n: int = 5,
|
|
787
|
+
use_vectorized: bool = True,
|
|
788
|
+
filter_pkl_file: str = None,
|
|
789
|
+
auto_params: bool = False,
|
|
790
|
+
n_bins: int = 50,
|
|
791
|
+
min_percentile: float = 1.0,
|
|
792
|
+
max_percentile: float = 99.0,
|
|
793
|
+
) -> pd.DataFrame:
|
|
794
|
+
"""
|
|
795
|
+
Compute Jensen-Shannon divergence between transcript graphs, with optional auto r selection.
|
|
796
|
+
|
|
797
|
+
Args:
|
|
798
|
+
pkl_file: Input PKL path containing df_registered.
|
|
799
|
+
output_dir: Output directory; if None, infer a default.
|
|
800
|
+
max_count: Max comparisons per target cell.
|
|
801
|
+
transcript_window: Candidate window on transcript count difference.
|
|
802
|
+
bin_size: Path length bin size; use 'auto' to infer.
|
|
803
|
+
threshold: Isolated-node ratio threshold.
|
|
804
|
+
r_min: Minimum r for search; use 'auto' to infer.
|
|
805
|
+
r_max: Maximum r for search; use 'auto' to infer.
|
|
806
|
+
r_step: Step size for r search.
|
|
807
|
+
num_threads: Thread pool size.
|
|
808
|
+
use_same_r: Use a single r per gene (max across cells).
|
|
809
|
+
visualize_top_n: Visualize top-N most similar pairs.
|
|
810
|
+
use_vectorized: Use vectorized portrait computation.
|
|
811
|
+
filter_pkl_file: Optional PKL to filter (cell, gene) pairs.
|
|
812
|
+
auto_params: Auto-set r_min/r_max/bin_size based on distances.
|
|
813
|
+
n_bins: Bin count for auto bin_size.
|
|
814
|
+
min_percentile: Percentile used to infer r_min.
|
|
815
|
+
max_percentile: Percentile used to infer r_max.
|
|
816
|
+
|
|
817
|
+
Returns:
|
|
818
|
+
pd.DataFrame: JS divergence results.
|
|
819
|
+
"""
|
|
820
|
+
# Default thread count.
|
|
821
|
+
if num_threads is None:
|
|
822
|
+
import multiprocessing
|
|
823
|
+
|
|
824
|
+
num_threads = min(multiprocessing.cpu_count() - 2, 20)
|
|
825
|
+
|
|
826
|
+
# Auto parameter flags.
|
|
827
|
+
auto_r_min = r_min == "auto" or auto_params
|
|
828
|
+
auto_r_max = r_max == "auto" or auto_params
|
|
829
|
+
auto_bin_size = bin_size == "auto" or auto_params
|
|
830
|
+
|
|
831
|
+
# Use None placeholders for auto-calculated params.
|
|
832
|
+
if auto_r_min:
|
|
833
|
+
r_min = None
|
|
834
|
+
else:
|
|
835
|
+
r_min = float(r_min)
|
|
836
|
+
|
|
837
|
+
if auto_r_max:
|
|
838
|
+
r_max = None
|
|
839
|
+
else:
|
|
840
|
+
r_max = float(r_max)
|
|
841
|
+
|
|
842
|
+
if auto_bin_size:
|
|
843
|
+
bin_size = None
|
|
844
|
+
else:
|
|
845
|
+
bin_size = float(bin_size)
|
|
846
|
+
|
|
847
|
+
logger.info(f"Computing JS divergence with {num_threads} threads")
|
|
848
|
+
if auto_params or auto_r_min or auto_r_max or auto_bin_size:
|
|
849
|
+
logger.info("Auto-calculating r_min, r_max, and bin_size")
|
|
850
|
+
else:
|
|
851
|
+
logger.info(
|
|
852
|
+
f"r search params: threshold={threshold}, r_min={r_min}, r_max={r_max}, r_step={r_step}"
|
|
853
|
+
)
|
|
854
|
+
logger.info(f"Path length bin_size: {bin_size}")
|
|
855
|
+
|
|
856
|
+
logger.info(
|
|
857
|
+
f"{'Use a single r per gene' if use_same_r else 'Use per (cell, gene) r'}"
|
|
858
|
+
)
|
|
859
|
+
logger.info(
|
|
860
|
+
f"{'Use vectorized portrait computation' if use_vectorized else 'Use loop-based portrait computation'}"
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
start_time = time.time()
|
|
864
|
+
|
|
865
|
+
# Load data.
|
|
866
|
+
with open(pkl_file, "rb") as f:
|
|
867
|
+
data_dict = pickle.load(f)
|
|
868
|
+
|
|
869
|
+
# Validate required key.
|
|
870
|
+
if "df_registered" not in data_dict:
|
|
871
|
+
raise ValueError(f"PKL file {pkl_file} does not contain df_registered")
|
|
872
|
+
|
|
873
|
+
df = data_dict["df_registered"]
|
|
874
|
+
logger.info(f"Loaded {len(df)} transcript records")
|
|
875
|
+
|
|
876
|
+
# Validate required columns.
|
|
877
|
+
required_cols = ["cell", "gene", "x_c_s", "y_c_s"]
|
|
878
|
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
|
879
|
+
if missing_cols:
|
|
880
|
+
raise ValueError(f"df_registered is missing required columns: {missing_cols}")
|
|
881
|
+
|
|
882
|
+
# Optionally filter (cell, gene) pairs.
|
|
883
|
+
if filter_pkl_file and os.path.exists(filter_pkl_file):
|
|
884
|
+
logger.info(f"Filtering (cell, gene) pairs using PKL file: {filter_pkl_file}")
|
|
885
|
+
try:
|
|
886
|
+
with open(filter_pkl_file, "rb") as f:
|
|
887
|
+
filter_data = pickle.load(f)
|
|
888
|
+
|
|
889
|
+
# Extract cell_labels and gene_labels.
|
|
890
|
+
if "cell_labels" in filter_data and "gene_labels" in filter_data:
|
|
891
|
+
# Require equal lengths to form (cell, gene) pairs.
|
|
892
|
+
cell_labels = filter_data["cell_labels"]
|
|
893
|
+
gene_labels = filter_data["gene_labels"]
|
|
894
|
+
|
|
895
|
+
if len(cell_labels) == len(gene_labels):
|
|
896
|
+
# Build set of (cell, gene) pairs.
|
|
897
|
+
cell_gene_pairs = set(zip(cell_labels, gene_labels))
|
|
898
|
+
logger.info(
|
|
899
|
+
f"Extracted {len(cell_gene_pairs)} unique (cell, gene) pairs from filter file"
|
|
900
|
+
)
|
|
901
|
+
|
|
902
|
+
# Add temp (cell, gene) key for filtering.
|
|
903
|
+
df["cell_gene_pair"] = list(zip(df["cell"], df["gene"]))
|
|
904
|
+
|
|
905
|
+
# Filter df_registered to keep only pairs from the filter file.
|
|
906
|
+
original_len = len(df)
|
|
907
|
+
df = df[df["cell_gene_pair"].isin(cell_gene_pairs)]
|
|
908
|
+
|
|
909
|
+
# Drop temp column.
|
|
910
|
+
df = df.drop(columns=["cell_gene_pair"])
|
|
911
|
+
|
|
912
|
+
logger.info(
|
|
913
|
+
f"Filtered records: before={original_len}, after={len(df)}"
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
if len(df) == 0:
|
|
917
|
+
logger.warning(
|
|
918
|
+
"No records remain after filtering; check whether (cell, gene) pairs match"
|
|
919
|
+
)
|
|
920
|
+
return pd.DataFrame()
|
|
921
|
+
else:
|
|
922
|
+
logger.warning(
|
|
923
|
+
f"Filter PKL has mismatched lengths: cell_labels={len(cell_labels)} vs gene_labels={len(gene_labels)}; cannot form exact pairs"
|
|
924
|
+
)
|
|
925
|
+
logger.info("Proceeding with all cells and genes")
|
|
926
|
+
else:
|
|
927
|
+
logger.warning(
|
|
928
|
+
f"Filter PKL file {filter_pkl_file} does not contain cell_labels or gene_labels"
|
|
929
|
+
)
|
|
930
|
+
except Exception as e:
|
|
931
|
+
logger.error(f"Failed to read filter PKL file: {e}")
|
|
932
|
+
logger.info("Proceeding with all cells and genes")
|
|
933
|
+
|
|
934
|
+
# Unique cells and genes.
|
|
935
|
+
cell_list = sorted(df["cell"].unique())
|
|
936
|
+
gene_list = sorted(df["gene"].unique())
|
|
937
|
+
|
|
938
|
+
logger.info(f"Dataset contains {len(cell_list)} cells and {len(gene_list)} genes")
|
|
939
|
+
|
|
940
|
+
# Resolve output directory.
|
|
941
|
+
if output_dir is None:
|
|
942
|
+
# Try to infer the dataset name from the PKL filename.
|
|
943
|
+
filename = os.path.basename(pkl_file)
|
|
944
|
+
# Drop common suffixes, e.g. "_data_dict.pkl".
|
|
945
|
+
dataset = filename.split("_data_dict")[0]
|
|
946
|
+
if dataset == filename: # No "_data_dict" suffix found.
|
|
947
|
+
# Fall back to stripping the extension.
|
|
948
|
+
dataset = os.path.splitext(filename)[0]
|
|
949
|
+
|
|
950
|
+
logger.info(f"Derived dataset name from filename {filename}: {dataset}")
|
|
951
|
+
|
|
952
|
+
# Determine data directory: search upward for a "GRASP" directory.
|
|
953
|
+
pkl_abs_path = os.path.abspath(pkl_file)
|
|
954
|
+
# Find the "GRASP" directory as project root.
|
|
955
|
+
grasp_dir = None
|
|
956
|
+
path_parts = pkl_abs_path.split(os.sep)
|
|
957
|
+
for i, part in enumerate(path_parts):
|
|
958
|
+
if part == "GRASP":
|
|
959
|
+
grasp_dir = os.sep.join(path_parts[: i + 1])
|
|
960
|
+
break
|
|
961
|
+
|
|
962
|
+
if grasp_dir:
|
|
963
|
+
# Under GRASP, search for a dataset directory (e.g. data1_simulated1).
|
|
964
|
+
# Naming rule: directory name contains the dataset identifier.
|
|
965
|
+
data_dir = None
|
|
966
|
+
for item in os.listdir(grasp_dir):
|
|
967
|
+
item_path = os.path.join(grasp_dir, item)
|
|
968
|
+
if os.path.isdir(item_path) and dataset in item:
|
|
969
|
+
data_dir = item_path
|
|
970
|
+
break
|
|
971
|
+
|
|
972
|
+
if data_dir:
|
|
973
|
+
output_dir = os.path.join(data_dir, "step2_js")
|
|
974
|
+
logger.info(f"Using data directory: {data_dir}")
|
|
975
|
+
else:
|
|
976
|
+
# If no matching data directory, use the PKL directory.
|
|
977
|
+
pkl_dir = os.path.dirname(pkl_abs_path)
|
|
978
|
+
output_dir = os.path.join(pkl_dir, "step2_js")
|
|
979
|
+
logger.info(
|
|
980
|
+
f"No matching data directory found; using PKL directory: {pkl_dir}"
|
|
981
|
+
)
|
|
982
|
+
else:
|
|
983
|
+
# If no GRASP directory found, use the PKL directory.
|
|
984
|
+
pkl_dir = os.path.dirname(pkl_abs_path)
|
|
985
|
+
output_dir = os.path.join(pkl_dir, "step2_js")
|
|
986
|
+
logger.info(
|
|
987
|
+
f"GRASP directory not found in path; using PKL directory: {pkl_dir}"
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
991
|
+
|
|
992
|
+
# Create visualization directory.
|
|
993
|
+
vis_dir = f"{output_dir}/visualization"
|
|
994
|
+
os.makedirs(vis_dir, exist_ok=True)
|
|
995
|
+
|
|
996
|
+
# Analyze transcript distribution (optional diagnostics).
|
|
997
|
+
logger.info("Analyzing transcript distribution (optional diagnostics)...")
|
|
998
|
+
try:
|
|
999
|
+
analyze_transcript_distribution(df, output_dir)
|
|
1000
|
+
except Exception as e:
|
|
1001
|
+
logger.warning(f"Transcript distribution analysis failed: {e}")
|
|
1002
|
+
|
|
1003
|
+
# Stage 0: precompute all distance matrices globally.
|
|
1004
|
+
logger.info("Stage 0: precomputing all distance matrices")
|
|
1005
|
+
dist_dict = {} # Global distance matrix dict {(cell, gene): dist_matrix}
|
|
1006
|
+
|
|
1007
|
+
# Precompute all distance matrices using a thread pool.
|
|
1008
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
1009
|
+
futures = {}
|
|
1010
|
+
|
|
1011
|
+
# Submit all computation tasks.
|
|
1012
|
+
for gene in gene_list:
|
|
1013
|
+
gene_df = df[df["gene"] == gene]
|
|
1014
|
+
|
|
1015
|
+
for cell in cell_list:
|
|
1016
|
+
cell_df = gene_df[gene_df["cell"] == cell]
|
|
1017
|
+
# Skip cells with insufficient transcripts.
|
|
1018
|
+
if len(cell_df) <= 1:
|
|
1019
|
+
continue
|
|
1020
|
+
|
|
1021
|
+
# Define a local function to compute distance matrices.
|
|
1022
|
+
def calc_dist_matrix(c_df):
|
|
1023
|
+
positions = c_df[["x_c_s", "y_c_s"]].values
|
|
1024
|
+
return distance_matrix(positions, positions)
|
|
1025
|
+
|
|
1026
|
+
# Submit task.
|
|
1027
|
+
futures[(cell, gene)] = executor.submit(calc_dist_matrix, cell_df)
|
|
1028
|
+
|
|
1029
|
+
# Collect results.
|
|
1030
|
+
for (cell, gene), future in tqdm(
|
|
1031
|
+
futures.items(),
|
|
1032
|
+
desc="Precomputing distance matrices",
|
|
1033
|
+
disable=not sys.stdout.isatty(),
|
|
1034
|
+
):
|
|
1035
|
+
try:
|
|
1036
|
+
dist_dict[(cell, gene)] = future.result()
|
|
1037
|
+
except Exception as e:
|
|
1038
|
+
logger.error(
|
|
1039
|
+
f"Failed to compute distance matrix for cell={cell}, gene={gene}: {e}"
|
|
1040
|
+
)
|
|
1041
|
+
|
|
1042
|
+
logger.info(f"Distance matrix precompute done; {len(dist_dict)} (cell,gene) pairs")
|
|
1043
|
+
|
|
1044
|
+
# Auto-select parameters based on precomputed distance matrices.
|
|
1045
|
+
if auto_r_min or auto_r_max or auto_bin_size:
|
|
1046
|
+
logger.info("Auto-selecting parameters from precomputed distance matrices")
|
|
1047
|
+
all_dists = []
|
|
1048
|
+
|
|
1049
|
+
# Collect all non-zero distances (no sampling).
|
|
1050
|
+
for key, dmat in tqdm(
|
|
1051
|
+
dist_dict.items(),
|
|
1052
|
+
desc="Collecting distance samples",
|
|
1053
|
+
disable=not sys.stdout.isatty(),
|
|
1054
|
+
):
|
|
1055
|
+
# Upper triangle (exclude self distances).
|
|
1056
|
+
triu_indices = np.triu_indices_from(dmat, k=1)
|
|
1057
|
+
dists = dmat[triu_indices]
|
|
1058
|
+
# Exclude zero and infinite distances.
|
|
1059
|
+
valid_dists = dists[(dists > 0) & (np.isfinite(dists))]
|
|
1060
|
+
if len(valid_dists) > 0:
|
|
1061
|
+
all_dists.append(valid_dists)
|
|
1062
|
+
|
|
1063
|
+
if all_dists:
|
|
1064
|
+
all_dists = np.concatenate(all_dists)
|
|
1065
|
+
logger.info(f"Collected {len(all_dists)} valid distance values")
|
|
1066
|
+
|
|
1067
|
+
# Percentile-based parameter estimation.
|
|
1068
|
+
if auto_r_min:
|
|
1069
|
+
r_min = float(np.percentile(all_dists, min_percentile))
|
|
1070
|
+
r_min = round(r_min, 2) # keep two decimals
|
|
1071
|
+
logger.info(
|
|
1072
|
+
f"Auto-set r_min = {r_min:.2f} ({min_percentile}% percentile)"
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
if auto_r_max:
|
|
1076
|
+
r_max = float(np.percentile(all_dists, max_percentile))
|
|
1077
|
+
r_max = round(r_max, 2) # keep two decimals
|
|
1078
|
+
logger.info(
|
|
1079
|
+
f"Auto-set r_max = {r_max:.2f} ({max_percentile}% percentile)"
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
if auto_bin_size:
|
|
1083
|
+
# Set bin_size = (r_max - r_min) / n_bins.
|
|
1084
|
+
bin_size = float((r_max - r_min) / n_bins)
|
|
1085
|
+
bin_size = round(bin_size, 2) # keep two decimals
|
|
1086
|
+
bin_size = max(0.01, bin_size) # ensure it is not too small
|
|
1087
|
+
logger.info(
|
|
1088
|
+
f"Auto-set bin_size = {bin_size:.2f} (1/{n_bins} of distance range)"
|
|
1089
|
+
)
|
|
1090
|
+
else:
|
|
1091
|
+
logger.warning("Could not collect valid distance samples; using defaults")
|
|
1092
|
+
if auto_r_min:
|
|
1093
|
+
r_min = 0.01
|
|
1094
|
+
if auto_r_max:
|
|
1095
|
+
r_max = 0.6
|
|
1096
|
+
if auto_bin_size:
|
|
1097
|
+
bin_size = 0.01
|
|
1098
|
+
|
|
1099
|
+
# Ensure parameters have reasonable defaults.
|
|
1100
|
+
if r_min is None:
|
|
1101
|
+
r_min = 0.01
|
|
1102
|
+
if r_max is None:
|
|
1103
|
+
r_max = 0.6
|
|
1104
|
+
if bin_size is None:
|
|
1105
|
+
bin_size = 0.01
|
|
1106
|
+
|
|
1107
|
+
logger.info(
|
|
1108
|
+
f"Final params: r_min={r_min:.2f}, r_max={r_max:.2f}, bin_size={bin_size:.2f}"
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
# Stage 1: precompute all network portraits (including auto-selecting r).
|
|
1112
|
+
logger.info("Stage 1: precomputing network portraits")
|
|
1113
|
+
portraits = {}
|
|
1114
|
+
|
|
1115
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
1116
|
+
# Submit gene-level precompute tasks.
|
|
1117
|
+
futures = {
|
|
1118
|
+
executor.submit(
|
|
1119
|
+
precompute_portraits_for_gene,
|
|
1120
|
+
gene,
|
|
1121
|
+
df,
|
|
1122
|
+
cell_list,
|
|
1123
|
+
threshold,
|
|
1124
|
+
bin_size,
|
|
1125
|
+
r_min,
|
|
1126
|
+
r_max,
|
|
1127
|
+
r_step,
|
|
1128
|
+
use_same_r,
|
|
1129
|
+
use_vectorized,
|
|
1130
|
+
dist_dict, # pass the global distance matrix dict
|
|
1131
|
+
): gene
|
|
1132
|
+
for gene in gene_list
|
|
1133
|
+
}
|
|
1134
|
+
|
|
1135
|
+
# Collect results.
|
|
1136
|
+
for future in tqdm(
|
|
1137
|
+
as_completed(futures),
|
|
1138
|
+
total=len(futures),
|
|
1139
|
+
desc="Precomputing network portraits",
|
|
1140
|
+
disable=not sys.stdout.isatty(),
|
|
1141
|
+
):
|
|
1142
|
+
gene = futures[future]
|
|
1143
|
+
try:
|
|
1144
|
+
gene_portraits = future.result()
|
|
1145
|
+
portraits.update(gene_portraits)
|
|
1146
|
+
logger.info(
|
|
1147
|
+
f"Gene {gene} precompute done; {len(gene_portraits)} distributions"
|
|
1148
|
+
)
|
|
1149
|
+
except Exception as e:
|
|
1150
|
+
logger.error(f"Gene {gene} precompute failed: {e}")
|
|
1151
|
+
|
|
1152
|
+
logger.info(f"Network portrait precompute done; {len(portraits)} (cell,gene) pairs")
|
|
1153
|
+
|
|
1154
|
+
# Stage 2: compute JS divergence.
|
|
1155
|
+
logger.info("Stage 2: computing JS divergence")
|
|
1156
|
+
all_distances = []
|
|
1157
|
+
|
|
1158
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
1159
|
+
# Submit gene-level JS divergence tasks.
|
|
1160
|
+
futures = {
|
|
1161
|
+
executor.submit(
|
|
1162
|
+
find_js_distances_for_gene,
|
|
1163
|
+
gene,
|
|
1164
|
+
df,
|
|
1165
|
+
cell_list,
|
|
1166
|
+
portraits,
|
|
1167
|
+
bin_size,
|
|
1168
|
+
max_count,
|
|
1169
|
+
transcript_window,
|
|
1170
|
+
): gene
|
|
1171
|
+
for gene in gene_list
|
|
1172
|
+
}
|
|
1173
|
+
|
|
1174
|
+
# Collect results.
|
|
1175
|
+
for future in tqdm(
|
|
1176
|
+
as_completed(futures),
|
|
1177
|
+
total=len(futures),
|
|
1178
|
+
desc="Computing JS divergence",
|
|
1179
|
+
disable=not sys.stdout.isatty(),
|
|
1180
|
+
):
|
|
1181
|
+
gene = futures[future]
|
|
1182
|
+
try:
|
|
1183
|
+
gene_distances = future.result()
|
|
1184
|
+
all_distances.extend(gene_distances)
|
|
1185
|
+
logger.info(
|
|
1186
|
+
f"Gene {gene} JS divergence done; {len(gene_distances)} results"
|
|
1187
|
+
)
|
|
1188
|
+
except Exception as e:
|
|
1189
|
+
logger.error(f"Gene {gene} JS divergence failed: {e}")
|
|
1190
|
+
|
|
1191
|
+
# Clean up distance matrices to free memory.
|
|
1192
|
+
dist_dict.clear()
|
|
1193
|
+
|
|
1194
|
+
# Save results to a DataFrame.
|
|
1195
|
+
if all_distances:
|
|
1196
|
+
distances_df = pd.DataFrame(
|
|
1197
|
+
all_distances,
|
|
1198
|
+
columns=[
|
|
1199
|
+
"target_cell",
|
|
1200
|
+
"target_gene",
|
|
1201
|
+
"cell",
|
|
1202
|
+
"gene",
|
|
1203
|
+
"num_transcripts",
|
|
1204
|
+
"js_distance",
|
|
1205
|
+
"transcript_diff",
|
|
1206
|
+
"target_r",
|
|
1207
|
+
"other_r",
|
|
1208
|
+
],
|
|
1209
|
+
)
|
|
1210
|
+
|
|
1211
|
+
# Save results.
|
|
1212
|
+
output_path = f"{output_dir}/js_distances_bin{bin_size:.4f}_count{max_count}_threshold{threshold}.csv"
|
|
1213
|
+
distances_df.to_csv(output_path, index=False)
|
|
1214
|
+
|
|
1215
|
+
logger.info(f"\nDone. Results saved to: {output_path}")
|
|
1216
|
+
logger.info(f"Computed {len(distances_df)} JS divergence records")
|
|
1217
|
+
|
|
1218
|
+
# Visualize top-N most similar pairs (smallest JS divergence).
|
|
1219
|
+
if visualize_top_n > 0:
|
|
1220
|
+
logger.info(
|
|
1221
|
+
f"Visualizing top {visualize_top_n} most similar pairs by JS divergence"
|
|
1222
|
+
)
|
|
1223
|
+
visualize_most_similar_pairs(
|
|
1224
|
+
df, distances_df, portraits, visualize_top_n, vis_dir, use_vectorized
|
|
1225
|
+
)
|
|
1226
|
+
|
|
1227
|
+
# Total elapsed time.
|
|
1228
|
+
total_time = time.time() - start_time
|
|
1229
|
+
logger.info(f"Total time: {total_time:.2f}s")
|
|
1230
|
+
|
|
1231
|
+
return distances_df
|
|
1232
|
+
else:
|
|
1233
|
+
logger.warning("WARNING: no JS divergences were computed")
|
|
1234
|
+
return pd.DataFrame()
|
|
1235
|
+
|
|
1236
|
+
|
|
1237
|
+
def visualize_most_similar_pairs(
|
|
1238
|
+
df, distances_df, portraits, top_n=5, output_dir=None, use_vectorized=True
|
|
1239
|
+
):
|
|
1240
|
+
"""
|
|
1241
|
+
Visualize the top-N most similar cell pairs by JS divergence.
|
|
1242
|
+
|
|
1243
|
+
Args:
|
|
1244
|
+
df: Transcript DataFrame.
|
|
1245
|
+
distances_df: JS divergence results DataFrame.
|
|
1246
|
+
portraits: Precomputed portrait dict.
|
|
1247
|
+
top_n: Number of pairs to visualize.
|
|
1248
|
+
output_dir: Output directory.
|
|
1249
|
+
use_vectorized: Whether to use vectorized portrait computation.
|
|
1250
|
+
"""
|
|
1251
|
+
# Sort by JS divergence.
|
|
1252
|
+
sorted_df = distances_df.sort_values("js_distance").reset_index(drop=True)
|
|
1253
|
+
|
|
1254
|
+
# Ensure output directory exists.
|
|
1255
|
+
if output_dir:
|
|
1256
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1257
|
+
|
|
1258
|
+
# Visualize the top-N pairs.
|
|
1259
|
+
for i in range(min(top_n, len(sorted_df))):
|
|
1260
|
+
row = sorted_df.iloc[i]
|
|
1261
|
+
|
|
1262
|
+
target_cell = row["target_cell"]
|
|
1263
|
+
target_gene = row["target_gene"]
|
|
1264
|
+
other_cell = row["cell"]
|
|
1265
|
+
other_gene = row["gene"]
|
|
1266
|
+
js_dist = row["js_distance"]
|
|
1267
|
+
target_r = row["target_r"]
|
|
1268
|
+
other_r = row["other_r"]
|
|
1269
|
+
|
|
1270
|
+
logger.info(
|
|
1271
|
+
f"Rank {i + 1} most similar pair: {target_cell}:{target_gene} - {other_cell}:{other_gene}, JS={js_dist:.4f}"
|
|
1272
|
+
)
|
|
1273
|
+
|
|
1274
|
+
# Extract transcript data.
|
|
1275
|
+
target_df = df[(df["cell"] == target_cell) & (df["gene"] == target_gene)]
|
|
1276
|
+
other_df = df[(df["cell"] == other_cell) & (df["gene"] == other_gene)]
|
|
1277
|
+
|
|
1278
|
+
# Skip if there are too few transcripts.
|
|
1279
|
+
if len(target_df) <= 1 or len(other_df) <= 1:
|
|
1280
|
+
logger.warning(
|
|
1281
|
+
f"Pair {target_cell}:{target_gene} - {other_cell}:{other_gene} has too few transcripts; skipping visualization"
|
|
1282
|
+
)
|
|
1283
|
+
continue
|
|
1284
|
+
|
|
1285
|
+
# Build graphs.
|
|
1286
|
+
target_graph = build_weighted_graph(target_df, target_r)
|
|
1287
|
+
other_graph = build_weighted_graph(other_df, other_r)
|
|
1288
|
+
|
|
1289
|
+
# Compute network portraits.
|
|
1290
|
+
target_portrait, _ = get_network_portrait(
|
|
1291
|
+
target_graph, bin_size=0.01, use_vectorized=use_vectorized
|
|
1292
|
+
)
|
|
1293
|
+
other_portrait, _ = get_network_portrait(
|
|
1294
|
+
other_graph, bin_size=0.01, use_vectorized=use_vectorized
|
|
1295
|
+
)
|
|
1296
|
+
|
|
1297
|
+
# Create per-pair output directory.
|
|
1298
|
+
pair_dir = None
|
|
1299
|
+
if output_dir:
|
|
1300
|
+
pair_dir = f"{output_dir}/pair_{i + 1}_js{js_dist:.4f}"
|
|
1301
|
+
os.makedirs(pair_dir, exist_ok=True)
|
|
1302
|
+
|
|
1303
|
+
# Visualize.
|
|
1304
|
+
pair_prefix = f"Rank {i + 1} JS={js_dist:.4f}: "
|
|
1305
|
+
|
|
1306
|
+
if pair_dir:
|
|
1307
|
+
# Save graph structure.
|
|
1308
|
+
plot_graph(
|
|
1309
|
+
target_graph,
|
|
1310
|
+
title=f"{pair_prefix}{target_cell}:{target_gene} (r={target_r:.2f})",
|
|
1311
|
+
save_path=f"{pair_dir}/cell1_graph.png",
|
|
1312
|
+
)
|
|
1313
|
+
plot_graph(
|
|
1314
|
+
other_graph,
|
|
1315
|
+
title=f"{pair_prefix}{other_cell}:{other_gene} (r={other_r:.2f})",
|
|
1316
|
+
save_path=f"{pair_dir}/cell2_graph.png",
|
|
1317
|
+
)
|
|
1318
|
+
|
|
1319
|
+
# Save portraits.
|
|
1320
|
+
plot_portrait(
|
|
1321
|
+
target_portrait,
|
|
1322
|
+
title=f"{pair_prefix}{target_cell}:{target_gene} Network portrait",
|
|
1323
|
+
save_path=f"{pair_dir}/cell1_portrait.png",
|
|
1324
|
+
)
|
|
1325
|
+
plot_portrait(
|
|
1326
|
+
other_portrait,
|
|
1327
|
+
title=f"{pair_prefix}{other_cell}:{other_gene} Network portrait",
|
|
1328
|
+
save_path=f"{pair_dir}/cell2_portrait.png",
|
|
1329
|
+
)
|
|
1330
|
+
|
|
1331
|
+
# Save transcript scatter plots.
|
|
1332
|
+
plot_transcripts(
|
|
1333
|
+
target_df,
|
|
1334
|
+
target_r,
|
|
1335
|
+
title=f"{pair_prefix}{target_cell}:{target_gene} Transcripts",
|
|
1336
|
+
save_path=f"{pair_dir}/cell1_transcripts.png",
|
|
1337
|
+
)
|
|
1338
|
+
plot_transcripts(
|
|
1339
|
+
other_df,
|
|
1340
|
+
other_r,
|
|
1341
|
+
title=f"{pair_prefix}{other_cell}:{other_gene} Transcripts",
|
|
1342
|
+
save_path=f"{pair_dir}/cell2_transcripts.png",
|
|
1343
|
+
)
|
|
1344
|
+
|
|
1345
|
+
logger.info(f"Saved pair visualization to: {pair_dir}")
|
|
1346
|
+
else:
|
|
1347
|
+
# Show graph structure.
|
|
1348
|
+
plot_graph(
|
|
1349
|
+
target_graph,
|
|
1350
|
+
title=f"{pair_prefix}{target_cell}:{target_gene} (r={target_r:.2f})",
|
|
1351
|
+
)
|
|
1352
|
+
plot_graph(
|
|
1353
|
+
other_graph,
|
|
1354
|
+
title=f"{pair_prefix}{other_cell}:{other_gene} (r={other_r:.2f})",
|
|
1355
|
+
)
|
|
1356
|
+
|
|
1357
|
+
# Show portraits.
|
|
1358
|
+
plot_portrait(
|
|
1359
|
+
target_portrait,
|
|
1360
|
+
title=f"{pair_prefix}{target_cell}:{target_gene} Network portrait",
|
|
1361
|
+
)
|
|
1362
|
+
plot_portrait(
|
|
1363
|
+
other_portrait,
|
|
1364
|
+
title=f"{pair_prefix}{other_cell}:{other_gene} Network portrait",
|
|
1365
|
+
)
|
|
1366
|
+
|
|
1367
|
+
# Show transcript scatter plots.
|
|
1368
|
+
plot_transcripts(
|
|
1369
|
+
target_df,
|
|
1370
|
+
target_r,
|
|
1371
|
+
title=f"{pair_prefix}{target_cell}:{target_gene} Transcripts",
|
|
1372
|
+
)
|
|
1373
|
+
plot_transcripts(
|
|
1374
|
+
other_df,
|
|
1375
|
+
other_r,
|
|
1376
|
+
title=f"{pair_prefix}{other_cell}:{other_gene} Transcripts",
|
|
1377
|
+
)
|
|
1378
|
+
|
|
1379
|
+
|
|
1380
|
+
def plot_transcripts(df, r, title="Transcript distribution", save_path=None):
|
|
1381
|
+
"""
|
|
1382
|
+
Plot transcript spatial distribution.
|
|
1383
|
+
|
|
1384
|
+
Args:
|
|
1385
|
+
df: Transcript DataFrame.
|
|
1386
|
+
r: Connection radius.
|
|
1387
|
+
title: Plot title.
|
|
1388
|
+
save_path: Optional output file path.
|
|
1389
|
+
"""
|
|
1390
|
+
plt.figure(figsize=(10, 8))
|
|
1391
|
+
|
|
1392
|
+
# Plot transcript locations.
|
|
1393
|
+
plt.scatter(df["x_c_s"], df["y_c_s"], alpha=0.6, s=10)
|
|
1394
|
+
|
|
1395
|
+
# Add a radius circle for each transcript.
|
|
1396
|
+
for _, row in df.iterrows():
|
|
1397
|
+
circle = plt.Circle(
|
|
1398
|
+
(row["x_c_s"], row["y_c_s"]),
|
|
1399
|
+
r,
|
|
1400
|
+
fill=False,
|
|
1401
|
+
color="gray",
|
|
1402
|
+
alpha=0.2,
|
|
1403
|
+
linestyle="--",
|
|
1404
|
+
)
|
|
1405
|
+
plt.gca().add_patch(circle)
|
|
1406
|
+
# plt.text(center_x, center_y + r + 0.1, f'r = {r:.3f}', ha='center', fontsize=10, color='red') # center_x/center_y undefined
|
|
1407
|
+
|
|
1408
|
+
plt.xlabel("X coordinate")
|
|
1409
|
+
plt.ylabel("Y coordinate")
|
|
1410
|
+
plt.title(title)
|
|
1411
|
+
plt.axis("equal")
|
|
1412
|
+
plt.grid(True, alpha=0.3)
|
|
1413
|
+
|
|
1414
|
+
if save_path:
|
|
1415
|
+
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
|
1416
|
+
plt.close()
|
|
1417
|
+
else:
|
|
1418
|
+
plt.show()
|
|
1419
|
+
|
|
1420
|
+
|
|
1421
|
+
def analyze_transcript_distribution(df, output_dir=None):
|
|
1422
|
+
"""
|
|
1423
|
+
Analyze and report transcript distribution.
|
|
1424
|
+
|
|
1425
|
+
Args:
|
|
1426
|
+
df: Transcript DataFrame.
|
|
1427
|
+
output_dir: Output directory. If None, only logs are emitted.
|
|
1428
|
+
|
|
1429
|
+
Returns:
|
|
1430
|
+
Dict: Summary statistics.
|
|
1431
|
+
"""
|
|
1432
|
+
logger.info("Analyzing transcript distribution...")
|
|
1433
|
+
|
|
1434
|
+
# Basic stats.
|
|
1435
|
+
total_transcripts = len(df)
|
|
1436
|
+
unique_genes = df["gene"].nunique()
|
|
1437
|
+
unique_cells = df["cell"].nunique()
|
|
1438
|
+
|
|
1439
|
+
# Transcripts per gene.
|
|
1440
|
+
gene_transcript_counts = df.groupby("gene").size()
|
|
1441
|
+
|
|
1442
|
+
if not gene_transcript_counts.empty:
|
|
1443
|
+
gene_stats_values = {
|
|
1444
|
+
"transcript_per_gene_mean": float(gene_transcript_counts.mean()),
|
|
1445
|
+
"transcript_per_gene_median": float(gene_transcript_counts.median()),
|
|
1446
|
+
"transcript_per_gene_std": float(gene_transcript_counts.std())
|
|
1447
|
+
if not np.isnan(gene_transcript_counts.std())
|
|
1448
|
+
else None,
|
|
1449
|
+
"transcript_per_gene_min": int(gene_transcript_counts.min()),
|
|
1450
|
+
"transcript_per_gene_max": int(gene_transcript_counts.max()),
|
|
1451
|
+
}
|
|
1452
|
+
else:
|
|
1453
|
+
gene_stats_values = {
|
|
1454
|
+
"transcript_per_gene_mean": 0.0,
|
|
1455
|
+
"transcript_per_gene_median": 0.0,
|
|
1456
|
+
"transcript_per_gene_std": None,
|
|
1457
|
+
"transcript_per_gene_min": 0,
|
|
1458
|
+
"transcript_per_gene_max": 0,
|
|
1459
|
+
}
|
|
1460
|
+
gene_stats = {"total_genes": int(unique_genes), **gene_stats_values}
|
|
1461
|
+
|
|
1462
|
+
# Transcripts per cell.
|
|
1463
|
+
cell_transcript_counts = df.groupby("cell").size()
|
|
1464
|
+
if not cell_transcript_counts.empty:
|
|
1465
|
+
cell_stats_values = {
|
|
1466
|
+
"transcript_per_cell_mean": float(cell_transcript_counts.mean()),
|
|
1467
|
+
"transcript_per_cell_median": float(cell_transcript_counts.median()),
|
|
1468
|
+
"transcript_per_cell_std": float(cell_transcript_counts.std())
|
|
1469
|
+
if not np.isnan(cell_transcript_counts.std())
|
|
1470
|
+
else None,
|
|
1471
|
+
"transcript_per_cell_min": int(cell_transcript_counts.min()),
|
|
1472
|
+
"transcript_per_cell_max": int(cell_transcript_counts.max()),
|
|
1473
|
+
}
|
|
1474
|
+
else:
|
|
1475
|
+
cell_stats_values = {
|
|
1476
|
+
"transcript_per_cell_mean": 0.0,
|
|
1477
|
+
"transcript_per_cell_median": 0.0,
|
|
1478
|
+
"transcript_per_cell_std": None,
|
|
1479
|
+
"transcript_per_cell_min": 0,
|
|
1480
|
+
"transcript_per_cell_max": 0,
|
|
1481
|
+
}
|
|
1482
|
+
cell_stats = {"total_cells": int(unique_cells), **cell_stats_values}
|
|
1483
|
+
|
|
1484
|
+
# Transcripts per (cell, gene) pair.
|
|
1485
|
+
cell_gene_transcript_counts = df.groupby(["cell", "gene"]).size()
|
|
1486
|
+
if not cell_gene_transcript_counts.empty:
|
|
1487
|
+
pair_stats_values = {
|
|
1488
|
+
"transcript_per_pair_mean": float(cell_gene_transcript_counts.mean()),
|
|
1489
|
+
"transcript_per_pair_median": float(cell_gene_transcript_counts.median()),
|
|
1490
|
+
"transcript_per_pair_std": float(cell_gene_transcript_counts.std())
|
|
1491
|
+
if not np.isnan(cell_gene_transcript_counts.std())
|
|
1492
|
+
else None,
|
|
1493
|
+
}
|
|
1494
|
+
else:
|
|
1495
|
+
pair_stats_values = {
|
|
1496
|
+
"transcript_per_pair_mean": 0.0,
|
|
1497
|
+
"transcript_per_pair_median": 0.0,
|
|
1498
|
+
"transcript_per_pair_std": None,
|
|
1499
|
+
}
|
|
1500
|
+
pair_stats = {
|
|
1501
|
+
"total_cell_gene_pairs": int(
|
|
1502
|
+
len(cell_gene_transcript_counts)
|
|
1503
|
+
), # len() returns python int
|
|
1504
|
+
**pair_stats_values,
|
|
1505
|
+
"single_transcript_pairs": int((cell_gene_transcript_counts == 1).sum()),
|
|
1506
|
+
"multi_transcript_pairs": int((cell_gene_transcript_counts > 1).sum()),
|
|
1507
|
+
}
|
|
1508
|
+
|
|
1509
|
+
# Count genes that are mostly single-transcript.
|
|
1510
|
+
genes_with_mostly_single_transcripts = 0
|
|
1511
|
+
if (
|
|
1512
|
+
unique_genes > 0 and not cell_gene_transcript_counts.empty
|
|
1513
|
+
): # Avoid processing if no genes or no pairs
|
|
1514
|
+
for gene in df["gene"].unique():
|
|
1515
|
+
gene_pairs = cell_gene_transcript_counts[
|
|
1516
|
+
cell_gene_transcript_counts.index.get_level_values("gene") == gene
|
|
1517
|
+
]
|
|
1518
|
+
if not gene_pairs.empty:
|
|
1519
|
+
single_transcript_ratio = (gene_pairs == 1).sum() / len(gene_pairs)
|
|
1520
|
+
if single_transcript_ratio > 0.8: # >80% pairs are single-transcript
|
|
1521
|
+
genes_with_mostly_single_transcripts += 1
|
|
1522
|
+
|
|
1523
|
+
problem_stats = {
|
|
1524
|
+
"genes_with_mostly_single_transcripts": int(
|
|
1525
|
+
genes_with_mostly_single_transcripts
|
|
1526
|
+
),
|
|
1527
|
+
"problematic_gene_ratio": float(
|
|
1528
|
+
genes_with_mostly_single_transcripts / unique_genes
|
|
1529
|
+
)
|
|
1530
|
+
if unique_genes > 0
|
|
1531
|
+
else 0.0,
|
|
1532
|
+
"single_transcript_pair_ratio": float(
|
|
1533
|
+
pair_stats["single_transcript_pairs"] / pair_stats["total_cell_gene_pairs"]
|
|
1534
|
+
)
|
|
1535
|
+
if pair_stats["total_cell_gene_pairs"] > 0
|
|
1536
|
+
else 0.0,
|
|
1537
|
+
}
|
|
1538
|
+
|
|
1539
|
+
# Summary.
|
|
1540
|
+
stats_summary = {
|
|
1541
|
+
"total_transcripts": int(total_transcripts), # len() returns python int
|
|
1542
|
+
"gene_stats": gene_stats,
|
|
1543
|
+
"cell_stats": cell_stats,
|
|
1544
|
+
"pair_stats": pair_stats,
|
|
1545
|
+
"problem_stats": problem_stats,
|
|
1546
|
+
}
|
|
1547
|
+
|
|
1548
|
+
# Report.
|
|
1549
|
+
logger.info("=" * 60)
|
|
1550
|
+
logger.info("Transcript distribution report")
|
|
1551
|
+
logger.info("=" * 60)
|
|
1552
|
+
logger.info(f"Total transcripts: {total_transcripts:,}")
|
|
1553
|
+
logger.info(f"Unique genes: {unique_genes:,}")
|
|
1554
|
+
logger.info(f"Unique cells: {unique_cells:,}")
|
|
1555
|
+
logger.info("")
|
|
1556
|
+
|
|
1557
|
+
logger.info("Gene-level stats:")
|
|
1558
|
+
logger.info(
|
|
1559
|
+
f" Mean transcripts per gene: {gene_stats['transcript_per_gene_mean']:.1f}"
|
|
1560
|
+
)
|
|
1561
|
+
logger.info(f" Median: {gene_stats['transcript_per_gene_median']:.1f}")
|
|
1562
|
+
logger.info(
|
|
1563
|
+
f" Range: {gene_stats['transcript_per_gene_min']}-{gene_stats['transcript_per_gene_max']}"
|
|
1564
|
+
)
|
|
1565
|
+
logger.info("")
|
|
1566
|
+
|
|
1567
|
+
logger.info("Cell-level stats:")
|
|
1568
|
+
logger.info(
|
|
1569
|
+
f" Mean transcripts per cell: {cell_stats['transcript_per_cell_mean']:.1f}"
|
|
1570
|
+
)
|
|
1571
|
+
logger.info(f" Median: {cell_stats['transcript_per_cell_median']:.1f}")
|
|
1572
|
+
logger.info(
|
|
1573
|
+
f" Range: {cell_stats['transcript_per_cell_min']}-{cell_stats['transcript_per_cell_max']}"
|
|
1574
|
+
)
|
|
1575
|
+
logger.info("")
|
|
1576
|
+
|
|
1577
|
+
logger.info("(Cell, gene) pair-level stats:")
|
|
1578
|
+
logger.info(f" Total (cell, gene) pairs: {pair_stats['total_cell_gene_pairs']:,}")
|
|
1579
|
+
logger.info(
|
|
1580
|
+
f" Single-transcript pairs: {pair_stats['single_transcript_pairs']:,} ({pair_stats['single_transcript_pairs'] / pair_stats['total_cell_gene_pairs'] * 100:.1f}%)"
|
|
1581
|
+
)
|
|
1582
|
+
logger.info(
|
|
1583
|
+
f" Multi-transcript pairs: {pair_stats['multi_transcript_pairs']:,} ({pair_stats['multi_transcript_pairs'] / pair_stats['total_cell_gene_pairs'] * 100:.1f}%)"
|
|
1584
|
+
)
|
|
1585
|
+
logger.info(
|
|
1586
|
+
f" Mean transcripts per pair: {pair_stats['transcript_per_pair_mean']:.1f}"
|
|
1587
|
+
)
|
|
1588
|
+
logger.info("")
|
|
1589
|
+
|
|
1590
|
+
logger.info("Potential issues:")
|
|
1591
|
+
logger.info(
|
|
1592
|
+
f" Genes mostly single-transcript: {problem_stats['genes_with_mostly_single_transcripts']:,} ({problem_stats['problematic_gene_ratio'] * 100:.1f}%)"
|
|
1593
|
+
)
|
|
1594
|
+
logger.info(
|
|
1595
|
+
f" Single-transcript pair ratio: {problem_stats['single_transcript_pair_ratio'] * 100:.1f}%"
|
|
1596
|
+
)
|
|
1597
|
+
|
|
1598
|
+
if problem_stats["problematic_gene_ratio"] > 0.5:
|
|
1599
|
+
logger.warning(
|
|
1600
|
+
"WARNING: >50% of genes are mostly single-transcript; consider checking data quality or adjusting parameters"
|
|
1601
|
+
)
|
|
1602
|
+
elif problem_stats["single_transcript_pair_ratio"] > 0.7:
|
|
1603
|
+
logger.warning(
|
|
1604
|
+
"WARNING: >70% of (cell, gene) pairs have a single transcript; this may affect network portrait quality"
|
|
1605
|
+
)
|
|
1606
|
+
else:
|
|
1607
|
+
logger.info("Data quality looks OK for network portrait analysis")
|
|
1608
|
+
|
|
1609
|
+
logger.info("=" * 60)
|
|
1610
|
+
|
|
1611
|
+
# If output_dir is provided, write stats and plots.
|
|
1612
|
+
if output_dir:
|
|
1613
|
+
import json
|
|
1614
|
+
|
|
1615
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1616
|
+
|
|
1617
|
+
# Save stats.
|
|
1618
|
+
stats_file = os.path.join(output_dir, "transcript_distribution_stats.json")
|
|
1619
|
+
with open(stats_file, "w", encoding="utf-8") as f:
|
|
1620
|
+
json.dump(stats_summary, f, indent=2, ensure_ascii=False)
|
|
1621
|
+
logger.info(f"Saved detailed stats to: {stats_file}")
|
|
1622
|
+
|
|
1623
|
+
# Save distribution plots.
|
|
1624
|
+
plt.figure(figsize=(12, 8))
|
|
1625
|
+
plt.subplot(2, 2, 1)
|
|
1626
|
+
plt.hist(gene_transcript_counts, bins=50, alpha=0.7, edgecolor="black")
|
|
1627
|
+
plt.xlabel("Transcripts per gene")
|
|
1628
|
+
plt.ylabel("Number of genes")
|
|
1629
|
+
plt.title("Transcript count per gene")
|
|
1630
|
+
plt.yscale("log")
|
|
1631
|
+
|
|
1632
|
+
plt.subplot(2, 2, 2)
|
|
1633
|
+
plt.hist(cell_transcript_counts, bins=50, alpha=0.7, edgecolor="black")
|
|
1634
|
+
plt.xlabel("Transcripts per cell")
|
|
1635
|
+
plt.ylabel("Number of cells")
|
|
1636
|
+
plt.title("Transcript count per cell")
|
|
1637
|
+
plt.yscale("log")
|
|
1638
|
+
|
|
1639
|
+
plt.subplot(2, 2, 3)
|
|
1640
|
+
plt.hist(cell_gene_transcript_counts, bins=30, alpha=0.7, edgecolor="black")
|
|
1641
|
+
plt.xlabel("Transcripts per (cell, gene) pair")
|
|
1642
|
+
plt.ylabel("Number of pairs")
|
|
1643
|
+
plt.title("Transcript count per (cell, gene) pair")
|
|
1644
|
+
plt.yscale("log")
|
|
1645
|
+
|
|
1646
|
+
plt.subplot(2, 2, 4)
|
|
1647
|
+
# Single vs multi transcript pairs.
|
|
1648
|
+
labels = ["Single-transcript pairs", "Multi-transcript pairs"]
|
|
1649
|
+
sizes = [
|
|
1650
|
+
pair_stats["single_transcript_pairs"],
|
|
1651
|
+
pair_stats["multi_transcript_pairs"],
|
|
1652
|
+
]
|
|
1653
|
+
plt.pie(sizes, labels=labels, autopct="%1.1f%%", startangle=90)
|
|
1654
|
+
plt.title("Single vs multi-transcript pairs")
|
|
1655
|
+
|
|
1656
|
+
plt.tight_layout()
|
|
1657
|
+
dist_plot_file = os.path.join(output_dir, "transcript_distribution_plots.png")
|
|
1658
|
+
plt.savefig(dist_plot_file, dpi=300, bbox_inches="tight")
|
|
1659
|
+
plt.close()
|
|
1660
|
+
logger.info(f"Saved distribution plots to: {dist_plot_file}")
|
|
1661
|
+
|
|
1662
|
+
return stats_summary
|
|
1663
|
+
|
|
1664
|
+
|
|
1665
|
+
if __name__ == "__main__":
|
|
1666
|
+
# CLI argument parsing.
|
|
1667
|
+
parser = argparse.ArgumentParser(
|
|
1668
|
+
description="Compute similarity between transcript graphs using JS divergence"
|
|
1669
|
+
)
|
|
1670
|
+
parser.add_argument(
|
|
1671
|
+
"--pkl_file",
|
|
1672
|
+
type=str,
|
|
1673
|
+
required=True,
|
|
1674
|
+
help="Path to a PKL containing df_registered",
|
|
1675
|
+
)
|
|
1676
|
+
parser.add_argument(
|
|
1677
|
+
"--output_dir", type=str, default=None, help="Output directory (optional)"
|
|
1678
|
+
)
|
|
1679
|
+
parser.add_argument(
|
|
1680
|
+
"--max_count", type=int, default=10, help="Max comparisons per target cell"
|
|
1681
|
+
)
|
|
1682
|
+
parser.add_argument(
|
|
1683
|
+
"--transcript_window",
|
|
1684
|
+
type=int,
|
|
1685
|
+
default=30,
|
|
1686
|
+
help="Transcript count difference window for candidate filtering",
|
|
1687
|
+
)
|
|
1688
|
+
parser.add_argument(
|
|
1689
|
+
"--bin_size",
|
|
1690
|
+
type=str,
|
|
1691
|
+
default="0.01",
|
|
1692
|
+
help='Path length bin size for JS divergence; set to "auto" to estimate',
|
|
1693
|
+
)
|
|
1694
|
+
parser.add_argument(
|
|
1695
|
+
"--threshold",
|
|
1696
|
+
type=float,
|
|
1697
|
+
default=0.05,
|
|
1698
|
+
help="Isolated node ratio threshold for selecting r",
|
|
1699
|
+
)
|
|
1700
|
+
parser.add_argument(
|
|
1701
|
+
"--r_min",
|
|
1702
|
+
type=str,
|
|
1703
|
+
default="0.01",
|
|
1704
|
+
help='Minimum r for search; set to "auto" to estimate',
|
|
1705
|
+
)
|
|
1706
|
+
parser.add_argument(
|
|
1707
|
+
"--r_max",
|
|
1708
|
+
type=str,
|
|
1709
|
+
default="0.6",
|
|
1710
|
+
help='Maximum r for search; set to "auto" to estimate',
|
|
1711
|
+
)
|
|
1712
|
+
parser.add_argument(
|
|
1713
|
+
"--r_step", type=float, default=0.03, help="Step size for r search"
|
|
1714
|
+
)
|
|
1715
|
+
parser.add_argument(
|
|
1716
|
+
"--num_threads",
|
|
1717
|
+
type=int,
|
|
1718
|
+
default=None,
|
|
1719
|
+
help="Number of threads (default: CPU count)",
|
|
1720
|
+
)
|
|
1721
|
+
parser.add_argument(
|
|
1722
|
+
"--use_same_r",
|
|
1723
|
+
action="store_true",
|
|
1724
|
+
help="Use a shared r for all cells within a gene (gene-level r)",
|
|
1725
|
+
)
|
|
1726
|
+
parser.add_argument(
|
|
1727
|
+
"--visualize_top_n",
|
|
1728
|
+
type=int,
|
|
1729
|
+
default=5,
|
|
1730
|
+
help="Visualize top-N most similar pairs by JS divergence (0 disables)",
|
|
1731
|
+
)
|
|
1732
|
+
parser.add_argument(
|
|
1733
|
+
"--log_level",
|
|
1734
|
+
type=str,
|
|
1735
|
+
default="INFO",
|
|
1736
|
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
|
1737
|
+
help="Log level",
|
|
1738
|
+
)
|
|
1739
|
+
parser.add_argument(
|
|
1740
|
+
"--log_file",
|
|
1741
|
+
type=str,
|
|
1742
|
+
default="js_distance_transcriptome.log",
|
|
1743
|
+
help="Log filename",
|
|
1744
|
+
)
|
|
1745
|
+
parser.add_argument(
|
|
1746
|
+
"--no_vectorized",
|
|
1747
|
+
action="store_true",
|
|
1748
|
+
help="Disable vectorized portrait computation (slower, lower memory)",
|
|
1749
|
+
)
|
|
1750
|
+
parser.add_argument(
|
|
1751
|
+
"--filter_pkl_file",
|
|
1752
|
+
type=str,
|
|
1753
|
+
default=None,
|
|
1754
|
+
help="Optional filter PKL path (contains cell_labels/gene_labels)",
|
|
1755
|
+
)
|
|
1756
|
+
parser.add_argument(
|
|
1757
|
+
"--auto_params", action="store_true", help="Auto-set r_min, r_max, and bin_size"
|
|
1758
|
+
)
|
|
1759
|
+
parser.add_argument(
|
|
1760
|
+
"--n_bins",
|
|
1761
|
+
type=int,
|
|
1762
|
+
default=50,
|
|
1763
|
+
help="Number of bins when auto-setting bin_size",
|
|
1764
|
+
)
|
|
1765
|
+
parser.add_argument(
|
|
1766
|
+
"--min_percentile",
|
|
1767
|
+
type=float,
|
|
1768
|
+
default=1.0,
|
|
1769
|
+
help="Percentile used when auto-setting r_min",
|
|
1770
|
+
)
|
|
1771
|
+
parser.add_argument(
|
|
1772
|
+
"--max_percentile",
|
|
1773
|
+
type=float,
|
|
1774
|
+
default=99.0,
|
|
1775
|
+
help="Percentile used when auto-setting r_max",
|
|
1776
|
+
)
|
|
1777
|
+
|
|
1778
|
+
args = parser.parse_args()
|
|
1779
|
+
|
|
1780
|
+
# Set log level.
|
|
1781
|
+
logger.setLevel(getattr(logging, args.log_level))
|
|
1782
|
+
|
|
1783
|
+
# Add file handler.
|
|
1784
|
+
file_handler = logging.FileHandler(args.log_file)
|
|
1785
|
+
file_handler.setFormatter(
|
|
1786
|
+
logging.Formatter(
|
|
1787
|
+
"[%(asctime)s][%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
|
1788
|
+
)
|
|
1789
|
+
)
|
|
1790
|
+
logger.addHandler(file_handler)
|
|
1791
|
+
|
|
1792
|
+
# Track program runtime.
|
|
1793
|
+
program_start_time = time.time()
|
|
1794
|
+
logger.info("=" * 80)
|
|
1795
|
+
logger.info("Program started")
|
|
1796
|
+
logger.info(f"CLI args: {vars(args)}")
|
|
1797
|
+
|
|
1798
|
+
try:
|
|
1799
|
+
# Run main function.
|
|
1800
|
+
distances_df = calculate_js_distances(
|
|
1801
|
+
pkl_file=args.pkl_file,
|
|
1802
|
+
output_dir=args.output_dir,
|
|
1803
|
+
max_count=args.max_count,
|
|
1804
|
+
transcript_window=args.transcript_window,
|
|
1805
|
+
bin_size=args.bin_size,
|
|
1806
|
+
threshold=args.threshold,
|
|
1807
|
+
r_min=args.r_min,
|
|
1808
|
+
r_max=args.r_max,
|
|
1809
|
+
r_step=args.r_step,
|
|
1810
|
+
num_threads=args.num_threads,
|
|
1811
|
+
use_same_r=args.use_same_r,
|
|
1812
|
+
visualize_top_n=args.visualize_top_n,
|
|
1813
|
+
use_vectorized=(not args.no_vectorized),
|
|
1814
|
+
filter_pkl_file=args.filter_pkl_file,
|
|
1815
|
+
auto_params=args.auto_params,
|
|
1816
|
+
n_bins=args.n_bins,
|
|
1817
|
+
min_percentile=args.min_percentile,
|
|
1818
|
+
max_percentile=args.max_percentile,
|
|
1819
|
+
)
|
|
1820
|
+
|
|
1821
|
+
# Print result stats.
|
|
1822
|
+
if not distances_df.empty:
|
|
1823
|
+
logger.info("\nResult stats:")
|
|
1824
|
+
logger.info(f"Total distances: {len(distances_df)}")
|
|
1825
|
+
logger.info(
|
|
1826
|
+
f"JS range: [{distances_df['js_distance'].min():.4f}, {distances_df['js_distance'].max():.4f}]"
|
|
1827
|
+
)
|
|
1828
|
+
logger.info(f"Mean JS: {distances_df['js_distance'].mean():.4f}")
|
|
1829
|
+
logger.info(f"Median JS: {distances_df['js_distance'].median():.4f}")
|
|
1830
|
+
logger.info(
|
|
1831
|
+
f"r range: [{distances_df['target_r'].min():.2f}, {distances_df['target_r'].max():.2f}]"
|
|
1832
|
+
)
|
|
1833
|
+
logger.info(f"Mean r: {distances_df['target_r'].mean():.2f}")
|
|
1834
|
+
|
|
1835
|
+
# Preview first few rows.
|
|
1836
|
+
logger.info("\nResult preview:")
|
|
1837
|
+
logger.info(distances_df.head().to_string())
|
|
1838
|
+
|
|
1839
|
+
except Exception as e:
|
|
1840
|
+
logger.error(f"Program failed: {e}")
|
|
1841
|
+
import traceback
|
|
1842
|
+
|
|
1843
|
+
logger.error(traceback.format_exc())
|
|
1844
|
+
raise
|
|
1845
|
+
|
|
1846
|
+
# Compute and report total runtime.
|
|
1847
|
+
program_end_time = time.time()
|
|
1848
|
+
total_program_time = program_end_time - program_start_time
|
|
1849
|
+
hours, remainder = divmod(total_program_time, 3600)
|
|
1850
|
+
minutes, seconds = divmod(remainder, 60)
|
|
1851
|
+
|
|
1852
|
+
logger.info("=" * 80)
|
|
1853
|
+
logger.info(f"Program total runtime: {int(hours)}h {int(minutes)}m {seconds:.2f}s")
|
|
1854
|
+
logger.info(
|
|
1855
|
+
f"Start time: {datetime.fromtimestamp(program_start_time).strftime('%Y-%m-%d %H:%M:%S')}"
|
|
1856
|
+
)
|
|
1857
|
+
logger.info(
|
|
1858
|
+
f"End time: {datetime.fromtimestamp(program_end_time).strftime('%Y-%m-%d %H:%M:%S')}"
|
|
1859
|
+
)
|
|
1860
|
+
logger.info("=" * 80)
|
|
1861
|
+
|
|
1862
|
+
print("=" * 50)
|