glam4cm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. glam4cm/__init__.py +9 -0
  2. glam4cm/data_loading/__init__.py +0 -0
  3. glam4cm/data_loading/data.py +631 -0
  4. glam4cm/data_loading/encoding.py +76 -0
  5. glam4cm/data_loading/graph_dataset.py +940 -0
  6. glam4cm/data_loading/metadata.py +84 -0
  7. glam4cm/data_loading/models_dataset.py +361 -0
  8. glam4cm/data_loading/utils.py +20 -0
  9. glam4cm/downstream_tasks/__init__.py +0 -0
  10. glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
  11. glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
  12. glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
  13. glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
  14. glam4cm/downstream_tasks/bert_node_classification.py +164 -0
  15. glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
  16. glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
  17. glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
  18. glam4cm/downstream_tasks/common_args.py +160 -0
  19. glam4cm/downstream_tasks/create_dataset.py +51 -0
  20. glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
  21. glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
  22. glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
  23. glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
  24. glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
  25. glam4cm/downstream_tasks/utils.py +35 -0
  26. glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
  27. glam4cm/embeddings/__init__.py +0 -0
  28. glam4cm/embeddings/bert.py +72 -0
  29. glam4cm/embeddings/common.py +43 -0
  30. glam4cm/embeddings/fasttext.py +0 -0
  31. glam4cm/embeddings/tfidf.py +25 -0
  32. glam4cm/embeddings/w2v.py +41 -0
  33. glam4cm/encoding/__init__.py +0 -0
  34. glam4cm/encoding/common.py +0 -0
  35. glam4cm/encoding/encoders.py +100 -0
  36. glam4cm/graph2str/__init__.py +0 -0
  37. glam4cm/graph2str/common.py +34 -0
  38. glam4cm/graph2str/constants.py +15 -0
  39. glam4cm/graph2str/ontouml.py +141 -0
  40. glam4cm/graph2str/uml.py +0 -0
  41. glam4cm/lang2graph/__init__.py +0 -0
  42. glam4cm/lang2graph/archimate.py +31 -0
  43. glam4cm/lang2graph/bpmn.py +0 -0
  44. glam4cm/lang2graph/common.py +416 -0
  45. glam4cm/lang2graph/ecore.py +221 -0
  46. glam4cm/lang2graph/ontouml.py +169 -0
  47. glam4cm/lang2graph/utils.py +80 -0
  48. glam4cm/models/cmgpt.py +352 -0
  49. glam4cm/models/gnn_layers.py +273 -0
  50. glam4cm/models/hf.py +10 -0
  51. glam4cm/run.py +99 -0
  52. glam4cm/run_configs.py +126 -0
  53. glam4cm/settings.py +54 -0
  54. glam4cm/tokenization/__init__.py +0 -0
  55. glam4cm/tokenization/special_tokens.py +4 -0
  56. glam4cm/tokenization/utils.py +37 -0
  57. glam4cm/trainers/__init__.py +0 -0
  58. glam4cm/trainers/bert_classifier.py +105 -0
  59. glam4cm/trainers/cm_gpt_trainer.py +153 -0
  60. glam4cm/trainers/gnn_edge_classifier.py +126 -0
  61. glam4cm/trainers/gnn_graph_classifier.py +123 -0
  62. glam4cm/trainers/gnn_link_predictor.py +144 -0
  63. glam4cm/trainers/gnn_node_classifier.py +135 -0
  64. glam4cm/trainers/gnn_trainer.py +129 -0
  65. glam4cm/trainers/metrics.py +55 -0
  66. glam4cm/utils.py +194 -0
  67. glam4cm-0.1.0.dist-info/LICENSE +21 -0
  68. glam4cm-0.1.0.dist-info/METADATA +86 -0
  69. glam4cm-0.1.0.dist-info/RECORD +72 -0
  70. glam4cm-0.1.0.dist-info/WHEEL +5 -0
  71. glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
  72. glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,352 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from glam4cm.settings import device
