unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
unike/data/KGReader.py
ADDED
@@ -0,0 +1,138 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/KGReader.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 17, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
|
7
|
+
#
|
8
|
+
# 从文件中读取知识图谱.
|
9
|
+
|
10
|
+
"""
|
11
|
+
KGReader - 从文件中读取知识图谱。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import os
|
15
|
+
import numpy as np
|
16
|
+
import collections
|
17
|
+
|
18
|
+
class KGReader:
|
19
|
+
|
20
|
+
"""
|
21
|
+
从文件中读取知识图谱。
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
in_path: str = "./",
|
27
|
+
ent_file: str = "entity2id.txt",
|
28
|
+
rel_file: str = "relation2id.txt",
|
29
|
+
train_file: str = "train2id.txt"):
|
30
|
+
|
31
|
+
"""创建 KGReader 对象。
|
32
|
+
|
33
|
+
:param in_path: 数据集目录
|
34
|
+
:type in_path: str
|
35
|
+
:param ent_file: entity2id.txt
|
36
|
+
:type ent_file: str
|
37
|
+
:param rel_file: relation2id.txt
|
38
|
+
:type rel_file: str
|
39
|
+
:param train_file: train2id.txt
|
40
|
+
:type train_file: str
|
41
|
+
"""
|
42
|
+
|
43
|
+
#: 数据集目录
|
44
|
+
self.in_path: str = in_path
|
45
|
+
#: entity2id.txt
|
46
|
+
self.ent_file: str = ent_file
|
47
|
+
#: relation2id.txt
|
48
|
+
self.rel_file: str = rel_file
|
49
|
+
#: train2id.txt
|
50
|
+
self.train_file: str = train_file
|
51
|
+
|
52
|
+
#: 实体的个数
|
53
|
+
self.ent_tol: int = 0
|
54
|
+
#: 关系的个数
|
55
|
+
self.rel_tol: int = 0
|
56
|
+
#: 训练集三元组的个数
|
57
|
+
self.train_tol: int = 0
|
58
|
+
|
59
|
+
#: 实体->ID
|
60
|
+
self.ent2id: dict = {}
|
61
|
+
#: 关系->ID
|
62
|
+
self.rel2id: dict = {}
|
63
|
+
#: ID->实体
|
64
|
+
self.id2ent: dict = {}
|
65
|
+
#: ID->关系
|
66
|
+
self.id2rel: dict = {}
|
67
|
+
|
68
|
+
#: 训练集三元组
|
69
|
+
self.train_triples: list[tuple[int, int, int]] = []
|
70
|
+
|
71
|
+
#: 训练集中所有 h-r 对对应的 t 集合
|
72
|
+
self.hr2t_train: collections.defaultdict[set] = collections.defaultdict(set)
|
73
|
+
#: 训练集中所有 r-t 对对应的 h 集合
|
74
|
+
self.rt2h_train: collections.defaultdict[set] = collections.defaultdict(set)
|
75
|
+
|
76
|
+
self.get_id()
|
77
|
+
self.get_train_triples_id()
|
78
|
+
|
79
|
+
def get_id(self):
|
80
|
+
|
81
|
+
"""读取 :py:attr:`ent_file` 文件和 :py:attr:`rel_file` 文件。"""
|
82
|
+
|
83
|
+
with open(os.path.join(self.in_path, self.ent_file)) as f:
|
84
|
+
self.ent_tol = (int)(f.readline())
|
85
|
+
for line in f:
|
86
|
+
entity, eid = line.strip().split("\t")
|
87
|
+
self.ent2id[entity] = int(eid)
|
88
|
+
self.id2ent[int(eid)] = entity
|
89
|
+
|
90
|
+
with open(os.path.join(self.in_path, self.rel_file)) as f:
|
91
|
+
self.rel_tol = (int)(f.readline())
|
92
|
+
for line in f:
|
93
|
+
relation, rid = line.strip().split("\t")
|
94
|
+
self.rel2id[relation] = int(rid)
|
95
|
+
self.id2rel[int(rid)] = relation
|
96
|
+
|
97
|
+
def get_train_triples_id(self):
|
98
|
+
|
99
|
+
"""读取 :py:attr:`train_file` 文件。"""
|
100
|
+
|
101
|
+
with open(os.path.join(self.in_path, self.train_file)) as f:
|
102
|
+
self.train_tol = (int)(f.readline())
|
103
|
+
for line in f:
|
104
|
+
h, t, r = line.strip().split()
|
105
|
+
self.train_triples.append((int(h), int(r), int(t)))
|
106
|
+
|
107
|
+
def get_hr2t_rt2h_from_train(self):
|
108
|
+
|
109
|
+
"""获得 :py:attr:`hr2t_train` 和 :py:attr:`rt2h_train` 。"""
|
110
|
+
|
111
|
+
for h, r, t in self.train_triples:
|
112
|
+
self.hr2t_train[(h, r)].add(t)
|
113
|
+
self.rt2h_train[(r, t)].add(h)
|
114
|
+
for h, r in self.hr2t_train:
|
115
|
+
self.hr2t_train[(h, r)] = np.array(list(self.hr2t_train[(h, r)]))
|
116
|
+
for r, t in self.rt2h_train:
|
117
|
+
self.rt2h_train[(r, t)] = np.array(list(self.rt2h_train[(r, t)]))
|
118
|
+
|
119
|
+
def get_hr_train(self):
|
120
|
+
|
121
|
+
"""用于 ``CompGCN`` :cite:`CompGCN` 训练,因为 ``CompGCN`` :cite:`CompGCN` 的组合运算仅需要头实体和关系。
|
122
|
+
|
123
|
+
如果想获得更详细的信息请访问 :ref:`CompGCN <compgcn>`。
|
124
|
+
"""
|
125
|
+
|
126
|
+
self.t_triples = self.train_triples
|
127
|
+
self.train_triples = [(hr, list(t)) for (hr,t) in self.hr2t_train.items()]
|
128
|
+
|
129
|
+
def get_train(self) -> list[tuple[int, int, int]]:
|
130
|
+
|
131
|
+
"""
|
132
|
+
返回训练集三元组。
|
133
|
+
|
134
|
+
:returns: :py:attr:`train_triples`
|
135
|
+
:rtype: list[tuple[int, int, int]]
|
136
|
+
"""
|
137
|
+
|
138
|
+
return self.train_triples
|
@@ -0,0 +1,261 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/RGCNSampler.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 21, 2024
|
7
|
+
#
|
8
|
+
# R-GCN 的数据采样器.
|
9
|
+
|
10
|
+
"""
|
11
|
+
RGCNSampler - R-GCN 的数据采样器。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import dgl
|
15
|
+
import torch
|
16
|
+
import typing
|
17
|
+
import warnings
|
18
|
+
import numpy as np
|
19
|
+
from .RevSampler import RevSampler
|
20
|
+
|
21
|
+
warnings.filterwarnings("ignore")
|
22
|
+
|
23
|
+
class RGCNSampler(RevSampler):
|
24
|
+
|
25
|
+
"""``R-GCN`` :cite:`R-GCN` 的训练数据采样器。
|
26
|
+
|
27
|
+
例子::
|
28
|
+
|
29
|
+
from unike.data import RGCNSampler, CompGCNSampler
|
30
|
+
from torch.utils.data import DataLoader
|
31
|
+
|
32
|
+
#: 训练数据采样器
|
33
|
+
train_sampler: typing.Union[typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = train_sampler(
|
34
|
+
in_path=in_path,
|
35
|
+
ent_file=ent_file,
|
36
|
+
rel_file=rel_file,
|
37
|
+
train_file=train_file,
|
38
|
+
batch_size=batch_size,
|
39
|
+
neg_ent=neg_ent
|
40
|
+
)
|
41
|
+
|
42
|
+
#: 训练集三元组
|
43
|
+
data_train: list[tuple[int, int, int]] = train_sampler.get_train()
|
44
|
+
|
45
|
+
train_dataloader = DataLoader(
|
46
|
+
data_train,
|
47
|
+
shuffle=True,
|
48
|
+
batch_size=batch_size,
|
49
|
+
num_workers=num_workers,
|
50
|
+
pin_memory=True,
|
51
|
+
drop_last=True,
|
52
|
+
collate_fn=train_sampler.sampling,
|
53
|
+
)
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(
|
57
|
+
self,
|
58
|
+
in_path: str = "./",
|
59
|
+
ent_file: str = "entity2id.txt",
|
60
|
+
rel_file: str = "relation2id.txt",
|
61
|
+
train_file: str = "train2id.txt",
|
62
|
+
batch_size: int | None = None,
|
63
|
+
neg_ent: int = 1):
|
64
|
+
|
65
|
+
"""创建 RGCNSampler 对象。
|
66
|
+
|
67
|
+
:param in_path: 数据集目录
|
68
|
+
:type in_path: str
|
69
|
+
:param ent_file: entity2id.txt
|
70
|
+
:type ent_file: str
|
71
|
+
:param rel_file: relation2id.txt
|
72
|
+
:type rel_file: str
|
73
|
+
:param train_file: train2id.txt
|
74
|
+
:type train_file: str
|
75
|
+
:param batch_size: batch size
|
76
|
+
:type batch_size: int | None
|
77
|
+
:param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)
|
78
|
+
:type neg_ent: int
|
79
|
+
"""
|
80
|
+
|
81
|
+
super().__init__(
|
82
|
+
in_path=in_path,
|
83
|
+
ent_file=ent_file,
|
84
|
+
rel_file=rel_file,
|
85
|
+
train_file=train_file
|
86
|
+
)
|
87
|
+
|
88
|
+
#: batch size
|
89
|
+
self.batch_size: int = batch_size
|
90
|
+
#: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)
|
91
|
+
self.neg_ent: int = neg_ent
|
92
|
+
|
93
|
+
self.entity = None
|
94
|
+
self.triples = None
|
95
|
+
self.label = None
|
96
|
+
self.graph = None
|
97
|
+
self.relation = None
|
98
|
+
self.norm = None
|
99
|
+
|
100
|
+
def sampling(
|
101
|
+
self,
|
102
|
+
pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]:
|
103
|
+
|
104
|
+
"""``R-GCN`` :cite:`R-GCN` 的采样函数。
|
105
|
+
|
106
|
+
:param pos_triples: 知识图谱中的正确三元组
|
107
|
+
:type pos_triples: list[tuple[int, int, int]]
|
108
|
+
:returns: ``R-GCN`` :cite:`R-GCN` 的训练数据
|
109
|
+
:rtype: dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]]
|
110
|
+
"""
|
111
|
+
|
112
|
+
batch_data = {}
|
113
|
+
|
114
|
+
pos_triples = np.array(pos_triples)
|
115
|
+
pos_triples, self.entity = self.sampling_positive(pos_triples)
|
116
|
+
head_triples = self.sampling_negative('head', pos_triples)
|
117
|
+
tail_triples = self.sampling_negative('tail', pos_triples)
|
118
|
+
self.triples = np.concatenate((pos_triples,head_triples,tail_triples))
|
119
|
+
batch_data['entity'] = self.entity
|
120
|
+
batch_data['triples'] = torch.from_numpy(self.triples)
|
121
|
+
|
122
|
+
self.label = torch.zeros((len(self.triples),1))
|
123
|
+
self.label[0 : self.batch_size] = 1
|
124
|
+
batch_data['label'] = self.label
|
125
|
+
|
126
|
+
split_size = int(self.batch_size * 0.5)
|
127
|
+
graph_split_ids = np.random.choice(
|
128
|
+
self.batch_size,
|
129
|
+
size=split_size,
|
130
|
+
replace=False
|
131
|
+
)
|
132
|
+
head,rela,tail = pos_triples.transpose()
|
133
|
+
head = torch.tensor(head[graph_split_ids], dtype=torch.long).contiguous()
|
134
|
+
rela = torch.tensor(rela[graph_split_ids], dtype=torch.long).contiguous()
|
135
|
+
tail = torch.tensor(tail[graph_split_ids], dtype=torch.long).contiguous()
|
136
|
+
self.graph, self.relation, self.norm = self.build_graph(len(self.entity), (head,rela,tail), -1)
|
137
|
+
batch_data['graph'] = self.graph
|
138
|
+
batch_data['relation'] = self.relation
|
139
|
+
batch_data['norm'] = self.norm
|
140
|
+
|
141
|
+
return batch_data
|
142
|
+
|
143
|
+
def sampling_positive(
|
144
|
+
self,
|
145
|
+
positive_triples: list[tuple[int, int, int]]) -> tuple[np.ndarray, torch.Tensor]:
|
146
|
+
|
147
|
+
"""为创建子图重新采样三元组子集,重排实体 ID。
|
148
|
+
|
149
|
+
:param pos_triples: 知识图谱中的正确三元组
|
150
|
+
:type pos_triples: list[tuple[int, int, int]]
|
151
|
+
:returns: 三元组子集和原始的实体 ID
|
152
|
+
:rtype: tuple[numpy.ndarray, torch.Tensor]
|
153
|
+
"""
|
154
|
+
|
155
|
+
edges = np.random.choice(
|
156
|
+
np.arange(len(positive_triples)),
|
157
|
+
size = self.batch_size,
|
158
|
+
replace=False
|
159
|
+
)
|
160
|
+
edges = positive_triples[edges]
|
161
|
+
head, rela, tail = np.array(edges).transpose()
|
162
|
+
entity, index = np.unique((head, tail), return_inverse=True)
|
163
|
+
head, tail = np.reshape(index, (2, -1))
|
164
|
+
|
165
|
+
return np.stack((head,rela,tail)).transpose(), \
|
166
|
+
torch.from_numpy(entity).view(-1,1).long()
|
167
|
+
|
168
|
+
def sampling_negative(
|
169
|
+
self,
|
170
|
+
mode: int,
|
171
|
+
pos_triples: list[tuple[int, int, int]]) -> np.ndarray:
|
172
|
+
|
173
|
+
"""采样负三元组。
|
174
|
+
|
175
|
+
:param mode: 'head' 或 'tail'
|
176
|
+
:type mode: str
|
177
|
+
:param pos_triples: 知识图谱中的正确三元组
|
178
|
+
:type pos_triples: list[tuple[int, int, int]]
|
179
|
+
:returns: 负三元组
|
180
|
+
:rtype: numpy.ndarray
|
181
|
+
"""
|
182
|
+
|
183
|
+
neg_random = np.random.choice(
|
184
|
+
len(self.entity),
|
185
|
+
size = self.neg_ent * len(pos_triples)
|
186
|
+
)
|
187
|
+
neg_samples = np.tile(pos_triples, (self.neg_ent, 1))
|
188
|
+
if mode == 'head':
|
189
|
+
neg_samples[:,0] = neg_random
|
190
|
+
elif mode == 'tail':
|
191
|
+
neg_samples[:,2] = neg_random
|
192
|
+
return neg_samples
|
193
|
+
|
194
|
+
def build_graph(
|
195
|
+
self,
|
196
|
+
num_ent: int,
|
197
|
+
triples: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
198
|
+
power: int = -1) -> tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor]:
|
199
|
+
|
200
|
+
"""建立子图。
|
201
|
+
|
202
|
+
:param num_ent: 子图的节点数
|
203
|
+
:type num_ent: int
|
204
|
+
:param triples: 知识图谱中的正确三元组子集
|
205
|
+
:type triples: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
206
|
+
:param power: 幂
|
207
|
+
:type power: int
|
208
|
+
:returns: 子图、关系、边的归一化系数
|
209
|
+
:rtype: tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor]
|
210
|
+
"""
|
211
|
+
|
212
|
+
head, rela, tail = triples[0], triples[1], triples[2]
|
213
|
+
graph = dgl.graph(([], []))
|
214
|
+
graph.add_nodes(num_ent)
|
215
|
+
graph.add_edges(head, tail)
|
216
|
+
node_norm = self.comp_deg_norm(graph, power)
|
217
|
+
edge_norm = self.node_norm_to_edge_norm(graph,node_norm)
|
218
|
+
rela = torch.tensor(rela)
|
219
|
+
return graph, rela, edge_norm
|
220
|
+
|
221
|
+
def comp_deg_norm(
|
222
|
+
self,
|
223
|
+
graph: dgl.DGLGraph,
|
224
|
+
power: int = -1) -> torch.Tensor:
|
225
|
+
|
226
|
+
"""根据目标节点度计算目标节点的归一化系数。
|
227
|
+
|
228
|
+
:param graph: 子图
|
229
|
+
:type graph: dgl.DGLGraph
|
230
|
+
:param power: 幂
|
231
|
+
:type power: int
|
232
|
+
:returns: 节点的归一化系数
|
233
|
+
:rtype: torch.Tensor
|
234
|
+
"""
|
235
|
+
|
236
|
+
graph = graph.local_var()
|
237
|
+
in_deg = graph.in_degrees(range(graph.number_of_nodes())).float().numpy()
|
238
|
+
norm = in_deg.__pow__(power)
|
239
|
+
norm[np.isinf(norm)] = 0
|
240
|
+
return torch.from_numpy(norm)
|
241
|
+
|
242
|
+
def node_norm_to_edge_norm(
|
243
|
+
self,
|
244
|
+
graph: dgl.DGLGraph,
|
245
|
+
node_norm: torch.Tensor) -> torch.Tensor:
|
246
|
+
|
247
|
+
"""根据目标节点度计算每条边的归一化系数。
|
248
|
+
|
249
|
+
:param graph: 子图
|
250
|
+
:type graph: dgl.DGLGraph
|
251
|
+
:param node_norm: 节点的归一化系数
|
252
|
+
:type node_norm: torch.Tensor
|
253
|
+
:returns: 边的归一化系数
|
254
|
+
:rtype: torch.Tensor
|
255
|
+
"""
|
256
|
+
|
257
|
+
graph = graph.local_var()
|
258
|
+
# convert to edge norm
|
259
|
+
graph.ndata['norm'] = node_norm.view(-1,1)
|
260
|
+
graph.apply_edges(lambda edges : {'norm' : edges.dst['norm']})
|
261
|
+
return graph.edata['norm']
|
@@ -0,0 +1,208 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/RGCNTestSampler.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
|
7
|
+
#
|
8
|
+
# R-GCN 的测试数据采样器.
|
9
|
+
|
10
|
+
"""
|
11
|
+
RGCNTestSampler - R-GCN 的测试数据采样器。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import os
|
15
|
+
import dgl
|
16
|
+
import torch
|
17
|
+
import typing
|
18
|
+
import numpy as np
|
19
|
+
from .TestSampler import TestSampler
|
20
|
+
from .RGCNSampler import RGCNSampler
|
21
|
+
from .CompGCNSampler import CompGCNSampler
|
22
|
+
from typing_extensions import override
|
23
|
+
|
24
|
+
class RGCNTestSampler(TestSampler):
|
25
|
+
|
26
|
+
"""``R-GCN`` :cite:`R-GCN` 的测试数据采样器。
|
27
|
+
|
28
|
+
例子::
|
29
|
+
|
30
|
+
from unike.data import RGCNTestSampler, CompGCNTestSampler
|
31
|
+
from torch.utils.data import DataLoader
|
32
|
+
|
33
|
+
#: 测试数据采样器
|
34
|
+
test_sampler: typing.Union[typing.Type[RGCNTestSampler], typing.Type[CompGCNTestSampler]] = test_sampler(
|
35
|
+
sampler=train_sampler,
|
36
|
+
valid_file=valid_file,
|
37
|
+
test_file=test_file,
|
38
|
+
)
|
39
|
+
|
40
|
+
#: 验证集三元组
|
41
|
+
data_val: list[tuple[int, int, int]] = test_sampler.get_valid()
|
42
|
+
#: 测试集三元组
|
43
|
+
data_test: list[tuple[int, int, int]] = test_sampler.get_test()
|
44
|
+
|
45
|
+
val_dataloader = DataLoader(
|
46
|
+
data_val,
|
47
|
+
shuffle=False,
|
48
|
+
batch_size=test_batch_size,
|
49
|
+
num_workers=num_workers,
|
50
|
+
pin_memory=True,
|
51
|
+
collate_fn=test_sampler.sampling,
|
52
|
+
)
|
53
|
+
|
54
|
+
test_dataloader = DataLoader(
|
55
|
+
data_test,
|
56
|
+
shuffle=False,
|
57
|
+
batch_size=test_batch_size,
|
58
|
+
num_workers=num_workers,
|
59
|
+
pin_memory=True,
|
60
|
+
collate_fn=test_sampler.sampling,
|
61
|
+
)
|
62
|
+
"""
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
sampler: typing.Union[RGCNSampler, CompGCNSampler],
|
67
|
+
valid_file: str = "valid2id.txt",
|
68
|
+
test_file: str = "test2id.txt",
|
69
|
+
type_constrain: bool = True):
|
70
|
+
|
71
|
+
"""创建 RGCNTestSampler 对象。
|
72
|
+
|
73
|
+
:param sampler: 训练数据采样器。
|
74
|
+
:type sampler: typing.Union[RGCNSampler, CompGCNSampler]
|
75
|
+
:param valid_file: valid2id.txt
|
76
|
+
:type valid_file: str
|
77
|
+
:param test_file: test2id.txt
|
78
|
+
:type test_file: str
|
79
|
+
:param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
|
80
|
+
:type type_constrain: bool
|
81
|
+
"""
|
82
|
+
|
83
|
+
super().__init__(
|
84
|
+
sampler=sampler,
|
85
|
+
valid_file=valid_file,
|
86
|
+
test_file=test_file,
|
87
|
+
type_constrain = type_constrain
|
88
|
+
)
|
89
|
+
|
90
|
+
#: 训练集三元组
|
91
|
+
self.triples: list[tuple[int, int, int]] = self.sampler.t_triples if isinstance(self.sampler, CompGCNSampler) else self.sampler.train_triples
|
92
|
+
#: 幂
|
93
|
+
self.power: float = -1
|
94
|
+
|
95
|
+
self.add_valid_test_reverse_triples()
|
96
|
+
self.get_hr2t_rt2h_from_all()
|
97
|
+
|
98
|
+
@override
|
99
|
+
def get_valid_test_triples_id(self):
|
100
|
+
|
101
|
+
"""读取 :py:attr:`valid_file` 文件和 :py:attr:`test_file` 文件。"""
|
102
|
+
|
103
|
+
with open(os.path.join(self.sampler.in_path, self.valid_file)) as f:
|
104
|
+
self.valid_tol = (int)(f.readline())
|
105
|
+
for line in f:
|
106
|
+
h, t, r = line.strip().split()
|
107
|
+
self.valid_triples.append((int(h), int(r), int(t)))
|
108
|
+
|
109
|
+
with open(os.path.join(self.sampler.in_path, self.test_file)) as f:
|
110
|
+
self.test_tol = (int)(f.readline())
|
111
|
+
for line in f:
|
112
|
+
h, t, r = line.strip().split()
|
113
|
+
self.test_triples.append((int(h), int(r), int(t)))
|
114
|
+
|
115
|
+
def add_valid_test_reverse_triples(self):
|
116
|
+
|
117
|
+
"""对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。"""
|
118
|
+
|
119
|
+
tol = int(self.sampler.rel_tol / 2)
|
120
|
+
|
121
|
+
with open(os.path.join(self.sampler.in_path, self.valid_file)) as f:
|
122
|
+
f.readline()
|
123
|
+
for line in f:
|
124
|
+
h, t, r = line.strip().split()
|
125
|
+
self.valid_triples.append(
|
126
|
+
(int(t), int(r) + tol, int(h))
|
127
|
+
)
|
128
|
+
|
129
|
+
with open(os.path.join(self.sampler.in_path, self.test_file)) as f:
|
130
|
+
f.readline()
|
131
|
+
for line in f:
|
132
|
+
h, t, r = line.strip().split()
|
133
|
+
self.test_triples.append(
|
134
|
+
(int(t), int(r) + tol, int(h))
|
135
|
+
)
|
136
|
+
|
137
|
+
self.all_true_triples = set(
|
138
|
+
self.triples + self.valid_triples + self.test_triples
|
139
|
+
)
|
140
|
+
|
141
|
+
@override
|
142
|
+
def get_type_constrain_id(self):
|
143
|
+
|
144
|
+
"""读取 type_constrain.txt 文件。"""
|
145
|
+
|
146
|
+
tol = int(self.sampler.rel_tol / 2)
|
147
|
+
|
148
|
+
with open(os.path.join(self.sampler.in_path, "type_constrain.txt")) as f:
|
149
|
+
rel_tol = (int)(f.readline())
|
150
|
+
first_line = True
|
151
|
+
for line in f:
|
152
|
+
rel_types = line.strip().split("\t")
|
153
|
+
for entity in rel_types[2:]:
|
154
|
+
if first_line:
|
155
|
+
self.rel_heads[int(rel_types[0])].add(int(entity))
|
156
|
+
self.rel_tails[int(rel_types[0]) + tol].add(int(entity))
|
157
|
+
else:
|
158
|
+
self.rel_tails[int(rel_types[0])].add(int(entity))
|
159
|
+
self.rel_heads[int(rel_types[0]) + tol].add(int(entity))
|
160
|
+
first_line = not first_line
|
161
|
+
|
162
|
+
for rel in self.rel_heads:
|
163
|
+
self.rel_heads[rel] = torch.tensor(list(self.rel_heads[rel]))
|
164
|
+
for rel in self.rel_tails:
|
165
|
+
self.rel_tails[rel] = torch.tensor(list(self.rel_tails[rel]))
|
166
|
+
|
167
|
+
@override
|
168
|
+
def sampling(
|
169
|
+
self,
|
170
|
+
data: list[tuple[int, int, int]]) -> dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]]:
|
171
|
+
|
172
|
+
"""``R-GCN`` :cite:`R-GCN` 的测试数据采样函数。
|
173
|
+
|
174
|
+
:param data: 测试的正确三元组
|
175
|
+
:type data: list[tuple[int, int, int]]
|
176
|
+
:returns: ``R-GCN`` :cite:`R-GCN` 的测试数据
|
177
|
+
:rtype: dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]]
|
178
|
+
"""
|
179
|
+
|
180
|
+
batch_data = {}
|
181
|
+
head_label = torch.zeros(len(data), self.ent_tol)
|
182
|
+
tail_label = torch.zeros(len(data), self.ent_tol)
|
183
|
+
for idx, triple in enumerate(data):
|
184
|
+
head, rel, tail = triple
|
185
|
+
head_label[idx][self.rt2h_all[(rel, tail)]] = 1.0
|
186
|
+
tail_label[idx][self.hr2t_all[(head, rel)]] = 1.0
|
187
|
+
if self.type_constrain:
|
188
|
+
head_label_type = torch.ones(len(data), self.ent_tol)
|
189
|
+
tail_laebl_type = torch.ones(len(data), self.ent_tol)
|
190
|
+
for idx, triple in enumerate(data):
|
191
|
+
head, rel, tail = triple
|
192
|
+
head_label_type[idx][self.rel_heads[rel]] = 0.0
|
193
|
+
tail_laebl_type[idx][self.rel_tails[rel]] = 0.0
|
194
|
+
head_label_type[idx][self.rt2h_all[(rel, tail)]] = 1.0
|
195
|
+
tail_laebl_type[idx][self.hr2t_all[(head, rel)]] = 1.0
|
196
|
+
batch_data["head_label_type"] = head_label_type
|
197
|
+
batch_data["tail_label_type"] = tail_laebl_type
|
198
|
+
batch_data["positive_sample"] = torch.tensor(data)
|
199
|
+
batch_data["head_label"] = head_label
|
200
|
+
batch_data["tail_label"] = tail_label
|
201
|
+
|
202
|
+
graph, rela, norm = self.sampler.build_graph(self.ent_tol, np.array(self.triples).transpose(), self.power)
|
203
|
+
batch_data["graph"] = graph
|
204
|
+
batch_data["rela"] = rela
|
205
|
+
batch_data["norm"] = norm
|
206
|
+
batch_data["entity"] = torch.arange(0, self.ent_tol, dtype=torch.long).view(-1,1)
|
207
|
+
|
208
|
+
return batch_data
|
unike/data/RevSampler.py
ADDED
@@ -0,0 +1,78 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/RevSampler.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 15, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
|
7
|
+
#
|
8
|
+
# 为 KGReader 增加相反关系,用于图神经网络模型.
|
9
|
+
|
10
|
+
"""
|
11
|
+
RevSampler - 为 KGReader 增加相反关系,用于图神经网络模型。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import os
|
15
|
+
from .KGReader import KGReader
|
16
|
+
|
17
|
+
class RevSampler(KGReader):
|
18
|
+
|
19
|
+
"""在训练集中增加相反关系.
|
20
|
+
|
21
|
+
对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
in_path: str = "./",
|
27
|
+
ent_file: str = "entity2id.txt",
|
28
|
+
rel_file: str = "relation2id.txt",
|
29
|
+
train_file: str = "train2id.txt"):
|
30
|
+
|
31
|
+
"""创建 RevSampler 对象。
|
32
|
+
|
33
|
+
:param in_path: 数据集目录
|
34
|
+
:type in_path: str
|
35
|
+
:param ent_file: entity2id.txt
|
36
|
+
:type ent_file: str
|
37
|
+
:param rel_file: relation2id.txt
|
38
|
+
:type rel_file: str
|
39
|
+
:param train_file: train2id.txt
|
40
|
+
:type train_file: str
|
41
|
+
"""
|
42
|
+
|
43
|
+
super().__init__(
|
44
|
+
in_path=in_path,
|
45
|
+
ent_file=ent_file,
|
46
|
+
rel_file=rel_file,
|
47
|
+
train_file=train_file
|
48
|
+
)
|
49
|
+
|
50
|
+
self.add_reverse_relation()
|
51
|
+
self.add_train_reverse_triples()
|
52
|
+
self.get_hr2t_rt2h_from_train()
|
53
|
+
|
54
|
+
def add_reverse_relation(self):
|
55
|
+
|
56
|
+
"""增加相反关系:r` = r + rel_tol"""
|
57
|
+
|
58
|
+
with open(os.path.join(self.in_path, self.rel_file)) as f:
|
59
|
+
f.readline()
|
60
|
+
for line in f:
|
61
|
+
relation, rid = line.strip().split("\t")
|
62
|
+
self.rel2id[relation + "_reverse"] = int(rid) + self.rel_tol
|
63
|
+
self.id2rel[int(rid) + self.rel_tol] = relation + "_reverse"
|
64
|
+
self.rel_tol = len(self.rel2id)
|
65
|
+
|
66
|
+
def add_train_reverse_triples(self):
|
67
|
+
|
68
|
+
"""对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。"""
|
69
|
+
|
70
|
+
tol = int(self.rel_tol / 2)
|
71
|
+
|
72
|
+
with open(os.path.join(self.in_path, self.train_file)) as f:
|
73
|
+
f.readline()
|
74
|
+
for line in f:
|
75
|
+
h, t, r = line.strip().split()
|
76
|
+
self.train_triples.append(
|
77
|
+
(int(t), int(r) + tol, int(h))
|
78
|
+
)
|