unike 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. unike/__init__.py +5 -0
  2. unike/config/HPOTrainer.py +305 -0
  3. unike/config/Tester.py +385 -0
  4. unike/config/Trainer.py +519 -0
  5. unike/config/TrainerAccelerator.py +39 -0
  6. unike/config/__init__.py +37 -0
  7. unike/data/BernSampler.py +168 -0
  8. unike/data/CompGCNSampler.py +140 -0
  9. unike/data/CompGCNTestSampler.py +84 -0
  10. unike/data/KGEDataLoader.py +315 -0
  11. unike/data/KGReader.py +138 -0
  12. unike/data/RGCNSampler.py +261 -0
  13. unike/data/RGCNTestSampler.py +208 -0
  14. unike/data/RevSampler.py +78 -0
  15. unike/data/TestSampler.py +189 -0
  16. unike/data/TradSampler.py +122 -0
  17. unike/data/TradTestSampler.py +87 -0
  18. unike/data/UniSampler.py +145 -0
  19. unike/data/__init__.py +47 -0
  20. unike/module/BaseModule.py +130 -0
  21. unike/module/__init__.py +20 -0
  22. unike/module/loss/CompGCNLoss.py +96 -0
  23. unike/module/loss/Loss.py +26 -0
  24. unike/module/loss/MarginLoss.py +148 -0
  25. unike/module/loss/RGCNLoss.py +117 -0
  26. unike/module/loss/SigmoidLoss.py +145 -0
  27. unike/module/loss/SoftplusLoss.py +145 -0
  28. unike/module/loss/__init__.py +35 -0
  29. unike/module/model/Analogy.py +237 -0
  30. unike/module/model/CompGCN.py +562 -0
  31. unike/module/model/ComplEx.py +235 -0
  32. unike/module/model/DistMult.py +276 -0
  33. unike/module/model/HolE.py +308 -0
  34. unike/module/model/Model.py +107 -0
  35. unike/module/model/RESCAL.py +309 -0
  36. unike/module/model/RGCN.py +304 -0
  37. unike/module/model/RotatE.py +303 -0
  38. unike/module/model/SimplE.py +237 -0
  39. unike/module/model/TransD.py +458 -0
  40. unike/module/model/TransE.py +290 -0
  41. unike/module/model/TransH.py +322 -0
  42. unike/module/model/TransR.py +402 -0
  43. unike/module/model/__init__.py +60 -0
  44. unike/module/strategy/CompGCNSampling.py +140 -0
  45. unike/module/strategy/NegativeSampling.py +138 -0
  46. unike/module/strategy/RGCNSampling.py +134 -0
  47. unike/module/strategy/Strategy.py +26 -0
  48. unike/module/strategy/__init__.py +29 -0
  49. unike/utils/EarlyStopping.py +94 -0
  50. unike/utils/Timer.py +74 -0
  51. unike/utils/WandbLogger.py +46 -0
  52. unike/utils/__init__.py +26 -0
  53. unike/utils/tools.py +118 -0
  54. unike/version.py +1 -0
  55. unike-3.0.1.dist-info/METADATA +101 -0
  56. unike-3.0.1.dist-info/RECORD +59 -0
  57. unike-3.0.1.dist-info/WHEEL +4 -0
  58. unike-3.0.1.dist-info/entry_points.txt +2 -0
  59. unike-3.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,519 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/config/Trainer.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 9, 2024
