glam4cm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. glam4cm/__init__.py +9 -0
  2. glam4cm/data_loading/__init__.py +0 -0
  3. glam4cm/data_loading/data.py +631 -0
  4. glam4cm/data_loading/encoding.py +76 -0
  5. glam4cm/data_loading/graph_dataset.py +940 -0
  6. glam4cm/data_loading/metadata.py +84 -0
  7. glam4cm/data_loading/models_dataset.py +361 -0
  8. glam4cm/data_loading/utils.py +20 -0
  9. glam4cm/downstream_tasks/__init__.py +0 -0
  10. glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
  11. glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
  12. glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
  13. glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
  14. glam4cm/downstream_tasks/bert_node_classification.py +164 -0
  15. glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
  16. glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
  17. glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
  18. glam4cm/downstream_tasks/common_args.py +160 -0
  19. glam4cm/downstream_tasks/create_dataset.py +51 -0
  20. glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
  21. glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
  22. glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
  23. glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
  24. glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
  25. glam4cm/downstream_tasks/utils.py +35 -0
  26. glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
  27. glam4cm/embeddings/__init__.py +0 -0
  28. glam4cm/embeddings/bert.py +72 -0
  29. glam4cm/embeddings/common.py +43 -0
  30. glam4cm/embeddings/fasttext.py +0 -0
  31. glam4cm/embeddings/tfidf.py +25 -0
  32. glam4cm/embeddings/w2v.py +41 -0
  33. glam4cm/encoding/__init__.py +0 -0
  34. glam4cm/encoding/common.py +0 -0
  35. glam4cm/encoding/encoders.py +100 -0
  36. glam4cm/graph2str/__init__.py +0 -0
  37. glam4cm/graph2str/common.py +34 -0
  38. glam4cm/graph2str/constants.py +15 -0
  39. glam4cm/graph2str/ontouml.py +141 -0
  40. glam4cm/graph2str/uml.py +0 -0
  41. glam4cm/lang2graph/__init__.py +0 -0
  42. glam4cm/lang2graph/archimate.py +31 -0
  43. glam4cm/lang2graph/bpmn.py +0 -0
  44. glam4cm/lang2graph/common.py +416 -0
  45. glam4cm/lang2graph/ecore.py +221 -0
  46. glam4cm/lang2graph/ontouml.py +169 -0
  47. glam4cm/lang2graph/utils.py +80 -0
  48. glam4cm/models/cmgpt.py +352 -0
  49. glam4cm/models/gnn_layers.py +273 -0
  50. glam4cm/models/hf.py +10 -0
  51. glam4cm/run.py +99 -0
  52. glam4cm/run_configs.py +126 -0
  53. glam4cm/settings.py +54 -0
  54. glam4cm/tokenization/__init__.py +0 -0
  55. glam4cm/tokenization/special_tokens.py +4 -0
  56. glam4cm/tokenization/utils.py +37 -0
  57. glam4cm/trainers/__init__.py +0 -0
  58. glam4cm/trainers/bert_classifier.py +105 -0
  59. glam4cm/trainers/cm_gpt_trainer.py +153 -0
  60. glam4cm/trainers/gnn_edge_classifier.py +126 -0
  61. glam4cm/trainers/gnn_graph_classifier.py +123 -0
  62. glam4cm/trainers/gnn_link_predictor.py +144 -0
  63. glam4cm/trainers/gnn_node_classifier.py +135 -0
  64. glam4cm/trainers/gnn_trainer.py +129 -0
  65. glam4cm/trainers/metrics.py +55 -0
  66. glam4cm/utils.py +194 -0
  67. glam4cm-0.1.0.dist-info/LICENSE +21 -0
  68. glam4cm-0.1.0.dist-info/METADATA +86 -0
  69. glam4cm-0.1.0.dist-info/RECORD +72 -0
  70. glam4cm-0.1.0.dist-info/WHEEL +5 -0
  71. glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
  72. glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,137 @@
