unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,402 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/model/TransR.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
|
7
|
+
#
|
8
|
+
# 该头文件定义了 TransR.
|
9
|
+
|
10
|
+
"""
|
11
|
+
TransR - 是一个为实体和关系嵌入向量分别构建了独立的向量空间,将实体向量投影到特定的关系向量空间进行平移操作的模型。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import typing
|
16
|
+
import numpy as np
|
17
|
+
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
19
|
+
from .Model import Model
|
20
|
+
from typing_extensions import override
|
21
|
+
|
22
|
+
class TransR(Model):
|
23
|
+
|
24
|
+
"""
|
25
|
+
``TransR`` :cite:`TransR` 提出于 2015 年,是一个为实体和关系嵌入向量分别构建了独立的向量空间,将实体向量投影到特定的关系向量空间进行平移操作的模型。
|
26
|
+
|
27
|
+
评分函数为:
|
28
|
+
|
29
|
+
.. math::
|
30
|
+
|
31
|
+
\Vert hM_r+r-tM_r \Vert_{L_1/L_2}
|
32
|
+
|
33
|
+
正三元组的评分函数的值越小越好,如果想获得更详细的信息请访问 :ref:`TransR <transr>`。
|
34
|
+
|
35
|
+
例子::
|
36
|
+
|
37
|
+
from unike.data import KGEDataLoader, BernSampler, TradTestSampler
|
38
|
+
from unike.module.model import TransE, TransR
|
39
|
+
from unike.module.loss import MarginLoss
|
40
|
+
from unike.module.strategy import NegativeSampling
|
41
|
+
from unike.config import Trainer, Tester
|
42
|
+
|
43
|
+
# dataloader for training
|
44
|
+
dataloader = KGEDataLoader(
|
45
|
+
in_path = "../../benchmarks/FB15K237/",
|
46
|
+
batch_size = 2048,
|
47
|
+
neg_ent = 25,
|
48
|
+
test = True,
|
49
|
+
test_batch_size = 10,
|
50
|
+
num_workers = 16,
|
51
|
+
train_sampler = BernSampler,
|
52
|
+
test_sampler = TradTestSampler
|
53
|
+
)
|
54
|
+
|
55
|
+
# define the transe
|
56
|
+
transe = TransE(
|
57
|
+
ent_tol = dataloader.get_ent_tol(),
|
58
|
+
rel_tol = dataloader.get_rel_tol(),
|
59
|
+
dim = 100,
|
60
|
+
p_norm = 1,
|
61
|
+
norm_flag = True)
|
62
|
+
|
63
|
+
transr = TransR(
|
64
|
+
ent_tol = dataloader.get_ent_tol(),
|
65
|
+
rel_tol = dataloader.get_rel_tol(),
|
66
|
+
dim_e = 100,
|
67
|
+
dim_r = 100,
|
68
|
+
p_norm = 1,
|
69
|
+
norm_flag = True,
|
70
|
+
rand_init = False)
|
71
|
+
|
72
|
+
model_e = NegativeSampling(
|
73
|
+
model = transe,
|
74
|
+
loss = MarginLoss(margin = 5.0)
|
75
|
+
)
|
76
|
+
|
77
|
+
model_r = NegativeSampling(
|
78
|
+
model = transr,
|
79
|
+
loss = MarginLoss(margin = 4.0)
|
80
|
+
)
|
81
|
+
|
82
|
+
# pretrain transe
|
83
|
+
trainer = Trainer(model = model_e, data_loader = dataloader.train_dataloader(),
|
84
|
+
epochs = 1, lr = 0.5, opt_method = "sgd", use_gpu = True, device = 'cuda:0')
|
85
|
+
trainer.run()
|
86
|
+
parameters = transe.get_parameters()
|
87
|
+
transe.save_parameters("../../checkpoint/transr_transe.json")
|
88
|
+
|
89
|
+
# test the transr
|
90
|
+
tester = Tester(model = transr, data_loader = dataloader, use_gpu = True, device = 'cuda:0')
|
91
|
+
|
92
|
+
# train transr
|
93
|
+
transr.set_parameters(parameters)
|
94
|
+
trainer = Trainer(model = model_r, data_loader = dataloader.train_dataloader(),
|
95
|
+
epochs = 1000, lr = 1.0, opt_method = "sgd", use_gpu = True, device = 'cuda:0',
|
96
|
+
tester = tester, test = True, valid_interval = 100,
|
97
|
+
log_interval = 100, save_interval = 100, save_path = '../../checkpoint/transr.pth')
|
98
|
+
trainer.run()
|
99
|
+
|
100
|
+
# test the model
|
101
|
+
transr.load_checkpoint('../../checkpoint/transr.pth')
|
102
|
+
tester.set_sampling_mode("link_test")
|
103
|
+
tester.run_link_prediction()
|
104
|
+
"""
|
105
|
+
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
ent_tol: int,
|
109
|
+
rel_tol: int,
|
110
|
+
dim_e: int = 100,
|
111
|
+
dim_r: int = 100,
|
112
|
+
p_norm: int = 1,
|
113
|
+
norm_flag: bool = True,
|
114
|
+
rand_init: bool = False,
|
115
|
+
margin: float | None = None):
|
116
|
+
|
117
|
+
"""创建 TransR 对象。
|
118
|
+
|
119
|
+
:param ent_tol: 实体的个数
|
120
|
+
:type ent_tol: int
|
121
|
+
:param rel_tol: 关系的个数
|
122
|
+
:type rel_tol: int
|
123
|
+
:param dim_e: 实体嵌入向量的维度
|
124
|
+
:type dim_e: int
|
125
|
+
:param dim_r: 关系嵌入向量的维度
|
126
|
+
:type dim_r: int
|
127
|
+
:param p_norm: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
|
128
|
+
:type p_norm: int
|
129
|
+
:param norm_flag: 是否利用 :py:func:`torch.nn.functional.normalize` 对实体和关系嵌入的最后一维执行 L2-norm。
|
130
|
+
:type norm_flag: bool
|
131
|
+
:param rand_init: 关系矩阵是否采用随机初始化。
|
132
|
+
:type rand_init: bool
|
133
|
+
:param margin: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
|
134
|
+
:type margin: float
|
135
|
+
"""
|
136
|
+
|
137
|
+
super(TransR, self).__init__(ent_tol, rel_tol)
|
138
|
+
|
139
|
+
#: 实体嵌入向量的维度
|
140
|
+
self.dim_e: int = dim_e
|
141
|
+
#: 关系嵌入向量的维度
|
142
|
+
self.dim_r: int = dim_r
|
143
|
+
#: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
|
144
|
+
self.p_norm: int = p_norm
|
145
|
+
#: 是否利用 :py:func:`torch.nn.functional.normalize`
|
146
|
+
#: 对实体和关系嵌入向量的最后一维执行 L2-norm。
|
147
|
+
self.norm_flag: bool = norm_flag
|
148
|
+
#: 关系矩阵是否采用随机初始化
|
149
|
+
self.rand_init: bool = rand_init
|
150
|
+
|
151
|
+
#: 根据实体个数,创建的实体嵌入
|
152
|
+
self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim_e)
|
153
|
+
#: 根据关系个数,创建的关系嵌入
|
154
|
+
self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim_r)
|
155
|
+
|
156
|
+
if margin != None:
|
157
|
+
#: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
|
158
|
+
self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
|
159
|
+
self.margin.requires_grad = False
|
160
|
+
self.margin_flag: bool = True
|
161
|
+
else:
|
162
|
+
self.margin_flag: bool = False
|
163
|
+
|
164
|
+
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
|
165
|
+
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
|
166
|
+
|
167
|
+
#: 关系矩阵
|
168
|
+
self.transfer_matrix: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim_e * self.dim_r)
|
169
|
+
|
170
|
+
if not self.rand_init:
|
171
|
+
identity = torch.zeros(self.dim_e, self.dim_r)
|
172
|
+
for i in range(min(self.dim_e, self.dim_r)):
|
173
|
+
identity[i][i] = 1
|
174
|
+
identity = identity.view(self.dim_e * self.dim_r)
|
175
|
+
for i in range(self.rel_tol):
|
176
|
+
self.transfer_matrix.weight.data[i] = identity
|
177
|
+
else:
|
178
|
+
nn.init.xavier_uniform_(self.transfer_matrix.weight.data)
|
179
|
+
|
180
|
+
@override
|
181
|
+
def forward(
|
182
|
+
self,
|
183
|
+
triples: torch.Tensor,
|
184
|
+
negs: torch.Tensor = None,
|
185
|
+
mode: str = 'single') -> torch.Tensor:
|
186
|
+
|
187
|
+
"""
|
188
|
+
定义每次调用时执行的计算。
|
189
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
190
|
+
|
191
|
+
:param triples: 正确的三元组
|
192
|
+
:type triples: torch.Tensor
|
193
|
+
:param negs: 负三元组类别
|
194
|
+
:type negs: torch.Tensor
|
195
|
+
:param mode: 模式
|
196
|
+
:type triples: str
|
197
|
+
:returns: 三元组的得分
|
198
|
+
:rtype: torch.Tensor
|
199
|
+
"""
|
200
|
+
|
201
|
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
|
202
|
+
rel_transfer = self.transfer_matrix(triples[:, 1])
|
203
|
+
head_emb = self._transfer(head_emb, rel_transfer)
|
204
|
+
tail_emb = self._transfer(tail_emb, rel_transfer)
|
205
|
+
score = self._calc(head_emb, relation_emb, tail_emb)
|
206
|
+
if self.margin_flag:
|
207
|
+
return self.margin - score
|
208
|
+
else:
|
209
|
+
return score
|
210
|
+
|
211
|
+
def _transfer(
|
212
|
+
self,
|
213
|
+
e: torch.Tensor,
|
214
|
+
r_transfer: torch.Tensor) -> torch.Tensor:
|
215
|
+
|
216
|
+
"""
|
217
|
+
将头实体或尾实体的向量投影到特定的关系向量空间。
|
218
|
+
|
219
|
+
:param e: 头实体或尾实体向量。
|
220
|
+
:type e: torch.Tensor
|
221
|
+
:param r_transfer: 特定关系矩阵
|
222
|
+
:type r_transfer: torch.Tensor
|
223
|
+
:returns: 投影后的实体向量
|
224
|
+
:rtype: torch.Tensor
|
225
|
+
"""
|
226
|
+
|
227
|
+
r_transfer = r_transfer.view(-1, self.dim_e, self.dim_r)
|
228
|
+
r_transfer = r_transfer.unsqueeze(dim=1)
|
229
|
+
e = e.unsqueeze(dim=-2)
|
230
|
+
e = torch.matmul(e, r_transfer)
|
231
|
+
return e.squeeze(dim=-2)
|
232
|
+
|
233
|
+
def _calc(
|
234
|
+
self,
|
235
|
+
h: torch.Tensor,
|
236
|
+
r: torch.Tensor,
|
237
|
+
t: torch.Tensor) -> torch.Tensor:
|
238
|
+
|
239
|
+
"""计算 TransR 的评分函数。
|
240
|
+
|
241
|
+
:param h: 头实体的向量。
|
242
|
+
:type h: torch.Tensor
|
243
|
+
:param r: 关系的向量。
|
244
|
+
:type r: torch.Tensor
|
245
|
+
:param t: 尾实体的向量。
|
246
|
+
:type t: torch.Tensor
|
247
|
+
:returns: 三元组的得分
|
248
|
+
:rtype: torch.Tensor
|
249
|
+
"""
|
250
|
+
|
251
|
+
# 对嵌入的最后一维进行归一化
|
252
|
+
if self.norm_flag:
|
253
|
+
h = F.normalize(h, 2, -1)
|
254
|
+
r = F.normalize(r, 2, -1)
|
255
|
+
t = F.normalize(t, 2, -1)
|
256
|
+
|
257
|
+
score = (h + r) - t
|
258
|
+
|
259
|
+
# 利用距离函数计算得分
|
260
|
+
score = torch.norm(score, self.p_norm, -1)
|
261
|
+
return score
|
262
|
+
|
263
|
+
@override
|
264
|
+
def predict(
|
265
|
+
self,
|
266
|
+
data: dict[str, typing.Union[torch.Tensor,str]],
|
267
|
+
mode: str) -> torch.Tensor:
|
268
|
+
|
269
|
+
"""TransR 的推理方法。
|
270
|
+
|
271
|
+
:param data: 数据。
|
272
|
+
:type data: dict[str, typing.Union[torch.Tensor,str]]
|
273
|
+
:param mode: 'head_predict' 或 'tail_predict'
|
274
|
+
:type mode: str
|
275
|
+
:returns: 三元组的得分
|
276
|
+
:rtype: torch.Tensor
|
277
|
+
"""
|
278
|
+
|
279
|
+
triples = data["positive_sample"]
|
280
|
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
|
281
|
+
rel_transfer = self.transfer_matrix(triples[:, 1])
|
282
|
+
head_emb = self._transfer(head_emb, rel_transfer)
|
283
|
+
tail_emb = self._transfer(tail_emb, rel_transfer)
|
284
|
+
score = self._calc(head_emb, relation_emb, tail_emb)
|
285
|
+
|
286
|
+
if self.margin_flag:
|
287
|
+
score = self.margin - score
|
288
|
+
return score
|
289
|
+
else:
|
290
|
+
return -score
|
291
|
+
|
292
|
+
def regularization(
|
293
|
+
self,
|
294
|
+
data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
|
295
|
+
|
296
|
+
"""L2 正则化函数(又称权重衰减),在损失函数中用到。
|
297
|
+
|
298
|
+
:param data: 数据。
|
299
|
+
:type data: dict[str, typing.Union[torch.Tensor, str]]
|
300
|
+
:returns: 模型参数的正则损失
|
301
|
+
:rtype: torch.Tensor
|
302
|
+
"""
|
303
|
+
|
304
|
+
pos_sample = data["positive_sample"]
|
305
|
+
neg_sample = data["negative_sample"]
|
306
|
+
mode = data["mode"]
|
307
|
+
pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
|
308
|
+
pos_rel_transfer = self.transfer_matrix(pos_sample[:, 1])
|
309
|
+
if mode == "bern":
|
310
|
+
neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
|
311
|
+
else:
|
312
|
+
neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
|
313
|
+
neg_rel_transfer = self.transfer_matrix(pos_sample[:, 1])
|
314
|
+
|
315
|
+
pos_regul = (torch.mean(pos_head_emb ** 2) +
|
316
|
+
torch.mean(pos_relation_emb ** 2) +
|
317
|
+
torch.mean(pos_tail_emb ** 2) +
|
318
|
+
torch.mean(pos_rel_transfer ** 2)) / 4
|
319
|
+
|
320
|
+
neg_regul = (torch.mean(neg_head_emb ** 2) +
|
321
|
+
torch.mean(neg_relation_emb ** 2) +
|
322
|
+
torch.mean(neg_tail_emb ** 2) +
|
323
|
+
torch.mean(neg_rel_transfer ** 2)) / 4
|
324
|
+
|
325
|
+
regul = (pos_regul + neg_regul) / 2
|
326
|
+
|
327
|
+
return regul
|
328
|
+
|
329
|
+
def get_transr_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
330
|
+
|
331
|
+
"""返回 :py:class:`TransR` 的默认超参数优化配置。
|
332
|
+
|
333
|
+
``TransR`` :cite:`TransR` 进行超参数优化的时候,需要先训练一个 ``TransE`` :cite:`TransE` 模型(训练 1 epoch)。
|
334
|
+
然后 ``TransR`` :cite:`TransR` 的实体和关系的嵌入向量初始化为 TransE 的结果。
|
335
|
+
**margin_e** 、 **lr_e** 和 **opt_method_e** 是 ``TransE`` :cite:`TransE` 的训练超参数。
|
336
|
+
如果想获得更详细的信息请访问 :ref:`TransR <transr>`。
|
337
|
+
|
338
|
+
默认配置为::
|
339
|
+
|
340
|
+
parameters_dict = {
|
341
|
+
'model': {
|
342
|
+
'value': 'TransR'
|
343
|
+
},
|
344
|
+
'dim': {
|
345
|
+
'values': [50, 100]
|
346
|
+
},
|
347
|
+
'p_norm': {
|
348
|
+
'values': [1, 2]
|
349
|
+
},
|
350
|
+
'norm_flag': {
|
351
|
+
'value': True
|
352
|
+
},
|
353
|
+
'rand_init': {
|
354
|
+
'value': False
|
355
|
+
},
|
356
|
+
'margin_e': {
|
357
|
+
'values': [1.0, 3.0, 6.0]
|
358
|
+
},
|
359
|
+
'lr_e': {
|
360
|
+
'distribution': 'uniform',
|
361
|
+
'min': 1e-5,
|
362
|
+
'max': 1.0
|
363
|
+
},
|
364
|
+
'opt_method_e': {
|
365
|
+
'values': ['adam', 'adagrad', 'sgd']
|
366
|
+
},
|
367
|
+
}
|
368
|
+
|
369
|
+
:returns: :py:class:`TransR` 的默认超参数优化配置
|
370
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
371
|
+
"""
|
372
|
+
|
373
|
+
parameters_dict = {
|
374
|
+
'model': {
|
375
|
+
'value': 'TransR'
|
376
|
+
},
|
377
|
+
'dim': {
|
378
|
+
'values': [50, 100]
|
379
|
+
},
|
380
|
+
'p_norm': {
|
381
|
+
'values': [1, 2]
|
382
|
+
},
|
383
|
+
'norm_flag': {
|
384
|
+
'value': True
|
385
|
+
},
|
386
|
+
'rand_init': {
|
387
|
+
'value': False
|
388
|
+
},
|
389
|
+
'margin_e': {
|
390
|
+
'values': [1.0, 3.0, 6.0]
|
391
|
+
},
|
392
|
+
'lr_e': {
|
393
|
+
'distribution': 'uniform',
|
394
|
+
'min': 1e-5,
|
395
|
+
'max': 1.0
|
396
|
+
},
|
397
|
+
'opt_method_e': {
|
398
|
+
'values': ['adam', 'adagrad', 'sgd']
|
399
|
+
},
|
400
|
+
}
|
401
|
+
|
402
|
+
return parameters_dict
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/model/__init__.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 28, 2023
|
7
|
+
#
|
8
|
+
# 该头文件定义了 model 接口.
|
9
|
+
|
10
|
+
"""KGE 模型部分。"""
|
11
|
+
|
12
|
+
from __future__ import absolute_import
|
13
|
+
from __future__ import division
|
14
|
+
from __future__ import print_function
|
15
|
+
|
16
|
+
from .Model import Model
|
17
|
+
from .TransE import TransE, get_transe_hpo_config
|
18
|
+
from .TransH import TransH, get_transh_hpo_config
|
19
|
+
from .TransR import TransR, get_transr_hpo_config
|
20
|
+
from .TransD import TransD, get_transd_hpo_config
|
21
|
+
from .RotatE import RotatE, get_rotate_hpo_config
|
22
|
+
from .RESCAL import RESCAL, get_rescal_hpo_config
|
23
|
+
from .DistMult import DistMult, get_distmult_hpo_config
|
24
|
+
from .HolE import HolE, get_hole_hpo_config
|
25
|
+
from .ComplEx import ComplEx, get_complex_hpo_config
|
26
|
+
from .Analogy import Analogy, get_analogy_hpo_config
|
27
|
+
from .SimplE import SimplE, get_simple_hpo_config
|
28
|
+
from .RGCN import RGCN, get_rgcn_hpo_config
|
29
|
+
from .CompGCN import CompGCN, CompGCNCov, get_compgcn_hpo_config
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'Model',
|
33
|
+
'TransE',
|
34
|
+
'get_transe_hpo_config',
|
35
|
+
'TransH',
|
36
|
+
'get_transh_hpo_config',
|
37
|
+
'TransR',
|
38
|
+
'get_transr_hpo_config',
|
39
|
+
'TransD',
|
40
|
+
'get_transd_hpo_config',
|
41
|
+
'RotatE',
|
42
|
+
'get_rotate_hpo_config',
|
43
|
+
'RESCAL',
|
44
|
+
'get_rescal_hpo_config',
|
45
|
+
'DistMult',
|
46
|
+
'get_distmult_hpo_config',
|
47
|
+
'HolE',
|
48
|
+
'get_hole_hpo_config',
|
49
|
+
'ComplEx',
|
50
|
+
'get_complex_hpo_config',
|
51
|
+
'Analogy',
|
52
|
+
'get_analogy_hpo_config',
|
53
|
+
'SimplE',
|
54
|
+
'get_simple_hpo_config',
|
55
|
+
'RGCN',
|
56
|
+
'get_rgcn_hpo_config',
|
57
|
+
'CompGCN',
|
58
|
+
'CompGCNCov',
|
59
|
+
'get_compgcn_hpo_config'
|
60
|
+
]
|
@@ -0,0 +1,140 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/strategy/CompGCNSampling.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 20, 2023
|
7
|
+
#
|
8
|
+
# 该脚本定义了 CompGCN 模型的训练策略.
|
9
|
+
|
10
|
+
"""
|
11
|
+
CompGCNSampling - 训练策略类,包含损失函数。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import dgl
|
15
|
+
import torch
|
16
|
+
import typing
|
17
|
+
from ..loss import Loss
|
18
|
+
from ..model import CompGCN
|
19
|
+
from .Strategy import Strategy
|
20
|
+
|
21
|
+
class CompGCNSampling(Strategy):
|
22
|
+
|
23
|
+
"""
|
24
|
+
将模型和损失函数封装到一起,方便模型训练,用于 ``CompGCN`` :cite:`CompGCN`。
|
25
|
+
|
26
|
+
例子::
|
27
|
+
|
28
|
+
from unike.module.model import CompGCN
|
29
|
+
from unike.module.loss import CompGCNLoss
|
30
|
+
from unike.module.strategy import CompGCNSampling
|
31
|
+
from unike.config import Trainer, GraphTester
|
32
|
+
|
33
|
+
# define the model
|
34
|
+
compgcn = CompGCN(
|
35
|
+
ent_tol = dataloader.train_sampler.ent_tol,
|
36
|
+
rel_tol = dataloader.train_sampler.rel_tol,
|
37
|
+
dim = 100
|
38
|
+
)
|
39
|
+
|
40
|
+
# define the loss function
|
41
|
+
model = CompGCNSampling(
|
42
|
+
model = compgcn,
|
43
|
+
loss = CompGCNLoss(model = compgcn),
|
44
|
+
ent_tol = dataloader.train_sampler.ent_tol
|
45
|
+
)
|
46
|
+
|
47
|
+
# test the model
|
48
|
+
tester = GraphTester(model = compgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0', prediction = "tail")
|
49
|
+
|
50
|
+
# train the model
|
51
|
+
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
|
52
|
+
epochs = 2000, lr = 0.0001, use_gpu = True, device = 'cuda:0',
|
53
|
+
tester = tester, test = True, valid_interval = 50, log_interval = 50,
|
54
|
+
save_interval = 50, save_path = '../../checkpoint/compgcn.pth'
|
55
|
+
)
|
56
|
+
trainer.run()
|
57
|
+
"""
|
58
|
+
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
model: CompGCN = None,
|
62
|
+
loss: Loss = None,
|
63
|
+
smoothing: float = 0.1,
|
64
|
+
ent_tol: int = None):
|
65
|
+
|
66
|
+
"""创建 CompGCNSampling 对象。
|
67
|
+
|
68
|
+
:param model: CompGCN 模型
|
69
|
+
:type model: :py:class:`unike.module.model.CompGCN`
|
70
|
+
:param loss: 损失函数。
|
71
|
+
:type loss: :py:class:`unike.module.loss.Loss`
|
72
|
+
:param smoothing: smoothing
|
73
|
+
:type smoothing: float
|
74
|
+
:param ent_tol: 实体个数
|
75
|
+
:type ent_tol: int
|
76
|
+
"""
|
77
|
+
|
78
|
+
super(CompGCNSampling, self).__init__()
|
79
|
+
|
80
|
+
#: CompGCN 模型,即 :py:class:`unike.module.model.CompGCN`
|
81
|
+
self.model: CompGCN = model
|
82
|
+
#: 损失函数,即 :py:class:`unike.module.loss.Loss`
|
83
|
+
self.loss: Loss = loss
|
84
|
+
#: smoothing
|
85
|
+
self.smoothing: float = smoothing
|
86
|
+
#: 实体个数
|
87
|
+
self.ent_tol: int = ent_tol
|
88
|
+
|
89
|
+
def forward(
|
90
|
+
self,
|
91
|
+
data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]) -> torch.Tensor:
|
92
|
+
|
93
|
+
"""计算最后的损失值。定义每次调用时执行的计算。
|
94
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
95
|
+
|
96
|
+
:param data: 数据
|
97
|
+
:type data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
|
98
|
+
:returns: 损失值
|
99
|
+
:rtype: torch.Tensor
|
100
|
+
"""
|
101
|
+
|
102
|
+
graph = data["graph"]
|
103
|
+
relation = data['relation']
|
104
|
+
norm = data['norm']
|
105
|
+
sample = data["sample"]
|
106
|
+
label = data["label"]
|
107
|
+
score = self.model(graph, relation, norm, sample)
|
108
|
+
label = (1.0 - self.smoothing) * label + (1.0 / self.ent_tol)
|
109
|
+
loss = self.loss(score, label)
|
110
|
+
return loss
|
111
|
+
|
112
|
+
def get_compgcn_sampling_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
113
|
+
|
114
|
+
"""返回 :py:class:`CompGCNSampling` 的默认超参数优化配置。
|
115
|
+
|
116
|
+
默认配置为::
|
117
|
+
|
118
|
+
parameters_dict = {
|
119
|
+
'strategy': {
|
120
|
+
'value': 'CompGCNSampling'
|
121
|
+
},
|
122
|
+
'smoothing': {
|
123
|
+
'value': 0.1
|
124
|
+
}
|
125
|
+
}
|
126
|
+
|
127
|
+
:returns: :py:class:`CompGCNSampling` 的默认超参数优化配置
|
128
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
129
|
+
"""
|
130
|
+
|
131
|
+
parameters_dict = {
|
132
|
+
'strategy': {
|
133
|
+
'value': 'CompGCNSampling'
|
134
|
+
},
|
135
|
+
'smoothing': {
|
136
|
+
'value': 0.1
|
137
|
+
}
|
138
|
+
}
|
139
|
+
|
140
|
+
return parameters_dict
|