@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (334) hide show
  1. package/dist/aggregator/base.d.ts +180 -0
  2. package/dist/aggregator/base.js +236 -0
  3. package/dist/aggregator/get.d.ts +16 -0
  4. package/dist/aggregator/get.js +31 -0
  5. package/dist/aggregator/index.d.ts +7 -0
  6. package/dist/aggregator/index.js +4 -0
  7. package/dist/aggregator/mean.d.ts +23 -0
  8. package/dist/aggregator/mean.js +69 -0
  9. package/dist/aggregator/secure.d.ts +27 -0
  10. package/dist/aggregator/secure.js +91 -0
  11. package/dist/async_informant.d.ts +15 -0
  12. package/dist/async_informant.js +42 -0
  13. package/dist/client/base.d.ts +76 -0
  14. package/dist/client/base.js +88 -0
  15. package/dist/client/decentralized/base.d.ts +32 -0
  16. package/dist/client/decentralized/base.js +192 -0
  17. package/dist/client/decentralized/index.d.ts +2 -0
  18. package/dist/client/decentralized/index.js +2 -0
  19. package/dist/client/decentralized/messages.d.ts +28 -0
  20. package/dist/client/decentralized/messages.js +44 -0
  21. package/dist/client/decentralized/peer.d.ts +40 -0
  22. package/dist/client/decentralized/peer.js +189 -0
  23. package/dist/client/decentralized/peer_pool.d.ts +12 -0
  24. package/dist/client/decentralized/peer_pool.js +44 -0
  25. package/dist/client/event_connection.d.ts +34 -0
  26. package/dist/client/event_connection.js +105 -0
  27. package/dist/client/federated/base.d.ts +54 -0
  28. package/dist/client/federated/base.js +151 -0
  29. package/dist/client/federated/index.d.ts +2 -0
  30. package/dist/client/federated/index.js +2 -0
  31. package/dist/client/federated/messages.d.ts +30 -0
  32. package/dist/client/federated/messages.js +24 -0
  33. package/dist/client/index.d.ts +8 -0
  34. package/dist/client/index.js +8 -0
  35. package/dist/client/local.d.ts +3 -0
  36. package/dist/client/local.js +3 -0
  37. package/dist/client/messages.d.ts +30 -0
  38. package/dist/client/messages.js +26 -0
  39. package/dist/client/types.d.ts +2 -0
  40. package/dist/client/types.js +4 -0
  41. package/dist/client/utils.d.ts +2 -0
  42. package/dist/client/utils.js +7 -0
  43. package/dist/dataset/data/data.d.ts +48 -0
  44. package/dist/dataset/data/data.js +72 -0
  45. package/dist/dataset/data/data_split.d.ts +8 -0
  46. package/dist/dataset/data/data_split.js +1 -0
  47. package/dist/dataset/data/image_data.d.ts +11 -0
  48. package/dist/dataset/data/image_data.js +38 -0
  49. package/dist/dataset/data/index.d.ts +6 -0
  50. package/dist/dataset/data/index.js +5 -0
  51. package/dist/dataset/data/preprocessing/base.d.ts +16 -0
  52. package/dist/dataset/data/preprocessing/base.js +1 -0
  53. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
  54. package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
  55. package/dist/dataset/data/preprocessing/index.d.ts +4 -0
  56. package/dist/dataset/data/preprocessing/index.js +3 -0
  57. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
  58. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
  59. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
  60. package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
  61. package/dist/dataset/data/tabular_data.d.ts +11 -0
  62. package/dist/dataset/data/tabular_data.js +25 -0
  63. package/dist/dataset/data/text_data.d.ts +11 -0
  64. package/dist/dataset/data/text_data.js +14 -0
  65. package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
  66. package/dist/dataset/data_loader/data_loader.js +2 -0
  67. package/dist/dataset/data_loader/image_loader.d.ts +20 -3
  68. package/dist/dataset/data_loader/image_loader.js +98 -23
  69. package/dist/dataset/data_loader/index.d.ts +5 -2
  70. package/dist/dataset/data_loader/index.js +4 -7
  71. package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
  72. package/dist/dataset/data_loader/tabular_loader.js +75 -15
  73. package/dist/dataset/data_loader/text_loader.d.ts +14 -0
  74. package/dist/dataset/data_loader/text_loader.js +25 -0
  75. package/dist/dataset/dataset.d.ts +5 -0
  76. package/dist/dataset/dataset.js +1 -0
  77. package/dist/dataset/dataset_builder.d.ts +60 -0
  78. package/dist/dataset/dataset_builder.js +142 -0
  79. package/dist/dataset/index.d.ts +5 -0
  80. package/dist/dataset/index.js +3 -0
  81. package/dist/default_tasks/cifar10/index.d.ts +2 -0
  82. package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
  83. package/dist/default_tasks/cifar10/model.d.ts +434 -0
  84. package/dist/default_tasks/cifar10/model.js +2385 -0
  85. package/dist/default_tasks/geotags/index.d.ts +2 -0
  86. package/dist/default_tasks/geotags/index.js +65 -0
  87. package/dist/default_tasks/geotags/model.d.ts +593 -0
  88. package/dist/default_tasks/geotags/model.js +4715 -0
  89. package/dist/default_tasks/index.d.ts +8 -0
  90. package/dist/default_tasks/index.js +8 -0
  91. package/dist/default_tasks/lus_covid.d.ts +2 -0
  92. package/dist/default_tasks/lus_covid.js +89 -0
  93. package/dist/default_tasks/mnist.d.ts +2 -0
  94. package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
  95. package/dist/default_tasks/simple_face/index.d.ts +2 -0
  96. package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
  97. package/dist/default_tasks/simple_face/model.d.ts +513 -0
  98. package/dist/default_tasks/simple_face/model.js +4301 -0
  99. package/dist/default_tasks/skin_mnist.d.ts +2 -0
  100. package/dist/default_tasks/skin_mnist.js +80 -0
  101. package/dist/default_tasks/titanic.d.ts +2 -0
  102. package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
  103. package/dist/default_tasks/wikitext.d.ts +2 -0
  104. package/dist/default_tasks/wikitext.js +38 -0
  105. package/dist/index.d.ts +18 -2
  106. package/dist/index.js +18 -6
  107. package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
  108. package/dist/informant/graph_informant.js +20 -0
  109. package/dist/informant/index.d.ts +1 -0
  110. package/dist/informant/index.js +1 -0
  111. package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
  112. package/dist/logging/console_logger.js +22 -0
  113. package/dist/logging/index.d.ts +2 -0
  114. package/dist/logging/index.js +1 -0
  115. package/dist/{core/logging → logging}/logger.d.ts +3 -3
  116. package/dist/logging/logger.js +1 -0
  117. package/dist/memory/base.d.ts +119 -0
  118. package/dist/memory/base.js +9 -0
  119. package/dist/memory/empty.d.ts +20 -0
  120. package/dist/memory/empty.js +43 -0
  121. package/dist/memory/index.d.ts +3 -1
  122. package/dist/memory/index.js +3 -5
  123. package/dist/memory/model_type.d.ts +9 -0
  124. package/dist/memory/model_type.js +10 -0
  125. package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
  126. package/dist/{core/privacy.js → privacy.js} +11 -16
  127. package/dist/serialization/index.d.ts +2 -0
  128. package/dist/serialization/index.js +2 -0
  129. package/dist/serialization/model.d.ts +5 -0
  130. package/dist/serialization/model.js +67 -0
  131. package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
  132. package/dist/serialization/weights.js +37 -0
  133. package/dist/task/data_example.js +14 -0
  134. package/dist/task/digest.js +14 -0
  135. package/dist/{core/task → task}/display_information.d.ts +5 -3
  136. package/dist/task/display_information.js +46 -0
  137. package/dist/task/index.d.ts +7 -0
  138. package/dist/task/index.js +5 -0
  139. package/dist/task/label_type.d.ts +9 -0
  140. package/dist/task/label_type.js +28 -0
  141. package/dist/task/summary.js +13 -0
  142. package/dist/{core/task → task}/task.d.ts +7 -7
  143. package/dist/task/task.js +22 -0
  144. package/dist/task/task_handler.d.ts +5 -0
  145. package/dist/task/task_handler.js +20 -0
  146. package/dist/task/task_provider.d.ts +5 -0
  147. package/dist/task/task_provider.js +1 -0
  148. package/dist/{core/task → task}/training_information.d.ts +9 -10
  149. package/dist/task/training_information.js +88 -0
  150. package/dist/training/disco.d.ts +40 -0
  151. package/dist/training/disco.js +107 -0
  152. package/dist/training/index.d.ts +2 -0
  153. package/dist/training/index.js +1 -0
  154. package/dist/training/trainer/distributed_trainer.d.ts +20 -0
  155. package/dist/training/trainer/distributed_trainer.js +36 -0
  156. package/dist/training/trainer/local_trainer.d.ts +12 -0
  157. package/dist/training/trainer/local_trainer.js +19 -0
  158. package/dist/training/trainer/trainer.d.ts +33 -0
  159. package/dist/training/trainer/trainer.js +52 -0
  160. package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
  161. package/dist/training/trainer/trainer_builder.js +43 -0
  162. package/dist/types.d.ts +8 -0
  163. package/dist/types.js +1 -0
  164. package/dist/utils/event_emitter.d.ts +40 -0
  165. package/dist/utils/event_emitter.js +57 -0
  166. package/dist/validation/index.d.ts +1 -0
  167. package/dist/validation/index.js +1 -0
  168. package/dist/validation/validator.d.ts +28 -0
  169. package/dist/validation/validator.js +132 -0
  170. package/dist/weights/aggregation.d.ts +21 -0
  171. package/dist/weights/aggregation.js +44 -0
  172. package/dist/weights/index.d.ts +2 -0
  173. package/dist/weights/index.js +2 -0
  174. package/dist/weights/weights_container.d.ts +68 -0
  175. package/dist/weights/weights_container.js +96 -0
  176. package/package.json +24 -15
  177. package/README.md +0 -53
  178. package/dist/core/async_buffer.d.ts +0 -41
  179. package/dist/core/async_buffer.js +0 -97
  180. package/dist/core/async_informant.d.ts +0 -20
  181. package/dist/core/async_informant.js +0 -69
  182. package/dist/core/client/base.d.ts +0 -33
  183. package/dist/core/client/base.js +0 -35
  184. package/dist/core/client/decentralized/base.d.ts +0 -32
  185. package/dist/core/client/decentralized/base.js +0 -212
  186. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  187. package/dist/core/client/decentralized/clear_text.js +0 -96
  188. package/dist/core/client/decentralized/index.d.ts +0 -4
  189. package/dist/core/client/decentralized/index.js +0 -9
  190. package/dist/core/client/decentralized/messages.d.ts +0 -41
  191. package/dist/core/client/decentralized/messages.js +0 -54
  192. package/dist/core/client/decentralized/peer.d.ts +0 -26
  193. package/dist/core/client/decentralized/peer.js +0 -210
  194. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  195. package/dist/core/client/decentralized/peer_pool.js +0 -92
  196. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  197. package/dist/core/client/decentralized/sec_agg.js +0 -190
  198. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  199. package/dist/core/client/decentralized/secret_shares.js +0 -39
  200. package/dist/core/client/decentralized/types.d.ts +0 -2
  201. package/dist/core/client/decentralized/types.js +0 -7
  202. package/dist/core/client/event_connection.d.ts +0 -37
  203. package/dist/core/client/event_connection.js +0 -158
  204. package/dist/core/client/federated/client.d.ts +0 -37
  205. package/dist/core/client/federated/client.js +0 -273
  206. package/dist/core/client/federated/index.d.ts +0 -2
  207. package/dist/core/client/federated/index.js +0 -7
  208. package/dist/core/client/federated/messages.d.ts +0 -38
  209. package/dist/core/client/federated/messages.js +0 -25
  210. package/dist/core/client/index.d.ts +0 -5
  211. package/dist/core/client/index.js +0 -11
  212. package/dist/core/client/local.d.ts +0 -8
  213. package/dist/core/client/local.js +0 -36
  214. package/dist/core/client/messages.d.ts +0 -28
  215. package/dist/core/client/messages.js +0 -33
  216. package/dist/core/client/utils.d.ts +0 -2
  217. package/dist/core/client/utils.js +0 -19
  218. package/dist/core/dataset/data/data.d.ts +0 -11
  219. package/dist/core/dataset/data/data.js +0 -20
  220. package/dist/core/dataset/data/data_split.d.ts +0 -5
  221. package/dist/core/dataset/data/data_split.js +0 -2
  222. package/dist/core/dataset/data/image_data.d.ts +0 -8
  223. package/dist/core/dataset/data/image_data.js +0 -64
  224. package/dist/core/dataset/data/index.d.ts +0 -5
  225. package/dist/core/dataset/data/index.js +0 -11
  226. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  227. package/dist/core/dataset/data/preprocessing.js +0 -33
  228. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  229. package/dist/core/dataset/data/tabular_data.js +0 -40
  230. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  231. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  232. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  233. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  234. package/dist/core/dataset/data_loader/index.js +0 -9
  235. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  236. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  237. package/dist/core/dataset/dataset.d.ts +0 -2
  238. package/dist/core/dataset/dataset.js +0 -2
  239. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  240. package/dist/core/dataset/dataset_builder.js +0 -96
  241. package/dist/core/dataset/index.d.ts +0 -4
  242. package/dist/core/dataset/index.js +0 -14
  243. package/dist/core/default_tasks/cifar10.d.ts +0 -2
  244. package/dist/core/default_tasks/geotags.d.ts +0 -2
  245. package/dist/core/default_tasks/geotags.js +0 -69
  246. package/dist/core/default_tasks/index.d.ts +0 -6
  247. package/dist/core/default_tasks/index.js +0 -15
  248. package/dist/core/default_tasks/lus_covid.d.ts +0 -2
  249. package/dist/core/default_tasks/lus_covid.js +0 -96
  250. package/dist/core/default_tasks/mnist.d.ts +0 -2
  251. package/dist/core/default_tasks/simple_face.d.ts +0 -2
  252. package/dist/core/default_tasks/titanic.d.ts +0 -2
  253. package/dist/core/index.d.ts +0 -18
  254. package/dist/core/index.js +0 -39
  255. package/dist/core/informant/graph_informant.js +0 -23
  256. package/dist/core/informant/index.d.ts +0 -3
  257. package/dist/core/informant/index.js +0 -9
  258. package/dist/core/informant/training_informant/base.d.ts +0 -31
  259. package/dist/core/informant/training_informant/base.js +0 -83
  260. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  261. package/dist/core/informant/training_informant/decentralized.js +0 -22
  262. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  263. package/dist/core/informant/training_informant/federated.js +0 -32
  264. package/dist/core/informant/training_informant/index.d.ts +0 -4
  265. package/dist/core/informant/training_informant/index.js +0 -11
  266. package/dist/core/informant/training_informant/local.d.ts +0 -6
  267. package/dist/core/informant/training_informant/local.js +0 -20
  268. package/dist/core/logging/console_logger.js +0 -33
  269. package/dist/core/logging/index.d.ts +0 -3
  270. package/dist/core/logging/index.js +0 -9
  271. package/dist/core/logging/logger.js +0 -9
  272. package/dist/core/logging/trainer_logger.d.ts +0 -24
  273. package/dist/core/logging/trainer_logger.js +0 -59
  274. package/dist/core/memory/base.d.ts +0 -22
  275. package/dist/core/memory/base.js +0 -9
  276. package/dist/core/memory/empty.d.ts +0 -14
  277. package/dist/core/memory/empty.js +0 -75
  278. package/dist/core/memory/index.d.ts +0 -3
  279. package/dist/core/memory/index.js +0 -9
  280. package/dist/core/memory/model_type.d.ts +0 -4
  281. package/dist/core/memory/model_type.js +0 -9
  282. package/dist/core/serialization/index.d.ts +0 -2
  283. package/dist/core/serialization/index.js +0 -6
  284. package/dist/core/serialization/model.d.ts +0 -5
  285. package/dist/core/serialization/model.js +0 -55
  286. package/dist/core/serialization/weights.js +0 -64
  287. package/dist/core/task/data_example.js +0 -24
  288. package/dist/core/task/digest.js +0 -18
  289. package/dist/core/task/display_information.js +0 -49
  290. package/dist/core/task/index.d.ts +0 -6
  291. package/dist/core/task/index.js +0 -15
  292. package/dist/core/task/model_compile_data.d.ts +0 -6
  293. package/dist/core/task/model_compile_data.js +0 -22
  294. package/dist/core/task/summary.js +0 -19
  295. package/dist/core/task/task.js +0 -35
  296. package/dist/core/task/task_handler.d.ts +0 -5
  297. package/dist/core/task/task_handler.js +0 -53
  298. package/dist/core/task/task_provider.d.ts +0 -6
  299. package/dist/core/task/task_provider.js +0 -13
  300. package/dist/core/task/training_information.js +0 -66
  301. package/dist/core/training/disco.d.ts +0 -23
  302. package/dist/core/training/disco.js +0 -130
  303. package/dist/core/training/index.d.ts +0 -2
  304. package/dist/core/training/index.js +0 -7
  305. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  306. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  307. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  308. package/dist/core/training/trainer/local_trainer.js +0 -34
  309. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  310. package/dist/core/training/trainer/round_tracker.js +0 -47
  311. package/dist/core/training/trainer/trainer.d.ts +0 -65
  312. package/dist/core/training/trainer/trainer.js +0 -160
  313. package/dist/core/training/trainer/trainer_builder.js +0 -95
  314. package/dist/core/training/training_schemes.d.ts +0 -5
  315. package/dist/core/training/training_schemes.js +0 -10
  316. package/dist/core/types.d.ts +0 -4
  317. package/dist/core/types.js +0 -2
  318. package/dist/core/validation/index.d.ts +0 -1
  319. package/dist/core/validation/index.js +0 -5
  320. package/dist/core/validation/validator.d.ts +0 -17
  321. package/dist/core/validation/validator.js +0 -104
  322. package/dist/core/weights/aggregation.d.ts +0 -7
  323. package/dist/core/weights/aggregation.js +0 -72
  324. package/dist/core/weights/index.d.ts +0 -2
  325. package/dist/core/weights/index.js +0 -7
  326. package/dist/core/weights/weights_container.d.ts +0 -19
  327. package/dist/core/weights/weights_container.js +0 -64
  328. package/dist/imports.d.ts +0 -2
  329. package/dist/imports.js +0 -7
  330. package/dist/memory/memory.d.ts +0 -26
  331. package/dist/memory/memory.js +0 -160
  332. package/dist/{core/task → task}/data_example.d.ts +1 -1
  333. package/dist/{core/task → task}/digest.d.ts +0 -0
  334. package/dist/{core/task → task}/summary.d.ts +1 -1
