unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,168 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/BernSampler.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 30, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
7
+ #
8
+ # 该脚本定义了 BernSampler 类.
9
+
10
+ """
11
+ BernSampler - 平移模型和语义匹配模型的训练集数据采样器。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import random
17
+ import collections
18
+ import numpy as np
19
+ from .TradSampler import TradSampler
20
+ from typing_extensions import override
21
+
22
+ class BernSampler(TradSampler):
23
+
24
+ """
25
+ 平移模型和语义匹配模型的训练集 Bern 数据采样器(伯努利分布),如果想获得更详细的信息请访问 :ref:`TransH <transh>`。
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ in_path: str = "./",
31
+ ent_file: str = "entity2id.txt",
32
+ rel_file: str = "relation2id.txt",
33
+ train_file: str = "train2id.txt",
34
+ batch_size: int | None = None,
35
+ neg_ent: int = 1):
36
+
37
+ """创建 BernSampler 对象。
38
+
39
+ :param in_path: 数据集目录
40
+ :type in_path: str
41
+ :param ent_file: entity2id.txt
42
+ :type ent_file: str
43
+ :param rel_file: relation2id.txt
44
+ :type rel_file: str
45
+ :param train_file: train2id.txt
46
+ :type train_file: str
47
+ :param batch_size: batch size 在该采样器中不起作用,只是占位符。
48
+ :type batch_size: int | None
49
+ :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity
50
+ :type neg_ent: int
51
+ """
52
+
53
+ super().__init__(
54
+ in_path=in_path,
55
+ ent_file=ent_file,
56
+ rel_file=rel_file,
57
+ train_file=train_file,
58
+ batch_size = batch_size,
59
+ neg_ent = neg_ent
60
+ )
61
+
62
+ self.tph, self.hpt = self.get_tph_hpt()
63
+
64
+ def get_tph_hpt(self) -> tuple[collections.defaultdict[float], collections.defaultdict[float]]:
65
+
66
+ """计算 tph 和 hpt。
67
+
68
+ :returns: tph 和 hpt
69
+ :rtype: tuple[collections.defaultdict[float], collections.defaultdict[float]]
70
+ """
71
+
72
+ h_of_r = collections.defaultdict(set)
73
+ t_of_r = collections.defaultdict(set)
74
+ freq_rel = collections.defaultdict(float)
75
+ tph = collections.defaultdict(float)
76
+ hpt = collections.defaultdict(float)
77
+ for h, r, t in self.train_triples:
78
+ freq_rel[r] += 1.0
79
+ h_of_r[r].add(h)
80
+ t_of_r[r].add(t)
81
+ for r in h_of_r:
82
+ tph[r] = freq_rel[r] / len(h_of_r[r])
83
+ hpt[r] = freq_rel[r] / len(t_of_r[r])
84
+ return tph, hpt
85
+
86
+ @override
87
+ def sampling(
88
+ self,
89
+ pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[str, torch.Tensor]]:
90
+
91
+ """平移模型和语义匹配模型的训练集 bern 的数据采样函数(伯努利分布)。
92
+
93
+ :param pos_triples: 知识图谱中的正确三元组
94
+ :type pos_triples: list[tuple[int, int, int]]
95
+ :returns: 平移模型和语义匹配模型的训练数据
96
+ :rtype: dict[str, typing.Union[str, torch.Tensor]]
97
+ """
98
+
99
+ batch_data = {}
100
+ neg_ent_sample = []
101
+
102
+ batch_data['mode'] = 'bern'
103
+ for h, r, t in pos_triples:
104
+ neg_ent = self.__normal_batch(h, r, t, self.neg_ent)
105
+ neg_ent_sample += neg_ent
106
+
107
+ batch_data["positive_sample"] = torch.LongTensor(np.array(pos_triples))
108
+ batch_data["negative_sample"] = torch.LongTensor(np.array(neg_ent_sample))
109
+
110
+ return batch_data
111
+
112
+ def __normal_batch(
113
+ self,
114
+ h: int,
115
+ r: int,
116
+ t: int,
117
+ neg_size: int) -> list[tuple[int, int, int]]:
118
+
119
+ """Bern 负采样函数
120
+
121
+ :param h: 头实体
122
+ :type h: int
123
+ :param r: 关系
124
+ :type r: int
125
+ :param t: 尾实体
126
+ :type t: int
127
+ :param neg_size: 负三元组个数
128
+ :type neg_size: int
129
+ :returns: 负三元组中的头实体列表
130
+ :rtype: list[tuple[int, int, int]]
131
+ """
132
+
133
+ neg_size_h = 0
134
+ neg_size_t = 0
135
+ prob = self.hpt[r] / (self.hpt[r] + self.tph[r])
136
+ for _ in range(neg_size):
137
+ if random.random() < prob:
138
+ neg_size_t += 1
139
+ else:
140
+ neg_size_h += 1
141
+
142
+ res = []
143
+
144
+ neg_list_h = []
145
+ neg_cur_size = 0
146
+ while neg_cur_size < neg_size_h:
147
+ neg_tmp_h = self.corrupt_head(t, r, num_max=(neg_size_h - neg_cur_size) * 2)
148
+ neg_list_h.append(neg_tmp_h)
149
+ neg_cur_size += len(neg_tmp_h)
150
+ if neg_list_h != []:
151
+ neg_list_h = np.concatenate(neg_list_h)
152
+
153
+ for hh in neg_list_h[:neg_size_h]:
154
+ res.append((hh, r, t))
155
+
156
+ neg_list_t = []
157
+ neg_cur_size = 0
158
+ while neg_cur_size < neg_size_t:
159
+ neg_tmp_t = self.corrupt_tail(h, r, num_max=(neg_size_t - neg_cur_size) * 2)
160
+ neg_list_t.append(neg_tmp_t)
161
+ neg_cur_size += len(neg_tmp_t)
162
+ if neg_list_t != []:
163
+ neg_list_t = np.concatenate(neg_list_t)
164
+
165
+ for tt in neg_list_t[:neg_size_t]:
166
+ res.append((h, r, tt))
167
+
168
+ return res
@@ -0,0 +1,140 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/CompGCNSampler.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
7
+ #
8
+ # 该脚本定义了 CompGCNSampler 类.
9
+
10
+ """
11
+ CompGCNSampler - CompGCN 的数据采样器。
12
+ """
13
+
14
+ import dgl
15
+ import torch
16
+ import typing
17
+ import numpy as np
18
+ from .RGCNSampler import RGCNSampler
19
+ from typing_extensions import override
20
+
21
+ class CompGCNSampler(RGCNSampler):
22
+
23
+ """``CompGCN`` :cite:`CompGCN` 的训练数据采样器。
24
+
25
+ 例子::
26
+
27
+ from unike.data import RGCNSampler, CompGCNSampler
28
+ from torch.utils.data import DataLoader
29
+
30
+ #: 训练数据采样器
31
+ train_sampler: typing.Union[typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = train_sampler(
32
+ in_path=in_path,
33
+ ent_file=ent_file,
34
+ rel_file=rel_file,
35
+ train_file=train_file,
36
+ batch_size=batch_size,
37
+ neg_ent=neg_ent
38
+ )
39
+
40
+ #: 训练集三元组
41
+ data_train: list[tuple[int, int, int]] = train_sampler.get_train()
42
+
43
+ train_dataloader = DataLoader(
44
+ data_train,
45
+ shuffle=True,
46
+ batch_size=batch_size,
47
+ num_workers=num_workers,
48
+ pin_memory=True,
49
+ drop_last=True,
50
+ collate_fn=train_sampler.sampling,
51
+ )
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ in_path: str = "./",
57
+ ent_file: str = "entity2id.txt",
58
+ rel_file: str = "relation2id.txt",
59
+ train_file: str = "train2id.txt",
60
+ batch_size: int | None = None,
61
+ neg_ent: int = 1):
62
+
63
+ """创建 CompGCNSampler 对象。
64
+
65
+ :param in_path: 数据集目录
66
+ :type in_path: str
67
+ :param ent_file: entity2id.txt
68
+ :type ent_file: str
69
+ :param rel_file: relation2id.txt
70
+ :type rel_file: str
71
+ :param train_file: train2id.txt
72
+ :type train_file: str
73
+ :param batch_size: batch size
74
+ :type batch_size: int | None
75
+ :param neg_ent: 对于 CompGCN 不起作用。
76
+ :type neg_ent: int
77
+ """
78
+
79
+ super().__init__(
80
+ in_path=in_path,
81
+ ent_file=ent_file,
82
+ rel_file=rel_file,
83
+ train_file=train_file,
84
+ batch_size=batch_size,
85
+ neg_ent=neg_ent
86
+ )
87
+
88
+ super().get_hr_train()
89
+
90
+ self.graph, self.relation, self.norm = \
91
+ self.build_graph(self.ent_tol, np.array(self.t_triples).transpose(), -0.5)
92
+
93
+ @override
94
+ def sampling(
95
+ self,
96
+ pos_hr_t: list[tuple[tuple[int, int], list[int]]]) -> dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]:
97
+
98
+ """``CompGCN`` :cite:`CompGCN` 的采样函数。
99
+
100
+ :param pos_triples: 知识图谱中的正确三元组
101
+ :type pos_triples: list[tuple[tuple[int, int], list[int]]]
102
+ :returns: ``CompGCN`` :cite:`CompGCN` 的训练数据
103
+ :rtype: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
104
+ """
105
+
106
+ batch_data = {}
107
+
108
+ self.label = torch.zeros(self.batch_size, self.ent_tol)
109
+ self.triples = torch.LongTensor([hr for hr , _ in pos_hr_t])
110
+ for id, hr_sample in enumerate([t for _ ,t in pos_hr_t]):
111
+ self.label[id][hr_sample] = 1
112
+
113
+ batch_data['sample'] = self.triples
114
+ batch_data['label'] = self.label
115
+ batch_data['graph'] = self.graph
116
+ batch_data['relation'] = self.relation
117
+ batch_data['norm'] = self.norm
118
+
119
+ return batch_data
120
+
121
+ @override
122
+ def node_norm_to_edge_norm(
123
+ self,
124
+ graph: dgl.DGLGraph,
125
+ node_norm: torch.Tensor) -> torch.Tensor:
126
+
127
+ """根据源节点和目标节点的度计算每条边的归一化系数。
128
+
129
+ :param graph: 子图的节点数
130
+ :type graph: dgl.DGLGraph
131
+ :param node_norm: 节点的归一化系数
132
+ :type node_norm: torch.Tensor
133
+ :returns: 边的归一化系数
134
+ :rtype: torch.Tensor
135
+ """
136
+
137
+ graph.ndata['norm'] = node_norm
138
+ graph.apply_edges(lambda edges: {'norm': edges.dst['norm'] * edges.src['norm']})
139
+ norm = graph.edata.pop('norm').squeeze()
140
+ return norm
@@ -0,0 +1,84 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/CompGCNTestSampler.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 19, 2023
7
+ #
8
+ # 该脚本定义了 CompGCNTestSampler 类.
9
+
10
+ """
11
+ CompGCNTestSampler - CompGCN 的测试数据采样器。
12
+ """
13
+
14
+ from .CompGCNSampler import CompGCNSampler
15
+ from .RGCNTestSampler import RGCNTestSampler
16
+
17
+ class CompGCNTestSampler(RGCNTestSampler):
18
+
19
+ """``CompGCN`` :cite:`CompGCN` 的测试数据采样器。
20
+
21
+ 例子::
22
+
23
+ from unike.data import RGCNTestSampler, CompGCNTestSampler
24
+ from torch.utils.data import DataLoader
25
+
26
+ #: 测试数据采样器
27
+ test_sampler: typing.Union[typing.Type[RGCNTestSampler], typing.Type[CompGCNTestSampler]] = test_sampler(
28
+ sampler=train_sampler,
29
+ valid_file=valid_file,
30
+ test_file=test_file,
31
+ )
32
+
33
+ #: 验证集三元组
34
+ data_val: list[tuple[int, int, int]] = test_sampler.get_valid()
35
+ #: 测试集三元组
36
+ data_test: list[tuple[int, int, int]] = test_sampler.get_test()
37
+
38
+ val_dataloader = DataLoader(
39
+ data_val,
40
+ shuffle=False,
41
+ batch_size=test_batch_size,
42
+ num_workers=num_workers,
43
+ pin_memory=True,
44
+ collate_fn=test_sampler.sampling,
45
+ )
46
+
47
+ test_dataloader = DataLoader(
48
+ data_test,
49
+ shuffle=False,
50
+ batch_size=test_batch_size,
51
+ num_workers=num_workers,
52
+ pin_memory=True,
53
+ collate_fn=test_sampler.sampling,
54
+ )
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ sampler: CompGCNSampler,
60
+ valid_file: str = "valid2id.txt",
61
+ test_file: str = "test2id.txt",
62
+ type_constrain: bool = True):
63
+
64
+ """创建 CompGCNTestSampler 对象。
65
+
66
+ :param sampler: 训练数据采样器。
67
+ :type sampler: CompGCNSampler
68
+ :param valid_file: valid2id.txt
69
+ :type valid_file: str
70
+ :param test_file: test2id.txt
71
+ :type test_file: str
72
+ :param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
73
+ :type type_constrain: bool
74
+ """
75
+
76
+ super().__init__(
77
+ sampler=sampler,
78
+ valid_file = valid_file,
79
+ test_file = test_file,
80
+ type_constrain = type_constrain
81
+ )
82
+
83
+ #: 幂
84
+ self.power: float = -0.5
@@ -0,0 +1,315 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/KGEDataLoader.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Apr 27, 2024
7
+ #
8
+ # 为 KGE 模型读取数据.
9
+
10
+ """
11
+ KGEDataLoader - KGE 模型读取数据集类。
12
+ """
13
+
14
+ import typing
15
+ from .UniSampler import UniSampler
16
+ from .BernSampler import BernSampler
17
+ from .RGCNSampler import RGCNSampler
18
+ from .CompGCNSampler import CompGCNSampler
19
+ from .TestSampler import TestSampler
20
+ from .TradTestSampler import TradTestSampler
21
+ from torch.utils.data import DataLoader
22
+
23
+ class KGEDataLoader:
24
+
25
+ """KGE 模型数据加载器。
26
+
27
+ 例子::
28
+
29
+ from unike.data import KGEDataLoader, BernSampler, TradTestSampler
30
+
31
+ dataloader = KGEDataLoader(
32
+ in_path = "../../benchmarks/FB15K/",
33
+ batch_size = 8192,
34
+ neg_ent = 25,
35
+ test = True,
36
+ test_batch_size = 256,
37
+ num_workers = 16,
38
+ train_sampler = BernSampler,
39
+ test_sampler = TradTestSampler
40
+ )
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ in_path: str = "./",
46
+ ent_file: str = "entity2id.txt",
47
+ rel_file: str = "relation2id.txt",
48
+ train_file: str = "train2id.txt",
49
+ valid_file: str = "valid2id.txt",
50
+ test_file: str = "test2id.txt",
51
+ batch_size: int | None = None,
52
+ neg_ent: int = 1,
53
+ test: bool = False,
54
+ test_batch_size: int | None = None,
55
+ type_constrain: bool = True,
56
+ num_workers: int | None = None,
57
+ train_sampler: typing.Union[typing.Type[UniSampler], typing.Type[BernSampler], typing.Type[RGCNSampler], typing.Type[CompGCNSampler]] = BernSampler,
58
+ test_sampler: typing.Type[TestSampler] = TradTestSampler):
59
+
60
+ """创建 KGEDataLoader 对象。
61
+
62
+ :param in_path: 数据集目录
63
+ :type in_path: str
64
+ :param ent_file: entity2id.txt
65
+ :type ent_file: str
66
+ :param rel_file: relation2id.txt
67
+ :type rel_file: str
68
+ :param train_file: train2id.txt
69
+ :type train_file: str
70
+ :param valid_file: valid2id.txt
71
+ :type valid_file: str
72
+ :param test_file: test2id.txt
73
+ :type test_file: str
74
+ :param batch_size: batch size
75
+ :type batch_size: int | None
76
+ :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity;对于 CompGCN 不起作用。
77
+ :type neg_ent: int
78
+ :param test: 是否读取验证集和测试集
79
+ :type test: bool
80
+ :param test_batch_size: test batch size
81
+ :type test_batch_size: int | None
82
+ :param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
83
+ :type type_constrain: bool
84
+ :param num_workers: 加载数据的进程数
85
+ :type num_workers: int
86
+ :param train_sampler: 训练数据采样器
87
+ :type train_sampler: typing.Union[typing.Type[UniSampler], typing.Type[BernSampler], typing.Type[RGCNSampler], typing.Type[CompGCNSampler]]
88
+ :param test_sampler: 测试数据采样器
89
+ :type test_sampler: typing.Type[TestSampler]
90
+ """
91
+
92
+ #: 数据集目录
93
+ self.in_path: str = in_path
94
+ #: entity2id.txt
95
+ self.ent_file: str = ent_file
96
+ #: relation2id.txt
97
+ self.rel_file: str = rel_file
98
+ #: train2id.txt
99
+ self.train_file: str = train_file
100
+ #: valid2id.txt
101
+ self.valid_file: str = valid_file
102
+ #: test2id.txt
103
+ self.test_file: str = test_file
104
+ #: batch size
105
+ self.batch_size: int = batch_size
106
+ #: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity;对于 CompGCN 不起作用。
107
+ self.neg_ent: int = neg_ent
108
+ #: 是否读取验证集和测试集
109
+ self.test: bool = test
110
+ #: test batch size
111
+ self.test_batch_size: int = test_batch_size
112
+ #: 是否报告 type_constrain.txt 限制的测试结果
113
+ self.type_constrain: bool = type_constrain
114
+ #: 加载数据的进程数
115
+ self.num_workers: int = num_workers
116
+
117
+ #: 训练数据采样器
118
+ self.train_sampler: typing.Union[UniSampler, BernSampler, RGCNSampler, CompGCNSampler] = train_sampler(
119
+ in_path=self.in_path,
120
+ ent_file=self.ent_file,
121
+ rel_file=self.rel_file,
122
+ train_file=self.train_file,
123
+ batch_size=self.batch_size,
124
+ neg_ent=self.neg_ent
125
+ )
126
+
127
+ #: 训练集三元组
128
+ self.data_train: list[tuple[int, int, int]] = self.train_sampler.get_train()
129
+
130
+ if self.test:
131
+ #: 测试数据采样器
132
+ self.test_sampler: TestSampler = test_sampler(
133
+ sampler=self.train_sampler,
134
+ valid_file=self.valid_file,
135
+ test_file=self.test_file,
136
+ type_constrain=type_constrain
137
+ )
138
+
139
+ #: 验证集三元组
140
+ self.data_val: list[tuple[int, int, int]] = self.test_sampler.get_valid()
141
+ #: 测试集三元组
142
+ self.data_test: list[tuple[int, int, int]] = self.test_sampler.get_test()
143
+
144
+ def get_ent_tol(self) -> int:
145
+
146
+ """返回实体个数。
147
+
148
+ :returns: 实体个数
149
+ :rtype: int
150
+ """
151
+
152
+ return self.train_sampler.ent_tol
153
+
154
+ def get_rel_tol(self) -> int:
155
+
156
+ """返回关系个数。
157
+
158
+ :returns: 关系个数
159
+ :rtype: int
160
+ """
161
+
162
+ return self.train_sampler.rel_tol
163
+
164
+ def train_dataloader(self) -> DataLoader:
165
+
166
+ """返回训练数据加载器。
167
+
168
+ :returns: 训练数据加载器
169
+ :rtype: torch.utils.data.DataLoader
170
+ """
171
+
172
+ return DataLoader(
173
+ self.data_train,
174
+ shuffle=True,
175
+ batch_size=self.batch_size,
176
+ num_workers=self.num_workers,
177
+ pin_memory=True,
178
+ drop_last=True,
179
+ collate_fn=self.train_sampler.sampling,
180
+ )
181
+
182
+ def val_dataloader(self) -> DataLoader:
183
+
184
+ """返回验证数据加载器。
185
+
186
+ :returns: 验证数据加载器
187
+ :rtype: torch.utils.data.DataLoader
188
+ """
189
+
190
+ return DataLoader(
191
+ self.data_val,
192
+ shuffle=False,
193
+ batch_size=self.test_batch_size,
194
+ num_workers=self.num_workers,
195
+ pin_memory=True,
196
+ collate_fn=self.test_sampler.sampling,
197
+ )
198
+
199
+ def test_dataloader(self) -> DataLoader:
200
+
201
+ """返回测试数据加载器。
202
+
203
+ :returns: 测试数据加载器
204
+ :rtype: torch.utils.data.DataLoader"""
205
+
206
+ return DataLoader(
207
+ self.data_test,
208
+ shuffle=False,
209
+ batch_size=self.test_batch_size,
210
+ num_workers=self.num_workers,
211
+ pin_memory=True,
212
+ collate_fn=self.test_sampler.sampling,
213
+ )
214
+
215
+ def get_kge_data_loader_hpo_config() -> dict[str, dict[str, typing.Any]]:
216
+
217
+ """返回 :py:class:`KGEDataLoader` 的默认超参数优化配置。
218
+
219
+ 默认配置为::
220
+
221
+ parameters_dict = {
222
+ 'dataloader': {
223
+ 'value': 'KGEDataLoader'
224
+ },
225
+ 'in_path': {
226
+ 'value': './'
227
+ },
228
+ 'ent_file': {
229
+ 'value': 'entity2id.txt'
230
+ },
231
+ 'rel_file': {
232
+ 'value': 'relation2id.txt'
233
+ },
234
+ 'train_file': {
235
+ 'value': 'train2id.txt'
236
+ },
237
+ 'valid_file': {
238
+ 'value': 'valid2id.txt'
239
+ },
240
+ 'test_file': {
241
+ 'value': 'test2id.txt'
242
+ },
243
+ 'batch_size': {
244
+ 'values': [512, 1024, 2048, 4096]
245
+ },
246
+ 'neg_ent': {
247
+ 'values': [1, 4, 16, 64]
248
+ },
249
+ 'test_batch_size': {
250
+ 'value': 30
251
+ },
252
+ 'type_constrain': {
253
+ 'value': True
254
+ },
255
+ 'num_workers': {
256
+ 'value': 16
257
+ },
258
+ 'train_sampler': {
259
+ 'value': 'BernSampler'
260
+ },
261
+ 'test_sampler': {
262
+ 'value': 'TradTestSampler'
263
+ }
264
+ }
265
+
266
+ :returns: :py:class:`KGEDataLoader` 的默认超参数优化配置
267
+ :rtype: dict[str, dict[str, typing.Any]]
268
+ """
269
+
270
+ parameters_dict = {
271
+ 'dataloader': {
272
+ 'value': 'KGEDataLoader'
273
+ },
274
+ 'in_path': {
275
+ 'value': './'
276
+ },
277
+ 'ent_file': {
278
+ 'value': 'entity2id.txt'
279
+ },
280
+ 'rel_file': {
281
+ 'value': 'relation2id.txt'
282
+ },
283
+ 'train_file': {
284
+ 'value': 'train2id.txt'
285
+ },
286
+ 'valid_file': {
287
+ 'value': 'valid2id.txt'
288
+ },
289
+ 'test_file': {
290
+ 'value': 'test2id.txt'
291
+ },
292
+ 'batch_size': {
293
+ 'values': [512, 1024, 2048, 4096]
294
+ },
295
+ 'neg_ent': {
296
+ 'values': [1, 4, 16, 64]
297
+ },
298
+ 'test_batch_size': {
299
+ 'value': 30
300
+ },
301
+ 'type_constrain': {
302
+ 'value': True
303
+ },
304
+ 'num_workers': {
305
+ 'value': 16
306
+ },
307
+ 'train_sampler': {
308
+ 'value': 'BernSampler'
309
+ },
310
+ 'test_sampler': {
311
+ 'value': 'TradTestSampler'
312
+ }
313
+ }
314
+
315
+ return parameters_dict