unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
unike/__init__.py
ADDED
@@ -0,0 +1,305 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/config/HPOTrainer.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 2, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 6, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了并行训练循环函数.
|
9
|
+
|
10
|
+
"""
|
11
|
+
hpo_train - 超参数优化训练循环函数。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import wandb
|
15
|
+
import typing
|
16
|
+
from ..utils import import_class
|
17
|
+
from ..module.model import TransE
|
18
|
+
from ..module.loss import MarginLoss
|
19
|
+
from ..module.strategy import NegativeSampling
|
20
|
+
from ..config import Trainer, Tester
|
21
|
+
from ..data import KGEDataLoader
|
22
|
+
import logging
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
logging.basicConfig(format='%(levelname)s:%(module)s:%(asctime)s:%(message)s',
|
26
|
+
datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
|
27
|
+
|
28
|
+
def set_hpo_config(
|
29
|
+
method: str = 'bayes',
|
30
|
+
sweep_name: str = 'unike_hpo',
|
31
|
+
metric_name: str = 'val/hits@10',
|
32
|
+
metric_goal: str = 'maximize',
|
33
|
+
data_loader_config: dict[str, dict[str, typing.Any]] = {},
|
34
|
+
kge_config: dict[str, dict[str, typing.Any]] = {},
|
35
|
+
loss_config: dict[str, dict[str, typing.Any]] = {},
|
36
|
+
strategy_config: dict[str, dict[str, typing.Any]] = {},
|
37
|
+
tester_config: dict[str, dict[str, typing.Any]] = {},
|
38
|
+
trainer_config: dict[str, dict[str, typing.Any]] = {}) -> dict[str, dict[str, typing.Any]]:
|
39
|
+
|
40
|
+
"""设置超参数优化范围。
|
41
|
+
|
42
|
+
:param method: 超参数优化的方法,``grid`` 或 ``random`` 或 ``bayes``
|
43
|
+
:type param: str
|
44
|
+
:param sweep_name: 超参数优化 sweep 的名字
|
45
|
+
:type sweep_name: str
|
46
|
+
:param metric_name: 超参数优化的指标名字
|
47
|
+
:type metric_name: str
|
48
|
+
:param metric_goal: 超参数优化的指标目标,``maximize`` 或 ``minimize``
|
49
|
+
:type metric_goal: str
|
50
|
+
:param data_loader_config: :py:class:`unike.data.KGEDataLoader` 的超参数优化配置
|
51
|
+
:type data_loader_config: dict
|
52
|
+
:param kge_config: :py:class:`unike.module.model.Model` 的超参数优化配置
|
53
|
+
:type kge_config: dict
|
54
|
+
:param loss_config: :py:class:`unike.module.loss.Loss` 的超参数优化配置
|
55
|
+
:type loss_config: dict
|
56
|
+
:param strategy_config: :py:class:`unike.module.strategy.Strategy` 的超参数优化配置
|
57
|
+
:type strategy_config: dict
|
58
|
+
:param tester_config: :py:class:`unike.config.Tester` 的超参数优化配置
|
59
|
+
:type tester_config: dict
|
60
|
+
:param trainer_config: :py:class:`unike.config.Trainer` 的超参数优化配置
|
61
|
+
:type trainer_config: dict
|
62
|
+
:returns: 超参数优化范围
|
63
|
+
:rtype: dict
|
64
|
+
"""
|
65
|
+
|
66
|
+
sweep_config: dict[str, str] = {
|
67
|
+
'method': method,
|
68
|
+
'name': sweep_name
|
69
|
+
}
|
70
|
+
|
71
|
+
metric: dict[str, str] = {
|
72
|
+
'name': metric_name,
|
73
|
+
'goal': metric_goal
|
74
|
+
}
|
75
|
+
|
76
|
+
parameters_dict: dict[str, dict[str, typing.Any]] | None = {}
|
77
|
+
parameters_dict.update(data_loader_config)
|
78
|
+
parameters_dict.update(kge_config)
|
79
|
+
parameters_dict.update(loss_config)
|
80
|
+
parameters_dict.update(strategy_config)
|
81
|
+
parameters_dict.update(tester_config)
|
82
|
+
parameters_dict.update(trainer_config)
|
83
|
+
|
84
|
+
sweep_config['metric'] = metric
|
85
|
+
sweep_config['parameters'] = parameters_dict
|
86
|
+
|
87
|
+
return sweep_config
|
88
|
+
|
89
|
+
def set_hpo_hits(
|
90
|
+
new_hits: list[int] = [1, 3, 10]):
|
91
|
+
|
92
|
+
"""设置 Hits 指标。
|
93
|
+
|
94
|
+
:param new_hits: 准备报告的指标 Hit@N 的列表,默认为 [1, 3, 10], 表示报告 Hits@1, Hits@3, Hits@10
|
95
|
+
:type new_hits: list[int]
|
96
|
+
"""
|
97
|
+
|
98
|
+
tmp = Tester.hits
|
99
|
+
Tester.hits = new_hits
|
100
|
+
logger.info(f"Hits@N 指标由 {tmp} 变为 {Tester.hits}")
|
101
|
+
|
102
|
+
def start_hpo_train(
|
103
|
+
config: dict[str, dict[str, typing.Any]] | None = None,
|
104
|
+
project: str = "pybind11-ke-sweeps",
|
105
|
+
count: int = 2):
|
106
|
+
|
107
|
+
"""开启超参数优化。
|
108
|
+
|
109
|
+
:param config: wandb 的超参数优化配置。
|
110
|
+
:type config: dict
|
111
|
+
:param project: 项目名
|
112
|
+
:type param: str
|
113
|
+
:param count: 进行几次尝试。
|
114
|
+
:type count: int
|
115
|
+
"""
|
116
|
+
|
117
|
+
wandb.login()
|
118
|
+
|
119
|
+
sweep_id = wandb.sweep(config, project=project)
|
120
|
+
|
121
|
+
wandb.agent(sweep_id, hpo_train, count=count)
|
122
|
+
|
123
|
+
def hpo_train(config: dict[str, typing.Any] | None = None):
|
124
|
+
|
125
|
+
"""超参数优化训练循环函数。
|
126
|
+
|
127
|
+
:param config: wandb 的项目配置如超参数。
|
128
|
+
:type config: dict[str, typing.Any] | None
|
129
|
+
"""
|
130
|
+
|
131
|
+
with wandb.init(config = config):
|
132
|
+
|
133
|
+
config = wandb.config
|
134
|
+
|
135
|
+
# dataloader for training
|
136
|
+
dataloader_class: type[KGEDataLoader] = import_class(f"unike.data.{config.dataloader}")
|
137
|
+
dataloader = dataloader_class(
|
138
|
+
in_path = config.in_path,
|
139
|
+
ent_file = config.ent_file,
|
140
|
+
rel_file = config.rel_file,
|
141
|
+
train_file = config.train_file,
|
142
|
+
valid_file = config.valid_file,
|
143
|
+
test_file = config.test_file,
|
144
|
+
batch_size = config.batch_size,
|
145
|
+
neg_ent = config.neg_ent,
|
146
|
+
test = True,
|
147
|
+
test_batch_size = config.test_batch_size,
|
148
|
+
type_constrain = config.type_constrain,
|
149
|
+
num_workers = config.num_workers,
|
150
|
+
train_sampler = import_class(f"unike.data.{config.train_sampler}"),
|
151
|
+
test_sampler = import_class(f"unike.data.{config.test_sampler}")
|
152
|
+
)
|
153
|
+
|
154
|
+
# define the model
|
155
|
+
model_class = import_class(f"unike.module.model.{config.model}")
|
156
|
+
if config.model in ["TransE", "TransH"]:
|
157
|
+
kge_model = model_class(
|
158
|
+
ent_tol = dataloader.get_ent_tol(),
|
159
|
+
rel_tol = dataloader.get_rel_tol(),
|
160
|
+
dim = config.dim,
|
161
|
+
p_norm = config.p_norm,
|
162
|
+
norm_flag = config.norm_flag
|
163
|
+
)
|
164
|
+
elif config.model == "TransR":
|
165
|
+
transe = TransE(
|
166
|
+
ent_tol = dataloader.get_ent_tol(),
|
167
|
+
rel_tol = dataloader.get_rel_tol(),
|
168
|
+
dim = config.dim,
|
169
|
+
p_norm = config.p_norm,
|
170
|
+
norm_flag = config.norm_flag
|
171
|
+
)
|
172
|
+
kge_model = model_class(
|
173
|
+
ent_tol = dataloader.get_ent_tol(),
|
174
|
+
rel_tol = dataloader.get_rel_tol(),
|
175
|
+
dim_e = config.dim,
|
176
|
+
dim_r = config.dim,
|
177
|
+
p_norm = config.p_norm,
|
178
|
+
norm_flag = config.norm_flag,
|
179
|
+
rand_init = config.rand_init)
|
180
|
+
model_e = NegativeSampling(
|
181
|
+
model = transe,
|
182
|
+
loss = MarginLoss(margin = config.margin_e)
|
183
|
+
)
|
184
|
+
trainer_e = Trainer(
|
185
|
+
model = model_e,
|
186
|
+
data_loader = dataloader.train_dataloader(),
|
187
|
+
epochs = 1,
|
188
|
+
lr = config.lr_e,
|
189
|
+
opt_method = config.opt_method_e,
|
190
|
+
use_gpu = config.use_gpu,
|
191
|
+
device = config.device
|
192
|
+
)
|
193
|
+
trainer_e.run()
|
194
|
+
parameters = transe.get_parameters()
|
195
|
+
transe.save_parameters("./transr_transe.json")
|
196
|
+
kge_model.set_parameters(parameters)
|
197
|
+
elif config.model == "TransD":
|
198
|
+
kge_model = model_class(
|
199
|
+
ent_tol = dataloader.get_ent_tol(),
|
200
|
+
rel_tol = dataloader.get_rel_tol(),
|
201
|
+
dim_e = config.dim_e,
|
202
|
+
dim_r = config.dim_r,
|
203
|
+
p_norm = config.p_norm,
|
204
|
+
norm_flag = config.norm_flag)
|
205
|
+
elif config.model == "RotatE":
|
206
|
+
kge_model = model_class(
|
207
|
+
ent_tol = dataloader.get_ent_tol(),
|
208
|
+
rel_tol = dataloader.get_rel_tol(),
|
209
|
+
dim = config.dim,
|
210
|
+
margin = config.margin,
|
211
|
+
epsilon = config.epsilon)
|
212
|
+
elif config.model in ["RESCAL", "DistMult", "HolE", "ComplEx", "Analogy", "SimplE"]:
|
213
|
+
kge_model = model_class(
|
214
|
+
ent_tol = dataloader.get_ent_tol(),
|
215
|
+
rel_tol = dataloader.get_rel_tol(),
|
216
|
+
dim = config.dim)
|
217
|
+
elif config.model == "RGCN":
|
218
|
+
kge_model = model_class(
|
219
|
+
ent_tol = dataloader.get_ent_tol(),
|
220
|
+
rel_tol = dataloader.get_rel_tol(),
|
221
|
+
dim = config.dim,
|
222
|
+
num_layers = config.num_layers)
|
223
|
+
elif config.model == "CompGCN":
|
224
|
+
kge_model = model_class(
|
225
|
+
ent_tol = dataloader.get_ent_tol(),
|
226
|
+
rel_tol = dataloader.get_rel_tol(),
|
227
|
+
dim = config.dim,
|
228
|
+
opn = config.opn,
|
229
|
+
fet_drop = config.fet_drop,
|
230
|
+
hid_drop = config.hid_drop,
|
231
|
+
margin = config.margin,
|
232
|
+
decoder_model = config.decoder_model)
|
233
|
+
|
234
|
+
# define the loss function
|
235
|
+
loss_class = import_class(f"unike.module.loss.{config.loss}")
|
236
|
+
if config.loss == 'MarginLoss':
|
237
|
+
loss = loss_class(
|
238
|
+
adv_temperature = config.adv_temperature,
|
239
|
+
margin = config.margin
|
240
|
+
)
|
241
|
+
elif config.loss in ['SigmoidLoss', 'SoftplusLoss']:
|
242
|
+
loss = loss_class(adv_temperature = config.adv_temperature)
|
243
|
+
elif config.loss == 'RGCNLoss':
|
244
|
+
loss = loss_class(
|
245
|
+
model = kge_model,
|
246
|
+
regularization = config.regularization
|
247
|
+
)
|
248
|
+
elif config.loss == 'CompGCNLoss':
|
249
|
+
loss = loss_class(model = kge_model)
|
250
|
+
|
251
|
+
# define the strategy
|
252
|
+
strategy_class = import_class(f"unike.module.strategy.{config.strategy}")
|
253
|
+
if config.strategy == 'NegativeSampling':
|
254
|
+
model = strategy_class(
|
255
|
+
model = kge_model,
|
256
|
+
loss = loss,
|
257
|
+
regul_rate = config.regul_rate,
|
258
|
+
l3_regul_rate = config.l3_regul_rate
|
259
|
+
)
|
260
|
+
elif config.strategy == 'RGCNSampling':
|
261
|
+
model = strategy_class(
|
262
|
+
model = kge_model,
|
263
|
+
loss = loss
|
264
|
+
)
|
265
|
+
elif config.strategy == 'CompGCNSampling':
|
266
|
+
model = strategy_class(
|
267
|
+
model = kge_model,
|
268
|
+
loss = loss,
|
269
|
+
smoothing = config.smoothing,
|
270
|
+
ent_tol = dataloader.train_sampler.ent_tol
|
271
|
+
)
|
272
|
+
|
273
|
+
# test the model
|
274
|
+
tester_class: type[Tester] = import_class(f"unike.config.{config.tester}")
|
275
|
+
tester = tester_class(
|
276
|
+
model = kge_model,
|
277
|
+
data_loader = dataloader,
|
278
|
+
prediction = config.prediction,
|
279
|
+
use_tqdm = config.use_tqdm,
|
280
|
+
use_gpu = config.use_gpu,
|
281
|
+
device = config.device
|
282
|
+
)
|
283
|
+
|
284
|
+
# # train the model
|
285
|
+
trainer_class: type[Trainer] = import_class(f"unike.config.{config.trainer}")
|
286
|
+
trainer = trainer_class(
|
287
|
+
model = model,
|
288
|
+
data_loader = dataloader.train_dataloader(),
|
289
|
+
epochs = config.epochs,
|
290
|
+
lr = config.lr,
|
291
|
+
opt_method = config.opt_method,
|
292
|
+
use_gpu = config.use_gpu,
|
293
|
+
device = config.device,
|
294
|
+
tester = tester,
|
295
|
+
test = True,
|
296
|
+
valid_interval = config.valid_interval,
|
297
|
+
log_interval = config.log_interval,
|
298
|
+
save_path = config.save_path,
|
299
|
+
use_early_stopping = config.use_early_stopping,
|
300
|
+
metric = config.metric,
|
301
|
+
patience = config.patience,
|
302
|
+
delta = config.delta,
|
303
|
+
use_wandb = True
|
304
|
+
)
|
305
|
+
trainer.run()
|
unike/config/Tester.py
ADDED
@@ -0,0 +1,385 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/config/Tester.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 6, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了验证模型类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
Tester - 验证模型类,内部使用 ``tqmn`` 实现进度条。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import dgl
|
15
|
+
import torch
|
16
|
+
import typing
|
17
|
+
import collections
|
18
|
+
import numpy as np
|
19
|
+
from tqdm import tqdm
|
20
|
+
from ..module.model import Model
|
21
|
+
from ..data import KGEDataLoader
|
22
|
+
import logging
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
logging.basicConfig(format='%(levelname)s:%(module)s:%(asctime)s:%(message)s',
|
26
|
+
datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
|
27
|
+
|
28
|
+
class Tester(object):
|
29
|
+
|
30
|
+
"""
|
31
|
+
主要用于 KGE 模型的评估。
|
32
|
+
|
33
|
+
例子::
|
34
|
+
|
35
|
+
from unike.data import KGEDataLoader, BernSampler, TradTestSampler
|
36
|
+
from unike.module.model import TransE
|
37
|
+
from unike.module.loss import MarginLoss
|
38
|
+
from unike.module.strategy import NegativeSampling
|
39
|
+
from unike.config import Trainer, Tester
|
40
|
+
|
41
|
+
# dataloader for training
|
42
|
+
dataloader = KGEDataLoader(
|
43
|
+
in_path = "../../benchmarks/FB15K/",
|
44
|
+
batch_size = 8192,
|
45
|
+
neg_ent = 25,
|
46
|
+
test = True,
|
47
|
+
test_batch_size = 256,
|
48
|
+
num_workers = 16,
|
49
|
+
train_sampler = BernSampler,
|
50
|
+
test_sampler = TradTestSampler
|
51
|
+
)
|
52
|
+
|
53
|
+
# define the model
|
54
|
+
transe = TransE(
|
55
|
+
ent_tol = dataloader.train_sampler.ent_tol,
|
56
|
+
rel_tol = dataloader.train_sampler.rel_tol,
|
57
|
+
dim = 50,
|
58
|
+
p_norm = 1,
|
59
|
+
norm_flag = True)
|
60
|
+
|
61
|
+
# define the loss function
|
62
|
+
model = NegativeSampling(
|
63
|
+
model = transe,
|
64
|
+
loss = MarginLoss(margin = 1.0),
|
65
|
+
regul_rate = 0.01
|
66
|
+
)
|
67
|
+
|
68
|
+
# test the model
|
69
|
+
tester = Tester(model = transe, data_loader = dataloader, use_gpu = True, device = 'cuda:1')
|
70
|
+
|
71
|
+
# train the model
|
72
|
+
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
|
73
|
+
epochs = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1',
|
74
|
+
tester = tester, test = True, valid_interval = 100,
|
75
|
+
log_interval = 100, save_interval = 100,
|
76
|
+
save_path = '../../checkpoint/transe.pth', delta = 0.01)
|
77
|
+
trainer.run()
|
78
|
+
"""
|
79
|
+
|
80
|
+
#: 准备报告的指标 Hit@N 的列表,默认为 [1, 3, 10], 表示报告 Hits@1, Hits@3, Hits@10
|
81
|
+
hits: list[int] = [1, 3, 10]
|
82
|
+
|
83
|
+
def __init__(
|
84
|
+
self,
|
85
|
+
model: Model = None,
|
86
|
+
data_loader: KGEDataLoader = None,
|
87
|
+
sampling_mode: str = 'link_test',
|
88
|
+
prediction: str = "all",
|
89
|
+
use_tqdm: bool = True,
|
90
|
+
use_gpu: bool = True,
|
91
|
+
device: str = "cuda:0",
|
92
|
+
only_test: bool = False):
|
93
|
+
|
94
|
+
"""创建 Tester 对象。
|
95
|
+
|
96
|
+
:param model: KGE 模型
|
97
|
+
:type model: unike.module.model.Model
|
98
|
+
:param data_loader: py:class:`unike.data.KGEDataLoader`
|
99
|
+
:type data_loader: unike.data.KGEDataLoader
|
100
|
+
:param sampling_mode: 评估验证集还是测试集:**'link_test'** or **'link_valid'**
|
101
|
+
:type sampling_mode: str
|
102
|
+
:param prediction: 链接预测模式: **'all'**、**'head'**、**'tail'**
|
103
|
+
:type prediction: str
|
104
|
+
:param use_tqdm: 是否启用进度条
|
105
|
+
:type use_tqdm: bool
|
106
|
+
:param use_gpu: 是否使用 gpu
|
107
|
+
:type use_gpu: bool
|
108
|
+
:param device: 使用哪个 gpu
|
109
|
+
:type device: str
|
110
|
+
:param only_test: 是否是评估已经训练好的模型
|
111
|
+
:type only_test: bool
|
112
|
+
"""
|
113
|
+
|
114
|
+
#: KGE 模型,即 :py:class:`unike.module.model.Model`
|
115
|
+
self.model: Model = model
|
116
|
+
#: :py:class:`unike.data.KGEDataLoader`
|
117
|
+
self.data_loader: KGEDataLoader = data_loader
|
118
|
+
#: :py:class:`unike.data.TestDataLoader` 负采样的方式:**'link_test'** or **'link_valid'**
|
119
|
+
self.sampling_mode: str = sampling_mode
|
120
|
+
#: 链接预测模式: **'all'**、**'head'**、**'tail'**
|
121
|
+
self.prediction: str = prediction
|
122
|
+
#: 是否启用进度条
|
123
|
+
self.use_tqdm: bool = use_tqdm
|
124
|
+
#: 是否使用 gpu
|
125
|
+
self.use_gpu: bool = use_gpu
|
126
|
+
#: gpu,利用 ``device`` 构造的 :py:class:`torch.device` 对象
|
127
|
+
self.device: torch.device = torch.device(device)
|
128
|
+
#: 验证数据加载器。
|
129
|
+
self.val_dataloader: torch.utils.data.DataLoader = self.data_loader.val_dataloader()
|
130
|
+
#: 测试数据加载器。
|
131
|
+
self.test_dataloader: torch.utils.data.DataLoader = self.data_loader.test_dataloader()
|
132
|
+
|
133
|
+
if self.use_gpu and only_test:
|
134
|
+
self.model.cuda(device = self.device)
|
135
|
+
|
136
|
+
def set_hits(
|
137
|
+
self,
|
138
|
+
new_hits: list[int] = [1, 3, 10]):
|
139
|
+
|
140
|
+
"""定义 Hits 指标。
|
141
|
+
|
142
|
+
:param new_hits: 准备报告的指标 Hit@N 的列表,默认为 [1, 3, 10], 表示报告 Hits@1, Hits@3, Hits@10
|
143
|
+
:type new_hits: list[int]
|
144
|
+
"""
|
145
|
+
|
146
|
+
tmp = self.hits
|
147
|
+
self.hits = new_hits
|
148
|
+
|
149
|
+
logger.info(f"Hits@N 指标由 {tmp} 变为 {self.hits}")
|
150
|
+
|
151
|
+
def to_var(
|
152
|
+
self,
|
153
|
+
x: torch.Tensor) -> torch.Tensor:
|
154
|
+
|
155
|
+
"""根据 :py:attr:`use_gpu` 返回 ``x`` 的张量
|
156
|
+
|
157
|
+
:param x: 数据
|
158
|
+
:type x: torch.Tensor
|
159
|
+
:returns: 张量
|
160
|
+
:rtype: torch.Tensor
|
161
|
+
"""
|
162
|
+
|
163
|
+
if self.use_gpu:
|
164
|
+
return x.to(self.device)
|
165
|
+
else:
|
166
|
+
return x
|
167
|
+
|
168
|
+
def run_link_prediction(self) -> dict[str, float]:
|
169
|
+
|
170
|
+
"""进行链接预测。
|
171
|
+
|
172
|
+
:returns: 经典指标分别为 MR,MRR,Hits@1,Hits@3,Hits@10
|
173
|
+
:rtype: dict[str, float]
|
174
|
+
"""
|
175
|
+
|
176
|
+
if self.sampling_mode == "link_valid":
|
177
|
+
training_range = tqdm(self.val_dataloader) if self.use_tqdm else self.val_dataloader
|
178
|
+
elif self.sampling_mode == "link_test":
|
179
|
+
training_range = tqdm(self.test_dataloader) if self.use_tqdm else self.test_dataloader
|
180
|
+
self.model.eval()
|
181
|
+
results = collections.defaultdict(float)
|
182
|
+
results_type = collections.defaultdict(float)
|
183
|
+
with torch.no_grad():
|
184
|
+
for data in training_range:
|
185
|
+
data = {key : self.to_var(value) for key, value in data.items()}
|
186
|
+
if "head_label_type" in data.keys():
|
187
|
+
ranks, ranks_type = link_predict(data, self.model, prediction=self.prediction)
|
188
|
+
results_type["count_type"] += torch.numel(ranks_type)
|
189
|
+
results_type["mr_type"] += torch.sum(ranks_type).item()
|
190
|
+
results_type["mrr_type"] += torch.sum(1.0 / ranks_type).item()
|
191
|
+
for k in self.hits:
|
192
|
+
results_type['hits@{}_type'.format(k)] += torch.numel(ranks_type[ranks_type <= k])
|
193
|
+
else:
|
194
|
+
ranks = link_predict(data, self.model, prediction=self.prediction)
|
195
|
+
results["count"] += torch.numel(ranks)
|
196
|
+
results["mr"] += torch.sum(ranks).item()
|
197
|
+
results["mrr"] += torch.sum(1.0 / ranks).item()
|
198
|
+
for k in self.hits:
|
199
|
+
results['hits@{}'.format(k)] += torch.numel(ranks[ranks <= k])
|
200
|
+
|
201
|
+
count = results["count"]
|
202
|
+
results = {key : np.around(value / count, decimals=3).item() for key, value in results.items() if key != "count"}
|
203
|
+
if "count_type" in results_type.keys():
|
204
|
+
count_type = results_type["count_type"]
|
205
|
+
results_type = {key : np.around(value / count_type, decimals=3).item() for key, value in results_type.items() if key != "count_type"}
|
206
|
+
for key, value in results_type.items():
|
207
|
+
results[key] = value
|
208
|
+
return results
|
209
|
+
|
210
|
+
def set_sampling_mode(self, sampling_mode: str):
|
211
|
+
|
212
|
+
"""设置 :py:attr:`sampling_mode`
|
213
|
+
|
214
|
+
:param sampling_mode: 数据采样模式,**'link_test'** 和 **'link_valid'** 分别表示为链接预测进行测试集和验证集的负采样
|
215
|
+
:type sampling_mode: str
|
216
|
+
"""
|
217
|
+
|
218
|
+
self.sampling_mode = sampling_mode
|
219
|
+
|
220
|
+
def link_predict(
|
221
|
+
batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]],
|
222
|
+
model: Model,
|
223
|
+
prediction: str = "all") -> tuple[torch.Tensor, ...]:
|
224
|
+
|
225
|
+
"""
|
226
|
+
进行链接预测。
|
227
|
+
|
228
|
+
:param batch: 测试数据
|
229
|
+
:type batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
|
230
|
+
:param model: KGE 模型
|
231
|
+
:type model: unike.module.model.Model
|
232
|
+
:param prediction: **'all'**, **'head'**, **'tail'**
|
233
|
+
:type prediction: str
|
234
|
+
:returns: 正确三元组的排名
|
235
|
+
:rtype: tuple[torch.Tensor, ...]
|
236
|
+
"""
|
237
|
+
|
238
|
+
if prediction == "all":
|
239
|
+
tail_ranks = tail_predict(batch, model)
|
240
|
+
head_ranks = head_predict(batch, model)
|
241
|
+
if "head_label_type" in batch.keys():
|
242
|
+
return torch.cat([tail_ranks[0], head_ranks[0]]).float(), torch.cat([tail_ranks[1], head_ranks[1]]).float()
|
243
|
+
else:
|
244
|
+
return torch.cat([tail_ranks, head_ranks]).float()
|
245
|
+
elif prediction == "head":
|
246
|
+
if "head_label_type" in batch.keys():
|
247
|
+
ranks, ranks_type = head_predict(batch, model)
|
248
|
+
return ranks.float(), ranks_type.float()
|
249
|
+
else:
|
250
|
+
ranks = head_predict(batch, model)
|
251
|
+
return ranks.float()
|
252
|
+
elif prediction == "tail":
|
253
|
+
if "tail_label_type" in batch.keys():
|
254
|
+
ranks, ranks_type = tail_predict(batch, model)
|
255
|
+
return ranks.float(), ranks_type.float()
|
256
|
+
else:
|
257
|
+
ranks = tail_predict(batch, model)
|
258
|
+
return ranks.float()
|
259
|
+
|
260
|
+
def head_predict(
|
261
|
+
batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]],
|
262
|
+
model: Model) -> tuple[torch.Tensor, ...]:
|
263
|
+
|
264
|
+
"""
|
265
|
+
进行头实体的链接预测。
|
266
|
+
|
267
|
+
:param batch: 测试数据
|
268
|
+
:type batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
|
269
|
+
:param model: KGE 模型
|
270
|
+
:type model: unike.module.model.Model
|
271
|
+
:returns: 正确三元组的排名
|
272
|
+
:rtype: tuple[torch.Tensor, ...]
|
273
|
+
"""
|
274
|
+
|
275
|
+
pos_triple = batch["positive_sample"]
|
276
|
+
idx = pos_triple[:, 0]
|
277
|
+
label = batch["head_label"]
|
278
|
+
pred_score = model.predict(batch, "head_predict")
|
279
|
+
if "head_label_type" in batch.keys():
|
280
|
+
label_type = batch["head_label_type"]
|
281
|
+
return calc_ranks(idx, label, pred_score), calc_ranks(idx, label_type, pred_score)
|
282
|
+
return calc_ranks(idx, label, pred_score)
|
283
|
+
|
284
|
+
def tail_predict(
|
285
|
+
batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]],
|
286
|
+
model: Model) -> tuple[torch.Tensor, ...]:
|
287
|
+
|
288
|
+
"""
|
289
|
+
进行尾实体的链接预测。
|
290
|
+
|
291
|
+
:param batch: 测试数据
|
292
|
+
:type batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
|
293
|
+
:param model: KGE 模型
|
294
|
+
:type model: unike.module.model.Model
|
295
|
+
:returns: 正确三元组的排名
|
296
|
+
:rtype: tuple[torch.Tensor, ...]
|
297
|
+
"""
|
298
|
+
|
299
|
+
pos_triple = batch["positive_sample"]
|
300
|
+
idx = pos_triple[:, 2]
|
301
|
+
label = batch["tail_label"]
|
302
|
+
pred_score = model.predict(batch, "tail_predict")
|
303
|
+
if "tail_label_type" in batch.keys():
|
304
|
+
label_type = batch["tail_label_type"]
|
305
|
+
return calc_ranks(idx, label, pred_score), calc_ranks(idx, label_type, pred_score)
|
306
|
+
return calc_ranks(idx, label, pred_score)
|
307
|
+
|
308
|
+
def calc_ranks(
|
309
|
+
idx: torch.Tensor,
|
310
|
+
label: torch.Tensor,
|
311
|
+
pred_score: torch.Tensor) -> torch.Tensor:
|
312
|
+
|
313
|
+
"""
|
314
|
+
计算三元组的排名。
|
315
|
+
|
316
|
+
:param idx: 需要链接预测的实体 ID
|
317
|
+
:type idx: torch.Tensor
|
318
|
+
:param label: 标签
|
319
|
+
:type label: torch.Tensor
|
320
|
+
:param pred_score: 三元组的评分
|
321
|
+
:type pred_score: torch.Tensor
|
322
|
+
:returns: 正确三元组的排名
|
323
|
+
:rtype: torch.Tensor
|
324
|
+
"""
|
325
|
+
|
326
|
+
b_range = torch.arange(pred_score.size()[0])
|
327
|
+
target_pred = pred_score[b_range, idx]
|
328
|
+
pred_score = torch.where(label.bool(), -torch.ones_like(pred_score) * 10000000, pred_score)
|
329
|
+
pred_score[b_range, idx] = target_pred
|
330
|
+
|
331
|
+
ranks = (
|
332
|
+
1
|
333
|
+
+ torch.argsort(
|
334
|
+
torch.argsort(pred_score, dim=1, descending=True), dim=1, descending=False
|
335
|
+
)[b_range, idx]
|
336
|
+
)
|
337
|
+
return ranks
|
338
|
+
|
339
|
+
def get_tester_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
340
|
+
|
341
|
+
"""返回 :py:class:`Tester` 的默认超参数优化配置。
|
342
|
+
|
343
|
+
默认配置为::
|
344
|
+
|
345
|
+
parameters_dict = {
|
346
|
+
'tester': {
|
347
|
+
'value': 'Tester'
|
348
|
+
},
|
349
|
+
'prediction': {
|
350
|
+
'value': 'all'
|
351
|
+
},
|
352
|
+
'use_tqdm': {
|
353
|
+
'value': False
|
354
|
+
},
|
355
|
+
'use_gpu': {
|
356
|
+
'value': True
|
357
|
+
},
|
358
|
+
'device': {
|
359
|
+
'value': 'cuda:0'
|
360
|
+
},
|
361
|
+
}
|
362
|
+
|
363
|
+
:returns: :py:class:`Tester` 的默认超参数优化配置
|
364
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
365
|
+
"""
|
366
|
+
|
367
|
+
parameters_dict = {
|
368
|
+
'tester': {
|
369
|
+
'value': 'Tester'
|
370
|
+
},
|
371
|
+
'prediction': {
|
372
|
+
'value': 'all'
|
373
|
+
},
|
374
|
+
'use_tqdm': {
|
375
|
+
'value': False
|
376
|
+
},
|
377
|
+
'use_gpu': {
|
378
|
+
'value': True
|
379
|
+
},
|
380
|
+
'device': {
|
381
|
+
'value': 'cuda:0'
|
382
|
+
},
|
383
|
+
}
|
384
|
+
|
385
|
+
return parameters_dict
|