@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/{core/task → task}/task.d.ts +7 -7
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +24 -15
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/default_tasks/cifar10.d.ts +0 -2
- package/dist/core/default_tasks/geotags.d.ts +0 -2
- package/dist/core/default_tasks/geotags.js +0 -69
- package/dist/core/default_tasks/index.d.ts +0 -6
- package/dist/core/default_tasks/index.js +0 -15
- package/dist/core/default_tasks/lus_covid.d.ts +0 -2
- package/dist/core/default_tasks/lus_covid.js +0 -96
- package/dist/core/default_tasks/mnist.d.ts +0 -2
- package/dist/core/default_tasks/simple_face.d.ts +0 -2
- package/dist/core/default_tasks/titanic.d.ts +0 -2
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -39
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/digest.js +0 -18
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -6
- package/dist/core/task/index.js +0 -15
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.js +0 -35
- package/dist/core/task/task_handler.d.ts +0 -5
- package/dist/core/task/task_handler.js +0 -53
- package/dist/core/task/task_provider.d.ts +0 -6
- package/dist/core/task/task_provider.js +0 -13
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -7
- package/dist/core/weights/aggregation.js +0 -72
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/digest.d.ts +0 -0
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -1,96 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.lusCovid = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var __1 = require("..");
|
|
6
|
-
exports.lusCovid = {
|
|
7
|
-
getTask: function () {
|
|
8
|
-
return {
|
|
9
|
-
taskID: 'lus_covid',
|
|
10
|
-
displayInformation: {
|
|
11
|
-
taskTitle: 'COVID Lung Ultrasound',
|
|
12
|
-
summary: {
|
|
13
|
-
preview: 'Do you have a data of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task.',
|
|
14
|
-
overview: "Don’t have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
|
|
15
|
-
},
|
|
16
|
-
model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
|
|
17
|
-
tradeoffs: 'We are using a simpler version of DeepChest in order to be able to run it on the browser.',
|
|
18
|
-
dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
|
|
19
|
-
dataExampleText: 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png',
|
|
20
|
-
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png'
|
|
21
|
-
},
|
|
22
|
-
trainingInformation: {
|
|
23
|
-
modelID: 'lus-covid-model',
|
|
24
|
-
epochs: 15,
|
|
25
|
-
roundDuration: 10,
|
|
26
|
-
validationSplit: 0.2,
|
|
27
|
-
batchSize: 2,
|
|
28
|
-
modelCompileData: {
|
|
29
|
-
optimizer: 'sgd',
|
|
30
|
-
loss: 'binaryCrossentropy',
|
|
31
|
-
metrics: ['accuracy']
|
|
32
|
-
},
|
|
33
|
-
learningRate: 0.001,
|
|
34
|
-
IMAGE_H: 100,
|
|
35
|
-
IMAGE_W: 100,
|
|
36
|
-
preprocessingFunctions: [__1.data.ImagePreprocessing.Resize],
|
|
37
|
-
LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
|
|
38
|
-
dataType: 'image',
|
|
39
|
-
scheme: 'Decentralized',
|
|
40
|
-
noiseScale: undefined,
|
|
41
|
-
clippingRadius: 20,
|
|
42
|
-
decentralizedSecure: true,
|
|
43
|
-
minimumReadyPeers: 3,
|
|
44
|
-
maxShareValue: 100
|
|
45
|
-
}
|
|
46
|
-
};
|
|
47
|
-
},
|
|
48
|
-
getModel: function () {
|
|
49
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
50
|
-
var imageHeight, imageWidth, imageChannels, numOutputClasses, model;
|
|
51
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
52
|
-
imageHeight = 100;
|
|
53
|
-
imageWidth = 100;
|
|
54
|
-
imageChannels = 3;
|
|
55
|
-
numOutputClasses = 2;
|
|
56
|
-
model = __1.tf.sequential();
|
|
57
|
-
// In the first layer of our convolutional neural network we have
|
|
58
|
-
// to specify the input shape. Then we specify some parameters for
|
|
59
|
-
// the convolution operation that takes place in this layer.
|
|
60
|
-
model.add(__1.tf.layers.conv2d({
|
|
61
|
-
inputShape: [imageHeight, imageWidth, imageChannels],
|
|
62
|
-
kernelSize: 5,
|
|
63
|
-
filters: 8,
|
|
64
|
-
strides: 1,
|
|
65
|
-
activation: 'relu',
|
|
66
|
-
kernelInitializer: 'varianceScaling'
|
|
67
|
-
}));
|
|
68
|
-
// The MaxPooling layer acts as a sort of downsampling using max values
|
|
69
|
-
// in a region instead of averaging.
|
|
70
|
-
model.add(__1.tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
|
|
71
|
-
// Repeat another conv2d + maxPooling stack.
|
|
72
|
-
// Note that we have more filters in the convolution.
|
|
73
|
-
model.add(__1.tf.layers.conv2d({
|
|
74
|
-
kernelSize: 5,
|
|
75
|
-
filters: 16,
|
|
76
|
-
strides: 1,
|
|
77
|
-
activation: 'relu',
|
|
78
|
-
kernelInitializer: 'varianceScaling'
|
|
79
|
-
}));
|
|
80
|
-
model.add(__1.tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
|
|
81
|
-
// Now we flatten the output from the 2D filters into a 1D vector to prepare
|
|
82
|
-
// it for input into our last layer. This is common practice when feeding
|
|
83
|
-
// higher dimensional data to a final classification output layer.
|
|
84
|
-
model.add(__1.tf.layers.flatten());
|
|
85
|
-
// Our last layer is a dense layer which has 2 output units, one for each
|
|
86
|
-
// output class.
|
|
87
|
-
model.add(__1.tf.layers.dense({
|
|
88
|
-
units: numOutputClasses,
|
|
89
|
-
kernelInitializer: 'varianceScaling',
|
|
90
|
-
activation: 'softmax'
|
|
91
|
-
}));
|
|
92
|
-
return [2 /*return*/, model];
|
|
93
|
-
});
|
|
94
|
-
});
|
|
95
|
-
}
|
|
96
|
-
};
|
package/dist/core/index.d.ts
DELETED
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
export * as tf from '@tensorflow/tfjs';
|
|
2
|
-
export * as data from './dataset';
|
|
3
|
-
export * as serialization from './serialization';
|
|
4
|
-
export * as training from './training';
|
|
5
|
-
export * as privacy from './privacy';
|
|
6
|
-
export { GraphInformant, TrainingInformant, informant } from './informant';
|
|
7
|
-
export { Base as Client } from './client';
|
|
8
|
-
export * as client from './client';
|
|
9
|
-
export { WeightsContainer, aggregation } from './weights';
|
|
10
|
-
export { AsyncBuffer } from './async_buffer';
|
|
11
|
-
export { AsyncInformant } from './async_informant';
|
|
12
|
-
export { Logger, ConsoleLogger, TrainerLog } from './logging';
|
|
13
|
-
export { Memory, ModelType, ModelInfo, Path, ModelSource, Empty as EmptyMemory } from './memory';
|
|
14
|
-
export { Disco, TrainingSchemes } from './training';
|
|
15
|
-
export { Validator } from './validation';
|
|
16
|
-
export * from './task';
|
|
17
|
-
export * as defaultTasks from './default_tasks';
|
|
18
|
-
export * from './types';
|
package/dist/core/index.js
DELETED
|
@@ -1,39 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.defaultTasks = exports.Validator = exports.TrainingSchemes = exports.Disco = exports.EmptyMemory = exports.ModelType = exports.Memory = exports.TrainerLog = exports.ConsoleLogger = exports.Logger = exports.AsyncInformant = exports.AsyncBuffer = exports.aggregation = exports.WeightsContainer = exports.client = exports.Client = exports.informant = exports.TrainingInformant = exports.GraphInformant = exports.privacy = exports.training = exports.serialization = exports.data = exports.tf = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
exports.tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
|
|
6
|
-
exports.data = (0, tslib_1.__importStar)(require("./dataset"));
|
|
7
|
-
exports.serialization = (0, tslib_1.__importStar)(require("./serialization"));
|
|
8
|
-
exports.training = (0, tslib_1.__importStar)(require("./training"));
|
|
9
|
-
exports.privacy = (0, tslib_1.__importStar)(require("./privacy"));
|
|
10
|
-
var informant_1 = require("./informant");
|
|
11
|
-
Object.defineProperty(exports, "GraphInformant", { enumerable: true, get: function () { return informant_1.GraphInformant; } });
|
|
12
|
-
Object.defineProperty(exports, "TrainingInformant", { enumerable: true, get: function () { return informant_1.TrainingInformant; } });
|
|
13
|
-
Object.defineProperty(exports, "informant", { enumerable: true, get: function () { return informant_1.informant; } });
|
|
14
|
-
var client_1 = require("./client");
|
|
15
|
-
Object.defineProperty(exports, "Client", { enumerable: true, get: function () { return client_1.Base; } });
|
|
16
|
-
exports.client = (0, tslib_1.__importStar)(require("./client"));
|
|
17
|
-
var weights_1 = require("./weights");
|
|
18
|
-
Object.defineProperty(exports, "WeightsContainer", { enumerable: true, get: function () { return weights_1.WeightsContainer; } });
|
|
19
|
-
Object.defineProperty(exports, "aggregation", { enumerable: true, get: function () { return weights_1.aggregation; } });
|
|
20
|
-
var async_buffer_1 = require("./async_buffer");
|
|
21
|
-
Object.defineProperty(exports, "AsyncBuffer", { enumerable: true, get: function () { return async_buffer_1.AsyncBuffer; } });
|
|
22
|
-
var async_informant_1 = require("./async_informant");
|
|
23
|
-
Object.defineProperty(exports, "AsyncInformant", { enumerable: true, get: function () { return async_informant_1.AsyncInformant; } });
|
|
24
|
-
var logging_1 = require("./logging");
|
|
25
|
-
Object.defineProperty(exports, "Logger", { enumerable: true, get: function () { return logging_1.Logger; } });
|
|
26
|
-
Object.defineProperty(exports, "ConsoleLogger", { enumerable: true, get: function () { return logging_1.ConsoleLogger; } });
|
|
27
|
-
Object.defineProperty(exports, "TrainerLog", { enumerable: true, get: function () { return logging_1.TrainerLog; } });
|
|
28
|
-
var memory_1 = require("./memory");
|
|
29
|
-
Object.defineProperty(exports, "Memory", { enumerable: true, get: function () { return memory_1.Memory; } });
|
|
30
|
-
Object.defineProperty(exports, "ModelType", { enumerable: true, get: function () { return memory_1.ModelType; } });
|
|
31
|
-
Object.defineProperty(exports, "EmptyMemory", { enumerable: true, get: function () { return memory_1.Empty; } });
|
|
32
|
-
var training_1 = require("./training");
|
|
33
|
-
Object.defineProperty(exports, "Disco", { enumerable: true, get: function () { return training_1.Disco; } });
|
|
34
|
-
Object.defineProperty(exports, "TrainingSchemes", { enumerable: true, get: function () { return training_1.TrainingSchemes; } });
|
|
35
|
-
var validation_1 = require("./validation");
|
|
36
|
-
Object.defineProperty(exports, "Validator", { enumerable: true, get: function () { return validation_1.Validator; } });
|
|
37
|
-
(0, tslib_1.__exportStar)(require("./task"), exports);
|
|
38
|
-
exports.defaultTasks = (0, tslib_1.__importStar)(require("./default_tasks"));
|
|
39
|
-
(0, tslib_1.__exportStar)(require("./types"), exports);
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.GraphInformant = void 0;
|
|
4
|
-
var immutable_1 = require("immutable");
|
|
5
|
-
var GraphInformant = /** @class */ (function () {
|
|
6
|
-
function GraphInformant() {
|
|
7
|
-
this.currentAccuracy = 0;
|
|
8
|
-
this.accuracyDataSeries = (0, immutable_1.Repeat)(0, GraphInformant.NB_EPOCHS_ON_GRAPH).toList();
|
|
9
|
-
}
|
|
10
|
-
GraphInformant.prototype.updateAccuracy = function (accuracy) {
|
|
11
|
-
this.accuracyDataSeries = this.accuracyDataSeries.shift().push(accuracy);
|
|
12
|
-
this.currentAccuracy = accuracy;
|
|
13
|
-
};
|
|
14
|
-
GraphInformant.prototype.data = function () {
|
|
15
|
-
return this.accuracyDataSeries;
|
|
16
|
-
};
|
|
17
|
-
GraphInformant.prototype.accuracy = function () {
|
|
18
|
-
return this.currentAccuracy;
|
|
19
|
-
};
|
|
20
|
-
GraphInformant.NB_EPOCHS_ON_GRAPH = 10;
|
|
21
|
-
return GraphInformant;
|
|
22
|
-
}());
|
|
23
|
-
exports.GraphInformant = GraphInformant;
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.informant = exports.TrainingInformant = exports.GraphInformant = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var graph_informant_1 = require("./graph_informant");
|
|
6
|
-
Object.defineProperty(exports, "GraphInformant", { enumerable: true, get: function () { return graph_informant_1.GraphInformant; } });
|
|
7
|
-
var training_informant_1 = require("./training_informant");
|
|
8
|
-
Object.defineProperty(exports, "TrainingInformant", { enumerable: true, get: function () { return training_informant_1.Base; } });
|
|
9
|
-
exports.informant = (0, tslib_1.__importStar)(require("./training_informant"));
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
import { List } from 'immutable';
|
|
2
|
-
import { Task } from '../../task';
|
|
3
|
-
import { GraphInformant } from '../graph_informant';
|
|
4
|
-
export declare abstract class Base {
|
|
5
|
-
readonly task: Task;
|
|
6
|
-
private readonly nbrMessagesToShow;
|
|
7
|
-
private messages;
|
|
8
|
-
protected readonly trainingGraphInformant: GraphInformant;
|
|
9
|
-
protected readonly validationGraphInformant: GraphInformant;
|
|
10
|
-
protected currentRound: number;
|
|
11
|
-
protected currentNumberOfParticipants: number;
|
|
12
|
-
protected totalNumberOfParticipants: number;
|
|
13
|
-
protected averageNumberOfParticipants: number;
|
|
14
|
-
constructor(task: Task, nbrMessagesToShow?: number);
|
|
15
|
-
abstract update(statistics: Record<string, number>): void;
|
|
16
|
-
addMessage(msg: string): void;
|
|
17
|
-
getMessages(): string[];
|
|
18
|
-
round(): number;
|
|
19
|
-
participants(): number;
|
|
20
|
-
totalParticipants(): number;
|
|
21
|
-
averageParticipants(): number;
|
|
22
|
-
updateTrainingGraph(accuracy: number): void;
|
|
23
|
-
updateValidationGraph(accuracy: number): void;
|
|
24
|
-
trainingAccuracy(): number;
|
|
25
|
-
validationAccuracy(): number;
|
|
26
|
-
trainingAccuracyData(): List<number>;
|
|
27
|
-
validationAccuracyData(): List<number>;
|
|
28
|
-
isDecentralized(): boolean;
|
|
29
|
-
isFederated(): boolean;
|
|
30
|
-
static isTrainingInformant(raw: unknown): raw is Base;
|
|
31
|
-
}
|
|
@@ -1,83 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Base = void 0;
|
|
4
|
-
var immutable_1 = require("immutable");
|
|
5
|
-
var graph_informant_1 = require("../graph_informant");
|
|
6
|
-
var Base = /** @class */ (function () {
|
|
7
|
-
function Base(task, nbrMessagesToShow) {
|
|
8
|
-
if (nbrMessagesToShow === void 0) { nbrMessagesToShow = 10; }
|
|
9
|
-
this.task = task;
|
|
10
|
-
this.nbrMessagesToShow = nbrMessagesToShow;
|
|
11
|
-
// written feedback
|
|
12
|
-
this.messages = (0, immutable_1.List)();
|
|
13
|
-
// graph feedback
|
|
14
|
-
this.trainingGraphInformant = new graph_informant_1.GraphInformant();
|
|
15
|
-
this.validationGraphInformant = new graph_informant_1.GraphInformant();
|
|
16
|
-
// statistics
|
|
17
|
-
this.currentRound = 0;
|
|
18
|
-
this.currentNumberOfParticipants = 0;
|
|
19
|
-
this.totalNumberOfParticipants = 0;
|
|
20
|
-
this.averageNumberOfParticipants = 0;
|
|
21
|
-
}
|
|
22
|
-
Base.prototype.addMessage = function (msg) {
|
|
23
|
-
if (this.messages.size >= this.nbrMessagesToShow) {
|
|
24
|
-
this.messages = this.messages.shift();
|
|
25
|
-
}
|
|
26
|
-
this.messages = this.messages.push(msg);
|
|
27
|
-
};
|
|
28
|
-
Base.prototype.getMessages = function () {
|
|
29
|
-
return this.messages.toArray();
|
|
30
|
-
};
|
|
31
|
-
Base.prototype.round = function () {
|
|
32
|
-
return this.currentRound;
|
|
33
|
-
};
|
|
34
|
-
Base.prototype.participants = function () {
|
|
35
|
-
return this.currentNumberOfParticipants;
|
|
36
|
-
};
|
|
37
|
-
Base.prototype.totalParticipants = function () {
|
|
38
|
-
return this.totalNumberOfParticipants;
|
|
39
|
-
};
|
|
40
|
-
Base.prototype.averageParticipants = function () {
|
|
41
|
-
return this.averageNumberOfParticipants;
|
|
42
|
-
};
|
|
43
|
-
Base.prototype.updateTrainingGraph = function (accuracy) {
|
|
44
|
-
this.trainingGraphInformant.updateAccuracy(accuracy);
|
|
45
|
-
};
|
|
46
|
-
Base.prototype.updateValidationGraph = function (accuracy) {
|
|
47
|
-
this.validationGraphInformant.updateAccuracy(accuracy);
|
|
48
|
-
};
|
|
49
|
-
Base.prototype.trainingAccuracy = function () {
|
|
50
|
-
return this.trainingGraphInformant.accuracy();
|
|
51
|
-
};
|
|
52
|
-
Base.prototype.validationAccuracy = function () {
|
|
53
|
-
return this.validationGraphInformant.accuracy();
|
|
54
|
-
};
|
|
55
|
-
Base.prototype.trainingAccuracyData = function () {
|
|
56
|
-
return this.trainingGraphInformant.data();
|
|
57
|
-
};
|
|
58
|
-
Base.prototype.validationAccuracyData = function () {
|
|
59
|
-
return this.validationGraphInformant.data();
|
|
60
|
-
};
|
|
61
|
-
Base.prototype.isDecentralized = function () {
|
|
62
|
-
return false;
|
|
63
|
-
};
|
|
64
|
-
Base.prototype.isFederated = function () {
|
|
65
|
-
return false;
|
|
66
|
-
};
|
|
67
|
-
Base.isTrainingInformant = function (raw) {
|
|
68
|
-
if (typeof raw !== 'object') {
|
|
69
|
-
return false;
|
|
70
|
-
}
|
|
71
|
-
if (raw === null) {
|
|
72
|
-
return false;
|
|
73
|
-
}
|
|
74
|
-
// TODO
|
|
75
|
-
var requiredFields = (0, immutable_1.Set)();
|
|
76
|
-
if (!(requiredFields.every(function (field) { return field in raw; }))) {
|
|
77
|
-
return false;
|
|
78
|
-
}
|
|
79
|
-
return true;
|
|
80
|
-
};
|
|
81
|
-
return Base;
|
|
82
|
-
}());
|
|
83
|
-
exports.Base = Base;
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.DecentralizedInformant = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var _1 = require(".");
|
|
6
|
-
var DecentralizedInformant = /** @class */ (function (_super) {
|
|
7
|
-
(0, tslib_1.__extends)(DecentralizedInformant, _super);
|
|
8
|
-
function DecentralizedInformant() {
|
|
9
|
-
return _super !== null && _super.apply(this, arguments) || this;
|
|
10
|
-
}
|
|
11
|
-
DecentralizedInformant.prototype.update = function (statistics) {
|
|
12
|
-
this.currentRound += 1;
|
|
13
|
-
this.currentNumberOfParticipants = statistics.currentNumberOfParticipants;
|
|
14
|
-
this.totalNumberOfParticipants += this.currentNumberOfParticipants;
|
|
15
|
-
this.averageNumberOfParticipants = this.totalNumberOfParticipants / this.currentRound;
|
|
16
|
-
};
|
|
17
|
-
DecentralizedInformant.prototype.isDecentralized = function () {
|
|
18
|
-
return true;
|
|
19
|
-
};
|
|
20
|
-
return DecentralizedInformant;
|
|
21
|
-
}(_1.Base));
|
|
22
|
-
exports.DecentralizedInformant = DecentralizedInformant;
|
|
@@ -1,14 +0,0 @@
|
|
|
1
|
-
import { Base } from '.';
|
|
2
|
-
/**
|
|
3
|
-
* Class that collects information about the status of the training-loop of the model.
|
|
4
|
-
*/
|
|
5
|
-
export declare class FederatedInformant extends Base {
|
|
6
|
-
displayHeatmap: boolean;
|
|
7
|
-
/**
|
|
8
|
-
* Update the server statistics with the JSON received from the server
|
|
9
|
-
* For now it's just the JSON, but we might want to keep it as a dictionary
|
|
10
|
-
* @param receivedStatistics statistics received from the server.
|
|
11
|
-
*/
|
|
12
|
-
update(receivedStatistics: Record<string, number>): void;
|
|
13
|
-
isFederated(): boolean;
|
|
14
|
-
}
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.FederatedInformant = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var _1 = require(".");
|
|
6
|
-
/**
|
|
7
|
-
* Class that collects information about the status of the training-loop of the model.
|
|
8
|
-
*/
|
|
9
|
-
var FederatedInformant = /** @class */ (function (_super) {
|
|
10
|
-
(0, tslib_1.__extends)(FederatedInformant, _super);
|
|
11
|
-
function FederatedInformant() {
|
|
12
|
-
var _this = _super !== null && _super.apply(this, arguments) || this;
|
|
13
|
-
_this.displayHeatmap = false;
|
|
14
|
-
return _this;
|
|
15
|
-
}
|
|
16
|
-
/**
|
|
17
|
-
* Update the server statistics with the JSON received from the server
|
|
18
|
-
* For now it's just the JSON, but we might want to keep it as a dictionary
|
|
19
|
-
* @param receivedStatistics statistics received from the server.
|
|
20
|
-
*/
|
|
21
|
-
FederatedInformant.prototype.update = function (receivedStatistics) {
|
|
22
|
-
this.currentRound = receivedStatistics.round;
|
|
23
|
-
this.currentNumberOfParticipants = receivedStatistics.currentNumberOfParticipants;
|
|
24
|
-
this.totalNumberOfParticipants = receivedStatistics.totalNumberOfParticipants;
|
|
25
|
-
this.averageNumberOfParticipants = receivedStatistics.averageNumberOfParticipants;
|
|
26
|
-
};
|
|
27
|
-
FederatedInformant.prototype.isFederated = function () {
|
|
28
|
-
return true;
|
|
29
|
-
};
|
|
30
|
-
return FederatedInformant;
|
|
31
|
-
}(_1.Base));
|
|
32
|
-
exports.FederatedInformant = FederatedInformant;
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.LocalInformant = exports.DecentralizedInformant = exports.FederatedInformant = exports.Base = void 0;
|
|
4
|
-
var base_1 = require("./base");
|
|
5
|
-
Object.defineProperty(exports, "Base", { enumerable: true, get: function () { return base_1.Base; } });
|
|
6
|
-
var federated_1 = require("./federated");
|
|
7
|
-
Object.defineProperty(exports, "FederatedInformant", { enumerable: true, get: function () { return federated_1.FederatedInformant; } });
|
|
8
|
-
var decentralized_1 = require("./decentralized");
|
|
9
|
-
Object.defineProperty(exports, "DecentralizedInformant", { enumerable: true, get: function () { return decentralized_1.DecentralizedInformant; } });
|
|
10
|
-
var local_1 = require("./local");
|
|
11
|
-
Object.defineProperty(exports, "LocalInformant", { enumerable: true, get: function () { return local_1.LocalInformant; } });
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.LocalInformant = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var _1 = require(".");
|
|
6
|
-
var LocalInformant = /** @class */ (function (_super) {
|
|
7
|
-
(0, tslib_1.__extends)(LocalInformant, _super);
|
|
8
|
-
function LocalInformant(task, nbrMessagesToShow) {
|
|
9
|
-
var _this = _super.call(this, task, nbrMessagesToShow) || this;
|
|
10
|
-
_this.currentNumberOfParticipants = 1;
|
|
11
|
-
_this.averageNumberOfParticipants = 1;
|
|
12
|
-
_this.totalNumberOfParticipants = 1;
|
|
13
|
-
return _this;
|
|
14
|
-
}
|
|
15
|
-
LocalInformant.prototype.update = function (statistics) {
|
|
16
|
-
this.currentRound = statistics.currentRound;
|
|
17
|
-
};
|
|
18
|
-
return LocalInformant;
|
|
19
|
-
}(_1.Base));
|
|
20
|
-
exports.LocalInformant = LocalInformant;
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.ConsoleLogger = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var chalk_1 = (0, tslib_1.__importDefault)(require("chalk"));
|
|
6
|
-
var logger_1 = require("./logger");
|
|
7
|
-
/**
|
|
8
|
-
* Same properties as Toaster but on the console
|
|
9
|
-
*
|
|
10
|
-
* @class Logger
|
|
11
|
-
*/
|
|
12
|
-
var ConsoleLogger = /** @class */ (function (_super) {
|
|
13
|
-
(0, tslib_1.__extends)(ConsoleLogger, _super);
|
|
14
|
-
function ConsoleLogger() {
|
|
15
|
-
return _super !== null && _super.apply(this, arguments) || this;
|
|
16
|
-
}
|
|
17
|
-
/**
|
|
18
|
-
* Logs success message on the console (in green)
|
|
19
|
-
* @param {String} message - message to be displayed
|
|
20
|
-
*/
|
|
21
|
-
ConsoleLogger.prototype.success = function (message) {
|
|
22
|
-
console.log(chalk_1.default.green(message));
|
|
23
|
-
};
|
|
24
|
-
/**
|
|
25
|
-
* Logs error message on the console (in red)
|
|
26
|
-
* @param message - message to be displayed
|
|
27
|
-
*/
|
|
28
|
-
ConsoleLogger.prototype.error = function (message) {
|
|
29
|
-
console.log(chalk_1.default.red(message));
|
|
30
|
-
};
|
|
31
|
-
return ConsoleLogger;
|
|
32
|
-
}(logger_1.Logger));
|
|
33
|
-
exports.ConsoleLogger = ConsoleLogger;
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.TrainerLog = exports.ConsoleLogger = exports.Logger = void 0;
|
|
4
|
-
var logger_1 = require("./logger");
|
|
5
|
-
Object.defineProperty(exports, "Logger", { enumerable: true, get: function () { return logger_1.Logger; } });
|
|
6
|
-
var console_logger_1 = require("./console_logger");
|
|
7
|
-
Object.defineProperty(exports, "ConsoleLogger", { enumerable: true, get: function () { return console_logger_1.ConsoleLogger; } });
|
|
8
|
-
var trainer_logger_1 = require("./trainer_logger");
|
|
9
|
-
Object.defineProperty(exports, "TrainerLog", { enumerable: true, get: function () { return trainer_logger_1.TrainerLog; } });
|
|
@@ -1,24 +0,0 @@
|
|
|
1
|
-
import { List } from 'immutable';
|
|
2
|
-
import { tf } from '..';
|
|
3
|
-
import { ConsoleLogger } from '.';
|
|
4
|
-
export declare class TrainerLog {
|
|
5
|
-
epochs: List<number>;
|
|
6
|
-
trainAccuracy: List<number>;
|
|
7
|
-
validationAccuracy: List<number>;
|
|
8
|
-
loss: List<number>;
|
|
9
|
-
add(epoch: number, logs?: tf.Logs): void;
|
|
10
|
-
}
|
|
11
|
-
/**
|
|
12
|
-
*
|
|
13
|
-
* @class TrainerLogger
|
|
14
|
-
*/
|
|
15
|
-
export declare class TrainerLogger extends ConsoleLogger {
|
|
16
|
-
readonly log: TrainerLog;
|
|
17
|
-
readonly saveTrainerLog: boolean;
|
|
18
|
-
constructor(saveTrainerLog?: boolean);
|
|
19
|
-
onEpochEnd(epoch: number, logs?: tf.Logs): void;
|
|
20
|
-
/**
|
|
21
|
-
* Display ram usage
|
|
22
|
-
*/
|
|
23
|
-
ramUsage(): void;
|
|
24
|
-
}
|
|
@@ -1,59 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.TrainerLogger = exports.TrainerLog = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var immutable_1 = require("immutable");
|
|
6
|
-
var __1 = require("..");
|
|
7
|
-
var _1 = require(".");
|
|
8
|
-
var TrainerLog = /** @class */ (function () {
|
|
9
|
-
function TrainerLog() {
|
|
10
|
-
this.epochs = (0, immutable_1.List)();
|
|
11
|
-
this.trainAccuracy = (0, immutable_1.List)();
|
|
12
|
-
this.validationAccuracy = (0, immutable_1.List)();
|
|
13
|
-
this.loss = (0, immutable_1.List)();
|
|
14
|
-
}
|
|
15
|
-
TrainerLog.prototype.add = function (epoch, logs) {
|
|
16
|
-
this.epochs = this.epochs.push(epoch);
|
|
17
|
-
if (logs !== undefined) {
|
|
18
|
-
this.trainAccuracy = this.trainAccuracy.push(logs.acc);
|
|
19
|
-
this.validationAccuracy = this.validationAccuracy.push(logs.val_acc);
|
|
20
|
-
this.loss = this.loss.push(logs.loss);
|
|
21
|
-
}
|
|
22
|
-
};
|
|
23
|
-
return TrainerLog;
|
|
24
|
-
}());
|
|
25
|
-
exports.TrainerLog = TrainerLog;
|
|
26
|
-
/**
|
|
27
|
-
*
|
|
28
|
-
* @class TrainerLogger
|
|
29
|
-
*/
|
|
30
|
-
var TrainerLogger = /** @class */ (function (_super) {
|
|
31
|
-
(0, tslib_1.__extends)(TrainerLogger, _super);
|
|
32
|
-
// TODO: pass savaTrainerLog as false in browser, used for benchmarking
|
|
33
|
-
function TrainerLogger(saveTrainerLog) {
|
|
34
|
-
if (saveTrainerLog === void 0) { saveTrainerLog = true; }
|
|
35
|
-
var _this = _super.call(this) || this;
|
|
36
|
-
_this.saveTrainerLog = saveTrainerLog;
|
|
37
|
-
_this.log = new TrainerLog();
|
|
38
|
-
return _this;
|
|
39
|
-
}
|
|
40
|
-
TrainerLogger.prototype.onEpochEnd = function (epoch, logs) {
|
|
41
|
-
var _a, _b, _c;
|
|
42
|
-
// save logs
|
|
43
|
-
if (this.saveTrainerLog) {
|
|
44
|
-
this.log.add(epoch, logs);
|
|
45
|
-
}
|
|
46
|
-
// console output
|
|
47
|
-
var msg = "Epoch: " + epoch + "\nTrain: " + ((_a = logs === null || logs === void 0 ? void 0 : logs.acc) !== null && _a !== void 0 ? _a : 'undefined') + "\nValidation:" + ((_b = logs === null || logs === void 0 ? void 0 : logs.val_acc) !== null && _b !== void 0 ? _b : 'undefined') + "\nLoss:" + ((_c = logs === null || logs === void 0 ? void 0 : logs.loss) !== null && _c !== void 0 ? _c : 'undefined');
|
|
48
|
-
this.success("On epoch end:\n" + msg + "\n");
|
|
49
|
-
};
|
|
50
|
-
/**
|
|
51
|
-
* Display ram usage
|
|
52
|
-
*/
|
|
53
|
-
TrainerLogger.prototype.ramUsage = function () {
|
|
54
|
-
this.success("Training RAM usage is = " + __1.tf.memory().numBytes * 0.000001 + " MB");
|
|
55
|
-
this.success("Number of allocated tensors = " + __1.tf.memory().numTensors);
|
|
56
|
-
};
|
|
57
|
-
return TrainerLogger;
|
|
58
|
-
}(_1.ConsoleLogger));
|
|
59
|
-
exports.TrainerLogger = TrainerLogger;
|
|
@@ -1,22 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { TaskID } from '..';
|
|
3
|
-
import { ModelType } from './model_type';
|
|
4
|
-
export declare type Path = string;
|
|
5
|
-
export interface ModelInfo {
|
|
6
|
-
type?: ModelType;
|
|
7
|
-
taskID: TaskID;
|
|
8
|
-
name: string;
|
|
9
|
-
}
|
|
10
|
-
export declare type ModelSource = ModelInfo | Path;
|
|
11
|
-
export declare abstract class Memory {
|
|
12
|
-
abstract getModel(source: ModelSource): Promise<tf.LayersModel>;
|
|
13
|
-
abstract deleteModel(source: ModelSource): Promise<void>;
|
|
14
|
-
abstract loadModel(source: ModelSource): Promise<void>;
|
|
15
|
-
abstract getModelMetadata(source: ModelSource): Promise<object | undefined>;
|
|
16
|
-
abstract updateWorkingModel(source: ModelSource, model: tf.LayersModel): Promise<void>;
|
|
17
|
-
abstract saveWorkingModel(source: ModelSource): Promise<void>;
|
|
18
|
-
abstract downloadModel(source: ModelSource): Promise<void>;
|
|
19
|
-
abstract contains(source: ModelSource): Promise<boolean>;
|
|
20
|
-
abstract pathFor(source: ModelSource): Path | undefined;
|
|
21
|
-
abstract infoFor(source: ModelSource): ModelInfo | undefined;
|
|
22
|
-
}
|
package/dist/core/memory/base.js
DELETED