@epfml/discojs 2.0.0 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/default_tasks/cifar10/index.js +60 -0
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/default_tasks/mnist.js +61 -0
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/default_tasks/simple_face/index.js +48 -0
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/default_tasks/titanic.js +88 -0
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.d.ts +5 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/task/task.d.ts +12 -0
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +25 -16
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -41
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -3
- package/dist/core/task/index.js +0 -8
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.d.ts +0 -10
- package/dist/core/task/task.js +0 -31
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/tasks/cifar10.d.ts +0 -3
- package/dist/core/tasks/cifar10.js +0 -65
- package/dist/core/tasks/geotags.d.ts +0 -3
- package/dist/core/tasks/geotags.js +0 -67
- package/dist/core/tasks/index.d.ts +0 -6
- package/dist/core/tasks/index.js +0 -10
- package/dist/core/tasks/lus_covid.d.ts +0 -3
- package/dist/core/tasks/lus_covid.js +0 -87
- package/dist/core/tasks/mnist.d.ts +0 -3
- package/dist/core/tasks/mnist.js +0 -60
- package/dist/core/tasks/simple_face.d.ts +0 -2
- package/dist/core/tasks/simple_face.js +0 -41
- package/dist/core/tasks/titanic.d.ts +0 -3
- package/dist/core/tasks/titanic.js +0 -88
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -8
- package/dist/core/weights/aggregation.js +0 -96
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.task = void 0;
|
|
4
|
-
var __1 = require("..");
|
|
5
|
-
exports.task = {
|
|
6
|
-
taskID: 'simple_face',
|
|
7
|
-
displayInformation: {
|
|
8
|
-
taskTitle: 'Simple Face',
|
|
9
|
-
summary: {
|
|
10
|
-
preview: 'Can you detect if the person in a picture is a child or an adult?',
|
|
11
|
-
overview: 'Simple face is a small subset of face_task from Kaggle'
|
|
12
|
-
},
|
|
13
|
-
limitations: 'The training data is limited to small images of size 200x200.',
|
|
14
|
-
tradeoffs: 'Training success strongly depends on label distribution',
|
|
15
|
-
dataFormatInformation: '',
|
|
16
|
-
dataExampleText: 'Below you find an example',
|
|
17
|
-
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png'
|
|
18
|
-
},
|
|
19
|
-
trainingInformation: {
|
|
20
|
-
modelID: 'simple_face-model',
|
|
21
|
-
epochs: 50,
|
|
22
|
-
modelURL: 'https://storage.googleapis.com/deai-313515.appspot.com/models/mobileNetV2_35_alpha_2_classes/model.json',
|
|
23
|
-
roundDuration: 1,
|
|
24
|
-
validationSplit: 0.2,
|
|
25
|
-
batchSize: 10,
|
|
26
|
-
preprocessingFunctions: [__1.data.ImagePreprocessing.Normalize],
|
|
27
|
-
learningRate: 0.001,
|
|
28
|
-
modelCompileData: {
|
|
29
|
-
optimizer: 'sgd',
|
|
30
|
-
loss: 'categoricalCrossentropy',
|
|
31
|
-
metrics: ['accuracy']
|
|
32
|
-
},
|
|
33
|
-
dataType: 'image',
|
|
34
|
-
IMAGE_H: 200,
|
|
35
|
-
IMAGE_W: 200,
|
|
36
|
-
LABEL_LIST: ['child', 'adult'],
|
|
37
|
-
scheme: 'Federated',
|
|
38
|
-
noiseScale: undefined,
|
|
39
|
-
clippingRadius: undefined
|
|
40
|
-
}
|
|
41
|
-
};
|
|
@@ -1,88 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.model = exports.task = void 0;
|
|
4
|
-
var __1 = require("..");
|
|
5
|
-
exports.task = {
|
|
6
|
-
taskID: 'titanic',
|
|
7
|
-
displayInformation: {
|
|
8
|
-
taskTitle: 'Titanic',
|
|
9
|
-
summary: {
|
|
10
|
-
preview: "Test our platform by using a publicly available <b>tabular</b> dataset. <br><br> Download the passenger list from the Titanic shipwreck here: <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/epfml/disco/raw/develop/example_training_data/titanic_train.csv'>titanic_train.csv</a> (more info <a class='underline text-primary-dark dark:text-primary-light' href='https://www.kaggle.com/c/titanic'>here</a>). <br> This model predicts the type of person most likely to survive/die in the historic ship accident, based on their characteristics (sex, age, class etc.).",
|
|
11
|
-
overview: 'We all know the unfortunate story of the Titanic: this flamboyant new transatlantic boat that sunk in 1912 in the North Atlantic Ocean. Today, we revist this tragedy by trying to predict the survival odds of the passenger given some basic features.'
|
|
12
|
-
},
|
|
13
|
-
model: 'The current model does not normalize the given data and applies only a very simple pre-processing of the data.',
|
|
14
|
-
tradeoffs: 'We are using a small model for this task: 4 fully connected layers with few neurons. This allows fast training but can yield to reduced accuracy.',
|
|
15
|
-
dataFormatInformation: 'This model takes as input a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.<br><br>pclass: A proxy for socio-economic status (SES)<br>1st = Upper<br>2nd = Middle<br>3rd = Lower<br><br>age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5<br><br>sibsp: The dataset defines family relations in this way:<br>Sibling = brother, sister, stepbrother, stepsister<br>Spouse = husband, wife (mistresses and fiancés were ignored)<br><br>parch: The dataset defines family relations in this way:<br>Parent = mother, father<br>Child = daughter, son, stepdaughter, stepson<br>Some children travelled only with a nanny, therefore parch=0 for them.<br><br>The first line of the CSV contains the header:<br> PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked<br><br>Each susequent row contains the corresponding data.',
|
|
16
|
-
dataExampleText: 'Below one can find an example of a datapoint taken as input by our model. In this datapoint, the person is young man named Owen Harris that unfortunnalty perished with the Titanic. He boarded the boat in South Hamptons and was a 3rd class passenger. On the testing & validation page, the data should not contain the label column (Survived).',
|
|
17
|
-
dataExample: [
|
|
18
|
-
{ columnName: 'PassengerId', columnData: '1' },
|
|
19
|
-
{ columnName: 'Survived', columnData: '0' },
|
|
20
|
-
{ columnName: 'Name', columnData: 'Braund, Mr. Owen Harris' },
|
|
21
|
-
{ columnName: 'Sex', columnData: 'male' },
|
|
22
|
-
{ columnName: 'Age', columnData: '22' },
|
|
23
|
-
{ columnName: 'SibSp', columnData: '1' },
|
|
24
|
-
{ columnName: 'Parch', columnData: '0' },
|
|
25
|
-
{ columnName: 'Ticket', columnData: '1/5 21171' },
|
|
26
|
-
{ columnName: 'Fare', columnData: '7.25' },
|
|
27
|
-
{ columnName: 'Cabin', columnData: 'E46' },
|
|
28
|
-
{ columnName: 'Embarked', columnData: 'S' },
|
|
29
|
-
{ columnName: 'Pclass', columnData: '3' }
|
|
30
|
-
],
|
|
31
|
-
headers: [
|
|
32
|
-
'PassengerId',
|
|
33
|
-
'Survived',
|
|
34
|
-
'Name',
|
|
35
|
-
'Sex',
|
|
36
|
-
'Age',
|
|
37
|
-
'SibSp',
|
|
38
|
-
'Parch',
|
|
39
|
-
'Ticket',
|
|
40
|
-
'Fare',
|
|
41
|
-
'Cabin',
|
|
42
|
-
'Embarked',
|
|
43
|
-
'Pclass'
|
|
44
|
-
]
|
|
45
|
-
},
|
|
46
|
-
trainingInformation: {
|
|
47
|
-
modelID: 'titanic-model',
|
|
48
|
-
epochs: 20,
|
|
49
|
-
roundDuration: 10,
|
|
50
|
-
validationSplit: 0,
|
|
51
|
-
batchSize: 30,
|
|
52
|
-
preprocessingFunctions: [],
|
|
53
|
-
modelCompileData: {
|
|
54
|
-
optimizer: 'rmsprop',
|
|
55
|
-
loss: 'binaryCrossentropy',
|
|
56
|
-
metrics: ['accuracy']
|
|
57
|
-
},
|
|
58
|
-
dataType: 'tabular',
|
|
59
|
-
inputColumns: [
|
|
60
|
-
'PassengerId',
|
|
61
|
-
'Age',
|
|
62
|
-
'SibSp',
|
|
63
|
-
'Parch',
|
|
64
|
-
'Fare',
|
|
65
|
-
'Pclass'
|
|
66
|
-
],
|
|
67
|
-
outputColumns: [
|
|
68
|
-
'Survived'
|
|
69
|
-
],
|
|
70
|
-
scheme: 'Federated',
|
|
71
|
-
noiseScale: undefined,
|
|
72
|
-
clippingRadius: undefined
|
|
73
|
-
}
|
|
74
|
-
};
|
|
75
|
-
function model() {
|
|
76
|
-
var model = __1.tf.sequential();
|
|
77
|
-
model.add(__1.tf.layers.dense({
|
|
78
|
-
inputShape: [6],
|
|
79
|
-
units: 124,
|
|
80
|
-
activation: 'relu',
|
|
81
|
-
kernelInitializer: 'leCunNormal'
|
|
82
|
-
}));
|
|
83
|
-
model.add(__1.tf.layers.dense({ units: 64, activation: 'relu' }));
|
|
84
|
-
model.add(__1.tf.layers.dense({ units: 32, activation: 'relu' }));
|
|
85
|
-
model.add(__1.tf.layers.dense({ units: 1, activation: 'sigmoid' }));
|
|
86
|
-
return model;
|
|
87
|
-
}
|
|
88
|
-
exports.model = model;
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
import { Client, data, Logger, Task, TrainingInformant, TrainingSchemes, Memory } from '..';
|
|
2
|
-
import { TrainerLog } from '../logging/trainer_logger';
|
|
3
|
-
interface DiscoOptions {
|
|
4
|
-
client?: Client;
|
|
5
|
-
url?: string | URL;
|
|
6
|
-
scheme?: TrainingSchemes;
|
|
7
|
-
informant?: TrainingInformant;
|
|
8
|
-
logger?: Logger;
|
|
9
|
-
memory?: Memory;
|
|
10
|
-
}
|
|
11
|
-
export declare class Disco {
|
|
12
|
-
readonly task: Task;
|
|
13
|
-
readonly logger: Logger;
|
|
14
|
-
readonly memory: Memory;
|
|
15
|
-
private readonly client;
|
|
16
|
-
private readonly trainer;
|
|
17
|
-
constructor(task: Task, options: DiscoOptions);
|
|
18
|
-
fit(dataTuple: data.DataSplit): Promise<void>;
|
|
19
|
-
pause(): Promise<void>;
|
|
20
|
-
close(): Promise<void>;
|
|
21
|
-
logs(): Promise<TrainerLog>;
|
|
22
|
-
}
|
|
23
|
-
export {};
|
|
@@ -1,130 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Disco = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var __1 = require("..");
|
|
6
|
-
var trainer_builder_1 = require("./trainer/trainer_builder");
|
|
7
|
-
// Handles the training loop, server communication & provides the user with feedback.
|
|
8
|
-
var Disco = /** @class */ (function () {
|
|
9
|
-
// client need to be connected
|
|
10
|
-
function Disco(task, options) {
|
|
11
|
-
if (options.scheme === undefined) {
|
|
12
|
-
options.scheme = __1.TrainingSchemes[task.trainingInformation.scheme];
|
|
13
|
-
}
|
|
14
|
-
if (options.client === undefined) {
|
|
15
|
-
if (options.url === undefined) {
|
|
16
|
-
throw new Error('could not determine client from given parameters');
|
|
17
|
-
}
|
|
18
|
-
if (typeof options.url === 'string') {
|
|
19
|
-
options.url = new URL(options.url);
|
|
20
|
-
}
|
|
21
|
-
switch (options.scheme) {
|
|
22
|
-
case __1.TrainingSchemes.FEDERATED:
|
|
23
|
-
options.client = new __1.client.federated.Client(options.url, task);
|
|
24
|
-
break;
|
|
25
|
-
case __1.TrainingSchemes.DECENTRALIZED:
|
|
26
|
-
options.client = new __1.client.federated.Client(options.url, task);
|
|
27
|
-
break;
|
|
28
|
-
default:
|
|
29
|
-
options.client = new __1.client.Local(options.url, task);
|
|
30
|
-
break;
|
|
31
|
-
}
|
|
32
|
-
}
|
|
33
|
-
if (options.informant === undefined) {
|
|
34
|
-
switch (options.scheme) {
|
|
35
|
-
case __1.TrainingSchemes.FEDERATED:
|
|
36
|
-
options.informant = new __1.informant.FederatedInformant(task);
|
|
37
|
-
break;
|
|
38
|
-
case __1.TrainingSchemes.DECENTRALIZED:
|
|
39
|
-
options.informant = new __1.informant.DecentralizedInformant(task);
|
|
40
|
-
break;
|
|
41
|
-
default:
|
|
42
|
-
options.informant = new __1.informant.LocalInformant(task);
|
|
43
|
-
break;
|
|
44
|
-
}
|
|
45
|
-
}
|
|
46
|
-
if (options.logger === undefined) {
|
|
47
|
-
options.logger = new __1.ConsoleLogger();
|
|
48
|
-
}
|
|
49
|
-
if (options.memory === undefined) {
|
|
50
|
-
options.memory = new __1.EmptyMemory();
|
|
51
|
-
}
|
|
52
|
-
if (options.client.task !== task) {
|
|
53
|
-
throw new Error('client not setup for given task');
|
|
54
|
-
}
|
|
55
|
-
if (options.informant.task.taskID !== task.taskID) {
|
|
56
|
-
throw new Error('informant not setup for given task');
|
|
57
|
-
}
|
|
58
|
-
this.task = task;
|
|
59
|
-
this.client = options.client;
|
|
60
|
-
this.memory = options.memory;
|
|
61
|
-
this.logger = options.logger;
|
|
62
|
-
var trainerBuilder = new trainer_builder_1.TrainerBuilder(this.memory, this.task, options.informant);
|
|
63
|
-
this.trainer = trainerBuilder.build(this.client, options.scheme !== __1.TrainingSchemes.LOCAL);
|
|
64
|
-
}
|
|
65
|
-
Disco.prototype.fit = function (dataTuple) {
|
|
66
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
67
|
-
var trainDataset, valDataset;
|
|
68
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
69
|
-
switch (_a.label) {
|
|
70
|
-
case 0:
|
|
71
|
-
this.logger.success('Thank you for your contribution. Data preprocessing has started');
|
|
72
|
-
trainDataset = dataTuple.train.batch().preprocess();
|
|
73
|
-
valDataset = dataTuple.validation !== undefined
|
|
74
|
-
? dataTuple.validation.batch().preprocess()
|
|
75
|
-
: trainDataset;
|
|
76
|
-
return [4 /*yield*/, this.client.connect()];
|
|
77
|
-
case 1:
|
|
78
|
-
_a.sent();
|
|
79
|
-
return [4 /*yield*/, this.trainer];
|
|
80
|
-
case 2: return [4 /*yield*/, (_a.sent()).trainModel(trainDataset.dataset, valDataset.dataset)];
|
|
81
|
-
case 3:
|
|
82
|
-
_a.sent();
|
|
83
|
-
return [2 /*return*/];
|
|
84
|
-
}
|
|
85
|
-
});
|
|
86
|
-
});
|
|
87
|
-
};
|
|
88
|
-
// Stops the training function. Does not disconnect the client.
|
|
89
|
-
Disco.prototype.pause = function () {
|
|
90
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
91
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
92
|
-
switch (_a.label) {
|
|
93
|
-
case 0: return [4 /*yield*/, this.trainer];
|
|
94
|
-
case 1: return [4 /*yield*/, (_a.sent()).stopTraining()];
|
|
95
|
-
case 2:
|
|
96
|
-
_a.sent();
|
|
97
|
-
this.logger.success('Training was successfully interrupted.');
|
|
98
|
-
return [2 /*return*/];
|
|
99
|
-
}
|
|
100
|
-
});
|
|
101
|
-
});
|
|
102
|
-
};
|
|
103
|
-
Disco.prototype.close = function () {
|
|
104
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
105
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
106
|
-
switch (_a.label) {
|
|
107
|
-
case 0: return [4 /*yield*/, this.pause()];
|
|
108
|
-
case 1:
|
|
109
|
-
_a.sent();
|
|
110
|
-
return [4 /*yield*/, this.client.disconnect()];
|
|
111
|
-
case 2:
|
|
112
|
-
_a.sent();
|
|
113
|
-
return [2 /*return*/];
|
|
114
|
-
}
|
|
115
|
-
});
|
|
116
|
-
});
|
|
117
|
-
};
|
|
118
|
-
Disco.prototype.logs = function () {
|
|
119
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
120
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
121
|
-
switch (_a.label) {
|
|
122
|
-
case 0: return [4 /*yield*/, this.trainer];
|
|
123
|
-
case 1: return [2 /*return*/, (_a.sent()).getTrainerLog()];
|
|
124
|
-
}
|
|
125
|
-
});
|
|
126
|
-
});
|
|
127
|
-
};
|
|
128
|
-
return Disco;
|
|
129
|
-
}());
|
|
130
|
-
exports.Disco = Disco;
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.TrainingSchemes = exports.Disco = void 0;
|
|
4
|
-
var disco_1 = require("./disco");
|
|
5
|
-
Object.defineProperty(exports, "Disco", { enumerable: true, get: function () { return disco_1.Disco; } });
|
|
6
|
-
var training_schemes_1 = require("./training_schemes");
|
|
7
|
-
Object.defineProperty(exports, "TrainingSchemes", { enumerable: true, get: function () { return training_schemes_1.TrainingSchemes; } });
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
import { tf, Client, Memory, Task, TrainingInformant } from '../..';
|
|
2
|
-
import { Trainer } from './trainer';
|
|
3
|
-
/**
|
|
4
|
-
* Class whose role is to train a model in a distributed way with a given dataset.
|
|
5
|
-
*/
|
|
6
|
-
export declare class DistributedTrainer extends Trainer {
|
|
7
|
-
private readonly previousRoundModel;
|
|
8
|
-
private readonly client;
|
|
9
|
-
/** DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
|
|
10
|
-
*/
|
|
11
|
-
constructor(task: Task, trainingInformant: TrainingInformant, memory: Memory, model: tf.LayersModel, previousRoundModel: tf.LayersModel, client: Client);
|
|
12
|
-
/**
|
|
13
|
-
* Callback called every time a round is over
|
|
14
|
-
*/
|
|
15
|
-
onRoundEnd(accuracy: number): Promise<void>;
|
|
16
|
-
/**
|
|
17
|
-
* Callback called once training is over
|
|
18
|
-
*/
|
|
19
|
-
onTrainEnd(): Promise<void>;
|
|
20
|
-
}
|
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.DistributedTrainer = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var __1 = require("../..");
|
|
6
|
-
var trainer_1 = require("./trainer");
|
|
7
|
-
/**
|
|
8
|
-
* Class whose role is to train a model in a distributed way with a given dataset.
|
|
9
|
-
*/
|
|
10
|
-
var DistributedTrainer = /** @class */ (function (_super) {
|
|
11
|
-
(0, tslib_1.__extends)(DistributedTrainer, _super);
|
|
12
|
-
/** DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
|
|
13
|
-
*/
|
|
14
|
-
function DistributedTrainer(task, trainingInformant, memory, model, previousRoundModel, client) {
|
|
15
|
-
var _this = _super.call(this, task, trainingInformant, memory, model) || this;
|
|
16
|
-
_this.previousRoundModel = previousRoundModel;
|
|
17
|
-
_this.client = client;
|
|
18
|
-
return _this;
|
|
19
|
-
}
|
|
20
|
-
/**
|
|
21
|
-
* Callback called every time a round is over
|
|
22
|
-
*/
|
|
23
|
-
DistributedTrainer.prototype.onRoundEnd = function (accuracy) {
|
|
24
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
25
|
-
var currentRoundWeights, previousRoundWeights, aggregatedWeights;
|
|
26
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
27
|
-
switch (_a.label) {
|
|
28
|
-
case 0:
|
|
29
|
-
currentRoundWeights = __1.WeightsContainer.from(this.model);
|
|
30
|
-
previousRoundWeights = __1.WeightsContainer.from(this.previousRoundModel);
|
|
31
|
-
return [4 /*yield*/, this.client.onRoundEndCommunication(currentRoundWeights, previousRoundWeights, this.roundTracker.round, this.trainingInformant)];
|
|
32
|
-
case 1:
|
|
33
|
-
aggregatedWeights = _a.sent();
|
|
34
|
-
this.previousRoundModel.setWeights(currentRoundWeights.weights);
|
|
35
|
-
this.model.setWeights(aggregatedWeights.weights);
|
|
36
|
-
return [4 /*yield*/, this.memory.updateWorkingModel({ taskID: this.task.taskID, name: this.trainingInformation.modelID }, this.model)];
|
|
37
|
-
case 2:
|
|
38
|
-
_a.sent();
|
|
39
|
-
return [2 /*return*/];
|
|
40
|
-
}
|
|
41
|
-
});
|
|
42
|
-
});
|
|
43
|
-
};
|
|
44
|
-
// if it is undefined, will training continue? we hope yes
|
|
45
|
-
/**
|
|
46
|
-
* Callback called once training is over
|
|
47
|
-
*/
|
|
48
|
-
DistributedTrainer.prototype.onTrainEnd = function () {
|
|
49
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
50
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
51
|
-
switch (_a.label) {
|
|
52
|
-
case 0: return [4 /*yield*/, this.client.onTrainEndCommunication(__1.WeightsContainer.from(this.model), this.trainingInformant)];
|
|
53
|
-
case 1:
|
|
54
|
-
_a.sent();
|
|
55
|
-
return [4 /*yield*/, _super.prototype.onTrainEnd.call(this)];
|
|
56
|
-
case 2:
|
|
57
|
-
_a.sent();
|
|
58
|
-
return [2 /*return*/];
|
|
59
|
-
}
|
|
60
|
-
});
|
|
61
|
-
});
|
|
62
|
-
};
|
|
63
|
-
return DistributedTrainer;
|
|
64
|
-
}(trainer_1.Trainer));
|
|
65
|
-
exports.DistributedTrainer = DistributedTrainer;
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
import { tf } from '../..';
|
|
2
|
-
import { Trainer } from './trainer';
|
|
3
|
-
/** Class whose role is to locally (alone) train a model on a given dataset, without any collaborators.
|
|
4
|
-
*/
|
|
5
|
-
export declare class LocalTrainer extends Trainer {
|
|
6
|
-
/**
|
|
7
|
-
* Callback called every time a round is over. For local training, a round is typically an epoch
|
|
8
|
-
*/
|
|
9
|
-
onRoundEnd(accuracy: number): Promise<void>;
|
|
10
|
-
protected onEpochEnd(epoch: number, logs?: tf.Logs): void;
|
|
11
|
-
}
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.LocalTrainer = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var trainer_1 = require("./trainer");
|
|
6
|
-
/** Class whose role is to locally (alone) train a model on a given dataset, without any collaborators.
|
|
7
|
-
*/
|
|
8
|
-
var LocalTrainer = /** @class */ (function (_super) {
|
|
9
|
-
(0, tslib_1.__extends)(LocalTrainer, _super);
|
|
10
|
-
function LocalTrainer() {
|
|
11
|
-
return _super !== null && _super.apply(this, arguments) || this;
|
|
12
|
-
}
|
|
13
|
-
/**
|
|
14
|
-
* Callback called every time a round is over. For local training, a round is typically an epoch
|
|
15
|
-
*/
|
|
16
|
-
LocalTrainer.prototype.onRoundEnd = function (accuracy) {
|
|
17
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
18
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
19
|
-
switch (_a.label) {
|
|
20
|
-
case 0: return [4 /*yield*/, this.memory.updateWorkingModel({ taskID: this.task.taskID, name: this.trainingInformation.modelID }, this.model)];
|
|
21
|
-
case 1:
|
|
22
|
-
_a.sent();
|
|
23
|
-
return [2 /*return*/];
|
|
24
|
-
}
|
|
25
|
-
});
|
|
26
|
-
});
|
|
27
|
-
};
|
|
28
|
-
LocalTrainer.prototype.onEpochEnd = function (epoch, logs) {
|
|
29
|
-
_super.prototype.onEpochEnd.call(this, epoch, logs);
|
|
30
|
-
this.trainingInformant.update({ currentRound: epoch });
|
|
31
|
-
};
|
|
32
|
-
return LocalTrainer;
|
|
33
|
-
}(trainer_1.Trainer));
|
|
34
|
-
exports.LocalTrainer = LocalTrainer;
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Class that keeps track of the current batch in order for the trainer to query when round has ended.
|
|
3
|
-
*
|
|
4
|
-
* @remark
|
|
5
|
-
* In distributed training, the client trains locally for a certain amount of epochs before sharing his weights to the server/neighbor, this
|
|
6
|
-
* is what we call a round.
|
|
7
|
-
*
|
|
8
|
-
* The role of the RoundTracker is to keep track of when a roundHasEnded using the current batch number. The batch in the RoundTracker is cumulative whereas
|
|
9
|
-
* in the onBatchEnd it is not (it resets to 0 after each epoch).
|
|
10
|
-
*
|
|
11
|
-
* The roundDuration is the length of a round (in batches).
|
|
12
|
-
*/
|
|
13
|
-
export declare class RoundTracker {
|
|
14
|
-
round: number;
|
|
15
|
-
batch: number;
|
|
16
|
-
roundDuration: number;
|
|
17
|
-
constructor(roundDuration: number);
|
|
18
|
-
/**
|
|
19
|
-
* Update the batch number, to be called inside onBatchEnd. (We do not use batch output of onBatchEnd since it is
|
|
20
|
-
* not cumulative).
|
|
21
|
-
*/
|
|
22
|
-
updateBatch(): void;
|
|
23
|
-
/**
|
|
24
|
-
* Returns true if a local round has ended, false otherwise.
|
|
25
|
-
*
|
|
26
|
-
* @remark
|
|
27
|
-
* Returns true if (batch) mod (batches per round) == 0, false otherwise
|
|
28
|
-
*/
|
|
29
|
-
roundHasEnded(): boolean;
|
|
30
|
-
}
|
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.RoundTracker = void 0;
|
|
4
|
-
/**
|
|
5
|
-
* Class that keeps track of the current batch in order for the trainer to query when round has ended.
|
|
6
|
-
*
|
|
7
|
-
* @remark
|
|
8
|
-
* In distributed training, the client trains locally for a certain amount of epochs before sharing his weights to the server/neighbor, this
|
|
9
|
-
* is what we call a round.
|
|
10
|
-
*
|
|
11
|
-
* The role of the RoundTracker is to keep track of when a roundHasEnded using the current batch number. The batch in the RoundTracker is cumulative whereas
|
|
12
|
-
* in the onBatchEnd it is not (it resets to 0 after each epoch).
|
|
13
|
-
*
|
|
14
|
-
* The roundDuration is the length of a round (in batches).
|
|
15
|
-
*/
|
|
16
|
-
var RoundTracker = /** @class */ (function () {
|
|
17
|
-
function RoundTracker(roundDuration) {
|
|
18
|
-
this.round = 0;
|
|
19
|
-
this.batch = 0;
|
|
20
|
-
this.roundDuration = roundDuration;
|
|
21
|
-
}
|
|
22
|
-
/**
|
|
23
|
-
* Update the batch number, to be called inside onBatchEnd. (We do not use batch output of onBatchEnd since it is
|
|
24
|
-
* not cumulative).
|
|
25
|
-
*/
|
|
26
|
-
RoundTracker.prototype.updateBatch = function () {
|
|
27
|
-
this.batch += 1;
|
|
28
|
-
};
|
|
29
|
-
/**
|
|
30
|
-
* Returns true if a local round has ended, false otherwise.
|
|
31
|
-
*
|
|
32
|
-
* @remark
|
|
33
|
-
* Returns true if (batch) mod (batches per round) == 0, false otherwise
|
|
34
|
-
*/
|
|
35
|
-
RoundTracker.prototype.roundHasEnded = function () {
|
|
36
|
-
if (this.batch === 0) {
|
|
37
|
-
return false;
|
|
38
|
-
}
|
|
39
|
-
var roundHasEnded = this.batch % this.roundDuration === 0;
|
|
40
|
-
if (roundHasEnded) {
|
|
41
|
-
this.round += 1;
|
|
42
|
-
}
|
|
43
|
-
return roundHasEnded;
|
|
44
|
-
};
|
|
45
|
-
return RoundTracker;
|
|
46
|
-
}());
|
|
47
|
-
exports.RoundTracker = RoundTracker;
|
|
@@ -1,65 +0,0 @@
|
|
|
1
|
-
import { tf, Memory, Task, TrainingInformant, TrainingInformation } from '../..';
|
|
2
|
-
import { RoundTracker } from './round_tracker';
|
|
3
|
-
import { TrainerLog } from '../../logging/trainer_logger';
|
|
4
|
-
/** Abstract class whose role is to train a model with a given dataset. This can be either done
|
|
5
|
-
* locally (alone) or in a distributed way with collaborators. The Trainer works as follows:
|
|
6
|
-
*
|
|
7
|
-
* 1. Call trainModel(dataset) to start training
|
|
8
|
-
* 2. Once a batch ends, onBatchEnd is triggered, which will then call onRoundEnd once the round has ended.
|
|
9
|
-
*
|
|
10
|
-
* The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators. To know when
|
|
11
|
-
* a round has ended we use the roundTracker object.
|
|
12
|
-
*/
|
|
13
|
-
export declare abstract class Trainer {
|
|
14
|
-
readonly task: Task;
|
|
15
|
-
readonly trainingInformant: TrainingInformant;
|
|
16
|
-
readonly memory: Memory;
|
|
17
|
-
readonly model: tf.LayersModel;
|
|
18
|
-
readonly trainingInformation: TrainingInformation;
|
|
19
|
-
readonly roundTracker: RoundTracker;
|
|
20
|
-
private stopTrainingRequested;
|
|
21
|
-
private readonly trainerLogger;
|
|
22
|
-
/**
|
|
23
|
-
* Constructs the training manager.
|
|
24
|
-
* @param task the trained task
|
|
25
|
-
* @param trainingInformant the training informant
|
|
26
|
-
*/
|
|
27
|
-
constructor(task: Task, trainingInformant: TrainingInformant, memory: Memory, model: tf.LayersModel);
|
|
28
|
-
/**
|
|
29
|
-
* Every time a round ends this function will be called
|
|
30
|
-
*/
|
|
31
|
-
protected abstract onRoundEnd(accuracy: number): Promise<void>;
|
|
32
|
-
/** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
|
|
33
|
-
*/
|
|
34
|
-
protected onBatchEnd(_: number, logs?: tf.Logs): Promise<void>;
|
|
35
|
-
/**
|
|
36
|
-
* We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
|
|
37
|
-
*/
|
|
38
|
-
protected onEpochEnd(epoch: number, logs?: tf.Logs): void;
|
|
39
|
-
/**
|
|
40
|
-
* When the training ends this function will be call
|
|
41
|
-
*/
|
|
42
|
-
protected onTrainEnd(logs?: tf.Logs): Promise<void>;
|
|
43
|
-
/**
|
|
44
|
-
* Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
|
|
45
|
-
*/
|
|
46
|
-
stopTraining(): Promise<void>;
|
|
47
|
-
/**
|
|
48
|
-
* Start training the model with the given dataset
|
|
49
|
-
* @param dataset
|
|
50
|
-
*/
|
|
51
|
-
trainModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): Promise<void>;
|
|
52
|
-
/**
|
|
53
|
-
* Format accuracy
|
|
54
|
-
*/
|
|
55
|
-
protected roundDecimals(accuracy: number, decimalsToRound?: number): number;
|
|
56
|
-
/**
|
|
57
|
-
* reset stop training state
|
|
58
|
-
*/
|
|
59
|
-
protected resetStopTrainerState(): void;
|
|
60
|
-
/**
|
|
61
|
-
* If stop training is requested, do so
|
|
62
|
-
*/
|
|
63
|
-
protected stopTrainModelIfRequested(): void;
|
|
64
|
-
getTrainerLog(): TrainerLog;
|
|
65
|
-
}
|