ONTraC 0.0.4b4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. ONTraC/__init__.py +0 -0
  2. ONTraC/__pycache__/__init__.cpython-311.pyc +0 -0
  3. ONTraC/__pycache__/__init__.cpython-312.pyc +0 -0
  4. ONTraC/bin/GP.py +92 -0
  5. ONTraC/bin/NTScore.py +46 -0
  6. ONTraC/bin/ONTraC.py +109 -0
  7. ONTraC/bin/__init__.py +0 -0
  8. ONTraC/bin/createDataSet.py +40 -0
  9. ONTraC/data.py +102 -0
  10. ONTraC/log.py +41 -0
  11. ONTraC/model/__init__.py +1 -0
  12. ONTraC/model/_model.py +152 -0
  13. ONTraC/model/dmon_exp_pool.py +168 -0
  14. ONTraC/model/norm_dense_gcn_conv.py +89 -0
  15. ONTraC/optparser/_GP.py +63 -0
  16. ONTraC/optparser/_IO.py +104 -0
  17. ONTraC/optparser/_NT.py +49 -0
  18. ONTraC/optparser/_ONTraC.py +81 -0
  19. ONTraC/optparser/__init__.py +4 -0
  20. ONTraC/optparser/_create_dataset.py +88 -0
  21. ONTraC/optparser/_train.py +235 -0
  22. ONTraC/run/processes.py +212 -0
  23. ONTraC/train/__init__.py +1 -0
  24. ONTraC/train/_batch_train.py +254 -0
  25. ONTraC/train/inspect_funcs.py +180 -0
  26. ONTraC/train/loss_funs.py +178 -0
  27. ONTraC/utils/NTScore.py +120 -0
  28. ONTraC/utils/__init__.py +1 -0
  29. ONTraC/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  30. ONTraC/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  31. ONTraC/utils/__pycache__/_utils.cpython-311.pyc +0 -0
  32. ONTraC/utils/__pycache__/_utils.cpython-312.pyc +0 -0
  33. ONTraC/utils/_utils.py +85 -0
  34. ONTraC/utils/decorators.py +90 -0
  35. ONTraC/utils/niche_net_constr.py +176 -0
  36. ONTraC/version.py +1 -0
  37. ONTraC-0.0.4b4.dist-info/LICENSE +21 -0
  38. ONTraC-0.0.4b4.dist-info/METADATA +166 -0
  39. ONTraC-0.0.4b4.dist-info/RECORD +42 -0
  40. ONTraC-0.0.4b4.dist-info/WHEEL +5 -0
  41. ONTraC-0.0.4b4.dist-info/entry_points.txt +5 -0
  42. ONTraC-0.0.4b4.dist-info/top_level.txt +1 -0
