unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,235 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/ComplEx.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 6, 2023
7
+ #
8
+ # 该头文件定义了 ComplEx.
9
+
10
+ """
11
+ ComplEx - 第一个真正意义上复数域模型,简单而且高效。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import torch.nn as nn
17
+ from .Model import Model
18
+ from typing_extensions import override
19
+
20
+ class ComplEx(Model):
21
+
22
+ """
23
+ ``ComplEx`` :cite:`ComplEx` 提出于 2016 年,第一个真正意义上复数域模型,简单而且高效。复数版本的 :py:class:`unike.module.model.DistMult`。
24
+
25
+ 评分函数为:
26
+
27
+ .. math::
28
+
29
+ <\operatorname{Re}(h),\operatorname{Re}(r),\operatorname{Re}(t)>
30
+ +<\operatorname{Re}(h),\operatorname{Im}(r),\operatorname{Im}(t)>
31
+ +<\operatorname{Im}(h),\operatorname{Re}(r),\operatorname{Im}(t)>
32
+ -<\operatorname{Im}(h),\operatorname{Im}(r),\operatorname{Re}(t)>
33
+
34
+ :math:`h, r, t \in \mathbb{C}^n` 是复数向量,:math:`< \mathbf{a}, \mathbf{b}, \mathbf{c} >=\sum_{i=1}^{n}a_ib_ic_i` 为逐元素多线性点积(element-wise multi-linear dot product)。
35
+
36
+ 正三元组的评分函数的值越大越好,负三元组越小越好,如果想获得更详细的信息请访问 :ref:`ComplEx <complex>`。
37
+
38
+ 例子::
39
+
40
+ from unike.config import Trainer, Tester
41
+ from unike.module.model import ComplEx
42
+ from unike.module.loss import SoftplusLoss
43
+ from unike.module.strategy import NegativeSampling
44
+
45
+ # define the model
46
+ complEx = ComplEx(
47
+ ent_tol = train_dataloader.get_ent_tol(),
48
+ rel_tol = train_dataloader.get_rel_tol(),
49
+ dim = config.dim
50
+ )
51
+
52
+ # define the loss function
53
+ model = NegativeSampling(
54
+ model = complEx,
55
+ loss = SoftplusLoss(),
56
+ batch_size = train_dataloader.get_batch_size(),
57
+ regul_rate = config.regul_rate
58
+ )
59
+
60
+ # test the model
61
+ tester = Tester(model = complEx, data_loader = test_dataloader, use_gpu = config.use_gpu, device = config.device)
62
+
63
+ # train the model
64
+ trainer = Trainer(model = model, data_loader = train_dataloader, epochs = config.epochs,
65
+ lr = config.lr, opt_method = config.opt_method, use_gpu = config.use_gpu, device = config.device,
66
+ tester = tester, test = config.test, valid_interval = config.valid_interval,
67
+ log_interval = config.log_interval, save_interval = config.save_interval,
68
+ save_path = config.save_path, use_wandb = True)
69
+ trainer.run()
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ ent_tol: int,
75
+ rel_tol: int,
76
+ dim: int = 100):
77
+
78
+ """创建 ComplEx 对象。
79
+
80
+ :param ent_tol: 实体的个数
81
+ :type ent_tol: int
82
+ :param rel_tol: 关系的个数
83
+ :type rel_tol: int
84
+ :param dim: 实体嵌入向量和关系嵌入向量的维度
85
+ :type dim: int
86
+ """
87
+
88
+ super(ComplEx, self).__init__(ent_tol, rel_tol)
89
+
90
+ #: 实体嵌入向量和关系嵌入向量的维度
91
+ self.dim: int = dim
92
+ #: 根据实体个数,创建的实体嵌入
93
+ self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim * 2)
94
+ #: 根据关系个数,创建的关系嵌入
95
+ self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim * 2)
96
+
97
+ nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
98
+ nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
99
+
100
+ @override
101
+ def forward(
102
+ self,
103
+ triples: torch.Tensor,
104
+ negs: torch.Tensor = None,
105
+ mode: str = 'single') -> torch.Tensor:
106
+
107
+ """
108
+ 定义每次调用时执行的计算。
109
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
110
+
111
+ :param triples: 正确的三元组
112
+ :type triples: torch.Tensor
113
+ :param negs: 负三元组类别
114
+ :type negs: torch.Tensor
115
+ :param mode: 模式
116
+ :type triples: str
117
+ :returns: 三元组的得分
118
+ :rtype: torch.Tensor
119
+ """
120
+
121
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
122
+ score = self._calc(head_emb, relation_emb, tail_emb)
123
+ return score
124
+
125
+ def _calc(
126
+ self,
127
+ h: torch.Tensor,
128
+ r: torch.Tensor,
129
+ t: torch.Tensor) -> torch.Tensor:
130
+
131
+ """计算 ComplEx 的评分函数。
132
+
133
+ :param h: 头实体的向量。
134
+ :type h: torch.Tensor
135
+ :param r: 关系的向量。
136
+ :type r: torch.Tensor
137
+ :param t: 尾实体的向量。
138
+ :type t: torch.Tensor
139
+ :returns: 三元组的得分
140
+ :rtype: torch.Tensor
141
+ """
142
+
143
+ re_head, im_head = torch.chunk(h, 2, dim=-1)
144
+ re_relation, im_relation = torch.chunk(r, 2, dim=-1)
145
+ re_tail, im_tail = torch.chunk(t, 2, dim=-1)
146
+
147
+ return torch.sum(
148
+ re_head * re_tail * re_relation
149
+ + im_head * im_tail * re_relation
150
+ + re_head * im_tail * im_relation
151
+ - im_head * re_tail * im_relation,
152
+ -1
153
+ )
154
+
155
+ @override
156
+ def predict(
157
+ self,
158
+ data: dict[str, typing.Union[torch.Tensor,str]],
159
+ mode) -> torch.Tensor:
160
+
161
+ """ComplEx 的推理方法。
162
+
163
+ :param data: 数据。
164
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
165
+ :returns: 三元组的得分
166
+ :rtype: torch.Tensor
167
+ """
168
+
169
+ triples = data["positive_sample"]
170
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
171
+ score = self._calc(head_emb, relation_emb, tail_emb)
172
+ return score
173
+
174
+ def regularization(
175
+ self,
176
+ data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
177
+
178
+ """L2 正则化函数(又称权重衰减),在损失函数中用到。
179
+
180
+ :param data: 数据。
181
+ :type data: dict[str, typing.Union[torch.Tensor, str]]
182
+ :returns: 模型参数的正则损失
183
+ :rtype: torch.Tensor
184
+ """
185
+
186
+ pos_sample = data["positive_sample"]
187
+ neg_sample = data["negative_sample"]
188
+ mode = data["mode"]
189
+ pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
190
+ if mode == "bern":
191
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
192
+ else:
193
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
194
+
195
+ pos_regul = (torch.mean(pos_head_emb ** 2) +
196
+ torch.mean(pos_relation_emb ** 2) +
197
+ torch.mean(pos_tail_emb ** 2)) / 3
198
+
199
+ neg_regul = (torch.mean(neg_head_emb ** 2) +
200
+ torch.mean(neg_relation_emb ** 2) +
201
+ torch.mean(neg_tail_emb ** 2)) / 3
202
+
203
+ regul = (pos_regul + neg_regul) / 2
204
+
205
+ return regul
206
+
207
+ def get_complex_hpo_config() -> dict[str, dict[str, typing.Any]]:
208
+
209
+ """返回 :py:class:`ComplEx` 的默认超参数优化配置。
210
+
211
+ 默认配置为::
212
+
213
+ parameters_dict = {
214
+ 'model': {
215
+ 'value': 'ComplEx'
216
+ },
217
+ 'dim': {
218
+ 'values': [50, 100, 200]
219
+ }
220
+ }
221
+
222
+ :returns: :py:class:`ComplEx` 的默认超参数优化配置
223
+ :rtype: dict[str, dict[str, typing.Any]]
224
+ """
225
+
226
+ parameters_dict = {
227
+ 'model': {
228
+ 'value': 'ComplEx'
229
+ },
230
+ 'dim': {
231
+ 'values': [50, 100, 200]
232
+ }
233
+ }
234
+
235
+ return parameters_dict
@@ -0,0 +1,276 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/model/DistMult.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 6, 2023
7
+ #
8
+ # 该头文件定义了 DistMult.
9
+
10
+ """
11
+ DistMult - 最简单的双线性模型,与 TransE 参数量相同,因此非常容易的应用于大型的知识图谱。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import torch.nn as nn
17
+ from .Model import Model
18
+ from typing_extensions import override
19
+
20
+ class DistMult(Model):
21
+
22
+ """
23
+ ``DistMult`` :cite:`DistMult` 提出于 2015 年,最简单的双线性模型,与 TransE 参数量相同,因此非常容易的应用于大型的知识图谱。
24
+
25
+ 评分函数为:
26
+
27
+ .. math::
28
+
29
+ \sum_{i=1}^{n}h_ir_it_i
30
+
31
+ 为逐元素多线性点积(element-wise multi-linear dot product),正三元组的评分函数的值越大越好,负三元组越小越好,如果想获得更详细的信息请访问 :ref:`DistMult <distMult>`。
32
+
33
+ 例子::
34
+
35
+ from unike.utils import WandbLogger
36
+ from unike.data import KGEDataLoader, BernSampler, TradTestSampler
37
+ from unike.module.model import DistMult
38
+ from unike.module.loss import SoftplusLoss
39
+ from unike.module.strategy import NegativeSampling
40
+ from unike.config import Trainer, Tester
41
+
42
+ wandb_logger = WandbLogger(
43
+ project="pybind11-ke",
44
+ name="DistMult-WN18RR",
45
+ config=dict(
46
+ in_path = '../../benchmarks/WN18RR/',
47
+ batch_size = 4096,
48
+ neg_ent = 25,
49
+ test = True,
50
+ test_batch_size = 10,
51
+ num_workers = 16,
52
+ dim = 200,
53
+ regul_rate = 1.0,
54
+ use_gpu = True,
55
+ device = 'cuda:0',
56
+ epochs = 2000,
57
+ lr = 0.5,
58
+ opt_method = 'adagrad',
59
+ valid_interval = 100,
60
+ log_interval = 100,
61
+ save_interval = 100,
62
+ save_path = '../../checkpoint/distMult.pth'
63
+ )
64
+ )
65
+
66
+ config = wandb_logger.config
67
+
68
+ # dataloader for training
69
+ dataloader = KGEDataLoader(
70
+ in_path = config.in_path,
71
+ batch_size = config.batch_size,
72
+ neg_ent = config.neg_ent,
73
+ test = config.test,
74
+ test_batch_size = config.test_batch_size,
75
+ num_workers = config.num_workers,
76
+ train_sampler = BernSampler,
77
+ test_sampler = TradTestSampler
78
+ )
79
+
80
+ # define the model
81
+ distmult = DistMult(
82
+ ent_tol = dataloader.get_ent_tol(),
83
+ rel_tol = dataloader.get_rel_tol(),
84
+ dim = config.dim
85
+ )
86
+
87
+ # define the loss function
88
+ model = NegativeSampling(
89
+ model = distmult,
90
+ loss = SoftplusLoss(),
91
+ regul_rate = config.regul_rate
92
+ )
93
+
94
+ # test the model
95
+ tester = Tester(model = distmult, data_loader = dataloader, use_gpu = config.use_gpu, device = config.device)
96
+
97
+ # train the model
98
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(), epochs = config.epochs,
99
+ lr = config.lr, opt_method = config.opt_method, use_gpu = config.use_gpu, device = config.device,
100
+ tester = tester, test = config.test, valid_interval = config.valid_interval,
101
+ log_interval = config.log_interval, save_interval = config.save_interval,
102
+ save_path = config.save_path, use_wandb = True)
103
+ trainer.run()
104
+
105
+ # close your wandb run
106
+ wandb_logger.finish()
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ ent_tol: int,
112
+ rel_tol: int,
113
+ dim: int = 100):
114
+
115
+ """创建 DistMult 对象。
116
+
117
+ :param ent_tol: 实体的个数
118
+ :type ent_tol: int
119
+ :param rel_tol: 关系的个数
120
+ :type rel_tol: int
121
+ :param dim: 实体嵌入向量和关系对角矩阵的维度
122
+ :type dim: int
123
+ """
124
+
125
+ super(DistMult, self).__init__(ent_tol, rel_tol)
126
+
127
+ #: 实体嵌入向量和关系对角矩阵的维度
128
+ self.dim: int = dim
129
+ #: 根据实体个数,创建的实体嵌入
130
+ self.ent_embeddings: torch.nn.Embedding = nn.Embedding(self.ent_tol, self.dim)
131
+ #: 根据关系个数,创建的关系对角矩阵
132
+ self.rel_embeddings: torch.nn.Embedding = nn.Embedding(self.rel_tol, self.dim)
133
+
134
+ nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
135
+ nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
136
+
137
+ @override
138
+ def forward(
139
+ self,
140
+ triples: torch.Tensor,
141
+ negs: torch.Tensor = None,
142
+ mode: str = 'single') -> torch.Tensor:
143
+
144
+ """
145
+ 定义每次调用时执行的计算。
146
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
147
+
148
+ :param triples: 正确的三元组
149
+ :type triples: torch.Tensor
150
+ :param negs: 负三元组类别
151
+ :type negs: torch.Tensor
152
+ :param mode: 模式
153
+ :type triples: str
154
+ :returns: 三元组的得分
155
+ :rtype: torch.Tensor
156
+ """
157
+
158
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, negs, mode)
159
+ score = self._calc(head_emb, relation_emb, tail_emb)
160
+ return score
161
+
162
+ def _calc(
163
+ self,
164
+ h: torch.Tensor,
165
+ r: torch.Tensor,
166
+ t: torch.Tensor) -> torch.Tensor:
167
+
168
+ """计算 DistMult 的评分函数。
169
+
170
+ :param h: 头实体的向量。
171
+ :type h: torch.Tensor
172
+ :param r: 关系的对角矩阵。
173
+ :type r: torch.Tensor
174
+ :param t: 尾实体的向量。
175
+ :type t: torch.Tensor
176
+ :returns: 三元组的得分
177
+ :rtype: torch.Tensor
178
+ """
179
+
180
+ score = (h * r) * t
181
+
182
+ # 计算得分
183
+ score = torch.sum(score, -1)
184
+ return score
185
+
186
+ @override
187
+ def predict(
188
+ self,
189
+ data: dict[str, typing.Union[torch.Tensor,str]],
190
+ mode) -> torch.Tensor:
191
+
192
+ """DistMult 的推理方法。
193
+
194
+ :param data: 数据。
195
+ :type data: dict[str, typing.Union[torch.Tensor,str]]
196
+ :returns: 三元组的得分
197
+ :rtype: torch.Tensor
198
+ """
199
+
200
+ triples = data["positive_sample"]
201
+ head_emb, relation_emb, tail_emb = self.tri2emb(triples, mode=mode)
202
+ score = self._calc(head_emb, relation_emb, tail_emb)
203
+ return score
204
+
205
+ def regularization(
206
+ self,
207
+ data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
208
+
209
+ """L2 正则化函数(又称权重衰减),在损失函数中用到。
210
+
211
+ :param data: 数据。
212
+ :type data: dict[str, typing.Union[torch.Tensor, str]]
213
+ :returns: 模型参数的正则损失
214
+ :rtype: torch.Tensor
215
+ """
216
+
217
+ pos_sample = data["positive_sample"]
218
+ neg_sample = data["negative_sample"]
219
+ mode = data["mode"]
220
+ pos_head_emb, pos_relation_emb, pos_tail_emb = self.tri2emb(pos_sample)
221
+ if mode == "bern":
222
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(neg_sample)
223
+ else:
224
+ neg_head_emb, neg_relation_emb, neg_tail_emb = self.tri2emb(pos_sample, neg_sample, mode)
225
+
226
+ pos_regul = (torch.mean(pos_head_emb ** 2) +
227
+ torch.mean(pos_relation_emb ** 2) +
228
+ torch.mean(pos_tail_emb ** 2)) / 3
229
+
230
+ neg_regul = (torch.mean(neg_head_emb ** 2) +
231
+ torch.mean(neg_relation_emb ** 2) +
232
+ torch.mean(neg_tail_emb ** 2)) / 3
233
+
234
+ regul = (pos_regul + neg_regul) / 2
235
+
236
+ return regul
237
+
238
+ def l3_regularization(self):
239
+
240
+ """L3 正则化函数,在损失函数中用到。
241
+
242
+ :returns: 模型参数的正则损失
243
+ :rtype: torch.Tensor
244
+ """
245
+
246
+ return (self.ent_embeddings.weight.norm(p = 3)**3 + self.rel_embeddings.weight.norm(p = 3)**3)
247
+
248
+ def get_distmult_hpo_config() -> dict[str, dict[str, typing.Any]]:
249
+
250
+ """返回 :py:class:`DistMult` 的默认超参数优化配置。
251
+
252
+ 默认配置为::
253
+
254
+ parameters_dict = {
255
+ 'model': {
256
+ 'value': 'DistMult'
257
+ },
258
+ 'dim': {
259
+ 'values': [50, 100, 200]
260
+ }
261
+ }
262
+
263
+ :returns: :py:class:`DistMult` 的默认超参数优化配置
264
+ :rtype: dict[str, dict[str, typing.Any]]
265
+ """
266
+
267
+ parameters_dict = {
268
+ 'model': {
269
+ 'value': 'DistMult'
270
+ },
271
+ 'dim': {
272
+ 'values': [50, 100, 200]
273
+ }
274
+ }
275
+
276
+ return parameters_dict