unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,138 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/strategy/NegativeSampling.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 9, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了平移模型和语义匹配模型的训练策略.
|
9
|
+
|
10
|
+
"""
|
11
|
+
NegativeSampling - 训练策略类,包含损失函数。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import typing
|
16
|
+
from ..loss import Loss
|
17
|
+
from ..model import Model
|
18
|
+
from .Strategy import Strategy
|
19
|
+
|
20
|
+
class NegativeSampling(Strategy):
|
21
|
+
|
22
|
+
"""
|
23
|
+
将模型和损失函数封装到一起,方便模型训练。
|
24
|
+
|
25
|
+
例子::
|
26
|
+
|
27
|
+
from unike.module.model import TransE
|
28
|
+
from unike.module.loss import MarginLoss
|
29
|
+
from unike.module.strategy import NegativeSampling
|
30
|
+
|
31
|
+
# define the model
|
32
|
+
transe = TransE(
|
33
|
+
ent_tol = dataloader.get_ent_tol(),
|
34
|
+
rel_tol = dataloader.get_rel_tol(),
|
35
|
+
dim = 50,
|
36
|
+
p_norm = 1,
|
37
|
+
norm_flag = True
|
38
|
+
)
|
39
|
+
|
40
|
+
# define the loss function
|
41
|
+
model = NegativeSampling(
|
42
|
+
model = transe,
|
43
|
+
loss = MarginLoss(margin = 1.0),
|
44
|
+
regul_rate = 0.01
|
45
|
+
)
|
46
|
+
"""
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
model: Model = None,
|
51
|
+
loss: Loss = None,
|
52
|
+
regul_rate: float = 0.0,
|
53
|
+
l3_regul_rate: float = 0.0):
|
54
|
+
|
55
|
+
"""创建 NegativeSampling 对象。
|
56
|
+
|
57
|
+
:param model: KGE 模型
|
58
|
+
:type model: :py:class:`unike.module.model.Model`
|
59
|
+
:param loss: 损失函数。
|
60
|
+
:type loss: :py:class:`unike.module.loss.Loss`
|
61
|
+
:param regul_rate: 权重衰减系数
|
62
|
+
:type regul_rate: float
|
63
|
+
:param l3_regul_rate: l3 正则化系数
|
64
|
+
:type l3_regul_rate: float
|
65
|
+
"""
|
66
|
+
|
67
|
+
super(NegativeSampling, self).__init__()
|
68
|
+
#: KGE 模型,即 :py:class:`unike.module.model.Model`
|
69
|
+
self.model: Model = model
|
70
|
+
#: 损失函数,即 :py:class:`unike.module.loss.Loss`
|
71
|
+
self.loss: Loss = loss
|
72
|
+
#: 权重衰减系数
|
73
|
+
self.regul_rate: float = regul_rate
|
74
|
+
#: l3 正则化系数
|
75
|
+
self.l3_regul_rate: float = l3_regul_rate
|
76
|
+
|
77
|
+
def forward(self, data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
|
78
|
+
|
79
|
+
"""计算最后的损失值。定义每次调用时执行的计算。
|
80
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
81
|
+
|
82
|
+
:param data: 数据
|
83
|
+
:type data: dict[str, typing.Union[torch.Tensor, str]]
|
84
|
+
:returns: 损失值
|
85
|
+
:rtype: torch.Tensor
|
86
|
+
"""
|
87
|
+
|
88
|
+
pos_sample = data["positive_sample"]
|
89
|
+
neg_sample = data["negative_sample"]
|
90
|
+
mode = data["mode"]
|
91
|
+
pos_score = self.model(pos_sample)
|
92
|
+
if mode == "bern":
|
93
|
+
neg_score = self.model(neg_sample)
|
94
|
+
neg_score = neg_score.view(pos_score.shape[0], -1)
|
95
|
+
else:
|
96
|
+
neg_score = self.model(pos_sample, neg_sample, mode)
|
97
|
+
loss_res = self.loss(pos_score, neg_score)
|
98
|
+
if self.regul_rate != 0:
|
99
|
+
loss_res += self.regul_rate * self.model.regularization(data)
|
100
|
+
if self.l3_regul_rate != 0:
|
101
|
+
loss_res += self.l3_regul_rate * self.model.l3_regularization()
|
102
|
+
return loss_res
|
103
|
+
|
104
|
+
def get_negative_sampling_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
105
|
+
|
106
|
+
"""返回 :py:class:`NegativeSampling` 的默认超参数优化配置。
|
107
|
+
|
108
|
+
默认配置为::
|
109
|
+
|
110
|
+
parameters_dict = {
|
111
|
+
'strategy': {
|
112
|
+
'value': 'NegativeSampling'
|
113
|
+
},
|
114
|
+
'regul_rate': {
|
115
|
+
'value': 0.0
|
116
|
+
},
|
117
|
+
'l3_regul_rate': {
|
118
|
+
'value': 0.0
|
119
|
+
}
|
120
|
+
}
|
121
|
+
|
122
|
+
:returns: :py:class:`NegativeSampling` 的默认超参数优化配置
|
123
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
124
|
+
"""
|
125
|
+
|
126
|
+
parameters_dict = {
|
127
|
+
'strategy': {
|
128
|
+
'value': 'NegativeSampling'
|
129
|
+
},
|
130
|
+
'regul_rate': {
|
131
|
+
'value': 0.0
|
132
|
+
},
|
133
|
+
'l3_regul_rate': {
|
134
|
+
'value': 0.0
|
135
|
+
}
|
136
|
+
}
|
137
|
+
|
138
|
+
return parameters_dict
|
@@ -0,0 +1,134 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/strategy/RGCNSampling.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 18, 2023
|
7
|
+
#
|
8
|
+
# 该脚本定义了 R-GCN 模型的训练策略.
|
9
|
+
|
10
|
+
"""
|
11
|
+
NegativeSampling - 训练策略类,包含损失函数。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import dgl
|
15
|
+
import torch
|
16
|
+
import typing
|
17
|
+
from ..loss import Loss
|
18
|
+
from ..model import Model
|
19
|
+
from .Strategy import Strategy
|
20
|
+
|
21
|
+
class RGCNSampling(Strategy):
|
22
|
+
|
23
|
+
"""
|
24
|
+
将模型和损失函数封装到一起,方便模型训练,用于 ``R-GCN`` :cite:`R-GCN`。
|
25
|
+
|
26
|
+
例子::
|
27
|
+
|
28
|
+
from unike.data import GraphDataLoader
|
29
|
+
from unike.module.model import RGCN
|
30
|
+
from unike.module.loss import RGCNLoss
|
31
|
+
from unike.module.strategy import RGCNSampling
|
32
|
+
from unike.config import Trainer, GraphTester
|
33
|
+
|
34
|
+
dataloader = GraphDataLoader(
|
35
|
+
in_path = "../../benchmarks/FB15K237/",
|
36
|
+
batch_size = 60000,
|
37
|
+
neg_ent = 10,
|
38
|
+
test = True,
|
39
|
+
test_batch_size = 100,
|
40
|
+
num_workers = 16
|
41
|
+
)
|
42
|
+
|
43
|
+
# define the model
|
44
|
+
rgcn = RGCN(
|
45
|
+
ent_tol = dataloader.train_sampler.ent_tol,
|
46
|
+
rel_tol = dataloader.train_sampler.rel_tol,
|
47
|
+
dim = 500,
|
48
|
+
num_layers = 2
|
49
|
+
)
|
50
|
+
|
51
|
+
# define the loss function
|
52
|
+
model = RGCNSampling(
|
53
|
+
model = rgcn,
|
54
|
+
loss = RGCNLoss(model = rgcn, regularization = 1e-5)
|
55
|
+
)
|
56
|
+
|
57
|
+
# test the model
|
58
|
+
tester = GraphTester(model = rgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0')
|
59
|
+
|
60
|
+
# train the model
|
61
|
+
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
|
62
|
+
epochs = 10000, lr = 0.0001, use_gpu = True, device = 'cuda:0',
|
63
|
+
tester = tester, test = True, valid_interval = 500, log_interval = 500,
|
64
|
+
save_interval = 500, save_path = '../../checkpoint/rgcn.pth'
|
65
|
+
)
|
66
|
+
trainer.run()
|
67
|
+
"""
|
68
|
+
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
model: Model = None,
|
72
|
+
loss: Loss = None):
|
73
|
+
|
74
|
+
"""创建 RGCNSampling 对象。
|
75
|
+
|
76
|
+
:param model: R-GCN 模型
|
77
|
+
:type model: :py:class:`unike.module.model.RGCN`
|
78
|
+
:param loss: 损失函数。
|
79
|
+
:type loss: :py:class:`unike.module.loss.Loss`
|
80
|
+
"""
|
81
|
+
|
82
|
+
super(RGCNSampling, self).__init__()
|
83
|
+
|
84
|
+
#: R-GCN 模型,即 :py:class:`unike.module.model.RGCN`
|
85
|
+
self.model: Model = model
|
86
|
+
#: 损失函数,即 :py:class:`unike.module.loss.Loss`
|
87
|
+
self.loss: Loss = loss
|
88
|
+
|
89
|
+
def forward(
|
90
|
+
self,
|
91
|
+
data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]) -> torch.Tensor:
|
92
|
+
|
93
|
+
"""计算最后的损失值。定义每次调用时执行的计算。
|
94
|
+
:py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
|
95
|
+
|
96
|
+
:param data: 数据
|
97
|
+
:type data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
|
98
|
+
:returns: 损失值
|
99
|
+
:rtype: torch.Tensor
|
100
|
+
"""
|
101
|
+
|
102
|
+
graph = data["graph"]
|
103
|
+
entity = data['entity']
|
104
|
+
relation = data['relation']
|
105
|
+
norm = data['norm']
|
106
|
+
triples = data["triples"]
|
107
|
+
label = data["label"]
|
108
|
+
score = self.model(graph, entity, relation, norm, triples)
|
109
|
+
loss = self.loss(score, label)
|
110
|
+
return loss
|
111
|
+
|
112
|
+
def get_rgcn_sampling_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
113
|
+
|
114
|
+
"""返回 :py:class:`RGCNSampling` 的默认超参数优化配置。
|
115
|
+
|
116
|
+
默认配置为::
|
117
|
+
|
118
|
+
parameters_dict = {
|
119
|
+
'strategy': {
|
120
|
+
'value': 'RGCNSampling'
|
121
|
+
}
|
122
|
+
}
|
123
|
+
|
124
|
+
:returns: :py:class:`RGCNSampling` 的默认超参数优化配置
|
125
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
126
|
+
"""
|
127
|
+
|
128
|
+
parameters_dict = {
|
129
|
+
'strategy': {
|
130
|
+
'value': 'RGCNSampling'
|
131
|
+
}
|
132
|
+
}
|
133
|
+
|
134
|
+
return parameters_dict
|
@@ -0,0 +1,26 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/strategy/Strategy.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 4, 2023
|
7
|
+
#
|
8
|
+
# 该脚本定义了损失函数的基类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
Strategy - 该脚本定义了训练策略的基类。
|
12
|
+
"""
|
13
|
+
|
14
|
+
from ..BaseModule import BaseModule
|
15
|
+
|
16
|
+
class Strategy(BaseModule):
|
17
|
+
|
18
|
+
"""
|
19
|
+
继承自 :py:class:`unike.module.BaseModule`,什么额外的属性都没有增加。
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self):
|
23
|
+
|
24
|
+
"""创建 Loss 对象。"""
|
25
|
+
|
26
|
+
super(Strategy, self).__init__()
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/module/strategy/__init__.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 28, 2023
|
7
|
+
#
|
8
|
+
# 该头文件定义了 strategy 接口.
|
9
|
+
|
10
|
+
"""训练策略部分。"""
|
11
|
+
|
12
|
+
from __future__ import absolute_import
|
13
|
+
from __future__ import division
|
14
|
+
from __future__ import print_function
|
15
|
+
|
16
|
+
from .Strategy import Strategy
|
17
|
+
from .NegativeSampling import NegativeSampling, get_negative_sampling_hpo_config
|
18
|
+
from .RGCNSampling import RGCNSampling, get_rgcn_sampling_hpo_config
|
19
|
+
from .CompGCNSampling import CompGCNSampling, get_compgcn_sampling_hpo_config
|
20
|
+
|
21
|
+
__all__ = [
|
22
|
+
'Strategy',
|
23
|
+
'NegativeSampling',
|
24
|
+
'get_negative_sampling_hpo_config',
|
25
|
+
'RGCNSampling',
|
26
|
+
'get_rgcn_sampling_hpo_config',
|
27
|
+
'CompGCNSampling',
|
28
|
+
'get_compgcn_sampling_hpo_config'
|
29
|
+
]
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/utils/EarlyStopping.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 5, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 6, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了 EarlyStopping 类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
EarlyStopping - 使用早停止避免过拟合。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import os
|
15
|
+
import numpy as np
|
16
|
+
from ..module.model import Model
|
17
|
+
import logging
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
logging.basicConfig(format='%(levelname)s:%(module)s:%(asctime)s:%(message)s',
|
21
|
+
datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
|
22
|
+
|
23
|
+
class EarlyStopping:
|
24
|
+
|
25
|
+
"""
|
26
|
+
如果验证得分(越大越好)在给定的耐心后没有改善,则提前停止训练。
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
save_path: str,
|
32
|
+
patience: int = 2,
|
33
|
+
verbose: bool = True,
|
34
|
+
delta: float = 0):
|
35
|
+
|
36
|
+
"""创建 EarlyStopping 对象。
|
37
|
+
|
38
|
+
:param save_path: 模型保存目录
|
39
|
+
:type save_path: str
|
40
|
+
:param patience: 上次验证得分改善后等待多长时间。默认值:2
|
41
|
+
:type patience: int
|
42
|
+
:param verbose: 如果为 True,则为每个验证得分改进打印一条消息。默认值:True
|
43
|
+
:type verbose: bool
|
44
|
+
:param delta: 监测数量的最小变化才符合改进条件。默认值:0
|
45
|
+
:type delta: float
|
46
|
+
"""
|
47
|
+
|
48
|
+
#: 模型保存目录
|
49
|
+
self.save_path: str = os.path.join(save_path, 'best_network.pth')
|
50
|
+
#: 上次验证得分改善后等待多长时间。默认值:2
|
51
|
+
self.patience: int = patience
|
52
|
+
#: 如果为 True,则为每个验证得分改进打印一条消息。默认值:True
|
53
|
+
self.verbose: bool = verbose
|
54
|
+
#: 监测数量的最小变化才符合改进条件。默认值:0
|
55
|
+
self.delta: float = delta
|
56
|
+
|
57
|
+
#: 计数变量
|
58
|
+
self.counter: int = 0
|
59
|
+
#: 保存最好的得分
|
60
|
+
self.best_score: float = -np.Inf
|
61
|
+
#: 早停开关
|
62
|
+
self.early_stop: bool = False
|
63
|
+
|
64
|
+
def __call__(
|
65
|
+
self,
|
66
|
+
score: float,
|
67
|
+
model: Model):
|
68
|
+
|
69
|
+
"""
|
70
|
+
进行早停记录。
|
71
|
+
"""
|
72
|
+
|
73
|
+
if score <= self.best_score + self.delta:
|
74
|
+
self.counter += 1
|
75
|
+
logger.info(f'EarlyStopping counter: {self.counter} / {self.patience}')
|
76
|
+
if self.counter >= self.patience:
|
77
|
+
self.early_stop = True
|
78
|
+
else:
|
79
|
+
self.save_checkpoint(score, model)
|
80
|
+
self.counter = 0
|
81
|
+
|
82
|
+
def save_checkpoint(
|
83
|
+
self,
|
84
|
+
score: float,
|
85
|
+
model: Model):
|
86
|
+
|
87
|
+
"""
|
88
|
+
当验证得分改善时保存模型。
|
89
|
+
"""
|
90
|
+
|
91
|
+
if self.verbose:
|
92
|
+
logger.info(f'Validation score improved ({self.best_score:.6f} --> {score:.6f}). Saving model ...')
|
93
|
+
model.save_checkpoint(self.save_path)
|
94
|
+
self.best_score = score
|
unike/utils/Timer.py
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/utils/Timer.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on July 6, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 24, 2023
|
7
|
+
#
|
8
|
+
# 该脚本定义了计时器类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
:py:class:`Timer` - 计时器类。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import time
|
15
|
+
|
16
|
+
class Timer:
|
17
|
+
|
18
|
+
"""记录多次实验的时间。
|
19
|
+
|
20
|
+
通过调用 :py:meth:`stop` 能够返回距离上一次 :py:meth:`stop` 调用或创建 :py:class:`Timer` 对象时的时间间隔。
|
21
|
+
|
22
|
+
:py:meth:`avg` 能够返回多次实验的平均时间;:py:meth:`sum` 能够返回多次实验的总时间。"""
|
23
|
+
|
24
|
+
def __init__(self):
|
25
|
+
|
26
|
+
"""创建 Timer 对象。"""
|
27
|
+
|
28
|
+
#: 存放时间间隔的列表
|
29
|
+
self.times: list[float] = []
|
30
|
+
#: 记录当前时间
|
31
|
+
self.current: float = None
|
32
|
+
#: 记录上一次的时间
|
33
|
+
self.last: float = None
|
34
|
+
|
35
|
+
self.__restart()
|
36
|
+
|
37
|
+
def __restart(self):
|
38
|
+
|
39
|
+
"""重启计时器。"""
|
40
|
+
|
41
|
+
self.last = self.current = time.time()
|
42
|
+
|
43
|
+
def stop(self) -> float:
|
44
|
+
|
45
|
+
"""停止计时器并将时间记录在列表中。
|
46
|
+
|
47
|
+
:returns: 返回最后一次的间隔时间。
|
48
|
+
:rtype: float
|
49
|
+
"""
|
50
|
+
|
51
|
+
self.current = time.time()
|
52
|
+
self.times.append(self.current - self.last)
|
53
|
+
self.last = self.current
|
54
|
+
return self.times[-1]
|
55
|
+
|
56
|
+
def avg(self) -> float:
|
57
|
+
|
58
|
+
"""返回平均时间。
|
59
|
+
|
60
|
+
:returns: 平均时间
|
61
|
+
:rtype: float
|
62
|
+
"""
|
63
|
+
|
64
|
+
return sum(self.times) / len(self.times)
|
65
|
+
|
66
|
+
def sum(self) -> float:
|
67
|
+
|
68
|
+
"""返回时间总和。
|
69
|
+
|
70
|
+
:returns: 时间总和。
|
71
|
+
:rtype: float
|
72
|
+
"""
|
73
|
+
|
74
|
+
return sum(self.times)
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/utils/WandbLogger.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 1, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 24, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了 WandbLogger 类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
WandbLogger - 使用 Weights and Biases 记录实验结果。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import typing
|
15
|
+
import wandb
|
16
|
+
|
17
|
+
class WandbLogger:
|
18
|
+
|
19
|
+
"""使用 `Weights and Biases <https://docs.wandb.ai/>`_ 记录实验结果。"""
|
20
|
+
|
21
|
+
def __init__(self,
|
22
|
+
project: str ="pybind11-ke",
|
23
|
+
name: str = "transe",
|
24
|
+
config: dict[str, typing.Any] | None = None):
|
25
|
+
|
26
|
+
"""创建 WandbLogger 对象。
|
27
|
+
|
28
|
+
:param project: wandb 的项目名称
|
29
|
+
:type project: str
|
30
|
+
:param name: wandb 的 run name
|
31
|
+
:type name: str
|
32
|
+
:param config: wandb 的项目配置如超参数。
|
33
|
+
:type config: dict[str, typing.Any] | None
|
34
|
+
"""
|
35
|
+
|
36
|
+
wandb.login()
|
37
|
+
wandb.init(project=project, name=name, config=config)
|
38
|
+
|
39
|
+
#: config 的副本
|
40
|
+
self.config: dict = wandb.config
|
41
|
+
|
42
|
+
def finish(self):
|
43
|
+
|
44
|
+
"""关闭 wandb"""
|
45
|
+
|
46
|
+
wandb.finish()
|
unike/utils/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/utils/__init__.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on July 6, 2023
|
6
|
+
#
|
7
|
+
# 该头文件定义了 utils 接口.
|
8
|
+
|
9
|
+
"""工具类。"""
|
10
|
+
|
11
|
+
from __future__ import absolute_import
|
12
|
+
from __future__ import division
|
13
|
+
from __future__ import print_function
|
14
|
+
|
15
|
+
from .Timer import Timer
|
16
|
+
from .WandbLogger import WandbLogger
|
17
|
+
from .tools import import_class, construct_type_constrain
|
18
|
+
from .EarlyStopping import EarlyStopping
|
19
|
+
|
20
|
+
__all__ = [
|
21
|
+
'Timer',
|
22
|
+
'WandbLogger',
|
23
|
+
'import_class',
|
24
|
+
'construct_type_constrain',
|
25
|
+
'EarlyStopping',
|
26
|
+
]
|
unike/utils/tools.py
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/utils/tools.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 3, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了 WandbLogger 类.
|
9
|
+
|
10
|
+
import importlib
|
11
|
+
|
12
|
+
def import_class(module_and_class_name: str) -> type:
|
13
|
+
|
14
|
+
"""从模块中导入类。
|
15
|
+
|
16
|
+
:param module_and_class_name: 模块和类名,如 **unike.module.model.TransE** 。
|
17
|
+
:type module_and_class_name: str
|
18
|
+
:returns: 类名
|
19
|
+
:rtype: type
|
20
|
+
"""
|
21
|
+
|
22
|
+
module_name, class_name = module_and_class_name.rsplit(".", 1)
|
23
|
+
module = importlib.import_module(module_name)
|
24
|
+
class_ = getattr(module, class_name)
|
25
|
+
return class_
|
26
|
+
|
27
|
+
def construct_type_constrain(
|
28
|
+
in_path: str = "./",
|
29
|
+
train_file: str = "train2id.txt",
|
30
|
+
valid_file: str = "valid2id.txt",
|
31
|
+
test_file: str = "test2id.txt"
|
32
|
+
):
|
33
|
+
|
34
|
+
"""构建 type_constrain.txt 文件
|
35
|
+
|
36
|
+
type_constrain.txt: 类型约束文件, 第一行是关系的个数
|
37
|
+
|
38
|
+
下面的行是每个关系的类型限制 (训练集、验证集、测试集中每个关系存在的 head 和 tail 的类型)
|
39
|
+
|
40
|
+
每个关系有两行:
|
41
|
+
|
42
|
+
第一行:**rel_id** **heads_num** **head1** **head2** ...
|
43
|
+
|
44
|
+
第二行: **rel_id** **tails_num** **tail1** **tail2** ...
|
45
|
+
|
46
|
+
如 benchmarks/FB15K 的 id 为 1200 的关系,它有 4 种类型头实体(3123,1034,58 和 5733)和 4 种类型的尾实体(12123,4388,11087 和 11088)。
|
47
|
+
|
48
|
+
1200 4 3123 1034 58 5733
|
49
|
+
|
50
|
+
1200 4 12123 4388 11087 11088
|
51
|
+
|
52
|
+
:param in_path: 数据集目录
|
53
|
+
:type in_path: str
|
54
|
+
:param train_file: train2id.txt
|
55
|
+
:type train_file: str
|
56
|
+
:param valid_file: valid2id.txt
|
57
|
+
:type valid_file: str
|
58
|
+
:param test_file: test2id.txt
|
59
|
+
:type test_file: str
|
60
|
+
"""
|
61
|
+
|
62
|
+
|
63
|
+
rel_head: dict = {}
|
64
|
+
rel_tail: dict = {}
|
65
|
+
|
66
|
+
train = open(in_path + train_file, "r")
|
67
|
+
valid = open(in_path + valid_file, "r")
|
68
|
+
test = open(in_path + test_file, "r")
|
69
|
+
|
70
|
+
tot = (int)(train.readline())
|
71
|
+
for i in range(tot):
|
72
|
+
content = train.readline()
|
73
|
+
h,t,r = content.strip().split()
|
74
|
+
if not r in rel_head:
|
75
|
+
rel_head[r] = {}
|
76
|
+
if not r in rel_tail:
|
77
|
+
rel_tail[r] = {}
|
78
|
+
rel_head[r][h] = 1
|
79
|
+
rel_tail[r][t] = 1
|
80
|
+
|
81
|
+
tot = (int)(valid.readline())
|
82
|
+
for i in range(tot):
|
83
|
+
content = valid.readline()
|
84
|
+
h,t,r = content.strip().split()
|
85
|
+
if not r in rel_head:
|
86
|
+
rel_head[r] = {}
|
87
|
+
if not r in rel_tail:
|
88
|
+
rel_tail[r] = {}
|
89
|
+
rel_head[r][h] = 1
|
90
|
+
rel_tail[r][t] = 1
|
91
|
+
|
92
|
+
tot = (int)(test.readline())
|
93
|
+
for i in range(tot):
|
94
|
+
content = test.readline()
|
95
|
+
h,t,r = content.strip().split()
|
96
|
+
if not r in rel_head:
|
97
|
+
rel_head[r] = {}
|
98
|
+
if not r in rel_tail:
|
99
|
+
rel_tail[r] = {}
|
100
|
+
rel_head[r][h] = 1
|
101
|
+
rel_tail[r][t] = 1
|
102
|
+
|
103
|
+
train.close()
|
104
|
+
valid.close()
|
105
|
+
test.close()
|
106
|
+
|
107
|
+
f = open(in_path + "type_constrain.txt", "w")
|
108
|
+
f.write("%d\n" % (len(rel_head)))
|
109
|
+
for i in rel_head:
|
110
|
+
f.write("%s\t%d" % (i, len(rel_head[i])))
|
111
|
+
for j in rel_head[i]:
|
112
|
+
f.write("\t%s" % (j))
|
113
|
+
f.write("\n")
|
114
|
+
f.write("%s\t%d" % (i, len(rel_tail[i])))
|
115
|
+
for j in rel_tail[i]:
|
116
|
+
f.write("\t%s" % (j))
|
117
|
+
f.write("\n")
|
118
|
+
f.close()
|
unike/version.py
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
__version__: str = '3.0.1'
|