glam4cm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. glam4cm/__init__.py +9 -0
  2. glam4cm/data_loading/__init__.py +0 -0
  3. glam4cm/data_loading/data.py +631 -0
  4. glam4cm/data_loading/encoding.py +76 -0
  5. glam4cm/data_loading/graph_dataset.py +940 -0
  6. glam4cm/data_loading/metadata.py +84 -0
  7. glam4cm/data_loading/models_dataset.py +361 -0
  8. glam4cm/data_loading/utils.py +20 -0
  9. glam4cm/downstream_tasks/__init__.py +0 -0
  10. glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
  11. glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
  12. glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
  13. glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
  14. glam4cm/downstream_tasks/bert_node_classification.py +164 -0
  15. glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
  16. glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
  17. glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
  18. glam4cm/downstream_tasks/common_args.py +160 -0
  19. glam4cm/downstream_tasks/create_dataset.py +51 -0
  20. glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
  21. glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
  22. glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
  23. glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
  24. glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
  25. glam4cm/downstream_tasks/utils.py +35 -0
  26. glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
  27. glam4cm/embeddings/__init__.py +0 -0
  28. glam4cm/embeddings/bert.py +72 -0
  29. glam4cm/embeddings/common.py +43 -0
  30. glam4cm/embeddings/fasttext.py +0 -0
  31. glam4cm/embeddings/tfidf.py +25 -0
  32. glam4cm/embeddings/w2v.py +41 -0
  33. glam4cm/encoding/__init__.py +0 -0
  34. glam4cm/encoding/common.py +0 -0
  35. glam4cm/encoding/encoders.py +100 -0
  36. glam4cm/graph2str/__init__.py +0 -0
  37. glam4cm/graph2str/common.py +34 -0
  38. glam4cm/graph2str/constants.py +15 -0
  39. glam4cm/graph2str/ontouml.py +141 -0
  40. glam4cm/graph2str/uml.py +0 -0
  41. glam4cm/lang2graph/__init__.py +0 -0
  42. glam4cm/lang2graph/archimate.py +31 -0
  43. glam4cm/lang2graph/bpmn.py +0 -0
  44. glam4cm/lang2graph/common.py +416 -0
  45. glam4cm/lang2graph/ecore.py +221 -0
  46. glam4cm/lang2graph/ontouml.py +169 -0
  47. glam4cm/lang2graph/utils.py +80 -0
  48. glam4cm/models/cmgpt.py +352 -0
  49. glam4cm/models/gnn_layers.py +273 -0
  50. glam4cm/models/hf.py +10 -0
  51. glam4cm/run.py +99 -0
  52. glam4cm/run_configs.py +126 -0
  53. glam4cm/settings.py +54 -0
  54. glam4cm/tokenization/__init__.py +0 -0
  55. glam4cm/tokenization/special_tokens.py +4 -0
  56. glam4cm/tokenization/utils.py +37 -0
  57. glam4cm/trainers/__init__.py +0 -0
  58. glam4cm/trainers/bert_classifier.py +105 -0
  59. glam4cm/trainers/cm_gpt_trainer.py +153 -0
  60. glam4cm/trainers/gnn_edge_classifier.py +126 -0
  61. glam4cm/trainers/gnn_graph_classifier.py +123 -0
  62. glam4cm/trainers/gnn_link_predictor.py +144 -0
  63. glam4cm/trainers/gnn_node_classifier.py +135 -0
  64. glam4cm/trainers/gnn_trainer.py +129 -0
  65. glam4cm/trainers/metrics.py +55 -0
  66. glam4cm/utils.py +194 -0
  67. glam4cm-0.1.0.dist-info/LICENSE +21 -0
  68. glam4cm-0.1.0.dist-info/METADATA +86 -0
  69. glam4cm-0.1.0.dist-info/RECORD +72 -0
  70. glam4cm-0.1.0.dist-info/WHEEL +5 -0
  71. glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
  72. glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,84 @@
