glam4cm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +9 -0
- glam4cm/data_loading/__init__.py +0 -0
- glam4cm/data_loading/data.py +631 -0
- glam4cm/data_loading/encoding.py +76 -0
- glam4cm/data_loading/graph_dataset.py +940 -0
- glam4cm/data_loading/metadata.py +84 -0
- glam4cm/data_loading/models_dataset.py +361 -0
- glam4cm/data_loading/utils.py +20 -0
- glam4cm/downstream_tasks/__init__.py +0 -0
- glam4cm/downstream_tasks/bert_edge_classification.py +144 -0
- glam4cm/downstream_tasks/bert_graph_classification.py +137 -0
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +156 -0
- glam4cm/downstream_tasks/bert_link_prediction.py +145 -0
- glam4cm/downstream_tasks/bert_node_classification.py +164 -0
- glam4cm/downstream_tasks/cm_gpt_edge_classification.py +73 -0
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +76 -0
- glam4cm/downstream_tasks/cm_gpt_pretraining.py +64 -0
- glam4cm/downstream_tasks/common_args.py +160 -0
- glam4cm/downstream_tasks/create_dataset.py +51 -0
- glam4cm/downstream_tasks/gnn_edge_classification.py +106 -0
- glam4cm/downstream_tasks/gnn_graph_cls.py +101 -0
- glam4cm/downstream_tasks/gnn_link_prediction.py +109 -0
- glam4cm/downstream_tasks/gnn_node_classification.py +103 -0
- glam4cm/downstream_tasks/tf_idf_text_classification.py +22 -0
- glam4cm/downstream_tasks/utils.py +35 -0
- glam4cm/downstream_tasks/word2vec_text_classification.py +108 -0
- glam4cm/embeddings/__init__.py +0 -0
- glam4cm/embeddings/bert.py +72 -0
- glam4cm/embeddings/common.py +43 -0
- glam4cm/embeddings/fasttext.py +0 -0
- glam4cm/embeddings/tfidf.py +25 -0
- glam4cm/embeddings/w2v.py +41 -0
- glam4cm/encoding/__init__.py +0 -0
- glam4cm/encoding/common.py +0 -0
- glam4cm/encoding/encoders.py +100 -0
- glam4cm/graph2str/__init__.py +0 -0
- glam4cm/graph2str/common.py +34 -0
- glam4cm/graph2str/constants.py +15 -0
- glam4cm/graph2str/ontouml.py +141 -0
- glam4cm/graph2str/uml.py +0 -0
- glam4cm/lang2graph/__init__.py +0 -0
- glam4cm/lang2graph/archimate.py +31 -0
- glam4cm/lang2graph/bpmn.py +0 -0
- glam4cm/lang2graph/common.py +416 -0
- glam4cm/lang2graph/ecore.py +221 -0
- glam4cm/lang2graph/ontouml.py +169 -0
- glam4cm/lang2graph/utils.py +80 -0
- glam4cm/models/cmgpt.py +352 -0
- glam4cm/models/gnn_layers.py +273 -0
- glam4cm/models/hf.py +10 -0
- glam4cm/run.py +99 -0
- glam4cm/run_configs.py +126 -0
- glam4cm/settings.py +54 -0
- glam4cm/tokenization/__init__.py +0 -0
- glam4cm/tokenization/special_tokens.py +4 -0
- glam4cm/tokenization/utils.py +37 -0
- glam4cm/trainers/__init__.py +0 -0
- glam4cm/trainers/bert_classifier.py +105 -0
- glam4cm/trainers/cm_gpt_trainer.py +153 -0
- glam4cm/trainers/gnn_edge_classifier.py +126 -0
- glam4cm/trainers/gnn_graph_classifier.py +123 -0
- glam4cm/trainers/gnn_link_predictor.py +144 -0
- glam4cm/trainers/gnn_node_classifier.py +135 -0
- glam4cm/trainers/gnn_trainer.py +129 -0
- glam4cm/trainers/metrics.py +55 -0
- glam4cm/utils.py +194 -0
- glam4cm-0.1.0.dist-info/LICENSE +21 -0
- glam4cm-0.1.0.dist-info/METADATA +86 -0
- glam4cm-0.1.0.dist-info/RECORD +72 -0
- glam4cm-0.1.0.dist-info/WHEEL +5 -0
- glam4cm-0.1.0.dist-info/entry_points.txt +2 -0
- glam4cm-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,141 @@
|
|
1
|
+
import networkx as nx
|
2
|
+
import random
|
3
|
+
import itertools
|
4
|
+
from tqdm.auto import tqdm
|
5
|
+
from constants import *
|
6
|
+
from common import (
|
7
|
+
get_node_neighbours,
|
8
|
+
remove_extra_spaces,
|
9
|
+
has_neighbours_incl_incoming
|
10
|
+
)
|
11
|
+
|
12
|
+
|
13
|
+
def get_node_text_triples(g, distance=1, only_name=False):
|
14
|
+
node_strings = [get_node_str(g, node, distance) for node in g.nodes]
|
15
|
+
node_triples = list()
|
16
|
+
for node, node_str in zip(list(g.nodes), node_strings):
|
17
|
+
name = g.nodes[node]['name'] if 'name' in g.nodes[node] else " reference "
|
18
|
+
node_type = g.nodes[node]['type']
|
19
|
+
prompt_str = f"{node_type} {name}: {node_str}" if not only_name else f"{name}"
|
20
|
+
node_triples.append(prompt_str)
|
21
|
+
return node_triples
|
22
|
+
|
23
|
+
|
24
|
+
def check_stereotype_relevance(g, n):
|
25
|
+
return 'use_stereotype' in g.nodes[n] and g.nodes[n]['use_stereotype']
|
26
|
+
|
27
|
+
|
28
|
+
def process_name_and_steroetype(g, n):
|
29
|
+
string = g.nodes[n]['name'] if g.nodes[n]['name'] != "Null" else ""
|
30
|
+
string += f' {g.nodes[n]["stereotype"]} ' if check_stereotype_relevance(g, n) else ""
|
31
|
+
|
32
|
+
return string
|
33
|
+
|
34
|
+
|
35
|
+
def process_node_for_string(g, n, src=True):
|
36
|
+
if g.nodes[n]['type'] == 'Class':
|
37
|
+
return [process_name_and_steroetype(g, n)]
|
38
|
+
|
39
|
+
strings = list()
|
40
|
+
node_str = process_name_and_steroetype(g, n)
|
41
|
+
edges = list(g.in_edges(n)) if src else list(g.out_edges(n))
|
42
|
+
for edge in edges:
|
43
|
+
v = edge[0] if src else edge[1]
|
44
|
+
v_str = f" {process_edge_for_string(g, edge)} {process_name_and_steroetype(g, v)}"
|
45
|
+
n_str = v_str + node_str if src else node_str + v_str
|
46
|
+
strings.append(n_str)
|
47
|
+
return list(set(map(remove_extra_spaces, strings)))
|
48
|
+
|
49
|
+
|
50
|
+
def process_edge_for_string(g, e):
|
51
|
+
edge_type_s = e_s[g.edges()[e]['type']]
|
52
|
+
return remove_extra_spaces(f" {edge_type_s} ")
|
53
|
+
|
54
|
+
|
55
|
+
def get_triples_from_edges(g, edges=None):
|
56
|
+
if edges is None:
|
57
|
+
edges = g.edges()
|
58
|
+
triples = []
|
59
|
+
for edge in edges:
|
60
|
+
u, v = edge
|
61
|
+
edge_str = process_edge_for_string(g, edge)
|
62
|
+
u_strings, v_strings = process_node_for_string(g, u, src=True), process_node_for_string(g, v, src=False)
|
63
|
+
for u_str, v_str in itertools.product(u_strings, v_strings):
|
64
|
+
pos_triple = u_str + f" {edge_str} " + v_str
|
65
|
+
triples.append(remove_extra_spaces(pos_triple))
|
66
|
+
|
67
|
+
return triples
|
68
|
+
|
69
|
+
|
70
|
+
def process_path_string(g, path):
|
71
|
+
edges = list(zip(path[:-1], path[1:]))
|
72
|
+
triples = get_triples_from_edges(g, edges)
|
73
|
+
|
74
|
+
return remove_extra_spaces(f" {SEP} ".join(triples))
|
75
|
+
|
76
|
+
|
77
|
+
def get_triples_from_node(g, n, distance=1):
|
78
|
+
triples = list()
|
79
|
+
use_stereotype = g.nodes[n]['use_stereotype'] if 'use_stereotype' in g.nodes[n] else False
|
80
|
+
g.nodes[n]['use_stereotype'] = False
|
81
|
+
node_neighbours = get_node_neighbours(g, n, distance)
|
82
|
+
for neighbour in node_neighbours:
|
83
|
+
paths = [p for p in nx.all_simple_paths(g, n, neighbour, cutoff=distance)]
|
84
|
+
for path in paths:
|
85
|
+
triples.append(process_path_string(g, path))
|
86
|
+
|
87
|
+
g.nodes[n]['use_stereotype'] = use_stereotype
|
88
|
+
return triples
|
89
|
+
|
90
|
+
|
91
|
+
def get_node_str(g, n, distance=1):
|
92
|
+
node_triples = get_triples_from_node(g, n, distance)
|
93
|
+
return remove_extra_spaces(f" | ".join(node_triples))
|
94
|
+
|
95
|
+
|
96
|
+
def create_triples_from_graph_edges(graphs):
|
97
|
+
triples = list()
|
98
|
+
for g, _ in tqdm(graphs):
|
99
|
+
triples += get_triples_from_edges(g)
|
100
|
+
|
101
|
+
return triples
|
102
|
+
|
103
|
+
|
104
|
+
def mask_graph(graph, stereotypes_classes, mask_prob=0.2, use_stereotypes=False, use_rel_stereotypes=False):
|
105
|
+
all_stereotype_nodes = [node for node in graph.nodes if 'stereotype' in graph.nodes[node]\
|
106
|
+
and graph.nodes[node]['stereotype'] in stereotypes_classes and has_neighbours_incl_incoming(graph, node)\
|
107
|
+
and (True if use_rel_stereotypes else graph.nodes[node]['type'] == 'Class')]
|
108
|
+
|
109
|
+
assert all(['stereotype' in graph.nodes[node] for node in all_stereotype_nodes]), "All stereotype nodes should have stereotype property"
|
110
|
+
|
111
|
+
total_masked_nodes = int(len(all_stereotype_nodes) * mask_prob)
|
112
|
+
masked_nodes = random.sample(all_stereotype_nodes, total_masked_nodes)
|
113
|
+
unmasked_nodes = [node for node in all_stereotype_nodes if node not in masked_nodes]
|
114
|
+
|
115
|
+
for node in masked_nodes:
|
116
|
+
graph.nodes[node]['masked'] = True
|
117
|
+
graph.nodes[node]['use_stereotype'] = False
|
118
|
+
|
119
|
+
for node in unmasked_nodes:
|
120
|
+
graph.nodes[node]['masked'] = False
|
121
|
+
graph.nodes[node]['use_stereotype'] = use_stereotypes
|
122
|
+
|
123
|
+
assert all(['masked' in graph.nodes[node] for node in all_stereotype_nodes]), "All stereotype nodes should be masked or unmasked"
|
124
|
+
|
125
|
+
|
126
|
+
|
127
|
+
def mask_graphs(graphs, stereotypes_classes, mask_prob=0.2, use_stereotypes=False, use_rel_stereotypes=False):
|
128
|
+
masked, unmasked, total = 0, 0, 0
|
129
|
+
# for graph, f_name in tqdm(graphs, desc='Masking graphs'):
|
130
|
+
for graph, _ in graphs:
|
131
|
+
mask_graph(graph, stereotypes_classes, mask_prob=mask_prob, use_stereotypes=use_stereotypes, use_rel_stereotypes=use_rel_stereotypes)
|
132
|
+
masked += len([node for node in graph.nodes if 'masked' in graph.nodes[node] and graph.nodes[node]['masked']])
|
133
|
+
unmasked += len([node for node in graph.nodes if 'masked' in graph.nodes[node] and not graph.nodes[node]['masked']])
|
134
|
+
total += len([node for node in graph.nodes if 'masked' in graph.nodes[node]])
|
135
|
+
|
136
|
+
## % of masked nodes upto 2 decimal places
|
137
|
+
print(f"Masked {round(masked/total, 2)*100}%")
|
138
|
+
print(f"Unmasked {round(unmasked/total, 2)*100}%")
|
139
|
+
|
140
|
+
print("Total masked nodes:", masked)
|
141
|
+
print("Total unmasked nodes:", unmasked)
|
glam4cm/graph2str/uml.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from glam4cm.lang2graph.common import LangGraph
|
2
|
+
|
3
|
+
|
4
|
+
class ArchiMateNxG(LangGraph):
|
5
|
+
def __init__(
|
6
|
+
self,
|
7
|
+
json_obj: dict,
|
8
|
+
path: str,
|
9
|
+
timeout = -1
|
10
|
+
):
|
11
|
+
super().__init__()
|
12
|
+
self.json_obj = json_obj
|
13
|
+
self.timeout = timeout
|
14
|
+
self.path = path
|
15
|
+
self.graph_id = json_obj['identifier'].split('/')[-1]
|
16
|
+
|
17
|
+
|
18
|
+
self.__create_graph()
|
19
|
+
self.set_numbered_labels()
|
20
|
+
|
21
|
+
# self.text = " ".join([
|
22
|
+
# self.nodes[node]['name'] if 'name' in self.nodes[node] else ''
|
23
|
+
# for node in self.nodes
|
24
|
+
# ])
|
25
|
+
|
26
|
+
|
27
|
+
def __create_graph(self):
|
28
|
+
for node in self.json_obj['elements']:
|
29
|
+
self.add_node(node['id'], **node)
|
30
|
+
for edge in self.json_obj['relationships']:
|
31
|
+
self.add_edge(edge['sourceId'], edge['targetId'], **edge)
|
File without changes
|
@@ -0,0 +1,416 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import List
|
3
|
+
import networkx as nx
|
4
|
+
from uuid import uuid4
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
from glam4cm.data_loading.metadata import GraphMetadata
|
8
|
+
from glam4cm.tokenization.special_tokens import *
|
9
|
+
from glam4cm.tokenization.utils import doc_tokenizer
|
10
|
+
import glam4cm.utils as utils
|
11
|
+
|
12
|
+
SEP = ' '
|
13
|
+
REFERENCE = 'reference'
|
14
|
+
SUPERTYPE = 'supertype'
|
15
|
+
CONTAINMENT = 'containment'
|
16
|
+
|
17
|
+
|
18
|
+
class LangGraph(nx.DiGraph):
|
19
|
+
def __init__(self):
|
20
|
+
super().__init__()
|
21
|
+
self.id = uuid4().hex
|
22
|
+
self.node_label_to_id = {}
|
23
|
+
self.id_to_node_label = {}
|
24
|
+
self.edge_label_to_id = {}
|
25
|
+
self.id_to_edge_label = {}
|
26
|
+
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def create_graph(self):
|
30
|
+
pass
|
31
|
+
|
32
|
+
|
33
|
+
def set_numbered_labels(self):
|
34
|
+
self.node_label_to_id = {label: i for i, label in enumerate(self.nodes())}
|
35
|
+
self.id_to_node_label = {i: label for i, label in enumerate(self.nodes())}
|
36
|
+
|
37
|
+
self.edge_label_to_id = {label: i for i, label in enumerate(self.edges())}
|
38
|
+
self.id_to_edge_label = {i: label for i, label in enumerate(self.edges())}
|
39
|
+
|
40
|
+
self.numbered_graph = self.get_numbered_graph()
|
41
|
+
self.edge_to_idx = {edge: idx for idx, edge in enumerate(self.numbered_graph.edges())}
|
42
|
+
self.idx_to_edge = {idx: edge for idx, edge in enumerate(self.numbered_graph.edges())}
|
43
|
+
|
44
|
+
|
45
|
+
|
46
|
+
def get_numbered_graph(self) -> nx.DiGraph:
|
47
|
+
nodes = [(self.node_label_to_id[i], data) for i, data in list(self.nodes(data=True))]
|
48
|
+
edges = [(self.node_label_to_id[i], self.node_label_to_id[j], data) for i, j, data in list(self.edges(data=True))]
|
49
|
+
graph = nx.DiGraph()
|
50
|
+
graph.add_nodes_from(nodes)
|
51
|
+
graph.add_edges_from(edges)
|
52
|
+
|
53
|
+
return graph
|
54
|
+
|
55
|
+
|
56
|
+
@property
|
57
|
+
def enr(self):
|
58
|
+
if self.number_of_nodes() == 0:
|
59
|
+
return -1
|
60
|
+
return self.number_of_edges() / self.number_of_nodes()
|
61
|
+
|
62
|
+
|
63
|
+
@property
|
64
|
+
def edge_index(self):
|
65
|
+
edge_index = torch.tensor(list(self.numbered_graph.edges)).t().contiguous().numpy()
|
66
|
+
return edge_index
|
67
|
+
|
68
|
+
@property
|
69
|
+
def hash(self):
|
70
|
+
return utils.md5_hash(str(sorted(self.edges)))
|
71
|
+
|
72
|
+
def get_edge_id(self, edge):
|
73
|
+
return self.edge_label_to_id[edge]
|
74
|
+
|
75
|
+
def get_edge_label(self, edge_id):
|
76
|
+
return self.edge_label_to_id[edge_id]
|
77
|
+
|
78
|
+
|
79
|
+
def get_node_id(self, node):
|
80
|
+
return self.node_label_to_id[node]
|
81
|
+
|
82
|
+
def get_node_label(self, node_id):
|
83
|
+
return self.node_label_to_id[node_id]
|
84
|
+
|
85
|
+
|
86
|
+
def create_graph_from_edge_index(graph, edge_index: np.ndarray):
|
87
|
+
"""
|
88
|
+
Create a subgraph from G using only the edges specified in edge_index.
|
89
|
+
|
90
|
+
Parameters:
|
91
|
+
G (networkx.Graph): The original graph.
|
92
|
+
edge_index (numpy.ndarray): A numpy containing edge indices.
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
networkx.Graph: A subgraph of G containing only the edges in edge_index.
|
96
|
+
"""
|
97
|
+
|
98
|
+
if isinstance(edge_index, torch.Tensor):
|
99
|
+
edge_index = edge_index.cpu().numpy()
|
100
|
+
|
101
|
+
# Add nodes and edges from the edge_index to the subgraph
|
102
|
+
subgraph = nx.DiGraph()
|
103
|
+
subgraph.add_nodes_from(list(graph.numbered_graph.nodes(data=True)))
|
104
|
+
subgraph.add_edges_from([(u, v, graph.numbered_graph.edges[u, v]) for u, v in edge_index.T])
|
105
|
+
for node, data in subgraph.nodes(data=True):
|
106
|
+
data = graph.numbered_graph.nodes[node]
|
107
|
+
subgraph.nodes[node].update(data)
|
108
|
+
|
109
|
+
|
110
|
+
|
111
|
+
subgraph.node_label_to_id = graph.node_label_to_id
|
112
|
+
subgraph.id_to_node_label = graph.id_to_node_label
|
113
|
+
subgraph.edge_label_to_id = graph.edge_label_to_id
|
114
|
+
subgraph.id_to_edge_label = graph.id_to_edge_label
|
115
|
+
try:
|
116
|
+
assert subgraph.number_of_edges() == edge_index.shape[1]
|
117
|
+
except AssertionError as e:
|
118
|
+
print(f"Number of edges mismatch {subgraph.number_of_edges()} != {edge_index.size(1)}")
|
119
|
+
import pickle
|
120
|
+
pickle.dump([graph, edge_index], open("subgraph.pkl", "wb"))
|
121
|
+
raise e
|
122
|
+
|
123
|
+
return subgraph
|
124
|
+
|
125
|
+
|
126
|
+
def format_path(
|
127
|
+
graph: LangGraph,
|
128
|
+
path: List,
|
129
|
+
metadata: GraphMetadata,
|
130
|
+
use_node_attributes = False,
|
131
|
+
use_node_types = False,
|
132
|
+
use_edge_label = False,
|
133
|
+
use_edge_types = False,
|
134
|
+
node_cls_label=None,
|
135
|
+
edge_cls_label='type',
|
136
|
+
use_special_tokens = False,
|
137
|
+
no_labels = False,
|
138
|
+
preprocessor=doc_tokenizer,
|
139
|
+
neg_sample=False
|
140
|
+
):
|
141
|
+
"""Format a path into a string representation."""
|
142
|
+
def get_node_label(node):
|
143
|
+
masked = graph.nodes[node].get('masked')
|
144
|
+
node_type = f"{graph.nodes[node].get(f'{node_cls_label}', '')}" if use_node_types and not masked and node_cls_label else ''
|
145
|
+
node_type = f"{node_cls_label}: {node_type}" if node_type else ''
|
146
|
+
node_label = get_node_name(
|
147
|
+
graph.nodes[node],
|
148
|
+
metadata.node_label,
|
149
|
+
use_node_attributes,
|
150
|
+
metadata.node_attributes
|
151
|
+
) if not no_labels else ''
|
152
|
+
if preprocessor:
|
153
|
+
node_label = preprocessor(node_label)
|
154
|
+
|
155
|
+
node_label = f"{node_type} {node_label}".strip()
|
156
|
+
if use_special_tokens:
|
157
|
+
node_label = f"{NODE_BEGIN}{node_label}{NODE_END}"
|
158
|
+
|
159
|
+
return node_label.strip()
|
160
|
+
|
161
|
+
def get_edge_label(n1, n2):
|
162
|
+
edge_data = graph.get_edge_data(n1, n2)
|
163
|
+
masked = edge_data.get('masked')
|
164
|
+
edge_label = edge_data.get(metadata.edge_label, '') if use_edge_label and not no_labels else ''
|
165
|
+
edge_type = f"{edge_cls_label}:{get_edge_data(edge_data, f'{edge_cls_label}', metadata.type)}" if use_edge_types and not masked and edge_cls_label else ''
|
166
|
+
|
167
|
+
if preprocessor:
|
168
|
+
edge_label = preprocessor(edge_label)
|
169
|
+
|
170
|
+
edge_label = f"{edge_type} {edge_label}".strip()
|
171
|
+
|
172
|
+
if use_special_tokens:
|
173
|
+
edge_label = f"{EDGE_START}{edge_label}{EDGE_END}"
|
174
|
+
|
175
|
+
return edge_label.strip()
|
176
|
+
|
177
|
+
assert len(path) > 0, "Path must contain at least one node."
|
178
|
+
formatted = [get_node_label(path[0])]
|
179
|
+
for i in range(1, len(path)):
|
180
|
+
n1 = path[i - 1]
|
181
|
+
n2 = path[i]
|
182
|
+
|
183
|
+
if not neg_sample:
|
184
|
+
formatted.append(get_edge_label(n1, n2))
|
185
|
+
formatted.append(get_node_label(n2))
|
186
|
+
|
187
|
+
return " ".join(formatted).strip()
|
188
|
+
|
189
|
+
def get_edge_texts(
|
190
|
+
graph: LangGraph,
|
191
|
+
edge: tuple,
|
192
|
+
d: int,
|
193
|
+
metadata: GraphMetadata,
|
194
|
+
use_node_attributes=False,
|
195
|
+
use_node_types=False,
|
196
|
+
use_edge_types=False,
|
197
|
+
use_edge_label=False,
|
198
|
+
node_cls_label=None,
|
199
|
+
edge_cls_label='type',
|
200
|
+
use_special_tokens=False,
|
201
|
+
no_labels=False,
|
202
|
+
preprocessor: callable = doc_tokenizer,
|
203
|
+
neg_samples=False
|
204
|
+
):
|
205
|
+
n1, n2 = edge
|
206
|
+
if not neg_samples:
|
207
|
+
masked = graph.edges[n1, n2].get('masked')
|
208
|
+
graph.edges[n1, n2]['masked'] = True
|
209
|
+
|
210
|
+
n1_text = get_node_text(
|
211
|
+
graph=graph,
|
212
|
+
node=n1,
|
213
|
+
d=d,
|
214
|
+
metadata=metadata,
|
215
|
+
use_node_attributes=use_node_attributes,
|
216
|
+
use_node_types=use_node_types,
|
217
|
+
use_edge_types=use_edge_types,
|
218
|
+
use_edge_label=use_edge_label,
|
219
|
+
node_cls_label=node_cls_label,
|
220
|
+
edge_cls_label=edge_cls_label,
|
221
|
+
use_special_tokens=use_special_tokens,
|
222
|
+
no_labels=no_labels,
|
223
|
+
preprocessor=preprocessor,
|
224
|
+
exclude_edges=[edge]
|
225
|
+
)
|
226
|
+
n2_text = get_node_text(
|
227
|
+
graph=graph,
|
228
|
+
node=n2,
|
229
|
+
d=d,
|
230
|
+
metadata=metadata,
|
231
|
+
use_node_attributes=use_node_attributes,
|
232
|
+
use_node_types=use_node_types,
|
233
|
+
use_edge_types=use_edge_types,
|
234
|
+
use_edge_label=use_edge_label,
|
235
|
+
node_cls_label=node_cls_label,
|
236
|
+
edge_cls_label=edge_cls_label,
|
237
|
+
use_special_tokens=use_special_tokens,
|
238
|
+
no_labels=no_labels,
|
239
|
+
preprocessor=preprocessor,
|
240
|
+
exclude_edges=[edge]
|
241
|
+
)
|
242
|
+
if not neg_samples:
|
243
|
+
graph.edges[n1, n2]['masked'] = masked or False
|
244
|
+
|
245
|
+
return n1_text + EDGE_START + EDGE_END + n2_text
|
246
|
+
|
247
|
+
|
248
|
+
def get_node_text(
|
249
|
+
graph: LangGraph,
|
250
|
+
node,
|
251
|
+
d: int,
|
252
|
+
metadata: GraphMetadata,
|
253
|
+
use_node_attributes=False,
|
254
|
+
use_node_types=False,
|
255
|
+
use_edge_types=False,
|
256
|
+
use_edge_label=False,
|
257
|
+
node_cls_label=None,
|
258
|
+
edge_cls_label='type',
|
259
|
+
use_special_tokens=False,
|
260
|
+
no_labels=False,
|
261
|
+
preprocessor: callable = doc_tokenizer,
|
262
|
+
exclude_edges=None
|
263
|
+
):
|
264
|
+
masked = graph.nodes[node].get('masked')
|
265
|
+
graph.nodes[node]['masked'] = True
|
266
|
+
raw_paths = utils.bfs(graph=graph, start_node=node, d=d, exclude_edges=exclude_edges)
|
267
|
+
unique_paths = utils.remove_subsets(list_of_lists=raw_paths)
|
268
|
+
text = "\n".join([
|
269
|
+
format_path(
|
270
|
+
graph=graph,
|
271
|
+
path=path,
|
272
|
+
metadata=metadata,
|
273
|
+
use_node_attributes=use_node_attributes,
|
274
|
+
use_node_types=use_node_types,
|
275
|
+
use_edge_types=use_edge_types,
|
276
|
+
use_edge_label=use_edge_label,
|
277
|
+
node_cls_label=node_cls_label,
|
278
|
+
edge_cls_label=edge_cls_label,
|
279
|
+
use_special_tokens=use_special_tokens,
|
280
|
+
no_labels=no_labels,
|
281
|
+
preprocessor=preprocessor,
|
282
|
+
neg_sample=False
|
283
|
+
)
|
284
|
+
for path in unique_paths
|
285
|
+
])
|
286
|
+
graph.nodes[node]['masked'] = masked or False
|
287
|
+
return text
|
288
|
+
|
289
|
+
|
290
|
+
def get_node_texts(
|
291
|
+
graph: LangGraph,
|
292
|
+
d: int,
|
293
|
+
metadata: GraphMetadata,
|
294
|
+
use_node_attributes=False,
|
295
|
+
use_node_types=False,
|
296
|
+
use_edge_types=False,
|
297
|
+
use_edge_label=False,
|
298
|
+
node_cls_label=None,
|
299
|
+
edge_cls_label='type',
|
300
|
+
use_special_tokens=False,
|
301
|
+
no_labels=False,
|
302
|
+
preprocessor: callable = doc_tokenizer
|
303
|
+
):
|
304
|
+
paths_dict = {}
|
305
|
+
for node in graph.nodes:
|
306
|
+
paths_dict[node] = get_node_text(
|
307
|
+
graph=graph,
|
308
|
+
node=node,
|
309
|
+
d=d,
|
310
|
+
metadata=metadata,
|
311
|
+
use_node_attributes=use_node_attributes,
|
312
|
+
use_node_types=use_node_types,
|
313
|
+
use_edge_types=use_edge_types,
|
314
|
+
use_edge_label=use_edge_label,
|
315
|
+
node_cls_label=node_cls_label,
|
316
|
+
edge_cls_label=edge_cls_label,
|
317
|
+
use_special_tokens=use_special_tokens,
|
318
|
+
no_labels=no_labels,
|
319
|
+
preprocessor=preprocessor
|
320
|
+
)
|
321
|
+
|
322
|
+
return paths_dict
|
323
|
+
|
324
|
+
|
325
|
+
def get_attribute_labels(node_data, attribute_labels):
|
326
|
+
if isinstance(node_data[attribute_labels], list):
|
327
|
+
if not node_data[attribute_labels]:
|
328
|
+
return ''
|
329
|
+
if isinstance(node_data[attribute_labels][0], tuple):
|
330
|
+
return ", ".join([f"{k}: {v}" for k, v in node_data[attribute_labels]])
|
331
|
+
elif isinstance(node_data[attribute_labels][0], dict):
|
332
|
+
return ", ".join([f"{k}: {v}" for d in node_data[attribute_labels] for k, v in d.items()])
|
333
|
+
return ", ".join(node_data[attribute_labels])
|
334
|
+
if isinstance(node_data[attribute_labels], dict):
|
335
|
+
return ", ".join([f"{k}: {v}" for k, v in node_data[attribute_labels].items()])
|
336
|
+
return node_data[attribute_labels]
|
337
|
+
|
338
|
+
|
339
|
+
def get_node_name(
|
340
|
+
node_data,
|
341
|
+
label,
|
342
|
+
use_attributes,
|
343
|
+
attribute_labels,
|
344
|
+
):
|
345
|
+
if use_attributes and attribute_labels in node_data:
|
346
|
+
attributes_str = "(" + get_attribute_labels(node_data, attribute_labels) + ")"
|
347
|
+
else:
|
348
|
+
attributes_str = ''
|
349
|
+
node_label = node_data.get(label, '')
|
350
|
+
node_label = '' if node_label.lower() == 'null' else node_label
|
351
|
+
return f"{node_label}{attributes_str}".strip()
|
352
|
+
|
353
|
+
|
354
|
+
def get_node_data(
|
355
|
+
node_data: dict,
|
356
|
+
node_label: str,
|
357
|
+
model_type: str,
|
358
|
+
):
|
359
|
+
if model_type == 'archimate':
|
360
|
+
return get_archimate_node_data(node_data, node_label)
|
361
|
+
elif model_type == 'ecore':
|
362
|
+
return get_uml_node_data(node_data, node_label)
|
363
|
+
elif model_type == 'ontouml':
|
364
|
+
return get_ontouml_node_data(node_data, node_label)
|
365
|
+
else:
|
366
|
+
raise ValueError(f"Unknown model type: {model_type}")
|
367
|
+
|
368
|
+
|
369
|
+
|
370
|
+
def get_edge_data(
|
371
|
+
edge_data: dict,
|
372
|
+
edge_label: str,
|
373
|
+
model_type: str,
|
374
|
+
):
|
375
|
+
if model_type == 'archimate':
|
376
|
+
return get_archimate_edge_data(edge_data, edge_label)
|
377
|
+
elif model_type == 'ecore':
|
378
|
+
return get_uml_edge_data(edge_data, edge_label)
|
379
|
+
elif model_type == 'ontouml':
|
380
|
+
return get_ontouml_edge_data(edge_data, edge_label)
|
381
|
+
else:
|
382
|
+
raise ValueError(f"Unknown model type: {model_type}")
|
383
|
+
|
384
|
+
|
385
|
+
def get_archimate_node_data(edge_data: dict, node_label: str):
|
386
|
+
return edge_data.get(node_label)
|
387
|
+
|
388
|
+
def get_uml_node_data(node_data: dict, node_label: str):
|
389
|
+
return node_data.get(node_label, '')
|
390
|
+
|
391
|
+
def get_ontouml_node_data(node_data: dict, node_label: str):
|
392
|
+
return node_data.get(node_label, '')
|
393
|
+
|
394
|
+
|
395
|
+
def get_archimate_edge_data(edge_data: dict, edge_label: str):
|
396
|
+
return edge_data.get(edge_label)
|
397
|
+
|
398
|
+
|
399
|
+
def get_uml_edge_data(edge_data: dict, edge_label: str):
|
400
|
+
if edge_label == 'type':
|
401
|
+
return get_uml_edge_type(edge_data)
|
402
|
+
elif edge_label in edge_data:
|
403
|
+
return edge_data[edge_label]
|
404
|
+
else:
|
405
|
+
raise ValueError(f"Unknown edge label: {edge_label}")
|
406
|
+
|
407
|
+
def get_ontouml_edge_data(edge_data: dict, edge_label: str):
|
408
|
+
return edge_data.get(edge_label)
|
409
|
+
|
410
|
+
def get_uml_edge_type(edge_data):
|
411
|
+
edge_type = edge_data.get('type')
|
412
|
+
if edge_type == SUPERTYPE:
|
413
|
+
return SUPERTYPE
|
414
|
+
if edge_type == CONTAINMENT:
|
415
|
+
return CONTAINMENT
|
416
|
+
return REFERENCE
|