unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,304 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/model/RGCN.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 22, 2024
|
7
|
+
#
|
8
|
+
# 该头文件定义了 R-GCN.
|
9
|
+
|
10
|
+
"""
|
11
|
+
R-GCN - 第一个图神经网络模型。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import dgl
|
15
|
+
import torch
|
16
|
+
import typing
|
17
|
+
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
19
|
+
from .Model import Model
|
20
|
+
from dgl.nn.pytorch import RelGraphConv
|
21
|
+
from typing_extensions import override
|
22
|
+
|
23
|
+
class RGCN(Model):
|
24
|
+
|
25
|
+
"""
|
26
|
+
``R-GCN`` :cite:`R-GCN` 提出于 2017 年,是第一个图神经网络模型。
|
27
|
+
|
28
|
+
正三元组的评分函数的值越大越好,如果想获得更详细的信息请访问 :ref:`R-GCN <rgcn>`。
|
29
|
+
|
30
|
+
例子::
|
31
|
+
|
32
|
+
from unike.data import GraphDataLoader
|
33
|
+
from unike.module.model import RGCN
|
34
|
+
from unike.module.loss import RGCNLoss
|
35
|
+
from unike.module.strategy import RGCNSampling
|
36
|
+
from unike.config import Trainer, GraphTester
|
37
|
+
|
38
|
+
dataloader = GraphDataLoader(
|
39
|
+
in_path = "../../benchmarks/FB15K237/",
|
40
|
+
batch_size = 60000,
|
41
|
+
neg_ent = 10,
|
42
|
+
test = True,
|
43
|
+
test_batch_size = 100,
|
44
|
+
num_workers = 16
|
45
|
+
)
|
46
|
+
|
47
|
+
# define the model
|
48
|
+
rgcn = RGCN(
|
49
|
+
ent_tol = dataloader.get_ent_tol(),
|
50
|
+
rel_tol = dataloader.get_rel_tol(),
|
51
|
+
dim = 500,
|
52
|
+
num_layers = 2
|
53
|
+
)
|
54
|
+
|
55
|
+
# define the loss function
|
56
|
+
model = RGCNSampling(
|
57
|
+
model = rgcn,
|
58
|
+
loss = RGCNLoss(model = rgcn, regularization = 1e-5)
|
59
|
+
)
|
60
|
+
|
61
|
+
# test the model
|
62
|
+
tester = GraphTester(model = rgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0')
|
63
|
+
|
64
|
+
# train the model
|
65
|
+
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
|
66
|
+
epochs = 10000, lr = 0.0001, use_gpu = True, device = 'cuda:0',
|
67
|
+
tester = tester, test = True, valid_interval = 500, log_interval = 500,
|
68
|
+
save_interval = 500, save_path = '../../checkpoint/rgcn.pth'
|
69
|
+
)
|
70
|
+
trainer.run()
|
71
|
+
"""
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
ent_tol: int,
|
76
|
+
rel_tol: int,
|
77
|
+
dim: int,
|
78
|
+
num_layers: int):
|
79
|
+
|
80
|
+
"""创建 RGCN 对象。
|
81
|
+
|
82
|
+
:param ent_tol: 实体的个数
|
83
|
+
:type ent_tol: int
|
84
|
+
:param rel_tol: 关系的个数
|
85
|
+
:type rel_tol: int
|
86
|
+
:param dim: 实体和关系嵌入向量的维度
|
87
|
+
:type dim: int
|
88
|
+
:param num_layers: 图神经网络的层数
|
89
|
+
:type num_layers: int
|
90
|
+
"""
|
91
|
+
|
92
|
+
super(RGCN, self).__init__(ent_tol, rel_tol)
|
93
|
+
|
94
|
+
#: 实体和关系嵌入向量的维度
|
95
|
+
self.dim: int = dim
|
96
|
+
#: 图神经网络的层数
|
97
|
+
self.num_layers: int = num_layers
|
98
|
+
|
99
|
+
#: 根据实体个数,创建的实体嵌入
|
100
|
+
self.ent_emb: torch.nn.Embedding = None
|
101
|
+
#: 根据关系个数,创建的关系嵌入
|
102
|
+
self.rel_emb: torch.nn.parameter.Parameter = None
|
103
|
+
#: R-GCN 的图神经网络层
|
104
|
+
self.RGCN: torch.nn.ModuleList = None
|
105
|
+
#: 图神经网络层的输出
|
106
|
+
self.Loss_emb: torch.nn.Embedding = None
|
107
|
+
|
108
|
+
self.build_model()
|
109
|
+
|
110
|
+
def build_model(self):
|
111
|
+
|
112
|
+
"""构建模型"""
|
113
|
+
|
114
|
+
self.ent_emb = nn.Embedding(self.ent_tol, self.dim)
|
115
|
+
|
116
|
+
self.rel_emb = nn.Parameter(torch.Tensor(self.rel_tol, self.dim))
|
117
|
+
|
118
|
+
nn.init.xavier_uniform_(self.rel_emb, gain=nn.init.calculate_gain('relu'))
|
119
|
+
|
120
|
+
self.RGCN = nn.ModuleList()
|
121
|
+
for idx in range(self.num_layers):
|
122
|
+
RGCN_idx = self.build_hidden_layer(idx)
|
123
|
+
self.RGCN.append(RGCN_idx)
|
124
|
+
|
125
|
+
def build_hidden_layer(
|
126
|
+
self,
|
127
|
+
idx: int) -> dgl.nn.pytorch.conv.RelGraphConv:
|
128
|
+
|
129
|
+
"""返回第 idx 的图神经网络层。
|
130
|
+
|
131
|
+
:param idx: 数据。
|
132
|
+
:type idx: int
|
133
|
+
:returns: 图神经网络层
|
134
|
+
:rtype: dgl.nn.pytorch.conv.RelGraphConv
|
135
|
+
"""
|
136
|
+
|
137
|
+
act = F.relu if idx < self.num_layers - 1 else None
|
138
|
+
return RelGraphConv(self.dim, self.dim, self.rel_tol, "bdd",
|
139
|
+
num_bases=100, activation=act, self_loop=True, dropout=0.2)
|
140
|
+
|
141
|
+
@override
|
142
|
+
def forward(
|
143
|
+
self,
|
144
|
+
graph: dgl.DGLGraph,
|
145
|
+
ent: torch.Tensor,
|
146
|
+
rel: torch.Tensor,
|
147
|
+
norm: torch.Tensor,
|
148
|
+
triples: torch.Tensor,
|
149
|
+
mode: str = 'single') -> torch.Tensor:
|
150
|
+
|
151
|
+
"""
|
152
|
+
定义每次调用时执行的计算。
|
153
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
154
|
+
|
155
|
+
:param graph: 子图
|
156
|
+
:type graph: dgl.DGLGraph
|
157
|
+
:param ent: 子图的实体
|
158
|
+
:type ent: torch.Tensor
|
159
|
+
:param rel: 子图的关系
|
160
|
+
:type rel: torch.Tensor
|
161
|
+
:param norm: 关系的归一化系数
|
162
|
+
:type norm: torch.Tensor
|
163
|
+
:param triples: 三元组
|
164
|
+
:type triples: torch.Tensor
|
165
|
+
:param mode: 模式
|
166
|
+
:type mode: str
|
167
|
+
:returns: 三元组的得分
|
168
|
+
:rtype: torch.Tensor
|
169
|
+
"""
|
170
|
+
|
171
|
+
embedding = self.ent_emb(ent.squeeze())
|
172
|
+
for layer in self.RGCN:
|
173
|
+
embedding = layer(graph, embedding, rel, norm)
|
174
|
+
self.Loss_emb = embedding
|
175
|
+
head_emb, rela_emb, tail_emb = self.tri2emb(embedding, triples, mode)
|
176
|
+
score = self.distmult_score_func(head_emb, rela_emb, tail_emb, mode)
|
177
|
+
|
178
|
+
return score
|
179
|
+
|
180
|
+
def tri2emb(
|
181
|
+
self,
|
182
|
+
embedding: torch.Tensor,
|
183
|
+
triples: torch.Tensor,
|
184
|
+
mode: str = "single") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
185
|
+
|
186
|
+
"""
|
187
|
+
获得三元组对应头实体、关系和尾实体的嵌入向量。
|
188
|
+
|
189
|
+
:param embedding: 经过图神经网络更新的实体嵌入向量
|
190
|
+
:type embedding: torch.Tensor
|
191
|
+
:param triples: 训练的三元组
|
192
|
+
:type triples: torch.Tensor
|
193
|
+
:param mode: 模式
|
194
|
+
:type mode: str
|
195
|
+
:returns: 头实体、关系和尾实体的嵌入向量
|
196
|
+
:rtype: torch.Tensor
|
197
|
+
"""
|
198
|
+
|
199
|
+
rela_emb = self.rel_emb[triples[:, 1]].unsqueeze(1) # [bs, 1, dim]
|
200
|
+
head_emb = embedding[triples[:, 0]].unsqueeze(1) # [bs, 1, dim]
|
201
|
+
tail_emb = embedding[triples[:, 2]].unsqueeze(1) # [bs, 1, dim]
|
202
|
+
|
203
|
+
if mode == "head-batch" or mode == "head_predict":
|
204
|
+
head_emb = embedding.unsqueeze(0) # [1, num_ent, dim]
|
205
|
+
|
206
|
+
elif mode == "tail-batch" or mode == "tail_predict":
|
207
|
+
tail_emb = embedding.unsqueeze(0) # [1, num_ent, dim]
|
208
|
+
|
209
|
+
return head_emb, rela_emb, tail_emb
|
210
|
+
|
211
|
+
def distmult_score_func(
|
212
|
+
self,
|
213
|
+
head_emb: torch.Tensor,
|
214
|
+
relation_emb: torch.Tensor,
|
215
|
+
tail_emb: torch.Tensor,
|
216
|
+
mode: str) -> torch.Tensor:
|
217
|
+
|
218
|
+
"""
|
219
|
+
计算 DistMult 的评分函数。
|
220
|
+
|
221
|
+
:param head_emb: 头实体嵌入向量
|
222
|
+
:type head_emb: torch.Tensor
|
223
|
+
:param relation_emb: 关系嵌入向量
|
224
|
+
:type relation_emb: torch.Tensor
|
225
|
+
:param tail_emb: 尾实体嵌入向量
|
226
|
+
:type tail_emb: torch.Tensor
|
227
|
+
:returns: 三元组的得分
|
228
|
+
:rtype: torch.Tensor
|
229
|
+
"""
|
230
|
+
|
231
|
+
if mode == 'head-batch':
|
232
|
+
score = head_emb * (relation_emb * tail_emb)
|
233
|
+
else:
|
234
|
+
score = (head_emb * relation_emb) * tail_emb
|
235
|
+
|
236
|
+
score = score.sum(dim = -1)
|
237
|
+
return score
|
238
|
+
|
239
|
+
@override
|
240
|
+
def predict(
|
241
|
+
self,
|
242
|
+
data: dict[str, torch.Tensor],
|
243
|
+
mode: str) -> torch.Tensor:
|
244
|
+
|
245
|
+
"""R-GCN 的推理方法。
|
246
|
+
|
247
|
+
:param data: 数据。
|
248
|
+
:type data: dict[str, torch.Tensor]
|
249
|
+
:param mode: 模式
|
250
|
+
:type mode: str
|
251
|
+
:returns: 三元组的得分
|
252
|
+
:rtype: torch.Tensor
|
253
|
+
"""
|
254
|
+
|
255
|
+
triples = data['positive_sample']
|
256
|
+
graph = data['graph']
|
257
|
+
ent = data['entity']
|
258
|
+
rel = data['rela']
|
259
|
+
norm = data['norm']
|
260
|
+
|
261
|
+
embedding = self.ent_emb(ent.squeeze())
|
262
|
+
for layer in self.RGCN:
|
263
|
+
embedding = layer(graph, embedding, rel, norm)
|
264
|
+
self.Loss_emb = embedding
|
265
|
+
head_emb, rela_emb, tail_emb = self.tri2emb(embedding, triples, mode)
|
266
|
+
score = self.distmult_score_func(head_emb, rela_emb, tail_emb, mode)
|
267
|
+
|
268
|
+
return score
|
269
|
+
|
270
|
+
def get_rgcn_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
271
|
+
|
272
|
+
"""返回 :py:class:`RGCN` 的默认超参数优化配置。
|
273
|
+
|
274
|
+
默认配置为::
|
275
|
+
|
276
|
+
parameters_dict = {
|
277
|
+
'model': {
|
278
|
+
'value': 'RGCN'
|
279
|
+
},
|
280
|
+
'dim': {
|
281
|
+
'values': [200, 300, 400]
|
282
|
+
},
|
283
|
+
'num_layers': {
|
284
|
+
'value': 2
|
285
|
+
}
|
286
|
+
}
|
287
|
+
|
288
|
+
:returns: :py:class:`RGCN` 的默认超参数优化配置
|
289
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
290
|
+
"""
|
291
|
+
|
292
|
+
parameters_dict = {
|
293
|
+
'model': {
|
294
|
+
'value': 'RGCN'
|
295
|
+
},
|
296
|
+
'dim': {
|
297
|
+
'values': [200, 300, 400]
|
298
|
+
},
|
299
|
+
'num_layers': {
|
300
|
+
'value': 2
|
301
|
+
}
|
302
|
+
}
|
303
|
+
|
304
|
+
return parameters_dict
|
@@ -0,0 +1,303 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/model/RotatE.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 11, 2023
|
7
|
+
#
|
8
|
+
# 该头文件定义了 RotatE.
|
9
|
+
|
10
|
+
"""
|
11
|
+
RotatE - 将实体表示成复数向量,关系建模为复数向量空间的旋转。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import typing
|
16
|
+
import torch.nn as nn
|
17
|
+
from .Model import Model
|
18
|
+
from typing_extensions import override
|
19
|
+
|
20
|
+
class RotatE(Model):
|
21
|
+
|
22
|
+
"""
|
23
|
+
``RotatE`` :cite:`RotatE` 提出于 2019 年,将实体表示成复数向量,关系建模为复数向量空间的旋转。
|
24
|
+
|
25
|
+
评分函数为:
|
26
|
+
|
27
|
+
.. math::
|
28
|
+
|
29
|
+
\gamma - \parallel \mathbf{h} \circ \mathbf{r} - \mathbf{t} \parallel_{L_2}
|
30
|
+
|
31
|
+
:math:`\circ` 表示哈达玛积(Hadamard product),正三元组的评分函数的值越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
|
32
|
+
|
33
|
+
例子::
|
34
|
+
|
35
|
+
from unike.data import KGEDataLoader, UniSampler, TradTestSampler
|
36
|
+
from unike.module.model import RotatE
|
37
|
+
from unike.module.loss import SigmoidLoss
|
38
|
+
from unike.module.strategy import NegativeSampling
|
39
|
+
from unike.config import Trainer, Tester
|
40
|
+
|
41
|
+
# dataloader for training
|
42
|
+
dataloader = KGEDataLoader(
|
43
|
+
in_path = '../../benchmarks/WN18RR/',
|
44
|
+
batch_size = 2000,
|
45
|
+
neg_ent = 64,
|
46
|
+
test = True,
|
47
|
+
test_batch_size = 10,
|
48
|
+
num_workers = 16,
|
49
|
+
train_sampler = UniSampler,
|
50
|
+
test_sampler = TradTestSampler
|
51
|
+
)
|
52
|
+
|
53
|
+
# define the model
|
54
|
+
rotate = RotatE(
|
55
|
+
ent_tol = dataloader.get_ent_tol(),
|
56
|
+
rel_tol = dataloader.get_rel_tol(),
|
57
|
+
dim = 1024,
|
58
|
+
margin = 6.0,
|
59
|
+
epsilon = 2.0,
|
60
|
+
)
|
61
|
+
|
62
|
+
# define the loss function
|
63
|
+
model = NegativeSampling(
|
64
|
+
model = rotate,
|
65
|
+
loss = SigmoidLoss(adv_temperature = 2),
|
66
|
+
regul_rate = 0.0,
|
67
|
+
)
|
68
|
+
|
69
|
+
# test the model
|
70
|
+
tester = Tester(model = rotate, data_loader = dataloader, use_gpu = True, device = 'cuda:1')
|
71
|
+
|
72
|
+
# train the model
|
73
|
+
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(), epochs = 6000,
|
74
|
+
lr = 2e-5, opt_method = 'adam', use_gpu = True, device = 'cuda:1',
|
75
|
+
tester = tester, test = True, valid_interval = 100,
|
76
|
+
log_interval = 100, save_interval = 100,
|
77
|
+
save_path = '../../checkpoint/rotate.pth', use_wandb = False)
|
78
|
+
trainer.run()
|
79
|
+
"""
|
80
|
+
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
ent_tol: int,
|
84
|
+
rel_tol: int,
|
85
|
+
dim: int = 100,
|
86
|
+
margin: float = 6.0,
|
87
|
+
epsilon: float = 2.0):
|
88
|
+
|
89
|
+
"""创建 RotatE 对象。
|
90
|
+
|
91
|
+
:param ent_tol: 实体的个数
|
92
|
+
:type ent_tol: int
|
93
|
+
:param rel_tol: 关系的个数
|
94
|
+
:type rel_tol: int
|
95
|
+
:param dim: 实体和关系嵌入向量的维度
|
96
|
+
:type dim: int
|
97
|
+
:param margin: 原论文中损失函数的 gamma。
|
98
|
+
:type margin: float
|
99
|
+
:param epsilon: RotatE 原论文对应的源代码固定为 2.0。
|
100
|
+
:type epsilon: float
|
101
|
+
"""
|
102
|
+
|
103
|
+
super(RotatE, self).__init__(ent_tol, rel_tol)
|
104
|
+
|
105
|
+
#: RotatE 原论文对应的源代码固定为 2.0。
|
106
|
+
self.epsilon: int = epsilon
|
107
|
+
|
108
|
+
#: RotatE 原论文的实现中将实体嵌入向量的维度指定为 ``dim`` 的 2 倍。
|
109
|
+
#: 因为实体嵌入向量需要划分为实部和虚部。
|
110
|
+
self.dim_e: int = dim * 2
|
111
|
+
#: 关系嵌入向量的维度,为 ``dim``。
|
112
|
+
self.dim_r: int = dim
|
113
|
+
|
114
|
+
#: 根据实体个数,创建的实体嵌入。
|
115
|
+
self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim_e)
|
116
|
+
#: 根据关系个数,创建的关系嵌入。
|
117
|
+
self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim_r)
|
118
|
+
|
119
|
+
self.ent_embedding_range = nn.Parameter(
|
120
|
+
torch.Tensor([(margin + self.epsilon) / self.dim_e]),
|
121
|
+
requires_grad=False
|
122
|
+
)
|
123
|
+
|
124
|
+
nn.init.uniform_(
|
125
|
+
tensor = self.ent_embeddings.weight.data,
|
126
|
+
a=-self.ent_embedding_range.item(),
|
127
|
+
b=self.ent_embedding_range.item()
|
128
|
+
)
|
129
|
+
|
130
|
+
self.rel_embedding_range = nn.Parameter(
|
131
|
+
torch.Tensor([(margin + self.epsilon) / self.dim_r]),
|
132
|
+
requires_grad=False
|
133
|
+
)
|
134
|
+
|
135
|
+
nn.init.uniform_(
|
136
|
+
tensor = self.rel_embeddings.weight.data,
|
137
|
+
a=-self.rel_embedding_range.item(),
|
138
|
+
b=self.rel_embedding_range.item()
|
139
|
+
)
|
140
|
+
|
141
|
+
#: 原论文中损失函数的 gamma。
|
142
|
+
self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
|
143
|
+
self.margin.requires_grad = False
|
144
|
+
|
145
|
+
@override
|
146
|
+
def forward(
|
147
|
+
self,
|
148
|
+
triples: torch.Tensor,
|
149
|
+
negs: torch.Tensor = None,
|
150
|
+
mode: str = 'single') -> torch.Tensor:
|
151
|
+
|
152
|
+
"""
|
153
|
+
定义每次调用时执行的计算。
|
154
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
155
|
+
|
156
|
+
:param triples: 正确的三元组
|
157
|
+
:type triples: torch.Tensor
|
158
|
+
:param negs: 负三元组类别
|
159
|
+
:type negs: torch.Tensor
|
160
|
+
:param mode: 模式
|
161
|
+
:type triples: str
|
162
|
+
:returns: 三元组的得分
|
163
|
+
:rtype: torch.Tensor
|
164
|
+
"""
|
165
|
+
|
166
|
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
|
167
|
+
score = self.margin - self._calc(head_emb, relation_emb, tail_emb)
|
168
|
+
return score
|
169
|
+
|
170
|
+
def _calc(
|
171
|
+
self,
|
172
|
+
h: torch.Tensor,
|
173
|
+
r: torch.Tensor,
|
174
|
+
t: torch.Tensor) -> torch.Tensor:
|
175
|
+
|
176
|
+
"""计算 RotatE 的评分函数。
|
177
|
+
|
178
|
+
利用 :py:func:`torch.chunk` 拆分实体嵌入向量获得复数的实部和虚部。
|
179
|
+
原论文使用 L1-norm 作为距离函数,而这里使用的 L2-norm 作为距离函数。
|
180
|
+
|
181
|
+
:param h: 头实体的向量。
|
182
|
+
:type h: torch.Tensor
|
183
|
+
:param r: 关系的向量。
|
184
|
+
:type r: torch.Tensor
|
185
|
+
:param t: 尾实体的向量。
|
186
|
+
:type t: torch.Tensor
|
187
|
+
:returns: 三元组的得分
|
188
|
+
:rtype: torch.Tensor
|
189
|
+
"""
|
190
|
+
|
191
|
+
pi = self.pi_const
|
192
|
+
|
193
|
+
re_head, im_head = torch.chunk(h, 2, dim=-1)
|
194
|
+
re_tail, im_tail = torch.chunk(t, 2, dim=-1)
|
195
|
+
|
196
|
+
# Make phases of relations uniformly distributed in [-pi, pi]
|
197
|
+
phase_relation = r / (self.rel_embedding_range.item() / pi)
|
198
|
+
|
199
|
+
re_relation = torch.cos(phase_relation)
|
200
|
+
im_relation = torch.sin(phase_relation)
|
201
|
+
|
202
|
+
re_score = re_head * re_relation - im_head * im_relation
|
203
|
+
im_score = re_head * im_relation + im_head * re_relation
|
204
|
+
re_score = re_score - re_tail
|
205
|
+
im_score = im_score - im_tail
|
206
|
+
|
207
|
+
score = torch.stack([re_score, im_score], dim = 0)
|
208
|
+
score = score.norm(dim = 0).sum(dim = -1)
|
209
|
+
return score
|
210
|
+
|
211
|
+
@override
|
212
|
+
def predict(
|
213
|
+
self,
|
214
|
+
data: dict[str, typing.Union[torch.Tensor,str]],
|
215
|
+
mode) -> torch.Tensor:
|
216
|
+
|
217
|
+
"""RotatE 的推理方法。
|
218
|
+
|
219
|
+
:param data: 数据。
|
220
|
+
:type data: dict[str, typing.Union[torch.Tensor,str]]
|
221
|
+
:returns: 三元组的得分
|
222
|
+
:rtype: torch.Tensor
|
223
|
+
"""
|
224
|
+
|
225
|
+
triples = data["positive_sample"]
|
226
|
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
|
227
|
+
score = self.margin - self._calc(head_emb, relation_emb, tail_emb)
|
228
|
+
return score
|
229
|
+
|
230
|
+
def regularization(
|
231
|
+
self,
|
232
|
+
data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
|
233
|
+
|
234
|
+
"""L2 正则化函数(又称权重衰减),在损失函数中用到。
|
235
|
+
|
236
|
+
:param data: 数据。
|
237
|
+
:type data: dict[str, typing.Union[torch.Tensor, str]]
|
238
|
+
:returns: 模型参数的正则损失
|
239
|
+
:rtype: torch.Tensor
|
240
|
+
"""
|
241
|
+
|
242
|
+
pos_sample = data["positive_sample"]
|
243
|
+
neg_sample = data["negative_sample"]
|
244
|
+
mode = data["mode"]
|
245
|
+
pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
|
246
|
+
if mode == "bern":
|
247
|
+
neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
|
248
|
+
else:
|
249
|
+
neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
|
250
|
+
|
251
|
+
pos_regul = (torch.mean(pos_head_emb ** 2) +
|
252
|
+
torch.mean(pos_relation_emb ** 2) +
|
253
|
+
torch.mean(pos_tail_emb ** 2)) / 3
|
254
|
+
|
255
|
+
neg_regul = (torch.mean(neg_head_emb ** 2) +
|
256
|
+
torch.mean(neg_relation_emb ** 2) +
|
257
|
+
torch.mean(neg_tail_emb ** 2)) / 3
|
258
|
+
|
259
|
+
regul = (pos_regul + neg_regul) / 2
|
260
|
+
|
261
|
+
return regul
|
262
|
+
|
263
|
+
def get_rotate_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
264
|
+
|
265
|
+
"""返回 :py:class:`RotatE` 的默认超参数优化配置。
|
266
|
+
|
267
|
+
默认配置为::
|
268
|
+
|
269
|
+
parameters_dict = {
|
270
|
+
'model': {
|
271
|
+
'value': 'RotatE'
|
272
|
+
},
|
273
|
+
'dim': {
|
274
|
+
'values': [256, 512, 1024]
|
275
|
+
},
|
276
|
+
'margin': {
|
277
|
+
'values': [1.0, 3.0, 6.0]
|
278
|
+
},
|
279
|
+
'epsilon': {
|
280
|
+
'value': 2.0
|
281
|
+
}
|
282
|
+
}
|
283
|
+
|
284
|
+
:returns: :py:class:`RotatE` 的默认超参数优化配置
|
285
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
286
|
+
"""
|
287
|
+
|
288
|
+
parameters_dict = {
|
289
|
+
'model': {
|
290
|
+
'value': 'RotatE'
|
291
|
+
},
|
292
|
+
'dim': {
|
293
|
+
'values': [256, 512, 1024]
|
294
|
+
},
|
295
|
+
'margin': {
|
296
|
+
'values': [1.0, 3.0, 6.0]
|
297
|
+
},
|
298
|
+
'epsilon': {
|
299
|
+
'value': 2.0
|
300
|
+
}
|
301
|
+
}
|
302
|
+
|
303
|
+
return parameters_dict
|