unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,290 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/TransE.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 9, 2024
7
+ #
8
+ # 该头文件定义了 TransE.
9
+
10
+ """
11
+ TransE - 第一个平移模型,简单而且高效。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from .Model import Model
19
+ from typing_extensions import override
20
+
21
+ class TransE(Model):
22
+
23
+ """
24
+ ``TransE`` :cite:`TransE` 提出于 2013 年,是第一个平移模型,开创了平移模型研究方向。由于其简单性和高效性,
25
+ 至今依旧是常用基线模型,在某些数据集上能够比其他更复杂的模型表现的更好。
26
+
27
+ 评分函数为:
28
+
29
+ .. math::
30
+
31
+ \parallel h + r - t \parallel_{L_1/L_2}
32
+
33
+ 正三元组的评分函数的值越小越好,如果想获得更详细的信息请访问 :ref:`TransE <transe>`。
34
+
35
+ 例子::
36
+
37
+ from unike.data import KGEDataLoader, BernSampler, TradTestSampler
38
+ from unike.module.model import TransE
39
+ from unike.module.loss import MarginLoss
40
+ from unike.module.strategy import NegativeSampling
41
+ from unike.config import Trainer, Tester
42
+
43
+ # dataloader for training
44
+ dataloader = KGEDataLoader(
45
+ in_path = "../../benchmarks/FB15K/",
46
+ batch_size = 8192,
47
+ neg_ent = 25,
48
+ test = True,
49
+ test_batch_size = 256,
50
+ num_workers = 16,
51
+ train_sampler = BernSampler,
52
+ test_sampler = TradTestSampler
53
+ )
54
+
55
+ # define the model
56
+ transe = TransE(
57
+ ent_tol = dataloader.get_ent_tol(),
58
+ rel_tol = dataloader.get_rel_tol(),
59
+ dim = 50,
60
+ p_norm = 1,
61
+ norm_flag = True)
62
+
63
+ # define the loss function
64
+ model = NegativeSampling(
65
+ model = transe,
66
+ loss = MarginLoss(margin = 1.0),
67
+ regul_rate = 0.01
68
+ )
69
+
70
+ # test the model
71
+ tester = Tester(model = transe, data_loader = dataloader, use_gpu = True, device = 'cuda:1')
72
+
73
+ # train the model
74
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
75
+ epochs = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1',
76
+ tester = tester, test = True, valid_interval = 100,
77
+ log_interval = 100, save_interval = 100,
78
+ save_path = '../../checkpoint/transe.pth', delta = 0.01)
79
+ trainer.run()
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ ent_tol: int,
85
+ rel_tol: int,
86
+ dim: int = 100,
87
+ p_norm: int = 1,
88
+ norm_flag: bool = True,
89
+ margin: float | None = None):
90
+
91
+ """创建 TransE 对象。
92
+
93
+ :param ent_tol: 实体的个数
94
+ :type ent_tol: int
95
+ :param rel_tol: 关系的个数
96
+ :type rel_tol: int
97
+ :param dim: 实体和关系嵌入向量的维度
98
+ :type dim: int
99
+ :param p_norm: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
100
+ :type p_norm: int
101
+ :param norm_flag: 是否利用 :py:func:`torch.nn.functional.normalize` 对实体和关系嵌入的最后一维执行 L2-norm。
102
+ :type norm_flag: bool
103
+ :param margin: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
104
+ :type margin: float
105
+ """
106
+
107
+ super(TransE, self).__init__(ent_tol, rel_tol)
108
+
109
+ #: 实体和关系嵌入向量的维度
110
+ self.dim: int = dim
111
+ #: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
112
+ self.p_norm: int = p_norm
113
+ #: 是否利用 :py:func:`torch.nn.functional.normalize`
114
+ #: 对实体和关系嵌入向量的最后一维执行 L2-norm。
115
+ self.norm_flag: bool = norm_flag
116
+
117
+ #: 根据实体个数,创建的实体嵌入
118
+ self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim)
119
+ #: 根据关系个数,创建的关系嵌入
120
+ self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim)
121
+
122
+ if margin != None:
123
+ #: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
124
+ self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
125
+ self.margin.requires_grad = False
126
+ self.margin_flag: bool = True
127
+ else:
128
+ self.margin_flag: bool = False
129
+
130
+ nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
131
+ nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
132
+
133
+ @override
134
+ def forward(
135
+ self,
136
+ triples: torch.Tensor,
137
+ negs: torch.Tensor = None,
138
+ mode: str = 'single') -> torch.Tensor:
139
+
140
+ """
141
+ 定义每次调用时执行的计算。
142
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
143
+
144
+ :param triples: 正确的三元组
145
+ :type triples: torch.Tensor
146
+ :param negs: 负三元组类别
147
+ :type negs: torch.Tensor
148
+ :param mode: 模式
149
+ :type triples: str
150
+ :returns: 三元组的得分
151
+ :rtype: torch.Tensor
152
+ """
153
+
154
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
155
+ score = self._calc(head_emb, relation_emb, tail_emb)
156
+ if self.margin_flag:
157
+ return self.margin - score
158
+ else:
159
+ return score
160
+
161
+ def _calc(
162
+ self,
163
+ h: torch.Tensor,
164
+ r: torch.Tensor,
165
+ t: torch.Tensor) -> torch.Tensor:
166
+
167
+ """计算 TransE 的评分函数。
168
+
169
+ :param h: 头实体的向量。
170
+ :type h: torch.Tensor
171
+ :param r: 关系的向量。
172
+ :type r: torch.Tensor
173
+ :param t: 尾实体的向量。
174
+ :type t: torch.Tensor
175
+ :returns: 三元组的得分
176
+ :rtype: torch.Tensor
177
+ """
178
+
179
+ # 对嵌入的最后一维进行归一化
180
+ if self.norm_flag:
181
+ h = F.normalize(h, 2, -1)
182
+ r = F.normalize(r, 2, -1)
183
+ t = F.normalize(t, 2, -1)
184
+
185
+ score = (h + r) - t
186
+
187
+ # 利用距离函数计算得分
188
+ score = torch.norm(score, self.p_norm, -1)
189
+ return score
190
+
191
+ @override
192
+ def predict(
193
+ self,
194
+ data: dict[str, typing.Union[torch.Tensor,str]],
195
+ mode: str) -> torch.Tensor:
196
+
197
+ """TransE 的推理方法。
198
+
199
+ :param data: 数据。
200
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
201
+ :param mode: 'head_predict' 或 'tail_predict'
202
+ :type mode: str
203
+ :returns: 三元组的得分
204
+ :rtype: torch.Tensor
205
+ """
206
+
207
+ triples = data["positive_sample"]
208
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
209
+ score = self._calc(head_emb, relation_emb, tail_emb)
210
+
211
+ if self.margin_flag:
212
+ score = self.margin - score
213
+ return score
214
+ else:
215
+ return -score
216
+
217
+ def regularization(
218
+ self,
219
+ data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
220
+
221
+ """L2 正则化函数(又称权重衰减),在损失函数中用到。
222
+
223
+ :param data: 数据。
224
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
225
+ :returns: 模型参数的正则损失
226
+ :rtype: torch.Tensor
227
+ """
228
+
229
+ pos_sample = data["positive_sample"]
230
+ neg_sample = data["negative_sample"]
231
+ mode = data["mode"]
232
+ pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
233
+ if mode == "bern":
234
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
235
+ else:
236
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
237
+
238
+ pos_regul = (torch.mean(pos_head_emb ** 2) +
239
+ torch.mean(pos_relation_emb ** 2) +
240
+ torch.mean(pos_tail_emb ** 2)) / 3
241
+
242
+ neg_regul = (torch.mean(neg_head_emb ** 2) +
243
+ torch.mean(neg_relation_emb ** 2) +
244
+ torch.mean(neg_tail_emb ** 2)) / 3
245
+
246
+ regul = (pos_regul + neg_regul) / 2
247
+
248
+ return regul
249
+
250
+ def get_transe_hpo_config() -> dict[str, dict[str, typing.Any]]:
251
+
252
+ """返回 :py:class:`TransE` 的默认超参数优化配置。
253
+
254
+ 默认配置为::
255
+
256
+ parameters_dict = {
257
+ 'model': {
258
+ 'value': 'TransE'
259
+ },
260
+ 'dim': {
261
+ 'values': [50, 100, 200]
262
+ },
263
+ 'p_norm': {
264
+ 'values': [1, 2]
265
+ },
266
+ 'norm_flag': {
267
+ 'value': True
268
+ }
269
+ }
270
+
271
+ :returns: :py:class:`TransE` 的默认超参数优化配置
272
+ :rtype: dict[str, dict[str, typing.Any]]
273
+ """
274
+
275
+ parameters_dict = {
276
+ 'model': {
277
+ 'value': 'TransE'
278
+ },
279
+ 'dim': {
280
+ 'values': [50, 100, 200]
281
+ },
282
+ 'p_norm': {
283
+ 'values': [1, 2]
284
+ },
285
+ 'norm_flag': {
286
+ 'value': True
287
+ }
288
+ }
289
+
290
+ return parameters_dict
@@ -0,0 +1,322 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/TransH.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
7
+ #
8
+ # 该头文件定义了 TransH.
9
+
10
+ """
11
+ TransH - 是第二个平移模型,将关系建模为超平面上的平移操作。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from .Model import Model
19
+ from typing_extensions import override
20
+
21
+ class TransH(Model):
22
+
23
+ """
24
+ ``TransH`` :cite:`TransH` 提出于 2014 年,是第二个平移模型,将关系建模为超平面上的平移操作。
25
+
26
+ 评分函数为:
27
+
28
+ .. math::
29
+
30
+ \Vert (h-r_w^T hr_w)+r_d-(t-r_w^T tr_w)\Vert_{L_1/L_2}
31
+
32
+ 正三元组的评分函数的值越小越好,如果想获得更详细的信息请访问 :ref:`TransH <transh>`。
33
+
34
+ 例子::
35
+
36
+ from unike.data import KGEDataLoader, BernSampler, TradTestSampler
37
+ from unike.module.model import TransH
38
+ from unike.module.loss import MarginLoss
39
+ from unike.module.strategy import NegativeSampling
40
+ from unike.config import Trainer, Tester
41
+
42
+ # dataloader for training
43
+ dataloader = KGEDataLoader(
44
+ in_path = "../../benchmarks/FB15K237/",
45
+ batch_size = 4096,
46
+ neg_ent = 25,
47
+ test = True,
48
+ test_batch_size = 30,
49
+ num_workers = 16,
50
+ train_sampler = BernSampler,
51
+ test_sampler = TradTestSampler
52
+ )
53
+
54
+ # define the model
55
+ transh = TransH(
56
+ ent_tol = dataloader.get_ent_tol(),
57
+ rel_tol = dataloader.get_rel_tol(),
58
+ dim = 200,
59
+ p_norm = 1,
60
+ norm_flag = True)
61
+
62
+ # define the loss function
63
+ model = NegativeSampling(
64
+ model = transh,
65
+ loss = MarginLoss(margin = 4.0),
66
+ # regul_rate = 0.01
67
+ )
68
+
69
+ # test the model
70
+ tester = Tester(model = transh, data_loader = dataloader, use_gpu = True, device = 'cuda:1')
71
+
72
+ # train the model
73
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
74
+ epochs = 1000, lr = 0.5, use_gpu = True, device = 'cuda:1',
75
+ tester = tester, test = True, valid_interval = 100,
76
+ log_interval = 100, save_interval = 100, save_path = '../../checkpoint/transh.pth',
77
+ delta = 0.01)
78
+ trainer.run()
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ ent_tol: int,
84
+ rel_tol: int,
85
+ dim: int = 100,
86
+ p_norm: int = 1,
87
+ norm_flag: bool = True,
88
+ margin: float | None = None):
89
+
90
+ """创建 TransH 对象。
91
+
92
+ :param ent_tol: 实体的个数
93
+ :type ent_tol: int
94
+ :param rel_tol: 关系的个数
95
+ :type rel_tol: int
96
+ :param dim: 实体、关系嵌入向量和和法向量的维度
97
+ :type dim: int
98
+ :param p_norm: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
99
+ :type p_norm: int
100
+ :param norm_flag: 是否利用 :py:func:`torch.nn.functional.normalize` 对实体和关系嵌入的最后一维执行 L2-norm。
101
+ :type norm_flag: bool
102
+ :param margin: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
103
+ :type margin: float
104
+ """
105
+
106
+ super(TransH, self).__init__(ent_tol, rel_tol)
107
+
108
+ #: 实体、关系嵌入向量和和法向量的维度
109
+ self.dim: int = dim
110
+ #: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
111
+ self.p_norm: int = p_norm
112
+ #: 是否利用 :py:func:`torch.nn.functional.normalize`
113
+ #: 对实体和关系嵌入向量的最后一维执行 L2-norm。
114
+ self.norm_flag: bool = norm_flag
115
+
116
+ #: 根据实体个数,创建的实体嵌入
117
+ self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim)
118
+ #: 根据关系个数,创建的关系嵌入
119
+ self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim)
120
+ #: 根据关系个数,创建的法向量
121
+ self.norm_vector: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim)
122
+
123
+ if margin != None:
124
+ #: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
125
+ self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
126
+ self.margin.requires_grad = False
127
+ self.margin_flag = True
128
+ else:
129
+ self.margin_flag = False
130
+
131
+ nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
132
+ nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
133
+ nn.init.xavier_uniform_(self.norm_vector.weight.data)
134
+
135
+ @override
136
+ def forward(
137
+ self,
138
+ triples: torch.Tensor,
139
+ negs: torch.Tensor = None,
140
+ mode: str = 'single') -> torch.Tensor:
141
+
142
+ """
143
+ 定义每次调用时执行的计算。
144
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
145
+
146
+ :param triples: 正确的三元组
147
+ :type triples: torch.Tensor
148
+ :param negs: 负三元组类别
149
+ :type negs: torch.Tensor
150
+ :param mode: 模式
151
+ :type triples: str
152
+ :returns: 三元组的得分
153
+ :rtype: torch.Tensor
154
+ """
155
+
156
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
157
+ norm_vector = self.norm_vector(triples[:, 1]).unsqueeze(dim=1)
158
+ head_emb = self._transfer(head_emb, norm_vector)
159
+ tail_emb = self._transfer(tail_emb, norm_vector)
160
+ score = self._calc(head_emb, relation_emb, tail_emb)
161
+
162
+ if self.margin_flag:
163
+ return self.margin - score
164
+ else:
165
+ return score
166
+
167
+ def _transfer(
168
+ self,
169
+ e: torch.Tensor,
170
+ norm: torch.Tensor) -> torch.Tensor:
171
+
172
+ """
173
+ 将头实体或尾实体的向量投影到超平面上。
174
+
175
+ :param e: 头实体或尾实体向量。
176
+ :type e: torch.Tensor
177
+ :param norm: 法向量
178
+ :type norm: torch.Tensor
179
+ :returns: 投影后的实体向量
180
+ :rtype: torch.Tensor
181
+ """
182
+
183
+ norm = F.normalize(norm, p = 2, dim = -1)
184
+ return e - torch.sum(e * norm, -1, True) * norm
185
+
186
+ def _calc(
187
+ self,
188
+ h: torch.Tensor,
189
+ r: torch.Tensor,
190
+ t: torch.Tensor) -> torch.Tensor:
191
+
192
+ """计算 TransH 的评分函数。
193
+
194
+ :param h: 头实体的向量。
195
+ :type h: torch.Tensor
196
+ :param r: 关系实体的向量。
197
+ :type r: torch.Tensor
198
+ :param t: 尾实体的向量。
199
+ :type t: torch.Tensor
200
+ :returns: 三元组的得分
201
+ :rtype: torch.Tensor
202
+ """
203
+
204
+ # 对嵌入的最后一维进行归一化
205
+ if self.norm_flag:
206
+ h = F.normalize(h, 2, -1)
207
+ r = F.normalize(r, 2, -1)
208
+ t = F.normalize(t, 2, -1)
209
+
210
+ score = (h + r) - t
211
+
212
+ # 利用距离函数计算得分
213
+ score = torch.norm(score, self.p_norm, -1)
214
+ return score
215
+
216
+ @override
217
+ def predict(
218
+ self,
219
+ data: dict[str, typing.Union[torch.Tensor,str]],
220
+ mode: str) -> torch.Tensor:
221
+
222
+ """TransH 的推理方法。
223
+
224
+ :param data: 数据。
225
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
226
+ :param mode: 'head_predict' 或 'tail_predict'
227
+ :type mode: str
228
+ :returns: 三元组的得分
229
+ :rtype: torch.Tensor
230
+ """
231
+
232
+ triples = data["positive_sample"]
233
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
234
+ norm_vector = self.norm_vector(triples[:, 1]).unsqueeze(dim=1)
235
+ head_emb = self._transfer(head_emb, norm_vector)
236
+ tail_emb = self._transfer(tail_emb, norm_vector)
237
+ score = self._calc(head_emb, relation_emb, tail_emb)
238
+
239
+ if self.margin_flag:
240
+ score = self.margin - score
241
+ return score
242
+ else:
243
+ return -score
244
+
245
+ def regularization(
246
+ self,
247
+ data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
248
+
249
+ """L2 正则化函数(又称权重衰减),在损失函数中用到。
250
+
251
+ :param data: 数据。
252
+ :type data: dict[str, typing.Union[torch.Tensor, str]]
253
+ :returns: 模型参数的正则损失
254
+ :rtype: torch.Tensor
255
+ """
256
+
257
+ pos_sample = data["positive_sample"]
258
+ neg_sample = data["negative_sample"]
259
+ mode = data["mode"]
260
+ pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
261
+ pos_norm_vector = self.norm_vector(pos_sample[:, 1]).unsqueeze(dim=1)
262
+ if mode == "bern":
263
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
264
+ else:
265
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
266
+ neg_norm_vector = self.norm_vector(pos_sample[:, 1]).unsqueeze(dim=1)
267
+
268
+ pos_regul = (torch.mean(pos_head_emb ** 2) +
269
+ torch.mean(pos_relation_emb ** 2) +
270
+ torch.mean(pos_tail_emb ** 2) +
271
+ torch.mean(pos_norm_vector ** 2)) / 4
272
+
273
+ neg_regul = (torch.mean(neg_head_emb ** 2) +
274
+ torch.mean(neg_relation_emb ** 2) +
275
+ torch.mean(neg_tail_emb ** 2) +
276
+ torch.mean(neg_norm_vector ** 2)) / 4
277
+
278
+ regul = (pos_regul + neg_regul) / 2
279
+
280
+ return regul
281
+
282
+ def get_transh_hpo_config() -> dict[str, dict[str, typing.Any]]:
283
+
284
+ """返回 :py:class:`TransH` 的默认超参数优化配置。
285
+
286
+ 默认配置为::
287
+
288
+ parameters_dict = {
289
+ 'model': {
290
+ 'value': 'TransH'
291
+ },
292
+ 'dim': {
293
+ 'values': [50, 100, 200]
294
+ },
295
+ 'p_norm': {
296
+ 'values': [1, 2]
297
+ },
298
+ 'norm_flag': {
299
+ 'value': True
300
+ }
301
+ }
302
+
303
+ :returns: :py:class:`TransH` 的默认超参数优化配置
304
+ :rtype: dict[str, dict[str, typing.Any]]
305
+ """
306
+
307
+ parameters_dict = {
308
+ 'model': {
309
+ 'value': 'TransH'
310
+ },
311
+ 'dim': {
312
+ 'values': [50, 100, 200]
313
+ },
314
+ 'p_norm': {
315
+ 'values': [1, 2]
316
+ },
317
+ 'norm_flag': {
318
+ 'value': True
319
+ }
320
+ }
321
+
322
+ return parameters_dict