unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,237 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/model/SimplE.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 7, 2023
|
7
|
+
#
|
8
|
+
# 该头文件定义了 SimplE.
|
9
|
+
|
10
|
+
"""
|
11
|
+
SimplE - 简单的双线性模型,能够为头实体和尾实体学习不同的嵌入向量。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import math
|
15
|
+
import torch
|
16
|
+
import typing
|
17
|
+
import numpy as np
|
18
|
+
import torch.nn as nn
|
19
|
+
from .Model import Model
|
20
|
+
from typing_extensions import override
|
21
|
+
|
22
|
+
class SimplE(Model):
|
23
|
+
|
24
|
+
"""
|
25
|
+
``SimplE`` :cite:`SimplE` 提出于 2018 年,简单的双线性模型,能够为头实体和尾实体学习不同的嵌入向量。
|
26
|
+
|
27
|
+
评分函数为:
|
28
|
+
|
29
|
+
.. math::
|
30
|
+
|
31
|
+
1/2(<\mathbf{h}_{i}, \mathbf{v}_r, \mathbf{t}_{j}> + <\mathbf{h}_{j}, \mathbf{v}_{r^{-1}}, \mathbf{t}_{i}>)
|
32
|
+
|
33
|
+
:math:`< \mathbf{a}, \mathbf{b}, \mathbf{c} >` 为逐元素多线性点积(element-wise multi-linear dot product)。
|
34
|
+
|
35
|
+
正三元组的评分函数的值越大越好,负三元组越小越好,如果想获得更详细的信息请访问 :ref:`SimplE <simple>`。
|
36
|
+
|
37
|
+
例子::
|
38
|
+
|
39
|
+
from unike.config import Trainer, Tester
|
40
|
+
from unike.module.model import SimplE
|
41
|
+
from unike.module.loss import SoftplusLoss
|
42
|
+
from unike.module.strategy import NegativeSampling
|
43
|
+
|
44
|
+
# define the model
|
45
|
+
simple = SimplE(
|
46
|
+
ent_tol = train_dataloader.get_ent_tol(),
|
47
|
+
rel_tol = train_dataloader.get_rel_tol(),
|
48
|
+
dim = config.dim
|
49
|
+
)
|
50
|
+
|
51
|
+
# define the loss function
|
52
|
+
model = NegativeSampling(
|
53
|
+
model = simple,
|
54
|
+
loss = SoftplusLoss(),
|
55
|
+
batch_size = train_dataloader.get_batch_size(),
|
56
|
+
regul_rate = config.regul_rate
|
57
|
+
)
|
58
|
+
|
59
|
+
# dataloader for test
|
60
|
+
test_dataloader = TestDataLoader(in_path = config.in_path)
|
61
|
+
|
62
|
+
# test the model
|
63
|
+
tester = Tester(model = simple, data_loader = test_dataloader, use_gpu = config.use_gpu, device = config.device)
|
64
|
+
|
65
|
+
# train the model
|
66
|
+
trainer = Trainer(model = model, data_loader = train_dataloader, epochs = config.epochs,
|
67
|
+
lr = config.lr, opt_method = config.opt_method, use_gpu = config.use_gpu, device = config.device,
|
68
|
+
tester = tester, test = config.test, valid_interval = config.valid_interval,
|
69
|
+
log_interval = config.log_interval, save_interval = config.save_interval,
|
70
|
+
save_path = config.save_path, use_wandb = True)
|
71
|
+
trainer.run()
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
ent_tol: int,
|
77
|
+
rel_tol: int,
|
78
|
+
dim: int = 100):
|
79
|
+
|
80
|
+
"""创建 SimplE 对象。
|
81
|
+
|
82
|
+
:param ent_tol: 实体的个数
|
83
|
+
:type ent_tol: int
|
84
|
+
:param rel_tol: 关系的个数
|
85
|
+
:type rel_tol: int
|
86
|
+
:param dim: 实体嵌入向量和关系嵌入向量的维度
|
87
|
+
:type dim: int
|
88
|
+
"""
|
89
|
+
|
90
|
+
super(SimplE, self).__init__(ent_tol, rel_tol)
|
91
|
+
|
92
|
+
#: 实体嵌入向量和关系嵌入向量的维度
|
93
|
+
self.dim: int = dim
|
94
|
+
|
95
|
+
#: 根据实体个数,创建的实体嵌入
|
96
|
+
self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim * 2)
|
97
|
+
#: 根据关系个数,创建的关系嵌入
|
98
|
+
self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim * 2)
|
99
|
+
|
100
|
+
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
|
101
|
+
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
|
102
|
+
|
103
|
+
@override
|
104
|
+
def forward(
|
105
|
+
self,
|
106
|
+
triples: torch.Tensor,
|
107
|
+
negs: torch.Tensor = None,
|
108
|
+
mode: str = 'single') -> torch.Tensor:
|
109
|
+
|
110
|
+
"""
|
111
|
+
定义每次调用时执行的计算。
|
112
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
113
|
+
|
114
|
+
:param triples: 正确的三元组
|
115
|
+
:type triples: torch.Tensor
|
116
|
+
:param negs: 负三元组类别
|
117
|
+
:type negs: torch.Tensor
|
118
|
+
:param mode: 模式
|
119
|
+
:type triples: str
|
120
|
+
:returns: 三元组的得分
|
121
|
+
:rtype: torch.Tensor
|
122
|
+
"""
|
123
|
+
|
124
|
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
|
125
|
+
score = self._calc(head_emb, relation_emb, tail_emb)
|
126
|
+
return score
|
127
|
+
|
128
|
+
def _calc(
|
129
|
+
self,
|
130
|
+
h: torch.Tensor,
|
131
|
+
r: torch.Tensor,
|
132
|
+
t: torch.Tensor) -> torch.Tensor:
|
133
|
+
|
134
|
+
"""计算 SimplE 的评分函数。
|
135
|
+
|
136
|
+
:param h: 头实体的向量。
|
137
|
+
:type h: torch.Tensor
|
138
|
+
:param r: 关系的向量。
|
139
|
+
:type r: torch.Tensor
|
140
|
+
:param t: 尾实体的向量。
|
141
|
+
:type t: torch.Tensor
|
142
|
+
:returns: 三元组的得分
|
143
|
+
:rtype: torch.Tensor
|
144
|
+
"""
|
145
|
+
|
146
|
+
hh_embs, th_embs = torch.chunk(h, 2, dim=-1)
|
147
|
+
r_embs, r_inv_embs = torch.chunk(r, 2, dim=-1)
|
148
|
+
ht_embs, tt_embs = torch.chunk(t, 2, dim=-1)
|
149
|
+
|
150
|
+
scores1 = torch.sum(hh_embs * r_embs * tt_embs, -1)
|
151
|
+
scores2 = torch.sum(ht_embs * r_inv_embs * th_embs, -1)
|
152
|
+
|
153
|
+
# Without clipping, we run into NaN problems.
|
154
|
+
# 基于论文作者的实现。
|
155
|
+
return torch.clamp((scores1 + scores2) / 2, -20, 20)
|
156
|
+
|
157
|
+
@override
|
158
|
+
def predict(
|
159
|
+
self,
|
160
|
+
data: dict[str, typing.Union[torch.Tensor,str]],
|
161
|
+
mode) -> torch.Tensor:
|
162
|
+
|
163
|
+
"""SimplE 的推理方法。
|
164
|
+
|
165
|
+
:param data: 数据。
|
166
|
+
:type data: dict[str, typing.Union[torch.Tensor,str]]
|
167
|
+
:returns: 三元组的得分
|
168
|
+
:rtype: torch.Tensor
|
169
|
+
"""
|
170
|
+
|
171
|
+
triples = data["positive_sample"]
|
172
|
+
head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
|
173
|
+
score = self._calc(head_emb, relation_emb, tail_emb)
|
174
|
+
return score
|
175
|
+
|
176
|
+
def regularization(
|
177
|
+
self,
|
178
|
+
data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
|
179
|
+
|
180
|
+
"""L2 正则化函数(又称权重衰减),在损失函数中用到。
|
181
|
+
|
182
|
+
:param data: 数据。
|
183
|
+
:type data: dict[str, typing.Union[torch.Tensor, str]]
|
184
|
+
:returns: 模型参数的正则损失
|
185
|
+
:rtype: torch.Tensor
|
186
|
+
"""
|
187
|
+
|
188
|
+
pos_sample = data["positive_sample"]
|
189
|
+
neg_sample = data["negative_sample"]
|
190
|
+
mode = data["mode"]
|
191
|
+
pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
|
192
|
+
if mode == "bern":
|
193
|
+
neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
|
194
|
+
else:
|
195
|
+
neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
|
196
|
+
|
197
|
+
pos_regul = (torch.mean(pos_head_emb ** 2) +
|
198
|
+
torch.mean(pos_relation_emb ** 2) +
|
199
|
+
torch.mean(pos_tail_emb ** 2)) / 3
|
200
|
+
|
201
|
+
neg_regul = (torch.mean(neg_head_emb ** 2) +
|
202
|
+
torch.mean(neg_relation_emb ** 2) +
|
203
|
+
torch.mean(neg_tail_emb ** 2)) / 3
|
204
|
+
|
205
|
+
regul = (pos_regul + neg_regul) / 2
|
206
|
+
|
207
|
+
return regul
|
208
|
+
|
209
|
+
def get_simple_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
210
|
+
|
211
|
+
"""返回 :py:class:`SimplE` 的默认超参数优化配置。
|
212
|
+
|
213
|
+
默认配置为::
|
214
|
+
|
215
|
+
parameters_dict = {
|
216
|
+
'model': {
|
217
|
+
'value': 'SimplE'
|
218
|
+
},
|
219
|
+
'dim': {
|
220
|
+
'values': [50, 100, 200]
|
221
|
+
}
|
222
|
+
}
|
223
|
+
|
224
|
+
:returns: :py:class:`SimplE` 的默认超参数优化配置
|
225
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
226
|
+
"""
|
227
|
+
|
228
|
+
parameters_dict = {
|
229
|
+
'model': {
|
230
|
+
'value': 'SimplE'
|
231
|
+
},
|
232
|
+
'dim': {
|
233
|
+
'values': [50, 100, 200]
|
234
|
+
}
|
235
|
+
}
|
236
|
+
|
237
|
+
return parameters_dict
|
@@ -0,0 +1,458 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/model/TransD.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2023
|
7
|
+
#
|
8
|
+
# 该头文件定义了 TransD.
|
9
|
+
|
10
|
+
"""
|
11
|
+
TransD - 自动生成映射矩阵,简单而且高效,是对 TransR 的改进。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import typing
|
16
|
+
import torch.nn as nn
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from .Model import Model
|
19
|
+
from typing_extensions import override
|
20
|
+
|
21
|
+
class TransD(Model):
|
22
|
+
|
23
|
+
"""
|
24
|
+
``TransD`` :cite:`TransD` 提出于 2015 年,自动生成映射矩阵,简单而且高效,是对 TransR 的改进。
|
25
|
+
|
26
|
+
评分函数为:
|
27
|
+
|
28
|
+
.. math::
|
29
|
+
|
30
|
+
\parallel (\mathbf{r}_p \mathbf{h}_p^T + \mathbf{I})\mathbf{h} + \mathbf{r} - (\mathbf{r}_p \mathbf{t}_p^T + \mathbf{I})\mathbf{t} \parallel_{L_1/L_2}
|
31
|
+
|
32
|
+
正三元组的评分函数的值越小越好,如果想获得更详细的信息请访问 :ref:`TransD <transd>`。
|
33
|
+
|
34
|
+
例子::
|
35
|
+
|
36
|
+
from unike.utils import WandbLogger
|
37
|
+
from unike.data import KGEDataLoader, BernSampler, TradTestSampler
|
38
|
+
from unike.module.model import TransD
|
39
|
+
from unike.module.loss import MarginLoss
|
40
|
+
from unike.module.strategy import NegativeSampling
|
41
|
+
from unike.config import Trainer, Tester
|
42
|
+
|
43
|
+
wandb_logger = WandbLogger(
|
44
|
+
project="pybind11-ke",
|
45
|
+
name="TransD-FB15K237",
|
46
|
+
config=dict(
|
47
|
+
in_path = '../../benchmarks/FB15K237/',
|
48
|
+
batch_size = 2048,
|
49
|
+
neg_ent = 25,
|
50
|
+
test = True,
|
51
|
+
test_batch_size = 10,
|
52
|
+
num_workers = 16,
|
53
|
+
dim_e = 200,
|
54
|
+
dim_r = 200,
|
55
|
+
p_norm = 1,
|
56
|
+
norm_flag = True,
|
57
|
+
margin = 4.0,
|
58
|
+
use_gpu = True,
|
59
|
+
device = 'cuda:1',
|
60
|
+
epochs = 1000,
|
61
|
+
lr = 1.0,
|
62
|
+
opt_method = 'sgd',
|
63
|
+
valid_interval = 100,
|
64
|
+
log_interval = 100,
|
65
|
+
save_interval = 100,
|
66
|
+
save_path = '../../checkpoint/transd.pth'
|
67
|
+
)
|
68
|
+
)
|
69
|
+
|
70
|
+
config = wandb_logger.config
|
71
|
+
|
72
|
+
# dataloader for training
|
73
|
+
dataloader = KGEDataLoader(
|
74
|
+
in_path = config.in_path,
|
75
|
+
batch_size = config.batch_size,
|
76
|
+
neg_ent = config.neg_ent,
|
77
|
+
test = config.test,
|
78
|
+
test_batch_size = config.test_batch_size,
|
79
|
+
num_workers = config.num_workers,
|
80
|
+
train_sampler = BernSampler,
|
81
|
+
test_sampler = TradTestSampler
|
82
|
+
)
|
83
|
+
|
84
|
+
# define the model
|
85
|
+
transd = TransD(
|
86
|
+
ent_tol = dataloader.get_ent_tol(),
|
87
|
+
rel_tol = dataloader.get_rel_tol(),
|
88
|
+
dim_e = config.dim_e,
|
89
|
+
dim_r = config.dim_r,
|
90
|
+
p_norm = config.p_norm,
|
91
|
+
norm_flag = config.norm_flag)
|
92
|
+
|
93
|
+
# define the loss function
|
94
|
+
model = NegativeSampling(
|
95
|
+
model = transd,
|
96
|
+
loss = MarginLoss(margin = config.margin)
|
97
|
+
)
|
98
|
+
|
99
|
+
# test the model
|
100
|
+
tester = Tester(model = transd, data_loader = dataloader, use_gpu = config.use_gpu, device = config.device)
|
101
|
+
|
102
|
+
# train the model
|
103
|
+
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(), epochs = config.epochs,
|
104
|
+
lr = config.lr, opt_method = config.opt_method, use_gpu = config.use_gpu, device = config.device,
|
105
|
+
tester = tester, test = config.test, valid_interval = config.valid_interval,
|
106
|
+
log_interval = config.log_interval, save_interval = config.save_interval,
|
107
|
+
save_path = config.save_path, use_wandb = True)
|
108
|
+
trainer.run()
|
109
|
+
|
110
|
+
# close your wandb run
|
111
|
+
wandb_logger.finish()
|
112
|
+
"""
|
113
|
+
|
114
|
+
def __init__(
|
115
|
+
self,
|
116
|
+
ent_tol: int,
|
117
|
+
rel_tol: int,
|
118
|
+
dim_e: int = 100,
|
119
|
+
dim_r: int = 100,
|
120
|
+
p_norm: int = 1,
|
121
|
+
norm_flag: bool = True,
|
122
|
+
margin: float | None = None):
|
123
|
+
|
124
|
+
"""创建 TransD 对象。
|
125
|
+
|
126
|
+
:param ent_tol: 实体的个数
|
127
|
+
:type ent_tol: int
|
128
|
+
:param rel_tol: 关系的个数
|
129
|
+
:type rel_tol: int
|
130
|
+
:param dim_e: 实体嵌入和实体投影向量的维度
|
131
|
+
:type dim_e: int
|
132
|
+
:param dim_r: 关系嵌入和关系投影向量的维度
|
133
|
+
:type dim_r: int
|
134
|
+
:param p_norm: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
|
135
|
+
:type p_norm: int
|
136
|
+
:param norm_flag: 是否利用 :py:func:`torch.nn.functional.normalize` 对实体和关系嵌入的最后一维执行 L2-norm。
|
137
|
+
:type norm_flag: bool
|
138
|
+
:param margin: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
|
139
|
+
:type margin: float
|
140
|
+
"""
|
141
|
+
|
142
|
+
super(TransD, self).__init__(ent_tol, rel_tol)
|
143
|
+
|
144
|
+
#: 实体嵌入和实体投影向量的维度
|
145
|
+
self.dim_e: int = dim_e
|
146
|
+
#: 关系嵌入和关系投影向量的维度
|
147
|
+
self.dim_r: int = dim_r
|
148
|
+
#: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
|
149
|
+
self.p_norm: int = p_norm
|
150
|
+
#: 是否利用 :py:func:`torch.nn.functional.normalize`
|
151
|
+
#: 对实体和关系嵌入向量的最后一维执行 L2-norm。
|
152
|
+
self.norm_flag: bool = norm_flag
|
153
|
+
|
154
|
+
#: 根据实体个数,创建的实体嵌入
|
155
|
+
self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim_e)
|
156
|
+
#: 根据关系个数,创建的关系嵌入
|
157
|
+
self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim_r)
|
158
|
+
#: 根据实体个数,创建的实体投影向量
|
159
|
+
self.ent_transfer: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim_e)
|
160
|
+
#: 根据关系个数,创建的关系投影向量
|
161
|
+
self.rel_transfer: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim_r)
|
162
|
+
|
163
|
+
if margin != None:
|
164
|
+
#: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
|
165
|
+
self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
|
166
|
+
self.margin.requires_grad = False
|
167
|
+
self.margin_flag: bool = True
|
168
|
+
else:
|
169
|
+
self.margin_flag: bool = False
|
170
|
+
|
171
|
+
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
|
172
|
+
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
|
173
|
+
nn.init.xavier_uniform_(self.ent_transfer.weight.data)
|
174
|
+
nn.init.xavier_uniform_(self.rel_transfer.weight.data)
|
175
|
+
|
176
|
+
@override
|
177
|
+
def forward(
|
178
|
+
self,
|
179
|
+
triples: torch.Tensor,
|
180
|
+
negs: torch.Tensor = None,
|
181
|
+
mode: str = 'single') -> torch.Tensor:
|
182
|
+
|
183
|
+
"""
|
184
|
+
定义每次调用时执行的计算。
|
185
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
186
|
+
|
187
|
+
:param triples: 正确的三元组
|
188
|
+
:type triples: torch.Tensor
|
189
|
+
:param negs: 负三元组类别
|
190
|
+
:type negs: torch.Tensor
|
191
|
+
:param mode: 模式
|
192
|
+
:type triples: str
|
193
|
+
:returns: 三元组的得分
|
194
|
+
:rtype: torch.Tensor
|
195
|
+
"""
|
196
|
+
|
197
|
+
h, r, t = self.tri2emb(triples, negs, mode)
|
198
|
+
h_transfer, r_transfer, t_transfer = self.tri2transfer(triples, negs, mode)
|
199
|
+
h = self._transfer(h, h_transfer, r_transfer)
|
200
|
+
t = self._transfer(t, t_transfer, r_transfer)
|
201
|
+
score = self._calc(h, r, t)
|
202
|
+
if self.margin_flag:
|
203
|
+
return self.margin - score
|
204
|
+
else:
|
205
|
+
return score
|
206
|
+
|
207
|
+
def tri2transfer(
|
208
|
+
self,
|
209
|
+
triples: torch.Tensor,
|
210
|
+
negs: torch.Tensor = None,
|
211
|
+
mode: str = 'single') -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
212
|
+
|
213
|
+
"""
|
214
|
+
返回三元组对应的嵌入向量。
|
215
|
+
|
216
|
+
:param triples: 正确的三元组
|
217
|
+
:type triples: torch.Tensor
|
218
|
+
:param negs: 负三元组类别
|
219
|
+
:type negs: torch.Tensor
|
220
|
+
:param mode: 模式
|
221
|
+
:type triples: str
|
222
|
+
:returns: 头实体、关系和尾实体的嵌入向量
|
223
|
+
:rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
224
|
+
"""
|
225
|
+
|
226
|
+
if mode == "single":
|
227
|
+
head_emb = self.ent_transfer(triples[:, 0]).unsqueeze(1)
|
228
|
+
relation_emb = self.rel_transfer(triples[:, 1]).unsqueeze(1)
|
229
|
+
tail_emb = self.ent_transfer(triples[:, 2]).unsqueeze(1)
|
230
|
+
|
231
|
+
elif mode == "head-batch" or mode == "head_predict":
|
232
|
+
if negs is None:
|
233
|
+
head_emb = self.ent_transfer.weight.data.unsqueeze(0)
|
234
|
+
else:
|
235
|
+
head_emb = self.ent_transfer(negs)
|
236
|
+
|
237
|
+
relation_emb = self.rel_transfer(triples[:, 1]).unsqueeze(1)
|
238
|
+
tail_emb = self.ent_transfer(triples[:, 2]).unsqueeze(1)
|
239
|
+
|
240
|
+
elif mode == "tail-batch" or mode == "tail_predict":
|
241
|
+
head_emb = self.ent_transfer(triples[:, 0]).unsqueeze(1)
|
242
|
+
relation_emb = self.rel_transfer(triples[:, 1]).unsqueeze(1)
|
243
|
+
|
244
|
+
if negs is None:
|
245
|
+
tail_emb = self.ent_transfer.weight.data.unsqueeze(0)
|
246
|
+
else:
|
247
|
+
tail_emb = self.ent_transfer(negs)
|
248
|
+
|
249
|
+
return head_emb, relation_emb, tail_emb
|
250
|
+
|
251
|
+
def _transfer(
|
252
|
+
self,
|
253
|
+
e: torch.Tensor,
|
254
|
+
e_transfer: torch.Tensor,
|
255
|
+
r_transfer: torch.Tensor) -> torch.Tensor:
|
256
|
+
|
257
|
+
"""
|
258
|
+
将头实体或尾实体的向量映射到关系向量空间。
|
259
|
+
|
260
|
+
:param e: 头实体或尾实体向量。
|
261
|
+
:type e: torch.Tensor
|
262
|
+
:param e_transfer: 头实体或尾实体的投影向量
|
263
|
+
:type e_transfer: torch.Tensor
|
264
|
+
:param r_transfer: 关系的投影向量
|
265
|
+
:type r_transfer: torch.Tensor
|
266
|
+
:returns: 投影后的实体向量
|
267
|
+
:rtype: torch.Tensor
|
268
|
+
"""
|
269
|
+
|
270
|
+
return F.normalize(
|
271
|
+
self._resize(e, len(e.size())-1, r_transfer.size()[-1]) + torch.sum(e * e_transfer, -1, True) * r_transfer,
|
272
|
+
p = 2,
|
273
|
+
dim = -1
|
274
|
+
)
|
275
|
+
|
276
|
+
def _resize(
|
277
|
+
self,
|
278
|
+
tensor: torch.Tensor,
|
279
|
+
axis: int,
|
280
|
+
size: int) -> torch.Tensor:
|
281
|
+
|
282
|
+
"""计算实体向量与单位矩阵的乘法,并返回结果向量。
|
283
|
+
|
284
|
+
源代码使用 :py:func:`torch.narrow` 进行缩小向量,
|
285
|
+
:py:func:`torch.nn.functional.pad` 进行填充向量。
|
286
|
+
|
287
|
+
:param tensor: 实体向量。
|
288
|
+
:type tensor: torch.Tensor
|
289
|
+
:param axis: 在哪个轴上进行乘法运算。
|
290
|
+
:type axis: int
|
291
|
+
:param size: 运算结果在 ``axis`` 上的维度大小,一般为关系向量的维度。
|
292
|
+
:type size: int
|
293
|
+
:returns: 乘法结果的向量
|
294
|
+
:rtype: torch.Tensor
|
295
|
+
"""
|
296
|
+
|
297
|
+
shape = tensor.size()
|
298
|
+
osize = shape[axis]
|
299
|
+
if osize == size:
|
300
|
+
return tensor
|
301
|
+
if (osize > size):
|
302
|
+
return torch.narrow(tensor, axis, 0, size)
|
303
|
+
paddings = []
|
304
|
+
for i in range(len(shape)):
|
305
|
+
if i == axis:
|
306
|
+
paddings = [0, size - osize] + paddings
|
307
|
+
else:
|
308
|
+
paddings = [0, 0] + paddings
|
309
|
+
return F.pad(tensor, paddings, mode = "constant", value = 0)
|
310
|
+
|
311
|
+
def _calc(
|
312
|
+
self,
|
313
|
+
h: torch.Tensor,
|
314
|
+
r: torch.Tensor,
|
315
|
+
t: torch.Tensor) -> torch.Tensor:
|
316
|
+
|
317
|
+
"""计算 TransD 的评分函数。
|
318
|
+
|
319
|
+
:param h: 头实体的向量。
|
320
|
+
:type h: torch.Tensor
|
321
|
+
:param r: 关系的向量。
|
322
|
+
:type r: torch.Tensor
|
323
|
+
:param t: 尾实体的向量。
|
324
|
+
:type t: torch.Tensor
|
325
|
+
:returns: 三元组的得分
|
326
|
+
:rtype: torch.Tensor
|
327
|
+
"""
|
328
|
+
|
329
|
+
# 对嵌入的最后一维进行归一化
|
330
|
+
if self.norm_flag:
|
331
|
+
h = F.normalize(h, 2, -1)
|
332
|
+
r = F.normalize(r, 2, -1)
|
333
|
+
t = F.normalize(t, 2, -1)
|
334
|
+
|
335
|
+
score = (h + r) - t
|
336
|
+
|
337
|
+
# 利用距离函数计算得分
|
338
|
+
score = torch.norm(score, self.p_norm, -1)
|
339
|
+
return score
|
340
|
+
|
341
|
+
@override
|
342
|
+
def predict(
|
343
|
+
self,
|
344
|
+
data: dict[str, typing.Union[torch.Tensor,str]],
|
345
|
+
mode: str) -> torch.Tensor:
|
346
|
+
|
347
|
+
"""TransH 的推理方法。
|
348
|
+
|
349
|
+
:param data: 数据。
|
350
|
+
:type data: dict[str, typing.Union[torch.Tensor,str]]
|
351
|
+
:param mode: 'head_predict' 或 'tail_predict'
|
352
|
+
:type mode: str
|
353
|
+
:returns: 三元组的得分
|
354
|
+
:rtype: torch.Tensor
|
355
|
+
"""
|
356
|
+
|
357
|
+
triples = data["positive_sample"]
|
358
|
+
h, r, t = self.tri2emb(triples, mode=mode)
|
359
|
+
h_transfer, r_transfer, t_transfer = self.tri2transfer(triples, mode=mode)
|
360
|
+
h = self._transfer(h, h_transfer, r_transfer)
|
361
|
+
t = self._transfer(t, t_transfer, r_transfer)
|
362
|
+
score = self._calc(h, r, t)
|
363
|
+
|
364
|
+
if self.margin_flag:
|
365
|
+
score = self.margin - score
|
366
|
+
return score
|
367
|
+
else:
|
368
|
+
return -score
|
369
|
+
|
370
|
+
def regularization(
|
371
|
+
self,
|
372
|
+
data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
|
373
|
+
|
374
|
+
"""L2 正则化函数(又称权重衰减),在损失函数中用到。
|
375
|
+
|
376
|
+
:param data: 数据。
|
377
|
+
:type data: dict[str, typing.Union[torch.Tensor, str]]
|
378
|
+
:returns: 模型参数的正则损失
|
379
|
+
:rtype: torch.Tensor
|
380
|
+
"""
|
381
|
+
|
382
|
+
pos_sample = data["positive_sample"]
|
383
|
+
neg_sample = data["negative_sample"]
|
384
|
+
mode = data["mode"]
|
385
|
+
pos_h, pos_r, pos_t = self.tri2emb(pos_sample)
|
386
|
+
pos_h_transfer, pos_r_transfer, pos_t_transfer = self.tri2transfer(pos_sample)
|
387
|
+
if mode == "bern":
|
388
|
+
neg_h, neg_r, neg_t = self.tri2emb(neg_sample)
|
389
|
+
neg_h_transfer, neg_r_transfer, neg_t_transfer = self.tri2transfer(neg_sample)
|
390
|
+
else:
|
391
|
+
neg_h, neg_r, neg_t = self.tri2emb(pos_sample, neg_sample, mode)
|
392
|
+
neg_h_transfer, neg_r_transfer, neg_t_transfer = self.tri2transfer(pos_sample, neg_sample, mode)
|
393
|
+
|
394
|
+
pos_regul = (torch.mean(pos_h ** 2) +
|
395
|
+
torch.mean(pos_r ** 2) +
|
396
|
+
torch.mean(pos_t ** 2) +
|
397
|
+
torch.mean(pos_h_transfer ** 2) +
|
398
|
+
torch.mean(pos_r_transfer ** 2) +
|
399
|
+
torch.mean(pos_t_transfer ** 2)) / 6
|
400
|
+
|
401
|
+
neg_regul = (torch.mean(neg_h ** 2) +
|
402
|
+
torch.mean(neg_r ** 2) +
|
403
|
+
torch.mean(neg_t ** 2) +
|
404
|
+
torch.mean(neg_h_transfer ** 2) +
|
405
|
+
torch.mean(neg_r_transfer ** 2) +
|
406
|
+
torch.mean(neg_t_transfer ** 2)) / 6
|
407
|
+
|
408
|
+
regul = (pos_regul + neg_regul) / 2
|
409
|
+
|
410
|
+
return regul
|
411
|
+
|
412
|
+
def get_transd_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
413
|
+
|
414
|
+
"""返回 :py:class:`TransD` 的默认超参数优化配置。
|
415
|
+
|
416
|
+
默认配置为::
|
417
|
+
|
418
|
+
parameters_dict = {
|
419
|
+
'model': {
|
420
|
+
'value': 'TransD'
|
421
|
+
},
|
422
|
+
'dim_e': {
|
423
|
+
'values': [50, 100, 200]
|
424
|
+
},
|
425
|
+
'dim_r': {
|
426
|
+
'values': [50, 100, 200]
|
427
|
+
},
|
428
|
+
'p_norm': {
|
429
|
+
'values': [1, 2]
|
430
|
+
},
|
431
|
+
'norm_flag': {
|
432
|
+
'value': True
|
433
|
+
}
|
434
|
+
}
|
435
|
+
|
436
|
+
:returns: :py:class:`TransD` 的默认超参数优化配置
|
437
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
438
|
+
"""
|
439
|
+
|
440
|
+
parameters_dict = {
|
441
|
+
'model': {
|
442
|
+
'value': 'TransD'
|
443
|
+
},
|
444
|
+
'dim_e': {
|
445
|
+
'values': [50, 100, 200]
|
446
|
+
},
|
447
|
+
'dim_r': {
|
448
|
+
'values': [50, 100, 200]
|
449
|
+
},
|
450
|
+
'p_norm': {
|
451
|
+
'values': [1, 2]
|
452
|
+
},
|
453
|
+
'norm_flag': {
|
454
|
+
'value': True
|
455
|
+
}
|
456
|
+
}
|
457
|
+
|
458
|
+
return parameters_dict
|