@epfml/discojs 2.0.0 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/default_tasks/cifar10/index.js +60 -0
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/default_tasks/mnist.js +61 -0
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/default_tasks/simple_face/index.js +48 -0
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/default_tasks/titanic.js +88 -0
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.d.ts +5 -0
  135. package/dist/task/digest.js +14 -0
  136. package/dist/{core/task → task}/display_information.d.ts +5 -3
  137. package/dist/task/display_information.js +46 -0
  138. package/dist/task/index.d.ts +7 -0
  139. package/dist/task/index.js +5 -0
  140. package/dist/task/label_type.d.ts +9 -0
  141. package/dist/task/label_type.js +28 -0
  142. package/dist/task/summary.js +13 -0
  143. package/dist/task/task.d.ts +12 -0
  144. package/dist/task/task.js +22 -0
  145. package/dist/task/task_handler.d.ts +5 -0
  146. package/dist/task/task_handler.js +20 -0
  147. package/dist/task/task_provider.d.ts +5 -0
  148. package/dist/task/task_provider.js +1 -0
  149. package/dist/{core/task → task}/training_information.d.ts +9 -10
  150. package/dist/task/training_information.js +88 -0
  151. package/dist/training/disco.d.ts +40 -0
  152. package/dist/training/disco.js +107 -0
  153. package/dist/training/index.d.ts +2 -0
  154. package/dist/training/index.js +1 -0
  155. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  156. package/dist/training/trainer/distributed_trainer.js +36 -0
  157. package/dist/training/trainer/local_trainer.d.ts +12 -0
  158. package/dist/training/trainer/local_trainer.js +19 -0
  159. package/dist/training/trainer/trainer.d.ts +33 -0
  160. package/dist/training/trainer/trainer.js +52 -0
  161. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  162. package/dist/training/trainer/trainer_builder.js +43 -0
  163. package/dist/types.d.ts +8 -0
  164. package/dist/types.js +1 -0
  165. package/dist/utils/event_emitter.d.ts +40 -0
  166. package/dist/utils/event_emitter.js +57 -0
  167. package/dist/validation/index.d.ts +1 -0
  168. package/dist/validation/index.js +1 -0
  169. package/dist/validation/validator.d.ts +28 -0
  170. package/dist/validation/validator.js +132 -0
  171. package/dist/weights/aggregation.d.ts +21 -0
  172. package/dist/weights/aggregation.js +44 -0
  173. package/dist/weights/index.d.ts +2 -0
  174. package/dist/weights/index.js +2 -0
  175. package/dist/weights/weights_container.d.ts +68 -0
  176. package/dist/weights/weights_container.js +96 -0
  177. package/package.json +25 -16
  178. package/README.md +0 -53
  179. package/dist/core/async_buffer.d.ts +0 -41
  180. package/dist/core/async_buffer.js +0 -97
  181. package/dist/core/async_informant.d.ts +0 -20
  182. package/dist/core/async_informant.js +0 -69
  183. package/dist/core/client/base.d.ts +0 -33
  184. package/dist/core/client/base.js +0 -35
  185. package/dist/core/client/decentralized/base.d.ts +0 -32
  186. package/dist/core/client/decentralized/base.js +0 -212
  187. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  188. package/dist/core/client/decentralized/clear_text.js +0 -96
  189. package/dist/core/client/decentralized/index.d.ts +0 -4
  190. package/dist/core/client/decentralized/index.js +0 -9
  191. package/dist/core/client/decentralized/messages.d.ts +0 -41
  192. package/dist/core/client/decentralized/messages.js +0 -54
  193. package/dist/core/client/decentralized/peer.d.ts +0 -26
  194. package/dist/core/client/decentralized/peer.js +0 -210
  195. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  196. package/dist/core/client/decentralized/peer_pool.js +0 -92
  197. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  198. package/dist/core/client/decentralized/sec_agg.js +0 -190
  199. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  200. package/dist/core/client/decentralized/secret_shares.js +0 -39
  201. package/dist/core/client/decentralized/types.d.ts +0 -2
  202. package/dist/core/client/decentralized/types.js +0 -7
  203. package/dist/core/client/event_connection.d.ts +0 -37
  204. package/dist/core/client/event_connection.js +0 -158
  205. package/dist/core/client/federated/client.d.ts +0 -37
  206. package/dist/core/client/federated/client.js +0 -273
  207. package/dist/core/client/federated/index.d.ts +0 -2
  208. package/dist/core/client/federated/index.js +0 -7
  209. package/dist/core/client/federated/messages.d.ts +0 -38
  210. package/dist/core/client/federated/messages.js +0 -25
  211. package/dist/core/client/index.d.ts +0 -5
  212. package/dist/core/client/index.js +0 -11
  213. package/dist/core/client/local.d.ts +0 -8
  214. package/dist/core/client/local.js +0 -36
  215. package/dist/core/client/messages.d.ts +0 -28
  216. package/dist/core/client/messages.js +0 -33
  217. package/dist/core/client/utils.d.ts +0 -2
  218. package/dist/core/client/utils.js +0 -19
  219. package/dist/core/dataset/data/data.d.ts +0 -11
  220. package/dist/core/dataset/data/data.js +0 -20
  221. package/dist/core/dataset/data/data_split.d.ts +0 -5
  222. package/dist/core/dataset/data/data_split.js +0 -2
  223. package/dist/core/dataset/data/image_data.d.ts +0 -8
  224. package/dist/core/dataset/data/image_data.js +0 -64
  225. package/dist/core/dataset/data/index.d.ts +0 -5
  226. package/dist/core/dataset/data/index.js +0 -11
  227. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  228. package/dist/core/dataset/data/preprocessing.js +0 -33
  229. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  230. package/dist/core/dataset/data/tabular_data.js +0 -40
  231. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  232. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  233. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  234. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  235. package/dist/core/dataset/data_loader/index.js +0 -9
  236. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  237. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  238. package/dist/core/dataset/dataset.d.ts +0 -2
  239. package/dist/core/dataset/dataset.js +0 -2
  240. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  241. package/dist/core/dataset/dataset_builder.js +0 -96
  242. package/dist/core/dataset/index.d.ts +0 -4
  243. package/dist/core/dataset/index.js +0 -14
  244. package/dist/core/index.d.ts +0 -18
  245. package/dist/core/index.js +0 -41
  246. package/dist/core/informant/graph_informant.js +0 -23
  247. package/dist/core/informant/index.d.ts +0 -3
  248. package/dist/core/informant/index.js +0 -9
  249. package/dist/core/informant/training_informant/base.d.ts +0 -31
  250. package/dist/core/informant/training_informant/base.js +0 -83
  251. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  252. package/dist/core/informant/training_informant/decentralized.js +0 -22
  253. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  254. package/dist/core/informant/training_informant/federated.js +0 -32
  255. package/dist/core/informant/training_informant/index.d.ts +0 -4
  256. package/dist/core/informant/training_informant/index.js +0 -11
  257. package/dist/core/informant/training_informant/local.d.ts +0 -6
  258. package/dist/core/informant/training_informant/local.js +0 -20
  259. package/dist/core/logging/console_logger.js +0 -33
  260. package/dist/core/logging/index.d.ts +0 -3
  261. package/dist/core/logging/index.js +0 -9
  262. package/dist/core/logging/logger.js +0 -9
  263. package/dist/core/logging/trainer_logger.d.ts +0 -24
  264. package/dist/core/logging/trainer_logger.js +0 -59
  265. package/dist/core/memory/base.d.ts +0 -22
  266. package/dist/core/memory/base.js +0 -9
  267. package/dist/core/memory/empty.d.ts +0 -14
  268. package/dist/core/memory/empty.js +0 -75
  269. package/dist/core/memory/index.d.ts +0 -3
  270. package/dist/core/memory/index.js +0 -9
  271. package/dist/core/memory/model_type.d.ts +0 -4
  272. package/dist/core/memory/model_type.js +0 -9
  273. package/dist/core/serialization/index.d.ts +0 -2
  274. package/dist/core/serialization/index.js +0 -6
  275. package/dist/core/serialization/model.d.ts +0 -5
  276. package/dist/core/serialization/model.js +0 -55
  277. package/dist/core/serialization/weights.js +0 -64
  278. package/dist/core/task/data_example.js +0 -24
  279. package/dist/core/task/display_information.js +0 -49
  280. package/dist/core/task/index.d.ts +0 -3
  281. package/dist/core/task/index.js +0 -8
  282. package/dist/core/task/model_compile_data.d.ts +0 -6
  283. package/dist/core/task/model_compile_data.js +0 -22
  284. package/dist/core/task/summary.js +0 -19
  285. package/dist/core/task/task.d.ts +0 -10
  286. package/dist/core/task/task.js +0 -31
  287. package/dist/core/task/training_information.js +0 -66
  288. package/dist/core/tasks/cifar10.d.ts +0 -3
  289. package/dist/core/tasks/cifar10.js +0 -65
  290. package/dist/core/tasks/geotags.d.ts +0 -3
  291. package/dist/core/tasks/geotags.js +0 -67
  292. package/dist/core/tasks/index.d.ts +0 -6
  293. package/dist/core/tasks/index.js +0 -10
  294. package/dist/core/tasks/lus_covid.d.ts +0 -3
  295. package/dist/core/tasks/lus_covid.js +0 -87
  296. package/dist/core/tasks/mnist.d.ts +0 -3
  297. package/dist/core/tasks/mnist.js +0 -60
  298. package/dist/core/tasks/simple_face.d.ts +0 -2
  299. package/dist/core/tasks/simple_face.js +0 -41
  300. package/dist/core/tasks/titanic.d.ts +0 -3
  301. package/dist/core/tasks/titanic.js +0 -88
  302. package/dist/core/training/disco.d.ts +0 -23
  303. package/dist/core/training/disco.js +0 -130
  304. package/dist/core/training/index.d.ts +0 -2
  305. package/dist/core/training/index.js +0 -7
  306. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  307. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  308. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  309. package/dist/core/training/trainer/local_trainer.js +0 -34
  310. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  311. package/dist/core/training/trainer/round_tracker.js +0 -47
  312. package/dist/core/training/trainer/trainer.d.ts +0 -65
  313. package/dist/core/training/trainer/trainer.js +0 -160
  314. package/dist/core/training/trainer/trainer_builder.js +0 -95
  315. package/dist/core/training/training_schemes.d.ts +0 -5
  316. package/dist/core/training/training_schemes.js +0 -10
  317. package/dist/core/types.d.ts +0 -4
  318. package/dist/core/types.js +0 -2
  319. package/dist/core/validation/index.d.ts +0 -1
  320. package/dist/core/validation/index.js +0 -5
  321. package/dist/core/validation/validator.d.ts +0 -17
  322. package/dist/core/validation/validator.js +0 -104
  323. package/dist/core/weights/aggregation.d.ts +0 -8
  324. package/dist/core/weights/aggregation.js +0 -96
  325. package/dist/core/weights/index.d.ts +0 -2
  326. package/dist/core/weights/index.js +0 -7
  327. package/dist/core/weights/weights_container.d.ts +0 -19
  328. package/dist/core/weights/weights_container.js +0 -64
  329. package/dist/imports.d.ts +0 -2
  330. package/dist/imports.js +0 -7
  331. package/dist/memory/memory.d.ts +0 -26
  332. package/dist/memory/memory.js +0 -160
  333. package/dist/{core/task → task}/data_example.d.ts +1 -1
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -1,160 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Trainer = void 0;
4
- var tslib_1 = require("tslib");
5
- var round_tracker_1 = require("./round_tracker");
6
- var trainer_logger_1 = require("../../logging/trainer_logger");
7
- /** Abstract class whose role is to train a model with a given dataset. This can be either done
8
- * locally (alone) or in a distributed way with collaborators. The Trainer works as follows:
9
- *
10
- * 1. Call trainModel(dataset) to start training
11
- * 2. Once a batch ends, onBatchEnd is triggered, which will then call onRoundEnd once the round has ended.
12
- *
13
- * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators. To know when
14
- * a round has ended we use the roundTracker object.
15
- */
16
- var Trainer = /** @class */ (function () {
17
- /**
18
- * Constructs the training manager.
19
- * @param task the trained task
20
- * @param trainingInformant the training informant
21
- */
22
- function Trainer(task, trainingInformant, memory, model) {
23
- this.task = task;
24
- this.trainingInformant = trainingInformant;
25
- this.memory = memory;
26
- this.model = model;
27
- this.stopTrainingRequested = false;
28
- this.trainerLogger = new trainer_logger_1.TrainerLogger();
29
- var trainingInformation = task.trainingInformation;
30
- if (trainingInformation === undefined) {
31
- throw new Error('round duration is undefined');
32
- }
33
- this.trainingInformation = trainingInformation;
34
- this.roundTracker = new round_tracker_1.RoundTracker(trainingInformation.roundDuration);
35
- }
36
- /** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
37
- */
38
- Trainer.prototype.onBatchEnd = function (_, logs) {
39
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
40
- return (0, tslib_1.__generator)(this, function (_a) {
41
- switch (_a.label) {
42
- case 0:
43
- if (logs === undefined) {
44
- return [2 /*return*/];
45
- }
46
- this.roundTracker.updateBatch();
47
- this.stopTrainModelIfRequested();
48
- if (!this.roundTracker.roundHasEnded()) return [3 /*break*/, 2];
49
- return [4 /*yield*/, this.onRoundEnd(logs.acc)];
50
- case 1:
51
- _a.sent();
52
- _a.label = 2;
53
- case 2: return [2 /*return*/];
54
- }
55
- });
56
- });
57
- };
58
- /**
59
- * We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
60
- */
61
- Trainer.prototype.onEpochEnd = function (epoch, logs) {
62
- this.trainerLogger.onEpochEnd(epoch, logs);
63
- if (logs !== undefined && !isNaN(logs.acc) && !isNaN(logs.val_acc)) {
64
- this.trainingInformant.updateTrainingGraph(this.roundDecimals(logs.acc));
65
- this.trainingInformant.updateValidationGraph(this.roundDecimals(logs.val_acc));
66
- }
67
- else {
68
- this.trainerLogger.error('onEpochEnd: NaN value');
69
- }
70
- };
71
- /**
72
- * When the training ends this function will be call
73
- */
74
- Trainer.prototype.onTrainEnd = function (logs) {
75
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
76
- return (0, tslib_1.__generator)(this, function (_a) {
77
- this.trainingInformant.addMessage('Training finished.');
78
- return [2 /*return*/];
79
- });
80
- });
81
- };
82
- /**
83
- * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
84
- */
85
- Trainer.prototype.stopTraining = function () {
86
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
87
- return (0, tslib_1.__generator)(this, function (_a) {
88
- this.stopTrainingRequested = true;
89
- return [2 /*return*/];
90
- });
91
- });
92
- };
93
- /**
94
- * Start training the model with the given dataset
95
- * @param dataset
96
- */
97
- Trainer.prototype.trainModel = function (dataset, valDataset) {
98
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
99
- var _this = this;
100
- return (0, tslib_1.__generator)(this, function (_a) {
101
- switch (_a.label) {
102
- case 0:
103
- this.resetStopTrainerState();
104
- // Assign callbacks and start training
105
- return [4 /*yield*/, this.model.fitDataset(dataset, {
106
- epochs: this.trainingInformation.epochs,
107
- validationData: valDataset,
108
- callbacks: {
109
- onEpochEnd: function (epoch, logs) { return _this.onEpochEnd(epoch, logs); },
110
- onBatchEnd: function (epoch, logs) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
111
- switch (_a.label) {
112
- case 0: return [4 /*yield*/, this.onBatchEnd(epoch, logs)];
113
- case 1: return [2 /*return*/, _a.sent()];
114
- }
115
- }); }); },
116
- onTrainEnd: function (logs) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
117
- switch (_a.label) {
118
- case 0: return [4 /*yield*/, this.onTrainEnd(logs)];
119
- case 1: return [2 /*return*/, _a.sent()];
120
- }
121
- }); }); }
122
- }
123
- })];
124
- case 1:
125
- // Assign callbacks and start training
126
- _a.sent();
127
- return [2 /*return*/];
128
- }
129
- });
130
- });
131
- };
132
- /**
133
- * Format accuracy
134
- */
135
- Trainer.prototype.roundDecimals = function (accuracy, decimalsToRound) {
136
- if (decimalsToRound === void 0) { decimalsToRound = 2; }
137
- return +(accuracy * 100).toFixed(decimalsToRound);
138
- };
139
- /**
140
- * reset stop training state
141
- */
142
- Trainer.prototype.resetStopTrainerState = function () {
143
- this.model.stopTraining = false;
144
- this.stopTrainingRequested = false;
145
- };
146
- /**
147
- * If stop training is requested, do so
148
- */
149
- Trainer.prototype.stopTrainModelIfRequested = function () {
150
- if (this.stopTrainingRequested) {
151
- this.model.stopTraining = true;
152
- this.stopTrainingRequested = false;
153
- }
154
- };
155
- Trainer.prototype.getTrainerLog = function () {
156
- return this.trainerLogger.log;
157
- };
158
- return Trainer;
159
- }());
160
- exports.Trainer = Trainer;
@@ -1,95 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.TrainerBuilder = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("../..");
6
- var distributed_trainer_1 = require("./distributed_trainer");
7
- var local_trainer_1 = require("./local_trainer");
8
- /**
9
- * A class that helps build the Trainer and auxiliary classes.
10
- */
11
- var TrainerBuilder = /** @class */ (function () {
12
- function TrainerBuilder(memory, task, trainingInformant) {
13
- this.memory = memory;
14
- this.task = task;
15
- this.trainingInformant = trainingInformant;
16
- }
17
- /**
18
- * Builds a trainer object.
19
- *
20
- * @param client client to share weights with (either distributed or federated)
21
- * @param distributed whether to build a distributed or local trainer
22
- * @returns
23
- */
24
- TrainerBuilder.prototype.build = function (client, distributed) {
25
- if (distributed === void 0) { distributed = false; }
26
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
27
- var model;
28
- return (0, tslib_1.__generator)(this, function (_a) {
29
- switch (_a.label) {
30
- case 0: return [4 /*yield*/, this.getModel(client)];
31
- case 1:
32
- model = _a.sent();
33
- if (distributed) {
34
- return [2 /*return*/, new distributed_trainer_1.DistributedTrainer(this.task, this.trainingInformant, this.memory, model, model, client)];
35
- }
36
- else {
37
- return [2 /*return*/, new local_trainer_1.LocalTrainer(this.task, this.trainingInformant, this.memory, model)];
38
- }
39
- return [2 /*return*/];
40
- }
41
- });
42
- });
43
- };
44
- /**
45
- * If a model exists in memory, laod it, otherwise load model from server
46
- * @returns
47
- */
48
- TrainerBuilder.prototype.getModel = function (client) {
49
- var _a;
50
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
51
- var modelID, info, model;
52
- return (0, tslib_1.__generator)(this, function (_b) {
53
- switch (_b.label) {
54
- case 0:
55
- modelID = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.modelID;
56
- if (modelID === undefined) {
57
- throw new TypeError('model ID is undefined');
58
- }
59
- info = { type: __1.ModelType.WORKING, taskID: this.task.taskID, name: modelID };
60
- return [4 /*yield*/, this.memory.contains(info)];
61
- case 1: return [4 /*yield*/, ((_b.sent()) ? this.memory.getModel(info) : client.getLatestModel())];
62
- case 2:
63
- model = _b.sent();
64
- return [4 /*yield*/, this.updateModelInformation(model)];
65
- case 3: return [2 /*return*/, _b.sent()];
66
- }
67
- });
68
- });
69
- };
70
- TrainerBuilder.prototype.updateModelInformation = function (model) {
71
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
72
- var info;
73
- return (0, tslib_1.__generator)(this, function (_a) {
74
- // Continue local training from previous epoch checkpoint
75
- if (model.getUserDefinedMetadata() === undefined) {
76
- model.setUserDefinedMetadata({ epoch: 0 });
77
- }
78
- info = this.task.trainingInformation;
79
- if (info === undefined) {
80
- throw new TypeError('training information is undefined');
81
- }
82
- model.compile(info.modelCompileData);
83
- if (info.learningRate !== undefined) {
84
- // TODO: Not the right way to change learningRate and hence we cast to any
85
- // the right way is to construct the optimiser and pass learningRate via
86
- // argument.
87
- model.optimizer.learningRate = info.learningRate;
88
- }
89
- return [2 /*return*/, model];
90
- });
91
- });
92
- };
93
- return TrainerBuilder;
94
- }());
95
- exports.TrainerBuilder = TrainerBuilder;
@@ -1,5 +0,0 @@
1
- export declare enum TrainingSchemes {
2
- LOCAL = "local",
3
- DECENTRALIZED = "decentralized",
4
- FEDERATED = "federated"
5
- }
@@ -1,10 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.TrainingSchemes = void 0;
4
- /* eslint-disable no-unused-vars */
5
- var TrainingSchemes;
6
- (function (TrainingSchemes) {
7
- TrainingSchemes["LOCAL"] = "local";
8
- TrainingSchemes["DECENTRALIZED"] = "decentralized";
9
- TrainingSchemes["FEDERATED"] = "federated";
10
- })(TrainingSchemes = exports.TrainingSchemes || (exports.TrainingSchemes = {}));
@@ -1,4 +0,0 @@
1
- import { tf } from '.';
2
- export declare type Path = string;
3
- export declare type Weights = tf.Tensor[];
4
- export declare type MetadataID = string;
@@ -1,2 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
@@ -1 +0,0 @@
1
- export { Validator } from './validator';
@@ -1,5 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Validator = void 0;
4
- var validator_1 = require("./validator");
5
- Object.defineProperty(exports, "Validator", { enumerable: true, get: function () { return validator_1.Validator; } });
@@ -1,17 +0,0 @@
1
- import { List } from 'immutable';
2
- import { tf, data, Task, Logger, Client, Memory, ModelSource } from '..';
3
- export declare class Validator {
4
- readonly task: Task;
5
- readonly logger: Logger;
6
- private readonly memory;
7
- private readonly source?;
8
- private readonly client?;
9
- private readonly graphInformant;
10
- private size;
11
- constructor(task: Task, logger: Logger, memory: Memory, source?: ModelSource | undefined, client?: Client | undefined);
12
- assess(data: data.Data): Promise<void>;
13
- getModel(): Promise<tf.LayersModel>;
14
- accuracyData(): List<number>;
15
- accuracy(): number;
16
- visitedSamples(): number;
17
- }
@@ -1,104 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Validator = void 0;
4
- var tslib_1 = require("tslib");
5
- var immutable_1 = require("immutable");
6
- var __1 = require("..");
7
- var Validator = /** @class */ (function () {
8
- function Validator(task, logger, memory, source, client) {
9
- this.task = task;
10
- this.logger = logger;
11
- this.memory = memory;
12
- this.source = source;
13
- this.client = client;
14
- this.graphInformant = new __1.GraphInformant();
15
- this.size = 0;
16
- if (source === undefined && client === undefined) {
17
- throw new Error('cannot identify model');
18
- }
19
- }
20
- Validator.prototype.assess = function (data) {
21
- var _a, _b, _c;
22
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
23
- var batchSize, labels, classes, model, hits;
24
- var _this = this;
25
- return (0, tslib_1.__generator)(this, function (_d) {
26
- switch (_d.label) {
27
- case 0:
28
- batchSize = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.batchSize;
29
- if (batchSize === undefined) {
30
- throw new TypeError('batch size is undefined');
31
- }
32
- labels = (_b = this.task.trainingInformation) === null || _b === void 0 ? void 0 : _b.LABEL_LIST;
33
- classes = (_c = labels === null || labels === void 0 ? void 0 : labels.length) !== null && _c !== void 0 ? _c : 1;
34
- return [4 /*yield*/, this.getModel()];
35
- case 1:
36
- model = _d.sent();
37
- hits = 0;
38
- return [4 /*yield*/, data.dataset.batch(batchSize).forEachAsync(function (e) {
39
- if (typeof e === 'object' && 'xs' in e && 'ys' in e) {
40
- var xs = e.xs;
41
- var ys = e.ys.dataSync();
42
- var pred = model.predict(xs, { batchSize: batchSize })
43
- .dataSync()
44
- .map(Math.round);
45
- _this.size += xs.shape[0];
46
- hits += (0, immutable_1.List)(pred).zip((0, immutable_1.List)(ys))
47
- .map(function (_a) {
48
- var _b = (0, tslib_1.__read)(_a, 2), p = _b[0], y = _b[1];
49
- return 1 - Math.abs(p - y);
50
- })
51
- .reduce(function (acc, e) { return acc + e; }) / classes;
52
- var currentAccuracy = hits / _this.size;
53
- _this.graphInformant.updateAccuracy(currentAccuracy);
54
- }
55
- else {
56
- throw new TypeError('missing feature/label in dataset');
57
- }
58
- })];
59
- case 2:
60
- _d.sent();
61
- this.logger.success("Obtained validation accuracy of " + this.accuracy());
62
- this.logger.success("Visited " + this.visitedSamples() + " samples");
63
- return [2 /*return*/];
64
- }
65
- });
66
- });
67
- };
68
- Validator.prototype.getModel = function () {
69
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
70
- var _a;
71
- return (0, tslib_1.__generator)(this, function (_b) {
72
- switch (_b.label) {
73
- case 0:
74
- _a = this.source !== undefined;
75
- if (!_a) return [3 /*break*/, 2];
76
- return [4 /*yield*/, this.memory.contains(this.source)];
77
- case 1:
78
- _a = (_b.sent());
79
- _b.label = 2;
80
- case 2:
81
- if (!_a) return [3 /*break*/, 4];
82
- return [4 /*yield*/, this.memory.getModel(this.source)];
83
- case 3: return [2 /*return*/, _b.sent()];
84
- case 4:
85
- if (!(this.client !== undefined)) return [3 /*break*/, 6];
86
- return [4 /*yield*/, this.client.getLatestModel()];
87
- case 5: return [2 /*return*/, _b.sent()];
88
- case 6: throw new Error('cannot identify model');
89
- }
90
- });
91
- });
92
- };
93
- Validator.prototype.accuracyData = function () {
94
- return this.graphInformant.data();
95
- };
96
- Validator.prototype.accuracy = function () {
97
- return this.graphInformant.accuracy();
98
- };
99
- Validator.prototype.visitedSamples = function () {
100
- return this.size;
101
- };
102
- return Validator;
103
- }());
104
- exports.Validator = Validator;
@@ -1,8 +0,0 @@
1
- import { TensorLike, WeightsContainer } from './weights_container';
2
- declare type WeightsLike = Iterable<TensorLike>;
3
- export declare function sum(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
4
- export declare function diff(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
5
- export declare function avg(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
6
- export declare function avgClippingWeights(peersWeights: Iterable<WeightsLike | WeightsContainer>, currentModel: WeightsContainer, tauPercentile: number): WeightsContainer;
7
- export declare function assertWeightsEqual(w1: WeightsContainer, w2: WeightsContainer, epsilon?: number): void;
8
- export {};
@@ -1,96 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.assertWeightsEqual = exports.avgClippingWeights = exports.avg = exports.diff = exports.sum = void 0;
4
- var tslib_1 = require("tslib");
5
- var immutable_1 = require("immutable");
6
- var chai_1 = require("chai");
7
- var __1 = require("..");
8
- var weights_container_1 = require("./weights_container");
9
- function parseWeights(weights) {
10
- var _a;
11
- var r = (0, immutable_1.List)(weights).map(function (w) {
12
- return w instanceof weights_container_1.WeightsContainer ? w : new weights_container_1.WeightsContainer(w);
13
- });
14
- var weightsSize = (_a = r.first()) === null || _a === void 0 ? void 0 : _a.weights.length;
15
- if (weightsSize === undefined) {
16
- throw new Error('no weights to work with');
17
- }
18
- if (r.rest().every(function (w) { return w.weights.length !== weightsSize; })) {
19
- throw new Error('weights dimensions are different for some of the operands');
20
- }
21
- return r;
22
- }
23
- function centerWeights(weights, currentModel) {
24
- return parseWeights(weights).map(function (model) { return model.mapWith(currentModel, __1.tf.sub); });
25
- }
26
- function clipWeights(modelList, normArray, tau) {
27
- return modelList.map(function (weights) { return weights.map(function (w, i) { return __1.tf.prod(w, Math.min(1, tau / (normArray[i]))); }); });
28
- }
29
- function computeQuantile(array, q) {
30
- var sorted = array.sort(function (a, b) { return a - b; });
31
- var pos = (sorted.length - 1) * q;
32
- var base = Math.floor(pos);
33
- var rest = pos - base;
34
- if (sorted[base + 1] !== undefined) {
35
- return sorted[base] + rest * (sorted[base + 1] - sorted[base]);
36
- }
37
- else {
38
- return sorted[base];
39
- }
40
- }
41
- function reduce(weights, fn) {
42
- return parseWeights(weights).reduce(function (acc, ws) {
43
- return new weights_container_1.WeightsContainer(acc.weights.map(function (w, i) {
44
- return fn(w, ws.get(i));
45
- }));
46
- });
47
- }
48
- function sum(weights) {
49
- return reduce(weights, __1.tf.add);
50
- }
51
- exports.sum = sum;
52
- function diff(weights) {
53
- return reduce(weights, __1.tf.sub);
54
- }
55
- exports.diff = diff;
56
- function avg(weights) {
57
- var size = (0, immutable_1.List)(weights).size;
58
- return sum(weights).map(function (ws) { return ws.div(size); });
59
- }
60
- exports.avg = avg;
61
- // See: https://arxiv.org/abs/2012.10333
62
- function avgClippingWeights(peersWeights, currentModel, tauPercentile) {
63
- // Computing the centered peers weights with respect to the previous model aggragation
64
- var centeredPeersWeights = centerWeights(peersWeights, currentModel);
65
- // Computing the Matrix Norm (Frobenius Norm) of the centered peers weights
66
- var normArray = Array.from(centeredPeersWeights.map(function (model) { return model.frobeniusNorm(); }));
67
- // Computing the parameter tau as third percentile with respect to the norm array
68
- var tau = computeQuantile(normArray, tauPercentile);
69
- // Computing the centered clipped peers weights given the norm array and the parameter tau
70
- var centeredMean = clipWeights(centeredPeersWeights, normArray, tau);
71
- // Aggregating all centered clipped peers weights
72
- return avg(centeredMean);
73
- }
74
- exports.avgClippingWeights = avgClippingWeights;
75
- // TODO: implement equal in WeightsContainer
76
- function assertWeightsEqual(w1, w2, epsilon) {
77
- var e_1, _a;
78
- if (epsilon === void 0) { epsilon = 0; }
79
- try {
80
- // Inefficient because we wait for each layer to completely load before we start loading the next layer
81
- // when using tf.Tensor.dataSync() in a for loop. Could be made more efficient by using Promise.all().
82
- // Not worth making more efficient, because this function is only used for testing, where tf.Tensors are small.
83
- for (var _b = (0, tslib_1.__values)(w1.sub(w2).weights), _c = _b.next(); !_c.done; _c = _b.next()) {
84
- var t = _c.value;
85
- chai_1.assert.strictEqual(__1.tf.lessEqual(t.abs(), epsilon).all().dataSync()[0], 1);
86
- }
87
- }
88
- catch (e_1_1) { e_1 = { error: e_1_1 }; }
89
- finally {
90
- try {
91
- if (_c && !_c.done && (_a = _b.return)) _a.call(_b);
92
- }
93
- finally { if (e_1) throw e_1.error; }
94
- }
95
- }
96
- exports.assertWeightsEqual = assertWeightsEqual;
@@ -1,2 +0,0 @@
1
- export { WeightsContainer } from './weights_container';
2
- export * as aggregation from './aggregation';
@@ -1,7 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.aggregation = exports.WeightsContainer = void 0;
4
- var tslib_1 = require("tslib");
5
- var weights_container_1 = require("./weights_container");
6
- Object.defineProperty(exports, "WeightsContainer", { enumerable: true, get: function () { return weights_container_1.WeightsContainer; } });
7
- exports.aggregation = (0, tslib_1.__importStar)(require("./aggregation"));
@@ -1,19 +0,0 @@
1
- import { tf, Weights } from '..';
2
- export declare type TensorLike = tf.Tensor | ArrayLike<number>;
3
- export declare class WeightsContainer {
4
- private readonly _weights;
5
- constructor(weights: Iterable<TensorLike>);
6
- get weights(): Weights;
7
- add(other: WeightsContainer): WeightsContainer;
8
- sub(other: WeightsContainer): WeightsContainer;
9
- mapWith(other: WeightsContainer, fn: (a: tf.Tensor, b: tf.Tensor) => tf.Tensor): WeightsContainer;
10
- map(fn: (t: tf.Tensor, i: number) => tf.Tensor): WeightsContainer;
11
- map(fn: (t: tf.Tensor) => tf.Tensor): WeightsContainer;
12
- reduce(fn: (acc: tf.Tensor, t: tf.Tensor) => tf.Tensor): tf.Tensor;
13
- get(index: number): tf.Tensor | undefined;
14
- frobeniusNorm(): number;
15
- static of(...weights: TensorLike[]): WeightsContainer;
16
- static from(model: tf.LayersModel): WeightsContainer;
17
- static add(a: Iterable<TensorLike>, b: Iterable<TensorLike>): WeightsContainer;
18
- static sub(a: Iterable<TensorLike>, b: Iterable<TensorLike>): WeightsContainer;
19
- }
@@ -1,64 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.WeightsContainer = void 0;
4
- var tslib_1 = require("tslib");
5
- var immutable_1 = require("immutable");
6
- var __1 = require("..");
7
- var WeightsContainer = /** @class */ (function () {
8
- function WeightsContainer(weights) {
9
- this._weights = (0, immutable_1.List)(weights).map(function (w) {
10
- return w instanceof __1.tf.Tensor ? w : __1.tf.tensor(w);
11
- });
12
- }
13
- Object.defineProperty(WeightsContainer.prototype, "weights", {
14
- get: function () {
15
- return this._weights.toArray();
16
- },
17
- enumerable: false,
18
- configurable: true
19
- });
20
- WeightsContainer.prototype.add = function (other) {
21
- return this.mapWith(other, __1.tf.add);
22
- };
23
- WeightsContainer.prototype.sub = function (other) {
24
- return this.mapWith(other, __1.tf.sub);
25
- };
26
- WeightsContainer.prototype.mapWith = function (other, fn) {
27
- return new WeightsContainer(this._weights
28
- .zip(other._weights)
29
- .map(function (_a) {
30
- var _b = (0, tslib_1.__read)(_a, 2), w1 = _b[0], w2 = _b[1];
31
- return fn(w1, w2);
32
- }));
33
- };
34
- WeightsContainer.prototype.map = function (fn) {
35
- return new WeightsContainer(this._weights.map(fn));
36
- };
37
- WeightsContainer.prototype.reduce = function (fn) {
38
- return this._weights.reduce(fn);
39
- };
40
- WeightsContainer.prototype.get = function (index) {
41
- return this._weights.get(index);
42
- };
43
- WeightsContainer.prototype.frobeniusNorm = function () {
44
- return Math.sqrt(this.map(function (w) { return w.square().sum(); }).reduce(function (a, b) { return a.add(b); }).dataSync()[0]);
45
- };
46
- WeightsContainer.of = function () {
47
- var weights = [];
48
- for (var _i = 0; _i < arguments.length; _i++) {
49
- weights[_i] = arguments[_i];
50
- }
51
- return new this(weights);
52
- };
53
- WeightsContainer.from = function (model) {
54
- return new this(model.weights.map(function (w) { return w.read(); }));
55
- };
56
- WeightsContainer.add = function (a, b) {
57
- return new this(a).add(new this(b));
58
- };
59
- WeightsContainer.sub = function (a, b) {
60
- return new this(a).sub(new this(b));
61
- };
62
- return WeightsContainer;
63
- }());
64
- exports.WeightsContainer = WeightsContainer;
package/dist/imports.d.ts DELETED
@@ -1,2 +0,0 @@
1
- export * as data from './dataset/data_loader';
2
- export { IndexedDB } from './memory';
package/dist/imports.js DELETED
@@ -1,7 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.IndexedDB = exports.data = void 0;
4
- var tslib_1 = require("tslib");
5
- exports.data = (0, tslib_1.__importStar)(require("./dataset/data_loader"));
6
- var memory_1 = require("./memory");
7
- Object.defineProperty(exports, "IndexedDB", { enumerable: true, get: function () { return memory_1.IndexedDB; } });
@@ -1,26 +0,0 @@
1
- import { tf, Memory, Path, ModelInfo, ModelSource } from '..';
2
- export declare class IndexedDB extends Memory {
3
- pathFor(source: ModelSource): Path;
4
- infoFor(source: ModelSource): ModelInfo;
5
- getModelMetadata(source: ModelSource): Promise<tf.io.ModelArtifactsInfo | undefined>;
6
- contains(source: ModelSource): Promise<boolean>;
7
- getModel(source: ModelSource): Promise<tf.LayersModel>;
8
- deleteModel(source: ModelSource): Promise<void>;
9
- loadModel(source: ModelSource): Promise<void>;
10
- /**
11
- * Saves the working model to the source.
12
- * @param source the destination
13
- * @param model the model
14
- */
15
- updateWorkingModel(source: ModelSource, model: tf.LayersModel): Promise<void>;
16
- /**
17
- * Creates a saved copy of the working model corresponding to the source.
18
- * @param source the source
19
- */
20
- saveWorkingModel(source: ModelSource): Promise<void>;
21
- /**
22
- * Downloads the model corresponding to the source.
23
- * @param source the source
24
- */
25
- downloadModel(source: ModelSource): Promise<void>;
26
- }