@epfml/discojs 2.0.0 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/default_tasks/cifar10/index.js +60 -0
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/default_tasks/mnist.js +61 -0
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/default_tasks/simple_face/index.js +48 -0
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/default_tasks/titanic.js +88 -0
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.d.ts +5 -0
  135. package/dist/task/digest.js +14 -0
  136. package/dist/{core/task → task}/display_information.d.ts +5 -3
  137. package/dist/task/display_information.js +46 -0
  138. package/dist/task/index.d.ts +7 -0
  139. package/dist/task/index.js +5 -0
  140. package/dist/task/label_type.d.ts +9 -0
  141. package/dist/task/label_type.js +28 -0
  142. package/dist/task/summary.js +13 -0
  143. package/dist/task/task.d.ts +12 -0
  144. package/dist/task/task.js +22 -0
  145. package/dist/task/task_handler.d.ts +5 -0
  146. package/dist/task/task_handler.js +20 -0
  147. package/dist/task/task_provider.d.ts +5 -0
  148. package/dist/task/task_provider.js +1 -0
  149. package/dist/{core/task → task}/training_information.d.ts +9 -10
  150. package/dist/task/training_information.js +88 -0
  151. package/dist/training/disco.d.ts +40 -0
  152. package/dist/training/disco.js +107 -0
  153. package/dist/training/index.d.ts +2 -0
  154. package/dist/training/index.js +1 -0
  155. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  156. package/dist/training/trainer/distributed_trainer.js +36 -0
  157. package/dist/training/trainer/local_trainer.d.ts +12 -0
  158. package/dist/training/trainer/local_trainer.js +19 -0
  159. package/dist/training/trainer/trainer.d.ts +33 -0
  160. package/dist/training/trainer/trainer.js +52 -0
  161. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  162. package/dist/training/trainer/trainer_builder.js +43 -0
  163. package/dist/types.d.ts +8 -0
  164. package/dist/types.js +1 -0
  165. package/dist/utils/event_emitter.d.ts +40 -0
  166. package/dist/utils/event_emitter.js +57 -0
  167. package/dist/validation/index.d.ts +1 -0
  168. package/dist/validation/index.js +1 -0
  169. package/dist/validation/validator.d.ts +28 -0
  170. package/dist/validation/validator.js +132 -0
  171. package/dist/weights/aggregation.d.ts +21 -0
  172. package/dist/weights/aggregation.js +44 -0
  173. package/dist/weights/index.d.ts +2 -0
  174. package/dist/weights/index.js +2 -0
  175. package/dist/weights/weights_container.d.ts +68 -0
  176. package/dist/weights/weights_container.js +96 -0
  177. package/package.json +25 -16
  178. package/README.md +0 -53
  179. package/dist/core/async_buffer.d.ts +0 -41
  180. package/dist/core/async_buffer.js +0 -97
  181. package/dist/core/async_informant.d.ts +0 -20
  182. package/dist/core/async_informant.js +0 -69
  183. package/dist/core/client/base.d.ts +0 -33
  184. package/dist/core/client/base.js +0 -35
  185. package/dist/core/client/decentralized/base.d.ts +0 -32
  186. package/dist/core/client/decentralized/base.js +0 -212
  187. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  188. package/dist/core/client/decentralized/clear_text.js +0 -96
  189. package/dist/core/client/decentralized/index.d.ts +0 -4
  190. package/dist/core/client/decentralized/index.js +0 -9
  191. package/dist/core/client/decentralized/messages.d.ts +0 -41
  192. package/dist/core/client/decentralized/messages.js +0 -54
  193. package/dist/core/client/decentralized/peer.d.ts +0 -26
  194. package/dist/core/client/decentralized/peer.js +0 -210
  195. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  196. package/dist/core/client/decentralized/peer_pool.js +0 -92
  197. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  198. package/dist/core/client/decentralized/sec_agg.js +0 -190
  199. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  200. package/dist/core/client/decentralized/secret_shares.js +0 -39
  201. package/dist/core/client/decentralized/types.d.ts +0 -2
  202. package/dist/core/client/decentralized/types.js +0 -7
  203. package/dist/core/client/event_connection.d.ts +0 -37
  204. package/dist/core/client/event_connection.js +0 -158
  205. package/dist/core/client/federated/client.d.ts +0 -37
  206. package/dist/core/client/federated/client.js +0 -273
  207. package/dist/core/client/federated/index.d.ts +0 -2
  208. package/dist/core/client/federated/index.js +0 -7
  209. package/dist/core/client/federated/messages.d.ts +0 -38
  210. package/dist/core/client/federated/messages.js +0 -25
  211. package/dist/core/client/index.d.ts +0 -5
  212. package/dist/core/client/index.js +0 -11
  213. package/dist/core/client/local.d.ts +0 -8
  214. package/dist/core/client/local.js +0 -36
  215. package/dist/core/client/messages.d.ts +0 -28
  216. package/dist/core/client/messages.js +0 -33
  217. package/dist/core/client/utils.d.ts +0 -2
  218. package/dist/core/client/utils.js +0 -19
  219. package/dist/core/dataset/data/data.d.ts +0 -11
  220. package/dist/core/dataset/data/data.js +0 -20
  221. package/dist/core/dataset/data/data_split.d.ts +0 -5
  222. package/dist/core/dataset/data/data_split.js +0 -2
  223. package/dist/core/dataset/data/image_data.d.ts +0 -8
  224. package/dist/core/dataset/data/image_data.js +0 -64
  225. package/dist/core/dataset/data/index.d.ts +0 -5
  226. package/dist/core/dataset/data/index.js +0 -11
  227. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  228. package/dist/core/dataset/data/preprocessing.js +0 -33
  229. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  230. package/dist/core/dataset/data/tabular_data.js +0 -40
  231. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  232. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  233. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  234. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  235. package/dist/core/dataset/data_loader/index.js +0 -9
  236. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  237. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  238. package/dist/core/dataset/dataset.d.ts +0 -2
  239. package/dist/core/dataset/dataset.js +0 -2
  240. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  241. package/dist/core/dataset/dataset_builder.js +0 -96
  242. package/dist/core/dataset/index.d.ts +0 -4
  243. package/dist/core/dataset/index.js +0 -14
  244. package/dist/core/index.d.ts +0 -18
  245. package/dist/core/index.js +0 -41
  246. package/dist/core/informant/graph_informant.js +0 -23
  247. package/dist/core/informant/index.d.ts +0 -3
  248. package/dist/core/informant/index.js +0 -9
  249. package/dist/core/informant/training_informant/base.d.ts +0 -31
  250. package/dist/core/informant/training_informant/base.js +0 -83
  251. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  252. package/dist/core/informant/training_informant/decentralized.js +0 -22
  253. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  254. package/dist/core/informant/training_informant/federated.js +0 -32
  255. package/dist/core/informant/training_informant/index.d.ts +0 -4
  256. package/dist/core/informant/training_informant/index.js +0 -11
  257. package/dist/core/informant/training_informant/local.d.ts +0 -6
  258. package/dist/core/informant/training_informant/local.js +0 -20
  259. package/dist/core/logging/console_logger.js +0 -33
  260. package/dist/core/logging/index.d.ts +0 -3
  261. package/dist/core/logging/index.js +0 -9
  262. package/dist/core/logging/logger.js +0 -9
  263. package/dist/core/logging/trainer_logger.d.ts +0 -24
  264. package/dist/core/logging/trainer_logger.js +0 -59
  265. package/dist/core/memory/base.d.ts +0 -22
  266. package/dist/core/memory/base.js +0 -9
  267. package/dist/core/memory/empty.d.ts +0 -14
  268. package/dist/core/memory/empty.js +0 -75
  269. package/dist/core/memory/index.d.ts +0 -3
  270. package/dist/core/memory/index.js +0 -9
  271. package/dist/core/memory/model_type.d.ts +0 -4
  272. package/dist/core/memory/model_type.js +0 -9
  273. package/dist/core/serialization/index.d.ts +0 -2
  274. package/dist/core/serialization/index.js +0 -6
  275. package/dist/core/serialization/model.d.ts +0 -5
  276. package/dist/core/serialization/model.js +0 -55
  277. package/dist/core/serialization/weights.js +0 -64
  278. package/dist/core/task/data_example.js +0 -24
  279. package/dist/core/task/display_information.js +0 -49
  280. package/dist/core/task/index.d.ts +0 -3
  281. package/dist/core/task/index.js +0 -8
  282. package/dist/core/task/model_compile_data.d.ts +0 -6
  283. package/dist/core/task/model_compile_data.js +0 -22
  284. package/dist/core/task/summary.js +0 -19
  285. package/dist/core/task/task.d.ts +0 -10
  286. package/dist/core/task/task.js +0 -31
  287. package/dist/core/task/training_information.js +0 -66
  288. package/dist/core/tasks/cifar10.d.ts +0 -3
  289. package/dist/core/tasks/cifar10.js +0 -65
  290. package/dist/core/tasks/geotags.d.ts +0 -3
  291. package/dist/core/tasks/geotags.js +0 -67
  292. package/dist/core/tasks/index.d.ts +0 -6
  293. package/dist/core/tasks/index.js +0 -10
  294. package/dist/core/tasks/lus_covid.d.ts +0 -3
  295. package/dist/core/tasks/lus_covid.js +0 -87
  296. package/dist/core/tasks/mnist.d.ts +0 -3
  297. package/dist/core/tasks/mnist.js +0 -60
  298. package/dist/core/tasks/simple_face.d.ts +0 -2
  299. package/dist/core/tasks/simple_face.js +0 -41
  300. package/dist/core/tasks/titanic.d.ts +0 -3
  301. package/dist/core/tasks/titanic.js +0 -88
  302. package/dist/core/training/disco.d.ts +0 -23
  303. package/dist/core/training/disco.js +0 -130
  304. package/dist/core/training/index.d.ts +0 -2
  305. package/dist/core/training/index.js +0 -7
  306. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  307. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  308. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  309. package/dist/core/training/trainer/local_trainer.js +0 -34
  310. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  311. package/dist/core/training/trainer/round_tracker.js +0 -47
  312. package/dist/core/training/trainer/trainer.d.ts +0 -65
  313. package/dist/core/training/trainer/trainer.js +0 -160
  314. package/dist/core/training/trainer/trainer_builder.js +0 -95
  315. package/dist/core/training/training_schemes.d.ts +0 -5
  316. package/dist/core/training/training_schemes.js +0 -10
  317. package/dist/core/types.d.ts +0 -4
  318. package/dist/core/types.js +0 -2
  319. package/dist/core/validation/index.d.ts +0 -1
  320. package/dist/core/validation/index.js +0 -5
  321. package/dist/core/validation/validator.d.ts +0 -17
  322. package/dist/core/validation/validator.js +0 -104
  323. package/dist/core/weights/aggregation.d.ts +0 -8
  324. package/dist/core/weights/aggregation.js +0 -96
  325. package/dist/core/weights/index.d.ts +0 -2
  326. package/dist/core/weights/index.js +0 -7
  327. package/dist/core/weights/weights_container.d.ts +0 -19
  328. package/dist/core/weights/weights_container.js +0 -64
  329. package/dist/imports.d.ts +0 -2
  330. package/dist/imports.js +0 -7
  331. package/dist/memory/memory.d.ts +0 -26
  332. package/dist/memory/memory.js +0 -160
  333. package/dist/{core/task → task}/data_example.d.ts +1 -1
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -0,0 +1,91 @@
1
+ import { Map, List, Range } from 'immutable';
2
+ import * as tf from '@tensorflow/tfjs';
3
+ import { AggregationStep, Base as Aggregator } from './base.js';
4
+ import { aggregation } from '../index.js';
5
+ /**
6
+ * Aggregator implementing secure multi-party computation for decentralized learning.
7
+ * An aggregation consists of two communication rounds:
8
+ * - first, nodes communicate their secret shares to each other;
9
+ * - then, they sum their received shares and communicate the result.
10
+ * Finally, nodes are able to average the received partial sums to establish the aggregation result.
11
+ */
12
+ export class SecureAggregator extends Aggregator {
13
+ maxShareValue;
14
+ constructor(model, maxShareValue = 100) {
15
+ super(model, 0, 2);
16
+ this.maxShareValue = maxShareValue;
17
+ }
18
+ aggregate() {
19
+ this.log(AggregationStep.AGGREGATE);
20
+ if (this.communicationRound === 0) {
21
+ // Sum the received shares
22
+ const result = aggregation.sum(this.contributions.get(0)?.values());
23
+ this.emit(result);
24
+ }
25
+ else if (this.communicationRound === 1) {
26
+ // Average the received partial sums
27
+ const result = aggregation.avg(this.contributions.get(1)?.values());
28
+ if (this.model !== undefined) {
29
+ this.model.weights = result;
30
+ }
31
+ this.emit(result);
32
+ }
33
+ else {
34
+ throw new Error('communication round is out of bounds');
35
+ }
36
+ }
37
+ add(nodeId, contribution, round, communicationRound) {
38
+ if (this.nodes.has(nodeId) && this.isWithinRoundCutoff(round)) {
39
+ this.log(this.contributions.hasIn([communicationRound, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId);
40
+ this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
41
+ this.informant?.update();
42
+ if (this.isFull()) {
43
+ this.aggregate();
44
+ }
45
+ return true;
46
+ }
47
+ return false;
48
+ }
49
+ isFull() {
50
+ const contribs = this.contributions.get(this.communicationRound);
51
+ if (contribs === undefined) {
52
+ return false;
53
+ }
54
+ return contribs.size === this.nodes.size;
55
+ }
56
+ makePayloads(weights) {
57
+ if (this.communicationRound === 0) {
58
+ const shares = this.generateAllShares(weights);
59
+ // Abitrarily assign our shares to the available nodes
60
+ return Map(List(this.nodes).zip(shares));
61
+ }
62
+ else {
63
+ // Send our partial sum to every other nodes
64
+ return this.nodes.toMap().map(() => weights);
65
+ }
66
+ }
67
+ /**
68
+ * Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
69
+ */
70
+ generateAllShares(secret) {
71
+ if (this.nodes.size === 0) {
72
+ throw new Error('too few participants to generate shares');
73
+ }
74
+ // Generate N-1 shares
75
+ const shares = Range(0, this.nodes.size - 1)
76
+ .map(() => this.generateRandomShare(secret))
77
+ .toList();
78
+ // The last share completes the sum
79
+ return shares.push(secret.sub(aggregation.sum(shares)));
80
+ }
81
+ /**
82
+ * Generates one share in the same shape as the secret that is populated with values randomly chosen from
83
+ * a uniform distribution between (-maxShareValue, maxShareValue).
84
+ */
85
+ generateRandomShare(secret) {
86
+ const MAX_SEED_BITS = 47;
87
+ const random = crypto.getRandomValues(new BigInt64Array(1))[0];
88
+ const seed = Number(BigInt.asUintN(MAX_SEED_BITS, random));
89
+ return secret.map((t) => tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue, 'float32', seed));
90
+ }
91
+ }
@@ -0,0 +1,15 @@
1
+ import type { AggregatorBase } from './aggregator/index.js';
2
+ export declare class AsyncInformant<T> {
3
+ private readonly aggregator;
4
+ private _round;
5
+ private _currentNumberOfParticipants;
6
+ private _totalNumberOfParticipants;
7
+ private _averageNumberOfParticipants;
8
+ constructor(aggregator: AggregatorBase<T>);
9
+ update(): void;
10
+ get round(): number;
11
+ get currentNumberOfParticipants(): number;
12
+ get totalNumberOfParticipants(): number;
13
+ get averageNumberOfParticipants(): number;
14
+ getAllStatistics(): Record<'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number>;
15
+ }
@@ -0,0 +1,42 @@
1
+ export class AsyncInformant {
2
+ aggregator;
3
+ _round = 0;
4
+ _currentNumberOfParticipants = 0;
5
+ _totalNumberOfParticipants = 0;
6
+ _averageNumberOfParticipants = 0;
7
+ constructor(aggregator) {
8
+ this.aggregator = aggregator;
9
+ }
10
+ update() {
11
+ if (this.round === 0 || this.round < this.aggregator.round) {
12
+ this._round = this.aggregator.round;
13
+ this._currentNumberOfParticipants = this.aggregator.size;
14
+ this._averageNumberOfParticipants = this.totalNumberOfParticipants / this.round;
15
+ this._totalNumberOfParticipants += this.currentNumberOfParticipants;
16
+ }
17
+ else {
18
+ this._round = this.aggregator.round;
19
+ }
20
+ }
21
+ // Getter functions
22
+ get round() {
23
+ return this._round;
24
+ }
25
+ get currentNumberOfParticipants() {
26
+ return this._currentNumberOfParticipants;
27
+ }
28
+ get totalNumberOfParticipants() {
29
+ return this._totalNumberOfParticipants;
30
+ }
31
+ get averageNumberOfParticipants() {
32
+ return this._averageNumberOfParticipants;
33
+ }
34
+ getAllStatistics() {
35
+ return {
36
+ round: this.round,
37
+ currentNumberOfParticipants: this.currentNumberOfParticipants,
38
+ totalNumberOfParticipants: this.totalNumberOfParticipants,
39
+ averageNumberOfParticipants: this.averageNumberOfParticipants
40
+ };
41
+ }
42
+ }
@@ -0,0 +1,76 @@
1
+ import type { Set } from 'immutable';
2
+ import type { Model, Task, WeightsContainer } from '../index.js';
3
+ import type { NodeID } from './types.js';
4
+ import type { EventConnection } from './event_connection.js';
5
+ import type { Aggregator } from '../aggregator/index.js';
6
+ /**
7
+ * Main, abstract, class representing a Disco client in a network, which handles
8
+ * communication with other nodes, be it peers or a server.
9
+ */
10
+ export declare abstract class Base {
11
+ /**
12
+ * The network server's URL to connect to.
13
+ */
14
+ readonly url: URL;
15
+ /**
16
+ * The client's corresponding task.
17
+ */
18
+ readonly task: Task;
19
+ /**
20
+ * The client's aggregator.
21
+ */
22
+ readonly aggregator: Aggregator;
23
+ /**
24
+ * Own ID provided by the network's server.
25
+ */
26
+ protected _ownId?: NodeID;
27
+ /**
28
+ * The network's server.
29
+ */
30
+ protected _server?: EventConnection;
31
+ /**
32
+ * The aggregator's result produced after aggregation.
33
+ */
34
+ protected aggregationResult?: Promise<WeightsContainer>;
35
+ constructor(
36
+ /**
37
+ * The network server's URL to connect to.
38
+ */
39
+ url: URL,
40
+ /**
41
+ * The client's corresponding task.
42
+ */
43
+ task: Task,
44
+ /**
45
+ * The client's aggregator.
46
+ */
47
+ aggregator: Aggregator);
48
+ /**
49
+ * Handles the connection process from the client to any sort of network server.
50
+ */
51
+ connect(): Promise<void>;
52
+ /**
53
+ * Handles the disconnection process of the client from any sort of network server.
54
+ */
55
+ disconnect(): Promise<void>;
56
+ /**
57
+ * Fetches the latest model available on the network's server, for the adequate task.
58
+ * @returns The latest model
59
+ */
60
+ getLatestModel(): Promise<Model>;
61
+ /**
62
+ * Communication callback called at the beginning of every training round.
63
+ * @param _weights The most recent local weight updates
64
+ * @param _round The current training round
65
+ */
66
+ onRoundBeginCommunication(_weights: WeightsContainer, _round: number): Promise<void>;
67
+ /**
68
+ * Communication callback called the end of every training round.
69
+ * @param _weights The most recent local weight updates
70
+ * @param _round The current training round
71
+ */
72
+ onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<void>;
73
+ get nodes(): Set<NodeID>;
74
+ get ownId(): NodeID;
75
+ get server(): EventConnection;
76
+ }
@@ -0,0 +1,88 @@
1
+ import axios from 'axios';
2
+ import { serialization } from '../index.js';
3
+ /**
4
+ * Main, abstract, class representing a Disco client in a network, which handles
5
+ * communication with other nodes, be it peers or a server.
6
+ */
7
+ export class Base {
8
+ url;
9
+ task;
10
+ aggregator;
11
+ /**
12
+ * Own ID provided by the network's server.
13
+ */
14
+ _ownId;
15
+ /**
16
+ * The network's server.
17
+ */
18
+ _server;
19
+ /**
20
+ * The aggregator's result produced after aggregation.
21
+ */
22
+ aggregationResult;
23
+ constructor(
24
+ /**
25
+ * The network server's URL to connect to.
26
+ */
27
+ url,
28
+ /**
29
+ * The client's corresponding task.
30
+ */
31
+ task,
32
+ /**
33
+ * The client's aggregator.
34
+ */
35
+ aggregator) {
36
+ this.url = url;
37
+ this.task = task;
38
+ this.aggregator = aggregator;
39
+ }
40
+ /**
41
+ * Handles the connection process from the client to any sort of network server.
42
+ */
43
+ async connect() { }
44
+ /**
45
+ * Handles the disconnection process of the client from any sort of network server.
46
+ */
47
+ async disconnect() { }
48
+ /**
49
+ * Fetches the latest model available on the network's server, for the adequate task.
50
+ * @returns The latest model
51
+ */
52
+ async getLatestModel() {
53
+ const url = new URL('', this.url.href);
54
+ if (!url.pathname.endsWith('/')) {
55
+ url.pathname += '/';
56
+ }
57
+ url.pathname += `tasks/${this.task.id}/model.json`;
58
+ const response = await axios.get(url.href, { responseType: 'arraybuffer' });
59
+ return await serialization.model.decode(new Uint8Array(response.data));
60
+ }
61
+ /**
62
+ * Communication callback called at the beginning of every training round.
63
+ * @param _weights The most recent local weight updates
64
+ * @param _round The current training round
65
+ */
66
+ async onRoundBeginCommunication(_weights, _round) { }
67
+ /**
68
+ * Communication callback called the end of every training round.
69
+ * @param _weights The most recent local weight updates
70
+ * @param _round The current training round
71
+ */
72
+ async onRoundEndCommunication(_weights, _round) { }
73
+ get nodes() {
74
+ return this.aggregator.nodes;
75
+ }
76
+ get ownId() {
77
+ if (this._ownId === undefined) {
78
+ throw new Error('the node is not connected');
79
+ }
80
+ return this._ownId;
81
+ }
82
+ get server() {
83
+ if (this._server === undefined) {
84
+ throw new Error('server undefined, not connected');
85
+ }
86
+ return this._server;
87
+ }
88
+ }
@@ -0,0 +1,32 @@
1
+ import { type WeightsContainer } from '../../index.js';
2
+ import { Client } from '../index.js';
3
+ import { type PeerConnection } from '../event_connection.js';
4
+ import * as messages from './messages.js';
5
+ /**
6
+ * Represents a decentralized client in a network of peers. Peers coordinate each other with the
7
+ * help of the network's server, yet only exchange payloads between each other. Communication
8
+ * with the server is based off regular WebSockets, whereas peer-to-peer communication uses
9
+ * WebRTC for Node.js.
10
+ */
11
+ export declare class Base extends Client {
12
+ /**
13
+ * The pool of peers to communicate with during the current training round.
14
+ */
15
+ private pool?;
16
+ private connections?;
17
+ /**
18
+ * Send message to server that this client is ready for the next training round.
19
+ */
20
+ private waitForPeers;
21
+ protected sendMessagetoPeer(peer: PeerConnection, msg: messages.PeerMessage): void;
22
+ /**
23
+ * Creation of the WebSocket for the server, connection of client to that WebSocket,
24
+ * deals with message reception from the decentralized client's perspective (messages received by client).
25
+ */
26
+ private connectServer;
27
+ connect(): Promise<void>;
28
+ disconnect(): Promise<void>;
29
+ onRoundBeginCommunication(_: WeightsContainer, round: number): Promise<void>;
30
+ onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<void>;
31
+ private receivePayloads;
32
+ }
@@ -0,0 +1,192 @@
1
+ import { Map, Set } from 'immutable';
2
+ import { serialization } from '../../index.js';
3
+ import { Client } from '../index.js';
4
+ import { type } from '../messages.js';
5
+ import { timeout } from '../utils.js';
6
+ import { WebSocketServer, waitMessage, waitMessageWithTimeout } from '../event_connection.js';
7
+ import { PeerPool } from './peer_pool.js';
8
+ import * as messages from './messages.js';
9
+ /**
10
+ * Represents a decentralized client in a network of peers. Peers coordinate each other with the
11
+ * help of the network's server, yet only exchange payloads between each other. Communication
12
+ * with the server is based off regular WebSockets, whereas peer-to-peer communication uses
13
+ * WebRTC for Node.js.
14
+ */
15
+ export class Base extends Client {
16
+ /**
17
+ * The pool of peers to communicate with during the current training round.
18
+ */
19
+ pool;
20
+ connections;
21
+ /**
22
+ * Send message to server that this client is ready for the next training round.
23
+ */
24
+ async waitForPeers(round) {
25
+ console.info(`[${this.ownId}] is ready for round`, round);
26
+ // Broadcast our readiness
27
+ const readyMessage = { type: type.PeerIsReady };
28
+ if (this.server === undefined) {
29
+ throw new Error('server undefined, could not connect peers');
30
+ }
31
+ this.server.send(readyMessage);
32
+ // Wait for peers to be connected before sending any update information
33
+ try {
34
+ const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound);
35
+ if (this.nodes.size > 0) {
36
+ throw new Error('got new peer list from server but was already received for this round');
37
+ }
38
+ const peers = Set(receivedMessage.peers);
39
+ console.info(`[${this.ownId}] received peers for round:`, peers.toJS());
40
+ if (this.ownId !== undefined && peers.has(this.ownId)) {
41
+ throw new Error('received peer list contains our own id');
42
+ }
43
+ this.aggregator.setNodes(peers.add(this.ownId));
44
+ if (this.pool === undefined) {
45
+ throw new Error('waiting for peers but peer pool is undefined');
46
+ }
47
+ const connections = await this.pool.getPeers(peers, this.server,
48
+ // Init receipt of peers weights
49
+ (conn) => { this.receivePayloads(conn, round); });
50
+ console.info(`[${this.ownId}] received peers for round ${round}:`, connections.keySeq().toJS());
51
+ return connections;
52
+ }
53
+ catch (e) {
54
+ console.error(e);
55
+ this.aggregator.setNodes(Set(this.ownId));
56
+ return Map();
57
+ }
58
+ }
59
+ sendMessagetoPeer(peer, msg) {
60
+ console.info(`[${this.ownId}] send message to peer`, msg.peer, msg);
61
+ peer.send(msg);
62
+ }
63
+ /**
64
+ * Creation of the WebSocket for the server, connection of client to that WebSocket,
65
+ * deals with message reception from the decentralized client's perspective (messages received by client).
66
+ */
67
+ async connectServer(url) {
68
+ const server = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer);
69
+ server.on(type.SignalForPeer, (event) => {
70
+ console.info(`[${this.ownId}] received signal from`, event.peer);
71
+ if (this.pool === undefined) {
72
+ throw new Error('received signal but peer pool is undefined');
73
+ }
74
+ this.pool.signal(event.peer, event.signal);
75
+ });
76
+ return server;
77
+ }
78
+ async connect() {
79
+ const serverURL = new URL('', this.url.href);
80
+ switch (this.url.protocol) {
81
+ case 'http:':
82
+ serverURL.protocol = 'ws:';
83
+ break;
84
+ case 'https:':
85
+ serverURL.protocol = 'wss:';
86
+ break;
87
+ default:
88
+ throw new Error(`unknown protocol: ${this.url.protocol}`);
89
+ }
90
+ serverURL.pathname += `deai/${this.task.id}`;
91
+ this._server = await this.connectServer(serverURL);
92
+ const msg = {
93
+ type: type.ClientConnected
94
+ };
95
+ this.server.send(msg);
96
+ const peerIdMsg = await waitMessage(this.server, type.AssignNodeID);
97
+ console.info(`[${peerIdMsg.id}] assigned id generated by server`);
98
+ if (this._ownId !== undefined) {
99
+ throw new Error('received id from server but was already received');
100
+ }
101
+ this._ownId = peerIdMsg.id;
102
+ this.pool = new PeerPool(peerIdMsg.id);
103
+ }
104
+ async disconnect() {
105
+ // Disconnect from peers
106
+ await this.pool?.shutdown();
107
+ this.pool = undefined;
108
+ if (this.connections !== undefined) {
109
+ const peers = this.connections.keySeq().toSet();
110
+ this.aggregator.setNodes(this.aggregator.nodes.subtract(peers));
111
+ }
112
+ // Disconnect from server
113
+ await this.server?.disconnect();
114
+ this._server = undefined;
115
+ this._ownId = undefined;
116
+ return Promise.resolve();
117
+ }
118
+ async onRoundBeginCommunication(_, round) {
119
+ // Reset peers list at each round of training to make sure client works with an updated peers
120
+ // list, maintained by the server. Adds any received weights to the aggregator.
121
+ this.connections = await this.waitForPeers(round);
122
+ // Store the promise for the current round's aggregation result.
123
+ this.aggregationResult = this.aggregator.receiveResult();
124
+ }
125
+ async onRoundEndCommunication(weights, round) {
126
+ let result = weights;
127
+ // Perform the required communication rounds. Each communication round consists in sending our local payload,
128
+ // followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
129
+ // A communication round's payload is the aggregation result of the previous communication round. The first
130
+ // communication round simply sends our training result, i.e. model weights updates. This scheme allows for
131
+ // the aggregator to define any complex multi-round aggregation mechanism.
132
+ for (let r = 0; r < this.aggregator.communicationRounds; r++) {
133
+ // Generate our payloads for this communication round and send them to all ready connected peers
134
+ if (this.connections !== undefined) {
135
+ const payloads = this.aggregator.makePayloads(result);
136
+ try {
137
+ await Promise.all(payloads.map(async (payload, id) => {
138
+ if (id === this.ownId) {
139
+ this.aggregator.add(this.ownId, payload, round, r);
140
+ }
141
+ else {
142
+ const connection = this.connections?.get(id);
143
+ if (connection !== undefined) {
144
+ const encoded = await serialization.weights.encode(payload);
145
+ this.sendMessagetoPeer(connection, {
146
+ type: type.Payload,
147
+ peer: id,
148
+ round: r,
149
+ payload: encoded
150
+ });
151
+ }
152
+ }
153
+ }));
154
+ }
155
+ catch {
156
+ throw new Error('error while sending weights');
157
+ }
158
+ }
159
+ if (this.aggregationResult === undefined) {
160
+ throw new TypeError('aggregation result promise is undefined');
161
+ }
162
+ // Wait for aggregation before proceeding to the next communication round.
163
+ // The current result will be used as payload for the eventual next communication round.
164
+ result = await Promise.race([this.aggregationResult, timeout()]);
165
+ // There is at least one communication round remaining
166
+ if (r < this.aggregator.communicationRounds - 1) {
167
+ // Reuse the aggregation result
168
+ this.aggregationResult = this.aggregator.receiveResult();
169
+ }
170
+ }
171
+ // Reset the peers list for the next round
172
+ this.aggregator.resetNodes();
173
+ }
174
+ receivePayloads(connections, round) {
175
+ console.info(`[${this.ownId}] Accepting new contributions for round ${round}`);
176
+ connections.forEach(async (connection, peerId) => {
177
+ let receivedPayloads = 0;
178
+ do {
179
+ try {
180
+ const message = await waitMessageWithTimeout(connection, type.Payload);
181
+ const decoded = serialization.weights.decode(message.payload);
182
+ if (!this.aggregator.add(peerId, decoded, round, message.round)) {
183
+ console.warn(`[${this.ownId}] Failed to add contribution from peer ${peerId}`);
184
+ }
185
+ }
186
+ catch (e) {
187
+ console.warn(e instanceof Error ? e.message : e);
188
+ }
189
+ } while (++receivedPayloads < this.aggregator.communicationRounds);
190
+ });
191
+ }
192
+ }
@@ -0,0 +1,2 @@
1
+ export { Base as DecentralizedClient } from './base.js';
2
+ export * as messages from './messages.js';
@@ -0,0 +1,2 @@
1
+ export { Base as DecentralizedClient } from './base.js';
2
+ export * as messages from './messages.js';
@@ -0,0 +1,28 @@
1
+ import { weights } from '../../serialization/index.js';
2
+ import { type SignalData } from './peer.js';
3
+ import { type NodeID } from '../types.js';
4
+ import { type, type ClientConnected, type AssignNodeID } from '../messages.js';
5
+ export interface SignalForPeer {
6
+ type: type.SignalForPeer;
7
+ peer: NodeID;
8
+ signal: SignalData;
9
+ }
10
+ export interface PeerIsReady {
11
+ type: type.PeerIsReady;
12
+ }
13
+ export interface PeersForRound {
14
+ type: type.PeersForRound;
15
+ peers: NodeID[];
16
+ }
17
+ export interface Payload {
18
+ type: type.Payload;
19
+ peer: NodeID;
20
+ round: number;
21
+ payload: weights.Encoded;
22
+ }
23
+ export type MessageFromServer = AssignNodeID | SignalForPeer | PeersForRound;
24
+ export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady;
25
+ export type PeerMessage = Payload;
26
+ export declare function isMessageFromServer(o: unknown): o is MessageFromServer;
27
+ export declare function isMessageToServer(o: unknown): o is MessageToServer;
28
+ export declare function isPeerMessage(o: unknown): o is PeerMessage;
@@ -0,0 +1,44 @@
1
+ import { weights } from '../../serialization/index.js';
2
+ import { isNodeID } from '../types.js';
3
+ import { type, hasMessageType } from '../messages.js';
4
+ export function isMessageFromServer(o) {
5
+ if (!hasMessageType(o)) {
6
+ return false;
7
+ }
8
+ switch (o.type) {
9
+ case type.AssignNodeID:
10
+ return 'id' in o && isNodeID(o.id);
11
+ case type.SignalForPeer:
12
+ return 'peer' in o && isNodeID(o.peer) &&
13
+ 'signal' in o; // TODO check signal content?
14
+ case type.PeersForRound:
15
+ return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID);
16
+ }
17
+ return false;
18
+ }
19
+ export function isMessageToServer(o) {
20
+ if (!hasMessageType(o)) {
21
+ return false;
22
+ }
23
+ switch (o.type) {
24
+ case type.ClientConnected:
25
+ return true;
26
+ case type.SignalForPeer:
27
+ return 'peer' in o && isNodeID(o.peer) &&
28
+ 'signal' in o; // TODO check signal content?
29
+ case type.PeerIsReady:
30
+ return true;
31
+ }
32
+ return false;
33
+ }
34
+ export function isPeerMessage(o) {
35
+ if (!hasMessageType(o)) {
36
+ return false;
37
+ }
38
+ switch (o.type) {
39
+ case type.Payload:
40
+ return ('peer' in o && isNodeID(o.peer) &&
41
+ 'payload' in o && weights.isEncoded(o.payload));
42
+ }
43
+ return false;
44
+ }
@@ -0,0 +1,40 @@
1
+ /// <reference types="node" resolution-mode="require"/>
2
+ import type { NodeID } from '../types.js';
3
+ export type SignalData = {
4
+ type: 'answer' | 'offer' | 'pranswer' | 'rollback';
5
+ sdp?: string;
6
+ } | {
7
+ type: 'transceiverRequest';
8
+ transceiverRequest: {
9
+ kind: string;
10
+ };
11
+ } | {
12
+ type: 'renegotiate';
13
+ renegotiate: true;
14
+ } | {
15
+ type: 'candidate';
16
+ candidate: RTCIceCandidate;
17
+ };
18
+ interface Events {
19
+ 'close': () => void;
20
+ 'connect': () => void;
21
+ 'signal': (signal: SignalData) => void;
22
+ 'data': (data: Buffer) => void;
23
+ }
24
+ export declare class Peer {
25
+ readonly id: NodeID;
26
+ private readonly peer;
27
+ private bufferSize?;
28
+ private sendCounter;
29
+ private sendQueue;
30
+ private receiving;
31
+ constructor(id: NodeID, initiator?: boolean);
32
+ send(msg: Buffer): void;
33
+ private flush;
34
+ get maxChunkSize(): number;
35
+ private chunk;
36
+ destroy(): Promise<void>;
37
+ signal(signal: SignalData): void;
38
+ on<K extends keyof Events>(event: K, listener: Events[K]): void;
39
+ }
40
+ export {};