unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,402 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/TransR.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
7
+ #
8
+ # 该头文件定义了 TransR.
9
+
10
+ """
11
+ TransR - 是一个为实体和关系嵌入向量分别构建了独立的向量空间,将实体向量投影到特定的关系向量空间进行平移操作的模型。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import numpy as np
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from .Model import Model
20
+ from typing_extensions import override
21
+
22
+ class TransR(Model):
23
+
24
+ """
25
+ ``TransR`` :cite:`TransR` 提出于 2015 年,是一个为实体和关系嵌入向量分别构建了独立的向量空间,将实体向量投影到特定的关系向量空间进行平移操作的模型。
26
+
27
+ 评分函数为:
28
+
29
+ .. math::
30
+
31
+ \Vert hM_r+r-tM_r \Vert_{L_1/L_2}
32
+
33
+ 正三元组的评分函数的值越小越好,如果想获得更详细的信息请访问 :ref:`TransR <transr>`。
34
+
35
+ 例子::
36
+
37
+ from unike.data import KGEDataLoader, BernSampler, TradTestSampler
38
+ from unike.module.model import TransE, TransR
39
+ from unike.module.loss import MarginLoss
40
+ from unike.module.strategy import NegativeSampling
41
+ from unike.config import Trainer, Tester
42
+
43
+ # dataloader for training
44
+ dataloader = KGEDataLoader(
45
+ in_path = "../../benchmarks/FB15K237/",
46
+ batch_size = 2048,
47
+ neg_ent = 25,
48
+ test = True,
49
+ test_batch_size = 10,
50
+ num_workers = 16,
51
+ train_sampler = BernSampler,
52
+ test_sampler = TradTestSampler
53
+ )
54
+
55
+ # define the transe
56
+ transe = TransE(
57
+ ent_tol = dataloader.get_ent_tol(),
58
+ rel_tol = dataloader.get_rel_tol(),
59
+ dim = 100,
60
+ p_norm = 1,
61
+ norm_flag = True)
62
+
63
+ transr = TransR(
64
+ ent_tol = dataloader.get_ent_tol(),
65
+ rel_tol = dataloader.get_rel_tol(),
66
+ dim_e = 100,
67
+ dim_r = 100,
68
+ p_norm = 1,
69
+ norm_flag = True,
70
+ rand_init = False)
71
+
72
+ model_e = NegativeSampling(
73
+ model = transe,
74
+ loss = MarginLoss(margin = 5.0)
75
+ )
76
+
77
+ model_r = NegativeSampling(
78
+ model = transr,
79
+ loss = MarginLoss(margin = 4.0)
80
+ )
81
+
82
+ # pretrain transe
83
+ trainer = Trainer(model = model_e, data_loader = dataloader.train_dataloader(),
84
+ epochs = 1, lr = 0.5, opt_method = "sgd", use_gpu = True, device = 'cuda:0')
85
+ trainer.run()
86
+ parameters = transe.get_parameters()
87
+ transe.save_parameters("../../checkpoint/transr_transe.json")
88
+
89
+ # test the transr
90
+ tester = Tester(model = transr, data_loader = dataloader, use_gpu = True, device = 'cuda:0')
91
+
92
+ # train transr
93
+ transr.set_parameters(parameters)
94
+ trainer = Trainer(model = model_r, data_loader = dataloader.train_dataloader(),
95
+ epochs = 1000, lr = 1.0, opt_method = "sgd", use_gpu = True, device = 'cuda:0',
96
+ tester = tester, test = True, valid_interval = 100,
97
+ log_interval = 100, save_interval = 100, save_path = '../../checkpoint/transr.pth')
98
+ trainer.run()
99
+
100
+ # test the model
101
+ transr.load_checkpoint('../../checkpoint/transr.pth')
102
+ tester.set_sampling_mode("link_test")
103
+ tester.run_link_prediction()
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ ent_tol: int,
109
+ rel_tol: int,
110
+ dim_e: int = 100,
111
+ dim_r: int = 100,
112
+ p_norm: int = 1,
113
+ norm_flag: bool = True,
114
+ rand_init: bool = False,
115
+ margin: float | None = None):
116
+
117
+ """创建 TransR 对象。
118
+
119
+ :param ent_tol: 实体的个数
120
+ :type ent_tol: int
121
+ :param rel_tol: 关系的个数
122
+ :type rel_tol: int
123
+ :param dim_e: 实体嵌入向量的维度
124
+ :type dim_e: int
125
+ :param dim_r: 关系嵌入向量的维度
126
+ :type dim_r: int
127
+ :param p_norm: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
128
+ :type p_norm: int
129
+ :param norm_flag: 是否利用 :py:func:`torch.nn.functional.normalize` 对实体和关系嵌入的最后一维执行 L2-norm。
130
+ :type norm_flag: bool
131
+ :param rand_init: 关系矩阵是否采用随机初始化。
132
+ :type rand_init: bool
133
+ :param margin: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
134
+ :type margin: float
135
+ """
136
+
137
+ super(TransR, self).__init__(ent_tol, rel_tol)
138
+
139
+ #: 实体嵌入向量的维度
140
+ self.dim_e: int = dim_e
141
+ #: 关系嵌入向量的维度
142
+ self.dim_r: int = dim_r
143
+ #: 评分函数的距离函数, 按照原论文,这里可以取 1 或 2。
144
+ self.p_norm: int = p_norm
145
+ #: 是否利用 :py:func:`torch.nn.functional.normalize`
146
+ #: 对实体和关系嵌入向量的最后一维执行 L2-norm。
147
+ self.norm_flag: bool = norm_flag
148
+ #: 关系矩阵是否采用随机初始化
149
+ self.rand_init: bool = rand_init
150
+
151
+ #: 根据实体个数,创建的实体嵌入
152
+ self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim_e)
153
+ #: 根据关系个数,创建的关系嵌入
154
+ self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim_r)
155
+
156
+ if margin != None:
157
+ #: 当使用 ``RotatE`` :cite:`RotatE` 的损失函数 :py:class:`unike.module.loss.SigmoidLoss`,需要提供此参数,将 ``TransE`` :cite:`TransE` 的正三元组的评分由越小越好转化为越大越好,如果想获得更详细的信息请访问 :ref:`RotatE <rotate>`。
158
+ self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
159
+ self.margin.requires_grad = False
160
+ self.margin_flag: bool = True
161
+ else:
162
+ self.margin_flag: bool = False
163
+
164
+ nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
165
+ nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
166
+
167
+ #: 关系矩阵
168
+ self.transfer_matrix: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim_e * self.dim_r)
169
+
170
+ if not self.rand_init:
171
+ identity = torch.zeros(self.dim_e, self.dim_r)
172
+ for i in range(min(self.dim_e, self.dim_r)):
173
+ identity[i][i] = 1
174
+ identity = identity.view(self.dim_e * self.dim_r)
175
+ for i in range(self.rel_tol):
176
+ self.transfer_matrix.weight.data[i] = identity
177
+ else:
178
+ nn.init.xavier_uniform_(self.transfer_matrix.weight.data)
179
+
180
+ @override
181
+ def forward(
182
+ self,
183
+ triples: torch.Tensor,
184
+ negs: torch.Tensor = None,
185
+ mode: str = 'single') -> torch.Tensor:
186
+
187
+ """
188
+ 定义每次调用时执行的计算。
189
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
190
+
191
+ :param triples: 正确的三元组
192
+ :type triples: torch.Tensor
193
+ :param negs: 负三元组类别
194
+ :type negs: torch.Tensor
195
+ :param mode: 模式
196
+ :type triples: str
197
+ :returns: 三元组的得分
198
+ :rtype: torch.Tensor
199
+ """
200
+
201
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
202
+ rel_transfer = self.transfer_matrix(triples[:, 1])
203
+ head_emb = self._transfer(head_emb, rel_transfer)
204
+ tail_emb = self._transfer(tail_emb, rel_transfer)
205
+ score = self._calc(head_emb, relation_emb, tail_emb)
206
+ if self.margin_flag:
207
+ return self.margin - score
208
+ else:
209
+ return score
210
+
211
+ def _transfer(
212
+ self,
213
+ e: torch.Tensor,
214
+ r_transfer: torch.Tensor) -> torch.Tensor:
215
+
216
+ """
217
+ 将头实体或尾实体的向量投影到特定的关系向量空间。
218
+
219
+ :param e: 头实体或尾实体向量。
220
+ :type e: torch.Tensor
221
+ :param r_transfer: 特定关系矩阵
222
+ :type r_transfer: torch.Tensor
223
+ :returns: 投影后的实体向量
224
+ :rtype: torch.Tensor
225
+ """
226
+
227
+ r_transfer = r_transfer.view(-1, self.dim_e, self.dim_r)
228
+ r_transfer = r_transfer.unsqueeze(dim=1)
229
+ e = e.unsqueeze(dim=-2)
230
+ e = torch.matmul(e, r_transfer)
231
+ return e.squeeze(dim=-2)
232
+
233
+ def _calc(
234
+ self,
235
+ h: torch.Tensor,
236
+ r: torch.Tensor,
237
+ t: torch.Tensor) -> torch.Tensor:
238
+
239
+ """计算 TransR 的评分函数。
240
+
241
+ :param h: 头实体的向量。
242
+ :type h: torch.Tensor
243
+ :param r: 关系的向量。
244
+ :type r: torch.Tensor
245
+ :param t: 尾实体的向量。
246
+ :type t: torch.Tensor
247
+ :returns: 三元组的得分
248
+ :rtype: torch.Tensor
249
+ """
250
+
251
+ # 对嵌入的最后一维进行归一化
252
+ if self.norm_flag:
253
+ h = F.normalize(h, 2, -1)
254
+ r = F.normalize(r, 2, -1)
255
+ t = F.normalize(t, 2, -1)
256
+
257
+ score = (h + r) - t
258
+
259
+ # 利用距离函数计算得分
260
+ score = torch.norm(score, self.p_norm, -1)
261
+ return score
262
+
263
+ @override
264
+ def predict(
265
+ self,
266
+ data: dict[str, typing.Union[torch.Tensor,str]],
267
+ mode: str) -> torch.Tensor:
268
+
269
+ """TransR 的推理方法。
270
+
271
+ :param data: 数据。
272
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
273
+ :param mode: 'head_predict' 或 'tail_predict'
274
+ :type mode: str
275
+ :returns: 三元组的得分
276
+ :rtype: torch.Tensor
277
+ """
278
+
279
+ triples = data["positive_sample"]
280
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
281
+ rel_transfer = self.transfer_matrix(triples[:, 1])
282
+ head_emb = self._transfer(head_emb, rel_transfer)
283
+ tail_emb = self._transfer(tail_emb, rel_transfer)
284
+ score = self._calc(head_emb, relation_emb, tail_emb)
285
+
286
+ if self.margin_flag:
287
+ score = self.margin - score
288
+ return score
289
+ else:
290
+ return -score
291
+
292
+ def regularization(
293
+ self,
294
+ data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
295
+
296
+ """L2 正则化函数(又称权重衰减),在损失函数中用到。
297
+
298
+ :param data: 数据。
299
+ :type data: dict[str, typing.Union[torch.Tensor, str]]
300
+ :returns: 模型参数的正则损失
301
+ :rtype: torch.Tensor
302
+ """
303
+
304
+ pos_sample = data["positive_sample"]
305
+ neg_sample = data["negative_sample"]
306
+ mode = data["mode"]
307
+ pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
308
+ pos_rel_transfer = self.transfer_matrix(pos_sample[:, 1])
309
+ if mode == "bern":
310
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
311
+ else:
312
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
313
+ neg_rel_transfer = self.transfer_matrix(pos_sample[:, 1])
314
+
315
+ pos_regul = (torch.mean(pos_head_emb ** 2) +
316
+ torch.mean(pos_relation_emb ** 2) +
317
+ torch.mean(pos_tail_emb ** 2) +
318
+ torch.mean(pos_rel_transfer ** 2)) / 4
319
+
320
+ neg_regul = (torch.mean(neg_head_emb ** 2) +
321
+ torch.mean(neg_relation_emb ** 2) +
322
+ torch.mean(neg_tail_emb ** 2) +
323
+ torch.mean(neg_rel_transfer ** 2)) / 4
324
+
325
+ regul = (pos_regul + neg_regul) / 2
326
+
327
+ return regul
328
+
329
+ def get_transr_hpo_config() -> dict[str, dict[str, typing.Any]]:
330
+
331
+ """返回 :py:class:`TransR` 的默认超参数优化配置。
332
+
333
+ ``TransR`` :cite:`TransR` 进行超参数优化的时候,需要先训练一个 ``TransE`` :cite:`TransE` 模型(训练 1 epoch)。
334
+ 然后 ``TransR`` :cite:`TransR` 的实体和关系的嵌入向量初始化为 TransE 的结果。
335
+ **margin_e** 、 **lr_e** 和 **opt_method_e** 是 ``TransE`` :cite:`TransE` 的训练超参数。
336
+ 如果想获得更详细的信息请访问 :ref:`TransR <transr>`。
337
+
338
+ 默认配置为::
339
+
340
+ parameters_dict = {
341
+ 'model': {
342
+ 'value': 'TransR'
343
+ },
344
+ 'dim': {
345
+ 'values': [50, 100]
346
+ },
347
+ 'p_norm': {
348
+ 'values': [1, 2]
349
+ },
350
+ 'norm_flag': {
351
+ 'value': True
352
+ },
353
+ 'rand_init': {
354
+ 'value': False
355
+ },
356
+ 'margin_e': {
357
+ 'values': [1.0, 3.0, 6.0]
358
+ },
359
+ 'lr_e': {
360
+ 'distribution': 'uniform',
361
+ 'min': 1e-5,
362
+ 'max': 1.0
363
+ },
364
+ 'opt_method_e': {
365
+ 'values': ['adam', 'adagrad', 'sgd']
366
+ },
367
+ }
368
+
369
+ :returns: :py:class:`TransR` 的默认超参数优化配置
370
+ :rtype: dict[str, dict[str, typing.Any]]
371
+ """
372
+
373
+ parameters_dict = {
374
+ 'model': {
375
+ 'value': 'TransR'
376
+ },
377
+ 'dim': {
378
+ 'values': [50, 100]
379
+ },
380
+ 'p_norm': {
381
+ 'values': [1, 2]
382
+ },
383
+ 'norm_flag': {
384
+ 'value': True
385
+ },
386
+ 'rand_init': {
387
+ 'value': False
388
+ },
389
+ 'margin_e': {
390
+ 'values': [1.0, 3.0, 6.0]
391
+ },
392
+ 'lr_e': {
393
+ 'distribution': 'uniform',
394
+ 'min': 1e-5,
395
+ 'max': 1.0
396
+ },
397
+ 'opt_method_e': {
398
+ 'values': ['adam', 'adagrad', 'sgd']
399
+ },
400
+ }
401
+
402
+ return parameters_dict
@@ -0,0 +1,60 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/__init__.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 28, 2023
7
+ #
8
+ # 该头文件定义了 model 接口.
9
+
10
+ """KGE 模型部分。"""
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+ from .Model import Model
17
+ from .TransE import TransE, get_transe_hpo_config
18
+ from .TransH import TransH, get_transh_hpo_config
19
+ from .TransR import TransR, get_transr_hpo_config
20
+ from .TransD import TransD, get_transd_hpo_config
21
+ from .RotatE import RotatE, get_rotate_hpo_config
22
+ from .RESCAL import RESCAL, get_rescal_hpo_config
23
+ from .DistMult import DistMult, get_distmult_hpo_config
24
+ from .HolE import HolE, get_hole_hpo_config
25
+ from .ComplEx import ComplEx, get_complex_hpo_config
26
+ from .Analogy import Analogy, get_analogy_hpo_config
27
+ from .SimplE import SimplE, get_simple_hpo_config
28
+ from .RGCN import RGCN, get_rgcn_hpo_config
29
+ from .CompGCN import CompGCN, CompGCNCov, get_compgcn_hpo_config
30
+
31
+ __all__ = [
32
+ 'Model',
33
+ 'TransE',
34
+ 'get_transe_hpo_config',
35
+ 'TransH',
36
+ 'get_transh_hpo_config',
37
+ 'TransR',
38
+ 'get_transr_hpo_config',
39
+ 'TransD',
40
+ 'get_transd_hpo_config',
41
+ 'RotatE',
42
+ 'get_rotate_hpo_config',
43
+ 'RESCAL',
44
+ 'get_rescal_hpo_config',
45
+ 'DistMult',
46
+ 'get_distmult_hpo_config',
47
+ 'HolE',
48
+ 'get_hole_hpo_config',
49
+ 'ComplEx',
50
+ 'get_complex_hpo_config',
51
+ 'Analogy',
52
+ 'get_analogy_hpo_config',
53
+ 'SimplE',
54
+ 'get_simple_hpo_config',
55
+ 'RGCN',
56
+ 'get_rgcn_hpo_config',
57
+ 'CompGCN',
58
+ 'CompGCNCov',
59
+ 'get_compgcn_hpo_config'
60
+ ]
@@ -0,0 +1,140 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/strategy/CompGCNSampling.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 20, 2023
7
+ #
8
+ # 该脚本定义了 CompGCN 模型的训练策略.
9
+
10
+ """
11
+ CompGCNSampling - 训练策略类,包含损失函数。
12
+ """
13
+
14
+ import dgl
15
+ import torch
16
+ import typing
17
+ from ..loss import Loss
18
+ from ..model import CompGCN
19
+ from .Strategy import Strategy
20
+
21
+ class CompGCNSampling(Strategy):
22
+
23
+ """
24
+ 将模型和损失函数封装到一起,方便模型训练,用于 ``CompGCN`` :cite:`CompGCN`。
25
+
26
+ 例子::
27
+
28
+ from unike.module.model import CompGCN
29
+ from unike.module.loss import CompGCNLoss
30
+ from unike.module.strategy import CompGCNSampling
31
+ from unike.config import Trainer, GraphTester
32
+
33
+ # define the model
34
+ compgcn = CompGCN(
35
+ ent_tol = dataloader.train_sampler.ent_tol,
36
+ rel_tol = dataloader.train_sampler.rel_tol,
37
+ dim = 100
38
+ )
39
+
40
+ # define the loss function
41
+ model = CompGCNSampling(
42
+ model = compgcn,
43
+ loss = CompGCNLoss(model = compgcn),
44
+ ent_tol = dataloader.train_sampler.ent_tol
45
+ )
46
+
47
+ # test the model
48
+ tester = GraphTester(model = compgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0', prediction = "tail")
49
+
50
+ # train the model
51
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
52
+ epochs = 2000, lr = 0.0001, use_gpu = True, device = 'cuda:0',
53
+ tester = tester, test = True, valid_interval = 50, log_interval = 50,
54
+ save_interval = 50, save_path = '../../checkpoint/compgcn.pth'
55
+ )
56
+ trainer.run()
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ model: CompGCN = None,
62
+ loss: Loss = None,
63
+ smoothing: float = 0.1,
64
+ ent_tol: int = None):
65
+
66
+ """创建 CompGCNSampling 对象。
67
+
68
+ :param model: CompGCN 模型
69
+ :type model: :py:class:`unike.module.model.CompGCN`
70
+ :param loss: 损失函数。
71
+ :type loss: :py:class:`unike.module.loss.Loss`
72
+ :param smoothing: smoothing
73
+ :type smoothing: float
74
+ :param ent_tol: 实体个数
75
+ :type ent_tol: int
76
+ """
77
+
78
+ super(CompGCNSampling, self).__init__()
79
+
80
+ #: CompGCN 模型,即 :py:class:`unike.module.model.CompGCN`
81
+ self.model: CompGCN = model
82
+ #: 损失函数,即 :py:class:`unike.module.loss.Loss`
83
+ self.loss: Loss = loss
84
+ #: smoothing
85
+ self.smoothing: float = smoothing
86
+ #: 实体个数
87
+ self.ent_tol: int = ent_tol
88
+
89
+ def forward(
90
+ self,
91
+ data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]) -> torch.Tensor:
92
+
93
+ """计算最后的损失值。定义每次调用时执行的计算。
94
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
95
+
96
+ :param data: 数据
97
+ :type data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
98
+ :returns: 损失值
99
+ :rtype: torch.Tensor
100
+ """
101
+
102
+ graph = data["graph"]
103
+ relation = data['relation']
104
+ norm = data['norm']
105
+ sample = data["sample"]
106
+ label = data["label"]
107
+ score = self.model(graph, relation, norm, sample)
108
+ label = (1.0 - self.smoothing) * label + (1.0 / self.ent_tol)
109
+ loss = self.loss(score, label)
110
+ return loss
111
+
112
+ def get_compgcn_sampling_hpo_config() -> dict[str, dict[str, typing.Any]]:
113
+
114
+ """返回 :py:class:`CompGCNSampling` 的默认超参数优化配置。
115
+
116
+ 默认配置为::
117
+
118
+ parameters_dict = {
119
+ 'strategy': {
120
+ 'value': 'CompGCNSampling'
121
+ },
122
+ 'smoothing': {
123
+ 'value': 0.1
124
+ }
125
+ }
126
+
127
+ :returns: :py:class:`CompGCNSampling` 的默认超参数优化配置
128
+ :rtype: dict[str, dict[str, typing.Any]]
129
+ """
130
+
131
+ parameters_dict = {
132
+ 'strategy': {
133
+ 'value': 'CompGCNSampling'
134
+ },
135
+ 'smoothing': {
136
+ 'value': 0.1
137
+ }
138
+ }
139
+
140
+ return parameters_dict