@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/{core/task → task}/task.d.ts +7 -7
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +24 -15
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/default_tasks/cifar10.d.ts +0 -2
- package/dist/core/default_tasks/geotags.d.ts +0 -2
- package/dist/core/default_tasks/geotags.js +0 -69
- package/dist/core/default_tasks/index.d.ts +0 -6
- package/dist/core/default_tasks/index.js +0 -15
- package/dist/core/default_tasks/lus_covid.d.ts +0 -2
- package/dist/core/default_tasks/lus_covid.js +0 -96
- package/dist/core/default_tasks/mnist.d.ts +0 -2
- package/dist/core/default_tasks/simple_face.d.ts +0 -2
- package/dist/core/default_tasks/titanic.d.ts +0 -2
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -39
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/digest.js +0 -18
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -6
- package/dist/core/task/index.js +0 -15
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.js +0 -35
- package/dist/core/task/task_handler.d.ts +0 -5
- package/dist/core/task/task_handler.js +0 -53
- package/dist/core/task/task_provider.d.ts +0 -6
- package/dist/core/task/task_provider.js +0 -13
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -7
- package/dist/core/weights/aggregation.js +0 -72
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/digest.d.ts +0 -0
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -1,26 +1,101 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
1
|
+
import { Range } from 'immutable';
|
|
2
|
+
import * as tf from '@tensorflow/tfjs';
|
|
3
|
+
import { ImageData } from '../data/index.js';
|
|
4
|
+
import { DataLoader } from '../data_loader/index.js';
|
|
5
|
+
/**
|
|
6
|
+
* Image data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
|
|
7
|
+
* @epfml/discojs-web and @epfml/discojs-node.
|
|
8
|
+
* Load labels and correctly match them with their respective images, with the following constraints:
|
|
9
|
+
* 1. Images are given as 1 image/1 file;
|
|
10
|
+
* 2. Labels are given as multiple labels/1 file, each label file can contain a different amount of labels.
|
|
11
|
+
*/
|
|
12
|
+
export class ImageLoader extends DataLoader {
|
|
13
|
+
task;
|
|
14
|
+
constructor(task) {
|
|
15
|
+
super();
|
|
16
|
+
this.task = task;
|
|
10
17
|
}
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
18
|
+
async load(image, config) {
|
|
19
|
+
let tensorContainer;
|
|
20
|
+
if (config?.labels === undefined) {
|
|
21
|
+
tensorContainer = await this.readImageFrom(image, config?.channels);
|
|
22
|
+
}
|
|
23
|
+
else {
|
|
24
|
+
tensorContainer = {
|
|
25
|
+
xs: await this.readImageFrom(image, config?.channels),
|
|
26
|
+
ys: config.labels[0]
|
|
27
|
+
};
|
|
28
|
+
}
|
|
29
|
+
return tf.data.array([tensorContainer]);
|
|
30
|
+
}
|
|
31
|
+
async buildDataset(images, labels, indices, config) {
|
|
32
|
+
// Can't use arrow function for generator and need access to 'this'
|
|
33
|
+
// eslint-disable-next-line
|
|
34
|
+
const self = this;
|
|
35
|
+
async function* dataGenerator() {
|
|
36
|
+
const withLabels = config?.labels !== undefined;
|
|
37
|
+
let index = 0;
|
|
38
|
+
while (index < indices.length) {
|
|
39
|
+
const sample = await self.readImageFrom(images[indices[index]], config?.channels);
|
|
40
|
+
const label = withLabels ? labels[indices[index]] : undefined;
|
|
41
|
+
const value = withLabels ? { xs: sample, ys: label } : sample;
|
|
42
|
+
index++;
|
|
43
|
+
yield value;
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
// @ts-expect-error: For some reasons typescript refuses async generator but tensorflow do work with them
|
|
47
|
+
const dataset = tf.data.generator(dataGenerator);
|
|
48
|
+
return await ImageData.init(dataset, this.task, indices.length);
|
|
49
|
+
}
|
|
50
|
+
async loadAll(images, config) {
|
|
51
|
+
let labels = [];
|
|
52
|
+
const indices = Range(0, images.length).toArray();
|
|
53
|
+
if (config?.labels !== undefined) {
|
|
54
|
+
const labelList = this.task.trainingInformation?.LABEL_LIST;
|
|
55
|
+
if (labelList === undefined || !Array.isArray(labelList)) {
|
|
56
|
+
throw new Error('LABEL_LIST should be specified in the task training information');
|
|
57
|
+
}
|
|
58
|
+
const numberOfClasses = labelList.length;
|
|
59
|
+
// Map label strings to integer
|
|
60
|
+
const label_to_int = new Map(labelList.map((label_name, idx) => [label_name, idx]));
|
|
61
|
+
if (label_to_int.size != numberOfClasses) {
|
|
62
|
+
throw new Error("Input labels aren't matching the task LABEL_LIST");
|
|
63
|
+
}
|
|
64
|
+
labels = config.labels.map(label_name => {
|
|
65
|
+
const label_int = label_to_int.get(label_name);
|
|
66
|
+
if (label_int === undefined) {
|
|
67
|
+
throw new Error(`Found input label ${label_name} not specified in task LABEL_LIST`);
|
|
20
68
|
}
|
|
69
|
+
return label_int;
|
|
21
70
|
});
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
71
|
+
labels = await tf.oneHot(tf.tensor1d(labels, 'int32'), numberOfClasses).array();
|
|
72
|
+
}
|
|
73
|
+
if (config?.shuffle === undefined || config?.shuffle) {
|
|
74
|
+
this.shuffle(indices);
|
|
75
|
+
}
|
|
76
|
+
if (config?.validationSplit === undefined || config?.validationSplit === 0) {
|
|
77
|
+
const dataset = await this.buildDataset(images, labels, indices, config);
|
|
78
|
+
return {
|
|
79
|
+
train: dataset,
|
|
80
|
+
validation: undefined
|
|
81
|
+
};
|
|
82
|
+
}
|
|
83
|
+
const trainSize = Math.floor(images.length * (1 - config.validationSplit));
|
|
84
|
+
const trainIndices = indices.slice(0, trainSize);
|
|
85
|
+
const valIndices = indices.slice(trainSize);
|
|
86
|
+
const trainDataset = await this.buildDataset(images, labels, trainIndices, config);
|
|
87
|
+
const valDataset = await this.buildDataset(images, labels, valIndices, config);
|
|
88
|
+
return {
|
|
89
|
+
train: trainDataset,
|
|
90
|
+
validation: valDataset
|
|
91
|
+
};
|
|
92
|
+
}
|
|
93
|
+
shuffle(array) {
|
|
94
|
+
for (let i = 0; i < array.length; i++) {
|
|
95
|
+
const j = Math.floor(Math.random() * i);
|
|
96
|
+
const swap = array[i];
|
|
97
|
+
array[i] = array[j];
|
|
98
|
+
array[j] = swap;
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
}
|
|
@@ -1,2 +1,5 @@
|
|
|
1
|
-
export {
|
|
2
|
-
export {
|
|
1
|
+
export type { DataConfig } from './data_loader.js';
|
|
2
|
+
export { DataLoader } from './data_loader.js';
|
|
3
|
+
export { ImageLoader } from './image_loader.js';
|
|
4
|
+
export { TabularLoader } from './tabular_loader.js';
|
|
5
|
+
export { TextLoader } from './text_loader.js';
|
|
@@ -1,7 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
Object.defineProperty(exports, "WebImageLoader", { enumerable: true, get: function () { return image_loader_1.WebImageLoader; } });
|
|
6
|
-
var tabular_loader_1 = require("./tabular_loader");
|
|
7
|
-
Object.defineProperty(exports, "WebTabularLoader", { enumerable: true, get: function () { return tabular_loader_1.WebTabularLoader; } });
|
|
1
|
+
export { DataLoader } from './data_loader.js';
|
|
2
|
+
export { ImageLoader } from './image_loader.js';
|
|
3
|
+
export { TabularLoader } from './tabular_loader.js';
|
|
4
|
+
export { TextLoader } from './text_loader.js';
|
|
@@ -1,4 +1,35 @@
|
|
|
1
|
-
import {
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
import type { Task } from '../../index.js';
|
|
2
|
+
import type { Dataset, DataSplit } from '../index.js';
|
|
3
|
+
import type { DataConfig } from './index.js';
|
|
4
|
+
import { DataLoader } from './index.js';
|
|
5
|
+
/**
|
|
6
|
+
* Tabular data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
|
|
7
|
+
* @epfml/discojs-web and @epfml/discojs-node. Loads data from files whose entries are line-separated and consist of
|
|
8
|
+
* character-separated features and label(s). Such files typically have the .csv extension.
|
|
9
|
+
*/
|
|
10
|
+
export declare abstract class TabularLoader<Source> extends DataLoader<Source> {
|
|
11
|
+
private readonly task;
|
|
12
|
+
readonly delimiter: string;
|
|
13
|
+
constructor(task: Task, delimiter?: string);
|
|
14
|
+
/**
|
|
15
|
+
* Creates a CSV dataset object based off the given source.
|
|
16
|
+
* @param source File object, URL string or local file system path.
|
|
17
|
+
* @param csvConfig Object expected by TF.js to create a CSVDataset.
|
|
18
|
+
* @returns The CSVDataset object built upon the given source.
|
|
19
|
+
*/
|
|
20
|
+
abstract loadDatasetFrom(source: Source, csvConfig: Record<string, unknown>): Promise<Dataset>;
|
|
21
|
+
/**
|
|
22
|
+
* Expects delimiter-separated tabular data made of N columns. The data may be
|
|
23
|
+
* potentially split among several sources. Every source should contain N-1
|
|
24
|
+
* feature columns and 1 single label column.
|
|
25
|
+
* @param source List of File objects, URLs or file system paths.
|
|
26
|
+
* @param config
|
|
27
|
+
* @returns A TF.js dataset built upon read tabular data stored in the given sources.
|
|
28
|
+
*/
|
|
29
|
+
load(source: Source, config?: DataConfig): Promise<Dataset>;
|
|
30
|
+
/**
|
|
31
|
+
* Creates the CSV datasets based off the given sources, then fuses them into a single CSV
|
|
32
|
+
* dataset.
|
|
33
|
+
*/
|
|
34
|
+
loadAll(sources: Source[], config: DataConfig): Promise<DataSplit>;
|
|
4
35
|
}
|
|
@@ -1,16 +1,76 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
1
|
+
import { List, Map, Set } from 'immutable';
|
|
2
|
+
import { TabularData } from '../index.js';
|
|
3
|
+
import { DataLoader } from './index.js';
|
|
4
|
+
// Window size from which the dataset shuffling will sample
|
|
5
|
+
const BUFFER_SIZE = 1000;
|
|
6
|
+
/**
|
|
7
|
+
* Tabular data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
|
|
8
|
+
* @epfml/discojs-web and @epfml/discojs-node. Loads data from files whose entries are line-separated and consist of
|
|
9
|
+
* character-separated features and label(s). Such files typically have the .csv extension.
|
|
10
|
+
*/
|
|
11
|
+
export class TabularLoader extends DataLoader {
|
|
12
|
+
task;
|
|
13
|
+
delimiter;
|
|
14
|
+
constructor(task, delimiter = ',') {
|
|
15
|
+
super();
|
|
16
|
+
this.task = task;
|
|
17
|
+
this.delimiter = delimiter;
|
|
10
18
|
}
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
19
|
+
/**
|
|
20
|
+
* Expects delimiter-separated tabular data made of N columns. The data may be
|
|
21
|
+
* potentially split among several sources. Every source should contain N-1
|
|
22
|
+
* feature columns and 1 single label column.
|
|
23
|
+
* @param source List of File objects, URLs or file system paths.
|
|
24
|
+
* @param config
|
|
25
|
+
* @returns A TF.js dataset built upon read tabular data stored in the given sources.
|
|
26
|
+
*/
|
|
27
|
+
async load(source, config) {
|
|
28
|
+
/**
|
|
29
|
+
* Prepare the CSV config object based off the given features and labels.
|
|
30
|
+
* If labels is empty, then the returned dataset is comprised of samples only.
|
|
31
|
+
* Otherwise, each entry is of the form `{ xs, ys }` with `xs` as features and `ys`
|
|
32
|
+
* as labels.
|
|
33
|
+
*/
|
|
34
|
+
if (config?.features === undefined) {
|
|
35
|
+
// TODO @s314cy
|
|
36
|
+
throw new Error('Not implemented');
|
|
37
|
+
}
|
|
38
|
+
const columnConfigs = Map(Set(config.features).map((feature) => [feature, { required: false, isLabel: false }])).merge(Set(config.labels).map((label) => [label, { required: true, isLabel: true }]));
|
|
39
|
+
const csvConfig = {
|
|
40
|
+
hasHeader: true,
|
|
41
|
+
columnConfigs: columnConfigs.toObject(),
|
|
42
|
+
configuredColumnsOnly: true,
|
|
43
|
+
delimiter: this.delimiter
|
|
44
|
+
};
|
|
45
|
+
const dataset = (await this.loadDatasetFrom(source, csvConfig)).map((t) => {
|
|
46
|
+
if (typeof t === 'object') {
|
|
47
|
+
if (('xs' in t) && ('ys' in t)) {
|
|
48
|
+
const { xs, ys } = t;
|
|
49
|
+
return {
|
|
50
|
+
xs: Object.values(xs),
|
|
51
|
+
ys: Object.values(ys)
|
|
52
|
+
};
|
|
53
|
+
}
|
|
54
|
+
else {
|
|
55
|
+
return t;
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
throw new TypeError('Expected TensorContainerObject');
|
|
59
|
+
});
|
|
60
|
+
return (config?.shuffle === undefined || config?.shuffle) ? dataset.shuffle(BUFFER_SIZE) : dataset;
|
|
61
|
+
}
|
|
62
|
+
/**
|
|
63
|
+
* Creates the CSV datasets based off the given sources, then fuses them into a single CSV
|
|
64
|
+
* dataset.
|
|
65
|
+
*/
|
|
66
|
+
async loadAll(sources, config) {
|
|
67
|
+
const datasets = await Promise.all(sources.map(async (source) => await this.load(source, { ...config, shuffle: false })));
|
|
68
|
+
let dataset = List(datasets).reduce((acc, dataset) => acc.concatenate(dataset));
|
|
69
|
+
dataset = config?.shuffle === true ? dataset.shuffle(BUFFER_SIZE) : dataset;
|
|
70
|
+
const data = await TabularData.init(dataset, this.task);
|
|
71
|
+
// TODO: Implement validation split for tabular data (tricky due to streaming)
|
|
72
|
+
return {
|
|
73
|
+
train: data
|
|
74
|
+
};
|
|
75
|
+
}
|
|
76
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import type { Task } from '../../index.js';
|
|
2
|
+
import type { DataSplit, Dataset } from '../index.js';
|
|
3
|
+
import { DataLoader, DataConfig } from './index.js';
|
|
4
|
+
/**
|
|
5
|
+
* Text data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
|
|
6
|
+
* @epfml/discojs-web and @epfml/discojs-node.
|
|
7
|
+
*/
|
|
8
|
+
export declare abstract class TextLoader<S> extends DataLoader<S> {
|
|
9
|
+
private readonly task;
|
|
10
|
+
constructor(task: Task);
|
|
11
|
+
abstract loadDatasetFrom(source: S): Promise<Dataset>;
|
|
12
|
+
load(source: S, config?: DataConfig): Promise<Dataset>;
|
|
13
|
+
loadAll(sources: S[], config?: DataConfig): Promise<DataSplit>;
|
|
14
|
+
}
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import { TextData } from '../index.js';
|
|
2
|
+
import { DataLoader } from './index.js';
|
|
3
|
+
/**
|
|
4
|
+
* Text data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
|
|
5
|
+
* @epfml/discojs-web and @epfml/discojs-node.
|
|
6
|
+
*/
|
|
7
|
+
export class TextLoader extends DataLoader {
|
|
8
|
+
task;
|
|
9
|
+
constructor(task) {
|
|
10
|
+
super();
|
|
11
|
+
this.task = task;
|
|
12
|
+
}
|
|
13
|
+
async load(source, config) {
|
|
14
|
+
const dataset = await this.loadDatasetFrom(source);
|
|
15
|
+
// 1st arg: Stream shuffling buffer size
|
|
16
|
+
return (config?.shuffle === undefined || config?.shuffle) ? dataset.shuffle(1000, undefined, true) : dataset;
|
|
17
|
+
}
|
|
18
|
+
async loadAll(sources, config) {
|
|
19
|
+
const concatenated = (await Promise.all(sources.map(async (src) => await this.load(src, config))))
|
|
20
|
+
.reduce((acc, dataset) => acc.concatenate(dataset));
|
|
21
|
+
return {
|
|
22
|
+
train: await TextData.init(concatenated, this.task)
|
|
23
|
+
};
|
|
24
|
+
}
|
|
25
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import type { Task } from '../index.js';
|
|
2
|
+
import type { DataSplit } from './data/index.js';
|
|
3
|
+
import type { DataConfig, DataLoader } from './data_loader/data_loader.js';
|
|
4
|
+
/**
|
|
5
|
+
* Incrementally builds a dataset from the provided file sources. The sources may
|
|
6
|
+
* either be file blobs or regular file system paths.
|
|
7
|
+
*/
|
|
8
|
+
export declare class DatasetBuilder<Source> {
|
|
9
|
+
/**
|
|
10
|
+
* The data loader used to load the data contained in the provided files.
|
|
11
|
+
*/
|
|
12
|
+
private readonly dataLoader;
|
|
13
|
+
/**
|
|
14
|
+
* The task for which the dataset should be built.
|
|
15
|
+
*/
|
|
16
|
+
readonly task: Task;
|
|
17
|
+
/**
|
|
18
|
+
* The buffer of unlabelled file sources.
|
|
19
|
+
*/
|
|
20
|
+
private _sources;
|
|
21
|
+
/**
|
|
22
|
+
* The buffer of labelled file sources.
|
|
23
|
+
*/
|
|
24
|
+
private labelledSources;
|
|
25
|
+
/**
|
|
26
|
+
* Whether a dataset was already produced.
|
|
27
|
+
*/
|
|
28
|
+
private _built;
|
|
29
|
+
constructor(
|
|
30
|
+
/**
|
|
31
|
+
* The data loader used to load the data contained in the provided files.
|
|
32
|
+
*/
|
|
33
|
+
dataLoader: DataLoader<Source>,
|
|
34
|
+
/**
|
|
35
|
+
* The task for which the dataset should be built.
|
|
36
|
+
*/
|
|
37
|
+
task: Task);
|
|
38
|
+
/**
|
|
39
|
+
* Adds the given file sources to the builder's buffer. Sources may be provided a label in the case
|
|
40
|
+
* of supervised learning.
|
|
41
|
+
* @param sources The array of file sources
|
|
42
|
+
* @param label The file sources label
|
|
43
|
+
*/
|
|
44
|
+
addFiles(sources: Source[], label?: string): void;
|
|
45
|
+
/**
|
|
46
|
+
* Clears the file sources buffers. If a label is provided, only the file sources
|
|
47
|
+
* corresponding to the given label will be removed.
|
|
48
|
+
* @param label The file sources label
|
|
49
|
+
*/
|
|
50
|
+
clearFiles(label?: string): void;
|
|
51
|
+
private resetBuiltState;
|
|
52
|
+
private getLabels;
|
|
53
|
+
build(config?: DataConfig): Promise<DataSplit>;
|
|
54
|
+
/**
|
|
55
|
+
* Whether the dataset builder has already been consumed to produce a dataset.
|
|
56
|
+
*/
|
|
57
|
+
get built(): boolean;
|
|
58
|
+
get size(): number;
|
|
59
|
+
get sources(): Source[];
|
|
60
|
+
}
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import { Map } from 'immutable';
|
|
2
|
+
/**
|
|
3
|
+
* Incrementally builds a dataset from the provided file sources. The sources may
|
|
4
|
+
* either be file blobs or regular file system paths.
|
|
5
|
+
*/
|
|
6
|
+
export class DatasetBuilder {
|
|
7
|
+
dataLoader;
|
|
8
|
+
task;
|
|
9
|
+
/**
|
|
10
|
+
* The buffer of unlabelled file sources.
|
|
11
|
+
*/
|
|
12
|
+
_sources;
|
|
13
|
+
/**
|
|
14
|
+
* The buffer of labelled file sources.
|
|
15
|
+
*/
|
|
16
|
+
labelledSources;
|
|
17
|
+
/**
|
|
18
|
+
* Whether a dataset was already produced.
|
|
19
|
+
*/
|
|
20
|
+
// TODO useless, responsibility on callers
|
|
21
|
+
_built;
|
|
22
|
+
constructor(
|
|
23
|
+
/**
|
|
24
|
+
* The data loader used to load the data contained in the provided files.
|
|
25
|
+
*/
|
|
26
|
+
dataLoader,
|
|
27
|
+
/**
|
|
28
|
+
* The task for which the dataset should be built.
|
|
29
|
+
*/
|
|
30
|
+
task) {
|
|
31
|
+
this.dataLoader = dataLoader;
|
|
32
|
+
this.task = task;
|
|
33
|
+
this._sources = [];
|
|
34
|
+
this.labelledSources = Map();
|
|
35
|
+
this._built = false;
|
|
36
|
+
}
|
|
37
|
+
/**
|
|
38
|
+
* Adds the given file sources to the builder's buffer. Sources may be provided a label in the case
|
|
39
|
+
* of supervised learning.
|
|
40
|
+
* @param sources The array of file sources
|
|
41
|
+
* @param label The file sources label
|
|
42
|
+
*/
|
|
43
|
+
addFiles(sources, label) {
|
|
44
|
+
if (this.built) {
|
|
45
|
+
this.resetBuiltState();
|
|
46
|
+
}
|
|
47
|
+
if (label === undefined) {
|
|
48
|
+
this._sources = this._sources.concat(sources);
|
|
49
|
+
}
|
|
50
|
+
else {
|
|
51
|
+
const currentSources = this.labelledSources.get(label);
|
|
52
|
+
if (currentSources === undefined) {
|
|
53
|
+
this.labelledSources = this.labelledSources.set(label, sources);
|
|
54
|
+
}
|
|
55
|
+
else {
|
|
56
|
+
this.labelledSources = this.labelledSources.set(label, currentSources.concat(sources));
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
/**
|
|
61
|
+
* Clears the file sources buffers. If a label is provided, only the file sources
|
|
62
|
+
* corresponding to the given label will be removed.
|
|
63
|
+
* @param label The file sources label
|
|
64
|
+
*/
|
|
65
|
+
clearFiles(label) {
|
|
66
|
+
if (this.built) {
|
|
67
|
+
this.resetBuiltState();
|
|
68
|
+
}
|
|
69
|
+
if (label === undefined) {
|
|
70
|
+
this._sources = [];
|
|
71
|
+
}
|
|
72
|
+
else {
|
|
73
|
+
this.labelledSources = this.labelledSources.delete(label);
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
// If files are added or removed, then this should be called since the latest
|
|
77
|
+
// version of the dataset_builder has not yet been built.
|
|
78
|
+
resetBuiltState() {
|
|
79
|
+
this._built = false;
|
|
80
|
+
}
|
|
81
|
+
getLabels() {
|
|
82
|
+
// We need to duplicate the labels as we need one for each source.
|
|
83
|
+
// Say for label A we have sources [img1, img2, img3], then we
|
|
84
|
+
// need labels [A, A, A].
|
|
85
|
+
let labels = [];
|
|
86
|
+
this.labelledSources.forEach((sources, label) => {
|
|
87
|
+
const sourcesLabels = Array.from({ length: sources.length }, (_) => label);
|
|
88
|
+
labels = labels.concat(sourcesLabels);
|
|
89
|
+
});
|
|
90
|
+
return labels.flat();
|
|
91
|
+
}
|
|
92
|
+
async build(config) {
|
|
93
|
+
// Require that at least one source collection is non-empty, but not both
|
|
94
|
+
if ((this._sources.length > 0) === (this.labelledSources.size > 0)) {
|
|
95
|
+
throw new Error('Please provide dataset input files');
|
|
96
|
+
}
|
|
97
|
+
let dataTuple;
|
|
98
|
+
if (this._sources.length > 0) {
|
|
99
|
+
let defaultConfig = {};
|
|
100
|
+
if (config?.inference === true) {
|
|
101
|
+
// Inferring model, no labels needed
|
|
102
|
+
defaultConfig = {
|
|
103
|
+
features: this.task.trainingInformation.inputColumns,
|
|
104
|
+
shuffle: false
|
|
105
|
+
};
|
|
106
|
+
}
|
|
107
|
+
else {
|
|
108
|
+
// Labels are contained in the given sources
|
|
109
|
+
defaultConfig = {
|
|
110
|
+
features: this.task.trainingInformation.inputColumns,
|
|
111
|
+
labels: this.task.trainingInformation.outputColumns,
|
|
112
|
+
shuffle: false
|
|
113
|
+
};
|
|
114
|
+
}
|
|
115
|
+
dataTuple = await this.dataLoader.loadAll(this._sources, { ...defaultConfig, ...config });
|
|
116
|
+
}
|
|
117
|
+
else {
|
|
118
|
+
// Labels are inferred from the file selection boxes
|
|
119
|
+
const defaultConfig = {
|
|
120
|
+
labels: this.getLabels(),
|
|
121
|
+
shuffle: false
|
|
122
|
+
};
|
|
123
|
+
const sources = this.labelledSources.valueSeq().toArray().flat();
|
|
124
|
+
dataTuple = await this.dataLoader.loadAll(sources, { ...defaultConfig, ...config });
|
|
125
|
+
}
|
|
126
|
+
// TODO @s314cy: Support .csv labels for image datasets (supervised training or testing)
|
|
127
|
+
this._built = true;
|
|
128
|
+
return dataTuple;
|
|
129
|
+
}
|
|
130
|
+
/**
|
|
131
|
+
* Whether the dataset builder has already been consumed to produce a dataset.
|
|
132
|
+
*/
|
|
133
|
+
get built() {
|
|
134
|
+
return this._built;
|
|
135
|
+
}
|
|
136
|
+
get size() {
|
|
137
|
+
return Math.max(this._sources.length, this.labelledSources.size);
|
|
138
|
+
}
|
|
139
|
+
get sources() {
|
|
140
|
+
return this._sources.length > 0 ? this._sources : this.labelledSources.valueSeq().toArray().flat();
|
|
141
|
+
}
|
|
142
|
+
}
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
export type { Dataset } from './dataset.js';
|
|
2
|
+
export { DatasetBuilder } from './dataset_builder.js';
|
|
3
|
+
export { ImageLoader, TabularLoader, DataLoader, TextLoader } from './data_loader/index.js';
|
|
4
|
+
export type { DataSplit } from './data/index.js';
|
|
5
|
+
export { Data, TabularData, ImageData, TextData, ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING } from './data/index.js';
|
|
@@ -0,0 +1,3 @@
|
|
|
1
|
+
export { DatasetBuilder } from './dataset_builder.js';
|
|
2
|
+
export { ImageLoader, TabularLoader, DataLoader, TextLoader } from './data_loader/index.js';
|
|
3
|
+
export { Data, TabularData, ImageData, TextData, ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING } from './data/index.js';
|
|
@@ -1,12 +1,10 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
exports.cifar10 = {
|
|
7
|
-
getTask: function () {
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { data, models } from '../../index.js';
|
|
3
|
+
import baseModel from './model.js';
|
|
4
|
+
export const cifar10 = {
|
|
5
|
+
getTask() {
|
|
8
6
|
return {
|
|
9
|
-
|
|
7
|
+
id: 'cifar10',
|
|
10
8
|
displayInformation: {
|
|
11
9
|
taskTitle: 'CIFAR10',
|
|
12
10
|
summary: {
|
|
@@ -25,17 +23,12 @@ exports.cifar10 = {
|
|
|
25
23
|
roundDuration: 10,
|
|
26
24
|
validationSplit: 0.2,
|
|
27
25
|
batchSize: 10,
|
|
28
|
-
modelCompileData: {
|
|
29
|
-
optimizer: 'sgd',
|
|
30
|
-
loss: 'categoricalCrossentropy',
|
|
31
|
-
metrics: ['accuracy']
|
|
32
|
-
},
|
|
33
26
|
dataType: 'image',
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
27
|
+
preprocessingFunctions: [data.ImagePreprocessing.Resize],
|
|
28
|
+
IMAGE_H: 224,
|
|
29
|
+
IMAGE_W: 224,
|
|
37
30
|
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
38
|
-
scheme: '
|
|
31
|
+
scheme: 'decentralized',
|
|
39
32
|
noiseScale: undefined,
|
|
40
33
|
clippingRadius: 20,
|
|
41
34
|
decentralizedSecure: true,
|
|
@@ -44,25 +37,24 @@ exports.cifar10 = {
|
|
|
44
37
|
}
|
|
45
38
|
};
|
|
46
39
|
},
|
|
47
|
-
getModel
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
}
|
|
65
|
-
});
|
|
40
|
+
async getModel() {
|
|
41
|
+
const mobilenet = await tf.loadLayersModel({
|
|
42
|
+
load: async () => Promise.resolve(baseModel),
|
|
43
|
+
});
|
|
44
|
+
const x = mobilenet.getLayer('global_average_pooling2d_1');
|
|
45
|
+
const predictions = tf.layers
|
|
46
|
+
.dense({ units: 10, activation: 'softmax', name: 'denseModified' })
|
|
47
|
+
.apply(x.output);
|
|
48
|
+
const model = tf.model({
|
|
49
|
+
inputs: mobilenet.input,
|
|
50
|
+
outputs: predictions,
|
|
51
|
+
name: 'modelModified'
|
|
52
|
+
});
|
|
53
|
+
model.compile({
|
|
54
|
+
optimizer: 'sgd',
|
|
55
|
+
loss: 'categoricalCrossentropy',
|
|
56
|
+
metrics: ['accuracy']
|
|
66
57
|
});
|
|
58
|
+
return new models.TFJS(model);
|
|
67
59
|
}
|
|
68
60
|
};
|