@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.js +14 -0
  135. package/dist/{core/task → task}/display_information.d.ts +5 -3
  136. package/dist/task/display_information.js +46 -0
  137. package/dist/task/index.d.ts +7 -0
  138. package/dist/task/index.js +5 -0
  139. package/dist/task/label_type.d.ts +9 -0
  140. package/dist/task/label_type.js +28 -0
  141. package/dist/task/summary.js +13 -0
  142. package/dist/{core/task → task}/task.d.ts +7 -7
  143. package/dist/task/task.js +22 -0
  144. package/dist/task/task_handler.d.ts +5 -0
  145. package/dist/task/task_handler.js +20 -0
  146. package/dist/task/task_provider.d.ts +5 -0
  147. package/dist/task/task_provider.js +1 -0
  148. package/dist/{core/task → task}/training_information.d.ts +9 -10
  149. package/dist/task/training_information.js +88 -0
  150. package/dist/training/disco.d.ts +40 -0
  151. package/dist/training/disco.js +107 -0
  152. package/dist/training/index.d.ts +2 -0
  153. package/dist/training/index.js +1 -0
  154. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  155. package/dist/training/trainer/distributed_trainer.js +36 -0
  156. package/dist/training/trainer/local_trainer.d.ts +12 -0
  157. package/dist/training/trainer/local_trainer.js +19 -0
  158. package/dist/training/trainer/trainer.d.ts +33 -0
  159. package/dist/training/trainer/trainer.js +52 -0
  160. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  161. package/dist/training/trainer/trainer_builder.js +43 -0
  162. package/dist/types.d.ts +8 -0
  163. package/dist/types.js +1 -0
  164. package/dist/utils/event_emitter.d.ts +40 -0
  165. package/dist/utils/event_emitter.js +57 -0
  166. package/dist/validation/index.d.ts +1 -0
  167. package/dist/validation/index.js +1 -0
  168. package/dist/validation/validator.d.ts +28 -0
  169. package/dist/validation/validator.js +132 -0
  170. package/dist/weights/aggregation.d.ts +21 -0
  171. package/dist/weights/aggregation.js +44 -0
  172. package/dist/weights/index.d.ts +2 -0
  173. package/dist/weights/index.js +2 -0
  174. package/dist/weights/weights_container.d.ts +68 -0
  175. package/dist/weights/weights_container.js +96 -0
  176. package/package.json +24 -15
  177. package/README.md +0 -53
  178. package/dist/core/async_buffer.d.ts +0 -41
  179. package/dist/core/async_buffer.js +0 -97
  180. package/dist/core/async_informant.d.ts +0 -20
  181. package/dist/core/async_informant.js +0 -69
  182. package/dist/core/client/base.d.ts +0 -33
  183. package/dist/core/client/base.js +0 -35
  184. package/dist/core/client/decentralized/base.d.ts +0 -32
  185. package/dist/core/client/decentralized/base.js +0 -212
  186. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  187. package/dist/core/client/decentralized/clear_text.js +0 -96
  188. package/dist/core/client/decentralized/index.d.ts +0 -4
  189. package/dist/core/client/decentralized/index.js +0 -9
  190. package/dist/core/client/decentralized/messages.d.ts +0 -41
  191. package/dist/core/client/decentralized/messages.js +0 -54
  192. package/dist/core/client/decentralized/peer.d.ts +0 -26
  193. package/dist/core/client/decentralized/peer.js +0 -210
  194. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  195. package/dist/core/client/decentralized/peer_pool.js +0 -92
  196. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  197. package/dist/core/client/decentralized/sec_agg.js +0 -190
  198. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  199. package/dist/core/client/decentralized/secret_shares.js +0 -39
  200. package/dist/core/client/decentralized/types.d.ts +0 -2
  201. package/dist/core/client/decentralized/types.js +0 -7
  202. package/dist/core/client/event_connection.d.ts +0 -37
  203. package/dist/core/client/event_connection.js +0 -158
  204. package/dist/core/client/federated/client.d.ts +0 -37
  205. package/dist/core/client/federated/client.js +0 -273
  206. package/dist/core/client/federated/index.d.ts +0 -2
  207. package/dist/core/client/federated/index.js +0 -7
  208. package/dist/core/client/federated/messages.d.ts +0 -38
  209. package/dist/core/client/federated/messages.js +0 -25
  210. package/dist/core/client/index.d.ts +0 -5
  211. package/dist/core/client/index.js +0 -11
  212. package/dist/core/client/local.d.ts +0 -8
  213. package/dist/core/client/local.js +0 -36
  214. package/dist/core/client/messages.d.ts +0 -28
  215. package/dist/core/client/messages.js +0 -33
  216. package/dist/core/client/utils.d.ts +0 -2
  217. package/dist/core/client/utils.js +0 -19
  218. package/dist/core/dataset/data/data.d.ts +0 -11
  219. package/dist/core/dataset/data/data.js +0 -20
  220. package/dist/core/dataset/data/data_split.d.ts +0 -5
  221. package/dist/core/dataset/data/data_split.js +0 -2
  222. package/dist/core/dataset/data/image_data.d.ts +0 -8
  223. package/dist/core/dataset/data/image_data.js +0 -64
  224. package/dist/core/dataset/data/index.d.ts +0 -5
  225. package/dist/core/dataset/data/index.js +0 -11
  226. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  227. package/dist/core/dataset/data/preprocessing.js +0 -33
  228. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  229. package/dist/core/dataset/data/tabular_data.js +0 -40
  230. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  231. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  232. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  233. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  234. package/dist/core/dataset/data_loader/index.js +0 -9
  235. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  236. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  237. package/dist/core/dataset/dataset.d.ts +0 -2
  238. package/dist/core/dataset/dataset.js +0 -2
  239. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  240. package/dist/core/dataset/dataset_builder.js +0 -96
  241. package/dist/core/dataset/index.d.ts +0 -4
  242. package/dist/core/dataset/index.js +0 -14
  243. package/dist/core/default_tasks/cifar10.d.ts +0 -2
  244. package/dist/core/default_tasks/geotags.d.ts +0 -2
  245. package/dist/core/default_tasks/geotags.js +0 -69
  246. package/dist/core/default_tasks/index.d.ts +0 -6
  247. package/dist/core/default_tasks/index.js +0 -15
  248. package/dist/core/default_tasks/lus_covid.d.ts +0 -2
  249. package/dist/core/default_tasks/lus_covid.js +0 -96
  250. package/dist/core/default_tasks/mnist.d.ts +0 -2
  251. package/dist/core/default_tasks/simple_face.d.ts +0 -2
  252. package/dist/core/default_tasks/titanic.d.ts +0 -2
  253. package/dist/core/index.d.ts +0 -18
  254. package/dist/core/index.js +0 -39
  255. package/dist/core/informant/graph_informant.js +0 -23
  256. package/dist/core/informant/index.d.ts +0 -3
  257. package/dist/core/informant/index.js +0 -9
  258. package/dist/core/informant/training_informant/base.d.ts +0 -31
  259. package/dist/core/informant/training_informant/base.js +0 -83
  260. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  261. package/dist/core/informant/training_informant/decentralized.js +0 -22
  262. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  263. package/dist/core/informant/training_informant/federated.js +0 -32
  264. package/dist/core/informant/training_informant/index.d.ts +0 -4
  265. package/dist/core/informant/training_informant/index.js +0 -11
  266. package/dist/core/informant/training_informant/local.d.ts +0 -6
  267. package/dist/core/informant/training_informant/local.js +0 -20
  268. package/dist/core/logging/console_logger.js +0 -33
  269. package/dist/core/logging/index.d.ts +0 -3
  270. package/dist/core/logging/index.js +0 -9
  271. package/dist/core/logging/logger.js +0 -9
  272. package/dist/core/logging/trainer_logger.d.ts +0 -24
  273. package/dist/core/logging/trainer_logger.js +0 -59
  274. package/dist/core/memory/base.d.ts +0 -22
  275. package/dist/core/memory/base.js +0 -9
  276. package/dist/core/memory/empty.d.ts +0 -14
  277. package/dist/core/memory/empty.js +0 -75
  278. package/dist/core/memory/index.d.ts +0 -3
  279. package/dist/core/memory/index.js +0 -9
  280. package/dist/core/memory/model_type.d.ts +0 -4
  281. package/dist/core/memory/model_type.js +0 -9
  282. package/dist/core/serialization/index.d.ts +0 -2
  283. package/dist/core/serialization/index.js +0 -6
  284. package/dist/core/serialization/model.d.ts +0 -5
  285. package/dist/core/serialization/model.js +0 -55
  286. package/dist/core/serialization/weights.js +0 -64
  287. package/dist/core/task/data_example.js +0 -24
  288. package/dist/core/task/digest.js +0 -18
  289. package/dist/core/task/display_information.js +0 -49
  290. package/dist/core/task/index.d.ts +0 -6
  291. package/dist/core/task/index.js +0 -15
  292. package/dist/core/task/model_compile_data.d.ts +0 -6
  293. package/dist/core/task/model_compile_data.js +0 -22
  294. package/dist/core/task/summary.js +0 -19
  295. package/dist/core/task/task.js +0 -35
  296. package/dist/core/task/task_handler.d.ts +0 -5
  297. package/dist/core/task/task_handler.js +0 -53
  298. package/dist/core/task/task_provider.d.ts +0 -6
  299. package/dist/core/task/task_provider.js +0 -13
  300. package/dist/core/task/training_information.js +0 -66
  301. package/dist/core/training/disco.d.ts +0 -23
  302. package/dist/core/training/disco.js +0 -130
  303. package/dist/core/training/index.d.ts +0 -2
  304. package/dist/core/training/index.js +0 -7
  305. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  306. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  307. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  308. package/dist/core/training/trainer/local_trainer.js +0 -34
  309. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  310. package/dist/core/training/trainer/round_tracker.js +0 -47
  311. package/dist/core/training/trainer/trainer.d.ts +0 -65
  312. package/dist/core/training/trainer/trainer.js +0 -160
  313. package/dist/core/training/trainer/trainer_builder.js +0 -95
  314. package/dist/core/training/training_schemes.d.ts +0 -5
  315. package/dist/core/training/training_schemes.js +0 -10
  316. package/dist/core/types.d.ts +0 -4
  317. package/dist/core/types.js +0 -2
  318. package/dist/core/validation/index.d.ts +0 -1
  319. package/dist/core/validation/index.js +0 -5
  320. package/dist/core/validation/validator.d.ts +0 -17
  321. package/dist/core/validation/validator.js +0 -104
  322. package/dist/core/weights/aggregation.d.ts +0 -7
  323. package/dist/core/weights/aggregation.js +0 -72
  324. package/dist/core/weights/index.d.ts +0 -2
  325. package/dist/core/weights/index.js +0 -7
  326. package/dist/core/weights/weights_container.d.ts +0 -19
  327. package/dist/core/weights/weights_container.js +0 -64
  328. package/dist/imports.d.ts +0 -2
  329. package/dist/imports.js +0 -7
  330. package/dist/memory/memory.d.ts +0 -26
  331. package/dist/memory/memory.js +0 -160
  332. package/dist/{core/task → task}/data_example.d.ts +1 -1
  333. package/dist/{core/task → task}/digest.d.ts +0 -0
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const skinMnist: TaskProvider;
@@ -0,0 +1,80 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { data, models } from '../index.js';
3
+ export const skinMnist = {
4
+ getTask() {
5
+ return {
6
+ id: 'skin_mnist',
7
+ displayInformation: {
8
+ taskTitle: 'Skin disease classification',
9
+ summary: {
10
+ preview: 'Can you determine the skin disease from the dermatoscopic images?',
11
+ overview: 'HAM10000 "Human Against Machine with 10000 training images" dataset is a large collection of multi-source dermatoscopic images of pigmented lesions from Kaggle'
12
+ },
13
+ limitations: 'The training data is limited to small images of size 28x28, similarly to the MNIST dataset.',
14
+ tradeoffs: 'Training success strongly depends on label distribution',
15
+ dataFormatInformation: '',
16
+ dataExampleText: 'Below you find an example',
17
+ dataExampleImage: 'http://walidbn.com/ISIC_0024306.jpg'
18
+ },
19
+ trainingInformation: {
20
+ modelID: 'skin_mnist-model',
21
+ epochs: 50,
22
+ roundDuration: 1,
23
+ validationSplit: 0.1,
24
+ batchSize: 32,
25
+ preprocessingFunctions: [data.ImagePreprocessing.Normalize],
26
+ dataType: 'image',
27
+ IMAGE_H: 28,
28
+ IMAGE_W: 28,
29
+ LABEL_LIST: [
30
+ 'Melanocytic nevi',
31
+ 'Melanoma',
32
+ 'Benign keratosis-like lesions',
33
+ 'Basal cell carcinoma',
34
+ 'Actinic keratoses',
35
+ 'Vascular lesions',
36
+ 'Dermatofibroma'
37
+ ],
38
+ scheme: 'federated',
39
+ noiseScale: undefined,
40
+ clippingRadius: undefined
41
+ }
42
+ };
43
+ },
44
+ getModel() {
45
+ const numClasses = 7;
46
+ const size = 28;
47
+ const model = tf.sequential();
48
+ model.add(tf.layers.conv2d({
49
+ inputShape: [size, size, 3],
50
+ filters: 256,
51
+ kernelSize: 3,
52
+ activation: 'relu'
53
+ }));
54
+ model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
55
+ model.add(tf.layers.dropout({ rate: 0.3 }));
56
+ model.add(tf.layers.conv2d({
57
+ filters: 128,
58
+ kernelSize: 3,
59
+ activation: 'relu'
60
+ }));
61
+ model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
62
+ model.add(tf.layers.dropout({ rate: 0.3 }));
63
+ model.add(tf.layers.conv2d({
64
+ filters: 64,
65
+ kernelSize: 3,
66
+ activation: 'relu'
67
+ }));
68
+ model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
69
+ model.add(tf.layers.dropout({ rate: 0.3 }));
70
+ model.add(tf.layers.flatten());
71
+ model.add(tf.layers.dense({ units: 32 }));
72
+ model.add(tf.layers.dense({ units: numClasses, activation: 'softmax' }));
73
+ model.compile({
74
+ optimizer: tf.train.adam(0.001),
75
+ loss: 'categoricalCrossentropy',
76
+ metrics: ['accuracy']
77
+ });
78
+ return Promise.resolve(new models.TFJS(model));
79
+ }
80
+ };
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const titanic: TaskProvider;
@@ -1,12 +1,9 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.titanic = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("..");
6
- exports.titanic = {
7
- getTask: function () {
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { data, models } from '../index.js';
3
+ export const titanic = {
4
+ getTask() {
8
5
  return {
9
- taskID: 'titanic',
6
+ id: 'titanic',
10
7
  displayInformation: {
11
8
  taskTitle: 'Titanic',
12
9
  summary: {
@@ -50,17 +47,11 @@ exports.titanic = {
50
47
  modelID: 'titanic-model',
51
48
  epochs: 20,
52
49
  roundDuration: 10,
53
- validationSplit: 0,
50
+ validationSplit: 0.2,
54
51
  batchSize: 30,
55
- preprocessingFunctions: [],
56
- modelCompileData: {
57
- optimizer: 'rmsprop',
58
- loss: 'binaryCrossentropy',
59
- metrics: ['accuracy']
60
- },
52
+ preprocessingFunctions: [data.TabularPreprocessing.Sanitize],
61
53
  dataType: 'tabular',
62
54
  inputColumns: [
63
- 'PassengerId',
64
55
  'Age',
65
56
  'SibSp',
66
57
  'Parch',
@@ -70,28 +61,28 @@ exports.titanic = {
70
61
  outputColumns: [
71
62
  'Survived'
72
63
  ],
73
- scheme: 'Federated',
64
+ scheme: 'federated', // secure aggregation not yet implemented for FeAI
74
65
  noiseScale: undefined,
75
66
  clippingRadius: undefined
76
67
  }
77
68
  };
78
69
  },
79
- getModel: function () {
80
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
81
- var model;
82
- return (0, tslib_1.__generator)(this, function (_a) {
83
- model = __1.tf.sequential();
84
- model.add(__1.tf.layers.dense({
85
- inputShape: [6],
86
- units: 124,
87
- activation: 'relu',
88
- kernelInitializer: 'leCunNormal'
89
- }));
90
- model.add(__1.tf.layers.dense({ units: 64, activation: 'relu' }));
91
- model.add(__1.tf.layers.dense({ units: 32, activation: 'relu' }));
92
- model.add(__1.tf.layers.dense({ units: 1, activation: 'sigmoid' }));
93
- return [2 /*return*/, model];
94
- });
70
+ getModel() {
71
+ const model = tf.sequential();
72
+ model.add(tf.layers.dense({
73
+ inputShape: [5],
74
+ units: 124,
75
+ activation: 'relu',
76
+ kernelInitializer: 'leCunNormal'
77
+ }));
78
+ model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
79
+ model.add(tf.layers.dense({ units: 32, activation: 'relu' }));
80
+ model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
81
+ model.compile({
82
+ optimizer: tf.train.sgd(0.001),
83
+ loss: 'binaryCrossentropy',
84
+ metrics: ['accuracy']
95
85
  });
86
+ return Promise.resolve(new models.TFJS(model));
96
87
  }
97
88
  };
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const wikitext: TaskProvider;
@@ -0,0 +1,38 @@
1
+ import { data, models } from '../index.js';
2
+ export const wikitext = {
3
+ getTask() {
4
+ return {
5
+ id: 'wikitext-103',
6
+ displayInformation: {
7
+ taskTitle: 'Language modelling on wikitext',
8
+ summary: {
9
+ preview: 'In this challenge, we ask you to do next word prediction on a dataset of Wikipedia articles.',
10
+ overview: 'Wikitext-103-raw is a dataset comprising unprocessed text excerpts from Wikipedia articles, designed for tasks related to natural language processing and language modeling.'
11
+ },
12
+ limitations: 'The dataset may contain noise, inconsistencies, and unstructured content due to its raw nature, potentially posing challenges for certain NLP tasks.',
13
+ tradeoffs: 'The raw format may lack structured annotations and may require additional preprocessing for specific applications.',
14
+ dataFormatInformation: 'The dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.',
15
+ dataExampleText: 'An example excerpt from the dataset could be: "The history of artificial intelligence dates back to ancient times, with philosophical discussions on the nature of thought and reasoning."'
16
+ },
17
+ trainingInformation: {
18
+ dataType: 'text',
19
+ modelID: 'wikitext-103-raw-model',
20
+ preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
21
+ validationSplit: 0.2, // TODO: is this used somewhere? because train, eval and test are already split in dataset
22
+ epochs: 5,
23
+ scheme: 'federated',
24
+ noiseScale: undefined,
25
+ decentralizedSecure: true,
26
+ minimumReadyPeers: 3,
27
+ maxShareValue: 100,
28
+ roundDuration: 10,
29
+ batchSize: 16,
30
+ tokenizer: 'Xenova/gpt2',
31
+ maxSequenceLength: 128
32
+ }
33
+ };
34
+ },
35
+ getModel() {
36
+ return Promise.resolve(new models.GPT());
37
+ }
38
+ };
package/dist/index.d.ts CHANGED
@@ -1,2 +1,18 @@
1
- export * from './core';
2
- export * as browser from './imports';
1
+ export * as data from './dataset/index.js';
2
+ export * as serialization from './serialization/index.js';
3
+ export * as training from './training/index.js';
4
+ export * as privacy from './privacy.js';
5
+ export { GraphInformant } from './informant/index.js';
6
+ export * as client from './client/index.js';
7
+ export * as aggregator from './aggregator/index.js';
8
+ export { WeightsContainer, aggregation } from './weights/index.js';
9
+ export { AsyncInformant } from './async_informant.js';
10
+ export { Logger, ConsoleLogger } from './logging/index.js';
11
+ export { Memory, ModelType, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
12
+ export { Disco, RoundLogs } from './training/index.js';
13
+ export { Validator } from './validation/index.js';
14
+ export { Model, EpochLogs } from './models/index.js';
15
+ export * as models from './models/index.js';
16
+ export * from './task/index.js';
17
+ export * as defaultTasks from './default_tasks/index.js';
18
+ export * from './types.js';
package/dist/index.js CHANGED
@@ -1,6 +1,18 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.browser = void 0;
4
- var tslib_1 = require("tslib");
5
- (0, tslib_1.__exportStar)(require("./core"), exports);
6
- exports.browser = (0, tslib_1.__importStar)(require("./imports"));
1
+ export * as data from './dataset/index.js';
2
+ export * as serialization from './serialization/index.js';
3
+ export * as training from './training/index.js';
4
+ export * as privacy from './privacy.js';
5
+ export { GraphInformant } from './informant/index.js';
6
+ export * as client from './client/index.js';
7
+ export * as aggregator from './aggregator/index.js';
8
+ export { WeightsContainer, aggregation } from './weights/index.js';
9
+ export { AsyncInformant } from './async_informant.js';
10
+ export { ConsoleLogger } from './logging/index.js';
11
+ export { Memory, ModelType, Empty as EmptyMemory } from './memory/index.js';
12
+ export { Disco } from './training/index.js';
13
+ export { Validator } from './validation/index.js';
14
+ export { Model } from './models/index.js';
15
+ export * as models from './models/index.js';
16
+ export * from './task/index.js';
17
+ export * as defaultTasks from './default_tasks/index.js';
18
+ export * from './types.js';
@@ -1,4 +1,4 @@
1
- import { List } from 'immutable';
1
+ import { type List } from 'immutable';
2
2
  export declare class GraphInformant {
3
3
  static readonly NB_EPOCHS_ON_GRAPH = 10;
4
4
  private currentAccuracy;
@@ -0,0 +1,20 @@
1
+ import { Repeat } from 'immutable';
2
+ export class GraphInformant {
3
+ static NB_EPOCHS_ON_GRAPH = 10;
4
+ currentAccuracy;
5
+ accuracyDataSeries;
6
+ constructor() {
7
+ this.currentAccuracy = 0;
8
+ this.accuracyDataSeries = Repeat(0, GraphInformant.NB_EPOCHS_ON_GRAPH).toList();
9
+ }
10
+ updateAccuracy(accuracy) {
11
+ this.accuracyDataSeries = this.accuracyDataSeries.shift().push(accuracy);
12
+ this.currentAccuracy = accuracy;
13
+ }
14
+ data() {
15
+ return this.accuracyDataSeries;
16
+ }
17
+ accuracy() {
18
+ return this.currentAccuracy;
19
+ }
20
+ }
@@ -0,0 +1 @@
1
+ export { GraphInformant } from './graph_informant.js';
@@ -0,0 +1 @@
1
+ export { GraphInformant } from './graph_informant.js';
@@ -1,10 +1,10 @@
1
- import { Logger } from './logger';
1
+ import { Logger } from "./logger.js";
2
2
  /**
3
3
  * Same properties as Toaster but on the console
4
4
  *
5
5
  * @class Logger
6
6
  */
7
- export declare class ConsoleLogger extends Logger {
7
+ export declare class ConsoleLogger implements Logger {
8
8
  /**
9
9
  * Logs success message on the console (in green)
10
10
  * @param {String} message - message to be displayed
@@ -0,0 +1,22 @@
1
+ import chalk from "chalk";
2
+ /**
3
+ * Same properties as Toaster but on the console
4
+ *
5
+ * @class Logger
6
+ */
7
+ export class ConsoleLogger {
8
+ /**
9
+ * Logs success message on the console (in green)
10
+ * @param {String} message - message to be displayed
11
+ */
12
+ success(message) {
13
+ console.log(chalk.green(message));
14
+ }
15
+ /**
16
+ * Logs error message on the console (in red)
17
+ * @param message - message to be displayed
18
+ */
19
+ error(message) {
20
+ console.error(chalk.red(message));
21
+ }
22
+ }
@@ -0,0 +1,2 @@
1
+ export { Logger } from './logger.js';
2
+ export { ConsoleLogger } from './console_logger.js';
@@ -0,0 +1 @@
1
+ export { ConsoleLogger } from './console_logger.js';
@@ -1,12 +1,12 @@
1
- export declare abstract class Logger {
1
+ export interface Logger {
2
2
  /**
3
3
  * Logs sucess message (in green)
4
4
  * @param message - message to be displayed
5
5
  */
6
- abstract success(message: string): void;
6
+ success(message: string): void;
7
7
  /**
8
8
  * Logs error message (in red)
9
9
  * @param message - message to be displayed
10
10
  */
11
- abstract error(message: string): void;
11
+ error(message: string): void;
12
12
  }
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,119 @@
1
+ import type { Model, TaskID } from '../index.js';
2
+ import type { ModelType } from './model_type.js';
3
+ /**
4
+ * Model path which uniquely identifies a model in memory.
5
+ */
6
+ export type Path = string;
7
+ /**
8
+ * Model information which uniquely identifies a model in memory.
9
+ */
10
+ export interface ModelInfo {
11
+ /**
12
+ * The model's type: "working" or "saved" model.
13
+ */
14
+ type?: ModelType;
15
+ /**
16
+ * The model's version, to allow for multiple saved models of a same task without
17
+ * causing id conflicts
18
+ */
19
+ version?: number;
20
+ /**
21
+ * The model's corresponding task
22
+ */
23
+ taskID: TaskID;
24
+ /**
25
+ * The model's name
26
+ */
27
+ name: string;
28
+ }
29
+ /**
30
+ * A model source uniquely identifies a model stored in memory.
31
+ */
32
+ export type ModelSource = ModelInfo | Path;
33
+ /**
34
+ * Represents a model memory system, providing functions to fetch, save, delete and update models.
35
+ * Stored models can either be a model currently being trained ("working model") or a regular model
36
+ * saved in memory ("saved model"). There can only be a single working model for a given task.
37
+ */
38
+ export declare abstract class Memory {
39
+ /**
40
+ * Fetches the model identified by the given model source.
41
+ * @param source The model source
42
+ * @returns The model
43
+ */
44
+ abstract getModel(source: ModelSource): Promise<Model>;
45
+ /**
46
+ * Removes the model identified by the given model source from memory.
47
+ * @param source The model source
48
+ * @returns The model
49
+ */
50
+ abstract deleteModel(source: ModelSource): Promise<void>;
51
+ /**
52
+ * Replaces the corresponding working model with the saved model identified by the given model source.
53
+ * @param source The model source
54
+ */
55
+ abstract loadModel(source: ModelSource): Promise<void>;
56
+ /**
57
+ * Fetches metadata for the model identified by the given model source.
58
+ * If the model does not exist in memory, returns undefined.
59
+ * @param source The model source
60
+ * @returns The model metadata or undefined
61
+ */
62
+ abstract getModelMetadata(source: ModelSource): Promise<object | undefined>;
63
+ /**
64
+ * Replaces the working model identified by the given source with the newly provided model.
65
+ * @param source The model source
66
+ * @param model The new model
67
+ */
68
+ abstract updateWorkingModel(source: ModelSource, model: Model): Promise<void>;
69
+ /**
70
+ * Creates a saved model copy from the working model identified by the given model source.
71
+ * Returns the saved model's path.
72
+ * @param source The model source
73
+ * @returns The saved model's path
74
+ */
75
+ abstract saveWorkingModel(source: ModelSource): Promise<Path | undefined>;
76
+ /**
77
+ * Saves the newly provided model to the given model source.
78
+ * Returns the saved model's path
79
+ * @param source The model source
80
+ * @param model The new model
81
+ * @returns The saved model's path
82
+ */
83
+ abstract saveModel(source: ModelSource, model: Model): Promise<Path | undefined>;
84
+ /**
85
+ * Moves the model identified by the model source to a file system. This is platform-dependent.
86
+ * @param source The model source
87
+ */
88
+ abstract downloadModel(source: ModelSource): Promise<void>;
89
+ /**
90
+ * Checks whether the model memory contains the model identified by the given source.
91
+ * @param source The model source
92
+ * @returns True if the memory contains the model, false otherwise
93
+ */
94
+ abstract contains(source: ModelSource): Promise<boolean>;
95
+ /**
96
+ * Computes the path in memory corresponding to the given model source, be it a path or model information.
97
+ * This is used to easily switch between model path and information, which are both unique model identifiers
98
+ * with a one-to-one correspondance. Returns undefined instead if no path could be inferred from the given
99
+ * model source.
100
+ * @param source The model source
101
+ * @returns The model path
102
+ */
103
+ abstract pathFor(source: ModelSource): Path | undefined;
104
+ /**
105
+ * Computes the model information corresponding to the given model source, be it a path or model information.
106
+ * This is used to easily switch between model path and information, which are both unique model identifiers
107
+ * with a one-to-one correspondance. Returns undefined instead if no unique model information could be inferred
108
+ * from the given model source.
109
+ * @param source The model source
110
+ * @returns The model information
111
+ */
112
+ abstract infoFor(source: ModelSource): ModelInfo | undefined;
113
+ /**
114
+ * Computes the lowest version a model source can have without conflicting with model versions currently in memory.
115
+ * @param source The model source
116
+ * @returns The duplicated model source
117
+ */
118
+ abstract duplicateSource(source: ModelSource): Promise<ModelSource | undefined>;
119
+ }
@@ -0,0 +1,9 @@
1
+ // only used browser-side
2
+ // TODO: replace IO type
3
+ /**
4
+ * Represents a model memory system, providing functions to fetch, save, delete and update models.
5
+ * Stored models can either be a model currently being trained ("working model") or a regular model
6
+ * saved in memory ("saved model"). There can only be a single working model for a given task.
7
+ */
8
+ export class Memory {
9
+ }
@@ -0,0 +1,20 @@
1
+ import type { Model } from '../index.js';
2
+ import type { ModelInfo, Path } from './base.js';
3
+ import { Memory } from './base.js';
4
+ /**
5
+ * Represents an empty model memory.
6
+ */
7
+ export declare class Empty extends Memory {
8
+ getModelMetadata(): Promise<undefined>;
9
+ contains(): Promise<boolean>;
10
+ getModel(): Promise<Model>;
11
+ loadModel(): Promise<void>;
12
+ updateWorkingModel(): Promise<void>;
13
+ saveWorkingModel(): Promise<undefined>;
14
+ saveModel(): Promise<undefined>;
15
+ deleteModel(): Promise<void>;
16
+ downloadModel(): Promise<void>;
17
+ pathFor(): Path;
18
+ infoFor(): ModelInfo;
19
+ duplicateSource(): Promise<undefined>;
20
+ }
@@ -0,0 +1,43 @@
1
+ import { Memory } from './base.js';
2
+ /**
3
+ * Represents an empty model memory.
4
+ */
5
+ export class Empty extends Memory {
6
+ getModelMetadata() {
7
+ return Promise.resolve(undefined);
8
+ }
9
+ contains() {
10
+ return Promise.resolve(false);
11
+ }
12
+ getModel() {
13
+ return Promise.reject(new Error('empty'));
14
+ }
15
+ loadModel() {
16
+ return Promise.reject(new Error('empty'));
17
+ }
18
+ updateWorkingModel() {
19
+ // nothing to do
20
+ return Promise.resolve();
21
+ }
22
+ saveWorkingModel() {
23
+ return Promise.resolve(undefined);
24
+ }
25
+ saveModel() {
26
+ return Promise.resolve(undefined);
27
+ }
28
+ async deleteModel() {
29
+ // nothing to do
30
+ }
31
+ downloadModel() {
32
+ return Promise.reject(new Error('empty'));
33
+ }
34
+ pathFor() {
35
+ throw new Error('empty');
36
+ }
37
+ infoFor() {
38
+ throw new Error('empty');
39
+ }
40
+ duplicateSource() {
41
+ return Promise.resolve(undefined);
42
+ }
43
+ }
@@ -1 +1,3 @@
1
- export { IndexedDB } from './memory';
1
+ export { Empty } from './empty.js';
2
+ export { Memory, type ModelInfo, type Path, type ModelSource } from './base.js';
3
+ export { ModelType } from './model_type.js';
@@ -1,5 +1,3 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.IndexedDB = void 0;
4
- var memory_1 = require("./memory");
5
- Object.defineProperty(exports, "IndexedDB", { enumerable: true, get: function () { return memory_1.IndexedDB; } });
1
+ export { Empty } from './empty.js';
2
+ export { Memory } from './base.js';
3
+ export { ModelType } from './model_type.js';
@@ -0,0 +1,9 @@
1
+ /**
2
+ * Type of models stored in memory. Stored models can either be a model currently
3
+ * being trained ("working model") or a regular model saved in memory ("saved model").
4
+ * There can only be a single working model for a given task.
5
+ */
6
+ export declare enum ModelType {
7
+ WORKING = "working",
8
+ SAVED = "saved"
9
+ }
@@ -0,0 +1,10 @@
1
+ /**
2
+ * Type of models stored in memory. Stored models can either be a model currently
3
+ * being trained ("working model") or a regular model saved in memory ("saved model").
4
+ * There can only be a single working model for a given task.
5
+ */
6
+ export var ModelType;
7
+ (function (ModelType) {
8
+ ModelType["WORKING"] = "working";
9
+ ModelType["SAVED"] = "saved";
10
+ })(ModelType || (ModelType = {}));
@@ -1,4 +1,4 @@
1
- import { Task, WeightsContainer } from '.';
1
+ import type { Task, WeightsContainer } from './index.js';
2
2
  /**
3
3
  * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
4
4
  * The previous round's weights are the last weights pulled from server/peers.