deeptan-network 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deeptan/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """
2
+ DeepTAN: A novel graph-based multi-task framework designed to infer large-scale multi-omics trait-associated networks (TANs) and reconstruct phenotype-specific omics states.
3
+ """
4
+
5
+ __package_name__ = "deeptan"
6
+ __distribution_name__ = "deeptan-network"
7
+ __version__ = "0.1.0"
8
+ __author__ = "Ying Wang, Zeming Meng"
9
+ __credits__ = "Northwest A&F University"
@@ -0,0 +1,5 @@
1
+ r"""
2
+ CLI for DeepTAN.
3
+ """
4
+
5
+ from deeptan.cli import deeptan_fit, deeptan_litdata, deeptan_perturb, deeptan_pkl2h5, deeptan_predict, hello, print_sc_h5
@@ -0,0 +1,77 @@
1
+ import argparse
2
+ import ast
3
+
4
+ import deeptan.constants as const
5
+ from deeptan.graph.model import DeepTANTune
6
+
7
+
8
+ def deeptan_fit_tune():
9
+ parser = argparse.ArgumentParser(description="DeepTAN fitting and tuning pipeline.")
10
+ parser.add_argument("--auto_tune", "--atune", action="store_true", help="Whether to perform hyperparameter tuning")
11
+
12
+ parser.add_argument("--em", type=str, default="", help="Existing model checkpoint path for loading")
13
+ parser.add_argument("--focus", type=str, default="None", help="Focus on a specific task, choose from 'None', 'recon', 'label', 'recon_and_freeze', 'label_and_freeze'")
14
+ parser.add_argument("--no_guide_gat", "--nog", action="store_true", help="Whether to disable edge weights of guidance graphs on graph attentions")
15
+
16
+ parser.add_argument("--litdata", "--data", type=str, required=False, help="Path to litdata directory")
17
+ parser.add_argument("--bs", type=int, default=const.default.bs, help="Batch size for training")
18
+ parser.add_argument("--lr", type=float, default=const.default.lr, help="Learning rate")
19
+ parser.add_argument("--log_dir", "--logdir", type=str, default=".tmp_logs_tune", help="Directory for logging")
20
+ parser.add_argument("--input_node_emb_dim", "--indim", type=int, default=1, help="Input node embedding dimension")
21
+ parser.add_argument("--is_regression", "--ir", action="store_true", help="Whether the task is regression")
22
+ parser.add_argument("--acc_grad_batch", "--agb", type=int, default=const.default.accumulate_grad_batches, help="Accumulate gradients over multiple batches")
23
+ parser.add_argument("--chunk_size", "--ck", type=int, default=const.default.chunk_size, help="A proper chunk size can balance memory usage and speed")
24
+ parser.add_argument("--accelerator", "--ac", type=str, default=const.default.accelerator, help="cpu, gpu, tpu, hpu, mps, auto")
25
+ parser.add_argument("--devices", "--dev", type=str, default=const.default.devices, help="Devices to use")
26
+ parser.add_argument("--ntrials", "--nt", type=int, default=const.default.n_trials, help="Number of trials for hyperparameter tuning")
27
+ parser.add_argument("--njobs", "--nj", type=int, default=const.default.n_jobs, help="The number of parallel jobs for Optuna. If this argument is set to -1, the number is set to CPU count.")
28
+ args = parser.parse_args()
29
+
30
+ if args.no_guide_gat:
31
+ guide_gat = False
32
+ else:
33
+ guide_gat = True
34
+
35
+ _dev = args.devices
36
+ if args.devices.startswith("["):
37
+ if not args.devices.endswith("]"):
38
+ raise ValueError("Devices argument must be a valid list or a single device string.")
39
+ # Convert string representation of list to actual list
40
+ _dev = ast.literal_eval(args.devices)
41
+
42
+ _config = const.default.model_config.copy()
43
+ _config.update(
44
+ {
45
+ "es": const.default.es,
46
+ "max_ep": const.default.max_epoch,
47
+ "min_ep": const.default.min_epoch,
48
+ "log_dir": args.log_dir,
49
+ "litdata": args.litdata,
50
+ "bs": args.bs,
51
+ "lr": args.lr,
52
+ "chunk_size": args.chunk_size,
53
+ "is_regression": args.is_regression,
54
+ "accelerator": args.accelerator,
55
+ "devices": _dev,
56
+ "input_node_emb_dim": args.input_node_emb_dim,
57
+ "acc_grad_batch": args.acc_grad_batch,
58
+ "guide_gat": guide_gat,
59
+ }
60
+ )
61
+ print(f"\n🔧Configuration⚙️\n{_config}\n")
62
+
63
+ if len(args.em) < 3:
64
+ ckpt = None
65
+ else:
66
+ ckpt = args.em
67
+
68
+ if args.focus == "None":
69
+ focus = None
70
+ else:
71
+ focus = args.focus
72
+
73
+ trainer = DeepTANTune(_config, ckpt, focus)
74
+ if args.auto_tune:
75
+ trainer.optimize(n_trials=args.ntrials, n_jobs=args.njobs)
76
+ else:
77
+ trainer._train_on_args()
@@ -0,0 +1,113 @@
1
+ import argparse
2
+ import json
3
+ import os
4
+ import pickle
5
+ import shutil
6
+
7
+ import litdata
8
+ import numpy as np
9
+ import polars as pl
10
+
11
+ import deeptan.constants as const
12
+ from deeptan.utils.data import DeepTANDataModule, read_nmic_npz
13
+
14
+
15
+ def deeptan_litdata():
16
+ parser = argparse.ArgumentParser(description="Optimize data for Deeptan model")
17
+ parser.add_argument("--labels", type=str, default="", help="Path to label data in .parquet format")
18
+ parser.add_argument("--bs", type=int, default=const.default.bs, help="Batch size for training")
19
+ parser.add_argument("--trn_npz", type=str, required=True, help="Path to training data in .npz format (generated by mi2graph)")
20
+ parser.add_argument("--val_parquet", type=str, required=True, help="Path to validation data in .parquet format")
21
+ parser.add_argument("--tst_parquet", type=str, required=True, help="Path to test data in .parquet format")
22
+ parser.add_argument("--output_dir", type=str, default=".tmp_data_optimized", help="Directory for logging")
23
+ parser.add_argument("--thre_mi", type=float, default=const.default.threshold_nmic, help="Threshold for edge attribute")
24
+ parser.add_argument("--in_feat", type=str, default="", help="Path to a .csv with header, containing a list of features to specify. If None, all features are used")
25
+ parser.add_argument("--in_obs", type=str, default="", help="")
26
+ parser.add_argument("--onlytest", action="store_true", help="Only run the test phase")
27
+ parser.add_argument("--n_workers", type=int, default=const.default.n_threads, help="Number of workers for data loading")
28
+ args = parser.parse_args()
29
+
30
+ if len(args.labels) < 2:
31
+ labels = None
32
+ else:
33
+ labels = args.labels
34
+
35
+ if len(args.in_feat) < 2:
36
+ specify_features = None
37
+ else:
38
+ specify_features = args.in_feat
39
+
40
+ if len(args.in_obs) < 2:
41
+ specify_trn_obs = None
42
+ else:
43
+ specify_trn_obs = args.in_obs
44
+
45
+ files_fit = {
46
+ const.dkey.abbr_train: args.trn_npz,
47
+ const.dkey.abbr_val: args.val_parquet,
48
+ const.dkey.abbr_test: args.tst_parquet,
49
+ }
50
+ datamodule = DeepTANDataModule(
51
+ files_fit,
52
+ labels,
53
+ batch_size=args.bs,
54
+ edge_attr_threshold=args.thre_mi,
55
+ specify_features=specify_features,
56
+ )
57
+ datamodule.setup()
58
+
59
+ # Copy original training data to output directory for saving node_names and g_label_dim
60
+ # shutil.copy(args.trn_npz, os.path.join(args.output_dir, "trn.npz"))
61
+ others2save = {
62
+ "dict_node_names": datamodule.dict_node_names,
63
+ "output_g_label_dim": datamodule.label_dim,
64
+ }
65
+ os.makedirs(args.output_dir, exist_ok=True)
66
+ # Save as json
67
+ with open(os.path.join(args.output_dir, const.fname.litdata_others2save_json), "w") as f:
68
+ json.dump(others2save, f)
69
+ # Save as pickle
70
+ with open(os.path.join(args.output_dir, const.fname.litdata_others2save_pkl), "wb") as f:
71
+ pickle.dump(others2save, f)
72
+
73
+ if labels is not None:
74
+ shutil.copy(labels, os.path.join(args.output_dir, const.fname.label_class_onehot))
75
+
76
+ # Check obs_names filter
77
+ if specify_trn_obs is not None:
78
+ _edge_attr, _edge_index, _mat, _mat_feat_indices, _obs_names, _node_names = read_nmic_npz(args.trn_npz)
79
+ _obs_names_goal = pl.read_parquet(specify_trn_obs)["obs_names"].to_list()
80
+ # Get intersection of _obs_names and _obs_names_goal
81
+ _obs_names_filtered = list(set(_obs_names) & set(_obs_names_goal))
82
+ # Get available indices in _obs_names for the following litdata getting indices
83
+ _obs_names_indices = np.where(np.isin(_obs_names, _obs_names_filtered))[0]
84
+ _trn_indices = _obs_names_indices.tolist()
85
+ else:
86
+ _trn_indices = list(range(datamodule.train.len()))
87
+
88
+ # Optimize
89
+ if not args.onlytest:
90
+ litdata.optimize(
91
+ fn=datamodule.train.get,
92
+ inputs=_trn_indices,
93
+ output_dir=os.path.join(args.output_dir, const.dkey.abbr_train),
94
+ chunk_bytes=const.default.lit_chunk_bytes,
95
+ compression=const.default.lit_compression,
96
+ num_workers=min(args.n_workers, const.default.n_threads),
97
+ )
98
+ litdata.optimize(
99
+ fn=datamodule.val.get,
100
+ inputs=list(range(datamodule.val.len())),
101
+ output_dir=os.path.join(args.output_dir, const.dkey.abbr_val),
102
+ chunk_bytes=const.default.lit_chunk_bytes,
103
+ compression=const.default.lit_compression,
104
+ num_workers=min(args.n_workers, const.default.n_threads),
105
+ )
106
+ litdata.optimize(
107
+ fn=datamodule.test.get,
108
+ inputs=list(range(datamodule.test.len())),
109
+ output_dir=os.path.join(args.output_dir, const.dkey.abbr_test),
110
+ chunk_bytes=const.default.lit_chunk_bytes,
111
+ compression=const.default.lit_compression,
112
+ num_workers=min(args.n_workers, const.default.n_threads),
113
+ )
@@ -0,0 +1,31 @@
1
+ import argparse
2
+ import os
3
+
4
+ from deeptan.graph.recon import predict_perturbation
5
+
6
+
7
+ def deeptan_perturb():
8
+ parser = argparse.ArgumentParser(description="DeepTAN perturbation script.")
9
+ parser.add_argument("--em", type=str, required=True, help="Existing model checkpoint path.")
10
+ parser.add_argument("--litdata", "--data", type=str, required=True, help="Path to litdata directory")
11
+ parser.add_argument("--output", "--out", type=str, required=True, help="Path to output file")
12
+ parser.add_argument("--maplocation", "--maploc", type=str, default=None, help="Map location for model loading")
13
+ parser.add_argument("--overwrite", action="store_true", help="Overwrite existing output directory")
14
+ args = parser.parse_args()
15
+
16
+ model_path = args.em
17
+ litdata_dir = args.litdata
18
+
19
+ # Create output directory
20
+ output_dir = os.path.dirname(args.output)
21
+ os.makedirs(output_dir, exist_ok=True)
22
+
23
+ print(f"Predicting with model {model_path} on data {litdata_dir}")
24
+ predict_perturbation(
25
+ model_ckpt_path=model_path,
26
+ litdata_dir=litdata_dir,
27
+ output_path=args.output,
28
+ n_perturbations=5,
29
+ map_location=args.maplocation,
30
+ overwrite_files=args.overwrite,
31
+ )
@@ -0,0 +1,41 @@
1
+ import argparse
2
+ import os
3
+
4
+ from deeptan.utils.uni import convert_pickle_to_h5
5
+
6
+
7
+ def convert_pkl_to_h5():
8
+ parser = argparse.ArgumentParser(description="Convert a pickle file to HDF5 format.")
9
+ parser.add_argument("-i", "--input", required=True, help="Path to a pickle file or a directory containing pickle files.")
10
+ parser.add_argument("-o", "--output", required=False, help="Path to the output HDF5 file or directory. If not provided, output will be in the same location as the input file with .h5 extension.")
11
+ parser.add_argument("-f", "--force", action="store_true", help="Overwrite existing HDF5 files without prompting.")
12
+ args = parser.parse_args()
13
+
14
+ if os.path.isdir(args.input):
15
+ for root, _, files in os.walk(args.input):
16
+ for file in files:
17
+ if file.endswith(".pkl"):
18
+ input_pkl = os.path.join(root, file)
19
+ if args.output is None:
20
+ output_h5 = input_pkl.replace(".pkl", ".h5")
21
+ else:
22
+ output_h5 = os.path.join(args.output, file.replace(".pkl", ".h5"))
23
+ if os.path.exists(output_h5) and not args.force:
24
+ print(f"File already exists: {output_h5}. Skipping.")
25
+ else:
26
+ convert_pickle_to_h5(input_pkl, output_h5)
27
+ print(f"Converted {input_pkl} to {output_h5}.")
28
+ else:
29
+ if args.output is None:
30
+ output_h5 = args.input.replace(".pkl", ".h5")
31
+ else:
32
+ if os.path.isdir(args.output):
33
+ output_h5 = os.path.join(args.output, os.path.basename(args.input).replace(".pkl", ".h5"))
34
+ else:
35
+ output_h5 = args.output.replace(".pkl", ".h5")
36
+
37
+ if os.path.exists(output_h5) and not args.force:
38
+ print(f"File already exists: {output_h5}. Skipping.")
39
+ else:
40
+ convert_pickle_to_h5(args.input, output_h5)
41
+ print(f"Conversion complete: {output_h5}.")
@@ -0,0 +1,36 @@
1
+ import argparse
2
+ import os
3
+
4
+ from deeptan.graph.recon import predict
5
+
6
+
7
+ def deeptan_predict():
8
+ parser = argparse.ArgumentParser(description="DeepTAN prediction script.")
9
+ parser.add_argument("--em", type=str, required=True, help="Existing model checkpoint path.")
10
+ parser.add_argument("--litdata", "--data", type=str, required=True, help="Path to litdata directory")
11
+ parser.add_argument("--output", "--out", type=str, required=True, help="Path to output file")
12
+ parser.add_argument("--bs", "--batch_size", type=int, default=8, help="Batch size for prediction")
13
+ parser.add_argument("--maplocation", "--maploc", type=str, default=None, help="Map location for model loading")
14
+ parser.add_argument("--overwrite", action="store_true", help="Overwrite existing output directory")
15
+ args = parser.parse_args()
16
+
17
+ model_path = args.em
18
+ litdata_dir = args.litdata
19
+
20
+ # Create output directory
21
+ output_dir = os.path.dirname(args.output)
22
+ os.makedirs(output_dir, exist_ok=True)
23
+ output_path = args.output if args.output.endswith(".h5") else f"{args.output}.h5"
24
+
25
+ if not os.path.exists(output_path) or args.overwrite:
26
+ print(f"Predicting with model {model_path} on data {litdata_dir}")
27
+ predict(
28
+ model_ckpt_path=model_path,
29
+ litdata_dir=litdata_dir,
30
+ output_path=output_path,
31
+ map_location=args.map_location,
32
+ batch_size=args.bs,
33
+ save_h5=True,
34
+ )
35
+ else:
36
+ print(f"Results already exist at {output_path}")
deeptan/cli/hello.py ADDED
@@ -0,0 +1,6 @@
1
+ import deeptan.constants as const
2
+
3
+
4
+ def hello():
5
+ print(const.art.ascii_art)
6
+ print("\nHello, DeepTAN!\n")
@@ -0,0 +1,32 @@
1
+ import argparse
2
+ import sys
3
+
4
+ import scanpy as sc
5
+
6
+
7
+ def print_sc_h5():
8
+ parser = argparse.ArgumentParser(description="Read an H5AD or H5 file and print its contents.")
9
+ parser.add_argument("-p", "--path", type=str, help="Path to the H5AD or H5 file.")
10
+ args = parser.parse_args()
11
+ if args.path.endswith(".h5ad"):
12
+ adata = sc.read_h5ad(args.path)
13
+ elif args.path.endswith(".h5"):
14
+ adata = sc.read_10x_h5(args.path)
15
+ else:
16
+ print("Unsupported file format. Please provide a .h5ad or .h5 file.")
17
+ sys.exit(1)
18
+ print(adata)
19
+ print("\nobs keys: ")
20
+ print(adata.obs.keys())
21
+ for _key in adata.obs.keys():
22
+ print(f"\n{_key}:")
23
+ print(adata.obs[_key])
24
+ print("\nvar keys: ")
25
+ for _key in adata.var_keys():
26
+ print(f"\n{_key}:")
27
+ print(adata.var[_key])
28
+ if "Celltype" in adata.obs.keys():
29
+ print("\nCelltype:")
30
+ print(type(adata.obs["Celltype"]))
31
+ if "Orig.ident" in adata.obs.keys():
32
+ sc.pl.umap(adata, color="Orig.ident", title="UMAP by Experiment")
@@ -0,0 +1,9 @@
1
+ r"""
2
+ Constants.
3
+ """
4
+
5
+ import deeptan.constants.art as art
6
+ import deeptan.constants.default as default
7
+ import deeptan.constants.dict_key as dkey
8
+ import deeptan.constants.fname as fname
9
+ import deeptan.constants.hparam_candidates as hparam_candidates
@@ -0,0 +1,9 @@
1
+ ascii_art = r"""
2
+ ____ _________ _ __
3
+ / __ \___ ___ ____/_ __/ | / | / /
4
+ / / / / _ \/ _ \/ __ \/ / / /| | / |/ /
5
+ / /_/ / __/ __/ /_/ / / / ___ |/ /| /
6
+ /_____/\___/\___/ .___/_/ /_/ |_/_/ |_/
7
+ /_/
8
+
9
+ """
@@ -0,0 +1,75 @@
1
+ r"""
2
+ Default values.
3
+ """
4
+
5
+ from multiprocessing import cpu_count
6
+ from os import getenv
7
+
8
+ from numpy import ceil
9
+
10
+ bs = 8
11
+ accumulate_grad_batches = 4
12
+ lr = 0.0002
13
+ es = 5
14
+ min_epoch = 3
15
+ max_epoch = 50
16
+ dropout = 0.2
17
+ negative_slope = 0.2
18
+ node_emb_dim = 128
19
+ g_emb_dim = 192
20
+ label_pred_hidden_dims = [1024, 256]
21
+ fusion_dims_node_emb = [256, 128]
22
+ n_heads_pooling = 4
23
+ n_heads_node_emb = 4
24
+ n_heads_ge_decoder = 4
25
+ n_heads_label_pred = 4
26
+ n_hop = 1
27
+
28
+ chunk_size = 256
29
+ mem_safety_factor = 0.6
30
+ operation_overhead = 2.5
31
+
32
+ threshold_nmic = 0.01
33
+ threshold_subg_overlap = 0.85
34
+ threshold_edge_exist = 0.03
35
+
36
+ matmul_precision = "high"
37
+ accelerator = "auto"
38
+ devices = "auto"
39
+ precision = "32-true"
40
+ gradient_clip_val = 1.0
41
+
42
+ n_threads = int(getenv("NUM_THREADS", ceil(cpu_count() * 0.9)))
43
+
44
+ time_format = "%Y%m%d%H%M%S"
45
+ time_delay = 11.7
46
+ ckpt_fname_format = "best_model"
47
+ optuna_db = "sqlite:///optuna.db"
48
+ n_jobs = 1
49
+ n_trials = 30
50
+ n_workers = 1
51
+
52
+ lit_chunk_bytes = "256MB"
53
+ lit_compression = "zstd"
54
+ lit_max_cache_size = "26GB"
55
+
56
+ model_config = {
57
+ "guide_gat": True,
58
+ "class_weights": None,
59
+ "use_focal_loss": True,
60
+ "focal_alpha": None,
61
+ "node_emb_dim": node_emb_dim,
62
+ "fusion_dims_node_emb": fusion_dims_node_emb,
63
+ "output_dim_g_emb": g_emb_dim,
64
+ "n_hop": n_hop,
65
+ "threshold_edge_exist": threshold_edge_exist,
66
+ "threshold_subgraph_overlap": threshold_subg_overlap,
67
+ "n_heads_node_emb": n_heads_node_emb,
68
+ "n_heads_pooling": n_heads_pooling,
69
+ "n_heads_ge_decoder": n_heads_ge_decoder,
70
+ "n_heads_label_pred": n_heads_label_pred,
71
+ "dropout": dropout,
72
+ "lr": lr,
73
+ "chunk_size": chunk_size,
74
+ "n_workers": n_workers,
75
+ }
@@ -0,0 +1,98 @@
1
+ r"""
2
+ Dictionary keys and column names.
3
+ """
4
+
5
+ title_train = "train"
6
+ title_val = "val"
7
+ title_test = "test"
8
+ title_predict = "pred"
9
+
10
+ abbr_train = "trn"
11
+ abbr_val = "val"
12
+ abbr_test = "tst"
13
+
14
+ splits = [abbr_train, abbr_val, abbr_test]
15
+
16
+ title_trn_loss = "trn/loss"
17
+ title_val_loss = "val/loss"
18
+ title_tst_loss = "tst/loss"
19
+
20
+ tsb_keys2pick = [
21
+ "test/recon_MSE",
22
+ "test/recon_RMSE",
23
+ "test/recon_MAE",
24
+ "test/recon_PCC",
25
+ "test/label_MSE",
26
+ "test/label_RMSE",
27
+ "test/label_MAE",
28
+ "test/label_PCC",
29
+ "test/label_F1_weighted",
30
+ "test/label_F1_macro",
31
+ "test/label_F1_micro",
32
+ "test/label_AUROC",
33
+ "test/label_Accuracy",
34
+ "test/label_Precision",
35
+ "test/label_Recall",
36
+ "test/loss",
37
+ "test/loss_unweighted",
38
+ "val/loss",
39
+ "val/loss_unweighted",
40
+ "val/recon_MSE",
41
+ "val/recon_RMSE",
42
+ "val/recon_MAE",
43
+ "val/recon_PCC",
44
+ "val/label_MSE",
45
+ "val/label_RMSE",
46
+ "val/label_MAE",
47
+ "val/label_PCC",
48
+ "val/label_F1_weighted",
49
+ "val/label_F1_macro",
50
+ "val/label_F1_micro",
51
+ "val/label_AUROC",
52
+ "val/label_Accuracy",
53
+ "val/label_Precision",
54
+ "val/label_Recall",
55
+ ]
56
+
57
+ title_metric_mapping = {
58
+ "jsd": "1 - JSD",
59
+ "mae": "1 - MAE",
60
+ "mse": "1 - MSE",
61
+ "pcc": "PCC",
62
+ "spearman": "Spearman",
63
+ "weighted_recall": "Recall (weighted)",
64
+ "weighted_precision": "Precision (weighted)",
65
+ "weighted_f1": "F1 Score (weighted)",
66
+ "macro_f1": "F1 Score (macro)",
67
+ "micro_f1": "F1 Score (micro)",
68
+ "auprc": "AUPRC",
69
+ "auroc": "AUROC",
70
+ "accuracy": "ACC",
71
+ "kbet_true_label": "kBET (true labels)",
72
+ "kbet_pred_label": "kBET (predicted labels)",
73
+ "asw_true_label": "ASW (true labels)",
74
+ "asw_pred_label": "ASW (predicted labels)",
75
+ "kbet": "kBET",
76
+ "asw": "ASW",
77
+ "ari": "ARI",
78
+ "nmi": "NMI",
79
+ "ami": "AMI",
80
+ "ari_leiden": "ARI (Leiden)",
81
+ "nmi_leiden": "NMI (Leiden)",
82
+ "ami_leiden": "AMI (Leiden)",
83
+ "method": "Method",
84
+ "train_size": "Train size",
85
+ "n_feat": "Number of features for test",
86
+ }
87
+ title_task_mapping = {
88
+ "multitask": "Multitask",
89
+ "multitask_noguide": "Multitask (no SGG)",
90
+ "focus_recon": "Focus on reconstruction",
91
+ "focus_label": "Focus on labelling",
92
+ "focus_label_on_focus_recon": "Fine-tuned on reconstruction",
93
+ }
94
+ title_colnameC2_mapping = {
95
+ "task": "Task",
96
+ "metric": "Metric",
97
+ "value": "Value",
98
+ }
@@ -0,0 +1,7 @@
1
+ r"""
2
+ File names.
3
+ """
4
+
5
+ litdata_others2save_pkl = "others2save.pkl"
6
+ litdata_others2save_json = "others2save.json"
7
+ label_class_onehot = "label_onehot.parquet"
@@ -0,0 +1,8 @@
1
+ r"""
2
+ Hyperparameter candidates.
3
+ """
4
+
5
+ batch_size = [16, 32, 64, 128]
6
+ lr = [5e-3, 1e-3, 5e-4, 1e-4, 5e-5]
7
+ dropout_high = 0.8
8
+ dropout_step = 0.2
@@ -0,0 +1,5 @@
1
+ r"""
2
+ The backbone graph module.
3
+ """
4
+
5
+ from deeptan.graph import model, modules, recon