gsg 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ from . import preprocess as pp
2
+ from . import train
3
+ from . import utils
4
+ from . import models as model
@@ -0,0 +1,46 @@
1
+ from .edcoder import PreModel
2
+
3
+ #####################################################################################################################################
4
+ # Adapted from: #
5
+ # @inproceedings{hou2022graphmae, #
6
+ # title={GraphMAE: Self-Supervised Masked Graph Autoencoders}, #
7
+ # author={Hou, Zhenyu and Liu, Xiao and Cen, Yukuo and Dong, Yuxiao and Yang, Hongxia and Wang, Chunjie and Tang, Jie}, #
8
+ # booktitle={Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, #
9
+ # pages={594--604}, #
10
+ # year={2022} #
11
+ # } #
12
+ #####################################################################################################################################
13
+
14
+ def build_model(args):
15
+ num_hidden = args.num_hidden
16
+ num_layers = args.num_layers
17
+ in_drop = args.in_drop
18
+ norm = args.norm
19
+ encoder_type = "gin"
20
+ if args.imputation:
21
+ decoder_type = "mlp"
22
+ else:
23
+ decoder_type = 'gin'
24
+ mask_rate = args.mask_rate
25
+ replace_rate = args.replace_rate
26
+
27
+
28
+ activation = args.activation
29
+ alpha_l = args.alpha_l
30
+ num_features = args.num_features
31
+
32
+
33
+ model = PreModel(
34
+ in_dim=num_features,
35
+ num_hidden=num_hidden,
36
+ num_layers=num_layers,
37
+ activation=activation,
38
+ feat_drop=in_drop,
39
+ encoder_type=encoder_type,
40
+ decoder_type=decoder_type,
41
+ mask_rate=mask_rate,
42
+ norm=norm,
43
+ replace_rate=replace_rate,
44
+ alpha_l=alpha_l,
45
+ )
46
+ return model
@@ -0,0 +1,197 @@
1
+ from typing import Optional
2
+ from itertools import chain
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .gin import GIN
9
+ from .loss_func import sce_loss
10
+ from .utils import drop_edge
11
+
12
+
13
+ def setup_module(m_type, enc_dec, in_dim, num_hidden, out_dim, num_layers, dropout, activation, residual, norm) -> nn.Module:
14
+ if m_type == "gin":
15
+ mod = GIN(
16
+ in_dim=in_dim,
17
+ num_hidden=num_hidden,
18
+ out_dim=out_dim,
19
+ num_layers=num_layers,
20
+ dropout=dropout,
21
+ activation=activation,
22
+ residual=residual,
23
+ norm=norm,
24
+ encoding=(enc_dec == "encoding"),
25
+ )
26
+
27
+ elif m_type == "mlp":
28
+ # * just for decoder
29
+ mod = nn.Sequential(
30
+ nn.Linear(in_dim, num_hidden),
31
+ nn.PReLU(),
32
+ nn.Dropout(0.2),
33
+ nn.Linear(num_hidden, out_dim)
34
+ )
35
+ elif m_type == "linear":
36
+ mod = nn.Linear(in_dim, out_dim)
37
+ else:
38
+ raise NotImplementedError
39
+
40
+ return mod
41
+
42
+
43
+ class PreModel(nn.Module):
44
+ def __init__(
45
+ self,
46
+ in_dim: int,
47
+ num_hidden: int,
48
+ num_layers: int,
49
+ activation: str,
50
+ feat_drop: float,
51
+ norm: Optional[str],
52
+ mask_rate: float = 0.3,
53
+ encoder_type: str = "gin",
54
+ decoder_type: str = "gin",
55
+ loss_fn: str = "sce",
56
+ drop_edge_rate: float = 0.0,
57
+ replace_rate: float = 0.1,
58
+ alpha_l: float = 2,
59
+ residual: bool = False,
60
+ concat_hidden: bool = False,
61
+ ):
62
+ super(PreModel, self).__init__()
63
+ self._mask_rate = mask_rate
64
+
65
+ self._encoder_type = encoder_type
66
+ self._decoder_type = decoder_type
67
+ self._drop_edge_rate = drop_edge_rate
68
+ self._output_hidden_size = num_hidden
69
+ self._concat_hidden = concat_hidden
70
+
71
+ self._replace_rate = replace_rate
72
+ self._mask_token_rate = 1 - self._replace_rate
73
+
74
+ enc_num_hidden = num_hidden
75
+ dec_in_dim = num_hidden
76
+ dec_num_hidden = num_hidden
77
+
78
+ # build encoder
79
+ self.encoder = setup_module(
80
+ m_type=encoder_type,
81
+ enc_dec="encoding",
82
+ in_dim=in_dim,
83
+ num_hidden=enc_num_hidden,
84
+ out_dim=enc_num_hidden,
85
+ num_layers=num_layers,
86
+ activation=activation,
87
+ dropout=feat_drop,
88
+ residual=residual,
89
+ norm=norm,
90
+ )
91
+
92
+ # build decoder for attribute prediction
93
+ self.decoder = setup_module(
94
+ m_type=decoder_type,
95
+ enc_dec="decoding",
96
+ in_dim=dec_in_dim,
97
+ num_hidden=dec_num_hidden,
98
+ out_dim=in_dim,
99
+ num_layers=1,
100
+ activation=activation,
101
+ dropout=feat_drop,
102
+ residual=residual,
103
+ norm=norm,
104
+ )
105
+
106
+ self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim))
107
+ if concat_hidden:
108
+ self.encoder_to_decoder = nn.Linear(dec_in_dim * num_layers, dec_in_dim, bias=False)
109
+ else:
110
+ self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False)
111
+
112
+ # * setup loss function
113
+ self.criterion = partial(sce_loss, alpha=alpha_l)
114
+
115
+ @property
116
+ def output_hidden_dim(self):
117
+ return self._output_hidden_size
118
+
119
+ def encoding_mask_noise(self, g, x, mask_rate=0.3):
120
+ num_nodes = g.num_nodes()
121
+ perm = torch.randperm(num_nodes, device=x.device)
122
+ num_mask_nodes = int(mask_rate * num_nodes)
123
+
124
+ # random masking
125
+ num_mask_nodes = int(mask_rate * num_nodes)
126
+ mask_nodes = perm[: num_mask_nodes]
127
+ keep_nodes = perm[num_mask_nodes: ]
128
+
129
+ if self._replace_rate > 0:
130
+ num_noise_nodes = int(self._replace_rate * num_mask_nodes)
131
+ perm_mask = torch.randperm(num_mask_nodes, device=x.device)
132
+ token_nodes = mask_nodes[perm_mask[: int(self._mask_token_rate * num_mask_nodes)]]
133
+ noise_nodes = mask_nodes[perm_mask[-int(self._replace_rate * num_mask_nodes):]]
134
+ noise_to_be_chosen = torch.randperm(num_nodes, device=x.device)[:num_noise_nodes]
135
+
136
+ out_x = x.clone()
137
+ out_x[token_nodes] = 0.0
138
+ out_x[noise_nodes] = x[noise_to_be_chosen]
139
+ else:
140
+ out_x = x.clone()
141
+ token_nodes = mask_nodes
142
+ out_x[mask_nodes] = 0.0
143
+
144
+ out_x[token_nodes] += self.enc_mask_token
145
+ use_g = g.clone()
146
+
147
+ return use_g, out_x, (mask_nodes, keep_nodes)
148
+
149
+ def forward(self, g, x):
150
+ # ---- attribute reconstruction ----
151
+ loss = self.mask_attr_prediction(g, x)
152
+ loss_item = {"loss": loss.item()}
153
+ return loss, loss_item
154
+
155
+ def mask_attr_prediction(self, g, x, test=False):
156
+ pre_use_g, use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(g, x, self._mask_rate)
157
+
158
+ if self._drop_edge_rate > 0:
159
+ use_g, masked_edges = drop_edge(pre_use_g, self._drop_edge_rate, return_edges=True)
160
+ else:
161
+ use_g = pre_use_g
162
+
163
+ enc_rep, all_hidden = self.encoder(use_g, use_x, return_hidden=True)
164
+ if self._concat_hidden:
165
+ enc_rep = torch.cat(all_hidden, dim=1)
166
+
167
+ # ---- attribute reconstruction ----
168
+ rep = self.encoder_to_decoder(enc_rep)
169
+ if self._decoder_type not in ("mlp", "linear"):
170
+ rep[mask_nodes] = 0
171
+
172
+ if self._decoder_type in ("mlp", "linear") :
173
+ recon = self.decoder(rep)
174
+ else:
175
+ recon = self.decoder(pre_use_g, rep)
176
+
177
+ x_init = x[mask_nodes]
178
+ x_rec = recon[mask_nodes]
179
+ # x_init = x
180
+ # x_rec = recon
181
+
182
+ if not test:
183
+ loss = self.criterion(x_rec, x_init)
184
+ return loss
185
+ return recon.detach().cpu().numpy()
186
+
187
+ def embed(self, g, x):
188
+ rep = self.encoder(g, x)
189
+ return rep
190
+
191
+ @property
192
+ def enc_params(self):
193
+ return self.encoder.parameters()
194
+
195
+ @property
196
+ def dec_params(self):
197
+ return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])
@@ -0,0 +1,202 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import dgl.function as fn
5
+ from dgl.utils import expand_as_pair
6
+
7
+ from .utils import create_activation, create_norm
8
+
9
+
10
+ class GIN(nn.Module):
11
+ def __init__(self,
12
+ in_dim,
13
+ num_hidden,
14
+ out_dim,
15
+ num_layers,
16
+ dropout,
17
+ activation,
18
+ residual,
19
+ norm,
20
+ encoding=False,
21
+ learn_eps=False,
22
+ aggr="sum",
23
+ ):
24
+ super(GIN, self).__init__()
25
+ self.out_dim = out_dim
26
+ self.num_layers = num_layers
27
+ self.layers = nn.ModuleList()
28
+ self.activation = activation
29
+ self.dropout = dropout
30
+
31
+ last_activation = create_activation(activation) if encoding else None
32
+ last_residual = encoding and residual
33
+ last_norm = norm if encoding else None
34
+
35
+ if num_layers == 1:
36
+ apply_func = MLP(2, in_dim, num_hidden, out_dim, activation=activation, norm=norm)
37
+ if last_norm:
38
+ apply_func = ApplyNodeFunc(apply_func, norm=norm, activation=activation)
39
+ self.layers.append(GINConv(in_dim, out_dim, apply_func, init_eps=0, learn_eps=learn_eps, residual=last_residual))
40
+ else:
41
+ # input projection (no residual)
42
+ self.layers.append(GINConv(
43
+ in_dim,
44
+ num_hidden,
45
+ ApplyNodeFunc(MLP(2, in_dim, num_hidden, num_hidden, activation=activation, norm=norm), activation=activation, norm=norm),
46
+ init_eps=0,
47
+ learn_eps=learn_eps,
48
+ residual=residual)
49
+ )
50
+ # hidden layers
51
+ for l in range(1, num_layers - 1):
52
+ # due to multi-head, the in_dim = num_hidden * num_heads
53
+ self.layers.append(GINConv(
54
+ num_hidden, num_hidden,
55
+ ApplyNodeFunc(MLP(2, num_hidden, num_hidden, num_hidden, activation=activation, norm=norm), activation=activation, norm=norm),
56
+ init_eps=0,
57
+ learn_eps=learn_eps,
58
+ residual=residual)
59
+ )
60
+ # output projection
61
+ apply_func = MLP(2, num_hidden, num_hidden, out_dim, activation=activation, norm=norm)
62
+ if last_norm:
63
+ apply_func = ApplyNodeFunc(apply_func, activation=activation, norm=norm)
64
+
65
+ self.layers.append(GINConv(num_hidden, out_dim, apply_func, init_eps=0, learn_eps=learn_eps, residual=last_residual))
66
+
67
+ self.head = nn.Identity()
68
+
69
+ def forward(self, g, inputs, return_hidden=False):
70
+ h = inputs
71
+ hidden_list = []
72
+ for l in range(self.num_layers):
73
+ h = F.dropout(h, p=self.dropout, training=self.training)
74
+ h = self.layers[l](g, h)
75
+ hidden_list.append(h)
76
+ # output projection
77
+ if return_hidden:
78
+ return self.head(h), hidden_list
79
+ else:
80
+ return self.head(h)
81
+
82
+ def reset_classifier(self, num_classes):
83
+ self.head = nn.Linear(self.out_dim, num_classes)
84
+
85
+
86
+ class GINConv(nn.Module):
87
+ def __init__(self,
88
+ in_dim,
89
+ out_dim,
90
+ apply_func,
91
+ aggregator_type="sum",
92
+ init_eps=0,
93
+ learn_eps=False,
94
+ residual=False,
95
+ ):
96
+ super().__init__()
97
+ self._in_feats = in_dim
98
+ self._out_feats = out_dim
99
+ self.apply_func = apply_func
100
+
101
+ self._aggregator_type = aggregator_type
102
+ if aggregator_type == 'sum':
103
+ self._reducer = fn.sum
104
+ elif aggregator_type == 'max':
105
+ self._reducer = fn.max
106
+ elif aggregator_type == 'mean':
107
+ self._reducer = fn.mean
108
+ else:
109
+ raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
110
+
111
+ if learn_eps:
112
+ self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps]))
113
+ else:
114
+ self.register_buffer('eps', torch.FloatTensor([init_eps]))
115
+
116
+ if residual:
117
+ if self._in_feats != self._out_feats:
118
+ self.res_fc = nn.Linear(
119
+ self._in_feats, self._out_feats, bias=False)
120
+ print("! Linear Residual !")
121
+ else:
122
+ print("Identity Residual ")
123
+ self.res_fc = nn.Identity()
124
+ else:
125
+ self.register_buffer('res_fc', None)
126
+
127
+ def forward(self, graph, feat):
128
+ with graph.local_scope():
129
+ aggregate_fn = fn.copy_src('h', 'm')
130
+
131
+ feat_src, feat_dst = expand_as_pair(feat, graph)
132
+ graph.srcdata['h'] = feat_src
133
+ graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))
134
+ rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
135
+ if self.apply_func is not None:
136
+ rst = self.apply_func(rst)
137
+
138
+ if self.res_fc is not None:
139
+ rst = rst + self.res_fc(feat_dst)
140
+
141
+ return rst
142
+
143
+
144
+ class ApplyNodeFunc(nn.Module):
145
+ """Update the node feature hv with MLP, BN and ReLU."""
146
+ def __init__(self, mlp, norm="batchnorm", activation="relu"):
147
+ super(ApplyNodeFunc, self).__init__()
148
+ self.mlp = mlp
149
+ norm_func = create_norm(norm)
150
+ if norm_func is None:
151
+ self.norm = nn.Identity()
152
+ else:
153
+ self.norm = norm_func(self.mlp.output_dim)
154
+ self.act = create_activation(activation)
155
+
156
+ def forward(self, h):
157
+ h = self.mlp(h)
158
+ h = self.norm(h)
159
+ h = self.act(h)
160
+ return h
161
+
162
+
163
+ class MLP(nn.Module):
164
+ """MLP with linear output"""
165
+ def __init__(self, num_layers, input_dim, hidden_dim, output_dim, activation="relu", norm="batchnorm"):
166
+ super(MLP, self).__init__()
167
+ self.linear_or_not = True # default is linear model
168
+ self.num_layers = num_layers
169
+ self.output_dim = output_dim
170
+
171
+ if num_layers < 1:
172
+ raise ValueError("number of layers should be positive!")
173
+ elif num_layers == 1:
174
+ # Linear model
175
+ self.linear = nn.Linear(input_dim, output_dim)
176
+ else:
177
+ # Multi-layer model
178
+ self.linear_or_not = False
179
+ self.linears = torch.nn.ModuleList()
180
+ self.norms = torch.nn.ModuleList()
181
+ self.activations = torch.nn.ModuleList()
182
+
183
+ self.linears.append(nn.Linear(input_dim, hidden_dim))
184
+ for layer in range(num_layers - 2):
185
+ self.linears.append(nn.Linear(hidden_dim, hidden_dim))
186
+ self.linears.append(nn.Linear(hidden_dim, output_dim))
187
+
188
+ for layer in range(num_layers - 1):
189
+ self.norms.append(create_norm(norm)(hidden_dim))
190
+ self.activations.append(create_activation(activation))
191
+
192
+ def forward(self, x):
193
+ if self.linear_or_not:
194
+ # If linear model
195
+ return self.linear(x)
196
+ else:
197
+ # If MLP
198
+ h = x
199
+ for i in range(self.num_layers - 1):
200
+ h = self.norms[i](self.linears[i](h))
201
+ h = self.activations[i](h)
202
+ return self.linears[-1](h)
@@ -0,0 +1,11 @@
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def sce_loss(x, y, alpha=3):
5
+ x = F.normalize(x, p=2, dim=-1)
6
+ y = F.normalize(y, p=2, dim=-1)
7
+
8
+ loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
9
+
10
+ loss = loss.mean()
11
+ return loss
@@ -0,0 +1,177 @@
1
+ import random
2
+ import logging
3
+
4
+ import dgl
5
+ import torch
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ from tqdm import tqdm
9
+ from functools import partial
10
+ from torch import optim as optim
11
+
12
+
13
+
14
+ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO)
15
+
16
+ def accuracy(y_pred, y_true):
17
+ y_true = y_true.squeeze().long()
18
+ preds = y_pred.max(1)[1].type_as(y_true)
19
+ correct = preds.eq(y_true).double()
20
+ correct = correct.sum().item()
21
+ return correct / len(y_true)
22
+
23
+
24
+ def set_random_seed(seed):
25
+ random.seed(seed)
26
+ np.random.seed(seed)
27
+ torch.manual_seed(seed)
28
+ torch.cuda.manual_seed(seed)
29
+ torch.cuda.manual_seed_all(seed)
30
+ torch.backends.cudnn.determinstic = True
31
+
32
+
33
+ def get_current_lr(optimizer):
34
+ return optimizer.state_dict()["param_groups"][0]["lr"]
35
+
36
+
37
+ def create_activation(name):
38
+ if name == "relu":
39
+ return nn.ReLU()
40
+ elif name == "gelu":
41
+ return nn.GELU()
42
+ elif name == "prelu":
43
+ return nn.PReLU()
44
+ elif name is None:
45
+ return nn.Identity()
46
+ elif name == "elu":
47
+ return nn.ELU()
48
+ else:
49
+ raise NotImplementedError(f"{name} is not implemented.")
50
+
51
+
52
+ def create_norm(name):
53
+ if name == "layernorm":
54
+ return nn.LayerNorm
55
+ elif name == "batchnorm":
56
+ return nn.BatchNorm1d
57
+ elif name == "graphnorm":
58
+ return partial(NormLayer, norm_type="groupnorm")
59
+ else:
60
+ return None
61
+
62
+
63
+ def create_optimizer(opt, model, lr, weight_decay, get_num_layer=None, get_layer_scale=None):
64
+ opt_lower = opt.lower()
65
+
66
+ parameters = model.parameters()
67
+ opt_args = dict(lr=lr, weight_decay=weight_decay)
68
+
69
+ opt_split = opt_lower.split("_")
70
+ opt_lower = opt_split[-1]
71
+ if opt_lower == "adam":
72
+ optimizer = optim.Adam(parameters, **opt_args)
73
+ elif opt_lower == "adamw":
74
+ optimizer = optim.AdamW(parameters, **opt_args)
75
+ elif opt_lower == "adadelta":
76
+ optimizer = optim.Adadelta(parameters, **opt_args)
77
+ elif opt_lower == "radam":
78
+ optimizer = optim.RAdam(parameters, **opt_args)
79
+ elif opt_lower == "sgd":
80
+ opt_args["momentum"] = 0.9
81
+ return optim.SGD(parameters, **opt_args)
82
+ else:
83
+ assert False and "Invalid optimizer"
84
+
85
+ return optimizer
86
+
87
+
88
+ # -------------------
89
+ def pretrain(model, graph, feat, optimizer, max_epoch, device):
90
+ logging.info("start training..")
91
+ graph = graph.to(device)
92
+ x = feat.to(device)
93
+
94
+ epoch_iter = tqdm(range(max_epoch))
95
+ for epoch in epoch_iter:
96
+ model.train()
97
+
98
+ loss, loss_dict = model(graph, x)
99
+
100
+ optimizer.zero_grad()
101
+ loss.backward()
102
+ optimizer.step()
103
+
104
+ epoch_iter.set_description(f"# Epoch {epoch}: train_loss: {loss.item():.4f}")
105
+ return model
106
+
107
+
108
+ def mask_edge(graph, mask_prob):
109
+ E = graph.num_edges()
110
+
111
+ mask_rates = torch.FloatTensor(np.ones(E) * mask_prob)
112
+ masks = torch.bernoulli(1 - mask_rates)
113
+ mask_idx = masks.nonzero().squeeze(1)
114
+ return mask_idx
115
+
116
+
117
+ def drop_edge(graph, drop_rate, return_edges=False):
118
+ if drop_rate <= 0:
119
+ return graph
120
+
121
+ n_node = graph.num_nodes()
122
+ edge_mask = mask_edge(graph, drop_rate)
123
+ src = graph.edges()[0]
124
+ dst = graph.edges()[1]
125
+
126
+ nsrc = src[edge_mask]
127
+ ndst = dst[edge_mask]
128
+
129
+ ng = dgl.graph((nsrc, ndst), num_nodes=n_node)
130
+ ng = ng.add_self_loop()
131
+
132
+ dsrc = src[~edge_mask]
133
+ ddst = dst[~edge_mask]
134
+
135
+ if return_edges:
136
+ return ng, (dsrc, ddst)
137
+ return ng
138
+
139
+
140
+
141
+ class NormLayer(nn.Module):
142
+ def __init__(self, hidden_dim, norm_type):
143
+ super().__init__()
144
+ if norm_type == "batchnorm":
145
+ self.norm = nn.BatchNorm1d(hidden_dim)
146
+ elif norm_type == "layernorm":
147
+ self.norm = nn.LayerNorm(hidden_dim)
148
+ elif norm_type == "graphnorm":
149
+ self.norm = norm_type
150
+ self.weight = nn.Parameter(torch.ones(hidden_dim))
151
+ self.bias = nn.Parameter(torch.zeros(hidden_dim))
152
+
153
+ self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
154
+ else:
155
+ raise NotImplementedError
156
+
157
+ def forward(self, graph, x):
158
+ tensor = x
159
+ if self.norm is not None and type(self.norm) != str:
160
+ return self.norm(tensor)
161
+ elif self.norm is None:
162
+ return tensor
163
+ batch_list = graph.batch_num_nodes
164
+ batch_size = len(batch_list)
165
+ batch_list = torch.Tensor(batch_list).long().to(tensor.device)
166
+ batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
167
+ batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
168
+ mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
169
+ mean = mean.scatter_add_(0, batch_index, tensor)
170
+ mean = (mean.T / batch_list).T
171
+ mean = mean.repeat_interleave(batch_list, dim=0)
172
+ sub = tensor - mean * self.mean_scale
173
+ std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
174
+ std = std.scatter_add_(0, batch_index, sub.pow(2))
175
+ std = ((std.T / batch_list).T + 1e-6).sqrt()
176
+ std = std.repeat_interleave(batch_list, dim=0)
177
+ return self.weight * sub / std + self.bias
@@ -0,0 +1,183 @@
1
+ import os
2
+ import warnings
3
+ import itertools
4
+ warnings.filterwarnings("ignore")
5
+
6
+ import dgl
7
+ import torch
8
+ import anndata as ad
9
+ import scanpy as sc
10
+ import pandas as pd
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ from tqdm import tqdm
14
+ from scipy import sparse
15
+ from sklearn.cluster import KMeans
16
+ from sklearn.neighbors import BallTree
17
+ from scipy.spatial.distance import pdist, squareform
18
+
19
+ from . import utils
20
+
21
+
22
+ def read_10X_Visium(path,
23
+ genome=None,
24
+ count_file='filtered_feature_bc_matrix.h5',
25
+ library_id=None,
26
+ load_images=True,
27
+ quality='hires',
28
+ image_path = None):
29
+ adata = sc.read_visium(path,
30
+ genome=genome,
31
+ count_file=count_file,
32
+ library_id=library_id,
33
+ load_images=load_images,)
34
+ adata.var_names_make_unique()
35
+ if library_id is None:
36
+ library_id = list(adata.uns["spatial"].keys())[0]
37
+ if quality == "fulres":
38
+ image_coor = adata.obsm["spatial"]
39
+ img = plt.imread(image_path, 0)
40
+ adata.uns["spatial"][library_id]["images"]["fulres"] = img
41
+ else:
42
+ scale = adata.uns["spatial"][library_id]["scalefactors"][
43
+ "tissue_" + quality + "_scalef"]
44
+ image_coor = adata.obsm["spatial"] * scale
45
+ adata.obs["imagecol"] = image_coor[:, 0]
46
+ adata.obs["imagerow"] = image_coor[:, 1]
47
+ adata.uns["spatial"][library_id]["use_quality"] = quality
48
+ return adata
49
+
50
+ def read_10X_Visium_with_label(path,
51
+ genome=None,
52
+ count_file='filtered_feature_bc_matrix.h5',
53
+ library_id=None,
54
+ load_images=True,
55
+ quality='hires',
56
+ image_path = None):
57
+ adata = sc.read_visium(path,
58
+ genome=genome,
59
+ count_file=count_file,
60
+ library_id=library_id,
61
+ load_images=load_images,)
62
+ adata.var_names_make_unique()
63
+ if library_id is None:
64
+ library_id = list(adata.uns["spatial"].keys())[0]
65
+ if quality == "fulres":
66
+ image_coor = adata.obsm["spatial"]
67
+ img = plt.imread(image_path, 0)
68
+ adata.uns["spatial"][library_id]["images"]["fulres"] = img
69
+ else:
70
+ scale = adata.uns["spatial"][library_id]["scalefactors"][
71
+ "tissue_" + quality + "_scalef"]
72
+ image_coor = adata.obsm["spatial"] * scale
73
+ if(os.path.exists(path + "/metadata.tsv")):
74
+ adata.obs = pd.read_table(path + "/metadata.tsv",sep="\t",index_col=0)
75
+ adata.obs["imagecol"] = image_coor[:, 0]
76
+ adata.obs["imagerow"] = image_coor[:, 1]
77
+ adata.uns["spatial"][library_id]["use_quality"] = quality
78
+ return adata
79
+
80
+
81
+ def read_stereo_seq(counts_data_path, position_path):
82
+ counts_file = os.path.join(counts_data_path)
83
+ coor_file = os.path.join(position_path)
84
+ coor_df = pd.read_csv(coor_file, sep='\t')
85
+ counts = pd.read_csv(counts_file, sep='\t', index_col=0)
86
+ counts.columns = ['Spot_' + str(x) for x in counts.columns]
87
+ coor_df.index = coor_df['label'].map(lambda x: 'Spot_' + str(x))
88
+ adata = sc.AnnData(counts.T)
89
+ adata.obs = coor_df
90
+ adata.var_names_make_unique()
91
+ coor_df = coor_df.loc[adata.obs_names, ['y', 'x']]
92
+ adata.obsm["spatial"] = coor_df.to_numpy()
93
+ sc.pp.calculate_qc_metrics(adata, inplace=True)
94
+ adata.obs['imagecol'] = coor_df.iloc[:, 1]
95
+ adata.obs['imagerow'] = coor_df.iloc[:, 0]
96
+ return adata
97
+
98
+ def read_slide_seq(path,
99
+ library_id=None,
100
+ scale=None,
101
+ quality="hires",
102
+ spot_diameter_fullres=50,
103
+ background_color="white",):
104
+ count = pd.read_csv(os.path.join(path, "count_matrix.count"))
105
+ meta = pd.read_csv(os.path.join(path, "spatial.idx"))
106
+ adata = AnnData(count.iloc[:, 1:].set_index("gene").T)
107
+ adata.var["ENSEMBL"] = count["ENSEMBL"].values
108
+ adata.obs["index"] = meta["index"].values
109
+ if scale == None:
110
+ max_coor = np.max(meta[["x", "y"]].values)
111
+ scale = 2000 / max_coor
112
+ adata.obs["imagecol"] = meta["x"].values * scale
113
+ adata.obs["imagerow"] = meta["y"].values * scale
114
+ # Create image
115
+ max_size = np.max([adata.obs["imagecol"].max(), adata.obs["imagerow"].max()])
116
+ max_size = int(max_size + 0.1 * max_size)
117
+ if background_color == "black":
118
+ image = Image.new("RGBA", (max_size, max_size), (0, 0, 0, 0))
119
+ else:
120
+ image = Image.new("RGBA", (max_size, max_size), (255, 255, 255, 255))
121
+ imgarr = np.array(image)
122
+ if library_id is None:
123
+ library_id = "Slide-seq"
124
+ adata.uns["spatial"] = {}
125
+ adata.uns["spatial"][library_id] = {}
126
+ adata.uns["spatial"][library_id]["images"] = {}
127
+ adata.uns["spatial"][library_id]["images"][quality] = imgarr
128
+ adata.uns["spatial"][library_id]["use_quality"] = quality
129
+ adata.uns["spatial"][library_id]["scalefactors"] = {}
130
+ adata.uns["spatial"][library_id]["scalefactors"][
131
+ "tissue_" + quality + "_scalef"] = scale
132
+ adata.uns["spatial"][library_id]["scalefactors"][
133
+ "spot_diameter_fullres"
134
+ ] = spot_diameter_fullres
135
+ adata.obsm["spatial"] = meta[["x", "y"]].values
136
+ return adata
137
+
138
+
139
+ def Graph_10X(adata, args):
140
+ cell_loc = adata.obs[["imagerow", "imagecol"]].values
141
+ if args.graph == 'radius':
142
+ distance_np = pdist(cell_loc, metric = "euclidean")
143
+ distance_np_X = squareform(distance_np)
144
+ threshold = args.threshold_radius
145
+ num_big = np.where((0< distance_np_X)&(distance_np_X < threshold))[0].shape[0]
146
+ adj_matrix = np.zeros(distance_np_X.shape)
147
+ non_zero_point = np.where((0< distance_np_X)&(distance_np_X<threshold))
148
+ for i in tqdm(range(num_big)):
149
+ x = non_zero_point[0][i]
150
+ y = non_zero_point[1][i]
151
+ adj_matrix[x][y] = 1
152
+ adj_matrix = adj_matrix + np.eye(distance_np_X.shape[0])
153
+ adj_matrix = np.float32(adj_matrix)
154
+ adj_matrix_crs = sparse.csr_matrix(adj_matrix)
155
+ elif args.graph == 'knn':
156
+ tree = BallTree(cell_loc)
157
+ distances, tail_list = tree.query(cell_loc, k=args.num_neighbors)
158
+ head_list = []
159
+ head_list = [head_list + [i] * len(tail_list[i]) for i in range(len(tail_list))]
160
+ head_list = list(itertools.chain.from_iterable(head_list))
161
+ tail_list = list(itertools.chain.from_iterable(tail_list))
162
+ distances = np.ones_like(head_list)
163
+ adj_matrix_crs = sparse.coo_matrix((distances, (head_list, tail_list)), shape=(cell_loc.shape[0], cell_loc.shape[0])).tocsr()
164
+ graph = dgl.from_scipy(adj_matrix_crs, eweight_name='w')
165
+
166
+ adata.var_names=[i.upper() for i in list(adata.var_names)]
167
+ adata.var["genename"] = adata.var.index.astype("str")
168
+ adata.var_names_make_unique()
169
+ if(args.feature_dim_method == "PCA"):
170
+ sc.pp.filter_genes(adata, min_cells=5)
171
+ adata_X = sc.pp.normalize_total(adata, target_sum=1, exclude_highly_expressed=True, inplace=False)['X']
172
+ adata_X = sc.pp.scale(adata_X)
173
+ adata_X = sc.pp.pca(adata_X, n_comps=args.num_features)
174
+ else:
175
+ sc.pp.filter_genes(adata, min_cells=5)
176
+ sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=args.num_features)
177
+ sc.pp.normalize_total(adata, target_sum=1e4)
178
+ sc.pp.log1p(adata)
179
+ adata_Vars = adata[:, adata.var['highly_variable']]
180
+ adata_X = adata_Vars.X.todense()
181
+ graph.ndata["feat"] = torch.tensor(adata_X.copy())
182
+ return adata,graph
183
+
gsg-0.6.0/GSG/train.py ADDED
@@ -0,0 +1,21 @@
1
+ from . import models
2
+ from . import utils
3
+
4
+ def GSG_train(adata, graph, args):
5
+ device = args.device if args.device >= 0 else "cpu"
6
+ utils.set_random_seed(args.seeds)
7
+ model = models.build_model(args)
8
+ model.to(device)
9
+ optimizer = models.utils.create_optimizer(args.optimizer, model, args.lr, args.weight_decay)
10
+ x = graph.ndata["feat"]
11
+ if not args.load_model:
12
+ model = models.utils.pretrain(model, graph, x, optimizer, args.max_epoch, device)
13
+ model.train(False)
14
+ x = graph.ndata["feat"]
15
+ embedding = model.embed(graph.to(device), x.to(device))
16
+ adata.obsm["GSG_embedding"] = embedding.cpu().detach().numpy()
17
+ if args.imputation:
18
+ latten_embedding = model.encoder_to_decoder(embedding)
19
+ imputation_embedding = model.decoder(graph.to(device),latten_embedding)
20
+ adata.obsm["GSG_imputation"] = imputation_embedding.cpu().detach().numpy()
21
+ return adata, model
gsg-0.6.0/GSG/utils.py ADDED
@@ -0,0 +1,32 @@
1
+ import os
2
+ import random
3
+
4
+ import torch
5
+ import numpy as np
6
+ from scipy import sparse
7
+ from sklearn.cluster import KMeans
8
+
9
+ def mkdir(path):
10
+ folder = os.path.exists(path)
11
+ if not folder:
12
+ os.makedirs(path)
13
+ print("--- new folder... ---")
14
+ print("--- OK ---")
15
+ else:
16
+ print("--- There is this folder! ---")
17
+
18
+ def KMeans_use(embedding,cluster_number):
19
+ kmeans = KMeans(n_clusters=cluster_number,
20
+ init="k-means++",
21
+ random_state=0)
22
+ pred = kmeans.fit_predict(embedding)
23
+ return pred
24
+
25
+
26
+ def set_random_seed(seed):
27
+ random.seed(seed)
28
+ np.random.seed(seed)
29
+ torch.manual_seed(seed)
30
+ torch.cuda.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+ torch.backends.cudnn.determinstic = True
gsg-0.6.0/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 keaml-Guan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
gsg-0.6.0/PKG-INFO ADDED
@@ -0,0 +1,64 @@
1
+ Metadata-Version: 2.1
2
+ Name: gsg
3
+ Version: 0.6.0
4
+ Summary: GSG: A generative self-supervised graph learning framework for spatial transcriptomics
5
+ Author-email: Chuyao Wang <chuyao25@mails.jlu.edu.cn>
6
+ License: MIT
7
+ Requires-Python: >=3.7
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE.txt
10
+
11
+ # A masked generative graph representation learning framework empowering precise spatial domain identification
12
+ ![GitHub Repo stars](https://img.shields.io/github/stars/keaml-Guan/GSG) &nbsp;&nbsp; ![GitHub watchers](https://img.shields.io/github/watchers/keaml-Guan/GSG) &nbsp;&nbsp; ![GitHub](https://img.shields.io/github/license/keaml-Guan/GSG)
13
+ #
14
+ ![](https://raw.githubusercontent.com/keaml-Guan/GSG/main/figures/Fig1_11_reduce.jpg)
15
+ <br>
16
+
17
+ ## ✨ Overview
18
+
19
+ Recent advances in spatial transcriptomics (ST) have opened new avenues for preserving spatial information while measuring gene expression. Yet, the challenge of seamlessly integrating this data into accurate and transferable representation remains. Here, we introduce a generative self-supervised graph (GSG) learning framework to achieve an effective joint embedding of location and gene expression within ST data. Our approach surpasses existing methods in identifying spatial domains within the human dorsolateral prefrontal cortex. Moreover, it can offer reliable analyses across various techniques, including Stereo-seq, Slide-seq, and seqFISH, irrespective of spatial resolution. Furthermore, GSG addresses dropout defects, enhancing gene expression by smoothing spatial patterns, extracting critical features, reducing batch effects, and enabling the integration of disparate datasets. Additionally, we performed spatial transcriptomic analysis on fetal human hearts, and applied GSG to extract biological insights. These experiments highlight GSG's accuracy in identifying spatial domains, uncovering specific *APCDD1* expression in fetal endocardium, and implicating its role in congenital heart disease. Our results showcase GSG's superiority and underscore its valuable contributions to advancing spatial-omics analysis.
20
+
21
+ ## 🛠️ Installation
22
+
23
+ > [!NOTE]
24
+ > **!!! The recommended operating system is Ubuntu 18.04 LTS.** Some packages may not download correctly on Windows.
25
+
26
+ ### Use python virutal environment with conda
27
+
28
+ ```sh
29
+ conda create -n gsg python=3.8
30
+ conda activate gsg
31
+ # Need install cudnn based on your CUDA version.Refer to https://developer.nvidia.com/cudnn-archive
32
+ # conda install cudnn[=version]
33
+ ```
34
+
35
+ ### Install GSG
36
+
37
+ Install GSG and dgl(for gpu) from PyPi:
38
+
39
+ ```sh
40
+ pip install GSG==0.5.8
41
+ pip install dgl-cu110 -f https://data.dgl.ai/wheels/repo.html
42
+ ```
43
+
44
+ Required packages include:
45
+
46
+ ```sh
47
+ torch==1.9.0, cudnn==8.4, numpy==1.22.0, scanpy==1.8.2, anndata==0.8.0, dgl==0.9.0,
48
+ pandas==1.2.4, scipy==1.7.3, scikit-learn==1.0.1, tqdm==4.64.1, matplotlib==3.5.3,
49
+ tensorboardX==2.5.1, pyyaml==6.0.1, plotly==5.21.0, kaleido==0.2.1, igraph==0.9.8
50
+ ```
51
+
52
+ ## 🚀 Quick Start
53
+
54
+ See our model document details from [Docs](https://keaml-guan.github.io/GSG/).
55
+
56
+ We provide the scripts for reproducing the quantitative and visualization results of the paper in [/docs/tutorials/](https://github.com/keaml-Guan/GSG/tree/main/docs/tutorials/).
57
+
58
+ ## 📚 Citation
59
+
60
+ Wang C, Zhang T, Sun H, et al. A masked generative graph representation learning framework empowering precise spatial domain identification[J]. *Bioinformatics*, 2026, 42(6). ++[https://doi.org/10.1093/bioinformatics/btag333.](https://doi.org/10.1093/bioinformatics/btag333)++
61
+
62
+ ## 📩 Contact
63
+
64
+ If you have any questions, feel free to contact [chuyao25@mails.jlu.edu.cn](mailto:chuyao25@mails.jlu.edu.cn).
gsg-0.6.0/README.md ADDED
@@ -0,0 +1,54 @@
1
+ # A masked generative graph representation learning framework empowering precise spatial domain identification
2
+ ![GitHub Repo stars](https://img.shields.io/github/stars/keaml-Guan/GSG) &nbsp;&nbsp; ![GitHub watchers](https://img.shields.io/github/watchers/keaml-Guan/GSG) &nbsp;&nbsp; ![GitHub](https://img.shields.io/github/license/keaml-Guan/GSG)
3
+ #
4
+ ![](https://raw.githubusercontent.com/keaml-Guan/GSG/main/figures/Fig1_11_reduce.jpg)
5
+ <br>
6
+
7
+ ## ✨ Overview
8
+
9
+ Recent advances in spatial transcriptomics (ST) have opened new avenues for preserving spatial information while measuring gene expression. Yet, the challenge of seamlessly integrating this data into accurate and transferable representation remains. Here, we introduce a generative self-supervised graph (GSG) learning framework to achieve an effective joint embedding of location and gene expression within ST data. Our approach surpasses existing methods in identifying spatial domains within the human dorsolateral prefrontal cortex. Moreover, it can offer reliable analyses across various techniques, including Stereo-seq, Slide-seq, and seqFISH, irrespective of spatial resolution. Furthermore, GSG addresses dropout defects, enhancing gene expression by smoothing spatial patterns, extracting critical features, reducing batch effects, and enabling the integration of disparate datasets. Additionally, we performed spatial transcriptomic analysis on fetal human hearts, and applied GSG to extract biological insights. These experiments highlight GSG's accuracy in identifying spatial domains, uncovering specific *APCDD1* expression in fetal endocardium, and implicating its role in congenital heart disease. Our results showcase GSG's superiority and underscore its valuable contributions to advancing spatial-omics analysis.
10
+
11
+ ## 🛠️ Installation
12
+
13
+ > [!NOTE]
14
+ > **!!! The recommended operating system is Ubuntu 18.04 LTS.** Some packages may not download correctly on Windows.
15
+
16
+ ### Use python virutal environment with conda
17
+
18
+ ```sh
19
+ conda create -n gsg python=3.8
20
+ conda activate gsg
21
+ # Need install cudnn based on your CUDA version.Refer to https://developer.nvidia.com/cudnn-archive
22
+ # conda install cudnn[=version]
23
+ ```
24
+
25
+ ### Install GSG
26
+
27
+ Install GSG and dgl(for gpu) from PyPi:
28
+
29
+ ```sh
30
+ pip install GSG==0.5.8
31
+ pip install dgl-cu110 -f https://data.dgl.ai/wheels/repo.html
32
+ ```
33
+
34
+ Required packages include:
35
+
36
+ ```sh
37
+ torch==1.9.0, cudnn==8.4, numpy==1.22.0, scanpy==1.8.2, anndata==0.8.0, dgl==0.9.0,
38
+ pandas==1.2.4, scipy==1.7.3, scikit-learn==1.0.1, tqdm==4.64.1, matplotlib==3.5.3,
39
+ tensorboardX==2.5.1, pyyaml==6.0.1, plotly==5.21.0, kaleido==0.2.1, igraph==0.9.8
40
+ ```
41
+
42
+ ## 🚀 Quick Start
43
+
44
+ See our model document details from [Docs](https://keaml-guan.github.io/GSG/).
45
+
46
+ We provide the scripts for reproducing the quantitative and visualization results of the paper in [/docs/tutorials/](https://github.com/keaml-Guan/GSG/tree/main/docs/tutorials/).
47
+
48
+ ## 📚 Citation
49
+
50
+ Wang C, Zhang T, Sun H, et al. A masked generative graph representation learning framework empowering precise spatial domain identification[J]. *Bioinformatics*, 2026, 42(6). ++[https://doi.org/10.1093/bioinformatics/btag333.](https://doi.org/10.1093/bioinformatics/btag333)++
51
+
52
+ ## 📩 Contact
53
+
54
+ If you have any questions, feel free to contact [chuyao25@mails.jlu.edu.cn](mailto:chuyao25@mails.jlu.edu.cn).
@@ -0,0 +1,64 @@
1
+ Metadata-Version: 2.1
2
+ Name: gsg
3
+ Version: 0.6.0
4
+ Summary: GSG: A generative self-supervised graph learning framework for spatial transcriptomics
5
+ Author-email: Chuyao Wang <chuyao25@mails.jlu.edu.cn>
6
+ License: MIT
7
+ Requires-Python: >=3.7
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE.txt
10
+
11
+ # A masked generative graph representation learning framework empowering precise spatial domain identification
12
+ ![GitHub Repo stars](https://img.shields.io/github/stars/keaml-Guan/GSG) &nbsp;&nbsp; ![GitHub watchers](https://img.shields.io/github/watchers/keaml-Guan/GSG) &nbsp;&nbsp; ![GitHub](https://img.shields.io/github/license/keaml-Guan/GSG)
13
+ #
14
+ ![](https://raw.githubusercontent.com/keaml-Guan/GSG/main/figures/Fig1_11_reduce.jpg)
15
+ <br>
16
+
17
+ ## ✨ Overview
18
+
19
+ Recent advances in spatial transcriptomics (ST) have opened new avenues for preserving spatial information while measuring gene expression. Yet, the challenge of seamlessly integrating this data into accurate and transferable representation remains. Here, we introduce a generative self-supervised graph (GSG) learning framework to achieve an effective joint embedding of location and gene expression within ST data. Our approach surpasses existing methods in identifying spatial domains within the human dorsolateral prefrontal cortex. Moreover, it can offer reliable analyses across various techniques, including Stereo-seq, Slide-seq, and seqFISH, irrespective of spatial resolution. Furthermore, GSG addresses dropout defects, enhancing gene expression by smoothing spatial patterns, extracting critical features, reducing batch effects, and enabling the integration of disparate datasets. Additionally, we performed spatial transcriptomic analysis on fetal human hearts, and applied GSG to extract biological insights. These experiments highlight GSG's accuracy in identifying spatial domains, uncovering specific *APCDD1* expression in fetal endocardium, and implicating its role in congenital heart disease. Our results showcase GSG's superiority and underscore its valuable contributions to advancing spatial-omics analysis.
20
+
21
+ ## 🛠️ Installation
22
+
23
+ > [!NOTE]
24
+ > **!!! The recommended operating system is Ubuntu 18.04 LTS.** Some packages may not download correctly on Windows.
25
+
26
+ ### Use python virutal environment with conda
27
+
28
+ ```sh
29
+ conda create -n gsg python=3.8
30
+ conda activate gsg
31
+ # Need install cudnn based on your CUDA version.Refer to https://developer.nvidia.com/cudnn-archive
32
+ # conda install cudnn[=version]
33
+ ```
34
+
35
+ ### Install GSG
36
+
37
+ Install GSG and dgl(for gpu) from PyPi:
38
+
39
+ ```sh
40
+ pip install GSG==0.5.8
41
+ pip install dgl-cu110 -f https://data.dgl.ai/wheels/repo.html
42
+ ```
43
+
44
+ Required packages include:
45
+
46
+ ```sh
47
+ torch==1.9.0, cudnn==8.4, numpy==1.22.0, scanpy==1.8.2, anndata==0.8.0, dgl==0.9.0,
48
+ pandas==1.2.4, scipy==1.7.3, scikit-learn==1.0.1, tqdm==4.64.1, matplotlib==3.5.3,
49
+ tensorboardX==2.5.1, pyyaml==6.0.1, plotly==5.21.0, kaleido==0.2.1, igraph==0.9.8
50
+ ```
51
+
52
+ ## 🚀 Quick Start
53
+
54
+ See our model document details from [Docs](https://keaml-guan.github.io/GSG/).
55
+
56
+ We provide the scripts for reproducing the quantitative and visualization results of the paper in [/docs/tutorials/](https://github.com/keaml-Guan/GSG/tree/main/docs/tutorials/).
57
+
58
+ ## 📚 Citation
59
+
60
+ Wang C, Zhang T, Sun H, et al. A masked generative graph representation learning framework empowering precise spatial domain identification[J]. *Bioinformatics*, 2026, 42(6). ++[https://doi.org/10.1093/bioinformatics/btag333.](https://doi.org/10.1093/bioinformatics/btag333)++
61
+
62
+ ## 📩 Contact
63
+
64
+ If you have any questions, feel free to contact [chuyao25@mails.jlu.edu.cn](mailto:chuyao25@mails.jlu.edu.cn).
@@ -0,0 +1,17 @@
1
+ LICENSE.txt
2
+ README.md
3
+ pyproject.toml
4
+ GSG/__init__.py
5
+ GSG/preprocess.py
6
+ GSG/train.py
7
+ GSG/utils.py
8
+ GSG/models/__init__.py
9
+ GSG/models/edcoder.py
10
+ GSG/models/gin.py
11
+ GSG/models/loss_func.py
12
+ GSG/models/utils.py
13
+ gsg.egg-info/PKG-INFO
14
+ gsg.egg-info/SOURCES.txt
15
+ gsg.egg-info/dependency_links.txt
16
+ gsg.egg-info/requires.txt
17
+ gsg.egg-info/top_level.txt
@@ -0,0 +1,10 @@
1
+ numpy==1.21.6
2
+ pandas==1.2.4
3
+ scipy
4
+ scikit-learn
5
+ torch==1.9.0
6
+ dgl==0.9.0
7
+ scanpy==1.8.2
8
+ anndata==0.8.0
9
+ squidpy==1.1.2
10
+ leidenalg<0.11,>=0.8.2
@@ -0,0 +1 @@
1
+ GSG
@@ -0,0 +1,31 @@
1
+ [build-system]
2
+ requires = ["setuptools>=59.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "gsg"
7
+ version = "0.6.0"
8
+ description = "GSG: A generative self-supervised graph learning framework for spatial transcriptomics"
9
+ readme = "README.md"
10
+ requires-python = ">=3.7"
11
+ authors = [
12
+ {name = "Chuyao Wang", email = "chuyao25@mails.jlu.edu.cn"}
13
+ ]
14
+ license = {text = "MIT"}
15
+
16
+ dependencies = [
17
+ "numpy==1.21.6",
18
+ "pandas==1.2.4",
19
+ "scipy",
20
+ "scikit-learn",
21
+ "torch==1.9.0",
22
+ "dgl==0.9.0",
23
+ "scanpy==1.8.2",
24
+ "anndata==0.8.0",
25
+ "squidpy==1.1.2",
26
+ "leidenalg>=0.8.2,<0.11",
27
+ ]
28
+
29
+ [tool.setuptools.packages.find]
30
+ where = ["."]
31
+ include = ["GSG*"]
gsg-0.6.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+