unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,130 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/BaseModule.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 4, 2023
7
+ #
8
+ # 该头文件定义了 BaseModule.
9
+
10
+ """BaseModule - 所有模块的基类"""
11
+
12
+ import os
13
+ import json
14
+ import torch
15
+ import torch.nn as nn
16
+ import numpy as np
17
+ from typing import Any
18
+
19
+ class BaseModule(nn.Module):
20
+
21
+ """继承自 :py:class:`torch.nn.Module`,并且封装了一些常用功能,如加载和保存模型。"""
22
+
23
+ def __init__(self):
24
+
25
+ """创建 BaseModule 对象。"""
26
+
27
+ super(BaseModule, self).__init__()
28
+
29
+ #: 常数 0
30
+ self.zero_const: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([0]))
31
+ self.zero_const.requires_grad = False
32
+
33
+ #: 常数 pi
34
+ self.pi_const: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([3.14159265358979323846]))
35
+ self.pi_const.requires_grad = False
36
+
37
+ def load_checkpoint(self, path: str):
38
+
39
+ """加载模型权重。
40
+
41
+ :param path: 模型保存的路径
42
+ :type path: str
43
+ """
44
+
45
+ self.load_state_dict(torch.load(os.path.join(path)))
46
+ self.eval()
47
+
48
+ def save_checkpoint(self, path: str):
49
+
50
+ """保存模型权重。
51
+
52
+ :param path: 模型保存的路径
53
+ :type path: str
54
+ """
55
+
56
+ if not os.path.exists(os.path.split(path)[0]):
57
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
58
+ torch.save(self.state_dict(), path)
59
+
60
+ def get_parameters(
61
+ self,
62
+ mode: str = "numpy",
63
+ param_dict: dict[str, Any] | None = None
64
+ ) -> dict[str, np.ndarray] | dict[str, list] | dict[str, torch.Tensor]:
65
+
66
+ """获得模型权重。
67
+
68
+ :param mode: 模型保存的格式,可以选择 ``numpy`` 、 ``list`` 和 ``Tensor`` 。
69
+ :type path: str
70
+ :param param_dict: 可以选择从哪里获得模型权重。
71
+ :type param_dict: dict[str, typing.Any] | None
72
+ :returns: 模型权重字典。
73
+ :rtype: dict[str, numpy.ndarray] | dict[str, list] | dict[str, torch.Tensor]
74
+ """
75
+
76
+ all_param_dict = self.state_dict()
77
+ if param_dict == None:
78
+ param_dict = all_param_dict.keys()
79
+ res = {}
80
+ for param in param_dict:
81
+ if mode == "numpy":
82
+ res[param] = all_param_dict[param].cpu().numpy()
83
+ elif mode == "list":
84
+ res[param] = all_param_dict[param].cpu().numpy().tolist()
85
+ else:
86
+ res[param] = all_param_dict[param]
87
+ return res
88
+
89
+ def set_parameters(self, parameters: dict[str, Any]):
90
+
91
+ """加载模型权重。
92
+
93
+ :param parameters: 模型权重字典。
94
+ :type parameters: dict[str, typing.Any]
95
+ """
96
+
97
+ for i in parameters:
98
+ parameters[i] = torch.Tensor(parameters[i])
99
+ self.load_state_dict(parameters, strict = False)
100
+ self.eval()
101
+
102
+ def load_parameters(self, path: str):
103
+
104
+ """加载模型权重。
105
+
106
+ :param path: 模型保存的路径
107
+ :type path: str
108
+ """
109
+
110
+ f = open(path, "r")
111
+ parameters = json.loads(f.read())
112
+ f.close()
113
+ for i in parameters:
114
+ parameters[i] = torch.Tensor(parameters[i])
115
+ self.load_state_dict(parameters, strict = False)
116
+ self.eval()
117
+
118
+ def save_parameters(self, path: str):
119
+
120
+ """用 json 格式保存模型权重。
121
+
122
+ :param path: 模型保存的路径
123
+ :type path: str
124
+ """
125
+
126
+ if not os.path.exists(os.path.split(path)[0]):
127
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
128
+ f = open(path, "w")
129
+ f.write(json.dumps(self.get_parameters("list")))
130
+ f.close()
@@ -0,0 +1,20 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/__init__.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 28, 2023
7
+ #
8
+ # 该头文件定义了 module 接口.
9
+
10
+ """模块部分,包含模型和损失函数。"""
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+ from .BaseModule import BaseModule
17
+
18
+ __init__ = {
19
+ 'BaseModule'
20
+ }
@@ -0,0 +1,96 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/loss/CompGCNLoss.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 23, 2024
7
+ #
8
+ # 该脚本定义了 CompGCNLoss 类.
9
+
10
+ """
11
+ CompGCNLoss - 损失函数类,CompGCN 原论文中应用这种损失函数完成模型学习。
12
+ """
13
+
14
+ import torch
15
+ from .Loss import Loss
16
+ from typing import Any
17
+ from ..model import CompGCN
18
+
19
+ class CompGCNLoss(Loss):
20
+
21
+ """
22
+ ``CompGCN`` :cite:`CompGCN` 原论文中应用这种损失函数完成模型训练。
23
+
24
+ .. Note:: :py:meth:`forward` 中的正样本评分函数的得分应大于负样本评分函数的得分。
25
+
26
+ 例子::
27
+
28
+ from unike.module.loss import CompGCNLoss
29
+ from unike.module.strategy import CompGCNSampling
30
+
31
+ # define the loss function
32
+ model = CompGCNSampling(
33
+ model = compgcn,
34
+ loss = CompGCNLoss(model = compgcn),
35
+ ent_tol = dataloader.get_ent_tol()
36
+ )
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ model: CompGCN):
42
+
43
+ """创建 CompGCNLoss 对象。
44
+
45
+ :param model: 模型
46
+ :type model: CompGCN
47
+ """
48
+
49
+ super(CompGCNLoss, self).__init__()
50
+
51
+ #: 模型
52
+ self.model: CompGCN = model
53
+ #: 损失函数
54
+ self.loss: torch.nn.BCELoss = torch.nn.BCELoss()
55
+
56
+ def forward(
57
+ self,
58
+ pred: torch.Tensor,
59
+ label: torch.Tensor) -> torch.Tensor:
60
+
61
+ """计算 CompGCNLoss 损失函数。定义每次调用时执行的计算。:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
62
+
63
+ :param pred: 模型的得分。
64
+ :type pred: torch.Tensor
65
+ :param labels: 标签
66
+ :type labels: torch.Tensor
67
+ :returns: 损失值
68
+ :rtype: torch.Tensor
69
+ """
70
+
71
+ loss = self.loss(pred, label)
72
+ return loss
73
+
74
+ def get_compgcn_loss_hpo_config() -> dict[str, dict[str, Any]]:
75
+
76
+ """返回 :py:class:`CompGCNLoss` 的默认超参数优化配置。
77
+
78
+ 默认配置为::
79
+
80
+ parameters_dict = {
81
+ 'loss': {
82
+ 'value': 'CompGCNLoss'
83
+ }
84
+ }
85
+
86
+ :returns: :py:class:`CompGCNLoss` 的默认超参数优化配置
87
+ :rtype: dict[str, dict[str, typing.Any]]
88
+ """
89
+
90
+ parameters_dict = {
91
+ 'loss': {
92
+ 'value': 'CompGCNLoss'
93
+ }
94
+ }
95
+
96
+ return parameters_dict
@@ -0,0 +1,26 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/loss/Loss.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 4, 2023
7
+ #
8
+ # 该脚本定义了损失函数的基类.
9
+
10
+ """
11
+ Loss - 该脚本定义了损失函数的基类。
12
+ """
13
+
14
+ from ..BaseModule import BaseModule
15
+
16
+ class Loss(BaseModule):
17
+
18
+ """
19
+ 继承自 :py:class:`unike.module.BaseModule`,什么额外的属性都没有增加。
20
+ """
21
+
22
+ def __init__(self):
23
+
24
+ """创建 Loss 对象。"""
25
+
26
+ super(Loss, self).__init__()
@@ -0,0 +1,148 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/loss/MarginLoss.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 9, 2024
7
+ #
8
+ # 该脚本定义了 margin-based ranking criterion 损失函数.
9
+
10
+ """
11
+ MarginLoss - 损失函数类,TransE 原论文中应用这种损失函数完成模型学习。
12
+ """
13
+
14
+ import torch
15
+ import numpy as np
16
+ from typing import Any
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from .Loss import Loss
20
+
21
+ class MarginLoss(Loss):
22
+
23
+ """
24
+ ``TransE`` :cite:`TransE` 原论文中应用这种损失函数完成模型训练。
25
+
26
+ .. Note:: :py:meth:`forward` 中的正样本评分函数的得分应小于负样本评分函数的得分。
27
+
28
+ 例子::
29
+
30
+ from unike.module.model import TransE
31
+ from unike.module.loss import MarginLoss
32
+ from unike.module.strategy import NegativeSampling
33
+
34
+ # define the model
35
+ transe = TransE(
36
+ ent_tol = dataloader.get_ent_tol(),
37
+ rel_tol = dataloader.get_rel_tol(),
38
+ dim = 50,
39
+ p_norm = 1,
40
+ norm_flag = True
41
+ )
42
+
43
+ # define the loss function
44
+ model = NegativeSampling(
45
+ model = transe,
46
+ loss = MarginLoss(margin = 1.0),
47
+ regul_rate = 0.01
48
+ )
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ adv_temperature: float | None = None,
54
+ margin: float = 6.0):
55
+
56
+ """创建 MarginLoss 对象。
57
+
58
+ :param adv_temperature: RotatE 提出的自我对抗负采样中的温度。
59
+ :type adv_temperature: float
60
+ :param margin: gamma。
61
+ :type margin: float
62
+ """
63
+
64
+ super(MarginLoss, self).__init__()
65
+
66
+ #: gamma
67
+ self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
68
+ self.margin.requires_grad = False
69
+ if adv_temperature != None:
70
+ #: RotatE 提出的自我对抗负采样中的温度。
71
+ self.adv_temperature: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([adv_temperature]))
72
+ self.adv_temperature.requires_grad = False
73
+ #: 是否启用 RotatE 提出的自我对抗负采样。
74
+ self.adv_flag: bool = True
75
+ else:
76
+ self.adv_flag: bool = False
77
+
78
+ def get_weights(
79
+ self,
80
+ n_score: torch.Tensor) -> torch.Tensor:
81
+
82
+ """计算 RotatE 提出的自我对抗负采样中的负样本的分布概率。
83
+
84
+ :param n_score: 负样本评分函数的得分。
85
+ :type n_score: torch.Tensor
86
+ :returns: 自我对抗负采样中的负样本的分布概率
87
+ :rtype: torch.Tensor
88
+ """
89
+
90
+ return F.softmax(-n_score * self.adv_temperature, dim = -1).detach()
91
+
92
+ def forward(
93
+ self,
94
+ p_score: torch.Tensor,
95
+ n_score: torch.Tensor) -> torch.Tensor:
96
+
97
+ """计算 margin-based ranking criterion 损失函数。定义每次调用时执行的计算。
98
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
99
+
100
+ :param p_score: 正样本评分函数的得分。
101
+ :type p_score: torch.Tensor
102
+ :param n_score: 负样本评分函数的得分。
103
+ :type n_score: torch.Tensor
104
+ :returns: 损失值
105
+ :rtype: torch.Tensor
106
+ """
107
+
108
+ if self.adv_flag:
109
+ return (self.get_weights(n_score) * torch.max(p_score - n_score,
110
+ -self.margin)).sum(dim = -1).mean() + self.margin
111
+ else:
112
+ return (torch.max(p_score - n_score, -self.margin)).mean() + self.margin
113
+
114
+ def get_margin_loss_hpo_config() -> dict[str, dict[str, Any]]:
115
+
116
+ """返回 :py:class:`MarginLoss` 的默认超参数优化配置。
117
+
118
+ 默认配置为::
119
+
120
+ parameters_dict = {
121
+ 'loss': {
122
+ 'value': 'MarginLoss'
123
+ },
124
+ 'adv_temperature': {
125
+ 'values': [1.0, 3.0, 6.0]
126
+ },
127
+ 'margin': {
128
+ 'values': [1.0, 3.0, 6.0]
129
+ }
130
+ }
131
+
132
+ :returns: :py:class:`MarginLoss` 的默认超参数优化配置
133
+ :rtype: dict[str, dict[str, typing.Any]]
134
+ """
135
+
136
+ parameters_dict = {
137
+ 'loss': {
138
+ 'value': 'MarginLoss'
139
+ },
140
+ 'adv_temperature': {
141
+ 'values': [1.0, 3.0, 6.0]
142
+ },
143
+ 'margin': {
144
+ 'values': [1.0, 3.0, 6.0]
145
+ }
146
+ }
147
+
148
+ return parameters_dict
@@ -0,0 +1,117 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/loss/RGCN_Loss.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 17, 2023
7
+ #
8
+ # 该脚本定义了 RGCNLoss 类.
9
+
10
+ """
11
+ RGCNLoss - 损失函数类,R-GCN 原论文中应用这种损失函数完成模型学习。
12
+ """
13
+
14
+ import torch
15
+ from typing import Any
16
+ from ..model import RGCN
17
+ import torch.nn.functional as F
18
+ from .Loss import Loss
19
+
20
+ class RGCNLoss(Loss):
21
+
22
+ """
23
+ ``R-GCN`` :cite:`R-GCN` 原论文中应用这种损失函数完成模型训练。
24
+
25
+ .. Note:: :py:meth:`forward` 中的正样本评分函数的得分应大于负样本评分函数的得分。
26
+
27
+ 例子::
28
+
29
+ from unike.module.loss import RGCNLoss
30
+ from unike.module.strategy import RGCNSampling
31
+
32
+ # define the loss function
33
+ model = RGCNSampling(
34
+ model = rgcn,
35
+ loss = RGCNLoss(model = rgcn, regularization = 1e-5)
36
+ )
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ model: RGCN,
42
+ regularization: float):
43
+
44
+ """创建 RGCNLoss 对象。
45
+
46
+ :param model: 模型
47
+ :type model: RGCN
48
+ :param regularization: 正则率
49
+ :type regularization: float
50
+ """
51
+
52
+ super(RGCNLoss, self).__init__()
53
+
54
+ #: 模型
55
+ self.model: RGCN = model
56
+ #: 正则率
57
+ self.regularization: float = regularization
58
+
59
+ def reg_loss(self) -> torch.Tensor:
60
+
61
+ """获得正则部分的损失。
62
+
63
+ :returns: 损失值
64
+ :rtype: torch.Tensor
65
+ """
66
+
67
+ return torch.mean(self.model.Loss_emb.pow(2)) + torch.mean(self.model.rel_emb.pow(2))
68
+
69
+ def forward(
70
+ self,
71
+ score: torch.Tensor,
72
+ labels: torch.Tensor) -> torch.Tensor:
73
+
74
+ """计算 RGCNLoss 损失函数。定义每次调用时执行的计算。:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
75
+
76
+ :param score: 模型的得分。
77
+ :type score: torch.Tensor
78
+ :param labels: 标签
79
+ :type labels: torch.Tensor
80
+ :returns: 损失值
81
+ :rtype: torch.Tensor
82
+ """
83
+
84
+ loss = F.binary_cross_entropy_with_logits(score, labels)
85
+ regu = self.regularization * self.reg_loss()
86
+ loss += regu
87
+ return loss
88
+
89
+ def get_rgcn_loss_hpo_config() -> dict[str, dict[str, Any]]:
90
+
91
+ """返回 :py:class:`RGCNLoss` 的默认超参数优化配置。
92
+
93
+ 默认配置为::
94
+
95
+ parameters_dict = {
96
+ 'loss': {
97
+ 'value': 'RGCNLoss'
98
+ },
99
+ 'regularization': {
100
+ 'value': 1e-5
101
+ }
102
+ }
103
+
104
+ :returns: :py:class:`RGCNLoss` 的默认超参数优化配置
105
+ :rtype: dict[str, dict[str, typing.Any]]
106
+ """
107
+
108
+ parameters_dict = {
109
+ 'loss': {
110
+ 'value': 'RGCNLoss'
111
+ },
112
+ 'regularization': {
113
+ 'value': 1e-5
114
+ }
115
+ }
116
+
117
+ return parameters_dict
@@ -0,0 +1,145 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/loss/SigmoidLoss.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 11, 2023
7
+ #
8
+ # 该脚本定义了 regularized logistic loss 损失函数.
9
+
10
+ """
11
+ SigmoidLoss - 损失函数类,RotatE 原论文应用这种损失函数完成模型学习。
12
+ """
13
+
14
+ import torch
15
+ import numpy as np
16
+ import torch.nn as nn
17
+ from typing import Any
18
+ import torch.nn.functional as F
19
+ from .Loss import Loss
20
+
21
+ class SigmoidLoss(Loss):
22
+
23
+ """
24
+ ``RotatE`` :cite:`RotatE` 原论文中应用这种损失函数完成模型训练。
25
+
26
+ .. Note:: :py:meth:`forward` 中的正样本评分函数的得分应大于负样本评分函数的得分。
27
+
28
+ 例子::
29
+
30
+ from unike.module.loss import SigmoidLoss
31
+ from unike.module.strategy import NegativeSampling
32
+
33
+ # define the loss function
34
+ model = NegativeSampling(
35
+ model = rotate,
36
+ loss = SigmoidLoss(adv_temperature = 2),
37
+ batch_size = train_dataloader.get_batch_size(),
38
+ regul_rate = 0.0
39
+ )
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ adv_temperature: float | None = None):
45
+
46
+ """创建 SigmoidLoss 对象。
47
+
48
+ :param adv_temperature: RotatE 提出的自我对抗负采样中的温度。
49
+ :type adv_temperature: float
50
+ """
51
+
52
+ super(SigmoidLoss, self).__init__()
53
+ #: 逻辑函数,类型为 :py:class:`torch.nn.LogSigmoid`。
54
+ self.criterion: torch.nn.LogSigmoid = nn.LogSigmoid()
55
+ if adv_temperature != None:
56
+ #: RotatE 提出的自我对抗负采样中的温度。
57
+ self.adv_temperature: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([adv_temperature]))
58
+ self.adv_temperature.requires_grad = False
59
+ #: 是否启用 RotatE 提出的自我对抗负采样。
60
+ self.adv_flag: bool = True
61
+ else:
62
+ self.adv_flag: bool = False
63
+
64
+ def get_weights(
65
+ self,
66
+ n_score: torch.Tensor) -> torch.Tensor:
67
+
68
+ """计算 RotatE 提出的自我对抗负采样中的负样本的分布概率。
69
+
70
+ :param n_score: 负样本评分函数的得分。
71
+ :type n_score: torch.Tensor
72
+ :returns: 自我对抗负采样中的负样本的分布概率
73
+ :rtype: torch.Tensor
74
+ """
75
+
76
+ return F.softmax(n_score * self.adv_temperature, dim = -1).detach()
77
+
78
+ def forward(
79
+ self,
80
+ p_score: torch.Tensor,
81
+ n_score: torch.Tensor) -> torch.Tensor:
82
+
83
+ """计算 SigmoidLoss 损失函数。定义每次调用时执行的计算。
84
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
85
+
86
+ :param p_score: 正样本评分函数的得分。
87
+ :type p_score: torch.Tensor
88
+ :param n_score: 负样本评分函数的得分。
89
+ :type n_score: torch.Tensor
90
+ :returns: 损失值
91
+ :rtype: torch.Tensor
92
+ """
93
+
94
+ if self.adv_flag:
95
+ return -(self.criterion(p_score).mean() + (self.get_weights(n_score) * self.criterion(-n_score)).sum(dim = -1).mean()) / 2
96
+ else:
97
+ return -(self.criterion(p_score).mean() + self.criterion(-n_score).mean()) / 2
98
+
99
+ def predict(
100
+ self,
101
+ p_score: torch.Tensor,
102
+ n_score: torch.Tensor) -> np.ndarray:
103
+
104
+ """SigmoidLoss 的推理方法。
105
+
106
+ :param p_score: 正样本评分函数的得分。
107
+ :type p_score: torch.Tensor
108
+ :param n_score: 负样本评分函数的得分。
109
+ :type n_score: torch.Tensor
110
+ :returns: 损失值
111
+ :rtype: numpy.ndarray
112
+ """
113
+
114
+ score = self.forward(p_score, n_score)
115
+ return score.cpu().data.numpy()
116
+
117
+ def get_sigmoid_loss_hpo_config() -> dict[str, dict[str, Any]]:
118
+
119
+ """返回 :py:class:`SigmoidLoss` 的默认超参数优化配置。
120
+
121
+ 默认配置为::
122
+
123
+ parameters_dict = {
124
+ 'loss': {
125
+ 'value': 'SigmoidLoss'
126
+ },
127
+ 'adv_temperature': {
128
+ 'values': [1.0, 3.0, 6.0]
129
+ }
130
+ }
131
+
132
+ :returns: :py:class:`SigmoidLoss` 的默认超参数优化配置
133
+ :rtype: dict[str, dict[str, typing.Any]]
134
+ """
135
+
136
+ parameters_dict = {
137
+ 'loss': {
138
+ 'value': 'SigmoidLoss'
139
+ },
140
+ 'adv_temperature': {
141
+ 'values': [1.0, 3.0, 6.0]
142
+ }
143
+ }
144
+
145
+ return parameters_dict