unike 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unike/__init__.py +5 -0
- unike/config/HPOTrainer.py +305 -0
- unike/config/Tester.py +385 -0
- unike/config/Trainer.py +519 -0
- unike/config/TrainerAccelerator.py +39 -0
- unike/config/__init__.py +37 -0
- unike/data/BernSampler.py +168 -0
- unike/data/CompGCNSampler.py +140 -0
- unike/data/CompGCNTestSampler.py +84 -0
- unike/data/KGEDataLoader.py +315 -0
- unike/data/KGReader.py +138 -0
- unike/data/RGCNSampler.py +261 -0
- unike/data/RGCNTestSampler.py +208 -0
- unike/data/RevSampler.py +78 -0
- unike/data/TestSampler.py +189 -0
- unike/data/TradSampler.py +122 -0
- unike/data/TradTestSampler.py +87 -0
- unike/data/UniSampler.py +145 -0
- unike/data/__init__.py +47 -0
- unike/module/BaseModule.py +130 -0
- unike/module/__init__.py +20 -0
- unike/module/loss/CompGCNLoss.py +96 -0
- unike/module/loss/Loss.py +26 -0
- unike/module/loss/MarginLoss.py +148 -0
- unike/module/loss/RGCNLoss.py +117 -0
- unike/module/loss/SigmoidLoss.py +145 -0
- unike/module/loss/SoftplusLoss.py +145 -0
- unike/module/loss/__init__.py +35 -0
- unike/module/model/Analogy.py +237 -0
- unike/module/model/CompGCN.py +562 -0
- unike/module/model/ComplEx.py +235 -0
- unike/module/model/DistMult.py +276 -0
- unike/module/model/HolE.py +308 -0
- unike/module/model/Model.py +107 -0
- unike/module/model/RESCAL.py +309 -0
- unike/module/model/RGCN.py +304 -0
- unike/module/model/RotatE.py +303 -0
- unike/module/model/SimplE.py +237 -0
- unike/module/model/TransD.py +458 -0
- unike/module/model/TransE.py +290 -0
- unike/module/model/TransH.py +322 -0
- unike/module/model/TransR.py +402 -0
- unike/module/model/__init__.py +60 -0
- unike/module/strategy/CompGCNSampling.py +140 -0
- unike/module/strategy/NegativeSampling.py +138 -0
- unike/module/strategy/RGCNSampling.py +134 -0
- unike/module/strategy/Strategy.py +26 -0
- unike/module/strategy/__init__.py +29 -0
- unike/utils/EarlyStopping.py +94 -0
- unike/utils/Timer.py +74 -0
- unike/utils/WandbLogger.py +46 -0
- unike/utils/__init__.py +26 -0
- unike/utils/tools.py +118 -0
- unike/version.py +1 -0
- unike-3.0.1.dist-info/METADATA +101 -0
- unike-3.0.1.dist-info/RECORD +59 -0
- unike-3.0.1.dist-info/WHEEL +4 -0
- unike-3.0.1.dist-info/entry_points.txt +2 -0
- unike-3.0.1.dist-info/licenses/LICENSE +21 -0
unike/config/Trainer.py
ADDED
@@ -0,0 +1,519 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/config/Trainer.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 9, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了训练循环基类类.
|
9
|
+
|
10
|
+
"""
|
11
|
+
Trainer - 训练循环类。
|
12
|
+
"""
|
13
|
+
|
14
|
+
import os
|
15
|
+
import dgl
|
16
|
+
import wandb
|
17
|
+
import typing
|
18
|
+
import torch
|
19
|
+
from .Tester import Tester
|
20
|
+
import torch.optim as optim
|
21
|
+
from ..utils.Timer import Timer
|
22
|
+
from ..module.model import Model
|
23
|
+
from torch.utils.data import DataLoader
|
24
|
+
from ..utils.EarlyStopping import EarlyStopping
|
25
|
+
from ..module.strategy import Strategy
|
26
|
+
from accelerate import Accelerator
|
27
|
+
import logging
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
logging.basicConfig(format='%(levelname)s:%(module)s:%(asctime)s:%(message)s',
|
31
|
+
datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
|
32
|
+
|
33
|
+
class Trainer(object):
|
34
|
+
|
35
|
+
"""
|
36
|
+
主要用于 KGE 模型的训练。
|
37
|
+
|
38
|
+
例子::
|
39
|
+
|
40
|
+
from unike.data import KGEDataLoader, BernSampler, TradTestSampler
|
41
|
+
from unike.module.model import TransE
|
42
|
+
from unike.module.loss import MarginLoss
|
43
|
+
from unike.module.strategy import NegativeSampling
|
44
|
+
from unike.config import Trainer, Tester
|
45
|
+
|
46
|
+
# dataloader for training
|
47
|
+
dataloader = KGEDataLoader(
|
48
|
+
in_path = "../../benchmarks/FB15K/",
|
49
|
+
batch_size = 8192,
|
50
|
+
neg_ent = 25,
|
51
|
+
test = True,
|
52
|
+
test_batch_size = 256,
|
53
|
+
num_workers = 16,
|
54
|
+
train_sampler = BernSampler,
|
55
|
+
test_sampler = TradTestSampler
|
56
|
+
)
|
57
|
+
|
58
|
+
# define the model
|
59
|
+
transe = TransE(
|
60
|
+
ent_tol = dataloader.get_ent_tol(),
|
61
|
+
rel_tol = dataloader.get_rel_tol(),
|
62
|
+
dim = 50,
|
63
|
+
p_norm = 1,
|
64
|
+
norm_flag = True)
|
65
|
+
|
66
|
+
# define the loss function
|
67
|
+
model = NegativeSampling(
|
68
|
+
model = transe,
|
69
|
+
loss = MarginLoss(margin = 1.0),
|
70
|
+
regul_rate = 0.01
|
71
|
+
)
|
72
|
+
|
73
|
+
# test the model
|
74
|
+
tester = Tester(model = transe, data_loader = dataloader, use_gpu = True, device = 'cuda:1')
|
75
|
+
|
76
|
+
# train the model
|
77
|
+
trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
|
78
|
+
epochs = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1',
|
79
|
+
tester = tester, test = True, valid_interval = 100,
|
80
|
+
log_interval = 100, save_interval = 100,
|
81
|
+
save_path = '../../checkpoint/transe.pth', delta = 0.01)
|
82
|
+
trainer.run()
|
83
|
+
"""
|
84
|
+
|
85
|
+
def __init__(
|
86
|
+
self,
|
87
|
+
model: Strategy = None,
|
88
|
+
data_loader: DataLoader = None,
|
89
|
+
epochs: int = 1000,
|
90
|
+
lr: float = 0.5,
|
91
|
+
opt_method: str = "Adam",
|
92
|
+
accelerator: Accelerator = None,
|
93
|
+
use_gpu: bool = True,
|
94
|
+
device: str = "cuda:0",
|
95
|
+
tester: Tester | None = None,
|
96
|
+
test: bool = False,
|
97
|
+
valid_interval: int | None = None,
|
98
|
+
log_interval: int | None = None,
|
99
|
+
save_interval: int | None = None,
|
100
|
+
save_path: str = None,
|
101
|
+
use_early_stopping: bool = True,
|
102
|
+
metric: str = 'hits@10',
|
103
|
+
patience: int = 2,
|
104
|
+
delta: float = 0,
|
105
|
+
use_wandb: bool = False):
|
106
|
+
|
107
|
+
"""创建 Trainer 对象。
|
108
|
+
|
109
|
+
:param model: 包装 KGE 模型的训练策略类
|
110
|
+
:type model: :py:class:`unike.module.strategy.Strategy`
|
111
|
+
:param data_loader: :py:class:`torch.utils.data.DataLoader`
|
112
|
+
:type data_loader: torch.utils.data.DataLoader
|
113
|
+
:param epochs: 训练轮次数
|
114
|
+
:type epochs: int
|
115
|
+
:param lr: 学习率
|
116
|
+
:type lr: float
|
117
|
+
:param opt_method: 优化器: **'Adam'** or **'adam'**, **'Adagrad'** or **'adagrad'**, **'SGD'** or **'sgd'**
|
118
|
+
:type opt_method: str
|
119
|
+
:param accelerator: :py:meth:`unike.config.accelerator_prepare` 返回列表中的最后一个元素。
|
120
|
+
:type accelerator: object
|
121
|
+
:param use_gpu: 是否使用 gpu
|
122
|
+
:type use_gpu: bool
|
123
|
+
:param device: 使用哪个 gpu
|
124
|
+
:type device: str
|
125
|
+
:param tester: 用于模型评估的验证模型类
|
126
|
+
:type tester: :py:class:`unike.config.Tester`
|
127
|
+
:param test: 是否在测试集上评估模型, :py:attr:`tester` 不为空
|
128
|
+
:type test: bool
|
129
|
+
:param valid_interval: 训练几轮在验证集上评估一次模型, :py:attr:`tester` 不为空
|
130
|
+
:type valid_interval: int
|
131
|
+
:param log_interval: 训练几轮输出一次日志
|
132
|
+
:type log_interval: int
|
133
|
+
:param save_interval: 训练几轮保存一次模型
|
134
|
+
:type save_interval: int
|
135
|
+
:param save_path: 模型保存的路径
|
136
|
+
:type save_path: str
|
137
|
+
:param use_early_stopping: 是否启用早停,需要 :py:attr:`tester` 和 :py:attr:`save_path` 不为空
|
138
|
+
:type use_early_stopping: bool
|
139
|
+
:param metric: 早停使用的验证指标,可选值:**'mr'**, **'mrr'**, **'hits@N'**, **'mr_type'**, **'mrr_type'**, **'hits@N_type'**。默认值:**'hits@10'**
|
140
|
+
:type metric: str
|
141
|
+
:param patience: :py:attr:`unike.utils.EarlyStopping.patience` 参数,上次验证得分改善后等待多长时间。默认值:2
|
142
|
+
:type patience: int
|
143
|
+
:param delta: :py:attr:`unike.utils.EarlyStopping.delta` 参数,监测数量的最小变化才符合改进条件。默认值:0
|
144
|
+
:type delta: float
|
145
|
+
:param use_wandb: 是否启用 wandb 进行日志输出
|
146
|
+
:type use_wandb: bool
|
147
|
+
"""
|
148
|
+
|
149
|
+
#: 包装 KGE 模型的训练策略类,即 :py:class:`unike.module.strategy.Strategy`
|
150
|
+
self.model: Strategy = model
|
151
|
+
|
152
|
+
#: :py:meth:`__init__` 传入的 :py:class:`torch.utils.data.DataLoader`
|
153
|
+
self.data_loader: torch.utils.data.DataLoader = data_loader
|
154
|
+
#: epochs
|
155
|
+
self.epochs: int = epochs
|
156
|
+
|
157
|
+
#: 学习率
|
158
|
+
self.lr: float = lr
|
159
|
+
#: 用户传入的优化器名字字符串
|
160
|
+
self.opt_method: str = opt_method
|
161
|
+
#: 根据 :py:meth:`__init__` 的 ``opt_method`` 生成对应的优化器
|
162
|
+
self.optimizer: torch.optim.SGD | torch.optim.Adagrad | torch.optim.Adam | None = None
|
163
|
+
#: 学习率调度器
|
164
|
+
self.scheduler: torch.optim.lr_scheduler.MultiStepLR | None = None
|
165
|
+
|
166
|
+
#: 是否进行分布式并行训练,:py:meth:`unike.config.accelerator_prepare` 返回列表中的最后一个元素。
|
167
|
+
self.accelerator = accelerator
|
168
|
+
|
169
|
+
#: 是否使用 gpu
|
170
|
+
self.use_gpu: bool = use_gpu
|
171
|
+
#: gpu,利用 ``device`` 构造的 :py:class:`torch.device` 对象
|
172
|
+
self.device: typing.Union[torch.device, str] = torch.device(device) if self.use_gpu else "cpu"
|
173
|
+
|
174
|
+
#: 用于模型评估的验证模型类
|
175
|
+
self.tester: Tester | None = tester
|
176
|
+
#: 是否在测试集上评估模型, :py:attr:`tester` 不为空
|
177
|
+
self.test: bool = test
|
178
|
+
#: 训练几轮在验证集上评估一次模型, :py:attr:`tester` 不为空
|
179
|
+
self.valid_interval: int | None = valid_interval
|
180
|
+
|
181
|
+
#: 训练几轮输出一次日志
|
182
|
+
self.log_interval: int | None = log_interval
|
183
|
+
#: 训练几轮保存一次模型
|
184
|
+
self.save_interval: int | None = save_interval
|
185
|
+
#: 模型保存的路径
|
186
|
+
self.save_path: str | None = save_path
|
187
|
+
|
188
|
+
#: 是否启用早停,需要 :py:attr:`tester` 和 :py:attr:`save_path` 不为空
|
189
|
+
self.use_early_stopping: bool = use_early_stopping
|
190
|
+
#: 早停使用的验证指标,可选值:**'mr'**, **'mrr'**, **'hits@N'**, **'mr_type'**, **'mrr_type'**, **'hits@N_type'**。默认值:**'hits@10'**
|
191
|
+
self.metric: str = metric
|
192
|
+
#: :py:attr:`unike.utils.EarlyStopping.patience` 参数,上次验证得分改善后等待多长时间。默认值:2
|
193
|
+
self.patience: int = patience
|
194
|
+
#: :py:attr:`unike.utils.EarlyStopping.delta` 参数,监测数量的最小变化才符合改进条件。默认值:0
|
195
|
+
self.delta: float = delta
|
196
|
+
#: 早停对象
|
197
|
+
self.early_stopping: EarlyStopping = None
|
198
|
+
|
199
|
+
#: 是否启用 wandb 进行日志输出
|
200
|
+
self.use_wandb: bool = use_wandb
|
201
|
+
|
202
|
+
def configure_optimizers(self):
|
203
|
+
|
204
|
+
"""可以通过重新实现该方法自定义配置优化器。"""
|
205
|
+
|
206
|
+
if self.opt_method == "Adam" or self.opt_method == "adam":
|
207
|
+
self.optimizer = optim.Adam(
|
208
|
+
self.model.parameters(),
|
209
|
+
lr=self.lr,
|
210
|
+
)
|
211
|
+
elif self.opt_method == "Adagrad" or self.opt_method == "adagrad":
|
212
|
+
self.optimizer = optim.Adagrad(
|
213
|
+
self.model.parameters(),
|
214
|
+
lr=self.lr,
|
215
|
+
)
|
216
|
+
elif self.opt_method == "SGD" or self.opt_method == "sgd":
|
217
|
+
self.optimizer = optim.SGD(
|
218
|
+
self.model.parameters(),
|
219
|
+
lr = self.lr,
|
220
|
+
momentum=0.9,
|
221
|
+
)
|
222
|
+
|
223
|
+
if self.accelerator:
|
224
|
+
self.optimizer = self.accelerator.prepare(self.optimizer)
|
225
|
+
|
226
|
+
milestones = int(self.epochs / 3)
|
227
|
+
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
228
|
+
self.optimizer, milestones=[milestones, milestones*2],
|
229
|
+
gamma=0.1
|
230
|
+
)
|
231
|
+
|
232
|
+
def train_one_step(
|
233
|
+
self,
|
234
|
+
data: dict[str, typing.Union[str, dgl.DGLGraph, torch.Tensor]]) -> float:
|
235
|
+
|
236
|
+
"""根据 :py:attr:`data_loader` 生成的 1 批次(batch) ``data`` 将
|
237
|
+
模型训练 1 步。
|
238
|
+
|
239
|
+
:param data: 训练数据
|
240
|
+
:type data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
|
241
|
+
:returns: 损失值
|
242
|
+
:rtype: float
|
243
|
+
"""
|
244
|
+
|
245
|
+
self.optimizer.zero_grad()
|
246
|
+
if not self.accelerator:
|
247
|
+
data = {key : self.to_var(value) if key != 'mode' else value for key, value in data.items()}
|
248
|
+
loss = self.model(data)
|
249
|
+
if not self.accelerator:
|
250
|
+
loss.backward()
|
251
|
+
else:
|
252
|
+
self.accelerator.backward(loss)
|
253
|
+
self.optimizer.step()
|
254
|
+
return loss.item()
|
255
|
+
|
256
|
+
def run(self):
|
257
|
+
|
258
|
+
"""
|
259
|
+
训练循环,首先根据 :py:attr:`use_gpu` 设置 :py:attr:`model` 是否使用 gpu 训练,然后根据
|
260
|
+
:py:attr:`opt_method` 设置 :py:attr:`optimizer`,最后迭代 :py:attr:`data_loader` 获取数据,
|
261
|
+
并利用 :py:meth:`train_one_step` 训练。
|
262
|
+
"""
|
263
|
+
|
264
|
+
if not self.accelerator and self.use_gpu:
|
265
|
+
self.model.cuda(device = self.device)
|
266
|
+
|
267
|
+
if self.use_early_stopping and self.tester and self.save_path:
|
268
|
+
self.early_stopping = EarlyStopping(
|
269
|
+
save_path = os.path.split(self.save_path)[0],
|
270
|
+
patience = self.patience,
|
271
|
+
delta = self.delta)
|
272
|
+
|
273
|
+
self.configure_optimizers()
|
274
|
+
|
275
|
+
logger.info(f"[{self.get_device()}] Initialization completed, start model training.")
|
276
|
+
|
277
|
+
if self.use_wandb:
|
278
|
+
if not self.accelerator:
|
279
|
+
wandb.watch(self.model.model, log_freq=100)
|
280
|
+
else:
|
281
|
+
wandb.watch(self.model.module.model, log_freq=100)
|
282
|
+
|
283
|
+
timer = Timer()
|
284
|
+
|
285
|
+
for epoch in range(self.epochs):
|
286
|
+
|
287
|
+
res = 0.0
|
288
|
+
if not self.accelerator:
|
289
|
+
self.model.model.train()
|
290
|
+
else:
|
291
|
+
self.model.module.model.train()
|
292
|
+
|
293
|
+
for data in self.data_loader:
|
294
|
+
loss = self.train_one_step(data)
|
295
|
+
res += loss
|
296
|
+
timer.stop()
|
297
|
+
self.scheduler.step()
|
298
|
+
|
299
|
+
if self.is_local_main_process():
|
300
|
+
|
301
|
+
if self.valid_interval and self.tester and \
|
302
|
+
(epoch + 1) % self.valid_interval == 0:
|
303
|
+
logger.info(f"[{self.get_device()}] Epoch {epoch+1} | The model starts evaluation on the validation set.")
|
304
|
+
self.print_test("link_valid", epoch)
|
305
|
+
|
306
|
+
if self.early_stopping and self.early_stopping.early_stop:
|
307
|
+
logger.info(f"[{self.get_device()}] Send an early stopping signal")
|
308
|
+
if self.accelerator:
|
309
|
+
self.accelerator.set_trigger()
|
310
|
+
else:
|
311
|
+
break
|
312
|
+
|
313
|
+
if self.save_interval and self.save_path and (epoch + 1) % self.save_interval == 0:
|
314
|
+
path = os.path.join(os.path.splitext(self.save_path)[0] + "-" + str(epoch+1) + \
|
315
|
+
os.path.splitext(self.save_path)[-1])
|
316
|
+
self.get_model().save_checkpoint(path)
|
317
|
+
logger.info(f"[{self.get_device()}] Epoch {epoch+1} | Training checkpoint saved at {path}")
|
318
|
+
|
319
|
+
if self.accelerator and self.accelerator.check_trigger():
|
320
|
+
logger.info(f"[{self.get_device()}] Early stopping")
|
321
|
+
break
|
322
|
+
|
323
|
+
if self.log_interval and (epoch + 1) % self.log_interval == 0:
|
324
|
+
if self.is_local_main_process() and self.use_wandb:
|
325
|
+
wandb.log({"train/train_loss" : res, "train/epoch" : epoch + 1})
|
326
|
+
logger.info(f"[{self.get_device()}] Epoch [{epoch+1:>4d}/{self.epochs:>4d}] | Batchsize: {self.data_loader.batch_size} | loss: {res:>9f} | {timer.avg():.5f} seconds/epoch")
|
327
|
+
|
328
|
+
logger.info(f"[{self.get_device()}] The model training is completed, taking a total of {timer.sum():.5f} seconds.")
|
329
|
+
|
330
|
+
if self.is_local_main_process():
|
331
|
+
|
332
|
+
if self.use_wandb:
|
333
|
+
wandb.log({"duration" : timer.sum()})
|
334
|
+
|
335
|
+
if self.save_path:
|
336
|
+
self.get_model().save_checkpoint(self.save_path)
|
337
|
+
logger.info(f"[{self.get_device()}] Model saved at {self.save_path}.")
|
338
|
+
|
339
|
+
if self.test and self.tester:
|
340
|
+
logger.info(f"[{self.get_device()}] The model starts evaluating in the test set.")
|
341
|
+
self.print_test("link_test")
|
342
|
+
|
343
|
+
def print_test(
|
344
|
+
self,
|
345
|
+
sampling_mode: str,
|
346
|
+
epoch: int = 0):
|
347
|
+
|
348
|
+
"""根据 :py:attr:`tester` 类型进行链接预测 。
|
349
|
+
|
350
|
+
:param sampling_mode: 数据
|
351
|
+
:type sampling_mode: str
|
352
|
+
"""
|
353
|
+
|
354
|
+
self.tester.set_sampling_mode(sampling_mode)
|
355
|
+
|
356
|
+
if sampling_mode == "link_test":
|
357
|
+
mode = "test"
|
358
|
+
elif sampling_mode == "link_valid":
|
359
|
+
mode = "val"
|
360
|
+
|
361
|
+
results = self.tester.run_link_prediction()
|
362
|
+
for key, value in results.items():
|
363
|
+
logger.info(f"{key}: {value}")
|
364
|
+
if self.use_wandb:
|
365
|
+
log_dict = {f"{mode}/{key}" : value for key, value in results.items()}
|
366
|
+
if sampling_mode == "link_valid":
|
367
|
+
log_dict.update({
|
368
|
+
"val/epoch": epoch
|
369
|
+
})
|
370
|
+
wandb.log(log_dict)
|
371
|
+
|
372
|
+
if self.early_stopping is not None and sampling_mode == "link_valid":
|
373
|
+
if self.metric in ['mr', 'mr_type']:
|
374
|
+
self.early_stopping(-results[self.metric], self.get_model())
|
375
|
+
elif self.metric in results.keys():
|
376
|
+
self.early_stopping(results[self.metric], self.get_model())
|
377
|
+
else:
|
378
|
+
raise ValueError("Early stopping metric is not valid.")
|
379
|
+
|
380
|
+
def to_var(
|
381
|
+
self,
|
382
|
+
x: torch.Tensor) -> torch.Tensor:
|
383
|
+
|
384
|
+
"""将 ``x`` 转移到对应的设备上。
|
385
|
+
|
386
|
+
:param x: 数据
|
387
|
+
:type x: torch.Tensor
|
388
|
+
:returns: 张量
|
389
|
+
:rtype: torch.Tensor
|
390
|
+
"""
|
391
|
+
|
392
|
+
if self.use_gpu:
|
393
|
+
return x.to(self.device)
|
394
|
+
else:
|
395
|
+
return x
|
396
|
+
|
397
|
+
def get_model(self) -> Model:
|
398
|
+
|
399
|
+
"""返回原始的 KGE 模型。
|
400
|
+
|
401
|
+
:returns: KGE 模型
|
402
|
+
:rtype: :py:class:`unike.module.model.Model`
|
403
|
+
"""
|
404
|
+
|
405
|
+
if self.accelerator:
|
406
|
+
return self.model.module.model
|
407
|
+
else:
|
408
|
+
return self.model.model
|
409
|
+
|
410
|
+
def get_device(self) -> typing.Union[torch.device, str]:
|
411
|
+
|
412
|
+
"""返回当前进程的设备。
|
413
|
+
|
414
|
+
:returns: 设备信息
|
415
|
+
:rtype: typing.Union[torch.device, str]
|
416
|
+
"""
|
417
|
+
|
418
|
+
if self.accelerator:
|
419
|
+
return self.model.device
|
420
|
+
else:
|
421
|
+
return self.device
|
422
|
+
|
423
|
+
def is_local_main_process(self) -> bool:
|
424
|
+
|
425
|
+
"""当前进程是否是主进程。
|
426
|
+
|
427
|
+
:returns: 当前进程是否是主进程。
|
428
|
+
:rtype: bool
|
429
|
+
"""
|
430
|
+
|
431
|
+
return not self.accelerator or self.accelerator.is_local_main_process
|
432
|
+
|
433
|
+
def get_trainer_hpo_config() -> dict[str, dict[str, typing.Any]]:
|
434
|
+
|
435
|
+
"""返回 :py:class:`Trainer` 的默认超参数优化配置。
|
436
|
+
|
437
|
+
默认配置为::
|
438
|
+
|
439
|
+
parameters_dict = {
|
440
|
+
'trainer': {
|
441
|
+
'value': 'Trainer'
|
442
|
+
},
|
443
|
+
'epochs': {
|
444
|
+
'value': 10000
|
445
|
+
},
|
446
|
+
'lr': {
|
447
|
+
'distribution': 'uniform',
|
448
|
+
'min': 1e-5,
|
449
|
+
'max': 1.0
|
450
|
+
},
|
451
|
+
'opt_method': {
|
452
|
+
'values': ['adam', 'adagrad', 'sgd']
|
453
|
+
},
|
454
|
+
'valid_interval': {
|
455
|
+
'value': 100
|
456
|
+
},
|
457
|
+
'log_interval': {
|
458
|
+
'value': 100
|
459
|
+
},
|
460
|
+
'save_path': {
|
461
|
+
'value': './model.pth'
|
462
|
+
},
|
463
|
+
'use_early_stopping': {
|
464
|
+
'value': True
|
465
|
+
},
|
466
|
+
'metric': {
|
467
|
+
'value': 'hits@10'
|
468
|
+
},
|
469
|
+
'patience': {
|
470
|
+
'value': 2
|
471
|
+
},
|
472
|
+
'delta': {
|
473
|
+
'value': 0.0001
|
474
|
+
},
|
475
|
+
}
|
476
|
+
|
477
|
+
:returns: :py:class:`Trainer` 的默认超参数优化配置
|
478
|
+
:rtype: dict[str, dict[str, typing.Any]]
|
479
|
+
"""
|
480
|
+
|
481
|
+
parameters_dict = {
|
482
|
+
'trainer': {
|
483
|
+
'value': 'Trainer'
|
484
|
+
},
|
485
|
+
'epochs': {
|
486
|
+
'value': 10000
|
487
|
+
},
|
488
|
+
'lr': {
|
489
|
+
'distribution': 'uniform',
|
490
|
+
'min': 1e-5,
|
491
|
+
'max': 1.0
|
492
|
+
},
|
493
|
+
'opt_method': {
|
494
|
+
'values': ['adam', 'adagrad', 'sgd']
|
495
|
+
},
|
496
|
+
'valid_interval': {
|
497
|
+
'value': 100
|
498
|
+
},
|
499
|
+
'log_interval': {
|
500
|
+
'value': 100
|
501
|
+
},
|
502
|
+
'save_path': {
|
503
|
+
'value': './model.pth'
|
504
|
+
},
|
505
|
+
'use_early_stopping': {
|
506
|
+
'value': True
|
507
|
+
},
|
508
|
+
'metric': {
|
509
|
+
'value': 'hits@10'
|
510
|
+
},
|
511
|
+
'patience': {
|
512
|
+
'value': 2
|
513
|
+
},
|
514
|
+
'delta': {
|
515
|
+
'value': 0.0001
|
516
|
+
},
|
517
|
+
}
|
518
|
+
|
519
|
+
return parameters_dict
|
@@ -0,0 +1,39 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/config/TrainerAccelerator.py
|
4
|
+
#
|
5
|
+
# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Apr 12, 2024
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Apr 27, 2024
|
7
|
+
#
|
8
|
+
# 该脚本定义了并行训练循环函数.
|
9
|
+
|
10
|
+
"""
|
11
|
+
利用 accelerate 实现并行训练。
|
12
|
+
"""
|
13
|
+
|
14
|
+
from typing import Any, List
|
15
|
+
from accelerate import Accelerator
|
16
|
+
|
17
|
+
def accelerator_prepare(*args: List[Any]) -> List[Any]:
|
18
|
+
|
19
|
+
"""
|
20
|
+
由于分布式并行依赖于 `accelerate <https://github.com/huggingface/accelerate>`_ ,因此,需要利用 Accelerator 为分布式训练准备对象。
|
21
|
+
|
22
|
+
例子::
|
23
|
+
|
24
|
+
dataloader, model, accelerator = accelerator_prepare(
|
25
|
+
dataloader,
|
26
|
+
model
|
27
|
+
)
|
28
|
+
|
29
|
+
:param args: :py:class:`unike.data.KGEDataLoader` 和 :py:class:`unike.module.strategy.Strategy` 。
|
30
|
+
:type args: typing.List[typing.Any]
|
31
|
+
:returns: 包装好的对象列表和 Accelerator() 对象。
|
32
|
+
:rtype: typing.List[typing.Any]
|
33
|
+
"""
|
34
|
+
|
35
|
+
accelerator = Accelerator()
|
36
|
+
result = accelerator.prepare(*args)
|
37
|
+
result = list(result)
|
38
|
+
result.append(accelerator)
|
39
|
+
return result
|
unike/config/__init__.py
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
# coding:utf-8
|
2
|
+
#
|
3
|
+
# unike/config/__init__.py
|
4
|
+
#
|
5
|
+
# git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
|
6
|
+
# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 3, 2023
|
7
|
+
#
|
8
|
+
# 该头文件定义了 config 接口.
|
9
|
+
|
10
|
+
"""循环部分,包含训练循环和验证循环。"""
|
11
|
+
|
12
|
+
from __future__ import absolute_import
|
13
|
+
from __future__ import division
|
14
|
+
from __future__ import print_function
|
15
|
+
|
16
|
+
from .Trainer import Trainer, get_trainer_hpo_config
|
17
|
+
from .Tester import Tester, get_tester_hpo_config, link_predict, head_predict, tail_predict, calc_ranks
|
18
|
+
|
19
|
+
from .HPOTrainer import set_hpo_config, set_hpo_hits,start_hpo_train, hpo_train
|
20
|
+
|
21
|
+
from .TrainerAccelerator import accelerator_prepare
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
'Trainer',
|
25
|
+
'get_trainer_hpo_config',
|
26
|
+
'Tester',
|
27
|
+
'get_tester_hpo_config',
|
28
|
+
'link_predict',
|
29
|
+
'head_predict',
|
30
|
+
'tail_predict',
|
31
|
+
'calc_ranks',
|
32
|
+
'set_hpo_config',
|
33
|
+
'set_hpo_hits',
|
34
|
+
'start_hpo_train',
|
35
|
+
'hpo_train',
|
36
|
+
'accelerator_prepare'
|
37
|
+
]
|