@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.js +14 -0
  135. package/dist/{core/task → task}/display_information.d.ts +5 -3
  136. package/dist/task/display_information.js +46 -0
  137. package/dist/task/index.d.ts +7 -0
  138. package/dist/task/index.js +5 -0
  139. package/dist/task/label_type.d.ts +9 -0
  140. package/dist/task/label_type.js +28 -0
  141. package/dist/task/summary.js +13 -0
  142. package/dist/{core/task → task}/task.d.ts +7 -7
  143. package/dist/task/task.js +22 -0
  144. package/dist/task/task_handler.d.ts +5 -0
  145. package/dist/task/task_handler.js +20 -0
  146. package/dist/task/task_provider.d.ts +5 -0
  147. package/dist/task/task_provider.js +1 -0
  148. package/dist/{core/task → task}/training_information.d.ts +9 -10
  149. package/dist/task/training_information.js +88 -0
  150. package/dist/training/disco.d.ts +40 -0
  151. package/dist/training/disco.js +107 -0
  152. package/dist/training/index.d.ts +2 -0
  153. package/dist/training/index.js +1 -0
  154. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  155. package/dist/training/trainer/distributed_trainer.js +36 -0
  156. package/dist/training/trainer/local_trainer.d.ts +12 -0
  157. package/dist/training/trainer/local_trainer.js +19 -0
  158. package/dist/training/trainer/trainer.d.ts +33 -0
  159. package/dist/training/trainer/trainer.js +52 -0
  160. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  161. package/dist/training/trainer/trainer_builder.js +43 -0
  162. package/dist/types.d.ts +8 -0
  163. package/dist/types.js +1 -0
  164. package/dist/utils/event_emitter.d.ts +40 -0
  165. package/dist/utils/event_emitter.js +57 -0
  166. package/dist/validation/index.d.ts +1 -0
  167. package/dist/validation/index.js +1 -0
  168. package/dist/validation/validator.d.ts +28 -0
  169. package/dist/validation/validator.js +132 -0
  170. package/dist/weights/aggregation.d.ts +21 -0
  171. package/dist/weights/aggregation.js +44 -0
  172. package/dist/weights/index.d.ts +2 -0
  173. package/dist/weights/index.js +2 -0
  174. package/dist/weights/weights_container.d.ts +68 -0
  175. package/dist/weights/weights_container.js +96 -0
  176. package/package.json +24 -15
  177. package/README.md +0 -53
  178. package/dist/core/async_buffer.d.ts +0 -41
  179. package/dist/core/async_buffer.js +0 -97
  180. package/dist/core/async_informant.d.ts +0 -20
  181. package/dist/core/async_informant.js +0 -69
  182. package/dist/core/client/base.d.ts +0 -33
  183. package/dist/core/client/base.js +0 -35
  184. package/dist/core/client/decentralized/base.d.ts +0 -32
  185. package/dist/core/client/decentralized/base.js +0 -212
  186. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  187. package/dist/core/client/decentralized/clear_text.js +0 -96
  188. package/dist/core/client/decentralized/index.d.ts +0 -4
  189. package/dist/core/client/decentralized/index.js +0 -9
  190. package/dist/core/client/decentralized/messages.d.ts +0 -41
  191. package/dist/core/client/decentralized/messages.js +0 -54
  192. package/dist/core/client/decentralized/peer.d.ts +0 -26
  193. package/dist/core/client/decentralized/peer.js +0 -210
  194. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  195. package/dist/core/client/decentralized/peer_pool.js +0 -92
  196. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  197. package/dist/core/client/decentralized/sec_agg.js +0 -190
  198. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  199. package/dist/core/client/decentralized/secret_shares.js +0 -39
  200. package/dist/core/client/decentralized/types.d.ts +0 -2
  201. package/dist/core/client/decentralized/types.js +0 -7
  202. package/dist/core/client/event_connection.d.ts +0 -37
  203. package/dist/core/client/event_connection.js +0 -158
  204. package/dist/core/client/federated/client.d.ts +0 -37
  205. package/dist/core/client/federated/client.js +0 -273
  206. package/dist/core/client/federated/index.d.ts +0 -2
  207. package/dist/core/client/federated/index.js +0 -7
  208. package/dist/core/client/federated/messages.d.ts +0 -38
  209. package/dist/core/client/federated/messages.js +0 -25
  210. package/dist/core/client/index.d.ts +0 -5
  211. package/dist/core/client/index.js +0 -11
  212. package/dist/core/client/local.d.ts +0 -8
  213. package/dist/core/client/local.js +0 -36
  214. package/dist/core/client/messages.d.ts +0 -28
  215. package/dist/core/client/messages.js +0 -33
  216. package/dist/core/client/utils.d.ts +0 -2
  217. package/dist/core/client/utils.js +0 -19
  218. package/dist/core/dataset/data/data.d.ts +0 -11
  219. package/dist/core/dataset/data/data.js +0 -20
  220. package/dist/core/dataset/data/data_split.d.ts +0 -5
  221. package/dist/core/dataset/data/data_split.js +0 -2
  222. package/dist/core/dataset/data/image_data.d.ts +0 -8
  223. package/dist/core/dataset/data/image_data.js +0 -64
  224. package/dist/core/dataset/data/index.d.ts +0 -5
  225. package/dist/core/dataset/data/index.js +0 -11
  226. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  227. package/dist/core/dataset/data/preprocessing.js +0 -33
  228. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  229. package/dist/core/dataset/data/tabular_data.js +0 -40
  230. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  231. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  232. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  233. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  234. package/dist/core/dataset/data_loader/index.js +0 -9
  235. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  236. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  237. package/dist/core/dataset/dataset.d.ts +0 -2
  238. package/dist/core/dataset/dataset.js +0 -2
  239. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  240. package/dist/core/dataset/dataset_builder.js +0 -96
  241. package/dist/core/dataset/index.d.ts +0 -4
  242. package/dist/core/dataset/index.js +0 -14
  243. package/dist/core/default_tasks/cifar10.d.ts +0 -2
  244. package/dist/core/default_tasks/geotags.d.ts +0 -2
  245. package/dist/core/default_tasks/geotags.js +0 -69
  246. package/dist/core/default_tasks/index.d.ts +0 -6
  247. package/dist/core/default_tasks/index.js +0 -15
  248. package/dist/core/default_tasks/lus_covid.d.ts +0 -2
  249. package/dist/core/default_tasks/lus_covid.js +0 -96
  250. package/dist/core/default_tasks/mnist.d.ts +0 -2
  251. package/dist/core/default_tasks/simple_face.d.ts +0 -2
  252. package/dist/core/default_tasks/titanic.d.ts +0 -2
  253. package/dist/core/index.d.ts +0 -18
  254. package/dist/core/index.js +0 -39
  255. package/dist/core/informant/graph_informant.js +0 -23
  256. package/dist/core/informant/index.d.ts +0 -3
  257. package/dist/core/informant/index.js +0 -9
  258. package/dist/core/informant/training_informant/base.d.ts +0 -31
  259. package/dist/core/informant/training_informant/base.js +0 -83
  260. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  261. package/dist/core/informant/training_informant/decentralized.js +0 -22
  262. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  263. package/dist/core/informant/training_informant/federated.js +0 -32
  264. package/dist/core/informant/training_informant/index.d.ts +0 -4
  265. package/dist/core/informant/training_informant/index.js +0 -11
  266. package/dist/core/informant/training_informant/local.d.ts +0 -6
  267. package/dist/core/informant/training_informant/local.js +0 -20
  268. package/dist/core/logging/console_logger.js +0 -33
  269. package/dist/core/logging/index.d.ts +0 -3
  270. package/dist/core/logging/index.js +0 -9
  271. package/dist/core/logging/logger.js +0 -9
  272. package/dist/core/logging/trainer_logger.d.ts +0 -24
  273. package/dist/core/logging/trainer_logger.js +0 -59
  274. package/dist/core/memory/base.d.ts +0 -22
  275. package/dist/core/memory/base.js +0 -9
  276. package/dist/core/memory/empty.d.ts +0 -14
  277. package/dist/core/memory/empty.js +0 -75
  278. package/dist/core/memory/index.d.ts +0 -3
  279. package/dist/core/memory/index.js +0 -9
  280. package/dist/core/memory/model_type.d.ts +0 -4
  281. package/dist/core/memory/model_type.js +0 -9
  282. package/dist/core/serialization/index.d.ts +0 -2
  283. package/dist/core/serialization/index.js +0 -6
  284. package/dist/core/serialization/model.d.ts +0 -5
  285. package/dist/core/serialization/model.js +0 -55
  286. package/dist/core/serialization/weights.js +0 -64
  287. package/dist/core/task/data_example.js +0 -24
  288. package/dist/core/task/digest.js +0 -18
  289. package/dist/core/task/display_information.js +0 -49
  290. package/dist/core/task/index.d.ts +0 -6
  291. package/dist/core/task/index.js +0 -15
  292. package/dist/core/task/model_compile_data.d.ts +0 -6
  293. package/dist/core/task/model_compile_data.js +0 -22
  294. package/dist/core/task/summary.js +0 -19
  295. package/dist/core/task/task.js +0 -35
  296. package/dist/core/task/task_handler.d.ts +0 -5
  297. package/dist/core/task/task_handler.js +0 -53
  298. package/dist/core/task/task_provider.d.ts +0 -6
  299. package/dist/core/task/task_provider.js +0 -13
  300. package/dist/core/task/training_information.js +0 -66
  301. package/dist/core/training/disco.d.ts +0 -23
  302. package/dist/core/training/disco.js +0 -130
  303. package/dist/core/training/index.d.ts +0 -2
  304. package/dist/core/training/index.js +0 -7
  305. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  306. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  307. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  308. package/dist/core/training/trainer/local_trainer.js +0 -34
  309. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  310. package/dist/core/training/trainer/round_tracker.js +0 -47
  311. package/dist/core/training/trainer/trainer.d.ts +0 -65
  312. package/dist/core/training/trainer/trainer.js +0 -160
  313. package/dist/core/training/trainer/trainer_builder.js +0 -95
  314. package/dist/core/training/training_schemes.d.ts +0 -5
  315. package/dist/core/training/training_schemes.js +0 -10
  316. package/dist/core/types.d.ts +0 -4
  317. package/dist/core/types.js +0 -2
  318. package/dist/core/validation/index.d.ts +0 -1
  319. package/dist/core/validation/index.js +0 -5
  320. package/dist/core/validation/validator.d.ts +0 -17
  321. package/dist/core/validation/validator.js +0 -104
  322. package/dist/core/weights/aggregation.d.ts +0 -7
  323. package/dist/core/weights/aggregation.js +0 -72
  324. package/dist/core/weights/index.d.ts +0 -2
  325. package/dist/core/weights/index.js +0 -7
  326. package/dist/core/weights/weights_container.d.ts +0 -19
  327. package/dist/core/weights/weights_container.js +0 -64
  328. package/dist/imports.d.ts +0 -2
  329. package/dist/imports.js +0 -7
  330. package/dist/memory/memory.d.ts +0 -26
  331. package/dist/memory/memory.js +0 -160
  332. package/dist/{core/task → task}/data_example.d.ts +1 -1
  333. package/dist/{core/task → task}/digest.d.ts +0 -0
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -1,130 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Disco = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("..");
6
- var trainer_builder_1 = require("./trainer/trainer_builder");
7
- // Handles the training loop, server communication & provides the user with feedback.
8
- var Disco = /** @class */ (function () {
9
- // client need to be connected
10
- function Disco(task, options) {
11
- if (options.scheme === undefined) {
12
- options.scheme = __1.TrainingSchemes[task.trainingInformation.scheme];
13
- }
14
- if (options.client === undefined) {
15
- if (options.url === undefined) {
16
- throw new Error('could not determine client from given parameters');
17
- }
18
- if (typeof options.url === 'string') {
19
- options.url = new URL(options.url);
20
- }
21
- switch (options.scheme) {
22
- case __1.TrainingSchemes.FEDERATED:
23
- options.client = new __1.client.federated.Client(options.url, task);
24
- break;
25
- case __1.TrainingSchemes.DECENTRALIZED:
26
- options.client = new __1.client.federated.Client(options.url, task);
27
- break;
28
- default:
29
- options.client = new __1.client.Local(options.url, task);
30
- break;
31
- }
32
- }
33
- if (options.informant === undefined) {
34
- switch (options.scheme) {
35
- case __1.TrainingSchemes.FEDERATED:
36
- options.informant = new __1.informant.FederatedInformant(task);
37
- break;
38
- case __1.TrainingSchemes.DECENTRALIZED:
39
- options.informant = new __1.informant.DecentralizedInformant(task);
40
- break;
41
- default:
42
- options.informant = new __1.informant.LocalInformant(task);
43
- break;
44
- }
45
- }
46
- if (options.logger === undefined) {
47
- options.logger = new __1.ConsoleLogger();
48
- }
49
- if (options.memory === undefined) {
50
- options.memory = new __1.EmptyMemory();
51
- }
52
- if (options.client.task !== task) {
53
- throw new Error('client not setup for given task');
54
- }
55
- if (options.informant.task.taskID !== task.taskID) {
56
- throw new Error('informant not setup for given task');
57
- }
58
- this.task = task;
59
- this.client = options.client;
60
- this.memory = options.memory;
61
- this.logger = options.logger;
62
- var trainerBuilder = new trainer_builder_1.TrainerBuilder(this.memory, this.task, options.informant);
63
- this.trainer = trainerBuilder.build(this.client, options.scheme !== __1.TrainingSchemes.LOCAL);
64
- }
65
- Disco.prototype.fit = function (dataTuple) {
66
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
67
- var trainDataset, valDataset;
68
- return (0, tslib_1.__generator)(this, function (_a) {
69
- switch (_a.label) {
70
- case 0:
71
- this.logger.success('Thank you for your contribution. Data preprocessing has started');
72
- trainDataset = dataTuple.train.batch().preprocess();
73
- valDataset = dataTuple.validation !== undefined
74
- ? dataTuple.validation.batch().preprocess()
75
- : trainDataset;
76
- return [4 /*yield*/, this.client.connect()];
77
- case 1:
78
- _a.sent();
79
- return [4 /*yield*/, this.trainer];
80
- case 2: return [4 /*yield*/, (_a.sent()).trainModel(trainDataset.dataset, valDataset.dataset)];
81
- case 3:
82
- _a.sent();
83
- return [2 /*return*/];
84
- }
85
- });
86
- });
87
- };
88
- // Stops the training function. Does not disconnect the client.
89
- Disco.prototype.pause = function () {
90
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
91
- return (0, tslib_1.__generator)(this, function (_a) {
92
- switch (_a.label) {
93
- case 0: return [4 /*yield*/, this.trainer];
94
- case 1: return [4 /*yield*/, (_a.sent()).stopTraining()];
95
- case 2:
96
- _a.sent();
97
- this.logger.success('Training was successfully interrupted.');
98
- return [2 /*return*/];
99
- }
100
- });
101
- });
102
- };
103
- Disco.prototype.close = function () {
104
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
105
- return (0, tslib_1.__generator)(this, function (_a) {
106
- switch (_a.label) {
107
- case 0: return [4 /*yield*/, this.pause()];
108
- case 1:
109
- _a.sent();
110
- return [4 /*yield*/, this.client.disconnect()];
111
- case 2:
112
- _a.sent();
113
- return [2 /*return*/];
114
- }
115
- });
116
- });
117
- };
118
- Disco.prototype.logs = function () {
119
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
120
- return (0, tslib_1.__generator)(this, function (_a) {
121
- switch (_a.label) {
122
- case 0: return [4 /*yield*/, this.trainer];
123
- case 1: return [2 /*return*/, (_a.sent()).getTrainerLog()];
124
- }
125
- });
126
- });
127
- };
128
- return Disco;
129
- }());
130
- exports.Disco = Disco;
@@ -1,2 +0,0 @@
1
- export { Disco } from './disco';
2
- export { TrainingSchemes } from './training_schemes';
@@ -1,7 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.TrainingSchemes = exports.Disco = void 0;
4
- var disco_1 = require("./disco");
5
- Object.defineProperty(exports, "Disco", { enumerable: true, get: function () { return disco_1.Disco; } });
6
- var training_schemes_1 = require("./training_schemes");
7
- Object.defineProperty(exports, "TrainingSchemes", { enumerable: true, get: function () { return training_schemes_1.TrainingSchemes; } });
@@ -1,20 +0,0 @@
1
- import { tf, Client, Memory, Task, TrainingInformant } from '../..';
2
- import { Trainer } from './trainer';
3
- /**
4
- * Class whose role is to train a model in a distributed way with a given dataset.
5
- */
6
- export declare class DistributedTrainer extends Trainer {
7
- private readonly previousRoundModel;
8
- private readonly client;
9
- /** DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
10
- */
11
- constructor(task: Task, trainingInformant: TrainingInformant, memory: Memory, model: tf.LayersModel, previousRoundModel: tf.LayersModel, client: Client);
12
- /**
13
- * Callback called every time a round is over
14
- */
15
- onRoundEnd(accuracy: number): Promise<void>;
16
- /**
17
- * Callback called once training is over
18
- */
19
- onTrainEnd(): Promise<void>;
20
- }
@@ -1,65 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.DistributedTrainer = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("../..");
6
- var trainer_1 = require("./trainer");
7
- /**
8
- * Class whose role is to train a model in a distributed way with a given dataset.
9
- */
10
- var DistributedTrainer = /** @class */ (function (_super) {
11
- (0, tslib_1.__extends)(DistributedTrainer, _super);
12
- /** DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
13
- */
14
- function DistributedTrainer(task, trainingInformant, memory, model, previousRoundModel, client) {
15
- var _this = _super.call(this, task, trainingInformant, memory, model) || this;
16
- _this.previousRoundModel = previousRoundModel;
17
- _this.client = client;
18
- return _this;
19
- }
20
- /**
21
- * Callback called every time a round is over
22
- */
23
- DistributedTrainer.prototype.onRoundEnd = function (accuracy) {
24
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
25
- var currentRoundWeights, previousRoundWeights, aggregatedWeights;
26
- return (0, tslib_1.__generator)(this, function (_a) {
27
- switch (_a.label) {
28
- case 0:
29
- currentRoundWeights = __1.WeightsContainer.from(this.model);
30
- previousRoundWeights = __1.WeightsContainer.from(this.previousRoundModel);
31
- return [4 /*yield*/, this.client.onRoundEndCommunication(currentRoundWeights, previousRoundWeights, this.roundTracker.round, this.trainingInformant)];
32
- case 1:
33
- aggregatedWeights = _a.sent();
34
- this.previousRoundModel.setWeights(currentRoundWeights.weights);
35
- this.model.setWeights(aggregatedWeights.weights);
36
- return [4 /*yield*/, this.memory.updateWorkingModel({ taskID: this.task.taskID, name: this.trainingInformation.modelID }, this.model)];
37
- case 2:
38
- _a.sent();
39
- return [2 /*return*/];
40
- }
41
- });
42
- });
43
- };
44
- // if it is undefined, will training continue? we hope yes
45
- /**
46
- * Callback called once training is over
47
- */
48
- DistributedTrainer.prototype.onTrainEnd = function () {
49
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
50
- return (0, tslib_1.__generator)(this, function (_a) {
51
- switch (_a.label) {
52
- case 0: return [4 /*yield*/, this.client.onTrainEndCommunication(__1.WeightsContainer.from(this.model), this.trainingInformant)];
53
- case 1:
54
- _a.sent();
55
- return [4 /*yield*/, _super.prototype.onTrainEnd.call(this)];
56
- case 2:
57
- _a.sent();
58
- return [2 /*return*/];
59
- }
60
- });
61
- });
62
- };
63
- return DistributedTrainer;
64
- }(trainer_1.Trainer));
65
- exports.DistributedTrainer = DistributedTrainer;
@@ -1,11 +0,0 @@
1
- import { tf } from '../..';
2
- import { Trainer } from './trainer';
3
- /** Class whose role is to locally (alone) train a model on a given dataset, without any collaborators.
4
- */
5
- export declare class LocalTrainer extends Trainer {
6
- /**
7
- * Callback called every time a round is over. For local training, a round is typically an epoch
8
- */
9
- onRoundEnd(accuracy: number): Promise<void>;
10
- protected onEpochEnd(epoch: number, logs?: tf.Logs): void;
11
- }
@@ -1,34 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.LocalTrainer = void 0;
4
- var tslib_1 = require("tslib");
5
- var trainer_1 = require("./trainer");
6
- /** Class whose role is to locally (alone) train a model on a given dataset, without any collaborators.
7
- */
8
- var LocalTrainer = /** @class */ (function (_super) {
9
- (0, tslib_1.__extends)(LocalTrainer, _super);
10
- function LocalTrainer() {
11
- return _super !== null && _super.apply(this, arguments) || this;
12
- }
13
- /**
14
- * Callback called every time a round is over. For local training, a round is typically an epoch
15
- */
16
- LocalTrainer.prototype.onRoundEnd = function (accuracy) {
17
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
18
- return (0, tslib_1.__generator)(this, function (_a) {
19
- switch (_a.label) {
20
- case 0: return [4 /*yield*/, this.memory.updateWorkingModel({ taskID: this.task.taskID, name: this.trainingInformation.modelID }, this.model)];
21
- case 1:
22
- _a.sent();
23
- return [2 /*return*/];
24
- }
25
- });
26
- });
27
- };
28
- LocalTrainer.prototype.onEpochEnd = function (epoch, logs) {
29
- _super.prototype.onEpochEnd.call(this, epoch, logs);
30
- this.trainingInformant.update({ currentRound: epoch });
31
- };
32
- return LocalTrainer;
33
- }(trainer_1.Trainer));
34
- exports.LocalTrainer = LocalTrainer;
@@ -1,30 +0,0 @@
1
- /**
2
- * Class that keeps track of the current batch in order for the trainer to query when round has ended.
3
- *
4
- * @remark
5
- * In distributed training, the client trains locally for a certain amount of epochs before sharing his weights to the server/neighbor, this
6
- * is what we call a round.
7
- *
8
- * The role of the RoundTracker is to keep track of when a roundHasEnded using the current batch number. The batch in the RoundTracker is cumulative whereas
9
- * in the onBatchEnd it is not (it resets to 0 after each epoch).
10
- *
11
- * The roundDuration is the length of a round (in batches).
12
- */
13
- export declare class RoundTracker {
14
- round: number;
15
- batch: number;
16
- roundDuration: number;
17
- constructor(roundDuration: number);
18
- /**
19
- * Update the batch number, to be called inside onBatchEnd. (We do not use batch output of onBatchEnd since it is
20
- * not cumulative).
21
- */
22
- updateBatch(): void;
23
- /**
24
- * Returns true if a local round has ended, false otherwise.
25
- *
26
- * @remark
27
- * Returns true if (batch) mod (batches per round) == 0, false otherwise
28
- */
29
- roundHasEnded(): boolean;
30
- }
@@ -1,47 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.RoundTracker = void 0;
4
- /**
5
- * Class that keeps track of the current batch in order for the trainer to query when round has ended.
6
- *
7
- * @remark
8
- * In distributed training, the client trains locally for a certain amount of epochs before sharing his weights to the server/neighbor, this
9
- * is what we call a round.
10
- *
11
- * The role of the RoundTracker is to keep track of when a roundHasEnded using the current batch number. The batch in the RoundTracker is cumulative whereas
12
- * in the onBatchEnd it is not (it resets to 0 after each epoch).
13
- *
14
- * The roundDuration is the length of a round (in batches).
15
- */
16
- var RoundTracker = /** @class */ (function () {
17
- function RoundTracker(roundDuration) {
18
- this.round = 0;
19
- this.batch = 0;
20
- this.roundDuration = roundDuration;
21
- }
22
- /**
23
- * Update the batch number, to be called inside onBatchEnd. (We do not use batch output of onBatchEnd since it is
24
- * not cumulative).
25
- */
26
- RoundTracker.prototype.updateBatch = function () {
27
- this.batch += 1;
28
- };
29
- /**
30
- * Returns true if a local round has ended, false otherwise.
31
- *
32
- * @remark
33
- * Returns true if (batch) mod (batches per round) == 0, false otherwise
34
- */
35
- RoundTracker.prototype.roundHasEnded = function () {
36
- if (this.batch === 0) {
37
- return false;
38
- }
39
- var roundHasEnded = this.batch % this.roundDuration === 0;
40
- if (roundHasEnded) {
41
- this.round += 1;
42
- }
43
- return roundHasEnded;
44
- };
45
- return RoundTracker;
46
- }());
47
- exports.RoundTracker = RoundTracker;
@@ -1,65 +0,0 @@
1
- import { tf, Memory, Task, TrainingInformant, TrainingInformation } from '../..';
2
- import { RoundTracker } from './round_tracker';
3
- import { TrainerLog } from '../../logging/trainer_logger';
4
- /** Abstract class whose role is to train a model with a given dataset. This can be either done
5
- * locally (alone) or in a distributed way with collaborators. The Trainer works as follows:
6
- *
7
- * 1. Call trainModel(dataset) to start training
8
- * 2. Once a batch ends, onBatchEnd is triggered, which will then call onRoundEnd once the round has ended.
9
- *
10
- * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators. To know when
11
- * a round has ended we use the roundTracker object.
12
- */
13
- export declare abstract class Trainer {
14
- readonly task: Task;
15
- readonly trainingInformant: TrainingInformant;
16
- readonly memory: Memory;
17
- readonly model: tf.LayersModel;
18
- readonly trainingInformation: TrainingInformation;
19
- readonly roundTracker: RoundTracker;
20
- private stopTrainingRequested;
21
- private readonly trainerLogger;
22
- /**
23
- * Constructs the training manager.
24
- * @param task the trained task
25
- * @param trainingInformant the training informant
26
- */
27
- constructor(task: Task, trainingInformant: TrainingInformant, memory: Memory, model: tf.LayersModel);
28
- /**
29
- * Every time a round ends this function will be called
30
- */
31
- protected abstract onRoundEnd(accuracy: number): Promise<void>;
32
- /** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
33
- */
34
- protected onBatchEnd(_: number, logs?: tf.Logs): Promise<void>;
35
- /**
36
- * We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
37
- */
38
- protected onEpochEnd(epoch: number, logs?: tf.Logs): void;
39
- /**
40
- * When the training ends this function will be call
41
- */
42
- protected onTrainEnd(logs?: tf.Logs): Promise<void>;
43
- /**
44
- * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
45
- */
46
- stopTraining(): Promise<void>;
47
- /**
48
- * Start training the model with the given dataset
49
- * @param dataset
50
- */
51
- trainModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): Promise<void>;
52
- /**
53
- * Format accuracy
54
- */
55
- protected roundDecimals(accuracy: number, decimalsToRound?: number): number;
56
- /**
57
- * reset stop training state
58
- */
59
- protected resetStopTrainerState(): void;
60
- /**
61
- * If stop training is requested, do so
62
- */
63
- protected stopTrainModelIfRequested(): void;
64
- getTrainerLog(): TrainerLog;
65
- }
@@ -1,160 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Trainer = void 0;
4
- var tslib_1 = require("tslib");
5
- var round_tracker_1 = require("./round_tracker");
6
- var trainer_logger_1 = require("../../logging/trainer_logger");
7
- /** Abstract class whose role is to train a model with a given dataset. This can be either done
8
- * locally (alone) or in a distributed way with collaborators. The Trainer works as follows:
9
- *
10
- * 1. Call trainModel(dataset) to start training
11
- * 2. Once a batch ends, onBatchEnd is triggered, which will then call onRoundEnd once the round has ended.
12
- *
13
- * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators. To know when
14
- * a round has ended we use the roundTracker object.
15
- */
16
- var Trainer = /** @class */ (function () {
17
- /**
18
- * Constructs the training manager.
19
- * @param task the trained task
20
- * @param trainingInformant the training informant
21
- */
22
- function Trainer(task, trainingInformant, memory, model) {
23
- this.task = task;
24
- this.trainingInformant = trainingInformant;
25
- this.memory = memory;
26
- this.model = model;
27
- this.stopTrainingRequested = false;
28
- this.trainerLogger = new trainer_logger_1.TrainerLogger();
29
- var trainingInformation = task.trainingInformation;
30
- if (trainingInformation === undefined) {
31
- throw new Error('round duration is undefined');
32
- }
33
- this.trainingInformation = trainingInformation;
34
- this.roundTracker = new round_tracker_1.RoundTracker(trainingInformation.roundDuration);
35
- }
36
- /** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
37
- */
38
- Trainer.prototype.onBatchEnd = function (_, logs) {
39
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
40
- return (0, tslib_1.__generator)(this, function (_a) {
41
- switch (_a.label) {
42
- case 0:
43
- if (logs === undefined) {
44
- return [2 /*return*/];
45
- }
46
- this.roundTracker.updateBatch();
47
- this.stopTrainModelIfRequested();
48
- if (!this.roundTracker.roundHasEnded()) return [3 /*break*/, 2];
49
- return [4 /*yield*/, this.onRoundEnd(logs.acc)];
50
- case 1:
51
- _a.sent();
52
- _a.label = 2;
53
- case 2: return [2 /*return*/];
54
- }
55
- });
56
- });
57
- };
58
- /**
59
- * We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
60
- */
61
- Trainer.prototype.onEpochEnd = function (epoch, logs) {
62
- this.trainerLogger.onEpochEnd(epoch, logs);
63
- if (logs !== undefined && !isNaN(logs.acc) && !isNaN(logs.val_acc)) {
64
- this.trainingInformant.updateTrainingGraph(this.roundDecimals(logs.acc));
65
- this.trainingInformant.updateValidationGraph(this.roundDecimals(logs.val_acc));
66
- }
67
- else {
68
- this.trainerLogger.error('onEpochEnd: NaN value');
69
- }
70
- };
71
- /**
72
- * When the training ends this function will be call
73
- */
74
- Trainer.prototype.onTrainEnd = function (logs) {
75
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
76
- return (0, tslib_1.__generator)(this, function (_a) {
77
- this.trainingInformant.addMessage('Training finished.');
78
- return [2 /*return*/];
79
- });
80
- });
81
- };
82
- /**
83
- * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
84
- */
85
- Trainer.prototype.stopTraining = function () {
86
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
87
- return (0, tslib_1.__generator)(this, function (_a) {
88
- this.stopTrainingRequested = true;
89
- return [2 /*return*/];
90
- });
91
- });
92
- };
93
- /**
94
- * Start training the model with the given dataset
95
- * @param dataset
96
- */
97
- Trainer.prototype.trainModel = function (dataset, valDataset) {
98
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
99
- var _this = this;
100
- return (0, tslib_1.__generator)(this, function (_a) {
101
- switch (_a.label) {
102
- case 0:
103
- this.resetStopTrainerState();
104
- // Assign callbacks and start training
105
- return [4 /*yield*/, this.model.fitDataset(dataset, {
106
- epochs: this.trainingInformation.epochs,
107
- validationData: valDataset,
108
- callbacks: {
109
- onEpochEnd: function (epoch, logs) { return _this.onEpochEnd(epoch, logs); },
110
- onBatchEnd: function (epoch, logs) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
111
- switch (_a.label) {
112
- case 0: return [4 /*yield*/, this.onBatchEnd(epoch, logs)];
113
- case 1: return [2 /*return*/, _a.sent()];
114
- }
115
- }); }); },
116
- onTrainEnd: function (logs) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
117
- switch (_a.label) {
118
- case 0: return [4 /*yield*/, this.onTrainEnd(logs)];
119
- case 1: return [2 /*return*/, _a.sent()];
120
- }
121
- }); }); }
122
- }
123
- })];
124
- case 1:
125
- // Assign callbacks and start training
126
- _a.sent();
127
- return [2 /*return*/];
128
- }
129
- });
130
- });
131
- };
132
- /**
133
- * Format accuracy
134
- */
135
- Trainer.prototype.roundDecimals = function (accuracy, decimalsToRound) {
136
- if (decimalsToRound === void 0) { decimalsToRound = 2; }
137
- return +(accuracy * 100).toFixed(decimalsToRound);
138
- };
139
- /**
140
- * reset stop training state
141
- */
142
- Trainer.prototype.resetStopTrainerState = function () {
143
- this.model.stopTraining = false;
144
- this.stopTrainingRequested = false;
145
- };
146
- /**
147
- * If stop training is requested, do so
148
- */
149
- Trainer.prototype.stopTrainModelIfRequested = function () {
150
- if (this.stopTrainingRequested) {
151
- this.model.stopTraining = true;
152
- this.stopTrainingRequested = false;
153
- }
154
- };
155
- Trainer.prototype.getTrainerLog = function () {
156
- return this.trainerLogger.log;
157
- };
158
- return Trainer;
159
- }());
160
- exports.Trainer = Trainer;
@@ -1,95 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.TrainerBuilder = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("../..");
6
- var distributed_trainer_1 = require("./distributed_trainer");
7
- var local_trainer_1 = require("./local_trainer");
8
- /**
9
- * A class that helps build the Trainer and auxiliary classes.
10
- */
11
- var TrainerBuilder = /** @class */ (function () {
12
- function TrainerBuilder(memory, task, trainingInformant) {
13
- this.memory = memory;
14
- this.task = task;
15
- this.trainingInformant = trainingInformant;
16
- }
17
- /**
18
- * Builds a trainer object.
19
- *
20
- * @param client client to share weights with (either distributed or federated)
21
- * @param distributed whether to build a distributed or local trainer
22
- * @returns
23
- */
24
- TrainerBuilder.prototype.build = function (client, distributed) {
25
- if (distributed === void 0) { distributed = false; }
26
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
27
- var model;
28
- return (0, tslib_1.__generator)(this, function (_a) {
29
- switch (_a.label) {
30
- case 0: return [4 /*yield*/, this.getModel(client)];
31
- case 1:
32
- model = _a.sent();
33
- if (distributed) {
34
- return [2 /*return*/, new distributed_trainer_1.DistributedTrainer(this.task, this.trainingInformant, this.memory, model, model, client)];
35
- }
36
- else {
37
- return [2 /*return*/, new local_trainer_1.LocalTrainer(this.task, this.trainingInformant, this.memory, model)];
38
- }
39
- return [2 /*return*/];
40
- }
41
- });
42
- });
43
- };
44
- /**
45
- * If a model exists in memory, laod it, otherwise load model from server
46
- * @returns
47
- */
48
- TrainerBuilder.prototype.getModel = function (client) {
49
- var _a;
50
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
51
- var modelID, info, model;
52
- return (0, tslib_1.__generator)(this, function (_b) {
53
- switch (_b.label) {
54
- case 0:
55
- modelID = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.modelID;
56
- if (modelID === undefined) {
57
- throw new TypeError('model ID is undefined');
58
- }
59
- info = { type: __1.ModelType.WORKING, taskID: this.task.taskID, name: modelID };
60
- return [4 /*yield*/, this.memory.contains(info)];
61
- case 1: return [4 /*yield*/, ((_b.sent()) ? this.memory.getModel(info) : client.getLatestModel())];
62
- case 2:
63
- model = _b.sent();
64
- return [4 /*yield*/, this.updateModelInformation(model)];
65
- case 3: return [2 /*return*/, _b.sent()];
66
- }
67
- });
68
- });
69
- };
70
- TrainerBuilder.prototype.updateModelInformation = function (model) {
71
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
72
- var info;
73
- return (0, tslib_1.__generator)(this, function (_a) {
74
- // Continue local training from previous epoch checkpoint
75
- if (model.getUserDefinedMetadata() === undefined) {
76
- model.setUserDefinedMetadata({ epoch: 0 });
77
- }
78
- info = this.task.trainingInformation;
79
- if (info === undefined) {
80
- throw new TypeError('training information is undefined');
81
- }
82
- model.compile(info.modelCompileData);
83
- if (info.learningRate !== undefined) {
84
- // TODO: Not the right way to change learningRate and hence we cast to any
85
- // the right way is to construct the optimiser and pass learningRate via
86
- // argument.
87
- model.optimizer.learningRate = info.learningRate;
88
- }
89
- return [2 /*return*/, model];
90
- });
91
- });
92
- };
93
- return TrainerBuilder;
94
- }());
95
- exports.TrainerBuilder = TrainerBuilder;