unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,304 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/RGCN.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 22, 2024
7
+ #
8
+ # 该头文件定义了 R-GCN.
9
+
10
+ """
11
+ R-GCN - 第一个图神经网络模型。
12
+ """
13
+
14
+ import dgl
15
+ import torch
16
+ import typing
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from .Model import Model
20
+ from dgl.nn.pytorch import RelGraphConv
21
+ from typing_extensions import override
22
+
23
+ class RGCN(Model):
24
+
25
+ """
26
+ ``R-GCN`` :cite:`R-GCN` 提出于 2017 年,是第一个图神经网络模型。
27
+
28
+ 正三元组的评分函数的值越大越好,如果想获得更详细的信息请访问 :ref:`R-GCN <rgcn>`。
29
+
30
+ 例子::
31
+
32
+ from unike.data import GraphDataLoader
33
+ from unike.module.model import RGCN
34
+ from unike.module.loss import RGCNLoss
35
+ from unike.module.strategy import RGCNSampling
36
+ from unike.config import Trainer, GraphTester
37
+
38
+ dataloader = GraphDataLoader(
39
+ in_path = "../../benchmarks/FB15K237/",
40
+ batch_size = 60000,
41
+ neg_ent = 10,
42
+ test = True,
43
+ test_batch_size = 100,
44
+ num_workers = 16
45
+ )
46
+
47
+ # define the model
48
+ rgcn = RGCN(
49
+ ent_tol = dataloader.get_ent_tol(),
50
+ rel_tol = dataloader.get_rel_tol(),
51
+ dim = 500,
52
+ num_layers = 2
53
+ )
54
+
55
+ # define the loss function
56
+ model = RGCNSampling(
57
+ model = rgcn,
58
+ loss = RGCNLoss(model = rgcn, regularization = 1e-5)
59
+ )
60
+
61
+ # test the model
62
+ tester = GraphTester(model = rgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0')
63
+
64
+ # train the model
65
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
66
+ epochs = 10000, lr = 0.0001, use_gpu = True, device = 'cuda:0',
67
+ tester = tester, test = True, valid_interval = 500, log_interval = 500,
68
+ save_interval = 500, save_path = '../../checkpoint/rgcn.pth'
69
+ )
70
+ trainer.run()
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ ent_tol: int,
76
+ rel_tol: int,
77
+ dim: int,
78
+ num_layers: int):
79
+
80
+ """创建 RGCN 对象。
81
+
82
+ :param ent_tol: 实体的个数
83
+ :type ent_tol: int
84
+ :param rel_tol: 关系的个数
85
+ :type rel_tol: int
86
+ :param dim: 实体和关系嵌入向量的维度
87
+ :type dim: int
88
+ :param num_layers: 图神经网络的层数
89
+ :type num_layers: int
90
+ """
91
+
92
+ super(RGCN, self).__init__(ent_tol, rel_tol)
93
+
94
+ #: 实体和关系嵌入向量的维度
95
+ self.dim: int = dim
96
+ #: 图神经网络的层数
97
+ self.num_layers: int = num_layers
98
+
99
+ #: 根据实体个数,创建的实体嵌入
100
+ self.ent_emb: torch.nn.Embedding = None
101
+ #: 根据关系个数,创建的关系嵌入
102
+ self.rel_emb: torch.nn.parameter.Parameter = None
103
+ #: R-GCN 的图神经网络层
104
+ self.RGCN: torch.nn.ModuleList = None
105
+ #: 图神经网络层的输出
106
+ self.Loss_emb: torch.nn.Embedding = None
107
+
108
+ self.build_model()
109
+
110
+ def build_model(self):
111
+
112
+ """构建模型"""
113
+
114
+ self.ent_emb = nn.Embedding(self.ent_tol, self.dim)
115
+
116
+ self.rel_emb = nn.Parameter(torch.Tensor(self.rel_tol, self.dim))
117
+
118
+ nn.init.xavier_uniform_(self.rel_emb, gain=nn.init.calculate_gain('relu'))
119
+
120
+ self.RGCN = nn.ModuleList()
121
+ for idx in range(self.num_layers):
122
+ RGCN_idx = self.build_hidden_layer(idx)
123
+ self.RGCN.append(RGCN_idx)
124
+
125
+ def build_hidden_layer(
126
+ self,
127
+ idx: int) -> dgl.nn.pytorch.conv.RelGraphConv:
128
+
129
+ """返回第 idx 的图神经网络层。
130
+
131
+ :param idx: 数据。
132
+ :type idx: int
133
+ :returns: 图神经网络层
134
+ :rtype: dgl.nn.pytorch.conv.RelGraphConv
135
+ """
136
+
137
+ act = F.relu if idx < self.num_layers - 1 else None
138
+ return RelGraphConv(self.dim, self.dim, self.rel_tol, "bdd",
139
+ num_bases=100, activation=act, self_loop=True, dropout=0.2)
140
+
141
+ @override
142
+ def forward(
143
+ self,
144
+ graph: dgl.DGLGraph,
145
+ ent: torch.Tensor,
146
+ rel: torch.Tensor,
147
+ norm: torch.Tensor,
148
+ triples: torch.Tensor,
149
+ mode: str = 'single') -> torch.Tensor:
150
+
151
+ """
152
+ 定义每次调用时执行的计算。
153
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
154
+
155
+ :param graph: 子图
156
+ :type graph: dgl.DGLGraph
157
+ :param ent: 子图的实体
158
+ :type ent: torch.Tensor
159
+ :param rel: 子图的关系
160
+ :type rel: torch.Tensor
161
+ :param norm: 关系的归一化系数
162
+ :type norm: torch.Tensor
163
+ :param triples: 三元组
164
+ :type triples: torch.Tensor
165
+ :param mode: 模式
166
+ :type mode: str
167
+ :returns: 三元组的得分
168
+ :rtype: torch.Tensor
169
+ """
170
+
171
+ embedding = self.ent_emb(ent.squeeze())
172
+ for layer in self.RGCN:
173
+ embedding = layer(graph, embedding, rel, norm)
174
+ self.Loss_emb = embedding
175
+ head_emb, rela_emb, tail_emb = self.tri2emb(embedding, triples, mode)
176
+ score = self.distmult_score_func(head_emb, rela_emb, tail_emb, mode)
177
+
178
+ return score
179
+
180
+ def tri2emb(
181
+ self,
182
+ embedding: torch.Tensor,
183
+ triples: torch.Tensor,
184
+ mode: str = "single") -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
185
+
186
+ """
187
+ 获得三元组对应头实体、关系和尾实体的嵌入向量。
188
+
189
+ :param embedding: 经过图神经网络更新的实体嵌入向量
190
+ :type embedding: torch.Tensor
191
+ :param triples: 训练的三元组
192
+ :type triples: torch.Tensor
193
+ :param mode: 模式
194
+ :type mode: str
195
+ :returns: 头实体、关系和尾实体的嵌入向量
196
+ :rtype: torch.Tensor
197
+ """
198
+
199
+ rela_emb = self.rel_emb[triples[:, 1]].unsqueeze(1) # [bs, 1, dim]
200
+ head_emb = embedding[triples[:, 0]].unsqueeze(1) # [bs, 1, dim]
201
+ tail_emb = embedding[triples[:, 2]].unsqueeze(1) # [bs, 1, dim]
202
+
203
+ if mode == "head-batch" or mode == "head_predict":
204
+ head_emb = embedding.unsqueeze(0) # [1, num_ent, dim]
205
+
206
+ elif mode == "tail-batch" or mode == "tail_predict":
207
+ tail_emb = embedding.unsqueeze(0) # [1, num_ent, dim]
208
+
209
+ return head_emb, rela_emb, tail_emb
210
+
211
+ def distmult_score_func(
212
+ self,
213
+ head_emb: torch.Tensor,
214
+ relation_emb: torch.Tensor,
215
+ tail_emb: torch.Tensor,
216
+ mode: str) -> torch.Tensor:
217
+
218
+ """
219
+ 计算 DistMult 的评分函数。
220
+
221
+ :param head_emb: 头实体嵌入向量
222
+ :type head_emb: torch.Tensor
223
+ :param relation_emb: 关系嵌入向量
224
+ :type relation_emb: torch.Tensor
225
+ :param tail_emb: 尾实体嵌入向量
226
+ :type tail_emb: torch.Tensor
227
+ :returns: 三元组的得分
228
+ :rtype: torch.Tensor
229
+ """
230
+
231
+ if mode == 'head-batch':
232
+ score = head_emb * (relation_emb * tail_emb)
233
+ else:
234
+ score = (head_emb * relation_emb) * tail_emb
235
+
236
+ score = score.sum(dim = -1)
237
+ return score
238
+
239
+ @override
240
+ def predict(
241
+ self,
242
+ data: dict[str, torch.Tensor],
243
+ mode: str) -> torch.Tensor:
244
+
245
+ """R-GCN 的推理方法。
246
+
247
+ :param data: 数据。
248
+ :type data: dict[str, torch.Tensor]
249
+ :param mode: 模式
250
+ :type mode: str
251
+ :returns: 三元组的得分
252
+ :rtype: torch.Tensor
253
+ """
254
+
255
+ triples = data['positive_sample']
256
+ graph = data['graph']
257
+ ent = data['entity']
258
+ rel = data['rela']
259
+ norm = data['norm']
260
+
261
+ embedding = self.ent_emb(ent.squeeze())
262
+ for layer in self.RGCN:
263
+ embedding = layer(graph, embedding, rel, norm)
264
+ self.Loss_emb = embedding
265
+ head_emb, rela_emb, tail_emb = self.tri2emb(embedding, triples, mode)
266
+ score = self.distmult_score_func(head_emb, rela_emb, tail_emb, mode)
267
+
268
+ return score
269
+
270
+ def get_rgcn_hpo_config() -> dict[str, dict[str, typing.Any]]:
271
+
272
+ """返回 :py:class:`RGCN` 的默认超参数优化配置。
273
+
274
+ 默认配置为::
275
+
276
+ parameters_dict = {
277
+ 'model': {
278
+ 'value': 'RGCN'
279
+ },
280
+ 'dim': {
281
+ 'values': [200, 300, 400]
282
+ },
283
+ 'num_layers': {
284
+ 'value': 2
285
+ }
286
+ }
287
+
288
+ :returns: :py:class:`RGCN` 的默认超参数优化配置
289
+ :rtype: dict[str, dict[str, typing.Any]]
290
+ """
291
+
292
+ parameters_dict = {
293
+ 'model': {
294
+ 'value': 'RGCN'
295
+ },
296
+ 'dim': {
297
+ 'values': [200, 300, 400]
298
+ },
299
+ 'num_layers': {
300
+ 'value': 2
301
+ }
302
+ }
303
+
304
+ return parameters_dict
@@ -0,0 +1,303 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/RotatE.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 11, 2023
7
+ #
8
+ # 该头文件定义了 RotatE.
9
+
10
+ """
11
+ RotatE - 将实体表示成复数向量,关系建模为复数向量空间的旋转。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import torch.nn as nn
17
+ from .Model import Model
18
+ from typing_extensions import override
19
+
20
+ class RotatE(Model):
21
+
22
+ """
23
+ ``RotatE`` :cite:`RotatE` 提出于 2019 年,将实体表示成复数向量,关系建模为复数向量空间的旋转。
24
+
25
+ 评分函数为:
26
+
27
+ .. math::
28
+
29
+ \gamma - \parallel \mathbf{h} \circ \mathbf{r} - \mathbf{t} \parallel_{L_2}
30
+
31
+ :math:`\circ` 表示哈达玛积(Hadamard product),正三元组的评分函数的值越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
32
+
33
+ 例子::
34
+
35
+ from unike.data import KGEDataLoader, UniSampler, TradTestSampler
36
+ from unike.module.model import RotatE
37
+ from unike.module.loss import SigmoidLoss
38
+ from unike.module.strategy import NegativeSampling
39
+ from unike.config import Trainer, Tester
40
+
41
+ # dataloader for training
42
+ dataloader = KGEDataLoader(
43
+ in_path = '../../benchmarks/WN18RR/',
44
+ batch_size = 2000,
45
+ neg_ent = 64,
46
+ test = True,
47
+ test_batch_size = 10,
48
+ num_workers = 16,
49
+ train_sampler = UniSampler,
50
+ test_sampler = TradTestSampler
51
+ )
52
+
53
+ # define the model
54
+ rotate = RotatE(
55
+ ent_tol = dataloader.get_ent_tol(),
56
+ rel_tol = dataloader.get_rel_tol(),
57
+ dim = 1024,
58
+ margin = 6.0,
59
+ epsilon = 2.0,
60
+ )
61
+
62
+ # define the loss function
63
+ model = NegativeSampling(
64
+ model = rotate,
65
+ loss = SigmoidLoss(adv_temperature = 2),
66
+ regul_rate = 0.0,
67
+ )
68
+
69
+ # test the model
70
+ tester = Tester(model = rotate, data_loader = dataloader, use_gpu = True, device = 'cuda:1')
71
+
72
+ # train the model
73
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(), epochs = 6000,
74
+ lr = 2e-5, opt_method = 'adam', use_gpu = True, device = 'cuda:1',
75
+ tester = tester, test = True, valid_interval = 100,
76
+ log_interval = 100, save_interval = 100,
77
+ save_path = '../../checkpoint/rotate.pth', use_wandb = False)
78
+ trainer.run()
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ ent_tol: int,
84
+ rel_tol: int,
85
+ dim: int = 100,
86
+ margin: float = 6.0,
87
+ epsilon: float = 2.0):
88
+
89
+ """创建 RotatE 对象。
90
+
91
+ :param ent_tol: 实体的个数
92
+ :type ent_tol: int
93
+ :param rel_tol: 关系的个数
94
+ :type rel_tol: int
95
+ :param dim: 实体和关系嵌入向量的维度
96
+ :type dim: int
97
+ :param margin: 原论文中损失函数的 gamma。
98
+ :type margin: float
99
+ :param epsilon: RotatE 原论文对应的源代码固定为 2.0。
100
+ :type epsilon: float
101
+ """
102
+
103
+ super(RotatE, self).__init__(ent_tol, rel_tol)
104
+
105
+ #: RotatE 原论文对应的源代码固定为 2.0。
106
+ self.epsilon: int = epsilon
107
+
108
+ #: RotatE 原论文的实现中将实体嵌入向量的维度指定为 ``dim`` 的 2 倍。
109
+ #: 因为实体嵌入向量需要划分为实部和虚部。
110
+ self.dim_e: int = dim * 2
111
+ #: 关系嵌入向量的维度,为 ``dim``。
112
+ self.dim_r: int = dim
113
+
114
+ #: 根据实体个数,创建的实体嵌入。
115
+ self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim_e)
116
+ #: 根据关系个数,创建的关系嵌入。
117
+ self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim_r)
118
+
119
+ self.ent_embedding_range = nn.Parameter(
120
+ torch.Tensor([(margin + self.epsilon) / self.dim_e]),
121
+ requires_grad=False
122
+ )
123
+
124
+ nn.init.uniform_(
125
+ tensor = self.ent_embeddings.weight.data,
126
+ a=-self.ent_embedding_range.item(),
127
+ b=self.ent_embedding_range.item()
128
+ )
129
+
130
+ self.rel_embedding_range = nn.Parameter(
131
+ torch.Tensor([(margin + self.epsilon) / self.dim_r]),
132
+ requires_grad=False
133
+ )
134
+
135
+ nn.init.uniform_(
136
+ tensor = self.rel_embeddings.weight.data,
137
+ a=-self.rel_embedding_range.item(),
138
+ b=self.rel_embedding_range.item()
139
+ )
140
+
141
+ #: 原论文中损失函数的 gamma。
142
+ self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
143
+ self.margin.requires_grad = False
144
+
145
+ @override
146
+ def forward(
147
+ self,
148
+ triples: torch.Tensor,
149
+ negs: torch.Tensor = None,
150
+ mode: str = 'single') -> torch.Tensor:
151
+
152
+ """
153
+ 定义每次调用时执行的计算。
154
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
155
+
156
+ :param triples: 正确的三元组
157
+ :type triples: torch.Tensor
158
+ :param negs: 负三元组类别
159
+ :type negs: torch.Tensor
160
+ :param mode: 模式
161
+ :type triples: str
162
+ :returns: 三元组的得分
163
+ :rtype: torch.Tensor
164
+ """
165
+
166
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
167
+ score = self.margin - self._calc(head_emb, relation_emb, tail_emb)
168
+ return score
169
+
170
+ def _calc(
171
+ self,
172
+ h: torch.Tensor,
173
+ r: torch.Tensor,
174
+ t: torch.Tensor) -> torch.Tensor:
175
+
176
+ """计算 RotatE 的评分函数。
177
+
178
+ 利用 :py:func:`torch.chunk` 拆分实体嵌入向量获得复数的实部和虚部。
179
+ 原论文使用 L1-norm 作为距离函数,而这里使用的 L2-norm 作为距离函数。
180
+
181
+ :param h: 头实体的向量。
182
+ :type h: torch.Tensor
183
+ :param r: 关系的向量。
184
+ :type r: torch.Tensor
185
+ :param t: 尾实体的向量。
186
+ :type t: torch.Tensor
187
+ :returns: 三元组的得分
188
+ :rtype: torch.Tensor
189
+ """
190
+
191
+ pi = self.pi_const
192
+
193
+ re_head, im_head = torch.chunk(h, 2, dim=-1)
194
+ re_tail, im_tail = torch.chunk(t, 2, dim=-1)
195
+
196
+ # Make phases of relations uniformly distributed in [-pi, pi]
197
+ phase_relation = r / (self.rel_embedding_range.item() / pi)
198
+
199
+ re_relation = torch.cos(phase_relation)
200
+ im_relation = torch.sin(phase_relation)
201
+
202
+ re_score = re_head * re_relation - im_head * im_relation
203
+ im_score = re_head * im_relation + im_head * re_relation
204
+ re_score = re_score - re_tail
205
+ im_score = im_score - im_tail
206
+
207
+ score = torch.stack([re_score, im_score], dim = 0)
208
+ score = score.norm(dim = 0).sum(dim = -1)
209
+ return score
210
+
211
+ @override
212
+ def predict(
213
+ self,
214
+ data: dict[str, typing.Union[torch.Tensor,str]],
215
+ mode) -> torch.Tensor:
216
+
217
+ """RotatE 的推理方法。
218
+
219
+ :param data: 数据。
220
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
221
+ :returns: 三元组的得分
222
+ :rtype: torch.Tensor
223
+ """
224
+
225
+ triples = data["positive_sample"]
226
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
227
+ score = self.margin - self._calc(head_emb, relation_emb, tail_emb)
228
+ return score
229
+
230
+ def regularization(
231
+ self,
232
+ data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
233
+
234
+ """L2 正则化函数(又称权重衰减),在损失函数中用到。
235
+
236
+ :param data: 数据。
237
+ :type data: dict[str, typing.Union[torch.Tensor, str]]
238
+ :returns: 模型参数的正则损失
239
+ :rtype: torch.Tensor
240
+ """
241
+
242
+ pos_sample = data["positive_sample"]
243
+ neg_sample = data["negative_sample"]
244
+ mode = data["mode"]
245
+ pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
246
+ if mode == "bern":
247
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
248
+ else:
249
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
250
+
251
+ pos_regul = (torch.mean(pos_head_emb ** 2) +
252
+ torch.mean(pos_relation_emb ** 2) +
253
+ torch.mean(pos_tail_emb ** 2)) / 3
254
+
255
+ neg_regul = (torch.mean(neg_head_emb ** 2) +
256
+ torch.mean(neg_relation_emb ** 2) +
257
+ torch.mean(neg_tail_emb ** 2)) / 3
258
+
259
+ regul = (pos_regul + neg_regul) / 2
260
+
261
+ return regul
262
+
263
+ def get_rotate_hpo_config() -> dict[str, dict[str, typing.Any]]:
264
+
265
+ """返回 :py:class:`RotatE` 的默认超参数优化配置。
266
+
267
+ 默认配置为::
268
+
269
+ parameters_dict = {
270
+ 'model': {
271
+ 'value': 'RotatE'
272
+ },
273
+ 'dim': {
274
+ 'values': [256, 512, 1024]
275
+ },
276
+ 'margin': {
277
+ 'values': [1.0, 3.0, 6.0]
278
+ },
279
+ 'epsilon': {
280
+ 'value': 2.0
281
+ }
282
+ }
283
+
284
+ :returns: :py:class:`RotatE` 的默认超参数优化配置
285
+ :rtype: dict[str, dict[str, typing.Any]]
286
+ """
287
+
288
+ parameters_dict = {
289
+ 'model': {
290
+ 'value': 'RotatE'
291
+ },
292
+ 'dim': {
293
+ 'values': [256, 512, 1024]
294
+ },
295
+ 'margin': {
296
+ 'values': [1.0, 3.0, 6.0]
297
+ },
298
+ 'epsilon': {
299
+ 'value': 2.0
300
+ }
301
+ }
302
+
303
+ return parameters_dict