glam4cm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +9 -0
- glam4cm/data_loading/__init__.py +0 -0
- glam4cm/data_loading/data.py +631 -0
- glam4cm/data_loading/encoding.py +76 -0
- glam4cm/data_loading/graph_dataset.py +940 -0
- glam4cm/data_loading/metadata.py +84 -0
- glam4cm/data_loading/models_dataset.py +361 -0
- glam4cm/data_loading/utils.py +20 -0
- glam4cm/downstream_tasks/__init__.py +0 -0
- glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
- glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
- glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
- glam4cm/downstream_tasks/bert_node_classification.py +164 -0
- glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
- glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
- glam4cm/downstream_tasks/common_args.py +160 -0
- glam4cm/downstream_tasks/create_dataset.py +51 -0
- glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
- glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
- glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
- glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
- glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
- glam4cm/downstream_tasks/utils.py +35 -0
- glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
- glam4cm/embeddings/__init__.py +0 -0
- glam4cm/embeddings/bert.py +72 -0
- glam4cm/embeddings/common.py +43 -0
- glam4cm/embeddings/fasttext.py +0 -0
- glam4cm/embeddings/tfidf.py +25 -0
- glam4cm/embeddings/w2v.py +41 -0
- glam4cm/encoding/__init__.py +0 -0
- glam4cm/encoding/common.py +0 -0
- glam4cm/encoding/encoders.py +100 -0
- glam4cm/graph2str/__init__.py +0 -0
- glam4cm/graph2str/common.py +34 -0
- glam4cm/graph2str/constants.py +15 -0
- glam4cm/graph2str/ontouml.py +141 -0
- glam4cm/graph2str/uml.py +0 -0
- glam4cm/lang2graph/__init__.py +0 -0
- glam4cm/lang2graph/archimate.py +31 -0
- glam4cm/lang2graph/bpmn.py +0 -0
- glam4cm/lang2graph/common.py +416 -0
- glam4cm/lang2graph/ecore.py +221 -0
- glam4cm/lang2graph/ontouml.py +169 -0
- glam4cm/lang2graph/utils.py +80 -0
- glam4cm/models/cmgpt.py +352 -0
- glam4cm/models/gnn_layers.py +273 -0
- glam4cm/models/hf.py +10 -0
- glam4cm/run.py +99 -0
- glam4cm/run_configs.py +126 -0
- glam4cm/settings.py +54 -0
- glam4cm/tokenization/__init__.py +0 -0
- glam4cm/tokenization/special_tokens.py +4 -0
- glam4cm/tokenization/utils.py +37 -0
- glam4cm/trainers/__init__.py +0 -0
- glam4cm/trainers/bert_classifier.py +105 -0
- glam4cm/trainers/cm_gpt_trainer.py +153 -0
- glam4cm/trainers/gnn_edge_classifier.py +126 -0
- glam4cm/trainers/gnn_graph_classifier.py +123 -0
- glam4cm/trainers/gnn_link_predictor.py +144 -0
- glam4cm/trainers/gnn_node_classifier.py +135 -0
- glam4cm/trainers/gnn_trainer.py +129 -0
- glam4cm/trainers/metrics.py +55 -0
- glam4cm/utils.py +194 -0
- glam4cm-0.1.0.dist-info/LICENSE +21 -0
- glam4cm-0.1.0.dist-info/METADATA +86 -0
- glam4cm-0.1.0.dist-info/RECORD +72 -0
- glam4cm-0.1.0.dist-info/WHEEL +5 -0
- glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
- glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,84 @@
|
|
1
|
+
class GraphMetadata:
|
2
|
+
def __init__(self, model_type):
|
3
|
+
self.type = model_type
|
4
|
+
|
5
|
+
@property
|
6
|
+
def node_label(self):
|
7
|
+
return self.node.get('label', None)
|
8
|
+
|
9
|
+
@property
|
10
|
+
def node_cls(self):
|
11
|
+
return self.node.get('cls', None)
|
12
|
+
|
13
|
+
@property
|
14
|
+
def node_attributes(self):
|
15
|
+
return self.node.get('attributes', None)
|
16
|
+
|
17
|
+
@property
|
18
|
+
def edge_label(self):
|
19
|
+
return self.edge.get('label', None)
|
20
|
+
|
21
|
+
@property
|
22
|
+
def edge_cls(self):
|
23
|
+
return self.edge.get('cls', None)
|
24
|
+
|
25
|
+
@property
|
26
|
+
def graph_cls(self):
|
27
|
+
return self.graph.get('cls', None)
|
28
|
+
|
29
|
+
@property
|
30
|
+
def graph_label(self):
|
31
|
+
return self.graph.get('label', None)
|
32
|
+
|
33
|
+
|
34
|
+
|
35
|
+
class EcoreMetaData(GraphMetadata):
|
36
|
+
def __init__(self):
|
37
|
+
super().__init__('ecore')
|
38
|
+
self.node = {
|
39
|
+
"label": "name",
|
40
|
+
"cls": "abstract",
|
41
|
+
"attributes": "attributes"
|
42
|
+
}
|
43
|
+
self.edge = {
|
44
|
+
"label": "name",
|
45
|
+
"cls": "type"
|
46
|
+
}
|
47
|
+
self.graph = {
|
48
|
+
"label": "text",
|
49
|
+
"cls": "label"
|
50
|
+
}
|
51
|
+
|
52
|
+
|
53
|
+
|
54
|
+
class ArchimateMetaData(GraphMetadata):
|
55
|
+
def __init__(self):
|
56
|
+
super().__init__('archimate')
|
57
|
+
self.node = {
|
58
|
+
"label": "name",
|
59
|
+
"cls": ["type", "layer"],
|
60
|
+
}
|
61
|
+
self.edge = {
|
62
|
+
"cls": "type"
|
63
|
+
}
|
64
|
+
|
65
|
+
self.graph = {
|
66
|
+
"label": "text",
|
67
|
+
}
|
68
|
+
|
69
|
+
|
70
|
+
class OntoUMLMetaData(GraphMetadata):
|
71
|
+
def __init__(self):
|
72
|
+
super().__init__('ontouml')
|
73
|
+
self.node = {
|
74
|
+
"label": "name",
|
75
|
+
"cls": ["stereotype"],
|
76
|
+
"attributes": "properties"
|
77
|
+
}
|
78
|
+
self.edge = {
|
79
|
+
"cls": "stereotype"
|
80
|
+
}
|
81
|
+
|
82
|
+
self.graph = {
|
83
|
+
"label": "text",
|
84
|
+
}
|
@@ -0,0 +1,361 @@
|
|
1
|
+
from typing import List
|
2
|
+
import pandas as pd
|
3
|
+
from tqdm.auto import tqdm
|
4
|
+
import pickle
|
5
|
+
from random import shuffle
|
6
|
+
from sklearn.model_selection import StratifiedKFold
|
7
|
+
import json
|
8
|
+
import os
|
9
|
+
from glam4cm.data_loading.encoding import EncodingDataset
|
10
|
+
from glam4cm.lang2graph.archimate import ArchiMateNxG
|
11
|
+
from glam4cm.lang2graph.ecore import EcoreNxG
|
12
|
+
from glam4cm.lang2graph.common import LangGraph
|
13
|
+
from glam4cm.lang2graph.ontouml import OntoUMLNxG
|
14
|
+
from glam4cm.settings import (
|
15
|
+
datasets_dir,
|
16
|
+
seed,
|
17
|
+
)
|
18
|
+
import numpy as np
|
19
|
+
|
20
|
+
|
21
|
+
from glam4cm.settings import logger
|
22
|
+
|
23
|
+
|
24
|
+
class ModelDataset:
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
dataset_name: str,
|
28
|
+
dataset_dir=datasets_dir,
|
29
|
+
save_dir='datasets/pickles',
|
30
|
+
min_edges: int = -1,
|
31
|
+
min_enr: float = -1,
|
32
|
+
timeout=-1,
|
33
|
+
preprocess_graph_text: callable = None
|
34
|
+
):
|
35
|
+
self.name = dataset_name
|
36
|
+
self.dataset_dir = dataset_dir
|
37
|
+
self.save_dir = save_dir
|
38
|
+
os.makedirs(save_dir, exist_ok=True)
|
39
|
+
|
40
|
+
self.min_edges = min_edges
|
41
|
+
self.min_enr = min_enr
|
42
|
+
self.timeout = timeout
|
43
|
+
self.preprocess_graph_text = preprocess_graph_text
|
44
|
+
|
45
|
+
self.graphs: List[LangGraph] = []
|
46
|
+
|
47
|
+
|
48
|
+
def get_train_test_split(self, train_size=0.8):
|
49
|
+
n = len(self.graphs)
|
50
|
+
train_size = int(n * train_size)
|
51
|
+
idx = list(range(n))
|
52
|
+
shuffle(idx)
|
53
|
+
train_idx = idx[:train_size]
|
54
|
+
test_idx = idx[train_size:]
|
55
|
+
return train_idx, test_idx
|
56
|
+
|
57
|
+
|
58
|
+
def k_fold_split(
|
59
|
+
self,
|
60
|
+
k=10
|
61
|
+
):
|
62
|
+
kfold = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)
|
63
|
+
n = len(self.graphs)
|
64
|
+
for train_idx, test_idx in kfold.split(np.zeros(n), np.zeros(n)):
|
65
|
+
yield train_idx, test_idx
|
66
|
+
|
67
|
+
|
68
|
+
@property
|
69
|
+
def data(self):
|
70
|
+
X, y = [], []
|
71
|
+
for g in self.graphs:
|
72
|
+
X.append(g.text)
|
73
|
+
y.append(g.label)
|
74
|
+
|
75
|
+
if self.preprocess_graph_text:
|
76
|
+
X = [self.preprocess_graph_text(x) for x in X]
|
77
|
+
return X, y
|
78
|
+
|
79
|
+
def __get_lm_data(self, train_idx, test_idx, tokenizer, remove_duplicates=False):
|
80
|
+
X, y = self.data
|
81
|
+
y_enc = {label: i for i, label in enumerate(set(y))}
|
82
|
+
y = [y_enc[label] for label in y]
|
83
|
+
X_train, y_train = [X[i] for i in train_idx], [y[i] for i in train_idx]
|
84
|
+
X_test, y_test = [X[i] for i in test_idx], [y[i] for i in test_idx]
|
85
|
+
train_dataset = EncodingDataset(tokenizer, X_train, y_train, remove_duplicates=remove_duplicates)
|
86
|
+
test_dataset = EncodingDataset(tokenizer, X_test, y_test, remove_duplicates=remove_duplicates)
|
87
|
+
num_classes = len(set(y))
|
88
|
+
return {
|
89
|
+
'train': train_dataset,
|
90
|
+
'test': test_dataset,
|
91
|
+
'num_classes': num_classes
|
92
|
+
}
|
93
|
+
|
94
|
+
def get_graph_classification_data(self, tokenizer, remove_duplicates=False):
|
95
|
+
train_idx, test_idx = self.get_train_test_split()
|
96
|
+
return self.__get_lm_data(train_idx, test_idx, tokenizer, remove_duplicates=remove_duplicates)
|
97
|
+
|
98
|
+
def get_graph_classification_data_kfold(self, tokenizer, k=10, remove_duplicates=False):
|
99
|
+
for train_idx, test_idx in self.k_fold_split(k=k):
|
100
|
+
yield self.__get_lm_data(train_idx, test_idx, tokenizer, remove_duplicates=remove_duplicates)
|
101
|
+
|
102
|
+
|
103
|
+
def __repr__(self):
|
104
|
+
return f'Dataset({self.name}, graphs={len(self.graphs)})'
|
105
|
+
|
106
|
+
def __getitem__(self, key) -> LangGraph:
|
107
|
+
return self.graphs[key]
|
108
|
+
|
109
|
+
def __iter__(self):
|
110
|
+
return iter(self.graphs)
|
111
|
+
|
112
|
+
def __len__(self):
|
113
|
+
return len(self.graphs)
|
114
|
+
|
115
|
+
def save(self):
|
116
|
+
print(f'Saving {self.name} to pickle')
|
117
|
+
with open(os.path.join(self.save_dir, f'{self.name}.pkl'), 'wb') as f:
|
118
|
+
pickle.dump(self.graphs, f)
|
119
|
+
print(f'Saved {self.name} to pickle')
|
120
|
+
|
121
|
+
|
122
|
+
def filter_graphs(self):
|
123
|
+
graphs = list()
|
124
|
+
for graph in self.graphs:
|
125
|
+
addable = True
|
126
|
+
if self.min_edges > 0 and graph.number_of_edges() < self.min_edges:
|
127
|
+
addable = False
|
128
|
+
if self.min_enr > 0 and graph.enr < self.min_enr:
|
129
|
+
addable = False
|
130
|
+
|
131
|
+
if addable:
|
132
|
+
graphs.append(graph)
|
133
|
+
|
134
|
+
self.graphs = graphs
|
135
|
+
|
136
|
+
|
137
|
+
|
138
|
+
def load(self):
|
139
|
+
print(f'Loading {self.name} from pickle')
|
140
|
+
with open(os.path.join(self.save_dir, f'{self.name}.pkl'), 'rb') as f:
|
141
|
+
self.graphs = pickle.load(f)
|
142
|
+
|
143
|
+
self.filter_graphs()
|
144
|
+
print(f'Loaded {self.name} with {len(self.graphs)} graphs')
|
145
|
+
|
146
|
+
|
147
|
+
@property
|
148
|
+
def summary(self):
|
149
|
+
num_graphs = len(self.graphs)
|
150
|
+
num_edges = sum([g.number_of_edges() for g in self.graphs])
|
151
|
+
num_nodes = sum([g.number_of_nodes() for g in self.graphs])
|
152
|
+
average_nodes = num_nodes / num_graphs
|
153
|
+
average_edges = num_edges / num_graphs
|
154
|
+
average_n2e_ratio = np.mean([g.number_of_nodes() / g.number_of_edges() for g in self.graphs])
|
155
|
+
return {
|
156
|
+
'num_graphs': num_graphs,
|
157
|
+
'num_edges': num_edges,
|
158
|
+
'num_nodes': num_nodes,
|
159
|
+
'average_nodes': f"{average_nodes:.2f}",
|
160
|
+
'average_edges': f"{average_edges:.2f}",
|
161
|
+
'average_n2e_ratio': f"{average_n2e_ratio:.2f}"
|
162
|
+
}
|
163
|
+
|
164
|
+
|
165
|
+
class EcoreDataset(ModelDataset):
|
166
|
+
def __init__(
|
167
|
+
self,
|
168
|
+
dataset_name: str,
|
169
|
+
dataset_dir=datasets_dir,
|
170
|
+
save_dir='datasets/pickles',
|
171
|
+
reload=False,
|
172
|
+
remove_duplicates=False,
|
173
|
+
min_edges: int = -1,
|
174
|
+
min_enr: float = -1,
|
175
|
+
preprocess_graph_text: callable = None
|
176
|
+
):
|
177
|
+
super().__init__(
|
178
|
+
dataset_name,
|
179
|
+
dataset_dir=dataset_dir,
|
180
|
+
save_dir=save_dir,
|
181
|
+
min_edges=min_edges,
|
182
|
+
min_enr=min_enr,
|
183
|
+
preprocess_graph_text=preprocess_graph_text
|
184
|
+
)
|
185
|
+
os.makedirs(save_dir, exist_ok=True)
|
186
|
+
|
187
|
+
dataset_exists = os.path.exists(os.path.join(save_dir, f'{dataset_name}.pkl'))
|
188
|
+
if reload or not dataset_exists:
|
189
|
+
self.graphs: List[EcoreNxG] = []
|
190
|
+
data_path = os.path.join(dataset_dir, dataset_name)
|
191
|
+
for file in os.listdir(data_path):
|
192
|
+
if file.endswith('.jsonl') and file.startswith('ecore'):
|
193
|
+
json_objects = json.load(open(os.path.join(data_path, file)))
|
194
|
+
for g in tqdm(json_objects, desc=f'Loading {dataset_name.title()}'):
|
195
|
+
if remove_duplicates and g['is_duplicated']:
|
196
|
+
continue
|
197
|
+
nxg = EcoreNxG(g)
|
198
|
+
self.graphs.append(nxg)
|
199
|
+
|
200
|
+
print(f'Loaded Total {self.name} with {len(self.graphs)} graphs')
|
201
|
+
print("Filtering...")
|
202
|
+
self.save()
|
203
|
+
self.filter_graphs()
|
204
|
+
else:
|
205
|
+
self.load()
|
206
|
+
|
207
|
+
logger.info(f'Loaded {self.name} with {len(self.graphs)} graphs')
|
208
|
+
|
209
|
+
# if remove_duplicates:
|
210
|
+
# self.dedup()
|
211
|
+
|
212
|
+
logger.info(f'Graphs: {len(self.graphs)}')
|
213
|
+
print(f'Loaded {self.name} with {len(self.graphs)} graphs')
|
214
|
+
|
215
|
+
|
216
|
+
def dedup(self) -> List[EcoreNxG]:
|
217
|
+
logger.info(f'Deduplicating {self.name}')
|
218
|
+
return [g for g in self.graphs if not g.is_duplicated]
|
219
|
+
|
220
|
+
def __repr__(self):
|
221
|
+
return f"EcoreDataset({self.name}, graphs={len(self.graphs)})"
|
222
|
+
|
223
|
+
|
224
|
+
class ArchiMateDataset(ModelDataset):
|
225
|
+
def __init__(
|
226
|
+
self,
|
227
|
+
dataset_name: str,
|
228
|
+
dataset_dir=datasets_dir,
|
229
|
+
save_dir='datasets/pickles',
|
230
|
+
reload=False,
|
231
|
+
remove_duplicates=False,
|
232
|
+
min_edges: int = -1,
|
233
|
+
min_enr: float = -1,
|
234
|
+
timeout=-1,
|
235
|
+
language=None,
|
236
|
+
preprocess_graph_text: callable = None
|
237
|
+
):
|
238
|
+
super().__init__(
|
239
|
+
dataset_name,
|
240
|
+
dataset_dir=dataset_dir,
|
241
|
+
save_dir=save_dir,
|
242
|
+
min_edges=min_edges,
|
243
|
+
min_enr=min_enr,
|
244
|
+
timeout=timeout,
|
245
|
+
preprocess_graph_text=preprocess_graph_text
|
246
|
+
)
|
247
|
+
os.makedirs(save_dir, exist_ok=True)
|
248
|
+
|
249
|
+
dataset_exists = os.path.exists(os.path.join(save_dir, f'{dataset_name}.pkl'))
|
250
|
+
if reload or not dataset_exists:
|
251
|
+
self.graphs: List[ArchiMateNxG] = []
|
252
|
+
data_path = os.path.join(dataset_dir, dataset_name, 'processed-models')
|
253
|
+
if language:
|
254
|
+
df = pd.read_csv(os.path.join(dataset_dir, dataset_name, f'{language}-metadata.csv'))
|
255
|
+
model_dirs = df['ID'].to_list()
|
256
|
+
else:
|
257
|
+
model_dirs = os.listdir(data_path)
|
258
|
+
|
259
|
+
for model_dir in tqdm(model_dirs, desc=f'Loading {dataset_name.title()}'):
|
260
|
+
model_dir = os.path.join(data_path, model_dir)
|
261
|
+
if os.path.isdir(model_dir):
|
262
|
+
model_file = os.path.join(model_dir, 'model.json')
|
263
|
+
if os.path.exists(model_file):
|
264
|
+
model = json.load(open(model_file))
|
265
|
+
try:
|
266
|
+
nxg = ArchiMateNxG(
|
267
|
+
model,
|
268
|
+
path=model_file,
|
269
|
+
timeout=timeout
|
270
|
+
)
|
271
|
+
if nxg.number_of_edges() < 1:
|
272
|
+
continue
|
273
|
+
self.graphs.append(nxg)
|
274
|
+
|
275
|
+
except Exception as e:
|
276
|
+
raise e
|
277
|
+
|
278
|
+
self.filter_graphs()
|
279
|
+
self.save()
|
280
|
+
else:
|
281
|
+
self.load()
|
282
|
+
|
283
|
+
if remove_duplicates:
|
284
|
+
self.dedup()
|
285
|
+
|
286
|
+
print(f'Loaded {self.name} with {len(self.graphs)} graphs')
|
287
|
+
print(f'Graphs: {len(self.graphs)}')
|
288
|
+
|
289
|
+
|
290
|
+
def dedup(self) -> List[ArchiMateNxG]:
|
291
|
+
return list({str(g.edges(data=True)): g for g in self.graphs}.values())
|
292
|
+
|
293
|
+
def __repr__(self):
|
294
|
+
return f"ArchiMateDataset({self.name}, graphs={len(self.graphs)})"
|
295
|
+
|
296
|
+
|
297
|
+
class OntoUMLDataset(ModelDataset):
|
298
|
+
def __init__(
|
299
|
+
self,
|
300
|
+
dataset_name: str,
|
301
|
+
dataset_dir=datasets_dir,
|
302
|
+
save_dir='datasets/pickles',
|
303
|
+
reload=False,
|
304
|
+
remove_duplicates=False,
|
305
|
+
min_edges: int = -1,
|
306
|
+
min_enr: float = -1,
|
307
|
+
timeout=-1,
|
308
|
+
preprocess_graph_text: callable = None
|
309
|
+
):
|
310
|
+
super().__init__(
|
311
|
+
dataset_name,
|
312
|
+
dataset_dir=dataset_dir,
|
313
|
+
save_dir=save_dir,
|
314
|
+
min_edges=min_edges,
|
315
|
+
min_enr=min_enr,
|
316
|
+
timeout=timeout,
|
317
|
+
preprocess_graph_text=preprocess_graph_text
|
318
|
+
)
|
319
|
+
os.makedirs(save_dir, exist_ok=True)
|
320
|
+
|
321
|
+
dataset_exists = os.path.exists(os.path.join(save_dir, f'{dataset_name}.pkl'))
|
322
|
+
if reload or not dataset_exists:
|
323
|
+
self.graphs: List[OntoUMLNxG] = []
|
324
|
+
data_path = os.path.join(dataset_dir, dataset_name, 'models')
|
325
|
+
model_dirs = os.listdir(data_path)
|
326
|
+
|
327
|
+
for model_dir in tqdm(model_dirs, desc=f'Loading {dataset_name.title()}'):
|
328
|
+
model_dir = os.path.join(data_path, model_dir)
|
329
|
+
if os.path.isdir(model_dir):
|
330
|
+
model_file = os.path.join(model_dir, 'ontology.json')
|
331
|
+
if os.path.exists(model_file):
|
332
|
+
with open(model_file, encoding='iso-8859-1') as f:
|
333
|
+
model = json.load(f)
|
334
|
+
try:
|
335
|
+
nxg = OntoUMLNxG(model)
|
336
|
+
if nxg.number_of_edges() < 1:
|
337
|
+
continue
|
338
|
+
self.graphs.append(nxg)
|
339
|
+
|
340
|
+
except Exception as e:
|
341
|
+
print(f"Error in {model_file} {e}")
|
342
|
+
|
343
|
+
self.filter_graphs()
|
344
|
+
self.save()
|
345
|
+
else:
|
346
|
+
self.load()
|
347
|
+
|
348
|
+
if remove_duplicates:
|
349
|
+
self.dedup()
|
350
|
+
|
351
|
+
print(f'Loaded {self.name} with {len(self.graphs)} graphs')
|
352
|
+
print(f'Graphs: {len(self.graphs)}')
|
353
|
+
|
354
|
+
|
355
|
+
def dedup(self) -> List[OntoUMLNxG]:
|
356
|
+
return list({str(g.edges(data=True)): g for g in self.graphs}.values())
|
357
|
+
|
358
|
+
def __repr__(self):
|
359
|
+
return f"OntoUMLDataset({self.name}, graphs={len(self.graphs)})"
|
360
|
+
|
361
|
+
|
@@ -0,0 +1,20 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def oversample_dataset(dataset, oversampling_ratio=0.7):
|
5
|
+
"""
|
6
|
+
This function oversamples the classes that occur less frequently in the dataset.
|
7
|
+
The occurence of each class is counted and each class is oversampled 70% of the difference between the most common class and the class in question.
|
8
|
+
"""
|
9
|
+
|
10
|
+
class_occurences = dataset[:]['labels'].numpy()
|
11
|
+
unique_classes, counts = np.unique(class_occurences, return_counts=True)
|
12
|
+
max_count = counts.max()
|
13
|
+
indices_with_oversamples = []
|
14
|
+
for class_idx, count in zip(unique_classes, counts):
|
15
|
+
class_indices = np.where(class_occurences == class_idx)[0]
|
16
|
+
indices_with_oversamples.extend(class_indices)
|
17
|
+
oversample_count = int(oversampling_ratio * (max_count - count))
|
18
|
+
indices_with_oversamples.extend(np.random.choice(class_indices, oversample_count))
|
19
|
+
|
20
|
+
return indices_with_oversamples
|
File without changes
|
@@ -0,0 +1,144 @@
|
|
1
|
+
import os
|
2
|
+
from transformers import TrainingArguments, Trainer
|
3
|
+
from glam4cm.data_loading.graph_dataset import GraphEdgeDataset
|
4
|
+
from glam4cm.data_loading.utils import oversample_dataset
|
5
|
+
from glam4cm.settings import LP_TASK_EDGE_CLS
|
6
|
+
from glam4cm.downstream_tasks.common_args import get_bert_args_parser, get_common_args_parser, get_config_params
|
7
|
+
from glam4cm.models.hf import get_model
|
8
|
+
from glam4cm.downstream_tasks.utils import get_models_dataset
|
9
|
+
|
10
|
+
|
11
|
+
from sklearn.metrics import (
|
12
|
+
accuracy_score,
|
13
|
+
f1_score,
|
14
|
+
recall_score,
|
15
|
+
balanced_accuracy_score
|
16
|
+
)
|
17
|
+
|
18
|
+
from glam4cm.tokenization.utils import get_tokenizer
|
19
|
+
from glam4cm.utils import merge_argument_parsers, set_seed
|
20
|
+
|
21
|
+
|
22
|
+
def compute_metrics(pred):
|
23
|
+
labels = pred.label_ids
|
24
|
+
preds = pred.predictions.argmax(-1)
|
25
|
+
acc = (preds == labels).mean()
|
26
|
+
f1_macro = f1_score(labels, preds, average='macro')
|
27
|
+
accuracy = accuracy_score(labels, preds)
|
28
|
+
recall = recall_score(labels, preds, average='macro')
|
29
|
+
balanced_acc = balanced_accuracy_score(labels, preds)
|
30
|
+
|
31
|
+
return {
|
32
|
+
'balanced_accuracy': balanced_acc,
|
33
|
+
'accuracy': acc,
|
34
|
+
'f1_macro': f1_macro,
|
35
|
+
'precision': accuracy,
|
36
|
+
'recall': recall
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
def get_parser():
|
41
|
+
common_parser = get_common_args_parser()
|
42
|
+
bert_parser = get_bert_args_parser()
|
43
|
+
parser = merge_argument_parsers(common_parser, bert_parser)
|
44
|
+
|
45
|
+
parser.add_argument('--oversampling_ratio', type=float, default=-1)
|
46
|
+
|
47
|
+
return parser
|
48
|
+
|
49
|
+
|
50
|
+
def run(args):
|
51
|
+
set_seed(args.seed)
|
52
|
+
|
53
|
+
config_params = dict(
|
54
|
+
min_enr = args.min_enr,
|
55
|
+
min_edges = args.min_edges,
|
56
|
+
remove_duplicates = args.remove_duplicates,
|
57
|
+
language = args.language,
|
58
|
+
reload=args.reload
|
59
|
+
)
|
60
|
+
dataset_name = args.dataset
|
61
|
+
|
62
|
+
print("Loaded dataset")
|
63
|
+
dataset = get_models_dataset(dataset_name, **config_params)
|
64
|
+
|
65
|
+
graph_data_params = get_config_params(args)
|
66
|
+
graph_data_params = {**graph_data_params, 'task_type': LP_TASK_EDGE_CLS}
|
67
|
+
|
68
|
+
print("Loading graph dataset")
|
69
|
+
graph_dataset = GraphEdgeDataset(dataset, **graph_data_params)
|
70
|
+
print("Loaded graph dataset")
|
71
|
+
|
72
|
+
assert hasattr(graph_dataset, f'num_edges_{args.edge_cls_label}'), f"Dataset does not have node_{args.edge_cls_label} attribute"
|
73
|
+
num_labels = getattr(graph_dataset, f"num_edges_{args.edge_cls_label}")
|
74
|
+
|
75
|
+
|
76
|
+
model_name = args.model_name
|
77
|
+
tokenizer = get_tokenizer(model_name, args.use_special_tokens)
|
78
|
+
|
79
|
+
print("Getting Edge Classification data")
|
80
|
+
bert_dataset = graph_dataset.get_link_prediction_lm_data(tokenizer=tokenizer)
|
81
|
+
|
82
|
+
# exit(0)
|
83
|
+
|
84
|
+
if args.oversampling_ratio != -1:
|
85
|
+
ind_w_oversamples = oversample_dataset(bert_dataset['train'])
|
86
|
+
bert_dataset['train'].inputs = bert_dataset['train'][ind_w_oversamples]
|
87
|
+
|
88
|
+
print("Training model")
|
89
|
+
print(f'Number of labels: {num_labels}')
|
90
|
+
|
91
|
+
model = get_model(args.ckpt if args.ckpt else model_name, num_labels, len(tokenizer))
|
92
|
+
|
93
|
+
if args.freeze_pretrained_weights:
|
94
|
+
for param in model.base_model.parameters():
|
95
|
+
param.requires_grad = False
|
96
|
+
|
97
|
+
output_dir = os.path.join(
|
98
|
+
'results',
|
99
|
+
dataset_name,
|
100
|
+
'edge_cls',
|
101
|
+
f'{args.edge_cls_label}',
|
102
|
+
f"{graph_dataset.config_hash}",
|
103
|
+
)
|
104
|
+
|
105
|
+
logs_dir = os.path.join(
|
106
|
+
'logs',
|
107
|
+
dataset_name,
|
108
|
+
'edge_cls',
|
109
|
+
f'{args.edge_cls_label}',
|
110
|
+
f"{graph_dataset.config_hash}",
|
111
|
+
)
|
112
|
+
|
113
|
+
training_args = TrainingArguments(
|
114
|
+
output_dir=output_dir,
|
115
|
+
num_train_epochs=args.num_epochs,
|
116
|
+
per_device_train_batch_size=args.train_batch_size,
|
117
|
+
per_device_eval_batch_size=args.eval_batch_size,
|
118
|
+
weight_decay=0.01,
|
119
|
+
logging_dir=logs_dir,
|
120
|
+
logging_steps=args.num_log_steps,
|
121
|
+
eval_strategy='steps',
|
122
|
+
eval_steps=args.num_eval_steps,
|
123
|
+
save_steps=args.num_save_steps,
|
124
|
+
save_total_limit=2,
|
125
|
+
load_best_model_at_end=True,
|
126
|
+
fp16=True,
|
127
|
+
)
|
128
|
+
|
129
|
+
trainer = Trainer(
|
130
|
+
model=model,
|
131
|
+
args=training_args,
|
132
|
+
train_dataset=bert_dataset['train'],
|
133
|
+
eval_dataset=bert_dataset['test'],
|
134
|
+
compute_metrics=compute_metrics
|
135
|
+
)
|
136
|
+
|
137
|
+
trainer.train()
|
138
|
+
print(trainer.evaluate())
|
139
|
+
trainer.save_model()
|
140
|
+
|
141
|
+
|
142
|
+
if __name__ == '__main__':
|
143
|
+
args = get_parser()
|
144
|
+
run(args)
|