unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,562 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/CompGCN.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 26, 2024
7
+ #
8
+ # 该脚本定义了 CompGCN 类.
9
+
10
+ """
11
+ CompGCN - 这是一种在图卷积网络中整合多关系信息的新框架,它利用知识图谱嵌入技术中的各种组合操作,将实体和关系共同嵌入到图中。
12
+ """
13
+
14
+ import dgl
15
+ import torch
16
+ import typing
17
+ from torch import nn
18
+ import dgl.function as fn
19
+ from .Model import Model
20
+ import torch.nn.functional as F
21
+ from typing_extensions import override
22
+
23
+ class CompGCN(Model):
24
+
25
+ """
26
+ ``CompGCN`` :cite:`CompGCN` 发表于 ``2020`` 年,这是一种在图卷积网络中整合多关系信息的新框架,它利用知识图谱嵌入技术中的各种组合操作,将实体和关系共同嵌入到图中。
27
+
28
+ 正三元组的评分函数的值越大越好,如果想获得更详细的信息请访问 :ref:`CompGCN <compgcn>`。
29
+
30
+ 例子::
31
+
32
+ from unike.module.model import CompGCN
33
+ from unike.module.loss import CompGCNLoss
34
+ from unike.module.strategy import CompGCNSampling
35
+ from unike.config import Trainer, GraphTester
36
+
37
+ # define the model
38
+ compgcn = CompGCN(
39
+ ent_tol = dataloader.get_ent_tol(),
40
+ rel_tol = dataloader.get_rel_tol(),
41
+ dim = 100
42
+ )
43
+
44
+ # define the loss function
45
+ model = CompGCNSampling(
46
+ model = compgcn,
47
+ loss = CompGCNLoss(model = compgcn),
48
+ ent_tol = dataloader.get_ent_tol()
49
+ )
50
+
51
+ # test the model
52
+ tester = GraphTester(model = compgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0', prediction = "tail")
53
+
54
+ # train the model
55
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
56
+ epochs = 2000, lr = 0.0001, use_gpu = True, device = 'cuda:0',
57
+ tester = tester, test = True, valid_interval = 50, log_interval = 50,
58
+ save_interval = 50, save_path = '../../checkpoint/compgcn.pth'
59
+ )
60
+ trainer.run()
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ ent_tol: int,
66
+ rel_tol: int,
67
+ dim: int,
68
+ opn: str = 'mult',
69
+ fet_drop: float = 0.2,
70
+ hid_drop: float = 0.3,
71
+ margin: float = 40.0,
72
+ decoder_model: str = 'ConvE'):
73
+
74
+ """创建 RGCN 对象。
75
+
76
+ :param ent_tol: 实体的个数
77
+ :type ent_tol: int
78
+ :param rel_tol: 关系的个数
79
+ :type rel_tol: int
80
+ :param dim: 实体和关系嵌入向量的维度
81
+ :type dim: int
82
+ :param opn: 组成运算符:'mult'、'sub'、'corr'
83
+ :type opn: str
84
+ :param fet_drop: 用于 'ConvE' 解码器,用于卷积特征的 dropout
85
+ :type fet_drop: float
86
+ :param hid_drop: 用于 'ConvE' 解码器,用于隐藏层的 dropout
87
+ :type hid_drop: float
88
+ :param margin: 用于 'TransE' 解码器,gamma。
89
+ :type margin: float
90
+ :param decoder_model: 用什么得分函数作为解码器: 'ConvE'、'DistMult'、'TransE'
91
+ :type decoder_model: str
92
+ """
93
+
94
+ super(CompGCN, self).__init__(ent_tol, rel_tol)
95
+
96
+ #: 实体和关系嵌入向量的维度
97
+ self.dim: int = dim
98
+ #: 组成运算符:'mult'、'sub'、'corr'
99
+ self.opn: str = opn
100
+ #: 用什么得分函数作为解码器: 'ConvE'、'DistMult'
101
+ self.decoder_model: str = decoder_model
102
+
103
+ #------------------------------CompGCN--------------------------------------------------------------------
104
+ #: 根据实体个数,创建的实体嵌入
105
+ self.ent_emb: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor(self.ent_tol, self.dim))
106
+ #: 根据关系个数,创建的关系嵌入
107
+ self.rel_emb: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor(self.rel_tol, self.dim))
108
+
109
+ nn.init.xavier_normal_(self.ent_emb, gain=nn.init.calculate_gain('relu'))
110
+ nn.init.xavier_normal_(self.rel_emb, gain=nn.init.calculate_gain('relu'))
111
+
112
+ #: CompGCNCov
113
+ self.GraphCov: CompGCNCov = CompGCNCov(self.dim, self.dim * 2, torch.tanh, bias = 'False', drop_rate = 0.1, opn = self.opn)
114
+ #: 用于 :py:attr:`GraphCov` 输出结果
115
+ self.drop: torch.nn.Dropout = nn.Dropout(0.3)
116
+ #: 最后计算得分时的偏置
117
+ self.bias: torch.nn.parameter.Parameter = nn.Parameter(torch.zeros(self.ent_tol))
118
+ #-----------------------------ConvE-----------------------------------------------------------------------
119
+ #: 用于 'ConvE' 解码器,头实体嵌入向量和关系嵌入向量的 BatchNorm
120
+ self.bn0: torch.nn.BatchNorm2d = torch.nn.BatchNorm2d(1)
121
+ #: 用于 'ConvE' 解码器,卷积层
122
+ self.conv1: torch.nn.Conv2d = torch.nn.Conv2d(1, 200, (7, 7), 1, 0, bias=False)
123
+ #: 用于 'ConvE' 解码器,卷积特征的 BatchNorm
124
+ self.bn1: torch.nn.Conv2d = torch.nn.BatchNorm2d(200)
125
+ #: 用于 'ConvE' 解码器,卷积特征的 Dropout
126
+ self.fet_drop: torch.nn.Dropout = torch.nn.Dropout2d(fet_drop)
127
+ flat_sz_h = 4 * self.dim // 20 - 7 + 1
128
+ flat_sz_w = 20 - 7 + 1
129
+ flat_sz = flat_sz_h * flat_sz_w * 200
130
+ #: 用于 'ConvE' 解码器,隐藏层层
131
+ self.fc: torch.nn.Linear = torch.nn.Linear(flat_sz, self.dim*2)
132
+ #: 用于 'ConvE' 解码器,隐藏层的 Dropout
133
+ self.hid_drop: torch.nn.Dropout = torch.nn.Dropout(hid_drop)
134
+ #: 用于 'ConvE' 解码器,隐藏层的 BatchNorm
135
+ self.bn2: torch.nn.BatchNorm1d = torch.nn.BatchNorm1d(self.dim*2)
136
+ #-----------------------------TransE-----------------------------------------------------------------------
137
+ #: 用于 TransE 得分函数
138
+ self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
139
+ self.margin.requires_grad = False
140
+
141
+ @override
142
+ def forward(
143
+ self,
144
+ graph: dgl.DGLGraph,
145
+ relation: torch.Tensor,
146
+ norm: torch.Tensor,
147
+ triples: torch.Tensor) -> torch.Tensor:
148
+
149
+ """
150
+ 定义每次调用时执行的计算。
151
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
152
+
153
+ :param graph: 子图
154
+ :type graph: dgl.DGLGraph
155
+ :param relation: 子图的关系
156
+ :type relation: torch.Tensor
157
+ :param norm: 关系的归一化系数
158
+ :type norm: torch.Tensor
159
+ :param triples: 三元组
160
+ :type triples: torch.Tensor
161
+ :returns: 三元组的得分
162
+ :rtype: torch.Tensor
163
+ """
164
+
165
+ head, rela = triples[:,0], triples[:, 1]
166
+ x, r = self.ent_emb, self.rel_emb
167
+ x, r = self.GraphCov(graph, x, r, relation, norm)
168
+ x = self.drop(x)
169
+ head_emb = torch.index_select(x, 0, head)
170
+ rela_emb = torch.index_select(r, 0, rela)
171
+
172
+ if self.decoder_model.lower() == 'conve':
173
+ score = self.conve(head_emb, rela_emb, x)
174
+ elif self.decoder_model.lower() == 'distmult':
175
+ score = self.distmult(head_emb, rela_emb, x)
176
+ elif self.decoder_model.lower() == 'transe':
177
+ score = self.transe(head_emb, rela_emb, x)
178
+ else:
179
+ raise ValueError("please choose decoder (TransE/DistMult/ConvE)")
180
+
181
+ return score
182
+
183
+ def conve(
184
+ self,
185
+ sub_emb: torch.Tensor,
186
+ rel_emb: torch.Tensor,
187
+ all_ent: torch.Tensor) -> torch.Tensor:
188
+
189
+ """计算 ConvE 作为解码器时三元组的得分。
190
+
191
+ :param sub_emb: 头实体的嵌入向量
192
+ :type sub_emb: torch.Tensor
193
+ :param rel_emb: 关系的嵌入向量
194
+ :type rel_emb: torch.Tensor
195
+ :param all_ent: 全部实体的嵌入向量
196
+ :type all_ent: torch.Tensor
197
+ :returns: 三元组的得分
198
+ :rtype: torch.Tensor"""
199
+
200
+ stack_input = self.concat(sub_emb, rel_emb)
201
+ x = self.bn0(stack_input)
202
+ x = self.conv1(x)
203
+ x = self.bn1(x)
204
+ x = F.relu(x)
205
+ x = self.fet_drop(x)
206
+ x = x.view(x.shape[0], -1)
207
+ x = self.fc(x)
208
+ x = self.hid_drop(x)
209
+ x = self.bn2(x)
210
+ x = F.relu(x)
211
+ x = torch.mm(x, all_ent.transpose(1, 0))
212
+ x += self.bias.expand_as(x)
213
+ score = torch.sigmoid(x)
214
+ return score
215
+
216
+ def concat(
217
+ self,
218
+ ent_embed: torch.Tensor,
219
+ rel_embed: torch.Tensor) -> torch.Tensor:
220
+
221
+ """ConvE 作为解码器时,用于拼接头实体嵌入向量和关系嵌入向量。
222
+
223
+ :param ent_embed: 头实体的嵌入向量
224
+ :type ent_embed: torch.Tensor
225
+ :param rel_embed: 关系的嵌入向量
226
+ :type rel_embed: torch.Tensor
227
+ :returns: ConvE 解码器的输入特征
228
+ :rtype: torch.Tensor"""
229
+
230
+ ent_embed = ent_embed.view(-1, 1, self.dim*2)
231
+ rel_embed = rel_embed.view(-1, 1, self.dim*2)
232
+ stack_input = torch.cat([ent_embed, rel_embed], 1)
233
+ stack_input = stack_input.reshape(-1, 1, 4*self.dim//20, 20)
234
+ return stack_input
235
+
236
+ def distmult(
237
+ self,
238
+ head_emb: torch.Tensor,
239
+ rela_emb: torch.Tensor,
240
+ all_ent: torch.Tensor) -> torch.Tensor:
241
+
242
+ """计算 DistMult 作为解码器时三元组的得分。
243
+
244
+ :param sub_emb: 头实体的嵌入向量
245
+ :type sub_emb: torch.Tensor
246
+ :param rel_emb: 关系的嵌入向量
247
+ :type rel_emb: torch.Tensor
248
+ :param all_ent: 全部实体的嵌入向量
249
+ :type all_ent: torch.Tensor
250
+ :returns: 三元组的得分
251
+ :rtype: torch.Tensor"""
252
+
253
+ obj_emb = head_emb * rela_emb
254
+ x = torch.mm(obj_emb, all_ent.transpose(1, 0))
255
+ x += self.bias.expand_as(x)
256
+ score = torch.sigmoid(x)
257
+ return score
258
+
259
+ def transe(
260
+ self,
261
+ head_emb: torch.Tensor,
262
+ rela_emb: torch.Tensor,
263
+ all_ent: torch.Tensor) -> torch.Tensor:
264
+
265
+ """计算 TransE 作为解码器时三元组的得分。
266
+
267
+ :param sub_emb: 头实体的嵌入向量
268
+ :type sub_emb: torch.Tensor
269
+ :param rel_emb: 关系的嵌入向量
270
+ :type rel_emb: torch.Tensor
271
+ :param all_ent: 全部实体的嵌入向量
272
+ :type all_ent: torch.Tensor
273
+ :returns: 三元组的得分
274
+ :rtype: torch.Tensor"""
275
+
276
+ obj_emb = head_emb + rela_emb
277
+ x = self.margin - torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2)
278
+ score = torch.sigmoid(x)
279
+ return score
280
+
281
+ @override
282
+ def predict(
283
+ self,
284
+ data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]],
285
+ mode: str) -> torch.Tensor:
286
+
287
+ """CompGCN 的推理方法。
288
+
289
+ :param data: 数据。
290
+ :type data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
291
+ :param mode: 在 CompGCN 时,无用,只为了保证推理函数形式一致
292
+ :type mode: str
293
+ :returns: 三元组的得分
294
+ :rtype: torch.Tensor
295
+ """
296
+
297
+ triples = data['positive_sample']
298
+ graph = data['graph']
299
+ relation = data['rela']
300
+ norm = data['norm']
301
+
302
+ head, rela = triples[:,0], triples[:, 1]
303
+ x, r = self.ent_emb, self.rel_emb
304
+ x, r = self.GraphCov(graph, x, r, relation, norm)
305
+ x = self.drop(x)
306
+ head_emb = torch.index_select(x, 0, head)
307
+ rela_emb = torch.index_select(r, 0, rela)
308
+
309
+ if self.decoder_model.lower() == 'conve':
310
+ score = self.conve(head_emb, rela_emb, x)
311
+ elif self.decoder_model.lower() == 'distmult':
312
+ score = self.distmult(head_emb, rela_emb, x)
313
+ elif self.decoder_model.lower() == 'transe':
314
+ score = self.transe(head_emb, rela_emb, x)
315
+ else:
316
+ raise ValueError("please choose decoder (TransE/DistMult/ConvE)")
317
+
318
+ return score
319
+
320
+ class CompGCNCov(nn.Module):
321
+
322
+ """``CompGCN`` :cite:`CompGCN` 图神经网络模块。"""
323
+
324
+ def __init__(
325
+ self,
326
+ in_channels: int,
327
+ out_channels: int,
328
+ act: typing.Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
329
+ bias: bool = True,
330
+ drop_rate: float = 0.,
331
+ opn: str = 'corr'):
332
+
333
+ """创建 CompGCN 对象。
334
+
335
+ :param in_channels: 输入的特征维度
336
+ :type in_channels: int
337
+ :param out_channels: 输出的特征维度
338
+ :type out_channels: int
339
+ :param act: 激活函数
340
+ :type act: typing.Callable[[torch.Tensor], torch.Tensor]
341
+ :param bias: 是否有偏置
342
+ :type bias: bool
343
+ :param drop_rate: Dropout rate
344
+ :type drop_rate: float
345
+ :param opn: 组成运算符:'mult'、'sub'、'corr'
346
+ :type opn: str
347
+ """
348
+
349
+ super(CompGCNCov, self).__init__()
350
+
351
+ #: 输入的特征维度
352
+ self.in_channels: int = in_channels
353
+ #: 输出的特征维度
354
+ self.out_channels: int = out_channels
355
+ self.rel_wt = None
356
+ #: 关系嵌入向量
357
+ self.rel: torch.nn.parameter.Parameter = None
358
+ #: 组成运算符:'mult'、'sub'、'corr'
359
+ self.opn: str = opn
360
+ #:图神经网络的权重矩阵,用于原始关系
361
+ self.in_w: torch.nn.parameter.Parameter = self.get_param([in_channels, out_channels])
362
+ #:图神经网络的权重矩阵,用于相反关系
363
+ self.out_w: torch.nn.parameter.Parameter = self.get_param([in_channels, out_channels])
364
+ #: 用于原始关系和相反关系转换后输出结果的 Dropout
365
+ self.drop: torch.nn.Dropout = nn.Dropout(drop_rate)
366
+ #: 自循环关系嵌入向量的转换矩阵
367
+ self.loop_rel: torch.nn.parameter.Parameter = self.get_param([1, in_channels])
368
+ #:图神经网络的权重矩阵,用于自循环关系
369
+ self.loop_w: torch.nn.parameter.Parameter = self.get_param([in_channels, out_channels])
370
+ #: 偏置
371
+ self.bias: torch.nn.Parameter = nn.Parameter(torch.zeros(out_channels)) if bias else None
372
+ #: BatchNorm
373
+ self.bn: torch.nn.BatchNorm1d = torch.nn.BatchNorm1d(out_channels)
374
+ #: 激活函数
375
+ self.act: typing.Callable[[torch.Tensor], torch.Tensor] = act
376
+ #: 关系嵌入向量的转换矩阵
377
+ self.w_rel: torch.nn.parameter.Parameter = self.get_param([in_channels, out_channels])
378
+
379
+ def get_param(
380
+ self,
381
+ shape: list[int]) -> torch.nn.parameter.Parameter:
382
+
383
+ """获得权重矩阵。
384
+
385
+ :param shape: 权重矩阵的 shape
386
+ :type shape: list[int]
387
+ :returns: 权重矩阵
388
+ :rtype: torch.nn.parameter.Parameter
389
+ """
390
+
391
+ param = nn.Parameter(torch.Tensor(*shape))
392
+ nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu'))
393
+ return param
394
+
395
+ def forward(
396
+ self,
397
+ graph: dgl.DGLGraph,
398
+ ent_emb: torch.nn.parameter.Parameter,
399
+ rel_emb: torch.nn.parameter.Parameter,
400
+ edge_type: torch.Tensor,
401
+ edge_norm: torch.Tensor) -> tuple[torch.nn.parameter.Parameter, torch.nn.parameter.Parameter]:
402
+
403
+ """
404
+ 定义每次调用时执行的计算。
405
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
406
+
407
+ :param graph: 子图
408
+ :type graph: dgl.DGLGraph
409
+ :param ent_emb: 实体嵌入向量
410
+ :type ent_emb: torch.nn.parameter.Parameter
411
+ :param rel_emb: 关系嵌入向量
412
+ :type rel_emb: torch.nn.parameter.Parameter
413
+ :param edge_type: 关系 ID
414
+ :type edge_type: torch.Tensor
415
+ :param norm: 关系的归一化系数
416
+ :type norm: torch.Tensor
417
+ :returns: 更新后的实体嵌入和关系嵌入
418
+ :rtype: tuple[torch.nn.parameter.Parameter, torch.nn.parameter.Parameter]
419
+ """
420
+
421
+ graph = graph.local_var()
422
+ graph.ndata['h'] = ent_emb
423
+ graph.edata['type'] = edge_type
424
+ graph.edata['norm'] = edge_norm
425
+ if self.rel_wt is None:
426
+ self.rel = rel_emb
427
+ else:
428
+ self.rel = torch.mm(self.rel_wt, rel_emb)
429
+ graph.update_all(self.message_func, fn.sum(msg='msg', out='h'), self.reduce_func)
430
+ ent_emb = graph.ndata.pop('h') + torch.mm(self.comp(ent_emb, self.loop_rel), self.loop_w) / 3
431
+ if self.bias is not None:
432
+ ent_emb = ent_emb + self.bias
433
+ ent_emb = self.bn(ent_emb)
434
+
435
+ return self.act(ent_emb), torch.matmul(self.rel, self.w_rel)
436
+
437
+ def message_func(self, edges: dgl.udf.EdgeBatch):
438
+
439
+ """
440
+ 消息函数。
441
+ """
442
+
443
+ edge_type = edges.data['type']
444
+ edge_num = edge_type.shape[0]
445
+ edge_data = self.comp(edges.src['h'], self.rel[edge_type])
446
+ msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w),
447
+ torch.matmul(edge_data[edge_num // 2:, :], self.out_w)])
448
+ msg = msg * edges.data['norm'].reshape(-1, 1)
449
+ return {'msg': msg}
450
+
451
+ def comp(
452
+ self,
453
+ h: torch.Tensor,
454
+ r: torch.Tensor) -> torch.Tensor:
455
+
456
+ """组成运算:'mult'、'sub'、'corr'
457
+
458
+ :param h: 头实体嵌入向量
459
+ :type h: torch.Tensor
460
+ :param r: 关系嵌入向量
461
+ :type r: torch.Tensor
462
+ :returns: 组合后的边数据
463
+ :rtype: torch.Tensor
464
+ """
465
+
466
+ def com_mult(a, b):
467
+
468
+ """复数乘法"""
469
+
470
+ r1, i1 = a.real, a.imag
471
+ r2, i2 = b.real, b.imag
472
+ real = r1 * r2 - i1 * i2
473
+ imag = r1 * i2 + i1 * r2
474
+ return torch.complex(real, imag)
475
+
476
+ def conj(a):
477
+
478
+ """共轭复数"""
479
+
480
+ a.imag = -a.imag
481
+ return a
482
+
483
+ def ccorr(a, b):
484
+
485
+ """corr 运算"""
486
+
487
+ return torch.fft.irfft(com_mult(conj(torch.fft.rfft(a)), torch.fft.rfft(b)), a.shape[-1])
488
+
489
+ if self.opn == 'mult':
490
+ return h * r
491
+ elif self.opn == 'sub':
492
+ return h - r
493
+ elif self.opn == 'corr':
494
+ return ccorr(h, r.expand_as(h))
495
+ else:
496
+ raise KeyError(f'composition operator {self.opn} not recognized.')
497
+
498
+ def reduce_func(self, nodes: dgl.udf.NodeBatch):
499
+
500
+ """聚合函数"""
501
+
502
+ return {'h': self.drop(nodes.data['h']) / 3}
503
+
504
+ def get_compgcn_hpo_config() -> dict[str, dict[str, typing.Any]]:
505
+
506
+ """返回 :py:class:`CompGCN` 的默认超参数优化配置。
507
+
508
+ 默认配置为::
509
+
510
+ parameters_dict = {
511
+ 'model': {
512
+ 'value': 'CompGCN'
513
+ },
514
+ 'dim': {
515
+ 'values': [100, 150, 200]
516
+ },
517
+ 'opn': {
518
+ 'value': 'mult'
519
+ },
520
+ 'fet_drop': {
521
+ 'value': 0.2
522
+ },
523
+ 'hid_drop': {
524
+ 'value': 0.3
525
+ },
526
+ 'margin': {
527
+ 'value': 40.0
528
+ },
529
+ 'decoder_model': {
530
+ 'value': 'ConvE'
531
+ }
532
+ }
533
+
534
+ :returns: :py:class:`CompGCN` 的默认超参数优化配置
535
+ :rtype: dict[str, dict[str, typing.Any]]
536
+ """
537
+
538
+ parameters_dict = {
539
+ 'model': {
540
+ 'value': 'CompGCN'
541
+ },
542
+ 'dim': {
543
+ 'values': [100, 150, 200]
544
+ },
545
+ 'opn': {
546
+ 'value': 'mult'
547
+ },
548
+ 'fet_drop': {
549
+ 'value': 0.2
550
+ },
551
+ 'hid_drop': {
552
+ 'value': 0.3
553
+ },
554
+ 'margin': {
555
+ 'value': 40.0
556
+ },
557
+ 'decoder_model': {
558
+ 'value': 'ConvE'
559
+ }
560
+ }
561
+
562
+ return parameters_dict