unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,130 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/BaseModule.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 4, 2023
|
7
|
+
#
|
8
|
+
# 该头文件定义了 BaseModule.
|
9
|
+
|
10
|
+
"""BaseModule - 所有模块的基类"""
|
11
|
+
|
12
|
+
import os
|
13
|
+
import json
|
14
|
+
import torch
|
15
|
+
import torch.nn as nn
|
16
|
+
import numpy as np
|
17
|
+
from typing import Any
|
18
|
+
|
19
|
+
class BaseModule(nn.Module):
|
20
|
+
|
21
|
+
"""继承自 :py:class:`torch.nn.Module`,并且封装了一些常用功能,如加载和保存模型。"""
|
22
|
+
|
23
|
+
def __init__(self):
|
24
|
+
|
25
|
+
"""创建 BaseModule 对象。"""
|
26
|
+
|
27
|
+
super(BaseModule, self).__init__()
|
28
|
+
|
29
|
+
#: 常数 0
|
30
|
+
self.zero_const: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([0]))
|
31
|
+
self.zero_const.requires_grad = False
|
32
|
+
|
33
|
+
#: 常数 pi
|
34
|
+
self.pi_const: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([3.14159265358979323846]))
|
35
|
+
self.pi_const.requires_grad = False
|
36
|
+
|
37
|
+
def load_checkpoint(self, path: str):
|
38
|
+
|
39
|
+
"""加载模型权重。
|
40
|
+
|
41
|
+
:param path: 模型保存的路径
|
42
|
+
:type path: str
|
43
|
+
"""
|
44
|
+
|
45
|
+
self.load_state_dict(torch.load(os.path.join(path)))
|
46
|
+
self.eval()
|
47
|
+
|
48
|
+
def save_checkpoint(self, path: str):
|
49
|
+
|
50
|
+
"""保存模型权重。
|
51
|
+
|
52
|
+
:param path: 模型保存的路径
|
53
|
+
:type path: str
|
54
|
+
"""
|
55
|
+
|
56
|
+
if not os.path.exists(os.path.split(path)[0]):
|
57
|
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
58
|
+
torch.save(self.state_dict(), path)
|
59
|
+
|
60
|
+
def get_parameters(
|
61
|
+
self,
|
62
|
+
mode: str = "numpy",
|
63
|
+
param_dict: dict[str, Any] | None = None
|
64
|
+
) -> dict[str, np.ndarray] | dict[str, list] | dict[str, torch.Tensor]:
|
65
|
+
|
66
|
+
"""获得模型权重。
|
67
|
+
|
68
|
+
:param mode: 模型保存的格式,可以选择 ``numpy`` 、 ``list`` 和 ``Tensor`` 。
|
69
|
+
:type path: str
|
70
|
+
:param param_dict: 可以选择从哪里获得模型权重。
|
71
|
+
:type param_dict: dict[str, typing.Any] | None
|
72
|
+
:returns: 模型权重字典。
|
73
|
+
:rtype: dict[str, numpy.ndarray] | dict[str, list] | dict[str, torch.Tensor]
|
74
|
+
"""
|
75
|
+
|
76
|
+
all_param_dict = self.state_dict()
|
77
|
+
if param_dict == None:
|
78
|
+
param_dict = all_param_dict.keys()
|
79
|
+
res = {}
|
80
|
+
for param in param_dict:
|
81
|
+
if mode == "numpy":
|
82
|
+
res[param] = all_param_dict[param].cpu().numpy()
|
83
|
+
elif mode == "list":
|
84
|
+
res[param] = all_param_dict[param].cpu().numpy().tolist()
|
85
|
+
else:
|
86
|
+
res[param] = all_param_dict[param]
|
87
|
+
return res
|
88
|
+
|
89
|
+
def set_parameters(self, parameters: dict[str, Any]):
|
90
|
+
|
91
|
+
"""加载模型权重。
|
92
|
+
|
93
|
+
:param parameters: 模型权重字典。
|
94
|
+
:type parameters: dict[str, typing.Any]
|
95
|
+
"""
|
96
|
+
|
97
|
+
for i in parameters:
|
98
|
+
parameters[i] = torch.Tensor(parameters[i])
|
99
|
+
self.load_state_dict(parameters, strict = False)
|
100
|
+
self.eval()
|
101
|
+
|
102
|
+
def load_parameters(self, path: str):
|
103
|
+
|
104
|
+
"""加载模型权重。
|
105
|
+
|
106
|
+
:param path: 模型保存的路径
|
107
|
+
:type path: str
|
108
|
+
"""
|
109
|
+
|
110
|
+
f = open(path, "r")
|
111
|
+
parameters = json.loads(f.read())
|
112
|
+
f.close()
|
113
|
+
for i in parameters:
|
114
|
+
parameters[i] = torch.Tensor(parameters[i])
|
115
|
+
self.load_state_dict(parameters, strict = False)
|
116
|
+
self.eval()
|
117
|
+
|
118
|
+
def save_parameters(self, path: str):
|
119
|
+
|
120
|
+
"""用 json 格式保存模型权重。
|
121
|
+
|
122
|
+
:param path: 模型保存的路径
|
123
|
+
:type path: str
|
124
|
+
"""
|
125
|
+
|
126
|
+
if not os.path.exists(os.path.split(path)[0]):
|
127
|
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
128
|
+
f = open(path, "w")
|
129
|
+
f.write(json.dumps(self.get_parameters("list")))
|
130
|
+
f.close()
|
unike/module/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/__init__.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 28, 2023
|
7
|
+
#
|
8
|
+
# 该头文件定义了 module 接口.
|
9
|
+
|
10
|
+
"""模块部分,包含模型和损失函数。"""
|
11
|
+
|
12
|
+
from __future__ import absolute_import
|
13
|
+
from __future__ import division
|
14
|
+
from __future__ import print_function
|
15
|
+
|
16
|
+
from .BaseModule import BaseModule
|
17
|
+
|
18
|
+
__init__ = {
|
19
|
+
'BaseModule'
|
20
|
+
}
|
@@ -0,0 +1,96 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/loss/CompGCNLoss.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 23, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了 CompGCNLoss 类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
CompGCNLoss - 损失函数类,CompGCN 原论文中应用这种损失函数完成模型学习。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
from .Loss import Loss
|
16
|
+
from typing import Any
|
17
|
+
from ..model import CompGCN
|
18
|
+
|
19
|
+
class CompGCNLoss(Loss):
|
20
|
+
|
21
|
+
"""
|
22
|
+
``CompGCN`` :cite:`CompGCN` 原论文中应用这种损失函数完成模型训练。
|
23
|
+
|
24
|
+
.. Note:: :py:meth:`forward` 中的正样本评分函数的得分应大于负样本评分函数的得分。
|
25
|
+
|
26
|
+
例子::
|
27
|
+
|
28
|
+
from unike.module.loss import CompGCNLoss
|
29
|
+
from unike.module.strategy import CompGCNSampling
|
30
|
+
|
31
|
+
# define the loss function
|
32
|
+
model = CompGCNSampling(
|
33
|
+
model = compgcn,
|
34
|
+
loss = CompGCNLoss(model = compgcn),
|
35
|
+
ent_tol = dataloader.get_ent_tol()
|
36
|
+
)
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
model: CompGCN):
|
42
|
+
|
43
|
+
"""创建 CompGCNLoss 对象。
|
44
|
+
|
45
|
+
:param model: 模型
|
46
|
+
:type model: CompGCN
|
47
|
+
"""
|
48
|
+
|
49
|
+
super(CompGCNLoss, self).__init__()
|
50
|
+
|
51
|
+
#: 模型
|
52
|
+
self.model: CompGCN = model
|
53
|
+
#: 损失函数
|
54
|
+
self.loss: torch.nn.BCELoss = torch.nn.BCELoss()
|
55
|
+
|
56
|
+
def forward(
|
57
|
+
self,
|
58
|
+
pred: torch.Tensor,
|
59
|
+
label: torch.Tensor) -> torch.Tensor:
|
60
|
+
|
61
|
+
"""计算 CompGCNLoss 损失函数。定义每次调用时执行的计算。:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
62
|
+
|
63
|
+
:param pred: 模型的得分。
|
64
|
+
:type pred: torch.Tensor
|
65
|
+
:param labels: 标签
|
66
|
+
:type labels: torch.Tensor
|
67
|
+
:returns: 损失值
|
68
|
+
:rtype: torch.Tensor
|
69
|
+
"""
|
70
|
+
|
71
|
+
loss = self.loss(pred, label)
|
72
|
+
return loss
|
73
|
+
|
74
|
+
def get_compgcn_loss_hpo_config() -> dict[str, dict[str, Any]]:
|
75
|
+
|
76
|
+
"""返回 :py:class:`CompGCNLoss` 的默认超参数优化配置。
|
77
|
+
|
78
|
+
默认配置为::
|
79
|
+
|
80
|
+
parameters_dict = {
|
81
|
+
'loss': {
|
82
|
+
'value': 'CompGCNLoss'
|
83
|
+
}
|
84
|
+
}
|
85
|
+
|
86
|
+
:returns: :py:class:`CompGCNLoss` 的默认超参数优化配置
|
87
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
88
|
+
"""
|
89
|
+
|
90
|
+
parameters_dict = {
|
91
|
+
'loss': {
|
92
|
+
'value': 'CompGCNLoss'
|
93
|
+
}
|
94
|
+
}
|
95
|
+
|
96
|
+
return parameters_dict
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/loss/Loss.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 4, 2023
|
7
|
+
#
|
8
|
+
# 该脚本定义了损失函数的基类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
Loss - 该脚本定义了损失函数的基类。
|
12
|
+
"""
|
13
|
+
|
14
|
+
from ..BaseModule import BaseModule
|
15
|
+
|
16
|
+
class Loss(BaseModule):
|
17
|
+
|
18
|
+
"""
|
19
|
+
继承自 :py:class:`unike.module.BaseModule`,什么额外的属性都没有增加。
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self):
|
23
|
+
|
24
|
+
"""创建 Loss 对象。"""
|
25
|
+
|
26
|
+
super(Loss, self).__init__()
|
@@ -0,0 +1,148 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/loss/MarginLoss.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 9, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了 margin-based ranking criterion 损失函数.
|
9
|
+
|
10
|
+
"""
|
11
|
+
MarginLoss - 损失函数类,TransE 原论文中应用这种损失函数完成模型学习。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import numpy as np
|
16
|
+
from typing import Any
|
17
|
+
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
19
|
+
from .Loss import Loss
|
20
|
+
|
21
|
+
class MarginLoss(Loss):
|
22
|
+
|
23
|
+
"""
|
24
|
+
``TransE`` :cite:`TransE` 原论文中应用这种损失函数完成模型训练。
|
25
|
+
|
26
|
+
.. Note:: :py:meth:`forward` 中的正样本评分函数的得分应小于负样本评分函数的得分。
|
27
|
+
|
28
|
+
例子::
|
29
|
+
|
30
|
+
from unike.module.model import TransE
|
31
|
+
from unike.module.loss import MarginLoss
|
32
|
+
from unike.module.strategy import NegativeSampling
|
33
|
+
|
34
|
+
# define the model
|
35
|
+
transe = TransE(
|
36
|
+
ent_tol = dataloader.get_ent_tol(),
|
37
|
+
rel_tol = dataloader.get_rel_tol(),
|
38
|
+
dim = 50,
|
39
|
+
p_norm = 1,
|
40
|
+
norm_flag = True
|
41
|
+
)
|
42
|
+
|
43
|
+
# define the loss function
|
44
|
+
model = NegativeSampling(
|
45
|
+
model = transe,
|
46
|
+
loss = MarginLoss(margin = 1.0),
|
47
|
+
regul_rate = 0.01
|
48
|
+
)
|
49
|
+
"""
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
adv_temperature: float | None = None,
|
54
|
+
margin: float = 6.0):
|
55
|
+
|
56
|
+
"""创建 MarginLoss 对象。
|
57
|
+
|
58
|
+
:param adv_temperature: RotatE 提出的自我对抗负采样中的温度。
|
59
|
+
:type adv_temperature: float
|
60
|
+
:param margin: gamma。
|
61
|
+
:type margin: float
|
62
|
+
"""
|
63
|
+
|
64
|
+
super(MarginLoss, self).__init__()
|
65
|
+
|
66
|
+
#: gamma
|
67
|
+
self.margin: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([margin]))
|
68
|
+
self.margin.requires_grad = False
|
69
|
+
if adv_temperature != None:
|
70
|
+
#: RotatE 提出的自我对抗负采样中的温度。
|
71
|
+
self.adv_temperature: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([adv_temperature]))
|
72
|
+
self.adv_temperature.requires_grad = False
|
73
|
+
#: 是否启用 RotatE 提出的自我对抗负采样。
|
74
|
+
self.adv_flag: bool = True
|
75
|
+
else:
|
76
|
+
self.adv_flag: bool = False
|
77
|
+
|
78
|
+
def get_weights(
|
79
|
+
self,
|
80
|
+
n_score: torch.Tensor) -> torch.Tensor:
|
81
|
+
|
82
|
+
"""计算 RotatE 提出的自我对抗负采样中的负样本的分布概率。
|
83
|
+
|
84
|
+
:param n_score: 负样本评分函数的得分。
|
85
|
+
:type n_score: torch.Tensor
|
86
|
+
:returns: 自我对抗负采样中的负样本的分布概率
|
87
|
+
:rtype: torch.Tensor
|
88
|
+
"""
|
89
|
+
|
90
|
+
return F.softmax(-n_score * self.adv_temperature, dim = -1).detach()
|
91
|
+
|
92
|
+
def forward(
|
93
|
+
self,
|
94
|
+
p_score: torch.Tensor,
|
95
|
+
n_score: torch.Tensor) -> torch.Tensor:
|
96
|
+
|
97
|
+
"""计算 margin-based ranking criterion 损失函数。定义每次调用时执行的计算。
|
98
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
99
|
+
|
100
|
+
:param p_score: 正样本评分函数的得分。
|
101
|
+
:type p_score: torch.Tensor
|
102
|
+
:param n_score: 负样本评分函数的得分。
|
103
|
+
:type n_score: torch.Tensor
|
104
|
+
:returns: 损失值
|
105
|
+
:rtype: torch.Tensor
|
106
|
+
"""
|
107
|
+
|
108
|
+
if self.adv_flag:
|
109
|
+
return (self.get_weights(n_score) * torch.max(p_score - n_score,
|
110
|
+
-self.margin)).sum(dim = -1).mean() + self.margin
|
111
|
+
else:
|
112
|
+
return (torch.max(p_score - n_score, -self.margin)).mean() + self.margin
|
113
|
+
|
114
|
+
def get_margin_loss_hpo_config() -> dict[str, dict[str, Any]]:
|
115
|
+
|
116
|
+
"""返回 :py:class:`MarginLoss` 的默认超参数优化配置。
|
117
|
+
|
118
|
+
默认配置为::
|
119
|
+
|
120
|
+
parameters_dict = {
|
121
|
+
'loss': {
|
122
|
+
'value': 'MarginLoss'
|
123
|
+
},
|
124
|
+
'adv_temperature': {
|
125
|
+
'values': [1.0, 3.0, 6.0]
|
126
|
+
},
|
127
|
+
'margin': {
|
128
|
+
'values': [1.0, 3.0, 6.0]
|
129
|
+
}
|
130
|
+
}
|
131
|
+
|
132
|
+
:returns: :py:class:`MarginLoss` 的默认超参数优化配置
|
133
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
134
|
+
"""
|
135
|
+
|
136
|
+
parameters_dict = {
|
137
|
+
'loss': {
|
138
|
+
'value': 'MarginLoss'
|
139
|
+
},
|
140
|
+
'adv_temperature': {
|
141
|
+
'values': [1.0, 3.0, 6.0]
|
142
|
+
},
|
143
|
+
'margin': {
|
144
|
+
'values': [1.0, 3.0, 6.0]
|
145
|
+
}
|
146
|
+
}
|
147
|
+
|
148
|
+
return parameters_dict
|
@@ -0,0 +1,117 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/loss/RGCN_Loss.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 17, 2023
|
7
|
+
#
|
8
|
+
# 该脚本定义了 RGCNLoss 类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
RGCNLoss - 损失函数类,R-GCN 原论文中应用这种损失函数完成模型学习。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
from typing import Any
|
16
|
+
from ..model import RGCN
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from .Loss import Loss
|
19
|
+
|
20
|
+
class RGCNLoss(Loss):
|
21
|
+
|
22
|
+
"""
|
23
|
+
``R-GCN`` :cite:`R-GCN` 原论文中应用这种损失函数完成模型训练。
|
24
|
+
|
25
|
+
.. Note:: :py:meth:`forward` 中的正样本评分函数的得分应大于负样本评分函数的得分。
|
26
|
+
|
27
|
+
例子::
|
28
|
+
|
29
|
+
from unike.module.loss import RGCNLoss
|
30
|
+
from unike.module.strategy import RGCNSampling
|
31
|
+
|
32
|
+
# define the loss function
|
33
|
+
model = RGCNSampling(
|
34
|
+
model = rgcn,
|
35
|
+
loss = RGCNLoss(model = rgcn, regularization = 1e-5)
|
36
|
+
)
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
model: RGCN,
|
42
|
+
regularization: float):
|
43
|
+
|
44
|
+
"""创建 RGCNLoss 对象。
|
45
|
+
|
46
|
+
:param model: 模型
|
47
|
+
:type model: RGCN
|
48
|
+
:param regularization: 正则率
|
49
|
+
:type regularization: float
|
50
|
+
"""
|
51
|
+
|
52
|
+
super(RGCNLoss, self).__init__()
|
53
|
+
|
54
|
+
#: 模型
|
55
|
+
self.model: RGCN = model
|
56
|
+
#: 正则率
|
57
|
+
self.regularization: float = regularization
|
58
|
+
|
59
|
+
def reg_loss(self) -> torch.Tensor:
|
60
|
+
|
61
|
+
"""获得正则部分的损失。
|
62
|
+
|
63
|
+
:returns: 损失值
|
64
|
+
:rtype: torch.Tensor
|
65
|
+
"""
|
66
|
+
|
67
|
+
return torch.mean(self.model.Loss_emb.pow(2)) + torch.mean(self.model.rel_emb.pow(2))
|
68
|
+
|
69
|
+
def forward(
|
70
|
+
self,
|
71
|
+
score: torch.Tensor,
|
72
|
+
labels: torch.Tensor) -> torch.Tensor:
|
73
|
+
|
74
|
+
"""计算 RGCNLoss 损失函数。定义每次调用时执行的计算。:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
75
|
+
|
76
|
+
:param score: 模型的得分。
|
77
|
+
:type score: torch.Tensor
|
78
|
+
:param labels: 标签
|
79
|
+
:type labels: torch.Tensor
|
80
|
+
:returns: 损失值
|
81
|
+
:rtype: torch.Tensor
|
82
|
+
"""
|
83
|
+
|
84
|
+
loss = F.binary_cross_entropy_with_logits(score, labels)
|
85
|
+
regu = self.regularization * self.reg_loss()
|
86
|
+
loss += regu
|
87
|
+
return loss
|
88
|
+
|
89
|
+
def get_rgcn_loss_hpo_config() -> dict[str, dict[str, Any]]:
|
90
|
+
|
91
|
+
"""返回 :py:class:`RGCNLoss` 的默认超参数优化配置。
|
92
|
+
|
93
|
+
默认配置为::
|
94
|
+
|
95
|
+
parameters_dict = {
|
96
|
+
'loss': {
|
97
|
+
'value': 'RGCNLoss'
|
98
|
+
},
|
99
|
+
'regularization': {
|
100
|
+
'value': 1e-5
|
101
|
+
}
|
102
|
+
}
|
103
|
+
|
104
|
+
:returns: :py:class:`RGCNLoss` 的默认超参数优化配置
|
105
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
106
|
+
"""
|
107
|
+
|
108
|
+
parameters_dict = {
|
109
|
+
'loss': {
|
110
|
+
'value': 'RGCNLoss'
|
111
|
+
},
|
112
|
+
'regularization': {
|
113
|
+
'value': 1e-5
|
114
|
+
}
|
115
|
+
}
|
116
|
+
|
117
|
+
return parameters_dict
|
@@ -0,0 +1,145 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/loss/SigmoidLoss.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 11, 2023
|
7
|
+
#
|
8
|
+
# 该脚本定义了 regularized logistic loss 损失函数.
|
9
|
+
|
10
|
+
"""
|
11
|
+
SigmoidLoss - 损失函数类,RotatE 原论文应用这种损失函数完成模型学习。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import numpy as np
|
16
|
+
import torch.nn as nn
|
17
|
+
from typing import Any
|
18
|
+
import torch.nn.functional as F
|
19
|
+
from .Loss import Loss
|
20
|
+
|
21
|
+
class SigmoidLoss(Loss):
|
22
|
+
|
23
|
+
"""
|
24
|
+
``RotatE`` :cite:`RotatE` 原论文中应用这种损失函数完成模型训练。
|
25
|
+
|
26
|
+
.. Note:: :py:meth:`forward` 中的正样本评分函数的得分应大于负样本评分函数的得分。
|
27
|
+
|
28
|
+
例子::
|
29
|
+
|
30
|
+
from unike.module.loss import SigmoidLoss
|
31
|
+
from unike.module.strategy import NegativeSampling
|
32
|
+
|
33
|
+
# define the loss function
|
34
|
+
model = NegativeSampling(
|
35
|
+
model = rotate,
|
36
|
+
loss = SigmoidLoss(adv_temperature = 2),
|
37
|
+
batch_size = train_dataloader.get_batch_size(),
|
38
|
+
regul_rate = 0.0
|
39
|
+
)
|
40
|
+
"""
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
adv_temperature: float | None = None):
|
45
|
+
|
46
|
+
"""创建 SigmoidLoss 对象。
|
47
|
+
|
48
|
+
:param adv_temperature: RotatE 提出的自我对抗负采样中的温度。
|
49
|
+
:type adv_temperature: float
|
50
|
+
"""
|
51
|
+
|
52
|
+
super(SigmoidLoss, self).__init__()
|
53
|
+
#: 逻辑函数,类型为 :py:class:`torch.nn.LogSigmoid`。
|
54
|
+
self.criterion: torch.nn.LogSigmoid = nn.LogSigmoid()
|
55
|
+
if adv_temperature != None:
|
56
|
+
#: RotatE 提出的自我对抗负采样中的温度。
|
57
|
+
self.adv_temperature: torch.nn.parameter.Parameter = nn.Parameter(torch.Tensor([adv_temperature]))
|
58
|
+
self.adv_temperature.requires_grad = False
|
59
|
+
#: 是否启用 RotatE 提出的自我对抗负采样。
|
60
|
+
self.adv_flag: bool = True
|
61
|
+
else:
|
62
|
+
self.adv_flag: bool = False
|
63
|
+
|
64
|
+
def get_weights(
|
65
|
+
self,
|
66
|
+
n_score: torch.Tensor) -> torch.Tensor:
|
67
|
+
|
68
|
+
"""计算 RotatE 提出的自我对抗负采样中的负样本的分布概率。
|
69
|
+
|
70
|
+
:param n_score: 负样本评分函数的得分。
|
71
|
+
:type n_score: torch.Tensor
|
72
|
+
:returns: 自我对抗负采样中的负样本的分布概率
|
73
|
+
:rtype: torch.Tensor
|
74
|
+
"""
|
75
|
+
|
76
|
+
return F.softmax(n_score * self.adv_temperature, dim = -1).detach()
|
77
|
+
|
78
|
+
def forward(
|
79
|
+
self,
|
80
|
+
p_score: torch.Tensor,
|
81
|
+
n_score: torch.Tensor) -> torch.Tensor:
|
82
|
+
|
83
|
+
"""计算 SigmoidLoss 损失函数。定义每次调用时执行的计算。
|
84
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
85
|
+
|
86
|
+
:param p_score: 正样本评分函数的得分。
|
87
|
+
:type p_score: torch.Tensor
|
88
|
+
:param n_score: 负样本评分函数的得分。
|
89
|
+
:type n_score: torch.Tensor
|
90
|
+
:returns: 损失值
|
91
|
+
:rtype: torch.Tensor
|
92
|
+
"""
|
93
|
+
|
94
|
+
if self.adv_flag:
|
95
|
+
return -(self.criterion(p_score).mean() + (self.get_weights(n_score) * self.criterion(-n_score)).sum(dim = -1).mean()) / 2
|
96
|
+
else:
|
97
|
+
return -(self.criterion(p_score).mean() + self.criterion(-n_score).mean()) / 2
|
98
|
+
|
99
|
+
def predict(
|
100
|
+
self,
|
101
|
+
p_score: torch.Tensor,
|
102
|
+
n_score: torch.Tensor) -> np.ndarray:
|
103
|
+
|
104
|
+
"""SigmoidLoss 的推理方法。
|
105
|
+
|
106
|
+
:param p_score: 正样本评分函数的得分。
|
107
|
+
:type p_score: torch.Tensor
|
108
|
+
:param n_score: 负样本评分函数的得分。
|
109
|
+
:type n_score: torch.Tensor
|
110
|
+
:returns: 损失值
|
111
|
+
:rtype: numpy.ndarray
|
112
|
+
"""
|
113
|
+
|
114
|
+
score = self.forward(p_score, n_score)
|
115
|
+
return score.cpu().data.numpy()
|
116
|
+
|
117
|
+
def get_sigmoid_loss_hpo_config() -> dict[str, dict[str, Any]]:
|
118
|
+
|
119
|
+
"""返回 :py:class:`SigmoidLoss` 的默认超参数优化配置。
|
120
|
+
|
121
|
+
默认配置为::
|
122
|
+
|
123
|
+
parameters_dict = {
|
124
|
+
'loss': {
|
125
|
+
'value': 'SigmoidLoss'
|
126
|
+
},
|
127
|
+
'adv_temperature': {
|
128
|
+
'values': [1.0, 3.0, 6.0]
|
129
|
+
}
|
130
|
+
}
|
131
|
+
|
132
|
+
:returns: :py:class:`SigmoidLoss` 的默认超参数优化配置
|
133
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
134
|
+
"""
|
135
|
+
|
136
|
+
parameters_dict = {
|
137
|
+
'loss': {
|
138
|
+
'value': 'SigmoidLoss'
|
139
|
+
},
|
140
|
+
'adv_temperature': {
|
141
|
+
'values': [1.0, 3.0, 6.0]
|
142
|
+
}
|
143
|
+
}
|
144
|
+
|
145
|
+
return parameters_dict
|