glam4cm 0.1.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. glam4cm/__init__.py +2 -1
  2. glam4cm/data_loading/data.py +90 -146
  3. glam4cm/data_loading/encoding.py +17 -6
  4. glam4cm/data_loading/graph_dataset.py +192 -57
  5. glam4cm/data_loading/metadata.py +1 -1
  6. glam4cm/data_loading/models_dataset.py +42 -18
  7. glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
  8. glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
  9. glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
  10. glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
  11. glam4cm/downstream_tasks/bert_node_classification.py +127 -89
  12. glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
  13. glam4cm/downstream_tasks/common_args.py +32 -4
  14. glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
  15. glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
  16. glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
  17. glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
  18. glam4cm/downstream_tasks/utils.py +16 -2
  19. glam4cm/embeddings/bert.py +1 -1
  20. glam4cm/embeddings/common.py +7 -4
  21. glam4cm/encoding/encoders.py +1 -1
  22. glam4cm/lang2graph/archimate.py +0 -5
  23. glam4cm/lang2graph/common.py +99 -41
  24. glam4cm/lang2graph/ecore.py +1 -2
  25. glam4cm/lang2graph/ontouml.py +8 -7
  26. glam4cm/models/gnn_layers.py +20 -6
  27. glam4cm/models/hf.py +2 -2
  28. glam4cm/run.py +12 -7
  29. glam4cm/run_conf_v2.py +405 -0
  30. glam4cm/run_configs.py +70 -106
  31. glam4cm/run_confs.py +41 -0
  32. glam4cm/settings.py +15 -2
  33. glam4cm/tokenization/special_tokens.py +23 -1
  34. glam4cm/tokenization/utils.py +23 -4
  35. glam4cm/trainers/cm_gpt_trainer.py +1 -1
  36. glam4cm/trainers/gnn_edge_classifier.py +12 -1
  37. glam4cm/trainers/gnn_graph_classifier.py +12 -5
  38. glam4cm/trainers/gnn_link_predictor.py +18 -3
  39. glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
  40. glam4cm/trainers/gnn_trainer.py +8 -0
  41. glam4cm/trainers/metrics.py +1 -1
  42. glam4cm/utils.py +265 -2
  43. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
  44. glam4cm-1.0.0.dist-info/RECORD +75 -0
  45. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
  46. glam4cm-0.1.1.dist-info/RECORD +0 -72
  47. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
  48. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
  49. {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,19 @@
1
+ import numpy as np
1
2
  from glam4cm.models.hf import get_model
2
- from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
3
+ from glam4cm.downstream_tasks.common_args import (
4
+ get_bert_args_parser,
5
+ get_common_args_parser,
6
+ get_config_params,
7
+ get_config_str
8
+ )
3
9
  import os
4
10
  from transformers import TrainingArguments, Trainer
5
11
  from glam4cm.data_loading.graph_dataset import GraphNodeDataset
6
12
  from glam4cm.data_loading.utils import oversample_dataset
7
- from glam4cm.downstream_tasks.utils import get_models_dataset
13
+ from glam4cm.downstream_tasks.utils import get_logging_steps, get_models_dataset
14
+ from glam4cm.settings import NODE_CLS_TASK, results_dir
8
15
  from glam4cm.tokenization.special_tokens import *
9
-
16
+ from sklearn.model_selection import StratifiedKFold
10
17
 
11
18
  from sklearn.metrics import (
12
19
  accuracy_score,
@@ -16,7 +23,7 @@ from sklearn.metrics import (
16
23
  )
17
24
 
18
25
  from glam4cm.tokenization.utils import get_tokenizer
19
- from glam4cm.utils import merge_argument_parsers, set_seed
26
+ from glam4cm.utils import merge_argument_parsers, set_encoded_labels, set_seed
20
27
 
21
28
 
22
29
 
@@ -54,11 +61,24 @@ def get_parser():
54
61
  return parser
55
62
 
56
63
 
57
-
58
64
  def run(args):
59
65
  set_seed(args.seed)
66
+ dataset_name = args.dataset
67
+ print("Training model")
68
+ output_dir = os.path.join(
69
+ results_dir,
70
+ dataset_name,
71
+ f'LM_{NODE_CLS_TASK}',
72
+ f'{args.node_cls_label}',
73
+ get_config_str(args)
74
+ )
75
+
76
+ # if os.path.exists(output_dir):
77
+ # print(f"Output directory {output_dir} already exists. Exiting.")
78
+ # exit(0)
60
79
 
61
80
  config_params = dict(
81
+ include_dummies = args.include_dummies,
62
82
  min_enr = args.min_enr,
63
83
  min_edges = args.min_edges,
64
84
  remove_duplicates = args.remove_duplicates,
@@ -71,92 +91,110 @@ def run(args):
71
91
 
72
92
  print("Loaded dataset")
73
93
 
74
- graph_data_params = get_config_params(args)
94
+ graph_data_params = {**get_config_params(args), 'task_type': NODE_CLS_TASK}
75
95
  print("Loading graph dataset")
76
- graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
77
- print("Loaded graph dataset")
78
-
79
- assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node_{args.node_cls_label} attribute"
80
- num_labels = getattr(graph_dataset, f"num_nodes_{args.node_cls_label}")
81
-
82
- model_name = args.model_name
83
- tokenizer = get_tokenizer(model_name, use_special_tokens=args.use_special_tokens)
84
-
85
- print("Getting node classification data")
86
- bert_dataset = graph_dataset.get_node_classification_lm_data(
87
- label=args.node_cls_label,
88
- tokenizer=tokenizer,
89
- distance=distance,
90
- )
91
-
92
- # exit(0)
93
-
94
- if args.oversampling_ratio != -1:
95
- ind_w_oversamples = oversample_dataset(bert_dataset['train'])
96
- bert_dataset['train'].inputs = bert_dataset['train'][ind_w_oversamples]
97
-
98
- model = get_model(
99
- args.ckpt if args.ckpt else model_name,
100
- num_labels=2,
101
- len_tokenizer=len(tokenizer)
102
- )
103
-
104
- if args.freeze_pretrained_weights:
105
- for param in model.base_model.parameters():
106
- param.requires_grad = False
107
-
108
96
 
109
- print("Training model")
110
- output_dir = os.path.join(
111
- 'results',
112
- dataset_name,
113
- 'node_cls',
114
- f'{args.node_cls_label}',
115
- f"{graph_dataset.config_hash}",
116
- )
117
-
118
- logs_dir = os.path.join(
119
- 'logs',
120
- dataset_name,
121
- 'node_cls',
122
- f'{args.node_cls_label}',
123
- f"{graph_dataset.config_hash}",
124
- )
125
-
126
- print("Output Dir: ", output_dir)
127
- print("Logs Dir: ", logs_dir)
128
- print("Len Train Dataset: ", len(bert_dataset['train']))
129
- print("Len Test Dataset: ", len(bert_dataset['test']))
130
-
131
- training_args = TrainingArguments(
132
- output_dir=output_dir,
133
- num_train_epochs=args.num_epochs,
134
- per_device_train_batch_size=args.train_batch_size,
135
- per_device_eval_batch_size=args.eval_batch_size,
136
- weight_decay=0.01,
137
- logging_dir=logs_dir,
138
- logging_steps=args.num_log_steps,
139
- eval_strategy='steps',
140
- eval_steps=args.num_eval_steps,
141
- save_steps=args.num_save_steps,
142
- save_total_limit=2,
143
- load_best_model_at_end=True,
144
- fp16=True,
145
- )
146
-
147
- trainer = Trainer(
148
- model=model,
149
- args=training_args,
150
- train_dataset=bert_dataset['train'],
151
- eval_dataset=bert_dataset['test'],
152
- compute_metrics=compute_metrics
153
- )
154
-
155
- trainer.train()
156
- results = trainer.evaluate()
157
- print(results)
158
-
159
- trainer.save_model()
97
+ k = int(1 / args.test_ratio)
98
+
99
+ for i in range(k):
100
+ set_seed(np.random.randint(0, 1000))
101
+ graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
102
+ print("Loaded graph dataset")
103
+
104
+ assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node_{args.node_cls_label} attribute"
105
+ num_labels = getattr(graph_dataset, f"num_nodes_{args.node_cls_label}")
106
+
107
+ model_name = args.model_name
108
+ tokenizer = get_tokenizer(model_name, use_special_tokens=args.use_special_tokens)
109
+
110
+ print("Getting node classification data")
111
+ bert_dataset = graph_dataset.get_node_classification_lm_data(
112
+ label=args.node_cls_label,
113
+ tokenizer=tokenizer,
114
+ distance=distance,
115
+ )
116
+
117
+ # exit(0)
118
+
119
+ if args.oversampling_ratio != -1:
120
+ ind_w_oversamples = oversample_dataset(bert_dataset['train'])
121
+ bert_dataset['train'].inputs = bert_dataset['train'][ind_w_oversamples]
122
+
123
+
124
+ model = get_model(
125
+ args.ckpt if args.ckpt else model_name,
126
+ num_labels=num_labels,
127
+ len_tokenizer=len(tokenizer),
128
+ trust_remote_code=args.trust_remote_code
129
+ )
130
+
131
+ if args.freeze_pretrained_weights:
132
+ for param in model.base_model.parameters():
133
+ param.requires_grad = False
134
+
135
+
136
+ logs_dir = os.path.join(
137
+ 'logs',
138
+ dataset_name,
139
+ f'BERT_{NODE_CLS_TASK}',
140
+ f'{args.node_cls_label}',
141
+ f"{graph_dataset.config_hash}_{i}",
142
+ )
143
+
144
+ print("Output Dir: ", output_dir)
145
+ print("Logs Dir: ", logs_dir)
146
+ print("Len Train Dataset: ", len(bert_dataset['train']))
147
+ print("Len Test Dataset: ", len(bert_dataset['test']))
148
+
149
+ train_dataset = bert_dataset['train']
150
+ test_dataset = bert_dataset['test']
151
+ set_encoded_labels(train_dataset, test_dataset)
152
+
153
+
154
+ print("Num epochs: ", args.num_epochs)
155
+
156
+ logging_steps = get_logging_steps(
157
+ len(train_dataset),
158
+ args.num_epochs,
159
+ args.train_batch_size
160
+ )
161
+
162
+ training_args = TrainingArguments(
163
+ output_dir=output_dir,
164
+ num_train_epochs=args.num_epochs,
165
+ per_device_train_batch_size=args.train_batch_size,
166
+ per_device_eval_batch_size=args.eval_batch_size,
167
+ weight_decay=0.01,
168
+ logging_dir=logs_dir,
169
+ logging_steps=logging_steps,
170
+ eval_strategy='steps',
171
+ eval_steps=logging_steps,
172
+ # save_steps=args.num_save_steps,
173
+ # save_total_limit=2,
174
+ # load_best_model_at_end=True,
175
+ fp16=True,
176
+ save_strategy="no"
177
+ )
178
+
179
+ trainer = Trainer(
180
+ model=model,
181
+ args=training_args,
182
+ train_dataset=bert_dataset['train'],
183
+ eval_dataset=bert_dataset['test'],
184
+ compute_metrics=compute_metrics
185
+ )
186
+
187
+ trainer.train()
188
+ results = trainer.evaluate()
189
+
190
+ # with open(os.path.join(output_dir, 'results.txt'), 'a') as f:
191
+ # f.write(str(results))
192
+ # f.write('\n')
193
+
194
+ print(results)
195
+
196
+ trainer.save_model()
197
+ break
160
198
 
161
199
 
162
200
  if __name__ == '__main__':
@@ -1,6 +1,8 @@
1
1
  import os
2
2
  from glam4cm.downstream_tasks.common_args import (
3
- get_common_args_parser,
3
+ get_common_args_parser,
4
+ get_config_params,
5
+ get_config_str,
4
6
  get_gpt_args_parser
5
7
  )
6
8
 
@@ -9,7 +11,8 @@ from glam4cm.models.cmgpt import CMGPT, CMGPTClassifier
9
11
  from glam4cm.downstream_tasks.utils import get_models_dataset
10
12
  from glam4cm.tokenization.utils import get_tokenizer
11
13
  from glam4cm.trainers.cm_gpt_trainer import CMGPTTrainer
12
- from glam4cm.utils import merge_argument_parsers, set_seed
14
+ from glam4cm.utils import merge_argument_parsers, set_encoded_labels, set_seed
15
+ from glam4cm.settings import NODE_CLS_TASK, results_dir
13
16
 
14
17
 
15
18
  def get_parser():
@@ -25,29 +28,68 @@ def get_parser():
25
28
  def run(args):
26
29
  set_seed(args.seed)
27
30
 
28
- tokenizer = get_tokenizer('bert-base-cased', special_tokens=True)
31
+ tokenizer = get_tokenizer('bert-base-cased', use_special_tokens=args.use_special_tokens)
29
32
 
30
- models_dataset_params = dict(
31
- language='en',
33
+ set_seed(args.seed)
34
+ dataset_name = args.dataset
35
+ print("Training model")
36
+ output_dir = os.path.join(
37
+ results_dir,
38
+ dataset_name,
39
+ f'LM_{NODE_CLS_TASK}',
40
+ f'{args.node_cls_label}',
41
+ get_config_str(args)
32
42
  )
33
43
 
34
- graph_params = dict(
35
- use_special_tokens=args.use_special_tokens,
36
- distance=args.distance,
37
- reload = args.reload
44
+ # if os.path.exists(output_dir):
45
+ # print(f"Output directory {output_dir} already exists. Exiting.")
46
+ # exit(0)
47
+
48
+ config_params = dict(
49
+ include_dummies = args.include_dummies,
50
+ min_enr = args.min_enr,
51
+ min_edges = args.min_edges,
52
+ remove_duplicates = args.remove_duplicates,
53
+ reload=args.reload,
54
+ language = args.language
38
55
  )
56
+ dataset_name = args.dataset
57
+ distance = args.distance
58
+ dataset = get_models_dataset(dataset_name, **config_params)
59
+
60
+ print("Loaded dataset")
39
61
 
40
- models_dataset = get_models_dataset(args.dataset, **models_dataset_params)
41
- graph_dataset = GraphNodeDataset(models_dataset, **graph_params)
62
+ graph_data_params = {**get_config_params(args), 'task_type': NODE_CLS_TASK}
63
+ print("Loading graph dataset")
64
+ graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
42
65
 
43
66
  assert hasattr(graph_dataset, f'num_nodes_{args.node_cls_label}'), f"Dataset does not have node labels for {args.node_cls_label}"
44
67
 
45
68
  node_label_dataset = graph_dataset.get_node_classification_lm_data(
46
69
  args.node_cls_label,
47
70
  tokenizer=tokenizer,
48
- distance=1,
71
+ distance=args.distance,
72
+ )
73
+
74
+ set_encoded_labels(node_label_dataset['train'], node_label_dataset['test'])
75
+
76
+ print("Training model")
77
+ output_dir = os.path.join(
78
+ results_dir,
79
+ args.dataset,
80
+ f'LM_{NODE_CLS_TASK}',
81
+ f'{args.node_cls_label}',
82
+ )
83
+
84
+ logs_dir = os.path.join(
85
+ 'logs',
86
+ args.dataset,
87
+ f'CMGPT_{NODE_CLS_TASK}',
88
+ f'{args.node_cls_label}',
89
+ f"{graph_dataset.config_hash}",
49
90
  )
50
91
 
92
+
51
93
  if args.pretr and os.path.exists(args.pretr):
52
94
  print(f"Loading pretrained model from {args.pretr}")
53
95
  cmgpt = CMGPT.from_pretrained(f"{args.pretr}")
@@ -60,7 +102,8 @@ def run(args):
60
102
  n_layer=args.n_layer,
61
103
  n_head=args.n_head,
62
104
  )
63
-
105
+ print(f"Train dataset size: {len(node_label_dataset['train'])}")
106
+ print(f"Test dataset size: {len(node_label_dataset['test'])}")
64
107
  cmgpt_classifier = CMGPTClassifier(cmgpt, num_classes=getattr(graph_dataset, f"num_nodes_{args.node_cls_label}"))
65
108
 
66
109
  trainer = CMGPTTrainer(
@@ -68,9 +111,12 @@ def run(args):
68
111
  train_dataset=node_label_dataset['train'],
69
112
  test_dataset=node_label_dataset['test'],
70
113
  batch_size=args.batch_size,
71
- num_epochs=args.num_epochs
114
+ num_epochs=args.num_epochs,
115
+ log_dir=logs_dir,
116
+ results_dir=output_dir,
72
117
  )
73
118
 
74
119
  trainer.train()
75
120
 
76
- trainer.save_model()
121
+ trainer.save_model()
122
+
@@ -1,13 +1,35 @@
1
1
  from argparse import ArgumentParser
2
2
  from glam4cm.settings import (
3
+ MODERN_BERT,
3
4
  BERT_MODEL,
4
5
  WORD2VEC_MODEL,
5
6
  TFIDF_MODEL
6
7
  )
7
8
 
9
+ def get_config_str(args):
10
+ config_str = ""
11
+ if args.use_attributes:
12
+ config_str += "_attrs"
13
+ if args.use_edge_label:
14
+ config_str += "_el"
15
+ if args.use_edge_types:
16
+ config_str += "_et"
17
+ if args.use_node_types:
18
+ config_str += "_nt"
19
+ if args.use_special_tokens:
20
+ config_str += "_st"
21
+ if args.no_labels:
22
+ config_str += "_nolb"
23
+ config_str += f"_{args.node_cls_label}" if args.node_cls_label else ""
24
+ config_str += f"_{args.edge_cls_label}" if args.edge_cls_label else ""
25
+ config_str += f"_{args.distance}" if args.distance else ""
26
+
27
+ return config_str
28
+
8
29
 
9
30
  def get_config_params(args):
10
31
  common_params = dict(
32
+
11
33
  distance=args.distance,
12
34
  reload=args.reload,
13
35
  test_ratio=args.test_ratio,
@@ -17,6 +39,8 @@ def get_config_params(args):
17
39
  use_edge_types=args.use_edge_types,
18
40
  use_edge_label=args.use_edge_label,
19
41
  no_labels=args.no_labels,
42
+
43
+ node_topk=args.node_topk,
20
44
 
21
45
  use_special_tokens=args.use_special_tokens,
22
46
 
@@ -58,6 +82,7 @@ def get_common_args_parser():
58
82
  ]
59
83
  )
60
84
  parser.add_argument('--remove_duplicates', action='store_true')
85
+ parser.add_argument('--include_dummies', action='store_true')
61
86
  parser.add_argument('--reload', action='store_true')
62
87
  parser.add_argument('--min_enr', type=float, default=-1.0)
63
88
  parser.add_argument('--min_edges', type=int, default=-1)
@@ -72,6 +97,8 @@ def get_common_args_parser():
72
97
 
73
98
  parser.add_argument('--node_cls_label', type=str, default=None)
74
99
  parser.add_argument('--edge_cls_label', type=str, default=None)
100
+
101
+ parser.add_argument('--node_topk', type=int, default=-1)
75
102
 
76
103
 
77
104
  parser.add_argument('--limit', type=int, default=-1)
@@ -84,9 +111,11 @@ def get_common_args_parser():
84
111
  parser.add_argument(
85
112
  '--embed_model_name',
86
113
  type=str,
87
- default='bert-base-uncased',
88
- choices=[BERT_MODEL, WORD2VEC_MODEL, TFIDF_MODEL]
114
+ default=MODERN_BERT,
115
+ choices=[MODERN_BERT, BERT_MODEL, WORD2VEC_MODEL, TFIDF_MODEL]
89
116
  )
117
+
118
+ parser.add_argument('--trust_remote_code', action='store_true')
90
119
  parser.add_argument('--max_length', type=int, default=512)
91
120
  parser.add_argument('--ckpt', type=str, default=None)
92
121
 
@@ -134,7 +163,7 @@ def get_bert_args_parser():
134
163
  parser = ArgumentParser()
135
164
 
136
165
  parser.add_argument('--freeze_pretrained_weights', action='store_true')
137
- parser.add_argument('--model_name', type=str, default='bert-base-uncased')
166
+ parser.add_argument('--model_name', type=str, default='answerdotai/ModernBERT-base')
138
167
 
139
168
  parser.add_argument('--warmup_steps', type=int, default=200)
140
169
  parser.add_argument('--num_log_steps', type=int, default=200)
@@ -148,7 +177,6 @@ def get_bert_args_parser():
148
177
  def get_gpt_args_parser():
149
178
  parser = ArgumentParser()
150
179
  parser.add_argument('--model_name', type=str, default='gpt2')
151
- parser.add_argument('--use_special_tokens', action='store_true')
152
180
 
153
181
  parser.add_argument('--warmup_steps', type=int, default=200)
154
182
  parser.add_argument('--blocks', type=int, default=6)
@@ -1,12 +1,16 @@
1
1
  import os
2
2
  from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
3
3
  from glam4cm.models.gnn_layers import GNNConv, EdgeClassifer
4
- from glam4cm.settings import LP_TASK_EDGE_CLS
4
+ from glam4cm.settings import EDGE_CLS_TASK, results_dir
5
5
  from glam4cm.downstream_tasks.utils import get_models_dataset
6
6
  from glam4cm.tokenization.special_tokens import *
7
7
  from glam4cm.trainers.gnn_edge_classifier import GNNEdgeClassificationTrainer as Trainer
8
- from glam4cm.utils import set_seed, merge_argument_parsers
9
- from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
8
+ from glam4cm.utils import set_seed, merge_argument_parsers, set_torch_encoding_labels
9
+ from glam4cm.downstream_tasks.common_args import (
10
+ get_common_args_parser,
11
+ get_config_params,
12
+ get_gnn_args_parser
13
+ )
10
14
 
11
15
 
12
16
  def get_parser():
@@ -21,6 +25,7 @@ def run(args):
21
25
  set_seed(args.seed)
22
26
 
23
27
  config_params = dict(
28
+ include_dummies = args.include_dummies,
24
29
  min_enr = args.min_enr,
25
30
  min_edges = args.min_edges,
26
31
  remove_duplicates = args.remove_duplicates,
@@ -32,13 +37,21 @@ def run(args):
32
37
  dataset = get_models_dataset(dataset_name, **config_params)
33
38
 
34
39
  graph_data_params = get_config_params(args)
35
- graph_data_params = {**graph_data_params, 'task_type': LP_TASK_EDGE_CLS}
40
+ graph_data_params = {**graph_data_params, 'task_type': EDGE_CLS_TASK}
41
+ print("Using model: ", graph_data_params['embed_model_name'])
42
+ if args.ckpt:
43
+ print("Using checkpoint: ", args.ckpt)
44
+
45
+ # if args.use_embeddings:
46
+ # graph_data_params['embed_model_name'] = os.path.join(results_dir, dataset_name, f'{args.edge_cls_label}')
36
47
 
37
48
  print("Loading graph dataset")
38
49
  graph_dataset = GraphEdgeDataset(dataset, **graph_data_params)
39
50
  print("Loaded graph dataset")
40
51
 
41
52
  graph_torch_data = graph_dataset.get_torch_dataset()
53
+ exclude_labels = getattr(graph_dataset, f"edge_exclude_{args.edge_cls_label}")
54
+ set_torch_encoding_labels(graph_torch_data, f"edge_{args.edge_cls_label}", exclude_labels)
42
55
 
43
56
  input_dim = graph_torch_data[0].x.shape[1]
44
57
 
@@ -59,12 +72,16 @@ def run(args):
59
72
 
60
73
  edge_dim = graph_dataset[0].data.edge_attr.shape[1] if args.use_edge_attrs else None
61
74
 
75
+ ue = "" if not args.use_edge_attrs else "_ue"
76
+
62
77
  logs_dir = os.path.join(
63
78
  "logs",
64
79
  dataset_name,
65
- "gnn_edge_cls",
80
+ f"GNN_{EDGE_CLS_TASK}",
81
+ f"{args.edge_cls_label}{ue}",
66
82
  f"{graph_dataset.config_hash}",
67
83
  )
84
+
68
85
 
69
86
  gnn_conv_model = GNNConv(
70
87
  model_name=model_name,
@@ -99,8 +116,8 @@ def run(args):
99
116
  num_epochs=args.num_epochs,
100
117
  batch_size=args.batch_size,
101
118
  use_edge_attrs=args.use_edge_attrs,
102
- logs_dir=logs_dir
119
+ logs_dir=logs_dir,
103
120
  )
104
121
 
105
122
  print("Training GNN Edge Classification model")
106
- trainer.run()
123
+ trainer.run()
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  from glam4cm.data_loading.graph_dataset import GraphNodeDataset
3
3
  from glam4cm.models.gnn_layers import GNNConv, GraphClassifer
4
+ from glam4cm.settings import GRAPH_CLS_TASK, DUMMY_GRAPH_CLS_TASK, results_dir
4
5
  from glam4cm.trainers.gnn_graph_classifier import GNNGraphClassificationTrainer as Trainer
5
6
  from glam4cm.downstream_tasks.common_args import get_common_args_parser, get_config_params, get_gnn_args_parser
6
7
  from glam4cm.utils import merge_argument_parsers, set_seed
@@ -21,6 +22,7 @@ def run(args):
21
22
  set_seed(args.seed)
22
23
 
23
24
  config_params = dict(
25
+ include_dummies = args.include_dummies,
24
26
  min_enr = args.min_enr,
25
27
  min_edges = args.min_edges,
26
28
  remove_duplicates = args.remove_duplicates,
@@ -30,8 +32,10 @@ def run(args):
30
32
  dataset_name = args.dataset
31
33
 
32
34
  dataset = get_models_dataset(dataset_name, **config_params)
33
-
34
- graph_data_params = get_config_params(args)
35
+
36
+ graph_data_params = {**get_config_params(args), 'task_type': GRAPH_CLS_TASK if not args.include_dummies else DUMMY_GRAPH_CLS_TASK}
37
+ # if args.use_embeddings:
38
+ # graph_data_params['ckpt'] = os.path.join(results_dir, dataset_name, f'{args.cls_label}')
35
39
 
36
40
  print("Loading graph dataset")
37
41
  graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
@@ -40,8 +44,16 @@ def run(args):
40
44
  cls_label = f"num_graph_{args.cls_label}"
41
45
  assert hasattr(graph_dataset, cls_label), f"Dataset does not have attribute {cls_label}"
42
46
  num_classes = getattr(graph_dataset, cls_label)
43
-
44
47
  print(f"Number of classes: {num_classes}")
48
+
49
+ if args.include_dummies:
50
+ import numpy as np
51
+ dummy_class = int(graph_dataset.graph_label_map_label.transform(['dummy'])[0])
52
+ for g, l in zip(graph_dataset, [int(g.data.graph_label[0]) == dummy_class for g in graph_dataset]):
53
+ setattr(g.data, f"graph_{args.cls_label}", np.array([int(l)]))
54
+ num_classes = 2
55
+
56
+
45
57
  model_name = args.gnn_conv_model
46
58
  hidden_dim = args.hidden_dim
47
59
  output_dim = args.output_dim
@@ -53,14 +65,15 @@ def run(args):
53
65
  aggregation = args.aggregation
54
66
 
55
67
  input_dim = graph_dataset[0].data.x.shape[1]
56
-
68
+ ue = "" if not args.use_edge_attrs else "_ue"
57
69
  logs_dir = os.path.join(
58
70
  "logs",
59
71
  dataset_name,
60
- "gnn_graph_cls",
72
+ f"GNN_{GRAPH_CLS_TASK}{ue}",
61
73
  f"{graph_dataset.config_hash}",
62
74
  )
63
75
 
76
+ fold_id = 0
64
77
  for datasets in graph_dataset.get_kfold_gnn_graph_classification_data():
65
78
 
66
79
  edge_dim = graph_dataset[0].data.edge_attr.shape[1] if args.num_heads else None
@@ -94,7 +107,7 @@ def run(args):
94
107
  num_epochs=args.num_epochs,
95
108
  batch_size=args.batch_size,
96
109
  use_edge_attrs=args.use_edge_attrs,
97
- logs_dir=logs_dir
110
+ logs_dir=logs_dir + f"_{fold_id}",
98
111
  )
99
112
 
100
113
  trainer.run()