glam4cm 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. glam4cm/__init__.py +2 -1
  2. glam4cm/data_loading/data.py +90 -146
  3. glam4cm/data_loading/encoding.py +17 -6
  4. glam4cm/data_loading/graph_dataset.py +192 -57
  5. glam4cm/data_loading/metadata.py +1 -1
  6. glam4cm/data_loading/models_dataset.py +42 -18
  7. glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
  8. glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
  9. glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
  10. glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
  11. glam4cm/downstream_tasks/bert_node_classification.py +127 -89
  12. glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
  13. glam4cm/downstream_tasks/common_args.py +32 -4
  14. glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
  15. glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
  16. glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
  17. glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
  18. glam4cm/downstream_tasks/utils.py +16 -2
  19. glam4cm/embeddings/bert.py +1 -1
  20. glam4cm/embeddings/common.py +7 -4
  21. glam4cm/encoding/encoders.py +1 -1
  22. glam4cm/lang2graph/archimate.py +0 -5
  23. glam4cm/lang2graph/common.py +99 -41
  24. glam4cm/lang2graph/ecore.py +1 -2
  25. glam4cm/lang2graph/ontouml.py +8 -7
  26. glam4cm/models/gnn_layers.py +20 -6
  27. glam4cm/models/hf.py +2 -2
  28. glam4cm/run.py +13 -9
  29. glam4cm/run_conf_v2.py +405 -0
  30. glam4cm/run_configs.py +70 -106
  31. glam4cm/run_confs.py +41 -0
  32. glam4cm/settings.py +15 -2
  33. glam4cm/tokenization/special_tokens.py +23 -1
  34. glam4cm/tokenization/utils.py +23 -4
  35. glam4cm/trainers/cm_gpt_trainer.py +1 -1
  36. glam4cm/trainers/gnn_edge_classifier.py +12 -1
  37. glam4cm/trainers/gnn_graph_classifier.py +12 -5
  38. glam4cm/trainers/gnn_link_predictor.py +18 -3
  39. glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
  40. glam4cm/trainers/gnn_trainer.py +8 -0
  41. glam4cm/trainers/metrics.py +1 -1
  42. glam4cm/utils.py +265 -2
  43. {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
  44. glam4cm-1.0.0.dist-info/RECORD +75 -0
  45. {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
  46. glam4cm-0.1.0.dist-info/RECORD +0 -72
  47. {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
  48. {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
  49. {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
@@ -2,10 +2,15 @@ import os
2
2
  from transformers import TrainingArguments, Trainer
3
3
  from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
4
4
  from glam4cm.data_loading.utils import oversample_dataset
5
- from glam4cm.settings import LP_TASK_EDGE_CLS
6
- from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
5
+ from glam4cm.settings import EDGE_CLS_TASK, results_dir
6
+ from glam4cm.downstream_tasks.common_args import (
7
+ get_bert_args_parser,
8
+ get_common_args_parser,
9
+ get_config_params,
10
+ get_config_str
11
+ )
7
12
  from glam4cm.models.hf import get_model
8
- from glam4cm.downstream_tasks.utils import get_models_dataset
13
+ from glam4cm.downstream_tasks.utils import get_logging_steps, get_models_dataset
9
14
 
10
15
 
11
16
  from sklearn.metrics import (
@@ -16,7 +21,7 @@ from sklearn.metrics import (
16
21
  )
17
22
 
18
23
  from glam4cm.tokenization.utils import get_tokenizer
19
- from glam4cm.utils import merge_argument_parsers, set_seed
24
+ from glam4cm.utils import merge_argument_parsers, set_encoded_labels, set_seed
20
25
 
21
26
 
22
27
  def compute_metrics(pred):
@@ -49,27 +54,39 @@ def get_parser():
49
54
 
50
55
  def run(args):
51
56
  set_seed(args.seed)
57
+ dataset_name = args.dataset
58
+ output_dir = os.path.join(
59
+ results_dir,
60
+ dataset_name,
61
+ f"LM_{EDGE_CLS_TASK}",
62
+ f'{args.edge_cls_label}',
63
+ get_config_str(args)
64
+ )
65
+ if os.path.exists(output_dir):
66
+ print(f"Output directory {output_dir} already exists. Exiting.")
67
+ exit(0)
52
68
 
53
69
  config_params = dict(
70
+ include_dummies = args.include_dummies,
54
71
  min_enr = args.min_enr,
55
72
  min_edges = args.min_edges,
56
73
  remove_duplicates = args.remove_duplicates,
57
74
  language = args.language,
58
75
  reload=args.reload
59
76
  )
60
- dataset_name = args.dataset
77
+
61
78
 
62
79
  print("Loaded dataset")
63
80
  dataset = get_models_dataset(dataset_name, **config_params)
64
81
 
65
82
  graph_data_params = get_config_params(args)
66
- graph_data_params = {**graph_data_params, 'task_type': LP_TASK_EDGE_CLS}
83
+ graph_data_params = {**graph_data_params, 'task_type': EDGE_CLS_TASK}
67
84
 
68
85
  print("Loading graph dataset")
69
86
  graph_dataset = GraphEdgeDataset(dataset, **graph_data_params)
70
87
  print("Loaded graph dataset")
71
88
 
72
- assert hasattr(graph_dataset, f'num_edges_{args.edge_cls_label}'), f"Dataset does not have node_{args.edge_cls_label} attribute"
89
+ assert hasattr(graph_dataset, f'num_edges_{args.edge_cls_label}'), f"Dataset does not have edge_{args.edge_cls_label} attribute"
73
90
  num_labels = getattr(graph_dataset, f"num_edges_{args.edge_cls_label}")
74
91
 
75
92
 
@@ -78,6 +95,11 @@ def run(args):
78
95
 
79
96
  print("Getting Edge Classification data")
80
97
  bert_dataset = graph_dataset.get_link_prediction_lm_data(tokenizer=tokenizer)
98
+
99
+ train_dataset = bert_dataset['train']
100
+ test_dataset = bert_dataset['test']
101
+ set_encoded_labels(train_dataset, test_dataset)
102
+
81
103
 
82
104
  # exit(0)
83
105
 
@@ -88,28 +110,32 @@ def run(args):
88
110
  print("Training model")
89
111
  print(f'Number of labels: {num_labels}')
90
112
 
91
- model = get_model(args.ckpt if args.ckpt else model_name, num_labels, len(tokenizer))
92
-
113
+ model = get_model(
114
+ args.ckpt if args.ckpt else model_name,
115
+ num_labels,
116
+ len(tokenizer),
117
+ trust_remote_code=args.trust_remote_code
118
+ )
119
+
93
120
  if args.freeze_pretrained_weights:
94
121
  for param in model.base_model.parameters():
95
122
  param.requires_grad = False
96
123
 
97
- output_dir = os.path.join(
98
- 'results',
99
- dataset_name,
100
- 'edge_cls',
101
- f'{args.edge_cls_label}',
102
- f"{graph_dataset.config_hash}",
103
- )
104
124
 
105
125
  logs_dir = os.path.join(
106
126
  'logs',
107
127
  dataset_name,
108
- 'edge_cls',
128
+ f"LM_{EDGE_CLS_TASK}",
109
129
  f'{args.edge_cls_label}',
110
130
  f"{graph_dataset.config_hash}",
111
131
  )
112
132
 
133
+ logging_steps = get_logging_steps(
134
+ len(train_dataset),
135
+ args.num_epochs,
136
+ args.train_batch_size
137
+ )
138
+
113
139
  training_args = TrainingArguments(
114
140
  output_dir=output_dir,
115
141
  num_train_epochs=args.num_epochs,
@@ -117,13 +143,14 @@ def run(args):
117
143
  per_device_eval_batch_size=args.eval_batch_size,
118
144
  weight_decay=0.01,
119
145
  logging_dir=logs_dir,
120
- logging_steps=args.num_log_steps,
146
+ logging_steps=logging_steps,
121
147
  eval_strategy='steps',
122
- eval_steps=args.num_eval_steps,
123
- save_steps=args.num_save_steps,
124
- save_total_limit=2,
125
- load_best_model_at_end=True,
148
+ eval_steps=logging_steps,
149
+ # save_steps=args.num_save_steps,
150
+ # save_total_limit=2,
151
+ # load_best_model_at_end=True,
126
152
  fp16=True,
153
+ save_strategy="no"
127
154
  )
128
155
 
129
156
  trainer = Trainer(
@@ -12,10 +12,16 @@ from transformers import (
12
12
 
13
13
  from glam4cm.data_loading.graph_dataset import GraphNodeDataset
14
14
  from glam4cm.models.hf import get_model
15
- from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
16
- from glam4cm.downstream_tasks.utils import get_models_dataset
15
+ from glam4cm.downstream_tasks.common_args import (
16
+ get_bert_args_parser,
17
+ get_common_args_parser,
18
+ get_config_params,
19
+ get_config_str
20
+ )
21
+ from glam4cm.downstream_tasks.utils import get_logging_steps, get_models_dataset
22
+ from glam4cm.settings import GRAPH_CLS_TASK, results_dir
17
23
  from glam4cm.tokenization.utils import get_tokenizer
18
- from glam4cm.utils import merge_argument_parsers, set_seed
24
+ from glam4cm.utils import merge_argument_parsers, set_encoded_labels, set_seed
19
25
 
20
26
 
21
27
 
@@ -51,6 +57,7 @@ def run(args):
51
57
  set_seed(args.seed)
52
58
 
53
59
  config_params = dict(
60
+ include_dummies = args.include_dummies,
54
61
  min_enr = args.min_enr,
55
62
  min_edges = args.min_edges,
56
63
  remove_duplicates = args.remove_duplicates,
@@ -60,7 +67,7 @@ def run(args):
60
67
  dataset_name = args.dataset
61
68
 
62
69
  dataset = get_models_dataset(dataset_name, **config_params)
63
- graph_data_params = get_config_params(args)
70
+ graph_data_params = {**get_config_params(args), 'task_type': GRAPH_CLS_TASK}
64
71
  print("Loading graph dataset")
65
72
  graph_dataset = GraphNodeDataset(dataset, **graph_data_params)
66
73
  print("Loaded graph dataset")
@@ -76,30 +83,50 @@ def run(args):
76
83
  train_dataset = classification_dataset['train']
77
84
  test_dataset = classification_dataset['test']
78
85
  num_labels = classification_dataset['num_classes']
86
+
87
+ set_encoded_labels(train_dataset, test_dataset)
79
88
 
80
89
  print(len(train_dataset), len(test_dataset), num_labels)
81
90
 
82
91
  print("Training model")
83
92
  output_dir = os.path.join(
84
- 'results',
93
+ results_dir,
85
94
  dataset_name,
86
- f'graph_cls_',
87
- f"{graph_dataset.config_hash}",
95
+ f"LM_{GRAPH_CLS_TASK}",
96
+ f'{args.cls_label}',
97
+ get_config_str(args)
88
98
  )
99
+ # if os.path.exists(output_dir):
100
+ # print(f"Output directory {output_dir} already exists. Exiting.")
101
+ # exit(0)
89
102
 
90
103
  logs_dir = os.path.join(
91
104
  'logs',
92
105
  dataset_name,
93
- f'graph_cls_',
94
- f"{graph_dataset.config_hash}"
106
+ f"LM_{GRAPH_CLS_TASK}",
107
+ f'{args.cls_label}',
108
+ f"{graph_dataset.config_hash}_{fold_id}",
109
+
95
110
  )
96
111
 
97
- model = get_model(args.ckpt if args.ckpt else model_name, num_labels, len(tokenizer))
112
+ model = get_model(
113
+ args.ckpt if args.ckpt else model_name,
114
+ num_labels,
115
+ len(tokenizer),
116
+ trust_remote_code=args.trust_remote_code
117
+ )
98
118
 
99
119
  if args.freeze_pretrained_weights:
100
120
  for param in model.base_model.parameters():
101
121
  param.requires_grad = False
102
122
 
123
+
124
+ logging_steps = get_logging_steps(
125
+ len(train_dataset),
126
+ args.num_epochs,
127
+ args.train_batch_size
128
+ )
129
+ #
103
130
  # Training arguments
104
131
  training_args = TrainingArguments(
105
132
  output_dir=output_dir,
@@ -111,12 +138,13 @@ def run(args):
111
138
  weight_decay=0.01,
112
139
  learning_rate=5e-5,
113
140
  logging_dir=logs_dir,
114
- logging_steps=args.num_log_steps,
115
- eval_steps=args.num_eval_steps,
116
- save_steps=args.num_save_steps,
141
+ logging_steps=logging_steps,
142
+ eval_steps=logging_steps,
143
+ save_steps=logging_steps,
117
144
  save_total_limit=2,
118
145
  load_best_model_at_end=True,
119
- fp16=True
146
+ fp16=True,
147
+ save_strategy="steps"
120
148
  )
121
149
 
122
150
  # Trainer
@@ -133,5 +161,7 @@ def run(args):
133
161
  results = trainer.evaluate()
134
162
  print(results)
135
163
 
164
+ trainer.save_model()
165
+
136
166
  fold_id += 1
137
167
  break
@@ -1,3 +1,4 @@
1
+ from collections import Counter
1
2
  import os
2
3
  import json
3
4
  from argparse import ArgumentParser
@@ -20,8 +21,7 @@ from transformers import (
20
21
 
21
22
  from glam4cm.data_loading.encoding import EncodingDataset
22
23
  from glam4cm.models.hf import get_model
23
-
24
-
24
+ from glam4cm.settings import results_dir
25
25
 
26
26
  def compute_metrics(pred):
27
27
  labels = pred.label_ids
@@ -49,13 +49,17 @@ def get_parser():
49
49
  parser.add_argument('--ckpt', type=str, default=None)
50
50
  parser.add_argument('--max_length', type=int, default=512)
51
51
  parser.add_argument('--k', type=int, default=10)
52
+ parser.add_argument('--limit', type=int, default=-1)
53
+ parser.add_argument('--trust_remote_code', action='store_true')
54
+ parser.add_argument('--include_dummies', action='store_true')
55
+ parser.add_argument('--task_type', type=str, default='graph_cls')
52
56
 
53
57
  parser.add_argument('--num_epochs', type=int, default=10)
54
58
 
55
59
  parser.add_argument('--warmup_steps', type=int, default=500)
56
- parser.add_argument('--num_log_steps', type=int, default=50)
57
- parser.add_argument('--num_eval_steps', type=int, default=50)
58
- parser.add_argument('--num_save_steps', type=int, default=50)
60
+ parser.add_argument('--num_log_steps', type=int, default=500)
61
+ parser.add_argument('--num_eval_steps', type=int, default=500)
62
+ parser.add_argument('--num_save_steps', type=int, default=500)
59
63
  parser.add_argument('--train_batch_size', type=int, default=2)
60
64
  parser.add_argument('--eval_batch_size', type=int, default=128)
61
65
  parser.add_argument('--lr', type=float, default=1e-5)
@@ -66,15 +70,27 @@ def get_parser():
66
70
  def run(args):
67
71
  dataset_name = args.dataset_name
68
72
  model_name = args.model_name
69
-
70
-
73
+ include_dummies = args.include_dummies
74
+
75
+ file_name = 'ecore.jsonl' if include_dummies and dataset_name == 'modelset' else 'ecore-with-dummy.jsonl'
71
76
  texts = [
72
77
  (g['txt'], g['labels'])
73
- for file_name in os.listdir(f'datasets/{dataset_name}')
74
78
  for g in json.load(open(f'datasets/{dataset_name}/{file_name}'))
75
- if 'ecore' in file_name and file_name.endswith('.jsonl')
79
+ if g['labels'] not in ['dummy', 'unknown']
80
+ ]
81
+ allowed_labels = [label for label, _ in dict(Counter([t[1] for t in texts]).most_common(48)).items()]
82
+ texts = [
83
+ (t, l) for t, l in texts
84
+ if l in allowed_labels
76
85
  ]
86
+ if args.task_type == 'dd':
87
+ print("Task type: DD")
88
+ texts = [(t, 0 if l in 'dummy' else 1) for t, l in texts]
89
+
77
90
  shuffle(texts)
91
+ limit = args.limit if args.limit > 0 else len(texts)
92
+ texts = texts[:limit]
93
+
78
94
  labels = [y for _, y in texts]
79
95
  y_map = {label: i for i, label in enumerate(set(y for y in labels))}
80
96
  y = [y_map[y] for y in labels]
@@ -84,13 +100,14 @@ def run(args):
84
100
 
85
101
  num_labels = len(y_map)
86
102
 
87
- tokenizer = AutoTokenizer.from_pretrained(model_name)
103
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=args.trust_remote_code)
88
104
  k = args.k
89
105
  kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=args.seed)
90
106
 
91
107
  i = 0
92
108
  for train_idx, test_idx in kfold.split(np.zeros(n), np.zeros(n)):
93
-
109
+ # if i == 0:
110
+ # continue
94
111
  print(f'Fold {i+1}/{k}')
95
112
 
96
113
  train_texts = [texts[i] for i in train_idx]
@@ -101,27 +118,34 @@ def run(args):
101
118
 
102
119
  print(f'Train: {len(train_texts)}, Test: {len(test_texts)}', num_labels)
103
120
 
104
- train_dataset = EncodingDataset(tokenizer, train_texts, train_y)
105
- test_dataset = EncodingDataset(tokenizer, test_texts, test_y)
121
+ train_dataset = EncodingDataset(tokenizer, train_texts, train_y, max_length=args.max_length)
122
+ test_dataset = EncodingDataset(tokenizer, test_texts, test_y, max_length=args.max_length)
123
+ # import code; code.interact(local=locals())
106
124
 
107
- model = get_model(args.ckpt if args.ckpt else model_name, num_labels, len(tokenizer))
125
+ model = get_model(
126
+ args.ckpt if args.ckpt else model_name,
127
+ num_labels,
128
+ len(tokenizer),
129
+ trust_remote_code=args.trust_remote_code
130
+ )
108
131
 
109
132
  print("Training model")
110
133
  output_dir = os.path.join(
111
- 'results',
134
+ results_dir,
112
135
  dataset_name,
113
- 'graph_cls_comp',
136
+ f'graph_cls_comp_{"dummy" if include_dummies else ""}{i+1}',
114
137
  )
115
138
 
116
139
  logs_dir = os.path.join(
117
140
  'logs',
118
- dataset_name,
119
- 'graph_cls_comp',
141
+ f"{dataset_name}_{args.model_name if args.ckpt is None else args.ckpt.split('/')[-1]}",
142
+ f'graph_cls_comp_{"dummy" if include_dummies else ""}{i+1}',
120
143
  )
121
144
 
122
145
  print("Running epochs: ", args.num_epochs)
123
146
 
124
147
  # Training arguments
148
+ print("Batch size: ", args.train_batch_size)
125
149
  training_args = TrainingArguments(
126
150
  output_dir=output_dir,
127
151
  num_train_epochs=args.num_epochs,
@@ -131,11 +155,10 @@ def run(args):
131
155
  warmup_steps=500,
132
156
  weight_decay=0.01,
133
157
  logging_dir=logs_dir,
134
- logging_steps=10,
135
- eval_steps=10,
136
- save_total_limit=2,
137
- load_best_model_at_end=True,
138
- fp16=True
158
+ logging_steps=args.num_log_steps,
159
+ eval_steps=args.num_eval_steps,
160
+ fp16=True,
161
+ save_strategy="no"
139
162
  )
140
163
 
141
164
  # Trainer
@@ -153,4 +176,4 @@ def run(args):
153
176
  print(results)
154
177
 
155
178
  i += 1
156
-
179
+ # break
@@ -3,9 +3,14 @@ import os
3
3
  from transformers import TrainingArguments, Trainer
4
4
  from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
5
5
  from glam4cm.models.hf import get_model
6
- from glam4cm.settings import LP_TASK_LINK_PRED
7
- from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
8
- from glam4cm.downstream_tasks.utils import get_models_dataset
6
+ from glam4cm.settings import LINK_PRED_TASK, results_dir
7
+ from glam4cm.downstream_tasks.common_args import (
8
+ get_bert_args_parser,
9
+ get_common_args_parser,
10
+ get_config_params,
11
+ get_config_str
12
+ )
13
+ from glam4cm.downstream_tasks.utils import get_logging_steps, get_models_dataset
9
14
  from glam4cm.tokenization.special_tokens import *
10
15
 
11
16
 
@@ -17,7 +22,7 @@ from sklearn.metrics import (
17
22
  )
18
23
 
19
24
  from glam4cm.tokenization.utils import get_tokenizer
20
- from glam4cm.utils import merge_argument_parsers, set_seed
25
+ from glam4cm.utils import merge_argument_parsers, set_encoded_labels, set_seed
21
26
 
22
27
 
23
28
  def compute_metrics(pred):
@@ -56,6 +61,7 @@ def run(args):
56
61
  set_seed(args.seed)
57
62
 
58
63
  config_params = dict(
64
+ include_dummies = args.include_dummies,
59
65
  min_enr = args.min_enr,
60
66
  min_edges = args.min_edges,
61
67
  remove_duplicates = args.remove_duplicates,
@@ -69,17 +75,19 @@ def run(args):
69
75
 
70
76
 
71
77
  graph_data_params = get_config_params(args)
72
- graph_data_params = {**graph_data_params, 'task': LP_TASK_LINK_PRED}
78
+ graph_data_params = {
79
+ **graph_data_params,
80
+ 'add_negative_train_samples': True,
81
+ 'neg_sampling_ratio': args.neg_sampling_ratio,
82
+ 'task_type': LINK_PRED_TASK
83
+ }
84
+ print(graph_data_params)
73
85
 
74
86
  print("Loading graph dataset")
75
87
  graph_dataset = GraphEdgeDataset(
76
88
  dataset,
77
- dict(
78
- **graph_data_params,
79
- add_negative_train_samples=args.add_negative_train_samples,
80
- neg_sampling_ratio=args.neg_sampling_ratio,
81
- task=LP_TASK_LINK_PRED
82
- ))
89
+ **graph_data_params
90
+ )
83
91
  print("Loaded graph dataset")
84
92
 
85
93
 
@@ -89,33 +97,44 @@ def run(args):
89
97
 
90
98
 
91
99
  print("Getting link prediction data")
92
- bert_dataset = graph_dataset.get_link_prediction_lm_data(
93
- tokenizer=tokenizer,
94
- task_type=LP_TASK_LINK_PRED
95
- )
100
+ bert_dataset = graph_dataset.get_link_prediction_lm_data(tokenizer=tokenizer)
101
+ train_dataset = bert_dataset['train']
102
+ test_dataset = bert_dataset['test']
103
+ set_encoded_labels(train_dataset, test_dataset)
104
+
96
105
 
97
106
  print("Training model")
98
- model = get_model(args.ckpt if args.ckpt else model_name, num_labels=2, len_tokenizer=len(tokenizer))
107
+ model = get_model(args.ckpt if args.ckpt else model_name, num_labels=2, len_tokenizer=len(tokenizer), trust_remote_code=args.trust_remote_code)
99
108
 
100
109
  if args.freeze_pretrained_weights:
101
110
  for param in model.base_model.parameters():
102
111
  param.requires_grad = False
103
112
 
104
-
113
+
105
114
  output_dir = os.path.join(
106
- 'results',
115
+ results_dir,
107
116
  dataset_name,
108
- 'lp',
109
- f"{graph_dataset.config_hash}",
117
+ f"LM_{LINK_PRED_TASK}",
118
+ get_config_str(args)
110
119
  )
120
+ if os.path.exists(output_dir):
121
+ print(f"Output directory {output_dir} already exists. Exiting.")
122
+ exit(0)
111
123
 
112
124
  logs_dir = os.path.join(
113
125
  'logs',
114
126
  dataset_name,
115
- 'lp',
127
+ f"LM_{LINK_PRED_TASK}",
116
128
  f"{graph_dataset.config_hash}",
117
129
  )
118
130
 
131
+ logging_steps = get_logging_steps(
132
+ len(train_dataset),
133
+ args.num_epochs,
134
+ args.train_batch_size
135
+ )
136
+
137
+
119
138
  training_args = TrainingArguments(
120
139
  output_dir=output_dir,
121
140
  num_train_epochs=args.num_epochs,
@@ -123,13 +142,14 @@ def run(args):
123
142
  per_device_eval_batch_size=args.eval_batch_size,
124
143
  weight_decay=0.01,
125
144
  logging_dir=logs_dir,
126
- logging_steps=200,
145
+ logging_steps=logging_steps,
127
146
  eval_strategy='steps',
128
- eval_steps=200,
129
- save_steps=200,
130
- save_total_limit=2,
131
- load_best_model_at_end=True,
147
+ eval_steps=logging_steps,
148
+ # save_steps=200,
149
+ # save_total_limit=2,
150
+ # load_best_model_at_end=True,
132
151
  fp16=True,
152
+ save_strategy="no"
133
153
  )
134
154
 
135
155
  trainer = Trainer(