@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/{core/task → task}/task.d.ts +7 -7
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +24 -15
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/default_tasks/cifar10.d.ts +0 -2
- package/dist/core/default_tasks/geotags.d.ts +0 -2
- package/dist/core/default_tasks/geotags.js +0 -69
- package/dist/core/default_tasks/index.d.ts +0 -6
- package/dist/core/default_tasks/index.js +0 -15
- package/dist/core/default_tasks/lus_covid.d.ts +0 -2
- package/dist/core/default_tasks/lus_covid.js +0 -96
- package/dist/core/default_tasks/mnist.d.ts +0 -2
- package/dist/core/default_tasks/simple_face.d.ts +0 -2
- package/dist/core/default_tasks/titanic.d.ts +0 -2
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -39
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/digest.js +0 -18
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -6
- package/dist/core/task/index.js +0 -15
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.js +0 -35
- package/dist/core/task/task_handler.d.ts +0 -5
- package/dist/core/task/task_handler.js +0 -53
- package/dist/core/task/task_provider.d.ts +0 -6
- package/dist/core/task/task_provider.js +0 -13
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -7
- package/dist/core/weights/aggregation.js +0 -72
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/digest.d.ts +0 -0
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
export { cifar10 } from './cifar10/index.js';
|
|
2
|
+
export { geotags } from './geotags/index.js';
|
|
3
|
+
export { lusCovid } from './lus_covid.js';
|
|
4
|
+
export { mnist } from './mnist.js';
|
|
5
|
+
export { simpleFace } from './simple_face/index.js';
|
|
6
|
+
export { skinMnist } from './skin_mnist.js';
|
|
7
|
+
export { titanic } from './titanic.js';
|
|
8
|
+
export { wikitext } from './wikitext.js';
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
export { cifar10 } from './cifar10/index.js';
|
|
2
|
+
export { geotags } from './geotags/index.js';
|
|
3
|
+
export { lusCovid } from './lus_covid.js';
|
|
4
|
+
export { mnist } from './mnist.js';
|
|
5
|
+
export { simpleFace } from './simple_face/index.js';
|
|
6
|
+
export { skinMnist } from './skin_mnist.js';
|
|
7
|
+
export { titanic } from './titanic.js';
|
|
8
|
+
export { wikitext } from './wikitext.js';
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { data, models } from '../index.js';
|
|
3
|
+
export const lusCovid = {
|
|
4
|
+
getTask() {
|
|
5
|
+
return {
|
|
6
|
+
id: 'lus_covid',
|
|
7
|
+
displayInformation: {
|
|
8
|
+
taskTitle: 'COVID Lung Ultrasound',
|
|
9
|
+
summary: {
|
|
10
|
+
preview: 'Do you have a data of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task.',
|
|
11
|
+
overview: "Don’t have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
|
|
12
|
+
},
|
|
13
|
+
model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
|
|
14
|
+
tradeoffs: 'We are using a simpler version of DeepChest in order to be able to run it on the browser.',
|
|
15
|
+
dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
|
|
16
|
+
dataExampleText: 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png',
|
|
17
|
+
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png'
|
|
18
|
+
},
|
|
19
|
+
trainingInformation: {
|
|
20
|
+
modelID: 'lus-covid-model',
|
|
21
|
+
epochs: 50,
|
|
22
|
+
roundDuration: 2,
|
|
23
|
+
validationSplit: 0,
|
|
24
|
+
batchSize: 5,
|
|
25
|
+
IMAGE_H: 100,
|
|
26
|
+
IMAGE_W: 100,
|
|
27
|
+
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
28
|
+
LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
|
|
29
|
+
dataType: 'image',
|
|
30
|
+
scheme: 'federated',
|
|
31
|
+
noiseScale: undefined,
|
|
32
|
+
clippingRadius: 20,
|
|
33
|
+
decentralizedSecure: true,
|
|
34
|
+
minimumReadyPeers: 2,
|
|
35
|
+
maxShareValue: 100
|
|
36
|
+
}
|
|
37
|
+
};
|
|
38
|
+
},
|
|
39
|
+
// Model architecture from tensorflow.js docs:
|
|
40
|
+
// https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#4
|
|
41
|
+
async getModel() {
|
|
42
|
+
const imageHeight = 100;
|
|
43
|
+
const imageWidth = 100;
|
|
44
|
+
const imageChannels = 3;
|
|
45
|
+
const numOutputClasses = 2;
|
|
46
|
+
const model = tf.sequential();
|
|
47
|
+
// In the first layer of our convolutional neural network we have
|
|
48
|
+
// to specify the input shape. Then we specify some parameters for
|
|
49
|
+
// the convolution operation that takes place in this layer.
|
|
50
|
+
model.add(tf.layers.conv2d({
|
|
51
|
+
inputShape: [imageHeight, imageWidth, imageChannels],
|
|
52
|
+
kernelSize: 5,
|
|
53
|
+
filters: 8,
|
|
54
|
+
strides: 1,
|
|
55
|
+
activation: 'relu',
|
|
56
|
+
kernelInitializer: 'varianceScaling'
|
|
57
|
+
}));
|
|
58
|
+
// The MaxPooling layer acts as a sort of downsampling using max values
|
|
59
|
+
// in a region instead of averaging.
|
|
60
|
+
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
|
|
61
|
+
// Repeat the conv2d + maxPooling block.
|
|
62
|
+
// Note that we have more filters in the convolution.
|
|
63
|
+
model.add(tf.layers.conv2d({
|
|
64
|
+
kernelSize: 5,
|
|
65
|
+
filters: 16,
|
|
66
|
+
strides: 1,
|
|
67
|
+
activation: 'relu',
|
|
68
|
+
kernelInitializer: 'varianceScaling'
|
|
69
|
+
}));
|
|
70
|
+
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
|
|
71
|
+
// Now we flatten the output from the 2D filters into a 1D vector to prepare
|
|
72
|
+
// it for input into our last layer. This is common practice when feeding
|
|
73
|
+
// higher dimensional data to a final classification output layer.
|
|
74
|
+
model.add(tf.layers.flatten());
|
|
75
|
+
// Our last layer is a dense layer which has 2 output units, one for each
|
|
76
|
+
// output class.
|
|
77
|
+
model.add(tf.layers.dense({
|
|
78
|
+
units: numOutputClasses,
|
|
79
|
+
kernelInitializer: 'varianceScaling',
|
|
80
|
+
activation: 'softmax'
|
|
81
|
+
}));
|
|
82
|
+
model.compile({
|
|
83
|
+
optimizer: 'sgd',
|
|
84
|
+
loss: 'binaryCrossentropy',
|
|
85
|
+
metrics: ['accuracy']
|
|
86
|
+
});
|
|
87
|
+
return Promise.resolve(new models.TFJS(model));
|
|
88
|
+
}
|
|
89
|
+
};
|
|
@@ -1,12 +1,9 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
var __1 = require("..");
|
|
6
|
-
exports.mnist = {
|
|
7
|
-
getTask: function () {
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { models } from '../index.js';
|
|
3
|
+
export const mnist = {
|
|
4
|
+
getTask() {
|
|
8
5
|
return {
|
|
9
|
-
|
|
6
|
+
id: 'mnist',
|
|
10
7
|
displayInformation: {
|
|
11
8
|
taskTitle: 'MNIST',
|
|
12
9
|
summary: {
|
|
@@ -25,17 +22,12 @@ exports.mnist = {
|
|
|
25
22
|
roundDuration: 10,
|
|
26
23
|
validationSplit: 0.2,
|
|
27
24
|
batchSize: 30,
|
|
28
|
-
modelCompileData: {
|
|
29
|
-
optimizer: 'rmsprop',
|
|
30
|
-
loss: 'categoricalCrossentropy',
|
|
31
|
-
metrics: ['accuracy']
|
|
32
|
-
},
|
|
33
25
|
dataType: 'image',
|
|
34
26
|
IMAGE_H: 28,
|
|
35
27
|
IMAGE_W: 28,
|
|
36
28
|
preprocessingFunctions: [],
|
|
37
29
|
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
38
|
-
scheme: '
|
|
30
|
+
scheme: 'decentralized',
|
|
39
31
|
noiseScale: undefined,
|
|
40
32
|
clippingRadius: 20,
|
|
41
33
|
decentralizedSecure: true,
|
|
@@ -44,26 +36,26 @@ exports.mnist = {
|
|
|
44
36
|
}
|
|
45
37
|
};
|
|
46
38
|
},
|
|
47
|
-
getModel
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
});
|
|
39
|
+
getModel() {
|
|
40
|
+
const model = tf.sequential();
|
|
41
|
+
model.add(tf.layers.conv2d({
|
|
42
|
+
inputShape: [28, 28, 3],
|
|
43
|
+
kernelSize: 3,
|
|
44
|
+
filters: 16,
|
|
45
|
+
activation: 'relu'
|
|
46
|
+
}));
|
|
47
|
+
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
|
|
48
|
+
model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
|
|
49
|
+
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
|
|
50
|
+
model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
|
|
51
|
+
model.add(tf.layers.flatten({}));
|
|
52
|
+
model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
|
|
53
|
+
model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
|
|
54
|
+
model.compile({
|
|
55
|
+
optimizer: 'rmsprop',
|
|
56
|
+
loss: 'categoricalCrossentropy',
|
|
57
|
+
metrics: ['accuracy']
|
|
67
58
|
});
|
|
59
|
+
return Promise.resolve(new models.TFJS(model));
|
|
68
60
|
}
|
|
69
61
|
};
|
|
@@ -1,12 +1,10 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
exports.simpleFace = {
|
|
7
|
-
getTask: function () {
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { data, models } from '../../index.js';
|
|
3
|
+
import baseModel from './model.js';
|
|
4
|
+
export const simpleFace = {
|
|
5
|
+
getTask() {
|
|
8
6
|
return {
|
|
9
|
-
|
|
7
|
+
id: 'simple_face',
|
|
10
8
|
displayInformation: {
|
|
11
9
|
taskTitle: 'Simple Face',
|
|
12
10
|
summary: {
|
|
@@ -22,32 +20,29 @@ exports.simpleFace = {
|
|
|
22
20
|
trainingInformation: {
|
|
23
21
|
modelID: 'simple_face-model',
|
|
24
22
|
epochs: 50,
|
|
25
|
-
modelURL: 'https://storage.googleapis.com/deai-313515.appspot.com/models/mobileNetV2_35_alpha_2_classes/model.json',
|
|
26
23
|
roundDuration: 1,
|
|
27
24
|
validationSplit: 0.2,
|
|
28
25
|
batchSize: 10,
|
|
29
|
-
preprocessingFunctions: [
|
|
30
|
-
learningRate: 0.001,
|
|
31
|
-
modelCompileData: {
|
|
32
|
-
optimizer: 'sgd',
|
|
33
|
-
loss: 'categoricalCrossentropy',
|
|
34
|
-
metrics: ['accuracy']
|
|
35
|
-
},
|
|
26
|
+
preprocessingFunctions: [data.ImagePreprocessing.Normalize],
|
|
36
27
|
dataType: 'image',
|
|
37
28
|
IMAGE_H: 200,
|
|
38
29
|
IMAGE_W: 200,
|
|
39
30
|
LABEL_LIST: ['child', 'adult'],
|
|
40
|
-
scheme: '
|
|
31
|
+
scheme: 'federated', // secure aggregation not yet implemented for federated
|
|
41
32
|
noiseScale: undefined,
|
|
42
33
|
clippingRadius: undefined
|
|
43
34
|
}
|
|
44
35
|
};
|
|
45
36
|
},
|
|
46
|
-
getModel
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
37
|
+
async getModel() {
|
|
38
|
+
const model = await tf.loadLayersModel({
|
|
39
|
+
load: async () => Promise.resolve(baseModel),
|
|
40
|
+
});
|
|
41
|
+
model.compile({
|
|
42
|
+
optimizer: tf.train.sgd(0.001),
|
|
43
|
+
loss: 'categoricalCrossentropy',
|
|
44
|
+
metrics: ['accuracy']
|
|
51
45
|
});
|
|
46
|
+
return new models.TFJS(model);
|
|
52
47
|
}
|
|
53
48
|
};
|