unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,308 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/HolE.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 31, 2024
7
+ #
8
+ # 该头文件定义了 HolE.
9
+
10
+ """
11
+ HolE - 利用循环相关进行知识图谱嵌入,是 RESCAL 的压缩版本,因此非常容易的应用于大型的知识图谱。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import torch.nn as nn
17
+ from .Model import Model
18
+ from typing_extensions import override
19
+
20
+ class HolE(Model):
21
+
22
+ """
23
+ ``HolE`` :cite:`HolE` 提出于 2016 年,利用循环相关进行知识图谱嵌入,是 RESCAL 的压缩版本,因此非常容易的应用于大型的知识图谱。
24
+
25
+ 评分函数为:
26
+
27
+ .. math::
28
+
29
+ \mathbf{r}^T (\mathcal{F}^{-1}(\overline{\mathcal{F}(\mathbf{h})} \odot \mathcal{F}(\mathbf{t})))
30
+
31
+ 其中 :math:`\mathcal{F}(\cdot)` 和 :math:`\mathcal{F}^{-1}(\cdot)` 表示快速傅里叶变换,:math:`\overline{\mathbf{x}}` 表示复数共轭,:math:`\odot` 表示哈达玛积。
32
+
33
+ 正三元组的评分函数的值越大越好,负三元组越小越好,如果想获得更详细的信息请访问 :ref:`HolE <hole>`。
34
+
35
+ 例子::
36
+
37
+ from unike.utils import WandbLogger
38
+ from unike.data import KGEDataLoader, BernSampler, TradTestSampler
39
+ from unike.module.model import HolE
40
+ from unike.module.loss import SoftplusLoss
41
+ from unike.module.strategy import NegativeSampling
42
+ from unike.config import Trainer, Tester
43
+
44
+ wandb_logger = WandbLogger(
45
+ project="pybind11-ke",
46
+ name="HolE-WN18RR",
47
+ config=dict(
48
+ in_path = '../../benchmarks/WN18RR/',
49
+ batch_size = 8192,
50
+ neg_ent = 25,
51
+ test = True,
52
+ test_batch_size = 256,
53
+ num_workers = 16,
54
+ dim = 100,
55
+ regul_rate = 1.0,
56
+ use_gpu = True,
57
+ device = 'cuda:0',
58
+ epochs = 1000,
59
+ lr = 0.5,
60
+ opt_method = 'adagrad',
61
+ valid_interval = 100,
62
+ log_interval = 100,
63
+ save_interval = 100,
64
+ save_path = '../../checkpoint/hole.pth'
65
+ )
66
+ )
67
+
68
+ config = wandb_logger.config
69
+
70
+ # dataloader for training
71
+ dataloader = KGEDataLoader(
72
+ in_path = config.in_path,
73
+ batch_size = config.batch_size,
74
+ neg_ent = config.neg_ent,
75
+ test = config.test,
76
+ test_batch_size = config.test_batch_size,
77
+ num_workers = config.num_workers,
78
+ train_sampler = BernSampler,
79
+ test_sampler = TradTestSampler
80
+ )
81
+
82
+ # define the model
83
+ hole = HolE(
84
+ ent_tol = dataloader.get_ent_tol(),
85
+ rel_tol = dataloader.get_rel_tol(),
86
+ dim = config.dim
87
+ )
88
+
89
+ # define the loss function
90
+ model = NegativeSampling(
91
+ model = hole,
92
+ loss = SoftplusLoss(),
93
+ regul_rate = config.regul_rate
94
+ )
95
+
96
+ # test the model
97
+ tester = Tester(model = hole, data_loader = dataloader, use_gpu = config.use_gpu, device = config.device)
98
+
99
+ # train the model
100
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(), epochs = config.epochs,
101
+ lr = config.lr, opt_method = config.opt_method, use_gpu = config.use_gpu, device = config.device,
102
+ tester = tester, test = config.test, valid_interval = config.valid_interval,
103
+ log_interval = config.log_interval, save_interval = config.save_interval,
104
+ save_path = config.save_path, use_wandb = True)
105
+ trainer.run()
106
+
107
+ # close your wandb run
108
+ wandb_logger.finish()
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ ent_tol: int,
114
+ rel_tol: int,
115
+ dim: int = 100):
116
+
117
+ """创建 HolE 对象。
118
+
119
+ :param ent_tol: 实体的个数
120
+ :type ent_tol: int
121
+ :param rel_tol: 关系的个数
122
+ :type rel_tol: int
123
+ :param dim: 实体和关系嵌入向量的维度
124
+ :type dim: int
125
+ """
126
+
127
+ super(HolE, self).__init__(ent_tol, rel_tol)
128
+
129
+ #: 实体和关系嵌入向量的维度
130
+ self.dim: int = dim
131
+
132
+ #: 根据实体个数,创建的实体嵌入
133
+ self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim)
134
+ #: 根据关系个数,创建的关系嵌入
135
+ self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim)
136
+
137
+ nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
138
+ nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
139
+
140
+ @override
141
+ def forward(
142
+ self,
143
+ triples: torch.Tensor,
144
+ negs: torch.Tensor = None,
145
+ mode: str = 'single') -> torch.Tensor:
146
+
147
+ """
148
+ 定义每次调用时执行的计算。
149
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
150
+
151
+ :param triples: 正确的三元组
152
+ :type triples: torch.Tensor
153
+ :param negs: 负三元组类别
154
+ :type negs: torch.Tensor
155
+ :param mode: 模式
156
+ :type triples: str
157
+ :returns: 三元组的得分
158
+ :rtype: torch.Tensor
159
+ """
160
+
161
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
162
+ score = self._calc(head_emb, relation_emb, tail_emb)
163
+ return score
164
+
165
+ def _calc(
166
+ self,
167
+ h: torch.Tensor,
168
+ r: torch.Tensor,
169
+ t: torch.Tensor) -> torch.Tensor:
170
+
171
+ """计算 HolE 的评分函数。
172
+
173
+ :param h: 头实体的向量。
174
+ :type h: torch.Tensor
175
+ :param r: 关系的向量。
176
+ :type r: torch.Tensor
177
+ :param t: 尾实体的向量。
178
+ :type t: torch.Tensor
179
+ :returns: 三元组的得分
180
+ :rtype: torch.Tensor
181
+ """
182
+
183
+ score = self._ccorr(h, t) * r
184
+ score = torch.sum(score, -1)
185
+ return score
186
+
187
+ def _ccorr(
188
+ self,
189
+ a: torch.Tensor,
190
+ b: torch.Tensor) -> torch.Tensor:
191
+
192
+ """计算循环相关 :math:`\mathcal{F}^{-1}(\overline{\mathcal{F}(\mathbf{h})} \odot \mathcal{F}(\mathbf{t}))`。
193
+
194
+ 利用 :py:func:`torch.fft.rfft` 计算实数到复数离散傅里叶变换,:py:func:`torch.fft.irfft` 是其逆变换;
195
+ 利用 :py:func:`torch.conj` 计算复数的共轭。
196
+
197
+ :param a: 头实体的向量。
198
+ :type a: torch.Tensor
199
+ :param b: 尾实体的向量。
200
+ :type b: torch.Tensor
201
+ :returns: 返回循环相关计算结果。
202
+ :rtype: torch.Tensor
203
+ """
204
+
205
+ # 计算傅里叶变换
206
+ a_fft = torch.fft.rfft(a, dim=-1)
207
+ b_fft = torch.fft.rfft(b, dim=-1)
208
+
209
+ # 复数的共轭
210
+ a_fft = torch.conj(a_fft)
211
+
212
+ # 哈达玛积
213
+ p_fft = a_fft * b_fft
214
+
215
+ # 傅里叶变换的逆变换
216
+ return torch.fft.irfft(p_fft, n=a.shape[-1], dim=-1)
217
+
218
+ @override
219
+ def predict(
220
+ self,
221
+ data: dict[str, typing.Union[torch.Tensor,str]],
222
+ mode) -> torch.Tensor:
223
+
224
+ """HolE 的推理方法。
225
+
226
+ :param data: 数据。
227
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
228
+ :returns: 三元组的得分
229
+ :rtype: torch.Tensor
230
+ """
231
+
232
+ triples = data["positive_sample"]
233
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
234
+ score = self._calc(head_emb, relation_emb, tail_emb)
235
+ return score
236
+
237
+ def regularization(
238
+ self,
239
+ data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
240
+
241
+ """L2 正则化函数(又称权重衰减),在损失函数中用到。
242
+
243
+ :param data: 数据。
244
+ :type data: dict[str, typing.Union[torch.Tensor, str]]
245
+ :returns: 模型参数的正则损失
246
+ :rtype: torch.Tensor
247
+ """
248
+
249
+ pos_sample = data["positive_sample"]
250
+ neg_sample = data["negative_sample"]
251
+ mode = data["mode"]
252
+ pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
253
+ if mode == "bern":
254
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
255
+ else:
256
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
257
+
258
+ pos_regul = (torch.mean(pos_head_emb ** 2) +
259
+ torch.mean(pos_relation_emb ** 2) +
260
+ torch.mean(pos_tail_emb ** 2)) / 3
261
+
262
+ neg_regul = (torch.mean(neg_head_emb ** 2) +
263
+ torch.mean(neg_relation_emb ** 2) +
264
+ torch.mean(neg_tail_emb ** 2)) / 3
265
+
266
+ regul = (pos_regul + neg_regul) / 2
267
+
268
+ return regul
269
+
270
+ def l3_regularization(self) -> torch.Tensor:
271
+
272
+ """L3 正则化函数,在损失函数中用到。
273
+
274
+ :returns: 模型参数的正则损失
275
+ :rtype: torch.Tensor
276
+ """
277
+
278
+ return (self.ent_embeddings.weight.norm(p = 3)**3 + self.rel_embeddings.weight.norm(p = 3)**3)
279
+
280
+ def get_hole_hpo_config() -> dict[str, dict[str, typing.Any]]:
281
+
282
+ """返回 :py:class:`HolE` 的默认超参数优化配置。
283
+
284
+ 默认配置为::
285
+
286
+ parameters_dict = {
287
+ 'model': {
288
+ 'value': 'HolE'
289
+ },
290
+ 'dim': {
291
+ 'values': [50, 100, 200]
292
+ }
293
+ }
294
+
295
+ :returns: :py:class:`HolE` 的默认超参数优化配置
296
+ :rtype: dict[str, dict[str, typing.Any]]
297
+ """
298
+
299
+ parameters_dict = {
300
+ 'model': {
301
+ 'value': 'HolE'
302
+ },
303
+ 'dim': {
304
+ 'values': [50, 100, 200]
305
+ }
306
+ }
307
+
308
+ return parameters_dict
@@ -0,0 +1,107 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/Model.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2023
7
+ #
8
+ # 该头文件定义了 Model.
9
+
10
+ """Model 类 - 所有 KGE 模型的基类"""
11
+
12
+ import torch
13
+ from ..BaseModule import BaseModule
14
+
15
+ class Model(BaseModule):
16
+
17
+ """
18
+ 继承自 :py:class:`unike.module.BaseModule`,仅仅增加了两个属性::py:attr:`ent_tol` 和 :py:attr:`rel_tol`。
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ ent_tol: int,
24
+ rel_tol: int):
25
+
26
+ """创建 Model 对象。
27
+
28
+ :param ent_tol: 实体的个数
29
+ :type ent_tol: int
30
+ :param rel_tol: 关系的个数
31
+ :type rel_tol: int
32
+ """
33
+
34
+ super(Model, self).__init__()
35
+
36
+ #: 实体的种类
37
+ self.ent_tol: int = ent_tol
38
+ #: 关系的种类
39
+ self.rel_tol: int = rel_tol
40
+
41
+ def forward(self) -> torch.Tensor:
42
+
43
+ """
44
+ 定义每次调用时执行的计算。该方法未实现,子类必须重写该方法,否则抛出 :py:class:`NotImplementedError` 错误。
45
+
46
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
47
+
48
+ :returns: 三元组的得分
49
+ :rtype: torch.Tensor
50
+ """
51
+
52
+ raise NotImplementedError
53
+
54
+ def predict(self) -> torch.Tensor:
55
+
56
+ """
57
+ KGE 模型的推理方法。该方法未实现,子类必须重写该方法,否则抛出 :py:class:`NotImplementedError` 错误。
58
+
59
+ :returns: 三元组的得分
60
+ :rtype: torch.Tensor
61
+ """
62
+
63
+ raise NotImplementedError
64
+
65
+ def tri2emb(
66
+ self,
67
+ triples: torch.Tensor,
68
+ negs: torch.Tensor = None,
69
+ mode: str = 'single') -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
70
+
71
+ """
72
+ 返回三元组对应的嵌入向量。
73
+
74
+ :param triples: 正确的三元组
75
+ :type triples: torch.Tensor
76
+ :param negs: 负三元组类别
77
+ :type negs: torch.Tensor
78
+ :param mode: 模式
79
+ :type triples: str
80
+ :returns: 头实体、关系和尾实体的嵌入向量
81
+ :rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
82
+ """
83
+
84
+ if mode == "single":
85
+ head_emb = self.ent_embeddings(triples[:, 0]).unsqueeze(1)
86
+ relation_emb = self.rel_embeddings(triples[:, 1]).unsqueeze(1)
87
+ tail_emb = self.ent_embeddings(triples[:, 2]).unsqueeze(1)
88
+
89
+ elif mode == "head-batch" or mode == "head_predict":
90
+ if negs is None:
91
+ head_emb = self.ent_embeddings.weight.data.unsqueeze(0)
92
+ else:
93
+ head_emb = self.ent_embeddings(negs)
94
+
95
+ relation_emb = self.rel_embeddings(triples[:, 1]).unsqueeze(1)
96
+ tail_emb = self.ent_embeddings(triples[:, 2]).unsqueeze(1)
97
+
98
+ elif mode == "tail-batch" or mode == "tail_predict":
99
+ head_emb = self.ent_embeddings(triples[:, 0]).unsqueeze(1)
100
+ relation_emb = self.rel_embeddings(triples[:, 1]).unsqueeze(1)
101
+
102
+ if negs is None:
103
+ tail_emb = self.ent_embeddings.weight.data.unsqueeze(0)
104
+ else:
105
+ tail_emb = self.ent_embeddings(negs)
106
+
107
+ return head_emb, relation_emb, tail_emb
@@ -0,0 +1,309 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/RESCAL.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 31, 2023
7
+ #
8
+ # 该头文件定义了 RESCAL.
9
+
10
+ """
11
+ RESCAL - 一个张量分解模型。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import torch.nn as nn
17
+ from .Model import Model
18
+ from typing_extensions import override
19
+
20
+ class RESCAL(Model):
21
+
22
+ """
23
+ ``RESCAL`` :cite:`RESCAL` 提出于 2011 年,是很多张量分解模型的基石,模型较复杂。
24
+
25
+ 评分函数为:
26
+
27
+ .. math::
28
+
29
+ -\mathbf{h}^T \mathbf{M}_r \mathbf{t}
30
+
31
+ 正三元组的评分函数的值越小越好,如果想获得更详细的信息请访问 :ref:`RESCAL <rescal>`。
32
+
33
+ 例子::
34
+
35
+ from unike.utils import WandbLogger
36
+ from unike.data import KGEDataLoader, BernSampler, TradTestSampler
37
+ from unike.module.model import RESCAL
38
+ from unike.module.loss import MarginLoss
39
+ from unike.module.strategy import NegativeSampling
40
+ from unike.config import Trainer, Tester
41
+
42
+ wandb_logger = WandbLogger(
43
+ project="pybind11-ke",
44
+ name="RESCAL-FB15K237",
45
+ config=dict(
46
+ in_path = '../../benchmarks/FB15K237/',
47
+ batch_size = 2048,
48
+ neg_ent = 25,
49
+ test = True,
50
+ test_batch_size = 10,
51
+ num_workers = 16,
52
+ dim = 50,
53
+ margin = 1.0,
54
+ use_gpu = True,
55
+ device = 'cuda:0',
56
+ epochs = 1000,
57
+ lr = 0.1,
58
+ opt_method = 'adagrad',
59
+ valid_interval = 100,
60
+ log_interval = 100,
61
+ save_interval = 100,
62
+ save_path = '../../checkpoint/rescal.pth'
63
+ )
64
+ )
65
+
66
+ config = wandb_logger.config
67
+
68
+ # dataloader for training
69
+ dataloader = KGEDataLoader(
70
+ in_path = config.in_path,
71
+ batch_size = config.batch_size,
72
+ neg_ent = config.neg_ent,
73
+ test = config.test,
74
+ test_batch_size = config.test_batch_size,
75
+ num_workers = config.num_workers,
76
+ train_sampler = BernSampler,
77
+ test_sampler = TradTestSampler
78
+ )
79
+
80
+ # define the model
81
+ rescal = RESCAL(
82
+ ent_tol = dataloader.get_ent_tol(),
83
+ rel_tol = dataloader.get_rel_tol(),
84
+ dim = config.dim
85
+ )
86
+
87
+ # define the loss function
88
+ model = NegativeSampling(
89
+ model = rescal,
90
+ loss = MarginLoss(margin = config.margin)
91
+ )
92
+
93
+ # test the model
94
+ tester = Tester(model = rescal, data_loader = dataloader, use_gpu = config.use_gpu, device = config.device)
95
+
96
+ # train the model
97
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(), epochs = config.epochs,
98
+ lr = config.lr, opt_method = config.opt_method, use_gpu = config.use_gpu, device = config.device,
99
+ tester = tester, test = config.test, valid_interval = config.valid_interval,
100
+ log_interval = config.log_interval, save_interval = config.save_interval,
101
+ save_path = config.save_path, use_wandb = True)
102
+ trainer.run()
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ ent_tol: int,
108
+ rel_tol: int,
109
+ dim: int = 100):
110
+
111
+ """创建 RESCAL 对象。
112
+
113
+ :param ent_tol: 实体的个数
114
+ :type ent_tol: int
115
+ :param rel_tol: 关系的个数
116
+ :type rel_tol: int
117
+ :param dim: 实体和关系嵌入向量的维度
118
+ :type dim: int
119
+ """
120
+
121
+ super(RESCAL, self).__init__(ent_tol, rel_tol)
122
+
123
+ #: 实体和关系嵌入向量的维度
124
+ self.dim: int = dim
125
+ #: 根据实体个数,创建的实体嵌入
126
+ self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim)
127
+ #: 根据关系个数,创建的关系矩阵
128
+ self.rel_matrices: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim * self.dim)
129
+
130
+ nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
131
+ nn.init.xavier_uniform_(self.rel_matrices.weight.data)
132
+
133
+ @override
134
+ def forward(
135
+ self,
136
+ triples: torch.Tensor,
137
+ negs: torch.Tensor = None,
138
+ mode: str = 'single') -> torch.Tensor:
139
+
140
+ """
141
+ 定义每次调用时执行的计算。
142
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
143
+
144
+ :param triples: 正确的三元组
145
+ :type triples: torch.Tensor
146
+ :param negs: 负三元组类别
147
+ :type negs: torch.Tensor
148
+ :param mode: 模式
149
+ :type triples: str
150
+ :returns: 三元组的得分
151
+ :rtype: torch.Tensor
152
+ """
153
+
154
+ head_emb, tail_emb = self.tri2emb(triples, negs, mode)
155
+ rel_matric = self.rel_matrices(triples[:, 1])
156
+ score = self._calc(head_emb, rel_matric, tail_emb)
157
+ return score
158
+
159
+ @override
160
+ def tri2emb(
161
+ self,
162
+ triples: torch.Tensor,
163
+ negs: torch.Tensor = None,
164
+ mode: str = 'single') -> tuple[torch.Tensor, torch.Tensor]:
165
+
166
+ """
167
+ 返回三元组对应的嵌入向量。
168
+
169
+ :param triples: 正确的三元组
170
+ :type triples: torch.Tensor
171
+ :param negs: 负三元组类别
172
+ :type negs: torch.Tensor
173
+ :param mode: 模式
174
+ :type triples: str
175
+ :returns: 头实体和尾实体的嵌入向量
176
+ :rtype: tuple[torch.Tensor, torch.Tensor]
177
+ """
178
+
179
+ if mode == "single":
180
+ head_emb = self.ent_embeddings(triples[:, 0]).unsqueeze(1)
181
+ tail_emb = self.ent_embeddings(triples[:, 2]).unsqueeze(1)
182
+
183
+ elif mode == "head-batch" or mode == "head_predict":
184
+ if negs is None:
185
+ head_emb = self.ent_embeddings.weight.data.unsqueeze(0)
186
+ else:
187
+ head_emb = self.ent_embeddings(negs)
188
+
189
+ tail_emb = self.ent_embeddings(triples[:, 2]).unsqueeze(1)
190
+
191
+ elif mode == "tail-batch" or mode == "tail_predict":
192
+ head_emb = self.ent_embeddings(triples[:, 0]).unsqueeze(1)
193
+
194
+ if negs is None:
195
+ tail_emb = self.ent_embeddings.weight.data.unsqueeze(0)
196
+ else:
197
+ tail_emb = self.ent_embeddings(negs)
198
+
199
+ return head_emb, tail_emb
200
+
201
+ def _calc(
202
+ self,
203
+ h: torch.Tensor,
204
+ r: torch.Tensor,
205
+ t: torch.Tensor) -> torch.Tensor:
206
+
207
+ """计算 RESCAL 的评分函数。
208
+
209
+ :param h: 头实体的向量。
210
+ :type h: torch.Tensor
211
+ :param r: 关系矩阵。
212
+ :type r: torch.Tensor
213
+ :param t: 尾实体的向量。
214
+ :type t: torch.Tensor
215
+ :returns: 三元组的得分
216
+ :rtype: torch.Tensor
217
+ """
218
+
219
+ r = r.view(-1, self.dim, self.dim)
220
+ r = r.unsqueeze(dim=1)
221
+ h = h.unsqueeze(dim=-2)
222
+ hr = torch.matmul(h, r)
223
+ hr = hr.squeeze(dim=-2)
224
+ return -torch.sum(hr * t, -1)
225
+
226
+ @override
227
+ def predict(
228
+ self,
229
+ data: dict[str, typing.Union[torch.Tensor,str]],
230
+ mode) -> torch.Tensor:
231
+
232
+ """RESCAL 的推理方法。
233
+
234
+ :param data: 数据。
235
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
236
+ :returns: 三元组的得分
237
+ :rtype: torch.Tensor
238
+ """
239
+
240
+ triples = data["positive_sample"]
241
+ head_emb, tail_emb = self.tri2emb(triples, mode=mode)
242
+ rel_matric = self.rel_matrices(triples[:, 1])
243
+ score = self._calc(head_emb, rel_matric, tail_emb)
244
+ return -score
245
+
246
+ def regularization(
247
+ self,
248
+ data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
249
+
250
+ """L2 正则化函数(又称权重衰减),在损失函数中用到。
251
+
252
+ :param data: 数据。
253
+ :type data: dict[str, typing.Union[torch.Tensor, str]]
254
+ :returns: 模型参数的正则损失
255
+ :rtype: torch.Tensor
256
+ """
257
+
258
+ pos_sample = data["positive_sample"]
259
+ neg_sample = data["negative_sample"]
260
+ mode = data["mode"]
261
+ pos_head_emb, pos_tail_emb = self.tri2emb(pos_sample)
262
+ pos_rel_transfer = self.rel_matrices(pos_sample[:, 1])
263
+ if mode == "bern":
264
+ neg_head_emb, neg_tail_emb = self.tri2emb(neg_sample)
265
+ else:
266
+ neg_head_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
267
+ neg_rel_transfer = self.rel_matrices(pos_sample[:, 1])
268
+
269
+ pos_regul = (torch.mean(pos_head_emb ** 2) +
270
+ torch.mean(pos_tail_emb ** 2) +
271
+ torch.mean(pos_rel_transfer ** 2)) / 3
272
+
273
+ neg_regul = (torch.mean(neg_head_emb ** 2) +
274
+ torch.mean(neg_tail_emb ** 2) +
275
+ torch.mean(neg_rel_transfer ** 2)) / 3
276
+
277
+ regul = (pos_regul + neg_regul) / 2
278
+
279
+ return regul
280
+
281
+ def get_rescal_hpo_config() -> dict[str, dict[str, typing.Any]]:
282
+
283
+ """返回 :py:class:`RESCAL` 的默认超参数优化配置。
284
+
285
+ 默认配置为::
286
+
287
+ parameters_dict = {
288
+ 'model': {
289
+ 'value': 'RESCAL'
290
+ },
291
+ 'dim': {
292
+ 'values': [50, 100, 200]
293
+ }
294
+ }
295
+
296
+ :returns: :py:class:`RESCAL` 的默认超参数优化配置
297
+ :rtype: dict[str, dict[str, typing.Any]]
298
+ """
299
+
300
+ parameters_dict = {
301
+ 'model': {
302
+ 'value': 'RESCAL'
303
+ },
304
+ 'dim': {
305
+ 'values': [50, 100, 200]
306
+ }
307
+ }
308
+
309
+ return parameters_dict