ONTraC 0.0.4b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ONTraC/__init__.py +0 -0
- ONTraC/__pycache__/__init__.cpython-311.pyc +0 -0
- ONTraC/__pycache__/__init__.cpython-312.pyc +0 -0
- ONTraC/bin/GP.py +92 -0
- ONTraC/bin/NTScore.py +46 -0
- ONTraC/bin/ONTraC.py +109 -0
- ONTraC/bin/__init__.py +0 -0
- ONTraC/bin/createDataSet.py +40 -0
- ONTraC/data.py +102 -0
- ONTraC/log.py +41 -0
- ONTraC/model/__init__.py +1 -0
- ONTraC/model/_model.py +152 -0
- ONTraC/model/dmon_exp_pool.py +168 -0
- ONTraC/model/norm_dense_gcn_conv.py +89 -0
- ONTraC/optparser/_GP.py +63 -0
- ONTraC/optparser/_IO.py +104 -0
- ONTraC/optparser/_NT.py +49 -0
- ONTraC/optparser/_ONTraC.py +81 -0
- ONTraC/optparser/__init__.py +4 -0
- ONTraC/optparser/_create_dataset.py +88 -0
- ONTraC/optparser/_train.py +235 -0
- ONTraC/run/processes.py +212 -0
- ONTraC/train/__init__.py +1 -0
- ONTraC/train/_batch_train.py +254 -0
- ONTraC/train/inspect_funcs.py +180 -0
- ONTraC/train/loss_funs.py +178 -0
- ONTraC/utils/NTScore.py +120 -0
- ONTraC/utils/__init__.py +1 -0
- ONTraC/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- ONTraC/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- ONTraC/utils/__pycache__/_utils.cpython-311.pyc +0 -0
- ONTraC/utils/__pycache__/_utils.cpython-312.pyc +0 -0
- ONTraC/utils/_utils.py +85 -0
- ONTraC/utils/decorators.py +90 -0
- ONTraC/utils/niche_net_constr.py +176 -0
- ONTraC/version.py +1 -0
- ONTraC-0.0.4b4.dist-info/LICENSE +21 -0
- ONTraC-0.0.4b4.dist-info/METADATA +166 -0
- ONTraC-0.0.4b4.dist-info/RECORD +42 -0
- ONTraC-0.0.4b4.dist-info/WHEEL +5 -0
- ONTraC-0.0.4b4.dist-info/entry_points.txt +5 -0
- ONTraC-0.0.4b4.dist-info/top_level.txt +1 -0
ONTraC/__init__.py
ADDED
|
File without changes
|
|
Binary file
|
|
Binary file
|
ONTraC/bin/GP.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from ONTraC.model import GraphPooling
|
|
10
|
+
from ONTraC.optparser import opt_GP_validate, prepare_GP_optparser
|
|
11
|
+
from ONTraC.run.processes import *
|
|
12
|
+
from ONTraC.train import GPBatchTrain, SubBatchTrainProtocol
|
|
13
|
+
from ONTraC.train.inspect_funcs import loss_record
|
|
14
|
+
from ONTraC.utils import device_validate
|
|
15
|
+
|
|
16
|
+
# ------------------------------------
|
|
17
|
+
# Classes
|
|
18
|
+
# ------------------------------------
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# ------------------------------------
|
|
22
|
+
# Functions
|
|
23
|
+
# ------------------------------------
|
|
24
|
+
def get_inspect_funcs() -> Optional[list[Callable]]:
|
|
25
|
+
"""
|
|
26
|
+
Inspect function list
|
|
27
|
+
:param output_dir: output dir
|
|
28
|
+
:param epoch_filter: epoch filter
|
|
29
|
+
:return: list of inspect functions
|
|
30
|
+
"""
|
|
31
|
+
return [loss_record]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ------------------------------------
|
|
35
|
+
# Main Function
|
|
36
|
+
# ------------------------------------
|
|
37
|
+
def main() -> None:
|
|
38
|
+
"""
|
|
39
|
+
Main function
|
|
40
|
+
:return: None
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
# ----- prepare -----
|
|
44
|
+
# load parameters
|
|
45
|
+
options = load_parameters(opt_validate_func=opt_GP_validate, prepare_optparser_func=prepare_GP_optparser)
|
|
46
|
+
# device
|
|
47
|
+
device: torch.device = device_validate(device_name=options.device)
|
|
48
|
+
# load data
|
|
49
|
+
dataset, sample_loader = load_data(options=options)
|
|
50
|
+
# random seed
|
|
51
|
+
n_seed = t_seed = r_seed = options.seed
|
|
52
|
+
random.seed(a=r_seed)
|
|
53
|
+
torch.manual_seed(seed=t_seed)
|
|
54
|
+
np.random.seed(seed=n_seed)
|
|
55
|
+
|
|
56
|
+
# ----- train -----
|
|
57
|
+
inspect_funcs_list = get_inspect_funcs()
|
|
58
|
+
batch_train: SubBatchTrainProtocol = train(nn_model=GraphPooling,
|
|
59
|
+
options=options,
|
|
60
|
+
BatchTrain=GPBatchTrain,
|
|
61
|
+
device=device,
|
|
62
|
+
dataset=dataset,
|
|
63
|
+
sample_loader=sample_loader,
|
|
64
|
+
inspect_funcs=inspect_funcs_list,
|
|
65
|
+
model_name='GraphPooling')
|
|
66
|
+
|
|
67
|
+
# --- evaluate ---
|
|
68
|
+
evaluate(batch_train=batch_train, model_name='GraphPooling')
|
|
69
|
+
|
|
70
|
+
# ----- predict -----
|
|
71
|
+
consolidate_s_array, consolidate_out_adj_array = predict(output_dir=options.GNN_dir,
|
|
72
|
+
batch_train=batch_train,
|
|
73
|
+
dataset=dataset,
|
|
74
|
+
model_name='GraphPooling')
|
|
75
|
+
|
|
76
|
+
# ----- Pseudotime -----
|
|
77
|
+
if consolidate_s_array is not None and consolidate_out_adj_array is not None:
|
|
78
|
+
NTScore(options=options,
|
|
79
|
+
dataset=dataset,
|
|
80
|
+
consolidate_s_array=consolidate_s_array,
|
|
81
|
+
consolidate_out_adj_array=consolidate_out_adj_array)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# ------------------------------------
|
|
85
|
+
# Program running
|
|
86
|
+
# ------------------------------------
|
|
87
|
+
if __name__ == '__main__':
|
|
88
|
+
try:
|
|
89
|
+
main()
|
|
90
|
+
except KeyboardInterrupt:
|
|
91
|
+
sys.stderr.write("User interrupts me! ;-) See you ^.^!\n")
|
|
92
|
+
sys.exit(0)
|
ONTraC/bin/NTScore.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from ONTraC.data import load_dataset
|
|
8
|
+
from ONTraC.optparser import opt_NT_validate, prepare_NT_optparser
|
|
9
|
+
from ONTraC.run.processes import *
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# ------------------------------------
|
|
13
|
+
# Main Function
|
|
14
|
+
# ------------------------------------
|
|
15
|
+
def main() -> None:
|
|
16
|
+
"""
|
|
17
|
+
Main function
|
|
18
|
+
:return: None
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
# ----- prepare -----
|
|
22
|
+
# --- load parameters ---
|
|
23
|
+
options = load_parameters(opt_validate_func=opt_NT_validate, prepare_optparser_func=prepare_NT_optparser)
|
|
24
|
+
# --- load data ---
|
|
25
|
+
dataset, _ = load_dataset(options=options)
|
|
26
|
+
# load consolidated s_array and out_adj_array
|
|
27
|
+
consolidate_s_array = np.loadtxt(fname=f'{options.GNN_dir}/consolidate_s.csv.gz', delimiter=',')
|
|
28
|
+
consolidate_out_adj_array = np.loadtxt(fname=f'{options.GNN_dir}/consolidate_out_adj.csv.gz', delimiter=',')
|
|
29
|
+
|
|
30
|
+
# ----- Pseudotime -----
|
|
31
|
+
if consolidate_s_array is not None and consolidate_out_adj_array is not None:
|
|
32
|
+
NTScore(options=options,
|
|
33
|
+
dataset=dataset,
|
|
34
|
+
consolidate_s_array=consolidate_s_array,
|
|
35
|
+
consolidate_out_adj_array=consolidate_out_adj_array)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# ------------------------------------
|
|
39
|
+
# Program running
|
|
40
|
+
# ------------------------------------
|
|
41
|
+
if __name__ == '__main__':
|
|
42
|
+
try:
|
|
43
|
+
main()
|
|
44
|
+
except KeyboardInterrupt:
|
|
45
|
+
sys.stderr.write("User interrupts me! ;-) See you ^.^!\n")
|
|
46
|
+
sys.exit(0)
|
ONTraC/bin/ONTraC.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
import sys
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from ONTraC.log import *
|
|
11
|
+
from ONTraC.model import GraphPooling
|
|
12
|
+
from ONTraC.optparser import opt_ontrac_validate, prepare_ontrac_optparser
|
|
13
|
+
from ONTraC.run.processes import *
|
|
14
|
+
from ONTraC.train import GPBatchTrain, SubBatchTrainProtocol
|
|
15
|
+
from ONTraC.train.inspect_funcs import loss_record
|
|
16
|
+
from ONTraC.utils import device_validate
|
|
17
|
+
from ONTraC.utils.niche_net_constr import (construct_niche_network,
|
|
18
|
+
gen_samples_yaml,
|
|
19
|
+
load_original_data)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ------------------------------------
|
|
23
|
+
# Functions
|
|
24
|
+
# ------------------------------------
|
|
25
|
+
def get_inspect_funcs() -> Optional[list[Callable]]:
|
|
26
|
+
"""
|
|
27
|
+
Inspect function list
|
|
28
|
+
:param output_dir: output dir
|
|
29
|
+
:param epoch_filter: epoch filter
|
|
30
|
+
:return: list of inspect functions
|
|
31
|
+
"""
|
|
32
|
+
return [loss_record]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ------------------------------------
|
|
36
|
+
# Main Function
|
|
37
|
+
# ------------------------------------
|
|
38
|
+
def main() -> None:
|
|
39
|
+
"""
|
|
40
|
+
main function
|
|
41
|
+
Input data files information should be stored in a YAML file.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
# prepare options
|
|
45
|
+
options = load_parameters(opt_validate_func=opt_ontrac_validate, prepare_optparser_func=prepare_ontrac_optparser)
|
|
46
|
+
|
|
47
|
+
# ----- Niche Network Construct -----
|
|
48
|
+
# load original data
|
|
49
|
+
ori_data_df = load_original_data(options=options)
|
|
50
|
+
|
|
51
|
+
# define edges for each sample
|
|
52
|
+
construct_niche_network(options=options, ori_data_df=ori_data_df)
|
|
53
|
+
|
|
54
|
+
# save samples.yaml
|
|
55
|
+
gen_samples_yaml(options=options, ori_data_df=ori_data_df)
|
|
56
|
+
|
|
57
|
+
# ----- Graph Pooling -----
|
|
58
|
+
# device
|
|
59
|
+
device: torch.device = device_validate(device_name=options.device)
|
|
60
|
+
# load data
|
|
61
|
+
dataset, sample_loader = load_data(options=options)
|
|
62
|
+
# random seed
|
|
63
|
+
n_seed = t_seed = r_seed = options.seed
|
|
64
|
+
random.seed(a=r_seed)
|
|
65
|
+
torch.manual_seed(seed=t_seed)
|
|
66
|
+
np.random.seed(seed=n_seed)
|
|
67
|
+
# train
|
|
68
|
+
inspect_funcs_list = get_inspect_funcs()
|
|
69
|
+
batch_train: SubBatchTrainProtocol = train(nn_model=GraphPooling,
|
|
70
|
+
options=options,
|
|
71
|
+
BatchTrain=GPBatchTrain,
|
|
72
|
+
device=device,
|
|
73
|
+
dataset=dataset,
|
|
74
|
+
sample_loader=sample_loader,
|
|
75
|
+
inspect_funcs=inspect_funcs_list,
|
|
76
|
+
model_name='GraphPooling')
|
|
77
|
+
# evaluate
|
|
78
|
+
evaluate(batch_train=batch_train, model_name='GraphPooling')
|
|
79
|
+
# predict
|
|
80
|
+
consolidate_s_array, consolidate_out_adj_array = predict(output_dir=options.GNN_dir,
|
|
81
|
+
batch_train=batch_train,
|
|
82
|
+
dataset=dataset,
|
|
83
|
+
model_name='GraphPooling')
|
|
84
|
+
# niche cluster
|
|
85
|
+
if consolidate_s_array is not None:
|
|
86
|
+
graph_pooling_output(ori_data_df=ori_data_df,
|
|
87
|
+
dataset=dataset,
|
|
88
|
+
rel_params=get_rel_params(
|
|
89
|
+
options=options, params=read_yaml_file(f'{options.preprocessing_dir}/samples.yaml')),
|
|
90
|
+
consolidate_s_array=consolidate_s_array,
|
|
91
|
+
output_dir=options.GNN_dir)
|
|
92
|
+
|
|
93
|
+
# ----- NT score -----
|
|
94
|
+
if consolidate_s_array is not None and consolidate_out_adj_array is not None:
|
|
95
|
+
NTScore(options=options,
|
|
96
|
+
dataset=dataset,
|
|
97
|
+
consolidate_s_array=consolidate_s_array,
|
|
98
|
+
consolidate_out_adj_array=consolidate_out_adj_array)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
# ------------------------------------
|
|
102
|
+
# Program running
|
|
103
|
+
# ------------------------------------
|
|
104
|
+
if __name__ == '__main__':
|
|
105
|
+
try:
|
|
106
|
+
main()
|
|
107
|
+
except KeyboardInterrupt:
|
|
108
|
+
sys.stderr.write("User interrupts me! ;-) See you ^.^!\n")
|
|
109
|
+
sys.exit(0)
|
ONTraC/bin/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from ONTraC.log import *
|
|
6
|
+
from ONTraC.optparser import (opt_create_ds_validate, prepare_create_ds_optparser)
|
|
7
|
+
from ONTraC.utils.niche_net_constr import load_original_data, construct_niche_network, gen_samples_yaml
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# ------------------------------------
|
|
11
|
+
# Main Function
|
|
12
|
+
# ------------------------------------
|
|
13
|
+
def main() -> None:
|
|
14
|
+
"""
|
|
15
|
+
main function
|
|
16
|
+
Input data files information should be stored in a YAML file.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# prepare options
|
|
20
|
+
options = opt_create_ds_validate(prepare_create_ds_optparser())
|
|
21
|
+
|
|
22
|
+
# load original data
|
|
23
|
+
ori_data_df = load_original_data(options=options, data_file=options.dataset)
|
|
24
|
+
|
|
25
|
+
# define edges for each sample
|
|
26
|
+
construct_niche_network(options=options, ori_data_df=ori_data_df)
|
|
27
|
+
|
|
28
|
+
# save samples.yaml
|
|
29
|
+
gen_samples_yaml(options=options, ori_data_df=ori_data_df)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# ------------------------------------
|
|
33
|
+
# Program running
|
|
34
|
+
# ------------------------------------
|
|
35
|
+
if __name__ == '__main__':
|
|
36
|
+
try:
|
|
37
|
+
main()
|
|
38
|
+
except KeyboardInterrupt:
|
|
39
|
+
sys.stderr.write("User interrupts me! ;-) See you ^.^!\n")
|
|
40
|
+
sys.exit(0)
|
ONTraC/data.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from optparse import Values
|
|
2
|
+
from typing import Dict, List, Tuple
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import torch
|
|
7
|
+
import torch_geometric.transforms as T
|
|
8
|
+
from torch_geometric.data import Data, InMemoryDataset
|
|
9
|
+
from torch_geometric.loader import DenseDataLoader
|
|
10
|
+
|
|
11
|
+
from .log import *
|
|
12
|
+
from .utils import count_lines, device_validate, get_rel_params, read_yaml_file
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# ------------------------------------
|
|
16
|
+
# Classes
|
|
17
|
+
# ------------------------------------
|
|
18
|
+
class SpatailOmicsDataset(InMemoryDataset):
|
|
19
|
+
|
|
20
|
+
def __init__(self, root, params: Dict, transform=None, pre_transform=None):
|
|
21
|
+
self.params = params
|
|
22
|
+
super(SpatailOmicsDataset, self).__init__(root, transform, pre_transform)
|
|
23
|
+
self.data, self.slices = torch.load(self.processed_paths[0])
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def raw_file_names(self):
|
|
27
|
+
# return list(
|
|
28
|
+
# flatten([[sample for name, sample in data.items() if name != 'Name'] for data in self.params['Data']]))
|
|
29
|
+
return []
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def processed_file_names(self):
|
|
33
|
+
return ['data.pt']
|
|
34
|
+
|
|
35
|
+
def download(self):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def process(self):
|
|
39
|
+
data_list = []
|
|
40
|
+
for index, sample in enumerate(self.params['Data']):
|
|
41
|
+
info(f'Processing sample {index + 1} of {len(self.params["Data"])}')
|
|
42
|
+
data = Data(
|
|
43
|
+
x=torch.from_numpy(np.loadtxt(sample['Features'], dtype=np.float32, delimiter=',')),
|
|
44
|
+
edge_index=torch.from_numpy(np.loadtxt(sample['EdgeIndex'], dtype=np.int64,
|
|
45
|
+
delimiter=',')).t().contiguous(),
|
|
46
|
+
# TODO: support 3D coordinates
|
|
47
|
+
pos=torch.from_numpy(pd.read_csv(sample['Coordinates'])[['x', 'y']].values),
|
|
48
|
+
name=sample['Name'])
|
|
49
|
+
data_list.append(data)
|
|
50
|
+
data, slices = self.collate(data_list)
|
|
51
|
+
torch.save((data, slices), self.processed_paths[0])
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ------------------------------------
|
|
55
|
+
# Misc functions
|
|
56
|
+
# ------------------------------------
|
|
57
|
+
def max_nodes(samples: List[Dict[str, str]]) -> int:
|
|
58
|
+
"""
|
|
59
|
+
Get the maximum number of nodes in a dataset
|
|
60
|
+
:param params: List[Dict[str, str], list of samples
|
|
61
|
+
:return: int, maximum number of nodes
|
|
62
|
+
"""
|
|
63
|
+
max_nodes = 0
|
|
64
|
+
for sample in samples:
|
|
65
|
+
max_nodes = max(max_nodes, count_lines(sample['Coordinates']))
|
|
66
|
+
return max_nodes
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def load_dataset(options: Values) -> Tuple[SpatailOmicsDataset, Data]:
|
|
70
|
+
device = device_validate()
|
|
71
|
+
params = read_yaml_file(f'{options.preprocessing_dir}/samples.yaml')
|
|
72
|
+
rel_params = get_rel_params(options, params)
|
|
73
|
+
dataset = create_torch_dataset(options, rel_params)
|
|
74
|
+
all_sample_loader = DenseDataLoader(dataset, batch_size=len(dataset))
|
|
75
|
+
data = next(iter(all_sample_loader)).to(device)
|
|
76
|
+
return dataset, data
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ------------------------------------
|
|
80
|
+
# Flow control functions
|
|
81
|
+
# ------------------------------------
|
|
82
|
+
def create_torch_dataset(options: Values, params: Dict) -> SpatailOmicsDataset:
|
|
83
|
+
"""
|
|
84
|
+
Create torch dataset
|
|
85
|
+
:param params: Dict, input samples
|
|
86
|
+
:return: None
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
# ------------------------------------
|
|
90
|
+
# Step 1: Get the maximum number of nodes
|
|
91
|
+
m_nodes = max_nodes(params['Data'])
|
|
92
|
+
# upcelling m_nodes to the nearest 100
|
|
93
|
+
m_nodes = int(np.ceil(m_nodes / 100.0)) * 100
|
|
94
|
+
info(f'Maximum number of nodes: {m_nodes}')
|
|
95
|
+
# ------------------------------------
|
|
96
|
+
|
|
97
|
+
# ------------------------------------
|
|
98
|
+
# Step 2: Create torch dataset
|
|
99
|
+
dataset = SpatailOmicsDataset(root=options.preprocessing_dir, params=params,
|
|
100
|
+
transform=T.ToDense(m_nodes)) # transform edge_index to adj matrix
|
|
101
|
+
# dataset = SpatailOmicsDataset(root=options.input, params=params)
|
|
102
|
+
return dataset
|
ONTraC/log.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import time
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_current_time() -> str:
|
|
6
|
+
return time.strftime('%H:%M:%S', time.localtime())
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def write_direct_message(message: str):
|
|
10
|
+
curr_time_str = get_current_time()
|
|
11
|
+
sys.stdout.write(f'{curr_time_str} --- {message}\n')
|
|
12
|
+
sys.stdout.flush()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def debug(message: str):
|
|
16
|
+
write_direct_message(f'DEBUG: {message}')
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def info(message: str):
|
|
20
|
+
write_direct_message(f'INFO: {message}')
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def write_direct_message_err(message: str):
|
|
24
|
+
curr_time_str = get_current_time()
|
|
25
|
+
sys.stderr.write(f'{curr_time_str} --- {message}\n')
|
|
26
|
+
sys.stderr.flush()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def warning(message: str):
|
|
30
|
+
write_direct_message_err(f'WARNING: {message}')
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def error(message: str):
|
|
34
|
+
write_direct_message_err(f'ERROR: {message}')
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def critical(message: str):
|
|
38
|
+
write_direct_message_err(f'CRITICAL: {message}')
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
__all__ = ['debug', 'info', 'warning', 'error', 'critical']
|
ONTraC/model/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ._model import *
|
ONTraC/model/_model.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from ..log import *
|
|
7
|
+
from .dmon_exp_pool import DMoNPooling
|
|
8
|
+
from .norm_dense_gcn_conv import NormDenseGCNConv
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NodePooling(torch.nn.Module):
|
|
12
|
+
"""
|
|
13
|
+
NodePooling
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, input_feats, k: int, dropout: float = 0, exponent: float = 1, *args, **kwargs) -> None:
|
|
17
|
+
super().__init__(*args, **kwargs)
|
|
18
|
+
self.dropout = dropout
|
|
19
|
+
self.exponent = exponent
|
|
20
|
+
self.pool = DMoNPooling(channels=input_feats, k=k, dropout=0, exponent=self.exponent)
|
|
21
|
+
self.k = k
|
|
22
|
+
|
|
23
|
+
self.reset_parameters()
|
|
24
|
+
|
|
25
|
+
def reset_parameters(self) -> None:
|
|
26
|
+
self.pool.reset_parameters()
|
|
27
|
+
|
|
28
|
+
def forward(self,
|
|
29
|
+
x: Tensor,
|
|
30
|
+
adj: Tensor,
|
|
31
|
+
mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
32
|
+
r"""
|
|
33
|
+
forward function
|
|
34
|
+
Args:
|
|
35
|
+
x (torch.Tensor): Node feature tensor
|
|
36
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
37
|
+
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
|
|
38
|
+
each graph, and feature dimension :math:`F`.
|
|
39
|
+
adj (torch.Tensor): Adjacency tensor
|
|
40
|
+
:math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
|
|
41
|
+
mask (torch.Tensor, optional): Mask matrix
|
|
42
|
+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
|
43
|
+
the valid nodes for each graph. (default: :obj:`None`)
|
|
44
|
+
Returns:
|
|
45
|
+
Tensor: output feature matrix
|
|
46
|
+
"""
|
|
47
|
+
s, out, out_adj, spectral_loss, ortho_loss, cluster_loss = self.pool(x=x, adj=adj, mask=mask)
|
|
48
|
+
return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss
|
|
49
|
+
|
|
50
|
+
def predict(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
|
|
51
|
+
s, out, out_adj, *_ = self.pool(x=x, adj=adj, mask=mask)
|
|
52
|
+
return s, out, out_adj
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class GraphPooling(torch.nn.Module):
|
|
56
|
+
"""
|
|
57
|
+
GNN with Node Pooling
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self,
|
|
61
|
+
input_feats: int,
|
|
62
|
+
hidden_feats: int,
|
|
63
|
+
k: int,
|
|
64
|
+
dropout: float = 0,
|
|
65
|
+
exponent: float = 1,
|
|
66
|
+
*args,
|
|
67
|
+
**kwargs) -> None:
|
|
68
|
+
super().__init__(*args, **kwargs)
|
|
69
|
+
self.gcn1 = NormDenseGCNConv(input_feats, hidden_feats)
|
|
70
|
+
self.activation1 = torch.nn.SELU()
|
|
71
|
+
self.gcn2 = NormDenseGCNConv(hidden_feats, hidden_feats)
|
|
72
|
+
self.activation2 = torch.nn.SELU()
|
|
73
|
+
self.pool = NodePooling(input_feats=hidden_feats, k=k, dropout=dropout, exponent=exponent)
|
|
74
|
+
self.k = k
|
|
75
|
+
|
|
76
|
+
self.reset_parameters()
|
|
77
|
+
|
|
78
|
+
def reset_parameters(self) -> None:
|
|
79
|
+
self.gcn1.reset_parameters()
|
|
80
|
+
self.gcn2.reset_parameters()
|
|
81
|
+
self.pool.reset_parameters()
|
|
82
|
+
|
|
83
|
+
def forward(self,
|
|
84
|
+
x: Tensor,
|
|
85
|
+
adj: Tensor,
|
|
86
|
+
mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
87
|
+
r"""
|
|
88
|
+
forward function
|
|
89
|
+
X' = \mathbf{\hat{L}}X\mathbf{\Theta}
|
|
90
|
+
\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}
|
|
91
|
+
\mathbf(\hat{D}) = \sum_{j=1}^N \mathbf{\hat{A}}_{ij}
|
|
92
|
+
\mathbf{\hat{L}} = \mathbf{\hat{D}}^{-1/2}\mathbf{\hat{A}}\mathbf{\hat{D}}^{-1/2}
|
|
93
|
+
Args:
|
|
94
|
+
x (torch.Tensor): Node feature tensor
|
|
95
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
96
|
+
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
|
|
97
|
+
each graph, and feature dimension :math:`F`.
|
|
98
|
+
adj (torch.Tensor): Adjacency tensor
|
|
99
|
+
:math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
|
|
100
|
+
mask (torch.Tensor, optional): Mask matrix
|
|
101
|
+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
|
|
102
|
+
the valid nodes for each graph. (default: :obj:`None`)
|
|
103
|
+
Returns:
|
|
104
|
+
Tensor: output feature matrix
|
|
105
|
+
"""
|
|
106
|
+
x = self.activation1(self.gcn1(x=x, adj=adj, mask=mask))
|
|
107
|
+
x = self.activation2(self.gcn2(x=x, adj=adj, mask=mask))
|
|
108
|
+
s, out, out_adj, spectral_loss, ortho_loss, cluster_loss = self.pool(x=x, adj=adj, mask=mask)
|
|
109
|
+
return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss
|
|
110
|
+
|
|
111
|
+
def evaluate(self,
|
|
112
|
+
x: Tensor,
|
|
113
|
+
adj: Tensor,
|
|
114
|
+
mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
|
115
|
+
x = self.activation1(self.gcn1(x=x, adj=adj, mask=mask))
|
|
116
|
+
x = self.activation2(self.gcn2(x=x, adj=adj, mask=mask))
|
|
117
|
+
s, out, out_adj, spectral_loss, ortho_loss, cluster_loss = self.pool(x=x, adj=adj, mask=mask)
|
|
118
|
+
return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss
|
|
119
|
+
|
|
120
|
+
def predict(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
|
|
121
|
+
r"""
|
|
122
|
+
predict function
|
|
123
|
+
Args:
|
|
124
|
+
x (torch.Tensor): Node feature tensor
|
|
125
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
|
|
126
|
+
batch-size :math:`B`, (maximum) number of nodes :math:`N` for
|
|
127
|
+
each graph, and feature dimension :math:`F`.
|
|
128
|
+
adj (torch.Tensor): Adjacency tensor
|
|
129
|
+
:math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
|
|
130
|
+
mask (torch.Tensor, optional): Mask matrix\
|
|
131
|
+
:math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating\
|
|
132
|
+
the valid nodes for each graph. (default: :obj:`None`)
|
|
133
|
+
Returns:
|
|
134
|
+
s (torch.Tensor): Node assignment matrix
|
|
135
|
+
:math:`\mathbf{S} \in \mathbb{R}^{B \times N \times K}`
|
|
136
|
+
out (torch.Tensor): Output feature matrix
|
|
137
|
+
:math:`\mathbf{X} \in \mathbb{R}^{B \times K \times H}`
|
|
138
|
+
out_adj (torch.Tensor): Output adjacency matrix
|
|
139
|
+
:math:`\mathbf{A} \in \mathbb{R}^{B \times K \times K}`
|
|
140
|
+
"""
|
|
141
|
+
x = self.activation1(self.gcn1(x=x, adj=adj, mask=mask))
|
|
142
|
+
x = self.activation2(self.gcn2(x=x, adj=adj, mask=mask))
|
|
143
|
+
s, out, out_adj, *_ = self.pool(x=x, adj=adj, mask=mask)
|
|
144
|
+
return s, out, out_adj
|
|
145
|
+
|
|
146
|
+
def predict_embed(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
147
|
+
x = self.activation1(self.gcn1(x=x, adj=adj, mask=mask))
|
|
148
|
+
x = self.activation2(self.gcn2(x=x, adj=adj, mask=mask))
|
|
149
|
+
return x
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
__all__ = ['NodePooling', 'GraphPooling']
|