glam4cm 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. glam4cm/__init__.py +2 -1
  2. glam4cm/data_loading/data.py +90 -146
  3. glam4cm/data_loading/encoding.py +17 -6
  4. glam4cm/data_loading/graph_dataset.py +192 -57
  5. glam4cm/data_loading/metadata.py +1 -1
  6. glam4cm/data_loading/models_dataset.py +42 -18
  7. glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
  8. glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
  9. glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
  10. glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
  11. glam4cm/downstream_tasks/bert_node_classification.py +127 -89
  12. glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
  13. glam4cm/downstream_tasks/common_args.py +32 -4
  14. glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
  15. glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
  16. glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
  17. glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
  18. glam4cm/downstream_tasks/utils.py +16 -2
  19. glam4cm/embeddings/bert.py +1 -1
  20. glam4cm/embeddings/common.py +7 -4
  21. glam4cm/encoding/encoders.py +1 -1
  22. glam4cm/lang2graph/archimate.py +0 -5
  23. glam4cm/lang2graph/common.py +99 -41
  24. glam4cm/lang2graph/ecore.py +1 -2
  25. glam4cm/lang2graph/ontouml.py +8 -7
  26. glam4cm/models/gnn_layers.py +20 -6
  27. glam4cm/models/hf.py +2 -2
  28. glam4cm/run.py +13 -9
  29. glam4cm/run_conf_v2.py +405 -0
  30. glam4cm/run_configs.py +70 -106
  31. glam4cm/run_confs.py +41 -0
  32. glam4cm/settings.py +15 -2
  33. glam4cm/tokenization/special_tokens.py +23 -1
  34. glam4cm/tokenization/utils.py +23 -4
  35. glam4cm/trainers/cm_gpt_trainer.py +1 -1
  36. glam4cm/trainers/gnn_edge_classifier.py +12 -1
  37. glam4cm/trainers/gnn_graph_classifier.py +12 -5
  38. glam4cm/trainers/gnn_link_predictor.py +18 -3
  39. glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
  40. glam4cm/trainers/gnn_trainer.py +8 -0
  41. glam4cm/trainers/metrics.py +1 -1
  42. glam4cm/utils.py +265 -2
  43. {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
  44. glam4cm-1.0.0.dist-info/RECORD +75 -0
  45. {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
  46. glam4cm-0.1.0.dist-info/RECORD +0 -72
  47. {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
  48. {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
  49. {glam4cm-0.1.0.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
glam4cm/run_conf_v2.py ADDED
@@ -0,0 +1,405 @@
1
+ import argparse
2
+ import os
3
+ import pandas as pd
4
+ import subprocess
5
+ from tqdm.auto import tqdm
6
+ from glam4cm.settings import (
7
+ GRAPH_CLS_TASK,
8
+ NODE_CLS_TASK,
9
+ LINK_PRED_TASK,
10
+ EDGE_CLS_TASK,
11
+ results_dir
12
+ )
13
+
14
+ def get_args():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--tasks', type=str)
17
+ parser.add_argument('--start', type=int, default=-1)
18
+ parser.add_argument('--end', type=int, default=-1)
19
+ parser.add_argument('--reload', action='store_true')
20
+ parser.add_argument('--run_lm', action='store_true')
21
+ parser.add_argument('--run_gnn', action='store_true')
22
+ parser.add_argument('--min_distance', type=int, default=0)
23
+ parser.add_argument('--max_distance', type=int, default=3)
24
+ parser.add_argument('--distances', type=str, default=None)
25
+
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+ args = get_args()
30
+
31
+
32
+ dataset_confs = {
33
+ 'eamodelset': {
34
+ "node_cls_label": ["type", "layer"],
35
+ "edge_cls_label": "type",
36
+ "extra_params": {
37
+ "num_epochs": 50,
38
+ }
39
+ },
40
+ 'ecore_555': {
41
+ "node_cls_label": ["abstract"],
42
+ "edge_cls_label": "type",
43
+ "extra_params": {
44
+ "num_epochs": 50,
45
+ }
46
+ },
47
+ 'modelset': {
48
+ "node_cls_label": ["abstract"],
49
+ "edge_cls_label": "type",
50
+ "extra_params": {
51
+ "num_epochs": 50,
52
+ }
53
+ },
54
+ 'ontouml': {
55
+ "node_cls_label": ["stereotype"],
56
+ "edge_cls_label": "type",
57
+ "extra_params": {
58
+ "num_epochs": 50,
59
+ 'node_topk': 20
60
+ }
61
+ },
62
+ }
63
+
64
+ task_configs = {
65
+ 2: {
66
+ "bert_config": {
67
+ "train_batch_size": 2,
68
+ },
69
+ "gnn_config": {
70
+ "task_id": 6,
71
+ },
72
+ },
73
+ 3: {
74
+ "bert_config": {
75
+ "train_batch_size": 32,
76
+ },
77
+ "gnn_config": {
78
+ "task_id": 7,
79
+ },
80
+ },
81
+ 4: {
82
+ "bert_config": {
83
+ "train_batch_size": 64,
84
+ },
85
+ "gnn_config": {
86
+ "task_id": 8,
87
+ },
88
+ },
89
+ 5: {
90
+ "bert_config": {
91
+ "train_batch_size": 64,
92
+ },
93
+ "gnn_config": {
94
+ "task_id": 9,
95
+ },
96
+ },
97
+ 11: {
98
+ "bert_config": {
99
+ "train_batch_size": 1024,
100
+ },
101
+ "gnn_config": {
102
+ "task_id": 9,
103
+ },
104
+ }
105
+ }
106
+
107
+ dataset_updates = [
108
+ "",
109
+ "use_attributes",
110
+ "use_node_types",
111
+ "use_edge_label",
112
+ "use_edge_types",
113
+ ]
114
+
115
+ gnn_conf = {
116
+ "lr": 1e-3
117
+ }
118
+
119
+ gnn_updates = [
120
+ "",
121
+ "use_embeddings",
122
+ "use_edge_attrs"
123
+ ]
124
+
125
+ gnn_models = [
126
+ {
127
+ "name": "SAGEConv",
128
+ "params": {}
129
+ },
130
+ {
131
+ "name": "GATv2Conv",
132
+ "params": {
133
+ "num_heads": 4
134
+ }
135
+ }
136
+ ]
137
+
138
+ gnn_train = True
139
+
140
+
141
+ def cmd_to_dict(command_line):
142
+ return {
143
+ i.split('=')[0].replace('--', ''): True if '=' not in i else i.split('=')[1]
144
+ for i in command_line.split()
145
+ }
146
+
147
+
148
+ def get_config_str(command_line):
149
+ args = cmd_to_dict(command_line)
150
+ config_str = ""
151
+ if 'use_attributes' in args:
152
+ config_str += "_attrs"
153
+ if 'use_edge_label' in args:
154
+ config_str += "_el"
155
+ if 'use_edge_types' in args:
156
+ config_str += "_et"
157
+ if 'use_node_types' in args:
158
+ config_str += "_nt"
159
+ if 'use_special_tokens' in args:
160
+ config_str += "_st"
161
+ if 'no_labels' in args:
162
+ config_str += "_nolb"
163
+ if "node_cls_label" in args:
164
+ config_str += f"_{args['node_cls_label']}"
165
+ if "edge_cls_label" in args:
166
+ config_str += f"_{args['edge_cls_label']}"
167
+ if "distance" in args:
168
+ config_str += f"_{args['distance']}"
169
+
170
+ return config_str
171
+
172
+
173
+
174
+ def get_embed_model_name(command_line):
175
+ args = cmd_to_dict(command_line)
176
+ task_id = int(args['task_id'])
177
+
178
+ if task_id == 6:
179
+ label = f'LM_{GRAPH_CLS_TASK}/label'
180
+ elif task_id == 7:
181
+ label = f"LM_{NODE_CLS_TASK}/{args['node_cls_label']}"
182
+ elif task_id == 8:
183
+ label = f"LM_{LINK_PRED_TASK}"
184
+ elif task_id == 9:
185
+ label = f"LM_{EDGE_CLS_TASK}/{args['edge_cls_label']}"
186
+
187
+ model_name = os.path.join(
188
+ results_dir,
189
+ args['dataset'],
190
+ label,
191
+ get_config_str(command_line)
192
+ )
193
+ if not os.path.exists(model_name):
194
+ print(model_name, os.path.exists(model_name), " does not exist")
195
+ return model_name
196
+
197
+
198
+ def execute_configs(run_configs, tasks_str: str):
199
+
200
+ log_file = f"logs/run_configs_tasks_{tasks_str}_cmgpt.csv"
201
+ if os.path.exists(log_file):
202
+ df = pd.read_csv(log_file)
203
+ else:
204
+ df = pd.DataFrame(columns=['Config', 'Status'])
205
+ remaining_configs = {c['lm']: c['gnn'] for c in run_configs if c['lm'] not in df['Config'].values}
206
+
207
+ start = 0 if args.start == -1 else args.start
208
+ end = len(remaining_configs) if args.end == -1 else args.end
209
+ lm_script_commands = [lm_script_command for lm_script_command in remaining_configs.keys()][start:end]
210
+
211
+ for lm_script_command in lm_script_commands:
212
+ remaining_configs[lm_script_command] = [
213
+ gnn_script_command + ' --ckpt=' + get_embed_model_name(gnn_script_command)
214
+ if 'use_embeddings' in gnn_script_command
215
+ else gnn_script_command
216
+ for gnn_script_command in remaining_configs[lm_script_command]
217
+ ]
218
+ print("\n".join([r for r in remaining_configs]))
219
+ print("Total number of configurations: ", len(run_configs))
220
+ print(f"Total number of remaining configurations: {len(remaining_configs)}")
221
+ print("Total number of configurations to run: ", len(remaining_configs) + sum([len(v) for v in remaining_configs.values()]))
222
+ import json
223
+ print(json.dumps(remaining_configs, indent=2), len(remaining_configs))
224
+
225
+ for lm_script_command in tqdm(lm_script_commands, desc=f'Running tasks: {start}-{end-1}'):
226
+ if args.run_lm:
227
+ lm_script_command = lm_script_command.replace("train_batch_size", "batch_size")
228
+ print(f'Running LM --> {lm_script_command}')
229
+ result = subprocess.run(f'python glam_test.py {lm_script_command}', shell=True)
230
+
231
+ status = 'success' if result.returncode == 0 else f'❌ {result.stderr}'
232
+ print(f"✅ finished running command: {lm_script_command}" if result.returncode == 0 else f"❌ failed with error:\n{result.stderr}")
233
+
234
+ df.loc[len(df)] = [lm_script_command, status]
235
+ df.to_csv(log_file, index=False)
236
+
237
+ if args.run_gnn:
238
+ for gnn_script_command in tqdm(remaining_configs[lm_script_command], desc='Running GNN'):
239
+ print(f'Running GNN --> {gnn_script_command}')
240
+
241
+ result = subprocess.run(f'python glam_test.py {gnn_script_command}', shell=True)
242
+
243
+ status = 'success' if result.returncode == 0 else f'❌ {result.stderr}'
244
+ print(f"✅ finished running command: {gnn_script_command}" if result.returncode == 0 else f"❌ failed with error:\n{result.stderr}")
245
+
246
+ df.loc[len(df)] = [gnn_script_command, status]
247
+ df.to_csv(log_file, index=False)
248
+
249
+
250
+ def get_run_configs(tasks):
251
+
252
+ run_configs = list()
253
+ for task_id in tasks:
254
+ bert_task_config_str = [f'--task_id={task_id}'] + [f'--{k}={v}' for k, v in task_configs[task_id]['bert_config'].items()] + (['--reload'] if args.reload else [])
255
+
256
+ if args.distances:
257
+ distances = [int(i) for i in args.distances.split(',')]
258
+ else:
259
+ distances = [d for d in range(args.min_distance, args.max_distance + 1)]
260
+
261
+ for distance in distances:
262
+ distance_config_str = [f'--distance={distance}']
263
+
264
+ for i in range(len(dataset_updates)):
265
+ if i < len(dataset_updates) - 1:
266
+ continue
267
+
268
+ for dataset, dataset_conf in dataset_confs.items():
269
+ if (task_id == 2 and dataset not in ['ecore_555', 'modelset'])\
270
+ or (task_id in [4, 5] and dataset in ['ontouml']):
271
+ continue
272
+ dataset_conf_str = [f'--dataset={dataset}'] + [f'--{k}={v}' for k, v in dataset_conf['extra_params'].items()] + ['--min_edges=10']
273
+ node_cls_labels = dataset_conf['node_cls_label'] if isinstance(dataset_conf['node_cls_label'], list) else [dataset_conf['node_cls_label']]
274
+ edge_cls_labels = (dataset_conf['edge_cls_label'] if isinstance(dataset_conf['edge_cls_label'], list) else [dataset_conf['edge_cls_label']]) if 'edge_cls_label' in dataset_conf else []
275
+ for node_cls_label in node_cls_labels:
276
+ for edge_cls_label in edge_cls_labels:
277
+ labels_conf_str = [f'--node_cls_label={node_cls_label}', f'--edge_cls_label={edge_cls_label}']
278
+
279
+ config_task_str = [f'--{u}' if u else '' for u in [x for x in dataset_updates[:i+1]]]
280
+ # print(config_task_str)
281
+ # if dataset == 'eamodelset':
282
+ # continue
283
+ if dataset == 'ontouml':
284
+ if "--use_edge_label" in config_task_str:
285
+ config_task_str.remove("--use_edge_label")
286
+
287
+ if dataset == 'eamodelset':
288
+ if "--use_edge_label" in config_task_str:
289
+ config_task_str.remove("--use_edge_label")
290
+ if "--use_attributes" in config_task_str:
291
+ config_task_str.remove("--use_attributes")
292
+
293
+ bert_config = " ".join(bert_task_config_str + \
294
+ dataset_conf_str + \
295
+ labels_conf_str + \
296
+ config_task_str + \
297
+ distance_config_str
298
+ )
299
+
300
+ # if distance > 1:
301
+ # bert_config = bert_config.replace(f"--train_batch_size={task_configs[task_id]['bert_config']['train_batch_size']}", "--train_batch_size=4")
302
+ # print(bert_config)
303
+ run_configs.append({'lm': bert_config})
304
+
305
+ if gnn_train:
306
+ gnn_configs = list()
307
+ for gnn_model in gnn_models:
308
+ for j in range(len((gnn_updates))):
309
+ gnn_task_config_str = [f'--{u}={v}' if u else '' for u, v in task_configs[task_id]['gnn_config'].items()] + (['--reload'] if args.reload else [])
310
+ gnn_config_str = [f'--{u}' if u else '' for u in [i for i in gnn_updates[:j+1]]]
311
+ gnn_params_str = [f'--gnn_conv_model={gnn_model["name"]}'] + \
312
+ [f'--{k}={v}' for k, v in gnn_model['params'].items()] + \
313
+ [f'--{k}={v}' for k, v in gnn_conf.items()]
314
+
315
+ gnn_config = " ".join(
316
+ gnn_task_config_str + \
317
+ gnn_config_str + \
318
+ gnn_params_str + \
319
+ dataset_conf_str + \
320
+ labels_conf_str + \
321
+ config_task_str + \
322
+ distance_config_str
323
+ )
324
+ gnn_config = gnn_config.replace(f"--train_batch_size={task_configs[task_id]['bert_config']['train_batch_size']}", "--train_batch_size=8")
325
+ gnn_config = gnn_config.replace(f"--num_epochs={dataset_conf['extra_params']['num_epochs']}", "--num_epochs=200")
326
+ gnn_configs.append(gnn_config)
327
+
328
+ run_configs[-1]['gnn'] = gnn_configs
329
+
330
+
331
+
332
+ return run_configs
333
+
334
+
335
+ def get_remaining_configs(tasks_str, run_configs):
336
+
337
+ def change_batch_size(conf: str):
338
+ if "distance=2" in conf or "distance=3" in conf:
339
+ conf.replace("--batch_size=64", "--batch_size=8")\
340
+ .replace("--batch_size=32", "--batch_size=8")\
341
+ .replace("--batch_size=16", "--batch_size=8")
342
+ return conf
343
+
344
+
345
+ os.makedirs('logs', exist_ok=True)
346
+ log_file = f"logs/run_configs_tasks_{tasks_str}.csv"
347
+ if os.path.exists(log_file):
348
+ df = pd.read_csv(log_file)
349
+ else:
350
+ df = pd.DataFrame(columns=['Config', 'Status'])
351
+ return {
352
+ 'df': df,
353
+ 'configs': run_configs,
354
+ 'log_file': log_file
355
+ }
356
+
357
+ # your data
358
+ v = df['Config'].apply(lambda x: int(x.split('--task_id=')[1].split()[0]))
359
+ parent_child = {2: 6, 3: 7, 4: 8, 5: 9}
360
+ parent_idxs = [i for i, val in enumerate(v) if val in parent_child]
361
+
362
+ mapping = []
363
+
364
+ for idx, pidx in enumerate(parent_idxs):
365
+ start = pidx + 1
366
+ end = parent_idxs[idx+1] if idx+1 < len(parent_idxs) else len(v)
367
+
368
+ child_val = parent_child[v[pidx]]
369
+ children = [j for j in range(start, end) if v[j] == child_val]
370
+
371
+ mapping.append({pidx: children})
372
+
373
+ rem_configs = dict()
374
+ to_delete = []
375
+ num_gnn = 0
376
+ for mapping in mapping:
377
+ lm_idx = list(mapping.keys())[0]
378
+ gnn_indices = mapping[lm_idx]
379
+ lm_config = change_batch_size(df.iloc[lm_idx]['Config'])
380
+ if any(df.iloc[i]['Status'] != "success" for i in gnn_indices):
381
+ rem_configs[lm_config] = [df.iloc[i]['Config'] for i in gnn_indices]
382
+ num_gnn = len(rem_configs[lm_config])
383
+ to_delete.append(lm_idx)
384
+ to_delete.extend(gnn_indices)
385
+ print(f"Total number of configurations to run: {len(rem_configs)*num_gnn}")
386
+ df = df.drop(to_delete)
387
+ df = df.reset_index(drop=True)
388
+
389
+ return {
390
+ 'df': df,
391
+ 'configs': rem_configs,
392
+ 'log_file': log_file
393
+ }
394
+
395
+
396
+ def main():
397
+ tasks = [int(i) for i in args.tasks.split(',')]
398
+
399
+ run_configs = get_run_configs(tasks)
400
+ # Execute the configurations
401
+ execute_configs(run_configs, tasks_str="_".join([str(i) for i in tasks]))
402
+ # Save the configurations to a CSV file
403
+
404
+ if __name__ == '__main__':
405
+ main()
glam4cm/run_configs.py CHANGED
@@ -1,126 +1,90 @@
1
+ import argparse
2
+ import itertools
1
3
  import subprocess
2
-
3
4
  from tqdm.auto import tqdm
4
5
 
5
6
 
6
- tasks = {
7
- 0: 'Create Dataset',
8
-
9
- 1: 'BERT Graph Classification Comparison',
10
- 2: 'BERT Graph Classification',
11
- 3: 'BERT Node Classification',
12
- 4: 'BERT Link Prediction',
13
- 5: 'BERT Edge Classification',
14
-
15
-
16
- 6: 'GNN Graph Classification',
17
- 7: 'GNN Node Classification',
18
- 8: 'GNN Edge Classification',
19
- 9: 'GNN Link Prediction',
20
- }
21
7
 
22
8
  all_tasks = {
23
9
  1: [
24
10
  '--dataset=ecore_555 --num_epochs=5 --train_batch_size=2',
25
11
  '--dataset=modelset --num_epochs=10 --train_batch_size=2',
26
- ],
12
+ ],
13
+
14
+ 2: [
15
+ '--min_edges=10 --train_batch_size=2 --num_epochs=5',
16
+ ],
27
17
 
28
- 2: [
29
- '--dataset=ecore_555 --num_epochs=5 --min_edges=10 --train_batch_size=2',
30
- '--dataset=ecore_555 --num_epochs=5 --use_attributes --min_edges=10 --train_batch_size=2',
31
- '--dataset=ecore_555 --num_epochs=5 --use_edge_types --min_edges=10 --train_batch_size=2',
32
- '--dataset=ecore_555 --num_epochs=5 --use_attributes --use_edge_types --min_edges=10 --train_batch_size=2',
33
- '--dataset=modelset --num_epochs=10 --min_edges=10 --train_batch_size=2',
34
- '--dataset=modelset --num_epochs=10 --use_attributes --min_edges=10 --train_batch_size=2',
35
- '--dataset=modelset --num_epochs=10 --use_edge_types --min_edges=10 --train_batch_size=2',
36
- '--dataset=modelset --num_epochs=10 --use_attributes --use_edge_types --min_edges=10 --train_batch_size=2',
37
- ],
18
+ 3: [
19
+ '--min_edges=10 --train_batch_size=64 --distance=1 --num_epochs=10',
20
+ ],
38
21
 
39
- 3: [
40
- '--dataset=ecore_555 --num_epochs=5 --cls_label=abstract --min_edges=10 --train_batch_size=32',
41
- '--dataset=ecore_555 --num_epochs=5 --use_attributes --cls_label=abstract --train_batch_size=32 --min_edges=10',
42
- '--dataset=ecore_555 --num_epochs=5 --use_edge_types --cls_label=abstract --train_batch_size=32 --min_edges=10',
43
- '--dataset=ecore_555 --num_epochs=5 --use_attributes --use_edge_types --cls_label=abstract --train_batch_size=32 --min_edges=10',
44
- '--dataset=modelset --num_epochs=10 --cls_label=abstract --train_batch_size=32 --min_edges=10',
45
- '--dataset=modelset --num_epochs=10 --use_attributes --cls_label=abstract --train_batch_size=32 --min_edges=10',
46
- '--dataset=modelset --num_epochs=10 --use_edge_types --cls_label=abstract --train_batch_size=32 --min_edges=10',
47
- '--dataset=modelset --num_epochs=10 --use_attributes --use_edge_types --cls_label=abstract --train_batch_size=32 --min_edges=10',
48
-
49
- '--dataset=mar-ecore-github --num_epochs=10 --use_attributes --use_edge_types --cls_label=abstract --train_batch_size=32 --min_edges=10',
50
-
51
- '--dataset=eamodelset --num_epochs=15 --cls_label=type --train_batch_size=32 --min_edges=10',
52
- '--dataset=eamodelset --num_epochs=15 --use_edge_types --cls_label=type --train_batch_size=32 --min_edges=10',
53
- '--dataset=eamodelset --num_epochs=15 --cls_label=layer --train_batch_size=32 --min_edges=10',
54
- '--dataset=eamodelset --num_epochs=15 --use_edge_types --cls_label=layer --train_batch_size=32 --min_edges=10',
55
- ],
22
+ 4: [
23
+ "--min_edges=10 --distance=1 --train_batch_size=64 --num_epochs=5",
24
+ ],
56
25
 
57
- 4: [
58
- '--dataset=ecore_555 --num_epochs=3 --train_batch_size=32 --min_edges=10',
59
- '--dataset=ecore_555 --num_epochs=3 --use_attributes --train_batch_size=32 --min_edges=10',
60
- '--dataset=modelset --num_epochs=5 --train_batch_size=32 --min_edges=10 --reload',
61
- '--dataset=modelset --num_epochs=5 --use_attributes --train_batch_size=32 --min_edges=10 --reload',
62
-
63
- '--dataset=mar-ecore-github --num_epochs=5 --use_attributes --train_batch_size=32 --min_edges=10 --reload',
64
- '--dataset=eamodelset --num_epochs=5 --train_batch_size=32 --min_edges=10 --reload',
65
- ],
66
-
67
26
  5: [
68
- '--dataset=ecore_555 --num_epochs=5 --train_batch_size=32 --min_edges=10 --reload',
69
- '--dataset=ecore_555 --num_epochs=5 --use_attributes --train_batch_size=32 --min_edges=10 --reload',
70
- '--dataset=modelset --num_epochs=10 --train_batch_size=32 --min_edges=10 --reload',
71
- '--dataset=modelset --num_epochs=10 --use_attributes --train_batch_size=32 --min_edges=10 --reload',
72
- '--dataset=mar-ecore-github --num_epochs=10 --use_attributes --train_batch_size=32 --min_edges=10 --reload',
73
- '--dataset=eamodelset --num_epochs=15 --train_batch_size=32 --min_edges=10 --reload',
27
+ "--min_edges=10 --distance=1 --train_batch_size=64 --num_epochs=5",
74
28
  ],
75
- 6: [
76
- '--dataset=ecore_555 --num_epochs=200 --batch_size=32 --min_edges=10 --reload',
77
- '--dataset=ecore_555 --num_epochs=200 --batch_size=32 --min_edges=10 --use_embeddings --ckpt=results/ecore_555/graph_cls_/10_att_0_nt_0/checkpoint-225 --reload',
78
- '--dataset=ecore_555 --num_epochs=200 --batch_size=32 --min_edges=10 --use_embeddings --use_attributes --ckpt=results/ecore_555/graph_cls_/10_att_1_nt_0/checkpoint-225 --reload',
79
- '--dataset=ecore_555 --num_epochs=200 --batch_size=32 --min_edges=10 --use_embeddings --use_edge_types --ckpt=results/ecore_555/graph_cls_/10_att_0_nt_1/checkpoint-225 --reload',
80
- '--dataset=ecore_555 --num_epochs=200 --batch_size=32 --min_edges=10 --use_embeddings --use_attributes --use_edge_types --ckpt=results/ecore_555/graph_cls_/10_att_1_nt_1/checkpoint-225 --reload',
81
-
82
- '--dataset=modelset --num_epochs=200 --batch_size=32 --min_edges=10 --reload',
83
- '--dataset=modelset --num_epochs=200 --batch_size=32 --min_edges=10 --use_embeddings --ckpt=results/modelset/graph_cls_/10_att_0_nt_0/checkpoint-2540 --reload',
84
- '--dataset=modelset --num_epochs=200 --batch_size=32 --min_edges=10 --use_embeddings --use_attributes --ckpt=results/modelset/graph_cls_/10_att_1_nt_0/checkpoint-2540 --reload',
85
- '--dataset=modelset --num_epochs=200 --batch_size=32 --min_edges=10 --use_embeddings --use_edge_types --ckpt=results/modelset/graph_cls_/10_att_0_nt_1/checkpoint-2540 --reload',
86
- '--dataset=modelset --num_epochs=200 --batch_size=32 --min_edges=10 --use_embeddings --use_attributes --use_edge_types --ckpt=results/modelset/graph_cls_/10_att_1_nt_1/checkpoint-2540 --reload',
87
- ],
88
- 7: [
89
- '--dataset=ecore_555 --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --reload',
90
- '--dataset=ecore_555 --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --ckpt=results/ecore_555/node_cls/abstract/abstract_10_att_0_nt_0/checkpoint-540 --reload',
91
- '--dataset=ecore_555 --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --use_attributes --ckpt=results/ecore_555/node_cls/abstract/abstract_10_att_1_nt_0/checkpoint-540 --reload',
92
- '--dataset=ecore_555 --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --use_edge_types --ckpt=results/ecore_555/node_cls/abstract/abstract_10_att_0_nt_1/checkpoint-540 --reload',
93
- '--dataset=ecore_555 --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --use_attributes --use_edge_types --ckpt=results/ecore_555/node_cls/abstract/abstract_10_att_1_nt_1/checkpoint-540 --reload',
94
-
95
-
96
- '--dataset=modelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --reload',
97
- '--dataset=modelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --ckpt=results/modelset/node_cls/abstract/abstract_10_att_0_nt_0/checkpoint-6870 --reload',
98
- '--dataset=modelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --use_attributes --ckpt=results/modelset/node_cls/abstract/abstract_10_att_1_nt_0/checkpoint-6870 --reload',
99
- '--dataset=modelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --use_edge_types --ckpt=results/modelset/node_cls/abstract/abstract_10_att_0_nt_1/checkpoint-6870 --reload',
100
- '--dataset=modelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --use_attributes --use_edge_types --ckpt=results/modelset/node_cls/abstract/abstract_10_att_1_nt_1/checkpoint-6870 --reload',
101
-
102
- '--dataset=mar-ecore-github --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --reload',
103
- '--dataset=mar-ecore-github --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --ckpt=results/mar-ecore-github/node_cls/abstract/abstract_10_att_0_nt_0/checkpoint-19400 --reload',
104
- '--dataset=mar-ecore-github --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --use_attributes --ckpt=results/mar-ecore-github/node_cls/abstract/abstract_10_att_1_nt_0/checkpoint-19400 --reload',
105
- '--dataset=mar-ecore-github --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --use_edge_types --ckpt=results/mar-ecore-github/node_cls/abstract/abstract_10_att_0_nt_1/checkpoint-19400 --reload',
106
- '--dataset=mar-ecore-github --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=abstract --use_embeddings --use_attributes --use_edge_types --ckpt=results/mar-ecore-github/node_cls/abstract/abstract_10_att_1_nt_1/checkpoint-19400 --reload',
29
+ }
107
30
 
108
- '--dataset=eamodelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=type --reload',
109
- '--dataset=eamodelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=type --use_embeddings --ckpt=results/eamodelset/node_cls/layer/layer_10_att_0_nt_0/checkpoint-9570 --reload',
110
- '--dataset=eamodelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=type --use_embeddings --use_edge_types --ckpt=results/eamodelset/node_cls/layer/layer_10_att_0_nt_1/checkpoint-9570 --reload',
31
+ dataset_confs = {
32
+ 'ecore_555': {
33
+ "node_cls_label": ["abstract"],
34
+ "edge_cls_label": "type",
35
+ },
36
+ 'modelset': {
37
+ "node_cls_label": ["abstract"],
38
+ "edge_cls_label": "type",
39
+ },
40
+ # 'mar-ecore-github': {
41
+ # "node_cls_label": ["abstract"],
42
+ # "edge_cls_label": "type",
43
+ # },
44
+ 'eamodelset': {
45
+ "node_cls_label": ["type", "layer"],
46
+ "edge_cls_label": "type",
47
+ },
48
+ }
111
49
 
112
- '--dataset=eamodelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=type --reload',
113
- '--dataset=eamodelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=type --use_embeddings --ckpt=results/eamodelset/node_cls/type/type_10_att_0_nt_0/checkpoint-9570 --reload',
114
- '--dataset=eamodelset --num_epochs=200 --batch_size=32 --min_edges=10 --cls_label=type --use_embeddings --use_edge_types --ckpt=results/eamodelset/node_cls/type/type_10_att_0_nt_1/checkpoint-9570 --reload',
115
- ]
50
+ param_configs = {
51
+ 'use_attributes': [0, 1],
52
+ 'use_node_types': [0, 1],
53
+ 'use_edge_types': [0, 1],
54
+ 'use_edge_label': [0, 1],
116
55
  }
117
56
 
118
- allowed_tasks = [7]
57
+ def get_args():
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument('--tasks', type=str)
60
+ args = parser.parse_args()
61
+ return args
62
+
119
63
 
120
- for script_id in tqdm(allowed_tasks, desc='Running tasks'):
121
- task = tasks[script_id]
122
- for script in tqdm(all_tasks[script_id], desc=f'Running scripts for {task}'):
123
- script += f' --task={script_id} '
124
- print(f'Running {script}')
64
+ tasks = [int(i) for i in get_args().tasks.split(',')]
65
+ run_configs = list()
66
+
67
+ for task_id in all_tasks:
68
+ if task_id not in tasks:
69
+ continue
70
+
71
+ for task_str in all_tasks[task_id]:
72
+ for dataset, dataset_conf in dataset_confs.items():
73
+ if task_id == 2 and dataset not in ['ecore_555', 'modelset']:
74
+ continue
75
+ task_str = f'--task_id={task_id} ' + task_str
76
+
77
+ node_cls_label = dataset_conf['node_cls_label'] if isinstance(dataset_conf['node_cls_label'], list) else [dataset_conf['node_cls_label']]
78
+ edge_cls_label = dataset_conf['edge_cls_label'] if isinstance(dataset_conf['edge_cls_label'], list) else [dataset_conf['edge_cls_label']]
79
+
80
+ for node_cls_label, edge_cls_label in itertools.product(node_cls_label, edge_cls_label):
81
+ for params in itertools.product(*param_configs.values()):
82
+ config = {k: v for k, v in zip(param_configs.keys(), params)}
83
+ config_task_str = task_str + f' --dataset={dataset} --node_cls_label={node_cls_label} --edge_cls_label={edge_cls_label} ' + ' '.join([f'--{k}' if v else '' for k, v in config.items()])
84
+ # print(config_task_str)
85
+ run_configs.append(config_task_str)
86
+
125
87
 
126
- subprocess.run(f'python run.py {script}', shell=True)
88
+ for script_command in tqdm(run_configs, desc='Running tasks'):
89
+ print(f'Running {script_command}')
90
+ subprocess.run(f'python src/glam4cm/run.py {script_command}', shell=True)
glam4cm/run_confs.py ADDED
@@ -0,0 +1,41 @@
1
+ import argparse
2
+ from tqdm.auto import tqdm
3
+ import subprocess
4
+
5
+
6
+ def get_args():
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--confs_file', type=str, help='File containing configurations to run')
9
+ args = parser.parse_args()
10
+ return args
11
+
12
+ def run_tasks(configs):
13
+ for config in tqdm(configs, desc="Running tasks"):
14
+ print(f"Running config: {config}")
15
+ result = subprocess.run(f'python glam_test.py {config}', shell=True)
16
+ if result.returncode != 0:
17
+ print(f"Error running config {config}: {result.stderr}")
18
+ else:
19
+ print(f"Config {config} completed successfully.")
20
+ print("All tasks completed.")
21
+
22
+ def main():
23
+ args = get_args()
24
+ if not args.confs_file:
25
+ print("No configuration file specified. Exiting.")
26
+ return
27
+ try:
28
+ with open(args.confs_file, 'r') as f:
29
+ configs = f.read().splitlines()
30
+ if not configs:
31
+ print("Configuration file is empty. Exiting.")
32
+ return
33
+ print(f"Found {len(configs)} configurations to run.")
34
+ run_tasks(configs)
35
+ except FileNotFoundError:
36
+ print(f"Configuration file {args.confs_file} not found. Exiting.")
37
+ except Exception as e:
38
+ print(f"An error occurred: {e}. Exiting.")
39
+
40
+ if __name__ == "__main__":
41
+ main()