7
+ #
8
+ # 该脚本定义了训练循环基类类.
9
+
10
+ """
11
+ Trainer - 训练循环类。
12
+ """
13
+
14
+ import os
15
+ import dgl
16
+ import wandb
17
+ import typing
18
+ import torch
19
+ from .Tester import Tester
20
+ import torch.optim as optim
21
+ from ..utils.Timer import Timer
22
+ from ..module.model import Model
23
+ from torch.utils.data import DataLoader
24
+ from ..utils.EarlyStopping import EarlyStopping
25
+ from ..module.strategy import Strategy
26
+ from accelerate import Accelerator
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+ logging.basicConfig(format='%(levelname)s:%(module)s:%(asctime)s:%(message)s',
31
+ datefmt='%Y-%m-%d %H:%M:%S', level=logging.DEBUG)
32
+
33
+ class Trainer(object):
34
+
35
+ """
36
+ 主要用于 KGE 模型的训练。
37
+
38
+ 例子::
39
+
40
+ from unike.data import KGEDataLoader, BernSampler, TradTestSampler
41
+ from unike.module.model import TransE
42
+ from unike.module.loss import MarginLoss
43
+ from unike.module.strategy import NegativeSampling
44
+ from unike.config import Trainer, Tester
45
+
46
+ # dataloader for training
47
+ dataloader = KGEDataLoader(
48
+ in_path = "../../benchmarks/FB15K/",
49
+ batch_size = 8192,
50
+ neg_ent = 25,
51
+ test = True,
52
+ test_batch_size = 256,
53
+ num_workers = 16,
54
+ train_sampler = BernSampler,
55
+ test_sampler = TradTestSampler
56
+ )
57
+
58
+ # define the model
59
+ transe = TransE(
60
+ ent_tol = dataloader.get_ent_tol(),
61
+ rel_tol = dataloader.get_rel_tol(),
62
+ dim = 50,
63
+ p_norm = 1,
64
+ norm_flag = True)
65
+
66
+ # define the loss function
67
+ model = NegativeSampling(
68
+ model = transe,
69
+ loss = MarginLoss(margin = 1.0),
70
+ regul_rate = 0.01
71
+ )
72
+
73
+ # test the model
74
+ tester = Tester(model = transe, data_loader = dataloader, use_gpu = True, device = 'cuda:1')
75
+
76
+ # train the model
77
+ trainer = Trainer(model = model, data_loader = dataloader.train_dataloader(),
78
+ epochs = 1000, lr = 0.01, use_gpu = True, device = 'cuda:1',
79
+ tester = tester, test = True, valid_interval = 100,
80
+ log_interval = 100, save_interval = 100,
81
+ save_path = '../../checkpoint/transe.pth', delta = 0.01)
82
+ trainer.run()
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ model: Strategy = None,
88
+ data_loader: DataLoader = None,
89
+ epochs: int = 1000,
90
+ lr: float = 0.5,
91
+ opt_method: str = "Adam",
92
+ accelerator: Accelerator = None,
93
+ use_gpu: bool = True,
94
+ device: str = "cuda:0",
95
+ tester: Tester | None = None,
96
+ test: bool = False,
97
+ valid_interval: int | None = None,
98
+ log_interval: int | None = None,
99
+ save_interval: int | None = None,
100
+ save_path: str = None,
101
+ use_early_stopping: bool = True,
102
+ metric: str = 'hits@10',
103
+ patience: int = 2,
104
+ delta: float = 0,
105
+ use_wandb: bool = False):
106
+
107
+ """创建 Trainer 对象。
108
+
109
+ :param model: 包装 KGE 模型的训练策略类
110
+ :type model: :py:class:`unike.module.strategy.Strategy`
111
+ :param data_loader: :py:class:`torch.utils.data.DataLoader`
112
+ :type data_loader: torch.utils.data.DataLoader
113
+ :param epochs: 训练轮次数
114
+ :type epochs: int
115
+ :param lr: 学习率
116
+ :type lr: float
117
+ :param opt_method: 优化器: **'Adam'** or **'adam'**, **'Adagrad'** or **'adagrad'**, **'SGD'** or **'sgd'**
118
+ :type opt_method: str
119
+ :param accelerator: :py:meth:`unike.config.accelerator_prepare` 返回列表中的最后一个元素。
120
+ :type accelerator: object
121
+ :param use_gpu: 是否使用 gpu
122
+ :type use_gpu: bool
123
+ :param device: 使用哪个 gpu
124
+ :type device: str
125
+ :param tester: 用于模型评估的验证模型类
126
+ :type tester: :py:class:`unike.config.Tester`
127
+ :param test: 是否在测试集上评估模型, :py:attr:`tester` 不为空
128
+ :type test: bool
129
+ :param valid_interval: 训练几轮在验证集上评估一次模型, :py:attr:`tester` 不为空
130
+ :type valid_interval: int
131
+ :param log_interval: 训练几轮输出一次日志
132
+ :type log_interval: int
133
+ :param save_interval: 训练几轮保存一次模型
134
+ :type save_interval: int
135
+ :param save_path: 模型保存的路径
136
+ :type save_path: str
137
+ :param use_early_stopping: 是否启用早停,需要 :py:attr:`tester` 和 :py:attr:`save_path` 不为空
138
+ :type use_early_stopping: bool
139
+ :param metric: 早停使用的验证指标,可选值:**'mr'**, **'mrr'**, **'hits@N'**, **'mr_type'**, **'mrr_type'**, **'hits@N_type'**。默认值:**'hits@10'**
140
+ :type metric: str
141
+ :param patience: :py:attr:`unike.utils.EarlyStopping.patience` 参数,上次验证得分改善后等待多长时间。默认值:2
142
+ :type patience: int
143
+ :param delta: :py:attr:`unike.utils.EarlyStopping.delta` 参数,监测数量的最小变化才符合改进条件。默认值:0
144
+ :type delta: float
145
+ :param use_wandb: 是否启用 wandb 进行日志输出
146
+ :type use_wandb: bool
147
+ """
148
+
149
+ #: 包装 KGE 模型的训练策略类,即 :py:class:`unike.module.strategy.Strategy`
150
+ self.model: Strategy = model
151
+
152
+ #: :py:meth:`__init__` 传入的 :py:class:`torch.utils.data.DataLoader`
153
+ self.data_loader: torch.utils.data.DataLoader = data_loader
154
+ #: epochs
155
+ self.epochs: int = epochs
156
+
157
+ #: 学习率
158
+ self.lr: float = lr
159
+ #: 用户传入的优化器名字字符串
160
+ self.opt_method: str = opt_method
161
+ #: 根据 :py:meth:`__init__` 的 ``opt_method`` 生成对应的优化器
162
+ self.optimizer: torch.optim.SGD | torch.optim.Adagrad | torch.optim.Adam | None = None
163
+ #: 学习率调度器
164
+ self.scheduler: torch.optim.lr_scheduler.MultiStepLR | None = None
165
+
166
+ #: 是否进行分布式并行训练,:py:meth:`unike.config.accelerator_prepare` 返回列表中的最后一个元素。
167
+ self.accelerator = accelerator
168
+
169
+ #: 是否使用 gpu
170
+ self.use_gpu: bool = use_gpu
171
+ #: gpu,利用 ``device`` 构造的 :py:class:`torch.device` 对象
172
+ self.device: typing.Union[torch.device, str] = torch.device(device) if self.use_gpu else "cpu"
173
+
174
+ #: 用于模型评估的验证模型类
175
+ self.tester: Tester | None = tester
176
+ #: 是否在测试集上评估模型, :py:attr:`tester` 不为空
177
+ self.test: bool = test
178
+ #: 训练几轮在验证集上评估一次模型, :py:attr:`tester` 不为空
179
+ self.valid_interval: int | None = valid_interval
180
+
181
+ #: 训练几轮输出一次日志
182
+ self.log_interval: int | None = log_interval
183
+ #: 训练几轮保存一次模型
184
+ self.save_interval: int | None = save_interval
185
+ #: 模型保存的路径
186
+ self.save_path: str | None = save_path
187
+
188
+ #: 是否启用早停,需要 :py:attr:`tester` 和 :py:attr:`save_path` 不为空
189
+ self.use_early_stopping: bool = use_early_stopping
190
+ #: 早停使用的验证指标,可选值:**'mr'**, **'mrr'**, **'hits@N'**, **'mr_type'**, **'mrr_type'**, **'hits@N_type'**。默认值:**'hits@10'**
191
+ self.metric: str = metric
192
+ #: :py:attr:`unike.utils.EarlyStopping.patience` 参数,上次验证得分改善后等待多长时间。默认值:2
193
+ self.patience: int = patience
194
+ #: :py:attr:`unike.utils.EarlyStopping.delta` 参数,监测数量的最小变化才符合改进条件。默认值:0
195
+ self.delta: float = delta
196
+ #: 早停对象
197
+ self.early_stopping: EarlyStopping = None
198
+
199
+ #: 是否启用 wandb 进行日志输出
200
+ self.use_wandb: bool = use_wandb
201
+
202
+ def configure_optimizers(self):
203
+
204
+ """可以通过重新实现该方法自定义配置优化器。"""
205
+
206
+ if self.opt_method == "Adam" or self.opt_method == "adam":
207
+ self.optimizer = optim.Adam(
208
+ self.model.parameters(),
209
+ lr=self.lr,
210
+ )
211
+ elif self.opt_method == "Adagrad" or self.opt_method == "adagrad":
212
+ self.optimizer = optim.Adagrad(
213
+ self.model.parameters(),
214
+ lr=self.lr,
215
+ )
216
+ elif self.opt_method == "SGD" or self.opt_method == "sgd":
217
+ self.optimizer = optim.SGD(
218
+ self.model.parameters(),
219
+ lr = self.lr,
220
+ momentum=0.9,
221
+ )
222
+
223
+ if self.accelerator:
224
+ self.optimizer = self.accelerator.prepare(self.optimizer)
225
+
226
+ milestones = int(self.epochs / 3)
227
+ self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
228
+ self.optimizer, milestones=[milestones, milestones*2],
229
+ gamma=0.1
230
+ )
231
+
232
+ def train_one_step(
233
+ self,
234
+ data: dict[str, typing.Union[str, dgl.DGLGraph, torch.Tensor]]) -> float:
235
+
236
+ """根据 :py:attr:`data_loader` 生成的 1 批次(batch) ``data`` 将
237
+ 模型训练 1 步。
238
+
239
+ :param data: 训练数据
240
+ :type data: dict[str, typing.Union[dgl.DGLGraph, torch.Tensor]]
241
+ :returns: 损失值
242
+ :rtype: float
243
+ """
244
+
245
+ self.optimizer.zero_grad()
246
+ if not self.accelerator:
247
+ data = {key : self.to_var(value) if key != 'mode' else value for key, value in data.items()}
248
+ loss = self.model(data)
249
+ if not self.accelerator:
250
+ loss.backward()
251
+ else:
252
+ self.accelerator.backward(loss)
253
+ self.optimizer.step()
254
+ return loss.item()
255
+
256
+ def run(self):
257
+
258
+ """
259
+ 训练循环,首先根据 :py:attr:`use_gpu` 设置 :py:attr:`model` 是否使用 gpu 训练,然后根据
260
+ :py:attr:`opt_method` 设置 :py:attr:`optimizer`,最后迭代 :py:attr:`data_loader` 获取数据,
261
+ 并利用 :py:meth:`train_one_step` 训练。
262
+ """
263
+
264
+ if not self.accelerator and self.use_gpu:
265
+ self.model.cuda(device = self.device)
266
+
267
+ if self.use_early_stopping and self.tester and self.save_path:
268
+ self.early_stopping = EarlyStopping(
269
+ save_path = os.path.split(self.save_path)[0],
270
+ patience = self.patience,
271
+ delta = self.delta)
272
+
273
+ self.configure_optimizers()
274
+
275
+ logger.info(f"[{self.get_device()}] Initialization completed, start model training.")
276
+
277
+ if self.use_wandb:
278
+ if not self.accelerator:
279
+ wandb.watch(self.model.model, log_freq=100)
280
+ else:
281
+ wandb.watch(self.model.module.model, log_freq=100)
282
+
283
+ timer = Timer()
284
+
285
+ for epoch in range(self.epochs):
286
+
287
+ res = 0.0
288
+ if not self.accelerator:
289
+ self.model.model.train()
290
+ else:
291
+ self.model.module.model.train()
292
+
293
+ for data in self.data_loader:
294
+ loss = self.train_one_step(data)
295
+ res += loss
296
+ timer.stop()
297
+ self.scheduler.step()
298
+
299
+ if self.is_local_main_process():
300
+
301
+ if self.valid_interval and self.tester and \
302
+ (epoch + 1) % self.valid_interval == 0:
303
+ logger.info(f"[{self.get_device()}] Epoch {epoch+1} | The model starts evaluation on the validation set.")
304
+ self.print_test("link_valid", epoch)
305
+
306
+ if self.early_stopping and self.early_stopping.early_stop:
307
+ logger.info(f"[{self.get_device()}] Send an early stopping signal")
308
+ if self.accelerator:
309
+ self.accelerator.set_trigger()
310
+ else:
311
+ break
312
+
313
+ if self.save_interval and self.save_path and (epoch + 1) % self.save_interval == 0:
314
+ path = os.path.join(os.path.splitext(self.save_path)[0] + "-" + str(epoch+1) + \
315
+ os.path.splitext(self.save_path)[-1])
316
+ self.get_model().save_checkpoint(path)
317
+ logger.info(f"[{self.get_device()}] Epoch {epoch+1} | Training checkpoint saved at {path}")
318
+
319
+ if self.accelerator and self.accelerator.check_trigger():
320
+ logger.info(f"[{self.get_device()}] Early stopping")
321
+ break
322
+
323
+ if self.log_interval and (epoch + 1) % self.log_interval == 0:
324
+ if self.is_local_main_process() and self.use_wandb:
325
+ wandb.log({"train/train_loss" : res, "train/epoch" : epoch + 1})
326
+ logger.info(f"[{self.get_device()}] Epoch [{epoch+1:>4d}/{self.epochs:>4d}] | Batchsize: {self.data_loader.batch_size} | loss: {res:>9f} | {timer.avg():.5f} seconds/epoch")
327
+
328
+ logger.info(f"[{self.get_device()}] The model training is completed, taking a total of {timer.sum():.5f} seconds.")
329
+
330
+ if self.is_local_main_process():
331
+
332
+ if self.use_wandb:
333
+ wandb.log({"duration" : timer.sum()})
334
+
335
+ if self.save_path:
336
+ self.get_model().save_checkpoint(self.save_path)
337
+ logger.info(f"[{self.get_device()}] Model saved at {self.save_path}.")
338
+
339
+ if self.test and self.tester:
340
+ logger.info(f"[{self.get_device()}] The model starts evaluating in the test set.")
341
+ self.print_test("link_test")
342
+
343
+ def print_test(
344
+ self,
345
+ sampling_mode: str,
346
+ epoch: int = 0):
347
+
348
+ """根据 :py:attr:`tester` 类型进行链接预测 。
349
+
350
+ :param sampling_mode: 数据
351
+ :type sampling_mode: str
352
+ """
353
+
354
+ self.tester.set_sampling_mode(sampling_mode)
355
+
356
+ if sampling_mode == "link_test":
357
+ mode = "test"
358
+ elif sampling_mode == "link_valid":
359
+ mode = "val"
360
+
361
+ results = self.tester.run_link_prediction()
362
+ for key, value in results.items():
363
+ logger.info(f"{key}: {value}")
364
+ if self.use_wandb:
365
+ log_dict = {f"{mode}/{key}" : value for key, value in results.items()}
366
+ if sampling_mode == "link_valid":
367
+ log_dict.update({
368
+ "val/epoch": epoch
369
+ })
370
+ wandb.log(log_dict)
371
+
372
+ if self.early_stopping is not None and sampling_mode == "link_valid":
373
+ if self.metric in ['mr', 'mr_type']:
374
+ self.early_stopping(-results[self.metric], self.get_model())
375
+ elif self.metric in results.keys():
376
+ self.early_stopping(results[self.metric], self.get_model())
377
+ else:
378
+ raise ValueError("Early stopping metric is not valid.")
379
+
380
+ def to_var(
381
+ self,
382
+ x: torch.Tensor) -> torch.Tensor:
383
+
384
+ """将 ``x`` 转移到对应的设备上。
385
+
386
+ :param x: 数据
387
+ :type x: torch.Tensor
388
+ :returns: 张量
389
+ :rtype: torch.Tensor
390
+ """
391
+
392
+ if self.use_gpu:
393
+ return x.to(self.device)
394
+ else:
395
+ return x
396
+
397
+ def get_model(self) -> Model:
398
+
399
+ """返回原始的 KGE 模型。
400
+
401
+ :returns: KGE 模型
402
+ :rtype: :py:class:`unike.module.model.Model`
403
+ """
404
+
405
+ if self.accelerator:
406
+ return self.model.module.model
407
+ else:
408
+ return self.model.model
409
+
410
+ def get_device(self) -> typing.Union[torch.device, str]:
411
+
412
+ """返回当前进程的设备。
413
+
414
+ :returns: 设备信息
415
+ :rtype: typing.Union[torch.device, str]
416
+ """
417
+
418
+ if self.accelerator:
419
+ return self.model.device
420
+ else:
421
+ return self.device
422
+
423
+ def is_local_main_process(self) -> bool:
424
+
425
+ """当前进程是否是主进程。
426
+
427
+ :returns: 当前进程是否是主进程。
428
+ :rtype: bool
429
+ """
430
+
431
+ return not self.accelerator or self.accelerator.is_local_main_process
432
+
433
+ def get_trainer_hpo_config() -> dict[str, dict[str, typing.Any]]:
434
+
435
+ """返回 :py:class:`Trainer` 的默认超参数优化配置。
436
+
437
+ 默认配置为::
438
+
439
+ parameters_dict = {
440
+ 'trainer': {
441
+ 'value': 'Trainer'
442
+ },
443
+ 'epochs': {
444
+ 'value': 10000
445
+ },
446
+ 'lr': {
447
+ 'distribution': 'uniform',
448
+ 'min': 1e-5,
449
+ 'max': 1.0
450
+ },
451
+ 'opt_method': {
452
+ 'values': ['adam', 'adagrad', 'sgd']
453
+ },
454
+ 'valid_interval': {
455
+ 'value': 100
456
+ },
457
+ 'log_interval': {
458
+ 'value': 100
459
+ },
460
+ 'save_path': {
461
+ 'value': './model.pth'
462
+ },
463
+ 'use_early_stopping': {
464
+ 'value': True
465
+ },
466
+ 'metric': {
467
+ 'value': 'hits@10'
468
+ },
469
+ 'patience': {
470
+ 'value': 2
471
+ },
472
+ 'delta': {
473
+ 'value': 0.0001
474
+ },
475
+ }
476
+
477
+ :returns: :py:class:`Trainer` 的默认超参数优化配置
478
+ :rtype: dict[str, dict[str, typing.Any]]
479
+ """
480
+
481
+ parameters_dict = {
482
+ 'trainer': {
483
+ 'value': 'Trainer'
484
+ },
485
+ 'epochs': {
486
+ 'value': 10000
487
+ },
488
+ 'lr': {
489
+ 'distribution': 'uniform',
490
+ 'min': 1e-5,
491
+ 'max': 1.0
492
+ },
493
+ 'opt_method': {
494
+ 'values': ['adam', 'adagrad', 'sgd']
495
+ },
496
+ 'valid_interval': {
497
+ 'value': 100
498
+ },
499
+ 'log_interval': {
500
+ 'value': 100
501
+ },
502
+ 'save_path': {
503
+ 'value': './model.pth'
504
+ },
505
+ 'use_early_stopping': {
506
+ 'value': True
507
+ },
508
+ 'metric': {
509
+ 'value': 'hits@10'
510
+ },
511
+ 'patience': {
512
+ 'value': 2
513
+ },
514
+ 'delta': {
515
+ 'value': 0.0001
516
+ },
517
+ }
518
+
519
+ return parameters_dict
@@ -0,0 +1,39 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/config/TrainerAccelerator.py
4
+ #
5
+ # created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Apr 12, 2024
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Apr 27, 2024
7
+ #
8
+ # 该脚本定义了并行训练循环函数.
9
+
10
+ """
11
+ 利用 accelerate 实现并行训练。
12
+ """
13
+
14
+ from typing import Any, List
15
+ from accelerate import Accelerator
16
+
17
+ def accelerator_prepare(*args: List[Any]) -> List[Any]:
18
+
19
+ """
20
+ 由于分布式并行依赖于 `accelerate <https://github.com/huggingface/accelerate>`_ ,因此,需要利用 Accelerator 为分布式训练准备对象。
21
+
22
+ 例子::
23
+
24
+ dataloader, model, accelerator = accelerator_prepare(
25
+ dataloader,
26
+ model
27
+ )
28
+
29
+ :param args: :py:class:`unike.data.KGEDataLoader` 和 :py:class:`unike.module.strategy.Strategy` 。
30
+ :type args: typing.List[typing.Any]
31
+ :returns: 包装好的对象列表和 Accelerator() 对象。
32
+ :rtype: typing.List[typing.Any]
33
+ """
34
+
35
+ accelerator = Accelerator()
36
+ result = accelerator.prepare(*args)
37
+ result = list(result)
38
+ result.append(accelerator)
39
+ return result
@@ -0,0 +1,37 @@
1
+ # coding:utf-8
2
+ #
3
+ # unike/config/__init__.py
4
+ #
5
+ # git pull from OpenKE-PyTorch by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on May 7, 2023
6
+ # updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on Jan 3, 2023
7
+ #
8
+ # 该头文件定义了 config 接口.
9
+
10
+ """循环部分,包含训练循环和验证循环。"""
11
+
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+ from .Trainer import Trainer, get_trainer_hpo_config
17
+ from .Tester import Tester, get_tester_hpo_config, link_predict, head_predict, tail_predict, calc_ranks
18
+
19
+ from .HPOTrainer import set_hpo_config, set_hpo_hits,start_hpo_train, hpo_train
20
+
21
+ from .TrainerAccelerator import accelerator_prepare
22
+
23
+ __all__ = [
24
+ 'Trainer',
25
+ 'get_trainer_hpo_config',
26
+ 'Tester',
27
+ 'get_tester_hpo_config',
28
+ 'link_predict',
29
+ 'head_predict',
30
+ 'tail_predict',
31
+ 'calc_ranks',
32
+ 'set_hpo_config',
33
+ 'set_hpo_hits',
34
+ 'start_hpo_train',
35
+ 'hpo_train',
36
+ 'accelerator_prepare'
37
+ ]