unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,237 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/SimplE.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 7, 2023
7
+ #
8
+ # 该头文件定义了 SimplE.
9
+
10
+ """
11
+ SimplE - 简单的双线性模型,能够为头实体和尾实体学习不同的嵌入向量。
12
+ """
13
+
14
+ import math
15
+ import torch
16
+ import typing
17
+ import numpy as np
18
+ import torch.nn as nn
19
+ from .Model import Model
20
+ from typing_extensions import override
21
+
22
+ class SimplE(Model):
23
+
24
+ """
25
+ ``SimplE`` :cite:`SimplE` 提出于 2018 年,简单的双线性模型,能够为头实体和尾实体学习不同的嵌入向量。
26
+
27
+ 评分函数为:
28
+
29
+ .. math::
30
+
31
+ 1/2(<\mathbf{h}_{i}, \mathbf{v}_r, \mathbf{t}_{j}> + <\mathbf{h}_{j}, \mathbf{v}_{r^{-1}}, \mathbf{t}_{i}>)
32
+
33
+ :math:`< \mathbf{a}, \mathbf{b}, \mathbf{c} >` 为逐元素多线性点积(element-wise multi-linear dot product)。
34
+
35
+ 正三元组的评分函数的值越大越好,负三元组越小越好,如果想获得更详细的信息请访问 :ref:`SimplE <simple>`。
36
+
37
+ 例子::
38
+
39
+ from unike.config import Trainer, Tester
40
+ from unike.module.model import SimplE
41
+ from unike.module.loss import SoftplusLoss
42
+ from unike.module.strategy import NegativeSampling
43
+
44
+ # define the model
45
+ simple = SimplE(
46
+ ent_tol = train_dataloader.get_ent_tol(),
47
+ rel_tol = train_dataloader.get_rel_tol(),
48
+ dim = config.dim
49
+ )
50
+
51
+ # define the loss function
52
+ model = NegativeSampling(
53
+ model = simple,
54
+ loss = SoftplusLoss(),
55
+ batch_size = train_dataloader.get_batch_size(),
56
+ regul_rate = config.regul_rate
57
+ )
58
+
59
+ # dataloader for test
60
+ test_dataloader = TestDataLoader(in_path = config.in_path)
61
+
62
+ # test the model
63
+ tester = Tester(model = simple, data_loader = test_dataloader, use_gpu = config.use_gpu, device = config.device)
64
+
65
+ # train the model
66
+ trainer = Trainer(model = model, data_loader = train_dataloader, epochs = config.epochs,
67
+ lr = config.lr, opt_method = config.opt_method, use_gpu = config.use_gpu, device = config.device,
68
+ tester = tester, test = config.test, valid_interval = config.valid_interval,
69
+ log_interval = config.log_interval, save_interval = config.save_interval,
70
+ save_path = config.save_path, use_wandb = True)
71
+ trainer.run()
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ ent_tol: int,
77
+ rel_tol: int,
78
+ dim: int = 100):
79
+
80
+ """创建 SimplE 对象。
81
+
82
+ :param ent_tol: 实体的个数
83
+ :type ent_tol: int
84
+ :param rel_tol: 关系的个数
85
+ :type rel_tol: int
86
+ :param dim: 实体嵌入向量和关系嵌入向量的维度
87
+ :type dim: int
88
+ """
89
+
90
+ super(SimplE, self).__init__(ent_tol, rel_tol)
91
+
92
+ #: 实体嵌入向量和关系嵌入向量的维度
93
+ self.dim: int = dim
94
+
95
+ #: 根据实体个数,创建的实体嵌入
96
+ self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim * 2)
97
+ #: 根据关系个数,创建的关系嵌入
98
+ self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim * 2)
99
+
100
+ nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
101
+ nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
102
+
103
+ @override
104
+ def forward(
105
+ self,
106
+ triples: torch.Tensor,
107
+ negs: torch.Tensor = None,
108
+ mode: str = 'single') -> torch.Tensor:
109
+
110
+ """
111
+ 定义每次调用时执行的计算。
112
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
113
+
114
+ :param triples: 正确的三元组
115
+ :type triples: torch.Tensor
116
+ :param negs: 负三元组类别
117
+ :type negs: torch.Tensor
118
+ :param mode: 模式
119
+ :type triples: str
120
+ :returns: 三元组的得分
121
+ :rtype: torch.Tensor
122
+ """
123
+
124
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
125
+ score = self._calc(head_emb, relation_emb, tail_emb)
126
+ return score
127
+
128
+ def _calc(
129
+ self,
130
+ h: torch.Tensor,
131
+ r: torch.Tensor,
132
+ t: torch.Tensor) -> torch.Tensor:
133
+
134
+ """计算 SimplE 的评分函数。
135
+
136
+ :param h: 头实体的向量。
137
+ :type h: torch.Tensor
138
+ :param r: 关系的向量。
139
+ :type r: torch.Tensor
140
+ :param t: 尾实体的向量。
141
+ :type t: torch.Tensor
142
+ :returns: 三元组的得分
143
+ :rtype: torch.Tensor
144
+ """
145
+
146
+ hh_embs, th_embs = torch.chunk(h, 2, dim=-1)
147
+ r_embs, r_inv_embs = torch.chunk(r, 2, dim=-1)
148
+ ht_embs, tt_embs = torch.chunk(t, 2, dim=-1)
149
+
150
+ scores1 = torch.sum(hh_embs * r_embs * tt_embs, -1)
151
+ scores2 = torch.sum(ht_embs * r_inv_embs * th_embs, -1)
152
+
153
+ # Without clipping, we run into NaN problems.
154
+ # 基于论文作者的实现。
155
+ return torch.clamp((scores1 + scores2) / 2, -20, 20)
156
+
157
+ @override
158
+ def predict(
159
+ self,
160
+ data: dict[str, typing.Union[torch.Tensor,str]],
161
+ mode) -> torch.Tensor:
162
+
163
+ """SimplE 的推理方法。
164
+
165
+ :param data: 数据。
166
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
167
+ :returns: 三元组的得分
168
+ :rtype: torch.Tensor
169
+ """
170
+
171
+ triples = data["positive_sample"]
172
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
173
+ score = self._calc(head_emb, relation_emb, tail_emb)
174
+ return score
175
+
176
+ def regularization(
177
+ self,
178
+ data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
179
+
180
+ """L2 正则化函数(又称权重衰减),在损失函数中用到。
181
+
182
+ :param data: 数据。
183
+ :type data: dict[str, typing.Union[torch.Tensor, str]]
184
+ :returns: 模型参数的正则损失
185
+ :rtype: torch.Tensor
186
+ """
187
+
188
+ pos_sample = data["positive_sample"]
189
+ neg_sample = data["negative_sample"]
190
+ mode = data["mode"]
191
+ pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
192
+ if mode == "bern":
193
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
194
+ else:
195
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
196
+
197
+ pos_regul = (torch.mean(pos_head_emb ** 2) +
198
+ torch.mean(pos_relation_emb ** 2) +
199
+ torch.mean(pos_tail_emb ** 2)) / 3
200
+
201
+ neg_regul = (torch.mean(neg_head_emb ** 2) +
202
+ torch.mean(neg_relation_emb ** 2) +
203
+ torch.mean(neg_tail_emb ** 2)) / 3
204
+
205
+ regul = (pos_regul + neg_regul) / 2
206
+
207
+ return regul
208
+
209
+ def get_simple_hpo_config() -> dict[str, dict[str, typing.Any]]:
210
+
211
+ """返回 :py:class:`SimplE` 的默认超参数优化配置。
212
+
213
+ 默认配置为::
214
+
215
+ parameters_dict = {
216
+ 'model': {
217
+ 'value': 'SimplE'
218
+ },
219
+ 'dim': {
220
+ 'values': [50, 100, 200]
221
+ }
222
+ }
223
+
224
+ :returns: :py:class:`SimplE` 的默认超参数优化配置
225
+ :rtype: dict[str, dict[str, typing.Any]]
226
+ """
227
+
228
+ parameters_dict = {
229
+ 'model': {
230
+ 'value': 'SimplE'
231
+ },
232
+ 'dim': {
233
+ 'values': [50, 100, 200]
234
+ }
235
+ }
236
+
237
+ return parameters_dict
@@ -0,0 +1,458 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/TransD.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2023
7
+ #
8
+ # 该头文件定义了 TransD.
9
+
10
+ """
11
+ TransD - 自动生成映射矩阵,简单而且高效,是对 TransR 的改进。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from .Model import Model
19
+ from typing_extensions import override
20
+
21
+ class TransD(Model):
22
+
23
+ """
24
+ ``TransD`` :cite:`TransD` 提出于 2015 年,自动生成映射矩阵,简单而且高效,是对 TransR 的改进。
25
+
26
+ 评分函数为:
27
+
28
+ .. math::
29
+
30
+ \parallel (\mathbf{r}_p \mathbf{h}_p^T + \mathbf{I})\mathbf{h} + \mathbf{r} - (\mathbf{r}_p \mathbf{t}_p^T + \mathbf{I})\mathbf{t} \parallel_{L_1/L_2}
31
+
32
+ 正三元组的评分函数的值越小越好,如果想获得更详细的信息请访问 :ref:`TransD <transd>`。
33
+
34
+ 例子::
35
+
36
+ from unike.utils import WandbLogger
37
+ from unike.data import KGEDataLoader, BernSampler, TradTestSampler
38
+ from unike.module.model import TransD
39
+ from unike.module.loss import MarginLoss
40
+ from unike.module.strategy import NegativeSampling
41
+ from unike.config import Trainer, Tester
42
+
43
+ wandb_logger = WandbLogger(
44
+ project="pybind11-ke",
45
+ name="TransD-FB15K237",
46
+ config=dict(
47
+ in_path = '../../benchmarks/FB15K237/',
48
+ batch_size = 2048,
49
+ neg_ent = 25,
50
+ test = True,
51
+ test_batch_size = 10,
52
+ num_workers = 16,
53
+ dim_e = 200,
54
+ dim_r = 200,
55
+ p_norm = 1,
56
+ norm_flag = True,
57
+ margin = 4.0,
58
+ use_gpu = True,
59
+ device = 'cuda:1',
60
+ epochs = 1000,
61
+ lr = 1.0,
62
+ opt_method = 'sgd',
63
+ valid_interval = 100,
64
+ log_interval = 100,
65
+ save_interval = 100,
66
+ save_path = '../../checkpoint/transd.pth'
67
+ )
68
+ )
69
+
70
+ config = wandb_logger.config
71
+
72
+ # dataloader for training
73
+ dataloader = KGEDataLoader(
74
+ in_path = config.in_path,
75
+ batch_size = config.batch_size,
76
+ neg_ent = config.neg_ent,
77
+ test = config.test,
78
+ test_batch_size = config.test_batch_size,
79
+ num_workers = config.num_workers,
80
+ train_sampler = BernSampler,
81
+ test_sampler = TradTestSampler
82
+ )
83
+
84
+ # define the model
85
+ transd = TransD(
86
+ ent_tol = dataloader.get_ent_tol(),
87
+ rel_tol = dataloader.get_rel_tol(),
88
+ dim_e = config.dim_e,
89
+ dim_r = config.dim_r,
90
+ p_norm = config.p_norm,
91
+ norm_flag = config.norm_flag)
92
+
93
+ # define the loss function
94
+ model = NegativeSampling(
95
+ model = transd,
96
+ loss = MarginLoss(margin = config.margin)
97
+ )
98
+
99
+ # test the model
100
+ tester = Tester(model = transd, data_loader = dataloader, use_gpu = config.use_gpu, device = config.device)
101
+
102
+ # train the model
103
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(), epochs = config.epochs,
104
+ lr = config.lr, opt_method = config.opt_method, use_gpu = config.use_gpu, device = config.device,
105
+ tester = tester, test = config.test, valid_interval = config.valid_interval,
106
+ log_interval = config.log_interval, save_interval = config.save_interval,
107
+ save_path = config.save_path, use_wandb = True)
108
+ trainer.run()
109
+
110
+ # close your wandb run
111
+ wandb_logger.finish()
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ ent_tol: int,
117
+ rel_tol: int,
118
+ dim_e: int = 100,
119
+ dim_r: int = 100,
120
+ p_norm: int = 1,
121
+ norm_flag: bool = True,
122
+ margin: float | None = None):
123
+
124
+ """创建 TransD 对象。
125
+
126
+ :param ent_tol: 实体的个数
127
+ :type ent_tol: int
128
+ :param rel_tol: 关系的个数
129
+ :type rel_tol: int
130
+ :param dim_e: 实体嵌入和实体投影向量的维度
131
+ :type dim_e: int
132
+ :param dim_r: 关系嵌入和关系投影向量的维度
133
+ :type dim_r: int
134
+ :param p_norm: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
135
+ :type p_norm: int
136
+ :param norm_flag: 是否利用 :py:func:`torch.nn.functional.normalize` 对实体和关系嵌入的最后一维执行 L2-norm。
137
+ :type norm_flag: bool
138
+ :param margin: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
139
+ :type margin: float
140
+ """
141
+
142
+ super(TransD, self).__init__(ent_tol, rel_tol)
143
+
144
+ #: 实体嵌入和实体投影向量的维度
145
+ self.dim_e: int = dim_e
146
+ #: 关系嵌入和关系投影向量的维度
147
+ self.dim_r: int = dim_r
148
+ #: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
149
+ self.p_norm: int = p_norm
150
+ #: 是否利用 :py:func:`torch.nn.functional.normalize`
151
+ #: 对实体和关系嵌入向量的最后一维执行 L2-norm。
152
+ self.norm_flag: bool = norm_flag
153
+
154
+ #: 根据实体个数,创建的实体嵌入
155
+ self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim_e)
156
+ #: 根据关系个数,创建的关系嵌入
157
+ self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim_r)
158
+ #: 根据实体个数,创建的实体投影向量
159
+ self.ent_transfer: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim_e)
160
+ #: 根据关系个数,创建的关系投影向量
161
+ self.rel_transfer: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim_r)
162
+
163
+ if margin != None:
164
+ #: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
165
+ self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
166
+ self.margin.requires_grad = False
167
+ self.margin_flag: bool = True
168
+ else:
169
+ self.margin_flag: bool = False
170
+
171
+ nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
172
+ nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
173
+ nn.init.xavier_uniform_(self.ent_transfer.weight.data)
174
+ nn.init.xavier_uniform_(self.rel_transfer.weight.data)
175
+
176
+ @override
177
+ def forward(
178
+ self,
179
+ triples: torch.Tensor,
180
+ negs: torch.Tensor = None,
181
+ mode: str = 'single') -> torch.Tensor:
182
+
183
+ """
184
+ 定义每次调用时执行的计算。
185
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
186
+
187
+ :param triples: 正确的三元组
188
+ :type triples: torch.Tensor
189
+ :param negs: 负三元组类别
190
+ :type negs: torch.Tensor
191
+ :param mode: 模式
192
+ :type triples: str
193
+ :returns: 三元组的得分
194
+ :rtype: torch.Tensor
195
+ """
196
+
197
+ h, r, t = self.tri2emb(triples, negs, mode)
198
+ h_transfer, r_transfer, t_transfer = self.tri2transfer(triples, negs, mode)
199
+ h = self._transfer(h, h_transfer, r_transfer)
200
+ t = self._transfer(t, t_transfer, r_transfer)
201
+ score = self._calc(h, r, t)
202
+ if self.margin_flag:
203
+ return self.margin - score
204
+ else:
205
+ return score
206
+
207
+ def tri2transfer(
208
+ self,
209
+ triples: torch.Tensor,
210
+ negs: torch.Tensor = None,
211
+ mode: str = 'single') -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
212
+
213
+ """
214
+ 返回三元组对应的嵌入向量。
215
+
216
+ :param triples: 正确的三元组
217
+ :type triples: torch.Tensor
218
+ :param negs: 负三元组类别
219
+ :type negs: torch.Tensor
220
+ :param mode: 模式
221
+ :type triples: str
222
+ :returns: 头实体、关系和尾实体的嵌入向量
223
+ :rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
224
+ """
225
+
226
+ if mode == "single":
227
+ head_emb = self.ent_transfer(triples[:, 0]).unsqueeze(1)
228
+ relation_emb = self.rel_transfer(triples[:, 1]).unsqueeze(1)
229
+ tail_emb = self.ent_transfer(triples[:, 2]).unsqueeze(1)
230
+
231
+ elif mode == "head-batch" or mode == "head_predict":
232
+ if negs is None:
233
+ head_emb = self.ent_transfer.weight.data.unsqueeze(0)
234
+ else:
235
+ head_emb = self.ent_transfer(negs)
236
+
237
+ relation_emb = self.rel_transfer(triples[:, 1]).unsqueeze(1)
238
+ tail_emb = self.ent_transfer(triples[:, 2]).unsqueeze(1)
239
+
240
+ elif mode == "tail-batch" or mode == "tail_predict":
241
+ head_emb = self.ent_transfer(triples[:, 0]).unsqueeze(1)
242
+ relation_emb = self.rel_transfer(triples[:, 1]).unsqueeze(1)
243
+
244
+ if negs is None:
245
+ tail_emb = self.ent_transfer.weight.data.unsqueeze(0)
246
+ else:
247
+ tail_emb = self.ent_transfer(negs)
248
+
249
+ return head_emb, relation_emb, tail_emb
250
+
251
+ def _transfer(
252
+ self,
253
+ e: torch.Tensor,
254
+ e_transfer: torch.Tensor,
255
+ r_transfer: torch.Tensor) -> torch.Tensor:
256
+
257
+ """
258
+ 将头实体或尾实体的向量映射到关系向量空间。
259
+
260
+ :param e: 头实体或尾实体向量。
261
+ :type e: torch.Tensor
262
+ :param e_transfer: 头实体或尾实体的投影向量
263
+ :type e_transfer: torch.Tensor
264
+ :param r_transfer: 关系的投影向量
265
+ :type r_transfer: torch.Tensor
266
+ :returns: 投影后的实体向量
267
+ :rtype: torch.Tensor
268
+ """
269
+
270
+ return F.normalize(
271
+ self._resize(e, len(e.size())-1, r_transfer.size()[-1]) + torch.sum(e * e_transfer, -1, True) * r_transfer,
272
+ p = 2,
273
+ dim = -1
274
+ )
275
+
276
+ def _resize(
277
+ self,
278
+ tensor: torch.Tensor,
279
+ axis: int,
280
+ size: int) -> torch.Tensor:
281
+
282
+ """计算实体向量与单位矩阵的乘法,并返回结果向量。
283
+
284
+ 源代码使用 :py:func:`torch.narrow` 进行缩小向量,
285
+ :py:func:`torch.nn.functional.pad` 进行填充向量。
286
+
287
+ :param tensor: 实体向量。
288
+ :type tensor: torch.Tensor
289
+ :param axis: 在哪个轴上进行乘法运算。
290
+ :type axis: int
291
+ :param size: 运算结果在 ``axis`` 上的维度大小,一般为关系向量的维度。
292
+ :type size: int
293
+ :returns: 乘法结果的向量
294
+ :rtype: torch.Tensor
295
+ """
296
+
297
+ shape = tensor.size()
298
+ osize = shape[axis]
299
+ if osize == size:
300
+ return tensor
301
+ if (osize > size):
302
+ return torch.narrow(tensor, axis, 0, size)
303
+ paddings = []
304
+ for i in range(len(shape)):
305
+ if i == axis:
306
+ paddings = [0, size - osize] + paddings
307
+ else:
308
+ paddings = [0, 0] + paddings
309
+ return F.pad(tensor, paddings, mode = "constant", value = 0)
310
+
311
+ def _calc(
312
+ self,
313
+ h: torch.Tensor,
314
+ r: torch.Tensor,
315
+ t: torch.Tensor) -> torch.Tensor:
316
+
317
+ """计算 TransD 的评分函数。
318
+
319
+ :param h: 头实体的向量。
320
+ :type h: torch.Tensor
321
+ :param r: 关系的向量。
322
+ :type r: torch.Tensor
323
+ :param t: 尾实体的向量。
324
+ :type t: torch.Tensor
325
+ :returns: 三元组的得分
326
+ :rtype: torch.Tensor
327
+ """
328
+
329
+ # 对嵌入的最后一维进行归一化
330
+ if self.norm_flag:
331
+ h = F.normalize(h, 2, -1)
332
+ r = F.normalize(r, 2, -1)
333
+ t = F.normalize(t, 2, -1)
334
+
335
+ score = (h + r) - t
336
+
337
+ # 利用距离函数计算得分
338
+ score = torch.norm(score, self.p_norm, -1)
339
+ return score
340
+
341
+ @override
342
+ def predict(
343
+ self,
344
+ data: dict[str, typing.Union[torch.Tensor,str]],
345
+ mode: str) -> torch.Tensor:
346
+
347
+ """TransH 的推理方法。
348
+
349
+ :param data: 数据。
350
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
351
+ :param mode: 'head_predict' 或 'tail_predict'
352
+ :type mode: str
353
+ :returns: 三元组的得分
354
+ :rtype: torch.Tensor
355
+ """
356
+
357
+ triples = data["positive_sample"]
358
+ h, r, t = self.tri2emb(triples, mode=mode)
359
+ h_transfer, r_transfer, t_transfer = self.tri2transfer(triples, mode=mode)
360
+ h = self._transfer(h, h_transfer, r_transfer)
361
+ t = self._transfer(t, t_transfer, r_transfer)
362
+ score = self._calc(h, r, t)
363
+
364
+ if self.margin_flag:
365
+ score = self.margin - score
366
+ return score
367
+ else:
368
+ return -score
369
+
370
+ def regularization(
371
+ self,
372
+ data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
373
+
374
+ """L2 正则化函数(又称权重衰减),在损失函数中用到。
375
+
376
+ :param data: 数据。
377
+ :type data: dict[str, typing.Union[torch.Tensor, str]]
378
+ :returns: 模型参数的正则损失
379
+ :rtype: torch.Tensor
380
+ """
381
+
382
+ pos_sample = data["positive_sample"]
383
+ neg_sample = data["negative_sample"]
384
+ mode = data["mode"]
385
+ pos_h, pos_r, pos_t = self.tri2emb(pos_sample)
386
+ pos_h_transfer, pos_r_transfer, pos_t_transfer = self.tri2transfer(pos_sample)
387
+ if mode == "bern":
388
+ neg_h, neg_r, neg_t = self.tri2emb(neg_sample)
389
+ neg_h_transfer, neg_r_transfer, neg_t_transfer = self.tri2transfer(neg_sample)
390
+ else:
391
+ neg_h, neg_r, neg_t = self.tri2emb(pos_sample, neg_sample, mode)
392
+ neg_h_transfer, neg_r_transfer, neg_t_transfer = self.tri2transfer(pos_sample, neg_sample, mode)
393
+
394
+ pos_regul = (torch.mean(pos_h ** 2) +
395
+ torch.mean(pos_r ** 2) +
396
+ torch.mean(pos_t ** 2) +
397
+ torch.mean(pos_h_transfer ** 2) +
398
+ torch.mean(pos_r_transfer ** 2) +
399
+ torch.mean(pos_t_transfer ** 2)) / 6
400
+
401
+ neg_regul = (torch.mean(neg_h ** 2) +
402
+ torch.mean(neg_r ** 2) +
403
+ torch.mean(neg_t ** 2) +
404
+ torch.mean(neg_h_transfer ** 2) +
405
+ torch.mean(neg_r_transfer ** 2) +
406
+ torch.mean(neg_t_transfer ** 2)) / 6
407
+
408
+ regul = (pos_regul + neg_regul) / 2
409
+
410
+ return regul
411
+
412
+ def get_transd_hpo_config() -> dict[str, dict[str, typing.Any]]:
413
+
414
+ """返回 :py:class:`TransD` 的默认超参数优化配置。
415
+
416
+ 默认配置为::
417
+
418
+ parameters_dict = {
419
+ 'model': {
420
+ 'value': 'TransD'
421
+ },
422
+ 'dim_e': {
423
+ 'values': [50, 100, 200]
424
+ },
425
+ 'dim_r': {
426
+ 'values': [50, 100, 200]
427
+ },
428
+ 'p_norm': {
429
+ 'values': [1, 2]
430
+ },
431
+ 'norm_flag': {
432
+ 'value': True
433
+ }
434
+ }
435
+
436
+ :returns: :py:class:`TransD` 的默认超参数优化配置
437
+ :rtype: dict[str, dict[str, typing.Any]]
438
+ """
439
+
440
+ parameters_dict = {
441
+ 'model': {
442
+ 'value': 'TransD'
443
+ },
444
+ 'dim_e': {
445
+ 'values': [50, 100, 200]
446
+ },
447
+ 'dim_r': {
448
+ 'values': [50, 100, 200]
449
+ },
450
+ 'p_norm': {
451
+ 'values': [1, 2]
452
+ },
453
+ 'norm_flag': {
454
+ 'value': True
455
+ }
456
+ }
457
+
458
+ return parameters_dict