@epfml/discojs 2.0.0 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/default_tasks/cifar10/index.js +60 -0
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/default_tasks/mnist.js +61 -0
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/default_tasks/simple_face/index.js +48 -0
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/default_tasks/titanic.js +88 -0
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.d.ts +5 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/task/task.d.ts +12 -0
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +25 -16
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -41
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -3
- package/dist/core/task/index.js +0 -8
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.d.ts +0 -10
- package/dist/core/task/task.js +0 -31
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/tasks/cifar10.d.ts +0 -3
- package/dist/core/tasks/cifar10.js +0 -65
- package/dist/core/tasks/geotags.d.ts +0 -3
- package/dist/core/tasks/geotags.js +0 -67
- package/dist/core/tasks/index.d.ts +0 -6
- package/dist/core/tasks/index.js +0 -10
- package/dist/core/tasks/lus_covid.d.ts +0 -3
- package/dist/core/tasks/lus_covid.js +0 -87
- package/dist/core/tasks/mnist.d.ts +0 -3
- package/dist/core/tasks/mnist.js +0 -60
- package/dist/core/tasks/simple_face.d.ts +0 -2
- package/dist/core/tasks/simple_face.js +0 -41
- package/dist/core/tasks/titanic.d.ts +0 -3
- package/dist/core/tasks/titanic.js +0 -88
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -8
- package/dist/core/weights/aggregation.js +0 -96
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
export { cifar10 } from './cifar10/index.js';
|
|
2
|
+
export { geotags } from './geotags/index.js';
|
|
3
|
+
export { lusCovid } from './lus_covid.js';
|
|
4
|
+
export { mnist } from './mnist.js';
|
|
5
|
+
export { simpleFace } from './simple_face/index.js';
|
|
6
|
+
export { skinMnist } from './skin_mnist.js';
|
|
7
|
+
export { titanic } from './titanic.js';
|
|
8
|
+
export { wikitext } from './wikitext.js';
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
export { cifar10 } from './cifar10/index.js';
|
|
2
|
+
export { geotags } from './geotags/index.js';
|
|
3
|
+
export { lusCovid } from './lus_covid.js';
|
|
4
|
+
export { mnist } from './mnist.js';
|
|
5
|
+
export { simpleFace } from './simple_face/index.js';
|
|
6
|
+
export { skinMnist } from './skin_mnist.js';
|
|
7
|
+
export { titanic } from './titanic.js';
|
|
8
|
+
export { wikitext } from './wikitext.js';
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { data, models } from '../index.js';
|
|
3
|
+
export const lusCovid = {
|
|
4
|
+
getTask() {
|
|
5
|
+
return {
|
|
6
|
+
id: 'lus_covid',
|
|
7
|
+
displayInformation: {
|
|
8
|
+
taskTitle: 'COVID Lung Ultrasound',
|
|
9
|
+
summary: {
|
|
10
|
+
preview: 'Do you have a data of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task.',
|
|
11
|
+
overview: "Don’t have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
|
|
12
|
+
},
|
|
13
|
+
model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
|
|
14
|
+
tradeoffs: 'We are using a simpler version of DeepChest in order to be able to run it on the browser.',
|
|
15
|
+
dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
|
|
16
|
+
dataExampleText: 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png',
|
|
17
|
+
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png'
|
|
18
|
+
},
|
|
19
|
+
trainingInformation: {
|
|
20
|
+
modelID: 'lus-covid-model',
|
|
21
|
+
epochs: 50,
|
|
22
|
+
roundDuration: 2,
|
|
23
|
+
validationSplit: 0,
|
|
24
|
+
batchSize: 5,
|
|
25
|
+
IMAGE_H: 100,
|
|
26
|
+
IMAGE_W: 100,
|
|
27
|
+
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
28
|
+
LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
|
|
29
|
+
dataType: 'image',
|
|
30
|
+
scheme: 'federated',
|
|
31
|
+
noiseScale: undefined,
|
|
32
|
+
clippingRadius: 20,
|
|
33
|
+
decentralizedSecure: true,
|
|
34
|
+
minimumReadyPeers: 2,
|
|
35
|
+
maxShareValue: 100
|
|
36
|
+
}
|
|
37
|
+
};
|
|
38
|
+
},
|
|
39
|
+
// Model architecture from tensorflow.js docs:
|
|
40
|
+
// https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#4
|
|
41
|
+
async getModel() {
|
|
42
|
+
const imageHeight = 100;
|
|
43
|
+
const imageWidth = 100;
|
|
44
|
+
const imageChannels = 3;
|
|
45
|
+
const numOutputClasses = 2;
|
|
46
|
+
const model = tf.sequential();
|
|
47
|
+
// In the first layer of our convolutional neural network we have
|
|
48
|
+
// to specify the input shape. Then we specify some parameters for
|
|
49
|
+
// the convolution operation that takes place in this layer.
|
|
50
|
+
model.add(tf.layers.conv2d({
|
|
51
|
+
inputShape: [imageHeight, imageWidth, imageChannels],
|
|
52
|
+
kernelSize: 5,
|
|
53
|
+
filters: 8,
|
|
54
|
+
strides: 1,
|
|
55
|
+
activation: 'relu',
|
|
56
|
+
kernelInitializer: 'varianceScaling'
|
|
57
|
+
}));
|
|
58
|
+
// The MaxPooling layer acts as a sort of downsampling using max values
|
|
59
|
+
// in a region instead of averaging.
|
|
60
|
+
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
|
|
61
|
+
// Repeat the conv2d + maxPooling block.
|
|
62
|
+
// Note that we have more filters in the convolution.
|
|
63
|
+
model.add(tf.layers.conv2d({
|
|
64
|
+
kernelSize: 5,
|
|
65
|
+
filters: 16,
|
|
66
|
+
strides: 1,
|
|
67
|
+
activation: 'relu',
|
|
68
|
+
kernelInitializer: 'varianceScaling'
|
|
69
|
+
}));
|
|
70
|
+
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
|
|
71
|
+
// Now we flatten the output from the 2D filters into a 1D vector to prepare
|
|
72
|
+
// it for input into our last layer. This is common practice when feeding
|
|
73
|
+
// higher dimensional data to a final classification output layer.
|
|
74
|
+
model.add(tf.layers.flatten());
|
|
75
|
+
// Our last layer is a dense layer which has 2 output units, one for each
|
|
76
|
+
// output class.
|
|
77
|
+
model.add(tf.layers.dense({
|
|
78
|
+
units: numOutputClasses,
|
|
79
|
+
kernelInitializer: 'varianceScaling',
|
|
80
|
+
activation: 'softmax'
|
|
81
|
+
}));
|
|
82
|
+
model.compile({
|
|
83
|
+
optimizer: 'sgd',
|
|
84
|
+
loss: 'binaryCrossentropy',
|
|
85
|
+
metrics: ['accuracy']
|
|
86
|
+
});
|
|
87
|
+
return Promise.resolve(new models.TFJS(model));
|
|
88
|
+
}
|
|
89
|
+
};
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { models } from '../index.js';
|
|
3
|
+
export const mnist = {
|
|
4
|
+
getTask() {
|
|
5
|
+
return {
|
|
6
|
+
id: 'mnist',
|
|
7
|
+
displayInformation: {
|
|
8
|
+
taskTitle: 'MNIST',
|
|
9
|
+
summary: {
|
|
10
|
+
preview: "Test our platform by using a publicly available <b>image</b> dataset. <br><br> Download the classic MNIST imagebank of hand-written numbers <a class='underline text-primary-dark dark:text-primary-light' href='https://www.kaggle.com/scolianni/mnistasjpg'>here</a>. <br> This model learns to identify hand written numbers.",
|
|
11
|
+
overview: 'The MNIST handwritten digit classification problem is a standard dataset used in computer vision and deep learning. Although the dataset is effectively solved, we use it to test our Decentralised Learning algorithms and platform.'
|
|
12
|
+
},
|
|
13
|
+
model: 'The current model is a very simple CNN and its main goal is to test the app and the Decentralizsed Learning functionality.',
|
|
14
|
+
tradeoffs: 'We are using a simple model, first a 2d convolutional layer > max pooling > 2d convolutional layer > max pooling > convolutional layer > 2 dense layers.',
|
|
15
|
+
dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can upload each digit image of your dataset in the box corresponding to its label. The model taskes images of size 28x28 as input.',
|
|
16
|
+
dataExampleText: 'Below you can find an example of an expected image representing the digit 9.',
|
|
17
|
+
dataExampleImage: 'http://storage.googleapis.com/deai-313515.appspot.com/example_training_data/9-mnist-example.png'
|
|
18
|
+
},
|
|
19
|
+
trainingInformation: {
|
|
20
|
+
modelID: 'mnist-model',
|
|
21
|
+
epochs: 10,
|
|
22
|
+
roundDuration: 10,
|
|
23
|
+
validationSplit: 0.2,
|
|
24
|
+
batchSize: 30,
|
|
25
|
+
dataType: 'image',
|
|
26
|
+
IMAGE_H: 28,
|
|
27
|
+
IMAGE_W: 28,
|
|
28
|
+
preprocessingFunctions: [],
|
|
29
|
+
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
30
|
+
scheme: 'decentralized',
|
|
31
|
+
noiseScale: undefined,
|
|
32
|
+
clippingRadius: 20,
|
|
33
|
+
decentralizedSecure: true,
|
|
34
|
+
minimumReadyPeers: 3,
|
|
35
|
+
maxShareValue: 100
|
|
36
|
+
}
|
|
37
|
+
};
|
|
38
|
+
},
|
|
39
|
+
getModel() {
|
|
40
|
+
const model = tf.sequential();
|
|
41
|
+
model.add(tf.layers.conv2d({
|
|
42
|
+
inputShape: [28, 28, 3],
|
|
43
|
+
kernelSize: 3,
|
|
44
|
+
filters: 16,
|
|
45
|
+
activation: 'relu'
|
|
46
|
+
}));
|
|
47
|
+
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
|
|
48
|
+
model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
|
|
49
|
+
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
|
|
50
|
+
model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
|
|
51
|
+
model.add(tf.layers.flatten({}));
|
|
52
|
+
model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
|
|
53
|
+
model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
|
|
54
|
+
model.compile({
|
|
55
|
+
optimizer: 'rmsprop',
|
|
56
|
+
loss: 'categoricalCrossentropy',
|
|
57
|
+
metrics: ['accuracy']
|
|
58
|
+
});
|
|
59
|
+
return Promise.resolve(new models.TFJS(model));
|
|
60
|
+
}
|
|
61
|
+
};
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { data, models } from '../../index.js';
|
|
3
|
+
import baseModel from './model.js';
|
|
4
|
+
export const simpleFace = {
|
|
5
|
+
getTask() {
|
|
6
|
+
return {
|
|
7
|
+
id: 'simple_face',
|
|
8
|
+
displayInformation: {
|
|
9
|
+
taskTitle: 'Simple Face',
|
|
10
|
+
summary: {
|
|
11
|
+
preview: 'Can you detect if the person in a picture is a child or an adult?',
|
|
12
|
+
overview: 'Simple face is a small subset of face_task from Kaggle'
|
|
13
|
+
},
|
|
14
|
+
limitations: 'The training data is limited to small images of size 200x200.',
|
|
15
|
+
tradeoffs: 'Training success strongly depends on label distribution',
|
|
16
|
+
dataFormatInformation: '',
|
|
17
|
+
dataExampleText: 'Below you find an example',
|
|
18
|
+
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png'
|
|
19
|
+
},
|
|
20
|
+
trainingInformation: {
|
|
21
|
+
modelID: 'simple_face-model',
|
|
22
|
+
epochs: 50,
|
|
23
|
+
roundDuration: 1,
|
|
24
|
+
validationSplit: 0.2,
|
|
25
|
+
batchSize: 10,
|
|
26
|
+
preprocessingFunctions: [data.ImagePreprocessing.Normalize],
|
|
27
|
+
dataType: 'image',
|
|
28
|
+
IMAGE_H: 200,
|
|
29
|
+
IMAGE_W: 200,
|
|
30
|
+
LABEL_LIST: ['child', 'adult'],
|
|
31
|
+
scheme: 'federated', // secure aggregation not yet implemented for federated
|
|
32
|
+
noiseScale: undefined,
|
|
33
|
+
clippingRadius: undefined
|
|
34
|
+
}
|
|
35
|
+
};
|
|
36
|
+
},
|
|
37
|
+
async getModel() {
|
|
38
|
+
const model = await tf.loadLayersModel({
|
|
39
|
+
load: async () => Promise.resolve(baseModel),
|
|
40
|
+
});
|
|
41
|
+
model.compile({
|
|
42
|
+
optimizer: tf.train.sgd(0.001),
|
|
43
|
+
loss: 'categoricalCrossentropy',
|
|
44
|
+
metrics: ['accuracy']
|
|
45
|
+
});
|
|
46
|
+
return new models.TFJS(model);
|
|
47
|
+
}
|
|
48
|
+
};
|