unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,189 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/TestSampler.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
|
7
|
+
#
|
8
|
+
# 测试数据采样器基类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
TestSampler - 测试数据采样器基类。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import os
|
15
|
+
import torch
|
16
|
+
import typing
|
17
|
+
from .TradSampler import TradSampler
|
18
|
+
from .RGCNSampler import RGCNSampler
|
19
|
+
from .CompGCNSampler import CompGCNSampler
|
20
|
+
from collections import defaultdict as ddict
|
21
|
+
from ..utils import construct_type_constrain
|
22
|
+
|
23
|
+
class TestSampler(object):
|
24
|
+
|
25
|
+
"""测试数据采样器基类。
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
sampler: typing.Union[TradSampler, RGCNSampler, CompGCNSampler],
|
31
|
+
valid_file: str = "valid2id.txt",
|
32
|
+
test_file: str = "test2id.txt",
|
33
|
+
type_constrain: bool = True):
|
34
|
+
|
35
|
+
"""创建 TestSampler 对象。
|
36
|
+
|
37
|
+
:param sampler: 训练数据采样器。
|
38
|
+
:type sampler: typing.Union[TradSampler, RGCNSampler, CompGCNSampler]
|
39
|
+
:param valid_file: valid2id.txt
|
40
|
+
:type valid_file: str
|
41
|
+
:param test_file: test2id.txt
|
42
|
+
:type test_file: str
|
43
|
+
:param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
|
44
|
+
:type type_constrain: bool
|
45
|
+
"""
|
46
|
+
|
47
|
+
#: 训练数据采样器
|
48
|
+
self.sampler: typing.Union[TradSampler, RGCNSampler, CompGCNSampler] = sampler
|
49
|
+
#: 实体的个数
|
50
|
+
self.ent_tol: int = sampler.ent_tol
|
51
|
+
#: valid2id.txt
|
52
|
+
self.valid_file: str = valid_file
|
53
|
+
#: test2id.txt
|
54
|
+
self.test_file: str = test_file
|
55
|
+
|
56
|
+
#: 验证集三元组的个数
|
57
|
+
self.valid_tol: int = 0
|
58
|
+
#: 测试集三元组的个数
|
59
|
+
self.test_tol: int = 0
|
60
|
+
|
61
|
+
#: 验证集三元组
|
62
|
+
self.valid_triples: list[tuple[int, int, int]] = []
|
63
|
+
#: 测试集三元组
|
64
|
+
self.test_triples: list[tuple[int, int, int]] = []
|
65
|
+
#: 知识图谱所有三元组
|
66
|
+
self.all_true_triples: set[tuple[int, int, int]] = set()
|
67
|
+
|
68
|
+
self.get_valid_test_triples_id()
|
69
|
+
|
70
|
+
#: 知识图谱中所有 h-r 对对应的 t 集合
|
71
|
+
self.hr2t_all: ddict[set] = ddict(set)
|
72
|
+
#: 知识图谱中所有 r-t 对对应的 h 集合
|
73
|
+
self.rt2h_all: ddict[set] = ddict(set)
|
74
|
+
|
75
|
+
#: 是否报告 type_constrain.txt 限制的测试结果
|
76
|
+
self.type_constrain: bool = type_constrain
|
77
|
+
|
78
|
+
if self.type_constrain:
|
79
|
+
construct_type_constrain(
|
80
|
+
in_path = self.sampler.in_path,
|
81
|
+
train_file = self.sampler.train_file,
|
82
|
+
valid_file = self.valid_file,
|
83
|
+
test_file = self.test_file
|
84
|
+
)
|
85
|
+
#: 知识图谱中所有 r 存在头实体种类
|
86
|
+
self.rel_heads: ddict[set] = ddict(set)
|
87
|
+
#: 知识图谱中所有 r 存在尾实体种类
|
88
|
+
self.rel_tails: ddict[set] = ddict(set)
|
89
|
+
self.get_type_constrain_id()
|
90
|
+
|
91
|
+
def get_valid_test_triples_id(self):
|
92
|
+
|
93
|
+
"""读取 :py:attr:`valid_file` 文件和 :py:attr:`test_file` 文件。"""
|
94
|
+
|
95
|
+
with open(os.path.join(self.sampler.in_path, self.valid_file)) as f:
|
96
|
+
self.valid_tol = (int)(f.readline())
|
97
|
+
for line in f:
|
98
|
+
h, t, r = line.strip().split()
|
99
|
+
self.valid_triples.append((int(h), int(r), int(t)))
|
100
|
+
|
101
|
+
with open(os.path.join(self.sampler.in_path, self.test_file)) as f:
|
102
|
+
self.test_tol = (int)(f.readline())
|
103
|
+
for line in f:
|
104
|
+
h, t, r = line.strip().split()
|
105
|
+
self.test_triples.append((int(h), int(r), int(t)))
|
106
|
+
|
107
|
+
self.all_true_triples = set(
|
108
|
+
self.sampler.train_triples + self.valid_triples + self.test_triples
|
109
|
+
)
|
110
|
+
|
111
|
+
def get_type_constrain_id(self):
|
112
|
+
|
113
|
+
"""读取 type_constrain.txt 文件。"""
|
114
|
+
|
115
|
+
with open(os.path.join(self.sampler.in_path, "type_constrain.txt")) as f:
|
116
|
+
rel_tol = (int)(f.readline())
|
117
|
+
first_line = True
|
118
|
+
for line in f:
|
119
|
+
rel_types = line.strip().split("\t")
|
120
|
+
for entity in rel_types[2:]:
|
121
|
+
if first_line:
|
122
|
+
self.rel_heads[int(rel_types[0])].add(int(entity))
|
123
|
+
else:
|
124
|
+
self.rel_tails[int(rel_types[0])].add(int(entity))
|
125
|
+
first_line = not first_line
|
126
|
+
|
127
|
+
for rel in self.rel_heads:
|
128
|
+
self.rel_heads[rel] = torch.tensor(list(self.rel_heads[rel]))
|
129
|
+
for rel in self.rel_tails:
|
130
|
+
self.rel_tails[rel] = torch.tensor(list(self.rel_tails[rel]))
|
131
|
+
|
132
|
+
def get_hr2t_rt2h_from_all(self):
|
133
|
+
|
134
|
+
"""获得 :py:attr:`hr2t_all` 和 :py:attr:`rt2h_all` 。"""
|
135
|
+
|
136
|
+
for h, r, t in self.all_true_triples:
|
137
|
+
self.hr2t_all[(h, r)].add(t)
|
138
|
+
self.rt2h_all[(r, t)].add(h)
|
139
|
+
for h, r in self.hr2t_all:
|
140
|
+
self.hr2t_all[(h, r)] = torch.tensor(list(self.hr2t_all[(h, r)]))
|
141
|
+
for r, t in self.rt2h_all:
|
142
|
+
self.rt2h_all[(r, t)] = torch.tensor(list(self.rt2h_all[(r, t)]))
|
143
|
+
|
144
|
+
def sampling(
|
145
|
+
self,
|
146
|
+
data: list[tuple[int, int, int]]) -> dict[str, torch.Tensor]:
|
147
|
+
|
148
|
+
"""采样函数。该方法未实现,子类必须重写该方法,否则抛出 :py:class:`NotImplementedError` 错误。
|
149
|
+
|
150
|
+
:param data: 测试的正确三元组
|
151
|
+
:type data: list[tuple[int, int, int]]
|
152
|
+
:returns: 测试数据
|
153
|
+
:rtype: dict[str, torch.Tensor]
|
154
|
+
"""
|
155
|
+
|
156
|
+
raise NotImplementedError
|
157
|
+
|
158
|
+
def get_valid(self) -> list[tuple[int, int, int]]:
|
159
|
+
|
160
|
+
"""
|
161
|
+
返回验证集三元组。
|
162
|
+
|
163
|
+
:returns: :py:attr:`valid_triples`
|
164
|
+
:rtype: list[tuple[int, int, int]]
|
165
|
+
"""
|
166
|
+
|
167
|
+
return self.valid_triples
|
168
|
+
|
169
|
+
def get_test(self) -> list[tuple[int, int, int]]:
|
170
|
+
|
171
|
+
"""
|
172
|
+
返回测试集三元组。
|
173
|
+
|
174
|
+
:returns: :py:attr:`test_triples`
|
175
|
+
:rtype: list[tuple[int, int, int]]
|
176
|
+
"""
|
177
|
+
|
178
|
+
return self.test_triples
|
179
|
+
|
180
|
+
def get_all_true_triples(self) -> set[tuple[int, int, int]]:
|
181
|
+
|
182
|
+
"""
|
183
|
+
返回知识图谱所有三元组。
|
184
|
+
|
185
|
+
:returns: :py:attr:`all_true_triples`
|
186
|
+
:rtype: set[tuple[int, int, int]]
|
187
|
+
"""
|
188
|
+
|
189
|
+
return self.all_true_triples
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/TradSampler.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 28, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
|
7
|
+
#
|
8
|
+
# 为 KGReader 增加构建负三元组的函数,用于平移模型和语义匹配模型.
|
9
|
+
|
10
|
+
"""
|
11
|
+
TradSampler - 为 KGReader 增加构建负三元组的函数,用于平移模型和语义匹配模型。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import typing
|
16
|
+
import numpy as np
|
17
|
+
from .KGReader import KGReader
|
18
|
+
|
19
|
+
class TradSampler(KGReader):
|
20
|
+
|
21
|
+
"""
|
22
|
+
平移模型和语义匹配模型的采样器的基类。
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
in_path: str = "./",
|
28
|
+
ent_file: str = "entity2id.txt",
|
29
|
+
rel_file: str = "relation2id.txt",
|
30
|
+
train_file: str = "train2id.txt",
|
31
|
+
batch_size: int | None = None,
|
32
|
+
neg_ent: int = 1):
|
33
|
+
|
34
|
+
"""创建 TradSampler 对象。
|
35
|
+
|
36
|
+
:param in_path: 数据集目录
|
37
|
+
:type in_path: str
|
38
|
+
:param ent_file: entity2id.txt
|
39
|
+
:type ent_file: str
|
40
|
+
:param rel_file: relation2id.txt
|
41
|
+
:type rel_file: str
|
42
|
+
:param train_file: train2id.txt
|
43
|
+
:type train_file: str
|
44
|
+
:param batch_size: batch size 在该采样器中不起作用,只是占位符。
|
45
|
+
:type batch_size: int | None
|
46
|
+
:param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity
|
47
|
+
:type neg_ent: int
|
48
|
+
"""
|
49
|
+
|
50
|
+
super().__init__(
|
51
|
+
in_path=in_path,
|
52
|
+
ent_file=ent_file,
|
53
|
+
rel_file=rel_file,
|
54
|
+
train_file=train_file
|
55
|
+
)
|
56
|
+
|
57
|
+
#: batch size
|
58
|
+
self.batch_size: int = batch_size
|
59
|
+
#: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)
|
60
|
+
self.neg_ent: int = neg_ent
|
61
|
+
|
62
|
+
self.get_hr2t_rt2h_from_train()
|
63
|
+
|
64
|
+
def sampling(
|
65
|
+
self,
|
66
|
+
pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[str, torch.Tensor]]:
|
67
|
+
|
68
|
+
"""平移模型和语义匹配模型的训练集数据采样函数。该方法未实现,子类必须重写该方法,否则抛出 :py:class:`NotImplementedError` 错误。
|
69
|
+
|
70
|
+
:param pos_triples: 知识图谱中的正确三元组
|
71
|
+
:type pos_triples: list[tuple[int, int, int]]
|
72
|
+
:returns: 平移模型和语义匹配模型的训练数据
|
73
|
+
:rtype: dict[str, typing.Union[str, torch.Tensor]]
|
74
|
+
"""
|
75
|
+
|
76
|
+
raise NotImplementedError
|
77
|
+
|
78
|
+
def corrupt_head(
|
79
|
+
self,
|
80
|
+
t: int,
|
81
|
+
r: int,
|
82
|
+
num_max: int = 1) -> np.ndarray:
|
83
|
+
|
84
|
+
"""替换头实体构建负三元组。
|
85
|
+
|
86
|
+
:param t: 尾实体
|
87
|
+
:type t: int
|
88
|
+
:param r: 关系
|
89
|
+
:type r: int
|
90
|
+
:param num_max: 一次负采样的个数
|
91
|
+
:type num_max: int
|
92
|
+
:returns: 负三元组的头实体列表
|
93
|
+
:rtype: numpy.ndarray
|
94
|
+
"""
|
95
|
+
|
96
|
+
tmp = torch.randint(low=0, high=self.ent_tol, size=(num_max,)).numpy()
|
97
|
+
mask = np.in1d(tmp, self.rt2h_train[(r, t)], assume_unique=True, invert=True)
|
98
|
+
neg = tmp[mask]
|
99
|
+
return neg
|
100
|
+
|
101
|
+
def corrupt_tail(
|
102
|
+
self,
|
103
|
+
h: int,
|
104
|
+
r: int,
|
105
|
+
num_max: int = 1) -> np.ndarray:
|
106
|
+
|
107
|
+
"""替换尾实体构建负三元组。
|
108
|
+
|
109
|
+
:param h: 头实体
|
110
|
+
:type h: int
|
111
|
+
:param r: 关系
|
112
|
+
:type r: int
|
113
|
+
:param num_max: 一次负采样的个数
|
114
|
+
:type num_max: int
|
115
|
+
:returns: 负三元组的尾实体列表
|
116
|
+
:rtype: numpy.ndarray
|
117
|
+
"""
|
118
|
+
|
119
|
+
tmp = torch.randint(low=0, high=self.ent_tol, size=(num_max,)).numpy()
|
120
|
+
mask = np.in1d(tmp, self.hr2t_train[(h, r)], assume_unique=True, invert=True)
|
121
|
+
neg = tmp[mask]
|
122
|
+
return neg
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/TradTestSampler.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
|
7
|
+
#
|
8
|
+
# 平移模型和语义匹配模型的测试数据采样器.
|
9
|
+
|
10
|
+
"""
|
11
|
+
TradTestSampler - 平移模型和语义匹配模型的测试数据采样器。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
from .TradSampler import TradSampler
|
16
|
+
from .TestSampler import TestSampler
|
17
|
+
from typing_extensions import override
|
18
|
+
|
19
|
+
class TradTestSampler(TestSampler):
|
20
|
+
|
21
|
+
"""平移模型和语义匹配模型的测试数据采样器。
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(
|
25
|
+
self,
|
26
|
+
sampler: TradSampler,
|
27
|
+
valid_file: str = "valid2id.txt",
|
28
|
+
test_file: str = "test2id.txt",
|
29
|
+
type_constrain: bool = True):
|
30
|
+
|
31
|
+
"""创建 TradTestSampler 对象。
|
32
|
+
|
33
|
+
:param sampler: 训练数据采样器。
|
34
|
+
:type sampler: TradSampler
|
35
|
+
:param valid_file: valid2id.txt
|
36
|
+
:type valid_file: str
|
37
|
+
:param test_file: test2id.txt
|
38
|
+
:type test_file: str
|
39
|
+
:param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
|
40
|
+
:type type_constrain: bool
|
41
|
+
"""
|
42
|
+
|
43
|
+
super().__init__(
|
44
|
+
sampler=sampler,
|
45
|
+
valid_file=valid_file,
|
46
|
+
test_file=test_file,
|
47
|
+
type_constrain=type_constrain
|
48
|
+
)
|
49
|
+
|
50
|
+
self.get_hr2t_rt2h_from_all()
|
51
|
+
|
52
|
+
@override
|
53
|
+
def sampling(
|
54
|
+
self,
|
55
|
+
data: list[tuple[int, int, int]]) -> dict[str, torch.Tensor]:
|
56
|
+
|
57
|
+
"""采样函数。
|
58
|
+
|
59
|
+
:param data: 测试的正确三元组
|
60
|
+
:type data: list[tuple[int, int, int]]
|
61
|
+
:returns: 测试数据
|
62
|
+
:rtype: dict[str, torch.Tensor]
|
63
|
+
"""
|
64
|
+
|
65
|
+
batch_data = {}
|
66
|
+
head_label = torch.zeros(len(data), self.ent_tol)
|
67
|
+
tail_label = torch.zeros(len(data), self.ent_tol)
|
68
|
+
for idx, triple in enumerate(data):
|
69
|
+
head, rel, tail = triple
|
70
|
+
head_label[idx][self.rt2h_all[(rel, tail)]] = 1.0
|
71
|
+
tail_label[idx][self.hr2t_all[(head, rel)]] = 1.0
|
72
|
+
|
73
|
+
if self.type_constrain:
|
74
|
+
head_label_type = torch.ones(len(data), self.ent_tol)
|
75
|
+
tail_laebl_type = torch.ones(len(data), self.ent_tol)
|
76
|
+
for idx, triple in enumerate(data):
|
77
|
+
head, rel, tail = triple
|
78
|
+
head_label_type[idx][self.rel_heads[rel]] = 0.0
|
79
|
+
tail_laebl_type[idx][self.rel_tails[rel]] = 0.0
|
80
|
+
head_label_type[idx][self.rt2h_all[(rel, tail)]] = 1.0
|
81
|
+
tail_laebl_type[idx][self.hr2t_all[(head, rel)]] = 1.0
|
82
|
+
batch_data["head_label_type"] = head_label_type
|
83
|
+
batch_data["tail_label_type"] = tail_laebl_type
|
84
|
+
batch_data["positive_sample"] = torch.tensor(data)
|
85
|
+
batch_data["head_label"] = head_label
|
86
|
+
batch_data["tail_label"] = tail_label
|
87
|
+
return batch_data
|
unike/data/UniSampler.py
ADDED
@@ -0,0 +1,145 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/UniSampler.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
|
7
|
+
#
|
8
|
+
# 平移模型和语义匹配模型的训练集数据采样器.
|
9
|
+
|
10
|
+
"""
|
11
|
+
UniSampler - 平移模型和语义匹配模型的训练集数据采样器。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import typing
|
16
|
+
import warnings
|
17
|
+
import numpy as np
|
18
|
+
from .TradSampler import TradSampler
|
19
|
+
from typing_extensions import override
|
20
|
+
|
21
|
+
warnings.filterwarnings("ignore")
|
22
|
+
|
23
|
+
class UniSampler(TradSampler):
|
24
|
+
|
25
|
+
"""平移模型和语义匹配模型的训练集普通的数据采样器(均值分布)。
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
in_path: str = "./",
|
31
|
+
ent_file: str = "entity2id.txt",
|
32
|
+
rel_file: str = "relation2id.txt",
|
33
|
+
train_file: str = "train2id.txt",
|
34
|
+
batch_size: int | None = None,
|
35
|
+
neg_ent: int = 1):
|
36
|
+
|
37
|
+
"""创建 UniSampler 对象。
|
38
|
+
|
39
|
+
:param in_path: 数据集目录
|
40
|
+
:type in_path: str
|
41
|
+
:param ent_file: entity2id.txt
|
42
|
+
:type ent_file: str
|
43
|
+
:param rel_file: relation2id.txt
|
44
|
+
:type rel_file: str
|
45
|
+
:param train_file: train2id.txt
|
46
|
+
:type train_file: str
|
47
|
+
:param batch_size: batch size 在该采样器中不起作用,只是占位符。
|
48
|
+
:type batch_size: int | None
|
49
|
+
:param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity
|
50
|
+
:type neg_ent: int
|
51
|
+
"""
|
52
|
+
|
53
|
+
super().__init__(
|
54
|
+
in_path=in_path,
|
55
|
+
ent_file=ent_file,
|
56
|
+
rel_file=rel_file,
|
57
|
+
train_file=train_file,
|
58
|
+
batch_size = batch_size,
|
59
|
+
neg_ent = neg_ent
|
60
|
+
)
|
61
|
+
|
62
|
+
self.cross_sampling_flag = 0
|
63
|
+
|
64
|
+
@override
|
65
|
+
def sampling(
|
66
|
+
self,
|
67
|
+
pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[str, torch.Tensor]]:
|
68
|
+
|
69
|
+
"""平移模型和语义匹配模型的训练集普通的数据采样函数(均匀分布)。
|
70
|
+
|
71
|
+
:param pos_triples: 知识图谱中的正确三元组
|
72
|
+
:type pos_triples: list[tuple[int, int, int]]
|
73
|
+
:returns: 平移模型和语义匹配模型的训练数据
|
74
|
+
:rtype: dict[str, typing.Union[str, torch.Tensor]]
|
75
|
+
"""
|
76
|
+
|
77
|
+
batch_data = {}
|
78
|
+
neg_ent_sample = []
|
79
|
+
self.cross_sampling_flag = 1 - self.cross_sampling_flag
|
80
|
+
if self.cross_sampling_flag == 0:
|
81
|
+
batch_data['mode'] = "head-batch"
|
82
|
+
for h, r, t in pos_triples:
|
83
|
+
neg_head = self.head_batch(t, r, self.neg_ent)
|
84
|
+
neg_ent_sample.append(neg_head)
|
85
|
+
else:
|
86
|
+
batch_data['mode'] = "tail-batch"
|
87
|
+
for h, r, t in pos_triples:
|
88
|
+
neg_tail = self.tail_batch(h, r, self.neg_ent)
|
89
|
+
neg_ent_sample.append(neg_tail)
|
90
|
+
|
91
|
+
batch_data["positive_sample"] = torch.LongTensor(np.array(pos_triples))
|
92
|
+
batch_data['negative_sample'] = torch.LongTensor(np.array(neg_ent_sample))
|
93
|
+
return batch_data
|
94
|
+
|
95
|
+
def head_batch(
|
96
|
+
self,
|
97
|
+
t: int,
|
98
|
+
r: int,
|
99
|
+
neg_size: int= None) -> np.ndarray:
|
100
|
+
|
101
|
+
"""替换头实体构建负三元组。
|
102
|
+
|
103
|
+
:param t: 尾实体
|
104
|
+
:type t: int
|
105
|
+
:param r: 关系
|
106
|
+
:type r: int
|
107
|
+
:param neg_size: 负三元组个数
|
108
|
+
:type neg_size: int
|
109
|
+
:returns: 负三元组中的头实体列表
|
110
|
+
:rtype: numpy.ndarray
|
111
|
+
"""
|
112
|
+
|
113
|
+
neg_list = []
|
114
|
+
neg_cur_size = 0
|
115
|
+
while neg_cur_size < neg_size:
|
116
|
+
neg_tmp = self.corrupt_head(t, r, num_max=(neg_size - neg_cur_size) * 2)
|
117
|
+
neg_list.append(neg_tmp)
|
118
|
+
neg_cur_size += len(neg_tmp)
|
119
|
+
return np.concatenate(neg_list)[:neg_size]
|
120
|
+
|
121
|
+
def tail_batch(
|
122
|
+
self,
|
123
|
+
h: int,
|
124
|
+
r: int,
|
125
|
+
neg_size: int = None) -> np.ndarray:
|
126
|
+
|
127
|
+
"""替换尾实体构建负三元组。
|
128
|
+
|
129
|
+
:param h: 头实体
|
130
|
+
:type h: int
|
131
|
+
:param r: 关系
|
132
|
+
:type r: int
|
133
|
+
:param neg_size: 负三元组个数
|
134
|
+
:type neg_size: int
|
135
|
+
:returns: 负三元组中的尾实体列表
|
136
|
+
:rtype: numpy.ndarray
|
137
|
+
"""
|
138
|
+
|
139
|
+
neg_list = []
|
140
|
+
neg_cur_size = 0
|
141
|
+
while neg_cur_size < neg_size:
|
142
|
+
neg_tmp = self.corrupt_tail(h, r, num_max=(neg_size - neg_cur_size) * 2)
|
143
|
+
neg_list.append(neg_tmp)
|
144
|
+
neg_cur_size += len(neg_tmp)
|
145
|
+
return np.concatenate(neg_list)[:neg_size]
|
unike/data/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/data/__init__.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2023
|
7
|
+
#
|
8
|
+
# 该头文件定义了 data 接口.
|
9
|
+
|
10
|
+
"""数据采样部分,包含为训练和验证模型定义的数据采样器。"""
|
11
|
+
|
12
|
+
from __future__ import absolute_import
|
13
|
+
from __future__ import division
|
14
|
+
from __future__ import print_function
|
15
|
+
|
16
|
+
from .KGReader import KGReader
|
17
|
+
|
18
|
+
from .TradSampler import TradSampler
|
19
|
+
from .UniSampler import UniSampler
|
20
|
+
from .BernSampler import BernSampler
|
21
|
+
|
22
|
+
from .RevSampler import RevSampler
|
23
|
+
from .RGCNSampler import RGCNSampler
|
24
|
+
from .CompGCNSampler import CompGCNSampler
|
25
|
+
|
26
|
+
from .TestSampler import TestSampler
|
27
|
+
from .TradTestSampler import TradTestSampler
|
28
|
+
from .RGCNTestSampler import RGCNTestSampler
|
29
|
+
from .CompGCNTestSampler import CompGCNTestSampler
|
30
|
+
|
31
|
+
from .KGEDataLoader import KGEDataLoader, get_kge_data_loader_hpo_config
|
32
|
+
|
33
|
+
__all__ = [
|
34
|
+
'KGReader',
|
35
|
+
'TradSampler',
|
36
|
+
'UniSampler',
|
37
|
+
'BernSampler',
|
38
|
+
'RevSampler',
|
39
|
+
'RGCNSampler',
|
40
|
+
'CompGCNSampler',
|
41
|
+
'TestSampler',
|
42
|
+
'TradTestSampler',
|
43
|
+
'RGCNTestSampler',
|
44
|
+
'CompGCNTestSampler',
|
45
|
+
'KGEDataLoader',
|
46
|
+
'get_kge_data_loader_hpo_config'
|
47
|
+
]
|