unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,189 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/TestSampler.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
7
+ #
8
+ # 测试数据采样器基类.
9
+
10
+ """
11
+ TestSampler - 测试数据采样器基类。
12
+ """
13
+
14
+ import os
15
+ import torch
16
+ import typing
17
+ from .TradSampler import TradSampler
18
+ from .RGCNSampler import RGCNSampler
19
+ from .CompGCNSampler import CompGCNSampler
20
+ from collections import defaultdict as ddict
21
+ from ..utils import construct_type_constrain
22
+
23
+ class TestSampler(object):
24
+
25
+ """测试数据采样器基类。
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ sampler: typing.Union[TradSampler, RGCNSampler, CompGCNSampler],
31
+ valid_file: str = "valid2id.txt",
32
+ test_file: str = "test2id.txt",
33
+ type_constrain: bool = True):
34
+
35
+ """创建 TestSampler 对象。
36
+
37
+ :param sampler: 训练数据采样器。
38
+ :type sampler: typing.Union[TradSampler, RGCNSampler, CompGCNSampler]
39
+ :param valid_file: valid2id.txt
40
+ :type valid_file: str
41
+ :param test_file: test2id.txt
42
+ :type test_file: str
43
+ :param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
44
+ :type type_constrain: bool
45
+ """
46
+
47
+ #: 训练数据采样器
48
+ self.sampler: typing.Union[TradSampler, RGCNSampler, CompGCNSampler] = sampler
49
+ #: 实体的个数
50
+ self.ent_tol: int = sampler.ent_tol
51
+ #: valid2id.txt
52
+ self.valid_file: str = valid_file
53
+ #: test2id.txt
54
+ self.test_file: str = test_file
55
+
56
+ #: 验证集三元组的个数
57
+ self.valid_tol: int = 0
58
+ #: 测试集三元组的个数
59
+ self.test_tol: int = 0
60
+
61
+ #: 验证集三元组
62
+ self.valid_triples: list[tuple[int, int, int]] = []
63
+ #: 测试集三元组
64
+ self.test_triples: list[tuple[int, int, int]] = []
65
+ #: 知识图谱所有三元组
66
+ self.all_true_triples: set[tuple[int, int, int]] = set()
67
+
68
+ self.get_valid_test_triples_id()
69
+
70
+ #: 知识图谱中所有 h-r 对对应的 t 集合
71
+ self.hr2t_all: ddict[set] = ddict(set)
72
+ #: 知识图谱中所有 r-t 对对应的 h 集合
73
+ self.rt2h_all: ddict[set] = ddict(set)
74
+
75
+ #: 是否报告 type_constrain.txt 限制的测试结果
76
+ self.type_constrain: bool = type_constrain
77
+
78
+ if self.type_constrain:
79
+ construct_type_constrain(
80
+ in_path = self.sampler.in_path,
81
+ train_file = self.sampler.train_file,
82
+ valid_file = self.valid_file,
83
+ test_file = self.test_file
84
+ )
85
+ #: 知识图谱中所有 r 存在头实体种类
86
+ self.rel_heads: ddict[set] = ddict(set)
87
+ #: 知识图谱中所有 r 存在尾实体种类
88
+ self.rel_tails: ddict[set] = ddict(set)
89
+ self.get_type_constrain_id()
90
+
91
+ def get_valid_test_triples_id(self):
92
+
93
+ """读取 :py:attr:`valid_file` 文件和 :py:attr:`test_file` 文件。"""
94
+
95
+ with open(os.path.join(self.sampler.in_path, self.valid_file)) as f:
96
+ self.valid_tol = (int)(f.readline())
97
+ for line in f:
98
+ h, t, r = line.strip().split()
99
+ self.valid_triples.append((int(h), int(r), int(t)))
100
+
101
+ with open(os.path.join(self.sampler.in_path, self.test_file)) as f:
102
+ self.test_tol = (int)(f.readline())
103
+ for line in f:
104
+ h, t, r = line.strip().split()
105
+ self.test_triples.append((int(h), int(r), int(t)))
106
+
107
+ self.all_true_triples = set(
108
+ self.sampler.train_triples + self.valid_triples + self.test_triples
109
+ )
110
+
111
+ def get_type_constrain_id(self):
112
+
113
+ """读取 type_constrain.txt 文件。"""
114
+
115
+ with open(os.path.join(self.sampler.in_path, "type_constrain.txt")) as f:
116
+ rel_tol = (int)(f.readline())
117
+ first_line = True
118
+ for line in f:
119
+ rel_types = line.strip().split("\t")
120
+ for entity in rel_types[2:]:
121
+ if first_line:
122
+ self.rel_heads[int(rel_types[0])].add(int(entity))
123
+ else:
124
+ self.rel_tails[int(rel_types[0])].add(int(entity))
125
+ first_line = not first_line
126
+
127
+ for rel in self.rel_heads:
128
+ self.rel_heads[rel] = torch.tensor(list(self.rel_heads[rel]))
129
+ for rel in self.rel_tails:
130
+ self.rel_tails[rel] = torch.tensor(list(self.rel_tails[rel]))
131
+
132
+ def get_hr2t_rt2h_from_all(self):
133
+
134
+ """获得 :py:attr:`hr2t_all` 和 :py:attr:`rt2h_all` 。"""
135
+
136
+ for h, r, t in self.all_true_triples:
137
+ self.hr2t_all[(h, r)].add(t)
138
+ self.rt2h_all[(r, t)].add(h)
139
+ for h, r in self.hr2t_all:
140
+ self.hr2t_all[(h, r)] = torch.tensor(list(self.hr2t_all[(h, r)]))
141
+ for r, t in self.rt2h_all:
142
+ self.rt2h_all[(r, t)] = torch.tensor(list(self.rt2h_all[(r, t)]))
143
+
144
+ def sampling(
145
+ self,
146
+ data: list[tuple[int, int, int]]) -> dict[str, torch.Tensor]:
147
+
148
+ """采样函数。该方法未实现,子类必须重写该方法,否则抛出 :py:class:`NotImplementedError` 错误。
149
+
150
+ :param data: 测试的正确三元组
151
+ :type data: list[tuple[int, int, int]]
152
+ :returns: 测试数据
153
+ :rtype: dict[str, torch.Tensor]
154
+ """
155
+
156
+ raise NotImplementedError
157
+
158
+ def get_valid(self) -> list[tuple[int, int, int]]:
159
+
160
+ """
161
+ 返回验证集三元组。
162
+
163
+ :returns: :py:attr:`valid_triples`
164
+ :rtype: list[tuple[int, int, int]]
165
+ """
166
+
167
+ return self.valid_triples
168
+
169
+ def get_test(self) -> list[tuple[int, int, int]]:
170
+
171
+ """
172
+ 返回测试集三元组。
173
+
174
+ :returns: :py:attr:`test_triples`
175
+ :rtype: list[tuple[int, int, int]]
176
+ """
177
+
178
+ return self.test_triples
179
+
180
+ def get_all_true_triples(self) -> set[tuple[int, int, int]]:
181
+
182
+ """
183
+ 返回知识图谱所有三元组。
184
+
185
+ :returns: :py:attr:`all_true_triples`
186
+ :rtype: set[tuple[int, int, int]]
187
+ """
188
+
189
+ return self.all_true_triples
@@ -0,0 +1,122 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/TradSampler.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 28, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
7
+ #
8
+ # 为 KGReader 增加构建负三元组的函数,用于平移模型和语义匹配模型.
9
+
10
+ """
11
+ TradSampler - 为 KGReader 增加构建负三元组的函数,用于平移模型和语义匹配模型。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import numpy as np
17
+ from .KGReader import KGReader
18
+
19
+ class TradSampler(KGReader):
20
+
21
+ """
22
+ 平移模型和语义匹配模型的采样器的基类。
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ in_path: str = "./",
28
+ ent_file: str = "entity2id.txt",
29
+ rel_file: str = "relation2id.txt",
30
+ train_file: str = "train2id.txt",
31
+ batch_size: int | None = None,
32
+ neg_ent: int = 1):
33
+
34
+ """创建 TradSampler 对象。
35
+
36
+ :param in_path: 数据集目录
37
+ :type in_path: str
38
+ :param ent_file: entity2id.txt
39
+ :type ent_file: str
40
+ :param rel_file: relation2id.txt
41
+ :type rel_file: str
42
+ :param train_file: train2id.txt
43
+ :type train_file: str
44
+ :param batch_size: batch size 在该采样器中不起作用,只是占位符。
45
+ :type batch_size: int | None
46
+ :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity
47
+ :type neg_ent: int
48
+ """
49
+
50
+ super().__init__(
51
+ in_path=in_path,
52
+ ent_file=ent_file,
53
+ rel_file=rel_file,
54
+ train_file=train_file
55
+ )
56
+
57
+ #: batch size
58
+ self.batch_size: int = batch_size
59
+ #: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity (head + tail)
60
+ self.neg_ent: int = neg_ent
61
+
62
+ self.get_hr2t_rt2h_from_train()
63
+
64
+ def sampling(
65
+ self,
66
+ pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[str, torch.Tensor]]:
67
+
68
+ """平移模型和语义匹配模型的训练集数据采样函数。该方法未实现,子类必须重写该方法,否则抛出 :py:class:`NotImplementedError` 错误。
69
+
70
+ :param pos_triples: 知识图谱中的正确三元组
71
+ :type pos_triples: list[tuple[int, int, int]]
72
+ :returns: 平移模型和语义匹配模型的训练数据
73
+ :rtype: dict[str, typing.Union[str, torch.Tensor]]
74
+ """
75
+
76
+ raise NotImplementedError
77
+
78
+ def corrupt_head(
79
+ self,
80
+ t: int,
81
+ r: int,
82
+ num_max: int = 1) -> np.ndarray:
83
+
84
+ """替换头实体构建负三元组。
85
+
86
+ :param t: 尾实体
87
+ :type t: int
88
+ :param r: 关系
89
+ :type r: int
90
+ :param num_max: 一次负采样的个数
91
+ :type num_max: int
92
+ :returns: 负三元组的头实体列表
93
+ :rtype: numpy.ndarray
94
+ """
95
+
96
+ tmp = torch.randint(low=0, high=self.ent_tol, size=(num_max,)).numpy()
97
+ mask = np.in1d(tmp, self.rt2h_train[(r, t)], assume_unique=True, invert=True)
98
+ neg = tmp[mask]
99
+ return neg
100
+
101
+ def corrupt_tail(
102
+ self,
103
+ h: int,
104
+ r: int,
105
+ num_max: int = 1) -> np.ndarray:
106
+
107
+ """替换尾实体构建负三元组。
108
+
109
+ :param h: 头实体
110
+ :type h: int
111
+ :param r: 关系
112
+ :type r: int
113
+ :param num_max: 一次负采样的个数
114
+ :type num_max: int
115
+ :returns: 负三元组的尾实体列表
116
+ :rtype: numpy.ndarray
117
+ """
118
+
119
+ tmp = torch.randint(low=0, high=self.ent_tol, size=(num_max,)).numpy()
120
+ mask = np.in1d(tmp, self.hr2t_train[(h, r)], assume_unique=True, invert=True)
121
+ neg = tmp[mask]
122
+ return neg
@@ -0,0 +1,87 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/TradTestSampler.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
7
+ #
8
+ # 平移模型和语义匹配模型的测试数据采样器.
9
+
10
+ """
11
+ TradTestSampler - 平移模型和语义匹配模型的测试数据采样器。
12
+ """
13
+
14
+ import torch
15
+ from .TradSampler import TradSampler
16
+ from .TestSampler import TestSampler
17
+ from typing_extensions import override
18
+
19
+ class TradTestSampler(TestSampler):
20
+
21
+ """平移模型和语义匹配模型的测试数据采样器。
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ sampler: TradSampler,
27
+ valid_file: str = "valid2id.txt",
28
+ test_file: str = "test2id.txt",
29
+ type_constrain: bool = True):
30
+
31
+ """创建 TradTestSampler 对象。
32
+
33
+ :param sampler: 训练数据采样器。
34
+ :type sampler: TradSampler
35
+ :param valid_file: valid2id.txt
36
+ :type valid_file: str
37
+ :param test_file: test2id.txt
38
+ :type test_file: str
39
+ :param type_constrain: 是否报告 type_constrain.txt 限制的测试结果
40
+ :type type_constrain: bool
41
+ """
42
+
43
+ super().__init__(
44
+ sampler=sampler,
45
+ valid_file=valid_file,
46
+ test_file=test_file,
47
+ type_constrain=type_constrain
48
+ )
49
+
50
+ self.get_hr2t_rt2h_from_all()
51
+
52
+ @override
53
+ def sampling(
54
+ self,
55
+ data: list[tuple[int, int, int]]) -> dict[str, torch.Tensor]:
56
+
57
+ """采样函数。
58
+
59
+ :param data: 测试的正确三元组
60
+ :type data: list[tuple[int, int, int]]
61
+ :returns: 测试数据
62
+ :rtype: dict[str, torch.Tensor]
63
+ """
64
+
65
+ batch_data = {}
66
+ head_label = torch.zeros(len(data), self.ent_tol)
67
+ tail_label = torch.zeros(len(data), self.ent_tol)
68
+ for idx, triple in enumerate(data):
69
+ head, rel, tail = triple
70
+ head_label[idx][self.rt2h_all[(rel, tail)]] = 1.0
71
+ tail_label[idx][self.hr2t_all[(head, rel)]] = 1.0
72
+
73
+ if self.type_constrain:
74
+ head_label_type = torch.ones(len(data), self.ent_tol)
75
+ tail_laebl_type = torch.ones(len(data), self.ent_tol)
76
+ for idx, triple in enumerate(data):
77
+ head, rel, tail = triple
78
+ head_label_type[idx][self.rel_heads[rel]] = 0.0
79
+ tail_laebl_type[idx][self.rel_tails[rel]] = 0.0
80
+ head_label_type[idx][self.rt2h_all[(rel, tail)]] = 1.0
81
+ tail_laebl_type[idx][self.hr2t_all[(head, rel)]] = 1.0
82
+ batch_data["head_label_type"] = head_label_type
83
+ batch_data["tail_label_type"] = tail_laebl_type
84
+ batch_data["positive_sample"] = torch.tensor(data)
85
+ batch_data["head_label"] = head_label
86
+ batch_data["tail_label"] = tail_label
87
+ return batch_data
@@ -0,0 +1,145 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/UniSampler.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2024
7
+ #
8
+ # 平移模型和语义匹配模型的训练集数据采样器.
9
+
10
+ """
11
+ UniSampler - 平移模型和语义匹配模型的训练集数据采样器。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ import warnings
17
+ import numpy as np
18
+ from .TradSampler import TradSampler
19
+ from typing_extensions import override
20
+
21
+ warnings.filterwarnings("ignore")
22
+
23
+ class UniSampler(TradSampler):
24
+
25
+ """平移模型和语义匹配模型的训练集普通的数据采样器(均值分布)。
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ in_path: str = "./",
31
+ ent_file: str = "entity2id.txt",
32
+ rel_file: str = "relation2id.txt",
33
+ train_file: str = "train2id.txt",
34
+ batch_size: int | None = None,
35
+ neg_ent: int = 1):
36
+
37
+ """创建 UniSampler 对象。
38
+
39
+ :param in_path: 数据集目录
40
+ :type in_path: str
41
+ :param ent_file: entity2id.txt
42
+ :type ent_file: str
43
+ :param rel_file: relation2id.txt
44
+ :type rel_file: str
45
+ :param train_file: train2id.txt
46
+ :type train_file: str
47
+ :param batch_size: batch size 在该采样器中不起作用,只是占位符。
48
+ :type batch_size: int | None
49
+ :param neg_ent: 对于每一个正三元组, 构建的负三元组的个数, 替换 entity
50
+ :type neg_ent: int
51
+ """
52
+
53
+ super().__init__(
54
+ in_path=in_path,
55
+ ent_file=ent_file,
56
+ rel_file=rel_file,
57
+ train_file=train_file,
58
+ batch_size = batch_size,
59
+ neg_ent = neg_ent
60
+ )
61
+
62
+ self.cross_sampling_flag = 0
63
+
64
+ @override
65
+ def sampling(
66
+ self,
67
+ pos_triples: list[tuple[int, int, int]]) -> dict[str, typing.Union[str, torch.Tensor]]:
68
+
69
+ """平移模型和语义匹配模型的训练集普通的数据采样函数(均匀分布)。
70
+
71
+ :param pos_triples: 知识图谱中的正确三元组
72
+ :type pos_triples: list[tuple[int, int, int]]
73
+ :returns: 平移模型和语义匹配模型的训练数据
74
+ :rtype: dict[str, typing.Union[str, torch.Tensor]]
75
+ """
76
+
77
+ batch_data = {}
78
+ neg_ent_sample = []
79
+ self.cross_sampling_flag = 1 - self.cross_sampling_flag
80
+ if self.cross_sampling_flag == 0:
81
+ batch_data['mode'] = "head-batch"
82
+ for h, r, t in pos_triples:
83
+ neg_head = self.head_batch(t, r, self.neg_ent)
84
+ neg_ent_sample.append(neg_head)
85
+ else:
86
+ batch_data['mode'] = "tail-batch"
87
+ for h, r, t in pos_triples:
88
+ neg_tail = self.tail_batch(h, r, self.neg_ent)
89
+ neg_ent_sample.append(neg_tail)
90
+
91
+ batch_data["positive_sample"] = torch.LongTensor(np.array(pos_triples))
92
+ batch_data['negative_sample'] = torch.LongTensor(np.array(neg_ent_sample))
93
+ return batch_data
94
+
95
+ def head_batch(
96
+ self,
97
+ t: int,
98
+ r: int,
99
+ neg_size: int= None) -> np.ndarray:
100
+
101
+ """替换头实体构建负三元组。
102
+
103
+ :param t: 尾实体
104
+ :type t: int
105
+ :param r: 关系
106
+ :type r: int
107
+ :param neg_size: 负三元组个数
108
+ :type neg_size: int
109
+ :returns: 负三元组中的头实体列表
110
+ :rtype: numpy.ndarray
111
+ """
112
+
113
+ neg_list = []
114
+ neg_cur_size = 0
115
+ while neg_cur_size < neg_size:
116
+ neg_tmp = self.corrupt_head(t, r, num_max=(neg_size - neg_cur_size) * 2)
117
+ neg_list.append(neg_tmp)
118
+ neg_cur_size += len(neg_tmp)
119
+ return np.concatenate(neg_list)[:neg_size]
120
+
121
+ def tail_batch(
122
+ self,
123
+ h: int,
124
+ r: int,
125
+ neg_size: int = None) -> np.ndarray:
126
+
127
+ """替换尾实体构建负三元组。
128
+
129
+ :param h: 头实体
130
+ :type h: int
131
+ :param r: 关系
132
+ :type r: int
133
+ :param neg_size: 负三元组个数
134
+ :type neg_size: int
135
+ :returns: 负三元组中的尾实体列表
136
+ :rtype: numpy.ndarray
137
+ """
138
+
139
+ neg_list = []
140
+ neg_cur_size = 0
141
+ while neg_cur_size < neg_size:
142
+ neg_tmp = self.corrupt_tail(h, r, num_max=(neg_size - neg_cur_size) * 2)
143
+ neg_list.append(neg_tmp)
144
+ neg_cur_size += len(neg_tmp)
145
+ return np.concatenate(neg_list)[:neg_size]
unike/data/__init__.py ADDED
@@ -0,0 +1,47 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/data/__init__.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 29, 2023
7
+ #
8
+ # 该头文件定义了 data 接口.
9
+
10
+ """数据采样部分,包含为训练和验证模型定义的数据采样器。"""
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+ from .KGReader import KGReader
17
+
18
+ from .TradSampler import TradSampler
19
+ from .UniSampler import UniSampler
20
+ from .BernSampler import BernSampler
21
+
22
+ from .RevSampler import RevSampler
23
+ from .RGCNSampler import RGCNSampler
24
+ from .CompGCNSampler import CompGCNSampler
25
+
26
+ from .TestSampler import TestSampler
27
+ from .TradTestSampler import TradTestSampler
28
+ from .RGCNTestSampler import RGCNTestSampler
29
+ from .CompGCNTestSampler import CompGCNTestSampler
30
+
31
+ from .KGEDataLoader import KGEDataLoader, get_kge_data_loader_hpo_config
32
+
33
+ __all__ = [
34
+ 'KGReader',
35
+ 'TradSampler',
36
+ 'UniSampler',
37
+ 'BernSampler',
38
+ 'RevSampler',
39
+ 'RGCNSampler',
40
+ 'CompGCNSampler',
41
+ 'TestSampler',
42
+ 'TradTestSampler',
43
+ 'RGCNTestSampler',
44
+ 'CompGCNTestSampler',
45
+ 'KGEDataLoader',
46
+ 'get_kge_data_loader_hpo_config'
47
+ ]