glam4cm 0.1.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- glam4cm/__init__.py +2 -1
- glam4cm/data_loading/data.py +90 -146
- glam4cm/data_loading/encoding.py +17 -6
- glam4cm/data_loading/graph_dataset.py +192 -57
- glam4cm/data_loading/metadata.py +1 -1
- glam4cm/data_loading/models_dataset.py +42 -18
- glam4cm/downstream_tasks/bert_edge_classification.py +49 -22
- glam4cm/downstream_tasks/bert_graph_classification.py +44 -14
- glam4cm/downstream_tasks/bert_graph_classification_comp.py +47 -24
- glam4cm/downstream_tasks/bert_link_prediction.py +46 -26
- glam4cm/downstream_tasks/bert_node_classification.py +127 -89
- glam4cm/downstream_tasks/cm_gpt_node_classification.py +61 -15
- glam4cm/downstream_tasks/common_args.py +32 -4
- glam4cm/downstream_tasks/gnn_edge_classification.py +24 -7
- glam4cm/downstream_tasks/gnn_graph_cls.py +19 -6
- glam4cm/downstream_tasks/gnn_link_prediction.py +25 -13
- glam4cm/downstream_tasks/gnn_node_classification.py +19 -7
- glam4cm/downstream_tasks/utils.py +16 -2
- glam4cm/embeddings/bert.py +1 -1
- glam4cm/embeddings/common.py +7 -4
- glam4cm/encoding/encoders.py +1 -1
- glam4cm/lang2graph/archimate.py +0 -5
- glam4cm/lang2graph/common.py +99 -41
- glam4cm/lang2graph/ecore.py +1 -2
- glam4cm/lang2graph/ontouml.py +8 -7
- glam4cm/models/gnn_layers.py +20 -6
- glam4cm/models/hf.py +2 -2
- glam4cm/run.py +12 -7
- glam4cm/run_conf_v2.py +405 -0
- glam4cm/run_configs.py +70 -106
- glam4cm/run_confs.py +41 -0
- glam4cm/settings.py +15 -2
- glam4cm/tokenization/special_tokens.py +23 -1
- glam4cm/tokenization/utils.py +23 -4
- glam4cm/trainers/cm_gpt_trainer.py +1 -1
- glam4cm/trainers/gnn_edge_classifier.py +12 -1
- glam4cm/trainers/gnn_graph_classifier.py +12 -5
- glam4cm/trainers/gnn_link_predictor.py +18 -3
- glam4cm/trainers/gnn_link_predictor_v2.py +146 -0
- glam4cm/trainers/gnn_trainer.py +8 -0
- glam4cm/trainers/metrics.py +1 -1
- glam4cm/utils.py +265 -2
- {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/METADATA +3 -2
- glam4cm-1.0.0.dist-info/RECORD +75 -0
- {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/WHEEL +1 -1
- glam4cm-0.1.1.dist-info/RECORD +0 -72
- {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/entry_points.txt +0 -0
- {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {glam4cm-0.1.1.dist-info → glam4cm-1.0.0.dist-info}/top_level.txt +0 -0
glam4cm/__init__.py
CHANGED
glam4cm/data_loading/data.py
CHANGED
@@ -19,12 +19,12 @@ from glam4cm.lang2graph.common import (
|
|
19
19
|
|
20
20
|
from scipy.sparse import csr_matrix
|
21
21
|
|
22
|
-
from glam4cm.settings import
|
22
|
+
from glam4cm.settings import DUMMY_GRAPH_CLS_TASK, EDGE_CLS_TASK, LINK_PRED_TASK
|
23
23
|
from glam4cm.tokenization.special_tokens import *
|
24
24
|
from torch_geometric.transforms import RandomLinkSplit
|
25
25
|
import torch
|
26
|
-
from torch_geometric.data import Data
|
27
|
-
from typing import List,
|
26
|
+
from torch_geometric.data import Data
|
27
|
+
from typing import List, Union
|
28
28
|
from glam4cm.tokenization.utils import doc_tokenizer
|
29
29
|
|
30
30
|
|
@@ -98,6 +98,7 @@ class TorchGraph:
|
|
98
98
|
self,
|
99
99
|
graph: Union[EcoreNxG, ArchiMateNxG],
|
100
100
|
metadata: Union[EcoreMetaData, ArchimateMetaData],
|
101
|
+
task_type: str,
|
101
102
|
distance = 0,
|
102
103
|
test_ratio=0.2,
|
103
104
|
use_edge_types=False,
|
@@ -108,9 +109,12 @@ class TorchGraph:
|
|
108
109
|
no_labels=False,
|
109
110
|
node_cls_label=None,
|
110
111
|
edge_cls_label='type',
|
112
|
+
|
113
|
+
node_topk: List[Union[str, int]]=None,
|
111
114
|
fp='test_graph.pkl'
|
112
115
|
):
|
113
116
|
|
117
|
+
self.task_type = task_type
|
114
118
|
self.fp = fp
|
115
119
|
self.graph = graph
|
116
120
|
self.metadata = metadata
|
@@ -126,6 +130,8 @@ class TorchGraph:
|
|
126
130
|
self.node_cls_label = node_cls_label
|
127
131
|
self.edge_cls_label = edge_cls_label
|
128
132
|
|
133
|
+
self.node_topk = node_topk
|
134
|
+
|
129
135
|
self.distance = distance
|
130
136
|
self.test_ratio = test_ratio
|
131
137
|
self.data = NumpyData()
|
@@ -148,6 +154,10 @@ class TorchGraph:
|
|
148
154
|
)
|
149
155
|
|
150
156
|
edge_texts = self.get_graph_edge_strs()
|
157
|
+
|
158
|
+
# print(f"Number of edges: {len(edge_texts)}")
|
159
|
+
# print("Edge strings: ", edge_texts[:50])
|
160
|
+
|
151
161
|
return node_texts, edge_texts
|
152
162
|
|
153
163
|
|
@@ -162,17 +172,20 @@ class TorchGraph:
|
|
162
172
|
|
163
173
|
def generate_embeddings():
|
164
174
|
if randomize_ne or embedder is None:
|
165
|
-
print("Randomizing node embeddings")
|
175
|
+
# print("Randomizing node embeddings")
|
166
176
|
self.data.x = np.random.randn(self.graph.number_of_nodes(), random_embed_dim)
|
167
177
|
else:
|
168
178
|
self.data.x = embedder.embed(list(self.node_texts.values()))
|
169
179
|
|
170
180
|
if randomize_ee or embedder is None:
|
171
|
-
print("Randomizing edge embeddings")
|
181
|
+
# print("Randomizing edge embeddings")
|
172
182
|
self.data.edge_attr = np.random.randn(self.graph.number_of_edges(), random_embed_dim)
|
173
183
|
else:
|
174
|
-
|
184
|
+
edge_texts = list(self.edge_texts.values())
|
185
|
+
self.data.edge_attr = embedder.embed(edge_texts) \
|
186
|
+
if len(edge_texts) > 0 else np.empty((self.graph.number_of_edges(), random_embed_dim))
|
175
187
|
|
188
|
+
|
176
189
|
if os.path.exists(f"{self.fp}") and not reload:
|
177
190
|
with open(f"{self.fp}", 'rb') as f:
|
178
191
|
obj: Union[TorchEdgeGraph, TorchNodeGraph] = pickle.load(f)
|
@@ -199,6 +212,7 @@ class TorchGraph:
|
|
199
212
|
distance = self.distance
|
200
213
|
|
201
214
|
subgraph = create_graph_from_edge_index(self.graph, edge_index)
|
215
|
+
|
202
216
|
return get_node_texts(
|
203
217
|
subgraph,
|
204
218
|
distance,
|
@@ -230,6 +244,7 @@ class TorchGraph:
|
|
230
244
|
self.graph.numbered_graph,
|
231
245
|
(u, v),
|
232
246
|
d=self.distance,
|
247
|
+
task_type=self.task_type,
|
233
248
|
metadata=self.metadata,
|
234
249
|
use_node_attributes=self.use_attributes,
|
235
250
|
use_node_types=self.use_node_types,
|
@@ -249,14 +264,13 @@ class TorchGraph:
|
|
249
264
|
def validate_data(self):
|
250
265
|
assert self.data.num_nodes == self.graph.number_of_nodes()
|
251
266
|
|
267
|
+
|
252
268
|
def set_graph_label(self):
|
253
269
|
if self.metadata.graph_label is not None and not hasattr(self.graph, self.metadata.graph_label): #Graph has a label
|
254
270
|
text = doc_tokenizer("\n".join(list(self.node_texts.values())))
|
255
|
-
# print("Text:", text)
|
256
|
-
# print("-" * 100)
|
257
271
|
setattr(self.graph, self.metadata.graph_label, text)
|
258
272
|
|
259
|
-
|
273
|
+
|
260
274
|
@property
|
261
275
|
def name(self):
|
262
276
|
return '.'.join(self.graph.graph_id.replace('/', '_').split('.')[:-1])
|
@@ -268,24 +282,28 @@ class TorchEdgeGraph(TorchGraph):
|
|
268
282
|
self,
|
269
283
|
graph: Union[EcoreNxG, ArchiMateNxG],
|
270
284
|
metadata: Union[EcoreMetaData, ArchimateMetaData],
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
285
|
+
task_type: str,
|
286
|
+
distance: int = 1,
|
287
|
+
test_ratio: float =0.2,
|
288
|
+
add_negative_train_samples: bool =False,
|
289
|
+
neg_samples_ratio: int =1,
|
290
|
+
use_edge_types: bool =False,
|
291
|
+
use_node_types: bool =False,
|
292
|
+
use_edge_label: bool =False,
|
293
|
+
use_attributes: bool =False,
|
294
|
+
use_special_tokens: bool =False,
|
295
|
+
node_cls_label: str =None,
|
296
|
+
edge_cls_label: str ='type',
|
297
|
+
no_labels: bool =False,
|
298
|
+
|
299
|
+
node_topk: List[Union[str, int]]=None,
|
283
300
|
fp: str = 'test_graph.pkl'
|
284
301
|
):
|
285
302
|
|
286
303
|
super().__init__(
|
287
304
|
graph=graph,
|
288
305
|
metadata=metadata,
|
306
|
+
task_type=task_type,
|
289
307
|
distance=distance,
|
290
308
|
test_ratio=test_ratio,
|
291
309
|
use_node_types=use_node_types,
|
@@ -296,6 +314,7 @@ class TorchEdgeGraph(TorchGraph):
|
|
296
314
|
no_labels=no_labels,
|
297
315
|
node_cls_label=node_cls_label,
|
298
316
|
edge_cls_label=edge_cls_label,
|
317
|
+
node_topk=node_topk,
|
299
318
|
fp=fp
|
300
319
|
)
|
301
320
|
self.add_negative_train_samples = add_negative_train_samples
|
@@ -304,9 +323,7 @@ class TorchEdgeGraph(TorchGraph):
|
|
304
323
|
self.validate_data()
|
305
324
|
self.set_graph_label()
|
306
325
|
|
307
|
-
|
308
326
|
|
309
|
-
|
310
327
|
def get_pyg_data(self):
|
311
328
|
|
312
329
|
d = GraphData()
|
@@ -344,7 +361,8 @@ class TorchEdgeGraph(TorchGraph):
|
|
344
361
|
setattr(d, 'test_pos_edge_label', test_data.pos_edge_label)
|
345
362
|
|
346
363
|
|
347
|
-
if
|
364
|
+
if self.add_negative_train_samples:
|
365
|
+
assert hasattr(train_data, 'neg_edge_label_index')
|
348
366
|
assert not any([self.graph.numbered_graph.has_edge(*edge) for edge in train_data.neg_edge_label_index.t().tolist()])
|
349
367
|
assert not any([self.graph.numbered_graph.has_edge(*edge) for edge in test_data.neg_edge_label_index.t().tolist()])
|
350
368
|
setattr(d, 'train_neg_edge_label_index', train_data.neg_edge_label_index)
|
@@ -372,6 +390,9 @@ class TorchEdgeGraph(TorchGraph):
|
|
372
390
|
node_texts, edge_texts = self.get_node_edge_strings(
|
373
391
|
edge_index=edge_index.numpy(),
|
374
392
|
)
|
393
|
+
|
394
|
+
# print("Node texts: ", list(node_texts.values())[:5])
|
395
|
+
# print("Edge texts: ", list(edge_texts.values())[:5])
|
375
396
|
|
376
397
|
setattr(d, 'num_nodes', self.graph.number_of_nodes())
|
377
398
|
setattr(d, 'num_edges', self.graph.number_of_edges())
|
@@ -384,7 +405,7 @@ class TorchEdgeGraph(TorchGraph):
|
|
384
405
|
train_pos_edge_index = self.data.edge_index
|
385
406
|
test_pos_edge_index = self.data.test_pos_edge_label_index
|
386
407
|
|
387
|
-
if task_type ==
|
408
|
+
if task_type == LINK_PRED_TASK:
|
388
409
|
train_neg_edge_index = self.data.train_neg_edge_label_index
|
389
410
|
test_neg_edge_index = self.data.test_neg_edge_label_index
|
390
411
|
else:
|
@@ -412,9 +433,12 @@ class TorchEdgeGraph(TorchGraph):
|
|
412
433
|
|
413
434
|
edge_strs = list(edge_strs.values())
|
414
435
|
data[f'{edge_index_label}_edges'] = edge_strs
|
436
|
+
|
437
|
+
# print(f"Number of {edge_index_label} edges: {len(edge_strs)}")
|
438
|
+
# print("Edge strings: ", edge_strs[:50])
|
415
439
|
|
416
440
|
|
417
|
-
if task_type ==
|
441
|
+
if task_type == EDGE_CLS_TASK and not only_texts:
|
418
442
|
train_mask = self.data.train_edge_mask
|
419
443
|
test_mask = self.data.test_edge_mask
|
420
444
|
train_classes, test_classes = getattr(self.data, f'edge_{label}')[train_mask], getattr(self.data, f'edge_{label}')[test_mask]
|
@@ -431,22 +455,28 @@ class TorchNodeGraph(TorchGraph):
|
|
431
455
|
self,
|
432
456
|
graph: Union[EcoreNxG, ArchiMateNxG],
|
433
457
|
metadata: dict,
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
458
|
+
task_type: str,
|
459
|
+
|
460
|
+
distance: int = 1,
|
461
|
+
test_ratio: float =0.2,
|
462
|
+
use_node_types: bool =False,
|
463
|
+
use_edge_types: bool =False,
|
464
|
+
use_edge_label: bool =False,
|
465
|
+
use_attributes: bool =False,
|
466
|
+
use_special_tokens: bool =False,
|
467
|
+
no_labels: bool =False,
|
468
|
+
node_cls_label: str =None,
|
469
|
+
edge_cls_label: str ='type',
|
470
|
+
|
471
|
+
node_topk: List[Union[str, int]]=None,
|
472
|
+
|
473
|
+
fp='test_graph.pkl',
|
445
474
|
):
|
446
475
|
|
447
476
|
super().__init__(
|
448
477
|
graph,
|
449
478
|
metadata=metadata,
|
479
|
+
task_type=task_type,
|
450
480
|
distance=distance,
|
451
481
|
test_ratio=test_ratio,
|
452
482
|
use_node_types=use_node_types,
|
@@ -457,6 +487,8 @@ class TorchNodeGraph(TorchGraph):
|
|
457
487
|
no_labels=no_labels,
|
458
488
|
node_cls_label=node_cls_label,
|
459
489
|
edge_cls_label=edge_cls_label,
|
490
|
+
|
491
|
+
node_topk=node_topk,
|
460
492
|
fp=fp
|
461
493
|
)
|
462
494
|
|
@@ -468,15 +500,28 @@ class TorchNodeGraph(TorchGraph):
|
|
468
500
|
|
469
501
|
def get_pyg_data(self):
|
470
502
|
d = GraphData()
|
471
|
-
|
472
|
-
list(self.graph.numbered_graph.nodes)
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
503
|
+
if self.task_type == DUMMY_GRAPH_CLS_TASK:
|
504
|
+
train_nodes = list(self.graph.numbered_graph.nodes)
|
505
|
+
test_nodes = list()
|
506
|
+
else:
|
507
|
+
train_nodes, test_nodes = train_test_split(
|
508
|
+
list(self.graph.numbered_graph.nodes),
|
509
|
+
test_size=self.test_ratio,
|
510
|
+
shuffle=True,
|
511
|
+
random_state=42
|
512
|
+
)
|
477
513
|
|
514
|
+
def get_node_label(node):
|
515
|
+
if self.node_cls_label in self.graph.numbered_graph.nodes[node]\
|
516
|
+
and self.graph.numbered_graph.nodes[node][self.node_cls_label] is not None:
|
517
|
+
return self.graph.numbered_graph.nodes[node][self.node_cls_label]
|
518
|
+
return None
|
519
|
+
|
478
520
|
nx.set_node_attributes(self.graph.numbered_graph, {node: False for node in train_nodes}, 'masked')
|
479
|
-
nx.set_node_attributes(self.graph.numbered_graph, {
|
521
|
+
nx.set_node_attributes(self.graph.numbered_graph, {
|
522
|
+
node: get_node_label(node) in self.node_topk
|
523
|
+
for node in test_nodes
|
524
|
+
}, 'masked')
|
480
525
|
|
481
526
|
train_idx = torch.tensor(train_nodes, dtype=torch.long)
|
482
527
|
test_idx = torch.tensor(test_nodes, dtype=torch.long)
|
@@ -527,105 +572,4 @@ def validate_edges(graph: Union[TorchEdgeGraph, TorchNodeGraph]):
|
|
527
572
|
|
528
573
|
if train_neg_edge_index is not None and test_neg_edge_index is not None:
|
529
574
|
assert len(set((a, b) for a, b in train_neg_edge_index.T.tolist()).intersection(set((a, b) for a, b in test_neg_edge_index.T.tolist()))) == 0
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
class LinkPredictionCollater:
|
535
|
-
def __init__(
|
536
|
-
self,
|
537
|
-
follow_batch: Optional[List[str]] = None,
|
538
|
-
exclude_keys: Optional[List[str]] = None
|
539
|
-
):
|
540
|
-
self.follow_batch = follow_batch
|
541
|
-
self.exclude_keys = exclude_keys
|
542
|
-
|
543
|
-
def __call__(self, batch: List[Data]):
|
544
|
-
# Initialize lists to collect batched properties
|
545
|
-
x = []
|
546
|
-
edge_index = []
|
547
|
-
edge_attr = []
|
548
|
-
y = []
|
549
|
-
overall_edge_index = []
|
550
|
-
edge_classes = []
|
551
|
-
train_edge_mask = []
|
552
|
-
test_edge_mask = []
|
553
|
-
train_pos_edge_label_index = []
|
554
|
-
train_pos_edge_label = []
|
555
|
-
train_neg_edge_label_index = []
|
556
|
-
train_neg_edge_label = []
|
557
|
-
test_pos_edge_label_index = []
|
558
|
-
test_pos_edge_label = []
|
559
|
-
test_neg_edge_label_index = []
|
560
|
-
test_neg_edge_label = []
|
561
|
-
|
562
|
-
# Offsets for edge indices
|
563
|
-
node_offset = 0
|
564
|
-
edge_offset = 0
|
565
|
-
|
566
|
-
for data in batch:
|
567
|
-
x.append(data.x)
|
568
|
-
edge_index.append(data.edge_index + node_offset)
|
569
|
-
edge_attr.append(data.edge_attr)
|
570
|
-
y.append(data.y)
|
571
|
-
overall_edge_index.append(data.overall_edge_index + edge_offset)
|
572
|
-
edge_classes.append(data.edge_classes)
|
573
|
-
|
574
|
-
train_edge_mask.append(data.train_edge_mask)
|
575
|
-
test_edge_mask.append(data.test_edge_mask)
|
576
|
-
|
577
|
-
train_pos_edge_label_index.append(data.train_pos_edge_label_index + node_offset)
|
578
|
-
train_pos_edge_label.append(data.train_pos_edge_label)
|
579
|
-
train_neg_edge_label_index.append(data.train_neg_edge_label_index + node_offset)
|
580
|
-
train_neg_edge_label.append(data.train_neg_edge_label)
|
581
|
-
|
582
|
-
test_pos_edge_label_index.append(data.test_pos_edge_label_index + node_offset)
|
583
|
-
test_pos_edge_label.append(data.test_pos_edge_label)
|
584
|
-
test_neg_edge_label_index.append(data.test_neg_edge_label_index + node_offset)
|
585
|
-
test_neg_edge_label.append(data.test_neg_edge_label)
|
586
|
-
|
587
|
-
node_offset += data.num_nodes
|
588
|
-
edge_offset += data.edge_attr.size(0)
|
589
|
-
|
590
|
-
return GraphData(
|
591
|
-
x=torch.cat(x, dim=0),
|
592
|
-
edge_index=torch.cat(edge_index, dim=1),
|
593
|
-
edge_attr=torch.cat(edge_attr, dim=0),
|
594
|
-
y=torch.tensor(y),
|
595
|
-
overall_edge_index=torch.cat(overall_edge_index, dim=1),
|
596
|
-
edge_classes=torch.cat(edge_classes),
|
597
|
-
train_edge_mask=torch.cat(train_edge_mask),
|
598
|
-
test_edge_mask=torch.cat(test_edge_mask),
|
599
|
-
train_pos_edge_label_index=torch.cat(train_pos_edge_label_index, dim=1),
|
600
|
-
train_pos_edge_label=torch.cat(train_pos_edge_label),
|
601
|
-
train_neg_edge_label_index=torch.cat(train_neg_edge_label_index, dim=1),
|
602
|
-
train_neg_edge_label=torch.cat(train_neg_edge_label),
|
603
|
-
test_pos_edge_label_index=torch.cat(test_pos_edge_label_index, dim=1),
|
604
|
-
test_pos_edge_label=torch.cat(test_pos_edge_label),
|
605
|
-
test_neg_edge_label_index=torch.cat(test_neg_edge_label_index, dim=1),
|
606
|
-
test_neg_edge_label=torch.cat(test_neg_edge_label),
|
607
|
-
num_nodes=node_offset
|
608
|
-
)
|
609
|
-
|
610
|
-
|
611
|
-
class LinkPredictionDataLoader(torch.utils.data.DataLoader):
|
612
|
-
def __init__(
|
613
|
-
self,
|
614
|
-
dataset: Union[Dataset, Sequence[Data]],
|
615
|
-
batch_size: int = 1,
|
616
|
-
shuffle: bool = False,
|
617
|
-
collate_fn=None,
|
618
|
-
follow_batch: Optional[List[str]] = None,
|
619
|
-
exclude_keys: Optional[List[str]] = None,
|
620
|
-
**kwargs,
|
621
|
-
):
|
622
|
-
if collate_fn is None:
|
623
|
-
collate_fn = LinkPredictionCollater(follow_batch, exclude_keys)
|
624
|
-
|
625
|
-
super().__init__(
|
626
|
-
dataset,
|
627
|
-
batch_size,
|
628
|
-
shuffle,
|
629
|
-
collate_fn=collate_fn,
|
630
|
-
**kwargs,
|
631
|
-
)
|
575
|
+
|
glam4cm/data_loading/encoding.py
CHANGED
@@ -1,24 +1,34 @@
|
|
1
|
+
from typing import List, Union
|
1
2
|
from torch.utils.data import Dataset
|
2
3
|
import torch
|
4
|
+
from transformers import AutoTokenizer
|
5
|
+
|
6
|
+
def get_max_length(tokenizer):
|
7
|
+
tokenizer_name = tokenizer.name_or_path.lower()
|
8
|
+
if 'modernbert' in tokenizer_name:
|
9
|
+
return 8000
|
10
|
+
return 512
|
3
11
|
|
4
12
|
# Create your dataset
|
5
13
|
class EncodingDataset(Dataset):
|
6
14
|
def __init__(
|
7
15
|
self,
|
8
|
-
tokenizer,
|
9
|
-
texts,
|
10
|
-
labels=None,
|
16
|
+
tokenizer: AutoTokenizer,
|
17
|
+
texts: List[str],
|
18
|
+
labels:List[Union[str, int]]=None,
|
11
19
|
max_length=512,
|
12
20
|
remove_duplicates=False
|
13
21
|
):
|
22
|
+
|
23
|
+
max_length = get_max_length(tokenizer)
|
14
24
|
|
15
25
|
if remove_duplicates:
|
16
|
-
|
26
|
+
print(f'Dataset with {len(texts)} samples before removing duplicates')
|
17
27
|
texts_to_id = {text: i for i, text in enumerate(texts)}
|
18
28
|
texts = list(texts_to_id.keys())
|
19
29
|
labels = [labels[i] for i in texts_to_id.values()] if labels else None
|
20
30
|
|
21
|
-
# print(f'
|
31
|
+
# print(f'Encoding started with {len(texts)} samples')
|
22
32
|
|
23
33
|
self.inputs = tokenizer(
|
24
34
|
texts,
|
@@ -31,7 +41,8 @@ class EncodingDataset(Dataset):
|
|
31
41
|
if labels is not None:
|
32
42
|
self.inputs['labels'] = torch.tensor(labels, dtype=torch.long) if labels is not None else None
|
33
43
|
|
34
|
-
print("
|
44
|
+
# print("Embedding shape: ", self.inputs['input_ids'].shape)
|
45
|
+
# print("Encoding Dataset created with {} samples".format(len(self.inputs['input_ids'])))
|
35
46
|
# print("\n".join([f"Label: {l}, Text: {i}" for i, l in zip(texts, labels)]))
|
36
47
|
# import code; code.interact(local=locals())
|
37
48
|
|