4
+
5
+
6
+ def weights_init(model):
7
+ """
8
+ Initialize the weights of the model
9
+ xaiver_uniform is used for linear layers and embeddings
10
+ zeros is used for biases
11
+ xavier_uniform initializes the weights with a uniform distribution
12
+ This is done to avoid the exploding gradient problem
13
+ """
14
+
15
+ if isinstance(model, nn.Linear):
16
+ nn.init.xavier_uniform_(model.weight.data)
17
+ if model.bias is not None:
18
+ nn.init.zeros_(model.bias.data)
19
+ elif isinstance(model, nn.Embedding):
20
+ nn.init.xavier_uniform_(model.weight.data)
21
+ elif isinstance(model, nn.LayerNorm):
22
+ nn.init.ones_(model.weight.data)
23
+ nn.init.zeros_(model.bias.data)
24
+
25
+
26
+ class Head(nn.Module):
27
+ """ one head of self-attention """
28
+
29
+ def __init__(self, embed_dim, head_size, dropout=0.1):
30
+ super().__init__()
31
+ self.key = nn.Linear(embed_dim, head_size, bias=False)
32
+ self.query = nn.Linear(embed_dim, head_size, bias=False)
33
+ self.value = nn.Linear(embed_dim, head_size, bias=False)
34
+ self.register_buffer('tril', torch.tril(torch.ones(head_size, head_size)))
35
+ self.softmax = nn.Softmax(dim=-1)
36
+
37
+ self.dropout = nn.Dropout(dropout)
38
+
39
+ def forward(self, x, attention_mask):
40
+ """
41
+ x: [batch_size, seq_len, embed_dim]
42
+ attention_mask: [batch_size, seq_len]
43
+
44
+ This method computes the attention scores between each token in the sequence
45
+ """
46
+ _, _, C = x.shape
47
+ k = self.key(x)
48
+ q = self.query(x)
49
+
50
+ # Compute attention scores ("affinities") only where the mask is non-zero
51
+ wei = q @ k.transpose(-2, -1) * C**-0.5
52
+ wei = wei.masked_fill((attention_mask.unsqueeze(1) == 0), float('-inf'))
53
+ wei = self.softmax(wei)
54
+ wei = self.dropout(wei)
55
+
56
+ # Perform the weighted aggregation of the values
57
+ v = self.value(x)
58
+ out = wei @ v
59
+ return out
60
+
61
+
62
+ class MultiHeadAttention(nn.Module):
63
+ """
64
+ multiple heads of self-attention in parallel
65
+ This class first splits the embedding dimension into multiple heads
66
+ Then, each head computes the attention scores between each token in the sequence
67
+ Finally, the outputs of all the heads are concatenated and projected back to the original embedding dimension
68
+ """
69
+
70
+ def __init__(self, embed_dim, num_heads, dropout=0.1):
71
+ super().__init__()
72
+ head_size = embed_dim // num_heads
73
+ self.heads = nn.ModuleList([Head(embed_dim, head_size) for _ in range(num_heads)])
74
+ self.proj = nn.Linear(embed_dim, embed_dim)
75
+ self.dropout = nn.Dropout(dropout)
76
+
77
+ def forward(self, x, attn_mask):
78
+ """
79
+ x: [batch_size, seq_len, embed_dim]
80
+ """
81
+ out = torch.cat([h(x, attn_mask) for h in self.heads], dim=-1)
82
+ out = self.dropout(self.proj(out))
83
+ return out
84
+
85
+ class FeedFoward(nn.Module):
86
+ """
87
+ a simple linear layer followed by a non-linearity
88
+ """
89
+
90
+ def __init__(self, input_dim, embed_dim=None, num_classes=None, dropout=0.1):
91
+ super().__init__()
92
+
93
+ if num_classes is None:
94
+ num_classes = input_dim if embed_dim is None else embed_dim
95
+
96
+ if embed_dim is None:
97
+ embed_dim = input_dim
98
+
99
+
100
+ self.net = nn.Sequential(
101
+ nn.Linear(input_dim, 4 * embed_dim),
102
+ nn.ReLU(),
103
+ nn.Linear(4 * embed_dim, num_classes),
104
+ nn.Dropout(dropout),
105
+ )
106
+
107
+ def forward(self, x):
108
+ return self.net(x)
109
+
110
+ class Block(nn.Module):
111
+ """ Transformer block: communication followed by computation """
112
+
113
+ def __init__(self, embed_dim, n_head):
114
+ # embed_dim: embedding dimension, n_head: the number of heads we'd like
115
+ super().__init__()
116
+ self.sa = MultiHeadAttention(embed_dim, n_head)
117
+ self.ffwd = FeedFoward(embed_dim)
118
+ self.ln1 = nn.LayerNorm(embed_dim)
119
+ self.ln2 = nn.LayerNorm(embed_dim)
120
+
121
+ def forward(self, x, attn_mask):
122
+ x = x + self.sa(self.ln1(x), attn_mask)
123
+ x = x + self.ffwd(self.ln2(x))
124
+ return x
125
+
126
+
127
+ class CMGPT(nn.Module):
128
+ """
129
+ UML-GPT model
130
+
131
+ vocab_size: the size of the vocabulary
132
+ embed_dim: the embedding dimension
133
+ block_size: the maximum sequence length
134
+ n_layer: the number of transformer blocks
135
+ n_head: the number of heads in each transformer block
136
+ load_pretrained_from: the path to the pretrained model
137
+
138
+ This class uses the string representation of the node as the input
139
+ The string representation is tokenized using the tokenizer
140
+ The tokenized sequence is then passed through the transformer blocks
141
+ Finally, the logits for the next token are computed using a linear layer
142
+
143
+ """
144
+ def __init__(
145
+ self,
146
+ vocab_size,
147
+ embed_dim,
148
+ block_size,
149
+ n_layer,
150
+ n_head,
151
+ load_pretrained_from=None
152
+ ):
153
+ super().__init__()
154
+ # each token directly reads off the logits for the next token from a lookup table
155
+
156
+ if load_pretrained_from is not None:
157
+ self.load_state_dict(torch.load(load_pretrained_from))
158
+ else:
159
+ self.token_embedding_table = nn.Embedding(vocab_size, embed_dim)
160
+ self.position_embedding_table = nn.Embedding(block_size, embed_dim)
161
+ self.blocks = nn.Sequential(*[Block(embed_dim, n_head) for _ in range(n_layer)])
162
+ self.ln_f = nn.LayerNorm(embed_dim) # final layer norm
163
+ self.lm_head = nn.Linear(embed_dim, vocab_size)
164
+
165
+ self.apply(weights_init)
166
+
167
+
168
+ def forward(self, x, attention_mask, labels=None):
169
+ """
170
+ x: [batch_size, seq_len]
171
+ attention_mask: [batch_size, seq_len]
172
+
173
+ This method computes the logits for the next token
174
+ """
175
+ embeddings = self.get_embedding(x, attention_mask)
176
+ logits = self.lm_head(embeddings)
177
+ if labels is not None:
178
+ loss = self.get_loss(logits, labels)
179
+ return logits, loss
180
+ return logits
181
+
182
+
183
+ def get_loss(self, logits, labels, ignore_index=-100):
184
+ """
185
+ logits: [batch_size, seq_len, vocab_size]
186
+ labels: [batch_size, seq_len]
187
+
188
+ This method computes the loss for the next token prediction task
189
+ This is achieved by shifting the labels by one position and computing the cross entropy loss
190
+ """
191
+ block_size = self.position_embedding_table.weight.shape[0]
192
+ labels = labels[..., :block_size]
193
+ loss = None
194
+ if labels is not None:
195
+ # Shift so that tokens < n predict n
196
+ shift_logits = logits[..., :-1, :].contiguous()
197
+ shift_labels = labels[..., 1:].contiguous()
198
+ # Flatten the tokens
199
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignore_index)
200
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
201
+
202
+ return loss
203
+
204
+
205
+ def get_embedding(self, x, attention_mask):
206
+ """
207
+ x: [batch_size, seq_len]
208
+ attention_mask: [batch_size, seq_len]
209
+ """
210
+ block_size = self.position_embedding_table.weight.shape[0]
211
+ vocab_size = self.token_embedding_table.weight.shape[0]
212
+
213
+ x = x[..., :block_size]
214
+ attention_mask = attention_mask[..., :block_size]
215
+
216
+ assert x.shape[-1] <= block_size, f"Sequence length {x.shape[-1]} is greater than block size {block_size}"
217
+
218
+ # print("Token embeddings", x.shape, torch.min(x), torch.max(x), vocab_size)
219
+ assert torch.min(x) <= vocab_size, f"Min token id {torch.min(x)} is greater than vocab size {vocab_size}"
220
+ assert torch.max(x) <= vocab_size, f"Max token id {torch.max(x)} is greater than vocab size {vocab_size}"
221
+
222
+
223
+ token_embeddings = self.token_embedding_table(x)
224
+
225
+ position_ids = torch.arange(x.size(1), dtype=torch.long, device=x.device)
226
+ position_ids = position_ids.unsqueeze(0).expand_as(x)
227
+
228
+ # print("Position embeddings", position_ids.shape, torch.min(position_ids), torch.max(position_ids), block_size)
229
+ torch.min(position_ids) <= block_size, f"Min position id {torch.min(position_ids)} is greater than block size {block_size}"
230
+ torch.max(position_ids) <= block_size, f"Max position id {torch.max(position_ids)} is greater than block size {block_size}"
231
+ position_embeddings = self.position_embedding_table(position_ids)
232
+
233
+
234
+ embeddings = token_embeddings + position_embeddings
235
+
236
+ # # Modify the forward pass to include src_key_padding_mask
237
+
238
+
239
+ for block in self.blocks:
240
+ # print("Embed dim: ", embeddings.shape)
241
+ embeddings = block(embeddings, attention_mask)
242
+
243
+ embeddings = self.ln_f(embeddings)
244
+ return embeddings
245
+
246
+
247
+ def get_model_size(self):
248
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
249
+
250
+ def __repr__(self):
251
+ return super().__repr__() + f'\nNumber of parameters: {self.get_model_size() / 1000000:.3f}M'
252
+
253
+
254
+ @property
255
+ def __name__(self):
256
+ return 'CMGPT'
257
+
258
+
259
+ @property
260
+ def name_or_path(self):
261
+ return 'CMGPT'
262
+
263
+
264
+ @staticmethod
265
+ def from_pretrained(state_dict_pth):
266
+ state_dict = torch.load(state_dict_pth, map_location=device)
267
+ vocab_size, embed_dim = [s.shape for _, s in state_dict.items() if 'token_embedding_table' in _][0]
268
+ num_heads = max([int(name.split('.sa.heads.')[1].split('.')[0]) for name, s in state_dict.items() if '.sa.heads.' in name]) + 1
269
+ block_size = [s.shape[0] for _, s in state_dict.items() if 'position_embedding_table' in _][0]
270
+ num_layers = max([int(name.split('blocks.')[1].split('.')[0]) for name, s in state_dict.items() if 'blocks.' in name]) + 1
271
+ model = CMGPT(vocab_size, embed_dim, block_size, num_layers, num_heads)
272
+ model.load_state_dict(state_dict)
273
+ return model
274
+
275
+
276
+ class CMGPTClassifier(nn.Module):
277
+ """
278
+ UML-GPT model for classification
279
+
280
+ model: the UML-GPT model
281
+ num_classes: the number of classes
282
+
283
+ """
284
+ def __init__(
285
+ self,
286
+ model: CMGPT,
287
+ num_classes: int
288
+ ):
289
+ super().__init__()
290
+
291
+ self.model = model
292
+ _, embed_dim = self.model.lm_head.weight.data.shape
293
+ self.classifier = FeedFoward(input_dim=embed_dim, num_classes=num_classes)
294
+ self.apply(weights_init)
295
+
296
+ def forward(self, x, attention_mask, labels=None, pool=None):
297
+ # x: [batch_size, seq_len]
298
+ # attention_mask: [batch_size, seq_len]
299
+ lm_logits = self.model.get_embedding(x, attention_mask)
300
+ if pool:
301
+ """Pool the logits across the sequence dimension"""
302
+ lm_logits = torch.mean(lm_logits, dim=1)
303
+ else:
304
+ """Use the logits at the last position"""
305
+ lm_logits = lm_logits[:, -1, :]
306
+
307
+ logits = self.classifier(lm_logits)
308
+
309
+ if labels is not None:
310
+ loss = self.get_loss(logits, labels)
311
+ return logits, loss
312
+ return logits
313
+
314
+ def get_loss(self, logits, labels):
315
+ logits = logits.to(device)
316
+ labels = labels.to(device)
317
+
318
+ if len(labels.shape) == 1:
319
+ loss_fct = torch.nn.CrossEntropyLoss()
320
+ loss = loss_fct(logits, labels)
321
+ else:
322
+ loss_fct = torch.nn.BCEWithLogitsLoss()
323
+ loss = loss_fct(logits.float(), labels.float())
324
+ return loss
325
+
326
+ def get_embedding(self, x, attention_mask):
327
+ return self.model.get_embedding(x, attention_mask)
328
+
329
+ def get_model_size(self):
330
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
331
+
332
+ def __repr__(self):
333
+ return super().__repr__() + f'\nNumber of parameters: {self.get_model_size()/1000000:.3f}M'
334
+
335
+ @staticmethod
336
+ def from_pretrained(state_dict_path, num_classes, init_classifier=True):
337
+ if init_classifier:
338
+ print("Initializing classifier from pretrained model with num classes: ", num_classes)
339
+ model = CMGPTClassifier(CMGPT.from_pretrained(state_dict), num_classes)
340
+ else:
341
+ state_dict = torch.load(state_dict_path, map_location=device)
342
+ vocab_size, embed_dim = [s.shape for _, s in state_dict.items() if 'token_embedding_table' in _][0]
343
+ num_heads = max([int(name.split('.sa.heads.')[1].split('.')[0]) for name, s in state_dict.items() if '.sa.heads.' in name]) + 1
344
+ block_size = [s.shape[0] for _, s in state_dict.items() if 'position_embedding_table' in _][0]
345
+ num_layers = max([int(name.split('blocks.')[1].split('.')[0]) for name, s in state_dict.items() if 'blocks.' in name]) + 1
346
+ num_classes = state_dict['classifier.net.2.weight'].shape[0]
347
+ uml_gpt = CMGPT(vocab_size, embed_dim, block_size, num_layers, num_heads)
348
+
349
+ model = CMGPTClassifier(uml_gpt, num_classes)
350
+ model.load_state_dict(state_dict)
351
+
352
+ return model
@@ -0,0 +1,273 @@
1
+ import torch
2
+ from torch.nn import functional as F
3
+ from torch_geometric.nn import aggr
4
+ from torch_geometric.nn import (
5
+ global_add_pool,
6
+ global_max_pool,
7
+ global_mean_pool,
8
+ )
9
+ import torch_geometric
10
+ import torch.nn as nn
11
+
12
+
13
+ aggregation_methods = {
14
+ 'mean': aggr.MeanAggregation(),
15
+ 'sum': aggr.SumAggregation(),
16
+ 'max': aggr.MaxAggregation(),
17
+ 'mul': aggr.MulAggregation(),
18
+ }
19
+
20
+ supported_conv_models = {
21
+ 'GCNConv': False, ## True or False if the model requires num_heads
22
+ 'GraphConv': False,
23
+ 'GATConv': True,
24
+ 'SAGEConv': False,
25
+ 'GINConv': False,
26
+ 'GATv2Conv': True,
27
+ }
28
+
29
+ global_pooling_methods = {
30
+ 'sum': global_add_pool,
31
+ 'mean': global_mean_pool,
32
+ 'max': global_max_pool,
33
+ }
34
+
35
+
36
+ class GNNConv(torch.nn.Module):
37
+ """
38
+ A general GNN model created using the PyTorch Geometric library
39
+ model_name: the name of the GNN model
40
+ input_dim: the input dimension
41
+ hidden_dim: the hidden dimension
42
+ out_dim: the output dimension
43
+
44
+ num_layers: the number of GNN layers
45
+ num_heads: the number of heads in the GNN layer
46
+ residual: whether to use residual connections
47
+ l_norm: whether to use layer normalization
48
+ dropout: the dropout probability
49
+
50
+ """
51
+ def __init__(
52
+ self,
53
+ model_name,
54
+ input_dim,
55
+ hidden_dim,
56
+ out_dim=None,
57
+ num_layers=2,
58
+ num_heads=None,
59
+ residual=False,
60
+ l_norm=False,
61
+ dropout=0.1,
62
+ aggregation='mean',
63
+ edge_dim=None
64
+ ):
65
+ super(GNNConv, self).__init__()
66
+
67
+ assert model_name in supported_conv_models, f"Model {model_name} not supported. Choose from {supported_conv_models.keys()}"
68
+ heads_supported = supported_conv_models[model_name]
69
+ if heads_supported and num_heads is None:
70
+ raise ValueError(f"Model {model_name} requires num_heads to be set to an integer")
71
+
72
+ if not heads_supported and num_heads is not None:
73
+ num_heads = None
74
+
75
+ assert aggregation in aggregation_methods, f"Aggregation method {aggregation} not supported. Choose from {aggregation_methods.keys()}"
76
+ aggregation = aggregation_methods[aggregation]
77
+
78
+ self.input_dim = input_dim
79
+ self.embed_dim = hidden_dim
80
+ self.out_dim = out_dim if out_dim is not None else hidden_dim
81
+ self.num_layers = num_layers
82
+ self.num_heads = num_heads
83
+ self.aggregation = aggregation
84
+ self.edge_dim = edge_dim
85
+
86
+
87
+ gnn_model = getattr(torch_geometric.nn, model_name)
88
+ self.conv_layers = nn.ModuleList()
89
+
90
+ for i in range(num_layers):
91
+ if num_heads is None:
92
+ conv = gnn_model(
93
+ input_dim,
94
+ hidden_dim if i != num_layers - 1 else self.out_dim,
95
+ aggr=aggregation
96
+ )
97
+ else:
98
+ conv = gnn_model(
99
+ input_dim if i == 0 else num_heads*input_dim,
100
+ hidden_dim if i != num_layers - 1 else self.out_dim,
101
+ heads=num_heads,
102
+ aggr=aggregation,
103
+ edge_dim=edge_dim
104
+ )
105
+ self.conv_layers.append(conv)
106
+ input_dim = hidden_dim
107
+
108
+ self.activation = nn.ReLU()
109
+ self.layer_norm = nn.LayerNorm(hidden_dim if num_heads is None else num_heads*hidden_dim) if l_norm else None
110
+ self.residual = residual
111
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else None
112
+
113
+
114
+ def forward(self, in_feat, edge_index, edge_attr=None):
115
+
116
+ def activate(h):
117
+ h = self.activation(h)
118
+
119
+ if self.layer_norm is not None:
120
+ h = self.layer_norm(h)
121
+
122
+ if self.dropout is not None:
123
+ h = self.dropout(h)
124
+ return h
125
+
126
+ h = in_feat
127
+ h = self.conv_layers[0](h, edge_index, edge_attr) if isinstance(edge_attr, torch.Tensor) else self.conv_layers[0](h, edge_index)
128
+ activate(h)
129
+
130
+ for conv in self.conv_layers[1:-1]:
131
+ nh = conv(h, edge_index, edge_attr) if isinstance(edge_attr, torch.Tensor) else conv(h, edge_index)
132
+ h = nh if not self.residual else nh + h
133
+ activate(h)
134
+
135
+ h = self.conv_layers[-1](h, edge_index)
136
+ activate(h)
137
+ return h
138
+
139
+
140
+ class EdgeClassifer(nn.Module):
141
+
142
+ """
143
+ An MLP predictor for link prediction
144
+
145
+ h_feats: the input dimension
146
+ num_classes: the number of classes
147
+ num_layers: the number of layers in the MLP
148
+
149
+ This class concatenates the node embeddings of the two nodes in the edge
150
+ The concatenated embeddings are then passed through an MLP
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ input_dim,
156
+ hidden_dim,
157
+ num_classes,
158
+ num_layers=2,
159
+ dropout=0.3,
160
+ edge_dim=None,
161
+ bias=False
162
+ ):
163
+ super().__init__()
164
+ self.layers = nn.ModuleList()
165
+ self.input_dim = input_dim
166
+ self.embed_dim = hidden_dim
167
+ self.num_layers = num_layers
168
+ self.num_classes = num_classes
169
+
170
+ in_feats = input_dim * 2
171
+ if edge_dim is not None:
172
+ in_feats += edge_dim
173
+
174
+
175
+ for _ in range(num_layers):
176
+ self.layers.append(nn.Linear(in_feats, hidden_dim, bias=bias))
177
+ self.layers.append(nn.ReLU())
178
+ self.layers.append(nn.Dropout(dropout))
179
+ in_feats = hidden_dim
180
+
181
+ self.layers.append(nn.Linear(hidden_dim, num_classes, bias=bias))
182
+
183
+
184
+ def forward(self, x, edge_index, edge_attr=None):
185
+ h = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1)
186
+ if edge_attr is not None:
187
+ h = torch.cat([h, edge_attr], dim=-1)
188
+
189
+ for layer in self.layers:
190
+ h = layer(h)
191
+
192
+ return h
193
+
194
+
195
+ class NodeClassifier(nn.Module):
196
+
197
+ """
198
+ An MLP predictor for link prediction
199
+
200
+ h_feats: the input dimension
201
+ num_classes: the number of classes
202
+ num_layers: the number of layers in the MLP
203
+
204
+ This class concatenates the node embeddings of the two nodes in the edge
205
+ The concatenated embeddings are then passed through an MLP
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ input_dim,
211
+ hidden_dim,
212
+ num_classes,
213
+ num_layers=2,
214
+ dropout=0.3,
215
+ bias=True
216
+ ):
217
+ super().__init__()
218
+ self.layers = nn.ModuleList()
219
+ self.embed_dim = hidden_dim
220
+ self.num_layers = num_layers
221
+ self.num_classes = num_classes
222
+
223
+ for _ in range(num_layers - 1):
224
+ self.layers.append(nn.Linear(input_dim, hidden_dim, bias=bias))
225
+ self.layers.append(nn.ReLU())
226
+ self.layers.append(nn.Dropout(dropout))
227
+ input_dim = hidden_dim
228
+
229
+ self.layers.append(nn.Linear(hidden_dim, num_classes, bias=bias))
230
+
231
+
232
+ def forward(self, x):
233
+ h = x
234
+ for layer in self.layers:
235
+ h = layer(h)
236
+
237
+ return h
238
+
239
+
240
+ class GraphClassifer(nn.Module):
241
+
242
+ """
243
+ An MLP predictor for link prediction
244
+
245
+ h_feats: the input dimension
246
+ num_classes: the number of classes
247
+ num_layers: the number of layers in the MLP
248
+
249
+ This class concatenates the node embeddings of the two nodes in the edge
250
+ The concatenated embeddings are then passed through an MLP
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ input_dim,
256
+ num_classes,
257
+ global_pool='mean',
258
+ bias=False
259
+ ):
260
+ super().__init__()
261
+ self.layers = nn.ModuleList()
262
+ self.input_dim = input_dim
263
+ self.num_classes = num_classes
264
+
265
+ self.layers.append(nn.Linear(input_dim, num_classes, bias=bias))
266
+ self.global_pool = global_pooling_methods[global_pool]
267
+
268
+ def forward(self, x, batch):
269
+ h = self.global_pool(x, batch)
270
+ for layer in self.layers:
271
+ h = layer(h)
272
+
273
+ return h
glam4cm/models/hf.py ADDED
@@ -0,0 +1,10 @@
1
+ from transformers import AutoModelForSequenceClassification
2
+
3
+ def get_model(model_name, num_labels, len_tokenizer=None) -> AutoModelForSequenceClassification:
4
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
5
+ if len_tokenizer:
6
+ model.resize_token_embeddings(len_tokenizer)
7
+ assert model.config.vocab_size == len_tokenizer,\
8
+ f"Tokenizer size {len_tokenizer} does not match model size {model.config.vocab_size}"
9
+
10
+ return model