unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,138 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/strategy/NegativeSampling.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 9, 2024
7
+ #
8
+ # 该脚本定义了平移模型和语义匹配模型的训练策略.
9
+
10
+ """
11
+ NegativeSampling - 训练策略类,包含损失函数。
12
+ """
13
+
14
+ import torch
15
+ import typing
16
+ from ..loss import Loss
17
+ from ..model import Model
18
+ from .Strategy import Strategy
19
+
20
+ class NegativeSampling(Strategy):
21
+
22
+ """
23
+ 将模型和损失函数封装到一起,方便模型训练。
24
+
25
+ 例子::
26
+
27
+ from unike.module.model import TransE
28
+ from unike.module.loss import MarginLoss
29
+ from unike.module.strategy import NegativeSampling
30
+
31
+ # define the model
32
+ transe = TransE(
33
+ ent_tol = dataloader.get_ent_tol(),
34
+ rel_tol = dataloader.get_rel_tol(),
35
+ dim = 50,
36
+ p_norm = 1,
37
+ norm_flag = True
38
+ )
39
+
40
+ # define the loss function
41
+ model = NegativeSampling(
42
+ model = transe,
43
+ loss = MarginLoss(margin = 1.0),
44
+ regul_rate = 0.01
45
+ )
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ model: Model = None,
51
+ loss: Loss = None,
52
+ regul_rate: float = 0.0,
53
+ l3_regul_rate: float = 0.0):
54
+
55
+ """创建 NegativeSampling 对象。
56
+
57
+ :param model: KGE 模型
58
+ :type model: :py:class:`unike.module.model.Model`
59
+ :param loss: 损失函数。
60
+ :type loss: :py:class:`unike.module.loss.Loss`
61
+ :param regul_rate: 权重衰减系数
62
+ :type regul_rate: float
63
+ :param l3_regul_rate: l3 正则化系数
64
+ :type l3_regul_rate: float
65
+ """
66
+
67
+ super(NegativeSampling, self).__init__()
68
+ #: KGE 模型,即 :py:class:`unike.module.model.Model`
69
+ self.model: Model = model
70
+ #: 损失函数,即 :py:class:`unike.module.loss.Loss`
71
+ self.loss: Loss = loss
72
+ #: 权重衰减系数
73
+ self.regul_rate: float = regul_rate
74
+ #: l3 正则化系数
75
+ self.l3_regul_rate: float = l3_regul_rate
76
+
77
+ def forward(self, data: dict[str, typing.Union[torch.Tensor, str]]) -> torch.Tensor:
78
+
79
+ """计算最后的损失值。定义每次调用时执行的计算。
80
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
81
+
82
+ :param data: 数据
83
+ :type data: dict[str, typing.Union[torch.Tensor, str]]
84
+ :returns: 损失值
85
+ :rtype: torch.Tensor
86
+ """
87
+
88
+ pos_sample = data["positive_sample"]
89
+ neg_sample = data["negative_sample"]
90
+ mode = data["mode"]
91
+ pos_score = self.model(pos_sample)
92
+ if mode == "bern":
93
+ neg_score = self.model(neg_sample)
94
+ neg_score = neg_score.view(pos_score.shape[0], -1)
95
+ else:
96
+ neg_score = self.model(pos_sample, neg_sample, mode)
97
+ loss_res = self.loss(pos_score, neg_score)
98
+ if self.regul_rate != 0:
99
+ loss_res += self.regul_rate * self.model.regularization(data)
100
+ if self.l3_regul_rate != 0:
101
+ loss_res += self.l3_regul_rate * self.model.l3_regularization()
102
+ return loss_res
103
+
104
+ def get_negative_sampling_hpo_config() -> dict[str, dict[str, typing.Any]]:
105
+
106
+ """返回 :py:class:`NegativeSampling` 的默认超参数优化配置。
107
+
108
+ 默认配置为::
109
+
110
+ parameters_dict = {
111
+ 'strategy': {
112
+ 'value': 'NegativeSampling'
113
+ },
114
+ 'regul_rate': {
115
+ 'value': 0.0
116
+ },
117
+ 'l3_regul_rate': {
118
+ 'value': 0.0
119
+ }
120
+ }
121
+
122
+ :returns: :py:class:`NegativeSampling` 的默认超参数优化配置
123
+ :rtype: dict[str, dict[str, typing.Any]]
124
+ """
125
+
126
+ parameters_dict = {
127
+ 'strategy': {
128
+ 'value': 'NegativeSampling'
129
+ },
130
+ 'regul_rate': {
131
+ 'value': 0.0
132
+ },
133
+ 'l3_regul_rate': {
134
+ 'value': 0.0
135
+ }
136
+ }
137
+
138
+ return parameters_dict
@@ -0,0 +1,134 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/strategy/RGCNSampling.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 16, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 18, 2023
7
+ #
8
+ # 该脚本定义了 R-GCN 模型的训练策略.
9
+
10
+ """
11
+ NegativeSampling - 训练策略类,包含损失函数。
12
+ """
13
+
14
+ import dgl
15
+ import torch
16
+ import typing
17
+ from ..loss import Loss
18
+ from ..model import Model
19
+ from .Strategy import Strategy
20
+
21
+ class RGCNSampling(Strategy):
22
+
23
+ """
24
+ 将模型和损失函数封装到一起,方便模型训练,用于 ``R-GCN`` :cite:`R-GCN`。
25
+
26
+ 例子::
27
+
28
+ from unike.data import GraphDataLoader
29
+ from unike.module.model import RGCN
30
+ from unike.module.loss import RGCNLoss
31
+ from unike.module.strategy import RGCNSampling
32
+ from unike.config import Trainer, GraphTester
33
+
34
+ dataloader = GraphDataLoader(
35
+ in_path = "../../benchmarks/FB15K237/",
36
+ batch_size = 60000,
37
+ neg_ent = 10,
38
+ test = True,
39
+ test_batch_size = 100,
40
+ num_workers = 16
41
+ )
42
+
43
+ # define the model
44
+ rgcn = RGCN(
45
+ ent_tol = dataloader.train_sampler.ent_tol,
46
+ rel_tol = dataloader.train_sampler.rel_tol,
47
+ dim = 500,
48
+ num_layers = 2
49
+ )
50
+
51
+ # define the loss function
52
+ model = RGCNSampling(
53
+ model = rgcn,
54
+ loss = RGCNLoss(model = rgcn, regularization = 1e-5)
55
+ )
56
+
57
+ # test the model
58
+ tester = GraphTester(model = rgcn, data_loader = dataloader, use_gpu = True, device = 'cuda:0')
59
+
60
+ # train the model
61
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
62
+ epochs = 10000, lr = 0.0001, use_gpu = True, device = 'cuda:0',
63
+ tester = tester, test = True, valid_interval = 500, log_interval = 500,
64
+ save_interval = 500, save_path = '../../checkpoint/rgcn.pth'
65
+ )
66
+ trainer.run()
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ model: Model = None,
72
+ loss: Loss = None):
73
+
74
+ """创建 RGCNSampling 对象。
75
+
76
+ :param model: R-GCN 模型
77
+ :type model: :py:class:`unike.module.model.RGCN`
78
+ :param loss: 损失函数。
79
+ :type loss: :py:class:`unike.module.loss.Loss`
80
+ """
81
+
82
+ super(RGCNSampling, self).__init__()
83
+
84
+ #: R-GCN 模型,即 :py:class:`unike.module.model.RGCN`
85
+ self.model: Model = model
86
+ #: 损失函数,即 :py:class:`unike.module.loss.Loss`
87
+ self.loss: Loss = loss
88
+
89
+ def forward(
90
+ self,
91
+ data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]) -> torch.Tensor:
92
+
93
+ """计算最后的损失值。定义每次调用时执行的计算。
94
+ :py:class:`torch.nn.Module` 子类必须重写 :py:meth:`torch.nn.Module.forward`。
95
+
96
+ :param data: 数据
97
+ :type data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
98
+ :returns: 损失值
99
+ :rtype: torch.Tensor
100
+ """
101
+
102
+ graph = data["graph"]
103
+ entity = data['entity']
104
+ relation = data['relation']
105
+ norm = data['norm']
106
+ triples = data["triples"]
107
+ label = data["label"]
108
+ score = self.model(graph, entity, relation, norm, triples)
109
+ loss = self.loss(score, label)
110
+ return loss
111
+
112
+ def get_rgcn_sampling_hpo_config() -> dict[str, dict[str, typing.Any]]:
113
+
114
+ """返回 :py:class:`RGCNSampling` 的默认超参数优化配置。
115
+
116
+ 默认配置为::
117
+
118
+ parameters_dict = {
119
+ 'strategy': {
120
+ 'value': 'RGCNSampling'
121
+ }
122
+ }
123
+
124
+ :returns: :py:class:`RGCNSampling` 的默认超参数优化配置
125
+ :rtype: dict[str, dict[str, typing.Any]]
126
+ """
127
+
128
+ parameters_dict = {
129
+ 'strategy': {
130
+ 'value': 'RGCNSampling'
131
+ }
132
+ }
133
+
134
+ return parameters_dict
@@ -0,0 +1,26 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/strategy/Strategy.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 4, 2023
7
+ #
8
+ # 该脚本定义了损失函数的基类.
9
+
10
+ """
11
+ Strategy - 该脚本定义了训练策略的基类。
12
+ """
13
+
14
+ from ..BaseModule import BaseModule
15
+
16
+ class Strategy(BaseModule):
17
+
18
+ """
19
+ 继承自 :py:class:`unike.module.BaseModule`,什么额外的属性都没有增加。
20
+ """
21
+
22
+ def __init__(self):
23
+
24
+ """创建 Loss 对象。"""
25
+
26
+ super(Strategy, self).__init__()
@@ -0,0 +1,29 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/module/strategy/__init__.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 28, 2023
7
+ #
8
+ # 该头文件定义了 strategy 接口.
9
+
10
+ """训练策略部分。"""
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+ from .Strategy import Strategy
17
+ from .NegativeSampling import NegativeSampling, get_negative_sampling_hpo_config
18
+ from .RGCNSampling import RGCNSampling, get_rgcn_sampling_hpo_config
19
+ from .CompGCNSampling import CompGCNSampling, get_compgcn_sampling_hpo_config
20
+
21
+ __all__ = [
22
+ 'Strategy',
23
+ 'NegativeSampling',
24
+ 'get_negative_sampling_hpo_config',
25
+ 'RGCNSampling',
26
+ 'get_rgcn_sampling_hpo_config',
27
+ 'CompGCNSampling',
28
+ 'get_compgcn_sampling_hpo_config'
29
+ ]
@@ -0,0 +1,94 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/utils/EarlyStopping.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 5, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 6, 2024
7
+ #
8
+ # 该脚本定义了 EarlyStopping 类.
9
+
10
+ """
11
+ EarlyStopping - 使用早停止避免过拟合。
12
+ """
13
+
14
+ import os
15
+ import numpy as np
16
+ from ..module.model import Model
17
+ import logging
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logging.basicConfig(format='%(levelname)s:%(module)s:%(asctime)s:%(message)s',
21
+ datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
22
+
23
+ class EarlyStopping:
24
+
25
+ """
26
+ 如果验证得分(越大越好)在给定的耐心后没有改善,则提前停止训练。
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ save_path: str,
32
+ patience: int = 2,
33
+ verbose: bool = True,
34
+ delta: float = 0):
35
+
36
+ """创建 EarlyStopping 对象。
37
+
38
+ :param save_path: 模型保存目录
39
+ :type save_path: str
40
+ :param patience: 上次验证得分改善后等待多长时间。默认值:2
41
+ :type patience: int
42
+ :param verbose: 如果为 True,则为每个验证得分改进打印一条消息。默认值:True
43
+ :type verbose: bool
44
+ :param delta: 监测数量的最小变化才符合改进条件。默认值:0
45
+ :type delta: float
46
+ """
47
+
48
+ #: 模型保存目录
49
+ self.save_path: str = os.path.join(save_path, 'best_network.pth')
50
+ #: 上次验证得分改善后等待多长时间。默认值:2
51
+ self.patience: int = patience
52
+ #: 如果为 True,则为每个验证得分改进打印一条消息。默认值:True
53
+ self.verbose: bool = verbose
54
+ #: 监测数量的最小变化才符合改进条件。默认值:0
55
+ self.delta: float = delta
56
+
57
+ #: 计数变量
58
+ self.counter: int = 0
59
+ #: 保存最好的得分
60
+ self.best_score: float = -np.Inf
61
+ #: 早停开关
62
+ self.early_stop: bool = False
63
+
64
+ def __call__(
65
+ self,
66
+ score: float,
67
+ model: Model):
68
+
69
+ """
70
+ 进行早停记录。
71
+ """
72
+
73
+ if score <= self.best_score + self.delta:
74
+ self.counter += 1
75
+ logger.info(f'EarlyStopping counter: {self.counter} / {self.patience}')
76
+ if self.counter >= self.patience:
77
+ self.early_stop = True
78
+ else:
79
+ self.save_checkpoint(score, model)
80
+ self.counter = 0
81
+
82
+ def save_checkpoint(
83
+ self,
84
+ score: float,
85
+ model: Model):
86
+
87
+ """
88
+ 当验证得分改善时保存模型。
89
+ """
90
+
91
+ if self.verbose:
92
+ logger.info(f'Validation score improved ({self.best_score:.6f} --> {score:.6f}). Saving model ...')
93
+ model.save_checkpoint(self.save_path)
94
+ self.best_score = score
unike/utils/Timer.py ADDED
@@ -0,0 +1,74 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/utils/Timer.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on July 6, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 24, 2023
7
+ #
8
+ # 该脚本定义了计时器类.
9
+
10
+ """
11
+ :py:class:`Timer` - 计时器类。
12
+ """
13
+
14
+ import time
15
+
16
+ class Timer:
17
+
18
+ """记录多次实验的时间。
19
+
20
+ 通过调用 :py:meth:`stop` 能够返回距离上一次 :py:meth:`stop` 调用或创建 :py:class:`Timer` 对象时的时间间隔。
21
+
22
+ :py:meth:`avg` 能够返回多次实验的平均时间;:py:meth:`sum` 能够返回多次实验的总时间。"""
23
+
24
+ def __init__(self):
25
+
26
+ """创建 Timer 对象。"""
27
+
28
+ #: 存放时间间隔的列表
29
+ self.times: list[float] = []
30
+ #: 记录当前时间
31
+ self.current: float = None
32
+ #: 记录上一次的时间
33
+ self.last: float = None
34
+
35
+ self.__restart()
36
+
37
+ def __restart(self):
38
+
39
+ """重启计时器。"""
40
+
41
+ self.last = self.current = time.time()
42
+
43
+ def stop(self) -> float:
44
+
45
+ """停止计时器并将时间记录在列表中。
46
+
47
+ :returns: 返回最后一次的间隔时间。
48
+ :rtype: float
49
+ """
50
+
51
+ self.current = time.time()
52
+ self.times.append(self.current - self.last)
53
+ self.last = self.current
54
+ return self.times[-1]
55
+
56
+ def avg(self) -> float:
57
+
58
+ """返回平均时间。
59
+
60
+ :returns: 平均时间
61
+ :rtype: float
62
+ """
63
+
64
+ return sum(self.times) / len(self.times)
65
+
66
+ def sum(self) -> float:
67
+
68
+ """返回时间总和。
69
+
70
+ :returns: 时间总和。
71
+ :rtype: float
72
+ """
73
+
74
+ return sum(self.times)
@@ -0,0 +1,46 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/utils/WandbLogger.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 1, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 24, 2024
7
+ #
8
+ # 该脚本定义了 WandbLogger 类.
9
+
10
+ """
11
+ WandbLogger - 使用 Weights and Biases 记录实验结果。
12
+ """
13
+
14
+ import typing
15
+ import wandb
16
+
17
+ class WandbLogger:
18
+
19
+ """使用 `Weights and Biases <https://docs.wandb.ai/>`_ 记录实验结果。"""
20
+
21
+ def __init__(self,
22
+ project: str ="pybind11-ke",
23
+ name: str = "transe",
24
+ config: dict[str, typing.Any] | None = None):
25
+
26
+ """创建 WandbLogger 对象。
27
+
28
+ :param project: wandb 的项目名称
29
+ :type project: str
30
+ :param name: wandb 的 run name
31
+ :type name: str
32
+ :param config: wandb 的项目配置如超参数。
33
+ :type config: dict[str, typing.Any] | None
34
+ """
35
+
36
+ wandb.login()
37
+ wandb.init(project=project, name=name, config=config)
38
+
39
+ #: config 的副本
40
+ self.config: dict = wandb.config
41
+
42
+ def finish(self):
43
+
44
+ """关闭 wandb"""
45
+
46
+ wandb.finish()
@@ -0,0 +1,26 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/utils/__init__.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on July 6, 2023
6
+ #
7
+ # 该头文件定义了 utils 接口.
8
+
9
+ """工具类。"""
10
+
11
+ from __future__ import absolute_import
12
+ from __future__ import division
13
+ from __future__ import print_function
14
+
15
+ from .Timer import Timer
16
+ from .WandbLogger import WandbLogger
17
+ from .tools import import_class, construct_type_constrain
18
+ from .EarlyStopping import EarlyStopping
19
+
20
+ __all__ = [
21
+ 'Timer',
22
+ 'WandbLogger',
23
+ 'import_class',
24
+ 'construct_type_constrain',
25
+ 'EarlyStopping',
26
+ ]
unike/utils/tools.py ADDED
@@ -0,0 +1,118 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/utils/tools.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 3, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Feb 25, 2024
7
+ #
8
+ # 该脚本定义了 WandbLogger 类.
9
+
10
+ import importlib
11
+
12
+ def import_class(module_and_class_name: str) -> type:
13
+
14
+ """从模块中导入类。
15
+
16
+ :param module_and_class_name: 模块和类名,如 **unike.module.model.TransE** 。
17
+ :type module_and_class_name: str
18
+ :returns: 类名
19
+ :rtype: type
20
+ """
21
+
22
+ module_name, class_name = module_and_class_name.rsplit(".", 1)
23
+ module = importlib.import_module(module_name)
24
+ class_ = getattr(module, class_name)
25
+ return class_
26
+
27
+ def construct_type_constrain(
28
+ in_path: str = "./",
29
+ train_file: str = "train2id.txt",
30
+ valid_file: str = "valid2id.txt",
31
+ test_file: str = "test2id.txt"
32
+ ):
33
+
34
+ """构建 type_constrain.txt 文件
35
+
36
+ type_constrain.txt: 类型约束文件, 第一行是关系的个数
37
+
38
+ 下面的行是每个关系的类型限制 (训练集、验证集、测试集中每个关系存在的 head 和 tail 的类型)
39
+
40
+ 每个关系有两行:
41
+
42
+ 第一行:**rel_id** **heads_num** **head1** **head2** ...
43
+
44
+ 第二行: **rel_id** **tails_num** **tail1** **tail2** ...
45
+
46
+ 如 benchmarks/FB15K 的 id 为 1200 的关系,它有 4 种类型头实体(3123,1034,58 和 5733)和 4 种类型的尾实体(12123,4388,11087 和 11088)。
47
+
48
+ 1200 4 3123 1034 58 5733
49
+
50
+ 1200 4 12123 4388 11087 11088
51
+
52
+ :param in_path: 数据集目录
53
+ :type in_path: str
54
+ :param train_file: train2id.txt
55
+ :type train_file: str
56
+ :param valid_file: valid2id.txt
57
+ :type valid_file: str
58
+ :param test_file: test2id.txt
59
+ :type test_file: str
60
+ """
61
+
62
+
63
+ rel_head: dict = {}
64
+ rel_tail: dict = {}
65
+
66
+ train = open(in_path + train_file, "r")
67
+ valid = open(in_path + valid_file, "r")
68
+ test = open(in_path + test_file, "r")
69
+
70
+ tot = (int)(train.readline())
71
+ for i in range(tot):
72
+ content = train.readline()
73
+ h,t,r = content.strip().split()
74
+ if not r in rel_head:
75
+ rel_head[r] = {}
76
+ if not r in rel_tail:
77
+ rel_tail[r] = {}
78
+ rel_head[r][h] = 1
79
+ rel_tail[r][t] = 1
80
+
81
+ tot = (int)(valid.readline())
82
+ for i in range(tot):
83
+ content = valid.readline()
84
+ h,t,r = content.strip().split()
85
+ if not r in rel_head:
86
+ rel_head[r] = {}
87
+ if not r in rel_tail:
88
+ rel_tail[r] = {}
89
+ rel_head[r][h] = 1
90
+ rel_tail[r][t] = 1
91
+
92
+ tot = (int)(test.readline())
93
+ for i in range(tot):
94
+ content = test.readline()
95
+ h,t,r = content.strip().split()
96
+ if not r in rel_head:
97
+ rel_head[r] = {}
98
+ if not r in rel_tail:
99
+ rel_tail[r] = {}
100
+ rel_head[r][h] = 1
101
+ rel_tail[r][t] = 1
102
+
103
+ train.close()
104
+ valid.close()
105
+ test.close()
106
+
107
+ f = open(in_path + "type_constrain.txt", "w")
108
+ f.write("%d\n" % (len(rel_head)))
109
+ for i in rel_head:
110
+ f.write("%s\t%d" % (i, len(rel_head[i])))
111
+ for j in rel_head[i]:
112
+ f.write("\t%s" % (j))
113
+ f.write("\n")
114
+ f.write("%s\t%d" % (i, len(rel_tail[i])))
115
+ for j in rel_tail[i]:
116
+ f.write("\t%s" % (j))
117
+ f.write("\n")
118
+ f.close()
unike/version.py ADDED
@@ -0,0 +1 @@
1
+ __version__: str = '3.0.1'