@epfml/discojs 2.0.0 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/default_tasks/cifar10/index.js +60 -0
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/default_tasks/mnist.js +61 -0
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/default_tasks/simple_face/index.js +48 -0
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/default_tasks/titanic.js +88 -0
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.d.ts +5 -0
  135. package/dist/task/digest.js +14 -0
  136. package/dist/{core/task → task}/display_information.d.ts +5 -3
  137. package/dist/task/display_information.js +46 -0
  138. package/dist/task/index.d.ts +7 -0
  139. package/dist/task/index.js +5 -0
  140. package/dist/task/label_type.d.ts +9 -0
  141. package/dist/task/label_type.js +28 -0
  142. package/dist/task/summary.js +13 -0
  143. package/dist/task/task.d.ts +12 -0
  144. package/dist/task/task.js +22 -0
  145. package/dist/task/task_handler.d.ts +5 -0
  146. package/dist/task/task_handler.js +20 -0
  147. package/dist/task/task_provider.d.ts +5 -0
  148. package/dist/task/task_provider.js +1 -0
  149. package/dist/{core/task → task}/training_information.d.ts +9 -10
  150. package/dist/task/training_information.js +88 -0
  151. package/dist/training/disco.d.ts +40 -0
  152. package/dist/training/disco.js +107 -0
  153. package/dist/training/index.d.ts +2 -0
  154. package/dist/training/index.js +1 -0
  155. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  156. package/dist/training/trainer/distributed_trainer.js +36 -0
  157. package/dist/training/trainer/local_trainer.d.ts +12 -0
  158. package/dist/training/trainer/local_trainer.js +19 -0
  159. package/dist/training/trainer/trainer.d.ts +33 -0
  160. package/dist/training/trainer/trainer.js +52 -0
  161. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  162. package/dist/training/trainer/trainer_builder.js +43 -0
  163. package/dist/types.d.ts +8 -0
  164. package/dist/types.js +1 -0
  165. package/dist/utils/event_emitter.d.ts +40 -0
  166. package/dist/utils/event_emitter.js +57 -0
  167. package/dist/validation/index.d.ts +1 -0
  168. package/dist/validation/index.js +1 -0
  169. package/dist/validation/validator.d.ts +28 -0
  170. package/dist/validation/validator.js +132 -0
  171. package/dist/weights/aggregation.d.ts +21 -0
  172. package/dist/weights/aggregation.js +44 -0
  173. package/dist/weights/index.d.ts +2 -0
  174. package/dist/weights/index.js +2 -0
  175. package/dist/weights/weights_container.d.ts +68 -0
  176. package/dist/weights/weights_container.js +96 -0
  177. package/package.json +25 -16
  178. package/README.md +0 -53
  179. package/dist/core/async_buffer.d.ts +0 -41
  180. package/dist/core/async_buffer.js +0 -97
  181. package/dist/core/async_informant.d.ts +0 -20
  182. package/dist/core/async_informant.js +0 -69
  183. package/dist/core/client/base.d.ts +0 -33
  184. package/dist/core/client/base.js +0 -35
  185. package/dist/core/client/decentralized/base.d.ts +0 -32
  186. package/dist/core/client/decentralized/base.js +0 -212
  187. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  188. package/dist/core/client/decentralized/clear_text.js +0 -96
  189. package/dist/core/client/decentralized/index.d.ts +0 -4
  190. package/dist/core/client/decentralized/index.js +0 -9
  191. package/dist/core/client/decentralized/messages.d.ts +0 -41
  192. package/dist/core/client/decentralized/messages.js +0 -54
  193. package/dist/core/client/decentralized/peer.d.ts +0 -26
  194. package/dist/core/client/decentralized/peer.js +0 -210
  195. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  196. package/dist/core/client/decentralized/peer_pool.js +0 -92
  197. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  198. package/dist/core/client/decentralized/sec_agg.js +0 -190
  199. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  200. package/dist/core/client/decentralized/secret_shares.js +0 -39
  201. package/dist/core/client/decentralized/types.d.ts +0 -2
  202. package/dist/core/client/decentralized/types.js +0 -7
  203. package/dist/core/client/event_connection.d.ts +0 -37
  204. package/dist/core/client/event_connection.js +0 -158
  205. package/dist/core/client/federated/client.d.ts +0 -37
  206. package/dist/core/client/federated/client.js +0 -273
  207. package/dist/core/client/federated/index.d.ts +0 -2
  208. package/dist/core/client/federated/index.js +0 -7
  209. package/dist/core/client/federated/messages.d.ts +0 -38
  210. package/dist/core/client/federated/messages.js +0 -25
  211. package/dist/core/client/index.d.ts +0 -5
  212. package/dist/core/client/index.js +0 -11
  213. package/dist/core/client/local.d.ts +0 -8
  214. package/dist/core/client/local.js +0 -36
  215. package/dist/core/client/messages.d.ts +0 -28
  216. package/dist/core/client/messages.js +0 -33
  217. package/dist/core/client/utils.d.ts +0 -2
  218. package/dist/core/client/utils.js +0 -19
  219. package/dist/core/dataset/data/data.d.ts +0 -11
  220. package/dist/core/dataset/data/data.js +0 -20
  221. package/dist/core/dataset/data/data_split.d.ts +0 -5
  222. package/dist/core/dataset/data/data_split.js +0 -2
  223. package/dist/core/dataset/data/image_data.d.ts +0 -8
  224. package/dist/core/dataset/data/image_data.js +0 -64
  225. package/dist/core/dataset/data/index.d.ts +0 -5
  226. package/dist/core/dataset/data/index.js +0 -11
  227. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  228. package/dist/core/dataset/data/preprocessing.js +0 -33
  229. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  230. package/dist/core/dataset/data/tabular_data.js +0 -40
  231. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  232. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  233. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  234. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  235. package/dist/core/dataset/data_loader/index.js +0 -9
  236. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  237. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  238. package/dist/core/dataset/dataset.d.ts +0 -2
  239. package/dist/core/dataset/dataset.js +0 -2
  240. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  241. package/dist/core/dataset/dataset_builder.js +0 -96
  242. package/dist/core/dataset/index.d.ts +0 -4
  243. package/dist/core/dataset/index.js +0 -14
  244. package/dist/core/index.d.ts +0 -18
  245. package/dist/core/index.js +0 -41
  246. package/dist/core/informant/graph_informant.js +0 -23
  247. package/dist/core/informant/index.d.ts +0 -3
  248. package/dist/core/informant/index.js +0 -9
  249. package/dist/core/informant/training_informant/base.d.ts +0 -31
  250. package/dist/core/informant/training_informant/base.js +0 -83
  251. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  252. package/dist/core/informant/training_informant/decentralized.js +0 -22
  253. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  254. package/dist/core/informant/training_informant/federated.js +0 -32
  255. package/dist/core/informant/training_informant/index.d.ts +0 -4
  256. package/dist/core/informant/training_informant/index.js +0 -11
  257. package/dist/core/informant/training_informant/local.d.ts +0 -6
  258. package/dist/core/informant/training_informant/local.js +0 -20
  259. package/dist/core/logging/console_logger.js +0 -33
  260. package/dist/core/logging/index.d.ts +0 -3
  261. package/dist/core/logging/index.js +0 -9
  262. package/dist/core/logging/logger.js +0 -9
  263. package/dist/core/logging/trainer_logger.d.ts +0 -24
  264. package/dist/core/logging/trainer_logger.js +0 -59
  265. package/dist/core/memory/base.d.ts +0 -22
  266. package/dist/core/memory/base.js +0 -9
  267. package/dist/core/memory/empty.d.ts +0 -14
  268. package/dist/core/memory/empty.js +0 -75
  269. package/dist/core/memory/index.d.ts +0 -3
  270. package/dist/core/memory/index.js +0 -9
  271. package/dist/core/memory/model_type.d.ts +0 -4
  272. package/dist/core/memory/model_type.js +0 -9
  273. package/dist/core/serialization/index.d.ts +0 -2
  274. package/dist/core/serialization/index.js +0 -6
  275. package/dist/core/serialization/model.d.ts +0 -5
  276. package/dist/core/serialization/model.js +0 -55
  277. package/dist/core/serialization/weights.js +0 -64
  278. package/dist/core/task/data_example.js +0 -24
  279. package/dist/core/task/display_information.js +0 -49
  280. package/dist/core/task/index.d.ts +0 -3
  281. package/dist/core/task/index.js +0 -8
  282. package/dist/core/task/model_compile_data.d.ts +0 -6
  283. package/dist/core/task/model_compile_data.js +0 -22
  284. package/dist/core/task/summary.js +0 -19
  285. package/dist/core/task/task.d.ts +0 -10
  286. package/dist/core/task/task.js +0 -31
  287. package/dist/core/task/training_information.js +0 -66
  288. package/dist/core/tasks/cifar10.d.ts +0 -3
  289. package/dist/core/tasks/cifar10.js +0 -65
  290. package/dist/core/tasks/geotags.d.ts +0 -3
  291. package/dist/core/tasks/geotags.js +0 -67
  292. package/dist/core/tasks/index.d.ts +0 -6
  293. package/dist/core/tasks/index.js +0 -10
  294. package/dist/core/tasks/lus_covid.d.ts +0 -3
  295. package/dist/core/tasks/lus_covid.js +0 -87
  296. package/dist/core/tasks/mnist.d.ts +0 -3
  297. package/dist/core/tasks/mnist.js +0 -60
  298. package/dist/core/tasks/simple_face.d.ts +0 -2
  299. package/dist/core/tasks/simple_face.js +0 -41
  300. package/dist/core/tasks/titanic.d.ts +0 -3
  301. package/dist/core/tasks/titanic.js +0 -88
  302. package/dist/core/training/disco.d.ts +0 -23
  303. package/dist/core/training/disco.js +0 -130
  304. package/dist/core/training/index.d.ts +0 -2
  305. package/dist/core/training/index.js +0 -7
  306. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  307. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  308. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  309. package/dist/core/training/trainer/local_trainer.js +0 -34
  310. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  311. package/dist/core/training/trainer/round_tracker.js +0 -47
  312. package/dist/core/training/trainer/trainer.d.ts +0 -65
  313. package/dist/core/training/trainer/trainer.js +0 -160
  314. package/dist/core/training/trainer/trainer_builder.js +0 -95
  315. package/dist/core/training/training_schemes.d.ts +0 -5
  316. package/dist/core/training/training_schemes.js +0 -10
  317. package/dist/core/types.d.ts +0 -4
  318. package/dist/core/types.js +0 -2
  319. package/dist/core/validation/index.d.ts +0 -1
  320. package/dist/core/validation/index.js +0 -5
  321. package/dist/core/validation/validator.d.ts +0 -17
  322. package/dist/core/validation/validator.js +0 -104
  323. package/dist/core/weights/aggregation.d.ts +0 -8
  324. package/dist/core/weights/aggregation.js +0 -96
  325. package/dist/core/weights/index.d.ts +0 -2
  326. package/dist/core/weights/index.js +0 -7
  327. package/dist/core/weights/weights_container.d.ts +0 -19
  328. package/dist/core/weights/weights_container.js +0 -64
  329. package/dist/imports.d.ts +0 -2
  330. package/dist/imports.js +0 -7
  331. package/dist/memory/memory.d.ts +0 -26
  332. package/dist/memory/memory.js +0 -160
  333. package/dist/{core/task → task}/data_example.d.ts +1 -1
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -0,0 +1,107 @@
1
+ import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js';
2
+ import { MeanAggregator } from '../aggregator/mean.js';
3
+ import { TrainerBuilder } from './trainer/trainer_builder.js';
4
+ /**
5
+ * Top-level class handling distributed training from a client's perspective. It is meant to be
6
+ * a convenient object providing a reduced yet complete API that wraps model training,
7
+ * communication with nodes, logs and model memory.
8
+ */
9
+ export class Disco {
10
+ task;
11
+ logger;
12
+ memory;
13
+ client;
14
+ trainer;
15
+ constructor(task, options) {
16
+ if (options.scheme === undefined) {
17
+ options.scheme = task.trainingInformation.scheme;
18
+ }
19
+ if (options.aggregator === undefined) {
20
+ options.aggregator = new MeanAggregator();
21
+ }
22
+ if (options.client === undefined) {
23
+ if (options.url === undefined) {
24
+ throw new Error('could not determine client from given parameters');
25
+ }
26
+ if (typeof options.url === 'string') {
27
+ options.url = new URL(options.url);
28
+ }
29
+ switch (options.scheme) {
30
+ case 'federated':
31
+ options.client = new clients.federated.FederatedClient(options.url, task, options.aggregator);
32
+ break;
33
+ case 'decentralized':
34
+ options.client = new clients.decentralized.DecentralizedClient(options.url, task, options.aggregator);
35
+ break;
36
+ case 'local':
37
+ options.client = new clients.Local(options.url, task, options.aggregator);
38
+ break;
39
+ default: {
40
+ const _ = options.scheme;
41
+ throw new Error('should never happen');
42
+ }
43
+ }
44
+ }
45
+ if (options.logger === undefined) {
46
+ options.logger = new ConsoleLogger();
47
+ }
48
+ if (options.memory === undefined) {
49
+ options.memory = new EmptyMemory();
50
+ }
51
+ if (options.client.task !== task) {
52
+ throw new Error('client not setup for given task');
53
+ }
54
+ this.task = task;
55
+ this.client = options.client;
56
+ this.memory = options.memory;
57
+ this.logger = options.logger;
58
+ const trainerBuilder = new TrainerBuilder(this.memory, this.task);
59
+ this.trainer = trainerBuilder.build(this.client, options.scheme !== 'local');
60
+ }
61
+ /**
62
+ * Starts a training instance for the Disco object's task on the provided data tuple.
63
+ * @param dataTuple The data tuple
64
+ */
65
+ // TODO RoundLogs should contain number of participants but Trainer doesn't need client
66
+ async *fit(dataTuple) {
67
+ this.logger.success("Training started.");
68
+ const trainData = dataTuple.train.preprocess().batch();
69
+ const validationData = dataTuple.validation?.preprocess().batch() ?? trainData;
70
+ await this.client.connect();
71
+ const trainer = await this.trainer;
72
+ for await (const roundLogs of trainer.fitModel(trainData.dataset, validationData.dataset)) {
73
+ let msg = `Round: ${roundLogs.round}\n`;
74
+ for (const epochLogs of roundLogs.epochs.values()) {
75
+ msg += ` Epoch: ${epochLogs.epoch}\n`;
76
+ msg += ` Training loss: ${epochLogs.training.loss}\n`;
77
+ if (epochLogs.training.accuracy !== undefined) {
78
+ msg += ` Training accuracy: ${epochLogs.training.accuracy}\n`;
79
+ }
80
+ if (epochLogs.validation !== undefined) {
81
+ msg += ` Validation loss: ${epochLogs.validation.loss}\n`;
82
+ msg += ` Validation accuracy: ${epochLogs.validation.accuracy}\n`;
83
+ }
84
+ }
85
+ this.logger.success(msg);
86
+ yield {
87
+ ...roundLogs,
88
+ participants: this.client.nodes.size + 1 // add ourself
89
+ };
90
+ }
91
+ this.logger.success("Training finished.");
92
+ }
93
+ /**
94
+ * Stops the ongoing training instance without disconnecting the client.
95
+ */
96
+ async pause() {
97
+ const trainer = await this.trainer;
98
+ await trainer.stopTraining();
99
+ }
100
+ /**
101
+ * Completely stops the ongoing training instance.
102
+ */
103
+ async close() {
104
+ await this.pause();
105
+ await this.client.disconnect();
106
+ }
107
+ }
@@ -0,0 +1,2 @@
1
+ export { Disco } from './disco.js';
2
+ export { RoundLogs } from './trainer/trainer.js';
@@ -0,0 +1 @@
1
+ export { Disco } from './disco.js';
@@ -0,0 +1,20 @@
1
+ import type { Model, Memory, Task, client as clients } from "../../index.js";
2
+ import { Trainer } from "./trainer.js";
3
+ /**
4
+ * Class whose role is to train a model in a distributed way with a given dataset.
5
+ */
6
+ export declare class DistributedTrainer extends Trainer {
7
+ private readonly task;
8
+ private readonly memory;
9
+ private readonly client;
10
+ private readonly aggregator;
11
+ /**
12
+ * DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
13
+ */
14
+ constructor(task: Task, memory: Memory, model: Model, client: clients.Client);
15
+ onRoundBegin(round: number): Promise<void>;
16
+ /**
17
+ * Callback called every time a round is over
18
+ */
19
+ onRoundEnd(round: number): Promise<void>;
20
+ }
@@ -0,0 +1,36 @@
1
+ import { Trainer } from "./trainer.js";
2
+ /**
3
+ * Class whose role is to train a model in a distributed way with a given dataset.
4
+ */
5
+ export class DistributedTrainer extends Trainer {
6
+ task;
7
+ memory;
8
+ client;
9
+ aggregator;
10
+ /**
11
+ * DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
12
+ */
13
+ constructor(task, memory, model, client) {
14
+ super(task, model);
15
+ this.task = task;
16
+ this.memory = memory;
17
+ this.client = client;
18
+ this.aggregator = this.client.aggregator;
19
+ this.aggregator.setModel(model);
20
+ }
21
+ async onRoundBegin(round) {
22
+ await this.client.onRoundBeginCommunication(this.model.weights, round);
23
+ }
24
+ /**
25
+ * Callback called every time a round is over
26
+ */
27
+ async onRoundEnd(round) {
28
+ await this.client.onRoundEndCommunication(this.model.weights, round);
29
+ if (this.aggregator.model !== undefined) {
30
+ // The aggregator's own aggregation is async. The trainer updates its model to match the aggregator's
31
+ // after it has completed a round of training.
32
+ this.model.weights = this.aggregator.model.weights;
33
+ }
34
+ await this.memory.updateWorkingModel({ taskID: this.task.id, name: this.task.trainingInformation.modelID }, this.model);
35
+ }
36
+ }
@@ -0,0 +1,12 @@
1
+ import type { Memory, Model, Task } from "../../index.js";
2
+ import { Trainer } from "./trainer.js";
3
+ /** Class whose role is to locally (alone) train a model on a given dataset,
4
+ * without any collaborators.
5
+ */
6
+ export declare class LocalTrainer extends Trainer {
7
+ private readonly task;
8
+ private readonly memory;
9
+ constructor(task: Task, memory: Memory, model: Model);
10
+ onRoundBegin(): Promise<void>;
11
+ onRoundEnd(): Promise<void>;
12
+ }
@@ -0,0 +1,19 @@
1
+ import { Trainer } from "./trainer.js";
2
+ /** Class whose role is to locally (alone) train a model on a given dataset,
3
+ * without any collaborators.
4
+ */
5
+ export class LocalTrainer extends Trainer {
6
+ task;
7
+ memory;
8
+ constructor(task, memory, model) {
9
+ super(task, model);
10
+ this.task = task;
11
+ this.memory = memory;
12
+ }
13
+ async onRoundBegin() {
14
+ return await Promise.resolve();
15
+ }
16
+ async onRoundEnd() {
17
+ await this.memory.updateWorkingModel({ taskID: this.task.id, name: this.task.trainingInformation.modelID }, this.model);
18
+ }
19
+ }
@@ -0,0 +1,33 @@
1
+ import type tf from "@tensorflow/tfjs";
2
+ import { List } from "immutable";
3
+ import type { Model, Task } from "../../index.js";
4
+ import { EpochLogs } from "../../models/model.js";
5
+ export interface RoundLogs {
6
+ round: number;
7
+ epochs: List<EpochLogs>;
8
+ }
9
+ /** Abstract class whose role is to train a model with a given dataset. This can be either done
10
+ * locally (alone) or in a distributed way with collaborators.
11
+ *
12
+ * 1. Call `fitModel(dataset)` to start training.
13
+ * 2. which will then call onRoundEnd once the round has ended.
14
+ *
15
+ * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators.
16
+ */
17
+ export declare abstract class Trainer {
18
+ #private;
19
+ readonly model: Model;
20
+ private training?;
21
+ constructor(task: Task, model: Model);
22
+ protected abstract onRoundBegin(round: number): Promise<void>;
23
+ protected abstract onRoundEnd(round: number): Promise<void>;
24
+ /**
25
+ * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
26
+ */
27
+ stopTraining(): Promise<void>;
28
+ /**
29
+ * Start training the model with the given dataset
30
+ * @param dataset
31
+ */
32
+ fitModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<RoundLogs>;
33
+ }
@@ -0,0 +1,52 @@
1
+ import { List } from "immutable";
2
+ /** Abstract class whose role is to train a model with a given dataset. This can be either done
3
+ * locally (alone) or in a distributed way with collaborators.
4
+ *
5
+ * 1. Call `fitModel(dataset)` to start training.
6
+ * 2. which will then call onRoundEnd once the round has ended.
7
+ *
8
+ * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators.
9
+ */
10
+ export class Trainer {
11
+ model;
12
+ #roundDuration;
13
+ #epochs;
14
+ training;
15
+ constructor(task, model) {
16
+ this.model = model;
17
+ this.#roundDuration = task.trainingInformation.roundDuration;
18
+ this.#epochs = task.trainingInformation.epochs;
19
+ }
20
+ /**
21
+ * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
22
+ */
23
+ async stopTraining() {
24
+ await this.training?.return();
25
+ }
26
+ /**
27
+ * Start training the model with the given dataset
28
+ * @param dataset
29
+ */
30
+ async *fitModel(dataset, valDataset) {
31
+ if (this.training !== undefined) {
32
+ throw new Error("training already running, cancel it before launching a new one");
33
+ }
34
+ await this.onRoundBegin(0);
35
+ this.training = this.model.train(dataset, valDataset, this.#epochs);
36
+ for await (const logs of this.training) {
37
+ // for now, round (sharing on network) == epoch (full pass over local data)
38
+ yield {
39
+ round: logs.epoch,
40
+ epochs: List.of(logs),
41
+ };
42
+ if (logs.epoch % this.#roundDuration === 0) {
43
+ const round = Math.trunc(logs.epoch / this.#roundDuration);
44
+ await this.onRoundEnd(round);
45
+ await this.onRoundBegin(round);
46
+ }
47
+ }
48
+ const round = Math.trunc(this.#epochs / this.#roundDuration);
49
+ await this.onRoundEnd(round);
50
+ this.training = undefined;
51
+ }
52
+ }
@@ -1,13 +1,12 @@
1
- import { Client, Task, TrainingInformant, Memory } from '../..';
2
- import { Trainer } from './trainer';
1
+ import type { client as clients, Task, Memory } from '../../index.js';
2
+ import type { Trainer } from './trainer.js';
3
3
  /**
4
4
  * A class that helps build the Trainer and auxiliary classes.
5
5
  */
6
6
  export declare class TrainerBuilder {
7
7
  private readonly memory;
8
8
  private readonly task;
9
- private readonly trainingInformant;
10
- constructor(memory: Memory, task: Task, trainingInformant: TrainingInformant);
9
+ constructor(memory: Memory, task: Task);
11
10
  /**
12
11
  * Builds a trainer object.
13
12
  *
@@ -15,11 +14,10 @@ export declare class TrainerBuilder {
15
14
  * @param distributed whether to build a distributed or local trainer
16
15
  * @returns
17
16
  */
18
- build(client: Client, distributed?: boolean): Promise<Trainer>;
17
+ build(client: clients.Client, distributed?: boolean): Promise<Trainer>;
19
18
  /**
20
- * If a model exists in memory, laod it, otherwise load model from server
19
+ * If a model exists in memory, load it, otherwise load model from server
21
20
  * @returns
22
21
  */
23
22
  private getModel;
24
- private updateModelInformation;
25
23
  }
@@ -0,0 +1,43 @@
1
+ import { ModelType } from '../../index.js';
2
+ import { DistributedTrainer } from './distributed_trainer.js';
3
+ import { LocalTrainer } from './local_trainer.js';
4
+ /**
5
+ * A class that helps build the Trainer and auxiliary classes.
6
+ */
7
+ export class TrainerBuilder {
8
+ memory;
9
+ task;
10
+ constructor(memory, task) {
11
+ this.memory = memory;
12
+ this.task = task;
13
+ }
14
+ /**
15
+ * Builds a trainer object.
16
+ *
17
+ * @param client client to share weights with (either distributed or federated)
18
+ * @param distributed whether to build a distributed or local trainer
19
+ * @returns
20
+ */
21
+ async build(client, distributed = false) {
22
+ const model = await this.getModel(client);
23
+ if (distributed) {
24
+ return new DistributedTrainer(this.task, this.memory, model, client);
25
+ }
26
+ else {
27
+ return new LocalTrainer(this.task, this.memory, model);
28
+ }
29
+ }
30
+ /**
31
+ * If a model exists in memory, load it, otherwise load model from server
32
+ * @returns
33
+ */
34
+ async getModel(client) {
35
+ const modelID = this.task.trainingInformation?.modelID;
36
+ if (modelID === undefined) {
37
+ throw new TypeError('model ID is undefined');
38
+ }
39
+ const info = { type: ModelType.WORKING, taskID: this.task.id, name: modelID };
40
+ const model = await (await this.memory.contains(info) ? this.memory.getModel(info) : client.getLatestModel());
41
+ return model;
42
+ }
43
+ }
@@ -0,0 +1,8 @@
1
+ import type { Map } from 'immutable';
2
+ import type { WeightsContainer } from './index.js';
3
+ import type { NodeID } from './client/index.js';
4
+ export type Path = string;
5
+ export type MetadataKey = string;
6
+ export type MetadataValue = string;
7
+ export type Features = number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][];
8
+ export type Contributions = Map<NodeID, WeightsContainer>;
package/dist/types.js ADDED
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,40 @@
1
+ type Listener<T> = (_: T) => void;
2
+ /**
3
+ * Call handlers on given events
4
+ *
5
+ * @typeParam I object/mapping from event name to emitted value type
6
+ */
7
+ export declare class EventEmitter<I extends Record<string, unknown>> {
8
+ private listeners;
9
+ /**
10
+ * @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
11
+ */
12
+ constructor(initialListeners?: {
13
+ [E in keyof I]?: Listener<I[E]>;
14
+ });
15
+ /**
16
+ * Register listener to call on event
17
+ *
18
+ * @param event event name to listen to
19
+ * @param listener handler to call
20
+ */
21
+ on<E extends keyof I>(event: E, listener: Listener<I[E]>): void;
22
+ /**
23
+ * Register listener to call once on next event
24
+ *
25
+ * @param event event name to listen to
26
+ * @param listener handler to call next time
27
+ */
28
+ once<E extends keyof I>(event: E, listener: Listener<I[E]>): void;
29
+ /**
30
+ * Send value to registered listeners of event name
31
+ *
32
+ * @param event send to listeners of event name
33
+ * @param value what to call listeners with
34
+ */
35
+ emit<E extends keyof I>(event: E, value: I[E]): void;
36
+ }
37
+ /** `EventEmitter` for all events */
38
+ export declare class Sink extends EventEmitter<Record<string, unknown>> {
39
+ }
40
+ export {};
@@ -0,0 +1,57 @@
1
+ // inspired by https://danilafe.com/blog/typescript_typesafe_events/
2
+ import { List } from 'immutable';
3
+ /**
4
+ * Call handlers on given events
5
+ *
6
+ * @typeParam I object/mapping from event name to emitted value type
7
+ */
8
+ export class EventEmitter {
9
+ listeners = {};
10
+ /**
11
+ * @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
12
+ */
13
+ constructor(initialListeners = {}) {
14
+ for (const event in initialListeners) {
15
+ const listener = initialListeners[event];
16
+ if (listener !== undefined) {
17
+ this.on(event, listener);
18
+ }
19
+ }
20
+ }
21
+ /**
22
+ * Register listener to call on event
23
+ *
24
+ * @param event event name to listen to
25
+ * @param listener handler to call
26
+ */
27
+ on(event, listener) {
28
+ const eventListeners = this.listeners[event] ?? List();
29
+ this.listeners[event] = eventListeners.push([false, listener]);
30
+ }
31
+ /**
32
+ * Register listener to call once on next event
33
+ *
34
+ * @param event event name to listen to
35
+ * @param listener handler to call next time
36
+ */
37
+ once(event, listener) {
38
+ const eventListeners = this.listeners[event] ?? List();
39
+ this.listeners[event] = eventListeners.push([true, listener]);
40
+ }
41
+ /**
42
+ * Send value to registered listeners of event name
43
+ *
44
+ * @param event send to listeners of event name
45
+ * @param value what to call listeners with
46
+ */
47
+ emit(event, value) {
48
+ const eventListeners = this.listeners[event] ?? List();
49
+ this.listeners[event] = eventListeners.filterNot(([once]) => once);
50
+ eventListeners.forEach(([_, listener]) => {
51
+ listener(value);
52
+ });
53
+ }
54
+ }
55
+ /** `EventEmitter` for all events */
56
+ export class Sink extends EventEmitter {
57
+ }
@@ -0,0 +1 @@
1
+ export { Validator } from './validator.js';
@@ -0,0 +1 @@
1
+ export { Validator } from './validator.js';
@@ -0,0 +1,28 @@
1
+ import { List } from 'immutable';
2
+ import type { data, Model, Task, Logger, client as clients, Memory, ModelSource, Features } from '../index.js';
3
+ export declare class Validator {
4
+ readonly task: Task;
5
+ readonly logger: Logger;
6
+ private readonly memory;
7
+ private readonly source?;
8
+ private readonly client?;
9
+ private readonly graphInformant;
10
+ private size;
11
+ private _confusionMatrix;
12
+ constructor(task: Task, logger: Logger, memory: Memory, source?: ModelSource | undefined, client?: clients.Client | undefined);
13
+ private getLabel;
14
+ assess(data: data.Data, useConfusionMatrix?: boolean): Promise<Array<{
15
+ groundTruth: number;
16
+ pred: number;
17
+ features: Features;
18
+ }>>;
19
+ predict(data: data.Data): Promise<Array<{
20
+ features: Features;
21
+ pred: number;
22
+ }>>;
23
+ getModel(): Promise<Model>;
24
+ get accuracyData(): List<number>;
25
+ get accuracy(): number;
26
+ get visitedSamples(): number;
27
+ get confusionMatrix(): number[][] | undefined;
28
+ }
@@ -0,0 +1,132 @@
1
+ import { List } from 'immutable';
2
+ import * as tf from '@tensorflow/tfjs';
3
+ import { GraphInformant } from '../index.js';
4
+ export class Validator {
5
+ task;
6
+ logger;
7
+ memory;
8
+ source;
9
+ client;
10
+ graphInformant = new GraphInformant();
11
+ size = 0;
12
+ _confusionMatrix;
13
+ constructor(task, logger, memory, source, client) {
14
+ this.task = task;
15
+ this.logger = logger;
16
+ this.memory = memory;
17
+ this.source = source;
18
+ this.client = client;
19
+ if (source === undefined && client === undefined) {
20
+ throw new Error('To initialize a Validator, either or both a source and client need to be specified');
21
+ }
22
+ }
23
+ async getLabel(ys) {
24
+ switch (ys.shape[1]) {
25
+ case 1:
26
+ return await ys.greaterEqual(tf.scalar(0.5)).data();
27
+ case 2:
28
+ return await ys.argMax(1).data();
29
+ default:
30
+ throw new Error(`unable to reduce tensor of shape: ${ys.shape.toString()}`);
31
+ }
32
+ }
33
+ async assess(data, useConfusionMatrix = false) {
34
+ const batchSize = this.task.trainingInformation?.batchSize;
35
+ if (batchSize === undefined) {
36
+ throw new TypeError('Batch size is undefined');
37
+ }
38
+ const model = await this.getModel();
39
+ let features = [];
40
+ const groundTruth = [];
41
+ let hits = 0;
42
+ // Get model predictions per batch and flatten the result
43
+ // Also build the features and ground truth arrays
44
+ const predictions = (await data.preprocess().dataset.batch(batchSize)
45
+ .mapAsync(async (e) => {
46
+ if (typeof e === 'object' && 'xs' in e && 'ys' in e) {
47
+ const xs = e.xs;
48
+ const ys = await this.getLabel(e.ys);
49
+ const pred = await this.getLabel(await model.predict(xs));
50
+ const currentFeatures = await xs.array();
51
+ if (Array.isArray(currentFeatures)) {
52
+ features = features.concat(currentFeatures);
53
+ }
54
+ else {
55
+ throw new TypeError('Data format is incorrect');
56
+ }
57
+ groundTruth.push(...Array.from(ys));
58
+ this.size += xs.shape[0];
59
+ hits += List(pred).zip(List(ys)).filter(([p, y]) => p === y).size;
60
+ // TODO: Confusion Matrix stats
61
+ const currentAccuracy = hits / this.size;
62
+ this.graphInformant.updateAccuracy(currentAccuracy);
63
+ return Array.from(pred);
64
+ }
65
+ else {
66
+ throw new Error('Input data is missing a feature or the label');
67
+ }
68
+ }).toArray()).flat();
69
+ this.logger.success(`Obtained validation accuracy of ${this.accuracy}`);
70
+ this.logger.success(`Visited ${this.visitedSamples} samples`);
71
+ if (useConfusionMatrix) {
72
+ try {
73
+ this._confusionMatrix = tf.math.confusionMatrix([], [], 0).arraySync();
74
+ }
75
+ catch (e) {
76
+ console.error(e instanceof Error ? e.message : e);
77
+ throw new Error('Failed to compute the confusion matrix');
78
+ }
79
+ }
80
+ return List(groundTruth)
81
+ .zip(List(predictions), List(features))
82
+ .map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f }))
83
+ .toArray();
84
+ }
85
+ async predict(data) {
86
+ const batchSize = this.task.trainingInformation?.batchSize;
87
+ if (batchSize === undefined) {
88
+ throw new TypeError('Batch size is undefined');
89
+ }
90
+ const model = await this.getModel();
91
+ let features = [];
92
+ // Get model prediction per batch and flatten the result
93
+ // Also incrementally build the features array
94
+ const predictions = (await data.preprocess().dataset.batch(batchSize)
95
+ .mapAsync(async (e) => {
96
+ const xs = e;
97
+ const currentFeatures = await xs.array();
98
+ if (Array.isArray(currentFeatures)) {
99
+ features = features.concat(currentFeatures);
100
+ }
101
+ else {
102
+ throw new TypeError('Data format is incorrect');
103
+ }
104
+ const pred = await this.getLabel(await model.predict(xs));
105
+ return Array.from(pred);
106
+ }).toArray()).flat();
107
+ return List(features).zip(List(predictions))
108
+ .map(([f, p]) => ({ features: f, pred: p }))
109
+ .toArray();
110
+ }
111
+ async getModel() {
112
+ if (this.source !== undefined && await this.memory.contains(this.source)) {
113
+ return await this.memory.getModel(this.source);
114
+ }
115
+ if (this.client !== undefined) {
116
+ return await this.client.getLatestModel();
117
+ }
118
+ throw new Error('Could not load the model');
119
+ }
120
+ get accuracyData() {
121
+ return this.graphInformant.data();
122
+ }
123
+ get accuracy() {
124
+ return this.graphInformant.accuracy();
125
+ }
126
+ get visitedSamples() {
127
+ return this.size;
128
+ }
129
+ get confusionMatrix() {
130
+ return this._confusionMatrix;
131
+ }
132
+ }
@@ -0,0 +1,21 @@
1
+ import type { TensorLike } from './weights_container.js';
2
+ import { WeightsContainer } from './weights_container.js';
3
+ type WeightsLike = Iterable<TensorLike>;
4
+ /**
5
+ * Sums the given iterable of weights entry-wise.
6
+ * @param weights The list of weights to sum
7
+ * @returns The summed weights
8
+ */
9
+ export declare function sum(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
10
+ /**
11
+ * Computes the successive entry-wise difference between the weights of the given iterable.
12
+ * The operation is not commutative w.r.t. the iterable's ordering.
13
+ */
14
+ export declare function diff(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
15
+ /**
16
+ * Averages the given iterable of weights entry-wise.
17
+ * @param weights The list of weights to average
18
+ * @returns The averaged weights
19
+ */
20
+ export declare function avg(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
21
+ export {};