1
+ class GraphMetadata:
2
+ def __init__(self, model_type):
3
+ self.type = model_type
4
+
5
+ @property
6
+ def node_label(self):
7
+ return self.node.get('label', None)
8
+
9
+ @property
10
+ def node_cls(self):
11
+ return self.node.get('cls', None)
12
+
13
+ @property
14
+ def node_attributes(self):
15
+ return self.node.get('attributes', None)
16
+
17
+ @property
18
+ def edge_label(self):
19
+ return self.edge.get('label', None)
20
+
21
+ @property
22
+ def edge_cls(self):
23
+ return self.edge.get('cls', None)
24
+
25
+ @property
26
+ def graph_cls(self):
27
+ return self.graph.get('cls', None)
28
+
29
+ @property
30
+ def graph_label(self):
31
+ return self.graph.get('label', None)
32
+
33
+
34
+
35
+ class EcoreMetaData(GraphMetadata):
36
+ def __init__(self):
37
+ super().__init__('ecore')
38
+ self.node = {
39
+ "label": "name",
40
+ "cls": "abstract",
41
+ "attributes": "attributes"
42
+ }
43
+ self.edge = {
44
+ "label": "name",
45
+ "cls": "type"
46
+ }
47
+ self.graph = {
48
+ "label": "text",
49
+ "cls": "label"
50
+ }
51
+
52
+
53
+
54
+ class ArchimateMetaData(GraphMetadata):
55
+ def __init__(self):
56
+ super().__init__('archimate')
57
+ self.node = {
58
+ "label": "name",
59
+ "cls": ["type", "layer"],
60
+ }
61
+ self.edge = {
62
+ "cls": "type"
63
+ }
64
+
65
+ self.graph = {
66
+ "label": "text",
67
+ }
68
+
69
+
70
+ class OntoUMLMetaData(GraphMetadata):
71
+ def __init__(self):
72
+ super().__init__('ontouml')
73
+ self.node = {
74
+ "label": "name",
75
+ "cls": ["stereotype"],
76
+ "attributes": "properties"
77
+ }
78
+ self.edge = {
79
+ "cls": "stereotype"
80
+ }
81
+
82
+ self.graph = {
83
+ "label": "text",
84
+ }
@@ -0,0 +1,361 @@
1
+ from typing import List
2
+ import pandas as pd
3
+ from tqdm.auto import tqdm
4
+ import pickle
5
+ from random import shuffle
6
+ from sklearn.model_selection import StratifiedKFold
7
+ import json
8
+ import os
9
+ from glam4cm.data_loading.encoding import EncodingDataset
10
+ from glam4cm.lang2graph.archimate import ArchiMateNxG
11
+ from glam4cm.lang2graph.ecore import EcoreNxG
12
+ from glam4cm.lang2graph.common import LangGraph
13
+ from glam4cm.lang2graph.ontouml import OntoUMLNxG
14
+ from glam4cm.settings import (
15
+ datasets_dir,
16
+ seed,
17
+ )
18
+ import numpy as np
19
+
20
+
21
+ from glam4cm.settings import logger
22
+
23
+
24
+ class ModelDataset:
25
+ def __init__(
26
+ self,
27
+ dataset_name: str,
28
+ dataset_dir=datasets_dir,
29
+ save_dir='datasets/pickles',
30
+ min_edges: int = -1,
31
+ min_enr: float = -1,
32
+ timeout=-1,
33
+ preprocess_graph_text: callable = None
34
+ ):
35
+ self.name = dataset_name
36
+ self.dataset_dir = dataset_dir
37
+ self.save_dir = save_dir
38
+ os.makedirs(save_dir, exist_ok=True)
39
+
40
+ self.min_edges = min_edges
41
+ self.min_enr = min_enr
42
+ self.timeout = timeout
43
+ self.preprocess_graph_text = preprocess_graph_text
44
+
45
+ self.graphs: List[LangGraph] = []
46
+
47
+
48
+ def get_train_test_split(self, train_size=0.8):
49
+ n = len(self.graphs)
50
+ train_size = int(n * train_size)
51
+ idx = list(range(n))
52
+ shuffle(idx)
53
+ train_idx = idx[:train_size]
54
+ test_idx = idx[train_size:]
55
+ return train_idx, test_idx
56
+
57
+
58
+ def k_fold_split(
59
+ self,
60
+ k=10
61
+ ):
62
+ kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
63
+ n = len(self.graphs)
64
+ for train_idx, test_idx in kfold.split(np.zeros(n), np.zeros(n)):
65
+ yield train_idx, test_idx
66
+
67
+
68
+ @property
69
+ def data(self):
70
+ X, y = [], []
71
+ for g in self.graphs:
72
+ X.append(g.text)
73
+ y.append(g.label)
74
+
75
+ if self.preprocess_graph_text:
76
+ X = [self.preprocess_graph_text(x) for x in X]
77
+ return X, y
78
+
79
+ def __get_lm_data(self, train_idx, test_idx, tokenizer, remove_duplicates=False):
80
+ X, y = self.data
81
+ y_enc = {label: i for i, label in enumerate(set(y))}
82
+ y = [y_enc[label] for label in y]
83
+ X_train, y_train = [X[i] for i in train_idx], [y[i] for i in train_idx]
84
+ X_test, y_test = [X[i] for i in test_idx], [y[i] for i in test_idx]
85
+ train_dataset = EncodingDataset(tokenizer, X_train, y_train, remove_duplicates=remove_duplicates)
86
+ test_dataset = EncodingDataset(tokenizer, X_test, y_test, remove_duplicates=remove_duplicates)
87
+ num_classes = len(set(y))
88
+ return {
89
+ 'train': train_dataset,
90
+ 'test': test_dataset,
91
+ 'num_classes': num_classes
92
+ }
93
+
94
+ def get_graph_classification_data(self, tokenizer, remove_duplicates=False):
95
+ train_idx, test_idx = self.get_train_test_split()
96
+ return self.__get_lm_data(train_idx, test_idx, tokenizer, remove_duplicates=remove_duplicates)
97
+
98
+ def get_graph_classification_data_kfold(self, tokenizer, k=10, remove_duplicates=False):
99
+ for train_idx, test_idx in self.k_fold_split(k=k):
100
+ yield self.__get_lm_data(train_idx, test_idx, tokenizer, remove_duplicates=remove_duplicates)
101
+
102
+
103
+ def __repr__(self):
104
+ return f'Dataset({self.name}, graphs={len(self.graphs)})'
105
+
106
+ def __getitem__(self, key) -> LangGraph:
107
+ return self.graphs[key]
108
+
109
+ def __iter__(self):
110
+ return iter(self.graphs)
111
+
112
+ def __len__(self):
113
+ return len(self.graphs)
114
+
115
+ def save(self):
116
+ print(f'Saving {self.name} to pickle')
117
+ with open(os.path.join(self.save_dir, f'{self.name}.pkl'), 'wb') as f:
118
+ pickle.dump(self.graphs, f)
119
+ print(f'Saved {self.name} to pickle')
120
+
121
+
122
+ def filter_graphs(self):
123
+ graphs = list()
124
+ for graph in self.graphs:
125
+ addable = True
126
+ if self.min_edges > 0 and graph.number_of_edges() < self.min_edges:
127
+ addable = False
128
+ if self.min_enr > 0 and graph.enr < self.min_enr:
129
+ addable = False
130
+
131
+ if addable:
132
+ graphs.append(graph)
133
+
134
+ self.graphs = graphs
135
+
136
+
137
+
138
+ def load(self):
139
+ print(f'Loading {self.name} from pickle')
140
+ with open(os.path.join(self.save_dir, f'{self.name}.pkl'), 'rb') as f:
141
+ self.graphs = pickle.load(f)
142
+
143
+ self.filter_graphs()
144
+ print(f'Loaded {self.name} with {len(self.graphs)} graphs')
145
+
146
+
147
+ @property
148
+ def summary(self):
149
+ num_graphs = len(self.graphs)
150
+ num_edges = sum([g.number_of_edges() for g in self.graphs])
151
+ num_nodes = sum([g.number_of_nodes() for g in self.graphs])
152
+ average_nodes = num_nodes / num_graphs
153
+ average_edges = num_edges / num_graphs
154
+ average_n2e_ratio = np.mean([g.number_of_nodes() / g.number_of_edges() for g in self.graphs])
155
+ return {
156
+ 'num_graphs': num_graphs,
157
+ 'num_edges': num_edges,
158
+ 'num_nodes': num_nodes,
159
+ 'average_nodes': f"{average_nodes:.2f}",
160
+ 'average_edges': f"{average_edges:.2f}",
161
+ 'average_n2e_ratio': f"{average_n2e_ratio:.2f}"
162
+ }
163
+
164
+
165
+ class EcoreDataset(ModelDataset):
166
+ def __init__(
167
+ self,
168
+ dataset_name: str,
169
+ dataset_dir=datasets_dir,
170
+ save_dir='datasets/pickles',
171
+ reload=False,
172
+ remove_duplicates=False,
173
+ min_edges: int = -1,
174
+ min_enr: float = -1,
175
+ preprocess_graph_text: callable = None
176
+ ):
177
+ super().__init__(
178
+ dataset_name,
179
+ dataset_dir=dataset_dir,
180
+ save_dir=save_dir,
181
+ min_edges=min_edges,
182
+ min_enr=min_enr,
183
+ preprocess_graph_text=preprocess_graph_text
184
+ )
185
+ os.makedirs(save_dir, exist_ok=True)
186
+
187
+ dataset_exists = os.path.exists(os.path.join(save_dir, f'{dataset_name}.pkl'))
188
+ if reload or not dataset_exists:
189
+ self.graphs: List[EcoreNxG] = []
190
+ data_path = os.path.join(dataset_dir, dataset_name)
191
+ for file in os.listdir(data_path):
192
+ if file.endswith('.jsonl') and file.startswith('ecore'):
193
+ json_objects = json.load(open(os.path.join(data_path, file)))
194
+ for g in tqdm(json_objects, desc=f'Loading {dataset_name.title()}'):
195
+ if remove_duplicates and g['is_duplicated']:
196
+ continue
197
+ nxg = EcoreNxG(g)
198
+ self.graphs.append(nxg)
199
+
200
+ print(f'Loaded Total {self.name} with {len(self.graphs)} graphs')
201
+ print("Filtering...")
202
+ self.save()
203
+ self.filter_graphs()
204
+ else:
205
+ self.load()
206
+
207
+ logger.info(f'Loaded {self.name} with {len(self.graphs)} graphs')
208
+
209
+ # if remove_duplicates:
210
+ # self.dedup()
211
+
212
+ logger.info(f'Graphs: {len(self.graphs)}')
213
+ print(f'Loaded {self.name} with {len(self.graphs)} graphs')
214
+
215
+
216
+ def dedup(self) -> List[EcoreNxG]:
217
+ logger.info(f'Deduplicating {self.name}')
218
+ return [g for g in self.graphs if not g.is_duplicated]
219
+
220
+ def __repr__(self):
221
+ return f"EcoreDataset({self.name}, graphs={len(self.graphs)})"
222
+
223
+
224
+ class ArchiMateDataset(ModelDataset):
225
+ def __init__(
226
+ self,
227
+ dataset_name: str,
228
+ dataset_dir=datasets_dir,
229
+ save_dir='datasets/pickles',
230
+ reload=False,
231
+ remove_duplicates=False,
232
+ min_edges: int = -1,
233
+ min_enr: float = -1,
234
+ timeout=-1,
235
+ language=None,
236
+ preprocess_graph_text: callable = None
237
+ ):
238
+ super().__init__(
239
+ dataset_name,
240
+ dataset_dir=dataset_dir,
241
+ save_dir=save_dir,
242
+ min_edges=min_edges,
243
+ min_enr=min_enr,
244
+ timeout=timeout,
245
+ preprocess_graph_text=preprocess_graph_text
246
+ )
247
+ os.makedirs(save_dir, exist_ok=True)
248
+
249
+ dataset_exists = os.path.exists(os.path.join(save_dir, f'{dataset_name}.pkl'))
250
+ if reload or not dataset_exists:
251
+ self.graphs: List[ArchiMateNxG] = []
252
+ data_path = os.path.join(dataset_dir, dataset_name, 'processed-models')
253
+ if language:
254
+ df = pd.read_csv(os.path.join(dataset_dir, dataset_name, f'{language}-metadata.csv'))
255
+ model_dirs = df['ID'].to_list()
256
+ else:
257
+ model_dirs = os.listdir(data_path)
258
+
259
+ for model_dir in tqdm(model_dirs, desc=f'Loading {dataset_name.title()}'):
260
+ model_dir = os.path.join(data_path, model_dir)
261
+ if os.path.isdir(model_dir):
262
+ model_file = os.path.join(model_dir, 'model.json')
263
+ if os.path.exists(model_file):
264
+ model = json.load(open(model_file))
265
+ try:
266
+ nxg = ArchiMateNxG(
267
+ model,
268
+ path=model_file,
269
+ timeout=timeout
270
+ )
271
+ if nxg.number_of_edges() < 1:
272
+ continue
273
+ self.graphs.append(nxg)
274
+
275
+ except Exception as e:
276
+ raise e
277
+
278
+ self.filter_graphs()
279
+ self.save()
280
+ else:
281
+ self.load()
282
+
283
+ if remove_duplicates:
284
+ self.dedup()
285
+
286
+ print(f'Loaded {self.name} with {len(self.graphs)} graphs')
287
+ print(f'Graphs: {len(self.graphs)}')
288
+
289
+
290
+ def dedup(self) -> List[ArchiMateNxG]:
291
+ return list({str(g.edges(data=True)): g for g in self.graphs}.values())
292
+
293
+ def __repr__(self):
294
+ return f"ArchiMateDataset({self.name}, graphs={len(self.graphs)})"
295
+
296
+
297
+ class OntoUMLDataset(ModelDataset):
298
+ def __init__(
299
+ self,
300
+ dataset_name: str,
301
+ dataset_dir=datasets_dir,
302
+ save_dir='datasets/pickles',
303
+ reload=False,
304
+ remove_duplicates=False,
305
+ min_edges: int = -1,
306
+ min_enr: float = -1,
307
+ timeout=-1,
308
+ preprocess_graph_text: callable = None
309
+ ):
310
+ super().__init__(
311
+ dataset_name,
312
+ dataset_dir=dataset_dir,
313
+ save_dir=save_dir,
314
+ min_edges=min_edges,
315
+ min_enr=min_enr,
316
+ timeout=timeout,
317
+ preprocess_graph_text=preprocess_graph_text
318
+ )
319
+ os.makedirs(save_dir, exist_ok=True)
320
+
321
+ dataset_exists = os.path.exists(os.path.join(save_dir, f'{dataset_name}.pkl'))
322
+ if reload or not dataset_exists:
323
+ self.graphs: List[OntoUMLNxG] = []
324
+ data_path = os.path.join(dataset_dir, dataset_name, 'models')
325
+ model_dirs = os.listdir(data_path)
326
+
327
+ for model_dir in tqdm(model_dirs, desc=f'Loading {dataset_name.title()}'):
328
+ model_dir = os.path.join(data_path, model_dir)
329
+ if os.path.isdir(model_dir):
330
+ model_file = os.path.join(model_dir, 'ontology.json')
331
+ if os.path.exists(model_file):
332
+ with open(model_file, encoding='iso-8859-1') as f:
333
+ model = json.load(f)
334
+ try:
335
+ nxg = OntoUMLNxG(model)
336
+ if nxg.number_of_edges() < 1:
337
+ continue
338
+ self.graphs.append(nxg)
339
+
340
+ except Exception as e:
341
+ print(f"Error in {model_file} {e}")
342
+
343
+ self.filter_graphs()
344
+ self.save()
345
+ else:
346
+ self.load()
347
+
348
+ if remove_duplicates:
349
+ self.dedup()
350
+
351
+ print(f'Loaded {self.name} with {len(self.graphs)} graphs')
352
+ print(f'Graphs: {len(self.graphs)}')
353
+
354
+
355
+ def dedup(self) -> List[OntoUMLNxG]:
356
+ return list({str(g.edges(data=True)): g for g in self.graphs}.values())
357
+
358
+ def __repr__(self):
359
+ return f"OntoUMLDataset({self.name}, graphs={len(self.graphs)})"
360
+
361
+
@@ -0,0 +1,20 @@
1
+ import numpy as np
2
+
3
+
4
+ def oversample_dataset(dataset, oversampling_ratio=0.7):
5
+ """
6
+ This function oversamples the classes that occur less frequently in the dataset.
7
+ The occurence of each class is counted and each class is oversampled 70% of the difference between the most common class and the class in question.
8
+ """
9
+
10
+ class_occurences = dataset[:]['labels'].numpy()
11
+ unique_classes, counts = np.unique(class_occurences, return_counts=True)
12
+ max_count = counts.max()
13
+ indices_with_oversamples = []
14
+ for class_idx, count in zip(unique_classes, counts):
15
+ class_indices = np.where(class_occurences == class_idx)[0]
16
+ indices_with_oversamples.extend(class_indices)
17
+ oversample_count = int(oversampling_ratio * (max_count - count))
18
+ indices_with_oversamples.extend(np.random.choice(class_indices, oversample_count))
19
+
20
+ return indices_with_oversamples
File without changes
@@ -0,0 +1,144 @@
1
+ import os
2
+ from transformers import TrainingArguments, Trainer
3
+ from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
4
+ from glam4cm.data_loading.utils import oversample_dataset
5
+ from glam4cm.settings import LP_TASK_EDGE_CLS
6
+ from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
7
+ from glam4cm.models.hf import get_model
8
+ from glam4cm.downstream_tasks.utils import get_models_dataset
9
+
10
+
11
+ from sklearn.metrics import (
12
+ accuracy_score,
13
+ f1_score,
14
+ recall_score,
15
+ balanced_accuracy_score
16
+ )
17
+
18
+ from glam4cm.tokenization.utils import get_tokenizer
19
+ from glam4cm.utils import merge_argument_parsers, set_seed
20
+
21
+
22
+ def compute_metrics(pred):
23
+ labels = pred.label_ids
24
+ preds = pred.predictions.argmax(-1)
25
+ acc = (preds == labels).mean()
26
+ f1_macro = f1_score(labels, preds, average='macro')
27
+ accuracy = accuracy_score(labels, preds)
28
+ recall = recall_score(labels, preds, average='macro')
29
+ balanced_acc = balanced_accuracy_score(labels, preds)
30
+
31
+ return {
32
+ 'balanced_accuracy': balanced_acc,
33
+ 'accuracy': acc,
34
+ 'f1_macro': f1_macro,
35
+ 'precision': accuracy,
36
+ 'recall': recall
37
+ }
38
+
39
+
40
+ def get_parser():
41
+ common_parser = get_common_args_parser()
42
+ bert_parser = get_bert_args_parser()
43
+ parser = merge_argument_parsers(common_parser, bert_parser)
44
+
45
+ parser.add_argument('--oversampling_ratio', type=float, default=-1)
46
+
47
+ return parser
48
+
49
+
50
+ def run(args):
51
+ set_seed(args.seed)
52
+
53
+ config_params = dict(
54
+ min_enr = args.min_enr,
55
+ min_edges = args.min_edges,
56
+ remove_duplicates = args.remove_duplicates,
57
+ language = args.language,
58
+ reload=args.reload
59
+ )
60
+ dataset_name = args.dataset
61
+
62
+ print("Loaded dataset")
63
+ dataset = get_models_dataset(dataset_name, **config_params)
64
+
65
+ graph_data_params = get_config_params(args)
66
+ graph_data_params = {**graph_data_params, 'task_type': LP_TASK_EDGE_CLS}
67
+
68
+ print("Loading graph dataset")
69
+ graph_dataset = GraphEdgeDataset(dataset, **graph_data_params)
70
+ print("Loaded graph dataset")
71
+
72
+ assert hasattr(graph_dataset, f'num_edges_{args.edge_cls_label}'), f"Dataset does not have node_{args.edge_cls_label} attribute"
73
+ num_labels = getattr(graph_dataset, f"num_edges_{args.edge_cls_label}")
74
+
75
+
76
+ model_name = args.model_name
77
+ tokenizer = get_tokenizer(model_name, args.use_special_tokens)
78
+
79
+ print("Getting Edge Classification data")
80
+ bert_dataset = graph_dataset.get_link_prediction_lm_data(tokenizer=tokenizer)
81
+
82
+ # exit(0)
83
+
84
+ if args.oversampling_ratio != -1:
85
+ ind_w_oversamples = oversample_dataset(bert_dataset['train'])
86
+ bert_dataset['train'].inputs = bert_dataset['train'][ind_w_oversamples]
87
+
88
+ print("Training model")
89
+ print(f'Number of labels: {num_labels}')
90
+
91
+ model = get_model(args.ckpt if args.ckpt else model_name, num_labels, len(tokenizer))
92
+
93
+ if args.freeze_pretrained_weights:
94
+ for param in model.base_model.parameters():
95
+ param.requires_grad = False
96
+
97
+ output_dir = os.path.join(
98
+ 'results',
99
+ dataset_name,
100
+ 'edge_cls',
101
+ f'{args.edge_cls_label}',
102
+ f"{graph_dataset.config_hash}",
103
+ )
104
+
105
+ logs_dir = os.path.join(
106
+ 'logs',
107
+ dataset_name,
108
+ 'edge_cls',
109
+ f'{args.edge_cls_label}',
110
+ f"{graph_dataset.config_hash}",
111
+ )
112
+
113
+ training_args = TrainingArguments(
114
+ output_dir=output_dir,
115
+ num_train_epochs=args.num_epochs,
116
+ per_device_train_batch_size=args.train_batch_size,
117
+ per_device_eval_batch_size=args.eval_batch_size,
118
+ weight_decay=0.01,
119
+ logging_dir=logs_dir,
120
+ logging_steps=args.num_log_steps,
121
+ eval_strategy='steps',
122
+ eval_steps=args.num_eval_steps,
123
+ save_steps=args.num_save_steps,
124
+ save_total_limit=2,
125
+ load_best_model_at_end=True,
126
+ fp16=True,
127
+ )
128
+
129
+ trainer = Trainer(
130
+ model=model,
131
+ args=training_args,
132
+ train_dataset=bert_dataset['train'],
133
+ eval_dataset=bert_dataset['test'],
134
+ compute_metrics=compute_metrics
135
+ )
136
+
137
+ trainer.train()
138
+ print(trainer.evaluate())
139
+ trainer.save_model()
140
+
141
+
142
+ if __name__ == '__main__':
143
+ args = get_parser()
144
+ run(args)