@epfml/discojs 2.0.0 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/default_tasks/cifar10/index.js +60 -0
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/default_tasks/mnist.js +61 -0
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/default_tasks/simple_face/index.js +48 -0
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/default_tasks/titanic.js +88 -0
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.d.ts +5 -0
  135. package/dist/task/digest.js +14 -0
  136. package/dist/{core/task → task}/display_information.d.ts +5 -3
  137. package/dist/task/display_information.js +46 -0
  138. package/dist/task/index.d.ts +7 -0
  139. package/dist/task/index.js +5 -0
  140. package/dist/task/label_type.d.ts +9 -0
  141. package/dist/task/label_type.js +28 -0
  142. package/dist/task/summary.js +13 -0
  143. package/dist/task/task.d.ts +12 -0
  144. package/dist/task/task.js +22 -0
  145. package/dist/task/task_handler.d.ts +5 -0
  146. package/dist/task/task_handler.js +20 -0
  147. package/dist/task/task_provider.d.ts +5 -0
  148. package/dist/task/task_provider.js +1 -0
  149. package/dist/{core/task → task}/training_information.d.ts +9 -10
  150. package/dist/task/training_information.js +88 -0
  151. package/dist/training/disco.d.ts +40 -0
  152. package/dist/training/disco.js +107 -0
  153. package/dist/training/index.d.ts +2 -0
  154. package/dist/training/index.js +1 -0
  155. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  156. package/dist/training/trainer/distributed_trainer.js +36 -0
  157. package/dist/training/trainer/local_trainer.d.ts +12 -0
  158. package/dist/training/trainer/local_trainer.js +19 -0
  159. package/dist/training/trainer/trainer.d.ts +33 -0
  160. package/dist/training/trainer/trainer.js +52 -0
  161. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  162. package/dist/training/trainer/trainer_builder.js +43 -0
  163. package/dist/types.d.ts +8 -0
  164. package/dist/types.js +1 -0
  165. package/dist/utils/event_emitter.d.ts +40 -0
  166. package/dist/utils/event_emitter.js +57 -0
  167. package/dist/validation/index.d.ts +1 -0
  168. package/dist/validation/index.js +1 -0
  169. package/dist/validation/validator.d.ts +28 -0
  170. package/dist/validation/validator.js +132 -0
  171. package/dist/weights/aggregation.d.ts +21 -0
  172. package/dist/weights/aggregation.js +44 -0
  173. package/dist/weights/index.d.ts +2 -0
  174. package/dist/weights/index.js +2 -0
  175. package/dist/weights/weights_container.d.ts +68 -0
  176. package/dist/weights/weights_container.js +96 -0
  177. package/package.json +25 -16
  178. package/README.md +0 -53
  179. package/dist/core/async_buffer.d.ts +0 -41
  180. package/dist/core/async_buffer.js +0 -97
  181. package/dist/core/async_informant.d.ts +0 -20
  182. package/dist/core/async_informant.js +0 -69
  183. package/dist/core/client/base.d.ts +0 -33
  184. package/dist/core/client/base.js +0 -35
  185. package/dist/core/client/decentralized/base.d.ts +0 -32
  186. package/dist/core/client/decentralized/base.js +0 -212
  187. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  188. package/dist/core/client/decentralized/clear_text.js +0 -96
  189. package/dist/core/client/decentralized/index.d.ts +0 -4
  190. package/dist/core/client/decentralized/index.js +0 -9
  191. package/dist/core/client/decentralized/messages.d.ts +0 -41
  192. package/dist/core/client/decentralized/messages.js +0 -54
  193. package/dist/core/client/decentralized/peer.d.ts +0 -26
  194. package/dist/core/client/decentralized/peer.js +0 -210
  195. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  196. package/dist/core/client/decentralized/peer_pool.js +0 -92
  197. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  198. package/dist/core/client/decentralized/sec_agg.js +0 -190
  199. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  200. package/dist/core/client/decentralized/secret_shares.js +0 -39
  201. package/dist/core/client/decentralized/types.d.ts +0 -2
  202. package/dist/core/client/decentralized/types.js +0 -7
  203. package/dist/core/client/event_connection.d.ts +0 -37
  204. package/dist/core/client/event_connection.js +0 -158
  205. package/dist/core/client/federated/client.d.ts +0 -37
  206. package/dist/core/client/federated/client.js +0 -273
  207. package/dist/core/client/federated/index.d.ts +0 -2
  208. package/dist/core/client/federated/index.js +0 -7
  209. package/dist/core/client/federated/messages.d.ts +0 -38
  210. package/dist/core/client/federated/messages.js +0 -25
  211. package/dist/core/client/index.d.ts +0 -5
  212. package/dist/core/client/index.js +0 -11
  213. package/dist/core/client/local.d.ts +0 -8
  214. package/dist/core/client/local.js +0 -36
  215. package/dist/core/client/messages.d.ts +0 -28
  216. package/dist/core/client/messages.js +0 -33
  217. package/dist/core/client/utils.d.ts +0 -2
  218. package/dist/core/client/utils.js +0 -19
  219. package/dist/core/dataset/data/data.d.ts +0 -11
  220. package/dist/core/dataset/data/data.js +0 -20
  221. package/dist/core/dataset/data/data_split.d.ts +0 -5
  222. package/dist/core/dataset/data/data_split.js +0 -2
  223. package/dist/core/dataset/data/image_data.d.ts +0 -8
  224. package/dist/core/dataset/data/image_data.js +0 -64
  225. package/dist/core/dataset/data/index.d.ts +0 -5
  226. package/dist/core/dataset/data/index.js +0 -11
  227. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  228. package/dist/core/dataset/data/preprocessing.js +0 -33
  229. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  230. package/dist/core/dataset/data/tabular_data.js +0 -40
  231. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  232. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  233. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  234. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  235. package/dist/core/dataset/data_loader/index.js +0 -9
  236. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  237. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  238. package/dist/core/dataset/dataset.d.ts +0 -2
  239. package/dist/core/dataset/dataset.js +0 -2
  240. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  241. package/dist/core/dataset/dataset_builder.js +0 -96
  242. package/dist/core/dataset/index.d.ts +0 -4
  243. package/dist/core/dataset/index.js +0 -14
  244. package/dist/core/index.d.ts +0 -18
  245. package/dist/core/index.js +0 -41
  246. package/dist/core/informant/graph_informant.js +0 -23
  247. package/dist/core/informant/index.d.ts +0 -3
  248. package/dist/core/informant/index.js +0 -9
  249. package/dist/core/informant/training_informant/base.d.ts +0 -31
  250. package/dist/core/informant/training_informant/base.js +0 -83
  251. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  252. package/dist/core/informant/training_informant/decentralized.js +0 -22
  253. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  254. package/dist/core/informant/training_informant/federated.js +0 -32
  255. package/dist/core/informant/training_informant/index.d.ts +0 -4
  256. package/dist/core/informant/training_informant/index.js +0 -11
  257. package/dist/core/informant/training_informant/local.d.ts +0 -6
  258. package/dist/core/informant/training_informant/local.js +0 -20
  259. package/dist/core/logging/console_logger.js +0 -33
  260. package/dist/core/logging/index.d.ts +0 -3
  261. package/dist/core/logging/index.js +0 -9
  262. package/dist/core/logging/logger.js +0 -9
  263. package/dist/core/logging/trainer_logger.d.ts +0 -24
  264. package/dist/core/logging/trainer_logger.js +0 -59
  265. package/dist/core/memory/base.d.ts +0 -22
  266. package/dist/core/memory/base.js +0 -9
  267. package/dist/core/memory/empty.d.ts +0 -14
  268. package/dist/core/memory/empty.js +0 -75
  269. package/dist/core/memory/index.d.ts +0 -3
  270. package/dist/core/memory/index.js +0 -9
  271. package/dist/core/memory/model_type.d.ts +0 -4
  272. package/dist/core/memory/model_type.js +0 -9
  273. package/dist/core/serialization/index.d.ts +0 -2
  274. package/dist/core/serialization/index.js +0 -6
  275. package/dist/core/serialization/model.d.ts +0 -5
  276. package/dist/core/serialization/model.js +0 -55
  277. package/dist/core/serialization/weights.js +0 -64
  278. package/dist/core/task/data_example.js +0 -24
  279. package/dist/core/task/display_information.js +0 -49
  280. package/dist/core/task/index.d.ts +0 -3
  281. package/dist/core/task/index.js +0 -8
  282. package/dist/core/task/model_compile_data.d.ts +0 -6
  283. package/dist/core/task/model_compile_data.js +0 -22
  284. package/dist/core/task/summary.js +0 -19
  285. package/dist/core/task/task.d.ts +0 -10
  286. package/dist/core/task/task.js +0 -31
  287. package/dist/core/task/training_information.js +0 -66
  288. package/dist/core/tasks/cifar10.d.ts +0 -3
  289. package/dist/core/tasks/cifar10.js +0 -65
  290. package/dist/core/tasks/geotags.d.ts +0 -3
  291. package/dist/core/tasks/geotags.js +0 -67
  292. package/dist/core/tasks/index.d.ts +0 -6
  293. package/dist/core/tasks/index.js +0 -10
  294. package/dist/core/tasks/lus_covid.d.ts +0 -3
  295. package/dist/core/tasks/lus_covid.js +0 -87
  296. package/dist/core/tasks/mnist.d.ts +0 -3
  297. package/dist/core/tasks/mnist.js +0 -60
  298. package/dist/core/tasks/simple_face.d.ts +0 -2
  299. package/dist/core/tasks/simple_face.js +0 -41
  300. package/dist/core/tasks/titanic.d.ts +0 -3
  301. package/dist/core/tasks/titanic.js +0 -88
  302. package/dist/core/training/disco.d.ts +0 -23
  303. package/dist/core/training/disco.js +0 -130
  304. package/dist/core/training/index.d.ts +0 -2
  305. package/dist/core/training/index.js +0 -7
  306. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  307. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  308. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  309. package/dist/core/training/trainer/local_trainer.js +0 -34
  310. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  311. package/dist/core/training/trainer/round_tracker.js +0 -47
  312. package/dist/core/training/trainer/trainer.d.ts +0 -65
  313. package/dist/core/training/trainer/trainer.js +0 -160
  314. package/dist/core/training/trainer/trainer_builder.js +0 -95
  315. package/dist/core/training/training_schemes.d.ts +0 -5
  316. package/dist/core/training/training_schemes.js +0 -10
  317. package/dist/core/types.d.ts +0 -4
  318. package/dist/core/types.js +0 -2
  319. package/dist/core/validation/index.d.ts +0 -1
  320. package/dist/core/validation/index.js +0 -5
  321. package/dist/core/validation/validator.d.ts +0 -17
  322. package/dist/core/validation/validator.js +0 -104
  323. package/dist/core/weights/aggregation.d.ts +0 -8
  324. package/dist/core/weights/aggregation.js +0 -96
  325. package/dist/core/weights/index.d.ts +0 -2
  326. package/dist/core/weights/index.js +0 -7
  327. package/dist/core/weights/weights_container.d.ts +0 -19
  328. package/dist/core/weights/weights_container.js +0 -64
  329. package/dist/imports.d.ts +0 -2
  330. package/dist/imports.js +0 -7
  331. package/dist/memory/memory.d.ts +0 -26
  332. package/dist/memory/memory.js +0 -160
  333. package/dist/{core/task → task}/data_example.d.ts +1 -1
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const skinMnist: TaskProvider;
@@ -0,0 +1,80 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { data, models } from '../index.js';
3
+ export const skinMnist = {
4
+ getTask() {
5
+ return {
6
+ id: 'skin_mnist',
7
+ displayInformation: {
8
+ taskTitle: 'Skin disease classification',
9
+ summary: {
10
+ preview: 'Can you determine the skin disease from the dermatoscopic images?',
11
+ overview: 'HAM10000 "Human Against Machine with 10000 training images" dataset is a large collection of multi-source dermatoscopic images of pigmented lesions from Kaggle'
12
+ },
13
+ limitations: 'The training data is limited to small images of size 28x28, similarly to the MNIST dataset.',
14
+ tradeoffs: 'Training success strongly depends on label distribution',
15
+ dataFormatInformation: '',
16
+ dataExampleText: 'Below you find an example',
17
+ dataExampleImage: 'http://walidbn.com/ISIC_0024306.jpg'
18
+ },
19
+ trainingInformation: {
20
+ modelID: 'skin_mnist-model',
21
+ epochs: 50,
22
+ roundDuration: 1,
23
+ validationSplit: 0.1,
24
+ batchSize: 32,
25
+ preprocessingFunctions: [data.ImagePreprocessing.Normalize],
26
+ dataType: 'image',
27
+ IMAGE_H: 28,
28
+ IMAGE_W: 28,
29
+ LABEL_LIST: [
30
+ 'Melanocytic nevi',
31
+ 'Melanoma',
32
+ 'Benign keratosis-like lesions',
33
+ 'Basal cell carcinoma',
34
+ 'Actinic keratoses',
35
+ 'Vascular lesions',
36
+ 'Dermatofibroma'
37
+ ],
38
+ scheme: 'federated',
39
+ noiseScale: undefined,
40
+ clippingRadius: undefined
41
+ }
42
+ };
43
+ },
44
+ getModel() {
45
+ const numClasses = 7;
46
+ const size = 28;
47
+ const model = tf.sequential();
48
+ model.add(tf.layers.conv2d({
49
+ inputShape: [size, size, 3],
50
+ filters: 256,
51
+ kernelSize: 3,
52
+ activation: 'relu'
53
+ }));
54
+ model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
55
+ model.add(tf.layers.dropout({ rate: 0.3 }));
56
+ model.add(tf.layers.conv2d({
57
+ filters: 128,
58
+ kernelSize: 3,
59
+ activation: 'relu'
60
+ }));
61
+ model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
62
+ model.add(tf.layers.dropout({ rate: 0.3 }));
63
+ model.add(tf.layers.conv2d({
64
+ filters: 64,
65
+ kernelSize: 3,
66
+ activation: 'relu'
67
+ }));
68
+ model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
69
+ model.add(tf.layers.dropout({ rate: 0.3 }));
70
+ model.add(tf.layers.flatten());
71
+ model.add(tf.layers.dense({ units: 32 }));
72
+ model.add(tf.layers.dense({ units: numClasses, activation: 'softmax' }));
73
+ model.compile({
74
+ optimizer: tf.train.adam(0.001),
75
+ loss: 'categoricalCrossentropy',
76
+ metrics: ['accuracy']
77
+ });
78
+ return Promise.resolve(new models.TFJS(model));
79
+ }
80
+ };
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const titanic: TaskProvider;
@@ -0,0 +1,88 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { data, models } from '../index.js';
3
+ export const titanic = {
4
+ getTask() {
5
+ return {
6
+ id: 'titanic',
7
+ displayInformation: {
8
+ taskTitle: 'Titanic',
9
+ summary: {
10
+ preview: "Test our platform by using a publicly available <b>tabular</b> dataset. <br><br> Download the passenger list from the Titanic shipwreck here: <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/epfml/disco/raw/develop/example_training_data/titanic_train.csv'>titanic_train.csv</a> (more info <a class='underline text-primary-dark dark:text-primary-light' href='https://www.kaggle.com/c/titanic'>here</a>). <br> This model predicts the type of person most likely to survive/die in the historic ship accident, based on their characteristics (sex, age, class etc.).",
11
+ overview: 'We all know the unfortunate story of the Titanic: this flamboyant new transatlantic boat that sunk in 1912 in the North Atlantic Ocean. Today, we revist this tragedy by trying to predict the survival odds of the passenger given some basic features.'
12
+ },
13
+ model: 'The current model does not normalize the given data and applies only a very simple pre-processing of the data.',
14
+ tradeoffs: 'We are using a small model for this task: 4 fully connected layers with few neurons. This allows fast training but can yield to reduced accuracy.',
15
+ dataFormatInformation: 'This model takes as input a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.<br><br>pclass: A proxy for socio-economic status (SES)<br>1st = Upper<br>2nd = Middle<br>3rd = Lower<br><br>age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5<br><br>sibsp: The dataset defines family relations in this way:<br>Sibling = brother, sister, stepbrother, stepsister<br>Spouse = husband, wife (mistresses and fiancés were ignored)<br><br>parch: The dataset defines family relations in this way:<br>Parent = mother, father<br>Child = daughter, son, stepdaughter, stepson<br>Some children travelled only with a nanny, therefore parch=0 for them.<br><br>The first line of the CSV contains the header:<br> PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked<br><br>Each susequent row contains the corresponding data.',
16
+ dataExampleText: 'Below one can find an example of a datapoint taken as input by our model. In this datapoint, the person is young man named Owen Harris that unfortunnalty perished with the Titanic. He boarded the boat in South Hamptons and was a 3rd class passenger. On the testing & validation page, the data should not contain the label column (Survived).',
17
+ dataExample: [
18
+ { columnName: 'PassengerId', columnData: '1' },
19
+ { columnName: 'Survived', columnData: '0' },
20
+ { columnName: 'Name', columnData: 'Braund, Mr. Owen Harris' },
21
+ { columnName: 'Sex', columnData: 'male' },
22
+ { columnName: 'Age', columnData: '22' },
23
+ { columnName: 'SibSp', columnData: '1' },
24
+ { columnName: 'Parch', columnData: '0' },
25
+ { columnName: 'Ticket', columnData: '1/5 21171' },
26
+ { columnName: 'Fare', columnData: '7.25' },
27
+ { columnName: 'Cabin', columnData: 'E46' },
28
+ { columnName: 'Embarked', columnData: 'S' },
29
+ { columnName: 'Pclass', columnData: '3' }
30
+ ],
31
+ headers: [
32
+ 'PassengerId',
33
+ 'Survived',
34
+ 'Name',
35
+ 'Sex',
36
+ 'Age',
37
+ 'SibSp',
38
+ 'Parch',
39
+ 'Ticket',
40
+ 'Fare',
41
+ 'Cabin',
42
+ 'Embarked',
43
+ 'Pclass'
44
+ ]
45
+ },
46
+ trainingInformation: {
47
+ modelID: 'titanic-model',
48
+ epochs: 20,
49
+ roundDuration: 10,
50
+ validationSplit: 0.2,
51
+ batchSize: 30,
52
+ preprocessingFunctions: [data.TabularPreprocessing.Sanitize],
53
+ dataType: 'tabular',
54
+ inputColumns: [
55
+ 'Age',
56
+ 'SibSp',
57
+ 'Parch',
58
+ 'Fare',
59
+ 'Pclass'
60
+ ],
61
+ outputColumns: [
62
+ 'Survived'
63
+ ],
64
+ scheme: 'federated', // secure aggregation not yet implemented for FeAI
65
+ noiseScale: undefined,
66
+ clippingRadius: undefined
67
+ }
68
+ };
69
+ },
70
+ getModel() {
71
+ const model = tf.sequential();
72
+ model.add(tf.layers.dense({
73
+ inputShape: [5],
74
+ units: 124,
75
+ activation: 'relu',
76
+ kernelInitializer: 'leCunNormal'
77
+ }));
78
+ model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
79
+ model.add(tf.layers.dense({ units: 32, activation: 'relu' }));
80
+ model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
81
+ model.compile({
82
+ optimizer: tf.train.sgd(0.001),
83
+ loss: 'binaryCrossentropy',
84
+ metrics: ['accuracy']
85
+ });
86
+ return Promise.resolve(new models.TFJS(model));
87
+ }
88
+ };
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const wikitext: TaskProvider;
@@ -0,0 +1,38 @@
1
+ import { data, models } from '../index.js';
2
+ export const wikitext = {
3
+ getTask() {
4
+ return {
5
+ id: 'wikitext-103',
6
+ displayInformation: {
7
+ taskTitle: 'Language modelling on wikitext',
8
+ summary: {
9
+ preview: 'In this challenge, we ask you to do next word prediction on a dataset of Wikipedia articles.',
10
+ overview: 'Wikitext-103-raw is a dataset comprising unprocessed text excerpts from Wikipedia articles, designed for tasks related to natural language processing and language modeling.'
11
+ },
12
+ limitations: 'The dataset may contain noise, inconsistencies, and unstructured content due to its raw nature, potentially posing challenges for certain NLP tasks.',
13
+ tradeoffs: 'The raw format may lack structured annotations and may require additional preprocessing for specific applications.',
14
+ dataFormatInformation: 'The dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.',
15
+ dataExampleText: 'An example excerpt from the dataset could be: "The history of artificial intelligence dates back to ancient times, with philosophical discussions on the nature of thought and reasoning."'
16
+ },
17
+ trainingInformation: {
18
+ dataType: 'text',
19
+ modelID: 'wikitext-103-raw-model',
20
+ preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
21
+ validationSplit: 0.2, // TODO: is this used somewhere? because train, eval and test are already split in dataset
22
+ epochs: 5,
23
+ scheme: 'federated',
24
+ noiseScale: undefined,
25
+ decentralizedSecure: true,
26
+ minimumReadyPeers: 3,
27
+ maxShareValue: 100,
28
+ roundDuration: 10,
29
+ batchSize: 16,
30
+ tokenizer: 'Xenova/gpt2',
31
+ maxSequenceLength: 128
32
+ }
33
+ };
34
+ },
35
+ getModel() {
36
+ return Promise.resolve(new models.GPT());
37
+ }
38
+ };
package/dist/index.d.ts CHANGED
@@ -1,2 +1,18 @@
1
- export * from './core';
2
- export * as browser from './imports';
1
+ export * as data from './dataset/index.js';
2
+ export * as serialization from './serialization/index.js';
3
+ export * as training from './training/index.js';
4
+ export * as privacy from './privacy.js';
5
+ export { GraphInformant } from './informant/index.js';
6
+ export * as client from './client/index.js';
7
+ export * as aggregator from './aggregator/index.js';
8
+ export { WeightsContainer, aggregation } from './weights/index.js';
9
+ export { AsyncInformant } from './async_informant.js';
10
+ export { Logger, ConsoleLogger } from './logging/index.js';
11
+ export { Memory, ModelType, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
12
+ export { Disco, RoundLogs } from './training/index.js';
13
+ export { Validator } from './validation/index.js';
14
+ export { Model, EpochLogs } from './models/index.js';
15
+ export * as models from './models/index.js';
16
+ export * from './task/index.js';
17
+ export * as defaultTasks from './default_tasks/index.js';
18
+ export * from './types.js';
package/dist/index.js CHANGED
@@ -1,6 +1,18 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.browser = void 0;
4
- var tslib_1 = require("tslib");
5
- (0, tslib_1.__exportStar)(require("./core"), exports);
6
- exports.browser = (0, tslib_1.__importStar)(require("./imports"));
1
+ export * as data from './dataset/index.js';
2
+ export * as serialization from './serialization/index.js';
3
+ export * as training from './training/index.js';
4
+ export * as privacy from './privacy.js';
5
+ export { GraphInformant } from './informant/index.js';
6
+ export * as client from './client/index.js';
7
+ export * as aggregator from './aggregator/index.js';
8
+ export { WeightsContainer, aggregation } from './weights/index.js';
9
+ export { AsyncInformant } from './async_informant.js';
10
+ export { ConsoleLogger } from './logging/index.js';
11
+ export { Memory, ModelType, Empty as EmptyMemory } from './memory/index.js';
12
+ export { Disco } from './training/index.js';
13
+ export { Validator } from './validation/index.js';
14
+ export { Model } from './models/index.js';
15
+ export * as models from './models/index.js';
16
+ export * from './task/index.js';
17
+ export * as defaultTasks from './default_tasks/index.js';
18
+ export * from './types.js';
@@ -1,4 +1,4 @@
1
- import { List } from 'immutable';
1
+ import { type List } from 'immutable';
2
2
  export declare class GraphInformant {
3
3
  static readonly NB_EPOCHS_ON_GRAPH = 10;
4
4
  private currentAccuracy;
@@ -0,0 +1,20 @@
1
+ import { Repeat } from 'immutable';
2
+ export class GraphInformant {
3
+ static NB_EPOCHS_ON_GRAPH = 10;
4
+ currentAccuracy;
5
+ accuracyDataSeries;
6
+ constructor() {
7
+ this.currentAccuracy = 0;
8
+ this.accuracyDataSeries = Repeat(0, GraphInformant.NB_EPOCHS_ON_GRAPH).toList();
9
+ }
10
+ updateAccuracy(accuracy) {
11
+ this.accuracyDataSeries = this.accuracyDataSeries.shift().push(accuracy);
12
+ this.currentAccuracy = accuracy;
13
+ }
14
+ data() {
15
+ return this.accuracyDataSeries;
16
+ }
17
+ accuracy() {
18
+ return this.currentAccuracy;
19
+ }
20
+ }
@@ -0,0 +1 @@
1
+ export { GraphInformant } from './graph_informant.js';
@@ -0,0 +1 @@
1
+ export { GraphInformant } from './graph_informant.js';
@@ -1,10 +1,10 @@
1
- import { Logger } from './logger';
1
+ import { Logger } from "./logger.js";
2
2
  /**
3
3
  * Same properties as Toaster but on the console
4
4
  *
5
5
  * @class Logger
6
6
  */
7
- export declare class ConsoleLogger extends Logger {
7
+ export declare class ConsoleLogger implements Logger {
8
8
  /**
9
9
  * Logs success message on the console (in green)
10
10
  * @param {String} message - message to be displayed
@@ -0,0 +1,22 @@
1
+ import chalk from "chalk";
2
+ /**
3
+ * Same properties as Toaster but on the console
4
+ *
5
+ * @class Logger
6
+ */
7
+ export class ConsoleLogger {
8
+ /**
9
+ * Logs success message on the console (in green)
10
+ * @param {String} message - message to be displayed
11
+ */
12
+ success(message) {
13
+ console.log(chalk.green(message));
14
+ }
15
+ /**
16
+ * Logs error message on the console (in red)
17
+ * @param message - message to be displayed
18
+ */
19
+ error(message) {
20
+ console.error(chalk.red(message));
21
+ }
22
+ }
@@ -0,0 +1,2 @@
1
+ export { Logger } from './logger.js';
2
+ export { ConsoleLogger } from './console_logger.js';
@@ -0,0 +1 @@
1
+ export { ConsoleLogger } from './console_logger.js';
@@ -1,12 +1,12 @@
1
- export declare abstract class Logger {
1
+ export interface Logger {
2
2
  /**
3
3
  * Logs sucess message (in green)
4
4
  * @param message - message to be displayed
5
5
  */
6
- abstract success(message: string): void;
6
+ success(message: string): void;
7
7
  /**
8
8
  * Logs error message (in red)
9
9
  * @param message - message to be displayed
10
10
  */
11
- abstract error(message: string): void;
11
+ error(message: string): void;
12
12
  }
@@ -0,0 +1 @@
1
+ export {};
@@ -0,0 +1,119 @@
1
+ import type { Model, TaskID } from '../index.js';
2
+ import type { ModelType } from './model_type.js';
3
+ /**
4
+ * Model path which uniquely identifies a model in memory.
5
+ */
6
+ export type Path = string;
7
+ /**
8
+ * Model information which uniquely identifies a model in memory.
9
+ */
10
+ export interface ModelInfo {
11
+ /**
12
+ * The model's type: "working" or "saved" model.
13
+ */
14
+ type?: ModelType;
15
+ /**
16
+ * The model's version, to allow for multiple saved models of a same task without
17
+ * causing id conflicts
18
+ */
19
+ version?: number;
20
+ /**
21
+ * The model's corresponding task
22
+ */
23
+ taskID: TaskID;
24
+ /**
25
+ * The model's name
26
+ */
27
+ name: string;
28
+ }
29
+ /**
30
+ * A model source uniquely identifies a model stored in memory.
31
+ */
32
+ export type ModelSource = ModelInfo | Path;
33
+ /**
34
+ * Represents a model memory system, providing functions to fetch, save, delete and update models.
35
+ * Stored models can either be a model currently being trained ("working model") or a regular model
36
+ * saved in memory ("saved model"). There can only be a single working model for a given task.
37
+ */
38
+ export declare abstract class Memory {
39
+ /**
40
+ * Fetches the model identified by the given model source.
41
+ * @param source The model source
42
+ * @returns The model
43
+ */
44
+ abstract getModel(source: ModelSource): Promise<Model>;
45
+ /**
46
+ * Removes the model identified by the given model source from memory.
47
+ * @param source The model source
48
+ * @returns The model
49
+ */
50
+ abstract deleteModel(source: ModelSource): Promise<void>;
51
+ /**
52
+ * Replaces the corresponding working model with the saved model identified by the given model source.
53
+ * @param source The model source
54
+ */
55
+ abstract loadModel(source: ModelSource): Promise<void>;
56
+ /**
57
+ * Fetches metadata for the model identified by the given model source.
58
+ * If the model does not exist in memory, returns undefined.
59
+ * @param source The model source
60
+ * @returns The model metadata or undefined
61
+ */
62
+ abstract getModelMetadata(source: ModelSource): Promise<object | undefined>;
63
+ /**
64
+ * Replaces the working model identified by the given source with the newly provided model.
65
+ * @param source The model source
66
+ * @param model The new model
67
+ */
68
+ abstract updateWorkingModel(source: ModelSource, model: Model): Promise<void>;
69
+ /**
70
+ * Creates a saved model copy from the working model identified by the given model source.
71
+ * Returns the saved model's path.
72
+ * @param source The model source
73
+ * @returns The saved model's path
74
+ */
75
+ abstract saveWorkingModel(source: ModelSource): Promise<Path | undefined>;
76
+ /**
77
+ * Saves the newly provided model to the given model source.
78
+ * Returns the saved model's path
79
+ * @param source The model source
80
+ * @param model The new model
81
+ * @returns The saved model's path
82
+ */
83
+ abstract saveModel(source: ModelSource, model: Model): Promise<Path | undefined>;
84
+ /**
85
+ * Moves the model identified by the model source to a file system. This is platform-dependent.
86
+ * @param source The model source
87
+ */
88
+ abstract downloadModel(source: ModelSource): Promise<void>;
89
+ /**
90
+ * Checks whether the model memory contains the model identified by the given source.
91
+ * @param source The model source
92
+ * @returns True if the memory contains the model, false otherwise
93
+ */
94
+ abstract contains(source: ModelSource): Promise<boolean>;
95
+ /**
96
+ * Computes the path in memory corresponding to the given model source, be it a path or model information.
97
+ * This is used to easily switch between model path and information, which are both unique model identifiers
98
+ * with a one-to-one correspondance. Returns undefined instead if no path could be inferred from the given
99
+ * model source.
100
+ * @param source The model source
101
+ * @returns The model path
102
+ */
103
+ abstract pathFor(source: ModelSource): Path | undefined;
104
+ /**
105
+ * Computes the model information corresponding to the given model source, be it a path or model information.
106
+ * This is used to easily switch between model path and information, which are both unique model identifiers
107
+ * with a one-to-one correspondance. Returns undefined instead if no unique model information could be inferred
108
+ * from the given model source.
109
+ * @param source The model source
110
+ * @returns The model information
111
+ */
112
+ abstract infoFor(source: ModelSource): ModelInfo | undefined;
113
+ /**
114
+ * Computes the lowest version a model source can have without conflicting with model versions currently in memory.
115
+ * @param source The model source
116
+ * @returns The duplicated model source
117
+ */
118
+ abstract duplicateSource(source: ModelSource): Promise<ModelSource | undefined>;
119
+ }
@@ -0,0 +1,9 @@
1
+ // only used browser-side
2
+ // TODO: replace IO type
3
+ /**
4
+ * Represents a model memory system, providing functions to fetch, save, delete and update models.
5
+ * Stored models can either be a model currently being trained ("working model") or a regular model
6
+ * saved in memory ("saved model"). There can only be a single working model for a given task.
7
+ */
8
+ export class Memory {
9
+ }
@@ -0,0 +1,20 @@
1
+ import type { Model } from '../index.js';
2
+ import type { ModelInfo, Path } from './base.js';
3
+ import { Memory } from './base.js';
4
+ /**
5
+ * Represents an empty model memory.
6
+ */
7
+ export declare class Empty extends Memory {
8
+ getModelMetadata(): Promise<undefined>;
9
+ contains(): Promise<boolean>;
10
+ getModel(): Promise<Model>;
11
+ loadModel(): Promise<void>;
12
+ updateWorkingModel(): Promise<void>;
13
+ saveWorkingModel(): Promise<undefined>;
14
+ saveModel(): Promise<undefined>;
15
+ deleteModel(): Promise<void>;
16
+ downloadModel(): Promise<void>;
17
+ pathFor(): Path;
18
+ infoFor(): ModelInfo;
19
+ duplicateSource(): Promise<undefined>;
20
+ }
@@ -0,0 +1,43 @@
1
+ import { Memory } from './base.js';
2
+ /**
3
+ * Represents an empty model memory.
4
+ */
5
+ export class Empty extends Memory {
6
+ getModelMetadata() {
7
+ return Promise.resolve(undefined);
8
+ }
9
+ contains() {
10
+ return Promise.resolve(false);
11
+ }
12
+ getModel() {
13
+ return Promise.reject(new Error('empty'));
14
+ }
15
+ loadModel() {
16
+ return Promise.reject(new Error('empty'));
17
+ }
18
+ updateWorkingModel() {
19
+ // nothing to do
20
+ return Promise.resolve();
21
+ }
22
+ saveWorkingModel() {
23
+ return Promise.resolve(undefined);
24
+ }
25
+ saveModel() {
26
+ return Promise.resolve(undefined);
27
+ }
28
+ async deleteModel() {
29
+ // nothing to do
30
+ }
31
+ downloadModel() {
32
+ return Promise.reject(new Error('empty'));
33
+ }
34
+ pathFor() {
35
+ throw new Error('empty');
36
+ }
37
+ infoFor() {
38
+ throw new Error('empty');
39
+ }
40
+ duplicateSource() {
41
+ return Promise.resolve(undefined);
42
+ }
43
+ }
@@ -1 +1,3 @@
1
- export { IndexedDB } from './memory';
1
+ export { Empty } from './empty.js';
2
+ export { Memory, type ModelInfo, type Path, type ModelSource } from './base.js';
3
+ export { ModelType } from './model_type.js';
@@ -1,5 +1,3 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.IndexedDB = void 0;
4
- var memory_1 = require("./memory");
5
- Object.defineProperty(exports, "IndexedDB", { enumerable: true, get: function () { return memory_1.IndexedDB; } });
1
+ export { Empty } from './empty.js';
2
+ export { Memory } from './base.js';
3
+ export { ModelType } from './model_type.js';
@@ -0,0 +1,9 @@
1
+ /**
2
+ * Type of models stored in memory. Stored models can either be a model currently
3
+ * being trained ("working model") or a regular model saved in memory ("saved model").
4
+ * There can only be a single working model for a given task.
5
+ */
6
+ export declare enum ModelType {
7
+ WORKING = "working",
8
+ SAVED = "saved"
9
+ }
@@ -0,0 +1,10 @@
1
+ /**
2
+ * Type of models stored in memory. Stored models can either be a model currently
3
+ * being trained ("working model") or a regular model saved in memory ("saved model").
4
+ * There can only be a single working model for a given task.
5
+ */
6
+ export var ModelType;
7
+ (function (ModelType) {
8
+ ModelType["WORKING"] = "working";
9
+ ModelType["SAVED"] = "saved";
10
+ })(ModelType || (ModelType = {}));
@@ -1,4 +1,4 @@
1
- import { Task, WeightsContainer } from '.';
1
+ import type { Task, WeightsContainer } from './index.js';
2
2
  /**
3
3
  * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
4
4
  * The previous round's weights are the last weights pulled from server/peers.