deeptan-network 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deeptan/__init__.py +9 -0
- deeptan/cli/__init__.py +5 -0
- deeptan/cli/deeptan_fit.py +77 -0
- deeptan/cli/deeptan_litdata.py +113 -0
- deeptan/cli/deeptan_perturb.py +31 -0
- deeptan/cli/deeptan_pkl2h5.py +41 -0
- deeptan/cli/deeptan_predict.py +36 -0
- deeptan/cli/hello.py +6 -0
- deeptan/cli/print_sc_h5.py +32 -0
- deeptan/constants/__init__.py +9 -0
- deeptan/constants/art.py +9 -0
- deeptan/constants/default.py +75 -0
- deeptan/constants/dict_key.py +98 -0
- deeptan/constants/fname.py +7 -0
- deeptan/constants/hparam_candidates.py +8 -0
- deeptan/graph/__init__.py +5 -0
- deeptan/graph/model.py +1000 -0
- deeptan/graph/modules.py +720 -0
- deeptan/graph/recon.py +246 -0
- deeptan/utils/.ipynb_checkpoints/data-checkpoint.py +1042 -0
- deeptan/utils/.ipynb_checkpoints/metrics-checkpoint.py +867 -0
- deeptan/utils/.ipynb_checkpoints/peaks-checkpoint.py +210 -0
- deeptan/utils/.ipynb_checkpoints/uni-checkpoint.py +343 -0
- deeptan/utils/__init__.py +5 -0
- deeptan/utils/data.py +1042 -0
- deeptan/utils/metrics.py +863 -0
- deeptan/utils/peaks.py +210 -0
- deeptan/utils/uni.py +354 -0
- deeptan_network-0.1.0.dist-info/METADATA +1306 -0
- deeptan_network-0.1.0.dist-info/RECORD +33 -0
- deeptan_network-0.1.0.dist-info/WHEEL +4 -0
- deeptan_network-0.1.0.dist-info/entry_points.txt +8 -0
- deeptan_network-0.1.0.dist-info/licenses/LICENSE +674 -0
deeptan/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DeepTAN: A novel graph-based multi-task framework designed to infer large-scale multi-omics trait-associated networks (TANs) and reconstruct phenotype-specific omics states.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
__package_name__ = "deeptan"
|
|
6
|
+
__distribution_name__ = "deeptan-network"
|
|
7
|
+
__version__ = "0.1.0"
|
|
8
|
+
__author__ = "Ying Wang, Zeming Meng"
|
|
9
|
+
__credits__ = "Northwest A&F University"
|
deeptan/cli/__init__.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import ast
|
|
3
|
+
|
|
4
|
+
import deeptan.constants as const
|
|
5
|
+
from deeptan.graph.model import DeepTANTune
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def deeptan_fit_tune():
|
|
9
|
+
parser = argparse.ArgumentParser(description="DeepTAN fitting and tuning pipeline.")
|
|
10
|
+
parser.add_argument("--auto_tune", "--atune", action="store_true", help="Whether to perform hyperparameter tuning")
|
|
11
|
+
|
|
12
|
+
parser.add_argument("--em", type=str, default="", help="Existing model checkpoint path for loading")
|
|
13
|
+
parser.add_argument("--focus", type=str, default="None", help="Focus on a specific task, choose from 'None', 'recon', 'label', 'recon_and_freeze', 'label_and_freeze'")
|
|
14
|
+
parser.add_argument("--no_guide_gat", "--nog", action="store_true", help="Whether to disable edge weights of guidance graphs on graph attentions")
|
|
15
|
+
|
|
16
|
+
parser.add_argument("--litdata", "--data", type=str, required=False, help="Path to litdata directory")
|
|
17
|
+
parser.add_argument("--bs", type=int, default=const.default.bs, help="Batch size for training")
|
|
18
|
+
parser.add_argument("--lr", type=float, default=const.default.lr, help="Learning rate")
|
|
19
|
+
parser.add_argument("--log_dir", "--logdir", type=str, default=".tmp_logs_tune", help="Directory for logging")
|
|
20
|
+
parser.add_argument("--input_node_emb_dim", "--indim", type=int, default=1, help="Input node embedding dimension")
|
|
21
|
+
parser.add_argument("--is_regression", "--ir", action="store_true", help="Whether the task is regression")
|
|
22
|
+
parser.add_argument("--acc_grad_batch", "--agb", type=int, default=const.default.accumulate_grad_batches, help="Accumulate gradients over multiple batches")
|
|
23
|
+
parser.add_argument("--chunk_size", "--ck", type=int, default=const.default.chunk_size, help="A proper chunk size can balance memory usage and speed")
|
|
24
|
+
parser.add_argument("--accelerator", "--ac", type=str, default=const.default.accelerator, help="cpu, gpu, tpu, hpu, mps, auto")
|
|
25
|
+
parser.add_argument("--devices", "--dev", type=str, default=const.default.devices, help="Devices to use")
|
|
26
|
+
parser.add_argument("--ntrials", "--nt", type=int, default=const.default.n_trials, help="Number of trials for hyperparameter tuning")
|
|
27
|
+
parser.add_argument("--njobs", "--nj", type=int, default=const.default.n_jobs, help="The number of parallel jobs for Optuna. If this argument is set to -1, the number is set to CPU count.")
|
|
28
|
+
args = parser.parse_args()
|
|
29
|
+
|
|
30
|
+
if args.no_guide_gat:
|
|
31
|
+
guide_gat = False
|
|
32
|
+
else:
|
|
33
|
+
guide_gat = True
|
|
34
|
+
|
|
35
|
+
_dev = args.devices
|
|
36
|
+
if args.devices.startswith("["):
|
|
37
|
+
if not args.devices.endswith("]"):
|
|
38
|
+
raise ValueError("Devices argument must be a valid list or a single device string.")
|
|
39
|
+
# Convert string representation of list to actual list
|
|
40
|
+
_dev = ast.literal_eval(args.devices)
|
|
41
|
+
|
|
42
|
+
_config = const.default.model_config.copy()
|
|
43
|
+
_config.update(
|
|
44
|
+
{
|
|
45
|
+
"es": const.default.es,
|
|
46
|
+
"max_ep": const.default.max_epoch,
|
|
47
|
+
"min_ep": const.default.min_epoch,
|
|
48
|
+
"log_dir": args.log_dir,
|
|
49
|
+
"litdata": args.litdata,
|
|
50
|
+
"bs": args.bs,
|
|
51
|
+
"lr": args.lr,
|
|
52
|
+
"chunk_size": args.chunk_size,
|
|
53
|
+
"is_regression": args.is_regression,
|
|
54
|
+
"accelerator": args.accelerator,
|
|
55
|
+
"devices": _dev,
|
|
56
|
+
"input_node_emb_dim": args.input_node_emb_dim,
|
|
57
|
+
"acc_grad_batch": args.acc_grad_batch,
|
|
58
|
+
"guide_gat": guide_gat,
|
|
59
|
+
}
|
|
60
|
+
)
|
|
61
|
+
print(f"\n🔧Configuration⚙️\n{_config}\n")
|
|
62
|
+
|
|
63
|
+
if len(args.em) < 3:
|
|
64
|
+
ckpt = None
|
|
65
|
+
else:
|
|
66
|
+
ckpt = args.em
|
|
67
|
+
|
|
68
|
+
if args.focus == "None":
|
|
69
|
+
focus = None
|
|
70
|
+
else:
|
|
71
|
+
focus = args.focus
|
|
72
|
+
|
|
73
|
+
trainer = DeepTANTune(_config, ckpt, focus)
|
|
74
|
+
if args.auto_tune:
|
|
75
|
+
trainer.optimize(n_trials=args.ntrials, n_jobs=args.njobs)
|
|
76
|
+
else:
|
|
77
|
+
trainer._train_on_args()
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import pickle
|
|
5
|
+
import shutil
|
|
6
|
+
|
|
7
|
+
import litdata
|
|
8
|
+
import numpy as np
|
|
9
|
+
import polars as pl
|
|
10
|
+
|
|
11
|
+
import deeptan.constants as const
|
|
12
|
+
from deeptan.utils.data import DeepTANDataModule, read_nmic_npz
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def deeptan_litdata():
|
|
16
|
+
parser = argparse.ArgumentParser(description="Optimize data for Deeptan model")
|
|
17
|
+
parser.add_argument("--labels", type=str, default="", help="Path to label data in .parquet format")
|
|
18
|
+
parser.add_argument("--bs", type=int, default=const.default.bs, help="Batch size for training")
|
|
19
|
+
parser.add_argument("--trn_npz", type=str, required=True, help="Path to training data in .npz format (generated by mi2graph)")
|
|
20
|
+
parser.add_argument("--val_parquet", type=str, required=True, help="Path to validation data in .parquet format")
|
|
21
|
+
parser.add_argument("--tst_parquet", type=str, required=True, help="Path to test data in .parquet format")
|
|
22
|
+
parser.add_argument("--output_dir", type=str, default=".tmp_data_optimized", help="Directory for logging")
|
|
23
|
+
parser.add_argument("--thre_mi", type=float, default=const.default.threshold_nmic, help="Threshold for edge attribute")
|
|
24
|
+
parser.add_argument("--in_feat", type=str, default="", help="Path to a .csv with header, containing a list of features to specify. If None, all features are used")
|
|
25
|
+
parser.add_argument("--in_obs", type=str, default="", help="")
|
|
26
|
+
parser.add_argument("--onlytest", action="store_true", help="Only run the test phase")
|
|
27
|
+
parser.add_argument("--n_workers", type=int, default=const.default.n_threads, help="Number of workers for data loading")
|
|
28
|
+
args = parser.parse_args()
|
|
29
|
+
|
|
30
|
+
if len(args.labels) < 2:
|
|
31
|
+
labels = None
|
|
32
|
+
else:
|
|
33
|
+
labels = args.labels
|
|
34
|
+
|
|
35
|
+
if len(args.in_feat) < 2:
|
|
36
|
+
specify_features = None
|
|
37
|
+
else:
|
|
38
|
+
specify_features = args.in_feat
|
|
39
|
+
|
|
40
|
+
if len(args.in_obs) < 2:
|
|
41
|
+
specify_trn_obs = None
|
|
42
|
+
else:
|
|
43
|
+
specify_trn_obs = args.in_obs
|
|
44
|
+
|
|
45
|
+
files_fit = {
|
|
46
|
+
const.dkey.abbr_train: args.trn_npz,
|
|
47
|
+
const.dkey.abbr_val: args.val_parquet,
|
|
48
|
+
const.dkey.abbr_test: args.tst_parquet,
|
|
49
|
+
}
|
|
50
|
+
datamodule = DeepTANDataModule(
|
|
51
|
+
files_fit,
|
|
52
|
+
labels,
|
|
53
|
+
batch_size=args.bs,
|
|
54
|
+
edge_attr_threshold=args.thre_mi,
|
|
55
|
+
specify_features=specify_features,
|
|
56
|
+
)
|
|
57
|
+
datamodule.setup()
|
|
58
|
+
|
|
59
|
+
# Copy original training data to output directory for saving node_names and g_label_dim
|
|
60
|
+
# shutil.copy(args.trn_npz, os.path.join(args.output_dir, "trn.npz"))
|
|
61
|
+
others2save = {
|
|
62
|
+
"dict_node_names": datamodule.dict_node_names,
|
|
63
|
+
"output_g_label_dim": datamodule.label_dim,
|
|
64
|
+
}
|
|
65
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
66
|
+
# Save as json
|
|
67
|
+
with open(os.path.join(args.output_dir, const.fname.litdata_others2save_json), "w") as f:
|
|
68
|
+
json.dump(others2save, f)
|
|
69
|
+
# Save as pickle
|
|
70
|
+
with open(os.path.join(args.output_dir, const.fname.litdata_others2save_pkl), "wb") as f:
|
|
71
|
+
pickle.dump(others2save, f)
|
|
72
|
+
|
|
73
|
+
if labels is not None:
|
|
74
|
+
shutil.copy(labels, os.path.join(args.output_dir, const.fname.label_class_onehot))
|
|
75
|
+
|
|
76
|
+
# Check obs_names filter
|
|
77
|
+
if specify_trn_obs is not None:
|
|
78
|
+
_edge_attr, _edge_index, _mat, _mat_feat_indices, _obs_names, _node_names = read_nmic_npz(args.trn_npz)
|
|
79
|
+
_obs_names_goal = pl.read_parquet(specify_trn_obs)["obs_names"].to_list()
|
|
80
|
+
# Get intersection of _obs_names and _obs_names_goal
|
|
81
|
+
_obs_names_filtered = list(set(_obs_names) & set(_obs_names_goal))
|
|
82
|
+
# Get available indices in _obs_names for the following litdata getting indices
|
|
83
|
+
_obs_names_indices = np.where(np.isin(_obs_names, _obs_names_filtered))[0]
|
|
84
|
+
_trn_indices = _obs_names_indices.tolist()
|
|
85
|
+
else:
|
|
86
|
+
_trn_indices = list(range(datamodule.train.len()))
|
|
87
|
+
|
|
88
|
+
# Optimize
|
|
89
|
+
if not args.onlytest:
|
|
90
|
+
litdata.optimize(
|
|
91
|
+
fn=datamodule.train.get,
|
|
92
|
+
inputs=_trn_indices,
|
|
93
|
+
output_dir=os.path.join(args.output_dir, const.dkey.abbr_train),
|
|
94
|
+
chunk_bytes=const.default.lit_chunk_bytes,
|
|
95
|
+
compression=const.default.lit_compression,
|
|
96
|
+
num_workers=min(args.n_workers, const.default.n_threads),
|
|
97
|
+
)
|
|
98
|
+
litdata.optimize(
|
|
99
|
+
fn=datamodule.val.get,
|
|
100
|
+
inputs=list(range(datamodule.val.len())),
|
|
101
|
+
output_dir=os.path.join(args.output_dir, const.dkey.abbr_val),
|
|
102
|
+
chunk_bytes=const.default.lit_chunk_bytes,
|
|
103
|
+
compression=const.default.lit_compression,
|
|
104
|
+
num_workers=min(args.n_workers, const.default.n_threads),
|
|
105
|
+
)
|
|
106
|
+
litdata.optimize(
|
|
107
|
+
fn=datamodule.test.get,
|
|
108
|
+
inputs=list(range(datamodule.test.len())),
|
|
109
|
+
output_dir=os.path.join(args.output_dir, const.dkey.abbr_test),
|
|
110
|
+
chunk_bytes=const.default.lit_chunk_bytes,
|
|
111
|
+
compression=const.default.lit_compression,
|
|
112
|
+
num_workers=min(args.n_workers, const.default.n_threads),
|
|
113
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from deeptan.graph.recon import predict_perturbation
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def deeptan_perturb():
|
|
8
|
+
parser = argparse.ArgumentParser(description="DeepTAN perturbation script.")
|
|
9
|
+
parser.add_argument("--em", type=str, required=True, help="Existing model checkpoint path.")
|
|
10
|
+
parser.add_argument("--litdata", "--data", type=str, required=True, help="Path to litdata directory")
|
|
11
|
+
parser.add_argument("--output", "--out", type=str, required=True, help="Path to output file")
|
|
12
|
+
parser.add_argument("--maplocation", "--maploc", type=str, default=None, help="Map location for model loading")
|
|
13
|
+
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing output directory")
|
|
14
|
+
args = parser.parse_args()
|
|
15
|
+
|
|
16
|
+
model_path = args.em
|
|
17
|
+
litdata_dir = args.litdata
|
|
18
|
+
|
|
19
|
+
# Create output directory
|
|
20
|
+
output_dir = os.path.dirname(args.output)
|
|
21
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
22
|
+
|
|
23
|
+
print(f"Predicting with model {model_path} on data {litdata_dir}")
|
|
24
|
+
predict_perturbation(
|
|
25
|
+
model_ckpt_path=model_path,
|
|
26
|
+
litdata_dir=litdata_dir,
|
|
27
|
+
output_path=args.output,
|
|
28
|
+
n_perturbations=5,
|
|
29
|
+
map_location=args.maplocation,
|
|
30
|
+
overwrite_files=args.overwrite,
|
|
31
|
+
)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from deeptan.utils.uni import convert_pickle_to_h5
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def convert_pkl_to_h5():
|
|
8
|
+
parser = argparse.ArgumentParser(description="Convert a pickle file to HDF5 format.")
|
|
9
|
+
parser.add_argument("-i", "--input", required=True, help="Path to a pickle file or a directory containing pickle files.")
|
|
10
|
+
parser.add_argument("-o", "--output", required=False, help="Path to the output HDF5 file or directory. If not provided, output will be in the same location as the input file with .h5 extension.")
|
|
11
|
+
parser.add_argument("-f", "--force", action="store_true", help="Overwrite existing HDF5 files without prompting.")
|
|
12
|
+
args = parser.parse_args()
|
|
13
|
+
|
|
14
|
+
if os.path.isdir(args.input):
|
|
15
|
+
for root, _, files in os.walk(args.input):
|
|
16
|
+
for file in files:
|
|
17
|
+
if file.endswith(".pkl"):
|
|
18
|
+
input_pkl = os.path.join(root, file)
|
|
19
|
+
if args.output is None:
|
|
20
|
+
output_h5 = input_pkl.replace(".pkl", ".h5")
|
|
21
|
+
else:
|
|
22
|
+
output_h5 = os.path.join(args.output, file.replace(".pkl", ".h5"))
|
|
23
|
+
if os.path.exists(output_h5) and not args.force:
|
|
24
|
+
print(f"File already exists: {output_h5}. Skipping.")
|
|
25
|
+
else:
|
|
26
|
+
convert_pickle_to_h5(input_pkl, output_h5)
|
|
27
|
+
print(f"Converted {input_pkl} to {output_h5}.")
|
|
28
|
+
else:
|
|
29
|
+
if args.output is None:
|
|
30
|
+
output_h5 = args.input.replace(".pkl", ".h5")
|
|
31
|
+
else:
|
|
32
|
+
if os.path.isdir(args.output):
|
|
33
|
+
output_h5 = os.path.join(args.output, os.path.basename(args.input).replace(".pkl", ".h5"))
|
|
34
|
+
else:
|
|
35
|
+
output_h5 = args.output.replace(".pkl", ".h5")
|
|
36
|
+
|
|
37
|
+
if os.path.exists(output_h5) and not args.force:
|
|
38
|
+
print(f"File already exists: {output_h5}. Skipping.")
|
|
39
|
+
else:
|
|
40
|
+
convert_pickle_to_h5(args.input, output_h5)
|
|
41
|
+
print(f"Conversion complete: {output_h5}.")
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from deeptan.graph.recon import predict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def deeptan_predict():
|
|
8
|
+
parser = argparse.ArgumentParser(description="DeepTAN prediction script.")
|
|
9
|
+
parser.add_argument("--em", type=str, required=True, help="Existing model checkpoint path.")
|
|
10
|
+
parser.add_argument("--litdata", "--data", type=str, required=True, help="Path to litdata directory")
|
|
11
|
+
parser.add_argument("--output", "--out", type=str, required=True, help="Path to output file")
|
|
12
|
+
parser.add_argument("--bs", "--batch_size", type=int, default=8, help="Batch size for prediction")
|
|
13
|
+
parser.add_argument("--maplocation", "--maploc", type=str, default=None, help="Map location for model loading")
|
|
14
|
+
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing output directory")
|
|
15
|
+
args = parser.parse_args()
|
|
16
|
+
|
|
17
|
+
model_path = args.em
|
|
18
|
+
litdata_dir = args.litdata
|
|
19
|
+
|
|
20
|
+
# Create output directory
|
|
21
|
+
output_dir = os.path.dirname(args.output)
|
|
22
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
23
|
+
output_path = args.output if args.output.endswith(".h5") else f"{args.output}.h5"
|
|
24
|
+
|
|
25
|
+
if not os.path.exists(output_path) or args.overwrite:
|
|
26
|
+
print(f"Predicting with model {model_path} on data {litdata_dir}")
|
|
27
|
+
predict(
|
|
28
|
+
model_ckpt_path=model_path,
|
|
29
|
+
litdata_dir=litdata_dir,
|
|
30
|
+
output_path=output_path,
|
|
31
|
+
map_location=args.map_location,
|
|
32
|
+
batch_size=args.bs,
|
|
33
|
+
save_h5=True,
|
|
34
|
+
)
|
|
35
|
+
else:
|
|
36
|
+
print(f"Results already exist at {output_path}")
|
deeptan/cli/hello.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
import scanpy as sc
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def print_sc_h5():
|
|
8
|
+
parser = argparse.ArgumentParser(description="Read an H5AD or H5 file and print its contents.")
|
|
9
|
+
parser.add_argument("-p", "--path", type=str, help="Path to the H5AD or H5 file.")
|
|
10
|
+
args = parser.parse_args()
|
|
11
|
+
if args.path.endswith(".h5ad"):
|
|
12
|
+
adata = sc.read_h5ad(args.path)
|
|
13
|
+
elif args.path.endswith(".h5"):
|
|
14
|
+
adata = sc.read_10x_h5(args.path)
|
|
15
|
+
else:
|
|
16
|
+
print("Unsupported file format. Please provide a .h5ad or .h5 file.")
|
|
17
|
+
sys.exit(1)
|
|
18
|
+
print(adata)
|
|
19
|
+
print("\nobs keys: ")
|
|
20
|
+
print(adata.obs.keys())
|
|
21
|
+
for _key in adata.obs.keys():
|
|
22
|
+
print(f"\n{_key}:")
|
|
23
|
+
print(adata.obs[_key])
|
|
24
|
+
print("\nvar keys: ")
|
|
25
|
+
for _key in adata.var_keys():
|
|
26
|
+
print(f"\n{_key}:")
|
|
27
|
+
print(adata.var[_key])
|
|
28
|
+
if "Celltype" in adata.obs.keys():
|
|
29
|
+
print("\nCelltype:")
|
|
30
|
+
print(type(adata.obs["Celltype"]))
|
|
31
|
+
if "Orig.ident" in adata.obs.keys():
|
|
32
|
+
sc.pl.umap(adata, color="Orig.ident", title="UMAP by Experiment")
|
deeptan/constants/art.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
Default values.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from multiprocessing import cpu_count
|
|
6
|
+
from os import getenv
|
|
7
|
+
|
|
8
|
+
from numpy import ceil
|
|
9
|
+
|
|
10
|
+
bs = 8
|
|
11
|
+
accumulate_grad_batches = 4
|
|
12
|
+
lr = 0.0002
|
|
13
|
+
es = 5
|
|
14
|
+
min_epoch = 3
|
|
15
|
+
max_epoch = 50
|
|
16
|
+
dropout = 0.2
|
|
17
|
+
negative_slope = 0.2
|
|
18
|
+
node_emb_dim = 128
|
|
19
|
+
g_emb_dim = 192
|
|
20
|
+
label_pred_hidden_dims = [1024, 256]
|
|
21
|
+
fusion_dims_node_emb = [256, 128]
|
|
22
|
+
n_heads_pooling = 4
|
|
23
|
+
n_heads_node_emb = 4
|
|
24
|
+
n_heads_ge_decoder = 4
|
|
25
|
+
n_heads_label_pred = 4
|
|
26
|
+
n_hop = 1
|
|
27
|
+
|
|
28
|
+
chunk_size = 256
|
|
29
|
+
mem_safety_factor = 0.6
|
|
30
|
+
operation_overhead = 2.5
|
|
31
|
+
|
|
32
|
+
threshold_nmic = 0.01
|
|
33
|
+
threshold_subg_overlap = 0.85
|
|
34
|
+
threshold_edge_exist = 0.03
|
|
35
|
+
|
|
36
|
+
matmul_precision = "high"
|
|
37
|
+
accelerator = "auto"
|
|
38
|
+
devices = "auto"
|
|
39
|
+
precision = "32-true"
|
|
40
|
+
gradient_clip_val = 1.0
|
|
41
|
+
|
|
42
|
+
n_threads = int(getenv("NUM_THREADS", ceil(cpu_count() * 0.9)))
|
|
43
|
+
|
|
44
|
+
time_format = "%Y%m%d%H%M%S"
|
|
45
|
+
time_delay = 11.7
|
|
46
|
+
ckpt_fname_format = "best_model"
|
|
47
|
+
optuna_db = "sqlite:///optuna.db"
|
|
48
|
+
n_jobs = 1
|
|
49
|
+
n_trials = 30
|
|
50
|
+
n_workers = 1
|
|
51
|
+
|
|
52
|
+
lit_chunk_bytes = "256MB"
|
|
53
|
+
lit_compression = "zstd"
|
|
54
|
+
lit_max_cache_size = "26GB"
|
|
55
|
+
|
|
56
|
+
model_config = {
|
|
57
|
+
"guide_gat": True,
|
|
58
|
+
"class_weights": None,
|
|
59
|
+
"use_focal_loss": True,
|
|
60
|
+
"focal_alpha": None,
|
|
61
|
+
"node_emb_dim": node_emb_dim,
|
|
62
|
+
"fusion_dims_node_emb": fusion_dims_node_emb,
|
|
63
|
+
"output_dim_g_emb": g_emb_dim,
|
|
64
|
+
"n_hop": n_hop,
|
|
65
|
+
"threshold_edge_exist": threshold_edge_exist,
|
|
66
|
+
"threshold_subgraph_overlap": threshold_subg_overlap,
|
|
67
|
+
"n_heads_node_emb": n_heads_node_emb,
|
|
68
|
+
"n_heads_pooling": n_heads_pooling,
|
|
69
|
+
"n_heads_ge_decoder": n_heads_ge_decoder,
|
|
70
|
+
"n_heads_label_pred": n_heads_label_pred,
|
|
71
|
+
"dropout": dropout,
|
|
72
|
+
"lr": lr,
|
|
73
|
+
"chunk_size": chunk_size,
|
|
74
|
+
"n_workers": n_workers,
|
|
75
|
+
}
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
Dictionary keys and column names.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
title_train = "train"
|
|
6
|
+
title_val = "val"
|
|
7
|
+
title_test = "test"
|
|
8
|
+
title_predict = "pred"
|
|
9
|
+
|
|
10
|
+
abbr_train = "trn"
|
|
11
|
+
abbr_val = "val"
|
|
12
|
+
abbr_test = "tst"
|
|
13
|
+
|
|
14
|
+
splits = [abbr_train, abbr_val, abbr_test]
|
|
15
|
+
|
|
16
|
+
title_trn_loss = "trn/loss"
|
|
17
|
+
title_val_loss = "val/loss"
|
|
18
|
+
title_tst_loss = "tst/loss"
|
|
19
|
+
|
|
20
|
+
tsb_keys2pick = [
|
|
21
|
+
"test/recon_MSE",
|
|
22
|
+
"test/recon_RMSE",
|
|
23
|
+
"test/recon_MAE",
|
|
24
|
+
"test/recon_PCC",
|
|
25
|
+
"test/label_MSE",
|
|
26
|
+
"test/label_RMSE",
|
|
27
|
+
"test/label_MAE",
|
|
28
|
+
"test/label_PCC",
|
|
29
|
+
"test/label_F1_weighted",
|
|
30
|
+
"test/label_F1_macro",
|
|
31
|
+
"test/label_F1_micro",
|
|
32
|
+
"test/label_AUROC",
|
|
33
|
+
"test/label_Accuracy",
|
|
34
|
+
"test/label_Precision",
|
|
35
|
+
"test/label_Recall",
|
|
36
|
+
"test/loss",
|
|
37
|
+
"test/loss_unweighted",
|
|
38
|
+
"val/loss",
|
|
39
|
+
"val/loss_unweighted",
|
|
40
|
+
"val/recon_MSE",
|
|
41
|
+
"val/recon_RMSE",
|
|
42
|
+
"val/recon_MAE",
|
|
43
|
+
"val/recon_PCC",
|
|
44
|
+
"val/label_MSE",
|
|
45
|
+
"val/label_RMSE",
|
|
46
|
+
"val/label_MAE",
|
|
47
|
+
"val/label_PCC",
|
|
48
|
+
"val/label_F1_weighted",
|
|
49
|
+
"val/label_F1_macro",
|
|
50
|
+
"val/label_F1_micro",
|
|
51
|
+
"val/label_AUROC",
|
|
52
|
+
"val/label_Accuracy",
|
|
53
|
+
"val/label_Precision",
|
|
54
|
+
"val/label_Recall",
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
title_metric_mapping = {
|
|
58
|
+
"jsd": "1 - JSD",
|
|
59
|
+
"mae": "1 - MAE",
|
|
60
|
+
"mse": "1 - MSE",
|
|
61
|
+
"pcc": "PCC",
|
|
62
|
+
"spearman": "Spearman",
|
|
63
|
+
"weighted_recall": "Recall (weighted)",
|
|
64
|
+
"weighted_precision": "Precision (weighted)",
|
|
65
|
+
"weighted_f1": "F1 Score (weighted)",
|
|
66
|
+
"macro_f1": "F1 Score (macro)",
|
|
67
|
+
"micro_f1": "F1 Score (micro)",
|
|
68
|
+
"auprc": "AUPRC",
|
|
69
|
+
"auroc": "AUROC",
|
|
70
|
+
"accuracy": "ACC",
|
|
71
|
+
"kbet_true_label": "kBET (true labels)",
|
|
72
|
+
"kbet_pred_label": "kBET (predicted labels)",
|
|
73
|
+
"asw_true_label": "ASW (true labels)",
|
|
74
|
+
"asw_pred_label": "ASW (predicted labels)",
|
|
75
|
+
"kbet": "kBET",
|
|
76
|
+
"asw": "ASW",
|
|
77
|
+
"ari": "ARI",
|
|
78
|
+
"nmi": "NMI",
|
|
79
|
+
"ami": "AMI",
|
|
80
|
+
"ari_leiden": "ARI (Leiden)",
|
|
81
|
+
"nmi_leiden": "NMI (Leiden)",
|
|
82
|
+
"ami_leiden": "AMI (Leiden)",
|
|
83
|
+
"method": "Method",
|
|
84
|
+
"train_size": "Train size",
|
|
85
|
+
"n_feat": "Number of features for test",
|
|
86
|
+
}
|
|
87
|
+
title_task_mapping = {
|
|
88
|
+
"multitask": "Multitask",
|
|
89
|
+
"multitask_noguide": "Multitask (no SGG)",
|
|
90
|
+
"focus_recon": "Focus on reconstruction",
|
|
91
|
+
"focus_label": "Focus on labelling",
|
|
92
|
+
"focus_label_on_focus_recon": "Fine-tuned on reconstruction",
|
|
93
|
+
}
|
|
94
|
+
title_colnameC2_mapping = {
|
|
95
|
+
"task": "Task",
|
|
96
|
+
"metric": "Metric",
|
|
97
|
+
"value": "Value",
|
|
98
|
+
}
|