unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
unike/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ from .version import __version__
@@ -0,0 +1,305 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/config/HPOTrainer.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 2, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 6, 2024
7
+ #
8
+ # 该脚本定义了并行训练循环函数.
9
+
10
+ """
11
+ hpo_train - 超参数优化训练循环函数。
12
+ """
13
+
14
+ import wandb
15
+ import typing
16
+ from ..utils import import_class
17
+ from ..module.model import TransE
18
+ from ..module.loss import MarginLoss
19
+ from ..module.strategy import NegativeSampling
20
+ from ..config import Trainer, Tester
21
+ from ..data import KGEDataLoader
22
+ import logging
23
+
24
+ logger = logging.getLogger(__name__)
25
+ logging.basicConfig(format='%(levelname)s:%(module)s:%(asctime)s:%(message)s',
26
+ datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
27
+
28
+ def set_hpo_config(
29
+ method: str = 'bayes',
30
+ sweep_name: str = 'unike_hpo',
31
+ metric_name: str = 'val/hits@10',
32
+ metric_goal: str = 'maximize',
33
+ data_loader_config: dict[str, dict[str, typing.Any]] = {},
34
+ kge_config: dict[str, dict[str, typing.Any]] = {},
35
+ loss_config: dict[str, dict[str, typing.Any]] = {},
36
+ strategy_config: dict[str, dict[str, typing.Any]] = {},
37
+ tester_config: dict[str, dict[str, typing.Any]] = {},
38
+ trainer_config: dict[str, dict[str, typing.Any]] = {}) -> dict[str, dict[str, typing.Any]]:
39
+
40
+ """设置超参数优化范围。
41
+
42
+ :param method: 超参数优化的方法,``grid`` 或 ``random`` 或 ``bayes``
43
+ :type param: str
44
+ :param sweep_name: 超参数优化 sweep 的名字
45
+ :type sweep_name: str
46
+ :param metric_name: 超参数优化的指标名字
47
+ :type metric_name: str
48
+ :param metric_goal: 超参数优化的指标目标,``maximize`` 或 ``minimize``
49
+ :type metric_goal: str
50
+ :param data_loader_config: :py:class:`unike.data.KGEDataLoader` 的超参数优化配置
51
+ :type data_loader_config: dict
52
+ :param kge_config: :py:class:`unike.module.model.Model` 的超参数优化配置
53
+ :type kge_config: dict
54
+ :param loss_config: :py:class:`unike.module.loss.Loss` 的超参数优化配置
55
+ :type loss_config: dict
56
+ :param strategy_config: :py:class:`unike.module.strategy.Strategy` 的超参数优化配置
57
+ :type strategy_config: dict
58
+ :param tester_config: :py:class:`unike.config.Tester` 的超参数优化配置
59
+ :type tester_config: dict
60
+ :param trainer_config: :py:class:`unike.config.Trainer` 的超参数优化配置
61
+ :type trainer_config: dict
62
+ :returns: 超参数优化范围
63
+ :rtype: dict
64
+ """
65
+
66
+ sweep_config: dict[str, str] = {
67
+ 'method': method,
68
+ 'name': sweep_name
69
+ }
70
+
71
+ metric: dict[str, str] = {
72
+ 'name': metric_name,
73
+ 'goal': metric_goal
74
+ }
75
+
76
+ parameters_dict: dict[str, dict[str, typing.Any]] | None = {}
77
+ parameters_dict.update(data_loader_config)
78
+ parameters_dict.update(kge_config)
79
+ parameters_dict.update(loss_config)
80
+ parameters_dict.update(strategy_config)
81
+ parameters_dict.update(tester_config)
82
+ parameters_dict.update(trainer_config)
83
+
84
+ sweep_config['metric'] = metric
85
+ sweep_config['parameters'] = parameters_dict
86
+
87
+ return sweep_config
88
+
89
+ def set_hpo_hits(
90
+ new_hits: list[int] = [1, 3, 10]):
91
+
92
+ """设置 Hits 指标。
93
+
94
+ :param new_hits: 准备报告的指标 Hit@N 的列表,默认为 [1, 3, 10], 表示报告 Hits@1, Hits@3, Hits@10
95
+ :type new_hits: list[int]
96
+ """
97
+
98
+ tmp = Tester.hits
99
+ Tester.hits = new_hits
100
+ logger.info(f"Hits@N 指标由 {tmp} 变为 {Tester.hits}")
101
+
102
+ def start_hpo_train(
103
+ config: dict[str, dict[str, typing.Any]] | None = None,
104
+ project: str = "pybind11-ke-sweeps",
105
+ count: int = 2):
106
+
107
+ """开启超参数优化。
108
+
109
+ :param config: wandb 的超参数优化配置。
110
+ :type config: dict
111
+ :param project: 项目名
112
+ :type param: str
113
+ :param count: 进行几次尝试。
114
+ :type count: int
115
+ """
116
+
117
+ wandb.login()
118
+
119
+ sweep_id = wandb.sweep(config, project=project)
120
+
121
+ wandb.agent(sweep_id, hpo_train, count=count)
122
+
123
+ def hpo_train(config: dict[str, typing.Any] | None = None):
124
+
125
+ """超参数优化训练循环函数。
126
+
127
+ :param config: wandb 的项目配置如超参数。
128
+ :type config: dict[str, typing.Any] | None
129
+ """
130
+
131
+ with wandb.init(config = config):
132
+
133
+ config = wandb.config
134
+
135
+ # dataloader for training
136
+ dataloader_class: type[KGEDataLoader] = import_class(f"unike.data.{config.dataloader}")
137
+ dataloader = dataloader_class(
138
+ in_path = config.in_path,
139
+ ent_file = config.ent_file,
140
+ rel_file = config.rel_file,
141
+ train_file = config.train_file,
142
+ valid_file = config.valid_file,
143
+ test_file = config.test_file,
144
+ batch_size = config.batch_size,
145
+ neg_ent = config.neg_ent,
146
+ test = True,
147
+ test_batch_size = config.test_batch_size,
148
+ type_constrain = config.type_constrain,
149
+ num_workers = config.num_workers,
150
+ train_sampler = import_class(f"unike.data.{config.train_sampler}"),
151
+ test_sampler = import_class(f"unike.data.{config.test_sampler}")
152
+ )
153
+
154
+ # define the model
155
+ model_class = import_class(f"unike.module.model.{config.model}")
156
+ if config.model in ["TransE", "TransH"]:
157
+ kge_model = model_class(
158
+ ent_tol = dataloader.get_ent_tol(),
159
+ rel_tol = dataloader.get_rel_tol(),
160
+ dim = config.dim,
161
+ p_norm = config.p_norm,
162
+ norm_flag = config.norm_flag
163
+ )
164
+ elif config.model == "TransR":
165
+ transe = TransE(
166
+ ent_tol = dataloader.get_ent_tol(),
167
+ rel_tol = dataloader.get_rel_tol(),
168
+ dim = config.dim,
169
+ p_norm = config.p_norm,
170
+ norm_flag = config.norm_flag
171
+ )
172
+ kge_model = model_class(
173
+ ent_tol = dataloader.get_ent_tol(),
174
+ rel_tol = dataloader.get_rel_tol(),
175
+ dim_e = config.dim,
176
+ dim_r = config.dim,
177
+ p_norm = config.p_norm,
178
+ norm_flag = config.norm_flag,
179
+ rand_init = config.rand_init)
180
+ model_e = NegativeSampling(
181
+ model = transe,
182
+ loss = MarginLoss(margin = config.margin_e)
183
+ )
184
+ trainer_e = Trainer(
185
+ model = model_e,
186
+ data_loader = dataloader.train_dataloader(),
187
+ epochs = 1,
188
+ lr = config.lr_e,
189
+ opt_method = config.opt_method_e,
190
+ use_gpu = config.use_gpu,
191
+ device = config.device
192
+ )
193
+ trainer_e.run()
194
+ parameters = transe.get_parameters()
195
+ transe.save_parameters("./transr_transe.json")
196
+ kge_model.set_parameters(parameters)
197
+ elif config.model == "TransD":
198
+ kge_model = model_class(
199
+ ent_tol = dataloader.get_ent_tol(),
200
+ rel_tol = dataloader.get_rel_tol(),
201
+ dim_e = config.dim_e,
202
+ dim_r = config.dim_r,
203
+ p_norm = config.p_norm,
204
+ norm_flag = config.norm_flag)
205
+ elif config.model == "RotatE":
206
+ kge_model = model_class(
207
+ ent_tol = dataloader.get_ent_tol(),
208
+ rel_tol = dataloader.get_rel_tol(),
209
+ dim = config.dim,
210
+ margin = config.margin,
211
+ epsilon = config.epsilon)
212
+ elif config.model in ["RESCAL", "DistMult", "HolE", "ComplEx", "Analogy", "SimplE"]:
213
+ kge_model = model_class(
214
+ ent_tol = dataloader.get_ent_tol(),
215
+ rel_tol = dataloader.get_rel_tol(),
216
+ dim = config.dim)
217
+ elif config.model == "RGCN":
218
+ kge_model = model_class(
219
+ ent_tol = dataloader.get_ent_tol(),
220
+ rel_tol = dataloader.get_rel_tol(),
221
+ dim = config.dim,
222
+ num_layers = config.num_layers)
223
+ elif config.model == "CompGCN":
224
+ kge_model = model_class(
225
+ ent_tol = dataloader.get_ent_tol(),
226
+ rel_tol = dataloader.get_rel_tol(),
227
+ dim = config.dim,
228
+ opn = config.opn,
229
+ fet_drop = config.fet_drop,
230
+ hid_drop = config.hid_drop,
231
+ margin = config.margin,
232
+ decoder_model = config.decoder_model)
233
+
234
+ # define the loss function
235
+ loss_class = import_class(f"unike.module.loss.{config.loss}")
236
+ if config.loss == 'MarginLoss':
237
+ loss = loss_class(
238
+ adv_temperature = config.adv_temperature,
239
+ margin = config.margin
240
+ )
241
+ elif config.loss in ['SigmoidLoss', 'SoftplusLoss']:
242
+ loss = loss_class(adv_temperature = config.adv_temperature)
243
+ elif config.loss == 'RGCNLoss':
244
+ loss = loss_class(
245
+ model = kge_model,
246
+ regularization = config.regularization
247
+ )
248
+ elif config.loss == 'CompGCNLoss':
249
+ loss = loss_class(model = kge_model)
250
+
251
+ # define the strategy
252
+ strategy_class = import_class(f"unike.module.strategy.{config.strategy}")
253
+ if config.strategy == 'NegativeSampling':
254
+ model = strategy_class(
255
+ model = kge_model,
256
+ loss = loss,
257
+ regul_rate = config.regul_rate,
258
+ l3_regul_rate = config.l3_regul_rate
259
+ )
260
+ elif config.strategy == 'RGCNSampling':
261
+ model = strategy_class(
262
+ model = kge_model,
263
+ loss = loss
264
+ )
265
+ elif config.strategy == 'CompGCNSampling':
266
+ model = strategy_class(
267
+ model = kge_model,
268
+ loss = loss,
269
+ smoothing = config.smoothing,
270
+ ent_tol = dataloader.train_sampler.ent_tol
271
+ )
272
+
273
+ # test the model
274
+ tester_class: type[Tester] = import_class(f"unike.config.{config.tester}")
275
+ tester = tester_class(
276
+ model = kge_model,
277
+ data_loader = dataloader,
278
+ prediction = config.prediction,
279
+ use_tqdm = config.use_tqdm,
280
+ use_gpu = config.use_gpu,
281
+ device = config.device
282
+ )
283
+
284
+ # # train the model
285
+ trainer_class: type[Trainer] = import_class(f"unike.config.{config.trainer}")
286
+ trainer = trainer_class(
287
+ model = model,
288
+ data_loader = dataloader.train_dataloader(),
289
+ epochs = config.epochs,
290
+ lr = config.lr,
291
+ opt_method = config.opt_method,
292
+ use_gpu = config.use_gpu,
293
+ device = config.device,
294
+ tester = tester,
295
+ test = True,
296
+ valid_interval = config.valid_interval,
297
+ log_interval = config.log_interval,
298
+ save_path = config.save_path,
299
+ use_early_stopping = config.use_early_stopping,
300
+ metric = config.metric,
301
+ patience = config.patience,
302
+ delta = config.delta,
303
+ use_wandb = True
304
+ )
305
+ trainer.run()
unike/config/Tester.py ADDED
@@ -0,0 +1,385 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/config/Tester.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 6, 2024
7
+ #
8
+ # 该脚本定义了验证模型类.
9
+
10
+ """
11
+ Tester - 验证模型类,内部使用 ``tqmn`` 实现进度条。
12
+ """
13
+
14
+ import dgl
15
+ import torch
16
+ import typing
17
+ import collections
18
+ import numpy as np
19
+ from tqdm import tqdm
20
+ from ..module.model import Model
21
+ from ..data import KGEDataLoader
22
+ import logging
23
+
24
+ logger = logging.getLogger(__name__)
25
+ logging.basicConfig(format='%(levelname)s:%(module)s:%(asctime)s:%(message)s',
26
+ datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
27
+
28
+ class Tester(object):
29
+
30
+ """
31
+ 主要用于 KGE 模型的评估。
32
+
33
+ 例子::
34
+
35
+ from unike.data import KGEDataLoader, BernSampler, TradTestSampler
36
+ from unike.module.model import TransE
37
+ from unike.module.loss import MarginLoss
38
+ from unike.module.strategy import NegativeSampling
39
+ from unike.config import Trainer, Tester
40
+
41
+ # dataloader for training
42
+ dataloader = KGEDataLoader(
43
+ in_path = "../../benchmarks/FB15K/",
44
+ batch_size = 8192,
45
+ neg_ent = 25,
46
+ test = True,
47
+ test_batch_size = 256,
48
+ num_workers = 16,
49
+ train_sampler = BernSampler,
50
+ test_sampler = TradTestSampler
51
+ )
52
+
53
+ # define the model
54
+ transe = TransE(
55
+ ent_tol = dataloader.train_sampler.ent_tol,
56
+ rel_tol = dataloader.train_sampler.rel_tol,
57
+ dim = 50,
58
+ p_norm = 1,
59
+ norm_flag = True)
60
+
61
+ # define the loss function
62
+ model = NegativeSampling(
63
+ model = transe,
64
+ loss = MarginLoss(margin = 1.0),
65
+ regul_rate = 0.01
66
+ )
67
+
68
+ # test the model
69
+ tester = Tester(model = transe, data_loader = dataloader, use_gpu = True, device = 'cuda:1')
70
+
71
+ # train the model
72
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
73
+ epochs = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1',
74
+ tester = tester, test = True, valid_interval = 100,
75
+ log_interval = 100, save_interval = 100,
76
+ save_path = '../../checkpoint/transe.pth', delta = 0.01)
77
+ trainer.run()
78
+ """
79
+
80
+ #: 准备报告的指标 Hit@N 的列表,默认为 [1, 3, 10], 表示报告 Hits@1, Hits@3, Hits@10
81
+ hits: list[int] = [1, 3, 10]
82
+
83
+ def __init__(
84
+ self,
85
+ model: Model = None,
86
+ data_loader: KGEDataLoader = None,
87
+ sampling_mode: str = 'link_test',
88
+ prediction: str = "all",
89
+ use_tqdm: bool = True,
90
+ use_gpu: bool = True,
91
+ device: str = "cuda:0",
92
+ only_test: bool = False):
93
+
94
+ """创建 Tester 对象。
95
+
96
+ :param model: KGE 模型
97
+ :type model: unike.module.model.Model
98
+ :param data_loader: py:class:`unike.data.KGEDataLoader`
99
+ :type data_loader: unike.data.KGEDataLoader
100
+ :param sampling_mode: 评估验证集还是测试集:**'link_test'** or **'link_valid'**
101
+ :type sampling_mode: str
102
+ :param prediction: 链接预测模式: **'all'**、**'head'**、**'tail'**
103
+ :type prediction: str
104
+ :param use_tqdm: 是否启用进度条
105
+ :type use_tqdm: bool
106
+ :param use_gpu: 是否使用 gpu
107
+ :type use_gpu: bool
108
+ :param device: 使用哪个 gpu
109
+ :type device: str
110
+ :param only_test: 是否是评估已经训练好的模型
111
+ :type only_test: bool
112
+ """
113
+
114
+ #: KGE 模型,即 :py:class:`unike.module.model.Model`
115
+ self.model: Model = model
116
+ #: :py:class:`unike.data.KGEDataLoader`
117
+ self.data_loader: KGEDataLoader = data_loader
118
+ #: :py:class:`unike.data.TestDataLoader` 负采样的方式:**'link_test'** or **'link_valid'**
119
+ self.sampling_mode: str = sampling_mode
120
+ #: 链接预测模式: **'all'**、**'head'**、**'tail'**
121
+ self.prediction: str = prediction
122
+ #: 是否启用进度条
123
+ self.use_tqdm: bool = use_tqdm
124
+ #: 是否使用 gpu
125
+ self.use_gpu: bool = use_gpu
126
+ #: gpu,利用 ``device`` 构造的 :py:class:`torch.device` 对象
127
+ self.device: torch.device = torch.device(device)
128
+ #: 验证数据加载器。
129
+ self.val_dataloader: torch.utils.data.DataLoader = self.data_loader.val_dataloader()
130
+ #: 测试数据加载器。
131
+ self.test_dataloader: torch.utils.data.DataLoader = self.data_loader.test_dataloader()
132
+
133
+ if self.use_gpu and only_test:
134
+ self.model.cuda(device = self.device)
135
+
136
+ def set_hits(
137
+ self,
138
+ new_hits: list[int] = [1, 3, 10]):
139
+
140
+ """定义 Hits 指标。
141
+
142
+ :param new_hits: 准备报告的指标 Hit@N 的列表,默认为 [1, 3, 10], 表示报告 Hits@1, Hits@3, Hits@10
143
+ :type new_hits: list[int]
144
+ """
145
+
146
+ tmp = self.hits
147
+ self.hits = new_hits
148
+
149
+ logger.info(f"Hits@N 指标由 {tmp} 变为 {self.hits}")
150
+
151
+ def to_var(
152
+ self,
153
+ x: torch.Tensor) -> torch.Tensor:
154
+
155
+ """根据 :py:attr:`use_gpu` 返回 ``x`` 的张量
156
+
157
+ :param x: 数据
158
+ :type x: torch.Tensor
159
+ :returns: 张量
160
+ :rtype: torch.Tensor
161
+ """
162
+
163
+ if self.use_gpu:
164
+ return x.to(self.device)
165
+ else:
166
+ return x
167
+
168
+ def run_link_prediction(self) -> dict[str, float]:
169
+
170
+ """进行链接预测。
171
+
172
+ :returns: 经典指标分别为 MR,MRR,Hits@1,Hits@3,Hits@10
173
+ :rtype: dict[str, float]
174
+ """
175
+
176
+ if self.sampling_mode == "link_valid":
177
+ training_range = tqdm(self.val_dataloader) if self.use_tqdm else self.val_dataloader
178
+ elif self.sampling_mode == "link_test":
179
+ training_range = tqdm(self.test_dataloader) if self.use_tqdm else self.test_dataloader
180
+ self.model.eval()
181
+ results = collections.defaultdict(float)
182
+ results_type = collections.defaultdict(float)
183
+ with torch.no_grad():
184
+ for data in training_range:
185
+ data = {key : self.to_var(value) for key, value in data.items()}
186
+ if "head_label_type" in data.keys():
187
+ ranks, ranks_type = link_predict(data, self.model, prediction=self.prediction)
188
+ results_type["count_type"] += torch.numel(ranks_type)
189
+ results_type["mr_type"] += torch.sum(ranks_type).item()
190
+ results_type["mrr_type"] += torch.sum(1.0 / ranks_type).item()
191
+ for k in self.hits:
192
+ results_type['hits@{}_type'.format(k)] += torch.numel(ranks_type[ranks_type <= k])
193
+ else:
194
+ ranks = link_predict(data, self.model, prediction=self.prediction)
195
+ results["count"] += torch.numel(ranks)
196
+ results["mr"] += torch.sum(ranks).item()
197
+ results["mrr"] += torch.sum(1.0 / ranks).item()
198
+ for k in self.hits:
199
+ results['hits@{}'.format(k)] += torch.numel(ranks[ranks <= k])
200
+
201
+ count = results["count"]
202
+ results = {key : np.around(value / count, decimals=3).item() for key, value in results.items() if key != "count"}
203
+ if "count_type" in results_type.keys():
204
+ count_type = results_type["count_type"]
205
+ results_type = {key : np.around(value / count_type, decimals=3).item() for key, value in results_type.items() if key != "count_type"}
206
+ for key, value in results_type.items():
207
+ results[key] = value
208
+ return results
209
+
210
+ def set_sampling_mode(self, sampling_mode: str):
211
+
212
+ """设置 :py:attr:`sampling_mode`
213
+
214
+ :param sampling_mode: 数据采样模式,**'link_test'** 和 **'link_valid'** 分别表示为链接预测进行测试集和验证集的负采样
215
+ :type sampling_mode: str
216
+ """
217
+
218
+ self.sampling_mode = sampling_mode
219
+
220
+ def link_predict(
221
+ batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]],
222
+ model: Model,
223
+ prediction: str = "all") -> tuple[torch.Tensor, ...]:
224
+
225
+ """
226
+ 进行链接预测。
227
+
228
+ :param batch: 测试数据
229
+ :type batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
230
+ :param model: KGE 模型
231
+ :type model: unike.module.model.Model
232
+ :param prediction: **'all'**, **'head'**, **'tail'**
233
+ :type prediction: str
234
+ :returns: 正确三元组的排名
235
+ :rtype: tuple[torch.Tensor, ...]
236
+ """
237
+
238
+ if prediction == "all":
239
+ tail_ranks = tail_predict(batch, model)
240
+ head_ranks = head_predict(batch, model)
241
+ if "head_label_type" in batch.keys():
242
+ return torch.cat([tail_ranks[0], head_ranks[0]]).float(), torch.cat([tail_ranks[1], head_ranks[1]]).float()
243
+ else:
244
+ return torch.cat([tail_ranks, head_ranks]).float()
245
+ elif prediction == "head":
246
+ if "head_label_type" in batch.keys():
247
+ ranks, ranks_type = head_predict(batch, model)
248
+ return ranks.float(), ranks_type.float()
249
+ else:
250
+ ranks = head_predict(batch, model)
251
+ return ranks.float()
252
+ elif prediction == "tail":
253
+ if "tail_label_type" in batch.keys():
254
+ ranks, ranks_type = tail_predict(batch, model)
255
+ return ranks.float(), ranks_type.float()
256
+ else:
257
+ ranks = tail_predict(batch, model)
258
+ return ranks.float()
259
+
260
+ def head_predict(
261
+ batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]],
262
+ model: Model) -> tuple[torch.Tensor, ...]:
263
+
264
+ """
265
+ 进行头实体的链接预测。
266
+
267
+ :param batch: 测试数据
268
+ :type batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
269
+ :param model: KGE 模型
270
+ :type model: unike.module.model.Model
271
+ :returns: 正确三元组的排名
272
+ :rtype: tuple[torch.Tensor, ...]
273
+ """
274
+
275
+ pos_triple = batch["positive_sample"]
276
+ idx = pos_triple[:, 0]
277
+ label = batch["head_label"]
278
+ pred_score = model.predict(batch, "head_predict")
279
+ if "head_label_type" in batch.keys():
280
+ label_type = batch["head_label_type"]
281
+ return calc_ranks(idx, label, pred_score), calc_ranks(idx, label_type, pred_score)
282
+ return calc_ranks(idx, label, pred_score)
283
+
284
+ def tail_predict(
285
+ batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]],
286
+ model: Model) -> tuple[torch.Tensor, ...]:
287
+
288
+ """
289
+ 进行尾实体的链接预测。
290
+
291
+ :param batch: 测试数据
292
+ :type batch: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
293
+ :param model: KGE 模型
294
+ :type model: unike.module.model.Model
295
+ :returns: 正确三元组的排名
296
+ :rtype: tuple[torch.Tensor, ...]
297
+ """
298
+
299
+ pos_triple = batch["positive_sample"]
300
+ idx = pos_triple[:, 2]
301
+ label = batch["tail_label"]
302
+ pred_score = model.predict(batch, "tail_predict")
303
+ if "tail_label_type" in batch.keys():
304
+ label_type = batch["tail_label_type"]
305
+ return calc_ranks(idx, label, pred_score), calc_ranks(idx, label_type, pred_score)
306
+ return calc_ranks(idx, label, pred_score)
307
+
308
+ def calc_ranks(
309
+ idx: torch.Tensor,
310
+ label: torch.Tensor,
311
+ pred_score: torch.Tensor) -> torch.Tensor:
312
+
313
+ """
314
+ 计算三元组的排名。
315
+
316
+ :param idx: 需要链接预测的实体 ID
317
+ :type idx: torch.Tensor
318
+ :param label: 标签
319
+ :type label: torch.Tensor
320
+ :param pred_score: 三元组的评分
321
+ :type pred_score: torch.Tensor
322
+ :returns: 正确三元组的排名
323
+ :rtype: torch.Tensor
324
+ """
325
+
326
+ b_range = torch.arange(pred_score.size()[0])
327
+ target_pred = pred_score[b_range, idx]
328
+ pred_score = torch.where(label.bool(), -torch.ones_like(pred_score) * 10000000, pred_score)
329
+ pred_score[b_range, idx] = target_pred
330
+
331
+ ranks = (
332
+ 1
333
+ + torch.argsort(
334
+ torch.argsort(pred_score, dim=1, descending=True), dim=1, descending=False
335
+ )[b_range, idx]
336
+ )
337
+ return ranks
338
+
339
+ def get_tester_hpo_config() -> dict[str, dict[str, typing.Any]]:
340
+
341
+ """返回 :py:class:`Tester` 的默认超参数优化配置。
342
+
343
+ 默认配置为::
344
+
345
+ parameters_dict = {
346
+ 'tester': {
347
+ 'value': 'Tester'
348
+ },
349
+ 'prediction': {
350
+ 'value': 'all'
351
+ },
352
+ 'use_tqdm': {
353
+ 'value': False
354
+ },
355
+ 'use_gpu': {
356
+ 'value': True
357
+ },
358
+ 'device': {
359
+ 'value': 'cuda:0'
360
+ },
361
+ }
362
+
363
+ :returns: :py:class:`Tester` 的默认超参数优化配置
364
+ :rtype: dict[str, dict[str, typing.Any]]
365
+ """
366
+
367
+ parameters_dict = {
368
+ 'tester': {
369
+ 'value': 'Tester'
370
+ },
371
+ 'prediction': {
372
+ 'value': 'all'
373
+ },
374
+ 'use_tqdm': {
375
+ 'value': False
376
+ },
377
+ 'use_gpu': {
378
+ 'value': True
379
+ },
380
+ 'device': {
381
+ 'value': 'cuda:0'
382
+ },
383
+ }
384
+
385
+ return parameters_dict