@@ -0,0 +1,8 @@
1
+ export { cifar10 } from './cifar10/index.js';
2
+ export { geotags } from './geotags/index.js';
3
+ export { lusCovid } from './lus_covid.js';
4
+ export { mnist } from './mnist.js';
5
+ export { simpleFace } from './simple_face/index.js';
6
+ export { skinMnist } from './skin_mnist.js';
7
+ export { titanic } from './titanic.js';
8
+ export { wikitext } from './wikitext.js';
@@ -0,0 +1,8 @@
1
+ export { cifar10 } from './cifar10/index.js';
2
+ export { geotags } from './geotags/index.js';
3
+ export { lusCovid } from './lus_covid.js';
4
+ export { mnist } from './mnist.js';
5
+ export { simpleFace } from './simple_face/index.js';
6
+ export { skinMnist } from './skin_mnist.js';
7
+ export { titanic } from './titanic.js';
8
+ export { wikitext } from './wikitext.js';
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const lusCovid: TaskProvider;
@@ -0,0 +1,89 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { data, models } from '../index.js';
3
+ export const lusCovid = {
4
+ getTask() {
5
+ return {
6
+ id: 'lus_covid',
7
+ displayInformation: {
8
+ taskTitle: 'COVID Lung Ultrasound',
9
+ summary: {
10
+ preview: 'Do you have a data of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task.',
11
+ overview: "Don’t have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
12
+ },
13
+ model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
14
+ tradeoffs: 'We are using a simpler version of DeepChest in order to be able to run it on the browser.',
15
+ dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
16
+ dataExampleText: 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png',
17
+ dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png'
18
+ },
19
+ trainingInformation: {
20
+ modelID: 'lus-covid-model',
21
+ epochs: 50,
22
+ roundDuration: 2,
23
+ validationSplit: 0,
24
+ batchSize: 5,
25
+ IMAGE_H: 100,
26
+ IMAGE_W: 100,
27
+ preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
28
+ LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
29
+ dataType: 'image',
30
+ scheme: 'federated',
31
+ noiseScale: undefined,
32
+ clippingRadius: 20,
33
+ decentralizedSecure: true,
34
+ minimumReadyPeers: 2,
35
+ maxShareValue: 100
36
+ }
37
+ };
38
+ },
39
+ // Model architecture from tensorflow.js docs:
40
+ // https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#4
41
+ async getModel() {
42
+ const imageHeight = 100;
43
+ const imageWidth = 100;
44
+ const imageChannels = 3;
45
+ const numOutputClasses = 2;
46
+ const model = tf.sequential();
47
+ // In the first layer of our convolutional neural network we have
48
+ // to specify the input shape. Then we specify some parameters for
49
+ // the convolution operation that takes place in this layer.
50
+ model.add(tf.layers.conv2d({
51
+ inputShape: [imageHeight, imageWidth, imageChannels],
52
+ kernelSize: 5,
53
+ filters: 8,
54
+ strides: 1,
55
+ activation: 'relu',
56
+ kernelInitializer: 'varianceScaling'
57
+ }));
58
+ // The MaxPooling layer acts as a sort of downsampling using max values
59
+ // in a region instead of averaging.
60
+ model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
61
+ // Repeat the conv2d + maxPooling block.
62
+ // Note that we have more filters in the convolution.
63
+ model.add(tf.layers.conv2d({
64
+ kernelSize: 5,
65
+ filters: 16,
66
+ strides: 1,
67
+ activation: 'relu',
68
+ kernelInitializer: 'varianceScaling'
69
+ }));
70
+ model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
71
+ // Now we flatten the output from the 2D filters into a 1D vector to prepare
72
+ // it for input into our last layer. This is common practice when feeding
73
+ // higher dimensional data to a final classification output layer.
74
+ model.add(tf.layers.flatten());
75
+ // Our last layer is a dense layer which has 2 output units, one for each
76
+ // output class.
77
+ model.add(tf.layers.dense({
78
+ units: numOutputClasses,
79
+ kernelInitializer: 'varianceScaling',
80
+ activation: 'softmax'
81
+ }));
82
+ model.compile({
83
+ optimizer: 'sgd',
84
+ loss: 'binaryCrossentropy',
85
+ metrics: ['accuracy']
86
+ });
87
+ return Promise.resolve(new models.TFJS(model));
88
+ }
89
+ };
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const mnist: TaskProvider;
@@ -1,12 +1,9 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.mnist = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("..");
6
- exports.mnist = {
7
- getTask: function () {
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { models } from '../index.js';
3
+ export const mnist = {
4
+ getTask() {
8
5
  return {
9
- taskID: 'mnist',
6
+ id: 'mnist',
10
7
  displayInformation: {
11
8
  taskTitle: 'MNIST',
12
9
  summary: {
@@ -25,17 +22,12 @@ exports.mnist = {
25
22
  roundDuration: 10,
26
23
  validationSplit: 0.2,
27
24
  batchSize: 30,
28
- modelCompileData: {
29
- optimizer: 'rmsprop',
30
- loss: 'categoricalCrossentropy',
31
- metrics: ['accuracy']
32
- },
33
25
  dataType: 'image',
34
26
  IMAGE_H: 28,
35
27
  IMAGE_W: 28,
36
28
  preprocessingFunctions: [],
37
29
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
38
- scheme: 'Decentralized',
30
+ scheme: 'decentralized',
39
31
  noiseScale: undefined,
40
32
  clippingRadius: 20,
41
33
  decentralizedSecure: true,
@@ -44,26 +36,26 @@ exports.mnist = {
44
36
  }
45
37
  };
46
38
  },
47
- getModel: function () {
48
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
49
- var model;
50
- return (0, tslib_1.__generator)(this, function (_a) {
51
- model = __1.tf.sequential();
52
- model.add(__1.tf.layers.conv2d({
53
- inputShape: [28, 28, 3],
54
- kernelSize: 3,
55
- filters: 16,
56
- activation: 'relu'
57
- }));
58
- model.add(__1.tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
59
- model.add(__1.tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
60
- model.add(__1.tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
61
- model.add(__1.tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
62
- model.add(__1.tf.layers.flatten({}));
63
- model.add(__1.tf.layers.dense({ units: 64, activation: 'relu' }));
64
- model.add(__1.tf.layers.dense({ units: 10, activation: 'softmax' }));
65
- return [2 /*return*/, model];
66
- });
39
+ getModel() {
40
+ const model = tf.sequential();
41
+ model.add(tf.layers.conv2d({
42
+ inputShape: [28, 28, 3],
43
+ kernelSize: 3,
44
+ filters: 16,
45
+ activation: 'relu'
46
+ }));
47
+ model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
48
+ model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
49
+ model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
50
+ model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
51
+ model.add(tf.layers.flatten({}));
52
+ model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
53
+ model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
54
+ model.compile({
55
+ optimizer: 'rmsprop',
56
+ loss: 'categoricalCrossentropy',
57
+ metrics: ['accuracy']
67
58
  });
59
+ return Promise.resolve(new models.TFJS(model));
68
60
  }
69
61
  };
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../../index.js';
2
+ export declare const simpleFace: TaskProvider;
@@ -1,12 +1,10 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.simpleFace = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("..");
6
- exports.simpleFace = {
7
- getTask: function () {
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { data, models } from '../../index.js';
3
+ import baseModel from './model.js';
4
+ export const simpleFace = {
5
+ getTask() {
8
6
  return {
9
- taskID: 'simple_face',
7
+ id: 'simple_face',
10
8
  displayInformation: {
11
9
  taskTitle: 'Simple Face',
12
10
  summary: {
@@ -22,32 +20,29 @@ exports.simpleFace = {
22
20
  trainingInformation: {
23
21
  modelID: 'simple_face-model',
24
22
  epochs: 50,
25
- modelURL: 'https://storage.googleapis.com/deai-313515.appspot.com/models/mobileNetV2_35_alpha_2_classes/model.json',
26
23
  roundDuration: 1,
27
24
  validationSplit: 0.2,
28
25
  batchSize: 10,
29
- preprocessingFunctions: [__1.data.ImagePreprocessing.Normalize],
30
- learningRate: 0.001,
31
- modelCompileData: {
32
- optimizer: 'sgd',
33
- loss: 'categoricalCrossentropy',
34
- metrics: ['accuracy']
35
- },
26
+ preprocessingFunctions: [data.ImagePreprocessing.Normalize],
36
27
  dataType: 'image',
37
28
  IMAGE_H: 200,
38
29
  IMAGE_W: 200,
39
30
  LABEL_LIST: ['child', 'adult'],
40
- scheme: 'Federated',
31
+ scheme: 'federated', // secure aggregation not yet implemented for federated
41
32
  noiseScale: undefined,
42
33
  clippingRadius: undefined
43
34
  }
44
35
  };
45
36
  },
46
- getModel: function () {
47
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
48
- return (0, tslib_1.__generator)(this, function (_a) {
49
- throw new Error('Not implemented');
50
- });
37
+ async getModel() {
38
+ const model = await tf.loadLayersModel({
39
+ load: async () => Promise.resolve(baseModel),
40
+ });
41
+ model.compile({
42
+ optimizer: tf.train.sgd(0.001),
43
+ loss: 'categoricalCrossentropy',
44
+ metrics: ['accuracy']
51
45
  });
46
+ return new models.TFJS(model);
52
47
  }
53
48
  };