unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,290 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/model/TransE.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 9, 2024
|
7
|
+
#
|
8
|
+
# 该头文件定义了 TransE.
|
9
|
+
|
10
|
+
"""
|
11
|
+
TransE - 第一个平移模型,简单而且高效。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import typing
|
16
|
+
import torch.nn as nn
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from .Model import Model
|
19
|
+
from typing_extensions import override
|
20
|
+
|
21
|
+
class TransE(Model):
|
22
|
+
|
23
|
+
"""
|
24
|
+
``TransE`` :cite:`TransE` 提出于 2013 年,是第一个平移模型,开创了平移模型研究方向。由于其简单性和高效性,
|
25
|
+
至今依旧是常用基线模型,在某些数据集上能够比其他更复杂的模型表现的更好。
|
26
|
+
|
27
|
+
评分函数为:
|
28
|
+
|
29
|
+
.. math::
|
30
|
+
|
31
|
+
\parallel h + r - t \parallel_{L_1/L_2}
|
32
|
+
|
33
|
+
正三元组的评分函数的值越小越好,如果想获得更详细的信息请访问 :ref:`TransE <transe>`。
|
34
|
+
|
35
|
+
例子::
|
36
|
+
|
37
|
+
from unike.data import KGEDataLoader, BernSampler, TradTestSampler
|
38
|
+
from unike.module.model import TransE
|
39
|
+
from unike.module.loss import MarginLoss
|
40
|
+
from unike.module.strategy import NegativeSampling
|
41
|
+
from unike.config import Trainer, Tester
|
42
|
+
|
43
|
+
# dataloader for training
|
44
|
+
dataloader = KGEDataLoader(
|
45
|
+
in_path = "../../benchmarks/FB15K/",
|
46
|
+
batch_size = 8192,
|
47
|
+
neg_ent = 25,
|
48
|
+
test = True,
|
49
|
+
test_batch_size = 256,
|
50
|
+
num_workers = 16,
|
51
|
+
train_sampler = BernSampler,
|
52
|
+
test_sampler = TradTestSampler
|
53
|
+
)
|
54
|
+
|
55
|
+
# define the model
|
56
|
+
transe = TransE(
|
57
|
+
ent_tol = dataloader.get_ent_tol(),
|
58
|
+
rel_tol = dataloader.get_rel_tol(),
|
59
|
+
dim = 50,
|
60
|
+
p_norm = 1,
|
61
|
+
norm_flag = True)
|
62
|
+
|
63
|
+
# define the loss function
|
64
|
+
model = NegativeSampling(
|
65
|
+
model = transe,
|
66
|
+
loss = MarginLoss(margin = 1.0),
|
67
|
+
regul_rate = 0.01
|
68
|
+
)
|
69
|
+
|
70
|
+
# test the model
|
71
|
+
tester = Tester(model = transe, data_loader = dataloader, use_gpu = True, device = 'cuda:1')
|
72
|
+
|
73
|
+
# train the model
|
74
|
+
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
|
75
|
+
epochs = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1',
|
76
|
+
tester = tester, test = True, valid_interval = 100,
|
77
|
+
log_interval = 100, save_interval = 100,
|
78
|
+
save_path = '../../checkpoint/transe.pth', delta = 0.01)
|
79
|
+
trainer.run()
|
80
|
+
"""
|
81
|
+
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
ent_tol: int,
|
85
|
+
rel_tol: int,
|
86
|
+
dim: int = 100,
|
87
|
+
p_norm: int = 1,
|
88
|
+
norm_flag: bool = True,
|
89
|
+
margin: float | None = None):
|
90
|
+
|
91
|
+
"""创建 TransE 对象。
|
92
|
+
|
93
|
+
:param ent_tol: 实体的个数
|
94
|
+
:type ent_tol: int
|
95
|
+
:param rel_tol: 关系的个数
|
96
|
+
:type rel_tol: int
|
97
|
+
:param dim: 实体和关系嵌入向量的维度
|
98
|
+
:type dim: int
|
99
|
+
:param p_norm: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
|
100
|
+
:type p_norm: int
|
101
|
+
:param norm_flag: 是否利用 :py:func:`torch.nn.functional.normalize` 对实体和关系嵌入的最后一维执行 L2-norm。
|
102
|
+
:type norm_flag: bool
|
103
|
+
:param margin: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
|
104
|
+
:type margin: float
|
105
|
+
"""
|
106
|
+
|
107
|
+
super(TransE, self).__init__(ent_tol, rel_tol)
|
108
|
+
|
109
|
+
#: 实体和关系嵌入向量的维度
|
110
|
+
self.dim: int = dim
|
111
|
+
#: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
|
112
|
+
self.p_norm: int = p_norm
|
113
|
+
#: 是否利用 :py:func:`torch.nn.functional.normalize`
|
114
|
+
#: 对实体和关系嵌入向量的最后一维执行 L2-norm。
|
115
|
+
self.norm_flag: bool = norm_flag
|
116
|
+
|
117
|
+
#: 根据实体个数,创建的实体嵌入
|
118
|
+
self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim)
|
119
|
+
#: 根据关系个数,创建的关系嵌入
|
120
|
+
self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim)
|
121
|
+
|
122
|
+
if margin != None:
|
123
|
+
#: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
|
124
|
+
self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
|
125
|
+
self.margin.requires_grad = False
|
126
|
+
self.margin_flag: bool = True
|
127
|
+
else:
|
128
|
+
self.margin_flag: bool = False
|
129
|
+
|
130
|
+
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
|
131
|
+
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
|
132
|
+
|
133
|
+
@override
|
134
|
+
def forward(
|
135
|
+
self,
|
136
|
+
triples: torch.Tensor,
|
137
|
+
negs: torch.Tensor = None,
|
138
|
+
mode: str = 'single') -> torch.Tensor:
|
139
|
+
|
140
|
+
"""
|
141
|
+
定义每次调用时执行的计算。
|
142
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
143
|
+
|
144
|
+
:param triples: 正确的三元组
|
145
|
+
:type triples: torch.Tensor
|
146
|
+
:param negs: 负三元组类别
|
147
|
+
:type negs: torch.Tensor
|
148
|
+
:param mode: 模式
|
149
|
+
:type triples: str
|
150
|
+
:returns: 三元组的得分
|
151
|
+
:rtype: torch.Tensor
|
152
|
+
"""
|
153
|
+
|
154
|
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
|
155
|
+
score = self._calc(head_emb, relation_emb, tail_emb)
|
156
|
+
if self.margin_flag:
|
157
|
+
return self.margin - score
|
158
|
+
else:
|
159
|
+
return score
|
160
|
+
|
161
|
+
def _calc(
|
162
|
+
self,
|
163
|
+
h: torch.Tensor,
|
164
|
+
r: torch.Tensor,
|
165
|
+
t: torch.Tensor) -> torch.Tensor:
|
166
|
+
|
167
|
+
"""计算 TransE 的评分函数。
|
168
|
+
|
169
|
+
:param h: 头实体的向量。
|
170
|
+
:type h: torch.Tensor
|
171
|
+
:param r: 关系的向量。
|
172
|
+
:type r: torch.Tensor
|
173
|
+
:param t: 尾实体的向量。
|
174
|
+
:type t: torch.Tensor
|
175
|
+
:returns: 三元组的得分
|
176
|
+
:rtype: torch.Tensor
|
177
|
+
"""
|
178
|
+
|
179
|
+
# 对嵌入的最后一维进行归一化
|
180
|
+
if self.norm_flag:
|
181
|
+
h = F.normalize(h, 2, -1)
|
182
|
+
r = F.normalize(r, 2, -1)
|
183
|
+
t = F.normalize(t, 2, -1)
|
184
|
+
|
185
|
+
score = (h + r) - t
|
186
|
+
|
187
|
+
# 利用距离函数计算得分
|
188
|
+
score = torch.norm(score, self.p_norm, -1)
|
189
|
+
return score
|
190
|
+
|
191
|
+
@override
|
192
|
+
def predict(
|
193
|
+
self,
|
194
|
+
data: dict[str, typing.Union[torch.Tensor,str]],
|
195
|
+
mode: str) -> torch.Tensor:
|
196
|
+
|
197
|
+
"""TransE 的推理方法。
|
198
|
+
|
199
|
+
:param data: 数据。
|
200
|
+
:type data: dict[str, typing.Union[torch.Tensor,str]]
|
201
|
+
:param mode: 'head_predict' 或 'tail_predict'
|
202
|
+
:type mode: str
|
203
|
+
:returns: 三元组的得分
|
204
|
+
:rtype: torch.Tensor
|
205
|
+
"""
|
206
|
+
|
207
|
+
triples = data["positive_sample"]
|
208
|
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
|
209
|
+
score = self._calc(head_emb, relation_emb, tail_emb)
|
210
|
+
|
211
|
+
if self.margin_flag:
|
212
|
+
score = self.margin - score
|
213
|
+
return score
|
214
|
+
else:
|
215
|
+
return -score
|
216
|
+
|
217
|
+
def regularization(
|
218
|
+
self,
|
219
|
+
data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
|
220
|
+
|
221
|
+
"""L2 正则化函数(又称权重衰减),在损失函数中用到。
|
222
|
+
|
223
|
+
:param data: 数据。
|
224
|
+
:type data: dict[str, typing.Union[torch.Tensor,str]]
|
225
|
+
:returns: 模型参数的正则损失
|
226
|
+
:rtype: torch.Tensor
|
227
|
+
"""
|
228
|
+
|
229
|
+
pos_sample = data["positive_sample"]
|
230
|
+
neg_sample = data["negative_sample"]
|
231
|
+
mode = data["mode"]
|
232
|
+
pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
|
233
|
+
if mode == "bern":
|
234
|
+
neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
|
235
|
+
else:
|
236
|
+
neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
|
237
|
+
|
238
|
+
pos_regul = (torch.mean(pos_head_emb ** 2) +
|
239
|
+
torch.mean(pos_relation_emb ** 2) +
|
240
|
+
torch.mean(pos_tail_emb ** 2)) / 3
|
241
|
+
|
242
|
+
neg_regul = (torch.mean(neg_head_emb ** 2) +
|
243
|
+
torch.mean(neg_relation_emb ** 2) +
|
244
|
+
torch.mean(neg_tail_emb ** 2)) / 3
|
245
|
+
|
246
|
+
regul = (pos_regul + neg_regul) / 2
|
247
|
+
|
248
|
+
return regul
|
249
|
+
|
250
|
+
def get_transe_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
251
|
+
|
252
|
+
"""返回 :py:class:`TransE` 的默认超参数优化配置。
|
253
|
+
|
254
|
+
默认配置为::
|
255
|
+
|
256
|
+
parameters_dict = {
|
257
|
+
'model': {
|
258
|
+
'value': 'TransE'
|
259
|
+
},
|
260
|
+
'dim': {
|
261
|
+
'values': [50, 100, 200]
|
262
|
+
},
|
263
|
+
'p_norm': {
|
264
|
+
'values': [1, 2]
|
265
|
+
},
|
266
|
+
'norm_flag': {
|
267
|
+
'value': True
|
268
|
+
}
|
269
|
+
}
|
270
|
+
|
271
|
+
:returns: :py:class:`TransE` 的默认超参数优化配置
|
272
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
273
|
+
"""
|
274
|
+
|
275
|
+
parameters_dict = {
|
276
|
+
'model': {
|
277
|
+
'value': 'TransE'
|
278
|
+
},
|
279
|
+
'dim': {
|
280
|
+
'values': [50, 100, 200]
|
281
|
+
},
|
282
|
+
'p_norm': {
|
283
|
+
'values': [1, 2]
|
284
|
+
},
|
285
|
+
'norm_flag': {
|
286
|
+
'value': True
|
287
|
+
}
|
288
|
+
}
|
289
|
+
|
290
|
+
return parameters_dict
|
@@ -0,0 +1,322 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/model/TransH.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
|
7
|
+
#
|
8
|
+
# 该头文件定义了 TransH.
|
9
|
+
|
10
|
+
"""
|
11
|
+
TransH - 是第二个平移模型,将关系建模为超平面上的平移操作。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import typing
|
16
|
+
import torch.nn as nn
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from .Model import Model
|
19
|
+
from typing_extensions import override
|
20
|
+
|
21
|
+
class TransH(Model):
|
22
|
+
|
23
|
+
"""
|
24
|
+
``TransH`` :cite:`TransH` 提出于 2014 年,是第二个平移模型,将关系建模为超平面上的平移操作。
|
25
|
+
|
26
|
+
评分函数为:
|
27
|
+
|
28
|
+
.. math::
|
29
|
+
|
30
|
+
\Vert (h-r_w^T hr_w)+r_d-(t-r_w^T tr_w)\Vert_{L_1/L_2}
|
31
|
+
|
32
|
+
正三元组的评分函数的值越小越好,如果想获得更详细的信息请访问 :ref:`TransH <transh>`。
|
33
|
+
|
34
|
+
例子::
|
35
|
+
|
36
|
+
from unike.data import KGEDataLoader, BernSampler, TradTestSampler
|
37
|
+
from unike.module.model import TransH
|
38
|
+
from unike.module.loss import MarginLoss
|
39
|
+
from unike.module.strategy import NegativeSampling
|
40
|
+
from unike.config import Trainer, Tester
|
41
|
+
|
42
|
+
# dataloader for training
|
43
|
+
dataloader = KGEDataLoader(
|
44
|
+
in_path = "../../benchmarks/FB15K237/",
|
45
|
+
batch_size = 4096,
|
46
|
+
neg_ent = 25,
|
47
|
+
test = True,
|
48
|
+
test_batch_size = 30,
|
49
|
+
num_workers = 16,
|
50
|
+
train_sampler = BernSampler,
|
51
|
+
test_sampler = TradTestSampler
|
52
|
+
)
|
53
|
+
|
54
|
+
# define the model
|
55
|
+
transh = TransH(
|
56
|
+
ent_tol = dataloader.get_ent_tol(),
|
57
|
+
rel_tol = dataloader.get_rel_tol(),
|
58
|
+
dim = 200,
|
59
|
+
p_norm = 1,
|
60
|
+
norm_flag = True)
|
61
|
+
|
62
|
+
# define the loss function
|
63
|
+
model = NegativeSampling(
|
64
|
+
model = transh,
|
65
|
+
loss = MarginLoss(margin = 4.0),
|
66
|
+
# regul_rate = 0.01
|
67
|
+
)
|
68
|
+
|
69
|
+
# test the model
|
70
|
+
tester = Tester(model = transh, data_loader = dataloader, use_gpu = True, device = 'cuda:1')
|
71
|
+
|
72
|
+
# train the model
|
73
|
+
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
|
74
|
+
epochs = 1000, lr = 0.5, use_gpu = True, device = 'cuda:1',
|
75
|
+
tester = tester, test = True, valid_interval = 100,
|
76
|
+
log_interval = 100, save_interval = 100, save_path = '../../checkpoint/transh.pth',
|
77
|
+
delta = 0.01)
|
78
|
+
trainer.run()
|
79
|
+
"""
|
80
|
+
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
ent_tol: int,
|
84
|
+
rel_tol: int,
|
85
|
+
dim: int = 100,
|
86
|
+
p_norm: int = 1,
|
87
|
+
norm_flag: bool = True,
|
88
|
+
margin: float | None = None):
|
89
|
+
|
90
|
+
"""创建 TransH 对象。
|
91
|
+
|
92
|
+
:param ent_tol: 实体的个数
|
93
|
+
:type ent_tol: int
|
94
|
+
:param rel_tol: 关系的个数
|
95
|
+
:type rel_tol: int
|
96
|
+
:param dim: 实体、关系嵌入向量和和法向量的维度
|
97
|
+
:type dim: int
|
98
|
+
:param p_norm: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
|
99
|
+
:type p_norm: int
|
100
|
+
:param norm_flag: 是否利用 :py:func:`torch.nn.functional.normalize` 对实体和关系嵌入的最后一维执行 L2-norm。
|
101
|
+
:type norm_flag: bool
|
102
|
+
:param margin: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
|
103
|
+
:type margin: float
|
104
|
+
"""
|
105
|
+
|
106
|
+
super(TransH, self).__init__(ent_tol, rel_tol)
|
107
|
+
|
108
|
+
#: 实体、关系嵌入向量和和法向量的维度
|
109
|
+
self.dim: int = dim
|
110
|
+
#: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
|
111
|
+
self.p_norm: int = p_norm
|
112
|
+
#: 是否利用 :py:func:`torch.nn.functional.normalize`
|
113
|
+
#: 对实体和关系嵌入向量的最后一维执行 L2-norm。
|
114
|
+
self.norm_flag: bool = norm_flag
|
115
|
+
|
116
|
+
#: 根据实体个数,创建的实体嵌入
|
117
|
+
self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim)
|
118
|
+
#: 根据关系个数,创建的关系嵌入
|
119
|
+
self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim)
|
120
|
+
#: 根据关系个数,创建的法向量
|
121
|
+
self.norm_vector: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim)
|
122
|
+
|
123
|
+
if margin != None:
|
124
|
+
#: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
|
125
|
+
self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
|
126
|
+
self.margin.requires_grad = False
|
127
|
+
self.margin_flag = True
|
128
|
+
else:
|
129
|
+
self.margin_flag = False
|
130
|
+
|
131
|
+
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
|
132
|
+
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
|
133
|
+
nn.init.xavier_uniform_(self.norm_vector.weight.data)
|
134
|
+
|
135
|
+
@override
|
136
|
+
def forward(
|
137
|
+
self,
|
138
|
+
triples: torch.Tensor,
|
139
|
+
negs: torch.Tensor = None,
|
140
|
+
mode: str = 'single') -> torch.Tensor:
|
141
|
+
|
142
|
+
"""
|
143
|
+
定义每次调用时执行的计算。
|
144
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
145
|
+
|
146
|
+
:param triples: 正确的三元组
|
147
|
+
:type triples: torch.Tensor
|
148
|
+
:param negs: 负三元组类别
|
149
|
+
:type negs: torch.Tensor
|
150
|
+
:param mode: 模式
|
151
|
+
:type triples: str
|
152
|
+
:returns: 三元组的得分
|
153
|
+
:rtype: torch.Tensor
|
154
|
+
"""
|
155
|
+
|
156
|
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
|
157
|
+
norm_vector = self.norm_vector(triples[:, 1]).unsqueeze(dim=1)
|
158
|
+
head_emb = self._transfer(head_emb, norm_vector)
|
159
|
+
tail_emb = self._transfer(tail_emb, norm_vector)
|
160
|
+
score = self._calc(head_emb, relation_emb, tail_emb)
|
161
|
+
|
162
|
+
if self.margin_flag:
|
163
|
+
return self.margin - score
|
164
|
+
else:
|
165
|
+
return score
|
166
|
+
|
167
|
+
def _transfer(
|
168
|
+
self,
|
169
|
+
e: torch.Tensor,
|
170
|
+
norm: torch.Tensor) -> torch.Tensor:
|
171
|
+
|
172
|
+
"""
|
173
|
+
将头实体或尾实体的向量投影到超平面上。
|
174
|
+
|
175
|
+
:param e: 头实体或尾实体向量。
|
176
|
+
:type e: torch.Tensor
|
177
|
+
:param norm: 法向量
|
178
|
+
:type norm: torch.Tensor
|
179
|
+
:returns: 投影后的实体向量
|
180
|
+
:rtype: torch.Tensor
|
181
|
+
"""
|
182
|
+
|
183
|
+
norm = F.normalize(norm, p = 2, dim = -1)
|
184
|
+
return e - torch.sum(e * norm, -1, True) * norm
|
185
|
+
|
186
|
+
def _calc(
|
187
|
+
self,
|
188
|
+
h: torch.Tensor,
|
189
|
+
r: torch.Tensor,
|
190
|
+
t: torch.Tensor) -> torch.Tensor:
|
191
|
+
|
192
|
+
"""计算 TransH 的评分函数。
|
193
|
+
|
194
|
+
:param h: 头实体的向量。
|
195
|
+
:type h: torch.Tensor
|
196
|
+
:param r: 关系实体的向量。
|
197
|
+
:type r: torch.Tensor
|
198
|
+
:param t: 尾实体的向量。
|
199
|
+
:type t: torch.Tensor
|
200
|
+
:returns: 三元组的得分
|
201
|
+
:rtype: torch.Tensor
|
202
|
+
"""
|
203
|
+
|
204
|
+
# 对嵌入的最后一维进行归一化
|
205
|
+
if self.norm_flag:
|
206
|
+
h = F.normalize(h, 2, -1)
|
207
|
+
r = F.normalize(r, 2, -1)
|
208
|
+
t = F.normalize(t, 2, -1)
|
209
|
+
|
210
|
+
score = (h + r) - t
|
211
|
+
|
212
|
+
# 利用距离函数计算得分
|
213
|
+
score = torch.norm(score, self.p_norm, -1)
|
214
|
+
return score
|
215
|
+
|
216
|
+
@override
|
217
|
+
def predict(
|
218
|
+
self,
|
219
|
+
data: dict[str, typing.Union[torch.Tensor,str]],
|
220
|
+
mode: str) -> torch.Tensor:
|
221
|
+
|
222
|
+
"""TransH 的推理方法。
|
223
|
+
|
224
|
+
:param data: 数据。
|
225
|
+
:type data: dict[str, typing.Union[torch.Tensor,str]]
|
226
|
+
:param mode: 'head_predict' 或 'tail_predict'
|
227
|
+
:type mode: str
|
228
|
+
:returns: 三元组的得分
|
229
|
+
:rtype: torch.Tensor
|
230
|
+
"""
|
231
|
+
|
232
|
+
triples = data["positive_sample"]
|
233
|
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
|
234
|
+
norm_vector = self.norm_vector(triples[:, 1]).unsqueeze(dim=1)
|
235
|
+
head_emb = self._transfer(head_emb, norm_vector)
|
236
|
+
tail_emb = self._transfer(tail_emb, norm_vector)
|
237
|
+
score = self._calc(head_emb, relation_emb, tail_emb)
|
238
|
+
|
239
|
+
if self.margin_flag:
|
240
|
+
score = self.margin - score
|
241
|
+
return score
|
242
|
+
else:
|
243
|
+
return -score
|
244
|
+
|
245
|
+
def regularization(
|
246
|
+
self,
|
247
|
+
data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
|
248
|
+
|
249
|
+
"""L2 正则化函数(又称权重衰减),在损失函数中用到。
|
250
|
+
|
251
|
+
:param data: 数据。
|
252
|
+
:type data: dict[str, typing.Union[torch.Tensor, str]]
|
253
|
+
:returns: 模型参数的正则损失
|
254
|
+
:rtype: torch.Tensor
|
255
|
+
"""
|
256
|
+
|
257
|
+
pos_sample = data["positive_sample"]
|
258
|
+
neg_sample = data["negative_sample"]
|
259
|
+
mode = data["mode"]
|
260
|
+
pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
|
261
|
+
pos_norm_vector = self.norm_vector(pos_sample[:, 1]).unsqueeze(dim=1)
|
262
|
+
if mode == "bern":
|
263
|
+
neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
|
264
|
+
else:
|
265
|
+
neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
|
266
|
+
neg_norm_vector = self.norm_vector(pos_sample[:, 1]).unsqueeze(dim=1)
|
267
|
+
|
268
|
+
pos_regul = (torch.mean(pos_head_emb ** 2) +
|
269
|
+
torch.mean(pos_relation_emb ** 2) +
|
270
|
+
torch.mean(pos_tail_emb ** 2) +
|
271
|
+
torch.mean(pos_norm_vector ** 2)) / 4
|
272
|
+
|
273
|
+
neg_regul = (torch.mean(neg_head_emb ** 2) +
|
274
|
+
torch.mean(neg_relation_emb ** 2) +
|
275
|
+
torch.mean(neg_tail_emb ** 2) +
|
276
|
+
torch.mean(neg_norm_vector ** 2)) / 4
|
277
|
+
|
278
|
+
regul = (pos_regul + neg_regul) / 2
|
279
|
+
|
280
|
+
return regul
|
281
|
+
|
282
|
+
def get_transh_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
283
|
+
|
284
|
+
"""返回 :py:class:`TransH` 的默认超参数优化配置。
|
285
|
+
|
286
|
+
默认配置为::
|
287
|
+
|
288
|
+
parameters_dict = {
|
289
|
+
'model': {
|
290
|
+
'value': 'TransH'
|
291
|
+
},
|
292
|
+
'dim': {
|
293
|
+
'values': [50, 100, 200]
|
294
|
+
},
|
295
|
+
'p_norm': {
|
296
|
+
'values': [1, 2]
|
297
|
+
},
|
298
|
+
'norm_flag': {
|
299
|
+
'value': True
|
300
|
+
}
|
301
|
+
}
|
302
|
+
|
303
|
+
:returns: :py:class:`TransH` 的默认超参数优化配置
|
304
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
305
|
+
"""
|
306
|
+
|
307
|
+
parameters_dict = {
|
308
|
+
'model': {
|
309
|
+
'value': 'TransH'
|
310
|
+
},
|
311
|
+
'dim': {
|
312
|
+
'values': [50, 100, 200]
|
313
|
+
},
|
314
|
+
'p_norm': {
|
315
|
+
'values': [1, 2]
|
316
|
+
},
|
317
|
+
'norm_flag': {
|
318
|
+
'value': True
|
319
|
+
}
|
320
|
+
}
|
321
|
+
|
322
|
+
return parameters_dict
|