unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,562 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/model/CompGCN.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 26, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了 CompGCN 类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
CompGCN - 这是一种在图卷积网络中整合多关系信息的新框架,它利用知识图谱嵌入技术中的各种组合操作,将实体和关系共同嵌入到图中。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import dgl
|
15
|
+
import torch
|
16
|
+
import typing
|
17
|
+
from torch import nn
|
18
|
+
import dgl.function as fn
|
19
|
+
from .Model import Model
|
20
|
+
import torch.nn.functional as F
|
21
|
+
from typing_extensions import override
|
22
|
+
|
23
|
+
class CompGCN(Model):
|
24
|
+
|
25
|
+
"""
|
26
|
+
``CompGCN`` :cite:`CompGCN` 发表于 ``2020`` 年,这是一种在图卷积网络中整合多关系信息的新框架,它利用知识图谱嵌入技术中的各种组合操作,将实体和关系共同嵌入到图中。
|
27
|
+
|
28
|
+
正三元组的评分函数的值越大越好,如果想获得更详细的信息请访问 :ref:`CompGCN <compgcn>`。
|
29
|
+
|
30
|
+
例子::
|
31
|
+
|
32
|
+
from unike.module.model import CompGCN
|
33
|
+
from unike.module.loss import CompGCNLoss
|
34
|
+
from unike.module.strategy import CompGCNSampling
|
35
|
+
from unike.config import Trainer, GraphTester
|
36
|
+
|
37
|
+
# define the model
|
38
|
+
compgcn = CompGCN(
|
39
|
+
ent_tol = dataloader.get_ent_tol(),
|
40
|
+
rel_tol = dataloader.get_rel_tol(),
|
41
|
+
dim = 100
|
42
|
+
)
|
43
|
+
|
44
|
+
# define the loss function
|
45
|
+
model = CompGCNSampling(
|
46
|
+
model = compgcn,
|
47
|
+
loss = CompGCNLoss(model = compgcn),
|
48
|
+
ent_tol = dataloader.get_ent_tol()
|
49
|
+
)
|
50
|
+
|
51
|
+
# test the model
|
52
|
+
tester = GraphTester(model = compgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0', prediction = "tail")
|
53
|
+
|
54
|
+
# train the model
|
55
|
+
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
|
56
|
+
epochs = 2000, lr = 0.0001, use_gpu = True, device = 'cuda:0',
|
57
|
+
tester = tester, test = True, valid_interval = 50, log_interval = 50,
|
58
|
+
save_interval = 50, save_path = '../../checkpoint/compgcn.pth'
|
59
|
+
)
|
60
|
+
trainer.run()
|
61
|
+
"""
|
62
|
+
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
ent_tol: int,
|
66
|
+
rel_tol: int,
|
67
|
+
dim: int,
|
68
|
+
opn: str = 'mult',
|
69
|
+
fet_drop: float = 0.2,
|
70
|
+
hid_drop: float = 0.3,
|
71
|
+
margin: float = 40.0,
|
72
|
+
decoder_model: str = 'ConvE'):
|
73
|
+
|
74
|
+
"""创建 RGCN 对象。
|
75
|
+
|
76
|
+
:param ent_tol: 实体的个数
|
77
|
+
:type ent_tol: int
|
78
|
+
:param rel_tol: 关系的个数
|
79
|
+
:type rel_tol: int
|
80
|
+
:param dim: 实体和关系嵌入向量的维度
|
81
|
+
:type dim: int
|
82
|
+
:param opn: 组成运算符:'mult'、'sub'、'corr'
|
83
|
+
:type opn: str
|
84
|
+
:param fet_drop: 用于 'ConvE' 解码器,用于卷积特征的 dropout
|
85
|
+
:type fet_drop: float
|
86
|
+
:param hid_drop: 用于 'ConvE' 解码器,用于隐藏层的 dropout
|
87
|
+
:type hid_drop: float
|
88
|
+
:param margin: 用于 'TransE' 解码器,gamma。
|
89
|
+
:type margin: float
|
90
|
+
:param decoder_model: 用什么得分函数作为解码器: 'ConvE'、'DistMult'、'TransE'
|
91
|
+
:type decoder_model: str
|
92
|
+
"""
|
93
|
+
|
94
|
+
super(CompGCN, self).__init__(ent_tol, rel_tol)
|
95
|
+
|
96
|
+
#: 实体和关系嵌入向量的维度
|
97
|
+
self.dim: int = dim
|
98
|
+
#: 组成运算符:'mult'、'sub'、'corr'
|
99
|
+
self.opn: str = opn
|
100
|
+
#: 用什么得分函数作为解码器: 'ConvE'、'DistMult'
|
101
|
+
self.decoder_model: str = decoder_model
|
102
|
+
|
103
|
+
#------------------------------CompGCN--------------------------------------------------------------------
|
104
|
+
#: 根据实体个数,创建的实体嵌入
|
105
|
+
self.ent_emb: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor(self.ent_tol, self.dim))
|
106
|
+
#: 根据关系个数,创建的关系嵌入
|
107
|
+
self.rel_emb: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor(self.rel_tol, self.dim))
|
108
|
+
|
109
|
+
nn.init.xavier_normal_(self.ent_emb, gain=nn.init.calculate_gain('relu'))
|
110
|
+
nn.init.xavier_normal_(self.rel_emb, gain=nn.init.calculate_gain('relu'))
|
111
|
+
|
112
|
+
#: CompGCNCov
|
113
|
+
self.GraphCov: CompGCNCov = CompGCNCov(self.dim, self.dim * 2, torch.tanh, bias = 'False', drop_rate = 0.1, opn = self.opn)
|
114
|
+
#: 用于 :py:attr:`GraphCov` 输出结果
|
115
|
+
self.drop: torch.nn.Dropout = nn.Dropout(0.3)
|
116
|
+
#: 最后计算得分时的偏置
|
117
|
+
self.bias: torch.nn.parameter.Parameter = nn.Parameter(torch.zeros(self.ent_tol))
|
118
|
+
#-----------------------------ConvE-----------------------------------------------------------------------
|
119
|
+
#: 用于 'ConvE' 解码器,头实体嵌入向量和关系嵌入向量的 BatchNorm
|
120
|
+
self.bn0: torch.nn.BatchNorm2d = torch.nn.BatchNorm2d(1)
|
121
|
+
#: 用于 'ConvE' 解码器,卷积层
|
122
|
+
self.conv1: torch.nn.Conv2d = torch.nn.Conv2d(1, 200, (7, 7), 1, 0, bias=False)
|
123
|
+
#: 用于 'ConvE' 解码器,卷积特征的 BatchNorm
|
124
|
+
self.bn1: torch.nn.Conv2d = torch.nn.BatchNorm2d(200)
|
125
|
+
#: 用于 'ConvE' 解码器,卷积特征的 Dropout
|
126
|
+
self.fet_drop: torch.nn.Dropout = torch.nn.Dropout2d(fet_drop)
|
127
|
+
flat_sz_h = 4 * self.dim // 20 - 7 + 1
|
128
|
+
flat_sz_w = 20 - 7 + 1
|
129
|
+
flat_sz = flat_sz_h * flat_sz_w * 200
|
130
|
+
#: 用于 'ConvE' 解码器,隐藏层层
|
131
|
+
self.fc: torch.nn.Linear = torch.nn.Linear(flat_sz, self.dim*2)
|
132
|
+
#: 用于 'ConvE' 解码器,隐藏层的 Dropout
|
133
|
+
self.hid_drop: torch.nn.Dropout = torch.nn.Dropout(hid_drop)
|
134
|
+
#: 用于 'ConvE' 解码器,隐藏层的 BatchNorm
|
135
|
+
self.bn2: torch.nn.BatchNorm1d = torch.nn.BatchNorm1d(self.dim*2)
|
136
|
+
#-----------------------------TransE-----------------------------------------------------------------------
|
137
|
+
#: 用于 TransE 得分函数
|
138
|
+
self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
|
139
|
+
self.margin.requires_grad = False
|
140
|
+
|
141
|
+
@override
|
142
|
+
def forward(
|
143
|
+
self,
|
144
|
+
graph: dgl.DGLGraph,
|
145
|
+
relation: torch.Tensor,
|
146
|
+
norm: torch.Tensor,
|
147
|
+
triples: torch.Tensor) -> torch.Tensor:
|
148
|
+
|
149
|
+
"""
|
150
|
+
定义每次调用时执行的计算。
|
151
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
152
|
+
|
153
|
+
:param graph: 子图
|
154
|
+
:type graph: dgl.DGLGraph
|
155
|
+
:param relation: 子图的关系
|
156
|
+
:type relation: torch.Tensor
|
157
|
+
:param norm: 关系的归一化系数
|
158
|
+
:type norm: torch.Tensor
|
159
|
+
:param triples: 三元组
|
160
|
+
:type triples: torch.Tensor
|
161
|
+
:returns: 三元组的得分
|
162
|
+
:rtype: torch.Tensor
|
163
|
+
"""
|
164
|
+
|
165
|
+
head, rela = triples[:,0], triples[:, 1]
|
166
|
+
x, r = self.ent_emb, self.rel_emb
|
167
|
+
x, r = self.GraphCov(graph, x, r, relation, norm)
|
168
|
+
x = self.drop(x)
|
169
|
+
head_emb = torch.index_select(x, 0, head)
|
170
|
+
rela_emb = torch.index_select(r, 0, rela)
|
171
|
+
|
172
|
+
if self.decoder_model.lower() == 'conve':
|
173
|
+
score = self.conve(head_emb, rela_emb, x)
|
174
|
+
elif self.decoder_model.lower() == 'distmult':
|
175
|
+
score = self.distmult(head_emb, rela_emb, x)
|
176
|
+
elif self.decoder_model.lower() == 'transe':
|
177
|
+
score = self.transe(head_emb, rela_emb, x)
|
178
|
+
else:
|
179
|
+
raise ValueError("please choose decoder (TransE/DistMult/ConvE)")
|
180
|
+
|
181
|
+
return score
|
182
|
+
|
183
|
+
def conve(
|
184
|
+
self,
|
185
|
+
sub_emb: torch.Tensor,
|
186
|
+
rel_emb: torch.Tensor,
|
187
|
+
all_ent: torch.Tensor) -> torch.Tensor:
|
188
|
+
|
189
|
+
"""计算 ConvE 作为解码器时三元组的得分。
|
190
|
+
|
191
|
+
:param sub_emb: 头实体的嵌入向量
|
192
|
+
:type sub_emb: torch.Tensor
|
193
|
+
:param rel_emb: 关系的嵌入向量
|
194
|
+
:type rel_emb: torch.Tensor
|
195
|
+
:param all_ent: 全部实体的嵌入向量
|
196
|
+
:type all_ent: torch.Tensor
|
197
|
+
:returns: 三元组的得分
|
198
|
+
:rtype: torch.Tensor"""
|
199
|
+
|
200
|
+
stack_input = self.concat(sub_emb, rel_emb)
|
201
|
+
x = self.bn0(stack_input)
|
202
|
+
x = self.conv1(x)
|
203
|
+
x = self.bn1(x)
|
204
|
+
x = F.relu(x)
|
205
|
+
x = self.fet_drop(x)
|
206
|
+
x = x.view(x.shape[0], -1)
|
207
|
+
x = self.fc(x)
|
208
|
+
x = self.hid_drop(x)
|
209
|
+
x = self.bn2(x)
|
210
|
+
x = F.relu(x)
|
211
|
+
x = torch.mm(x, all_ent.transpose(1, 0))
|
212
|
+
x += self.bias.expand_as(x)
|
213
|
+
score = torch.sigmoid(x)
|
214
|
+
return score
|
215
|
+
|
216
|
+
def concat(
|
217
|
+
self,
|
218
|
+
ent_embed: torch.Tensor,
|
219
|
+
rel_embed: torch.Tensor) -> torch.Tensor:
|
220
|
+
|
221
|
+
"""ConvE 作为解码器时,用于拼接头实体嵌入向量和关系嵌入向量。
|
222
|
+
|
223
|
+
:param ent_embed: 头实体的嵌入向量
|
224
|
+
:type ent_embed: torch.Tensor
|
225
|
+
:param rel_embed: 关系的嵌入向量
|
226
|
+
:type rel_embed: torch.Tensor
|
227
|
+
:returns: ConvE 解码器的输入特征
|
228
|
+
:rtype: torch.Tensor"""
|
229
|
+
|
230
|
+
ent_embed = ent_embed.view(-1, 1, self.dim*2)
|
231
|
+
rel_embed = rel_embed.view(-1, 1, self.dim*2)
|
232
|
+
stack_input = torch.cat([ent_embed, rel_embed], 1)
|
233
|
+
stack_input = stack_input.reshape(-1, 1, 4*self.dim//20, 20)
|
234
|
+
return stack_input
|
235
|
+
|
236
|
+
def distmult(
|
237
|
+
self,
|
238
|
+
head_emb: torch.Tensor,
|
239
|
+
rela_emb: torch.Tensor,
|
240
|
+
all_ent: torch.Tensor) -> torch.Tensor:
|
241
|
+
|
242
|
+
"""计算 DistMult 作为解码器时三元组的得分。
|
243
|
+
|
244
|
+
:param sub_emb: 头实体的嵌入向量
|
245
|
+
:type sub_emb: torch.Tensor
|
246
|
+
:param rel_emb: 关系的嵌入向量
|
247
|
+
:type rel_emb: torch.Tensor
|
248
|
+
:param all_ent: 全部实体的嵌入向量
|
249
|
+
:type all_ent: torch.Tensor
|
250
|
+
:returns: 三元组的得分
|
251
|
+
:rtype: torch.Tensor"""
|
252
|
+
|
253
|
+
obj_emb = head_emb * rela_emb
|
254
|
+
x = torch.mm(obj_emb, all_ent.transpose(1, 0))
|
255
|
+
x += self.bias.expand_as(x)
|
256
|
+
score = torch.sigmoid(x)
|
257
|
+
return score
|
258
|
+
|
259
|
+
def transe(
|
260
|
+
self,
|
261
|
+
head_emb: torch.Tensor,
|
262
|
+
rela_emb: torch.Tensor,
|
263
|
+
all_ent: torch.Tensor) -> torch.Tensor:
|
264
|
+
|
265
|
+
"""计算 TransE 作为解码器时三元组的得分。
|
266
|
+
|
267
|
+
:param sub_emb: 头实体的嵌入向量
|
268
|
+
:type sub_emb: torch.Tensor
|
269
|
+
:param rel_emb: 关系的嵌入向量
|
270
|
+
:type rel_emb: torch.Tensor
|
271
|
+
:param all_ent: 全部实体的嵌入向量
|
272
|
+
:type all_ent: torch.Tensor
|
273
|
+
:returns: 三元组的得分
|
274
|
+
:rtype: torch.Tensor"""
|
275
|
+
|
276
|
+
obj_emb = head_emb + rela_emb
|
277
|
+
x = self.margin - torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2)
|
278
|
+
score = torch.sigmoid(x)
|
279
|
+
return score
|
280
|
+
|
281
|
+
@override
|
282
|
+
def predict(
|
283
|
+
self,
|
284
|
+
data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]],
|
285
|
+
mode: str) -> torch.Tensor:
|
286
|
+
|
287
|
+
"""CompGCN 的推理方法。
|
288
|
+
|
289
|
+
:param data: 数据。
|
290
|
+
:type data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
|
291
|
+
:param mode: 在 CompGCN 时,无用,只为了保证推理函数形式一致
|
292
|
+
:type mode: str
|
293
|
+
:returns: 三元组的得分
|
294
|
+
:rtype: torch.Tensor
|
295
|
+
"""
|
296
|
+
|
297
|
+
triples = data['positive_sample']
|
298
|
+
graph = data['graph']
|
299
|
+
relation = data['rela']
|
300
|
+
norm = data['norm']
|
301
|
+
|
302
|
+
head, rela = triples[:,0], triples[:, 1]
|
303
|
+
x, r = self.ent_emb, self.rel_emb
|
304
|
+
x, r = self.GraphCov(graph, x, r, relation, norm)
|
305
|
+
x = self.drop(x)
|
306
|
+
head_emb = torch.index_select(x, 0, head)
|
307
|
+
rela_emb = torch.index_select(r, 0, rela)
|
308
|
+
|
309
|
+
if self.decoder_model.lower() == 'conve':
|
310
|
+
score = self.conve(head_emb, rela_emb, x)
|
311
|
+
elif self.decoder_model.lower() == 'distmult':
|
312
|
+
score = self.distmult(head_emb, rela_emb, x)
|
313
|
+
elif self.decoder_model.lower() == 'transe':
|
314
|
+
score = self.transe(head_emb, rela_emb, x)
|
315
|
+
else:
|
316
|
+
raise ValueError("please choose decoder (TransE/DistMult/ConvE)")
|
317
|
+
|
318
|
+
return score
|
319
|
+
|
320
|
+
class CompGCNCov(nn.Module):
|
321
|
+
|
322
|
+
"""``CompGCN`` :cite:`CompGCN` 图神经网络模块。"""
|
323
|
+
|
324
|
+
def __init__(
|
325
|
+
self,
|
326
|
+
in_channels: int,
|
327
|
+
out_channels: int,
|
328
|
+
act: typing.Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
329
|
+
bias: bool = True,
|
330
|
+
drop_rate: float = 0.,
|
331
|
+
opn: str = 'corr'):
|
332
|
+
|
333
|
+
"""创建 CompGCN 对象。
|
334
|
+
|
335
|
+
:param in_channels: 输入的特征维度
|
336
|
+
:type in_channels: int
|
337
|
+
:param out_channels: 输出的特征维度
|
338
|
+
:type out_channels: int
|
339
|
+
:param act: 激活函数
|
340
|
+
:type act: typing.Callable[[torch.Tensor], torch.Tensor]
|
341
|
+
:param bias: 是否有偏置
|
342
|
+
:type bias: bool
|
343
|
+
:param drop_rate: Dropout rate
|
344
|
+
:type drop_rate: float
|
345
|
+
:param opn: 组成运算符:'mult'、'sub'、'corr'
|
346
|
+
:type opn: str
|
347
|
+
"""
|
348
|
+
|
349
|
+
super(CompGCNCov, self).__init__()
|
350
|
+
|
351
|
+
#: 输入的特征维度
|
352
|
+
self.in_channels: int = in_channels
|
353
|
+
#: 输出的特征维度
|
354
|
+
self.out_channels: int = out_channels
|
355
|
+
self.rel_wt = None
|
356
|
+
#: 关系嵌入向量
|
357
|
+
self.rel: torch.nn.parameter.Parameter = None
|
358
|
+
#: 组成运算符:'mult'、'sub'、'corr'
|
359
|
+
self.opn: str = opn
|
360
|
+
#:图神经网络的权重矩阵,用于原始关系
|
361
|
+
self.in_w: torch.nn.parameter.Parameter = self.get_param([in_channels, out_channels])
|
362
|
+
#:图神经网络的权重矩阵,用于相反关系
|
363
|
+
self.out_w: torch.nn.parameter.Parameter = self.get_param([in_channels, out_channels])
|
364
|
+
#: 用于原始关系和相反关系转换后输出结果的 Dropout
|
365
|
+
self.drop: torch.nn.Dropout = nn.Dropout(drop_rate)
|
366
|
+
#: 自循环关系嵌入向量的转换矩阵
|
367
|
+
self.loop_rel: torch.nn.parameter.Parameter = self.get_param([1, in_channels])
|
368
|
+
#:图神经网络的权重矩阵,用于自循环关系
|
369
|
+
self.loop_w: torch.nn.parameter.Parameter = self.get_param([in_channels, out_channels])
|
370
|
+
#: 偏置
|
371
|
+
self.bias: torch.nn.Parameter = nn.Parameter(torch.zeros(out_channels)) if bias else None
|
372
|
+
#: BatchNorm
|
373
|
+
self.bn: torch.nn.BatchNorm1d = torch.nn.BatchNorm1d(out_channels)
|
374
|
+
#: 激活函数
|
375
|
+
self.act: typing.Callable[[torch.Tensor], torch.Tensor] = act
|
376
|
+
#: 关系嵌入向量的转换矩阵
|
377
|
+
self.w_rel: torch.nn.parameter.Parameter = self.get_param([in_channels, out_channels])
|
378
|
+
|
379
|
+
def get_param(
|
380
|
+
self,
|
381
|
+
shape: list[int]) -> torch.nn.parameter.Parameter:
|
382
|
+
|
383
|
+
"""获得权重矩阵。
|
384
|
+
|
385
|
+
:param shape: 权重矩阵的 shape
|
386
|
+
:type shape: list[int]
|
387
|
+
:returns: 权重矩阵
|
388
|
+
:rtype: torch.nn.parameter.Parameter
|
389
|
+
"""
|
390
|
+
|
391
|
+
param = nn.Parameter(torch.Tensor(*shape))
|
392
|
+
nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
|
393
|
+
return param
|
394
|
+
|
395
|
+
def forward(
|
396
|
+
self,
|
397
|
+
graph: dgl.DGLGraph,
|
398
|
+
ent_emb: torch.nn.parameter.Parameter,
|
399
|
+
rel_emb: torch.nn.parameter.Parameter,
|
400
|
+
edge_type: torch.Tensor,
|
401
|
+
edge_norm: torch.Tensor) -> tuple[torch.nn.parameter.Parameter, torch.nn.parameter.Parameter]:
|
402
|
+
|
403
|
+
"""
|
404
|
+
定义每次调用时执行的计算。
|
405
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
406
|
+
|
407
|
+
:param graph: 子图
|
408
|
+
:type graph: dgl.DGLGraph
|
409
|
+
:param ent_emb: 实体嵌入向量
|
410
|
+
:type ent_emb: torch.nn.parameter.Parameter
|
411
|
+
:param rel_emb: 关系嵌入向量
|
412
|
+
:type rel_emb: torch.nn.parameter.Parameter
|
413
|
+
:param edge_type: 关系 ID
|
414
|
+
:type edge_type: torch.Tensor
|
415
|
+
:param norm: 关系的归一化系数
|
416
|
+
:type norm: torch.Tensor
|
417
|
+
:returns: 更新后的实体嵌入和关系嵌入
|
418
|
+
:rtype: tuple[torch.nn.parameter.Parameter, torch.nn.parameter.Parameter]
|
419
|
+
"""
|
420
|
+
|
421
|
+
graph = graph.local_var()
|
422
|
+
graph.ndata['h'] = ent_emb
|
423
|
+
graph.edata['type'] = edge_type
|
424
|
+
graph.edata['norm'] = edge_norm
|
425
|
+
if self.rel_wt is None:
|
426
|
+
self.rel = rel_emb
|
427
|
+
else:
|
428
|
+
self.rel = torch.mm(self.rel_wt, rel_emb)
|
429
|
+
graph.update_all(self.message_func, fn.sum(msg='msg', out='h'), self.reduce_func)
|
430
|
+
ent_emb = graph.ndata.pop('h') + torch.mm(self.comp(ent_emb, self.loop_rel), self.loop_w) / 3
|
431
|
+
if self.bias is not None:
|
432
|
+
ent_emb = ent_emb + self.bias
|
433
|
+
ent_emb = self.bn(ent_emb)
|
434
|
+
|
435
|
+
return self.act(ent_emb), torch.matmul(self.rel, self.w_rel)
|
436
|
+
|
437
|
+
def message_func(self, edges: dgl.udf.EdgeBatch):
|
438
|
+
|
439
|
+
"""
|
440
|
+
消息函数。
|
441
|
+
"""
|
442
|
+
|
443
|
+
edge_type = edges.data['type']
|
444
|
+
edge_num = edge_type.shape[0]
|
445
|
+
edge_data = self.comp(edges.src['h'], self.rel[edge_type])
|
446
|
+
msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w),
|
447
|
+
torch.matmul(edge_data[edge_num // 2:, :], self.out_w)])
|
448
|
+
msg = msg * edges.data['norm'].reshape(-1, 1)
|
449
|
+
return {'msg': msg}
|
450
|
+
|
451
|
+
def comp(
|
452
|
+
self,
|
453
|
+
h: torch.Tensor,
|
454
|
+
r: torch.Tensor) -> torch.Tensor:
|
455
|
+
|
456
|
+
"""组成运算:'mult'、'sub'、'corr'
|
457
|
+
|
458
|
+
:param h: 头实体嵌入向量
|
459
|
+
:type h: torch.Tensor
|
460
|
+
:param r: 关系嵌入向量
|
461
|
+
:type r: torch.Tensor
|
462
|
+
:returns: 组合后的边数据
|
463
|
+
:rtype: torch.Tensor
|
464
|
+
"""
|
465
|
+
|
466
|
+
def com_mult(a, b):
|
467
|
+
|
468
|
+
"""复数乘法"""
|
469
|
+
|
470
|
+
r1, i1 = a.real, a.imag
|
471
|
+
r2, i2 = b.real, b.imag
|
472
|
+
real = r1 * r2 - i1 * i2
|
473
|
+
imag = r1 * i2 + i1 * r2
|
474
|
+
return torch.complex(real, imag)
|
475
|
+
|
476
|
+
def conj(a):
|
477
|
+
|
478
|
+
"""共轭复数"""
|
479
|
+
|
480
|
+
a.imag = -a.imag
|
481
|
+
return a
|
482
|
+
|
483
|
+
def ccorr(a, b):
|
484
|
+
|
485
|
+
"""corr 运算"""
|
486
|
+
|
487
|
+
return torch.fft.irfft(com_mult(conj(torch.fft.rfft(a)), torch.fft.rfft(b)), a.shape[-1])
|
488
|
+
|
489
|
+
if self.opn == 'mult':
|
490
|
+
return h * r
|
491
|
+
elif self.opn == 'sub':
|
492
|
+
return h - r
|
493
|
+
elif self.opn == 'corr':
|
494
|
+
return ccorr(h, r.expand_as(h))
|
495
|
+
else:
|
496
|
+
raise KeyError(f'composition operator {self.opn} not recognized.')
|
497
|
+
|
498
|
+
def reduce_func(self, nodes: dgl.udf.NodeBatch):
|
499
|
+
|
500
|
+
"""聚合函数"""
|
501
|
+
|
502
|
+
return {'h': self.drop(nodes.data['h']) / 3}
|
503
|
+
|
504
|
+
def get_compgcn_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
505
|
+
|
506
|
+
"""返回 :py:class:`CompGCN` 的默认超参数优化配置。
|
507
|
+
|
508
|
+
默认配置为::
|
509
|
+
|
510
|
+
parameters_dict = {
|
511
|
+
'model': {
|
512
|
+
'value': 'CompGCN'
|
513
|
+
},
|
514
|
+
'dim': {
|
515
|
+
'values': [100, 150, 200]
|
516
|
+
},
|
517
|
+
'opn': {
|
518
|
+
'value': 'mult'
|
519
|
+
},
|
520
|
+
'fet_drop': {
|
521
|
+
'value': 0.2
|
522
|
+
},
|
523
|
+
'hid_drop': {
|
524
|
+
'value': 0.3
|
525
|
+
},
|
526
|
+
'margin': {
|
527
|
+
'value': 40.0
|
528
|
+
},
|
529
|
+
'decoder_model': {
|
530
|
+
'value': 'ConvE'
|
531
|
+
}
|
532
|
+
}
|
533
|
+
|
534
|
+
:returns: :py:class:`CompGCN` 的默认超参数优化配置
|
535
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
536
|
+
"""
|
537
|
+
|
538
|
+
parameters_dict = {
|
539
|
+
'model': {
|
540
|
+
'value': 'CompGCN'
|
541
|
+
},
|
542
|
+
'dim': {
|
543
|
+
'values': [100, 150, 200]
|
544
|
+
},
|
545
|
+
'opn': {
|
546
|
+
'value': 'mult'
|
547
|
+
},
|
548
|
+
'fet_drop': {
|
549
|
+
'value': 0.2
|
550
|
+
},
|
551
|
+
'hid_drop': {
|
552
|
+
'value': 0.3
|
553
|
+
},
|
554
|
+
'margin': {
|
555
|
+
'value': 40.0
|
556
|
+
},
|
557
|
+
'decoder_model': {
|
558
|
+
'value': 'ConvE'
|
559
|
+
}
|
560
|
+
}
|
561
|
+
|
562
|
+
return parameters_dict
|