@epfml/discojs 2.0.0 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/default_tasks/cifar10/index.js +60 -0
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/default_tasks/mnist.js +61 -0
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/default_tasks/simple_face/index.js +48 -0
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/default_tasks/titanic.js +88 -0
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.d.ts +5 -0
  135. package/dist/task/digest.js +14 -0
  136. package/dist/{core/task → task}/display_information.d.ts +5 -3
  137. package/dist/task/display_information.js +46 -0
  138. package/dist/task/index.d.ts +7 -0
  139. package/dist/task/index.js +5 -0
  140. package/dist/task/label_type.d.ts +9 -0
  141. package/dist/task/label_type.js +28 -0
  142. package/dist/task/summary.js +13 -0
  143. package/dist/task/task.d.ts +12 -0
  144. package/dist/task/task.js +22 -0
  145. package/dist/task/task_handler.d.ts +5 -0
  146. package/dist/task/task_handler.js +20 -0
  147. package/dist/task/task_provider.d.ts +5 -0
  148. package/dist/task/task_provider.js +1 -0
  149. package/dist/{core/task → task}/training_information.d.ts +9 -10
  150. package/dist/task/training_information.js +88 -0
  151. package/dist/training/disco.d.ts +40 -0
  152. package/dist/training/disco.js +107 -0
  153. package/dist/training/index.d.ts +2 -0
  154. package/dist/training/index.js +1 -0
  155. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  156. package/dist/training/trainer/distributed_trainer.js +36 -0
  157. package/dist/training/trainer/local_trainer.d.ts +12 -0
  158. package/dist/training/trainer/local_trainer.js +19 -0
  159. package/dist/training/trainer/trainer.d.ts +33 -0
  160. package/dist/training/trainer/trainer.js +52 -0
  161. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  162. package/dist/training/trainer/trainer_builder.js +43 -0
  163. package/dist/types.d.ts +8 -0
  164. package/dist/types.js +1 -0
  165. package/dist/utils/event_emitter.d.ts +40 -0
  166. package/dist/utils/event_emitter.js +57 -0
  167. package/dist/validation/index.d.ts +1 -0
  168. package/dist/validation/index.js +1 -0
  169. package/dist/validation/validator.d.ts +28 -0
  170. package/dist/validation/validator.js +132 -0
  171. package/dist/weights/aggregation.d.ts +21 -0
  172. package/dist/weights/aggregation.js +44 -0
  173. package/dist/weights/index.d.ts +2 -0
  174. package/dist/weights/index.js +2 -0
  175. package/dist/weights/weights_container.d.ts +68 -0
  176. package/dist/weights/weights_container.js +96 -0
  177. package/package.json +25 -16
  178. package/README.md +0 -53
  179. package/dist/core/async_buffer.d.ts +0 -41
  180. package/dist/core/async_buffer.js +0 -97
  181. package/dist/core/async_informant.d.ts +0 -20
  182. package/dist/core/async_informant.js +0 -69
  183. package/dist/core/client/base.d.ts +0 -33
  184. package/dist/core/client/base.js +0 -35
  185. package/dist/core/client/decentralized/base.d.ts +0 -32
  186. package/dist/core/client/decentralized/base.js +0 -212
  187. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  188. package/dist/core/client/decentralized/clear_text.js +0 -96
  189. package/dist/core/client/decentralized/index.d.ts +0 -4
  190. package/dist/core/client/decentralized/index.js +0 -9
  191. package/dist/core/client/decentralized/messages.d.ts +0 -41
  192. package/dist/core/client/decentralized/messages.js +0 -54
  193. package/dist/core/client/decentralized/peer.d.ts +0 -26
  194. package/dist/core/client/decentralized/peer.js +0 -210
  195. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  196. package/dist/core/client/decentralized/peer_pool.js +0 -92
  197. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  198. package/dist/core/client/decentralized/sec_agg.js +0 -190
  199. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  200. package/dist/core/client/decentralized/secret_shares.js +0 -39
  201. package/dist/core/client/decentralized/types.d.ts +0 -2
  202. package/dist/core/client/decentralized/types.js +0 -7
  203. package/dist/core/client/event_connection.d.ts +0 -37
  204. package/dist/core/client/event_connection.js +0 -158
  205. package/dist/core/client/federated/client.d.ts +0 -37
  206. package/dist/core/client/federated/client.js +0 -273
  207. package/dist/core/client/federated/index.d.ts +0 -2
  208. package/dist/core/client/federated/index.js +0 -7
  209. package/dist/core/client/federated/messages.d.ts +0 -38
  210. package/dist/core/client/federated/messages.js +0 -25
  211. package/dist/core/client/index.d.ts +0 -5
  212. package/dist/core/client/index.js +0 -11
  213. package/dist/core/client/local.d.ts +0 -8
  214. package/dist/core/client/local.js +0 -36
  215. package/dist/core/client/messages.d.ts +0 -28
  216. package/dist/core/client/messages.js +0 -33
  217. package/dist/core/client/utils.d.ts +0 -2
  218. package/dist/core/client/utils.js +0 -19
  219. package/dist/core/dataset/data/data.d.ts +0 -11
  220. package/dist/core/dataset/data/data.js +0 -20
  221. package/dist/core/dataset/data/data_split.d.ts +0 -5
  222. package/dist/core/dataset/data/data_split.js +0 -2
  223. package/dist/core/dataset/data/image_data.d.ts +0 -8
  224. package/dist/core/dataset/data/image_data.js +0 -64
  225. package/dist/core/dataset/data/index.d.ts +0 -5
  226. package/dist/core/dataset/data/index.js +0 -11
  227. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  228. package/dist/core/dataset/data/preprocessing.js +0 -33
  229. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  230. package/dist/core/dataset/data/tabular_data.js +0 -40
  231. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  232. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  233. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  234. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  235. package/dist/core/dataset/data_loader/index.js +0 -9
  236. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  237. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  238. package/dist/core/dataset/dataset.d.ts +0 -2
  239. package/dist/core/dataset/dataset.js +0 -2
  240. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  241. package/dist/core/dataset/dataset_builder.js +0 -96
  242. package/dist/core/dataset/index.d.ts +0 -4
  243. package/dist/core/dataset/index.js +0 -14
  244. package/dist/core/index.d.ts +0 -18
  245. package/dist/core/index.js +0 -41
  246. package/dist/core/informant/graph_informant.js +0 -23
  247. package/dist/core/informant/index.d.ts +0 -3
  248. package/dist/core/informant/index.js +0 -9
  249. package/dist/core/informant/training_informant/base.d.ts +0 -31
  250. package/dist/core/informant/training_informant/base.js +0 -83
  251. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  252. package/dist/core/informant/training_informant/decentralized.js +0 -22
  253. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  254. package/dist/core/informant/training_informant/federated.js +0 -32
  255. package/dist/core/informant/training_informant/index.d.ts +0 -4
  256. package/dist/core/informant/training_informant/index.js +0 -11
  257. package/dist/core/informant/training_informant/local.d.ts +0 -6
  258. package/dist/core/informant/training_informant/local.js +0 -20
  259. package/dist/core/logging/console_logger.js +0 -33
  260. package/dist/core/logging/index.d.ts +0 -3
  261. package/dist/core/logging/index.js +0 -9
  262. package/dist/core/logging/logger.js +0 -9
  263. package/dist/core/logging/trainer_logger.d.ts +0 -24
  264. package/dist/core/logging/trainer_logger.js +0 -59
  265. package/dist/core/memory/base.d.ts +0 -22
  266. package/dist/core/memory/base.js +0 -9
  267. package/dist/core/memory/empty.d.ts +0 -14
  268. package/dist/core/memory/empty.js +0 -75
  269. package/dist/core/memory/index.d.ts +0 -3
  270. package/dist/core/memory/index.js +0 -9
  271. package/dist/core/memory/model_type.d.ts +0 -4
  272. package/dist/core/memory/model_type.js +0 -9
  273. package/dist/core/serialization/index.d.ts +0 -2
  274. package/dist/core/serialization/index.js +0 -6
  275. package/dist/core/serialization/model.d.ts +0 -5
  276. package/dist/core/serialization/model.js +0 -55
  277. package/dist/core/serialization/weights.js +0 -64
  278. package/dist/core/task/data_example.js +0 -24
  279. package/dist/core/task/display_information.js +0 -49
  280. package/dist/core/task/index.d.ts +0 -3
  281. package/dist/core/task/index.js +0 -8
  282. package/dist/core/task/model_compile_data.d.ts +0 -6
  283. package/dist/core/task/model_compile_data.js +0 -22
  284. package/dist/core/task/summary.js +0 -19
  285. package/dist/core/task/task.d.ts +0 -10
  286. package/dist/core/task/task.js +0 -31
  287. package/dist/core/task/training_information.js +0 -66
  288. package/dist/core/tasks/cifar10.d.ts +0 -3
  289. package/dist/core/tasks/cifar10.js +0 -65
  290. package/dist/core/tasks/geotags.d.ts +0 -3
  291. package/dist/core/tasks/geotags.js +0 -67
  292. package/dist/core/tasks/index.d.ts +0 -6
  293. package/dist/core/tasks/index.js +0 -10
  294. package/dist/core/tasks/lus_covid.d.ts +0 -3
  295. package/dist/core/tasks/lus_covid.js +0 -87
  296. package/dist/core/tasks/mnist.d.ts +0 -3
  297. package/dist/core/tasks/mnist.js +0 -60
  298. package/dist/core/tasks/simple_face.d.ts +0 -2
  299. package/dist/core/tasks/simple_face.js +0 -41
  300. package/dist/core/tasks/titanic.d.ts +0 -3
  301. package/dist/core/tasks/titanic.js +0 -88
  302. package/dist/core/training/disco.d.ts +0 -23
  303. package/dist/core/training/disco.js +0 -130
  304. package/dist/core/training/index.d.ts +0 -2
  305. package/dist/core/training/index.js +0 -7
  306. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  307. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  308. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  309. package/dist/core/training/trainer/local_trainer.js +0 -34
  310. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  311. package/dist/core/training/trainer/round_tracker.js +0 -47
  312. package/dist/core/training/trainer/trainer.d.ts +0 -65
  313. package/dist/core/training/trainer/trainer.js +0 -160
  314. package/dist/core/training/trainer/trainer_builder.js +0 -95
  315. package/dist/core/training/training_schemes.d.ts +0 -5
  316. package/dist/core/training/training_schemes.js +0 -10
  317. package/dist/core/types.d.ts +0 -4
  318. package/dist/core/types.js +0 -2
  319. package/dist/core/validation/index.d.ts +0 -1
  320. package/dist/core/validation/index.js +0 -5
  321. package/dist/core/validation/validator.d.ts +0 -17
  322. package/dist/core/validation/validator.js +0 -104
  323. package/dist/core/weights/aggregation.d.ts +0 -8
  324. package/dist/core/weights/aggregation.js +0 -96
  325. package/dist/core/weights/index.d.ts +0 -2
  326. package/dist/core/weights/index.js +0 -7
  327. package/dist/core/weights/weights_container.d.ts +0 -19
  328. package/dist/core/weights/weights_container.js +0 -64
  329. package/dist/imports.d.ts +0 -2
  330. package/dist/imports.js +0 -7
  331. package/dist/memory/memory.d.ts +0 -26
  332. package/dist/memory/memory.js +0 -160
  333. package/dist/{core/task → task}/data_example.d.ts +1 -1
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -1,273 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Client = void 0;
4
- var tslib_1 = require("tslib");
5
- var uuid_1 = require("uuid");
6
- var __1 = require("../..");
7
- var base_1 = require("../base");
8
- var messages = (0, tslib_1.__importStar)(require("./messages"));
9
- var messages_1 = require("../messages");
10
- var nodeUrl = (0, tslib_1.__importStar)(require("url"));
11
- var event_connection_1 = require("../event_connection");
12
- var utils_1 = require("../utils");
13
- /**
14
- * Class that deals with communication with the centralized server when training
15
- * a specific task in the federated setting.
16
- */
17
- var Client = /** @class */ (function (_super) {
18
- (0, tslib_1.__extends)(Client, _super);
19
- function Client() {
20
- var _this = _super !== null && _super.apply(this, arguments) || this;
21
- _this.clientID = (0, uuid_1.v4)();
22
- _this.round = 0;
23
- return _this;
24
- }
25
- Object.defineProperty(Client.prototype, "server", {
26
- get: function () {
27
- if (this._server === undefined) {
28
- throw new Error('server undefined, not connected');
29
- }
30
- return this._server;
31
- },
32
- enumerable: false,
33
- configurable: true
34
- });
35
- // It opens a new WebSocket connection and listens to new messages over the channel
36
- Client.prototype.connectServer = function (url) {
37
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
38
- var server;
39
- return (0, tslib_1.__generator)(this, function (_a) {
40
- switch (_a.label) {
41
- case 0: return [4 /*yield*/, event_connection_1.WebSocketServer.connect(url, messages.isMessageFederated, messages.isMessageFederated)];
42
- case 1:
43
- server = _a.sent();
44
- return [2 /*return*/, server];
45
- }
46
- });
47
- });
48
- };
49
- /**
50
- * Initialize the connection to the server. TODO: In the case of FeAI,
51
- * should return the current server-side round for the task.
52
- */
53
- Client.prototype.connect = function () {
54
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
55
- var URL, serverURL, _a, msg;
56
- return (0, tslib_1.__generator)(this, function (_b) {
57
- switch (_b.label) {
58
- case 0:
59
- URL = typeof window !== 'undefined' ? window.URL : nodeUrl.URL;
60
- serverURL = new URL('', this.url.href);
61
- switch (this.url.protocol) {
62
- case 'http:':
63
- serverURL.protocol = 'ws:';
64
- break;
65
- case 'https:':
66
- serverURL.protocol = 'wss:';
67
- break;
68
- default:
69
- throw new Error("unknown protocol: " + this.url.protocol);
70
- }
71
- serverURL.pathname += "feai/" + this.task.taskID + "/" + this.clientID;
72
- _a = this;
73
- return [4 /*yield*/, this.connectServer(serverURL)];
74
- case 1:
75
- _a._server = _b.sent();
76
- msg = {
77
- type: messages_1.type.clientConnected
78
- };
79
- this.server.send(msg);
80
- return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.clientConnected, utils_1.MAX_WAIT_PER_ROUND)];
81
- case 2:
82
- _b.sent();
83
- this.connected = true;
84
- return [2 /*return*/];
85
- }
86
- });
87
- });
88
- };
89
- /**
90
- * Disconnection process when user quits the task.
91
- */
92
- Client.prototype.disconnect = function () {
93
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
94
- return (0, tslib_1.__generator)(this, function (_a) {
95
- this.server.disconnect();
96
- this._server = undefined;
97
- this.connected = false;
98
- return [2 /*return*/];
99
- });
100
- });
101
- };
102
- // It sends a message to the server
103
- Client.prototype.sendMessage = function (msg) {
104
- var _a;
105
- (_a = this.server) === null || _a === void 0 ? void 0 : _a.send(msg);
106
- };
107
- // It sends weights to the server
108
- Client.prototype.postWeightsToServer = function (weights) {
109
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
110
- var msg;
111
- var _a;
112
- return (0, tslib_1.__generator)(this, function (_b) {
113
- switch (_b.label) {
114
- case 0:
115
- _a = {
116
- type: messages_1.type.postWeightsToServer
117
- };
118
- return [4 /*yield*/, __1.serialization.weights.encode(weights)];
119
- case 1:
120
- msg = (_a.weights = _b.sent(),
121
- _a.round = this.round,
122
- _a);
123
- this.sendMessage(msg);
124
- return [2 /*return*/];
125
- }
126
- });
127
- });
128
- };
129
- // It retrieves the last server round and weights, but return only the server round
130
- Client.prototype.getLatestServerRound = function () {
131
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
132
- var msg, received;
133
- return (0, tslib_1.__generator)(this, function (_a) {
134
- switch (_a.label) {
135
- case 0:
136
- this.serverRound = undefined;
137
- this.serverWeights = undefined;
138
- msg = {
139
- type: messages_1.type.latestServerRound
140
- };
141
- this.sendMessage(msg);
142
- return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.latestServerRound, utils_1.MAX_WAIT_PER_ROUND)];
143
- case 1:
144
- received = _a.sent();
145
- this.serverRound = received.round;
146
- this.serverWeights = __1.serialization.weights.decode(received.weights);
147
- return [2 /*return*/, this.serverRound];
148
- }
149
- });
150
- });
151
- };
152
- // It retrieves the last server round and weights, but return only the server weights
153
- Client.prototype.pullRoundAndFetchWeights = function () {
154
- var _a;
155
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
156
- return (0, tslib_1.__generator)(this, function (_b) {
157
- switch (_b.label) {
158
- case 0:
159
- // get server round of latest model
160
- return [4 /*yield*/, this.getLatestServerRound()];
161
- case 1:
162
- // get server round of latest model
163
- _b.sent();
164
- if (this.round < ((_a = this.serverRound) !== null && _a !== void 0 ? _a : 0)) {
165
- // Update the local round to match the server's
166
- this.round = this.serverRound;
167
- return [2 /*return*/, this.serverWeights];
168
- }
169
- else {
170
- return [2 /*return*/, undefined];
171
- }
172
- return [2 /*return*/];
173
- }
174
- });
175
- });
176
- };
177
- // It pulls statistics from the server
178
- Client.prototype.pullServerStatistics = function (trainingInformant) {
179
- var _a;
180
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
181
- var msg, received;
182
- return (0, tslib_1.__generator)(this, function (_b) {
183
- switch (_b.label) {
184
- case 0:
185
- this.receivedStatistics = undefined;
186
- msg = {
187
- type: messages_1.type.pullServerStatistics
188
- };
189
- this.sendMessage(msg);
190
- return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.pullServerStatistics, utils_1.MAX_WAIT_PER_ROUND)];
191
- case 1:
192
- received = _b.sent();
193
- this.receivedStatistics = received.statistics;
194
- trainingInformant.update((_a = this.receivedStatistics) !== null && _a !== void 0 ? _a : {});
195
- return [2 /*return*/];
196
- }
197
- });
198
- });
199
- };
200
- // It posts a new metadata value to the server
201
- Client.prototype.postMetadata = function (metadataID, metadata) {
202
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
203
- var msg;
204
- return (0, tslib_1.__generator)(this, function (_a) {
205
- msg = {
206
- type: messages_1.type.postMetadata,
207
- taskId: this.task.taskID,
208
- clientId: this.clientID,
209
- round: this.round,
210
- metadataId: metadataID,
211
- metadata: metadata
212
- };
213
- this.sendMessage(msg);
214
- return [2 /*return*/];
215
- });
216
- });
217
- };
218
- // It gets a metadata map from the server
219
- Client.prototype.getMetadataMap = function (metadataId) {
220
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
221
- var msg, received;
222
- return (0, tslib_1.__generator)(this, function (_a) {
223
- switch (_a.label) {
224
- case 0:
225
- this.metadataMap = undefined;
226
- msg = {
227
- type: messages_1.type.getMetadataMap,
228
- taskId: this.task.taskID,
229
- clientId: this.clientID,
230
- round: this.round,
231
- metadataId: metadataId
232
- };
233
- this.sendMessage(msg);
234
- return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.getMetadataMap, utils_1.MAX_WAIT_PER_ROUND)];
235
- case 1:
236
- received = _a.sent();
237
- if (received.metadataMap !== undefined) {
238
- this.metadataMap = new Map(received.metadataMap);
239
- }
240
- return [2 /*return*/, this.metadataMap];
241
- }
242
- });
243
- });
244
- };
245
- Client.prototype.onRoundEndCommunication = function (updatedWeights, staleWeights, _, trainingInformant) {
246
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
247
- var noisyWeights, serverWeights;
248
- return (0, tslib_1.__generator)(this, function (_a) {
249
- switch (_a.label) {
250
- case 0:
251
- noisyWeights = __1.privacy.addDifferentialPrivacy(updatedWeights, staleWeights, this.task);
252
- return [4 /*yield*/, this.postWeightsToServer(noisyWeights)];
253
- case 1:
254
- _a.sent();
255
- return [4 /*yield*/, this.pullServerStatistics(trainingInformant)];
256
- case 2:
257
- _a.sent();
258
- return [4 /*yield*/, this.pullRoundAndFetchWeights()];
259
- case 3:
260
- serverWeights = _a.sent();
261
- return [2 /*return*/, serverWeights !== null && serverWeights !== void 0 ? serverWeights : staleWeights];
262
- }
263
- });
264
- });
265
- };
266
- Client.prototype.onTrainEndCommunication = function () {
267
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
268
- return [2 /*return*/];
269
- }); });
270
- };
271
- return Client;
272
- }(base_1.Base));
273
- exports.Client = Client;
@@ -1,2 +0,0 @@
1
- export { Client } from './client';
2
- export * as messages from './messages';
@@ -1,7 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.messages = exports.Client = void 0;
4
- var tslib_1 = require("tslib");
5
- var client_1 = require("./client");
6
- Object.defineProperty(exports, "Client", { enumerable: true, get: function () { return client_1.Client; } });
7
- exports.messages = (0, tslib_1.__importStar)(require("./messages"));
@@ -1,38 +0,0 @@
1
- import { MetadataID } from '../..';
2
- import { weights } from '../../serialization';
3
- import { type } from '../messages';
4
- export declare type MessageFederated = postWeightsToServer | latestServerRound | pullServerStatistics | postMetadata | getMetadataMap | messageGeneral;
5
- export interface messageGeneral {
6
- type: type;
7
- }
8
- export interface postWeightsToServer {
9
- type: type.postWeightsToServer;
10
- weights: weights.Encoded;
11
- round: number;
12
- }
13
- export interface latestServerRound {
14
- type: type.latestServerRound;
15
- weights: weights.Encoded;
16
- round: number;
17
- }
18
- export interface pullServerStatistics {
19
- type: type.pullServerStatistics;
20
- statistics: Record<string, number>;
21
- }
22
- export interface postMetadata {
23
- type: type.postMetadata;
24
- clientId: string;
25
- taskId: string;
26
- round: number;
27
- metadataId: string;
28
- metadata: string;
29
- }
30
- export interface getMetadataMap {
31
- type: type.getMetadataMap;
32
- clientId: string;
33
- taskId: string;
34
- round: number;
35
- metadataId: MetadataID;
36
- metadataMap?: Array<[string, string | undefined]>;
37
- }
38
- export declare function isMessageFederated(o: unknown): o is MessageFederated;
@@ -1,25 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.isMessageFederated = void 0;
4
- var messages_1 = require("../messages");
5
- function isMessageFederated(o) {
6
- if (!(0, messages_1.hasMessageType)(o)) {
7
- return false;
8
- }
9
- switch (o.type) {
10
- case messages_1.type.clientConnected:
11
- return true;
12
- case messages_1.type.postWeightsToServer:
13
- return true;
14
- case messages_1.type.latestServerRound:
15
- return true;
16
- case messages_1.type.pullServerStatistics:
17
- return true;
18
- case messages_1.type.postMetadata:
19
- return true;
20
- case messages_1.type.getMetadataMap:
21
- return true;
22
- }
23
- return false;
24
- }
25
- exports.isMessageFederated = isMessageFederated;
@@ -1,5 +0,0 @@
1
- export { Base } from './base';
2
- export * as decentralized from './decentralized';
3
- export * as federated from './federated';
4
- export * as messages from './messages';
5
- export { Local } from './local';
@@ -1,11 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Local = exports.messages = exports.federated = exports.decentralized = exports.Base = void 0;
4
- var tslib_1 = require("tslib");
5
- var base_1 = require("./base");
6
- Object.defineProperty(exports, "Base", { enumerable: true, get: function () { return base_1.Base; } });
7
- exports.decentralized = (0, tslib_1.__importStar)(require("./decentralized"));
8
- exports.federated = (0, tslib_1.__importStar)(require("./federated"));
9
- exports.messages = (0, tslib_1.__importStar)(require("./messages"));
10
- var local_1 = require("./local");
11
- Object.defineProperty(exports, "Local", { enumerable: true, get: function () { return local_1.Local; } });
@@ -1,8 +0,0 @@
1
- import { WeightsContainer } from '..';
2
- import { Base } from './base';
3
- export declare class Local extends Base {
4
- connect(): Promise<void>;
5
- disconnect(): Promise<void>;
6
- onRoundEndCommunication(_: WeightsContainer): Promise<WeightsContainer>;
7
- onTrainEndCommunication(): Promise<void>;
8
- }
@@ -1,36 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Local = void 0;
4
- var tslib_1 = require("tslib");
5
- var base_1 = require("./base");
6
- // does pretty much nothing
7
- var Local = /** @class */ (function (_super) {
8
- (0, tslib_1.__extends)(Local, _super);
9
- function Local() {
10
- return _super !== null && _super.apply(this, arguments) || this;
11
- }
12
- Local.prototype.connect = function () {
13
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
14
- return [2 /*return*/];
15
- }); });
16
- };
17
- Local.prototype.disconnect = function () {
18
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
19
- return [2 /*return*/];
20
- }); });
21
- };
22
- Local.prototype.onRoundEndCommunication = function (_) {
23
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
24
- return (0, tslib_1.__generator)(this, function (_a) {
25
- return [2 /*return*/, _];
26
- });
27
- });
28
- };
29
- Local.prototype.onTrainEndCommunication = function () {
30
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
31
- return [2 /*return*/];
32
- }); });
33
- };
34
- return Local;
35
- }(base_1.Base));
36
- exports.Local = Local;
@@ -1,28 +0,0 @@
1
- import * as decentralized from './decentralized/messages';
2
- import * as federated from './federated/messages';
3
- export declare enum type {
4
- clientConnected = 0,
5
- PeerID = 1,
6
- SignalForPeer = 2,
7
- PeerIsReady = 3,
8
- PeersForRound = 4,
9
- Weights = 5,
10
- Shares = 6,
11
- PartialSums = 7,
12
- postWeightsToServer = 8,
13
- postMetadata = 9,
14
- getMetadataMap = 10,
15
- latestServerRound = 11,
16
- pullRoundAndFetchWeights = 12,
17
- pullServerStatistics = 13
18
- }
19
- export interface clientConnected {
20
- type: type.clientConnected;
21
- }
22
- export declare type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
23
- export declare type NarrowMessage<D> = Extract<Message, {
24
- type: D;
25
- }>;
26
- export declare function hasMessageType(raw: unknown): raw is {
27
- type: type;
28
- } & Record<string, unknown>;
@@ -1,33 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.hasMessageType = exports.type = void 0;
4
- var type;
5
- (function (type) {
6
- type[type["clientConnected"] = 0] = "clientConnected";
7
- // decentralized
8
- type[type["PeerID"] = 1] = "PeerID";
9
- type[type["SignalForPeer"] = 2] = "SignalForPeer";
10
- type[type["PeerIsReady"] = 3] = "PeerIsReady";
11
- type[type["PeersForRound"] = 4] = "PeersForRound";
12
- type[type["Weights"] = 5] = "Weights";
13
- type[type["Shares"] = 6] = "Shares";
14
- type[type["PartialSums"] = 7] = "PartialSums";
15
- // federated
16
- type[type["postWeightsToServer"] = 8] = "postWeightsToServer";
17
- type[type["postMetadata"] = 9] = "postMetadata";
18
- type[type["getMetadataMap"] = 10] = "getMetadataMap";
19
- type[type["latestServerRound"] = 11] = "latestServerRound";
20
- type[type["pullRoundAndFetchWeights"] = 12] = "pullRoundAndFetchWeights";
21
- type[type["pullServerStatistics"] = 13] = "pullServerStatistics";
22
- })(type = exports.type || (exports.type = {}));
23
- function hasMessageType(raw) {
24
- if (typeof raw !== 'object' || raw === null) {
25
- return false;
26
- }
27
- var o = raw;
28
- if (!('type' in o && typeof o.type === 'number' && o.type in type)) {
29
- return false;
30
- }
31
- return true;
32
- }
33
- exports.hasMessageType = hasMessageType;
@@ -1,2 +0,0 @@
1
- export declare const MAX_WAIT_PER_ROUND = 10000;
2
- export declare function timeout(ms: number): Promise<never>;
@@ -1,19 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.timeout = exports.MAX_WAIT_PER_ROUND = void 0;
4
- var tslib_1 = require("tslib");
5
- // Time to wait for the others in milliseconds.
6
- exports.MAX_WAIT_PER_ROUND = 10000;
7
- function timeout(ms) {
8
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
9
- return (0, tslib_1.__generator)(this, function (_a) {
10
- switch (_a.label) {
11
- case 0: return [4 /*yield*/, new Promise(function (resolve, reject) {
12
- setTimeout(function () { return reject(new Error('timeout')); }, ms);
13
- })];
14
- case 1: return [2 /*return*/, _a.sent()];
15
- }
16
- });
17
- });
18
- }
19
- exports.timeout = timeout;
@@ -1,11 +0,0 @@
1
- import { Task } from '../..';
2
- import { Dataset } from '../dataset';
3
- export declare abstract class Data {
4
- readonly dataset: Dataset;
5
- readonly task: Task;
6
- readonly size?: number | undefined;
7
- protected constructor(dataset: Dataset, task: Task, size?: number | undefined);
8
- static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
9
- abstract batch(): Data;
10
- abstract preprocess(): Data;
11
- }
@@ -1,20 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Data = void 0;
4
- var tslib_1 = require("tslib");
5
- var Data = /** @class */ (function () {
6
- function Data(dataset, task, size) {
7
- this.dataset = dataset;
8
- this.task = task;
9
- this.size = size;
10
- }
11
- Data.init = function (dataset, task, size) {
12
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
13
- return (0, tslib_1.__generator)(this, function (_a) {
14
- throw new Error('abstract');
15
- });
16
- });
17
- };
18
- return Data;
19
- }());
20
- exports.Data = Data;
@@ -1,5 +0,0 @@
1
- import { Data } from './data';
2
- export interface DataSplit {
3
- train: Data;
4
- validation?: Data;
5
- }
@@ -1,2 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
@@ -1,8 +0,0 @@
1
- import { Task } from '../..';
2
- import { Dataset } from '../dataset';
3
- import { Data } from './data';
4
- export declare class ImageData extends Data {
5
- static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
6
- batch(): Data;
7
- preprocess(): Data;
8
- }
@@ -1,64 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ImageData = void 0;
4
- var tslib_1 = require("tslib");
5
- var preprocessing_1 = require("./preprocessing");
6
- var data_1 = require("./data");
7
- var ImageData = /** @class */ (function (_super) {
8
- (0, tslib_1.__extends)(ImageData, _super);
9
- function ImageData() {
10
- return _super !== null && _super.apply(this, arguments) || this;
11
- }
12
- ImageData.init = function (dataset, task, size) {
13
- var _a;
14
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
15
- var sample, shape, e_1;
16
- return (0, tslib_1.__generator)(this, function (_b) {
17
- switch (_b.label) {
18
- case 0:
19
- if (!!((_a = task.trainingInformation.preprocessingFunctions) === null || _a === void 0 ? void 0 : _a.includes(preprocessing_1.ImagePreprocessing.Resize))) return [3 /*break*/, 4];
20
- _b.label = 1;
21
- case 1:
22
- _b.trys.push([1, 3, , 4]);
23
- return [4 /*yield*/, dataset.take(1).toArray()];
24
- case 2:
25
- sample = (_b.sent())[0];
26
- // TODO: We suppose the presence of labels
27
- // TODO: Typing (discojs-node/src/dataset/data_loader/image_loader.spec.ts)
28
- if (!(typeof sample === 'object' && sample !== null)) {
29
- throw new Error();
30
- }
31
- shape = void 0;
32
- if ('xs' in sample && 'ys' in sample) {
33
- shape = sample.xs.shape;
34
- }
35
- else {
36
- shape = sample.shape;
37
- }
38
- if (!(shape[0] === task.trainingInformation.IMAGE_W &&
39
- shape[1] === task.trainingInformation.IMAGE_H)) {
40
- throw new Error();
41
- }
42
- return [3 /*break*/, 4];
43
- case 3:
44
- e_1 = _b.sent();
45
- throw new Error('Data input format is not compatible with the chosen task');
46
- case 4: return [2 /*return*/, new ImageData(dataset, task, size)];
47
- }
48
- });
49
- });
50
- };
51
- ImageData.prototype.batch = function () {
52
- var batchSize = this.task.trainingInformation.batchSize;
53
- var newDataset = batchSize === undefined ? this.dataset : this.dataset.batch(batchSize);
54
- return new ImageData(newDataset, this.task, this.size);
55
- };
56
- ImageData.prototype.preprocess = function () {
57
- var newDataset = this.dataset;
58
- var preprocessImage = (0, preprocessing_1.getPreprocessImage)(this.task);
59
- newDataset = newDataset.map(function (x) { return preprocessImage(x); });
60
- return new ImageData(newDataset, this.task, this.size);
61
- };
62
- return ImageData;
63
- }(data_1.Data));
64
- exports.ImageData = ImageData;
@@ -1,5 +0,0 @@
1
- export { DataSplit } from './data_split';
2
- export { Data } from './data';
3
- export { ImageData } from './image_data';
4
- export { TabularData } from './tabular_data';
5
- export { ImagePreprocessing } from './preprocessing';
@@ -1,11 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ImagePreprocessing = exports.TabularData = exports.ImageData = exports.Data = void 0;
4
- var data_1 = require("./data");
5
- Object.defineProperty(exports, "Data", { enumerable: true, get: function () { return data_1.Data; } });
6
- var image_data_1 = require("./image_data");
7
- Object.defineProperty(exports, "ImageData", { enumerable: true, get: function () { return image_data_1.ImageData; } });
8
- var tabular_data_1 = require("./tabular_data");
9
- Object.defineProperty(exports, "TabularData", { enumerable: true, get: function () { return tabular_data_1.TabularData; } });
10
- var preprocessing_1 = require("./preprocessing");
11
- Object.defineProperty(exports, "ImagePreprocessing", { enumerable: true, get: function () { return preprocessing_1.ImagePreprocessing; } });
@@ -1,13 +0,0 @@
1
- import { tf, Task } from '../..';
2
- declare type PreprocessImage = (image: tf.TensorContainer) => tf.TensorContainer;
3
- export declare type Preprocessing = ImagePreprocessing;
4
- export interface ImageTensorContainer extends tf.TensorContainerObject {
5
- xs: tf.Tensor3D | tf.Tensor4D;
6
- ys: tf.Tensor1D | number | undefined;
7
- }
8
- export declare enum ImagePreprocessing {
9
- Normalize = "normalize",
10
- Resize = "resize"
11
- }
12
- export declare function getPreprocessImage(task: Task): PreprocessImage;
13
- export {};