ONTraC/__init__.py ADDED
File without changes
ONTraC/bin/GP.py ADDED
@@ -0,0 +1,92 @@
1
+ #!/usr/bin/env python
2
+
3
+ import random
4
+ import sys
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from ONTraC.model import GraphPooling
10
+ from ONTraC.optparser import opt_GP_validate, prepare_GP_optparser
11
+ from ONTraC.run.processes import *
12
+ from ONTraC.train import GPBatchTrain, SubBatchTrainProtocol
13
+ from ONTraC.train.inspect_funcs import loss_record
14
+ from ONTraC.utils import device_validate
15
+
16
+ # ------------------------------------
17
+ # Classes
18
+ # ------------------------------------
19
+
20
+
21
+ # ------------------------------------
22
+ # Functions
23
+ # ------------------------------------
24
+ def get_inspect_funcs() -> Optional[list[Callable]]:
25
+ """
26
+ Inspect function list
27
+ :param output_dir: output dir
28
+ :param epoch_filter: epoch filter
29
+ :return: list of inspect functions
30
+ """
31
+ return [loss_record]
32
+
33
+
34
+ # ------------------------------------
35
+ # Main Function
36
+ # ------------------------------------
37
+ def main() -> None:
38
+ """
39
+ Main function
40
+ :return: None
41
+ """
42
+
43
+ # ----- prepare -----
44
+ # load parameters
45
+ options = load_parameters(opt_validate_func=opt_GP_validate, prepare_optparser_func=prepare_GP_optparser)
46
+ # device
47
+ device: torch.device = device_validate(device_name=options.device)
48
+ # load data
49
+ dataset, sample_loader = load_data(options=options)
50
+ # random seed
51
+ n_seed = t_seed = r_seed = options.seed
52
+ random.seed(a=r_seed)
53
+ torch.manual_seed(seed=t_seed)
54
+ np.random.seed(seed=n_seed)
55
+
56
+ # ----- train -----
57
+ inspect_funcs_list = get_inspect_funcs()
58
+ batch_train: SubBatchTrainProtocol = train(nn_model=GraphPooling,
59
+ options=options,
60
+ BatchTrain=GPBatchTrain,
61
+ device=device,
62
+ dataset=dataset,
63
+ sample_loader=sample_loader,
64
+ inspect_funcs=inspect_funcs_list,
65
+ model_name='GraphPooling')
66
+
67
+ # --- evaluate ---
68
+ evaluate(batch_train=batch_train, model_name='GraphPooling')
69
+
70
+ # ----- predict -----
71
+ consolidate_s_array, consolidate_out_adj_array = predict(output_dir=options.GNN_dir,
72
+ batch_train=batch_train,
73
+ dataset=dataset,
74
+ model_name='GraphPooling')
75
+
76
+ # ----- Pseudotime -----
77
+ if consolidate_s_array is not None and consolidate_out_adj_array is not None:
78
+ NTScore(options=options,
79
+ dataset=dataset,
80
+ consolidate_s_array=consolidate_s_array,
81
+ consolidate_out_adj_array=consolidate_out_adj_array)
82
+
83
+
84
+ # ------------------------------------
85
+ # Program running
86
+ # ------------------------------------
87
+ if __name__ == '__main__':
88
+ try:
89
+ main()
90
+ except KeyboardInterrupt:
91
+ sys.stderr.write("User interrupts me! ;-) See you ^.^!\n")
92
+ sys.exit(0)
ONTraC/bin/NTScore.py ADDED
@@ -0,0 +1,46 @@
1
+ #!/usr/bin/env python
2
+
3
+ import sys
4
+
5
+ import numpy as np
6
+
7
+ from ONTraC.data import load_dataset
8
+ from ONTraC.optparser import opt_NT_validate, prepare_NT_optparser
9
+ from ONTraC.run.processes import *
10
+
11
+
12
+ # ------------------------------------
13
+ # Main Function
14
+ # ------------------------------------
15
+ def main() -> None:
16
+ """
17
+ Main function
18
+ :return: None
19
+ """
20
+
21
+ # ----- prepare -----
22
+ # --- load parameters ---
23
+ options = load_parameters(opt_validate_func=opt_NT_validate, prepare_optparser_func=prepare_NT_optparser)
24
+ # --- load data ---
25
+ dataset, _ = load_dataset(options=options)
26
+ # load consolidated s_array and out_adj_array
27
+ consolidate_s_array = np.loadtxt(fname=f'{options.GNN_dir}/consolidate_s.csv.gz', delimiter=',')
28
+ consolidate_out_adj_array = np.loadtxt(fname=f'{options.GNN_dir}/consolidate_out_adj.csv.gz', delimiter=',')
29
+
30
+ # ----- Pseudotime -----
31
+ if consolidate_s_array is not None and consolidate_out_adj_array is not None:
32
+ NTScore(options=options,
33
+ dataset=dataset,
34
+ consolidate_s_array=consolidate_s_array,
35
+ consolidate_out_adj_array=consolidate_out_adj_array)
36
+
37
+
38
+ # ------------------------------------
39
+ # Program running
40
+ # ------------------------------------
41
+ if __name__ == '__main__':
42
+ try:
43
+ main()
44
+ except KeyboardInterrupt:
45
+ sys.stderr.write("User interrupts me! ;-) See you ^.^!\n")
46
+ sys.exit(0)
ONTraC/bin/ONTraC.py ADDED
@@ -0,0 +1,109 @@
1
+ #!/usr/bin/env python
2
+
3
+ import random
4
+ import sys
5
+ from typing import Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from ONTraC.log import *
11
+ from ONTraC.model import GraphPooling
12
+ from ONTraC.optparser import opt_ontrac_validate, prepare_ontrac_optparser
13
+ from ONTraC.run.processes import *
14
+ from ONTraC.train import GPBatchTrain, SubBatchTrainProtocol
15
+ from ONTraC.train.inspect_funcs import loss_record
16
+ from ONTraC.utils import device_validate
17
+ from ONTraC.utils.niche_net_constr import (construct_niche_network,
18
+ gen_samples_yaml,
19
+ load_original_data)
20
+
21
+
22
+ # ------------------------------------
23
+ # Functions
24
+ # ------------------------------------
25
+ def get_inspect_funcs() -> Optional[list[Callable]]:
26
+ """
27
+ Inspect function list
28
+ :param output_dir: output dir
29
+ :param epoch_filter: epoch filter
30
+ :return: list of inspect functions
31
+ """
32
+ return [loss_record]
33
+
34
+
35
+ # ------------------------------------
36
+ # Main Function
37
+ # ------------------------------------
38
+ def main() -> None:
39
+ """
40
+ main function
41
+ Input data files information should be stored in a YAML file.
42
+ """
43
+
44
+ # prepare options
45
+ options = load_parameters(opt_validate_func=opt_ontrac_validate, prepare_optparser_func=prepare_ontrac_optparser)
46
+
47
+ # ----- Niche Network Construct -----
48
+ # load original data
49
+ ori_data_df = load_original_data(options=options)
50
+
51
+ # define edges for each sample
52
+ construct_niche_network(options=options, ori_data_df=ori_data_df)
53
+
54
+ # save samples.yaml
55
+ gen_samples_yaml(options=options, ori_data_df=ori_data_df)
56
+
57
+ # ----- Graph Pooling -----
58
+ # device
59
+ device: torch.device = device_validate(device_name=options.device)
60
+ # load data
61
+ dataset, sample_loader = load_data(options=options)
62
+ # random seed
63
+ n_seed = t_seed = r_seed = options.seed
64
+ random.seed(a=r_seed)
65
+ torch.manual_seed(seed=t_seed)
66
+ np.random.seed(seed=n_seed)
67
+ # train
68
+ inspect_funcs_list = get_inspect_funcs()
69
+ batch_train: SubBatchTrainProtocol = train(nn_model=GraphPooling,
70
+ options=options,
71
+ BatchTrain=GPBatchTrain,
72
+ device=device,
73
+ dataset=dataset,
74
+ sample_loader=sample_loader,
75
+ inspect_funcs=inspect_funcs_list,
76
+ model_name='GraphPooling')
77
+ # evaluate
78
+ evaluate(batch_train=batch_train, model_name='GraphPooling')
79
+ # predict
80
+ consolidate_s_array, consolidate_out_adj_array = predict(output_dir=options.GNN_dir,
81
+ batch_train=batch_train,
82
+ dataset=dataset,
83
+ model_name='GraphPooling')
84
+ # niche cluster
85
+ if consolidate_s_array is not None:
86
+ graph_pooling_output(ori_data_df=ori_data_df,
87
+ dataset=dataset,
88
+ rel_params=get_rel_params(
89
+ options=options, params=read_yaml_file(f'{options.preprocessing_dir}/samples.yaml')),
90
+ consolidate_s_array=consolidate_s_array,
91
+ output_dir=options.GNN_dir)
92
+
93
+ # ----- NT score -----
94
+ if consolidate_s_array is not None and consolidate_out_adj_array is not None:
95
+ NTScore(options=options,
96
+ dataset=dataset,
97
+ consolidate_s_array=consolidate_s_array,
98
+ consolidate_out_adj_array=consolidate_out_adj_array)
99
+
100
+
101
+ # ------------------------------------
102
+ # Program running
103
+ # ------------------------------------
104
+ if __name__ == '__main__':
105
+ try:
106
+ main()
107
+ except KeyboardInterrupt:
108
+ sys.stderr.write("User interrupts me! ;-) See you ^.^!\n")
109
+ sys.exit(0)
ONTraC/bin/__init__.py ADDED
File without changes
@@ -0,0 +1,40 @@
1
+ #!/usr/bin/env python
2
+
3
+ import sys
4
+
5
+ from ONTraC.log import *
6
+ from ONTraC.optparser import (opt_create_ds_validate, prepare_create_ds_optparser)
7
+ from ONTraC.utils.niche_net_constr import load_original_data, construct_niche_network, gen_samples_yaml
8
+
9
+
10
+ # ------------------------------------
11
+ # Main Function
12
+ # ------------------------------------
13
+ def main() -> None:
14
+ """
15
+ main function
16
+ Input data files information should be stored in a YAML file.
17
+ """
18
+
19
+ # prepare options
20
+ options = opt_create_ds_validate(prepare_create_ds_optparser())
21
+
22
+ # load original data
23
+ ori_data_df = load_original_data(options=options, data_file=options.dataset)
24
+
25
+ # define edges for each sample
26
+ construct_niche_network(options=options, ori_data_df=ori_data_df)
27
+
28
+ # save samples.yaml
29
+ gen_samples_yaml(options=options, ori_data_df=ori_data_df)
30
+
31
+
32
+ # ------------------------------------
33
+ # Program running
34
+ # ------------------------------------
35
+ if __name__ == '__main__':
36
+ try:
37
+ main()
38
+ except KeyboardInterrupt:
39
+ sys.stderr.write("User interrupts me! ;-) See you ^.^!\n")
40
+ sys.exit(0)
ONTraC/data.py ADDED
@@ -0,0 +1,102 @@
1
+ from optparse import Values
2
+ from typing import Dict, List, Tuple
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import torch_geometric.transforms as T
8
+ from torch_geometric.data import Data, InMemoryDataset
9
+ from torch_geometric.loader import DenseDataLoader
10
+
11
+ from .log import *
12
+ from .utils import count_lines, device_validate, get_rel_params, read_yaml_file
13
+
14
+
15
+ # ------------------------------------
16
+ # Classes
17
+ # ------------------------------------
18
+ class SpatailOmicsDataset(InMemoryDataset):
19
+
20
+ def __init__(self, root, params: Dict, transform=None, pre_transform=None):
21
+ self.params = params
22
+ super(SpatailOmicsDataset, self).__init__(root, transform, pre_transform)
23
+ self.data, self.slices = torch.load(self.processed_paths[0])
24
+
25
+ @property
26
+ def raw_file_names(self):
27
+ # return list(
28
+ # flatten([[sample for name, sample in data.items() if name != 'Name'] for data in self.params['Data']]))
29
+ return []
30
+
31
+ @property
32
+ def processed_file_names(self):
33
+ return ['data.pt']
34
+
35
+ def download(self):
36
+ pass
37
+
38
+ def process(self):
39
+ data_list = []
40
+ for index, sample in enumerate(self.params['Data']):
41
+ info(f'Processing sample {index + 1} of {len(self.params["Data"])}')
42
+ data = Data(
43
+ x=torch.from_numpy(np.loadtxt(sample['Features'], dtype=np.float32, delimiter=',')),
44
+ edge_index=torch.from_numpy(np.loadtxt(sample['EdgeIndex'], dtype=np.int64,
45
+ delimiter=',')).t().contiguous(),
46
+ # TODO: support 3D coordinates
47
+ pos=torch.from_numpy(pd.read_csv(sample['Coordinates'])[['x', 'y']].values),
48
+ name=sample['Name'])
49
+ data_list.append(data)
50
+ data, slices = self.collate(data_list)
51
+ torch.save((data, slices), self.processed_paths[0])
52
+
53
+
54
+ # ------------------------------------
55
+ # Misc functions
56
+ # ------------------------------------
57
+ def max_nodes(samples: List[Dict[str, str]]) -> int:
58
+ """
59
+ Get the maximum number of nodes in a dataset
60
+ :param params: List[Dict[str, str], list of samples
61
+ :return: int, maximum number of nodes
62
+ """
63
+ max_nodes = 0
64
+ for sample in samples:
65
+ max_nodes = max(max_nodes, count_lines(sample['Coordinates']))
66
+ return max_nodes
67
+
68
+
69
+ def load_dataset(options: Values) -> Tuple[SpatailOmicsDataset, Data]:
70
+ device = device_validate()
71
+ params = read_yaml_file(f'{options.preprocessing_dir}/samples.yaml')
72
+ rel_params = get_rel_params(options, params)
73
+ dataset = create_torch_dataset(options, rel_params)
74
+ all_sample_loader = DenseDataLoader(dataset, batch_size=len(dataset))
75
+ data = next(iter(all_sample_loader)).to(device)
76
+ return dataset, data
77
+
78
+
79
+ # ------------------------------------
80
+ # Flow control functions
81
+ # ------------------------------------
82
+ def create_torch_dataset(options: Values, params: Dict) -> SpatailOmicsDataset:
83
+ """
84
+ Create torch dataset
85
+ :param params: Dict, input samples
86
+ :return: None
87
+ """
88
+
89
+ # ------------------------------------
90
+ # Step 1: Get the maximum number of nodes
91
+ m_nodes = max_nodes(params['Data'])
92
+ # upcelling m_nodes to the nearest 100
93
+ m_nodes = int(np.ceil(m_nodes / 100.0)) * 100
94
+ info(f'Maximum number of nodes: {m_nodes}')
95
+ # ------------------------------------
96
+
97
+ # ------------------------------------
98
+ # Step 2: Create torch dataset
99
+ dataset = SpatailOmicsDataset(root=options.preprocessing_dir, params=params,
100
+ transform=T.ToDense(m_nodes)) # transform edge_index to adj matrix
101
+ # dataset = SpatailOmicsDataset(root=options.input, params=params)
102
+ return dataset
ONTraC/log.py ADDED
@@ -0,0 +1,41 @@
1
+ import sys
2
+ import time
3
+
4
+
5
+ def get_current_time() -> str:
6
+ return time.strftime('%H:%M:%S', time.localtime())
7
+
8
+
9
+ def write_direct_message(message: str):
10
+ curr_time_str = get_current_time()
11
+ sys.stdout.write(f'{curr_time_str} --- {message}\n')
12
+ sys.stdout.flush()
13
+
14
+
15
+ def debug(message: str):
16
+ write_direct_message(f'DEBUG: {message}')
17
+
18
+
19
+ def info(message: str):
20
+ write_direct_message(f'INFO: {message}')
21
+
22
+
23
+ def write_direct_message_err(message: str):
24
+ curr_time_str = get_current_time()
25
+ sys.stderr.write(f'{curr_time_str} --- {message}\n')
26
+ sys.stderr.flush()
27
+
28
+
29
+ def warning(message: str):
30
+ write_direct_message_err(f'WARNING: {message}')
31
+
32
+
33
+ def error(message: str):
34
+ write_direct_message_err(f'ERROR: {message}')
35
+
36
+
37
+ def critical(message: str):
38
+ write_direct_message_err(f'CRITICAL: {message}')
39
+
40
+
41
+ __all__ = ['debug', 'info', 'warning', 'error', 'critical']
@@ -0,0 +1 @@
1
+ from ._model import *
ONTraC/model/_model.py ADDED
@@ -0,0 +1,152 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from ..log import *
7
+ from .dmon_exp_pool import DMoNPooling
8
+ from .norm_dense_gcn_conv import NormDenseGCNConv
9
+
10
+
11
+ class NodePooling(torch.nn.Module):
12
+ """
13
+ NodePooling
14
+ """
15
+
16
+ def __init__(self, input_feats, k: int, dropout: float = 0, exponent: float = 1, *args, **kwargs) -> None:
17
+ super().__init__(*args, **kwargs)
18
+ self.dropout = dropout
19
+ self.exponent = exponent
20
+ self.pool = DMoNPooling(channels=input_feats, k=k, dropout=0, exponent=self.exponent)
21
+ self.k = k
22
+
23
+ self.reset_parameters()
24
+
25
+ def reset_parameters(self) -> None:
26
+ self.pool.reset_parameters()
27
+
28
+ def forward(self,
29
+ x: Tensor,
30
+ adj: Tensor,
31
+ mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
32
+ r"""
33
+ forward function
34
+ Args:
35
+ x (torch.Tensor): Node feature tensor
36
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
37
+ batch-size :math:`B`, (maximum) number of nodes :math:`N` for
38
+ each graph, and feature dimension :math:`F`.
39
+ adj (torch.Tensor): Adjacency tensor
40
+ :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
41
+ mask (torch.Tensor, optional): Mask matrix
42
+ :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
43
+ the valid nodes for each graph. (default: :obj:`None`)
44
+ Returns:
45
+ Tensor: output feature matrix
46
+ """
47
+ s, out, out_adj, spectral_loss, ortho_loss, cluster_loss = self.pool(x=x, adj=adj, mask=mask)
48
+ return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss
49
+
50
+ def predict(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
51
+ s, out, out_adj, *_ = self.pool(x=x, adj=adj, mask=mask)
52
+ return s, out, out_adj
53
+
54
+
55
+ class GraphPooling(torch.nn.Module):
56
+ """
57
+ GNN with Node Pooling
58
+ """
59
+
60
+ def __init__(self,
61
+ input_feats: int,
62
+ hidden_feats: int,
63
+ k: int,
64
+ dropout: float = 0,
65
+ exponent: float = 1,
66
+ *args,
67
+ **kwargs) -> None:
68
+ super().__init__(*args, **kwargs)
69
+ self.gcn1 = NormDenseGCNConv(input_feats, hidden_feats)
70
+ self.activation1 = torch.nn.SELU()
71
+ self.gcn2 = NormDenseGCNConv(hidden_feats, hidden_feats)
72
+ self.activation2 = torch.nn.SELU()
73
+ self.pool = NodePooling(input_feats=hidden_feats, k=k, dropout=dropout, exponent=exponent)
74
+ self.k = k
75
+
76
+ self.reset_parameters()
77
+
78
+ def reset_parameters(self) -> None:
79
+ self.gcn1.reset_parameters()
80
+ self.gcn2.reset_parameters()
81
+ self.pool.reset_parameters()
82
+
83
+ def forward(self,
84
+ x: Tensor,
85
+ adj: Tensor,
86
+ mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
87
+ r"""
88
+ forward function
89
+ X' = \mathbf{\hat{L}}X\mathbf{\Theta}
90
+ \mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}
91
+ \mathbf(\hat{D}) = \sum_{j=1}^N \mathbf{\hat{A}}_{ij}
92
+ \mathbf{\hat{L}} = \mathbf{\hat{D}}^{-1/2}\mathbf{\hat{A}}\mathbf{\hat{D}}^{-1/2}
93
+ Args:
94
+ x (torch.Tensor): Node feature tensor
95
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
96
+ batch-size :math:`B`, (maximum) number of nodes :math:`N` for
97
+ each graph, and feature dimension :math:`F`.
98
+ adj (torch.Tensor): Adjacency tensor
99
+ :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
100
+ mask (torch.Tensor, optional): Mask matrix
101
+ :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
102
+ the valid nodes for each graph. (default: :obj:`None`)
103
+ Returns:
104
+ Tensor: output feature matrix
105
+ """
106
+ x = self.activation1(self.gcn1(x=x, adj=adj, mask=mask))
107
+ x = self.activation2(self.gcn2(x=x, adj=adj, mask=mask))
108
+ s, out, out_adj, spectral_loss, ortho_loss, cluster_loss = self.pool(x=x, adj=adj, mask=mask)
109
+ return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss
110
+
111
+ def evaluate(self,
112
+ x: Tensor,
113
+ adj: Tensor,
114
+ mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
115
+ x = self.activation1(self.gcn1(x=x, adj=adj, mask=mask))
116
+ x = self.activation2(self.gcn2(x=x, adj=adj, mask=mask))
117
+ s, out, out_adj, spectral_loss, ortho_loss, cluster_loss = self.pool(x=x, adj=adj, mask=mask)
118
+ return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss
119
+
120
+ def predict(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
121
+ r"""
122
+ predict function
123
+ Args:
124
+ x (torch.Tensor): Node feature tensor
125
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
126
+ batch-size :math:`B`, (maximum) number of nodes :math:`N` for
127
+ each graph, and feature dimension :math:`F`.
128
+ adj (torch.Tensor): Adjacency tensor
129
+ :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
130
+ mask (torch.Tensor, optional): Mask matrix\
131
+ :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating\
132
+ the valid nodes for each graph. (default: :obj:`None`)
133
+ Returns:
134
+ s (torch.Tensor): Node assignment matrix
135
+ :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times K}`
136
+ out (torch.Tensor): Output feature matrix
137
+ :math:`\mathbf{X} \in \mathbb{R}^{B \times K \times H}`
138
+ out_adj (torch.Tensor): Output adjacency matrix
139
+ :math:`\mathbf{A} \in \mathbb{R}^{B \times K \times K}`
140
+ """
141
+ x = self.activation1(self.gcn1(x=x, adj=adj, mask=mask))
142
+ x = self.activation2(self.gcn2(x=x, adj=adj, mask=mask))
143
+ s, out, out_adj, *_ = self.pool(x=x, adj=adj, mask=mask)
144
+ return s, out, out_adj
145
+
146
+ def predict_embed(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) -> Tensor:
147
+ x = self.activation1(self.gcn1(x=x, adj=adj, mask=mask))
148
+ x = self.activation2(self.gcn2(x=x, adj=adj, mask=mask))
149
+ return x
150
+
151
+
152
+ __all__ = ['NodePooling', 'GraphPooling']