unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,168 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/BernSampler.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 30, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了 BernSampler 类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
BernSampler - 平移模型和语义匹配模型的训练集数据采样器。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import typing
|
16
|
+
import random
|
17
|
+
import collections
|
18
|
+
import numpy as np
|
19
|
+
from .TradSampler import TradSampler
|
20
|
+
from typing_extensions import override
|
21
|
+
|
22
|
+
class BernSampler(TradSampler):
|
23
|
+
|
24
|
+
"""
|
25
|
+
平移模型和语义匹配模型的训练集 Bern 数据采样器(伯努利分布),如果想获得更详细的信息请访问 :ref:`TransH <transh>`。
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
in_path: str = "./",
|
31
|
+
ent_file: str = "entity2id.txt",
|
32
|
+
rel_file: str = "relation2id.txt",
|
33
|
+
train_file: str = "train2id.txt",
|
34
|
+
batch_size: int | None = None,
|
35
|
+
neg_ent: int = 1):
|
36
|
+
|
37
|
+
"""创建 BernSampler 对象。
|
38
|
+
|
39
|
+
:param in_path: 数据集目录
|
40
|
+
:type in_path: str
|
41
|
+
:param ent_file: entity2id.txt
|
42
|
+
:type ent_file: str
|
43
|
+
:param rel_file: relation2id.txt
|
44
|
+
:type rel_file: str
|
45
|
+
:param train_file: train2id.txt
|
46
|
+
:type train_file: str
|
47
|
+
:param batch_size: batch size 在该采样器中不起作用,只是占位符。
|
48
|
+
:type batch_size: int | None
|
49
|
+
:param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity
|
50
|
+
:type neg_ent: int
|
51
|
+
"""
|
52
|
+
|
53
|
+
super().__init__(
|
54
|
+
in_path=in_path,
|
55
|
+
ent_file=ent_file,
|
56
|
+
rel_file=rel_file,
|
57
|
+
train_file=train_file,
|
58
|
+
batch_size = batch_size,
|
59
|
+
neg_ent = neg_ent
|
60
|
+
)
|
61
|
+
|
62
|
+
self.tph, self.hpt = self.get_tph_hpt()
|
63
|
+
|
64
|
+
def get_tph_hpt(self) -> tuple[collections.defaultdict[float], collections.defaultdict[float]]:
|
65
|
+
|
66
|
+
"""计算 tph 和 hpt。
|
67
|
+
|
68
|
+
:returns: tph 和 hpt
|
69
|
+
:rtype: tuple[collections.defaultdict[float], collections.defaultdict[float]]
|
70
|
+
"""
|
71
|
+
|
72
|
+
h_of_r = collections.defaultdict(set)
|
73
|
+
t_of_r = collections.defaultdict(set)
|
74
|
+
freq_rel = collections.defaultdict(float)
|
75
|
+
tph = collections.defaultdict(float)
|
76
|
+
hpt = collections.defaultdict(float)
|
77
|
+
for h, r, t in self.train_triples:
|
78
|
+
freq_rel[r] += 1.0
|
79
|
+
h_of_r[r].add(h)
|
80
|
+
t_of_r[r].add(t)
|
81
|
+
for r in h_of_r:
|
82
|
+
tph[r] = freq_rel[r] / len(h_of_r[r])
|
83
|
+
hpt[r] = freq_rel[r] / len(t_of_r[r])
|
84
|
+
return tph, hpt
|
85
|
+
|
86
|
+
@override
|
87
|
+
def sampling(
|
88
|
+
self,
|
89
|
+
pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[str, torch.Tensor]]:
|
90
|
+
|
91
|
+
"""平移模型和语义匹配模型的训练集 bern 的数据采样函数(伯努利分布)。
|
92
|
+
|
93
|
+
:param pos_triples: 知识图谱中的正确三元组
|
94
|
+
:type pos_triples: list[tuple[int, int, int]]
|
95
|
+
:returns: 平移模型和语义匹配模型的训练数据
|
96
|
+
:rtype: dict[str, typing.Union[str, torch.Tensor]]
|
97
|
+
"""
|
98
|
+
|
99
|
+
batch_data = {}
|
100
|
+
neg_ent_sample = []
|
101
|
+
|
102
|
+
batch_data['mode'] = 'bern'
|
103
|
+
for h, r, t in pos_triples:
|
104
|
+
neg_ent = self.__normal_batch(h, r, t, self.neg_ent)
|
105
|
+
neg_ent_sample += neg_ent
|
106
|
+
|
107
|
+
batch_data["positive_sample"] = torch.LongTensor(np.array(pos_triples))
|
108
|
+
batch_data["negative_sample"] = torch.LongTensor(np.array(neg_ent_sample))
|
109
|
+
|
110
|
+
return batch_data
|
111
|
+
|
112
|
+
def __normal_batch(
|
113
|
+
self,
|
114
|
+
h: int,
|
115
|
+
r: int,
|
116
|
+
t: int,
|
117
|
+
neg_size: int) -> list[tuple[int, int, int]]:
|
118
|
+
|
119
|
+
"""Bern 负采样函数
|
120
|
+
|
121
|
+
:param h: 头实体
|
122
|
+
:type h: int
|
123
|
+
:param r: 关系
|
124
|
+
:type r: int
|
125
|
+
:param t: 尾实体
|
126
|
+
:type t: int
|
127
|
+
:param neg_size: 负三元组个数
|
128
|
+
:type neg_size: int
|
129
|
+
:returns: 负三元组中的头实体列表
|
130
|
+
:rtype: list[tuple[int, int, int]]
|
131
|
+
"""
|
132
|
+
|
133
|
+
neg_size_h = 0
|
134
|
+
neg_size_t = 0
|
135
|
+
prob = self.hpt[r] / (self.hpt[r] + self.tph[r])
|
136
|
+
for _ in range(neg_size):
|
137
|
+
if random.random() < prob:
|
138
|
+
neg_size_t += 1
|
139
|
+
else:
|
140
|
+
neg_size_h += 1
|
141
|
+
|
142
|
+
res = []
|
143
|
+
|
144
|
+
neg_list_h = []
|
145
|
+
neg_cur_size = 0
|
146
|
+
while neg_cur_size < neg_size_h:
|
147
|
+
neg_tmp_h = self.corrupt_head(t, r, num_max=(neg_size_h - neg_cur_size) * 2)
|
148
|
+
neg_list_h.append(neg_tmp_h)
|
149
|
+
neg_cur_size += len(neg_tmp_h)
|
150
|
+
if neg_list_h != []:
|
151
|
+
neg_list_h = np.concatenate(neg_list_h)
|
152
|
+
|
153
|
+
for hh in neg_list_h[:neg_size_h]:
|
154
|
+
res.append((hh, r, t))
|
155
|
+
|
156
|
+
neg_list_t = []
|
157
|
+
neg_cur_size = 0
|
158
|
+
while neg_cur_size < neg_size_t:
|
159
|
+
neg_tmp_t = self.corrupt_tail(h, r, num_max=(neg_size_t - neg_cur_size) * 2)
|
160
|
+
neg_list_t.append(neg_tmp_t)
|
161
|
+
neg_cur_size += len(neg_tmp_t)
|
162
|
+
if neg_list_t != []:
|
163
|
+
neg_list_t = np.concatenate(neg_list_t)
|
164
|
+
|
165
|
+
for tt in neg_list_t[:neg_size_t]:
|
166
|
+
res.append((h, r, tt))
|
167
|
+
|
168
|
+
return res
|
@@ -0,0 +1,140 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/CompGCNSampler.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
|
7
|
+
#
|
8
|
+
# 该脚本定义了 CompGCNSampler 类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
CompGCNSampler - CompGCN 的数据采样器。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import dgl
|
15
|
+
import torch
|
16
|
+
import typing
|
17
|
+
import numpy as np
|
18
|
+
from .RGCNSampler import RGCNSampler
|
19
|
+
from typing_extensions import override
|
20
|
+
|
21
|
+
class CompGCNSampler(RGCNSampler):
|
22
|
+
|
23
|
+
"""``CompGCN`` :cite:`CompGCN` 的训练数据采样器。
|
24
|
+
|
25
|
+
例子::
|
26
|
+
|
27
|
+
from unike.data import RGCNSampler, CompGCNSampler
|
28
|
+
from torch.utils.data import DataLoader
|
29
|
+
|
30
|
+
#: 训练数据采样器
|
31
|
+
train_sampler: typing.Union[typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = train_sampler(
|
32
|
+
in_path=in_path,
|
33
|
+
ent_file=ent_file,
|
34
|
+
rel_file=rel_file,
|
35
|
+
train_file=train_file,
|
36
|
+
batch_size=batch_size,
|
37
|
+
neg_ent=neg_ent
|
38
|
+
)
|
39
|
+
|
40
|
+
#: 训练集三元组
|
41
|
+
data_train: list[tuple[int, int, int]] = train_sampler.get_train()
|
42
|
+
|
43
|
+
train_dataloader = DataLoader(
|
44
|
+
data_train,
|
45
|
+
shuffle=True,
|
46
|
+
batch_size=batch_size,
|
47
|
+
num_workers=num_workers,
|
48
|
+
pin_memory=True,
|
49
|
+
drop_last=True,
|
50
|
+
collate_fn=train_sampler.sampling,
|
51
|
+
)
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
in_path: str = "./",
|
57
|
+
ent_file: str = "entity2id.txt",
|
58
|
+
rel_file: str = "relation2id.txt",
|
59
|
+
train_file: str = "train2id.txt",
|
60
|
+
batch_size: int | None = None,
|
61
|
+
neg_ent: int = 1):
|
62
|
+
|
63
|
+
"""创建 CompGCNSampler 对象。
|
64
|
+
|
65
|
+
:param in_path: 数据集目录
|
66
|
+
:type in_path: str
|
67
|
+
:param ent_file: entity2id.txt
|
68
|
+
:type ent_file: str
|
69
|
+
:param rel_file: relation2id.txt
|
70
|
+
:type rel_file: str
|
71
|
+
:param train_file: train2id.txt
|
72
|
+
:type train_file: str
|
73
|
+
:param batch_size: batch size
|
74
|
+
:type batch_size: int | None
|
75
|
+
:param neg_ent: 对于 CompGCN 不起作用。
|
76
|
+
:type neg_ent: int
|
77
|
+
"""
|
78
|
+
|
79
|
+
super().__init__(
|
80
|
+
in_path=in_path,
|
81
|
+
ent_file=ent_file,
|
82
|
+
rel_file=rel_file,
|
83
|
+
train_file=train_file,
|
84
|
+
batch_size=batch_size,
|
85
|
+
neg_ent=neg_ent
|
86
|
+
)
|
87
|
+
|
88
|
+
super().get_hr_train()
|
89
|
+
|
90
|
+
self.graph, self.relation, self.norm = \
|
91
|
+
self.build_graph(self.ent_tol, np.array(self.t_triples).transpose(), -0.5)
|
92
|
+
|
93
|
+
@override
|
94
|
+
def sampling(
|
95
|
+
self,
|
96
|
+
pos_hr_t: list[tuple[tuple[int, int], list[int]]]) -> dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]:
|
97
|
+
|
98
|
+
"""``CompGCN`` :cite:`CompGCN` 的采样函数。
|
99
|
+
|
100
|
+
:param pos_triples: 知识图谱中的正确三元组
|
101
|
+
:type pos_triples: list[tuple[tuple[int, int], list[int]]]
|
102
|
+
:returns: ``CompGCN`` :cite:`CompGCN` 的训练数据
|
103
|
+
:rtype: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
|
104
|
+
"""
|
105
|
+
|
106
|
+
batch_data = {}
|
107
|
+
|
108
|
+
self.label = torch.zeros(self.batch_size, self.ent_tol)
|
109
|
+
self.triples = torch.LongTensor([hr for hr , _ in pos_hr_t])
|
110
|
+
for id, hr_sample in enumerate([t for _ ,t in pos_hr_t]):
|
111
|
+
self.label[id][hr_sample] = 1
|
112
|
+
|
113
|
+
batch_data['sample'] = self.triples
|
114
|
+
batch_data['label'] = self.label
|
115
|
+
batch_data['graph'] = self.graph
|
116
|
+
batch_data['relation'] = self.relation
|
117
|
+
batch_data['norm'] = self.norm
|
118
|
+
|
119
|
+
return batch_data
|
120
|
+
|
121
|
+
@override
|
122
|
+
def node_norm_to_edge_norm(
|
123
|
+
self,
|
124
|
+
graph: dgl.DGLGraph,
|
125
|
+
node_norm: torch.Tensor) -> torch.Tensor:
|
126
|
+
|
127
|
+
"""根据源节点和目标节点的度计算每条边的归一化系数。
|
128
|
+
|
129
|
+
:param graph: 子图的节点数
|
130
|
+
:type graph: dgl.DGLGraph
|
131
|
+
:param node_norm: 节点的归一化系数
|
132
|
+
:type node_norm: torch.Tensor
|
133
|
+
:returns: 边的归一化系数
|
134
|
+
:rtype: torch.Tensor
|
135
|
+
"""
|
136
|
+
|
137
|
+
graph.ndata['norm'] = node_norm
|
138
|
+
graph.apply_edges(lambda edges: {'norm': edges.dst['norm'] * edges.src['norm']})
|
139
|
+
norm = graph.edata.pop('norm').squeeze()
|
140
|
+
return norm
|
@@ -0,0 +1,84 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/CompGCNTestSampler.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
|
7
|
+
#
|
8
|
+
# 该脚本定义了 CompGCNTestSampler 类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
CompGCNTestSampler - CompGCN 的测试数据采样器。
|
12
|
+
"""
|
13
|
+
|
14
|
+
from .CompGCNSampler import CompGCNSampler
|
15
|
+
from .RGCNTestSampler import RGCNTestSampler
|
16
|
+
|
17
|
+
class CompGCNTestSampler(RGCNTestSampler):
|
18
|
+
|
19
|
+
"""``CompGCN`` :cite:`CompGCN` 的测试数据采样器。
|
20
|
+
|
21
|
+
例子::
|
22
|
+
|
23
|
+
from unike.data import RGCNTestSampler, CompGCNTestSampler
|
24
|
+
from torch.utils.data import DataLoader
|
25
|
+
|
26
|
+
#: 测试数据采样器
|
27
|
+
test_sampler: typing.Union[typing.Type[RGCNTestSampler], typing.Type[CompGCNTestSampler]] = test_sampler(
|
28
|
+
sampler=train_sampler,
|
29
|
+
valid_file=valid_file,
|
30
|
+
test_file=test_file,
|
31
|
+
)
|
32
|
+
|
33
|
+
#: 验证集三元组
|
34
|
+
data_val: list[tuple[int, int, int]] = test_sampler.get_valid()
|
35
|
+
#: 测试集三元组
|
36
|
+
data_test: list[tuple[int, int, int]] = test_sampler.get_test()
|
37
|
+
|
38
|
+
val_dataloader = DataLoader(
|
39
|
+
data_val,
|
40
|
+
shuffle=False,
|
41
|
+
batch_size=test_batch_size,
|
42
|
+
num_workers=num_workers,
|
43
|
+
pin_memory=True,
|
44
|
+
collate_fn=test_sampler.sampling,
|
45
|
+
)
|
46
|
+
|
47
|
+
test_dataloader = DataLoader(
|
48
|
+
data_test,
|
49
|
+
shuffle=False,
|
50
|
+
batch_size=test_batch_size,
|
51
|
+
num_workers=num_workers,
|
52
|
+
pin_memory=True,
|
53
|
+
collate_fn=test_sampler.sampling,
|
54
|
+
)
|
55
|
+
"""
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
sampler: CompGCNSampler,
|
60
|
+
valid_file: str = "valid2id.txt",
|
61
|
+
test_file: str = "test2id.txt",
|
62
|
+
type_constrain: bool = True):
|
63
|
+
|
64
|
+
"""创建 CompGCNTestSampler 对象。
|
65
|
+
|
66
|
+
:param sampler: 训练数据采样器。
|
67
|
+
:type sampler: CompGCNSampler
|
68
|
+
:param valid_file: valid2id.txt
|
69
|
+
:type valid_file: str
|
70
|
+
:param test_file: test2id.txt
|
71
|
+
:type test_file: str
|
72
|
+
:param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
|
73
|
+
:type type_constrain: bool
|
74
|
+
"""
|
75
|
+
|
76
|
+
super().__init__(
|
77
|
+
sampler=sampler,
|
78
|
+
valid_file = valid_file,
|
79
|
+
test_file = test_file,
|
80
|
+
type_constrain = type_constrain
|
81
|
+
)
|
82
|
+
|
83
|
+
#: 幂
|
84
|
+
self.power: float = -0.5
|
@@ -0,0 +1,315 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/KGEDataLoader.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Apr 27, 2024
|
7
|
+
#
|
8
|
+
# 为 KGE 模型读取数据.
|
9
|
+
|
10
|
+
"""
|
11
|
+
KGEDataLoader - KGE 模型读取数据集类。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import typing
|
15
|
+
from .UniSampler import UniSampler
|
16
|
+
from .BernSampler import BernSampler
|
17
|
+
from .RGCNSampler import RGCNSampler
|
18
|
+
from .CompGCNSampler import CompGCNSampler
|
19
|
+
from .TestSampler import TestSampler
|
20
|
+
from .TradTestSampler import TradTestSampler
|
21
|
+
from torch.utils.data import DataLoader
|
22
|
+
|
23
|
+
class KGEDataLoader:
|
24
|
+
|
25
|
+
"""KGE 模型数据加载器。
|
26
|
+
|
27
|
+
例子::
|
28
|
+
|
29
|
+
from unike.data import KGEDataLoader, BernSampler, TradTestSampler
|
30
|
+
|
31
|
+
dataloader = KGEDataLoader(
|
32
|
+
in_path = "../../benchmarks/FB15K/",
|
33
|
+
batch_size = 8192,
|
34
|
+
neg_ent = 25,
|
35
|
+
test = True,
|
36
|
+
test_batch_size = 256,
|
37
|
+
num_workers = 16,
|
38
|
+
train_sampler = BernSampler,
|
39
|
+
test_sampler = TradTestSampler
|
40
|
+
)
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
in_path: str = "./",
|
46
|
+
ent_file: str = "entity2id.txt",
|
47
|
+
rel_file: str = "relation2id.txt",
|
48
|
+
train_file: str = "train2id.txt",
|
49
|
+
valid_file: str = "valid2id.txt",
|
50
|
+
test_file: str = "test2id.txt",
|
51
|
+
batch_size: int | None = None,
|
52
|
+
neg_ent: int = 1,
|
53
|
+
test: bool = False,
|
54
|
+
test_batch_size: int | None = None,
|
55
|
+
type_constrain: bool = True,
|
56
|
+
num_workers: int | None = None,
|
57
|
+
train_sampler: typing.Union[typing.Type[UniSampler], typing.Type[BernSampler], typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = BernSampler,
|
58
|
+
test_sampler: typing.Type[TestSampler] = TradTestSampler):
|
59
|
+
|
60
|
+
"""创建 KGEDataLoader 对象。
|
61
|
+
|
62
|
+
:param in_path: 数据集目录
|
63
|
+
:type in_path: str
|
64
|
+
:param ent_file: entity2id.txt
|
65
|
+
:type ent_file: str
|
66
|
+
:param rel_file: relation2id.txt
|
67
|
+
:type rel_file: str
|
68
|
+
:param train_file: train2id.txt
|
69
|
+
:type train_file: str
|
70
|
+
:param valid_file: valid2id.txt
|
71
|
+
:type valid_file: str
|
72
|
+
:param test_file: test2id.txt
|
73
|
+
:type test_file: str
|
74
|
+
:param batch_size: batch size
|
75
|
+
:type batch_size: int | None
|
76
|
+
:param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity;对于 CompGCN 不起作用。
|
77
|
+
:type neg_ent: int
|
78
|
+
:param test: 是否读取验证集和测试集
|
79
|
+
:type test: bool
|
80
|
+
:param test_batch_size: test batch size
|
81
|
+
:type test_batch_size: int | None
|
82
|
+
:param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
|
83
|
+
:type type_constrain: bool
|
84
|
+
:param num_workers: 加载数据的进程数
|
85
|
+
:type num_workers: int
|
86
|
+
:param train_sampler: 训练数据采样器
|
87
|
+
:type train_sampler: typing.Union[typing.Type[UniSampler], typing.Type[BernSampler], typing.Type[RGCNSampler], typing.Type[CompGCNSampler]]
|
88
|
+
:param test_sampler: 测试数据采样器
|
89
|
+
:type test_sampler: typing.Type[TestSampler]
|
90
|
+
"""
|
91
|
+
|
92
|
+
#: 数据集目录
|
93
|
+
self.in_path: str = in_path
|
94
|
+
#: entity2id.txt
|
95
|
+
self.ent_file: str = ent_file
|
96
|
+
#: relation2id.txt
|
97
|
+
self.rel_file: str = rel_file
|
98
|
+
#: train2id.txt
|
99
|
+
self.train_file: str = train_file
|
100
|
+
#: valid2id.txt
|
101
|
+
self.valid_file: str = valid_file
|
102
|
+
#: test2id.txt
|
103
|
+
self.test_file: str = test_file
|
104
|
+
#: batch size
|
105
|
+
self.batch_size: int = batch_size
|
106
|
+
#: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity;对于 CompGCN 不起作用。
|
107
|
+
self.neg_ent: int = neg_ent
|
108
|
+
#: 是否读取验证集和测试集
|
109
|
+
self.test: bool = test
|
110
|
+
#: test batch size
|
111
|
+
self.test_batch_size: int = test_batch_size
|
112
|
+
#: 是否报告 type_constrain.txt 限制的测试结果
|
113
|
+
self.type_constrain: bool = type_constrain
|
114
|
+
#: 加载数据的进程数
|
115
|
+
self.num_workers: int = num_workers
|
116
|
+
|
117
|
+
#: 训练数据采样器
|
118
|
+
self.train_sampler: typing.Union[UniSampler, BernSampler, RGCNSampler, CompGCNSampler] = train_sampler(
|
119
|
+
in_path=self.in_path,
|
120
|
+
ent_file=self.ent_file,
|
121
|
+
rel_file=self.rel_file,
|
122
|
+
train_file=self.train_file,
|
123
|
+
batch_size=self.batch_size,
|
124
|
+
neg_ent=self.neg_ent
|
125
|
+
)
|
126
|
+
|
127
|
+
#: 训练集三元组
|
128
|
+
self.data_train: list[tuple[int, int, int]] = self.train_sampler.get_train()
|
129
|
+
|
130
|
+
if self.test:
|
131
|
+
#: 测试数据采样器
|
132
|
+
self.test_sampler: TestSampler = test_sampler(
|
133
|
+
sampler=self.train_sampler,
|
134
|
+
valid_file=self.valid_file,
|
135
|
+
test_file=self.test_file,
|
136
|
+
type_constrain=type_constrain
|
137
|
+
)
|
138
|
+
|
139
|
+
#: 验证集三元组
|
140
|
+
self.data_val: list[tuple[int, int, int]] = self.test_sampler.get_valid()
|
141
|
+
#: 测试集三元组
|
142
|
+
self.data_test: list[tuple[int, int, int]] = self.test_sampler.get_test()
|
143
|
+
|
144
|
+
def get_ent_tol(self) -> int:
|
145
|
+
|
146
|
+
"""返回实体个数。
|
147
|
+
|
148
|
+
:returns: 实体个数
|
149
|
+
:rtype: int
|
150
|
+
"""
|
151
|
+
|
152
|
+
return self.train_sampler.ent_tol
|
153
|
+
|
154
|
+
def get_rel_tol(self) -> int:
|
155
|
+
|
156
|
+
"""返回关系个数。
|
157
|
+
|
158
|
+
:returns: 关系个数
|
159
|
+
:rtype: int
|
160
|
+
"""
|
161
|
+
|
162
|
+
return self.train_sampler.rel_tol
|
163
|
+
|
164
|
+
def train_dataloader(self) -> DataLoader:
|
165
|
+
|
166
|
+
"""返回训练数据加载器。
|
167
|
+
|
168
|
+
:returns: 训练数据加载器
|
169
|
+
:rtype: torch.utils.data.DataLoader
|
170
|
+
"""
|
171
|
+
|
172
|
+
return DataLoader(
|
173
|
+
self.data_train,
|
174
|
+
shuffle=True,
|
175
|
+
batch_size=self.batch_size,
|
176
|
+
num_workers=self.num_workers,
|
177
|
+
pin_memory=True,
|
178
|
+
drop_last=True,
|
179
|
+
collate_fn=self.train_sampler.sampling,
|
180
|
+
)
|
181
|
+
|
182
|
+
def val_dataloader(self) -> DataLoader:
|
183
|
+
|
184
|
+
"""返回验证数据加载器。
|
185
|
+
|
186
|
+
:returns: 验证数据加载器
|
187
|
+
:rtype: torch.utils.data.DataLoader
|
188
|
+
"""
|
189
|
+
|
190
|
+
return DataLoader(
|
191
|
+
self.data_val,
|
192
|
+
shuffle=False,
|
193
|
+
batch_size=self.test_batch_size,
|
194
|
+
num_workers=self.num_workers,
|
195
|
+
pin_memory=True,
|
196
|
+
collate_fn=self.test_sampler.sampling,
|
197
|
+
)
|
198
|
+
|
199
|
+
def test_dataloader(self) -> DataLoader:
|
200
|
+
|
201
|
+
"""返回测试数据加载器。
|
202
|
+
|
203
|
+
:returns: 测试数据加载器
|
204
|
+
:rtype: torch.utils.data.DataLoader"""
|
205
|
+
|
206
|
+
return DataLoader(
|
207
|
+
self.data_test,
|
208
|
+
shuffle=False,
|
209
|
+
batch_size=self.test_batch_size,
|
210
|
+
num_workers=self.num_workers,
|
211
|
+
pin_memory=True,
|
212
|
+
collate_fn=self.test_sampler.sampling,
|
213
|
+
)
|
214
|
+
|
215
|
+
def get_kge_data_loader_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
216
|
+
|
217
|
+
"""返回 :py:class:`KGEDataLoader` 的默认超参数优化配置。
|
218
|
+
|
219
|
+
默认配置为::
|
220
|
+
|
221
|
+
parameters_dict = {
|
222
|
+
'dataloader': {
|
223
|
+
'value': 'KGEDataLoader'
|
224
|
+
},
|
225
|
+
'in_path': {
|
226
|
+
'value': './'
|
227
|
+
},
|
228
|
+
'ent_file': {
|
229
|
+
'value': 'entity2id.txt'
|
230
|
+
},
|
231
|
+
'rel_file': {
|
232
|
+
'value': 'relation2id.txt'
|
233
|
+
},
|
234
|
+
'train_file': {
|
235
|
+
'value': 'train2id.txt'
|
236
|
+
},
|
237
|
+
'valid_file': {
|
238
|
+
'value': 'valid2id.txt'
|
239
|
+
},
|
240
|
+
'test_file': {
|
241
|
+
'value': 'test2id.txt'
|
242
|
+
},
|
243
|
+
'batch_size': {
|
244
|
+
'values': [512, 1024, 2048, 4096]
|
245
|
+
},
|
246
|
+
'neg_ent': {
|
247
|
+
'values': [1, 4, 16, 64]
|
248
|
+
},
|
249
|
+
'test_batch_size': {
|
250
|
+
'value': 30
|
251
|
+
},
|
252
|
+
'type_constrain': {
|
253
|
+
'value': True
|
254
|
+
},
|
255
|
+
'num_workers': {
|
256
|
+
'value': 16
|
257
|
+
},
|
258
|
+
'train_sampler': {
|
259
|
+
'value': 'BernSampler'
|
260
|
+
},
|
261
|
+
'test_sampler': {
|
262
|
+
'value': 'TradTestSampler'
|
263
|
+
}
|
264
|
+
}
|
265
|
+
|
266
|
+
:returns: :py:class:`KGEDataLoader` 的默认超参数优化配置
|
267
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
268
|
+
"""
|
269
|
+
|
270
|
+
parameters_dict = {
|
271
|
+
'dataloader': {
|
272
|
+
'value': 'KGEDataLoader'
|
273
|
+
},
|
274
|
+
'in_path': {
|
275
|
+
'value': './'
|
276
|
+
},
|
277
|
+
'ent_file': {
|
278
|
+
'value': 'entity2id.txt'
|
279
|
+
},
|
280
|
+
'rel_file': {
|
281
|
+
'value': 'relation2id.txt'
|
282
|
+
},
|
283
|
+
'train_file': {
|
284
|
+
'value': 'train2id.txt'
|
285
|
+
},
|
286
|
+
'valid_file': {
|
287
|
+
'value': 'valid2id.txt'
|
288
|
+
},
|
289
|
+
'test_file': {
|
290
|
+
'value': 'test2id.txt'
|
291
|
+
},
|
292
|
+
'batch_size': {
|
293
|
+
'values': [512, 1024, 2048, 4096]
|
294
|
+
},
|
295
|
+
'neg_ent': {
|
296
|
+
'values': [1, 4, 16, 64]
|
297
|
+
},
|
298
|
+
'test_batch_size': {
|
299
|
+
'value': 30
|
300
|
+
},
|
301
|
+
'type_constrain': {
|
302
|
+
'value': True
|
303
|
+
},
|
304
|
+
'num_workers': {
|
305
|
+
'value': 16
|
306
|
+
},
|
307
|
+
'train_sampler': {
|
308
|
+
'value': 'BernSampler'
|
309
|
+
},
|
310
|
+
'test_sampler': {
|
311
|
+
'value': 'TradTestSampler'
|
312
|
+
}
|
313
|
+
}
|
314
|
+
|
315
|
+
return parameters_dict
|