unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
unike/data/KGReader.py ADDED
@@ -0,0 +1,138 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/KGReader.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 17, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
7
+ #
8
+ # 从文件中读取知识图谱.
9
+
10
+ """
11
+ KGReader - 从文件中读取知识图谱。
12
+ """
13
+
14
+ import os
15
+ import numpy as np
16
+ import collections
17
+
18
+ class KGReader:
19
+
20
+ """
21
+ 从文件中读取知识图谱。
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ in_path: str = "./",
27
+ ent_file: str = "entity2id.txt",
28
+ rel_file: str = "relation2id.txt",
29
+ train_file: str = "train2id.txt"):
30
+
31
+ """创建 KGReader 对象。
32
+
33
+ :param in_path: 数据集目录
34
+ :type in_path: str
35
+ :param ent_file: entity2id.txt
36
+ :type ent_file: str
37
+ :param rel_file: relation2id.txt
38
+ :type rel_file: str
39
+ :param train_file: train2id.txt
40
+ :type train_file: str
41
+ """
42
+
43
+ #: 数据集目录
44
+ self.in_path: str = in_path
45
+ #: entity2id.txt
46
+ self.ent_file: str = ent_file
47
+ #: relation2id.txt
48
+ self.rel_file: str = rel_file
49
+ #: train2id.txt
50
+ self.train_file: str = train_file
51
+
52
+ #: 实体的个数
53
+ self.ent_tol: int = 0
54
+ #: 关系的个数
55
+ self.rel_tol: int = 0
56
+ #: 训练集三元组的个数
57
+ self.train_tol: int = 0
58
+
59
+ #: 实体->ID
60
+ self.ent2id: dict = {}
61
+ #: 关系->ID
62
+ self.rel2id: dict = {}
63
+ #: ID->实体
64
+ self.id2ent: dict = {}
65
+ #: ID->关系
66
+ self.id2rel: dict = {}
67
+
68
+ #: 训练集三元组
69
+ self.train_triples: list[tuple[int, int, int]] = []
70
+
71
+ #: 训练集中所有 h-r 对对应的 t 集合
72
+ self.hr2t_train: collections.defaultdict[set] = collections.defaultdict(set)
73
+ #: 训练集中所有 r-t 对对应的 h 集合
74
+ self.rt2h_train: collections.defaultdict[set] = collections.defaultdict(set)
75
+
76
+ self.get_id()
77
+ self.get_train_triples_id()
78
+
79
+ def get_id(self):
80
+
81
+ """读取 :py:attr:`ent_file` 文件和 :py:attr:`rel_file` 文件。"""
82
+
83
+ with open(os.path.join(self.in_path, self.ent_file)) as f:
84
+ self.ent_tol = (int)(f.readline())
85
+ for line in f:
86
+ entity, eid = line.strip().split("\t")
87
+ self.ent2id[entity] = int(eid)
88
+ self.id2ent[int(eid)] = entity
89
+
90
+ with open(os.path.join(self.in_path, self.rel_file)) as f:
91
+ self.rel_tol = (int)(f.readline())
92
+ for line in f:
93
+ relation, rid = line.strip().split("\t")
94
+ self.rel2id[relation] = int(rid)
95
+ self.id2rel[int(rid)] = relation
96
+
97
+ def get_train_triples_id(self):
98
+
99
+ """读取 :py:attr:`train_file` 文件。"""
100
+
101
+ with open(os.path.join(self.in_path, self.train_file)) as f:
102
+ self.train_tol = (int)(f.readline())
103
+ for line in f:
104
+ h, t, r = line.strip().split()
105
+ self.train_triples.append((int(h), int(r), int(t)))
106
+
107
+ def get_hr2t_rt2h_from_train(self):
108
+
109
+ """获得 :py:attr:`hr2t_train` 和 :py:attr:`rt2h_train` 。"""
110
+
111
+ for h, r, t in self.train_triples:
112
+ self.hr2t_train[(h, r)].add(t)
113
+ self.rt2h_train[(r, t)].add(h)
114
+ for h, r in self.hr2t_train:
115
+ self.hr2t_train[(h, r)] = np.array(list(self.hr2t_train[(h, r)]))
116
+ for r, t in self.rt2h_train:
117
+ self.rt2h_train[(r, t)] = np.array(list(self.rt2h_train[(r, t)]))
118
+
119
+ def get_hr_train(self):
120
+
121
+ """用于 ``CompGCN`` :cite:`CompGCN` 训练,因为 ``CompGCN`` :cite:`CompGCN` 的组合运算仅需要头实体和关系。
122
+
123
+ 如果想获得更详细的信息请访问 :ref:`CompGCN <compgcn>`。
124
+ """
125
+
126
+ self.t_triples = self.train_triples
127
+ self.train_triples = [(hr, list(t)) for (hr,t) in self.hr2t_train.items()]
128
+
129
+ def get_train(self) -> list[tuple[int, int, int]]:
130
+
131
+ """
132
+ 返回训练集三元组。
133
+
134
+ :returns: :py:attr:`train_triples`
135
+ :rtype: list[tuple[int, int, int]]
136
+ """
137
+
138
+ return self.train_triples
@@ -0,0 +1,261 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/RGCNSampler.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 21, 2024
7
+ #
8
+ # R-GCN 的数据采样器.
9
+
10
+ """
11
+ RGCNSampler - R-GCN 的数据采样器。
12
+ """
13
+
14
+ import dgl
15
+ import torch
16
+ import typing
17
+ import warnings
18
+ import numpy as np
19
+ from .RevSampler import RevSampler
20
+
21
+ warnings.filterwarnings("ignore")
22
+
23
+ class RGCNSampler(RevSampler):
24
+
25
+ """``R-GCN`` :cite:`R-GCN` 的训练数据采样器。
26
+
27
+ 例子::
28
+
29
+ from unike.data import RGCNSampler, CompGCNSampler
30
+ from torch.utils.data import DataLoader
31
+
32
+ #: 训练数据采样器
33
+ train_sampler: typing.Union[typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = train_sampler(
34
+ in_path=in_path,
35
+ ent_file=ent_file,
36
+ rel_file=rel_file,
37
+ train_file=train_file,
38
+ batch_size=batch_size,
39
+ neg_ent=neg_ent
40
+ )
41
+
42
+ #: 训练集三元组
43
+ data_train: list[tuple[int, int, int]] = train_sampler.get_train()
44
+
45
+ train_dataloader = DataLoader(
46
+ data_train,
47
+ shuffle=True,
48
+ batch_size=batch_size,
49
+ num_workers=num_workers,
50
+ pin_memory=True,
51
+ drop_last=True,
52
+ collate_fn=train_sampler.sampling,
53
+ )
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ in_path: str = "./",
59
+ ent_file: str = "entity2id.txt",
60
+ rel_file: str = "relation2id.txt",
61
+ train_file: str = "train2id.txt",
62
+ batch_size: int | None = None,
63
+ neg_ent: int = 1):
64
+
65
+ """创建 RGCNSampler 对象。
66
+
67
+ :param in_path: 数据集目录
68
+ :type in_path: str
69
+ :param ent_file: entity2id.txt
70
+ :type ent_file: str
71
+ :param rel_file: relation2id.txt
72
+ :type rel_file: str
73
+ :param train_file: train2id.txt
74
+ :type train_file: str
75
+ :param batch_size: batch size
76
+ :type batch_size: int | None
77
+ :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)
78
+ :type neg_ent: int
79
+ """
80
+
81
+ super().__init__(
82
+ in_path=in_path,
83
+ ent_file=ent_file,
84
+ rel_file=rel_file,
85
+ train_file=train_file
86
+ )
87
+
88
+ #: batch size
89
+ self.batch_size: int = batch_size
90
+ #: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)
91
+ self.neg_ent: int = neg_ent
92
+
93
+ self.entity = None
94
+ self.triples = None
95
+ self.label = None
96
+ self.graph = None
97
+ self.relation = None
98
+ self.norm = None
99
+
100
+ def sampling(
101
+ self,
102
+ pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]:
103
+
104
+ """``R-GCN`` :cite:`R-GCN` 的采样函数。
105
+
106
+ :param pos_triples: 知识图谱中的正确三元组
107
+ :type pos_triples: list[tuple[int, int, int]]
108
+ :returns: ``R-GCN`` :cite:`R-GCN` 的训练数据
109
+ :rtype: dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]]
110
+ """
111
+
112
+ batch_data = {}
113
+
114
+ pos_triples = np.array(pos_triples)
115
+ pos_triples, self.entity = self.sampling_positive(pos_triples)
116
+ head_triples = self.sampling_negative('head', pos_triples)
117
+ tail_triples = self.sampling_negative('tail', pos_triples)
118
+ self.triples = np.concatenate((pos_triples,head_triples,tail_triples))
119
+ batch_data['entity'] = self.entity
120
+ batch_data['triples'] = torch.from_numpy(self.triples)
121
+
122
+ self.label = torch.zeros((len(self.triples),1))
123
+ self.label[0 : self.batch_size] = 1
124
+ batch_data['label'] = self.label
125
+
126
+ split_size = int(self.batch_size * 0.5)
127
+ graph_split_ids = np.random.choice(
128
+ self.batch_size,
129
+ size=split_size,
130
+ replace=False
131
+ )
132
+ head,rela,tail = pos_triples.transpose()
133
+ head = torch.tensor(head[graph_split_ids], dtype=torch.long).contiguous()
134
+ rela = torch.tensor(rela[graph_split_ids], dtype=torch.long).contiguous()
135
+ tail = torch.tensor(tail[graph_split_ids], dtype=torch.long).contiguous()
136
+ self.graph, self.relation, self.norm = self.build_graph(len(self.entity), (head,rela,tail), -1)
137
+ batch_data['graph'] = self.graph
138
+ batch_data['relation'] = self.relation
139
+ batch_data['norm'] = self.norm
140
+
141
+ return batch_data
142
+
143
+ def sampling_positive(
144
+ self,
145
+ positive_triples: list[tuple[int, int, int]]) -> tuple[np.ndarray, torch.Tensor]:
146
+
147
+ """为创建子图重新采样三元组子集,重排实体 ID。
148
+
149
+ :param pos_triples: 知识图谱中的正确三元组
150
+ :type pos_triples: list[tuple[int, int, int]]
151
+ :returns: 三元组子集和原始的实体 ID
152
+ :rtype: tuple[numpy.ndarray, torch.Tensor]
153
+ """
154
+
155
+ edges = np.random.choice(
156
+ np.arange(len(positive_triples)),
157
+ size = self.batch_size,
158
+ replace=False
159
+ )
160
+ edges = positive_triples[edges]
161
+ head, rela, tail = np.array(edges).transpose()
162
+ entity, index = np.unique((head, tail), return_inverse=True)
163
+ head, tail = np.reshape(index, (2, -1))
164
+
165
+ return np.stack((head,rela,tail)).transpose(), \
166
+ torch.from_numpy(entity).view(-1,1).long()
167
+
168
+ def sampling_negative(
169
+ self,
170
+ mode: int,
171
+ pos_triples: list[tuple[int, int, int]]) -> np.ndarray:
172
+
173
+ """采样负三元组。
174
+
175
+ :param mode: 'head' 或 'tail'
176
+ :type mode: str
177
+ :param pos_triples: 知识图谱中的正确三元组
178
+ :type pos_triples: list[tuple[int, int, int]]
179
+ :returns: 负三元组
180
+ :rtype: numpy.ndarray
181
+ """
182
+
183
+ neg_random = np.random.choice(
184
+ len(self.entity),
185
+ size = self.neg_ent * len(pos_triples)
186
+ )
187
+ neg_samples = np.tile(pos_triples, (self.neg_ent, 1))
188
+ if mode == 'head':
189
+ neg_samples[:,0] = neg_random
190
+ elif mode == 'tail':
191
+ neg_samples[:,2] = neg_random
192
+ return neg_samples
193
+
194
+ def build_graph(
195
+ self,
196
+ num_ent: int,
197
+ triples: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
198
+ power: int = -1) -> tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor]:
199
+
200
+ """建立子图。
201
+
202
+ :param num_ent: 子图的节点数
203
+ :type num_ent: int
204
+ :param triples: 知识图谱中的正确三元组子集
205
+ :type triples: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
206
+ :param power: 幂
207
+ :type power: int
208
+ :returns: 子图、关系、边的归一化系数
209
+ :rtype: tuple[dgl.DGLGraph, torch.Tensor, torch.Tensor]
210
+ """
211
+
212
+ head, rela, tail = triples[0], triples[1], triples[2]
213
+ graph = dgl.graph(([], []))
214
+ graph.add_nodes(num_ent)
215
+ graph.add_edges(head, tail)
216
+ node_norm = self.comp_deg_norm(graph, power)
217
+ edge_norm = self.node_norm_to_edge_norm(graph,node_norm)
218
+ rela = torch.tensor(rela)
219
+ return graph, rela, edge_norm
220
+
221
+ def comp_deg_norm(
222
+ self,
223
+ graph: dgl.DGLGraph,
224
+ power: int = -1) -> torch.Tensor:
225
+
226
+ """根据目标节点度计算目标节点的归一化系数。
227
+
228
+ :param graph: 子图
229
+ :type graph: dgl.DGLGraph
230
+ :param power: 幂
231
+ :type power: int
232
+ :returns: 节点的归一化系数
233
+ :rtype: torch.Tensor
234
+ """
235
+
236
+ graph = graph.local_var()
237
+ in_deg = graph.in_degrees(range(graph.number_of_nodes())).float().numpy()
238
+ norm = in_deg.__pow__(power)
239
+ norm[np.isinf(norm)] = 0
240
+ return torch.from_numpy(norm)
241
+
242
+ def node_norm_to_edge_norm(
243
+ self,
244
+ graph: dgl.DGLGraph,
245
+ node_norm: torch.Tensor) -> torch.Tensor:
246
+
247
+ """根据目标节点度计算每条边的归一化系数。
248
+
249
+ :param graph: 子图
250
+ :type graph: dgl.DGLGraph
251
+ :param node_norm: 节点的归一化系数
252
+ :type node_norm: torch.Tensor
253
+ :returns: 边的归一化系数
254
+ :rtype: torch.Tensor
255
+ """
256
+
257
+ graph = graph.local_var()
258
+ # convert to edge norm
259
+ graph.ndata['norm'] = node_norm.view(-1,1)
260
+ graph.apply_edges(lambda edges : {'norm' : edges.dst['norm']})
261
+ return graph.edata['norm']
@@ -0,0 +1,208 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/RGCNTestSampler.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
7
+ #
8
+ # R-GCN 的测试数据采样器.
9
+
10
+ """
11
+ RGCNTestSampler - R-GCN 的测试数据采样器。
12
+ """
13
+
14
+ import os
15
+ import dgl
16
+ import torch
17
+ import typing
18
+ import numpy as np
19
+ from .TestSampler import TestSampler
20
+ from .RGCNSampler import RGCNSampler
21
+ from .CompGCNSampler import CompGCNSampler
22
+ from typing_extensions import override
23
+
24
+ class RGCNTestSampler(TestSampler):
25
+
26
+ """``R-GCN`` :cite:`R-GCN` 的测试数据采样器。
27
+
28
+ 例子::
29
+
30
+ from unike.data import RGCNTestSampler, CompGCNTestSampler
31
+ from torch.utils.data import DataLoader
32
+
33
+ #: 测试数据采样器
34
+ test_sampler: typing.Union[typing.Type[RGCNTestSampler], typing.Type[CompGCNTestSampler]] = test_sampler(
35
+ sampler=train_sampler,
36
+ valid_file=valid_file,
37
+ test_file=test_file,
38
+ )
39
+
40
+ #: 验证集三元组
41
+ data_val: list[tuple[int, int, int]] = test_sampler.get_valid()
42
+ #: 测试集三元组
43
+ data_test: list[tuple[int, int, int]] = test_sampler.get_test()
44
+
45
+ val_dataloader = DataLoader(
46
+ data_val,
47
+ shuffle=False,
48
+ batch_size=test_batch_size,
49
+ num_workers=num_workers,
50
+ pin_memory=True,
51
+ collate_fn=test_sampler.sampling,
52
+ )
53
+
54
+ test_dataloader = DataLoader(
55
+ data_test,
56
+ shuffle=False,
57
+ batch_size=test_batch_size,
58
+ num_workers=num_workers,
59
+ pin_memory=True,
60
+ collate_fn=test_sampler.sampling,
61
+ )
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ sampler: typing.Union[RGCNSampler, CompGCNSampler],
67
+ valid_file: str = "valid2id.txt",
68
+ test_file: str = "test2id.txt",
69
+ type_constrain: bool = True):
70
+
71
+ """创建 RGCNTestSampler 对象。
72
+
73
+ :param sampler: 训练数据采样器。
74
+ :type sampler: typing.Union[RGCNSampler, CompGCNSampler]
75
+ :param valid_file: valid2id.txt
76
+ :type valid_file: str
77
+ :param test_file: test2id.txt
78
+ :type test_file: str
79
+ :param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
80
+ :type type_constrain: bool
81
+ """
82
+
83
+ super().__init__(
84
+ sampler=sampler,
85
+ valid_file=valid_file,
86
+ test_file=test_file,
87
+ type_constrain = type_constrain
88
+ )
89
+
90
+ #: 训练集三元组
91
+ self.triples: list[tuple[int, int, int]] = self.sampler.t_triples if isinstance(self.sampler, CompGCNSampler) else self.sampler.train_triples
92
+ #: 幂
93
+ self.power: float = -1
94
+
95
+ self.add_valid_test_reverse_triples()
96
+ self.get_hr2t_rt2h_from_all()
97
+
98
+ @override
99
+ def get_valid_test_triples_id(self):
100
+
101
+ """读取 :py:attr:`valid_file` 文件和 :py:attr:`test_file` 文件。"""
102
+
103
+ with open(os.path.join(self.sampler.in_path, self.valid_file)) as f:
104
+ self.valid_tol = (int)(f.readline())
105
+ for line in f:
106
+ h, t, r = line.strip().split()
107
+ self.valid_triples.append((int(h), int(r), int(t)))
108
+
109
+ with open(os.path.join(self.sampler.in_path, self.test_file)) as f:
110
+ self.test_tol = (int)(f.readline())
111
+ for line in f:
112
+ h, t, r = line.strip().split()
113
+ self.test_triples.append((int(h), int(r), int(t)))
114
+
115
+ def add_valid_test_reverse_triples(self):
116
+
117
+ """对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。"""
118
+
119
+ tol = int(self.sampler.rel_tol / 2)
120
+
121
+ with open(os.path.join(self.sampler.in_path, self.valid_file)) as f:
122
+ f.readline()
123
+ for line in f:
124
+ h, t, r = line.strip().split()
125
+ self.valid_triples.append(
126
+ (int(t), int(r) + tol, int(h))
127
+ )
128
+
129
+ with open(os.path.join(self.sampler.in_path, self.test_file)) as f:
130
+ f.readline()
131
+ for line in f:
132
+ h, t, r = line.strip().split()
133
+ self.test_triples.append(
134
+ (int(t), int(r) + tol, int(h))
135
+ )
136
+
137
+ self.all_true_triples = set(
138
+ self.triples + self.valid_triples + self.test_triples
139
+ )
140
+
141
+ @override
142
+ def get_type_constrain_id(self):
143
+
144
+ """读取 type_constrain.txt 文件。"""
145
+
146
+ tol = int(self.sampler.rel_tol / 2)
147
+
148
+ with open(os.path.join(self.sampler.in_path, "type_constrain.txt")) as f:
149
+ rel_tol = (int)(f.readline())
150
+ first_line = True
151
+ for line in f:
152
+ rel_types = line.strip().split("\t")
153
+ for entity in rel_types[2:]:
154
+ if first_line:
155
+ self.rel_heads[int(rel_types[0])].add(int(entity))
156
+ self.rel_tails[int(rel_types[0]) + tol].add(int(entity))
157
+ else:
158
+ self.rel_tails[int(rel_types[0])].add(int(entity))
159
+ self.rel_heads[int(rel_types[0]) + tol].add(int(entity))
160
+ first_line = not first_line
161
+
162
+ for rel in self.rel_heads:
163
+ self.rel_heads[rel] = torch.tensor(list(self.rel_heads[rel]))
164
+ for rel in self.rel_tails:
165
+ self.rel_tails[rel] = torch.tensor(list(self.rel_tails[rel]))
166
+
167
+ @override
168
+ def sampling(
169
+ self,
170
+ data: list[tuple[int, int, int]]) -> dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]]:
171
+
172
+ """``R-GCN`` :cite:`R-GCN` 的测试数据采样函数。
173
+
174
+ :param data: 测试的正确三元组
175
+ :type data: list[tuple[int, int, int]]
176
+ :returns: ``R-GCN`` :cite:`R-GCN` 的测试数据
177
+ :rtype: dict[str, typing.Union[dgl.DGLGraph , torch.Tensor]]
178
+ """
179
+
180
+ batch_data = {}
181
+ head_label = torch.zeros(len(data), self.ent_tol)
182
+ tail_label = torch.zeros(len(data), self.ent_tol)
183
+ for idx, triple in enumerate(data):
184
+ head, rel, tail = triple
185
+ head_label[idx][self.rt2h_all[(rel, tail)]] = 1.0
186
+ tail_label[idx][self.hr2t_all[(head, rel)]] = 1.0
187
+ if self.type_constrain:
188
+ head_label_type = torch.ones(len(data), self.ent_tol)
189
+ tail_laebl_type = torch.ones(len(data), self.ent_tol)
190
+ for idx, triple in enumerate(data):
191
+ head, rel, tail = triple
192
+ head_label_type[idx][self.rel_heads[rel]] = 0.0
193
+ tail_laebl_type[idx][self.rel_tails[rel]] = 0.0
194
+ head_label_type[idx][self.rt2h_all[(rel, tail)]] = 1.0
195
+ tail_laebl_type[idx][self.hr2t_all[(head, rel)]] = 1.0
196
+ batch_data["head_label_type"] = head_label_type
197
+ batch_data["tail_label_type"] = tail_laebl_type
198
+ batch_data["positive_sample"] = torch.tensor(data)
199
+ batch_data["head_label"] = head_label
200
+ batch_data["tail_label"] = tail_label
201
+
202
+ graph, rela, norm = self.sampler.build_graph(self.ent_tol, np.array(self.triples).transpose(), self.power)
203
+ batch_data["graph"] = graph
204
+ batch_data["rela"] = rela
205
+ batch_data["norm"] = norm
206
+ batch_data["entity"] = torch.arange(0, self.ent_tol, dtype=torch.long).view(-1,1)
207
+
208
+ return batch_data
@@ -0,0 +1,78 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/RevSampler.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 15, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
7
+ #
8
+ # 为 KGReader 增加相反关系,用于图神经网络模型.
9
+
10
+ """
11
+ RevSampler - 为 KGReader 增加相反关系,用于图神经网络模型。
12
+ """
13
+
14
+ import os
15
+ from .KGReader import KGReader
16
+
17
+ class RevSampler(KGReader):
18
+
19
+ """在训练集中增加相反关系.
20
+
21
+ 对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ in_path: str = "./",
27
+ ent_file: str = "entity2id.txt",
28
+ rel_file: str = "relation2id.txt",
29
+ train_file: str = "train2id.txt"):
30
+
31
+ """创建 RevSampler 对象。
32
+
33
+ :param in_path: 数据集目录
34
+ :type in_path: str
35
+ :param ent_file: entity2id.txt
36
+ :type ent_file: str
37
+ :param rel_file: relation2id.txt
38
+ :type rel_file: str
39
+ :param train_file: train2id.txt
40
+ :type train_file: str
41
+ """
42
+
43
+ super().__init__(
44
+ in_path=in_path,
45
+ ent_file=ent_file,
46
+ rel_file=rel_file,
47
+ train_file=train_file
48
+ )
49
+
50
+ self.add_reverse_relation()
51
+ self.add_train_reverse_triples()
52
+ self.get_hr2t_rt2h_from_train()
53
+
54
+ def add_reverse_relation(self):
55
+
56
+ """增加相反关系:r` = r + rel_tol"""
57
+
58
+ with open(os.path.join(self.in_path, self.rel_file)) as f:
59
+ f.readline()
60
+ for line in f:
61
+ relation, rid = line.strip().split("\t")
62
+ self.rel2id[relation + "_reverse"] = int(rid) + self.rel_tol
63
+ self.id2rel[int(rid) + self.rel_tol] = relation + "_reverse"
64
+ self.rel_tol = len(self.rel2id)
65
+
66
+ def add_train_reverse_triples(self):
67
+
68
+ """对于每一个三元组 (h, r, t),生成相反关系三元组 (t, r`, h): r` = r + rel_tol。"""
69
+
70
+ tol = int(self.rel_tol / 2)
71
+
72
+ with open(os.path.join(self.in_path, self.train_file)) as f:
73
+ f.readline()
74
+ for line in f:
75
+ h, t, r = line.strip().split()
76
+ self.train_triples.append(
77
+ (int(t), int(r) + tol, int(h))
78
+ )