@epfml/discojs 2.1.1 → 2.1.2-p20240506085559.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.js +14 -0
  135. package/dist/{core/task → task}/display_information.d.ts +5 -3
  136. package/dist/task/display_information.js +46 -0
  137. package/dist/task/index.d.ts +7 -0
  138. package/dist/task/index.js +5 -0
  139. package/dist/task/label_type.d.ts +9 -0
  140. package/dist/task/label_type.js +28 -0
  141. package/dist/task/summary.js +13 -0
  142. package/dist/{core/task → task}/task.d.ts +7 -7
  143. package/dist/task/task.js +22 -0
  144. package/dist/task/task_handler.d.ts +5 -0
  145. package/dist/task/task_handler.js +20 -0
  146. package/dist/task/task_provider.d.ts +5 -0
  147. package/dist/task/task_provider.js +1 -0
  148. package/dist/{core/task → task}/training_information.d.ts +9 -10
  149. package/dist/task/training_information.js +88 -0
  150. package/dist/training/disco.d.ts +40 -0
  151. package/dist/training/disco.js +107 -0
  152. package/dist/training/index.d.ts +2 -0
  153. package/dist/training/index.js +1 -0
  154. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  155. package/dist/training/trainer/distributed_trainer.js +36 -0
  156. package/dist/training/trainer/local_trainer.d.ts +12 -0
  157. package/dist/training/trainer/local_trainer.js +19 -0
  158. package/dist/training/trainer/trainer.d.ts +33 -0
  159. package/dist/training/trainer/trainer.js +52 -0
  160. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  161. package/dist/training/trainer/trainer_builder.js +43 -0
  162. package/dist/types.d.ts +8 -0
  163. package/dist/types.js +1 -0
  164. package/dist/utils/event_emitter.d.ts +40 -0
  165. package/dist/utils/event_emitter.js +57 -0
  166. package/dist/validation/index.d.ts +1 -0
  167. package/dist/validation/index.js +1 -0
  168. package/dist/validation/validator.d.ts +28 -0
  169. package/dist/validation/validator.js +132 -0
  170. package/dist/weights/aggregation.d.ts +21 -0
  171. package/dist/weights/aggregation.js +44 -0
  172. package/dist/weights/index.d.ts +2 -0
  173. package/dist/weights/index.js +2 -0
  174. package/dist/weights/weights_container.d.ts +68 -0
  175. package/dist/weights/weights_container.js +96 -0
  176. package/package.json +24 -15
  177. package/README.md +0 -53
  178. package/dist/core/async_buffer.d.ts +0 -41
  179. package/dist/core/async_buffer.js +0 -97
  180. package/dist/core/async_informant.d.ts +0 -20
  181. package/dist/core/async_informant.js +0 -69
  182. package/dist/core/client/base.d.ts +0 -33
  183. package/dist/core/client/base.js +0 -35
  184. package/dist/core/client/decentralized/base.d.ts +0 -32
  185. package/dist/core/client/decentralized/base.js +0 -212
  186. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  187. package/dist/core/client/decentralized/clear_text.js +0 -96
  188. package/dist/core/client/decentralized/index.d.ts +0 -4
  189. package/dist/core/client/decentralized/index.js +0 -9
  190. package/dist/core/client/decentralized/messages.d.ts +0 -41
  191. package/dist/core/client/decentralized/messages.js +0 -54
  192. package/dist/core/client/decentralized/peer.d.ts +0 -26
  193. package/dist/core/client/decentralized/peer.js +0 -210
  194. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  195. package/dist/core/client/decentralized/peer_pool.js +0 -92
  196. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  197. package/dist/core/client/decentralized/sec_agg.js +0 -190
  198. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  199. package/dist/core/client/decentralized/secret_shares.js +0 -39
  200. package/dist/core/client/decentralized/types.d.ts +0 -2
  201. package/dist/core/client/decentralized/types.js +0 -7
  202. package/dist/core/client/event_connection.d.ts +0 -37
  203. package/dist/core/client/event_connection.js +0 -158
  204. package/dist/core/client/federated/client.d.ts +0 -37
  205. package/dist/core/client/federated/client.js +0 -273
  206. package/dist/core/client/federated/index.d.ts +0 -2
  207. package/dist/core/client/federated/index.js +0 -7
  208. package/dist/core/client/federated/messages.d.ts +0 -38
  209. package/dist/core/client/federated/messages.js +0 -25
  210. package/dist/core/client/index.d.ts +0 -5
  211. package/dist/core/client/index.js +0 -11
  212. package/dist/core/client/local.d.ts +0 -8
  213. package/dist/core/client/local.js +0 -36
  214. package/dist/core/client/messages.d.ts +0 -28
  215. package/dist/core/client/messages.js +0 -33
  216. package/dist/core/client/utils.d.ts +0 -2
  217. package/dist/core/client/utils.js +0 -19
  218. package/dist/core/dataset/data/data.d.ts +0 -11
  219. package/dist/core/dataset/data/data.js +0 -20
  220. package/dist/core/dataset/data/data_split.d.ts +0 -5
  221. package/dist/core/dataset/data/data_split.js +0 -2
  222. package/dist/core/dataset/data/image_data.d.ts +0 -8
  223. package/dist/core/dataset/data/image_data.js +0 -64
  224. package/dist/core/dataset/data/index.d.ts +0 -5
  225. package/dist/core/dataset/data/index.js +0 -11
  226. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  227. package/dist/core/dataset/data/preprocessing.js +0 -33
  228. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  229. package/dist/core/dataset/data/tabular_data.js +0 -40
  230. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  231. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  232. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  233. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  234. package/dist/core/dataset/data_loader/index.js +0 -9
  235. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  236. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  237. package/dist/core/dataset/dataset.d.ts +0 -2
  238. package/dist/core/dataset/dataset.js +0 -2
  239. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  240. package/dist/core/dataset/dataset_builder.js +0 -96
  241. package/dist/core/dataset/index.d.ts +0 -4
  242. package/dist/core/dataset/index.js +0 -14
  243. package/dist/core/default_tasks/cifar10.d.ts +0 -2
  244. package/dist/core/default_tasks/geotags.d.ts +0 -2
  245. package/dist/core/default_tasks/geotags.js +0 -69
  246. package/dist/core/default_tasks/index.d.ts +0 -6
  247. package/dist/core/default_tasks/index.js +0 -15
  248. package/dist/core/default_tasks/lus_covid.d.ts +0 -2
  249. package/dist/core/default_tasks/lus_covid.js +0 -96
  250. package/dist/core/default_tasks/mnist.d.ts +0 -2
  251. package/dist/core/default_tasks/simple_face.d.ts +0 -2
  252. package/dist/core/default_tasks/titanic.d.ts +0 -2
  253. package/dist/core/index.d.ts +0 -18
  254. package/dist/core/index.js +0 -39
  255. package/dist/core/informant/graph_informant.js +0 -23
  256. package/dist/core/informant/index.d.ts +0 -3
  257. package/dist/core/informant/index.js +0 -9
  258. package/dist/core/informant/training_informant/base.d.ts +0 -31
  259. package/dist/core/informant/training_informant/base.js +0 -83
  260. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  261. package/dist/core/informant/training_informant/decentralized.js +0 -22
  262. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  263. package/dist/core/informant/training_informant/federated.js +0 -32
  264. package/dist/core/informant/training_informant/index.d.ts +0 -4
  265. package/dist/core/informant/training_informant/index.js +0 -11
  266. package/dist/core/informant/training_informant/local.d.ts +0 -6
  267. package/dist/core/informant/training_informant/local.js +0 -20
  268. package/dist/core/logging/console_logger.js +0 -33
  269. package/dist/core/logging/index.d.ts +0 -3
  270. package/dist/core/logging/index.js +0 -9
  271. package/dist/core/logging/logger.js +0 -9
  272. package/dist/core/logging/trainer_logger.d.ts +0 -24
  273. package/dist/core/logging/trainer_logger.js +0 -59
  274. package/dist/core/memory/base.d.ts +0 -22
  275. package/dist/core/memory/base.js +0 -9
  276. package/dist/core/memory/empty.d.ts +0 -14
  277. package/dist/core/memory/empty.js +0 -75
  278. package/dist/core/memory/index.d.ts +0 -3
  279. package/dist/core/memory/index.js +0 -9
  280. package/dist/core/memory/model_type.d.ts +0 -4
  281. package/dist/core/memory/model_type.js +0 -9
  282. package/dist/core/serialization/index.d.ts +0 -2
  283. package/dist/core/serialization/index.js +0 -6
  284. package/dist/core/serialization/model.d.ts +0 -5
  285. package/dist/core/serialization/model.js +0 -55
  286. package/dist/core/serialization/weights.js +0 -64
  287. package/dist/core/task/data_example.js +0 -24
  288. package/dist/core/task/digest.js +0 -18
  289. package/dist/core/task/display_information.js +0 -49
  290. package/dist/core/task/index.d.ts +0 -6
  291. package/dist/core/task/index.js +0 -15
  292. package/dist/core/task/model_compile_data.d.ts +0 -6
  293. package/dist/core/task/model_compile_data.js +0 -22
  294. package/dist/core/task/summary.js +0 -19
  295. package/dist/core/task/task.js +0 -35
  296. package/dist/core/task/task_handler.d.ts +0 -5
  297. package/dist/core/task/task_handler.js +0 -53
  298. package/dist/core/task/task_provider.d.ts +0 -6
  299. package/dist/core/task/task_provider.js +0 -13
  300. package/dist/core/task/training_information.js +0 -66
  301. package/dist/core/training/disco.d.ts +0 -23
  302. package/dist/core/training/disco.js +0 -130
  303. package/dist/core/training/index.d.ts +0 -2
  304. package/dist/core/training/index.js +0 -7
  305. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  306. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  307. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  308. package/dist/core/training/trainer/local_trainer.js +0 -34
  309. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  310. package/dist/core/training/trainer/round_tracker.js +0 -47
  311. package/dist/core/training/trainer/trainer.d.ts +0 -65
  312. package/dist/core/training/trainer/trainer.js +0 -160
  313. package/dist/core/training/trainer/trainer_builder.js +0 -95
  314. package/dist/core/training/training_schemes.d.ts +0 -5
  315. package/dist/core/training/training_schemes.js +0 -10
  316. package/dist/core/types.d.ts +0 -4
  317. package/dist/core/types.js +0 -2
  318. package/dist/core/validation/index.d.ts +0 -1
  319. package/dist/core/validation/index.js +0 -5
  320. package/dist/core/validation/validator.d.ts +0 -17
  321. package/dist/core/validation/validator.js +0 -104
  322. package/dist/core/weights/aggregation.d.ts +0 -7
  323. package/dist/core/weights/aggregation.js +0 -72
  324. package/dist/core/weights/index.d.ts +0 -2
  325. package/dist/core/weights/index.js +0 -7
  326. package/dist/core/weights/weights_container.d.ts +0 -19
  327. package/dist/core/weights/weights_container.js +0 -64
  328. package/dist/imports.d.ts +0 -2
  329. package/dist/imports.js +0 -7
  330. package/dist/memory/memory.d.ts +0 -26
  331. package/dist/memory/memory.js +0 -160
  332. package/dist/{core/task → task}/data_example.d.ts +1 -1
  333. package/dist/{core/task → task}/digest.d.ts +0 -0
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -0,0 +1,180 @@
1
+ import { Map, Set } from 'immutable';
2
+ import type { client, Model, AsyncInformant } from '../index.js';
3
+ export declare enum AggregationStep {
4
+ ADD = 0,
5
+ UPDATE = 1,
6
+ AGGREGATE = 2
7
+ }
8
+ /**
9
+ * Main, abstract, aggregator class whose role is to buffer contributions and to produce
10
+ * a result based off their aggregation, whenever some defined condition is met.
11
+ */
12
+ export declare abstract class Base<T> {
13
+ /**
14
+ * The Model whose weights are updated on aggregation.
15
+ */
16
+ protected _model?: Model | undefined;
17
+ /**
18
+ * The round cut-off for contributions.
19
+ */
20
+ protected readonly roundCutoff: number;
21
+ /**
22
+ * The number of communication rounds occurring during any given aggregation round.
23
+ */
24
+ readonly communicationRounds: number;
25
+ /**
26
+ * Contains the ids of all active nodes, i.e. members of the aggregation group at
27
+ * a given round. It is a subset of all the nodes available in the network.
28
+ */
29
+ protected _nodes: Set<client.NodeID>;
30
+ /**
31
+ * Contains the contributions received from active nodes, accessible by node id.
32
+ * It defines the effective aggregation group, which is possibly a subset
33
+ * of all active nodes, depending on the aggregation scheme.
34
+ */
35
+ protected contributions: Map<number, Map<client.NodeID, T>>;
36
+ /**
37
+ * Emits the aggregation event whenever an aggregation step is performed.
38
+ * Triggers the resolve of the result promise and the preparation for the
39
+ * next aggregation round.
40
+ */
41
+ private readonly eventEmitter;
42
+ protected informant?: AsyncInformant<T>;
43
+ /**
44
+ * The result promise which, on resolve, will contain the current aggregation result.
45
+ * This promise should be fetched by any object making use of an aggregator, in order
46
+ * to await upon aggregation.
47
+ */
48
+ protected result: Promise<T>;
49
+ /**
50
+ * The current aggregation round, used for assessing whether a node contribution is recent enough
51
+ * or not.
52
+ */
53
+ protected _round: number;
54
+ /**
55
+ * The current communication round. A single aggregation round is made of possibly multiple
56
+ * communication rounds. This makes the aggregator free to perform intermediate aggregation
57
+ * steps based off communication with its nodes. Overall, this allows for more complex
58
+ * aggregation schemes requiring an exchange of information between nodes before aggregating.
59
+ */
60
+ protected _communicationRound: number;
61
+ constructor(
62
+ /**
63
+ * The Model whose weights are updated on aggregation.
64
+ */
65
+ _model?: Model | undefined,
66
+ /**
67
+ * The round cut-off for contributions.
68
+ */
69
+ roundCutoff?: number,
70
+ /**
71
+ * The number of communication rounds occurring during any given aggregation round.
72
+ */
73
+ communicationRounds?: number);
74
+ /**
75
+ * Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
76
+ * The contribution will be aggregated during the next aggregation step.
77
+ * @param nodeId The node's id
78
+ * @param contribution The node's contribution
79
+ * @param round For which aggregation round the contribution was made
80
+ * @param communicationRound For which communication round the contribution was made
81
+ */
82
+ abstract add(nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean;
83
+ /**
84
+ * Performs an aggregation step over the received node contributions.
85
+ * Must store the aggregation's result in the aggregator's result promise.
86
+ */
87
+ abstract aggregate(): void;
88
+ registerObserver(informant: AsyncInformant<T>): void;
89
+ /**
90
+ * Returns whether the given round is recent enough, dependent on the
91
+ * aggregator's round cutoff.
92
+ * @param round The round
93
+ * @returns True if the round is recent enough, false otherwise
94
+ */
95
+ isWithinRoundCutoff(round: number): boolean;
96
+ /**
97
+ * Logs useful messages during the various aggregation steps.
98
+ * @param step The aggregation step
99
+ * @param from The node which triggered the logging message
100
+ */
101
+ log(step: AggregationStep, from?: client.NodeID): void;
102
+ /**
103
+ * Sets the aggregator's TF.js model.
104
+ * @param model The new TF.js model
105
+ */
106
+ setModel(model: Model): void;
107
+ /**
108
+ * Adds a node's id to the set of active nodes. A node represents an active neighbor
109
+ * peer/client within the network, whom we are communicating with during this aggregation
110
+ * round.
111
+ * @param nodeId The node to be added
112
+ */
113
+ registerNode(nodeId: client.NodeID): boolean;
114
+ /**
115
+ * Overwrites the current set of active nodes with the given one. A node represents
116
+ * an active neighbor peer/client within the network, whom we are communicating with
117
+ * during this aggregation round.
118
+ * @param nodeIds The new set of nodes
119
+ */
120
+ setNodes(nodeIds: Set<client.NodeID>): void;
121
+ /**
122
+ * Empties the current set of "nodes". Usually called at the end of an aggregation round,
123
+ * if the set of nodes is meant to change or to be actualized.
124
+ */
125
+ resetNodes(): void;
126
+ /**
127
+ * Sets the aggregator's round number. To be used whenever the aggregator is out of sync
128
+ * with the network's round.
129
+ * @param round The new round
130
+ */
131
+ setRound(round: number): void;
132
+ /**
133
+ * Emits the event containing the aggregation result, which allows the result
134
+ * promise to resolve and for the next aggregation round to take place.
135
+ * @param aggregated The aggregation result
136
+ */
137
+ protected emit(aggregated: T): void;
138
+ /**
139
+ * Updates the aggregator's state to proceed to the next communication round.
140
+ * If all communication rounds were performed, proceeds to the next aggregation round
141
+ * and empties the collection of stored contributions.
142
+ */
143
+ nextRound(): void;
144
+ private makeResult;
145
+ /**
146
+ * Aggregation steps are performed asynchronously, yet can be awaited upon when required.
147
+ * This function gives access to the current aggregation result's promise, which will
148
+ * eventually resolve and contain the result of the very next aggregation step, at the
149
+ * time of the function call.
150
+ * @returns The promise containing the aggregation result
151
+ */
152
+ receiveResult(): Promise<T>;
153
+ /**
154
+ * Constructs the payloads sent to other nodes as contribution.
155
+ * @param base Object from which the payload is computed
156
+ */
157
+ abstract makePayloads(base: T): Map<client.NodeID, T>;
158
+ abstract isFull(): boolean;
159
+ /**
160
+ * The set of node ids, representing our neighbors within the network.
161
+ */
162
+ get nodes(): Set<client.NodeID>;
163
+ /**
164
+ * The aggregation round.
165
+ */
166
+ get round(): number;
167
+ /**
168
+ * The aggregator's current size, defined by its number of contributions. The size is bounded by
169
+ * the amount of all active nodes times the number of communication rounds.
170
+ */
171
+ get size(): number;
172
+ /**
173
+ * The aggregator's current model.
174
+ */
175
+ get model(): Model | undefined;
176
+ /**
177
+ * The current communication round.
178
+ */
179
+ get communicationRound(): number;
180
+ }
@@ -0,0 +1,236 @@
1
+ import { Map, Set } from 'immutable';
2
+ import { EventEmitter } from '../utils/event_emitter.js';
3
+ export var AggregationStep;
4
+ (function (AggregationStep) {
5
+ AggregationStep[AggregationStep["ADD"] = 0] = "ADD";
6
+ AggregationStep[AggregationStep["UPDATE"] = 1] = "UPDATE";
7
+ AggregationStep[AggregationStep["AGGREGATE"] = 2] = "AGGREGATE";
8
+ })(AggregationStep || (AggregationStep = {}));
9
+ /**
10
+ * Main, abstract, aggregator class whose role is to buffer contributions and to produce
11
+ * a result based off their aggregation, whenever some defined condition is met.
12
+ */
13
+ export class Base {
14
+ _model;
15
+ roundCutoff;
16
+ communicationRounds;
17
+ /**
18
+ * Contains the ids of all active nodes, i.e. members of the aggregation group at
19
+ * a given round. It is a subset of all the nodes available in the network.
20
+ */
21
+ _nodes;
22
+ /**
23
+ * Contains the contributions received from active nodes, accessible by node id.
24
+ * It defines the effective aggregation group, which is possibly a subset
25
+ * of all active nodes, depending on the aggregation scheme.
26
+ */
27
+ contributions;
28
+ /**
29
+ * Emits the aggregation event whenever an aggregation step is performed.
30
+ * Triggers the resolve of the result promise and the preparation for the
31
+ * next aggregation round.
32
+ */
33
+ eventEmitter = new EventEmitter();
34
+ informant;
35
+ /**
36
+ * The result promise which, on resolve, will contain the current aggregation result.
37
+ * This promise should be fetched by any object making use of an aggregator, in order
38
+ * to await upon aggregation.
39
+ */
40
+ result;
41
+ /**
42
+ * The current aggregation round, used for assessing whether a node contribution is recent enough
43
+ * or not.
44
+ */
45
+ _round = 0;
46
+ /**
47
+ * The current communication round. A single aggregation round is made of possibly multiple
48
+ * communication rounds. This makes the aggregator free to perform intermediate aggregation
49
+ * steps based off communication with its nodes. Overall, this allows for more complex
50
+ * aggregation schemes requiring an exchange of information between nodes before aggregating.
51
+ */
52
+ _communicationRound = 0;
53
+ constructor(
54
+ /**
55
+ * The Model whose weights are updated on aggregation.
56
+ */
57
+ _model,
58
+ /**
59
+ * The round cut-off for contributions.
60
+ */
61
+ roundCutoff = 0,
62
+ /**
63
+ * The number of communication rounds occurring during any given aggregation round.
64
+ */
65
+ communicationRounds = 1) {
66
+ this._model = _model;
67
+ this.roundCutoff = roundCutoff;
68
+ this.communicationRounds = communicationRounds;
69
+ this.contributions = Map();
70
+ this._nodes = Set();
71
+ // Make the initial result promise
72
+ this.result = this.makeResult();
73
+ // On every aggregation, update the object's state to match the current aggregation
74
+ // and communication rounds.
75
+ this.eventEmitter.on('aggregation', () => {
76
+ this.nextRound();
77
+ });
78
+ }
79
+ registerObserver(informant) {
80
+ this.informant = informant;
81
+ }
82
+ /**
83
+ * Returns whether the given round is recent enough, dependent on the
84
+ * aggregator's round cutoff.
85
+ * @param round The round
86
+ * @returns True if the round is recent enough, false otherwise
87
+ */
88
+ isWithinRoundCutoff(round) {
89
+ return this.round - round <= this.roundCutoff;
90
+ }
91
+ /**
92
+ * Logs useful messages during the various aggregation steps.
93
+ * @param step The aggregation step
94
+ * @param from The node which triggered the logging message
95
+ */
96
+ log(step, from) {
97
+ switch (step) {
98
+ case AggregationStep.ADD:
99
+ console.log(`> Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
100
+ break;
101
+ case AggregationStep.UPDATE:
102
+ if (from === undefined) {
103
+ return;
104
+ }
105
+ console.log(`> Updating contribution from node ${from} for round (${this.communicationRound}, ${this.round})`);
106
+ break;
107
+ case AggregationStep.AGGREGATE:
108
+ console.log('*'.repeat(80));
109
+ console.log(`Buffer is full. Aggregating weights for round (${this.communicationRound}, ${this.round})\n`);
110
+ break;
111
+ default: {
112
+ const _ = step;
113
+ throw new Error('should never happen');
114
+ }
115
+ }
116
+ }
117
+ /**
118
+ * Sets the aggregator's TF.js model.
119
+ * @param model The new TF.js model
120
+ */
121
+ setModel(model) {
122
+ this._model = model;
123
+ }
124
+ /**
125
+ * Adds a node's id to the set of active nodes. A node represents an active neighbor
126
+ * peer/client within the network, whom we are communicating with during this aggregation
127
+ * round.
128
+ * @param nodeId The node to be added
129
+ */
130
+ registerNode(nodeId) {
131
+ if (!this.nodes.has(nodeId)) {
132
+ this._nodes = this._nodes.add(nodeId);
133
+ return true;
134
+ }
135
+ return false;
136
+ }
137
+ /**
138
+ * Overwrites the current set of active nodes with the given one. A node represents
139
+ * an active neighbor peer/client within the network, whom we are communicating with
140
+ * during this aggregation round.
141
+ * @param nodeIds The new set of nodes
142
+ */
143
+ setNodes(nodeIds) {
144
+ this._nodes = nodeIds;
145
+ }
146
+ /**
147
+ * Empties the current set of "nodes". Usually called at the end of an aggregation round,
148
+ * if the set of nodes is meant to change or to be actualized.
149
+ */
150
+ resetNodes() {
151
+ this._nodes = Set();
152
+ }
153
+ /**
154
+ * Sets the aggregator's round number. To be used whenever the aggregator is out of sync
155
+ * with the network's round.
156
+ * @param round The new round
157
+ */
158
+ setRound(round) {
159
+ if (round > this.round) {
160
+ this._round = round;
161
+ }
162
+ }
163
+ /**
164
+ * Emits the event containing the aggregation result, which allows the result
165
+ * promise to resolve and for the next aggregation round to take place.
166
+ * @param aggregated The aggregation result
167
+ */
168
+ emit(aggregated) {
169
+ this.eventEmitter.emit('aggregation', aggregated);
170
+ }
171
+ /**
172
+ * Updates the aggregator's state to proceed to the next communication round.
173
+ * If all communication rounds were performed, proceeds to the next aggregation round
174
+ * and empties the collection of stored contributions.
175
+ */
176
+ nextRound() {
177
+ if (++this._communicationRound === this.communicationRounds) {
178
+ this._communicationRound = 0;
179
+ this._round++;
180
+ this.contributions = Map();
181
+ }
182
+ this.result = this.makeResult();
183
+ this.informant?.update();
184
+ }
185
+ async makeResult() {
186
+ return await new Promise((resolve) => {
187
+ this.eventEmitter.once('aggregation', (w) => {
188
+ resolve(w);
189
+ });
190
+ });
191
+ }
192
+ /**
193
+ * Aggregation steps are performed asynchronously, yet can be awaited upon when required.
194
+ * This function gives access to the current aggregation result's promise, which will
195
+ * eventually resolve and contain the result of the very next aggregation step, at the
196
+ * time of the function call.
197
+ * @returns The promise containing the aggregation result
198
+ */
199
+ async receiveResult() {
200
+ return await this.result;
201
+ }
202
+ /**
203
+ * The set of node ids, representing our neighbors within the network.
204
+ */
205
+ get nodes() {
206
+ return this._nodes;
207
+ }
208
+ /**
209
+ * The aggregation round.
210
+ */
211
+ get round() {
212
+ return this._round;
213
+ }
214
+ /**
215
+ * The aggregator's current size, defined by its number of contributions. The size is bounded by
216
+ * the amount of all active nodes times the number of communication rounds.
217
+ */
218
+ get size() {
219
+ return this.contributions
220
+ .valueSeq()
221
+ .map((m) => m.size)
222
+ .reduce((totalSize, size) => totalSize + size) ?? 0;
223
+ }
224
+ /**
225
+ * The aggregator's current model.
226
+ */
227
+ get model() {
228
+ return this._model;
229
+ }
230
+ /**
231
+ * The current communication round.
232
+ */
233
+ get communicationRound() {
234
+ return this._communicationRound;
235
+ }
236
+ }
@@ -0,0 +1,16 @@
1
+ import type { Task } from '../index.js';
2
+ import { aggregator } from '../index.js';
3
+ /**
4
+ * Enumeration of the available types of aggregator.
5
+ */
6
+ export declare enum AggregatorChoice {
7
+ MEAN = 0,
8
+ SECURE = 1,
9
+ BANDIT = 2
10
+ }
11
+ /**
12
+ * Provides the aggregator object adequate to the given task.
13
+ * @param task The task
14
+ * @returns The aggregator
15
+ */
16
+ export declare function getAggregator(task: Task): aggregator.Aggregator;
@@ -0,0 +1,31 @@
1
+ import { aggregator } from '../index.js';
2
+ /**
3
+ * Enumeration of the available types of aggregator.
4
+ */
5
+ export var AggregatorChoice;
6
+ (function (AggregatorChoice) {
7
+ AggregatorChoice[AggregatorChoice["MEAN"] = 0] = "MEAN";
8
+ AggregatorChoice[AggregatorChoice["SECURE"] = 1] = "SECURE";
9
+ AggregatorChoice[AggregatorChoice["BANDIT"] = 2] = "BANDIT";
10
+ })(AggregatorChoice || (AggregatorChoice = {}));
11
+ /**
12
+ * Provides the aggregator object adequate to the given task.
13
+ * @param task The task
14
+ * @returns The aggregator
15
+ */
16
+ export function getAggregator(task) {
17
+ const error = new Error('not implemented');
18
+ switch (task.trainingInformation.aggregator) {
19
+ case AggregatorChoice.MEAN:
20
+ return new aggregator.MeanAggregator();
21
+ case AggregatorChoice.BANDIT:
22
+ throw error;
23
+ case AggregatorChoice.SECURE:
24
+ if (task.trainingInformation.scheme !== 'decentralized') {
25
+ throw new Error('secure aggregation is currently supported for decentralized only');
26
+ }
27
+ return new aggregator.SecureAggregator();
28
+ default:
29
+ return new aggregator.MeanAggregator();
30
+ }
31
+ }
@@ -0,0 +1,7 @@
1
+ import type { WeightsContainer } from '../weights/index.js';
2
+ import type { Base } from './base.js';
3
+ export { Base as AggregatorBase, AggregationStep } from './base.js';
4
+ export { MeanAggregator } from './mean.js';
5
+ export { SecureAggregator } from './secure.js';
6
+ export { getAggregator, AggregatorChoice } from './get.js';
7
+ export type Aggregator = Base<WeightsContainer>;
@@ -0,0 +1,4 @@
1
+ export { Base as AggregatorBase, AggregationStep } from './base.js';
2
+ export { MeanAggregator } from './mean.js';
3
+ export { SecureAggregator } from './secure.js';
4
+ export { getAggregator, AggregatorChoice } from './get.js';
@@ -0,0 +1,23 @@
1
+ import type { Map } from 'immutable';
2
+ import { Base as Aggregator } from './base.js';
3
+ import type { Model, WeightsContainer, client } from '../index.js';
4
+ /**
5
+ * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
6
+ */
7
+ export declare class MeanAggregator extends Aggregator<WeightsContainer> {
8
+ /**
9
+ * The threshold t to fulfill to trigger an aggregation step. It can either be:
10
+ * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
11
+ * - absolute: t > 1, thus requiring t contributions
12
+ */
13
+ readonly threshold: number;
14
+ constructor(model?: Model, roundCutoff?: number, threshold?: number);
15
+ /**
16
+ * Checks whether the contributions buffer is full, according to the set threshold.
17
+ * @returns Whether the contributions buffer is full
18
+ */
19
+ isFull(): boolean;
20
+ add(nodeId: client.NodeID, contribution: WeightsContainer, round: number): boolean;
21
+ aggregate(): void;
22
+ makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
23
+ }
@@ -0,0 +1,69 @@
1
+ import { AggregationStep, Base as Aggregator } from './base.js';
2
+ import { aggregation } from '../index.js';
3
+ /**
4
+ * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
5
+ */
6
+ export class MeanAggregator extends Aggregator {
7
+ /**
8
+ * The threshold t to fulfill to trigger an aggregation step. It can either be:
9
+ * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
10
+ * - absolute: t > 1, thus requiring t contributions
11
+ */
12
+ threshold;
13
+ constructor(model, roundCutoff = 0, threshold = 1) {
14
+ super(model, roundCutoff, 1);
15
+ // Default threshold is 100% of node participation
16
+ if (threshold === undefined) {
17
+ this.threshold = 1;
18
+ // Threshold must be positive
19
+ }
20
+ else if (threshold <= 0) {
21
+ throw new Error('threshold must be positive');
22
+ // Thresholds greater than 1 are considered absolute instead of relative to the number of nodes
23
+ }
24
+ else if (threshold > 1 && Math.round(threshold) !== threshold) {
25
+ throw new Error('absolute thresholds must integers');
26
+ }
27
+ else {
28
+ this.threshold = threshold;
29
+ }
30
+ }
31
+ /**
32
+ * Checks whether the contributions buffer is full, according to the set threshold.
33
+ * @returns Whether the contributions buffer is full
34
+ */
35
+ isFull() {
36
+ if (this.threshold <= 1) {
37
+ const contribs = this.contributions.get(this.communicationRound);
38
+ if (contribs === undefined) {
39
+ return false;
40
+ }
41
+ return contribs.size >= this.threshold * this.nodes.size;
42
+ }
43
+ return this.contributions.size >= this.threshold;
44
+ }
45
+ add(nodeId, contribution, round) {
46
+ if (this.nodes.has(nodeId) && this.isWithinRoundCutoff(round)) {
47
+ this.log(this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId);
48
+ this.contributions = this.contributions.setIn([0, nodeId], contribution);
49
+ this.informant?.update();
50
+ if (this.isFull()) {
51
+ this.aggregate();
52
+ }
53
+ return true;
54
+ }
55
+ return false;
56
+ }
57
+ aggregate() {
58
+ this.log(AggregationStep.AGGREGATE);
59
+ const result = aggregation.avg(this.contributions.get(0)?.values());
60
+ if (this.model !== undefined) {
61
+ this.model.weights = result;
62
+ }
63
+ this.emit(result);
64
+ }
65
+ makePayloads(weights) {
66
+ // Communicate our local weights to every other node, be it a peer or a server
67
+ return this.nodes.toMap().map(() => weights);
68
+ }
69
+ }
@@ -0,0 +1,27 @@
1
+ import { Map, List } from 'immutable';
2
+ import { Base as Aggregator } from './base.js';
3
+ import type { Model, WeightsContainer, client } from '../index.js';
4
+ /**
5
+ * Aggregator implementing secure multi-party computation for decentralized learning.
6
+ * An aggregation consists of two communication rounds:
7
+ * - first, nodes communicate their secret shares to each other;
8
+ * - then, they sum their received shares and communicate the result.
9
+ * Finally, nodes are able to average the received partial sums to establish the aggregation result.
10
+ */
11
+ export declare class SecureAggregator extends Aggregator<WeightsContainer> {
12
+ private readonly maxShareValue;
13
+ constructor(model?: Model, maxShareValue?: number);
14
+ aggregate(): void;
15
+ add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound: number): boolean;
16
+ isFull(): boolean;
17
+ makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
18
+ /**
19
+ * Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
20
+ */
21
+ generateAllShares(secret: WeightsContainer): List<WeightsContainer>;
22
+ /**
23
+ * Generates one share in the same shape as the secret that is populated with values randomly chosen from
24
+ * a uniform distribution between (-maxShareValue, maxShareValue).
25
+ */
26
+ generateRandomShare(secret: WeightsContainer): WeightsContainer;
27
+ }