1
+ import os
2
+ from sklearn.metrics import (
3
+ accuracy_score,
4
+ balanced_accuracy_score,
5
+ f1_score,
6
+ recall_score
7
+ )
8
+ from transformers import (
9
+ Trainer,
10
+ TrainingArguments
11
+ )
12
+
13
+ from glam4cm.data_loading.graph_dataset import GraphNodeDataset
14
+ from glam4cm.models.hf import get_model
15
+ from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
16
+ from glam4cm.downstream_tasks.utils import get_models_dataset
17
+ from glam4cm.tokenization.utils import get_tokenizer
18
+ from glam4cm.utils import merge_argument_parsers, set_seed
19
+
20
+
21
+
22
+ def compute_metrics(pred):
23
+ labels = pred.label_ids
24
+ preds = pred.predictions.argmax(-1)
25
+ acc = (preds == labels).mean()
26
+ f1_macro = f1_score(labels, preds, average='macro')
27
+ accuracy = accuracy_score(labels, preds)
28
+ recall = recall_score(labels, preds, average='macro')
29
+ balanced_acc = balanced_accuracy_score(labels, preds)
30
+
31
+ return {
32
+ 'balanced_accuracy': balanced_acc,
33
+ 'accuracy': acc,
34
+ 'f1_macro': f1_macro,
35
+ 'precision': accuracy,
36
+ 'recall': recall
37
+ }
38
+
39
+
40
+ def get_parser():
41
+ common_parser = get_common_args_parser()
42
+ bert_parser = get_bert_args_parser()
43
+ parser = merge_argument_parsers(common_parser, bert_parser)
44
+
45
+ parser.add_argument('--cls_label', type=str, default='label')
46
+ parser.add_argument('--remove_duplicate_graphs', action='store_true')
47
+ return parser
48
+
49
+
50
+ def run(args):
51
+ set_seed(args.seed)
52
+
53
+ config_params = dict(
54
+ min_enr = args.min_enr,
55
+ min_edges = args.min_edges,
56
+ remove_duplicates = args.remove_duplicates,
57
+ reload = args.reload,
58
+ language = args.language
59
+ )
60
+ dataset_name = args.dataset
61
+
62
+ dataset = get_models_dataset(dataset_name, **config_params)
63
+ graph_data_params = get_config_params(args)
64
+ print("Loading graph dataset")
65
+ graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
66
+ print("Loaded graph dataset")
67
+
68
+ model_name = args.model_name
69
+ tokenizer = get_tokenizer(model_name, args.use_special_tokens)
70
+
71
+ fold_id = 0
72
+ for classification_dataset in graph_dataset.get_kfold_lm_graph_classification_data(
73
+ tokenizer,
74
+ remove_duplicates=args.remove_duplicate_graphs
75
+ ):
76
+ train_dataset = classification_dataset['train']
77
+ test_dataset = classification_dataset['test']
78
+ num_labels = classification_dataset['num_classes']
79
+
80
+ print(len(train_dataset), len(test_dataset), num_labels)
81
+
82
+ print("Training model")
83
+ output_dir = os.path.join(
84
+ 'results',
85
+ dataset_name,
86
+ f'graph_cls_',
87
+ f"{graph_dataset.config_hash}",
88
+ )
89
+
90
+ logs_dir = os.path.join(
91
+ 'logs',
92
+ dataset_name,
93
+ f'graph_cls_',
94
+ f"{graph_dataset.config_hash}"
95
+ )
96
+
97
+ model = get_model(args.ckpt if args.ckpt else model_name, num_labels, len(tokenizer))
98
+
99
+ if args.freeze_pretrained_weights:
100
+ for param in model.base_model.parameters():
101
+ param.requires_grad = False
102
+
103
+ # Training arguments
104
+ training_args = TrainingArguments(
105
+ output_dir=output_dir,
106
+ num_train_epochs=args.num_epochs,
107
+ eval_strategy="steps",
108
+ per_device_train_batch_size=args.train_batch_size,
109
+ per_device_eval_batch_size=args.eval_batch_size,
110
+ warmup_steps=200,
111
+ weight_decay=0.01,
112
+ learning_rate=5e-5,
113
+ logging_dir=logs_dir,
114
+ logging_steps=args.num_log_steps,
115
+ eval_steps=args.num_eval_steps,
116
+ save_steps=args.num_save_steps,
117
+ save_total_limit=2,
118
+ load_best_model_at_end=True,
119
+ fp16=True
120
+ )
121
+
122
+ # Trainer
123
+ trainer = Trainer(
124
+ model=model,
125
+ args=training_args,
126
+ train_dataset=train_dataset,
127
+ eval_dataset=test_dataset,
128
+ compute_metrics=compute_metrics
129
+ )
130
+
131
+ # Train the model
132
+ trainer.train()
133
+ results = trainer.evaluate()
134
+ print(results)
135
+
136
+ fold_id += 1
137
+ break
@@ -0,0 +1,156 @@
1
+ import os
2
+ import json
3
+ from argparse import ArgumentParser
4
+ from random import shuffle
5
+
6
+ import numpy as np
7
+ from sklearn.model_selection import StratifiedKFold
8
+
9
+ from sklearn.metrics import (
10
+ accuracy_score,
11
+ balanced_accuracy_score,
12
+ f1_score,
13
+ recall_score
14
+ )
15
+ from transformers import (
16
+ AutoTokenizer,
17
+ Trainer,
18
+ TrainingArguments
19
+ )
20
+
21
+ from glam4cm.data_loading.encoding import EncodingDataset
22
+ from glam4cm.models.hf import get_model
23
+
24
+
25
+
26
+ def compute_metrics(pred):
27
+ labels = pred.label_ids
28
+ preds = np.argmax(pred.predictions, axis=1)
29
+ acc = (preds == labels).mean()
30
+ f1_macro = f1_score(labels, preds, average='macro')
31
+ accuracy = accuracy_score(labels, preds)
32
+ recall = recall_score(labels, preds, average='macro')
33
+ balanced_acc = balanced_accuracy_score(labels, preds)
34
+
35
+ return {
36
+ 'balanced_accuracy': balanced_acc,
37
+ 'accuracy': acc,
38
+ 'f1_macro': f1_macro,
39
+ 'precision': accuracy,
40
+ 'recall': recall
41
+ }
42
+
43
+
44
+ def get_parser():
45
+ parser = ArgumentParser()
46
+ parser.add_argument('--seed', type=int, default=42)
47
+ parser.add_argument('--dataset_name', type=str, default='ecore_555')
48
+ parser.add_argument('--model_name', type=str, default='bert-base-uncased')
49
+ parser.add_argument('--ckpt', type=str, default=None)
50
+ parser.add_argument('--max_length', type=int, default=512)
51
+ parser.add_argument('--k', type=int, default=10)
52
+
53
+ parser.add_argument('--num_epochs', type=int, default=10)
54
+
55
+ parser.add_argument('--warmup_steps', type=int, default=500)
56
+ parser.add_argument('--num_log_steps', type=int, default=50)
57
+ parser.add_argument('--num_eval_steps', type=int, default=50)
58
+ parser.add_argument('--num_save_steps', type=int, default=50)
59
+ parser.add_argument('--train_batch_size', type=int, default=2)
60
+ parser.add_argument('--eval_batch_size', type=int, default=128)
61
+ parser.add_argument('--lr', type=float, default=1e-5)
62
+
63
+ return parser
64
+
65
+
66
+ def run(args):
67
+ dataset_name = args.dataset_name
68
+ model_name = args.model_name
69
+
70
+
71
+ texts = [
72
+ (g['txt'], g['labels'])
73
+ for file_name in os.listdir(f'datasets/{dataset_name}')
74
+ for g in json.load(open(f'datasets/{dataset_name}/{file_name}'))
75
+ if 'ecore' in file_name and file_name.endswith('.jsonl')
76
+ ]
77
+ shuffle(texts)
78
+ labels = [y for _, y in texts]
79
+ y_map = {label: i for i, label in enumerate(set(y for y in labels))}
80
+ y = [y_map[y] for y in labels]
81
+ n = len(texts)
82
+
83
+ texts = [text for text, _ in texts]
84
+
85
+ num_labels = len(y_map)
86
+
87
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
88
+ k = args.k
89
+ kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=args.seed)
90
+
91
+ i = 0
92
+ for train_idx, test_idx in kfold.split(np.zeros(n), np.zeros(n)):
93
+
94
+ print(f'Fold {i+1}/{k}')
95
+
96
+ train_texts = [texts[i] for i in train_idx]
97
+ test_texts = [texts[i] for i in test_idx]
98
+ train_y = [y[i] for i in train_idx]
99
+ test_y = [y[i] for i in test_idx]
100
+
101
+
102
+ print(f'Train: {len(train_texts)}, Test: {len(test_texts)}', num_labels)
103
+
104
+ train_dataset = EncodingDataset(tokenizer, train_texts, train_y)
105
+ test_dataset = EncodingDataset(tokenizer, test_texts, test_y)
106
+
107
+ model = get_model(args.ckpt if args.ckpt else model_name, num_labels, len(tokenizer))
108
+
109
+ print("Training model")
110
+ output_dir = os.path.join(
111
+ 'results',
112
+ dataset_name,
113
+ 'graph_cls_comp',
114
+ )
115
+
116
+ logs_dir = os.path.join(
117
+ 'logs',
118
+ dataset_name,
119
+ 'graph_cls_comp',
120
+ )
121
+
122
+ print("Running epochs: ", args.num_epochs)
123
+
124
+ # Training arguments
125
+ training_args = TrainingArguments(
126
+ output_dir=output_dir,
127
+ num_train_epochs=args.num_epochs,
128
+ eval_strategy="steps",
129
+ per_device_train_batch_size=args.train_batch_size,
130
+ per_device_eval_batch_size=args.eval_batch_size,
131
+ warmup_steps=500,
132
+ weight_decay=0.01,
133
+ logging_dir=logs_dir,
134
+ logging_steps=10,
135
+ eval_steps=10,
136
+ save_total_limit=2,
137
+ load_best_model_at_end=True,
138
+ fp16=True
139
+ )
140
+
141
+ # Trainer
142
+ trainer = Trainer(
143
+ model=model,
144
+ args=training_args,
145
+ train_dataset=train_dataset,
146
+ eval_dataset=test_dataset,
147
+ compute_metrics=compute_metrics
148
+ )
149
+
150
+ # Train the model
151
+ trainer.train()
152
+ results = trainer.evaluate()
153
+ print(results)
154
+
155
+ i += 1
156
+
@@ -0,0 +1,145 @@
1
+ from collections import Counter
2
+ import os
3
+ from transformers import TrainingArguments, Trainer
4
+ from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
5
+ from glam4cm.models.hf import get_model
6
+ from glam4cm.settings import LP_TASK_LINK_PRED
7
+ from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
8
+ from glam4cm.downstream_tasks.utils import get_models_dataset
9
+ from glam4cm.tokenization.special_tokens import *
10
+
11
+
12
+ from sklearn.metrics import (
13
+ f1_score,
14
+ precision_score,
15
+ balanced_accuracy_score,
16
+ recall_score
17
+ )
18
+
19
+ from glam4cm.tokenization.utils import get_tokenizer
20
+ from glam4cm.utils import merge_argument_parsers, set_seed
21
+
22
+
23
+ def compute_metrics(pred):
24
+ labels = pred.label_ids
25
+ preds = pred.predictions.argmax(-1)
26
+ print(Counter(preds), Counter(labels))
27
+ acc = (preds == labels).mean()
28
+ # roc = roc_auc_score(labels, preds)
29
+ f1_macro = f1_score(labels, preds)
30
+ f1_micro = f1_score(labels, preds, )
31
+ precision = precision_score(labels, preds)
32
+ balanced_accuracy = balanced_accuracy_score(labels, preds)
33
+ recall = recall_score(labels, preds)
34
+
35
+ return {
36
+ 'accuracy': acc,
37
+ 'f1_macro': f1_macro,
38
+ 'f1_micro': f1_micro,
39
+ 'precision': precision,
40
+ 'recall': recall,
41
+ 'balanced_accuracy': balanced_accuracy
42
+ }
43
+
44
+
45
+
46
+ def get_parser():
47
+
48
+ common_parser = get_common_args_parser()
49
+ bert_parser = get_bert_args_parser()
50
+ parser = merge_argument_parsers(common_parser, bert_parser)
51
+
52
+ return parser
53
+
54
+
55
+ def run(args):
56
+ set_seed(args.seed)
57
+
58
+ config_params = dict(
59
+ min_enr = args.min_enr,
60
+ min_edges = args.min_edges,
61
+ remove_duplicates = args.remove_duplicates,
62
+ reload = args.reload,
63
+ language = args.language
64
+ )
65
+ dataset_name = args.dataset
66
+ dataset = get_models_dataset(dataset_name, **config_params)
67
+
68
+ print("Loaded dataset")
69
+
70
+
71
+ graph_data_params = get_config_params(args)
72
+ graph_data_params = {**graph_data_params, 'task': LP_TASK_LINK_PRED}
73
+
74
+ print("Loading graph dataset")
75
+ graph_dataset = GraphEdgeDataset(
76
+ dataset,
77
+ dict(
78
+ **graph_data_params,
79
+ add_negative_train_samples=args.add_negative_train_samples,
80
+ neg_sampling_ratio=args.neg_sampling_ratio,
81
+ task=LP_TASK_LINK_PRED
82
+ ))
83
+ print("Loaded graph dataset")
84
+
85
+
86
+
87
+ model_name = args.model_name
88
+ tokenizer = get_tokenizer(model_name, args.use_special_tokens)
89
+
90
+
91
+ print("Getting link prediction data")
92
+ bert_dataset = graph_dataset.get_link_prediction_lm_data(
93
+ tokenizer=tokenizer,
94
+ task_type=LP_TASK_LINK_PRED
95
+ )
96
+
97
+ print("Training model")
98
+ model = get_model(args.ckpt if args.ckpt else model_name, num_labels=2, len_tokenizer=len(tokenizer))
99
+
100
+ if args.freeze_pretrained_weights:
101
+ for param in model.base_model.parameters():
102
+ param.requires_grad = False
103
+
104
+
105
+ output_dir = os.path.join(
106
+ 'results',
107
+ dataset_name,
108
+ 'lp',
109
+ f"{graph_dataset.config_hash}",
110
+ )
111
+
112
+ logs_dir = os.path.join(
113
+ 'logs',
114
+ dataset_name,
115
+ 'lp',
116
+ f"{graph_dataset.config_hash}",
117
+ )
118
+
119
+ training_args = TrainingArguments(
120
+ output_dir=output_dir,
121
+ num_train_epochs=args.num_epochs,
122
+ per_device_train_batch_size=args.train_batch_size,
123
+ per_device_eval_batch_size=args.eval_batch_size,
124
+ weight_decay=0.01,
125
+ logging_dir=logs_dir,
126
+ logging_steps=200,
127
+ eval_strategy='steps',
128
+ eval_steps=200,
129
+ save_steps=200,
130
+ save_total_limit=2,
131
+ load_best_model_at_end=True,
132
+ fp16=True,
133
+ )
134
+
135
+ trainer = Trainer(
136
+ model=model,
137
+ args=training_args,
138
+ train_dataset=bert_dataset['train'],
139
+ eval_dataset=bert_dataset['test'],
140
+ compute_metrics=compute_metrics
141
+ )
142
+
143
+ trainer.train()
144
+ print(trainer.evaluate())
145
+ trainer.save_model()
@@ -0,0 +1,164 @@
1
+ from glam4cm.models.hf import get_model
2
+ from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
3
+ import os
4
+ from transformers import TrainingArguments, Trainer
5
+ from glam4cm.data_loading.graph_dataset import GraphNodeDataset
6
+ from glam4cm.data_loading.utils import oversample_dataset
7
+ from glam4cm.downstream_tasks.utils import get_models_dataset
8
+ from glam4cm.tokenization.special_tokens import *
9
+
10
+
11
+ from sklearn.metrics import (
12
+ accuracy_score,
13
+ f1_score,
14
+ recall_score,
15
+ balanced_accuracy_score
16
+ )
17
+
18
+ from glam4cm.tokenization.utils import get_tokenizer
19
+ from glam4cm.utils import merge_argument_parsers, set_seed
20
+
21
+
22
+
23
+ def compute_metrics(pred):
24
+ labels = pred.label_ids
25
+ preds = pred.predictions.argmax(-1)
26
+ acc = (preds == labels).mean()
27
+ f1_macro = f1_score(labels, preds, average='macro')
28
+ accuracy = accuracy_score(labels, preds)
29
+ recall = recall_score(labels, preds, average='macro')
30
+ balanced_acc = balanced_accuracy_score(labels, preds)
31
+
32
+ return {
33
+ 'balanced_accuracy': balanced_acc,
34
+ 'accuracy': acc,
35
+ 'f1_macro': f1_macro,
36
+ 'precision': accuracy,
37
+ 'recall': recall
38
+ }
39
+
40
+
41
+ def get_num_labels(dataset):
42
+ train_labels = dataset['train'][:]['labels'].unique().tolist()
43
+ test_labels = dataset['test'][:]['labels'].unique().tolist()
44
+ return len(set(train_labels + test_labels))
45
+
46
+
47
+ def get_parser():
48
+ common_parser = get_common_args_parser()
49
+ bert_parser = get_bert_args_parser()
50
+ parser = merge_argument_parsers(common_parser, bert_parser)
51
+
52
+ parser.add_argument('--oversampling_ratio', type=float, default=-1)
53
+
54
+ return parser
55
+
56
+
57
+
58
+ def run(args):
59
+ set_seed(args.seed)
60
+
61
+ config_params = dict(
62
+ min_enr = args.min_enr,
63
+ min_edges = args.min_edges,
64
+ remove_duplicates = args.remove_duplicates,
65
+ reload=args.reload,
66
+ language = args.language
67
+ )
68
+ dataset_name = args.dataset
69
+ distance = args.distance
70
+ dataset = get_models_dataset(dataset_name, **config_params)
71
+
72
+ print("Loaded dataset")
73
+
74
+ graph_data_params = get_config_params(args)
75
+ print("Loading graph dataset")
76
+ graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
77
+ print("Loaded graph dataset")
78
+
79
+ assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node_{args.node_cls_label} attribute"
80
+ num_labels = getattr(graph_dataset, f"num_nodes_{args.node_cls_label}")
81
+
82
+ model_name = args.model_name
83
+ tokenizer = get_tokenizer(model_name, use_special_tokens=args.use_special_tokens)
84
+
85
+ print("Getting node classification data")
86
+ bert_dataset = graph_dataset.get_node_classification_lm_data(
87
+ label=args.node_cls_label,
88
+ tokenizer=tokenizer,
89
+ distance=distance,
90
+ )
91
+
92
+ # exit(0)
93
+
94
+ if args.oversampling_ratio != -1:
95
+ ind_w_oversamples = oversample_dataset(bert_dataset['train'])
96
+ bert_dataset['train'].inputs = bert_dataset['train'][ind_w_oversamples]
97
+
98
+ model = get_model(
99
+ args.ckpt if args.ckpt else model_name,
100
+ num_labels=2,
101
+ len_tokenizer=len(tokenizer)
102
+ )
103
+
104
+ if args.freeze_pretrained_weights:
105
+ for param in model.base_model.parameters():
106
+ param.requires_grad = False
107
+
108
+
109
+ print("Training model")
110
+ output_dir = os.path.join(
111
+ 'results',
112
+ dataset_name,
113
+ 'node_cls',
114
+ f'{args.node_cls_label}',
115
+ f"{graph_dataset.config_hash}",
116
+ )
117
+
118
+ logs_dir = os.path.join(
119
+ 'logs',
120
+ dataset_name,
121
+ 'node_cls',
122
+ f'{args.node_cls_label}',
123
+ f"{graph_dataset.config_hash}",
124
+ )
125
+
126
+ print("Output Dir: ", output_dir)
127
+ print("Logs Dir: ", logs_dir)
128
+ print("Len Train Dataset: ", len(bert_dataset['train']))
129
+ print("Len Test Dataset: ", len(bert_dataset['test']))
130
+
131
+ training_args = TrainingArguments(
132
+ output_dir=output_dir,
133
+ num_train_epochs=args.num_epochs,
134
+ per_device_train_batch_size=args.train_batch_size,
135
+ per_device_eval_batch_size=args.eval_batch_size,
136
+ weight_decay=0.01,
137
+ logging_dir=logs_dir,
138
+ logging_steps=args.num_log_steps,
139
+ eval_strategy='steps',
140
+ eval_steps=args.num_eval_steps,
141
+ save_steps=args.num_save_steps,
142
+ save_total_limit=2,
143
+ load_best_model_at_end=True,
144
+ fp16=True,
145
+ )
146
+
147
+ trainer = Trainer(
148
+ model=model,
149
+ args=training_args,
150
+ train_dataset=bert_dataset['train'],
151
+ eval_dataset=bert_dataset['test'],
152
+ compute_metrics=compute_metrics
153
+ )
154
+
155
+ trainer.train()
156
+ results = trainer.evaluate()
157
+ print(results)
158
+
159
+ trainer.save_model()
160
+
161
+
162
+ if __name__ == '__main__':
163
+ args = get_parser()
164
+ run(args)
@@ -0,0 +1,73 @@
1
+ import os
2
+ from glam4cm.downstream_tasks.common_args import (
3
+ get_common_args_parser,
4
+ get_gpt_args_parser
5
+ )
6
+
7
+ from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
8
+ from glam4cm.models.cmgpt import CMGPT, CMGPTClassifier
9
+ from glam4cm.downstream_tasks.utils import get_models_dataset
10
+ from glam4cm.tokenization.utils import get_tokenizer
11
+ from glam4cm.trainers.cm_gpt_trainer import CMGPTTrainer
12
+ from glam4cm.utils import merge_argument_parsers, set_seed
13
+
14
+
15
+ def get_parser():
16
+ common_parser = get_common_args_parser()
17
+ bert_parser = get_gpt_args_parser()
18
+ parser = merge_argument_parsers(common_parser, bert_parser)
19
+
20
+ parser.add_argument('--cls_label', type=str, default='type')
21
+ parser.add_argument('--pretr', type=str, default=None)
22
+ return parser
23
+
24
+
25
+ def run(args):
26
+ set_seed(args.seed)
27
+
28
+ tokenizer = get_tokenizer('bert-base-cased', special_tokens=True)
29
+
30
+ models_dataset_params = dict(
31
+ language='en',
32
+ )
33
+
34
+ graph_params = dict(
35
+ use_special_tokens=args.use_special_tokens,
36
+ distance=args.distance,
37
+ reload = args.reload
38
+ )
39
+
40
+ models_dataset = get_models_dataset(args.dataset, **models_dataset_params)
41
+ graph_dataset = GraphEdgeDataset(models_dataset, **graph_params)
42
+
43
+ assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node labels for {args.node_cls_label}"
44
+
45
+ node_label_dataset = graph_dataset.get_link_prediction_lm_data(
46
+ tokenizer=tokenizer,
47
+ label=args.node_cls_label
48
+ )
49
+
50
+ if args.pretr and os.path.exists(args.pretr):
51
+ cmgpt = CMGPT.from_pretrained(f"{args.pretr}")
52
+ else:
53
+ cmgpt = CMGPT(
54
+ vocab_size=len(tokenizer),
55
+ embed_dim=args.embed_dim,
56
+ block_size=args.block_size,
57
+ n_layer=args.n_layer,
58
+ n_head=args.n_head,
59
+ )
60
+
61
+ cmgpt_classifier = CMGPTClassifier(cmgpt, num_classes=getattr(graph_dataset, f"num_nodes_{args.node_cls_label}"))
62
+
63
+ trainer = CMGPTTrainer(
64
+ cmgpt_classifier,
65
+ train_dataset=node_label_dataset['train'],
66
+ test_dataset=node_label_dataset['test'],
67
+ batch_size=args.batch_size,
68
+ num_epochs=args.num_epochs
69
+ )
70
+
71
+ trainer.train()
72
+
73
+ trainer.save_model()